source: icGREP/icgrep-devel/icgrep/toolchain/cpudriver.cpp @ 5928

Last change on this file since 5928 was 5928, checked in by cameron, 16 months ago

Fix segfault for -ShowASM

File size: 13.0 KB
Line 
1#include "cpudriver.h"
2
3#include <IR_Gen/idisa_target.h>
4#include <toolchain/toolchain.h>
5#include <llvm/Support/DynamicLibrary.h>           // for LoadLibraryPermanently
6#include <llvm/ExecutionEngine/ExecutionEngine.h>  // for EngineBuilder
7#include <llvm/ExecutionEngine/RTDyldMemoryManager.h>
8
9#include <llvm/IR/LegacyPassManager.h>             // for PassManager
10#include <llvm/IR/IRPrintingPasses.h>
11#include <llvm/InitializePasses.h>                 // for initializeCodeGencd .
12#include <llvm/PassRegistry.h>                     // for PassRegistry
13#include <llvm/Support/CodeGen.h>                  // for Level, Level::None
14#include <llvm/Support/Compiler.h>                 // for LLVM_UNLIKELY
15#include <llvm/Support/TargetSelect.h>
16#include <llvm/Support/FileSystem.h>
17#include <llvm/Target/TargetMachine.h>             // for TargetMachine, Tar...
18#include <llvm/Target/TargetOptions.h>             // for TargetOptions
19#include <llvm/Transforms/Scalar.h>
20#if LLVM_VERSION_INTEGER >= LLVM_VERSION_CODE(3, 9, 0)
21#include <llvm/Transforms/Scalar/GVN.h>
22#endif
23#include <llvm/Transforms/Utils/Local.h>
24#include <toolchain/object_cache.h>
25#include <toolchain/pipeline.h>
26#include <kernels/kernel_builder.h>
27#include <kernels/kernel.h>
28#include <llvm/IR/Verifier.h>
29#include "llvm/IR/Mangler.h"
30#ifdef ORCJIT
31#if LLVM_VERSION_INTEGER < LLVM_VERSION_CODE(4, 0, 0)
32#include <llvm/ExecutionEngine/Orc/JITSymbol.h>
33#else
34#include <llvm/ExecutionEngine/JITSymbol.h>
35#endif
36#include <llvm/ExecutionEngine/RuntimeDyld.h>
37#include <llvm/ExecutionEngine/SectionMemoryManager.h>
38#include <llvm/ExecutionEngine/Orc/CompileUtils.h>
39#include <llvm/ExecutionEngine/Orc/IRCompileLayer.h>
40#include <llvm/ExecutionEngine/Orc/IRTransformLayer.h>
41#include <llvm/ExecutionEngine/Orc/LambdaResolver.h>
42#if LLVM_VERSION_INTEGER < LLVM_VERSION_CODE(5, 0, 0)
43#include <llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h>
44#else
45#include <llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h>
46#endif
47#include <llvm/ExecutionEngine/Orc/GlobalMappingLayer.h>
48#endif
49
50#ifndef NDEBUG
51#define IN_DEBUG_MODE true
52#else
53#define IN_DEBUG_MODE false
54#endif
55
56using namespace llvm;
57using Kernel = kernel::Kernel;
58using StreamSetBuffer = parabix::StreamSetBuffer;
59using KernelBuilder = kernel::KernelBuilder;
60
61ParabixDriver::ParabixDriver(std::string && moduleName)
62: Driver(std::move(moduleName))
63, mTarget(nullptr)
64#ifndef ORCJIT
65, mEngine(nullptr)
66#endif
67, mCache(nullptr)
68, mIROutputStream(nullptr)
69, mASMOutputStream(nullptr) {
70
71    InitializeNativeTarget();
72    InitializeNativeTargetAsmPrinter();
73    InitializeNativeTargetAsmParser();
74    llvm::sys::DynamicLibrary::LoadLibraryPermanently(nullptr);
75   
76
77#ifdef ORCJIT
78    EngineBuilder builder;
79#else
80    std::string errMessage;
81    EngineBuilder builder{std::unique_ptr<Module>(mMainModule)};
82    builder.setErrorStr(&errMessage);
83    builder.setUseOrcMCJITReplacement(true);
84    builder.setVerifyModules(false);
85    builder.setEngineKind(EngineKind::JIT);
86#endif
87    builder.setTargetOptions(codegen::Options);
88    builder.setOptLevel(codegen::OptLevel);
89
90    StringMap<bool> HostCPUFeatures;
91    if (sys::getHostCPUFeatures(HostCPUFeatures)) {
92        std::vector<std::string> attrs;
93        for (auto &flag : HostCPUFeatures) {
94            if (flag.second) {
95                attrs.push_back("+" + flag.first().str());
96            }
97        }
98        builder.setMAttrs(attrs);
99    }
100
101    mTarget = builder.selectTarget();
102   
103    if (mTarget == nullptr) {
104        throw std::runtime_error("Could not selectTarget");
105    }
106    preparePassManager();
107
108#ifdef ORCJIT
109    mCompileLayer = make_unique<CompileLayerT>(mObjectLayer, orc::SimpleCompiler(*mTarget));
110#else
111    mEngine = builder.create();
112    if (mEngine == nullptr) {
113        throw std::runtime_error("Could not create ExecutionEngine: " + errMessage);
114    }
115#endif
116    if (LLVM_LIKELY(codegen::EnableObjectCache)) {
117        if (codegen::ObjectCacheDir) {
118            mCache = new ParabixObjectCache(codegen::ObjectCacheDir);
119        } else {
120            mCache = new ParabixObjectCache();
121        }
122#ifdef ORCJIT
123#if LLVM_VERSION_INTEGER < LLVM_VERSION_CODE(5, 0, 0)
124        mCompileLayer->setObjectCache(mCache);
125#else
126        mCompileLayer->getCompiler().setObjectCache(mCache);
127#endif
128#else
129        mEngine->setObjectCache(mCache);
130#endif
131    }
132    auto triple = mTarget->getTargetTriple().getTriple();
133    const DataLayout DL(mTarget->createDataLayout());
134    mMainModule->setTargetTriple(triple);
135    mMainModule->setDataLayout(DL);
136    iBuilder.reset(IDISA::GetIDISA_Builder(*mContext));
137    iBuilder->setDriver(this);
138    iBuilder->setModule(mMainModule);
139}
140
141void ParabixDriver::makeKernelCall(Kernel * kernel, const std::vector<StreamSetBuffer *> & inputs, const std::vector<StreamSetBuffer *> & outputs) {
142    assert ("makeKernelCall was already run on this kernel." && (kernel->getModule() == nullptr));
143    mPipeline.emplace_back(kernel);
144    kernel->bindPorts(inputs, outputs);
145    if (!mCache || !mCache->loadCachedObjectFile(iBuilder, kernel)) {
146        mUncachedKernel.push_back(kernel);
147    }
148    if (kernel->getModule() == nullptr) {
149        kernel->makeModule(iBuilder);
150    }
151    assert (kernel->getModule());
152}
153
154void ParabixDriver::generatePipelineIR() {
155
156    for (Kernel * const kernel : mUncachedKernel) {
157        kernel->prepareKernel(iBuilder);
158    }
159
160    // note: instantiation of all kernels must occur prior to initialization
161    for (Kernel * const k : mPipeline) {
162        k->addKernelDeclarations(iBuilder);
163    }
164    for (Kernel * const k : mPipeline) {
165        k->createInstance(iBuilder);
166    }
167    for (Kernel * const k : mPipeline) {
168        k->initializeInstance(iBuilder);
169    }
170    if (codegen::SegmentPipelineParallel) {
171        generateSegmentParallelPipeline(iBuilder, mPipeline);
172    } else {
173        generatePipelineLoop(iBuilder, mPipeline);
174    }
175    for (const auto & k : mPipeline) {
176        k->finalizeInstance(iBuilder);
177    }
178}
179
180
181Function * ParabixDriver::addLinkFunction(Module * mod, llvm::StringRef name, FunctionType * type, void * functionPtr) const {
182    if (LLVM_UNLIKELY(mod == nullptr)) {
183        report_fatal_error("addLinkFunction(" + name + ") cannot be called until after addKernelCall or makeKernelCall");
184    }
185    Function * f = mod->getFunction(name);
186    if (LLVM_UNLIKELY(f == nullptr)) {
187        f = Function::Create(type, Function::ExternalLinkage, name, mod);
188#ifdef ORCJIT
189#else
190        mEngine->updateGlobalMapping(f, functionPtr);
191#endif
192    } else if (LLVM_UNLIKELY(f->getType() != type->getPointerTo())) {
193        report_fatal_error("Cannot link " + name + ": a function with a different signature already exists with that name in " + mod->getName());
194    }
195    return f;
196}
197
198std::string ParabixDriver::getMangledName(std::string s) {
199    DataLayout DL(mTarget->createDataLayout());   
200    std::string MangledName;
201    raw_string_ostream MangledNameStream(MangledName);
202    Mangler::getNameWithPrefix(MangledNameStream, s, DL);
203    return MangledName;
204}
205
206void ParabixDriver::preparePassManager() {
207    PassRegistry * Registry = PassRegistry::getPassRegistry();
208    initializeCore(*Registry);
209    initializeCodeGen(*Registry);
210    initializeLowerIntrinsicsPass(*Registry);
211   
212    if (LLVM_UNLIKELY(codegen::ShowUnoptimizedIROption != codegen::OmittedOption)) {
213        if (LLVM_LIKELY(mIROutputStream == nullptr)) {
214            if (codegen::ShowUnoptimizedIROption != "") {
215                std::error_code error;
216                mIROutputStream = new raw_fd_ostream(codegen::ShowUnoptimizedIROption, error, sys::fs::OpenFlags::F_None);
217            } else {
218                mIROutputStream = new raw_fd_ostream(STDERR_FILENO, false, true);
219            }
220        }
221        mPassManager.add(createPrintModulePass(*mIROutputStream));
222    }
223    if (IN_DEBUG_MODE || LLVM_UNLIKELY(codegen::DebugOptionIsSet(codegen::VerifyIR))) {
224        mPassManager.add(createVerifierPass());
225    }
226    mPassManager.add(createPromoteMemoryToRegisterPass());    // Promote stack variables to constants or PHI nodes
227    mPassManager.add(createCFGSimplificationPass());          // Remove dead basic blocks and unnecessary branch statements / phi nodes
228    mPassManager.add(createEarlyCSEPass());                   // Simple common subexpression elimination pass
229    mPassManager.add(createInstructionCombiningPass());       // Simple peephole optimizations and bit-twiddling.
230    mPassManager.add(createReassociatePass());                // Canonicalizes commutative expressions
231    mPassManager.add(createGVNPass());                        // Global value numbering redundant expression elimination pass
232    mPassManager.add(createCFGSimplificationPass());          // Repeat CFG Simplification to "clean up" any newly found redundant phi nodes
233   
234    if (LLVM_UNLIKELY(codegen::ShowIROption != codegen::OmittedOption)) {
235        if (LLVM_LIKELY(mIROutputStream == nullptr)) {
236            if (codegen::ShowIROption != "") {
237                std::error_code error;
238                mIROutputStream = new raw_fd_ostream(codegen::ShowIROption, error, sys::fs::OpenFlags::F_None);
239            } else {
240                mIROutputStream = new raw_fd_ostream(STDERR_FILENO, false, true);
241            }
242        }
243        mPassManager.add(createPrintModulePass(*mIROutputStream));
244    }
245   
246#if LLVM_VERSION_INTEGER >= LLVM_VERSION_CODE(3, 7, 0)
247    if (LLVM_UNLIKELY(codegen::ShowASMOption != codegen::OmittedOption)) {
248        if (codegen::ShowASMOption != "") {
249            std::error_code error;
250            mASMOutputStream = new raw_fd_ostream(codegen::ShowASMOption, error, sys::fs::OpenFlags::F_None);
251        } else {
252            mASMOutputStream = new raw_fd_ostream(STDERR_FILENO, false, true);
253        }
254        if (LLVM_UNLIKELY(mTarget->addPassesToEmitFile(mPassManager, *mASMOutputStream, TargetMachine::CGFT_AssemblyFile))) {
255            report_fatal_error("LLVM error: could not add emit assembly pass");
256        }
257    }
258#endif
259}
260
261void ParabixDriver::finalizeObject() {
262#ifdef ORCJIT
263    std::vector<std::unique_ptr<Module>> moduleSet;
264    auto Resolver = llvm::orc::createLambdaResolver(
265            [&](const std::string &Name) {
266                auto Sym = mCompileLayer->findSymbol(Name, false);
267                if (!Sym) Sym = mCompileLayer->findSymbol(getMangledName(Name), false);
268#if LLVM_VERSION_INTEGER <= LLVM_VERSION_CODE(3, 9, 1)
269                if (Sym) return Sym.toRuntimeDyldSymbol();
270                return RuntimeDyld::SymbolInfo(nullptr);
271#else
272                if (Sym) return Sym;
273                return JITSymbol(nullptr);
274#endif
275            },
276            [&](const std::string &Name) {
277                auto SymAddr = RTDyldMemoryManager::getSymbolAddressInProcess(Name);
278                if (!SymAddr) SymAddr = RTDyldMemoryManager::getSymbolAddressInProcess(getMangledName(Name));
279#if LLVM_VERSION_INTEGER <= LLVM_VERSION_CODE(3, 9, 1)
280                if (SymAddr) return RuntimeDyld::SymbolInfo(SymAddr, JITSymbolFlags::Exported);
281                return RuntimeDyld::SymbolInfo(nullptr);
282#else
283                if (SymAddr) return JITSymbol(SymAddr, JITSymbolFlags::Exported);
284                return JITSymbol(nullptr);
285#endif
286            });
287#endif
288   
289    Module * module = nullptr;
290    try {
291        for (Kernel * const kernel : mUncachedKernel) {
292            iBuilder->setKernel(kernel);
293            kernel->generateKernel(iBuilder);
294            module = kernel->getModule(); assert (module);
295            module->setTargetTriple(mMainModule->getTargetTriple());
296            mPassManager.run(*module);
297        }
298        module = mMainModule;
299        iBuilder->setKernel(nullptr);
300        //mPassManager.run(*mMainModule);
301        for (Kernel * const kernel : mPipeline) {
302            if (LLVM_UNLIKELY(kernel->getModule() == nullptr)) {
303                report_fatal_error(kernel->getName() + " was neither loaded from cache nor generated prior to finalizeObject");
304            }
305#ifndef ORCJIT
306            mEngine->addModule(std::unique_ptr<Module>(kernel->getModule()));
307#else
308            moduleSet.push_back(std::unique_ptr<Module>(kernel->getModule()));
309#endif
310        }
311#ifndef ORCJIT
312            mEngine->finalizeObject();
313#else
314            moduleSet.push_back(std::unique_ptr<Module>(mMainModule));
315            auto handle = mCompileLayer->addModuleSet(std::move(moduleSet), make_unique<SectionMemoryManager>(), std::move(Resolver));
316#endif
317    } catch (const std::exception & e) {
318        report_fatal_error(module->getName() + ": " + e.what());
319    }
320
321}
322
323bool ParabixDriver::hasExternalFunction(llvm::StringRef functionName) const {
324    return RTDyldMemoryManager::getSymbolAddressInProcess(functionName);
325}
326
327void * ParabixDriver::getMain() {
328#ifndef ORCJIT
329    return mEngine->getPointerToNamedFunction("Main");
330#else
331    auto MainSym = mCompileLayer->findSymbol(getMangledName("Main"), false);
332    assert (MainSym && "Main not found");
333   
334    intptr_t main = (intptr_t) MainSym.getAddress();
335    return (void *) main;
336#endif
337}
338
339void ParabixDriver::performIncrementalCacheCleanupStep() {
340    if (mCache) mCache->performIncrementalCacheCleanupStep();
341}
342
343ParabixDriver::~ParabixDriver() {
344#ifndef ORCJIT
345    delete mEngine;
346#endif
347    delete mCache;
348    delete mTarget;
349    delete mIROutputStream;
350    delete mASMOutputStream;
351 }
Note: See TracBrowser for help on using the repository browser.