source: icGREP/icgrep-devel/icgrep/kernels/interface.h @ 5436

Last change on this file since 5436 was 5436, checked in by nmedfort, 2 years ago

Continued refactoring work. PabloKernel? now abstract base type with a 'generatePabloMethod' hook to generate Pablo code.

File size: 8.7 KB
Line 
1/*
2 *  Copyright (c) 2016 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 */
5
6#ifndef KERNEL_INTERFACE_H
7#define KERNEL_INTERFACE_H
8
9#include <llvm/IR/Constants.h>
10#include <string>
11#include <vector>
12
13namespace IDISA { class IDISA_Builder; }
14namespace kernel { class KernelBuilder; }
15
16// Processing rate attributes are required for all stream set bindings for a kernel.
17// These attributes describe the number of items that are processed or produced as
18// a ratio in comparison to the principal input stream set (or the principal output
19// stream set if there is no input.
20//
21// The default ratio is FixedRatio(1) which means that there is one item processed or
22// produced for every item of the principal input or output stream.
23// FixedRatio(m, n) means that for every group of n items of the principal stream,
24// there are m items in the output stream (rounding up).
25//
26// Kernels which produce a variable number of items use MaxRatio(n), for a maximum
27// of n items produced or consumed per principal input or output item.  MaxRatio(m, n)
28// means there are at most m items for every n items of the principal stream.
29//
30// RoundUpToMultiple(n) means that number of items produced is the same as the
31// number of input items, rounded up to an exact multiple of n.
32//
33
34struct ProcessingRate  {
35    enum class ProcessingRateKind : uint8_t { FixedRatio, RoundUp, Add1, MaxRatio, Unknown };
36    ProcessingRateKind getKind() const {return mKind;}
37    bool isFixedRatio() const {return mKind == ProcessingRateKind::FixedRatio;}
38    bool isMaxRatio() const {return mKind == ProcessingRateKind::MaxRatio;}
39    bool isExact() const {return (mKind == ProcessingRateKind::FixedRatio)||(mKind == ProcessingRateKind::RoundUp)||(mKind == ProcessingRateKind::Add1) ;}
40    bool isUnknownRate() const { return mKind == ProcessingRateKind::Unknown; }
41    llvm::Value * CreateRatioCalculation(IDISA::IDISA_Builder * const b, llvm::Value * principalInputItems, llvm::Value * doFinal = nullptr) const;
42    llvm::Value * CreateMaxReferenceItemsCalculation(IDISA::IDISA_Builder * const b, llvm::Value * outputItems, llvm::Value * doFinal) const;
43    friend ProcessingRate FixedRatio(unsigned strmItemsPer, unsigned perPrincipalInputItems, std::string && referenceStreamSet);
44    friend ProcessingRate MaxRatio(unsigned strmItemsPer, unsigned perPrincipalInputItems, std::string && referenceStreamSet);
45    friend ProcessingRate RoundUpToMultiple(unsigned itemMultiple, std::string && referenceStreamSet);
46    friend ProcessingRate Add1(std::string && referenceStreamSet);
47    friend ProcessingRate UnknownRate();
48    uint16_t getRatioNumerator() const { return mRatioNumerator;}
49    uint16_t getRatioDenominator() const { return mRatioDenominator;}
50    const std::string & referenceStreamSet() const { return mReferenceStreamSet;}
51protected:
52    ProcessingRate(ProcessingRateKind k, unsigned numerator, unsigned denominator, std::string && referenceStreamSet)
53    : mKind(k), mRatioNumerator(numerator), mRatioDenominator(denominator), mReferenceStreamSet(referenceStreamSet) {}
54private:
55    const ProcessingRateKind mKind;
56    const uint16_t mRatioNumerator;
57    const uint16_t mRatioDenominator;
58    const std::string mReferenceStreamSet;
59}; 
60
61ProcessingRate FixedRatio(unsigned strmItemsPer, unsigned perPrincipalInputItems = 1, std::string && referenceStreamSet = "");
62ProcessingRate MaxRatio(unsigned strmItemsPer, unsigned perPrincipalInputItems = 1, std::string && referenceStreamSet = "");
63ProcessingRate RoundUpToMultiple(unsigned itemMultiple, std::string &&referenceStreamSet = "");
64ProcessingRate Add1(std::string && referenceStreamSet = "");
65ProcessingRate UnknownRate();
66
67struct Binding {
68    Binding(llvm::Type * type, const std::string & name, ProcessingRate r = FixedRatio(1))
69    : type(type), name(name), rate(r) { }
70    llvm::Type * const        type;
71    const std::string         name;
72    const ProcessingRate      rate;
73};
74
75class KernelInterface {
76public:
77    /*
78     
79     This class defines the methods to be used to generate the code 
80     necessary for declaring, allocating, calling and synchronizing
81     kernels.   The methods to be used for constructing kernels are defined
82     within the KernelBuilder class of kernel.h
83     
84     */
85   
86    const std::string & getName() const { return mKernelName; }
87
88    void setName(std::string newName) { mKernelName = newName; }
89       
90    virtual bool isCachable() const = 0;
91
92    virtual std::string makeSignature() = 0;
93
94    const std::vector<Binding> & getStreamInputs() const {
95        return mStreamSetInputs;
96    }
97
98    const Binding & getStreamInput(const unsigned i) const {
99        return mStreamSetInputs[i];
100    }
101
102    const std::vector<Binding> & getStreamOutputs() const {
103        return mStreamSetOutputs;
104    }
105
106    const Binding & getStreamOutput(const unsigned i) const {
107        return mStreamSetOutputs[i];
108    }
109
110    const std::vector<Binding> & getScalarInputs() const {
111        return mScalarInputs;
112    }
113
114    const Binding & getScalarInput(const unsigned i) const {
115        return mScalarInputs[i];
116    }
117
118    const std::vector<Binding> & getScalarOutputs() const {
119        return mScalarOutputs;
120    }
121
122    const Binding & getScalarOutput(const unsigned i) const {
123        return mScalarOutputs[i];
124    }
125
126    // Add ExternalLinkage method declarations for the kernel to a given client module.
127    void addKernelDeclarations();
128
129    virtual void linkExternalMethods() = 0;
130
131    virtual llvm::Value * createInstance() = 0;
132
133    virtual void initializeInstance() = 0;
134
135    virtual void finalizeInstance() = 0;
136
137    void setInitialArguments(std::vector<llvm::Value *> && args) {
138        mInitialArguments.swap(args);
139    }
140
141    llvm::Value * getInstance() const {
142        return mKernelInstance;
143    }
144
145    unsigned getLookAhead() const {
146        return mLookAheadPositions;
147    }
148
149    void setLookAhead(const unsigned lookAheadPositions) {
150        mLookAheadPositions = lookAheadPositions;
151    }
152
153    kernel::KernelBuilder * getBuilder() const {
154        return iBuilder;
155    }
156
157    void setBuilder(const std::unique_ptr<kernel::KernelBuilder> & builder) {
158        iBuilder = builder.get();
159    }
160
161protected:
162
163    virtual llvm::Value * getProducedItemCount(const std::string & name, llvm::Value * doFinal = nullptr) const = 0;
164
165    virtual void setProducedItemCount(const std::string & name, llvm::Value * value) const = 0;
166
167    virtual llvm::Value * getProcessedItemCount(const std::string & name) const = 0;
168
169    virtual void setProcessedItemCount(const std::string & name, llvm::Value * value) const = 0;
170
171    virtual llvm::Value * getConsumedItemCount(const std::string & name) const = 0;
172
173    virtual void setConsumedItemCount(const std::string & name, llvm::Value * value) const = 0;
174
175    virtual llvm::Value * getTerminationSignal() const = 0;
176
177    virtual void setTerminationSignal() const = 0;
178
179    llvm::Function * getInitFunction(llvm::Module * const module) const;
180
181    llvm::Function * getDoSegmentFunction(llvm::Module * const module) const;
182
183    llvm::Function * getTerminateFunction(llvm::Module * const module) const;
184
185    KernelInterface(std::string kernelName,
186                    std::vector<Binding> && stream_inputs,
187                    std::vector<Binding> && stream_outputs,
188                    std::vector<Binding> && scalar_inputs,
189                    std::vector<Binding> && scalar_outputs,
190                    std::vector<Binding> && internal_scalars)
191    : iBuilder(nullptr)
192    , mModule(nullptr)
193    , mKernelInstance(nullptr)
194    , mKernelStateType(nullptr)
195    , mLookAheadPositions(0)
196    , mKernelName(kernelName)
197    , mStreamSetInputs(stream_inputs)
198    , mStreamSetOutputs(stream_outputs)
199    , mScalarInputs(scalar_inputs)
200    , mScalarOutputs(scalar_outputs)
201    , mInternalScalars(internal_scalars)
202    {
203
204    }
205   
206    void setInstance(llvm::Value * const instance) {
207        assert ("kernel instance cannot be null!" && instance);
208        assert ("kernel instance must point to a valid kernel state type!" && (instance->getType()->getPointerElementType() == mKernelStateType));
209        mKernelInstance = instance;
210    }
211
212protected:
213   
214    kernel::KernelBuilder *                 iBuilder;
215    llvm::Module *                          mModule;
216
217    llvm::Value *                           mKernelInstance;
218    llvm::StructType *                      mKernelStateType;
219    unsigned                                mLookAheadPositions;
220    std::string                             mKernelName;
221    std::vector<llvm::Value *>              mInitialArguments;
222    std::vector<Binding>                    mStreamSetInputs;
223    std::vector<Binding>                    mStreamSetOutputs;
224    std::vector<Binding>                    mScalarInputs;
225    std::vector<Binding>                    mScalarOutputs;
226    std::vector<Binding>                    mInternalScalars;
227};
228
229#endif
Note: See TracBrowser for help on using the repository browser.