source: icGREP/icgrep-devel/icgrep/pablo/pablo_compiler.cpp @ 4273

Last change on this file since 4273 was 4273, checked in by cameron, 4 years ago

CodeGenOpt::Level::Less optimization improves performance with while loops

File size: 34.2 KB
Line 
1/*
2 *  Copyright (c) 2014 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 *  icgrep is a trademark of International Characters.
5 */
6
7/*
8 *  Copyright (c) 2014 International Characters.
9 *  This software is licensed to the public under the Open Software License 3.0.
10 *  icgrep is a trademark of International Characters.
11 */
12
13#include <pablo/pablo_compiler.h>
14#include <pablo/codegenstate.h>
15#include <pablo/printer_pablos.h>
16#include <cc/cc_namemap.hpp>
17#include <re/re_name.h>
18#include <stdexcept>
19#include <include/simd-lib/bitblock.hpp>
20
21//#define DUMP_GENERATED_IR
22//#define DUMP_OPTIMIZED_IR
23
24extern "C" {
25  void wrapped_print_register(BitBlock bit_block) {
26      print_register<BitBlock>("", bit_block);
27  }
28}
29
30#define CREATE_GENERAL_CODE_CATEGORY(SUFFIX) \
31SUFFIX * f##SUFFIX = nullptr; \
32extern "C" { \
33    BitBlock __get_category_##SUFFIX(Basis_bits &basis_bits) { \
34        if (f##SUFFIX == nullptr) f##SUFFIX = new SUFFIX(); \
35        Struct_##SUFFIX output; \
36        f##SUFFIX->do_block(basis_bits, output); \
37        return output.cc; \
38    } \
39}
40
41CREATE_GENERAL_CODE_CATEGORY(Cc)
42CREATE_GENERAL_CODE_CATEGORY(Cf)
43CREATE_GENERAL_CODE_CATEGORY(Cn)
44CREATE_GENERAL_CODE_CATEGORY(Co)
45CREATE_GENERAL_CODE_CATEGORY(Cs)
46CREATE_GENERAL_CODE_CATEGORY(Ll)
47CREATE_GENERAL_CODE_CATEGORY(Lm)
48CREATE_GENERAL_CODE_CATEGORY(Lo)
49CREATE_GENERAL_CODE_CATEGORY(Lt)
50CREATE_GENERAL_CODE_CATEGORY(Lu)
51CREATE_GENERAL_CODE_CATEGORY(Mc)
52CREATE_GENERAL_CODE_CATEGORY(Me)
53CREATE_GENERAL_CODE_CATEGORY(Mn)
54CREATE_GENERAL_CODE_CATEGORY(Nd)
55CREATE_GENERAL_CODE_CATEGORY(Nl)
56CREATE_GENERAL_CODE_CATEGORY(No)
57CREATE_GENERAL_CODE_CATEGORY(Pc)
58CREATE_GENERAL_CODE_CATEGORY(Pd)
59CREATE_GENERAL_CODE_CATEGORY(Pe)
60CREATE_GENERAL_CODE_CATEGORY(Pf)
61CREATE_GENERAL_CODE_CATEGORY(Pi)
62CREATE_GENERAL_CODE_CATEGORY(Po)
63CREATE_GENERAL_CODE_CATEGORY(Ps)
64CREATE_GENERAL_CODE_CATEGORY(Sc)
65CREATE_GENERAL_CODE_CATEGORY(Sk)
66CREATE_GENERAL_CODE_CATEGORY(Sm)
67CREATE_GENERAL_CODE_CATEGORY(So)
68CREATE_GENERAL_CODE_CATEGORY(Zl)
69CREATE_GENERAL_CODE_CATEGORY(Zp)
70CREATE_GENERAL_CODE_CATEGORY(Zs)
71
72#undef CREATE_GENERAL_CODE_CATEGORY
73
74namespace pablo {
75
76PabloCompiler::PabloCompiler(const std::vector<Var*> & basisBits)
77: mBasisBits(basisBits)
78, mMod(new Module("icgrep", getGlobalContext()))
79, mBasicBlock(nullptr)
80, mExecutionEngine(nullptr)
81, mBitBlockType(VectorType::get(IntegerType::get(mMod->getContext(), 64), BLOCK_SIZE / 64))
82, mBasisBitsInputPtr(nullptr)
83, mCarryQueueIdx(0)
84, mCarryQueuePtr(nullptr)
85, mNestingDepth(0)
86, mCarryQueueSize(0)
87, mZeroInitializer(ConstantAggregateZero::get(mBitBlockType))
88, mOneInitializer(ConstantVector::getAllOnesValue(mBitBlockType))
89, mFunctionType(nullptr)
90, mFunc_process_block(nullptr)
91, mBasisBitsAddr(nullptr)
92, mOutputAddrPtr(nullptr)
93, mMaxPabloWhileDepth(0)
94{
95    //Create the jit execution engine.up
96    InitializeNativeTarget();
97
98    InitializeNativeTargetAsmPrinter();
99    InitializeNativeTargetAsmParser();
100
101    DefineTypes();
102    DeclareFunctions();
103}
104
105PabloCompiler::~PabloCompiler()
106{
107    delete mMod;
108    delete fPs;
109    delete fNl;
110    delete fNo;
111    delete fLo;
112    delete fLl;
113    delete fLm;
114    delete fNd;
115    delete fPc;
116    delete fLt;
117    delete fLu;
118    delete fPf;
119    delete fPd;
120    delete fPe;
121    delete fPi;
122    delete fPo;
123    delete fMe;
124    delete fMc;
125    delete fMn;
126    delete fSk;
127    delete fSo;
128    delete fSm;
129    delete fSc;
130    delete fZl;
131    delete fCo;
132    delete fCn;
133    delete fCc;
134    delete fCf;
135    delete fCs;
136    delete fZp;
137    delete fZs;
138
139}
140
141LLVM_Gen_RetVal PabloCompiler::compile(PabloBlock & pb)
142{
143    mCarryQueueSize = 0;
144    DeclareCallFunctions(pb.statements());
145    mCarryQueueVector.resize(mCarryQueueSize);
146
147    Function::arg_iterator args = mFunc_process_block->arg_begin();
148    mBasisBitsAddr = args++;
149    mBasisBitsAddr->setName("basis_bits");
150    mCarryQueuePtr = args++;
151    mCarryQueuePtr->setName("carry_q");
152    mOutputAddrPtr = args++;
153    mOutputAddrPtr->setName("output");
154
155    //Create the carry queue.
156    mCarryQueueIdx = 0;
157    mNestingDepth = 0;
158    mBasicBlock = BasicBlock::Create(mMod->getContext(), "parabix_entry", mFunc_process_block,0);
159
160    //The basis bits structure
161    for (unsigned i = 0; i != mBasisBits.size(); ++i) {
162        IRBuilder<> b(mBasicBlock);
163        Value* indices[] = {b.getInt64(0), b.getInt32(i)};
164        const String * const name = mBasisBits[i]->getName();
165        Value * gep = b.CreateGEP(mBasisBitsAddr, indices);
166        LoadInst * basisBit = b.CreateAlignedLoad(gep, BLOCK_SIZE/8, false, name->str());
167        mMarkerMap.insert(std::make_pair(name, basisBit));
168    }
169
170    //Generate the IR instructions for the function.
171    compileStatements(pb.statements());
172
173    assert (mCarryQueueIdx == mCarryQueueSize);
174    assert (mNestingDepth == 0);
175    //Terminate the block
176    ReturnInst::Create(mMod->getContext(), mBasicBlock);
177
178    //Un-comment this line in order to display the IR that has been generated by this module.
179    #ifdef DUMP_GENERATED_IR
180    mMod->dump();
181    #endif
182
183    //Create a verifier.  The verifier will print an error message if our module is malformed in any way.
184    #ifdef USE_LLVM_3_5
185    verifyModule(*mMod, &dbgs());
186    #endif
187    #ifdef USE_LLVM_3_4
188    verifyModule(*mMod, PrintMessageAction);
189    #endif
190
191    //Use the pass manager to run optimizations on the function.
192    FunctionPassManager fpm(mMod);
193    std::string ErrStr;
194
195    if (mMaxPabloWhileDepth == 0) {
196      mExecutionEngine = EngineBuilder(mMod).setUseMCJIT(true).setErrorStr(&ErrStr).setOptLevel(CodeGenOpt::Level::None).create();
197    }
198    else {
199      mExecutionEngine = EngineBuilder(mMod).setUseMCJIT(true).setErrorStr(&ErrStr).setOptLevel(CodeGenOpt::Level::Less).create();
200    }
201    if (mExecutionEngine == nullptr) {
202        throw std::runtime_error("Could not create ExecutionEngine: " + ErrStr);
203    }
204
205
206#ifdef USE_LLVM_3_5
207    mMod->setDataLayout(mExecutionEngine->getDataLayout());
208    // Set up the optimizer pipeline.  Start with registering info about how the target lays out data structures.
209    fpm.add(new DataLayoutPass(mMod));
210#endif
211
212#ifdef USE_LLVM_3_4
213    fpm.add(new DataLayout(*mExecutionEngine->getDataLayout()));
214#endif
215
216    //fpm.add(createPromoteMemoryToRegisterPass()); //Transform to SSA form.
217    //fpm.add(createBasicAliasAnalysisPass());      //Provide basic AliasAnalysis support for GVN. (Global Value Numbering)
218    //fpm.add(createCFGSimplificationPass());       //Simplify the control flow graph.
219    fpm.add(createInstructionCombiningPass());    //Simple peephole optimizations and bit-twiddling.
220    fpm.add(createReassociatePass());             //Reassociate expressions.
221    fpm.add(createGVNPass());                     //Eliminate common subexpressions.
222
223    fpm.doInitialization();
224
225    fpm.run(*mFunc_process_block);
226
227#ifdef DUMP_OPTIMIZED_IR
228    mMod->dump();
229#endif
230    mExecutionEngine->finalizeObject();
231
232    LLVM_Gen_RetVal retVal;
233    //Return the required size of the carry queue and a pointer to the process_block function.
234    retVal.carry_q_size = mCarryQueueSize;
235    retVal.process_block_fptr = mExecutionEngine->getPointerToFunction(mFunc_process_block);
236
237    return retVal;
238}
239
240void PabloCompiler::DefineTypes()
241{
242    StructType * structBasisBits = mMod->getTypeByName("struct.Basis_bits");
243    if (structBasisBits == nullptr) {
244        structBasisBits = StructType::create(mMod->getContext(), "struct.Basis_bits");
245    }
246    std::vector<Type*>StructTy_struct_Basis_bits_fields;
247    for (int i = 0; i != mBasisBits.size(); i++)
248    {
249        StructTy_struct_Basis_bits_fields.push_back(mBitBlockType);
250    }
251    if (structBasisBits->isOpaque()) {
252        structBasisBits->setBody(StructTy_struct_Basis_bits_fields, /*isPacked=*/false);
253    }
254    mBasisBitsInputPtr = PointerType::get(structBasisBits, 0);
255
256    std::vector<Type*>functionTypeArgs;
257    functionTypeArgs.push_back(mBasisBitsInputPtr);
258
259    //The carry q array.
260    //A pointer to the BitBlock vector.
261    functionTypeArgs.push_back(PointerType::get(mBitBlockType, 0));
262
263    //The output structure.
264    StructType * outputStruct = mMod->getTypeByName("struct.Output");
265    if (!outputStruct) {
266        outputStruct = StructType::create(mMod->getContext(), "struct.Output");
267    }
268    if (outputStruct->isOpaque()) {
269        std::vector<Type*>fields;
270        fields.push_back(mBitBlockType);
271        fields.push_back(mBitBlockType);
272        outputStruct->setBody(fields, /*isPacked=*/false);
273    }
274    PointerType* outputStructPtr = PointerType::get(outputStruct, 0);
275
276    //The &output parameter.
277    functionTypeArgs.push_back(outputStructPtr);
278
279    mFunctionType = FunctionType::get(
280     /*Result=*/Type::getVoidTy(mMod->getContext()),
281     /*Params=*/functionTypeArgs,
282     /*isVarArg=*/false);
283}
284
285void PabloCompiler::DeclareFunctions()
286{
287    //This function can be used for testing to print the contents of a register from JIT'd code to the terminal window.
288    //mFunc_print_register = mMod->getOrInsertFunction("wrapped_print_register", Type::getVoidTy(getGlobalContext()), mXi64Vect, NULL);
289    //mExecutionEngine->addGlobalMapping(cast<GlobalValue>(mFunc_print_register), (void *)&wrapped_print_register);
290    // to call->  b.CreateCall(mFunc_print_register, unicode_category);
291
292#ifdef USE_UADD_OVERFLOW
293    // Type Definitions for llvm.uadd.with.overflow.carryin.i128 or .i256
294    std::vector<Type*>StructTy_0_fields;
295    StructTy_0_fields.push_back(IntegerType::get(mMod->getContext(), BLOCK_SIZE));
296    StructTy_0_fields.push_back(IntegerType::get(mMod->getContext(), 1));
297    StructType *StructTy_0 = StructType::get(mMod->getContext(), StructTy_0_fields, /*isPacked=*/false);
298
299    std::vector<Type*>FuncTy_1_args;
300    FuncTy_1_args.push_back(IntegerType::get(mMod->getContext(), BLOCK_SIZE));
301    FuncTy_1_args.push_back(IntegerType::get(mMod->getContext(), BLOCK_SIZE));
302    FuncTy_1_args.push_back(IntegerType::get(mMod->getContext(), 1));
303    FunctionType* FuncTy_1 = FunctionType::get(
304                                              /*Result=*/StructTy_0,
305                                              /*Params=*/FuncTy_1_args,
306                                              /*isVarArg=*/false);
307
308    mFunc_llvm_uadd_with_overflow = mMod->getFunction("llvm.uadd.with.overflow.carryin.i"##BLOCK_SIZE);
309    if (!mFunc_llvm_uadd_with_overflow) {
310        mFunc_llvm_uadd_with_overflow = Function::Create(
311          /*Type=*/ FuncTy_1,
312          /*Linkage=*/ GlobalValue::ExternalLinkage,
313          /*Name=*/ "llvm.uadd.with.overflow.carryin.i"##BLOCK_SIZE, mMod); // (external, no body)
314        mFunc_llvm_uadd_with_overflow->setCallingConv(CallingConv::C);
315    }
316    AttributeSet mFunc_llvm_uadd_with_overflow_PAL;
317    {
318        SmallVector<AttributeSet, 4> Attrs;
319        AttributeSet PAS;
320        {
321          AttrBuilder B;
322          B.addAttribute(Attribute::NoUnwind);
323          B.addAttribute(Attribute::ReadNone);
324          PAS = AttributeSet::get(mMod->getContext(), ~0U, B);
325        }
326
327        Attrs.push_back(PAS);
328        mFunc_llvm_uadd_with_overflow_PAL = AttributeSet::get(mMod->getContext(), Attrs);
329    }
330    mFunc_llvm_uadd_with_overflow->setAttributes(mFunc_llvm_uadd_with_overflow_PAL);
331#endif
332
333    //Starts on process_block
334    SmallVector<AttributeSet, 4> Attrs;
335    AttributeSet PAS;
336    {
337        AttrBuilder B;
338        B.addAttribute(Attribute::ReadOnly);
339        B.addAttribute(Attribute::NoCapture);
340        PAS = AttributeSet::get(mMod->getContext(), 1U, B);
341    }
342    Attrs.push_back(PAS);
343    {
344        AttrBuilder B;
345        B.addAttribute(Attribute::NoCapture);
346        PAS = AttributeSet::get(mMod->getContext(), 2U, B);
347    }
348    Attrs.push_back(PAS);
349    {
350        AttrBuilder B;
351        B.addAttribute(Attribute::NoCapture);
352        PAS = AttributeSet::get(mMod->getContext(), 3U, B);
353    }
354    Attrs.push_back(PAS);
355    {
356        AttrBuilder B;
357        B.addAttribute(Attribute::NoUnwind);
358        B.addAttribute(Attribute::UWTable);
359        PAS = AttributeSet::get(mMod->getContext(), ~0U, B);
360    }
361    AttributeSet AttrSet = AttributeSet::get(mMod->getContext(), Attrs);
362
363    //Create the function that will be generated.
364    mFunc_process_block = mMod->getFunction("process_block");
365    if (!mFunc_process_block) {
366        mFunc_process_block = Function::Create(
367            /*Type=*/mFunctionType,
368            /*Linkage=*/GlobalValue::ExternalLinkage,
369            /*Name=*/"process_block", mMod);
370        mFunc_process_block->setCallingConv(CallingConv::C);
371    }
372    mFunc_process_block->setAttributes(AttrSet);
373}
374
375void PabloCompiler::DeclareCallFunctions(const StatementList & stmts) {
376    for (PabloAST * stmt : stmts) {
377        if (const Assign * assign = dyn_cast<Assign>(stmt)) {
378            DeclareCallFunctions(assign->getExpr());
379        }
380        if (const Next * next = dyn_cast<Next>(stmt)) {
381            DeclareCallFunctions(next->getExpr());
382        }
383        else if (If * ifStatement = dyn_cast<If>(stmt)) {
384            const auto preIfCarryCount = mCarryQueueSize;
385            DeclareCallFunctions(ifStatement->getCondition());
386            DeclareCallFunctions(ifStatement->getBody());
387            ifStatement->setInclusiveCarryCount(mCarryQueueSize - preIfCarryCount);
388        }
389        else if (While * whileStatement = dyn_cast<While>(stmt)) {
390            const auto preWhileCarryCount = mCarryQueueSize;
391            DeclareCallFunctions(whileStatement->getCondition());
392            DeclareCallFunctions(whileStatement->getBody());
393            whileStatement->setInclusiveCarryCount(mCarryQueueSize - preWhileCarryCount);
394        }
395    }
396}
397
398void PabloCompiler::DeclareCallFunctions(const PabloAST * expr)
399{
400    if (const Call * call = dyn_cast<const Call>(expr)) {
401        const String * const callee = call->getCallee();
402        assert (callee);
403        if (mCalleeMap.find(callee) == mCalleeMap.end()) {
404            void * callee_ptr = nullptr;
405            #define CHECK_GENERAL_CODE_CATEGORY(SUFFIX) \
406                if (callee->str() == #SUFFIX) { \
407                    callee_ptr = (void*)&__get_category_##SUFFIX; \
408                } else
409            CHECK_GENERAL_CODE_CATEGORY(Cc)
410            CHECK_GENERAL_CODE_CATEGORY(Cf)
411            CHECK_GENERAL_CODE_CATEGORY(Cn)
412            CHECK_GENERAL_CODE_CATEGORY(Co)
413            CHECK_GENERAL_CODE_CATEGORY(Cs)
414            CHECK_GENERAL_CODE_CATEGORY(Ll)
415            CHECK_GENERAL_CODE_CATEGORY(Lm)
416            CHECK_GENERAL_CODE_CATEGORY(Lo)
417            CHECK_GENERAL_CODE_CATEGORY(Lt)
418            CHECK_GENERAL_CODE_CATEGORY(Lu)
419            CHECK_GENERAL_CODE_CATEGORY(Mc)
420            CHECK_GENERAL_CODE_CATEGORY(Me)
421            CHECK_GENERAL_CODE_CATEGORY(Mn)
422            CHECK_GENERAL_CODE_CATEGORY(Nd)
423            CHECK_GENERAL_CODE_CATEGORY(Nl)
424            CHECK_GENERAL_CODE_CATEGORY(No)
425            CHECK_GENERAL_CODE_CATEGORY(Pc)
426            CHECK_GENERAL_CODE_CATEGORY(Pd)
427            CHECK_GENERAL_CODE_CATEGORY(Pe)
428            CHECK_GENERAL_CODE_CATEGORY(Pf)
429            CHECK_GENERAL_CODE_CATEGORY(Pi)
430            CHECK_GENERAL_CODE_CATEGORY(Po)
431            CHECK_GENERAL_CODE_CATEGORY(Ps)
432            CHECK_GENERAL_CODE_CATEGORY(Sc)
433            CHECK_GENERAL_CODE_CATEGORY(Sk)
434            CHECK_GENERAL_CODE_CATEGORY(Sm)
435            CHECK_GENERAL_CODE_CATEGORY(So)
436            CHECK_GENERAL_CODE_CATEGORY(Zl)
437            CHECK_GENERAL_CODE_CATEGORY(Zp)
438            CHECK_GENERAL_CODE_CATEGORY(Zs)
439            // OTHERWISE ...
440            throw std::runtime_error("Unknown unicode category \"" + callee->str() + "\"");
441            #undef CHECK_GENERAL_CODE_CATEGORY
442            Value * unicodeCategory = mMod->getOrInsertFunction("__get_category_" + callee->str(), mBitBlockType, mBasisBitsInputPtr, NULL);
443            if (unicodeCategory == nullptr) {
444                throw std::runtime_error("Could not create static method call for unicode category \"" + callee->str() + "\"");
445            }
446            mExecutionEngine->addGlobalMapping(cast<GlobalValue>(unicodeCategory), callee_ptr);
447            mCalleeMap.insert(std::make_pair(callee, unicodeCategory));
448        }
449    }
450    else if (const And * pablo_and = dyn_cast<const And>(expr))
451    {
452        DeclareCallFunctions(pablo_and->getExpr1());
453        DeclareCallFunctions(pablo_and->getExpr2());
454    }
455    else if (const Or * pablo_or = dyn_cast<const Or>(expr))
456    {
457        DeclareCallFunctions(pablo_or->getExpr1());
458        DeclareCallFunctions(pablo_or->getExpr2());
459    }
460    else if (const Sel * pablo_sel = dyn_cast<const Sel>(expr))
461    {
462        DeclareCallFunctions(pablo_sel->getCondition());
463        DeclareCallFunctions(pablo_sel->getTrueExpr());
464        DeclareCallFunctions(pablo_sel->getFalseExpr());
465    }
466    else if (const Not * pablo_not = dyn_cast<const Not>(expr))
467    {
468        DeclareCallFunctions(pablo_not->getExpr());
469    }
470    else if (const Advance * adv = dyn_cast<const Advance>(expr))
471    {
472        ++mCarryQueueSize;
473        DeclareCallFunctions(adv->getExpr());
474    }
475    else if (const MatchStar * mstar = dyn_cast<const MatchStar>(expr))
476    {
477        ++mCarryQueueSize;
478        DeclareCallFunctions(mstar->getMarker());
479        DeclareCallFunctions(mstar->getCharClass());
480    }
481    else if (const ScanThru * sthru = dyn_cast<const ScanThru>(expr))
482    {
483        ++mCarryQueueSize;
484        DeclareCallFunctions(sthru->getScanFrom());
485        DeclareCallFunctions(sthru->getScanThru());
486    }
487}
488
489Value * PabloCompiler::compileStatements(const StatementList & stmts) {
490    Value * retVal = nullptr;
491    for (PabloAST * statement : stmts) {
492        retVal = compileStatement(statement);
493    }
494    return retVal;
495}
496
497Value * PabloCompiler::compileStatement(const PabloAST * stmt)
498{
499    Value * retVal = nullptr;
500    if (const Assign * assign = dyn_cast<const Assign>(stmt))
501    {
502        Value* expr = compileExpression(assign->getExpr());
503        mMarkerMap[assign->getName()] = expr;
504        if (unlikely(assign->isOutputAssignment())) {
505            SetOutputValue(expr, assign->getOutputIndex());
506        }
507        retVal = expr;
508    }
509    if (const Next * next = dyn_cast<const Next>(stmt))
510    {
511        Value* expr = compileExpression(next->getExpr());
512        mMarkerMap[next->getName()] = expr;
513        retVal = expr;
514    }
515    else if (const If * ifstmt = dyn_cast<const If>(stmt))
516    {
517        BasicBlock * ifEntryBlock = mBasicBlock;
518        BasicBlock * ifBodyBlock = BasicBlock::Create(mMod->getContext(), "if.body", mFunc_process_block, 0);
519        BasicBlock * ifEndBlock = BasicBlock::Create(mMod->getContext(), "if.end", mFunc_process_block, 0);
520
521        int if_start_idx = mCarryQueueIdx;
522
523        Value* if_test_value = compileExpression(ifstmt->getCondition());
524
525        /* Generate the statements into the if body block, and also determine the
526           final carry index.  */
527
528        IRBuilder<> bIfBody(ifBodyBlock);
529        mBasicBlock = ifBodyBlock;
530
531        ++mNestingDepth;
532
533        Value *  returnMarker = compileStatements(ifstmt->getBody());
534
535        int if_end_idx = mCarryQueueIdx;
536        if (if_start_idx < if_end_idx + 1) {
537            // Have at least two internal carries.   Accumulate and store.
538            int if_accum_idx = mCarryQueueIdx++;
539
540            Value* if_carry_accum_value = genCarryInLoad(if_start_idx);
541
542            for (int c = if_start_idx+1; c < if_end_idx; c++)
543            {
544                Value* carryq_value = genCarryInLoad(c);
545                if_carry_accum_value = bIfBody.CreateOr(carryq_value, if_carry_accum_value);
546            }
547            genCarryOutStore(if_carry_accum_value, if_accum_idx);
548
549        }
550        bIfBody.CreateBr(ifEndBlock);
551
552        IRBuilder<> b_entry(ifEntryBlock);
553        mBasicBlock = ifEntryBlock;
554        if (if_start_idx < if_end_idx) {
555            // Have at least one internal carry.
556            int if_accum_idx = mCarryQueueIdx - 1;
557            Value* last_if_pending_carries = genCarryInLoad(if_accum_idx);
558            if_test_value = b_entry.CreateOr(if_test_value, last_if_pending_carries);
559        }
560        b_entry.CreateCondBr(genBitBlockAny(if_test_value), ifEndBlock, ifBodyBlock);
561
562        mBasicBlock = ifEndBlock;
563        --mNestingDepth;
564
565        retVal = returnMarker;
566    }
567    else if (const While * whileStatement = dyn_cast<const While>(stmt))
568    {
569        const auto baseCarryQueueIdx = mCarryQueueIdx;
570        if (mNestingDepth == 0) {
571            for (auto i = 0; i != whileStatement->getInclusiveCarryCount(); ++i) {
572                genCarryInLoad(baseCarryQueueIdx + i);
573            }
574        }       
575
576        SmallVector<Next*, 4> nextNodes;
577        for (PabloAST * node : whileStatement->getBody()) {
578            if (isa<Next>(node)) {
579                nextNodes.push_back(cast<Next>(node));
580            }
581        }
582
583        // Compile the initial iteration statements; the calls to genCarryOutStore will update the
584        // mCarryQueueVector with the appropriate values. Although we're not actually entering a new basic
585        // block yet, increment the nesting depth so that any calls to genCarryInLoad or genCarryOutStore
586        // will refer to the previous value.
587        ++mNestingDepth;
588        if (mMaxPabloWhileDepth < mNestingDepth) mMaxPabloWhileDepth = mNestingDepth;
589        compileStatements(whileStatement->getBody());
590       
591        // Reset the carry queue index. Note: this ought to be changed in the future. Currently this assumes
592        // that compiling the while body twice will generate the equivalent IR. This is not necessarily true
593        // but works for now.
594        mCarryQueueIdx = baseCarryQueueIdx;
595
596        BasicBlock* whileCondBlock = BasicBlock::Create(mMod->getContext(), "while.cond", mFunc_process_block, 0);
597        BasicBlock* whileBodyBlock = BasicBlock::Create(mMod->getContext(), "while.body", mFunc_process_block, 0);
598        BasicBlock* whileEndBlock = BasicBlock::Create(mMod->getContext(), "while.end", mFunc_process_block, 0);
599
600        // Note: compileStatements may update the mBasicBlock pointer if the body contains nested loops. It
601        // may not be same one that we entered the function with.
602        IRBuilder<> bEntry(mBasicBlock);
603        bEntry.CreateBr(whileCondBlock);
604
605        // CONDITION BLOCK
606        IRBuilder<> bCond(whileCondBlock);
607        // generate phi nodes for any carry propogating instruction
608        std::vector<PHINode*> phiNodes(whileStatement->getInclusiveCarryCount() + nextNodes.size());
609        unsigned index = 0;
610        for (index = 0; index != whileStatement->getInclusiveCarryCount(); ++index) {
611            PHINode * phi = bCond.CreatePHI(mBitBlockType, 2);
612            phi->addIncoming(mCarryQueueVector[baseCarryQueueIdx + index], mBasicBlock);
613            mCarryQueueVector[baseCarryQueueIdx + index] = mZeroInitializer; // (use phi for multi-carry mode.)
614            phiNodes[index] = phi;
615        }
616        // and for any Next nodes in the loop body
617        for (Next * n : nextNodes) {
618            PHINode * phi = bCond.CreatePHI(mBitBlockType, 2, n->getName()->str());
619            auto f = mMarkerMap.find(n->getName());
620            assert (f != mMarkerMap.end());
621            phi->addIncoming(f->second, mBasicBlock);
622            mMarkerMap[n->getName()] = phi;
623            phiNodes[index++] = phi;
624        }
625
626        mBasicBlock = whileCondBlock;
627        bCond.CreateCondBr(genBitBlockAny(compileExpression(whileStatement->getCondition())), whileEndBlock, whileBodyBlock);
628
629        // BODY BLOCK
630        mBasicBlock = whileBodyBlock;
631        retVal = compileStatements(whileStatement->getBody());
632        // update phi nodes for any carry propogating instruction
633        IRBuilder<> bWhileBody(mBasicBlock);
634        for (index = 0; index != whileStatement->getInclusiveCarryCount(); ++index) {
635            Value * carryOut = bWhileBody.CreateOr(phiNodes[index], mCarryQueueVector[baseCarryQueueIdx + index]);
636            PHINode * phi = phiNodes[index];
637            phi->addIncoming(carryOut, mBasicBlock);
638            mCarryQueueVector[baseCarryQueueIdx + index] = phi;
639        }
640        // and for any Next nodes in the loop body
641        for (Next * n : nextNodes) {
642            auto f = mMarkerMap.find(n->getName());
643            assert (f != mMarkerMap.end());
644            PHINode * phi = phiNodes[index++];
645            phi->addIncoming(f->second, mBasicBlock);
646            mMarkerMap[n->getName()] = phi;
647        }
648
649        bWhileBody.CreateBr(whileCondBlock);
650
651        // EXIT BLOCK
652        mBasicBlock = whileEndBlock;   
653        if (--mNestingDepth == 0) {
654            for (index = 0; index != whileStatement->getInclusiveCarryCount(); ++index) {
655                genCarryOutStore(phiNodes[index], baseCarryQueueIdx + index);
656            }
657        }
658    }
659    return retVal;
660}
661
662Value * PabloCompiler::compileExpression(const PabloAST * expr)
663{
664    Value * retVal = nullptr;
665    IRBuilder<> b(mBasicBlock);
666    if (isa<Ones>(expr)) {
667        retVal = mOneInitializer;
668    }
669    else if (isa<Zeroes>(expr)) {
670        retVal = mZeroInitializer;
671    }
672    else if (const Call* call = dyn_cast<Call>(expr)) {
673        //Call the callee once and store the result in the marker map.
674        auto mi = mMarkerMap.find(call->getCallee());
675        if (mi == mMarkerMap.end()) {
676            auto ci = mCalleeMap.find(call->getCallee());
677            if (ci == mCalleeMap.end()) {
678                throw std::runtime_error("Unexpected error locating static function for \"" + call->getCallee()->str() + "\"");
679            }
680            mi = mMarkerMap.insert(std::make_pair(call->getCallee(), b.CreateCall(ci->second, mBasisBitsAddr))).first;
681        }
682        retVal = mi->second;
683    }
684    else if (const Var * var = dyn_cast<Var>(expr))
685    {
686        auto f = mMarkerMap.find(var->getName());
687        assert (f != mMarkerMap.end());
688        retVal = f->second;
689    }
690    else if (const And * pablo_and = dyn_cast<And>(expr))
691    {
692        retVal = b.CreateAnd(compileExpression(pablo_and->getExpr1()), compileExpression(pablo_and->getExpr2()), "and");
693    }
694    else if (const Or * pablo_or = dyn_cast<Or>(expr))
695    {
696        retVal = b.CreateOr(compileExpression(pablo_or->getExpr1()), compileExpression(pablo_or->getExpr2()), "or");
697    }
698    else if (const Sel * sel = dyn_cast<Sel>(expr))
699    {
700        Value* ifMask = compileExpression(sel->getCondition());
701        Value* ifTrue = b.CreateAnd(ifMask, compileExpression(sel->getTrueExpr()));
702        Value* ifFalse = b.CreateAnd(genNot(ifMask), compileExpression(sel->getFalseExpr()));
703        retVal = b.CreateOr(ifTrue, ifFalse);
704    }
705    else if (const Not * pablo_not = dyn_cast<Not>(expr))
706    {
707        retVal = genNot(compileExpression(pablo_not->getExpr()));
708    }
709    else if (const Advance * adv = dyn_cast<Advance>(expr))
710    {
711        Value* strm_value = compileExpression(adv->getExpr());
712        int shift = adv->getAdvanceAmount();
713        retVal = genAdvanceWithCarry(strm_value, shift);
714    }
715    else if (const MatchStar * mstar = dyn_cast<MatchStar>(expr))
716    {
717        Value* marker = compileExpression(mstar->getMarker());
718        Value* cc = compileExpression(mstar->getCharClass());
719        Value* marker_and_cc = b.CreateAnd(marker, cc);
720        retVal = b.CreateOr(b.CreateXor(genAddWithCarry(marker_and_cc, cc), cc), marker, "matchstar");
721    }
722    else if (const ScanThru * sthru = dyn_cast<ScanThru>(expr))
723    {
724        Value* marker_expr = compileExpression(sthru->getScanFrom());
725        Value* cc_expr = compileExpression(sthru->getScanThru());
726        retVal = b.CreateAnd(genAddWithCarry(marker_expr, cc_expr), genNot(cc_expr), "scanthru");
727    }
728    return retVal;
729}
730
731#ifdef USE_UADD_OVERFLOW
732SumWithOverflowPack PabloCompiler::callUaddOverflow(Value* int128_e1, Value* int128_e2, Value* int1_cin) {
733    std::vector<Value*> struct_res_params;
734    struct_res_params.push_back(int128_e1);
735    struct_res_params.push_back(int128_e2);
736    struct_res_params.push_back(int1_cin);
737    CallInst* struct_res = CallInst::Create(mFunc_llvm_uadd_with_overflow, struct_res_params, "uadd_overflow_res", mBasicBlock);
738    struct_res->setCallingConv(CallingConv::C);
739    struct_res->setTailCall(false);
740    AttributeSet struct_res_PAL;
741    struct_res->setAttributes(struct_res_PAL);
742
743    SumWithOverflowPack ret;
744
745    std::vector<unsigned> int128_sum_indices;
746    int128_sum_indices.push_back(0);
747    ret.sum = ExtractValueInst::Create(struct_res, int128_sum_indices, "sum", mBasicBlock);
748
749    std::vector<unsigned> int1_obit_indices;
750    int1_obit_indices.push_back(1);
751    ret.obit = ExtractValueInst::Create(struct_res, int1_obit_indices, "obit", mBasicBlock);
752
753    return ret;
754}
755#endif
756
757Value* PabloCompiler::genAddWithCarry(Value* e1, Value* e2) {
758    IRBuilder<> b(mBasicBlock);
759
760    //CarryQ - carry in.
761    const int carryIdx = mCarryQueueIdx++;
762    Value* carryq_value = genCarryInLoad(carryIdx);
763
764#ifdef USE_UADD_OVERFLOW
765    //use llvm.uadd.with.overflow.i128 or i256
766    ConstantInt* const_int32_6 = ConstantInt::get(mMod->getContext(), APInt(32, StringRef("0"), 10));
767    CastInst* int128_e1 = new BitCastInst(e1, IntegerType::get(mMod->getContext(), BLOCK_SIZE), "e1_128", mBasicBlock);
768    CastInst* int128_e2 = new BitCastInst(e2, IntegerType::get(mMod->getContext(), BLOCK_SIZE), "e2_128", mBasicBlock);
769    ExtractElementInst * int64_carryq_value = ExtractElementInst::Create(carryq_value, const_int32_6, "carryq_64", mBasicBlock);
770    CastInst* int1_carryq_value = new TruncInst(int64_carryq_value, IntegerType::get(mMod->getContext(), 1), "carryq_1", mBasicBlock);
771    SumWithOverflowPack sumpack0;
772    sumpack0 = callUaddOverflow(int128_e1, int128_e2, int1_carryq_value);
773    Value* obit = sumpack0.obit;
774    Value* sum = b.CreateBitCast(sumpack0.sum, mXi64Vect, "sum");
775    /*obit is the i1 carryout, zero extend and insert it into a v2i64 or v4i64 vector.*/
776    ConstantAggregateZero* const_packed_5 = ConstantAggregateZero::get(mXi64Vect);
777    CastInst* int64_o0 = new ZExtInst(obit, IntegerType::get(mMod->getContext(), 64), "o0", mBasicBlock);
778    InsertElementInst* carry_out = InsertElementInst::Create(const_packed_5, int64_o0, const_int32_6, "carry_out", mBasicBlock);
779#else
780    //calculate carry through logical ops
781    Value* carrygen = b.CreateAnd(e1, e2, "carrygen");
782    Value* carryprop = b.CreateOr(e1, e2, "carryprop");
783    Value* digitsum = b.CreateAdd(e1, e2, "digitsum");
784    Value* partial = b.CreateAdd(digitsum, carryq_value, "partial");
785    Value* digitcarry = b.CreateOr(carrygen, b.CreateAnd(carryprop, genNot(partial)));
786    Value* mid_carry_in = genShiftLeft64(b.CreateLShr(digitcarry, 63), "mid_carry_in");
787
788    Value* sum = b.CreateAdd(partial, mid_carry_in, "sum");
789    Value* carry_out = genShiftHighbitToLow(b.CreateOr(carrygen, b.CreateAnd(carryprop, genNot(sum))), "carry_out");
790#endif
791    genCarryOutStore(carry_out, carryIdx);
792    return sum;
793}
794
795Value* PabloCompiler::genCarryInLoad(const unsigned index) {   
796    assert (index < mCarryQueueVector.size());
797    if (mNestingDepth == 0) {
798        IRBuilder<> b(mBasicBlock);
799        mCarryQueueVector[index] = b.CreateAlignedLoad(b.CreateGEP(mCarryQueuePtr, b.getInt64(index)), BLOCK_SIZE/8, false);
800    }
801    return mCarryQueueVector[index];
802}
803
804void PabloCompiler::genCarryOutStore(Value* carryOut, const unsigned index ) {
805    assert (carryOut);
806    assert (index < mCarryQueueVector.size());
807    if (mNestingDepth == 0) {       
808        IRBuilder<> b(mBasicBlock);
809        b.CreateAlignedStore(carryOut, b.CreateGEP(mCarryQueuePtr, b.getInt64(index)), BLOCK_SIZE/8, false);
810    }
811    mCarryQueueVector[index] = carryOut;
812}
813
814inline Value* PabloCompiler::genBitBlockAny(Value* test) {
815    IRBuilder<> b(mBasicBlock);
816    Value* cast_marker_value_1 = b.CreateBitCast(test, IntegerType::get(mMod->getContext(), BLOCK_SIZE));
817    return b.CreateICmpEQ(cast_marker_value_1, ConstantInt::get(IntegerType::get(mMod->getContext(), BLOCK_SIZE), 0));
818}
819
820Value* PabloCompiler::genShiftHighbitToLow(Value* e, const Twine &namehint) {
821    IRBuilder<> b(mBasicBlock);
822    Value* i128_val = b.CreateBitCast(e, IntegerType::get(mMod->getContext(), BLOCK_SIZE));
823    return b.CreateBitCast(b.CreateLShr(i128_val, BLOCK_SIZE - 1, namehint), mBitBlockType);
824}
825
826Value* PabloCompiler::genShiftLeft64(Value* e, const Twine &namehint) {
827    IRBuilder<> b(mBasicBlock);
828    Value* i128_val = b.CreateBitCast(e, IntegerType::get(mMod->getContext(), BLOCK_SIZE));
829    return b.CreateBitCast(b.CreateShl(i128_val, 64, namehint), mBitBlockType);
830}
831
832inline Value* PabloCompiler::genNot(Value* expr) {
833    IRBuilder<> b(mBasicBlock);
834    return b.CreateXor(expr, mOneInitializer, "not");
835}
836
837Value* PabloCompiler::genAdvanceWithCarry(Value* strm_value, int shift_amount) {
838
839    IRBuilder<> b(mBasicBlock);
840#if (BLOCK_SIZE == 128)
841    const auto carryIdx = mCarryQueueIdx++;
842    if (shift_amount == 1) {
843        Value* carryq_value = genCarryInLoad(carryIdx);
844        Value* srli_1_value = b.CreateLShr(strm_value, 63);
845        Value* packed_shuffle;
846        Constant* const_packed_1_elems [] = {b.getInt32(0), b.getInt32(2)};
847        Constant* const_packed_1 = ConstantVector::get(const_packed_1_elems);
848        packed_shuffle = b.CreateShuffleVector(carryq_value, srli_1_value, const_packed_1);
849
850        Constant* const_packed_2_elems[] = {b.getInt64(1), b.getInt64(1)};
851        Constant* const_packed_2 = ConstantVector::get(const_packed_2_elems);
852
853        Value* shl_value = b.CreateShl(strm_value, const_packed_2);
854        Value* result_value = b.CreateOr(shl_value, packed_shuffle, "advance");
855
856        Value* carry_out = genShiftHighbitToLow(strm_value, "carry_out");
857        //CarryQ - carry out:
858        genCarryOutStore(carry_out, carryIdx);
859           
860        return result_value;
861    }
862    else if (shift_amount < 64) {
863        // This is the preferred logic, but is too slow for the general case.   
864        // We need to speed up our custom LLVM for this code.
865        Value* carryq_longint = b.CreateBitCast(genCarryInLoad(carryIdx), IntegerType::get(mMod->getContext(), BLOCK_SIZE));
866        Value* strm_longint = b.CreateBitCast(strm_value, IntegerType::get(mMod->getContext(), BLOCK_SIZE));
867        Value* adv_longint = b.CreateOr(b.CreateShl(strm_longint, shift_amount), carryq_longint, "advance");
868    Value* result_value = b.CreateBitCast(adv_longint, mBitBlockType);
869    Value* carry_out = b.CreateBitCast(b.CreateLShr(strm_longint, BLOCK_SIZE - shift_amount, "advance_out"), mBitBlockType);
870        //CarryQ - carry out:
871        genCarryOutStore(carry_out, carryIdx);
872           
873        return result_value;
874    }
875    else {//if (shift_amount >= 64) {
876        throw std::runtime_error("Shift amount >= 64 in Advance is currently unsupported.");
877    }
878#endif
879
880#if (BLOCK_SIZE == 256)
881    return genAddWithCarry(strm_value, strm_value);
882#endif
883
884}
885
886void PabloCompiler::SetOutputValue(Value * marker, const unsigned index) {
887    IRBuilder<> b(mBasicBlock);
888    if (marker->getType()->isPointerTy()) {
889        marker = b.CreateAlignedLoad(marker, BLOCK_SIZE/8, false);
890    }
891    Value* indices[] = {b.getInt64(0), b.getInt32(index)};
892    Value* gep = b.CreateGEP(mOutputAddrPtr, indices);
893    b.CreateAlignedStore(marker, gep, BLOCK_SIZE/8, false);
894}
895
896}
Note: See TracBrowser for help on using the repository browser.