Ignore:
Timestamp:
Jun 14, 2016, 2:31:21 PM (3 years ago)
Author:
cameron
Message:

Support for KernelInterface? instances

File:
1 edited

Legend:

Unmodified
Added
Removed
  • icGREP/icgrep-devel/icgrep/kernels/interface.cpp

    r5050 r5053  
    1919                                 std::vector<ScalarBinding> scalar_parameters,
    2020                                 std::vector<ScalarBinding> scalar_outputs,
    21                                  std::vector<ScalarBinding> internal_scalars) {
    22     iBuilder = builder;
    23     mKernelName = kernelName;
    24     mStreamSetInputs = stream_inputs;
    25     mStreamSetOutputs = stream_outputs;
    26     mScalarInputs = scalar_parameters;
    27     mScalarOutputs = scalar_outputs;
    28     mInternalScalars = internal_scalars;
    29     std::vector<Type *> kernelFields;
     21                                 std::vector<ScalarBinding> internal_scalars) :
     22                iBuilder(builder),
     23                mKernelName(kernelName),
     24                mStreamSetInputs(stream_inputs),
     25                mStreamSetOutputs(stream_outputs),
     26                mScalarInputs(scalar_parameters),
     27                mScalarOutputs(scalar_outputs),
     28                mInternalScalars(internal_scalars),
     29                mKernelStateType(nullptr) {
     30   
    3031    for (auto binding : scalar_parameters) {
    31         unsigned index = kernelFields.size();
    32         kernelFields.push_back(binding.scalarType);
    33         mInternalStateNameMap.emplace(binding.scalarName, iBuilder->getInt32(index));
     32        addScalar(binding.scalarType, binding.scalarName);
    3433    }
    3534    for (auto binding : scalar_outputs) {
    36         unsigned index = kernelFields.size();
    37         kernelFields.push_back(binding.scalarType);
    38         mInternalStateNameMap.emplace(binding.scalarName, iBuilder->getInt32(index));
     35        addScalar(binding.scalarType, binding.scalarName);
    3936    }
    4037    for (auto binding : internal_scalars) {
    41         unsigned index = kernelFields.size();
    42         kernelFields.push_back(binding.scalarType);
    43         mInternalStateNameMap.emplace(binding.scalarName, iBuilder->getInt32(index));
    44     }
    45     mKernelStateType = StructType::create(getGlobalContext(), kernelFields, kernelName);
     38        addScalar(binding.scalarType, binding.scalarName);
     39    }
     40}
     41
     42const std::string init_suffix = "_Init";
     43const std::string doBlock_suffix = "_DoBlock";
     44const std::string finalBlock_suffix = "_FinalBlock";
     45const std::string accumulator_infix = "_get_";
     46
     47
     48void KernelInterface::addScalar(Type * t, std::string scalarName) {
     49    if (LLVM_UNLIKELY(mKernelStateType != nullptr)) {
     50        throw std::runtime_error("Illegal addition of kernel field after kernel state finalized: " + scalarName);
     51    }
     52    unsigned index = mKernelFields.size();
     53    mKernelFields.push_back(t);
     54    mInternalStateNameMap.emplace(scalarName, iBuilder->getInt32(index));
     55}
     56
     57void KernelInterface::finalizeKernelStateType() {
     58    mKernelStateType = StructType::create(getGlobalContext(), mKernelFields, mKernelName);
    4659}
    4760
    4861void KernelInterface::addKernelDeclarations(Module * client) {
    49    
     62    finalizeKernelStateType();
    5063    Type * selfType = PointerType::getUnqual(mKernelStateType);
    5164    // Create the accumulator get function prototypes
    5265    for (auto binding : mScalarOutputs) {
    5366        FunctionType * accumFnType = FunctionType::get(binding.scalarType, {selfType}, false);
    54         Function * accumFn = Function::Create(accumFnType, GlobalValue::ExternalLinkage, mKernelName + "_get_" + binding.scalarName, client);
     67        std::string fnName = mKernelName + accumulator_infix + binding.scalarName;
     68        Function * accumFn = Function::Create(accumFnType, GlobalValue::ExternalLinkage, fnName, client);
    5569        accumFn->setCallingConv(CallingConv::C);
    5670        accumFn->setDoesNotThrow();
     
    6579    }
    6680    FunctionType * initFunctionType = FunctionType::get(iBuilder->getVoidTy(), initParameters, false);
    67     Function * initFn = Function::Create(initFunctionType, GlobalValue::ExternalLinkage, mKernelName + "_Init", client);
     81    std::string initFnName = mKernelName + init_suffix;
     82    Function * initFn = Function::Create(initFunctionType, GlobalValue::ExternalLinkage, initFnName, client);
    6883    initFn->setCallingConv(CallingConv::C);
    6984    initFn->setDoesNotThrow();
     
    91106    }
    92107    FunctionType * doBlockFunctionType = FunctionType::get(iBuilder->getVoidTy(), doBlockParameters, false);
    93     Function * doBlockFn = Function::Create(doBlockFunctionType, GlobalValue::ExternalLinkage, mKernelName + "_DoBlock", client);
     108    std::string doBlockName = mKernelName + doBlock_suffix;
     109    Function * doBlockFn = Function::Create(doBlockFunctionType, GlobalValue::ExternalLinkage, doBlockName, client);
    94110    doBlockFn->setCallingConv(CallingConv::C);
    95111    doBlockFn->setDoesNotThrow();
     
    98114    }
    99115   
    100     FunctionType * finalBlockFunctionType = FunctionType::get(iBuilder->getVoidTy(), finalBlockParameters, false);
    101     Function * finalBlockFn = Function::Create(finalBlockFunctionType, GlobalValue::ExternalLinkage, mKernelName + "_FinalBlock", client);
     116    FunctionType * finalBlockType = FunctionType::get(iBuilder->getVoidTy(), finalBlockParameters, false);
     117    std::string finalBlockName = mKernelName + finalBlock_suffix;
     118    Function * finalBlockFn = Function::Create(finalBlockType, GlobalValue::ExternalLinkage, finalBlockName, client);
    102119    finalBlockFn->setCallingConv(CallingConv::C);
    103120    finalBlockFn->setDoesNotThrow();
     
    133150
    134151std::unique_ptr<Module> KernelInterface::createKernelModule() {
    135     std::unique_ptr<Module> theModule = llvm::make_unique<Module>(mKernelName, getGlobalContext());
     152    std::unique_ptr<Module> theModule = make_unique<Module>(mKernelName, getGlobalContext());
    136153    addKernelDeclarations(theModule.get());
    137154   
     
    163180}
    164181
    165 llvm::Value * KernelInterface::getScalarIndex(std::string fieldName) {
     182Value * KernelInterface::getScalarIndex(std::string fieldName) {
    166183    const auto f = mInternalStateNameMap.find(fieldName);
    167184    if (LLVM_UNLIKELY(f == mInternalStateNameMap.end())) {
     
    172189
    173190
    174 llvm::Value * KernelInterface::getParameter(Function * f, std::string paramName) {
     191Value * KernelInterface::getScalarField(Value * self, std::string fieldName) {
     192    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(fieldName)});
     193    return iBuilder->CreateLoad(ptr);
     194}
     195
     196void KernelInterface::setScalarField(Value * self, std::string fieldName, Value * newFieldVal) {
     197    Value * ptr = iBuilder->CreateGEP(self, {iBuilder->getInt32(0), getScalarIndex(fieldName)});
     198    iBuilder->CreateStore(ptr, newFieldVal);
     199}
     200
     201
     202Value * KernelInterface::getParameter(Function * f, std::string paramName) {
    175203    for (Function::arg_iterator argIter = f->arg_begin(), end = f->arg_end(); argIter != end; argIter++) {
    176204        Value * arg = &*argIter;
     
    180208}
    181209
     210
     211Value * KernelInterface::createInstance(std::vector<Value *> args) {
     212    Value * kernelInstance = iBuilder->CreateAlloca(mKernelStateType);
     213    Module * m = iBuilder->getModule();
     214    std::vector<Value *> init_args = {kernelInstance};
     215    for (auto a : args) {
     216        init_args.push_back(a);
     217    }
     218    std::string initFnName = mKernelName + init_suffix;
     219    Function * initMethod = m->getFunction(initFnName);
     220    if (!initMethod) {
     221        throw std::runtime_error("Cannot find " + initFnName);
     222        //Or just zero-initialize???
     223        //iBuilder->CreateStore(Constant::getNullValue(mKernelStateType), kernelInstance);
     224        //return kernelInstance;
     225    }
     226    iBuilder->CreateCall(initMethod, init_args);
     227    return kernelInstance;
     228}
     229
     230Value * KernelInterface::createDoBlockCall(Value * self, std::vector<Value *> streamSets) {
     231    Module * m = iBuilder->getModule();
     232    std::string doBlockName = mKernelName + doBlock_suffix;
     233    Function * doBlockMethod = m->getFunction(doBlockName);
     234    if (!doBlockMethod) {
     235        throw std::runtime_error("Cannot find " + doBlockName);
     236    }
     237    std::vector<Value *> args = {self};
     238    for (auto ss : streamSets) {
     239        args.push_back(ss);
     240    }
     241    return iBuilder->CreateCall(doBlockMethod, args);
     242}
     243
     244Value * KernelInterface::createFinalBlockCall(Value * self, Value * remainingBytes, std::vector<Value *> streamSets) {
     245    Module * m = iBuilder->getModule();
     246    std::string finalBlockName = mKernelName + finalBlock_suffix;
     247    Function * finalBlockMethod = m->getFunction(finalBlockName);
     248    if (!finalBlockMethod) {
     249        throw std::runtime_error("Cannot find " + finalBlockName);
     250    }
     251    std::vector<Value *> args = {self, remainingBytes};
     252    for (auto ss : streamSets) {
     253        args.push_back(ss);
     254    }
     255    return iBuilder->CreateCall(finalBlockMethod, args);
     256}
     257
     258Value * KernelInterface::createGetAccumulatorCall(Value * self, std::string accumName) {
     259    Module * m = iBuilder->getModule();
     260    std::string fnName = mKernelName + accumulator_infix + accumName;
     261    Function * accumMethod = m->getFunction(fnName);
     262    if (!accumMethod) {
     263        throw std::runtime_error("Cannot find " + fnName);
     264    }
     265    return iBuilder->CreateCall(accumMethod, {self});
     266}
     267
     268
Note: See TracChangeset for help on using the changeset viewer.