source: icGREP/icgrep-devel/icgrep/kernels/kernel.cpp @ 5251

Last change on this file since 5251 was 5250, checked in by cameron, 3 years ago

Allow for override of kernel init method.

File size: 28.5 KB
Line 
1/*
2 *  Copyright (c) 2016 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 */
5
6#include "kernel.h"
7#include <llvm/IR/Module.h>
8#include <llvm/IR/Type.h>
9#include <llvm/IR/Value.h>
10#include <llvm/Support/raw_ostream.h>
11#include <llvm/Support/ErrorHandling.h>
12#include <toolchain.h>
13
14using namespace llvm;
15using namespace kernel;
16
17KernelBuilder::KernelBuilder(IDISA::IDISA_Builder * builder,
18                             std::string kernelName,
19                             std::vector<Binding> stream_inputs,
20                             std::vector<Binding> stream_outputs,
21                             std::vector<Binding> scalar_parameters,
22                             std::vector<Binding> scalar_outputs,
23                             std::vector<Binding> internal_scalars)
24: KernelInterface(builder, kernelName, stream_inputs, stream_outputs, scalar_parameters, scalar_outputs, internal_scalars) {
25
26}
27
28unsigned KernelBuilder::addScalar(Type * type, const std::string & name) {
29    if (LLVM_UNLIKELY(mKernelStateType != nullptr)) {
30        llvm::report_fatal_error("Cannot add kernel field " + name + " after kernel state finalized");
31    }
32    const auto index = mKernelFields.size();
33    mKernelMap.emplace(name, index);
34    mKernelFields.push_back(type);
35    return index;
36}
37
38void KernelBuilder::prepareKernel() {
39    if (LLVM_UNLIKELY(mKernelStateType != nullptr)) {
40        llvm::report_fatal_error("Cannot prepare kernel after kernel state finalized");
41    }
42    unsigned blockSize = iBuilder->getBitBlockWidth();
43    if (mStreamSetInputs.size() != mStreamSetInputBuffers.size()) {
44        std::string tmp;
45        raw_string_ostream out(tmp);
46        out << "kernel contains " << mStreamSetInputBuffers.size() << " input buffers for "
47            << mStreamSetInputs.size() << " input stream sets.";
48        throw std::runtime_error(out.str());
49    }
50    if (mStreamSetOutputs.size() != mStreamSetOutputBuffers.size()) {
51        std::string tmp;
52        raw_string_ostream out(tmp);
53        out << "kernel contains " << mStreamSetOutputBuffers.size() << " output buffers for "
54            << mStreamSetOutputs.size() << " output stream sets.";
55        throw std::runtime_error(out.str());
56    }
57    int streamSetNo = 0;
58    for (unsigned i = 0; i < mStreamSetInputs.size(); i++) {
59        if ((mStreamSetInputBuffers[i]->getBufferSize() > 0) && (mStreamSetInputBuffers[i]->getBufferSize() < codegen::SegmentSize + (blockSize + mLookAheadPositions - 1)/blockSize)) {
60             llvm::report_fatal_error("Kernel preparation: Buffer size too small " + mStreamSetInputs[i].name);
61        }
62        mScalarInputs.push_back(Binding{mStreamSetInputBuffers[i]->getStreamSetStructPointerType(), mStreamSetInputs[i].name + structPtrSuffix});
63        mStreamSetNameMap.emplace(mStreamSetInputs[i].name, streamSetNo);
64        addScalar(iBuilder->getSizeTy(), mStreamSetInputs[i].name + processedItemCountSuffix);
65        streamSetNo++;
66    }
67    for (unsigned i = 0; i < mStreamSetOutputs.size(); i++) {
68        mScalarInputs.push_back(Binding{mStreamSetOutputBuffers[i]->getStreamSetStructPointerType(), mStreamSetOutputs[i].name + structPtrSuffix});
69        mStreamSetNameMap.emplace(mStreamSetOutputs[i].name, streamSetNo);
70        addScalar(iBuilder->getSizeTy(), mStreamSetOutputs[i].name + producedItemCountSuffix);
71        streamSetNo++;
72    }
73    for (auto binding : mScalarInputs) {
74        addScalar(binding.type, binding.name);
75    }
76    for (auto binding : mScalarOutputs) {
77        addScalar(binding.type, binding.name);
78    }
79    for (auto binding : mInternalScalars) {
80        addScalar(binding.type, binding.name);
81    }
82    addScalar(iBuilder->getSizeTy(), blockNoScalar);
83    addScalar(iBuilder->getSizeTy(), logicalSegmentNoScalar);
84    addScalar(iBuilder->getInt1Ty(), terminationSignal);
85    mKernelStateType = StructType::create(iBuilder->getContext(), mKernelFields, mKernelName);
86}
87
88std::unique_ptr<Module> KernelBuilder::createKernelModule(const std::vector<StreamSetBuffer *> & inputs, const std::vector<StreamSetBuffer *> & outputs) {
89    auto saveModule = iBuilder->getModule();
90    auto savePoint = iBuilder->saveIP();
91    auto module = make_unique<Module>(mKernelName + "_" + iBuilder->getBitBlockTypeName(), iBuilder->getContext());
92    iBuilder->setModule(module.get());
93    generateKernel(inputs, outputs);
94    iBuilder->setModule(saveModule);
95    iBuilder->restoreIP(savePoint);
96    return module;
97}
98
99void KernelBuilder::generateKernel(const std::vector<StreamSetBuffer *> & inputs, const std::vector<StreamSetBuffer *> & outputs) {
100    auto savePoint = iBuilder->saveIP();
101    Module * const m = iBuilder->getModule();
102    mStreamSetInputBuffers.assign(inputs.begin(), inputs.end());
103    mStreamSetOutputBuffers.assign(outputs.begin(), outputs.end());
104    prepareKernel();            // possibly overridden by the KernelBuilder subtype
105    addKernelDeclarations(m);
106    generateInitMethod();       // possibly overridden by the KernelBuilder subtype
107    generateDoBlockMethod();    // must be implemented by the KernelBuilder subtype
108    generateFinalBlockMethod(); // possibly overridden by the KernelBuilder subtype
109    generateDoSegmentMethod();
110
111    // Implement the accumulator get functions
112    for (auto binding : mScalarOutputs) {
113        auto fnName = mKernelName + accumulator_infix + binding.name;
114        Function * accumFn = m->getFunction(fnName);
115        iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "get_" + binding.name, accumFn, 0));
116        Value * self = &*(accumFn->arg_begin());
117        Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(binding.name)});
118        Value * retVal = iBuilder->CreateLoad(ptr);
119        iBuilder->CreateRet(retVal);
120    }
121    generateInitMethod();
122    iBuilder->restoreIP(savePoint);
123}
124
125// Default init method, possibly overridden if special init actions required.
126void KernelBuilder::generateInitMethod() const {
127    auto savePoint = iBuilder->saveIP();
128    Module * const m = iBuilder->getModule();
129    Function * initFunction = m->getFunction(mKernelName + init_suffix);
130    iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "Init_entry", initFunction, 0));   
131    Function::arg_iterator args = initFunction->arg_begin();
132    Value * self = &*(args++);
133    iBuilder->CreateStore(ConstantAggregateZero::get(mKernelStateType), self);
134    for (auto binding : mScalarInputs) {
135        Value * param = &*(args++);
136        Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(binding.name)});
137        iBuilder->CreateStore(param, ptr);
138    }
139    iBuilder->CreateRetVoid();
140    iBuilder->restoreIP(savePoint);
141}
142
143//  The default finalBlock method simply dispatches to the doBlock routine.
144void KernelBuilder::generateFinalBlockMethod() const {
145    auto savePoint = iBuilder->saveIP();
146    Module * m = iBuilder->getModule();
147    Function * doBlockFunction = m->getFunction(mKernelName + doBlock_suffix);
148    Function * finalBlockFunction = m->getFunction(mKernelName + finalBlock_suffix);
149    iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "fb_entry", finalBlockFunction, 0));
150    // Final Block arguments: self, remaining, then the standard DoBlock args.
151    Function::arg_iterator args = finalBlockFunction->arg_begin();
152    Value * self = &*(args++);
153    /* Skip "remaining" arg */ args++;
154    std::vector<Value *> doBlockArgs = {self};
155    while (args != finalBlockFunction->arg_end()){
156        doBlockArgs.push_back(&*args++);
157    }
158    iBuilder->CreateCall(doBlockFunction, doBlockArgs);
159    iBuilder->CreateRetVoid();
160    iBuilder->restoreIP(savePoint);
161}
162
163// Note: this may be overridden to incorporate doBlock logic directly into
164// the doSegment function.
165void KernelBuilder::generateDoBlockLogic(Value * self, Value * /* blockNo */) const {
166    Function * doBlockFunction = iBuilder->getModule()->getFunction(mKernelName + doBlock_suffix);
167    iBuilder->CreateCall(doBlockFunction, self);
168}
169
170//  The default doSegment method dispatches to the doBlock routine for
171//  each block of the given number of blocksToDo, and then updates counts.
172void KernelBuilder::generateDoSegmentMethod() const {
173    auto savePoint = iBuilder->saveIP();
174    Module * m = iBuilder->getModule();
175    Function * doSegmentFunction = m->getFunction(mKernelName + doSegment_suffix);
176    iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "entry", doSegmentFunction, 0));
177    BasicBlock * entryBlock = iBuilder->GetInsertBlock();
178    BasicBlock * strideLoopCond = BasicBlock::Create(iBuilder->getContext(), "strideLoopCond", doSegmentFunction, 0);
179    BasicBlock * strideLoopBody = BasicBlock::Create(iBuilder->getContext(), "strideLoopBody", doSegmentFunction, 0);
180    BasicBlock * stridesDone = BasicBlock::Create(iBuilder->getContext(), "stridesDone", doSegmentFunction, 0);
181    BasicBlock * checkFinalStride = BasicBlock::Create(iBuilder->getContext(), "checkFinalStride", doSegmentFunction, 0);
182    BasicBlock * checkEndSignals = BasicBlock::Create(iBuilder->getContext(), "checkEndSignals", doSegmentFunction, 0);
183    BasicBlock * callFinalBlock = BasicBlock::Create(iBuilder->getContext(), "callFinalBlock", doSegmentFunction, 0);
184    BasicBlock * segmentDone = BasicBlock::Create(iBuilder->getContext(), "segmentDone", doSegmentFunction, 0);
185    BasicBlock * finalExit = BasicBlock::Create(iBuilder->getContext(), "finalExit", doSegmentFunction, 0);
186    Type * const size_ty = iBuilder->getSizeTy();
187    Constant * stride = ConstantInt::get(size_ty, iBuilder->getStride());
188    Value * strideBlocks = ConstantInt::get(size_ty, iBuilder->getStride() / iBuilder->getBitBlockWidth());
189   
190    Function::arg_iterator args = doSegmentFunction->arg_begin();
191    Value * self = &*(args++);
192    Value * blocksToDo = &*(args);
193   
194    std::vector<Value *> inbufProducerPtrs;
195    std::vector<Value *> endSignalPtrs;
196    for (unsigned i = 0; i < mStreamSetInputs.size(); i++) {
197        Value * param = getStreamSetStructPtr(self, mStreamSetInputs[i].name);
198        inbufProducerPtrs.push_back(mStreamSetInputBuffers[i]->getProducerPosPtr(param));
199        endSignalPtrs.push_back(mStreamSetInputBuffers[i]->getEndOfInputPtr(param));
200    }
201   
202    std::vector<Value *> producerPos;
203    /* Determine the actually available data examining all input stream sets. */
204    LoadInst * p = iBuilder->CreateAtomicLoadAcquire(inbufProducerPtrs[0]);
205    producerPos.push_back(p);
206    Value * availablePos = producerPos[0];
207    for (unsigned i = 1; i < inbufProducerPtrs.size(); i++) {
208        LoadInst * p = iBuilder->CreateAtomicLoadAcquire(inbufProducerPtrs[i]);
209        producerPos.push_back(p);
210        /* Set the available position to be the minimum of availablePos and producerPos. */
211        availablePos = iBuilder->CreateSelect(iBuilder->CreateICmpULT(availablePos, p), availablePos, p);
212    }
213    Value * processed = getProcessedItemCount(self, mStreamSetInputs[0].name);
214    Value * itemsAvail = iBuilder->CreateSub(availablePos, processed);
215//#ifndef NDEBUG
216//    iBuilder->CallPrintInt(mKernelName + "_itemsAvail", itemsAvail);
217//#endif
218    Value * stridesToDo = iBuilder->CreateUDiv(blocksToDo, strideBlocks);
219    Value * stridesAvail = iBuilder->CreateUDiv(itemsAvail, stride);
220    /* Adjust the number of full blocks to do, based on the available data, if necessary. */
221    Value * lessThanFullSegment = iBuilder->CreateICmpULT(stridesAvail, stridesToDo);
222    stridesToDo = iBuilder->CreateSelect(lessThanFullSegment, stridesAvail, stridesToDo);
223    //iBuilder->CallPrintInt(mKernelName + "_stridesAvail", stridesAvail);
224    iBuilder->CreateBr(strideLoopCond);
225
226    iBuilder->SetInsertPoint(strideLoopCond);
227    PHINode * stridesRemaining = iBuilder->CreatePHI(size_ty, 2, "stridesRemaining");
228    stridesRemaining->addIncoming(stridesToDo, entryBlock);
229    Value * notDone = iBuilder->CreateICmpUGT(stridesRemaining, ConstantInt::get(size_ty, 0));
230    iBuilder->CreateCondBr(notDone, strideLoopBody, stridesDone);
231
232    iBuilder->SetInsertPoint(strideLoopBody);
233    Value * blockNo = getScalarField(self, blockNoScalar);   
234
235    generateDoBlockLogic(self, blockNo);
236    setBlockNo(self, iBuilder->CreateAdd(blockNo, strideBlocks));
237    stridesRemaining->addIncoming(iBuilder->CreateSub(stridesRemaining, ConstantInt::get(size_ty, 1)), strideLoopBody);
238    iBuilder->CreateBr(strideLoopCond);
239   
240    iBuilder->SetInsertPoint(stridesDone);
241    processed = iBuilder->CreateAdd(processed, iBuilder->CreateMul(stridesToDo, stride));
242    setProcessedItemCount(self, mStreamSetInputs[0].name, processed);
243    iBuilder->CreateCondBr(lessThanFullSegment, checkFinalStride, segmentDone);
244   
245    iBuilder->SetInsertPoint(checkFinalStride);
246   
247    /* We had less than a full segment of data; we may have reached the end of input
248       on one of the stream sets.  */
249   
250    Value * alreadyDone = getTerminationSignal(self);
251    iBuilder->CreateCondBr(alreadyDone, finalExit, checkEndSignals);
252   
253    iBuilder->SetInsertPoint(checkEndSignals);
254    Value * endOfInput = iBuilder->CreateLoad(endSignalPtrs[0]);
255    if (endSignalPtrs.size() > 1) {
256        /* If there is more than one input stream set, then we need to confirm that one of
257           them has both the endSignal set and the length = to availablePos. */
258        endOfInput = iBuilder->CreateAnd(endOfInput, iBuilder->CreateICmpEQ(availablePos, producerPos[0]));
259        for (unsigned i = 1; i < endSignalPtrs.size(); i++) {
260            Value * e = iBuilder->CreateAnd(iBuilder->CreateLoad(endSignalPtrs[i]), iBuilder->CreateICmpEQ(availablePos, producerPos[i]));
261            endOfInput = iBuilder->CreateOr(endOfInput, e);
262        }
263    }
264    iBuilder->CreateCondBr(endOfInput, callFinalBlock, segmentDone);
265   
266    iBuilder->SetInsertPoint(callFinalBlock);
267   
268    Value * remainingItems = iBuilder->CreateSub(availablePos, processed);
269    createFinalBlockCall(self, remainingItems);
270    setProcessedItemCount(self, mStreamSetInputs[0].name, availablePos);
271   
272    for (unsigned i = 0; i < mStreamSetOutputs.size(); i++) {
273        Value * ssStructPtr = getStreamSetStructPtr(self, mStreamSetOutputs[i].name);
274        mStreamSetOutputBuffers[i]->setEndOfInput(ssStructPtr);
275    }
276    setTerminationSignal(self);
277    iBuilder->CreateBr(segmentDone);
278   
279    iBuilder->SetInsertPoint(segmentDone);
280//#ifndef NDEBUG
281//    iBuilder->CallPrintInt(mKernelName + "_produced", produced);
282//#endif
283    for (unsigned i = 0; i < mStreamSetOutputs.size(); i++) {
284        Value * produced = getProducedItemCount(self, mStreamSetOutputs[i].name);
285        Value * ssStructPtr = getStreamSetStructPtr(self, mStreamSetOutputs[i].name);
286        Value * producerPosPtr = mStreamSetOutputBuffers[i]->getProducerPosPtr(ssStructPtr);
287        iBuilder->CreateAtomicStoreRelease(produced, producerPosPtr);
288    }
289    iBuilder->CreateBr(finalExit);
290    iBuilder->SetInsertPoint(finalExit);
291
292    iBuilder->CreateRetVoid();
293    iBuilder->restoreIP(savePoint);
294}
295
296ConstantInt * KernelBuilder::getScalarIndex(const std::string & name) const {
297    const auto f = mKernelMap.find(name);
298    if (LLVM_UNLIKELY(f == mKernelMap.end())) {
299        llvm::report_fatal_error("Kernel does not contain scalar: " + name);
300    }
301    return iBuilder->getInt32(f->second);
302}
303
304unsigned KernelBuilder::getScalarCount() const {
305    return mKernelFields.size();
306}
307
308Value * KernelBuilder::getScalarFieldPtr(Value * self, const std::string & fieldName) const {
309    return iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(fieldName)});
310}
311
312Value * KernelBuilder::getScalarField(Value * self, const std::string & fieldName) const {
313    return iBuilder->CreateLoad(getScalarFieldPtr(self, fieldName));
314}
315
316void KernelBuilder::setScalarField(Value * self, const std::string & fieldName, Value * newFieldVal) const {
317    iBuilder->CreateStore(newFieldVal, getScalarFieldPtr(self, fieldName));
318}
319
320LoadInst * KernelBuilder::acquireLogicalSegmentNo(Value * self) const {
321    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(logicalSegmentNoScalar)});
322    return iBuilder->CreateAtomicLoadAcquire(ptr);
323}
324
325Value * KernelBuilder::getProcessedItemCount(Value * self, const std::string & ssName) const {
326    return getScalarField(self, ssName + processedItemCountSuffix);
327}
328
329Value * KernelBuilder::getProducedItemCount(Value * self, const std::string & ssName) const {
330    return getScalarField(self, ssName + producedItemCountSuffix);
331}
332
333Value * KernelBuilder::getTerminationSignal(Value * self) const {
334    return getScalarField(self, terminationSignal);
335}
336
337void KernelBuilder::releaseLogicalSegmentNo(Value * self, Value * newCount) const {
338    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(logicalSegmentNoScalar)});
339    iBuilder->CreateAtomicStoreRelease(newCount, ptr);
340}
341
342void KernelBuilder::setProcessedItemCount(Value * self, const std::string & ssName, Value * newCount) const {
343    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(ssName + processedItemCountSuffix)});
344    iBuilder->CreateStore(newCount, ptr);
345}
346
347void KernelBuilder::setProducedItemCount(Value * self, const std::string & ssName, Value * newCount) const {
348    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(ssName + producedItemCountSuffix)});
349    iBuilder->CreateStore(newCount, ptr);
350}
351
352void KernelBuilder::setTerminationSignal(Value * self) const {
353    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(terminationSignal)});
354    iBuilder->CreateStore(ConstantInt::get(iBuilder->getInt1Ty(), 1), ptr);
355}
356
357Value * KernelBuilder::getBlockNo(Value * self) const {
358    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(blockNoScalar)});
359    return iBuilder->CreateLoad(ptr);
360}
361
362void KernelBuilder::setBlockNo(Value * self, Value * newFieldVal) const {
363    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(blockNoScalar)});
364    iBuilder->CreateStore(newFieldVal, ptr);
365}
366
367
368Value * KernelBuilder::getParameter(Function * f, const std::string & paramName) const {
369    for (Function::arg_iterator argIter = f->arg_begin(), end = f->arg_end(); argIter != end; argIter++) {
370        Value * arg = &*argIter;
371        if (arg->getName() == paramName) return arg;
372    }
373    llvm::report_fatal_error("Method does not have parameter: " + paramName);
374}
375
376unsigned KernelBuilder::getStreamSetIndex(const std::string & name) const {
377    const auto f = mStreamSetNameMap.find(name);
378    if (LLVM_UNLIKELY(f == mStreamSetNameMap.end())) {
379        llvm::report_fatal_error("Kernel does not contain stream set: " + name);
380    }
381    return f->second;
382}
383
384size_t KernelBuilder::getStreamSetBufferSize(Value * /* self */, const std::string & name) const {
385    const unsigned index = getStreamSetIndex(name);
386    StreamSetBuffer * buf = nullptr;
387    if (index < mStreamSetInputs.size()) {
388        buf = mStreamSetInputBuffers[index];
389    } else {
390        buf = mStreamSetOutputBuffers[index - mStreamSetInputs.size()];
391    }
392    return buf->getBufferSize();
393}
394
395Value * KernelBuilder::getStreamSetStructPtr(Value * self, const std::string & name) const {
396    return getScalarField(self, name + structPtrSuffix);
397}
398
399Value * KernelBuilder::getStreamSetBlockPtr(Value * self, const std::string &name, Value * blockNo) const {
400    Value * const structPtr = getStreamSetStructPtr(self, name);
401    const unsigned index = getStreamSetIndex(name);
402    StreamSetBuffer * buf = nullptr;
403    if (index < mStreamSetInputs.size()) {
404        buf = mStreamSetInputBuffers[index];
405    } else {
406        buf = mStreamSetOutputBuffers[index - mStreamSetInputs.size()];
407    }   
408    return buf->getStreamSetBlockPointer(structPtr, blockNo);
409}
410
411Value * KernelBuilder::getStream(Value * self, const std::string & name, Value * blockNo, Value * index) {
412    return iBuilder->CreateGEP(getStreamSetBlockPtr(self, name, blockNo), {iBuilder->getInt32(0), index});
413}
414
415void KernelBuilder::createInstance() {
416    if (LLVM_UNLIKELY(mKernelStateType == nullptr)) {
417        llvm::report_fatal_error("Cannot create kernel instance before calling prepareKernel()");
418    }
419    mKernelInstance = iBuilder->CreateCacheAlignedAlloca(mKernelStateType);
420    Module * m = iBuilder->getModule();
421    std::vector<Value *> init_args = {mKernelInstance};
422    for (auto a : mInitialArguments) {
423        init_args.push_back(a);
424    }
425    for (auto b : mStreamSetInputBuffers) {
426        init_args.push_back(b->getStreamSetStructPtr());
427    }
428    for (auto b : mStreamSetOutputBuffers) {
429        init_args.push_back(b->getStreamSetStructPtr());
430    }
431    std::string initFnName = mKernelName + init_suffix;
432    Function * initMethod = m->getFunction(initFnName);
433    if (!initMethod) {
434        llvm::report_fatal_error("Cannot find " + initFnName);
435    }
436    iBuilder->CreateCall(initMethod, init_args);
437}
438
439Function * KernelBuilder::generateThreadFunction(const std::string & name) const {
440    if (LLVM_UNLIKELY(mKernelStateType == nullptr)) {
441        llvm::report_fatal_error("Cannot generate thread function before calling prepareKernel()");
442    }
443    Module * m = iBuilder->getModule();
444    Type * const voidTy = iBuilder->getVoidTy();
445    Type * const voidPtrTy = iBuilder->getVoidPtrTy();
446    Type * const int8PtrTy = iBuilder->getInt8PtrTy();
447    Type * const int1ty = iBuilder->getInt1Ty();
448
449    Function * const threadFunc = cast<Function>(m->getOrInsertFunction(name, voidTy, int8PtrTy, nullptr));
450    threadFunc->setCallingConv(CallingConv::C);
451    Function::arg_iterator args = threadFunc->arg_begin();
452
453    Value * const arg = &*(args++);
454    arg->setName("args");
455
456    iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "entry", threadFunc,0));
457
458    Value * self = iBuilder->CreateBitCast(arg, PointerType::get(mKernelStateType, 0));
459
460    std::vector<Value *> inbufProducerPtrs;
461    std::vector<Value *> inbufConsumerPtrs;
462    std::vector<Value *> outbufProducerPtrs;
463    std::vector<Value *> outbufConsumerPtrs;   
464    std::vector<Value *> endSignalPtrs;
465
466    for (unsigned i = 0; i < mStreamSetInputs.size(); i++) {
467        Value * ssStructPtr = getStreamSetStructPtr(self, mStreamSetInputs[i].name);
468        inbufProducerPtrs.push_back(mStreamSetInputBuffers[i]->getProducerPosPtr(ssStructPtr));
469        inbufConsumerPtrs.push_back(mStreamSetInputBuffers[i]->getConsumerPosPtr(ssStructPtr));
470        endSignalPtrs.push_back(mStreamSetInputBuffers[i]->getEndOfInputPtr(ssStructPtr));
471    }
472    for (unsigned i = 0; i < mStreamSetOutputs.size(); i++) {
473        Value * ssStructPtr = getStreamSetStructPtr(self, mStreamSetOutputs[i].name);
474        outbufProducerPtrs.push_back(mStreamSetOutputBuffers[i]->getProducerPosPtr(ssStructPtr));
475        outbufConsumerPtrs.push_back(mStreamSetOutputBuffers[i]->getConsumerPosPtr(ssStructPtr));
476    }
477
478    const unsigned segmentBlocks = codegen::SegmentSize;
479    const unsigned bufferSegments = codegen::BufferSegments;
480    const unsigned segmentSize = segmentBlocks * iBuilder->getBitBlockWidth();
481    Type * const size_ty = iBuilder->getSizeTy();
482
483    Value * segSize = ConstantInt::get(size_ty, segmentSize);
484    Value * bufferSize = ConstantInt::get(size_ty, segmentSize * (bufferSegments - 1));
485    Value * segBlocks = ConstantInt::get(size_ty, segmentBlocks);
486   
487    BasicBlock * outputCheckBlock = BasicBlock::Create(iBuilder->getContext(), "outputCheck", threadFunc, 0);
488    BasicBlock * inputCheckBlock = BasicBlock::Create(iBuilder->getContext(), "inputCheck", threadFunc, 0);
489   
490    BasicBlock * endSignalCheckBlock = BasicBlock::Create(iBuilder->getContext(), "endSignalCheck", threadFunc, 0);
491    BasicBlock * doSegmentBlock = BasicBlock::Create(iBuilder->getContext(), "doSegment", threadFunc, 0);
492    BasicBlock * endBlock = BasicBlock::Create(iBuilder->getContext(), "end", threadFunc, 0);
493    BasicBlock * doFinalSegBlock = BasicBlock::Create(iBuilder->getContext(), "doFinalSeg", threadFunc, 0);
494    BasicBlock * doFinalBlock = BasicBlock::Create(iBuilder->getContext(), "doFinal", threadFunc, 0);
495
496    iBuilder->CreateBr(outputCheckBlock);
497
498    iBuilder->SetInsertPoint(outputCheckBlock);
499
500    Value * waitCondTest = ConstantInt::get(int1ty, 1);   
501    for (unsigned i = 0; i < outbufProducerPtrs.size(); i++) {
502        LoadInst * producerPos = iBuilder->CreateAtomicLoadAcquire(outbufProducerPtrs[i]);
503        // iBuilder->CallPrintInt(name + ":output producerPos", producerPos);
504        LoadInst * consumerPos = iBuilder->CreateAtomicLoadAcquire(outbufConsumerPtrs[i]);
505        // iBuilder->CallPrintInt(name + ":output consumerPos", consumerPos);
506        waitCondTest = iBuilder->CreateAnd(waitCondTest, iBuilder->CreateICmpULE(producerPos, iBuilder->CreateAdd(consumerPos, bufferSize)));
507    }
508   
509    iBuilder->CreateCondBr(waitCondTest, inputCheckBlock, outputCheckBlock); 
510
511    iBuilder->SetInsertPoint(inputCheckBlock); 
512
513    Value * requiredSize = segSize;
514    if (mLookAheadPositions > 0) {
515        requiredSize = iBuilder->CreateAdd(segSize, ConstantInt::get(size_ty, mLookAheadPositions));
516    }
517    waitCondTest = ConstantInt::get(int1ty, 1); 
518    for (unsigned i = 0; i < inbufProducerPtrs.size(); i++) {
519        LoadInst * producerPos = iBuilder->CreateAtomicLoadAcquire(inbufProducerPtrs[i]);
520        // iBuilder->CallPrintInt(name + ":input producerPos", producerPos);
521        LoadInst * consumerPos = iBuilder->CreateAtomicLoadAcquire(inbufConsumerPtrs[i]);
522        // iBuilder->CallPrintInt(name + ":input consumerPos", consumerPos);
523        waitCondTest = iBuilder->CreateAnd(waitCondTest, iBuilder->CreateICmpULE(iBuilder->CreateAdd(consumerPos, requiredSize), producerPos));
524    }
525
526    iBuilder->CreateCondBr(waitCondTest, doSegmentBlock, endSignalCheckBlock);
527   
528    iBuilder->SetInsertPoint(endSignalCheckBlock);
529   
530    LoadInst * endSignal = iBuilder->CreateLoad(endSignalPtrs[0]);
531    for (unsigned i = 1; i < endSignalPtrs.size(); i++){
532        LoadInst * endSignal_next = iBuilder->CreateLoad(endSignalPtrs[i]);
533        iBuilder->CreateAnd(endSignal, endSignal_next);
534    }
535       
536    iBuilder->CreateCondBr(endSignal, endBlock, inputCheckBlock);
537   
538    iBuilder->SetInsertPoint(doSegmentBlock);
539 
540    createDoSegmentCall(self, segBlocks);
541
542    for (unsigned i = 0; i < inbufConsumerPtrs.size(); i++) {
543        Value * consumerPos = iBuilder->CreateAdd(iBuilder->CreateLoad(inbufConsumerPtrs[i]), segSize);
544        iBuilder->CreateAtomicStoreRelease(consumerPos, inbufConsumerPtrs[i]);
545    }
546   
547    for (unsigned i = 0; i < outbufProducerPtrs.size(); i++) {
548        Value * produced = getProducedItemCount(self, mStreamSetOutputs[i].name);
549        iBuilder->CreateAtomicStoreRelease(produced, outbufProducerPtrs[i]);
550    }
551   
552    Value * earlyEndSignal = getTerminationSignal(self);
553    if (earlyEndSignal != ConstantInt::getNullValue(iBuilder->getInt1Ty())) {
554        BasicBlock * earlyEndBlock = BasicBlock::Create(iBuilder->getContext(), "earlyEndSignal", threadFunc, 0);
555        iBuilder->CreateCondBr(earlyEndSignal, earlyEndBlock, outputCheckBlock);
556
557        iBuilder->SetInsertPoint(earlyEndBlock);
558        for (unsigned i = 0; i < mStreamSetOutputs.size(); i++) {
559            Value * ssStructPtr = getStreamSetStructPtr(self, mStreamSetOutputs[i].name);
560            mStreamSetOutputBuffers[i]->setEndOfInput(ssStructPtr);
561        }       
562    }
563    iBuilder->CreateBr(outputCheckBlock);
564     
565    iBuilder->SetInsertPoint(endBlock);
566    LoadInst * producerPos = iBuilder->CreateLoad(inbufProducerPtrs[0]);
567    LoadInst * consumerPos = iBuilder->CreateLoad(inbufConsumerPtrs[0]);
568    Value * remainingBytes = iBuilder->CreateSub(producerPos, consumerPos);
569    Value * blockSize = ConstantInt::get(size_ty, iBuilder->getBitBlockWidth());
570    Value * blocks = iBuilder->CreateUDiv(remainingBytes, blockSize);
571    Value * finalBlockRemainingBytes = iBuilder->CreateURem(remainingBytes, blockSize);
572
573    iBuilder->CreateCondBr(iBuilder->CreateICmpEQ(blocks, ConstantInt::get(size_ty, 0)), doFinalBlock, doFinalSegBlock);
574
575    iBuilder->SetInsertPoint(doFinalSegBlock);
576
577    createDoSegmentCall(self, blocks);
578
579    iBuilder->CreateBr(doFinalBlock);
580
581    iBuilder->SetInsertPoint(doFinalBlock);
582
583    createFinalBlockCall(self, finalBlockRemainingBytes);
584
585    for (unsigned i = 0; i < inbufConsumerPtrs.size(); i++) {
586        Value * consumerPos = iBuilder->CreateAdd(iBuilder->CreateLoad(inbufConsumerPtrs[i]), remainingBytes);
587        iBuilder->CreateAtomicStoreRelease(consumerPos, inbufConsumerPtrs[i]);
588    }
589    for (unsigned i = 0; i < outbufProducerPtrs.size(); i++) {
590        iBuilder->CreateAtomicStoreRelease(producerPos, outbufProducerPtrs[i]);
591    }
592
593    for (unsigned i = 0; i < mStreamSetOutputs.size(); i++) {
594        Value * ssStructPtr = getStreamSetStructPtr(self, mStreamSetOutputs[i].name);
595        mStreamSetOutputBuffers[i]->setEndOfInput(ssStructPtr);
596    }
597
598    iBuilder->CreatePThreadExitCall(Constant::getNullValue(voidPtrTy));
599    iBuilder->CreateRetVoid();
600
601    return threadFunc;
602
603}
604
605KernelBuilder::~KernelBuilder() {
606}
Note: See TracBrowser for help on using the repository browser.