source: icGREP/icgrep-devel/icgrep/kernels/kernel_builder.cpp @ 6135

Last change on this file since 6135 was 6047, checked in by nmedfort, 10 months ago

Major refactoring of buffer types. Static buffers replace Circular and CircularCopyback?. External buffers unify Source/External?.

File size: 26.7 KB
Line 
1#include "kernel_builder.h"
2#include <toolchain/toolchain.h>
3#include <kernels/kernel.h>
4#include <kernels/streamset.h>
5#include <llvm/Support/raw_ostream.h>
6#include <llvm/IR/Module.h>
7
8using namespace llvm;
9using namespace parabix;
10
11inline static bool is_power_2(const uint64_t n) {
12    return ((n & (n - 1)) == 0) && n;
13}
14
15namespace kernel {
16
17using Port = Kernel::Port;
18
19Value * KernelBuilder::getScalarFieldPtr(llvm::Value * const instance, Value * const index) {
20    if (LLVM_UNLIKELY(codegen::DebugOptionIsSet(codegen::EnableAsserts))) {
21        CreateAssert(instance, "getScalarFieldPtr: instance cannot be null!");
22    }
23    return CreateGEP(instance, {getInt32(0), index});
24}
25
26Value * KernelBuilder::getScalarFieldPtr(llvm::Value * const handle, const std::string & fieldName) {
27    return getScalarFieldPtr(handle, getInt32(mKernel->getScalarIndex(fieldName)));
28}
29
30llvm::Value * KernelBuilder::getScalarFieldPtr(llvm::Value * const index) {
31    return getScalarFieldPtr(mKernel->getInstance(), index);
32}
33
34llvm::Value *KernelBuilder:: getScalarFieldPtr(const std::string & fieldName) {
35    return getScalarFieldPtr(mKernel->getInstance(), fieldName);
36}
37
38Value * KernelBuilder::getScalarField(const std::string & fieldName) {
39    return CreateLoad(getScalarFieldPtr(fieldName), fieldName);
40}
41
42void KernelBuilder::setScalarField(const std::string & fieldName, Value * value) {
43    CreateStore(value, getScalarFieldPtr(fieldName));
44}
45
46Value * KernelBuilder::getStreamHandle(const std::string & name) {
47    Value * const ptr = getScalarField(name + BUFFER_SUFFIX);
48    if (LLVM_UNLIKELY(codegen::DebugOptionIsSet(codegen::EnableAsserts))) {
49        CreateAssert(ptr, name + " handle cannot be null!");
50    }
51    return ptr;
52}
53
54LoadInst * KernelBuilder::acquireLogicalSegmentNo() {
55    return CreateAtomicLoadAcquire(getScalarFieldPtr(LOGICAL_SEGMENT_NO_SCALAR));
56}
57
58void KernelBuilder::releaseLogicalSegmentNo(Value * const nextSegNo) {
59    CreateAtomicStoreRelease(nextSegNo, getScalarFieldPtr(LOGICAL_SEGMENT_NO_SCALAR));
60}
61
62Value * KernelBuilder::getCycleCountPtr() {
63    return getScalarFieldPtr(CYCLECOUNT_SCALAR);
64}
65
66Value * KernelBuilder::getNamedItemCount(const std::string & name, const std::string & suffix) {
67    const ProcessingRate & rate = mKernel->getBinding(name).getRate();
68    Value * itemCount = nullptr;
69    if (LLVM_UNLIKELY(rate.isRelative())) {
70        Port port; unsigned index;
71        std::tie(port, index) = mKernel->getStreamPort(rate.getReference());
72        if (port == Port::Input) {
73            itemCount = getProcessedItemCount(rate.getReference());
74        } else {
75            itemCount = getProducedItemCount(rate.getReference());
76        }
77        itemCount = CreateMul2(itemCount, rate.getRate());
78    } else {
79        itemCount = getScalarField(name + suffix);
80    }
81    return itemCount;
82}
83
84void KernelBuilder::setNamedItemCount(const std::string & name, const std::string & suffix, llvm::Value * const value) {
85    const ProcessingRate & rate = mKernel->getBinding(name).getRate();
86    const auto safetyCheck = mKernel->treatUnsafeKernelOperationsAsErrors();
87    if (LLVM_UNLIKELY(rate.isDerived() && safetyCheck)) {
88        report_fatal_error("Cannot set item count: " + name + " is a derived rate stream");
89    }
90    if (codegen::DebugOptionIsSet(codegen::TraceCounts)) {
91        CallPrintIntToStderr(mKernel->getName() + ": " + name + suffix, value);
92    }
93    if (LLVM_UNLIKELY(codegen::DebugOptionIsSet(codegen::EnableAsserts) && safetyCheck)) {
94        Value * const current = getScalarField(name + suffix);
95        CreateAssert(CreateICmpUGE(value, current), name + " " + suffix + " must be monotonically non-decreasing");
96    }
97    setScalarField(name + suffix, value);
98}
99
100
101Value * KernelBuilder::getAvailableItemCount(const std::string & name) {
102    const auto & inputs = mKernel->getStreamInputs();
103    for (unsigned i = 0; i < inputs.size(); ++i) {
104        if (inputs[i].getName() == name) {
105            return mKernel->getAvailableItemCount(i);
106        }
107    }
108    return nullptr;
109}
110
111Value * KernelBuilder::getTerminationSignal() {
112    return CreateICmpNE(getScalarField(TERMINATION_SIGNAL), getSize(0));
113}
114
115void KernelBuilder::setTerminationSignal(llvm::Value * const value) {
116    assert (value->getType() == getInt1Ty());
117    if (codegen::DebugOptionIsSet(codegen::TraceCounts)) {
118        CallPrintIntToStderr(mKernel->getName() + ": setTerminationSignal", value);
119    }
120    setScalarField(TERMINATION_SIGNAL, CreateZExt(value, getSizeTy()));
121}
122
123Value * KernelBuilder::getLinearlyAccessibleItems(const std::string & name, Value * fromPosition, Value * avail, bool reverse) {
124    const StreamSetBuffer * const buf = mKernel->getInputStreamSetBuffer(name);
125    return buf->getLinearlyAccessibleItems(this, getStreamHandle(name), fromPosition, avail, reverse);
126}
127
128Value * KernelBuilder::getLinearlyWritableItems(const std::string & name, Value * fromPosition, bool reverse) {
129    const StreamSetBuffer * const buf = mKernel->getOutputStreamSetBuffer(name);
130    return buf->getLinearlyWritableItems(this, getStreamHandle(name), fromPosition, getConsumedItemCount(name), reverse);
131}
132
133/** ------------------------------------------------------------------------------------------------------------- *
134 * @brief CreatePrepareOverflow
135 ** ------------------------------------------------------------------------------------------------------------- */
136void KernelBuilder::CreatePrepareOverflow(const std::string & name) {
137    const StreamSetBuffer * const buf = mKernel->getOutputStreamSetBuffer(name);
138    assert (buf->supportsCopyBack());
139    Constant * const overflowSize = ConstantExpr::getSizeOf(buf->getType());
140    Value * const handle = getStreamHandle(name);
141    // TODO: handle non constant stream set counts
142    assert (isa<Constant>(buf->getStreamSetCount(this, handle)));
143    Value * const base = buf->getBaseAddress(this, handle);
144    Value * const overflow = buf->getOverflowAddress(this, handle);
145    const auto blockSize = getBitBlockWidth() / 8;
146    CreateMemZero(overflow, overflowSize, blockSize);
147    CreateMemZero(base, overflowSize, blockSize);
148}
149
150/** ------------------------------------------------------------------------------------------------------------- *
151 * @brief getItemWidth
152 ** ------------------------------------------------------------------------------------------------------------- */
153inline unsigned LLVM_READNONE getItemWidth(const Type * ty ) {
154    if (LLVM_LIKELY(isa<ArrayType>(ty))) {
155        ty = ty->getArrayElementType();
156    }
157    return cast<IntegerType>(ty->getVectorElementType())->getBitWidth();
158}
159
160/** ------------------------------------------------------------------------------------------------------------- *
161 * @brief CreateNonLinearCopyFromOverflow
162 ** ------------------------------------------------------------------------------------------------------------- */
163void KernelBuilder::CreateNonLinearCopyFromOverflow(const Binding & output, llvm::Value * const itemsToCopy, Value * overflowOffset) {
164
165    Value * const handle = getStreamHandle(output.getName());
166    Type * const bitBlockPtrTy = getBitBlockType()->getPointerTo();
167    const StreamSetBuffer * const buf = mKernel->getOutputStreamSetBuffer(output.getName());
168    assert (buf->supportsCopyBack());
169    Value * const target = CreatePointerCast(buf->getBaseAddress(this, handle), bitBlockPtrTy);
170    Value * const source = CreatePointerCast(buf->getOverflowAddress(this, handle), bitBlockPtrTy);
171    const auto blockSize = getBitBlockWidth() / 8;
172    Constant * const BLOCK_WIDTH = getSize(getBitBlockWidth());
173    Constant * const ITEM_WIDTH = getSize(getItemWidth(buf->getBaseType()));
174    Value * const streamCount = buf->getStreamSetCount(this, handle);
175
176    // If we have a computed overflow position, the base and overflow regions were not speculatively zeroed out prior
177    // to the kernel writing over them. To handle them, we compute a mask of valid items and exclude any bit not in
178    // them before OR-ing together the streams.
179    if (overflowOffset) {
180
181        overflowOffset = CreateMul(overflowOffset, ITEM_WIDTH);
182        Value * targetMask = bitblock_mask_from(CreateURem(overflowOffset, BLOCK_WIDTH));
183        Value * sourceMask = CreateNot(targetMask);
184        Value * const overflowBlockCount = CreateUDiv(overflowOffset, BLOCK_WIDTH);
185        Value * const blockOffset = CreateMul(overflowBlockCount, streamCount);
186        Value * const fullCopyLength = CreateMul(blockOffset, getSize(blockSize));
187        CreateMemCpy(target, source, fullCopyLength, blockSize);
188
189        BasicBlock * const partialCopyEntry = GetInsertBlock();
190        BasicBlock * const partialCopyLoop = CreateBasicBlock();
191        BasicBlock * const partialCopyExit = CreateBasicBlock();
192
193        Value * const partialBlockCount = CreateAdd(blockOffset, streamCount);
194        CreateBr(partialCopyLoop);
195
196        SetInsertPoint(partialCopyLoop);
197        PHINode * const blockIndex = CreatePHI(getSizeTy(), 2);
198        blockIndex->addIncoming(blockOffset, partialCopyEntry);
199        Value * const sourcePtr = CreateGEP(source, blockIndex);
200        Value * sourceValue = CreateBlockAlignedLoad(sourcePtr);
201        sourceValue = CreateAnd(sourceValue, sourceMask);
202        Value * const targetPtr = CreateGEP(target, blockIndex);
203        Value * targetValue = CreateBlockAlignedLoad(targetPtr);
204        targetValue = CreateAnd(targetValue, targetMask);
205        targetValue = CreateOr(targetValue, sourceValue);
206        CreateBlockAlignedStore(targetValue, targetPtr);
207        Value * const nextBlockIndex = CreateAdd(blockIndex, getSize(1));
208        blockIndex->addIncoming(nextBlockIndex, partialCopyLoop);
209        CreateCondBr(CreateICmpNE(nextBlockIndex, partialBlockCount), partialCopyLoop, partialCopyExit);
210
211        SetInsertPoint(partialCopyExit);
212
213    } else {
214
215        BasicBlock * const mergeCopyEntry = GetInsertBlock();
216        BasicBlock * const mergeCopyLoop = CreateBasicBlock();
217        BasicBlock * const mergeCopyExit = CreateBasicBlock();
218
219        Value * blocksToCopy = CreateCeilUDiv(itemsToCopy, BLOCK_WIDTH);
220        blocksToCopy = CreateMul(blocksToCopy, ITEM_WIDTH);
221        blocksToCopy = CreateMul(blocksToCopy, streamCount);
222
223        CreateBr(mergeCopyLoop);
224
225        SetInsertPoint(mergeCopyLoop);
226        PHINode * const blockIndex = CreatePHI(getSizeTy(), 2);
227        blockIndex->addIncoming(getSize(0), mergeCopyEntry);
228        Value * const sourcePtr = CreateGEP(source, blockIndex);
229        Value * const sourceValue = CreateBlockAlignedLoad(sourcePtr);
230        Value * const targetPtr = CreateGEP(target, blockIndex);
231        Value * targetValue = CreateBlockAlignedLoad(targetPtr);
232        targetValue = CreateOr(targetValue, sourceValue);
233        CreateBlockAlignedStore(targetValue, targetPtr);
234        Value * const nextBlockIndex = CreateAdd(blockIndex, getSize(1));
235        blockIndex->addIncoming(nextBlockIndex, mergeCopyLoop);
236        CreateCondBr(CreateICmpNE(nextBlockIndex, blocksToCopy), mergeCopyLoop, mergeCopyExit);
237
238        SetInsertPoint(mergeCopyExit);
239    }
240
241
242
243}
244
245/** ------------------------------------------------------------------------------------------------------------- *
246 * @brief CreateCopyFromOverflow
247 ** ------------------------------------------------------------------------------------------------------------- */
248void KernelBuilder::CreateCopyFromOverflow(const Binding & output, llvm::Value * const itemsToCopy) {
249
250    Value * const handle = getStreamHandle(output.getName());
251    Type * const bitBlockPtrTy = getBitBlockType()->getPointerTo();
252    const StreamSetBuffer * const buf = mKernel->getOutputStreamSetBuffer(output.getName());
253    assert (buf->supportsCopyBack());
254    Value * const target = CreatePointerCast(buf->getBaseAddress(this, handle), bitBlockPtrTy);
255    Value * const source = CreatePointerCast(buf->getOverflowAddress(this, handle), bitBlockPtrTy);
256    Constant * const BLOCK_WIDTH = getSize(getBitBlockWidth());
257    Constant * const ITEM_WIDTH = getSize(getItemWidth(buf->getBaseType()));
258    Value * const streamCount = buf->getStreamSetCount(this, handle);
259
260    BasicBlock * const mergeCopyEntry = GetInsertBlock();
261    BasicBlock * const mergeCopyLoop = CreateBasicBlock();
262    BasicBlock * const mergeCopyExit = CreateBasicBlock();
263
264    Value * blocksToCopy = CreateCeilUDiv(itemsToCopy, BLOCK_WIDTH);
265    blocksToCopy = CreateMul(blocksToCopy, ITEM_WIDTH);
266    blocksToCopy = CreateMul(blocksToCopy, streamCount);
267
268    CreateCondBr(CreateICmpEQ(blocksToCopy, getSize(0)), mergeCopyExit, mergeCopyLoop);
269
270    SetInsertPoint(mergeCopyLoop);
271    PHINode * const blockIndex = CreatePHI(getSizeTy(), 2);
272    blockIndex->addIncoming(getSize(0), mergeCopyEntry);
273    Value * const sourcePtr = CreateGEP(source, blockIndex);
274    Value * const sourceValue = CreateBlockAlignedLoad(sourcePtr);
275    Value * const targetPtr = CreateGEP(target, blockIndex);
276    CreateBlockAlignedStore(sourceValue, targetPtr);
277    Value * const nextBlockIndex = CreateAdd(blockIndex, getSize(1));
278    blockIndex->addIncoming(nextBlockIndex, mergeCopyLoop);
279    CreateCondBr(CreateICmpNE(nextBlockIndex, blocksToCopy), mergeCopyLoop, mergeCopyExit);
280
281    SetInsertPoint(mergeCopyExit);
282}
283
284
285/** ------------------------------------------------------------------------------------------------------------- *
286 * @brief CreateCopyToOverflow
287 ** ------------------------------------------------------------------------------------------------------------- */
288void KernelBuilder::CreateCopyToOverflow(const std::string & name) {
289    const StreamSetBuffer * const buf = mKernel->getOutputStreamSetBuffer(name);
290    assert (buf->supportsCopyBack());
291    Value * const handle = getStreamHandle(name);
292    // TODO: handle non constant stream set counts
293    assert (isa<Constant>(buf->getStreamSetCount(this, handle)));
294    Value * const target = buf->getBaseAddress(this, handle);
295    Value * const source = buf->getOverflowAddress(this, handle);
296    Constant * const overflowSize = ConstantExpr::getSizeOf(buf->getType());
297    CreateMemCpy(target, source, overflowSize, getBitBlockWidth() / 8);
298}
299
300Value * KernelBuilder::getConsumerLock(const std::string & name) {
301    return getScalarField(name + CONSUMER_SUFFIX);
302}
303
304void KernelBuilder::setConsumerLock(const std::string & name, Value * const value) {
305    setScalarField(name + CONSUMER_SUFFIX, value);
306}
307
308Value * KernelBuilder::getInputStreamBlockPtr(const std::string & name, Value * const streamIndex, Value * const blockOffset) {
309    const StreamSetBuffer * const buf = mKernel->getInputStreamSetBuffer(name);
310    Value * blockIndex = CreateLShr(getProcessedItemCount(name), std::log2(getBitBlockWidth()));
311    if (blockOffset) {
312        assert (blockOffset->getType() == blockIndex->getType());
313        blockIndex = CreateAdd(blockIndex, blockOffset);
314    }
315    return buf->getStreamBlockPtr(this, getStreamHandle(name), streamIndex, blockIndex, true);
316}
317
318Value * KernelBuilder::getInputStreamPackPtr(const std::string & name, Value * const streamIndex, Value * const packIndex, Value * const blockOffset) {
319    const StreamSetBuffer * const buf = mKernel->getInputStreamSetBuffer(name);
320    Value * blockIndex = CreateLShr(getProcessedItemCount(name), std::log2(getBitBlockWidth()));
321    if (blockOffset) {
322        assert (blockOffset->getType() == blockIndex->getType());
323        blockIndex = CreateAdd(blockIndex, blockOffset);
324    }
325    return buf->getStreamPackPtr(this, getStreamHandle(name), streamIndex, blockIndex, packIndex, true);
326}
327
328Value * KernelBuilder::loadInputStreamBlock(const std::string & name, Value * const streamIndex, Value * const blockOffset) {
329    return CreateBlockAlignedLoad(getInputStreamBlockPtr(name, streamIndex, blockOffset));
330}
331
332Value * KernelBuilder::loadInputStreamPack(const std::string & name, Value * const streamIndex, Value * const packIndex, Value * const blockOffset) {
333    return CreateBlockAlignedLoad(getInputStreamPackPtr(name, streamIndex, packIndex, blockOffset));
334}
335
336Value * KernelBuilder::getInputStreamSetCount(const std::string & name) {
337    const StreamSetBuffer * const buf = mKernel->getInputStreamSetBuffer(name);
338    return buf->getStreamSetCount(this, getStreamHandle(name));
339}
340
341Value * KernelBuilder::getOutputStreamBlockPtr(const std::string & name, Value * streamIndex, Value * const blockOffset) {
342    const StreamSetBuffer * const buf = mKernel->getOutputStreamSetBuffer(name);
343    Value * blockIndex = CreateLShr(getProducedItemCount(name), std::log2(getBitBlockWidth()));
344    if (blockOffset) {
345        assert (blockOffset->getType() == blockIndex->getType());
346        blockIndex = CreateAdd(blockIndex, blockOffset);
347    }
348    return buf->getStreamBlockPtr(this, getStreamHandle(name), streamIndex, blockIndex, false);
349}
350
351Value * KernelBuilder::getOutputStreamPackPtr(const std::string & name, Value * streamIndex, Value * packIndex, llvm::Value * blockOffset) {
352    const StreamSetBuffer * const buf = mKernel->getOutputStreamSetBuffer(name);
353    Value * blockIndex = CreateLShr(getProducedItemCount(name), std::log2(getBitBlockWidth()));
354    if (blockOffset) {
355        assert (blockOffset->getType() == blockIndex->getType());
356        blockIndex = CreateAdd(blockIndex, blockOffset);
357    }
358    return buf->getStreamPackPtr(this, getStreamHandle(name), streamIndex, blockIndex, packIndex, false);
359}
360
361
362StoreInst * KernelBuilder::storeOutputStreamBlock(const std::string & name, Value * streamIndex, llvm::Value * blockOffset, Value * toStore) {
363    Value * const ptr = getOutputStreamBlockPtr(name, streamIndex, blockOffset);
364    Type * const storeTy = toStore->getType();
365    Type * const ptrElemTy = ptr->getType()->getPointerElementType();
366    if (LLVM_UNLIKELY(storeTy != ptrElemTy)) {
367        if (LLVM_LIKELY(storeTy->canLosslesslyBitCastTo(ptrElemTy))) {
368            toStore = CreateBitCast(toStore, ptrElemTy);
369        } else {
370            std::string tmp;
371            raw_string_ostream out(tmp);
372            out << "invalid type conversion when calling storeOutputStreamBlock on " <<  name << ": ";
373            ptrElemTy->print(out);
374            out << " vs. ";
375            storeTy->print(out);
376        }
377    }
378    return CreateBlockAlignedStore(toStore, ptr);
379}
380
381StoreInst * KernelBuilder::storeOutputStreamPack(const std::string & name, Value * streamIndex, Value * packIndex, llvm::Value * blockOffset, Value * toStore) {
382    Value * const ptr = getOutputStreamPackPtr(name, streamIndex, packIndex, blockOffset);
383    Type * const storeTy = toStore->getType();
384    Type * const ptrElemTy = ptr->getType()->getPointerElementType();
385    if (LLVM_UNLIKELY(storeTy != ptrElemTy)) {
386        if (LLVM_LIKELY(storeTy->canLosslesslyBitCastTo(ptrElemTy))) {
387            toStore = CreateBitCast(toStore, ptrElemTy);
388        } else {
389            std::string tmp;
390            raw_string_ostream out(tmp);
391            out << "invalid type conversion when calling storeOutputStreamPack on " <<  name << ": ";
392            ptrElemTy->print(out);
393            out << " vs. ";
394            storeTy->print(out);
395        }
396    }
397    return CreateBlockAlignedStore(toStore, ptr);
398}
399
400Value * KernelBuilder::getOutputStreamSetCount(const std::string & name) {
401    const StreamSetBuffer * const buf = mKernel->getOutputStreamSetBuffer(name);
402    return buf->getStreamSetCount(this, getStreamHandle(name));
403}
404
405Value * KernelBuilder::getRawInputPointer(const std::string & name, Value * absolutePosition) {
406    const StreamSetBuffer * const buf = mKernel->getInputStreamSetBuffer(name);
407    return buf->getRawItemPointer(this, getStreamHandle(name), absolutePosition);
408}
409
410Value * KernelBuilder::getRawOutputPointer(const std::string & name, Value * absolutePosition) {
411    const StreamSetBuffer * const buf = mKernel->getOutputStreamSetBuffer(name);
412    return buf->getRawItemPointer(this, getStreamHandle(name), absolutePosition);
413}
414
415Value * KernelBuilder::getBaseAddress(const std::string & name) {
416    return mKernel->getAnyStreamSetBuffer(name)->getBaseAddress(this, getStreamHandle(name));
417}
418
419void KernelBuilder::setBaseAddress(const std::string & name, Value * const addr) {
420    return mKernel->getAnyStreamSetBuffer(name)->setBaseAddress(this, getStreamHandle(name), addr);
421}
422
423Value * KernelBuilder::getCapacity(const std::string & name) {
424    return mKernel->getAnyStreamSetBuffer(name)->getCapacity(this, getStreamHandle(name));
425}
426
427void KernelBuilder::setCapacity(const std::string & name, Value * c) {
428    mKernel->getAnyStreamSetBuffer(name)->setCapacity(this, getStreamHandle(name), c);
429}
430   
431CallInst * KernelBuilder::createDoSegmentCall(const std::vector<Value *> & args) {
432    return mKernel->makeDoSegmentCall(*this, args);
433}
434
435Value * KernelBuilder::getAccumulator(const std::string & accumName) {
436    auto results = mKernel->mOutputScalarResult;
437    if (LLVM_UNLIKELY(results == nullptr)) {
438        report_fatal_error("Cannot get accumulator " + accumName + " until " + mKernel->getName() + " has terminated.");
439    }
440    const auto & outputs = mKernel->getScalarOutputs();
441    const auto n = outputs.size();
442    if (LLVM_UNLIKELY(n == 0)) {
443        report_fatal_error(mKernel->getName() + " has no output scalars.");
444    } else {
445        for (unsigned i = 0; i < n; ++i) {
446            const Binding & b = outputs[i];
447            if (b.getName() == accumName) {
448                if (n == 1) {
449                    return results;
450                } else {
451                    return CreateExtractValue(results, {i});
452                }
453            }
454        }
455        report_fatal_error(mKernel->getName() + " has no output scalar named " + accumName);
456    }
457}
458
459BasicBlock * KernelBuilder::CreateConsumerWait() {
460    const auto consumers = mKernel->getStreamOutputs();
461    BasicBlock * const entry = GetInsertBlock();
462    if (consumers.empty()) {
463        return entry;
464    } else {
465        Function * const parent = entry->getParent();
466        IntegerType * const sizeTy = getSizeTy();
467        ConstantInt * const zero = getInt32(0);
468        ConstantInt * const one = getInt32(1);
469        ConstantInt * const size0 = getSize(0);
470
471        Value * const segNo = acquireLogicalSegmentNo();
472        const auto n = consumers.size();
473        BasicBlock * load[n + 1];
474        BasicBlock * wait[n];
475        for (unsigned i = 0; i < n; ++i) {
476            load[i] = BasicBlock::Create(getContext(), consumers[i].getName() + "Load", parent);
477            wait[i] = BasicBlock::Create(getContext(), consumers[i].getName() + "Wait", parent);
478        }
479        load[n] = BasicBlock::Create(getContext(), "Resume", parent);
480        CreateBr(load[0]);
481        for (unsigned i = 0; i < n; ++i) {
482
483            SetInsertPoint(load[i]);
484            Value * const outputConsumers = getConsumerLock(consumers[i].getName());
485
486            Value * const consumerCount = CreateLoad(CreateGEP(outputConsumers, {zero, zero}));
487            Value * const consumerPtr = CreateLoad(CreateGEP(outputConsumers, {zero, one}));
488            Value * const noConsumers = CreateICmpEQ(consumerCount, size0);
489            CreateUnlikelyCondBr(noConsumers, load[i + 1], wait[i]);
490
491            SetInsertPoint(wait[i]);
492            PHINode * const consumerPhi = CreatePHI(sizeTy, 2);
493            consumerPhi->addIncoming(size0, load[i]);
494
495            Value * const conSegPtr = CreateLoad(CreateGEP(consumerPtr, consumerPhi));
496            Value * const processedSegmentCount = CreateAtomicLoadAcquire(conSegPtr);
497            Value * const ready = CreateICmpEQ(segNo, processedSegmentCount);
498            assert (ready->getType() == getInt1Ty());
499            Value * const nextConsumerIdx = CreateAdd(consumerPhi, CreateZExt(ready, sizeTy));
500            consumerPhi->addIncoming(nextConsumerIdx, wait[i]);
501            Value * const next = CreateICmpEQ(nextConsumerIdx, consumerCount);
502            CreateCondBr(next, load[i + 1], wait[i]);
503        }
504
505        BasicBlock * const exit = load[n];
506        SetInsertPoint(exit);
507        return exit;
508    }
509}
510
511/** ------------------------------------------------------------------------------------------------------------- *
512 * @brief CreateUDiv2
513 ** ------------------------------------------------------------------------------------------------------------- */
514Value * KernelBuilder::CreateUDiv2(Value * const number, const ProcessingRate::RateValue & divisor, const Twine & Name) {
515    if (divisor.numerator() == 1 && divisor.denominator() == 1) {
516        return number;
517    }
518    Constant * const n = ConstantInt::get(number->getType(), divisor.numerator());
519    if (LLVM_LIKELY(divisor.denominator() == 1)) {
520        return CreateUDiv(number, n, Name);
521    } else {
522        Constant * const d = ConstantInt::get(number->getType(), divisor.denominator());
523        return CreateUDiv(CreateMul(number, d), n);
524    }
525}
526
527/** ------------------------------------------------------------------------------------------------------------- *
528 * @brief CreateCeilUDiv2
529 ** ------------------------------------------------------------------------------------------------------------- */
530Value * KernelBuilder::CreateCeilUDiv2(Value * const number, const ProcessingRate::RateValue & divisor, const Twine & Name) {
531    if (divisor.numerator() == 1 && divisor.denominator() == 1) {
532        return number;
533    }
534    Constant * const n = ConstantInt::get(number->getType(), divisor.numerator());
535    if (LLVM_LIKELY(divisor.denominator() == 1)) {
536        return CreateCeilUDiv(number, n, Name);
537    } else {
538        //   âŒŠ(num + ratio - 1) / ratio⌋
539        // = ⌊(num - 1) / (n/d)⌋ + (ratio/ratio)
540        // = ⌊(d * (num - 1)) / n⌋ + 1
541        Constant * const ONE = ConstantInt::get(number->getType(), 1);
542        Constant * const d = ConstantInt::get(number->getType(), divisor.denominator());
543        return CreateAdd(CreateUDiv(CreateMul(CreateSub(number, ONE), d), n), ONE, Name);
544    }
545}
546
547/** ------------------------------------------------------------------------------------------------------------- *
548 * @brief CreateMul2
549 ** ------------------------------------------------------------------------------------------------------------- */
550Value * KernelBuilder::CreateMul2(Value * const number, const ProcessingRate::RateValue & factor, const Twine & Name) {
551    if (factor.numerator() == 1 && factor.denominator() == 1) {
552        return number;
553    }
554    Constant * const n = ConstantInt::get(number->getType(), factor.numerator());
555    if (LLVM_LIKELY(factor.denominator() == 1)) {
556        return CreateMul(number, n, Name);
557    } else {
558        Constant * const d = ConstantInt::get(number->getType(), factor.denominator());
559        return CreateUDiv(CreateMul(number, n), d, Name);
560    }
561}
562
563/** ------------------------------------------------------------------------------------------------------------- *
564 * @brief CreateMulCeil2
565 ** ------------------------------------------------------------------------------------------------------------- */
566Value * KernelBuilder::CreateCeilUMul2(Value * const number, const ProcessingRate::RateValue & factor, const Twine & Name) {
567    if (factor.denominator() == 1) {
568        return CreateMul2(number, factor, Name);
569    }
570    Constant * const n = ConstantInt::get(number->getType(), factor.numerator());
571    Constant * const d = ConstantInt::get(number->getType(), factor.denominator());
572    return CreateCeilUDiv(CreateMul(number, n), d, Name);
573}
574
575}
Note: See TracBrowser for help on using the repository browser.