source: icGREP/icgrep-devel/icgrep/kernels/kernel.cpp @ 5257

Last change on this file since 5257 was 5257, checked in by cameron, 2 years ago

finalSegment kernel methods initial check-in

File size: 30.5 KB
Line 
1/*
2 *  Copyright (c) 2016 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 */
5
6#include "kernel.h"
7#include <llvm/IR/Module.h>
8#include <llvm/IR/Type.h>
9#include <llvm/IR/Value.h>
10#include <llvm/Support/raw_ostream.h>
11#include <llvm/Support/ErrorHandling.h>
12#include <toolchain.h>
13
14using namespace llvm;
15using namespace kernel;
16
17KernelBuilder::KernelBuilder(IDISA::IDISA_Builder * builder,
18                             std::string kernelName,
19                             std::vector<Binding> stream_inputs,
20                             std::vector<Binding> stream_outputs,
21                             std::vector<Binding> scalar_parameters,
22                             std::vector<Binding> scalar_outputs,
23                             std::vector<Binding> internal_scalars)
24: KernelInterface(builder, kernelName, stream_inputs, stream_outputs, scalar_parameters, scalar_outputs, internal_scalars),
25mNoTerminateAttribute(false) {
26
27}
28
29unsigned KernelBuilder::addScalar(Type * type, const std::string & name) {
30    if (LLVM_UNLIKELY(mKernelStateType != nullptr)) {
31        llvm::report_fatal_error("Cannot add kernel field " + name + " after kernel state finalized");
32    }
33    const auto index = mKernelFields.size();
34    mKernelMap.emplace(name, index);
35    mKernelFields.push_back(type);
36    return index;
37}
38
39void KernelBuilder::prepareKernel() {
40    if (LLVM_UNLIKELY(mKernelStateType != nullptr)) {
41        llvm::report_fatal_error("Cannot prepare kernel after kernel state finalized");
42    }
43    unsigned blockSize = iBuilder->getBitBlockWidth();
44    if (mStreamSetInputs.size() != mStreamSetInputBuffers.size()) {
45        std::string tmp;
46        raw_string_ostream out(tmp);
47        out << "kernel contains " << mStreamSetInputBuffers.size() << " input buffers for "
48            << mStreamSetInputs.size() << " input stream sets.";
49        throw std::runtime_error(out.str());
50    }
51    if (mStreamSetOutputs.size() != mStreamSetOutputBuffers.size()) {
52        std::string tmp;
53        raw_string_ostream out(tmp);
54        out << "kernel contains " << mStreamSetOutputBuffers.size() << " output buffers for "
55            << mStreamSetOutputs.size() << " output stream sets.";
56        throw std::runtime_error(out.str());
57    }
58    int streamSetNo = 0;
59    for (unsigned i = 0; i < mStreamSetInputs.size(); i++) {
60        if ((mStreamSetInputBuffers[i]->getBufferSize() > 0) && (mStreamSetInputBuffers[i]->getBufferSize() < codegen::SegmentSize + (blockSize + mLookAheadPositions - 1)/blockSize)) {
61             llvm::report_fatal_error("Kernel preparation: Buffer size too small " + mStreamSetInputs[i].name);
62        }
63        mScalarInputs.push_back(Binding{mStreamSetInputBuffers[i]->getStreamSetStructPointerType(), mStreamSetInputs[i].name + structPtrSuffix});
64        mStreamSetNameMap.emplace(mStreamSetInputs[i].name, streamSetNo);
65        addScalar(iBuilder->getSizeTy(), mStreamSetInputs[i].name + processedItemCountSuffix);
66        streamSetNo++;
67    }
68    for (unsigned i = 0; i < mStreamSetOutputs.size(); i++) {
69        mScalarInputs.push_back(Binding{mStreamSetOutputBuffers[i]->getStreamSetStructPointerType(), mStreamSetOutputs[i].name + structPtrSuffix});
70        mStreamSetNameMap.emplace(mStreamSetOutputs[i].name, streamSetNo);
71        addScalar(iBuilder->getSizeTy(), mStreamSetOutputs[i].name + producedItemCountSuffix);
72        streamSetNo++;
73    }
74    for (auto binding : mScalarInputs) {
75        addScalar(binding.type, binding.name);
76    }
77    for (auto binding : mScalarOutputs) {
78        addScalar(binding.type, binding.name);
79    }
80    for (auto binding : mInternalScalars) {
81        addScalar(binding.type, binding.name);
82    }
83    addScalar(iBuilder->getSizeTy(), blockNoScalar);
84    addScalar(iBuilder->getSizeTy(), logicalSegmentNoScalar);
85    addScalar(iBuilder->getInt1Ty(), terminationSignal);
86    mKernelStateType = StructType::create(iBuilder->getContext(), mKernelFields, mKernelName);
87}
88
89std::unique_ptr<Module> KernelBuilder::createKernelModule(const std::vector<StreamSetBuffer *> & inputs, const std::vector<StreamSetBuffer *> & outputs) {
90    auto saveModule = iBuilder->getModule();
91    auto savePoint = iBuilder->saveIP();
92    auto module = make_unique<Module>(mKernelName + "_" + iBuilder->getBitBlockTypeName(), iBuilder->getContext());
93    iBuilder->setModule(module.get());
94    generateKernel(inputs, outputs);
95    iBuilder->setModule(saveModule);
96    iBuilder->restoreIP(savePoint);
97    return module;
98}
99
100void KernelBuilder::generateKernel(const std::vector<StreamSetBuffer *> & inputs, const std::vector<StreamSetBuffer *> & outputs) {
101    auto savePoint = iBuilder->saveIP();
102    Module * const m = iBuilder->getModule();
103    mStreamSetInputBuffers.assign(inputs.begin(), inputs.end());
104    mStreamSetOutputBuffers.assign(outputs.begin(), outputs.end());
105    prepareKernel();            // possibly overridden by the KernelBuilder subtype
106    addKernelDeclarations(m);
107    generateInitMethod();       // possibly overridden by the KernelBuilder subtype
108    generateDoBlockMethod();    // must be implemented by the KernelBuilder subtype
109    generateFinalBlockMethod(); // possibly overridden by the KernelBuilder subtype
110    generateDoSegmentMethod();
111    generateFinalSegmentMethod();
112
113    // Implement the accumulator get functions
114    for (auto binding : mScalarOutputs) {
115        auto fnName = mKernelName + accumulator_infix + binding.name;
116        Function * accumFn = m->getFunction(fnName);
117        iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "get_" + binding.name, accumFn, 0));
118        Value * self = &*(accumFn->arg_begin());
119        Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(binding.name)});
120        Value * retVal = iBuilder->CreateLoad(ptr);
121        iBuilder->CreateRet(retVal);
122    }
123    generateInitMethod();
124    iBuilder->restoreIP(savePoint);
125}
126
127// Default init method, possibly overridden if special init actions required.
128void KernelBuilder::generateInitMethod() const {
129    auto savePoint = iBuilder->saveIP();
130    Module * const m = iBuilder->getModule();
131    Function * initFunction = m->getFunction(mKernelName + init_suffix);
132    iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "Init_entry", initFunction, 0));   
133    Function::arg_iterator args = initFunction->arg_begin();
134    Value * self = &*(args++);
135    iBuilder->CreateStore(ConstantAggregateZero::get(mKernelStateType), self);
136    for (auto binding : mScalarInputs) {
137        Value * param = &*(args++);
138        Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(binding.name)});
139        iBuilder->CreateStore(param, ptr);
140    }
141    iBuilder->CreateRetVoid();
142    iBuilder->restoreIP(savePoint);
143}
144
145//  The default finalBlock method simply dispatches to the doBlock routine.
146void KernelBuilder::generateFinalBlockMethod() const {
147    auto savePoint = iBuilder->saveIP();
148    Module * m = iBuilder->getModule();
149    Function * doBlockFunction = m->getFunction(mKernelName + doBlock_suffix);
150    Function * finalBlockFunction = m->getFunction(mKernelName + finalBlock_suffix);
151    iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "fb_entry", finalBlockFunction, 0));
152    // Final Block arguments: self, remaining, then the standard DoBlock args.
153    Function::arg_iterator args = finalBlockFunction->arg_begin();
154    Value * self = &*(args++);
155    /* Skip "remaining" arg */ args++;
156    std::vector<Value *> doBlockArgs = {self};
157    while (args != finalBlockFunction->arg_end()){
158        doBlockArgs.push_back(&*args++);
159    }
160    iBuilder->CreateCall(doBlockFunction, doBlockArgs);
161    iBuilder->CreateRetVoid();
162    iBuilder->restoreIP(savePoint);
163}
164
165// Note: this may be overridden to incorporate doBlock logic directly into
166// the doSegment function.
167void KernelBuilder::generateDoBlockLogic(Value * self, Value * /* blockNo */) const {
168    Function * doBlockFunction = iBuilder->getModule()->getFunction(mKernelName + doBlock_suffix);
169    iBuilder->CreateCall(doBlockFunction, self);
170}
171
172
173//  The default doSegment method dispatches to the doBlock routine for
174//  each block of the given number of blocksToDo, and then updates counts.
175void KernelBuilder::generateDoSegmentMethod() const {
176    auto savePoint = iBuilder->saveIP();
177    Module * m = iBuilder->getModule();
178    Function * doSegmentFunction = m->getFunction(mKernelName + doSegment_suffix);
179    iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "entry", doSegmentFunction, 0));
180    BasicBlock * entryBlock = iBuilder->GetInsertBlock();
181    BasicBlock * strideLoopCond = BasicBlock::Create(iBuilder->getContext(), "strideLoopCond", doSegmentFunction, 0);
182    BasicBlock * strideLoopBody = BasicBlock::Create(iBuilder->getContext(), "strideLoopBody", doSegmentFunction, 0);
183    BasicBlock * stridesDone = BasicBlock::Create(iBuilder->getContext(), "stridesDone", doSegmentFunction, 0);
184    BasicBlock * segmentDone = BasicBlock::Create(iBuilder->getContext(), "segmentDone", doSegmentFunction, 0);
185    BasicBlock * finalExit = BasicBlock::Create(iBuilder->getContext(), "finalExit", doSegmentFunction, 0);
186    Type * const size_ty = iBuilder->getSizeTy();
187    Constant * stride = ConstantInt::get(size_ty, iBuilder->getStride());
188    Value * strideBlocks = ConstantInt::get(size_ty, iBuilder->getStride() / iBuilder->getBitBlockWidth());
189   
190    Function::arg_iterator args = doSegmentFunction->arg_begin();
191    Value * self = &*(args++);
192    Value * blocksToDo = &*(args);
193   
194    std::vector<Value *> inbufProducerPtrs;
195    std::vector<Value *> endSignalPtrs;
196    for (unsigned i = 0; i < mStreamSetInputs.size(); i++) {
197        Value * param = getStreamSetStructPtr(self, mStreamSetInputs[i].name);
198        inbufProducerPtrs.push_back(mStreamSetInputBuffers[i]->getProducerPosPtr(param));
199        endSignalPtrs.push_back(mStreamSetInputBuffers[i]->getEndOfInputPtr(param));
200    }
201   
202    std::vector<Value *> producerPos;
203    /* Determine the actually available data examining all input stream sets. */
204    LoadInst * p = iBuilder->CreateAtomicLoadAcquire(inbufProducerPtrs[0]);
205    producerPos.push_back(p);
206    Value * availablePos = producerPos[0];
207    for (unsigned i = 1; i < inbufProducerPtrs.size(); i++) {
208        LoadInst * p = iBuilder->CreateAtomicLoadAcquire(inbufProducerPtrs[i]);
209        producerPos.push_back(p);
210        /* Set the available position to be the minimum of availablePos and producerPos. */
211        availablePos = iBuilder->CreateSelect(iBuilder->CreateICmpULT(availablePos, p), availablePos, p);
212    }
213    Value * processed = getProcessedItemCount(self, mStreamSetInputs[0].name);
214    Value * itemsAvail = iBuilder->CreateSub(availablePos, processed);
215#ifndef NDEBUG
216    iBuilder->CallPrintInt(mKernelName + "_itemsAvail", itemsAvail);
217#endif
218    Value * stridesToDo = iBuilder->CreateUDiv(blocksToDo, strideBlocks);
219    Value * stridesAvail = iBuilder->CreateUDiv(itemsAvail, stride);
220    /* Adjust the number of full blocks to do, based on the available data, if necessary. */
221    Value * lessThanFullSegment = iBuilder->CreateICmpULT(stridesAvail, stridesToDo);
222    stridesToDo = iBuilder->CreateSelect(lessThanFullSegment, stridesAvail, stridesToDo);
223    //iBuilder->CallPrintInt(mKernelName + "_stridesAvail", stridesAvail);
224    iBuilder->CreateBr(strideLoopCond);
225
226    iBuilder->SetInsertPoint(strideLoopCond);
227    PHINode * stridesRemaining = iBuilder->CreatePHI(size_ty, 2, "stridesRemaining");
228    stridesRemaining->addIncoming(stridesToDo, entryBlock);
229    Value * notDone = iBuilder->CreateICmpUGT(stridesRemaining, ConstantInt::get(size_ty, 0));
230    iBuilder->CreateCondBr(notDone, strideLoopBody, stridesDone);
231
232    iBuilder->SetInsertPoint(strideLoopBody);
233    Value * blockNo = getScalarField(self, blockNoScalar);   
234
235    generateDoBlockLogic(self, blockNo);
236    setBlockNo(self, iBuilder->CreateAdd(blockNo, strideBlocks));
237    stridesRemaining->addIncoming(iBuilder->CreateSub(stridesRemaining, ConstantInt::get(size_ty, 1)), strideLoopBody);
238    iBuilder->CreateBr(strideLoopCond);
239   
240    iBuilder->SetInsertPoint(stridesDone);
241    processed = iBuilder->CreateAdd(processed, iBuilder->CreateMul(stridesToDo, stride));
242    setProcessedItemCount(self, mStreamSetInputs[0].name, processed);
243    iBuilder->CreateBr(segmentDone);
244    iBuilder->SetInsertPoint(segmentDone);
245#ifndef NDEBUG
246    iBuilder->CallPrintInt(mKernelName + "_processed", processed);
247#endif
248    for (unsigned i = 0; i < mStreamSetOutputs.size(); i++) {
249        Value * produced = getProducedItemCount(self, mStreamSetOutputs[i].name);
250        Value * ssStructPtr = getStreamSetStructPtr(self, mStreamSetOutputs[i].name);
251        Value * producerPosPtr = mStreamSetOutputBuffers[i]->getProducerPosPtr(ssStructPtr);
252        iBuilder->CreateAtomicStoreRelease(produced, producerPosPtr);
253    }
254    iBuilder->CreateBr(finalExit);
255    iBuilder->SetInsertPoint(finalExit);
256
257    iBuilder->CreateRetVoid();
258    iBuilder->restoreIP(savePoint);
259}
260
261void KernelBuilder::generateFinalSegmentMethod() const {
262    auto savePoint = iBuilder->saveIP();
263    Module * m = iBuilder->getModule();
264    Function * finalSegmentFunction = m->getFunction(mKernelName + finalSegment_suffix);
265    iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "entry", finalSegmentFunction, 0));
266    BasicBlock * doStrides = BasicBlock::Create(iBuilder->getContext(), "doStrides", finalSegmentFunction, 0);
267    BasicBlock * stridesDone = BasicBlock::Create(iBuilder->getContext(), "stridesDone", finalSegmentFunction, 0);
268    Type * const size_ty = iBuilder->getSizeTy();
269    Constant * stride = ConstantInt::get(size_ty, iBuilder->getStride());
270    Value * strideBlocks = ConstantInt::get(size_ty, iBuilder->getStride() / iBuilder->getBitBlockWidth());
271    Function::arg_iterator args = finalSegmentFunction->arg_begin();
272    Value * self = &*(args++);
273    Value * blocksToDo = &*(args);
274    std::vector<Value *> inbufProducerPtrs;
275    std::vector<Value *> endSignalPtrs;
276    for (unsigned i = 0; i < mStreamSetInputs.size(); i++) {
277        Value * param = getStreamSetStructPtr(self, mStreamSetInputs[i].name);
278        inbufProducerPtrs.push_back(mStreamSetInputBuffers[i]->getProducerPosPtr(param));
279        endSignalPtrs.push_back(mStreamSetInputBuffers[i]->getEndOfInputPtr(param));
280    }
281   
282    std::vector<Value *> producerPos;
283    /* Determine the actually available data examining all input stream sets. */
284    LoadInst * p = iBuilder->CreateAtomicLoadAcquire(inbufProducerPtrs[0]);
285    producerPos.push_back(p);
286    Value * availablePos = producerPos[0];
287    for (unsigned i = 1; i < inbufProducerPtrs.size(); i++) {
288        LoadInst * p = iBuilder->CreateAtomicLoadAcquire(inbufProducerPtrs[i]);
289        producerPos.push_back(p);
290        /* Set the available position to be the minimum of availablePos and producerPos. */
291        availablePos = iBuilder->CreateSelect(iBuilder->CreateICmpULT(availablePos, p), availablePos, p);
292    }
293    Value * processed = getProcessedItemCount(self, mStreamSetInputs[0].name);
294    Value * itemsAvail = iBuilder->CreateSub(availablePos, processed);
295#ifndef NDEBUG
296    iBuilder->CallPrintInt(mKernelName + "_itemsAvail final", itemsAvail);
297#endif
298    Value * stridesToDo = iBuilder->CreateUDiv(blocksToDo, strideBlocks);
299    Value * stridesAvail = iBuilder->CreateUDiv(itemsAvail, stride);
300    /* Adjust the number of full blocks to do, based on the available data, if necessary. */
301    Value * lessThanFullSegment = iBuilder->CreateICmpULT(stridesAvail, stridesToDo);
302    stridesToDo = iBuilder->CreateSelect(lessThanFullSegment, stridesAvail, stridesToDo);
303    Value * notDone = iBuilder->CreateICmpUGT(stridesToDo, ConstantInt::get(size_ty, 0));
304    iBuilder->CreateCondBr(notDone, doStrides, stridesDone);
305   
306    iBuilder->SetInsertPoint(doStrides);
307    createDoSegmentCall(self, blocksToDo);
308    iBuilder->CreateBr(stridesDone);
309   
310    iBuilder->SetInsertPoint(stridesDone);
311    /* Now at most a partial block remains. */
312   
313    processed = getProcessedItemCount(self, mStreamSetInputs[0].name);   
314    Value * remainingItems = iBuilder->CreateSub(producerPos[0], processed);
315    //iBuilder->CallPrintInt(mKernelName + " remainingItems", remainingItems);
316       
317    createFinalBlockCall(self, remainingItems);
318    processed = iBuilder->CreateAdd(processed, remainingItems);
319    setProcessedItemCount(self, mStreamSetInputs[0].name, processed);
320       
321#ifndef NDEBUG
322    iBuilder->CallPrintInt(mKernelName + "_processed final", processed);
323#endif
324    for (unsigned i = 0; i < mStreamSetOutputs.size(); i++) {
325        Value * produced = getProducedItemCount(self, mStreamSetOutputs[i].name);
326        Value * ssStructPtr = getStreamSetStructPtr(self, mStreamSetOutputs[i].name);
327        Value * producerPosPtr = mStreamSetOutputBuffers[i]->getProducerPosPtr(ssStructPtr);
328        iBuilder->CreateAtomicStoreRelease(produced, producerPosPtr);
329    }
330
331    iBuilder->CreateRetVoid();
332
333    iBuilder->restoreIP(savePoint);
334}
335
336
337
338ConstantInt * KernelBuilder::getScalarIndex(const std::string & name) const {
339    const auto f = mKernelMap.find(name);
340    if (LLVM_UNLIKELY(f == mKernelMap.end())) {
341        llvm::report_fatal_error("Kernel does not contain scalar: " + name);
342    }
343    return iBuilder->getInt32(f->second);
344}
345
346unsigned KernelBuilder::getScalarCount() const {
347    return mKernelFields.size();
348}
349
350Value * KernelBuilder::getScalarFieldPtr(Value * self, const std::string & fieldName) const {
351    return iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(fieldName)});
352}
353
354Value * KernelBuilder::getScalarField(Value * self, const std::string & fieldName) const {
355    return iBuilder->CreateLoad(getScalarFieldPtr(self, fieldName));
356}
357
358void KernelBuilder::setScalarField(Value * self, const std::string & fieldName, Value * newFieldVal) const {
359    iBuilder->CreateStore(newFieldVal, getScalarFieldPtr(self, fieldName));
360}
361
362LoadInst * KernelBuilder::acquireLogicalSegmentNo(Value * self) const {
363    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(logicalSegmentNoScalar)});
364    return iBuilder->CreateAtomicLoadAcquire(ptr);
365}
366
367Value * KernelBuilder::getProcessedItemCount(Value * self, const std::string & ssName) const {
368    return getScalarField(self, ssName + processedItemCountSuffix);
369}
370
371Value * KernelBuilder::getProducedItemCount(Value * self, const std::string & ssName) const {
372    return getScalarField(self, ssName + producedItemCountSuffix);
373}
374
375Value * KernelBuilder::getTerminationSignal(Value * self) const {
376    return getScalarField(self, terminationSignal);
377}
378
379void KernelBuilder::releaseLogicalSegmentNo(Value * self, Value * newCount) const {
380    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(logicalSegmentNoScalar)});
381    iBuilder->CreateAtomicStoreRelease(newCount, ptr);
382}
383
384void KernelBuilder::setProcessedItemCount(Value * self, const std::string & ssName, Value * newCount) const {
385    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(ssName + processedItemCountSuffix)});
386    iBuilder->CreateStore(newCount, ptr);
387}
388
389void KernelBuilder::setProducedItemCount(Value * self, const std::string & ssName, Value * newCount) const {
390    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(ssName + producedItemCountSuffix)});
391    iBuilder->CreateStore(newCount, ptr);
392}
393
394void KernelBuilder::setTerminationSignal(Value * self) const {
395    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(terminationSignal)});
396    iBuilder->CreateStore(ConstantInt::get(iBuilder->getInt1Ty(), 1), ptr);
397}
398
399Value * KernelBuilder::getBlockNo(Value * self) const {
400    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(blockNoScalar)});
401    return iBuilder->CreateLoad(ptr);
402}
403
404void KernelBuilder::setBlockNo(Value * self, Value * newFieldVal) const {
405    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(blockNoScalar)});
406    iBuilder->CreateStore(newFieldVal, ptr);
407}
408
409
410Value * KernelBuilder::getParameter(Function * f, const std::string & paramName) const {
411    for (Function::arg_iterator argIter = f->arg_begin(), end = f->arg_end(); argIter != end; argIter++) {
412        Value * arg = &*argIter;
413        if (arg->getName() == paramName) return arg;
414    }
415    llvm::report_fatal_error("Method does not have parameter: " + paramName);
416}
417
418unsigned KernelBuilder::getStreamSetIndex(const std::string & name) const {
419    const auto f = mStreamSetNameMap.find(name);
420    if (LLVM_UNLIKELY(f == mStreamSetNameMap.end())) {
421        llvm::report_fatal_error("Kernel does not contain stream set: " + name);
422    }
423    return f->second;
424}
425
426size_t KernelBuilder::getStreamSetBufferSize(Value * /* self */, const std::string & name) const {
427    const unsigned index = getStreamSetIndex(name);
428    StreamSetBuffer * buf = nullptr;
429    if (index < mStreamSetInputs.size()) {
430        buf = mStreamSetInputBuffers[index];
431    } else {
432        buf = mStreamSetOutputBuffers[index - mStreamSetInputs.size()];
433    }
434    return buf->getBufferSize();
435}
436
437Value * KernelBuilder::getStreamSetStructPtr(Value * self, const std::string & name) const {
438    return getScalarField(self, name + structPtrSuffix);
439}
440
441Value * KernelBuilder::getStreamSetBlockPtr(Value * self, const std::string &name, Value * blockNo) const {
442    Value * const structPtr = getStreamSetStructPtr(self, name);
443    const unsigned index = getStreamSetIndex(name);
444    StreamSetBuffer * buf = nullptr;
445    if (index < mStreamSetInputs.size()) {
446        buf = mStreamSetInputBuffers[index];
447    } else {
448        buf = mStreamSetOutputBuffers[index - mStreamSetInputs.size()];
449    }   
450    return buf->getStreamSetBlockPointer(structPtr, blockNo);
451}
452
453Value * KernelBuilder::getStream(Value * self, const std::string & name, Value * blockNo, Value * index) {
454    return iBuilder->CreateGEP(getStreamSetBlockPtr(self, name, blockNo), {iBuilder->getInt32(0), index});
455}
456
457void KernelBuilder::createInstance() {
458    if (LLVM_UNLIKELY(mKernelStateType == nullptr)) {
459        llvm::report_fatal_error("Cannot create kernel instance before calling prepareKernel()");
460    }
461    mKernelInstance = iBuilder->CreateCacheAlignedAlloca(mKernelStateType);
462    Module * m = iBuilder->getModule();
463    std::vector<Value *> init_args = {mKernelInstance};
464    for (auto a : mInitialArguments) {
465        init_args.push_back(a);
466    }
467    for (auto b : mStreamSetInputBuffers) {
468        init_args.push_back(b->getStreamSetStructPtr());
469    }
470    for (auto b : mStreamSetOutputBuffers) {
471        init_args.push_back(b->getStreamSetStructPtr());
472    }
473    std::string initFnName = mKernelName + init_suffix;
474    Function * initMethod = m->getFunction(initFnName);
475    if (!initMethod) {
476        llvm::report_fatal_error("Cannot find " + initFnName);
477    }
478    iBuilder->CreateCall(initMethod, init_args);
479}
480
481Function * KernelBuilder::generateThreadFunction(const std::string & name) const {
482    if (LLVM_UNLIKELY(mKernelStateType == nullptr)) {
483        llvm::report_fatal_error("Cannot generate thread function before calling prepareKernel()");
484    }
485    Module * m = iBuilder->getModule();
486    Type * const voidTy = iBuilder->getVoidTy();
487    Type * const voidPtrTy = iBuilder->getVoidPtrTy();
488    Type * const int8PtrTy = iBuilder->getInt8PtrTy();
489    Type * const int1ty = iBuilder->getInt1Ty();
490
491    Function * const threadFunc = cast<Function>(m->getOrInsertFunction(name, voidTy, int8PtrTy, nullptr));
492    threadFunc->setCallingConv(CallingConv::C);
493    Function::arg_iterator args = threadFunc->arg_begin();
494
495    Value * const arg = &*(args++);
496    arg->setName("args");
497
498    iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "entry", threadFunc,0));
499
500    Value * self = iBuilder->CreateBitCast(arg, PointerType::get(mKernelStateType, 0));
501
502    std::vector<Value *> inbufProducerPtrs;
503    std::vector<Value *> inbufConsumerPtrs;
504    std::vector<Value *> outbufProducerPtrs;
505    std::vector<Value *> outbufConsumerPtrs;   
506    std::vector<Value *> endSignalPtrs;
507
508    for (unsigned i = 0; i < mStreamSetInputs.size(); i++) {
509        Value * ssStructPtr = getStreamSetStructPtr(self, mStreamSetInputs[i].name);
510        inbufProducerPtrs.push_back(mStreamSetInputBuffers[i]->getProducerPosPtr(ssStructPtr));
511        inbufConsumerPtrs.push_back(mStreamSetInputBuffers[i]->getConsumerPosPtr(ssStructPtr));
512        endSignalPtrs.push_back(mStreamSetInputBuffers[i]->getEndOfInputPtr(ssStructPtr));
513    }
514    for (unsigned i = 0; i < mStreamSetOutputs.size(); i++) {
515        Value * ssStructPtr = getStreamSetStructPtr(self, mStreamSetOutputs[i].name);
516        outbufProducerPtrs.push_back(mStreamSetOutputBuffers[i]->getProducerPosPtr(ssStructPtr));
517        outbufConsumerPtrs.push_back(mStreamSetOutputBuffers[i]->getConsumerPosPtr(ssStructPtr));
518    }
519
520    const unsigned segmentBlocks = codegen::SegmentSize;
521    const unsigned bufferSegments = codegen::BufferSegments;
522    const unsigned segmentSize = segmentBlocks * iBuilder->getBitBlockWidth();
523    Type * const size_ty = iBuilder->getSizeTy();
524
525    Value * segSize = ConstantInt::get(size_ty, segmentSize);
526    Value * bufferSize = ConstantInt::get(size_ty, segmentSize * (bufferSegments - 1));
527    Value * segBlocks = ConstantInt::get(size_ty, segmentBlocks);
528   
529    BasicBlock * outputCheckBlock = BasicBlock::Create(iBuilder->getContext(), "outputCheck", threadFunc, 0);
530    BasicBlock * inputCheckBlock = BasicBlock::Create(iBuilder->getContext(), "inputCheck", threadFunc, 0);
531   
532    BasicBlock * endSignalCheckBlock = BasicBlock::Create(iBuilder->getContext(), "endSignalCheck", threadFunc, 0);
533    BasicBlock * doSegmentBlock = BasicBlock::Create(iBuilder->getContext(), "doSegment", threadFunc, 0);
534    BasicBlock * endBlock = BasicBlock::Create(iBuilder->getContext(), "end", threadFunc, 0);
535    BasicBlock * doFinalSegBlock = BasicBlock::Create(iBuilder->getContext(), "doFinalSeg", threadFunc, 0);
536    BasicBlock * doFinalBlock = BasicBlock::Create(iBuilder->getContext(), "doFinal", threadFunc, 0);
537
538    iBuilder->CreateBr(outputCheckBlock);
539
540    iBuilder->SetInsertPoint(outputCheckBlock);
541
542    Value * waitCondTest = ConstantInt::get(int1ty, 1);   
543    for (unsigned i = 0; i < outbufProducerPtrs.size(); i++) {
544        LoadInst * producerPos = iBuilder->CreateAtomicLoadAcquire(outbufProducerPtrs[i]);
545        // iBuilder->CallPrintInt(name + ":output producerPos", producerPos);
546        LoadInst * consumerPos = iBuilder->CreateAtomicLoadAcquire(outbufConsumerPtrs[i]);
547        // iBuilder->CallPrintInt(name + ":output consumerPos", consumerPos);
548        waitCondTest = iBuilder->CreateAnd(waitCondTest, iBuilder->CreateICmpULE(producerPos, iBuilder->CreateAdd(consumerPos, bufferSize)));
549    }
550   
551    iBuilder->CreateCondBr(waitCondTest, inputCheckBlock, outputCheckBlock); 
552
553    iBuilder->SetInsertPoint(inputCheckBlock); 
554
555    Value * requiredSize = segSize;
556    if (mLookAheadPositions > 0) {
557        requiredSize = iBuilder->CreateAdd(segSize, ConstantInt::get(size_ty, mLookAheadPositions));
558    }
559    waitCondTest = ConstantInt::get(int1ty, 1); 
560    for (unsigned i = 0; i < inbufProducerPtrs.size(); i++) {
561        LoadInst * producerPos = iBuilder->CreateAtomicLoadAcquire(inbufProducerPtrs[i]);
562        // iBuilder->CallPrintInt(name + ":input producerPos", producerPos);
563        LoadInst * consumerPos = iBuilder->CreateAtomicLoadAcquire(inbufConsumerPtrs[i]);
564        // iBuilder->CallPrintInt(name + ":input consumerPos", consumerPos);
565        waitCondTest = iBuilder->CreateAnd(waitCondTest, iBuilder->CreateICmpULE(iBuilder->CreateAdd(consumerPos, requiredSize), producerPos));
566    }
567
568    iBuilder->CreateCondBr(waitCondTest, doSegmentBlock, endSignalCheckBlock);
569   
570    iBuilder->SetInsertPoint(endSignalCheckBlock);
571   
572    LoadInst * endSignal = iBuilder->CreateLoad(endSignalPtrs[0]);
573    for (unsigned i = 1; i < endSignalPtrs.size(); i++){
574        LoadInst * endSignal_next = iBuilder->CreateLoad(endSignalPtrs[i]);
575        iBuilder->CreateAnd(endSignal, endSignal_next);
576    }
577       
578    iBuilder->CreateCondBr(endSignal, endBlock, inputCheckBlock);
579   
580    iBuilder->SetInsertPoint(doSegmentBlock);
581 
582    createDoSegmentCall(self, segBlocks);
583
584    for (unsigned i = 0; i < inbufConsumerPtrs.size(); i++) {
585        Value * consumerPos = iBuilder->CreateAdd(iBuilder->CreateLoad(inbufConsumerPtrs[i]), segSize);
586        iBuilder->CreateAtomicStoreRelease(consumerPos, inbufConsumerPtrs[i]);
587    }
588   
589    for (unsigned i = 0; i < outbufProducerPtrs.size(); i++) {
590        Value * produced = getProducedItemCount(self, mStreamSetOutputs[i].name);
591        iBuilder->CreateAtomicStoreRelease(produced, outbufProducerPtrs[i]);
592    }
593   
594    Value * earlyEndSignal = getTerminationSignal(self);
595    if (earlyEndSignal != ConstantInt::getNullValue(iBuilder->getInt1Ty())) {
596        BasicBlock * earlyEndBlock = BasicBlock::Create(iBuilder->getContext(), "earlyEndSignal", threadFunc, 0);
597        iBuilder->CreateCondBr(earlyEndSignal, earlyEndBlock, outputCheckBlock);
598
599        iBuilder->SetInsertPoint(earlyEndBlock);
600        for (unsigned i = 0; i < mStreamSetOutputs.size(); i++) {
601            Value * ssStructPtr = getStreamSetStructPtr(self, mStreamSetOutputs[i].name);
602            mStreamSetOutputBuffers[i]->setEndOfInput(ssStructPtr);
603        }       
604    }
605    iBuilder->CreateBr(outputCheckBlock);
606     
607    iBuilder->SetInsertPoint(endBlock);
608    LoadInst * producerPos = iBuilder->CreateLoad(inbufProducerPtrs[0]);
609    LoadInst * consumerPos = iBuilder->CreateLoad(inbufConsumerPtrs[0]);
610    Value * remainingBytes = iBuilder->CreateSub(producerPos, consumerPos);
611    Value * blockSize = ConstantInt::get(size_ty, iBuilder->getBitBlockWidth());
612    Value * blocks = iBuilder->CreateUDiv(remainingBytes, blockSize);
613    Value * finalBlockRemainingBytes = iBuilder->CreateURem(remainingBytes, blockSize);
614
615    iBuilder->CreateCondBr(iBuilder->CreateICmpEQ(blocks, ConstantInt::get(size_ty, 0)), doFinalBlock, doFinalSegBlock);
616
617    iBuilder->SetInsertPoint(doFinalSegBlock);
618
619    createDoSegmentCall(self, blocks);
620
621    iBuilder->CreateBr(doFinalBlock);
622
623    iBuilder->SetInsertPoint(doFinalBlock);
624
625    createFinalBlockCall(self, finalBlockRemainingBytes);
626
627    for (unsigned i = 0; i < inbufConsumerPtrs.size(); i++) {
628        Value * consumerPos = iBuilder->CreateAdd(iBuilder->CreateLoad(inbufConsumerPtrs[i]), remainingBytes);
629        iBuilder->CreateAtomicStoreRelease(consumerPos, inbufConsumerPtrs[i]);
630    }
631    for (unsigned i = 0; i < outbufProducerPtrs.size(); i++) {
632        iBuilder->CreateAtomicStoreRelease(producerPos, outbufProducerPtrs[i]);
633    }
634
635    for (unsigned i = 0; i < mStreamSetOutputs.size(); i++) {
636        Value * ssStructPtr = getStreamSetStructPtr(self, mStreamSetOutputs[i].name);
637        mStreamSetOutputBuffers[i]->setEndOfInput(ssStructPtr);
638    }
639
640    iBuilder->CreatePThreadExitCall(Constant::getNullValue(voidPtrTy));
641    iBuilder->CreateRetVoid();
642
643    return threadFunc;
644
645}
646
647KernelBuilder::~KernelBuilder() {
648}
Note: See TracBrowser for help on using the repository browser.