source: icGREP/icgrep-devel/icgrep/pablo/pablo_compiler.cpp @ 4272

Last change on this file since 4272 was 4270, checked in by nmedfort, 4 years ago

Minor changes.

File size: 33.9 KB
Line 
1/*
2 *  Copyright (c) 2014 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 *  icgrep is a trademark of International Characters.
5 */
6
7/*
8 *  Copyright (c) 2014 International Characters.
9 *  This software is licensed to the public under the Open Software License 3.0.
10 *  icgrep is a trademark of International Characters.
11 */
12
13#include <pablo/pablo_compiler.h>
14#include <pablo/codegenstate.h>
15#include <pablo/printer_pablos.h>
16#include <cc/cc_namemap.hpp>
17#include <re/re_name.h>
18#include <stdexcept>
19#include <include/simd-lib/bitblock.hpp>
20
21//#define DUMP_GENERATED_IR
22//#define DUMP_OPTIMIZED_IR
23
24extern "C" {
25  void wrapped_print_register(BitBlock bit_block) {
26      print_register<BitBlock>("", bit_block);
27  }
28}
29
30#define CREATE_GENERAL_CODE_CATEGORY(SUFFIX) \
31SUFFIX * f##SUFFIX = nullptr; \
32extern "C" { \
33    BitBlock __get_category_##SUFFIX(Basis_bits &basis_bits) { \
34        if (f##SUFFIX == nullptr) f##SUFFIX = new SUFFIX(); \
35        Struct_##SUFFIX output; \
36        f##SUFFIX->do_block(basis_bits, output); \
37        return output.cc; \
38    } \
39}
40
41CREATE_GENERAL_CODE_CATEGORY(Cc)
42CREATE_GENERAL_CODE_CATEGORY(Cf)
43CREATE_GENERAL_CODE_CATEGORY(Cn)
44CREATE_GENERAL_CODE_CATEGORY(Co)
45CREATE_GENERAL_CODE_CATEGORY(Cs)
46CREATE_GENERAL_CODE_CATEGORY(Ll)
47CREATE_GENERAL_CODE_CATEGORY(Lm)
48CREATE_GENERAL_CODE_CATEGORY(Lo)
49CREATE_GENERAL_CODE_CATEGORY(Lt)
50CREATE_GENERAL_CODE_CATEGORY(Lu)
51CREATE_GENERAL_CODE_CATEGORY(Mc)
52CREATE_GENERAL_CODE_CATEGORY(Me)
53CREATE_GENERAL_CODE_CATEGORY(Mn)
54CREATE_GENERAL_CODE_CATEGORY(Nd)
55CREATE_GENERAL_CODE_CATEGORY(Nl)
56CREATE_GENERAL_CODE_CATEGORY(No)
57CREATE_GENERAL_CODE_CATEGORY(Pc)
58CREATE_GENERAL_CODE_CATEGORY(Pd)
59CREATE_GENERAL_CODE_CATEGORY(Pe)
60CREATE_GENERAL_CODE_CATEGORY(Pf)
61CREATE_GENERAL_CODE_CATEGORY(Pi)
62CREATE_GENERAL_CODE_CATEGORY(Po)
63CREATE_GENERAL_CODE_CATEGORY(Ps)
64CREATE_GENERAL_CODE_CATEGORY(Sc)
65CREATE_GENERAL_CODE_CATEGORY(Sk)
66CREATE_GENERAL_CODE_CATEGORY(Sm)
67CREATE_GENERAL_CODE_CATEGORY(So)
68CREATE_GENERAL_CODE_CATEGORY(Zl)
69CREATE_GENERAL_CODE_CATEGORY(Zp)
70CREATE_GENERAL_CODE_CATEGORY(Zs)
71
72#undef CREATE_GENERAL_CODE_CATEGORY
73
74namespace pablo {
75
76PabloCompiler::PabloCompiler(const std::vector<Var*> & basisBits)
77: mBasisBits(basisBits)
78, mMod(new Module("icgrep", getGlobalContext()))
79, mBasicBlock(nullptr)
80, mExecutionEngine(nullptr)
81, mBitBlockType(VectorType::get(IntegerType::get(mMod->getContext(), 64), BLOCK_SIZE / 64))
82, mBasisBitsInputPtr(nullptr)
83, mCarryQueueIdx(0)
84, mCarryQueuePtr(nullptr)
85, mNestingDepth(0)
86, mCarryQueueSize(0)
87, mZeroInitializer(ConstantAggregateZero::get(mBitBlockType))
88, mOneInitializer(ConstantVector::getAllOnesValue(mBitBlockType))
89, mFunctionType(nullptr)
90, mFunc_process_block(nullptr)
91, mBasisBitsAddr(nullptr)
92, mOutputAddrPtr(nullptr)
93{
94    //Create the jit execution engine.up
95    InitializeNativeTarget();
96    std::string ErrStr;
97    mExecutionEngine = EngineBuilder(mMod).setUseMCJIT(true).setErrorStr(&ErrStr).setOptLevel(CodeGenOpt::Level::None).create();
98    if (mExecutionEngine == nullptr) {
99        throw std::runtime_error("Could not create ExecutionEngine: " + ErrStr);
100    }
101
102    InitializeNativeTargetAsmPrinter();
103    InitializeNativeTargetAsmParser();
104
105    DefineTypes();
106    DeclareFunctions();
107}
108
109PabloCompiler::~PabloCompiler()
110{
111    delete mMod;
112    delete fPs;
113    delete fNl;
114    delete fNo;
115    delete fLo;
116    delete fLl;
117    delete fLm;
118    delete fNd;
119    delete fPc;
120    delete fLt;
121    delete fLu;
122    delete fPf;
123    delete fPd;
124    delete fPe;
125    delete fPi;
126    delete fPo;
127    delete fMe;
128    delete fMc;
129    delete fMn;
130    delete fSk;
131    delete fSo;
132    delete fSm;
133    delete fSc;
134    delete fZl;
135    delete fCo;
136    delete fCn;
137    delete fCc;
138    delete fCf;
139    delete fCs;
140    delete fZp;
141    delete fZs;
142
143}
144
145LLVM_Gen_RetVal PabloCompiler::compile(PabloBlock & pb)
146{
147    mCarryQueueSize = 0;
148    DeclareCallFunctions(pb.statements());
149    mCarryQueueVector.resize(mCarryQueueSize);
150
151    Function::arg_iterator args = mFunc_process_block->arg_begin();
152    mBasisBitsAddr = args++;
153    mBasisBitsAddr->setName("basis_bits");
154    mCarryQueuePtr = args++;
155    mCarryQueuePtr->setName("carry_q");
156    mOutputAddrPtr = args++;
157    mOutputAddrPtr->setName("output");
158
159    //Create the carry queue.
160    mCarryQueueIdx = 0;
161    mNestingDepth = 0;
162    mBasicBlock = BasicBlock::Create(mMod->getContext(), "parabix_entry", mFunc_process_block,0);
163
164    //The basis bits structure
165    for (unsigned i = 0; i != mBasisBits.size(); ++i) {
166        IRBuilder<> b(mBasicBlock);
167        Value* indices[] = {b.getInt64(0), b.getInt32(i)};
168        const String * const name = mBasisBits[i]->getName();
169        Value * gep = b.CreateGEP(mBasisBitsAddr, indices);
170        LoadInst * basisBit = b.CreateAlignedLoad(gep, BLOCK_SIZE/8, false, name->str());
171        mMarkerMap.insert(std::make_pair(name, basisBit));
172    }
173
174    //Generate the IR instructions for the function.
175    compileStatements(pb.statements());
176
177    assert (mCarryQueueIdx == mCarryQueueSize);
178    assert (mNestingDepth == 0);
179    //Terminate the block
180    ReturnInst::Create(mMod->getContext(), mBasicBlock);
181
182    //Un-comment this line in order to display the IR that has been generated by this module.
183    #ifdef DUMP_GENERATED_IR
184    mMod->dump();
185    #endif
186
187    //Create a verifier.  The verifier will print an error message if our module is malformed in any way.
188    #ifdef USE_LLVM_3_5
189    verifyModule(*mMod, &dbgs());
190    #endif
191    #ifdef USE_LLVM_3_4
192    verifyModule(*mMod, PrintMessageAction);
193    #endif
194
195    //Use the pass manager to run optimizations on the function.
196    FunctionPassManager fpm(mMod);
197
198#ifdef USE_LLVM_3_5
199    mMod->setDataLayout(mExecutionEngine->getDataLayout());
200    // Set up the optimizer pipeline.  Start with registering info about how the target lays out data structures.
201    fpm.add(new DataLayoutPass(mMod));
202#endif
203
204#ifdef USE_LLVM_3_4
205    fpm.add(new DataLayout(*mExecutionEngine->getDataLayout()));
206#endif
207
208    //fpm.add(createPromoteMemoryToRegisterPass()); //Transform to SSA form.
209    //fpm.add(createBasicAliasAnalysisPass());      //Provide basic AliasAnalysis support for GVN. (Global Value Numbering)
210    //fpm.add(createCFGSimplificationPass());       //Simplify the control flow graph.
211    fpm.add(createInstructionCombiningPass());    //Simple peephole optimizations and bit-twiddling.
212    fpm.add(createReassociatePass());             //Reassociate expressions.
213    fpm.add(createGVNPass());                     //Eliminate common subexpressions.
214
215    fpm.doInitialization();
216
217    fpm.run(*mFunc_process_block);
218
219#ifdef DUMP_OPTIMIZED_IR
220    mMod->dump();
221#endif
222    mExecutionEngine->finalizeObject();
223
224    LLVM_Gen_RetVal retVal;
225    //Return the required size of the carry queue and a pointer to the process_block function.
226    retVal.carry_q_size = mCarryQueueSize;
227    retVal.process_block_fptr = mExecutionEngine->getPointerToFunction(mFunc_process_block);
228
229    return retVal;
230}
231
232void PabloCompiler::DefineTypes()
233{
234    StructType * structBasisBits = mMod->getTypeByName("struct.Basis_bits");
235    if (structBasisBits == nullptr) {
236        structBasisBits = StructType::create(mMod->getContext(), "struct.Basis_bits");
237    }
238    std::vector<Type*>StructTy_struct_Basis_bits_fields;
239    for (int i = 0; i != mBasisBits.size(); i++)
240    {
241        StructTy_struct_Basis_bits_fields.push_back(mBitBlockType);
242    }
243    if (structBasisBits->isOpaque()) {
244        structBasisBits->setBody(StructTy_struct_Basis_bits_fields, /*isPacked=*/false);
245    }
246    mBasisBitsInputPtr = PointerType::get(structBasisBits, 0);
247
248    std::vector<Type*>functionTypeArgs;
249    functionTypeArgs.push_back(mBasisBitsInputPtr);
250
251    //The carry q array.
252    //A pointer to the BitBlock vector.
253    functionTypeArgs.push_back(PointerType::get(mBitBlockType, 0));
254
255    //The output structure.
256    StructType * outputStruct = mMod->getTypeByName("struct.Output");
257    if (!outputStruct) {
258        outputStruct = StructType::create(mMod->getContext(), "struct.Output");
259    }
260    if (outputStruct->isOpaque()) {
261        std::vector<Type*>fields;
262        fields.push_back(mBitBlockType);
263        fields.push_back(mBitBlockType);
264        outputStruct->setBody(fields, /*isPacked=*/false);
265    }
266    PointerType* outputStructPtr = PointerType::get(outputStruct, 0);
267
268    //The &output parameter.
269    functionTypeArgs.push_back(outputStructPtr);
270
271    mFunctionType = FunctionType::get(
272     /*Result=*/Type::getVoidTy(mMod->getContext()),
273     /*Params=*/functionTypeArgs,
274     /*isVarArg=*/false);
275}
276
277void PabloCompiler::DeclareFunctions()
278{
279    //This function can be used for testing to print the contents of a register from JIT'd code to the terminal window.
280    //mFunc_print_register = mMod->getOrInsertFunction("wrapped_print_register", Type::getVoidTy(getGlobalContext()), mXi64Vect, NULL);
281    //mExecutionEngine->addGlobalMapping(cast<GlobalValue>(mFunc_print_register), (void *)&wrapped_print_register);
282    // to call->  b.CreateCall(mFunc_print_register, unicode_category);
283
284#ifdef USE_UADD_OVERFLOW
285    // Type Definitions for llvm.uadd.with.overflow.carryin.i128 or .i256
286    std::vector<Type*>StructTy_0_fields;
287    StructTy_0_fields.push_back(IntegerType::get(mMod->getContext(), BLOCK_SIZE));
288    StructTy_0_fields.push_back(IntegerType::get(mMod->getContext(), 1));
289    StructType *StructTy_0 = StructType::get(mMod->getContext(), StructTy_0_fields, /*isPacked=*/false);
290
291    std::vector<Type*>FuncTy_1_args;
292    FuncTy_1_args.push_back(IntegerType::get(mMod->getContext(), BLOCK_SIZE));
293    FuncTy_1_args.push_back(IntegerType::get(mMod->getContext(), BLOCK_SIZE));
294    FuncTy_1_args.push_back(IntegerType::get(mMod->getContext(), 1));
295    FunctionType* FuncTy_1 = FunctionType::get(
296                                              /*Result=*/StructTy_0,
297                                              /*Params=*/FuncTy_1_args,
298                                              /*isVarArg=*/false);
299
300    mFunc_llvm_uadd_with_overflow = mMod->getFunction("llvm.uadd.with.overflow.carryin.i"##BLOCK_SIZE);
301    if (!mFunc_llvm_uadd_with_overflow) {
302        mFunc_llvm_uadd_with_overflow = Function::Create(
303          /*Type=*/ FuncTy_1,
304          /*Linkage=*/ GlobalValue::ExternalLinkage,
305          /*Name=*/ "llvm.uadd.with.overflow.carryin.i"##BLOCK_SIZE, mMod); // (external, no body)
306        mFunc_llvm_uadd_with_overflow->setCallingConv(CallingConv::C);
307    }
308    AttributeSet mFunc_llvm_uadd_with_overflow_PAL;
309    {
310        SmallVector<AttributeSet, 4> Attrs;
311        AttributeSet PAS;
312        {
313          AttrBuilder B;
314          B.addAttribute(Attribute::NoUnwind);
315          B.addAttribute(Attribute::ReadNone);
316          PAS = AttributeSet::get(mMod->getContext(), ~0U, B);
317        }
318
319        Attrs.push_back(PAS);
320        mFunc_llvm_uadd_with_overflow_PAL = AttributeSet::get(mMod->getContext(), Attrs);
321    }
322    mFunc_llvm_uadd_with_overflow->setAttributes(mFunc_llvm_uadd_with_overflow_PAL);
323#endif
324
325    //Starts on process_block
326    SmallVector<AttributeSet, 4> Attrs;
327    AttributeSet PAS;
328    {
329        AttrBuilder B;
330        B.addAttribute(Attribute::ReadOnly);
331        B.addAttribute(Attribute::NoCapture);
332        PAS = AttributeSet::get(mMod->getContext(), 1U, B);
333    }
334    Attrs.push_back(PAS);
335    {
336        AttrBuilder B;
337        B.addAttribute(Attribute::NoCapture);
338        PAS = AttributeSet::get(mMod->getContext(), 2U, B);
339    }
340    Attrs.push_back(PAS);
341    {
342        AttrBuilder B;
343        B.addAttribute(Attribute::NoCapture);
344        PAS = AttributeSet::get(mMod->getContext(), 3U, B);
345    }
346    Attrs.push_back(PAS);
347    {
348        AttrBuilder B;
349        B.addAttribute(Attribute::NoUnwind);
350        B.addAttribute(Attribute::UWTable);
351        PAS = AttributeSet::get(mMod->getContext(), ~0U, B);
352    }
353    AttributeSet AttrSet = AttributeSet::get(mMod->getContext(), Attrs);
354
355    //Create the function that will be generated.
356    mFunc_process_block = mMod->getFunction("process_block");
357    if (!mFunc_process_block) {
358        mFunc_process_block = Function::Create(
359            /*Type=*/mFunctionType,
360            /*Linkage=*/GlobalValue::ExternalLinkage,
361            /*Name=*/"process_block", mMod);
362        mFunc_process_block->setCallingConv(CallingConv::C);
363    }
364    mFunc_process_block->setAttributes(AttrSet);
365}
366
367void PabloCompiler::DeclareCallFunctions(const StatementList & stmts) {
368    for (PabloAST * stmt : stmts) {
369        if (const Assign * assign = dyn_cast<Assign>(stmt)) {
370            DeclareCallFunctions(assign->getExpr());
371        }
372        if (const Next * next = dyn_cast<Next>(stmt)) {
373            DeclareCallFunctions(next->getExpr());
374        }
375        else if (If * ifStatement = dyn_cast<If>(stmt)) {
376            const auto preIfCarryCount = mCarryQueueSize;
377            DeclareCallFunctions(ifStatement->getCondition());
378            DeclareCallFunctions(ifStatement->getBody());
379            ifStatement->setInclusiveCarryCount(mCarryQueueSize - preIfCarryCount);
380        }
381        else if (While * whileStatement = dyn_cast<While>(stmt)) {
382            const auto preWhileCarryCount = mCarryQueueSize;
383            DeclareCallFunctions(whileStatement->getCondition());
384            DeclareCallFunctions(whileStatement->getBody());
385            whileStatement->setInclusiveCarryCount(mCarryQueueSize - preWhileCarryCount);
386        }
387    }
388}
389
390void PabloCompiler::DeclareCallFunctions(const PabloAST * expr)
391{
392    if (const Call * call = dyn_cast<const Call>(expr)) {
393        const String * const callee = call->getCallee();
394        assert (callee);
395        if (mCalleeMap.find(callee) == mCalleeMap.end()) {
396            void * callee_ptr = nullptr;
397            #define CHECK_GENERAL_CODE_CATEGORY(SUFFIX) \
398                if (callee->str() == #SUFFIX) { \
399                    callee_ptr = (void*)&__get_category_##SUFFIX; \
400                } else
401            CHECK_GENERAL_CODE_CATEGORY(Cc)
402            CHECK_GENERAL_CODE_CATEGORY(Cf)
403            CHECK_GENERAL_CODE_CATEGORY(Cn)
404            CHECK_GENERAL_CODE_CATEGORY(Co)
405            CHECK_GENERAL_CODE_CATEGORY(Cs)
406            CHECK_GENERAL_CODE_CATEGORY(Ll)
407            CHECK_GENERAL_CODE_CATEGORY(Lm)
408            CHECK_GENERAL_CODE_CATEGORY(Lo)
409            CHECK_GENERAL_CODE_CATEGORY(Lt)
410            CHECK_GENERAL_CODE_CATEGORY(Lu)
411            CHECK_GENERAL_CODE_CATEGORY(Mc)
412            CHECK_GENERAL_CODE_CATEGORY(Me)
413            CHECK_GENERAL_CODE_CATEGORY(Mn)
414            CHECK_GENERAL_CODE_CATEGORY(Nd)
415            CHECK_GENERAL_CODE_CATEGORY(Nl)
416            CHECK_GENERAL_CODE_CATEGORY(No)
417            CHECK_GENERAL_CODE_CATEGORY(Pc)
418            CHECK_GENERAL_CODE_CATEGORY(Pd)
419            CHECK_GENERAL_CODE_CATEGORY(Pe)
420            CHECK_GENERAL_CODE_CATEGORY(Pf)
421            CHECK_GENERAL_CODE_CATEGORY(Pi)
422            CHECK_GENERAL_CODE_CATEGORY(Po)
423            CHECK_GENERAL_CODE_CATEGORY(Ps)
424            CHECK_GENERAL_CODE_CATEGORY(Sc)
425            CHECK_GENERAL_CODE_CATEGORY(Sk)
426            CHECK_GENERAL_CODE_CATEGORY(Sm)
427            CHECK_GENERAL_CODE_CATEGORY(So)
428            CHECK_GENERAL_CODE_CATEGORY(Zl)
429            CHECK_GENERAL_CODE_CATEGORY(Zp)
430            CHECK_GENERAL_CODE_CATEGORY(Zs)
431            // OTHERWISE ...
432            throw std::runtime_error("Unknown unicode category \"" + callee->str() + "\"");
433            #undef CHECK_GENERAL_CODE_CATEGORY
434            Value * unicodeCategory = mMod->getOrInsertFunction("__get_category_" + callee->str(), mBitBlockType, mBasisBitsInputPtr, NULL);
435            if (unicodeCategory == nullptr) {
436                throw std::runtime_error("Could not create static method call for unicode category \"" + callee->str() + "\"");
437            }
438            mExecutionEngine->addGlobalMapping(cast<GlobalValue>(unicodeCategory), callee_ptr);
439            mCalleeMap.insert(std::make_pair(callee, unicodeCategory));
440        }
441    }
442    else if (const And * pablo_and = dyn_cast<const And>(expr))
443    {
444        DeclareCallFunctions(pablo_and->getExpr1());
445        DeclareCallFunctions(pablo_and->getExpr2());
446    }
447    else if (const Or * pablo_or = dyn_cast<const Or>(expr))
448    {
449        DeclareCallFunctions(pablo_or->getExpr1());
450        DeclareCallFunctions(pablo_or->getExpr2());
451    }
452    else if (const Sel * pablo_sel = dyn_cast<const Sel>(expr))
453    {
454        DeclareCallFunctions(pablo_sel->getCondition());
455        DeclareCallFunctions(pablo_sel->getTrueExpr());
456        DeclareCallFunctions(pablo_sel->getFalseExpr());
457    }
458    else if (const Not * pablo_not = dyn_cast<const Not>(expr))
459    {
460        DeclareCallFunctions(pablo_not->getExpr());
461    }
462    else if (const Advance * adv = dyn_cast<const Advance>(expr))
463    {
464        ++mCarryQueueSize;
465        DeclareCallFunctions(adv->getExpr());
466    }
467    else if (const MatchStar * mstar = dyn_cast<const MatchStar>(expr))
468    {
469        ++mCarryQueueSize;
470        DeclareCallFunctions(mstar->getMarker());
471        DeclareCallFunctions(mstar->getCharClass());
472    }
473    else if (const ScanThru * sthru = dyn_cast<const ScanThru>(expr))
474    {
475        ++mCarryQueueSize;
476        DeclareCallFunctions(sthru->getScanFrom());
477        DeclareCallFunctions(sthru->getScanThru());
478    }
479}
480
481Value * PabloCompiler::compileStatements(const StatementList & stmts) {
482    Value * retVal = nullptr;
483    for (PabloAST * statement : stmts) {
484        retVal = compileStatement(statement);
485    }
486    return retVal;
487}
488
489Value * PabloCompiler::compileStatement(const PabloAST * stmt)
490{
491    Value * retVal = nullptr;
492    if (const Assign * assign = dyn_cast<const Assign>(stmt))
493    {
494        Value* expr = compileExpression(assign->getExpr());
495        mMarkerMap[assign->getName()] = expr;
496        if (unlikely(assign->isOutputAssignment())) {
497            SetOutputValue(expr, assign->getOutputIndex());
498        }
499        retVal = expr;
500    }
501    if (const Next * next = dyn_cast<const Next>(stmt))
502    {
503        Value* expr = compileExpression(next->getExpr());
504        mMarkerMap[next->getName()] = expr;
505        retVal = expr;
506    }
507    else if (const If * ifstmt = dyn_cast<const If>(stmt))
508    {
509        BasicBlock * ifEntryBlock = mBasicBlock;
510        BasicBlock * ifBodyBlock = BasicBlock::Create(mMod->getContext(), "if.body", mFunc_process_block, 0);
511        BasicBlock * ifEndBlock = BasicBlock::Create(mMod->getContext(), "if.end", mFunc_process_block, 0);
512
513        int if_start_idx = mCarryQueueIdx;
514
515        Value* if_test_value = compileExpression(ifstmt->getCondition());
516
517        /* Generate the statements into the if body block, and also determine the
518           final carry index.  */
519
520        IRBuilder<> bIfBody(ifBodyBlock);
521        mBasicBlock = ifBodyBlock;
522
523        ++mNestingDepth;
524
525        Value *  returnMarker = compileStatements(ifstmt->getBody());
526
527        int if_end_idx = mCarryQueueIdx;
528        if (if_start_idx < if_end_idx + 1) {
529            // Have at least two internal carries.   Accumulate and store.
530            int if_accum_idx = mCarryQueueIdx++;
531
532            Value* if_carry_accum_value = genCarryInLoad(if_start_idx);
533
534            for (int c = if_start_idx+1; c < if_end_idx; c++)
535            {
536                Value* carryq_value = genCarryInLoad(c);
537                if_carry_accum_value = bIfBody.CreateOr(carryq_value, if_carry_accum_value);
538            }
539            genCarryOutStore(if_carry_accum_value, if_accum_idx);
540
541        }
542        bIfBody.CreateBr(ifEndBlock);
543
544        IRBuilder<> b_entry(ifEntryBlock);
545        mBasicBlock = ifEntryBlock;
546        if (if_start_idx < if_end_idx) {
547            // Have at least one internal carry.
548            int if_accum_idx = mCarryQueueIdx - 1;
549            Value* last_if_pending_carries = genCarryInLoad(if_accum_idx);
550            if_test_value = b_entry.CreateOr(if_test_value, last_if_pending_carries);
551        }
552        b_entry.CreateCondBr(genBitBlockAny(if_test_value), ifEndBlock, ifBodyBlock);
553
554        mBasicBlock = ifEndBlock;
555        --mNestingDepth;
556
557        retVal = returnMarker;
558    }
559    else if (const While * whileStatement = dyn_cast<const While>(stmt))
560    {
561        const auto baseCarryQueueIdx = mCarryQueueIdx;
562        if (mNestingDepth == 0) {
563            for (auto i = 0; i != whileStatement->getInclusiveCarryCount(); ++i) {
564                genCarryInLoad(baseCarryQueueIdx + i);
565            }
566        }       
567
568        SmallVector<Next*, 4> nextNodes;
569        for (PabloAST * node : whileStatement->getBody()) {
570            if (isa<Next>(node)) {
571                nextNodes.push_back(cast<Next>(node));
572            }
573        }
574
575        // Compile the initial iteration statements; the calls to genCarryOutStore will update the
576        // mCarryQueueVector with the appropriate values. Although we're not actually entering a new basic
577        // block yet, increment the nesting depth so that any calls to genCarryInLoad or genCarryOutStore
578        // will refer to the previous value.
579        ++mNestingDepth;
580        compileStatements(whileStatement->getBody());
581        // Reset the carry queue index. Note: this ought to be changed in the future. Currently this assumes
582        // that compiling the while body twice will generate the equivalent IR. This is not necessarily true
583        // but works for now.
584        mCarryQueueIdx = baseCarryQueueIdx;
585
586        BasicBlock* whileCondBlock = BasicBlock::Create(mMod->getContext(), "while.cond", mFunc_process_block, 0);
587        BasicBlock* whileBodyBlock = BasicBlock::Create(mMod->getContext(), "while.body", mFunc_process_block, 0);
588        BasicBlock* whileEndBlock = BasicBlock::Create(mMod->getContext(), "while.end", mFunc_process_block, 0);
589
590        // Note: compileStatements may update the mBasicBlock pointer if the body contains nested loops. It
591        // may not be same one that we entered the function with.
592        IRBuilder<> bEntry(mBasicBlock);
593        bEntry.CreateBr(whileCondBlock);
594
595        // CONDITION BLOCK
596        IRBuilder<> bCond(whileCondBlock);
597        // generate phi nodes for any carry propogating instruction
598        std::vector<PHINode*> phiNodes(whileStatement->getInclusiveCarryCount() + nextNodes.size());
599        unsigned index = 0;
600        for (index = 0; index != whileStatement->getInclusiveCarryCount(); ++index) {
601            PHINode * phi = bCond.CreatePHI(mBitBlockType, 2);
602            phi->addIncoming(mCarryQueueVector[baseCarryQueueIdx + index], mBasicBlock);
603            mCarryQueueVector[baseCarryQueueIdx + index] = mZeroInitializer; // (use phi for multi-carry mode.)
604            phiNodes[index] = phi;
605        }
606        // and for any Next nodes in the loop body
607        for (Next * n : nextNodes) {
608            PHINode * phi = bCond.CreatePHI(mBitBlockType, 2, n->getName()->str());
609            auto f = mMarkerMap.find(n->getName());
610            assert (f != mMarkerMap.end());
611            phi->addIncoming(f->second, mBasicBlock);
612            mMarkerMap[n->getName()] = phi;
613            phiNodes[index++] = phi;
614        }
615
616        mBasicBlock = whileCondBlock;
617        bCond.CreateCondBr(genBitBlockAny(compileExpression(whileStatement->getCondition())), whileEndBlock, whileBodyBlock);
618
619        // BODY BLOCK
620        mBasicBlock = whileBodyBlock;
621        retVal = compileStatements(whileStatement->getBody());
622        // update phi nodes for any carry propogating instruction
623        IRBuilder<> bWhileBody(mBasicBlock);
624        for (index = 0; index != whileStatement->getInclusiveCarryCount(); ++index) {
625            Value * carryOut = bWhileBody.CreateOr(phiNodes[index], mCarryQueueVector[baseCarryQueueIdx + index]);
626            PHINode * phi = phiNodes[index];
627            phi->addIncoming(carryOut, mBasicBlock);
628            mCarryQueueVector[baseCarryQueueIdx + index] = phi;
629        }
630        // and for any Next nodes in the loop body
631        for (Next * n : nextNodes) {
632            auto f = mMarkerMap.find(n->getName());
633            assert (f != mMarkerMap.end());
634            PHINode * phi = phiNodes[index++];
635            phi->addIncoming(f->second, mBasicBlock);
636            mMarkerMap[n->getName()] = phi;
637        }
638
639        bWhileBody.CreateBr(whileCondBlock);
640
641        // EXIT BLOCK
642        mBasicBlock = whileEndBlock;   
643        if (--mNestingDepth == 0) {
644            for (index = 0; index != whileStatement->getInclusiveCarryCount(); ++index) {
645                genCarryOutStore(phiNodes[index], baseCarryQueueIdx + index);
646            }
647        }
648    }
649    return retVal;
650}
651
652Value * PabloCompiler::compileExpression(const PabloAST * expr)
653{
654    Value * retVal = nullptr;
655    IRBuilder<> b(mBasicBlock);
656    if (isa<Ones>(expr)) {
657        retVal = mOneInitializer;
658    }
659    else if (isa<Zeroes>(expr)) {
660        retVal = mZeroInitializer;
661    }
662    else if (const Call* call = dyn_cast<Call>(expr)) {
663        //Call the callee once and store the result in the marker map.
664        auto mi = mMarkerMap.find(call->getCallee());
665        if (mi == mMarkerMap.end()) {
666            auto ci = mCalleeMap.find(call->getCallee());
667            if (ci == mCalleeMap.end()) {
668                throw std::runtime_error("Unexpected error locating static function for \"" + call->getCallee()->str() + "\"");
669            }
670            mi = mMarkerMap.insert(std::make_pair(call->getCallee(), b.CreateCall(ci->second, mBasisBitsAddr))).first;
671        }
672        retVal = mi->second;
673    }
674    else if (const Var * var = dyn_cast<Var>(expr))
675    {
676        auto f = mMarkerMap.find(var->getName());
677        assert (f != mMarkerMap.end());
678        retVal = f->second;
679    }
680    else if (const And * pablo_and = dyn_cast<And>(expr))
681    {
682        retVal = b.CreateAnd(compileExpression(pablo_and->getExpr1()), compileExpression(pablo_and->getExpr2()), "and");
683    }
684    else if (const Or * pablo_or = dyn_cast<Or>(expr))
685    {
686        retVal = b.CreateOr(compileExpression(pablo_or->getExpr1()), compileExpression(pablo_or->getExpr2()), "or");
687    }
688    else if (const Sel * sel = dyn_cast<Sel>(expr))
689    {
690        Value* ifMask = compileExpression(sel->getCondition());
691        Value* ifTrue = b.CreateAnd(ifMask, compileExpression(sel->getTrueExpr()));
692        Value* ifFalse = b.CreateAnd(genNot(ifMask), compileExpression(sel->getFalseExpr()));
693        retVal = b.CreateOr(ifTrue, ifFalse);
694    }
695    else if (const Not * pablo_not = dyn_cast<Not>(expr))
696    {
697        retVal = genNot(compileExpression(pablo_not->getExpr()));
698    }
699    else if (const Advance * adv = dyn_cast<Advance>(expr))
700    {
701        Value* strm_value = compileExpression(adv->getExpr());
702        int shift = adv->getAdvanceAmount();
703        retVal = genAdvanceWithCarry(strm_value, shift);
704    }
705    else if (const MatchStar * mstar = dyn_cast<MatchStar>(expr))
706    {
707        Value* marker = compileExpression(mstar->getMarker());
708        Value* cc = compileExpression(mstar->getCharClass());
709        Value* marker_and_cc = b.CreateAnd(marker, cc);
710        retVal = b.CreateOr(b.CreateXor(genAddWithCarry(marker_and_cc, cc), cc), marker, "matchstar");
711    }
712    else if (const ScanThru * sthru = dyn_cast<ScanThru>(expr))
713    {
714        Value* marker_expr = compileExpression(sthru->getScanFrom());
715        Value* cc_expr = compileExpression(sthru->getScanThru());
716        retVal = b.CreateAnd(genAddWithCarry(marker_expr, cc_expr), genNot(cc_expr), "scanthru");
717    }
718    return retVal;
719}
720
721#ifdef USE_UADD_OVERFLOW
722SumWithOverflowPack PabloCompiler::callUaddOverflow(Value* int128_e1, Value* int128_e2, Value* int1_cin) {
723    std::vector<Value*> struct_res_params;
724    struct_res_params.push_back(int128_e1);
725    struct_res_params.push_back(int128_e2);
726    struct_res_params.push_back(int1_cin);
727    CallInst* struct_res = CallInst::Create(mFunc_llvm_uadd_with_overflow, struct_res_params, "uadd_overflow_res", mBasicBlock);
728    struct_res->setCallingConv(CallingConv::C);
729    struct_res->setTailCall(false);
730    AttributeSet struct_res_PAL;
731    struct_res->setAttributes(struct_res_PAL);
732
733    SumWithOverflowPack ret;
734
735    std::vector<unsigned> int128_sum_indices;
736    int128_sum_indices.push_back(0);
737    ret.sum = ExtractValueInst::Create(struct_res, int128_sum_indices, "sum", mBasicBlock);
738
739    std::vector<unsigned> int1_obit_indices;
740    int1_obit_indices.push_back(1);
741    ret.obit = ExtractValueInst::Create(struct_res, int1_obit_indices, "obit", mBasicBlock);
742
743    return ret;
744}
745#endif
746
747Value* PabloCompiler::genAddWithCarry(Value* e1, Value* e2) {
748    IRBuilder<> b(mBasicBlock);
749
750    //CarryQ - carry in.
751    const int carryIdx = mCarryQueueIdx++;
752    Value* carryq_value = genCarryInLoad(carryIdx);
753
754#ifdef USE_UADD_OVERFLOW
755    //use llvm.uadd.with.overflow.i128 or i256
756    ConstantInt* const_int32_6 = ConstantInt::get(mMod->getContext(), APInt(32, StringRef("0"), 10));
757    CastInst* int128_e1 = new BitCastInst(e1, IntegerType::get(mMod->getContext(), BLOCK_SIZE), "e1_128", mBasicBlock);
758    CastInst* int128_e2 = new BitCastInst(e2, IntegerType::get(mMod->getContext(), BLOCK_SIZE), "e2_128", mBasicBlock);
759    ExtractElementInst * int64_carryq_value = ExtractElementInst::Create(carryq_value, const_int32_6, "carryq_64", mBasicBlock);
760    CastInst* int1_carryq_value = new TruncInst(int64_carryq_value, IntegerType::get(mMod->getContext(), 1), "carryq_1", mBasicBlock);
761    SumWithOverflowPack sumpack0;
762    sumpack0 = callUaddOverflow(int128_e1, int128_e2, int1_carryq_value);
763    Value* obit = sumpack0.obit;
764    Value* sum = b.CreateBitCast(sumpack0.sum, mXi64Vect, "sum");
765    /*obit is the i1 carryout, zero extend and insert it into a v2i64 or v4i64 vector.*/
766    ConstantAggregateZero* const_packed_5 = ConstantAggregateZero::get(mXi64Vect);
767    CastInst* int64_o0 = new ZExtInst(obit, IntegerType::get(mMod->getContext(), 64), "o0", mBasicBlock);
768    InsertElementInst* carry_out = InsertElementInst::Create(const_packed_5, int64_o0, const_int32_6, "carry_out", mBasicBlock);
769#else
770    //calculate carry through logical ops
771    Value* carrygen = b.CreateAnd(e1, e2, "carrygen");
772    Value* carryprop = b.CreateOr(e1, e2, "carryprop");
773    Value* digitsum = b.CreateAdd(e1, e2, "digitsum");
774    Value* partial = b.CreateAdd(digitsum, carryq_value, "partial");
775    Value* digitcarry = b.CreateOr(carrygen, b.CreateAnd(carryprop, genNot(partial)));
776    Value* mid_carry_in = genShiftLeft64(b.CreateLShr(digitcarry, 63), "mid_carry_in");
777
778    Value* sum = b.CreateAdd(partial, mid_carry_in, "sum");
779    Value* carry_out = genShiftHighbitToLow(b.CreateOr(carrygen, b.CreateAnd(carryprop, genNot(sum))), "carry_out");
780#endif
781    genCarryOutStore(carry_out, carryIdx);
782    return sum;
783}
784
785Value* PabloCompiler::genCarryInLoad(const unsigned index) {   
786    assert (index < mCarryQueueVector.size());
787    if (mNestingDepth == 0) {
788        IRBuilder<> b(mBasicBlock);
789        mCarryQueueVector[index] = b.CreateAlignedLoad(b.CreateGEP(mCarryQueuePtr, b.getInt64(index)), BLOCK_SIZE/8, false);
790    }
791    return mCarryQueueVector[index];
792}
793
794void PabloCompiler::genCarryOutStore(Value* carryOut, const unsigned index ) {
795    assert (carryOut);
796    assert (index < mCarryQueueVector.size());
797    if (mNestingDepth == 0) {       
798        IRBuilder<> b(mBasicBlock);
799        b.CreateAlignedStore(carryOut, b.CreateGEP(mCarryQueuePtr, b.getInt64(index)), BLOCK_SIZE/8, false);
800    }
801    mCarryQueueVector[index] = carryOut;
802}
803
804inline Value* PabloCompiler::genBitBlockAny(Value* test) {
805    IRBuilder<> b(mBasicBlock);
806    Value* cast_marker_value_1 = b.CreateBitCast(test, IntegerType::get(mMod->getContext(), BLOCK_SIZE));
807    return b.CreateICmpEQ(cast_marker_value_1, ConstantInt::get(IntegerType::get(mMod->getContext(), BLOCK_SIZE), 0));
808}
809
810Value* PabloCompiler::genShiftHighbitToLow(Value* e, const Twine &namehint) {
811    IRBuilder<> b(mBasicBlock);
812    Value* i128_val = b.CreateBitCast(e, IntegerType::get(mMod->getContext(), BLOCK_SIZE));
813    return b.CreateBitCast(b.CreateLShr(i128_val, BLOCK_SIZE - 1, namehint), mBitBlockType);
814}
815
816Value* PabloCompiler::genShiftLeft64(Value* e, const Twine &namehint) {
817    IRBuilder<> b(mBasicBlock);
818    Value* i128_val = b.CreateBitCast(e, IntegerType::get(mMod->getContext(), BLOCK_SIZE));
819    return b.CreateBitCast(b.CreateShl(i128_val, 64, namehint), mBitBlockType);
820}
821
822inline Value* PabloCompiler::genNot(Value* expr) {
823    IRBuilder<> b(mBasicBlock);
824    return b.CreateXor(expr, mOneInitializer, "not");
825}
826
827Value* PabloCompiler::genAdvanceWithCarry(Value* strm_value, int shift_amount) {
828
829    IRBuilder<> b(mBasicBlock);
830#if (BLOCK_SIZE == 128)
831    const auto carryIdx = mCarryQueueIdx++;
832    if (shift_amount == 1) {
833        Value* carryq_value = genCarryInLoad(carryIdx);
834        Value* srli_1_value = b.CreateLShr(strm_value, 63);
835        Value* packed_shuffle;
836        Constant* const_packed_1_elems [] = {b.getInt32(0), b.getInt32(2)};
837        Constant* const_packed_1 = ConstantVector::get(const_packed_1_elems);
838        packed_shuffle = b.CreateShuffleVector(carryq_value, srli_1_value, const_packed_1);
839
840        Constant* const_packed_2_elems[] = {b.getInt64(1), b.getInt64(1)};
841        Constant* const_packed_2 = ConstantVector::get(const_packed_2_elems);
842
843        Value* shl_value = b.CreateShl(strm_value, const_packed_2);
844        Value* result_value = b.CreateOr(shl_value, packed_shuffle, "advance");
845
846        Value* carry_out = genShiftHighbitToLow(strm_value, "carry_out");
847        //CarryQ - carry out:
848        genCarryOutStore(carry_out, carryIdx);
849           
850        return result_value;
851    }
852    else if (shift_amount < 64) {
853        // This is the preferred logic, but is too slow for the general case.   
854        // We need to speed up our custom LLVM for this code.
855        Value* carryq_longint = b.CreateBitCast(genCarryInLoad(carryIdx), IntegerType::get(mMod->getContext(), BLOCK_SIZE));
856        Value* strm_longint = b.CreateBitCast(strm_value, IntegerType::get(mMod->getContext(), BLOCK_SIZE));
857        Value* adv_longint = b.CreateOr(b.CreateShl(strm_longint, shift_amount), carryq_longint, "advance");
858    Value* result_value = b.CreateBitCast(adv_longint, mBitBlockType);
859    Value* carry_out = b.CreateBitCast(b.CreateLShr(strm_longint, BLOCK_SIZE - shift_amount, "advance_out"), mBitBlockType);
860        //CarryQ - carry out:
861        genCarryOutStore(carry_out, carryIdx);
862           
863        return result_value;
864    }
865    else {//if (shift_amount >= 64) {
866        throw std::runtime_error("Shift amount >= 64 in Advance is currently unsupported.");
867    }
868#endif
869
870#if (BLOCK_SIZE == 256)
871    return genAddWithCarry(strm_value, strm_value);
872#endif
873
874}
875
876void PabloCompiler::SetOutputValue(Value * marker, const unsigned index) {
877    IRBuilder<> b(mBasicBlock);
878    if (marker->getType()->isPointerTy()) {
879        marker = b.CreateAlignedLoad(marker, BLOCK_SIZE/8, false);
880    }
881    Value* indices[] = {b.getInt64(0), b.getInt32(index)};
882    Value* gep = b.CreateGEP(mOutputAddrPtr, indices);
883    b.CreateAlignedStore(marker, gep, BLOCK_SIZE/8, false);
884}
885
886}
Note: See TracBrowser for help on using the repository browser.