source: icGREP/icgrep-devel/icgrep/pablo/pablo_compiler.cpp @ 4741

Last change on this file since 4741 was 4741, checked in by nmedfort, 4 years ago

More work on the reassociation pass.

File size: 33.1 KB
Line 
1/*
2 *  Copyright (c) 2014-15 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 *  icgrep is a trademark of International Characters.
5 */
6
7#include <pablo/pablo_compiler.h>
8#include <pablo/codegenstate.h>
9#include <pablo/carry_data.h>
10#include <pablo/carry_manager.h>
11#include <pablo/printer_pablos.h>
12#include <pablo/function.h>
13#include <cc/cc_namemap.hpp>
14#include <re/re_name.h>
15#include <stdexcept>
16#include <include/simd-lib/bitblock.hpp>
17#include <sstream>
18#include <IDISA/idisa_builder.h>
19#include <llvm/IR/Verifier.h>
20#include <llvm/Pass.h>
21#include <llvm/PassManager.h>
22#include <llvm/ADT/SmallVector.h>
23#include <llvm/Analysis/Passes.h>
24#include <llvm/IR/BasicBlock.h>
25#include <llvm/IR/CallingConv.h>
26#include <llvm/IR/Constants.h>
27#include <llvm/IR/DataLayout.h>
28#include <llvm/IR/DerivedTypes.h>
29#include <llvm/IR/Function.h>
30#include <llvm/IR/GlobalVariable.h>
31#include <llvm/IR/InlineAsm.h>
32#include <llvm/IR/Instructions.h>
33#include <llvm/IR/LLVMContext.h>
34#include <llvm/IR/Module.h>
35#include <llvm/Support/FormattedStream.h>
36#include <llvm/Support/MathExtras.h>
37#include <llvm/Support/Casting.h>
38#include <llvm/Support/Compiler.h>
39#include <llvm/Support/Debug.h>
40#include <llvm/Support/TargetSelect.h>
41#include <llvm/Support/Host.h>
42#include <llvm/Transforms/Scalar.h>
43#include <llvm/IRReader/IRReader.h>
44#include <llvm/Bitcode/ReaderWriter.h>
45#include <llvm/Support/MemoryBuffer.h>
46#include <llvm/IR/IRBuilder.h>
47#include <llvm/Support/CommandLine.h>
48#include <llvm/ADT/Twine.h>
49#include <iostream>
50
51static cl::OptionCategory eIRDumpOptions("LLVM IR Dump Options", "These options control dumping of LLVM IR.");
52static cl::opt<bool> DumpGeneratedIR("dump-generated-IR", cl::init(false), cl::desc("Print LLVM IR generated by Pablo Compiler."), cl::cat(eIRDumpOptions));
53
54static cl::OptionCategory fTracingOptions("Run-time Tracing Options", "These options control execution traces.");
55static cl::opt<bool> DumpTrace("dump-trace", cl::init(false), cl::desc("Generate dynamic traces of executed assignments."), cl::cat(fTracingOptions));
56
57namespace pablo {
58
59PabloCompiler::PabloCompiler()
60: mMod(nullptr)
61, mBuilder(nullptr)
62, mCarryManager(nullptr)
63, mBitBlockType(VectorType::get(IntegerType::get(getGlobalContext(), 64), BLOCK_SIZE / 64))
64, iBuilder(mBitBlockType)
65, mInputType(nullptr)
66, mCarryDataPtr(nullptr)
67, mWhileDepth(0)
68, mIfDepth(0)
69, mZeroInitializer(ConstantAggregateZero::get(mBitBlockType))
70, mOneInitializer(ConstantVector::getAllOnesValue(mBitBlockType))
71, mFunction(nullptr)
72, mInputAddressPtr(nullptr)
73, mOutputAddressPtr(nullptr)
74, mMaxWhileDepth(0)
75, mPrintRegisterFunction(nullptr) {
76
77}
78
79PabloCompiler::~PabloCompiler() {
80}
81   
82
83void PabloCompiler::genPrintRegister(std::string regName, Value * bitblockValue) {
84    Constant * regNameData = ConstantDataArray::getString(mMod->getContext(), regName);
85    GlobalVariable *regStrVar = new GlobalVariable(*mMod,
86                                                   ArrayType::get(IntegerType::get(mMod->getContext(), 8), regName.length()+1),
87                                                   /*isConstant=*/ true,
88                                                   /*Linkage=*/ GlobalValue::PrivateLinkage,
89                                                   /*Initializer=*/ regNameData);
90    Value * regStrPtr = mBuilder->CreateGEP(regStrVar, {mBuilder->getInt64(0), mBuilder->getInt32(0)});
91    mBuilder->CreateCall(mPrintRegisterFunction, {regStrPtr, bitblockValue});
92}
93
94llvm::Function * PabloCompiler::compile(PabloFunction * function) {
95    Module * module = new Module("", getGlobalContext());
96   
97    auto func = compile(function, module);
98   
99    //Display the IR that has been generated by this module.
100    if (LLVM_UNLIKELY(DumpGeneratedIR)) {
101        module->dump();
102    }
103    //Create a verifier.  The verifier will print an error message if our module is malformed in any way.
104    verifyModule(*module, &dbgs());
105
106    return func;
107}
108
109llvm::Function * PabloCompiler::compile(PabloFunction * function, Module * module) {
110
111 
112    PabloBlock & mainScope = function->getEntryBlock();
113
114    mainScope.enumerateScopes(0);
115   
116    Examine(*function);
117
118    mMod = module;
119
120    mBuilder = new IRBuilder<>(mMod->getContext());
121
122    iBuilder.initialize(mMod, mBuilder);
123
124    mCarryManager = new CarryManager(mBuilder, mBitBlockType, mZeroInitializer, mOneInitializer, &iBuilder);
125   
126    if (DumpTrace) DeclareDebugFunctions();
127       
128    GenerateFunction(*function);
129   
130    mBuilder->SetInsertPoint(BasicBlock::Create(mMod->getContext(), "entry", mFunction,0));
131
132    //The basis bits structure
133    for (unsigned i = 0; i != function->getNumOfParameters(); ++i) {
134        Value* indices[] = {mBuilder->getInt64(0), mBuilder->getInt32(i)};
135        Value * gep = mBuilder->CreateGEP(mInputAddressPtr, indices);
136        LoadInst * basisBit = mBuilder->CreateAlignedLoad(gep, BLOCK_SIZE/8, false, function->getParameter(i)->getName()->to_string());
137        mMarkerMap[function->getParameter(i)] = basisBit;
138        if (DumpTrace) {
139            genPrintRegister(function->getParameter(i)->getName()->to_string(), basisBit);
140        }
141    }
142     
143    //Generate the IR instructions for the function.
144   
145    mCarryManager->initialize(mMod, &mainScope);
146   
147    compileBlock(mainScope);
148   
149    mCarryManager->ensureCarriesStoredLocal();
150    mCarryManager->leaveScope();
151   
152   
153    mCarryManager->generateBlockNoIncrement();
154
155    if (DumpTrace) {
156        genPrintRegister("mBlockNo", mBuilder->CreateAlignedLoad(mBuilder->CreateBitCast(mCarryManager->getBlockNoPtr(), PointerType::get(mBitBlockType, 0)), BLOCK_SIZE/8, false));
157    }
158   
159    // Write the output values out
160    for (unsigned i = 0; i != function->getNumOfResults(); ++i) {
161        assert (function->getResult(i));
162        SetOutputValue(mMarkerMap[function->getResult(i)], i);
163    }
164
165    //Terminate the block
166    ReturnInst::Create(mMod->getContext(), mBuilder->GetInsertBlock());
167   
168    // Clean up
169    delete mCarryManager; mCarryManager = nullptr;
170    delete mBuilder; mBuilder = nullptr;
171    mMod = nullptr; // don't delete this. It's either owned by the ExecutionEngine or the calling function.
172
173    //Return the required size of the carry data area to the process_block function.
174    return mFunction;
175}
176
177inline void PabloCompiler::GenerateFunction(PabloFunction & function) {
178    mInputType = PointerType::get(StructType::get(mMod->getContext(), std::vector<Type *>(function.getNumOfParameters(), mBitBlockType)), 0);
179    Type * outputType = PointerType::get(StructType::get(mMod->getContext(), std::vector<Type *>(function.getNumOfResults(), mBitBlockType)), 0);
180    FunctionType * functionType = FunctionType::get(Type::getVoidTy(mMod->getContext()), {{mInputType, outputType}}, false);
181
182#ifdef USE_UADD_OVERFLOW
183#ifdef USE_TWO_UADD_OVERFLOW
184    // Type Definitions for llvm.uadd.with.overflow.carryin.i128 or .i256
185    std::vector<Type*>StructTy_0_fields;
186    StructTy_0_fields.push_back(IntegerType::get(mMod->getContext(), BLOCK_SIZE));
187    StructTy_0_fields.push_back(IntegerType::get(mMod->getContext(), 1));
188    StructType *StructTy_0 = StructType::get(mMod->getContext(), StructTy_0_fields, /*isPacked=*/false);
189
190    std::vector<Type*>FuncTy_1_args;
191    FuncTy_1_args.push_back(IntegerType::get(mMod->getContext(), BLOCK_SIZE));
192    FuncTy_1_args.push_back(IntegerType::get(mMod->getContext(), BLOCK_SIZE));
193    FunctionType* FuncTy_1 = FunctionType::get(
194                                              /*Result=*/StructTy_0,
195                                              /*Params=*/FuncTy_1_args,
196                                              /*isVarArg=*/false);
197
198    mFunctionUaddOverflow = mMod->getFunction("llvm.uadd.with.overflow.i" +
199                                              std::to_string(BLOCK_SIZE));
200    if (!mFunctionUaddOverflow) {
201        mFunctionUaddOverflow= Function::Create(
202          /*Type=*/ FuncTy_1,
203          /*Linkage=*/ GlobalValue::ExternalLinkage,
204          /*Name=*/ "llvm.uadd.with.overflow.i" + std::to_string(BLOCK_SIZE), mMod); // (external, no body)
205        mFunctionUaddOverflow->setCallingConv(CallingConv::C);
206    }
207    AttributeSet mFunctionUaddOverflowPAL;
208    {
209        SmallVector<AttributeSet, 4> Attrs;
210        AttributeSet PAS;
211        {
212          AttrBuilder B;
213          B.addAttribute(Attribute::NoUnwind);
214          B.addAttribute(Attribute::ReadNone);
215          PAS = AttributeSet::get(mMod->getContext(), ~0U, B);
216        }
217
218        Attrs.push_back(PAS);
219        mFunctionUaddOverflowPAL = AttributeSet::get(mMod->getContext(), Attrs);
220    }
221    mFunctionUaddOverflow->setAttributes(mFunctionUaddOverflowPAL);
222#else
223    // Type Definitions for llvm.uadd.with.overflow.carryin.i128 or .i256
224    std::vector<Type*>StructTy_0_fields;
225    StructTy_0_fields.push_back(IntegerType::get(mMod->getContext(), BLOCK_SIZE));
226    StructTy_0_fields.push_back(IntegerType::get(mMod->getContext(), 1));
227    StructType *StructTy_0 = StructType::get(mMod->getContext(), StructTy_0_fields, /*isPacked=*/false);
228
229    std::vector<Type*>FuncTy_1_args;
230    FuncTy_1_args.push_back(IntegerType::get(mMod->getContext(), BLOCK_SIZE));
231    FuncTy_1_args.push_back(IntegerType::get(mMod->getContext(), BLOCK_SIZE));
232    FuncTy_1_args.push_back(IntegerType::get(mMod->getContext(), 1));
233    FunctionType* FuncTy_1 = FunctionType::get(
234                                              /*Result=*/StructTy_0,
235                                              /*Params=*/FuncTy_1_args,
236                                              /*isVarArg=*/false);
237
238    mFunctionUaddOverflowCarryin = mMod->getFunction("llvm.uadd.with.overflow.carryin.i" +
239                                              std::to_string(BLOCK_SIZE));
240    if (!mFunctionUaddOverflowCarryin) {
241        mFunctionUaddOverflowCarryin = Function::Create(
242          /*Type=*/ FuncTy_1,
243          /*Linkage=*/ GlobalValue::ExternalLinkage,
244          /*Name=*/ "llvm.uadd.with.overflow.carryin.i" + std::to_string(BLOCK_SIZE), mMod); // (external, no body)
245        mFunctionUaddOverflowCarryin->setCallingConv(CallingConv::C);
246    }
247    AttributeSet mFunctionUaddOverflowCarryinPAL;
248    {
249        SmallVector<AttributeSet, 4> Attrs;
250        AttributeSet PAS;
251        {
252          AttrBuilder B;
253          B.addAttribute(Attribute::NoUnwind);
254          B.addAttribute(Attribute::ReadNone);
255          PAS = AttributeSet::get(mMod->getContext(), ~0U, B);
256        }
257
258        Attrs.push_back(PAS);
259        mFunctionUaddOverflowCarryinPAL = AttributeSet::get(mMod->getContext(), Attrs);
260    }
261    mFunctionUaddOverflowCarryin->setAttributes(mFunctionUaddOverflowCarryinPAL);
262#endif
263#endif
264
265    //Starts on process_block
266    SmallVector<AttributeSet, 3> Attrs;
267    Attrs.push_back(AttributeSet::get(mMod->getContext(), ~0U, { Attribute::NoUnwind, Attribute::UWTable }));
268    Attrs.push_back(AttributeSet::get(mMod->getContext(), 1U, { Attribute::ReadOnly, Attribute::NoCapture }));
269    Attrs.push_back(AttributeSet::get(mMod->getContext(), 2U, { Attribute::ReadNone, Attribute::NoCapture }));
270    AttributeSet AttrSet = AttributeSet::get(mMod->getContext(), Attrs);
271
272    // Create the function that will be generated.
273    mFunction = Function::Create(functionType, GlobalValue::ExternalLinkage, function.getName()->value(), mMod);
274    mFunction->setCallingConv(CallingConv::C);
275    mFunction->setAttributes(AttrSet);
276
277    Function::arg_iterator args = mFunction->arg_begin();
278    mInputAddressPtr = args++;
279    mInputAddressPtr->setName("input");
280    mOutputAddressPtr = args++;
281    mOutputAddressPtr->setName("output");
282}
283
284inline void PabloCompiler::Examine(PabloFunction & function) {
285    mWhileDepth = 0;
286    mIfDepth = 0;
287    mMaxWhileDepth = 0;
288    Examine(function.getEntryBlock());
289    if (LLVM_UNLIKELY(mWhileDepth != 0 || mIfDepth != 0)) {
290        throw std::runtime_error("Malformed Pablo AST: Unbalanced If or While nesting depth!");
291    }
292}
293
294
295void PabloCompiler::Examine(PabloBlock & block) {
296    for (Statement * stmt : block) {
297        if (If * ifStatement = dyn_cast<If>(stmt)) {
298            Examine(ifStatement->getBody());
299        }
300        else if (While * whileStatement = dyn_cast<While>(stmt)) {
301            mMaxWhileDepth = std::max(mMaxWhileDepth, ++mWhileDepth);
302            Examine(whileStatement->getBody());
303            --mWhileDepth;
304        }
305    }
306}
307
308inline void PabloCompiler::DeclareDebugFunctions() {
309        //This function can be used for testing to print the contents of a register from JIT'd code to the terminal window.
310        mPrintRegisterFunction = mMod->getOrInsertFunction("wrapped_print_register", Type::getVoidTy(mMod->getContext()), Type::getInt8PtrTy(mMod->getContext()), mBitBlockType, NULL);
311}
312
313void PabloCompiler::compileBlock(PabloBlock & block) {
314    mPabloBlock = & block;
315    for (const Statement * statement : block) {
316        compileStatement(statement);
317    }
318    mPabloBlock = block.getParent();
319}
320
321    Value * PabloCompiler::genBitTest2(Value * e1, Value * e2) {
322        Type * t1 = e1->getType();
323        Type * t2 = e2->getType();
324        if (t1 == mBitBlockType) {
325            if (t2 == mBitBlockType) {
326                return iBuilder.bitblock_any(mBuilder->CreateOr(e1, e2));
327            }
328            else {
329                Value * m1 = mBuilder->CreateZExt(iBuilder.hsimd_signmask(16, e1), t2);
330                return mBuilder->CreateICmpNE(mBuilder->CreateOr(m1, e2), ConstantInt::get(t2, 0));
331            }
332        }
333        else if (t2 == mBitBlockType) {
334            Value * m2 = mBuilder->CreateZExt(iBuilder.hsimd_signmask(16, e2), t1);
335            return mBuilder->CreateICmpNE(mBuilder->CreateOr(e1, m2), ConstantInt::get(t1, 0));
336        }
337        else {
338            return mBuilder->CreateICmpNE(mBuilder->CreateOr(e1, e2), ConstantInt::get(t1, 0));
339        }
340    }
341   
342    void PabloCompiler::compileIf(const If * ifStatement) {       
343    //
344    //  The If-ElseZero stmt:
345    //  if <predicate:expr> then <body:stmt>* elsezero <defined:var>* endif
346    //  If the value of the predicate is nonzero, then determine the values of variables
347    //  <var>* by executing the given statements.  Otherwise, the value of the
348    //  variables are all zero.  Requirements: (a) no variable that is defined within
349    //  the body of the if may be accessed outside unless it is explicitly
350    //  listed in the variable list, (b) every variable in the defined list receives
351    //  a value within the body, and (c) the logical consequence of executing
352    //  the statements in the event that the predicate is zero is that the
353    //  values of all defined variables indeed work out to be 0.
354    //
355    //  Simple Implementation with Phi nodes:  a phi node in the if exit block
356    //  is inserted for each variable in the defined variable list.  It receives
357    //  a zero value from the ifentry block and the defined value from the if
358    //  body.
359    //
360
361    BasicBlock * ifEntryBlock = mBuilder->GetInsertBlock();
362    BasicBlock * ifBodyBlock = BasicBlock::Create(mMod->getContext(), "if.body", mFunction, 0);
363    BasicBlock * ifEndBlock = BasicBlock::Create(mMod->getContext(), "if.end", mFunction, 0);
364   
365    PabloBlock & ifBody = ifStatement -> getBody();
366   
367    Value * if_test_value = compileExpression(ifStatement->getCondition());
368   
369    mCarryManager->enterScope(&ifBody);
370    if (mCarryManager->blockHasCarries()) {
371        // load the summary variable
372        Value* last_if_pending_data = mCarryManager->getCarrySummaryExpr();
373        mBuilder->CreateCondBr(genBitTest2(if_test_value, last_if_pending_data), ifBodyBlock, ifEndBlock);
374
375    }
376    else {
377        mBuilder->CreateCondBr(iBuilder.bitblock_any(if_test_value), ifBodyBlock, ifEndBlock);
378    }
379    // Entry processing is complete, now handle the body of the if.
380    mBuilder->SetInsertPoint(ifBodyBlock);
381   
382    mCarryManager->initializeCarryDataAtIfEntry();
383    compileBlock(ifBody);
384    if (mCarryManager->blockHasCarries()) {
385        mCarryManager->generateCarryOutSummaryCodeIfNeeded();
386    }
387    BasicBlock * ifBodyFinalBlock = mBuilder->GetInsertBlock();
388    mCarryManager->ensureCarriesStoredLocal();
389    mBuilder->CreateBr(ifEndBlock);
390    //End Block
391    mBuilder->SetInsertPoint(ifEndBlock);
392    for (const PabloAST * node : ifStatement->getDefined()) {
393        const Assign * assign = cast<Assign>(node);
394        PHINode * phi = mBuilder->CreatePHI(mBitBlockType, 2, assign->getName()->value());
395        auto f = mMarkerMap.find(assign);
396        assert (f != mMarkerMap.end());
397        phi->addIncoming(mZeroInitializer, ifEntryBlock);
398        phi->addIncoming(f->second, ifBodyFinalBlock);
399        mMarkerMap[assign] = phi;
400    }
401    // Create the phi Node for the summary variable, if needed.
402    mCarryManager->buildCarryDataPhisAfterIfBody(ifEntryBlock, ifBodyFinalBlock);
403    mCarryManager->leaveScope();
404}
405
406void PabloCompiler::compileWhile(const While * whileStatement) {
407
408    PabloBlock & whileBody = whileStatement -> getBody();
409   
410    BasicBlock * whileEntryBlock = mBuilder->GetInsertBlock();
411    BasicBlock * whileBodyBlock = BasicBlock::Create(mMod->getContext(), "while.body", mFunction, 0);
412    BasicBlock * whileEndBlock = BasicBlock::Create(mMod->getContext(), "while.end", mFunction, 0);
413
414    mCarryManager->enterScope(&whileBody);
415    mCarryManager->ensureCarriesLoadedRecursive();
416
417    const auto & nextNodes = whileStatement->getVariants();
418    std::vector<PHINode *> nextPhis;
419    nextPhis.reserve(nextNodes.size());
420
421    // On entry to the while structure, proceed to execute the first iteration
422    // of the loop body unconditionally.   The while condition is tested at the end of
423    // the loop.
424
425    mBuilder->CreateBr(whileBodyBlock);
426    mBuilder->SetInsertPoint(whileBodyBlock);
427
428    //
429    // There are 3 sets of Phi nodes for the while loop.
430    // (1) Carry-ins: (a) incoming carry data first iterations, (b) zero thereafter
431    // (2) Carry-out accumulators: (a) zero first iteration, (b) |= carry-out of each iteration
432    // (3) Next nodes: (a) values set up before loop, (b) modified values calculated in loop.
433
434    mCarryManager->initializeCarryDataPhisAtWhileEntry(whileEntryBlock);
435
436    // for any Next nodes in the loop body, initialize to (a) pre-loop value.
437    for (const Next * n : nextNodes) {
438        PHINode * phi = mBuilder->CreatePHI(mBitBlockType, 2, n->getName()->value());
439        auto f = mMarkerMap.find(n->getInitial());
440        assert (f != mMarkerMap.end());
441        phi->addIncoming(f->second, whileEntryBlock);
442        mMarkerMap[n->getInitial()] = phi;
443        nextPhis.push_back(phi);
444    }
445
446    //
447    // Now compile the loop body proper.  Carry-out accumulated values
448    // and iterated values of Next nodes will be computed.
449    ++mWhileDepth;
450    compileBlock(whileBody);
451
452    BasicBlock * whileBodyFinalBlock = mBuilder->GetInsertBlock();
453
454    if (mCarryManager->blockHasCarries()) {
455        mCarryManager->generateCarryOutSummaryCodeIfNeeded();
456    }
457    mCarryManager->extendCarryDataPhisAtWhileBodyFinalBlock(whileBodyFinalBlock);
458
459    // Terminate the while loop body with a conditional branch back.
460    mBuilder->CreateCondBr(iBuilder.bitblock_any(compileExpression(whileStatement->getCondition())), whileBodyBlock, whileEndBlock);
461
462    // and for any Next nodes in the loop body
463    for (unsigned i = 0; i < nextNodes.size(); i++) {
464        const Next * n = nextNodes[i];
465        auto f = mMarkerMap.find(n->getExpr());
466        if (LLVM_UNLIKELY(f == mMarkerMap.end())) {
467            throw std::runtime_error("Next node expression was not compiled!");
468        }
469        nextPhis[i]->addIncoming(f->second, whileBodyFinalBlock);
470    }
471
472    mBuilder->SetInsertPoint(whileEndBlock);
473    --mWhileDepth;
474
475    mCarryManager->ensureCarriesStoredRecursive();
476    mCarryManager->leaveScope();
477}
478
479
480void PabloCompiler::compileStatement(const Statement * stmt) {
481    Value * expr = nullptr;
482    if (const Assign * assign = dyn_cast<const Assign>(stmt)) {
483        expr = compileExpression(assign->getExpression());
484    }
485    else if (const Next * next = dyn_cast<const Next>(stmt)) {
486        expr = compileExpression(next->getExpr());
487    }
488    else if (const If * ifStatement = dyn_cast<const If>(stmt)) {
489        compileIf(ifStatement);
490        return;
491    }
492    else if (const While * whileStatement = dyn_cast<const While>(stmt)) {
493        compileWhile(whileStatement);
494        return;
495    }
496    else if (const Call* call = dyn_cast<Call>(stmt)) {
497        //Call the callee once and store the result in the marker map.
498        if (mMarkerMap.count(call) != 0) {
499            return;
500        }
501
502        const Prototype * proto = call->getPrototype();
503        const String * callee = proto->getName();
504
505        Type * inputType = StructType::get(mMod->getContext(), std::vector<Type *>{proto->getNumOfParameters(), mBitBlockType});
506        Type * outputType = StructType::get(mMod->getContext(), std::vector<Type *>{proto->getNumOfResults(), mBitBlockType});
507        FunctionType * functionType = FunctionType::get(Type::getVoidTy(mMod->getContext()), std::vector<Type *>{PointerType::get(inputType, 0), PointerType::get(outputType, 0)}, false);
508
509        //Starts on process_block
510        SmallVector<AttributeSet, 3> Attrs;
511        Attrs.push_back(AttributeSet::get(mMod->getContext(), 1U, { Attribute::ReadOnly, Attribute::NoCapture }));
512        Attrs.push_back(AttributeSet::get(mMod->getContext(), 2U, { Attribute::ReadNone, Attribute::NoCapture }));
513        AttributeSet AttrSet = AttributeSet::get(mMod->getContext(), Attrs);
514
515        Function * externalFunction = cast<Function>(mMod->getOrInsertFunction(callee->value(), functionType, AttrSet));
516        if (LLVM_UNLIKELY(externalFunction == nullptr)) {
517            throw std::runtime_error("Could not create static method call for external function \"" + callee->to_string() + "\"");
518        }
519        externalFunction->setCallingConv(llvm::CallingConv::C);
520
521
522        AllocaInst * outputStruct = mBuilder->CreateAlloca(outputType);
523        mBuilder->CreateCall2(externalFunction, mInputAddressPtr, outputStruct);
524        Value * outputPtr = mBuilder->CreateGEP(outputStruct, { mBuilder->getInt32(0), mBuilder->getInt32(0) });
525        expr = mBuilder->CreateAlignedLoad(outputPtr, BLOCK_SIZE / 8, false);
526    }
527    else if (const And * pablo_and = dyn_cast<And>(stmt)) {
528        expr = mBuilder->CreateAnd(compileExpression(pablo_and->getExpr1()), compileExpression(pablo_and->getExpr2()), "and");
529    }
530    else if (const Or * pablo_or = dyn_cast<Or>(stmt)) {
531        expr = mBuilder->CreateOr(compileExpression(pablo_or->getExpr1()), compileExpression(pablo_or->getExpr2()), "or");
532    }
533    else if (const Xor * pablo_xor = dyn_cast<Xor>(stmt)) {
534        expr = mBuilder->CreateXor(compileExpression(pablo_xor->getExpr1()), compileExpression(pablo_xor->getExpr2()), "xor");
535    }
536    else if (const Sel * sel = dyn_cast<Sel>(stmt)) {
537        Value* ifMask = compileExpression(sel->getCondition());
538        Value* ifTrue = mBuilder->CreateAnd(ifMask, compileExpression(sel->getTrueExpr()));
539        Value* ifFalse = mBuilder->CreateAnd(genNot(ifMask), compileExpression(sel->getFalseExpr()));
540        expr = mBuilder->CreateOr(ifTrue, ifFalse);
541    }
542    else if (const Not * pablo_not = dyn_cast<Not>(stmt)) {
543        expr = genNot(compileExpression(pablo_not->getExpr()));
544    }
545    else if (const Advance * adv = dyn_cast<Advance>(stmt)) {
546        Value* strm_value = compileExpression(adv->getExpr());
547        int shift = adv->getAdvanceAmount();
548        unsigned advance_index = adv->getLocalAdvanceIndex();
549        expr = mCarryManager->advanceCarryInCarryOut(advance_index, shift, strm_value);
550    }
551    else if (const Mod64Advance * adv = dyn_cast<Mod64Advance>(stmt)) {
552        Value* strm_value = compileExpression(adv->getExpr());
553        int shift = adv->getAdvanceAmount();
554        expr = iBuilder.simd_slli(64, strm_value, shift);
555    }
556    else if (const MatchStar * mstar = dyn_cast<MatchStar>(stmt)) {
557        Value * marker = compileExpression(mstar->getMarker());
558        Value * cc = compileExpression(mstar->getCharClass());
559        Value * marker_and_cc = mBuilder->CreateAnd(marker, cc);
560        unsigned carry_index = mstar->getLocalCarryIndex();
561        Value * sum = mCarryManager->addCarryInCarryOut(carry_index, marker_and_cc, cc);
562        expr = mBuilder->CreateOr(mBuilder->CreateXor(sum, cc), marker, "matchstar");
563    }
564    else if (const Mod64MatchStar * mstar = dyn_cast<Mod64MatchStar>(stmt)) {
565        Value * marker = compileExpression(mstar->getMarker());
566        Value * cc = compileExpression(mstar->getCharClass());
567        Value * marker_and_cc = mBuilder->CreateAnd(marker, cc);
568        Value * sum = iBuilder.simd_add(64, marker_and_cc, cc);
569        expr = mBuilder->CreateOr(mBuilder->CreateXor(sum, cc), marker, "matchstar64");
570    }
571    else if (const ScanThru * sthru = dyn_cast<ScanThru>(stmt)) {
572        Value * marker_expr = compileExpression(sthru->getScanFrom());
573        Value * cc_expr = compileExpression(sthru->getScanThru());
574        unsigned carry_index = sthru->getLocalCarryIndex();
575        Value * sum = mCarryManager->addCarryInCarryOut(carry_index, marker_expr, cc_expr);
576        expr = mBuilder->CreateAnd(sum, genNot(cc_expr), "scanthru");
577    }
578    else if (const Mod64ScanThru * sthru = dyn_cast<Mod64ScanThru>(stmt)) {
579        Value * marker_expr = compileExpression(sthru->getScanFrom());
580        Value * cc_expr = compileExpression(sthru->getScanThru());
581        Value * sum = iBuilder.simd_add(64, marker_expr, cc_expr);
582        expr = mBuilder->CreateAnd(sum, genNot(cc_expr), "scanthru64");
583    }
584    else if (const Count * c = dyn_cast<Count>(stmt)) {
585        unsigned count_index = c->getGlobalCountIndex();
586        Value * to_count = compileExpression(c->getExpr());
587        expr = mCarryManager->popCount(to_count, count_index);
588    }
589    else {
590        llvm::raw_os_ostream cerr(std::cerr);
591        PabloPrinter::print(stmt, cerr);
592        throw std::runtime_error("Unrecognized Pablo Statement! can't compile.");
593    }
594    mMarkerMap[stmt] = expr;
595    if (DumpTrace) {
596        genPrintRegister(stmt->getName()->to_string(), expr);
597    }
598   
599}
600
601Value * PabloCompiler::compileExpression(const PabloAST * expr) {
602    if (isa<Ones>(expr)) {
603        return mOneInitializer;
604    }
605    else if (isa<Zeroes>(expr)) {
606        return mZeroInitializer;
607    }
608    auto f = mMarkerMap.find(expr);
609    if (LLVM_UNLIKELY(f == mMarkerMap.end())) {
610        std::string o;
611        llvm::raw_string_ostream str(o);
612        str << "\"";
613        PabloPrinter::print(expr, str);
614        str << "\" was used before definition!";
615        throw std::runtime_error(str.str());
616    }
617    return f->second;
618}
619
620
621#ifdef USE_UADD_OVERFLOW
622#ifdef USE_TWO_UADD_OVERFLOW
623PabloCompiler::SumWithOverflowPack PabloCompiler::callUaddOverflow(Value* int128_e1, Value* int128_e2) {
624    std::vector<Value*> struct_res_params;
625    struct_res_params.push_back(int128_e1);
626    struct_res_params.push_back(int128_e2);
627    CallInst* struct_res = CallInst::Create(mFunctionUaddOverflow, struct_res_params, "uadd_overflow_res", mBasicBlock);
628    struct_res->setCallingConv(CallingConv::C);
629    struct_res->setTailCall(false);
630    AttributeSet struct_res_PAL;
631    struct_res->setAttributes(struct_res_PAL);
632
633    SumWithOverflowPack ret;
634
635    std::vector<unsigned> int128_sum_indices;
636    int128_sum_indices.push_back(0);
637    ret.sum = ExtractValueInst::Create(struct_res, int128_sum_indices, "sum", mBasicBlock);
638
639    std::vector<unsigned> int1_obit_indices;
640    int1_obit_indices.push_back(1);
641    ret.obit = ExtractValueInst::Create(struct_res, int1_obit_indices, "obit", mBasicBlock);
642
643    return ret;
644}
645#else
646PabloCompiler::SumWithOverflowPack PabloCompiler::callUaddOverflow(Value* int128_e1, Value* int128_e2, Value* int1_cin) {
647    std::vector<Value*> struct_res_params;
648    struct_res_params.push_back(int128_e1);
649    struct_res_params.push_back(int128_e2);
650    struct_res_params.push_back(int1_cin);
651    CallInst* struct_res = CallInst::Create(mFunctionUaddOverflowCarryin, struct_res_params, "uadd_overflow_res", mBasicBlock);
652    struct_res->setCallingConv(CallingConv::C);
653    struct_res->setTailCall(false);
654    AttributeSet struct_res_PAL;
655    struct_res->setAttributes(struct_res_PAL);
656
657    SumWithOverflowPack ret;
658
659    std::vector<unsigned> int128_sum_indices;
660    int128_sum_indices.push_back(0);
661    ret.sum = ExtractValueInst::Create(struct_res, int128_sum_indices, "sum", mBasicBlock);
662
663    std::vector<unsigned> int1_obit_indices;
664    int1_obit_indices.push_back(1);
665    ret.obit = ExtractValueInst::Create(struct_res, int1_obit_indices, "obit", mBasicBlock);
666
667    return ret;
668}
669#endif
670#endif
671
672
673Value* PabloCompiler::genAddWithCarry(Value* e1, Value* e2, unsigned localIndex) {
674    Value * carryq_value = mCarryManager->getCarryOpCarryIn(localIndex);
675#ifdef USE_TWO_UADD_OVERFLOW
676    //This is the ideal implementation, which uses two uadd.with.overflow
677    //The back end should be able to recognize this pattern and combine it into uadd.with.overflow.carryin
678    CastInst* int128_e1 = new BitCastInst(e1, mBuilder->getIntNTy(BLOCK_SIZE), "e1_128", mBasicBlock);
679    CastInst* int128_e2 = new BitCastInst(e2, mBuilder->getIntNTy(BLOCK_SIZE), "e2_128", mBasicBlock);
680    CastInst* int128_carryq_value = new BitCastInst(carryq_value, mBuilder->getIntNTy(BLOCK_SIZE), "carryq_128", mBasicBlock);
681
682    SumWithOverflowPack sumpack0, sumpack1;
683
684    sumpack0 = callUaddOverflow(int128_e1, int128_e2);
685    sumpack1 = callUaddOverflow(sumpack0.sum, int128_carryq_value);
686
687    Value* obit = mBuilder->CreateOr(sumpack0.obit, sumpack1.obit, "carry_bit");
688    Value* sum = mBuilder->CreateBitCast(sumpack1.sum, mBitBlockType, "ret_sum");
689
690    /*obit is the i1 carryout, zero extend and insert it into a v2i64 or v4i64 vector.*/
691    ConstantAggregateZero* const_packed_5 = ConstantAggregateZero::get(mBitBlockType);
692    ConstantInt* const_int32_6 = ConstantInt::get(mMod->getContext(), APInt(32, StringRef("0"), 10));
693    CastInst* int64_o0 = new ZExtInst(obit, IntegerType::get(mMod->getContext(), 64), "o0", mBasicBlock);
694    InsertElementInst* carry_out = InsertElementInst::Create(const_packed_5, int64_o0, const_int32_6, "carry_out", mBasicBlock);
695
696#elif defined USE_UADD_OVERFLOW
697    //use llvm.uadd.with.overflow.i128 or i256
698    CastInst* int128_e1 = new BitCastInst(e1, mBuilder->getIntNTy(BLOCK_SIZE), "e1_128", mBasicBlock);
699    CastInst* int128_e2 = new BitCastInst(e2, mBuilder->getIntNTy(BLOCK_SIZE), "e2_128", mBasicBlock);
700
701    //get i1 carryin from iBLOCK_SIZE
702    ConstantInt* const_int32_6 = ConstantInt::get(mMod->getContext(), APInt(32, StringRef("0"), 10));
703    ExtractElementInst * int64_carryq_value = ExtractElementInst::Create(carryq_value, const_int32_6, "carryq_64", mBasicBlock);
704    CastInst* int1_carryq_value = new TruncInst(int64_carryq_value, IntegerType::get(mMod->getContext(), 1), "carryq_1", mBasicBlock);
705
706    SumWithOverflowPack sumpack0;
707    sumpack0 = callUaddOverflow(int128_e1, int128_e2, int1_carryq_value);
708    Value* obit = sumpack0.obit;
709    Value* sum = mBuilder->CreateBitCast(sumpack0.sum, mBitBlockType, "sum");
710
711    /*obit is the i1 carryout, zero extend and insert it into a v2i64 or v4i64 vector.*/
712    ConstantAggregateZero* const_packed_5 = ConstantAggregateZero::get(mBitBlockType);
713    CastInst* int64_o0 = new ZExtInst(obit, IntegerType::get(mMod->getContext(), 64), "o0", mBasicBlock);
714    InsertElementInst* carry_out = InsertElementInst::Create(const_packed_5, int64_o0, const_int32_6, "carry_out", mBasicBlock);
715#elif (BLOCK_SIZE == 128)
716    //calculate carry through logical ops
717    Value* carrygen = mBuilder->CreateAnd(e1, e2, "carrygen");
718    Value* carryprop = mBuilder->CreateOr(e1, e2, "carryprop");
719    Value* digitsum = mBuilder->CreateAdd(e1, e2, "digitsum");
720    Value* partial = mBuilder->CreateAdd(digitsum, carryq_value, "partial");
721    Value* digitcarry = mBuilder->CreateOr(carrygen, mBuilder->CreateAnd(carryprop, genNot(partial)));
722    Value* mid_carry_in = genShiftLeft64(mBuilder->CreateLShr(digitcarry, 63), "mid_carry_in");
723
724    Value* sum = mBuilder->CreateAdd(partial, mid_carry_in, "sum");
725    Value* carry_out = genShiftHighbitToLow(BLOCK_SIZE, mBuilder->CreateOr(carrygen, mBuilder->CreateAnd(carryprop, genNot(sum))));
726#else
727    //BLOCK_SIZE == 256, there is no other implementation
728    static_assert(false, "Add with carry for 256-bit bitblock requires USE_UADD_OVERFLOW");
729#endif //USE_TWO_UADD_OVERFLOW
730
731    mCarryManager->setCarryOpCarryOut(localIndex, carry_out);
732    return sum;
733}
734
735Value * PabloCompiler::genShiftHighbitToLow(unsigned FieldWidth, Value * op) {
736    unsigned FieldCount = BLOCK_SIZE/FieldWidth;
737    VectorType * vType = VectorType::get(IntegerType::get(mMod->getContext(), FieldWidth), FieldCount);
738    Value * v = mBuilder->CreateBitCast(op, vType);
739    return mBuilder->CreateBitCast(mBuilder->CreateLShr(v, FieldWidth - 1), mBitBlockType);
740}
741
742Value* PabloCompiler::genShiftLeft64(Value* e, const Twine &namehint) {
743    Value* i128_val = mBuilder->CreateBitCast(e, mBuilder->getIntNTy(BLOCK_SIZE));
744    return mBuilder->CreateBitCast(mBuilder->CreateShl(i128_val, 64, namehint), mBitBlockType);
745}
746
747inline Value* PabloCompiler::genNot(Value* expr) {
748    return mBuilder->CreateXor(expr, mOneInitializer, "not");
749}
750   
751void PabloCompiler::SetOutputValue(Value * marker, const unsigned index) {
752    if (LLVM_UNLIKELY(marker == nullptr)) {
753        throw std::runtime_error("Cannot set result " + std::to_string(index) + " to Null");
754    }
755    if (LLVM_UNLIKELY(marker->getType()->isPointerTy())) {
756        marker = mBuilder->CreateAlignedLoad(marker, BLOCK_SIZE/8, false);
757    }
758    Value* indices[] = {mBuilder->getInt64(0), mBuilder->getInt32(index)};
759    Value* gep = mBuilder->CreateGEP(mOutputAddressPtr, indices);
760    mBuilder->CreateAlignedStore(marker, gep, BLOCK_SIZE/8, false);
761}
762
763}
Note: See TracBrowser for help on using the repository browser.