source: icGREP/icgrep-devel/icgrep/kernels/kernel.cpp @ 5260

Last change on this file since 5260 was 5260, checked in by nmedfort, 9 months ago

Changes working towards simplifying accessing stream elements + some modifications to simplify include / forward declarations within the CodeGen? library.

File size: 31.6 KB
Line 
1/*
2 *  Copyright (c) 2016 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 */
5
6#include "kernel.h"
7#include <llvm/IR/Value.h>               // for Value
8#include <llvm/Support/ErrorHandling.h>  // for report_fatal_error
9#include <toolchain.h>                   // for BufferSegments, SegmentSize
10#include "IR_Gen/idisa_builder.h"        // for IDISA_Builder
11#include "kernels/streamset.h"           // for StreamSetBuffer
12#include "llvm/ADT/StringRef.h"          // for StringRef, operator==
13#include "llvm/IR/CallingConv.h"         // for ::C
14#include "llvm/IR/Constant.h"            // for Constant
15#include "llvm/IR/Constants.h"           // for ConstantInt
16#include "llvm/IR/Function.h"            // for Function, Function::arg_iter...
17#include "llvm/IR/Instructions.h"        // for LoadInst (ptr only), PHINode
18#include "llvm/Support/Compiler.h"       // for LLVM_UNLIKELY
19namespace llvm { class BasicBlock; }
20namespace llvm { class Module; }
21namespace llvm { class Type; }
22
23using namespace llvm;
24using namespace kernel;
25using namespace parabix;
26
27KernelBuilder::KernelBuilder(IDISA::IDISA_Builder * builder,
28                             std::string kernelName,
29                             std::vector<Binding> stream_inputs,
30                             std::vector<Binding> stream_outputs,
31                             std::vector<Binding> scalar_parameters,
32                             std::vector<Binding> scalar_outputs,
33                             std::vector<Binding> internal_scalars)
34: KernelInterface(builder, kernelName, stream_inputs, stream_outputs, scalar_parameters, scalar_outputs, internal_scalars),
35mNoTerminateAttribute(false) {
36
37}
38
39unsigned KernelBuilder::addScalar(Type * type, const std::string & name) {
40    if (LLVM_UNLIKELY(mKernelStateType != nullptr)) {
41        llvm::report_fatal_error("Cannot add kernel field " + name + " after kernel state finalized");
42    }
43    const auto index = mKernelFields.size();
44    mKernelMap.emplace(name, index);
45    mKernelFields.push_back(type);
46    return index;
47}
48
49void KernelBuilder::prepareKernel() {
50    if (LLVM_UNLIKELY(mKernelStateType != nullptr)) {
51        llvm::report_fatal_error("Cannot prepare kernel after kernel state finalized");
52    }
53    unsigned blockSize = iBuilder->getBitBlockWidth();
54    if (mStreamSetInputs.size() != mStreamSetInputBuffers.size()) {
55        std::string tmp;
56        raw_string_ostream out(tmp);
57        out << "kernel contains " << mStreamSetInputBuffers.size() << " input buffers for "
58            << mStreamSetInputs.size() << " input stream sets.";
59        throw std::runtime_error(out.str());
60    }
61    if (mStreamSetOutputs.size() != mStreamSetOutputBuffers.size()) {
62        std::string tmp;
63        raw_string_ostream out(tmp);
64        out << "kernel contains " << mStreamSetOutputBuffers.size() << " output buffers for "
65            << mStreamSetOutputs.size() << " output stream sets.";
66        throw std::runtime_error(out.str());
67    }
68    int streamSetNo = 0;
69    for (unsigned i = 0; i < mStreamSetInputs.size(); i++) {
70        if ((mStreamSetInputBuffers[i]->getBufferSize() > 0) && (mStreamSetInputBuffers[i]->getBufferSize() < codegen::SegmentSize + (blockSize + mLookAheadPositions - 1)/blockSize)) {
71             llvm::report_fatal_error("Kernel preparation: Buffer size too small " + mStreamSetInputs[i].name);
72        }
73        mScalarInputs.push_back(Binding{mStreamSetInputBuffers[i]->getStreamSetStructPointerType(), mStreamSetInputs[i].name + structPtrSuffix});
74        mStreamSetNameMap.emplace(mStreamSetInputs[i].name, streamSetNo);
75        addScalar(iBuilder->getSizeTy(), mStreamSetInputs[i].name + processedItemCountSuffix);
76        streamSetNo++;
77    }
78    for (unsigned i = 0; i < mStreamSetOutputs.size(); i++) {
79        mScalarInputs.push_back(Binding{mStreamSetOutputBuffers[i]->getStreamSetStructPointerType(), mStreamSetOutputs[i].name + structPtrSuffix});
80        mStreamSetNameMap.emplace(mStreamSetOutputs[i].name, streamSetNo);
81        addScalar(iBuilder->getSizeTy(), mStreamSetOutputs[i].name + producedItemCountSuffix);
82        streamSetNo++;
83    }
84    for (auto binding : mScalarInputs) {
85        addScalar(binding.type, binding.name);
86    }
87    for (auto binding : mScalarOutputs) {
88        addScalar(binding.type, binding.name);
89    }
90    for (auto binding : mInternalScalars) {
91        addScalar(binding.type, binding.name);
92    }
93    addScalar(iBuilder->getSizeTy(), blockNoScalar);
94    addScalar(iBuilder->getSizeTy(), logicalSegmentNoScalar);
95    addScalar(iBuilder->getInt1Ty(), terminationSignal);
96    mKernelStateType = StructType::create(iBuilder->getContext(), mKernelFields, mKernelName);
97}
98
99std::unique_ptr<Module> KernelBuilder::createKernelModule(const std::vector<StreamSetBuffer *> & inputs, const std::vector<StreamSetBuffer *> & outputs) {
100    auto saveModule = iBuilder->getModule();
101    auto savePoint = iBuilder->saveIP();
102    auto module = make_unique<Module>(mKernelName + "_" + iBuilder->getBitBlockTypeName(), iBuilder->getContext());
103    iBuilder->setModule(module.get());
104    generateKernel(inputs, outputs);
105    iBuilder->setModule(saveModule);
106    iBuilder->restoreIP(savePoint);
107    return module;
108}
109
110void KernelBuilder::generateKernel(const std::vector<StreamSetBuffer *> & inputs, const std::vector<StreamSetBuffer *> & outputs) {
111    auto savePoint = iBuilder->saveIP();
112    Module * const m = iBuilder->getModule();
113    mStreamSetInputBuffers.assign(inputs.begin(), inputs.end());
114    mStreamSetOutputBuffers.assign(outputs.begin(), outputs.end());
115    prepareKernel();            // possibly overridden by the KernelBuilder subtype
116    addKernelDeclarations(m);
117    generateInitMethod();       // possibly overridden by the KernelBuilder subtype
118    generateDoBlockMethod();    // must be implemented by the KernelBuilder subtype
119    generateFinalBlockMethod(); // possibly overridden by the KernelBuilder subtype
120    generateDoSegmentMethod();
121    generateFinalSegmentMethod();
122
123    // Implement the accumulator get functions
124    for (auto binding : mScalarOutputs) {
125        auto fnName = mKernelName + accumulator_infix + binding.name;
126        Function * accumFn = m->getFunction(fnName);
127        iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "get_" + binding.name, accumFn, 0));
128        Value * self = &*(accumFn->arg_begin());
129        Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(binding.name)});
130        Value * retVal = iBuilder->CreateLoad(ptr);
131        iBuilder->CreateRet(retVal);
132    }
133    generateInitMethod();
134    iBuilder->restoreIP(savePoint);
135}
136
137// Default init method, possibly overridden if special init actions required.
138void KernelBuilder::generateInitMethod() const {
139    auto savePoint = iBuilder->saveIP();
140    Module * const m = iBuilder->getModule();
141    Function * initFunction = m->getFunction(mKernelName + init_suffix);
142    iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "Init_entry", initFunction, 0));   
143    Function::arg_iterator args = initFunction->arg_begin();
144    Value * self = &*(args++);
145    iBuilder->CreateStore(ConstantAggregateZero::get(mKernelStateType), self);
146    for (auto binding : mScalarInputs) {
147        Value * param = &*(args++);
148        Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(binding.name)});
149        iBuilder->CreateStore(param, ptr);
150    }
151    iBuilder->CreateRetVoid();
152    iBuilder->restoreIP(savePoint);
153}
154
155//  The default finalBlock method simply dispatches to the doBlock routine.
156void KernelBuilder::generateFinalBlockMethod() const {
157    auto savePoint = iBuilder->saveIP();
158    Module * m = iBuilder->getModule();
159    Function * doBlockFunction = m->getFunction(mKernelName + doBlock_suffix);
160    Function * finalBlockFunction = m->getFunction(mKernelName + finalBlock_suffix);
161    iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "fb_entry", finalBlockFunction, 0));
162    // Final Block arguments: self, remaining, then the standard DoBlock args.
163    Function::arg_iterator args = finalBlockFunction->arg_begin();
164    Value * self = &*(args++);
165    /* Skip "remaining" arg */ args++;
166    std::vector<Value *> doBlockArgs = {self};
167    while (args != finalBlockFunction->arg_end()){
168        doBlockArgs.push_back(&*args++);
169    }
170    iBuilder->CreateCall(doBlockFunction, doBlockArgs);
171    iBuilder->CreateRetVoid();
172    iBuilder->restoreIP(savePoint);
173}
174
175// Note: this may be overridden to incorporate doBlock logic directly into
176// the doSegment function.
177void KernelBuilder::generateDoBlockLogic(Value * self, Value * /* blockNo */) const {
178    Function * doBlockFunction = iBuilder->getModule()->getFunction(mKernelName + doBlock_suffix);
179    iBuilder->CreateCall(doBlockFunction, self);
180}
181
182
183//  The default doSegment method dispatches to the doBlock routine for
184//  each block of the given number of blocksToDo, and then updates counts.
185void KernelBuilder::generateDoSegmentMethod() const {
186    auto savePoint = iBuilder->saveIP();
187    Module * m = iBuilder->getModule();
188    Function * doSegmentFunction = m->getFunction(mKernelName + doSegment_suffix);
189    iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "entry", doSegmentFunction, 0));
190    BasicBlock * entryBlock = iBuilder->GetInsertBlock();
191    BasicBlock * strideLoopCond = BasicBlock::Create(iBuilder->getContext(), "strideLoopCond", doSegmentFunction, 0);
192    BasicBlock * strideLoopBody = BasicBlock::Create(iBuilder->getContext(), "strideLoopBody", doSegmentFunction, 0);
193    BasicBlock * stridesDone = BasicBlock::Create(iBuilder->getContext(), "stridesDone", doSegmentFunction, 0);
194    BasicBlock * segmentDone = BasicBlock::Create(iBuilder->getContext(), "segmentDone", doSegmentFunction, 0);
195    BasicBlock * finalExit = BasicBlock::Create(iBuilder->getContext(), "finalExit", doSegmentFunction, 0);
196    Type * const size_ty = iBuilder->getSizeTy();
197    Constant * stride = ConstantInt::get(size_ty, iBuilder->getStride());
198    Value * strideBlocks = ConstantInt::get(size_ty, iBuilder->getStride() / iBuilder->getBitBlockWidth());
199   
200    Function::arg_iterator args = doSegmentFunction->arg_begin();
201    Value * self = &*(args++);
202    Value * blocksToDo = &*(args);
203   
204    std::vector<Value *> inbufProducerPtrs;
205    std::vector<Value *> endSignalPtrs;
206    for (unsigned i = 0; i < mStreamSetInputs.size(); i++) {
207        Value * param = getStreamSetStructPtr(self, mStreamSetInputs[i].name);
208        inbufProducerPtrs.push_back(mStreamSetInputBuffers[i]->getProducerPosPtr(param));
209        endSignalPtrs.push_back(mStreamSetInputBuffers[i]->getEndOfInputPtr(param));
210    }
211   
212    std::vector<Value *> producerPos;
213    /* Determine the actually available data examining all input stream sets. */
214    LoadInst * p = iBuilder->CreateAtomicLoadAcquire(inbufProducerPtrs[0]);
215    producerPos.push_back(p);
216    Value * availablePos = producerPos[0];
217    for (unsigned i = 1; i < inbufProducerPtrs.size(); i++) {
218        LoadInst * p = iBuilder->CreateAtomicLoadAcquire(inbufProducerPtrs[i]);
219        producerPos.push_back(p);
220        /* Set the available position to be the minimum of availablePos and producerPos. */
221        availablePos = iBuilder->CreateSelect(iBuilder->CreateICmpULT(availablePos, p), availablePos, p);
222    }
223    Value * processed = getProcessedItemCount(self, mStreamSetInputs[0].name);
224    Value * itemsAvail = iBuilder->CreateSub(availablePos, processed);
225//#ifndef NDEBUG
226//    iBuilder->CallPrintInt(mKernelName + "_itemsAvail", itemsAvail);
227//#endif
228    Value * stridesToDo = iBuilder->CreateUDiv(blocksToDo, strideBlocks);
229    Value * stridesAvail = iBuilder->CreateUDiv(itemsAvail, stride);
230    /* Adjust the number of full blocks to do, based on the available data, if necessary. */
231    Value * lessThanFullSegment = iBuilder->CreateICmpULT(stridesAvail, stridesToDo);
232    stridesToDo = iBuilder->CreateSelect(lessThanFullSegment, stridesAvail, stridesToDo);
233    //iBuilder->CallPrintInt(mKernelName + "_stridesAvail", stridesAvail);
234    iBuilder->CreateBr(strideLoopCond);
235
236    iBuilder->SetInsertPoint(strideLoopCond);
237    PHINode * stridesRemaining = iBuilder->CreatePHI(size_ty, 2, "stridesRemaining");
238    stridesRemaining->addIncoming(stridesToDo, entryBlock);
239    Value * notDone = iBuilder->CreateICmpUGT(stridesRemaining, ConstantInt::get(size_ty, 0));
240    iBuilder->CreateCondBr(notDone, strideLoopBody, stridesDone);
241
242    iBuilder->SetInsertPoint(strideLoopBody);
243    Value * blockNo = getScalarField(self, blockNoScalar);   
244
245    generateDoBlockLogic(self, blockNo);
246    setBlockNo(self, iBuilder->CreateAdd(blockNo, strideBlocks));
247    stridesRemaining->addIncoming(iBuilder->CreateSub(stridesRemaining, ConstantInt::get(size_ty, 1)), strideLoopBody);
248    iBuilder->CreateBr(strideLoopCond);
249   
250    iBuilder->SetInsertPoint(stridesDone);
251    processed = iBuilder->CreateAdd(processed, iBuilder->CreateMul(stridesToDo, stride));
252    setProcessedItemCount(self, mStreamSetInputs[0].name, processed);
253    iBuilder->CreateBr(segmentDone);
254    iBuilder->SetInsertPoint(segmentDone);
255//#ifndef NDEBUG
256//    iBuilder->CallPrintInt(mKernelName + "_processed", processed);
257//#endif
258    for (unsigned i = 0; i < mStreamSetOutputs.size(); i++) {
259        Value * produced = getProducedItemCount(self, mStreamSetOutputs[i].name);
260        Value * ssStructPtr = getStreamSetStructPtr(self, mStreamSetOutputs[i].name);
261        Value * producerPosPtr = mStreamSetOutputBuffers[i]->getProducerPosPtr(ssStructPtr);
262        iBuilder->CreateAtomicStoreRelease(produced, producerPosPtr);
263    }
264    iBuilder->CreateBr(finalExit);
265    iBuilder->SetInsertPoint(finalExit);
266
267    iBuilder->CreateRetVoid();
268    iBuilder->restoreIP(savePoint);
269}
270
271void KernelBuilder::generateFinalSegmentMethod() const {
272    auto savePoint = iBuilder->saveIP();
273    Module * m = iBuilder->getModule();
274    Function * finalSegmentFunction = m->getFunction(mKernelName + finalSegment_suffix);
275    iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "entry", finalSegmentFunction, 0));
276    BasicBlock * doStrides = BasicBlock::Create(iBuilder->getContext(), "doStrides", finalSegmentFunction, 0);
277    BasicBlock * stridesDone = BasicBlock::Create(iBuilder->getContext(), "stridesDone", finalSegmentFunction, 0);
278    Type * const size_ty = iBuilder->getSizeTy();
279    Constant * stride = ConstantInt::get(size_ty, iBuilder->getStride());
280    Value * strideBlocks = ConstantInt::get(size_ty, iBuilder->getStride() / iBuilder->getBitBlockWidth());
281    Function::arg_iterator args = finalSegmentFunction->arg_begin();
282    Value * self = &*(args++);
283    Value * blocksToDo = &*(args);
284    std::vector<Value *> inbufProducerPtrs;
285    std::vector<Value *> endSignalPtrs;
286    for (unsigned i = 0; i < mStreamSetInputs.size(); i++) {
287        Value * param = getStreamSetStructPtr(self, mStreamSetInputs[i].name);
288        inbufProducerPtrs.push_back(mStreamSetInputBuffers[i]->getProducerPosPtr(param));
289        endSignalPtrs.push_back(mStreamSetInputBuffers[i]->getEndOfInputPtr(param));
290    }
291   
292    std::vector<Value *> producerPos;
293    /* Determine the actually available data examining all input stream sets. */
294    LoadInst * p = iBuilder->CreateAtomicLoadAcquire(inbufProducerPtrs[0]);
295    producerPos.push_back(p);
296    Value * availablePos = producerPos[0];
297    for (unsigned i = 1; i < inbufProducerPtrs.size(); i++) {
298        LoadInst * p = iBuilder->CreateAtomicLoadAcquire(inbufProducerPtrs[i]);
299        producerPos.push_back(p);
300        /* Set the available position to be the minimum of availablePos and producerPos. */
301        availablePos = iBuilder->CreateSelect(iBuilder->CreateICmpULT(availablePos, p), availablePos, p);
302    }
303    Value * processed = getProcessedItemCount(self, mStreamSetInputs[0].name);
304    Value * itemsAvail = iBuilder->CreateSub(availablePos, processed);
305//#ifndef NDEBUG
306//    iBuilder->CallPrintInt(mKernelName + "_itemsAvail final", itemsAvail);
307//#endif
308    Value * stridesToDo = iBuilder->CreateUDiv(blocksToDo, strideBlocks);
309    Value * stridesAvail = iBuilder->CreateUDiv(itemsAvail, stride);
310    /* Adjust the number of full blocks to do, based on the available data, if necessary. */
311    Value * lessThanFullSegment = iBuilder->CreateICmpULT(stridesAvail, stridesToDo);
312    stridesToDo = iBuilder->CreateSelect(lessThanFullSegment, stridesAvail, stridesToDo);
313    Value * notDone = iBuilder->CreateICmpUGT(stridesToDo, ConstantInt::get(size_ty, 0));
314    iBuilder->CreateCondBr(notDone, doStrides, stridesDone);
315   
316    iBuilder->SetInsertPoint(doStrides);
317    createDoSegmentCall(self, blocksToDo);
318    iBuilder->CreateBr(stridesDone);
319   
320    iBuilder->SetInsertPoint(stridesDone);
321    /* Now at most a partial block remains. */
322   
323    processed = getProcessedItemCount(self, mStreamSetInputs[0].name);   
324    Value * remainingItems = iBuilder->CreateSub(producerPos[0], processed);
325    //iBuilder->CallPrintInt(mKernelName + " remainingItems", remainingItems);
326       
327    createFinalBlockCall(self, remainingItems);
328    processed = iBuilder->CreateAdd(processed, remainingItems);
329    setProcessedItemCount(self, mStreamSetInputs[0].name, processed);       
330//#ifndef NDEBUG
331//    iBuilder->CallPrintInt(mKernelName + "_processed final", processed);
332//#endif
333    for (unsigned i = 0; i < mStreamSetOutputs.size(); i++) {
334        Value * produced = getProducedItemCount(self, mStreamSetOutputs[i].name);
335        Value * ssStructPtr = getStreamSetStructPtr(self, mStreamSetOutputs[i].name);
336        Value * producerPosPtr = mStreamSetOutputBuffers[i]->getProducerPosPtr(ssStructPtr);
337        iBuilder->CreateAtomicStoreRelease(produced, producerPosPtr);
338    }
339
340    iBuilder->CreateRetVoid();
341
342    iBuilder->restoreIP(savePoint);
343}
344
345
346
347ConstantInt * KernelBuilder::getScalarIndex(const std::string & name) const {
348    const auto f = mKernelMap.find(name);
349    if (LLVM_UNLIKELY(f == mKernelMap.end())) {
350        llvm::report_fatal_error("Kernel does not contain scalar: " + name);
351    }
352    return iBuilder->getInt32(f->second);
353}
354
355unsigned KernelBuilder::getScalarCount() const {
356    return mKernelFields.size();
357}
358
359Value * KernelBuilder::getScalarFieldPtr(Value * self, const std::string & fieldName) const {
360    return iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(fieldName)});
361}
362
363Value * KernelBuilder::getScalarField(Value * self, const std::string & fieldName) const {
364    return iBuilder->CreateLoad(getScalarFieldPtr(self, fieldName));
365}
366
367void KernelBuilder::setScalarField(Value * self, const std::string & fieldName, Value * newFieldVal) const {
368    iBuilder->CreateStore(newFieldVal, getScalarFieldPtr(self, fieldName));
369}
370
371LoadInst * KernelBuilder::acquireLogicalSegmentNo(Value * self) const {
372    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(logicalSegmentNoScalar)});
373    return iBuilder->CreateAtomicLoadAcquire(ptr);
374}
375
376Value * KernelBuilder::getProcessedItemCount(Value * self, const std::string & ssName) const {
377    return getScalarField(self, ssName + processedItemCountSuffix);
378}
379
380Value * KernelBuilder::getProducedItemCount(Value * self, const std::string & ssName) const {
381    return getScalarField(self, ssName + producedItemCountSuffix);
382}
383
384Value * KernelBuilder::getTerminationSignal(Value * self) const {
385    return getScalarField(self, terminationSignal);
386}
387
388void KernelBuilder::releaseLogicalSegmentNo(Value * self, Value * newCount) const {
389    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(logicalSegmentNoScalar)});
390    iBuilder->CreateAtomicStoreRelease(newCount, ptr);
391}
392
393void KernelBuilder::setProcessedItemCount(Value * self, const std::string & ssName, Value * newCount) const {
394    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(ssName + processedItemCountSuffix)});
395    iBuilder->CreateStore(newCount, ptr);
396}
397
398void KernelBuilder::setProducedItemCount(Value * self, const std::string & ssName, Value * newCount) const {
399    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(ssName + producedItemCountSuffix)});
400    iBuilder->CreateStore(newCount, ptr);
401}
402
403void KernelBuilder::setTerminationSignal(Value * self) const {
404    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(terminationSignal)});
405    iBuilder->CreateStore(ConstantInt::get(iBuilder->getInt1Ty(), 1), ptr);
406}
407
408Value * KernelBuilder::getBlockNo(Value * self) const {
409    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(blockNoScalar)});
410    return iBuilder->CreateLoad(ptr);
411}
412
413void KernelBuilder::setBlockNo(Value * self, Value * value) const {
414    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(blockNoScalar)});
415    iBuilder->CreateStore(value, ptr);
416}
417
418
419Value * KernelBuilder::getParameter(Function * f, const std::string & paramName) const {
420    for (Function::arg_iterator argIter = f->arg_begin(), end = f->arg_end(); argIter != end; argIter++) {
421        Value * arg = &*argIter;
422        if (arg->getName() == paramName) return arg;
423    }
424    llvm::report_fatal_error("Method does not have parameter: " + paramName);
425}
426
427unsigned KernelBuilder::getStreamSetIndex(const std::string & name) const {
428    const auto f = mStreamSetNameMap.find(name);
429    if (LLVM_UNLIKELY(f == mStreamSetNameMap.end())) {
430        llvm::report_fatal_error("Kernel does not contain stream set: " + name);
431    }
432    return f->second;
433}
434
435Value * KernelBuilder::getStreamSetStructPtr(Value * self, const std::string & name) const {
436    return getScalarField(self, name + structPtrSuffix);
437}
438
439inline const StreamSetBuffer * KernelBuilder::getStreamSetBuffer(const std::string & name) const {
440    const unsigned structIdx = getStreamSetIndex(name);
441    if (structIdx < mStreamSetInputs.size()) {
442        return mStreamSetInputBuffers[structIdx];
443    } else {
444        return mStreamSetOutputBuffers[structIdx - mStreamSetInputs.size()];
445    }
446}
447
448Value * KernelBuilder::getStreamSetPtr(Value * self, const std::string & name, Value * blockNo) const {
449    return getStreamSetBuffer(name)->getStreamSetPtr(getStreamSetStructPtr(self, name), blockNo);
450}
451
452Value * KernelBuilder::getStream(Value * self, const std::string & name, Value * blockNo, Value * index) const {
453    return getStreamSetBuffer(name)->getStream(getStreamSetStructPtr(self, name), blockNo, index);
454}
455
456Value * KernelBuilder::getStream(Value * self, const std::string & name, Value * blockNo, Value * index1, Value * index2) const {
457    assert (index1->getType() == index2->getType());
458    return getStreamSetBuffer(name)->getStream(getStreamSetStructPtr(self, name), blockNo, index1, index2);
459}
460
461Value * KernelBuilder::getStreamView(Value * self, const std::string & name, Value * blockNo, Value * index) const {
462    return getStreamSetBuffer(name)->getStreamView(getStreamSetStructPtr(self, name), blockNo, index);
463}
464
465Value * KernelBuilder::getStreamView(llvm::Type * type, Value * self, const std::string & name, Value * blockNo, Value * index) const {
466    return getStreamSetBuffer(name)->getStreamView(type, getStreamSetStructPtr(self, name), blockNo, index);
467}
468
469void KernelBuilder::createInstance() {
470    if (LLVM_UNLIKELY(mKernelStateType == nullptr)) {
471        llvm::report_fatal_error("Cannot create kernel instance before calling prepareKernel()");
472    }
473    mKernelInstance = iBuilder->CreateCacheAlignedAlloca(mKernelStateType);
474    Module * m = iBuilder->getModule();
475    std::vector<Value *> init_args = {mKernelInstance};
476    for (auto a : mInitialArguments) {
477        init_args.push_back(a);
478    }
479    for (auto b : mStreamSetInputBuffers) {
480        init_args.push_back(b->getStreamSetStructPtr());
481    }
482    for (auto b : mStreamSetOutputBuffers) {
483        init_args.push_back(b->getStreamSetStructPtr());
484    }
485    std::string initFnName = mKernelName + init_suffix;
486    Function * initMethod = m->getFunction(initFnName);
487    if (!initMethod) {
488        llvm::report_fatal_error("Cannot find " + initFnName);
489    }
490    iBuilder->CreateCall(initMethod, init_args);
491}
492
493Function * KernelBuilder::generateThreadFunction(const std::string & name) const {
494    if (LLVM_UNLIKELY(mKernelStateType == nullptr)) {
495        llvm::report_fatal_error("Cannot generate thread function before calling prepareKernel()");
496    }
497    Module * m = iBuilder->getModule();
498    Type * const voidTy = iBuilder->getVoidTy();
499    Type * const voidPtrTy = iBuilder->getVoidPtrTy();
500    Type * const int8PtrTy = iBuilder->getInt8PtrTy();
501    Type * const int1ty = iBuilder->getInt1Ty();
502
503    Function * const threadFunc = cast<Function>(m->getOrInsertFunction(name, voidTy, int8PtrTy, nullptr));
504    threadFunc->setCallingConv(CallingConv::C);
505    Function::arg_iterator args = threadFunc->arg_begin();
506
507    Value * const arg = &*(args++);
508    arg->setName("args");
509
510    iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "entry", threadFunc,0));
511
512    Value * self = iBuilder->CreateBitCast(arg, PointerType::get(mKernelStateType, 0));
513
514    std::vector<Value *> inbufProducerPtrs;
515    std::vector<Value *> inbufConsumerPtrs;
516    std::vector<Value *> outbufProducerPtrs;
517    std::vector<Value *> outbufConsumerPtrs;   
518    std::vector<Value *> endSignalPtrs;
519
520    for (unsigned i = 0; i < mStreamSetInputs.size(); i++) {
521        Value * ssStructPtr = getStreamSetStructPtr(self, mStreamSetInputs[i].name);
522        inbufProducerPtrs.push_back(mStreamSetInputBuffers[i]->getProducerPosPtr(ssStructPtr));
523        inbufConsumerPtrs.push_back(mStreamSetInputBuffers[i]->getConsumerPosPtr(ssStructPtr));
524        endSignalPtrs.push_back(mStreamSetInputBuffers[i]->getEndOfInputPtr(ssStructPtr));
525    }
526    for (unsigned i = 0; i < mStreamSetOutputs.size(); i++) {
527        Value * ssStructPtr = getStreamSetStructPtr(self, mStreamSetOutputs[i].name);
528        outbufProducerPtrs.push_back(mStreamSetOutputBuffers[i]->getProducerPosPtr(ssStructPtr));
529        outbufConsumerPtrs.push_back(mStreamSetOutputBuffers[i]->getConsumerPosPtr(ssStructPtr));
530    }
531
532    const unsigned segmentBlocks = codegen::SegmentSize;
533    const unsigned bufferSegments = codegen::BufferSegments;
534    const unsigned segmentSize = segmentBlocks * iBuilder->getBitBlockWidth();
535    Type * const size_ty = iBuilder->getSizeTy();
536
537    Value * segSize = ConstantInt::get(size_ty, segmentSize);
538    Value * bufferSize = ConstantInt::get(size_ty, segmentSize * (bufferSegments - 1));
539    Value * segBlocks = ConstantInt::get(size_ty, segmentBlocks);
540   
541    BasicBlock * outputCheckBlock = BasicBlock::Create(iBuilder->getContext(), "outputCheck", threadFunc, 0);
542    BasicBlock * inputCheckBlock = BasicBlock::Create(iBuilder->getContext(), "inputCheck", threadFunc, 0);
543   
544    BasicBlock * endSignalCheckBlock = BasicBlock::Create(iBuilder->getContext(), "endSignalCheck", threadFunc, 0);
545    BasicBlock * doSegmentBlock = BasicBlock::Create(iBuilder->getContext(), "doSegment", threadFunc, 0);
546    BasicBlock * endBlock = BasicBlock::Create(iBuilder->getContext(), "end", threadFunc, 0);
547    BasicBlock * doFinalSegBlock = BasicBlock::Create(iBuilder->getContext(), "doFinalSeg", threadFunc, 0);
548    BasicBlock * doFinalBlock = BasicBlock::Create(iBuilder->getContext(), "doFinal", threadFunc, 0);
549
550    iBuilder->CreateBr(outputCheckBlock);
551
552    iBuilder->SetInsertPoint(outputCheckBlock);
553
554    Value * waitCondTest = ConstantInt::get(int1ty, 1);   
555    for (unsigned i = 0; i < outbufProducerPtrs.size(); i++) {
556        LoadInst * producerPos = iBuilder->CreateAtomicLoadAcquire(outbufProducerPtrs[i]);
557        // iBuilder->CallPrintInt(name + ":output producerPos", producerPos);
558        LoadInst * consumerPos = iBuilder->CreateAtomicLoadAcquire(outbufConsumerPtrs[i]);
559        // iBuilder->CallPrintInt(name + ":output consumerPos", consumerPos);
560        waitCondTest = iBuilder->CreateAnd(waitCondTest, iBuilder->CreateICmpULE(producerPos, iBuilder->CreateAdd(consumerPos, bufferSize)));
561    }
562   
563    iBuilder->CreateCondBr(waitCondTest, inputCheckBlock, outputCheckBlock); 
564
565    iBuilder->SetInsertPoint(inputCheckBlock); 
566
567    Value * requiredSize = segSize;
568    if (mLookAheadPositions > 0) {
569        requiredSize = iBuilder->CreateAdd(segSize, ConstantInt::get(size_ty, mLookAheadPositions));
570    }
571    waitCondTest = ConstantInt::get(int1ty, 1); 
572    for (unsigned i = 0; i < inbufProducerPtrs.size(); i++) {
573        LoadInst * producerPos = iBuilder->CreateAtomicLoadAcquire(inbufProducerPtrs[i]);
574        // iBuilder->CallPrintInt(name + ":input producerPos", producerPos);
575        LoadInst * consumerPos = iBuilder->CreateAtomicLoadAcquire(inbufConsumerPtrs[i]);
576        // iBuilder->CallPrintInt(name + ":input consumerPos", consumerPos);
577        waitCondTest = iBuilder->CreateAnd(waitCondTest, iBuilder->CreateICmpULE(iBuilder->CreateAdd(consumerPos, requiredSize), producerPos));
578    }
579
580    iBuilder->CreateCondBr(waitCondTest, doSegmentBlock, endSignalCheckBlock);
581   
582    iBuilder->SetInsertPoint(endSignalCheckBlock);
583   
584    LoadInst * endSignal = iBuilder->CreateLoad(endSignalPtrs[0]);
585    for (unsigned i = 1; i < endSignalPtrs.size(); i++){
586        LoadInst * endSignal_next = iBuilder->CreateLoad(endSignalPtrs[i]);
587        iBuilder->CreateAnd(endSignal, endSignal_next);
588    }
589       
590    iBuilder->CreateCondBr(endSignal, endBlock, inputCheckBlock);
591   
592    iBuilder->SetInsertPoint(doSegmentBlock);
593 
594    createDoSegmentCall(self, segBlocks);
595
596    for (unsigned i = 0; i < inbufConsumerPtrs.size(); i++) {
597        Value * consumerPos = iBuilder->CreateAdd(iBuilder->CreateLoad(inbufConsumerPtrs[i]), segSize);
598        iBuilder->CreateAtomicStoreRelease(consumerPos, inbufConsumerPtrs[i]);
599    }
600   
601    for (unsigned i = 0; i < outbufProducerPtrs.size(); i++) {
602        Value * produced = getProducedItemCount(self, mStreamSetOutputs[i].name);
603        iBuilder->CreateAtomicStoreRelease(produced, outbufProducerPtrs[i]);
604    }
605   
606    Value * earlyEndSignal = getTerminationSignal(self);
607    if (earlyEndSignal != ConstantInt::getNullValue(iBuilder->getInt1Ty())) {
608        BasicBlock * earlyEndBlock = BasicBlock::Create(iBuilder->getContext(), "earlyEndSignal", threadFunc, 0);
609        iBuilder->CreateCondBr(earlyEndSignal, earlyEndBlock, outputCheckBlock);
610
611        iBuilder->SetInsertPoint(earlyEndBlock);
612        for (unsigned i = 0; i < mStreamSetOutputs.size(); i++) {
613            Value * ssStructPtr = getStreamSetStructPtr(self, mStreamSetOutputs[i].name);
614            mStreamSetOutputBuffers[i]->setEndOfInput(ssStructPtr);
615        }       
616    }
617    iBuilder->CreateBr(outputCheckBlock);
618     
619    iBuilder->SetInsertPoint(endBlock);
620    LoadInst * producerPos = iBuilder->CreateLoad(inbufProducerPtrs[0]);
621    LoadInst * consumerPos = iBuilder->CreateLoad(inbufConsumerPtrs[0]);
622    Value * remainingBytes = iBuilder->CreateSub(producerPos, consumerPos);
623    Value * blockSize = ConstantInt::get(size_ty, iBuilder->getBitBlockWidth());
624    Value * blocks = iBuilder->CreateUDiv(remainingBytes, blockSize);
625    Value * finalBlockRemainingBytes = iBuilder->CreateURem(remainingBytes, blockSize);
626
627    iBuilder->CreateCondBr(iBuilder->CreateICmpEQ(blocks, ConstantInt::get(size_ty, 0)), doFinalBlock, doFinalSegBlock);
628
629    iBuilder->SetInsertPoint(doFinalSegBlock);
630
631    createDoSegmentCall(self, blocks);
632
633    iBuilder->CreateBr(doFinalBlock);
634
635    iBuilder->SetInsertPoint(doFinalBlock);
636
637    createFinalBlockCall(self, finalBlockRemainingBytes);
638
639    for (unsigned i = 0; i < inbufConsumerPtrs.size(); i++) {
640        Value * consumerPos = iBuilder->CreateAdd(iBuilder->CreateLoad(inbufConsumerPtrs[i]), remainingBytes);
641        iBuilder->CreateAtomicStoreRelease(consumerPos, inbufConsumerPtrs[i]);
642    }
643    for (unsigned i = 0; i < outbufProducerPtrs.size(); i++) {
644        iBuilder->CreateAtomicStoreRelease(producerPos, outbufProducerPtrs[i]);
645    }
646
647    for (unsigned i = 0; i < mStreamSetOutputs.size(); i++) {
648        Value * ssStructPtr = getStreamSetStructPtr(self, mStreamSetOutputs[i].name);
649        mStreamSetOutputBuffers[i]->setEndOfInput(ssStructPtr);
650    }
651
652    iBuilder->CreatePThreadExitCall(Constant::getNullValue(voidPtrTy));
653    iBuilder->CreateRetVoid();
654
655    return threadFunc;
656
657}
658
659KernelBuilder::~KernelBuilder() {
660}
Note: See TracBrowser for help on using the repository browser.