source: icGREP/icgrep-devel/icgrep/kernels/kernel_builder.cpp @ 6233

Last change on this file since 6233 was 6233, checked in by nmedfort, 4 months ago

Moved termination signals into pipeline kernel

File size: 17.0 KB
Line 
1#include "kernel_builder.h"
2#include <toolchain/toolchain.h>
3#include <kernels/kernel.h>
4#include <kernels/streamset.h>
5#include <llvm/Support/raw_ostream.h>
6#include <llvm/IR/Module.h>
7
8using namespace llvm;
9
10inline static bool is_power_2(const uint64_t n) {
11    return ((n & (n - 1)) == 0) && n;
12}
13
14namespace kernel {
15
16using Port = Kernel::Port;
17
18inline Value * KernelBuilder::getScalarFieldPtr(Value * const handle, Value * const index) {
19    assert ("handle cannot be null" && handle);
20    assert ("index cannot be null" && index);
21    if (LLVM_UNLIKELY(codegen::DebugOptionIsSet(codegen::EnableAsserts))) {
22        CreateAssert(handle, "getScalarFieldPtr: handle cannot be null!");
23    }
24    #ifndef NDEBUG
25    const Function * const handleFunction = isa<Argument>(handle) ? cast<Argument>(handle)->getParent() : cast<Instruction>(handle)->getParent()->getParent();
26    const Function * const builderFunction = GetInsertBlock()->getParent();
27    assert ("handle is not from the current function." && handleFunction == builderFunction);
28    #endif
29    return CreateGEP(handle, {getInt32(0), index});
30}
31
32#warning TODO: make get scalar field able to get I/O scalars
33
34inline Value * KernelBuilder::getScalarFieldPtr(Value * const handle, const std::string & fieldName) {
35    ConstantInt * const index = getInt32(mKernel->getScalarIndex(fieldName));
36    return getScalarFieldPtr(handle, index);
37}
38
39Value * KernelBuilder::getScalarFieldPtr(Value * const index) {
40    return getScalarFieldPtr(mKernel->getHandle(), index);
41}
42
43Value * KernelBuilder::getScalarFieldPtr(const std::string & fieldName) {
44    return getScalarFieldPtr(mKernel->getHandle(), fieldName);
45}
46
47Value * KernelBuilder::getScalarField(const std::string & fieldName) {
48    Value * const ptr = getScalarFieldPtr(fieldName);
49    return CreateLoad(ptr, fieldName);
50}
51
52void KernelBuilder::setScalarField(const std::string & fieldName, Value * const value) {
53    Value * const ptr = getScalarFieldPtr(fieldName);
54    CreateStore(value, ptr);
55}
56
57LoadInst * KernelBuilder::acquireLogicalSegmentNo() {
58    return CreateAtomicLoadAcquire(getScalarFieldPtr(LOGICAL_SEGMENT_NO_SCALAR));
59}
60
61void KernelBuilder::releaseLogicalSegmentNo(Value * const nextSegNo) {
62    CreateAtomicStoreRelease(nextSegNo, getScalarFieldPtr(LOGICAL_SEGMENT_NO_SCALAR));
63}
64
65Value * KernelBuilder::getCycleCountPtr() {
66    return getScalarFieldPtr(CYCLECOUNT_SCALAR);
67}
68
69/** ------------------------------------------------------------------------------------------------------------- *
70 * @brief getNamedItemCount
71 ** ------------------------------------------------------------------------------------------------------------- */
72Value * KernelBuilder::getNamedItemCount(const std::string & name, const std::string & suffix) {
73    const ProcessingRate & rate = mKernel->getStreamBinding(name).getRate();
74    Value * itemCount = nullptr;
75    if (LLVM_UNLIKELY(rate.isRelative())) {
76        Port port; unsigned index;
77        std::tie(port, index) = mKernel->getStreamPort(rate.getReference());
78        if (port == Port::Input) {
79            itemCount = getProcessedItemCount(rate.getReference());
80        } else {
81            itemCount = getProducedItemCount(rate.getReference());
82        }
83        itemCount = CreateMul2(itemCount, rate.getRate());
84    } else {
85        itemCount = getScalarField(name + suffix);
86    }
87    return itemCount;
88}
89
90/** ------------------------------------------------------------------------------------------------------------- *
91 * @brief setNamedItemCount
92 ** ------------------------------------------------------------------------------------------------------------- */
93void KernelBuilder::setNamedItemCount(const std::string & name, const std::string & suffix, Value * const value) {
94    const ProcessingRate & rate = mKernel->getStreamBinding(name).getRate();
95    if (LLVM_UNLIKELY(rate.isDerived())) {
96        report_fatal_error("cannot set item count: " + name + " is a derived rate stream");
97    }
98    if (codegen::DebugOptionIsSet(codegen::TraceCounts)) {
99        CallPrintInt(mKernel->getName() + ": " + name + suffix, value);
100    }
101    if (LLVM_UNLIKELY(codegen::DebugOptionIsSet(codegen::EnableAsserts))) {
102        Value * const current = getScalarField(name + suffix);
103        CreateAssert(CreateICmpUGE(value, current), name + suffix + " must be monotonically non-decreasing");
104    }
105    setScalarField(name + suffix, value);
106}
107
108/** ------------------------------------------------------------------------------------------------------------- *
109 * @brief getAvailableItemCount
110 ** ------------------------------------------------------------------------------------------------------------- */
111Value * KernelBuilder::getAvailableItemCount(const std::string & name) {
112    return mKernel->getAvailableInputItems(name);
113}
114
115/** ------------------------------------------------------------------------------------------------------------- *
116 * @brief getAccessibleItemCount
117 ** ------------------------------------------------------------------------------------------------------------- */
118Value * KernelBuilder::getAccessibleItemCount(const std::string & name) {
119    return mKernel->getAccessibleInputItems(name);
120}
121
122/** ------------------------------------------------------------------------------------------------------------- *
123 * @brief getTerminationSignal
124 ** ------------------------------------------------------------------------------------------------------------- */
125Value * KernelBuilder::getTerminationSignal() {
126    Value * const ptr = mKernel->getTerminationSignalPtr();
127    if (ptr) {
128        return CreateLoad(ptr);
129    } else {
130        return getFalse();
131    }
132}
133
134/** ------------------------------------------------------------------------------------------------------------- *
135 * @brief setTerminationSignal
136 ** ------------------------------------------------------------------------------------------------------------- */
137void KernelBuilder::setTerminationSignal(Value * const value) {
138    assert (value);
139    assert (value->getType() == getInt1Ty());
140    if (codegen::DebugOptionIsSet(codegen::TraceCounts)) {
141        CallPrintInt(mKernel->getName() + ": setTerminationSignal", value);
142    }
143    Value * const ptr = mKernel->getTerminationSignalPtr();
144    if (LLVM_UNLIKELY(ptr == nullptr)) {
145        llvm::report_fatal_error(mKernel->getName() + " does not have CanTerminateEarly or MustExplicitlyTerminate set.");
146    }
147    CreateStore(value, ptr);
148}
149
150Value * KernelBuilder::getInputStreamBlockPtr(const std::string & name, Value * const streamIndex, Value * const blockOffset) {
151    const StreamSetBuffer * const buf = mKernel->getInputStreamSetBuffer(name);
152    Value * const processed = getProcessedItemCount(name);
153    Value * blockIndex = CreateLShr(processed, std::log2(getBitBlockWidth()));
154    if (blockOffset) {
155        blockIndex = CreateAdd(blockIndex, CreateZExtOrTrunc(blockOffset, blockIndex->getType()));
156    }
157    return buf->getStreamBlockPtr(this, streamIndex, blockIndex);
158}
159
160Value * KernelBuilder::getInputStreamPackPtr(const std::string & name, Value * const streamIndex, Value * const packIndex, Value * const blockOffset) {
161    const StreamSetBuffer * const buf = mKernel->getInputStreamSetBuffer(name);
162    Value * const processed = getProcessedItemCount(name);
163    Value * blockIndex = CreateLShr(processed, std::log2(getBitBlockWidth()));
164    if (blockOffset) {
165        blockIndex = CreateAdd(blockIndex, CreateZExtOrTrunc(blockOffset, blockIndex->getType()));
166    }
167    return buf->getStreamPackPtr(this, streamIndex, blockIndex, packIndex);
168}
169
170Value * KernelBuilder::loadInputStreamBlock(const std::string & name, Value * const streamIndex, Value * const blockOffset) {
171    return CreateBlockAlignedLoad(getInputStreamBlockPtr(name, streamIndex, blockOffset));
172}
173
174Value * KernelBuilder::loadInputStreamPack(const std::string & name, Value * const streamIndex, Value * const packIndex, Value * const blockOffset) {
175    return CreateBlockAlignedLoad(getInputStreamPackPtr(name, streamIndex, packIndex, blockOffset));
176}
177
178Value * KernelBuilder::getInputStreamSetCount(const std::string & name) {
179    const StreamSetBuffer * const buf = mKernel->getInputStreamSetBuffer(name);
180    return buf->getStreamSetCount(this);
181}
182
183Value * KernelBuilder::getOutputStreamBlockPtr(const std::string & name, Value * streamIndex, Value * const blockOffset) {
184    const StreamSetBuffer * const buf = mKernel->getOutputStreamSetBuffer(name);
185    Value * const produced = getProducedItemCount(name);
186    Value * blockIndex = CreateLShr(produced, std::log2(getBitBlockWidth()));
187    if (blockOffset) {
188        blockIndex = CreateAdd(blockIndex, CreateZExtOrTrunc(blockOffset, blockIndex->getType()));
189    }
190    return buf->getStreamBlockPtr(this, streamIndex, blockIndex);
191}
192
193Value * KernelBuilder::getOutputStreamPackPtr(const std::string & name, Value * streamIndex, Value * packIndex, Value * blockOffset) {
194    const StreamSetBuffer * const buf = mKernel->getOutputStreamSetBuffer(name);
195    Value * const produced = getProducedItemCount(name);
196    Value * blockIndex = CreateLShr(produced, std::log2(getBitBlockWidth()));
197    if (blockOffset) {
198        blockIndex = CreateAdd(blockIndex, CreateZExtOrTrunc(blockOffset, blockIndex->getType()));
199    }
200    return buf->getStreamPackPtr(this, streamIndex, blockIndex, packIndex);
201}
202
203StoreInst * KernelBuilder::storeOutputStreamBlock(const std::string & name, Value * streamIndex, Value * blockOffset, Value * toStore) {
204    Value * const ptr = getOutputStreamBlockPtr(name, streamIndex, blockOffset);
205    Type * const storeTy = toStore->getType();
206    Type * const ptrElemTy = ptr->getType()->getPointerElementType();
207    if (LLVM_UNLIKELY(storeTy != ptrElemTy)) {
208        if (LLVM_LIKELY(storeTy->canLosslesslyBitCastTo(ptrElemTy))) {
209            toStore = CreateBitCast(toStore, ptrElemTy);
210        } else {
211            std::string tmp;
212            raw_string_ostream out(tmp);
213            out << "invalid type conversion when calling storeOutputStreamBlock on " <<  name << ": ";
214            ptrElemTy->print(out);
215            out << " vs. ";
216            storeTy->print(out);
217        }
218    }
219    return CreateBlockAlignedStore(toStore, ptr);
220}
221
222StoreInst * KernelBuilder::storeOutputStreamPack(const std::string & name, Value * streamIndex, Value * packIndex, Value * blockOffset, Value * toStore) {
223    Value * const ptr = getOutputStreamPackPtr(name, streamIndex, packIndex, blockOffset);
224    Type * const storeTy = toStore->getType();
225    Type * const ptrElemTy = ptr->getType()->getPointerElementType();
226    if (LLVM_UNLIKELY(storeTy != ptrElemTy)) {
227        if (LLVM_LIKELY(storeTy->canLosslesslyBitCastTo(ptrElemTy))) {
228            toStore = CreateBitCast(toStore, ptrElemTy);
229        } else {
230            std::string tmp;
231            raw_string_ostream out(tmp);
232            out << "invalid type conversion when calling storeOutputStreamPack on " <<  name << ": ";
233            ptrElemTy->print(out);
234            out << " vs. ";
235            storeTy->print(out);
236        }
237    }
238    return CreateBlockAlignedStore(toStore, ptr);
239}
240
241Value * KernelBuilder::getOutputStreamSetCount(const std::string & name) {
242    const StreamSetBuffer * const buf = mKernel->getOutputStreamSetBuffer(name);
243    return buf->getStreamSetCount(this);
244}
245
246Value * KernelBuilder::getRawInputPointer(const std::string & name, Value * absolutePosition) {
247    const StreamSetBuffer * const buf = mKernel->getInputStreamSetBuffer(name);
248    return buf->getRawItemPointer(this, absolutePosition);
249}
250
251Value * KernelBuilder::getRawOutputPointer(const std::string & name, Value * absolutePosition) {
252    const StreamSetBuffer * const buf = mKernel->getOutputStreamSetBuffer(name);
253    return buf->getRawItemPointer(this, absolutePosition);
254}
255
256Value * KernelBuilder::getBaseAddress(const std::string & name) {
257    return mKernel->getStreamSetBuffer(name)->getBaseAddress(this);
258}
259
260void KernelBuilder::setBaseAddress(const std::string & name, Value * const addr) {
261    return mKernel->getStreamSetBuffer(name)->setBaseAddress(this, addr);
262}
263
264Value * KernelBuilder::getCapacity(const std::string & name) {
265    return mKernel->getStreamSetBuffer(name)->getCapacity(this);
266}
267
268void KernelBuilder::setCapacity(const std::string & name, Value * capacity) {
269    mKernel->getStreamSetBuffer(name)->setCapacity(this, capacity);
270}
271
272/** ------------------------------------------------------------------------------------------------------------- *
273 * @brief CreateUDiv2
274 ** ------------------------------------------------------------------------------------------------------------- */
275Value * KernelBuilder::CreateUDiv2(Value * const number, const ProcessingRate::RateValue & divisor, const Twine & Name) {
276    if (divisor.numerator() == 1 && divisor.denominator() == 1) {
277        return number;
278    }
279    Constant * const n = ConstantInt::get(number->getType(), divisor.numerator());
280    if (LLVM_LIKELY(divisor.denominator() == 1)) {
281        return CreateUDiv(number, n, Name);
282    } else {
283        Constant * const d = ConstantInt::get(number->getType(), divisor.denominator());
284        return CreateUDiv(CreateMul(number, d), n);
285    }
286}
287
288/** ------------------------------------------------------------------------------------------------------------- *
289 * @brief CreateCeilUDiv2
290 ** ------------------------------------------------------------------------------------------------------------- */
291Value * KernelBuilder::CreateCeilUDiv2(Value * const number, const ProcessingRate::RateValue & divisor, const Twine & Name) {
292    if (divisor.numerator() == 1 && divisor.denominator() == 1) {
293        return number;
294    }
295    Constant * const n = ConstantInt::get(number->getType(), divisor.numerator());
296    if (LLVM_LIKELY(divisor.denominator() == 1)) {
297        return CreateCeilUDiv(number, n, Name);
298    } else {
299        //   âŒŠ(num + ratio - 1) / ratio⌋
300        // = ⌊(num - 1) / (n/d)⌋ + (ratio/ratio)
301        // = ⌊(d * (num - 1)) / n⌋ + 1
302        Constant * const ONE = ConstantInt::get(number->getType(), 1);
303        Constant * const d = ConstantInt::get(number->getType(), divisor.denominator());
304        return CreateAdd(CreateUDiv(CreateMul(CreateSub(number, ONE), d), n), ONE, Name);
305    }
306}
307
308/** ------------------------------------------------------------------------------------------------------------- *
309 * @brief CreateMul2
310 ** ------------------------------------------------------------------------------------------------------------- */
311Value * KernelBuilder::CreateMul2(Value * const number, const ProcessingRate::RateValue & factor, const Twine & Name) {
312    if (factor.numerator() == 1 && factor.denominator() == 1) {
313        return number;
314    }
315    Constant * const n = ConstantInt::get(number->getType(), factor.numerator());
316    if (LLVM_LIKELY(factor.denominator() == 1)) {
317        return CreateMul(number, n, Name);
318    } else {
319        Constant * const d = ConstantInt::get(number->getType(), factor.denominator());
320        return CreateUDiv(CreateMul(number, n), d, Name);
321    }
322}
323
324/** ------------------------------------------------------------------------------------------------------------- *
325 * @brief CreateMulCeil2
326 ** ------------------------------------------------------------------------------------------------------------- */
327Value * KernelBuilder::CreateCeilUMul2(Value * const number, const ProcessingRate::RateValue & factor, const Twine & Name) {
328    if (factor.denominator() == 1) {
329        return CreateMul2(number, factor, Name);
330    }
331    Constant * const n = ConstantInt::get(number->getType(), factor.numerator());
332    Constant * const d = ConstantInt::get(number->getType(), factor.denominator());
333    return CreateCeilUDiv(CreateMul(number, n), d, Name);
334}
335
336/** ------------------------------------------------------------------------------------------------------------- *
337 * @brief resolveStreamSetType
338 ** ------------------------------------------------------------------------------------------------------------- */
339Type * KernelBuilder::resolveStreamSetType(Type * const streamSetType) {
340    // TODO: Should this function be here? in StreamSetBuffer? or in Binding?
341    unsigned numElements = 1;
342    Type * type = streamSetType;
343    if (LLVM_LIKELY(type->isArrayTy())) {
344        numElements = type->getArrayNumElements();
345        type = type->getArrayElementType();
346    }
347    if (LLVM_LIKELY(type->isVectorTy() && type->getVectorNumElements() == 0)) {
348        type = type->getVectorElementType();
349        if (LLVM_LIKELY(type->isIntegerTy())) {
350            const auto fieldWidth = cast<IntegerType>(type)->getBitWidth();
351            type = getBitBlockType();
352            if (fieldWidth != 1) {
353                type = ArrayType::get(type, fieldWidth);
354            }
355            return ArrayType::get(type, numElements);
356        }
357    }
358    std::string tmp;
359    raw_string_ostream out(tmp);
360    streamSetType->print(out);
361    out << " is an unvalid stream set buffer type.";
362    report_fatal_error(out.str());
363}
364
365/** ------------------------------------------------------------------------------------------------------------- *
366 * @brief getKernelName
367 ** ------------------------------------------------------------------------------------------------------------- */
368std::string KernelBuilder::getKernelName() const {
369    return mKernel->getName();
370}
371
372}
Note: See TracBrowser for help on using the repository browser.