source: icGREP/icgrep-devel/icgrep/kernels/pipeline_builder.h @ 6184

Last change on this file since 6184 was 6184, checked in by nmedfort, 7 months ago

Initial version of PipelineKernel? + revised StreamSet? model.

File size: 3.0 KB
Line 
1#ifndef PIPELINE_BUILDER_H
2#define PIPELINE_BUILDER_H
3
4#include <kernels/pipeline_kernel.h>
5
6class BaseDriver;
7
8namespace kernel {
9
10class PipelineBuilder {
11public:
12
13    using Kernels = PipelineKernel::Kernels;
14    using CallBinding = PipelineKernel::CallBinding;
15    using CallBindings = PipelineKernel::CallBindings;
16
17    template<typename KernelType, typename... Args>
18    Kernel * CreateKernelCall(Args &&... args) {
19        return initializeKernel(new KernelType(mDriver.getBuilder(), std::forward<Args>(args) ...));
20    }
21
22    StreamSet * CreateStreamSet(const unsigned NumElements = 1, const unsigned FieldWidth = 1) {
23        return mDriver.CreateStreamSet(NumElements, FieldWidth);
24    }
25
26    Scalar * CreateConstant(llvm::Constant * value) {
27        return mDriver.CreateConstant(value);
28    }
29
30    template <typename ExternalFunctionType>
31    void CreateCall(llvm::StringRef name, ExternalFunctionType & functionPtr, std::initializer_list<Scalar *> args) {
32        llvm::FunctionType * const type = FunctionTypeBuilder<ExternalFunctionType>::get(mDriver.getContext());
33        assert ("FunctionTypeBuilder did not resolve a function type." && type);
34        assert ("Function was not provided the correct number of args" && type->getNumParams() == args.size());
35        // Since the pipeline kernel module has not been made yet, just record the function info and its arguments.
36        mCallBindings.emplace_back(name, type, reinterpret_cast<void *>(&functionPtr), std::move(args));
37    }
38
39    void * compile();
40
41    void setNumOfThreads(const unsigned threads) {
42        mNumOfThreads = threads;
43    }
44
45    Scalar * getInputScalar(const unsigned i) {
46        return llvm::cast<Scalar>(mInputScalars[i].getRelationship());
47    }
48
49    Scalar * getInputScalar(const std::string & name);
50
51    void setInputScalar(const std::string & name, Scalar * value);
52
53    Scalar * getOutputScalar(const unsigned i) {
54        return llvm::cast<Scalar>(mOutputScalars[i].getRelationship());
55    }
56
57    Scalar * getOutputScalar(const std::string & name);
58
59    void setOutputScalar(const std::string & name, Scalar * value);
60
61    PipelineBuilder(BaseDriver & driver,
62                    Bindings && stream_inputs, Bindings && stream_outputs,
63                    Bindings && scalar_inputs, Bindings && scalar_outputs);
64
65protected:
66
67    PipelineKernel * makePipelineKernel();
68
69    Kernel * initializeKernel(Kernel * const kernel);
70
71    void addInputScalar(llvm::Type * type, std::string name);
72
73    llvm::Function * addOrDeclareMainFunction(PipelineKernel * const pk);
74
75private:
76
77    BaseDriver &       mDriver;
78
79    // eventual pipeline configuration
80    unsigned                        mNumOfThreads;
81    Bindings                        mInputStreamSets;
82    Bindings                        mOutputStreamSets;
83    Bindings                        mInputScalars;
84    Bindings                        mOutputScalars;
85    Bindings                        mInternalScalars;
86    Kernels                         mKernels;
87    CallBindings                    mCallBindings;
88
89};
90
91}
92
93#endif // PIPELINE_BUILDER_H
Note: See TracBrowser for help on using the repository browser.