source: icGREP/icgrep-devel/icgrep/kernels/symboltablepipeline.cpp @ 5251

Last change on this file since 5251 was 5246, checked in by nmedfort, 3 years ago

Code clean up to enforce proper calling order of KernelBuilder? methods

File size: 40.1 KB
Line 
1#include "symboltablepipeline.h"
2
3/*
4 *  Copyright (c) 2016 International Characters.
5 *  This software is licensed to the public under the Open Software License 3.0.
6 */
7
8#include "pipeline.h"
9#include "toolchain.h"
10#include "utf_encoding.h"
11
12#include <kernels/s2p_kernel.h>
13#include <kernels/instance.h>
14
15#include <pablo/pablo_compiler.h>
16#include <pablo/analysis/pabloverifier.hpp>
17
18#include <re/re_cc.h>
19#include <re/re_rep.h>
20#include <re/re_name.h>
21#include <re/re_compiler.h>
22#include <re/printer_re.h>
23
24#include <cc/cc_compiler.h>
25
26#include <pablo/printer_pablos.h>
27#include <iostream>
28
29#include <llvm/IR/Intrinsics.h>
30
31using namespace re;
32using namespace pablo;
33
34namespace kernel {
35
36SymbolTableBuilder::SymbolTableBuilder(Module * m, IDISA::IDISA_Builder * b)
37: mMod(m)
38, iBuilder(b)
39, mLongestLookahead(0)
40, mBitBlockType(b->getBitBlockType())
41, mBlockSize(b->getBitBlockWidth()) {
42
43}
44
45/** ------------------------------------------------------------------------------------------------------------- *
46 * @brief generateLeadingFunction
47 ** ------------------------------------------------------------------------------------------------------------- */
48PabloFunction * SymbolTableBuilder::generateLeadingFunction(const std::vector<unsigned> & endpoints) {
49    PabloFunction * const function = PabloFunction::Create("leading", 8, endpoints.size() + 2);
50    Encoding enc(Encoding::Type::ASCII, 8);
51    cc::CC_Compiler ccCompiler(*function, enc);
52    re::RE_Compiler reCompiler(*function, ccCompiler);
53    RE * cc = makeName(makeCC(makeCC(makeCC('a', 'z'), makeCC('A', 'Z')), makeCC('0', '9')));
54    reCompiler.compileUnicodeNames(cc);
55    PabloAST * const matches = reCompiler.compile(cc).stream;
56    PabloBlock * const entry = function->getEntryBlock();
57    PabloAST * const adv = entry->createAdvance(matches, 1);
58    PabloAST * const starts = entry->createAnd(matches, entry->createNot(adv));
59    PabloAST * const ends = entry->createAnd(adv, entry->createNot(matches));
60
61    function->setResult(0, entry->createAssign("l.S", starts));
62    function->setResult(1, entry->createAssign("l.E", ends));
63
64    PabloAST * M = ends;
65    unsigned step = 1;
66    unsigned i = 0;
67    for (unsigned endpoint : endpoints) {
68        assert (endpoint >= step);
69        unsigned span = endpoint - step;
70        while (span > step) {
71            M = entry->createOr(entry->createAdvance(M, step), M);
72            span = span - step;
73            step *= 2;
74        }
75        M = entry->createOr(entry->createAdvance(M, span), M);
76        function->setResult(i + 2, entry->createAssign("l.M" + std::to_string(i), M));
77        ++i;
78        step += span;
79    }
80
81    return function;
82}
83
84/** ------------------------------------------------------------------------------------------------------------- *
85 * @brief generateSortingFunction
86 ** ------------------------------------------------------------------------------------------------------------- */
87PabloFunction * SymbolTableBuilder::generateSortingFunction(const PabloFunction * const leading, const std::vector<unsigned> & endpoints) {
88    PabloFunction * const function = PabloFunction::Create("sorting", leading->getNumOfResults(), (leading->getNumOfResults() - 1) * 2);
89    PabloBlock * entry = function->getEntryBlock();
90    function->setParameter(0, entry->createVar("l.S"));
91    function->setParameter(1, entry->createVar("l.E"));
92    for (unsigned i = 2; i < leading->getNumOfResults(); ++i) {
93        function->setParameter(i, entry->createVar("l.M" + std::to_string(i - 2)));
94    }
95    PabloAST * R = function->getParameter(0);
96    PabloAST * const E = entry->createNot(function->getParameter(1));
97    unsigned i = 0;
98    unsigned lowerbound = 0;
99    for (unsigned endpoint : endpoints) {
100        PabloAST * const M = function->getParameter(i + 2);
101        PabloAST * const L = entry->createLookahead(M, endpoint, "lookahead" + std::to_string(endpoint));
102        PabloAST * S = entry->createAnd(L, R);
103        Assign * Si = entry->createAssign("s.S_" + std::to_string(i + 1), S);
104        PabloAST * F = entry->createScanThru(S, E);
105        Assign * Ei = entry->createAssign("s.E_" + std::to_string(i + 1), F);
106        function->setResult(i * 2, Si);
107        function->setResult(i * 2 + 1, Ei);
108        R = entry->createXor(R, S);
109        ++i;
110        lowerbound = endpoint;
111    }
112    Assign * Si = entry->createAssign("s.S_n", R);
113    PabloAST * F = entry->createScanThru(R, E);
114    Assign * Ei = entry->createAssign("s.E_n", F);
115    function->setResult(i * 2, Si);
116    function->setResult(i * 2 + 1, Ei);
117    mLongestLookahead = lowerbound;
118
119    return function;
120}
121
122/** ------------------------------------------------------------------------------------------------------------- *
123 * @brief generateCountForwardZeroes
124 ** ------------------------------------------------------------------------------------------------------------- */
125inline Value * generateCountForwardZeroes(IDISA::IDISA_Builder * iBuilder, Value * bits) {
126    Value * cttzFunc = Intrinsic::getDeclaration(iBuilder->getModule(), Intrinsic::cttz, bits->getType());
127    return iBuilder->CreateCall(cttzFunc, std::vector<Value *>({bits, ConstantInt::get(iBuilder->getInt1Ty(), 0)}));
128}
129
130/** ------------------------------------------------------------------------------------------------------------- *
131 * @brief generateMaskedGather
132 ** ------------------------------------------------------------------------------------------------------------- */
133inline Value * SymbolTableBuilder::generateMaskedGather(Value * const base, Value * const vindex, Value * const mask) {
134
135    /*
136        From Intel:
137
138        extern __m256i _mm256_mask_i32gather_epi32(__m256i def_vals, int const * base, __m256i vindex, __m256i vmask, const int scale);
139
140        From Clang avx2intrin.h:
141
142        #define _mm256_mask_i32gather_epi32(a, m, i, mask, s) __extension__ ({ \
143           (__m256i)__builtin_ia32_gatherd_d256((__v8si)(__m256i)(a), \
144                                                (int const *)(m), \
145                                                (__v8si)(__m256i)(i), \
146                                                (__v8si)(__m256i)(mask), (s)); })
147        From llvm IntrinsicsX86.td:
148
149        def llvm_ptr_ty        : LLVMPointerType<llvm_i8_ty>;             // i8*
150
151        def int_x86_avx2_gather_d_d_256 : GCCBuiltin<"__builtin_ia32_gatherd_d256">,
152           Intrinsic<[llvm_v8i32_ty],
153           [llvm_v8i32_ty, llvm_ptr_ty, llvm_v8i32_ty, llvm_v8i32_ty, llvm_i8_ty],
154           [IntrReadArgMem]>;
155
156     */
157
158    VectorType * const vecType = VectorType::get(iBuilder->getInt32Ty(), 8);
159    Function * const vgather = Intrinsic::getDeclaration(iBuilder->getModule(), Intrinsic::x86_avx2_gather_d_d_256);
160    return iBuilder->CreateCall(vgather, {Constant::getNullValue(vecType), base, iBuilder->CreateBitCast(vindex, vecType), iBuilder->CreateBitCast(mask, vecType), iBuilder->getInt8(1)});
161}
162
163/** ------------------------------------------------------------------------------------------------------------- *
164 * @brief generateResetLowestBit
165 ** ------------------------------------------------------------------------------------------------------------- */
166inline Value * generateResetLowestBit(IDISA::IDISA_Builder * iBuilder, Value * bits) {
167    Value * bits_minus1 = iBuilder->CreateSub(bits, ConstantInt::get(bits->getType(), 1));
168    return iBuilder->CreateAnd(bits_minus1, bits);
169}
170
171/** ------------------------------------------------------------------------------------------------------------- *
172 * @brief generateGatherKernel
173 ** ------------------------------------------------------------------------------------------------------------- */
174void SymbolTableBuilder::generateGatherKernel(KernelBuilder * kBuilder, const std::vector<unsigned> & endpoints, const unsigned scanWordBitWidth) {
175
176    Type * const intScanWordTy = iBuilder->getIntNTy(scanWordBitWidth);
177    const unsigned fieldCount = iBuilder->getBitBlockWidth() / scanWordBitWidth;
178    Type * const scanWordVectorType = VectorType::get(intScanWordTy, fieldCount);
179    const unsigned vectorWidth = iBuilder->getBitBlockWidth() / 32;
180    const unsigned gatherCount = vectorWidth * 4;
181
182    Type * startArrayType = ArrayType::get(iBuilder->getInt32Ty(), iBuilder->getBitBlockWidth() + gatherCount);
183    Type * endArrayType = ArrayType::get(iBuilder->getInt32Ty(), gatherCount);
184    Type * groupType = StructType::get(iBuilder->getInt32Ty(), startArrayType, iBuilder->getInt32Ty(), endArrayType, nullptr);
185    const unsigned baseIdx = kBuilder->addInternalState(iBuilder->getInt8PtrTy(), "Base");
186    const unsigned gatherPositionArrayIdx = kBuilder->addInternalState(ArrayType::get(groupType, endpoints.size()), "Positions");
187
188    for (unsigned maxKeyLength : endpoints) {
189        kBuilder->addInputStream(1, "startStream" + std::to_string(maxKeyLength));
190        kBuilder->addInputStream(1, "endStream" + std::to_string(maxKeyLength));
191        kBuilder->addOutputStream(4); // ((maxKeyLength + 3) / 4) * 4
192    }
193    kBuilder->addInputStream(1, "startStreamN");
194    kBuilder->addInputStream(1, "endStreamN");
195
196    Function * const function = kBuilder->prepareFunction();
197
198    BasicBlock * const entry = iBuilder->GetInsertBlock();
199
200    BasicBlock * groupCond = BasicBlock::Create(mMod->getContext(), "groupCond", function, 0);
201    BasicBlock * groupBody = BasicBlock::Create(mMod->getContext(), "groupBody", function, 0);
202
203    BasicBlock * startOuterCond = BasicBlock::Create(mMod->getContext(), "startOuterCond", function, 0);
204    BasicBlock * startOuterBody = BasicBlock::Create(mMod->getContext(), "startOuterBody", function, 0);
205    BasicBlock * startInnerCond = BasicBlock::Create(mMod->getContext(), "startInnerCond", function, 0);
206    BasicBlock * startInnerBody = BasicBlock::Create(mMod->getContext(), "startInnerBody", function, 0);
207
208    BasicBlock * endOuterCond = BasicBlock::Create(mMod->getContext(), "endOuterCond", function, 0);
209    BasicBlock * endOuterBody = BasicBlock::Create(mMod->getContext(), "endOuterBody", function, 0);
210    BasicBlock * endInnerCond = BasicBlock::Create(mMod->getContext(), "endInnerCond", function, 0);
211    BasicBlock * endInnerBody = BasicBlock::Create(mMod->getContext(), "endInnerBody", function, 0);
212
213    BasicBlock * gather = BasicBlock::Create(mMod->getContext(), "gather", function, 0);
214
215    BasicBlock * nextGroup = BasicBlock::Create(mMod->getContext(), "nextGroup", function, 0);
216
217    BasicBlock * exit = BasicBlock::Create(mMod->getContext(), "exit", function, 0);
218
219
220    // ENTRY BLOCK
221    iBuilder->SetInsertPoint(entry);
222    Type * const int32PtrTy = PointerType::get(iBuilder->getInt32Ty(), 0);
223    FunctionType * const gatherFunctionType = FunctionType::get(iBuilder->getVoidTy(), {iBuilder->getInt8PtrTy(), int32PtrTy, int32PtrTy, iBuilder->getInt32Ty(), iBuilder->getInt8PtrTy()}, false);
224    Value * const gatherFunctionPtrArray = iBuilder->CreateAlloca(PointerType::get(gatherFunctionType, 0), iBuilder->getInt32(endpoints.size()), "gatherFunctionPtrArray");
225
226    unsigned i = 0;
227    unsigned minKeyLength = 0;
228    for (unsigned maxKeyLength : endpoints) {
229        Function * f = generateGatherFunction(minKeyLength, maxKeyLength);
230        mGatherFunction.push_back(f);
231        iBuilder->CreateStore(f, iBuilder->CreateGEP(gatherFunctionPtrArray, iBuilder->getInt32(i++)));
232        minKeyLength = maxKeyLength;
233    }
234
235    //TODO: this won't work on files > 2^32 bytes yet; needs an intermediate flush then a recalculation of the base pointer.
236    Value * const base = iBuilder->CreateLoad(kBuilder->getInternalState(baseIdx), "base");
237    Value * const positionArray = kBuilder->getInternalState(gatherPositionArrayIdx);
238
239    Value * blockPos = iBuilder->CreateLoad(kBuilder->getBlockNo());
240    blockPos = iBuilder->CreateMul(blockPos, iBuilder->getSize(iBuilder->getBitBlockWidth()));
241
242    iBuilder->CreateBr(groupCond);
243
244    // GROUP COND
245    iBuilder->SetInsertPoint(groupCond);
246    PHINode * groupIV = iBuilder->CreatePHI(iBuilder->getInt32Ty(), 2);
247    groupIV->addIncoming(iBuilder->getInt32(0), entry);
248    Value * groupTest = iBuilder->CreateICmpNE(groupIV, iBuilder->getInt32(endpoints.size()));
249    iBuilder->CreateCondBr(groupTest, groupBody, exit);
250
251    // GROUP BODY
252    iBuilder->SetInsertPoint(groupBody);
253    // if two positions cannot be in the same vector element, we could possibly do some work in parallel here.
254
255    Value * index = iBuilder->CreateMul(groupIV, iBuilder->getInt32(2));
256    Value * startStreamPtr = kBuilder->getInputStream(index);
257    Value * startStream = iBuilder->CreateBlockAlignedLoad(startStreamPtr);
258    startStream = iBuilder->CreateBitCast(startStream, scanWordVectorType, "startStream");
259
260    index = iBuilder->CreateAdd(index, iBuilder->getInt32(1));
261    Value * endStreamPtr = kBuilder->getInputStream(index);
262    Value * endStream = iBuilder->CreateBlockAlignedLoad(endStreamPtr);
263    endStream = iBuilder->CreateBitCast(endStream, scanWordVectorType, "endStream");
264
265    Value * startIndexPtr = iBuilder->CreateGEP(positionArray, {iBuilder->getInt32(0), groupIV, iBuilder->getInt32(0)});
266    Value * startIndex = iBuilder->CreateLoad(startIndexPtr, "startIndex");
267    Value * startPosArray = iBuilder->CreateGEP(positionArray, {iBuilder->getInt32(0), groupIV, iBuilder->getInt32(1)}, "startPosArray");
268    Value * endIndexPtr = iBuilder->CreateGEP(positionArray, {iBuilder->getInt32(0), groupIV, iBuilder->getInt32(2)}, "endIndexPtr");
269    Value * endIndex = iBuilder->CreateLoad(endIndexPtr, "endIndex");
270    Value * endPosArray = iBuilder->CreateGEP(positionArray, {iBuilder->getInt32(0), groupIV, iBuilder->getInt32(3)}, "endPosArray");
271
272    iBuilder->CreateBr(startOuterCond);
273
274    // START OUTER COND
275    iBuilder->SetInsertPoint(startOuterCond);
276    PHINode * startBlockOffset = iBuilder->CreatePHI(iBuilder->getSizeTy(), 2);
277    startBlockOffset->addIncoming(blockPos, groupBody);
278    PHINode * startIndexPhi1 = iBuilder->CreatePHI(startIndex->getType(), 2, "startIndexPhi1");
279    startIndexPhi1->addIncoming(startIndex, groupBody);
280    PHINode * startIV = iBuilder->CreatePHI(iBuilder->getSizeTy(), 2);
281    startIV->addIncoming(iBuilder->getSize(0), groupBody);
282    Value * startOuterTest = iBuilder->CreateICmpNE(startIV, iBuilder->getSize(fieldCount));
283    iBuilder->CreateCondBr(startOuterTest, startOuterBody, endOuterCond);
284
285    // START OUTER BODY
286    iBuilder->SetInsertPoint(startOuterBody);
287    Value * startField = iBuilder->CreateExtractElement(startStream, startIV);
288    startIV->addIncoming(iBuilder->CreateAdd(startIV, iBuilder->getSize(1)), startInnerCond);
289    startBlockOffset->addIncoming(iBuilder->CreateAdd(startBlockOffset, iBuilder->getSize(scanWordBitWidth)), startInnerCond);
290    iBuilder->CreateBr(startInnerCond);
291
292    // START INNER COND
293    iBuilder->SetInsertPoint(startInnerCond);
294    PHINode * startIndexPhi2 = iBuilder->CreatePHI(startIndex->getType(), 2, "startIndexPhi2");
295    startIndexPhi2->addIncoming(startIndexPhi1, startOuterBody);
296    startIndexPhi1->addIncoming(startIndexPhi2, startInnerCond);
297    PHINode * startFieldPhi = iBuilder->CreatePHI(intScanWordTy, 2);
298    startFieldPhi->addIncoming(startField, startOuterBody);
299    Value * test = iBuilder->CreateICmpNE(startFieldPhi, ConstantInt::getNullValue(intScanWordTy));
300    iBuilder->CreateCondBr(test, startInnerBody, startOuterCond);
301
302    // START INNER BODY
303    iBuilder->SetInsertPoint(startInnerBody);
304    Value * startPos = generateCountForwardZeroes(iBuilder, startFieldPhi);
305    startFieldPhi->addIncoming(generateResetLowestBit(iBuilder, startFieldPhi), startInnerBody);
306    startPos = iBuilder->CreateTruncOrBitCast(iBuilder->CreateOr(startPos, startBlockOffset), iBuilder->getInt32Ty());
307    iBuilder->CreateStore(startPos, iBuilder->CreateGEP(startPosArray, {iBuilder->getInt32(0), startIndexPhi2}));
308    startIndexPhi2->addIncoming(iBuilder->CreateAdd(startIndexPhi2, ConstantInt::get(startIndexPhi2->getType(), 1)), startInnerBody);
309    iBuilder->CreateBr(startInnerCond);
310
311    // END POINT OUTER COND
312    iBuilder->SetInsertPoint(endOuterCond);
313    PHINode * endBlockOffset = iBuilder->CreatePHI(iBuilder->getSizeTy(), 2);
314    endBlockOffset->addIncoming(blockPos, startOuterCond);
315    PHINode * endIndexPhi1 = iBuilder->CreatePHI(endIndex->getType(), 2);
316    endIndexPhi1->addIncoming(endIndex, startOuterCond);
317    PHINode * startIndexPhi3 = iBuilder->CreatePHI(startIndex->getType(), 2, "startIndexPhi3");
318    startIndexPhi3->addIncoming(startIndexPhi1, startOuterCond);
319    PHINode * endIV = iBuilder->CreatePHI(iBuilder->getSizeTy(), 2);
320    endIV->addIncoming(iBuilder->getSize(0), startOuterCond);
321    Value * endOuterTest = iBuilder->CreateICmpNE(endIV, iBuilder->getSize(fieldCount));
322    iBuilder->CreateCondBr(endOuterTest, endOuterBody, nextGroup);
323
324    // END POINT OUTER BODY
325    iBuilder->SetInsertPoint(endOuterBody);
326    Value * endField = iBuilder->CreateExtractElement(endStream, endIV);
327    endIV->addIncoming(iBuilder->CreateAdd(endIV, iBuilder->getSize(1)), endInnerCond);
328    endBlockOffset->addIncoming(iBuilder->CreateAdd(endBlockOffset, iBuilder->getSize(scanWordBitWidth)), endInnerCond);
329    iBuilder->CreateBr(endInnerCond);
330
331    // END POINT INNER COND
332    iBuilder->SetInsertPoint(endInnerCond);
333    PHINode * startIndexPhi4 = iBuilder->CreatePHI(startIndexPhi3->getType(), 3, "startIndexPhi4");
334    startIndexPhi4->addIncoming(startIndexPhi3, endOuterBody);
335    startIndexPhi4->addIncoming(startIndexPhi4, endInnerBody);
336    startIndexPhi3->addIncoming(startIndexPhi4, endInnerCond);
337    PHINode * endIndexPhi2 = iBuilder->CreatePHI(endIndex->getType(), 3);
338    endIndexPhi2->addIncoming(endIndexPhi1, endOuterBody);
339    endIndexPhi1->addIncoming(endIndexPhi2, endInnerCond);
340    endIndexPhi2->addIncoming(ConstantInt::getNullValue(endIndex->getType()), gather);
341    PHINode * endFieldPhi = iBuilder->CreatePHI(intScanWordTy, 3);
342    endFieldPhi->addIncoming(endField, endOuterBody);
343    Value * endInnerTest = iBuilder->CreateICmpNE(endFieldPhi, ConstantInt::getNullValue(intScanWordTy));
344    iBuilder->CreateCondBr(endInnerTest, endInnerBody, endOuterCond);
345
346    // END POINT INNER BODY
347    iBuilder->SetInsertPoint(endInnerBody);
348    Value * endPos = generateCountForwardZeroes(iBuilder, endFieldPhi);
349    Value * updatedEndFieldPhi = generateResetLowestBit(iBuilder, endFieldPhi);
350    endFieldPhi->addIncoming(updatedEndFieldPhi, endInnerBody);
351    endFieldPhi->addIncoming(updatedEndFieldPhi, gather);
352    endPos = iBuilder->CreateTruncOrBitCast(iBuilder->CreateOr(endPos, endBlockOffset), iBuilder->getInt32Ty());
353    iBuilder->CreateStore(endPos, iBuilder->CreateGEP(endPosArray, {iBuilder->getInt32(0), endIndexPhi2}));
354    Value * updatedEndIndexPhi = iBuilder->CreateAdd(endIndexPhi2, ConstantInt::get(endIndexPhi2->getType(), 1));
355    endIndexPhi2->addIncoming(updatedEndIndexPhi, endInnerBody);
356    Value * filledEndPosBufferTest = iBuilder->CreateICmpEQ(updatedEndIndexPhi, ConstantInt::get(updatedEndIndexPhi->getType(), gatherCount));
357    iBuilder->CreateCondBr(filledEndPosBufferTest, gather, endInnerCond);
358
359    // GATHER
360    iBuilder->SetInsertPoint(gather);
361
362    Value * startArrayPtr = iBuilder->CreatePointerCast(startPosArray, PointerType::get(iBuilder->getInt32Ty(), 0));
363    Value * endArrayPtr = iBuilder->CreatePointerCast(endPosArray, PointerType::get(iBuilder->getInt32Ty(), 0));
364    Value * gatherFunctionPtr = iBuilder->CreateLoad(iBuilder->CreateGEP(gatherFunctionPtrArray, groupIV));
365    Value * outputBuffer = iBuilder->CreatePointerCast(kBuilder->getOutputStream(groupIV), iBuilder->getInt8PtrTy());
366    iBuilder->CreateCall(gatherFunctionPtr, {base, startArrayPtr, endArrayPtr, iBuilder->getInt32(32), outputBuffer});
367    // Copy the unused start positions to the front of the start position array and adjust the start index
368    Value * remainingArrayPtr = iBuilder->CreateGEP(startArrayPtr, iBuilder->getInt32(gatherCount));
369    Value * remainingCount = iBuilder->CreateSub(startIndexPhi4, iBuilder->getInt32(gatherCount));
370    Value * remainingBytes = iBuilder->CreateMul(remainingCount, iBuilder->getInt32(4));
371    iBuilder->CreateMemMove(startArrayPtr, remainingArrayPtr, remainingBytes, 4);
372    startIndexPhi4->addIncoming(remainingCount, gather);
373    iBuilder->CreateBr(endInnerCond);
374
375    // NEXT GROUP
376    iBuilder->SetInsertPoint(nextGroup);
377    iBuilder->CreateStore(startIndexPhi3, startIndexPtr);
378    iBuilder->CreateStore(endIndexPhi1, endIndexPtr);
379    groupIV->addIncoming(iBuilder->CreateAdd(groupIV, ConstantInt::get(groupIV->getType(), 1)), nextGroup);
380    iBuilder->CreateBr(groupCond);
381
382    iBuilder->SetInsertPoint(exit);
383    kBuilder->finalize();
384}
385
386/** ------------------------------------------------------------------------------------------------------------- *
387 * @brief generateGatherFunction
388 ** ------------------------------------------------------------------------------------------------------------- */
389Function * SymbolTableBuilder::generateGatherFunction(const unsigned minKeyLength, const unsigned maxKeyLength) {
390
391    assert (minKeyLength < maxKeyLength);
392
393    const std::string functionName = "gather_" + std::to_string(minKeyLength) + "_to_" + std::to_string(maxKeyLength);
394    Function * function = mMod->getFunction(functionName);
395    if (function == nullptr) {
396
397        const auto ip = iBuilder->saveIP();
398
399        const unsigned minCount = (minKeyLength / 4);
400        const unsigned maxCount = ((maxKeyLength + 3) / 4);
401
402        const unsigned vectorWidth = iBuilder->getBitBlockWidth() / 32;
403        Type * const gatherVectorType =  VectorType::get(iBuilder->getInt32Ty(), vectorWidth);
404        const unsigned gatherByteWidth = gatherVectorType->getPrimitiveSizeInBits() / 8;
405        Type * const transposedVectorType = VectorType::get(iBuilder->getInt8Ty(), iBuilder->getBitBlockWidth() / 8);
406        const unsigned transposedByteWidth = transposedVectorType->getPrimitiveSizeInBits() / 8;
407
408
409        Type * const int32PtrTy = PointerType::get(iBuilder->getInt32Ty(), 0);
410        FunctionType * const functionType = FunctionType::get(iBuilder->getVoidTy(), {iBuilder->getInt8PtrTy(), int32PtrTy, int32PtrTy, iBuilder->getInt32Ty(), iBuilder->getInt8PtrTy()}, false);
411        function = Function::Create(functionType, GlobalValue::ExternalLinkage, functionName, mMod);
412        function->setCallingConv(CallingConv::C);
413        function->setDoesNotCapture(1);
414        function->setDoesNotCapture(2);
415        function->setDoesNotCapture(3);
416        function->setDoesNotThrow();
417
418        Function::arg_iterator args = function->arg_begin();
419        Value * const base = &*(args++);
420        base->setName("base");
421        Value * startArray = &*(args++);
422        startArray->setName("startArray");
423        Value * endArray = &*(args++);
424        endArray->setName("endArray");
425        Value * const numOfKeys = &*(args++);
426        numOfKeys->setName("numOfKeys");
427        Value * result = &*(args++);
428        result->setName("result");
429
430        BasicBlock * entry = BasicBlock::Create(mMod->getContext(), "entry", function, 0);
431        BasicBlock * gatherCond = BasicBlock::Create(mMod->getContext(), "gatherCond", function, 0);
432        BasicBlock * partialGatherCond = BasicBlock::Create(mMod->getContext(), "partialGatherCond", function, 0);
433        BasicBlock * partialGatherBody = BasicBlock::Create(mMod->getContext(), "partialGatherBody", function, 0);
434        BasicBlock * gatherBody = BasicBlock::Create(mMod->getContext(), "gatherBody", function, 0);
435        BasicBlock * transposeCond = BasicBlock::Create(mMod->getContext(), "transposeCond", function, 0);
436        BasicBlock * transposeBody = BasicBlock::Create(mMod->getContext(), "transposeBody", function, 0);
437        BasicBlock * exit = BasicBlock::Create(mMod->getContext(), "exit", function, 0);
438
439        Value * const four = iBuilder->CreateVectorSplat(vectorWidth, iBuilder->getInt32(4));
440
441        // ENTRY
442        iBuilder->SetInsertPoint(entry);
443
444        AllocaInst * const buffer = iBuilder->CreateAlloca(gatherVectorType, iBuilder->getInt32(maxCount * 4), "buffer");
445        Value * end = iBuilder->CreateGEP(buffer, iBuilder->getInt32(maxCount * 4));
446        Value * size = iBuilder->CreateSub(iBuilder->CreatePtrToInt(end, iBuilder->getSizeTy()), iBuilder->CreatePtrToInt(buffer, iBuilder->getSizeTy()));
447        iBuilder->CreateMemSet(buffer, iBuilder->getInt8(0), size, 4);
448        Value * const transposed = iBuilder->CreateBitCast(buffer, transposedVectorType->getPointerTo(), "transposed");
449
450        startArray = iBuilder->CreateBitCast(startArray, gatherVectorType->getPointerTo());
451        endArray = iBuilder->CreateBitCast(endArray, gatherVectorType->getPointerTo());
452
453        iBuilder->CallPrintInt(functionName + ".numOfKeys", numOfKeys);
454
455        iBuilder->CreateBr(gatherCond);
456
457        // FULL GATHER COND
458        iBuilder->SetInsertPoint(gatherCond);
459        PHINode * remainingLanes = iBuilder->CreatePHI(iBuilder->getInt32Ty(), 2);
460        remainingLanes->addIncoming(numOfKeys, entry);
461
462        PHINode * gatherIV = iBuilder->CreatePHI(iBuilder->getInt32Ty(), 2);
463        gatherIV->addIncoming(iBuilder->getInt32(0), entry);
464
465        Value * gatherLoopTest = iBuilder->CreateICmpSGE(remainingLanes, iBuilder->getInt32(vectorWidth));
466        iBuilder->CreateCondBr(gatherLoopTest, gatherBody, partialGatherCond);
467
468        // PARTIAL GATHER COND
469        iBuilder->SetInsertPoint(partialGatherCond);
470        Value * partialGatherLoopTest = iBuilder->CreateICmpSLE(remainingLanes, iBuilder->getInt32(0));
471        iBuilder->CreateCondBr(partialGatherLoopTest, transposeCond, partialGatherBody);
472
473        // PARTIAL GATHER BODY
474        iBuilder->SetInsertPoint(partialGatherBody);
475        Type * registerType = iBuilder->getIntNTy(iBuilder->getBitBlockWidth());
476        Value * maskedLanes = iBuilder->CreateSub(iBuilder->getInt32(vectorWidth), remainingLanes);       
477        maskedLanes = iBuilder->CreateMul(maskedLanes, iBuilder->getInt32(32));
478        maskedLanes = iBuilder->CreateZExt(maskedLanes, registerType);
479        maskedLanes = iBuilder->CreateLShr(Constant::getAllOnesValue(registerType), maskedLanes);
480        maskedLanes = iBuilder->CreateBitCast(maskedLanes, gatherVectorType);
481        iBuilder->CreateBr(gatherBody);
482
483        // FULL GATHER BODY
484        iBuilder->SetInsertPoint(gatherBody);
485        PHINode * activeLanes = iBuilder->CreatePHI(gatherVectorType, 2, "activeLanes");
486        activeLanes->addIncoming(Constant::getAllOnesValue(gatherVectorType), gatherCond);
487        activeLanes->addIncoming(maskedLanes, partialGatherBody);
488
489
490        Value * startPos = iBuilder->CreateAlignedLoad(iBuilder->CreateGEP(startArray, gatherIV), 4);
491        Value * const endPos = iBuilder->CreateAlignedLoad(iBuilder->CreateGEP(endArray, gatherIV), 4);
492
493        for (unsigned blockCount = 0; blockCount < minCount; ++blockCount) {
494            Value * tokenData = generateMaskedGather(base, startPos, activeLanes);
495            Value * ptr = iBuilder->CreateGEP(buffer, iBuilder->CreateOr(gatherIV, iBuilder->getInt32(blockCount * 4)));
496            iBuilder->CreateAlignedStore(tokenData, ptr, transposedByteWidth);
497            startPos = iBuilder->CreateAdd(startPos, four);
498        }
499
500        for (unsigned blockCount = minCount; blockCount < maxCount; ++blockCount) {
501
502            // if we have not fully gathered the data for this key
503            Value * atLeastOneByte = iBuilder->CreateSExt(iBuilder->CreateICmpSLT(startPos, endPos), startPos->getType());
504            atLeastOneByte = iBuilder->CreateAnd(atLeastOneByte, activeLanes, "atLeastOneByte");
505
506            // gather it ...
507            Value * tokenData = generateMaskedGather(base, startPos, atLeastOneByte);
508
509            // and compute how much data is remaining.
510            Value * remaining = iBuilder->CreateSub(endPos, startPos);
511
512            // if this token has at least 4 bytes remaining ...
513            Value * atLeastFourBytes = iBuilder->CreateSExt(iBuilder->CreateICmpUGE(remaining, four), remaining->getType(), "atLeastFourBytes");
514
515            // determine how many bits do *not* belong to the token
516            remaining = iBuilder->CreateSub(four, remaining);
517            remaining = iBuilder->CreateShl(remaining, ConstantInt::get(remaining->getType(), 3));
518
519            // then mask them out prior to storing the value
520            Value * partialTokenMask = iBuilder->CreateLShr(ConstantInt::getAllOnesValue(remaining->getType()), remaining);
521            partialTokenMask = iBuilder->CreateOr(partialTokenMask, atLeastFourBytes);
522            tokenData = iBuilder->CreateAnd(partialTokenMask, tokenData);
523            Value * ptr = iBuilder->CreateGEP(buffer, iBuilder->CreateOr(gatherIV, iBuilder->getInt32(blockCount * 4)));
524            iBuilder->CreateAlignedStore(tokenData, ptr, transposedByteWidth);
525
526            startPos = iBuilder->CreateAdd(startPos, four);
527        }
528
529        gatherIV->addIncoming(iBuilder->CreateAdd(gatherIV, iBuilder->getInt32(1)), gatherBody);
530        remainingLanes->addIncoming(iBuilder->CreateSub(remainingLanes, iBuilder->getInt32(vectorWidth)), gatherBody);
531        iBuilder->CreateBr(gatherCond);
532
533        // TRANSPOSE COND
534        iBuilder->SetInsertPoint(transposeCond);
535        PHINode * transposeIV = iBuilder->CreatePHI(iBuilder->getInt32Ty(), 2);
536        transposeIV->addIncoming(iBuilder->getInt32(0), partialGatherCond);
537        Value * transposeLoopTest = iBuilder->CreateICmpNE(transposeIV, iBuilder->getInt32(maxCount));
538        iBuilder->CreateCondBr(transposeLoopTest, transposeBody, exit);
539
540        // TRANSPOSE BODY
541        iBuilder->SetInsertPoint(transposeBody);
542
543        Value * offset = iBuilder->CreateMul(transposeIV, iBuilder->getInt32(4));
544
545        Value * value[4];
546        for (unsigned i = 0; i < 4; ++i) {
547            Value * const ptr = iBuilder->CreateGEP(buffer, iBuilder->CreateAdd(offset, iBuilder->getInt32(i)));
548            value[i] = iBuilder->CreateLoad(ptr);
549        }
550
551        for (unsigned byteWidth = 2; byteWidth; --byteWidth) {
552            const unsigned fieldWidth = (byteWidth * 8);
553            const unsigned fieldCount = iBuilder->getBitBlockWidth() / fieldWidth;
554            VectorType * const type = VectorType::get(Type::getIntNTy(iBuilder->getContext(), fieldWidth), fieldCount);
555            std::vector<Constant *> even(fieldCount);
556            std::vector<Constant *> odd(fieldCount);
557            for (unsigned j = 0; j < fieldCount; ++j) {
558                even[j] = iBuilder->getInt32(j * 2);
559                odd[j] = iBuilder->getInt32(j * 2 + 1);
560            }
561            Constant * const evenVector = ConstantVector::get(even);
562            Constant * const oddVector = ConstantVector::get(odd);
563            Value * result[4];
564            for (unsigned i = 0; i < 4; i += 2) {
565                value[i] = iBuilder->CreateBitCast(value[i], type);
566                value[i + 1] = iBuilder->CreateBitCast(value[i + 1], type);
567                result[(i / byteWidth)] = iBuilder->CreateShuffleVector(value[i], value[i + 1], evenVector);
568                result[(i / byteWidth) + byteWidth] = iBuilder->CreateShuffleVector(value[i], value[i + 1], oddVector);
569            }
570            for (unsigned i = 0; i < 4; ++i) {
571                value[i] = result[i];
572            }
573        }
574
575        for (unsigned i = 0; i < 4; ++i) {
576            Value * ptr = iBuilder->CreateGEP(transposed, iBuilder->CreateAdd(offset, iBuilder->getInt32(i)));
577            iBuilder->CreateAlignedStore(value[i], ptr, gatherByteWidth);
578        }
579
580        transposeIV->addIncoming(iBuilder->CreateAdd(transposeIV, iBuilder->getInt32(1)), transposeBody);
581        iBuilder->CreateBr(transposeCond);
582
583        // EXIT
584        iBuilder->SetInsertPoint(exit);
585
586        // ... call hashing function ...
587
588        for (unsigned i = 0; i < maxKeyLength; ++i) {
589            Value * ptr = iBuilder->CreateGEP(transposed, iBuilder->getInt32(i));
590            Value * value = iBuilder->CreateAlignedLoad(ptr, gatherByteWidth);
591            iBuilder->CallPrintRegister(functionName + ".output" + std::to_string(i), value);
592        }
593
594        iBuilder->CreateRetVoid();
595
596        iBuilder->restoreIP(ip);
597    }
598
599    return function;
600}
601
602
603/** ------------------------------------------------------------------------------------------------------------- *
604 * @brief createKernels
605 ** ------------------------------------------------------------------------------------------------------------- */
606void SymbolTableBuilder::createKernels() {
607
608    std::vector<unsigned> endpoints;
609    endpoints.push_back(8);
610    endpoints.push_back(17);
611    endpoints.push_back(27);
612    endpoints.push_back(39);
613    endpoints.push_back(77);
614    endpoints.push_back(124);
615    endpoints.push_back(178);
616    endpoints.push_back(278);
617
618    PabloCompiler pablo_compiler(mMod, iBuilder);
619    PabloFunction * const leading = generateLeadingFunction(endpoints);
620    PabloFunction * const sorting = generateSortingFunction(leading, endpoints);
621
622    const auto bufferSize = ((mLongestLookahead + iBuilder->getBitBlockWidth() - 1) / iBuilder->getBitBlockWidth()) + 1;
623
624    mS2PKernel = new KernelBuilder(iBuilder, "s2p", 1);
625    mLeadingKernel = new KernelBuilder(iBuilder, "leading", bufferSize);
626    mSortingKernel = new KernelBuilder(iBuilder, "sorting", bufferSize);
627    mGatherKernel = new KernelBuilder(iBuilder, "gathering", 1);
628
629    generateS2PKernel(mMod, iBuilder, mS2PKernel);
630
631    pablo_compiler.setKernel(mLeadingKernel);
632    pablo_compiler.compile(leading);
633    pablo_compiler.setKernel(mSortingKernel);
634    pablo_compiler.compile(sorting);
635
636    delete leading;
637    delete sorting;
638
639    generateGatherKernel(mGatherKernel, endpoints, 64);
640}
641
642Function * SymbolTableBuilder::ExecuteKernels(){
643
644    Type * intType = iBuilder->getSizeTy();
645
646    Type * inputType = PointerType::get(ArrayType::get(StructType::get(mMod->getContext(), std::vector<Type *>({ArrayType::get(mBitBlockType, 8)})), 1), 0);
647    Function * const main = cast<Function>(mMod->getOrInsertFunction("Main", iBuilder->getVoidTy(), inputType, intType, nullptr));
648    main->setCallingConv(CallingConv::C);
649    Function::arg_iterator args = main->arg_begin();
650
651    Value * const inputStream = &*(args++);
652    inputStream->setName("inputStream");
653
654    Value * const bufferSize = &*(args++);
655    bufferSize->setName("bufferSize");
656
657    iBuilder->SetInsertPoint(BasicBlock::Create(mMod->getContext(), "entry", main,0));
658
659    BasicBlock * entryBlock = iBuilder->GetInsertBlock();
660
661    BasicBlock * leadingBlock = BasicBlock::Create(mMod->getContext(), "leadingBody", main, 0);
662
663    BasicBlock * partialLeadingCond = BasicBlock::Create(mMod->getContext(), "partialLeadingCond", main, 0);
664    BasicBlock * partialLeadingBody = BasicBlock::Create(mMod->getContext(), "partialLeadingBody", main, 0);
665
666    BasicBlock * regularCondBlock = BasicBlock::Create(mMod->getContext(), "regularCond", main, 0);
667    BasicBlock * regularBodyBlock = BasicBlock::Create(mMod->getContext(), "regularBody", main, 0);
668
669    BasicBlock * partialCondBlock = BasicBlock::Create(mMod->getContext(), "partialCond", main, 0);
670    BasicBlock * partialBodyBlock = BasicBlock::Create(mMod->getContext(),  "partialBody", main, 0);
671
672    BasicBlock * flushLengthGroupsBlock = BasicBlock::Create(mMod->getContext(), "flushLengthGroups", main, 0);
673
674    Instance * s2pInstance = mS2PKernel->instantiate(inputStream);
675    Instance * leadingInstance = mLeadingKernel->instantiate(s2pInstance->getOutputStreamBuffer());
676    Instance * sortingInstance = mSortingKernel->instantiate(leadingInstance->getOutputStreamBuffer());
677    Instance * gatheringInstance = mGatherKernel->instantiate(sortingInstance->getOutputStreamBuffer());
678
679    gatheringInstance->setInternalState("Base", iBuilder->CreateBitCast(inputStream, iBuilder->getInt8PtrTy()));
680
681    const unsigned leadingBlocks = (mLongestLookahead + iBuilder->getBitBlockWidth() - 1) / iBuilder->getBitBlockWidth();
682
683    Value * const requiredBytes = iBuilder->getSize(mBlockSize * leadingBlocks);
684    Value * const blockSize = iBuilder->getSize(mBlockSize);
685
686    // First compute any necessary leading blocks to allow the sorting kernel access to the "future" data produced by
687    // the leading kernel ...
688
689    Value * enoughDataForLookaheadCond = iBuilder->CreateICmpUGE(bufferSize, requiredBytes);
690    iBuilder->CreateCondBr(enoughDataForLookaheadCond, leadingBlock, partialLeadingCond);
691
692    iBuilder->SetInsertPoint(leadingBlock);
693    for (unsigned i = 0; i < leadingBlocks; ++i) {
694        s2pInstance->CreateDoBlockCall();
695        leadingInstance->CreateDoBlockCall();
696    }
697    iBuilder->CreateBr(regularCondBlock);
698
699    iBuilder->SetInsertPoint(partialLeadingCond);
700    PHINode * remainingBytes1 = iBuilder->CreatePHI(intType, 2);
701    remainingBytes1->addIncoming(bufferSize, entryBlock);
702    Value * remainingCond = iBuilder->CreateICmpUGT(remainingBytes1, blockSize);
703    iBuilder->CreateCondBr(remainingCond, partialLeadingBody, partialCondBlock);
704
705    iBuilder->SetInsertPoint(partialLeadingBody);
706    s2pInstance->CreateDoBlockCall();
707    leadingInstance->CreateDoBlockCall();
708    remainingBytes1->addIncoming(iBuilder->CreateSub(remainingBytes1, blockSize), partialLeadingBody);
709    iBuilder->CreateBr(partialLeadingCond);
710
711    // Now all the data for which we can produce and consume a full leading block...
712    iBuilder->SetInsertPoint(regularCondBlock);
713    PHINode * remainingBytes2 = iBuilder->CreatePHI(intType, 2);
714    remainingBytes2->addIncoming(bufferSize, leadingBlock);
715    Value * remainingBytesCond = iBuilder->CreateICmpUGT(remainingBytes2, requiredBytes);
716    iBuilder->CreateCondBr(remainingBytesCond, regularBodyBlock, partialCondBlock);
717
718    iBuilder->SetInsertPoint(regularBodyBlock);
719    s2pInstance->CreateDoBlockCall();
720    leadingInstance->CreateDoBlockCall();
721    sortingInstance->CreateDoBlockCall();
722    gatheringInstance->CreateDoBlockCall();
723    remainingBytes2->addIncoming(iBuilder->CreateSub(remainingBytes2, blockSize), regularBodyBlock);
724    iBuilder->CreateBr(regularCondBlock);
725
726    // Check if we have a partial blocks worth of leading data remaining
727    iBuilder->SetInsertPoint(partialCondBlock);
728    PHINode * remainingBytes3 = iBuilder->CreatePHI(intType, 3);
729    remainingBytes3->addIncoming(bufferSize, partialLeadingCond);
730    remainingBytes3->addIncoming(remainingBytes2, regularCondBlock);
731    Value * partialBlockCond = iBuilder->CreateICmpSGT(remainingBytes3, iBuilder->getSize(0));
732    iBuilder->CreateCondBr(partialBlockCond, partialBodyBlock, flushLengthGroupsBlock);
733
734    // If we do, process it and mask out the data
735    iBuilder->SetInsertPoint(partialBodyBlock);
736    s2pInstance->clearOutputStreamSet();
737    leadingInstance->CreateDoBlockCall();   
738    sortingInstance->CreateDoBlockCall();
739    gatheringInstance->CreateDoBlockCall();
740    remainingBytes3->addIncoming(iBuilder->CreateSub(remainingBytes3, blockSize), partialBodyBlock);
741    iBuilder->CreateBr(partialCondBlock);
742
743    // perform a final partial gather on all length groups ...
744    iBuilder->SetInsertPoint(flushLengthGroupsBlock);
745
746    Value * const base = iBuilder->CreateLoad(gatheringInstance->getInternalState("Base"));
747    Value * positionArray = gatheringInstance->getInternalState("Positions");
748
749    for (unsigned i = 0; i < mGatherFunction.size(); ++i) {
750        BasicBlock * nonEmptyGroup = BasicBlock::Create(mMod->getContext(), "flushLengthGroup" + std::to_string(i), main, 0);
751
752        BasicBlock * nextNonEmptyGroup = BasicBlock::Create(mMod->getContext(), "", main, 0);
753
754        ConstantInt * groupIV = iBuilder->getInt32(i);
755        Value * startIndexPtr = iBuilder->CreateGEP(positionArray, {iBuilder->getInt32(0), groupIV, iBuilder->getInt32(0)}, "startIndexPtr");
756        Value * startIndex = iBuilder->CreateLoad(startIndexPtr, "remaining");
757        Value * cond = iBuilder->CreateICmpNE(startIndex, ConstantInt::getNullValue(startIndex->getType()));
758        iBuilder->CreateCondBr(cond, nonEmptyGroup, nextNonEmptyGroup);
759
760        iBuilder->SetInsertPoint(nonEmptyGroup);
761        Value * startArray = iBuilder->CreateGEP(positionArray, {iBuilder->getInt32(0), groupIV, iBuilder->getInt32(1)}, "startArray");
762        Value * startArrayPtr = iBuilder->CreatePointerCast(startArray, PointerType::get(iBuilder->getInt32Ty(), 0));
763        Value * endArray = iBuilder->CreateGEP(positionArray, {iBuilder->getInt32(0), groupIV, iBuilder->getInt32(3)}, "endArray");
764        Value * endArrayPtr = iBuilder->CreatePointerCast(endArray, PointerType::get(iBuilder->getInt32Ty(), 0));
765        Value * outputBuffer = iBuilder->CreatePointerCast(gatheringInstance->getOutputStream(groupIV), iBuilder->getInt8PtrTy());
766        iBuilder->CreateCall(mGatherFunction.at(i), {base, startArrayPtr, endArrayPtr, startIndex, outputBuffer});
767        iBuilder->CreateBr(nextNonEmptyGroup);
768
769        iBuilder->SetInsertPoint(nextNonEmptyGroup);
770    }
771    iBuilder->CreateRetVoid();
772
773    return main;
774}
775
776SymbolTableBuilder::~SymbolTableBuilder() {
777    delete mS2PKernel;
778    delete mLeadingKernel;
779    delete mSortingKernel;
780    delete mGatherKernel;
781}
782
783
784}
Note: See TracBrowser for help on using the repository browser.