source: icGREP/icgrep-devel/icgrep/kernels/interface.cpp @ 5699

Last change on this file since 5699 was 5646, checked in by nmedfort, 22 months ago

Minor clean up. Bug fix for object cache when the same cached kernel is used twice in a single run. Improvement to RE Minimizer.

File size: 9.9 KB
Line 
1/*
2 *  Copyright (c) 2016 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 */
5
6#include "interface.h"
7#include <llvm/IR/Value.h>         // for Value
8#include <llvm/IR/CallingConv.h>   // for ::C
9#include <llvm/IR/DerivedTypes.h>  // for FunctionType (ptr only), PointerType
10#include <llvm/IR/Function.h>      // for Function, Function::arg_iterator
11#include <llvm/IR/Module.h>
12#include <kernels/kernel_builder.h>
13
14static const auto INIT_SUFFIX = "_Init";
15
16static const auto DO_SEGMENT_SUFFIX = "_DoSegment";
17
18static const auto TERMINATE_SUFFIX = "_Terminate";
19
20using namespace llvm;
21
22ProcessingRate FixedRatio(unsigned strmItems, unsigned referenceItems, std::string && referenceStreamSet) {
23    return ProcessingRate(ProcessingRate::ProcessingRateKind::FixedRatio, strmItems, referenceItems, std::move(referenceStreamSet));
24}
25
26ProcessingRate MaxRatio(unsigned strmItems, unsigned referenceItems, std::string && referenceStreamSet) {
27    return ProcessingRate(ProcessingRate::ProcessingRateKind::MaxRatio, strmItems, referenceItems, std::move(referenceStreamSet));
28}
29
30ProcessingRate RoundUpToMultiple(unsigned itemMultiple, std::string && referenceStreamSet) {
31    return ProcessingRate(ProcessingRate::ProcessingRateKind::RoundUp, itemMultiple, itemMultiple, std::move(referenceStreamSet));
32}
33
34ProcessingRate Add1(std::string && referenceStreamSet) {
35    return ProcessingRate(ProcessingRate::ProcessingRateKind::Add1, 0, 1, std::move(referenceStreamSet));
36}
37
38ProcessingRate UnknownRate() {
39    return ProcessingRate(ProcessingRate::ProcessingRateKind::Unknown, 0, 1, "");
40}
41
42unsigned ProcessingRate::calculateRatio(unsigned referenceItems, bool doFinal) const {
43    if (mKind == ProcessingRate::ProcessingRateKind::FixedRatio || mKind == ProcessingRate::ProcessingRateKind::MaxRatio) {
44        if (mRatioNumerator == mRatioDenominator) {
45            return referenceItems;
46        }
47        unsigned strmItems = referenceItems * mRatioNumerator;
48        return (strmItems + mRatioDenominator - 1) / mRatioDenominator;
49    }
50    if (mKind == ProcessingRate::ProcessingRateKind::RoundUp) {
51        return ((referenceItems + mRatioDenominator - 1) / mRatioDenominator) * mRatioDenominator;
52    }
53    if (mKind == ProcessingRate::ProcessingRateKind::Add1) {
54        return doFinal ? referenceItems + 1 : referenceItems;
55    }
56    report_fatal_error("Processing rate calculation attempted for variable or unknown rate.");
57}
58
59Value * ProcessingRate::CreateRatioCalculation(IDISA::IDISA_Builder * const b, Value * referenceItems, Value * doFinal) const {
60    if (mKind == ProcessingRate::ProcessingRateKind::FixedRatio || mKind == ProcessingRate::ProcessingRateKind::MaxRatio) {
61        if (mRatioNumerator == mRatioDenominator) {
62            return referenceItems;
63        }
64        Type * const T = referenceItems->getType();
65        Constant * const numerator = ConstantInt::get(T, mRatioNumerator);
66        Constant * const denominator = ConstantInt::get(T, mRatioDenominator);
67        Constant * const denominatorLess1 = ConstantInt::get(T, mRatioDenominator - 1);
68        Value * strmItems = b->CreateMul(referenceItems, numerator);
69        return b->CreateUDiv(b->CreateAdd(denominatorLess1, strmItems), denominator);
70    }
71    if (mKind == ProcessingRate::ProcessingRateKind::RoundUp) {
72        Type * const T = referenceItems->getType();
73        Constant * const denominator = ConstantInt::get(T, mRatioDenominator);
74        Constant * const denominatorLess1 = ConstantInt::get(T, mRatioDenominator - 1);
75        return b->CreateMul(b->CreateUDiv(b->CreateAdd(referenceItems, denominatorLess1), denominator), denominator);
76    }
77    if (mKind == ProcessingRate::ProcessingRateKind::Add1) {
78        if (doFinal) {
79            Type * const T = referenceItems->getType();
80            referenceItems = b->CreateAdd(referenceItems, b->CreateZExt(doFinal, T));
81        }
82        return referenceItems;
83    }
84    report_fatal_error("Processing rate calculation attempted for variable or unknown rate.");
85}
86
87unsigned ProcessingRate::calculateMaxReferenceItems(const unsigned outputItems, const bool doFinal) const {
88    if (mKind == ProcessingRate::ProcessingRateKind::FixedRatio || mKind == ProcessingRate::ProcessingRateKind::MaxRatio) {
89        if (mRatioNumerator == mRatioDenominator) {
90            return outputItems;
91        }
92        return (outputItems / mRatioNumerator) * mRatioDenominator;
93    }
94    if (mKind == ProcessingRate::ProcessingRateKind::RoundUp) {
95        return (outputItems / mRatioDenominator) * mRatioDenominator;
96    }
97    if (mKind == ProcessingRate::ProcessingRateKind::Add1) {
98        return outputItems - (doFinal ? 1 : 0);
99    }
100    report_fatal_error("Inverse processing rate calculation attempted for unknown rate.");
101}
102
103Value * ProcessingRate::CreateMaxReferenceItemsCalculation(IDISA::IDISA_Builder * const b, Value * outputItems, Value * doFinal) const {
104    if (mKind == ProcessingRate::ProcessingRateKind::FixedRatio || mKind == ProcessingRate::ProcessingRateKind::MaxRatio) {
105        if (mRatioNumerator == mRatioDenominator) {
106            return outputItems;
107        }
108        Type * const T = outputItems->getType();
109        Constant * const numerator = ConstantInt::get(T, mRatioNumerator);
110        Constant * const denominator = ConstantInt::get(T, mRatioDenominator);
111        return b->CreateMul(b->CreateUDiv(outputItems, numerator), denominator);
112    }
113    if (mKind == ProcessingRate::ProcessingRateKind::RoundUp) {
114        Type * const T = outputItems->getType();
115        Constant * const denominator = ConstantInt::get(T, mRatioDenominator);
116        return b->CreateMul(b->CreateUDiv(outputItems, denominator), denominator);
117    }
118    if (mKind == ProcessingRate::ProcessingRateKind::Add1) {
119        Type * const T = outputItems->getType();
120        if (doFinal) {
121            return b->CreateSub(outputItems, b->CreateZExt(doFinal, T));
122        }
123        return b->CreateSub(outputItems, ConstantInt::get(T, 1));
124    }
125    report_fatal_error("Inverse processing rate calculation attempted for unknown rate.");
126}
127
128void KernelInterface::addKernelDeclarations(const std::unique_ptr<kernel::KernelBuilder> & idb) {
129
130    if (mKernelStateType == nullptr) {
131        throw std::runtime_error("Kernel interface " + getName() + " not yet finalized.");
132    }
133
134    Module * const module = idb->getModule();
135    PointerType * const selfType = mKernelStateType->getPointerTo();
136    IntegerType * const sizeTy = idb->getSizeTy();
137    PointerType * const consumerTy = StructType::get(sizeTy, sizeTy->getPointerTo()->getPointerTo(), nullptr)->getPointerTo();
138    Type * const voidTy = idb->getVoidTy();
139
140    // Create the initialization function prototype
141    std::vector<Type *> initParameters = {selfType};
142    for (auto binding : mScalarInputs) {
143        initParameters.push_back(binding.type);
144    }
145    initParameters.insert(initParameters.end(), mStreamSetOutputs.size(), consumerTy);
146
147    FunctionType * const initType = FunctionType::get(voidTy, initParameters, false);
148    Function * const initFunc = Function::Create(initType, GlobalValue::ExternalLinkage, getName() + INIT_SUFFIX, module);
149    initFunc->setCallingConv(CallingConv::C);
150    initFunc->setDoesNotThrow();
151    auto args = initFunc->arg_begin();
152    args->setName("self");
153    for (auto binding : mScalarInputs) {
154        (++args)->setName(binding.name);
155    }
156    for (auto binding : mStreamSetOutputs) {
157        (++args)->setName(binding.name + "ConsumerLocks");
158    }
159
160    // Create the doSegment function prototype.
161    std::vector<Type *> params = {selfType, idb->getInt1Ty()};
162    params.insert(params.end(), mStreamSetInputs.size(), sizeTy);
163
164    FunctionType * const doSegmentType = FunctionType::get(voidTy, params, false);
165    Function * const doSegment = Function::Create(doSegmentType, GlobalValue::ExternalLinkage, getName() + DO_SEGMENT_SUFFIX, module);
166    doSegment->setCallingConv(CallingConv::C);
167    doSegment->setDoesNotThrow();
168    doSegment->setDoesNotCapture(1); // for self parameter only.
169    args = doSegment->arg_begin();
170    args->setName("self");
171    (++args)->setName("doFinal");
172    for (const Binding & input : mStreamSetInputs) {
173        (++args)->setName(input.name + "AvailableItems");
174    }
175
176    // Create the terminate function prototype
177    Type * resultType = nullptr;
178    if (mScalarOutputs.empty()) {
179        resultType = idb->getVoidTy();
180    } else {
181        const auto n = mScalarOutputs.size();
182        Type * outputType[n];
183        for (unsigned i = 0; i < n; ++i) {
184            outputType[i] = mScalarOutputs[i].type;
185        }
186        if (n == 1) {
187            resultType = outputType[0];
188        } else {
189            resultType = StructType::get(idb->getContext(), ArrayRef<Type *>(outputType, n));
190        }
191    }
192    FunctionType * const terminateType = FunctionType::get(resultType, {selfType}, false);
193    Function * const terminateFunc = Function::Create(terminateType, GlobalValue::ExternalLinkage, getName() + TERMINATE_SUFFIX, module);
194    terminateFunc->setCallingConv(CallingConv::C);
195    terminateFunc->setDoesNotThrow();
196    terminateFunc->setDoesNotCapture(1);
197    args = terminateFunc->arg_begin();
198    args->setName("self");
199
200    linkExternalMethods(idb);
201}
202
203Function * KernelInterface::getInitFunction(Module * const module) const {
204    const auto name = getName() + INIT_SUFFIX;
205    Function * f = module->getFunction(name);
206    if (LLVM_UNLIKELY(f == nullptr)) {
207        llvm::report_fatal_error("Cannot find " + name);
208    }
209    return f;
210}
211
212Function * KernelInterface::getDoSegmentFunction(llvm::Module * const module) const {
213    const auto name = getName() + DO_SEGMENT_SUFFIX;
214    Function * f = module->getFunction(name);
215    if (LLVM_UNLIKELY(f == nullptr)) {
216        llvm::report_fatal_error("Cannot find " + name);
217    }
218    return f;
219}
220
221Function * KernelInterface::getTerminateFunction(Module * const module) const {
222    const auto name = getName() + TERMINATE_SUFFIX;
223    Function * f = module->getFunction(name);
224    if (LLVM_UNLIKELY(f == nullptr)) {
225        llvm::report_fatal_error("Cannot find " + name);
226    }
227    return f;
228}
Note: See TracBrowser for help on using the repository browser.