source: icGREP/icgrep-devel/icgrep/toolchain/cpudriver.cpp @ 6187

Last change on this file since 6187 was 6187, checked in by nmedfort, 10 months ago

Potential bug fix for u32u8. CPUDriver only constructs the pass manager if uncached kernels exist.

File size: 12.2 KB
RevLine 
[5464]1#include "cpudriver.h"
2
3#include <IR_Gen/idisa_target.h>
[5731]4#include <toolchain/toolchain.h>
[5913]5#include <llvm/Support/DynamicLibrary.h>           // for LoadLibraryPermanently
[5915]6#include <llvm/ExecutionEngine/ExecutionEngine.h>  // for EngineBuilder
[5913]7#include <llvm/ExecutionEngine/RTDyldMemoryManager.h>
[5915]8
[5464]9#include <llvm/IR/LegacyPassManager.h>             // for PassManager
10#include <llvm/IR/IRPrintingPasses.h>
[5915]11#include <llvm/InitializePasses.h>                 // for initializeCodeGencd .
[5464]12#include <llvm/PassRegistry.h>                     // for PassRegistry
13#include <llvm/Support/CodeGen.h>                  // for Level, Level::None
14#include <llvm/Support/Compiler.h>                 // for LLVM_UNLIKELY
15#include <llvm/Support/TargetSelect.h>
[5733]16#include <llvm/Support/FileSystem.h>
[5464]17#include <llvm/Target/TargetMachine.h>             // for TargetMachine, Tar...
18#include <llvm/Target/TargetOptions.h>             // for TargetOptions
19#include <llvm/Transforms/Scalar.h>
[5841]20#if LLVM_VERSION_INTEGER >= LLVM_VERSION_CODE(3, 9, 0)
[5731]21#include <llvm/Transforms/Scalar/GVN.h>
22#endif
[5464]23#include <llvm/Transforms/Utils/Local.h>
24#include <toolchain/object_cache.h>
25#include <kernels/kernel_builder.h>
[6184]26#include <kernels/pipeline_builder.h>
[5464]27#include <llvm/IR/Verifier.h>
[5920]28#include "llvm/IR/Mangler.h"
[5915]29#ifdef ORCJIT
30#if LLVM_VERSION_INTEGER < LLVM_VERSION_CODE(4, 0, 0)
31#include <llvm/ExecutionEngine/Orc/JITSymbol.h>
32#else
33#include <llvm/ExecutionEngine/JITSymbol.h>
34#endif
35#include <llvm/ExecutionEngine/RuntimeDyld.h>
36#include <llvm/ExecutionEngine/SectionMemoryManager.h>
37#include <llvm/ExecutionEngine/Orc/CompileUtils.h>
38#include <llvm/ExecutionEngine/Orc/IRCompileLayer.h>
39#include <llvm/ExecutionEngine/Orc/IRTransformLayer.h>
40#include <llvm/ExecutionEngine/Orc/LambdaResolver.h>
[5917]41#if LLVM_VERSION_INTEGER < LLVM_VERSION_CODE(5, 0, 0)
[5915]42#include <llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h>
[5917]43#else
44#include <llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h>
45#endif
[5915]46#include <llvm/ExecutionEngine/Orc/GlobalMappingLayer.h>
47#endif
[5646]48
[5464]49#ifndef NDEBUG
50#define IN_DEBUG_MODE true
51#else
52#define IN_DEBUG_MODE false
53#endif
54
55using namespace llvm;
[6184]56using kernel::Kernel;
57using kernel::PipelineKernel;
58using kernel::StreamSetBuffer;
59using kernel::StreamSetBuffers;
60using kernel::KernelBuilder;
[5464]61
[6184]62CPUDriver::CPUDriver(std::string && moduleName)
63: BaseDriver(std::move(moduleName))
[5464]64, mTarget(nullptr)
[5915]65#ifndef ORCJIT
[5464]66, mEngine(nullptr)
[5915]67#endif
[5932]68, mUnoptimizedIROutputStream(nullptr)
[5616]69, mIROutputStream(nullptr)
70, mASMOutputStream(nullptr) {
[5464]71
72    InitializeNativeTarget();
73    InitializeNativeTargetAsmPrinter();
74    InitializeNativeTargetAsmParser();
[5913]75    llvm::sys::DynamicLibrary::LoadLibraryPermanently(nullptr);
[5915]76   
[5464]77
[6184]78    #ifdef ORCJIT
[5915]79    EngineBuilder builder;
[6184]80    #else
[5464]81    std::string errMessage;
82    EngineBuilder builder{std::unique_ptr<Module>(mMainModule)};
83    builder.setErrorStr(&errMessage);
84    builder.setUseOrcMCJITReplacement(true);
[5915]85    builder.setVerifyModules(false);
86    builder.setEngineKind(EngineKind::JIT);
[6184]87    #endif
[6030]88    builder.setTargetOptions(codegen::target_Options);
[5464]89    builder.setOptLevel(codegen::OptLevel);
90
91    StringMap<bool> HostCPUFeatures;
92    if (sys::getHostCPUFeatures(HostCPUFeatures)) {
93        std::vector<std::string> attrs;
94        for (auto &flag : HostCPUFeatures) {
[5924]95            if (flag.second) {
96                attrs.push_back("+" + flag.first().str());
97            }
[5464]98        }
99        builder.setMAttrs(attrs);
100    }
101
[5915]102    mTarget = builder.selectTarget();
103   
104    if (mTarget == nullptr) {
105        throw std::runtime_error("Could not selectTarget");
106    }
[6184]107    #ifdef ORCJIT
[5915]108    mCompileLayer = make_unique<CompileLayerT>(mObjectLayer, orc::SimpleCompiler(*mTarget));
[6184]109    #else
[5464]110    mEngine = builder.create();
111    if (mEngine == nullptr) {
112        throw std::runtime_error("Could not create ExecutionEngine: " + errMessage);
113    }
[6184]114    #endif
115    auto cache = ObjectCacheManager::getObjectCache();
116    if (cache) {
117        #ifdef ORCJIT
118        #if LLVM_VERSION_INTEGER < LLVM_VERSION_CODE(5, 0, 0)
119        mCompileLayer->setObjectCache(cache);
120        #else
121        mCompileLayer->getCompiler().setObjectCache(cache);
122        #endif
123        #else
124        mEngine->setObjectCache(cache);
125        #endif
[5464]126    }
[5915]127    auto triple = mTarget->getTargetTriple().getTriple();
128    const DataLayout DL(mTarget->createDataLayout());
129    mMainModule->setTargetTriple(triple);
130    mMainModule->setDataLayout(DL);
[5489]131    iBuilder.reset(IDISA::GetIDISA_Builder(*mContext));
[5464]132    iBuilder->setDriver(this);
133    iBuilder->setModule(mMainModule);
134}
135
[6184]136Function * CPUDriver::addLinkFunction(Module * mod, llvm::StringRef name, FunctionType * type, void * functionPtr) const {
[5630]137    if (LLVM_UNLIKELY(mod == nullptr)) {
138        report_fatal_error("addLinkFunction(" + name + ") cannot be called until after addKernelCall or makeKernelCall");
139    }
[5486]140    Function * f = mod->getFunction(name);
141    if (LLVM_UNLIKELY(f == nullptr)) {
142        f = Function::Create(type, Function::ExternalLinkage, name, mod);
[6184]143        #ifndef ORCJIT
[5630]144        mEngine->updateGlobalMapping(f, functionPtr);
[6184]145        #endif
[5486]146    } else if (LLVM_UNLIKELY(f->getType() != type->getPointerTo())) {
[5493]147        report_fatal_error("Cannot link " + name + ": a function with a different signature already exists with that name in " + mod->getName());
[5486]148    }
[5464]149    return f;
150}
151
[6184]152std::string CPUDriver::getMangledName(std::string s) {
[5985]153    #if LLVM_VERSION_INTEGER >= LLVM_VERSION_CODE(3, 9, 0)
[5919]154    DataLayout DL(mTarget->createDataLayout());   
155    std::string MangledName;
156    raw_string_ostream MangledNameStream(MangledName);
157    Mangler::getNameWithPrefix(MangledNameStream, s, DL);
158    return MangledName;
[5985]159    #else
160    return s;
161    #endif
[5919]162}
163
[6184]164void CPUDriver::preparePassManager() {
[5913]165    PassRegistry * Registry = PassRegistry::getPassRegistry();
166    initializeCore(*Registry);
167    initializeCodeGen(*Registry);
168    initializeLowerIntrinsicsPass(*Registry);
169   
[5745]170    if (LLVM_UNLIKELY(codegen::ShowUnoptimizedIROption != codegen::OmittedOption)) {
[5616]171        if (LLVM_LIKELY(mIROutputStream == nullptr)) {
[5745]172            if (codegen::ShowUnoptimizedIROption != "") {
[5616]173                std::error_code error;
[5932]174                mUnoptimizedIROutputStream = make_unique<raw_fd_ostream>(codegen::ShowUnoptimizedIROption, error, sys::fs::OpenFlags::F_None);
[5616]175            } else {
[5932]176                mUnoptimizedIROutputStream = make_unique<raw_fd_ostream>(STDERR_FILENO, false, true);
[5616]177            }
[5464]178        }
[5932]179        mPassManager.add(createPrintModulePass(*mUnoptimizedIROutputStream));
[5464]180    }
181    if (IN_DEBUG_MODE || LLVM_UNLIKELY(codegen::DebugOptionIsSet(codegen::VerifyIR))) {
[5913]182        mPassManager.add(createVerifierPass());
[5464]183    }
[6184]184
185    mPassManager.add(createDeadCodeEliminationPass());        // Eliminate any trivially dead code
[5913]186    mPassManager.add(createPromoteMemoryToRegisterPass());    // Promote stack variables to constants or PHI nodes
187    mPassManager.add(createCFGSimplificationPass());          // Remove dead basic blocks and unnecessary branch statements / phi nodes
188    mPassManager.add(createEarlyCSEPass());                   // Simple common subexpression elimination pass
189    mPassManager.add(createInstructionCombiningPass());       // Simple peephole optimizations and bit-twiddling.
190    mPassManager.add(createReassociatePass());                // Canonicalizes commutative expressions
191    mPassManager.add(createGVNPass());                        // Global value numbering redundant expression elimination pass
192    mPassManager.add(createCFGSimplificationPass());          // Repeat CFG Simplification to "clean up" any newly found redundant phi nodes
[6184]193    if (LLVM_UNLIKELY(codegen::DebugOptionIsSet(codegen::EnableAsserts))) {
194        mPassManager.add(createRemoveRedundantAssertionsPass());
195        mPassManager.add(createDeadCodeEliminationPass());
196        mPassManager.add(createCFGSimplificationPass());
197    }
198
[5745]199    if (LLVM_UNLIKELY(codegen::ShowIROption != codegen::OmittedOption)) {
[5616]200        if (LLVM_LIKELY(mIROutputStream == nullptr)) {
[5745]201            if (codegen::ShowIROption != "") {
[5464]202                std::error_code error;
[5932]203                mIROutputStream = make_unique<raw_fd_ostream>(codegen::ShowIROption, error, sys::fs::OpenFlags::F_None);
[5486]204            } else {
[5932]205                mIROutputStream = make_unique<raw_fd_ostream>(STDERR_FILENO, false, true);
[5464]206            }
207        }
[5913]208        mPassManager.add(createPrintModulePass(*mIROutputStream));
[5464]209    }
[5913]210   
[5841]211#if LLVM_VERSION_INTEGER >= LLVM_VERSION_CODE(3, 7, 0)
[5745]212    if (LLVM_UNLIKELY(codegen::ShowASMOption != codegen::OmittedOption)) {
213        if (codegen::ShowASMOption != "") {
[5464]214            std::error_code error;
[5932]215            mASMOutputStream = make_unique<raw_fd_ostream>(codegen::ShowASMOption, error, sys::fs::OpenFlags::F_None);
[5486]216        } else {
[5932]217            mASMOutputStream = make_unique<raw_fd_ostream>(STDERR_FILENO, false, true);
[5464]218        }
[5913]219        if (LLVM_UNLIKELY(mTarget->addPassesToEmitFile(mPassManager, *mASMOutputStream, TargetMachine::CGFT_AssemblyFile))) {
[5464]220            report_fatal_error("LLVM error: could not add emit assembly pass");
221        }
222    }
[5731]223#endif
[5913]224}
[5464]225
[6184]226void CPUDriver::generateUncachedKernels() {
[6187]227    if (mUncachedKernel.empty()) return;
228    preparePassManager();
[6184]229    for (auto & kernel : mUncachedKernel) {
230        kernel->prepareKernel(iBuilder);
231    }
232    mCachedKernel.reserve(mUncachedKernel.size());
233    for (auto & kernel : mUncachedKernel) {
234        kernel->generateKernel(iBuilder);
235        Module * const module = kernel->getModule(); assert (module);
236        module->setTargetTriple(mMainModule->getTargetTriple());
237        mPassManager.run(*module);
238        mCachedKernel.emplace_back(kernel.release());
239    }
240    mUncachedKernel.clear();
241}
242
243void * CPUDriver::finalizeObject(llvm::Function * mainMethod) {
244
245    #ifdef ORCJIT
246    auto Resolver = llvm::orc::createLambdaResolver(
247        [&](const std::string &Name) {
248            auto Sym = mCompileLayer->findSymbol(Name, false);
249            if (!Sym) Sym = mCompileLayer->findSymbol(getMangledName(Name), false);
250            #if LLVM_VERSION_INTEGER <= LLVM_VERSION_CODE(3, 9, 1)
251            if (Sym) return Sym.toRuntimeDyldSymbol();
252            return RuntimeDyld::SymbolInfo(nullptr);
253            #else
254            if (Sym) return Sym;
255            return JITSymbol(nullptr);
256            #endif
257        },
258        [&](const std::string &Name) {
259            auto SymAddr = RTDyldMemoryManager::getSymbolAddressInProcess(Name);
260            if (!SymAddr) SymAddr = RTDyldMemoryManager::getSymbolAddressInProcess(getMangledName(Name));
261            #if LLVM_VERSION_INTEGER <= LLVM_VERSION_CODE(3, 9, 1)
262            if (SymAddr) return RuntimeDyld::SymbolInfo(SymAddr, JITSymbolFlags::Exported);
263            return RuntimeDyld::SymbolInfo(nullptr);
264            #else
265            if (SymAddr) return JITSymbol(SymAddr, JITSymbolFlags::Exported);
266            return JITSymbol(nullptr);
267            #endif
268        });
269    #endif
270
271    iBuilder->setModule(mMainModule);
272    mPassManager.run(*mMainModule);
273    #ifdef ORCJIT
[5915]274    std::vector<std::unique_ptr<Module>> moduleSet;
[6184]275    moduleSet.reserve(mCachedKernel.size());
276    #endif
277    for (const auto & kernel : mCachedKernel) {
278        if (LLVM_UNLIKELY(kernel->getModule() == nullptr)) {
279            report_fatal_error(kernel->getName() + " was neither loaded from cache nor generated prior to finalizeObject");
[5464]280        }
[6184]281        #ifndef ORCJIT
282        mEngine->addModule(std::unique_ptr<Module>(kernel->getModule()));
283        #else
284        moduleSet.push_back(std::unique_ptr<Module>(kernel->getModule()));
285        #endif
[5464]286    }
[6184]287    mCachedKernel.clear();
288    // compile any uncompiled kernel/method
289    #ifndef ORCJIT
290    mEngine->finalizeObject();
291    #else
292    moduleSet.push_back(std::unique_ptr<Module>(mMainModule);
293    mCompileLayer->addModuleSet(std::move(moduleSet), make_unique<SectionMemoryManager>(), std::move(Resolver));
294    #endif
[5464]295
[6184]296    // return the compiled main method
297    #ifndef ORCJIT
298    return mEngine->getPointerToFunction(mainMethod);
299    #else
300    auto MainSym = mCompileLayer->findSymbol(getMangledName(mMainMethod->getName()), false);
301    assert (MainSym && "Main not found");
302    return (void *)MainSym.getAddress();
303    #endif
[5464]304}
305
[6184]306bool CPUDriver::hasExternalFunction(llvm::StringRef functionName) const {
[5913]307    return RTDyldMemoryManager::getSymbolAddressInProcess(functionName);
[5493]308}
309
[6184]310CPUDriver::~CPUDriver() {
311    #ifndef ORCJIT
[5474]312    delete mEngine;
[6184]313    #endif
[5464]314    delete mTarget;
[6184]315}
Note: See TracBrowser for help on using the repository browser.