source: icGREP/icgrep-devel/icgrep/wc.cpp @ 5035

Last change on this file since 5035 was 5035, checked in by cameron, 3 years ago

Add EOFmask internal state value to generated Pablo functions; implement pablo.inFile

File size: 15.6 KB
Line 
1/*
2 *  Copyright (c) 2015 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 *  icgrep is a trademark of International Characters.
5 */
6
7#include <string>
8#include <iostream>
9#include <iomanip>
10#include <fstream>
11#include <sstream>
12
13
14#include <toolchain.h>
15#include <llvm/IR/Function.h>
16#include <llvm/IR/Module.h>
17#include <llvm/ExecutionEngine/ExecutionEngine.h>
18#include <llvm/ExecutionEngine/MCJIT.h>
19
20#include <llvm/Support/CommandLine.h>
21#include <llvm/Support/raw_ostream.h>
22
23#include <utf_encoding.h>
24#include <re/re_cc.h>
25#include <cc/cc_compiler.h>
26#include <pablo/function.h>
27#include <IDISA/idisa_builder.h>
28#include <IDISA/idisa_target.h>
29#include <kernels/instance.h>
30#include <kernels/kernel.h>
31#include <kernels/s2p_kernel.h>
32
33#include <pablo/pablo_compiler.h>
34#include <pablo/pablo_toolchain.h>
35
36
37#include <utf_encoding.h>
38
39// mmap system
40#include <boost/filesystem.hpp>
41#include <boost/iostreams/device/mapped_file.hpp>
42using namespace boost::iostreams;
43using namespace boost::filesystem;
44
45#include <fcntl.h>
46static cl::OptionCategory wcFlags("Command Flags", "wc options");
47
48static cl::list<std::string> inputFiles(cl::Positional, cl::desc("<input file ...>"), cl::OneOrMore, cl::cat(wcFlags));
49
50enum CountOptions {
51    LineOption, WordOption, CharOption, ByteOption
52};
53
54static cl::list<CountOptions> wcOptions(
55  cl::values(clEnumValN(LineOption, "l", "Report the number of lines in each input file."),
56             clEnumValN(WordOption, "w", "Report the number of words in each input file."),
57             clEnumValN(CharOption, "m", "Report the number of characters in each input file (override -c)."),
58             clEnumValN(ByteOption, "c", "Report the number of bytes in each input file (override -m)."),
59             clEnumValEnd), cl::cat(wcFlags), cl::Grouping);
60                                                 
61
62
63static int defaultFieldWidth = 7;  // default field width
64
65
66bool CountLines = false;
67bool CountWords = false;
68bool CountChars = false;
69bool CountBytes = false;
70
71std::vector<uint64_t> lineCount;
72std::vector<uint64_t> wordCount;
73std::vector<uint64_t> charCount;
74std::vector<uint64_t> byteCount;
75
76uint64_t TotalLines = 0;
77uint64_t TotalWords = 0;
78uint64_t TotalChars = 0;
79uint64_t TotalBytes = 0;
80
81
82//  The callback routine that records counts in progress.
83//
84extern "C" {
85    void record_counts(uint64_t lines, uint64_t words, uint64_t chars, uint64_t bytes, uint64_t fileIdx) {
86        lineCount[fileIdx] = lines;
87        wordCount[fileIdx] = words;
88        charCount[fileIdx] = chars;
89        byteCount[fileIdx] = bytes;
90        TotalLines += lines;
91        TotalWords += words;
92        TotalChars += chars;
93        TotalBytes += bytes;
94    }
95}
96
97//
98//
99
100pablo::PabloFunction * wc_gen(Encoding encoding) {
101    //  input: 8 basis bit streams
102    //  output: 3 count streams
103   
104    pablo::PabloFunction * function = pablo::PabloFunction::Create("wc", 8, 3);
105    cc::CC_Compiler ccc(*function, encoding);
106   
107    pablo::PabloBuilder pBuilder(ccc.getBuilder().getPabloBlock(), ccc.getBuilder());
108    const std::vector<pablo::Var *> u8_bits = ccc.getBasisBits();
109
110    if (CountLines) {
111        pablo::PabloAST * LF = ccc.compileCC(re::makeCC(0x0A));
112        function->setResult(0, pBuilder.createAssign("lineCount", pBuilder.createCount(LF)));
113    }
114    else function->setResult(0, pBuilder.createAssign("lineCount", pBuilder.createZeroes()));
115    if (CountWords) {
116        pablo::PabloAST * WS = ccc.compileCC(re::makeCC(re::makeCC(0x09, 0x0D), re::makeCC(0x20)));
117       
118        pablo::PabloAST * wordChar = pBuilder.createNot(WS);
119        // WS_follow_or_start = 1 past WS or at start of file
120        pablo::PabloAST * WS_follow_or_start = pBuilder.createNot(pBuilder.createAdvance(wordChar, 1));
121        //
122        pablo::PabloAST * wordStart = pBuilder.createInFile(pBuilder.createAnd(wordChar, WS_follow_or_start));
123        function->setResult(1, pBuilder.createAssign("wordCount", pBuilder.createCount(wordStart)));
124    }
125    else function->setResult(1, pBuilder.createAssign("wordCount", pBuilder.createZeroes()));
126    if (CountChars) {
127        //
128        // FIXME: This correctly counts characters assuming valid UTF-8 input.  But what if input is
129        // not UTF-8, or is not valid?
130        //
131        pablo::PabloAST * u8Begin = ccc.compileCC(re::makeCC(re::makeCC(0, 0x7F), re::makeCC(0xC2, 0xF4)));
132        function->setResult(2, pBuilder.createAssign("charCount", pBuilder.createCount(u8Begin)));
133    }
134    else function->setResult(2, pBuilder.createAssign("charCount", pBuilder.createZeroes()));
135    return function;
136}
137
138using namespace kernel;
139
140
141class wcPipelineBuilder {
142public:
143    wcPipelineBuilder(llvm::Module * m, IDISA::IDISA_Builder * b);
144   
145    ~wcPipelineBuilder();
146   
147    void CreateKernels(pablo::PabloFunction * function);
148    llvm::Function * ExecuteKernels();
149   
150private:
151    llvm::Module *                      mMod;
152    IDISA::IDISA_Builder *              iBuilder;
153    KernelBuilder *                     mS2PKernel;
154    KernelBuilder *                     mWC_Kernel;
155    llvm::Type *                        mBitBlockType;
156    int                                 mBlockSize;
157};
158
159
160using namespace pablo;
161using namespace kernel;
162
163wcPipelineBuilder::wcPipelineBuilder(Module * m, IDISA::IDISA_Builder * b)
164: mMod(m)
165, iBuilder(b)
166, mBitBlockType(b->getBitBlockType())
167, mBlockSize(b->getBitBlockWidth()){
168   
169}
170
171wcPipelineBuilder::~wcPipelineBuilder(){
172    delete mS2PKernel;
173    delete mWC_Kernel;
174}
175
176void wcPipelineBuilder::CreateKernels(PabloFunction * function){
177    mS2PKernel = new KernelBuilder(iBuilder, "s2p", codegen::SegmentSize);
178    mWC_Kernel = new KernelBuilder(iBuilder, "wc", codegen::SegmentSize);
179   
180    generateS2PKernel(mMod, iBuilder, mS2PKernel);
181   
182    pablo_function_passes(function);
183   
184    PabloCompiler pablo_compiler(mMod, iBuilder);
185    try {
186        pablo_compiler.setKernel(mWC_Kernel);
187        pablo_compiler.compile(function);
188        delete function;
189        releaseSlabAllocatorMemory();
190    } catch (std::runtime_error e) {
191        delete function;
192        releaseSlabAllocatorMemory();
193        std::cerr << "Runtime error: " << e.what() << std::endl;
194        exit(1);
195    }
196   
197}
198
199
200
201
202Function * wcPipelineBuilder::ExecuteKernels() {
203    Constant * record_counts_routine;
204    Type * const int64ty = iBuilder->getInt64Ty();
205    Type * const voidTy = Type::getVoidTy(mMod->getContext());
206    record_counts_routine = mMod->getOrInsertFunction("record_counts", voidTy, int64ty, int64ty, int64ty, int64ty, int64ty, nullptr);
207    Type * const inputType = PointerType::get(ArrayType::get(StructType::get(mMod->getContext(), std::vector<Type *>({ArrayType::get(mBitBlockType, 8)})), 1), 0);
208   
209    Function * const main = cast<Function>(mMod->getOrInsertFunction("Main", voidTy, inputType, int64ty, int64ty, nullptr));
210    main->setCallingConv(CallingConv::C);
211    Function::arg_iterator args = main->arg_begin();
212   
213    Value * const inputStream = &*(args++);
214    inputStream->setName("input");
215    Value * const bufferSize = &*(args++);
216    bufferSize->setName("bufferSize");
217    Value * const fileIdx = &*(args++);
218    fileIdx->setName("fileIdx");
219   
220    iBuilder->SetInsertPoint(BasicBlock::Create(mMod->getContext(), "entry", main,0));
221   
222    BasicBlock * entryBlock = iBuilder->GetInsertBlock();
223
224    BasicBlock * segmentCondBlock = nullptr;
225    BasicBlock * segmentBodyBlock = nullptr;
226    const unsigned segmentSize = codegen::SegmentSize;
227    if (segmentSize > 1) {
228        segmentCondBlock = BasicBlock::Create(mMod->getContext(), "segmentCond", main, 0);
229        segmentBodyBlock = BasicBlock::Create(mMod->getContext(), "segmentBody", main, 0);
230    }
231    BasicBlock * fullCondBlock = BasicBlock::Create(mMod->getContext(), "fullCond", main, 0);
232    BasicBlock * fullBodyBlock = BasicBlock::Create(mMod->getContext(), "fullBody", main, 0);
233    BasicBlock * finalBlock = BasicBlock::Create(mMod->getContext(), "final", main, 0);
234    BasicBlock * finalPartialBlock = BasicBlock::Create(mMod->getContext(), "partial", main, 0);
235    BasicBlock * finalEmptyBlock = BasicBlock::Create(mMod->getContext(), "empty", main, 0);
236    BasicBlock * endBlock = BasicBlock::Create(mMod->getContext(), "end", main, 0);
237
238    Instance * s2pInstance = mS2PKernel->instantiate(inputStream);
239    Instance * wcInstance = mWC_Kernel->instantiate(s2pInstance->getOutputStreamBuffer());
240
241    Value * initialBufferSize = nullptr;
242    BasicBlock * initialBlock = nullptr;
243   
244    if (segmentSize > 1) {
245        iBuilder->CreateBr(segmentCondBlock);
246        iBuilder->SetInsertPoint(segmentCondBlock);
247        PHINode * remainingBytes = iBuilder->CreatePHI(int64ty, 2, "remainingBytes");
248        remainingBytes->addIncoming(bufferSize, entryBlock);
249        Constant * const step = ConstantInt::get(int64ty, mBlockSize * segmentSize);
250        Value * segmentCondTest = iBuilder->CreateICmpULT(remainingBytes, step);
251        iBuilder->CreateCondBr(segmentCondTest, fullCondBlock, segmentBodyBlock);
252        iBuilder->SetInsertPoint(segmentBodyBlock);
253        for (unsigned i = 0; i < segmentSize; ++i) {
254            s2pInstance->CreateDoBlockCall();
255        }
256        for (unsigned i = 0; i < segmentSize; ++i) {
257            wcInstance->CreateDoBlockCall();
258        }
259        remainingBytes->addIncoming(iBuilder->CreateSub(remainingBytes, step), segmentBodyBlock);
260        iBuilder->CreateBr(segmentCondBlock);
261        initialBufferSize = remainingBytes;
262        initialBlock = segmentCondBlock;
263    } else {
264        initialBufferSize = bufferSize;
265        initialBlock = entryBlock;
266        iBuilder->CreateBr(fullCondBlock);
267    }
268
269    iBuilder->SetInsertPoint(fullCondBlock);
270    PHINode * remainingBytes = iBuilder->CreatePHI(int64ty, 2, "remainingBytes");
271    remainingBytes->addIncoming(initialBufferSize, initialBlock);
272
273    Constant * const step = ConstantInt::get(int64ty, mBlockSize);
274    Value * fullCondTest = iBuilder->CreateICmpULT(remainingBytes, step);
275    iBuilder->CreateCondBr(fullCondTest, finalBlock, fullBodyBlock);
276   
277    iBuilder->SetInsertPoint(fullBodyBlock);
278
279    s2pInstance->CreateDoBlockCall();
280    wcInstance->CreateDoBlockCall();
281
282    Value * diff = iBuilder->CreateSub(remainingBytes, step);
283
284    remainingBytes->addIncoming(diff, fullBodyBlock);
285    iBuilder->CreateBr(fullCondBlock);
286   
287    iBuilder->SetInsertPoint(finalBlock);
288    Value * EOF_mask = iBuilder->CreateShl(Constant::getAllOnesValue(iBuilder->getIntNTy(mBlockSize)), remainingBytes);
289        wcInstance->setInternalState("EOFmask", iBuilder->CreateBitCast(EOF_mask, mBitBlockType));
290   
291    Value * emptyBlockCond = iBuilder->CreateICmpEQ(remainingBytes, ConstantInt::get(int64ty, 0));
292    iBuilder->CreateCondBr(emptyBlockCond, finalEmptyBlock, finalPartialBlock);
293   
294   
295    iBuilder->SetInsertPoint(finalPartialBlock);
296    s2pInstance->CreateDoBlockCall();
297
298    iBuilder->CreateBr(endBlock);
299   
300    iBuilder->SetInsertPoint(finalEmptyBlock);
301    s2pInstance->clearOutputStreamSet();
302    iBuilder->CreateBr(endBlock);
303   
304    iBuilder->SetInsertPoint(endBlock);
305
306    wcInstance->CreateDoBlockCall();
307   
308    Value * lineCount = iBuilder->CreateExtractElement(iBuilder->CreateBlockAlignedLoad(wcInstance->getOutputStream((int) 0)), iBuilder->getInt32(0));
309    Value * wordCount = iBuilder->CreateExtractElement(iBuilder->CreateBlockAlignedLoad(wcInstance->getOutputStream(1)), iBuilder->getInt32(0));
310    Value * charCount = iBuilder->CreateExtractElement(iBuilder->CreateBlockAlignedLoad(wcInstance->getOutputStream(2)), iBuilder->getInt32(0));
311   
312    iBuilder->CreateCall(record_counts_routine, std::vector<Value *>({lineCount, wordCount, charCount, bufferSize, fileIdx}));
313   
314    iBuilder->CreateRetVoid();
315   
316    return main;
317}
318
319
320typedef void (*wcFunctionType)(char * byte_data, size_t filesize, size_t fileIdx);
321
322static ExecutionEngine * wcEngine = nullptr;
323
324wcFunctionType wcCodeGen(void) {
325                           
326    Module * M = new Module("wc", getGlobalContext());
327    IDISA::IDISA_Builder * idb = IDISA::GetIDISA_Builder(M);
328
329    wcPipelineBuilder pipelineBuilder(M, idb);
330    Encoding encoding(Encoding::Type::UTF_8, 8);
331    pablo::PabloFunction * function = wc_gen(encoding);
332    pipelineBuilder.CreateKernels(function);
333    llvm::Function * main_IR = pipelineBuilder.ExecuteKernels();
334
335    wcEngine = JIT_to_ExecutionEngine(M);
336   
337    wcEngine->finalizeObject();
338
339    delete idb;
340    return reinterpret_cast<wcFunctionType>(wcEngine->getPointerToFunction(main_IR));
341}
342
343void wc(wcFunctionType fn_ptr, const int64_t fileIdx) {
344    std::string fileName = inputFiles[fileIdx];
345    size_t fileSize;
346    char * fileBuffer;
347   
348    const path file(fileName);
349    if (exists(file)) {
350        if (is_directory(file)) {
351            return;
352        }
353    } else {
354        std::cerr << "Error: cannot open " << fileName << " for processing. Skipped.\n";
355        return;
356    }
357   
358    fileSize = file_size(file);
359    mapped_file_source mappedFile;
360    if (fileSize == 0) {
361        fileBuffer = nullptr;
362    }
363    else {
364        try {
365            mappedFile.open(fileName);
366        } catch (std::exception &e) {
367            std::cerr << "Error: Boost mmap of " << fileName << ": " << e.what() << std::endl;
368            return;
369        }
370        fileBuffer = const_cast<char *>(mappedFile.data());
371    }
372    fn_ptr(fileBuffer, fileSize, fileIdx);
373
374    mappedFile.close();
375   
376}
377
378
379
380
381int main(int argc, char *argv[]) {
382    HideUnrelatedOptions(wcFlags);
383
384    cl::ParseCommandLineOptions(argc, argv);
385    if (wcOptions.size() == 0) {
386        CountLines = true;
387        CountWords = true;
388        CountBytes = true;
389    }
390    else {
391        CountLines = false;
392        CountWords = false;
393        CountBytes = false;
394        CountChars = false;
395        for (unsigned i = 0; i < wcOptions.size(); i++) {
396            switch (wcOptions[i]) {
397                case WordOption: CountWords = true; break;
398                case LineOption: CountLines = true; break;
399                case CharOption: CountBytes = true; CountChars = false; break;
400                case ByteOption: CountChars = true; CountBytes = false; break;
401            }
402        }
403    }
404   
405   
406    wcFunctionType fn_ptr = wcCodeGen();
407
408    int fileCount = inputFiles.size();
409    lineCount.resize(fileCount);
410    wordCount.resize(fileCount);
411    charCount.resize(fileCount);
412    byteCount.resize(fileCount);
413   
414    for (unsigned i = 0; i < inputFiles.size(); ++i) {
415        wc(fn_ptr, i);
416    }
417   
418    delete wcEngine;
419   
420    size_t maxCount = 0;
421    if (CountLines) maxCount = TotalLines;
422    if (CountWords) maxCount = TotalWords;
423    if (CountChars) maxCount = TotalChars;
424    if (CountBytes) maxCount = TotalBytes;
425   
426    int fieldWidth = std::to_string(maxCount).size() + 1;
427    if (fieldWidth < defaultFieldWidth) fieldWidth = defaultFieldWidth;
428
429    for (unsigned i = 0; i < inputFiles.size(); ++i) {
430        std::cout << std::setw(fieldWidth-1);
431        if (CountLines) {
432            std::cout << lineCount[i] << std::setw(fieldWidth);
433        }
434        if (CountWords) {
435            std::cout << wordCount[i] << std::setw(fieldWidth);
436        }
437        if (CountChars) {
438            std::cout << charCount[i] << std::setw(fieldWidth);
439        }
440        if (CountBytes) {
441            std::cout << byteCount[i];
442        }
443        std::cout << " " << inputFiles[i] << std::endl;
444    }
445    if (inputFiles.size() > 1) {
446        std::cout << std::setw(fieldWidth-1);
447        if (CountLines) {
448            std::cout << TotalLines << std::setw(fieldWidth);
449        }
450        if (CountWords) {
451            std::cout << TotalWords << std::setw(fieldWidth);
452        }
453        if (CountChars) {
454            std::cout << TotalChars << std::setw(fieldWidth);
455        }
456        if (CountBytes) {
457            std::cout << TotalBytes;
458        }
459        std::cout << " total" << std::endl;
460    }
461
462    return 0;
463}
464
465                       
Note: See TracBrowser for help on using the repository browser.