source: icGREP/icgrep-devel/icgrep/kernels/interface.cpp @ 5615

Last change on this file since 5615 was 5611, checked in by cameron, 22 months ago

MaxReferenceItems? for MaxRatio?

File size: 9.9 KB
Line 
1/*
2 *  Copyright (c) 2016 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 */
5
6#include "interface.h"
7#include <llvm/IR/Value.h>         // for Value
8#include <llvm/IR/CallingConv.h>   // for ::C
9#include <llvm/IR/DerivedTypes.h>  // for FunctionType (ptr only), PointerType
10#include <llvm/IR/Function.h>      // for Function, Function::arg_iterator
11#include <llvm/IR/Module.h>
12#include <kernels/kernel_builder.h>
13
14static const auto INIT_SUFFIX = "_Init";
15
16static const auto DO_SEGMENT_SUFFIX = "_DoSegment";
17
18static const auto TERMINATE_SUFFIX = "_Terminate";
19
20using namespace llvm;
21
22ProcessingRate FixedRatio(unsigned strmItems, unsigned referenceItems, std::string && referenceStreamSet) {
23    return ProcessingRate(ProcessingRate::ProcessingRateKind::FixedRatio, strmItems, referenceItems, std::move(referenceStreamSet));
24}
25
26ProcessingRate MaxRatio(unsigned strmItems, unsigned referenceItems, std::string && referenceStreamSet) {
27    return ProcessingRate(ProcessingRate::ProcessingRateKind::MaxRatio, strmItems, referenceItems, std::move(referenceStreamSet));
28}
29
30ProcessingRate RoundUpToMultiple(unsigned itemMultiple, std::string && referenceStreamSet) {
31    return ProcessingRate(ProcessingRate::ProcessingRateKind::RoundUp, itemMultiple, itemMultiple, std::move(referenceStreamSet));
32}
33
34ProcessingRate Add1(std::string && referenceStreamSet) {
35    return ProcessingRate(ProcessingRate::ProcessingRateKind::Add1, 0, 1, std::move(referenceStreamSet));
36}
37
38ProcessingRate UnknownRate() {
39    return ProcessingRate(ProcessingRate::ProcessingRateKind::Unknown, 0, 1, "");
40}
41
42unsigned ProcessingRate::calculateRatio(unsigned referenceItems, bool doFinal) const {
43    if (mKind == ProcessingRate::ProcessingRateKind::FixedRatio || mKind == ProcessingRate::ProcessingRateKind::MaxRatio) {
44        if (mRatioNumerator == mRatioDenominator) {
45            return referenceItems;
46        }
47        unsigned strmItems = referenceItems * mRatioNumerator;
48        return (strmItems + mRatioDenominator - 1) / mRatioDenominator;
49    }
50    if (mKind == ProcessingRate::ProcessingRateKind::RoundUp) {
51        return ((referenceItems + mRatioDenominator - 1) / mRatioDenominator) * mRatioDenominator;
52    }
53    if (mKind == ProcessingRate::ProcessingRateKind::Add1) {
54        return doFinal ? referenceItems + 1 : referenceItems;
55    }
56    report_fatal_error("Processing rate calculation attempted for variable or unknown rate.");
57}
58
59Value * ProcessingRate::CreateRatioCalculation(IDISA::IDISA_Builder * const b, Value * referenceItems, Value * doFinal) const {
60    if (mKind == ProcessingRate::ProcessingRateKind::FixedRatio || mKind == ProcessingRate::ProcessingRateKind::MaxRatio) {
61        if (mRatioNumerator == mRatioDenominator) {
62            return referenceItems;
63        }
64        Type * const T = referenceItems->getType();
65        Constant * const numerator = ConstantInt::get(T, mRatioNumerator);
66        Constant * const denominator = ConstantInt::get(T, mRatioDenominator);
67        Constant * const denominatorLess1 = ConstantInt::get(T, mRatioDenominator - 1);
68        Value * strmItems = b->CreateMul(referenceItems, numerator);
69        return b->CreateUDiv(b->CreateAdd(denominatorLess1, strmItems), denominator);
70    }
71    if (mKind == ProcessingRate::ProcessingRateKind::RoundUp) {
72        Type * const T = referenceItems->getType();
73        Constant * const denominator = ConstantInt::get(T, mRatioDenominator);
74        Constant * const denominatorLess1 = ConstantInt::get(T, mRatioDenominator - 1);
75        return b->CreateMul(b->CreateUDiv(b->CreateAdd(referenceItems, denominatorLess1), denominator), denominator);
76    }
77    if (mKind == ProcessingRate::ProcessingRateKind::Add1) {
78        if (doFinal) {
79            Type * const T = referenceItems->getType();
80            referenceItems = b->CreateAdd(referenceItems, b->CreateZExt(doFinal, T));
81        }
82        return referenceItems;
83    }
84    report_fatal_error("Processing rate calculation attempted for variable or unknown rate.");
85}
86
87unsigned ProcessingRate::calculateMaxReferenceItems(unsigned outputItems, bool doFinal) const {
88    if (mKind == ProcessingRate::ProcessingRateKind::FixedRatio || mKind == ProcessingRate::ProcessingRateKind::MaxRatio) {
89        if (mRatioNumerator == mRatioDenominator) {
90            return outputItems;
91        }
92        return (outputItems / mRatioNumerator) * mRatioDenominator;
93    }
94    if (mKind == ProcessingRate::ProcessingRateKind::RoundUp) {
95        return (outputItems / mRatioDenominator) * mRatioDenominator;
96    }
97    if (mKind == ProcessingRate::ProcessingRateKind::Add1) {
98        return doFinal ? outputItems - 1 : outputItems;
99    }
100    report_fatal_error("Inverse processing rate calculation attempted for unknown rate.");
101}
102
103Value * ProcessingRate::CreateMaxReferenceItemsCalculation(IDISA::IDISA_Builder * const b, Value * outputItems, Value * doFinal) const {
104    if (mKind == ProcessingRate::ProcessingRateKind::FixedRatio || mKind == ProcessingRate::ProcessingRateKind::MaxRatio) {
105        if (mRatioNumerator == mRatioDenominator) {
106            return outputItems;
107        }
108        Type * const T = outputItems->getType();
109        Constant * const numerator = ConstantInt::get(T, mRatioNumerator);
110        Constant * const denominator = ConstantInt::get(T, mRatioDenominator);
111        return b->CreateMul(b->CreateUDiv(outputItems, numerator), denominator);
112    }
113    if (mKind == ProcessingRate::ProcessingRateKind::RoundUp) {
114        Type * const T = outputItems->getType();
115        Constant * const denominator = ConstantInt::get(T, mRatioDenominator);
116        return b->CreateMul(b->CreateUDiv(outputItems, denominator), denominator);
117    }
118    if (mKind == ProcessingRate::ProcessingRateKind::Add1) {
119        Type * const T = outputItems->getType();
120        if (doFinal) {
121            return b->CreateSub(outputItems, b->CreateZExt(doFinal, T));
122        }
123        return b->CreateSub(outputItems, ConstantInt::get(T, 1));
124    }
125    report_fatal_error("Inverse processing rate calculation attempted for unknown rate.");
126}
127
128void KernelInterface::addKernelDeclarations(const std::unique_ptr<kernel::KernelBuilder> & idb) {
129
130    if (mKernelStateType == nullptr) {
131        throw std::runtime_error("Kernel interface " + getName() + " not yet finalized.");
132    }
133
134    Module * const module = idb->getModule();
135    PointerType * const selfType = mKernelStateType->getPointerTo();
136    IntegerType * const sizeTy = idb->getSizeTy();
137    PointerType * const consumerTy = StructType::get(sizeTy, sizeTy->getPointerTo()->getPointerTo(), nullptr)->getPointerTo();
138    Type * const voidTy = idb->getVoidTy();
139
140    // Create the initialization function prototype
141    std::vector<Type *> initParameters = {selfType};
142    for (auto binding : mScalarInputs) {
143        initParameters.push_back(binding.type);
144    }
145    initParameters.insert(initParameters.end(), mStreamSetOutputs.size(), consumerTy);
146
147    FunctionType * const initType = FunctionType::get(voidTy, initParameters, false);
148    Function * const initFunc = Function::Create(initType, GlobalValue::ExternalLinkage, getName() + INIT_SUFFIX, module);
149    initFunc->setCallingConv(CallingConv::C);
150    initFunc->setDoesNotThrow();
151    auto args = initFunc->arg_begin();
152    args->setName("self");
153    for (auto binding : mScalarInputs) {
154        (++args)->setName(binding.name);
155    }
156    for (auto binding : mStreamSetOutputs) {
157        (++args)->setName(binding.name + "ConsumerLocks");
158    }
159
160    // Create the doSegment function prototype.
161    std::vector<Type *> params = {selfType, idb->getInt1Ty()};
162    params.insert(params.end(), mStreamSetInputs.size(), sizeTy);
163
164    FunctionType * const doSegmentType = FunctionType::get(voidTy, params, false);
165    Function * const doSegment = Function::Create(doSegmentType, GlobalValue::ExternalLinkage, getName() + DO_SEGMENT_SUFFIX, module);
166    doSegment->setCallingConv(CallingConv::C);
167    doSegment->setDoesNotThrow();
168    doSegment->setDoesNotCapture(1); // for self parameter only.
169    args = doSegment->arg_begin();
170    args->setName("self");
171    (++args)->setName("doFinal");
172    for (const Binding & input : mStreamSetInputs) {
173        (++args)->setName(input.name + "AvailableItems");
174    }
175
176    // Create the terminate function prototype
177    Type * resultType = nullptr;
178    if (mScalarOutputs.empty()) {
179        resultType = idb->getVoidTy();
180    } else {
181        const auto n = mScalarOutputs.size();
182        Type * outputType[n];
183        for (unsigned i = 0; i < n; ++i) {
184            outputType[i] = mScalarOutputs[i].type;
185        }
186        if (n == 1) {
187            resultType = outputType[0];
188        } else {
189            resultType = StructType::get(idb->getContext(), ArrayRef<Type *>(outputType, n));
190        }
191    }
192    FunctionType * const terminateType = FunctionType::get(resultType, {selfType}, false);
193    Function * const terminateFunc = Function::Create(terminateType, GlobalValue::ExternalLinkage, getName() + TERMINATE_SUFFIX, module);
194    terminateFunc->setCallingConv(CallingConv::C);
195    terminateFunc->setDoesNotThrow();
196    terminateFunc->setDoesNotCapture(1);
197    args = terminateFunc->arg_begin();
198    args->setName("self");
199
200    linkExternalMethods(idb);
201}
202
203Function * KernelInterface::getInitFunction(Module * const module) const {
204    const auto name = getName() + INIT_SUFFIX;
205    Function * f = module->getFunction(name);
206    if (LLVM_UNLIKELY(f == nullptr)) {
207        llvm::report_fatal_error("Cannot find " + name);
208    }
209    return f;
210}
211
212Function * KernelInterface::getDoSegmentFunction(llvm::Module * const module) const {
213    const auto name = getName() + DO_SEGMENT_SUFFIX;
214    Function * f = module->getFunction(name);
215    if (LLVM_UNLIKELY(f == nullptr)) {
216        llvm::report_fatal_error("Cannot find " + name);
217    }
218    return f;
219}
220
221Function * KernelInterface::getTerminateFunction(Module * const module) const {
222    const auto name = getName() + TERMINATE_SUFFIX;
223    Function * f = module->getFunction(name);
224    if (LLVM_UNLIKELY(f == nullptr)) {
225        llvm::report_fatal_error("Cannot find " + name);
226    }
227    return f;
228}
Note: See TracBrowser for help on using the repository browser.