source: icGREP/icgrep-devel/icgrep/pablo/pablo_compiler.cpp @ 4275

Last change on this file since 4275 was 4275, checked in by linmengl, 5 years ago

USE_UADD_OVERFLOW flag is now working

File size: 36.5 KB
Line 
1/*
2 *  Copyright (c) 2014 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 *  icgrep is a trademark of International Characters.
5 */
6
7/*
8 *  Copyright (c) 2014 International Characters.
9 *  This software is licensed to the public under the Open Software License 3.0.
10 *  icgrep is a trademark of International Characters.
11 */
12
13#include <pablo/pablo_compiler.h>
14#include <pablo/codegenstate.h>
15#include <pablo/printer_pablos.h>
16#include <cc/cc_namemap.hpp>
17#include <re/re_name.h>
18#include <stdexcept>
19#include <include/simd-lib/bitblock.hpp>
20
21#ifdef USE_LLVM_3_4
22#include <llvm/Analysis/Verifier.h>
23#include <llvm/Assembly/PrintModulePass.h>
24#include <llvm/Linker.h>
25#endif
26#ifdef USE_LLVM_3_5
27#include <llvm/IR/Verifier.h>
28#endif
29
30#include <llvm/Pass.h>
31#include <llvm/PassManager.h>
32#include <llvm/ADT/SmallVector.h>
33#include <llvm/Analysis/Passes.h>
34#include <llvm/IR/BasicBlock.h>
35#include <llvm/IR/CallingConv.h>
36#include <llvm/IR/Constants.h>
37#include <llvm/IR/DataLayout.h>
38#include <llvm/IR/DerivedTypes.h>
39#include <llvm/IR/Function.h>
40#include <llvm/IR/GlobalVariable.h>
41#include <llvm/IR/InlineAsm.h>
42#include <llvm/IR/Instructions.h>
43#include <llvm/IR/LLVMContext.h>
44#include <llvm/IR/Module.h>
45#include <llvm/Support/FormattedStream.h>
46#include <llvm/Support/MathExtras.h>
47#include <llvm/Support/Casting.h>
48#include <llvm/Support/Debug.h>
49#include <llvm/Support/TargetSelect.h>
50#include <llvm/Support/Host.h>
51#include <llvm/Transforms/Scalar.h>
52#include <llvm/ExecutionEngine/ExecutionEngine.h>
53#include <llvm/ExecutionEngine/MCJIT.h>
54#include <llvm/IRReader/IRReader.h>
55#include <llvm/Bitcode/ReaderWriter.h>
56#include <llvm/Support/MemoryBuffer.h>
57
58#include <llvm/IR/IRBuilder.h>
59
60//#define DUMP_GENERATED_IR
61//#define DUMP_OPTIMIZED_IR
62
63extern "C" {
64  void wrapped_print_register(BitBlock bit_block) {
65      print_register<BitBlock>("", bit_block);
66  }
67}
68
69#define CREATE_GENERAL_CODE_CATEGORY(SUFFIX) \
70SUFFIX * f##SUFFIX = nullptr; \
71extern "C" { \
72    BitBlock __get_category_##SUFFIX(Basis_bits &basis_bits) { \
73        if (f##SUFFIX == nullptr) f##SUFFIX = new SUFFIX(); \
74        Struct_##SUFFIX output; \
75        f##SUFFIX->do_block(basis_bits, output); \
76        return output.cc; \
77    } \
78}
79
80CREATE_GENERAL_CODE_CATEGORY(Cc)
81CREATE_GENERAL_CODE_CATEGORY(Cf)
82CREATE_GENERAL_CODE_CATEGORY(Cn)
83CREATE_GENERAL_CODE_CATEGORY(Co)
84CREATE_GENERAL_CODE_CATEGORY(Cs)
85CREATE_GENERAL_CODE_CATEGORY(Ll)
86CREATE_GENERAL_CODE_CATEGORY(Lm)
87CREATE_GENERAL_CODE_CATEGORY(Lo)
88CREATE_GENERAL_CODE_CATEGORY(Lt)
89CREATE_GENERAL_CODE_CATEGORY(Lu)
90CREATE_GENERAL_CODE_CATEGORY(Mc)
91CREATE_GENERAL_CODE_CATEGORY(Me)
92CREATE_GENERAL_CODE_CATEGORY(Mn)
93CREATE_GENERAL_CODE_CATEGORY(Nd)
94CREATE_GENERAL_CODE_CATEGORY(Nl)
95CREATE_GENERAL_CODE_CATEGORY(No)
96CREATE_GENERAL_CODE_CATEGORY(Pc)
97CREATE_GENERAL_CODE_CATEGORY(Pd)
98CREATE_GENERAL_CODE_CATEGORY(Pe)
99CREATE_GENERAL_CODE_CATEGORY(Pf)
100CREATE_GENERAL_CODE_CATEGORY(Pi)
101CREATE_GENERAL_CODE_CATEGORY(Po)
102CREATE_GENERAL_CODE_CATEGORY(Ps)
103CREATE_GENERAL_CODE_CATEGORY(Sc)
104CREATE_GENERAL_CODE_CATEGORY(Sk)
105CREATE_GENERAL_CODE_CATEGORY(Sm)
106CREATE_GENERAL_CODE_CATEGORY(So)
107CREATE_GENERAL_CODE_CATEGORY(Zl)
108CREATE_GENERAL_CODE_CATEGORY(Zp)
109CREATE_GENERAL_CODE_CATEGORY(Zs)
110
111#undef CREATE_GENERAL_CODE_CATEGORY
112
113namespace pablo {
114
115PabloCompiler::PabloCompiler(const std::vector<Var*> & basisBits)
116: mBasisBits(basisBits)
117, mMod(new Module("icgrep", getGlobalContext()))
118, mBasicBlock(nullptr)
119, mExecutionEngine(nullptr)
120, mBitBlockType(VectorType::get(IntegerType::get(mMod->getContext(), 64), BLOCK_SIZE / 64))
121, mBasisBitsInputPtr(nullptr)
122, mCarryQueueIdx(0)
123, mCarryQueuePtr(nullptr)
124, mNestingDepth(0)
125, mCarryQueueSize(0)
126, mZeroInitializer(ConstantAggregateZero::get(mBitBlockType))
127, mOneInitializer(ConstantVector::getAllOnesValue(mBitBlockType))
128, mFunctionType(nullptr)
129, mFunction(nullptr)
130, mBasisBitsAddr(nullptr)
131, mOutputAddrPtr(nullptr)
132, mMaxPabloWhileDepth(0)
133{
134    //Create the jit execution engine.up
135    InitializeNativeTarget();
136    InitializeNativeTargetAsmPrinter();
137    InitializeNativeTargetAsmParser();
138    DefineTypes();
139    DeclareFunctions();
140}
141
142PabloCompiler::~PabloCompiler()
143{
144    delete mMod;
145    delete fPs;
146    delete fNl;
147    delete fNo;
148    delete fLo;
149    delete fLl;
150    delete fLm;
151    delete fNd;
152    delete fPc;
153    delete fLt;
154    delete fLu;
155    delete fPf;
156    delete fPd;
157    delete fPe;
158    delete fPi;
159    delete fPo;
160    delete fMe;
161    delete fMc;
162    delete fMn;
163    delete fSk;
164    delete fSo;
165    delete fSm;
166    delete fSc;
167    delete fZl;
168    delete fCo;
169    delete fCn;
170    delete fCc;
171    delete fCf;
172    delete fCs;
173    delete fZp;
174    delete fZs;
175
176}
177
178LLVM_Gen_RetVal PabloCompiler::compile(PabloBlock & pb)
179{
180    mCarryQueueSize = 0;
181    DeclareCallFunctions(pb.statements());
182    mCarryQueueVector.resize(mCarryQueueSize);
183
184    Function::arg_iterator args = mFunction->arg_begin();
185    mBasisBitsAddr = args++;
186    mBasisBitsAddr->setName("basis_bits");
187    mCarryQueuePtr = args++;
188    mCarryQueuePtr->setName("carry_q");
189    mOutputAddrPtr = args++;
190    mOutputAddrPtr->setName("output");
191
192    //Create the carry queue.
193    mCarryQueueIdx = 0;
194    mNestingDepth = 0;
195    mBasicBlock = BasicBlock::Create(mMod->getContext(), "parabix_entry", mFunction,0);
196
197    //The basis bits structure
198    for (unsigned i = 0; i != mBasisBits.size(); ++i) {
199        IRBuilder<> b(mBasicBlock);
200        Value* indices[] = {b.getInt64(0), b.getInt32(i)};
201        const String * const name = mBasisBits[i]->getName();
202        Value * gep = b.CreateGEP(mBasisBitsAddr, indices);
203        LoadInst * basisBit = b.CreateAlignedLoad(gep, BLOCK_SIZE/8, false, name->str());
204        mMarkerMap.insert(std::make_pair(name, basisBit));
205    }
206
207    //Generate the IR instructions for the function.
208    compileStatements(pb.statements());
209
210    assert (mCarryQueueIdx == mCarryQueueSize);
211    assert (mNestingDepth == 0);
212    //Terminate the block
213    ReturnInst::Create(mMod->getContext(), mBasicBlock);
214
215    //Un-comment this line in order to display the IR that has been generated by this module.
216    #ifdef DUMP_GENERATED_IR
217    mMod->dump();
218    #endif
219
220    //Create a verifier.  The verifier will print an error message if our module is malformed in any way.
221    #ifdef USE_LLVM_3_5
222    verifyModule(*mMod, &dbgs());
223    #endif
224    #ifdef USE_LLVM_3_4
225    verifyModule(*mMod, PrintMessageAction);
226    #endif
227
228    //Use the pass manager to run optimizations on the function.
229    FunctionPassManager fpm(mMod);
230    if (true) {
231        std::string errMessage;
232        EngineBuilder builder(mMod);
233        builder.setErrorStr(&errMessage);
234        builder.setMCPU(sys::getHostCPUName());
235        builder.setUseMCJIT(true);
236        builder.setOptLevel(mMaxPabloWhileDepth == 0 ? CodeGenOpt::Level::None : CodeGenOpt::Level::Less);
237        mExecutionEngine = builder.create();
238        if (mExecutionEngine == nullptr) {
239            throw std::runtime_error("Could not create ExecutionEngine: " + errMessage);
240        }
241    }
242#ifdef USE_LLVM_3_5
243    mMod->setDataLayout(mExecutionEngine->getDataLayout());
244    // Set up the optimizer pipeline.  Start with registering info about how the target lays out data structures.
245    fpm.add(new DataLayoutPass(mMod));
246#endif
247
248#ifdef USE_LLVM_3_4
249    fpm.add(new DataLayout(*mExecutionEngine->getDataLayout()));
250#endif
251
252    //fpm.add(createPromoteMemoryToRegisterPass()); //Transform to SSA form.
253    //fpm.add(createBasicAliasAnalysisPass());      //Provide basic AliasAnalysis support for GVN. (Global Value Numbering)
254    //fpm.add(createCFGSimplificationPass());       //Simplify the control flow graph.
255
256    fpm.add(createCorrelatedValuePropagationPass());
257    fpm.add(createEarlyCSEPass());
258    fpm.add(createInstructionCombiningPass());    //Simple peephole optimizations and bit-twiddling.
259    fpm.add(createReassociatePass());             //Reassociate expressions.
260    fpm.add(createGVNPass());                     //Eliminate common subexpressions.   
261
262
263    // -O1 is based on -O0
264
265    // adds: -adce -always-inline -basicaa -basiccg -correlated-propagation -deadargelim -dse -early-cse -functionattrs -globalopt -indvars
266    // -inline-cost -instcombine -ipsccp -jump-threading -lazy-value-info -lcssa -licm -loop-deletion -loop-idiom -loop-rotate -loop-simplify
267    // -loop-unroll -loop-unswitch -loops -lower-expect -memcpyopt -memdep -no-aa -notti -prune-eh -reassociate -scalar-evolution -sccp
268    // -simplifycfg -sroa -strip-dead-prototypes -tailcallelim -tbaa
269
270/// no noticable impact on performance
271
272//    fpm.add(createConstantPropagationPass());
273//    fpm.add(createDeadCodeEliminationPass());
274//    fpm.add(createJumpThreadingPass());
275//    fpm.add(createLoopIdiomPass());
276//    fpm.add(createLoopRotatePass());
277//    fpm.add(createLCSSAPass());
278//    fpm.add(createLazyValueInfoPass());
279//    fpm.add(createLowerExpectIntrinsicPass());
280//    fpm.add(createLoopStrengthReducePass());
281//    fpm.add(createLoopUnswitchPass());
282//    fpm.add(createSCCPPass());
283//    fpm.add(createLICMPass());
284
285/// hurt performance significantly
286
287//    fpm.add(createLoopUnrollPass());
288//    fpm.add(createIndVarSimplifyPass());
289    fpm.doInitialization();
290
291    fpm.run(*mFunction);
292
293#ifdef DUMP_OPTIMIZED_IR
294    mMod->dump();
295#endif
296    mExecutionEngine->finalizeObject();
297
298    LLVM_Gen_RetVal retVal;
299    //Return the required size of the carry queue and a pointer to the process_block function.
300    retVal.carry_q_size = mCarryQueueSize;
301    retVal.process_block_fptr = mExecutionEngine->getPointerToFunction(mFunction);
302
303    return retVal;
304}
305
306void PabloCompiler::DefineTypes()
307{
308    StructType * structBasisBits = mMod->getTypeByName("struct.Basis_bits");
309    if (structBasisBits == nullptr) {
310        structBasisBits = StructType::create(mMod->getContext(), "struct.Basis_bits");
311    }
312    std::vector<Type*>StructTy_struct_Basis_bits_fields;
313    for (int i = 0; i != mBasisBits.size(); i++)
314    {
315        StructTy_struct_Basis_bits_fields.push_back(mBitBlockType);
316    }
317    if (structBasisBits->isOpaque()) {
318        structBasisBits->setBody(StructTy_struct_Basis_bits_fields, /*isPacked=*/false);
319    }
320    mBasisBitsInputPtr = PointerType::get(structBasisBits, 0);
321
322    std::vector<Type*>functionTypeArgs;
323    functionTypeArgs.push_back(mBasisBitsInputPtr);
324
325    //The carry q array.
326    //A pointer to the BitBlock vector.
327    functionTypeArgs.push_back(PointerType::get(mBitBlockType, 0));
328
329    //The output structure.
330    StructType * outputStruct = mMod->getTypeByName("struct.Output");
331    if (!outputStruct) {
332        outputStruct = StructType::create(mMod->getContext(), "struct.Output");
333    }
334    if (outputStruct->isOpaque()) {
335        std::vector<Type*>fields;
336        fields.push_back(mBitBlockType);
337        fields.push_back(mBitBlockType);
338        outputStruct->setBody(fields, /*isPacked=*/false);
339    }
340    PointerType* outputStructPtr = PointerType::get(outputStruct, 0);
341
342    //The &output parameter.
343    functionTypeArgs.push_back(outputStructPtr);
344
345    mFunctionType = FunctionType::get(
346     /*Result=*/Type::getVoidTy(mMod->getContext()),
347     /*Params=*/functionTypeArgs,
348     /*isVarArg=*/false);
349}
350
351void PabloCompiler::DeclareFunctions()
352{
353    //This function can be used for testing to print the contents of a register from JIT'd code to the terminal window.
354    //mFunc_print_register = mMod->getOrInsertFunction("wrapped_print_register", Type::getVoidTy(getGlobalContext()), mXi64Vect, NULL);
355    //mExecutionEngine->addGlobalMapping(cast<GlobalValue>(mFunc_print_register), (void *)&wrapped_print_register);
356    // to call->  b.CreateCall(mFunc_print_register, unicode_category);
357
358#ifdef USE_UADD_OVERFLOW
359    // Type Definitions for llvm.uadd.with.overflow.carryin.i128 or .i256
360    std::vector<Type*>StructTy_0_fields;
361    StructTy_0_fields.push_back(IntegerType::get(mMod->getContext(), BLOCK_SIZE));
362    StructTy_0_fields.push_back(IntegerType::get(mMod->getContext(), 1));
363    StructType *StructTy_0 = StructType::get(mMod->getContext(), StructTy_0_fields, /*isPacked=*/false);
364
365    std::vector<Type*>FuncTy_1_args;
366    FuncTy_1_args.push_back(IntegerType::get(mMod->getContext(), BLOCK_SIZE));
367    FuncTy_1_args.push_back(IntegerType::get(mMod->getContext(), BLOCK_SIZE));
368    FuncTy_1_args.push_back(IntegerType::get(mMod->getContext(), 1));
369    FunctionType* FuncTy_1 = FunctionType::get(
370                                              /*Result=*/StructTy_0,
371                                              /*Params=*/FuncTy_1_args,
372                                              /*isVarArg=*/false);
373
374    mFunctionUaddOverflow = mMod->getFunction("llvm.uadd.with.overflow.carryin.i" + 
375                                              std::to_string(BLOCK_SIZE));
376    if (!mFunctionUaddOverflow) {
377        mFunctionUaddOverflow = Function::Create(
378          /*Type=*/ FuncTy_1,
379          /*Linkage=*/ GlobalValue::ExternalLinkage,
380          /*Name=*/ "llvm.uadd.with.overflow.carryin.i" + std::to_string(BLOCK_SIZE), mMod); // (external, no body)
381        mFunctionUaddOverflow->setCallingConv(CallingConv::C);
382    }
383    AttributeSet mFunctionUaddOverflowPAL;
384    {
385        SmallVector<AttributeSet, 4> Attrs;
386        AttributeSet PAS;
387        {
388          AttrBuilder B;
389          B.addAttribute(Attribute::NoUnwind);
390          B.addAttribute(Attribute::ReadNone);
391          PAS = AttributeSet::get(mMod->getContext(), ~0U, B);
392        }
393
394        Attrs.push_back(PAS);
395        mFunctionUaddOverflowPAL = AttributeSet::get(mMod->getContext(), Attrs);
396    }
397    mFunctionUaddOverflow->setAttributes(mFunctionUaddOverflowPAL);
398#endif
399
400    //Starts on process_block
401    SmallVector<AttributeSet, 4> Attrs;
402    AttributeSet PAS;
403    {
404        AttrBuilder B;
405        B.addAttribute(Attribute::ReadOnly);
406        B.addAttribute(Attribute::NoCapture);
407        PAS = AttributeSet::get(mMod->getContext(), 1U, B);
408    }
409    Attrs.push_back(PAS);
410    {
411        AttrBuilder B;
412        B.addAttribute(Attribute::NoCapture);
413        PAS = AttributeSet::get(mMod->getContext(), 2U, B);
414    }
415    Attrs.push_back(PAS);
416    {
417        AttrBuilder B;
418        B.addAttribute(Attribute::NoCapture);
419        PAS = AttributeSet::get(mMod->getContext(), 3U, B);
420    }
421    Attrs.push_back(PAS);
422    {
423        AttrBuilder B;
424        B.addAttribute(Attribute::NoUnwind);
425        B.addAttribute(Attribute::UWTable);
426        PAS = AttributeSet::get(mMod->getContext(), ~0U, B);
427    }
428    AttributeSet AttrSet = AttributeSet::get(mMod->getContext(), Attrs);
429
430    //Create the function that will be generated.
431    mFunction = mMod->getFunction("process_block");
432    if (!mFunction) {
433        mFunction = Function::Create(
434            /*Type=*/mFunctionType,
435            /*Linkage=*/GlobalValue::ExternalLinkage,
436            /*Name=*/"process_block", mMod);
437        mFunction->setCallingConv(CallingConv::C);
438    }
439    mFunction->setAttributes(AttrSet);
440}
441
442void PabloCompiler::DeclareCallFunctions(const StatementList & stmts) {
443    for (PabloAST * stmt : stmts) {
444        if (const Assign * assign = dyn_cast<Assign>(stmt)) {
445            DeclareCallFunctions(assign->getExpr());
446        }
447        if (const Next * next = dyn_cast<Next>(stmt)) {
448            DeclareCallFunctions(next->getExpr());
449        }
450        else if (If * ifStatement = dyn_cast<If>(stmt)) {
451            const auto preIfCarryCount = mCarryQueueSize;
452            DeclareCallFunctions(ifStatement->getCondition());
453            DeclareCallFunctions(ifStatement->getBody());
454            ifStatement->setInclusiveCarryCount(mCarryQueueSize - preIfCarryCount);
455        }
456        else if (While * whileStatement = dyn_cast<While>(stmt)) {
457            const auto preWhileCarryCount = mCarryQueueSize;
458            DeclareCallFunctions(whileStatement->getCondition());
459            DeclareCallFunctions(whileStatement->getBody());
460            whileStatement->setInclusiveCarryCount(mCarryQueueSize - preWhileCarryCount);
461        }
462    }
463}
464
465void PabloCompiler::DeclareCallFunctions(const PabloAST * expr)
466{
467    if (const Call * call = dyn_cast<const Call>(expr)) {
468        const String * const callee = call->getCallee();
469        assert (callee);
470        if (mCalleeMap.find(callee) == mCalleeMap.end()) {
471            void * callee_ptr = nullptr;
472            #define CHECK_GENERAL_CODE_CATEGORY(SUFFIX) \
473                if (callee->str() == #SUFFIX) { \
474                    callee_ptr = (void*)&__get_category_##SUFFIX; \
475                } else
476            CHECK_GENERAL_CODE_CATEGORY(Cc)
477            CHECK_GENERAL_CODE_CATEGORY(Cf)
478            CHECK_GENERAL_CODE_CATEGORY(Cn)
479            CHECK_GENERAL_CODE_CATEGORY(Co)
480            CHECK_GENERAL_CODE_CATEGORY(Cs)
481            CHECK_GENERAL_CODE_CATEGORY(Ll)
482            CHECK_GENERAL_CODE_CATEGORY(Lm)
483            CHECK_GENERAL_CODE_CATEGORY(Lo)
484            CHECK_GENERAL_CODE_CATEGORY(Lt)
485            CHECK_GENERAL_CODE_CATEGORY(Lu)
486            CHECK_GENERAL_CODE_CATEGORY(Mc)
487            CHECK_GENERAL_CODE_CATEGORY(Me)
488            CHECK_GENERAL_CODE_CATEGORY(Mn)
489            CHECK_GENERAL_CODE_CATEGORY(Nd)
490            CHECK_GENERAL_CODE_CATEGORY(Nl)
491            CHECK_GENERAL_CODE_CATEGORY(No)
492            CHECK_GENERAL_CODE_CATEGORY(Pc)
493            CHECK_GENERAL_CODE_CATEGORY(Pd)
494            CHECK_GENERAL_CODE_CATEGORY(Pe)
495            CHECK_GENERAL_CODE_CATEGORY(Pf)
496            CHECK_GENERAL_CODE_CATEGORY(Pi)
497            CHECK_GENERAL_CODE_CATEGORY(Po)
498            CHECK_GENERAL_CODE_CATEGORY(Ps)
499            CHECK_GENERAL_CODE_CATEGORY(Sc)
500            CHECK_GENERAL_CODE_CATEGORY(Sk)
501            CHECK_GENERAL_CODE_CATEGORY(Sm)
502            CHECK_GENERAL_CODE_CATEGORY(So)
503            CHECK_GENERAL_CODE_CATEGORY(Zl)
504            CHECK_GENERAL_CODE_CATEGORY(Zp)
505            CHECK_GENERAL_CODE_CATEGORY(Zs)
506            // OTHERWISE ...
507            throw std::runtime_error("Unknown unicode category \"" + callee->str() + "\"");
508            #undef CHECK_GENERAL_CODE_CATEGORY
509            Value * unicodeCategory = mMod->getOrInsertFunction("__get_category_" + callee->str(), mBitBlockType, mBasisBitsInputPtr, NULL);
510            if (unicodeCategory == nullptr) {
511                throw std::runtime_error("Could not create static method call for unicode category \"" + callee->str() + "\"");
512            }
513            mExecutionEngine->addGlobalMapping(cast<GlobalValue>(unicodeCategory), callee_ptr);
514            mCalleeMap.insert(std::make_pair(callee, unicodeCategory));
515        }
516    }
517    else if (const And * pablo_and = dyn_cast<const And>(expr))
518    {
519        DeclareCallFunctions(pablo_and->getExpr1());
520        DeclareCallFunctions(pablo_and->getExpr2());
521    }
522    else if (const Or * pablo_or = dyn_cast<const Or>(expr))
523    {
524        DeclareCallFunctions(pablo_or->getExpr1());
525        DeclareCallFunctions(pablo_or->getExpr2());
526    }
527    else if (const Sel * pablo_sel = dyn_cast<const Sel>(expr))
528    {
529        DeclareCallFunctions(pablo_sel->getCondition());
530        DeclareCallFunctions(pablo_sel->getTrueExpr());
531        DeclareCallFunctions(pablo_sel->getFalseExpr());
532    }
533    else if (const Not * pablo_not = dyn_cast<const Not>(expr))
534    {
535        DeclareCallFunctions(pablo_not->getExpr());
536    }
537    else if (const Advance * adv = dyn_cast<const Advance>(expr))
538    {
539        ++mCarryQueueSize;
540        DeclareCallFunctions(adv->getExpr());
541    }
542    else if (const MatchStar * mstar = dyn_cast<const MatchStar>(expr))
543    {
544        ++mCarryQueueSize;
545        DeclareCallFunctions(mstar->getMarker());
546        DeclareCallFunctions(mstar->getCharClass());
547    }
548    else if (const ScanThru * sthru = dyn_cast<const ScanThru>(expr))
549    {
550        ++mCarryQueueSize;
551        DeclareCallFunctions(sthru->getScanFrom());
552        DeclareCallFunctions(sthru->getScanThru());
553    }
554}
555
556Value * PabloCompiler::compileStatements(const StatementList & stmts) {
557    Value * retVal = nullptr;
558    for (PabloAST * statement : stmts) {
559        retVal = compileStatement(statement);
560    }
561    return retVal;
562}
563
564Value * PabloCompiler::compileStatement(const PabloAST * stmt)
565{
566    Value * retVal = nullptr;
567    if (const Assign * assign = dyn_cast<const Assign>(stmt))
568    {
569        Value* expr = compileExpression(assign->getExpr());
570        mMarkerMap[assign->getName()] = expr;
571        if (unlikely(assign->isOutputAssignment())) {
572            SetOutputValue(expr, assign->getOutputIndex());
573        }
574        retVal = expr;
575    }
576    if (const Next * next = dyn_cast<const Next>(stmt))
577    {
578        Value* expr = compileExpression(next->getExpr());
579        mMarkerMap[next->getName()] = expr;
580        retVal = expr;
581    }
582    else if (const If * ifstmt = dyn_cast<const If>(stmt))
583    {
584        BasicBlock * ifEntryBlock = mBasicBlock;
585        BasicBlock * ifBodyBlock = BasicBlock::Create(mMod->getContext(), "if.body", mFunction, 0);
586        BasicBlock * ifEndBlock = BasicBlock::Create(mMod->getContext(), "if.end", mFunction, 0);
587
588        int if_start_idx = mCarryQueueIdx;
589
590        Value* if_test_value = compileExpression(ifstmt->getCondition());
591
592        /* Generate the statements into the if body block, and also determine the
593           final carry index.  */
594
595        IRBuilder<> bIfBody(ifBodyBlock);
596        mBasicBlock = ifBodyBlock;
597
598        ++mNestingDepth;
599
600        Value *  returnMarker = compileStatements(ifstmt->getBody());
601
602        int if_end_idx = mCarryQueueIdx;
603        if (if_start_idx < if_end_idx + 1) {
604            // Have at least two internal carries.   Accumulate and store.
605            int if_accum_idx = mCarryQueueIdx++;
606
607            Value* if_carry_accum_value = genCarryInLoad(if_start_idx);
608
609            for (int c = if_start_idx+1; c < if_end_idx; c++)
610            {
611                Value* carryq_value = genCarryInLoad(c);
612                if_carry_accum_value = bIfBody.CreateOr(carryq_value, if_carry_accum_value);
613            }
614            genCarryOutStore(if_carry_accum_value, if_accum_idx);
615
616        }
617        bIfBody.CreateBr(ifEndBlock);
618
619        IRBuilder<> b_entry(ifEntryBlock);
620        mBasicBlock = ifEntryBlock;
621        if (if_start_idx < if_end_idx) {
622            // Have at least one internal carry.
623            int if_accum_idx = mCarryQueueIdx - 1;
624            Value* last_if_pending_carries = genCarryInLoad(if_accum_idx);
625            if_test_value = b_entry.CreateOr(if_test_value, last_if_pending_carries);
626        }
627        b_entry.CreateCondBr(genBitBlockAny(if_test_value), ifEndBlock, ifBodyBlock);
628
629        mBasicBlock = ifEndBlock;
630        --mNestingDepth;
631
632        retVal = returnMarker;
633    }
634    else if (const While * whileStatement = dyn_cast<const While>(stmt))
635    {
636        const auto baseCarryQueueIdx = mCarryQueueIdx;
637        if (mNestingDepth == 0) {
638            for (auto i = 0; i != whileStatement->getInclusiveCarryCount(); ++i) {
639                genCarryInLoad(baseCarryQueueIdx + i);
640            }
641        }       
642
643        SmallVector<Next*, 4> nextNodes;
644        for (PabloAST * node : whileStatement->getBody()) {
645            if (isa<Next>(node)) {
646                nextNodes.push_back(cast<Next>(node));
647            }
648        }
649
650        // Compile the initial iteration statements; the calls to genCarryOutStore will update the
651        // mCarryQueueVector with the appropriate values. Although we're not actually entering a new basic
652        // block yet, increment the nesting depth so that any calls to genCarryInLoad or genCarryOutStore
653        // will refer to the previous value.
654        ++mNestingDepth;
655        if (mMaxPabloWhileDepth < mNestingDepth) mMaxPabloWhileDepth = mNestingDepth;
656        compileStatements(whileStatement->getBody());
657       
658        // Reset the carry queue index. Note: this ought to be changed in the future. Currently this assumes
659        // that compiling the while body twice will generate the equivalent IR. This is not necessarily true
660        // but works for now.
661        mCarryQueueIdx = baseCarryQueueIdx;
662
663        BasicBlock* whileCondBlock = BasicBlock::Create(mMod->getContext(), "while.cond", mFunction, 0);
664        BasicBlock* whileBodyBlock = BasicBlock::Create(mMod->getContext(), "while.body", mFunction, 0);
665        BasicBlock* whileEndBlock = BasicBlock::Create(mMod->getContext(), "while.end", mFunction, 0);
666
667        // Note: compileStatements may update the mBasicBlock pointer if the body contains nested loops. It
668        // may not be same one that we entered the function with.
669        IRBuilder<> bEntry(mBasicBlock);
670        bEntry.CreateBr(whileCondBlock);
671
672        // CONDITION BLOCK
673        IRBuilder<> bCond(whileCondBlock);
674        // generate phi nodes for any carry propogating instruction
675        std::vector<PHINode*> phiNodes(whileStatement->getInclusiveCarryCount() + nextNodes.size());
676        unsigned index = 0;
677        for (index = 0; index != whileStatement->getInclusiveCarryCount(); ++index) {
678            PHINode * phi = bCond.CreatePHI(mBitBlockType, 2);
679            phi->addIncoming(mCarryQueueVector[baseCarryQueueIdx + index], mBasicBlock);
680            mCarryQueueVector[baseCarryQueueIdx + index] = mZeroInitializer; // (use phi for multi-carry mode.)
681            phiNodes[index] = phi;
682        }
683        // and for any Next nodes in the loop body
684        for (Next * n : nextNodes) {
685            PHINode * phi = bCond.CreatePHI(mBitBlockType, 2, n->getName()->str());
686            auto f = mMarkerMap.find(n->getName());
687            assert (f != mMarkerMap.end());
688            phi->addIncoming(f->second, mBasicBlock);
689            mMarkerMap[n->getName()] = phi;
690            phiNodes[index++] = phi;
691        }
692
693        mBasicBlock = whileCondBlock;
694        bCond.CreateCondBr(genBitBlockAny(compileExpression(whileStatement->getCondition())), whileEndBlock, whileBodyBlock);
695
696        // BODY BLOCK
697        mBasicBlock = whileBodyBlock;
698        retVal = compileStatements(whileStatement->getBody());
699        // update phi nodes for any carry propogating instruction
700        IRBuilder<> bWhileBody(mBasicBlock);
701        for (index = 0; index != whileStatement->getInclusiveCarryCount(); ++index) {
702            Value * carryOut = bWhileBody.CreateOr(phiNodes[index], mCarryQueueVector[baseCarryQueueIdx + index]);
703            PHINode * phi = phiNodes[index];
704            phi->addIncoming(carryOut, mBasicBlock);
705            mCarryQueueVector[baseCarryQueueIdx + index] = phi;
706        }
707        // and for any Next nodes in the loop body
708        for (Next * n : nextNodes) {
709            auto f = mMarkerMap.find(n->getName());
710            assert (f != mMarkerMap.end());
711            PHINode * phi = phiNodes[index++];
712            phi->addIncoming(f->second, mBasicBlock);
713            mMarkerMap[n->getName()] = phi;
714        }
715
716        bWhileBody.CreateBr(whileCondBlock);
717
718        // EXIT BLOCK
719        mBasicBlock = whileEndBlock;   
720        if (--mNestingDepth == 0) {
721            for (index = 0; index != whileStatement->getInclusiveCarryCount(); ++index) {
722                genCarryOutStore(phiNodes[index], baseCarryQueueIdx + index);
723            }
724        }
725    }
726    return retVal;
727}
728
729Value * PabloCompiler::compileExpression(const PabloAST * expr)
730{
731    Value * retVal = nullptr;
732    IRBuilder<> b(mBasicBlock);
733    if (isa<Ones>(expr)) {
734        retVal = mOneInitializer;
735    }
736    else if (isa<Zeroes>(expr)) {
737        retVal = mZeroInitializer;
738    }
739    else if (const Call* call = dyn_cast<Call>(expr)) {
740        //Call the callee once and store the result in the marker map.
741        auto mi = mMarkerMap.find(call->getCallee());
742        if (mi == mMarkerMap.end()) {
743            auto ci = mCalleeMap.find(call->getCallee());
744            if (ci == mCalleeMap.end()) {
745                throw std::runtime_error("Unexpected error locating static function for \"" + call->getCallee()->str() + "\"");
746            }
747            mi = mMarkerMap.insert(std::make_pair(call->getCallee(), b.CreateCall(ci->second, mBasisBitsAddr))).first;
748        }
749        retVal = mi->second;
750    }
751    else if (const Var * var = dyn_cast<Var>(expr))
752    {
753        auto f = mMarkerMap.find(var->getName());
754        assert (f != mMarkerMap.end());
755        retVal = f->second;
756    }
757    else if (const And * pablo_and = dyn_cast<And>(expr))
758    {
759        retVal = b.CreateAnd(compileExpression(pablo_and->getExpr1()), compileExpression(pablo_and->getExpr2()), "and");
760    }
761    else if (const Or * pablo_or = dyn_cast<Or>(expr))
762    {
763        retVal = b.CreateOr(compileExpression(pablo_or->getExpr1()), compileExpression(pablo_or->getExpr2()), "or");
764    }
765    else if (const Sel * sel = dyn_cast<Sel>(expr))
766    {
767        Value* ifMask = compileExpression(sel->getCondition());
768        Value* ifTrue = b.CreateAnd(ifMask, compileExpression(sel->getTrueExpr()));
769        Value* ifFalse = b.CreateAnd(genNot(ifMask), compileExpression(sel->getFalseExpr()));
770        retVal = b.CreateOr(ifTrue, ifFalse);
771    }
772    else if (const Not * pablo_not = dyn_cast<Not>(expr))
773    {
774        retVal = genNot(compileExpression(pablo_not->getExpr()));
775    }
776    else if (const Advance * adv = dyn_cast<Advance>(expr))
777    {
778        Value* strm_value = compileExpression(adv->getExpr());
779        int shift = adv->getAdvanceAmount();
780        retVal = genAdvanceWithCarry(strm_value, shift);
781    }
782    else if (const MatchStar * mstar = dyn_cast<MatchStar>(expr))
783    {
784        Value* marker = compileExpression(mstar->getMarker());
785        Value* cc = compileExpression(mstar->getCharClass());
786        Value* marker_and_cc = b.CreateAnd(marker, cc);
787        retVal = b.CreateOr(b.CreateXor(genAddWithCarry(marker_and_cc, cc), cc), marker, "matchstar");
788    }
789    else if (const ScanThru * sthru = dyn_cast<ScanThru>(expr))
790    {
791        Value* marker_expr = compileExpression(sthru->getScanFrom());
792        Value* cc_expr = compileExpression(sthru->getScanThru());
793        retVal = b.CreateAnd(genAddWithCarry(marker_expr, cc_expr), genNot(cc_expr), "scanthru");
794    }
795    return retVal;
796}
797
798#ifdef USE_UADD_OVERFLOW
799PabloCompiler::SumWithOverflowPack PabloCompiler::callUaddOverflow(Value* int128_e1, Value* int128_e2, Value* int1_cin) {
800    std::vector<Value*> struct_res_params;
801    struct_res_params.push_back(int128_e1);
802    struct_res_params.push_back(int128_e2);
803    struct_res_params.push_back(int1_cin);
804    CallInst* struct_res = CallInst::Create(mFunctionUaddOverflow, struct_res_params, "uadd_overflow_res", mBasicBlock);
805    struct_res->setCallingConv(CallingConv::C);
806    struct_res->setTailCall(false);
807    AttributeSet struct_res_PAL;
808    struct_res->setAttributes(struct_res_PAL);
809
810    SumWithOverflowPack ret;
811
812    std::vector<unsigned> int128_sum_indices;
813    int128_sum_indices.push_back(0);
814    ret.sum = ExtractValueInst::Create(struct_res, int128_sum_indices, "sum", mBasicBlock);
815
816    std::vector<unsigned> int1_obit_indices;
817    int1_obit_indices.push_back(1);
818    ret.obit = ExtractValueInst::Create(struct_res, int1_obit_indices, "obit", mBasicBlock);
819
820    return ret;
821}
822#endif
823
824Value* PabloCompiler::genAddWithCarry(Value* e1, Value* e2) {
825    IRBuilder<> b(mBasicBlock);
826
827    //CarryQ - carry in.
828    const int carryIdx = mCarryQueueIdx++;
829    Value* carryq_value = genCarryInLoad(carryIdx);
830
831#ifdef USE_UADD_OVERFLOW
832    //use llvm.uadd.with.overflow.i128 or i256
833    ConstantInt* const_int32_6 = ConstantInt::get(mMod->getContext(), APInt(32, StringRef("0"), 10));
834    CastInst* int128_e1 = new BitCastInst(e1, IntegerType::get(mMod->getContext(), BLOCK_SIZE), "e1_128", mBasicBlock);
835    CastInst* int128_e2 = new BitCastInst(e2, IntegerType::get(mMod->getContext(), BLOCK_SIZE), "e2_128", mBasicBlock);
836    ExtractElementInst * int64_carryq_value = ExtractElementInst::Create(carryq_value, const_int32_6, "carryq_64", mBasicBlock);
837    CastInst* int1_carryq_value = new TruncInst(int64_carryq_value, IntegerType::get(mMod->getContext(), 1), "carryq_1", mBasicBlock);
838    SumWithOverflowPack sumpack0;
839    sumpack0 = callUaddOverflow(int128_e1, int128_e2, int1_carryq_value);
840    Value* obit = sumpack0.obit;
841    Value* sum = b.CreateBitCast(sumpack0.sum, mBitBlockType, "sum");
842    /*obit is the i1 carryout, zero extend and insert it into a v2i64 or v4i64 vector.*/
843    ConstantAggregateZero* const_packed_5 = ConstantAggregateZero::get(mBitBlockType);
844    CastInst* int64_o0 = new ZExtInst(obit, IntegerType::get(mMod->getContext(), 64), "o0", mBasicBlock);
845    InsertElementInst* carry_out = InsertElementInst::Create(const_packed_5, int64_o0, const_int32_6, "carry_out", mBasicBlock);
846#else
847    //calculate carry through logical ops
848    Value* carrygen = b.CreateAnd(e1, e2, "carrygen");
849    Value* carryprop = b.CreateOr(e1, e2, "carryprop");
850    Value* digitsum = b.CreateAdd(e1, e2, "digitsum");
851    Value* partial = b.CreateAdd(digitsum, carryq_value, "partial");
852    Value* digitcarry = b.CreateOr(carrygen, b.CreateAnd(carryprop, genNot(partial)));
853    Value* mid_carry_in = genShiftLeft64(b.CreateLShr(digitcarry, 63), "mid_carry_in");
854
855    Value* sum = b.CreateAdd(partial, mid_carry_in, "sum");
856    Value* carry_out = genShiftHighbitToLow(b.CreateOr(carrygen, b.CreateAnd(carryprop, genNot(sum))), "carry_out");
857#endif
858    genCarryOutStore(carry_out, carryIdx);
859    return sum;
860}
861
862Value* PabloCompiler::genCarryInLoad(const unsigned index) {   
863    assert (index < mCarryQueueVector.size());
864    if (mNestingDepth == 0) {
865        IRBuilder<> b(mBasicBlock);
866        mCarryQueueVector[index] = b.CreateAlignedLoad(b.CreateGEP(mCarryQueuePtr, b.getInt64(index)), BLOCK_SIZE/8, false);
867    }
868    return mCarryQueueVector[index];
869}
870
871void PabloCompiler::genCarryOutStore(Value* carryOut, const unsigned index ) {
872    assert (carryOut);
873    assert (index < mCarryQueueVector.size());
874    if (mNestingDepth == 0) {       
875        IRBuilder<> b(mBasicBlock);
876        b.CreateAlignedStore(carryOut, b.CreateGEP(mCarryQueuePtr, b.getInt64(index)), BLOCK_SIZE/8, false);
877    }
878    mCarryQueueVector[index] = carryOut;
879}
880
881inline Value* PabloCompiler::genBitBlockAny(Value* test) {
882    IRBuilder<> b(mBasicBlock);
883    Value* cast_marker_value_1 = b.CreateBitCast(test, IntegerType::get(mMod->getContext(), BLOCK_SIZE));
884    return b.CreateICmpEQ(cast_marker_value_1, ConstantInt::get(IntegerType::get(mMod->getContext(), BLOCK_SIZE), 0));
885}
886
887Value* PabloCompiler::genShiftHighbitToLow(Value* e, const Twine &namehint) {
888    IRBuilder<> b(mBasicBlock);
889    Value* i128_val = b.CreateBitCast(e, IntegerType::get(mMod->getContext(), BLOCK_SIZE));
890    return b.CreateBitCast(b.CreateLShr(i128_val, BLOCK_SIZE - 1, namehint), mBitBlockType);
891}
892
893Value* PabloCompiler::genShiftLeft64(Value* e, const Twine &namehint) {
894    IRBuilder<> b(mBasicBlock);
895    Value* i128_val = b.CreateBitCast(e, IntegerType::get(mMod->getContext(), BLOCK_SIZE));
896    return b.CreateBitCast(b.CreateShl(i128_val, 64, namehint), mBitBlockType);
897}
898
899inline Value* PabloCompiler::genNot(Value* expr) {
900    IRBuilder<> b(mBasicBlock);
901    return b.CreateXor(expr, mOneInitializer, "not");
902}
903
904Value* PabloCompiler::genAdvanceWithCarry(Value* strm_value, int shift_amount) {
905
906    IRBuilder<> b(mBasicBlock);
907#if (BLOCK_SIZE == 128)
908    const auto carryIdx = mCarryQueueIdx++;
909    if (shift_amount == 1) {
910        Value* carryq_value = genCarryInLoad(carryIdx);
911        Value* srli_1_value = b.CreateLShr(strm_value, 63);
912        Value* packed_shuffle;
913        Constant* const_packed_1_elems [] = {b.getInt32(0), b.getInt32(2)};
914        Constant* const_packed_1 = ConstantVector::get(const_packed_1_elems);
915        packed_shuffle = b.CreateShuffleVector(carryq_value, srli_1_value, const_packed_1);
916
917        Constant* const_packed_2_elems[] = {b.getInt64(1), b.getInt64(1)};
918        Constant* const_packed_2 = ConstantVector::get(const_packed_2_elems);
919
920        Value* shl_value = b.CreateShl(strm_value, const_packed_2);
921        Value* result_value = b.CreateOr(shl_value, packed_shuffle, "advance");
922
923        Value* carry_out = genShiftHighbitToLow(strm_value, "carry_out");
924        //CarryQ - carry out:
925        genCarryOutStore(carry_out, carryIdx);
926           
927        return result_value;
928    }
929    else if (shift_amount < 64) {
930        // This is the preferred logic, but is too slow for the general case.   
931        // We need to speed up our custom LLVM for this code.
932        Value* carryq_longint = b.CreateBitCast(genCarryInLoad(carryIdx), IntegerType::get(mMod->getContext(), BLOCK_SIZE));
933        Value* strm_longint = b.CreateBitCast(strm_value, IntegerType::get(mMod->getContext(), BLOCK_SIZE));
934        Value* adv_longint = b.CreateOr(b.CreateShl(strm_longint, shift_amount), carryq_longint, "advance");
935    Value* result_value = b.CreateBitCast(adv_longint, mBitBlockType);
936    Value* carry_out = b.CreateBitCast(b.CreateLShr(strm_longint, BLOCK_SIZE - shift_amount, "advance_out"), mBitBlockType);
937        //CarryQ - carry out:
938        genCarryOutStore(carry_out, carryIdx);
939           
940        return result_value;
941    }
942    else {//if (shift_amount >= 64) {
943        throw std::runtime_error("Shift amount >= 64 in Advance is currently unsupported.");
944    }
945#endif
946
947#if (BLOCK_SIZE == 256)
948    return genAddWithCarry(strm_value, strm_value);
949#endif
950
951}
952
953void PabloCompiler::SetOutputValue(Value * marker, const unsigned index) {
954    IRBuilder<> b(mBasicBlock);
955    if (marker->getType()->isPointerTy()) {
956        marker = b.CreateAlignedLoad(marker, BLOCK_SIZE/8, false);
957    }
958    Value* indices[] = {b.getInt64(0), b.getInt32(index)};
959    Value* gep = b.CreateGEP(mOutputAddrPtr, indices);
960    b.CreateAlignedStore(marker, gep, BLOCK_SIZE/8, false);
961}
962
963}
Note: See TracBrowser for help on using the repository browser.