source: icGREP/icgrep-devel/icgrep/kernels/kernel.cpp @ 5133

Last change on this file since 5133 was 5133, checked in by cameron, 3 years ago

Defer binding of buffers to stream sets until kernel generation

File size: 11.6 KB
Line 
1/*
2 *  Copyright (c) 2016 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 */
5
6#include "kernel.h"
7#include <llvm/IR/Module.h>
8#include <llvm/IR/Type.h>
9#include <llvm/IR/Value.h>
10#include <llvm/Support/raw_ostream.h>
11
12using namespace llvm;
13using namespace kernel;
14
15KernelBuilder::KernelBuilder(IDISA::IDISA_Builder * builder,
16                                 std::string kernelName,
17                                 std::vector<StreamSetBinding> stream_inputs,
18                                 std::vector<StreamSetBinding> stream_outputs,
19                                 std::vector<ScalarBinding> scalar_parameters,
20                                 std::vector<ScalarBinding> scalar_outputs,
21                                 std::vector<ScalarBinding> internal_scalars) :
22    KernelInterface(builder, kernelName, stream_inputs, stream_outputs, scalar_parameters, scalar_outputs, internal_scalars) {}
23
24void KernelBuilder::addScalar(Type * t, std::string scalarName) {
25    if (LLVM_UNLIKELY(mKernelStateType != nullptr)) {
26        throw std::runtime_error("Illegal addition of kernel field after kernel state finalized: " + scalarName);
27    }
28    unsigned index = mKernelFields.size();
29    mKernelFields.push_back(t);
30    mInternalStateNameMap.emplace(scalarName, index);
31}
32
33void KernelBuilder::prepareKernel() {
34    if (mStreamSetInputs.size() != mStreamSetInputBuffers.size()) {
35        throw std::runtime_error("Kernel preparation: Incorrect number of input buffers");
36    }
37    if (mStreamSetOutputs.size() != mStreamSetOutputBuffers.size()) {
38        throw std::runtime_error("Kernel preparation: Incorrect number of input buffers");
39    }
40    addScalar(iBuilder->getSizeTy(), blockNoScalar);
41    int streamSetNo = 0;
42    for (unsigned i = 0; i < mStreamSetInputs.size(); i++) {
43        if (!(mStreamSetInputBuffers[i]->getBufferStreamSetType() == mStreamSetInputs[i].ssType)) {
44             throw std::runtime_error("Kernel preparation: Incorrect input buffer type");
45        }
46        mScalarInputs.push_back(ScalarBinding{mStreamSetInputBuffers[i]->getStreamBufferPointerType(), mStreamSetInputs[i].ssName + basePtrSuffix});
47        mStreamSetNameMap.emplace(mStreamSetInputs[i].ssName, streamSetNo);
48        streamSetNo++;
49    }
50    for (unsigned i = 0; i < mStreamSetOutputs.size(); i++) {
51        if (!(mStreamSetOutputBuffers[i]->getBufferStreamSetType() == mStreamSetOutputs[i].ssType)) {
52             throw std::runtime_error("Kernel preparation: Incorrect input buffer type");
53        }
54        mScalarInputs.push_back(ScalarBinding{mStreamSetOutputBuffers[i]->getStreamBufferPointerType(), mStreamSetOutputs[i].ssName + basePtrSuffix});
55        mStreamSetNameMap.emplace(mStreamSetOutputs[i].ssName, streamSetNo);
56        streamSetNo++;
57    }
58    for (auto binding : mScalarInputs) {
59        addScalar(binding.scalarType, binding.scalarName);
60    }
61    for (auto binding : mScalarOutputs) {
62        addScalar(binding.scalarType, binding.scalarName);
63    }
64    for (auto binding : mInternalScalars) {
65        addScalar(binding.scalarType, binding.scalarName);
66    }
67    mKernelStateType = StructType::create(getGlobalContext(), mKernelFields, mKernelName);
68}
69
70std::unique_ptr<Module> KernelBuilder::createKernelModule(std::vector<StreamSetBuffer *> input_buffers, std::vector<StreamSetBuffer *> output_buffers) {
71    Module * saveModule = iBuilder->getModule();
72    IDISA::IDISA_Builder::InsertPoint savePoint = iBuilder->saveIP();
73    std::unique_ptr<Module> theModule = make_unique<Module>(mKernelName + "_" + iBuilder->getBitBlockTypeName(), getGlobalContext());
74    Module * m = theModule.get();
75    iBuilder->setModule(m);
76    generateKernel(input_buffers, output_buffers);
77    iBuilder->setModule(saveModule);
78    iBuilder->restoreIP(savePoint);
79    return theModule;
80}
81
82void KernelBuilder::generateKernel(std::vector<StreamSetBuffer *> input_buffers, std::vector<StreamSetBuffer*> output_buffers) {
83    IDISA::IDISA_Builder::InsertPoint savePoint = iBuilder->saveIP();
84    Module * m = iBuilder->getModule();
85    mStreamSetInputBuffers = input_buffers;
86    mStreamSetOutputBuffers = output_buffers;
87    prepareKernel();  // possibly overriden by the KernelBuilder subtype
88    KernelInterface::addKernelDeclarations(m);
89    generateDoBlockMethod();     // must be implemented by the KernelBuilder subtype
90    generateFinalBlockMethod();  // possibly overriden by the KernelBuilder subtype
91    generateDoSegmentMethod();
92
93    // Implement the accumulator get functions
94    for (auto binding : mScalarOutputs) {
95        auto fnName = mKernelName + accumulator_infix + binding.scalarName;
96        Function * accumFn = m->getFunction(fnName);
97        iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "get_" + binding.scalarName, accumFn, 0));
98        Value * self = &*(accumFn->arg_begin());
99        Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(binding.scalarName)});
100        Value * retVal = iBuilder->CreateLoad(ptr);
101        iBuilder->CreateRet(retVal);
102    }
103    // Implement the initializer function
104    Function * initFunction = m->getFunction(mKernelName + init_suffix);
105    iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "Init_entry", initFunction, 0));
106   
107    Function::arg_iterator args = initFunction->arg_begin();
108    Value * self = &*(args++);
109    iBuilder->CreateStore(Constant::getNullValue(mKernelStateType), self);
110    for (auto binding : mScalarInputs) {
111        Value * parm = &*(args++);
112        Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(binding.scalarName)});
113        iBuilder->CreateStore(parm, ptr);
114    }
115    iBuilder->CreateRetVoid();
116    iBuilder->restoreIP(savePoint);
117}
118
119//  The default finalBlock method simply dispatches to the doBlock routine.
120void KernelBuilder::generateFinalBlockMethod() {
121    IDISA::IDISA_Builder::InsertPoint savePoint = iBuilder->saveIP();
122    Module * m = iBuilder->getModule();
123    Function * doBlockFunction = m->getFunction(mKernelName + doBlock_suffix);
124    Function * finalBlockFunction = m->getFunction(mKernelName + finalBlock_suffix);
125    iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "fb_entry", finalBlockFunction, 0));
126    // Final Block arguments: self, remaining, then the standard DoBlock args.
127    Function::arg_iterator args = finalBlockFunction->arg_begin();
128    Value * self = &*(args++);
129    /* Skip "remaining" arg */ args++;
130    std::vector<Value *> doBlockArgs = {self};
131    while (args != finalBlockFunction->arg_end()){
132        doBlockArgs.push_back(&*args++);
133    }
134    iBuilder->CreateCall(doBlockFunction, doBlockArgs);
135    iBuilder->CreateRetVoid();
136    iBuilder->restoreIP(savePoint);
137}
138
139//  The default doSegment method simply dispatches to the doBlock routine.
140void KernelBuilder::generateDoSegmentMethod() {
141    IDISA::IDISA_Builder::InsertPoint savePoint = iBuilder->saveIP();
142    Module * m = iBuilder->getModule();
143    Function * doBlockFunction = m->getFunction(mKernelName + doBlock_suffix);
144    Function * doSegmentFunction = m->getFunction(mKernelName + doSegment_suffix);
145    iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "entry", doSegmentFunction, 0));
146    BasicBlock * entryBlock = iBuilder->GetInsertBlock();
147    BasicBlock * blockLoop = BasicBlock::Create(iBuilder->getContext(), "blockLoop", doSegmentFunction, 0);
148    BasicBlock * blocksDone = BasicBlock::Create(iBuilder->getContext(), "blocksDone", doSegmentFunction, 0);
149
150   
151    Function::arg_iterator args = doSegmentFunction->arg_begin();
152    Value * self = &*(args++);
153    Value * blocksToDo = &*(args);
154   
155    iBuilder->CreateBr(blockLoop);
156   
157    iBuilder->SetInsertPoint(blockLoop);
158    PHINode * blocksRemaining = iBuilder->CreatePHI(iBuilder->getSizeTy(), 2, "blocksRemaining");
159    blocksRemaining->addIncoming(blocksToDo, entryBlock);
160   
161    Value * blockNo = getScalarField(self, blockNoScalar);
162   
163    iBuilder->CreateCall(doBlockFunction, {self});
164    setScalarField(self, blockNoScalar, iBuilder->CreateAdd(blockNo, ConstantInt::get(iBuilder->getSizeTy(), iBuilder->getStride() / iBuilder->getBitBlockWidth())));
165    blocksToDo = iBuilder->CreateSub(blocksRemaining, ConstantInt::get(iBuilder->getSizeTy(), 1));
166    blocksRemaining->addIncoming(blocksToDo, blockLoop);
167    Value * notDone = iBuilder->CreateICmpUGT(blocksToDo, ConstantInt::get(iBuilder->getSizeTy(), 0));
168    iBuilder->CreateCondBr(notDone, blockLoop, blocksDone);
169   
170    iBuilder->SetInsertPoint(blocksDone);
171    iBuilder->CreateRetVoid();
172    iBuilder->restoreIP(savePoint);
173}
174
175Value * KernelBuilder::getScalarIndex(std::string fieldName) {
176    const auto f = mInternalStateNameMap.find(fieldName);
177    if (LLVM_UNLIKELY(f == mInternalStateNameMap.end())) {
178        throw std::runtime_error("Kernel does not contain internal state: " + fieldName);
179    }
180    return iBuilder->getInt32(f->second);
181}
182
183
184
185Value * KernelBuilder::getScalarField(Value * self, std::string fieldName) {
186    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(fieldName)});
187    return iBuilder->CreateLoad(ptr);
188}
189
190void KernelBuilder::setScalarField(Value * self, std::string fieldName, Value * newFieldVal) {
191    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(fieldName)});
192    iBuilder->CreateStore(newFieldVal, ptr);
193}
194
195
196Value * KernelBuilder::getParameter(Function * f, std::string paramName) {
197    for (Function::arg_iterator argIter = f->arg_begin(), end = f->arg_end(); argIter != end; argIter++) {
198        Value * arg = &*argIter;
199        if (arg->getName() == paramName) return arg;
200    }
201    throw std::runtime_error("Method does not have parameter: " + paramName);
202}
203
204unsigned KernelBuilder::getStreamSetIndex(std::string ssName) {
205    const auto f = mStreamSetNameMap.find(ssName);
206    if (LLVM_UNLIKELY(f == mStreamSetNameMap.end())) {
207        throw std::runtime_error("Kernel does not contain stream set: " + ssName);
208    }
209    return f->second;
210}
211
212size_t KernelBuilder::getStreamSetBufferSize(Value * self, std::string ssName) {
213    unsigned ssIndex = getStreamSetIndex(ssName);
214    if (ssIndex < mStreamSetInputs.size()) {
215        return mStreamSetInputBuffers[ssIndex]->getBufferSize();
216    }
217    else {
218        return mStreamSetOutputBuffers[ssIndex - mStreamSetInputs.size()]->getBufferSize();
219    }
220}
221
222Value * KernelBuilder::getStreamSetBasePtr(Value * self, std::string ssName) {
223    return getScalarField(self, ssName + basePtrSuffix);
224}
225
226Value * KernelBuilder::getStreamSetBlockPtr(Value * self, std::string ssName, Value * blockNo) {
227    Value * basePtr = getStreamSetBasePtr(self, ssName);
228    unsigned ssIndex = getStreamSetIndex(ssName);
229    if (ssIndex < mStreamSetInputs.size()) {
230        return mStreamSetInputBuffers[ssIndex]->getStreamSetBlockPointer(basePtr, blockNo);
231    }
232    else {
233        return mStreamSetOutputBuffers[ssIndex - mStreamSetInputs.size()]->getStreamSetBlockPointer(basePtr, blockNo);
234    }
235}
236
237Value * KernelBuilder::createInstance(std::vector<Value *> args) {
238    Value * kernelInstance = iBuilder->CreateAlloca(mKernelStateType);
239    Module * m = iBuilder->getModule();
240    std::vector<Value *> init_args = {kernelInstance};
241    for (auto a : args) {
242        init_args.push_back(a);
243    }
244    for (auto b : mStreamSetInputBuffers) { 
245        init_args.push_back(b->getStreamSetBasePtr());
246    }
247    for (auto b : mStreamSetOutputBuffers) { 
248        init_args.push_back(b->getStreamSetBasePtr());
249    }
250    std::string initFnName = mKernelName + init_suffix;
251    Function * initMethod = m->getFunction(initFnName);
252    if (!initMethod) {
253        throw std::runtime_error("Cannot find " + initFnName);
254    }
255    iBuilder->CreateCall(initMethod, init_args);
256    return kernelInstance;
257}
258
259
260
261
Note: See TracBrowser for help on using the repository browser.