source: icGREP/icgrep-devel/icgrep/pablo/pablo_compiler.cpp @ 4716

Last change on this file since 4716 was 4716, checked in by cameron, 4 years ago

Mod64 approximation mode

File size: 35.5 KB
Line 
1/*
2 *  Copyright (c) 2014-15 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 *  icgrep is a trademark of International Characters.
5 */
6
7#include <pablo/pablo_compiler.h>
8#include <pablo/codegenstate.h>
9#include <pablo/carry_data.h>
10#include <pablo/carry_manager.h>
11#include <pablo/printer_pablos.h>
12#include <pablo/function.h>
13#include <cc/cc_namemap.hpp>
14#include <re/re_name.h>
15#include <stdexcept>
16#include <include/simd-lib/bitblock.hpp>
17#include <sstream>
18#include <IDISA/idisa_builder.h>
19#include <llvm/IR/Verifier.h>
20#include <llvm/Pass.h>
21#include <llvm/PassManager.h>
22#include <llvm/ADT/SmallVector.h>
23#include <llvm/Analysis/Passes.h>
24#include <llvm/IR/BasicBlock.h>
25#include <llvm/IR/CallingConv.h>
26#include <llvm/IR/Constants.h>
27#include <llvm/IR/DataLayout.h>
28#include <llvm/IR/DerivedTypes.h>
29#include <llvm/IR/Function.h>
30#include <llvm/IR/GlobalVariable.h>
31#include <llvm/IR/InlineAsm.h>
32#include <llvm/IR/Instructions.h>
33#include <llvm/IR/LLVMContext.h>
34#include <llvm/IR/Module.h>
35#include <llvm/Support/FormattedStream.h>
36#include <llvm/Support/MathExtras.h>
37#include <llvm/Support/Casting.h>
38#include <llvm/Support/Compiler.h>
39#include <llvm/Support/Debug.h>
40#include <llvm/Support/TargetSelect.h>
41#include <llvm/Support/Host.h>
42#include <llvm/Transforms/Scalar.h>
43#include <llvm/ExecutionEngine/ExecutionEngine.h>
44#include <llvm/ExecutionEngine/MCJIT.h>
45#include <llvm/IRReader/IRReader.h>
46#include <llvm/Bitcode/ReaderWriter.h>
47#include <llvm/Support/MemoryBuffer.h>
48#include <llvm/IR/IRBuilder.h>
49#include <llvm/Support/CommandLine.h>
50#include <llvm/ADT/Twine.h>
51#include <iostream>
52
53static cl::OptionCategory eIRDumpOptions("LLVM IR Dump Options", "These options control dumping of LLVM IR.");
54static cl::opt<bool> DumpGeneratedIR("dump-generated-IR", cl::init(false), cl::desc("Print LLVM IR generated by Pablo Compiler."), cl::cat(eIRDumpOptions));
55
56static cl::OptionCategory fTracingOptions("Run-time Tracing Options", "These options control execution traces.");
57static cl::opt<bool> TraceNext("trace-next-nodes", cl::init(false), cl::desc("Generate dynamic traces of executed Next nodes (while control variables)."), cl::cat(fTracingOptions));
58static cl::opt<bool> DumpTrace("dump-trace", cl::init(false), cl::desc("Generate dynamic traces of executed assignments."), cl::cat(fTracingOptions));
59
60extern "C" {
61  void wrapped_print_register(char * regName, BitBlock bit_block) {
62      print_register<BitBlock>(regName, bit_block);
63  }
64}
65
66namespace pablo {
67
68PabloCompiler::PabloCompiler()
69: mMod(nullptr)
70, mExecutionEngine(nullptr)
71, mBuilder(nullptr)
72, mCarryManager(nullptr)
73, mCarryOffset(0)
74, mBitBlockType(VectorType::get(IntegerType::get(getGlobalContext(), 64), BLOCK_SIZE / 64))
75, iBuilder(mBitBlockType)
76, mInputType(nullptr)
77, mCarryDataPtr(nullptr)
78, mWhileDepth(0)
79, mIfDepth(0)
80, mZeroInitializer(ConstantAggregateZero::get(mBitBlockType))
81, mOneInitializer(ConstantVector::getAllOnesValue(mBitBlockType))
82, mFunction(nullptr)
83, mInputAddressPtr(nullptr)
84, mOutputAddressPtr(nullptr)
85, mMaxWhileDepth(0)
86, mPrintRegisterFunction(nullptr) {
87
88}
89
90PabloCompiler::~PabloCompiler() {
91}
92   
93
94void PabloCompiler::genPrintRegister(std::string regName, Value * bitblockValue) {
95    Constant * regNameData = ConstantDataArray::getString(mMod->getContext(), regName);
96    GlobalVariable *regStrVar = new GlobalVariable(*mMod,
97                                                   ArrayType::get(IntegerType::get(mMod->getContext(), 8), regName.length()+1),
98                                                   /*isConstant=*/ true,
99                                                   /*Linkage=*/ GlobalValue::PrivateLinkage,
100                                                   /*Initializer=*/ regNameData);
101    Value * regStrPtr = mBuilder->CreateGEP(regStrVar, {mBuilder->getInt64(0), mBuilder->getInt32(0)});
102    mBuilder->CreateCall(mPrintRegisterFunction, {regStrPtr, bitblockValue});
103}
104
105CompiledPabloFunction PabloCompiler::compile(PabloFunction & function) {
106
107    Examine(function);
108
109    InitializeNativeTarget();
110    InitializeNativeTargetAsmPrinter();
111    InitializeNativeTargetAsmParser();
112
113    Module * module = new Module("", getGlobalContext());
114
115    mMod = module;
116
117    std::string errMessage;
118    #ifdef USE_LLVM_3_5
119    EngineBuilder builder(mMod);
120    #else
121    EngineBuilder builder(std::move(std::unique_ptr<Module>(mMod)));
122    #endif
123    builder.setErrorStr(&errMessage);
124    builder.setMCPU(sys::getHostCPUName());
125    #ifdef USE_LLVM_3_5
126    builder.setUseMCJIT(true);
127    #endif
128    builder.setOptLevel(mMaxWhileDepth ? CodeGenOpt::Level::Less : CodeGenOpt::Level::None);
129    mExecutionEngine = builder.create();
130    if (mExecutionEngine == nullptr) {
131        throw std::runtime_error("Could not create ExecutionEngine: " + errMessage);
132    }
133    DeclareDebugFunctions();
134
135    auto func = compile(function, mMod);
136
137    //Display the IR that has been generated by this module.
138    if (LLVM_UNLIKELY(DumpGeneratedIR)) {
139        module->dump();
140    }
141    //Create a verifier.  The verifier will print an error message if our module is malformed in any way.
142    verifyModule(*module, &dbgs());
143
144    mExecutionEngine->finalizeObject();
145    ExecutionEngine * engine = mExecutionEngine;
146    mExecutionEngine = nullptr; // <-- pass ownership of the execution engine to the caller
147
148    return CompiledPabloFunction(func.second, func.first, engine);
149}
150
151std::pair<llvm::Function *, size_t> PabloCompiler::compile(PabloFunction & function, Module * module) {
152
153 
154    function.getEntryBlock().enumerateScopes(0);
155   
156    Examine(function);
157
158    mMod = module;
159
160    mBuilder = new IRBuilder<>(mMod->getContext());
161
162    iBuilder.initialize(mMod, mBuilder);
163
164    mCarryManager = new CarryManager(mBuilder, mBitBlockType, mZeroInitializer, mOneInitializer, &iBuilder);
165
166    GenerateFunction(function);
167
168    mBuilder->SetInsertPoint(BasicBlock::Create(mMod->getContext(), "entry", mFunction,0));
169
170    //The basis bits structure
171    for (unsigned i = 0; i != function.getNumOfParameters(); ++i) {
172        Value* indices[] = {mBuilder->getInt64(0), mBuilder->getInt32(i)};
173        Value * gep = mBuilder->CreateGEP(mInputAddressPtr, indices);
174        LoadInst * basisBit = mBuilder->CreateAlignedLoad(gep, BLOCK_SIZE/8, false, function.getParameter(i)->getName()->to_string());
175        mMarkerMap[function.getParameter(i)] = basisBit;
176        if (DumpTrace) {
177            genPrintRegister(function.getParameter(i)->getName()->to_string(), basisBit);
178        }
179    }
180     
181    PabloBlock & mainScope = function.getEntryBlock();
182
183    mCarryOffset = mCarryManager->initialize(&mainScope, mCarryDataPtr);
184   
185    //Generate the IR instructions for the function.
186   
187    compileBlock(mainScope);
188   
189    mCarryManager->ensureCarriesStoredLocal();
190    mCarryManager->leaveScope();
191   
192   
193    mCarryManager->generateBlockNoIncrement();
194
195    if (DumpTrace || TraceNext) {
196        genPrintRegister("mBlockNo", mBuilder->CreateAlignedLoad(mBuilder->CreateBitCast(mCarryManager->getBlockNoPtr(), PointerType::get(mBitBlockType, 0)), BLOCK_SIZE/8, false));
197    }
198   
199    // Write the output values out
200    for (unsigned i = 0; i != function.getNumOfResults(); ++i) {
201        assert (function.getResult(i));
202        SetOutputValue(mMarkerMap[function.getResult(i)], i);
203    }
204
205    //Terminate the block
206    ReturnInst::Create(mMod->getContext(), mBuilder->GetInsertBlock());
207
208    // Clean up
209    delete mCarryManager; mCarryManager = nullptr;
210    delete mBuilder; mBuilder = nullptr;
211    mMod = nullptr; // don't delete this. It's either owned by the ExecutionEngine or the calling function.
212
213    //Return the required size of the carry data area to the process_block function.
214    return std::make_pair(mFunction, mCarryOffset * sizeof(BitBlock));
215}
216
217inline void PabloCompiler::GenerateFunction(PabloFunction & function) {
218    mInputType = PointerType::get(StructType::get(mMod->getContext(), std::vector<Type *>(function.getNumOfParameters(), mBitBlockType)), 0);
219    Type * carryType = PointerType::get(mBitBlockType, 0);
220    Type * outputType = PointerType::get(StructType::get(mMod->getContext(), std::vector<Type *>(function.getNumOfResults(), mBitBlockType)), 0);
221    FunctionType * functionType = FunctionType::get(Type::getVoidTy(mMod->getContext()), {{mInputType, carryType, outputType}}, false);
222
223#ifdef USE_UADD_OVERFLOW
224#ifdef USE_TWO_UADD_OVERFLOW
225    // Type Definitions for llvm.uadd.with.overflow.carryin.i128 or .i256
226    std::vector<Type*>StructTy_0_fields;
227    StructTy_0_fields.push_back(IntegerType::get(mMod->getContext(), BLOCK_SIZE));
228    StructTy_0_fields.push_back(IntegerType::get(mMod->getContext(), 1));
229    StructType *StructTy_0 = StructType::get(mMod->getContext(), StructTy_0_fields, /*isPacked=*/false);
230
231    std::vector<Type*>FuncTy_1_args;
232    FuncTy_1_args.push_back(IntegerType::get(mMod->getContext(), BLOCK_SIZE));
233    FuncTy_1_args.push_back(IntegerType::get(mMod->getContext(), BLOCK_SIZE));
234    FunctionType* FuncTy_1 = FunctionType::get(
235                                              /*Result=*/StructTy_0,
236                                              /*Params=*/FuncTy_1_args,
237                                              /*isVarArg=*/false);
238
239    mFunctionUaddOverflow = mMod->getFunction("llvm.uadd.with.overflow.i" +
240                                              std::to_string(BLOCK_SIZE));
241    if (!mFunctionUaddOverflow) {
242        mFunctionUaddOverflow= Function::Create(
243          /*Type=*/ FuncTy_1,
244          /*Linkage=*/ GlobalValue::ExternalLinkage,
245          /*Name=*/ "llvm.uadd.with.overflow.i" + std::to_string(BLOCK_SIZE), mMod); // (external, no body)
246        mFunctionUaddOverflow->setCallingConv(CallingConv::C);
247    }
248    AttributeSet mFunctionUaddOverflowPAL;
249    {
250        SmallVector<AttributeSet, 4> Attrs;
251        AttributeSet PAS;
252        {
253          AttrBuilder B;
254          B.addAttribute(Attribute::NoUnwind);
255          B.addAttribute(Attribute::ReadNone);
256          PAS = AttributeSet::get(mMod->getContext(), ~0U, B);
257        }
258
259        Attrs.push_back(PAS);
260        mFunctionUaddOverflowPAL = AttributeSet::get(mMod->getContext(), Attrs);
261    }
262    mFunctionUaddOverflow->setAttributes(mFunctionUaddOverflowPAL);
263#else
264    // Type Definitions for llvm.uadd.with.overflow.carryin.i128 or .i256
265    std::vector<Type*>StructTy_0_fields;
266    StructTy_0_fields.push_back(IntegerType::get(mMod->getContext(), BLOCK_SIZE));
267    StructTy_0_fields.push_back(IntegerType::get(mMod->getContext(), 1));
268    StructType *StructTy_0 = StructType::get(mMod->getContext(), StructTy_0_fields, /*isPacked=*/false);
269
270    std::vector<Type*>FuncTy_1_args;
271    FuncTy_1_args.push_back(IntegerType::get(mMod->getContext(), BLOCK_SIZE));
272    FuncTy_1_args.push_back(IntegerType::get(mMod->getContext(), BLOCK_SIZE));
273    FuncTy_1_args.push_back(IntegerType::get(mMod->getContext(), 1));
274    FunctionType* FuncTy_1 = FunctionType::get(
275                                              /*Result=*/StructTy_0,
276                                              /*Params=*/FuncTy_1_args,
277                                              /*isVarArg=*/false);
278
279    mFunctionUaddOverflowCarryin = mMod->getFunction("llvm.uadd.with.overflow.carryin.i" +
280                                              std::to_string(BLOCK_SIZE));
281    if (!mFunctionUaddOverflowCarryin) {
282        mFunctionUaddOverflowCarryin = Function::Create(
283          /*Type=*/ FuncTy_1,
284          /*Linkage=*/ GlobalValue::ExternalLinkage,
285          /*Name=*/ "llvm.uadd.with.overflow.carryin.i" + std::to_string(BLOCK_SIZE), mMod); // (external, no body)
286        mFunctionUaddOverflowCarryin->setCallingConv(CallingConv::C);
287    }
288    AttributeSet mFunctionUaddOverflowCarryinPAL;
289    {
290        SmallVector<AttributeSet, 4> Attrs;
291        AttributeSet PAS;
292        {
293          AttrBuilder B;
294          B.addAttribute(Attribute::NoUnwind);
295          B.addAttribute(Attribute::ReadNone);
296          PAS = AttributeSet::get(mMod->getContext(), ~0U, B);
297        }
298
299        Attrs.push_back(PAS);
300        mFunctionUaddOverflowCarryinPAL = AttributeSet::get(mMod->getContext(), Attrs);
301    }
302    mFunctionUaddOverflowCarryin->setAttributes(mFunctionUaddOverflowCarryinPAL);
303#endif
304#endif
305
306    //Starts on process_block
307    SmallVector<AttributeSet, 4> Attrs;
308    Attrs.push_back(AttributeSet::get(mMod->getContext(), ~0U, { Attribute::NoUnwind, Attribute::UWTable }));
309    Attrs.push_back(AttributeSet::get(mMod->getContext(), 1U, { Attribute::ReadOnly, Attribute::NoCapture }));
310    Attrs.push_back(AttributeSet::get(mMod->getContext(), 2U, { Attribute::NoCapture }));
311    Attrs.push_back(AttributeSet::get(mMod->getContext(), 3U, { Attribute::ReadNone, Attribute::NoCapture }));
312    AttributeSet AttrSet = AttributeSet::get(mMod->getContext(), Attrs);
313
314    // Create the function that will be generated.
315    mFunction = Function::Create(functionType, GlobalValue::ExternalLinkage, function.getName()->value(), mMod);
316    mFunction->setCallingConv(CallingConv::C);
317    mFunction->setAttributes(AttrSet);
318
319    Function::arg_iterator args = mFunction->arg_begin();
320    mInputAddressPtr = args++;
321    mInputAddressPtr->setName("input");
322    mCarryDataPtr = args++;
323    mCarryDataPtr->setName("carry");
324    mOutputAddressPtr = args++;
325    mOutputAddressPtr->setName("output");
326}
327
328inline void PabloCompiler::Examine(PabloFunction & function) {
329    mWhileDepth = 0;
330    mIfDepth = 0;
331    mMaxWhileDepth = 0;
332    Examine(function.getEntryBlock());
333    if (LLVM_UNLIKELY(mWhileDepth != 0 || mIfDepth != 0)) {
334        throw std::runtime_error("Malformed Pablo AST: Unbalanced If or While nesting depth!");
335    }
336}
337
338
339void PabloCompiler::Examine(PabloBlock & block) {
340    for (Statement * stmt : block) {
341        if (If * ifStatement = dyn_cast<If>(stmt)) {
342            Examine(ifStatement->getBody());
343        }
344        else if (While * whileStatement = dyn_cast<While>(stmt)) {
345            mMaxWhileDepth = std::max(mMaxWhileDepth, ++mWhileDepth);
346            Examine(whileStatement->getBody());
347            --mWhileDepth;
348        }
349    }
350}
351
352inline void PabloCompiler::DeclareDebugFunctions() {
353    if (DumpTrace || TraceNext) {
354        //This function can be used for testing to print the contents of a register from JIT'd code to the terminal window.
355        mPrintRegisterFunction = mMod->getOrInsertFunction("wrapped_print_register", Type::getVoidTy(mMod->getContext()), Type::getInt8PtrTy(mMod->getContext()), mBitBlockType, NULL);
356        if (mExecutionEngine) mExecutionEngine->addGlobalMapping(cast<GlobalValue>(mPrintRegisterFunction), (void *)&wrapped_print_register);
357    }
358}
359
360void PabloCompiler::compileBlock(PabloBlock & block) {
361    mPabloBlock = & block;
362    for (const Statement * statement : block) {
363        compileStatement(statement);
364    }
365    mPabloBlock = block.getParent();
366}
367
368    Value * PabloCompiler::genBitTest2(Value * e1, Value * e2) {
369        Type * t1 = e1->getType();
370        Type * t2 = e2->getType();
371        if (t1 == mBitBlockType) {
372            if (t2 == mBitBlockType) {
373                return iBuilder.bitblock_any(mBuilder->CreateOr(e1, e2));
374            }
375            else {
376                Value * m1 = mBuilder->CreateZExt(iBuilder.hsimd_signmask(16, e1), t2);
377                return mBuilder->CreateICmpNE(mBuilder->CreateOr(m1, e2), ConstantInt::get(t2, 0));
378            }
379        }
380        else if (t2 == mBitBlockType) {
381            Value * m2 = mBuilder->CreateZExt(iBuilder.hsimd_signmask(16, e2), t1);
382            return mBuilder->CreateICmpNE(mBuilder->CreateOr(e1, m2), ConstantInt::get(t1, 0));
383        }
384        else {
385            return mBuilder->CreateICmpNE(mBuilder->CreateOr(e1, e2), ConstantInt::get(t1, 0));
386        }
387    }
388   
389    void PabloCompiler::compileIf(const If * ifStatement) {       
390    //
391    //  The If-ElseZero stmt:
392    //  if <predicate:expr> then <body:stmt>* elsezero <defined:var>* endif
393    //  If the value of the predicate is nonzero, then determine the values of variables
394    //  <var>* by executing the given statements.  Otherwise, the value of the
395    //  variables are all zero.  Requirements: (a) no variable that is defined within
396    //  the body of the if may be accessed outside unless it is explicitly
397    //  listed in the variable list, (b) every variable in the defined list receives
398    //  a value within the body, and (c) the logical consequence of executing
399    //  the statements in the event that the predicate is zero is that the
400    //  values of all defined variables indeed work out to be 0.
401    //
402    //  Simple Implementation with Phi nodes:  a phi node in the if exit block
403    //  is inserted for each variable in the defined variable list.  It receives
404    //  a zero value from the ifentry block and the defined value from the if
405    //  body.
406    //
407
408    BasicBlock * ifEntryBlock = mBuilder->GetInsertBlock();
409    BasicBlock * ifBodyBlock = BasicBlock::Create(mMod->getContext(), "if.body", mFunction, 0);
410    BasicBlock * ifEndBlock = BasicBlock::Create(mMod->getContext(), "if.end", mFunction, 0);
411   
412    PabloBlock & ifBody = ifStatement -> getBody();
413   
414    Value * if_test_value = compileExpression(ifStatement->getCondition());
415   
416    mCarryManager->enterScope(&ifBody);
417    if (mCarryManager->blockHasCarries()) {
418        // load the summary variable
419        Value* last_if_pending_data = mCarryManager->getCarrySummaryExpr();
420        mBuilder->CreateCondBr(genBitTest2(if_test_value, last_if_pending_data), ifBodyBlock, ifEndBlock);
421
422    }
423    else {
424        mBuilder->CreateCondBr(iBuilder.bitblock_any(if_test_value), ifBodyBlock, ifEndBlock);
425    }
426    // Entry processing is complete, now handle the body of the if.
427    mBuilder->SetInsertPoint(ifBodyBlock);
428   
429    mCarryManager->initializeCarryDataAtIfEntry();
430    compileBlock(ifBody);
431    if (mCarryManager->blockHasCarries()) {
432        mCarryManager->generateCarryOutSummaryCodeIfNeeded();
433    }
434    BasicBlock * ifBodyFinalBlock = mBuilder->GetInsertBlock();
435    mCarryManager->ensureCarriesStoredLocal();
436    mBuilder->CreateBr(ifEndBlock);
437    //End Block
438    mBuilder->SetInsertPoint(ifEndBlock);
439    for (const PabloAST * node : ifStatement->getDefined()) {
440        const Assign * assign = cast<Assign>(node);
441        PHINode * phi = mBuilder->CreatePHI(mBitBlockType, 2, assign->getName()->value());
442        auto f = mMarkerMap.find(assign);
443        assert (f != mMarkerMap.end());
444        phi->addIncoming(mZeroInitializer, ifEntryBlock);
445        phi->addIncoming(f->second, ifBodyFinalBlock);
446        mMarkerMap[assign] = phi;
447    }
448    // Create the phi Node for the summary variable, if needed.
449    mCarryManager->buildCarryDataPhisAfterIfBody(ifEntryBlock, ifBodyFinalBlock);
450    mCarryManager->leaveScope();
451}
452
453void PabloCompiler::compileWhile(const While * whileStatement) {
454
455    PabloBlock & whileBody = whileStatement -> getBody();
456   
457    BasicBlock * whileEntryBlock = mBuilder->GetInsertBlock();
458    BasicBlock * whileBodyBlock = BasicBlock::Create(mMod->getContext(), "while.body", mFunction, 0);
459    BasicBlock * whileEndBlock = BasicBlock::Create(mMod->getContext(), "while.end", mFunction, 0);
460
461    mCarryManager->enterScope(&whileBody);
462    mCarryManager->ensureCarriesLoadedRecursive();
463
464    const auto & nextNodes = whileStatement->getVariants();
465    std::vector<PHINode *> nextPhis;
466    nextPhis.reserve(nextNodes.size());
467
468    // On entry to the while structure, proceed to execute the first iteration
469    // of the loop body unconditionally.   The while condition is tested at the end of
470    // the loop.
471
472    mBuilder->CreateBr(whileBodyBlock);
473    mBuilder->SetInsertPoint(whileBodyBlock);
474
475    //
476    // There are 3 sets of Phi nodes for the while loop.
477    // (1) Carry-ins: (a) incoming carry data first iterations, (b) zero thereafter
478    // (2) Carry-out accumulators: (a) zero first iteration, (b) |= carry-out of each iteration
479    // (3) Next nodes: (a) values set up before loop, (b) modified values calculated in loop.
480
481    mCarryManager->initializeCarryDataPhisAtWhileEntry(whileEntryBlock);
482
483    // for any Next nodes in the loop body, initialize to (a) pre-loop value.
484    for (const Next * n : nextNodes) {
485        PHINode * phi = mBuilder->CreatePHI(mBitBlockType, 2, n->getName()->value());
486        auto f = mMarkerMap.find(n->getInitial());
487        assert (f != mMarkerMap.end());
488        phi->addIncoming(f->second, whileEntryBlock);
489        mMarkerMap[n->getInitial()] = phi;
490        nextPhis.push_back(phi);
491    }
492
493    //
494    // Now compile the loop body proper.  Carry-out accumulated values
495    // and iterated values of Next nodes will be computed.
496    ++mWhileDepth;
497    compileBlock(whileBody);
498
499    BasicBlock * whileBodyFinalBlock = mBuilder->GetInsertBlock();
500
501    if (mCarryManager->blockHasCarries()) {
502        mCarryManager->generateCarryOutSummaryCodeIfNeeded();
503    }
504    mCarryManager->extendCarryDataPhisAtWhileBodyFinalBlock(whileBodyFinalBlock);
505
506    // Terminate the while loop body with a conditional branch back.
507    mBuilder->CreateCondBr(iBuilder.bitblock_any(compileExpression(whileStatement->getCondition())), whileBodyBlock, whileEndBlock);
508
509    // and for any Next nodes in the loop body
510    for (unsigned i = 0; i < nextNodes.size(); i++) {
511        const Next * n = nextNodes[i];
512        auto f = mMarkerMap.find(n->getExpr());
513        if (LLVM_UNLIKELY(f == mMarkerMap.end())) {
514            throw std::runtime_error("Next node expression was not compiled!");
515        }
516        nextPhis[i]->addIncoming(f->second, whileBodyFinalBlock);
517    }
518
519    mBuilder->SetInsertPoint(whileEndBlock);
520    --mWhileDepth;
521
522    mCarryManager->ensureCarriesStoredRecursive();
523    mCarryManager->leaveScope();
524}
525
526
527void PabloCompiler::compileStatement(const Statement * stmt) {
528    Value * expr = nullptr;
529    if (const Assign * assign = dyn_cast<const Assign>(stmt)) {
530        expr = compileExpression(assign->getExpression());
531    }
532    else if (const Next * next = dyn_cast<const Next>(stmt)) {
533        expr = compileExpression(next->getExpr());
534        if (TraceNext) {
535            genPrintRegister(next->getName()->to_string(), expr);
536        }
537    }
538    else if (const If * ifStatement = dyn_cast<const If>(stmt)) {
539        compileIf(ifStatement);
540        return;
541    }
542    else if (const While * whileStatement = dyn_cast<const While>(stmt)) {
543        compileWhile(whileStatement);
544        return;
545    }
546    else if (const Call* call = dyn_cast<Call>(stmt)) {
547        //Call the callee once and store the result in the marker map.
548        if (mMarkerMap.count(call) != 0) {
549            return;
550        }
551
552        const Prototype * proto = call->getPrototype();
553        const String * callee = proto->getName();
554
555        Type * inputType = StructType::get(mMod->getContext(), std::vector<Type *>{proto->getNumOfParameters(), mBitBlockType});
556        Type * carryType = mBitBlockType;
557        Type * outputType = StructType::get(mMod->getContext(), std::vector<Type *>{proto->getNumOfResults(), mBitBlockType});
558        FunctionType * functionType = FunctionType::get(Type::getVoidTy(mMod->getContext()), std::vector<Type *>{PointerType::get(inputType, 0), PointerType::get(carryType, 0), PointerType::get(outputType, 0)}, false);
559
560        //Starts on process_block
561        SmallVector<AttributeSet, 3> Attrs;
562        Attrs.push_back(AttributeSet::get(mMod->getContext(), 1U, { Attribute::ReadOnly, Attribute::NoCapture }));
563        Attrs.push_back(AttributeSet::get(mMod->getContext(), 2U, { Attribute::NoCapture }));
564        Attrs.push_back(AttributeSet::get(mMod->getContext(), 3U, { Attribute::ReadNone, Attribute::NoCapture }));
565        AttributeSet AttrSet = AttributeSet::get(mMod->getContext(), Attrs);
566
567        Function * externalFunction = cast<Function>(mMod->getOrInsertFunction(callee->value(), functionType, AttrSet));
568        if (LLVM_UNLIKELY(externalFunction == nullptr)) {
569            throw std::runtime_error("Could not create static method call for external function \"" + callee->to_string() + "\"");
570        }
571        externalFunction->setCallingConv(llvm::CallingConv::C);
572
573        if (mExecutionEngine) mExecutionEngine->addGlobalMapping(externalFunction, proto->getFunctionPtr());
574
575        // add mCarryOffset to mCarryDataPtr
576        Value * carryFramePtr = mBuilder->CreateGEP(mCarryDataPtr, mBuilder->getInt64(mCarryOffset));
577        AllocaInst * outputStruct = mBuilder->CreateAlloca(outputType);
578        mBuilder->CreateCall3(externalFunction, mInputAddressPtr, carryFramePtr, outputStruct);
579        Value * outputPtr = mBuilder->CreateGEP(outputStruct, { mBuilder->getInt32(0), mBuilder->getInt32(0) });
580        expr = mBuilder->CreateAlignedLoad(outputPtr, BLOCK_SIZE / 8, false);
581
582        mCarryOffset += (proto->getRequiredStateSpace() + (BLOCK_SIZE / 8) - 1) / (BLOCK_SIZE / 8);
583    }
584    else if (const And * pablo_and = dyn_cast<And>(stmt)) {
585        expr = mBuilder->CreateAnd(compileExpression(pablo_and->getExpr1()), compileExpression(pablo_and->getExpr2()), "and");
586    }
587    else if (const Or * pablo_or = dyn_cast<Or>(stmt)) {
588        expr = mBuilder->CreateOr(compileExpression(pablo_or->getExpr1()), compileExpression(pablo_or->getExpr2()), "or");
589    }
590    else if (const Xor * pablo_xor = dyn_cast<Xor>(stmt)) {
591        expr = mBuilder->CreateXor(compileExpression(pablo_xor->getExpr1()), compileExpression(pablo_xor->getExpr2()), "xor");
592    }
593    else if (const Sel * sel = dyn_cast<Sel>(stmt)) {
594        Value* ifMask = compileExpression(sel->getCondition());
595        Value* ifTrue = mBuilder->CreateAnd(ifMask, compileExpression(sel->getTrueExpr()));
596        Value* ifFalse = mBuilder->CreateAnd(genNot(ifMask), compileExpression(sel->getFalseExpr()));
597        expr = mBuilder->CreateOr(ifTrue, ifFalse);
598    }
599    else if (const Not * pablo_not = dyn_cast<Not>(stmt)) {
600        expr = genNot(compileExpression(pablo_not->getExpr()));
601    }
602    else if (const Advance * adv = dyn_cast<Advance>(stmt)) {
603        Value* strm_value = compileExpression(adv->getExpr());
604        int shift = adv->getAdvanceAmount();
605        if (adv->isMod64()) {
606            expr = iBuilder.simd_slli(64, strm_value, shift);
607        }
608        else {
609            unsigned advance_index = adv->getLocalAdvanceIndex();
610            expr = mCarryManager->advanceCarryInCarryOut(advance_index, shift, strm_value);
611        }
612    }
613    else if (const MatchStar * mstar = dyn_cast<MatchStar>(stmt)) {
614        Value * marker = compileExpression(mstar->getMarker());
615        Value * cc = compileExpression(mstar->getCharClass());
616        Value * marker_and_cc = mBuilder->CreateAnd(marker, cc);
617        Value * sum = nullptr;
618        if (mstar->isMod64()) {
619            sum = iBuilder.simd_add(64, marker_and_cc, cc);
620        }
621        else {
622            unsigned carry_index = mstar->getLocalCarryIndex();
623            sum = mCarryManager->addCarryInCarryOut(carry_index, marker_and_cc, cc);
624        }
625        expr = mBuilder->CreateOr(mBuilder->CreateXor(sum, cc), marker, "matchstar");
626    }
627    else if (const ScanThru * sthru = dyn_cast<ScanThru>(stmt)) {
628        Value * marker_expr = compileExpression(sthru->getScanFrom());
629        Value * cc_expr = compileExpression(sthru->getScanThru());
630        Value * sum = nullptr;
631        if (sthru->isMod64()) {
632            sum = iBuilder.simd_add(64, marker_expr, cc_expr);
633        }
634        else {
635            unsigned carry_index = sthru->getLocalCarryIndex();
636            sum = mCarryManager->addCarryInCarryOut(carry_index, marker_expr, cc_expr);
637        }
638        expr = mBuilder->CreateAnd(sum, genNot(cc_expr), "scanthru");
639    }
640    else {
641        llvm::raw_os_ostream cerr(std::cerr);
642        PabloPrinter::print(stmt, cerr);
643        throw std::runtime_error("Unrecognized Pablo Statement! can't compile.");
644    }
645    mMarkerMap[stmt] = expr;
646    if (DumpTrace) {
647        genPrintRegister(stmt->getName()->to_string(), expr);
648    }
649   
650}
651
652Value * PabloCompiler::compileExpression(const PabloAST * expr) {
653    if (isa<Ones>(expr)) {
654        return mOneInitializer;
655    }
656    else if (isa<Zeroes>(expr)) {
657        return mZeroInitializer;
658    }
659    auto f = mMarkerMap.find(expr);
660    if (LLVM_UNLIKELY(f == mMarkerMap.end())) {
661        std::string o;
662        llvm::raw_string_ostream str(o);
663        str << "\"";
664        PabloPrinter::print(expr, str);
665        str << "\" was used before definition!";
666        throw std::runtime_error(str.str());
667    }
668    return f->second;
669}
670
671
672#ifdef USE_UADD_OVERFLOW
673#ifdef USE_TWO_UADD_OVERFLOW
674PabloCompiler::SumWithOverflowPack PabloCompiler::callUaddOverflow(Value* int128_e1, Value* int128_e2) {
675    std::vector<Value*> struct_res_params;
676    struct_res_params.push_back(int128_e1);
677    struct_res_params.push_back(int128_e2);
678    CallInst* struct_res = CallInst::Create(mFunctionUaddOverflow, struct_res_params, "uadd_overflow_res", mBasicBlock);
679    struct_res->setCallingConv(CallingConv::C);
680    struct_res->setTailCall(false);
681    AttributeSet struct_res_PAL;
682    struct_res->setAttributes(struct_res_PAL);
683
684    SumWithOverflowPack ret;
685
686    std::vector<unsigned> int128_sum_indices;
687    int128_sum_indices.push_back(0);
688    ret.sum = ExtractValueInst::Create(struct_res, int128_sum_indices, "sum", mBasicBlock);
689
690    std::vector<unsigned> int1_obit_indices;
691    int1_obit_indices.push_back(1);
692    ret.obit = ExtractValueInst::Create(struct_res, int1_obit_indices, "obit", mBasicBlock);
693
694    return ret;
695}
696#else
697PabloCompiler::SumWithOverflowPack PabloCompiler::callUaddOverflow(Value* int128_e1, Value* int128_e2, Value* int1_cin) {
698    std::vector<Value*> struct_res_params;
699    struct_res_params.push_back(int128_e1);
700    struct_res_params.push_back(int128_e2);
701    struct_res_params.push_back(int1_cin);
702    CallInst* struct_res = CallInst::Create(mFunctionUaddOverflowCarryin, struct_res_params, "uadd_overflow_res", mBasicBlock);
703    struct_res->setCallingConv(CallingConv::C);
704    struct_res->setTailCall(false);
705    AttributeSet struct_res_PAL;
706    struct_res->setAttributes(struct_res_PAL);
707
708    SumWithOverflowPack ret;
709
710    std::vector<unsigned> int128_sum_indices;
711    int128_sum_indices.push_back(0);
712    ret.sum = ExtractValueInst::Create(struct_res, int128_sum_indices, "sum", mBasicBlock);
713
714    std::vector<unsigned> int1_obit_indices;
715    int1_obit_indices.push_back(1);
716    ret.obit = ExtractValueInst::Create(struct_res, int1_obit_indices, "obit", mBasicBlock);
717
718    return ret;
719}
720#endif
721#endif
722
723
724Value* PabloCompiler::genAddWithCarry(Value* e1, Value* e2, unsigned localIndex) {
725    Value * carryq_value = mCarryManager->getCarryOpCarryIn(localIndex);
726#ifdef USE_TWO_UADD_OVERFLOW
727    //This is the ideal implementation, which uses two uadd.with.overflow
728    //The back end should be able to recognize this pattern and combine it into uadd.with.overflow.carryin
729    CastInst* int128_e1 = new BitCastInst(e1, mBuilder->getIntNTy(BLOCK_SIZE), "e1_128", mBasicBlock);
730    CastInst* int128_e2 = new BitCastInst(e2, mBuilder->getIntNTy(BLOCK_SIZE), "e2_128", mBasicBlock);
731    CastInst* int128_carryq_value = new BitCastInst(carryq_value, mBuilder->getIntNTy(BLOCK_SIZE), "carryq_128", mBasicBlock);
732
733    SumWithOverflowPack sumpack0, sumpack1;
734
735    sumpack0 = callUaddOverflow(int128_e1, int128_e2);
736    sumpack1 = callUaddOverflow(sumpack0.sum, int128_carryq_value);
737
738    Value* obit = mBuilder->CreateOr(sumpack0.obit, sumpack1.obit, "carry_bit");
739    Value* sum = mBuilder->CreateBitCast(sumpack1.sum, mBitBlockType, "ret_sum");
740
741    /*obit is the i1 carryout, zero extend and insert it into a v2i64 or v4i64 vector.*/
742    ConstantAggregateZero* const_packed_5 = ConstantAggregateZero::get(mBitBlockType);
743    ConstantInt* const_int32_6 = ConstantInt::get(mMod->getContext(), APInt(32, StringRef("0"), 10));
744    CastInst* int64_o0 = new ZExtInst(obit, IntegerType::get(mMod->getContext(), 64), "o0", mBasicBlock);
745    InsertElementInst* carry_out = InsertElementInst::Create(const_packed_5, int64_o0, const_int32_6, "carry_out", mBasicBlock);
746
747#elif defined USE_UADD_OVERFLOW
748    //use llvm.uadd.with.overflow.i128 or i256
749    CastInst* int128_e1 = new BitCastInst(e1, mBuilder->getIntNTy(BLOCK_SIZE), "e1_128", mBasicBlock);
750    CastInst* int128_e2 = new BitCastInst(e2, mBuilder->getIntNTy(BLOCK_SIZE), "e2_128", mBasicBlock);
751
752    //get i1 carryin from iBLOCK_SIZE
753    ConstantInt* const_int32_6 = ConstantInt::get(mMod->getContext(), APInt(32, StringRef("0"), 10));
754    ExtractElementInst * int64_carryq_value = ExtractElementInst::Create(carryq_value, const_int32_6, "carryq_64", mBasicBlock);
755    CastInst* int1_carryq_value = new TruncInst(int64_carryq_value, IntegerType::get(mMod->getContext(), 1), "carryq_1", mBasicBlock);
756
757    SumWithOverflowPack sumpack0;
758    sumpack0 = callUaddOverflow(int128_e1, int128_e2, int1_carryq_value);
759    Value* obit = sumpack0.obit;
760    Value* sum = mBuilder->CreateBitCast(sumpack0.sum, mBitBlockType, "sum");
761
762    /*obit is the i1 carryout, zero extend and insert it into a v2i64 or v4i64 vector.*/
763    ConstantAggregateZero* const_packed_5 = ConstantAggregateZero::get(mBitBlockType);
764    CastInst* int64_o0 = new ZExtInst(obit, IntegerType::get(mMod->getContext(), 64), "o0", mBasicBlock);
765    InsertElementInst* carry_out = InsertElementInst::Create(const_packed_5, int64_o0, const_int32_6, "carry_out", mBasicBlock);
766#elif (BLOCK_SIZE == 128)
767    //calculate carry through logical ops
768    Value* carrygen = mBuilder->CreateAnd(e1, e2, "carrygen");
769    Value* carryprop = mBuilder->CreateOr(e1, e2, "carryprop");
770    Value* digitsum = mBuilder->CreateAdd(e1, e2, "digitsum");
771    Value* partial = mBuilder->CreateAdd(digitsum, carryq_value, "partial");
772    Value* digitcarry = mBuilder->CreateOr(carrygen, mBuilder->CreateAnd(carryprop, genNot(partial)));
773    Value* mid_carry_in = genShiftLeft64(mBuilder->CreateLShr(digitcarry, 63), "mid_carry_in");
774
775    Value* sum = mBuilder->CreateAdd(partial, mid_carry_in, "sum");
776    Value* carry_out = genShiftHighbitToLow(BLOCK_SIZE, mBuilder->CreateOr(carrygen, mBuilder->CreateAnd(carryprop, genNot(sum))));
777#else
778    //BLOCK_SIZE == 256, there is no other implementation
779    static_assert(false, "Add with carry for 256-bit bitblock requires USE_UADD_OVERFLOW");
780#endif //USE_TWO_UADD_OVERFLOW
781
782    mCarryManager->setCarryOpCarryOut(localIndex, carry_out);
783    return sum;
784}
785
786Value * PabloCompiler::genShiftHighbitToLow(unsigned FieldWidth, Value * op) {
787    unsigned FieldCount = BLOCK_SIZE/FieldWidth;
788    VectorType * vType = VectorType::get(IntegerType::get(mMod->getContext(), FieldWidth), FieldCount);
789    Value * v = mBuilder->CreateBitCast(op, vType);
790    return mBuilder->CreateBitCast(mBuilder->CreateLShr(v, FieldWidth - 1), mBitBlockType);
791}
792
793Value* PabloCompiler::genShiftLeft64(Value* e, const Twine &namehint) {
794    Value* i128_val = mBuilder->CreateBitCast(e, mBuilder->getIntNTy(BLOCK_SIZE));
795    return mBuilder->CreateBitCast(mBuilder->CreateShl(i128_val, 64, namehint), mBitBlockType);
796}
797
798inline Value* PabloCompiler::genNot(Value* expr) {
799    return mBuilder->CreateXor(expr, mOneInitializer, "not");
800}
801   
802void PabloCompiler::SetOutputValue(Value * marker, const unsigned index) {
803    if (LLVM_UNLIKELY(marker == nullptr)) {
804        throw std::runtime_error("Cannot set result " + std::to_string(index) + " to Null");
805    }
806    if (LLVM_UNLIKELY(marker->getType()->isPointerTy())) {
807        marker = mBuilder->CreateAlignedLoad(marker, BLOCK_SIZE/8, false);
808    }
809    Value* indices[] = {mBuilder->getInt64(0), mBuilder->getInt32(index)};
810    Value* gep = mBuilder->CreateGEP(mOutputAddressPtr, indices);
811    mBuilder->CreateAlignedStore(marker, gep, BLOCK_SIZE/8, false);
812}
813
814CompiledPabloFunction::CompiledPabloFunction(size_t carryDataSize, Function * function, ExecutionEngine * executionEngine)
815: CarryDataSize(carryDataSize)
816, FunctionPointer(executionEngine->getPointerToFunction(function))
817, mFunction(function)
818, mExecutionEngine(executionEngine)
819{
820
821}
822
823// Clean up the memory for the compiled function once we're finished using it.
824CompiledPabloFunction::~CompiledPabloFunction() {
825    if (mExecutionEngine) {
826        assert (mFunction);
827        // mExecutionEngine->freeMachineCodeForFunction(mFunction); // This function only prints a "not supported" message. Reevaluate with LLVM 3.6.
828        delete mExecutionEngine;
829    }
830}
831
832}
Note: See TracBrowser for help on using the repository browser.