source: icGREP/icgrep-devel/icgrep/kernels/until_n.cpp @ 5757

Last change on this file since 5757 was 5755, checked in by nmedfort, 23 months ago

Bug fixes and simplified MultiBlockKernel? logic

File size: 10.6 KB
Line 
1/*
2 *  Copyright (c) 2017 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 */
5
6#include "until_n.h"
7#include <llvm/IR/Module.h>
8#include <kernels/kernel_builder.h>
9#include <kernels/streamset.h>
10
11namespace llvm { class Type; }
12
13using namespace llvm;
14using namespace parabix;
15
16namespace kernel {
17
18const unsigned packSize = 64;
19   
20Value * UntilNkernel::generateMultiBlockLogic(const std::unique_ptr<KernelBuilder> & kb, Value * const numOfStrides) {
21/* 
22   Strategy:  first form an index consisting of one bit per packsize input positions,
23   with a 1 bit signifying that the corresponding pack has at least one 1 bit.
24   Build the index one pack at a time, i.e, packsize * packsize positions at a time.
25   After an index pack is constructed, scan the index pack for 1 bits.  Each 1 bit
26   found identifies an input pack with a nonzero popcount.  Take the actual popcount
27   of the corresponding input pack and update the total number of bits seen.   If
28   the number of bits seen reaches N with any pack, determine the position of the
29   Nth bit and signal termination at that point.
30 
31   For normal processing, we process whole blocks only, always advanced processed
32   and produced item counts by an integral number of blocks.   For final block
33   processing, we treat the final partial block as a whole block for the purpose
34   of finding the Nth bit.   However, if the located bit position is past the
35   EOF position, then this is treated as if the Nth bit does not exist in the
36   input stream.
37*/
38   
39    BasicBlock * entry = kb->GetInsertBlock();
40   
41    BasicBlock * processGroups = kb->CreateBasicBlock("processGroups");
42    BasicBlock * processBlockGroup = kb->CreateBasicBlock("processBlockGroup");
43    BasicBlock * doScan = kb->CreateBasicBlock("doScan");
44    BasicBlock * scanLoop = kb->CreateBasicBlock("scanLoop");
45    BasicBlock * continueScanLoop = kb->CreateBasicBlock("continueScanLoop");
46    BasicBlock * scanDone = kb->CreateBasicBlock("scanDone");
47    BasicBlock * notFoundYet = kb->CreateBasicBlock("notFoundYet");
48    BasicBlock * findNth = kb->CreateBasicBlock("findNth");
49    BasicBlock * getPosnAfterNth = kb->CreateBasicBlock("getPosnAfterNth");
50    BasicBlock * nthPosFound = kb->CreateBasicBlock("nthPosFound");
51    BasicBlock * doSegmentReturn = kb->CreateBasicBlock("doSegmentReturn");
52    Constant * blockSize = kb->getSize(kb->getBitBlockWidth());
53    Constant * blockSizeLess1 = kb->getSize(kb->getBitBlockWidth() - 1);
54    Constant * packsPerBlock = kb->getSize(kb->getBitBlockWidth()/packSize);
55   
56    Value * N = kb->getScalarField("N");
57   
58    // Set up the types for processing by pack.
59    Type * iPackTy = kb->getIntNTy(packSize);
60    Type * iPackPtrTy = iPackTy->getPointerTo();
61   
62//    Function::arg_iterator args = mCurrentMethod->arg_begin();
63//    /* self = */ args++;
64//    Value * itemsToDo = &*(args++);
65//    Value * sourceBitstream = &*(args++);
66//    Value * uptoN_bitstream = &*(args);
67   
68    Value * itemsToDo = mAvailableItemCount[0];
69    Value * sourceBitstream = kb->getInputStreamBlockPtr("bits", kb->getInt32(0)); // mStreamBufferPtr[0];
70    Value * uptoN_bitstream = kb->getInputStreamBlockPtr("uptoN", kb->getInt32(0)); // mStreamBufferPtr[1];
71
72    // Compute the ceiling of the number of blocks to do.  If we have a final
73    // partial block, it is treated as a full block initially.   
74    Value * blocksToDo = kb->CreateUDiv(kb->CreateAdd(itemsToDo, blockSizeLess1), blockSize);
75   
76    // We will create a bitmask of size packSize with one bit for every packSize positions.
77    // The index can accommodate blocksPerGroup blocks.
78    Constant * blocksPerGroup = kb->getSize(packSize/((kb->getBitBlockWidth()/packSize)));
79    kb->CreateCondBr(kb->CreateICmpUGT(blocksToDo, kb->getSize(0)), processGroups, notFoundYet);
80   
81    // Each iteration of the outerloop processes one blockGroup of at most blocksPerGroup.
82    kb->SetInsertPoint(processGroups);
83    PHINode * blockGroupBase = kb->CreatePHI(kb->getSizeTy(), 2);
84    blockGroupBase->addIncoming(kb->getSize(0), entry);
85    Value * groupPackPtr = kb->CreatePointerCast(kb->CreateGEP(sourceBitstream, blockGroupBase), iPackPtrTy);
86    Value * blockGroupLimit = kb->CreateAdd(blockGroupBase, blocksPerGroup);
87    blockGroupLimit = kb->CreateSelect(kb->CreateICmpULT(blockGroupLimit, blocksToDo), blockGroupLimit, blocksToDo);
88    kb->CreateBr(processBlockGroup);
89
90    // Outer loop processes the blocksToDo in groups of up to blocksPerGroup at a time.
91    // The bitmask for this group is assembled.
92    kb->SetInsertPoint(processBlockGroup);
93    PHINode * blockNo = kb->CreatePHI(kb->getSizeTy(), 2);
94    PHINode * groupMask = kb->CreatePHI(iPackTy, 2);
95    blockNo->addIncoming(blockGroupBase, processGroups);
96    groupMask->addIncoming(ConstantInt::getNullValue(iPackTy), processGroups);
97
98    Value * blk = kb->CreateBlockAlignedLoad(kb->CreateGEP(sourceBitstream, {blockNo, kb->getInt32(0)}));
99    kb->CreateBlockAlignedStore(blk, kb->CreateGEP(uptoN_bitstream, {blockNo, kb->getInt32(0)}));
100    Value * hasbit = kb->simd_ugt(packSize, blk, kb->allZeroes());
101    Value * blockMask = kb->CreateZExtOrTrunc(kb->hsimd_signmask(packSize, hasbit), iPackTy);
102    Value * nextBlockNo = kb->CreateAdd(blockNo, kb->getSize(1));
103    Value * blockMaskPosition = kb->CreateMul(kb->CreateSub(blockNo, blockGroupBase), packsPerBlock);
104    Value * nextgroupMask = kb->CreateOr(groupMask, kb->CreateShl(blockMask, blockMaskPosition));
105    blockNo->addIncoming(nextBlockNo, processBlockGroup);
106    groupMask->addIncoming(nextgroupMask, processBlockGroup);
107    kb->CreateCondBr(kb->CreateICmpULT(nextBlockNo, blockGroupLimit), processBlockGroup, doScan);
108
109    // The index pack has been assembled - process the corresponding blocks.
110    kb->SetInsertPoint(doScan);
111    Value * seenSoFar = kb->getScalarField("seenSoFar");
112    kb->CreateCondBr(kb->CreateICmpUGT(nextgroupMask, ConstantInt::getNullValue(iPackTy)), scanLoop, scanDone);
113   
114    kb->SetInsertPoint(scanLoop);
115    PHINode * groupMaskPhi = kb->CreatePHI(iPackTy, 2);
116    groupMaskPhi->addIncoming(nextgroupMask, doScan);
117    PHINode * seenSoFarPhi = kb->CreatePHI(kb->getSizeTy(), 2);
118    seenSoFarPhi->addIncoming(seenSoFar, doScan);
119    Value * nonZeroPack = kb->CreateZExtOrTrunc(kb->CreateCountForwardZeroes(groupMaskPhi), kb->getSizeTy());
120    Value * scanMask = kb->CreateLoad(kb->CreateGEP(groupPackPtr, nonZeroPack));
121    Value * packCount = kb->CreateZExtOrTrunc(kb->CreatePopcount(scanMask), kb->getSizeTy());
122    Value * newTotalSeen = kb->CreateAdd(packCount, seenSoFarPhi);
123    Value * seenLessThanN = kb->CreateICmpULT(newTotalSeen, N);
124    kb->CreateCondBr(seenLessThanN, continueScanLoop, findNth);
125
126    kb->SetInsertPoint(continueScanLoop);
127    Value * reducedGroupMask = kb->CreateResetLowestBit(groupMaskPhi);
128    groupMaskPhi->addIncoming(reducedGroupMask, continueScanLoop);
129    seenSoFarPhi->addIncoming(newTotalSeen, continueScanLoop);
130    kb->CreateCondBr(kb->CreateICmpUGT(reducedGroupMask, ConstantInt::getNullValue(iPackTy)), scanLoop, scanDone);
131
132    // Now we have processed the group of blocks and updated the number of positions
133    // seenSoFar without finding the Nth bit. 
134    kb->SetInsertPoint(scanDone);
135    PHINode * newTotalSeenPhi = kb->CreatePHI(kb->getSizeTy(), 2);
136    newTotalSeenPhi->addIncoming(seenSoFar, doScan);
137    newTotalSeenPhi->addIncoming(newTotalSeen, continueScanLoop);
138    kb->setScalarField("seenSoFar", newTotalSeenPhi);
139    blockGroupBase->addIncoming(nextBlockNo, scanDone);
140    kb->CreateCondBr(kb->CreateICmpULT(nextBlockNo, blocksToDo), processGroups, notFoundYet);
141
142    kb->SetInsertPoint(notFoundYet);
143    // Now we have determined that the Nth bit has not been found in the entire
144    // set of itemsToDo.
145   
146    Value * finalCount = kb->CreateAdd(kb->getProducedItemCount("uptoN"), itemsToDo);
147    kb->setProducedItemCount("uptoN", finalCount);
148    kb->CreateBr(doSegmentReturn);
149
150    //
151    // With the last input scanMask loaded, the count of one bits seen reaches or
152    // exceeds N.  Determine the position immediately after the Nth one bit.
153    //
154    kb->SetInsertPoint(findNth);
155    PHINode * seen1 = kb->CreatePHI(kb->getSizeTy(), 2);
156    seen1->addIncoming(seenSoFarPhi, scanLoop);
157    PHINode * remainingBits = kb->CreatePHI(iPackTy, 2);
158    remainingBits->addIncoming(scanMask, scanLoop);
159    Value * clearLowest = kb->CreateResetLowestBit(remainingBits);
160    Value * oneMoreSeen = kb->CreateAdd(seen1, kb->getSize(1));
161    seen1->addIncoming(oneMoreSeen, findNth);
162    remainingBits->addIncoming(clearLowest, findNth);
163    kb->CreateCondBr(kb->CreateICmpULT(oneMoreSeen, N), findNth, getPosnAfterNth);
164
165    //
166    // We have cleared the low bits of scanMask up to and including the Nth in the stream.
167    kb->SetInsertPoint(getPosnAfterNth);
168    Value * scanMaskUpToN = kb->CreateXor(scanMask, clearLowest);
169    Value * posnInPack = kb->CreateSub(ConstantInt::get(iPackTy, packSize-1), kb->CreateCountReverseZeroes(scanMaskUpToN));
170    Value * posnInGroup = kb->CreateAdd(kb->CreateMul(nonZeroPack, kb->getSize(packSize)), posnInPack);
171    Value * posnInItemsToDo = kb->CreateAdd(kb->CreateMul(blockGroupBase, blockSize), posnInGroup);
172    // It is conceivable that we found a bit at a position beyond the given itemsToDo,
173    // when we have a partial pack at the end of input.  In this case, the Nth bit does
174    // not exist in the valid range of itemsToDo.
175    kb->CreateCondBr(kb->CreateICmpUGE(posnInItemsToDo, itemsToDo), notFoundYet, nthPosFound);
176   
177    kb->SetInsertPoint(nthPosFound);
178    Value * itemsToKeep = kb->CreateAdd(posnInItemsToDo, kb->getSize(1));
179    finalCount = kb->CreateAdd(kb->getProcessedItemCount("bits"), itemsToKeep);
180    Value * finalBlock = kb->CreateUDiv(itemsToKeep, blockSize);
181    blk = kb->CreateBlockAlignedLoad(kb->CreateGEP(sourceBitstream, {finalBlock, kb->getInt32(0)}));
182    blk = kb->CreateAnd(blk, kb->CreateNot(kb->bitblock_mask_from(kb->CreateURem(itemsToKeep, blockSize))));
183    Value * outputPtr = kb->CreateGEP(uptoN_bitstream, {finalBlock, kb->getInt32(0)});
184    kb->CreateBlockAlignedStore(blk, outputPtr);
185    kb->setProcessedItemCount("bits", finalCount);
186    kb->setProducedItemCount("uptoN", finalCount);
187    kb->setTerminationSignal();
188    kb->CreateBr(doSegmentReturn);
189   
190    kb->SetInsertPoint(doSegmentReturn);
191    return numOfStrides;
192}
193
194UntilNkernel::UntilNkernel(const std::unique_ptr<kernel::KernelBuilder> & kb)
195: MultiBlockKernel("UntilN", {Binding{kb->getStreamSetTy(1, 1), "bits"}},
196                             {Binding{kb->getStreamSetTy(1, 1), "uptoN", BoundedRate(0, 1)}},
197                             {Binding{kb->getSizeTy(), "N"}}, {},
198                             {Binding{kb->getSizeTy(), "seenSoFar"}}) {
199}
200
201}
Note: See TracBrowser for help on using the repository browser.