source: icGREP/icgrep-devel/icgrep/kernels/symboltablepipeline.cpp @ 4986

Last change on this file since 4986 was 4986, checked in by nmedfort, 3 years ago

First attempt at dynamic segment size intergration.

File size: 13.0 KB
Line 
1#include "symboltablepipeline.h"
2
3/*
4 *  Copyright (c) 2016 International Characters.
5 *  This software is licensed to the public under the Open Software License 3.0.
6 */
7
8#include "pipeline.h"
9#include "toolchain.h"
10#include "utf_encoding.h"
11
12#include <kernels/s2p_kernel.h>
13#include <kernels/instance.h>
14
15#include <pablo/function.h>
16#include <pablo/pablo_compiler.h>
17
18#include <re/re_cc.h>
19#include <re/re_rep.h>
20#include <re/re_name.h>
21#include <re/re_compiler.h>
22#include <re/printer_re.h>
23
24#include <cc/cc_compiler.h>
25
26#include <pablo/printer_pablos.h>
27#include <iostream>
28
29using namespace re;
30using namespace pablo;
31
32namespace kernel {
33
34SymbolTableBuilder::SymbolTableBuilder(Module * m, IDISA::IDISA_Builder * b)
35: mMod(m)
36, iBuilder(b)
37, mLongestLookahead(0)
38, mBitBlockType(b->getBitBlockType())
39, mBlockSize(b->getBitBlockWidth()) {
40
41}
42
43/** ------------------------------------------------------------------------------------------------------------- *
44 * @brief generateLeadingFunction
45 ** ------------------------------------------------------------------------------------------------------------- */
46PabloFunction * SymbolTableBuilder::generateLeadingFunction(const std::vector<unsigned> & endpoints) {
47    PabloFunction * const function = PabloFunction::Create("leading", 8, endpoints.size() + 2);
48    Encoding enc(Encoding::Type::ASCII, 8);
49    cc::CC_Compiler ccCompiler(*function, enc);
50    re::RE_Compiler reCompiler(*function, ccCompiler);
51    RE * cc = makeName(makeCC(makeCC(65, 90), makeCC(97, 122)));
52    reCompiler.compileUnicodeNames(cc);
53    PabloAST * const matches = reCompiler.compile(cc).stream;
54    PabloBlock * const entry = function->getEntryBlock();
55    PabloAST * const adv = entry->createAdvance(matches, 1);
56    PabloAST * const starts = entry->createAnd(matches, entry->createNot(adv));
57    PabloAST * const ends = entry->createAnd(adv, entry->createNot(matches));
58
59    function->setResult(0, entry->createAssign("S", starts));
60    function->setResult(1, entry->createAssign("E", ends));
61
62    PabloAST * M = ends;
63    unsigned step = 1;
64    unsigned i = 0;
65    for (unsigned endpoint : endpoints) {
66        assert (endpoint >= step);
67        unsigned span = endpoint - step;
68        while (span > step) {
69            M = entry->createOr(entry->createAdvance(M, step), M);
70            span = span - step;
71            step *= 2;
72        }
73        M = entry->createOr(entry->createAdvance(M, span), M);
74        function->setResult(i + 2, entry->createAssign("M" + std::to_string(i), M));
75        ++i;
76        step += span;
77    }
78
79    return function;
80}
81
82/** ------------------------------------------------------------------------------------------------------------- *
83 * @brief generateSortingFunction
84 ** ------------------------------------------------------------------------------------------------------------- */
85PabloFunction * SymbolTableBuilder::generateSortingFunction(const PabloFunction * const leading, const std::vector<unsigned> & endpoints) {
86    PabloFunction * const function = PabloFunction::Create("sorting", leading->getNumOfResults(), leading->getNumOfResults() * 2);
87    PabloBlock * const entry = function->getEntryBlock();
88    function->setParameter(0, entry->createVar("S"));
89    function->setParameter(1, entry->createVar("E"));
90    for (unsigned i = 2; i < leading->getNumOfResults(); ++i) {
91        function->setParameter(i, entry->createVar("M" + std::to_string(i - 2)));
92    }
93    PabloAST * R = function->getParameter(0);
94    PabloAST * const E = entry->createNot(function->getParameter(1));
95    unsigned i = 1;
96    unsigned lowerbound = 0;
97    for (unsigned endpoint : endpoints) {
98        PabloAST * const M = function->getParameter(i + 1);
99        PabloAST * const L = entry->createLookahead(M, endpoint, "lookahead" + std::to_string(endpoint));
100        PabloAST * S = entry->createAnd(L, R);
101        Assign * Si = entry->createAssign("S_" + std::to_string(i), S);
102        R = entry->createXor(R, S);
103        PabloAST * F = entry->createScanThru(R, E);
104        Assign * Ei = entry->createAssign("E_" + std::to_string(i), F);
105        function->setResult(i * 2, Si);
106        function->setResult(i * 2 + 1, Ei);
107        ++i;
108        lowerbound = endpoint;
109    }
110    Assign * Si = entry->createAssign("S_n", R);
111    PabloAST * F = entry->createScanThru(R, E);
112    Assign * Ei = entry->createAssign("E_n", F);
113    function->setResult(i * 2, Si);
114    function->setResult(i * 2 + 1, Ei);
115    mLongestLookahead = lowerbound;
116    return function;
117}
118
119/** ------------------------------------------------------------------------------------------------------------- *
120 * @brief createKernels
121 ** ------------------------------------------------------------------------------------------------------------- */
122void SymbolTableBuilder::createKernels() {
123
124    std::vector<unsigned> endpoints;
125    endpoints.push_back(1);
126    endpoints.push_back(2);
127    endpoints.push_back(4);
128    endpoints.push_back(8);
129    endpoints.push_back(16);
130
131    PabloCompiler pablo_compiler(mMod, iBuilder);
132    PabloFunction * const leading = generateLeadingFunction(endpoints);
133    PabloFunction * const sorting = generateSortingFunction(leading, endpoints);
134
135    mS2PKernel = new KernelBuilder("s2p", mMod, iBuilder);
136    mLeadingKernel = new KernelBuilder("leading", mMod, iBuilder);
137    mSortingKernel = new KernelBuilder("sorting", mMod, iBuilder);
138
139    mLeadingKernel->setLongestLookaheadAmount(mLongestLookahead);
140    mSortingKernel->setLongestLookaheadAmount(mLongestLookahead);
141
142    generateS2PKernel(mMod, iBuilder, mS2PKernel);
143
144    pablo_compiler.setKernel(mLeadingKernel);
145    pablo_compiler.compile(leading);
146    pablo_compiler.setKernel(mSortingKernel);
147    pablo_compiler.compile(sorting);
148
149    delete leading;
150    delete sorting;
151
152    releaseSlabAllocatorMemory();
153}
154
155Function * SymbolTableBuilder::ExecuteKernels(){
156
157    Type * intType = iBuilder->getInt64Ty();
158
159    Type * inputType = PointerType::get(ArrayType::get(StructType::get(mMod->getContext(), std::vector<Type *>({ArrayType::get(mBitBlockType, 8)})), 1), 0);
160    Function * const main = cast<Function>(mMod->getOrInsertFunction("Main", Type::getVoidTy(mMod->getContext()), inputType, intType, nullptr));
161    main->setCallingConv(CallingConv::C);
162    Function::arg_iterator args = main->arg_begin();
163
164    Value * const inputStream = args++;
165    inputStream->setName("input");
166
167    Value * const bufferSize = args++;
168    bufferSize->setName("buffersize");
169
170    iBuilder->SetInsertPoint(BasicBlock::Create(mMod->getContext(), "entry", main,0));
171
172    BasicBlock * entryBlock = iBuilder->GetInsertBlock();
173
174    BasicBlock * leadingTestBlock = BasicBlock::Create(mMod->getContext(), "leadingCond", main, 0);
175    BasicBlock * leadingBodyBlock = BasicBlock::Create(mMod->getContext(), "leadingBody", main, 0);
176
177    BasicBlock * regularTestBlock = BasicBlock::Create(mMod->getContext(), "fullCond", main, 0);
178    BasicBlock * regularBodyBlock = BasicBlock::Create(mMod->getContext(), "fullBody", main, 0);
179    BasicBlock * regularExitBlock = BasicBlock::Create(mMod->getContext(), "fullExit", main, 0);
180
181    BasicBlock * partialBlock = BasicBlock::Create(mMod->getContext(),  "partialBlock", main, 0);
182
183    BasicBlock * finalTestBlock = BasicBlock::Create(mMod->getContext(),  "finalCond", main, 0);
184    BasicBlock * finalBodyBlock = BasicBlock::Create(mMod->getContext(),  "finalBody", main, 0);
185
186    BasicBlock * exitBlock = BasicBlock::Create(mMod->getContext(), "exit", main, 0);
187
188    Instance * s2pInstance = mS2PKernel->instantiate(inputStream);
189    Instance * leadingInstance = mLeadingKernel->instantiate(s2pInstance->getOutputStreamSet());
190    Instance * sortingInstance = mSortingKernel->instantiate(leadingInstance->getOutputStreamSet());
191
192    const unsigned leadingBlocks = (mLongestLookahead + iBuilder->getBitBlockWidth() - 1) / iBuilder->getBitBlockWidth();
193
194    Value * const requiredBytes = iBuilder->getInt64(mBlockSize * leadingBlocks);
195    Value * const blockSize = iBuilder->getInt64(mBlockSize);
196
197    // If the buffer size is smaller than our largest length group, only check up to the buffer size.
198    Value * safetyCheck = iBuilder->CreateICmpUGE(bufferSize, blockSize);
199    if (blockSize == requiredBytes) {
200        iBuilder->CreateCondBr(safetyCheck, leadingTestBlock, exitBlock); // fix this to be a special case
201    } else {
202        throw std::runtime_error("Not supported yet!");
203    }
204
205    // First compute any necessary leading blocks to allow the sorting kernel access to the "future" data produced by
206    // the leading kernel ...
207    iBuilder->SetInsertPoint(leadingTestBlock);
208    PHINode * blockNo = iBuilder->CreatePHI(intType, 2);
209    blockNo->addIncoming(iBuilder->getInt64(0), entryBlock);
210    PHINode * remainingBytes = iBuilder->CreatePHI(intType, 2);
211    remainingBytes->addIncoming(bufferSize, entryBlock);
212    Value * leadingBlocksCond = iBuilder->CreateICmpULT(blockNo, iBuilder->getInt64(leadingBlocks));
213    iBuilder->CreateCondBr(leadingBlocksCond, leadingBodyBlock, regularTestBlock);
214    iBuilder->SetInsertPoint(leadingBodyBlock);
215    s2pInstance->CreateDoBlockCall();
216    leadingInstance->CreateDoBlockCall();
217    blockNo->addIncoming(iBuilder->CreateAdd(blockNo, iBuilder->getInt64(1)), leadingBodyBlock);
218    remainingBytes->addIncoming(iBuilder->CreateSub(remainingBytes, blockSize), leadingBodyBlock);
219    iBuilder->CreateBr(leadingTestBlock);
220
221    // Now all the data for which we can produce and consume a full leading block...
222    iBuilder->SetInsertPoint(regularTestBlock);
223    PHINode * blockNo2 = iBuilder->CreatePHI(intType, 2);
224    blockNo2->addIncoming(blockNo, leadingTestBlock);
225    PHINode * remainingBytes2 = iBuilder->CreatePHI(intType, 2);
226    remainingBytes2->addIncoming(remainingBytes, leadingTestBlock);
227    Value * remainingBytesCond = iBuilder->CreateICmpUGE(remainingBytes2, requiredBytes);
228    iBuilder->CreateCondBr(remainingBytesCond, regularBodyBlock, regularExitBlock);
229    iBuilder->SetInsertPoint(regularBodyBlock);
230    s2pInstance->CreateDoBlockCall();
231    leadingInstance->CreateDoBlockCall();
232    sortingInstance->CreateDoBlockCall();
233    blockNo2->addIncoming(iBuilder->CreateAdd(blockNo2, iBuilder->getInt64(1)), regularBodyBlock);
234    remainingBytes2->addIncoming(iBuilder->CreateSub(remainingBytes2, blockSize), regularBodyBlock);
235    iBuilder->CreateBr(regularTestBlock);
236
237
238    // Check if we have a partial blocks worth of leading data remaining
239    iBuilder->SetInsertPoint(regularExitBlock);
240    Value * partialBlockCond = iBuilder->CreateICmpUGT(remainingBytes2, ConstantInt::getNullValue(intType));
241    iBuilder->CreateCondBr(partialBlockCond, partialBlock, finalTestBlock);
242
243    // If we do, process it and mask out the data
244    iBuilder->SetInsertPoint(partialBlock);
245    s2pInstance->CreateDoBlockCall();
246    Value * partialLeadingData[2];
247    for (unsigned i = 0; i < 2; ++i) {
248        partialLeadingData[i] = leadingInstance->getOutputStream(i);
249    }
250    leadingInstance->CreateDoBlockCall();
251    Type * fullBitBlockType = iBuilder->getIntNTy(mBlockSize);
252    Value * remaining = iBuilder->CreateZExt(iBuilder->CreateSub(blockSize, remainingBytes2), fullBitBlockType);
253    Value * eofMask = iBuilder->CreateLShr(ConstantInt::getAllOnesValue(fullBitBlockType), remaining);
254    eofMask = iBuilder->CreateBitCast(eofMask, mBitBlockType);
255    for (unsigned i = 0; i < 2; ++i) {
256        Value * value = iBuilder->CreateAnd(iBuilder->CreateBlockAlignedLoad(partialLeadingData[i]), eofMask);
257        iBuilder->CreateBlockAlignedStore(value, partialLeadingData[i]);
258    }
259    for (unsigned i = 0; i < 2; ++i) {
260        iBuilder->CreateBlockAlignedStore(ConstantInt::getNullValue(mBitBlockType), leadingInstance->getOutputStream(i));
261    }
262    sortingInstance->CreateDoBlockCall();
263    iBuilder->CreateBr(finalTestBlock);
264
265    // Now clear the leading data and test the final blocks
266    iBuilder->SetInsertPoint(finalTestBlock);
267    PHINode * remainingFullBlocks = iBuilder->CreatePHI(iBuilder->getInt64Ty(), 3);
268    remainingFullBlocks->addIncoming(iBuilder->getInt64(leadingBlocks), regularExitBlock);
269    remainingFullBlocks->addIncoming(iBuilder->getInt64(leadingBlocks), partialBlock);
270    Value * remainingFullBlocksCond = iBuilder->CreateICmpUGT(remainingFullBlocks, ConstantInt::getNullValue(intType));
271    iBuilder->CreateCondBr(remainingFullBlocksCond, finalBodyBlock, exitBlock);
272
273    iBuilder->SetInsertPoint(finalBodyBlock);
274    for (unsigned i = 0; i < 2; ++i) {
275        iBuilder->CreateBlockAlignedStore(ConstantInt::getNullValue(mBitBlockType), leadingInstance->getOutputStream(i));
276    }
277    Value * blockNoPtr = leadingInstance->getBlockNo();
278    Value * blockNoValue = iBuilder->CreateLoad(blockNoPtr);
279    blockNoValue = iBuilder->CreateAdd(blockNoValue, ConstantInt::get(blockNoValue->getType(), 1));
280    iBuilder->CreateStore(blockNoValue, blockNoPtr);
281
282    sortingInstance->CreateDoBlockCall();
283
284    remainingFullBlocks->addIncoming(iBuilder->CreateSub(remainingFullBlocks, iBuilder->getInt64(1)), finalBodyBlock);
285
286    iBuilder->CreateBr(finalTestBlock);
287
288    iBuilder->SetInsertPoint(exitBlock);
289    iBuilder->CreateRetVoid();
290
291    main->dump();
292
293    return main;
294}
295
296SymbolTableBuilder::~SymbolTableBuilder() {
297    delete mS2PKernel;
298    delete mLeadingKernel;
299    delete mSortingKernel;
300}
301
302}
Note: See TracBrowser for help on using the repository browser.