source: icGREP/icgrep-devel/icgrep/kernels/kernel.cpp @ 5135

Last change on this file since 5135 was 5135, checked in by lindanl, 3 years ago

Add pipeline parallel strategy to the framework.

File size: 19.4 KB
Line 
1/*
2 *  Copyright (c) 2016 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 */
5
6#include "kernel.h"
7#include <llvm/IR/Module.h>
8#include <llvm/IR/Type.h>
9#include <llvm/IR/Value.h>
10#include <llvm/Support/raw_ostream.h>
11#include <llvm/IR/TypeBuilder.h>
12#include <toolchain.h>
13
14using namespace llvm;
15using namespace kernel;
16
17KernelBuilder::KernelBuilder(IDISA::IDISA_Builder * builder,
18                                 std::string kernelName,
19                                 std::vector<StreamSetBinding> stream_inputs,
20                                 std::vector<StreamSetBinding> stream_outputs,
21                                 std::vector<ScalarBinding> scalar_parameters,
22                                 std::vector<ScalarBinding> scalar_outputs,
23                                 std::vector<ScalarBinding> internal_scalars) :
24    KernelInterface(builder, kernelName, stream_inputs, stream_outputs, scalar_parameters, scalar_outputs, internal_scalars) {}
25
26void KernelBuilder::addScalar(Type * t, std::string scalarName) {
27    if (LLVM_UNLIKELY(mKernelStateType != nullptr)) {
28        throw std::runtime_error("Illegal addition of kernel field after kernel state finalized: " + scalarName);
29    }
30    unsigned index = mKernelFields.size();
31    mKernelFields.push_back(t);
32    mInternalStateNameMap.emplace(scalarName, index);
33}
34
35void KernelBuilder::prepareKernel() {
36    if (mStreamSetInputs.size() != mStreamSetInputBuffers.size()) {
37        throw std::runtime_error("Kernel preparation: Incorrect number of input buffers");
38    }
39    if (mStreamSetOutputs.size() != mStreamSetOutputBuffers.size()) {
40        throw std::runtime_error("Kernel preparation: Incorrect number of input buffers");
41    }
42    addScalar(iBuilder->getSizeTy(), blockNoScalar);
43    int streamSetNo = 0;
44    for (unsigned i = 0; i < mStreamSetInputs.size(); i++) {
45        if (!(mStreamSetInputBuffers[i]->getBufferStreamSetType() == mStreamSetInputs[i].ssType)) {
46             throw std::runtime_error("Kernel preparation: Incorrect input buffer type");
47        }
48        mScalarInputs.push_back(ScalarBinding{mStreamSetInputBuffers[i]->getStreamSetStructPointerType(), mStreamSetInputs[i].ssName + basePtrSuffix});
49        mStreamSetNameMap.emplace(mStreamSetInputs[i].ssName, streamSetNo);
50        streamSetNo++;
51    }
52    for (unsigned i = 0; i < mStreamSetOutputs.size(); i++) {
53        if (!(mStreamSetOutputBuffers[i]->getBufferStreamSetType() == mStreamSetOutputs[i].ssType)) {
54             throw std::runtime_error("Kernel preparation: Incorrect input buffer type");
55        }
56        mScalarInputs.push_back(ScalarBinding{mStreamSetOutputBuffers[i]->getStreamSetStructPointerType(), mStreamSetOutputs[i].ssName + basePtrSuffix});
57        mStreamSetNameMap.emplace(mStreamSetOutputs[i].ssName, streamSetNo);
58        streamSetNo++;
59    }
60    for (auto binding : mScalarInputs) {
61        addScalar(binding.scalarType, binding.scalarName);
62    }
63    for (auto binding : mScalarOutputs) {
64        addScalar(binding.scalarType, binding.scalarName);
65    }
66    for (auto binding : mInternalScalars) {
67        addScalar(binding.scalarType, binding.scalarName);
68    }
69    mKernelStateType = StructType::create(getGlobalContext(), mKernelFields, mKernelName);
70}
71
72std::unique_ptr<Module> KernelBuilder::createKernelModule(std::vector<StreamSetBuffer *> input_buffers, std::vector<StreamSetBuffer *> output_buffers) {
73    Module * saveModule = iBuilder->getModule();
74    IDISA::IDISA_Builder::InsertPoint savePoint = iBuilder->saveIP();
75    std::unique_ptr<Module> theModule = make_unique<Module>(mKernelName + "_" + iBuilder->getBitBlockTypeName(), getGlobalContext());
76    Module * m = theModule.get();
77    iBuilder->setModule(m);
78    generateKernel(input_buffers, output_buffers);
79    iBuilder->setModule(saveModule);
80    iBuilder->restoreIP(savePoint);
81    return theModule;
82}
83
84void KernelBuilder::generateKernel(std::vector<StreamSetBuffer *> input_buffers, std::vector<StreamSetBuffer*> output_buffers) {
85    IDISA::IDISA_Builder::InsertPoint savePoint = iBuilder->saveIP();
86    Module * m = iBuilder->getModule();
87    mStreamSetInputBuffers = input_buffers;
88    mStreamSetOutputBuffers = output_buffers;
89    prepareKernel();  // possibly overriden by the KernelBuilder subtype
90    KernelInterface::addKernelDeclarations(m);
91    generateDoBlockMethod();     // must be implemented by the KernelBuilder subtype
92    generateFinalBlockMethod();  // possibly overriden by the KernelBuilder subtype
93    generateDoSegmentMethod();
94
95    // Implement the accumulator get functions
96    for (auto binding : mScalarOutputs) {
97        auto fnName = mKernelName + accumulator_infix + binding.scalarName;
98        Function * accumFn = m->getFunction(fnName);
99        iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "get_" + binding.scalarName, accumFn, 0));
100        Value * self = &*(accumFn->arg_begin());
101        Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(binding.scalarName)});
102        Value * retVal = iBuilder->CreateLoad(ptr);
103        iBuilder->CreateRet(retVal);
104    }
105    // Implement the initializer function
106    Function * initFunction = m->getFunction(mKernelName + init_suffix);
107    iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "Init_entry", initFunction, 0));
108   
109    Function::arg_iterator args = initFunction->arg_begin();
110    Value * self = &*(args++);
111    iBuilder->CreateStore(Constant::getNullValue(mKernelStateType), self);
112    for (auto binding : mScalarInputs) {
113        Value * parm = &*(args++);
114        Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(binding.scalarName)});
115        iBuilder->CreateStore(parm, ptr);
116    }
117    iBuilder->CreateRetVoid();
118    iBuilder->restoreIP(savePoint);
119}
120
121//  The default finalBlock method simply dispatches to the doBlock routine.
122void KernelBuilder::generateFinalBlockMethod() {
123    IDISA::IDISA_Builder::InsertPoint savePoint = iBuilder->saveIP();
124    Module * m = iBuilder->getModule();
125    Function * doBlockFunction = m->getFunction(mKernelName + doBlock_suffix);
126    Function * finalBlockFunction = m->getFunction(mKernelName + finalBlock_suffix);
127    iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "fb_entry", finalBlockFunction, 0));
128    // Final Block arguments: self, remaining, then the standard DoBlock args.
129    Function::arg_iterator args = finalBlockFunction->arg_begin();
130    Value * self = &*(args++);
131    /* Skip "remaining" arg */ args++;
132    std::vector<Value *> doBlockArgs = {self};
133    while (args != finalBlockFunction->arg_end()){
134        doBlockArgs.push_back(&*args++);
135    }
136    iBuilder->CreateCall(doBlockFunction, doBlockArgs);
137    iBuilder->CreateRetVoid();
138    iBuilder->restoreIP(savePoint);
139}
140
141//  The default doSegment method simply dispatches to the doBlock routine.
142void KernelBuilder::generateDoSegmentMethod() {
143    IDISA::IDISA_Builder::InsertPoint savePoint = iBuilder->saveIP();
144    Module * m = iBuilder->getModule();
145    Function * doBlockFunction = m->getFunction(mKernelName + doBlock_suffix);
146    Function * doSegmentFunction = m->getFunction(mKernelName + doSegment_suffix);
147    iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "entry", doSegmentFunction, 0));
148    BasicBlock * entryBlock = iBuilder->GetInsertBlock();
149    BasicBlock * blockLoop = BasicBlock::Create(iBuilder->getContext(), "blockLoop", doSegmentFunction, 0);
150    BasicBlock * blocksDone = BasicBlock::Create(iBuilder->getContext(), "blocksDone", doSegmentFunction, 0);
151
152   
153    Function::arg_iterator args = doSegmentFunction->arg_begin();
154    Value * self = &*(args++);
155    Value * blocksToDo = &*(args);
156   
157    iBuilder->CreateBr(blockLoop);
158   
159    iBuilder->SetInsertPoint(blockLoop);
160    PHINode * blocksRemaining = iBuilder->CreatePHI(iBuilder->getSizeTy(), 2, "blocksRemaining");
161    blocksRemaining->addIncoming(blocksToDo, entryBlock);
162   
163    Value * blockNo = getScalarField(self, blockNoScalar);
164   
165    iBuilder->CreateCall(doBlockFunction, {self});
166    setScalarField(self, blockNoScalar, iBuilder->CreateAdd(blockNo, ConstantInt::get(iBuilder->getSizeTy(), iBuilder->getStride() / iBuilder->getBitBlockWidth())));
167    blocksToDo = iBuilder->CreateSub(blocksRemaining, ConstantInt::get(iBuilder->getSizeTy(), 1));
168    blocksRemaining->addIncoming(blocksToDo, blockLoop);
169    Value * notDone = iBuilder->CreateICmpUGT(blocksToDo, ConstantInt::get(iBuilder->getSizeTy(), 0));
170    iBuilder->CreateCondBr(notDone, blockLoop, blocksDone);
171   
172    iBuilder->SetInsertPoint(blocksDone);
173    iBuilder->CreateRetVoid();
174    iBuilder->restoreIP(savePoint);
175}
176
177Value * KernelBuilder::getScalarIndex(std::string fieldName) {
178    const auto f = mInternalStateNameMap.find(fieldName);
179    if (LLVM_UNLIKELY(f == mInternalStateNameMap.end())) {
180        throw std::runtime_error("Kernel does not contain internal state: " + fieldName);
181    }
182    return iBuilder->getInt32(f->second);
183}
184
185
186
187Value * KernelBuilder::getScalarField(Value * self, std::string fieldName) {
188    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(fieldName)});
189    return iBuilder->CreateLoad(ptr);
190}
191
192void KernelBuilder::setScalarField(Value * self, std::string fieldName, Value * newFieldVal) {
193    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(fieldName)});
194    iBuilder->CreateStore(newFieldVal, ptr);
195}
196
197
198Value * KernelBuilder::getParameter(Function * f, std::string paramName) {
199    for (Function::arg_iterator argIter = f->arg_begin(), end = f->arg_end(); argIter != end; argIter++) {
200        Value * arg = &*argIter;
201        if (arg->getName() == paramName) return arg;
202    }
203    throw std::runtime_error("Method does not have parameter: " + paramName);
204}
205
206unsigned KernelBuilder::getStreamSetIndex(std::string ssName) {
207    const auto f = mStreamSetNameMap.find(ssName);
208    if (LLVM_UNLIKELY(f == mStreamSetNameMap.end())) {
209        throw std::runtime_error("Kernel does not contain stream set: " + ssName);
210    }
211    return f->second;
212}
213
214size_t KernelBuilder::getStreamSetBufferSize(Value * self, std::string ssName) {
215    unsigned ssIndex = getStreamSetIndex(ssName);
216    if (ssIndex < mStreamSetInputs.size()) {
217        return mStreamSetInputBuffers[ssIndex]->getBufferSize();
218    }
219    else {
220        return mStreamSetOutputBuffers[ssIndex - mStreamSetInputs.size()]->getBufferSize();
221    }
222}
223
224Value * KernelBuilder::getStreamSetBasePtr(Value * self, std::string ssName) {
225    return getScalarField(self, ssName + basePtrSuffix);
226}
227
228Value * KernelBuilder::getStreamSetBlockPtr(Value * self, std::string ssName, Value * blockNo) {
229    Value * basePtr = getStreamSetBasePtr(self, ssName);
230    unsigned ssIndex = getStreamSetIndex(ssName);
231    if (ssIndex < mStreamSetInputs.size()) {
232        return mStreamSetInputBuffers[ssIndex]->getStreamSetBlockPointer(basePtr, blockNo);
233    }
234    else {
235        return mStreamSetOutputBuffers[ssIndex - mStreamSetInputs.size()]->getStreamSetBlockPointer(basePtr, blockNo);
236    }
237}
238
239Value * KernelBuilder::createInstance(std::vector<Value *> args) {
240    Value * kernelInstance = iBuilder->CreateAlloca(mKernelStateType);
241    Module * m = iBuilder->getModule();
242    std::vector<Value *> init_args = {kernelInstance};
243    for (auto a : args) {
244        init_args.push_back(a);
245    }
246    for (auto b : mStreamSetInputBuffers) { 
247        init_args.push_back(b->getStreamSetStructPtr());
248    }
249    for (auto b : mStreamSetOutputBuffers) { 
250        init_args.push_back(b->getStreamSetStructPtr());
251    }
252    std::string initFnName = mKernelName + init_suffix;
253    Function * initMethod = m->getFunction(initFnName);
254    if (!initMethod) {
255        throw std::runtime_error("Cannot find " + initFnName);
256    }
257    iBuilder->CreateCall(initMethod, init_args);
258    return kernelInstance;
259}
260
261Function * KernelBuilder::generateThreadFunction(std::string name){
262    Module * m = iBuilder->getModule();
263    Type * const voidTy = Type::getVoidTy(m->getContext());
264    Type * const voidPtrTy = TypeBuilder<void *, false>::get(m->getContext());
265    Type * const int8PtrTy = iBuilder->getInt8PtrTy();
266    Type * const int1ty = iBuilder->getInt1Ty();
267
268    Function * const threadFunc = cast<Function>(m->getOrInsertFunction(name, voidTy, int8PtrTy, nullptr));
269    threadFunc->setCallingConv(CallingConv::C);
270    Function::arg_iterator args = threadFunc->arg_begin();
271
272    Value * const arg = &*(args++);
273    arg->setName("args");
274
275    iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "entry", threadFunc,0));
276
277    Value * self = iBuilder->CreateBitCast(arg, PointerType::get(mKernelStateType, 0));
278
279    std::vector<Value *> inbufProducerPtrs;
280    std::vector<Value *> inbufConsumerPtrs;
281    std::vector<Value *> outbufProducerPtrs;
282    std::vector<Value *> outbufConsumerPtrs;   
283    std::vector<Value *> endSignalPtrs;
284
285    for (unsigned i = 0; i < mStreamSetInputs.size(); i++) {
286        Value * basePtr = getStreamSetBasePtr(self, mStreamSetInputs[i].ssName);
287        inbufProducerPtrs.push_back(mStreamSetInputBuffers[i]->getProducerPosPtr(basePtr));
288        inbufConsumerPtrs.push_back(mStreamSetInputBuffers[i]->getComsumerPosPtr(basePtr));
289        endSignalPtrs.push_back(mStreamSetInputBuffers[i]->hasEndOfInputPtr(basePtr));
290    }
291    for (unsigned i = 0; i < mStreamSetOutputs.size(); i++) {
292        Value * basePtr = getStreamSetBasePtr(self, mStreamSetOutputs[i].ssName);
293        outbufProducerPtrs.push_back(mStreamSetOutputBuffers[i]->getProducerPosPtr(basePtr));
294        outbufConsumerPtrs.push_back(mStreamSetOutputBuffers[i]->getComsumerPosPtr(basePtr));
295    }
296
297    const unsigned segmentBlocks = codegen::SegmentSize;
298    const unsigned bufferSegments = codegen::BufferSegments;
299    const unsigned segmentSize = segmentBlocks * iBuilder->getBitBlockWidth();
300    Type * const size_ty = iBuilder->getSizeTy();
301
302    Value * segSize = ConstantInt::get(size_ty, segmentSize);
303    Value * bufferSize = ConstantInt::get(size_ty, segmentSize * (bufferSegments - 1));
304    Value * segBlocks = ConstantInt::get(size_ty, segmentBlocks);
305   
306    BasicBlock * outputCheckBlock = BasicBlock::Create(iBuilder->getContext(), "outputCheck", threadFunc, 0);
307    BasicBlock * inputCheckBlock = BasicBlock::Create(iBuilder->getContext(), "inputCheck", threadFunc, 0);
308   
309    BasicBlock * endSignalCheckBlock = BasicBlock::Create(iBuilder->getContext(), "endSignalCheck", threadFunc, 0);
310    BasicBlock * doSegmentBlock = BasicBlock::Create(iBuilder->getContext(), "doSegment", threadFunc, 0);
311    BasicBlock * endBlock = BasicBlock::Create(iBuilder->getContext(), "end", threadFunc, 0);
312    BasicBlock * doFinalSegBlock = BasicBlock::Create(iBuilder->getContext(), "doFinalSeg", threadFunc, 0);
313    BasicBlock * doFinalBlock = BasicBlock::Create(iBuilder->getContext(), "doFinal", threadFunc, 0);
314
315    iBuilder->CreateBr(outputCheckBlock);
316
317    iBuilder->SetInsertPoint(outputCheckBlock);
318
319    Value * waitCondTest = ConstantInt::get(int1ty, 1);   
320    for (unsigned i = 0; i < outbufProducerPtrs.size(); i++) {
321        LoadInst * producerPos = iBuilder->CreateAlignedLoad(outbufProducerPtrs[i], 8);
322        producerPos->setOrdering(Acquire);
323        // iBuilder->CallPrintInt(name + ":output producerPos", producerPos);
324        LoadInst * consumerPos = iBuilder->CreateAlignedLoad(outbufConsumerPtrs[i], 8);
325        consumerPos->setOrdering(Acquire);
326        // iBuilder->CallPrintInt(name + ":output consumerPos", consumerPos);
327        waitCondTest = iBuilder->CreateAnd(waitCondTest, iBuilder->CreateICmpULE(producerPos, iBuilder->CreateAdd(consumerPos, bufferSize)));
328    }
329   
330    iBuilder->CreateCondBr(waitCondTest, inputCheckBlock, outputCheckBlock); 
331
332    iBuilder->SetInsertPoint(inputCheckBlock); 
333
334    waitCondTest = ConstantInt::get(int1ty, 1); 
335    for (unsigned i = 0; i < inbufProducerPtrs.size(); i++) {
336        LoadInst * producerPos = iBuilder->CreateAlignedLoad(inbufProducerPtrs[i], 8);
337        producerPos->setOrdering(Acquire);
338        // iBuilder->CallPrintInt(name + ":input producerPos", producerPos);
339        LoadInst * consumerPos = iBuilder->CreateAlignedLoad(inbufConsumerPtrs[i], 8);
340        consumerPos->setOrdering(Acquire);
341        // iBuilder->CallPrintInt(name + ":input consumerPos", consumerPos);
342        waitCondTest = iBuilder->CreateAnd(waitCondTest, iBuilder->CreateICmpULE(iBuilder->CreateAdd(consumerPos, segSize), producerPos));
343    }
344
345    iBuilder->CreateCondBr(waitCondTest, doSegmentBlock, endSignalCheckBlock);
346   
347    iBuilder->SetInsertPoint(endSignalCheckBlock);
348   
349    LoadInst * endSignal = iBuilder->CreateAlignedLoad(endSignalPtrs[0], 8);
350    // iBuilder->CallPrintInt(name + ":endSignal", endSignal);
351    endSignal->setOrdering(Acquire);
352    for (unsigned i = 1; i < endSignalPtrs.size(); i++){
353        LoadInst * endSignal_next = iBuilder->CreateAlignedLoad(endSignalPtrs[i], 8);
354        endSignal_next->setOrdering(Acquire);
355        iBuilder->CreateAnd(endSignal, endSignal_next);
356    }
357       
358    iBuilder->CreateCondBr(iBuilder->CreateICmpEQ(endSignal, ConstantInt::get(iBuilder->getInt8Ty(), 1)), endBlock, inputCheckBlock);
359   
360    iBuilder->SetInsertPoint(doSegmentBlock);
361 
362    createDoSegmentCall(self, segBlocks);
363
364    for (unsigned i = 0; i < inbufConsumerPtrs.size(); i++) {
365        Value * consumerPos = iBuilder->CreateAdd(iBuilder->CreateLoad(inbufConsumerPtrs[i]), segSize);
366        iBuilder->CreateAlignedStore(consumerPos, inbufConsumerPtrs[i], 8)->setOrdering(Release);
367    }
368    for (unsigned i = 0; i < outbufProducerPtrs.size(); i++) {
369        Value * producerPos = iBuilder->CreateAdd(iBuilder->CreateLoad(outbufProducerPtrs[i]), segSize);
370        iBuilder->CreateAlignedStore(producerPos, outbufProducerPtrs[i], 8)->setOrdering(Release);
371    }
372   
373    iBuilder->CreateBr(outputCheckBlock);
374     
375    iBuilder->SetInsertPoint(endBlock);
376    LoadInst * producerPos = iBuilder->CreateLoad(inbufProducerPtrs[0]);
377    LoadInst * consumerPos = iBuilder->CreateLoad(inbufConsumerPtrs[0]);
378    Value * remainingBytes = iBuilder->CreateSub(producerPos, consumerPos);
379    Value * blockSize = ConstantInt::get(size_ty, iBuilder->getBitBlockWidth());
380    Value * blocks = iBuilder->CreateUDiv(remainingBytes, blockSize);
381    Value * finalBlockRemainingBytes = iBuilder->CreateURem(remainingBytes, blockSize);
382
383    iBuilder->CreateCondBr(iBuilder->CreateICmpEQ(blocks, ConstantInt::get(size_ty, 0)), doFinalBlock, doFinalSegBlock);
384
385    iBuilder->SetInsertPoint(doFinalSegBlock);
386
387    createDoSegmentCall(self, blocks);
388
389    iBuilder->CreateBr(doFinalBlock);
390
391    iBuilder->SetInsertPoint(doFinalBlock);
392
393    createFinalBlockCall(self, finalBlockRemainingBytes);
394
395    for (unsigned i = 0; i < inbufConsumerPtrs.size(); i++) {
396        Value * consumerPos = iBuilder->CreateAdd(iBuilder->CreateLoad(inbufConsumerPtrs[i]), remainingBytes);
397        iBuilder->CreateAlignedStore(consumerPos, inbufConsumerPtrs[i], 8)->setOrdering(Release);
398    }
399    for (unsigned i = 0; i < outbufProducerPtrs.size(); i++) {
400        Value * producerPos = iBuilder->CreateAdd(iBuilder->CreateLoad(outbufProducerPtrs[i]), remainingBytes);
401        iBuilder->CreateAlignedStore(producerPos, outbufProducerPtrs[i], 8)->setOrdering(Release);
402    }
403
404    for (unsigned i = 0; i < mStreamSetOutputs.size(); i++) {
405        Value * basePtr = getStreamSetBasePtr(self, mStreamSetOutputs[i].ssName);
406        mStreamSetOutputBuffers[i]->setEndOfInput(basePtr);
407    }
408
409    Value * nullVal = Constant::getNullValue(voidPtrTy);
410    Function * pthreadExitFunc = m->getFunction("pthread_exit");
411    CallInst * exitThread = iBuilder->CreateCall(pthreadExitFunc, {nullVal}); 
412    exitThread->setDoesNotReturn();
413    iBuilder->CreateRetVoid();
414
415    return threadFunc;
416
417}
418
419
Note: See TracBrowser for help on using the repository browser.