source: icGREP/icgrep-devel/icgrep/kernels/pipeline/pipeline_compiler.hpp @ 6186

Last change on this file since 6186 was 6186, checked in by cameron, 5 months ago

Various clean-ups

File size: 21.2 KB
Line 
1#ifndef PIPELINE_COMPILER_HPP
2#define PIPELINE_COMPILER_HPP
3
4#include <kernels/pipeline_kernel.h>
5#include <kernels/streamset.h>
6#include <kernels/kernel_builder.h>
7#include <toolchain/toolchain.h>
8#include <boost/container/flat_set.hpp>
9#include <boost/container/flat_map.hpp>
10#include <boost/graph/adjacency_list.hpp>
11//#include <boost/graph/topological_sort.hpp>
12#include <boost/math/common_factor_rt.hpp>
13#include <llvm/IR/Module.h>
14#include <llvm/Support/raw_ostream.h>
15#include <llvm/ADT/STLExtras.h>
16#include <queue>
17
18// #define PRINT_DEBUG_MESSAGES
19
20using namespace boost;
21using namespace boost::math;
22using boost::container::flat_set;
23using boost::container::flat_map;
24using namespace llvm;
25
26inline static unsigned floor_log2(const unsigned v) {
27    assert ("log2(0) is undefined!" && v != 0);
28    return 31 - __builtin_clz(v);
29}
30
31namespace kernel {
32
33using Port = Kernel::Port;
34using StreamPort = Kernel::StreamSetPort;
35using AttrId = Attribute::KindId;
36using RateValue = ProcessingRate::RateValue;
37using RateId = ProcessingRate::KindId;
38using Scalars = PipelineKernel::Scalars;
39using Kernels = PipelineKernel::Kernels;
40using CallBinding = PipelineKernel::CallBinding;
41
42#warning TODO: these graphs are similar; look into streamlining their generation.
43
44struct BufferNode { // use boost::variant instead of union? std::variant is c17+
45    Kernel * kernel = nullptr;
46    StreamSetBuffer * buffer = nullptr;
47    RateValue lower;
48    RateValue upper;
49};
50
51struct BufferRateData {   
52
53    RateValue minimum;
54    RateValue maximum;
55    unsigned port;
56
57    BufferRateData(const unsigned port = 0) : port(port) { }
58
59    BufferRateData(const unsigned port, RateValue min, RateValue max)
60    : minimum(std::move(min)), maximum(std::move(max)), port(port) { }
61};
62
63using BufferGraph = adjacency_list<vecS, vecS, bidirectionalS, BufferNode, BufferRateData>; // unsigned>;
64
65template <typename vertex_descriptor>
66using RelationshipMap = flat_map<const Relationship *, vertex_descriptor>;
67
68using BufferMap = RelationshipMap<BufferGraph::vertex_descriptor>;
69
70struct ConsumerNode {
71    Value * consumed = nullptr;
72    PHINode * phiNode = nullptr;
73};
74
75using ConsumerGraph = adjacency_list<vecS, vecS, bidirectionalS, ConsumerNode, unsigned>;
76
77template <typename Value>
78using StreamSetBufferMap = flat_map<const StreamSetBuffer *, Value>;
79
80template <typename Value>
81using KernelMap = flat_map<const Kernel *, Value>;
82
83struct Channel {
84    Channel() = default;
85    Channel(const RateValue & ratio, const StreamSetBuffer * const buffer = nullptr, const unsigned operand = 0)
86    : ratio(ratio), buffer(buffer), portIndex(operand) {
87
88    }
89    RateValue               ratio;
90    const StreamSetBuffer * buffer;
91    unsigned                portIndex;
92};
93
94using ChannelGraph = adjacency_list<vecS, vecS, bidirectionalS, const Kernel *, Channel>;
95
96using TerminationGraph = adjacency_list<hash_setS, vecS, bidirectionalS, Value *>;
97
98using ScalarDependencyGraph = adjacency_list<vecS, vecS, bidirectionalS, Value *, unsigned>;
99
100using PortDependencyGraph = adjacency_list<vecS, vecS, bidirectionalS, no_property, RateId>;
101
102struct OverflowRequirement {
103    unsigned copyBack;
104    unsigned facsimile;
105    OverflowRequirement() = default;
106    OverflowRequirement(const unsigned copyBack, const unsigned copyForward)
107    : copyBack(copyBack), facsimile(copyForward) { }
108};
109
110using OverflowRequirements = StreamSetBufferMap<OverflowRequirement>;
111
112using PopCountStreamDependencyGraph = adjacency_list<vecS, vecS, directedS, Value *>;
113
114using PopCountStreamDependencyVertex = PopCountStreamDependencyGraph::vertex_descriptor;
115
116struct PopCountData {
117    unsigned hasConstructedArray = -1;
118    PopCountStreamDependencyVertex vertex = 0;
119    Value * initial = nullptr;
120    Value * baseIndex = nullptr;
121    AllocaInst * individualCountArray = nullptr;
122    AllocaInst * partialSumArray = nullptr;
123    Value * strideCapacity = nullptr;
124    Value * finalPartialSum = nullptr;
125    Value * maximumNumOfStrides = nullptr;
126};
127
128using PopCountDataMap = flat_map<std::pair<const StreamSetBuffer *, bool>, PopCountData>;
129
130class PipelineCompiler {
131public:
132
133    using BuilderRef = const std::unique_ptr<kernel::KernelBuilder> &;
134
135    PipelineCompiler(BuilderRef b, PipelineKernel * const pipelineKernel);
136
137    void addHandlesToPipelineKernel(BuilderRef b);
138    void generateInitializeMethod(BuilderRef b);
139    void generateSingleThreadKernelMethod(BuilderRef b);
140    void generateMultiThreadKernelMethod(BuilderRef b, const unsigned numOfThreads);
141    void generateFinalizeMethod(BuilderRef b);
142    std::vector<Value *> getFinalOutputScalars(BuilderRef b);
143
144protected:
145
146// main pipeline functions
147
148    void start(BuilderRef b, Value * const initialSegNo);
149    void setActiveKernel(BuilderRef b, const unsigned index);
150    void synchronize(BuilderRef b);
151    void executeKernel(BuilderRef b);
152    void end(BuilderRef b, const unsigned step);
153
154// inter-kernel functions
155
156    Value * checkForSufficientInputDataAndOutputSpace(BuilderRef b);
157    Value * determineNumOfLinearStrides(BuilderRef b);
158    void calculateNonFinalItemCounts(BuilderRef b, Value * const numOfStrides);
159    void calculateFinalItemCounts(BuilderRef b);
160    void provideAllInputAndOutputSpace(BuilderRef b);
161    void writeKernelCall(BuilderRef b);
162    void writeCopyBackLogic(BuilderRef b);
163    void writeCopyForwardLogic(BuilderRef b);
164    void allocateThreadLocalState(BuilderRef b, const Port port, const unsigned i);
165    void deallocateThreadLocalState(BuilderRef b, const Port port, const unsigned i);
166
167    void checkIfAllProducingKernelsHaveTerminated(BuilderRef b);
168    void zeroFillPartiallyWrittenOutputStreams(BuilderRef b);
169    void initializeKernelCallPhis(BuilderRef b);
170    void initializeKernelExitPhis(BuilderRef b);
171    void storeCopyForwardProducedItemCounts(BuilderRef b);
172    void storeCopyBackProducedItemCounts(BuilderRef b);
173    void computeMinimumConsumedItemCounts(BuilderRef b);
174    void writeFinalConsumedItemCounts(BuilderRef b);
175    void readCurrentProducedItemCounts(BuilderRef b);
176    void releaseCurrentSegment(BuilderRef b);
177    void writeCopyToOverflowLogic(BuilderRef b);
178    void checkForSufficientInputData(BuilderRef b, const unsigned index);
179    void checkForSufficientOutputSpaceOrExpand(BuilderRef b, const unsigned index);
180
181// intra-kernel functions
182
183    void branchToTargetOrLoopExit(BuilderRef b, Value * const cond, BasicBlock * const target);
184    void expandOutputBuffer(BuilderRef b, Value * const hasEnough, const unsigned index, BasicBlock * const target);
185    Value * getInputStrideLength(BuilderRef b, const unsigned index);
186    Value * getOutputStrideLength(BuilderRef b, const unsigned index);
187    Value * getInitialStrideLength(BuilderRef b, const Binding & binding);
188    Value * calculateNumOfLinearItems(BuilderRef b, const Binding & binding, Value * const numOfStrides);
189    Value * getAccessibleInputItems(BuilderRef b, const unsigned index);
190    Value * getNumOfAccessibleStrides(BuilderRef b, const unsigned index);
191    Value * getNumOfWritableStrides(BuilderRef b, const unsigned index);
192    Value * getWritableOutputItems(BuilderRef b, const unsigned index);
193    Value * calculateBufferExpansionSize(BuilderRef b, const unsigned index);
194    Value * addLookahead(BuilderRef b, const unsigned index, Value * itemCount) const;
195    Value * subtractLookahead(BuilderRef b, const unsigned index, Value * itemCount) const;
196    Value * getFullyProcessedItemCount(BuilderRef b, const Binding & input) const;
197    Value * getTotalItemCount(BuilderRef b, const StreamSetBuffer * buffer) const;
198    Value * isTerminated(BuilderRef b) const;
199    void setTerminated(BuilderRef b);
200
201// pop-count functions
202
203    void initializePopCounts(BuilderRef b);
204    PopCountStreamDependencyGraph makePopCountStreamDependencyGraph(BuilderRef b);
205    void addPopCountStreamDependency(BuilderRef b, const unsigned index, const Binding & binding, PopCountStreamDependencyGraph & G);
206
207    Value * getInitialNumOfLinearPopCountItems(BuilderRef b, const ProcessingRate & rate);
208    Value * getMaximumNumOfPopCountStrides(BuilderRef b, const ProcessingRate & rate);
209    Value * getNumOfLinearPopCountItems(BuilderRef b, const ProcessingRate & rate, Value * const numOfStrides);
210    Value * getPopCountArray(BuilderRef b, const unsigned index);
211    Value * getNegatedPopCountArray(BuilderRef b, const unsigned index);
212    void allocateLocalPopCountArray(BuilderRef b, const ProcessingRate & rate);
213    void deallocateLocalPopCountArray(BuilderRef b, const ProcessingRate & rate);
214    void storePopCountSourceItemCount(BuilderRef b, const Port port, const unsigned index, Value * const offset, Value * const processable);
215
216    PopCountData & findOrAddPopCountData(BuilderRef b, const ProcessingRate & rate);
217    PopCountData & findOrAddPopCountData(BuilderRef b, const unsigned index, const bool negated);
218    Value * getInitialNumOfLinearPopCountItems(BuilderRef b, PopCountData & pc, const unsigned index, const bool negated);
219    PopCountData & makePopCountArray(BuilderRef b, const ProcessingRate & rate);
220    PopCountData & makePopCountArray(BuilderRef b, const unsigned index, const bool negated);
221    Value * getMinimumNumOfSourceItems(BuilderRef b, const PopCountData & pc);
222    Value * getSourceMarkers(BuilderRef b, PopCountData & pc, const unsigned index, Value * const offset) const;
223
224// consumer recording
225
226    ConsumerGraph makeConsumerGraph() const;
227    void createConsumedPhiNodes(BuilderRef b);
228    void initializeConsumedItemCount(BuilderRef b, const unsigned bufferVertex, Value * const produced);
229    void setConsumedItemCount(BuilderRef b, const unsigned bufferVertex, Value * const consumed) const;
230    Value * getConsumedItemCount(BuilderRef b, const unsigned index) const;
231
232// buffer analysis/management functions
233
234    BufferGraph makeBufferGraph(BuilderRef b);
235    void enumerateBufferProducerBindings(const unsigned producer, const Bindings & bindings, BufferGraph & G, BufferMap & M);
236    void enumerateBufferConsumerBindings(const unsigned consumer, const Bindings & bindings, BufferGraph & G, BufferMap & M);
237    BufferRateData getBufferRateData(const unsigned index, const unsigned port, bool input);
238
239    void constructBuffers(BuilderRef b);
240    void loadBufferHandles(BuilderRef b);
241    void releaseBuffers(BuilderRef b);
242    LLVM_READNONE bool requiresCopyBack(const StreamSetBuffer * const buffer) const;
243    LLVM_READNONE bool requiresFacsimile(const StreamSetBuffer * const buffer) const;
244    LLVM_READNONE unsigned getCopyBack(const StreamSetBuffer * const buffer) const;
245    LLVM_READNONE unsigned getFacsimile(const StreamSetBuffer * const buffer) const;
246    LLVM_READNONE bool isPipelineIO(const StreamSetBuffer * const buffer) const;
247
248    Value * getLogicalInputBaseAddress(BuilderRef b, const unsigned index) const;
249    Value * getLogicalOutputBaseAddress(BuilderRef b, const unsigned index) const;
250    Value * calculateLogicalBaseAddress(BuilderRef b, const Binding & binding, const StreamSetBuffer * const buffer, Value * const itemCount) const;
251
252// cycle counter functions
253
254    void startOptionalCycleCounter(BuilderRef b);
255    void updateOptionalCycleCounter(BuilderRef b);
256    void printOptionalCycleCounter(BuilderRef b);
257
258// analysis functions
259
260    PortDependencyGraph makePortDependencyGraph() const;
261    TerminationGraph makeTerminationGraph() const;
262    ScalarDependencyGraph makeScalarDependencyGraph() const;
263
264// misc. functions
265
266    Value * getFunctionFromKernelState(BuilderRef b, Type * const type, const std::string & suffix) const;
267
268    Value * getInitializationFunction(BuilderRef b) const;
269
270    Value * getDoSegmentFunction(BuilderRef b) const;
271
272    Value * getFinalizeFunction(BuilderRef b) const;
273
274    std::string makeKernelName(const unsigned kernelIndex) const;
275
276    std::string makeBufferName(const unsigned kernelIndex, const Binding & binding) const;
277
278    StreamSetBuffer * getInputBuffer(const unsigned index) const;
279
280    StreamSetBuffer * getOutputBuffer(const unsigned index) const;
281
282    const Binding & getBinding(const Port port, const unsigned i) const {
283        if (port == Port::Input) {
284            return mKernel->getInputStreamSetBinding(i);
285        } else {
286            return mKernel->getOutputStreamSetBinding(i);
287        }
288    }
289
290    void printBufferGraph(const BufferGraph & G, raw_ostream & out);
291
292    LLVM_READNONE const Binding & getInputBinding(const Kernel * const producer, const unsigned index) const;
293
294    LLVM_READNONE const Binding & getOutputBinding(const Kernel * const consumer, const unsigned index) const;
295
296    void writeOutputScalars(BuilderRef b, const unsigned u, std::vector<Value *> & args);
297
298    void itemCountSanityCheck(BuilderRef b, const Binding & binding, const std::string & presentLabel, const std::string & pastLabel,
299                              Value * const itemCount, Value * const expected, Value * const terminated);
300
301protected:
302
303    PipelineKernel * const                      mPipelineKernel;
304    const Kernels &                             mPipeline;
305
306    OwnedStreamSetBuffers                       mOwnedBuffers;
307    unsigned                                    mKernelIndex = 0;
308    const Kernel *                              mKernel = nullptr;
309
310    // pipeline state
311    PHINode *                                   mTerminatedPhi = nullptr;
312    PHINode *                                   mSegNo = nullptr;
313    BasicBlock *                                mPipelineLoop = nullptr;
314    BasicBlock *                                mKernelEntry = nullptr;
315    BasicBlock *                                mKernelLoopEntry = nullptr;
316    BasicBlock *                                mKernelLoopCall = nullptr;
317    BasicBlock *                                mKernelLoopExit = nullptr;
318    BasicBlock *                                mKernelLoopExitPhiCatch = nullptr;
319    BasicBlock *                                mKernelExit = nullptr;
320    BasicBlock *                                mPipelineEnd = nullptr;
321
322    // pipeline state
323    StreamSetBufferMap<Value *>                 mInputConsumedItemCountPhi;
324    StreamSetBufferMap<Value *>                 mTotalItemCount;
325    StreamSetBufferMap<Value *>                 mConsumedItemCount;
326    std::vector<Value *>                        mOutputScalars;
327
328    // kernel state
329    Value *                                     mNoMore = nullptr;
330    Value *                                     mNumOfLinearStrides = nullptr;
331    PHINode *                                   mNumOfLinearStridesPhi = nullptr;
332    Value *                                     mNonFinal = nullptr;
333    PHINode *                                   mIsFinalPhi = nullptr;
334
335    std::vector<Value *>                        mInputStrideLength;
336    std::vector<Value *>                        mAccessibleInputItems;
337    std::vector<PHINode *>                      mAccessibleInputItemsPhi;
338    std::vector<Value *>                        mInputStreamHandle;
339
340    std::vector<Value *>                        mOutputStrideLength;
341    std::vector<Value *>                        mWritableOutputItems;
342    std::vector<Value *>                        mCopyForwardProducedOutputItems;
343    std::vector<Value *>                        mAnteriorProcessedItemCount;
344    std::vector<Value *>                        mCopyBackProducedOutputItems;
345    std::vector<PHINode *>                      mWritableOutputItemsPhi;
346    std::vector<Value *>                        mOutputStreamHandle;
347
348    // debug + misc state
349    Value *                                     mCycleCountStart = nullptr;
350    PHINode *                                   mDeadLockCounter = nullptr;
351    Value *                                     mPipelineProgress = nullptr;
352    PHINode *                                   mHasProgressedPhi = nullptr;
353    PHINode *                                   mAlreadyProgressedPhi = nullptr;
354
355    // popcount state
356    PopCountStreamDependencyGraph               mPopCountDependencyGraph;
357    PopCountDataMap                             mPopCountDataMap;
358
359
360    // analysis state
361    flat_set<const StreamSetBuffer *>           mIsPipelineIO;
362    OverflowRequirements                        mOverflowRequirements;
363    BufferGraph                                 mBufferGraph;
364    ConsumerGraph                               mConsumerGraph;
365    TerminationGraph                            mTerminationGraph;
366    ScalarDependencyGraph                       mScalarDependencyGraph;
367
368};
369
370inline PipelineCompiler::PipelineCompiler(BuilderRef b, PipelineKernel * const pipelineKernel)
371: mPipelineKernel(pipelineKernel)
372, mPipeline(pipelineKernel->mKernels)
373, mBufferGraph(makeBufferGraph(b))
374, mConsumerGraph(makeConsumerGraph())
375, mTerminationGraph(makeTerminationGraph())
376, mScalarDependencyGraph(makeScalarDependencyGraph()) {
377
378
379
380
381}
382
383/** ------------------------------------------------------------------------------------------------------------- *
384 * @brief getInputBuffer
385 ** ------------------------------------------------------------------------------------------------------------- */
386inline StreamSetBuffer * PipelineCompiler::getInputBuffer(const unsigned index) const {
387    for (const auto e : make_iterator_range(in_edges(mKernelIndex, mBufferGraph))) {
388        if (mBufferGraph[e].port == index) {
389            return mBufferGraph[source(e, mBufferGraph)].buffer;
390        }
391    }
392    llvm_unreachable("input buffer not found");
393    return nullptr;
394}
395
396/** ------------------------------------------------------------------------------------------------------------- *
397 * @brief getOutputBuffer
398 ** ------------------------------------------------------------------------------------------------------------- */
399inline StreamSetBuffer * PipelineCompiler::getOutputBuffer(const unsigned index) const {
400    for (const auto e : make_iterator_range(out_edges(mKernelIndex, mBufferGraph))) {
401        if (mBufferGraph[e].port == index) {
402            return mBufferGraph[target(e, mBufferGraph)].buffer;
403        }
404    }
405    llvm_unreachable("output buffer not found");
406    return nullptr;
407}
408
409/** ------------------------------------------------------------------------------------------------------------- *
410 * @brief storedInKernel
411 ** ------------------------------------------------------------------------------------------------------------- */
412inline LLVM_READNONE bool storedInNestedKernel(const Binding & output) {
413    return output.getRate().isUnknown() || output.hasAttribute(AttrId::ManagedBuffer);
414}
415
416/** ------------------------------------------------------------------------------------------------------------- *
417 * @brief upperBound
418 ** ------------------------------------------------------------------------------------------------------------- */
419inline LLVM_READNONE RateValue upperBound(not_null<const Kernel *> kernel, const Binding & binding) {
420    assert (kernel->getStride() > 0);
421//    const auto ub = kernel->getUpperBound(binding);
422//    const auto stride = kernel->getStride();
423//    return (ub == 0) ? stride : ub * stride;
424    return kernel->getUpperBound(binding) * kernel->getStride();
425}
426
427/** ------------------------------------------------------------------------------------------------------------- *
428 * @brief lowerBound
429 ** ------------------------------------------------------------------------------------------------------------- */
430inline LLVM_READNONE RateValue lowerBound(not_null<const Kernel *> kernel, const Binding & binding) {
431    assert (kernel->getStride() > 0);
432    return kernel->getLowerBound(binding) * kernel->getStride();
433}
434
435/** ------------------------------------------------------------------------------------------------------------- *
436 * @brief makeKernelName
437 ** ------------------------------------------------------------------------------------------------------------- */
438inline LLVM_READNONE std::string PipelineCompiler::makeKernelName(const unsigned kernelIndex) const {
439    return PipelineKernel::makeKernelName(mPipeline[kernelIndex], kernelIndex);
440}
441
442/** ------------------------------------------------------------------------------------------------------------- *
443 * @brief makeBufferName
444 ** ------------------------------------------------------------------------------------------------------------- */
445inline LLVM_READNONE std::string PipelineCompiler::makeBufferName(const unsigned kernelIndex, const Binding & binding) const {
446    return PipelineKernel::makeBufferName(mPipeline[kernelIndex], kernelIndex, binding);
447}
448
449/** ------------------------------------------------------------------------------------------------------------- *
450 * @brief getRelationship
451 ** ------------------------------------------------------------------------------------------------------------- */
452inline const Relationship * getRelationship(not_null<const Relationship *> r) {
453    return r.get();
454}
455
456/** ------------------------------------------------------------------------------------------------------------- *
457 * @brief getRelationship
458 ** ------------------------------------------------------------------------------------------------------------- */
459inline const Relationship * getRelationship(const Binding & b) {
460    return getRelationship(b.getRelationship());
461}
462
463inline unsigned LLVM_READNONE getItemWidth(const Type * ty ) {
464    if (LLVM_LIKELY(isa<ArrayType>(ty))) {
465        ty = ty->getArrayElementType();
466    }
467    return cast<IntegerType>(ty->getVectorElementType())->getBitWidth();
468}
469
470template <typename Graph>
471inline typename graph_traits<Graph>::edge_descriptor in_edge(const typename graph_traits<Graph>::vertex_descriptor u, const Graph & G) {
472    assert (in_degree(u, G) == 1);
473    return *in_edges(u, G).first;
474}
475
476} // end of namespace
477
478#endif // PIPELINE_COMPILER_HPP
Note: See TracBrowser for help on using the repository browser.