source: icGREP/icgrep-devel/icgrep/kernels/kernel.cpp @ 5246

Last change on this file since 5246 was 5246, checked in by nmedfort, 2 years ago

Code clean up to enforce proper calling order of KernelBuilder? methods

File size: 28.0 KB
Line 
1/*
2 *  Copyright (c) 2016 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 */
5
6#include "kernel.h"
7#include <llvm/IR/Module.h>
8#include <llvm/IR/Type.h>
9#include <llvm/IR/Value.h>
10#include <llvm/Support/raw_ostream.h>
11#include <llvm/Support/ErrorHandling.h>
12#include <toolchain.h>
13
14using namespace llvm;
15using namespace kernel;
16
17KernelBuilder::KernelBuilder(IDISA::IDISA_Builder * builder,
18                             std::string kernelName,
19                             std::vector<Binding> stream_inputs,
20                             std::vector<Binding> stream_outputs,
21                             std::vector<Binding> scalar_parameters,
22                             std::vector<Binding> scalar_outputs,
23                             std::vector<Binding> internal_scalars)
24: KernelInterface(builder, kernelName, stream_inputs, stream_outputs, scalar_parameters, scalar_outputs, internal_scalars) {
25
26}
27
28unsigned KernelBuilder::addScalar(Type * type, const std::string & name) {
29    if (LLVM_UNLIKELY(mKernelStateType != nullptr)) {
30        llvm::report_fatal_error("Cannot add kernel field " + name + " after kernel state finalized");
31    }
32    const auto index = mKernelFields.size();
33    mKernelMap.emplace(name, index);
34    mKernelFields.push_back(type);
35    return index;
36}
37
38void KernelBuilder::prepareKernel() {
39    if (LLVM_UNLIKELY(mKernelStateType != nullptr)) {
40        llvm::report_fatal_error("Cannot prepare kernel after kernel state finalized");
41    }
42    unsigned blockSize = iBuilder->getBitBlockWidth();
43    if (mStreamSetInputs.size() != mStreamSetInputBuffers.size()) {
44        std::string tmp;
45        raw_string_ostream out(tmp);
46        out << "kernel contains " << mStreamSetInputBuffers.size() << " input buffers for "
47            << mStreamSetInputs.size() << " input stream sets.";
48        throw std::runtime_error(out.str());
49    }
50    if (mStreamSetOutputs.size() != mStreamSetOutputBuffers.size()) {
51        std::string tmp;
52        raw_string_ostream out(tmp);
53        out << "kernel contains " << mStreamSetOutputBuffers.size() << " output buffers for "
54            << mStreamSetOutputs.size() << " output stream sets.";
55        throw std::runtime_error(out.str());
56    }
57    int streamSetNo = 0;
58    for (unsigned i = 0; i < mStreamSetInputs.size(); i++) {
59        if ((mStreamSetInputBuffers[i]->getBufferSize() > 0) && (mStreamSetInputBuffers[i]->getBufferSize() < codegen::SegmentSize + (blockSize + mLookAheadPositions - 1)/blockSize)) {
60             llvm::report_fatal_error("Kernel preparation: Buffer size too small " + mStreamSetInputs[i].name);
61        }
62        mScalarInputs.push_back(Binding{mStreamSetInputBuffers[i]->getStreamSetStructPointerType(), mStreamSetInputs[i].name + structPtrSuffix});
63        mStreamSetNameMap.emplace(mStreamSetInputs[i].name, streamSetNo);
64        streamSetNo++;
65    }
66    for (unsigned i = 0; i < mStreamSetOutputs.size(); i++) {
67        mScalarInputs.push_back(Binding{mStreamSetOutputBuffers[i]->getStreamSetStructPointerType(), mStreamSetOutputs[i].name + structPtrSuffix});
68        mStreamSetNameMap.emplace(mStreamSetOutputs[i].name, streamSetNo);
69        streamSetNo++;
70    }
71    for (auto binding : mScalarInputs) {
72        addScalar(binding.type, binding.name);
73    }
74    for (auto binding : mScalarOutputs) {
75        addScalar(binding.type, binding.name);
76    }
77    for (auto binding : mInternalScalars) {
78        addScalar(binding.type, binding.name);
79    }
80    addScalar(iBuilder->getSizeTy(), blockNoScalar);
81    addScalar(iBuilder->getSizeTy(), logicalSegmentNoScalar);
82    addScalar(iBuilder->getSizeTy(), processedItemCount);
83    addScalar(iBuilder->getSizeTy(), producedItemCount);
84    addScalar(iBuilder->getInt1Ty(), terminationSignal);
85    mKernelStateType = StructType::create(iBuilder->getContext(), mKernelFields, mKernelName);
86}
87
88std::unique_ptr<Module> KernelBuilder::createKernelModule(const std::vector<StreamSetBuffer *> & inputs, const std::vector<StreamSetBuffer *> & outputs) {
89    auto saveModule = iBuilder->getModule();
90    auto savePoint = iBuilder->saveIP();
91    auto module = make_unique<Module>(mKernelName + "_" + iBuilder->getBitBlockTypeName(), iBuilder->getContext());
92    iBuilder->setModule(module.get());
93    generateKernel(inputs, outputs);
94    iBuilder->setModule(saveModule);
95    iBuilder->restoreIP(savePoint);
96    return module;
97}
98
99void KernelBuilder::generateKernel(const std::vector<StreamSetBuffer *> & inputs, const std::vector<StreamSetBuffer *> & outputs) {
100    auto savePoint = iBuilder->saveIP();
101    Module * const m = iBuilder->getModule();
102    mStreamSetInputBuffers.assign(inputs.begin(), inputs.end());
103    mStreamSetOutputBuffers.assign(outputs.begin(), outputs.end());
104    prepareKernel();            // possibly overridden by the KernelBuilder subtype
105    addKernelDeclarations(m);
106    generateDoBlockMethod();    // must be implemented by the KernelBuilder subtype
107    generateFinalBlockMethod(); // possibly overridden by the KernelBuilder subtype
108    generateDoSegmentMethod();
109
110    // Implement the accumulator get functions
111    for (auto binding : mScalarOutputs) {
112        auto fnName = mKernelName + accumulator_infix + binding.name;
113        Function * accumFn = m->getFunction(fnName);
114        iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "get_" + binding.name, accumFn, 0));
115        Value * self = &*(accumFn->arg_begin());
116        Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(binding.name)});
117        Value * retVal = iBuilder->CreateLoad(ptr);
118        iBuilder->CreateRet(retVal);
119    }
120
121    // Implement the initializer function
122    Function * initFunction = m->getFunction(mKernelName + init_suffix);
123    iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "Init_entry", initFunction, 0));   
124    Function::arg_iterator args = initFunction->arg_begin();
125    Value * self = &*(args++);
126    initializeKernelState(self);    // possibly overridden by the KernelBuilder subtype
127    for (auto binding : mScalarInputs) {
128        Value * param = &*(args++);
129        Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(binding.name)});
130        iBuilder->CreateStore(param, ptr);
131    }
132    iBuilder->CreateRetVoid();
133    iBuilder->restoreIP(savePoint);
134}
135
136void KernelBuilder::initializeKernelState(Value * self) const {
137    iBuilder->CreateStore(ConstantAggregateZero::get(mKernelStateType), self);
138}
139
140//  The default finalBlock method simply dispatches to the doBlock routine.
141void KernelBuilder::generateFinalBlockMethod() const {
142    auto savePoint = iBuilder->saveIP();
143    Module * m = iBuilder->getModule();
144    Function * doBlockFunction = m->getFunction(mKernelName + doBlock_suffix);
145    Function * finalBlockFunction = m->getFunction(mKernelName + finalBlock_suffix);
146    iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "fb_entry", finalBlockFunction, 0));
147    // Final Block arguments: self, remaining, then the standard DoBlock args.
148    Function::arg_iterator args = finalBlockFunction->arg_begin();
149    Value * self = &*(args++);
150    /* Skip "remaining" arg */ args++;
151    std::vector<Value *> doBlockArgs = {self};
152    while (args != finalBlockFunction->arg_end()){
153        doBlockArgs.push_back(&*args++);
154    }
155    iBuilder->CreateCall(doBlockFunction, doBlockArgs);
156    iBuilder->CreateRetVoid();
157    iBuilder->restoreIP(savePoint);
158}
159
160// Note: this may be overridden to incorporate doBlock logic directly into
161// the doSegment function.
162void KernelBuilder::generateDoBlockLogic(Value * self, Value * /* blockNo */) const {
163    Function * doBlockFunction = iBuilder->getModule()->getFunction(mKernelName + doBlock_suffix);
164    iBuilder->CreateCall(doBlockFunction, self);
165}
166
167//  The default doSegment method dispatches to the doBlock routine for
168//  each block of the given number of blocksToDo, and then updates counts.
169void KernelBuilder::generateDoSegmentMethod() const {
170    auto savePoint = iBuilder->saveIP();
171    Module * m = iBuilder->getModule();
172    Function * doSegmentFunction = m->getFunction(mKernelName + doSegment_suffix);
173    iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "entry", doSegmentFunction, 0));
174    BasicBlock * entryBlock = iBuilder->GetInsertBlock();
175    BasicBlock * strideLoopCond = BasicBlock::Create(iBuilder->getContext(), "strideLoopCond", doSegmentFunction, 0);
176    BasicBlock * strideLoopBody = BasicBlock::Create(iBuilder->getContext(), "strideLoopBody", doSegmentFunction, 0);
177    BasicBlock * stridesDone = BasicBlock::Create(iBuilder->getContext(), "stridesDone", doSegmentFunction, 0);
178    BasicBlock * checkFinalStride = BasicBlock::Create(iBuilder->getContext(), "checkFinalStride", doSegmentFunction, 0);
179    BasicBlock * checkEndSignals = BasicBlock::Create(iBuilder->getContext(), "checkEndSignals", doSegmentFunction, 0);
180    BasicBlock * callFinalBlock = BasicBlock::Create(iBuilder->getContext(), "callFinalBlock", doSegmentFunction, 0);
181    BasicBlock * segmentDone = BasicBlock::Create(iBuilder->getContext(), "segmentDone", doSegmentFunction, 0);
182    BasicBlock * finalExit = BasicBlock::Create(iBuilder->getContext(), "finalExit", doSegmentFunction, 0);
183    Type * const size_ty = iBuilder->getSizeTy();
184    Constant * stride = ConstantInt::get(size_ty, iBuilder->getStride());
185    Value * strideBlocks = ConstantInt::get(size_ty, iBuilder->getStride() / iBuilder->getBitBlockWidth());
186   
187    Function::arg_iterator args = doSegmentFunction->arg_begin();
188    Value * self = &*(args++);
189    Value * blocksToDo = &*(args);
190   
191    std::vector<Value *> inbufProducerPtrs;
192    std::vector<Value *> endSignalPtrs;
193    for (unsigned i = 0; i < mStreamSetInputs.size(); i++) {
194        Value * param = getStreamSetStructPtr(self, mStreamSetInputs[i].name);
195        inbufProducerPtrs.push_back(mStreamSetInputBuffers[i]->getProducerPosPtr(param));
196        endSignalPtrs.push_back(mStreamSetInputBuffers[i]->getEndOfInputPtr(param));
197    }
198   
199    std::vector<Value *> producerPos;
200    /* Determine the actually available data examining all input stream sets. */
201    LoadInst * p = iBuilder->CreateAtomicLoadAcquire(inbufProducerPtrs[0]);
202    producerPos.push_back(p);
203    Value * availablePos = producerPos[0];
204    for (unsigned i = 1; i < inbufProducerPtrs.size(); i++) {
205        LoadInst * p = iBuilder->CreateAtomicLoadAcquire(inbufProducerPtrs[i]);
206        producerPos.push_back(p);
207        /* Set the available position to be the minimum of availablePos and producerPos. */
208        availablePos = iBuilder->CreateSelect(iBuilder->CreateICmpULT(availablePos, p), availablePos, p);
209    }
210    Value * processed = getProcessedItemCount(self);
211    Value * itemsAvail = iBuilder->CreateSub(availablePos, processed);
212//#ifndef NDEBUG
213//    iBuilder->CallPrintInt(mKernelName + "_itemsAvail", itemsAvail);
214//#endif
215    Value * stridesToDo = iBuilder->CreateUDiv(blocksToDo, strideBlocks);
216    Value * stridesAvail = iBuilder->CreateUDiv(itemsAvail, stride);
217    /* Adjust the number of full blocks to do, based on the available data, if necessary. */
218    Value * lessThanFullSegment = iBuilder->CreateICmpULT(stridesAvail, stridesToDo);
219    stridesToDo = iBuilder->CreateSelect(lessThanFullSegment, stridesAvail, stridesToDo);
220    //iBuilder->CallPrintInt(mKernelName + "_stridesAvail", stridesAvail);
221    iBuilder->CreateBr(strideLoopCond);
222
223    iBuilder->SetInsertPoint(strideLoopCond);
224    PHINode * stridesRemaining = iBuilder->CreatePHI(size_ty, 2, "stridesRemaining");
225    stridesRemaining->addIncoming(stridesToDo, entryBlock);
226    Value * notDone = iBuilder->CreateICmpUGT(stridesRemaining, ConstantInt::get(size_ty, 0));
227    iBuilder->CreateCondBr(notDone, strideLoopBody, stridesDone);
228
229    iBuilder->SetInsertPoint(strideLoopBody);
230    Value * blockNo = getScalarField(self, blockNoScalar);   
231
232    generateDoBlockLogic(self, blockNo);
233    setBlockNo(self, iBuilder->CreateAdd(blockNo, strideBlocks));
234    stridesRemaining->addIncoming(iBuilder->CreateSub(stridesRemaining, ConstantInt::get(size_ty, 1)), strideLoopBody);
235    iBuilder->CreateBr(strideLoopCond);
236   
237    iBuilder->SetInsertPoint(stridesDone);
238    processed = iBuilder->CreateAdd(processed, iBuilder->CreateMul(stridesToDo, stride));
239    setProcessedItemCount(self, processed);
240    iBuilder->CreateCondBr(lessThanFullSegment, checkFinalStride, segmentDone);
241   
242    iBuilder->SetInsertPoint(checkFinalStride);
243   
244    /* We had less than a full segment of data; we may have reached the end of input
245       on one of the stream sets.  */
246   
247    Value * alreadyDone = getTerminationSignal(self);
248    iBuilder->CreateCondBr(alreadyDone, finalExit, checkEndSignals);
249   
250    iBuilder->SetInsertPoint(checkEndSignals);
251    Value * endOfInput = iBuilder->CreateLoad(endSignalPtrs[0]);
252    if (endSignalPtrs.size() > 1) {
253        /* If there is more than one input stream set, then we need to confirm that one of
254           them has both the endSignal set and the length = to availablePos. */
255        endOfInput = iBuilder->CreateAnd(endOfInput, iBuilder->CreateICmpEQ(availablePos, producerPos[0]));
256        for (unsigned i = 1; i < endSignalPtrs.size(); i++) {
257            Value * e = iBuilder->CreateAnd(iBuilder->CreateLoad(endSignalPtrs[i]), iBuilder->CreateICmpEQ(availablePos, producerPos[i]));
258            endOfInput = iBuilder->CreateOr(endOfInput, e);
259        }
260    }
261    iBuilder->CreateCondBr(endOfInput, callFinalBlock, segmentDone);
262   
263    iBuilder->SetInsertPoint(callFinalBlock);
264   
265    Value * remainingItems = iBuilder->CreateSub(availablePos, processed);
266    createFinalBlockCall(self, remainingItems);
267    setProcessedItemCount(self, availablePos);
268   
269    for (unsigned i = 0; i < mStreamSetOutputs.size(); i++) {
270        Value * ssStructPtr = getStreamSetStructPtr(self, mStreamSetOutputs[i].name);
271        mStreamSetOutputBuffers[i]->setEndOfInput(ssStructPtr);
272    }
273    setTerminationSignal(self);
274    iBuilder->CreateBr(segmentDone);
275   
276    iBuilder->SetInsertPoint(segmentDone);
277    Value * produced = getProducedItemCount(self);
278//#ifndef NDEBUG
279//    iBuilder->CallPrintInt(mKernelName + "_produced", produced);
280//#endif
281    for (unsigned i = 0; i < mStreamSetOutputs.size(); i++) {
282        Value * ssStructPtr = getStreamSetStructPtr(self, mStreamSetOutputs[i].name);
283        Value * producerPosPtr = mStreamSetOutputBuffers[i]->getProducerPosPtr(ssStructPtr);
284        iBuilder->CreateAtomicStoreRelease(produced, producerPosPtr);
285    }
286    iBuilder->CreateBr(finalExit);
287    iBuilder->SetInsertPoint(finalExit);
288
289    iBuilder->CreateRetVoid();
290    iBuilder->restoreIP(savePoint);
291}
292
293ConstantInt * KernelBuilder::getScalarIndex(const std::string & name) const {
294    const auto f = mKernelMap.find(name);
295    if (LLVM_UNLIKELY(f == mKernelMap.end())) {
296        llvm::report_fatal_error("Kernel does not contain scalar: " + name);
297    }
298    return iBuilder->getInt32(f->second);
299}
300
301unsigned KernelBuilder::getScalarCount() const {
302    return mKernelFields.size();
303}
304
305Value * KernelBuilder::getScalarFieldPtr(Value * self, const std::string & fieldName) const {
306    return iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(fieldName)});
307}
308
309Value * KernelBuilder::getScalarField(Value * self, const std::string & fieldName) const {
310    return iBuilder->CreateLoad(getScalarFieldPtr(self, fieldName));
311}
312
313void KernelBuilder::setScalarField(Value * self, const std::string & fieldName, Value * newFieldVal) const {
314    iBuilder->CreateStore(newFieldVal, getScalarFieldPtr(self, fieldName));
315}
316
317LoadInst * KernelBuilder::acquireLogicalSegmentNo(Value * self) const {
318    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(logicalSegmentNoScalar)});
319    return iBuilder->CreateAtomicLoadAcquire(ptr);
320}
321
322Value * KernelBuilder::getProcessedItemCount(Value * self) const {
323    return getScalarField(self, processedItemCount);
324}
325
326Value * KernelBuilder::getProducedItemCount(Value * self) const {
327    return getScalarField(self, producedItemCount);
328}
329
330Value * KernelBuilder::getTerminationSignal(Value * self) const {
331    return getScalarField(self, terminationSignal);
332}
333
334void KernelBuilder::releaseLogicalSegmentNo(Value * self, Value * newCount) const {
335    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(logicalSegmentNoScalar)});
336    iBuilder->CreateAtomicStoreRelease(newCount, ptr);
337}
338
339void KernelBuilder::setProcessedItemCount(Value * self, Value * newCount) const {
340    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(processedItemCount)});
341    iBuilder->CreateStore(newCount, ptr);
342}
343
344void KernelBuilder::setProducedItemCount(Value * self, Value * newCount) const {
345    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(producedItemCount)});
346    iBuilder->CreateStore(newCount, ptr);
347}
348
349void KernelBuilder::setTerminationSignal(Value * self) const {
350    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(terminationSignal)});
351    iBuilder->CreateStore(ConstantInt::get(iBuilder->getInt1Ty(), 1), ptr);
352}
353
354Value * KernelBuilder::getBlockNo(Value * self) const {
355    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(blockNoScalar)});
356    return iBuilder->CreateLoad(ptr);
357}
358
359void KernelBuilder::setBlockNo(Value * self, Value * newFieldVal) const {
360    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(blockNoScalar)});
361    iBuilder->CreateStore(newFieldVal, ptr);
362}
363
364
365Value * KernelBuilder::getParameter(Function * f, const std::string & paramName) const {
366    for (Function::arg_iterator argIter = f->arg_begin(), end = f->arg_end(); argIter != end; argIter++) {
367        Value * arg = &*argIter;
368        if (arg->getName() == paramName) return arg;
369    }
370    llvm::report_fatal_error("Method does not have parameter: " + paramName);
371}
372
373unsigned KernelBuilder::getStreamSetIndex(const std::string & name) const {
374    const auto f = mStreamSetNameMap.find(name);
375    if (LLVM_UNLIKELY(f == mStreamSetNameMap.end())) {
376        llvm::report_fatal_error("Kernel does not contain stream set: " + name);
377    }
378    return f->second;
379}
380
381size_t KernelBuilder::getStreamSetBufferSize(Value * /* self */, const std::string & name) const {
382    const unsigned index = getStreamSetIndex(name);
383    StreamSetBuffer * buf = nullptr;
384    if (index < mStreamSetInputs.size()) {
385        buf = mStreamSetInputBuffers[index];
386    } else {
387        buf = mStreamSetOutputBuffers[index - mStreamSetInputs.size()];
388    }
389    return buf->getBufferSize();
390}
391
392Value * KernelBuilder::getStreamSetStructPtr(Value * self, const std::string & name) const {
393    return getScalarField(self, name + structPtrSuffix);
394}
395
396Value * KernelBuilder::getStreamSetBlockPtr(Value * self, const std::string &name, Value * blockNo) const {
397    Value * const structPtr = getStreamSetStructPtr(self, name);
398    const unsigned index = getStreamSetIndex(name);
399    StreamSetBuffer * buf = nullptr;
400    if (index < mStreamSetInputs.size()) {
401        buf = mStreamSetInputBuffers[index];
402    } else {
403        buf = mStreamSetOutputBuffers[index - mStreamSetInputs.size()];
404    }   
405    return buf->getStreamSetBlockPointer(structPtr, blockNo);
406}
407
408Value * KernelBuilder::getStream(Value * self, const std::string & name, Value * blockNo, Value * index) {
409    return iBuilder->CreateGEP(getStreamSetBlockPtr(self, name, blockNo), {iBuilder->getInt32(0), index});
410}
411
412void KernelBuilder::createInstance() {
413    if (LLVM_UNLIKELY(mKernelStateType == nullptr)) {
414        llvm::report_fatal_error("Cannot create kernel instance before calling prepareKernel()");
415    }
416    mKernelInstance = iBuilder->CreateCacheAlignedAlloca(mKernelStateType);
417    Module * m = iBuilder->getModule();
418    std::vector<Value *> init_args = {mKernelInstance};
419    for (auto a : mInitialArguments) {
420        init_args.push_back(a);
421    }
422    for (auto b : mStreamSetInputBuffers) {
423        init_args.push_back(b->getStreamSetStructPtr());
424    }
425    for (auto b : mStreamSetOutputBuffers) {
426        init_args.push_back(b->getStreamSetStructPtr());
427    }
428    std::string initFnName = mKernelName + init_suffix;
429    Function * initMethod = m->getFunction(initFnName);
430    if (!initMethod) {
431        llvm::report_fatal_error("Cannot find " + initFnName);
432    }
433    iBuilder->CreateCall(initMethod, init_args);
434}
435
436Function * KernelBuilder::generateThreadFunction(const std::string & name) const {
437    if (LLVM_UNLIKELY(mKernelStateType == nullptr)) {
438        llvm::report_fatal_error("Cannot generate thread function before calling prepareKernel()");
439    }
440    Module * m = iBuilder->getModule();
441    Type * const voidTy = iBuilder->getVoidTy();
442    Type * const voidPtrTy = iBuilder->getVoidPtrTy();
443    Type * const int8PtrTy = iBuilder->getInt8PtrTy();
444    Type * const int1ty = iBuilder->getInt1Ty();
445
446    Function * const threadFunc = cast<Function>(m->getOrInsertFunction(name, voidTy, int8PtrTy, nullptr));
447    threadFunc->setCallingConv(CallingConv::C);
448    Function::arg_iterator args = threadFunc->arg_begin();
449
450    Value * const arg = &*(args++);
451    arg->setName("args");
452
453    iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "entry", threadFunc,0));
454
455    Value * self = iBuilder->CreateBitCast(arg, PointerType::get(mKernelStateType, 0));
456
457    std::vector<Value *> inbufProducerPtrs;
458    std::vector<Value *> inbufConsumerPtrs;
459    std::vector<Value *> outbufProducerPtrs;
460    std::vector<Value *> outbufConsumerPtrs;   
461    std::vector<Value *> endSignalPtrs;
462
463    for (unsigned i = 0; i < mStreamSetInputs.size(); i++) {
464        Value * ssStructPtr = getStreamSetStructPtr(self, mStreamSetInputs[i].name);
465        inbufProducerPtrs.push_back(mStreamSetInputBuffers[i]->getProducerPosPtr(ssStructPtr));
466        inbufConsumerPtrs.push_back(mStreamSetInputBuffers[i]->getConsumerPosPtr(ssStructPtr));
467        endSignalPtrs.push_back(mStreamSetInputBuffers[i]->getEndOfInputPtr(ssStructPtr));
468    }
469    for (unsigned i = 0; i < mStreamSetOutputs.size(); i++) {
470        Value * ssStructPtr = getStreamSetStructPtr(self, mStreamSetOutputs[i].name);
471        outbufProducerPtrs.push_back(mStreamSetOutputBuffers[i]->getProducerPosPtr(ssStructPtr));
472        outbufConsumerPtrs.push_back(mStreamSetOutputBuffers[i]->getConsumerPosPtr(ssStructPtr));
473    }
474
475    const unsigned segmentBlocks = codegen::SegmentSize;
476    const unsigned bufferSegments = codegen::BufferSegments;
477    const unsigned segmentSize = segmentBlocks * iBuilder->getBitBlockWidth();
478    Type * const size_ty = iBuilder->getSizeTy();
479
480    Value * segSize = ConstantInt::get(size_ty, segmentSize);
481    Value * bufferSize = ConstantInt::get(size_ty, segmentSize * (bufferSegments - 1));
482    Value * segBlocks = ConstantInt::get(size_ty, segmentBlocks);
483   
484    BasicBlock * outputCheckBlock = BasicBlock::Create(iBuilder->getContext(), "outputCheck", threadFunc, 0);
485    BasicBlock * inputCheckBlock = BasicBlock::Create(iBuilder->getContext(), "inputCheck", threadFunc, 0);
486   
487    BasicBlock * endSignalCheckBlock = BasicBlock::Create(iBuilder->getContext(), "endSignalCheck", threadFunc, 0);
488    BasicBlock * doSegmentBlock = BasicBlock::Create(iBuilder->getContext(), "doSegment", threadFunc, 0);
489    BasicBlock * endBlock = BasicBlock::Create(iBuilder->getContext(), "end", threadFunc, 0);
490    BasicBlock * doFinalSegBlock = BasicBlock::Create(iBuilder->getContext(), "doFinalSeg", threadFunc, 0);
491    BasicBlock * doFinalBlock = BasicBlock::Create(iBuilder->getContext(), "doFinal", threadFunc, 0);
492
493    iBuilder->CreateBr(outputCheckBlock);
494
495    iBuilder->SetInsertPoint(outputCheckBlock);
496
497    Value * waitCondTest = ConstantInt::get(int1ty, 1);   
498    for (unsigned i = 0; i < outbufProducerPtrs.size(); i++) {
499        LoadInst * producerPos = iBuilder->CreateAtomicLoadAcquire(outbufProducerPtrs[i]);
500        // iBuilder->CallPrintInt(name + ":output producerPos", producerPos);
501        LoadInst * consumerPos = iBuilder->CreateAtomicLoadAcquire(outbufConsumerPtrs[i]);
502        // iBuilder->CallPrintInt(name + ":output consumerPos", consumerPos);
503        waitCondTest = iBuilder->CreateAnd(waitCondTest, iBuilder->CreateICmpULE(producerPos, iBuilder->CreateAdd(consumerPos, bufferSize)));
504    }
505   
506    iBuilder->CreateCondBr(waitCondTest, inputCheckBlock, outputCheckBlock); 
507
508    iBuilder->SetInsertPoint(inputCheckBlock); 
509
510    Value * requiredSize = segSize;
511    if (mLookAheadPositions > 0) {
512        requiredSize = iBuilder->CreateAdd(segSize, ConstantInt::get(size_ty, mLookAheadPositions));
513    }
514    waitCondTest = ConstantInt::get(int1ty, 1); 
515    for (unsigned i = 0; i < inbufProducerPtrs.size(); i++) {
516        LoadInst * producerPos = iBuilder->CreateAtomicLoadAcquire(inbufProducerPtrs[i]);
517        // iBuilder->CallPrintInt(name + ":input producerPos", producerPos);
518        LoadInst * consumerPos = iBuilder->CreateAtomicLoadAcquire(inbufConsumerPtrs[i]);
519        // iBuilder->CallPrintInt(name + ":input consumerPos", consumerPos);
520        waitCondTest = iBuilder->CreateAnd(waitCondTest, iBuilder->CreateICmpULE(iBuilder->CreateAdd(consumerPos, requiredSize), producerPos));
521    }
522
523    iBuilder->CreateCondBr(waitCondTest, doSegmentBlock, endSignalCheckBlock);
524   
525    iBuilder->SetInsertPoint(endSignalCheckBlock);
526   
527    LoadInst * endSignal = iBuilder->CreateLoad(endSignalPtrs[0]);
528    for (unsigned i = 1; i < endSignalPtrs.size(); i++){
529        LoadInst * endSignal_next = iBuilder->CreateLoad(endSignalPtrs[i]);
530        iBuilder->CreateAnd(endSignal, endSignal_next);
531    }
532       
533    iBuilder->CreateCondBr(endSignal, endBlock, inputCheckBlock);
534   
535    iBuilder->SetInsertPoint(doSegmentBlock);
536 
537    createDoSegmentCall(self, segBlocks);
538
539    for (unsigned i = 0; i < inbufConsumerPtrs.size(); i++) {
540        Value * consumerPos = iBuilder->CreateAdd(iBuilder->CreateLoad(inbufConsumerPtrs[i]), segSize);
541        iBuilder->CreateAtomicStoreRelease(consumerPos, inbufConsumerPtrs[i]);
542    }
543   
544    Value * produced = getProducedItemCount(self);
545    for (unsigned i = 0; i < outbufProducerPtrs.size(); i++) {
546        iBuilder->CreateAtomicStoreRelease(produced, outbufProducerPtrs[i]);
547    }
548   
549    Value * earlyEndSignal = getTerminationSignal(self);
550    if (earlyEndSignal != ConstantInt::getNullValue(iBuilder->getInt1Ty())) {
551        BasicBlock * earlyEndBlock = BasicBlock::Create(iBuilder->getContext(), "earlyEndSignal", threadFunc, 0);
552        iBuilder->CreateCondBr(earlyEndSignal, earlyEndBlock, outputCheckBlock);
553
554        iBuilder->SetInsertPoint(earlyEndBlock);
555        for (unsigned i = 0; i < mStreamSetOutputs.size(); i++) {
556            Value * ssStructPtr = getStreamSetStructPtr(self, mStreamSetOutputs[i].name);
557            mStreamSetOutputBuffers[i]->setEndOfInput(ssStructPtr);
558        }       
559    }
560    iBuilder->CreateBr(outputCheckBlock);
561     
562    iBuilder->SetInsertPoint(endBlock);
563    LoadInst * producerPos = iBuilder->CreateLoad(inbufProducerPtrs[0]);
564    LoadInst * consumerPos = iBuilder->CreateLoad(inbufConsumerPtrs[0]);
565    Value * remainingBytes = iBuilder->CreateSub(producerPos, consumerPos);
566    Value * blockSize = ConstantInt::get(size_ty, iBuilder->getBitBlockWidth());
567    Value * blocks = iBuilder->CreateUDiv(remainingBytes, blockSize);
568    Value * finalBlockRemainingBytes = iBuilder->CreateURem(remainingBytes, blockSize);
569
570    iBuilder->CreateCondBr(iBuilder->CreateICmpEQ(blocks, ConstantInt::get(size_ty, 0)), doFinalBlock, doFinalSegBlock);
571
572    iBuilder->SetInsertPoint(doFinalSegBlock);
573
574    createDoSegmentCall(self, blocks);
575
576    iBuilder->CreateBr(doFinalBlock);
577
578    iBuilder->SetInsertPoint(doFinalBlock);
579
580    createFinalBlockCall(self, finalBlockRemainingBytes);
581
582    for (unsigned i = 0; i < inbufConsumerPtrs.size(); i++) {
583        Value * consumerPos = iBuilder->CreateAdd(iBuilder->CreateLoad(inbufConsumerPtrs[i]), remainingBytes);
584        iBuilder->CreateAtomicStoreRelease(consumerPos, inbufConsumerPtrs[i]);
585    }
586    for (unsigned i = 0; i < outbufProducerPtrs.size(); i++) {
587        iBuilder->CreateAtomicStoreRelease(producerPos, outbufProducerPtrs[i]);
588    }
589
590    for (unsigned i = 0; i < mStreamSetOutputs.size(); i++) {
591        Value * ssStructPtr = getStreamSetStructPtr(self, mStreamSetOutputs[i].name);
592        mStreamSetOutputBuffers[i]->setEndOfInput(ssStructPtr);
593    }
594
595    iBuilder->CreatePThreadExitCall(Constant::getNullValue(voidPtrTy));
596    iBuilder->CreateRetVoid();
597
598    return threadFunc;
599
600}
601
602KernelBuilder::~KernelBuilder() {
603}
Note: See TracBrowser for help on using the repository browser.