source: icGREP/icgrep-devel/icgrep/wc.cpp @ 5033

Last change on this file since 5033 was 5033, checked in by cameron, 3 years ago

Refactor: move grep-specific code out of toolchain

File size: 16.5 KB
Line 
1/*
2 *  Copyright (c) 2015 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 *  icgrep is a trademark of International Characters.
5 */
6
7#include <string>
8#include <iostream>
9#include <iomanip>
10#include <fstream>
11#include <sstream>
12
13
14#include <toolchain.h>
15#include <llvm/IR/Function.h>
16#include <llvm/IR/Module.h>
17#include <llvm/ExecutionEngine/ExecutionEngine.h>
18#include <llvm/ExecutionEngine/MCJIT.h>
19
20#include <llvm/Support/CommandLine.h>
21#include <llvm/Support/raw_ostream.h>
22
23#include <utf_encoding.h>
24#include <re/re_cc.h>
25#include <cc/cc_compiler.h>
26#include <pablo/function.h>
27#include <IDISA/idisa_builder.h>
28#include <IDISA/idisa_target.h>
29#include <kernels/instance.h>
30#include <kernels/kernel.h>
31#include <kernels/s2p_kernel.h>
32
33#include <pablo/pablo_compiler.h>
34#include <pablo/pablo_toolchain.h>
35
36
37#include <utf_encoding.h>
38
39// mmap system
40#include <boost/filesystem.hpp>
41#include <boost/iostreams/device/mapped_file.hpp>
42using namespace boost::iostreams;
43using namespace boost::filesystem;
44
45#include <fcntl.h>
46static cl::OptionCategory wcFlags("Command Flags", "wc options");
47
48static cl::list<std::string> inputFiles(cl::Positional, cl::desc("<input file ...>"), cl::OneOrMore, cl::cat(wcFlags));
49
50enum CountOptions {
51    LineOption, WordOption, CharOption, ByteOption
52};
53
54static cl::list<CountOptions> wcOptions(
55  cl::values(clEnumValN(LineOption, "l", "Report the number of lines in each input file."),
56             clEnumValN(WordOption, "w", "Report the number of words in each input file."),
57             clEnumValN(CharOption, "m", "Report the number of characters in each input file (override -c)."),
58             clEnumValN(ByteOption, "c", "Report the number of bytes in each input file (override -m)."),
59             clEnumValEnd), cl::cat(wcFlags), cl::Grouping);
60                                                 
61static cl::OptionCategory eIRDumpOptions("LLVM IR Dump Options", "These options control dumping of LLVM IR.");
62static cl::opt<bool> DumpGeneratedIR("dump-generated-IR", cl::init(false), cl::desc("Print LLVM IR generated by Pablo Compiler."), cl::cat(eIRDumpOptions));
63
64static cl::OptionCategory cMachineCodeOptimization("Machine Code Optimizations", "These options control back-end compilier optimization levels.");
65
66static cl::opt<char> OptLevel("O", cl::desc("Optimization level. [-O0, -O1, -O2, or -O3] (default = '-O0')"),
67                              cl::cat(cMachineCodeOptimization), cl::Prefix, cl::ZeroOrMore, cl::init('0'));
68
69static cl::opt<unsigned> SegmentSize("segment-size", cl::desc("Segment Size"), cl::value_desc("positive integer"), cl::init(1));
70
71
72static int defaultFieldWidth = 7;  // default field width
73
74
75bool CountLines = false;
76bool CountWords = false;
77bool CountChars = false;
78bool CountBytes = false;
79
80std::vector<uint64_t> lineCount;
81std::vector<uint64_t> wordCount;
82std::vector<uint64_t> charCount;
83std::vector<uint64_t> byteCount;
84
85uint64_t TotalLines = 0;
86uint64_t TotalWords = 0;
87uint64_t TotalChars = 0;
88uint64_t TotalBytes = 0;
89
90
91//  The callback routine that records counts in progress.
92//
93extern "C" {
94    void record_counts(uint64_t lines, uint64_t words, uint64_t chars, uint64_t bytes, uint64_t fileIdx) {
95        lineCount[fileIdx] = lines;
96        wordCount[fileIdx] = words;
97        charCount[fileIdx] = chars;
98        byteCount[fileIdx] = bytes;
99        TotalLines += lines;
100        TotalWords += words;
101        TotalChars += chars;
102        TotalBytes += bytes;
103    }
104}
105
106//
107//
108
109pablo::PabloFunction * wc_gen(Encoding encoding) {
110    //  input: 8 basis bit streams
111    //  output: 3 count streams
112   
113    pablo::PabloFunction * function = pablo::PabloFunction::Create("wc", 8, 3);
114    cc::CC_Compiler ccc(*function, encoding);
115   
116    pablo::PabloBuilder pBuilder(ccc.getBuilder().getPabloBlock(), ccc.getBuilder());
117    const std::vector<pablo::Var *> u8_bits = ccc.getBasisBits();
118
119    if (CountLines) {
120        pablo::PabloAST * LF = ccc.compileCC(re::makeCC(0x0A));
121        function->setResult(0, pBuilder.createAssign("lineCount", pBuilder.createCount(LF)));
122    }
123    else function->setResult(0, pBuilder.createAssign("lineCount", pBuilder.createZeroes()));
124    // FIXME - we need to limit this to pablo.inFile() because null bytes past EOF are matched by wordChar
125    if (CountWords) {
126        pablo::PabloAST * WS = ccc.compileCC(re::makeCC(re::makeCC(0x09, 0x0D), re::makeCC(0x20)));
127       
128        pablo::PabloAST * wordChar = ccc.compileCC(re::makeCC(re::makeCC(re::makeCC(0x00, 0x08), re::makeCC(0xE, 0x1F)), re::makeCC(0x21, 0xFF)));
129        // WS_follow_or_start = 1 past WS or at start of file
130        pablo::PabloAST * WS_follow_or_start = pBuilder.createNot(pBuilder.createAdvance(pBuilder.createNot(WS), 1));
131        //
132        pablo::PabloAST * wordStart = pBuilder.createAnd(wordChar, WS_follow_or_start);
133        function->setResult(1, pBuilder.createAssign("wordCount", pBuilder.createCount(wordStart)));
134    }
135    else function->setResult(1, pBuilder.createAssign("wordCount", pBuilder.createZeroes()));
136    if (CountChars) {
137        //
138        // FIXME: This correctly counts characters assuming valid UTF-8 input.  But what if input is
139        // not UTF-8, or is not valid?
140        //
141        pablo::PabloAST * u8Begin = ccc.compileCC(re::makeCC(re::makeCC(0, 0x7F), re::makeCC(0xC2, 0xF4)));
142        function->setResult(2, pBuilder.createAssign("charCount", pBuilder.createCount(u8Begin)));
143    }
144    else function->setResult(2, pBuilder.createAssign("charCount", pBuilder.createZeroes()));
145    return function;
146}
147
148using namespace kernel;
149
150
151class wcPipelineBuilder {
152public:
153    wcPipelineBuilder(llvm::Module * m, IDISA::IDISA_Builder * b);
154   
155    ~wcPipelineBuilder();
156   
157    void CreateKernels(pablo::PabloFunction * function);
158    llvm::Function * ExecuteKernels();
159   
160private:
161    llvm::Module *                      mMod;
162    IDISA::IDISA_Builder *              iBuilder;
163    KernelBuilder *                     mS2PKernel;
164    KernelBuilder *                     mWC_Kernel;
165    llvm::Type *                        mBitBlockType;
166    int                                 mBlockSize;
167};
168
169
170using namespace pablo;
171using namespace kernel;
172
173wcPipelineBuilder::wcPipelineBuilder(Module * m, IDISA::IDISA_Builder * b)
174: mMod(m)
175, iBuilder(b)
176, mBitBlockType(b->getBitBlockType())
177, mBlockSize(b->getBitBlockWidth()){
178   
179}
180
181wcPipelineBuilder::~wcPipelineBuilder(){
182    delete mS2PKernel;
183    delete mWC_Kernel;
184}
185
186void wcPipelineBuilder::CreateKernels(PabloFunction * function){
187    mS2PKernel = new KernelBuilder(iBuilder, "s2p", SegmentSize);
188    mWC_Kernel = new KernelBuilder(iBuilder, "wc", SegmentSize);
189   
190    generateS2PKernel(mMod, iBuilder, mS2PKernel);
191   
192    pablo_function_passes(function);
193   
194    PabloCompiler pablo_compiler(mMod, iBuilder);
195    try {
196        pablo_compiler.setKernel(mWC_Kernel);
197        pablo_compiler.compile(function);
198        delete function;
199        releaseSlabAllocatorMemory();
200    } catch (std::runtime_error e) {
201        delete function;
202        releaseSlabAllocatorMemory();
203        std::cerr << "Runtime error: " << e.what() << std::endl;
204        exit(1);
205    }
206   
207}
208
209
210
211
212Function * wcPipelineBuilder::ExecuteKernels() {
213    Constant * record_counts_routine;
214    Type * const int64ty = iBuilder->getInt64Ty();
215    Type * const voidTy = Type::getVoidTy(mMod->getContext());
216    record_counts_routine = mMod->getOrInsertFunction("record_counts", voidTy, int64ty, int64ty, int64ty, int64ty, int64ty, nullptr);
217    Type * const inputType = PointerType::get(ArrayType::get(StructType::get(mMod->getContext(), std::vector<Type *>({ArrayType::get(mBitBlockType, 8)})), 1), 0);
218   
219    Function * const main = cast<Function>(mMod->getOrInsertFunction("Main", voidTy, inputType, int64ty, int64ty, nullptr));
220    main->setCallingConv(CallingConv::C);
221    Function::arg_iterator args = main->arg_begin();
222   
223    Value * const inputStream = &*(args++);
224    inputStream->setName("input");
225    Value * const bufferSize = &*(args++);
226    bufferSize->setName("bufferSize");
227    Value * const fileIdx = &*(args++);
228    fileIdx->setName("fileIdx");
229   
230    iBuilder->SetInsertPoint(BasicBlock::Create(mMod->getContext(), "entry", main,0));
231   
232    BasicBlock * entryBlock = iBuilder->GetInsertBlock();
233
234    BasicBlock * segmentCondBlock = nullptr;
235    BasicBlock * segmentBodyBlock = nullptr;
236    const unsigned segmentSize = SegmentSize;
237    if (segmentSize > 1) {
238        segmentCondBlock = BasicBlock::Create(mMod->getContext(), "segmentCond", main, 0);
239        segmentBodyBlock = BasicBlock::Create(mMod->getContext(), "segmentBody", main, 0);
240    }
241    BasicBlock * fullCondBlock = BasicBlock::Create(mMod->getContext(), "fullCond", main, 0);
242    BasicBlock * fullBodyBlock = BasicBlock::Create(mMod->getContext(), "fullBody", main, 0);
243    BasicBlock * finalBlock = BasicBlock::Create(mMod->getContext(), "final", main, 0);
244    BasicBlock * finalPartialBlock = BasicBlock::Create(mMod->getContext(), "partial", main, 0);
245    BasicBlock * finalEmptyBlock = BasicBlock::Create(mMod->getContext(), "empty", main, 0);
246    BasicBlock * endBlock = BasicBlock::Create(mMod->getContext(), "end", main, 0);
247
248    Instance * s2pInstance = mS2PKernel->instantiate(inputStream);
249    Instance * wcInstance = mWC_Kernel->instantiate(s2pInstance->getOutputStreamBuffer());
250
251    Value * initialBufferSize = nullptr;
252    BasicBlock * initialBlock = nullptr;
253   
254    if (segmentSize > 1) {
255        iBuilder->CreateBr(segmentCondBlock);
256        iBuilder->SetInsertPoint(segmentCondBlock);
257        PHINode * remainingBytes = iBuilder->CreatePHI(int64ty, 2, "remainingBytes");
258        remainingBytes->addIncoming(bufferSize, entryBlock);
259        Constant * const step = ConstantInt::get(int64ty, mBlockSize * segmentSize);
260        Value * segmentCondTest = iBuilder->CreateICmpULT(remainingBytes, step);
261        iBuilder->CreateCondBr(segmentCondTest, fullCondBlock, segmentBodyBlock);
262        iBuilder->SetInsertPoint(segmentBodyBlock);
263        for (unsigned i = 0; i < segmentSize; ++i) {
264            s2pInstance->CreateDoBlockCall();
265        }
266        for (unsigned i = 0; i < segmentSize; ++i) {
267            wcInstance->CreateDoBlockCall();
268        }
269        remainingBytes->addIncoming(iBuilder->CreateSub(remainingBytes, step), segmentBodyBlock);
270        iBuilder->CreateBr(segmentCondBlock);
271        initialBufferSize = remainingBytes;
272        initialBlock = segmentCondBlock;
273    } else {
274        initialBufferSize = bufferSize;
275        initialBlock = entryBlock;
276        iBuilder->CreateBr(fullCondBlock);
277    }
278
279    iBuilder->SetInsertPoint(fullCondBlock);
280    PHINode * remainingBytes = iBuilder->CreatePHI(int64ty, 2, "remainingBytes");
281    remainingBytes->addIncoming(initialBufferSize, initialBlock);
282
283    Constant * const step = ConstantInt::get(int64ty, mBlockSize);
284    Value * fullCondTest = iBuilder->CreateICmpULT(remainingBytes, step);
285    iBuilder->CreateCondBr(fullCondTest, finalBlock, fullBodyBlock);
286   
287    iBuilder->SetInsertPoint(fullBodyBlock);
288
289    s2pInstance->CreateDoBlockCall();
290    wcInstance->CreateDoBlockCall();
291
292    Value * diff = iBuilder->CreateSub(remainingBytes, step);
293
294    remainingBytes->addIncoming(diff, fullBodyBlock);
295    iBuilder->CreateBr(fullCondBlock);
296   
297    iBuilder->SetInsertPoint(finalBlock);
298    Value * emptyBlockCond = iBuilder->CreateICmpEQ(remainingBytes, ConstantInt::get(int64ty, 0));
299    iBuilder->CreateCondBr(emptyBlockCond, finalEmptyBlock, finalPartialBlock);
300   
301   
302    iBuilder->SetInsertPoint(finalPartialBlock);
303    s2pInstance->CreateDoBlockCall();
304    iBuilder->CreateBr(endBlock);
305   
306    iBuilder->SetInsertPoint(finalEmptyBlock);
307    s2pInstance->clearOutputStreamSet();
308    iBuilder->CreateBr(endBlock);
309   
310    iBuilder->SetInsertPoint(endBlock);
311
312    wcInstance->CreateDoBlockCall();
313   
314    Value * lineCount = iBuilder->CreateExtractElement(iBuilder->CreateBlockAlignedLoad(wcInstance->getOutputStream((int) 0)), iBuilder->getInt32(0));
315    Value * wordCount = iBuilder->CreateExtractElement(iBuilder->CreateBlockAlignedLoad(wcInstance->getOutputStream(1)), iBuilder->getInt32(0));
316    Value * charCount = iBuilder->CreateExtractElement(iBuilder->CreateBlockAlignedLoad(wcInstance->getOutputStream(2)), iBuilder->getInt32(0));
317   
318    iBuilder->CreateCall(record_counts_routine, std::vector<Value *>({lineCount, wordCount, charCount, bufferSize, fileIdx}));
319   
320    iBuilder->CreateRetVoid();
321   
322    return main;
323}
324
325
326typedef void (*wcFunctionType)(char * byte_data, size_t filesize, size_t fileIdx);
327
328static ExecutionEngine * wcEngine = nullptr;
329
330wcFunctionType wcCodeGen(void) {
331                           
332    Module * M = new Module("wc", getGlobalContext());
333   
334    IDISA::IDISA_Builder * idb = IDISA::GetIDISA_Builder(M);
335
336    wcPipelineBuilder pipelineBuilder(M, idb);
337
338    Encoding encoding(Encoding::Type::UTF_8, 8);
339   
340    pablo::PabloFunction * function = wc_gen(encoding);
341   
342
343    pipelineBuilder.CreateKernels(function);
344
345    llvm::Function * main_IR = pipelineBuilder.ExecuteKernels();
346   
347    if (DumpGeneratedIR) {
348        M->dump();
349    }
350   
351    //verifyModule(*M, &dbgs());
352    //std::cerr << "ExecuteKernels(); done\n";
353    wcEngine = JIT_to_ExecutionEngine(M);
354   
355    wcEngine->finalizeObject();
356    //std::cerr << "finalizeObject(); done\n";
357
358    delete idb;
359    return reinterpret_cast<wcFunctionType>(wcEngine->getPointerToFunction(main_IR));
360}
361
362void wc(wcFunctionType fn_ptr, const int64_t fileIdx) {
363    std::string fileName = inputFiles[fileIdx];
364    size_t fileSize;
365    char * fileBuffer;
366   
367    const path file(fileName);
368    if (exists(file)) {
369        if (is_directory(file)) {
370            return;
371        }
372    } else {
373        std::cerr << "Error: cannot open " << fileName << " for processing. Skipped.\n";
374        return;
375    }
376   
377    fileSize = file_size(file);
378    mapped_file_source mappedFile;
379    if (fileSize == 0) {
380        fileBuffer = nullptr;
381    }
382    else {
383        try {
384            mappedFile.open(fileName);
385        } catch (std::exception &e) {
386            std::cerr << "Error: Boost mmap of " << fileName << ": " << e.what() << std::endl;
387            return;
388        }
389        fileBuffer = const_cast<char *>(mappedFile.data());
390    }
391    fn_ptr(fileBuffer, fileSize, fileIdx);
392
393    mappedFile.close();
394   
395}
396
397
398
399
400int main(int argc, char *argv[]) {
401    HideUnrelatedOptions(wcFlags);
402
403    cl::ParseCommandLineOptions(argc, argv);
404    if (wcOptions.size() == 0) {
405        CountLines = true;
406        CountWords = true;
407        CountBytes = true;
408    }
409    else {
410        CountLines = false;
411        CountWords = false;
412        CountBytes = false;
413        CountChars = false;
414        for (unsigned i = 0; i < wcOptions.size(); i++) {
415            switch (wcOptions[i]) {
416                case WordOption: CountWords = true; break;
417                case LineOption: CountLines = true; break;
418                case CharOption: CountBytes = true; CountChars = false; break;
419                case ByteOption: CountChars = true; CountBytes = false; break;
420            }
421        }
422    }
423   
424   
425    wcFunctionType fn_ptr = wcCodeGen();
426
427    int fileCount = inputFiles.size();
428    lineCount.resize(fileCount);
429    wordCount.resize(fileCount);
430    charCount.resize(fileCount);
431    byteCount.resize(fileCount);
432   
433    for (unsigned i = 0; i < inputFiles.size(); ++i) {
434        wc(fn_ptr, i);
435    }
436   
437    delete wcEngine;
438   
439    size_t maxCount = 0;
440    if (CountLines) maxCount = TotalLines;
441    if (CountWords) maxCount = TotalWords;
442    if (CountChars) maxCount = TotalChars;
443    if (CountBytes) maxCount = TotalBytes;
444   
445    int fieldWidth = std::to_string(maxCount).size() + 1;
446    if (fieldWidth < defaultFieldWidth) fieldWidth = defaultFieldWidth;
447
448    for (unsigned i = 0; i < inputFiles.size(); ++i) {
449        std::cout << std::setw(fieldWidth-1);
450        if (CountLines) {
451            std::cout << lineCount[i] << std::setw(fieldWidth);
452        }
453        if (CountWords) {
454            std::cout << wordCount[i] << std::setw(fieldWidth);
455        }
456        if (CountChars) {
457            std::cout << charCount[i] << std::setw(fieldWidth);
458        }
459        if (CountBytes) {
460            std::cout << byteCount[i];
461        }
462        std::cout << " " << inputFiles[i] << std::endl;
463    }
464    if (inputFiles.size() > 1) {
465        std::cout << std::setw(fieldWidth-1);
466        if (CountLines) {
467            std::cout << TotalLines << std::setw(fieldWidth);
468        }
469        if (CountWords) {
470            std::cout << TotalWords << std::setw(fieldWidth);
471        }
472        if (CountChars) {
473            std::cout << TotalChars << std::setw(fieldWidth);
474        }
475        if (CountBytes) {
476            std::cout << TotalBytes;
477        }
478        std::cout << " total" << std::endl;
479    }
480
481    return 0;
482}
483
484                       
Note: See TracBrowser for help on using the repository browser.