source: icGREP/icgrep-devel/icgrep/kernels/interface.cpp @ 5433

Last change on this file since 5433 was 5433, checked in by cameron, 2 years ago

MaxReferenceItemsCalculation? initial check-in

File size: 8.2 KB
Line 
1/*
2 *  Copyright (c) 2016 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 */
5
6#include "interface.h"
7#include <llvm/IR/Value.h>         // for Value
8#include <llvm/IR/CallingConv.h>   // for ::C
9#include <llvm/IR/DerivedTypes.h>  // for FunctionType (ptr only), PointerType
10#include <llvm/IR/Function.h>      // for Function, Function::arg_iterator
11#include <llvm/IR/Module.h>
12#include <IR_Gen/idisa_builder.h>
13
14static const auto INIT_SUFFIX = "_Init";
15
16static const auto DO_SEGMENT_SUFFIX = "_DoSegment";
17
18static const auto TERMINATE_SUFFIX = "_Terminate";
19
20using namespace llvm;
21
22ProcessingRate FixedRatio(unsigned strmItems, unsigned referenceItems, std::string && referenceStreamSet) {
23    return ProcessingRate(ProcessingRate::ProcessingRateKind::FixedRatio, strmItems, referenceItems, std::move(referenceStreamSet));
24}
25
26ProcessingRate MaxRatio(unsigned strmItems, unsigned referenceItems, std::string && referenceStreamSet) {
27    return ProcessingRate(ProcessingRate::ProcessingRateKind::MaxRatio, strmItems, referenceItems, std::move(referenceStreamSet));
28}
29
30ProcessingRate RoundUpToMultiple(unsigned itemMultiple, std::string && referenceStreamSet) {
31    return ProcessingRate(ProcessingRate::ProcessingRateKind::RoundUp, itemMultiple, itemMultiple, std::move(referenceStreamSet));
32}
33
34ProcessingRate Add1(std::string && referenceStreamSet) {
35    return ProcessingRate(ProcessingRate::ProcessingRateKind::Add1, 0, 0, std::move(referenceStreamSet));
36}
37
38ProcessingRate UnknownRate() {
39    return ProcessingRate(ProcessingRate::ProcessingRateKind::Unknown, 0, 0, "");
40}
41
42Value * ProcessingRate::CreateRatioCalculation(IDISA::IDISA_Builder * b, Value * referenceItems, Value * doFinal) const {
43    if (mKind == ProcessingRate::ProcessingRateKind::FixedRatio || mKind == ProcessingRate::ProcessingRateKind::MaxRatio) {
44        if (mRatioNumerator == mRatioDenominator) {
45            return referenceItems;
46        }
47        Type * const T = referenceItems->getType();
48        Constant * const numerator = ConstantInt::get(T, mRatioNumerator);
49        Constant * const denominator = ConstantInt::get(T, mRatioDenominator);
50        Constant * const denominatorLess1 = ConstantInt::get(T, mRatioDenominator - 1);
51        Value * strmItems = b->CreateMul(referenceItems, numerator);
52        return b->CreateUDiv(b->CreateAdd(denominatorLess1, strmItems), denominator);
53    }
54    if (mKind == ProcessingRate::ProcessingRateKind::RoundUp) {
55        Type * const T = referenceItems->getType();
56        Constant * const denominator = ConstantInt::get(T, mRatioDenominator);
57        Constant * const denominatorLess1 = ConstantInt::get(T, mRatioDenominator - 1);
58        return b->CreateMul(b->CreateUDiv(b->CreateAdd(referenceItems, denominatorLess1), denominator), denominator);
59    }
60    if (mKind == ProcessingRate::ProcessingRateKind::Add1) {
61        if (doFinal) {
62            Type * const T = referenceItems->getType();
63            referenceItems = b->CreateAdd(referenceItems, b->CreateZExt(doFinal, T));
64        }
65        return referenceItems;
66    }
67    return nullptr;
68}
69
70Value * ProcessingRate::CreateMaxReferenceItemsCalculation(IDISA::IDISA_Builder * b, Value * outputItems, Value * doFinal) const {
71    if (mKind == ProcessingRate::ProcessingRateKind::FixedRatio) {
72        if (mRatioNumerator == mRatioDenominator) {
73            return outputItems;
74        }
75        Type * const T = outputItems->getType();
76        Constant * const numerator = ConstantInt::get(T, mRatioNumerator);
77        Constant * const denominator = ConstantInt::get(T, mRatioDenominator);
78        return b->CreateMul(b->CreateUDiv(outputItems, numerator), denominator);
79    }
80    if (mKind == ProcessingRate::ProcessingRateKind::RoundUp) {
81        Type * const T = outputItems->getType();
82        Constant * const denominator = ConstantInt::get(T, mRatioDenominator);
83        return b->CreateMul(b->CreateUDiv(outputItems, denominator), denominator);
84    }
85    if (mKind == ProcessingRate::ProcessingRateKind::Add1) {
86        Type * const T = outputItems->getType();
87        if (doFinal) {
88            return b->CreateSub(outputItems, b->CreateZExt(doFinal, T));
89        }
90        return b->CreateSub(outputItems, ConstantInt::get(T, 1));
91    }
92    return nullptr;
93}
94
95void KernelInterface::addKernelDeclarations() {
96
97    if (mKernelStateType == nullptr) {
98        throw std::runtime_error("Kernel interface " + getName() + " not yet finalized.");
99    }
100
101    Module * const module = iBuilder->getModule();
102    PointerType * const selfType = mKernelStateType->getPointerTo();
103    IntegerType * const sizeTy = iBuilder->getSizeTy();
104    PointerType * const consumerTy = StructType::get(sizeTy, sizeTy->getPointerTo()->getPointerTo(), nullptr)->getPointerTo();
105    Type * const voidTy = iBuilder->getVoidTy();
106
107    // Create the initialization function prototype
108    std::vector<Type *> initParameters = {selfType};
109    for (auto binding : mScalarInputs) {
110        initParameters.push_back(binding.type);
111    }
112    initParameters.insert(initParameters.end(), mStreamSetOutputs.size(), consumerTy);
113
114    FunctionType * const initType = FunctionType::get(voidTy, initParameters, false);
115    Function * const initFunc = Function::Create(initType, GlobalValue::ExternalLinkage, getName() + INIT_SUFFIX, module);
116    initFunc->setCallingConv(CallingConv::C);
117    initFunc->setDoesNotThrow();
118    auto args = initFunc->arg_begin();
119    args->setName("self");
120    for (auto binding : mScalarInputs) {
121        (++args)->setName(binding.name);
122    }
123    for (auto binding : mStreamSetOutputs) {
124        (args++)->setName(binding.name + "ConsumerLocks");
125    }
126
127    // Create the doSegment function prototype.
128    std::vector<Type *> params = {selfType, iBuilder->getInt1Ty()};
129    params.insert(params.end(), mStreamSetInputs.size(), sizeTy);
130
131    FunctionType * const doSegmentType = FunctionType::get(voidTy, params, false);
132    Function * const doSegment = Function::Create(doSegmentType, GlobalValue::ExternalLinkage, getName() + DO_SEGMENT_SUFFIX, module);
133    doSegment->setCallingConv(CallingConv::C);
134    doSegment->setDoesNotThrow();
135    doSegment->setDoesNotCapture(1); // for self parameter only.
136    args = doSegment->arg_begin();
137    args->setName("self");
138    (++args)->setName("doFinal");
139    for (const Binding & input : mStreamSetInputs) {
140        (++args)->setName(input.name + "AvailableItems");
141    }
142
143    // Create the terminate function prototype
144    Type * resultType = nullptr;
145    if (mScalarOutputs.empty()) {
146        resultType = iBuilder->getVoidTy();
147    } else {
148        const auto n = mScalarOutputs.size();
149        Type * outputType[n];
150        for (unsigned i = 0; i < n; ++i) {
151            outputType[i] = mScalarOutputs[i].type;
152        }
153        if (n == 1) {
154            resultType = outputType[0];
155        } else {
156            resultType = StructType::get(iBuilder->getContext(), ArrayRef<Type *>(outputType, n));
157        }
158    }
159    FunctionType * const terminateType = FunctionType::get(resultType, {selfType}, false);
160    Function * const terminateFunc = Function::Create(terminateType, GlobalValue::ExternalLinkage, getName() + TERMINATE_SUFFIX, module);
161    terminateFunc->setCallingConv(CallingConv::C);
162    terminateFunc->setDoesNotThrow();
163    terminateFunc->setDoesNotCapture(1);
164    args = terminateFunc->arg_begin();
165    args->setName("self");
166
167    linkExternalMethods();
168}
169
170void KernelInterface::setInitialArguments(std::vector<Value *> args) {
171    mInitialArguments = args;
172}
173
174Function * KernelInterface::getInitFunction(Module * const module) const {
175    const auto name = getName() + INIT_SUFFIX;
176    Function * f = module->getFunction(name);
177    if (LLVM_UNLIKELY(f == nullptr)) {
178        llvm::report_fatal_error("Cannot find " + name);
179    }
180    return f;
181}
182
183Function * KernelInterface::getDoSegmentFunction(llvm::Module * const module) const {
184    const auto name = getName() + DO_SEGMENT_SUFFIX;
185    Function * f = module->getFunction(name);
186    if (LLVM_UNLIKELY(f == nullptr)) {
187        llvm::report_fatal_error("Cannot find " + name);
188    }
189    return f;
190}
191
192Function * KernelInterface::getTerminateFunction(Module * const module) const {
193    const auto name = getName() + TERMINATE_SUFFIX;
194    Function * f = module->getFunction(name);
195    if (LLVM_UNLIKELY(f == nullptr)) {
196        llvm::report_fatal_error("Cannot find " + name);
197    }
198    return f;
199}
Note: See TracBrowser for help on using the repository browser.