source: icGREP/icgrep-devel/icgrep/kernels/kernel.cpp @ 5127

Last change on this file since 5127 was 5127, checked in by lindanl, 3 years ago

Block number increased by Stride Blocks

File size: 9.7 KB
Line 
1/*
2 *  Copyright (c) 2016 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 */
5
6#include "kernel.h"
7#include <llvm/IR/Module.h>
8#include <llvm/IR/Type.h>
9#include <llvm/IR/Value.h>
10#include <llvm/Support/raw_ostream.h>
11
12using namespace llvm;
13using namespace kernel;
14
15KernelBuilder::KernelBuilder(IDISA::IDISA_Builder * builder,
16                                 std::string kernelName,
17                                 std::vector<StreamSetBinding> stream_inputs,
18                                 std::vector<StreamSetBinding> stream_outputs,
19                                 std::vector<ScalarBinding> scalar_parameters,
20                                 std::vector<ScalarBinding> scalar_outputs,
21                                 std::vector<ScalarBinding> internal_scalars) :
22    KernelInterface(builder, kernelName, stream_inputs, stream_outputs, scalar_parameters, scalar_outputs, internal_scalars) {}
23
24void KernelBuilder::addScalar(Type * t, std::string scalarName) {
25    if (LLVM_UNLIKELY(mKernelStateType != nullptr)) {
26        throw std::runtime_error("Illegal addition of kernel field after kernel state finalized: " + scalarName);
27    }
28    unsigned index = mKernelFields.size();
29    mKernelFields.push_back(t);
30    mInternalStateNameMap.emplace(scalarName, index);
31}
32
33void KernelBuilder::prepareKernel() {
34    addScalar(iBuilder->getSizeTy(), blockNoScalar);
35    int streamSetNo = 0;
36    for (auto sSet : mStreamSetInputs) {
37        mScalarInputs.push_back(ScalarBinding{sSet.ssType.getStreamBufferPointerType(), sSet.ssName + basePtrSuffix});
38        mStreamSetNameMap.emplace(sSet.ssName, streamSetNo);
39        streamSetNo++;
40    }
41    for (auto sSet : mStreamSetOutputs) {
42        mScalarInputs.push_back(ScalarBinding{sSet.ssType.getStreamBufferPointerType(), sSet.ssName + basePtrSuffix});
43        mStreamSetNameMap.emplace(sSet.ssName, streamSetNo);
44        streamSetNo++;
45    }
46    for (auto binding : mScalarInputs) {
47        addScalar(binding.scalarType, binding.scalarName);
48    }
49    for (auto binding : mScalarOutputs) {
50        addScalar(binding.scalarType, binding.scalarName);
51    }
52    for (auto binding : mInternalScalars) {
53        addScalar(binding.scalarType, binding.scalarName);
54    }
55    mKernelStateType = StructType::create(getGlobalContext(), mKernelFields, mKernelName);
56}
57
58std::unique_ptr<Module> KernelBuilder::createKernelModule() {
59    Module * saveModule = iBuilder->getModule();
60    IDISA::IDISA_Builder::InsertPoint savePoint = iBuilder->saveIP();
61    std::unique_ptr<Module> theModule = make_unique<Module>(mKernelName + "_" + iBuilder->getBitBlockTypeName(), getGlobalContext());
62    Module * m = theModule.get();
63    iBuilder->setModule(m);
64    generateKernel();
65    iBuilder->setModule(saveModule);
66    iBuilder->restoreIP(savePoint);
67    return theModule;
68}
69
70void KernelBuilder::generateKernel() {
71    IDISA::IDISA_Builder::InsertPoint savePoint = iBuilder->saveIP();
72    Module * m = iBuilder->getModule();
73
74    prepareKernel();  // possibly overriden by the KernelBuilder subtype
75    KernelInterface::addKernelDeclarations(m);
76    generateDoBlockMethod();     // must be implemented by the KernelBuilder subtype
77    generateFinalBlockMethod();  // possibly overriden by the KernelBuilder subtype
78    generateDoSegmentMethod();
79
80    // Implement the accumulator get functions
81    for (auto binding : mScalarOutputs) {
82        auto fnName = mKernelName + accumulator_infix + binding.scalarName;
83        Function * accumFn = m->getFunction(fnName);
84        iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "get_" + binding.scalarName, accumFn, 0));
85        Value * self = &*(accumFn->arg_begin());
86        Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(binding.scalarName)});
87        Value * retVal = iBuilder->CreateLoad(ptr);
88        iBuilder->CreateRet(retVal);
89    }
90    // Implement the initializer function
91    Function * initFunction = m->getFunction(mKernelName + init_suffix);
92    iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "Init_entry", initFunction, 0));
93   
94    Function::arg_iterator args = initFunction->arg_begin();
95    Value * self = &*(args++);
96    iBuilder->CreateStore(Constant::getNullValue(mKernelStateType), self);
97    for (auto binding : mScalarInputs) {
98        Value * parm = &*(args++);
99        Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(binding.scalarName)});
100        iBuilder->CreateStore(parm, ptr);
101    }
102    iBuilder->CreateRetVoid();
103    iBuilder->restoreIP(savePoint);
104}
105
106//  The default finalBlock method simply dispatches to the doBlock routine.
107void KernelBuilder::generateFinalBlockMethod() {
108    IDISA::IDISA_Builder::InsertPoint savePoint = iBuilder->saveIP();
109    Module * m = iBuilder->getModule();
110    Function * doBlockFunction = m->getFunction(mKernelName + doBlock_suffix);
111    Function * finalBlockFunction = m->getFunction(mKernelName + finalBlock_suffix);
112    iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "fb_entry", finalBlockFunction, 0));
113    // Final Block arguments: self, remaining, then the standard DoBlock args.
114    Function::arg_iterator args = finalBlockFunction->arg_begin();
115    Value * self = &*(args++);
116    /* Skip "remaining" arg */ args++;
117    std::vector<Value *> doBlockArgs = {self};
118    while (args != finalBlockFunction->arg_end()){
119        doBlockArgs.push_back(&*args++);
120    }
121    iBuilder->CreateCall(doBlockFunction, doBlockArgs);
122    iBuilder->CreateRetVoid();
123    iBuilder->restoreIP(savePoint);
124}
125
126//  The default doSegment method simply dispatches to the doBlock routine.
127void KernelBuilder::generateDoSegmentMethod() {
128    IDISA::IDISA_Builder::InsertPoint savePoint = iBuilder->saveIP();
129    Module * m = iBuilder->getModule();
130    Function * doBlockFunction = m->getFunction(mKernelName + doBlock_suffix);
131    Function * doSegmentFunction = m->getFunction(mKernelName + doSegment_suffix);
132    iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "entry", doSegmentFunction, 0));
133    BasicBlock * entryBlock = iBuilder->GetInsertBlock();
134    BasicBlock * blockLoop = BasicBlock::Create(iBuilder->getContext(), "blockLoop", doSegmentFunction, 0);
135    BasicBlock * blocksDone = BasicBlock::Create(iBuilder->getContext(), "blocksDone", doSegmentFunction, 0);
136
137   
138    Function::arg_iterator args = doSegmentFunction->arg_begin();
139    Value * self = &*(args++);
140    Value * blocksToDo = &*(args);
141   
142    iBuilder->CreateBr(blockLoop);
143   
144    iBuilder->SetInsertPoint(blockLoop);
145    PHINode * blocksRemaining = iBuilder->CreatePHI(iBuilder->getSizeTy(), 2, "blocksRemaining");
146    blocksRemaining->addIncoming(blocksToDo, entryBlock);
147   
148    Value * blockNo = getScalarField(self, blockNoScalar);
149   
150    iBuilder->CreateCall(doBlockFunction, {self});
151    setScalarField(self, blockNoScalar, iBuilder->CreateAdd(blockNo, ConstantInt::get(iBuilder->getSizeTy(), iBuilder->getStride() / iBuilder->getBitBlockWidth())));
152    blocksToDo = iBuilder->CreateSub(blocksRemaining, ConstantInt::get(iBuilder->getSizeTy(), 1));
153    blocksRemaining->addIncoming(blocksToDo, blockLoop);
154    Value * notDone = iBuilder->CreateICmpUGT(blocksToDo, ConstantInt::get(iBuilder->getSizeTy(), 0));
155    iBuilder->CreateCondBr(notDone, blockLoop, blocksDone);
156   
157    iBuilder->SetInsertPoint(blocksDone);
158    iBuilder->CreateRetVoid();
159    iBuilder->restoreIP(savePoint);
160}
161
162Value * KernelBuilder::getScalarIndex(std::string fieldName) {
163    const auto f = mInternalStateNameMap.find(fieldName);
164    if (LLVM_UNLIKELY(f == mInternalStateNameMap.end())) {
165        throw std::runtime_error("Kernel does not contain internal state: " + fieldName);
166    }
167    return iBuilder->getInt32(f->second);
168}
169
170
171
172Value * KernelBuilder::getScalarField(Value * self, std::string fieldName) {
173    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(fieldName)});
174    return iBuilder->CreateLoad(ptr);
175}
176
177void KernelBuilder::setScalarField(Value * self, std::string fieldName, Value * newFieldVal) {
178    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(fieldName)});
179    iBuilder->CreateStore(newFieldVal, ptr);
180}
181
182
183Value * KernelBuilder::getParameter(Function * f, std::string paramName) {
184    for (Function::arg_iterator argIter = f->arg_begin(), end = f->arg_end(); argIter != end; argIter++) {
185        Value * arg = &*argIter;
186        if (arg->getName() == paramName) return arg;
187    }
188    throw std::runtime_error("Method does not have parameter: " + paramName);
189}
190
191unsigned KernelBuilder::getStreamSetIndex(std::string ssName) {
192    const auto f = mStreamSetNameMap.find(ssName);
193    if (LLVM_UNLIKELY(f == mStreamSetNameMap.end())) {
194        throw std::runtime_error("Kernel does not contain stream set: " + ssName);
195    }
196    return f->second;
197}
198
199size_t KernelBuilder::getStreamSetBufferSize(Value * self, std::string ssName) {
200    unsigned ssIndex = getStreamSetIndex(ssName);
201    if (ssIndex < mStreamSetInputs.size()) {
202        return mStreamSetInputs[ssIndex].ssType.getBufferSize();
203    }
204    else {
205        return mStreamSetOutputs[ssIndex - mStreamSetInputs.size()].ssType.getBufferSize();
206    }
207}
208
209Value * KernelBuilder::getStreamSetBasePtr(Value * self, std::string ssName) {
210    return getScalarField(self, ssName + basePtrSuffix);
211}
212
213Value * KernelBuilder::getStreamSetBlockPtr(Value * self, std::string ssName, Value * blockNo) {
214    Value * basePtr = getStreamSetBasePtr(self, ssName);
215    unsigned ssIndex = getStreamSetIndex(ssName);
216    if (ssIndex < mStreamSetInputs.size()) {
217        return mStreamSetInputs[ssIndex].ssType.getStreamSetBlockPointer(basePtr, blockNo);
218    }
219    else {
220        return mStreamSetOutputs[ssIndex - mStreamSetInputs.size()].ssType.getStreamSetBlockPointer(basePtr, blockNo);
221    }
222}
223
224
225
226
Note: See TracBrowser for help on using the repository browser.