source: icGREP/icgrep-devel/icgrep/kernels/until_n.cpp @ 6135

Last change on this file since 6135 was 5985, checked in by nmedfort, 13 months ago

Restructured MultiBlock? kernel. Removal of Swizzled buffers. Inclusion of PopCount? rates / non-linear access. Modifications to several kernels to better align them with the kernel and pipeline changes.

File size: 10.1 KB
Line 
1/*
2 *  Copyright (c) 2017 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 */
5
6#include "until_n.h"
7#include <llvm/IR/Module.h>
8#include <kernels/kernel_builder.h>
9#include <kernels/streamset.h>
10#include <toolchain/toolchain.h>
11
12namespace llvm { class Type; }
13
14using namespace llvm;
15using namespace parabix;
16
17namespace kernel {
18
19void UntilNkernel::generateMultiBlockLogic(const std::unique_ptr<KernelBuilder> & b, llvm::Value * const numOfStrides) {
20
21/* 
22   Strategy:  first form an index consisting of one bit per packsize input positions,
23   with a 1 bit signifying that the corresponding pack has at least one 1 bit.
24   Build the index one pack at a time, i.e, packsize * packsize positions at a time.
25   After an index pack is constructed, scan the index pack for 1 bits.  Each 1 bit
26   found identifies an input pack with a nonzero popcount.  Take the actual popcount
27   of the corresponding input pack and update the total number of bits seen.   If
28   the number of bits seen reaches N with any pack, determine the position of the
29   Nth bit and signal termination at that point.
30 
31   For normal processing, we process whole blocks only, always advanced processed
32   and produced item counts by an integral number of blocks.   For final block
33   processing, we treat the final partial block as a whole block for the purpose
34   of finding the Nth bit.   However, if the located bit position is past the
35   EOF position, then this is treated as if the Nth bit does not exist in the
36   input stream.
37*/
38
39    IntegerType * const sizeTy = b->getSizeTy();
40    const unsigned packSize = sizeTy->getBitWidth();
41    Constant * const ZERO = b->getSize(0);
42    Constant * const ONE = b->getSize(1);
43    const auto packsPerBlock = b->getBitBlockWidth() / packSize;
44    Constant * const PACK_SIZE = b->getSize(packSize);
45    Constant * const PACKS_PER_BLOCK = b->getSize(packsPerBlock);
46    const auto blocksPerStride = getStride() / b->getBitBlockWidth();
47    Constant * const BLOCKS_PER_STRIDE = b->getSize(blocksPerStride);
48    const auto maximumBlocksPerIteration = packSize / packsPerBlock;
49    Constant * const MAXIMUM_BLOCKS_PER_ITERATION = b->getSize(maximumBlocksPerIteration);
50    VectorType * const packVectorTy = VectorType::get(sizeTy, packsPerBlock);
51    Value * const ZEROES = Constant::getNullValue(packVectorTy);
52
53    BasicBlock * const entry = b->GetInsertBlock();
54    Value * const numOfBlocks = b->CreateMul(numOfStrides, BLOCKS_PER_STRIDE);
55    BasicBlock * const strideLoop = b->CreateBasicBlock("strideLoop");
56    b->CreateBr(strideLoop);
57
58    b->SetInsertPoint(strideLoop);
59    PHINode * const baseBlockIndex = b->CreatePHI(b->getSizeTy(), 2);
60    baseBlockIndex->addIncoming(ZERO, entry);
61    PHINode * const blocksRemaining = b->CreatePHI(b->getSizeTy(), 2);
62    blocksRemaining->addIncoming(numOfBlocks, entry);
63    Value * const blocksToDo = b->CreateUMin(blocksRemaining, MAXIMUM_BLOCKS_PER_ITERATION);
64    BasicBlock * const iteratorLoop = b->CreateBasicBlock("iteratorLoop");
65    BasicBlock * const checkForMatches = b->CreateBasicBlock("checkForMatches");
66    b->CreateBr(iteratorLoop);
67
68
69    // Construct the outer iterator mask indicating whether any markers are in the stream.
70    b->SetInsertPoint(iteratorLoop);
71    PHINode * const groupMaskPhi = b->CreatePHI(b->getSizeTy(), 2);
72    groupMaskPhi->addIncoming(ZERO, strideLoop);
73    PHINode * const localIndex = b->CreatePHI(b->getSizeTy(), 2);
74    localIndex->addIncoming(ZERO, strideLoop);
75    Value * const blockIndex = b->CreateAdd(baseBlockIndex, localIndex);
76    Value * inputPtr = b->getInputStreamBlockPtr("bits", ZERO, blockIndex);
77    Value * inputValue = b->CreateBlockAlignedLoad(inputPtr);
78    Value * outputPtr = b->getOutputStreamBlockPtr("uptoN", ZERO, blockIndex);
79    b->CreateBlockAlignedStore(inputValue, outputPtr);
80    Value * const inputPackValue = b->CreateNot(b->simd_eq(packSize, inputValue, ZEROES));
81    Value * iteratorMask = b->CreateZExtOrTrunc(b->hsimd_signmask(packSize, inputPackValue), sizeTy);
82    iteratorMask = b->CreateShl(iteratorMask, b->CreateMul(localIndex, PACKS_PER_BLOCK));
83    iteratorMask = b->CreateOr(groupMaskPhi, iteratorMask);
84    groupMaskPhi->addIncoming(iteratorMask, iteratorLoop);
85    Value * const nextLocalIndex = b->CreateAdd(localIndex, ONE);
86    localIndex->addIncoming(nextLocalIndex, iteratorLoop);
87    b->CreateCondBr(b->CreateICmpNE(nextLocalIndex, blocksToDo), iteratorLoop, checkForMatches);
88
89    // Now check whether we have any matches
90    b->SetInsertPoint(checkForMatches);
91
92    BasicBlock * const processGroups = b->CreateBasicBlock("processGroups");
93    BasicBlock * const nextStride = b->CreateBasicBlock("nextStride");
94    b->CreateLikelyCondBr(b->CreateIsNull(iteratorMask), nextStride, processGroups);
95
96    b->SetInsertPoint(processGroups);
97    Value * const N = b->getScalarField("N");
98    Value * const initiallyObserved = b->getScalarField("observed");
99    BasicBlock * const processGroup = b->CreateBasicBlock("processGroup");
100    b->CreateBr(processGroup);
101
102    b->SetInsertPoint(processGroup);
103    PHINode * const observed = b->CreatePHI(initiallyObserved->getType(), 2);
104    observed->addIncoming(initiallyObserved, processGroups);
105    PHINode * const groupMarkers = b->CreatePHI(iteratorMask->getType(), 2);
106    groupMarkers->addIncoming(iteratorMask, processGroups);
107
108    Value * const groupIndex = b->CreateZExtOrTrunc(b->CreateCountForwardZeroes(groupMarkers), sizeTy);
109    Value * const blockIndex2 = b->CreateAdd(baseBlockIndex, b->CreateUDiv(groupIndex, PACKS_PER_BLOCK));
110    Value * const packOffset = b->CreateURem(groupIndex, PACKS_PER_BLOCK);
111    Value * const groupPtr2 = b->getInputStreamBlockPtr("bits", ZERO, blockIndex2);
112    Value * const groupValue = b->CreateBlockAlignedLoad(groupPtr2);
113    Value * const packBits = b->CreateExtractElement(b->CreateBitCast(groupValue, packVectorTy), packOffset);
114    Value * const packCount = b->CreateZExtOrTrunc(b->CreatePopcount(packBits), sizeTy);
115    Value * const observedUpTo = b->CreateAdd(observed, packCount);
116    BasicBlock * const haveNotSeenEnough = b->CreateBasicBlock("haveNotSeenEnough");
117    BasicBlock * const seenNOrMore = b->CreateBasicBlock("seenNOrMore");
118    b->CreateLikelyCondBr(b->CreateICmpULT(observedUpTo, N), haveNotSeenEnough, seenNOrMore);
119
120    // update our kernel state and check whether we have any other groups to process
121    b->SetInsertPoint(haveNotSeenEnough);
122    observed->addIncoming(observedUpTo, haveNotSeenEnough);
123    b->setScalarField("observed", observedUpTo);
124    Value * const remainingGroupMarkers = b->CreateResetLowestBit(groupMarkers);
125    groupMarkers->addIncoming(remainingGroupMarkers, haveNotSeenEnough);
126    b->CreateLikelyCondBr(b->CreateIsNull(remainingGroupMarkers), nextStride, processGroup);
127
128    // we've seen N non-zero items; determine the position of our items and clear any subsequent markers
129    b->SetInsertPoint(seenNOrMore);
130    if (LLVM_UNLIKELY(codegen::DebugOptionIsSet(codegen::EnableAsserts))) {
131        b->CreateAssert(b->CreateICmpUGT(N, observed), "N must be greater than observed count!");
132    }
133    Value * const bitsToFind = b->CreateSub(N, observed);
134    BasicBlock * const findNthBit = b->CreateBasicBlock("findNthBit");
135    BasicBlock * const foundNthBit = b->CreateBasicBlock("foundNthBit");
136    b->CreateBr(findNthBit);
137
138    b->SetInsertPoint(findNthBit);
139    PHINode * const remainingBitsToFind = b->CreatePHI(bitsToFind->getType(), 2);
140    remainingBitsToFind->addIncoming(bitsToFind, seenNOrMore);
141    PHINode * const remainingBits = b->CreatePHI(packBits->getType(), 2);
142    remainingBits->addIncoming(packBits, seenNOrMore);
143    Value * const nextRemainingBits = b->CreateResetLowestBit(remainingBits);
144    remainingBits->addIncoming(nextRemainingBits, findNthBit);
145    Value * const nextRemainingBitsToFind = b->CreateSub(remainingBitsToFind, ONE);
146    remainingBitsToFind->addIncoming(nextRemainingBitsToFind, findNthBit);
147    b->CreateLikelyCondBr(b->CreateIsNull(nextRemainingBitsToFind), foundNthBit, findNthBit);
148
149    // If we've found the n-th bit, end the segment after clearing the markers
150    b->SetInsertPoint(foundNthBit);
151
152    Value * const inputPtr2 = b->getInputStreamBlockPtr("bits", ZERO, blockIndex2);
153    Value * const inputValue2 = b->CreateBlockAlignedLoad(inputPtr2);
154    Value * const packPosition = b->CreateZExtOrTrunc(b->CreateCountForwardZeroes(remainingBits), sizeTy);
155    Value * const basePosition = b->CreateMul(packOffset, PACK_SIZE);
156    Value * const blockOffset = b->CreateAdd(b->CreateOr(basePosition, packPosition), ONE);
157    Value * const mask = b->CreateNot(b->bitblock_mask_from(blockOffset));
158    Value * const maskedInputValue = b->CreateAnd(inputValue2, mask);
159    Value * const outputPtr2 = b->getOutputStreamBlockPtr("uptoN", ZERO, blockIndex2);
160    b->CreateBlockAlignedStore(maskedInputValue, outputPtr2);
161    Value * const positionOfNthItem = b->CreateAdd(b->CreateMul(blockIndex2, b->getSize(b->getBitBlockWidth())), blockOffset);
162    b->setTerminationSignal();
163    BasicBlock * const segmentDone = b->CreateBasicBlock("segmentDone");
164    b->CreateBr(segmentDone);
165
166    nextStride->moveAfter(foundNthBit);
167
168    b->SetInsertPoint(nextStride);
169    blocksRemaining->addIncoming(b->CreateSub(blocksRemaining, MAXIMUM_BLOCKS_PER_ITERATION), nextStride);
170    baseBlockIndex->addIncoming(b->CreateAdd(baseBlockIndex, MAXIMUM_BLOCKS_PER_ITERATION), nextStride);
171    b->CreateLikelyCondBr(b->CreateICmpULE(blocksRemaining, MAXIMUM_BLOCKS_PER_ITERATION), segmentDone, strideLoop);
172
173    b->SetInsertPoint(segmentDone);
174    PHINode * const produced = b->CreatePHI(sizeTy, 2);
175    produced->addIncoming(positionOfNthItem, foundNthBit);
176    produced->addIncoming(b->getAvailableItemCount("bits"), nextStride);
177    Value * producedCount = b->getProducedItemCount("uptoN");
178    producedCount = b->CreateAdd(producedCount, produced);
179    b->setProducedItemCount("uptoN", producedCount);
180
181}
182
183UntilNkernel::UntilNkernel(const std::unique_ptr<kernel::KernelBuilder> & b)
184: MultiBlockKernel("UntilN",
185// inputs
186{Binding{b->getStreamSetTy(), "bits"}},
187// outputs
188{Binding{b->getStreamSetTy(), "uptoN", FixedRate(), Deferred()}},
189// input scalar
190{Binding{b->getSizeTy(), "N"}}, {},
191// internal state
192{Binding{b->getSizeTy(), "observed"}}) {
193    addAttribute(CanTerminateEarly());
194}
195
196}
Note: See TracBrowser for help on using the repository browser.