source: icGREP/icgrep-devel/icgrep/kernels/kernel.cpp @ 5063

Last change on this file since 5063 was 5063, checked in by cameron, 3 years ago

New kernel infrastructure

File size: 5.8 KB
Line 
1/*
2 *  Copyright (c) 2016 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 */
5
6#include "kernel.h"
7#include <llvm/IR/Module.h>
8#include <llvm/IR/Type.h>
9#include <llvm/IR/Value.h>
10#include <llvm/Support/raw_ostream.h>
11
12using namespace llvm;
13using namespace kernel;
14
15KernelBuilder::KernelBuilder(IDISA::IDISA_Builder * builder,
16                                 std::string kernelName,
17                                 std::vector<StreamSetBinding> stream_inputs,
18                                 std::vector<StreamSetBinding> stream_outputs,
19                                 std::vector<ScalarBinding> scalar_parameters,
20                                 std::vector<ScalarBinding> scalar_outputs,
21                                 std::vector<ScalarBinding> internal_scalars) :
22    KernelInterface(builder, kernelName, stream_inputs, stream_outputs, scalar_parameters, scalar_outputs, internal_scalars) {
23   
24    for (auto binding : scalar_parameters) {
25        addScalar(binding.scalarType, binding.scalarName);
26    }
27    for (auto binding : scalar_outputs) {
28        addScalar(binding.scalarType, binding.scalarName);
29    }
30    for (auto binding : internal_scalars) {
31        addScalar(binding.scalarType, binding.scalarName);
32    }
33}
34
35void KernelBuilder::addScalar(Type * t, std::string scalarName) {
36    if (LLVM_UNLIKELY(mKernelStateType != nullptr)) {
37        throw std::runtime_error("Illegal addition of kernel field after kernel state finalized: " + scalarName);
38    }
39    unsigned index = mKernelFields.size();
40    mKernelFields.push_back(t);
41    mInternalStateNameMap.emplace(scalarName, iBuilder->getInt32(index));
42}
43
44void KernelBuilder::finalizeKernelStateType() {
45    mKernelStateType = StructType::create(getGlobalContext(), mKernelFields, mKernelName);
46}
47
48std::unique_ptr<Module> KernelBuilder::createKernelModule() {
49    Module * saveModule = iBuilder->getModule();
50    IDISA::IDISA_Builder::InsertPoint savePoint = iBuilder->saveIP();
51    if (mKernelStateType == nullptr) finalizeKernelStateType();
52    std::unique_ptr<Module> theModule = make_unique<Module>(mKernelName, getGlobalContext());
53    Module * m = theModule.get();
54    iBuilder->setModule(m);
55    generateKernel();
56    iBuilder->setModule(saveModule);
57    iBuilder->restoreIP(savePoint);
58    return theModule;
59}
60
61void KernelBuilder::generateKernel() {
62    Module * m = iBuilder->getModule();
63    IDISA::IDISA_Builder::InsertPoint savePoint = iBuilder->saveIP();
64    addKernelDeclarations(m);
65    // Implement the accumulator get functions
66    for (auto binding : mScalarOutputs) {
67        auto fnName = mKernelName + accumulator_infix + binding.scalarName;
68        Function * accumFn = m->getFunction(fnName);
69        iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "get_" + binding.scalarName, accumFn, 0));
70        Value * self = &*(accumFn->arg_begin());
71        Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(binding.scalarName)});
72        Value * retVal = iBuilder->CreateLoad(ptr);
73        iBuilder->CreateRet(retVal);
74    }
75    // Implement the initializer function
76    Function * initFunction = m->getFunction(mKernelName + init_suffix);
77    iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "Init_entry", initFunction, 0));
78   
79    Function::arg_iterator args = initFunction->arg_begin();
80    Value * self = &*(args++);
81    iBuilder->CreateStore(Constant::getNullValue(mKernelStateType), self);
82    for (auto binding : mScalarInputs) {
83        Value * parm = &*(args++);
84        Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(binding.scalarName)});
85        iBuilder->CreateStore(parm, ptr);
86    }
87    iBuilder->CreateRetVoid();
88    iBuilder->restoreIP(savePoint);
89}
90
91void KernelBuilder::addTrivialFinalBlockMethod(Module * m) {
92    IDISA::IDISA_Builder::InsertPoint savePoint = iBuilder->saveIP();
93    Module * saveModule = iBuilder->getModule();
94    iBuilder->setModule(m);
95    Function * doBlockFunction = m->getFunction(mKernelName + doBlock_suffix);
96    Function * finalBlockFunction = m->getFunction(mKernelName + finalBlock_suffix);
97    iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "fb_entry", finalBlockFunction, 0));
98    // Final Block arguments: self, remaining, then the standard DoBlock args.
99    Function::arg_iterator args = finalBlockFunction->arg_begin();
100    Value * self = &*(args++);
101    /* Skip "remaining" arg */ args++;
102    std::vector<Value *> doBlockArgs = {self};
103    while (args != finalBlockFunction->arg_end()){
104        doBlockArgs.push_back(&*args++);
105    }
106    iBuilder->CreateCall(doBlockFunction, doBlockArgs);
107    iBuilder->CreateRetVoid();
108    iBuilder->setModule(saveModule);
109    iBuilder->restoreIP(savePoint);
110}
111
112Value * KernelBuilder::getScalarIndex(std::string fieldName) {
113    const auto f = mInternalStateNameMap.find(fieldName);
114    if (LLVM_UNLIKELY(f == mInternalStateNameMap.end())) {
115        throw std::runtime_error("Kernel does not contain internal state: " + fieldName);
116    }
117    return f->second;
118}
119
120Value * KernelBuilder::getScalarField(Value * self, std::string fieldName) {
121    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(fieldName)});
122    return iBuilder->CreateLoad(ptr);
123}
124
125void KernelBuilder::setScalarField(Value * self, std::string fieldName, Value * newFieldVal) {
126    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(fieldName)});
127    iBuilder->CreateStore(newFieldVal, ptr);
128}
129
130
131Value * KernelBuilder::getParameter(Function * f, std::string paramName) {
132    for (Function::arg_iterator argIter = f->arg_begin(), end = f->arg_end(); argIter != end; argIter++) {
133        Value * arg = &*argIter;
134        if (arg->getName() == paramName) return arg;
135    }
136    throw std::runtime_error("Method does not have parameter: " + paramName);
137}
138
139
140
Note: See TracBrowser for help on using the repository browser.