source: icGREP/icgrep-devel/icgrep/kernels/kernel_builder.cpp @ 6034

Last change on this file since 6034 was 6008, checked in by nmedfort, 17 months ago

Removed temporary buffers from pipeline and placed them in the source kernels.

File size: 27.8 KB
Line 
1#include "kernel_builder.h"
2#include <toolchain/toolchain.h>
3#include <kernels/kernel.h>
4#include <kernels/streamset.h>
5#include <llvm/Support/raw_ostream.h>
6#include <llvm/IR/Module.h>
7
8using namespace llvm;
9using namespace parabix;
10
11inline static bool is_power_2(const uint64_t n) {
12    return ((n & (n - 1)) == 0) && n;
13}
14
15namespace kernel {
16
17using Port = Kernel::Port;
18
19Value * KernelBuilder::getScalarFieldPtr(llvm::Value * const instance, Value * const index) {
20    if (LLVM_UNLIKELY(codegen::DebugOptionIsSet(codegen::EnableAsserts))) {
21        CreateAssert(instance, "getScalarFieldPtr: instance cannot be null!");
22    }
23    return CreateGEP(instance, {getInt32(0), index});
24}
25
26Value * KernelBuilder::getScalarFieldPtr(llvm::Value * const handle, const std::string & fieldName) {
27    return getScalarFieldPtr(handle, getInt32(mKernel->getScalarIndex(fieldName)));
28}
29
30llvm::Value * KernelBuilder::getScalarFieldPtr(llvm::Value * const index) {
31    return getScalarFieldPtr(mKernel->getInstance(), index);
32}
33
34llvm::Value *KernelBuilder:: getScalarFieldPtr(const std::string & fieldName) {
35    return getScalarFieldPtr(mKernel->getInstance(), fieldName);
36}
37
38Value * KernelBuilder::getScalarField(const std::string & fieldName) {
39    return CreateLoad(getScalarFieldPtr(fieldName), fieldName);
40}
41
42void KernelBuilder::setScalarField(const std::string & fieldName, Value * value) {
43    CreateStore(value, getScalarFieldPtr(fieldName));
44}
45
46Value * KernelBuilder::getStreamHandle(const std::string & name) {
47    Value * const ptr = getScalarField(name + BUFFER_SUFFIX);
48    if (LLVM_UNLIKELY(codegen::DebugOptionIsSet(codegen::EnableAsserts))) {
49        CreateAssert(ptr, name + " handle cannot be null!");
50    }
51    return ptr;
52}
53
54LoadInst * KernelBuilder::acquireLogicalSegmentNo() {
55    return CreateAtomicLoadAcquire(getScalarFieldPtr(LOGICAL_SEGMENT_NO_SCALAR));
56}
57
58void KernelBuilder::releaseLogicalSegmentNo(Value * const nextSegNo) {
59    CreateAtomicStoreRelease(nextSegNo, getScalarFieldPtr(LOGICAL_SEGMENT_NO_SCALAR));
60}
61
62Value * KernelBuilder::getCycleCountPtr() {
63    return getScalarFieldPtr(CYCLECOUNT_SCALAR);
64}
65
66Value * KernelBuilder::getNamedItemCount(const std::string & name, const std::string & suffix) {
67    const ProcessingRate & rate = mKernel->getBinding(name).getRate();
68    Value * itemCount = nullptr;
69    if (LLVM_UNLIKELY(rate.isRelative())) {
70        Port port; unsigned index;
71        std::tie(port, index) = mKernel->getStreamPort(rate.getReference());
72        if (port == Port::Input) {
73            itemCount = getProcessedItemCount(rate.getReference());
74        } else {
75            itemCount = getProducedItemCount(rate.getReference());
76        }
77        itemCount = CreateMul2(itemCount, rate.getRate());
78    } else {
79        itemCount = getScalarField(name + suffix);
80    }
81    return itemCount;
82}
83
84void KernelBuilder::setNamedItemCount(const std::string & name, const std::string & suffix, llvm::Value * const value) {
85    const ProcessingRate & rate = mKernel->getBinding(name).getRate();
86    const auto safetyCheck = mKernel->treatUnsafeKernelOperationsAsErrors();
87    if (LLVM_UNLIKELY(rate.isDerived() && safetyCheck)) {
88        report_fatal_error("Cannot set item count: " + name + " is a derived rate stream");
89    }
90    if (codegen::DebugOptionIsSet(codegen::TraceCounts)) {
91        CallPrintIntToStderr(mKernel->getName() + ": " + name + suffix, value);
92    }
93    if (LLVM_UNLIKELY(codegen::DebugOptionIsSet(codegen::EnableAsserts) && safetyCheck)) {
94        Value * const current = getScalarField(name + suffix);
95        CreateAssert(CreateICmpUGE(value, current), name + " " + suffix + " must be monotonically non-decreasing");
96    }
97    setScalarField(name + suffix, value);
98}
99
100
101Value * KernelBuilder::getAvailableItemCount(const std::string & name) {
102    const auto & inputs = mKernel->getStreamInputs();
103    for (unsigned i = 0; i < inputs.size(); ++i) {
104        if (inputs[i].getName() == name) {
105            return mKernel->getAvailableItemCount(i);
106        }
107    }
108    return nullptr;
109}
110
111Value * KernelBuilder::getTerminationSignal() {
112    return CreateICmpNE(getScalarField(TERMINATION_SIGNAL), getSize(0));
113}
114
115void KernelBuilder::setTerminationSignal(llvm::Value * const value) {
116    assert (value->getType() == getInt1Ty());
117    if (codegen::DebugOptionIsSet(codegen::TraceCounts)) {
118        CallPrintIntToStderr(mKernel->getName() + ": setTerminationSignal", value);
119    }
120    setScalarField(TERMINATION_SIGNAL, CreateZExt(value, getSizeTy()));
121}
122
123Value * KernelBuilder::getLinearlyAccessibleItems(const std::string & name, Value * fromPosition, Value * avail, bool reverse) {
124    const StreamSetBuffer * const buf = mKernel->getInputStreamSetBuffer(name);
125    return buf->getLinearlyAccessibleItems(this, getStreamHandle(name), fromPosition, avail, reverse);
126}
127
128Value * KernelBuilder::getLinearlyWritableItems(const std::string & name, Value * fromPosition, bool reverse) {
129    const StreamSetBuffer * const buf = mKernel->getOutputStreamSetBuffer(name);
130    return buf->getLinearlyWritableItems(this, getStreamHandle(name), fromPosition, getConsumedItemCount(name), reverse);
131}
132
133/** ------------------------------------------------------------------------------------------------------------- *
134 * @brief CreatePrepareOverflow
135 ** ------------------------------------------------------------------------------------------------------------- */
136void KernelBuilder::CreatePrepareOverflow(const std::string & name) {
137    const StreamSetBuffer * const buf = mKernel->getOutputStreamSetBuffer(name);
138    assert (buf->supportsCopyBack());
139    Constant * const overflowSize = ConstantExpr::getSizeOf(buf->getType());
140    Value * const handle = getStreamHandle(name);
141    // TODO: handle non constant stream set counts
142    assert (isa<Constant>(buf->getStreamSetCount(this, handle)));
143    Value * const base = buf->getBaseAddress(this, handle);
144    Value * const overflow = buf->getOverflowAddress(this, handle);
145    const auto blockSize = getBitBlockWidth() / 8;
146    CreateMemZero(overflow, overflowSize, blockSize);
147    CreateMemZero(base, overflowSize, blockSize);
148}
149
150/** ------------------------------------------------------------------------------------------------------------- *
151 * @brief getItemWidth
152 ** ------------------------------------------------------------------------------------------------------------- */
153inline unsigned LLVM_READNONE getItemWidth(const Type * ty ) {
154    if (LLVM_LIKELY(isa<ArrayType>(ty))) {
155        ty = ty->getArrayElementType();
156    }
157    return cast<IntegerType>(ty->getVectorElementType())->getBitWidth();
158}
159
160/** ------------------------------------------------------------------------------------------------------------- *
161 * @brief CreateNonLinearCopyFromOverflow
162 ** ------------------------------------------------------------------------------------------------------------- */
163void KernelBuilder::CreateNonLinearCopyFromOverflow(const Binding & output, llvm::Value * const itemsToCopy, Value * overflowOffset) {
164
165    Value * const handle = getStreamHandle(output.getName());
166    Type * const bitBlockPtrTy = getBitBlockType()->getPointerTo();
167    const StreamSetBuffer * const buf = mKernel->getOutputStreamSetBuffer(output.getName());
168    assert (buf->supportsCopyBack());
169    Value * const target = CreatePointerCast(buf->getBaseAddress(this, handle), bitBlockPtrTy);
170    Value * const source = CreatePointerCast(buf->getOverflowAddress(this, handle), bitBlockPtrTy);
171    const auto blockSize = getBitBlockWidth() / 8;
172    Constant * const BLOCK_WIDTH = getSize(getBitBlockWidth());
173    Constant * const ITEM_WIDTH = getSize(getItemWidth(buf->getBaseType()));
174    Value * const streamCount = buf->getStreamSetCount(this, handle);
175
176    // If we have a computed overflow position, the base and overflow regions were not speculatively zeroed out prior
177    // to the kernel writing over them. To handle them, we compute a mask of valid items and exclude any bit not in
178    // them before OR-ing together the streams.
179    if (overflowOffset) {
180
181        overflowOffset = CreateMul(overflowOffset, ITEM_WIDTH);
182        Value * targetMask = bitblock_mask_from(CreateURem(overflowOffset, BLOCK_WIDTH));
183        Value * sourceMask = CreateNot(targetMask);
184        Value * const overflowBlockCount = CreateUDiv(overflowOffset, BLOCK_WIDTH);
185        Value * const blockOffset = CreateMul(overflowBlockCount, streamCount);
186        Value * const fullCopyLength = CreateMul(blockOffset, getSize(blockSize));
187        CreateMemCpy(target, source, fullCopyLength, blockSize);
188
189        BasicBlock * const partialCopyEntry = GetInsertBlock();
190        BasicBlock * const partialCopyLoop = CreateBasicBlock();
191        BasicBlock * const partialCopyExit = CreateBasicBlock();
192
193        Value * const partialBlockCount = CreateAdd(blockOffset, streamCount);
194        CreateBr(partialCopyLoop);
195
196        SetInsertPoint(partialCopyLoop);
197        PHINode * const blockIndex = CreatePHI(getSizeTy(), 2);
198        blockIndex->addIncoming(blockOffset, partialCopyEntry);
199        Value * const sourcePtr = CreateGEP(source, blockIndex);
200        Value * sourceValue = CreateBlockAlignedLoad(sourcePtr);
201        sourceValue = CreateAnd(sourceValue, sourceMask);
202        Value * const targetPtr = CreateGEP(target, blockIndex);
203        Value * targetValue = CreateBlockAlignedLoad(targetPtr);
204        targetValue = CreateAnd(targetValue, targetMask);
205        targetValue = CreateOr(targetValue, sourceValue);
206        CreateBlockAlignedStore(targetValue, targetPtr);
207        Value * const nextBlockIndex = CreateAdd(blockIndex, getSize(1));
208        blockIndex->addIncoming(nextBlockIndex, partialCopyLoop);
209        CreateCondBr(CreateICmpNE(nextBlockIndex, partialBlockCount), partialCopyLoop, partialCopyExit);
210
211        SetInsertPoint(partialCopyExit);
212
213    } else {
214
215        BasicBlock * const mergeCopyEntry = GetInsertBlock();
216        BasicBlock * const mergeCopyLoop = CreateBasicBlock();
217        BasicBlock * const mergeCopyExit = CreateBasicBlock();
218
219        Value * blocksToCopy = CreateCeilUDiv(itemsToCopy, BLOCK_WIDTH);
220        blocksToCopy = CreateMul(blocksToCopy, ITEM_WIDTH);
221        blocksToCopy = CreateMul(blocksToCopy, streamCount);
222
223        CreateBr(mergeCopyLoop);
224
225        SetInsertPoint(mergeCopyLoop);
226        PHINode * const blockIndex = CreatePHI(getSizeTy(), 2);
227        blockIndex->addIncoming(getSize(0), mergeCopyEntry);
228        Value * const sourcePtr = CreateGEP(source, blockIndex);
229        Value * const sourceValue = CreateBlockAlignedLoad(sourcePtr);
230        Value * const targetPtr = CreateGEP(target, blockIndex);
231        Value * targetValue = CreateBlockAlignedLoad(targetPtr);
232        targetValue = CreateOr(targetValue, sourceValue);
233        CreateBlockAlignedStore(targetValue, targetPtr);
234        Value * const nextBlockIndex = CreateAdd(blockIndex, getSize(1));
235        blockIndex->addIncoming(nextBlockIndex, mergeCopyLoop);
236        CreateCondBr(CreateICmpNE(nextBlockIndex, blocksToCopy), mergeCopyLoop, mergeCopyExit);
237
238        SetInsertPoint(mergeCopyExit);
239    }
240
241
242
243}
244
245/** ------------------------------------------------------------------------------------------------------------- *
246 * @brief CreateCopyFromOverflow
247 ** ------------------------------------------------------------------------------------------------------------- */
248void KernelBuilder::CreateCopyFromOverflow(const Binding & output, llvm::Value * const itemsToCopy) {
249
250    Value * const handle = getStreamHandle(output.getName());
251    Type * const bitBlockPtrTy = getBitBlockType()->getPointerTo();
252    const StreamSetBuffer * const buf = mKernel->getOutputStreamSetBuffer(output.getName());
253    assert (buf->supportsCopyBack());
254    Value * const target = CreatePointerCast(buf->getBaseAddress(this, handle), bitBlockPtrTy);
255    Value * const source = CreatePointerCast(buf->getOverflowAddress(this, handle), bitBlockPtrTy);
256    Constant * const BLOCK_WIDTH = getSize(getBitBlockWidth());
257    Constant * const ITEM_WIDTH = getSize(getItemWidth(buf->getBaseType()));
258    Value * const streamCount = buf->getStreamSetCount(this, handle);
259
260    BasicBlock * const mergeCopyEntry = GetInsertBlock();
261    BasicBlock * const mergeCopyLoop = CreateBasicBlock();
262    BasicBlock * const mergeCopyExit = CreateBasicBlock();
263
264    Value * blocksToCopy = CreateCeilUDiv(itemsToCopy, BLOCK_WIDTH);
265    blocksToCopy = CreateMul(blocksToCopy, ITEM_WIDTH);
266    blocksToCopy = CreateMul(blocksToCopy, streamCount);
267
268    CreateCondBr(CreateICmpEQ(blocksToCopy, getSize(0)), mergeCopyExit, mergeCopyLoop);
269
270    SetInsertPoint(mergeCopyLoop);
271    PHINode * const blockIndex = CreatePHI(getSizeTy(), 2);
272    blockIndex->addIncoming(getSize(0), mergeCopyEntry);
273    Value * const sourcePtr = CreateGEP(source, blockIndex);
274    Value * const sourceValue = CreateBlockAlignedLoad(sourcePtr);
275    Value * const targetPtr = CreateGEP(target, blockIndex);
276    CreateBlockAlignedStore(sourceValue, targetPtr);
277    Value * const nextBlockIndex = CreateAdd(blockIndex, getSize(1));
278    blockIndex->addIncoming(nextBlockIndex, mergeCopyLoop);
279    CreateCondBr(CreateICmpNE(nextBlockIndex, blocksToCopy), mergeCopyLoop, mergeCopyExit);
280
281    SetInsertPoint(mergeCopyExit);
282}
283
284
285/** ------------------------------------------------------------------------------------------------------------- *
286 * @brief CreateCopyToOverflow
287 ** ------------------------------------------------------------------------------------------------------------- */
288void KernelBuilder::CreateCopyToOverflow(const std::string & name) {
289    const StreamSetBuffer * const buf = mKernel->getOutputStreamSetBuffer(name);
290    assert (buf->supportsCopyBack());
291    Value * const handle = getStreamHandle(name);
292    // TODO: handle non constant stream set counts
293    assert (isa<Constant>(buf->getStreamSetCount(this, handle)));
294    Value * const target = buf->getBaseAddress(this, handle);
295    Value * const source = buf->getOverflowAddress(this, handle);
296    Constant * const overflowSize = ConstantExpr::getSizeOf(buf->getType());
297    CreateMemCpy(target, source, overflowSize, getBitBlockWidth() / 8);
298}
299
300Value * KernelBuilder::getConsumerLock(const std::string & name) {
301    return getScalarField(name + CONSUMER_SUFFIX);
302}
303
304void KernelBuilder::setConsumerLock(const std::string & name, Value * const value) {
305    setScalarField(name + CONSUMER_SUFFIX, value);
306}
307
308Value * KernelBuilder::getInputStreamBlockPtr(const std::string & name, Value * const streamIndex, Value * const blockOffset) {
309    const StreamSetBuffer * const buf = mKernel->getInputStreamSetBuffer(name);
310    Value * blockIndex = CreateLShr(getProcessedItemCount(name), std::log2(getBitBlockWidth()));
311    if (blockOffset) {
312        assert (blockOffset->getType() == blockIndex->getType());
313        blockIndex = CreateAdd(blockIndex, blockOffset);
314    }
315    return buf->getStreamBlockPtr(this, getStreamHandle(name), streamIndex, blockIndex, true);
316}
317
318Value * KernelBuilder::getInputStreamPackPtr(const std::string & name, Value * const streamIndex, Value * const packIndex, Value * const blockOffset) {
319    const StreamSetBuffer * const buf = mKernel->getInputStreamSetBuffer(name);
320    Value * blockIndex = CreateLShr(getProcessedItemCount(name), std::log2(getBitBlockWidth()));
321    if (blockOffset) {
322        assert (blockOffset->getType() == blockIndex->getType());
323        blockIndex = CreateAdd(blockIndex, blockOffset);
324    }
325    return buf->getStreamPackPtr(this, getStreamHandle(name), streamIndex, blockIndex, packIndex, true);
326}
327
328Value * KernelBuilder::loadInputStreamBlock(const std::string & name, Value * const streamIndex, Value * const blockOffset) {
329    return CreateBlockAlignedLoad(getInputStreamBlockPtr(name, streamIndex, blockOffset));
330}
331
332Value * KernelBuilder::loadInputStreamPack(const std::string & name, Value * const streamIndex, Value * const packIndex, Value * const blockOffset) {
333    return CreateBlockAlignedLoad(getInputStreamPackPtr(name, streamIndex, packIndex, blockOffset));
334}
335
336Value * KernelBuilder::getInputStreamSetCount(const std::string & name) {
337    const StreamSetBuffer * const buf = mKernel->getInputStreamSetBuffer(name);
338    return buf->getStreamSetCount(this, getStreamHandle(name));
339}
340
341Value * KernelBuilder::getOutputStreamBlockPtr(const std::string & name, Value * streamIndex, Value * const blockOffset) {
342    const StreamSetBuffer * const buf = mKernel->getOutputStreamSetBuffer(name);
343    Value * blockIndex = CreateLShr(getProducedItemCount(name), std::log2(getBitBlockWidth()));
344    if (blockOffset) {
345        assert (blockOffset->getType() == blockIndex->getType());
346        blockIndex = CreateAdd(blockIndex, blockOffset);
347    }
348    return buf->getStreamBlockPtr(this, getStreamHandle(name), streamIndex, blockIndex, false);
349}
350
351Value * KernelBuilder::getOutputStreamPackPtr(const std::string & name, Value * streamIndex, Value * packIndex, llvm::Value * blockOffset) {
352    const StreamSetBuffer * const buf = mKernel->getOutputStreamSetBuffer(name);
353    Value * blockIndex = CreateLShr(getProducedItemCount(name), std::log2(getBitBlockWidth()));
354    if (blockOffset) {
355        assert (blockOffset->getType() == blockIndex->getType());
356        blockIndex = CreateAdd(blockIndex, blockOffset);
357    }
358    return buf->getStreamPackPtr(this, getStreamHandle(name), streamIndex, blockIndex, packIndex, false);
359}
360
361
362StoreInst * KernelBuilder::storeOutputStreamBlock(const std::string & name, Value * streamIndex, llvm::Value * blockOffset, Value * toStore) {
363    Value * const ptr = getOutputStreamBlockPtr(name, streamIndex, blockOffset);
364    Type * const storeTy = toStore->getType();
365    Type * const ptrElemTy = ptr->getType()->getPointerElementType();
366    if (LLVM_UNLIKELY(storeTy != ptrElemTy)) {
367        if (LLVM_LIKELY(storeTy->canLosslesslyBitCastTo(ptrElemTy))) {
368            toStore = CreateBitCast(toStore, ptrElemTy);
369        } else {
370            std::string tmp;
371            raw_string_ostream out(tmp);
372            out << "invalid type conversion when calling storeOutputStreamBlock on " <<  name << ": ";
373            ptrElemTy->print(out);
374            out << " vs. ";
375            storeTy->print(out);
376        }
377    }
378    return CreateBlockAlignedStore(toStore, ptr);
379}
380
381StoreInst * KernelBuilder::storeOutputStreamPack(const std::string & name, Value * streamIndex, Value * packIndex, llvm::Value * blockOffset, Value * toStore) {
382    Value * const ptr = getOutputStreamPackPtr(name, streamIndex, packIndex, blockOffset);
383    Type * const storeTy = toStore->getType();
384    Type * const ptrElemTy = ptr->getType()->getPointerElementType();
385    if (LLVM_UNLIKELY(storeTy != ptrElemTy)) {
386        if (LLVM_LIKELY(storeTy->canLosslesslyBitCastTo(ptrElemTy))) {
387            toStore = CreateBitCast(toStore, ptrElemTy);
388        } else {
389            std::string tmp;
390            raw_string_ostream out(tmp);
391            out << "invalid type conversion when calling storeOutputStreamPack on " <<  name << ": ";
392            ptrElemTy->print(out);
393            out << " vs. ";
394            storeTy->print(out);
395        }
396    }
397    return CreateBlockAlignedStore(toStore, ptr);
398}
399
400Value * KernelBuilder::getOutputStreamSetCount(const std::string & name) {
401    const StreamSetBuffer * const buf = mKernel->getOutputStreamSetBuffer(name);
402    return buf->getStreamSetCount(this, getStreamHandle(name));
403}
404
405Value * KernelBuilder::getRawInputPointer(const std::string & name, Value * absolutePosition) {
406    const StreamSetBuffer * const buf = mKernel->getInputStreamSetBuffer(name);
407    return buf->getRawItemPointer(this, getStreamHandle(name), absolutePosition);
408}
409
410Value * KernelBuilder::getRawOutputPointer(const std::string & name, Value * absolutePosition) {
411    const StreamSetBuffer * const buf = mKernel->getOutputStreamSetBuffer(name);
412    return buf->getRawItemPointer(this, getStreamHandle(name), absolutePosition);
413}
414
415Value * KernelBuilder::getBaseAddress(const std::string & name) {
416    return mKernel->getAnyStreamSetBuffer(name)->getBaseAddress(this, getStreamHandle(name));
417}
418
419void KernelBuilder::setBaseAddress(const std::string & name, Value * const addr) {
420    return mKernel->getAnyStreamSetBuffer(name)->setBaseAddress(this, getStreamHandle(name), addr);
421}
422
423Value * KernelBuilder::getBufferedSize(const std::string & name) {
424    return mKernel->getAnyStreamSetBuffer(name)->getBufferedSize(this, getStreamHandle(name));
425}
426
427void KernelBuilder::setBufferedSize(const std::string & name, Value * size) {
428    mKernel->getAnyStreamSetBuffer(name)->setBufferedSize(this, getStreamHandle(name), size);
429}
430
431Value * KernelBuilder::getCapacity(const std::string & name) {
432    return mKernel->getAnyStreamSetBuffer(name)->getCapacity(this, getStreamHandle(name));
433}
434
435void KernelBuilder::setCapacity(const std::string & name, Value * c) {
436    mKernel->getAnyStreamSetBuffer(name)->setCapacity(this, getStreamHandle(name), c);
437}
438
439void KernelBuilder::protectOutputStream(const std::string & name, const bool readOnly) {
440    const StreamSetBuffer * const buf = mKernel->getOutputStreamSetBuffer(name);
441    Value * const handle = getStreamHandle(name);
442    Value * const base = buf->getBaseAddress(this, handle);
443    Value * sz = ConstantExpr::getSizeOf(buf->getType());
444    sz = CreateMul(sz, getInt64(buf->getBufferBlocks()));
445    sz = CreateMul(sz, CreateZExt(buf->getStreamSetCount(this, handle), getInt64Ty()));
446    CreateMProtect(base, sz, readOnly ? CBuilder::READ : (CBuilder::READ | CBuilder::WRITE));
447}
448   
449CallInst * KernelBuilder::createDoSegmentCall(const std::vector<Value *> & args) {
450    return mKernel->makeDoSegmentCall(*this, args);
451}
452
453Value * KernelBuilder::getAccumulator(const std::string & accumName) {
454    auto results = mKernel->mOutputScalarResult;
455    if (LLVM_UNLIKELY(results == nullptr)) {
456        report_fatal_error("Cannot get accumulator " + accumName + " until " + mKernel->getName() + " has terminated.");
457    }
458    const auto & outputs = mKernel->getScalarOutputs();
459    const auto n = outputs.size();
460    if (LLVM_UNLIKELY(n == 0)) {
461        report_fatal_error(mKernel->getName() + " has no output scalars.");
462    } else {
463        for (unsigned i = 0; i < n; ++i) {
464            const Binding & b = outputs[i];
465            if (b.getName() == accumName) {
466                if (n == 1) {
467                    return results;
468                } else {
469                    return CreateExtractValue(results, {i});
470                }
471            }
472        }
473        report_fatal_error(mKernel->getName() + " has no output scalar named " + accumName);
474    }
475}
476
477void KernelBuilder::doubleCapacity(const std::string & name) {
478    const StreamSetBuffer * const buf = mKernel->getAnyStreamSetBuffer(name);
479    return buf->doubleCapacity(this, getStreamHandle(name));
480}
481
482BasicBlock * KernelBuilder::CreateConsumerWait() {
483    const auto consumers = mKernel->getStreamOutputs();
484    BasicBlock * const entry = GetInsertBlock();
485    if (consumers.empty()) {
486        return entry;
487    } else {
488        Function * const parent = entry->getParent();
489        IntegerType * const sizeTy = getSizeTy();
490        ConstantInt * const zero = getInt32(0);
491        ConstantInt * const one = getInt32(1);
492        ConstantInt * const size0 = getSize(0);
493
494        Value * const segNo = acquireLogicalSegmentNo();
495        const auto n = consumers.size();
496        BasicBlock * load[n + 1];
497        BasicBlock * wait[n];
498        for (unsigned i = 0; i < n; ++i) {
499            load[i] = BasicBlock::Create(getContext(), consumers[i].getName() + "Load", parent);
500            wait[i] = BasicBlock::Create(getContext(), consumers[i].getName() + "Wait", parent);
501        }
502        load[n] = BasicBlock::Create(getContext(), "Resume", parent);
503        CreateBr(load[0]);
504        for (unsigned i = 0; i < n; ++i) {
505
506            SetInsertPoint(load[i]);
507            Value * const outputConsumers = getConsumerLock(consumers[i].getName());
508
509            Value * const consumerCount = CreateLoad(CreateGEP(outputConsumers, {zero, zero}));
510            Value * const consumerPtr = CreateLoad(CreateGEP(outputConsumers, {zero, one}));
511            Value * const noConsumers = CreateICmpEQ(consumerCount, size0);
512            CreateUnlikelyCondBr(noConsumers, load[i + 1], wait[i]);
513
514            SetInsertPoint(wait[i]);
515            PHINode * const consumerPhi = CreatePHI(sizeTy, 2);
516            consumerPhi->addIncoming(size0, load[i]);
517
518            Value * const conSegPtr = CreateLoad(CreateGEP(consumerPtr, consumerPhi));
519            Value * const processedSegmentCount = CreateAtomicLoadAcquire(conSegPtr);
520            Value * const ready = CreateICmpEQ(segNo, processedSegmentCount);
521            assert (ready->getType() == getInt1Ty());
522            Value * const nextConsumerIdx = CreateAdd(consumerPhi, CreateZExt(ready, sizeTy));
523            consumerPhi->addIncoming(nextConsumerIdx, wait[i]);
524            Value * const next = CreateICmpEQ(nextConsumerIdx, consumerCount);
525            CreateCondBr(next, load[i + 1], wait[i]);
526        }
527
528        BasicBlock * const exit = load[n];
529        SetInsertPoint(exit);
530        return exit;
531    }
532}
533
534/** ------------------------------------------------------------------------------------------------------------- *
535 * @brief CreateUDiv2
536 ** ------------------------------------------------------------------------------------------------------------- */
537Value * KernelBuilder::CreateUDiv2(Value * const number, const ProcessingRate::RateValue & divisor, const Twine & Name) {
538    if (divisor.numerator() == 1 && divisor.denominator() == 1) {
539        return number;
540    }
541    Constant * const n = ConstantInt::get(number->getType(), divisor.numerator());
542    if (LLVM_LIKELY(divisor.denominator() == 1)) {
543        return CreateUDiv(number, n, Name);
544    } else {
545        Constant * const d = ConstantInt::get(number->getType(), divisor.denominator());
546        return CreateUDiv(CreateMul(number, d), n);
547    }
548}
549
550/** ------------------------------------------------------------------------------------------------------------- *
551 * @brief CreateCeilUDiv2
552 ** ------------------------------------------------------------------------------------------------------------- */
553Value * KernelBuilder::CreateCeilUDiv2(Value * const number, const ProcessingRate::RateValue & divisor, const Twine & Name) {
554    if (divisor.numerator() == 1 && divisor.denominator() == 1) {
555        return number;
556    }
557    Constant * const n = ConstantInt::get(number->getType(), divisor.numerator());
558    if (LLVM_LIKELY(divisor.denominator() == 1)) {
559        return CreateCeilUDiv(number, n, Name);
560    } else {
561        //   âŒŠ(num + ratio - 1) / ratio⌋
562        // = ⌊(num - 1) / (n/d)⌋ + (ratio/ratio)
563        // = ⌊(d * (num - 1)) / n⌋ + 1
564        Constant * const ONE = ConstantInt::get(number->getType(), 1);
565        Constant * const d = ConstantInt::get(number->getType(), divisor.denominator());
566        return CreateAdd(CreateUDiv(CreateMul(CreateSub(number, ONE), d), n), ONE, Name);
567    }
568}
569
570/** ------------------------------------------------------------------------------------------------------------- *
571 * @brief CreateMul2
572 ** ------------------------------------------------------------------------------------------------------------- */
573Value * KernelBuilder::CreateMul2(Value * const number, const ProcessingRate::RateValue & factor, const Twine & Name) {
574    if (factor.numerator() == 1 && factor.denominator() == 1) {
575        return number;
576    }
577    Constant * const n = ConstantInt::get(number->getType(), factor.numerator());
578    if (LLVM_LIKELY(factor.denominator() == 1)) {
579        return CreateMul(number, n, Name);
580    } else {
581        Constant * const d = ConstantInt::get(number->getType(), factor.denominator());
582        return CreateUDiv(CreateMul(number, n), d, Name);
583    }
584}
585
586/** ------------------------------------------------------------------------------------------------------------- *
587 * @brief CreateMulCeil2
588 ** ------------------------------------------------------------------------------------------------------------- */
589Value * KernelBuilder::CreateCeilUMul2(Value * const number, const ProcessingRate::RateValue & factor, const Twine & Name) {
590    if (factor.denominator() == 1) {
591        return CreateMul2(number, factor, Name);
592    }
593    Constant * const n = ConstantInt::get(number->getType(), factor.numerator());
594    Constant * const d = ConstantInt::get(number->getType(), factor.denominator());
595    return CreateCeilUDiv(CreateMul(number, n), d, Name);
596}
597
598}
Note: See TracBrowser for help on using the repository browser.