source: icGREP/icgrep-devel/icgrep/kernels/kernel_builder.cpp @ 5846

Last change on this file since 5846 was 5846, checked in by xwa163, 16 months ago

Fix bug of KernelBuilder::CreateStreamCpy? when handling copy back of StreamSet? buffer

File size: 28.0 KB
Line 
1#include "kernel_builder.h"
2#include <toolchain/toolchain.h>
3#include <kernels/kernel.h>
4#include <kernels/streamset.h>
5#include <llvm/Support/raw_ostream.h>
6#include <llvm/IR/Module.h>
7
8using namespace llvm;
9using namespace parabix;
10
11inline static bool is_power_2(const uint64_t n) {
12    return ((n & (n - 1)) == 0) && n;
13}
14
15namespace kernel {
16
17using Port = Kernel::Port;
18
19Value * KernelBuilder::getScalarFieldPtr(llvm::Value * const instance, Value * const index) {
20    if (LLVM_UNLIKELY(codegen::DebugOptionIsSet(codegen::EnableAsserts))) {
21        CreateAssert(instance, "getScalarFieldPtr: instance cannot be null!");
22    }
23    return CreateGEP(instance, {getInt32(0), index});
24}
25
26Value * KernelBuilder::getScalarFieldPtr(llvm::Value * const handle, const std::string & fieldName) {
27    return getScalarFieldPtr(handle, getInt32(mKernel->getScalarIndex(fieldName)));
28}
29
30llvm::Value * KernelBuilder::getScalarFieldPtr(llvm::Value * const index) {
31    return getScalarFieldPtr(mKernel->getInstance(), index);
32}
33
34llvm::Value *KernelBuilder:: getScalarFieldPtr(const std::string & fieldName) {
35    return getScalarFieldPtr(mKernel->getInstance(), fieldName);
36}
37
38Value * KernelBuilder::getScalarField(const std::string & fieldName) {
39    return CreateLoad(getScalarFieldPtr(fieldName), fieldName);
40}
41
42void KernelBuilder::setScalarField(const std::string & fieldName, Value * value) {
43    CreateStore(value, getScalarFieldPtr(fieldName));
44}
45
46Value * KernelBuilder::getStreamHandle(const std::string & name) {
47    Value * const ptr = getScalarField(name + Kernel::BUFFER_PTR_SUFFIX);
48    if (LLVM_UNLIKELY(codegen::DebugOptionIsSet(codegen::EnableAsserts))) {
49        CreateAssert(ptr, name + " handle cannot be null!");
50    }
51    return ptr;
52}
53
54LoadInst * KernelBuilder::acquireLogicalSegmentNo() {
55    return CreateAtomicLoadAcquire(getScalarFieldPtr(Kernel::LOGICAL_SEGMENT_NO_SCALAR));
56}
57
58void KernelBuilder::releaseLogicalSegmentNo(Value * nextSegNo) {
59    CreateAtomicStoreRelease(nextSegNo, getScalarFieldPtr(Kernel::LOGICAL_SEGMENT_NO_SCALAR));
60}
61
62Value * KernelBuilder::getCycleCountPtr() {
63    return getScalarFieldPtr(Kernel::CYCLECOUNT_SCALAR);
64}
65
66Value * KernelBuilder::getInternalItemCount(const std::string & name, const std::string & suffix) {
67    const ProcessingRate & rate = mKernel->getBinding(name).getRate();
68    Value * itemCount = nullptr;
69    if (LLVM_UNLIKELY(rate.isRelative())) {
70        Port port; unsigned index;
71        std::tie(port, index) = mKernel->getStreamPort(rate.getReference());
72        if (port == Port::Input) {
73            itemCount = getProcessedItemCount(rate.getReference());
74        } else {
75            itemCount = getProducedItemCount(rate.getReference());
76        }
77        const auto & r = rate.getRate();
78        if (r.numerator() != 1) {
79            itemCount = CreateMul(itemCount, ConstantInt::get(itemCount->getType(), r.numerator()));
80        }
81        if (r.denominator() != 1) {
82            itemCount = CreateExactUDiv(itemCount, ConstantInt::get(itemCount->getType(), r.denominator()));
83        }
84    } else if (LLVM_UNLIKELY(rate.isPopCount())) {
85        Port port; unsigned index;
86        std::tie(port, index) = mKernel->getStreamPort(rate.getReference());
87
88
89
90
91    } else {
92        itemCount = getScalarField(name + suffix);
93    }
94    return itemCount;
95}
96
97void KernelBuilder::setInternalItemCount(const std::string & name, const std::string & suffix, llvm::Value * const value) {
98    const ProcessingRate & rate = mKernel->getBinding(name).getRate();
99    if (LLVM_UNLIKELY(rate.isDerived())) {
100        report_fatal_error("Cannot set item count: " + name + " is a Derived rate");
101    }
102    if (codegen::DebugOptionIsSet(codegen::TraceCounts)) {
103        CallPrintIntToStderr(mKernel->getName() + ": " + name + suffix, value);
104    }
105    setScalarField(name + suffix, value);
106}
107
108
109Value * KernelBuilder::getAvailableItemCount(const std::string & name) {
110    const auto & inputs = mKernel->getStreamInputs();
111    for (unsigned i = 0; i < inputs.size(); ++i) {
112        if (inputs[i].getName() == name) {
113            return mKernel->getAvailableItemCount(i);
114        }
115    }
116    return nullptr;
117}
118
119Value * KernelBuilder::getTerminationSignal() {
120    return CreateICmpNE(getScalarField(Kernel::TERMINATION_SIGNAL), getSize(0));
121}
122
123void KernelBuilder::setTerminationSignal(llvm::Value * const value) {
124    assert (value->getType() == getInt1Ty());
125    if (codegen::DebugOptionIsSet(codegen::TraceCounts)) {
126        CallPrintIntToStderr(mKernel->getName() + ": setTerminationSignal", value);
127    }
128    setScalarField(Kernel::TERMINATION_SIGNAL, CreateZExt(value, getSizeTy()));
129}
130
131Value * KernelBuilder::getLinearlyAccessibleItems(const std::string & name, Value * fromPosition, Value * avail, bool reverse) {
132    const StreamSetBuffer * const buf = mKernel->getInputStreamSetBuffer(name);
133    return buf->getLinearlyAccessibleItems(this, getStreamHandle(name), fromPosition, avail, reverse);
134}
135
136Value * KernelBuilder::getLinearlyWritableItems(const std::string & name, Value * fromPosition, bool reverse) {
137    const StreamSetBuffer * const buf = mKernel->getOutputStreamSetBuffer(name);
138    return buf->getLinearlyWritableItems(this, getStreamHandle(name), fromPosition, getConsumedItemCount(name), reverse);
139}
140
141/** ------------------------------------------------------------------------------------------------------------- *
142 * @brief isConstantZero
143 ** ------------------------------------------------------------------------------------------------------------- */
144inline bool isConstantZero(Value * const v) {
145    return isa<ConstantInt>(v) && cast<ConstantInt>(v)->isNullValue();
146}
147
148/** ------------------------------------------------------------------------------------------------------------- *
149 * @brief isConstantOne
150 ** ------------------------------------------------------------------------------------------------------------- */
151inline bool isConstantOne(Value * const v) {
152    return isa<ConstantInt>(v) && cast<ConstantInt>(v)->isOne();
153}
154
155/** ------------------------------------------------------------------------------------------------------------- *
156 * @brief getItemWidth
157 ** ------------------------------------------------------------------------------------------------------------- */
158inline unsigned getItemWidth(const Type * ty) {
159    if (LLVM_LIKELY(isa<ArrayType>(ty))) {
160        ty = ty->getArrayElementType();
161    }
162    return cast<IntegerType>(ty->getVectorElementType())->getBitWidth();
163}
164
165/** ------------------------------------------------------------------------------------------------------------- *
166 * @brief getFieldWidth
167 ** ------------------------------------------------------------------------------------------------------------- */
168inline unsigned getFieldWidth(const unsigned bitWidth, const unsigned blockWidth) {
169    for (unsigned k = 16; k <= blockWidth; k *= 2) {
170        if ((bitWidth & (k - 1)) != 0) {
171            return k / 2;
172        }
173    }
174    return blockWidth;
175}
176
177/** ------------------------------------------------------------------------------------------------------------- *
178 * @brief CreateStreamCpy
179 ** ------------------------------------------------------------------------------------------------------------- */
180void KernelBuilder::CreateStreamCpy(const std::string & name, Value * target, Value * targetOffset, Value * source, Value * sourceOffset, Value * itemsToCopy, const unsigned itemAlignment) {
181
182    assert (target && targetOffset);
183    assert (source && sourceOffset);
184    assert (target->getType() == source->getType());
185    assert (target->getType()->isPointerTy());
186    assert (isConstantZero(targetOffset) || isConstantZero(sourceOffset));
187
188    const StreamSetBuffer * const buf = mKernel->getAnyStreamSetBuffer(name);
189
190    const auto itemWidth = getItemWidth(buf->getBaseType());
191    assert ("invalid item width" && is_power_2(itemWidth));
192    const auto blockWidth = getBitBlockWidth();
193    // Although our item width may be n bits, if we know we're always processing m items per block, our field width
194    // (w.r.t the stream copy) would be n*m. By taking this into account we can optimize and simplify the copy code.
195    const auto fieldWidth = getFieldWidth(itemWidth * itemAlignment, blockWidth);
196    const auto alignment = (fieldWidth + 7) / 8;
197
198    if (LLVM_LIKELY(itemWidth < fieldWidth)) {
199        const auto factor = fieldWidth / itemWidth;
200        Constant * const FACTOR = getSize(factor);
201        if (LLVM_UNLIKELY(codegen::DebugOptionIsSet(codegen::EnableAsserts))) {
202            ConstantInt * const ALIGNMENT = getSize(alignment);
203            const auto kernelName = mKernel->getName()+ ": " + name;
204            CreateAssertZero(CreateURem(CreatePtrToInt(target, getSizeTy()), ALIGNMENT), kernelName + " target is misaligned (" + std::to_string(alignment) + ")");
205            CreateAssertZero(CreateURem(targetOffset, FACTOR), kernelName + " target offset is misaligned (" + std::to_string(factor) + ")");
206            CreateAssertZero(CreateURem(CreatePtrToInt(source, getSizeTy()), ALIGNMENT), kernelName + " source is misaligned (" + std::to_string(alignment) + ")");
207            CreateAssertZero(CreateURem(sourceOffset, FACTOR), kernelName + " source offset is misaligned (" + std::to_string(factor) + ")");
208        }
209        targetOffset = CreateUDiv(targetOffset, FACTOR);
210        sourceOffset = CreateUDiv(sourceOffset, FACTOR);
211    }
212
213    /*
214       Streams are conceptually modelled as:
215
216                                            BLOCKS
217
218                                      A     B     C     D
219           STREAM SET ELEMENT   1  |aaaaa|bbbbb|ccccc|dddd |
220                                2  |eeeee|fffff|ggggg|hhhh |
221                                3  |iiiii|jjjjj|kkkkk|llll |
222
223       But the memory layout is actually:
224
225           A_1   A_2   A_3   B_1   B_2   B_3   C_1   C_2   C_3   D_1   D_2   D_3
226
227         |aaaaa|eeeee|iiiii|bbbbb|fffff|jjjjj|ccccc|ggggg|kkkkk|dddd |hhhh |llll |
228
229
230       So if we're copying the entire stream set block or our stream set has one element, we can use memcpy.
231
232       One compilication here is when the BlockSize of a stream is not equal to the BitBlockWidth.
233
234
235    */
236
237    Type * const fieldWidthTy = getIntNTy(fieldWidth);
238
239    Value * n = buf->getStreamSetCount(this, getStreamHandle(name));
240
241    if (isConstantOne(n) || fieldWidth == blockWidth || (isConstantZero(targetOffset) && isConstantZero(sourceOffset))) {
242        if (isConstantOne(n)) {
243            if (LLVM_LIKELY(itemWidth < 8)) {
244                itemsToCopy = CreateUDivCeil(itemsToCopy, getSize(8 / itemWidth));
245            } else if (LLVM_UNLIKELY(itemWidth > 8)) {
246                itemsToCopy = CreateMul(itemsToCopy, getSize(itemWidth / 8));
247            }
248        } else {
249            if (LLVM_LIKELY(blockWidth > (itemWidth * 8))) {
250                itemsToCopy = CreateUDivCeil(itemsToCopy, getSize(blockWidth / (8 * itemWidth)));
251            } else if (LLVM_LIKELY(blockWidth < (itemWidth * 8))) {
252                itemsToCopy = CreateUDivCeil(CreateMul(itemsToCopy, getSize(8)), getSize(blockWidth / itemWidth));
253            }
254            itemsToCopy = CreateMul(itemsToCopy, n);
255        }
256        PointerType * const ptrTy = fieldWidthTy->getPointerTo();
257        target = CreateGEP(CreatePointerCast(target, ptrTy), targetOffset);
258        source = CreateGEP(CreatePointerCast(source, ptrTy), sourceOffset);
259        CreateMemCpy(target, source, itemsToCopy, alignment);
260
261    } else { // either the target offset or source offset is non-zero but not both
262        auto t = getIntNTy(fieldWidth * buf->getNumOfStreams());
263        PointerType * const ptrTy = t->getPointerTo();
264        target = CreateGEP(CreatePointerCast(target, ptrTy), targetOffset);
265        source = CreateGEP(CreatePointerCast(source, ptrTy), sourceOffset);
266        n = this->CreateUDiv(n, this->getSize(buf->getNumOfStreams()));
267
268        VectorType * const blockTy = getBitBlockType();
269        PointerType * const blockPtrTy = blockTy->getPointerTo();
270
271        target = CreatePointerCast(target, blockPtrTy, "target");
272        source = CreatePointerCast(source, blockPtrTy, "source");
273
274        assert ((blockWidth % fieldWidth) == 0);
275
276        VectorType * const shiftTy = VectorType::get(fieldWidthTy, blockWidth / fieldWidth);
277        Constant * const width = getSize(blockWidth / itemWidth);
278        Constant * const ZERO = getSize(0);
279        Constant * const ONE = getSize(1);
280        BasicBlock * const entry = GetInsertBlock();
281
282        if (isConstantZero(targetOffset)) {
283
284            /*
285                                                BLOCKS
286
287                                          A     B     C     D
288               SOURCE STREAM        1  |aaa--|bbbBB|cccCC|  dDD|
289                                    2  |eee--|fffFF|gggGG|  hHH|
290                                    3  |iii--|jjjJJ|kkkKK|  lLL|
291
292
293                                          A     B     C     D
294               TARGET STREAM        1  |BBaaa|CCbbb|DDccc|    d|
295                                    2  |FFeee|GGfff|HHggg|    h|
296                                    3  |JJiii|KKjjj|LLkkk|    l|
297             */
298
299            Value * const blocksToCopy = CreateMul(CreateUDiv(itemsToCopy, width), n);
300            Value * const offset = CreateURem(sourceOffset, width);
301            Value * const offsetVector = simd_fill(fieldWidth, CreateTrunc(offset, fieldWidthTy));
302            Value * const remaining = CreateSub(width, offset);
303            Value * const remainingVector = simd_fill(fieldWidth, CreateTrunc(remaining, fieldWidthTy));
304
305            BasicBlock * const streamCopy = CreateBasicBlock(name + "PullCopy");
306            BasicBlock * const streamCopyRemaining = CreateBasicBlock(name + "PullCopyRemaining");
307            BasicBlock * const streamCopyEnd = CreateBasicBlock(name + "PullCopyEnd");
308
309            CreateCondBr(CreateICmpNE(blocksToCopy, ZERO), streamCopy, streamCopyRemaining);
310
311            SetInsertPoint(streamCopy);
312            PHINode * const i = CreatePHI(getSizeTy(), 2);
313            i->addIncoming(n, entry);
314            Value * prior = CreateAlignedLoad(CreateGEP(source, CreateSub(i, n)), alignment);
315            prior = CreateBitCast(CreateLShr(CreateBitCast(prior, shiftTy), offsetVector), blockTy);
316            Value * value = CreateAlignedLoad(CreateGEP(source, i), alignment);
317            value = CreateBitCast(CreateShl(CreateBitCast(value, shiftTy), remainingVector), blockTy);
318            CreateAlignedStore(CreateOr(value, prior), CreateGEP(target, i), alignment);
319            Value * const next_i = CreateAdd(i, ONE);
320            i->addIncoming(next_i, streamCopy);
321            CreateCondBr(CreateICmpNE(next_i, blocksToCopy), streamCopy, streamCopyRemaining);
322
323            SetInsertPoint(streamCopyRemaining);
324            PHINode * const j = CreatePHI(getSizeTy(), 2);
325            j->addIncoming(blocksToCopy, entry);
326            j->addIncoming(blocksToCopy, streamCopy);
327            Value * final = CreateAlignedLoad(CreateGEP(source, j), alignment);
328            final = CreateBitCast(CreateLShr(CreateBitCast(final, shiftTy), offsetVector), blockTy);
329            CreateAlignedStore(final, CreateGEP(target, j), alignment);
330            Value * const next_j = CreateAdd(j, ONE);
331            j->addIncoming(next_j, streamCopyRemaining);
332            CreateCondBr(CreateICmpNE(next_j, CreateAdd(blocksToCopy, n)), streamCopyRemaining, streamCopyEnd);
333
334            SetInsertPoint(streamCopyEnd);
335
336        } else if (isConstantZero(sourceOffset)) {
337
338            /*
339                                                BLOCKS
340
341                                          A     B     C     D
342               SOURCE STREAM        1  |AAAaa|BBBaa|CCCcc|    d|
343                                    2  |EEEee|FFFff|GGGgg|    h|
344                                    3  |IIIii|JJJjj|KKKkk|    l|
345
346
347                                          A     B     C     D
348               TARGET STREAM        1  |aa---|bbAAA|ccBBB| dCCC|
349                                    2  |ee---|ffEEE|ggFFF| hGGG|
350                                    3  |ii---|jjIII|kkJJJ| lKKK|
351
352            */
353
354            BasicBlock * const streamCopy = CreateBasicBlock(name + "PushCopy");
355            BasicBlock * const streamCopyRemainingCond = CreateBasicBlock(name + "PushCopyRemainingCond");
356            BasicBlock * const streamCopyRemaining = CreateBasicBlock(name + "PushCopyRemaining");
357            BasicBlock * const streamCopyEnd = CreateBasicBlock(name + "PushCopyEnd");
358
359            Value * const pos = CreateURem(targetOffset, width);
360            Value * const copied = CreateSub(width, pos);
361            Value * const copiedVector = simd_fill(fieldWidth, CreateTrunc(copied, fieldWidthTy));
362            Value * const mask = CreateLShr(Constant::getAllOnesValue(shiftTy), copiedVector);
363            Value * const offsetVector = simd_fill(fieldWidth, CreateTrunc(pos, fieldWidthTy));
364
365            CreateBr(streamCopy);
366
367            SetInsertPoint(streamCopy);
368            PHINode * const i = CreatePHI(getSizeTy(), 2);
369            i->addIncoming(ZERO, entry);
370            Value * priorTargetValue = CreateAlignedLoad(CreateGEP(target, i), alignment);
371            priorTargetValue = CreateBitCast(CreateAnd(CreateBitCast(priorTargetValue, shiftTy), mask), blockTy);
372            Value * sourceValue = CreateAlignedLoad(CreateGEP(source, i), alignment);
373            sourceValue = CreateBitCast(CreateShl(CreateBitCast(sourceValue, shiftTy), offsetVector), blockTy);
374            CreateAlignedStore(CreateOr(sourceValue, priorTargetValue), CreateGEP(target, i), alignment);
375            Value * const next_i = CreateAdd(i, ONE);
376            i->addIncoming(next_i, streamCopy);
377            CreateCondBr(CreateICmpNE(next_i, n), streamCopy, streamCopyRemainingCond);
378
379            SetInsertPoint(streamCopyRemainingCond);
380            Value * const blocksToCopy = CreateMul(CreateUDiv(CreateSub(itemsToCopy, copied), width), n);
381            CreateCondBr(CreateICmpULT(copied, itemsToCopy), streamCopyRemaining, streamCopyEnd);
382
383            SetInsertPoint(streamCopyRemaining);
384            PHINode * const j = CreatePHI(getSizeTy(), 2);
385            j->addIncoming(n, streamCopyRemainingCond);
386            Value * prior = CreateAlignedLoad(CreateGEP(source, CreateSub(j, n)), alignment);
387            prior = CreateBitCast(CreateShl(CreateBitCast(prior, shiftTy), offsetVector), blockTy);
388            Value * value = CreateAlignedLoad(CreateGEP(source, j), alignment);
389            value = CreateBitCast(CreateLShr(CreateBitCast(value, shiftTy), copiedVector), blockTy);
390            CreateAlignedStore(CreateOr(value, prior), CreateGEP(target, j), alignment);
391            Value * const next_j = CreateAdd(j, ONE);
392            j->addIncoming(next_j, streamCopyRemaining);
393            CreateCondBr(CreateICmpNE(next_j, blocksToCopy), streamCopyRemaining, streamCopyEnd);
394
395            SetInsertPoint(streamCopyEnd);
396        }
397    }
398}
399
400Value * KernelBuilder::getConsumerLock(const std::string & name) {
401    return getScalarField(name + Kernel::CONSUMER_SUFFIX);
402}
403
404void KernelBuilder::setConsumerLock(const std::string & name, Value * value) {
405    setScalarField(name + Kernel::CONSUMER_SUFFIX, value);
406}
407
408Value * KernelBuilder::loadInputStreamBlock(const std::string & name, Value * streamIndex) {
409    return CreateBlockAlignedLoad(getInputStreamBlockPtr(name, streamIndex));
410}
411
412Value * KernelBuilder::getInputStreamPackPtr(const std::string & name, Value * streamIndex, Value * packIndex) {
413    Value * const addr = mKernel->getStreamSetInputAddress(name);
414    if (addr) {
415        return CreateGEP(addr, {getInt32(0), streamIndex, packIndex});
416    } else {
417        const StreamSetBuffer * const buf = mKernel->getInputStreamSetBuffer(name);
418        Value * const blockIndex = CreateLShr(getProcessedItemCount(name), std::log2(getBitBlockWidth()));
419        return buf->getStreamPackPtr(this, getStreamHandle(name), getBaseAddress(name), streamIndex, blockIndex, packIndex, true);
420    }
421}
422
423Value * KernelBuilder::loadInputStreamPack(const std::string & name, Value * streamIndex, Value * packIndex) {
424    return CreateBlockAlignedLoad(getInputStreamPackPtr(name, streamIndex, packIndex));
425}
426
427Value * KernelBuilder::getInputStreamSetCount(const std::string & name) {
428    const StreamSetBuffer * const buf = mKernel->getInputStreamSetBuffer(name);
429    return buf->getStreamSetCount(this, getStreamHandle(name));
430}
431
432Value * KernelBuilder::getInputStreamBlockPtr(const std::string & name, Value * const streamIndex, Value * const blockOffset) {
433    Value * const addr = mKernel->getStreamSetInputAddress(name);
434    if (addr) {
435        return CreateGEP(addr, {blockOffset, streamIndex});
436    } else {
437        const StreamSetBuffer * const buf = mKernel->getInputStreamSetBuffer(name);
438        Value * blockIndex = CreateLShr(getProcessedItemCount(name), std::log2(getBitBlockWidth()));
439        blockIndex = CreateAdd(blockIndex, blockOffset);
440        return buf->getStreamBlockPtr(this, getStreamHandle(name), getBaseAddress(name), streamIndex, blockIndex, true);
441    }
442}
443
444Value * KernelBuilder::getOutputStreamBlockPtr(const std::string & name, Value * streamIndex, Value * const blockOffset) {
445    Value * const addr = mKernel->getStreamSetOutputAddress(name);
446    if (addr) {
447        return CreateGEP(addr, {blockOffset, streamIndex});
448    } else {
449        const StreamSetBuffer * const buf = mKernel->getOutputStreamSetBuffer(name);
450        Value * const blockIndex = CreateLShr(getProducedItemCount(name), std::log2(getBitBlockWidth()));
451        return buf->getStreamBlockPtr(this, getStreamHandle(name), getBaseAddress(name), streamIndex, blockIndex, false);
452    }
453}
454
455StoreInst * KernelBuilder::storeOutputStreamBlock(const std::string & name, Value * streamIndex, Value * toStore) {
456    return CreateBlockAlignedStore(toStore, getOutputStreamBlockPtr(name, streamIndex));
457}
458
459Value * KernelBuilder::getOutputStreamPackPtr(const std::string & name, Value * streamIndex, Value * packIndex) {
460    Value * const addr = mKernel->getStreamSetOutputAddress(name);
461    if (addr) {
462        return CreateGEP(addr, {getInt32(0), streamIndex, packIndex});
463    } else {
464        const StreamSetBuffer * const buf = mKernel->getOutputStreamSetBuffer(name);
465        Value * const blockIndex = CreateLShr(getProducedItemCount(name), std::log2(getBitBlockWidth()));
466        return buf->getStreamPackPtr(this, getStreamHandle(name), getBaseAddress(name), streamIndex, blockIndex, packIndex, false);
467    }
468}
469
470StoreInst * KernelBuilder::storeOutputStreamPack(const std::string & name, Value * streamIndex, Value * packIndex, Value * toStore) {
471    return CreateBlockAlignedStore(toStore, getOutputStreamPackPtr(name, streamIndex, packIndex));
472}
473
474Value * KernelBuilder::getOutputStreamSetCount(const std::string & name) {
475    const StreamSetBuffer * const buf = mKernel->getOutputStreamSetBuffer(name);
476    return buf->getStreamSetCount(this, getStreamHandle(name));
477}
478
479Value * KernelBuilder::getRawInputPointer(const std::string & name, Value * absolutePosition) {
480    const StreamSetBuffer * const buf = mKernel->getInputStreamSetBuffer(name);
481    return buf->getRawItemPointer(this, getStreamHandle(name), absolutePosition);
482}
483
484Value * KernelBuilder::getRawOutputPointer(const std::string & name, Value * absolutePosition) {
485    const StreamSetBuffer * const buf = mKernel->getOutputStreamSetBuffer(name);
486    return buf->getRawItemPointer(this, getStreamHandle(name), absolutePosition);
487}
488
489Value * KernelBuilder::getBaseAddress(const std::string & name) {
490    return mKernel->getAnyStreamSetBuffer(name)->getBaseAddress(this, getStreamHandle(name));
491}
492
493void KernelBuilder::setBaseAddress(const std::string & name, Value * const addr) {
494    return mKernel->getAnyStreamSetBuffer(name)->setBaseAddress(this, getStreamHandle(name), addr);
495}
496
497Value * KernelBuilder::getBufferedSize(const std::string & name) {
498    return mKernel->getAnyStreamSetBuffer(name)->getBufferedSize(this, getStreamHandle(name));
499}
500
501void KernelBuilder::setBufferedSize(const std::string & name, Value * size) {
502    mKernel->getAnyStreamSetBuffer(name)->setBufferedSize(this, getStreamHandle(name), size);
503}
504
505Value * KernelBuilder::getCapacity(const std::string & name) {
506    return mKernel->getAnyStreamSetBuffer(name)->getCapacity(this, getStreamHandle(name));
507}
508
509void KernelBuilder::setCapacity(const std::string & name, Value * c) {
510    mKernel->getAnyStreamSetBuffer(name)->setCapacity(this, getStreamHandle(name), c);
511}
512
513Value * KernelBuilder::getBlockAddress(const std::string & name, Value * blockIndex) {
514    const StreamSetBuffer * const buf = mKernel->getAnyStreamSetBuffer(name);
515    return buf->getBlockAddress(this, getStreamHandle(name), blockIndex);
516}
517
518void KernelBuilder::protectOutputStream(const std::string & name, const bool readOnly) {
519    const StreamSetBuffer * const buf = mKernel->getOutputStreamSetBuffer(name);
520    Value * const handle = getStreamHandle(name);
521    Value * const base = buf->getBaseAddress(this, handle);
522    Value * sz = ConstantExpr::getSizeOf(buf->getType());
523    sz = CreateMul(sz, getInt64(buf->getBufferBlocks()));
524    sz = CreateMul(sz, CreateZExt(buf->getStreamSetCount(this, handle), getInt64Ty()));
525    CreateMProtect(base, sz, readOnly ? CBuilder::READ : (CBuilder::READ | CBuilder::WRITE));
526}
527   
528CallInst * KernelBuilder::createDoSegmentCall(const std::vector<Value *> & args) {
529    return mKernel->makeDoSegmentCall(*this, args);
530}
531
532Value * KernelBuilder::getAccumulator(const std::string & accumName) {
533    auto results = mKernel->mOutputScalarResult;
534    if (LLVM_UNLIKELY(results == nullptr)) {
535        report_fatal_error("Cannot get accumulator " + accumName + " until " + mKernel->getName() + " has terminated.");
536    }
537    const auto & outputs = mKernel->getScalarOutputs();
538    const auto n = outputs.size();
539    if (LLVM_UNLIKELY(n == 0)) {
540        report_fatal_error(mKernel->getName() + " has no output scalars.");
541    } else {
542        for (unsigned i = 0; i < n; ++i) {
543            const Binding & b = outputs[i];
544            if (b.getName() == accumName) {
545                if (n == 1) {
546                    return results;
547                } else {
548                    return CreateExtractValue(results, {i});
549                }
550            }
551        }
552        report_fatal_error(mKernel->getName() + " has no output scalar named " + accumName);
553    }
554}
555
556BasicBlock * KernelBuilder::CreateConsumerWait() {
557    const auto consumers = mKernel->getStreamOutputs();
558    BasicBlock * const entry = GetInsertBlock();
559    if (consumers.empty()) {
560        return entry;
561    } else {
562        Function * const parent = entry->getParent();
563        IntegerType * const sizeTy = getSizeTy();
564        ConstantInt * const zero = getInt32(0);
565        ConstantInt * const one = getInt32(1);
566        ConstantInt * const size0 = getSize(0);
567
568        Value * const segNo = acquireLogicalSegmentNo();
569        const auto n = consumers.size();
570        BasicBlock * load[n + 1];
571        BasicBlock * wait[n];
572        for (unsigned i = 0; i < n; ++i) {
573            load[i] = BasicBlock::Create(getContext(), consumers[i].getName() + "Load", parent);
574            wait[i] = BasicBlock::Create(getContext(), consumers[i].getName() + "Wait", parent);
575        }
576        load[n] = BasicBlock::Create(getContext(), "Resume", parent);
577        CreateBr(load[0]);
578        for (unsigned i = 0; i < n; ++i) {
579
580            SetInsertPoint(load[i]);
581            Value * const outputConsumers = getConsumerLock(consumers[i].getName());
582
583            Value * const consumerCount = CreateLoad(CreateGEP(outputConsumers, {zero, zero}));
584            Value * const consumerPtr = CreateLoad(CreateGEP(outputConsumers, {zero, one}));
585            Value * const noConsumers = CreateICmpEQ(consumerCount, size0);
586            CreateUnlikelyCondBr(noConsumers, load[i + 1], wait[i]);
587
588            SetInsertPoint(wait[i]);
589            PHINode * const consumerPhi = CreatePHI(sizeTy, 2);
590            consumerPhi->addIncoming(size0, load[i]);
591
592            Value * const conSegPtr = CreateLoad(CreateGEP(consumerPtr, consumerPhi));
593            Value * const processedSegmentCount = CreateAtomicLoadAcquire(conSegPtr);
594            Value * const ready = CreateICmpEQ(segNo, processedSegmentCount);
595            assert (ready->getType() == getInt1Ty());
596            Value * const nextConsumerIdx = CreateAdd(consumerPhi, CreateZExt(ready, sizeTy));
597            consumerPhi->addIncoming(nextConsumerIdx, wait[i]);
598            Value * const next = CreateICmpEQ(nextConsumerIdx, consumerCount);
599            CreateCondBr(next, load[i + 1], wait[i]);
600        }
601
602        BasicBlock * const exit = load[n];
603        SetInsertPoint(exit);
604        return exit;
605    }
606}
607
608}
Note: See TracBrowser for help on using the repository browser.