source: icGREP/icgrep-devel/icgrep/kernels/pipeline_builder.h

Last change on this file was 6288, checked in by cameron, 6 months ago

Repeat of prior check in

File size: 6.1 KB
Line 
1#ifndef PIPELINE_BUILDER_H
2#define PIPELINE_BUILDER_H
3
4#include <kernels/pipeline_kernel.h>
5
6class BaseDriver;
7
8namespace kernel {
9
10class OptimizationBranchBuilder;
11
12class PipelineBuilder {
13    friend class PipelineKernel;
14    friend class OptimizationBranchBuilder;
15public:
16
17    using Kernels = PipelineKernel::Kernels;
18    using CallBinding = PipelineKernel::CallBinding;
19    using CallBindings = PipelineKernel::CallBindings;
20    using NestedBuilders = std::vector<std::shared_ptr<PipelineBuilder>>;
21
22    template<typename KernelType, typename... Args>
23    Kernel * CreateKernelCall(Args &&... args) {
24        return initializeKernel(new KernelType(mDriver.getBuilder(), std::forward<Args>(args) ...));
25    }
26
27    std::shared_ptr<OptimizationBranchBuilder>
28        CreateOptimizationBranch(Relationship * const condition,
29                                 Bindings && stream_inputs = {}, Bindings && stream_outputs = {},
30                                 Bindings && scalar_inputs = {}, Bindings && scalar_outputs = {});
31
32    StreamSet * CreateStreamSet(const unsigned NumElements = 1, const unsigned FieldWidth = 1) {
33        return mDriver.CreateStreamSet(NumElements, FieldWidth);
34    }
35
36    Scalar * CreateConstant(llvm::Constant * value) {
37        return mDriver.CreateConstant(value);
38    }
39
40    template <typename ExternalFunctionType>
41    void CreateCall(llvm::StringRef name, ExternalFunctionType & functionPtr, std::initializer_list<Scalar *> args) {
42        llvm::FunctionType * const type = FunctionTypeBuilder<ExternalFunctionType>::get(mDriver.getContext());
43        assert ("FunctionTypeBuilder did not resolve a function type." && type);
44        assert ("Function was not provided the correct number of args" && type->getNumParams() == args.size());
45        // Since the pipeline kernel module has not been made yet, just record the function info and its arguments.
46        mCallBindings.emplace_back(name, type, reinterpret_cast<void *>(&functionPtr), std::move(args));
47    }
48
49    Scalar * getInputScalar(const unsigned i) {
50        return llvm::cast<Scalar>(mInputScalars[i].getRelationship());
51    }
52
53    Scalar * getInputScalar(const std::string & name);
54
55    void setInputScalar(const std::string & name, Scalar * value);
56
57    Scalar * getOutputScalar(const unsigned i) {
58        return llvm::cast<Scalar>(mOutputScalars[i].getRelationship());
59    }
60
61    Scalar * getOutputScalar(const std::string & name);
62
63    void setOutputScalar(const std::string & name, Scalar * value);
64
65    PipelineBuilder(BaseDriver & driver,
66                    Bindings && stream_inputs, Bindings && stream_outputs,
67                    Bindings && scalar_inputs, Bindings && scalar_outputs,
68                    const unsigned numOfThreads = 1);
69
70    virtual ~PipelineBuilder() {}
71
72protected:
73
74
75    // Internal pipeline constructor uses a zero-length tag struct to prevent
76    // overloading errors. This paramater will be dropped by the compiler.
77    struct Internal {};
78    PipelineBuilder(Internal, BaseDriver & driver,
79                    Bindings stream_inputs, Bindings stream_outputs,
80                    Bindings scalar_inputs, Bindings scalar_outputs,
81                    const unsigned numOfThreads = 1);
82
83    virtual Kernel * makeKernel();
84
85    Kernel * initializeKernel(Kernel * const kernel);
86
87    void addInputScalar(llvm::Type * type, std::string name);
88
89protected:
90
91    BaseDriver &        mDriver;
92    // eventual pipeline configuration
93    unsigned            mNumOfThreads;
94    Bindings            mInputStreamSets;
95    Bindings            mOutputStreamSets;
96    Bindings            mInputScalars;
97    Bindings            mOutputScalars;
98    Bindings            mInternalScalars;
99    Kernels             mKernels;
100    CallBindings        mCallBindings;
101    NestedBuilders      mNestedBuilders;
102};
103
104/** ------------------------------------------------------------------------------------------------------------- *
105 * @brief ProgramBuilder
106 ** ------------------------------------------------------------------------------------------------------------- */
107class ProgramBuilder : public PipelineBuilder {
108    friend class PipelineBuilder;
109public:
110
111    void * compile();
112
113    void setNumOfThreads(const unsigned threads) {
114        mNumOfThreads = threads;
115    }
116
117    ProgramBuilder(BaseDriver & driver,
118                   Bindings && stream_inputs, Bindings && stream_outputs,
119                   Bindings && scalar_inputs, Bindings && scalar_outputs);
120
121private:
122
123};
124
125/** ------------------------------------------------------------------------------------------------------------- *
126 * @brief PipelineBranchBuilder
127 ** ------------------------------------------------------------------------------------------------------------- */
128class OptimizationBranchBuilder final : public PipelineBuilder {
129    friend class PipelineKernel;
130    friend class PipelineBuilder;
131public:
132
133    const std::unique_ptr<PipelineBuilder> & getNonZeroBranch() const {
134        return mNonZeroBranch;
135    }
136
137    const std::unique_ptr<PipelineBuilder> & getAllZeroBranch() const {
138        return mAllZeroBranch;
139    }
140
141    ~OptimizationBranchBuilder();
142
143protected:
144
145    OptimizationBranchBuilder(BaseDriver & driver, Relationship * const condition,
146                              Bindings && stream_inputs, Bindings && stream_outputs,
147                              Bindings && scalar_inputs, Bindings && scalar_outputs);
148
149    Kernel * makeKernel() override;
150
151private:
152    Relationship * const             mCondition;
153    std::unique_ptr<PipelineBuilder> mNonZeroBranch;
154    std::unique_ptr<PipelineBuilder> mAllZeroBranch;
155};
156
157inline std::shared_ptr<OptimizationBranchBuilder> PipelineBuilder::CreateOptimizationBranch (
158        Relationship * const condition,
159        Bindings && stream_inputs, Bindings && stream_outputs,
160        Bindings && scalar_inputs, Bindings && scalar_outputs) {
161    std::shared_ptr<OptimizationBranchBuilder> branch(
162        new OptimizationBranchBuilder(mDriver, condition,
163            std::move(stream_inputs), std::move(stream_outputs),
164            std::move(scalar_inputs), std::move(scalar_outputs)));
165    mNestedBuilders.emplace_back(std::static_pointer_cast<PipelineBuilder>(branch));
166    return branch;
167}
168
169}
170
171#endif // PIPELINE_BUILDER_H
Note: See TracBrowser for help on using the repository browser.