source: icGREP/icgrep-devel/icgrep/kernels/kernel.cpp @ 5193

Last change on this file since 5193 was 5193, checked in by cameron, 3 years ago

Fixes for NVPTX (but strideBlocks needs further investigation), u8u16

File size: 27.2 KB
RevLine 
[4924]1/*
2 *  Copyright (c) 2016 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 */
5
6#include "kernel.h"
[5063]7#include <llvm/IR/Module.h>
8#include <llvm/IR/Type.h>
9#include <llvm/IR/Value.h>
10#include <llvm/Support/raw_ostream.h>
[5135]11#include <llvm/IR/TypeBuilder.h>
[5174]12#include <llvm/Support/ErrorHandling.h>
[5135]13#include <toolchain.h>
[4924]14
[4959]15using namespace llvm;
[5063]16using namespace kernel;
[4959]17
[5063]18KernelBuilder::KernelBuilder(IDISA::IDISA_Builder * builder,
19                                 std::string kernelName,
20                                 std::vector<StreamSetBinding> stream_inputs,
21                                 std::vector<StreamSetBinding> stream_outputs,
22                                 std::vector<ScalarBinding> scalar_parameters,
23                                 std::vector<ScalarBinding> scalar_outputs,
24                                 std::vector<ScalarBinding> internal_scalars) :
[5076]25    KernelInterface(builder, kernelName, stream_inputs, stream_outputs, scalar_parameters, scalar_outputs, internal_scalars) {}
[4974]26
[5063]27void KernelBuilder::addScalar(Type * t, std::string scalarName) {
28    if (LLVM_UNLIKELY(mKernelStateType != nullptr)) {
[5174]29        llvm::report_fatal_error("Illegal addition of kernel field after kernel state finalized: " + scalarName);
[4924]30    }
[5063]31    unsigned index = mKernelFields.size();
32    mKernelFields.push_back(t);
[5104]33    mInternalStateNameMap.emplace(scalarName, index);
[4924]34}
[4968]35
[5076]36void KernelBuilder::prepareKernel() {
[5142]37    unsigned blockSize = iBuilder->getBitBlockWidth();
[5133]38    if (mStreamSetInputs.size() != mStreamSetInputBuffers.size()) {
[5174]39        llvm::report_fatal_error("Kernel preparation: Incorrect number of input buffers");
[5133]40    }
41    if (mStreamSetOutputs.size() != mStreamSetOutputBuffers.size()) {
[5174]42        llvm::report_fatal_error("Kernel preparation: Incorrect number of output buffers");
[5133]43    }
[5106]44    addScalar(iBuilder->getSizeTy(), blockNoScalar);
[5174]45    addScalar(iBuilder->getSizeTy(), logicalSegmentNoScalar);
46    addScalar(iBuilder->getSizeTy(), processedItemCount);
47    addScalar(iBuilder->getSizeTy(), producedItemCount);
48    addScalar(iBuilder->getInt1Ty(), terminationSignal);
[5104]49    int streamSetNo = 0;
[5133]50    for (unsigned i = 0; i < mStreamSetInputs.size(); i++) {
51        if (!(mStreamSetInputBuffers[i]->getBufferStreamSetType() == mStreamSetInputs[i].ssType)) {
[5174]52             llvm::report_fatal_error("Kernel preparation: Incorrect input buffer type");
[5133]53        }
[5174]54        if ((mStreamSetInputBuffers[i]->getBufferSize() > 0) && (mStreamSetInputBuffers[i]->getBufferSize() < codegen::SegmentSize + (blockSize + mLookAheadPositions - 1)/blockSize)) {
[5185]55             errs() << " buffer size = " << mStreamSetInputBuffers[i]->getBufferSize() << "\n";
56             llvm::report_fatal_error("Kernel preparation: Buffer size too small " + mStreamSetInputs[i].ssName);
[5142]57        }
[5185]58        mScalarInputs.push_back(ScalarBinding{mStreamSetInputBuffers[i]->getStreamSetStructPointerType(), mStreamSetInputs[i].ssName + structPtrSuffix});
[5133]59        mStreamSetNameMap.emplace(mStreamSetInputs[i].ssName, streamSetNo);
[5104]60        streamSetNo++;
[5086]61    }
[5133]62    for (unsigned i = 0; i < mStreamSetOutputs.size(); i++) {
63        if (!(mStreamSetOutputBuffers[i]->getBufferStreamSetType() == mStreamSetOutputs[i].ssType)) {
[5185]64             llvm::report_fatal_error("Kernel preparation: Incorrect output buffer type " + mStreamSetOutputs[i].ssName);
[5133]65        }
[5185]66        mScalarInputs.push_back(ScalarBinding{mStreamSetOutputBuffers[i]->getStreamSetStructPointerType(), mStreamSetOutputs[i].ssName + structPtrSuffix});
[5133]67        mStreamSetNameMap.emplace(mStreamSetOutputs[i].ssName, streamSetNo);
[5104]68        streamSetNo++;
[5086]69    }
[5076]70    for (auto binding : mScalarInputs) {
71        addScalar(binding.scalarType, binding.scalarName);
72    }
73    for (auto binding : mScalarOutputs) {
74        addScalar(binding.scalarType, binding.scalarName);
75    }
76    for (auto binding : mInternalScalars) {
77        addScalar(binding.scalarType, binding.scalarName);
78    }
[5175]79    mKernelStateType = StructType::create(iBuilder->getContext(), mKernelFields, mKernelName);
[4970]80}
81
[5133]82std::unique_ptr<Module> KernelBuilder::createKernelModule(std::vector<StreamSetBuffer *> input_buffers, std::vector<StreamSetBuffer *> output_buffers) {
[5063]83    Module * saveModule = iBuilder->getModule();
84    IDISA::IDISA_Builder::InsertPoint savePoint = iBuilder->saveIP();
[5175]85    std::unique_ptr<Module> theModule = make_unique<Module>(mKernelName + "_" + iBuilder->getBitBlockTypeName(), iBuilder->getContext());
[5063]86    Module * m = theModule.get();
87    iBuilder->setModule(m);
[5133]88    generateKernel(input_buffers, output_buffers);
[5063]89    iBuilder->setModule(saveModule);
90    iBuilder->restoreIP(savePoint);
91    return theModule;
[4970]92}
93
[5133]94void KernelBuilder::generateKernel(std::vector<StreamSetBuffer *> input_buffers, std::vector<StreamSetBuffer*> output_buffers) {
[5074]95    IDISA::IDISA_Builder::InsertPoint savePoint = iBuilder->saveIP();
[5063]96    Module * m = iBuilder->getModule();
[5133]97    mStreamSetInputBuffers = input_buffers;
98    mStreamSetOutputBuffers = output_buffers;
[5076]99    prepareKernel();  // possibly overriden by the KernelBuilder subtype
[5074]100    KernelInterface::addKernelDeclarations(m);
101    generateDoBlockMethod();     // must be implemented by the KernelBuilder subtype
102    generateFinalBlockMethod();  // possibly overriden by the KernelBuilder subtype
[5086]103    generateDoSegmentMethod();
[5074]104
[5063]105    // Implement the accumulator get functions
106    for (auto binding : mScalarOutputs) {
107        auto fnName = mKernelName + accumulator_infix + binding.scalarName;
108        Function * accumFn = m->getFunction(fnName);
109        iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "get_" + binding.scalarName, accumFn, 0));
110        Value * self = &*(accumFn->arg_begin());
111        Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(binding.scalarName)});
112        Value * retVal = iBuilder->CreateLoad(ptr);
113        iBuilder->CreateRet(retVal);
[4995]114    }
[5063]115    // Implement the initializer function
116    Function * initFunction = m->getFunction(mKernelName + init_suffix);
117    iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "Init_entry", initFunction, 0));
118   
119    Function::arg_iterator args = initFunction->arg_begin();
[5051]120    Value * self = &*(args++);
121    iBuilder->CreateStore(Constant::getNullValue(mKernelStateType), self);
[5063]122    for (auto binding : mScalarInputs) {
[5051]123        Value * parm = &*(args++);
[5063]124        Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(binding.scalarName)});
125        iBuilder->CreateStore(parm, ptr);
[5051]126    }
127    iBuilder->CreateRetVoid();
[5063]128    iBuilder->restoreIP(savePoint);
[5051]129}
130
[5074]131//  The default finalBlock method simply dispatches to the doBlock routine.
132void KernelBuilder::generateFinalBlockMethod() {
[5063]133    IDISA::IDISA_Builder::InsertPoint savePoint = iBuilder->saveIP();
[5074]134    Module * m = iBuilder->getModule();
[5063]135    Function * doBlockFunction = m->getFunction(mKernelName + doBlock_suffix);
136    Function * finalBlockFunction = m->getFunction(mKernelName + finalBlock_suffix);
137    iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "fb_entry", finalBlockFunction, 0));
138    // Final Block arguments: self, remaining, then the standard DoBlock args.
139    Function::arg_iterator args = finalBlockFunction->arg_begin();
140    Value * self = &*(args++);
141    /* Skip "remaining" arg */ args++;
142    std::vector<Value *> doBlockArgs = {self};
143    while (args != finalBlockFunction->arg_end()){
144        doBlockArgs.push_back(&*args++);
145    }
[5115]146    iBuilder->CreateCall(doBlockFunction, doBlockArgs);
[5111]147    iBuilder->CreateRetVoid();
[5063]148    iBuilder->restoreIP(savePoint);
[4986]149}
[4924]150
[5185]151// Note: this may be overridden to incorporate doBlock logic directly into
152// the doSegment function.
[5174]153void KernelBuilder::generateDoBlockLogic(Value * self, Value * blockNo) {
154    Function * doBlockFunction = iBuilder->getModule()->getFunction(mKernelName + doBlock_suffix);
155    iBuilder->CreateCall(doBlockFunction, {self});
156}
157
158//  The default doSegment method dispatches to the doBlock routine for
159//  each block of the given number of blocksToDo, and then updates counts.
[5086]160void KernelBuilder::generateDoSegmentMethod() {
161    IDISA::IDISA_Builder::InsertPoint savePoint = iBuilder->saveIP();
162    Module * m = iBuilder->getModule();
163    Function * doSegmentFunction = m->getFunction(mKernelName + doSegment_suffix);
164    iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "entry", doSegmentFunction, 0));
165    BasicBlock * entryBlock = iBuilder->GetInsertBlock();
[5165]166    BasicBlock * blockLoopCond = BasicBlock::Create(iBuilder->getContext(), "blockLoopCond", doSegmentFunction, 0);
167    BasicBlock * blockLoopBody = BasicBlock::Create(iBuilder->getContext(), "blockLoopBody", doSegmentFunction, 0);
[5086]168    BasicBlock * blocksDone = BasicBlock::Create(iBuilder->getContext(), "blocksDone", doSegmentFunction, 0);
[5188]169    BasicBlock * checkFinalBlock = BasicBlock::Create(iBuilder->getContext(), "checkFinalBlock", doSegmentFunction, 0);
170    BasicBlock * callFinalBlock = BasicBlock::Create(iBuilder->getContext(), "callFinalBlock", doSegmentFunction, 0);
171    BasicBlock * segmentDone = BasicBlock::Create(iBuilder->getContext(), "segmentDone", doSegmentFunction, 0);
[5165]172    Type * const size_ty = iBuilder->getSizeTy();
[5183]173    Constant * stride = ConstantInt::get(size_ty, iBuilder->getStride());
[5174]174    Value * strideBlocks = ConstantInt::get(size_ty, iBuilder->getStride() / iBuilder->getBitBlockWidth());
[5086]175   
176    Function::arg_iterator args = doSegmentFunction->arg_begin();
177    Value * self = &*(args++);
178    Value * blocksToDo = &*(args);
[5174]179    Value * segmentNo = getLogicalSegmentNo(self);
[5188]180   
[5183]181    std::vector<Value *> inbufProducerPtrs;
[5188]182    std::vector<Value *> endSignalPtrs;
[5183]183    for (unsigned i = 0; i < mStreamSetInputs.size(); i++) {
[5185]184        Value * ssStructPtr = getStreamSetStructPtr(self, mStreamSetInputs[i].ssName);
185        inbufProducerPtrs.push_back(mStreamSetInputBuffers[i]->getProducerPosPtr(ssStructPtr));
[5188]186        endSignalPtrs.push_back(mStreamSetInputBuffers[i]->hasEndOfInputPtr(ssStructPtr));
[5183]187    }
188   
[5188]189    std::vector<Value *> producerPos;
[5183]190    /* Determine the actually available data examining all input stream sets. */
[5192]191    LoadInst * p = iBuilder->CreateAtomicLoadAcquire(inbufProducerPtrs[0]);
[5188]192    producerPos.push_back(p);
193    Value * availablePos = producerPos[0];
[5183]194    for (unsigned i = 1; i < inbufProducerPtrs.size(); i++) {
[5192]195       
196        LoadInst * p = iBuilder->CreateAtomicLoadAcquire(inbufProducerPtrs[i]);
[5188]197        producerPos.push_back(p);
[5183]198        /* Set the available position to be the minimum of availablePos and producerPos. */
[5188]199        availablePos = iBuilder->CreateSelect(iBuilder->CreateICmpULT(availablePos, p), availablePos, p);
[5183]200    }
201    Value * processed = getProcessedItemCount(self);
202    Value * itemsAvail = iBuilder->CreateSub(availablePos, processed);
203#ifndef NDEBUG
204    iBuilder->CallPrintInt(mKernelName + "_itemsAvail", itemsAvail);
205#endif
206    Value * blocksAvail = iBuilder->CreateUDiv(itemsAvail, stride);
207    /* Adjust the number of full blocks to do, based on the available data, if necessary. */
[5188]208    Value * lessThanFullSegment = iBuilder->CreateICmpULT(blocksAvail, blocksToDo);
209    blocksToDo = iBuilder->CreateSelect(lessThanFullSegment, blocksAvail, blocksToDo);
[5185]210    //iBuilder->CallPrintInt(mKernelName + "_blocksAvail", blocksAvail);
[5165]211    iBuilder->CreateBr(blockLoopCond);
212
213    iBuilder->SetInsertPoint(blockLoopCond);
214    PHINode * blocksRemaining = iBuilder->CreatePHI(size_ty, 2, "blocksRemaining");
[5111]215    blocksRemaining->addIncoming(blocksToDo, entryBlock);
[5165]216    Value * notDone = iBuilder->CreateICmpUGT(blocksRemaining, ConstantInt::get(size_ty, 0));
217    iBuilder->CreateCondBr(notDone, blockLoopBody, blocksDone);
218
219    iBuilder->SetInsertPoint(blockLoopBody);
220    Value * blockNo = getScalarField(self, blockNoScalar);   
[5185]221
[5174]222    generateDoBlockLogic(self, blockNo);
223    setBlockNo(self, iBuilder->CreateAdd(blockNo, strideBlocks));
[5193]224    blocksRemaining->addIncoming(iBuilder->CreateSub(blocksRemaining, ConstantInt::get(size_ty, 1)), blockLoopBody);
[5165]225    iBuilder->CreateBr(blockLoopCond);
[5111]226   
227    iBuilder->SetInsertPoint(blocksDone);
[5183]228    processed = iBuilder->CreateAdd(processed, iBuilder->CreateMul(blocksToDo, stride));
229    setProcessedItemCount(self, processed);
[5188]230    iBuilder->CreateCondBr(lessThanFullSegment, checkFinalBlock, segmentDone);
231   
232    iBuilder->SetInsertPoint(checkFinalBlock);
233   
234    /* We had less than a full segment of data; we may have reached the end of input
235       on one of the stream sets.  */
236   
237    Value * endOfInput = iBuilder->CreateLoad(endSignalPtrs[0]);
238    if (endSignalPtrs.size() > 1) {
239        /* If there is more than one input stream set, then we need to confirm that one of
240           them has both the endSignal set and the length = to availablePos. */
241        endOfInput = iBuilder->CreateAnd(endOfInput, iBuilder->CreateICmpEQ(availablePos, producerPos[0]));
242        for (unsigned i = 1; i < endSignalPtrs.size(); i++) {
243            Value * e = iBuilder->CreateAnd(iBuilder->CreateLoad(endSignalPtrs[i]), iBuilder->CreateICmpEQ(availablePos, producerPos[i]));
244            endOfInput = iBuilder->CreateOr(endOfInput, e);
245        }
246    }
247    iBuilder->CreateCondBr(endOfInput, callFinalBlock, segmentDone);
248   
249    iBuilder->SetInsertPoint(callFinalBlock);
250   
251    Value * remainingItems = iBuilder->CreateURem(availablePos, stride);
252    createFinalBlockCall(self, remainingItems);
253    setProcessedItemCount(self, availablePos);
254   
255    for (unsigned i = 0; i < mStreamSetOutputs.size(); i++) {
256        Value * ssStructPtr = getStreamSetStructPtr(self, mStreamSetOutputs[i].ssName);
257        mStreamSetOutputBuffers[i]->setEndOfInput(ssStructPtr);
258    }
259   
260    iBuilder->CreateBr(segmentDone);
261   
262    iBuilder->SetInsertPoint(segmentDone);
[5183]263    Value * produced = getProducedItemCount(self);
264#ifndef NDEBUG
265    iBuilder->CallPrintInt(mKernelName + "_produced", produced);
266#endif
267    for (unsigned i = 0; i < mStreamSetOutputs.size(); i++) {
[5185]268        Value * ssStructPtr = getStreamSetStructPtr(self, mStreamSetOutputs[i].ssName);
269        Value * producerPosPtr = mStreamSetOutputBuffers[i]->getProducerPosPtr(ssStructPtr);
[5192]270        iBuilder->CreateAtomicStoreRelease(produced, producerPosPtr);
[5183]271    }
272
[5174]273    // Must be the last action, for synchronization.
274    setLogicalSegmentNo(self, iBuilder->CreateAdd(segmentNo, ConstantInt::get(size_ty, 1)));
275
[5111]276    iBuilder->CreateRetVoid();
[5086]277    iBuilder->restoreIP(savePoint);
278}
279
[5063]280Value * KernelBuilder::getScalarIndex(std::string fieldName) {
281    const auto f = mInternalStateNameMap.find(fieldName);
282    if (LLVM_UNLIKELY(f == mInternalStateNameMap.end())) {
[5174]283        llvm::report_fatal_error("Kernel does not contain internal state: " + fieldName);
[5000]284    }
[5104]285    return iBuilder->getInt32(f->second);
[4959]286}
[4924]287
[5109]288
289
[5063]290Value * KernelBuilder::getScalarField(Value * self, std::string fieldName) {
291    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(fieldName)});
292    return iBuilder->CreateLoad(ptr);
[4924]293}
294
[5063]295void KernelBuilder::setScalarField(Value * self, std::string fieldName, Value * newFieldVal) {
296    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(fieldName)});
297    iBuilder->CreateStore(newFieldVal, ptr);
[5008]298}
[5063]299
[5174]300Value * KernelBuilder::getLogicalSegmentNo(Value * self) { 
301    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(logicalSegmentNoScalar)});
[5192]302    LoadInst * segNo = iBuilder->CreateAtomicLoadAcquire(ptr);
[5174]303    return segNo;
304}
305
306Value * KernelBuilder::getProcessedItemCount(Value * self) { 
307    return getScalarField(self, processedItemCount);
308}
309
310Value * KernelBuilder::getProducedItemCount(Value * self) {
311    return getScalarField(self, producedItemCount);
312}
313
314//  By default, kernels do not terminate early. 
315Value * KernelBuilder::getTerminationSignal(Value * self) {
316    return ConstantInt::getNullValue(iBuilder->getInt1Ty());
317}
318
319
320void KernelBuilder::setLogicalSegmentNo(Value * self, Value * newCount) {
321    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(logicalSegmentNoScalar)});
[5192]322    iBuilder->CreateAtomicStoreRelease(newCount, ptr);
[5174]323}
324
325void KernelBuilder::setProcessedItemCount(Value * self, Value * newCount) {
326    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(processedItemCount)});
327    iBuilder->CreateStore(newCount, ptr);
328}
329
330void KernelBuilder::setProducedItemCount(Value * self, Value * newCount) {
331    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(producedItemCount)});
332    iBuilder->CreateStore(newCount, ptr);
333}
334
335void KernelBuilder::setTerminationSignal(Value * self, Value * newFieldVal) {
336    llvm::report_fatal_error("This kernel type does not support setTerminationSignal.");
337}
338
339
[5165]340Value * KernelBuilder::getBlockNo(Value * self) {
341    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(blockNoScalar)});
[5174]342    LoadInst * blockNo = iBuilder->CreateLoad(ptr);
[5165]343    return blockNo;
344}
[5063]345
[5165]346void KernelBuilder::setBlockNo(Value * self, Value * newFieldVal) {
347    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(blockNoScalar)});
[5174]348    iBuilder->CreateStore(newFieldVal, ptr);
[5165]349}
350
351
[5063]352Value * KernelBuilder::getParameter(Function * f, std::string paramName) {
353    for (Function::arg_iterator argIter = f->arg_begin(), end = f->arg_end(); argIter != end; argIter++) {
354        Value * arg = &*argIter;
355        if (arg->getName() == paramName) return arg;
[5051]356    }
[5174]357    llvm::report_fatal_error("Method does not have parameter: " + paramName);
[5051]358}
[5008]359
[5104]360unsigned KernelBuilder::getStreamSetIndex(std::string ssName) {
361    const auto f = mStreamSetNameMap.find(ssName);
362    if (LLVM_UNLIKELY(f == mStreamSetNameMap.end())) {
[5174]363        llvm::report_fatal_error("Kernel does not contain stream set: " + ssName);
[5104]364    }
365    return f->second;
366}
[5063]367
[5109]368size_t KernelBuilder::getStreamSetBufferSize(Value * self, std::string ssName) {
369    unsigned ssIndex = getStreamSetIndex(ssName);
370    if (ssIndex < mStreamSetInputs.size()) {
[5133]371        return mStreamSetInputBuffers[ssIndex]->getBufferSize();
[5109]372    }
373    else {
[5133]374        return mStreamSetOutputBuffers[ssIndex - mStreamSetInputs.size()]->getBufferSize();
[5109]375    }
376}
377
[5185]378Value * KernelBuilder::getStreamSetStructPtr(Value * self, std::string ssName) {
379    return getScalarField(self, ssName + structPtrSuffix);
[5104]380}
381
382Value * KernelBuilder::getStreamSetBlockPtr(Value * self, std::string ssName, Value * blockNo) {
[5185]383    Value * ssStructPtr = getStreamSetStructPtr(self, ssName);
[5104]384    unsigned ssIndex = getStreamSetIndex(ssName);
385    if (ssIndex < mStreamSetInputs.size()) {
[5185]386        return mStreamSetInputBuffers[ssIndex]->getStreamSetBlockPointer(ssStructPtr, blockNo);
[5104]387    }
388    else {
[5185]389        return mStreamSetOutputBuffers[ssIndex - mStreamSetInputs.size()]->getStreamSetBlockPointer(ssStructPtr, blockNo);
[5104]390    }
391}
392
[5133]393Value * KernelBuilder::createInstance(std::vector<Value *> args) {
[5185]394    Value * kernelInstance = iBuilder->CreateCacheAlignedAlloca(mKernelStateType);
[5133]395    Module * m = iBuilder->getModule();
396    std::vector<Value *> init_args = {kernelInstance};
397    for (auto a : args) {
398        init_args.push_back(a);
399    }
400    for (auto b : mStreamSetInputBuffers) { 
[5135]401        init_args.push_back(b->getStreamSetStructPtr());
[5133]402    }
403    for (auto b : mStreamSetOutputBuffers) { 
[5135]404        init_args.push_back(b->getStreamSetStructPtr());
[5133]405    }
406    std::string initFnName = mKernelName + init_suffix;
407    Function * initMethod = m->getFunction(initFnName);
408    if (!initMethod) {
[5174]409        llvm::report_fatal_error("Cannot find " + initFnName);
[5133]410    }
411    iBuilder->CreateCall(initMethod, init_args);
412    return kernelInstance;
413}
[5104]414
[5135]415Function * KernelBuilder::generateThreadFunction(std::string name){
416    Module * m = iBuilder->getModule();
417    Type * const voidTy = Type::getVoidTy(m->getContext());
418    Type * const voidPtrTy = TypeBuilder<void *, false>::get(m->getContext());
419    Type * const int8PtrTy = iBuilder->getInt8PtrTy();
420    Type * const int1ty = iBuilder->getInt1Ty();
[5104]421
[5135]422    Function * const threadFunc = cast<Function>(m->getOrInsertFunction(name, voidTy, int8PtrTy, nullptr));
423    threadFunc->setCallingConv(CallingConv::C);
424    Function::arg_iterator args = threadFunc->arg_begin();
[5104]425
[5135]426    Value * const arg = &*(args++);
427    arg->setName("args");
[5133]428
[5135]429    iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "entry", threadFunc,0));
430
431    Value * self = iBuilder->CreateBitCast(arg, PointerType::get(mKernelStateType, 0));
432
433    std::vector<Value *> inbufProducerPtrs;
434    std::vector<Value *> inbufConsumerPtrs;
435    std::vector<Value *> outbufProducerPtrs;
436    std::vector<Value *> outbufConsumerPtrs;   
437    std::vector<Value *> endSignalPtrs;
438
439    for (unsigned i = 0; i < mStreamSetInputs.size(); i++) {
[5185]440        Value * ssStructPtr = getStreamSetStructPtr(self, mStreamSetInputs[i].ssName);
441        inbufProducerPtrs.push_back(mStreamSetInputBuffers[i]->getProducerPosPtr(ssStructPtr));
442        inbufConsumerPtrs.push_back(mStreamSetInputBuffers[i]->getConsumerPosPtr(ssStructPtr));
443        endSignalPtrs.push_back(mStreamSetInputBuffers[i]->hasEndOfInputPtr(ssStructPtr));
[5135]444    }
445    for (unsigned i = 0; i < mStreamSetOutputs.size(); i++) {
[5185]446        Value * ssStructPtr = getStreamSetStructPtr(self, mStreamSetOutputs[i].ssName);
447        outbufProducerPtrs.push_back(mStreamSetOutputBuffers[i]->getProducerPosPtr(ssStructPtr));
448        outbufConsumerPtrs.push_back(mStreamSetOutputBuffers[i]->getConsumerPosPtr(ssStructPtr));
[5135]449    }
450
451    const unsigned segmentBlocks = codegen::SegmentSize;
452    const unsigned bufferSegments = codegen::BufferSegments;
453    const unsigned segmentSize = segmentBlocks * iBuilder->getBitBlockWidth();
454    Type * const size_ty = iBuilder->getSizeTy();
455
456    Value * segSize = ConstantInt::get(size_ty, segmentSize);
457    Value * bufferSize = ConstantInt::get(size_ty, segmentSize * (bufferSegments - 1));
458    Value * segBlocks = ConstantInt::get(size_ty, segmentBlocks);
459   
460    BasicBlock * outputCheckBlock = BasicBlock::Create(iBuilder->getContext(), "outputCheck", threadFunc, 0);
461    BasicBlock * inputCheckBlock = BasicBlock::Create(iBuilder->getContext(), "inputCheck", threadFunc, 0);
462   
463    BasicBlock * endSignalCheckBlock = BasicBlock::Create(iBuilder->getContext(), "endSignalCheck", threadFunc, 0);
464    BasicBlock * doSegmentBlock = BasicBlock::Create(iBuilder->getContext(), "doSegment", threadFunc, 0);
465    BasicBlock * endBlock = BasicBlock::Create(iBuilder->getContext(), "end", threadFunc, 0);
466    BasicBlock * doFinalSegBlock = BasicBlock::Create(iBuilder->getContext(), "doFinalSeg", threadFunc, 0);
467    BasicBlock * doFinalBlock = BasicBlock::Create(iBuilder->getContext(), "doFinal", threadFunc, 0);
468
469    iBuilder->CreateBr(outputCheckBlock);
470
471    iBuilder->SetInsertPoint(outputCheckBlock);
472
473    Value * waitCondTest = ConstantInt::get(int1ty, 1);   
474    for (unsigned i = 0; i < outbufProducerPtrs.size(); i++) {
[5192]475        LoadInst * producerPos = iBuilder->CreateAtomicLoadAcquire(outbufProducerPtrs[i]);
[5135]476        // iBuilder->CallPrintInt(name + ":output producerPos", producerPos);
[5192]477        LoadInst * consumerPos = iBuilder->CreateAtomicLoadAcquire(outbufConsumerPtrs[i]);
[5135]478        // iBuilder->CallPrintInt(name + ":output consumerPos", consumerPos);
479        waitCondTest = iBuilder->CreateAnd(waitCondTest, iBuilder->CreateICmpULE(producerPos, iBuilder->CreateAdd(consumerPos, bufferSize)));
480    }
481   
482    iBuilder->CreateCondBr(waitCondTest, inputCheckBlock, outputCheckBlock); 
483
484    iBuilder->SetInsertPoint(inputCheckBlock); 
485
[5174]486    Value * requiredSize = segSize;
487    if (mLookAheadPositions > 0) {
488        requiredSize = iBuilder->CreateAdd(segSize, ConstantInt::get(size_ty, mLookAheadPositions));
489    }
[5135]490    waitCondTest = ConstantInt::get(int1ty, 1); 
491    for (unsigned i = 0; i < inbufProducerPtrs.size(); i++) {
[5192]492        LoadInst * producerPos = iBuilder->CreateAtomicLoadAcquire(inbufProducerPtrs[i]);
[5135]493        // iBuilder->CallPrintInt(name + ":input producerPos", producerPos);
[5192]494        LoadInst * consumerPos = iBuilder->CreateAtomicLoadAcquire(inbufConsumerPtrs[i]);
[5135]495        // iBuilder->CallPrintInt(name + ":input consumerPos", consumerPos);
[5174]496        waitCondTest = iBuilder->CreateAnd(waitCondTest, iBuilder->CreateICmpULE(iBuilder->CreateAdd(consumerPos, requiredSize), producerPos));
[5135]497    }
498
499    iBuilder->CreateCondBr(waitCondTest, doSegmentBlock, endSignalCheckBlock);
500   
501    iBuilder->SetInsertPoint(endSignalCheckBlock);
502   
[5188]503    LoadInst * endSignal = iBuilder->CreateLoad(endSignalPtrs[0]);
[5135]504    for (unsigned i = 1; i < endSignalPtrs.size(); i++){
[5188]505        LoadInst * endSignal_next = iBuilder->CreateLoad(endSignalPtrs[i]);
[5135]506        iBuilder->CreateAnd(endSignal, endSignal_next);
507    }
508       
[5188]509    iBuilder->CreateCondBr(endSignal, endBlock, inputCheckBlock);
[5135]510   
511    iBuilder->SetInsertPoint(doSegmentBlock);
512 
513    createDoSegmentCall(self, segBlocks);
514
515    for (unsigned i = 0; i < inbufConsumerPtrs.size(); i++) {
516        Value * consumerPos = iBuilder->CreateAdd(iBuilder->CreateLoad(inbufConsumerPtrs[i]), segSize);
[5192]517        iBuilder->CreateAtomicStoreRelease(consumerPos, inbufConsumerPtrs[i]);
[5135]518    }
[5174]519   
520    Value * produced = getProducedItemCount(self);
[5135]521    for (unsigned i = 0; i < outbufProducerPtrs.size(); i++) {
[5192]522        iBuilder->CreateAtomicStoreRelease(produced, outbufProducerPtrs[i]);
[5135]523    }
524   
[5174]525    Value * earlyEndSignal = getTerminationSignal(self);
526    if (earlyEndSignal != ConstantInt::getNullValue(iBuilder->getInt1Ty())) {
527        BasicBlock * earlyEndBlock = BasicBlock::Create(iBuilder->getContext(), "earlyEndSignal", threadFunc, 0);
528        iBuilder->CreateCondBr(earlyEndSignal, earlyEndBlock, outputCheckBlock);
529
530        iBuilder->SetInsertPoint(earlyEndBlock);
531        for (unsigned i = 0; i < mStreamSetOutputs.size(); i++) {
[5185]532            Value * ssStructPtr = getStreamSetStructPtr(self, mStreamSetOutputs[i].ssName);
533            mStreamSetOutputBuffers[i]->setEndOfInput(ssStructPtr);
[5174]534        }       
535    }
[5135]536    iBuilder->CreateBr(outputCheckBlock);
537     
538    iBuilder->SetInsertPoint(endBlock);
539    LoadInst * producerPos = iBuilder->CreateLoad(inbufProducerPtrs[0]);
540    LoadInst * consumerPos = iBuilder->CreateLoad(inbufConsumerPtrs[0]);
541    Value * remainingBytes = iBuilder->CreateSub(producerPos, consumerPos);
542    Value * blockSize = ConstantInt::get(size_ty, iBuilder->getBitBlockWidth());
543    Value * blocks = iBuilder->CreateUDiv(remainingBytes, blockSize);
544    Value * finalBlockRemainingBytes = iBuilder->CreateURem(remainingBytes, blockSize);
545
546    iBuilder->CreateCondBr(iBuilder->CreateICmpEQ(blocks, ConstantInt::get(size_ty, 0)), doFinalBlock, doFinalSegBlock);
547
548    iBuilder->SetInsertPoint(doFinalSegBlock);
549
550    createDoSegmentCall(self, blocks);
551
552    iBuilder->CreateBr(doFinalBlock);
553
554    iBuilder->SetInsertPoint(doFinalBlock);
555
556    createFinalBlockCall(self, finalBlockRemainingBytes);
557
558    for (unsigned i = 0; i < inbufConsumerPtrs.size(); i++) {
559        Value * consumerPos = iBuilder->CreateAdd(iBuilder->CreateLoad(inbufConsumerPtrs[i]), remainingBytes);
[5192]560        iBuilder->CreateAtomicStoreRelease(consumerPos, inbufConsumerPtrs[i]);
[5135]561    }
562    for (unsigned i = 0; i < outbufProducerPtrs.size(); i++) {
[5192]563        iBuilder->CreateAtomicStoreRelease(producerPos, outbufProducerPtrs[i]);
[5135]564    }
565
566    for (unsigned i = 0; i < mStreamSetOutputs.size(); i++) {
[5185]567        Value * ssStructPtr = getStreamSetStructPtr(self, mStreamSetOutputs[i].ssName);
568        mStreamSetOutputBuffers[i]->setEndOfInput(ssStructPtr);
[5135]569    }
570
571    Value * nullVal = Constant::getNullValue(voidPtrTy);
572    Function * pthreadExitFunc = m->getFunction("pthread_exit");
573    CallInst * exitThread = iBuilder->CreateCall(pthreadExitFunc, {nullVal}); 
574    exitThread->setDoesNotReturn();
575    iBuilder->CreateRetVoid();
576
577    return threadFunc;
578
579}
Note: See TracBrowser for help on using the repository browser.