source: icGREP/icgrep-devel/icgrep/kernels/interface.h @ 5755

Last change on this file since 5755 was 5755, checked in by nmedfort, 16 months ago

Bug fixes and simplified MultiBlockKernel? logic

File size: 6.2 KB
Line 
1/*
2 *  Copyright (c) 2016 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 */
5
6#ifndef KERNEL_INTERFACE_H
7#define KERNEL_INTERFACE_H
8
9#include <kernels/processing_rate.h>
10#include <kernels/attributes.h>
11#include <memory>
12#include <string>
13#include <vector>
14
15namespace IDISA { class IDISA_Builder; }
16namespace kernel { class Kernel; }
17namespace kernel { class KernelBuilder; }
18namespace llvm { class CallInst; }
19namespace llvm { class Function; }
20namespace llvm { class Value; }
21namespace llvm { class Module; }
22namespace llvm { class StructType; }
23namespace llvm { class Type; }
24
25namespace kernel {
26
27struct Binding : public AttributeSet {
28
29    Binding(llvm::Type * type, const std::string & name, ProcessingRate r = FixedRate(1))
30    : AttributeSet()
31    , mType(type), mName(name), mRate(std::move(r)) { }
32
33
34    Binding(llvm::Type * type, const std::string & name, ProcessingRate r, Attribute && attribute)
35    : AttributeSet({std::move(attribute)})
36    , mType(type), mName(name), mRate(std::move(r)) { }
37
38
39    Binding(llvm::Type * type, const std::string & name, ProcessingRate r, std::initializer_list<Attribute> attributes)
40    : AttributeSet(attributes)
41    , mType(type), mName(name), mRate(std::move(r)) { }
42
43    llvm::Type * getType() const {
44        return mType;
45    }
46
47    const std::string & getName() const {
48        return mName;
49    }
50
51    const ProcessingRate & getRate() const {
52        return mRate;
53    }
54
55    ProcessingRate & getRate() {
56        return mRate;
57    }
58
59    bool isPrincipal() const {
60        return hasAttribute(Attribute::KindId::Principal);
61    }
62
63    bool notDeferred() const {
64        return !hasAttribute(Attribute::KindId::Deferred);
65    }
66
67private:
68    llvm::Type * const          mType;
69    const std::string           mName;
70    ProcessingRate              mRate;
71};
72
73using Bindings = std::vector<Binding>;
74
75class KernelInterface : public AttributeSet {
76public:
77    /*
78     
79     This class defines the methods to be used to generate the code 
80     necessary for declaring, allocating, calling and synchronizing
81     kernels.   The methods to be used for constructing kernels are defined
82     within the KernelBuilder class of kernel.h
83     
84     */
85   
86    const std::string & getName() const {
87        return mKernelName;
88    }
89       
90    virtual bool isCachable() const = 0;
91
92    virtual std::string makeSignature(const std::unique_ptr<kernel::KernelBuilder> & idb) = 0;
93
94    const std::vector<Binding> & getStreamInputs() const {
95        return mStreamSetInputs;
96    }
97
98    const Binding & getStreamInput(const unsigned i) const {
99        assert (i < getNumOfStreamInputs());
100        return mStreamSetInputs[i];
101    }
102
103    unsigned getNumOfStreamInputs() const {
104        return mStreamSetInputs.size();
105    }
106
107    const std::vector<Binding> & getStreamOutputs() const {
108        return mStreamSetOutputs;
109    }
110
111    unsigned getNumOfStreamOutputs() const {
112        return mStreamSetOutputs.size();
113    }
114
115    const Binding & getStreamOutput(const unsigned i) const {
116        assert (i < getNumOfStreamOutputs());
117        return mStreamSetOutputs[i];
118    }
119
120    const std::vector<Binding> & getScalarInputs() const {
121        return mScalarInputs;
122    }
123
124    const Binding & getScalarInput(const unsigned i) const {
125        return mScalarInputs[i];
126    }
127
128    const std::vector<Binding> & getScalarOutputs() const {
129        return mScalarOutputs;
130    }
131
132    const Binding & getScalarOutput(const unsigned i) const {
133        return mScalarOutputs[i];
134    }
135
136    // Add ExternalLinkage method declarations for the kernel to a given client module.
137    void addKernelDeclarations(const std::unique_ptr<kernel::KernelBuilder> & idb);
138
139    virtual void linkExternalMethods(const std::unique_ptr<kernel::KernelBuilder> & idb) = 0;
140
141    virtual llvm::Value * createInstance(const std::unique_ptr<kernel::KernelBuilder> & idb) = 0;
142
143    virtual void initializeInstance(const std::unique_ptr<kernel::KernelBuilder> & idb) = 0;
144
145    virtual void finalizeInstance(const std::unique_ptr<kernel::KernelBuilder> & idb) = 0;
146
147    void setInitialArguments(std::vector<llvm::Value *> && args) {
148        mInitialArguments.swap(args);
149    }
150
151    llvm::Value * getInstance() const {
152        return mKernelInstance;
153    }
154
155    void setInstance(llvm::Value * const instance);
156
157    bool hasPrincipalItemCount() const {
158        return mHasPrincipalItemCount;
159    }
160
161    unsigned getLookAhead(const unsigned i) const {
162        return 0;
163    }
164
165    void setLookAhead(const unsigned i, const unsigned lookAheadPositions) {
166
167    }
168
169protected:
170
171    llvm::Function * getInitFunction(llvm::Module * const module) const;
172
173    llvm::Function * getDoSegmentFunction(llvm::Module * const module) const;
174
175    llvm::Function * getTerminateFunction(llvm::Module * const module) const;
176
177    llvm::CallInst * makeDoSegmentCall(KernelBuilder & idb, const std::vector<llvm::Value *> & args) const;
178
179    KernelInterface(const std::string && kernelName,
180                    std::vector<Binding> && stream_inputs,
181                    std::vector<Binding> && stream_outputs,
182                    std::vector<Binding> && scalar_inputs,
183                    std::vector<Binding> && scalar_outputs,
184                    std::vector<Binding> && internal_scalars)
185    : mKernelInstance(nullptr)
186    , mModule(nullptr)
187    , mKernelStateType(nullptr)
188    , mHasPrincipalItemCount(false)
189    , mKernelName(kernelName)
190    , mStreamSetInputs(stream_inputs)
191    , mStreamSetOutputs(stream_outputs)
192    , mScalarInputs(scalar_inputs)
193    , mScalarOutputs(scalar_outputs)
194    , mInternalScalars(internal_scalars) {
195
196    }
197   
198protected:
199
200    llvm::Value *                           mKernelInstance;
201    llvm::Module *                          mModule;
202    llvm::StructType *                      mKernelStateType;
203    bool                                    mHasPrincipalItemCount;
204    const std::string                       mKernelName;
205    std::vector<llvm::Value *>              mInitialArguments;
206    std::vector<Binding>                    mStreamSetInputs;
207    std::vector<Binding>                    mStreamSetOutputs;
208    std::vector<Binding>                    mScalarInputs;
209    std::vector<Binding>                    mScalarOutputs;
210    std::vector<Binding>                    mInternalScalars;
211};
212
213}
214
215#endif
Note: See TracBrowser for help on using the repository browser.