source: icGREP/icgrep-devel/icgrep/kernels/kernel.cpp @ 5008

Last change on this file since 5008 was 5008, checked in by nmedfort, 3 years ago

Potential fix for Mac compilers.

File size: 15.3 KB
RevLine 
[4924]1/*
2 *  Copyright (c) 2016 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 */
5
6#include "kernel.h"
[4959]7#include <pablo/function.h>
8#include <IDISA/idisa_builder.h>
[4974]9#include <kernels/instance.h>
[4991]10#include <tuple>
11#include <boost/functional/hash_fwd.hpp>
12#include <unordered_map>
[4924]13
[4959]14using namespace llvm;
15using namespace pablo;
16
[4974]17namespace kernel {
18
[4924]19// sets name & sets internal state to the kernel superclass state
[5000]20KernelBuilder::KernelBuilder(IDISA::IDISA_Builder * builder, std::string && name, const unsigned defaultBufferSize)
21: iBuilder(builder)
[4924]22, mKernelName(name)
[5000]23, mDefaultBufferSize(defaultBufferSize)
24, mBitBlockType(builder->getBitBlockType())
[4986]25, mBlockNoIndex(0) {
[5000]26    assert (mDefaultBufferSize > 0);
[4924]27}
28
[4968]29/** ------------------------------------------------------------------------------------------------------------- *
[4970]30 * @brief addInternalState
[4968]31 ** ------------------------------------------------------------------------------------------------------------- */
[4970]32unsigned KernelBuilder::addInternalState(Type * const type) {
[4968]33    assert (type);
[4974]34    const unsigned index = mInternalState.size();
35    mInternalState.push_back(type);
[4968]36    return index;
[4924]37}
38
[4974]39unsigned KernelBuilder::addInternalState(llvm::Type * const type, std::string && name) {
40    if (LLVM_UNLIKELY(mInternalStateNameMap.count(name) != 0)) {
[4991]41        throw std::runtime_error("Kernel already contains internal state '" + name + "'");
[4970]42    }
43    const unsigned index = addInternalState(type);
[5000]44    mInternalStateNameMap.emplace(name, iBuilder->getInt32(index));
[4970]45    return index;
46}
47
[4968]48/** ------------------------------------------------------------------------------------------------------------- *
[4974]49 * @brief getInternalState
[4968]50 ** ------------------------------------------------------------------------------------------------------------- */
[5008]51Value * KernelBuilder::getInternalStateInternal(Value * const kernelState, const std::string & name) {
[4974]52    const auto f = mInternalStateNameMap.find(name);
53    if (LLVM_UNLIKELY(f == mInternalStateNameMap.end())) {
54        throw std::runtime_error("Kernel does not contain internal state " + name);
55    }
[5008]56    return getInternalStateInternal(kernelState, f->second);
[4974]57}
58
[5008]59Value * KernelBuilder::getInternalStateInternal(Value * const kernelState, disable_implicit_conversion<Value *> index) {
60    assert (index->getType()->isIntegerTy());
61    assert (kernelState->getType()->getPointerElementType() == mKernelStateType);
62    return iBuilder->CreateGEP(kernelState, {iBuilder->getInt32(0), index});
63}
64
[4968]65/** ------------------------------------------------------------------------------------------------------------- *
[4974]66 * @brief setInternalState
[4968]67 ** ------------------------------------------------------------------------------------------------------------- */
[5008]68void KernelBuilder::setInternalStateInternal(Value * const kernelState, const std::string & name, Value * const value) {
69    Value * ptr = getInternalStateInternal(kernelState, name);
[4974]70    assert (ptr->getType()->getPointerElementType() == value->getType());
71    if (value->getType() == iBuilder->getBitBlockType()) {
72        iBuilder->CreateBlockAlignedStore(value, ptr);
73    } else {
74        iBuilder->CreateStore(value, ptr);
75    }
[4924]76}
[4968]77
[5008]78void KernelBuilder::setInternalStateInternal(Value * const kernelState, disable_implicit_conversion<Value *> index, Value * const value) {
79    Value * ptr = getInternalStateInternal(kernelState, index);
[4974]80    assert (ptr->getType()->getPointerElementType() == value->getType());
81    if (value->getType() == iBuilder->getBitBlockType()) {
82        iBuilder->CreateBlockAlignedStore(value, ptr);
83    } else {
84        iBuilder->CreateStore(value, ptr);
85    }
86}
87
[4968]88/** ------------------------------------------------------------------------------------------------------------- *
89 * @brief addInputStream
90 ** ------------------------------------------------------------------------------------------------------------- */
[4974]91void KernelBuilder::addInputStream(const unsigned fields, std::string && name) {
[4970]92    assert (fields > 0 && !name.empty());
[4974]93    mInputStreamName.push_back(name);
[4986]94    if (fields == 1) {
[4974]95        mInputStream.push_back(mBitBlockType);
[4959]96    } else {
[4974]97        mInputStream.push_back(ArrayType::get(mBitBlockType, fields));
[4924]98    }
99}
[4968]100
[4970]101void KernelBuilder::addInputStream(const unsigned fields) {
[4986]102    addInputStream(fields, std::move(mKernelName + "_InputStream_" + std::to_string(mInputStream.size())));
[4970]103}
104
[4968]105/** ------------------------------------------------------------------------------------------------------------- *
[4970]106 * @brief getInputStream
107 ** ------------------------------------------------------------------------------------------------------------- */
[5008]108Value * KernelBuilder::getInputStreamInternal(Value * const inputStreamSet, disable_implicit_conversion<Value *> index) {
[5000]109    assert ("Parameters cannot be null!" && (inputStreamSet != nullptr && index != nullptr));
110    assert ("Stream index must be an integer!" && index->getType()->isIntegerTy());
111    assert ("Illegal input stream set provided!" && inputStreamSet->getType()->getPointerElementType() == mInputStreamType);
112    if (LLVM_LIKELY(isa<ConstantInt>(index.get()) || getInputStreamType()->isArrayTy())) {
113        return iBuilder->CreateGEP(inputStreamSet, { iBuilder->getInt32(0), index });
[4992]114    }
[5000]115    throw std::runtime_error("Cannot access the input stream with a non-constant value unless all input stream types are identical!");
[4970]116}
117
118/** ------------------------------------------------------------------------------------------------------------- *
[4974]119 * @brief addInputScalar
[4970]120 ** ------------------------------------------------------------------------------------------------------------- */
[4974]121void KernelBuilder::addInputScalar(Type * const type, std::string && name) {
122    assert (type && !name.empty());
123    mInputScalarName.push_back(name);
124    mInputScalar.push_back(type);
[4970]125}
126
[4974]127void KernelBuilder::addInputScalar(Type * const type) {
[4986]128    addInputScalar(type, std::move(mKernelName + "_InputScalar_" + std::to_string(mInputScalar.size())));
[4970]129}
130
131/** ------------------------------------------------------------------------------------------------------------- *
[4974]132 * @brief getInputScalar
[4970]133 ** ------------------------------------------------------------------------------------------------------------- */
[5008]134Value * KernelBuilder::getInputScalarInternal(Value * const inputScalarSet, disable_implicit_conversion<Value *>) {
[5000]135    assert (inputScalarSet);
[4970]136    throw std::runtime_error("currently not supported!");
137}
138
139/** ------------------------------------------------------------------------------------------------------------- *
[4974]140 * @brief addOutputStream
[4970]141 ** ------------------------------------------------------------------------------------------------------------- */
[4974]142unsigned KernelBuilder::addOutputStream(const unsigned fields) {
143    assert (fields > 0);
144    const unsigned index = mOutputStream.size();
145    mOutputStream.push_back((fields == 1) ? mBitBlockType : ArrayType::get(mBitBlockType, fields));
146    return index;
[4970]147}
148
[4974]149/** ------------------------------------------------------------------------------------------------------------- *
150 * @brief addOutputScalar
151 ** ------------------------------------------------------------------------------------------------------------- */
152unsigned KernelBuilder::addOutputScalar(Type * const type) {
153    assert (type);
154    const unsigned index = mOutputScalar.size();
155    mOutputScalar.push_back(type);
156    return index;
[4970]157}
158
159/** ------------------------------------------------------------------------------------------------------------- *
[4974]160 * @brief getOutputStream
[4970]161 ** ------------------------------------------------------------------------------------------------------------- */
[5008]162Value * KernelBuilder::getOutputStreamInternal(Value * const outputStreamSet, disable_implicit_conversion<Value *> index) {
[5000]163    assert ("Parameters cannot be null!" && (outputStreamSet != nullptr && index != nullptr));
164    assert ("Stream index must be an integer!" && index->getType()->isIntegerTy());
165    assert ("Illegal output stream set provided!" && outputStreamSet->getType()->getPointerElementType() == getOutputStreamType());
166    if (LLVM_LIKELY(isa<ConstantInt>(index.get()) || getOutputStreamType()->isArrayTy())) {
167        return iBuilder->CreateGEP(outputStreamSet, { iBuilder->getInt32(0), index });
[4995]168    }
169    throw std::runtime_error("Cannot access the output stream with a non-constant value unless all output stream types are identical!");
[4992]170}
171
[4970]172/** ------------------------------------------------------------------------------------------------------------- *
[4974]173 * @brief getOutputScalar
174 ** ------------------------------------------------------------------------------------------------------------- */
[5008]175Value * KernelBuilder::getOutputScalarInternal(Value * const, disable_implicit_conversion<Value *> ) {
[4974]176    throw std::runtime_error("currently not supported!");
[4970]177}
178
[4959]179/** ------------------------------------------------------------------------------------------------------------- *
[4995]180 * @brief packDataTypes
181 ** ------------------------------------------------------------------------------------------------------------- */
[5000]182Type * KernelBuilder::packDataTypes(const std::vector<llvm::Type *> & types) {
183    if (types.empty()) {
184        return nullptr;
185    }
[4995]186    for (Type * type : types) {
187        if (type != types.front()) { // use canLosslesslyBitcastInto ?
[5000]188            return StructType::get(iBuilder->getContext(), types);
[4995]189        }
190    }
[5000]191    return ArrayType::get(types.front(), types.size());
[4995]192}
193
194/** ------------------------------------------------------------------------------------------------------------- *
[4959]195 * @brief prepareFunction
196 ** ------------------------------------------------------------------------------------------------------------- */
[5000]197Function * KernelBuilder::prepareFunction(std::vector<unsigned> && inputStreamOffsets) {
[4968]198
[5001]199    mBlockNoIndex = iBuilder->getInt32(addInternalState(iBuilder->getInt64Ty(), "BlockNo"));
200
[5000]201    mKernelStateType = StructType::create(iBuilder->getContext(), mInternalState, mKernelName);
202    mInputScalarType = packDataTypes(mInputScalar);
203    mInputStreamType = packDataTypes(mInputStream);
204    mOutputScalarType = packDataTypes(mInputScalar);
[4995]205    mOutputStreamType = packDataTypes(mOutputStream);
[5000]206    mInputStreamOffsets = inputStreamOffsets;
[4930]207
[5000]208    std::vector<Type *> params;
209    params.push_back(mKernelStateType->getPointerTo());
210    if (mInputScalarType) {
211        params.push_back(mInputScalarType->getPointerTo());
212    }
213    if (mInputStreamType) {
214        for (unsigned i = 0; i < mInputStreamOffsets.size(); ++i) {
215            params.push_back(mInputStreamType->getPointerTo());
216        }
217    }
218    if (mOutputScalarType) {
219        params.push_back(mOutputScalarType->getPointerTo());
220    }
221    if (mOutputStreamType) {
222        params.push_back(mOutputStreamType->getPointerTo());
223    }
224
225    // A pointer value is captured if the function makes a copy of any part of the pointer that outlives
226    // the call (e.g., stored in a global or, depending on the context, when returned by the function.)
227    // Since this does not occur in either our DoBlock or Constructor, all parameters are marked nocapture.
228
229    FunctionType * const functionType = FunctionType::get(iBuilder->getVoidTy(), params, false);
230    mDoBlock = Function::Create(functionType, GlobalValue::ExternalLinkage, mKernelName + "_DoBlock", iBuilder->getModule());
231    mDoBlock->setCallingConv(CallingConv::C);
232    for (unsigned i = 1; i <= params.size(); ++i) {
233        mDoBlock->setDoesNotCapture(i);
234    }
[4991]235    mDoBlock->setDoesNotThrow();
[4986]236    Function::arg_iterator args = mDoBlock->arg_begin();
[5000]237    mKernelStateParam = args++;
238    mKernelStateParam->setName("this");
239    if (mInputScalarType) {
240        mInputScalarParam = args++;
241        mInputScalarParam->setName("inputScalars");
242    }
243    if (mInputStreamType) {
244        for (const unsigned offset : mInputStreamOffsets) {
245            Value * const inputStreamSet = args++;
246            inputStreamSet->setName("inputStreamSet" + std::to_string(offset));
247            mInputStreamParam.emplace(offset, inputStreamSet);
248        }
249    }
250    if (mOutputScalarType) {
251        mOutputScalarParam = args++;
252        mOutputScalarParam->setName("outputScalars");
253    }
254    if (mOutputStreamType) {
255        mOutputStreamParam = args;
256        mOutputStreamParam->setName("outputStreamSet");
257    }
258    iBuilder->SetInsertPoint(BasicBlock::Create(iBuilder->getContext(), "entry", mDoBlock, 0));
[4986]259    return mDoBlock;
[4924]260}
261
[4959]262/** ------------------------------------------------------------------------------------------------------------- *
263 * @brief finalize
264 ** ------------------------------------------------------------------------------------------------------------- */
265void KernelBuilder::finalize() {
266    // Finish the actual function
[4986]267    Value * blockNo = getBlockNo();
268    Value * value = iBuilder->CreateLoad(blockNo);
[4970]269    value = iBuilder->CreateAdd(value, ConstantInt::get(value->getType(), 1));
[4986]270    iBuilder->CreateStore(value, blockNo);
[4924]271    iBuilder->CreateRetVoid();
272
[5000]273    mKernelStateParam = nullptr;
274    mInputScalarParam = nullptr;
275    mInputStreamParam.clear();
276    mOutputScalarParam = nullptr;
277    mOutputStreamParam = nullptr;
[4974]278    iBuilder->ClearInsertionPoint();
[4986]279}
[4924]280
[4986]281/** ------------------------------------------------------------------------------------------------------------- *
282 * @brief instantiate
283 *
[5000]284 * Allocate and zero initialize the memory for this kernel and its output scalars and streams
[4986]285 ** ------------------------------------------------------------------------------------------------------------- */
[5000]286Instance * KernelBuilder::instantiate(std::pair<Value *, unsigned> && inputStreamSet, const unsigned outputBufferSize) {
287    AllocaInst * const kernelState = iBuilder->CreateAlloca(mKernelStateType);
288    iBuilder->CreateStore(Constant::getNullValue(mKernelStateType), kernelState);
289    AllocaInst * outputScalars = nullptr;
290    if (mOutputScalarType) {
291        outputScalars = iBuilder->CreateAlloca(mOutputScalarType);
292    }
293    AllocaInst * outputStreamSets = nullptr;
294    if (mOutputStreamType) {
295        outputStreamSets = iBuilder->CreateAlloca(mOutputStreamType, iBuilder->getInt32(outputBufferSize));
296    }
297    return new Instance(this, kernelState, nullptr, std::get<0>(inputStreamSet), std::get<1>(inputStreamSet), outputScalars, outputStreamSets, outputBufferSize);
[4959]298}
[4924]299
[4959]300/** ------------------------------------------------------------------------------------------------------------- *
[4974]301 * @brief instantiate
302 *
303 * Generate a new instance of this kernel and call the default constructor to initialize it
[4959]304 ** ------------------------------------------------------------------------------------------------------------- */
[5000]305Instance * KernelBuilder::instantiate(std::initializer_list<llvm::Value *> inputStreams) {   
[5001]306    AllocaInst * inputStruct = iBuilder->CreateAlloca(mInputStreamType);
307    unsigned i = 0;
308    for (Value * inputStream : inputStreams) {
309        Value * ptr = iBuilder->CreateGEP(inputStruct, { iBuilder->getInt32(0), iBuilder->getInt32(i++)});
310        iBuilder->CreateStore(inputStream, ptr);
311    }
312    return instantiate(std::make_pair(inputStruct, 0));
[4924]313}
314
[5008]315Value * KernelBuilder::getInputStreamParam(const unsigned streamOffset) const {
316    const auto f = mInputStreamParam.find(streamOffset);
317    if (LLVM_UNLIKELY(f == mInputStreamParam.end())) {
318        throw std::runtime_error("Kernel compilation error: No input stream parameter for stream offset " + std::to_string(streamOffset));
319    }
320    return f->second;
321}
322
[4974]323} // end of namespace kernel
Note: See TracBrowser for help on using the repository browser.