source: icGREP/icgrep-devel/icgrep/pablo/pablo_compiler.cpp @ 4268

Last change on this file since 4268 was 4268, checked in by nmedfort, 5 years ago

Generalized the writing of output variables by adding a 'flag' to the Assign nodes.

File size: 33.9 KB
Line 
1/*
2 *  Copyright (c) 2014 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 *  icgrep is a trademark of International Characters.
5 */
6
7/*
8 *  Copyright (c) 2014 International Characters.
9 *  This software is licensed to the public under the Open Software License 3.0.
10 *  icgrep is a trademark of International Characters.
11 */
12
13#include <pablo/pablo_compiler.h>
14#include <pablo/codegenstate.h>
15#include <pablo/printer_pablos.h>
16#include <cc/cc_namemap.hpp>
17#include <re/re_name.h>
18#include <stdexcept>
19#include <include/simd-lib/bitblock.hpp>
20
21//#define DUMP_GENERATED_IR
22//#define DUMP_OPTIMIZED_IR
23
24extern "C" {
25  void wrapped_print_register(BitBlock bit_block) {
26      print_register<BitBlock>("", bit_block);
27  }
28}
29
30#define CREATE_GENERAL_CODE_CATEGORY(SUFFIX) \
31SUFFIX * f##SUFFIX = nullptr; \
32extern "C" { \
33    BitBlock __get_category_##SUFFIX(Basis_bits &basis_bits) { \
34        if (f##SUFFIX == nullptr) f##SUFFIX = new SUFFIX(); \
35        Struct_##SUFFIX output; \
36        f##SUFFIX->do_block(basis_bits, output); \
37        return output.cc; \
38    } \
39}
40
41CREATE_GENERAL_CODE_CATEGORY(Cc)
42CREATE_GENERAL_CODE_CATEGORY(Cf)
43CREATE_GENERAL_CODE_CATEGORY(Cn)
44CREATE_GENERAL_CODE_CATEGORY(Co)
45CREATE_GENERAL_CODE_CATEGORY(Cs)
46CREATE_GENERAL_CODE_CATEGORY(Ll)
47CREATE_GENERAL_CODE_CATEGORY(Lm)
48CREATE_GENERAL_CODE_CATEGORY(Lo)
49CREATE_GENERAL_CODE_CATEGORY(Lt)
50CREATE_GENERAL_CODE_CATEGORY(Lu)
51CREATE_GENERAL_CODE_CATEGORY(Mc)
52CREATE_GENERAL_CODE_CATEGORY(Me)
53CREATE_GENERAL_CODE_CATEGORY(Mn)
54CREATE_GENERAL_CODE_CATEGORY(Nd)
55CREATE_GENERAL_CODE_CATEGORY(Nl)
56CREATE_GENERAL_CODE_CATEGORY(No)
57CREATE_GENERAL_CODE_CATEGORY(Pc)
58CREATE_GENERAL_CODE_CATEGORY(Pd)
59CREATE_GENERAL_CODE_CATEGORY(Pe)
60CREATE_GENERAL_CODE_CATEGORY(Pf)
61CREATE_GENERAL_CODE_CATEGORY(Pi)
62CREATE_GENERAL_CODE_CATEGORY(Po)
63CREATE_GENERAL_CODE_CATEGORY(Ps)
64CREATE_GENERAL_CODE_CATEGORY(Sc)
65CREATE_GENERAL_CODE_CATEGORY(Sk)
66CREATE_GENERAL_CODE_CATEGORY(Sm)
67CREATE_GENERAL_CODE_CATEGORY(So)
68CREATE_GENERAL_CODE_CATEGORY(Zl)
69CREATE_GENERAL_CODE_CATEGORY(Zp)
70CREATE_GENERAL_CODE_CATEGORY(Zs)
71
72#undef CREATE_GENERAL_CODE_CATEGORY
73
74namespace pablo {
75
76PabloCompiler::PabloCompiler(const BasisBitVars & basisBitVars, int bits)
77: mBits(bits)
78, mBasisBitVars(basisBitVars)
79, mMod(new Module("icgrep", getGlobalContext()))
80, mBasicBlock(nullptr)
81, mExecutionEngine(nullptr)
82, mXi64Vect(VectorType::get(IntegerType::get(mMod->getContext(), 64), BLOCK_SIZE / 64))
83, mXi128Vect(VectorType::get(IntegerType::get(mMod->getContext(), 128), BLOCK_SIZE / 128))
84, mBasisBitsInputPtr(nullptr)
85, mCarryQueueIdx(0)
86, mCarryQueuePtr(nullptr)
87, mNestingDepth(0)
88, mCarryQueueSize(0)
89, mZeroInitializer(ConstantAggregateZero::get(mXi64Vect))
90, mOneInitializer(ConstantVector::getAllOnesValue(mXi64Vect))
91, mFunctionType(nullptr)
92, mFunc_process_block(nullptr)
93, mBasisBitsAddr(nullptr)
94, mOutputAddrPtr(nullptr)
95{
96    //Create the jit execution engine.up
97    InitializeNativeTarget();
98    std::string ErrStr;
99    mExecutionEngine = EngineBuilder(mMod).setUseMCJIT(true).setErrorStr(&ErrStr).setOptLevel(CodeGenOpt::Level::None).create();
100    if (mExecutionEngine == nullptr) {
101        throw std::runtime_error("Could not create ExecutionEngine: " + ErrStr);
102    }
103
104    InitializeNativeTargetAsmPrinter();
105    InitializeNativeTargetAsmParser();
106
107    DefineTypes();
108    DeclareFunctions();
109}
110
111PabloCompiler::~PabloCompiler()
112{
113    delete mMod;
114    delete fPs;
115    delete fNl;
116    delete fNo;
117    delete fLo;
118    delete fLl;
119    delete fLm;
120    delete fNd;
121    delete fPc;
122    delete fLt;
123    delete fLu;
124    delete fPf;
125    delete fPd;
126    delete fPe;
127    delete fPi;
128    delete fPo;
129    delete fMe;
130    delete fMc;
131    delete fMn;
132    delete fSk;
133    delete fSo;
134    delete fSm;
135    delete fSc;
136    delete fZl;
137    delete fCo;
138    delete fCn;
139    delete fCc;
140    delete fCf;
141    delete fCs;
142    delete fZp;
143    delete fZs;
144
145}
146
147LLVM_Gen_RetVal PabloCompiler::compile(PabloBlock & pb)
148{
149    mCarryQueueSize = 0;
150    DeclareCallFunctions(pb.statements());
151    mCarryQueueVector.resize(mCarryQueueSize);
152
153    Function::arg_iterator args = mFunc_process_block->arg_begin();
154    mBasisBitsAddr = args++;
155    mBasisBitsAddr->setName("basis_bits");
156    mCarryQueuePtr = args++;
157    mCarryQueuePtr->setName("carry_q");
158    mOutputAddrPtr = args++;
159    mOutputAddrPtr->setName("output");
160
161    //Create the carry queue.
162    mCarryQueueIdx = 0;
163    mNestingDepth = 0;
164    mBasicBlock = BasicBlock::Create(mMod->getContext(), "parabix_entry", mFunc_process_block,0);
165
166    //The basis bits structure
167    for (unsigned i = 0; i < mBits; ++i) {
168        IRBuilder<> b(mBasicBlock);
169        Value* indices[] = {b.getInt64(0), b.getInt32(i)};
170        const String * const name = mBasisBitVars[i]->getName();
171        Value * gep = b.CreateGEP(mBasisBitsAddr, indices);
172        LoadInst * basisBit = b.CreateAlignedLoad(gep, BLOCK_SIZE/8, false, name->str());
173        mMarkerMap.insert(std::make_pair(name, basisBit));
174    }
175
176    //Generate the IR instructions for the function.
177    compileStatements(pb.statements());
178
179    assert (mCarryQueueIdx == mCarryQueueSize);
180    assert (mNestingDepth == 0);
181    //Terminate the block
182    ReturnInst::Create(mMod->getContext(), mBasicBlock);
183
184    //Un-comment this line in order to display the IR that has been generated by this module.
185    #ifdef DUMP_GENERATED_IR
186    mMod->dump();
187    #endif
188
189    //Create a verifier.  The verifier will print an error message if our module is malformed in any way.
190    #ifdef USE_LLVM_3_5
191    verifyModule(*mMod, &dbgs());
192    #endif
193    #ifdef USE_LLVM_3_4
194    verifyModule(*mMod, PrintMessageAction);
195    #endif
196
197    //Use the pass manager to run optimizations on the function.
198    FunctionPassManager fpm(mMod);
199
200#ifdef USE_LLVM_3_5
201    mMod->setDataLayout(mExecutionEngine->getDataLayout());
202    // Set up the optimizer pipeline.  Start with registering info about how the target lays out data structures.
203    fpm.add(new DataLayoutPass(mMod));
204#endif
205
206#ifdef USE_LLVM_3_4
207    fpm.add(new DataLayout(*mExecutionEngine->getDataLayout()));
208#endif
209
210    //fpm.add(createPromoteMemoryToRegisterPass()); //Transform to SSA form.
211    //fpm.add(createBasicAliasAnalysisPass());      //Provide basic AliasAnalysis support for GVN. (Global Value Numbering)
212    //fpm.add(createCFGSimplificationPass());       //Simplify the control flow graph.
213    fpm.add(createInstructionCombiningPass());    //Simple peephole optimizations and bit-twiddling.
214    fpm.add(createReassociatePass());             //Reassociate expressions.
215    fpm.add(createGVNPass());                     //Eliminate common subexpressions.
216
217    fpm.doInitialization();
218
219    fpm.run(*mFunc_process_block);
220
221#ifdef DUMP_OPTIMIZED_IR
222    mMod->dump();
223#endif
224    mExecutionEngine->finalizeObject();
225
226    LLVM_Gen_RetVal retVal;
227    //Return the required size of the carry queue and a pointer to the process_block function.
228    retVal.carry_q_size = mCarryQueueSize;
229    retVal.process_block_fptr = mExecutionEngine->getPointerToFunction(mFunc_process_block);
230
231    return retVal;
232}
233
234void PabloCompiler::DefineTypes()
235{
236    StructType * structBasisBits = mMod->getTypeByName("struct.Basis_bits");
237    if (structBasisBits == nullptr) {
238        structBasisBits = StructType::create(mMod->getContext(), "struct.Basis_bits");
239    }
240    std::vector<Type*>StructTy_struct_Basis_bits_fields;
241    for (int i = 0; i < mBits; i++)
242    {
243        StructTy_struct_Basis_bits_fields.push_back(mXi64Vect);
244    }
245    if (structBasisBits->isOpaque()) {
246        structBasisBits->setBody(StructTy_struct_Basis_bits_fields, /*isPacked=*/false);
247    }
248    mBasisBitsInputPtr = PointerType::get(structBasisBits, 0);
249
250    std::vector<Type*>functionTypeArgs;
251    functionTypeArgs.push_back(mBasisBitsInputPtr);
252
253    //The carry q array.
254    //A pointer to the BitBlock vector.
255    functionTypeArgs.push_back(PointerType::get(mXi64Vect, 0));
256
257    //The output structure.
258    StructType * outputStruct = mMod->getTypeByName("struct.Output");
259    if (!outputStruct) {
260        outputStruct = StructType::create(mMod->getContext(), "struct.Output");
261    }
262    if (outputStruct->isOpaque()) {
263        std::vector<Type*>fields;
264        fields.push_back(mXi64Vect);
265        fields.push_back(mXi64Vect);
266        outputStruct->setBody(fields, /*isPacked=*/false);
267    }
268    PointerType* outputStructPtr = PointerType::get(outputStruct, 0);
269
270    //The &output parameter.
271    functionTypeArgs.push_back(outputStructPtr);
272
273    mFunctionType = FunctionType::get(
274     /*Result=*/Type::getVoidTy(mMod->getContext()),
275     /*Params=*/functionTypeArgs,
276     /*isVarArg=*/false);
277}
278
279void PabloCompiler::DeclareFunctions()
280{
281    //This function can be used for testing to print the contents of a register from JIT'd code to the terminal window.
282    //mFunc_print_register = mMod->getOrInsertFunction("wrapped_print_register", Type::getVoidTy(getGlobalContext()), mXi64Vect, NULL);
283    //mExecutionEngine->addGlobalMapping(cast<GlobalValue>(mFunc_print_register), (void *)&wrapped_print_register);
284    // to call->  b.CreateCall(mFunc_print_register, unicode_category);
285
286#ifdef USE_UADD_OVERFLOW
287    // Type Definitions for llvm.uadd.with.overflow.carryin.i128 or .i256
288    std::vector<Type*>StructTy_0_fields;
289    StructTy_0_fields.push_back(IntegerType::get(mMod->getContext(), BLOCK_SIZE));
290    StructTy_0_fields.push_back(IntegerType::get(mMod->getContext(), 1));
291    StructType *StructTy_0 = StructType::get(mMod->getContext(), StructTy_0_fields, /*isPacked=*/false);
292
293    std::vector<Type*>FuncTy_1_args;
294    FuncTy_1_args.push_back(IntegerType::get(mMod->getContext(), BLOCK_SIZE));
295    FuncTy_1_args.push_back(IntegerType::get(mMod->getContext(), BLOCK_SIZE));
296    FuncTy_1_args.push_back(IntegerType::get(mMod->getContext(), 1));
297    FunctionType* FuncTy_1 = FunctionType::get(
298                                              /*Result=*/StructTy_0,
299                                              /*Params=*/FuncTy_1_args,
300                                              /*isVarArg=*/false);
301
302    mFunc_llvm_uadd_with_overflow = mMod->getFunction("llvm.uadd.with.overflow.carryin.i"##BLOCK_SIZE);
303    if (!mFunc_llvm_uadd_with_overflow) {
304        mFunc_llvm_uadd_with_overflow = Function::Create(
305          /*Type=*/ FuncTy_1,
306          /*Linkage=*/ GlobalValue::ExternalLinkage,
307          /*Name=*/ "llvm.uadd.with.overflow.carryin.i"##BLOCK_SIZE, mMod); // (external, no body)
308        mFunc_llvm_uadd_with_overflow->setCallingConv(CallingConv::C);
309    }
310    AttributeSet mFunc_llvm_uadd_with_overflow_PAL;
311    {
312        SmallVector<AttributeSet, 4> Attrs;
313        AttributeSet PAS;
314        {
315          AttrBuilder B;
316          B.addAttribute(Attribute::NoUnwind);
317          B.addAttribute(Attribute::ReadNone);
318          PAS = AttributeSet::get(mMod->getContext(), ~0U, B);
319        }
320
321        Attrs.push_back(PAS);
322        mFunc_llvm_uadd_with_overflow_PAL = AttributeSet::get(mMod->getContext(), Attrs);
323    }
324    mFunc_llvm_uadd_with_overflow->setAttributes(mFunc_llvm_uadd_with_overflow_PAL);
325#endif
326
327    //Starts on process_block
328    SmallVector<AttributeSet, 4> Attrs;
329    AttributeSet PAS;
330    {
331        AttrBuilder B;
332        B.addAttribute(Attribute::ReadOnly);
333        B.addAttribute(Attribute::NoCapture);
334        PAS = AttributeSet::get(mMod->getContext(), 1U, B);
335    }
336    Attrs.push_back(PAS);
337    {
338        AttrBuilder B;
339        B.addAttribute(Attribute::NoCapture);
340        PAS = AttributeSet::get(mMod->getContext(), 2U, B);
341    }
342    Attrs.push_back(PAS);
343    {
344        AttrBuilder B;
345        B.addAttribute(Attribute::NoCapture);
346        PAS = AttributeSet::get(mMod->getContext(), 3U, B);
347    }
348    Attrs.push_back(PAS);
349    {
350        AttrBuilder B;
351        B.addAttribute(Attribute::NoUnwind);
352        B.addAttribute(Attribute::UWTable);
353        PAS = AttributeSet::get(mMod->getContext(), ~0U, B);
354    }
355    AttributeSet AttrSet = AttributeSet::get(mMod->getContext(), Attrs);
356
357    //Create the function that will be generated.
358    mFunc_process_block = mMod->getFunction("process_block");
359    if (!mFunc_process_block) {
360        mFunc_process_block = Function::Create(
361            /*Type=*/mFunctionType,
362            /*Linkage=*/GlobalValue::ExternalLinkage,
363            /*Name=*/"process_block", mMod);
364        mFunc_process_block->setCallingConv(CallingConv::C);
365    }
366    mFunc_process_block->setAttributes(AttrSet);
367}
368
369void PabloCompiler::DeclareCallFunctions(const StatementList & stmts) {
370    for (PabloAST * stmt : stmts) {
371        if (const Assign * assign = dyn_cast<Assign>(stmt)) {
372            DeclareCallFunctions(assign->getExpr());
373        }
374        if (const Next * next = dyn_cast<Next>(stmt)) {
375            DeclareCallFunctions(next->getExpr());
376        }
377        else if (If * ifStatement = dyn_cast<If>(stmt)) {
378            const auto preIfCarryCount = mCarryQueueSize;
379            DeclareCallFunctions(ifStatement->getCondition());
380            DeclareCallFunctions(ifStatement->getBody());
381            ifStatement->setInclusiveCarryCount(mCarryQueueSize - preIfCarryCount);
382        }
383        else if (While * whileStatement = dyn_cast<While>(stmt)) {
384            const auto preWhileCarryCount = mCarryQueueSize;
385            DeclareCallFunctions(whileStatement->getCondition());
386            DeclareCallFunctions(whileStatement->getBody());
387            whileStatement->setInclusiveCarryCount(mCarryQueueSize - preWhileCarryCount);
388        }
389    }
390}
391
392void PabloCompiler::DeclareCallFunctions(const PabloAST * expr)
393{
394    if (const Call * call = dyn_cast<const Call>(expr)) {
395        const String * const callee = call->getCallee();
396        assert (callee);
397        if (mCalleeMap.find(callee) == mCalleeMap.end()) {
398            void * callee_ptr = nullptr;
399            #define CHECK_GENERAL_CODE_CATEGORY(SUFFIX) \
400                if (callee->str() == #SUFFIX) { \
401                    callee_ptr = (void*)&__get_category_##SUFFIX; \
402                } else
403            CHECK_GENERAL_CODE_CATEGORY(Cc)
404            CHECK_GENERAL_CODE_CATEGORY(Cf)
405            CHECK_GENERAL_CODE_CATEGORY(Cn)
406            CHECK_GENERAL_CODE_CATEGORY(Co)
407            CHECK_GENERAL_CODE_CATEGORY(Cs)
408            CHECK_GENERAL_CODE_CATEGORY(Ll)
409            CHECK_GENERAL_CODE_CATEGORY(Lm)
410            CHECK_GENERAL_CODE_CATEGORY(Lo)
411            CHECK_GENERAL_CODE_CATEGORY(Lt)
412            CHECK_GENERAL_CODE_CATEGORY(Lu)
413            CHECK_GENERAL_CODE_CATEGORY(Mc)
414            CHECK_GENERAL_CODE_CATEGORY(Me)
415            CHECK_GENERAL_CODE_CATEGORY(Mn)
416            CHECK_GENERAL_CODE_CATEGORY(Nd)
417            CHECK_GENERAL_CODE_CATEGORY(Nl)
418            CHECK_GENERAL_CODE_CATEGORY(No)
419            CHECK_GENERAL_CODE_CATEGORY(Pc)
420            CHECK_GENERAL_CODE_CATEGORY(Pd)
421            CHECK_GENERAL_CODE_CATEGORY(Pe)
422            CHECK_GENERAL_CODE_CATEGORY(Pf)
423            CHECK_GENERAL_CODE_CATEGORY(Pi)
424            CHECK_GENERAL_CODE_CATEGORY(Po)
425            CHECK_GENERAL_CODE_CATEGORY(Ps)
426            CHECK_GENERAL_CODE_CATEGORY(Sc)
427            CHECK_GENERAL_CODE_CATEGORY(Sk)
428            CHECK_GENERAL_CODE_CATEGORY(Sm)
429            CHECK_GENERAL_CODE_CATEGORY(So)
430            CHECK_GENERAL_CODE_CATEGORY(Zl)
431            CHECK_GENERAL_CODE_CATEGORY(Zp)
432            CHECK_GENERAL_CODE_CATEGORY(Zs)
433            // OTHERWISE ...
434            throw std::runtime_error("Unknown unicode category \"" + callee->str() + "\"");
435            #undef CHECK_GENERAL_CODE_CATEGORY
436            Value * unicodeCategory = mMod->getOrInsertFunction("__get_category_" + callee->str(), mXi64Vect, mBasisBitsInputPtr, NULL);
437            if (unicodeCategory == nullptr) {
438                throw std::runtime_error("Could not create static method call for unicode category \"" + callee->str() + "\"");
439            }
440            mExecutionEngine->addGlobalMapping(cast<GlobalValue>(unicodeCategory), callee_ptr);
441            mCalleeMap.insert(std::make_pair(callee, unicodeCategory));
442        }
443    }
444    else if (const And * pablo_and = dyn_cast<const And>(expr))
445    {
446        DeclareCallFunctions(pablo_and->getExpr1());
447        DeclareCallFunctions(pablo_and->getExpr2());
448    }
449    else if (const Or * pablo_or = dyn_cast<const Or>(expr))
450    {
451        DeclareCallFunctions(pablo_or->getExpr1());
452        DeclareCallFunctions(pablo_or->getExpr2());
453    }
454    else if (const Sel * pablo_sel = dyn_cast<const Sel>(expr))
455    {
456        DeclareCallFunctions(pablo_sel->getCondition());
457        DeclareCallFunctions(pablo_sel->getTrueExpr());
458        DeclareCallFunctions(pablo_sel->getFalseExpr());
459    }
460    else if (const Not * pablo_not = dyn_cast<const Not>(expr))
461    {
462        DeclareCallFunctions(pablo_not->getExpr());
463    }
464    else if (const Advance * adv = dyn_cast<const Advance>(expr))
465    {
466        ++mCarryQueueSize;
467        DeclareCallFunctions(adv->getExpr());
468    }
469    else if (const MatchStar * mstar = dyn_cast<const MatchStar>(expr))
470    {
471        ++mCarryQueueSize;
472        DeclareCallFunctions(mstar->getMarker());
473        DeclareCallFunctions(mstar->getCharClass());
474    }
475    else if (const ScanThru * sthru = dyn_cast<const ScanThru>(expr))
476    {
477        ++mCarryQueueSize;
478        DeclareCallFunctions(sthru->getScanFrom());
479        DeclareCallFunctions(sthru->getScanThru());
480    }
481}
482
483Value * PabloCompiler::compileStatements(const StatementList & stmts) {
484    Value * retVal = nullptr;
485    for (PabloAST * statement : stmts) {
486        retVal = compileStatement(statement);
487    }
488    return retVal;
489}
490
491Value * PabloCompiler::compileStatement(const PabloAST * stmt)
492{
493    Value * retVal = nullptr;
494    if (const Assign * assign = dyn_cast<const Assign>(stmt))
495    {
496        Value* expr = compileExpression(assign->getExpr());
497        mMarkerMap[assign->getName()] = expr;
498        if (unlikely(assign->isOutputAssignment())) {
499            SetOutputValue(expr, assign->getOutputIndex());
500        }
501        retVal = expr;
502    }
503    if (const Next * next = dyn_cast<const Next>(stmt))
504    {
505        Value* expr = compileExpression(next->getExpr());
506        mMarkerMap[next->getName()] = expr;
507        retVal = expr;
508    }
509    else if (const If * ifstmt = dyn_cast<const If>(stmt))
510    {
511        BasicBlock * ifEntryBlock = mBasicBlock;
512        BasicBlock * ifBodyBlock = BasicBlock::Create(mMod->getContext(), "if.body", mFunc_process_block, 0);
513        BasicBlock * ifEndBlock = BasicBlock::Create(mMod->getContext(), "if.end", mFunc_process_block, 0);
514
515        int if_start_idx = mCarryQueueIdx;
516
517        Value* if_test_value = compileExpression(ifstmt->getCondition());
518
519        /* Generate the statements into the if body block, and also determine the
520           final carry index.  */
521
522        IRBuilder<> bIfBody(ifBodyBlock);
523        mBasicBlock = ifBodyBlock;
524
525        ++mNestingDepth;
526
527        Value *  returnMarker = compileStatements(ifstmt->getBody());
528
529        int if_end_idx = mCarryQueueIdx;
530        if (if_start_idx < if_end_idx + 1) {
531            // Have at least two internal carries.   Accumulate and store.
532            int if_accum_idx = mCarryQueueIdx++;
533
534            Value* if_carry_accum_value = genCarryInLoad(if_start_idx);
535
536            for (int c = if_start_idx+1; c < if_end_idx; c++)
537            {
538                Value* carryq_value = genCarryInLoad(c);
539                if_carry_accum_value = bIfBody.CreateOr(carryq_value, if_carry_accum_value);
540            }
541            genCarryOutStore(if_carry_accum_value, if_accum_idx);
542
543        }
544        bIfBody.CreateBr(ifEndBlock);
545
546        IRBuilder<> b_entry(ifEntryBlock);
547        mBasicBlock = ifEntryBlock;
548        if (if_start_idx < if_end_idx) {
549            // Have at least one internal carry.
550            int if_accum_idx = mCarryQueueIdx - 1;
551            Value* last_if_pending_carries = genCarryInLoad(if_accum_idx);
552            if_test_value = b_entry.CreateOr(if_test_value, last_if_pending_carries);
553        }
554        b_entry.CreateCondBr(genBitBlockAny(if_test_value), ifEndBlock, ifBodyBlock);
555
556        mBasicBlock = ifEndBlock;
557        --mNestingDepth;
558
559        retVal = returnMarker;
560    }
561    else if (const While * whileStatement = dyn_cast<const While>(stmt))
562    {
563        const auto baseCarryQueueIdx = mCarryQueueIdx;
564        if (mNestingDepth == 0) {
565            for (auto i = 0; i != whileStatement->getInclusiveCarryCount(); ++i) {
566                genCarryInLoad(baseCarryQueueIdx + i);
567            }
568        }       
569
570        SmallVector<Next*, 4> nextNodes;
571        for (PabloAST * node : whileStatement->getBody()) {
572            if (isa<Next>(node)) {
573                nextNodes.push_back(cast<Next>(node));
574            }
575        }
576
577        // Compile the initial iteration statements; the calls to genCarryOutStore will update the
578        // mCarryQueueVector with the appropriate values. Although we're not actually entering a new basic
579        // block yet, increment the nesting depth so that any calls to genCarryInLoad or genCarryOutStore
580        // will refer to the previous value.
581        ++mNestingDepth;
582        compileStatements(whileStatement->getBody());
583        // Reset the carry queue index. Note: this ought to be changed in the future. Currently this assumes
584        // that compiling the while body twice will generate the equivalent IR. This is not necessarily true
585        // but works for now.
586        mCarryQueueIdx = baseCarryQueueIdx;
587
588        BasicBlock* whileCondBlock = BasicBlock::Create(mMod->getContext(), "while.cond", mFunc_process_block, 0);
589        BasicBlock* whileBodyBlock = BasicBlock::Create(mMod->getContext(), "while.body", mFunc_process_block, 0);
590        BasicBlock* whileEndBlock = BasicBlock::Create(mMod->getContext(), "while.end", mFunc_process_block, 0);
591
592        // Note: compileStatements may update the mBasicBlock pointer if the body contains nested loops. It
593        // may not be same one that we entered the function with.
594        IRBuilder<> bEntry(mBasicBlock);
595        bEntry.CreateBr(whileCondBlock);
596
597        // CONDITION BLOCK
598        IRBuilder<> bCond(whileCondBlock);
599        // generate phi nodes for any carry propogating instruction
600        std::vector<PHINode*> phiNodes(whileStatement->getInclusiveCarryCount() + nextNodes.size());
601        unsigned index = 0;
602        for (index = 0; index != whileStatement->getInclusiveCarryCount(); ++index) {
603            PHINode * phi = bCond.CreatePHI(mXi64Vect, 2);
604            phi->addIncoming(mCarryQueueVector[baseCarryQueueIdx + index], mBasicBlock);
605            mCarryQueueVector[baseCarryQueueIdx + index] = mZeroInitializer; // (use phi for multi-carry mode.)
606            phiNodes[index] = phi;
607        }
608        // and for any Next nodes in the loop body
609        for (Next * n : nextNodes) {
610            PHINode * phi = bCond.CreatePHI(mXi64Vect, 2, n->getName()->str());
611            auto f = mMarkerMap.find(n->getName());
612            assert (f != mMarkerMap.end());
613            phi->addIncoming(f->second, mBasicBlock);
614            mMarkerMap[n->getName()] = phi;
615            phiNodes[index++] = phi;
616        }
617
618        mBasicBlock = whileCondBlock;
619        bCond.CreateCondBr(genBitBlockAny(compileExpression(whileStatement->getCondition())), whileEndBlock, whileBodyBlock);
620
621        // BODY BLOCK
622        mBasicBlock = whileBodyBlock;
623        retVal = compileStatements(whileStatement->getBody());
624        // update phi nodes for any carry propogating instruction
625        IRBuilder<> bWhileBody(mBasicBlock);
626        for (index = 0; index != whileStatement->getInclusiveCarryCount(); ++index) {
627            Value * carryOut = bWhileBody.CreateOr(phiNodes[index], mCarryQueueVector[baseCarryQueueIdx + index]);
628            PHINode * phi = phiNodes[index];
629            phi->addIncoming(carryOut, mBasicBlock);
630            mCarryQueueVector[baseCarryQueueIdx + index] = phi;
631        }
632        // and for any Next nodes in the loop body
633        for (Next * n : nextNodes) {
634            auto f = mMarkerMap.find(n->getName());
635            assert (f != mMarkerMap.end());
636            PHINode * phi = phiNodes[index++];
637            phi->addIncoming(f->second, mBasicBlock);
638            mMarkerMap[n->getName()] = phi;
639        }
640
641        bWhileBody.CreateBr(whileCondBlock);
642
643        // EXIT BLOCK
644        mBasicBlock = whileEndBlock;   
645        if (--mNestingDepth == 0) {
646            for (index = 0; index != whileStatement->getInclusiveCarryCount(); ++index) {
647                genCarryOutStore(phiNodes[index], baseCarryQueueIdx + index);
648            }
649        }
650    }
651    return retVal;
652}
653
654Value * PabloCompiler::compileExpression(const PabloAST * expr)
655{
656    Value * retVal = nullptr;
657    IRBuilder<> b(mBasicBlock);
658    if (isa<Ones>(expr)) {
659        retVal = mOneInitializer;
660    }
661    else if (isa<Zeroes>(expr)) {
662        retVal = mZeroInitializer;
663    }
664    else if (const Call* call = dyn_cast<Call>(expr)) {
665        //Call the callee once and store the result in the marker map.
666        auto mi = mMarkerMap.find(call->getCallee());
667        if (mi == mMarkerMap.end()) {
668            auto ci = mCalleeMap.find(call->getCallee());
669            if (ci == mCalleeMap.end()) {
670                throw std::runtime_error("Unexpected error locating static function for \"" + call->getCallee()->str() + "\"");
671            }
672            mi = mMarkerMap.insert(std::make_pair(call->getCallee(), b.CreateCall(ci->second, mBasisBitsAddr))).first;
673        }
674        retVal = mi->second;
675    }
676    else if (const Var * var = dyn_cast<Var>(expr))
677    {       
678        auto f = mMarkerMap.find(var->getName());
679        assert (f != mMarkerMap.end());
680        retVal = f->second;
681    }
682    else if (const And * pablo_and = dyn_cast<And>(expr))
683    {
684        retVal = b.CreateAnd(compileExpression(pablo_and->getExpr1()), compileExpression(pablo_and->getExpr2()), "and");
685    }
686    else if (const Or * pablo_or = dyn_cast<Or>(expr))
687    {
688        retVal = b.CreateOr(compileExpression(pablo_or->getExpr1()), compileExpression(pablo_or->getExpr2()), "or");
689    }
690    else if (const Sel * sel = dyn_cast<Sel>(expr))
691    {
692        Value* ifMask = compileExpression(sel->getCondition());
693        Value* ifTrue = b.CreateAnd(ifMask, compileExpression(sel->getTrueExpr()));
694        Value* ifFalse = b.CreateAnd(genNot(ifMask), compileExpression(sel->getFalseExpr()));
695        retVal = b.CreateOr(ifTrue, ifFalse);
696    }
697    else if (const Not * pablo_not = dyn_cast<Not>(expr))
698    {
699        retVal = genNot(compileExpression(pablo_not->getExpr()));
700    }
701    else if (const Advance * adv = dyn_cast<Advance>(expr))
702    {
703        Value* strm_value = compileExpression(adv->getExpr());
704                int shift = adv->getAdvanceAmount();
705        retVal = genAdvanceWithCarry(strm_value, shift);
706    }
707    else if (const MatchStar * mstar = dyn_cast<MatchStar>(expr))
708    {
709        Value* marker = compileExpression(mstar->getMarker());
710        Value* cc = compileExpression(mstar->getCharClass());
711        Value* marker_and_cc = b.CreateAnd(marker, cc);
712        retVal = b.CreateOr(b.CreateXor(genAddWithCarry(marker_and_cc, cc), cc), marker, "matchstar");
713    }
714    else if (const ScanThru * sthru = dyn_cast<ScanThru>(expr))
715    {
716        Value* marker_expr = compileExpression(sthru->getScanFrom());
717        Value* cc_expr = compileExpression(sthru->getScanThru());
718        retVal = b.CreateAnd(genAddWithCarry(marker_expr, cc_expr), genNot(cc_expr), "scanthru");
719    }
720    return retVal;
721}
722
723#ifdef USE_UADD_OVERFLOW
724SumWithOverflowPack PabloCompiler::callUaddOverflow(Value* int128_e1, Value* int128_e2, Value* int1_cin) {
725    std::vector<Value*> struct_res_params;
726    struct_res_params.push_back(int128_e1);
727    struct_res_params.push_back(int128_e2);
728    struct_res_params.push_back(int1_cin);
729    CallInst* struct_res = CallInst::Create(mFunc_llvm_uadd_with_overflow, struct_res_params, "uadd_overflow_res", mBasicBlock);
730    struct_res->setCallingConv(CallingConv::C);
731    struct_res->setTailCall(false);
732    AttributeSet struct_res_PAL;
733    struct_res->setAttributes(struct_res_PAL);
734
735    SumWithOverflowPack ret;
736
737    std::vector<unsigned> int128_sum_indices;
738    int128_sum_indices.push_back(0);
739    ret.sum = ExtractValueInst::Create(struct_res, int128_sum_indices, "sum", mBasicBlock);
740
741    std::vector<unsigned> int1_obit_indices;
742    int1_obit_indices.push_back(1);
743    ret.obit = ExtractValueInst::Create(struct_res, int1_obit_indices, "obit", mBasicBlock);
744
745    return ret;
746}
747#endif
748
749Value* PabloCompiler::genAddWithCarry(Value* e1, Value* e2) {
750    IRBuilder<> b(mBasicBlock);
751
752    //CarryQ - carry in.
753    const int carryIdx = mCarryQueueIdx++;
754    Value* carryq_value = genCarryInLoad(carryIdx);
755
756#ifdef USE_UADD_OVERFLOW
757    //use llvm.uadd.with.overflow.i128 or i256
758    ConstantInt* const_int32_6 = ConstantInt::get(mMod->getContext(), APInt(32, StringRef("0"), 10));
759    CastInst* int128_e1 = new BitCastInst(e1, IntegerType::get(mMod->getContext(), BLOCK_SIZE), "e1_128", mBasicBlock);
760    CastInst* int128_e2 = new BitCastInst(e2, IntegerType::get(mMod->getContext(), BLOCK_SIZE), "e2_128", mBasicBlock);
761    ExtractElementInst * int64_carryq_value = ExtractElementInst::Create(carryq_value, const_int32_6, "carryq_64", mBasicBlock);
762    CastInst* int1_carryq_value = new TruncInst(int64_carryq_value, IntegerType::get(mMod->getContext(), 1), "carryq_1", mBasicBlock);
763    SumWithOverflowPack sumpack0;
764    sumpack0 = callUaddOverflow(int128_e1, int128_e2, int1_carryq_value);
765    Value* obit = sumpack0.obit;
766    Value* sum = b.CreateBitCast(sumpack0.sum, mXi64Vect, "sum");
767    /*obit is the i1 carryout, zero extend and insert it into a v2i64 or v4i64 vector.*/
768    ConstantAggregateZero* const_packed_5 = ConstantAggregateZero::get(mXi64Vect);
769    CastInst* int64_o0 = new ZExtInst(obit, IntegerType::get(mMod->getContext(), 64), "o0", mBasicBlock);
770    InsertElementInst* carry_out = InsertElementInst::Create(const_packed_5, int64_o0, const_int32_6, "carry_out", mBasicBlock);
771#else
772    //calculate carry through logical ops
773    Value* carrygen = b.CreateAnd(e1, e2, "carrygen");
774    Value* carryprop = b.CreateOr(e1, e2, "carryprop");
775    Value* digitsum = b.CreateAdd(e1, e2, "digitsum");
776    Value* partial = b.CreateAdd(digitsum, carryq_value, "partial");
777    Value* digitcarry = b.CreateOr(carrygen, b.CreateAnd(carryprop, genNot(partial)));
778    Value* mid_carry_in = genShiftLeft64(b.CreateLShr(digitcarry, 63), "mid_carry_in");
779
780    Value* sum = b.CreateAdd(partial, mid_carry_in, "sum");
781    Value* carry_out = genShiftHighbitToLow(b.CreateOr(carrygen, b.CreateAnd(carryprop, genNot(sum))), "carry_out");
782#endif
783    genCarryOutStore(carry_out, carryIdx);
784    return sum;
785}
786
787Value* PabloCompiler::genCarryInLoad(const unsigned index) {   
788    assert (index < mCarryQueueVector.size());
789    if (mNestingDepth == 0) {
790        IRBuilder<> b(mBasicBlock);
791        mCarryQueueVector[index] = b.CreateAlignedLoad(b.CreateGEP(mCarryQueuePtr, b.getInt64(index)), BLOCK_SIZE/8, false);
792    }
793    return mCarryQueueVector[index];
794}
795
796void PabloCompiler::genCarryOutStore(Value* carryOut, const unsigned index ) {
797    assert (carryOut);
798    assert (index < mCarryQueueVector.size());
799    if (mNestingDepth == 0) {       
800        IRBuilder<> b(mBasicBlock);
801        b.CreateAlignedStore(carryOut, b.CreateGEP(mCarryQueuePtr, b.getInt64(index)), BLOCK_SIZE/8, false);
802    }
803    mCarryQueueVector[index] = carryOut;
804}
805
806inline Value* PabloCompiler::genBitBlockAny(Value* test) {
807    IRBuilder<> b(mBasicBlock);
808    Value* cast_marker_value_1 = b.CreateBitCast(test, IntegerType::get(mMod->getContext(), BLOCK_SIZE));
809    return b.CreateICmpEQ(cast_marker_value_1, ConstantInt::get(IntegerType::get(mMod->getContext(), BLOCK_SIZE), 0));
810}
811
812Value* PabloCompiler::genShiftHighbitToLow(Value* e, const Twine &namehint) {
813    IRBuilder<> b(mBasicBlock);
814    Value* i128_val = b.CreateBitCast(e, IntegerType::get(mMod->getContext(), BLOCK_SIZE));
815    return b.CreateBitCast(b.CreateLShr(i128_val, BLOCK_SIZE - 1, namehint), mXi64Vect);
816}
817
818Value* PabloCompiler::genShiftLeft64(Value* e, const Twine &namehint) {
819    IRBuilder<> b(mBasicBlock);
820    Value* i128_val = b.CreateBitCast(e, IntegerType::get(mMod->getContext(), BLOCK_SIZE));
821    return b.CreateBitCast(b.CreateShl(i128_val, 64, namehint), mXi64Vect);
822}
823
824inline Value* PabloCompiler::genNot(Value* expr) {
825    IRBuilder<> b(mBasicBlock);
826    return b.CreateXor(expr, mOneInitializer, "not");
827}
828
829Value* PabloCompiler::genAdvanceWithCarry(Value* strm_value, int shift_amount) {
830
831    IRBuilder<> b(mBasicBlock);
832#if (BLOCK_SIZE == 128)
833    const auto carryIdx = mCarryQueueIdx++;
834    if (shift_amount == 1) {
835        Value* carryq_value = genCarryInLoad(carryIdx);
836        Value* srli_1_value = b.CreateLShr(strm_value, 63);
837        Value* packed_shuffle;
838        Constant* const_packed_1_elems [] = {b.getInt32(0), b.getInt32(2)};
839        Constant* const_packed_1 = ConstantVector::get(const_packed_1_elems);
840        packed_shuffle = b.CreateShuffleVector(carryq_value, srli_1_value, const_packed_1);
841
842        Constant* const_packed_2_elems[] = {b.getInt64(1), b.getInt64(1)};
843        Constant* const_packed_2 = ConstantVector::get(const_packed_2_elems);
844
845        Value* shl_value = b.CreateShl(strm_value, const_packed_2);
846        Value* result_value = b.CreateOr(shl_value, packed_shuffle, "advance");
847
848        Value* carry_out = genShiftHighbitToLow(strm_value, "carry_out");
849        //CarryQ - carry out:
850        genCarryOutStore(carry_out, carryIdx);
851           
852        return result_value;
853    }
854    else if (shift_amount < 64) {
855        // This is the preferred logic, but is too slow for the general case.   
856        // We need to speed up our custom LLVM for this code.
857        Value* carryq_longint = b.CreateBitCast(genCarryInLoad(carryIdx), IntegerType::get(mMod->getContext(), BLOCK_SIZE));
858        Value* strm_longint = b.CreateBitCast(strm_value, IntegerType::get(mMod->getContext(), BLOCK_SIZE));
859        Value* adv_longint = b.CreateOr(b.CreateShl(strm_longint, shift_amount), carryq_longint, "advance");
860        Value* result_value = b.CreateBitCast(adv_longint, mXi64Vect);
861        Value* carry_out = b.CreateBitCast(b.CreateLShr(strm_longint, BLOCK_SIZE - shift_amount, "advance_out"), mXi64Vect);
862        //CarryQ - carry out:
863        genCarryOutStore(carry_out, carryIdx);
864           
865        return result_value;
866    }
867    else {//if (shift_amount >= 64) {
868        throw std::runtime_error("Shift amount >= 64 in Advance is currently unsupported.");
869    }
870#endif
871
872#if (BLOCK_SIZE == 256)
873    return genAddWithCarry(strm_value, strm_value);
874#endif
875
876}
877
878void PabloCompiler::SetOutputValue(Value * marker, const unsigned index) {
879    IRBuilder<> b(mBasicBlock);
880    if (marker->getType()->isPointerTy()) {
881        marker = b.CreateAlignedLoad(marker, BLOCK_SIZE/8, false);
882    }
883    Value* indices[] = {b.getInt64(0), b.getInt32(index)};
884    Value* gep = b.CreateGEP(mOutputAddrPtr, indices);
885    b.CreateAlignedStore(marker, gep, BLOCK_SIZE/8, false);
886}
887
888}
Note: See TracBrowser for help on using the repository browser.