source: icGREP/icgrep-devel/icgrep/kernels/interface.cpp @ 5816

Last change on this file since 5816 was 5755, checked in by nmedfort, 23 months ago

Bug fixes and simplified MultiBlockKernel? logic

File size: 5.4 KB
RevLine 
[5047]1/*
2 *  Copyright (c) 2016 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 */
5
6#include "interface.h"
[5260]7#include <llvm/IR/Value.h>         // for Value
8#include <llvm/IR/CallingConv.h>   // for ::C
9#include <llvm/IR/DerivedTypes.h>  // for FunctionType (ptr only), PointerType
10#include <llvm/IR/Function.h>      // for Function, Function::arg_iterator
[5047]11#include <llvm/IR/Module.h>
[5436]12#include <kernels/kernel_builder.h>
[5047]13
[5287]14static const auto INIT_SUFFIX = "_Init";
15
16static const auto DO_SEGMENT_SUFFIX = "_DoSegment";
17
[5418]18static const auto TERMINATE_SUFFIX = "_Terminate";
[5287]19
[5047]20using namespace llvm;
21
[5706]22namespace kernel {
[5328]23
[5440]24void KernelInterface::addKernelDeclarations(const std::unique_ptr<kernel::KernelBuilder> & idb) {
[5431]25
[5060]26    if (mKernelStateType == nullptr) {
[5297]27        throw std::runtime_error("Kernel interface " + getName() + " not yet finalized.");
[5047]28    }
[5431]29
[5440]30    Module * const module = idb->getModule();
[5408]31    PointerType * const selfType = mKernelStateType->getPointerTo();
[5440]32    IntegerType * const sizeTy = idb->getSizeTy();
[5733]33    PointerType * const consumerTy = StructType::get(idb->getContext(), {sizeTy, sizeTy->getPointerTo()->getPointerTo()})->getPointerTo();
[5440]34    Type * const voidTy = idb->getVoidTy();
[5227]35
[5047]36    // Create the initialization function prototype
[5431]37    std::vector<Type *> initParameters = {selfType};
[5047]38    for (auto binding : mScalarInputs) {
[5706]39        initParameters.push_back(binding.getType());
[5047]40    }
[5408]41    initParameters.insert(initParameters.end(), mStreamSetOutputs.size(), consumerTy);
42
[5418]43    FunctionType * const initType = FunctionType::get(voidTy, initParameters, false);
[5431]44    Function * const initFunc = Function::Create(initType, GlobalValue::ExternalLinkage, getName() + INIT_SUFFIX, module);
[5418]45    initFunc->setCallingConv(CallingConv::C);
46    initFunc->setDoesNotThrow();
47    auto args = initFunc->arg_begin();
[5297]48    args->setName("self");
[5706]49    for (const Binding & binding : mScalarInputs) {
50        (++args)->setName(binding.getName());
[5049]51    }
[5706]52    for (const Binding & binding : mStreamSetOutputs) {
53        (++args)->setName(binding.getName() + "ConsumerLocks");
[5408]54    }
[5047]55
[5408]56    // Create the doSegment function prototype.
[5440]57    std::vector<Type *> params = {selfType, idb->getInt1Ty()};
[5401]58
[5706]59    const auto count = mStreamSetInputs.size();
60    params.insert(params.end(), count, sizeTy);
61
[5418]62    FunctionType * const doSegmentType = FunctionType::get(voidTy, params, false);
[5431]63    Function * const doSegment = Function::Create(doSegmentType, GlobalValue::ExternalLinkage, getName() + DO_SEGMENT_SUFFIX, module);
[5292]64    doSegment->setCallingConv(CallingConv::C);
65    doSegment->setDoesNotThrow();
66    args = doSegment->arg_begin();
[5297]67    args->setName("self");
68    (++args)->setName("doFinal");
[5755]69//    if (mHasPrincipalItemCount) {
[5706]70//        (++args)->setName("principleAvailableItemCount");
71//    }
[5398]72    for (const Binding & input : mStreamSetInputs) {
[5706]73        //const ProcessingRate & r = input.getRate();
74        //if (!r.isDerived()) {
75            (++args)->setName(input.getName() + "AvailableItems");
76        //}
[5263]77    }
[5287]78
[5411]79    // Create the terminate function prototype
[5418]80    Type * resultType = nullptr;
81    if (mScalarOutputs.empty()) {
[5440]82        resultType = idb->getVoidTy();
[5418]83    } else {
84        const auto n = mScalarOutputs.size();
85        Type * outputType[n];
86        for (unsigned i = 0; i < n; ++i) {
[5706]87            outputType[i] = mScalarOutputs[i].getType();
[5418]88        }
89        if (n == 1) {
90            resultType = outputType[0];
91        } else {
[5440]92            resultType = StructType::get(idb->getContext(), ArrayRef<Type *>(outputType, n));
[5418]93        }
94    }
95    FunctionType * const terminateType = FunctionType::get(resultType, {selfType}, false);
[5431]96    Function * const terminateFunc = Function::Create(terminateType, GlobalValue::ExternalLinkage, getName() + TERMINATE_SUFFIX, module);
[5411]97    terminateFunc->setCallingConv(CallingConv::C);
98    terminateFunc->setDoesNotThrow();
99    args = terminateFunc->arg_begin();
100    args->setName("self");
101
[5440]102    linkExternalMethods(idb);
[5047]103}
104
[5706]105void  KernelInterface::setInstance(Value * const instance) {
106    assert ("kernel instance cannot be null!" && instance);
107    assert ("kernel instance must point to a valid kernel state type!" && (instance->getType()->getPointerElementType() == mKernelStateType));
108    mKernelInstance = instance;
109}
110
[5431]111Function * KernelInterface::getInitFunction(Module * const module) const {
[5297]112    const auto name = getName() + INIT_SUFFIX;
[5431]113    Function * f = module->getFunction(name);
[5287]114    if (LLVM_UNLIKELY(f == nullptr)) {
115        llvm::report_fatal_error("Cannot find " + name);
[5053]116    }
[5287]117    return f;
[5053]118}
119
[5706]120Function * KernelInterface::getDoSegmentFunction(Module * const module) const {
[5297]121    const auto name = getName() + DO_SEGMENT_SUFFIX;
[5431]122    Function * f = module->getFunction(name);
[5287]123    if (LLVM_UNLIKELY(f == nullptr)) {
124        llvm::report_fatal_error("Cannot find " + name);
125    }
126    return f;
[5285]127}
[5411]128
[5431]129Function * KernelInterface::getTerminateFunction(Module * const module) const {
[5411]130    const auto name = getName() + TERMINATE_SUFFIX;
[5431]131    Function * f = module->getFunction(name);
[5411]132    if (LLVM_UNLIKELY(f == nullptr)) {
133        llvm::report_fatal_error("Cannot find " + name);
134    }
135    return f;
136}
[5706]137
138CallInst * KernelInterface::makeDoSegmentCall(kernel::KernelBuilder & idb, const std::vector<llvm::Value *> & args) const {
139    Function * const doSegment = getDoSegmentFunction(idb.getModule());
140    assert (doSegment->getArgumentList().size() <= args.size());
141    return idb.CreateCall(doSegment, args);
142}
143
144}
Note: See TracBrowser for help on using the repository browser.