source: icGREP/icgrep-devel/icgrep/kernels/until_n.cpp @ 5831

Last change on this file since 5831 was 5831, checked in by nmedfort, 16 months ago

Potential bug fix for 32-bit

File size: 9.3 KB
Line 
1/*
2 *  Copyright (c) 2017 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 */
5
6#include "until_n.h"
7#include <llvm/IR/Module.h>
8#include <kernels/kernel_builder.h>
9#include <kernels/streamset.h>
10#include <toolchain/toolchain.h>
11
12namespace llvm { class Type; }
13
14using namespace llvm;
15using namespace parabix;
16
17namespace kernel {
18
19void UntilNkernel::generateMultiBlockLogic(const std::unique_ptr<KernelBuilder> & b, llvm::Value * const numOfStrides) {
20
21/* 
22   Strategy:  first form an index consisting of one bit per packsize input positions,
23   with a 1 bit signifying that the corresponding pack has at least one 1 bit.
24   Build the index one pack at a time, i.e, packsize * packsize positions at a time.
25   After an index pack is constructed, scan the index pack for 1 bits.  Each 1 bit
26   found identifies an input pack with a nonzero popcount.  Take the actual popcount
27   of the corresponding input pack and update the total number of bits seen.   If
28   the number of bits seen reaches N with any pack, determine the position of the
29   Nth bit and signal termination at that point.
30 
31   For normal processing, we process whole blocks only, always advanced processed
32   and produced item counts by an integral number of blocks.   For final block
33   processing, we treat the final partial block as a whole block for the purpose
34   of finding the Nth bit.   However, if the located bit position is past the
35   EOF position, then this is treated as if the Nth bit does not exist in the
36   input stream.
37*/
38
39    const unsigned packSize = b->getSizeTy()->getBitWidth();
40    Constant * const ZERO = b->getSize(0);
41    Constant * const ONE = b->getSize(1);
42    const auto packsPerBlock = b->getBitBlockWidth() / packSize;
43    Constant * const PACK_SIZE = b->getSize(packSize);
44    Constant * const PACKS_PER_BLOCK = b->getSize(packsPerBlock);
45    Value * const ZEROES = b->allZeroes();
46    Type * packTy = b->getIntNTy(packSize);
47
48    BasicBlock * const entry = b->GetInsertBlock();
49    BasicBlock * const strideLoop = b->CreateBasicBlock("strideLoop");
50
51    b->CreateBr(strideLoop);
52    b->SetInsertPoint(strideLoop);
53    PHINode * const strideIndex = b->CreatePHI(b->getSizeTy(), 2);
54    strideIndex->addIncoming(ZERO, entry);
55
56    const auto n = (packSize * packSize) / b->getBitBlockWidth();
57    Value * groupMask = nullptr;
58    Value * const baseOffset = b->CreateMul(strideIndex, b->getSize(n));
59    for (unsigned i = 0; i < n; ++i) {
60        Value * offset = b->CreateNUWAdd(baseOffset, b->getSize(i));
61        Value * inputPtr = b->getInputStreamBlockPtr("bits", ZERO, offset);
62        Value * inputValue = b->CreateBlockAlignedLoad(inputPtr);
63        Value * outputPtr = b->getOutputStreamBlockPtr("uptoN", ZERO, offset);
64        b->CreateBlockAlignedStore(inputValue, outputPtr);
65        Value * markers = b->CreateNot(b->simd_eq(packSize, inputValue, ZEROES));
66        Value * blockMask = b->CreateZExtOrTrunc(b->hsimd_signmask(packSize, markers), packTy);
67        if (i) {
68            blockMask = b->CreateShl(blockMask, i * packsPerBlock);
69            groupMask = b->CreateOr(groupMask, blockMask);
70        } else {
71            groupMask = blockMask;
72        }
73    }
74
75    BasicBlock * const processGroups = b->CreateBasicBlock("processGroups");
76    BasicBlock * const nextStride = b->CreateBasicBlock("nextStride");
77
78    b->CreateLikelyCondBr(b->CreateIsNull(groupMask), nextStride, processGroups);
79
80    b->SetInsertPoint(processGroups);
81    Value * const N = b->getScalarField("N");
82    Value * const initiallyObserved = b->getScalarField("observed");
83    BasicBlock * const processGroup = b->CreateBasicBlock("processGroup");
84    b->CreateBr(processGroup);
85
86    b->SetInsertPoint(processGroup);
87    PHINode * const observed = b->CreatePHI(initiallyObserved->getType(), 2);
88    observed->addIncoming(initiallyObserved, processGroups);
89    PHINode * const groupMarkers = b->CreatePHI(groupMask->getType(), 2);
90    groupMarkers->addIncoming(groupMask, processGroups);
91
92    Value * const groupIndex = b->CreateZExtOrTrunc(b->CreateCountForwardZeroes(groupMarkers), b->getSizeTy());
93    Value * const blockIndex = b->CreateNUWAdd(baseOffset, b->CreateUDiv(groupIndex, PACKS_PER_BLOCK));
94    Value * const packOffset = b->CreateURem(groupIndex, PACKS_PER_BLOCK);
95    Value * const groupPtr = b->getInputStreamBlockPtr("bits", ZERO, blockIndex);
96    Value * const groupValue = b->CreateBlockAlignedLoad(groupPtr);
97    Value * const packBits = b->CreateExtractElement(groupValue, packOffset);
98
99    //Type * packPtrTy = packTy->getPointerTo();
100    //Value * const packPtr = b->CreateGEP(b->CreatePointerCast(groupPtr, packPtrTy), packOffset);
101    //Value * const packBits = b->CreateLoad(packPtr);
102    Value * const packCount = b->CreateZExtOrTrunc(b->CreatePopcount(packBits), b->getSizeTy());
103    Value * const observedUpTo = b->CreateNUWAdd(observed, packCount);
104
105    BasicBlock * const haveNotSeenEnough = b->CreateBasicBlock("haveNotSeenEnough");
106    BasicBlock * const seenNOrMore = b->CreateBasicBlock("seenNOrMore");
107    b->CreateLikelyCondBr(b->CreateICmpULT(observedUpTo, N), haveNotSeenEnough, seenNOrMore);
108
109    // update our kernel state and check whether we have any other groups to process
110    b->SetInsertPoint(haveNotSeenEnough);
111    observed->addIncoming(observedUpTo, haveNotSeenEnough);
112    b->setScalarField("observed", observedUpTo);
113    Value * const remainingGroupMarkers = b->CreateResetLowestBit(groupMarkers);
114    groupMarkers->addIncoming(remainingGroupMarkers, haveNotSeenEnough);
115    b->CreateLikelyCondBr(b->CreateIsNull(remainingGroupMarkers), nextStride, processGroup);
116
117    // we've seen N non-zero items; determine the position of our items and clear any subsequent markers
118    b->SetInsertPoint(seenNOrMore);
119    if (LLVM_UNLIKELY(codegen::DebugOptionIsSet(codegen::EnableAsserts))) {
120        b->CreateAssert(b->CreateICmpUGT(N, observed), "N must be greater than observed count!");
121    }
122    Value * const bitsToFind = b->CreateNUWSub(N, observed);
123    BasicBlock * const findNthBit = b->CreateBasicBlock("findNthBit");
124    BasicBlock * const foundNthBit = b->CreateBasicBlock("foundNthBit");
125    b->CreateBr(findNthBit);
126
127    b->SetInsertPoint(findNthBit);
128    PHINode * const remainingPositions = b->CreatePHI(bitsToFind->getType(), 2);
129    remainingPositions->addIncoming(bitsToFind, seenNOrMore);
130    PHINode * const remainingBits = b->CreatePHI(packBits->getType(), 2);
131    remainingBits->addIncoming(packBits, seenNOrMore);
132    Value * const nextRemainingPositions = b->CreateNUWSub(remainingPositions, ONE);
133    remainingPositions->addIncoming(nextRemainingPositions, findNthBit);
134    Value * const nextRemainingBits = b->CreateResetLowestBit(remainingBits);
135    remainingBits->addIncoming(nextRemainingBits, findNthBit);
136
137    b->CreateLikelyCondBr(b->CreateIsNull(nextRemainingPositions), foundNthBit, findNthBit);
138
139    // If we've found the n-th bit, end the segment after clearing the markers
140    b->SetInsertPoint(foundNthBit);
141    Value * const inputPtr = b->getInputStreamBlockPtr("bits", ZERO, blockIndex);
142    Value * const inputValue = b->CreateBlockAlignedLoad(inputPtr);
143    Value * const packPosition = b->CreateZExtOrTrunc(b->CreateCountForwardZeroes(remainingBits), b->getSizeTy());
144    Value * const basePosition = b->CreateNUWMul(packOffset, PACK_SIZE);
145    Value * const blockOffset = b->CreateNUWAdd(b->CreateOr(basePosition, packPosition), ONE);
146    Value * const mask = b->CreateNot(b->bitblock_mask_from(blockOffset));
147    Value * const maskedInputValue = b->CreateAnd(inputValue, mask);
148    Value * const outputPtr = b->getOutputStreamBlockPtr("uptoN", ZERO, blockIndex);
149    b->CreateBlockAlignedStore(maskedInputValue, outputPtr);
150    Value * const positionOfNthItem = b->CreateNUWAdd(b->CreateMul(blockIndex, b->getSize(b->getBitBlockWidth())), blockOffset);
151    b->setTerminationSignal();
152    BasicBlock * const segmentDone = b->CreateBasicBlock("segmentDone");
153    b->CreateBr(segmentDone);
154
155    nextStride->moveAfter(foundNthBit);
156
157    b->SetInsertPoint(nextStride);
158    Value * const nextStrideIndex = b->CreateNUWAdd(strideIndex, ONE);
159    strideIndex->addIncoming(nextStrideIndex, nextStride);
160    b->CreateLikelyCondBr(b->CreateICmpEQ(nextStrideIndex, numOfStrides), segmentDone, strideLoop);
161
162    Constant * const FULL_STRIDE = b->getSize(packSize * packSize);
163
164    b->SetInsertPoint(segmentDone);
165    PHINode * const produced = b->CreatePHI(b->getSizeTy(), 2);
166    produced->addIncoming(positionOfNthItem, foundNthBit);
167    produced->addIncoming(FULL_STRIDE, nextStride);
168
169    Value * producedCount = b->getProducedItemCount("uptoN");
170    producedCount = b->CreateNUWAdd(producedCount, b->CreateNUWMul(FULL_STRIDE, strideIndex));
171    producedCount = b->CreateNUWAdd(producedCount, produced);
172    b->setProducedItemCount("uptoN", producedCount);
173
174}
175
176unsigned LLVM_READNONE calculateRate(const std::unique_ptr<kernel::KernelBuilder> & b) {
177    const unsigned packSize = b->getSizeTy()->getBitWidth();
178    return (packSize * packSize) / b->getBitBlockWidth();
179}
180
181UntilNkernel::UntilNkernel(const std::unique_ptr<kernel::KernelBuilder> & b)
182: MultiBlockKernel("UntilN_" + std::to_string(calculateRate(b)),
183// inputs
184{Binding{b->getStreamSetTy(), "bits", FixedRate(calculateRate(b))}},
185// outputs
186{Binding{b->getStreamSetTy(), "uptoN", BoundedRate(0, calculateRate(b))}},
187// input scalar
188{Binding{b->getSizeTy(), "N"}}, {},
189// internal state
190{Binding{b->getSizeTy(), "observed"}}) {
191
192}
193
194}
Note: See TracBrowser for help on using the repository browser.