source: icGREP/icgrep-devel/icgrep/toolchain/cpudriver.cpp @ 5920

Last change on this file since 5920 was 5920, checked in by cameron, 13 months ago

Some small fixes and cleanup

File size: 12.9 KB
Line 
1#include "cpudriver.h"
2
3#include <IR_Gen/idisa_target.h>
4#include <toolchain/toolchain.h>
5#include <llvm/Support/DynamicLibrary.h>           // for LoadLibraryPermanently
6#include <llvm/ExecutionEngine/ExecutionEngine.h>  // for EngineBuilder
7#include <llvm/ExecutionEngine/RTDyldMemoryManager.h>
8
9#include <llvm/IR/LegacyPassManager.h>             // for PassManager
10#include <llvm/IR/IRPrintingPasses.h>
11#include <llvm/InitializePasses.h>                 // for initializeCodeGencd .
12#include <llvm/PassRegistry.h>                     // for PassRegistry
13#include <llvm/Support/CodeGen.h>                  // for Level, Level::None
14#include <llvm/Support/Compiler.h>                 // for LLVM_UNLIKELY
15#include <llvm/Support/TargetSelect.h>
16#include <llvm/Support/FileSystem.h>
17#include <llvm/Target/TargetMachine.h>             // for TargetMachine, Tar...
18#include <llvm/Target/TargetOptions.h>             // for TargetOptions
19#include <llvm/Transforms/Scalar.h>
20#if LLVM_VERSION_INTEGER >= LLVM_VERSION_CODE(3, 9, 0)
21#include <llvm/Transforms/Scalar/GVN.h>
22#endif
23#include <llvm/Transforms/Utils/Local.h>
24#include <toolchain/object_cache.h>
25#include <toolchain/pipeline.h>
26#include <kernels/kernel_builder.h>
27#include <kernels/kernel.h>
28#include <llvm/IR/Verifier.h>
29#include "llvm/IR/Mangler.h"
30#ifdef ORCJIT
31#if LLVM_VERSION_INTEGER < LLVM_VERSION_CODE(4, 0, 0)
32#include <llvm/ExecutionEngine/Orc/JITSymbol.h>
33#else
34#include <llvm/ExecutionEngine/JITSymbol.h>
35#endif
36#include <llvm/ExecutionEngine/RuntimeDyld.h>
37#include <llvm/ExecutionEngine/SectionMemoryManager.h>
38#include <llvm/ExecutionEngine/Orc/CompileUtils.h>
39#include <llvm/ExecutionEngine/Orc/IRCompileLayer.h>
40#include <llvm/ExecutionEngine/Orc/IRTransformLayer.h>
41#include <llvm/ExecutionEngine/Orc/LambdaResolver.h>
42#if LLVM_VERSION_INTEGER < LLVM_VERSION_CODE(5, 0, 0)
43#include <llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h>
44#else
45#include <llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h>
46#endif
47#include <llvm/ExecutionEngine/Orc/GlobalMappingLayer.h>
48#endif
49
50#ifndef NDEBUG
51#define IN_DEBUG_MODE true
52#else
53#define IN_DEBUG_MODE false
54#endif
55
56using namespace llvm;
57using Kernel = kernel::Kernel;
58using StreamSetBuffer = parabix::StreamSetBuffer;
59using KernelBuilder = kernel::KernelBuilder;
60
61ParabixDriver::ParabixDriver(std::string && moduleName)
62: Driver(std::move(moduleName))
63, mTarget(nullptr)
64#ifndef ORCJIT
65, mEngine(nullptr)
66#endif
67, mCache(nullptr)
68, mIROutputStream(nullptr)
69, mASMOutputStream(nullptr) {
70
71    InitializeNativeTarget();
72    InitializeNativeTargetAsmPrinter();
73    InitializeNativeTargetAsmParser();
74    llvm::sys::DynamicLibrary::LoadLibraryPermanently(nullptr);
75    preparePassManager();
76   
77
78#ifdef ORCJIT
79    EngineBuilder builder;
80#else
81    std::string errMessage;
82    EngineBuilder builder{std::unique_ptr<Module>(mMainModule)};
83    builder.setErrorStr(&errMessage);
84    builder.setUseOrcMCJITReplacement(true);
85    builder.setVerifyModules(false);
86    builder.setEngineKind(EngineKind::JIT);
87#endif
88    builder.setTargetOptions(codegen::Options);
89    builder.setOptLevel(codegen::OptLevel);
90
91    StringMap<bool> HostCPUFeatures;
92    if (sys::getHostCPUFeatures(HostCPUFeatures)) {
93        std::vector<std::string> attrs;
94        for (auto &flag : HostCPUFeatures) {
95            auto enabled = flag.second ? "+" : "-";
96            attrs.push_back(enabled + flag.first().str());
97        }
98        builder.setMAttrs(attrs);
99    }
100
101    mTarget = builder.selectTarget();
102   
103    if (mTarget == nullptr) {
104        throw std::runtime_error("Could not selectTarget");
105    }
106   
107#ifdef ORCJIT
108    mCompileLayer = make_unique<CompileLayerT>(mObjectLayer, orc::SimpleCompiler(*mTarget));
109#else
110    mEngine = builder.create();
111    if (mEngine == nullptr) {
112        throw std::runtime_error("Could not create ExecutionEngine: " + errMessage);
113    }
114#endif
115    if (LLVM_LIKELY(codegen::EnableObjectCache)) {
116        if (codegen::ObjectCacheDir) {
117            mCache = new ParabixObjectCache(codegen::ObjectCacheDir);
118        } else {
119            mCache = new ParabixObjectCache();
120        }
121#ifdef ORCJIT
122#if LLVM_VERSION_INTEGER < LLVM_VERSION_CODE(5, 0, 0)
123        mCompileLayer->setObjectCache(mCache);
124#else
125        mCompileLayer->getCompiler().setObjectCache(mCache);
126#endif
127#else
128        mEngine->setObjectCache(mCache);
129#endif
130    }
131    auto triple = mTarget->getTargetTriple().getTriple();
132    const DataLayout DL(mTarget->createDataLayout());
133    mMainModule->setTargetTriple(triple);
134    mMainModule->setDataLayout(DL);
135    iBuilder.reset(IDISA::GetIDISA_Builder(*mContext));
136    iBuilder->setDriver(this);
137    iBuilder->setModule(mMainModule);
138}
139
140void ParabixDriver::makeKernelCall(Kernel * kernel, const std::vector<StreamSetBuffer *> & inputs, const std::vector<StreamSetBuffer *> & outputs) {
141    assert ("addKernelCall or makeKernelCall was already run on this kernel." && (kernel->getModule() == nullptr));
142    mPipeline.emplace_back(kernel);
143    kernel->bindPorts(inputs, outputs);
144    if (!mCache || !mCache->loadCachedObjectFile(iBuilder, kernel)) {
145        mUncachedKernel.push_back(kernel);
146    }
147    if (kernel->getModule() == nullptr) {
148        kernel->makeModule(iBuilder);
149    }
150    assert (kernel->getModule());
151}
152
153void ParabixDriver::generatePipelineIR() {
154
155    for (Kernel * const kernel : mUncachedKernel) {
156        kernel->prepareKernel(iBuilder);
157    }
158
159    // note: instantiation of all kernels must occur prior to initialization
160    for (Kernel * const k : mPipeline) {
161        k->addKernelDeclarations(iBuilder);
162    }
163    for (Kernel * const k : mPipeline) {
164        k->createInstance(iBuilder);
165    }
166    for (Kernel * const k : mPipeline) {
167        k->initializeInstance(iBuilder);
168    }
169    if (codegen::SegmentPipelineParallel) {
170        generateSegmentParallelPipeline(iBuilder, mPipeline);
171    } else {
172        generatePipelineLoop(iBuilder, mPipeline);
173    }
174    for (const auto & k : mPipeline) {
175        k->finalizeInstance(iBuilder);
176    }
177}
178
179
180Function * ParabixDriver::addLinkFunction(Module * mod, llvm::StringRef name, FunctionType * type, void * functionPtr) const {
181    if (LLVM_UNLIKELY(mod == nullptr)) {
182        report_fatal_error("addLinkFunction(" + name + ") cannot be called until after addKernelCall or makeKernelCall");
183    }
184    Function * f = mod->getFunction(name);
185    if (LLVM_UNLIKELY(f == nullptr)) {
186        f = Function::Create(type, Function::ExternalLinkage, name, mod);
187#ifdef ORCJIT
188#else
189        mEngine->updateGlobalMapping(f, functionPtr);
190#endif
191    } else if (LLVM_UNLIKELY(f->getType() != type->getPointerTo())) {
192        report_fatal_error("Cannot link " + name + ": a function with a different signature already exists with that name in " + mod->getName());
193    }
194    return f;
195}
196
197std::string ParabixDriver::getMangledName(std::string s) {
198    DataLayout DL(mTarget->createDataLayout());   
199    std::string MangledName;
200    raw_string_ostream MangledNameStream(MangledName);
201    Mangler::getNameWithPrefix(MangledNameStream, s, DL);
202    return MangledName;
203}
204
205void ParabixDriver::preparePassManager() {
206    PassRegistry * Registry = PassRegistry::getPassRegistry();
207    initializeCore(*Registry);
208    initializeCodeGen(*Registry);
209    initializeLowerIntrinsicsPass(*Registry);
210   
211    if (LLVM_UNLIKELY(codegen::ShowUnoptimizedIROption != codegen::OmittedOption)) {
212        if (LLVM_LIKELY(mIROutputStream == nullptr)) {
213            if (codegen::ShowUnoptimizedIROption != "") {
214                std::error_code error;
215                mIROutputStream = new raw_fd_ostream(codegen::ShowUnoptimizedIROption, error, sys::fs::OpenFlags::F_None);
216            } else {
217                mIROutputStream = new raw_fd_ostream(STDERR_FILENO, false, true);
218            }
219        }
220        mPassManager.add(createPrintModulePass(*mIROutputStream));
221    }
222    if (IN_DEBUG_MODE || LLVM_UNLIKELY(codegen::DebugOptionIsSet(codegen::VerifyIR))) {
223        mPassManager.add(createVerifierPass());
224    }
225    mPassManager.add(createPromoteMemoryToRegisterPass());    // Promote stack variables to constants or PHI nodes
226    mPassManager.add(createCFGSimplificationPass());          // Remove dead basic blocks and unnecessary branch statements / phi nodes
227    mPassManager.add(createEarlyCSEPass());                   // Simple common subexpression elimination pass
228    mPassManager.add(createInstructionCombiningPass());       // Simple peephole optimizations and bit-twiddling.
229    mPassManager.add(createReassociatePass());                // Canonicalizes commutative expressions
230    mPassManager.add(createGVNPass());                        // Global value numbering redundant expression elimination pass
231    mPassManager.add(createCFGSimplificationPass());          // Repeat CFG Simplification to "clean up" any newly found redundant phi nodes
232   
233    if (LLVM_UNLIKELY(codegen::ShowIROption != codegen::OmittedOption)) {
234        if (LLVM_LIKELY(mIROutputStream == nullptr)) {
235            if (codegen::ShowIROption != "") {
236                std::error_code error;
237                mIROutputStream = new raw_fd_ostream(codegen::ShowIROption, error, sys::fs::OpenFlags::F_None);
238            } else {
239                mIROutputStream = new raw_fd_ostream(STDERR_FILENO, false, true);
240            }
241        }
242        mPassManager.add(createPrintModulePass(*mIROutputStream));
243    }
244   
245#if LLVM_VERSION_INTEGER >= LLVM_VERSION_CODE(3, 7, 0)
246    if (LLVM_UNLIKELY(codegen::ShowASMOption != codegen::OmittedOption)) {
247        if (codegen::ShowASMOption != "") {
248            std::error_code error;
249            mASMOutputStream = new raw_fd_ostream(codegen::ShowASMOption, error, sys::fs::OpenFlags::F_None);
250        } else {
251            mASMOutputStream = new raw_fd_ostream(STDERR_FILENO, false, true);
252        }
253        if (LLVM_UNLIKELY(mTarget->addPassesToEmitFile(mPassManager, *mASMOutputStream, TargetMachine::CGFT_AssemblyFile))) {
254            report_fatal_error("LLVM error: could not add emit assembly pass");
255        }
256    }
257#endif
258}
259
260void ParabixDriver::finalizeObject() {
261#ifdef ORCJIT
262    std::vector<std::unique_ptr<Module>> moduleSet;
263    auto Resolver = llvm::orc::createLambdaResolver(
264            [&](const std::string &Name) {
265                auto Sym = mCompileLayer->findSymbol(Name, false);
266#if LLVM_VERSION_INTEGER <= LLVM_VERSION_CODE(3, 9, 1)
267                if (Sym) return Sym.toRuntimeDyldSymbol();
268                return RuntimeDyld::SymbolInfo(nullptr);
269#else
270                if (Sym) return Sym;
271                return JITSymbol(nullptr);
272#endif
273            },
274            [&](const std::string &Name) {
275                auto SymAddr = RTDyldMemoryManager::getSymbolAddressInProcess(Name);
276                if (!SymAddr) SymAddr = RTDyldMemoryManager::getSymbolAddressInProcess(getMangledName(Name));
277#if LLVM_VERSION_INTEGER <= LLVM_VERSION_CODE(3, 9, 1)
278                if (SymAddr) return RuntimeDyld::SymbolInfo(SymAddr, JITSymbolFlags::Exported);
279                return RuntimeDyld::SymbolInfo(nullptr);
280#else
281                if (SymAddr) return JITSymbol(SymAddr, JITSymbolFlags::Exported);
282                return JITSymbol(nullptr);
283#endif
284            });
285#endif
286   
287    Module * module = nullptr;
288    try {
289        for (Kernel * const kernel : mUncachedKernel) {
290            iBuilder->setKernel(kernel);
291            kernel->generateKernel(iBuilder);
292            module = kernel->getModule(); assert (module);
293            module->setTargetTriple(mMainModule->getTargetTriple());
294            mPassManager.run(*module);
295        }
296        module = mMainModule;
297        iBuilder->setKernel(nullptr);
298        //mPassManager.run(*mMainModule);
299        for (Kernel * const kernel : mPipeline) {
300            if (LLVM_UNLIKELY(kernel->getModule() == nullptr)) {
301                report_fatal_error(kernel->getName() + " was neither loaded from cache nor generated prior to finalizeObject");
302            }
303#ifndef ORCJIT
304            mEngine->addModule(std::unique_ptr<Module>(kernel->getModule()));
305#else
306            moduleSet.push_back(std::unique_ptr<Module>(kernel->getModule()));
307#endif
308        }
309#ifndef ORCJIT
310            mEngine->finalizeObject();
311#else
312            moduleSet.push_back(std::unique_ptr<Module>(mMainModule));
313            auto handle = mCompileLayer->addModuleSet(std::move(moduleSet), make_unique<SectionMemoryManager>(), std::move(Resolver));
314#endif
315    } catch (const std::exception & e) {
316        report_fatal_error(module->getName() + ": " + e.what());
317    }
318
319}
320
321bool ParabixDriver::hasExternalFunction(llvm::StringRef functionName) const {
322    return RTDyldMemoryManager::getSymbolAddressInProcess(functionName);
323}
324
325void * ParabixDriver::getMain() {
326#ifndef ORCJIT
327    return mEngine->getPointerToNamedFunction("Main");
328#else
329    auto MainSym = mCompileLayer->findSymbol(getMangledName("Main"), false);
330    assert (MainSym && "Main not found");
331   
332    intptr_t main = (intptr_t) MainSym.getAddress();
333    return (void *) main;
334#endif
335}
336
337void ParabixDriver::performIncrementalCacheCleanupStep() {
338    if (mCache) mCache->performIncrementalCacheCleanupStep();
339}
340
341ParabixDriver::~ParabixDriver() {
342#ifndef ORCJIT
343    delete mEngine;
344#endif
345    delete mCache;
346    delete mTarget;
347    delete mIROutputStream;
348    delete mASMOutputStream;
349 }
Note: See TracBrowser for help on using the repository browser.