source: icGREP/icgrep-devel/icgrep/kernels/until_n.cpp @ 5589

Last change on this file since 5589 was 5493, checked in by cameron, 2 years ago

Restore check-ins from the last several days

File size: 10.2 KB
Line 
1/*
2 *  Copyright (c) 2017 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 */
5
6#include "until_n.h"
7#include <llvm/IR/Module.h>
8#include <kernels/kernel_builder.h>
9#include <kernels/streamset.h>
10
11namespace llvm { class Type; }
12
13using namespace llvm;
14using namespace parabix;
15
16namespace kernel {
17
18const unsigned packSize = 64;
19   
20void UntilNkernel::generateMultiBlockLogic(const std::unique_ptr<KernelBuilder> & kb) {
21/* 
22   Strategy:  first form an index consisting of one bit per packsize input positions,
23   with a 1 bit signifying that the corresponding pack has at least one 1 bit.
24   Build the index one pack at a time, i.e, packsize * packsize positions at a time.
25   After an index pack is constructed, scan the index pack for 1 bits.  Each 1 bit
26   found identifies an input pack with a nonzero popcount.  Take the actual popcount
27   of the corresponding input pack and update the total number of bits seen.   If
28   the number of bits seen reaches N with any pack, determine the position of the
29   Nth bit and signal termination at that point.
30 
31   For normal processing, we process whole blocks only, always advanced processed
32   and produced item counts by an integral number of blocks.   For final block
33   processing, we treat the final partial block as a whole block for the purpose
34   of finding the Nth bit.   However, if the located bit position is past the
35   EOF position, then this is treated as if the Nth bit does not exist in the
36   input stream.
37*/
38   
39    BasicBlock * entry = kb->GetInsertBlock();
40   
41    BasicBlock * processGroups = kb->CreateBasicBlock("processGroups");
42    BasicBlock * processBlockGroup = kb->CreateBasicBlock("processBlockGroup");
43    BasicBlock * doScan = kb->CreateBasicBlock("doScan");
44    BasicBlock * scanLoop = kb->CreateBasicBlock("scanLoop");
45    BasicBlock * continueScanLoop = kb->CreateBasicBlock("continueScanLoop");
46    BasicBlock * scanDone = kb->CreateBasicBlock("scanDone");
47    BasicBlock * notFoundYet = kb->CreateBasicBlock("notFoundYet");
48    BasicBlock * findNth = kb->CreateBasicBlock("findNth");
49    BasicBlock * getPosnAfterNth = kb->CreateBasicBlock("getPosnAfterNth");
50    BasicBlock * nthPosFound = kb->CreateBasicBlock("nthPosFound");
51    BasicBlock * doSegmentReturn = kb->CreateBasicBlock("doSegmentReturn");
52    Constant * blockSize = kb->getSize(kb->getBitBlockWidth());
53    Constant * blockSizeLess1 = kb->getSize(kb->getBitBlockWidth() - 1);
54    Constant * packsPerBlock = kb->getSize(kb->getBitBlockWidth()/packSize);
55   
56    Value * N = kb->getScalarField("N");
57   
58    // Set up the types for processing by pack.
59    Type * iPackTy = kb->getIntNTy(packSize);
60    Type * iPackPtrTy = iPackTy->getPointerTo();
61   
62    Function::arg_iterator args = mCurrentMethod->arg_begin();
63    /* self = */ args++;
64    Value * itemsToDo = &*(args++);
65    Value * sourceBitstream = &*(args++);
66    Value * uptoN_bitstream = &*(args);
67   
68    // Compute the ceiling of the number of blocks to do.  If we have a final
69    // partial block, it is treated as a full block initially.   
70    Value * blocksToDo = kb->CreateUDiv(kb->CreateAdd(itemsToDo, blockSizeLess1), blockSize);
71   
72    // We will create a bitmask of size packSize with one bit for every packSize positions.
73    // The index can accommodate blocksPerGroup blocks.
74    Constant * blocksPerGroup = kb->getSize(packSize/((kb->getBitBlockWidth()/packSize)));
75    kb->CreateCondBr(kb->CreateICmpUGT(blocksToDo, kb->getSize(0)), processGroups, notFoundYet);
76   
77    // Each iteration of the outerloop processes one blockGroup of at most blocksPerGroup.
78    kb->SetInsertPoint(processGroups);
79    PHINode * blockGroupBase = kb->CreatePHI(kb->getSizeTy(), 2);
80    blockGroupBase->addIncoming(kb->getSize(0), entry);
81    Value * groupPackPtr = kb->CreatePointerCast(kb->CreateGEP(sourceBitstream, blockGroupBase), iPackPtrTy);
82    Value * blockGroupLimit = kb->CreateAdd(blockGroupBase, blocksPerGroup);
83    blockGroupLimit = kb->CreateSelect(kb->CreateICmpULT(blockGroupLimit, blocksToDo), blockGroupLimit, blocksToDo);
84    kb->CreateBr(processBlockGroup);
85
86    // Outer loop processes the blocksToDo in groups of up to blocksPerGroup at a time.
87    // The bitmask for this group is assembled.
88    kb->SetInsertPoint(processBlockGroup);
89    PHINode * blockNo = kb->CreatePHI(kb->getSizeTy(), 2);
90    PHINode * groupMask = kb->CreatePHI(iPackTy, 2);
91    blockNo->addIncoming(blockGroupBase, processGroups);
92    groupMask->addIncoming(ConstantInt::getNullValue(iPackTy), processGroups);
93
94    Value * blk = kb->CreateBlockAlignedLoad(kb->CreateGEP(sourceBitstream, {blockNo, kb->getInt32(0)}));
95    kb->CreateBlockAlignedStore(blk, kb->CreateGEP(uptoN_bitstream, {blockNo, kb->getInt32(0)}));
96    Value * hasbit = kb->simd_ugt(packSize, blk, kb->allZeroes());
97    Value * blockMask = kb->CreateZExtOrTrunc(kb->hsimd_signmask(packSize, hasbit), iPackTy);
98    Value * nextBlockNo = kb->CreateAdd(blockNo, kb->getSize(1));
99    Value * blockMaskPosition = kb->CreateMul(kb->CreateSub(blockNo, blockGroupBase), packsPerBlock);
100    Value * nextgroupMask = kb->CreateOr(groupMask, kb->CreateShl(blockMask, blockMaskPosition));
101    blockNo->addIncoming(nextBlockNo, processBlockGroup);
102    groupMask->addIncoming(nextgroupMask, processBlockGroup);
103    kb->CreateCondBr(kb->CreateICmpULT(nextBlockNo, blockGroupLimit), processBlockGroup, doScan);
104
105    // The index pack has been assembled - process the corresponding blocks.
106    kb->SetInsertPoint(doScan);
107    Value * seenSoFar = kb->getScalarField("seenSoFar");
108    kb->CreateCondBr(kb->CreateICmpUGT(nextgroupMask, ConstantInt::getNullValue(iPackTy)), scanLoop, scanDone);
109   
110    kb->SetInsertPoint(scanLoop);
111    PHINode * groupMaskPhi = kb->CreatePHI(iPackTy, 2);
112    groupMaskPhi->addIncoming(nextgroupMask, doScan);
113    PHINode * seenSoFarPhi = kb->CreatePHI(kb->getSizeTy(), 2);
114    seenSoFarPhi->addIncoming(seenSoFar, doScan);
115    Value * nonZeroPack = kb->CreateZExtOrTrunc(kb->CreateCountForwardZeroes(groupMaskPhi), kb->getSizeTy());
116    Value * scanMask = kb->CreateLoad(kb->CreateGEP(groupPackPtr, nonZeroPack));
117    Value * packCount = kb->CreateZExtOrTrunc(kb->CreatePopcount(scanMask), kb->getSizeTy());
118    Value * newTotalSeen = kb->CreateAdd(packCount, seenSoFarPhi);
119    Value * seenLessThanN = kb->CreateICmpULT(newTotalSeen, N);
120    kb->CreateCondBr(seenLessThanN, continueScanLoop, findNth);
121
122    kb->SetInsertPoint(continueScanLoop);
123    Value * reducedGroupMask = kb->CreateResetLowestBit(groupMaskPhi);
124    groupMaskPhi->addIncoming(reducedGroupMask, continueScanLoop);
125    seenSoFarPhi->addIncoming(newTotalSeen, continueScanLoop);
126    kb->CreateCondBr(kb->CreateICmpUGT(reducedGroupMask, ConstantInt::getNullValue(iPackTy)), scanLoop, scanDone);
127
128    // Now we have processed the group of blocks and updated the number of positions
129    // seenSoFar without finding the Nth bit. 
130    kb->SetInsertPoint(scanDone);
131    PHINode * newTotalSeenPhi = kb->CreatePHI(kb->getSizeTy(), 2);
132    newTotalSeenPhi->addIncoming(seenSoFar, doScan);
133    newTotalSeenPhi->addIncoming(newTotalSeen, continueScanLoop);
134    kb->setScalarField("seenSoFar", newTotalSeenPhi);
135    blockGroupBase->addIncoming(nextBlockNo, scanDone);
136    kb->CreateCondBr(kb->CreateICmpULT(nextBlockNo, blocksToDo), processGroups, notFoundYet);
137
138    kb->SetInsertPoint(notFoundYet);
139    // Now we have determined that the Nth bit has not been found in the entire
140    // set of itemsToDo.
141   
142    Value * finalCount = kb->CreateAdd(kb->getProducedItemCount("uptoN"), itemsToDo);
143    kb->setProducedItemCount("uptoN", finalCount);
144    kb->CreateBr(doSegmentReturn);
145
146    //
147    // With the last input scanMask loaded, the count of one bits seen reaches or
148    // exceeds N.  Determine the position immediately after the Nth one bit.
149    //
150    kb->SetInsertPoint(findNth);
151   
152    PHINode * seen1 = kb->CreatePHI(kb->getSizeTy(), 2);
153    seen1->addIncoming(seenSoFarPhi, scanLoop);
154    PHINode * remainingBits = kb->CreatePHI(iPackTy, 2);
155    remainingBits->addIncoming(scanMask, scanLoop);
156    Value * clearLowest = kb->CreateResetLowestBit(remainingBits);
157    Value * oneMoreSeen = kb->CreateAdd(seen1, kb->getSize(1));
158    seen1->addIncoming(oneMoreSeen, findNth);
159    remainingBits->addIncoming(clearLowest, findNth);
160    kb->CreateCondBr(kb->CreateICmpULT(oneMoreSeen, N), findNth, getPosnAfterNth);
161
162    //
163    // We have cleared the low bits of scanMask up to and including the Nth in the stream.
164    kb->SetInsertPoint(getPosnAfterNth);
165    Value * scanMaskUpToN = kb->CreateXor(scanMask, clearLowest);
166    Value * posnInPack = kb->CreateSub(ConstantInt::get(iPackTy, packSize), kb->CreateCountReverseZeroes(scanMaskUpToN));
167    Value * posnInGroup = kb->CreateAdd(kb->CreateMul(nonZeroPack, kb->getSize(packSize)), posnInPack);
168    Value * posnInItemsToDo = kb->CreateAdd(kb->CreateMul(blockGroupBase, blockSize), posnInGroup);
169    // It is conceivable that we found a bit at a position beyond the given itemsToDo,
170    // when we have a partial pack at the end of input.  In this case, the Nth bit does
171    // not exist in the valid range of itemsToDo.
172    kb->CreateCondBr(kb->CreateICmpUGE(posnInItemsToDo, itemsToDo), notFoundYet, nthPosFound);
173   
174    kb->SetInsertPoint(nthPosFound);
175    finalCount = kb->CreateAdd(kb->getProcessedItemCount("bits"), posnInItemsToDo);
176    Value * finalBlock = kb->CreateUDiv(posnInItemsToDo, blockSize);
177    blk = kb->CreateBlockAlignedLoad(kb->CreateGEP(sourceBitstream, {finalBlock, kb->getInt32(0)}));
178    blk = kb->CreateAnd(blk, kb->CreateNot(kb->bitblock_mask_from(kb->CreateURem(posnInItemsToDo, blockSize))));
179    Value * outputPtr = kb->CreateGEP(uptoN_bitstream, {finalBlock, kb->getInt32(0)});
180    kb->CreateBlockAlignedStore(blk, outputPtr);
181    kb->setProcessedItemCount("bits", finalCount);
182    kb->setProducedItemCount("uptoN", finalCount);
183    kb->setTerminationSignal();
184    kb->CreateBr(doSegmentReturn);
185   
186    kb->SetInsertPoint(doSegmentReturn);
187}
188
189UntilNkernel::UntilNkernel(const std::unique_ptr<kernel::KernelBuilder> & kb)
190: MultiBlockKernel("UntilN", {Binding{kb->getStreamSetTy(1, 1), "bits"}},
191                             {Binding{kb->getStreamSetTy(1, 1), "uptoN", MaxRatio(1)}},
192                             {Binding{kb->getSizeTy(), "N"}}, {},
193                             {Binding{kb->getSizeTy(), "seenSoFar"}}) {
194}
195
196}
Note: See TracBrowser for help on using the repository browser.