source: icGREP/icgrep-devel/icgrep/lz4/grep/lz4_grep_base_generator.cpp @ 6188

Last change on this file since 6188 was 6188, checked in by cameron, 7 months ago

Small fix

File size: 17.6 KB
Line 
1
2#include "lz4_grep_base_generator.h"
3
4#include <boost/iostreams/device/mapped_file.hpp>
5
6#include <llvm/Support/PrettyStackTrace.h>
7
8#include <cc/cc_compiler.h>
9
10#include <kernels/cc_kernel.h>
11#include <kernels/s2p_kernel.h>
12#include <kernels/p2s_kernel.h>
13#include <kernels/source_kernel.h>
14#include <kernels/stdout_kernel.h>
15#include <kernels/fake_stream_generating_kernel.h>
16#include <re/re_toolchain.h>
17
18#include <re/collect_ccs.h>
19#include <re/replaceCC.h>
20
21#include <re/casing.h>
22#include <re/exclude_CC.h>
23#include <re/to_utf8.h>
24#include <re/re_analysis.h>
25#include <re/re_name.h>
26#include <re/re_name_resolve.h>
27#include <re/re_name_gather.h>
28#include <re/re_multiplex.h>
29#include <re/re_utility.h>
30
31#include <UCD/resolve_properties.h>
32#include <kernels/charclasses.h>
33#include <kernels/grep_kernel.h>
34#include <kernels/UCD_property_kernel.h>
35#include <kernels/grapheme_kernel.h>
36#include <kernels/linebreak_kernel.h>
37#include <kernels/streams_merge.h>
38#include <kernels/scanmatchgen.h>
39#include <kernels/until_n.h>
40#include <re/grapheme_clusters.h>
41#include <re/printer_re.h>
42#include <llvm/Support/raw_ostream.h>
43#include <llvm/Support/Debug.h>
44#include <kernels/lz4/lz4_block_decoder.h>
45#include <kernels/lz4/lz4_match_detector.h>
46
47#include <re/re_seq.h>
48#include <kernels/kernel_builder.h>
49#include <kernels/pipeline_builder.h>
50#include <re/re_alt.h>
51#include <kernels/lz4/decompression/lz4_bytestream_decompression.h>
52#include <kernels/lz4/lz4_not_kernel.h>
53
54namespace re { class CC; }
55
56using namespace llvm;
57using namespace kernel;
58using namespace grep;
59using namespace re;
60using namespace cc;
61
62using Alphabets = ICGrepKernel::Alphabets;
63using Externals = ICGrepKernel::Externals;
64
65inline RE * makeNonFinal() {
66    CC * const C2_F4 = makeByte(0xC2, 0xF4);
67    CC * const E0_F4 = makeByte(0xE0, 0xF4);
68    CC * const _80_F4 = makeByte(0x80, 0xBF);
69    RE * E0_F4x80_F4 = makeSeq({E0_F4, _80_F4});
70    CC * const F0_F4 = makeByte(0xE0, 0xF4);
71    RE * F0_F4x80_F4x80_F4 = makeSeq({F0_F4, _80_F4, _80_F4});
72    return makeAlt({C2_F4, E0_F4x80_F4, F0_F4x80_F4x80_F4});
73}
74
75LZ4GrepBaseGenerator::LZ4GrepBaseGenerator(const FunctionType type)
76: LZ4BaseGenerator()
77, u8NonFinalRe(makeNonFinal())
78, u8FinalRe(makeCC(0x0, 0x1FFFFF))
79, mMainMethod(nullptr) {
80    mGrepRecordBreak = grep::GrepRecordBreakKind::LF;
81    mMoveMatchesToEOL = true;
82    mPipeline = makeInternalPipeline(type);
83}
84
85inline std::unique_ptr<kernel::PipelineBuilder> LZ4GrepBaseGenerator::makeInternalPipeline(const FunctionType type) {
86    Bindings inputs;
87    Bindings outputs;
88
89    auto & b = mPxDriver.getBuilder();
90
91    Type * const inputType = b->getInt8PtrTy();
92    Type * const sizeTy = b->getSizeTy();
93    Type * const boolTy = b->getIntNTy(sizeof(bool) * 8);
94
95    inputs.emplace_back(inputType, "input");
96    inputs.emplace_back(sizeTy, "headerSize");
97    inputs.emplace_back(sizeTy, "fileSize");
98    inputs.emplace_back(boolTy, "hasBlockChecksum");
99
100    if (type == FunctionType::CountOnly) {
101        outputs.emplace_back(sizeTy, "countResult");
102    } else if (type == FunctionType::Match) {
103        Type * const intAddrTy = b->getIntAddrTy();
104        inputs.emplace_back(intAddrTy, "match_accumulator");
105    }
106
107    return mPxDriver.makePipeline(inputs, outputs);
108}
109
110void LZ4GrepBaseGenerator::generateScanMatchGrepPipeline(RE* regex, bool enableMultiplexing, bool utf8CC) {
111    if (enableMultiplexing) {
112        generateMultiplexingScanMatchGrepPipeline(regex, utf8CC);
113    } else {
114        generateFullyDecompressionScanMatchGrepPipeline(regex);
115    }
116}
117
118void LZ4GrepBaseGenerator::generateCountOnlyGrepPipeline(RE* regex, bool enableMultiplexing, bool utf8CC){
119    if (enableMultiplexing) {
120        generateMultiplexingCountOnlyGrepPipeline(regex, utf8CC);
121    } else {
122        generateFullyDecompressionCountOnlyGrepPipeline(regex);
123    }
124}
125
126void LZ4GrepBaseGenerator::initREs(RE * re) {
127    if (mGrepRecordBreak == GrepRecordBreakKind::Unicode) {
128        mBreakCC = makeCC(makeCC(0x0A, 0x0D), makeCC(makeCC(0x85), makeCC(0x2028, 0x2029)));
129    } else if (mGrepRecordBreak == GrepRecordBreakKind::Null) {
130        mBreakCC = makeByte(0);  // Null
131    } else {
132        mBreakCC = makeByte(0x0A); // LF
133    }
134    RE * anchorRE = mBreakCC;
135    if (mGrepRecordBreak == GrepRecordBreakKind::Unicode) {
136        Name * anchorName = makeName("UTF8_LB", Name::Type::Unicode);
137        anchorName->setDefinition(makeUnicodeBreak());
138        anchorRE = anchorName;
139    }
140
141    mRE = re;
142    bool allAnchored = true;
143
144    if (!hasEndAnchor(mRE)) allAnchored = false;
145    mRE = resolveModesAndExternalSymbols(mRE);
146    mRE = exclude_CC(mRE, mBreakCC);
147    mRE = resolveAnchors(mRE, anchorRE);
148    gatherUnicodeProperties(mRE, mUnicodeProperties);
149    mRE = regular_expression_passes(mRE);
150
151    if (allAnchored && (mGrepRecordBreak != GrepRecordBreakKind::Unicode)) mMoveMatchesToEOL = false;
152
153}
154
155StreamSet * LZ4GrepBaseGenerator::linefeedStreamFromUncompressedBits(StreamSet *uncompressedBasisBits) {
156    StreamSet * const LineFeedStream = mPipeline->CreateStreamSet(1, 1);
157    mPipeline->CreateKernelCall<LineFeedKernelBuilder>(uncompressedBasisBits, LineFeedStream, BitNumbering::BigEndian);
158    return LineFeedStream;
159}
160
161unsigned LZ4GrepBaseGenerator::calculateTwistWidth(unsigned numOfStreams) {
162    if (numOfStreams <= 2) {
163        return numOfStreams;
164    } else if (numOfStreams <= 4) {
165        return 4;
166    } else if (numOfStreams <= 8) {
167        return 8;
168    } else {
169        llvm::report_fatal_error("Twist: Unsupported numOfStreams " + std::to_string(numOfStreams));;
170    }
171}
172
173std::pair<StreamSet *, StreamSet *> LZ4GrepBaseGenerator::multiplexingGrep(RE * re, StreamSet * compressedByteStream, StreamSet * compressedBitStream, bool utf8CC) {
174
175    initREs(re);
176
177    //  Regular Expression Processing and Analysis Phase
178
179    StreamSet * fakeMatchCopiedBits = nullptr;
180    StreamSet * u8NoFinalStream = nullptr;
181    StreamSet * uncompressedCharClasses = nullptr;
182
183    CC * const linefeedCC = makeCC(0x0A);
184
185    std::shared_ptr<MultiplexedAlphabet> mpx;
186
187    if (utf8CC) {
188
189        const auto requireNonFinal = isRequireNonFinal(mRE);
190        Seq * const seq = cast<Seq>(makeSeq({mRE, linefeedCC}));
191
192        auto UnicodeSets = collectCCs(seq, Unicode, std::set<Name *>({makeZeroWidth("\\b{g}")}));;
193
194        mpx = std::make_shared<MultiplexedAlphabet>("mpx", UnicodeSets);
195
196        auto mpxCCs = mpx->getMultiplexedCCs();
197
198        bool mpxContainFinal = false;
199
200        seq->push_back(u8FinalRe);
201
202        auto UnicodeSetsWithU8Final = collectCCs(seq, Unicode, std::set<Name *>({makeZeroWidth("\\b{g}")}));;
203        auto u8FinalMpx = std::make_shared<MultiplexedAlphabet>("mpx", UnicodeSetsWithU8Final);
204        auto mpxCCsWithU8Final = u8FinalMpx->getMultiplexedCCs();
205
206        if (calculateTwistWidth(mpxCCs.size() + 1) > calculateTwistWidth(mpxCCsWithU8Final.size())) {
207            mpxContainFinal = true;
208            UnicodeSets = UnicodeSetsWithU8Final;
209            mpx = u8FinalMpx;
210            mpxCCs = mpxCCsWithU8Final;
211        }
212
213        mRE = transformCCs(mpx, mRE);
214
215        StreamSet * CharClasses = mPipeline->CreateStreamSet(mpxCCs.size());
216
217        mPipeline->CreateKernelCall<CharClassesKernel>(std::move(mpxCCs), compressedBitStream, CharClasses, BitNumbering::BigEndian);
218
219        if (!requireNonFinal) {
220            // We do not need to decompress U8 NonFinal Stream is all of the character class in target regular expression is byte length
221            uncompressedCharClasses = decompressBitStream(compressedByteStream, CharClasses);
222            auto fakeStreams = generateFakeStreams(uncompressedCharClasses, std::vector<unsigned>{8, 1});
223            fakeMatchCopiedBits = fakeStreams[0];
224            u8NoFinalStream = fakeStreams[1];
225        } else {
226            if (mpxContainFinal) {
227                auto decompressedStreams = decompressBitStreams(compressedByteStream, {CharClasses/*, compressedNonFinalStream*/});
228                uncompressedCharClasses = decompressedStreams[0];
229                auto fakeStreams = generateFakeStreams(uncompressedCharClasses, std::vector<unsigned>{8});
230                fakeMatchCopiedBits = fakeStreams[0];
231                StreamSet * u8FinalStream = mPipeline->CreateStreamSet();
232                RE * const mpxU8FinalRe = transformCCs(mpx, u8FinalRe);
233                Alphabets alpha;
234                alpha.emplace_back(mpx, uncompressedCharClasses);
235                mPipeline->CreateKernelCall<ICGrepKernel>(mpxU8FinalRe, fakeMatchCopiedBits, uncompressedCharClasses, Externals{}, alpha, BitNumbering::BigEndian, false);
236                u8NoFinalStream = mPipeline->CreateStreamSet(1, 1);
237                mPipeline->CreateKernelCall<LZ4NotKernel>(u8FinalStream, u8NoFinalStream);
238            } else {
239                StreamSet * compressedNonFinalStream = mPipeline->CreateStreamSet(1, 1);
240                mPipeline->CreateKernelCall<ICGrepKernel>(u8NonFinalRe, compressedBitStream, compressedNonFinalStream, Externals{}, Alphabets{}, BitNumbering::BigEndian);
241                auto decompressedStreams = decompressBitStreams(compressedByteStream, {CharClasses, compressedNonFinalStream});
242                uncompressedCharClasses = decompressedStreams[0];
243                u8NoFinalStream = decompressedStreams[1];
244                auto fakeStreams = generateFakeStreams(uncompressedCharClasses, std::vector<unsigned>{8});
245                fakeMatchCopiedBits = fakeStreams[0];
246            }
247        }
248
249    } else { // if (!utf8CC) {
250
251        RE * const targetRe = toUTF8(mRE, true);
252        Seq * const seq = cast<Seq>(makeSeq({targetRe, linefeedCC}));
253        auto UnicodeSets = collectCCs(seq, Byte, std::set<Name *>({makeZeroWidth("\\b{g}")}));
254
255        mpx = std::make_shared<MultiplexedAlphabet>("mpx", UnicodeSets);
256
257        mRE = transformCCs(mpx, targetRe);
258
259        auto mpx_basis = mpx->getMultiplexedCCs();
260        StreamSet * const CharClasses = mPipeline->CreateStreamSet(mpx_basis.size());
261
262        mPipeline->CreateKernelCall<ByteClassesKernel>(std::move(mpx_basis), compressedBitStream, CharClasses, BitNumbering::BigEndian);
263
264        uncompressedCharClasses = decompressBitStream(compressedByteStream, CharClasses);
265        auto fakeStreams = generateFakeStreams(uncompressedCharClasses, std::vector<unsigned>{8, 1});
266        fakeMatchCopiedBits = fakeStreams[0];
267        u8NoFinalStream = fakeStreams[1];
268    }
269
270    StreamSet * const MatchResults = mPipeline->CreateStreamSet(1, 1);
271
272    // Multiplexing Grep Kernel is not Cachable, since it is possible that two REs with name "mpx_1" have different alphabets
273    StreamSet * LineBreakStream = mPipeline->CreateStreamSet(1, 1);
274
275    RE * const transformedCC = transformCCs(mpx, linefeedCC);
276
277    Alphabets alpha;
278    alpha.emplace_back(mpx, uncompressedCharClasses);
279
280    mPipeline->CreateKernelCall<ICGrepKernel>(transformedCC, fakeMatchCopiedBits, LineBreakStream, Externals{}, alpha, BitNumbering::BigEndian, false);
281
282    Externals externals;
283    externals.emplace_back("UTF8_nonfinal", u8NoFinalStream);
284
285    mPipeline->CreateKernelCall<ICGrepKernel>(mRE, fakeMatchCopiedBits, MatchResults, externals, alpha, BitNumbering::BigEndian, false);
286
287    StreamSet * Matches = MatchResults;
288    if (mMoveMatchesToEOL) {
289        StreamSet * const MovedMatches = mPipeline->CreateStreamSet();
290        mPipeline->CreateKernelCall<MatchedLinesKernel>(Matches, LineBreakStream, MovedMatches);
291        Matches = MovedMatches;
292    }
293
294    return std::pair<StreamSet *, StreamSet *>(LineBreakStream, Matches);
295}
296
297std::pair<StreamSet *, StreamSet *> LZ4GrepBaseGenerator::grep(RE * re, StreamSet * byteStream, StreamSet * uncompressedBasisBits, bool ccMultiplexing) {
298
299    initREs(re);
300
301    //  Regular Expression Processing and Analysis Phase
302    StreamSet * const MatchResults = mPipeline->CreateStreamSet(1, 1);
303
304    if (uncompressedBasisBits == nullptr) {
305        uncompressedBasisBits = s2p(byteStream);
306    }
307
308    StreamSet * const LineBreakStream = linefeedStreamFromUncompressedBits(uncompressedBasisBits);
309
310    if (ccMultiplexing) {
311
312        const auto UnicodeSets = collectCCs(mRE, Unicode, std::set<Name *>({makeZeroWidth("\\b{g}")}));
313
314        auto mpx = std::make_shared<MultiplexedAlphabet>("mpx", UnicodeSets);
315        mRE = transformCCs(mpx, mRE);
316        auto mpx_basis = mpx->getMultiplexedCCs();
317        StreamSet * const CharClasses = mPipeline->CreateStreamSet(mpx_basis.size());
318        mPipeline->CreateKernelCall<CharClassesKernel>(std::move(mpx_basis), uncompressedBasisBits, CharClasses, BitNumbering::BigEndian);
319
320        Alphabets alphabets;
321        alphabets.emplace_back(std::move(mpx), CharClasses);
322        mPipeline->CreateKernelCall<ICGrepKernel>(mRE, uncompressedBasisBits, MatchResults, Externals{}, alphabets, BitNumbering::BigEndian, false);
323
324
325    } else {
326
327        bool anyGCB = hasGraphemeClusterBoundary(mRE);
328        bool isSimple = (mGrepRecordBreak != GrepRecordBreakKind::Unicode) && (!anyGCB);
329        if (isSimple) {
330            mRE = toUTF8(mRE);
331        }
332        mPipeline->CreateKernelCall<ICGrepKernel>(mRE, uncompressedBasisBits, MatchResults, Externals{}, Alphabets{}, BitNumbering::BigEndian);
333    }
334
335    StreamSet * Matches = MatchResults;
336    if (mMoveMatchesToEOL) {
337        StreamSet * const MovedMatches = mPipeline->CreateStreamSet();
338        mPipeline->CreateKernelCall<MatchedLinesKernel>(Matches, LineBreakStream, MovedMatches);
339        Matches = MovedMatches;
340    }
341
342    return std::pair<StreamSet *, StreamSet *>(LineBreakStream, Matches);
343
344}
345
346void LZ4GrepBaseGenerator::invokeScanMatchGrep(char* fileBuffer, size_t blockStart, size_t blockEnd, bool hasBlockChecksum) {
347    auto main = getScanMatchGrepMainFunction();
348    std::ostringstream s;
349    EmitMatch accum("", false, false, s);
350
351    main(fileBuffer, blockStart, blockEnd, hasBlockChecksum, reinterpret_cast<intptr_t>(&accum));
352    llvm::outs() << s.str();
353}
354
355
356void LZ4GrepBaseGenerator::generateFullyDecompressionScanMatchGrepPipeline(RE *regex) {
357    StreamSet* compressedByteStream = loadByteStream();
358
359    StreamSet * const uncompressedByteStream = byteStreamDecompression(compressedByteStream);
360    StreamSet * uncompressedBitStream = s2p(uncompressedByteStream);
361
362    StreamSet * LineBreakStream;
363    StreamSet * Matches;
364    std::tie(LineBreakStream, Matches) = grep(regex, uncompressedByteStream, uncompressedBitStream);
365
366
367    Kernel * scanMatchK = mPipeline->CreateKernelCall<ScanMatchKernel>(Matches, LineBreakStream, uncompressedByteStream, match_accumulator);
368    mPxDriver.LinkFunction(scanMatchK, "accumulate_match_wrapper", accumulate_match_wrapper);
369    mPxDriver.LinkFunction(scanMatchK, "finalize_match_wrapper", finalize_match_wrapper);
370
371    mMainMethod = mPipeline->compile();
372
373}
374
375void LZ4GrepBaseGenerator::generateMultiplexingScanMatchGrepPipeline(RE *regex, bool utf8CC) {
376
377    StreamSet *compressedByteStream = nullptr, *compressedBasisBits = nullptr;
378    std::tie(compressedByteStream, compressedBasisBits) = loadByteStreamAndBitStream();
379
380    StreamSet * LineBreakStream;
381    StreamSet * Matches;
382    std::tie(LineBreakStream, Matches) = multiplexingGrep(regex, compressedByteStream, compressedBasisBits, utf8CC);
383
384    LZ4BlockInfo blockInfo = getBlockInfo(compressedByteStream);
385
386    StreamSet * const decompressionByteStream = mPipeline->CreateStreamSet(1, 8);
387    mPipeline->CreateKernelCall<LZ4ByteStreamDecompressionKernel>(mFileSize, compressedByteStream, blockInfo, nullptr, decompressionByteStream );
388    Kernel * const scanMatchK = mPipeline->CreateKernelCall<ScanMatchKernel>(Matches, LineBreakStream, decompressionByteStream, match_accumulator);
389    mPxDriver.LinkFunction(scanMatchK, "accumulate_match_wrapper", accumulate_match_wrapper);
390    mPxDriver.LinkFunction(scanMatchK, "finalize_match_wrapper", finalize_match_wrapper);
391    mMainMethod = mPipeline->compile();
392}
393
394
395void LZ4GrepBaseGenerator::generateMultiplexingCountOnlyGrepPipeline(RE *regex, bool utf8CC) {
396    StreamSet *compressedByteStream = nullptr, *compressedBasisBits = nullptr;
397    std::tie(compressedByteStream, compressedBasisBits) = loadByteStreamAndBitStream();
398    StreamSet * Matches = multiplexingGrep(regex, compressedByteStream, compressedBasisBits, utf8CC).second;
399    mPipeline->CreateKernelCall<PopcountKernel>(Matches, mPipeline->getOutputScalar("countResult"));
400    mMainMethod = mPipeline->compile();
401}
402
403
404void LZ4GrepBaseGenerator::generateFullyDecompressionCountOnlyGrepPipeline(RE *regex) {
405    StreamSet * const uncompressedByteStream = generateUncompressedByteStream();
406    StreamSet * LineBreakStream;
407    StreamSet * Matches;
408    std::tie(LineBreakStream, Matches) = grep(regex, uncompressedByteStream, nullptr);
409    mPipeline->CreateKernelCall<PopcountKernel>(Matches, mPipeline->getOutputScalar("countResult"));
410    mMainMethod = mPipeline->compile();
411}
412
413
414ScanMatchGrepMainFunctionType LZ4GrepBaseGenerator::getScanMatchGrepMainFunction() {
415    return reinterpret_cast<ScanMatchGrepMainFunctionType>(mMainMethod);
416}
417CountOnlyGrepMainFunctionType LZ4GrepBaseGenerator::getCountOnlyGrepMainFunction() {
418    return reinterpret_cast<CountOnlyGrepMainFunctionType>(mMainMethod);
419}
420
421StreamSets LZ4GrepBaseGenerator::generateFakeStreams(StreamSet * refStream, std::vector<unsigned> numOfStreams) {
422    if (numOfStreams.empty()) {
423        return StreamSets{};
424    }
425    StreamSets outputStreams;
426    outputStreams.reserve(numOfStreams.size());
427    for (const auto k : numOfStreams) {
428        outputStreams.push_back(mPipeline->CreateStreamSet(k));
429    }
430    mPipeline->CreateKernelCall<FakeStreamGeneratingKernel>(refStream, outputStreams);
431    return outputStreams;
432}
433
434
435
436StreamSets LZ4GrepBaseGenerator::decompressBitStreams(StreamSet *compressedByteStream, StreamSets compressedBitStreams) {
437    // Default implementation here will be slow
438    StreamSets retVec;
439    for (unsigned i = 0; i < compressedBitStreams.size(); i++) {
440        retVec.push_back(decompressBitStream(compressedByteStream, compressedBitStreams[i]));
441    }
442    return retVec;
443}
Note: See TracBrowser for help on using the repository browser.