Ignore:
Timestamp:
Feb 2, 2018, 2:49:08 PM (15 months ago)
Author:
nmedfort
Message:

Revised pipeline structure to better control I/O rates

File:
1 edited

Legend:

Unmodified
Added
Removed
  • icGREP/icgrep-devel/icgrep/kernels/until_n.cpp

    r5832 r5856  
    4444    Constant * const PACK_SIZE = b->getSize(packSize);
    4545    Constant * const PACKS_PER_BLOCK = b->getSize(packsPerBlock);
    46     VectorType * const vTy = VectorType::get(sizeTy, packsPerBlock);
    47     Value * const ZEROES = Constant::getNullValue(vTy);
     46    const auto blocksPerStride = getStride() / b->getBitBlockWidth();
     47    Constant * const BLOCKS_PER_STRIDE = b->getSize(blocksPerStride);
     48    const auto maximumBlocksPerIteration = packSize / packsPerBlock;
     49    Constant * const MAXIMUM_BLOCKS_PER_ITERATION = b->getSize(maximumBlocksPerIteration);
     50    VectorType * const packVectorTy = VectorType::get(sizeTy, packsPerBlock);
     51    Value * const ZEROES = Constant::getNullValue(packVectorTy);
    4852
    4953    BasicBlock * const entry = b->GetInsertBlock();
     54    Value * const numOfBlocks = b->CreateMul(numOfStrides, BLOCKS_PER_STRIDE);
    5055    BasicBlock * const strideLoop = b->CreateBasicBlock("strideLoop");
     56    b->CreateBr(strideLoop);
    5157
    52     Value * const allAvailableItems = b->getAvailableItemCount("bits");
     58    b->SetInsertPoint(strideLoop);
     59    PHINode * const baseBlockIndex = b->CreatePHI(b->getSizeTy(), 2);
     60    baseBlockIndex->addIncoming(ZERO, entry);
     61    PHINode * const blocksRemaining = b->CreatePHI(b->getSizeTy(), 2);
     62    blocksRemaining->addIncoming(numOfBlocks, entry);
     63    Value * const blocksToDo = b->CreateUMin(blocksRemaining, MAXIMUM_BLOCKS_PER_ITERATION);
     64    BasicBlock * const iteratorLoop = b->CreateBasicBlock("iteratorLoop");
     65    BasicBlock * const checkForMatches = b->CreateBasicBlock("checkForMatches");
     66    b->CreateBr(iteratorLoop);
    5367
    54     b->CreateBr(strideLoop);
    55     b->SetInsertPoint(strideLoop);
    56     PHINode * const strideIndex = b->CreatePHI(sizeTy, 2);
    57     strideIndex->addIncoming(ZERO, entry);
    5868
    59     const auto n = (packSize * packSize) / b->getBitBlockWidth();
    60     Value * groupMask = nullptr;
    61     Value * const baseOffset = b->CreateMul(strideIndex, b->getSize(n));
    62     for (unsigned i = 0; i < n; ++i) {
    63         Value * offset = b->CreateNUWAdd(baseOffset, b->getSize(i));
    64         Value * inputPtr = b->getInputStreamBlockPtr("bits", ZERO, offset);
    65         Value * inputValue = b->CreateBlockAlignedLoad(inputPtr);
    66         Value * outputPtr = b->getOutputStreamBlockPtr("uptoN", ZERO, offset);
    67         b->CreateBlockAlignedStore(inputValue, outputPtr);
    68         Value * markers = b->CreateNot(b->simd_eq(packSize, inputValue, ZEROES));
    69         Value * blockMask = b->CreateZExtOrTrunc(b->hsimd_signmask(packSize, markers), sizeTy);
    70         if (i) {
    71             blockMask = b->CreateShl(blockMask, i * packsPerBlock);
    72             groupMask = b->CreateOr(groupMask, blockMask);
    73         } else {
    74             groupMask = blockMask;
    75         }
    76     }
     69    // Construct the outer iterator mask indicating whether any markers are in the stream.
     70    b->SetInsertPoint(iteratorLoop);
     71    PHINode * const groupMaskPhi = b->CreatePHI(b->getSizeTy(), 2);
     72    groupMaskPhi->addIncoming(ZERO, strideLoop);
     73    PHINode * const localIndex = b->CreatePHI(b->getSizeTy(), 2);
     74    localIndex->addIncoming(ZERO, strideLoop);
     75    Value * const blockIndex = b->CreateAdd(baseBlockIndex, localIndex);
     76    Value * inputPtr = b->getInputStreamBlockPtr("bits", ZERO, blockIndex);
     77    Value * inputValue = b->CreateBlockAlignedLoad(inputPtr);
     78    Value * outputPtr = b->getOutputStreamBlockPtr("uptoN", ZERO, blockIndex);
     79    b->CreateBlockAlignedStore(inputValue, outputPtr);
     80    Value * const inputPackValue = b->CreateNot(b->simd_eq(packSize, inputValue, ZEROES));
     81    Value * iteratorMask = b->CreateZExtOrTrunc(b->hsimd_signmask(packSize, inputPackValue), sizeTy);
     82    iteratorMask = b->CreateShl(iteratorMask, b->CreateMul(localIndex, PACKS_PER_BLOCK));
     83    iteratorMask = b->CreateOr(groupMaskPhi, iteratorMask);
     84    groupMaskPhi->addIncoming(iteratorMask, iteratorLoop);
     85    Value * const nextLocalIndex = b->CreateAdd(localIndex, ONE);
     86    localIndex->addIncoming(nextLocalIndex, iteratorLoop);
     87    b->CreateCondBr(b->CreateICmpNE(nextLocalIndex, blocksToDo), iteratorLoop, checkForMatches);
     88
     89    // Now check whether we have any matches
     90    b->SetInsertPoint(checkForMatches);
    7791
    7892    BasicBlock * const processGroups = b->CreateBasicBlock("processGroups");
    7993    BasicBlock * const nextStride = b->CreateBasicBlock("nextStride");
    80 
    81     b->CreateLikelyCondBr(b->CreateIsNull(groupMask), nextStride, processGroups);
     94    b->CreateLikelyCondBr(b->CreateIsNull(iteratorMask), nextStride, processGroups);
    8295
    8396    b->SetInsertPoint(processGroups);
     
    90103    PHINode * const observed = b->CreatePHI(initiallyObserved->getType(), 2);
    91104    observed->addIncoming(initiallyObserved, processGroups);
    92     PHINode * const groupMarkers = b->CreatePHI(groupMask->getType(), 2);
    93     groupMarkers->addIncoming(groupMask, processGroups);
     105    PHINode * const groupMarkers = b->CreatePHI(iteratorMask->getType(), 2);
     106    groupMarkers->addIncoming(iteratorMask, processGroups);
    94107
    95108    Value * const groupIndex = b->CreateZExtOrTrunc(b->CreateCountForwardZeroes(groupMarkers), sizeTy);
    96     Value * const blockIndex = b->CreateNUWAdd(baseOffset, b->CreateUDiv(groupIndex, PACKS_PER_BLOCK));
     109    Value * const blockIndex2 = b->CreateAdd(baseBlockIndex, b->CreateUDiv(groupIndex, PACKS_PER_BLOCK));
    97110    Value * const packOffset = b->CreateURem(groupIndex, PACKS_PER_BLOCK);
    98     Value * const groupPtr = b->getInputStreamBlockPtr("bits", ZERO, blockIndex);
    99     Value * const groupValue = b->CreateBlockAlignedLoad(groupPtr);
    100     Value * const packBits = b->CreateExtractElement(b->CreateBitCast(groupValue, vTy), packOffset);
    101     //Type * packPtrTy = packTy->getPointerTo();
    102     //Value * const packPtr = b->CreateGEP(b->CreatePointerCast(groupPtr, packPtrTy), packOffset);
    103     //Value * const packBits = b->CreateLoad(packPtr);
     111    Value * const groupPtr2 = b->getInputStreamBlockPtr("bits", ZERO, blockIndex2);
     112    Value * const groupValue = b->CreateBlockAlignedLoad(groupPtr2);
     113    Value * const packBits = b->CreateExtractElement(b->CreateBitCast(groupValue, packVectorTy), packOffset);
    104114    Value * const packCount = b->CreateZExtOrTrunc(b->CreatePopcount(packBits), sizeTy);
    105     Value * const observedUpTo = b->CreateNUWAdd(observed, packCount);
    106 
     115    Value * const observedUpTo = b->CreateAdd(observed, packCount);
    107116    BasicBlock * const haveNotSeenEnough = b->CreateBasicBlock("haveNotSeenEnough");
    108117    BasicBlock * const seenNOrMore = b->CreateBasicBlock("seenNOrMore");
     
    122131        b->CreateAssert(b->CreateICmpUGT(N, observed), "N must be greater than observed count!");
    123132    }
    124     Value * const bitsToFind = b->CreateNUWSub(N, observed);
     133    Value * const bitsToFind = b->CreateSub(N, observed);
    125134    BasicBlock * const findNthBit = b->CreateBasicBlock("findNthBit");
    126135    BasicBlock * const foundNthBit = b->CreateBasicBlock("foundNthBit");
     
    134143    Value * const nextRemainingBits = b->CreateResetLowestBit(remainingBits);
    135144    remainingBits->addIncoming(nextRemainingBits, findNthBit);
    136     Value * const nextRemainingBitsToFind = b->CreateNUWSub(remainingBitsToFind, ONE);
     145    Value * const nextRemainingBitsToFind = b->CreateSub(remainingBitsToFind, ONE);
    137146    remainingBitsToFind->addIncoming(nextRemainingBitsToFind, findNthBit);
    138147    b->CreateLikelyCondBr(b->CreateIsNull(nextRemainingBitsToFind), foundNthBit, findNthBit);
     
    140149    // If we've found the n-th bit, end the segment after clearing the markers
    141150    b->SetInsertPoint(foundNthBit);
    142     Value * const inputPtr = b->getInputStreamBlockPtr("bits", ZERO, blockIndex);
    143     Value * const inputValue = b->CreateBlockAlignedLoad(inputPtr);
     151
     152    Value * const inputPtr2 = b->getInputStreamBlockPtr("bits", ZERO, blockIndex2);
     153    Value * const inputValue2 = b->CreateBlockAlignedLoad(inputPtr2);
    144154    Value * const packPosition = b->CreateZExtOrTrunc(b->CreateCountForwardZeroes(remainingBits), sizeTy);
    145     Value * const basePosition = b->CreateNUWMul(packOffset, PACK_SIZE);
    146     Value * const blockOffset = b->CreateNUWAdd(b->CreateOr(basePosition, packPosition), ONE);
     155    Value * const basePosition = b->CreateMul(packOffset, PACK_SIZE);
     156    Value * const blockOffset = b->CreateAdd(b->CreateOr(basePosition, packPosition), ONE);
    147157    Value * const mask = b->CreateNot(b->bitblock_mask_from(blockOffset));
    148     Value * const maskedInputValue = b->CreateAnd(inputValue, mask);
    149     Value * const outputPtr = b->getOutputStreamBlockPtr("uptoN", ZERO, blockIndex);
    150     b->CreateBlockAlignedStore(maskedInputValue, outputPtr);
    151     Value * const positionOfNthItem = b->CreateNUWAdd(b->CreateMul(blockIndex, b->getSize(b->getBitBlockWidth())), blockOffset);
     158    Value * const maskedInputValue = b->CreateAnd(inputValue2, mask);
     159    Value * const outputPtr2 = b->getOutputStreamBlockPtr("uptoN", ZERO, blockIndex2);
     160    b->CreateBlockAlignedStore(maskedInputValue, outputPtr2);
     161    Value * const positionOfNthItem = b->CreateAdd(b->CreateMul(blockIndex2, b->getSize(b->getBitBlockWidth())), blockOffset);
    152162    b->setTerminationSignal();
    153163    BasicBlock * const segmentDone = b->CreateBasicBlock("segmentDone");
     
    157167
    158168    b->SetInsertPoint(nextStride);
    159     Value * const nextStrideIndex = b->CreateNUWAdd(strideIndex, ONE);
    160     strideIndex->addIncoming(nextStrideIndex, nextStride);
    161     b->CreateLikelyCondBr(b->CreateICmpEQ(nextStrideIndex, numOfStrides), segmentDone, strideLoop);
     169    blocksRemaining->addIncoming(b->CreateSub(blocksRemaining, MAXIMUM_BLOCKS_PER_ITERATION), nextStride);
     170    baseBlockIndex->addIncoming(b->CreateAdd(baseBlockIndex, MAXIMUM_BLOCKS_PER_ITERATION), nextStride);
     171    b->CreateLikelyCondBr(b->CreateICmpULE(blocksRemaining, MAXIMUM_BLOCKS_PER_ITERATION), segmentDone, strideLoop);
    162172
    163173    b->SetInsertPoint(segmentDone);
    164174    PHINode * const produced = b->CreatePHI(sizeTy, 2);
    165175    produced->addIncoming(positionOfNthItem, foundNthBit);
    166     produced->addIncoming(allAvailableItems, nextStride);
     176    produced->addIncoming(b->getAvailableItemCount("bits"), nextStride);
    167177    Value * producedCount = b->getProducedItemCount("uptoN");
    168     producedCount = b->CreateNUWAdd(producedCount, produced);
     178    producedCount = b->CreateAdd(producedCount, produced);
    169179    b->setProducedItemCount("uptoN", producedCount);
    170180
    171181}
    172182
    173 unsigned LLVM_READNONE calculateRate(const std::unique_ptr<kernel::KernelBuilder> & b) {
    174     const unsigned packSize = b->getSizeTy()->getBitWidth();
    175     return (packSize * packSize) / b->getBitBlockWidth();
    176 }
    177 
    178183UntilNkernel::UntilNkernel(const std::unique_ptr<kernel::KernelBuilder> & b)
    179 : MultiBlockKernel("UntilN_" + std::to_string(calculateRate(b)),
     184: MultiBlockKernel("UntilN",
    180185// inputs
    181 {Binding{b->getStreamSetTy(), "bits", FixedRate(calculateRate(b))}},
     186{Binding{b->getStreamSetTy(), "bits"}},
    182187// outputs
    183 {Binding{b->getStreamSetTy(), "uptoN", BoundedRate(0, calculateRate(b))}},
     188{Binding{b->getStreamSetTy(), "uptoN", BoundedRate(0, 1)}},
    184189// input scalar
    185190{Binding{b->getSizeTy(), "N"}}, {},
    186191// internal state
    187192{Binding{b->getSizeTy(), "observed"}}) {
    188 
     193    addAttribute(CanTerminateEarly());
    189194}
    190195
Note: See TracChangeset for help on using the changeset viewer.