source: icGREP/icgrep-devel/icgrep/llvm_gen.cpp @ 3911

Last change on this file since 3911 was 3854, checked in by cameron, 5 years ago

Use CodeGenOpt::Default + createPromoteMemoryToRegisterPass

File size: 18.5 KB
Line 
1/*
2 *  Copyright (c) 2014 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 *  icgrep is a trademark of International Characters.
5 */
6
7#include "llvm_gen.h"
8
9extern "C" {
10  void wrapped_print_register(BitBlock bit_block) {
11      print_register<BitBlock>("", bit_block);
12  }
13}
14
15void LLVM_Generator::Print_Register(char *name, BitBlock bit_block)
16{
17    print_register<BitBlock>(name, bit_block);
18}
19
20LLVM_Generator::LLVM_Generator(std::string basis_pattern, int bits)
21{
22    mBasis_Pattern = basis_pattern;
23    mBits = bits;
24    mInWhile = false;
25}
26
27LLVM_Generator::~LLVM_Generator()
28{
29    delete mMod;
30}
31
32LLVM_Gen_RetVal LLVM_Generator::Generate_LLVMIR(CodeGenState cg_state, std::list<PabloS*> cc_cgo_stmtsl)
33{
34    //Create the module.
35    MakeLLVMModule();
36
37    //Create the jit execution engine.up
38    InitializeNativeTarget();
39    std::string ErrStr;
40    mExecutionEngine = EngineBuilder(mMod).setUseMCJIT(true).setErrorStr(&ErrStr).setOptLevel(CodeGenOpt::Default).create();
41    if (!mExecutionEngine)
42    {
43        std::cout << "\nCould not create ExecutionEngine: " + ErrStr << std::endl;
44        exit(1);
45    }
46
47    InitializeNativeTargetAsmPrinter();
48    InitializeNativeTargetAsmParser();
49
50    DefineTypes();
51    DeclareFunctions();
52
53    Function::arg_iterator args = mFunc_process_block->arg_begin();
54    Value* ptr_basis_bits = args++;
55    ptr_basis_bits->setName("basis_bits");
56    mptr_carry_q = args++;
57    mptr_carry_q->setName("carry_q");
58    Value* ptr_output = args++;
59    ptr_output->setName("output");
60
61    //Create the carry queue.
62    mCarryQueueIdx = 0;
63
64    mBasicBlock = BasicBlock::Create(mMod->getContext(), "parabix_entry", mFunc_process_block,0);
65
66    //The basis bits structure
67    mPtr_basis_bits_addr = new AllocaInst(mStruct_Basis_Bits_Ptr1, "basis_bits.addr", mBasicBlock);
68    StoreInst* void_14 = new StoreInst(ptr_basis_bits, mPtr_basis_bits_addr, false, mBasicBlock);
69
70    for (int i = 0; i < mBits; i++)
71    {
72        StoreBitBlockMarkerPtr(mBasis_Pattern + INT2STRING(i), i);
73    }
74
75    //The output structure.
76    mPtr_output_addr = new AllocaInst(mStruct_Output_Ptr1, "output.addr", mBasicBlock);
77    StoreInst* void_16 = new StoreInst(ptr_output, mPtr_output_addr, false, mBasicBlock);
78
79    //Generate the IR instructions for the function.
80    Generate_PabloStatements(cc_cgo_stmtsl);
81    Generate_PabloStatements(cg_state.stmtsl);
82    SetReturnMarker(cg_state.newsym, 0);
83    SetReturnMarker("lex.cclf", 1);
84
85    //Terminate the block
86    ReturnInst::Create(mMod->getContext(), mBasicBlock);
87
88    //Create a verifier.  The verifier will print an error message if our module is malformed in any way.
89    verifyModule(*mMod, PrintMessageAction);
90
91    //Un-comment this line in order to display the IR that has been generated by this module.
92    //mMod->dump();
93
94    //Use the pass manager to run optimizations on the function.
95    FunctionPassManager fpm(mMod);
96
97    // Set up the optimizer pipeline.  Start with registering info about how the target lays out data structures.
98    fpm.add(new DataLayout(*mExecutionEngine->getDataLayout()));
99
100    fpm.add(createPromoteMemoryToRegisterPass());
101
102    fpm.doInitialization();
103
104    fpm.run(*mFunc_process_block);
105
106    //mMod->dump();
107
108    mExecutionEngine->finalizeObject();
109
110    LLVM_Gen_RetVal retVal;
111    //Return the required size of the carry queue and a pointer to the process_block function.
112    retVal.carry_q_size = LLVM_Generator_Helper::CarryCount_PabloStatements(cg_state.stmtsl);;
113    retVal.process_block_fptr = mExecutionEngine->getPointerToFunction(mFunc_process_block);
114
115    return retVal;
116}
117
118void LLVM_Generator::DefineTypes()
119{
120    //The BitBlock vector.
121    m64x2Vect = VectorType::get(IntegerType::get(mMod->getContext(), 64), 2);
122    //A pointer to the BitBlock vector.
123    m64x2Vect_Ptr1 = PointerType::get(m64x2Vect, 0);
124
125    //Constant definitions.
126    mConst_int64_neg1 = ConstantInt::get(mMod->getContext(), APInt(64, StringRef("-1"), 10));
127
128    mConst_Aggregate_64x2_0 = ConstantAggregateZero::get(m64x2Vect);
129    std::vector<Constant*> const_packed_27_elems;
130    const_packed_27_elems.push_back(mConst_int64_neg1);
131    const_packed_27_elems.push_back(mConst_int64_neg1);
132    mConst_Aggregate_64x2_neg1 = ConstantVector::get(const_packed_27_elems);
133
134
135    StructType *StructTy_struct_Basis_bits = mMod->getTypeByName("struct.Basis_bits");
136    if (!StructTy_struct_Basis_bits) {
137        StructTy_struct_Basis_bits = StructType::create(mMod->getContext(), "struct.Basis_bits");
138    }
139    std::vector<Type*>StructTy_struct_Basis_bits_fields;
140    for (int i = 0; i < mBits; i++)
141    {
142        StructTy_struct_Basis_bits_fields.push_back(m64x2Vect);
143    }
144    if (StructTy_struct_Basis_bits->isOpaque()) {
145        StructTy_struct_Basis_bits->setBody(StructTy_struct_Basis_bits_fields, /*isPacked=*/false);
146    }
147
148    mStruct_Basis_Bits_Ptr1 = PointerType::get(StructTy_struct_Basis_bits, 0);
149
150    std::vector<Type*>FuncTy_0_args;
151    FuncTy_0_args.push_back(mStruct_Basis_Bits_Ptr1);
152
153    //The carry q array.
154    FuncTy_0_args.push_back(m64x2Vect_Ptr1);
155
156    //The output structure.
157    StructType *StructTy_struct_Output = mMod->getTypeByName("struct.Output");
158    if (!StructTy_struct_Output) {
159        StructTy_struct_Output = StructType::create(mMod->getContext(), "struct.Output");
160    }
161    std::vector<Type*>StructTy_struct_Output_fields;
162    StructTy_struct_Output_fields.push_back(m64x2Vect);
163    StructTy_struct_Output_fields.push_back(m64x2Vect);
164    if (StructTy_struct_Output->isOpaque()) {
165        StructTy_struct_Output->setBody(StructTy_struct_Output_fields, /*isPacked=*/false);
166    }
167    mStruct_Output_Ptr1 = PointerType::get(StructTy_struct_Output, 0);
168
169    //The &output parameter.
170    FuncTy_0_args.push_back(mStruct_Output_Ptr1);
171
172    mFuncTy_0 = FunctionType::get(
173     /*Result=*/Type::getVoidTy(mMod->getContext()),
174     /*Params=*/FuncTy_0_args,
175     /*isVarArg=*/false);
176}
177
178void LLVM_Generator::DeclareFunctions()
179{
180    //This function can be used for testing to print the contents of a register from JIT'd code to the terminal window.
181    mFunc_print_register = mMod->getOrInsertFunction("wrapped_print_register", Type::getVoidTy(getGlobalContext()), m64x2Vect, NULL);
182    mExecutionEngine->addGlobalMapping(cast<GlobalValue>(mFunc_print_register), (void *)&wrapped_print_register);
183
184    SmallVector<AttributeSet, 4> Attrs;
185    AttributeSet PAS;
186    {
187        AttrBuilder B;
188        B.addAttribute(Attribute::ReadOnly);
189        B.addAttribute(Attribute::NoCapture);
190        PAS = AttributeSet::get(mMod->getContext(), 1U, B);
191    }
192    Attrs.push_back(PAS);
193    {
194        AttrBuilder B;
195        B.addAttribute(Attribute::NoCapture);
196        PAS = AttributeSet::get(mMod->getContext(), 2U, B);
197    }
198    Attrs.push_back(PAS);
199    {
200        AttrBuilder B;
201        B.addAttribute(Attribute::NoCapture);
202        PAS = AttributeSet::get(mMod->getContext(), 3U, B);
203    }
204    Attrs.push_back(PAS);
205    {
206        AttrBuilder B;
207        B.addAttribute(Attribute::NoUnwind);
208        B.addAttribute(Attribute::UWTable);
209        PAS = AttributeSet::get(mMod->getContext(), ~0U, B);
210    }
211    AttributeSet AttrSet = AttributeSet::get(mMod->getContext(), Attrs);
212
213    //Create the function that will be generated.
214    mFunc_process_block = mMod->getFunction("process_block");
215    if (!mFunc_process_block) {
216        mFunc_process_block = Function::Create(
217            /*Type=*/mFuncTy_0,
218            /*Linkage=*/GlobalValue::ExternalLinkage,
219            /*Name=*/"process_block", mMod);
220        mFunc_process_block->setCallingConv(CallingConv::C);
221    }
222    mFunc_process_block->setAttributes(AttrSet);
223}
224
225void LLVM_Generator::MakeLLVMModule()
226{
227    mMod = new Module("icgrep", getGlobalContext());
228    mMod->setDataLayout("e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v64:64:64-v128:128:128-a0:0:64-s0:64:64-f80:128:128-n8:16:32:64-S128");
229    mMod->setTargetTriple("x86_64-unknown-linux-gnu");
230}
231
232void LLVM_Generator::StoreBitBlockMarkerPtr(std::string name, int index)
233{
234    IRBuilder<> b(mBasicBlock);
235
236    Value* basis_bits_struct = b.CreateLoad(mPtr_basis_bits_addr);
237    Value* struct_indices[] = {b.getInt64(0), b.getInt32(index)};
238    Value* basis_bits_struct_GEP = b.CreateGEP(basis_bits_struct, struct_indices, name);
239    mMarkerMap.insert(make_pair(name, basis_bits_struct_GEP));
240}
241
242Value* LLVM_Generator::GetMarker(std::string name)
243{
244    IRBuilder<> b(mBasicBlock);
245
246    if (mMarkerMap.find(name) == mMarkerMap.end())
247    {       
248        Value* ptr = b.CreateAlloca(m64x2Vect);
249        Value* void_1 = b.CreateStore(mConst_Aggregate_64x2_0, ptr);
250        mMarkerMap.insert(make_pair(name, ptr));
251    }
252    std::map<std::string, Value*>::iterator itGet = mMarkerMap.find(name);
253
254    return itGet->second;
255}
256
257void LLVM_Generator::SetReturnMarker(std::string marker, int output_idx)
258{
259    IRBuilder<> b(mBasicBlock);
260
261    Value* marker_bitblock = b.CreateLoad(GetMarker(marker));
262    Value* output_struct = b.CreateLoad(mPtr_output_addr);
263    Value* output_indices[] = {b.getInt64(0), b.getInt32(output_idx)};
264    Value* output_struct_GEP = b.CreateGEP(output_struct, output_indices, "return" + marker);
265    Value* store_marker = b.CreateStore(marker_bitblock, output_struct_GEP);
266}
267
268std::string LLVM_Generator::Generate_PabloStatements(std::list<PabloS*> stmts)
269{
270    std::string retVal = "";
271
272    std::list<PabloS*>::iterator it;
273    for (it = stmts.begin(); it != stmts.end(); ++it)
274    {
275        retVal = Generate_PabloS(*it);
276    }
277
278    return retVal;
279}
280
281std::string LLVM_Generator::Generate_PabloS(PabloS *stmt)
282{
283    std::string retVal = "";
284
285    if (Assign* assign = dynamic_cast<Assign*>(stmt))
286    {
287        IRBuilder<> b(mBasicBlock);
288
289        b.CreateStore(Generate_PabloE(assign->getExpr()), GetMarker(assign->getM()));
290
291        retVal = assign->getM();
292    }
293    else if (While* whl = dynamic_cast<While*>(stmt))
294    {
295        IRBuilder<> b(mBasicBlock);
296
297        mWhileCondBlock = BasicBlock::Create(mMod->getContext(), "while.cond", mFunc_process_block, 0);
298        mWhileBodyBlock = BasicBlock::Create(mMod->getContext(), "while.body",mFunc_process_block, 0);
299        mWhileEndBlock = BasicBlock::Create(mMod->getContext(), "while.end",mFunc_process_block, 0);
300
301        int idx = mCarryQueueIdx;
302
303        std::string returnMarker = Generate_PabloStatements(whl->getPSList());
304
305        b.CreateBr(mWhileCondBlock);
306        mBasicBlock = mWhileCondBlock;
307        IRBuilder<> b_cond(mWhileCondBlock);
308
309        Value* expression_marker_value = Generate_PabloE(whl->getExpr());
310       
311        // Use an i128 compare for simplicity and speed.
312        Value* cast_marker_value_1 = b_cond.CreateBitCast(expression_marker_value, IntegerType::get(mMod->getContext(), 128));
313        Value* int_tobool1 = b_cond.CreateICmpEQ(cast_marker_value_1, ConstantInt::get(IntegerType::get(mMod->getContext(), 128), 0));
314        b_cond.CreateCondBr(int_tobool1, mWhileEndBlock, mWhileBodyBlock);
315
316        //Note: Everything that happens during the recursive calls for the pablo statements in the body of this while loop will
317        //happen within the basic block of the body of the while loop.  This strategy will not support kstars within
318        //kstars, a more complex stragegy for basicblocks will have to be devised for that.
319        mBasicBlock = mWhileBodyBlock;
320
321        mInWhile = true;
322        mCarryQueueIdx = idx;
323        returnMarker = Generate_PabloStatements(whl->getPSList());
324        mInWhile = false;
325        IRBuilder<> b_wb(mWhileBodyBlock);
326        b_wb.CreateBr(mWhileCondBlock);
327
328        mBasicBlock = mWhileEndBlock;
329
330        retVal = returnMarker;
331    }
332
333    return retVal;
334}
335
336Value* LLVM_Generator::Generate_PabloE(PabloE *expr)
337{
338    Value* retVal = 0;
339
340    if (All* all = dynamic_cast<All*>(expr))
341    {
342        IRBuilder<> b(mBasicBlock);
343
344        if ((all->getNum() != 0) && (all->getNum() != 1))
345            std::cout << "\nErr: 'All' can only be set to 1 or 0.\n" << std::endl;
346        Value* ptr_all = b.CreateAlloca(m64x2Vect);
347        Value* void_1 = b.CreateStore((all->getNum() == 0 ? mConst_Aggregate_64x2_0 : mConst_Aggregate_64x2_neg1), ptr_all);
348        Value* all_value = b.CreateLoad(ptr_all);
349
350        retVal = all_value;
351    }
352    else if (Var* var = dynamic_cast<Var*>(expr))
353    {
354        IRBuilder<> b(mBasicBlock);
355
356        Value* var_value = b.CreateLoad(GetMarker(var->getVar()), false, var->getVar());
357
358        retVal = var_value;
359    }
360    else if (And* pablo_and = dynamic_cast<And*>(expr))
361    {
362        IRBuilder<> b(mBasicBlock);
363
364        Value* and_result = b.CreateAnd(Generate_PabloE(pablo_and->getExpr1()), Generate_PabloE(pablo_and->getExpr2()), "and_inst");
365
366        retVal = and_result;
367    }
368    else if (Or* pablo_or = dynamic_cast<Or*>(expr))
369    {
370        IRBuilder<> b(mBasicBlock);
371
372        Value* or_result = b.CreateOr(Generate_PabloE(pablo_or->getExpr1()), Generate_PabloE(pablo_or->getExpr2()), "or_inst");
373
374        retVal = or_result;
375    }
376    else if (Sel* pablo_sel = dynamic_cast<Sel*>(expr))
377    {
378        IRBuilder<>b(mBasicBlock);
379
380        Value* and_if_true_result = b.CreateAnd(Generate_PabloE(pablo_sel->getIf_expr()), Generate_PabloE(pablo_sel->getT_expr()));
381        Constant* const_packed_elems [] = {b.getInt64(-1), b.getInt64(-1)};
382        Constant* const_packed = ConstantVector::get(const_packed_elems);
383        Value* not_if_result = b.CreateXor(Generate_PabloE(pablo_sel->getIf_expr()), const_packed);
384        Value* and_if_false_result = b.CreateAnd(not_if_result, Generate_PabloE(pablo_sel->getF_expr()));
385        Value* or_result = b.CreateOr(and_if_true_result, and_if_false_result);
386
387        retVal = or_result;
388    }
389    else if (Not* pablo_not = dynamic_cast<Not*>(expr))
390    {
391        IRBuilder<> b(mBasicBlock);
392
393        Constant* const_packed_elems [] = {b.getInt64(-1), b.getInt64(-1)};
394        Constant* const_packed = ConstantVector::get(const_packed_elems);
395        Value* expr_value = Generate_PabloE(pablo_not->getExpr());
396        Value* xor_rslt = b.CreateXor(expr_value, const_packed, "xor_inst");
397
398        retVal = xor_rslt;
399    }
400    else if (CharClass* cc = dynamic_cast<CharClass*>(expr))
401    {
402        IRBuilder<> b(mBasicBlock);
403
404        Value* character_class = b.CreateLoad(GetMarker(cc->getCharClass()));
405
406        retVal = character_class;
407    }
408    else if (Advance* adv = dynamic_cast<Advance*>(expr))
409    {
410        IRBuilder<> b(mBasicBlock);
411
412        //CarryQ - carry in.
413        Value* carryq_idx = b.getInt64(mCarryQueueIdx);
414        Value* carryq_GEP = b.CreateGEP(mptr_carry_q, carryq_idx);
415        Value* carryq_value = b.CreateLoad(carryq_GEP);
416
417        Value* strm_value = Generate_PabloE(adv->getExpr());
418        Value* srli_1_value = b.CreateLShr(strm_value, 63);
419
420        Value* packed_shuffle;
421        if (mInWhile)
422        {
423            Constant* const_packed_1_elems [] = {b.getInt32(0), b.getInt32(2)};
424            Constant* const_packed_1 = ConstantVector::get(const_packed_1_elems);
425            packed_shuffle = b.CreateShuffleVector(mConst_Aggregate_64x2_0, srli_1_value, const_packed_1, "packed_shuffle iw");
426        }
427        else
428        {
429            Constant* const_packed_1_elems [] = {b.getInt32(0), b.getInt32(2)};
430            Constant* const_packed_1 = ConstantVector::get(const_packed_1_elems);
431            packed_shuffle = b.CreateShuffleVector(carryq_value, srli_1_value, const_packed_1, "packed_shuffle nw");
432        }
433
434        Constant* const_packed_2_elems[] = {b.getInt64(1), b.getInt64(1)};
435        Constant* const_packed_2 = ConstantVector::get(const_packed_2_elems);
436
437        Value* shl_value = b.CreateShl(strm_value, const_packed_2, "shl_value");
438        Value* result_value = b.CreateOr(shl_value, packed_shuffle, "or.result_value");
439
440        //CarryQ - carry out.
441        Value* cast_marker_value_1 = b.CreateBitCast(strm_value, IntegerType::get(mMod->getContext(), 128));
442        Value* srli_2_value = b.CreateLShr(cast_marker_value_1, 127);
443        Value* carryout_2_carry = b.CreateBitCast(srli_2_value, m64x2Vect);
444
445        if (mInWhile)
446        {
447            Value* carryout = b.CreateOr(carryq_value, carryout_2_carry);
448            Value* void_1 = b.CreateStore(carryout, carryq_GEP);
449        }
450        else
451        {
452            Value* void_1 = b.CreateStore(carryout_2_carry, carryq_GEP);
453        }
454
455        //Increment the idx for the next advance or scan through.
456        mCarryQueueIdx++;
457
458        retVal = result_value;
459    }
460    else if (MatchStar* mstar = dynamic_cast<MatchStar*>(expr))
461    {
462        IRBuilder<> b(mBasicBlock);
463
464        //CarryQ - carry in.
465        Value* carryq_idx = b.getInt64(mCarryQueueIdx);
466        Value* carryq_GEP = b.CreateGEP(mptr_carry_q, carryq_idx);
467        Value* carryq_value = b.CreateLoad(carryq_GEP);
468        //Get the input stream.
469        Value* strm_value = Generate_PabloE(mstar->getExpr1());
470        //Get the character that is to be matched.
471        Value* cc_value = Generate_PabloE(mstar->getExpr2());
472
473        Value* and_value_1 = b.CreateAnd(cc_value, strm_value, "match_star_and_value_1");
474        Value* add_value_1 = b.CreateAdd(and_value_1, cc_value, "match_star_add_value_1");
475        Value* add_value_2 = b.CreateAdd(add_value_1, carryq_value, "match_star_add_value_2");
476        Value* xor_value_1 = b.CreateXor(add_value_2, mConst_Aggregate_64x2_neg1, "match_star_xor_value_1");
477        Value* and_value_2 = b.CreateAnd(cc_value, xor_value_1, "match_star_and_value_2");
478        Value* or_value_1 = b.CreateOr(and_value_1, and_value_2, "match_star_or_value_1");
479
480        Value* srli_instr_1 = b.CreateLShr(or_value_1, 63);
481
482        Value* cast_marker_value_1 = b.CreateBitCast(srli_instr_1, IntegerType::get(mMod->getContext(), 128));
483        Value* sll_1_value = b.CreateShl(cast_marker_value_1, 64);
484        Value* cast_marker_value_2 = b.CreateBitCast(sll_1_value, m64x2Vect);
485
486
487        Value* add_value_3 = b.CreateAdd(cast_marker_value_2, add_value_2, "match_star_add_value_3");
488        Value* xor_value_2 = b.CreateXor(add_value_3, mConst_Aggregate_64x2_neg1, "match_star_xor_value_2");
489        Value* and_value_3 = b.CreateAnd(cc_value, xor_value_2, "match_star_and_value_3");
490        Value* or_value_2  = b.CreateOr(and_value_1, and_value_3, "match_star_or_value_2 ");
491        Value* xor_value_3 = b.CreateXor(add_value_3, cc_value, "match_star_xor_value_3");
492        Value* result_value = b.CreateOr(xor_value_3, strm_value, "match_star_result_value");
493
494        //CarryQ - carry out:
495        Value* cast_marker_value_3 = b.CreateBitCast(or_value_2, IntegerType::get(mMod->getContext(), 128));
496        Value* srli_2_value = b.CreateLShr(cast_marker_value_3, 127);
497        Value* carryout_2_carry = b.CreateBitCast(srli_2_value, m64x2Vect);
498
499        Value* void_1 = b.CreateStore(carryout_2_carry, carryq_GEP);
500
501        mCarryQueueIdx++;
502
503        retVal = result_value;
504    }
505
506    return retVal;
507}
508
Note: See TracBrowser for help on using the repository browser.