source: icGREP/icgrep-devel/icgrep/kernels/kernel_builder.cpp @ 5757

Last change on this file since 5757 was 5757, checked in by nmedfort, 16 months ago

Bug fixes + more assertions to prevent similar errors.

File size: 27.6 KB
Line 
1#include "kernel_builder.h"
2#include <toolchain/toolchain.h>
3#include <kernels/kernel.h>
4#include <kernels/streamset.h>
5#include <llvm/Support/raw_ostream.h>
6#include <llvm/IR/Module.h>
7
8using namespace llvm;
9using namespace parabix;
10
11inline static bool is_power_2(const uint64_t n) {
12    return ((n & (n - 1)) == 0) && n;
13}
14
15namespace kernel {
16
17using Port = Kernel::Port;
18
19Value * KernelBuilder::getScalarFieldPtr(llvm::Value * const instance, Value * const index) {
20    if (LLVM_UNLIKELY(codegen::DebugOptionIsSet(codegen::EnableAsserts))) {
21        CreateAssert(instance, "getScalarFieldPtr: instance cannot be null!");
22    }
23    return CreateGEP(instance, {getInt32(0), index});
24}
25
26Value * KernelBuilder::getScalarFieldPtr(llvm::Value * const handle, const std::string & fieldName) {
27    return getScalarFieldPtr(handle, getInt32(mKernel->getScalarIndex(fieldName)));
28}
29
30llvm::Value * KernelBuilder::getScalarFieldPtr(llvm::Value * const index) {
31    return getScalarFieldPtr(mKernel->getInstance(), index);
32}
33
34llvm::Value *KernelBuilder:: getScalarFieldPtr(const std::string & fieldName) {
35    return getScalarFieldPtr(mKernel->getInstance(), fieldName);
36}
37
38Value * KernelBuilder::getScalarField(const std::string & fieldName) {
39    return CreateLoad(getScalarFieldPtr(fieldName), fieldName);
40}
41
42void KernelBuilder::setScalarField(const std::string & fieldName, Value * value) {
43    CreateStore(value, getScalarFieldPtr(fieldName));
44}
45
46Value * KernelBuilder::getStreamHandle(const std::string & name) {
47    Value * const ptr = getScalarField(name + Kernel::BUFFER_PTR_SUFFIX);
48    if (LLVM_UNLIKELY(codegen::DebugOptionIsSet(codegen::EnableAsserts))) {
49        CreateAssert(ptr, name + " handle cannot be null!");
50    }
51    return ptr;
52}
53
54LoadInst * KernelBuilder::acquireLogicalSegmentNo() {
55    return CreateAtomicLoadAcquire(getScalarFieldPtr(Kernel::LOGICAL_SEGMENT_NO_SCALAR));
56}
57
58void KernelBuilder::releaseLogicalSegmentNo(Value * nextSegNo) {
59    CreateAtomicStoreRelease(nextSegNo, getScalarFieldPtr(Kernel::LOGICAL_SEGMENT_NO_SCALAR));
60}
61
62Value * KernelBuilder::getCycleCountPtr() {
63    return getScalarFieldPtr(Kernel::CYCLECOUNT_SCALAR);
64}
65
66Value * KernelBuilder::getInternalItemCount(const std::string & name, const std::string & suffix) {
67    const ProcessingRate & rate = mKernel->getBinding(name).getRate();
68    Value * itemCount = nullptr;
69    if (LLVM_UNLIKELY(rate.isRelative())) {
70        Port port; unsigned index;
71        std::tie(port, index) = mKernel->getStreamPort(rate.getReference());
72        if (port == Port::Input) {
73            itemCount = getProcessedItemCount(rate.getReference());
74        } else {
75            itemCount = getProducedItemCount(rate.getReference());
76        }
77        const auto & r = rate.getRate();
78        if (r.numerator() != 1) {
79            itemCount = CreateMul(itemCount, ConstantInt::get(itemCount->getType(), r.numerator()));
80        }
81        if (r.denominator() != 1) {
82            itemCount = CreateExactUDiv(itemCount, ConstantInt::get(itemCount->getType(), r.denominator()));
83        }
84    } else {
85        itemCount = getScalarField(name + suffix);
86    }
87    return itemCount;
88}
89
90void KernelBuilder::setInternalItemCount(const std::string & name, const std::string & suffix, llvm::Value * const value) {
91    const ProcessingRate & rate = mKernel->getBinding(name).getRate();
92    if (LLVM_UNLIKELY(rate.isDerived())) {
93        report_fatal_error("Cannot set item count: " + name + " is a Derived rate");
94    }
95    if (codegen::DebugOptionIsSet(codegen::TraceCounts)) {
96        CallPrintIntToStderr(mKernel->getName() + ": " + name + suffix, value);
97    }
98    setScalarField(name + suffix, value);
99}
100
101
102Value * KernelBuilder::getAvailableItemCount(const std::string & name) {
103    const auto & inputs = mKernel->getStreamInputs();
104    for (unsigned i = 0; i < inputs.size(); ++i) {
105        if (inputs[i].getName() == name) {
106            return mKernel->getAvailableItemCount(i);
107        }
108    }
109    return nullptr;
110}
111
112Value * KernelBuilder::getTerminationSignal() {
113    if (mKernel->hasNoTerminateAttribute()) {
114        return getFalse();
115    }
116    return getScalarField(Kernel::TERMINATION_SIGNAL);
117}
118
119void KernelBuilder::setTerminationSignal(llvm::Value * const value) {
120    assert (!mKernel->hasNoTerminateAttribute());
121    assert (value->getType() == getInt1Ty());
122    if (codegen::DebugOptionIsSet(codegen::TraceCounts)) {
123        CallPrintIntToStderr(mKernel->getName() + ": setTerminationSignal", value);
124    }
125    setScalarField(Kernel::TERMINATION_SIGNAL, value);
126}
127
128Value * KernelBuilder::getLinearlyAccessibleItems(const std::string & name, Value * fromPosition, Value * avail, bool reverse) {
129    const StreamSetBuffer * const buf = mKernel->getInputStreamSetBuffer(name);
130    return buf->getLinearlyAccessibleItems(this, getStreamHandle(name), fromPosition, avail, reverse);
131}
132
133Value * KernelBuilder::getLinearlyWritableItems(const std::string & name, Value * fromPosition, bool reverse) {
134    const StreamSetBuffer * const buf = mKernel->getOutputStreamSetBuffer(name);
135    return buf->getLinearlyWritableItems(this, getStreamHandle(name), fromPosition, reverse);
136}
137
138//Value * KernelBuilder::getLinearlyCopyableItems(const std::string & name, Value * fromPosition, bool reverse) {
139//    const StreamSetBuffer * const buf = mKernel->getOutputStreamSetBuffer(name);
140//    return buf->getLinearlyCopyableItems(this, getStreamHandle(name), fromPosition, reverse);
141//}
142
143/** ------------------------------------------------------------------------------------------------------------- *
144 * @brief isConstantZero
145 ** ------------------------------------------------------------------------------------------------------------- */
146inline bool isConstantZero(Value * const v) {
147    return isa<ConstantInt>(v) && cast<ConstantInt>(v)->isNullValue();
148}
149
150/** ------------------------------------------------------------------------------------------------------------- *
151 * @brief isConstantOne
152 ** ------------------------------------------------------------------------------------------------------------- */
153inline bool isConstantOne(Value * const v) {
154    return isa<ConstantInt>(v) && cast<ConstantInt>(v)->isOne();
155}
156
157/** ------------------------------------------------------------------------------------------------------------- *
158 * @brief getItemWidth
159 ** ------------------------------------------------------------------------------------------------------------- */
160inline unsigned getItemWidth(const Type * ty) {
161    if (LLVM_LIKELY(isa<ArrayType>(ty))) {
162        ty = ty->getArrayElementType();
163    }
164    return cast<IntegerType>(ty->getVectorElementType())->getBitWidth();
165}
166
167/** ------------------------------------------------------------------------------------------------------------- *
168 * @brief getFieldWidth
169 ** ------------------------------------------------------------------------------------------------------------- */
170inline unsigned getFieldWidth(const unsigned bitWidth, const unsigned blockWidth) {
171    for (unsigned k = 16; k <= blockWidth; k *= 2) {
172        if ((bitWidth & (k - 1)) != 0) {
173            return k / 2;
174        }
175    }
176    return blockWidth;
177}
178
179/** ------------------------------------------------------------------------------------------------------------- *
180 * @brief CreateStreamCpy
181 ** ------------------------------------------------------------------------------------------------------------- */
182void KernelBuilder::CreateStreamCpy(const std::string & name, Value * target, Value * targetOffset, Value * source, Value * sourceOffset, Value * itemsToCopy, const unsigned itemAlignment) {
183
184    assert (target && targetOffset);
185    assert (source && sourceOffset);
186    assert (target->getType() == source->getType());
187    assert (target->getType()->isPointerTy());
188    assert (isConstantZero(targetOffset) || isConstantZero(sourceOffset));
189
190    const StreamSetBuffer * const buf = mKernel->getAnyStreamSetBuffer(name);
191    const auto itemWidth = getItemWidth(buf->getBaseType());
192    assert ("invalid item width" && is_power_2(itemWidth));
193    const auto blockWidth = getBitBlockWidth();
194    // Although our item width may be n bits, if we know we're always processing m items per block, our field width
195    // (w.r.t the stream copy) would be n*m. By taking this into account we can optimize and simplify the copy code.
196    const auto fieldWidth = getFieldWidth(itemWidth * itemAlignment, blockWidth);
197    assert ("overflow error" && is_power_2(fieldWidth) && (itemWidth <= fieldWidth) && (fieldWidth <= blockWidth));
198
199    if (LLVM_LIKELY(itemWidth < fieldWidth)) {
200        Constant * const factor = getSize(fieldWidth / itemWidth);
201        CreateAssertZero(CreateURem(targetOffset, factor), "target offset is not a multiple of its field width");
202        targetOffset = CreateUDiv(targetOffset, factor);
203        CreateAssertZero(CreateURem(sourceOffset, factor), "source offset is not a multiple of its field width");
204        sourceOffset = CreateUDiv(sourceOffset, factor);
205    }
206
207    /*
208       Streams are conceptually modelled as:
209
210                                            BLOCKS
211
212                                      A     B     C     D
213           STREAM SET ELEMENT   1  |aaaaa|bbbbb|ccccc|dddd |
214                                2  |eeeee|fffff|ggggg|hhhh |
215                                3  |iiiii|jjjjj|kkkkk|llll |
216
217       But the memory layout is actually:
218
219           A_1   A_2   A_3   B_1   B_2   B_3   C_1   C_2   C_3   D_1   D_2   D_3
220
221         |aaaaa|eeeee|iiiii|bbbbb|fffff|jjjjj|ccccc|ggggg|kkkkk|dddd |hhhh |llll |
222
223
224       So if we're copying the entire stream set block or our stream set has one element, we can use memcpy.
225
226    */
227
228    const auto alignment = (fieldWidth + 7) / 8;
229
230    Type * const fieldWidthTy = getIntNTy(fieldWidth);
231
232    Value * const n = buf->getStreamSetCount(this, getStreamHandle(name));
233    if (isConstantOne(n) || fieldWidth == blockWidth || (isConstantZero(targetOffset) && isConstantZero(sourceOffset))) {
234        if (isConstantOne(n)) {
235            if (LLVM_LIKELY(itemWidth < 8)) {
236                itemsToCopy = CreateUDivCeil(itemsToCopy, getSize(8 / itemWidth));
237            } else if (LLVM_UNLIKELY(itemWidth > 8)) {
238                itemsToCopy = CreateMul(itemsToCopy, getSize(itemWidth / 8));
239            }
240        } else {
241            if (LLVM_LIKELY(blockWidth > (itemWidth * 8))) {
242                itemsToCopy = CreateUDivCeil(itemsToCopy, getSize(blockWidth / (8 * itemWidth)));
243            } else if (LLVM_LIKELY(blockWidth < (itemWidth * 8))) {
244                itemsToCopy = CreateUDivCeil(CreateMul(itemsToCopy, getSize(8)), getSize(blockWidth / itemWidth));
245            }
246            itemsToCopy = CreateMul(itemsToCopy, n);
247        }
248        PointerType * const ptrTy = fieldWidthTy->getPointerTo();
249        target = CreateGEP(CreatePointerCast(target, ptrTy), targetOffset);
250        source = CreateGEP(CreatePointerCast(source, ptrTy), sourceOffset);
251        CreateMemCpy(target, source, itemsToCopy, alignment);
252
253    } else { // either the target offset or source offset is non-zero but not both
254
255        VectorType * const blockTy = getBitBlockType();
256        PointerType * const blockPtrTy = blockTy->getPointerTo();
257
258        target = CreatePointerCast(target, blockPtrTy);
259        source = CreatePointerCast(source, blockPtrTy);
260
261        assert ((blockWidth % fieldWidth) == 0);
262
263        VectorType * const shiftTy = VectorType::get(fieldWidthTy, blockWidth / fieldWidth);
264        Constant * const width = getSize(blockWidth / itemWidth);
265        BasicBlock * const entry = GetInsertBlock();
266
267        if (isConstantZero(targetOffset)) {
268
269            /*
270                                                BLOCKS
271
272                                          A     B     C     D
273               SOURCE STREAM        1  |aaa--|bbbBB|cccCC|  dDD|
274                                    2  |eee--|fffFF|gggGG|  hHH|
275                                    3  |iii--|jjjJJ|kkkKK|  lLL|
276
277
278                                          A     B     C     D
279               TARGET STREAM        1  |BBaaa|CCbbb|DDccc|    d|
280                                    2  |FFeee|GGfff|HHggg|    h|
281                                    3  |JJiii|KKjjj|LLkkk|    l|
282             */
283
284            Value * const blocksToCopy = CreateMul(CreateUDiv(itemsToCopy, width), n);
285            Value * const offset = CreateURem(sourceOffset, width);
286            Value * const remaining = CreateSub(width, offset);
287            Value * const trailing = CreateURem(CreateAdd(sourceOffset, itemsToCopy), width);
288
289            BasicBlock * const streamCopy = CreateBasicBlock(name + "StreamCopy");
290            BasicBlock * const streamCopyRemaining = CreateBasicBlock(name + "StreamCopyRemaining");
291            BasicBlock * const streamCopyEnd = CreateBasicBlock(name + "StreamCopyEnd");
292
293            CreateCondBr(CreateICmpNE(blocksToCopy, getSize(0)), streamCopy, streamCopyRemaining);
294
295            SetInsertPoint(streamCopy);
296            PHINode * const i = CreatePHI(getSizeTy(), 2);
297            i->addIncoming(n, entry);
298            Value * prior = CreateAlignedLoad(CreateGEP(source, CreateSub(i, n)), alignment);
299            prior = CreateLShr(CreateBitCast(prior, shiftTy), offset);
300            Value * value = CreateAlignedLoad(CreateGEP(source, i), alignment);
301            value = CreateShl(CreateBitCast(value, shiftTy), remaining);
302            Value * const result = CreateBitCast(CreateOr(value, prior), blockTy);
303            CreateAlignedStore(result, CreateGEP(target, i), alignment);
304            Value * const next_i = CreateAdd(i, getSize(1));
305            i->addIncoming(next_i, streamCopy);
306            CreateCondBr(CreateICmpNE(next_i, blocksToCopy), streamCopy, streamCopyRemaining);
307
308            SetInsertPoint(streamCopyRemaining);
309            PHINode * const j = CreatePHI(getSizeTy(), 2);
310            j->addIncoming(getSize(0), streamCopy);
311            Value * k = CreateAdd(blocksToCopy, j);
312            Value * final = CreateAlignedLoad(CreateGEP(source, k), alignment);
313            final = CreateLShr(CreateBitCast(prior, shiftTy), trailing);
314            CreateAlignedStore(final, CreateGEP(target, k), alignment);
315            Value * const next_j = CreateAdd(i, getSize(1));
316            i->addIncoming(next_j, streamCopyRemaining);
317            CreateCondBr(CreateICmpNE(next_j, n), streamCopyRemaining, streamCopyEnd);
318
319            SetInsertPoint(streamCopyEnd);
320
321        } else if (isConstantZero(sourceOffset)) {
322
323            /*
324                                                BLOCKS
325
326                                          A     B     C     D
327               SOURCE STREAM        1  |AAAaa|BBBaa|CCCcc|    d|
328                                    2  |EEEee|FFFff|GGGgg|    h|
329                                    3  |IIIii|JJJjj|KKKkk|    l|
330
331
332                                          A     B     C     D
333               TARGET STREAM        1  |aa---|bbAAA|ccBBB| dCCC|
334                                    2  |ee---|ffEEE|ggFFF| hGGG|
335                                    3  |ii---|jjIII|kkJJJ| lKKK|
336
337            */
338
339            BasicBlock * const streamCopy = CreateBasicBlock(name + "StreamCopy");
340            BasicBlock * const streamCopyRemainingCond = CreateBasicBlock(name + "StreamCopyRemainingCond");
341            BasicBlock * const streamCopyRemaining = CreateBasicBlock(name + "StreamCopyRemaining");
342            BasicBlock * const streamCopyEnd = CreateBasicBlock(name + "StreamCopyEnd");
343
344            Value * const offset = CreateURem(targetOffset, width);
345            Value * const copied = CreateSub(width, offset);
346            Value * const mask = CreateLShr(Constant::getAllOnesValue(shiftTy), copied);
347
348            SetInsertPoint(streamCopy);
349            PHINode * const i = CreatePHI(getSizeTy(), 2);
350            i->addIncoming(getSize(0), entry);
351            Value * targetValue = CreateAlignedLoad(CreateGEP(target, i), alignment);
352            targetValue = CreateAnd(CreateBitCast(targetValue, shiftTy), mask);
353            Value * sourceValue = CreateAlignedLoad(CreateGEP(source, i), alignment);
354            sourceValue = CreateShl(CreateBitCast(sourceValue, shiftTy), offset);
355            CreateAlignedStore(CreateOr(sourceValue, targetValue), CreateGEP(source, i), alignment);
356            Value * const next_i = CreateAdd(i, getSize(1));
357            i->addIncoming(next_i, streamCopy);
358            CreateCondBr(CreateICmpNE(next_i, n), streamCopy, streamCopyRemainingCond);
359
360            SetInsertPoint(streamCopyRemainingCond);
361            Value * const blocksToCopy = CreateMul(CreateUDiv(CreateSub(itemsToCopy, copied), width), n);
362            CreateCondBr(CreateICmpULT(copied, itemsToCopy), streamCopyRemaining, streamCopyEnd);
363
364            SetInsertPoint(streamCopyRemaining);
365            PHINode * const j = CreatePHI(getSizeTy(), 2);
366            j->addIncoming(n, entry);
367            Value * prior = CreateAlignedLoad(CreateGEP(source, CreateSub(j, n)), alignment);
368            prior = CreateShl(CreateBitCast(prior, shiftTy), offset);
369            Value * value = CreateAlignedLoad(CreateGEP(source, j), alignment);
370            value = CreateLShr(CreateBitCast(value, shiftTy), copied);
371            Value * const result = CreateBitCast(CreateOr(value, prior), blockTy);
372            CreateAlignedStore(result, CreateGEP(target, j), alignment);
373            Value * const next_j = CreateAdd(j, getSize(1));
374            j->addIncoming(next_j, streamCopy);
375            CreateCondBr(CreateICmpNE(next_j, blocksToCopy), streamCopyRemaining, streamCopyEnd);
376
377            SetInsertPoint(streamCopyEnd);
378        }
379
380    }
381}
382
383void KernelBuilder::CreateCopyBack(const std::string & name, llvm::Value * from, llvm::Value * to) {
384    const StreamSetBuffer * const buf = mKernel->getAnyStreamSetBuffer(name);
385    buf->genCopyBackLogic(this, getStreamHandle(name), from, to, name);
386}
387
388Value * KernelBuilder::getConsumerLock(const std::string & name) {
389    return getScalarField(name + Kernel::CONSUMER_SUFFIX);
390}
391
392void KernelBuilder::setConsumerLock(const std::string & name, Value * value) {
393    setScalarField(name + Kernel::CONSUMER_SUFFIX, value);
394}
395
396Value * KernelBuilder::getInputStreamBlockPtr(const std::string & name, Value * streamIndex) {
397    Value * const addr = mKernel->getStreamSetInputAddress(name);
398    if (addr) {
399        return CreateGEP(addr, {getInt32(0), streamIndex});
400    } else {
401        const StreamSetBuffer * const buf = mKernel->getInputStreamSetBuffer(name);
402        Value * const blockIndex = CreateLShr(getProcessedItemCount(name), std::log2(getBitBlockWidth()));
403        return buf->getStreamBlockPtr(this, getStreamHandle(name), getBaseAddress(name), streamIndex, blockIndex, true);
404    }
405}
406
407Value * KernelBuilder::loadInputStreamBlock(const std::string & name, Value * streamIndex) {
408    return CreateBlockAlignedLoad(getInputStreamBlockPtr(name, streamIndex));
409}
410
411Value * KernelBuilder::getInputStreamPackPtr(const std::string & name, Value * streamIndex, Value * packIndex) {
412    Value * const addr = mKernel->getStreamSetInputAddress(name);
413    if (addr) {
414        return CreateGEP(addr, {getInt32(0), streamIndex, packIndex});
415    } else {
416        const StreamSetBuffer * const buf = mKernel->getInputStreamSetBuffer(name);
417        Value * const blockIndex = CreateLShr(getProcessedItemCount(name), std::log2(getBitBlockWidth()));
418        return buf->getStreamPackPtr(this, getStreamHandle(name), getBaseAddress(name), streamIndex, blockIndex, packIndex, true);
419    }
420}
421
422Value * KernelBuilder::loadInputStreamPack(const std::string & name, Value * streamIndex, Value * packIndex) {
423
424
425
426    return CreateBlockAlignedLoad(getInputStreamPackPtr(name, streamIndex, packIndex));
427}
428
429Value * KernelBuilder::getInputStreamSetCount(const std::string & name) {
430    const StreamSetBuffer * const buf = mKernel->getInputStreamSetBuffer(name);
431    return buf->getStreamSetCount(this, getStreamHandle(name));
432}
433
434Value * KernelBuilder::getAdjustedInputStreamBlockPtr(Value * blockAdjustment, const std::string & name, Value * streamIndex) {
435    Value * const addr = mKernel->getStreamSetInputAddress(name);
436    if (addr) {
437        return CreateGEP(addr, {blockAdjustment, streamIndex});
438    } else {
439        const StreamSetBuffer * const buf = mKernel->getInputStreamSetBuffer(name);
440        Value * blockIndex = CreateLShr(getProcessedItemCount(name), std::log2(getBitBlockWidth()));
441        blockIndex = CreateAdd(blockIndex, blockAdjustment);
442        return buf->getStreamBlockPtr(this, getStreamHandle(name), getBaseAddress(name), streamIndex, blockIndex, true);
443    }
444}
445
446Value * KernelBuilder::getOutputStreamBlockPtr(const std::string & name, Value * streamIndex) {
447    Value * const addr = mKernel->getStreamSetOutputAddress(name);
448    if (addr) {
449        return CreateGEP(addr, {getInt32(0), streamIndex});
450    } else {
451        const StreamSetBuffer * const buf = mKernel->getOutputStreamSetBuffer(name);
452        Value * const blockIndex = CreateLShr(getProducedItemCount(name), std::log2(getBitBlockWidth()));
453        return buf->getStreamBlockPtr(this, getStreamHandle(name), getBaseAddress(name), streamIndex, blockIndex, false);
454    }
455}
456
457StoreInst * KernelBuilder::storeOutputStreamBlock(const std::string & name, Value * streamIndex, Value * toStore) {
458    return CreateBlockAlignedStore(toStore, getOutputStreamBlockPtr(name, streamIndex));
459}
460
461Value * KernelBuilder::getOutputStreamPackPtr(const std::string & name, Value * streamIndex, Value * packIndex) {
462    Value * const addr = mKernel->getStreamSetOutputAddress(name);
463    if (addr) {
464        return CreateGEP(addr, {getInt32(0), streamIndex, packIndex});
465    } else {
466        const StreamSetBuffer * const buf = mKernel->getOutputStreamSetBuffer(name);
467        Value * const blockIndex = CreateLShr(getProducedItemCount(name), std::log2(getBitBlockWidth()));
468        return buf->getStreamPackPtr(this, getStreamHandle(name), getBaseAddress(name), streamIndex, blockIndex, packIndex, false);
469    }
470}
471
472StoreInst * KernelBuilder::storeOutputStreamPack(const std::string & name, Value * streamIndex, Value * packIndex, Value * toStore) {
473    return CreateBlockAlignedStore(toStore, getOutputStreamPackPtr(name, streamIndex, packIndex));
474}
475
476Value * KernelBuilder::getOutputStreamSetCount(const std::string & name) {
477    const StreamSetBuffer * const buf = mKernel->getOutputStreamSetBuffer(name);
478    return buf->getStreamSetCount(this, getStreamHandle(name));
479}
480
481Value * KernelBuilder::getRawInputPointer(const std::string & name, Value * absolutePosition) {
482    const StreamSetBuffer * const buf = mKernel->getInputStreamSetBuffer(name);
483    return buf->getRawItemPointer(this, getStreamHandle(name), absolutePosition);
484}
485
486Value * KernelBuilder::getRawOutputPointer(const std::string & name, Value * absolutePosition) {
487    const StreamSetBuffer * const buf = mKernel->getOutputStreamSetBuffer(name);
488    return buf->getRawItemPointer(this, getStreamHandle(name), absolutePosition);
489}
490
491Value * KernelBuilder::getBaseAddress(const std::string & name) {
492    return mKernel->getAnyStreamSetBuffer(name)->getBaseAddress(this, getStreamHandle(name));
493}
494
495void KernelBuilder::setBaseAddress(const std::string & name, Value * const addr) {
496    return mKernel->getAnyStreamSetBuffer(name)->setBaseAddress(this, getStreamHandle(name), addr);
497}
498
499Value * KernelBuilder::getBufferedSize(const std::string & name) {
500    return mKernel->getAnyStreamSetBuffer(name)->getBufferedSize(this, getStreamHandle(name));
501}
502
503void KernelBuilder::setBufferedSize(const std::string & name, Value * size) {
504    mKernel->getAnyStreamSetBuffer(name)->setBufferedSize(this, getStreamHandle(name), size);
505}
506
507Value * KernelBuilder::getCapacity(const std::string & name) {
508    return mKernel->getAnyStreamSetBuffer(name)->getCapacity(this, getStreamHandle(name));
509}
510
511void KernelBuilder::setCapacity(const std::string & name, Value * c) {
512    mKernel->getAnyStreamSetBuffer(name)->setCapacity(this, getStreamHandle(name), c);
513}
514
515Value * KernelBuilder::getBlockAddress(const std::string & name, Value * blockIndex) {
516    const StreamSetBuffer * const buf = mKernel->getAnyStreamSetBuffer(name);
517    return buf->getBlockAddress(this, getStreamHandle(name), blockIndex);
518}
519
520void KernelBuilder::protectOutputStream(const std::string & name, const bool readOnly) {
521    const StreamSetBuffer * const buf = mKernel->getOutputStreamSetBuffer(name);
522    Value * const handle = getStreamHandle(name);
523    Value * const base = buf->getBaseAddress(this, handle);
524    Value * sz = ConstantExpr::getSizeOf(buf->getType());
525    sz = CreateMul(sz, getInt64(buf->getBufferBlocks()));
526    sz = CreateMul(sz, CreateZExt(buf->getStreamSetCount(this, handle), getInt64Ty()));
527    CreateMProtect(base, sz, readOnly ? CBuilder::READ : (CBuilder::READ | CBuilder::WRITE));
528}
529   
530CallInst * KernelBuilder::createDoSegmentCall(const std::vector<Value *> & args) {
531    return mKernel->makeDoSegmentCall(*this, args);
532}
533
534Value * KernelBuilder::getAccumulator(const std::string & accumName) {
535    auto results = mKernel->mOutputScalarResult;
536    if (LLVM_UNLIKELY(results == nullptr)) {
537        report_fatal_error("Cannot get accumulator " + accumName + " until " + mKernel->getName() + " has terminated.");
538    }
539    const auto & outputs = mKernel->getScalarOutputs();
540    const auto n = outputs.size();
541    if (LLVM_UNLIKELY(n == 0)) {
542        report_fatal_error(mKernel->getName() + " has no output scalars.");
543    } else {
544        for (unsigned i = 0; i < n; ++i) {
545            const Binding & b = outputs[i];
546            if (b.getName() == accumName) {
547                if (n == 1) {
548                    return results;
549                } else {
550                    return CreateExtractValue(results, {i});
551                }
552            }
553        }
554        report_fatal_error(mKernel->getName() + " has no output scalar named " + accumName);
555    }
556}
557
558BasicBlock * KernelBuilder::CreateConsumerWait() {
559    const auto consumers = mKernel->getStreamOutputs();
560    BasicBlock * const entry = GetInsertBlock();
561    if (consumers.empty()) {
562        return entry;
563    } else {
564        Function * const parent = entry->getParent();
565        IntegerType * const sizeTy = getSizeTy();
566        ConstantInt * const zero = getInt32(0);
567        ConstantInt * const one = getInt32(1);
568        ConstantInt * const size0 = getSize(0);
569
570        Value * const segNo = acquireLogicalSegmentNo();
571        const auto n = consumers.size();
572        BasicBlock * load[n + 1];
573        BasicBlock * wait[n];
574        for (unsigned i = 0; i < n; ++i) {
575            load[i] = BasicBlock::Create(getContext(), consumers[i].getName() + "Load", parent);
576            wait[i] = BasicBlock::Create(getContext(), consumers[i].getName() + "Wait", parent);
577        }
578        load[n] = BasicBlock::Create(getContext(), "Resume", parent);
579        CreateBr(load[0]);
580        for (unsigned i = 0; i < n; ++i) {
581
582            SetInsertPoint(load[i]);
583            Value * const outputConsumers = getConsumerLock(consumers[i].getName());
584
585            Value * const consumerCount = CreateLoad(CreateGEP(outputConsumers, {zero, zero}));
586            Value * const consumerPtr = CreateLoad(CreateGEP(outputConsumers, {zero, one}));
587            Value * const noConsumers = CreateICmpEQ(consumerCount, size0);
588            CreateUnlikelyCondBr(noConsumers, load[i + 1], wait[i]);
589
590            SetInsertPoint(wait[i]);
591            PHINode * const consumerPhi = CreatePHI(sizeTy, 2);
592            consumerPhi->addIncoming(size0, load[i]);
593
594            Value * const conSegPtr = CreateLoad(CreateGEP(consumerPtr, consumerPhi));
595            Value * const processedSegmentCount = CreateAtomicLoadAcquire(conSegPtr);
596            Value * const ready = CreateICmpEQ(segNo, processedSegmentCount);
597            assert (ready->getType() == getInt1Ty());
598            Value * const nextConsumerIdx = CreateAdd(consumerPhi, CreateZExt(ready, sizeTy));
599            consumerPhi->addIncoming(nextConsumerIdx, wait[i]);
600            Value * const next = CreateICmpEQ(nextConsumerIdx, consumerCount);
601            CreateCondBr(next, load[i + 1], wait[i]);
602        }
603
604        BasicBlock * const exit = load[n];
605        SetInsertPoint(exit);
606        return exit;
607    }
608}
609
610}
Note: See TracBrowser for help on using the repository browser.