source: icGREP/icgrep-devel/icgrep/kernels/kernel.h

Last change on this file was 6296, checked in by cameron, 3 months ago

Merge branch 'master' of https://cs-git-research.cs.surrey.sfu.ca/cameron/parabix-devel

File size: 24.9 KB
Line 
1/*
2 *  Copyright (c) 2016 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 */
5
6#ifndef KERNEL_H
7#define KERNEL_H
8
9#include <kernels/binding.h>
10#include <kernels/relationship.h>
11#include <kernels/streamset.h>
12#include <util/not_null.h>
13#include <llvm/ADT/StringRef.h>
14#include <llvm/Support/Compiler.h>
15#include <llvm/ADT/StringMap.h>
16#include <memory>
17#include <string>
18#include <vector>
19
20namespace llvm { class AllocaInst; }
21namespace llvm { class BasicBlock; }
22namespace llvm { class CallInst; }
23namespace llvm { class Constant; }
24namespace llvm { class Function; }
25namespace llvm { class IntegerType; }
26namespace llvm { class IndirectBrInst; }
27namespace llvm { class Module; }
28namespace llvm { class PHINode; }
29namespace llvm { class StructType; }
30namespace llvm { class LoadInst; }
31namespace llvm { class Type; }
32namespace llvm { class Value; }
33
34class BaseDriver;
35
36const static std::string BUFFER_HANDLE_SUFFIX = "_buffer";
37
38namespace kernel {
39
40class KernelBuilder;
41class StreamSetBuffer;
42class StreamSet;
43
44class Kernel : public AttributeSet {
45    friend class KernelBuilder;
46    friend class PipelineBuilder;
47    friend class PipelineCompiler;
48    friend class PipelineKernel;
49    friend class OptimizationBranch;
50    friend class OptimizationBranchCompiler;
51    friend class BaseDriver;
52public:
53
54    enum class TypeId {
55        SegmentOriented
56        , MultiBlock
57        , BlockOriented
58        , Pipeline
59        , OptimizationBranch
60    };
61
62    static bool classof(const Kernel *) { return true; }
63
64    static bool classof(const void *) { return false; }
65
66    LLVM_READNONE TypeId getTypeId() const {
67        return mTypeId;
68    }
69
70public:
71
72    enum class ScalarType { Input, Output, Internal, Local };
73
74    struct ScalarField {
75        ScalarType    Type;
76        unsigned      Index;
77
78        ScalarField(const ScalarType type, const unsigned index)
79        : Type(type), Index(index) {
80
81        }
82        constexpr ScalarField(const ScalarField & other) = default;
83        ScalarField & operator=(ScalarField && other) = default;
84    };
85
86    using ScalarFieldMap = llvm::StringMap<ScalarField>;
87
88    enum class Port { Input, Output };
89
90    using StreamSetPort = std::pair<Port, unsigned>;
91
92    using StreamSetMap = llvm::StringMap<StreamSetPort>;
93
94    // Kernel Signatures and Module IDs
95    //
96    // A kernel signature uniquely identifies a kernel and its full functionality.
97    // In the event that a particular kernel instance is to be generated and compiled
98    // to produce object code, and we have a cached kernel object code instance with
99    // the same signature and targetting the same IDISA architecture, then the cached
100    // object code may safely be used to avoid recompilation.
101    //
102    // A kernel signature is a byte string of arbitrary length.
103    //
104    // Kernel developers should take responsibility for designing appropriate signature
105    // mechanisms that are short, inexpensive to compute and guarantee uniqueness
106    // based on the semantics of the kernel.
107    //
108    // If no other mechanism is available, the default makeSignature() method uses the
109    // full LLVM IR (before optimization) of the kernel instance.
110    //
111    // A kernel Module ID is short string that is used as a name for a particular kernel
112    // instance.  Kernel Module IDs are used to look up and retrieve cached kernel
113    // instances and so should be highly likely to uniquely identify a kernel instance.
114    //
115    // The ideal case is that a kernel Module ID serves as a full kernel signature thus
116    // guaranteeing uniqueness.  In this case, hasSignature() should return false.
117    //
118
119    //
120    // Kernel builder subtypes define their logic of kernel construction
121    // in terms of 3 virtual methods for
122    // (a) preparing the Kernel state data structure
123    // (c) defining the logic of the finalBlock function.
124    //
125    // Note: the kernel state data structure must only be finalized after
126    // all scalar fields have been added.   If there are no fields to
127    // be added, the default method for preparing kernel state may be used.
128
129    LLVM_READNONE virtual const std::string getName() const {
130        return mKernelName;
131    }
132
133    LLVM_READNONE virtual bool hasFamilyName() const {
134        return false;
135    }
136
137    LLVM_READNONE virtual const std::string getFamilyName() const {
138        if (hasFamilyName()) {
139            return getDefaultFamilyName();
140        } else {
141            return getName();
142        }
143    }
144
145    virtual std::string makeSignature(const std::unique_ptr<KernelBuilder> & b);
146
147    virtual bool hasSignature() const { return true; }
148
149    virtual bool isCachable() const { return false; }
150
151    LLVM_READNONE bool isStateful() const;
152
153    unsigned getStride() const { return mStride; }
154
155    LLVM_READNONE const Bindings & getInputStreamSetBindings() const {
156        return mInputStreamSets;
157    }
158
159    LLVM_READNONE const Binding & getInputStreamSetBinding(const unsigned i) const {
160        assert (i < getNumOfStreamInputs());
161        return mInputStreamSets[i];
162    }
163
164    LLVM_READNONE const Binding & getInputStreamSetBinding(const llvm::StringRef name) const {
165        const auto port = getStreamPort(name);
166        assert (port.first == Port::Input);
167        return getInputStreamSetBinding(port.second);
168    }
169
170    LLVM_READNONE StreamSet * getInputStreamSet(const unsigned i) const {
171        return llvm::cast<StreamSet>(getInputStreamSetBinding(i).getRelationship());
172    }
173
174    LLVM_READNONE StreamSet * getInputStreamSet(const llvm::StringRef name) const {
175        return llvm::cast<StreamSet>(getInputStreamSetBinding(name).getRelationship());
176    }
177
178    void setInputStreamSet(const llvm::StringRef name, StreamSet * value) {
179        const auto port = getStreamPort(name);
180        assert (port.first == Port::Input);
181        setInputStreamSetAt(port.second, value);
182    }
183
184    LLVM_READNONE unsigned getNumOfStreamInputs() const {
185        return mInputStreamSets.size();
186    }
187
188    LLVM_READNONE const OwnedStreamSetBuffers & getInputStreamSetBuffers() const {
189        return mStreamSetInputBuffers;
190    }
191
192    LLVM_READNONE StreamSetBuffer * getInputStreamSetBuffer(const unsigned i) const {
193        assert (i < mStreamSetInputBuffers.size());
194        assert (mStreamSetInputBuffers[i]);
195        return mStreamSetInputBuffers[i].get();
196    }
197
198    LLVM_READNONE StreamSetBuffer * getInputStreamSetBuffer(const llvm::StringRef name) const {
199        const auto port = getStreamPort(name);
200        assert (port.first == Port::Input);
201        return getInputStreamSetBuffer(port.second);
202    }
203
204    LLVM_READNONE const Binding & getOutputStreamSetBinding(const unsigned i) const {
205        assert (i < getNumOfStreamOutputs());
206        return mOutputStreamSets[i];
207    }
208
209    LLVM_READNONE const Binding & getOutputStreamSetBinding(const llvm::StringRef name) const {
210        const auto port = getStreamPort(name);
211        assert (port.first == Port::Output);
212        return getOutputStreamSetBinding(port.second);
213    }
214
215    LLVM_READNONE StreamSet * getOutputStreamSet(const unsigned i) const {
216        return llvm::cast<StreamSet>(getOutputStreamSetBinding(i).getRelationship());
217    }
218
219    LLVM_READNONE StreamSet * getOutputStreamSet(const llvm::StringRef name) const {
220        return llvm::cast<StreamSet>(getOutputStreamSetBinding(name).getRelationship());
221    }
222
223    void setOutputStreamSet(const llvm::StringRef name, StreamSet * value) {
224        const auto port = getStreamPort(name);
225        assert (port.first == Port::Output);
226        setOutputStreamSetAt(port.second, value);
227    }
228
229    const Bindings & getOutputStreamSetBindings() const {
230        return mOutputStreamSets;
231    }
232
233    unsigned getNumOfStreamOutputs() const {
234        return mOutputStreamSets.size();
235    }
236
237    LLVM_READNONE const OwnedStreamSetBuffers & getOutputStreamSetBuffers() const {
238        return mStreamSetOutputBuffers;
239    }
240
241    LLVM_READNONE StreamSetBuffer * getOutputStreamSetBuffer(const unsigned i) const {
242        assert (i < mStreamSetOutputBuffers.size());
243        assert (mStreamSetOutputBuffers[i]);
244        return mStreamSetOutputBuffers[i].get();
245    }
246
247    LLVM_READNONE StreamSetBuffer * getOutputStreamSetBuffer(const llvm::StringRef name) const {
248        const auto port = getStreamPort(name);
249        assert (port.first == Port::Output);
250        return getOutputStreamSetBuffer(port.second);
251    }
252
253    const Bindings & getInputScalarBindings() const {
254        return mInputScalars;
255    }
256
257    Binding & getInputScalarBinding(const unsigned i) {
258        assert (i < mInputScalars.size());
259        return mInputScalars[i];
260    }
261
262    LLVM_READNONE Binding & getInputScalarBinding(const llvm::StringRef name);
263
264    LLVM_READNONE const Binding & getInputScalarBinding(const llvm::StringRef name) const {
265        return const_cast<Kernel *>(this)->getInputScalarBinding(name);
266    }
267
268    LLVM_READNONE Scalar * getInputScalar(const unsigned i) {
269        return llvm::cast<Scalar>(getInputScalarBinding(i).getRelationship());
270    }
271
272    LLVM_READNONE Scalar * getInputScalar(const llvm::StringRef name) {
273        return llvm::cast<Scalar>(getInputScalarBinding(name).getRelationship());
274    }
275
276    void setInputScalar(const llvm::StringRef name, Scalar * value) {
277        const auto & field = getScalarField(name);
278        assert(field.Type == ScalarType::Input);
279        setInputScalarAt(field.Index, value);
280    }
281
282    LLVM_READNONE unsigned getNumOfScalarInputs() const {
283        return mInputScalars.size();
284    }
285
286    const Bindings & getOutputScalarBindings() const {
287        return mOutputScalars;
288    }
289
290    Binding & getOutputScalarBinding(const unsigned i) {
291        assert (i < mInputScalars.size());
292        return mOutputScalars[i];
293    }
294
295    LLVM_READNONE Binding & getOutputScalarBinding(const llvm::StringRef name);
296
297    LLVM_READNONE const Binding & getOutputScalarBinding(const llvm::StringRef name) const {
298        return const_cast<Kernel *>(this)->getOutputScalarBinding(name);
299    }
300
301    LLVM_READNONE Scalar * getOutputScalar(const llvm::StringRef name) {
302        return llvm::cast<Scalar>(getOutputScalarBinding(name).getRelationship());
303    }
304
305    LLVM_READNONE Scalar * getOutputScalar(const unsigned i) {
306        return llvm::cast<Scalar>(getOutputScalarBinding(i).getRelationship());
307    }
308
309    void setOutputScalar(const llvm::StringRef name, Scalar * value) {
310        const auto & field = getScalarField(name);
311        assert(field.Type == ScalarType::Output);
312        setOutputScalarAt(field.Index, value);
313    }
314
315    LLVM_READNONE unsigned getNumOfScalarOutputs() const {
316        return mOutputScalars.size();
317    }
318
319    void addInternalScalar(llvm::Type * type, const llvm::StringRef name);
320
321    void addLocalScalar(llvm::Type * type, const llvm::StringRef name);
322
323    llvm::Value * getHandle() const {
324        return mHandle;
325    }
326
327    void setHandle(const std::unique_ptr<KernelBuilder> & b, llvm::Value * const instance);
328
329    llvm::Module * setModule(llvm::Module * const module);
330
331    llvm::Module * getModule() const {
332        return mModule;
333    }
334
335    llvm::StructType * getKernelType() const {
336        return mKernelStateType;
337    }
338
339    LLVM_READNONE const StreamSetBuffer * getStreamSetBuffer(const llvm::StringRef name) const {
340        unsigned index; Port port;
341        std::tie(port, index) = getStreamPort(name);
342        if (port == Port::Input) {
343            return getInputStreamSetBuffer(index);
344        } else {
345            return getOutputStreamSetBuffer(index);
346        }
347    }
348
349    llvm::Module * makeModule(const std::unique_ptr<KernelBuilder> & b);
350
351    // Add ExternalLinkage method declarations for the kernel to a given client module.
352    virtual void addKernelDeclarations(const std::unique_ptr<KernelBuilder> & b);
353
354    llvm::Value * createInstance(const std::unique_ptr<KernelBuilder> & b) const;
355
356    virtual void initializeInstance(const std::unique_ptr<KernelBuilder> & b, std::vector<llvm::Value *> & args);
357
358    llvm::Value * finalizeInstance(const std::unique_ptr<KernelBuilder> & b, llvm::Value * const handle) const;
359
360    void generateKernel(const std::unique_ptr<KernelBuilder> & b);
361
362    void prepareKernel(const std::unique_ptr<KernelBuilder> & b);
363
364    void prepareCachedKernel(const std::unique_ptr<KernelBuilder> & b);
365
366    LLVM_READNONE std::string getCacheName(const std::unique_ptr<KernelBuilder> & b) const;
367
368    LLVM_READNONE StreamSetPort getStreamPort(const llvm::StringRef name) const;
369
370    LLVM_READNONE const Binding & getStreamBinding(const llvm::StringRef name) const;
371
372    LLVM_READNONE ProcessingRate::RateValue getLowerBound(const Binding & binding) const;
373
374    LLVM_READNONE ProcessingRate::RateValue getUpperBound(const Binding & binding) const;
375
376    LLVM_READNONE bool isCountable(const Binding & binding) const;
377
378    LLVM_READNONE bool isAddressable(const Binding & binding) const;
379
380    LLVM_READNONE bool requiresOverflow(const Binding & binding) const;
381
382    /* Fill in any generated names / attributes for the kernel if their initialization is dependent on
383     * settings / bindings added after construction. */
384    virtual void finalizeKernel() { }
385
386    void initializeBindings(BaseDriver & driver);
387
388    virtual ~Kernel() = 0;
389
390protected:
391
392    static std::string getStringHash(const llvm::StringRef str);
393
394    LLVM_READNONE std::string getDefaultFamilyName() const;
395
396    virtual void addInternalKernelProperties(const std::unique_ptr<KernelBuilder> &) { }
397
398    virtual void linkExternalMethods(const std::unique_ptr<KernelBuilder> &) { }
399
400    virtual void generateInitializeMethod(const std::unique_ptr<KernelBuilder> &) { }
401
402    virtual void generateKernelMethod(const std::unique_ptr<KernelBuilder> &) = 0;
403
404    virtual void generateFinalizeMethod(const std::unique_ptr<KernelBuilder> &) { }
405
406    virtual void addAdditionalFunctions(const std::unique_ptr<KernelBuilder> &) { }
407
408    virtual void setInputStreamSetAt(const unsigned i, StreamSet * value);
409
410    virtual void setOutputStreamSetAt(const unsigned i, StreamSet * value);
411
412    virtual void setInputScalarAt(const unsigned i, Scalar * value);
413
414    virtual void setOutputScalarAt(const unsigned i, Scalar * value);
415
416    virtual std::vector<llvm::Value *> getFinalOutputScalars(const std::unique_ptr<KernelBuilder> & b);
417
418    void setStride(unsigned stride) { mStride = stride; }
419
420    LLVM_READNONE llvm::Value * getAccessibleInputItems(const llvm::StringRef name) const {
421        Port port; unsigned index;
422        std::tie(port, index) = getStreamPort(name);
423        assert (port == Port::Input);
424        return getAccessibleInputItems(index);
425    }
426
427    LLVM_READNONE llvm::Value * getAccessibleInputItems(const unsigned index) const {
428        assert (index < mAccessibleInputItems.size());
429        return mAccessibleInputItems[index];
430    }
431
432    LLVM_READNONE llvm::Value * getAvailableInputItems(const llvm::StringRef name) const {
433        Port port; unsigned index;
434        std::tie(port, index) = getStreamPort(name);
435        assert (port == Port::Input);
436        return getAvailableInputItems(index);
437    }
438
439    LLVM_READNONE llvm::Value * getAvailableInputItems(const unsigned index) const {
440        assert (index < mAvailableInputItems.size());
441        return mAvailableInputItems[index];
442    }
443
444    LLVM_READNONE bool canSetTerminateSignal() const {
445        return hasAttribute(Attribute::KindId::CanTerminateEarly) || hasAttribute(Attribute::KindId::MustExplicitlyTerminate);
446    }
447
448    LLVM_READNONE llvm::Value * getTerminationSignalPtr() const {
449        return mTerminationSignalPtr;
450    }
451
452    LLVM_READNONE llvm::Value * getProcessedInputItemsPtr(const llvm::StringRef name) const {
453        Port port; unsigned index;
454        std::tie(port, index) = getStreamPort(name);
455        assert (port == Port::Input);
456        return getProcessedInputItemsPtr(index);
457    }
458
459    LLVM_READNONE llvm::Value * getProcessedInputItemsPtr(const unsigned index) const {
460        return mProcessedInputItemPtr[index];
461    }
462
463    LLVM_READNONE llvm::Value * getProducedOutputItemsPtr(const llvm::StringRef name) const {
464        Port port; unsigned index;
465        std::tie(port, index) = getStreamPort(name);
466        assert (port == Port::Output);
467        return getProducedOutputItemsPtr(index);
468    }
469
470    LLVM_READNONE llvm::Value * getProducedOutputItemsPtr(const unsigned index) const {
471        return mProducedOutputItemPtr[index];
472    }
473
474    LLVM_READNONE llvm::Value * getWritableOutputItems(const llvm::StringRef name) const {
475        Port port; unsigned index;
476        std::tie(port, index) = getStreamPort(name);
477        assert (port == Port::Output);
478        return getWritableOutputItems(index);
479    }
480
481    LLVM_READNONE llvm::Value * getWritableOutputItems(const unsigned index) const {
482        return mWritableOutputItems[index];
483    }
484
485    LLVM_READNONE llvm::Value * getConsumedOutputItems(const llvm::StringRef name) const {
486        Port port; unsigned index;
487        std::tie(port, index) = getStreamPort(name);
488        assert (port == Port::Output);
489        return getConsumedOutputItems(index);
490    }
491
492    LLVM_READNONE llvm::Value * getConsumedOutputItems(const unsigned index) const {
493        return mConsumedOutputItems[index];
494    }
495
496
497    LLVM_READNONE llvm::Value * isFinal() const {
498        return mIsFinal;
499    }
500
501    // Constructor
502    Kernel(const std::unique_ptr<KernelBuilder> & b,
503           const TypeId typeId, std::string && kernelName,
504           Bindings && stream_inputs, Bindings && stream_outputs,
505           Bindings && scalar_inputs, Bindings && scalar_outputs,
506           Bindings && internal_scalars);
507
508private:
509
510    void initializeLocalScalarValues(const std::unique_ptr<KernelBuilder> & b);
511
512    void addInitializeDeclaration(const std::unique_ptr<KernelBuilder> & b);
513
514    void callGenerateInitializeMethod(const std::unique_ptr<KernelBuilder> & b);
515
516    void addDoSegmentDeclaration(const std::unique_ptr<KernelBuilder> & b);
517
518    std::vector<llvm::Type *> getDoSegmentFields(const std::unique_ptr<KernelBuilder> & b) const;
519
520    void callGenerateDoSegmentMethod(const std::unique_ptr<KernelBuilder> & b);
521
522    void setDoSegmentProperties(const std::unique_ptr<KernelBuilder> & b, const std::vector<llvm::Value *> & args);
523
524    std::vector<llvm::Value *> getDoSegmentProperties(const std::unique_ptr<KernelBuilder> & b) const;
525
526    void addFinalizeDeclaration(const std::unique_ptr<KernelBuilder> & b);
527
528    void callGenerateFinalizeMethod(const std::unique_ptr<KernelBuilder> & b);
529
530    void addScalarToMap(const llvm::StringRef name, const ScalarType scalarType, const unsigned index);
531
532    void addStreamToMap(const llvm::StringRef name, const Port port, const unsigned index);
533
534    LLVM_READNONE const ScalarField & getScalarField(const llvm::StringRef name) const;
535
536    llvm::Value * getScalarFieldPtr(KernelBuilder & b, const llvm::StringRef name) const;
537
538    void addBaseKernelProperties(const std::unique_ptr<KernelBuilder> & b);
539
540    llvm::Function * getInitFunction(llvm::Module * const module) const;
541
542    llvm::Function * getDoSegmentFunction(llvm::Module * const module) const;
543
544    llvm::Function * getTerminateFunction(llvm::Module * const module) const;
545
546protected:
547
548    mutable bool                    mIsGenerated;
549
550    llvm::Value *                   mHandle;
551    llvm::Module *                  mModule;
552    llvm::StructType *              mKernelStateType;
553
554    Bindings                        mInputStreamSets;
555    Bindings                        mOutputStreamSets;
556
557    Bindings                        mInputScalars;
558    Bindings                        mOutputScalars;
559    Bindings                        mInternalScalars;
560    Bindings                        mLocalScalars;
561
562    llvm::Function *                mCurrentMethod;
563    unsigned                        mStride;
564
565    llvm::Value *                   mTerminationSignalPtr;
566    llvm::Value *                   mIsFinal;
567    llvm::Value *                   mNumOfStrides;
568
569    std::vector<llvm::Value *>      mLocalScalarPtr;
570
571    std::vector<llvm::Value *>      mUpdatableProcessedInputItemPtr;
572    std::vector<llvm::Value *>      mProcessedInputItemPtr;
573
574    std::vector<llvm::Value *>      mAccessibleInputItems;
575    std::vector<llvm::Value *>      mAvailableInputItems;
576    std::vector<llvm::Value *>      mPopCountRateArray;
577    std::vector<llvm::Value *>      mNegatedPopCountRateArray;
578
579    std::vector<llvm::Value *>      mUpdatableProducedOutputItemPtr;
580    std::vector<llvm::Value *>      mProducedOutputItemPtr;
581
582    std::vector<llvm::Value *>      mWritableOutputItems;
583    std::vector<llvm::Value *>      mConsumedOutputItems;
584
585    ScalarFieldMap                  mScalarMap;
586    StreamSetMap                    mStreamSetMap;
587
588    const std::string               mKernelName;
589    const TypeId                    mTypeId;
590
591    OwnedStreamSetBuffers           mStreamSetInputBuffers;
592    OwnedStreamSetBuffers           mStreamSetOutputBuffers;
593
594};
595
596class SegmentOrientedKernel : public Kernel {
597public:
598
599    static bool classof(const Kernel * const k) {
600        return k->getTypeId() == TypeId::SegmentOriented;
601    }
602
603    static bool classof(const void *) { return false; }
604
605protected:
606
607    SegmentOrientedKernel(const std::unique_ptr<KernelBuilder> & b,
608                          std::string && kernelName,
609                          Bindings && stream_inputs,
610                          Bindings && stream_outputs,
611                          Bindings && scalar_parameters,
612                          Bindings && scalar_outputs,
613                          Bindings && internal_scalars);
614public:
615
616    virtual void generateDoSegmentMethod(const std::unique_ptr<KernelBuilder> & b) = 0;
617
618protected:
619
620    void generateKernelMethod(const std::unique_ptr<KernelBuilder> & b) final;
621
622};
623
624class MultiBlockKernel : public Kernel {
625    friend class BlockOrientedKernel;
626    friend class OptimizationBranch;
627public:
628
629    static bool classof(const Kernel * const k) {
630        return k->getTypeId() == TypeId::MultiBlock;
631    }
632
633    static bool classof(const void *) { return false; }
634
635protected:
636
637    MultiBlockKernel(const std::unique_ptr<KernelBuilder> & b,
638                     std::string && kernelName,
639                     Bindings && stream_inputs,
640                     Bindings && stream_outputs,
641                     Bindings && scalar_parameters,
642                     Bindings && scalar_outputs,
643                     Bindings && internal_scalars);
644
645    virtual void generateMultiBlockLogic(const std::unique_ptr<KernelBuilder> & b, llvm::Value * const numOfStrides) = 0;
646
647private:
648
649    MultiBlockKernel(const std::unique_ptr<KernelBuilder> & b,
650                     const TypeId kernelTypId,
651                     std::string && kernelName,
652                     Bindings && stream_inputs,
653                     Bindings && stream_outputs,
654                     Bindings && scalar_parameters,
655                     Bindings && scalar_outputs,
656                     Bindings && internal_scalars);
657
658private:
659
660    void generateKernelMethod(const std::unique_ptr<KernelBuilder> & b) final;
661
662};
663
664
665class BlockOrientedKernel : public MultiBlockKernel {
666public:
667
668    static bool classof(const Kernel * const k) {
669        return k->getTypeId() == TypeId::BlockOriented;
670    }
671
672    static bool classof(const void *) { return false; }
673
674protected:
675
676    void CreateDoBlockMethodCall(const std::unique_ptr<KernelBuilder> & b);
677
678    // Each kernel builder subtype must provide its own logic for generating
679    // doBlock calls.
680    virtual void generateDoBlockMethod(const std::unique_ptr<KernelBuilder> & b) = 0;
681
682    // Each kernel builder subtypre must also specify the logic for processing the
683    // final block of stream data, if there is any special processing required
684    // beyond simply calling the doBlock function.   In the case that the final block
685    // processing may be trivially implemented by dispatching to the doBlock method
686    // without additional preparation, the default generateFinalBlockMethod need
687    // not be overridden.
688
689    virtual void generateFinalBlockMethod(const std::unique_ptr<KernelBuilder> & b, llvm::Value * remainingItems);
690
691    BlockOrientedKernel(const std::unique_ptr<KernelBuilder> & b,
692                        std::string && kernelName,
693                        Bindings && stream_inputs,
694                        Bindings && stream_outputs,
695                        Bindings && scalar_parameters,
696                        Bindings && scalar_outputs,
697                        Bindings && internal_scalars);
698
699    llvm::Value * getRemainingItems(const std::unique_ptr<KernelBuilder> & b);
700
701private:
702
703    void annotateInputBindingsWithPopCountArrayAttributes();
704
705    void incrementCountableItemCounts(const std::unique_ptr<KernelBuilder> & b);
706
707    llvm::Value * getPopCountRateItemCount(const std::unique_ptr<KernelBuilder> & b, const ProcessingRate & rate);
708
709    void generateMultiBlockLogic(const std::unique_ptr<KernelBuilder> & b, llvm::Value * const numOfStrides) final;
710
711    void writeDoBlockMethod(const std::unique_ptr<KernelBuilder> & b);
712
713    void writeFinalBlockMethod(const std::unique_ptr<KernelBuilder> & b, llvm::Value * remainingItems);
714
715private:
716
717    llvm::Function *            mDoBlockMethod;
718    llvm::BasicBlock *          mStrideLoopBody;
719    llvm::IndirectBrInst *      mStrideLoopBranch;
720    llvm::PHINode *             mStrideLoopTarget;
721    llvm::PHINode *             mStrideBlockIndex;
722};
723
724}
725
726#endif
Note: See TracBrowser for help on using the repository browser.