source: icGREP/icgrep-devel/icgrep/kernels/until_n.cpp @ 5832

Last change on this file since 5832 was 5832, checked in by nmedfort, 15 months ago

Bug fix for UntilN

File size: 9.3 KB
Line 
1/*
2 *  Copyright (c) 2017 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 */
5
6#include "until_n.h"
7#include <llvm/IR/Module.h>
8#include <kernels/kernel_builder.h>
9#include <kernels/streamset.h>
10#include <toolchain/toolchain.h>
11
12namespace llvm { class Type; }
13
14using namespace llvm;
15using namespace parabix;
16
17namespace kernel {
18
19void UntilNkernel::generateMultiBlockLogic(const std::unique_ptr<KernelBuilder> & b, llvm::Value * const numOfStrides) {
20
21/* 
22   Strategy:  first form an index consisting of one bit per packsize input positions,
23   with a 1 bit signifying that the corresponding pack has at least one 1 bit.
24   Build the index one pack at a time, i.e, packsize * packsize positions at a time.
25   After an index pack is constructed, scan the index pack for 1 bits.  Each 1 bit
26   found identifies an input pack with a nonzero popcount.  Take the actual popcount
27   of the corresponding input pack and update the total number of bits seen.   If
28   the number of bits seen reaches N with any pack, determine the position of the
29   Nth bit and signal termination at that point.
30 
31   For normal processing, we process whole blocks only, always advanced processed
32   and produced item counts by an integral number of blocks.   For final block
33   processing, we treat the final partial block as a whole block for the purpose
34   of finding the Nth bit.   However, if the located bit position is past the
35   EOF position, then this is treated as if the Nth bit does not exist in the
36   input stream.
37*/
38
39    IntegerType * const sizeTy = b->getSizeTy();
40    const unsigned packSize = sizeTy->getBitWidth();
41    Constant * const ZERO = b->getSize(0);
42    Constant * const ONE = b->getSize(1);
43    const auto packsPerBlock = b->getBitBlockWidth() / packSize;
44    Constant * const PACK_SIZE = b->getSize(packSize);
45    Constant * const PACKS_PER_BLOCK = b->getSize(packsPerBlock);
46    VectorType * const vTy = VectorType::get(sizeTy, packsPerBlock);
47    Value * const ZEROES = Constant::getNullValue(vTy);
48
49    BasicBlock * const entry = b->GetInsertBlock();
50    BasicBlock * const strideLoop = b->CreateBasicBlock("strideLoop");
51
52    Value * const allAvailableItems = b->getAvailableItemCount("bits");
53
54    b->CreateBr(strideLoop);
55    b->SetInsertPoint(strideLoop);
56    PHINode * const strideIndex = b->CreatePHI(sizeTy, 2);
57    strideIndex->addIncoming(ZERO, entry);
58
59    const auto n = (packSize * packSize) / b->getBitBlockWidth();
60    Value * groupMask = nullptr;
61    Value * const baseOffset = b->CreateMul(strideIndex, b->getSize(n));
62    for (unsigned i = 0; i < n; ++i) {
63        Value * offset = b->CreateNUWAdd(baseOffset, b->getSize(i));
64        Value * inputPtr = b->getInputStreamBlockPtr("bits", ZERO, offset);
65        Value * inputValue = b->CreateBlockAlignedLoad(inputPtr);
66        Value * outputPtr = b->getOutputStreamBlockPtr("uptoN", ZERO, offset);
67        b->CreateBlockAlignedStore(inputValue, outputPtr);
68        Value * markers = b->CreateNot(b->simd_eq(packSize, inputValue, ZEROES));
69        Value * blockMask = b->CreateZExtOrTrunc(b->hsimd_signmask(packSize, markers), sizeTy);
70        if (i) {
71            blockMask = b->CreateShl(blockMask, i * packsPerBlock);
72            groupMask = b->CreateOr(groupMask, blockMask);
73        } else {
74            groupMask = blockMask;
75        }
76    }
77
78    BasicBlock * const processGroups = b->CreateBasicBlock("processGroups");
79    BasicBlock * const nextStride = b->CreateBasicBlock("nextStride");
80
81    b->CreateLikelyCondBr(b->CreateIsNull(groupMask), nextStride, processGroups);
82
83    b->SetInsertPoint(processGroups);
84    Value * const N = b->getScalarField("N");
85    Value * const initiallyObserved = b->getScalarField("observed");
86    BasicBlock * const processGroup = b->CreateBasicBlock("processGroup");
87    b->CreateBr(processGroup);
88
89    b->SetInsertPoint(processGroup);
90    PHINode * const observed = b->CreatePHI(initiallyObserved->getType(), 2);
91    observed->addIncoming(initiallyObserved, processGroups);
92    PHINode * const groupMarkers = b->CreatePHI(groupMask->getType(), 2);
93    groupMarkers->addIncoming(groupMask, processGroups);
94
95    Value * const groupIndex = b->CreateZExtOrTrunc(b->CreateCountForwardZeroes(groupMarkers), sizeTy);
96    Value * const blockIndex = b->CreateNUWAdd(baseOffset, b->CreateUDiv(groupIndex, PACKS_PER_BLOCK));
97    Value * const packOffset = b->CreateURem(groupIndex, PACKS_PER_BLOCK);
98    Value * const groupPtr = b->getInputStreamBlockPtr("bits", ZERO, blockIndex);
99    Value * const groupValue = b->CreateBlockAlignedLoad(groupPtr);
100    Value * const packBits = b->CreateExtractElement(b->CreateBitCast(groupValue, vTy), packOffset);
101    //Type * packPtrTy = packTy->getPointerTo();
102    //Value * const packPtr = b->CreateGEP(b->CreatePointerCast(groupPtr, packPtrTy), packOffset);
103    //Value * const packBits = b->CreateLoad(packPtr);
104    Value * const packCount = b->CreateZExtOrTrunc(b->CreatePopcount(packBits), sizeTy);
105    Value * const observedUpTo = b->CreateNUWAdd(observed, packCount);
106
107    BasicBlock * const haveNotSeenEnough = b->CreateBasicBlock("haveNotSeenEnough");
108    BasicBlock * const seenNOrMore = b->CreateBasicBlock("seenNOrMore");
109    b->CreateLikelyCondBr(b->CreateICmpULT(observedUpTo, N), haveNotSeenEnough, seenNOrMore);
110
111    // update our kernel state and check whether we have any other groups to process
112    b->SetInsertPoint(haveNotSeenEnough);
113    observed->addIncoming(observedUpTo, haveNotSeenEnough);
114    b->setScalarField("observed", observedUpTo);
115    Value * const remainingGroupMarkers = b->CreateResetLowestBit(groupMarkers);
116    groupMarkers->addIncoming(remainingGroupMarkers, haveNotSeenEnough);
117    b->CreateLikelyCondBr(b->CreateIsNull(remainingGroupMarkers), nextStride, processGroup);
118
119    // we've seen N non-zero items; determine the position of our items and clear any subsequent markers
120    b->SetInsertPoint(seenNOrMore);
121    if (LLVM_UNLIKELY(codegen::DebugOptionIsSet(codegen::EnableAsserts))) {
122        b->CreateAssert(b->CreateICmpUGT(N, observed), "N must be greater than observed count!");
123    }
124    Value * const bitsToFind = b->CreateNUWSub(N, observed);
125    BasicBlock * const findNthBit = b->CreateBasicBlock("findNthBit");
126    BasicBlock * const foundNthBit = b->CreateBasicBlock("foundNthBit");
127    b->CreateBr(findNthBit);
128
129    b->SetInsertPoint(findNthBit);
130    PHINode * const remainingBitsToFind = b->CreatePHI(bitsToFind->getType(), 2);
131    remainingBitsToFind->addIncoming(bitsToFind, seenNOrMore);
132    PHINode * const remainingBits = b->CreatePHI(packBits->getType(), 2);
133    remainingBits->addIncoming(packBits, seenNOrMore);
134    Value * const nextRemainingBits = b->CreateResetLowestBit(remainingBits);
135    remainingBits->addIncoming(nextRemainingBits, findNthBit);
136    Value * const nextRemainingBitsToFind = b->CreateNUWSub(remainingBitsToFind, ONE);
137    remainingBitsToFind->addIncoming(nextRemainingBitsToFind, findNthBit);
138    b->CreateLikelyCondBr(b->CreateIsNull(nextRemainingBitsToFind), foundNthBit, findNthBit);
139
140    // If we've found the n-th bit, end the segment after clearing the markers
141    b->SetInsertPoint(foundNthBit);
142    Value * const inputPtr = b->getInputStreamBlockPtr("bits", ZERO, blockIndex);
143    Value * const inputValue = b->CreateBlockAlignedLoad(inputPtr);
144    Value * const packPosition = b->CreateZExtOrTrunc(b->CreateCountForwardZeroes(remainingBits), sizeTy);
145    Value * const basePosition = b->CreateNUWMul(packOffset, PACK_SIZE);
146    Value * const blockOffset = b->CreateNUWAdd(b->CreateOr(basePosition, packPosition), ONE);
147    Value * const mask = b->CreateNot(b->bitblock_mask_from(blockOffset));
148    Value * const maskedInputValue = b->CreateAnd(inputValue, mask);
149    Value * const outputPtr = b->getOutputStreamBlockPtr("uptoN", ZERO, blockIndex);
150    b->CreateBlockAlignedStore(maskedInputValue, outputPtr);
151    Value * const positionOfNthItem = b->CreateNUWAdd(b->CreateMul(blockIndex, b->getSize(b->getBitBlockWidth())), blockOffset);
152    b->setTerminationSignal();
153    BasicBlock * const segmentDone = b->CreateBasicBlock("segmentDone");
154    b->CreateBr(segmentDone);
155
156    nextStride->moveAfter(foundNthBit);
157
158    b->SetInsertPoint(nextStride);
159    Value * const nextStrideIndex = b->CreateNUWAdd(strideIndex, ONE);
160    strideIndex->addIncoming(nextStrideIndex, nextStride);
161    b->CreateLikelyCondBr(b->CreateICmpEQ(nextStrideIndex, numOfStrides), segmentDone, strideLoop);
162
163    b->SetInsertPoint(segmentDone);
164    PHINode * const produced = b->CreatePHI(sizeTy, 2);
165    produced->addIncoming(positionOfNthItem, foundNthBit);
166    produced->addIncoming(allAvailableItems, nextStride);
167    Value * producedCount = b->getProducedItemCount("uptoN");
168    producedCount = b->CreateNUWAdd(producedCount, produced);
169    b->setProducedItemCount("uptoN", producedCount);
170
171}
172
173unsigned LLVM_READNONE calculateRate(const std::unique_ptr<kernel::KernelBuilder> & b) {
174    const unsigned packSize = b->getSizeTy()->getBitWidth();
175    return (packSize * packSize) / b->getBitBlockWidth();
176}
177
178UntilNkernel::UntilNkernel(const std::unique_ptr<kernel::KernelBuilder> & b)
179: MultiBlockKernel("UntilN_" + std::to_string(calculateRate(b)),
180// inputs
181{Binding{b->getStreamSetTy(), "bits", FixedRate(calculateRate(b))}},
182// outputs
183{Binding{b->getStreamSetTy(), "uptoN", BoundedRate(0, calculateRate(b))}},
184// input scalar
185{Binding{b->getSizeTy(), "N"}}, {},
186// internal state
187{Binding{b->getSizeTy(), "observed"}}) {
188
189}
190
191}
Note: See TracBrowser for help on using the repository browser.