source: icGREP/icgrep-devel/icgrep/kernels/kernel.cpp @ 4988

Last change on this file since 4988 was 4988, checked in by cameron, 4 years ago

casefold sample application/pipeline

File size: 20.6 KB
Line 
1/*
2 *  Copyright (c) 2016 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 */
5
6#include "kernel.h"
7#include <pablo/function.h>
8#include <IDISA/idisa_builder.h>
9#include <kernels/instance.h>
10
11using namespace llvm;
12using namespace pablo;
13
14inline bool isPowerOfTwo(const unsigned x) {
15    return (x != 0) && (x & (x - 1)) == 0;
16}
17
18namespace kernel {
19
20enum : unsigned {
21    INTERNAL_STATE = 0
22    , INPUT_STREAM_SET = 1
23    , OUTPUT_STREAM_SET = 2
24    , OUTPUT_SCALAR_SET = 3
25};
26
27// sets name & sets internal state to the kernel superclass state
28KernelBuilder::KernelBuilder(std::string name, Module * m, IDISA::IDISA_Builder * b, const unsigned bufferSize)
29: mMod(m)
30, iBuilder(b)
31, mKernelName(name)
32, mBitBlockType(b->getBitBlockType())
33, mBufferSize(bufferSize)
34, mBlockNoIndex(0) {
35    assert (mBufferSize > 0);
36    mBlockNoIndex = addInternalState(b->getInt64Ty(), "BlockNo");
37}
38
39SlabAllocator<Instance> Instance::mAllocator; // static allocator declaration; should probably be in a "instance.cpp"
40
41/** ------------------------------------------------------------------------------------------------------------- *
42 * @brief addInternalState
43 ** ------------------------------------------------------------------------------------------------------------- */
44unsigned KernelBuilder::addInternalState(Type * const type) {
45    assert (type);
46    const unsigned index = mInternalState.size();
47    mInternalState.push_back(type);
48    return index;
49}
50
51unsigned KernelBuilder::addInternalState(llvm::Type * const type, std::string && name) {
52    if (LLVM_UNLIKELY(mInternalStateNameMap.count(name) != 0)) {
53        throw std::runtime_error("Kernel already contains internal state " + name);
54    }
55    const unsigned index = addInternalState(type);
56    mInternalStateNameMap.emplace(name, index);
57    return index;
58}
59
60/** ------------------------------------------------------------------------------------------------------------- *
61 * @brief getInternalState
62 ** ------------------------------------------------------------------------------------------------------------- */
63Value * KernelBuilder::getInternalState(Value * const instance, const unsigned index) {
64    Value* indices[] = {iBuilder->getInt64(0),
65                        iBuilder->getInt32(INTERNAL_STATE),
66                        iBuilder->getInt32(index)};
67    return iBuilder->CreateGEP(instance, indices);
68}
69
70Value * KernelBuilder::getInternalState(Value * const instance, const std::string & name) {
71    const auto f = mInternalStateNameMap.find(name);
72    if (LLVM_UNLIKELY(f == mInternalStateNameMap.end())) {
73        throw std::runtime_error("Kernel does not contain internal state " + name);
74    }
75    return getInternalState(instance, f->second);
76}
77
78/** ------------------------------------------------------------------------------------------------------------- *
79 * @brief setInternalState
80 ** ------------------------------------------------------------------------------------------------------------- */
81void KernelBuilder::setInternalState(Value * const instance, const std::string & name, Value * const value) {
82    Value * ptr = getInternalState(instance, name);
83    assert (ptr->getType()->getPointerElementType() == value->getType());
84    if (value->getType() == iBuilder->getBitBlockType()) {
85        iBuilder->CreateBlockAlignedStore(value, ptr);
86    } else {
87        iBuilder->CreateStore(value, ptr);
88    }
89}
90
91void KernelBuilder::setInternalState(Value * const instance, const unsigned index, Value * const value) {
92    Value * ptr = getInternalState(instance, index);
93    assert (ptr->getType()->getPointerElementType() == value->getType());
94    if (value->getType() == iBuilder->getBitBlockType()) {
95        iBuilder->CreateBlockAlignedStore(value, ptr);
96    } else {
97        iBuilder->CreateStore(value, ptr);
98    }
99}
100
101/** ------------------------------------------------------------------------------------------------------------- *
102 * @brief addInputStream
103 ** ------------------------------------------------------------------------------------------------------------- */
104void KernelBuilder::addInputStream(const unsigned fields, std::string && name) {
105    assert (fields > 0 && !name.empty());
106    mInputStreamName.push_back(name);
107    if (fields == 1) {
108        mInputStream.push_back(mBitBlockType);
109    } else {
110        mInputStream.push_back(ArrayType::get(mBitBlockType, fields));
111    }
112}
113
114void KernelBuilder::addInputStream(const unsigned fields) {
115    addInputStream(fields, std::move(mKernelName + "_InputStream_" + std::to_string(mInputStream.size())));
116}
117
118/** ------------------------------------------------------------------------------------------------------------- *
119 * @brief getInputStream
120 ** ------------------------------------------------------------------------------------------------------------- */
121Value * KernelBuilder::getInputStream(Value * const instance, const unsigned index, const unsigned streamOffset) {
122    assert (instance);
123    Value * inputStream = iBuilder->CreateLoad(iBuilder->CreateGEP(instance,
124        {iBuilder->getInt32(0), iBuilder->getInt32(INPUT_STREAM_SET), iBuilder->getInt32(0)}));
125    Value * modFunction = iBuilder->CreateLoad(iBuilder->CreateGEP(instance,
126        {iBuilder->getInt32(0), iBuilder->getInt32(INPUT_STREAM_SET), iBuilder->getInt32(1)}));
127    Value * offset = iBuilder->CreateLoad(getBlockNo(instance));
128    if (streamOffset) {
129        offset = iBuilder->CreateAdd(offset, ConstantInt::get(offset->getType(), streamOffset));
130    }   
131    offset = iBuilder->CreateCall(modFunction, offset, "offset");
132    return iBuilder->CreateGEP(inputStream, { offset, iBuilder->getInt32(index) });
133}
134
135/** ------------------------------------------------------------------------------------------------------------- *
136 * @brief addInputScalar
137 ** ------------------------------------------------------------------------------------------------------------- */
138void KernelBuilder::addInputScalar(Type * const type, std::string && name) {
139    assert (type && !name.empty());
140    mInputScalarName.push_back(name);
141    mInputScalar.push_back(type);
142}
143
144void KernelBuilder::addInputScalar(Type * const type) {
145    addInputScalar(type, std::move(mKernelName + "_InputScalar_" + std::to_string(mInputScalar.size())));
146}
147
148/** ------------------------------------------------------------------------------------------------------------- *
149 * @brief getInputScalar
150 ** ------------------------------------------------------------------------------------------------------------- */
151Value * KernelBuilder::getInputScalar(Value * const instance, const unsigned) {
152    throw std::runtime_error("currently not supported!");
153}
154
155/** ------------------------------------------------------------------------------------------------------------- *
156 * @brief addOutputStream
157 ** ------------------------------------------------------------------------------------------------------------- */
158unsigned KernelBuilder::addOutputStream(const unsigned fields) {
159    assert (fields > 0);
160    const unsigned index = mOutputStream.size();
161    mOutputStream.push_back((fields == 1) ? mBitBlockType : ArrayType::get(mBitBlockType, fields));
162    return index;
163}
164
165/** ------------------------------------------------------------------------------------------------------------- *
166 * @brief addOutputScalar
167 ** ------------------------------------------------------------------------------------------------------------- */
168unsigned KernelBuilder::addOutputScalar(Type * const type) {
169    assert (type);
170    const unsigned index = mOutputScalar.size();
171    mOutputScalar.push_back(type);
172    return index;
173}
174
175/** ------------------------------------------------------------------------------------------------------------- *
176 * @brief getOutputStream
177 ** ------------------------------------------------------------------------------------------------------------- */
178Value * KernelBuilder::getOutputStream(Value * const instance, const unsigned index, const unsigned streamOffset) {
179    assert (instance);
180    Value * offset = getOffset(instance, streamOffset);
181    Value * const indices[] = {iBuilder->getInt32(0), iBuilder->getInt32(OUTPUT_STREAM_SET), offset, iBuilder->getInt32(index)};
182    return iBuilder->CreateGEP(instance, indices);
183}
184
185/** ------------------------------------------------------------------------------------------------------------- *
186 * @brief getOutputScalar
187 ** ------------------------------------------------------------------------------------------------------------- */
188Value * KernelBuilder::getOutputScalar(Value * const instance, const unsigned) {
189    throw std::runtime_error("currently not supported!");
190}
191
192/** ------------------------------------------------------------------------------------------------------------- *
193 * @brief prepareFunction
194 ** ------------------------------------------------------------------------------------------------------------- */
195Function * KernelBuilder::prepareFunction() {
196
197    PointerType * modFunctionType = PointerType::get(FunctionType::get(iBuilder->getInt64Ty(), {iBuilder->getInt64Ty()}, false), 0);
198    mInputStreamType = PointerType::get(StructType::get(mMod->getContext(), mInputStream), 0);
199    mInputScalarType = PointerType::get(StructType::get(mMod->getContext(), mInputScalar), 0);
200    mOutputStreamType = StructType::get(mMod->getContext(), mOutputStream);
201    Type * outputScalarType = StructType::get(mMod->getContext(), mOutputScalar);
202    Type * internalStateType = StructType::create(mMod->getContext(), mInternalState);
203    Type * inputStateType = StructType::create(mMod->getContext(), { mInputStreamType, modFunctionType});
204
205    Type * outputBufferType = ArrayType::get(mOutputStreamType, mBufferSize);
206    mKernelStateType = StructType::create(mMod->getContext(), {internalStateType, inputStateType, outputBufferType, outputScalarType}, mKernelName);
207
208    FunctionType * const functionType = FunctionType::get(iBuilder->getVoidTy(), {PointerType::get(mKernelStateType, 0)}, false);
209    mDoBlock = Function::Create(functionType, GlobalValue::ExternalLinkage, mKernelName + "_DoBlock", mMod);
210    mDoBlock->setCallingConv(CallingConv::C);
211  //  mDoBlock->addAttribute(1, Attribute::NoCapture);
212 //   mDoBlock->addAttribute(AttributeSet::FunctionIndex, Attribute::ReadNone);
213 //   mDoBlock->addAttribute(AttributeSet::FunctionIndex, Attribute::NoUnwind);
214
215    Function::arg_iterator args = mDoBlock->arg_begin();
216    mKernelParam = args++;
217    mKernelParam->setName("this");
218
219    iBuilder->SetInsertPoint(BasicBlock::Create(mMod->getContext(), "entry", mDoBlock, 0));
220
221//    mLocalBlockNo = iBuilder->CreateLoad(getBlockNo());
222//    Value * blockNo = iBuilder->CreateLoad(getBlockNo());
223//    iBuilder->CallPrintInt(mKernelName + "_BlockNo", blockNo);
224//    Value * modFunction = iBuilder->CreateLoad(iBuilder->CreateGEP(mKernelParam, {iBuilder->getInt32(0), iBuilder->getInt32(INPUT_STREAM_SET), iBuilder->getInt32(1)}));
225//    blockNo = iBuilder->CreateCall(modFunction, blockNo);
226//    iBuilder->CallPrintInt(mKernelName + "_Offset", blockNo);
227
228
229    return mDoBlock;
230}
231
232/** ------------------------------------------------------------------------------------------------------------- *
233 * @brief finalize
234 ** ------------------------------------------------------------------------------------------------------------- */
235void KernelBuilder::finalize() {
236
237    // Finish the actual function
238    Value * blockNo = getBlockNo();
239    Value * value = iBuilder->CreateLoad(blockNo);
240    value = iBuilder->CreateAdd(value, ConstantInt::get(value->getType(), 1));
241    iBuilder->CreateStore(value, blockNo);
242    iBuilder->CreateRetVoid();
243
244    // Generate the zero initializer
245    PointerType * modFunctionType = PointerType::get(FunctionType::get(iBuilder->getInt64Ty(), {iBuilder->getInt64Ty()}, false), 0);
246    FunctionType * constructorType = FunctionType::get(iBuilder->getVoidTy(), {PointerType::get(mKernelStateType, 0), mInputStreamType, modFunctionType}, false);
247
248    mConstructor = Function::Create(constructorType, GlobalValue::ExternalLinkage, mKernelName + "_Constructor", mMod);
249    mConstructor->setCallingConv(CallingConv::C);
250    mConstructor->addAttribute(1, Attribute::NoCapture);
251    //mConstructor->addAttribute(AttributeSet::FunctionIndex, Attribute::InlineHint);
252   // mConstructor->addAttribute(AttributeSet::FunctionIndex, Attribute::ReadNone);
253    //mConstructor->addAttribute(AttributeSet::FunctionIndex, Attribute::NoUnwind);
254    auto args = mConstructor->arg_begin();
255    mKernelParam = args++;
256    mKernelParam->setName("this");
257    Value * const inputStream = args++;
258    inputStream->setName("inputStream");
259    Value * const modFunction = args++;
260    modFunction->setName("modFunction");
261
262    iBuilder->SetInsertPoint(BasicBlock::Create(mMod->getContext(), "entry", mConstructor, 0));
263    for (unsigned i = 0; i < mInternalState.size(); ++i) {
264        Type * const type = mInternalState[i];
265        if (type->isIntegerTy() || type->isArrayTy() || type->isVectorTy()) {
266            setInternalState(i, Constant::getNullValue(type));
267        } else {
268            Value * const ptr = getInternalState(i);
269            Value * const size = iBuilder->CreatePtrDiff(iBuilder->CreateGEP(ptr, iBuilder->getInt32(1)), ptr);
270            iBuilder->CreateMemSet(ptr, iBuilder->getInt8(0), size, 4);
271        }
272    }
273
274    Value * const input = iBuilder->CreateGEP(mKernelParam, {iBuilder->getInt32(0), iBuilder->getInt32(INPUT_STREAM_SET)});
275    iBuilder->CreateStore(inputStream, iBuilder->CreateGEP(input, {iBuilder->getInt32(0), iBuilder->getInt32(0)}));
276    iBuilder->CreateStore(modFunction, iBuilder->CreateGEP(input, {iBuilder->getInt32(0), iBuilder->getInt32(1)}));
277    iBuilder->CreateRetVoid();
278
279//    if (mOutputStreamType->getStructNumElements()) {
280//        PointerType * outputStreamType = PointerType::get(mOutputStreamType, 0);
281//        FunctionType * type = FunctionType::get(outputStreamType, {outputStreamType, PointerType::get(blockNo->getType(), 0)}, false);
282//        mStreamSetFunction = Function::Create(type, Function::ExternalLinkage, mKernelName + "_StreamSet", mMod);
283//        auto arg = mStreamSetFunction->arg_begin();
284//        Value * stream = arg++;
285//        stream->setName("stream");
286//        mStreamSetFunction->addAttribute(1, Attribute::NoCapture);
287//        mStreamSetFunction->addAttribute(2, Attribute::NoCapture);
288//        mStreamSetFunction->addAttribute(AttributeSet::FunctionIndex, Attribute::InlineHint);
289//        mStreamSetFunction->addAttribute(AttributeSet::FunctionIndex, Attribute::ReadNone);
290//        mStreamSetFunction->addAttribute(AttributeSet::FunctionIndex, Attribute::NoUnwind);
291//        Value * offset = arg;
292//        BasicBlock * entry = BasicBlock::Create(mMod->getContext(), "entry", mStreamSetFunction);
293//        iBuilder->SetInsertPoint(entry);
294//        if (mBufferSize != 1) {
295//            offset = iBuilder->CreateLoad(offset);
296//            if (isPowerOfTwo(mBufferSize)) {
297//                offset = iBuilder->CreateAnd(offset, iBuilder->getInt64(mBufferSize - 1));
298//            } else if (mBufferSize > 2) {
299//                offset = iBuilder->CreateURem(offset, iBuilder->getInt64(mBufferSize));
300//            }
301//            stream = iBuilder->CreateGEP(stream, offset);
302//        }
303//        iBuilder->CreateRet(stream);
304//    }
305
306    iBuilder->ClearInsertionPoint();
307}
308
309/** ------------------------------------------------------------------------------------------------------------- *
310 * @brief instantiate
311 *
312 * Generate a new instance of this kernel and call the default constructor to initialize it
313 ** ------------------------------------------------------------------------------------------------------------- */
314Instance * KernelBuilder::instantiate(std::pair<Value *, unsigned> && inputStream) {
315    AllocaInst * const memory = iBuilder->CreateAlloca(mKernelStateType);
316    Value * const indices[] = {iBuilder->getInt32(0), iBuilder->getInt32(OUTPUT_STREAM_SET)};
317    Value * ptr = iBuilder->CreateGEP(std::get<0>(inputStream), indices);
318    iBuilder->CreateCall3(mConstructor, memory, iBuilder->CreatePointerCast(ptr, mInputStreamType), CreateModFunction(std::get<1>(inputStream)));
319    return new Instance(this, memory);
320}
321
322/** ------------------------------------------------------------------------------------------------------------- *
323 * @brief instantiate
324 *
325 * Generate a new instance of this kernel and call the default constructor to initialize it
326 ** ------------------------------------------------------------------------------------------------------------- */
327Instance * KernelBuilder::instantiate(llvm::Value * const inputStream) {
328    AllocaInst * const memory = iBuilder->CreateAlloca(mKernelStateType);
329    Value * ptr = inputStream;
330    iBuilder->CreateCall3(mConstructor, memory, iBuilder->CreatePointerCast(ptr, mInputStreamType), CreateModFunction(0));
331    return new Instance(this, memory);
332}
333
334/** ------------------------------------------------------------------------------------------------------------- *
335 * @brief CreateDoBlockCall
336 ** ------------------------------------------------------------------------------------------------------------- */
337void KernelBuilder::CreateDoBlockCall(Value * const instance) {
338    assert (mDoBlock && instance);
339    iBuilder->CreateCall(mDoBlock, instance);
340}
341
342/** ------------------------------------------------------------------------------------------------------------- *
343 * @brief clearOutputStream
344 ** ------------------------------------------------------------------------------------------------------------- */
345void KernelBuilder::clearOutputStream(Value * const instance, const unsigned streamOffset) {
346    Value * const indices[] = {iBuilder->getInt32(0), iBuilder->getInt32(OUTPUT_STREAM_SET), getOffset(instance, streamOffset)};
347    Value * ptr = iBuilder->CreateGEP(instance, indices, "ptr");
348    unsigned size = 0;
349    for (unsigned i = 0; i < mOutputStreamType->getStructNumElements(); ++i) {
350        size += mOutputStreamType->getStructElementType(i)->getPrimitiveSizeInBits();
351    }
352    iBuilder->CreateMemSet(ptr, iBuilder->getInt8(0), size / 8, 4);
353}
354
355/** ------------------------------------------------------------------------------------------------------------- *
356 * @brief offset
357 *
358 * Compute the stream index of the given offset value.
359 ** ------------------------------------------------------------------------------------------------------------- */
360Value * KernelBuilder::getOffset(Value * const instance, const unsigned value) {
361    Value * offset = nullptr;
362    if (mBufferSize > 1) {
363        offset = iBuilder->CreateLoad(getBlockNo(instance));
364        if (value) {
365            offset = iBuilder->CreateAdd(offset, iBuilder->getInt64(value));
366        }
367        if (isPowerOfTwo(mBufferSize)) {
368            offset = iBuilder->CreateAnd(offset, iBuilder->getInt64(mBufferSize - 1));
369        } else {
370            offset = iBuilder->CreateURem(offset, iBuilder->getInt64(mBufferSize));
371        }
372    } else {
373        offset = iBuilder->getInt64(value);
374    }
375    return offset;
376}
377
378/** ------------------------------------------------------------------------------------------------------------- *
379 * @brief CreateModFunction
380 *
381 * Generate a "modulo" function that dictates the local offset of a given blockNo
382 ** ------------------------------------------------------------------------------------------------------------- */
383inline Function * KernelBuilder::CreateModFunction(const unsigned size) {
384    const std::string name((size == 0) ? "continuous" : "finite" + std::to_string(size));
385    Function * function = mMod->getFunction(name);
386    if (function) {
387        return function;
388    }
389    const auto ip = iBuilder->saveIP();
390    FunctionType * type = FunctionType::get(iBuilder->getInt64Ty(), {iBuilder->getInt64Ty()}, false);
391    function = Function::Create(type, Function::ExternalLinkage, name, mMod);
392    Value * offset = function->arg_begin();
393    offset->setName("index");
394    BasicBlock * entry = BasicBlock::Create(mMod->getContext(), "entry", function);
395    iBuilder->SetInsertPoint(entry);
396    if (size) {
397        if (size == 1) {
398            offset = iBuilder->getInt64(0);
399        } else if (isPowerOfTwo(size)) {
400            offset = iBuilder->CreateAnd(offset, iBuilder->getInt64(size - 1));
401        } else {
402            offset = iBuilder->CreateURem(offset, iBuilder->getInt64(size));
403        }
404    }
405    iBuilder->CreateRet(offset);
406    iBuilder->restoreIP(ip);
407    return function;
408}
409
410/** ------------------------------------------------------------------------------------------------------------- *
411 * @brief setLongestLookaheadAmount
412 ** ------------------------------------------------------------------------------------------------------------- */
413void KernelBuilder::setLongestLookaheadAmount(const unsigned bits) {
414    const unsigned blockWidth = iBuilder->getBitBlockWidth();
415    const unsigned lookaheadBlocks = (bits + blockWidth - 1) / blockWidth;
416    mBufferSize = (lookaheadBlocks + 1);
417}
418
419} // end of namespace kernel
Note: See TracBrowser for help on using the repository browser.