source: icGREP/icgrep-devel/icgrep/kernels/until_n.cpp

Last change on this file was 6261, checked in by nmedfort, 4 months ago

Work on OptimizationBranch?; revisited pipeline termination

File size: 10.3 KB
Line 
1/*
2 *  Copyright (c) 2017 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 */
5
6#include "until_n.h"
7#include <llvm/IR/Module.h>
8#include <kernels/kernel_builder.h>
9#include <kernels/streamset.h>
10#include <toolchain/toolchain.h>
11
12namespace llvm { class Type; }
13
14using namespace llvm;
15
16namespace kernel {
17
18void UntilNkernel::generateMultiBlockLogic(const std::unique_ptr<KernelBuilder> & b, llvm::Value * const numOfStrides) {
19
20/*
21   Strategy:  first form an index consisting of one bit per packsize input positions,
22   with a 1 bit signifying that the corresponding pack has at least one 1 bit.
23   Build the index one pack at a time, i.e, packsize * packsize positions at a time.
24   After an index pack is constructed, scan the index pack for 1 bits.  Each 1 bit
25   found identifies an input pack with a nonzero popcount.  Take the actual popcount
26   of the corresponding input pack and update the total number of bits seen.   If
27   the number of bits seen reaches N with any pack, determine the position of the
28   Nth bit and signal termination at that point.
29
30   For normal processing, we process whole blocks only, always advanced processed
31   and produced item counts by an integral number of blocks.   For final block
32   processing, we treat the final partial block as a whole block for the purpose
33   of finding the Nth bit.   However, if the located bit position is past the
34   EOF position, then this is treated as if the Nth bit does not exist in the
35   input stream.
36*/
37    IntegerType * const sizeTy = b->getSizeTy();
38    const unsigned packSize = sizeTy->getBitWidth();
39    Constant * const ZERO = b->getSize(0);
40    Constant * const ONE = b->getSize(1);
41    const auto packsPerBlock = b->getBitBlockWidth() / packSize;
42    Constant * const PACK_SIZE = b->getSize(packSize);
43    Constant * const PACKS_PER_BLOCK = b->getSize(packsPerBlock);
44    const auto blocksPerStride = getStride() / b->getBitBlockWidth();
45    Constant * const BLOCKS_PER_STRIDE = b->getSize(blocksPerStride);
46    const auto maximumBlocksPerIteration = packSize / packsPerBlock;
47    Constant * const MAXIMUM_BLOCKS_PER_ITERATION = b->getSize(maximumBlocksPerIteration);
48    VectorType * const packVectorTy = VectorType::get(sizeTy, packsPerBlock);
49    Value * const ZEROES = Constant::getNullValue(packVectorTy);
50
51    BasicBlock * const entry = b->GetInsertBlock();
52    Value * const numOfBlocks = b->CreateMul(numOfStrides, BLOCKS_PER_STRIDE);
53    BasicBlock * const strideLoop = b->CreateBasicBlock("strideLoop");
54    b->CreateBr(strideLoop);
55
56    b->SetInsertPoint(strideLoop);
57    PHINode * const baseBlockIndex = b->CreatePHI(sizeTy, 2);
58    baseBlockIndex->addIncoming(ZERO, entry);
59    PHINode * const blocksRemaining = b->CreatePHI(sizeTy, 2);
60    blocksRemaining->addIncoming(numOfBlocks, entry);
61    Value * const blocksToDo = b->CreateUMin(blocksRemaining, MAXIMUM_BLOCKS_PER_ITERATION);
62    BasicBlock * const iteratorLoop = b->CreateBasicBlock("iteratorLoop");
63    BasicBlock * const checkForMatches = b->CreateBasicBlock("checkForMatches");
64    b->CreateBr(iteratorLoop);
65
66
67    // Construct the outer iterator mask indicating whether any markers are in the stream.
68    b->SetInsertPoint(iteratorLoop);
69    PHINode * const groupMaskPhi = b->CreatePHI(sizeTy, 2);
70    groupMaskPhi->addIncoming(ZERO, strideLoop);
71    PHINode * const localIndex = b->CreatePHI(sizeTy, 2);
72    localIndex->addIncoming(ZERO, strideLoop);
73    Value * const blockIndex = b->CreateAdd(baseBlockIndex, localIndex);
74    Value * inputValue = b->loadInputStreamBlock("bits", ZERO, blockIndex);
75    b->storeOutputStreamBlock("uptoN", ZERO, blockIndex, inputValue);
76    Value * const inputPackValue = b->CreateNot(b->simd_eq(packSize, inputValue, ZEROES));
77    Value * iteratorMask = b->CreateZExtOrTrunc(b->hsimd_signmask(packSize, inputPackValue), sizeTy);
78    iteratorMask = b->CreateShl(iteratorMask, b->CreateMul(localIndex, PACKS_PER_BLOCK));
79    iteratorMask = b->CreateOr(groupMaskPhi, iteratorMask);
80    groupMaskPhi->addIncoming(iteratorMask, iteratorLoop);
81    Value * const nextLocalIndex = b->CreateAdd(localIndex, ONE);
82    localIndex->addIncoming(nextLocalIndex, iteratorLoop);
83    b->CreateCondBr(b->CreateICmpNE(nextLocalIndex, blocksToDo), iteratorLoop, checkForMatches);
84
85    // Now check whether we have any matches
86    b->SetInsertPoint(checkForMatches);
87
88    BasicBlock * const processGroups = b->CreateBasicBlock("processGroups");
89    BasicBlock * const nextStride = b->CreateBasicBlock("nextStride");
90    b->CreateLikelyCondBr(b->CreateIsNull(iteratorMask), nextStride, processGroups);
91
92    b->SetInsertPoint(processGroups);
93    Value * const N = b->getScalarField("N");
94    Value * const initiallyObserved = b->getScalarField("observed");
95    BasicBlock * const processGroup = b->CreateBasicBlock("processGroup", nextStride);
96    b->CreateBr(processGroup);
97
98    b->SetInsertPoint(processGroup);
99    PHINode * const observed = b->CreatePHI(initiallyObserved->getType(), 2);
100    observed->addIncoming(initiallyObserved, processGroups);
101    PHINode * const groupMarkers = b->CreatePHI(iteratorMask->getType(), 2);
102    groupMarkers->addIncoming(iteratorMask, processGroups);
103
104    Value * const groupIndex = b->CreateZExtOrTrunc(b->CreateCountForwardZeroes(groupMarkers), sizeTy);
105    Value * const blockIndex2 = b->CreateAdd(baseBlockIndex, b->CreateUDiv(groupIndex, PACKS_PER_BLOCK));
106    Value * const packOffset = b->CreateURem(groupIndex, PACKS_PER_BLOCK);
107    Value * const groupValue = b->loadInputStreamBlock("bits", ZERO, blockIndex2);
108    Value * const packBits = b->CreateExtractElement(b->CreateBitCast(groupValue, packVectorTy), packOffset);
109    Value * const packCount = b->CreateZExtOrTrunc(b->CreatePopcount(packBits), sizeTy);
110    Value * const observedUpTo = b->CreateAdd(observed, packCount);
111    BasicBlock * const haveNotSeenEnough = b->CreateBasicBlock("haveNotSeenEnough", nextStride);
112    BasicBlock * const seenNOrMore = b->CreateBasicBlock("seenNOrMore", nextStride);
113    b->CreateLikelyCondBr(b->CreateICmpULT(observedUpTo, N), haveNotSeenEnough, seenNOrMore);
114
115    // update our kernel state and check whether we have any other groups to process
116    b->SetInsertPoint(haveNotSeenEnough);
117    observed->addIncoming(observedUpTo, haveNotSeenEnough);
118    b->setScalarField("observed", observedUpTo);
119    Value * const remainingGroupMarkers = b->CreateResetLowestBit(groupMarkers);
120    groupMarkers->addIncoming(remainingGroupMarkers, haveNotSeenEnough);
121    b->CreateLikelyCondBr(b->CreateIsNull(remainingGroupMarkers), nextStride, processGroup);
122
123    // we've seen N non-zero items; determine the position of our items and clear any subsequent markers
124    b->SetInsertPoint(seenNOrMore);
125    if (LLVM_UNLIKELY(codegen::DebugOptionIsSet(codegen::EnableAsserts))) {
126        b->CreateAssert(b->CreateICmpUGT(N, observed), "N must be greater than observed count!");
127    }
128    Value * const bitsToFind = b->CreateSub(N, observed);
129    BasicBlock * const findNthBit = b->CreateBasicBlock("findNthBit", nextStride);
130    BasicBlock * const foundNthBit = b->CreateBasicBlock("foundNthBit", nextStride);
131    b->CreateBr(findNthBit);
132
133    b->SetInsertPoint(findNthBit);
134    PHINode * const remainingBitsToFind = b->CreatePHI(bitsToFind->getType(), 2);
135    remainingBitsToFind->addIncoming(bitsToFind, seenNOrMore);
136    PHINode * const remainingBits = b->CreatePHI(packBits->getType(), 2);
137    remainingBits->addIncoming(packBits, seenNOrMore);
138    Value * const nextRemainingBits = b->CreateResetLowestBit(remainingBits);
139    remainingBits->addIncoming(nextRemainingBits, findNthBit);
140    Value * const nextRemainingBitsToFind = b->CreateSub(remainingBitsToFind, ONE);
141    remainingBitsToFind->addIncoming(nextRemainingBitsToFind, findNthBit);
142    b->CreateLikelyCondBr(b->CreateIsNull(nextRemainingBitsToFind), foundNthBit, findNthBit);
143
144    // If we've found the n-th bit, end the segment after clearing the markers
145    b->SetInsertPoint(foundNthBit);
146    Value * const packPosition = b->CreateZExtOrTrunc(b->CreateCountForwardZeroes(remainingBits), sizeTy);
147    Value * const basePosition = b->CreateMul(packOffset, PACK_SIZE);
148    Value * const blockOffset = b->CreateOr(basePosition, packPosition);
149    Value * const inputValue2 = b->loadInputStreamBlock("bits", ZERO, blockIndex2);
150    Value * const mask = b->bitblock_mask_to(blockOffset, true);
151    Value * const maskedInputValue = b->CreateAnd(inputValue2, mask);
152    b->storeOutputStreamBlock("uptoN", ZERO, blockIndex2, maskedInputValue);
153    Value * const priorProducedItemCount = b->getProducedItemCount("uptoN");
154    const auto log2BlockWidth = std::log2<unsigned>(b->getBitBlockWidth());
155    Value * positionOfNthItem = b->CreateShl(blockIndex2, log2BlockWidth);
156    positionOfNthItem = b->CreateAdd(positionOfNthItem, b->CreateAdd(blockOffset, ONE));
157    positionOfNthItem = b->CreateAdd(positionOfNthItem, priorProducedItemCount);
158    if (LLVM_UNLIKELY(codegen::DebugOptionIsSet(codegen::EnableAsserts))) {
159        Value * const availableBits = b->getAvailableItemCount("bits");
160        Value * const positionLessThanAvail = b->CreateICmpULT(positionOfNthItem, availableBits);
161        b->CreateAssert(positionLessThanAvail, "position of n-th item exceeds available items!");
162    }
163    b->setTerminationSignal();
164    BasicBlock * const segmentDone = b->CreateBasicBlock("segmentDone");
165    b->CreateBr(segmentDone);
166
167    b->SetInsertPoint(nextStride);
168    blocksRemaining->addIncoming(b->CreateSub(blocksRemaining, MAXIMUM_BLOCKS_PER_ITERATION), nextStride);
169    baseBlockIndex->addIncoming(b->CreateAdd(baseBlockIndex, MAXIMUM_BLOCKS_PER_ITERATION), nextStride);
170    Value * const availableBits = b->getAvailableItemCount("bits");
171    b->CreateLikelyCondBr(b->CreateICmpULE(blocksRemaining, MAXIMUM_BLOCKS_PER_ITERATION), segmentDone, strideLoop);
172
173    b->SetInsertPoint(segmentDone);
174    PHINode * const produced = b->CreatePHI(sizeTy, 2);
175    produced->addIncoming(positionOfNthItem, foundNthBit);
176    produced->addIncoming(availableBits, nextStride);
177    b->setProducedItemCount("uptoN", produced);
178
179}
180
181UntilNkernel::UntilNkernel(const std::unique_ptr<kernel::KernelBuilder> & b, Scalar * maxCount, StreamSet * AllMatches, StreamSet * Matches)
182: MultiBlockKernel(b, "UntilN",
183// inputs
184{Binding{"bits", AllMatches}},
185// outputs
186{Binding{"uptoN", Matches, BoundedRate(0, 1)}},
187// input scalar
188{Binding{"N", maxCount}}, {},
189// internal state
190{Binding{maxCount->getType(), "observed"}}) {
191    addAttribute(CanTerminateEarly());
192}
193
194}
Note: See TracBrowser for help on using the repository browser.