source: icGREP/icgrep-devel/icgrep/kernels/kernel_builder.cpp @ 5793

Last change on this file since 5793 was 5793, checked in by nmedfort, 16 months ago

Bug fix for pipeline: it was terminating too early when there was insufficient output space to process all of the input for a kernel.

File size: 27.7 KB
Line 
1#include "kernel_builder.h"
2#include <toolchain/toolchain.h>
3#include <kernels/kernel.h>
4#include <kernels/streamset.h>
5#include <llvm/Support/raw_ostream.h>
6#include <llvm/IR/Module.h>
7
8using namespace llvm;
9using namespace parabix;
10
11inline static bool is_power_2(const uint64_t n) {
12    return ((n & (n - 1)) == 0) && n;
13}
14
15namespace kernel {
16
17using Port = Kernel::Port;
18
19Value * KernelBuilder::getScalarFieldPtr(llvm::Value * const instance, Value * const index) {
20    if (LLVM_UNLIKELY(codegen::DebugOptionIsSet(codegen::EnableAsserts))) {
21        CreateAssert(instance, "getScalarFieldPtr: instance cannot be null!");
22    }
23    return CreateGEP(instance, {getInt32(0), index});
24}
25
26Value * KernelBuilder::getScalarFieldPtr(llvm::Value * const handle, const std::string & fieldName) {
27    return getScalarFieldPtr(handle, getInt32(mKernel->getScalarIndex(fieldName)));
28}
29
30llvm::Value * KernelBuilder::getScalarFieldPtr(llvm::Value * const index) {
31    return getScalarFieldPtr(mKernel->getInstance(), index);
32}
33
34llvm::Value *KernelBuilder:: getScalarFieldPtr(const std::string & fieldName) {
35    return getScalarFieldPtr(mKernel->getInstance(), fieldName);
36}
37
38Value * KernelBuilder::getScalarField(const std::string & fieldName) {
39    return CreateLoad(getScalarFieldPtr(fieldName), fieldName);
40}
41
42void KernelBuilder::setScalarField(const std::string & fieldName, Value * value) {
43    CreateStore(value, getScalarFieldPtr(fieldName));
44}
45
46Value * KernelBuilder::getStreamHandle(const std::string & name) {
47    Value * const ptr = getScalarField(name + Kernel::BUFFER_PTR_SUFFIX);
48    if (LLVM_UNLIKELY(codegen::DebugOptionIsSet(codegen::EnableAsserts))) {
49        CreateAssert(ptr, name + " handle cannot be null!");
50    }
51    return ptr;
52}
53
54LoadInst * KernelBuilder::acquireLogicalSegmentNo() {
55    return CreateAtomicLoadAcquire(getScalarFieldPtr(Kernel::LOGICAL_SEGMENT_NO_SCALAR));
56}
57
58void KernelBuilder::releaseLogicalSegmentNo(Value * nextSegNo) {
59    CreateAtomicStoreRelease(nextSegNo, getScalarFieldPtr(Kernel::LOGICAL_SEGMENT_NO_SCALAR));
60}
61
62Value * KernelBuilder::getCycleCountPtr() {
63    return getScalarFieldPtr(Kernel::CYCLECOUNT_SCALAR);
64}
65
66Value * KernelBuilder::getInternalItemCount(const std::string & name, const std::string & suffix) {
67    const ProcessingRate & rate = mKernel->getBinding(name).getRate();
68    Value * itemCount = nullptr;
69    if (LLVM_UNLIKELY(rate.isRelative())) {
70        Port port; unsigned index;
71        std::tie(port, index) = mKernel->getStreamPort(rate.getReference());
72        if (port == Port::Input) {
73            itemCount = getProcessedItemCount(rate.getReference());
74        } else {
75            itemCount = getProducedItemCount(rate.getReference());
76        }
77        const auto & r = rate.getRate();
78        if (r.numerator() != 1) {
79            itemCount = CreateMul(itemCount, ConstantInt::get(itemCount->getType(), r.numerator()));
80        }
81        if (r.denominator() != 1) {
82            itemCount = CreateExactUDiv(itemCount, ConstantInt::get(itemCount->getType(), r.denominator()));
83        }
84    } else if (LLVM_UNLIKELY(rate.isPopCount())) {
85        Port port; unsigned index;
86        std::tie(port, index) = mKernel->getStreamPort(rate.getReference());
87
88
89
90
91    } else {
92        itemCount = getScalarField(name + suffix);
93    }
94    return itemCount;
95}
96
97void KernelBuilder::setInternalItemCount(const std::string & name, const std::string & suffix, llvm::Value * const value) {
98    const ProcessingRate & rate = mKernel->getBinding(name).getRate();
99    if (LLVM_UNLIKELY(rate.isDerived())) {
100        report_fatal_error("Cannot set item count: " + name + " is a Derived rate");
101    }
102    if (codegen::DebugOptionIsSet(codegen::TraceCounts)) {
103        CallPrintIntToStderr(mKernel->getName() + ": " + name + suffix, value);
104    }
105    setScalarField(name + suffix, value);
106}
107
108
109Value * KernelBuilder::getAvailableItemCount(const std::string & name) {
110    const auto & inputs = mKernel->getStreamInputs();
111    for (unsigned i = 0; i < inputs.size(); ++i) {
112        if (inputs[i].getName() == name) {
113            return mKernel->getAvailableItemCount(i);
114        }
115    }
116    return nullptr;
117}
118
119Value * KernelBuilder::getTerminationSignal() {
120    return CreateICmpNE(getScalarField(Kernel::TERMINATION_SIGNAL), getSize(0));
121}
122
123void KernelBuilder::setTerminationSignal(llvm::Value * const value) {
124    assert (value->getType() == getInt1Ty());
125    if (codegen::DebugOptionIsSet(codegen::TraceCounts)) {
126        CallPrintIntToStderr(mKernel->getName() + ": setTerminationSignal", value);
127    }
128    setScalarField(Kernel::TERMINATION_SIGNAL, CreateZExt(value, getSizeTy()));
129}
130
131Value * KernelBuilder::getLinearlyAccessibleItems(const std::string & name, Value * fromPosition, Value * avail, bool reverse) {
132    const StreamSetBuffer * const buf = mKernel->getInputStreamSetBuffer(name);
133    return buf->getLinearlyAccessibleItems(this, getStreamHandle(name), fromPosition, avail, reverse);
134}
135
136Value * KernelBuilder::getLinearlyWritableItems(const std::string & name, Value * fromPosition, bool reverse) {
137    const StreamSetBuffer * const buf = mKernel->getOutputStreamSetBuffer(name);
138    return buf->getLinearlyWritableItems(this, getStreamHandle(name), fromPosition, getConsumedItemCount(name), reverse);
139}
140
141/** ------------------------------------------------------------------------------------------------------------- *
142 * @brief isConstantZero
143 ** ------------------------------------------------------------------------------------------------------------- */
144inline bool isConstantZero(Value * const v) {
145    return isa<ConstantInt>(v) && cast<ConstantInt>(v)->isNullValue();
146}
147
148/** ------------------------------------------------------------------------------------------------------------- *
149 * @brief isConstantOne
150 ** ------------------------------------------------------------------------------------------------------------- */
151inline bool isConstantOne(Value * const v) {
152    return isa<ConstantInt>(v) && cast<ConstantInt>(v)->isOne();
153}
154
155/** ------------------------------------------------------------------------------------------------------------- *
156 * @brief getItemWidth
157 ** ------------------------------------------------------------------------------------------------------------- */
158inline unsigned getItemWidth(const Type * ty) {
159    if (LLVM_LIKELY(isa<ArrayType>(ty))) {
160        ty = ty->getArrayElementType();
161    }
162    return cast<IntegerType>(ty->getVectorElementType())->getBitWidth();
163}
164
165/** ------------------------------------------------------------------------------------------------------------- *
166 * @brief getFieldWidth
167 ** ------------------------------------------------------------------------------------------------------------- */
168inline unsigned getFieldWidth(const unsigned bitWidth, const unsigned blockWidth) {
169    for (unsigned k = 16; k <= blockWidth; k *= 2) {
170        if ((bitWidth & (k - 1)) != 0) {
171            return k / 2;
172        }
173    }
174    return blockWidth;
175}
176
177/** ------------------------------------------------------------------------------------------------------------- *
178 * @brief CreateStreamCpy
179 ** ------------------------------------------------------------------------------------------------------------- */
180void KernelBuilder::CreateStreamCpy(const std::string & name, Value * target, Value * targetOffset, Value * source, Value * sourceOffset, Value * itemsToCopy, const unsigned itemAlignment) {
181
182    assert (target && targetOffset);
183    assert (source && sourceOffset);
184    assert (target->getType() == source->getType());
185    assert (target->getType()->isPointerTy());
186    assert (isConstantZero(targetOffset) || isConstantZero(sourceOffset));
187
188    const StreamSetBuffer * const buf = mKernel->getAnyStreamSetBuffer(name);
189
190    const auto itemWidth = getItemWidth(buf->getBaseType());
191    assert ("invalid item width" && is_power_2(itemWidth));
192    const auto blockWidth = getBitBlockWidth();
193    // Although our item width may be n bits, if we know we're always processing m items per block, our field width
194    // (w.r.t the stream copy) would be n*m. By taking this into account we can optimize and simplify the copy code.
195    const auto fieldWidth = getFieldWidth(itemWidth * itemAlignment, blockWidth);
196    const auto alignment = (fieldWidth + 7) / 8;
197    if (LLVM_LIKELY(itemWidth < fieldWidth)) {
198        const auto factor = fieldWidth / itemWidth;
199        Constant * const FACTOR = getSize(factor);
200        if (LLVM_UNLIKELY(codegen::DebugOptionIsSet(codegen::EnableAsserts))) {
201            ConstantInt * const ALIGNMENT = getSize(alignment);
202            const auto kernelName = mKernel->getName()+ ": " + name;
203            CreateAssertZero(CreateURem(CreatePtrToInt(target, getSizeTy()), ALIGNMENT), kernelName + " target is misaligned (" + std::to_string(alignment) + ")");
204            CreateAssertZero(CreateURem(targetOffset, FACTOR), kernelName + " target offset is misaligned (" + std::to_string(factor) + ")");
205            CreateAssertZero(CreateURem(CreatePtrToInt(source, getSizeTy()), ALIGNMENT), kernelName + " source is misaligned (" + std::to_string(alignment) + ")");
206            CreateAssertZero(CreateURem(sourceOffset, FACTOR), kernelName + " source offset is misaligned (" + std::to_string(factor) + ")");
207        }
208        targetOffset = CreateUDiv(targetOffset, FACTOR);
209        sourceOffset = CreateUDiv(sourceOffset, FACTOR);
210    }
211
212    /*
213       Streams are conceptually modelled as:
214
215                                            BLOCKS
216
217                                      A     B     C     D
218           STREAM SET ELEMENT   1  |aaaaa|bbbbb|ccccc|dddd |
219                                2  |eeeee|fffff|ggggg|hhhh |
220                                3  |iiiii|jjjjj|kkkkk|llll |
221
222       But the memory layout is actually:
223
224           A_1   A_2   A_3   B_1   B_2   B_3   C_1   C_2   C_3   D_1   D_2   D_3
225
226         |aaaaa|eeeee|iiiii|bbbbb|fffff|jjjjj|ccccc|ggggg|kkkkk|dddd |hhhh |llll |
227
228
229       So if we're copying the entire stream set block or our stream set has one element, we can use memcpy.
230
231       One compilication here is when the BlockSize of a stream is not equal to the BitBlockWidth.
232
233
234    */
235
236    Type * const fieldWidthTy = getIntNTy(fieldWidth);
237
238    Value * const n = buf->getStreamSetCount(this, getStreamHandle(name));
239    if (isConstantOne(n) || fieldWidth == blockWidth || (isConstantZero(targetOffset) && isConstantZero(sourceOffset))) {
240        if (isConstantOne(n)) {
241            if (LLVM_LIKELY(itemWidth < 8)) {
242                itemsToCopy = CreateUDivCeil(itemsToCopy, getSize(8 / itemWidth));
243            } else if (LLVM_UNLIKELY(itemWidth > 8)) {
244                itemsToCopy = CreateMul(itemsToCopy, getSize(itemWidth / 8));
245            }
246        } else {
247            if (LLVM_LIKELY(blockWidth > (itemWidth * 8))) {
248                itemsToCopy = CreateUDivCeil(itemsToCopy, getSize(blockWidth / (8 * itemWidth)));
249            } else if (LLVM_LIKELY(blockWidth < (itemWidth * 8))) {
250                itemsToCopy = CreateUDivCeil(CreateMul(itemsToCopy, getSize(8)), getSize(blockWidth / itemWidth));
251            }
252            itemsToCopy = CreateMul(itemsToCopy, n);
253        }
254        PointerType * const ptrTy = fieldWidthTy->getPointerTo();
255        target = CreateGEP(CreatePointerCast(target, ptrTy), targetOffset);
256        source = CreateGEP(CreatePointerCast(source, ptrTy), sourceOffset);
257        CreateMemCpy(target, source, itemsToCopy, alignment);
258
259    } else { // either the target offset or source offset is non-zero but not both
260
261        VectorType * const blockTy = getBitBlockType();
262        PointerType * const blockPtrTy = blockTy->getPointerTo();
263
264        target = CreatePointerCast(target, blockPtrTy);
265        source = CreatePointerCast(source, blockPtrTy);
266
267        assert ((blockWidth % fieldWidth) == 0);
268
269        VectorType * const shiftTy = VectorType::get(fieldWidthTy, blockWidth / fieldWidth);
270        Constant * const width = getSize(blockWidth / itemWidth);
271        BasicBlock * const entry = GetInsertBlock();
272
273        if (isConstantZero(targetOffset)) {
274
275            /*
276                                                BLOCKS
277
278                                          A     B     C     D
279               SOURCE STREAM        1  |aaa--|bbbBB|cccCC|  dDD|
280                                    2  |eee--|fffFF|gggGG|  hHH|
281                                    3  |iii--|jjjJJ|kkkKK|  lLL|
282
283
284                                          A     B     C     D
285               TARGET STREAM        1  |BBaaa|CCbbb|DDccc|    d|
286                                    2  |FFeee|GGfff|HHggg|    h|
287                                    3  |JJiii|KKjjj|LLkkk|    l|
288             */
289
290            Value * const blocksToCopy = CreateMul(CreateUDiv(itemsToCopy, width), n);
291            Value * const offset = CreateURem(sourceOffset, width);
292            Value * const remaining = CreateSub(width, offset);
293            Value * const trailing = CreateURem(CreateAdd(sourceOffset, itemsToCopy), width);
294
295            BasicBlock * const streamCopy = CreateBasicBlock(name + "StreamCopy");
296            BasicBlock * const streamCopyRemaining = CreateBasicBlock(name + "StreamCopyRemaining");
297            BasicBlock * const streamCopyEnd = CreateBasicBlock(name + "StreamCopyEnd");
298
299            CreateCondBr(CreateICmpNE(blocksToCopy, getSize(0)), streamCopy, streamCopyRemaining);
300
301            SetInsertPoint(streamCopy);
302            PHINode * const i = CreatePHI(getSizeTy(), 2);
303            i->addIncoming(n, entry);
304            Value * prior = CreateAlignedLoad(CreateGEP(source, CreateSub(i, n)), alignment);
305            prior = CreateLShr(CreateBitCast(prior, shiftTy), offset);
306            Value * value = CreateAlignedLoad(CreateGEP(source, i), alignment);
307            value = CreateShl(CreateBitCast(value, shiftTy), remaining);
308            Value * const result = CreateBitCast(CreateOr(value, prior), blockTy);
309            CreateAlignedStore(result, CreateGEP(target, i), alignment);
310            Value * const next_i = CreateAdd(i, getSize(1));
311            i->addIncoming(next_i, streamCopy);
312            CreateCondBr(CreateICmpNE(next_i, blocksToCopy), streamCopy, streamCopyRemaining);
313
314            SetInsertPoint(streamCopyRemaining);
315            PHINode * const j = CreatePHI(getSizeTy(), 2);
316            j->addIncoming(getSize(0), streamCopy);
317            Value * k = CreateAdd(blocksToCopy, j);
318            Value * final = CreateAlignedLoad(CreateGEP(source, k), alignment);
319            final = CreateLShr(CreateBitCast(prior, shiftTy), trailing);
320            CreateAlignedStore(final, CreateGEP(target, k), alignment);
321            Value * const next_j = CreateAdd(i, getSize(1));
322            i->addIncoming(next_j, streamCopyRemaining);
323            CreateCondBr(CreateICmpNE(next_j, n), streamCopyRemaining, streamCopyEnd);
324
325            SetInsertPoint(streamCopyEnd);
326
327        } else if (isConstantZero(sourceOffset)) {
328
329            /*
330                                                BLOCKS
331
332                                          A     B     C     D
333               SOURCE STREAM        1  |AAAaa|BBBaa|CCCcc|    d|
334                                    2  |EEEee|FFFff|GGGgg|    h|
335                                    3  |IIIii|JJJjj|KKKkk|    l|
336
337
338                                          A     B     C     D
339               TARGET STREAM        1  |aa---|bbAAA|ccBBB| dCCC|
340                                    2  |ee---|ffEEE|ggFFF| hGGG|
341                                    3  |ii---|jjIII|kkJJJ| lKKK|
342
343            */
344
345            BasicBlock * const streamCopy = CreateBasicBlock(name + "StreamCopy");
346            BasicBlock * const streamCopyRemainingCond = CreateBasicBlock(name + "StreamCopyRemainingCond");
347            BasicBlock * const streamCopyRemaining = CreateBasicBlock(name + "StreamCopyRemaining");
348            BasicBlock * const streamCopyEnd = CreateBasicBlock(name + "StreamCopyEnd");
349
350            Value * const offset = CreateURem(targetOffset, width);
351            Value * const copied = CreateSub(width, offset);
352            Value * const mask = CreateLShr(Constant::getAllOnesValue(shiftTy), copied);
353
354            SetInsertPoint(streamCopy);
355            PHINode * const i = CreatePHI(getSizeTy(), 2);
356            i->addIncoming(getSize(0), entry);
357            Value * targetValue = CreateAlignedLoad(CreateGEP(target, i), alignment);
358            targetValue = CreateAnd(CreateBitCast(targetValue, shiftTy), mask);
359            Value * sourceValue = CreateAlignedLoad(CreateGEP(source, i), alignment);
360            sourceValue = CreateShl(CreateBitCast(sourceValue, shiftTy), offset);
361            CreateAlignedStore(CreateOr(sourceValue, targetValue), CreateGEP(source, i), alignment);
362            Value * const next_i = CreateAdd(i, getSize(1));
363            i->addIncoming(next_i, streamCopy);
364            CreateCondBr(CreateICmpNE(next_i, n), streamCopy, streamCopyRemainingCond);
365
366            SetInsertPoint(streamCopyRemainingCond);
367            Value * const blocksToCopy = CreateMul(CreateUDiv(CreateSub(itemsToCopy, copied), width), n);
368            CreateCondBr(CreateICmpULT(copied, itemsToCopy), streamCopyRemaining, streamCopyEnd);
369
370            SetInsertPoint(streamCopyRemaining);
371            PHINode * const j = CreatePHI(getSizeTy(), 2);
372            j->addIncoming(n, entry);
373            Value * prior = CreateAlignedLoad(CreateGEP(source, CreateSub(j, n)), alignment);
374            prior = CreateShl(CreateBitCast(prior, shiftTy), offset);
375            Value * value = CreateAlignedLoad(CreateGEP(source, j), alignment);
376            value = CreateLShr(CreateBitCast(value, shiftTy), copied);
377            Value * const result = CreateBitCast(CreateOr(value, prior), blockTy);
378            CreateAlignedStore(result, CreateGEP(target, j), alignment);
379            Value * const next_j = CreateAdd(j, getSize(1));
380            j->addIncoming(next_j, streamCopy);
381            CreateCondBr(CreateICmpNE(next_j, blocksToCopy), streamCopyRemaining, streamCopyEnd);
382
383            SetInsertPoint(streamCopyEnd);
384        }
385
386    }
387}
388
389Value * KernelBuilder::getConsumerLock(const std::string & name) {
390    return getScalarField(name + Kernel::CONSUMER_SUFFIX);
391}
392
393void KernelBuilder::setConsumerLock(const std::string & name, Value * value) {
394    setScalarField(name + Kernel::CONSUMER_SUFFIX, value);
395}
396
397Value * KernelBuilder::getInputStreamBlockPtr(const std::string & name, Value * streamIndex) {
398    Value * const addr = mKernel->getStreamSetInputAddress(name);
399    if (addr) {
400        return CreateGEP(addr, {getInt32(0), streamIndex});
401    } else {
402        const StreamSetBuffer * const buf = mKernel->getInputStreamSetBuffer(name);
403        Value * const blockIndex = CreateLShr(getProcessedItemCount(name), std::log2(getBitBlockWidth()));
404        return buf->getStreamBlockPtr(this, getStreamHandle(name), getBaseAddress(name), streamIndex, blockIndex, true);
405    }
406}
407
408Value * KernelBuilder::loadInputStreamBlock(const std::string & name, Value * streamIndex) {
409    return CreateBlockAlignedLoad(getInputStreamBlockPtr(name, streamIndex));
410}
411
412Value * KernelBuilder::getInputStreamPackPtr(const std::string & name, Value * streamIndex, Value * packIndex) {
413    Value * const addr = mKernel->getStreamSetInputAddress(name);
414    if (addr) {
415        return CreateGEP(addr, {getInt32(0), streamIndex, packIndex});
416    } else {
417        const StreamSetBuffer * const buf = mKernel->getInputStreamSetBuffer(name);
418        Value * const blockIndex = CreateLShr(getProcessedItemCount(name), std::log2(getBitBlockWidth()));
419        return buf->getStreamPackPtr(this, getStreamHandle(name), getBaseAddress(name), streamIndex, blockIndex, packIndex, true);
420    }
421}
422
423Value * KernelBuilder::loadInputStreamPack(const std::string & name, Value * streamIndex, Value * packIndex) {
424    return CreateBlockAlignedLoad(getInputStreamPackPtr(name, streamIndex, packIndex));
425}
426
427Value * KernelBuilder::getInputStreamSetCount(const std::string & name) {
428    const StreamSetBuffer * const buf = mKernel->getInputStreamSetBuffer(name);
429    return buf->getStreamSetCount(this, getStreamHandle(name));
430}
431
432Value * KernelBuilder::getInputStreamBlockPtr(const std::string & name, Value * const streamIndex, Value * const blockOffset) {
433    Value * const addr = mKernel->getStreamSetInputAddress(name);
434    if (addr) {
435        return CreateGEP(addr, {blockOffset, streamIndex});
436    } else {
437        const StreamSetBuffer * const buf = mKernel->getInputStreamSetBuffer(name);
438        Value * blockIndex = CreateLShr(getProcessedItemCount(name), std::log2(getBitBlockWidth()));
439        blockIndex = CreateAdd(blockIndex, blockOffset);
440        return buf->getStreamBlockPtr(this, getStreamHandle(name), getBaseAddress(name), streamIndex, blockIndex, true);
441    }
442}
443
444Value * KernelBuilder::getOutputStreamBlockPtr(const std::string & name, Value * streamIndex) {
445    Value * const addr = mKernel->getStreamSetOutputAddress(name);
446    if (addr) {
447        return CreateGEP(addr, {getInt32(0), streamIndex});
448    } else {
449        const StreamSetBuffer * const buf = mKernel->getOutputStreamSetBuffer(name);
450        Value * const blockIndex = CreateLShr(getProducedItemCount(name), std::log2(getBitBlockWidth()));
451        return buf->getStreamBlockPtr(this, getStreamHandle(name), getBaseAddress(name), streamIndex, blockIndex, false);
452    }
453}
454
455StoreInst * KernelBuilder::storeOutputStreamBlock(const std::string & name, Value * streamIndex, Value * toStore) {
456    return CreateBlockAlignedStore(toStore, getOutputStreamBlockPtr(name, streamIndex));
457}
458
459Value * KernelBuilder::getOutputStreamPackPtr(const std::string & name, Value * streamIndex, Value * packIndex) {
460    Value * const addr = mKernel->getStreamSetOutputAddress(name);
461    if (addr) {
462        return CreateGEP(addr, {getInt32(0), streamIndex, packIndex});
463    } else {
464        const StreamSetBuffer * const buf = mKernel->getOutputStreamSetBuffer(name);
465        Value * const blockIndex = CreateLShr(getProducedItemCount(name), std::log2(getBitBlockWidth()));
466        return buf->getStreamPackPtr(this, getStreamHandle(name), getBaseAddress(name), streamIndex, blockIndex, packIndex, false);
467    }
468}
469
470StoreInst * KernelBuilder::storeOutputStreamPack(const std::string & name, Value * streamIndex, Value * packIndex, Value * toStore) {
471    return CreateBlockAlignedStore(toStore, getOutputStreamPackPtr(name, streamIndex, packIndex));
472}
473
474Value * KernelBuilder::getOutputStreamSetCount(const std::string & name) {
475    const StreamSetBuffer * const buf = mKernel->getOutputStreamSetBuffer(name);
476    return buf->getStreamSetCount(this, getStreamHandle(name));
477}
478
479Value * KernelBuilder::getRawInputPointer(const std::string & name, Value * absolutePosition) {
480    const StreamSetBuffer * const buf = mKernel->getInputStreamSetBuffer(name);
481    return buf->getRawItemPointer(this, getStreamHandle(name), absolutePosition);
482}
483
484Value * KernelBuilder::getRawOutputPointer(const std::string & name, Value * absolutePosition) {
485    const StreamSetBuffer * const buf = mKernel->getOutputStreamSetBuffer(name);
486    return buf->getRawItemPointer(this, getStreamHandle(name), absolutePosition);
487}
488
489Value * KernelBuilder::getBaseAddress(const std::string & name) {
490    return mKernel->getAnyStreamSetBuffer(name)->getBaseAddress(this, getStreamHandle(name));
491}
492
493void KernelBuilder::setBaseAddress(const std::string & name, Value * const addr) {
494    return mKernel->getAnyStreamSetBuffer(name)->setBaseAddress(this, getStreamHandle(name), addr);
495}
496
497Value * KernelBuilder::getBufferedSize(const std::string & name) {
498    return mKernel->getAnyStreamSetBuffer(name)->getBufferedSize(this, getStreamHandle(name));
499}
500
501void KernelBuilder::setBufferedSize(const std::string & name, Value * size) {
502    mKernel->getAnyStreamSetBuffer(name)->setBufferedSize(this, getStreamHandle(name), size);
503}
504
505Value * KernelBuilder::getCapacity(const std::string & name) {
506    return mKernel->getAnyStreamSetBuffer(name)->getCapacity(this, getStreamHandle(name));
507}
508
509void KernelBuilder::setCapacity(const std::string & name, Value * c) {
510    mKernel->getAnyStreamSetBuffer(name)->setCapacity(this, getStreamHandle(name), c);
511}
512
513Value * KernelBuilder::getBlockAddress(const std::string & name, Value * blockIndex) {
514    const StreamSetBuffer * const buf = mKernel->getAnyStreamSetBuffer(name);
515    return buf->getBlockAddress(this, getStreamHandle(name), blockIndex);
516}
517
518void KernelBuilder::protectOutputStream(const std::string & name, const bool readOnly) {
519    const StreamSetBuffer * const buf = mKernel->getOutputStreamSetBuffer(name);
520    Value * const handle = getStreamHandle(name);
521    Value * const base = buf->getBaseAddress(this, handle);
522    Value * sz = ConstantExpr::getSizeOf(buf->getType());
523    sz = CreateMul(sz, getInt64(buf->getBufferBlocks()));
524    sz = CreateMul(sz, CreateZExt(buf->getStreamSetCount(this, handle), getInt64Ty()));
525    CreateMProtect(base, sz, readOnly ? CBuilder::READ : (CBuilder::READ | CBuilder::WRITE));
526}
527   
528CallInst * KernelBuilder::createDoSegmentCall(const std::vector<Value *> & args) {
529    return mKernel->makeDoSegmentCall(*this, args);
530}
531
532Value * KernelBuilder::getAccumulator(const std::string & accumName) {
533    auto results = mKernel->mOutputScalarResult;
534    if (LLVM_UNLIKELY(results == nullptr)) {
535        report_fatal_error("Cannot get accumulator " + accumName + " until " + mKernel->getName() + " has terminated.");
536    }
537    const auto & outputs = mKernel->getScalarOutputs();
538    const auto n = outputs.size();
539    if (LLVM_UNLIKELY(n == 0)) {
540        report_fatal_error(mKernel->getName() + " has no output scalars.");
541    } else {
542        for (unsigned i = 0; i < n; ++i) {
543            const Binding & b = outputs[i];
544            if (b.getName() == accumName) {
545                if (n == 1) {
546                    return results;
547                } else {
548                    return CreateExtractValue(results, {i});
549                }
550            }
551        }
552        report_fatal_error(mKernel->getName() + " has no output scalar named " + accumName);
553    }
554}
555
556BasicBlock * KernelBuilder::CreateConsumerWait() {
557    const auto consumers = mKernel->getStreamOutputs();
558    BasicBlock * const entry = GetInsertBlock();
559    if (consumers.empty()) {
560        return entry;
561    } else {
562        Function * const parent = entry->getParent();
563        IntegerType * const sizeTy = getSizeTy();
564        ConstantInt * const zero = getInt32(0);
565        ConstantInt * const one = getInt32(1);
566        ConstantInt * const size0 = getSize(0);
567
568        Value * const segNo = acquireLogicalSegmentNo();
569        const auto n = consumers.size();
570        BasicBlock * load[n + 1];
571        BasicBlock * wait[n];
572        for (unsigned i = 0; i < n; ++i) {
573            load[i] = BasicBlock::Create(getContext(), consumers[i].getName() + "Load", parent);
574            wait[i] = BasicBlock::Create(getContext(), consumers[i].getName() + "Wait", parent);
575        }
576        load[n] = BasicBlock::Create(getContext(), "Resume", parent);
577        CreateBr(load[0]);
578        for (unsigned i = 0; i < n; ++i) {
579
580            SetInsertPoint(load[i]);
581            Value * const outputConsumers = getConsumerLock(consumers[i].getName());
582
583            Value * const consumerCount = CreateLoad(CreateGEP(outputConsumers, {zero, zero}));
584            Value * const consumerPtr = CreateLoad(CreateGEP(outputConsumers, {zero, one}));
585            Value * const noConsumers = CreateICmpEQ(consumerCount, size0);
586            CreateUnlikelyCondBr(noConsumers, load[i + 1], wait[i]);
587
588            SetInsertPoint(wait[i]);
589            PHINode * const consumerPhi = CreatePHI(sizeTy, 2);
590            consumerPhi->addIncoming(size0, load[i]);
591
592            Value * const conSegPtr = CreateLoad(CreateGEP(consumerPtr, consumerPhi));
593            Value * const processedSegmentCount = CreateAtomicLoadAcquire(conSegPtr);
594            Value * const ready = CreateICmpEQ(segNo, processedSegmentCount);
595            assert (ready->getType() == getInt1Ty());
596            Value * const nextConsumerIdx = CreateAdd(consumerPhi, CreateZExt(ready, sizeTy));
597            consumerPhi->addIncoming(nextConsumerIdx, wait[i]);
598            Value * const next = CreateICmpEQ(nextConsumerIdx, consumerCount);
599            CreateCondBr(next, load[i + 1], wait[i]);
600        }
601
602        BasicBlock * const exit = load[n];
603        SetInsertPoint(exit);
604        return exit;
605    }
606}
607
608}
Note: See TracBrowser for help on using the repository browser.