source: icGREP/icgrep-devel/icgrep/kernels/kernel.cpp @ 5008

Last change on this file since 5008 was 5008, checked in by nmedfort, 3 years ago

Potential fix for Mac compilers.

File size: 15.3 KB
Line 
1/*
2 *  Copyright (c) 2016 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 */
5
6#include "kernel.h"
7#include <pablo/function.h>
8#include <IDISA/idisa_builder.h>
9#include <kernels/instance.h>
10#include <tuple>
11#include <boost/functional/hash_fwd.hpp>
12#include <unordered_map>
13
14using namespace llvm;
15using namespace pablo;
16
17namespace kernel {
18
19// sets name & sets internal state to the kernel superclass state
20KernelBuilder::KernelBuilder(IDISA::IDISA_Builder * builder, std::string && name, const unsigned defaultBufferSize)
21: iBuilder(builder)
22, mKernelName(name)
23, mDefaultBufferSize(defaultBufferSize)
24, mBitBlockType(builder->getBitBlockType())
25, mBlockNoIndex(0) {
26    assert (mDefaultBufferSize > 0);
27}
28
29/** ------------------------------------------------------------------------------------------------------------- *
30 * @brief addInternalState
31 ** ------------------------------------------------------------------------------------------------------------- */
32unsigned KernelBuilder::addInternalState(Type * const type) {
33    assert (type);
34    const unsigned index = mInternalState.size();
35    mInternalState.push_back(type);
36    return index;
37}
38
39unsigned KernelBuilder::addInternalState(llvm::Type * const type, std::string && name) {
40    if (LLVM_UNLIKELY(mInternalStateNameMap.count(name) != 0)) {
41        throw std::runtime_error("Kernel already contains internal state '" + name + "'");
42    }
43    const unsigned index = addInternalState(type);
44    mInternalStateNameMap.emplace(name, iBuilder->getInt32(index));
45    return index;
46}
47
48/** ------------------------------------------------------------------------------------------------------------- *
49 * @brief getInternalState
50 ** ------------------------------------------------------------------------------------------------------------- */
51Value * KernelBuilder::getInternalStateInternal(Value * const kernelState, const std::string & name) {
52    const auto f = mInternalStateNameMap.find(name);
53    if (LLVM_UNLIKELY(f == mInternalStateNameMap.end())) {
54        throw std::runtime_error("Kernel does not contain internal state " + name);
55    }
56    return getInternalStateInternal(kernelState, f->second);
57}
58
59Value * KernelBuilder::getInternalStateInternal(Value * const kernelState, disable_implicit_conversion<Value *> index) {
60    assert (index->getType()->isIntegerTy());
61    assert (kernelState->getType()->getPointerElementType() == mKernelStateType);
62    return iBuilder->CreateGEP(kernelState, {iBuilder->getInt32(0), index});
63}
64
65/** ------------------------------------------------------------------------------------------------------------- *
66 * @brief setInternalState
67 ** ------------------------------------------------------------------------------------------------------------- */
68void KernelBuilder::setInternalStateInternal(Value * const kernelState, const std::string & name, Value * const value) {
69    Value * ptr = getInternalStateInternal(kernelState, name);
70    assert (ptr->getType()->getPointerElementType() == value->getType());
71    if (value->getType() == iBuilder->getBitBlockType()) {
72        iBuilder->CreateBlockAlignedStore(value, ptr);
73    } else {
74        iBuilder->CreateStore(value, ptr);
75    }
76}
77
78void KernelBuilder::setInternalStateInternal(Value * const kernelState, disable_implicit_conversion<Value *> index, Value * const value) {
79    Value * ptr = getInternalStateInternal(kernelState, index);
80    assert (ptr->getType()->getPointerElementType() == value->getType());
81    if (value->getType() == iBuilder->getBitBlockType()) {
82        iBuilder->CreateBlockAlignedStore(value, ptr);
83    } else {
84        iBuilder->CreateStore(value, ptr);
85    }
86}
87
88/** ------------------------------------------------------------------------------------------------------------- *
89 * @brief addInputStream
90 ** ------------------------------------------------------------------------------------------------------------- */
91void KernelBuilder::addInputStream(const unsigned fields, std::string && name) {
92    assert (fields > 0 && !name.empty());
93    mInputStreamName.push_back(name);
94    if (fields == 1) {
95        mInputStream.push_back(mBitBlockType);
96    } else {
97        mInputStream.push_back(ArrayType::get(mBitBlockType, fields));
98    }
99}
100
101void KernelBuilder::addInputStream(const unsigned fields) {
102    addInputStream(fields, std::move(mKernelName + "_InputStream_" + std::to_string(mInputStream.size())));
103}
104
105/** ------------------------------------------------------------------------------------------------------------- *
106 * @brief getInputStream
107 ** ------------------------------------------------------------------------------------------------------------- */
108Value * KernelBuilder::getInputStreamInternal(Value * const inputStreamSet, disable_implicit_conversion<Value *> index) {
109    assert ("Parameters cannot be null!" && (inputStreamSet != nullptr && index != nullptr));
110    assert ("Stream index must be an integer!" && index->getType()->isIntegerTy());
111    assert ("Illegal input stream set provided!" && inputStreamSet->getType()->getPointerElementType() == mInputStreamType);
112    if (LLVM_LIKELY(isa<ConstantInt>(index.get()) || getInputStreamType()->isArrayTy())) {
113        return iBuilder->CreateGEP(inputStreamSet, { iBuilder->getInt32(0), index });
114    }
115    throw std::runtime_error("Cannot access the input stream with a non-constant value unless all input stream types are identical!");
116}
117
118/** ------------------------------------------------------------------------------------------------------------- *
119 * @brief addInputScalar
120 ** ------------------------------------------------------------------------------------------------------------- */
121void KernelBuilder::addInputScalar(Type * const type, std::string && name) {
122    assert (type && !name.empty());
123    mInputScalarName.push_back(name);
124    mInputScalar.push_back(type);
125}
126
127void KernelBuilder::addInputScalar(Type * const type) {
128    addInputScalar(type, std::move(mKernelName + "_InputScalar_" + std::to_string(mInputScalar.size())));
129}
130
131/** ------------------------------------------------------------------------------------------------------------- *
132 * @brief getInputScalar
133 ** ------------------------------------------------------------------------------------------------------------- */
134Value * KernelBuilder::getInputScalarInternal(Value * const inputScalarSet, disable_implicit_conversion<Value *>) {
135    assert (inputScalarSet);
136    throw std::runtime_error("currently not supported!");
137}
138
139/** ------------------------------------------------------------------------------------------------------------- *
140 * @brief addOutputStream
141 ** ------------------------------------------------------------------------------------------------------------- */
142unsigned KernelBuilder::addOutputStream(const unsigned fields) {
143    assert (fields > 0);
144    const unsigned index = mOutputStream.size();
145    mOutputStream.push_back((fields == 1) ? mBitBlockType : ArrayType::get(mBitBlockType, fields));
146    return index;
147}
148
149/** ------------------------------------------------------------------------------------------------------------- *
150 * @brief addOutputScalar
151 ** ------------------------------------------------------------------------------------------------------------- */
152unsigned KernelBuilder::addOutputScalar(Type * const type) {
153    assert (type);
154    const unsigned index = mOutputScalar.size();
155    mOutputScalar.push_back(type);
156    return index;
157}
158
159/** ------------------------------------------------------------------------------------------------------------- *
160 * @brief getOutputStream
161 ** ------------------------------------------------------------------------------------------------------------- */
162Value * KernelBuilder::getOutputStreamInternal(Value * const outputStreamSet, disable_implicit_conversion<Value *> index) {
163    assert ("Parameters cannot be null!" && (outputStreamSet != nullptr && index != nullptr));
164    assert ("Stream index must be an integer!" && index->getType()->isIntegerTy());
165    assert ("Illegal output stream set provided!" && outputStreamSet->getType()->getPointerElementType() == getOutputStreamType());
166    if (LLVM_LIKELY(isa<ConstantInt>(index.get()) || getOutputStreamType()->isArrayTy())) {
167        return iBuilder->CreateGEP(outputStreamSet, { iBuilder->getInt32(0), index });
168    }
169    throw std::runtime_error("Cannot access the output stream with a non-constant value unless all output stream types are identical!");
170}
171
172/** ------------------------------------------------------------------------------------------------------------- *
173 * @brief getOutputScalar
174 ** ------------------------------------------------------------------------------------------------------------- */
175Value * KernelBuilder::getOutputScalarInternal(Value * const, disable_implicit_conversion<Value *> ) {
176    throw std::runtime_error("currently not supported!");
177}
178
179/** ------------------------------------------------------------------------------------------------------------- *
180 * @brief packDataTypes
181 ** ------------------------------------------------------------------------------------------------------------- */
182Type * KernelBuilder::packDataTypes(const std::vector<llvm::Type *> & types) {
183    if (types.empty()) {
184        return nullptr;
185    }
186    for (Type * type : types) {
187        if (type != types.front()) { // use canLosslesslyBitcastInto ?
188            return StructType::get(iBuilder->getContext(), types);
189        }
190    }
191    return ArrayType::get(types.front(), types.size());
192}
193
194/** ------------------------------------------------------------------------------------------------------------- *
195 * @brief prepareFunction
196 ** ------------------------------------------------------------------------------------------------------------- */
197Function * KernelBuilder::prepareFunction(std::vector<unsigned> && inputStreamOffsets) {
198
199    mBlockNoIndex = iBuilder->getInt32(addInternalState(iBuilder->getInt64Ty(), "BlockNo"));
200
201    mKernelStateType = StructType::create(iBuilder->getContext(), mInternalState, mKernelName);
202    mInputScalarType = packDataTypes(mInputScalar);
203    mInputStreamType = packDataTypes(mInputStream);
204    mOutputScalarType = packDataTypes(mInputScalar);
205    mOutputStreamType = packDataTypes(mOutputStream);
206    mInputStreamOffsets = inputStreamOffsets;
207
208    std::vector<Type *> params;
209    params.push_back(mKernelStateType->getPointerTo());
210    if (mInputScalarType) {
211        params.push_back(mInputScalarType->getPointerTo());
212    }
213    if (mInputStreamType) {
214        for (unsigned i = 0; i < mInputStreamOffsets.size(); ++i) {
215            params.push_back(mInputStreamType->getPointerTo());
216        }
217    }
218    if (mOutputScalarType) {
219        params.push_back(mOutputScalarType->getPointerTo());
220    }
221    if (mOutputStreamType) {
222        params.push_back(mOutputStreamType->getPointerTo());
223    }
224
225    // A pointer value is captured if the function makes a copy of any part of the pointer that outlives
226    // the call (e.g., stored in a global or, depending on the context, when returned by the function.)
227    // Since this does not occur in either our DoBlock or Constructor, all parameters are marked nocapture.
228
229    FunctionType * const functionType = FunctionType::get(iBuilder->getVoidTy(), params, false);
230    mDoBlock = Function::Create(functionType, GlobalValue::ExternalLinkage, mKernelName + "_DoBlock", iBuilder->getModule());
231    mDoBlock->setCallingConv(CallingConv::C);
232    for (unsigned i = 1; i <= params.size(); ++i) {
233        mDoBlock->setDoesNotCapture(i);
234    }
235    mDoBlock->setDoesNotThrow();
236    Function::arg_iterator args = mDoBlock->arg_begin();
237    mKernelStateParam = args++;
238    mKernelStateParam->setName("this");
239    if (mInputScalarType) {
240        mInputScalarParam = args++;
241        mInputScalarParam->setName("inputScalars");
242    }
243    if (mInputStreamType) {
244        for (const unsigned offset : mInputStreamOffsets) {
245            Value * const inputStreamSet = args++;
246            inputStreamSet->setName("inputStreamSet" + std::to_string(offset));
247            mInputStreamParam.emplace(offset, inputStreamSet);
248        }
249    }
250    if (mOutputScalarType) {
251        mOutputScalarParam = args++;
252        mOutputScalarParam->setName("outputScalars");
253    }
254    if (mOutputStreamType) {
255        mOutputStreamParam = args;
256        mOutputStreamParam->setName("outputStreamSet");
257    }
258    iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "entry", mDoBlock, 0));
259    return mDoBlock;
260}
261
262/** ------------------------------------------------------------------------------------------------------------- *
263 * @brief finalize
264 ** ------------------------------------------------------------------------------------------------------------- */
265void KernelBuilder::finalize() {
266    // Finish the actual function
267    Value * blockNo = getBlockNo();
268    Value * value = iBuilder->CreateLoad(blockNo);
269    value = iBuilder->CreateAdd(value, ConstantInt::get(value->getType(), 1));
270    iBuilder->CreateStore(value, blockNo);
271    iBuilder->CreateRetVoid();
272
273    mKernelStateParam = nullptr;
274    mInputScalarParam = nullptr;
275    mInputStreamParam.clear();
276    mOutputScalarParam = nullptr;
277    mOutputStreamParam = nullptr;
278    iBuilder->ClearInsertionPoint();
279}
280
281/** ------------------------------------------------------------------------------------------------------------- *
282 * @brief instantiate
283 *
284 * Allocate and zero initialize the memory for this kernel and its output scalars and streams
285 ** ------------------------------------------------------------------------------------------------------------- */
286Instance * KernelBuilder::instantiate(std::pair<Value *, unsigned> && inputStreamSet, const unsigned outputBufferSize) {
287    AllocaInst * const kernelState = iBuilder->CreateAlloca(mKernelStateType);
288    iBuilder->CreateStore(Constant::getNullValue(mKernelStateType), kernelState);
289    AllocaInst * outputScalars = nullptr;
290    if (mOutputScalarType) {
291        outputScalars = iBuilder->CreateAlloca(mOutputScalarType);
292    }
293    AllocaInst * outputStreamSets = nullptr;
294    if (mOutputStreamType) {
295        outputStreamSets = iBuilder->CreateAlloca(mOutputStreamType, iBuilder->getInt32(outputBufferSize));
296    }
297    return new Instance(this, kernelState, nullptr, std::get<0>(inputStreamSet), std::get<1>(inputStreamSet), outputScalars, outputStreamSets, outputBufferSize);
298}
299
300/** ------------------------------------------------------------------------------------------------------------- *
301 * @brief instantiate
302 *
303 * Generate a new instance of this kernel and call the default constructor to initialize it
304 ** ------------------------------------------------------------------------------------------------------------- */
305Instance * KernelBuilder::instantiate(std::initializer_list<llvm::Value *> inputStreams) {   
306    AllocaInst * inputStruct = iBuilder->CreateAlloca(mInputStreamType);
307    unsigned i = 0;
308    for (Value * inputStream : inputStreams) {
309        Value * ptr = iBuilder->CreateGEP(inputStruct, { iBuilder->getInt32(0), iBuilder->getInt32(i++)});
310        iBuilder->CreateStore(inputStream, ptr);
311    }
312    return instantiate(std::make_pair(inputStruct, 0));
313}
314
315Value * KernelBuilder::getInputStreamParam(const unsigned streamOffset) const {
316    const auto f = mInputStreamParam.find(streamOffset);
317    if (LLVM_UNLIKELY(f == mInputStreamParam.end())) {
318        throw std::runtime_error("Kernel compilation error: No input stream parameter for stream offset " + std::to_string(streamOffset));
319    }
320    return f->second;
321}
322
323} // end of namespace kernel
Note: See TracBrowser for help on using the repository browser.