Changeset 6296


Ignore:
Timestamp:
Jan 23, 2019, 11:18:13 AM (5 months ago)
Author:
cameron
Message:

Merge branch 'master' of https://cs-git-research.cs.surrey.sfu.ca/cameron/parabix-devel

Location:
icGREP/icgrep-devel/icgrep
Files:
13 edited

Legend:

Unmodified
Added
Removed
  • icGREP/icgrep-devel/icgrep/kernels/kernel.cpp

    r6288 r6296  
    247247
    248248/** ------------------------------------------------------------------------------------------------------------- *
    249  * @brief isParamAddressable
    250  ** ------------------------------------------------------------------------------------------------------------- */
    251 inline bool isParamAddressable(const Binding & binding) {
    252     if (binding.isDeferred()) {
    253         return true;
    254     }
    255     const ProcessingRate & rate = binding.getRate();
    256     return (rate.isBounded() || rate.isUnknown());
    257 }
    258 
    259 /** ------------------------------------------------------------------------------------------------------------- *
    260  * @brief isParamConstant
    261  ** ------------------------------------------------------------------------------------------------------------- */
    262 inline bool isParamConstant(const Binding & binding) {
    263     assert (!binding.isDeferred());
    264     const ProcessingRate & rate = binding.getRate();
    265     return rate.isFixed() || rate.isPopCount() || rate.isNegatedPopCount();
    266 }
    267 
    268 /** ------------------------------------------------------------------------------------------------------------- *
    269249 * @brief hasParam
    270250 ** ------------------------------------------------------------------------------------------------------------- */
     
    339319        // processed input items
    340320        const Binding & input = mInputStreamSets[i];
    341         if (isParamAddressable(input)) {
     321        if (isAddressable(input)) {
    342322            fields.push_back(sizePtrTy); // updatable
    343         }  else if (isParamConstant(input)) {
     323        }  else if (isCountable(input)) {
    344324            fields.push_back(sizeTy);  // constant
    345325        }
     
    364344        }
    365345        // produced output items
    366         if (canTerminate || isParamAddressable(output)) {
     346        if (canTerminate || isAddressable(output)) {
    367347            fields.push_back(sizePtrTy); // updatable
    368         } else if (isParamConstant(output)) {
     348        } else if (isCountable(output)) {
    369349            fields.push_back(sizeTy); // constant
    370350        }
     
    491471
    492472        Value * processed = nullptr;
    493         if (isParamAddressable(input)) {
     473        if (isAddressable(input)) {
    494474            assert (arg != args.end());
    495475            mUpdatableProcessedInputItemPtr[i] = *arg++;
    496476            processed = b->CreateLoad(mUpdatableProcessedInputItemPtr[i]);
    497         } else if (LLVM_LIKELY(isParamConstant(input))) {
     477        } else if (LLVM_LIKELY(isCountable(input))) {
    498478            assert (arg != args.end());
    499479            processed = *arg++;
     
    566546        /// ----------------------------------------------------
    567547        Value * produced = nullptr;
    568         if (LLVM_LIKELY(canTerminate || isParamAddressable(output))) {
     548        if (LLVM_LIKELY(canTerminate || isAddressable(output))) {
    569549            assert (arg != args.end());
    570550            mUpdatableProducedOutputItemPtr[i] = *arg++;
    571551            produced = b->CreateLoad(mUpdatableProducedOutputItemPtr[i]);
    572         } else if (LLVM_LIKELY(isParamConstant(output))) {
     552        } else if (LLVM_LIKELY(isCountable(output))) {
    573553            assert (arg != args.end());
    574554            produced = *arg++;
     
    593573        /// consumed or writable item count
    594574        /// ----------------------------------------------------
    595         Value * const items = *arg++;
    596575        if (LLVM_UNLIKELY(isLocalBuffer(output))) {
    597             mConsumedOutputItems[i] = items;
     576            Value * const consumed = *arg++;
     577            mConsumedOutputItems[i] = consumed;
    598578        } else {
    599             mWritableOutputItems[i] = items;
    600             Value * const capacity = b->CreateAdd(produced, items);
     579            Value * writable = *arg++;
     580            mWritableOutputItems[i] = writable;
     581            Value * const capacity = b->CreateAdd(produced, writable);
    601582            buffer->setCapacity(b.get(), capacity);
    602583        }
     
    638619        /// ----------------------------------------------------
    639620        const Binding & input = mInputStreamSets[i];
    640         if (isParamAddressable(input)) {
     621        if (isAddressable(input)) {
    641622            props.push_back(mProcessedInputItemPtr[i]);
    642         } else if (LLVM_LIKELY(isParamConstant(input))) {
     623        } else if (LLVM_LIKELY(isCountable(input))) {
    643624            props.push_back(b->CreateLoad(mProcessedInputItemPtr[i]));
    644625        }
     
    676657        /// produced item count
    677658        /// ----------------------------------------------------
    678         if (LLVM_LIKELY(canTerminate || isParamAddressable(output))) {
     659        if (LLVM_LIKELY(canTerminate || isAddressable(output))) {
    679660            props.push_back(mProducedOutputItemPtr[i]);
    680         } else if (LLVM_LIKELY(isParamConstant(output))) {
     661        } else if (LLVM_LIKELY(isCountable(output))) {
    681662            props.push_back(b->CreateLoad(mProducedOutputItemPtr[i]));
    682663        }
     
    980961 * @brief createInstance
    981962 ** ------------------------------------------------------------------------------------------------------------- */
    982 Value * Kernel::createInstance(const std::unique_ptr<KernelBuilder> & b) {
     963Value * Kernel::createInstance(const std::unique_ptr<KernelBuilder> & b) const {
    983964    if (LLVM_UNLIKELY(mKernelStateType == nullptr)) {
    984965        llvm_unreachable("Kernel state must be constructed prior to calling createInstance");
     
    1014995 * @brief finalizeInstance
    1015996 ** ------------------------------------------------------------------------------------------------------------- */
    1016 Value * Kernel::finalizeInstance(const std::unique_ptr<KernelBuilder> & b) {
     997Value * Kernel::finalizeInstance(const std::unique_ptr<KernelBuilder> & b, Value * const handle) const {
    1017998    Value * result = nullptr;
    1018999    Function * const termFunc = getTerminateFunction(b->getModule());
    10191000    if (LLVM_LIKELY(isStateful())) {
    1020         result = b->CreateCall(termFunc, { mHandle });
     1001        result = b->CreateCall(termFunc, { handle });
    10211002    } else {
    10221003        result = b->CreateCall(termFunc);
    10231004    }
    1024     mHandle = nullptr;
    10251005    if (mOutputScalars.empty()) {
    10261006        assert (!result || result->getType()->isVoidTy());
     
    11591139 ** ------------------------------------------------------------------------------------------------------------- */
    11601140bool Kernel::isCountable(const Binding & binding) const {
     1141    if (LLVM_UNLIKELY(binding.isDeferred())) {
     1142        return false;
     1143    }
    11611144    const ProcessingRate & rate = binding.getRate();
    1162     if (rate.isFixed() || rate.isPopCount() || rate.isNegatedPopCount()) {
     1145    return rate.isFixed() || rate.isPopCount() || rate.isNegatedPopCount();
     1146}
     1147
     1148/** ------------------------------------------------------------------------------------------------------------- *
     1149 * @brief isAddressable
     1150 ** ------------------------------------------------------------------------------------------------------------- */
     1151bool Kernel::isAddressable(const Binding & binding) const {
     1152    if (LLVM_UNLIKELY(binding.isDeferred())) {
    11631153        return true;
    1164     } else if (rate.isRelative()) {
    1165         return isCountable(getStreamBinding(rate.getReference()));
    1166     } else {
    1167         return false;
    1168     }
    1169 }
    1170 
    1171 /** ------------------------------------------------------------------------------------------------------------- *
    1172  * @brief isCalculable
    1173  ** ------------------------------------------------------------------------------------------------------------- */
    1174 bool Kernel::isCalculable(const Binding & binding) const {
     1154    }
    11751155    const ProcessingRate & rate = binding.getRate();
    1176     if (rate.isFixed() || rate.isBounded()) {
    1177         return true;
    1178     } else if (rate.isRelative()) {
    1179         return isCalculable(getStreamBinding(rate.getReference()));
    1180     } else {
    1181         return false;
    1182     }
     1156    return rate.isBounded() || rate.isUnknown();
    11831157}
    11841158
     
    11941168    } else {
    11951169        return true;
    1196     }
    1197 }
    1198 
    1199 /** ------------------------------------------------------------------------------------------------------------- *
    1200  * @brief isUnknownRate
    1201  ** ------------------------------------------------------------------------------------------------------------- */
    1202 bool Kernel::isUnknownRate(const Binding & binding) const {
    1203     const ProcessingRate & rate = binding.getRate();
    1204     if (rate.isUnknown()) {
    1205         return true;
    1206     } else if (rate.isRelative()) {
    1207         return isUnknownRate(getStreamBinding(rate.getReference()));
    1208     } else {
    1209         return false;
    12101170    }
    12111171}
  • icGREP/icgrep-devel/icgrep/kernels/kernel.h

    r6288 r6296  
    4848    friend class PipelineKernel;
    4949    friend class OptimizationBranch;
     50    friend class OptimizationBranchCompiler;
    5051    friend class BaseDriver;
    5152public:
     
    351352    virtual void addKernelDeclarations(const std::unique_ptr<KernelBuilder> & b);
    352353
    353     llvm::Value * createInstance(const std::unique_ptr<KernelBuilder> & b);
     354    llvm::Value * createInstance(const std::unique_ptr<KernelBuilder> & b) const;
    354355
    355356    virtual void initializeInstance(const std::unique_ptr<KernelBuilder> & b, std::vector<llvm::Value *> & args);
    356357
    357     llvm::Value * finalizeInstance(const std::unique_ptr<KernelBuilder> & b);
     358    llvm::Value * finalizeInstance(const std::unique_ptr<KernelBuilder> & b, llvm::Value * const handle) const;
    358359
    359360    void generateKernel(const std::unique_ptr<KernelBuilder> & b);
     
    375376    LLVM_READNONE bool isCountable(const Binding & binding) const;
    376377
    377     LLVM_READNONE bool isCalculable(const Binding & binding) const;
     378    LLVM_READNONE bool isAddressable(const Binding & binding) const;
    378379
    379380    LLVM_READNONE bool requiresOverflow(const Binding & binding) const;
    380 
    381     LLVM_READNONE bool isUnknownRate(const Binding & binding) const;
    382381
    383382    /* Fill in any generated names / attributes for the kernel if their initialization is dependent on
  • icGREP/icgrep-devel/icgrep/kernels/multiblock_kernel.cpp

    r6288 r6296  
    3636 ** ------------------------------------------------------------------------------------------------------------- */
    3737void MultiBlockKernel::generateKernelMethod(const std::unique_ptr<KernelBuilder> & b) {
    38     assert (mIsFinal);
    39     assert (mNumOfStrides);
    40     Value * const numOfStrides = b->CreateSelect(mIsFinal, b->getSize(1), mNumOfStrides);
    41     generateMultiBlockLogic(b, numOfStrides);
     38    generateMultiBlockLogic(b, b->CreateSelect(mIsFinal, b->getSize(1), mNumOfStrides));
    4239}
    4340
  • icGREP/icgrep-devel/icgrep/kernels/optimizationbranch.h

    r6289 r6296  
    1212class OptimizationBranch final : public Kernel {
    1313    friend class OptimizationBranchBuilder;
     14    friend class OptimizationBranchCompiler;
    1415public:
    1516
     
    2324    }
    2425
    25     const static std::string CONDITION_TAG;
     26    const Kernel * getAllZeroKernel() const {
     27        return mAllZeroKernel;
     28    }
    2629
    27     ~OptimizationBranch();
     30    const Kernel * getNonZeroKernel() const {
     31        return mNonZeroKernel;
     32    }
     33
     34    const Relationship * getCondition() const {
     35        return mCondition;
     36    }
    2837
    2938protected:
     
    3948                       Bindings && scalar_outputs);
    4049
     50    void addInternalKernelProperties(const std::unique_ptr<kernel::KernelBuilder> & b) final;
     51
    4152    void addKernelDeclarations(const std::unique_ptr<KernelBuilder> & b) final;
    4253
     
    4960private:
    5061
    51     llvm::Value * getItemCountIncrement(const std::unique_ptr<KernelBuilder> & b, const Binding & binding,
    52                                         llvm::Value * const first, llvm::Value * const last,
    53                                         llvm::Value * const defaultValue = nullptr) const;
    54 
    55     void callKernel(const std::unique_ptr<KernelBuilder> & b,
    56                     const Kernel * const kernel, llvm::Value * const first, llvm::Value * const last,
    57                     llvm::PHINode * const terminatedPhi);
    58 
    59 private:
    60 
    61     Relationship * const                                mCondition;
    62     Kernel * const                                      mNonZeroKernel;
    63     Kernel * const                                      mAllZeroKernel;
    64     mutable std::unique_ptr<OptimizationBranchCompiler> mCompiler;
     62    Relationship * const                        mCondition;
     63    Kernel * const                              mNonZeroKernel;
     64    Kernel * const                              mAllZeroKernel;
     65    std::unique_ptr<OptimizationBranchCompiler> mCompiler;
    6566};
    6667
  • icGREP/icgrep-devel/icgrep/kernels/optimizationbranch/optimizationbranch.cpp

    r6289 r6296  
    2323using ScalarDependencyMap = flat_map<const Relationship *, ScalarVertex>;
    2424
    25 const std::string OptimizationBranch::CONDITION_TAG = "@condition";
    26 
    27 const static std::string BRANCH_PREFIX = "@B";
    28 
    2925/** ------------------------------------------------------------------------------------------------------------- *
    30  * @brief isParamConstant
     26 * @brief addInternalKernelProperties
    3127 ** ------------------------------------------------------------------------------------------------------------- */
    32 inline bool isParamConstant(const Binding & binding) {
    33     if (binding.isDeferred()) {
    34         return false;
    35     }
    36     const ProcessingRate & rate = binding.getRate();
    37     return rate.isFixed() || rate.isPopCount() || rate.isNegatedPopCount();
     28void OptimizationBranch::addInternalKernelProperties(const std::unique_ptr<kernel::KernelBuilder> & b) {
     29    mCompiler = llvm::make_unique<OptimizationBranchCompiler>(this);
     30    mCompiler->addBranchProperties(b);
    3831}
    3932
    4033/** ------------------------------------------------------------------------------------------------------------- *
    41  * @brief isParamAddressable
     34 * @brief generateInitializeMethod
    4235 ** ------------------------------------------------------------------------------------------------------------- */
    43 inline bool isParamAddressable(const Binding & binding) {
    44     if (binding.isDeferred()) {
    45         return true;
    46     }
    47     const ProcessingRate & rate = binding.getRate();
    48     return (rate.isBounded() || rate.isUnknown());
    49 }
    50 
    51 /** ------------------------------------------------------------------------------------------------------------- *
    52  * @brief isLocalBuffer
    53  ** ------------------------------------------------------------------------------------------------------------- */
    54 inline bool isLocalBuffer(const Binding & output) {
    55     return output.getRate().isUnknown() || output.hasAttribute(AttrId::ManagedBuffer);
    56 }
    57 
    58 /** ------------------------------------------------------------------------------------------------------------- *
    59  * @brief loadKernelHandle
    60  ** ------------------------------------------------------------------------------------------------------------- */
    61 void loadHandle(const std::unique_ptr<KernelBuilder> & b, Kernel * const kernel, const std::string suffix) {
    62     if (LLVM_LIKELY(kernel->isStateful())) {
    63         Value * handle = b->getScalarField(BRANCH_PREFIX + suffix);
    64         if (kernel->hasFamilyName()) {
    65             handle = b->CreatePointerCast(handle, kernel->getKernelType()->getPointerTo());
    66         }
    67         kernel->setHandle(b, handle);
    68     }
     36void OptimizationBranch::generateInitializeMethod(const std::unique_ptr<KernelBuilder> & b) {
     37    mCompiler->generateInitializeMethod(b);
    6938}
    7039
     
    7342 ** ------------------------------------------------------------------------------------------------------------- */
    7443void OptimizationBranch::generateKernelMethod(const std::unique_ptr<KernelBuilder> & b) {
    75 
    76 
    77     BasicBlock * const nonZeroPath = b->CreateBasicBlock("nonZeroPath");
    78     BasicBlock * const allZeroPath = b->CreateBasicBlock("allZeroPath");
    79     BasicBlock * const exit = b->CreateBasicBlock("exit");
    80 
    81 
    82 
    83     Constant * const ZERO = b->getSize(0);
    84     Constant * const ONE = b->getSize(1);
    85 
    86     loadHandle(b, mAllZeroKernel, "0");
    87     loadHandle(b, mNonZeroKernel, "1");
    88 
    89 
    90     if (LLVM_LIKELY(isa<StreamSet>(mCondition))) {
    91 
    92         BasicBlock * const entry = b->GetInsertBlock();
    93         BasicBlock * const loopCond = b->CreateBasicBlock("cond", nonZeroPath);
    94         BasicBlock * const summarizeOneStride = b->CreateBasicBlock("summarizeOneStride", nonZeroPath);
    95         BasicBlock * const checkStride = b->CreateBasicBlock("checkStride", nonZeroPath);
    96         BasicBlock * const processStrides = b->CreateBasicBlock("processStrides", nonZeroPath);
    97         BasicBlock * const mergePaths = b->CreateBasicBlock("mergePaths", nonZeroPath);
    98 
    99         b->CreateBr(loopCond);
    100 
    101         b->SetInsertPoint(loopCond);
    102         IntegerType * const sizeTy = b->getSizeTy();
    103         IntegerType * const boolTy = b->getInt1Ty();
    104         PHINode * const currentFirstIndex = b->CreatePHI(sizeTy, 3, "firstStride");
    105         currentFirstIndex->addIncoming(ZERO, entry);
    106         PHINode * const currentLastIndex = b->CreatePHI(sizeTy, 3, "lastStride");
    107         currentLastIndex->addIncoming(ZERO, entry);
    108         PHINode * const currentState = b->CreatePHI(boolTy, 3);
    109         currentState->addIncoming(UndefValue::get(boolTy), entry);
    110 
    111 
    112         Constant * const blocksPerStride = b->getSize(getStride() / b->getBitBlockWidth());
    113         Value * const numOfConditionStreams = b->getInputStreamSetCount(CONDITION_TAG);
    114         Value * const numOfConditionBlocks = b->CreateMul(numOfConditionStreams, blocksPerStride);
    115 
    116         Value * const offset = b->CreateMul(currentLastIndex, blocksPerStride);
    117         Value * basePtr = b->getInputStreamBlockPtr(CONDITION_TAG, ZERO, offset);
    118         Type * const BitBlockTy = b->getBitBlockType();
    119         basePtr = b->CreatePointerCast(basePtr, BitBlockTy->getPointerTo());
    120         b->CreateBr(summarizeOneStride);
    121 
    122         // Predeclare some phi nodes
    123 
    124         b->SetInsertPoint(nonZeroPath);
    125         PHINode * const firstNonZeroIndex = b->CreatePHI(sizeTy, 2);
    126         PHINode * const lastNonZeroIndex = b->CreatePHI(sizeTy, 2);
    127         PHINode * const allZeroAfterNonZero = b->CreatePHI(boolTy, 2);
    128 
    129         b->SetInsertPoint(allZeroPath);
    130         PHINode * const firstAllZeroIndex = b->CreatePHI(sizeTy, 2);
    131         PHINode * const lastAllZeroIndex = b->CreatePHI(sizeTy, 2);
    132         PHINode * const nonZeroAfterAllZero = b->CreatePHI(boolTy, 2);
    133 
    134         PHINode * terminatedPhi = nullptr;
    135         if (canSetTerminateSignal()) {
    136             b->SetInsertPoint(mergePaths);
    137             terminatedPhi = b->CreatePHI(b->getInt1Ty(), 2);
    138         }
    139 
    140         // OR together every condition block in this stride
    141         b->SetInsertPoint(summarizeOneStride);
    142         PHINode * const iteration = b->CreatePHI(b->getSizeTy(), 2);
    143         iteration->addIncoming(ZERO, loopCond);
    144         PHINode * const merged = b->CreatePHI(BitBlockTy, 2);
    145         merged->addIncoming(Constant::getNullValue(BitBlockTy), loopCond);
    146         Value * value = b->CreateBlockAlignedLoad(basePtr, iteration);
    147         value = b->CreateOr(value, merged);
    148         merged->addIncoming(value, summarizeOneStride);
    149         Value * const nextIteration = b->CreateAdd(iteration, ONE);
    150         Value * const more = b->CreateICmpNE(nextIteration, numOfConditionBlocks);
    151         iteration->addIncoming(nextIteration, b->GetInsertBlock());
    152         b->CreateCondBr(more, summarizeOneStride, checkStride);
    153 
    154         // Check the merged value of our condition block(s); if it differs from
    155         // the prior value or this is our last stride, then process the strides.
    156         // Note, however, initially state is "indeterminate" so we silently
    157         // ignore the first stride unless it is also our last.
    158         b->SetInsertPoint(checkStride);
    159         Value * const nextState = b->bitblock_any(value);
    160         Value * const sameState = b->CreateICmpEQ(nextState, currentState);
    161         Value * const firstStride = b->CreateICmpEQ(currentLastIndex, ZERO);
    162         Value * const continuation = b->CreateOr(sameState, firstStride);
    163         Value * const nextIndex = b->CreateAdd(currentLastIndex, ONE);
    164 //        Value * const lastStrideIndex = b->CreateUMin(mNumOfStrides, ONE);
    165         Value * const notLastStride = b->CreateICmpULT(nextIndex, mNumOfStrides);
    166         Value * const checkNextStride = b->CreateAnd(continuation, notLastStride);
    167         currentLastIndex->addIncoming(nextIndex, checkStride);
    168         currentFirstIndex->addIncoming(currentFirstIndex, checkStride);
    169         currentState->addIncoming(nextState, checkStride);
    170         b->CreateLikelyCondBr(checkNextStride, loopCond, processStrides);
    171 
    172         // Process every stride between [first, last)
    173         b->SetInsertPoint(processStrides);
    174 
    175         // state is implicitly "indeterminate" during our first stride
    176         Value * const selectedPath = b->CreateSelect(firstStride, nextState, currentState);
    177         firstNonZeroIndex->addIncoming(currentFirstIndex, processStrides);
    178         firstAllZeroIndex->addIncoming(currentFirstIndex, processStrides);
    179         // When we reach the last (but not necessarily final) stride of this kernel,
    180         // we will either "append" the final stride to the current run or complete
    181         // the current run then perform one more iteration for the final stride, depending
    182         // whether it flips the branch selection state.
    183 
    184 //        b->CallPrintInt(" &&& nextState", nextState);
    185 //        b->CallPrintInt(" &&& firstStride", firstStride);
    186 //        b->CallPrintInt(" &&& continuation", continuation);
    187 //        b->CallPrintInt(" &&& notLastStride", notLastStride);
    188 
    189 //        b->CallPrintInt(" &&& nextIndex", nextIndex);
    190 //        b->CallPrintInt(" &&& currentLastIndex", currentLastIndex);
    191 
    192 //        b->CallPrintInt(" &&& mNumOfStrides", mNumOfStrides);
    193 //        b->CallPrintInt(" &&& nextIndex", nextIndex);
    194 
    195         Value * const nextLast = b->CreateSelect(continuation, mNumOfStrides, nextIndex);
    196 
    197 //        b->CallPrintInt(" &&& nextLast", nextLast);
    198 
    199         Value * const nextFirst = b->CreateSelect(continuation, mNumOfStrides, currentLastIndex);
    200 
    201 //        b->CallPrintInt(" &&& nextFirst", nextFirst);
    202 
    203 
    204         lastNonZeroIndex->addIncoming(nextFirst, processStrides);
    205         lastAllZeroIndex->addIncoming(nextFirst, processStrides);
    206         Value * finished = b->CreateNot(notLastStride);
    207         Value * const flipLastState = b->CreateAnd(finished, b->CreateNot(continuation));
    208         nonZeroAfterAllZero->addIncoming(flipLastState, processStrides);
    209         allZeroAfterNonZero->addIncoming(flipLastState, processStrides);
    210         b->CreateCondBr(selectedPath, nonZeroPath, allZeroPath);
    211 
    212         // make the actual calls and take any potential termination signal
    213         b->SetInsertPoint(nonZeroPath);
    214         callKernel(b, mNonZeroKernel, firstNonZeroIndex, lastNonZeroIndex, terminatedPhi);
    215         BasicBlock * const nonZeroPathExit = b->GetInsertBlock();
    216         firstAllZeroIndex->addIncoming(nextFirst, nonZeroPathExit);
    217         lastAllZeroIndex->addIncoming(nextLast, nonZeroPathExit);
    218         nonZeroAfterAllZero->addIncoming(b->getFalse(), nonZeroPathExit);
    219         b->CreateUnlikelyCondBr(allZeroAfterNonZero, allZeroPath, mergePaths);
    220 
    221         b->SetInsertPoint(allZeroPath);
    222         callKernel(b, mAllZeroKernel, firstAllZeroIndex, lastAllZeroIndex, terminatedPhi);
    223         BasicBlock * const allZeroPathExit = b->GetInsertBlock();
    224         firstNonZeroIndex->addIncoming(nextFirst, allZeroPathExit);
    225         lastNonZeroIndex->addIncoming(nextLast, allZeroPathExit);
    226         allZeroAfterNonZero->addIncoming(b->getFalse(), allZeroPathExit);
    227         b->CreateUnlikelyCondBr(nonZeroAfterAllZero, nonZeroPath, mergePaths);
    228 
    229         b->SetInsertPoint(mergePaths);
    230         currentFirstIndex->addIncoming(nextFirst, mergePaths);
    231         currentLastIndex->addIncoming(nextLast, mergePaths);
    232         currentState->addIncoming(nextState, mergePaths);
    233         if (terminatedPhi) {
    234             finished = b->CreateOr(finished, terminatedPhi);
    235         }
    236         b->CreateLikelyCondBr(finished, exit, loopCond);
    237 
    238     } else {
    239 
    240 //        Value * const cond = b->getScalarField(CONDITION_TAG);
    241 //        b->CreateCondBr(b->CreateIsNotNull(cond), nonZeroPath, allZeroPath);
    242 
    243 //        // make the actual calls and take any potential termination signal
    244 //        b->SetInsertPoint(nonZeroPath);
    245 //        callKernel(b, mNonZeroKernel, ZERO, mNumOfStrides, b->getTrue(), nullptr);
    246 //        b->CreateBr(exit);
    247 
    248 //        b->SetInsertPoint(allZeroPath);
    249 //        callKernel(b, mAllZeroKernel, ZERO, mNumOfStrides, b->getTrue(), nullptr);
    250 //        b->CreateBr(exit);
    251     }
    252 
    253     b->SetInsertPoint(exit);
    254 
    255 }
    256 
    257 /** ------------------------------------------------------------------------------------------------------------- *
    258  * @brief callKernel
    259  ** ------------------------------------------------------------------------------------------------------------- */
    260 void OptimizationBranch::callKernel(const std::unique_ptr<KernelBuilder> & b,
    261                                     const Kernel * const kernel,
    262                                     Value * const first, Value * const last,
    263                                     PHINode * const terminatedPhi) {
    264 
    265 //    b->CallPrintInt(" &&& " + kernel->getName() + "_first", first);
    266 //    b->CallPrintInt(" &&& " + kernel->getName() + "_last", last);
    267 
    268     if (LLVM_UNLIKELY(codegen::DebugOptionIsSet(codegen::EnableAsserts))) {
    269         Value * const nonZeroLength = b->CreateICmpULT(first, last);
    270         Value * valid = b->CreateOr(nonZeroLength, b->CreateIsNull(last));
    271         b->CreateAssert(valid,
    272             "Branch cannot execute 0 strides unless this is the final stride");
    273     }
    274 
    275     Function * const doSegment = kernel->getDoSegmentFunction(b->getModule());
    276 
    277     BasicBlock * incrementItemCounts = nullptr;
    278     BasicBlock * kernelExit = nullptr;
    279     if (kernel->canSetTerminateSignal()) {
    280         incrementItemCounts = b->CreateBasicBlock("incrementItemCounts");
    281         kernelExit = b->CreateBasicBlock("kernelExit");
    282     }
    283 
    284     std::vector<Value *> args;
    285     args.reserve(doSegment->arg_size());
    286     if (kernel->isStateful()) {
    287         args.push_back(kernel->getHandle()); // handle
    288     }
    289     Value * const numOfStrides = b->CreateSub(last, first);
    290     args.push_back(numOfStrides); // numOfStrides
    291     const auto numOfInputs = kernel->getNumOfStreamInputs();
    292 
    293     Value * const isFinal = b->CreateIsNull(last);
    294 
    295     for (unsigned i = 0; i < numOfInputs; i++) {
    296 
    297         const Binding & input = kernel->getInputStreamSetBinding(i);
    298         const auto & buffer = mStreamSetInputBuffers[i];
    299         // logical base input address
    300         args.push_back(buffer->getBaseAddress(b.get()));
    301         // processed input items
    302         Value * processed = mProcessedInputItemPtr[i];
    303         if (isParamConstant(input)) {
    304             processed = b->CreateLoad(processed);
    305         }
    306         args.push_back(processed);
    307         // accessible input items (after non-deferred processed item count)
    308         Value * accessible = getItemCountIncrement(b, input, first, last, mAccessibleInputItems[i]);
    309         accessible = b->CreateSelect(isFinal, mAccessibleInputItems[i], accessible);
    310         args.push_back(accessible);
    311         // TODO: What if one of the branches requires this but the other doesn't?
    312         if (LLVM_UNLIKELY(input.hasAttribute(AttrId::RequiresPopCountArray))) {
    313             args.push_back(b->CreateGEP(mPopCountRateArray[i], first));
    314         }
    315         if (LLVM_UNLIKELY(input.hasAttribute(AttrId::RequiresNegatedPopCountArray))) {
    316             args.push_back(b->CreateGEP(mNegatedPopCountRateArray[i], first));
    317         }
    318     }
    319 
    320     const auto numOfOutputs = kernel->getNumOfStreamOutputs();
    321     for (unsigned i = 0; i < numOfOutputs; ++i) {
    322         const Binding & output = kernel->getOutputStreamSetBinding(i);
    323         const auto & buffer = mStreamSetOutputBuffers[i];
    324         args.push_back(buffer->getBaseAddress(b.get()));
    325         // produced
    326         Value * produced = mProducedOutputItemPtr[i];
    327         if (isParamConstant(output)) {
    328             produced = b->CreateLoad(produced);
    329         }
    330         args.push_back(produced);
    331         Value * writable = getItemCountIncrement(b, output, first, last, mWritableOutputItems[i]);
    332         writable = b->CreateSelect(isFinal, mWritableOutputItems[i], writable);
    333         args.push_back(writable);
    334     }
    335 
    336     Value * const terminated = b->CreateCall(doSegment, args);
    337     if (incrementItemCounts) {
    338         b->CreateUnlikelyCondBr(terminated, kernelExit, incrementItemCounts);
    339 
    340         b->SetInsertPoint(incrementItemCounts);
    341     }
    342 
    343     for (unsigned i = 0; i < numOfInputs; ++i) {
    344         const Binding & input = kernel->getInputStreamSetBinding(i);
    345         if (isParamConstant(input)) {
    346             Value * const processed = b->CreateLoad(mProcessedInputItemPtr[i]);
    347             Value * const itemCount = getItemCountIncrement(b, input, first, last);
    348             Value * const updatedInputCount = b->CreateAdd(processed, itemCount);
    349             b->CreateStore(updatedInputCount, mProcessedInputItemPtr[i]);
    350         }
    351 //        Value * const processed = b->CreateLoad(mProcessedInputItemPtr[i]);
    352 //        b->CallPrintInt(" &&& " + input.getName() + "_processed'", processed);
    353     }
    354 
    355     for (unsigned i = 0; i < numOfOutputs; ++i) {
    356         const Binding & output = kernel->getOutputStreamSetBinding(i);
    357         if (isParamConstant(output)) {
    358             Value * const produced = b->CreateLoad(mProducedOutputItemPtr[i]);
    359             Value * const itemCount = getItemCountIncrement(b, output, first, last);
    360             Value * const updatedOutputCount = b->CreateAdd(produced, itemCount);
    361             b->CreateStore(updatedOutputCount, mProducedOutputItemPtr[i]);
    362         }
    363 //        Value * const processed = b->CreateLoad(mProducedOutputItemPtr[i]);
    364 //        b->CallPrintInt(" &&& " + output.getName() + "_produced'", processed);
    365     }
    366 
    367     if (incrementItemCounts) {
    368         terminatedPhi->addIncoming(terminated, b->GetInsertBlock());
    369         b->CreateBr(kernelExit);
    370         b->SetInsertPoint(kernelExit);
    371     }
    372 
    373 }
    374 
    375 /** ------------------------------------------------------------------------------------------------------------- *
    376  * @brief getItemCountIncrement
    377  ** ------------------------------------------------------------------------------------------------------------- */
    378 Value * OptimizationBranch::getItemCountIncrement(const std::unique_ptr<KernelBuilder> & b, const Binding & binding,
    379                                                   Value * const first, Value * const last, Value * const defaultValue) const {
    380     const ProcessingRate & rate = binding.getRate();
    381     if (rate.isFixed() || rate.isBounded()) {
    382         Constant * const strideLength = b->getSize(ceiling(getUpperBound(binding) * getStride()));
    383         Value * const numOfStrides = b->CreateSub(last, first);
    384         return b->CreateMul(numOfStrides, strideLength);
    385     } else if (rate.isPopCount() || rate.isNegatedPopCount()) {
    386         Port refPort;
    387         unsigned refIndex = 0;
    388         std::tie(refPort, refIndex) = getStreamPort(rate.getReference());
    389         assert (refPort == Port::Input);
    390         Value * array = nullptr;
    391         if (rate.isNegatedPopCount()) {
    392             array = mNegatedPopCountRateArray[refIndex];
    393         } else {
    394             array = mPopCountRateArray[refIndex];
    395         }
    396         Constant * const ONE = b->getSize(1);
    397         Value * const currentIndex = b->CreateSub(last, ONE);
    398         Value * const currentSum = b->CreateLoad(b->CreateGEP(array, currentIndex));
    399         Value * const priorIndex = b->CreateSub(first, ONE);
    400         Value * const priorSum = b->CreateLoad(b->CreateGEP(array, priorIndex));
    401         return b->CreateSub(currentSum, priorSum);
    402     }
    403     return defaultValue;
    404 }
    405 
    406 // TODO: abstract this. it's a near copy of the pipeline kernel logic
    407 
    408 void enumerateScalarProducerBindings(const std::unique_ptr<KernelBuilder> & b,
    409                                      const ScalarVertex producer,
    410                                      const Bindings & bindings,
    411                                      ScalarDependencyGraph & G,
    412                                      ScalarDependencyMap & M) {
    413     const auto n = bindings.size();
    414     for (unsigned i = 0; i < n; ++i) {
    415         const Binding & binding = bindings[i];
    416         const Relationship * const rel = binding.getRelationship();
    417         assert (M.count(rel) == 0);
    418         Value * const value = b->getScalarField(binding.getName());
    419         const auto buffer = add_vertex(value, G);
    420         add_edge(producer, buffer, i, G);
    421         M.emplace(rel, buffer);
    422     }
    423 }
    424 
    425 ScalarVertex makeIfConstant(const Binding & binding,
    426                             ScalarDependencyGraph & G,
    427                             ScalarDependencyMap & M) {
    428     const Relationship * const rel = binding.getRelationship();
    429     const auto f = M.find(rel);
    430     if (LLVM_LIKELY(f != M.end())) {
    431         return f->second;
    432     } else if (LLVM_LIKELY(isa<ScalarConstant>(rel))) {
    433         const auto bufferVertex = add_vertex(cast<ScalarConstant>(rel)->value(), G);
    434         M.emplace(rel, bufferVertex);
    435         return bufferVertex;
    436     } else {
    437         report_fatal_error("unknown scalar value");
    438     }
    439 }
    440 
    441 void enumerateScalarConsumerBindings(const ScalarVertex consumer,
    442                                      const Bindings & bindings,
    443                                      ScalarDependencyGraph & G,
    444                                      ScalarDependencyMap & M) {
    445     const auto n = bindings.size();
    446     for (unsigned i = 0; i < n; ++i) {
    447         const auto buffer = makeIfConstant(bindings[i], G, M);
    448         assert (buffer < num_vertices(G));
    449         add_edge(buffer, consumer, i, G);
    450     }
    451 }
    452 
    453 /** ------------------------------------------------------------------------------------------------------------- *
    454  * @brief initKernel
    455  ** ------------------------------------------------------------------------------------------------------------- */
    456 Value * initKernel(const std::unique_ptr<KernelBuilder> & b,
    457                    const unsigned index,
    458                    Kernel * const kernel,
    459                    Function * const initializer,
    460                    const ScalarDependencyGraph & G) {
    461     std::vector<Value *> args;
    462     const auto hasHandle = kernel->isStateful() ? 1U : 0U;
    463     args.resize(hasHandle + in_degree(index, G));
    464     if (LLVM_LIKELY(hasHandle)) {
    465         Value * handle = kernel->createInstance(b);
    466         if (LLVM_UNLIKELY(kernel->hasFamilyName())) {
    467             handle = b->CreatePointerCast(handle, b->getVoidPtrTy());
    468         }
    469         b->setScalarField(BRANCH_PREFIX + std::to_string(index - 1), handle);
    470         args[0] = handle;
    471     }
    472     for (const auto e : make_iterator_range(in_edges(index, G))) {
    473         const auto j = hasHandle + G[e];
    474         const auto scalar = source(e, G);
    475         args[j] = G[scalar];
    476     }
    477     return b->CreateCall(initializer, args);
    478 }
    479 
    480 /** ------------------------------------------------------------------------------------------------------------- *
    481  * @brief generateInitializeMethod
    482  ** ------------------------------------------------------------------------------------------------------------- */
    483 void OptimizationBranch::generateInitializeMethod(const std::unique_ptr<KernelBuilder> & b) {
    484 
    485     ScalarDependencyGraph G(3);
    486     ScalarDependencyMap M;
    487 
    488     enumerateScalarProducerBindings(b, 0, getInputScalarBindings(), G, M);
    489     enumerateScalarConsumerBindings(1, mAllZeroKernel->getInputScalarBindings(), G, M);
    490     enumerateScalarConsumerBindings(2, mNonZeroKernel->getInputScalarBindings(), G, M);
    491 
    492     Module * const m = b->getModule();
    493     Value * const term2 = initKernel(b, 1, mAllZeroKernel, mAllZeroKernel->getInitFunction(m), G);
    494     Value * const term1 = initKernel(b, 2, mNonZeroKernel, mNonZeroKernel->getInitFunction(m), G);
    495     b->CreateStore(b->CreateOr(term1, term2), mTerminationSignalPtr);
    496 }
    497 
    498 /** ------------------------------------------------------------------------------------------------------------- *
    499  * @brief generateFinalizeMethod
    500  ** ------------------------------------------------------------------------------------------------------------- */
    501 inline Value * callTerminate(const std::unique_ptr<KernelBuilder> & b, Kernel * kernel, const std::string suffix) {
    502     loadHandle(b, kernel, suffix);
    503     return kernel->finalizeInstance(b);
     44    mCompiler->generateKernelMethod(b);
    50445}
    50546
     
    50849 ** ------------------------------------------------------------------------------------------------------------- */
    50950void OptimizationBranch::generateFinalizeMethod(const std::unique_ptr<KernelBuilder> & b) {
    510     Value * allZeroResult = callTerminate(b, mAllZeroKernel, "0");
    511     Value * nonZeroResult = callTerminate(b, mNonZeroKernel, "1");
    512     if (LLVM_UNLIKELY(nonZeroResult || allZeroResult)) {
    513         report_fatal_error("OptimizationBranch does not support output scalars yet");
    514     }
    515 
     51    mCompiler->generateFinalizeMethod(b);
    51652}
    51753
     
    52359    mAllZeroKernel->addKernelDeclarations(b);
    52460    Kernel::addKernelDeclarations(b);
    525 }
    526 
    527 void addHandle(const std::unique_ptr<KernelBuilder> & b, const Kernel * const kernel, Bindings & scalars, const std::string suffix) {
    528     if (LLVM_LIKELY(kernel->isStateful())) {
    529         Type * handleType = nullptr;
    530         if (LLVM_UNLIKELY(kernel->hasFamilyName())) {
    531             handleType = b->getVoidPtrTy();
    532         } else {
    533             handleType = kernel->getKernelType()->getPointerTo();
    534         }
    535         scalars.emplace_back(handleType, BRANCH_PREFIX + suffix);
    536     }
    53761}
    53862
     
    55377, mNonZeroKernel(nonZeroKernel.get())
    55478, mAllZeroKernel(allZeroKernel.get()) {
    555     addHandle(b, mAllZeroKernel, mInternalScalars, "0");
    556     addHandle(b, mNonZeroKernel, mInternalScalars, "1");
    557 }
    558 
    559 OptimizationBranch::~OptimizationBranch() {
    56079
    56180}
  • icGREP/icgrep-devel/icgrep/kernels/optimizationbranch/optimizationbranch_compiler.hpp

    r6287 r6296  
    55#include <kernels/streamset.h>
    66#include <kernels/kernel_builder.h>
     7#include <boost/container/flat_map.hpp>
     8#include <boost/graph/adjacency_list.hpp>
     9#include <llvm/Support/raw_ostream.h>
     10
     11using namespace llvm;
     12using namespace boost;
     13using namespace boost::container;
     14
     15using StreamSetGraph = adjacency_list<vecS, vecS, bidirectionalS, no_property, unsigned>;
     16
     17struct RelationshipRef {
     18    unsigned Index;
     19    StringRef Name;
     20    RelationshipRef() : Index(0), Name() { }
     21    RelationshipRef(const unsigned index, StringRef name) : Index(index), Name(name) { }
     22};
     23
     24const static std::string BRANCH_PREFIX = "@B";
     25
     26using RelationshipGraph = adjacency_list<vecS, vecS, bidirectionalS, no_property, RelationshipRef>;
     27
     28using RelationshipCache = flat_map<RelationshipGraph::vertex_descriptor, Value *>;
    729
    830namespace kernel {
    931
    10     struct OptimizationBranchCompiler {
    11 
    12         OptimizationBranchCompiler(OptimizationBranch * const branch);
    13 
    14 
    15 
    16 
    17     private:
    18 
    19         OptimizationBranch * const          mBranch;
    20 
    21 
    22 
    23         std::vector<llvm::Value *>          mProcessedInputItems;
    24         std::vector<llvm::PHINode *>        mAccessibleInputItemPhi;
    25 
    26         std::vector<llvm::Value *>          mProducedOutputItems;
    27         std::vector<llvm::PHINode *>        mWritableOrConsumedOutputItemPhi;
    28 
     32using Port = Kernel::Port;
     33using StreamPort = Kernel::StreamSetPort;
     34using BuilderRef = const std::unique_ptr<kernel::KernelBuilder> &;
     35using AttrId = Attribute::KindId;
     36
     37enum : unsigned {
     38    BRANCH_INPUT = 0
     39    , ALL_ZERO_BRANCH = 1
     40    , NON_ZERO_BRANCH = 2
     41    , BRANCH_OUTPUT = 3
     42    , CONDITION_VARIABLE = 4
     43// ----------------------
     44    , INITIAL_GRAPH_SIZE = 5
     45};
     46
     47static_assert (ALL_ZERO_BRANCH < NON_ZERO_BRANCH, "invalid branch type ordering");
     48
     49class OptimizationBranchCompiler {
     50
     51    enum class RelationshipType : unsigned {
     52        StreamSet
     53        , Scalar
    2954    };
    3055
    31 }
    32 
     56public:
     57    OptimizationBranchCompiler(OptimizationBranch * const branch);
     58
     59    void addBranchProperties(BuilderRef b);
     60    void generateInitializeMethod(BuilderRef b);
     61    void generateKernelMethod(BuilderRef b);
     62    void generateFinalizeMethod(BuilderRef b);
     63    std::vector<Value *> getFinalOutputScalars(BuilderRef b);
     64
     65private:
     66
     67    Value * loadHandle(BuilderRef b, const unsigned branchType) const;
     68
     69    Value * getInputScalar(BuilderRef b, const unsigned scalar);
     70
     71    inline unsigned getNumOfInputBindings(const Kernel * const kernel, const RelationshipType type) const;
     72
     73    const Binding & getInputBinding(const Kernel * const kernel, const RelationshipType type, const unsigned i) const;
     74
     75    unsigned getNumOfOutputBindings(const Kernel * const kernel, const RelationshipType type) const;
     76
     77    const Binding & getOutputBinding(const Kernel * const kernel, const RelationshipType type, const unsigned i) const;
     78
     79    void generateStreamSetBranchMethod(BuilderRef b);
     80
     81    void executeBranch(BuilderRef b, const unsigned branchType, Value * const first, Value * const last);
     82
     83    Value * calculateAccessibleOrWritableItems(BuilderRef b, const Kernel * const kernel, const Binding & binding, Value * const first, Value * const last, Value * const defaultValue) const;
     84
     85    RelationshipGraph makeRelationshipGraph(const RelationshipType type) const;
     86
     87private:
     88
     89    OptimizationBranch * const          mBranch;
     90    const std::vector<const Kernel *>   mBranches;
     91
     92    PHINode *                           mTerminatedPhi;
     93
     94
     95
     96
     97    const RelationshipGraph             mStreamSetGraph;
     98    const RelationshipGraph             mScalarGraph;
     99    RelationshipCache                   mScalarCache;
     100
     101};
     102
     103template <typename Graph>
     104LLVM_READNONE
     105inline typename graph_traits<Graph>::edge_descriptor in_edge(const typename graph_traits<Graph>::vertex_descriptor u, const Graph & G) {
     106    assert (in_degree(u, G) == 1);
     107    return *in_edges(u, G).first;
     108}
     109
     110template <typename Graph>
     111LLVM_READNONE
     112inline typename graph_traits<Graph>::edge_descriptor preceding(const typename graph_traits<Graph>::edge_descriptor & e, const Graph & G) {
     113    return in_edge(source(e, G), G);
     114}
     115
     116template <typename Graph>
     117LLVM_READNONE
     118inline typename graph_traits<Graph>::edge_descriptor out_edge(const typename graph_traits<Graph>::vertex_descriptor u, const Graph & G) {
     119    assert (out_degree(u, G) == 1);
     120    return *out_edges(u, G).first;
     121}
     122
     123template <typename Graph>
     124LLVM_READNONE
     125inline typename graph_traits<Graph>::edge_descriptor descending(const typename graph_traits<Graph>::edge_descriptor & e, const Graph & G) {
     126    return out_edge(target(e, G), G);
     127}
     128
     129inline unsigned OptimizationBranchCompiler::getNumOfInputBindings(const Kernel * const kernel, const RelationshipType type) const {
     130    return (type == RelationshipType::StreamSet) ? kernel->getNumOfStreamInputs() : kernel->getNumOfScalarInputs();
     131}
     132
     133inline const Binding & OptimizationBranchCompiler::getInputBinding(const Kernel * const kernel, const RelationshipType type, const unsigned i) const {
     134    return (type == RelationshipType::StreamSet) ? kernel->mInputStreamSets[i] : kernel->mInputScalars[i];
     135}
     136
     137inline unsigned OptimizationBranchCompiler::getNumOfOutputBindings(const Kernel * const kernel, const RelationshipType type) const {
     138    return (type == RelationshipType::StreamSet) ? kernel->getNumOfStreamOutputs() : kernel->getNumOfScalarOutputs();
     139}
     140
     141inline const Binding & OptimizationBranchCompiler::getOutputBinding(const Kernel * const kernel, const RelationshipType type, const unsigned i) const {
     142    return (type == RelationshipType::StreamSet) ? kernel->mOutputStreamSets[i] : kernel->mOutputScalars[i];
     143}
     144
     145/** ------------------------------------------------------------------------------------------------------------- *
     146 * @brief makeRelationshipGraph
     147 ** ------------------------------------------------------------------------------------------------------------- */
     148RelationshipGraph OptimizationBranchCompiler::makeRelationshipGraph(const RelationshipType type) const {
     149
     150    using Vertex = RelationshipGraph::vertex_descriptor;
     151    using Map = flat_map<const Relationship *, Vertex>;
     152
     153    RelationshipGraph G(INITIAL_GRAPH_SIZE);
     154    Map M;
     155
     156    auto addRelationship = [&](const Relationship * const rel) {
     157        const auto f = M.find(rel);
     158        if (LLVM_UNLIKELY(f != M.end())) {
     159            return f->second;
     160        } else {
     161            const auto x = add_vertex(G);
     162            M.emplace(rel, x);
     163            return x;
     164        }
     165    };
     166
     167    const auto numOfInputs = getNumOfInputBindings(mBranch, type);
     168    for (unsigned i = 0; i < numOfInputs; ++i) {
     169        const auto & input = getInputBinding(mBranch, type, i);
     170        const auto r = addRelationship(input.getRelationship());
     171        add_edge(BRANCH_INPUT, r, RelationshipRef{i, input.getName()}, G);
     172    }
     173
     174    const auto numOfOutputs = getNumOfOutputBindings(mBranch, type);
     175    for (unsigned i = 0; i < numOfOutputs; ++i) {
     176        const auto & output = getOutputBinding(mBranch, type, i);
     177        const auto r = addRelationship(output.getRelationship());
     178        add_edge(r, BRANCH_OUTPUT, RelationshipRef{i, output.getName()}, G);
     179    }
     180
     181    if (type == RelationshipType::StreamSet && isa<StreamSet>(mBranch->getCondition())) {
     182        const auto r = addRelationship(mBranch->getCondition());
     183        add_edge(r, CONDITION_VARIABLE, RelationshipRef{}, G);
     184    }
     185
     186    if (type == RelationshipType::Scalar && isa<Scalar>(mBranch->getCondition())) {
     187        const auto r = addRelationship(mBranch->getCondition());
     188        add_edge(r, CONDITION_VARIABLE, RelationshipRef{}, G);
     189    }
     190
     191    auto findRelationship = [&](const Kernel * kernel, const Binding & binding) {
     192        const Relationship * const rel = binding.getRelationship();
     193        const auto f = M.find(rel);
     194        if (LLVM_UNLIKELY(f == M.end())) {
     195            if (LLVM_LIKELY(rel->isConstant())) {
     196                const auto x = add_vertex(G);
     197                M.emplace(rel, x);
     198                return x;
     199            } else {
     200                std::string tmp;
     201                raw_string_ostream msg(tmp);
     202                msg << "Branch was not provided a ";
     203                if (type == RelationshipType::StreamSet) {
     204                    msg << "StreamSet";
     205                } else if (type == RelationshipType::Scalar) {
     206                    msg << "Scalar";
     207                }
     208                msg << " binding for "
     209                    << kernel->getName()
     210                    << '.'
     211                    << binding.getName();
     212                report_fatal_error(msg.str());
     213            }
     214        }
     215        return f->second;
     216    };
     217
     218    auto linkRelationships = [&](const Kernel * const kernel, const Vertex branch) {
     219
     220        const auto numOfInputs = getNumOfInputBindings(kernel, type);
     221        for (unsigned i = 0; i < numOfInputs; ++i) {
     222            const auto & input = getInputBinding(mBranch, type, i);
     223            const auto r = findRelationship(kernel, input);
     224            add_edge(r, branch, RelationshipRef{i, input.getName()}, G);
     225        }
     226
     227        const auto numOfOutputs = getNumOfOutputBindings(kernel, type);
     228        for (unsigned i = 0; i < numOfOutputs; ++i) {
     229            const auto & output = getOutputBinding(kernel, type, i);
     230            const auto r = findRelationship(kernel, output);
     231            add_edge(branch, r, RelationshipRef{i, output.getName()}, G);
     232        }
     233    };
     234
     235    linkRelationships(mBranch->getAllZeroKernel(), ALL_ZERO_BRANCH);
     236    linkRelationships(mBranch->getNonZeroKernel(), NON_ZERO_BRANCH);
     237
     238    return G;
     239}
     240
     241/** ------------------------------------------------------------------------------------------------------------- *
     242 * @brief generateInitializeMethod
     243 ** ------------------------------------------------------------------------------------------------------------- */
     244void OptimizationBranchCompiler::addBranchProperties(BuilderRef b) {
     245    for (unsigned i = ALL_ZERO_BRANCH; i <= NON_ZERO_BRANCH; ++i) {
     246        const Kernel * const kernel = mBranches[i];
     247        if (LLVM_LIKELY(kernel->isStateful())) {
     248            Type * handlePtrType = nullptr;
     249            if (LLVM_UNLIKELY(kernel->hasFamilyName())) {
     250                handlePtrType = b->getVoidPtrTy();
     251            } else {
     252                handlePtrType = kernel->getKernelType()->getPointerTo();
     253            }
     254            mBranch->addInternalScalar(handlePtrType, BRANCH_PREFIX + std::to_string(i));
     255        }
     256    }
     257}
     258
     259/** ------------------------------------------------------------------------------------------------------------- *
     260 * @brief generateInitializeMethod
     261 ** ------------------------------------------------------------------------------------------------------------- */
     262void OptimizationBranchCompiler::generateInitializeMethod(BuilderRef b) {
     263    mScalarCache.clear();
     264    for (unsigned i = ALL_ZERO_BRANCH; i <= NON_ZERO_BRANCH; ++i) {
     265        const Kernel * const kernel = mBranches[i];
     266        if (kernel->isStateful() && !kernel->hasFamilyName()) {
     267            Value * const handle = kernel->createInstance(b);
     268            b->setScalarField(BRANCH_PREFIX + std::to_string(i), handle);
     269        }
     270    }
     271    std::vector<Value *> args;
     272    Module * const m = b->getModule();
     273    Value * terminated = b->getFalse();
     274    for (unsigned i = ALL_ZERO_BRANCH; i <= NON_ZERO_BRANCH; ++i) {
     275        const Kernel * const kernel = mBranches[i];
     276        const auto hasHandle = kernel->isStateful() ? 1U : 0U;
     277        args.resize(hasHandle + in_degree(i, mScalarGraph));
     278        if (LLVM_LIKELY(hasHandle)) {
     279            args[0] = b->getScalarField(BRANCH_PREFIX + std::to_string(i));
     280        }
     281        for (const auto e : make_iterator_range(in_edges(i, mScalarGraph))) {
     282            const RelationshipRef & ref = mScalarGraph[e];
     283            const auto j = ref.Index + hasHandle;
     284            args[j] = getInputScalar(b, source(e, mScalarGraph));
     285        }
     286        Value * const terminatedOnInit = b->CreateCall(kernel->getInitFunction(m), args);
     287        terminated = b->CreateOr(terminated, terminatedOnInit);
     288    }
     289    b->CreateStore(terminated, mBranch->getTerminationSignalPtr());
     290}
     291
     292/** ------------------------------------------------------------------------------------------------------------- *
     293 * @brief loadKernelHandle
     294 ** ------------------------------------------------------------------------------------------------------------- */
     295Value * OptimizationBranchCompiler::loadHandle(BuilderRef b, const unsigned branchType) const {
     296    const Kernel * const kernel = mBranches[branchType];
     297    Value * handle = nullptr;
     298    if (LLVM_LIKELY(kernel->isStateful())) {
     299        handle = b->getScalarField(BRANCH_PREFIX + std::to_string(branchType));
     300        if (kernel->hasFamilyName()) {
     301            handle = b->CreatePointerCast(handle, kernel->getKernelType()->getPointerTo());
     302        }
     303    }
     304    return handle;
     305}
     306
     307/** ------------------------------------------------------------------------------------------------------------- *
     308 * @brief generateKernelMethod
     309 ** ------------------------------------------------------------------------------------------------------------- */
     310void OptimizationBranchCompiler::generateKernelMethod(BuilderRef b) {
     311    if (LLVM_LIKELY(isa<StreamSet>(mBranch->getCondition()))) {
     312        generateStreamSetBranchMethod(b);
     313    } else {
     314
     315    }
     316}
     317
     318/** ------------------------------------------------------------------------------------------------------------- *
     319 * @brief getConditionRef
     320 ** ------------------------------------------------------------------------------------------------------------- */
     321inline const RelationshipRef & getConditionRef(const RelationshipGraph & G) {
     322    return G[preceding(in_edge(CONDITION_VARIABLE, G), G)];
     323}
     324
     325/** ------------------------------------------------------------------------------------------------------------- *
     326 * @brief generateStreamSetBranchMethod
     327 ** ------------------------------------------------------------------------------------------------------------- */
     328inline void OptimizationBranchCompiler::generateStreamSetBranchMethod(BuilderRef b) {
     329
     330    Constant * const ZERO = b->getSize(0);
     331    Constant * const ONE = b->getSize(1);
     332    Constant * const BLOCKS_PER_STRIDE = b->getSize(mBranch->getStride() / b->getBitBlockWidth());
     333
     334    BasicBlock * const entry = b->GetInsertBlock();
     335    BasicBlock * const loopCond = b->CreateBasicBlock("cond");
     336    BasicBlock * const summarizeOneStride = b->CreateBasicBlock("summarizeOneStride");
     337    BasicBlock * const checkStride = b->CreateBasicBlock("checkStride");
     338    BasicBlock * const processStrides = b->CreateBasicBlock("processStrides");
     339    BasicBlock * const mergePaths = b->CreateBasicBlock("mergePaths");
     340    BasicBlock * const nonZeroPath = b->CreateBasicBlock("nonZeroPath");
     341    BasicBlock * const allZeroPath = b->CreateBasicBlock("allZeroPath");
     342    BasicBlock * const exit = b->CreateBasicBlock("exit");
     343    b->CreateBr(loopCond);
     344
     345    b->SetInsertPoint(loopCond);
     346    IntegerType * const sizeTy = b->getSizeTy();
     347    IntegerType * const boolTy = b->getInt1Ty();
     348    PHINode * const currentFirstIndex = b->CreatePHI(sizeTy, 3, "firstStride");
     349    currentFirstIndex->addIncoming(ZERO, entry);
     350    PHINode * const currentLastIndex = b->CreatePHI(sizeTy, 3, "lastStride");
     351    currentLastIndex->addIncoming(ZERO, entry);
     352    PHINode * const currentState = b->CreatePHI(boolTy, 3);
     353    currentState->addIncoming(UndefValue::get(boolTy), entry);
     354
     355    const RelationshipRef & condRef = getConditionRef(mStreamSetGraph);
     356    Value * const numOfConditionStreams = b->getInputStreamSetCount(condRef.Name);
     357    Value * const numOfConditionBlocks = b->CreateMul(numOfConditionStreams, BLOCKS_PER_STRIDE);
     358
     359    Value * const offset = b->CreateMul(currentLastIndex, BLOCKS_PER_STRIDE);
     360    Value * basePtr = b->getInputStreamBlockPtr(condRef.Name, ZERO, offset);
     361    Type * const BitBlockTy = b->getBitBlockType();
     362    basePtr = b->CreatePointerCast(basePtr, BitBlockTy->getPointerTo());
     363    b->CreateBr(summarizeOneStride);
     364
     365    // Predeclare some phi nodes
     366
     367    b->SetInsertPoint(nonZeroPath);
     368    PHINode * const firstNonZeroIndex = b->CreatePHI(sizeTy, 2);
     369    PHINode * const lastNonZeroIndex = b->CreatePHI(sizeTy, 2);
     370    PHINode * const allZeroAfterNonZero = b->CreatePHI(boolTy, 2);
     371
     372    b->SetInsertPoint(allZeroPath);
     373    PHINode * const firstAllZeroIndex = b->CreatePHI(sizeTy, 2);
     374    PHINode * const lastAllZeroIndex = b->CreatePHI(sizeTy, 2);
     375    PHINode * const nonZeroAfterAllZero = b->CreatePHI(boolTy, 2);
     376
     377    mTerminatedPhi = nullptr;
     378    if (mBranch->canSetTerminateSignal()) {
     379        b->SetInsertPoint(mergePaths);
     380        mTerminatedPhi = b->CreatePHI(boolTy, 2);
     381    }
     382
     383    Value * const numOfStrides = mBranch->mNumOfStrides;
     384
     385    // OR together every condition block in this stride
     386    b->SetInsertPoint(summarizeOneStride);
     387    PHINode * const iteration = b->CreatePHI(sizeTy, 2);
     388    iteration->addIncoming(ZERO, loopCond);
     389    PHINode * const merged = b->CreatePHI(BitBlockTy, 2);
     390    merged->addIncoming(Constant::getNullValue(BitBlockTy), loopCond);
     391    Value * value = b->CreateBlockAlignedLoad(basePtr, iteration);
     392    value = b->CreateOr(value, merged);
     393    merged->addIncoming(value, summarizeOneStride);
     394    Value * const nextIteration = b->CreateAdd(iteration, ONE);
     395    Value * const more = b->CreateICmpNE(nextIteration, numOfConditionBlocks);
     396    iteration->addIncoming(nextIteration, b->GetInsertBlock());
     397    b->CreateCondBr(more, summarizeOneStride, checkStride);
     398
     399    // Check the merged value of our condition block(s); if it differs from
     400    // the prior value or this is our last stride, then process the strides.
     401    // Note, however, initially state is "indeterminate" so we silently
     402    // ignore the first stride unless it is also our last.
     403    b->SetInsertPoint(checkStride);
     404    Value * const nextState = b->bitblock_any(value);
     405    Value * const sameState = b->CreateICmpEQ(nextState, currentState);
     406    Value * const firstStride = b->CreateICmpEQ(currentLastIndex, ZERO);
     407    Value * const continuation = b->CreateOr(sameState, firstStride);
     408    Value * const nextIndex = b->CreateAdd(currentLastIndex, ONE);
     409    Value * const notLastStride = b->CreateICmpULT(nextIndex, numOfStrides);
     410    Value * const checkNextStride = b->CreateAnd(continuation, notLastStride);
     411    currentLastIndex->addIncoming(nextIndex, checkStride);
     412    currentFirstIndex->addIncoming(currentFirstIndex, checkStride);
     413    currentState->addIncoming(nextState, checkStride);
     414    b->CreateLikelyCondBr(checkNextStride, loopCond, processStrides);
     415
     416    // Process every stride between [first, last)
     417    b->SetInsertPoint(processStrides);
     418
     419    // state is implicitly "indeterminate" during our first stride
     420    Value * const selectedPath = b->CreateSelect(firstStride, nextState, currentState);
     421    firstNonZeroIndex->addIncoming(currentFirstIndex, processStrides);
     422    firstAllZeroIndex->addIncoming(currentFirstIndex, processStrides);
     423    // When we reach the last (but not necessarily final) stride of this kernel,
     424    // we will either "append" the final stride to the current run or complete
     425    // the current run then perform one more iteration for the final stride, depending
     426    // whether it flips the branch selection state.
     427    Value * const nextLast = b->CreateSelect(continuation, numOfStrides, nextIndex);
     428    Value * const nextFirst = b->CreateSelect(continuation, numOfStrides, currentLastIndex);
     429
     430    lastNonZeroIndex->addIncoming(nextFirst, processStrides);
     431    lastAllZeroIndex->addIncoming(nextFirst, processStrides);
     432    Value * finished = b->CreateNot(notLastStride);
     433    Value * const flipLastState = b->CreateAnd(finished, b->CreateNot(continuation));
     434    nonZeroAfterAllZero->addIncoming(flipLastState, processStrides);
     435    allZeroAfterNonZero->addIncoming(flipLastState, processStrides);
     436    b->CreateCondBr(selectedPath, nonZeroPath, allZeroPath);
     437
     438    // make the actual calls and take any potential termination signal
     439    b->SetInsertPoint(nonZeroPath);
     440    executeBranch(b, NON_ZERO_BRANCH, firstNonZeroIndex, lastNonZeroIndex);
     441    BasicBlock * const nonZeroPathExit = b->GetInsertBlock();
     442    firstAllZeroIndex->addIncoming(nextFirst, nonZeroPathExit);
     443    lastAllZeroIndex->addIncoming(nextLast, nonZeroPathExit);
     444    nonZeroAfterAllZero->addIncoming(b->getFalse(), nonZeroPathExit);
     445    b->CreateUnlikelyCondBr(allZeroAfterNonZero, allZeroPath, mergePaths);
     446
     447    b->SetInsertPoint(allZeroPath);
     448    executeBranch(b, ALL_ZERO_BRANCH, firstAllZeroIndex, lastAllZeroIndex);
     449    BasicBlock * const allZeroPathExit = b->GetInsertBlock();
     450    firstNonZeroIndex->addIncoming(nextFirst, allZeroPathExit);
     451    lastNonZeroIndex->addIncoming(nextLast, allZeroPathExit);
     452    allZeroAfterNonZero->addIncoming(b->getFalse(), allZeroPathExit);
     453    b->CreateUnlikelyCondBr(nonZeroAfterAllZero, nonZeroPath, mergePaths);
     454
     455    b->SetInsertPoint(mergePaths);
     456    currentFirstIndex->addIncoming(nextFirst, mergePaths);
     457    currentLastIndex->addIncoming(nextLast, mergePaths);
     458    currentState->addIncoming(nextState, mergePaths);
     459    if (mTerminatedPhi) {
     460        finished = b->CreateOr(finished, mTerminatedPhi);
     461    }
     462    b->CreateLikelyCondBr(finished, exit, loopCond);
     463
     464    b->SetInsertPoint(exit);
     465
     466}
     467
     468/** ------------------------------------------------------------------------------------------------------------- *
     469 * @brief callKernel
     470 ** ------------------------------------------------------------------------------------------------------------- */
     471void OptimizationBranchCompiler::executeBranch(BuilderRef b,
     472                                               const unsigned branchType,
     473                                               Value * const first,
     474                                               Value * const last) {
     475
     476
     477    const Kernel * const kernel = mBranches[branchType];
     478
     479    Function * const doSegment = kernel->getDoSegmentFunction(b->getModule());
     480
     481    BasicBlock * incrementItemCounts = nullptr;
     482    BasicBlock * kernelExit = nullptr;
     483    if (kernel->canSetTerminateSignal()) {
     484        incrementItemCounts = b->CreateBasicBlock("incrementItemCounts");
     485        kernelExit = b->CreateBasicBlock("kernelExit");
     486    }
     487
     488    Value * const handle = loadHandle(b, branchType);
     489    // Last can only be 0 if this is the branch's final stride.
     490    Value * const isFinal = b->CreateIsNull(last);
     491
     492    const auto numOfInputs = in_degree(branchType, mStreamSetGraph);
     493
     494    std::vector<Value *> baseInputAddress(numOfInputs, nullptr);
     495    std::vector<Value *> processedInputItemPtr(numOfInputs, nullptr);
     496    std::vector<Value *> processedInputItem(numOfInputs, nullptr);
     497    std::vector<Value *> accessibleInputItem(numOfInputs, nullptr);
     498    std::vector<Value *> popCountRateArray(numOfInputs, nullptr);
     499    std::vector<Value *> negatedPopCountRateArray(numOfInputs, nullptr);
     500
     501    for (const auto & e : make_iterator_range(in_edges(branchType, mStreamSetGraph))) {
     502        const RelationshipRef & host = mStreamSetGraph[e];
     503        const auto & buffer = mBranch->getInputStreamSetBuffer(host.Index);
     504        const Binding & input = kernel->getInputStreamSetBinding(host.Index);
     505        const RelationshipRef & path = mStreamSetGraph[preceding(e, mStreamSetGraph)];
     506        // logical base input address
     507        baseInputAddress[path.Index] = buffer->getBaseAddress(b.get());
     508        // processed input items
     509        Value * processed = mBranch->getProcessedInputItemsPtr(path.Index);
     510        if (kernel->isCountable(input)) {
     511            processedInputItemPtr[path.Index] = processed;
     512            processed = b->CreateLoad(processed);
     513        }
     514        processedInputItem[path.Index] = processed;
     515
     516        // accessible input items (after non-deferred processed item count)
     517        Value * const accessible = mBranch->getAccessibleInputItems(path.Index);
     518        Value * const provided = calculateAccessibleOrWritableItems(b, kernel, input, first, last, accessible);
     519        accessibleInputItem[path.Index] = b->CreateSelect(isFinal, accessible, provided);
     520
     521        if (LLVM_UNLIKELY(input.hasAttribute(AttrId::RequiresPopCountArray))) {
     522            popCountRateArray[path.Index] = b->CreateGEP(mBranch->mPopCountRateArray[host.Index], first);
     523        }
     524        if (LLVM_UNLIKELY(input.hasAttribute(AttrId::RequiresNegatedPopCountArray))) {
     525            negatedPopCountRateArray[path.Index] = b->CreateGEP(mBranch->mNegatedPopCountRateArray[host.Index], first);
     526        }
     527    }
     528
     529    const auto numOfOutputs = out_degree(branchType, mStreamSetGraph);
     530
     531    std::vector<Value *> baseOutputAddress(numOfOutputs, nullptr);
     532    std::vector<Value *> producedOutputItemPtr(numOfOutputs, nullptr);
     533    std::vector<Value *> producedOutputItem(numOfOutputs, nullptr);
     534    std::vector<Value *> writableOutputItem(numOfInputs, nullptr);
     535
     536    for (const auto & e : make_iterator_range(out_edges(branchType, mStreamSetGraph))) {
     537        const RelationshipRef & host = mStreamSetGraph[e];
     538        const auto & buffer = mBranch->getOutputStreamSetBuffer(host.Index);
     539        const Binding & output = kernel->getOutputStreamSetBinding(host.Index);
     540        const RelationshipRef & path = mStreamSetGraph[descending(e, mStreamSetGraph)];
     541        // logical base input address
     542        baseOutputAddress[path.Index] = buffer->getBaseAddress(b.get());
     543        // produced output items
     544        Value * produced = mBranch->getProducedOutputItemsPtr(path.Index);
     545        if (kernel->isCountable(output)) {
     546            producedOutputItemPtr[path.Index] = produced;
     547            produced = b->CreateLoad(produced);
     548        }
     549        producedOutputItem[path.Index] = produced;
     550        Value * const writable = mBranch->getWritableOutputItems(path.Index);
     551        Value * const provided = calculateAccessibleOrWritableItems(b, kernel, output, first, last, writable);
     552        writableOutputItem[path.Index] = b->CreateSelect(isFinal, writable, provided);
     553    }
     554
     555    std::vector<Value *> args;
     556    args.reserve(doSegment->arg_size());
     557    if (handle) {
     558        args.push_back(handle);
     559    }
     560    args.push_back(b->CreateSub(last, first)); // numOfStrides
     561    for (unsigned i = 0; i < numOfInputs; ++i) {
     562        args.push_back(baseInputAddress[i]);
     563        args.push_back(processedInputItem[i]);
     564        args.push_back(accessibleInputItem[i]);
     565        if (popCountRateArray[i]) {
     566            args.push_back(popCountRateArray[i]);
     567        }
     568        if (negatedPopCountRateArray[i]) {
     569            args.push_back(negatedPopCountRateArray[i]);
     570        }
     571    }
     572    for (unsigned i = 0; i < numOfOutputs; ++i) {
     573        args.push_back(baseOutputAddress[i]);
     574        args.push_back(producedOutputItem[i]);
     575        args.push_back(writableOutputItem[i]);
     576    }
     577
     578    Value * const terminated = b->CreateCall(doSegment, args);
     579
     580    // TODO: if either of these kernels "share" an output scalar, copy the scalar value from the
     581    // branch we took to the state of the branch we avoided. This requires that the branch pipeline
     582    // exposes them.
     583
     584
     585    if (incrementItemCounts) {
     586        b->CreateUnlikelyCondBr(terminated, kernelExit, incrementItemCounts);
     587
     588        b->SetInsertPoint(incrementItemCounts);
     589    }
     590
     591    for (unsigned i = 0; i < numOfInputs; ++i) {
     592        if (processedInputItemPtr[i]) {
     593            Value * const processed = processedInputItem[i];
     594            Value * const itemCount = accessibleInputItem[i];
     595            Value * const updatedInputCount = b->CreateAdd(processed, itemCount);
     596            b->CreateStore(updatedInputCount, processedInputItemPtr[i]);
     597        }
     598    }
     599
     600    for (unsigned i = 0; i < numOfOutputs; ++i) {
     601        if (producedOutputItemPtr[i]) {
     602            Value * const produced = producedOutputItem[i];
     603            Value * const itemCount = writableOutputItem[i];
     604            Value * const updatedOutputCount = b->CreateAdd(produced, itemCount);
     605            b->CreateStore(updatedOutputCount, producedOutputItemPtr[i]);
     606        }
     607    }
     608
     609    if (incrementItemCounts) {
     610        mTerminatedPhi->addIncoming(terminated, b->GetInsertBlock());
     611        b->CreateBr(kernelExit);
     612        b->SetInsertPoint(kernelExit);
     613    }
     614
     615}
     616
     617/** ------------------------------------------------------------------------------------------------------------- *
     618 * @brief calculateAccessibleOrWritableItems
     619 ** ------------------------------------------------------------------------------------------------------------- */
     620Value * OptimizationBranchCompiler::calculateAccessibleOrWritableItems(BuilderRef b,
     621                                                                       const Kernel * const kernel, const Binding & binding,
     622                                                                       Value * const first, Value * const last,
     623                                                                       Value * const defaultValue) const {
     624    const ProcessingRate & rate = binding.getRate();
     625    if (LLVM_LIKELY(rate.isFixed() || rate.isBounded())) {
     626        Constant * const strideLength = b->getSize(ceiling(rate.getUpperBound() * kernel->getStride()));
     627        Value * const numOfStrides = b->CreateSub(last, first);
     628        return b->CreateMul(numOfStrides, strideLength);
     629    } else if (rate.isPopCount() || rate.isNegatedPopCount()) {
     630        Port refPort;
     631        unsigned refIndex = 0;
     632        std::tie(refPort, refIndex) = mBranch->getStreamPort(rate.getReference());
     633        assert (refPort == Port::Input);
     634        Value * array = nullptr;
     635        if (rate.isNegatedPopCount()) {
     636            array = mBranch->mNegatedPopCountRateArray[refIndex];
     637        } else {
     638            array = mBranch->mPopCountRateArray[refIndex];
     639        }
     640        Constant * const ONE = b->getSize(1);
     641        Value * const currentIndex = b->CreateSub(last, ONE);
     642        Value * const currentSum = b->CreateLoad(b->CreateGEP(array, currentIndex));
     643        Value * const priorIndex = b->CreateSub(first, ONE);
     644        Value * const priorSum = b->CreateLoad(b->CreateGEP(array, priorIndex));
     645        return b->CreateSub(currentSum, priorSum);
     646    }
     647    return defaultValue;
     648}
     649
     650
     651/** ------------------------------------------------------------------------------------------------------------- *
     652 * @brief getScalar
     653 ** ------------------------------------------------------------------------------------------------------------- */
     654inline Value * OptimizationBranchCompiler::getInputScalar(BuilderRef b, const unsigned scalar) {
     655    const auto f = mScalarCache.find(scalar);
     656    if (LLVM_UNLIKELY(f != mScalarCache.end())) {
     657        return f->second;
     658    }
     659    const auto e = in_edge(scalar, mScalarGraph);
     660    const RelationshipRef & ref = mScalarGraph[e];
     661    Value * const value = b->getScalarField(ref.Name);
     662    mScalarCache.emplace(scalar, value);
     663    return value;
     664}
     665
     666inline std::vector<const Kernel *> makeBranches(const OptimizationBranch * const branch) {
     667    std::vector<const Kernel *> branches(4);
     668    branches[BRANCH_INPUT] = branch;
     669    branches[ALL_ZERO_BRANCH] = branch->getAllZeroKernel();
     670    branches[NON_ZERO_BRANCH] = branch->getNonZeroKernel();
     671    branches[BRANCH_OUTPUT] = branch;
     672    return branches;
     673}
     674
     675/** ------------------------------------------------------------------------------------------------------------- *
     676 * @brief generateFinalizeMethod
     677 ** ------------------------------------------------------------------------------------------------------------- */
     678void OptimizationBranchCompiler::generateFinalizeMethod(BuilderRef b) {
     679    for (unsigned i = ALL_ZERO_BRANCH; i <= NON_ZERO_BRANCH; ++i) {
     680        const Kernel * const kernel = mBranches[i];
     681        kernel->finalizeInstance(b, loadHandle(b, i));
     682    }
     683}
     684
     685/** ------------------------------------------------------------------------------------------------------------- *
     686 * @brief constructor
     687 ** ------------------------------------------------------------------------------------------------------------- */
     688OptimizationBranchCompiler::OptimizationBranchCompiler(OptimizationBranch * const branch)
     689: mBranch(branch)
     690, mBranches(makeBranches(branch))
     691, mStreamSetGraph(makeRelationshipGraph(RelationshipType::StreamSet))
     692, mScalarGraph(makeRelationshipGraph(RelationshipType::Scalar)) {
     693
     694}
     695
     696}
    33697
    34698#endif // OPTIMIZATIONBRANCH_COMPILER_HPP
  • icGREP/icgrep-devel/icgrep/kernels/pipeline/core_logic.hpp

    r6288 r6296  
    199199 * @brief isParamAddressable
    200200 ** ------------------------------------------------------------------------------------------------------------- */
    201 inline bool isParamAddressable(const Binding & binding) {
     201inline bool isAddressable(const Binding & binding) {
    202202    if (binding.isDeferred()) {
    203203        return true;
  • icGREP/icgrep-devel/icgrep/kernels/pipeline/kernel_logic.hpp

    r6288 r6296  
    463463    const auto numOfOutputs = mKernel->getNumOfStreamOutputs();
    464464
    465 #warning TODO: add MProtect to buffers and their handles.
    466 
    467 #warning TODO: send in the # of output items we want in the external buffers
     465    // TODO: add MProtect to buffers and their handles.
     466
     467    // TODO: send in the # of output items we want in the external buffers
    468468
    469469    b->setKernel(mPipelineKernel);
     
    505505        args.push_back(epoch(b, input, getInputBuffer(i), processed, inputItems));
    506506        mReturnedProcessedItemCountPtr[i] = addItemCountArg(b, input, deferred, processed, args);
    507 
    508507        args.push_back(inputItems); assert (inputItems);
    509508
  • icGREP/icgrep-devel/icgrep/kernels/pipeline/pipeline_builder.cpp

    r6288 r6296  
    446446                             allZero->getOutputStreamSetBindings(),
    447447                             mOutputStreamSets);
    448 
    449     if (isa<StreamSet>(mCondition)) {
    450         mInputStreamSets.emplace_back(OptimizationBranch::CONDITION_TAG, mCondition);
    451     } else {
    452         mInputScalars.emplace_back(OptimizationBranch::CONDITION_TAG, mCondition);
    453     }
    454448
    455449    OptimizationBranch * const br =
  • icGREP/icgrep-devel/icgrep/kernels/pipeline/pipeline_kernel.cpp

    r6288 r6296  
    201201        b->CreateCall(doSegment, segmentArgs);
    202202        // call and return the final output value(s)
    203         b->CreateRet(finalizeInstance(b));
     203        b->CreateRet(finalizeInstance(b, handle));
    204204    }
    205205
  • icGREP/icgrep-devel/icgrep/kernels/pipeline/pipeline_logic.hpp

    r6288 r6296  
    7777                mPipelineKernel->addInternalScalar(sizeTy, prefix + DEFERRED_ITEM_COUNT_SUFFIX);
    7878            }
    79 //            if (LLVM_UNLIKELY(onlyOne && isPipelineInput(kernelIndex, i))) {
    80 //                mPipelineKernel->addLocalScalar(sizeTy, prefix + ITEM_COUNT_SUFFIX);
    81 //            } else {
    82                 mPipelineKernel->addInternalScalar(sizeTy, prefix + ITEM_COUNT_SUFFIX);
    83 //            }
     79            mPipelineKernel->addInternalScalar(sizeTy, prefix + ITEM_COUNT_SUFFIX);
    8480        }
    8581
     
    8884            const Binding & output = kernel->getOutputStreamSetBinding(i);
    8985            const auto prefix = makeBufferName(kernelIndex, output);
    90 //            if (LLVM_UNLIKELY(isPipelineOutput(kernelIndex, i))) {
    91 //                mPipelineKernel->addLocalScalar(sizeTy, prefix + ITEM_COUNT_SUFFIX);
    92 //            } else {
    93                 mPipelineKernel->addInternalScalar(sizeTy, prefix + ITEM_COUNT_SUFFIX);
    94 //            }
     86            mPipelineKernel->addInternalScalar(sizeTy, prefix + ITEM_COUNT_SUFFIX);
    9587        }
    9688    }
     
    9991        // if this is a family kernel, it's handle will be passed into the kernel
    10092        // methods rather than stored within the pipeline state
    101         PointerType * kernelPtrTy = kernel->getKernelType()->getPointerTo(0);
    102         mPipelineKernel->addInternalScalar(kernelPtrTy, makeKernelName(kernelIndex));
     93        PointerType * const handlePtrTy = kernel->getKernelType()->getPointerTo(0);
     94        mPipelineKernel->addInternalScalar(handlePtrTy, makeKernelName(kernelIndex));
    10395    }
    10496}
  • icGREP/icgrep-devel/icgrep/kernels/relationship.h

    r6184 r6296  
    5454    void* operator new (std::size_t size, Allocator & A) noexcept {
    5555        return A.allocate<uint8_t>(size);
     56    }
     57
     58    bool isConstant() const {
     59        return mClassTypeId == ClassTypeId::ScalarConstant;
    5660    }
    5761
  • icGREP/icgrep-devel/icgrep/u8u16.cpp

    r6291 r6296  
    362362
    363363    auto B = P->CreateOptimizationBranch(nonAscii,
    364         {Binding{"ByteStream", ByteStream}}, {Binding{"u16bytes", u16bytes, BoundedRate(0, 1)}});
     364        {Binding{"ByteStream", ByteStream}, Binding{"condition", nonAscii}}, {Binding{"u16bytes", u16bytes, BoundedRate(0, 1)}});
     365
     366    makeAllAsciiBranch(B->getAllZeroBranch(), ByteStream, u16bytes);
    365367
    366368    makeNonAsciiBranch(B->getNonZeroBranch(), b->getBitBlockWidth() / 16, ByteStream, u16bytes);
    367 
    368     makeAllAsciiBranch(B->getAllZeroBranch(), ByteStream, u16bytes);
    369369
    370370    Scalar * outputFileName = P->getInputScalar("outputFileName");
Note: See TracChangeset for help on using the changeset viewer.