source: icGREP/icgrep-devel/icgrep/wc.cpp @ 5065

Last change on this file since 5065 was 5065, checked in by cameron, 3 years ago

LLVM type error fix

File size: 12.8 KB
Line 
1/*
2 *  Copyright (c) 2015 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 *  icgrep is a trademark of International Characters.
5 */
6
7#include <string>
8#include <iostream>
9#include <iomanip>
10#include <fstream>
11#include <sstream>
12
13
14#include <toolchain.h>
15#include <pablo/pablo_toolchain.h>
16#include <llvm/IR/Function.h>
17#include <llvm/IR/Module.h>
18#include <llvm/ExecutionEngine/ExecutionEngine.h>
19#include <llvm/ExecutionEngine/MCJIT.h>
20
21#include <llvm/Support/CommandLine.h>
22#include <llvm/Support/raw_ostream.h>
23
24#include <utf_encoding.h>
25#include <re/re_cc.h>
26#include <cc/cc_compiler.h>
27#include <pablo/function.h>
28#include <pablo/pablo_kernel.h>
29#include <IDISA/idisa_builder.h>
30#include <IDISA/idisa_target.h>
31#include <kernels/interface.h>
32#include <kernels/kernel.h>
33#include <kernels/s2p_kernel.h>
34
35#include <pablo/pablo_compiler.h>
36#include <pablo/pablo_toolchain.h>
37
38
39#include <utf_encoding.h>
40
41// mmap system
42#include <boost/filesystem.hpp>
43#include <boost/iostreams/device/mapped_file.hpp>
44using namespace boost::iostreams;
45using namespace boost::filesystem;
46
47#include <fcntl.h>
48static cl::OptionCategory wcFlags("Command Flags", "wc options");
49
50static cl::list<std::string> inputFiles(cl::Positional, cl::desc("<input file ...>"), cl::OneOrMore, cl::cat(wcFlags));
51
52enum CountOptions {
53    LineOption, WordOption, CharOption, ByteOption
54};
55
56static cl::list<CountOptions> wcOptions(
57  cl::values(clEnumValN(LineOption, "l", "Report the number of lines in each input file."),
58             clEnumValN(WordOption, "w", "Report the number of words in each input file."),
59             clEnumValN(CharOption, "m", "Report the number of characters in each input file (override -c)."),
60             clEnumValN(ByteOption, "c", "Report the number of bytes in each input file (override -m)."),
61             clEnumValEnd), cl::cat(wcFlags), cl::Grouping);
62                                                 
63
64
65static int defaultFieldWidth = 7;  // default field width
66
67
68bool CountLines = false;
69bool CountWords = false;
70bool CountChars = false;
71bool CountBytes = false;
72
73std::vector<uint64_t> lineCount;
74std::vector<uint64_t> wordCount;
75std::vector<uint64_t> charCount;
76std::vector<uint64_t> byteCount;
77
78uint64_t TotalLines = 0;
79uint64_t TotalWords = 0;
80uint64_t TotalChars = 0;
81uint64_t TotalBytes = 0;
82
83
84//  The callback routine that records counts in progress.
85//
86extern "C" {
87    void record_counts(uint64_t lines, uint64_t words, uint64_t chars, uint64_t bytes, uint64_t fileIdx) {
88        lineCount[fileIdx] = lines;
89        wordCount[fileIdx] = words;
90        charCount[fileIdx] = chars;
91        byteCount[fileIdx] = bytes;
92        TotalLines += lines;
93        TotalWords += words;
94        TotalChars += chars;
95        TotalBytes += bytes;
96    }
97}
98
99//
100//
101
102pablo::PabloFunction * wc_gen(Encoding encoding) {
103    //  input: 8 basis bit streams
104    //  output: 3 counters
105   
106    pablo::PabloFunction * function = pablo::PabloFunction::Create("wc", 8, 0);
107    cc::CC_Compiler ccc(*function, encoding);
108   
109    pablo::PabloBuilder pBuilder(ccc.getBuilder().getPabloBlock(), ccc.getBuilder());
110    const std::vector<pablo::Var *> u8_bits = ccc.getBasisBits();
111
112    if (CountLines) {
113        pablo::PabloAST * LF = ccc.compileCC(re::makeCC(0x0A));
114        function->setResultCount(pBuilder.createCount("lineCount", LF));
115    }
116    if (CountWords) {
117        pablo::PabloAST * WS = ccc.compileCC(re::makeCC(re::makeCC(0x09, 0x0D), re::makeCC(0x20)));
118       
119        pablo::PabloAST * wordChar = pBuilder.createNot(WS);
120        // WS_follow_or_start = 1 past WS or at start of file
121        pablo::PabloAST * WS_follow_or_start = pBuilder.createNot(pBuilder.createAdvance(wordChar, 1));
122        //
123        pablo::PabloAST * wordStart = pBuilder.createInFile(pBuilder.createAnd(wordChar, WS_follow_or_start));
124        function->setResultCount(pBuilder.createCount("wordCount", wordStart));
125    }
126    if (CountChars) {
127        //
128        // FIXME: This correctly counts characters assuming valid UTF-8 input.  But what if input is
129        // not UTF-8, or is not valid?
130        //
131        pablo::PabloAST * u8Begin = ccc.compileCC(re::makeCC(re::makeCC(0, 0x7F), re::makeCC(0xC2, 0xF4)));
132        function->setResultCount(pBuilder.createCount("charCount", u8Begin));
133    }
134    return function;
135}
136
137using namespace kernel;
138
139
140class wcPipelineBuilder {
141public:
142    wcPipelineBuilder(llvm::Module * m, IDISA::IDISA_Builder * b);
143   
144    ~wcPipelineBuilder();
145   
146    llvm::Function * ExecuteKernels(pablo::PabloFunction * function);
147   
148private:
149    llvm::Module *                      mMod;
150    IDISA::IDISA_Builder *              iBuilder;
151    llvm::Type *                        mBitBlockType;
152    int                                 mBlockSize;
153};
154
155
156using namespace pablo;
157using namespace kernel;
158
159wcPipelineBuilder::wcPipelineBuilder(Module * m, IDISA::IDISA_Builder * b)
160: mMod(m)
161, iBuilder(b)
162, mBitBlockType(b->getBitBlockType())
163, mBlockSize(b->getBitBlockWidth()){
164   
165}
166
167wcPipelineBuilder::~wcPipelineBuilder(){
168}
169
170
171Function * wcPipelineBuilder::ExecuteKernels(PabloFunction * function) {
172    s2pKernel  s2pk(iBuilder);
173    s2pk.generateKernel();
174   
175    pablo_function_passes(function);
176    PabloKernel  wck(iBuilder, "wc", function, {"lineCount", "wordCount", "charCount"});
177    wck.prepareKernel();
178    wck.generateKernel();
179
180    Constant * record_counts_routine;
181    Type * const int64ty = iBuilder->getInt64Ty();
182    Type * const voidTy = Type::getVoidTy(mMod->getContext());
183    record_counts_routine = mMod->getOrInsertFunction("record_counts", voidTy, int64ty, int64ty, int64ty, int64ty, int64ty, nullptr);
184    Type * const inputType = PointerType::get(ArrayType::get(ArrayType::get(mBitBlockType, 8), 1), 0);
185   
186    Function * const main = cast<Function>(mMod->getOrInsertFunction("Main", voidTy, inputType, int64ty, int64ty, nullptr));
187    main->setCallingConv(CallingConv::C);
188    Function::arg_iterator args = main->arg_begin();
189   
190    Value * const inputStream = &*(args++);
191    inputStream->setName("input");
192    Value * const bufferSize = &*(args++);
193    bufferSize->setName("bufferSize");
194    Value * const fileIdx = &*(args++);
195    fileIdx->setName("fileIdx");
196   
197    iBuilder->SetInsertPoint(BasicBlock::Create(mMod->getContext(), "entry", main,0));
198   
199    BasicBlock * entryBlock = iBuilder->GetInsertBlock();
200
201    BasicBlock * fullCondBlock = BasicBlock::Create(mMod->getContext(), "fullCond", main, 0);
202    BasicBlock * fullBodyBlock = BasicBlock::Create(mMod->getContext(), "fullBody", main, 0);
203    BasicBlock * finalBlock = BasicBlock::Create(mMod->getContext(), "final", main, 0);
204
205    StreamSetBuffer ByteStream(iBuilder, StreamSetType(1, 8), 0);
206    StreamSetBuffer BasisBits(iBuilder, StreamSetType(8, 1), 1);
207    ByteStream.setStreamSetBuffer(inputStream);
208    Value * basisBits = BasisBits.allocateBuffer();
209
210    Value * s2pInstance = s2pk.createInstance({});
211    Value * wcInstance = wck.createInstance({});
212   
213    Value * initialBufferSize = bufferSize;
214    BasicBlock * initialBlock = entryBlock;
215    Value * initialBlockNo = iBuilder->getInt64(0);
216
217    iBuilder->CreateBr(fullCondBlock);
218
219   
220    iBuilder->SetInsertPoint(fullCondBlock);
221    PHINode * remainingBytes = iBuilder->CreatePHI(int64ty, 2, "remainingBytes");
222    remainingBytes->addIncoming(initialBufferSize, initialBlock);
223    PHINode * blockNo = iBuilder->CreatePHI(int64ty, 2, "blockNo");
224    blockNo->addIncoming(initialBlockNo, initialBlock);
225
226    Constant * const step = ConstantInt::get(int64ty, mBlockSize);
227    Value * fullCondTest = iBuilder->CreateICmpULT(remainingBytes, step);
228    iBuilder->CreateCondBr(fullCondTest, finalBlock, fullBodyBlock);
229   
230    iBuilder->SetInsertPoint(fullBodyBlock);
231
232    s2pk.createDoBlockCall(s2pInstance, {ByteStream.getBlockPointer(blockNo), basisBits});
233    wck.createDoBlockCall(wcInstance, {basisBits});
234
235    Value * diff = iBuilder->CreateSub(remainingBytes, step);
236
237    remainingBytes->addIncoming(diff, fullBodyBlock);
238    blockNo->addIncoming(iBuilder->CreateAdd(blockNo, iBuilder->getInt64(1)), fullBodyBlock);
239    iBuilder->CreateBr(fullCondBlock);
240   
241    iBuilder->SetInsertPoint(finalBlock);
242    s2pk.createFinalBlockCall(s2pInstance, remainingBytes, {ByteStream.getBlockPointer(blockNo), basisBits});
243    wck.createFinalBlockCall(wcInstance, remainingBytes, {basisBits});
244   
245    Value * lineCount = wck.createGetAccumulatorCall(wcInstance, "lineCount");
246    Value * wordCount = wck.createGetAccumulatorCall(wcInstance, "wordCount");
247    Value * charCount = wck.createGetAccumulatorCall(wcInstance, "charCount");;
248
249    iBuilder->CreateCall(record_counts_routine, std::vector<Value *>({lineCount, wordCount, charCount, bufferSize, fileIdx}));
250   
251    iBuilder->CreateRetVoid();
252    return main;
253}
254
255
256typedef void (*wcFunctionType)(char * byte_data, size_t filesize, size_t fileIdx);
257
258static ExecutionEngine * wcEngine = nullptr;
259
260wcFunctionType wcCodeGen(void) {
261                           
262    Module * M = new Module("wc", getGlobalContext());
263    IDISA::IDISA_Builder * idb = IDISA::GetIDISA_Builder(M);
264
265    wcPipelineBuilder pipelineBuilder(M, idb);
266    Encoding encoding(Encoding::Type::UTF_8, 8);
267    pablo::PabloFunction * function = wc_gen(encoding);
268    llvm::Function * main_IR = pipelineBuilder.ExecuteKernels(function);
269
270    wcEngine = JIT_to_ExecutionEngine(M);
271   
272    wcEngine->finalizeObject();
273
274    delete idb;
275    return reinterpret_cast<wcFunctionType>(wcEngine->getPointerToFunction(main_IR));
276}
277
278void wc(wcFunctionType fn_ptr, const int64_t fileIdx) {
279    std::string fileName = inputFiles[fileIdx];
280    size_t fileSize;
281    char * fileBuffer;
282   
283    const path file(fileName);
284    if (exists(file)) {
285        if (is_directory(file)) {
286            return;
287        }
288    } else {
289        std::cerr << "Error: cannot open " << fileName << " for processing. Skipped.\n";
290        return;
291    }
292   
293    fileSize = file_size(file);
294    mapped_file_source mappedFile;
295    if (fileSize == 0) {
296        fileBuffer = nullptr;
297    }
298    else {
299        try {
300            mappedFile.open(fileName);
301        } catch (std::exception &e) {
302            std::cerr << "Error: Boost mmap of " << fileName << ": " << e.what() << std::endl;
303            return;
304        }
305        fileBuffer = const_cast<char *>(mappedFile.data());
306    }
307    fn_ptr(fileBuffer, fileSize, fileIdx);
308
309    mappedFile.close();
310   
311}
312
313
314
315
316int main(int argc, char *argv[]) {
317    cl::HideUnrelatedOptions(ArrayRef<const cl::OptionCategory *>{&wcFlags, pablo::pablo_toolchain_flags(), codegen::codegen_flags()});
318    cl::ParseCommandLineOptions(argc, argv);
319    if (wcOptions.size() == 0) {
320        CountLines = true;
321        CountWords = true;
322        CountBytes = true;
323    }
324    else {
325        CountLines = false;
326        CountWords = false;
327        CountBytes = false;
328        CountChars = false;
329        for (unsigned i = 0; i < wcOptions.size(); i++) {
330            switch (wcOptions[i]) {
331                case WordOption: CountWords = true; break;
332                case LineOption: CountLines = true; break;
333                case CharOption: CountBytes = true; CountChars = false; break;
334                case ByteOption: CountChars = true; CountBytes = false; break;
335            }
336        }
337    }
338   
339   
340    wcFunctionType fn_ptr = wcCodeGen();
341
342    int fileCount = inputFiles.size();
343    lineCount.resize(fileCount);
344    wordCount.resize(fileCount);
345    charCount.resize(fileCount);
346    byteCount.resize(fileCount);
347   
348    for (unsigned i = 0; i < inputFiles.size(); ++i) {
349        wc(fn_ptr, i);
350    }
351   
352    delete wcEngine;
353   
354    size_t maxCount = 0;
355    if (CountLines) maxCount = TotalLines;
356    if (CountWords) maxCount = TotalWords;
357    if (CountChars) maxCount = TotalChars;
358    if (CountBytes) maxCount = TotalBytes;
359   
360    int fieldWidth = std::to_string(maxCount).size() + 1;
361    if (fieldWidth < defaultFieldWidth) fieldWidth = defaultFieldWidth;
362
363    for (unsigned i = 0; i < inputFiles.size(); ++i) {
364        std::cout << std::setw(fieldWidth-1);
365        if (CountLines) {
366            std::cout << lineCount[i] << std::setw(fieldWidth);
367        }
368        if (CountWords) {
369            std::cout << wordCount[i] << std::setw(fieldWidth);
370        }
371        if (CountChars) {
372            std::cout << charCount[i] << std::setw(fieldWidth);
373        }
374        if (CountBytes) {
375            std::cout << byteCount[i];
376        }
377        std::cout << " " << inputFiles[i] << std::endl;
378    }
379    if (inputFiles.size() > 1) {
380        std::cout << std::setw(fieldWidth-1);
381        if (CountLines) {
382            std::cout << TotalLines << std::setw(fieldWidth);
383        }
384        if (CountWords) {
385            std::cout << TotalWords << std::setw(fieldWidth);
386        }
387        if (CountChars) {
388            std::cout << TotalChars << std::setw(fieldWidth);
389        }
390        if (CountBytes) {
391            std::cout << TotalBytes;
392        }
393        std::cout << " total" << std::endl;
394    }
395
396    return 0;
397}
398
399                       
Note: See TracBrowser for help on using the repository browser.