source: icGREP/icgrep-devel/icgrep/pablo/pablo_compiler.cpp @ 4657

Last change on this file since 4657 was 4657, checked in by nmedfort, 4 years ago

Initial introduction of a PabloFunction? type.

File size: 34.2 KB
Line 
1/*
2 *  Copyright (c) 2014-15 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 *  icgrep is a trademark of International Characters.
5 */
6
7#include <pablo/pablo_compiler.h>
8#include <pablo/codegenstate.h>
9#include <pablo/carry_data.h>
10#include <pablo/carry_manager.h>
11#include <pablo/printer_pablos.h>
12#include <pablo/function.h>
13#include <cc/cc_namemap.hpp>
14#include <re/re_name.h>
15#include <stdexcept>
16#include <include/simd-lib/bitblock.hpp>
17#include <sstream>
18#include <llvm/IR/Verifier.h>
19#include <llvm/Pass.h>
20#include <llvm/PassManager.h>
21#include <llvm/ADT/SmallVector.h>
22#include <llvm/Analysis/Passes.h>
23#include <llvm/IR/BasicBlock.h>
24#include <llvm/IR/CallingConv.h>
25#include <llvm/IR/Constants.h>
26#include <llvm/IR/DataLayout.h>
27#include <llvm/IR/DerivedTypes.h>
28#include <llvm/IR/Function.h>
29#include <llvm/IR/GlobalVariable.h>
30#include <llvm/IR/InlineAsm.h>
31#include <llvm/IR/Instructions.h>
32#include <llvm/IR/LLVMContext.h>
33#include <llvm/IR/Module.h>
34#include <llvm/Support/FormattedStream.h>
35#include <llvm/Support/MathExtras.h>
36#include <llvm/Support/Casting.h>
37#include <llvm/Support/Compiler.h>
38#include <llvm/Support/Debug.h>
39#include <llvm/Support/TargetSelect.h>
40#include <llvm/Support/Host.h>
41#include <llvm/Transforms/Scalar.h>
42#include <llvm/ExecutionEngine/ExecutionEngine.h>
43#include <llvm/ExecutionEngine/MCJIT.h>
44#include <llvm/IRReader/IRReader.h>
45#include <llvm/Bitcode/ReaderWriter.h>
46#include <llvm/Support/MemoryBuffer.h>
47#include <llvm/IR/IRBuilder.h>
48#include <llvm/Support/CommandLine.h>
49#include <llvm/ADT/Twine.h>
50#include <iostream>
51
52static cl::OptionCategory eIRDumpOptions("LLVM IR Dump Options", "These options control dumping of LLVM IR.");
53static cl::opt<bool> DumpGeneratedIR("dump-generated-IR", cl::init(false), cl::desc("print LLVM IR generated by RE compilation"), cl::cat(eIRDumpOptions));
54
55static cl::OptionCategory fTracingOptions("Run-time Tracing Options", "These options control execution traces.");
56static cl::opt<bool> TraceNext("trace-next-nodes", cl::init(false), cl::desc("Generate dynamic traces of executed Next nodes (while control variables)."), cl::cat(fTracingOptions));
57static cl::opt<bool> DumpTrace("dump-trace", cl::init(false), cl::desc("Generate dynamic traces of executed assignments."), cl::cat(fTracingOptions));
58
59extern "C" {
60  void wrapped_print_register(char * regName, BitBlock bit_block) {
61      print_register<BitBlock>(regName, bit_block);
62  }
63}
64
65namespace pablo {
66
67PabloCompiler::PabloCompiler()
68#ifdef USE_LLVM_3_5
69: mMod(new Module("icgrep", getGlobalContext()))
70#else
71: mModOwner(make_unique<Module>("icgrep", getGlobalContext()))
72, mMod(mModOwner.get())
73#endif
74, mBuilder(&LLVM_Builder)
75, mCarryManager(nullptr)
76, mExecutionEngine(nullptr)
77, mBitBlockType(VectorType::get(IntegerType::get(mMod->getContext(), 64), BLOCK_SIZE / 64))
78, mBasisBitsInputPtr(nullptr)
79, mCarryDataPtr(nullptr)
80, mWhileDepth(0)
81, mIfDepth(0)
82, mZeroInitializer(ConstantAggregateZero::get(mBitBlockType))
83, mOneInitializer(ConstantVector::getAllOnesValue(mBitBlockType))
84, mFunctionType(nullptr)
85, mFunction(nullptr)
86, mParameterAddr(nullptr)
87, mOutputAddrPtr(nullptr)
88, mMaxWhileDepth(0)
89, mPrintRegisterFunction(nullptr)
90{
91    //Create the jit execution engine.up
92    InitializeNativeTarget();
93    InitializeNativeTargetAsmPrinter();
94    InitializeNativeTargetAsmParser();
95}
96
97PabloCompiler::~PabloCompiler()
98{
99
100}
101   
102void PabloCompiler::InstallExternalFunction(std::string C_fn_name, void * fn_ptr) {
103    mExternalMap.insert(std::make_pair(C_fn_name, fn_ptr));
104}
105
106void PabloCompiler::genPrintRegister(std::string regName, Value * bitblockValue) {
107    Constant * regNameData = ConstantDataArray::getString(mMod->getContext(), regName);
108    GlobalVariable *regStrVar = new GlobalVariable(*mMod, 
109                                                   ArrayType::get(IntegerType::get(mMod->getContext(), 8), regName.length()+1),
110                                                   /*isConstant=*/ true,
111                                                   /*Linkage=*/ GlobalValue::PrivateLinkage,
112                                                   /*Initializer=*/ regNameData);
113    Value * regStrPtr = mBuilder->CreateGEP(regStrVar, {mBuilder->getInt64(0), mBuilder->getInt32(0)});
114    mBuilder->CreateCall(mPrintRegisterFunction, {regStrPtr, bitblockValue});
115}
116
117CompiledPabloFunction PabloCompiler::compile(PabloFunction & function)
118{
119    mWhileDepth = 0;
120    mIfDepth = 0;
121    mMaxWhileDepth = 0;
122    mCarryManager = new CarryManager(mMod, mBuilder, mBitBlockType, mZeroInitializer, mOneInitializer);
123   
124    std::string errMessage;
125#ifdef USE_LLVM_3_5
126    EngineBuilder builder(mMod);
127#else
128    EngineBuilder builder(std::move(mModOwner));
129#endif
130    builder.setErrorStr(&errMessage);
131    builder.setMCPU(sys::getHostCPUName());
132#ifdef USE_LLVM_3_5
133    builder.setUseMCJIT(true);
134#endif
135    builder.setOptLevel(mMaxWhileDepth ? CodeGenOpt::Level::Less : CodeGenOpt::Level::None);
136    mExecutionEngine = builder.create();
137    if (mExecutionEngine == nullptr) {
138        throw std::runtime_error("Could not create ExecutionEngine: " + errMessage);
139    }
140
141    DefineTypes(function);
142    DeclareFunctions();
143
144    Examine(function.getEntryBlock());
145    DeclareCallFunctions();
146
147    Function::arg_iterator args = mFunction->arg_begin();
148    mParameterAddr = args++;
149    mParameterAddr->setName("basis_bits");
150    mCarryDataPtr = args++;
151    mCarryDataPtr->setName("carry_data");
152    mOutputAddrPtr = args++;
153    mOutputAddrPtr->setName("output");
154
155    mWhileDepth = 0;
156    mIfDepth = 0;
157    mMaxWhileDepth = 0;
158    BasicBlock * b = BasicBlock::Create(mMod->getContext(), "parabix_entry", mFunction,0);
159    mBuilder->SetInsertPoint(b);
160
161    //The basis bits structure
162
163    for (unsigned i = 0; i != function.getParameters().size(); ++i) {
164        Value* indices[] = {mBuilder->getInt64(0), mBuilder->getInt32(i)};
165        Value * gep = mBuilder->CreateGEP(mParameterAddr, indices);
166        LoadInst * basisBit = mBuilder->CreateAlignedLoad(gep, BLOCK_SIZE/8, false, function.getParameter(i)->getName()->to_string());
167        mMarkerMap.insert(std::make_pair(function.getParameter(i), basisBit));
168    }
169       
170    unsigned totalCarryDataSize = mCarryManager->initialize(&(function.getEntryBlock()), mCarryDataPtr);
171   
172    //Generate the IR instructions for the function.
173    compileBlock(function.getEntryBlock());
174   
175    mCarryManager->generateBlockNoIncrement();
176
177    if (DumpTrace || TraceNext) {
178        genPrintRegister("mBlockNo", mBuilder->CreateAlignedLoad(mBuilder->CreateBitCast(mCarryManager->getBlockNoPtr(), PointerType::get(mBitBlockType, 0)), BLOCK_SIZE/8, false));
179    }
180   
181    if (LLVM_UNLIKELY(mWhileDepth != 0)) {
182        throw std::runtime_error("Non-zero nesting depth error (" + std::to_string(mWhileDepth) + ")");
183    }
184
185    // Write the output values out
186    for (unsigned i = 0; i != function.getResults().size(); ++i) {
187        SetOutputValue(mMarkerMap[function.getResult(i)], i);
188    }
189
190    //Terminate the block
191    ReturnInst::Create(mMod->getContext(), mBuilder->GetInsertBlock());
192
193    //Display the IR that has been generated by this module.
194    if (LLVM_UNLIKELY(DumpGeneratedIR)) {
195        mMod->dump();
196    }
197    //Create a verifier.  The verifier will print an error message if our module is malformed in any way.
198    verifyModule(*mMod, &dbgs());
199
200    mExecutionEngine->finalizeObject();
201
202    delete mCarryManager;
203    mCarryManager = nullptr;
204
205    //Return the required size of the carry data area to the process_block function.
206    return CompiledPabloFunction(totalCarryDataSize * sizeof(BitBlock), mFunction, mExecutionEngine);
207}
208
209void PabloCompiler::DefineTypes(PabloFunction & function) {
210
211    StructType * structBasisBits = mMod->getTypeByName("struct.Basis_bits");
212    if (structBasisBits == nullptr) {
213        structBasisBits = StructType::create(mMod->getContext(), "struct.Basis_bits");
214    }
215    std::vector<Type*> StructTy_struct_Basis_bits_fields;
216    for (int i = 0; i != function.getParameters().size(); i++) {
217        StructTy_struct_Basis_bits_fields.push_back(mBitBlockType);
218    }
219    if (structBasisBits->isOpaque()) {
220        structBasisBits->setBody(StructTy_struct_Basis_bits_fields, /*isPacked=*/false);
221    }
222    mBasisBitsInputPtr = PointerType::get(structBasisBits, 0);
223
224    std::vector<Type*>functionTypeArgs;
225    functionTypeArgs.push_back(mBasisBitsInputPtr);
226
227    //The carry data array.
228    //A pointer to the BitBlock vector.
229    functionTypeArgs.push_back(PointerType::get(mBitBlockType, 0));
230
231    //The output structure.
232    StructType * outputStruct = mMod->getTypeByName("struct.Output");
233    if (!outputStruct) {
234        outputStruct = StructType::create(mMod->getContext(), "struct.Output");
235    }
236    if (outputStruct->isOpaque()) {
237        std::vector<Type*>fields;
238        for (int i = 0; i != function.getResults().size(); i++) {
239            fields.push_back(mBitBlockType);
240        }
241        outputStruct->setBody(fields, /*isPacked=*/false);
242    }
243    PointerType * outputStructPtr = PointerType::get(outputStruct, 0);
244
245    //The &output parameter.
246    functionTypeArgs.push_back(outputStructPtr);
247
248    mFunctionType = FunctionType::get(
249     /*Result=*/Type::getVoidTy(mMod->getContext()),
250     /*Params=*/functionTypeArgs,
251     /*isVarArg=*/false);
252}
253
254void PabloCompiler::DeclareFunctions()
255{
256    //This function can be used for testing to print the contents of a register from JIT'd code to the terminal window.
257    mPrintRegisterFunction = mMod->getOrInsertFunction("wrapped_print_register", Type::getVoidTy(getGlobalContext()), Type::getInt8PtrTy(getGlobalContext()), mBitBlockType, NULL);
258    mExecutionEngine->addGlobalMapping(cast<GlobalValue>(mPrintRegisterFunction), (void *)&wrapped_print_register);
259    // to call->  mBuilder->CreateCall(mFunc_print_register, unicode_category);
260
261#ifdef USE_UADD_OVERFLOW
262#ifdef USE_TWO_UADD_OVERFLOW
263    // Type Definitions for llvm.uadd.with.overflow.carryin.i128 or .i256
264    std::vector<Type*>StructTy_0_fields;
265    StructTy_0_fields.push_back(IntegerType::get(mMod->getContext(), BLOCK_SIZE));
266    StructTy_0_fields.push_back(IntegerType::get(mMod->getContext(), 1));
267    StructType *StructTy_0 = StructType::get(mMod->getContext(), StructTy_0_fields, /*isPacked=*/false);
268
269    std::vector<Type*>FuncTy_1_args;
270    FuncTy_1_args.push_back(IntegerType::get(mMod->getContext(), BLOCK_SIZE));
271    FuncTy_1_args.push_back(IntegerType::get(mMod->getContext(), BLOCK_SIZE));
272    FunctionType* FuncTy_1 = FunctionType::get(
273                                              /*Result=*/StructTy_0,
274                                              /*Params=*/FuncTy_1_args,
275                                              /*isVarArg=*/false);
276
277    mFunctionUaddOverflow = mMod->getFunction("llvm.uadd.with.overflow.i" +
278                                              std::to_string(BLOCK_SIZE));
279    if (!mFunctionUaddOverflow) {
280        mFunctionUaddOverflow= Function::Create(
281          /*Type=*/ FuncTy_1,
282          /*Linkage=*/ GlobalValue::ExternalLinkage,
283          /*Name=*/ "llvm.uadd.with.overflow.i" + std::to_string(BLOCK_SIZE), mMod); // (external, no body)
284        mFunctionUaddOverflow->setCallingConv(CallingConv::C);
285    }
286    AttributeSet mFunctionUaddOverflowPAL;
287    {
288        SmallVector<AttributeSet, 4> Attrs;
289        AttributeSet PAS;
290        {
291          AttrBuilder B;
292          B.addAttribute(Attribute::NoUnwind);
293          B.addAttribute(Attribute::ReadNone);
294          PAS = AttributeSet::get(mMod->getContext(), ~0U, B);
295        }
296
297        Attrs.push_back(PAS);
298        mFunctionUaddOverflowPAL = AttributeSet::get(mMod->getContext(), Attrs);
299    }
300    mFunctionUaddOverflow->setAttributes(mFunctionUaddOverflowPAL);
301#else
302    // Type Definitions for llvm.uadd.with.overflow.carryin.i128 or .i256
303    std::vector<Type*>StructTy_0_fields;
304    StructTy_0_fields.push_back(IntegerType::get(mMod->getContext(), BLOCK_SIZE));
305    StructTy_0_fields.push_back(IntegerType::get(mMod->getContext(), 1));
306    StructType *StructTy_0 = StructType::get(mMod->getContext(), StructTy_0_fields, /*isPacked=*/false);
307
308    std::vector<Type*>FuncTy_1_args;
309    FuncTy_1_args.push_back(IntegerType::get(mMod->getContext(), BLOCK_SIZE));
310    FuncTy_1_args.push_back(IntegerType::get(mMod->getContext(), BLOCK_SIZE));
311    FuncTy_1_args.push_back(IntegerType::get(mMod->getContext(), 1));
312    FunctionType* FuncTy_1 = FunctionType::get(
313                                              /*Result=*/StructTy_0,
314                                              /*Params=*/FuncTy_1_args,
315                                              /*isVarArg=*/false);
316
317    mFunctionUaddOverflowCarryin = mMod->getFunction("llvm.uadd.with.overflow.carryin.i" +
318                                              std::to_string(BLOCK_SIZE));
319    if (!mFunctionUaddOverflowCarryin) {
320        mFunctionUaddOverflowCarryin = Function::Create(
321          /*Type=*/ FuncTy_1,
322          /*Linkage=*/ GlobalValue::ExternalLinkage,
323          /*Name=*/ "llvm.uadd.with.overflow.carryin.i" + std::to_string(BLOCK_SIZE), mMod); // (external, no body)
324        mFunctionUaddOverflowCarryin->setCallingConv(CallingConv::C);
325    }
326    AttributeSet mFunctionUaddOverflowCarryinPAL;
327    {
328        SmallVector<AttributeSet, 4> Attrs;
329        AttributeSet PAS;
330        {
331          AttrBuilder B;
332          B.addAttribute(Attribute::NoUnwind);
333          B.addAttribute(Attribute::ReadNone);
334          PAS = AttributeSet::get(mMod->getContext(), ~0U, B);
335        }
336
337        Attrs.push_back(PAS);
338        mFunctionUaddOverflowCarryinPAL = AttributeSet::get(mMod->getContext(), Attrs);
339    }
340    mFunctionUaddOverflowCarryin->setAttributes(mFunctionUaddOverflowCarryinPAL);
341#endif
342#endif
343
344    //Starts on process_block
345    SmallVector<AttributeSet, 4> Attrs;
346    AttributeSet PAS;
347    {
348        AttrBuilder B;
349        B.addAttribute(Attribute::ReadOnly);
350        B.addAttribute(Attribute::NoCapture);
351        PAS = AttributeSet::get(mMod->getContext(), 1U, B);
352    }
353    Attrs.push_back(PAS);
354    {
355        AttrBuilder B;
356        B.addAttribute(Attribute::NoCapture);
357        PAS = AttributeSet::get(mMod->getContext(), 2U, B);
358    }
359    Attrs.push_back(PAS);
360    {
361        AttrBuilder B;
362        B.addAttribute(Attribute::NoCapture);
363        PAS = AttributeSet::get(mMod->getContext(), 3U, B);
364    }
365    Attrs.push_back(PAS);
366    {
367        AttrBuilder B;
368        B.addAttribute(Attribute::NoUnwind);
369        B.addAttribute(Attribute::UWTable);
370        PAS = AttributeSet::get(mMod->getContext(), ~0U, B);
371    }
372    AttributeSet AttrSet = AttributeSet::get(mMod->getContext(), Attrs);
373
374    //Create the function that will be generated.
375    mFunction = mMod->getFunction("process_block");
376    if (!mFunction) {
377        mFunction = Function::Create(
378            /*Type=*/mFunctionType,
379            /*Linkage=*/GlobalValue::ExternalLinkage,
380            /*Name=*/"process_block", mMod);
381        mFunction->setCallingConv(CallingConv::C);
382    }
383    mFunction->setAttributes(AttrSet);
384}
385   
386void PabloCompiler::Examine(PabloBlock & blk) {
387    for (Statement * stmt : blk) {
388        if (Call * call = dyn_cast<Call>(stmt)) {
389            mCalleeMap.insert(std::make_pair(call->getCallee(), nullptr));
390        }
391        else if (If * ifStatement = dyn_cast<If>(stmt)) {
392            ++mIfDepth;
393            Examine(ifStatement->getBody());
394            --mIfDepth;
395        }
396        else if (While * whileStatement = dyn_cast<While>(stmt)) {
397            mMaxWhileDepth = std::max(mMaxWhileDepth, ++mWhileDepth);
398            Examine(whileStatement->getBody());
399            --mWhileDepth;
400        }
401    }
402}
403
404void PabloCompiler::DeclareCallFunctions() {
405    for (auto mapping : mCalleeMap) {
406        const String * callee = mapping.first;
407        //std::cerr << callee->str() << " to be declared\n";
408        auto ei = mExternalMap.find(callee->value());
409        if (ei != mExternalMap.end()) {
410            void * fn_ptr = ei->second;
411            //std::cerr << "Ptr found:" <<  std::hex << ((intptr_t) fn_ptr) << std::endl;
412            Value * externalValue = mMod->getOrInsertFunction(callee->value(), mBitBlockType, mBasisBitsInputPtr, NULL);
413            if (LLVM_UNLIKELY(externalValue == nullptr)) {
414                throw std::runtime_error("Could not create static method call for external function \"" + callee->to_string() + "\"");
415            }
416            mExecutionEngine->addGlobalMapping(cast<GlobalValue>(externalValue), fn_ptr);
417            mCalleeMap[callee] = externalValue;
418        }
419        else {
420            throw std::runtime_error("External function \"" + callee->to_string() + "\" not installed");
421        }
422    }
423}
424
425void PabloCompiler::compileBlock(PabloBlock & block) {
426    mCarryManager->ensureCarriesLoadedLocal(block);
427    mPabloBlock = & block;
428    for (const Statement * statement : block) {
429        compileStatement(statement);
430    }
431    mPabloBlock = block.getParent();
432    mCarryManager->ensureCarriesStoredLocal(block);
433}
434
435
436void PabloCompiler::compileIf(const If * ifStatement) {       
437    //
438    //  The If-ElseZero stmt:
439    //  if <predicate:expr> then <body:stmt>* elsezero <defined:var>* endif
440    //  If the value of the predicate is nonzero, then determine the values of variables
441    //  <var>* by executing the given statements.  Otherwise, the value of the
442    //  variables are all zero.  Requirements: (a) no variable that is defined within
443    //  the body of the if may be accessed outside unless it is explicitly
444    //  listed in the variable list, (b) every variable in the defined list receives
445    //  a value within the body, and (c) the logical consequence of executing
446    //  the statements in the event that the predicate is zero is that the
447    //  values of all defined variables indeed work out to be 0.
448    //
449    //  Simple Implementation with Phi nodes:  a phi node in the if exit block
450    //  is inserted for each variable in the defined variable list.  It receives
451    //  a zero value from the ifentry block and the defined value from the if
452    //  body.
453    //
454    BasicBlock * ifEntryBlock = mBuilder->GetInsertBlock();
455    BasicBlock * ifBodyBlock = BasicBlock::Create(mMod->getContext(), "if.body", mFunction, 0);
456    BasicBlock * ifEndBlock = BasicBlock::Create(mMod->getContext(), "if.end", mFunction, 0);
457   
458    PabloBlock & ifBody = ifStatement -> getBody();
459   
460    Value* if_test_value = compileExpression(ifStatement->getCondition());
461    if (mCarryManager->blockHasCarries(ifBody)) {
462        // load the summary variable
463        Value* last_if_pending_data = mCarryManager->getCarrySummaryExpr(ifBody);
464        if_test_value = mBuilder->CreateOr(if_test_value, last_if_pending_data);
465    }
466    mBuilder->CreateCondBr(genBitBlockAny(if_test_value), ifEndBlock, ifBodyBlock);
467    // Entry processing is complete, now handle the body of the if.
468    mBuilder->SetInsertPoint(ifBodyBlock);
469   
470    ++mIfDepth;
471    compileBlock(ifBody);
472    --mIfDepth;
473    if (mCarryManager->blockHasCarries(ifBody)) {
474        mCarryManager->generateCarryOutSummaryCode(ifBody);
475    }
476    BasicBlock * ifBodyFinalBlock = mBuilder->GetInsertBlock();
477    mBuilder->CreateBr(ifEndBlock);
478    //End Block
479    mBuilder->SetInsertPoint(ifEndBlock);
480    for (const PabloAST * node : ifStatement->getDefined()) {
481        const Assign * assign = cast<Assign>(node);
482        PHINode * phi = mBuilder->CreatePHI(mBitBlockType, 2, assign->getName()->value());
483        auto f = mMarkerMap.find(assign);
484        assert (f != mMarkerMap.end());
485        phi->addIncoming(mZeroInitializer, ifEntryBlock);
486        phi->addIncoming(f->second, ifBodyFinalBlock);
487        mMarkerMap[assign] = phi;
488    }
489    // Create the phi Node for the summary variable, if needed.
490    if (mCarryManager->summaryNeededInParentBlock(ifBody)) {
491        mCarryManager->addSummaryPhi(ifBody, ifEntryBlock, ifBodyFinalBlock);
492    }
493}
494
495void PabloCompiler::compileWhile(const While * whileStatement) {
496
497    PabloBlock & whileBody = whileStatement -> getBody();
498   
499    BasicBlock * whileEntryBlock = mBuilder->GetInsertBlock();
500    BasicBlock * whileBodyBlock = BasicBlock::Create(mMod->getContext(), "while.body", mFunction, 0);
501    BasicBlock * whileEndBlock = BasicBlock::Create(mMod->getContext(), "while.end", mFunction, 0);
502
503    mCarryManager->ensureCarriesLoadedRecursive(whileBody);
504
505    const auto & nextNodes = whileStatement->getVariants();
506    std::vector<PHINode *> nextPhis;
507    nextPhis.reserve(nextNodes.size());
508
509    // On entry to the while structure, proceed to execute the first iteration
510    // of the loop body unconditionally.   The while condition is tested at the end of
511    // the loop.
512
513    mBuilder->CreateBr(whileBodyBlock);
514    mBuilder->SetInsertPoint(whileBodyBlock);
515
516    //
517    // There are 3 sets of Phi nodes for the while loop.
518    // (1) Carry-ins: (a) incoming carry data first iterations, (b) zero thereafter
519    // (2) Carry-out accumulators: (a) zero first iteration, (b) |= carry-out of each iteration
520    // (3) Next nodes: (a) values set up before loop, (b) modified values calculated in loop.
521
522    mCarryManager->initializeCarryDataPhisAtWhileEntry(whileBody, whileEntryBlock);
523
524    // for any Next nodes in the loop body, initialize to (a) pre-loop value.
525    for (const Next * n : nextNodes) {
526        PHINode * phi = mBuilder->CreatePHI(mBitBlockType, 2, n->getName()->value());
527        auto f = mMarkerMap.find(n->getInitial());
528        assert (f != mMarkerMap.end());
529        phi->addIncoming(f->second, whileEntryBlock);
530        mMarkerMap[n->getInitial()] = phi;
531        nextPhis.push_back(phi);
532    }
533
534    //
535    // Now compile the loop body proper.  Carry-out accumulated values
536    // and iterated values of Next nodes will be computed.
537    ++mWhileDepth;
538    compileBlock(whileBody);
539
540    BasicBlock * whileBodyFinalBlock = mBuilder->GetInsertBlock();
541
542    mCarryManager->extendCarryDataPhisAtWhileBodyFinalBlock(whileBody, whileBodyFinalBlock);
543
544    // Terminate the while loop body with a conditional branch back.
545    mBuilder->CreateCondBr(genBitBlockAny(compileExpression(whileStatement->getCondition())), whileEndBlock, whileBodyBlock);
546
547    // and for any Next nodes in the loop body
548    for (unsigned i = 0; i < nextNodes.size(); i++) {
549        const Next * n = nextNodes[i];
550        auto f = mMarkerMap.find(n->getExpr());
551        if (LLVM_UNLIKELY(f == mMarkerMap.end())) {
552            throw std::runtime_error("Next node expression was not compiled!");
553        }
554        nextPhis[i]->addIncoming(f->second, whileBodyFinalBlock);
555    }
556
557    mBuilder->SetInsertPoint(whileEndBlock);
558    --mWhileDepth;
559
560    mCarryManager->ensureCarriesStoredRecursive(whileBody);
561}
562
563
564void PabloCompiler::compileStatement(const Statement * stmt) {
565    Value * expr = nullptr;
566    if (const Assign * assign = dyn_cast<const Assign>(stmt)) {
567        expr = compileExpression(assign->getExpression());
568    }
569    else if (const Next * next = dyn_cast<const Next>(stmt)) {
570        expr = compileExpression(next->getExpr());
571        if (TraceNext) {
572            genPrintRegister(next->getName()->to_string(), expr);
573        }
574    }
575    else if (const If * ifStatement = dyn_cast<const If>(stmt)) {
576        compileIf(ifStatement);
577        return;
578    }
579    else if (const While * whileStatement = dyn_cast<const While>(stmt)) {
580        compileWhile(whileStatement);
581        return;
582    }
583    else if (const Call* call = dyn_cast<Call>(stmt)) {
584        //Call the callee once and store the result in the marker map.
585        if (mMarkerMap.count(call) != 0) {
586            return;
587        }
588        auto ci = mCalleeMap.find(call->getCallee());
589        if (LLVM_UNLIKELY(ci == mCalleeMap.end())) {
590            throw std::runtime_error("Unexpected error locating static function for \"" + call->getCallee()->to_string() + "\"");
591        }
592        expr = mBuilder->CreateCall(ci->second, mParameterAddr);
593    }
594    else if (const And * pablo_and = dyn_cast<And>(stmt)) {
595        expr = mBuilder->CreateAnd(compileExpression(pablo_and->getExpr1()), compileExpression(pablo_and->getExpr2()), "and");
596    }
597    else if (const Or * pablo_or = dyn_cast<Or>(stmt)) {
598        expr = mBuilder->CreateOr(compileExpression(pablo_or->getExpr1()), compileExpression(pablo_or->getExpr2()), "or");
599    }
600    else if (const Xor * pablo_xor = dyn_cast<Xor>(stmt)) {
601        expr = mBuilder->CreateXor(compileExpression(pablo_xor->getExpr1()), compileExpression(pablo_xor->getExpr2()), "xor");
602    }
603    else if (const Sel * sel = dyn_cast<Sel>(stmt)) {
604        Value* ifMask = compileExpression(sel->getCondition());
605        Value* ifTrue = mBuilder->CreateAnd(ifMask, compileExpression(sel->getTrueExpr()));
606        Value* ifFalse = mBuilder->CreateAnd(genNot(ifMask), compileExpression(sel->getFalseExpr()));
607        expr = mBuilder->CreateOr(ifTrue, ifFalse);
608    }
609    else if (const Not * pablo_not = dyn_cast<Not>(stmt)) {
610        expr = genNot(compileExpression(pablo_not->getExpr()));
611    }
612    else if (const Advance * adv = dyn_cast<Advance>(stmt)) {
613        Value* strm_value = compileExpression(adv->getExpr());
614        int shift = adv->getAdvanceAmount();
615        unsigned advance_index = adv->getLocalAdvanceIndex();
616        expr = mCarryManager->advanceCarryInCarryOut(mPabloBlock, advance_index, shift, strm_value);
617    }
618    else if (const MatchStar * mstar = dyn_cast<MatchStar>(stmt)) {
619        Value * marker = compileExpression(mstar->getMarker());
620        Value * cc = compileExpression(mstar->getCharClass());
621        Value * marker_and_cc = mBuilder->CreateAnd(marker, cc);
622        unsigned carry_index = mstar->getLocalCarryIndex();
623        expr = mBuilder->CreateOr(mBuilder->CreateXor(genAddWithCarry(marker_and_cc, cc, carry_index), cc), marker, "matchstar");
624    }
625    else if (const ScanThru * sthru = dyn_cast<ScanThru>(stmt)) {
626        Value * marker_expr = compileExpression(sthru->getScanFrom());
627        Value * cc_expr = compileExpression(sthru->getScanThru());
628        unsigned carry_index = sthru->getLocalCarryIndex();
629        expr = mBuilder->CreateAnd(genAddWithCarry(marker_expr, cc_expr, carry_index), genNot(cc_expr), "scanthru");
630    }
631    else {
632        llvm::raw_os_ostream cerr(std::cerr);
633        PabloPrinter::print(stmt, cerr);
634        throw std::runtime_error("Unrecognized Pablo Statement! can't compile.");
635    }
636    mMarkerMap[stmt] = expr;
637    if (DumpTrace) {
638        genPrintRegister(stmt->getName()->to_string(), expr);
639    }
640   
641}
642
643Value * PabloCompiler::compileExpression(const PabloAST * expr) {
644    if (isa<Ones>(expr)) {
645        return mOneInitializer;
646    }
647    else if (isa<Zeroes>(expr)) {
648        return mZeroInitializer;
649    }
650    auto f = mMarkerMap.find(expr);
651    if (LLVM_UNLIKELY(f == mMarkerMap.end())) {
652        std::string o;
653        llvm::raw_string_ostream str(o);
654        str << "\"";
655        PabloPrinter::print(expr, str);
656        str << "\" was used before definition!";
657        throw std::runtime_error(str.str());
658    }
659    return f->second;
660}
661
662
663#ifdef USE_UADD_OVERFLOW
664#ifdef USE_TWO_UADD_OVERFLOW
665PabloCompiler::SumWithOverflowPack PabloCompiler::callUaddOverflow(Value* int128_e1, Value* int128_e2) {
666    std::vector<Value*> struct_res_params;
667    struct_res_params.push_back(int128_e1);
668    struct_res_params.push_back(int128_e2);
669    CallInst* struct_res = CallInst::Create(mFunctionUaddOverflow, struct_res_params, "uadd_overflow_res", mBasicBlock);
670    struct_res->setCallingConv(CallingConv::C);
671    struct_res->setTailCall(false);
672    AttributeSet struct_res_PAL;
673    struct_res->setAttributes(struct_res_PAL);
674
675    SumWithOverflowPack ret;
676
677    std::vector<unsigned> int128_sum_indices;
678    int128_sum_indices.push_back(0);
679    ret.sum = ExtractValueInst::Create(struct_res, int128_sum_indices, "sum", mBasicBlock);
680
681    std::vector<unsigned> int1_obit_indices;
682    int1_obit_indices.push_back(1);
683    ret.obit = ExtractValueInst::Create(struct_res, int1_obit_indices, "obit", mBasicBlock);
684
685    return ret;
686}
687#else
688PabloCompiler::SumWithOverflowPack PabloCompiler::callUaddOverflow(Value* int128_e1, Value* int128_e2, Value* int1_cin) {
689    std::vector<Value*> struct_res_params;
690    struct_res_params.push_back(int128_e1);
691    struct_res_params.push_back(int128_e2);
692    struct_res_params.push_back(int1_cin);
693    CallInst* struct_res = CallInst::Create(mFunctionUaddOverflowCarryin, struct_res_params, "uadd_overflow_res", mBasicBlock);
694    struct_res->setCallingConv(CallingConv::C);
695    struct_res->setTailCall(false);
696    AttributeSet struct_res_PAL;
697    struct_res->setAttributes(struct_res_PAL);
698
699    SumWithOverflowPack ret;
700
701    std::vector<unsigned> int128_sum_indices;
702    int128_sum_indices.push_back(0);
703    ret.sum = ExtractValueInst::Create(struct_res, int128_sum_indices, "sum", mBasicBlock);
704
705    std::vector<unsigned> int1_obit_indices;
706    int1_obit_indices.push_back(1);
707    ret.obit = ExtractValueInst::Create(struct_res, int1_obit_indices, "obit", mBasicBlock);
708
709    return ret;
710}
711#endif
712#endif
713
714
715Value* PabloCompiler::genAddWithCarry(Value* e1, Value* e2, unsigned localIndex) {
716    Value * carryq_value = mCarryManager->getCarryOpCarryIn(mPabloBlock, localIndex);
717#ifdef USE_TWO_UADD_OVERFLOW
718    //This is the ideal implementation, which uses two uadd.with.overflow
719    //The back end should be able to recognize this pattern and combine it into uadd.with.overflow.carryin
720    CastInst* int128_e1 = new BitCastInst(e1, mBuilder->getIntNTy(BLOCK_SIZE), "e1_128", mBasicBlock);
721    CastInst* int128_e2 = new BitCastInst(e2, mBuilder->getIntNTy(BLOCK_SIZE), "e2_128", mBasicBlock);
722    CastInst* int128_carryq_value = new BitCastInst(carryq_value, mBuilder->getIntNTy(BLOCK_SIZE), "carryq_128", mBasicBlock);
723
724    SumWithOverflowPack sumpack0, sumpack1;
725
726    sumpack0 = callUaddOverflow(int128_e1, int128_e2);
727    sumpack1 = callUaddOverflow(sumpack0.sum, int128_carryq_value);
728
729    Value* obit = mBuilder->CreateOr(sumpack0.obit, sumpack1.obit, "carry_bit");
730    Value* sum = mBuilder->CreateBitCast(sumpack1.sum, mBitBlockType, "ret_sum");
731
732    /*obit is the i1 carryout, zero extend and insert it into a v2i64 or v4i64 vector.*/
733    ConstantAggregateZero* const_packed_5 = ConstantAggregateZero::get(mBitBlockType);
734    ConstantInt* const_int32_6 = ConstantInt::get(mMod->getContext(), APInt(32, StringRef("0"), 10));
735    CastInst* int64_o0 = new ZExtInst(obit, IntegerType::get(mMod->getContext(), 64), "o0", mBasicBlock);
736    InsertElementInst* carry_out = InsertElementInst::Create(const_packed_5, int64_o0, const_int32_6, "carry_out", mBasicBlock);
737
738#elif defined USE_UADD_OVERFLOW
739    //use llvm.uadd.with.overflow.i128 or i256
740    CastInst* int128_e1 = new BitCastInst(e1, mBuilder->getIntNTy(BLOCK_SIZE), "e1_128", mBasicBlock);
741    CastInst* int128_e2 = new BitCastInst(e2, mBuilder->getIntNTy(BLOCK_SIZE), "e2_128", mBasicBlock);
742
743    //get i1 carryin from iBLOCK_SIZE
744    ConstantInt* const_int32_6 = ConstantInt::get(mMod->getContext(), APInt(32, StringRef("0"), 10));
745    ExtractElementInst * int64_carryq_value = ExtractElementInst::Create(carryq_value, const_int32_6, "carryq_64", mBasicBlock);
746    CastInst* int1_carryq_value = new TruncInst(int64_carryq_value, IntegerType::get(mMod->getContext(), 1), "carryq_1", mBasicBlock);
747
748    SumWithOverflowPack sumpack0;
749    sumpack0 = callUaddOverflow(int128_e1, int128_e2, int1_carryq_value);
750    Value* obit = sumpack0.obit;
751    Value* sum = mBuilder->CreateBitCast(sumpack0.sum, mBitBlockType, "sum");
752
753    /*obit is the i1 carryout, zero extend and insert it into a v2i64 or v4i64 vector.*/
754    ConstantAggregateZero* const_packed_5 = ConstantAggregateZero::get(mBitBlockType);
755    CastInst* int64_o0 = new ZExtInst(obit, IntegerType::get(mMod->getContext(), 64), "o0", mBasicBlock);
756    InsertElementInst* carry_out = InsertElementInst::Create(const_packed_5, int64_o0, const_int32_6, "carry_out", mBasicBlock);
757#elif (BLOCK_SIZE == 128)
758    //calculate carry through logical ops
759    Value* carrygen = mBuilder->CreateAnd(e1, e2, "carrygen");
760    Value* carryprop = mBuilder->CreateOr(e1, e2, "carryprop");
761    Value* digitsum = mBuilder->CreateAdd(e1, e2, "digitsum");
762    Value* partial = mBuilder->CreateAdd(digitsum, carryq_value, "partial");
763    Value* digitcarry = mBuilder->CreateOr(carrygen, mBuilder->CreateAnd(carryprop, genNot(partial)));
764    Value* mid_carry_in = genShiftLeft64(mBuilder->CreateLShr(digitcarry, 63), "mid_carry_in");
765
766    Value* sum = mBuilder->CreateAdd(partial, mid_carry_in, "sum");
767    Value* carry_out = genShiftHighbitToLow(BLOCK_SIZE, mBuilder->CreateOr(carrygen, mBuilder->CreateAnd(carryprop, genNot(sum))));
768#else
769    //BLOCK_SIZE == 256, there is no other implementation
770    static_assert(false, "Add with carry for 256-bit bitblock requires USE_UADD_OVERFLOW");
771#endif //USE_TWO_UADD_OVERFLOW
772
773    mCarryManager->setCarryOpCarryOut(mPabloBlock, localIndex, carry_out);
774    return sum;
775}
776
777inline Value* PabloCompiler::genBitBlockAny(Value* test) {
778    Value* cast_marker_value_1 = mBuilder->CreateBitCast(test, mBuilder->getIntNTy(BLOCK_SIZE));
779    return mBuilder->CreateICmpEQ(cast_marker_value_1, ConstantInt::get(mBuilder->getIntNTy(BLOCK_SIZE), 0));
780}
781
782Value * PabloCompiler::genShiftHighbitToLow(unsigned FieldWidth, Value * op) {
783    unsigned FieldCount = BLOCK_SIZE/FieldWidth;
784    VectorType * vType = VectorType::get(IntegerType::get(mMod->getContext(), FieldWidth), FieldCount);
785    Value * v = mBuilder->CreateBitCast(op, vType);
786    return mBuilder->CreateBitCast(mBuilder->CreateLShr(v, FieldWidth - 1), mBitBlockType);
787}
788
789Value* PabloCompiler::genShiftLeft64(Value* e, const Twine &namehint) {
790    Value* i128_val = mBuilder->CreateBitCast(e, mBuilder->getIntNTy(BLOCK_SIZE));
791    return mBuilder->CreateBitCast(mBuilder->CreateShl(i128_val, 64, namehint), mBitBlockType);
792}
793
794inline Value* PabloCompiler::genNot(Value* expr) {
795    return mBuilder->CreateXor(expr, mOneInitializer, "not");
796}
797
798   
799void PabloCompiler::SetOutputValue(Value * marker, const unsigned index) {
800    if (marker->getType()->isPointerTy()) {
801        marker = mBuilder->CreateAlignedLoad(marker, BLOCK_SIZE/8, false);
802    }
803    Value* indices[] = {mBuilder->getInt64(0), mBuilder->getInt32(index)};
804    Value* gep = mBuilder->CreateGEP(mOutputAddrPtr, indices);
805    mBuilder->CreateAlignedStore(marker, gep, BLOCK_SIZE/8, false);
806}
807
808CompiledPabloFunction::CompiledPabloFunction(size_t carryDataSize, Function * function, ExecutionEngine * executionEngine)
809: CarryDataSize(carryDataSize)
810, FunctionPointer(executionEngine->getPointerToFunction(function))
811, mFunction(function)
812, mExecutionEngine(executionEngine)
813{
814
815}
816
817// Clean up the memory for the compiled function once we're finished using it.
818CompiledPabloFunction::~CompiledPabloFunction() {
819    if (mExecutionEngine) {
820        assert (mFunction);
821        // mExecutionEngine->freeMachineCodeForFunction(mFunction); // This function only prints a "not supported" message. Reevaluate with LLVM 3.6.
822        delete mExecutionEngine;
823    }
824}
825
826}
Note: See TracBrowser for help on using the repository browser.