source: icGREP/icgrep-devel/icgrep/kernels/kernel.cpp @ 5247

Last change on this file since 5247 was 5247, checked in by cameron, 3 years ago

Separate processedItemCounts and producedItemCounts for each stream set

File size: 28.3 KB
RevLine 
[4924]1/*
2 *  Copyright (c) 2016 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 */
5
6#include "kernel.h"
[5063]7#include <llvm/IR/Module.h>
8#include <llvm/IR/Type.h>
9#include <llvm/IR/Value.h>
10#include <llvm/Support/raw_ostream.h>
[5174]11#include <llvm/Support/ErrorHandling.h>
[5135]12#include <toolchain.h>
[4924]13
[4959]14using namespace llvm;
[5063]15using namespace kernel;
[4959]16
[5063]17KernelBuilder::KernelBuilder(IDISA::IDISA_Builder * builder,
[5246]18                             std::string kernelName,
19                             std::vector<Binding> stream_inputs,
20                             std::vector<Binding> stream_outputs,
21                             std::vector<Binding> scalar_parameters,
22                             std::vector<Binding> scalar_outputs,
23                             std::vector<Binding> internal_scalars)
[5227]24: KernelInterface(builder, kernelName, stream_inputs, stream_outputs, scalar_parameters, scalar_outputs, internal_scalars) {
[4974]25
[5227]26}
27
[5246]28unsigned KernelBuilder::addScalar(Type * type, const std::string & name) {
[5063]29    if (LLVM_UNLIKELY(mKernelStateType != nullptr)) {
[5227]30        llvm::report_fatal_error("Cannot add kernel field " + name + " after kernel state finalized");
[4924]31    }
[5227]32    const auto index = mKernelFields.size();
33    mKernelMap.emplace(name, index);
34    mKernelFields.push_back(type);
35    return index;
[4924]36}
[4968]37
[5076]38void KernelBuilder::prepareKernel() {
[5246]39    if (LLVM_UNLIKELY(mKernelStateType != nullptr)) {
40        llvm::report_fatal_error("Cannot prepare kernel after kernel state finalized");
41    }
[5142]42    unsigned blockSize = iBuilder->getBitBlockWidth();
[5133]43    if (mStreamSetInputs.size() != mStreamSetInputBuffers.size()) {
[5202]44        std::string tmp;
45        raw_string_ostream out(tmp);
46        out << "kernel contains " << mStreamSetInputBuffers.size() << " input buffers for "
47            << mStreamSetInputs.size() << " input stream sets.";
[5217]48        throw std::runtime_error(out.str());
[5133]49    }
50    if (mStreamSetOutputs.size() != mStreamSetOutputBuffers.size()) {
[5202]51        std::string tmp;
52        raw_string_ostream out(tmp);
53        out << "kernel contains " << mStreamSetOutputBuffers.size() << " output buffers for "
54            << mStreamSetOutputs.size() << " output stream sets.";
[5217]55        throw std::runtime_error(out.str());
[5133]56    }
[5104]57    int streamSetNo = 0;
[5133]58    for (unsigned i = 0; i < mStreamSetInputs.size(); i++) {
[5174]59        if ((mStreamSetInputBuffers[i]->getBufferSize() > 0) && (mStreamSetInputBuffers[i]->getBufferSize() < codegen::SegmentSize + (blockSize + mLookAheadPositions - 1)/blockSize)) {
[5202]60             llvm::report_fatal_error("Kernel preparation: Buffer size too small " + mStreamSetInputs[i].name);
[5142]61        }
[5202]62        mScalarInputs.push_back(Binding{mStreamSetInputBuffers[i]->getStreamSetStructPointerType(), mStreamSetInputs[i].name + structPtrSuffix});
63        mStreamSetNameMap.emplace(mStreamSetInputs[i].name, streamSetNo);
[5247]64        addScalar(iBuilder->getSizeTy(), mStreamSetInputs[i].name + processedItemCountSuffix);
[5104]65        streamSetNo++;
[5086]66    }
[5133]67    for (unsigned i = 0; i < mStreamSetOutputs.size(); i++) {
[5202]68        mScalarInputs.push_back(Binding{mStreamSetOutputBuffers[i]->getStreamSetStructPointerType(), mStreamSetOutputs[i].name + structPtrSuffix});
69        mStreamSetNameMap.emplace(mStreamSetOutputs[i].name, streamSetNo);
[5247]70        addScalar(iBuilder->getSizeTy(), mStreamSetOutputs[i].name + producedItemCountSuffix);
[5104]71        streamSetNo++;
[5086]72    }
[5076]73    for (auto binding : mScalarInputs) {
[5202]74        addScalar(binding.type, binding.name);
[5076]75    }
76    for (auto binding : mScalarOutputs) {
[5202]77        addScalar(binding.type, binding.name);
[5076]78    }
79    for (auto binding : mInternalScalars) {
[5202]80        addScalar(binding.type, binding.name);
[5076]81    }
[5227]82    addScalar(iBuilder->getSizeTy(), blockNoScalar);
83    addScalar(iBuilder->getSizeTy(), logicalSegmentNoScalar);
84    addScalar(iBuilder->getInt1Ty(), terminationSignal);
[5175]85    mKernelStateType = StructType::create(iBuilder->getContext(), mKernelFields, mKernelName);
[4970]86}
87
[5246]88std::unique_ptr<Module> KernelBuilder::createKernelModule(const std::vector<StreamSetBuffer *> & inputs, const std::vector<StreamSetBuffer *> & outputs) {
89    auto saveModule = iBuilder->getModule();
[5202]90    auto savePoint = iBuilder->saveIP();
[5246]91    auto module = make_unique<Module>(mKernelName + "_" + iBuilder->getBitBlockTypeName(), iBuilder->getContext());
92    iBuilder->setModule(module.get());
93    generateKernel(inputs, outputs);
[5063]94    iBuilder->setModule(saveModule);
95    iBuilder->restoreIP(savePoint);
[5246]96    return module;
[4970]97}
98
[5246]99void KernelBuilder::generateKernel(const std::vector<StreamSetBuffer *> & inputs, const std::vector<StreamSetBuffer *> & outputs) {
[5202]100    auto savePoint = iBuilder->saveIP();
[5227]101    Module * const m = iBuilder->getModule();
[5246]102    mStreamSetInputBuffers.assign(inputs.begin(), inputs.end());
103    mStreamSetOutputBuffers.assign(outputs.begin(), outputs.end());
104    prepareKernel();            // possibly overridden by the KernelBuilder subtype
[5227]105    addKernelDeclarations(m);
[5246]106    generateDoBlockMethod();    // must be implemented by the KernelBuilder subtype
107    generateFinalBlockMethod(); // possibly overridden by the KernelBuilder subtype
[5086]108    generateDoSegmentMethod();
[5074]109
[5063]110    // Implement the accumulator get functions
111    for (auto binding : mScalarOutputs) {
[5202]112        auto fnName = mKernelName + accumulator_infix + binding.name;
[5063]113        Function * accumFn = m->getFunction(fnName);
[5202]114        iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "get_" + binding.name, accumFn, 0));
[5063]115        Value * self = &*(accumFn->arg_begin());
[5202]116        Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(binding.name)});
[5063]117        Value * retVal = iBuilder->CreateLoad(ptr);
118        iBuilder->CreateRet(retVal);
[4995]119    }
[5246]120
[5063]121    // Implement the initializer function
122    Function * initFunction = m->getFunction(mKernelName + init_suffix);
[5246]123    iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "Init_entry", initFunction, 0));   
[5063]124    Function::arg_iterator args = initFunction->arg_begin();
[5051]125    Value * self = &*(args++);
[5246]126    initializeKernelState(self);    // possibly overridden by the KernelBuilder subtype
[5063]127    for (auto binding : mScalarInputs) {
[5246]128        Value * param = &*(args++);
[5202]129        Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(binding.name)});
[5246]130        iBuilder->CreateStore(param, ptr);
[5051]131    }
132    iBuilder->CreateRetVoid();
[5063]133    iBuilder->restoreIP(savePoint);
[5051]134}
135
[5246]136void KernelBuilder::initializeKernelState(Value * self) const {
137    iBuilder->CreateStore(ConstantAggregateZero::get(mKernelStateType), self);
[5233]138}
139
[5074]140//  The default finalBlock method simply dispatches to the doBlock routine.
[5246]141void KernelBuilder::generateFinalBlockMethod() const {
[5202]142    auto savePoint = iBuilder->saveIP();
[5074]143    Module * m = iBuilder->getModule();
[5063]144    Function * doBlockFunction = m->getFunction(mKernelName + doBlock_suffix);
145    Function * finalBlockFunction = m->getFunction(mKernelName + finalBlock_suffix);
146    iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "fb_entry", finalBlockFunction, 0));
147    // Final Block arguments: self, remaining, then the standard DoBlock args.
148    Function::arg_iterator args = finalBlockFunction->arg_begin();
149    Value * self = &*(args++);
150    /* Skip "remaining" arg */ args++;
151    std::vector<Value *> doBlockArgs = {self};
152    while (args != finalBlockFunction->arg_end()){
153        doBlockArgs.push_back(&*args++);
154    }
[5115]155    iBuilder->CreateCall(doBlockFunction, doBlockArgs);
[5111]156    iBuilder->CreateRetVoid();
[5063]157    iBuilder->restoreIP(savePoint);
[4986]158}
[4924]159
[5185]160// Note: this may be overridden to incorporate doBlock logic directly into
161// the doSegment function.
[5246]162void KernelBuilder::generateDoBlockLogic(Value * self, Value * /* blockNo */) const {
[5174]163    Function * doBlockFunction = iBuilder->getModule()->getFunction(mKernelName + doBlock_suffix);
[5246]164    iBuilder->CreateCall(doBlockFunction, self);
[5174]165}
166
167//  The default doSegment method dispatches to the doBlock routine for
168//  each block of the given number of blocksToDo, and then updates counts.
[5246]169void KernelBuilder::generateDoSegmentMethod() const {
[5202]170    auto savePoint = iBuilder->saveIP();
[5086]171    Module * m = iBuilder->getModule();
172    Function * doSegmentFunction = m->getFunction(mKernelName + doSegment_suffix);
173    iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "entry", doSegmentFunction, 0));
174    BasicBlock * entryBlock = iBuilder->GetInsertBlock();
[5194]175    BasicBlock * strideLoopCond = BasicBlock::Create(iBuilder->getContext(), "strideLoopCond", doSegmentFunction, 0);
176    BasicBlock * strideLoopBody = BasicBlock::Create(iBuilder->getContext(), "strideLoopBody", doSegmentFunction, 0);
177    BasicBlock * stridesDone = BasicBlock::Create(iBuilder->getContext(), "stridesDone", doSegmentFunction, 0);
178    BasicBlock * checkFinalStride = BasicBlock::Create(iBuilder->getContext(), "checkFinalStride", doSegmentFunction, 0);
179    BasicBlock * checkEndSignals = BasicBlock::Create(iBuilder->getContext(), "checkEndSignals", doSegmentFunction, 0);
[5188]180    BasicBlock * callFinalBlock = BasicBlock::Create(iBuilder->getContext(), "callFinalBlock", doSegmentFunction, 0);
181    BasicBlock * segmentDone = BasicBlock::Create(iBuilder->getContext(), "segmentDone", doSegmentFunction, 0);
[5194]182    BasicBlock * finalExit = BasicBlock::Create(iBuilder->getContext(), "finalExit", doSegmentFunction, 0);
[5165]183    Type * const size_ty = iBuilder->getSizeTy();
[5183]184    Constant * stride = ConstantInt::get(size_ty, iBuilder->getStride());
[5174]185    Value * strideBlocks = ConstantInt::get(size_ty, iBuilder->getStride() / iBuilder->getBitBlockWidth());
[5086]186   
187    Function::arg_iterator args = doSegmentFunction->arg_begin();
188    Value * self = &*(args++);
189    Value * blocksToDo = &*(args);
[5188]190   
[5183]191    std::vector<Value *> inbufProducerPtrs;
[5188]192    std::vector<Value *> endSignalPtrs;
[5183]193    for (unsigned i = 0; i < mStreamSetInputs.size(); i++) {
[5246]194        Value * param = getStreamSetStructPtr(self, mStreamSetInputs[i].name);
195        inbufProducerPtrs.push_back(mStreamSetInputBuffers[i]->getProducerPosPtr(param));
196        endSignalPtrs.push_back(mStreamSetInputBuffers[i]->getEndOfInputPtr(param));
[5183]197    }
198   
[5188]199    std::vector<Value *> producerPos;
[5183]200    /* Determine the actually available data examining all input stream sets. */
[5192]201    LoadInst * p = iBuilder->CreateAtomicLoadAcquire(inbufProducerPtrs[0]);
[5188]202    producerPos.push_back(p);
203    Value * availablePos = producerPos[0];
[5183]204    for (unsigned i = 1; i < inbufProducerPtrs.size(); i++) {
[5192]205        LoadInst * p = iBuilder->CreateAtomicLoadAcquire(inbufProducerPtrs[i]);
[5188]206        producerPos.push_back(p);
[5183]207        /* Set the available position to be the minimum of availablePos and producerPos. */
[5188]208        availablePos = iBuilder->CreateSelect(iBuilder->CreateICmpULT(availablePos, p), availablePos, p);
[5183]209    }
[5247]210    Value * processed = getProcessedItemCount(self, mStreamSetInputs[0].name);
[5183]211    Value * itemsAvail = iBuilder->CreateSub(availablePos, processed);
[5202]212//#ifndef NDEBUG
213//    iBuilder->CallPrintInt(mKernelName + "_itemsAvail", itemsAvail);
214//#endif
[5194]215    Value * stridesToDo = iBuilder->CreateUDiv(blocksToDo, strideBlocks);
216    Value * stridesAvail = iBuilder->CreateUDiv(itemsAvail, stride);
[5183]217    /* Adjust the number of full blocks to do, based on the available data, if necessary. */
[5194]218    Value * lessThanFullSegment = iBuilder->CreateICmpULT(stridesAvail, stridesToDo);
219    stridesToDo = iBuilder->CreateSelect(lessThanFullSegment, stridesAvail, stridesToDo);
220    //iBuilder->CallPrintInt(mKernelName + "_stridesAvail", stridesAvail);
221    iBuilder->CreateBr(strideLoopCond);
[5165]222
[5194]223    iBuilder->SetInsertPoint(strideLoopCond);
224    PHINode * stridesRemaining = iBuilder->CreatePHI(size_ty, 2, "stridesRemaining");
225    stridesRemaining->addIncoming(stridesToDo, entryBlock);
226    Value * notDone = iBuilder->CreateICmpUGT(stridesRemaining, ConstantInt::get(size_ty, 0));
227    iBuilder->CreateCondBr(notDone, strideLoopBody, stridesDone);
[5165]228
[5194]229    iBuilder->SetInsertPoint(strideLoopBody);
[5165]230    Value * blockNo = getScalarField(self, blockNoScalar);   
[5185]231
[5174]232    generateDoBlockLogic(self, blockNo);
233    setBlockNo(self, iBuilder->CreateAdd(blockNo, strideBlocks));
[5194]234    stridesRemaining->addIncoming(iBuilder->CreateSub(stridesRemaining, ConstantInt::get(size_ty, 1)), strideLoopBody);
235    iBuilder->CreateBr(strideLoopCond);
[5111]236   
[5194]237    iBuilder->SetInsertPoint(stridesDone);
238    processed = iBuilder->CreateAdd(processed, iBuilder->CreateMul(stridesToDo, stride));
[5247]239    setProcessedItemCount(self, mStreamSetInputs[0].name, processed);
[5194]240    iBuilder->CreateCondBr(lessThanFullSegment, checkFinalStride, segmentDone);
[5188]241   
[5194]242    iBuilder->SetInsertPoint(checkFinalStride);
[5188]243   
244    /* We had less than a full segment of data; we may have reached the end of input
245       on one of the stream sets.  */
246   
[5194]247    Value * alreadyDone = getTerminationSignal(self);
248    iBuilder->CreateCondBr(alreadyDone, finalExit, checkEndSignals);
249   
250    iBuilder->SetInsertPoint(checkEndSignals);
[5188]251    Value * endOfInput = iBuilder->CreateLoad(endSignalPtrs[0]);
252    if (endSignalPtrs.size() > 1) {
253        /* If there is more than one input stream set, then we need to confirm that one of
254           them has both the endSignal set and the length = to availablePos. */
255        endOfInput = iBuilder->CreateAnd(endOfInput, iBuilder->CreateICmpEQ(availablePos, producerPos[0]));
256        for (unsigned i = 1; i < endSignalPtrs.size(); i++) {
257            Value * e = iBuilder->CreateAnd(iBuilder->CreateLoad(endSignalPtrs[i]), iBuilder->CreateICmpEQ(availablePos, producerPos[i]));
258            endOfInput = iBuilder->CreateOr(endOfInput, e);
259        }
260    }
261    iBuilder->CreateCondBr(endOfInput, callFinalBlock, segmentDone);
262   
263    iBuilder->SetInsertPoint(callFinalBlock);
264   
[5194]265    Value * remainingItems = iBuilder->CreateSub(availablePos, processed);
[5188]266    createFinalBlockCall(self, remainingItems);
[5247]267    setProcessedItemCount(self, mStreamSetInputs[0].name, availablePos);
[5188]268   
269    for (unsigned i = 0; i < mStreamSetOutputs.size(); i++) {
[5202]270        Value * ssStructPtr = getStreamSetStructPtr(self, mStreamSetOutputs[i].name);
[5188]271        mStreamSetOutputBuffers[i]->setEndOfInput(ssStructPtr);
272    }
[5194]273    setTerminationSignal(self);
[5188]274    iBuilder->CreateBr(segmentDone);
275   
276    iBuilder->SetInsertPoint(segmentDone);
[5202]277//#ifndef NDEBUG
278//    iBuilder->CallPrintInt(mKernelName + "_produced", produced);
279//#endif
[5183]280    for (unsigned i = 0; i < mStreamSetOutputs.size(); i++) {
[5247]281        Value * produced = getProducedItemCount(self, mStreamSetOutputs[i].name);
[5202]282        Value * ssStructPtr = getStreamSetStructPtr(self, mStreamSetOutputs[i].name);
[5185]283        Value * producerPosPtr = mStreamSetOutputBuffers[i]->getProducerPosPtr(ssStructPtr);
[5192]284        iBuilder->CreateAtomicStoreRelease(produced, producerPosPtr);
[5183]285    }
[5194]286    iBuilder->CreateBr(finalExit);
287    iBuilder->SetInsertPoint(finalExit);
[5174]288
[5111]289    iBuilder->CreateRetVoid();
[5086]290    iBuilder->restoreIP(savePoint);
291}
292
[5227]293ConstantInt * KernelBuilder::getScalarIndex(const std::string & name) const {
294    const auto f = mKernelMap.find(name);
295    if (LLVM_UNLIKELY(f == mKernelMap.end())) {
[5246]296        llvm::report_fatal_error("Kernel does not contain scalar: " + name);
[5000]297    }
[5104]298    return iBuilder->getInt32(f->second);
[4959]299}
[4924]300
[5227]301unsigned KernelBuilder::getScalarCount() const {
302    return mKernelFields.size();
303}
304
[5246]305Value * KernelBuilder::getScalarFieldPtr(Value * self, const std::string & fieldName) const {
[5202]306    return iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(fieldName)});
307}
[5109]308
[5246]309Value * KernelBuilder::getScalarField(Value * self, const std::string & fieldName) const {
[5202]310    return iBuilder->CreateLoad(getScalarFieldPtr(self, fieldName));
[4924]311}
312
[5246]313void KernelBuilder::setScalarField(Value * self, const std::string & fieldName, Value * newFieldVal) const {
[5202]314    iBuilder->CreateStore(newFieldVal, getScalarFieldPtr(self, fieldName));
[5008]315}
[5063]316
[5246]317LoadInst * KernelBuilder::acquireLogicalSegmentNo(Value * self) const {
[5174]318    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(logicalSegmentNoScalar)});
[5246]319    return iBuilder->CreateAtomicLoadAcquire(ptr);
[5174]320}
321
[5247]322Value * KernelBuilder::getProcessedItemCount(Value * self, const std::string & ssName) const {
323    return getScalarField(self, ssName + processedItemCountSuffix);
[5174]324}
325
[5247]326Value * KernelBuilder::getProducedItemCount(Value * self, const std::string & ssName) const {
327    return getScalarField(self, ssName + producedItemCountSuffix);
[5174]328}
329
[5246]330Value * KernelBuilder::getTerminationSignal(Value * self) const {
[5194]331    return getScalarField(self, terminationSignal);
[5174]332}
333
[5246]334void KernelBuilder::releaseLogicalSegmentNo(Value * self, Value * newCount) const {
[5174]335    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(logicalSegmentNoScalar)});
[5192]336    iBuilder->CreateAtomicStoreRelease(newCount, ptr);
[5174]337}
338
[5247]339void KernelBuilder::setProcessedItemCount(Value * self, const std::string & ssName, Value * newCount) const {
340    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(ssName + processedItemCountSuffix)});
[5174]341    iBuilder->CreateStore(newCount, ptr);
342}
343
[5247]344void KernelBuilder::setProducedItemCount(Value * self, const std::string & ssName, Value * newCount) const {
345    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(ssName + producedItemCountSuffix)});
[5174]346    iBuilder->CreateStore(newCount, ptr);
347}
348
[5246]349void KernelBuilder::setTerminationSignal(Value * self) const {
[5194]350    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(terminationSignal)});
351    iBuilder->CreateStore(ConstantInt::get(iBuilder->getInt1Ty(), 1), ptr);
[5174]352}
353
[5246]354Value * KernelBuilder::getBlockNo(Value * self) const {
[5165]355    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(blockNoScalar)});
[5246]356    return iBuilder->CreateLoad(ptr);
[5165]357}
[5063]358
[5246]359void KernelBuilder::setBlockNo(Value * self, Value * newFieldVal) const {
[5165]360    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(blockNoScalar)});
[5174]361    iBuilder->CreateStore(newFieldVal, ptr);
[5165]362}
363
364
[5246]365Value * KernelBuilder::getParameter(Function * f, const std::string & paramName) const {
[5063]366    for (Function::arg_iterator argIter = f->arg_begin(), end = f->arg_end(); argIter != end; argIter++) {
367        Value * arg = &*argIter;
368        if (arg->getName() == paramName) return arg;
[5051]369    }
[5174]370    llvm::report_fatal_error("Method does not have parameter: " + paramName);
[5051]371}
[5008]372
[5246]373unsigned KernelBuilder::getStreamSetIndex(const std::string & name) const {
[5202]374    const auto f = mStreamSetNameMap.find(name);
[5104]375    if (LLVM_UNLIKELY(f == mStreamSetNameMap.end())) {
[5202]376        llvm::report_fatal_error("Kernel does not contain stream set: " + name);
[5104]377    }
378    return f->second;
379}
[5063]380
[5246]381size_t KernelBuilder::getStreamSetBufferSize(Value * /* self */, const std::string & name) const {
[5202]382    const unsigned index = getStreamSetIndex(name);
383    StreamSetBuffer * buf = nullptr;
384    if (index < mStreamSetInputs.size()) {
385        buf = mStreamSetInputBuffers[index];
386    } else {
387        buf = mStreamSetOutputBuffers[index - mStreamSetInputs.size()];
[5109]388    }
[5202]389    return buf->getBufferSize();
[5109]390}
391
[5246]392Value * KernelBuilder::getStreamSetStructPtr(Value * self, const std::string & name) const {
[5202]393    return getScalarField(self, name + structPtrSuffix);
[5104]394}
395
[5246]396Value * KernelBuilder::getStreamSetBlockPtr(Value * self, const std::string &name, Value * blockNo) const {
[5202]397    Value * const structPtr = getStreamSetStructPtr(self, name);
398    const unsigned index = getStreamSetIndex(name);
399    StreamSetBuffer * buf = nullptr;
400    if (index < mStreamSetInputs.size()) {
401        buf = mStreamSetInputBuffers[index];
402    } else {
403        buf = mStreamSetOutputBuffers[index - mStreamSetInputs.size()];
404    }   
405    return buf->getStreamSetBlockPointer(structPtr, blockNo);
[5104]406}
407
[5246]408Value * KernelBuilder::getStream(Value * self, const std::string & name, Value * blockNo, Value * index) {
409    return iBuilder->CreateGEP(getStreamSetBlockPtr(self, name, blockNo), {iBuilder->getInt32(0), index});
410}
411
[5220]412void KernelBuilder::createInstance() {
[5246]413    if (LLVM_UNLIKELY(mKernelStateType == nullptr)) {
414        llvm::report_fatal_error("Cannot create kernel instance before calling prepareKernel()");
415    }
[5220]416    mKernelInstance = iBuilder->CreateCacheAlignedAlloca(mKernelStateType);
[5133]417    Module * m = iBuilder->getModule();
[5220]418    std::vector<Value *> init_args = {mKernelInstance};
419    for (auto a : mInitialArguments) {
[5133]420        init_args.push_back(a);
421    }
[5202]422    for (auto b : mStreamSetInputBuffers) {
[5135]423        init_args.push_back(b->getStreamSetStructPtr());
[5133]424    }
[5202]425    for (auto b : mStreamSetOutputBuffers) {
[5135]426        init_args.push_back(b->getStreamSetStructPtr());
[5133]427    }
428    std::string initFnName = mKernelName + init_suffix;
429    Function * initMethod = m->getFunction(initFnName);
430    if (!initMethod) {
[5174]431        llvm::report_fatal_error("Cannot find " + initFnName);
[5133]432    }
433    iBuilder->CreateCall(initMethod, init_args);
434}
[5104]435
[5246]436Function * KernelBuilder::generateThreadFunction(const std::string & name) const {
437    if (LLVM_UNLIKELY(mKernelStateType == nullptr)) {
438        llvm::report_fatal_error("Cannot generate thread function before calling prepareKernel()");
439    }
[5135]440    Module * m = iBuilder->getModule();
[5230]441    Type * const voidTy = iBuilder->getVoidTy();
[5227]442    Type * const voidPtrTy = iBuilder->getVoidPtrTy();
[5135]443    Type * const int8PtrTy = iBuilder->getInt8PtrTy();
444    Type * const int1ty = iBuilder->getInt1Ty();
[5104]445
[5135]446    Function * const threadFunc = cast<Function>(m->getOrInsertFunction(name, voidTy, int8PtrTy, nullptr));
447    threadFunc->setCallingConv(CallingConv::C);
448    Function::arg_iterator args = threadFunc->arg_begin();
[5104]449
[5135]450    Value * const arg = &*(args++);
451    arg->setName("args");
[5133]452
[5135]453    iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "entry", threadFunc,0));
454
455    Value * self = iBuilder->CreateBitCast(arg, PointerType::get(mKernelStateType, 0));
456
457    std::vector<Value *> inbufProducerPtrs;
458    std::vector<Value *> inbufConsumerPtrs;
459    std::vector<Value *> outbufProducerPtrs;
460    std::vector<Value *> outbufConsumerPtrs;   
461    std::vector<Value *> endSignalPtrs;
462
463    for (unsigned i = 0; i < mStreamSetInputs.size(); i++) {
[5202]464        Value * ssStructPtr = getStreamSetStructPtr(self, mStreamSetInputs[i].name);
[5185]465        inbufProducerPtrs.push_back(mStreamSetInputBuffers[i]->getProducerPosPtr(ssStructPtr));
466        inbufConsumerPtrs.push_back(mStreamSetInputBuffers[i]->getConsumerPosPtr(ssStructPtr));
[5217]467        endSignalPtrs.push_back(mStreamSetInputBuffers[i]->getEndOfInputPtr(ssStructPtr));
[5135]468    }
469    for (unsigned i = 0; i < mStreamSetOutputs.size(); i++) {
[5202]470        Value * ssStructPtr = getStreamSetStructPtr(self, mStreamSetOutputs[i].name);
[5185]471        outbufProducerPtrs.push_back(mStreamSetOutputBuffers[i]->getProducerPosPtr(ssStructPtr));
472        outbufConsumerPtrs.push_back(mStreamSetOutputBuffers[i]->getConsumerPosPtr(ssStructPtr));
[5135]473    }
474
475    const unsigned segmentBlocks = codegen::SegmentSize;
476    const unsigned bufferSegments = codegen::BufferSegments;
477    const unsigned segmentSize = segmentBlocks * iBuilder->getBitBlockWidth();
478    Type * const size_ty = iBuilder->getSizeTy();
479
480    Value * segSize = ConstantInt::get(size_ty, segmentSize);
481    Value * bufferSize = ConstantInt::get(size_ty, segmentSize * (bufferSegments - 1));
482    Value * segBlocks = ConstantInt::get(size_ty, segmentBlocks);
483   
484    BasicBlock * outputCheckBlock = BasicBlock::Create(iBuilder->getContext(), "outputCheck", threadFunc, 0);
485    BasicBlock * inputCheckBlock = BasicBlock::Create(iBuilder->getContext(), "inputCheck", threadFunc, 0);
486   
487    BasicBlock * endSignalCheckBlock = BasicBlock::Create(iBuilder->getContext(), "endSignalCheck", threadFunc, 0);
488    BasicBlock * doSegmentBlock = BasicBlock::Create(iBuilder->getContext(), "doSegment", threadFunc, 0);
489    BasicBlock * endBlock = BasicBlock::Create(iBuilder->getContext(), "end", threadFunc, 0);
490    BasicBlock * doFinalSegBlock = BasicBlock::Create(iBuilder->getContext(), "doFinalSeg", threadFunc, 0);
491    BasicBlock * doFinalBlock = BasicBlock::Create(iBuilder->getContext(), "doFinal", threadFunc, 0);
492
493    iBuilder->CreateBr(outputCheckBlock);
494
495    iBuilder->SetInsertPoint(outputCheckBlock);
496
497    Value * waitCondTest = ConstantInt::get(int1ty, 1);   
498    for (unsigned i = 0; i < outbufProducerPtrs.size(); i++) {
[5192]499        LoadInst * producerPos = iBuilder->CreateAtomicLoadAcquire(outbufProducerPtrs[i]);
[5135]500        // iBuilder->CallPrintInt(name + ":output producerPos", producerPos);
[5192]501        LoadInst * consumerPos = iBuilder->CreateAtomicLoadAcquire(outbufConsumerPtrs[i]);
[5135]502        // iBuilder->CallPrintInt(name + ":output consumerPos", consumerPos);
503        waitCondTest = iBuilder->CreateAnd(waitCondTest, iBuilder->CreateICmpULE(producerPos, iBuilder->CreateAdd(consumerPos, bufferSize)));
504    }
505   
506    iBuilder->CreateCondBr(waitCondTest, inputCheckBlock, outputCheckBlock); 
507
508    iBuilder->SetInsertPoint(inputCheckBlock); 
509
[5174]510    Value * requiredSize = segSize;
511    if (mLookAheadPositions > 0) {
512        requiredSize = iBuilder->CreateAdd(segSize, ConstantInt::get(size_ty, mLookAheadPositions));
513    }
[5135]514    waitCondTest = ConstantInt::get(int1ty, 1); 
515    for (unsigned i = 0; i < inbufProducerPtrs.size(); i++) {
[5192]516        LoadInst * producerPos = iBuilder->CreateAtomicLoadAcquire(inbufProducerPtrs[i]);
[5135]517        // iBuilder->CallPrintInt(name + ":input producerPos", producerPos);
[5192]518        LoadInst * consumerPos = iBuilder->CreateAtomicLoadAcquire(inbufConsumerPtrs[i]);
[5135]519        // iBuilder->CallPrintInt(name + ":input consumerPos", consumerPos);
[5174]520        waitCondTest = iBuilder->CreateAnd(waitCondTest, iBuilder->CreateICmpULE(iBuilder->CreateAdd(consumerPos, requiredSize), producerPos));
[5135]521    }
522
523    iBuilder->CreateCondBr(waitCondTest, doSegmentBlock, endSignalCheckBlock);
524   
525    iBuilder->SetInsertPoint(endSignalCheckBlock);
526   
[5188]527    LoadInst * endSignal = iBuilder->CreateLoad(endSignalPtrs[0]);
[5135]528    for (unsigned i = 1; i < endSignalPtrs.size(); i++){
[5188]529        LoadInst * endSignal_next = iBuilder->CreateLoad(endSignalPtrs[i]);
[5135]530        iBuilder->CreateAnd(endSignal, endSignal_next);
531    }
532       
[5188]533    iBuilder->CreateCondBr(endSignal, endBlock, inputCheckBlock);
[5135]534   
535    iBuilder->SetInsertPoint(doSegmentBlock);
536 
537    createDoSegmentCall(self, segBlocks);
538
539    for (unsigned i = 0; i < inbufConsumerPtrs.size(); i++) {
540        Value * consumerPos = iBuilder->CreateAdd(iBuilder->CreateLoad(inbufConsumerPtrs[i]), segSize);
[5192]541        iBuilder->CreateAtomicStoreRelease(consumerPos, inbufConsumerPtrs[i]);
[5135]542    }
[5174]543   
[5135]544    for (unsigned i = 0; i < outbufProducerPtrs.size(); i++) {
[5247]545        Value * produced = getProducedItemCount(self, mStreamSetOutputs[i].name);
[5192]546        iBuilder->CreateAtomicStoreRelease(produced, outbufProducerPtrs[i]);
[5135]547    }
548   
[5174]549    Value * earlyEndSignal = getTerminationSignal(self);
550    if (earlyEndSignal != ConstantInt::getNullValue(iBuilder->getInt1Ty())) {
551        BasicBlock * earlyEndBlock = BasicBlock::Create(iBuilder->getContext(), "earlyEndSignal", threadFunc, 0);
552        iBuilder->CreateCondBr(earlyEndSignal, earlyEndBlock, outputCheckBlock);
553
554        iBuilder->SetInsertPoint(earlyEndBlock);
555        for (unsigned i = 0; i < mStreamSetOutputs.size(); i++) {
[5202]556            Value * ssStructPtr = getStreamSetStructPtr(self, mStreamSetOutputs[i].name);
[5185]557            mStreamSetOutputBuffers[i]->setEndOfInput(ssStructPtr);
[5174]558        }       
559    }
[5135]560    iBuilder->CreateBr(outputCheckBlock);
561     
562    iBuilder->SetInsertPoint(endBlock);
563    LoadInst * producerPos = iBuilder->CreateLoad(inbufProducerPtrs[0]);
564    LoadInst * consumerPos = iBuilder->CreateLoad(inbufConsumerPtrs[0]);
565    Value * remainingBytes = iBuilder->CreateSub(producerPos, consumerPos);
566    Value * blockSize = ConstantInt::get(size_ty, iBuilder->getBitBlockWidth());
567    Value * blocks = iBuilder->CreateUDiv(remainingBytes, blockSize);
568    Value * finalBlockRemainingBytes = iBuilder->CreateURem(remainingBytes, blockSize);
569
570    iBuilder->CreateCondBr(iBuilder->CreateICmpEQ(blocks, ConstantInt::get(size_ty, 0)), doFinalBlock, doFinalSegBlock);
571
572    iBuilder->SetInsertPoint(doFinalSegBlock);
573
574    createDoSegmentCall(self, blocks);
575
576    iBuilder->CreateBr(doFinalBlock);
577
578    iBuilder->SetInsertPoint(doFinalBlock);
579
580    createFinalBlockCall(self, finalBlockRemainingBytes);
581
582    for (unsigned i = 0; i < inbufConsumerPtrs.size(); i++) {
583        Value * consumerPos = iBuilder->CreateAdd(iBuilder->CreateLoad(inbufConsumerPtrs[i]), remainingBytes);
[5192]584        iBuilder->CreateAtomicStoreRelease(consumerPos, inbufConsumerPtrs[i]);
[5135]585    }
586    for (unsigned i = 0; i < outbufProducerPtrs.size(); i++) {
[5192]587        iBuilder->CreateAtomicStoreRelease(producerPos, outbufProducerPtrs[i]);
[5135]588    }
589
590    for (unsigned i = 0; i < mStreamSetOutputs.size(); i++) {
[5202]591        Value * ssStructPtr = getStreamSetStructPtr(self, mStreamSetOutputs[i].name);
[5185]592        mStreamSetOutputBuffers[i]->setEndOfInput(ssStructPtr);
[5135]593    }
594
[5242]595    iBuilder->CreatePThreadExitCall(Constant::getNullValue(voidPtrTy));
[5135]596    iBuilder->CreateRetVoid();
597
598    return threadFunc;
599
600}
[5246]601
602KernelBuilder::~KernelBuilder() {
603}
Note: See TracBrowser for help on using the repository browser.