Ignore:
Timestamp:
Jul 13, 2015, 3:55:59 PM (4 years ago)
Author:
nmedfort
Message:

Made pablo compiler reenterant through alternate compile method that takes a Module parameter.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • icGREP/icgrep-devel/icgrep/pablo/pablo_compiler.cpp

    r4663 r4665  
    6767
    6868PabloCompiler::PabloCompiler()
    69 #ifdef USE_LLVM_3_5
    70 : mMod(new Module("icgrep", getGlobalContext()))
    71 #else
    72 : mModOwner(make_unique<Module>("icgrep", getGlobalContext()))
    73 , mMod(mModOwner.get())
    74 #endif
    75 , mBuilder(&LLVM_Builder)
     69: mMod(nullptr)
     70, mBuilder(nullptr)
    7671, mCarryManager(nullptr)
    77 , mBitBlockType(VectorType::get(IntegerType::get(mMod->getContext(), 64), BLOCK_SIZE / 64))
    78 , iBuilder(mMod, mBuilder, mBitBlockType)
    79 , mInputPtr(nullptr)
     72, mBitBlockType(VectorType::get(IntegerType::get(getGlobalContext(), 64), BLOCK_SIZE / 64))
     73, iBuilder(mBitBlockType)
     74, mInputType(nullptr)
    8075, mCarryDataPtr(nullptr)
    8176, mWhileDepth(0)
     
    9186}
    9287
    93 PabloCompiler::~PabloCompiler()
    94 {
    95 
     88PabloCompiler::~PabloCompiler() {
    9689}
    9790   
     
    10295void PabloCompiler::genPrintRegister(std::string regName, Value * bitblockValue) {
    10396    Constant * regNameData = ConstantDataArray::getString(mMod->getContext(), regName);
    104     GlobalVariable *regStrVar = new GlobalVariable(*mMod, 
     97    GlobalVariable *regStrVar = new GlobalVariable(*mMod,
    10598                                                   ArrayType::get(IntegerType::get(mMod->getContext(), 8), regName.length()+1),
    10699                                                   /*isConstant=*/ true,
     
    112105
    113106CompiledPabloFunction PabloCompiler::compile(PabloFunction & function) {
    114     mWhileDepth = 0;
    115     mIfDepth = 0;
    116     mMaxWhileDepth = 0;
    117     mCarryManager = new CarryManager(mBuilder, mBitBlockType, mZeroInitializer, mOneInitializer, &iBuilder);
    118 
    119     Examine(function.getEntryBlock());
    120    
     107
     108    Examine(function);
     109
    121110    InitializeNativeTarget();
    122111    InitializeNativeTargetAsmPrinter();
    123112    InitializeNativeTargetAsmParser();
    124113
     114    Module * module = new Module("", getGlobalContext());
     115
     116    mMod = module;
     117
    125118    std::string errMessage;
    126 #ifdef USE_LLVM_3_5
     119    #ifdef USE_LLVM_3_5
    127120    EngineBuilder builder(mMod);
    128 #else
    129     EngineBuilder builder(std::move(mModOwner));
    130 #endif
     121    #else
     122    EngineBuilder builder(std::move(std::unique_ptr<Module>(mMod)));
     123    #endif
    131124    builder.setErrorStr(&errMessage);
    132125    builder.setMCPU(sys::getHostCPUName());
    133 #ifdef USE_LLVM_3_5
     126    #ifdef USE_LLVM_3_5
    134127    builder.setUseMCJIT(true);
    135 #endif
     128    #endif
    136129    builder.setOptLevel(mMaxWhileDepth ? CodeGenOpt::Level::Less : CodeGenOpt::Level::None);
    137     ExecutionEngine * ee = builder.create();
    138     if (ee == nullptr) {
     130    ExecutionEngine * engine = builder.create();
     131    if (engine == nullptr) {
    139132        throw std::runtime_error("Could not create ExecutionEngine: " + errMessage);
    140133    }
     134    DeclareFunctions(engine);
     135    DeclareCallFunctions(function, engine);
     136
     137    auto func = compile(function, mMod);
     138
     139    //Display the IR that has been generated by this module.
     140    if (LLVM_UNLIKELY(DumpGeneratedIR)) {
     141        module->dump();
     142    }
     143    //Create a verifier.  The verifier will print an error message if our module is malformed in any way.
     144    verifyModule(*module, &dbgs());
     145
     146    engine->finalizeObject();
     147
     148    return CompiledPabloFunction(func.second, func.first, engine);
     149}
     150
     151std::pair<llvm::Function *, size_t> PabloCompiler::compile(PabloFunction & function, Module * module) {
     152
     153    Examine(function);
     154
     155    mMod = module;
     156
     157    mBuilder = new IRBuilder<>(mMod->getContext());
     158
     159    iBuilder.initialize(mMod, mBuilder);
     160
     161    mCarryManager = new CarryManager(mBuilder, mBitBlockType, mZeroInitializer, mOneInitializer, &iBuilder);
    141162
    142163    GenerateFunction(function);
    143     DeclareFunctions(ee);
    144     DeclareCallFunctions(ee);
    145 
    146     mWhileDepth = 0;
    147     mIfDepth = 0;
    148     mMaxWhileDepth = 0;
    149     BasicBlock * b = BasicBlock::Create(mMod->getContext(), "entry", mFunction,0);
    150     mBuilder->SetInsertPoint(b);
     164
     165    mBuilder->SetInsertPoint(BasicBlock::Create(mMod->getContext(), "entry", mFunction,0));
    151166
    152167    //The basis bits structure
     
    172187    }
    173188   
    174     if (LLVM_UNLIKELY(mWhileDepth != 0)) {
    175         throw std::runtime_error("Non-zero nesting depth error (" + std::to_string(mWhileDepth) + ")");
    176     }
    177 
    178189    // Write the output values out
    179190    for (unsigned i = 0; i != function.getResults().size(); ++i) {
     
    184195    ReturnInst::Create(mMod->getContext(), mBuilder->GetInsertBlock());
    185196
    186     //Display the IR that has been generated by this module.
    187     if (LLVM_UNLIKELY(DumpGeneratedIR)) {
    188         mMod->dump();
    189     }
    190     //Create a verifier.  The verifier will print an error message if our module is malformed in any way.
    191     verifyModule(*mMod, &dbgs());
    192 
    193     ee->finalizeObject();
    194 
    195     delete mCarryManager;
    196     mCarryManager = nullptr;
     197    // Clean up
     198    delete mCarryManager; mCarryManager = nullptr;
     199    delete mBuilder; mBuilder = nullptr;
     200    mMod = nullptr; // don't delete this. It's either owned by the ExecutionEngine or the calling function.
    197201
    198202    //Return the required size of the carry data area to the process_block function.
    199     return CompiledPabloFunction(totalCarryDataSize * sizeof(BitBlock), mFunction, ee);
     203    return std::make_pair(mFunction, totalCarryDataSize * sizeof(BitBlock));
    200204}
    201205
    202206inline void PabloCompiler::GenerateFunction(PabloFunction & function) {
    203     std::vector<Type *> inputType(function.getParameters().size(), mBitBlockType);
    204     std::vector<Type *> outputType(function.getResults().size(), mBitBlockType);
    205     mInputPtr = PointerType::get(StructType::get(mMod->getContext(), inputType), 0);
    206     Type * carryPtr = PointerType::get(mBitBlockType, 0);
    207     Type * outputPtr = PointerType::get(StructType::get(mMod->getContext(), outputType), 0);
    208     FunctionType * functionType = FunctionType::get(Type::getVoidTy(mMod->getContext()), {{mInputPtr, carryPtr, outputPtr}}, false);
    209 
     207    mInputType = PointerType::get(StructType::get(mMod->getContext(), std::vector<Type *>(function.getParameters().size(), mBitBlockType)), 0);
     208    Type * carryType = PointerType::get(mBitBlockType, 0);
     209    Type * outputType = PointerType::get(StructType::get(mMod->getContext(), std::vector<Type *>(function.getResults().size(), mBitBlockType)), 0);
     210    FunctionType * functionType = FunctionType::get(Type::getVoidTy(mMod->getContext()), {{mInputType, carryType, outputType}}, false);
    210211
    211212#ifdef USE_UADD_OVERFLOW
     
    314315}
    315316
    316 inline void PabloCompiler::DeclareFunctions(ExecutionEngine * ee) {
    317     if (DumpTrace || TraceNext) {
    318         //This function can be used for testing to print the contents of a register from JIT'd code to the terminal window.
    319         mPrintRegisterFunction = mMod->getOrInsertFunction("wrapped_print_register", Type::getVoidTy(getGlobalContext()), Type::getInt8PtrTy(getGlobalContext()), mBitBlockType, NULL);
    320         ee->addGlobalMapping(cast<GlobalValue>(mPrintRegisterFunction), (void *)&wrapped_print_register);
    321     }
    322 }
    323    
    324 void PabloCompiler::Examine(PabloBlock & blk) {
    325     for (Statement * stmt : blk) {
     317inline void PabloCompiler::Examine(PabloFunction & function) {
     318    if (mMod == nullptr) {
     319
     320        mWhileDepth = 0;
     321        mIfDepth = 0;
     322        mMaxWhileDepth = 0;
     323
     324        Examine(function.getEntryBlock());
     325
     326        if (LLVM_UNLIKELY(mWhileDepth != 0 || mIfDepth != 0)) {
     327            throw std::runtime_error("Malformed Pablo AST: Unbalanced If or While nesting depth!");
     328        }
     329    }
     330}
     331
     332
     333void PabloCompiler::Examine(PabloBlock & block) {
     334    for (Statement * stmt : block) {
    326335        if (Call * call = dyn_cast<Call>(stmt)) {
    327336            mCalleeMap.insert(std::make_pair(call->getCallee(), nullptr));
    328337        }
    329338        else if (If * ifStatement = dyn_cast<If>(stmt)) {
    330             ++mIfDepth;
    331339            Examine(ifStatement->getBody());
    332             --mIfDepth;
    333340        }
    334341        else if (While * whileStatement = dyn_cast<While>(stmt)) {
     
    340347}
    341348
    342 void PabloCompiler::DeclareCallFunctions(ExecutionEngine * ee) {
     349inline void PabloCompiler::DeclareFunctions(ExecutionEngine * engine) {
     350    if (DumpTrace || TraceNext) {
     351        //This function can be used for testing to print the contents of a register from JIT'd code to the terminal window.
     352        mPrintRegisterFunction = mMod->getOrInsertFunction("wrapped_print_register", Type::getVoidTy(mMod->getContext()), Type::getInt8PtrTy(mMod->getContext()), mBitBlockType, NULL);
     353        engine->addGlobalMapping(cast<GlobalValue>(mPrintRegisterFunction), (void *)&wrapped_print_register);
     354    }
     355}
     356   
     357void PabloCompiler::DeclareCallFunctions(PabloFunction & function, ExecutionEngine * engine) {
    343358    for (auto mapping : mCalleeMap) {
    344359        const String * callee = mapping.first;
    345         //std::cerr << callee->str() << " to be declared\n";
    346360        auto ei = mExternalMap.find(callee->value());
    347361        if (ei != mExternalMap.end()) {
    348             void * fn_ptr = ei->second;
    349             Value * externalValue = mMod->getOrInsertFunction(callee->value(), mBitBlockType, mInputPtr, NULL);
     362
     363            PointerType * inputType = PointerType::get(StructType::get(mMod->getContext(), std::vector<Type *>(function.getParameters().size(), mBitBlockType)), 0);
     364            ArrayRef<Type*> args = {inputType};
     365            FunctionType * functionType = FunctionType::get(mBitBlockType, args, false);
     366
     367            SmallVector<AttributeSet, 4> Attrs;
     368            Attrs.push_back(AttributeSet::get(mMod->getContext(), ~0U, { Attribute::NoUnwind, Attribute::UWTable }));
     369            Attrs.push_back(AttributeSet::get(mMod->getContext(), 1U, { Attribute::ReadOnly, Attribute::NoCapture }));
     370            AttributeSet AttrSet = AttributeSet::get(mMod->getContext(), Attrs);
     371
     372            Value * externalValue = mMod->getOrInsertFunction(callee->value(), functionType, AttrSet);
     373
    350374            if (LLVM_UNLIKELY(externalValue == nullptr)) {
    351375                throw std::runtime_error("Could not create static method call for external function \"" + callee->to_string() + "\"");
    352376            }
    353             ee->addGlobalMapping(cast<GlobalValue>(externalValue), fn_ptr);
     377            engine->addGlobalMapping(cast<GlobalValue>(externalValue), ei->second);
    354378            mCalleeMap[callee] = externalValue;
    355379        }
     
    409433    //  body.
    410434    //
     435
    411436    BasicBlock * ifEntryBlock = mBuilder->GetInsertBlock();
    412437    BasicBlock * ifBodyBlock = BasicBlock::Create(mMod->getContext(), "if.body", mFunction, 0);
Note: See TracChangeset for help on using the changeset viewer.