source: icGREP/icgrep-devel/icgrep/toolchain/cpudriver.cpp @ 6184

Last change on this file since 6184 was 6184, checked in by nmedfort, 9 months ago

Initial version of PipelineKernel? + revised StreamSet? model.

File size: 12.1 KB
Line 
1#include "cpudriver.h"
2
3#include <IR_Gen/idisa_target.h>
4#include <toolchain/toolchain.h>
5#include <llvm/Support/DynamicLibrary.h>           // for LoadLibraryPermanently
6#include <llvm/ExecutionEngine/ExecutionEngine.h>  // for EngineBuilder
7#include <llvm/ExecutionEngine/RTDyldMemoryManager.h>
8
9#include <llvm/IR/LegacyPassManager.h>             // for PassManager
10#include <llvm/IR/IRPrintingPasses.h>
11#include <llvm/InitializePasses.h>                 // for initializeCodeGencd .
12#include <llvm/PassRegistry.h>                     // for PassRegistry
13#include <llvm/Support/CodeGen.h>                  // for Level, Level::None
14#include <llvm/Support/Compiler.h>                 // for LLVM_UNLIKELY
15#include <llvm/Support/TargetSelect.h>
16#include <llvm/Support/FileSystem.h>
17#include <llvm/Target/TargetMachine.h>             // for TargetMachine, Tar...
18#include <llvm/Target/TargetOptions.h>             // for TargetOptions
19#include <llvm/Transforms/Scalar.h>
20#if LLVM_VERSION_INTEGER >= LLVM_VERSION_CODE(3, 9, 0)
21#include <llvm/Transforms/Scalar/GVN.h>
22#endif
23#include <llvm/Transforms/Utils/Local.h>
24#include <toolchain/object_cache.h>
25#include <kernels/kernel_builder.h>
26#include <kernels/pipeline_builder.h>
27#include <llvm/IR/Verifier.h>
28#include "llvm/IR/Mangler.h"
29#ifdef ORCJIT
30#if LLVM_VERSION_INTEGER < LLVM_VERSION_CODE(4, 0, 0)
31#include <llvm/ExecutionEngine/Orc/JITSymbol.h>
32#else
33#include <llvm/ExecutionEngine/JITSymbol.h>
34#endif
35#include <llvm/ExecutionEngine/RuntimeDyld.h>
36#include <llvm/ExecutionEngine/SectionMemoryManager.h>
37#include <llvm/ExecutionEngine/Orc/CompileUtils.h>
38#include <llvm/ExecutionEngine/Orc/IRCompileLayer.h>
39#include <llvm/ExecutionEngine/Orc/IRTransformLayer.h>
40#include <llvm/ExecutionEngine/Orc/LambdaResolver.h>
41#if LLVM_VERSION_INTEGER < LLVM_VERSION_CODE(5, 0, 0)
42#include <llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h>
43#else
44#include <llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h>
45#endif
46#include <llvm/ExecutionEngine/Orc/GlobalMappingLayer.h>
47#endif
48
49#ifndef NDEBUG
50#define IN_DEBUG_MODE true
51#else
52#define IN_DEBUG_MODE false
53#endif
54
55using namespace llvm;
56using kernel::Kernel;
57using kernel::PipelineKernel;
58using kernel::StreamSetBuffer;
59using kernel::StreamSetBuffers;
60using kernel::KernelBuilder;
61
62CPUDriver::CPUDriver(std::string && moduleName)
63: BaseDriver(std::move(moduleName))
64, mTarget(nullptr)
65#ifndef ORCJIT
66, mEngine(nullptr)
67#endif
68, mUnoptimizedIROutputStream(nullptr)
69, mIROutputStream(nullptr)
70, mASMOutputStream(nullptr) {
71
72    InitializeNativeTarget();
73    InitializeNativeTargetAsmPrinter();
74    InitializeNativeTargetAsmParser();
75    llvm::sys::DynamicLibrary::LoadLibraryPermanently(nullptr);
76   
77
78    #ifdef ORCJIT
79    EngineBuilder builder;
80    #else
81    std::string errMessage;
82    EngineBuilder builder{std::unique_ptr<Module>(mMainModule)};
83    builder.setErrorStr(&errMessage);
84    builder.setUseOrcMCJITReplacement(true);
85    builder.setVerifyModules(false);
86    builder.setEngineKind(EngineKind::JIT);
87    #endif
88    builder.setTargetOptions(codegen::target_Options);
89    builder.setOptLevel(codegen::OptLevel);
90
91    StringMap<bool> HostCPUFeatures;
92    if (sys::getHostCPUFeatures(HostCPUFeatures)) {
93        std::vector<std::string> attrs;
94        for (auto &flag : HostCPUFeatures) {
95            if (flag.second) {
96                attrs.push_back("+" + flag.first().str());
97            }
98        }
99        builder.setMAttrs(attrs);
100    }
101
102    mTarget = builder.selectTarget();
103   
104    if (mTarget == nullptr) {
105        throw std::runtime_error("Could not selectTarget");
106    }
107    preparePassManager();
108
109    #ifdef ORCJIT
110    mCompileLayer = make_unique<CompileLayerT>(mObjectLayer, orc::SimpleCompiler(*mTarget));
111    #else
112    mEngine = builder.create();
113    if (mEngine == nullptr) {
114        throw std::runtime_error("Could not create ExecutionEngine: " + errMessage);
115    }
116    #endif
117    auto cache = ObjectCacheManager::getObjectCache();
118    if (cache) {
119        #ifdef ORCJIT
120        #if LLVM_VERSION_INTEGER < LLVM_VERSION_CODE(5, 0, 0)
121        mCompileLayer->setObjectCache(cache);
122        #else
123        mCompileLayer->getCompiler().setObjectCache(cache);
124        #endif
125        #else
126        mEngine->setObjectCache(cache);
127        #endif
128    }
129    auto triple = mTarget->getTargetTriple().getTriple();
130    const DataLayout DL(mTarget->createDataLayout());
131    mMainModule->setTargetTriple(triple);
132    mMainModule->setDataLayout(DL);
133    iBuilder.reset(IDISA::GetIDISA_Builder(*mContext));
134    iBuilder->setDriver(this);
135    iBuilder->setModule(mMainModule);
136}
137
138Function * CPUDriver::addLinkFunction(Module * mod, llvm::StringRef name, FunctionType * type, void * functionPtr) const {
139    if (LLVM_UNLIKELY(mod == nullptr)) {
140        report_fatal_error("addLinkFunction(" + name + ") cannot be called until after addKernelCall or makeKernelCall");
141    }
142    Function * f = mod->getFunction(name);
143    if (LLVM_UNLIKELY(f == nullptr)) {
144        f = Function::Create(type, Function::ExternalLinkage, name, mod);
145        #ifndef ORCJIT
146        mEngine->updateGlobalMapping(f, functionPtr);
147        #endif
148    } else if (LLVM_UNLIKELY(f->getType() != type->getPointerTo())) {
149        report_fatal_error("Cannot link " + name + ": a function with a different signature already exists with that name in " + mod->getName());
150    }
151    return f;
152}
153
154std::string CPUDriver::getMangledName(std::string s) {
155    #if LLVM_VERSION_INTEGER >= LLVM_VERSION_CODE(3, 9, 0)
156    DataLayout DL(mTarget->createDataLayout());   
157    std::string MangledName;
158    raw_string_ostream MangledNameStream(MangledName);
159    Mangler::getNameWithPrefix(MangledNameStream, s, DL);
160    return MangledName;
161    #else
162    return s;
163    #endif
164}
165
166void CPUDriver::preparePassManager() {
167    PassRegistry * Registry = PassRegistry::getPassRegistry();
168    initializeCore(*Registry);
169    initializeCodeGen(*Registry);
170    initializeLowerIntrinsicsPass(*Registry);
171   
172    if (LLVM_UNLIKELY(codegen::ShowUnoptimizedIROption != codegen::OmittedOption)) {
173        if (LLVM_LIKELY(mIROutputStream == nullptr)) {
174            if (codegen::ShowUnoptimizedIROption != "") {
175                std::error_code error;
176                mUnoptimizedIROutputStream = make_unique<raw_fd_ostream>(codegen::ShowUnoptimizedIROption, error, sys::fs::OpenFlags::F_None);
177            } else {
178                mUnoptimizedIROutputStream = make_unique<raw_fd_ostream>(STDERR_FILENO, false, true);
179            }
180        }
181        mPassManager.add(createPrintModulePass(*mUnoptimizedIROutputStream));
182    }
183    if (IN_DEBUG_MODE || LLVM_UNLIKELY(codegen::DebugOptionIsSet(codegen::VerifyIR))) {
184        mPassManager.add(createVerifierPass());
185    }
186
187    mPassManager.add(createDeadCodeEliminationPass());        // Eliminate any trivially dead code
188    mPassManager.add(createPromoteMemoryToRegisterPass());    // Promote stack variables to constants or PHI nodes
189    mPassManager.add(createCFGSimplificationPass());          // Remove dead basic blocks and unnecessary branch statements / phi nodes
190    mPassManager.add(createEarlyCSEPass());                   // Simple common subexpression elimination pass
191    mPassManager.add(createInstructionCombiningPass());       // Simple peephole optimizations and bit-twiddling.
192    mPassManager.add(createReassociatePass());                // Canonicalizes commutative expressions
193    mPassManager.add(createGVNPass());                        // Global value numbering redundant expression elimination pass
194    mPassManager.add(createCFGSimplificationPass());          // Repeat CFG Simplification to "clean up" any newly found redundant phi nodes
195    if (LLVM_UNLIKELY(codegen::DebugOptionIsSet(codegen::EnableAsserts))) {
196        mPassManager.add(createRemoveRedundantAssertionsPass());
197        mPassManager.add(createDeadCodeEliminationPass());
198        mPassManager.add(createCFGSimplificationPass());
199    }
200
201    if (LLVM_UNLIKELY(codegen::ShowIROption != codegen::OmittedOption)) {
202        if (LLVM_LIKELY(mIROutputStream == nullptr)) {
203            if (codegen::ShowIROption != "") {
204                std::error_code error;
205                mIROutputStream = make_unique<raw_fd_ostream>(codegen::ShowIROption, error, sys::fs::OpenFlags::F_None);
206            } else {
207                mIROutputStream = make_unique<raw_fd_ostream>(STDERR_FILENO, false, true);
208            }
209        }
210        mPassManager.add(createPrintModulePass(*mIROutputStream));
211    }
212   
213#if LLVM_VERSION_INTEGER >= LLVM_VERSION_CODE(3, 7, 0)
214    if (LLVM_UNLIKELY(codegen::ShowASMOption != codegen::OmittedOption)) {
215        if (codegen::ShowASMOption != "") {
216            std::error_code error;
217            mASMOutputStream = make_unique<raw_fd_ostream>(codegen::ShowASMOption, error, sys::fs::OpenFlags::F_None);
218        } else {
219            mASMOutputStream = make_unique<raw_fd_ostream>(STDERR_FILENO, false, true);
220        }
221        if (LLVM_UNLIKELY(mTarget->addPassesToEmitFile(mPassManager, *mASMOutputStream, TargetMachine::CGFT_AssemblyFile))) {
222            report_fatal_error("LLVM error: could not add emit assembly pass");
223        }
224    }
225#endif
226}
227
228void CPUDriver::generateUncachedKernels() {
229    for (auto & kernel : mUncachedKernel) {
230        kernel->prepareKernel(iBuilder);
231    }
232    mCachedKernel.reserve(mUncachedKernel.size());
233    for (auto & kernel : mUncachedKernel) {
234        kernel->generateKernel(iBuilder);
235        Module * const module = kernel->getModule(); assert (module);
236        module->setTargetTriple(mMainModule->getTargetTriple());
237        mPassManager.run(*module);
238        mCachedKernel.emplace_back(kernel.release());
239    }
240    mUncachedKernel.clear();
241}
242
243void * CPUDriver::finalizeObject(llvm::Function * mainMethod) {
244
245    #ifdef ORCJIT
246    auto Resolver = llvm::orc::createLambdaResolver(
247        [&](const std::string &Name) {
248            auto Sym = mCompileLayer->findSymbol(Name, false);
249            if (!Sym) Sym = mCompileLayer->findSymbol(getMangledName(Name), false);
250            #if LLVM_VERSION_INTEGER <= LLVM_VERSION_CODE(3, 9, 1)
251            if (Sym) return Sym.toRuntimeDyldSymbol();
252            return RuntimeDyld::SymbolInfo(nullptr);
253            #else
254            if (Sym) return Sym;
255            return JITSymbol(nullptr);
256            #endif
257        },
258        [&](const std::string &Name) {
259            auto SymAddr = RTDyldMemoryManager::getSymbolAddressInProcess(Name);
260            if (!SymAddr) SymAddr = RTDyldMemoryManager::getSymbolAddressInProcess(getMangledName(Name));
261            #if LLVM_VERSION_INTEGER <= LLVM_VERSION_CODE(3, 9, 1)
262            if (SymAddr) return RuntimeDyld::SymbolInfo(SymAddr, JITSymbolFlags::Exported);
263            return RuntimeDyld::SymbolInfo(nullptr);
264            #else
265            if (SymAddr) return JITSymbol(SymAddr, JITSymbolFlags::Exported);
266            return JITSymbol(nullptr);
267            #endif
268        });
269    #endif
270
271    iBuilder->setModule(mMainModule);
272    mPassManager.run(*mMainModule);
273    #ifdef ORCJIT
274    std::vector<std::unique_ptr<Module>> moduleSet;
275    moduleSet.reserve(mCachedKernel.size());
276    #endif
277    for (const auto & kernel : mCachedKernel) {
278        if (LLVM_UNLIKELY(kernel->getModule() == nullptr)) {
279            report_fatal_error(kernel->getName() + " was neither loaded from cache nor generated prior to finalizeObject");
280        }
281        #ifndef ORCJIT
282        mEngine->addModule(std::unique_ptr<Module>(kernel->getModule()));
283        #else
284        moduleSet.push_back(std::unique_ptr<Module>(kernel->getModule()));
285        #endif
286    }
287    mCachedKernel.clear();
288    // compile any uncompiled kernel/method
289    #ifndef ORCJIT
290    mEngine->finalizeObject();
291    #else
292    moduleSet.push_back(std::unique_ptr<Module>(mMainModule);
293    mCompileLayer->addModuleSet(std::move(moduleSet), make_unique<SectionMemoryManager>(), std::move(Resolver));
294    #endif
295
296    // return the compiled main method
297    #ifndef ORCJIT
298    return mEngine->getPointerToFunction(mainMethod);
299    #else
300    auto MainSym = mCompileLayer->findSymbol(getMangledName(mMainMethod->getName()), false);
301    assert (MainSym && "Main not found");
302    return (void *)MainSym.getAddress();
303    #endif
304}
305
306bool CPUDriver::hasExternalFunction(llvm::StringRef functionName) const {
307    return RTDyldMemoryManager::getSymbolAddressInProcess(functionName);
308}
309
310CPUDriver::~CPUDriver() {
311    #ifndef ORCJIT
312    delete mEngine;
313    #endif
314    delete mTarget;
315}
Note: See TracBrowser for help on using the repository browser.