Ignore:
Timestamp:
Jul 10, 2015, 4:51:39 PM (4 years ago)
Author:
nmedfort
Message:

Initial introduction of a PabloFunction? type.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • icGREP/icgrep-devel/icgrep/pablo/pablo_compiler.cpp

    r4654 r4657  
    1010#include <pablo/carry_manager.h>
    1111#include <pablo/printer_pablos.h>
     12#include <pablo/function.h>
    1213#include <cc/cc_namemap.hpp>
    1314#include <re/re_name.h>
     
    6465namespace pablo {
    6566
    66 PabloCompiler::PabloCompiler(const std::vector<Var*> & basisBits)
    67 : mBasisBits(basisBits)
     67PabloCompiler::PabloCompiler()
    6868#ifdef USE_LLVM_3_5
    69 , mMod(new Module("icgrep", getGlobalContext()))
     69: mMod(new Module("icgrep", getGlobalContext()))
    7070#else
    71 , mModOwner(make_unique<Module>("icgrep", getGlobalContext()))
     71: mModOwner(make_unique<Module>("icgrep", getGlobalContext()))
    7272, mMod(mModOwner.get())
    7373#endif
     
    8484, mFunctionType(nullptr)
    8585, mFunction(nullptr)
    86 , mBasisBitsAddr(nullptr)
     86, mParameterAddr(nullptr)
    8787, mOutputAddrPtr(nullptr)
    8888, mMaxWhileDepth(0)
     
    9393    InitializeNativeTargetAsmPrinter();
    9494    InitializeNativeTargetAsmParser();
    95     DefineTypes();
    9695}
    9796
     
    116115}
    117116
    118 CompiledPabloFunction PabloCompiler::compile(PabloBlock & pb)
     117CompiledPabloFunction PabloCompiler::compile(PabloFunction & function)
    119118{
    120119    mWhileDepth = 0;
     
    139138        throw std::runtime_error("Could not create ExecutionEngine: " + errMessage);
    140139    }
     140
     141    DefineTypes(function);
    141142    DeclareFunctions();
    142143
    143     Examine(pb);
     144    Examine(function.getEntryBlock());
    144145    DeclareCallFunctions();
    145146
    146147    Function::arg_iterator args = mFunction->arg_begin();
    147     mBasisBitsAddr = args++;
    148     mBasisBitsAddr->setName("basis_bits");
     148    mParameterAddr = args++;
     149    mParameterAddr->setName("basis_bits");
    149150    mCarryDataPtr = args++;
    150151    mCarryDataPtr->setName("carry_data");
     
    159160
    160161    //The basis bits structure
    161     for (unsigned i = 0; i != mBasisBits.size(); ++i) {
     162
     163    for (unsigned i = 0; i != function.getParameters().size(); ++i) {
    162164        Value* indices[] = {mBuilder->getInt64(0), mBuilder->getInt32(i)};
    163         Value * gep = mBuilder->CreateGEP(mBasisBitsAddr, indices);
    164         LoadInst * basisBit = mBuilder->CreateAlignedLoad(gep, BLOCK_SIZE/8, false, mBasisBits[i]->getName()->to_string());
    165         mMarkerMap.insert(std::make_pair(mBasisBits[i], basisBit));
     165        Value * gep = mBuilder->CreateGEP(mParameterAddr, indices);
     166        LoadInst * basisBit = mBuilder->CreateAlignedLoad(gep, BLOCK_SIZE/8, false, function.getParameter(i)->getName()->to_string());
     167        mMarkerMap.insert(std::make_pair(function.getParameter(i), basisBit));
    166168    }
    167169       
    168     unsigned totalCarryDataSize = mCarryManager->initialize(&pb, mCarryDataPtr);
     170    unsigned totalCarryDataSize = mCarryManager->initialize(&(function.getEntryBlock()), mCarryDataPtr);
    169171   
    170172    //Generate the IR instructions for the function.
    171     compileBlock(pb);
     173    compileBlock(function.getEntryBlock());
    172174   
    173175    mCarryManager->generateBlockNoIncrement();
     
    179181    if (LLVM_UNLIKELY(mWhileDepth != 0)) {
    180182        throw std::runtime_error("Non-zero nesting depth error (" + std::to_string(mWhileDepth) + ")");
     183    }
     184
     185    // Write the output values out
     186    for (unsigned i = 0; i != function.getResults().size(); ++i) {
     187        SetOutputValue(mMarkerMap[function.getResult(i)], i);
    181188    }
    182189
     
    193200    mExecutionEngine->finalizeObject();
    194201
     202    delete mCarryManager;
     203    mCarryManager = nullptr;
     204
    195205    //Return the required size of the carry data area to the process_block function.
    196206    return CompiledPabloFunction(totalCarryDataSize * sizeof(BitBlock), mFunction, mExecutionEngine);
    197207}
    198208
    199 void PabloCompiler::DefineTypes()
    200 {
     209void PabloCompiler::DefineTypes(PabloFunction & function) {
     210
    201211    StructType * structBasisBits = mMod->getTypeByName("struct.Basis_bits");
    202212    if (structBasisBits == nullptr) {
    203213        structBasisBits = StructType::create(mMod->getContext(), "struct.Basis_bits");
    204214    }
    205     std::vector<Type*>StructTy_struct_Basis_bits_fields;
    206     for (int i = 0; i != mBasisBits.size(); i++)
    207     {
     215    std::vector<Type*> StructTy_struct_Basis_bits_fields;
     216    for (int i = 0; i != function.getParameters().size(); i++) {
    208217        StructTy_struct_Basis_bits_fields.push_back(mBitBlockType);
    209218    }
     
    227236    if (outputStruct->isOpaque()) {
    228237        std::vector<Type*>fields;
    229         fields.push_back(mBitBlockType);
    230         fields.push_back(mBitBlockType);
     238        for (int i = 0; i != function.getResults().size(); i++) {
     239            fields.push_back(mBitBlockType);
     240        }
    231241        outputStruct->setBody(fields, /*isPacked=*/false);
    232242    }
    233     PointerType* outputStructPtr = PointerType::get(outputStruct, 0);
     243    PointerType * outputStructPtr = PointerType::get(outputStruct, 0);
    234244
    235245    //The &output parameter.
     
    555565    Value * expr = nullptr;
    556566    if (const Assign * assign = dyn_cast<const Assign>(stmt)) {
    557         expr = compileExpression(assign->getExpr());
    558         if (LLVM_UNLIKELY(assign->isOutputAssignment())) {
    559             SetOutputValue(expr, assign->getOutputIndex());
    560         }
     567        expr = compileExpression(assign->getExpression());
    561568    }
    562569    else if (const Next * next = dyn_cast<const Next>(stmt)) {
     
    583590            throw std::runtime_error("Unexpected error locating static function for \"" + call->getCallee()->to_string() + "\"");
    584591        }
    585         expr = mBuilder->CreateCall(ci->second, mBasisBitsAddr);
     592        expr = mBuilder->CreateCall(ci->second, mParameterAddr);
    586593    }
    587594    else if (const And * pablo_and = dyn_cast<And>(stmt)) {
Note: See TracChangeset for help on using the changeset viewer.