source: icGREP/icgrep-devel/icgrep/lz4/grep/lz4_grep_base_generator.cpp @ 6219

Last change on this file since 6219 was 6219, checked in by cameron, 6 months ago

Update lz4 grep; don't print definitions of names except when needed

File size: 18.6 KB
Line 
1
2#include "lz4_grep_base_generator.h"
3
4#include <boost/iostreams/device/mapped_file.hpp>
5
6#include <llvm/Support/PrettyStackTrace.h>
7
8#include <cc/cc_compiler.h>
9
10#include <kernels/cc_kernel.h>
11#include <kernels/s2p_kernel.h>
12#include <kernels/p2s_kernel.h>
13#include <kernels/source_kernel.h>
14#include <kernels/stdout_kernel.h>
15#include <kernels/fake_stream_generating_kernel.h>
16#include <re/re_toolchain.h>
17
18#include <re/collect_ccs.h>
19#include <re/replaceCC.h>
20
21#include <re/casing.h>
22#include <re/exclude_CC.h>
23#include <re/to_utf8.h>
24#include <re/re_analysis.h>
25#include <re/re_name.h>
26#include <re/re_name_resolve.h>
27#include <re/re_name_gather.h>
28#include <re/re_multiplex.h>
29#include <re/re_utility.h>
30
31#include <UCD/resolve_properties.h>
32#include <kernels/charclasses.h>
33#include <kernels/grep_kernel.h>
34#include <kernels/UCD_property_kernel.h>
35#include <kernels/grapheme_kernel.h>
36#include <kernels/linebreak_kernel.h>
37#include <kernels/streams_merge.h>
38#include <kernels/scanmatchgen.h>
39#include <kernels/until_n.h>
40#include <re/grapheme_clusters.h>
41#include <re/printer_re.h>
42#include <llvm/Support/raw_ostream.h>
43#include <llvm/Support/Debug.h>
44#include <kernels/lz4/lz4_block_decoder.h>
45#include <kernels/lz4/lz4_match_detector.h>
46
47#include <re/re_seq.h>
48#include <kernels/kernel_builder.h>
49#include <kernels/pipeline_builder.h>
50#include <re/re_alt.h>
51#include <kernels/lz4/decompression/lz4_bytestream_decompression.h>
52#include <kernels/lz4/lz4_not_kernel.h>
53
54namespace re { class CC; }
55
56using namespace llvm;
57using namespace kernel;
58using namespace grep;
59using namespace re;
60using namespace cc;
61
62inline RE * makeNonFinal() {
63    CC * const C2_F4 = makeByte(0xC2, 0xF4);
64    CC * const E0_F4 = makeByte(0xE0, 0xF4);
65    CC * const _80_F4 = makeByte(0x80, 0xBF);
66    RE * E0_F4x80_F4 = makeSeq({E0_F4, _80_F4});
67    CC * const F0_F4 = makeByte(0xE0, 0xF4);
68    RE * F0_F4x80_F4x80_F4 = makeSeq({F0_F4, _80_F4, _80_F4});
69    return makeAlt({C2_F4, E0_F4x80_F4, F0_F4x80_F4x80_F4});
70}
71
72LZ4GrepBaseGenerator::LZ4GrepBaseGenerator(const FunctionType type)
73: LZ4BaseGenerator()
74, u8NonFinalRe(makeNonFinal())
75, u8FinalRe(makeCC(0x0, 0x1FFFFF))
76, mMainMethod(nullptr) {
77    mGrepRecordBreak = grep::GrepRecordBreakKind::LF;
78    mMoveMatchesToEOL = true;
79    mPipeline = makeInternalPipeline(type);
80}
81
82inline std::unique_ptr<kernel::PipelineBuilder> LZ4GrepBaseGenerator::makeInternalPipeline(const FunctionType type) {
83    Bindings inputs;
84    Bindings outputs;
85
86    auto & b = mPxDriver.getBuilder();
87
88    Type * const inputType = b->getInt8PtrTy();
89    Type * const sizeTy = b->getSizeTy();
90    Type * const boolTy = b->getIntNTy(sizeof(bool) * 8);
91
92    inputs.emplace_back(inputType, "input");
93    inputs.emplace_back(sizeTy, "headerSize");
94    inputs.emplace_back(sizeTy, "fileSize");
95    inputs.emplace_back(boolTy, "hasBlockChecksum");
96
97    if (type == FunctionType::CountOnly) {
98        outputs.emplace_back(sizeTy, "countResult");
99    } else if (type == FunctionType::Match) {
100        Type * const intAddrTy = b->getIntAddrTy();
101        inputs.emplace_back(intAddrTy, "match_accumulator");
102    }
103
104    return mPxDriver.makePipeline(inputs, outputs);
105}
106
107void LZ4GrepBaseGenerator::generateScanMatchGrepPipeline(RE* regex, bool enableMultiplexing, bool utf8CC) {
108    if (enableMultiplexing) {
109        generateMultiplexingScanMatchGrepPipeline(regex, utf8CC);
110    } else {
111        generateFullyDecompressionScanMatchGrepPipeline(regex);
112    }
113}
114
115void LZ4GrepBaseGenerator::generateCountOnlyGrepPipeline(RE* regex, bool enableMultiplexing, bool utf8CC){
116    if (enableMultiplexing) {
117        generateMultiplexingCountOnlyGrepPipeline(regex, utf8CC);
118    } else {
119        generateFullyDecompressionCountOnlyGrepPipeline(regex);
120    }
121}
122
123void LZ4GrepBaseGenerator::initREs(RE * re) {
124    if (mGrepRecordBreak == GrepRecordBreakKind::Unicode) {
125        mBreakCC = makeCC(makeCC(0x0A, 0x0D), makeCC(makeCC(0x85), makeCC(0x2028, 0x2029)));
126    } else if (mGrepRecordBreak == GrepRecordBreakKind::Null) {
127        mBreakCC = makeByte(0);  // Null
128    } else {
129        mBreakCC = makeByte(0x0A); // LF
130    }
131    RE * anchorRE = mBreakCC;
132    if (mGrepRecordBreak == GrepRecordBreakKind::Unicode) {
133        Name * anchorName = makeName("UTF8_LB", Name::Type::Unicode);
134        anchorName->setDefinition(makeUnicodeBreak());
135        anchorRE = anchorName;
136    }
137
138    mRE = re;
139    bool allAnchored = true;
140
141    if (!hasEndAnchor(mRE)) allAnchored = false;
142    mRE = resolveModesAndExternalSymbols(mRE);
143    mRE = exclude_CC(mRE, mBreakCC);
144    mRE = resolveAnchors(mRE, anchorRE);
145    gatherUnicodeProperties(mRE, mUnicodeProperties);
146    mRE = regular_expression_passes(mRE);
147
148    if (allAnchored && (mGrepRecordBreak != GrepRecordBreakKind::Unicode)) mMoveMatchesToEOL = false;
149
150}
151
152StreamSet * LZ4GrepBaseGenerator::linefeedStreamFromUncompressedBits(StreamSet *uncompressedBasisBits) {
153    StreamSet * const LineFeedStream = mPipeline->CreateStreamSet(1, 1);
154    mPipeline->CreateKernelCall<LineFeedKernelBuilder>(uncompressedBasisBits, LineFeedStream, BitNumbering::BigEndian);
155    return LineFeedStream;
156}
157
158unsigned LZ4GrepBaseGenerator::calculateTwistWidth(unsigned numOfStreams) {
159    if (numOfStreams <= 2) {
160        return numOfStreams;
161    } else if (numOfStreams <= 4) {
162        return 4;
163    } else if (numOfStreams <= 8) {
164        return 8;
165    } else {
166        llvm::report_fatal_error("Twist: Unsupported numOfStreams " + std::to_string(numOfStreams));;
167    }
168}
169
170std::pair<StreamSet *, StreamSet *> LZ4GrepBaseGenerator::multiplexingGrep(RE * re, StreamSet * compressedByteStream, StreamSet * compressedBitStream, bool utf8CC) {
171
172    initREs(re);
173
174    //  Regular Expression Processing and Analysis Phase
175
176    StreamSet * fakeMatchCopiedBits = nullptr;
177    StreamSet * u8NoFinalStream = nullptr;
178    StreamSet * uncompressedCharClasses = nullptr;
179
180    CC * const linefeedCC = makeCC(0x0A);
181
182    std::shared_ptr<MultiplexedAlphabet> mpx;
183
184    if (utf8CC) {
185
186        const auto requireNonFinal = isRequireNonFinal(mRE);
187        Seq * const seq = cast<Seq>(makeSeq({mRE, linefeedCC}));
188
189        auto UnicodeSets = collectCCs(seq, Unicode, std::set<Name *>({makeZeroWidth("\\b{g}")}));;
190
191        mpx = std::make_shared<MultiplexedAlphabet>("mpx", UnicodeSets);
192
193        auto mpxCCs = mpx->getMultiplexedCCs();
194
195        bool mpxContainFinal = false;
196
197        seq->push_back(u8FinalRe);
198
199        auto UnicodeSetsWithU8Final = collectCCs(seq, Unicode, std::set<Name *>({makeZeroWidth("\\b{g}")}));;
200        auto u8FinalMpx = std::make_shared<MultiplexedAlphabet>("mpx", UnicodeSetsWithU8Final);
201        auto mpxCCsWithU8Final = u8FinalMpx->getMultiplexedCCs();
202
203        if (calculateTwistWidth(mpxCCs.size() + 1) > calculateTwistWidth(mpxCCsWithU8Final.size())) {
204            mpxContainFinal = true;
205            UnicodeSets = UnicodeSetsWithU8Final;
206            mpx = u8FinalMpx;
207            mpxCCs = mpxCCsWithU8Final;
208        }
209
210        mRE = transformCCs(mpx, mRE);
211
212        StreamSet * CharClasses = mPipeline->CreateStreamSet(mpxCCs.size());
213
214        mPipeline->CreateKernelCall<CharClassesKernel>(std::move(mpxCCs), compressedBitStream, CharClasses, BitNumbering::BigEndian);
215
216        if (!requireNonFinal) {
217            // We do not need to decompress U8 NonFinal Stream is all of the character class in target regular expression is byte length
218            uncompressedCharClasses = decompressBitStream(compressedByteStream, CharClasses);
219            auto fakeStreams = generateFakeStreams(uncompressedCharClasses, std::vector<unsigned>{8, 1});
220            fakeMatchCopiedBits = fakeStreams[0];
221            u8NoFinalStream = fakeStreams[1];
222        } else {
223            if (mpxContainFinal) {
224                auto decompressedStreams = decompressBitStreams(compressedByteStream, {CharClasses/*, compressedNonFinalStream*/});
225                uncompressedCharClasses = decompressedStreams[0];
226                auto fakeStreams = generateFakeStreams(uncompressedCharClasses, std::vector<unsigned>{8});
227                fakeMatchCopiedBits = fakeStreams[0];
228                StreamSet * u8FinalStream = mPipeline->CreateStreamSet();
229                RE * const mpxU8FinalRe = transformCCs(mpx, u8FinalRe);
230               
231                std::unique_ptr<GrepKernelOptions> options = make_unique<GrepKernelOptions>();
232                options->addAlphabet(mpx, uncompressedCharClasses);
233                options->setSource(fakeMatchCopiedBits);
234                options->setNumbering(BitNumbering::BigEndian);
235                options->setRE(mpxU8FinalRe);
236                options->setResults(u8FinalStream);
237                mPipeline->CreateKernelCall<ICGrepKernel>(std::move(options));
238
239                u8NoFinalStream = mPipeline->CreateStreamSet(1, 1);
240                mPipeline->CreateKernelCall<LZ4NotKernel>(u8FinalStream, u8NoFinalStream);
241            } else {
242                StreamSet * compressedNonFinalStream = mPipeline->CreateStreamSet(1, 1);
243 
244                std::unique_ptr<GrepKernelOptions> options = make_unique<GrepKernelOptions>();
245                options->setSource(compressedBitStream);
246                options->setNumbering(BitNumbering::BigEndian);
247                options->setRE(u8NonFinalRe);
248                options->setResults(compressedNonFinalStream);
249                mPipeline->CreateKernelCall<ICGrepKernel>(std::move(options));
250
251                auto decompressedStreams = decompressBitStreams(compressedByteStream, {CharClasses, compressedNonFinalStream});
252                uncompressedCharClasses = decompressedStreams[0];
253                u8NoFinalStream = decompressedStreams[1];
254                auto fakeStreams = generateFakeStreams(uncompressedCharClasses, std::vector<unsigned>{8});
255                fakeMatchCopiedBits = fakeStreams[0];
256            }
257        }
258
259    } else { // if (!utf8CC) {
260
261        RE * const targetRe = toUTF8(mRE, true);
262        Seq * const seq = cast<Seq>(makeSeq({targetRe, linefeedCC}));
263        auto UnicodeSets = collectCCs(seq, Byte, std::set<Name *>({makeZeroWidth("\\b{g}")}));
264
265        mpx = std::make_shared<MultiplexedAlphabet>("mpx", UnicodeSets);
266
267        mRE = transformCCs(mpx, targetRe);
268
269        auto mpx_basis = mpx->getMultiplexedCCs();
270        StreamSet * const CharClasses = mPipeline->CreateStreamSet(mpx_basis.size());
271
272        mPipeline->CreateKernelCall<ByteClassesKernel>(std::move(mpx_basis), compressedBitStream, CharClasses, BitNumbering::BigEndian);
273
274        uncompressedCharClasses = decompressBitStream(compressedByteStream, CharClasses);
275        auto fakeStreams = generateFakeStreams(uncompressedCharClasses, std::vector<unsigned>{8, 1});
276        fakeMatchCopiedBits = fakeStreams[0];
277        u8NoFinalStream = fakeStreams[1];
278    }
279
280    StreamSet * const MatchResults = mPipeline->CreateStreamSet(1, 1);
281
282    // Multiplexing Grep Kernel is not Cachable, since it is possible that two REs with name "mpx_1" have different alphabets
283    StreamSet * LineBreakStream = mPipeline->CreateStreamSet(1, 1);
284
285    RE * const transformedLF = transformCCs(mpx, linefeedCC);
286    std::unique_ptr<GrepKernelOptions> optionsLF = make_unique<GrepKernelOptions>();
287
288    optionsLF->addAlphabet(mpx, uncompressedCharClasses);
289    optionsLF->setSource(fakeMatchCopiedBits);
290    optionsLF->setNumbering(BitNumbering::BigEndian);
291    optionsLF->setRE(transformedLF);
292    optionsLF->setResults(LineBreakStream);
293    mPipeline->CreateKernelCall<ICGrepKernel>(std::move(optionsLF));
294
295    std::unique_ptr<GrepKernelOptions> options = make_unique<GrepKernelOptions>();
296    options->addAlphabet(mpx, uncompressedCharClasses);
297    options->setSource(fakeMatchCopiedBits);
298    options->setNumbering(BitNumbering::BigEndian);
299    options->addExternal("UTF8_nonfinal", u8NoFinalStream);
300    options->setRE(mRE);
301    options->setResults(MatchResults);
302    mPipeline->CreateKernelCall<ICGrepKernel>(std::move(options));
303
304    StreamSet * Matches = MatchResults;
305    if (mMoveMatchesToEOL) {
306        StreamSet * const MovedMatches = mPipeline->CreateStreamSet();
307        mPipeline->CreateKernelCall<MatchedLinesKernel>(Matches, LineBreakStream, MovedMatches);
308        Matches = MovedMatches;
309    }
310
311    return std::pair<StreamSet *, StreamSet *>(LineBreakStream, Matches);
312}
313
314std::pair<StreamSet *, StreamSet *> LZ4GrepBaseGenerator::grep(RE * re, StreamSet * byteStream, StreamSet * uncompressedBasisBits, bool ccMultiplexing) {
315
316    initREs(re);
317
318    //  Regular Expression Processing and Analysis Phase
319    StreamSet * const MatchResults = mPipeline->CreateStreamSet(1, 1);
320
321    if (uncompressedBasisBits == nullptr) {
322        uncompressedBasisBits = s2p(byteStream);
323    }
324
325    StreamSet * const LineBreakStream = linefeedStreamFromUncompressedBits(uncompressedBasisBits);
326
327    if (ccMultiplexing) {
328
329        const auto UnicodeSets = collectCCs(mRE, Unicode, std::set<Name *>({makeZeroWidth("\\b{g}")}));
330
331        auto mpx = std::make_shared<MultiplexedAlphabet>("mpx", UnicodeSets);
332        mRE = transformCCs(mpx, mRE);
333        auto mpx_basis = mpx->getMultiplexedCCs();
334        StreamSet * const CharClasses = mPipeline->CreateStreamSet(mpx_basis.size());
335        mPipeline->CreateKernelCall<CharClassesKernel>(std::move(mpx_basis), uncompressedBasisBits, CharClasses, BitNumbering::BigEndian);
336
337        std::unique_ptr<GrepKernelOptions> options = make_unique<GrepKernelOptions>();
338        options->addAlphabet(mpx, CharClasses);
339        options->setSource(uncompressedBasisBits);
340        options->setNumbering(BitNumbering::BigEndian);
341        options->setRE(mRE);
342        options->setResults(MatchResults);
343        mPipeline->CreateKernelCall<ICGrepKernel>(std::move(options));
344
345
346    } else {
347
348        bool anyGCB = hasGraphemeClusterBoundary(mRE);
349        bool isSimple = (mGrepRecordBreak != GrepRecordBreakKind::Unicode) && (!anyGCB);
350        if (isSimple) {
351            mRE = toUTF8(mRE);
352        }
353        std::unique_ptr<GrepKernelOptions> options = make_unique<GrepKernelOptions>();
354        options->setSource(uncompressedBasisBits);
355        options->setNumbering(BitNumbering::BigEndian);
356        options->setRE(mRE);
357        options->setResults(MatchResults);
358        mPipeline->CreateKernelCall<ICGrepKernel>(std::move(options));
359    }
360
361    StreamSet * Matches = MatchResults;
362    if (mMoveMatchesToEOL) {
363        StreamSet * const MovedMatches = mPipeline->CreateStreamSet();
364        mPipeline->CreateKernelCall<MatchedLinesKernel>(Matches, LineBreakStream, MovedMatches);
365        Matches = MovedMatches;
366    }
367
368    return std::pair<StreamSet *, StreamSet *>(LineBreakStream, Matches);
369
370}
371
372void LZ4GrepBaseGenerator::invokeScanMatchGrep(char* fileBuffer, size_t blockStart, size_t blockEnd, bool hasBlockChecksum) {
373    auto main = getScanMatchGrepMainFunction();
374    std::ostringstream s;
375    EmitMatch accum("", false, false, s);
376
377    main(fileBuffer, blockStart, blockEnd, hasBlockChecksum, reinterpret_cast<intptr_t>(&accum));
378    llvm::outs() << s.str();
379}
380
381
382void LZ4GrepBaseGenerator::generateFullyDecompressionScanMatchGrepPipeline(RE *regex) {
383    StreamSet* compressedByteStream = loadByteStream();
384
385    StreamSet * const uncompressedByteStream = byteStreamDecompression(compressedByteStream);
386    StreamSet * uncompressedBitStream = s2p(uncompressedByteStream);
387
388    StreamSet * LineBreakStream;
389    StreamSet * Matches;
390    std::tie(LineBreakStream, Matches) = grep(regex, uncompressedByteStream, uncompressedBitStream);
391
392
393    Kernel * scanMatchK = mPipeline->CreateKernelCall<ScanMatchKernel>(Matches, LineBreakStream, uncompressedByteStream, match_accumulator);
394    mPxDriver.LinkFunction(scanMatchK, "accumulate_match_wrapper", accumulate_match_wrapper);
395    mPxDriver.LinkFunction(scanMatchK, "finalize_match_wrapper", finalize_match_wrapper);
396
397    mMainMethod = mPipeline->compile();
398
399}
400
401void LZ4GrepBaseGenerator::generateMultiplexingScanMatchGrepPipeline(RE *regex, bool utf8CC) {
402
403    StreamSet *compressedByteStream = nullptr, *compressedBasisBits = nullptr;
404    std::tie(compressedByteStream, compressedBasisBits) = loadByteStreamAndBitStream();
405
406    StreamSet * LineBreakStream;
407    StreamSet * Matches;
408    std::tie(LineBreakStream, Matches) = multiplexingGrep(regex, compressedByteStream, compressedBasisBits, utf8CC);
409
410    LZ4BlockInfo blockInfo = getBlockInfo(compressedByteStream);
411
412    StreamSet * const decompressionByteStream = mPipeline->CreateStreamSet(1, 8);
413    mPipeline->CreateKernelCall<LZ4ByteStreamDecompressionKernel>(mFileSize, compressedByteStream, blockInfo, nullptr, decompressionByteStream );
414    Kernel * const scanMatchK = mPipeline->CreateKernelCall<ScanMatchKernel>(Matches, LineBreakStream, decompressionByteStream, match_accumulator);
415    mPxDriver.LinkFunction(scanMatchK, "accumulate_match_wrapper", accumulate_match_wrapper);
416    mPxDriver.LinkFunction(scanMatchK, "finalize_match_wrapper", finalize_match_wrapper);
417    mMainMethod = mPipeline->compile();
418}
419
420
421void LZ4GrepBaseGenerator::generateMultiplexingCountOnlyGrepPipeline(RE *regex, bool utf8CC) {
422    StreamSet *compressedByteStream = nullptr, *compressedBasisBits = nullptr;
423    std::tie(compressedByteStream, compressedBasisBits) = loadByteStreamAndBitStream();
424    StreamSet * Matches = multiplexingGrep(regex, compressedByteStream, compressedBasisBits, utf8CC).second;
425    mPipeline->CreateKernelCall<PopcountKernel>(Matches, mPipeline->getOutputScalar("countResult"));
426    mMainMethod = mPipeline->compile();
427}
428
429
430void LZ4GrepBaseGenerator::generateFullyDecompressionCountOnlyGrepPipeline(RE *regex) {
431    StreamSet * const uncompressedByteStream = generateUncompressedByteStream();
432    StreamSet * LineBreakStream;
433    StreamSet * Matches;
434    std::tie(LineBreakStream, Matches) = grep(regex, uncompressedByteStream, nullptr);
435    mPipeline->CreateKernelCall<PopcountKernel>(Matches, mPipeline->getOutputScalar("countResult"));
436    mMainMethod = mPipeline->compile();
437}
438
439
440ScanMatchGrepMainFunctionType LZ4GrepBaseGenerator::getScanMatchGrepMainFunction() {
441    return reinterpret_cast<ScanMatchGrepMainFunctionType>(mMainMethod);
442}
443CountOnlyGrepMainFunctionType LZ4GrepBaseGenerator::getCountOnlyGrepMainFunction() {
444    return reinterpret_cast<CountOnlyGrepMainFunctionType>(mMainMethod);
445}
446
447StreamSets LZ4GrepBaseGenerator::generateFakeStreams(StreamSet * refStream, std::vector<unsigned> numOfStreams) {
448    if (numOfStreams.empty()) {
449        return StreamSets{};
450    }
451    StreamSets outputStreams;
452    outputStreams.reserve(numOfStreams.size());
453    for (const auto k : numOfStreams) {
454        outputStreams.push_back(mPipeline->CreateStreamSet(k));
455    }
456    mPipeline->CreateKernelCall<FakeStreamGeneratingKernel>(refStream, outputStreams);
457    return outputStreams;
458}
459
460
461
462StreamSets LZ4GrepBaseGenerator::decompressBitStreams(StreamSet *compressedByteStream, StreamSets compressedBitStreams) {
463    // Default implementation here will be slow
464    StreamSets retVec;
465    for (unsigned i = 0; i < compressedBitStreams.size(); i++) {
466        retVec.push_back(decompressBitStream(compressedByteStream, compressedBitStreams[i]));
467    }
468    return retVec;
469}
Note: See TracBrowser for help on using the repository browser.