source: icGREP/icgrep-devel/icgrep/kernels/kernel.cpp @ 5255

Last change on this file since 5255 was 5252, checked in by cameron, 3 years ago

Separate doSegment/final segment processing in pipeline loop; check optional NoTerminateAttribute?

File size: 28.5 KB
Line 
1/*
2 *  Copyright (c) 2016 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 */
5
6#include "kernel.h"
7#include <llvm/IR/Module.h>
8#include <llvm/IR/Type.h>
9#include <llvm/IR/Value.h>
10#include <llvm/Support/raw_ostream.h>
11#include <llvm/Support/ErrorHandling.h>
12#include <toolchain.h>
13
14using namespace llvm;
15using namespace kernel;
16
17KernelBuilder::KernelBuilder(IDISA::IDISA_Builder * builder,
18                             std::string kernelName,
19                             std::vector<Binding> stream_inputs,
20                             std::vector<Binding> stream_outputs,
21                             std::vector<Binding> scalar_parameters,
22                             std::vector<Binding> scalar_outputs,
23                             std::vector<Binding> internal_scalars)
24: KernelInterface(builder, kernelName, stream_inputs, stream_outputs, scalar_parameters, scalar_outputs, internal_scalars),
25mNoTerminateAttribute(false) {
26
27}
28
29unsigned KernelBuilder::addScalar(Type * type, const std::string & name) {
30    if (LLVM_UNLIKELY(mKernelStateType != nullptr)) {
31        llvm::report_fatal_error("Cannot add kernel field " + name + " after kernel state finalized");
32    }
33    const auto index = mKernelFields.size();
34    mKernelMap.emplace(name, index);
35    mKernelFields.push_back(type);
36    return index;
37}
38
39void KernelBuilder::prepareKernel() {
40    if (LLVM_UNLIKELY(mKernelStateType != nullptr)) {
41        llvm::report_fatal_error("Cannot prepare kernel after kernel state finalized");
42    }
43    unsigned blockSize = iBuilder->getBitBlockWidth();
44    if (mStreamSetInputs.size() != mStreamSetInputBuffers.size()) {
45        std::string tmp;
46        raw_string_ostream out(tmp);
47        out << "kernel contains " << mStreamSetInputBuffers.size() << " input buffers for "
48            << mStreamSetInputs.size() << " input stream sets.";
49        throw std::runtime_error(out.str());
50    }
51    if (mStreamSetOutputs.size() != mStreamSetOutputBuffers.size()) {
52        std::string tmp;
53        raw_string_ostream out(tmp);
54        out << "kernel contains " << mStreamSetOutputBuffers.size() << " output buffers for "
55            << mStreamSetOutputs.size() << " output stream sets.";
56        throw std::runtime_error(out.str());
57    }
58    int streamSetNo = 0;
59    for (unsigned i = 0; i < mStreamSetInputs.size(); i++) {
60        if ((mStreamSetInputBuffers[i]->getBufferSize() > 0) && (mStreamSetInputBuffers[i]->getBufferSize() < codegen::SegmentSize + (blockSize + mLookAheadPositions - 1)/blockSize)) {
61             llvm::report_fatal_error("Kernel preparation: Buffer size too small " + mStreamSetInputs[i].name);
62        }
63        mScalarInputs.push_back(Binding{mStreamSetInputBuffers[i]->getStreamSetStructPointerType(), mStreamSetInputs[i].name + structPtrSuffix});
64        mStreamSetNameMap.emplace(mStreamSetInputs[i].name, streamSetNo);
65        addScalar(iBuilder->getSizeTy(), mStreamSetInputs[i].name + processedItemCountSuffix);
66        streamSetNo++;
67    }
68    for (unsigned i = 0; i < mStreamSetOutputs.size(); i++) {
69        mScalarInputs.push_back(Binding{mStreamSetOutputBuffers[i]->getStreamSetStructPointerType(), mStreamSetOutputs[i].name + structPtrSuffix});
70        mStreamSetNameMap.emplace(mStreamSetOutputs[i].name, streamSetNo);
71        addScalar(iBuilder->getSizeTy(), mStreamSetOutputs[i].name + producedItemCountSuffix);
72        streamSetNo++;
73    }
74    for (auto binding : mScalarInputs) {
75        addScalar(binding.type, binding.name);
76    }
77    for (auto binding : mScalarOutputs) {
78        addScalar(binding.type, binding.name);
79    }
80    for (auto binding : mInternalScalars) {
81        addScalar(binding.type, binding.name);
82    }
83    addScalar(iBuilder->getSizeTy(), blockNoScalar);
84    addScalar(iBuilder->getSizeTy(), logicalSegmentNoScalar);
85    addScalar(iBuilder->getInt1Ty(), terminationSignal);
86    mKernelStateType = StructType::create(iBuilder->getContext(), mKernelFields, mKernelName);
87}
88
89std::unique_ptr<Module> KernelBuilder::createKernelModule(const std::vector<StreamSetBuffer *> & inputs, const std::vector<StreamSetBuffer *> & outputs) {
90    auto saveModule = iBuilder->getModule();
91    auto savePoint = iBuilder->saveIP();
92    auto module = make_unique<Module>(mKernelName + "_" + iBuilder->getBitBlockTypeName(), iBuilder->getContext());
93    iBuilder->setModule(module.get());
94    generateKernel(inputs, outputs);
95    iBuilder->setModule(saveModule);
96    iBuilder->restoreIP(savePoint);
97    return module;
98}
99
100void KernelBuilder::generateKernel(const std::vector<StreamSetBuffer *> & inputs, const std::vector<StreamSetBuffer *> & outputs) {
101    auto savePoint = iBuilder->saveIP();
102    Module * const m = iBuilder->getModule();
103    mStreamSetInputBuffers.assign(inputs.begin(), inputs.end());
104    mStreamSetOutputBuffers.assign(outputs.begin(), outputs.end());
105    prepareKernel();            // possibly overridden by the KernelBuilder subtype
106    addKernelDeclarations(m);
107    generateInitMethod();       // possibly overridden by the KernelBuilder subtype
108    generateDoBlockMethod();    // must be implemented by the KernelBuilder subtype
109    generateFinalBlockMethod(); // possibly overridden by the KernelBuilder subtype
110    generateDoSegmentMethod();
111
112    // Implement the accumulator get functions
113    for (auto binding : mScalarOutputs) {
114        auto fnName = mKernelName + accumulator_infix + binding.name;
115        Function * accumFn = m->getFunction(fnName);
116        iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "get_" + binding.name, accumFn, 0));
117        Value * self = &*(accumFn->arg_begin());
118        Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(binding.name)});
119        Value * retVal = iBuilder->CreateLoad(ptr);
120        iBuilder->CreateRet(retVal);
121    }
122    generateInitMethod();
123    iBuilder->restoreIP(savePoint);
124}
125
126// Default init method, possibly overridden if special init actions required.
127void KernelBuilder::generateInitMethod() const {
128    auto savePoint = iBuilder->saveIP();
129    Module * const m = iBuilder->getModule();
130    Function * initFunction = m->getFunction(mKernelName + init_suffix);
131    iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "Init_entry", initFunction, 0));   
132    Function::arg_iterator args = initFunction->arg_begin();
133    Value * self = &*(args++);
134    iBuilder->CreateStore(ConstantAggregateZero::get(mKernelStateType), self);
135    for (auto binding : mScalarInputs) {
136        Value * param = &*(args++);
137        Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(binding.name)});
138        iBuilder->CreateStore(param, ptr);
139    }
140    iBuilder->CreateRetVoid();
141    iBuilder->restoreIP(savePoint);
142}
143
144//  The default finalBlock method simply dispatches to the doBlock routine.
145void KernelBuilder::generateFinalBlockMethod() const {
146    auto savePoint = iBuilder->saveIP();
147    Module * m = iBuilder->getModule();
148    Function * doBlockFunction = m->getFunction(mKernelName + doBlock_suffix);
149    Function * finalBlockFunction = m->getFunction(mKernelName + finalBlock_suffix);
150    iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "fb_entry", finalBlockFunction, 0));
151    // Final Block arguments: self, remaining, then the standard DoBlock args.
152    Function::arg_iterator args = finalBlockFunction->arg_begin();
153    Value * self = &*(args++);
154    /* Skip "remaining" arg */ args++;
155    std::vector<Value *> doBlockArgs = {self};
156    while (args != finalBlockFunction->arg_end()){
157        doBlockArgs.push_back(&*args++);
158    }
159    iBuilder->CreateCall(doBlockFunction, doBlockArgs);
160    iBuilder->CreateRetVoid();
161    iBuilder->restoreIP(savePoint);
162}
163
164// Note: this may be overridden to incorporate doBlock logic directly into
165// the doSegment function.
166void KernelBuilder::generateDoBlockLogic(Value * self, Value * /* blockNo */) const {
167    Function * doBlockFunction = iBuilder->getModule()->getFunction(mKernelName + doBlock_suffix);
168    iBuilder->CreateCall(doBlockFunction, self);
169}
170
171//  The default doSegment method dispatches to the doBlock routine for
172//  each block of the given number of blocksToDo, and then updates counts.
173void KernelBuilder::generateDoSegmentMethod() const {
174    auto savePoint = iBuilder->saveIP();
175    Module * m = iBuilder->getModule();
176    Function * doSegmentFunction = m->getFunction(mKernelName + doSegment_suffix);
177    iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "entry", doSegmentFunction, 0));
178    BasicBlock * entryBlock = iBuilder->GetInsertBlock();
179    BasicBlock * strideLoopCond = BasicBlock::Create(iBuilder->getContext(), "strideLoopCond", doSegmentFunction, 0);
180    BasicBlock * strideLoopBody = BasicBlock::Create(iBuilder->getContext(), "strideLoopBody", doSegmentFunction, 0);
181    BasicBlock * stridesDone = BasicBlock::Create(iBuilder->getContext(), "stridesDone", doSegmentFunction, 0);
182    BasicBlock * checkFinalStride = BasicBlock::Create(iBuilder->getContext(), "checkFinalStride", doSegmentFunction, 0);
183    BasicBlock * checkEndSignals = BasicBlock::Create(iBuilder->getContext(), "checkEndSignals", doSegmentFunction, 0);
184    BasicBlock * callFinalBlock = BasicBlock::Create(iBuilder->getContext(), "callFinalBlock", doSegmentFunction, 0);
185    BasicBlock * segmentDone = BasicBlock::Create(iBuilder->getContext(), "segmentDone", doSegmentFunction, 0);
186    BasicBlock * finalExit = BasicBlock::Create(iBuilder->getContext(), "finalExit", doSegmentFunction, 0);
187    Type * const size_ty = iBuilder->getSizeTy();
188    Constant * stride = ConstantInt::get(size_ty, iBuilder->getStride());
189    Value * strideBlocks = ConstantInt::get(size_ty, iBuilder->getStride() / iBuilder->getBitBlockWidth());
190   
191    Function::arg_iterator args = doSegmentFunction->arg_begin();
192    Value * self = &*(args++);
193    Value * blocksToDo = &*(args);
194   
195    std::vector<Value *> inbufProducerPtrs;
196    std::vector<Value *> endSignalPtrs;
197    for (unsigned i = 0; i < mStreamSetInputs.size(); i++) {
198        Value * param = getStreamSetStructPtr(self, mStreamSetInputs[i].name);
199        inbufProducerPtrs.push_back(mStreamSetInputBuffers[i]->getProducerPosPtr(param));
200        endSignalPtrs.push_back(mStreamSetInputBuffers[i]->getEndOfInputPtr(param));
201    }
202   
203    std::vector<Value *> producerPos;
204    /* Determine the actually available data examining all input stream sets. */
205    LoadInst * p = iBuilder->CreateAtomicLoadAcquire(inbufProducerPtrs[0]);
206    producerPos.push_back(p);
207    Value * availablePos = producerPos[0];
208    for (unsigned i = 1; i < inbufProducerPtrs.size(); i++) {
209        LoadInst * p = iBuilder->CreateAtomicLoadAcquire(inbufProducerPtrs[i]);
210        producerPos.push_back(p);
211        /* Set the available position to be the minimum of availablePos and producerPos. */
212        availablePos = iBuilder->CreateSelect(iBuilder->CreateICmpULT(availablePos, p), availablePos, p);
213    }
214    Value * processed = getProcessedItemCount(self, mStreamSetInputs[0].name);
215    Value * itemsAvail = iBuilder->CreateSub(availablePos, processed);
216//#ifndef NDEBUG
217//    iBuilder->CallPrintInt(mKernelName + "_itemsAvail", itemsAvail);
218//#endif
219    Value * stridesToDo = iBuilder->CreateUDiv(blocksToDo, strideBlocks);
220    Value * stridesAvail = iBuilder->CreateUDiv(itemsAvail, stride);
221    /* Adjust the number of full blocks to do, based on the available data, if necessary. */
222    Value * lessThanFullSegment = iBuilder->CreateICmpULT(stridesAvail, stridesToDo);
223    stridesToDo = iBuilder->CreateSelect(lessThanFullSegment, stridesAvail, stridesToDo);
224    //iBuilder->CallPrintInt(mKernelName + "_stridesAvail", stridesAvail);
225    iBuilder->CreateBr(strideLoopCond);
226
227    iBuilder->SetInsertPoint(strideLoopCond);
228    PHINode * stridesRemaining = iBuilder->CreatePHI(size_ty, 2, "stridesRemaining");
229    stridesRemaining->addIncoming(stridesToDo, entryBlock);
230    Value * notDone = iBuilder->CreateICmpUGT(stridesRemaining, ConstantInt::get(size_ty, 0));
231    iBuilder->CreateCondBr(notDone, strideLoopBody, stridesDone);
232
233    iBuilder->SetInsertPoint(strideLoopBody);
234    Value * blockNo = getScalarField(self, blockNoScalar);   
235
236    generateDoBlockLogic(self, blockNo);
237    setBlockNo(self, iBuilder->CreateAdd(blockNo, strideBlocks));
238    stridesRemaining->addIncoming(iBuilder->CreateSub(stridesRemaining, ConstantInt::get(size_ty, 1)), strideLoopBody);
239    iBuilder->CreateBr(strideLoopCond);
240   
241    iBuilder->SetInsertPoint(stridesDone);
242    processed = iBuilder->CreateAdd(processed, iBuilder->CreateMul(stridesToDo, stride));
243    setProcessedItemCount(self, mStreamSetInputs[0].name, processed);
244    iBuilder->CreateCondBr(lessThanFullSegment, checkFinalStride, segmentDone);
245   
246    iBuilder->SetInsertPoint(checkFinalStride);
247   
248    /* We had less than a full segment of data; we may have reached the end of input
249       on one of the stream sets.  */
250   
251    Value * alreadyDone = getTerminationSignal(self);
252    iBuilder->CreateCondBr(alreadyDone, finalExit, checkEndSignals);
253   
254    iBuilder->SetInsertPoint(checkEndSignals);
255    Value * endOfInput = iBuilder->CreateLoad(endSignalPtrs[0]);
256    if (endSignalPtrs.size() > 1) {
257        /* If there is more than one input stream set, then we need to confirm that one of
258           them has both the endSignal set and the length = to availablePos. */
259        endOfInput = iBuilder->CreateAnd(endOfInput, iBuilder->CreateICmpEQ(availablePos, producerPos[0]));
260        for (unsigned i = 1; i < endSignalPtrs.size(); i++) {
261            Value * e = iBuilder->CreateAnd(iBuilder->CreateLoad(endSignalPtrs[i]), iBuilder->CreateICmpEQ(availablePos, producerPos[i]));
262            endOfInput = iBuilder->CreateOr(endOfInput, e);
263        }
264    }
265    iBuilder->CreateCondBr(endOfInput, callFinalBlock, segmentDone);
266   
267    iBuilder->SetInsertPoint(callFinalBlock);
268   
269    Value * remainingItems = iBuilder->CreateSub(availablePos, processed);
270    createFinalBlockCall(self, remainingItems);
271    setProcessedItemCount(self, mStreamSetInputs[0].name, availablePos);
272   
273    for (unsigned i = 0; i < mStreamSetOutputs.size(); i++) {
274        Value * ssStructPtr = getStreamSetStructPtr(self, mStreamSetOutputs[i].name);
275        mStreamSetOutputBuffers[i]->setEndOfInput(ssStructPtr);
276    }
277    setTerminationSignal(self);
278    iBuilder->CreateBr(segmentDone);
279   
280    iBuilder->SetInsertPoint(segmentDone);
281//#ifndef NDEBUG
282//    iBuilder->CallPrintInt(mKernelName + "_produced", produced);
283//#endif
284    for (unsigned i = 0; i < mStreamSetOutputs.size(); i++) {
285        Value * produced = getProducedItemCount(self, mStreamSetOutputs[i].name);
286        Value * ssStructPtr = getStreamSetStructPtr(self, mStreamSetOutputs[i].name);
287        Value * producerPosPtr = mStreamSetOutputBuffers[i]->getProducerPosPtr(ssStructPtr);
288        iBuilder->CreateAtomicStoreRelease(produced, producerPosPtr);
289    }
290    iBuilder->CreateBr(finalExit);
291    iBuilder->SetInsertPoint(finalExit);
292
293    iBuilder->CreateRetVoid();
294    iBuilder->restoreIP(savePoint);
295}
296
297ConstantInt * KernelBuilder::getScalarIndex(const std::string & name) const {
298    const auto f = mKernelMap.find(name);
299    if (LLVM_UNLIKELY(f == mKernelMap.end())) {
300        llvm::report_fatal_error("Kernel does not contain scalar: " + name);
301    }
302    return iBuilder->getInt32(f->second);
303}
304
305unsigned KernelBuilder::getScalarCount() const {
306    return mKernelFields.size();
307}
308
309Value * KernelBuilder::getScalarFieldPtr(Value * self, const std::string & fieldName) const {
310    return iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(fieldName)});
311}
312
313Value * KernelBuilder::getScalarField(Value * self, const std::string & fieldName) const {
314    return iBuilder->CreateLoad(getScalarFieldPtr(self, fieldName));
315}
316
317void KernelBuilder::setScalarField(Value * self, const std::string & fieldName, Value * newFieldVal) const {
318    iBuilder->CreateStore(newFieldVal, getScalarFieldPtr(self, fieldName));
319}
320
321LoadInst * KernelBuilder::acquireLogicalSegmentNo(Value * self) const {
322    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(logicalSegmentNoScalar)});
323    return iBuilder->CreateAtomicLoadAcquire(ptr);
324}
325
326Value * KernelBuilder::getProcessedItemCount(Value * self, const std::string & ssName) const {
327    return getScalarField(self, ssName + processedItemCountSuffix);
328}
329
330Value * KernelBuilder::getProducedItemCount(Value * self, const std::string & ssName) const {
331    return getScalarField(self, ssName + producedItemCountSuffix);
332}
333
334Value * KernelBuilder::getTerminationSignal(Value * self) const {
335    return getScalarField(self, terminationSignal);
336}
337
338void KernelBuilder::releaseLogicalSegmentNo(Value * self, Value * newCount) const {
339    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(logicalSegmentNoScalar)});
340    iBuilder->CreateAtomicStoreRelease(newCount, ptr);
341}
342
343void KernelBuilder::setProcessedItemCount(Value * self, const std::string & ssName, Value * newCount) const {
344    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(ssName + processedItemCountSuffix)});
345    iBuilder->CreateStore(newCount, ptr);
346}
347
348void KernelBuilder::setProducedItemCount(Value * self, const std::string & ssName, Value * newCount) const {
349    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(ssName + producedItemCountSuffix)});
350    iBuilder->CreateStore(newCount, ptr);
351}
352
353void KernelBuilder::setTerminationSignal(Value * self) const {
354    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(terminationSignal)});
355    iBuilder->CreateStore(ConstantInt::get(iBuilder->getInt1Ty(), 1), ptr);
356}
357
358Value * KernelBuilder::getBlockNo(Value * self) const {
359    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(blockNoScalar)});
360    return iBuilder->CreateLoad(ptr);
361}
362
363void KernelBuilder::setBlockNo(Value * self, Value * newFieldVal) const {
364    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(blockNoScalar)});
365    iBuilder->CreateStore(newFieldVal, ptr);
366}
367
368
369Value * KernelBuilder::getParameter(Function * f, const std::string & paramName) const {
370    for (Function::arg_iterator argIter = f->arg_begin(), end = f->arg_end(); argIter != end; argIter++) {
371        Value * arg = &*argIter;
372        if (arg->getName() == paramName) return arg;
373    }
374    llvm::report_fatal_error("Method does not have parameter: " + paramName);
375}
376
377unsigned KernelBuilder::getStreamSetIndex(const std::string & name) const {
378    const auto f = mStreamSetNameMap.find(name);
379    if (LLVM_UNLIKELY(f == mStreamSetNameMap.end())) {
380        llvm::report_fatal_error("Kernel does not contain stream set: " + name);
381    }
382    return f->second;
383}
384
385size_t KernelBuilder::getStreamSetBufferSize(Value * /* self */, const std::string & name) const {
386    const unsigned index = getStreamSetIndex(name);
387    StreamSetBuffer * buf = nullptr;
388    if (index < mStreamSetInputs.size()) {
389        buf = mStreamSetInputBuffers[index];
390    } else {
391        buf = mStreamSetOutputBuffers[index - mStreamSetInputs.size()];
392    }
393    return buf->getBufferSize();
394}
395
396Value * KernelBuilder::getStreamSetStructPtr(Value * self, const std::string & name) const {
397    return getScalarField(self, name + structPtrSuffix);
398}
399
400Value * KernelBuilder::getStreamSetBlockPtr(Value * self, const std::string &name, Value * blockNo) const {
401    Value * const structPtr = getStreamSetStructPtr(self, name);
402    const unsigned index = getStreamSetIndex(name);
403    StreamSetBuffer * buf = nullptr;
404    if (index < mStreamSetInputs.size()) {
405        buf = mStreamSetInputBuffers[index];
406    } else {
407        buf = mStreamSetOutputBuffers[index - mStreamSetInputs.size()];
408    }   
409    return buf->getStreamSetBlockPointer(structPtr, blockNo);
410}
411
412Value * KernelBuilder::getStream(Value * self, const std::string & name, Value * blockNo, Value * index) {
413    return iBuilder->CreateGEP(getStreamSetBlockPtr(self, name, blockNo), {iBuilder->getInt32(0), index});
414}
415
416void KernelBuilder::createInstance() {
417    if (LLVM_UNLIKELY(mKernelStateType == nullptr)) {
418        llvm::report_fatal_error("Cannot create kernel instance before calling prepareKernel()");
419    }
420    mKernelInstance = iBuilder->CreateCacheAlignedAlloca(mKernelStateType);
421    Module * m = iBuilder->getModule();
422    std::vector<Value *> init_args = {mKernelInstance};
423    for (auto a : mInitialArguments) {
424        init_args.push_back(a);
425    }
426    for (auto b : mStreamSetInputBuffers) {
427        init_args.push_back(b->getStreamSetStructPtr());
428    }
429    for (auto b : mStreamSetOutputBuffers) {
430        init_args.push_back(b->getStreamSetStructPtr());
431    }
432    std::string initFnName = mKernelName + init_suffix;
433    Function * initMethod = m->getFunction(initFnName);
434    if (!initMethod) {
435        llvm::report_fatal_error("Cannot find " + initFnName);
436    }
437    iBuilder->CreateCall(initMethod, init_args);
438}
439
440Function * KernelBuilder::generateThreadFunction(const std::string & name) const {
441    if (LLVM_UNLIKELY(mKernelStateType == nullptr)) {
442        llvm::report_fatal_error("Cannot generate thread function before calling prepareKernel()");
443    }
444    Module * m = iBuilder->getModule();
445    Type * const voidTy = iBuilder->getVoidTy();
446    Type * const voidPtrTy = iBuilder->getVoidPtrTy();
447    Type * const int8PtrTy = iBuilder->getInt8PtrTy();
448    Type * const int1ty = iBuilder->getInt1Ty();
449
450    Function * const threadFunc = cast<Function>(m->getOrInsertFunction(name, voidTy, int8PtrTy, nullptr));
451    threadFunc->setCallingConv(CallingConv::C);
452    Function::arg_iterator args = threadFunc->arg_begin();
453
454    Value * const arg = &*(args++);
455    arg->setName("args");
456
457    iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "entry", threadFunc,0));
458
459    Value * self = iBuilder->CreateBitCast(arg, PointerType::get(mKernelStateType, 0));
460
461    std::vector<Value *> inbufProducerPtrs;
462    std::vector<Value *> inbufConsumerPtrs;
463    std::vector<Value *> outbufProducerPtrs;
464    std::vector<Value *> outbufConsumerPtrs;   
465    std::vector<Value *> endSignalPtrs;
466
467    for (unsigned i = 0; i < mStreamSetInputs.size(); i++) {
468        Value * ssStructPtr = getStreamSetStructPtr(self, mStreamSetInputs[i].name);
469        inbufProducerPtrs.push_back(mStreamSetInputBuffers[i]->getProducerPosPtr(ssStructPtr));
470        inbufConsumerPtrs.push_back(mStreamSetInputBuffers[i]->getConsumerPosPtr(ssStructPtr));
471        endSignalPtrs.push_back(mStreamSetInputBuffers[i]->getEndOfInputPtr(ssStructPtr));
472    }
473    for (unsigned i = 0; i < mStreamSetOutputs.size(); i++) {
474        Value * ssStructPtr = getStreamSetStructPtr(self, mStreamSetOutputs[i].name);
475        outbufProducerPtrs.push_back(mStreamSetOutputBuffers[i]->getProducerPosPtr(ssStructPtr));
476        outbufConsumerPtrs.push_back(mStreamSetOutputBuffers[i]->getConsumerPosPtr(ssStructPtr));
477    }
478
479    const unsigned segmentBlocks = codegen::SegmentSize;
480    const unsigned bufferSegments = codegen::BufferSegments;
481    const unsigned segmentSize = segmentBlocks * iBuilder->getBitBlockWidth();
482    Type * const size_ty = iBuilder->getSizeTy();
483
484    Value * segSize = ConstantInt::get(size_ty, segmentSize);
485    Value * bufferSize = ConstantInt::get(size_ty, segmentSize * (bufferSegments - 1));
486    Value * segBlocks = ConstantInt::get(size_ty, segmentBlocks);
487   
488    BasicBlock * outputCheckBlock = BasicBlock::Create(iBuilder->getContext(), "outputCheck", threadFunc, 0);
489    BasicBlock * inputCheckBlock = BasicBlock::Create(iBuilder->getContext(), "inputCheck", threadFunc, 0);
490   
491    BasicBlock * endSignalCheckBlock = BasicBlock::Create(iBuilder->getContext(), "endSignalCheck", threadFunc, 0);
492    BasicBlock * doSegmentBlock = BasicBlock::Create(iBuilder->getContext(), "doSegment", threadFunc, 0);
493    BasicBlock * endBlock = BasicBlock::Create(iBuilder->getContext(), "end", threadFunc, 0);
494    BasicBlock * doFinalSegBlock = BasicBlock::Create(iBuilder->getContext(), "doFinalSeg", threadFunc, 0);
495    BasicBlock * doFinalBlock = BasicBlock::Create(iBuilder->getContext(), "doFinal", threadFunc, 0);
496
497    iBuilder->CreateBr(outputCheckBlock);
498
499    iBuilder->SetInsertPoint(outputCheckBlock);
500
501    Value * waitCondTest = ConstantInt::get(int1ty, 1);   
502    for (unsigned i = 0; i < outbufProducerPtrs.size(); i++) {
503        LoadInst * producerPos = iBuilder->CreateAtomicLoadAcquire(outbufProducerPtrs[i]);
504        // iBuilder->CallPrintInt(name + ":output producerPos", producerPos);
505        LoadInst * consumerPos = iBuilder->CreateAtomicLoadAcquire(outbufConsumerPtrs[i]);
506        // iBuilder->CallPrintInt(name + ":output consumerPos", consumerPos);
507        waitCondTest = iBuilder->CreateAnd(waitCondTest, iBuilder->CreateICmpULE(producerPos, iBuilder->CreateAdd(consumerPos, bufferSize)));
508    }
509   
510    iBuilder->CreateCondBr(waitCondTest, inputCheckBlock, outputCheckBlock); 
511
512    iBuilder->SetInsertPoint(inputCheckBlock); 
513
514    Value * requiredSize = segSize;
515    if (mLookAheadPositions > 0) {
516        requiredSize = iBuilder->CreateAdd(segSize, ConstantInt::get(size_ty, mLookAheadPositions));
517    }
518    waitCondTest = ConstantInt::get(int1ty, 1); 
519    for (unsigned i = 0; i < inbufProducerPtrs.size(); i++) {
520        LoadInst * producerPos = iBuilder->CreateAtomicLoadAcquire(inbufProducerPtrs[i]);
521        // iBuilder->CallPrintInt(name + ":input producerPos", producerPos);
522        LoadInst * consumerPos = iBuilder->CreateAtomicLoadAcquire(inbufConsumerPtrs[i]);
523        // iBuilder->CallPrintInt(name + ":input consumerPos", consumerPos);
524        waitCondTest = iBuilder->CreateAnd(waitCondTest, iBuilder->CreateICmpULE(iBuilder->CreateAdd(consumerPos, requiredSize), producerPos));
525    }
526
527    iBuilder->CreateCondBr(waitCondTest, doSegmentBlock, endSignalCheckBlock);
528   
529    iBuilder->SetInsertPoint(endSignalCheckBlock);
530   
531    LoadInst * endSignal = iBuilder->CreateLoad(endSignalPtrs[0]);
532    for (unsigned i = 1; i < endSignalPtrs.size(); i++){
533        LoadInst * endSignal_next = iBuilder->CreateLoad(endSignalPtrs[i]);
534        iBuilder->CreateAnd(endSignal, endSignal_next);
535    }
536       
537    iBuilder->CreateCondBr(endSignal, endBlock, inputCheckBlock);
538   
539    iBuilder->SetInsertPoint(doSegmentBlock);
540 
541    createDoSegmentCall(self, segBlocks);
542
543    for (unsigned i = 0; i < inbufConsumerPtrs.size(); i++) {
544        Value * consumerPos = iBuilder->CreateAdd(iBuilder->CreateLoad(inbufConsumerPtrs[i]), segSize);
545        iBuilder->CreateAtomicStoreRelease(consumerPos, inbufConsumerPtrs[i]);
546    }
547   
548    for (unsigned i = 0; i < outbufProducerPtrs.size(); i++) {
549        Value * produced = getProducedItemCount(self, mStreamSetOutputs[i].name);
550        iBuilder->CreateAtomicStoreRelease(produced, outbufProducerPtrs[i]);
551    }
552   
553    Value * earlyEndSignal = getTerminationSignal(self);
554    if (earlyEndSignal != ConstantInt::getNullValue(iBuilder->getInt1Ty())) {
555        BasicBlock * earlyEndBlock = BasicBlock::Create(iBuilder->getContext(), "earlyEndSignal", threadFunc, 0);
556        iBuilder->CreateCondBr(earlyEndSignal, earlyEndBlock, outputCheckBlock);
557
558        iBuilder->SetInsertPoint(earlyEndBlock);
559        for (unsigned i = 0; i < mStreamSetOutputs.size(); i++) {
560            Value * ssStructPtr = getStreamSetStructPtr(self, mStreamSetOutputs[i].name);
561            mStreamSetOutputBuffers[i]->setEndOfInput(ssStructPtr);
562        }       
563    }
564    iBuilder->CreateBr(outputCheckBlock);
565     
566    iBuilder->SetInsertPoint(endBlock);
567    LoadInst * producerPos = iBuilder->CreateLoad(inbufProducerPtrs[0]);
568    LoadInst * consumerPos = iBuilder->CreateLoad(inbufConsumerPtrs[0]);
569    Value * remainingBytes = iBuilder->CreateSub(producerPos, consumerPos);
570    Value * blockSize = ConstantInt::get(size_ty, iBuilder->getBitBlockWidth());
571    Value * blocks = iBuilder->CreateUDiv(remainingBytes, blockSize);
572    Value * finalBlockRemainingBytes = iBuilder->CreateURem(remainingBytes, blockSize);
573
574    iBuilder->CreateCondBr(iBuilder->CreateICmpEQ(blocks, ConstantInt::get(size_ty, 0)), doFinalBlock, doFinalSegBlock);
575
576    iBuilder->SetInsertPoint(doFinalSegBlock);
577
578    createDoSegmentCall(self, blocks);
579
580    iBuilder->CreateBr(doFinalBlock);
581
582    iBuilder->SetInsertPoint(doFinalBlock);
583
584    createFinalBlockCall(self, finalBlockRemainingBytes);
585
586    for (unsigned i = 0; i < inbufConsumerPtrs.size(); i++) {
587        Value * consumerPos = iBuilder->CreateAdd(iBuilder->CreateLoad(inbufConsumerPtrs[i]), remainingBytes);
588        iBuilder->CreateAtomicStoreRelease(consumerPos, inbufConsumerPtrs[i]);
589    }
590    for (unsigned i = 0; i < outbufProducerPtrs.size(); i++) {
591        iBuilder->CreateAtomicStoreRelease(producerPos, outbufProducerPtrs[i]);
592    }
593
594    for (unsigned i = 0; i < mStreamSetOutputs.size(); i++) {
595        Value * ssStructPtr = getStreamSetStructPtr(self, mStreamSetOutputs[i].name);
596        mStreamSetOutputBuffers[i]->setEndOfInput(ssStructPtr);
597    }
598
599    iBuilder->CreatePThreadExitCall(Constant::getNullValue(voidPtrTy));
600    iBuilder->CreateRetVoid();
601
602    return threadFunc;
603
604}
605
606KernelBuilder::~KernelBuilder() {
607}
Note: See TracBrowser for help on using the repository browser.