source: icGREP/icgrep-devel/icgrep/kernels/kernel_builder.cpp @ 6237

Last change on this file since 6237 was 6237, checked in by nmedfort, 4 months ago

Re-enabled segment pipeline parallelism; moved logical segment number into pipeline kernel.

File size: 16.7 KB
Line 
1#include "kernel_builder.h"
2#include <toolchain/toolchain.h>
3#include <kernels/kernel.h>
4#include <kernels/streamset.h>
5#include <llvm/Support/raw_ostream.h>
6#include <llvm/IR/Module.h>
7
8using namespace llvm;
9
10inline static bool is_power_2(const uint64_t n) {
11    return ((n & (n - 1)) == 0) && n;
12}
13
14namespace kernel {
15
16using Port = Kernel::Port;
17
18inline Value * KernelBuilder::getScalarFieldPtr(Value * const handle, Value * const index) {
19    assert ("handle cannot be null" && handle);
20    assert ("index cannot be null" && index);
21    if (LLVM_UNLIKELY(codegen::DebugOptionIsSet(codegen::EnableAsserts))) {
22        CreateAssert(handle, "getScalarFieldPtr: handle cannot be null!");
23    }
24    #ifndef NDEBUG
25    const Function * const handleFunction = isa<Argument>(handle) ? cast<Argument>(handle)->getParent() : cast<Instruction>(handle)->getParent()->getParent();
26    const Function * const builderFunction = GetInsertBlock()->getParent();
27    assert ("handle is not from the current function." && handleFunction == builderFunction);
28    #endif
29    return CreateGEP(handle, {getInt32(0), index});
30}
31
32#warning TODO: make get scalar field able to get I/O scalars
33
34inline Value * KernelBuilder::getScalarFieldPtr(Value * const handle, const std::string & fieldName) {
35    ConstantInt * const index = getInt32(mKernel->getScalarIndex(fieldName));
36    return getScalarFieldPtr(handle, index);
37}
38
39Value * KernelBuilder::getScalarFieldPtr(Value * const index) {
40    return getScalarFieldPtr(mKernel->getHandle(), index);
41}
42
43Value * KernelBuilder::getScalarFieldPtr(const std::string & fieldName) {
44    return getScalarFieldPtr(mKernel->getHandle(), fieldName);
45}
46
47Value * KernelBuilder::getScalarField(const std::string & fieldName) {
48    Value * const ptr = getScalarFieldPtr(fieldName);
49    return CreateLoad(ptr, fieldName);
50}
51
52void KernelBuilder::setScalarField(const std::string & fieldName, Value * const value) {
53    Value * const ptr = getScalarFieldPtr(fieldName);
54    CreateStore(value, ptr);
55}
56
57Value * KernelBuilder::getCycleCountPtr() {
58    return getScalarFieldPtr(CYCLECOUNT_SCALAR);
59}
60
61/** ------------------------------------------------------------------------------------------------------------- *
62 * @brief getNamedItemCount
63 ** ------------------------------------------------------------------------------------------------------------- */
64Value * KernelBuilder::getNamedItemCount(const std::string & name, const std::string & suffix) {
65    const ProcessingRate & rate = mKernel->getStreamBinding(name).getRate();
66    Value * itemCount = nullptr;
67    if (LLVM_UNLIKELY(rate.isRelative())) {
68        Port port; unsigned index;
69        std::tie(port, index) = mKernel->getStreamPort(rate.getReference());
70        if (port == Port::Input) {
71            itemCount = getProcessedItemCount(rate.getReference());
72        } else {
73            itemCount = getProducedItemCount(rate.getReference());
74        }
75        itemCount = CreateMul2(itemCount, rate.getRate());
76    } else {
77        itemCount = getScalarField(name + suffix);
78    }
79    return itemCount;
80}
81
82/** ------------------------------------------------------------------------------------------------------------- *
83 * @brief setNamedItemCount
84 ** ------------------------------------------------------------------------------------------------------------- */
85void KernelBuilder::setNamedItemCount(const std::string & name, const std::string & suffix, Value * const value) {
86    const ProcessingRate & rate = mKernel->getStreamBinding(name).getRate();
87    if (LLVM_UNLIKELY(rate.isDerived())) {
88        report_fatal_error("cannot set item count: " + name + " is a derived rate stream");
89    }
90    if (codegen::DebugOptionIsSet(codegen::TraceCounts)) {
91        CallPrintInt(mKernel->getName() + ": " + name + suffix, value);
92    }
93    if (LLVM_UNLIKELY(codegen::DebugOptionIsSet(codegen::EnableAsserts))) {
94        Value * const current = getScalarField(name + suffix);
95        CreateAssert(CreateICmpUGE(value, current), name + suffix + " must be monotonically non-decreasing");
96    }
97    setScalarField(name + suffix, value);
98}
99
100/** ------------------------------------------------------------------------------------------------------------- *
101 * @brief getAvailableItemCount
102 ** ------------------------------------------------------------------------------------------------------------- */
103Value * KernelBuilder::getAvailableItemCount(const std::string & name) {
104    return mKernel->getAvailableInputItems(name);
105}
106
107/** ------------------------------------------------------------------------------------------------------------- *
108 * @brief getAccessibleItemCount
109 ** ------------------------------------------------------------------------------------------------------------- */
110Value * KernelBuilder::getAccessibleItemCount(const std::string & name) {
111    return mKernel->getAccessibleInputItems(name);
112}
113
114/** ------------------------------------------------------------------------------------------------------------- *
115 * @brief getTerminationSignal
116 ** ------------------------------------------------------------------------------------------------------------- */
117Value * KernelBuilder::getTerminationSignal() {
118    Value * const ptr = mKernel->getTerminationSignalPtr();
119    if (ptr) {
120        return CreateLoad(ptr);
121    } else {
122        return getFalse();
123    }
124}
125
126/** ------------------------------------------------------------------------------------------------------------- *
127 * @brief setTerminationSignal
128 ** ------------------------------------------------------------------------------------------------------------- */
129void KernelBuilder::setTerminationSignal(Value * const value) {
130    assert (value);
131    assert (value->getType() == getInt1Ty());
132    if (codegen::DebugOptionIsSet(codegen::TraceCounts)) {
133        CallPrintInt(mKernel->getName() + ": setTerminationSignal", value);
134    }
135    Value * const ptr = mKernel->getTerminationSignalPtr();
136    if (LLVM_UNLIKELY(ptr == nullptr)) {
137        llvm::report_fatal_error(mKernel->getName() + " does not have CanTerminateEarly or MustExplicitlyTerminate set.");
138    }
139    CreateStore(value, ptr);
140}
141
142Value * KernelBuilder::getInputStreamBlockPtr(const std::string & name, Value * const streamIndex, Value * const blockOffset) {
143    const StreamSetBuffer * const buf = mKernel->getInputStreamSetBuffer(name);
144    Value * const processed = getProcessedItemCount(name);
145    Value * blockIndex = CreateLShr(processed, std::log2(getBitBlockWidth()));
146    if (blockOffset) {
147        blockIndex = CreateAdd(blockIndex, CreateZExtOrTrunc(blockOffset, blockIndex->getType()));
148    }
149    return buf->getStreamBlockPtr(this, streamIndex, blockIndex);
150}
151
152Value * KernelBuilder::getInputStreamPackPtr(const std::string & name, Value * const streamIndex, Value * const packIndex, Value * const blockOffset) {
153    const StreamSetBuffer * const buf = mKernel->getInputStreamSetBuffer(name);
154    Value * const processed = getProcessedItemCount(name);
155    Value * blockIndex = CreateLShr(processed, std::log2(getBitBlockWidth()));
156    if (blockOffset) {
157        blockIndex = CreateAdd(blockIndex, CreateZExtOrTrunc(blockOffset, blockIndex->getType()));
158    }
159    return buf->getStreamPackPtr(this, streamIndex, blockIndex, packIndex);
160}
161
162Value * KernelBuilder::loadInputStreamBlock(const std::string & name, Value * const streamIndex, Value * const blockOffset) {
163    return CreateBlockAlignedLoad(getInputStreamBlockPtr(name, streamIndex, blockOffset));
164}
165
166Value * KernelBuilder::loadInputStreamPack(const std::string & name, Value * const streamIndex, Value * const packIndex, Value * const blockOffset) {
167    return CreateBlockAlignedLoad(getInputStreamPackPtr(name, streamIndex, packIndex, blockOffset));
168}
169
170Value * KernelBuilder::getInputStreamSetCount(const std::string & name) {
171    const StreamSetBuffer * const buf = mKernel->getInputStreamSetBuffer(name);
172    return buf->getStreamSetCount(this);
173}
174
175Value * KernelBuilder::getOutputStreamBlockPtr(const std::string & name, Value * streamIndex, Value * const blockOffset) {
176    const StreamSetBuffer * const buf = mKernel->getOutputStreamSetBuffer(name);
177    Value * const produced = getProducedItemCount(name);
178    Value * blockIndex = CreateLShr(produced, std::log2(getBitBlockWidth()));
179    if (blockOffset) {
180        blockIndex = CreateAdd(blockIndex, CreateZExtOrTrunc(blockOffset, blockIndex->getType()));
181    }
182    return buf->getStreamBlockPtr(this, streamIndex, blockIndex);
183}
184
185Value * KernelBuilder::getOutputStreamPackPtr(const std::string & name, Value * streamIndex, Value * packIndex, Value * blockOffset) {
186    const StreamSetBuffer * const buf = mKernel->getOutputStreamSetBuffer(name);
187    Value * const produced = getProducedItemCount(name);
188    Value * blockIndex = CreateLShr(produced, std::log2(getBitBlockWidth()));
189    if (blockOffset) {
190        blockIndex = CreateAdd(blockIndex, CreateZExtOrTrunc(blockOffset, blockIndex->getType()));
191    }
192    return buf->getStreamPackPtr(this, streamIndex, blockIndex, packIndex);
193}
194
195StoreInst * KernelBuilder::storeOutputStreamBlock(const std::string & name, Value * streamIndex, Value * blockOffset, Value * toStore) {
196    Value * const ptr = getOutputStreamBlockPtr(name, streamIndex, blockOffset);
197    Type * const storeTy = toStore->getType();
198    Type * const ptrElemTy = ptr->getType()->getPointerElementType();
199    if (LLVM_UNLIKELY(storeTy != ptrElemTy)) {
200        if (LLVM_LIKELY(storeTy->canLosslesslyBitCastTo(ptrElemTy))) {
201            toStore = CreateBitCast(toStore, ptrElemTy);
202        } else {
203            std::string tmp;
204            raw_string_ostream out(tmp);
205            out << "invalid type conversion when calling storeOutputStreamBlock on " <<  name << ": ";
206            ptrElemTy->print(out);
207            out << " vs. ";
208            storeTy->print(out);
209        }
210    }
211    return CreateBlockAlignedStore(toStore, ptr);
212}
213
214StoreInst * KernelBuilder::storeOutputStreamPack(const std::string & name, Value * streamIndex, Value * packIndex, Value * blockOffset, Value * toStore) {
215    Value * const ptr = getOutputStreamPackPtr(name, streamIndex, packIndex, blockOffset);
216    Type * const storeTy = toStore->getType();
217    Type * const ptrElemTy = ptr->getType()->getPointerElementType();
218    if (LLVM_UNLIKELY(storeTy != ptrElemTy)) {
219        if (LLVM_LIKELY(storeTy->canLosslesslyBitCastTo(ptrElemTy))) {
220            toStore = CreateBitCast(toStore, ptrElemTy);
221        } else {
222            std::string tmp;
223            raw_string_ostream out(tmp);
224            out << "invalid type conversion when calling storeOutputStreamPack on " <<  name << ": ";
225            ptrElemTy->print(out);
226            out << " vs. ";
227            storeTy->print(out);
228        }
229    }
230    return CreateBlockAlignedStore(toStore, ptr);
231}
232
233Value * KernelBuilder::getOutputStreamSetCount(const std::string & name) {
234    const StreamSetBuffer * const buf = mKernel->getOutputStreamSetBuffer(name);
235    return buf->getStreamSetCount(this);
236}
237
238Value * KernelBuilder::getRawInputPointer(const std::string & name, Value * absolutePosition) {
239    const StreamSetBuffer * const buf = mKernel->getInputStreamSetBuffer(name);
240    return buf->getRawItemPointer(this, absolutePosition);
241}
242
243Value * KernelBuilder::getRawOutputPointer(const std::string & name, Value * absolutePosition) {
244    const StreamSetBuffer * const buf = mKernel->getOutputStreamSetBuffer(name);
245    return buf->getRawItemPointer(this, absolutePosition);
246}
247
248Value * KernelBuilder::getBaseAddress(const std::string & name) {
249    return mKernel->getStreamSetBuffer(name)->getBaseAddress(this);
250}
251
252void KernelBuilder::setBaseAddress(const std::string & name, Value * const addr) {
253    return mKernel->getStreamSetBuffer(name)->setBaseAddress(this, addr);
254}
255
256Value * KernelBuilder::getCapacity(const std::string & name) {
257    return mKernel->getStreamSetBuffer(name)->getCapacity(this);
258}
259
260void KernelBuilder::setCapacity(const std::string & name, Value * capacity) {
261    mKernel->getStreamSetBuffer(name)->setCapacity(this, capacity);
262}
263
264/** ------------------------------------------------------------------------------------------------------------- *
265 * @brief CreateUDiv2
266 ** ------------------------------------------------------------------------------------------------------------- */
267Value * KernelBuilder::CreateUDiv2(Value * const number, const ProcessingRate::RateValue & divisor, const Twine & Name) {
268    if (divisor.numerator() == 1 && divisor.denominator() == 1) {
269        return number;
270    }
271    Constant * const n = ConstantInt::get(number->getType(), divisor.numerator());
272    if (LLVM_LIKELY(divisor.denominator() == 1)) {
273        return CreateUDiv(number, n, Name);
274    } else {
275        Constant * const d = ConstantInt::get(number->getType(), divisor.denominator());
276        return CreateUDiv(CreateMul(number, d), n);
277    }
278}
279
280/** ------------------------------------------------------------------------------------------------------------- *
281 * @brief CreateCeilUDiv2
282 ** ------------------------------------------------------------------------------------------------------------- */
283Value * KernelBuilder::CreateCeilUDiv2(Value * const number, const ProcessingRate::RateValue & divisor, const Twine & Name) {
284    if (divisor.numerator() == 1 && divisor.denominator() == 1) {
285        return number;
286    }
287    Constant * const n = ConstantInt::get(number->getType(), divisor.numerator());
288    if (LLVM_LIKELY(divisor.denominator() == 1)) {
289        return CreateCeilUDiv(number, n, Name);
290    } else {
291        //   âŒŠ(num + ratio - 1) / ratio⌋
292        // = ⌊(num - 1) / (n/d)⌋ + (ratio/ratio)
293        // = ⌊(d * (num - 1)) / n⌋ + 1
294        Constant * const ONE = ConstantInt::get(number->getType(), 1);
295        Constant * const d = ConstantInt::get(number->getType(), divisor.denominator());
296        return CreateAdd(CreateUDiv(CreateMul(CreateSub(number, ONE), d), n), ONE, Name);
297    }
298}
299
300/** ------------------------------------------------------------------------------------------------------------- *
301 * @brief CreateMul2
302 ** ------------------------------------------------------------------------------------------------------------- */
303Value * KernelBuilder::CreateMul2(Value * const number, const ProcessingRate::RateValue & factor, const Twine & Name) {
304    if (factor.numerator() == 1 && factor.denominator() == 1) {
305        return number;
306    }
307    Constant * const n = ConstantInt::get(number->getType(), factor.numerator());
308    if (LLVM_LIKELY(factor.denominator() == 1)) {
309        return CreateMul(number, n, Name);
310    } else {
311        Constant * const d = ConstantInt::get(number->getType(), factor.denominator());
312        return CreateUDiv(CreateMul(number, n), d, Name);
313    }
314}
315
316/** ------------------------------------------------------------------------------------------------------------- *
317 * @brief CreateMulCeil2
318 ** ------------------------------------------------------------------------------------------------------------- */
319Value * KernelBuilder::CreateCeilUMul2(Value * const number, const ProcessingRate::RateValue & factor, const Twine & Name) {
320    if (factor.denominator() == 1) {
321        return CreateMul2(number, factor, Name);
322    }
323    Constant * const n = ConstantInt::get(number->getType(), factor.numerator());
324    Constant * const d = ConstantInt::get(number->getType(), factor.denominator());
325    return CreateCeilUDiv(CreateMul(number, n), d, Name);
326}
327
328/** ------------------------------------------------------------------------------------------------------------- *
329 * @brief resolveStreamSetType
330 ** ------------------------------------------------------------------------------------------------------------- */
331Type * KernelBuilder::resolveStreamSetType(Type * const streamSetType) {
332    // TODO: Should this function be here? in StreamSetBuffer? or in Binding?
333    unsigned numElements = 1;
334    Type * type = streamSetType;
335    if (LLVM_LIKELY(type->isArrayTy())) {
336        numElements = type->getArrayNumElements();
337        type = type->getArrayElementType();
338    }
339    if (LLVM_LIKELY(type->isVectorTy() && type->getVectorNumElements() == 0)) {
340        type = type->getVectorElementType();
341        if (LLVM_LIKELY(type->isIntegerTy())) {
342            const auto fieldWidth = cast<IntegerType>(type)->getBitWidth();
343            type = getBitBlockType();
344            if (fieldWidth != 1) {
345                type = ArrayType::get(type, fieldWidth);
346            }
347            return ArrayType::get(type, numElements);
348        }
349    }
350    std::string tmp;
351    raw_string_ostream out(tmp);
352    streamSetType->print(out);
353    out << " is an unvalid stream set buffer type.";
354    report_fatal_error(out.str());
355}
356
357/** ------------------------------------------------------------------------------------------------------------- *
358 * @brief getKernelName
359 ** ------------------------------------------------------------------------------------------------------------- */
360std::string KernelBuilder::getKernelName() const {
361    return mKernel->getName();
362}
363
364}
Note: See TracBrowser for help on using the repository browser.