source: icGREP/icgrep-devel/icgrep/kernels/interface.h @ 5435

Last change on this file since 5435 was 5435, checked in by nmedfort, 2 years ago

Continued refactoring work.

File size: 8.7 KB
Line 
1/*
2 *  Copyright (c) 2016 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 */
5
6#ifndef KERNEL_INTERFACE_H
7#define KERNEL_INTERFACE_H
8
9#include <llvm/IR/Constants.h>
10#include <string>
11#include <vector>
12
13namespace IDISA { class IDISA_Builder; }
14
15// Processing rate attributes are required for all stream set bindings for a kernel.
16// These attributes describe the number of items that are processed or produced as
17// a ratio in comparison to the principal input stream set (or the principal output
18// stream set if there is no input.
19//
20// The default ratio is FixedRatio(1) which means that there is one item processed or
21// produced for every item of the principal input or output stream.
22// FixedRatio(m, n) means that for every group of n items of the principal stream,
23// there are m items in the output stream (rounding up).
24//
25// Kernels which produce a variable number of items use MaxRatio(n), for a maximum
26// of n items produced or consumed per principal input or output item.  MaxRatio(m, n)
27// means there are at most m items for every n items of the principal stream.
28//
29// RoundUpToMultiple(n) means that number of items produced is the same as the
30// number of input items, rounded up to an exact multiple of n.
31//
32
33struct ProcessingRate  {
34    enum class ProcessingRateKind : uint8_t { FixedRatio, RoundUp, Add1, MaxRatio, Unknown };
35    ProcessingRateKind getKind() const {return mKind;}
36    bool isFixedRatio() const {return mKind == ProcessingRateKind::FixedRatio;}
37    bool isMaxRatio() const {return mKind == ProcessingRateKind::MaxRatio;}
38    bool isExact() const {return (mKind == ProcessingRateKind::FixedRatio)||(mKind == ProcessingRateKind::RoundUp)||(mKind == ProcessingRateKind::Add1) ;}
39    bool isUnknownRate() const { return mKind == ProcessingRateKind::Unknown; }
40    llvm::Value * CreateRatioCalculation(IDISA::IDISA_Builder * const b, llvm::Value * principalInputItems, llvm::Value * doFinal = nullptr) const;
41    llvm::Value * CreateMaxReferenceItemsCalculation(IDISA::IDISA_Builder * const b, llvm::Value * outputItems, llvm::Value * doFinal) const;
42    friend ProcessingRate FixedRatio(unsigned strmItemsPer, unsigned perPrincipalInputItems, std::string && referenceStreamSet);
43    friend ProcessingRate MaxRatio(unsigned strmItemsPer, unsigned perPrincipalInputItems, std::string && referenceStreamSet);
44    friend ProcessingRate RoundUpToMultiple(unsigned itemMultiple, std::string && referenceStreamSet);
45    friend ProcessingRate Add1(std::string && referenceStreamSet);
46    friend ProcessingRate UnknownRate();
47    uint16_t getRatioNumerator() const { return mRatioNumerator;}
48    uint16_t getRatioDenominator() const { return mRatioDenominator;}
49    const std::string & referenceStreamSet() const { return mReferenceStreamSet;}
50protected:
51    ProcessingRate(ProcessingRateKind k, unsigned numerator, unsigned denominator, std::string && referenceStreamSet)
52    : mKind(k), mRatioNumerator(numerator), mRatioDenominator(denominator), mReferenceStreamSet(referenceStreamSet) {}
53private:
54    const ProcessingRateKind mKind;
55    const uint16_t mRatioNumerator;
56    const uint16_t mRatioDenominator;
57    const std::string mReferenceStreamSet;
58}; 
59
60ProcessingRate FixedRatio(unsigned strmItemsPer, unsigned perPrincipalInputItems = 1, std::string && referenceStreamSet = "");
61ProcessingRate MaxRatio(unsigned strmItemsPer, unsigned perPrincipalInputItems = 1, std::string && referenceStreamSet = "");
62ProcessingRate RoundUpToMultiple(unsigned itemMultiple, std::string &&referenceStreamSet = "");
63ProcessingRate Add1(std::string && referenceStreamSet = "");
64ProcessingRate UnknownRate();
65
66struct Binding {
67    Binding(llvm::Type * type, const std::string & name, ProcessingRate r = FixedRatio(1))
68    : type(type), name(name), rate(r) { }
69    llvm::Type * const        type;
70    const std::string         name;
71    const ProcessingRate      rate;
72};
73
74class KernelInterface {
75public:
76    /*
77     
78     This class defines the methods to be used to generate the code 
79     necessary for declaring, allocating, calling and synchronizing
80     kernels.   The methods to be used for constructing kernels are defined
81     within the KernelBuilder class of kernel.h
82     
83     */
84   
85    const std::string & getName() const { return mKernelName; }
86
87    void setName(std::string newName) { mKernelName = newName; }
88       
89    virtual bool isCachable() const = 0;
90
91    virtual std::string makeSignature() = 0;
92
93    const std::vector<Binding> & getStreamInputs() const {
94        return mStreamSetInputs;
95    }
96
97    const Binding & getStreamInput(const unsigned i) const {
98        return mStreamSetInputs[i];
99    }
100
101    const std::vector<Binding> & getStreamOutputs() const {
102        return mStreamSetOutputs;
103    }
104
105    const Binding & getStreamOutput(const unsigned i) const {
106        return mStreamSetOutputs[i];
107    }
108
109    const std::vector<Binding> & getScalarInputs() const {
110        return mScalarInputs;
111    }
112
113    const Binding & getScalarInput(const unsigned i) const {
114        return mScalarInputs[i];
115    }
116
117    const std::vector<Binding> & getScalarOutputs() const {
118        return mScalarOutputs;
119    }
120
121    const Binding & getScalarOutput(const unsigned i) const {
122        return mScalarOutputs[i];
123    }
124
125    // Add ExternalLinkage method declarations for the kernel to a given client module.
126    void addKernelDeclarations();
127
128    virtual void linkExternalMethods() = 0;
129
130    virtual llvm::Value * createInstance() = 0;
131
132    virtual void initializeInstance() = 0;
133
134    virtual void finalizeInstance() = 0;
135
136    void setInitialArguments(std::vector<llvm::Value *> && args) {
137        mInitialArguments.swap(args);
138    }
139    llvm::Value * getInstance() const {
140        return mKernelInstance;
141    }
142
143    unsigned getLookAhead() const {
144        return mLookAheadPositions;
145    }
146
147    void setLookAhead(const unsigned lookAheadPositions) {
148        mLookAheadPositions = lookAheadPositions;
149    }
150
151    IDISA::IDISA_Builder * getBuilder() const {
152        return iBuilder;
153    }
154
155    void setBuilder(IDISA::IDISA_Builder * const builder) {
156        iBuilder = builder;
157    }
158
159protected:
160
161    virtual llvm::Value * getProducedItemCount(const std::string & name, llvm::Value * doFinal = nullptr) const = 0;
162
163    virtual void setProducedItemCount(const std::string & name, llvm::Value * value) const = 0;
164
165    virtual llvm::Value * getProcessedItemCount(const std::string & name) const = 0;
166
167    virtual void setProcessedItemCount(const std::string & name, llvm::Value * value) const = 0;
168
169    virtual llvm::Value * getConsumedItemCount(const std::string & name) const = 0;
170
171    virtual void setConsumedItemCount(const std::string & name, llvm::Value * value) const = 0;
172
173    virtual llvm::Value * getTerminationSignal() const = 0;
174
175    virtual void setTerminationSignal() const = 0;
176
177    llvm::Function * getInitFunction(llvm::Module * const module) const;
178
179    llvm::Function * getDoSegmentFunction(llvm::Module * const module) const;
180
181    llvm::Function * getTerminateFunction(llvm::Module * const module) const;
182
183    KernelInterface(std::string kernelName,
184                    std::vector<Binding> && stream_inputs,
185                    std::vector<Binding> && stream_outputs,
186                    std::vector<Binding> && scalar_inputs,
187                    std::vector<Binding> && scalar_outputs,
188                    std::vector<Binding> && internal_scalars)
189    : iBuilder(nullptr)
190    , mModule(nullptr)
191    , mKernelInstance(nullptr)
192    , mKernelStateType(nullptr)
193    , mLookAheadPositions(0)
194    , mKernelName(kernelName)
195    , mStreamSetInputs(stream_inputs)
196    , mStreamSetOutputs(stream_outputs)
197    , mScalarInputs(scalar_inputs)
198    , mScalarOutputs(scalar_outputs)
199    , mInternalScalars(internal_scalars)
200    {
201
202    }
203   
204    void setInstance(llvm::Value * const instance) {
205        assert ("kernel instance cannot be null!" && instance);
206        assert ("kernel instance must point to a valid kernel state type!" && (instance->getType()->getPointerElementType() == mKernelStateType));
207        mKernelInstance = instance;
208    }
209
210protected:
211   
212    IDISA::IDISA_Builder *                  iBuilder;
213    llvm::Module *                          mModule;
214
215    llvm::Value *                           mKernelInstance;
216    llvm::StructType *                      mKernelStateType;
217    unsigned                                mLookAheadPositions;
218    std::string                             mKernelName;
219    std::vector<llvm::Value *>              mInitialArguments;
220    std::vector<Binding>                    mStreamSetInputs;
221    std::vector<Binding>                    mStreamSetOutputs;
222    std::vector<Binding>                    mScalarInputs;
223    std::vector<Binding>                    mScalarOutputs;
224    std::vector<Binding>                    mInternalScalars;
225};
226
227#endif
Note: See TracBrowser for help on using the repository browser.