source: icGREP/icgrep-devel/icgrep/kernels/until_n.cpp @ 5830

Last change on this file since 5830 was 5830, checked in by nmedfort, 17 months ago

UntilN kernel rewritten to use new MultiBlock? system

File size: 9.1 KB
Line 
1/*
2 *  Copyright (c) 2017 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 */
5
6#include "until_n.h"
7#include <llvm/IR/Module.h>
8#include <kernels/kernel_builder.h>
9#include <kernels/streamset.h>
10#include <toolchain/toolchain.h>
11
12namespace llvm { class Type; }
13
14using namespace llvm;
15using namespace parabix;
16
17namespace kernel {
18
19const unsigned packSize = 64;
20   
21llvm::Value * UntilNkernel::generateMultiBlockLogic(const std::unique_ptr<KernelBuilder> & b, llvm::Value * const numOfStrides) {
22
23/* 
24   Strategy:  first form an index consisting of one bit per packsize input positions,
25   with a 1 bit signifying that the corresponding pack has at least one 1 bit.
26   Build the index one pack at a time, i.e, packsize * packsize positions at a time.
27   After an index pack is constructed, scan the index pack for 1 bits.  Each 1 bit
28   found identifies an input pack with a nonzero popcount.  Take the actual popcount
29   of the corresponding input pack and update the total number of bits seen.   If
30   the number of bits seen reaches N with any pack, determine the position of the
31   Nth bit and signal termination at that point.
32 
33   For normal processing, we process whole blocks only, always advanced processed
34   and produced item counts by an integral number of blocks.   For final block
35   processing, we treat the final partial block as a whole block for the purpose
36   of finding the Nth bit.   However, if the located bit position is past the
37   EOF position, then this is treated as if the Nth bit does not exist in the
38   input stream.
39*/
40
41    Constant * const ZERO = b->getSize(0);
42    Constant * const ONE = b->getSize(1);
43    const auto packsPerBlock = b->getBitBlockWidth() / packSize;
44    Constant * const PACK_SIZE = b->getSize(packSize);
45    Constant * const PACKS_PER_BLOCK = b->getSize(packsPerBlock);
46    Value * const ZEROES = b->allZeroes();
47    Type * packTy = b->getIntNTy(packSize);
48
49    BasicBlock * const entry = b->GetInsertBlock();
50    BasicBlock * const strideLoop = b->CreateBasicBlock("strideLoop");
51
52    b->CreateBr(strideLoop);
53    b->SetInsertPoint(strideLoop);
54    PHINode * const strideIndex = b->CreatePHI(b->getSizeTy(), 2);
55    strideIndex->addIncoming(ZERO, entry);
56
57    const auto n = (packSize * packSize) / b->getBitBlockWidth();
58    Value * groupMask = nullptr;
59    Value * const baseOffset = b->CreateMul(strideIndex, b->getSize(n));
60    for (unsigned i = 0; i < n; ++i) {
61        Value * offset = b->CreateNUWAdd(baseOffset, b->getSize(i));
62        Value * inputPtr = b->getInputStreamBlockPtr("bits", ZERO, offset);
63        Value * inputValue = b->CreateBlockAlignedLoad(inputPtr);
64        Value * outputPtr = b->getOutputStreamBlockPtr("uptoN", ZERO, offset);
65        b->CreateBlockAlignedStore(inputValue, outputPtr);
66        Value * markers = b->CreateNot(b->simd_eq(packSize, inputValue, ZEROES));
67        Value * blockMask = b->CreateZExtOrTrunc(b->hsimd_signmask(packSize, markers), packTy);
68        if (i) {
69            blockMask = b->CreateShl(blockMask, i * packsPerBlock);
70            groupMask = b->CreateOr(groupMask, blockMask);
71        } else {
72            groupMask = blockMask;
73        }
74    }
75
76    BasicBlock * const processGroups = b->CreateBasicBlock("processGroups");
77    BasicBlock * const nextStride = b->CreateBasicBlock("nextStride");
78
79    b->CreateLikelyCondBr(b->CreateIsNull(groupMask), nextStride, processGroups);
80
81    b->SetInsertPoint(processGroups);
82    Value * const N = b->getScalarField("N");
83    Value * const initiallyObserved = b->getScalarField("observed");
84    BasicBlock * const processGroup = b->CreateBasicBlock("processGroup");
85    b->CreateBr(processGroup);
86
87    b->SetInsertPoint(processGroup);
88    PHINode * const observed = b->CreatePHI(initiallyObserved->getType(), 2);
89    observed->addIncoming(initiallyObserved, processGroups);
90    PHINode * const groupMarkers = b->CreatePHI(groupMask->getType(), 2);
91    groupMarkers->addIncoming(groupMask, processGroups);
92
93    Value * const groupIndex = b->CreateZExtOrTrunc(b->CreateCountForwardZeroes(groupMarkers), b->getSizeTy());
94    Value * const blockIndex = b->CreateNUWAdd(baseOffset, b->CreateUDiv(groupIndex, PACKS_PER_BLOCK));
95    Value * const packOffset = b->CreateURem(groupIndex, PACKS_PER_BLOCK);
96    Value * const groupPtr = b->getInputStreamBlockPtr("bits", ZERO, blockIndex);
97    Value * const groupValue = b->CreateBlockAlignedLoad(groupPtr);
98    Value * const packBits = b->CreateExtractElement(groupValue, packOffset);
99
100    //Type * packPtrTy = packTy->getPointerTo();
101    //Value * const packPtr = b->CreateGEP(b->CreatePointerCast(groupPtr, packPtrTy), packOffset);
102    //Value * const packBits = b->CreateLoad(packPtr);
103    Value * const packCount = b->CreatePopcount(packBits);
104    Value * const observedUpTo = b->CreateNUWAdd(observed, packCount);
105
106    BasicBlock * const haveNotSeenEnough = b->CreateBasicBlock("haveNotSeenEnough");
107    BasicBlock * const seenNOrMore = b->CreateBasicBlock("seenNOrMore");
108    b->CreateLikelyCondBr(b->CreateICmpULT(observedUpTo, N), haveNotSeenEnough, seenNOrMore);
109
110    // update our kernel state and check whether we have any other groups to process
111    b->SetInsertPoint(haveNotSeenEnough);
112    observed->addIncoming(observedUpTo, haveNotSeenEnough);
113    b->setScalarField("observed", observedUpTo);
114    Value * const remainingGroupMarkers = b->CreateResetLowestBit(groupMarkers);
115    groupMarkers->addIncoming(remainingGroupMarkers, haveNotSeenEnough);
116    b->CreateLikelyCondBr(b->CreateIsNull(remainingGroupMarkers), nextStride, processGroup);
117
118    // we've seen N non-zero items; determine the position of our items and clear any subsequent markers
119    b->SetInsertPoint(seenNOrMore);
120    if (LLVM_UNLIKELY(codegen::DebugOptionIsSet(codegen::EnableAsserts))) {
121        b->CreateAssert(b->CreateICmpUGT(N, observed), "N must be greater than observed count!");
122    }
123    Value * const bitsToFind = b->CreateNUWSub(N, observed);
124    BasicBlock * const findNthBit = b->CreateBasicBlock("findNthBit");
125    BasicBlock * const foundNthBit = b->CreateBasicBlock("foundNthBit");
126    b->CreateBr(findNthBit);
127
128    b->SetInsertPoint(findNthBit);
129    PHINode * const remainingPositions = b->CreatePHI(bitsToFind->getType(), 2);
130    remainingPositions->addIncoming(bitsToFind, seenNOrMore);
131    PHINode * const remainingBits = b->CreatePHI(packBits->getType(), 2);
132    remainingBits->addIncoming(packBits, seenNOrMore);
133    Value * const nextRemainingPositions = b->CreateNUWSub(remainingPositions, ONE);
134    remainingPositions->addIncoming(nextRemainingPositions, findNthBit);
135    Value * const nextRemainingBits = b->CreateResetLowestBit(remainingBits);
136    remainingBits->addIncoming(nextRemainingBits, findNthBit);
137
138    b->CreateLikelyCondBr(b->CreateIsNull(nextRemainingPositions), foundNthBit, findNthBit);
139
140    // If we've found the n-th bit, end the segment after clearing the markers
141    b->SetInsertPoint(foundNthBit);
142    Value * const inputPtr = b->getInputStreamBlockPtr("bits", ZERO, blockIndex);
143    Value * const inputValue = b->CreateBlockAlignedLoad(inputPtr);
144    Value * const packPosition = b->CreateZExtOrTrunc(b->CreateCountForwardZeroes(remainingBits), b->getSizeTy());
145    Value * const basePosition = b->CreateNUWMul(packOffset, PACK_SIZE);
146    Value * const blockOffset = b->CreateNUWAdd(b->CreateOr(basePosition, packPosition), ONE);
147    Value * const mask = b->CreateNot(b->bitblock_mask_from(blockOffset));
148    Value * const maskedInputValue = b->CreateAnd(inputValue, mask);
149    Value * const outputPtr = b->getOutputStreamBlockPtr("uptoN", ZERO, blockIndex);
150    b->CreateBlockAlignedStore(maskedInputValue, outputPtr);
151    Value * const positionOfNthItem = b->CreateNUWAdd(b->CreateMul(blockIndex, b->getSize(b->getBitBlockWidth())), blockOffset);
152    b->setTerminationSignal();
153    BasicBlock * const segmentDone = b->CreateBasicBlock("segmentDone");
154    b->CreateBr(segmentDone);
155
156    nextStride->moveAfter(foundNthBit);
157
158    b->SetInsertPoint(nextStride);
159    Value * const nextStrideIndex = b->CreateNUWAdd(strideIndex, ONE);
160    strideIndex->addIncoming(nextStrideIndex, nextStride);
161    b->CreateLikelyCondBr(b->CreateICmpEQ(nextStrideIndex, numOfStrides), segmentDone, strideLoop);
162
163    Constant * const FULL_STRIDE = b->getSize(packSize * packSize);
164
165    b->SetInsertPoint(segmentDone);
166    PHINode * const produced = b->CreatePHI(b->getSizeTy(), 2);
167    produced->addIncoming(positionOfNthItem, foundNthBit);
168    produced->addIncoming(FULL_STRIDE, nextStride);
169
170    Value * producedCount = b->getProducedItemCount("uptoN");
171    producedCount = b->CreateNUWAdd(producedCount, b->CreateNUWMul(FULL_STRIDE, strideIndex));
172    producedCount = b->CreateNUWAdd(producedCount, produced);
173    b->setProducedItemCount("uptoN", producedCount);
174
175    return numOfStrides;
176}
177
178UntilNkernel::UntilNkernel(const std::unique_ptr<kernel::KernelBuilder> & b)
179: MultiBlockKernel("UntilN",
180// inputs
181{Binding{b->getStreamSetTy(), "bits", FixedRate((packSize * packSize) / b->getBitBlockWidth())}},
182// outputs
183{Binding{b->getStreamSetTy(), "uptoN", BoundedRate(0, (packSize * packSize) / b->getBitBlockWidth())}},
184// input scalar
185{Binding{b->getSizeTy(), "N"}}, {},
186// internal state
187{Binding{b->getSizeTy(), "observed"}}) {
188
189}
190
191}
Note: See TracBrowser for help on using the repository browser.