source: icGREP/icgrep-devel/icgrep/pablo/pablo_compiler.cpp @ 4261

Last change on this file since 4261 was 4261, checked in by cameron, 5 years ago

log 2 lower bound technique; VARIABLE_ADVANCE ifdef

File size: 33.7 KB
Line 
1/*
2 *  Copyright (c) 2014 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 *  icgrep is a trademark of International Characters.
5 */
6
7/*
8 *  Copyright (c) 2014 International Characters.
9 *  This software is licensed to the public under the Open Software License 3.0.
10 *  icgrep is a trademark of International Characters.
11 */
12
13#include <pablo/pablo_compiler.h>
14#include <pablo/codegenstate.h>
15#include <pablo/printer_pablos.h>
16#include <cc/cc_namemap.hpp>
17#include <re/re_name.h>
18#include <stdexcept>
19#include <include/simd-lib/bitblock.hpp>
20
21//#define DUMP_GENERATED_IR
22//#define DUMP_OPTIMIZED_IR
23
24extern "C" {
25  void wrapped_print_register(BitBlock bit_block) {
26      print_register<BitBlock>("", bit_block);
27  }
28}
29
30#define CREATE_GENERAL_CODE_CATEGORY(SUFFIX) \
31SUFFIX * f##SUFFIX = nullptr; \
32extern "C" { \
33    BitBlock __get_category_##SUFFIX(Basis_bits &basis_bits) { \
34        if (f##SUFFIX == nullptr) f##SUFFIX = new SUFFIX(); \
35        Struct_##SUFFIX output; \
36        f##SUFFIX->do_block(basis_bits, output); \
37        return output.cc; \
38    } \
39}
40
41CREATE_GENERAL_CODE_CATEGORY(Cc)
42CREATE_GENERAL_CODE_CATEGORY(Cf)
43CREATE_GENERAL_CODE_CATEGORY(Cn)
44CREATE_GENERAL_CODE_CATEGORY(Co)
45CREATE_GENERAL_CODE_CATEGORY(Cs)
46CREATE_GENERAL_CODE_CATEGORY(Ll)
47CREATE_GENERAL_CODE_CATEGORY(Lm)
48CREATE_GENERAL_CODE_CATEGORY(Lo)
49CREATE_GENERAL_CODE_CATEGORY(Lt)
50CREATE_GENERAL_CODE_CATEGORY(Lu)
51CREATE_GENERAL_CODE_CATEGORY(Mc)
52CREATE_GENERAL_CODE_CATEGORY(Me)
53CREATE_GENERAL_CODE_CATEGORY(Mn)
54CREATE_GENERAL_CODE_CATEGORY(Nd)
55CREATE_GENERAL_CODE_CATEGORY(Nl)
56CREATE_GENERAL_CODE_CATEGORY(No)
57CREATE_GENERAL_CODE_CATEGORY(Pc)
58CREATE_GENERAL_CODE_CATEGORY(Pd)
59CREATE_GENERAL_CODE_CATEGORY(Pe)
60CREATE_GENERAL_CODE_CATEGORY(Pf)
61CREATE_GENERAL_CODE_CATEGORY(Pi)
62CREATE_GENERAL_CODE_CATEGORY(Po)
63CREATE_GENERAL_CODE_CATEGORY(Ps)
64CREATE_GENERAL_CODE_CATEGORY(Sc)
65CREATE_GENERAL_CODE_CATEGORY(Sk)
66CREATE_GENERAL_CODE_CATEGORY(Sm)
67CREATE_GENERAL_CODE_CATEGORY(So)
68CREATE_GENERAL_CODE_CATEGORY(Zl)
69CREATE_GENERAL_CODE_CATEGORY(Zp)
70CREATE_GENERAL_CODE_CATEGORY(Zs)
71
72#undef CREATE_GENERAL_CODE_CATEGORY
73
74namespace pablo {
75
76PabloCompiler::PabloCompiler(const cc::CC_NameMap & nameMap, const BasisBitVars & basisBitVars, int bits)
77: mBits(bits)
78, mBasisBitVars(basisBitVars)
79, mMod(new Module("icgrep", getGlobalContext()))
80, mBasicBlock(nullptr)
81, mExecutionEngine(nullptr)
82, mXi64Vect(VectorType::get(IntegerType::get(mMod->getContext(), 64), BLOCK_SIZE / 64))
83, mXi128Vect(VectorType::get(IntegerType::get(mMod->getContext(), 128), BLOCK_SIZE / 128))
84, mBasisBitsInputPtr(nullptr)
85, mCarryQueueIdx(0)
86, mCarryQueuePtr(nullptr)
87, mNestingDepth(0)
88, mCarryQueueSize(0)
89, mZeroInitializer(ConstantAggregateZero::get(mXi64Vect))
90, mOneInitializer(ConstantVector::getAllOnesValue(mXi64Vect))
91, mFunctionType(nullptr)
92, mFunc_process_block(nullptr)
93, mBasisBitsAddr(nullptr)
94, mOutputAddrPtr(nullptr)
95, mNameMap(nameMap)
96{
97    //Create the jit execution engine.up
98    InitializeNativeTarget();
99    std::string ErrStr;
100    mExecutionEngine = EngineBuilder(mMod).setUseMCJIT(true).setErrorStr(&ErrStr).setOptLevel(CodeGenOpt::Level::Less).create();
101    if (mExecutionEngine == nullptr) {
102        throw std::runtime_error("Could not create ExecutionEngine: " + ErrStr);
103    }
104
105    InitializeNativeTargetAsmPrinter();
106    InitializeNativeTargetAsmParser();
107
108    DefineTypes();
109    DeclareFunctions();
110}
111
112PabloCompiler::~PabloCompiler()
113{
114    delete mMod;
115    delete fPs;
116    delete fNl;
117    delete fNo;
118    delete fLo;
119    delete fLl;
120    delete fLm;
121    delete fNd;
122    delete fPc;
123    delete fLt;
124    delete fLu;
125    delete fPf;
126    delete fPd;
127    delete fPe;
128    delete fPi;
129    delete fPo;
130    delete fMe;
131    delete fMc;
132    delete fMn;
133    delete fSk;
134    delete fSo;
135    delete fSm;
136    delete fSc;
137    delete fZl;
138    delete fCo;
139    delete fCn;
140    delete fCc;
141    delete fCf;
142    delete fCs;
143    delete fZp;
144    delete fZs;
145
146}
147
148LLVM_Gen_RetVal PabloCompiler::compile(PabloBlock & pb)
149{
150    mCarryQueueSize = 0;
151    DeclareCallFunctions(pb.expressions());
152    mCarryQueueVector.resize(mCarryQueueSize);
153
154    Function::arg_iterator args = mFunc_process_block->arg_begin();
155    mBasisBitsAddr = args++;
156    mBasisBitsAddr->setName("basis_bits");
157    mCarryQueuePtr = args++;
158    mCarryQueuePtr->setName("carry_q");
159    mOutputAddrPtr = args++;
160    mOutputAddrPtr->setName("output");
161
162    //Create the carry queue.
163    mCarryQueueIdx = 0;
164    mNestingDepth = 0;
165    mBasicBlock = BasicBlock::Create(mMod->getContext(), "parabix_entry", mFunc_process_block,0);
166
167    //The basis bits structure
168    for (unsigned i = 0; i < mBits; ++i) {
169        IRBuilder<> b(mBasicBlock);
170        Value* indices[] = {b.getInt64(0), b.getInt32(i)};
171        const String * const name = mBasisBitVars[i]->getName();
172        mMarkerMap.insert(std::make_pair(name, b.CreateGEP(mBasisBitsAddr, indices, name->str())));
173    }
174
175    //Generate the IR instructions for the function.
176    Value * result = compileStatements(pb.expressions());
177    SetReturnMarker(result, 0); // matches
178
179    const String * lf = pb.createVar(mNameMap["LineFeed"]->getName())->getName();
180
181    SetReturnMarker(GetMarker(lf), 1); // line feeds
182
183    assert (mCarryQueueIdx == mCarryQueueSize);
184    assert (mNestingDepth == 0);
185    //Terminate the block
186    ReturnInst::Create(mMod->getContext(), mBasicBlock);
187
188    //Un-comment this line in order to display the IR that has been generated by this module.
189    #ifdef DUMP_GENERATED_IR
190    mMod->dump();
191    #endif
192
193    //Create a verifier.  The verifier will print an error message if our module is malformed in any way.
194    #ifdef USE_LLVM_3_5
195    verifyModule(*mMod, &dbgs());
196    #endif
197    #ifdef USE_LLVM_3_4
198    verifyModule(*mMod, PrintMessageAction);
199    #endif
200
201    //Use the pass manager to run optimizations on the function.
202    FunctionPassManager fpm(mMod);
203
204#ifdef USE_LLVM_3_5
205    mMod->setDataLayout(mExecutionEngine->getDataLayout());
206    // Set up the optimizer pipeline.  Start with registering info about how the target lays out data structures.
207    fpm.add(new DataLayoutPass(mMod));
208#endif
209
210#ifdef USE_LLVM_3_4
211    fpm.add(new DataLayout(*mExecutionEngine->getDataLayout()));
212#endif
213
214    fpm.add(createPromoteMemoryToRegisterPass()); //Transform to SSA form.
215    fpm.add(createBasicAliasAnalysisPass());      //Provide basic AliasAnalysis support for GVN. (Global Value Numbering)
216    fpm.add(createCFGSimplificationPass());       //Simplify the control flow graph.
217    fpm.add(createInstructionCombiningPass());    //Simple peephole optimizations and bit-twiddling.
218    fpm.add(createReassociatePass());             //Reassociate expressions.
219    fpm.add(createGVNPass());                     //Eliminate common subexpressions.
220
221    fpm.doInitialization();
222
223    fpm.run(*mFunc_process_block);
224
225#ifdef DUMP_OPTIMIZED_IR
226    mMod->dump();
227#endif
228    mExecutionEngine->finalizeObject();
229
230    LLVM_Gen_RetVal retVal;
231    //Return the required size of the carry queue and a pointer to the process_block function.
232    retVal.carry_q_size = mCarryQueueSize;
233    retVal.process_block_fptr = mExecutionEngine->getPointerToFunction(mFunc_process_block);
234
235    return retVal;
236}
237
238void PabloCompiler::DefineTypes()
239{
240    StructType * structBasisBits = mMod->getTypeByName("struct.Basis_bits");
241    if (structBasisBits == nullptr) {
242        structBasisBits = StructType::create(mMod->getContext(), "struct.Basis_bits");
243    }
244    std::vector<Type*>StructTy_struct_Basis_bits_fields;
245    for (int i = 0; i < mBits; i++)
246    {
247        StructTy_struct_Basis_bits_fields.push_back(mXi64Vect);
248    }
249    if (structBasisBits->isOpaque()) {
250        structBasisBits->setBody(StructTy_struct_Basis_bits_fields, /*isPacked=*/false);
251    }
252    mBasisBitsInputPtr = PointerType::get(structBasisBits, 0);
253
254    std::vector<Type*>functionTypeArgs;
255    functionTypeArgs.push_back(mBasisBitsInputPtr);
256
257    //The carry q array.
258    //A pointer to the BitBlock vector.
259    functionTypeArgs.push_back(PointerType::get(mXi64Vect, 0));
260
261    //The output structure.
262    StructType * outputStruct = mMod->getTypeByName("struct.Output");
263    if (!outputStruct) {
264        outputStruct = StructType::create(mMod->getContext(), "struct.Output");
265    }
266    if (outputStruct->isOpaque()) {
267        std::vector<Type*>fields;
268        fields.push_back(mXi64Vect);
269        fields.push_back(mXi64Vect);
270        outputStruct->setBody(fields, /*isPacked=*/false);
271    }
272    PointerType* outputStructPtr = PointerType::get(outputStruct, 0);
273
274    //The &output parameter.
275    functionTypeArgs.push_back(outputStructPtr);
276
277    mFunctionType = FunctionType::get(
278     /*Result=*/Type::getVoidTy(mMod->getContext()),
279     /*Params=*/functionTypeArgs,
280     /*isVarArg=*/false);
281}
282
283void PabloCompiler::DeclareFunctions()
284{
285    //This function can be used for testing to print the contents of a register from JIT'd code to the terminal window.
286    //mFunc_print_register = mMod->getOrInsertFunction("wrapped_print_register", Type::getVoidTy(getGlobalContext()), mXi64Vect, NULL);
287    //mExecutionEngine->addGlobalMapping(cast<GlobalValue>(mFunc_print_register), (void *)&wrapped_print_register);
288    // to call->  b.CreateCall(mFunc_print_register, unicode_category);
289
290#ifdef USE_UADD_OVERFLOW
291    // Type Definitions for llvm.uadd.with.overflow.carryin.i128 or .i256
292    std::vector<Type*>StructTy_0_fields;
293    StructTy_0_fields.push_back(IntegerType::get(mMod->getContext(), BLOCK_SIZE));
294    StructTy_0_fields.push_back(IntegerType::get(mMod->getContext(), 1));
295    StructType *StructTy_0 = StructType::get(mMod->getContext(), StructTy_0_fields, /*isPacked=*/false);
296
297    std::vector<Type*>FuncTy_1_args;
298    FuncTy_1_args.push_back(IntegerType::get(mMod->getContext(), BLOCK_SIZE));
299    FuncTy_1_args.push_back(IntegerType::get(mMod->getContext(), BLOCK_SIZE));
300    FuncTy_1_args.push_back(IntegerType::get(mMod->getContext(), 1));
301    FunctionType* FuncTy_1 = FunctionType::get(
302                                              /*Result=*/StructTy_0,
303                                              /*Params=*/FuncTy_1_args,
304                                              /*isVarArg=*/false);
305
306    mFunc_llvm_uadd_with_overflow = mMod->getFunction("llvm.uadd.with.overflow.carryin.i"##BLOCK_SIZE);
307    if (!mFunc_llvm_uadd_with_overflow) {
308        mFunc_llvm_uadd_with_overflow = Function::Create(
309          /*Type=*/ FuncTy_1,
310          /*Linkage=*/ GlobalValue::ExternalLinkage,
311          /*Name=*/ "llvm.uadd.with.overflow.carryin.i"##BLOCK_SIZE, mMod); // (external, no body)
312        mFunc_llvm_uadd_with_overflow->setCallingConv(CallingConv::C);
313    }
314    AttributeSet mFunc_llvm_uadd_with_overflow_PAL;
315    {
316        SmallVector<AttributeSet, 4> Attrs;
317        AttributeSet PAS;
318        {
319          AttrBuilder B;
320          B.addAttribute(Attribute::NoUnwind);
321          B.addAttribute(Attribute::ReadNone);
322          PAS = AttributeSet::get(mMod->getContext(), ~0U, B);
323        }
324
325        Attrs.push_back(PAS);
326        mFunc_llvm_uadd_with_overflow_PAL = AttributeSet::get(mMod->getContext(), Attrs);
327    }
328    mFunc_llvm_uadd_with_overflow->setAttributes(mFunc_llvm_uadd_with_overflow_PAL);
329#endif
330
331    //Starts on process_block
332    SmallVector<AttributeSet, 4> Attrs;
333    AttributeSet PAS;
334    {
335        AttrBuilder B;
336        B.addAttribute(Attribute::ReadOnly);
337        B.addAttribute(Attribute::NoCapture);
338        PAS = AttributeSet::get(mMod->getContext(), 1U, B);
339    }
340    Attrs.push_back(PAS);
341    {
342        AttrBuilder B;
343        B.addAttribute(Attribute::NoCapture);
344        PAS = AttributeSet::get(mMod->getContext(), 2U, B);
345    }
346    Attrs.push_back(PAS);
347    {
348        AttrBuilder B;
349        B.addAttribute(Attribute::NoCapture);
350        PAS = AttributeSet::get(mMod->getContext(), 3U, B);
351    }
352    Attrs.push_back(PAS);
353    {
354        AttrBuilder B;
355        B.addAttribute(Attribute::NoUnwind);
356        B.addAttribute(Attribute::UWTable);
357        PAS = AttributeSet::get(mMod->getContext(), ~0U, B);
358    }
359    AttributeSet AttrSet = AttributeSet::get(mMod->getContext(), Attrs);
360
361    //Create the function that will be generated.
362    mFunc_process_block = mMod->getFunction("process_block");
363    if (!mFunc_process_block) {
364        mFunc_process_block = Function::Create(
365            /*Type=*/mFunctionType,
366            /*Linkage=*/GlobalValue::ExternalLinkage,
367            /*Name=*/"process_block", mMod);
368        mFunc_process_block->setCallingConv(CallingConv::C);
369    }
370    mFunc_process_block->setAttributes(AttrSet);
371}
372
373void PabloCompiler::DeclareCallFunctions(const ExpressionList & stmts) {
374    for (PabloAST * stmt : stmts) {
375        if (const Assign * assign = dyn_cast<Assign>(stmt)) {
376            DeclareCallFunctions(assign->getExpr());
377        }
378        if (const Next * next = dyn_cast<Next>(stmt)) {
379            DeclareCallFunctions(next->getExpr());
380        }
381        else if (If * ifStatement = dyn_cast<If>(stmt)) {
382            const auto preIfCarryCount = mCarryQueueSize;
383            DeclareCallFunctions(ifStatement->getCondition());
384            DeclareCallFunctions(ifStatement->getBody());
385            ifStatement->setInclusiveCarryCount(mCarryQueueSize - preIfCarryCount);
386        }
387        else if (While * whileStatement = dyn_cast<While>(stmt)) {
388            const auto preWhileCarryCount = mCarryQueueSize;
389            DeclareCallFunctions(whileStatement->getCondition());
390            DeclareCallFunctions(whileStatement->getBody());
391            whileStatement->setInclusiveCarryCount(mCarryQueueSize - preWhileCarryCount);
392        }
393    }
394}
395
396void PabloCompiler::DeclareCallFunctions(const PabloAST * expr)
397{
398    if (const Call * call = dyn_cast<const Call>(expr)) {
399        const String * const callee = call->getCallee();
400        assert (callee);
401        if (mCalleeMap.find(callee) == mCalleeMap.end()) {
402            void * callee_ptr = nullptr;
403            #define CHECK_GENERAL_CODE_CATEGORY(SUFFIX) \
404                if (callee->str() == #SUFFIX) { \
405                    callee_ptr = (void*)&__get_category_##SUFFIX; \
406                } else
407            CHECK_GENERAL_CODE_CATEGORY(Cc)
408            CHECK_GENERAL_CODE_CATEGORY(Cf)
409            CHECK_GENERAL_CODE_CATEGORY(Cn)
410            CHECK_GENERAL_CODE_CATEGORY(Co)
411            CHECK_GENERAL_CODE_CATEGORY(Cs)
412            CHECK_GENERAL_CODE_CATEGORY(Ll)
413            CHECK_GENERAL_CODE_CATEGORY(Lm)
414            CHECK_GENERAL_CODE_CATEGORY(Lo)
415            CHECK_GENERAL_CODE_CATEGORY(Lt)
416            CHECK_GENERAL_CODE_CATEGORY(Lu)
417            CHECK_GENERAL_CODE_CATEGORY(Mc)
418            CHECK_GENERAL_CODE_CATEGORY(Me)
419            CHECK_GENERAL_CODE_CATEGORY(Mn)
420            CHECK_GENERAL_CODE_CATEGORY(Nd)
421            CHECK_GENERAL_CODE_CATEGORY(Nl)
422            CHECK_GENERAL_CODE_CATEGORY(No)
423            CHECK_GENERAL_CODE_CATEGORY(Pc)
424            CHECK_GENERAL_CODE_CATEGORY(Pd)
425            CHECK_GENERAL_CODE_CATEGORY(Pe)
426            CHECK_GENERAL_CODE_CATEGORY(Pf)
427            CHECK_GENERAL_CODE_CATEGORY(Pi)
428            CHECK_GENERAL_CODE_CATEGORY(Po)
429            CHECK_GENERAL_CODE_CATEGORY(Ps)
430            CHECK_GENERAL_CODE_CATEGORY(Sc)
431            CHECK_GENERAL_CODE_CATEGORY(Sk)
432            CHECK_GENERAL_CODE_CATEGORY(Sm)
433            CHECK_GENERAL_CODE_CATEGORY(So)
434            CHECK_GENERAL_CODE_CATEGORY(Zl)
435            CHECK_GENERAL_CODE_CATEGORY(Zp)
436            CHECK_GENERAL_CODE_CATEGORY(Zs)
437            // OTHERWISE ...
438            throw std::runtime_error("Unknown unicode category \"" + callee->str() + "\"");
439            #undef CHECK_GENERAL_CODE_CATEGORY
440            Value * unicodeCategory = mMod->getOrInsertFunction("__get_category_" + callee->str(), mXi64Vect, mBasisBitsInputPtr, NULL);
441            if (unicodeCategory == nullptr) {
442                throw std::runtime_error("Could not create static method call for unicode category \"" + callee->str() + "\"");
443            }
444            mExecutionEngine->addGlobalMapping(cast<GlobalValue>(unicodeCategory), callee_ptr);
445            mCalleeMap.insert(std::make_pair(callee, unicodeCategory));
446        }
447    }
448    else if (const And * pablo_and = dyn_cast<const And>(expr))
449    {
450        DeclareCallFunctions(pablo_and->getExpr1());
451        DeclareCallFunctions(pablo_and->getExpr2());
452    }
453    else if (const Or * pablo_or = dyn_cast<const Or>(expr))
454    {
455        DeclareCallFunctions(pablo_or->getExpr1());
456        DeclareCallFunctions(pablo_or->getExpr2());
457    }
458    else if (const Sel * pablo_sel = dyn_cast<const Sel>(expr))
459    {
460        DeclareCallFunctions(pablo_sel->getCondition());
461        DeclareCallFunctions(pablo_sel->getTrueExpr());
462        DeclareCallFunctions(pablo_sel->getFalseExpr());
463    }
464    else if (const Not * pablo_not = dyn_cast<const Not>(expr))
465    {
466        DeclareCallFunctions(pablo_not->getExpr());
467    }
468    else if (const Advance * adv = dyn_cast<const Advance>(expr))
469    {
470        ++mCarryQueueSize;
471        DeclareCallFunctions(adv->getExpr());
472    }
473    else if (const MatchStar * mstar = dyn_cast<const MatchStar>(expr))
474    {
475        ++mCarryQueueSize;
476        DeclareCallFunctions(mstar->getMarker());
477        DeclareCallFunctions(mstar->getCharClass());
478    }
479    else if (const ScanThru * sthru = dyn_cast<const ScanThru>(expr))
480    {
481        ++mCarryQueueSize;
482        DeclareCallFunctions(sthru->getScanFrom());
483        DeclareCallFunctions(sthru->getScanThru());
484    }
485}
486
487inline Value* PabloCompiler::GetMarker(const String * name) {
488    auto itr = mMarkerMap.find(name);
489    if (itr == mMarkerMap.end()) {
490        IRBuilder<> b(mBasicBlock);
491        itr = mMarkerMap.insert(std::make_pair(name, b.CreateAlloca(mXi64Vect))).first;
492    }
493    return itr->second;
494}
495
496void PabloCompiler::SetReturnMarker(Value * marker, const unsigned index) {
497    IRBuilder<> b(mBasicBlock);
498    Value* marker_bitblock = b.CreateAlignedLoad(marker, BLOCK_SIZE/8, false);
499    Value* output_indices[] = {b.getInt64(0), b.getInt32(index)};
500    Value* output_struct_GEP = b.CreateGEP(mOutputAddrPtr, output_indices);
501    b.CreateAlignedStore(marker_bitblock, output_struct_GEP, BLOCK_SIZE/8, false);
502}
503
504
505Value * PabloCompiler::compileStatements(const ExpressionList & stmts) {
506    Value * retVal = nullptr;
507    for (PabloAST * statement : stmts) {
508        retVal = compileStatement(statement);
509    }
510    return retVal;
511}
512
513Value * PabloCompiler::compileStatement(const PabloAST * stmt)
514{
515    Value * retVal = nullptr;
516    if (const Assign * assign = dyn_cast<const Assign>(stmt))
517    {
518        Value* expr = compileExpression(assign->getExpr());
519        Value* marker = GetMarker(assign->getName());
520        IRBuilder<> b(mBasicBlock);
521        b.CreateAlignedStore(expr, marker, BLOCK_SIZE/8, false);
522        retVal = marker;
523    }
524    if (const Next * next = dyn_cast<const Next>(stmt))
525    {
526        IRBuilder<> b(mBasicBlock);
527        auto f = mMarkerMap.find(next->getName());
528        assert (f != mMarkerMap.end());
529        Value * marker = f->second;
530        Value * expr = compileExpression(next->getExpr());
531        b.CreateAlignedStore(expr, marker, BLOCK_SIZE/8, false);
532        retVal = marker;
533    }
534    else if (const If * ifstmt = dyn_cast<const If>(stmt))
535    {
536        BasicBlock * ifEntryBlock = mBasicBlock;
537        BasicBlock * ifBodyBlock = BasicBlock::Create(mMod->getContext(), "if.body", mFunc_process_block, 0);
538        BasicBlock * ifEndBlock = BasicBlock::Create(mMod->getContext(), "if.end", mFunc_process_block, 0);
539
540        int if_start_idx = mCarryQueueIdx;
541
542        Value* if_test_value = compileExpression(ifstmt->getCondition());
543
544        /* Generate the statements into the if body block, and also determine the
545           final carry index.  */
546
547        IRBuilder<> bIfBody(ifBodyBlock);
548        mBasicBlock = ifBodyBlock;
549
550        ++mNestingDepth;
551
552        Value *  returnMarker = compileStatements(ifstmt->getBody());
553
554        int if_end_idx = mCarryQueueIdx;
555        if (if_start_idx < if_end_idx + 1) {
556            // Have at least two internal carries.   Accumulate and store.
557            int if_accum_idx = mCarryQueueIdx++;
558
559            Value* if_carry_accum_value = genCarryInLoad(if_start_idx);
560
561            for (int c = if_start_idx+1; c < if_end_idx; c++)
562            {
563                Value* carryq_value = genCarryInLoad(c);
564                if_carry_accum_value = bIfBody.CreateOr(carryq_value, if_carry_accum_value);
565            }
566            genCarryOutStore(if_carry_accum_value, if_accum_idx);
567
568        }
569        bIfBody.CreateBr(ifEndBlock);
570
571        IRBuilder<> b_entry(ifEntryBlock);
572        mBasicBlock = ifEntryBlock;
573        if (if_start_idx < if_end_idx) {
574            // Have at least one internal carry.
575            int if_accum_idx = mCarryQueueIdx - 1;
576            Value* last_if_pending_carries = genCarryInLoad(if_accum_idx);
577            if_test_value = b_entry.CreateOr(if_test_value, last_if_pending_carries);
578        }
579        b_entry.CreateCondBr(genBitBlockAny(if_test_value), ifEndBlock, ifBodyBlock);
580
581        mBasicBlock = ifEndBlock;
582        --mNestingDepth;
583
584        retVal = returnMarker;
585    }
586    else if (const While * whl = dyn_cast<const While>(stmt))
587    {
588        const auto baseCarryQueueIdx = mCarryQueueIdx;
589        if (mNestingDepth == 0) {
590            for (auto i = 0; i != whl->getInclusiveCarryCount(); ++i) {
591                genCarryInLoad(baseCarryQueueIdx + i);
592            }
593        }       
594
595        // First compile the initial iteration statements; the calls to genCarryOutStore will update the
596        // mCarryQueueVector with the appropriate values. Although we're not actually entering a new basic
597        // block yet, increment the nesting depth so that any calls to genCarryInLoad or genCarryOutStore
598        // will refer to the previous value.
599        ++mNestingDepth;
600        compileStatements(whl->getBody());
601        // Reset the carry queue index. Note: this ought to be changed in the future. Currently this assumes
602        // that compiling the while body twice will generate the equivalent IR. This is not necessarily true
603        // but works for now.
604        mCarryQueueIdx = baseCarryQueueIdx;
605
606        BasicBlock* whileCondBlock = BasicBlock::Create(mMod->getContext(), "while.cond", mFunc_process_block, 0);
607        BasicBlock* whileBodyBlock = BasicBlock::Create(mMod->getContext(), "while.body", mFunc_process_block, 0);
608        BasicBlock* whileEndBlock = BasicBlock::Create(mMod->getContext(), "while.end", mFunc_process_block, 0);
609
610        // Note: compileStatements may update the mBasicBlock pointer if the body contains nested loops. Do not
611        // assume it is the same one in which we entered the function with.
612        IRBuilder<> bEntry(mBasicBlock);
613        bEntry.CreateBr(whileCondBlock);
614
615        // CONDITION BLOCK
616        IRBuilder<> bCond(whileCondBlock);
617        // generate phi nodes for any carry propogating instruction
618        std::vector<PHINode*> carryQueuePhiNodes(whl->getInclusiveCarryCount());
619        for (auto i = 0; i != whl->getInclusiveCarryCount(); ++i) {
620            PHINode * phi = bCond.CreatePHI(mXi64Vect, 2);
621            phi->addIncoming(mCarryQueueVector[baseCarryQueueIdx + i], mBasicBlock);
622            mCarryQueueVector[baseCarryQueueIdx + i] = mZeroInitializer; // phi; // (use phi for multi-carry mode.)
623            carryQueuePhiNodes[i] = phi;
624        }
625
626        mBasicBlock = whileCondBlock;
627        bCond.CreateCondBr(genBitBlockAny(compileExpression(whl->getCondition())), whileEndBlock, whileBodyBlock);
628
629        // BODY BLOCK
630        mBasicBlock = whileBodyBlock;
631        retVal = compileStatements(whl->getBody());
632        // update phi nodes for any carry propogating instruction
633        IRBuilder<> bWhileBody(mBasicBlock);
634        for (auto i = 0; i != whl->getInclusiveCarryCount(); ++i) {
635            Value * carryOut = bWhileBody.CreateOr(carryQueuePhiNodes[i], mCarryQueueVector[baseCarryQueueIdx + i]);
636            carryQueuePhiNodes[i]->addIncoming(carryOut, mBasicBlock);
637            mCarryQueueVector[baseCarryQueueIdx + i] = carryQueuePhiNodes[i];
638        }
639
640        bWhileBody.CreateBr(whileCondBlock);
641
642        // EXIT BLOCK
643        mBasicBlock = whileEndBlock;   
644        if (--mNestingDepth == 0) {
645            for (auto i = 0; i != whl->getInclusiveCarryCount(); ++i) {
646                genCarryOutStore(carryQueuePhiNodes[i], baseCarryQueueIdx + i);
647            }
648        }
649    }
650    return retVal;
651}
652
653Value * PabloCompiler::compileExpression(const PabloAST * expr)
654{
655    Value * retVal = nullptr;
656    IRBuilder<> b(mBasicBlock);
657    if (isa<Ones>(expr)) {
658        retVal = mOneInitializer;
659    }
660    else if (isa<Zeroes>(expr)) {
661        retVal = mZeroInitializer;
662    }
663    else if (const Call* call = dyn_cast<Call>(expr)) {
664        //Call the callee once and store the result in the marker map.
665        auto mi = mMarkerMap.find(call->getCallee());
666        if (mi == mMarkerMap.end()) {
667            auto ci = mCalleeMap.find(call->getCallee());
668            if (ci == mCalleeMap.end()) {
669                throw std::runtime_error("Unexpected error locating static function for \"" + call->getCallee()->str() + "\"");
670            }
671            Value * unicode_category = b.CreateCall(ci->second, mBasisBitsAddr);
672            Value * ptr = b.CreateAlloca(mXi64Vect);
673            b.CreateAlignedStore(unicode_category, ptr, BLOCK_SIZE/8, false);
674            mi = mMarkerMap.insert(std::make_pair(call->getCallee(), ptr)).first;
675        }
676        retVal = b.CreateAlignedLoad(mi->second, BLOCK_SIZE/8, false, call->getCallee()->str() );
677    }
678    else if (const Var * var = dyn_cast<Var>(expr))
679    {
680        auto f = mMarkerMap.find(var->getName());
681        assert (f != mMarkerMap.end());
682        retVal = b.CreateAlignedLoad(f->second, BLOCK_SIZE/8, false, var->getName()->str() );
683    }
684    else if (const And * pablo_and = dyn_cast<And>(expr))
685    {
686        retVal = b.CreateAnd(compileExpression(pablo_and->getExpr1()), compileExpression(pablo_and->getExpr2()), "and");
687    }
688    else if (const Or * pablo_or = dyn_cast<Or>(expr))
689    {
690        retVal = b.CreateOr(compileExpression(pablo_or->getExpr1()), compileExpression(pablo_or->getExpr2()), "or");
691    }
692    else if (const Sel * pablo_sel = dyn_cast<Sel>(expr))
693    {
694        Value* ifMask = compileExpression(pablo_sel->getCondition());
695        Value* and_if_true_result = b.CreateAnd(ifMask, compileExpression(pablo_sel->getTrueExpr()));
696        Value* and_if_false_result = b.CreateAnd(genNot(ifMask), compileExpression(pablo_sel->getFalseExpr()));
697        retVal = b.CreateOr(and_if_true_result, and_if_false_result);
698    }
699    else if (const Not * pablo_not = dyn_cast<Not>(expr))
700    {
701        retVal = genNot(compileExpression(pablo_not->getExpr()));
702    }
703    else if (const Advance * adv = dyn_cast<Advance>(expr))
704    {
705        Value* strm_value = compileExpression(adv->getExpr());
706                int shift = adv->getAdvanceAmount();
707        retVal = genAdvanceWithCarry(strm_value, shift);
708    }
709    else if (const MatchStar * mstar = dyn_cast<MatchStar>(expr))
710    {
711        Value* marker = compileExpression(mstar->getMarker());
712        Value* cc = compileExpression(mstar->getCharClass());
713        Value* marker_and_cc = b.CreateAnd(marker, cc);
714        retVal = b.CreateOr(b.CreateXor(genAddWithCarry(marker_and_cc, cc), cc), marker, "matchstar");
715    }
716    else if (const ScanThru * sthru = dyn_cast<ScanThru>(expr))
717    {
718        Value* marker_expr = compileExpression(sthru->getScanFrom());
719        Value* cc_expr = compileExpression(sthru->getScanThru());
720        retVal = b.CreateAnd(genAddWithCarry(marker_expr, cc_expr), genNot(cc_expr), "scanthru");
721    }
722    return retVal;
723}
724
725#ifdef USE_UADD_OVERFLOW
726SumWithOverflowPack PabloCompiler::callUaddOverflow(Value* int128_e1, Value* int128_e2, Value* int1_cin) {
727    std::vector<Value*> struct_res_params;
728    struct_res_params.push_back(int128_e1);
729    struct_res_params.push_back(int128_e2);
730    struct_res_params.push_back(int1_cin);
731    CallInst* struct_res = CallInst::Create(mFunc_llvm_uadd_with_overflow, struct_res_params, "uadd_overflow_res", mBasicBlock);
732    struct_res->setCallingConv(CallingConv::C);
733    struct_res->setTailCall(false);
734    AttributeSet struct_res_PAL;
735    struct_res->setAttributes(struct_res_PAL);
736
737    SumWithOverflowPack ret;
738
739    std::vector<unsigned> int128_sum_indices;
740    int128_sum_indices.push_back(0);
741    ret.sum = ExtractValueInst::Create(struct_res, int128_sum_indices, "sum", mBasicBlock);
742
743    std::vector<unsigned> int1_obit_indices;
744    int1_obit_indices.push_back(1);
745    ret.obit = ExtractValueInst::Create(struct_res, int1_obit_indices, "obit", mBasicBlock);
746
747    return ret;
748}
749#endif
750
751Value* PabloCompiler::genAddWithCarry(Value* e1, Value* e2) {
752    IRBuilder<> b(mBasicBlock);
753
754    //CarryQ - carry in.
755    const int carryIdx = mCarryQueueIdx++;
756    Value* carryq_value = genCarryInLoad(carryIdx);
757
758#ifdef USE_UADD_OVERFLOW
759    //use llvm.uadd.with.overflow.i128 or i256
760    ConstantInt* const_int32_6 = ConstantInt::get(mMod->getContext(), APInt(32, StringRef("0"), 10));
761    CastInst* int128_e1 = new BitCastInst(e1, IntegerType::get(mMod->getContext(), BLOCK_SIZE), "e1_128", mBasicBlock);
762    CastInst* int128_e2 = new BitCastInst(e2, IntegerType::get(mMod->getContext(), BLOCK_SIZE), "e2_128", mBasicBlock);
763    ExtractElementInst * int64_carryq_value = ExtractElementInst::Create(carryq_value, const_int32_6, "carryq_64", mBasicBlock);
764    CastInst* int1_carryq_value = new TruncInst(int64_carryq_value, IntegerType::get(mMod->getContext(), 1), "carryq_1", mBasicBlock);
765    SumWithOverflowPack sumpack0;
766    sumpack0 = callUaddOverflow(int128_e1, int128_e2, int1_carryq_value);
767    Value* obit = sumpack0.obit;
768    Value* sum = b.CreateBitCast(sumpack0.sum, mXi64Vect, "sum");
769    /*obit is the i1 carryout, zero extend and insert it into a v2i64 or v4i64 vector.*/
770    ConstantAggregateZero* const_packed_5 = ConstantAggregateZero::get(mXi64Vect);
771    CastInst* int64_o0 = new ZExtInst(obit, IntegerType::get(mMod->getContext(), 64), "o0", mBasicBlock);
772    InsertElementInst* carry_out = InsertElementInst::Create(const_packed_5, int64_o0, const_int32_6, "carry_out", mBasicBlock);
773#else
774    //calculate carry through logical ops
775    Value* carrygen = b.CreateAnd(e1, e2, "carrygen");
776    Value* carryprop = b.CreateOr(e1, e2, "carryprop");
777    Value* digitsum = b.CreateAdd(e1, e2, "digitsum");
778    Value* partial = b.CreateAdd(digitsum, carryq_value, "partial");
779    Value* digitcarry = b.CreateOr(carrygen, b.CreateAnd(carryprop, genNot(partial)));
780    Value* mid_carry_in = genShiftLeft64(b.CreateLShr(digitcarry, 63), "mid_carry_in");
781
782    Value* sum = b.CreateAdd(partial, mid_carry_in, "sum");
783    Value* carry_out = genShiftHighbitToLow(b.CreateOr(carrygen, b.CreateAnd(carryprop, genNot(sum))), "carry_out");
784#endif
785    genCarryOutStore(carry_out, carryIdx);
786    return sum;
787}
788
789Value* PabloCompiler::genCarryInLoad(const unsigned index) {   
790    assert (index < mCarryQueueVector.size());
791    if (mNestingDepth == 0) {
792        IRBuilder<> b(mBasicBlock);
793        mCarryQueueVector[index] = b.CreateAlignedLoad(b.CreateGEP(mCarryQueuePtr, b.getInt64(index)), BLOCK_SIZE/8, false);
794    }
795    return mCarryQueueVector[index];
796}
797
798void PabloCompiler::genCarryOutStore(Value* carryOut, const unsigned index ) {
799    assert (carryOut);
800    assert (index < mCarryQueueVector.size());
801    if (mNestingDepth == 0) {       
802        IRBuilder<> b(mBasicBlock);
803        b.CreateAlignedStore(carryOut, b.CreateGEP(mCarryQueuePtr, b.getInt64(index)), BLOCK_SIZE/8, false);
804    }
805    mCarryQueueVector[index] = carryOut;
806}
807
808inline Value* PabloCompiler::genBitBlockAny(Value* test) {
809    IRBuilder<> b(mBasicBlock);
810    Value* cast_marker_value_1 = b.CreateBitCast(test, IntegerType::get(mMod->getContext(), BLOCK_SIZE));
811    return b.CreateICmpEQ(cast_marker_value_1, ConstantInt::get(IntegerType::get(mMod->getContext(), BLOCK_SIZE), 0));
812}
813
814Value* PabloCompiler::genShiftHighbitToLow(Value* e, const Twine &namehint) {
815    IRBuilder<> b(mBasicBlock);
816    Value* i128_val = b.CreateBitCast(e, IntegerType::get(mMod->getContext(), BLOCK_SIZE));
817    return b.CreateBitCast(b.CreateLShr(i128_val, BLOCK_SIZE - 1, namehint), mXi64Vect);
818}
819
820Value* PabloCompiler::genShiftLeft64(Value* e, const Twine &namehint) {
821    IRBuilder<> b(mBasicBlock);
822    Value* i128_val = b.CreateBitCast(e, IntegerType::get(mMod->getContext(), BLOCK_SIZE));
823    return b.CreateBitCast(b.CreateShl(i128_val, 64, namehint), mXi64Vect);
824}
825
826inline Value* PabloCompiler::genNot(Value* expr) {
827    IRBuilder<> b(mBasicBlock);
828    return b.CreateXor(expr, mOneInitializer, "not");
829}
830
831Value* PabloCompiler::genAdvanceWithCarry(Value* strm_value, int shift_amount) {
832#ifndef VARIABLE_ADVANCE
833        if (shift_amount != 1) {
834                throw std::runtime_error("Shift amount != 1 in Advance is currently unsupported.");
835        }
836#endif
837
838    IRBuilder<> b(mBasicBlock);
839#if (BLOCK_SIZE == 128)
840    const auto carryIdx = mCarryQueueIdx++;
841#ifndef VARIABLE_ADVANCE
842    Value* carryq_value = genCarryInLoad(carryIdx);
843    Value* srli_1_value = b.CreateLShr(strm_value, 63);
844    Value* packed_shuffle;
845    Constant* const_packed_1_elems [] = {b.getInt32(0), b.getInt32(2)};
846    Constant* const_packed_1 = ConstantVector::get(const_packed_1_elems);
847    packed_shuffle = b.CreateShuffleVector(carryq_value, srli_1_value, const_packed_1);
848
849    Constant* const_packed_2_elems[] = {b.getInt64(1), b.getInt64(1)};
850    Constant* const_packed_2 = ConstantVector::get(const_packed_2_elems);
851
852    Value* shl_value = b.CreateShl(strm_value, const_packed_2);
853    Value* result_value = b.CreateOr(shl_value, packed_shuffle, "advance");
854
855    Value* carry_out = genShiftHighbitToLow(strm_value, "carry_out");
856#endif
857#ifdef VARIABLE_ADVANCE
858    Value* carryq_longint = b.CreateBitCast(genCarryInLoad(carryIdx), IntegerType::get(mMod->getContext(), BLOCK_SIZE));
859    Value* strm_longint = b.CreateBitCast(strm_value, IntegerType::get(mMod->getContext(), BLOCK_SIZE));
860    Value* adv_longint = b.CreateOr(b.CreateShl(strm_longint, shift_amount), carryq_longint, "advance");
861    Value* result_value = b.CreateBitCast(adv_longint, mXi64Vect);
862    Value* carry_out = b.CreateBitCast(b.CreateLShr(strm_longint, BLOCK_SIZE - shift_amount, "advance_out"), mXi64Vect);
863#endif
864    //CarryQ - carry out:
865    genCarryOutStore(carry_out, carryIdx);
866       
867    return result_value;
868#endif
869
870#if (BLOCK_SIZE == 256)
871    return genAddWithCarry(strm_value, strm_value);
872#endif
873
874}
875
876}
Note: See TracBrowser for help on using the repository browser.