source: icGREP/icgrep-devel/icgrep/pablo/pablo_compiler.cpp @ 4237

Last change on this file since 4237 was 4237, checked in by nmedfort, 5 years ago

Moved llvm_gen.* into pablo/pablo_compiler.* and updated CMakeLists.txt

File size: 32.1 KB
Line 
1/*
2 *  Copyright (c) 2014 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 *  icgrep is a trademark of International Characters.
5 */
6
7/*
8 *  Copyright (c) 2014 International Characters.
9 *  This software is licensed to the public under the Open Software License 3.0.
10 *  icgrep is a trademark of International Characters.
11 */
12
13#include <pablo/pablo_compiler.h>
14#include <pablo/codegenstate.h>
15#include <pablo/printer_pablos.h>
16#include <stdexcept>
17
18// #define DUMP_GENERATED_IR
19// #define DUMP_OPTIMIZED_IR
20
21extern "C" {
22  void wrapped_print_register(BitBlock bit_block) {
23      print_register<BitBlock>("", bit_block);
24  }
25}
26
27#define CREATE_GENERAL_CODE_CATEGORY(SUFFIX) \
28SUFFIX * f##SUFFIX = nullptr; \
29extern "C" { \
30    BitBlock __get_category_##SUFFIX(Basis_bits &basis_bits) { \
31        if (f##SUFFIX == nullptr) f##SUFFIX = new SUFFIX(); \
32        Struct_##SUFFIX output; \
33        f##SUFFIX->do_block(basis_bits, output); \
34        return output.cc; \
35    } \
36}
37
38CREATE_GENERAL_CODE_CATEGORY(Cc)
39CREATE_GENERAL_CODE_CATEGORY(Cf)
40CREATE_GENERAL_CODE_CATEGORY(Cn)
41CREATE_GENERAL_CODE_CATEGORY(Co)
42CREATE_GENERAL_CODE_CATEGORY(Cs)
43CREATE_GENERAL_CODE_CATEGORY(Ll)
44CREATE_GENERAL_CODE_CATEGORY(Lm)
45CREATE_GENERAL_CODE_CATEGORY(Lo)
46CREATE_GENERAL_CODE_CATEGORY(Lt)
47CREATE_GENERAL_CODE_CATEGORY(Lu)
48CREATE_GENERAL_CODE_CATEGORY(Mc)
49CREATE_GENERAL_CODE_CATEGORY(Me)
50CREATE_GENERAL_CODE_CATEGORY(Mn)
51CREATE_GENERAL_CODE_CATEGORY(Nd)
52CREATE_GENERAL_CODE_CATEGORY(Nl)
53CREATE_GENERAL_CODE_CATEGORY(No)
54CREATE_GENERAL_CODE_CATEGORY(Pc)
55CREATE_GENERAL_CODE_CATEGORY(Pd)
56CREATE_GENERAL_CODE_CATEGORY(Pe)
57CREATE_GENERAL_CODE_CATEGORY(Pf)
58CREATE_GENERAL_CODE_CATEGORY(Pi)
59CREATE_GENERAL_CODE_CATEGORY(Po)
60CREATE_GENERAL_CODE_CATEGORY(Ps)
61CREATE_GENERAL_CODE_CATEGORY(Sc)
62CREATE_GENERAL_CODE_CATEGORY(Sk)
63CREATE_GENERAL_CODE_CATEGORY(Sm)
64CREATE_GENERAL_CODE_CATEGORY(So)
65CREATE_GENERAL_CODE_CATEGORY(Zl)
66CREATE_GENERAL_CODE_CATEGORY(Zp)
67CREATE_GENERAL_CODE_CATEGORY(Zs)
68
69#undef CREATE_GENERAL_CODE_CATEGORY
70
71namespace pablo {
72
73PabloCompiler::PabloCompiler(std::map<std::string, std::string> name_map, std::string basis_pattern, int bits)
74: mBits(bits)
75, m_name_map(name_map)
76, mBasisBitPattern(basis_pattern)
77, mMod(new Module("icgrep", getGlobalContext()))
78, mBasicBlock(nullptr)
79, mExecutionEngine(nullptr)
80, mXi64Vect(nullptr)
81, mXi128Vect(nullptr)
82, mBasisBitsInputPtr(nullptr)
83, mOutputPtr(nullptr)
84, mCarryQueueIdx(0)
85, mptr_carry_q(nullptr)
86, mCarryQueueSize(0)
87, mConst_int64_neg1(nullptr)
88, mZeroInitializer(nullptr)
89, mAllOneInitializer(nullptr)
90, mFuncTy_0(nullptr)
91, mFunc_process_block(nullptr)
92, mBasisBitsAddr(nullptr)
93, mPtr_carry_q_addr(nullptr)
94, mPtr_output_addr(nullptr)
95{
96    //Create the jit execution engine.up
97    InitializeNativeTarget();
98    std::string ErrStr;
99
100    mExecutionEngine = EngineBuilder(mMod).setUseMCJIT(true).setErrorStr(&ErrStr).setOptLevel(CodeGenOpt::Level::Less).create();
101    if (mExecutionEngine == nullptr) {
102        throw std::runtime_error("\nCould not create ExecutionEngine: " + ErrStr);
103    }
104
105    InitializeNativeTargetAsmPrinter();
106    InitializeNativeTargetAsmParser();
107
108    DefineTypes();
109    DeclareFunctions();
110}
111
112PabloCompiler::~PabloCompiler()
113{
114    delete mMod;
115    delete fPs;
116    delete fNl;
117    delete fNo;
118    delete fLo;
119    delete fLl;
120    delete fLm;
121    delete fNd;
122    delete fPc;
123    delete fLt;
124    delete fLu;
125    delete fPf;
126    delete fPd;
127    delete fPe;
128    delete fPi;
129    delete fPo;
130    delete fMe;
131    delete fMc;
132    delete fMn;
133    delete fSk;
134    delete fSo;
135    delete fSm;
136    delete fSc;
137    delete fZl;
138    delete fCo;
139    delete fCn;
140    delete fCc;
141    delete fCf;
142    delete fCs;
143    delete fZp;
144    delete fZs;
145
146}
147
148LLVM_Gen_RetVal PabloCompiler::compile(const PabloBlock & cg_state)
149{
150    mCarryQueueSize = 0;
151    DeclareCallFunctions(cg_state.expressions());
152
153    Function::arg_iterator args = mFunc_process_block->arg_begin();
154    Value* ptr_basis_bits = args++;
155    ptr_basis_bits->setName("basis_bits");
156    mptr_carry_q = args++;
157    mptr_carry_q->setName("carry_q");
158    Value* ptr_output = args++;
159    ptr_output->setName("output");
160
161    //Create the carry queue.
162    mCarryQueueIdx = 0;
163    mBasicBlock = BasicBlock::Create(mMod->getContext(), "parabix_entry", mFunc_process_block,0);
164
165    //The basis bits structure
166    mBasisBitsAddr = new AllocaInst(mBasisBitsInputPtr, "basis_bits.addr", mBasicBlock);
167    new StoreInst(ptr_basis_bits, mBasisBitsAddr, false, mBasicBlock);
168    for (unsigned i = 0; i < mBits; ++i) {
169        IRBuilder<> b(mBasicBlock);
170        Value* basisBit = b.CreateLoad(mBasisBitsAddr);
171        Value* indices[] = {b.getInt64(0), b.getInt32(i)};
172        const std::string name = mBasisBitPattern + std::to_string(i);
173        Value* basis_bits_struct_GEP = b.CreateGEP(basisBit, indices, name);
174        mMarkerMap.insert(make_pair(name, basis_bits_struct_GEP));
175    }
176    mPtr_output_addr = new AllocaInst(mOutputPtr, "output.addr", mBasicBlock);
177    new StoreInst(ptr_output, mPtr_output_addr, false, mBasicBlock);
178
179    //Generate the IR instructions for the function.
180    SetReturnMarker(compileStatements(cg_state.expressions()), 0); // matches
181    SetReturnMarker(GetMarker(m_name_map.find("LineFeed")->second), 1); // line feeds
182
183    assert (mCarryQueueIdx == mCarryQueueSize);
184
185    //Terminate the block
186    ReturnInst::Create(mMod->getContext(), mBasicBlock);
187
188    //Create a verifier.  The verifier will print an error message if our module is malformed in any way.
189    #ifdef USE_LLVM_3_5
190    verifyModule(*mMod, &dbgs());
191    #endif
192    #ifdef USE_LLVM_3_4
193    verifyModule(*mMod, PrintMessageAction);
194    #endif
195
196    //Un-comment this line in order to display the IR that has been generated by this module.
197    #ifdef DUMP_GENERATED_IR
198    mMod->dump();
199    #endif
200
201    //Use the pass manager to run optimizations on the function.
202    FunctionPassManager fpm(mMod);
203
204#ifdef USE_LLVM_3_5
205    mMod->setDataLayout(mExecutionEngine->getDataLayout());
206    // Set up the optimizer pipeline.  Start with registering info about how the target lays out data structures.
207    fpm.add(new DataLayoutPass(mMod));
208#endif
209
210#ifdef USE_LLVM_3_4
211    fpm.add(new DataLayout(*mExecutionEngine->getDataLayout()));
212#endif
213
214    fpm.add(createPromoteMemoryToRegisterPass()); //Transform to SSA form.
215    fpm.add(createBasicAliasAnalysisPass());      //Provide basic AliasAnalysis support for GVN. (Global Value Numbering)
216    fpm.add(createCFGSimplificationPass());       //Simplify the control flow graph.
217    fpm.add(createInstructionCombiningPass());    //Simple peephole optimizations and bit-twiddling.
218    fpm.add(createReassociatePass());             //Reassociate expressions.
219    fpm.add(createGVNPass());                     //Eliminate common subexpressions.
220
221    fpm.doInitialization();
222
223    fpm.run(*mFunc_process_block);
224
225#ifdef DUMP_OPTIMIZED_IR
226    mMod->dump();
227#endif
228    mExecutionEngine->finalizeObject();
229
230    LLVM_Gen_RetVal retVal;
231    //Return the required size of the carry queue and a pointer to the process_block function.
232    retVal.carry_q_size = mCarryQueueSize;
233    retVal.process_block_fptr = mExecutionEngine->getPointerToFunction(mFunc_process_block);
234
235    return retVal;
236}
237
238void PabloCompiler::DefineTypes()
239{
240    //The BitBlock vector.
241    mXi64Vect = VectorType::get(IntegerType::get(mMod->getContext(), 64), BLOCK_SIZE / 64);
242    mXi128Vect = VectorType::get(IntegerType::get(mMod->getContext(), 128), BLOCK_SIZE / 128);
243
244    //Constant definitions.
245    mConst_int64_neg1 = ConstantInt::get(mMod->getContext(), APInt(64, -1));
246    mZeroInitializer = ConstantAggregateZero::get(mXi64Vect);
247
248    std::vector<Constant*> const_packed_27_elems;
249    for (int i = 0; i < BLOCK_SIZE / 64; ++i) {
250        const_packed_27_elems.push_back(mConst_int64_neg1);
251    }
252    mAllOneInitializer = ConstantVector::get(const_packed_27_elems);
253
254
255    StructType * StructTy_struct_Basis_bits = mMod->getTypeByName("struct.Basis_bits");
256    if (StructTy_struct_Basis_bits == nullptr) {
257        StructTy_struct_Basis_bits = StructType::create(mMod->getContext(), "struct.Basis_bits");
258    }
259    std::vector<Type*>StructTy_struct_Basis_bits_fields;
260    for (int i = 0; i < mBits; i++)
261    {
262        StructTy_struct_Basis_bits_fields.push_back(mXi64Vect);
263    }
264
265
266
267
268    if (StructTy_struct_Basis_bits->isOpaque()) {
269        StructTy_struct_Basis_bits->setBody(StructTy_struct_Basis_bits_fields, /*isPacked=*/false);
270    }
271    mBasisBitsInputPtr = PointerType::get(StructTy_struct_Basis_bits, 0);
272
273    std::vector<Type*>FuncTy_0_args;
274    FuncTy_0_args.push_back(mBasisBitsInputPtr);
275
276    //The carry q array.
277    //A pointer to the BitBlock vector.
278    FuncTy_0_args.push_back(PointerType::get(mXi64Vect, 0));
279
280    //The output structure.
281    StructType *StructTy_struct_Output = mMod->getTypeByName("struct.Output");
282    if (!StructTy_struct_Output) {
283        StructTy_struct_Output = StructType::create(mMod->getContext(), "struct.Output");
284    }
285    std::vector<Type*>StructTy_struct_Output_fields;
286    StructTy_struct_Output_fields.push_back(mXi64Vect);
287    StructTy_struct_Output_fields.push_back(mXi64Vect);
288    if (StructTy_struct_Output->isOpaque()) {
289        StructTy_struct_Output->setBody(StructTy_struct_Output_fields, /*isPacked=*/false);
290    }
291    mOutputPtr = PointerType::get(StructTy_struct_Output, 0);
292
293    //The &output parameter.
294    FuncTy_0_args.push_back(mOutputPtr);
295
296    mFuncTy_0 = FunctionType::get(
297     /*Result=*/Type::getVoidTy(mMod->getContext()),
298     /*Params=*/FuncTy_0_args,
299     /*isVarArg=*/false);
300}
301
302void PabloCompiler::DeclareFunctions()
303{
304    //This function can be used for testing to print the contents of a register from JIT'd code to the terminal window.
305    //mFunc_print_register = mMod->getOrInsertFunction("wrapped_print_register", Type::getVoidTy(getGlobalContext()), mXi64Vect, NULL);
306    //mExecutionEngine->addGlobalMapping(cast<GlobalValue>(mFunc_print_register), (void *)&wrapped_print_register);
307    // to call->  b.CreateCall(mFunc_print_register, unicode_category);
308
309#ifdef USE_UADD_OVERFLOW
310    // Type Definitions for llvm.uadd.with.overflow.i128 or .i256
311    std::vector<Type*>StructTy_0_fields;
312    StructTy_0_fields.push_back(IntegerType::get(mMod->getContext(), BLOCK_SIZE));
313    StructTy_0_fields.push_back(IntegerType::get(mMod->getContext(), 1));
314    StructType *StructTy_0 = StructType::get(mMod->getContext(), StructTy_0_fields, /*isPacked=*/false);
315
316    std::vector<Type*>FuncTy_1_args;
317    FuncTy_1_args.push_back(IntegerType::get(mMod->getContext(), BLOCK_SIZE));
318    FuncTy_1_args.push_back(IntegerType::get(mMod->getContext(), BLOCK_SIZE));
319    FunctionType* FuncTy_1 = FunctionType::get(
320                                              /*Result=*/StructTy_0,
321                                              /*Params=*/FuncTy_1_args,
322                                              /*isVarArg=*/false);
323
324    mFunc_llvm_uadd_with_overflow = mMod->getFunction("llvm.uadd.with.overflow.i" + std::to_string(BLOCK_SIZE));
325    if (!mFunc_llvm_uadd_with_overflow) {
326        mFunc_llvm_uadd_with_overflow = Function::Create(
327          /*Type=*/FuncTy_1,
328          /*Linkage=*/GlobalValue::ExternalLinkage,
329          /*Name=*/"llvm.uadd.with.overflow.i" + std::to_string(BLOCK_SIZE), mMod); // (external, no body)
330        mFunc_llvm_uadd_with_overflow->setCallingConv(CallingConv::C);
331    }
332    AttributeSet mFunc_llvm_uadd_with_overflow_PAL;
333    {
334        SmallVector<AttributeSet, 4> Attrs;
335        AttributeSet PAS;
336        {
337          AttrBuilder B;
338          B.addAttribute(Attribute::NoUnwind);
339          B.addAttribute(Attribute::ReadNone);
340          PAS = AttributeSet::get(mMod->getContext(), ~0U, B);
341        }
342
343        Attrs.push_back(PAS);
344        mFunc_llvm_uadd_with_overflow_PAL = AttributeSet::get(mMod->getContext(), Attrs);
345    }
346    mFunc_llvm_uadd_with_overflow->setAttributes(mFunc_llvm_uadd_with_overflow_PAL);
347#endif
348
349    //Starts on process_block
350    SmallVector<AttributeSet, 4> Attrs;
351    AttributeSet PAS;
352    {
353        AttrBuilder B;
354        B.addAttribute(Attribute::ReadOnly);
355        B.addAttribute(Attribute::NoCapture);
356        PAS = AttributeSet::get(mMod->getContext(), 1U, B);
357    }
358    Attrs.push_back(PAS);
359    {
360        AttrBuilder B;
361        B.addAttribute(Attribute::NoCapture);
362        PAS = AttributeSet::get(mMod->getContext(), 2U, B);
363    }
364    Attrs.push_back(PAS);
365    {
366        AttrBuilder B;
367        B.addAttribute(Attribute::NoCapture);
368        PAS = AttributeSet::get(mMod->getContext(), 3U, B);
369    }
370    Attrs.push_back(PAS);
371    {
372        AttrBuilder B;
373        B.addAttribute(Attribute::NoUnwind);
374        B.addAttribute(Attribute::UWTable);
375        PAS = AttributeSet::get(mMod->getContext(), ~0U, B);
376    }
377    AttributeSet AttrSet = AttributeSet::get(mMod->getContext(), Attrs);
378
379    //Create the function that will be generated.
380    mFunc_process_block = mMod->getFunction("process_block");
381    if (!mFunc_process_block) {
382        mFunc_process_block = Function::Create(
383            /*Type=*/mFuncTy_0,
384            /*Linkage=*/GlobalValue::ExternalLinkage,
385            /*Name=*/"process_block", mMod);
386        mFunc_process_block->setCallingConv(CallingConv::C);
387    }
388    mFunc_process_block->setAttributes(AttrSet);
389}
390
391void PabloCompiler::DeclareCallFunctions(const ExpressionList & stmts) {
392    for (const PabloE * stmt : stmts) {
393        if (const Assign * an = dyn_cast<const Assign>(stmt)) {
394            DeclareCallFunctions(an->getExpr());
395        }
396        else if (const If * ifstmt = dyn_cast<const If>(stmt)) {
397            DeclareCallFunctions(ifstmt->getCondition());
398            DeclareCallFunctions(ifstmt->getBody());
399        }
400        else if (const While * whl = dyn_cast<const While>(stmt)) {
401            DeclareCallFunctions(whl->getCondition());
402            DeclareCallFunctions(whl->getBody());
403        }
404    }
405}
406
407void PabloCompiler::DeclareCallFunctions(const PabloE * expr)
408{
409    if (const Call * pablo_call = dyn_cast<const Call>(expr)) {
410        const std::string callee = pablo_call->getCallee();
411        if (mCalleeMap.find(callee) == mCalleeMap.end()) {
412            void * callee_ptr = nullptr;
413            #define CHECK_GENERAL_CODE_CATEGORY(SUFFIX) \
414                if (callee == #SUFFIX) { \
415                    callee_ptr = (void*)&__get_category_##SUFFIX; \
416                } else
417            CHECK_GENERAL_CODE_CATEGORY(Cc)
418            CHECK_GENERAL_CODE_CATEGORY(Cf)
419            CHECK_GENERAL_CODE_CATEGORY(Cn)
420            CHECK_GENERAL_CODE_CATEGORY(Co)
421            CHECK_GENERAL_CODE_CATEGORY(Cs)
422            CHECK_GENERAL_CODE_CATEGORY(Ll)
423            CHECK_GENERAL_CODE_CATEGORY(Lm)
424            CHECK_GENERAL_CODE_CATEGORY(Lo)
425            CHECK_GENERAL_CODE_CATEGORY(Lt)
426            CHECK_GENERAL_CODE_CATEGORY(Lu)
427            CHECK_GENERAL_CODE_CATEGORY(Mc)
428            CHECK_GENERAL_CODE_CATEGORY(Me)
429            CHECK_GENERAL_CODE_CATEGORY(Mn)
430            CHECK_GENERAL_CODE_CATEGORY(Nd)
431            CHECK_GENERAL_CODE_CATEGORY(Nl)
432            CHECK_GENERAL_CODE_CATEGORY(No)
433            CHECK_GENERAL_CODE_CATEGORY(Pc)
434            CHECK_GENERAL_CODE_CATEGORY(Pd)
435            CHECK_GENERAL_CODE_CATEGORY(Pe)
436            CHECK_GENERAL_CODE_CATEGORY(Pf)
437            CHECK_GENERAL_CODE_CATEGORY(Pi)
438            CHECK_GENERAL_CODE_CATEGORY(Po)
439            CHECK_GENERAL_CODE_CATEGORY(Ps)
440            CHECK_GENERAL_CODE_CATEGORY(Sc)
441            CHECK_GENERAL_CODE_CATEGORY(Sk)
442            CHECK_GENERAL_CODE_CATEGORY(Sm)
443            CHECK_GENERAL_CODE_CATEGORY(So)
444            CHECK_GENERAL_CODE_CATEGORY(Zl)
445            CHECK_GENERAL_CODE_CATEGORY(Zp)
446            CHECK_GENERAL_CODE_CATEGORY(Zs)
447            // OTHERWISE ...
448            throw std::runtime_error("Unknown unicode category \"" + callee + "\"");
449            #undef CHECK_GENERAL_CODE_CATEGORY
450            Value * get_unicode_category = mMod->getOrInsertFunction("__get_category_" + callee, mXi64Vect, mBasisBitsInputPtr, NULL);
451            if (get_unicode_category == nullptr) {
452                throw std::runtime_error("Could not create static method call for unicode category \"" + callee + "\"");
453            }
454            mExecutionEngine->addGlobalMapping(cast<GlobalValue>(get_unicode_category), callee_ptr);
455            mCalleeMap.insert(std::make_pair(callee, get_unicode_category));
456        }
457    }
458    else if (const And * pablo_and = dyn_cast<const And>(expr))
459    {
460        DeclareCallFunctions(pablo_and->getExpr1());
461        DeclareCallFunctions(pablo_and->getExpr2());
462    }
463    else if (const Or * pablo_or = dyn_cast<const Or>(expr))
464    {
465        DeclareCallFunctions(pablo_or->getExpr1());
466        DeclareCallFunctions(pablo_or->getExpr2());
467    }
468    else if (const Sel * pablo_sel = dyn_cast<const Sel>(expr))
469    {
470        DeclareCallFunctions(pablo_sel->getCondition());
471        DeclareCallFunctions(pablo_sel->getTrueExpr());
472        DeclareCallFunctions(pablo_sel->getFalseExpr());
473    }
474    else if (const Not * pablo_not = dyn_cast<const Not>(expr))
475    {
476        DeclareCallFunctions(pablo_not->getExpr());
477    }
478    else if (const Advance * adv = dyn_cast<const Advance>(expr))
479    {
480        ++mCarryQueueSize;
481        DeclareCallFunctions(adv->getExpr());
482    }
483    else if (const MatchStar * mstar = dyn_cast<const MatchStar>(expr))
484    {
485        ++mCarryQueueSize;
486        DeclareCallFunctions(mstar->getExpr1());
487        DeclareCallFunctions(mstar->getExpr2());
488    }
489    else if (const ScanThru * sthru = dyn_cast<const ScanThru>(expr))
490    {
491        ++mCarryQueueSize;
492        DeclareCallFunctions(sthru->getScanFrom());
493        DeclareCallFunctions(sthru->getScanThru());
494    }
495}
496
497Value* PabloCompiler::GetMarker(const std::string & name)
498{
499    IRBuilder<> b(mBasicBlock);
500    auto itr = mMarkerMap.find(name);
501    if (itr == mMarkerMap.end()) {
502        Value* ptr = b.CreateAlloca(mXi64Vect);
503        b.CreateStore(mZeroInitializer, ptr);
504        itr = mMarkerMap.insert(make_pair(name, ptr)).first;
505    }
506    return itr->second;
507}
508
509void PabloCompiler::SetReturnMarker(Value * marker, const unsigned output_idx)
510{
511    IRBuilder<> b(mBasicBlock);
512    Value* marker_bitblock = b.CreateLoad(marker);
513    Value* output_struct = b.CreateLoad(mPtr_output_addr);
514    Value* output_indices[] = {b.getInt64(0), b.getInt32(output_idx)};
515    Value* output_struct_GEP = b.CreateGEP(output_struct, output_indices);
516    b.CreateStore(marker_bitblock, output_struct_GEP);
517}
518
519
520Value * PabloCompiler::compileStatements(const ExpressionList & stmts) {
521    Value * retVal = nullptr;
522    for (PabloE * statement : stmts) {
523        retVal = compileStatement(statement);
524    }
525    return retVal;
526}
527
528Value * PabloCompiler::compileStatement(PabloE * stmt)
529{
530    Value * retVal = nullptr;
531    if (const Assign * assign = dyn_cast<const Assign>(stmt))
532    {
533        IRBuilder<> b(mBasicBlock);
534        Value * marker = GetMarker(assign->getName());
535        Value * expr = compileExpression(assign->getExpr());
536        b.CreateStore(expr, marker);
537        retVal = marker;
538    }
539    else if (const If * ifstmt = dyn_cast<const If>(stmt))
540    {
541        BasicBlock * ifEntryBlock = mBasicBlock;
542        BasicBlock * ifBodyBlock = BasicBlock::Create(mMod->getContext(), "if.body",mFunc_process_block, 0);
543        BasicBlock * ifEndBlock = BasicBlock::Create(mMod->getContext(), "if.end",mFunc_process_block, 0);
544
545        int if_start_idx = mCarryQueueIdx;
546
547        Value* if_test_value = compileExpression(ifstmt->getCondition());
548
549        /* Generate the statements into the if body block, and also determine the
550           final carry index.  */
551
552        IRBuilder<> b_ifbody(ifBodyBlock);
553        mBasicBlock = ifBodyBlock;
554
555        Value *  returnMarker = compileStatements(ifstmt->getBody());
556
557        int if_end_idx = mCarryQueueIdx;
558        if (if_start_idx < if_end_idx + 1) {
559            // Have at least two internal carries.   Accumulate and store.
560            int if_accum_idx = mCarryQueueIdx++;
561
562            Value* if_carry_accum_value = genCarryInLoad(mptr_carry_q, if_start_idx);
563
564            for (int c = if_start_idx+1; c < if_end_idx; c++)
565            {
566                Value* carryq_value = genCarryInLoad(mptr_carry_q, c);
567                if_carry_accum_value = b_ifbody.CreateOr(carryq_value, if_carry_accum_value);
568            }
569            genCarryOutStore(if_carry_accum_value, mptr_carry_q, if_accum_idx);
570
571        }
572        b_ifbody.CreateBr(ifEndBlock);
573
574        IRBuilder<> b_entry(ifEntryBlock);
575        mBasicBlock = ifEntryBlock;
576        if (if_start_idx < if_end_idx) {
577            // Have at least one internal carry.
578            int if_accum_idx = mCarryQueueIdx - 1;
579            Value* last_if_pending_carries = genCarryInLoad(mptr_carry_q, if_accum_idx);
580            if_test_value = b_entry.CreateOr(if_test_value, last_if_pending_carries);
581        }
582        b_entry.CreateCondBr(genBitBlockAny(if_test_value), ifEndBlock, ifBodyBlock);
583
584        mBasicBlock = ifEndBlock;
585
586        retVal = returnMarker;
587    }
588    else if (const While* whl = dyn_cast<const While>(stmt))
589    {
590        int idx = mCarryQueueIdx;
591
592        //With this call to the while body we will account for all of the carry in values.
593        Value * returnMarker = compileStatements(whl->getBody());
594
595        BasicBlock*  whileCondBlock = BasicBlock::Create(mMod->getContext(), "while.cond", mFunc_process_block, 0);
596        BasicBlock*  whileBodyBlock = BasicBlock::Create(mMod->getContext(), "while.body", mFunc_process_block, 0);
597        BasicBlock*  whileEndBlock = BasicBlock::Create(mMod->getContext(), "while.end", mFunc_process_block, 0);
598
599        IRBuilder<> b(mBasicBlock);
600        b.CreateBr(whileCondBlock);
601        mBasicBlock = whileCondBlock;
602        IRBuilder<> b_cond(whileCondBlock);
603
604        Value* expression_marker_value = compileExpression(whl->getCondition());
605        Value* int_tobool1 = genBitBlockAny(expression_marker_value);
606
607        b_cond.CreateCondBr(int_tobool1, whileEndBlock, whileBodyBlock);
608
609        mBasicBlock = whileBodyBlock;
610        mCarryQueueIdx = 0;
611        //Store the current carry queue.
612        Value* ptr_last_carry_q = mptr_carry_q;
613
614        IRBuilder<> b_wb1(mBasicBlock);
615        //Create and initialize a new carry queue.
616        Value * ptr_while_carry_q = b_wb1.CreateAlloca(mXi64Vect, b_wb1.getInt64(mCarryQueueSize - idx));
617        for (int i = 0; i < (mCarryQueueSize - idx); i++) {
618            genCarryOutStore(mZeroInitializer, ptr_while_carry_q, i);
619        }
620
621        //Point mptr_carry_q to the new local carry queue.
622        mptr_carry_q = ptr_while_carry_q;
623
624        returnMarker = compileStatements(whl->getBody());
625
626        IRBuilder<> b_wb2(mBasicBlock);
627        //Copy back to the last carry queue the carries from the execution of the while statement list.
628        for (int c = 0; c < (mCarryQueueSize - idx); c++)
629        {
630            Value* new_carryq_value = b_wb2.CreateOr(genCarryInLoad(mptr_carry_q, c), genCarryInLoad(ptr_last_carry_q, idx + c));
631            genCarryOutStore(new_carryq_value, ptr_last_carry_q, idx + c);
632        }
633
634        b_wb2.CreateBr(whileCondBlock);
635
636        mBasicBlock = whileEndBlock;
637        mptr_carry_q = ptr_last_carry_q;
638        mCarryQueueIdx += idx;
639
640        retVal = returnMarker;
641    }
642    return retVal;
643}
644
645Value * PabloCompiler::compileExpression(PabloE * expr)
646{
647    Value * retVal = nullptr;
648    IRBuilder<> b(mBasicBlock);
649    if (const All* all = dyn_cast<All>(expr))
650    {
651        Value* ptr_all = b.CreateAlloca(mXi64Vect);
652        b.CreateStore((all->getValue() == 0 ? mZeroInitializer : mAllOneInitializer), ptr_all);
653        retVal = b.CreateLoad(ptr_all);
654    }
655    else if (const Call* call = dyn_cast<Call>(expr))
656    {
657        //Call the callee once and store the result in the marker map.
658        auto mi = mMarkerMap.find(call->getCallee());
659        if (mi == mMarkerMap.end()) {
660            auto ci = mCalleeMap.find(call->getCallee());
661            if (ci == mCalleeMap.end()) {
662                throw std::runtime_error("Unexpected error locating static function for \"" + call->getCallee() + "\"");
663            }
664            Value* basis_bits_struct = b.CreateLoad(mBasisBitsAddr);
665            Value* unicode_category = b.CreateCall(ci->second, basis_bits_struct);
666            Value* ptr = b.CreateAlloca(mXi64Vect);
667            b.CreateStore(unicode_category, ptr);
668            mi = mMarkerMap.insert(std::make_pair(call->getCallee(), ptr)).first;
669        }
670        retVal = b.CreateLoad(mi->second);
671    }
672    else if (const Var * var = dyn_cast<Var>(expr))
673    {
674        retVal = b.CreateLoad(GetMarker(var->getName()), false, var->getName());
675    }
676    else if (const And * pablo_and = dyn_cast<And>(expr))
677    {
678        retVal = b.CreateAnd(compileExpression(pablo_and->getExpr1()), compileExpression(pablo_and->getExpr2()), "and");
679    }
680    else if (const Or * pablo_or = dyn_cast<Or>(expr))
681    {
682        retVal = b.CreateOr(compileExpression(pablo_or->getExpr1()), compileExpression(pablo_or->getExpr2()), "or");
683    }
684    else if (const Sel * pablo_sel = dyn_cast<Sel>(expr))
685    {
686        Value* ifMask = compileExpression(pablo_sel->getCondition());
687        Value* and_if_true_result = b.CreateAnd(ifMask, compileExpression(pablo_sel->getTrueExpr()));
688        Value* and_if_false_result = b.CreateAnd(genNot(ifMask), compileExpression(pablo_sel->getFalseExpr()));
689        retVal = b.CreateOr(and_if_true_result, and_if_false_result);
690    }
691    else if (const Not * pablo_not = dyn_cast<Not>(expr))
692    {
693        Value* expr_value = compileExpression(pablo_not->getExpr());
694        retVal = b.CreateXor(expr_value, mAllOneInitializer, "not");
695    }
696    else if (const CharClass * cc = dyn_cast<CharClass>(expr))
697    {
698        retVal = b.CreateLoad(GetMarker(cc->getCharClass()));
699    }
700    else if (const Advance * adv = dyn_cast<Advance>(expr))
701    {
702        Value* strm_value = compileExpression(adv->getExpr());
703        retVal = genAdvanceWithCarry(strm_value);
704    }
705    else if (const MatchStar * mstar = dyn_cast<MatchStar>(expr))
706    {
707        Value* marker_expr = compileExpression(mstar->getExpr1());
708        Value* cc_expr = compileExpression(mstar->getExpr2());
709        Value* marker_and_cc = b.CreateAnd(marker_expr, cc_expr);
710        retVal = b.CreateOr(b.CreateXor(genAddWithCarry(marker_and_cc, cc_expr), cc_expr), marker_expr, "matchstar");
711    }
712    else if (const ScanThru * sthru = dyn_cast<ScanThru>(expr))
713    {
714        Value* marker_expr = compileExpression(sthru->getScanFrom());
715        Value* cc_expr = compileExpression(sthru->getScanThru());
716        retVal = b.CreateAnd(genAddWithCarry(marker_expr, cc_expr), genNot(cc_expr), "scanthru_rslt");
717    }
718    return retVal;
719}
720
721#ifdef USE_UADD_OVERFLOW
722SumWithOverflowPack LLVM_Generator::callUaddOverflow(Value* int128_e1, Value* int128_e2) {
723    std::vector<Value*> struct_res_params;
724    struct_res_params.push_back(int128_e1);
725    struct_res_params.push_back(int128_e2);
726    CallInst* struct_res = CallInst::Create(mFunc_llvm_uadd_with_overflow, struct_res_params, "uadd_overflow_res", mBasicBlock);
727    struct_res->setCallingConv(CallingConv::C);
728    struct_res->setTailCall(false);
729    AttributeSet struct_res_PAL;
730    struct_res->setAttributes(struct_res_PAL);
731
732    SumWithOverflowPack ret;
733
734    std::vector<unsigned> int128_sum_indices;
735    int128_sum_indices.push_back(0);
736    ret.sum = ExtractValueInst::Create(struct_res, int128_sum_indices, "sum", mBasicBlock);
737
738    std::vector<unsigned> int1_obit_indices;
739    int1_obit_indices.push_back(1);
740    ret.obit = ExtractValueInst::Create(struct_res, int1_obit_indices, "obit", mBasicBlock);
741
742    return ret;
743}
744#endif
745
746Value* PabloCompiler::genAddWithCarry(Value* e1, Value* e2) {
747    IRBuilder<> b(mBasicBlock);
748
749    //CarryQ - carry in.
750    int this_carry_idx = mCarryQueueIdx;
751    mCarryQueueIdx++;
752
753    Value* carryq_value = genCarryInLoad(mptr_carry_q, this_carry_idx);
754
755#ifdef USE_UADD_OVERFLOW
756    //use llvm.uadd.with.overflow.i128 or i256
757
758    CastInst* int128_e1 = new BitCastInst(e1, IntegerType::get(mMod->getContext(), BLOCK_SIZE), "e1_128", mBasicBlock);
759    CastInst* int128_e2 = new BitCastInst(e2, IntegerType::get(mMod->getContext(), BLOCK_SIZE), "e2_128", mBasicBlock);
760    CastInst* int128_carryq_value = new BitCastInst(carryq_value, IntegerType::get(mMod->getContext(), BLOCK_SIZE), "carryq_128", mBasicBlock);
761
762    SumWithOverflowPack sumpack0, sumpack1;
763
764    sumpack0 = callUaddOverflow(int128_e1, int128_e2);
765    sumpack1 = callUaddOverflow(sumpack0.sum, int128_carryq_value);
766
767    Value* obit = b.CreateOr(sumpack0.obit, sumpack1.obit, "carry_bit");
768    Value* ret_sum = b.CreateBitCast(sumpack1.sum, mXi64Vect, "ret_sum");
769
770    /*obit is the i1 carryout, zero extend and insert it into a v2i64 or v4i64 vector.*/
771    ConstantAggregateZero* const_packed_5 = ConstantAggregateZero::get(mXi64Vect);
772    ConstantInt* const_int32_6 = ConstantInt::get(mMod->getContext(), APInt(32, StringRef("0"), 10));
773    CastInst* int64_o0 = new ZExtInst(obit, IntegerType::get(mMod->getContext(), 64), "o0", mBasicBlock);
774    InsertElementInst* carry_out = InsertElementInst::Create(const_packed_5, int64_o0, const_int32_6, "carry_out", mBasicBlock);
775
776    Value* void_1 = genCarryOutStore(carry_out, mptr_carry_q, this_carry_idx);
777    return ret_sum;
778#else
779    //calculate carry through logical ops
780    Value* carrygen = b.CreateAnd(e1, e2, "carrygen");
781    Value* carryprop = b.CreateOr(e1, e2, "carryprop");
782    Value* digitsum = b.CreateAdd(e1, e2, "digitsum");
783    Value* partial = b.CreateAdd(digitsum, carryq_value, "partial");
784    Value* digitcarry = b.CreateOr(carrygen, b.CreateAnd(carryprop, genNot(partial)));
785    Value* mid_carry_in = genShiftLeft64(b.CreateLShr(digitcarry, 63), "mid_carry_in");
786
787    Value* sum = b.CreateAdd(partial, mid_carry_in, "sum");
788    Value* carry_out = genShiftHighbitToLow(b.CreateOr(carrygen, b.CreateAnd(carryprop, genNot(sum))), "carry_out");
789    Value* void_1 = genCarryOutStore(carry_out, mptr_carry_q, this_carry_idx);
790
791    return sum;
792#endif
793}
794
795Value* PabloCompiler::genCarryInLoad(Value* ptr_carry_q, int n) {
796    IRBuilder<> b(mBasicBlock);
797    Value* carryq_idx = b.getInt64(n);
798    Value* carryq_GEP = b.CreateGEP(ptr_carry_q, carryq_idx);
799    return b.CreateLoad(carryq_GEP);
800}
801
802Value* PabloCompiler::genCarryOutStore(Value* carryout, Value* ptr_carry_q, int n ) {
803    IRBuilder<> b(mBasicBlock);
804    Value* carryq_idx = b.getInt64(n);
805    Value* carryq_GEP = b.CreateGEP(ptr_carry_q, carryq_idx);
806    return b.CreateStore(carryout, carryq_GEP);
807}
808
809Value* PabloCompiler::genBitBlockAny(Value* e) {
810    IRBuilder<> b(mBasicBlock);
811    Value* cast_marker_value_1 = b.CreateBitCast(e, IntegerType::get(mMod->getContext(), BLOCK_SIZE));
812    return b.CreateICmpEQ(cast_marker_value_1, ConstantInt::get(IntegerType::get(mMod->getContext(), BLOCK_SIZE), 0));
813}
814
815Value* PabloCompiler::genShiftHighbitToLow(Value* e, const Twine &namehint) {
816    IRBuilder<> b(mBasicBlock);
817    Value* i128_val = b.CreateBitCast(e, IntegerType::get(mMod->getContext(), BLOCK_SIZE));
818    return b.CreateBitCast(b.CreateLShr(i128_val, BLOCK_SIZE - 1, namehint), mXi64Vect);
819}
820
821Value* PabloCompiler::genShiftLeft64(Value* e, const Twine &namehint) {
822    IRBuilder<> b(mBasicBlock);
823    Value* i128_val = b.CreateBitCast(e, IntegerType::get(mMod->getContext(), BLOCK_SIZE));
824    return b.CreateBitCast(b.CreateShl(i128_val, 64, namehint), mXi64Vect);
825}
826
827Value* PabloCompiler::genNot(Value* e, const Twine &namehint) {
828    IRBuilder<> b(mBasicBlock);
829    return b.CreateXor(e, mAllOneInitializer, namehint);
830}
831
832Value* PabloCompiler::genAdvanceWithCarry(Value* strm_value) {
833    IRBuilder<> b(mBasicBlock);
834#if (BLOCK_SIZE == 128)
835    int this_carry_idx = mCarryQueueIdx;
836    mCarryQueueIdx++;
837
838    Value* carryq_value = genCarryInLoad(mptr_carry_q, this_carry_idx);
839
840    Value* srli_1_value = b.CreateLShr(strm_value, 63);
841
842    Value* packed_shuffle;
843    Constant* const_packed_1_elems [] = {b.getInt32(0), b.getInt32(2)};
844    Constant* const_packed_1 = ConstantVector::get(const_packed_1_elems);
845    packed_shuffle = b.CreateShuffleVector(carryq_value, srli_1_value, const_packed_1, "packed_shuffle nw");
846
847    Constant* const_packed_2_elems[] = {b.getInt64(1), b.getInt64(1)};
848    Constant* const_packed_2 = ConstantVector::get(const_packed_2_elems);
849
850    Value* shl_value = b.CreateShl(strm_value, const_packed_2, "shl_value");
851    Value* result_value = b.CreateOr(shl_value, packed_shuffle, "or.result_value");
852
853    Value* carry_out = genShiftHighbitToLow(strm_value, "carry_out");
854    //CarryQ - carry out:
855    Value* void_1 = genCarryOutStore(carry_out, mptr_carry_q, this_carry_idx);
856
857    return result_value;
858#endif
859
860#if (BLOCK_SIZE == 256)
861    return genAddWithCarry(strm_value, strm_value);
862#endif
863
864}
865
866}
Note: See TracBrowser for help on using the repository browser.