Changeset 6289


Ignore:
Timestamp:
Jan 19, 2019, 2:44:54 PM (3 months ago)
Author:
cameron
Message:

Initial version of working OptimizationBranch?

Location:
icGREP/icgrep-devel/icgrep/kernels
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • icGREP/icgrep-devel/icgrep/kernels/optimizationbranch.h

    r6288 r6289  
    1010struct OptimizationBranchCompiler;
    1111
    12 class OptimizationBranch final : public MultiBlockKernel {
     12class OptimizationBranch final : public Kernel {
    1313    friend class OptimizationBranchBuilder;
    1414public:
     
    1616    static bool classof(const Kernel * const k) {
    1717        switch (k->getTypeId()) {
    18             case TypeId::MultiBlock:
    1918            case TypeId::OptimizationBranch:
    2019                return true;
     
    4443    void generateInitializeMethod(const std::unique_ptr<KernelBuilder> & b) final;
    4544
    46     void generateMultiBlockLogic(const std::unique_ptr<KernelBuilder> & b, llvm::Value * const numOfStrides) final;
     45    void generateKernelMethod(const std::unique_ptr<KernelBuilder> & b) final;
    4746
    4847    void generateFinalizeMethod(const std::unique_ptr<KernelBuilder> & b) final;
  • icGREP/icgrep-devel/icgrep/kernels/optimizationbranch/optimizationbranch.cpp

    r6288 r6289  
    55#include <boost/container/flat_map.hpp>
    66#include <llvm/Support/raw_ostream.h>
    7 
     7#include <toolchain/toolchain.h>
    88
    99#warning at compilation, this must verify that the I/O rates of the branch permits the rates of the branches
     
    7070
    7171/** ------------------------------------------------------------------------------------------------------------- *
    72  * @brief generateMultiBlockLogic
    73  ** ------------------------------------------------------------------------------------------------------------- */
    74 void OptimizationBranch::generateMultiBlockLogic(const std::unique_ptr<KernelBuilder> & b, Value * const numOfStrides) {
    75 
    76     BasicBlock * const loopCond = b->CreateBasicBlock("cond");
     72 * @brief generateDoSegmentMethod
     73 ** ------------------------------------------------------------------------------------------------------------- */
     74void OptimizationBranch::generateKernelMethod(const std::unique_ptr<KernelBuilder> & b) {
     75
     76
    7777    BasicBlock * const nonZeroPath = b->CreateBasicBlock("nonZeroPath");
    7878    BasicBlock * const allZeroPath = b->CreateBasicBlock("allZeroPath");
    79     BasicBlock * const mergePaths = b->CreateBasicBlock("mergePaths");
    8079    BasicBlock * const exit = b->CreateBasicBlock("exit");
     80
    8181
    8282
     
    8787    loadHandle(b, mNonZeroKernel, "1");
    8888
    89     BasicBlock * const entry = b->GetInsertBlock();
    90     b->CreateBr(loopCond);
    91 
    92     PHINode * terminatedPhi = nullptr;
    93     if (canSetTerminateSignal()) {
    94         b->SetInsertPoint(mergePaths);
    95         terminatedPhi = b->CreatePHI(b->getInt1Ty(), 2);
    96     }
    97 
    98     b->SetInsertPoint(loopCond);
    99     IntegerType * const sizeTy = b->getSizeTy();
    100     PHINode * const first = b->CreatePHI(sizeTy, 3, "firstStride");
    101     first->addIncoming(ZERO, entry);
    102     PHINode * const last = b->CreatePHI(sizeTy, 3, "lastStride");
    103     PHINode * const currentState = b->CreatePHI(b->getInt1Ty(), 3);
    104     currentState->addIncoming(UndefValue::get(b->getInt1Ty()), entry);
    105     Value * finished = nullptr;
    10689
    10790    if (LLVM_LIKELY(isa<StreamSet>(mCondition))) {
    10891
    109         last->addIncoming(ZERO, entry);
    110 
     92        BasicBlock * const entry = b->GetInsertBlock();
     93        BasicBlock * const loopCond = b->CreateBasicBlock("cond", nonZeroPath);
    11194        BasicBlock * const summarizeOneStride = b->CreateBasicBlock("summarizeOneStride", nonZeroPath);
    11295        BasicBlock * const checkStride = b->CreateBasicBlock("checkStride", nonZeroPath);
    11396        BasicBlock * const processStrides = b->CreateBasicBlock("processStrides", nonZeroPath);
     97        BasicBlock * const mergePaths = b->CreateBasicBlock("mergePaths", nonZeroPath);
     98
     99        b->CreateBr(loopCond);
     100
     101        b->SetInsertPoint(loopCond);
     102        IntegerType * const sizeTy = b->getSizeTy();
     103        IntegerType * const boolTy = b->getInt1Ty();
     104        PHINode * const currentFirstIndex = b->CreatePHI(sizeTy, 3, "firstStride");
     105        currentFirstIndex->addIncoming(ZERO, entry);
     106        PHINode * const currentLastIndex = b->CreatePHI(sizeTy, 3, "lastStride");
     107        currentLastIndex->addIncoming(ZERO, entry);
     108        PHINode * const currentState = b->CreatePHI(boolTy, 3);
     109        currentState->addIncoming(UndefValue::get(boolTy), entry);
     110
    114111
    115112        Constant * const blocksPerStride = b->getSize(getStride() / b->getBitBlockWidth());
     
    117114        Value * const numOfConditionBlocks = b->CreateMul(numOfConditionStreams, blocksPerStride);
    118115
    119         Value * const offset = b->CreateMul(last, blocksPerStride);
     116        Value * const offset = b->CreateMul(currentLastIndex, blocksPerStride);
    120117        Value * basePtr = b->getInputStreamBlockPtr(CONDITION_TAG, ZERO, offset);
    121118        Type * const BitBlockTy = b->getBitBlockType();
    122119        basePtr = b->CreatePointerCast(basePtr, BitBlockTy->getPointerTo());
    123120        b->CreateBr(summarizeOneStride);
     121
     122        // Predeclare some phi nodes
     123
     124        b->SetInsertPoint(nonZeroPath);
     125        PHINode * const firstNonZeroIndex = b->CreatePHI(sizeTy, 2);
     126        PHINode * const lastNonZeroIndex = b->CreatePHI(sizeTy, 2);
     127        PHINode * const allZeroAfterNonZero = b->CreatePHI(boolTy, 2);
     128
     129        b->SetInsertPoint(allZeroPath);
     130        PHINode * const firstAllZeroIndex = b->CreatePHI(sizeTy, 2);
     131        PHINode * const lastAllZeroIndex = b->CreatePHI(sizeTy, 2);
     132        PHINode * const nonZeroAfterAllZero = b->CreatePHI(boolTy, 2);
     133
     134        PHINode * terminatedPhi = nullptr;
     135        if (canSetTerminateSignal()) {
     136            b->SetInsertPoint(mergePaths);
     137            terminatedPhi = b->CreatePHI(b->getInt1Ty(), 2);
     138        }
    124139
    125140        // OR together every condition block in this stride
     
    144159        Value * const nextState = b->bitblock_any(value);
    145160        Value * const sameState = b->CreateICmpEQ(nextState, currentState);
    146         Value * const firstStride = b->CreateICmpEQ(last, ZERO);
     161        Value * const firstStride = b->CreateICmpEQ(currentLastIndex, ZERO);
    147162        Value * const continuation = b->CreateOr(sameState, firstStride);
    148         Value * const nextIndex = b->CreateAdd(last, ONE);
    149         Value * const notLastStride = b->CreateICmpNE(nextIndex, numOfStrides);
     163        Value * const nextIndex = b->CreateAdd(currentLastIndex, ONE);
     164//        Value * const lastStrideIndex = b->CreateUMin(mNumOfStrides, ONE);
     165        Value * const notLastStride = b->CreateICmpULT(nextIndex, mNumOfStrides);
    150166        Value * const checkNextStride = b->CreateAnd(continuation, notLastStride);
    151         last->addIncoming(nextIndex, checkStride);
    152         first->addIncoming(first, checkStride);
     167        currentLastIndex->addIncoming(nextIndex, checkStride);
     168        currentFirstIndex->addIncoming(currentFirstIndex, checkStride);
    153169        currentState->addIncoming(nextState, checkStride);
    154170        b->CreateLikelyCondBr(checkNextStride, loopCond, processStrides);
    155171
    156         // Process every stride between [first, index)
     172        // Process every stride between [first, last)
    157173        b->SetInsertPoint(processStrides);
     174
    158175        // state is implicitly "indeterminate" during our first stride
    159176        Value * const selectedPath = b->CreateSelect(firstStride, nextState, currentState);
    160         finished = b->CreateNot(notLastStride);
     177        firstNonZeroIndex->addIncoming(currentFirstIndex, processStrides);
     178        firstAllZeroIndex->addIncoming(currentFirstIndex, processStrides);
     179        // When we reach the last (but not necessarily final) stride of this kernel,
     180        // we will either "append" the final stride to the current run or complete
     181        // the current run then perform one more iteration for the final stride, depending
     182        // whether it flips the branch selection state.
     183
     184//        b->CallPrintInt(" &&& nextState", nextState);
     185//        b->CallPrintInt(" &&& firstStride", firstStride);
     186//        b->CallPrintInt(" &&& continuation", continuation);
     187//        b->CallPrintInt(" &&& notLastStride", notLastStride);
     188
     189//        b->CallPrintInt(" &&& nextIndex", nextIndex);
     190//        b->CallPrintInt(" &&& currentLastIndex", currentLastIndex);
     191
     192//        b->CallPrintInt(" &&& mNumOfStrides", mNumOfStrides);
     193//        b->CallPrintInt(" &&& nextIndex", nextIndex);
     194
     195        Value * const nextLast = b->CreateSelect(continuation, mNumOfStrides, nextIndex);
     196
     197//        b->CallPrintInt(" &&& nextLast", nextLast);
     198
     199        Value * const nextFirst = b->CreateSelect(continuation, mNumOfStrides, currentLastIndex);
     200
     201//        b->CallPrintInt(" &&& nextFirst", nextFirst);
     202
     203
     204        lastNonZeroIndex->addIncoming(nextFirst, processStrides);
     205        lastAllZeroIndex->addIncoming(nextFirst, processStrides);
     206        Value * finished = b->CreateNot(notLastStride);
     207        Value * const flipLastState = b->CreateAnd(finished, b->CreateNot(continuation));
     208        nonZeroAfterAllZero->addIncoming(flipLastState, processStrides);
     209        allZeroAfterNonZero->addIncoming(flipLastState, processStrides);
    161210        b->CreateCondBr(selectedPath, nonZeroPath, allZeroPath);
    162211
    163         first->addIncoming(last, mergePaths);
    164         last->addIncoming(nextIndex, mergePaths);
     212        // make the actual calls and take any potential termination signal
     213        b->SetInsertPoint(nonZeroPath);
     214        callKernel(b, mNonZeroKernel, firstNonZeroIndex, lastNonZeroIndex, terminatedPhi);
     215        BasicBlock * const nonZeroPathExit = b->GetInsertBlock();
     216        firstAllZeroIndex->addIncoming(nextFirst, nonZeroPathExit);
     217        lastAllZeroIndex->addIncoming(nextLast, nonZeroPathExit);
     218        nonZeroAfterAllZero->addIncoming(b->getFalse(), nonZeroPathExit);
     219        b->CreateUnlikelyCondBr(allZeroAfterNonZero, allZeroPath, mergePaths);
     220
     221        b->SetInsertPoint(allZeroPath);
     222        callKernel(b, mAllZeroKernel, firstAllZeroIndex, lastAllZeroIndex, terminatedPhi);
     223        BasicBlock * const allZeroPathExit = b->GetInsertBlock();
     224        firstNonZeroIndex->addIncoming(nextFirst, allZeroPathExit);
     225        lastNonZeroIndex->addIncoming(nextLast, allZeroPathExit);
     226        allZeroAfterNonZero->addIncoming(b->getFalse(), allZeroPathExit);
     227        b->CreateUnlikelyCondBr(nonZeroAfterAllZero, nonZeroPath, mergePaths);
     228
     229        b->SetInsertPoint(mergePaths);
     230        currentFirstIndex->addIncoming(nextFirst, mergePaths);
     231        currentLastIndex->addIncoming(nextLast, mergePaths);
    165232        currentState->addIncoming(nextState, mergePaths);
     233        if (terminatedPhi) {
     234            finished = b->CreateOr(finished, terminatedPhi);
     235        }
     236        b->CreateLikelyCondBr(finished, exit, loopCond);
     237
    166238    } else {
    167         Value * const cond = b->getScalarField(CONDITION_TAG);
    168         b->CreateCondBr(b->CreateIsNotNull(cond), nonZeroPath, allZeroPath);
    169 
    170         last->addIncoming(numOfStrides, entry);
    171         last->addIncoming(numOfStrides, mergePaths);
    172         first->addIncoming(ZERO, mergePaths);
    173         currentState->addIncoming(b->getFalse(), mergePaths);
    174         finished = b->getTrue();
    175     }
    176 
    177     // make the actual calls and take any potential termination signal
    178     b->SetInsertPoint(nonZeroPath);
    179     callKernel(b, mNonZeroKernel, first, last, terminatedPhi);
    180     b->CreateBr(mergePaths);
    181 
    182     b->SetInsertPoint(allZeroPath);
    183     callKernel(b, mAllZeroKernel, first, last, terminatedPhi);
    184     b->CreateBr(mergePaths);
    185 
    186     b->SetInsertPoint(mergePaths);
    187     // Value * finished = b->CreateICmpEQ(last, numOfStrides);
    188     if (terminatedPhi) {
    189         finished = b->CreateOr(finished, terminatedPhi);
    190     }
    191     b->CreateLikelyCondBr(finished, exit, loopCond);
     239
     240//        Value * const cond = b->getScalarField(CONDITION_TAG);
     241//        b->CreateCondBr(b->CreateIsNotNull(cond), nonZeroPath, allZeroPath);
     242
     243//        // make the actual calls and take any potential termination signal
     244//        b->SetInsertPoint(nonZeroPath);
     245//        callKernel(b, mNonZeroKernel, ZERO, mNumOfStrides, b->getTrue(), nullptr);
     246//        b->CreateBr(exit);
     247
     248//        b->SetInsertPoint(allZeroPath);
     249//        callKernel(b, mAllZeroKernel, ZERO, mNumOfStrides, b->getTrue(), nullptr);
     250//        b->CreateBr(exit);
     251    }
    192252
    193253    b->SetInsertPoint(exit);
    194 
    195     b->CallPrintInt("branch_exit --------------------", numOfStrides);
    196254
    197255}
     
    204262                                    Value * const first, Value * const last,
    205263                                    PHINode * const terminatedPhi) {
     264
     265//    b->CallPrintInt(" &&& " + kernel->getName() + "_first", first);
     266//    b->CallPrintInt(" &&& " + kernel->getName() + "_last", last);
     267
     268    if (LLVM_UNLIKELY(codegen::DebugOptionIsSet(codegen::EnableAsserts))) {
     269        Value * const nonZeroLength = b->CreateICmpULT(first, last);
     270        Value * valid = b->CreateOr(nonZeroLength, b->CreateIsNull(last));
     271        b->CreateAssert(valid,
     272            "Branch cannot execute 0 strides unless this is the final stride");
     273    }
    206274
    207275    Function * const doSegment = kernel->getDoSegmentFunction(b->getModule());
     
    214282    }
    215283
    216 
    217     b->CallPrintInt("branch_first -------------------", first);
    218     b->CallPrintInt("branch_last --------------------", last);
    219 
    220284    std::vector<Value *> args;
    221285    args.reserve(doSegment->arg_size());
     
    226290    args.push_back(numOfStrides); // numOfStrides
    227291    const auto numOfInputs = kernel->getNumOfStreamInputs();
     292
     293    Value * const isFinal = b->CreateIsNull(last);
    228294
    229295    for (unsigned i = 0; i < numOfInputs; i++) {
     
    240306        args.push_back(processed);
    241307        // accessible input items (after non-deferred processed item count)
    242         args.push_back(getItemCountIncrement(b, input, first, last, mAccessibleInputItems[i]));
     308        Value * accessible = getItemCountIncrement(b, input, first, last, mAccessibleInputItems[i]);
     309        accessible = b->CreateSelect(isFinal, mAccessibleInputItems[i], accessible);
     310        args.push_back(accessible);
    243311        // TODO: What if one of the branches requires this but the other doesn't?
    244312        if (LLVM_UNLIKELY(input.hasAttribute(AttrId::RequiresPopCountArray))) {
     
    261329        }
    262330        args.push_back(produced);
    263         args.push_back(getItemCountIncrement(b, output, first, last, mWritableOutputItems[i]));
    264     }
    265 
    266 
     331        Value * writable = getItemCountIncrement(b, output, first, last, mWritableOutputItems[i]);
     332        writable = b->CreateSelect(isFinal, mWritableOutputItems[i], writable);
     333        args.push_back(writable);
     334    }
    267335
    268336    Value * const terminated = b->CreateCall(doSegment, args);
     
    274342
    275343    for (unsigned i = 0; i < numOfInputs; ++i) {
    276         const Binding & input = mInputStreamSets[i];
     344        const Binding & input = kernel->getInputStreamSetBinding(i);
    277345        if (isParamConstant(input)) {
    278346            Value * const processed = b->CreateLoad(mProcessedInputItemPtr[i]);
     
    281349            b->CreateStore(updatedInputCount, mProcessedInputItemPtr[i]);
    282350        }
     351//        Value * const processed = b->CreateLoad(mProcessedInputItemPtr[i]);
     352//        b->CallPrintInt(" &&& " + input.getName() + "_processed'", processed);
    283353    }
    284354
    285355    for (unsigned i = 0; i < numOfOutputs; ++i) {
    286         const Binding & output = mOutputStreamSets[i];
     356        const Binding & output = kernel->getOutputStreamSetBinding(i);
    287357        if (isParamConstant(output)) {
    288358            Value * const produced = b->CreateLoad(mProducedOutputItemPtr[i]);
     
    291361            b->CreateStore(updatedOutputCount, mProducedOutputItemPtr[i]);
    292362        }
     363//        Value * const processed = b->CreateLoad(mProducedOutputItemPtr[i]);
     364//        b->CallPrintInt(" &&& " + output.getName() + "_produced'", processed);
    293365    }
    294366
     
    298370        b->SetInsertPoint(kernelExit);
    299371    }
    300 
    301     b->CallPrintInt("branch_exec --------------", numOfStrides);
    302372
    303373}
     
    476546    Bindings && scalar_inputs,
    477547    Bindings && scalar_outputs)
    478 : MultiBlockKernel(b, TypeId::OptimizationBranch, std::move(signature),
    479                    std::move(stream_inputs), std::move(stream_outputs),
    480                    std::move(scalar_inputs), std::move(scalar_outputs),
    481                    // internal scalar
    482                    {Binding{b->getInt8Ty(), "priorState"}})
     548: Kernel(b, TypeId::OptimizationBranch, std::move(signature),
     549         std::move(stream_inputs), std::move(stream_outputs),
     550         std::move(scalar_inputs), std::move(scalar_outputs),
     551         {})
    483552, mCondition(condition.get())
    484553, mNonZeroKernel(nonZeroKernel.get())
  • icGREP/icgrep-devel/icgrep/kernels/pipeline/pipeline_compiler.hpp

    r6288 r6289  
    441441private:
    442442
    443     static constexpr StreamPort FAKE_CONSUMER{Port::Input, std::numeric_limits<unsigned>::max()};
     443    static const StreamPort FAKE_CONSUMER;
    444444
    445445protected:
     
    533533};
    534534
     535const StreamPort PipelineCompiler::FAKE_CONSUMER{Port::Input, std::numeric_limits<unsigned>::max()};
     536
    535537// NOTE: these graph functions not safe for general use since they are intended for inspection of *edge-immutable* graphs.
    536538
Note: See TracChangeset for help on using the changeset viewer.