source: icGREP/icgrep-devel/icgrep/lz4/grep/lz4_grep_base_generator.cpp @ 6143

Last change on this file since 6143 was 6143, checked in by xwa163, 10 months ago

lz4_grep: some bug fixing and refactor

File size: 18.3 KB
Line 
1
2#include "lz4_grep_base_generator.h"
3
4#include <boost/iostreams/device/mapped_file.hpp>
5
6#include <llvm/Support/PrettyStackTrace.h>
7
8#include <cc/cc_compiler.h>
9
10#include <kernels/cc_kernel.h>
11#include <kernels/s2p_kernel.h>
12#include <kernels/p2s_kernel.h>
13#include <kernels/source_kernel.h>
14#include <kernels/stdout_kernel.h>
15#include <kernels/fake_stream_generating_kernel.h>
16#include <re/re_toolchain.h>
17
18#include <re/collect_ccs.h>
19#include <re/replaceCC.h>
20
21#include <re/casing.h>
22#include <re/exclude_CC.h>
23#include <re/to_utf8.h>
24#include <re/re_analysis.h>
25#include <re/re_name_resolve.h>
26#include <re/re_name_gather.h>
27#include <re/re_multiplex.h>
28#include <re/re_utility.h>
29
30#include <UCD/resolve_properties.h>
31#include <kernels/charclasses.h>
32#include <kernels/grep_kernel.h>
33#include <kernels/UCD_property_kernel.h>
34#include <kernels/grapheme_kernel.h>
35#include <kernels/linebreak_kernel.h>
36#include <kernels/streams_merge.h>
37#include <kernels/scanmatchgen.h>
38#include <kernels/until_n.h>
39#include <re/grapheme_clusters.h>
40#include <re/printer_re.h>
41#include <llvm/Support/raw_ostream.h>
42#include <llvm/Support/Debug.h>
43#include <kernels/lz4/lz4_block_decoder.h>
44
45#include <re/re_seq.h>
46#include <kernels/kernel_builder.h>
47
48namespace re { class CC; }
49
50using namespace llvm;
51using namespace parabix;
52using namespace kernel;
53using namespace grep;
54
55LZ4GrepBaseGenerator::LZ4GrepBaseGenerator(): LZ4BaseGenerator() {
56    mGrepRecordBreak = grep::GrepRecordBreakKind::LF;
57    mMoveMatchesToEOL = true;
58}
59
60
61
62void LZ4GrepBaseGenerator::generateScanMatchGrepPipeline(re::RE* regex) {
63    auto & iBuilder = mPxDriver.getBuilder();
64    this->generateScanMatchMainFunc(iBuilder);
65
66    StreamSetBuffer* compressedByteStream = this->loadByteStream();
67
68    StreamSetBuffer * const uncompressedByteStream = this->byteStreamDecompression(compressedByteStream);
69    StreamSetBuffer * uncompressedBitStream = this->s2p(uncompressedByteStream);
70
71    StreamSetBuffer * LineBreakStream;
72    StreamSetBuffer * Matches;
73    std::tie(LineBreakStream, Matches) = grep(regex, uncompressedBitStream);
74
75    kernel::Kernel * scanMatchK = mPxDriver.addKernelInstance<kernel::ScanMatchKernel>(iBuilder);
76    scanMatchK->setInitialArguments({match_accumulator});
77    mPxDriver.makeKernelCall(scanMatchK, {Matches, LineBreakStream, uncompressedByteStream}, {});
78    mPxDriver.LinkFunction(*scanMatchK, "accumulate_match_wrapper", &accumulate_match_wrapper);
79    mPxDriver.LinkFunction(*scanMatchK, "finalize_match_wrapper", &finalize_match_wrapper);
80
81    mPxDriver.generatePipelineIR();
82    mPxDriver.deallocateBuffers();
83
84    iBuilder->CreateRetVoid();
85
86    mPxDriver.finalizeObject();
87}
88
89
90void LZ4GrepBaseGenerator::generateCountOnlyGrepPipeline(re::RE* regex, bool enableMultiplexing, bool utf8CC){
91    if (enableMultiplexing) {
92        this->generateMultiplexingCountOnlyGrepPipeline(regex, utf8CC);
93    } else {
94        this->generateFullyDecompressionCountOnlyGrepPipeline(regex);
95    }
96}
97
98
99void LZ4GrepBaseGenerator::initREs(re::RE * RE) {
100    if (mGrepRecordBreak == GrepRecordBreakKind::Unicode) {
101        mBreakCC = re::makeCC(re::makeCC(0x0A, 0x0D), re::makeCC(re::makeCC(0x85), re::makeCC(0x2028, 0x2029)));
102    } else if (mGrepRecordBreak == GrepRecordBreakKind::Null) {
103        mBreakCC = re::makeByte(0);  // Null
104    } else {
105        mBreakCC = re::makeByte(0x0A); // LF
106    }
107    re::RE * anchorRE = mBreakCC;
108    if (mGrepRecordBreak == GrepRecordBreakKind::Unicode) {
109        re::Name * anchorName = re::makeName("UTF8_LB", re::Name::Type::Unicode);
110        anchorName->setDefinition(re::makeUnicodeBreak());
111        anchorRE = anchorName;
112    }
113
114    mRE = RE;
115    bool allAnchored = true;
116
117    if (!hasEndAnchor(mRE)) allAnchored = false;
118    mRE = resolveModesAndExternalSymbols(mRE);
119    mRE = re::exclude_CC(mRE, mBreakCC);
120    mRE = resolveAnchors(mRE, anchorRE);
121    re::gatherUnicodeProperties(mRE, mUnicodeProperties);
122    mRE = regular_expression_passes(mRE);
123
124    if (allAnchored && (mGrepRecordBreak != GrepRecordBreakKind::Unicode)) mMoveMatchesToEOL = false;
125
126}
127
128
129parabix::StreamSetBuffer * LZ4GrepBaseGenerator::linefeedStreamFromUncompressedBits(
130        parabix::StreamSetBuffer *uncompressedBasisBits) {
131    auto & idb = mPxDriver.getBuilder();
132    const unsigned baseBufferSize = this->getDefaultBufferBlocks();
133    StreamSetBuffer * LineFeedStream = mPxDriver.addBuffer<StaticBuffer>(idb, idb->getStreamSetTy(1, 1), baseBufferSize, 1);
134    kernel::Kernel * linefeedK = mPxDriver.addKernelInstance<kernel::LineFeedKernelBuilder>(idb, Binding{idb->getStreamSetTy(8), "basis", FixedRate(), Principal()}, cc::BitNumbering::BigEndian);
135    mPxDriver.makeKernelCall(linefeedK, {uncompressedBasisBits}, {LineFeedStream});
136    return LineFeedStream;
137}
138
139
140
141
142
143
144std::pair<parabix::StreamSetBuffer *, parabix::StreamSetBuffer *> LZ4GrepBaseGenerator::multiplexingGrep(
145        re::RE *RE,
146        parabix::StreamSetBuffer *compressedByteStream,
147        parabix::StreamSetBuffer *compressedBitStream,
148        bool utf8CC
149) {
150
151    this->initREs(RE);
152    auto mGrepDriver = &mPxDriver;
153
154    auto & idb = mGrepDriver->getBuilder();
155    const unsigned baseBufferSize = this->getDefaultBufferBlocks();
156    int MaxCountFlag = 0;
157
158    //  Regular Expression Processing and Analysis Phase
159    const auto nREs = 1;
160
161    std::vector<StreamSetBuffer *> MatchResultsBufs(nREs);
162
163
164    std::map<std::string, StreamSetBuffer *> propertyStream;
165
166    std::vector<std::string> externalStreamNames;
167    std::set<re::Name *> UnicodeProperties;
168
169
170    re::Seq* seq = re::makeSeq();
171    re::RE* targetRe = mRE;
172
173
174    re::CC* linefeedCC = nullptr;
175    if (utf8CC) {
176        linefeedCC = re::makeCC(0x0A);
177    } else {
178        linefeedCC = re::makeByte(0x0A);
179        targetRe = re::toUTF8(targetRe, true);
180    }
181
182    seq->push_back(targetRe);
183    seq->push_back(std::move(linefeedCC));
184
185    std::vector<re::CC*> UnicodeSets;
186    if (utf8CC) {
187        UnicodeSets = re::collectCCs(seq, &cc::Unicode, std::set<re::Name *>({re::makeZeroWidth("\\b{g}")}));
188    } else {
189        UnicodeSets = re::collectCCs(seq, &cc::Byte, std::set<re::Name *>({re::makeZeroWidth("\\b{g}")}));
190    }
191
192    StreamSetBuffer * const MatchResults = mGrepDriver->addBuffer<StaticBuffer>(idb, idb->getStreamSetTy(1, 1), baseBufferSize, 1);
193
194    mpx = make_unique<cc::MultiplexedAlphabet>("mpx", UnicodeSets);
195    mRE = transformCCs(mpx.get(), targetRe);
196
197
198    std::vector<re::CC *> mpx_basis = mpx->getMultiplexedCCs();
199    auto numOfCharacterClasses = mpx_basis.size();
200    StreamSetBuffer * CharClasses = mGrepDriver->addBuffer<StaticBuffer>(idb, idb->getStreamSetTy(numOfCharacterClasses), baseBufferSize, 1);
201
202    kernel::Kernel * ccK = nullptr;
203    if (utf8CC) {
204        ccK = mGrepDriver->addKernelInstance<kernel::CharClassesKernel>(idb, std::move(mpx_basis), false, cc::BitNumbering::BigEndian);
205    } else {
206        ccK = mGrepDriver->addKernelInstance<kernel::ByteClassesKernel>(idb, std::move(mpx_basis), false, cc::BitNumbering::BigEndian);
207    }
208    mGrepDriver->makeKernelCall(ccK, {compressedBitStream}, {CharClasses});
209
210    StreamSetBuffer * uncompressedCharClasses = this->decompressBitStream(compressedByteStream, CharClasses);
211
212
213    StreamSetBuffer * fakeMatchCopiedBits = mPxDriver.addBuffer<StaticBuffer>(idb, idb->getStreamSetTy(8),
214                                                                              this->getDefaultBufferBlocks(), 1);
215    StreamSetBuffer * u8NoFinalStream = mPxDriver.addBuffer<StaticBuffer>(idb, idb->getStreamSetTy(1),
216                                                                          this->getDefaultBufferBlocks(), 1);
217
218    Kernel* fakeStreamGeneratorK = mPxDriver.addKernelInstance<FakeStreamGeneratingKernel>(idb, numOfCharacterClasses, std::vector<unsigned>({8, 1}));
219    mPxDriver.makeKernelCall(fakeStreamGeneratorK, {uncompressedCharClasses}, {fakeMatchCopiedBits, u8NoFinalStream});
220
221    StreamSetBuffer * LineBreakStream = mPxDriver.addBuffer<StaticBuffer>(idb, idb->getStreamSetTy(1, 1), this->getDefaultBufferBlocks(), 1);
222    kernel::Kernel * lineFeedGrepK = mGrepDriver->addKernelInstance<kernel::ICGrepKernel>(idb, transformCCs(mpx.get(), linefeedCC), externalStreamNames, std::vector<cc::Alphabet *>{mpx.get()}, cc::BitNumbering::BigEndian);
223    mGrepDriver->makeKernelCall(lineFeedGrepK, {fakeMatchCopiedBits, uncompressedCharClasses}, {LineBreakStream});
224
225
226    externalStreamNames.push_back("UTF8_nonfinal");
227
228    kernel::Kernel * icgrepK = mGrepDriver->addKernelInstance<kernel::ICGrepKernel>(idb, mRE, externalStreamNames, std::vector<cc::Alphabet *>{mpx.get()}, cc::BitNumbering::BigEndian);
229    mGrepDriver->makeKernelCall(icgrepK, {fakeMatchCopiedBits, u8NoFinalStream, uncompressedCharClasses}, {MatchResults});
230    MatchResultsBufs[0] = MatchResults;
231
232    StreamSetBuffer * MergedResults = MatchResultsBufs[0];
233
234    StreamSetBuffer * Matches = MergedResults;
235    if (mMoveMatchesToEOL) {
236        StreamSetBuffer * OriginalMatches = Matches;
237        kernel::Kernel * matchedLinesK = mGrepDriver->addKernelInstance<kernel::MatchedLinesKernel>(idb);
238        Matches = mGrepDriver->addBuffer<StaticBuffer>(idb, idb->getStreamSetTy(1, 1), baseBufferSize, 1);
239        mGrepDriver->makeKernelCall(matchedLinesK, {OriginalMatches, LineBreakStream}, {Matches});
240    }
241
242    if (MaxCountFlag > 0) {
243        kernel::Kernel * untilK = mGrepDriver->addKernelInstance<kernel::UntilNkernel>(idb);
244        untilK->setInitialArguments({idb->getSize(MaxCountFlag)});
245        StreamSetBuffer * const AllMatches = Matches;
246        Matches = mGrepDriver->addBuffer<StaticBuffer>(idb, idb->getStreamSetTy(1, 1), baseBufferSize, 1);
247        mGrepDriver->makeKernelCall(untilK, {AllMatches}, {Matches});
248    }
249
250    return std::pair<StreamSetBuffer *, StreamSetBuffer *>(LineBreakStream, Matches);
251};
252
253std::pair<parabix::StreamSetBuffer *, parabix::StreamSetBuffer *> LZ4GrepBaseGenerator::grep(
254        re::RE *RE, parabix::StreamSetBuffer *uncompressedBasisBits, bool ccMultiplexing) {
255
256    this->initREs(RE);
257    auto mGrepDriver = &mPxDriver;
258
259    auto & idb = mGrepDriver->getBuilder();
260    // TODO: until we automate stream buffer sizing, use this calculation to determine how large our matches buffer needs to be.
261    const unsigned baseBufferSize = this->getDefaultBufferBlocks();
262    int MaxCountFlag = 0;
263
264    //  Regular Expression Processing and Analysis Phase
265    const auto nREs = 1;
266
267    std::vector<StreamSetBuffer *> MatchResultsBufs(nREs);
268
269    StreamSetBuffer * LineBreakStream = this->linefeedStreamFromUncompressedBits(uncompressedBasisBits);
270
271
272    std::map<std::string, StreamSetBuffer *> propertyStream;
273
274    for(unsigned i = 0; i < nREs; ++i) {
275        std::vector<std::string> externalStreamNames;
276        std::vector<StreamSetBuffer *> icgrepInputSets = {uncompressedBasisBits};
277
278        if (ccMultiplexing) {
279            const auto UnicodeSets = re::collectCCs(mRE, &cc::Unicode, std::set<re::Name *>({re::makeZeroWidth("\\b{g}")}));
280            StreamSetBuffer * const MatchResults = mGrepDriver->addBuffer<StaticBuffer>(idb, idb->getStreamSetTy(1, 1), baseBufferSize, 1);
281
282            std::unique_ptr<cc::MultiplexedAlphabet> mpx = make_unique<cc::MultiplexedAlphabet>("mpx", UnicodeSets);
283            mRE = transformCCs(mpx.get(), mRE);
284            std::vector<re::CC *> mpx_basis = mpx->getMultiplexedCCs();
285            auto numOfCharacterClasses = mpx_basis.size();
286            StreamSetBuffer * CharClasses = mGrepDriver->addBuffer<StaticBuffer>(idb, idb->getStreamSetTy(numOfCharacterClasses), baseBufferSize, 1);
287            kernel::Kernel * ccK = mGrepDriver->addKernelInstance<kernel::CharClassesKernel>(idb, std::move(mpx_basis), false, cc::BitNumbering::BigEndian);
288            mGrepDriver->makeKernelCall(ccK, {uncompressedBasisBits}, {CharClasses});
289
290            kernel::Kernel * icgrepK = mGrepDriver->addKernelInstance<kernel::ICGrepKernel>(idb, mRE, externalStreamNames, std::vector<cc::Alphabet *>{mpx.get()}, cc::BitNumbering::BigEndian);
291            icgrepInputSets.push_back(CharClasses);
292            mGrepDriver->makeKernelCall(icgrepK, icgrepInputSets, {MatchResults});
293            MatchResultsBufs[i] = MatchResults;
294        } else {
295            std::set<re::Name *> UnicodeProperties;
296
297            StreamSetBuffer * MatchResults = mGrepDriver->addBuffer<StaticBuffer>(idb, idb->getStreamSetTy(1, 1), baseBufferSize, 1);
298            kernel::Kernel * icgrepK = mGrepDriver->addKernelInstance<kernel::ICGrepKernel>(idb, mRE, externalStreamNames, std::vector<cc::Alphabet *>(), cc::BitNumbering::BigEndian);
299            mGrepDriver->makeKernelCall(icgrepK, icgrepInputSets, {MatchResults});
300            MatchResultsBufs[i] = MatchResults;
301        }
302    }
303
304    StreamSetBuffer * MergedResults = MatchResultsBufs[0];
305
306    StreamSetBuffer * Matches = MergedResults;
307    if (mMoveMatchesToEOL) {
308        StreamSetBuffer * OriginalMatches = Matches;
309        kernel::Kernel * matchedLinesK = mGrepDriver->addKernelInstance<kernel::MatchedLinesKernel>(idb);
310        Matches = mGrepDriver->addBuffer<StaticBuffer>(idb, idb->getStreamSetTy(1, 1), baseBufferSize, 1);
311        mGrepDriver->makeKernelCall(matchedLinesK, {OriginalMatches, LineBreakStream}, {Matches});
312    }
313
314    if (MaxCountFlag > 0) {
315        kernel::Kernel * untilK = mGrepDriver->addKernelInstance<kernel::UntilNkernel>(idb);
316        untilK->setInitialArguments({idb->getSize(MaxCountFlag)});
317        StreamSetBuffer * const AllMatches = Matches;
318        Matches = mGrepDriver->addBuffer<StaticBuffer>(idb, idb->getStreamSetTy(1, 1), baseBufferSize, 1);
319        mGrepDriver->makeKernelCall(untilK, {AllMatches}, {Matches});
320    }
321
322    return std::pair<StreamSetBuffer *, StreamSetBuffer *>(LineBreakStream, Matches);
323
324}
325
326void LZ4GrepBaseGenerator::invokeScanMatchGrep(char* fileBuffer, size_t blockStart, size_t blockEnd, bool hasBlockChecksum) {
327    auto main = this->getScanMatchGrepMainFunction();
328    std::ostringstream s;
329    EmitMatch accum("", false, false, s);
330
331    main(fileBuffer, blockStart, blockEnd, hasBlockChecksum, reinterpret_cast<intptr_t>(&accum));
332    llvm::outs() << s.str();
333}
334
335
336
337void LZ4GrepBaseGenerator::generateMultiplexingCountOnlyGrepPipeline(re::RE *regex, bool utf8CC) {
338    auto & iBuilder = mPxDriver.getBuilder();
339    this->generateCountOnlyMainFunc(iBuilder);
340
341    StreamSetBuffer *compressedByteStream = nullptr, *compressedBasisBits = nullptr;
342    std::tie(compressedByteStream, compressedBasisBits) = this->loadByteStreamAndBitStream();
343
344    StreamSetBuffer * LineBreakStream;
345    StreamSetBuffer * Matches;
346    std::tie(LineBreakStream, Matches) = multiplexingGrep(regex, compressedByteStream, compressedBasisBits, utf8CC);
347
348    kernel::Kernel * matchCountK = mPxDriver.addKernelInstance<kernel::PopcountKernel>(iBuilder);
349    mPxDriver.makeKernelCall(matchCountK, {Matches}, {});
350    mPxDriver.generatePipelineIR();
351
352    iBuilder->setKernel(matchCountK);
353    Value * matchedLineCount = iBuilder->getAccumulator("countResult");
354    matchedLineCount = iBuilder->CreateZExt(matchedLineCount, iBuilder->getInt64Ty());
355
356    mPxDriver.deallocateBuffers();
357
358    iBuilder->CreateRet(matchedLineCount);
359
360    mPxDriver.finalizeObject();
361}
362
363
364void LZ4GrepBaseGenerator::generateFullyDecompressionCountOnlyGrepPipeline(re::RE *regex) {
365    auto & iBuilder = mPxDriver.getBuilder();
366    this->generateCountOnlyMainFunc(iBuilder);
367
368    StreamSetBuffer * const uncompressedBitStream = this->generateUncompressedBitStreams();
369
370    StreamSetBuffer * LineBreakStream;
371    StreamSetBuffer * Matches;
372
373    std::tie(LineBreakStream, Matches) = grep(regex, uncompressedBitStream);
374
375    kernel::Kernel * matchCountK = mPxDriver.addKernelInstance<kernel::PopcountKernel>(iBuilder);
376    mPxDriver.makeKernelCall(matchCountK, {Matches}, {});
377    mPxDriver.generatePipelineIR();
378
379    iBuilder->setKernel(matchCountK);
380    Value * matchedLineCount = iBuilder->getAccumulator("countResult");
381    matchedLineCount = iBuilder->CreateZExt(matchedLineCount, iBuilder->getInt64Ty());
382
383    mPxDriver.deallocateBuffers();
384
385    iBuilder->CreateRet(matchedLineCount);
386
387    mPxDriver.finalizeObject();
388}
389
390
391ScanMatchGrepMainFunctionType LZ4GrepBaseGenerator::getScanMatchGrepMainFunction() {
392    return reinterpret_cast<ScanMatchGrepMainFunctionType>(mPxDriver.getMain());
393}
394CountOnlyGrepMainFunctionType LZ4GrepBaseGenerator::getCountOnlyGrepMainFunction() {
395    return reinterpret_cast<CountOnlyGrepMainFunctionType>(mPxDriver.getMain());
396}
397
398void LZ4GrepBaseGenerator::generateCountOnlyMainFunc(const std::unique_ptr<kernel::KernelBuilder> & iBuilder) {
399    Module * M = iBuilder->getModule();
400    Type * const int64Ty = iBuilder->getInt64Ty();
401    Type * const sizeTy = iBuilder->getSizeTy();
402    Type * const boolTy = iBuilder->getIntNTy(sizeof(bool) * 8);
403//    Type * const voidTy = iBuilder->getVoidTy();
404    Type * const inputType = iBuilder->getInt8PtrTy();
405
406    Function * const main = cast<Function>(M->getOrInsertFunction("Main", int64Ty, inputType, sizeTy, sizeTy, boolTy, nullptr));
407    main->setCallingConv(CallingConv::C);
408    Function::arg_iterator args = main->arg_begin();
409    mInputStream = &*(args++);
410    mInputStream->setName("input");
411
412    mHeaderSize = &*(args++);
413    mHeaderSize->setName("mHeaderSize");
414
415    mFileSize = &*(args++);
416    mFileSize->setName("mFileSize");
417
418    mHasBlockChecksum = &*(args++);
419    mHasBlockChecksum->setName("mHasBlockChecksum");
420    // TODO for now, we do not handle blockCheckSum
421    mHasBlockChecksum = iBuilder->getInt1(false);
422
423    iBuilder->SetInsertPoint(BasicBlock::Create(M->getContext(), "entry", main, 0));
424}
425
426void LZ4GrepBaseGenerator::generateScanMatchMainFunc(const std::unique_ptr<kernel::KernelBuilder> & iBuilder) {
427    Module * M = iBuilder->getModule();
428    Type * const sizeTy = iBuilder->getSizeTy();
429    Type * const boolTy = iBuilder->getIntNTy(sizeof(bool) * 8);
430    Type * const voidTy = iBuilder->getVoidTy();
431    Type * const inputType = iBuilder->getInt8PtrTy();
432    Type * const intAddrTy = iBuilder->getIntAddrTy();
433
434    Function * const main = cast<Function>(M->getOrInsertFunction("Main", voidTy, inputType, sizeTy, sizeTy, boolTy, intAddrTy, nullptr));
435    main->setCallingConv(CallingConv::C);
436    Function::arg_iterator args = main->arg_begin();
437    mInputStream = &*(args++);
438    mInputStream->setName("input");
439
440    mHeaderSize = &*(args++);
441    mHeaderSize->setName("mHeaderSize");
442
443    mFileSize = &*(args++);
444    mFileSize->setName("mFileSize");
445
446    mHasBlockChecksum = &*(args++);
447    mHasBlockChecksum->setName("mHasBlockChecksum");
448
449    match_accumulator = &*(args++);
450    match_accumulator->setName("match_accumulator");
451
452    iBuilder->SetInsertPoint(BasicBlock::Create(M->getContext(), "entry", main, 0));
453}
Note: See TracBrowser for help on using the repository browser.