Ignore:
Timestamp:
Jan 15, 2018, 3:42:27 PM (22 months ago)
Author:
nmedfort
Message:

Bug fix for UntilN

File:
1 edited

Legend:

Unmodified
Added
Removed
  • icGREP/icgrep-devel/icgrep/kernels/until_n.cpp

    r5831 r5832  
    3737*/
    3838
    39     const unsigned packSize = b->getSizeTy()->getBitWidth();
     39    IntegerType * const sizeTy = b->getSizeTy();
     40    const unsigned packSize = sizeTy->getBitWidth();
    4041    Constant * const ZERO = b->getSize(0);
    4142    Constant * const ONE = b->getSize(1);
     
    4344    Constant * const PACK_SIZE = b->getSize(packSize);
    4445    Constant * const PACKS_PER_BLOCK = b->getSize(packsPerBlock);
    45     Value * const ZEROES = b->allZeroes();
    46     Type * packTy = b->getIntNTy(packSize);
     46    VectorType * const vTy = VectorType::get(sizeTy, packsPerBlock);
     47    Value * const ZEROES = Constant::getNullValue(vTy);
    4748
    4849    BasicBlock * const entry = b->GetInsertBlock();
    4950    BasicBlock * const strideLoop = b->CreateBasicBlock("strideLoop");
    5051
     52    Value * const allAvailableItems = b->getAvailableItemCount("bits");
     53
    5154    b->CreateBr(strideLoop);
    5255    b->SetInsertPoint(strideLoop);
    53     PHINode * const strideIndex = b->CreatePHI(b->getSizeTy(), 2);
     56    PHINode * const strideIndex = b->CreatePHI(sizeTy, 2);
    5457    strideIndex->addIncoming(ZERO, entry);
    5558
     
    6467        b->CreateBlockAlignedStore(inputValue, outputPtr);
    6568        Value * markers = b->CreateNot(b->simd_eq(packSize, inputValue, ZEROES));
    66         Value * blockMask = b->CreateZExtOrTrunc(b->hsimd_signmask(packSize, markers), packTy);
     69        Value * blockMask = b->CreateZExtOrTrunc(b->hsimd_signmask(packSize, markers), sizeTy);
    6770        if (i) {
    6871            blockMask = b->CreateShl(blockMask, i * packsPerBlock);
     
    9093    groupMarkers->addIncoming(groupMask, processGroups);
    9194
    92     Value * const groupIndex = b->CreateZExtOrTrunc(b->CreateCountForwardZeroes(groupMarkers), b->getSizeTy());
     95    Value * const groupIndex = b->CreateZExtOrTrunc(b->CreateCountForwardZeroes(groupMarkers), sizeTy);
    9396    Value * const blockIndex = b->CreateNUWAdd(baseOffset, b->CreateUDiv(groupIndex, PACKS_PER_BLOCK));
    9497    Value * const packOffset = b->CreateURem(groupIndex, PACKS_PER_BLOCK);
    9598    Value * const groupPtr = b->getInputStreamBlockPtr("bits", ZERO, blockIndex);
    9699    Value * const groupValue = b->CreateBlockAlignedLoad(groupPtr);
    97     Value * const packBits = b->CreateExtractElement(groupValue, packOffset);
    98 
     100    Value * const packBits = b->CreateExtractElement(b->CreateBitCast(groupValue, vTy), packOffset);
    99101    //Type * packPtrTy = packTy->getPointerTo();
    100102    //Value * const packPtr = b->CreateGEP(b->CreatePointerCast(groupPtr, packPtrTy), packOffset);
    101103    //Value * const packBits = b->CreateLoad(packPtr);
    102     Value * const packCount = b->CreateZExtOrTrunc(b->CreatePopcount(packBits), b->getSizeTy());
     104    Value * const packCount = b->CreateZExtOrTrunc(b->CreatePopcount(packBits), sizeTy);
    103105    Value * const observedUpTo = b->CreateNUWAdd(observed, packCount);
    104106
     
    126128
    127129    b->SetInsertPoint(findNthBit);
    128     PHINode * const remainingPositions = b->CreatePHI(bitsToFind->getType(), 2);
    129     remainingPositions->addIncoming(bitsToFind, seenNOrMore);
     130    PHINode * const remainingBitsToFind = b->CreatePHI(bitsToFind->getType(), 2);
     131    remainingBitsToFind->addIncoming(bitsToFind, seenNOrMore);
    130132    PHINode * const remainingBits = b->CreatePHI(packBits->getType(), 2);
    131133    remainingBits->addIncoming(packBits, seenNOrMore);
    132     Value * const nextRemainingPositions = b->CreateNUWSub(remainingPositions, ONE);
    133     remainingPositions->addIncoming(nextRemainingPositions, findNthBit);
    134134    Value * const nextRemainingBits = b->CreateResetLowestBit(remainingBits);
    135135    remainingBits->addIncoming(nextRemainingBits, findNthBit);
    136 
    137     b->CreateLikelyCondBr(b->CreateIsNull(nextRemainingPositions), foundNthBit, findNthBit);
     136    Value * const nextRemainingBitsToFind = b->CreateNUWSub(remainingBitsToFind, ONE);
     137    remainingBitsToFind->addIncoming(nextRemainingBitsToFind, findNthBit);
     138    b->CreateLikelyCondBr(b->CreateIsNull(nextRemainingBitsToFind), foundNthBit, findNthBit);
    138139
    139140    // If we've found the n-th bit, end the segment after clearing the markers
     
    141142    Value * const inputPtr = b->getInputStreamBlockPtr("bits", ZERO, blockIndex);
    142143    Value * const inputValue = b->CreateBlockAlignedLoad(inputPtr);
    143     Value * const packPosition = b->CreateZExtOrTrunc(b->CreateCountForwardZeroes(remainingBits), b->getSizeTy());
     144    Value * const packPosition = b->CreateZExtOrTrunc(b->CreateCountForwardZeroes(remainingBits), sizeTy);
    144145    Value * const basePosition = b->CreateNUWMul(packOffset, PACK_SIZE);
    145146    Value * const blockOffset = b->CreateNUWAdd(b->CreateOr(basePosition, packPosition), ONE);
     
    160161    b->CreateLikelyCondBr(b->CreateICmpEQ(nextStrideIndex, numOfStrides), segmentDone, strideLoop);
    161162
    162     Constant * const FULL_STRIDE = b->getSize(packSize * packSize);
    163 
    164163    b->SetInsertPoint(segmentDone);
    165     PHINode * const produced = b->CreatePHI(b->getSizeTy(), 2);
     164    PHINode * const produced = b->CreatePHI(sizeTy, 2);
    166165    produced->addIncoming(positionOfNthItem, foundNthBit);
    167     produced->addIncoming(FULL_STRIDE, nextStride);
    168 
     166    produced->addIncoming(allAvailableItems, nextStride);
    169167    Value * producedCount = b->getProducedItemCount("uptoN");
    170     producedCount = b->CreateNUWAdd(producedCount, b->CreateNUWMul(FULL_STRIDE, strideIndex));
    171168    producedCount = b->CreateNUWAdd(producedCount, produced);
    172169    b->setProducedItemCount("uptoN", producedCount);
Note: See TracChangeset for help on using the changeset viewer.