source: icGREP/icgrep-devel/icgrep/toolchain/cpudriver.cpp

Last change on this file was 6297, checked in by cameron, 6 months ago

Merge branch 'master' of https://cs-git-research.cs.surrey.sfu.ca/cameron/parabix-devel

File size: 14.6 KB
Line 
1#include "cpudriver.h"
2
3#include <IR_Gen/idisa_target.h>
4#include <toolchain/toolchain.h>
5#include <llvm/Support/DynamicLibrary.h>           // for LoadLibraryPermanently
6#include <llvm/ExecutionEngine/ExecutionEngine.h>  // for EngineBuilder
7#include <llvm/ExecutionEngine/RTDyldMemoryManager.h>
8
9#include <llvm/IR/LegacyPassManager.h>             // for PassManager
10#include <llvm/IR/IRPrintingPasses.h>
11#include <llvm/InitializePasses.h>                 // for initializeCodeGencd .
12#include <llvm/PassRegistry.h>                     // for PassRegistry
13#include <llvm/Support/CodeGen.h>                  // for Level, Level::None
14#include <llvm/Support/Compiler.h>                 // for LLVM_UNLIKELY
15#include <llvm/Support/TargetSelect.h>
16#include <llvm/Support/FileSystem.h>
17#include <llvm/Target/TargetMachine.h>             // for TargetMachine, Tar...
18#include <llvm/Target/TargetOptions.h>             // for TargetOptions
19#include <llvm/Transforms/Scalar.h>
20#if LLVM_VERSION_INTEGER >= LLVM_VERSION_CODE(3, 9, 0)
21#include <llvm/Transforms/Scalar/GVN.h>
22#endif
23#include <llvm/Transforms/Utils/Local.h>
24#include <toolchain/object_cache.h>
25#include <kernels/kernel_builder.h>
26#include <kernels/pipeline_builder.h>
27#include <llvm/IR/Verifier.h>
28#include "llvm/IR/Mangler.h"
29#ifdef ORCJIT
30#if LLVM_VERSION_INTEGER < LLVM_VERSION_CODE(4, 0, 0)
31#include <llvm/ExecutionEngine/Orc/JITSymbol.h>
32#else
33#include <llvm/ExecutionEngine/JITSymbol.h>
34#endif
35#include <llvm/ExecutionEngine/RuntimeDyld.h>
36#include <llvm/ExecutionEngine/SectionMemoryManager.h>
37#include <llvm/ExecutionEngine/Orc/CompileUtils.h>
38#include <llvm/ExecutionEngine/Orc/IRCompileLayer.h>
39#include <llvm/ExecutionEngine/Orc/IRTransformLayer.h>
40#include <llvm/ExecutionEngine/Orc/LambdaResolver.h>
41#if LLVM_VERSION_INTEGER < LLVM_VERSION_CODE(5, 0, 0)
42#include <llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h>
43#else
44#include <llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h>
45#endif
46#include <llvm/ExecutionEngine/Orc/GlobalMappingLayer.h>
47#endif
48
49#ifndef NDEBUG
50#define IN_DEBUG_MODE true
51#else
52#define IN_DEBUG_MODE false
53#endif
54
55using namespace llvm;
56using kernel::Kernel;
57using kernel::PipelineKernel;
58using kernel::StreamSetBuffer;
59using kernel::StreamSetBuffers;
60using kernel::KernelBuilder;
61
62CPUDriver::CPUDriver(std::string && moduleName)
63: BaseDriver(std::move(moduleName))
64, mTarget(nullptr)
65#ifndef ORCJIT
66, mEngine(nullptr)
67#endif
68, mPassManager{}
69, mUnoptimizedIROutputStream{}
70, mIROutputStream{}
71, mASMOutputStream{} {
72
73    InitializeNativeTarget();
74    InitializeNativeTargetAsmPrinter();
75    llvm::sys::DynamicLibrary::LoadLibraryPermanently(nullptr);
76
77    #ifdef ORCJIT
78    EngineBuilder builder;
79    #else
80    std::string errMessage;
81    EngineBuilder builder{std::unique_ptr<Module>(mMainModule)};
82    builder.setErrorStr(&errMessage);
83    builder.setUseOrcMCJITReplacement(true);
84    builder.setVerifyModules(false);
85    builder.setEngineKind(EngineKind::JIT);
86    #endif
87    builder.setTargetOptions(codegen::target_Options);
88    builder.setOptLevel(codegen::OptLevel);
89
90    StringMap<bool> HostCPUFeatures;
91    if (sys::getHostCPUFeatures(HostCPUFeatures)) {
92        std::vector<std::string> attrs;
93        for (auto &flag : HostCPUFeatures) {
94            if (flag.second) {
95                attrs.push_back("+" + flag.first().str());
96            }
97        }
98        builder.setMAttrs(attrs);
99    }
100
101    mTarget = builder.selectTarget();
102
103    if (mTarget == nullptr) {
104        throw std::runtime_error("Could not selectTarget");
105    }
106    #ifdef ORCJIT
107    mCompileLayer = make_unique<CompileLayerT>(mObjectLayer, orc::SimpleCompiler(*mTarget));
108    #else
109    mEngine = builder.create();
110    if (mEngine == nullptr) {
111        throw std::runtime_error("Could not create ExecutionEngine: " + errMessage);
112    }
113    #endif
114    auto cache = ParabixObjectCache::getInstance();
115    if (cache) {
116        #ifdef ORCJIT
117        #if LLVM_VERSION_INTEGER < LLVM_VERSION_CODE(5, 0, 0)
118        mCompileLayer->setObjectCache(cache);
119        #else
120        mCompileLayer->getCompiler().setObjectCache(cache);
121        #endif
122        #else
123        mEngine->setObjectCache(cache);
124        #endif
125    }
126    auto triple = mTarget->getTargetTriple().getTriple();
127    const DataLayout DL(mTarget->createDataLayout());
128    mMainModule->setTargetTriple(triple);
129    mMainModule->setDataLayout(DL);
130    iBuilder.reset(IDISA::GetIDISA_Builder(*mContext));
131    iBuilder->setDriver(this);
132    iBuilder->setModule(mMainModule);
133}
134
135Function * CPUDriver::addLinkFunction(Module * mod, llvm::StringRef name, FunctionType * type, void * functionPtr) const {
136    if (LLVM_UNLIKELY(mod == nullptr)) {
137        report_fatal_error("addLinkFunction(" + name + ") cannot be called until after addKernel");
138    }
139    Function * f = mod->getFunction(name);
140    if (LLVM_UNLIKELY(f == nullptr)) {
141        f = Function::Create(type, Function::ExternalLinkage, name, mod);
142        #ifndef ORCJIT
143        mEngine->updateGlobalMapping(f, functionPtr);
144        #endif
145    } else if (LLVM_UNLIKELY(f->getType() != type->getPointerTo())) {
146        report_fatal_error("Cannot link " + name + ": a function with a different signature already exists with that name in " + mod->getName());
147    }
148    return f;
149}
150
151std::string CPUDriver::getMangledName(std::string s) {
152    #if LLVM_VERSION_INTEGER >= LLVM_VERSION_CODE(3, 9, 0)
153    DataLayout DL(mTarget->createDataLayout());
154    std::string MangledName;
155    raw_string_ostream MangledNameStream(MangledName);
156    Mangler::getNameWithPrefix(MangledNameStream, s, DL);
157    return MangledName;
158    #else
159    return s;
160    #endif
161}
162
163inline void CPUDriver::preparePassManager() {
164
165    if (mPassManager) return;
166
167    mPassManager = make_unique<legacy::PassManager>();
168
169    PassRegistry * Registry = PassRegistry::getPassRegistry();
170    initializeCore(*Registry);
171    initializeCodeGen(*Registry);
172    initializeLowerIntrinsicsPass(*Registry);
173
174    if (LLVM_UNLIKELY(codegen::ShowUnoptimizedIROption != codegen::OmittedOption)) {
175        if (LLVM_LIKELY(mIROutputStream == nullptr)) {
176            if (!codegen::ShowUnoptimizedIROption.empty()) {
177                std::error_code error;
178                mUnoptimizedIROutputStream = make_unique<raw_fd_ostream>(codegen::ShowUnoptimizedIROption, error, sys::fs::OpenFlags::F_None);
179            } else {
180                mUnoptimizedIROutputStream = make_unique<raw_fd_ostream>(STDERR_FILENO, false, true);
181            }
182        }
183        mPassManager->add(createPrintModulePass(*mUnoptimizedIROutputStream));
184    }
185    if (IN_DEBUG_MODE || LLVM_UNLIKELY(codegen::DebugOptionIsSet(codegen::VerifyIR))) {
186        mPassManager->add(createVerifierPass());
187    }
188    mPassManager->add(createDeadCodeEliminationPass());        // Eliminate any trivially dead code
189    mPassManager->add(createPromoteMemoryToRegisterPass());    // Promote stack variables to constants or PHI nodes
190    mPassManager->add(createCFGSimplificationPass());          // Remove dead basic blocks and unnecessary branch statements / phi nodes
191    mPassManager->add(createEarlyCSEPass());                   // Simple common subexpression elimination pass
192    mPassManager->add(createInstructionCombiningPass());       // Simple peephole optimizations and bit-twiddling.
193    mPassManager->add(createReassociatePass());                // Canonicalizes commutative expressions
194    mPassManager->add(createGVNPass());                        // Global value numbering redundant expression elimination pass
195    mPassManager->add(createCFGSimplificationPass());          // Repeat CFG Simplification to "clean up" any newly found redundant phi nodes
196    if (LLVM_UNLIKELY(codegen::DebugOptionIsSet(codegen::EnableAsserts))) {
197        mPassManager->add(createRemoveRedundantAssertionsPass());
198        mPassManager->add(createDeadCodeEliminationPass());
199        mPassManager->add(createCFGSimplificationPass());
200    }
201    if (LLVM_UNLIKELY(!codegen::TraceOption.empty())) {
202        mPassManager->add(createTracePass(iBuilder.get(), codegen::TraceOption));
203    }
204    if (LLVM_UNLIKELY(codegen::ShowIROption != codegen::OmittedOption)) {
205        if (LLVM_LIKELY(mIROutputStream == nullptr)) {
206            if (!codegen::ShowIROption.empty()) {
207                std::error_code error;
208                mIROutputStream = make_unique<raw_fd_ostream>(codegen::ShowIROption, error, sys::fs::OpenFlags::F_None);
209            } else {
210                mIROutputStream = make_unique<raw_fd_ostream>(STDERR_FILENO, false, true);
211            }
212        }
213        mPassManager->add(createPrintModulePass(*mIROutputStream));
214    }
215    #if LLVM_VERSION_INTEGER >= LLVM_VERSION_CODE(3, 7, 0)
216    if (LLVM_UNLIKELY(codegen::ShowASMOption != codegen::OmittedOption)) {
217        if (!codegen::ShowASMOption.empty()) {
218            std::error_code error;
219            mASMOutputStream = make_unique<raw_fd_ostream>(codegen::ShowASMOption, error, sys::fs::OpenFlags::F_None);
220        } else {
221            mASMOutputStream = make_unique<raw_fd_ostream>(STDERR_FILENO, false, true);
222        }
223        if (LLVM_UNLIKELY(mTarget->addPassesToEmitFile(*mPassManager, *mASMOutputStream, TargetMachine::CGFT_AssemblyFile))) {
224            report_fatal_error("LLVM error: could not add emit assembly pass");
225        }
226    }
227    #endif
228}
229
230void CPUDriver::generateUncachedKernels() {
231    if (mUncachedKernel.empty()) return;
232    preparePassManager();
233    for (auto & kernel : mUncachedKernel) {
234        kernel->prepareKernel(iBuilder);
235    }
236    mCachedKernel.reserve(mUncachedKernel.size());
237    for (auto & kernel : mUncachedKernel) {
238        kernel->generateKernel(iBuilder);
239        Module * const module = kernel->getModule(); assert (module);
240        module->setTargetTriple(mMainModule->getTargetTriple());
241        mPassManager->run(*module);
242        mCachedKernel.emplace_back(kernel.release());
243    }
244    mUncachedKernel.clear();
245}
246
247void * CPUDriver::finalizeObject(PipelineKernel * const pipeline) {
248
249    #ifdef ORCJIT
250    auto Resolver = llvm::orc::createLambdaResolver(
251        [&](const std::string &Name) {
252            auto Sym = mCompileLayer->findSymbol(Name, false);
253            if (!Sym) Sym = mCompileLayer->findSymbol(getMangledName(Name), false);
254            #if LLVM_VERSION_INTEGER <= LLVM_VERSION_CODE(3, 9, 1)
255            if (Sym) return Sym.toRuntimeDyldSymbol();
256            return RuntimeDyld::SymbolInfo(nullptr);
257            #else
258            if (Sym) return Sym;
259            return JITSymbol(nullptr);
260            #endif
261        },
262        [&](const std::string &Name) {
263            auto SymAddr = RTDyldMemoryManager::getSymbolAddressInProcess(Name);
264            if (!SymAddr) SymAddr = RTDyldMemoryManager::getSymbolAddressInProcess(getMangledName(Name));
265            #if LLVM_VERSION_INTEGER <= LLVM_VERSION_CODE(3, 9, 1)
266            if (SymAddr) return RuntimeDyld::SymbolInfo(SymAddr, JITSymbolFlags::Exported);
267            return RuntimeDyld::SymbolInfo(nullptr);
268            #else
269            if (SymAddr) return JITSymbol(SymAddr, JITSymbolFlags::Exported);
270            return JITSymbol(nullptr);
271            #endif
272        });
273    #endif
274
275    iBuilder->setModule(mMainModule);
276    // write/declare the "main" method
277    const auto method = pipeline->hasStaticMain() ? PipelineKernel::DeclareExternal : PipelineKernel::AddInternal;
278    Function * const main = pipeline->addOrDeclareMainFunction(iBuilder, method);
279
280    #ifdef ORCJIT
281    std::vector<std::unique_ptr<Module>> moduleSet;
282    moduleSet.reserve(mCachedKernel.size());
283    #endif
284    for (const auto & kernel : mCachedKernel) {
285        if (LLVM_UNLIKELY(kernel->getModule() == nullptr)) {
286            report_fatal_error(kernel->getName() + " was neither loaded from cache nor generated prior to finalizeObject");
287        }
288        kernel->addKernelDeclarations(iBuilder);
289        #ifndef ORCJIT
290        mEngine->addModule(std::unique_ptr<Module>(kernel->getModule()));
291        #else
292        moduleSet.push_back(std::unique_ptr<Module>(kernel->getModule()));
293        #endif
294    }
295    mCachedKernel.clear();
296    // compile any uncompiled kernel/method
297    #ifndef ORCJIT
298    mEngine->finalizeObject();
299    #else
300    moduleSet.push_back(std::unique_ptr<Module>(mMainModule));
301    mCompileLayer->addModuleSet(std::move(moduleSet), make_unique<SectionMemoryManager>(), std::move(Resolver));
302    #endif
303    // return the compiled main method
304    #ifndef ORCJIT
305    return mEngine->getPointerToFunction(main);
306    #else
307    auto MainSym = mCompileLayer->findSymbol(getMangledName(main->getName()), false);
308    assert (MainSym && "Main not found");
309    return (void *)MainSym.getAddress();
310    #endif
311}
312
313bool CPUDriver::hasExternalFunction(llvm::StringRef functionName) const {
314    return RTDyldMemoryManager::getSymbolAddressInProcess(functionName);
315}
316
317CPUDriver::~CPUDriver() {
318    #ifndef ORCJIT
319    delete mEngine;
320    #endif
321    delete mTarget;
322}
323
324
325class TracePass : public ModulePass {
326public:
327    static char ID;
328    TracePass(kernel::KernelBuilder * kb, StringRef to_trace) : ModulePass(ID), iBuilder(kb), mToTrace(to_trace) { }
329
330    bool addTraceStmt(BasicBlock * BB, BasicBlock::iterator to_trace, BasicBlock::iterator insert_pt) {
331        bool modified = false;
332        Type * t = (*to_trace).getType();
333        //t->dump();
334        if (t == iBuilder->getBitBlockType()) {
335            iBuilder->SetInsertPoint(BB, insert_pt);
336            iBuilder->CallPrintRegister((*to_trace).getName(), &*to_trace);
337            modified = true;
338        }
339        else if (t == iBuilder->getInt64Ty()) {
340            iBuilder->SetInsertPoint(BB, insert_pt);
341            iBuilder->CallPrintInt((*to_trace).getName(), &*to_trace);
342            modified = true;
343        }
344        return modified;
345    }
346
347    virtual bool runOnModule(Module &M) override;
348private:
349    kernel::KernelBuilder * iBuilder;
350    StringRef mToTrace;
351};
352
353char TracePass::ID = 0;
354
355bool TracePass::runOnModule(Module & M) {
356    Module * saveModule = iBuilder->getModule();
357    iBuilder->setModule(&M);
358    bool modified = false;
359    for (auto & F : M) {
360        for (auto & B : F) {
361            std::vector<BasicBlock::iterator> tracedPhis;
362            BasicBlock::iterator i = B.begin();
363            while (isa<PHINode>(*i)) {
364                if ((*i).getName().startswith(mToTrace)) {
365                    tracedPhis.push_back(i);
366                }
367                ++i;
368            }
369            for (auto t : tracedPhis) {
370                modified = addTraceStmt(&B, t, i) || modified;
371            }
372            while (i != B.end()) {
373                auto i0 = i;
374                ++i;
375                if ((*i0).getName().startswith(mToTrace)) {
376                    modified = addTraceStmt(&B, i0, i) || modified;
377                }
378            }
379        }
380    }
381    //if (modified) M.dump();
382    iBuilder->setModule(saveModule);
383    return modified;
384}
385
386ModulePass * CPUDriver::createTracePass(kernel::KernelBuilder * kb, StringRef to_trace) {
387    return new TracePass(iBuilder.get(), to_trace);
388}
389
Note: See TracBrowser for help on using the repository browser.