source: icGREP/icgrep-devel/icgrep/kernels/interface.cpp @ 5401

Last change on this file since 5401 was 5401, checked in by nmedfort, 2 years ago

Updated all projects to use ParabixDriver?. Deprecated original pipeline generation methods. Enabled LLVM optimizations, IR and ASM printing for Kernel modules. Enabled object cache by default. Begun work on moving consumed position information back to producing kernels.

File size: 7.3 KB
Line 
1/*
2 *  Copyright (c) 2016 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 */
5
6#include "interface.h"
7#include <llvm/IR/Value.h>         // for Value
8#include <llvm/IR/CallingConv.h>   // for ::C
9#include <llvm/IR/DerivedTypes.h>  // for FunctionType (ptr only), PointerType
10#include <llvm/IR/Function.h>      // for Function, Function::arg_iterator
11#include <llvm/IR/Module.h>
12#include <IR_Gen/idisa_builder.h>
13namespace llvm { class Module; }
14namespace llvm { class Type; }
15
16static const auto INIT_SUFFIX = "_Init";
17
18static const auto DO_SEGMENT_SUFFIX = "_DoSegment";
19
20static const auto ACCUMULATOR_INFIX = "_get_";
21
22using namespace llvm;
23
24ProcessingRate FixedRatio(unsigned strmItemsPer, unsigned perPrincipalInputItems, std::string && referenceStreamSet) {
25    return ProcessingRate(ProcessingRate::ProcessingRateKind::Fixed, strmItemsPer, perPrincipalInputItems, std::move(referenceStreamSet));
26}
27
28ProcessingRate MaxRatio(unsigned strmItemsPer, unsigned perPrincipalInputItems, std::string && referenceStreamSet) {
29    return ProcessingRate(ProcessingRate::ProcessingRateKind::Max, strmItemsPer, perPrincipalInputItems, std::move(referenceStreamSet));
30}
31
32ProcessingRate RoundUpToMultiple(unsigned itemMultiple, std::string && referenceStreamSet) {
33    return ProcessingRate(ProcessingRate::ProcessingRateKind::RoundUp, itemMultiple, itemMultiple, std::move(referenceStreamSet));
34}
35
36ProcessingRate Add1(std::string && referenceStreamSet) {
37    return ProcessingRate(ProcessingRate::ProcessingRateKind::Add1, 0, 0, std::move(referenceStreamSet));
38}
39
40ProcessingRate UnknownRate() {
41    return ProcessingRate(ProcessingRate::ProcessingRateKind::Unknown, 0, 0, "");
42}
43
44Value * ProcessingRate::CreateRatioCalculation(IDISA::IDISA_Builder * b, Value * principalInputItems, Value * doFinal) const {
45    if (mKind == ProcessingRate::ProcessingRateKind::Fixed || mKind == ProcessingRate::ProcessingRateKind::Max) {
46        if (mRatioNumerator == 1) {
47            return principalInputItems;
48        }
49        Type * const T = principalInputItems->getType();
50        Constant * const numerator = ConstantInt::get(T, mRatioNumerator);
51        Constant * const denominator = ConstantInt::get(T, mRatioDenominator);
52        Constant * const denominatorLess1 = ConstantInt::get(T, mRatioDenominator - 1);
53        Value * strmItems = b->CreateMul(principalInputItems, numerator);
54        return b->CreateUDiv(b->CreateAdd(denominatorLess1, strmItems), denominator);
55    }
56    if (mKind == ProcessingRate::ProcessingRateKind::RoundUp) {
57        Type * const T = principalInputItems->getType();
58        Constant * const denominator = ConstantInt::get(T, mRatioDenominator);
59        Constant * const denominatorLess1 = ConstantInt::get(T, mRatioDenominator - 1);
60        return b->CreateMul(b->CreateUDiv(b->CreateAdd(principalInputItems, denominatorLess1), denominator), denominator);
61    }
62    if (mKind == ProcessingRate::ProcessingRateKind::Add1) {
63        if (doFinal) {
64            Type * const T = principalInputItems->getType();
65            principalInputItems = b->CreateAdd(principalInputItems, b->CreateZExt(doFinal, T));
66        }
67        return principalInputItems;
68    }
69    return nullptr;
70}
71
72void KernelInterface::addKernelDeclarations(Module * client) {
73    Module * saveModule = iBuilder->getModule();
74    auto savePoint = iBuilder->saveIP();
75    iBuilder->setModule(client);
76    if (mKernelStateType == nullptr) {
77        throw std::runtime_error("Kernel interface " + getName() + " not yet finalized.");
78    }
79    PointerType * selfType = PointerType::getUnqual(mKernelStateType);
80
81    // Create the initialization function prototype
82    std::vector<Type *> initParameters = {selfType};
83    for (auto binding : mScalarInputs) {
84        initParameters.push_back(binding.type);
85    }
86    FunctionType * initType = FunctionType::get(iBuilder->getVoidTy(), initParameters, false);
87    Function * init = Function::Create(initType, GlobalValue::ExternalLinkage, getName() + INIT_SUFFIX, client);
88    init->setCallingConv(CallingConv::C);
89    init->setDoesNotThrow();
90    auto args = init->arg_begin();
91    args->setName("self");
92    for (auto binding : mScalarInputs) {
93        (++args)->setName(binding.name);
94    }
95
96    /// INVESTIGATE: should we explicitly mark whether to track a kernel output's consumed amount? It would have
97    /// to be done at the binding level using the current architecture. It would reduce the number of arguments
98    /// passed between kernels.
99
100    // Create the doSegment function prototype.   
101    IntegerType * const sizeTy = iBuilder->getSizeTy();
102
103    std::vector<Type *> params = {selfType, iBuilder->getInt1Ty()};
104    params.insert(params.end(), mStreamSetInputs.size() + mStreamSetOutputs.size(), sizeTy);
105
106    Type * retType = nullptr;
107    if (mStreamSetInputs.empty()) {
108        retType = iBuilder->getVoidTy();
109    } else {
110        retType = ArrayType::get(sizeTy, mStreamSetInputs.size());
111    }
112
113    FunctionType * const doSegmentType = FunctionType::get(retType, params, false);
114    Function * doSegment = Function::Create(doSegmentType, GlobalValue::ExternalLinkage, getName() + DO_SEGMENT_SUFFIX, client);
115    doSegment->setCallingConv(CallingConv::C);
116    doSegment->setDoesNotThrow();
117    doSegment->setDoesNotCapture(1); // for self parameter only.
118    args = doSegment->arg_begin();
119    args->setName("self");
120    (++args)->setName("doFinal");
121    for (const Binding & input : mStreamSetInputs) {
122        (++args)->setName(input.name + "_availableItems");
123    }
124    for (const Binding & output : mStreamSetOutputs) {
125        (++args)->setName(output.name + "_consumedItems");
126    }
127
128    /// INVESTIGATE: replace the accumulator methods with a single Exit method that handles any clean up and returns
129    /// a struct containing all scalar outputs?
130
131    // Create the accumulator get function prototypes
132    for (const auto & binding : mScalarOutputs) {
133        FunctionType * accumFnType = FunctionType::get(binding.type, {selfType}, false);
134        Function * accumFn = Function::Create(accumFnType, GlobalValue::ExternalLinkage, getName() + ACCUMULATOR_INFIX + binding.name, client);
135        accumFn->setCallingConv(CallingConv::C);
136        accumFn->setDoesNotThrow();
137        accumFn->setDoesNotCapture(1);
138        auto args = accumFn->arg_begin();
139        args->setName("self");
140    }
141
142    iBuilder->setModule(saveModule);
143    iBuilder->restoreIP(savePoint);
144}
145
146void KernelInterface::setInitialArguments(std::vector<Value *> args) {
147    mInitialArguments = args;
148}
149
150llvm::Function * KernelInterface::getAccumulatorFunction(const std::string & accumName) const {
151    const auto name = getName() + ACCUMULATOR_INFIX + accumName;
152    Function * f = iBuilder->getModule()->getFunction(name);
153    if (LLVM_UNLIKELY(f == nullptr)) {
154        llvm::report_fatal_error("Cannot find " + name);
155    }
156    return f;
157}
158
159Function * KernelInterface::getInitFunction() const {
160    const auto name = getName() + INIT_SUFFIX;
161    Function * f = iBuilder->getModule()->getFunction(name);
162    if (LLVM_UNLIKELY(f == nullptr)) {
163        llvm::report_fatal_error("Cannot find " + name);
164    }
165    return f;
166}
167
168Function * KernelInterface::getDoSegmentFunction() const {
169    const auto name = getName() + DO_SEGMENT_SUFFIX;
170    Function * f = iBuilder->getModule()->getFunction(name);
171    if (LLVM_UNLIKELY(f == nullptr)) {
172        llvm::report_fatal_error("Cannot find " + name);
173    }
174    return f;
175}
Note: See TracBrowser for help on using the repository browser.