source: icGREP/icgrep-devel/icgrep/kernels/interface.cpp @ 5706

Last change on this file since 5706 was 5706, checked in by nmedfort, 16 months ago

First stage of MultiBlockKernel? and pipeline restructuring

File size: 5.8 KB
Line 
1/*
2 *  Copyright (c) 2016 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 */
5
6#include "interface.h"
7#include <llvm/IR/Value.h>         // for Value
8#include <llvm/IR/CallingConv.h>   // for ::C
9#include <llvm/IR/DerivedTypes.h>  // for FunctionType (ptr only), PointerType
10#include <llvm/IR/Function.h>      // for Function, Function::arg_iterator
11#include <llvm/IR/Module.h>
12#include <kernels/kernel_builder.h>
13
14static const auto INIT_SUFFIX = "_Init";
15
16static const auto DO_SEGMENT_SUFFIX = "_DoSegment";
17
18static const auto TERMINATE_SUFFIX = "_Terminate";
19
20using namespace llvm;
21
22namespace kernel {
23
24void KernelInterface::addKernelDeclarations(const std::unique_ptr<kernel::KernelBuilder> & idb) {
25
26    if (mKernelStateType == nullptr) {
27        throw std::runtime_error("Kernel interface " + getName() + " not yet finalized.");
28    }
29
30    Module * const module = idb->getModule();
31    PointerType * const selfType = mKernelStateType->getPointerTo();
32    IntegerType * const sizeTy = idb->getSizeTy();
33    PointerType * const consumerTy = StructType::get(sizeTy, sizeTy->getPointerTo()->getPointerTo(), nullptr)->getPointerTo();
34    Type * const voidTy = idb->getVoidTy();
35
36    // Create the initialization function prototype
37    std::vector<Type *> initParameters = {selfType};
38    for (auto binding : mScalarInputs) {
39        initParameters.push_back(binding.getType());
40    }
41    initParameters.insert(initParameters.end(), mStreamSetOutputs.size(), consumerTy);
42
43    FunctionType * const initType = FunctionType::get(voidTy, initParameters, false);
44    Function * const initFunc = Function::Create(initType, GlobalValue::ExternalLinkage, getName() + INIT_SUFFIX, module);
45    initFunc->setCallingConv(CallingConv::C);
46    initFunc->setDoesNotThrow();
47    auto args = initFunc->arg_begin();
48    args->setName("self");
49    for (const Binding & binding : mScalarInputs) {
50        (++args)->setName(binding.getName());
51    }
52    for (const Binding & binding : mStreamSetOutputs) {
53        (++args)->setName(binding.getName() + "ConsumerLocks");
54    }
55
56    // Create the doSegment function prototype.
57    std::vector<Type *> params = {selfType, idb->getInt1Ty()};
58
59    const auto count = mStreamSetInputs.size();
60    params.insert(params.end(), count, sizeTy);
61
62    FunctionType * const doSegmentType = FunctionType::get(voidTy, params, false);
63    Function * const doSegment = Function::Create(doSegmentType, GlobalValue::ExternalLinkage, getName() + DO_SEGMENT_SUFFIX, module);
64    doSegment->setCallingConv(CallingConv::C);
65    doSegment->setDoesNotThrow();
66    doSegment->setDoesNotCapture(1); // for self parameter only.
67    args = doSegment->arg_begin();
68    args->setName("self");
69    (++args)->setName("doFinal");
70//    if (mHasPrincipleItemCount) {
71//        (++args)->setName("principleAvailableItemCount");
72//    }
73    for (const Binding & input : mStreamSetInputs) {
74        //const ProcessingRate & r = input.getRate();
75        //if (!r.isDerived()) {
76            (++args)->setName(input.getName() + "AvailableItems");
77        //}
78    }
79
80    // Create the terminate function prototype
81    Type * resultType = nullptr;
82    if (mScalarOutputs.empty()) {
83        resultType = idb->getVoidTy();
84    } else {
85        const auto n = mScalarOutputs.size();
86        Type * outputType[n];
87        for (unsigned i = 0; i < n; ++i) {
88            outputType[i] = mScalarOutputs[i].getType();
89        }
90        if (n == 1) {
91            resultType = outputType[0];
92        } else {
93            resultType = StructType::get(idb->getContext(), ArrayRef<Type *>(outputType, n));
94        }
95    }
96    FunctionType * const terminateType = FunctionType::get(resultType, {selfType}, false);
97    Function * const terminateFunc = Function::Create(terminateType, GlobalValue::ExternalLinkage, getName() + TERMINATE_SUFFIX, module);
98    terminateFunc->setCallingConv(CallingConv::C);
99    terminateFunc->setDoesNotThrow();
100    terminateFunc->setDoesNotCapture(1);
101    args = terminateFunc->arg_begin();
102    args->setName("self");
103
104    linkExternalMethods(idb);
105}
106
107void  KernelInterface::setInstance(Value * const instance) {
108    assert ("kernel instance cannot be null!" && instance);
109    assert ("kernel instance must point to a valid kernel state type!" && (instance->getType()->getPointerElementType() == mKernelStateType));
110    mKernelInstance = instance;
111}
112
113Function * KernelInterface::getInitFunction(Module * const module) const {
114    const auto name = getName() + INIT_SUFFIX;
115    Function * f = module->getFunction(name);
116    if (LLVM_UNLIKELY(f == nullptr)) {
117        llvm::report_fatal_error("Cannot find " + name);
118    }
119    return f;
120}
121
122Function * KernelInterface::getDoSegmentFunction(Module * const module) const {
123    const auto name = getName() + DO_SEGMENT_SUFFIX;
124    Function * f = module->getFunction(name);
125    if (LLVM_UNLIKELY(f == nullptr)) {
126        llvm::report_fatal_error("Cannot find " + name);
127    }
128    return f;
129}
130
131Function * KernelInterface::getTerminateFunction(Module * const module) const {
132    const auto name = getName() + TERMINATE_SUFFIX;
133    Function * f = module->getFunction(name);
134    if (LLVM_UNLIKELY(f == nullptr)) {
135        llvm::report_fatal_error("Cannot find " + name);
136    }
137    return f;
138}
139
140CallInst * KernelInterface::makeDoSegmentCall(kernel::KernelBuilder & idb, const std::vector<llvm::Value *> & args) const {
141    Function * const doSegment = getDoSegmentFunction(idb.getModule());
142    assert (doSegment->getArgumentList().size() <= args.size());
143    return idb.CreateCall(doSegment, args);
144}
145
146void Binding::addAttribute(Attribute attribute) {
147    for (Attribute & attr : attributes) {
148        if (attr.getKind() == attribute.getKind()) {
149            return;
150        }
151    }
152    attributes.emplace_back(attribute);
153}
154
155void KernelInterface::normalizeStreamProcessingRates() {
156
157}
158
159}
Note: See TracBrowser for help on using the repository browser.