source: icGREP/icgrep-devel/icgrep/wc.cpp @ 5034

Last change on this file since 5034 was 5034, checked in by cameron, 3 years ago

Clean out duplicate parameters for wc

File size: 15.6 KB
Line 
1/*
2 *  Copyright (c) 2015 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 *  icgrep is a trademark of International Characters.
5 */
6
7#include <string>
8#include <iostream>
9#include <iomanip>
10#include <fstream>
11#include <sstream>
12
13
14#include <toolchain.h>
15#include <llvm/IR/Function.h>
16#include <llvm/IR/Module.h>
17#include <llvm/ExecutionEngine/ExecutionEngine.h>
18#include <llvm/ExecutionEngine/MCJIT.h>
19
20#include <llvm/Support/CommandLine.h>
21#include <llvm/Support/raw_ostream.h>
22
23#include <utf_encoding.h>
24#include <re/re_cc.h>
25#include <cc/cc_compiler.h>
26#include <pablo/function.h>
27#include <IDISA/idisa_builder.h>
28#include <IDISA/idisa_target.h>
29#include <kernels/instance.h>
30#include <kernels/kernel.h>
31#include <kernels/s2p_kernel.h>
32
33#include <pablo/pablo_compiler.h>
34#include <pablo/pablo_toolchain.h>
35
36
37#include <utf_encoding.h>
38
39// mmap system
40#include <boost/filesystem.hpp>
41#include <boost/iostreams/device/mapped_file.hpp>
42using namespace boost::iostreams;
43using namespace boost::filesystem;
44
45#include <fcntl.h>
46static cl::OptionCategory wcFlags("Command Flags", "wc options");
47
48static cl::list<std::string> inputFiles(cl::Positional, cl::desc("<input file ...>"), cl::OneOrMore, cl::cat(wcFlags));
49
50enum CountOptions {
51    LineOption, WordOption, CharOption, ByteOption
52};
53
54static cl::list<CountOptions> wcOptions(
55  cl::values(clEnumValN(LineOption, "l", "Report the number of lines in each input file."),
56             clEnumValN(WordOption, "w", "Report the number of words in each input file."),
57             clEnumValN(CharOption, "m", "Report the number of characters in each input file (override -c)."),
58             clEnumValN(ByteOption, "c", "Report the number of bytes in each input file (override -m)."),
59             clEnumValEnd), cl::cat(wcFlags), cl::Grouping);
60                                                 
61
62
63static int defaultFieldWidth = 7;  // default field width
64
65
66bool CountLines = false;
67bool CountWords = false;
68bool CountChars = false;
69bool CountBytes = false;
70
71std::vector<uint64_t> lineCount;
72std::vector<uint64_t> wordCount;
73std::vector<uint64_t> charCount;
74std::vector<uint64_t> byteCount;
75
76uint64_t TotalLines = 0;
77uint64_t TotalWords = 0;
78uint64_t TotalChars = 0;
79uint64_t TotalBytes = 0;
80
81
82//  The callback routine that records counts in progress.
83//
84extern "C" {
85    void record_counts(uint64_t lines, uint64_t words, uint64_t chars, uint64_t bytes, uint64_t fileIdx) {
86        lineCount[fileIdx] = lines;
87        wordCount[fileIdx] = words;
88        charCount[fileIdx] = chars;
89        byteCount[fileIdx] = bytes;
90        TotalLines += lines;
91        TotalWords += words;
92        TotalChars += chars;
93        TotalBytes += bytes;
94    }
95}
96
97//
98//
99
100pablo::PabloFunction * wc_gen(Encoding encoding) {
101    //  input: 8 basis bit streams
102    //  output: 3 count streams
103   
104    pablo::PabloFunction * function = pablo::PabloFunction::Create("wc", 8, 3);
105    cc::CC_Compiler ccc(*function, encoding);
106   
107    pablo::PabloBuilder pBuilder(ccc.getBuilder().getPabloBlock(), ccc.getBuilder());
108    const std::vector<pablo::Var *> u8_bits = ccc.getBasisBits();
109
110    if (CountLines) {
111        pablo::PabloAST * LF = ccc.compileCC(re::makeCC(0x0A));
112        function->setResult(0, pBuilder.createAssign("lineCount", pBuilder.createCount(LF)));
113    }
114    else function->setResult(0, pBuilder.createAssign("lineCount", pBuilder.createZeroes()));
115    // FIXME - we need to limit this to pablo.inFile() because null bytes past EOF are matched by wordChar
116    if (CountWords) {
117        pablo::PabloAST * WS = ccc.compileCC(re::makeCC(re::makeCC(0x09, 0x0D), re::makeCC(0x20)));
118       
119        pablo::PabloAST * wordChar = ccc.compileCC(re::makeCC(re::makeCC(re::makeCC(0x00, 0x08), re::makeCC(0xE, 0x1F)), re::makeCC(0x21, 0xFF)));
120        // WS_follow_or_start = 1 past WS or at start of file
121        pablo::PabloAST * WS_follow_or_start = pBuilder.createNot(pBuilder.createAdvance(pBuilder.createNot(WS), 1));
122        //
123        pablo::PabloAST * wordStart = pBuilder.createAnd(wordChar, WS_follow_or_start);
124        function->setResult(1, pBuilder.createAssign("wordCount", pBuilder.createCount(wordStart)));
125    }
126    else function->setResult(1, pBuilder.createAssign("wordCount", pBuilder.createZeroes()));
127    if (CountChars) {
128        //
129        // FIXME: This correctly counts characters assuming valid UTF-8 input.  But what if input is
130        // not UTF-8, or is not valid?
131        //
132        pablo::PabloAST * u8Begin = ccc.compileCC(re::makeCC(re::makeCC(0, 0x7F), re::makeCC(0xC2, 0xF4)));
133        function->setResult(2, pBuilder.createAssign("charCount", pBuilder.createCount(u8Begin)));
134    }
135    else function->setResult(2, pBuilder.createAssign("charCount", pBuilder.createZeroes()));
136    return function;
137}
138
139using namespace kernel;
140
141
142class wcPipelineBuilder {
143public:
144    wcPipelineBuilder(llvm::Module * m, IDISA::IDISA_Builder * b);
145   
146    ~wcPipelineBuilder();
147   
148    void CreateKernels(pablo::PabloFunction * function);
149    llvm::Function * ExecuteKernels();
150   
151private:
152    llvm::Module *                      mMod;
153    IDISA::IDISA_Builder *              iBuilder;
154    KernelBuilder *                     mS2PKernel;
155    KernelBuilder *                     mWC_Kernel;
156    llvm::Type *                        mBitBlockType;
157    int                                 mBlockSize;
158};
159
160
161using namespace pablo;
162using namespace kernel;
163
164wcPipelineBuilder::wcPipelineBuilder(Module * m, IDISA::IDISA_Builder * b)
165: mMod(m)
166, iBuilder(b)
167, mBitBlockType(b->getBitBlockType())
168, mBlockSize(b->getBitBlockWidth()){
169   
170}
171
172wcPipelineBuilder::~wcPipelineBuilder(){
173    delete mS2PKernel;
174    delete mWC_Kernel;
175}
176
177void wcPipelineBuilder::CreateKernels(PabloFunction * function){
178    mS2PKernel = new KernelBuilder(iBuilder, "s2p", codegen::SegmentSize);
179    mWC_Kernel = new KernelBuilder(iBuilder, "wc", codegen::SegmentSize);
180   
181    generateS2PKernel(mMod, iBuilder, mS2PKernel);
182   
183    pablo_function_passes(function);
184   
185    PabloCompiler pablo_compiler(mMod, iBuilder);
186    try {
187        pablo_compiler.setKernel(mWC_Kernel);
188        pablo_compiler.compile(function);
189        delete function;
190        releaseSlabAllocatorMemory();
191    } catch (std::runtime_error e) {
192        delete function;
193        releaseSlabAllocatorMemory();
194        std::cerr << "Runtime error: " << e.what() << std::endl;
195        exit(1);
196    }
197   
198}
199
200
201
202
203Function * wcPipelineBuilder::ExecuteKernels() {
204    Constant * record_counts_routine;
205    Type * const int64ty = iBuilder->getInt64Ty();
206    Type * const voidTy = Type::getVoidTy(mMod->getContext());
207    record_counts_routine = mMod->getOrInsertFunction("record_counts", voidTy, int64ty, int64ty, int64ty, int64ty, int64ty, nullptr);
208    Type * const inputType = PointerType::get(ArrayType::get(StructType::get(mMod->getContext(), std::vector<Type *>({ArrayType::get(mBitBlockType, 8)})), 1), 0);
209   
210    Function * const main = cast<Function>(mMod->getOrInsertFunction("Main", voidTy, inputType, int64ty, int64ty, nullptr));
211    main->setCallingConv(CallingConv::C);
212    Function::arg_iterator args = main->arg_begin();
213   
214    Value * const inputStream = &*(args++);
215    inputStream->setName("input");
216    Value * const bufferSize = &*(args++);
217    bufferSize->setName("bufferSize");
218    Value * const fileIdx = &*(args++);
219    fileIdx->setName("fileIdx");
220   
221    iBuilder->SetInsertPoint(BasicBlock::Create(mMod->getContext(), "entry", main,0));
222   
223    BasicBlock * entryBlock = iBuilder->GetInsertBlock();
224
225    BasicBlock * segmentCondBlock = nullptr;
226    BasicBlock * segmentBodyBlock = nullptr;
227    const unsigned segmentSize = codegen::SegmentSize;
228    if (segmentSize > 1) {
229        segmentCondBlock = BasicBlock::Create(mMod->getContext(), "segmentCond", main, 0);
230        segmentBodyBlock = BasicBlock::Create(mMod->getContext(), "segmentBody", main, 0);
231    }
232    BasicBlock * fullCondBlock = BasicBlock::Create(mMod->getContext(), "fullCond", main, 0);
233    BasicBlock * fullBodyBlock = BasicBlock::Create(mMod->getContext(), "fullBody", main, 0);
234    BasicBlock * finalBlock = BasicBlock::Create(mMod->getContext(), "final", main, 0);
235    BasicBlock * finalPartialBlock = BasicBlock::Create(mMod->getContext(), "partial", main, 0);
236    BasicBlock * finalEmptyBlock = BasicBlock::Create(mMod->getContext(), "empty", main, 0);
237    BasicBlock * endBlock = BasicBlock::Create(mMod->getContext(), "end", main, 0);
238
239    Instance * s2pInstance = mS2PKernel->instantiate(inputStream);
240    Instance * wcInstance = mWC_Kernel->instantiate(s2pInstance->getOutputStreamBuffer());
241
242    Value * initialBufferSize = nullptr;
243    BasicBlock * initialBlock = nullptr;
244   
245    if (segmentSize > 1) {
246        iBuilder->CreateBr(segmentCondBlock);
247        iBuilder->SetInsertPoint(segmentCondBlock);
248        PHINode * remainingBytes = iBuilder->CreatePHI(int64ty, 2, "remainingBytes");
249        remainingBytes->addIncoming(bufferSize, entryBlock);
250        Constant * const step = ConstantInt::get(int64ty, mBlockSize * segmentSize);
251        Value * segmentCondTest = iBuilder->CreateICmpULT(remainingBytes, step);
252        iBuilder->CreateCondBr(segmentCondTest, fullCondBlock, segmentBodyBlock);
253        iBuilder->SetInsertPoint(segmentBodyBlock);
254        for (unsigned i = 0; i < segmentSize; ++i) {
255            s2pInstance->CreateDoBlockCall();
256        }
257        for (unsigned i = 0; i < segmentSize; ++i) {
258            wcInstance->CreateDoBlockCall();
259        }
260        remainingBytes->addIncoming(iBuilder->CreateSub(remainingBytes, step), segmentBodyBlock);
261        iBuilder->CreateBr(segmentCondBlock);
262        initialBufferSize = remainingBytes;
263        initialBlock = segmentCondBlock;
264    } else {
265        initialBufferSize = bufferSize;
266        initialBlock = entryBlock;
267        iBuilder->CreateBr(fullCondBlock);
268    }
269
270    iBuilder->SetInsertPoint(fullCondBlock);
271    PHINode * remainingBytes = iBuilder->CreatePHI(int64ty, 2, "remainingBytes");
272    remainingBytes->addIncoming(initialBufferSize, initialBlock);
273
274    Constant * const step = ConstantInt::get(int64ty, mBlockSize);
275    Value * fullCondTest = iBuilder->CreateICmpULT(remainingBytes, step);
276    iBuilder->CreateCondBr(fullCondTest, finalBlock, fullBodyBlock);
277   
278    iBuilder->SetInsertPoint(fullBodyBlock);
279
280    s2pInstance->CreateDoBlockCall();
281    wcInstance->CreateDoBlockCall();
282
283    Value * diff = iBuilder->CreateSub(remainingBytes, step);
284
285    remainingBytes->addIncoming(diff, fullBodyBlock);
286    iBuilder->CreateBr(fullCondBlock);
287   
288    iBuilder->SetInsertPoint(finalBlock);
289    Value * emptyBlockCond = iBuilder->CreateICmpEQ(remainingBytes, ConstantInt::get(int64ty, 0));
290    iBuilder->CreateCondBr(emptyBlockCond, finalEmptyBlock, finalPartialBlock);
291   
292   
293    iBuilder->SetInsertPoint(finalPartialBlock);
294    s2pInstance->CreateDoBlockCall();
295    iBuilder->CreateBr(endBlock);
296   
297    iBuilder->SetInsertPoint(finalEmptyBlock);
298    s2pInstance->clearOutputStreamSet();
299    iBuilder->CreateBr(endBlock);
300   
301    iBuilder->SetInsertPoint(endBlock);
302
303    wcInstance->CreateDoBlockCall();
304   
305    Value * lineCount = iBuilder->CreateExtractElement(iBuilder->CreateBlockAlignedLoad(wcInstance->getOutputStream((int) 0)), iBuilder->getInt32(0));
306    Value * wordCount = iBuilder->CreateExtractElement(iBuilder->CreateBlockAlignedLoad(wcInstance->getOutputStream(1)), iBuilder->getInt32(0));
307    Value * charCount = iBuilder->CreateExtractElement(iBuilder->CreateBlockAlignedLoad(wcInstance->getOutputStream(2)), iBuilder->getInt32(0));
308   
309    iBuilder->CreateCall(record_counts_routine, std::vector<Value *>({lineCount, wordCount, charCount, bufferSize, fileIdx}));
310   
311    iBuilder->CreateRetVoid();
312   
313    return main;
314}
315
316
317typedef void (*wcFunctionType)(char * byte_data, size_t filesize, size_t fileIdx);
318
319static ExecutionEngine * wcEngine = nullptr;
320
321wcFunctionType wcCodeGen(void) {
322                           
323    Module * M = new Module("wc", getGlobalContext());
324    IDISA::IDISA_Builder * idb = IDISA::GetIDISA_Builder(M);
325
326    wcPipelineBuilder pipelineBuilder(M, idb);
327    Encoding encoding(Encoding::Type::UTF_8, 8);
328    pablo::PabloFunction * function = wc_gen(encoding);
329    pipelineBuilder.CreateKernels(function);
330    llvm::Function * main_IR = pipelineBuilder.ExecuteKernels();
331
332    wcEngine = JIT_to_ExecutionEngine(M);
333   
334    wcEngine->finalizeObject();
335
336    delete idb;
337    return reinterpret_cast<wcFunctionType>(wcEngine->getPointerToFunction(main_IR));
338}
339
340void wc(wcFunctionType fn_ptr, const int64_t fileIdx) {
341    std::string fileName = inputFiles[fileIdx];
342    size_t fileSize;
343    char * fileBuffer;
344   
345    const path file(fileName);
346    if (exists(file)) {
347        if (is_directory(file)) {
348            return;
349        }
350    } else {
351        std::cerr << "Error: cannot open " << fileName << " for processing. Skipped.\n";
352        return;
353    }
354   
355    fileSize = file_size(file);
356    mapped_file_source mappedFile;
357    if (fileSize == 0) {
358        fileBuffer = nullptr;
359    }
360    else {
361        try {
362            mappedFile.open(fileName);
363        } catch (std::exception &e) {
364            std::cerr << "Error: Boost mmap of " << fileName << ": " << e.what() << std::endl;
365            return;
366        }
367        fileBuffer = const_cast<char *>(mappedFile.data());
368    }
369    fn_ptr(fileBuffer, fileSize, fileIdx);
370
371    mappedFile.close();
372   
373}
374
375
376
377
378int main(int argc, char *argv[]) {
379    HideUnrelatedOptions(wcFlags);
380
381    cl::ParseCommandLineOptions(argc, argv);
382    if (wcOptions.size() == 0) {
383        CountLines = true;
384        CountWords = true;
385        CountBytes = true;
386    }
387    else {
388        CountLines = false;
389        CountWords = false;
390        CountBytes = false;
391        CountChars = false;
392        for (unsigned i = 0; i < wcOptions.size(); i++) {
393            switch (wcOptions[i]) {
394                case WordOption: CountWords = true; break;
395                case LineOption: CountLines = true; break;
396                case CharOption: CountBytes = true; CountChars = false; break;
397                case ByteOption: CountChars = true; CountBytes = false; break;
398            }
399        }
400    }
401   
402   
403    wcFunctionType fn_ptr = wcCodeGen();
404
405    int fileCount = inputFiles.size();
406    lineCount.resize(fileCount);
407    wordCount.resize(fileCount);
408    charCount.resize(fileCount);
409    byteCount.resize(fileCount);
410   
411    for (unsigned i = 0; i < inputFiles.size(); ++i) {
412        wc(fn_ptr, i);
413    }
414   
415    delete wcEngine;
416   
417    size_t maxCount = 0;
418    if (CountLines) maxCount = TotalLines;
419    if (CountWords) maxCount = TotalWords;
420    if (CountChars) maxCount = TotalChars;
421    if (CountBytes) maxCount = TotalBytes;
422   
423    int fieldWidth = std::to_string(maxCount).size() + 1;
424    if (fieldWidth < defaultFieldWidth) fieldWidth = defaultFieldWidth;
425
426    for (unsigned i = 0; i < inputFiles.size(); ++i) {
427        std::cout << std::setw(fieldWidth-1);
428        if (CountLines) {
429            std::cout << lineCount[i] << std::setw(fieldWidth);
430        }
431        if (CountWords) {
432            std::cout << wordCount[i] << std::setw(fieldWidth);
433        }
434        if (CountChars) {
435            std::cout << charCount[i] << std::setw(fieldWidth);
436        }
437        if (CountBytes) {
438            std::cout << byteCount[i];
439        }
440        std::cout << " " << inputFiles[i] << std::endl;
441    }
442    if (inputFiles.size() > 1) {
443        std::cout << std::setw(fieldWidth-1);
444        if (CountLines) {
445            std::cout << TotalLines << std::setw(fieldWidth);
446        }
447        if (CountWords) {
448            std::cout << TotalWords << std::setw(fieldWidth);
449        }
450        if (CountChars) {
451            std::cout << TotalChars << std::setw(fieldWidth);
452        }
453        if (CountBytes) {
454            std::cout << TotalBytes;
455        }
456        std::cout << " total" << std::endl;
457    }
458
459    return 0;
460}
461
462                       
Note: See TracBrowser for help on using the repository browser.