source: icGREP/icgrep-devel/icgrep/lz4/grep/lz4_grep_base_generator.cpp @ 6153

Last change on this file since 6153 was 6153, checked in by xwa163, 7 months ago

lz4_grep: 1. fix typo 2. analyse the requirement of u8 non final

File size: 28.0 KB
Line 
1
2#include "lz4_grep_base_generator.h"
3
4#include <boost/iostreams/device/mapped_file.hpp>
5
6#include <llvm/Support/PrettyStackTrace.h>
7
8#include <cc/cc_compiler.h>
9
10#include <kernels/cc_kernel.h>
11#include <kernels/s2p_kernel.h>
12#include <kernels/p2s_kernel.h>
13#include <kernels/source_kernel.h>
14#include <kernels/stdout_kernel.h>
15#include <kernels/fake_stream_generating_kernel.h>
16#include <re/re_toolchain.h>
17
18#include <re/collect_ccs.h>
19#include <re/replaceCC.h>
20
21#include <re/casing.h>
22#include <re/exclude_CC.h>
23#include <re/to_utf8.h>
24#include <re/re_analysis.h>
25#include <re/re_name_resolve.h>
26#include <re/re_name_gather.h>
27#include <re/re_multiplex.h>
28#include <re/re_utility.h>
29
30#include <UCD/resolve_properties.h>
31#include <kernels/charclasses.h>
32#include <kernels/grep_kernel.h>
33#include <kernels/UCD_property_kernel.h>
34#include <kernels/grapheme_kernel.h>
35#include <kernels/linebreak_kernel.h>
36#include <kernels/streams_merge.h>
37#include <kernels/scanmatchgen.h>
38#include <kernels/until_n.h>
39#include <re/grapheme_clusters.h>
40#include <re/printer_re.h>
41#include <llvm/Support/raw_ostream.h>
42#include <llvm/Support/Debug.h>
43#include <kernels/lz4/lz4_block_decoder.h>
44#include <kernels/lz4/lz4_match_detector.h>
45
46#include <re/re_seq.h>
47#include <kernels/kernel_builder.h>
48#include <re/re_alt.h>
49#include <kernels/lz4/decompression/lz4_bytestream_decompression.h>
50#include <kernels/lz4/lz4_not_kernel.h>
51
52namespace re { class CC; }
53
54using namespace llvm;
55using namespace parabix;
56using namespace kernel;
57using namespace grep;
58using namespace re;
59
60LZ4GrepBaseGenerator::LZ4GrepBaseGenerator()
61        : LZ4BaseGenerator(),
62          u8NonFinalRe(makeAlt({makeByte(0xC2, 0xF4),
63                                makeSeq({makeByte(0xE0, 0xF4), makeByte(0x80, 0xBF)}),
64                                makeSeq({makeByte(0xF0, 0xF4), makeByte(0x80, 0xBF), makeByte(0x80, 0xBF)})})),
65          u8FinalRe(makeCC(0x0, 0x1FFFFF))
66{
67    mGrepRecordBreak = grep::GrepRecordBreakKind::LF;
68    mMoveMatchesToEOL = true;
69}
70
71void LZ4GrepBaseGenerator::generateScanMatchGrepPipeline(re::RE* regex, bool enableMultiplexing, bool utf8CC) {
72    if (enableMultiplexing) {
73        this->generateMultiplexingScanMatchGrepPipeline(regex, utf8CC);
74    } else {
75        this->generateFullyDecompressionScanMatchGrepPipeline(regex);
76    }
77}
78
79void LZ4GrepBaseGenerator::generateCountOnlyGrepPipeline(re::RE* regex, bool enableMultiplexing, bool utf8CC){
80    if (enableMultiplexing) {
81        this->generateMultiplexingCountOnlyGrepPipeline(regex, utf8CC);
82    } else {
83        this->generateFullyDecompressionCountOnlyGrepPipeline(regex);
84    }
85}
86
87void LZ4GrepBaseGenerator::initREs(re::RE * RE) {
88    if (mGrepRecordBreak == GrepRecordBreakKind::Unicode) {
89        mBreakCC = re::makeCC(re::makeCC(0x0A, 0x0D), re::makeCC(re::makeCC(0x85), re::makeCC(0x2028, 0x2029)));
90    } else if (mGrepRecordBreak == GrepRecordBreakKind::Null) {
91        mBreakCC = re::makeByte(0);  // Null
92    } else {
93        mBreakCC = re::makeByte(0x0A); // LF
94    }
95    re::RE * anchorRE = mBreakCC;
96    if (mGrepRecordBreak == GrepRecordBreakKind::Unicode) {
97        re::Name * anchorName = re::makeName("UTF8_LB", re::Name::Type::Unicode);
98        anchorName->setDefinition(re::makeUnicodeBreak());
99        anchorRE = anchorName;
100    }
101
102    mRE = RE;
103    bool allAnchored = true;
104
105    if (!hasEndAnchor(mRE)) allAnchored = false;
106    mRE = resolveModesAndExternalSymbols(mRE);
107    mRE = re::exclude_CC(mRE, mBreakCC);
108    mRE = resolveAnchors(mRE, anchorRE);
109    re::gatherUnicodeProperties(mRE, mUnicodeProperties);
110    mRE = regular_expression_passes(mRE);
111
112    if (allAnchored && (mGrepRecordBreak != GrepRecordBreakKind::Unicode)) mMoveMatchesToEOL = false;
113
114}
115
116parabix::StreamSetBuffer * LZ4GrepBaseGenerator::linefeedStreamFromUncompressedBits(
117        parabix::StreamSetBuffer *uncompressedBasisBits) {
118    auto & idb = mPxDriver.getBuilder();
119    const unsigned baseBufferSize = this->getDefaultBufferBlocks();
120    StreamSetBuffer * LineFeedStream = mPxDriver.addBuffer<StaticBuffer>(idb, idb->getStreamSetTy(1, 1), baseBufferSize, 1);
121    kernel::Kernel * linefeedK = mPxDriver.addKernelInstance<kernel::LineFeedKernelBuilder>(idb, Binding{idb->getStreamSetTy(8), "basis", FixedRate(), Principal()}, cc::BitNumbering::BigEndian);
122    mPxDriver.makeKernelCall(linefeedK, {uncompressedBasisBits}, {LineFeedStream});
123    return LineFeedStream;
124}
125
126std::pair<parabix::StreamSetBuffer *, parabix::StreamSetBuffer *> LZ4GrepBaseGenerator::multiplexingGrep(
127        re::RE *RE,
128        parabix::StreamSetBuffer *compressedByteStream,
129        parabix::StreamSetBuffer *compressedBitStream,
130        bool utf8CC
131) {
132
133    this->initREs(RE);
134    auto mGrepDriver = &mPxDriver;
135
136    auto & idb = mGrepDriver->getBuilder();
137    const unsigned baseBufferSize = this->getDefaultBufferBlocks();
138    int MaxCountFlag = 0;
139
140    //  Regular Expression Processing and Analysis Phase
141    const auto nREs = 1;
142
143    std::vector<StreamSetBuffer *> MatchResultsBufs(nREs);
144
145
146    std::map<std::string, StreamSetBuffer *> propertyStream;
147
148    std::vector<std::string> externalStreamNames;
149    std::set<re::Name *> UnicodeProperties;
150
151    StreamSetBuffer* fakeMatchCopiedBits = nullptr;
152    StreamSetBuffer* u8NoFinalStream = nullptr;
153    StreamSetBuffer * uncompressedCharClasses = nullptr;
154
155    re::CC* linefeedCC = nullptr;
156
157
158    if (utf8CC) {
159        re::Seq* seq = re::makeSeq();
160        re::RE* targetRe = mRE;
161
162        bool requireNonFinal = re::isRequireNonFinal(mRE);
163        std::vector<re::CC*> OriginalUnicodeSets = re::collectCCs(seq, &cc::Unicode, std::set<re::Name *>({re::makeZeroWidth("\\b{g}")}));;
164
165        linefeedCC = re::makeCC(0x0A);
166
167        seq->push_back(targetRe);
168        seq->push_back(std::move(linefeedCC));
169        if (requireNonFinal) {
170            seq->push_back(u8FinalRe);
171        }
172
173
174        std::vector<re::CC*> UnicodeSets = re::collectCCs(seq, &cc::Unicode, std::set<re::Name *>({re::makeZeroWidth("\\b{g}")}));;
175
176        mpx = make_unique<cc::MultiplexedAlphabet>("mpx", UnicodeSets);
177        mRE = transformCCs(mpx.get(), targetRe);
178
179
180        std::vector<re::CC *> mpx_basis = mpx->getMultiplexedCCs();
181        auto numOfCharacterClasses = mpx_basis.size();
182        llvm::errs() << "numOfUnicodeSet:" << UnicodeSets.size() << "\n";
183
184        llvm::errs() << "numOfCharacterClasses:" << numOfCharacterClasses << "\n";
185        StreamSetBuffer * CharClasses = mGrepDriver->addBuffer<StaticBuffer>(idb, idb->getStreamSetTy(numOfCharacterClasses), baseBufferSize, 1);
186
187        kernel::Kernel * ccK = mGrepDriver->addKernelInstance<kernel::CharClassesKernel>(idb, std::move(mpx_basis), false, cc::BitNumbering::BigEndian);
188        mGrepDriver->makeKernelCall(ccK, {compressedBitStream}, {CharClasses});
189
190        if (!requireNonFinal) {
191            // We do not need to decompress U8 NonFinal Stream is all of the character class in target regular expression is byte length
192            uncompressedCharClasses = this->decompressBitStream(compressedByteStream, CharClasses);
193            auto fakeStreams = this->generateFakeStreams(idb, uncompressedCharClasses, std::vector<unsigned>{8, 1});
194            fakeMatchCopiedBits = fakeStreams[0];
195            u8NoFinalStream = fakeStreams[1];
196        } else {
197//            StreamSetBuffer* compressedNonFinalStream = mGrepDriver->addBuffer<StaticBuffer>(idb, idb->getStreamSetTy(1, 1), baseBufferSize, 1);
198//            kernel::Kernel * nonFinalK = mGrepDriver->addKernelInstance<kernel::ICGrepKernel>(idb, u8NonFinalRe, externalStreamNames, std::vector<cc::Alphabet *>(), cc::BitNumbering::BigEndian);
199//            mGrepDriver->makeKernelCall(nonFinalK, {compressedBitStream}, {compressedNonFinalStream});
200
201            auto decompressedStreams = this->decompressBitStreams(compressedByteStream, {CharClasses/*, compressedNonFinalStream*/});
202            uncompressedCharClasses = decompressedStreams[0];
203//            u8NoFinalStream = decompressedStreams[1];
204
205            auto fakeStreams = this->generateFakeStreams(idb, uncompressedCharClasses, std::vector<unsigned>{8});
206            fakeMatchCopiedBits = fakeStreams[0];
207
208            StreamSetBuffer * u8FinalStream = mPxDriver.addBuffer<StaticBuffer>(idb, idb->getStreamSetTy(1, 1), this->getDefaultBufferBlocks(), 1);
209            ICGrepKernel * u8FinalGrepK = (ICGrepKernel *)mGrepDriver->addKernelInstance<kernel::ICGrepKernel>(idb, transformCCs(mpx.get(), u8FinalRe), externalStreamNames, std::vector<cc::Alphabet *>{mpx.get()}, cc::BitNumbering::BigEndian);
210            u8FinalGrepK->setCachable(false);
211            mGrepDriver->makeKernelCall(u8FinalGrepK, {fakeMatchCopiedBits, uncompressedCharClasses}, {u8FinalStream});
212
213            u8NoFinalStream = mPxDriver.addBuffer<StaticBuffer>(idb, idb->getStreamSetTy(1, 1), this->getDefaultBufferBlocks(), 1);
214            Kernel* notK = mGrepDriver->addKernelInstance<LZ4NotKernel>(idb);
215            mGrepDriver->makeKernelCall(notK, {u8FinalStream}, {u8NoFinalStream});
216        }
217    } else {
218        re::Seq* seq = re::makeSeq();
219        re::RE* targetRe = mRE;
220        targetRe = re::toUTF8(targetRe, true);
221
222        linefeedCC = re::makeByte(0x0A);
223
224        seq->push_back(targetRe);
225        seq->push_back(std::move(linefeedCC));
226
227        std::vector<re::CC*> UnicodeSets = re::collectCCs(seq, &cc::Byte, std::set<re::Name *>({re::makeZeroWidth("\\b{g}")}));
228
229        mpx = make_unique<cc::MultiplexedAlphabet>("mpx", UnicodeSets);
230        mRE = transformCCs(mpx.get(), targetRe);
231
232        std::vector<re::CC *> mpx_basis = mpx->getMultiplexedCCs();
233        auto numOfCharacterClasses = mpx_basis.size();
234        llvm::errs() << "numOfUnicodeSet:" << UnicodeSets.size() << "\n";
235        llvm::errs() << "numOfCharacterClasses:" << numOfCharacterClasses << "\n";
236        StreamSetBuffer * CharClasses = mGrepDriver->addBuffer<StaticBuffer>(idb, idb->getStreamSetTy(numOfCharacterClasses), baseBufferSize, 1);
237
238        kernel::Kernel * ccK = mGrepDriver->addKernelInstance<kernel::ByteClassesKernel>(idb, std::move(mpx_basis), false, cc::BitNumbering::BigEndian);
239        mGrepDriver->makeKernelCall(ccK, {compressedBitStream}, {CharClasses});
240
241        uncompressedCharClasses = this->decompressBitStream(compressedByteStream, CharClasses);
242        auto fakeStreams = this->generateFakeStreams(idb, uncompressedCharClasses, std::vector<unsigned>{8, 1});
243        fakeMatchCopiedBits = fakeStreams[0];
244        u8NoFinalStream = fakeStreams[1];
245    }
246
247    StreamSetBuffer * const MatchResults = mGrepDriver->addBuffer<StaticBuffer>(idb, idb->getStreamSetTy(1, 1), baseBufferSize, 1);
248
249    // Multiplexing Grep Kernel is not Cachable, since it is possible that two REs with name "mpx_1" have different alphabets
250    StreamSetBuffer * LineBreakStream = mPxDriver.addBuffer<StaticBuffer>(idb, idb->getStreamSetTy(1, 1), this->getDefaultBufferBlocks(), 1);
251    ICGrepKernel * lineFeedGrepK = (ICGrepKernel *)mGrepDriver->addKernelInstance<kernel::ICGrepKernel>(idb, transformCCs(mpx.get(), linefeedCC), externalStreamNames, std::vector<cc::Alphabet *>{mpx.get()}, cc::BitNumbering::BigEndian);
252    lineFeedGrepK->setCachable(false);
253    mGrepDriver->makeKernelCall(lineFeedGrepK, {fakeMatchCopiedBits, uncompressedCharClasses}, {LineBreakStream});
254
255
256    externalStreamNames.push_back("UTF8_nonfinal");
257
258    ICGrepKernel * icgrepK = (ICGrepKernel *)mGrepDriver->addKernelInstance<kernel::ICGrepKernel>(idb, mRE, externalStreamNames, std::vector<cc::Alphabet *>{mpx.get()}, cc::BitNumbering::BigEndian);
259    icgrepK->setCachable(false);
260    mGrepDriver->makeKernelCall(icgrepK, {fakeMatchCopiedBits, u8NoFinalStream, uncompressedCharClasses}, {MatchResults});
261    MatchResultsBufs[0] = MatchResults;
262
263    StreamSetBuffer * MergedResults = MatchResultsBufs[0];
264
265    StreamSetBuffer * Matches = MergedResults;
266    if (mMoveMatchesToEOL) {
267        StreamSetBuffer * OriginalMatches = Matches;
268        kernel::Kernel * matchedLinesK = mGrepDriver->addKernelInstance<kernel::MatchedLinesKernel>(idb);
269        Matches = mGrepDriver->addBuffer<StaticBuffer>(idb, idb->getStreamSetTy(1, 1), baseBufferSize, 1);
270        mGrepDriver->makeKernelCall(matchedLinesK, {OriginalMatches, LineBreakStream}, {Matches});
271    }
272
273    if (MaxCountFlag > 0) {
274        kernel::Kernel * untilK = mGrepDriver->addKernelInstance<kernel::UntilNkernel>(idb);
275        untilK->setInitialArguments({idb->getSize(MaxCountFlag)});
276        StreamSetBuffer * const AllMatches = Matches;
277        Matches = mGrepDriver->addBuffer<StaticBuffer>(idb, idb->getStreamSetTy(1, 1), baseBufferSize, 1);
278        mGrepDriver->makeKernelCall(untilK, {AllMatches}, {Matches});
279    }
280
281    return std::pair<StreamSetBuffer *, StreamSetBuffer *>(LineBreakStream, Matches);
282};
283
284std::pair<parabix::StreamSetBuffer *, parabix::StreamSetBuffer *> LZ4GrepBaseGenerator::grep(
285        re::RE *RE, parabix::StreamSetBuffer *byteStream, parabix::StreamSetBuffer *uncompressedBasisBits, bool ccMultiplexing) {
286
287    this->initREs(RE);
288    auto mGrepDriver = &mPxDriver;
289
290    auto & idb = mGrepDriver->getBuilder();
291    // TODO: until we automate stream buffer sizing, use this calculation to determine how large our matches buffer needs to be.
292    const unsigned baseBufferSize = this->getDefaultBufferBlocks();
293    int MaxCountFlag = 0;
294
295    //  Regular Expression Processing and Analysis Phase
296    const auto nREs = 1;
297
298    std::vector<StreamSetBuffer *> MatchResultsBufs(nREs);
299
300    StreamSetBuffer * LineBreakStream = nullptr;
301
302
303    std::map<std::string, StreamSetBuffer *> propertyStream;
304
305    for(unsigned i = 0; i < nREs; ++i) {
306
307        if (ccMultiplexing) {
308
309            if (uncompressedBasisBits == nullptr) {
310                uncompressedBasisBits = this->s2p(byteStream);
311            }
312            this->linefeedStreamFromUncompressedBits(uncompressedBasisBits);
313            std::vector<std::string> externalStreamNames;
314            std::vector<StreamSetBuffer *> icgrepInputSets = {uncompressedBasisBits};
315
316            const auto UnicodeSets = re::collectCCs(mRE, &cc::Unicode, std::set<re::Name *>({re::makeZeroWidth("\\b{g}")}));
317            StreamSetBuffer * const MatchResults = mGrepDriver->addBuffer<StaticBuffer>(idb, idb->getStreamSetTy(1, 1), baseBufferSize, 1);
318
319            std::unique_ptr<cc::MultiplexedAlphabet> mpx = make_unique<cc::MultiplexedAlphabet>("mpx", UnicodeSets);
320            mRE = transformCCs(mpx.get(), mRE);
321            std::vector<re::CC *> mpx_basis = mpx->getMultiplexedCCs();
322            auto numOfCharacterClasses = mpx_basis.size();
323            StreamSetBuffer * CharClasses = mGrepDriver->addBuffer<StaticBuffer>(idb, idb->getStreamSetTy(numOfCharacterClasses), baseBufferSize, 1);
324            kernel::Kernel * ccK = mGrepDriver->addKernelInstance<kernel::CharClassesKernel>(idb, std::move(mpx_basis), false, cc::BitNumbering::BigEndian);
325            mGrepDriver->makeKernelCall(ccK, {uncompressedBasisBits}, {CharClasses});
326
327            kernel::Kernel * icgrepK = mGrepDriver->addKernelInstance<kernel::ICGrepKernel>(idb, mRE, externalStreamNames, std::vector<cc::Alphabet *>{mpx.get()}, cc::BitNumbering::BigEndian);
328            icgrepInputSets.push_back(CharClasses);
329            mGrepDriver->makeKernelCall(icgrepK, icgrepInputSets, {MatchResults});
330            MatchResultsBufs[i] = MatchResults;
331        } else {
332
333            bool anyGCB = hasGraphemeClusterBoundary(mRE);
334            bool isSimple = (mGrepRecordBreak != GrepRecordBreakKind::Unicode) && (!anyGCB);
335            if (isSimple) {
336                mRE = toUTF8(mRE);
337            }
338            const unsigned ByteCClimit = 6;
339
340            if (byteTestsWithinLimit(mRE, ByteCClimit)) {
341                LineBreakStream = mGrepDriver->addBuffer<StaticBuffer>(idb, idb->getStreamSetTy(1, 1), baseBufferSize);
342                kernel::Kernel * breakK = mGrepDriver->addKernelInstance<kernel::DirectCharacterClassKernelBuilder>(idb, "breakCC", std::vector<re::CC *>{mBreakCC});
343                mGrepDriver->makeKernelCall(breakK, {byteStream}, {LineBreakStream});
344
345                std::vector<std::string> externalStreamNames;
346                std::vector<StreamSetBuffer *> icgrepInputSets = {byteStream};
347                StreamSetBuffer * MatchResults = mGrepDriver->addBuffer<StaticBuffer>(idb, idb->getStreamSetTy(1, 1), baseBufferSize);
348                kernel::Kernel * icgrepK = mGrepDriver->addKernelInstance<kernel::ByteGrepKernel>(idb, mRE, externalStreamNames);
349                mGrepDriver->makeKernelCall(icgrepK, icgrepInputSets, {MatchResults});
350                MatchResultsBufs[i] = MatchResults;
351
352            } else {
353
354                if (uncompressedBasisBits == nullptr) {
355                    uncompressedBasisBits = this->s2p(byteStream);
356                }
357                LineBreakStream = this->linefeedStreamFromUncompressedBits(uncompressedBasisBits);
358                std::vector<std::string> externalStreamNames;
359                std::vector<StreamSetBuffer *> icgrepInputSets = {uncompressedBasisBits};
360
361                std::set<re::Name *> UnicodeProperties;
362                StreamSetBuffer * MatchResults = mGrepDriver->addBuffer<StaticBuffer>(idb, idb->getStreamSetTy(1, 1), baseBufferSize, 1);
363                kernel::Kernel * icgrepK = mGrepDriver->addKernelInstance<kernel::ICGrepKernel>(idb, mRE, externalStreamNames, std::vector<cc::Alphabet *>(), cc::BitNumbering::BigEndian);
364                mGrepDriver->makeKernelCall(icgrepK, icgrepInputSets, {MatchResults});
365                MatchResultsBufs[i] = MatchResults;
366            }
367        }
368    }
369
370    StreamSetBuffer * MergedResults = MatchResultsBufs[0];
371
372    StreamSetBuffer * Matches = MergedResults;
373    if (mMoveMatchesToEOL) {
374        StreamSetBuffer * OriginalMatches = Matches;
375        kernel::Kernel * matchedLinesK = mGrepDriver->addKernelInstance<kernel::MatchedLinesKernel>(idb);
376        Matches = mGrepDriver->addBuffer<StaticBuffer>(idb, idb->getStreamSetTy(1, 1), baseBufferSize, 1);
377        mGrepDriver->makeKernelCall(matchedLinesK, {OriginalMatches, LineBreakStream}, {Matches});
378    }
379
380    if (MaxCountFlag > 0) {
381        kernel::Kernel * untilK = mGrepDriver->addKernelInstance<kernel::UntilNkernel>(idb);
382        untilK->setInitialArguments({idb->getSize(MaxCountFlag)});
383        StreamSetBuffer * const AllMatches = Matches;
384        Matches = mGrepDriver->addBuffer<StaticBuffer>(idb, idb->getStreamSetTy(1, 1), baseBufferSize, 1);
385        mGrepDriver->makeKernelCall(untilK, {AllMatches}, {Matches});
386    }
387
388    return std::pair<StreamSetBuffer *, StreamSetBuffer *>(LineBreakStream, Matches);
389
390}
391
392void LZ4GrepBaseGenerator::invokeScanMatchGrep(char* fileBuffer, size_t blockStart, size_t blockEnd, bool hasBlockChecksum) {
393    auto main = this->getScanMatchGrepMainFunction();
394    std::ostringstream s;
395    EmitMatch accum("", false, false, s);
396
397    main(fileBuffer, blockStart, blockEnd, hasBlockChecksum, reinterpret_cast<intptr_t>(&accum));
398    llvm::outs() << s.str();
399}
400
401
402void LZ4GrepBaseGenerator::generateFullyDecompressionScanMatchGrepPipeline(re::RE *regex) {
403    auto & iBuilder = mPxDriver.getBuilder();
404    this->generateScanMatchMainFunc(iBuilder);
405
406    StreamSetBuffer* compressedByteStream = this->loadByteStream();
407
408    StreamSetBuffer * const uncompressedByteStream = this->byteStreamDecompression(compressedByteStream);
409    StreamSetBuffer * uncompressedBitStream = this->s2p(uncompressedByteStream);
410
411    StreamSetBuffer * LineBreakStream;
412    StreamSetBuffer * Matches;
413    std::tie(LineBreakStream, Matches) = grep(regex, uncompressedByteStream, uncompressedBitStream);
414
415    kernel::Kernel * scanMatchK = mPxDriver.addKernelInstance<kernel::ScanMatchKernel>(iBuilder);
416    scanMatchK->setInitialArguments({match_accumulator});
417    mPxDriver.makeKernelCall(scanMatchK, {Matches, LineBreakStream, uncompressedByteStream}, {});
418    mPxDriver.LinkFunction(*scanMatchK, "accumulate_match_wrapper", &accumulate_match_wrapper);
419    mPxDriver.LinkFunction(*scanMatchK, "finalize_match_wrapper", &finalize_match_wrapper);
420
421    mPxDriver.generatePipelineIR();
422    mPxDriver.deallocateBuffers();
423
424    iBuilder->CreateRetVoid();
425
426    mPxDriver.finalizeObject();
427}
428
429void LZ4GrepBaseGenerator::generateMultiplexingScanMatchGrepPipeline(re::RE *regex, bool utf8CC) {
430    auto & iBuilder = mPxDriver.getBuilder();
431    this->generateScanMatchMainFunc(iBuilder);
432
433    StreamSetBuffer *compressedByteStream = nullptr, *compressedBasisBits = nullptr;
434    std::tie(compressedByteStream, compressedBasisBits) = this->loadByteStreamAndBitStream();
435
436    StreamSetBuffer * LineBreakStream;
437    StreamSetBuffer * Matches;
438    std::tie(LineBreakStream, Matches) = multiplexingGrep(regex, compressedByteStream, compressedBasisBits, utf8CC);
439
440//    Kernel* matchDetector = mPxDriver.addKernelInstance<LZ4MatchDetectorKernel>(iBuilder);
441//    StreamSetBuffer* hasMatch = mPxDriver.addBuffer<StaticBuffer>(iBuilder, iBuilder->getStreamSetTy(1, 8),
442//                                                                  this->getDefaultBufferBlocks(), 1);
443//    mPxDriver.makeKernelCall(matchDetector, {Matches, LineBreakStream}, {hasMatch});
444
445
446    LZ4BlockInfo blockInfo = this->getBlockInfo(compressedByteStream);
447
448    StreamSetBuffer *const uncompressedByteStream =
449            mPxDriver.addBuffer<StaticBuffer>(iBuilder, iBuilder->getStreamSetTy(1, 8),
450                                              this->getDefaultBufferBlocks(), 1);
451    Kernel* lz4AioK = mPxDriver.addKernelInstance<LZ4ByteStreamDecompressionKernel>(iBuilder, false, 4 * 1024 * 1024, true);
452    lz4AioK->setInitialArguments({mFileSize});
453    mPxDriver.makeKernelCall(
454            lz4AioK,
455            {
456                    compressedByteStream,
457
458                    // Block Data
459                    blockInfo.isCompress,
460                    blockInfo.blockStart,
461                    blockInfo.blockEnd,
462                    Matches
463            }, {
464                    uncompressedByteStream
465            });
466
467
468    kernel::Kernel * scanMatchK = mPxDriver.addKernelInstance<kernel::ScanMatchKernel>(iBuilder);
469    scanMatchK->setInitialArguments({match_accumulator});
470    mPxDriver.makeKernelCall(scanMatchK, {Matches, LineBreakStream, uncompressedByteStream}, {});
471    mPxDriver.LinkFunction(*scanMatchK, "accumulate_match_wrapper", &accumulate_match_wrapper);
472    mPxDriver.LinkFunction(*scanMatchK, "finalize_match_wrapper", &finalize_match_wrapper);
473
474
475    mPxDriver.generatePipelineIR();
476
477    mPxDriver.deallocateBuffers();
478    iBuilder->CreateRetVoid();
479    mPxDriver.finalizeObject();
480}
481
482
483void LZ4GrepBaseGenerator::generateMultiplexingCountOnlyGrepPipeline(re::RE *regex, bool utf8CC) {
484    auto & iBuilder = mPxDriver.getBuilder();
485    this->generateCountOnlyMainFunc(iBuilder);
486
487    StreamSetBuffer *compressedByteStream = nullptr, *compressedBasisBits = nullptr;
488    std::tie(compressedByteStream, compressedBasisBits) = this->loadByteStreamAndBitStream();
489
490    StreamSetBuffer * LineBreakStream;
491    StreamSetBuffer * Matches;
492    std::tie(LineBreakStream, Matches) = multiplexingGrep(regex, compressedByteStream, compressedBasisBits, utf8CC);
493
494    kernel::Kernel * matchCountK = mPxDriver.addKernelInstance<kernel::PopcountKernel>(iBuilder);
495    mPxDriver.makeKernelCall(matchCountK, {Matches}, {});
496    mPxDriver.generatePipelineIR();
497
498    iBuilder->setKernel(matchCountK);
499    Value * matchedLineCount = iBuilder->getAccumulator("countResult");
500    matchedLineCount = iBuilder->CreateZExt(matchedLineCount, iBuilder->getInt64Ty());
501
502    mPxDriver.deallocateBuffers();
503
504    iBuilder->CreateRet(matchedLineCount);
505
506    mPxDriver.finalizeObject();
507}
508
509
510void LZ4GrepBaseGenerator::generateFullyDecompressionCountOnlyGrepPipeline(re::RE *regex) {
511    auto & iBuilder = mPxDriver.getBuilder();
512    this->generateCountOnlyMainFunc(iBuilder);
513
514    StreamSetBuffer * const uncompressedByteStream = this->generateUncompressedByteStream();
515//    StreamSetBuffer * const uncompressedBitStream = this->generateUncompressedBitStreams();
516
517    StreamSetBuffer * LineBreakStream;
518    StreamSetBuffer * Matches;
519
520    std::tie(LineBreakStream, Matches) = grep(regex, uncompressedByteStream, nullptr);
521
522    kernel::Kernel * matchCountK = mPxDriver.addKernelInstance<kernel::PopcountKernel>(iBuilder);
523    mPxDriver.makeKernelCall(matchCountK, {Matches}, {});
524    mPxDriver.generatePipelineIR();
525
526    iBuilder->setKernel(matchCountK);
527    Value * matchedLineCount = iBuilder->getAccumulator("countResult");
528    matchedLineCount = iBuilder->CreateZExt(matchedLineCount, iBuilder->getInt64Ty());
529
530    mPxDriver.deallocateBuffers();
531
532    iBuilder->CreateRet(matchedLineCount);
533
534    mPxDriver.finalizeObject();
535}
536
537
538ScanMatchGrepMainFunctionType LZ4GrepBaseGenerator::getScanMatchGrepMainFunction() {
539    return reinterpret_cast<ScanMatchGrepMainFunctionType>(mPxDriver.getMain());
540}
541CountOnlyGrepMainFunctionType LZ4GrepBaseGenerator::getCountOnlyGrepMainFunction() {
542    return reinterpret_cast<CountOnlyGrepMainFunctionType>(mPxDriver.getMain());
543}
544
545void LZ4GrepBaseGenerator::generateCountOnlyMainFunc(const std::unique_ptr<kernel::KernelBuilder> & iBuilder) {
546    Module * M = iBuilder->getModule();
547    Type * const int64Ty = iBuilder->getInt64Ty();
548    Type * const sizeTy = iBuilder->getSizeTy();
549    Type * const boolTy = iBuilder->getIntNTy(sizeof(bool) * 8);
550//    Type * const voidTy = iBuilder->getVoidTy();
551    Type * const inputType = iBuilder->getInt8PtrTy();
552
553    Function * const main = cast<Function>(M->getOrInsertFunction("Main", int64Ty, inputType, sizeTy, sizeTy, boolTy, nullptr));
554    main->setCallingConv(CallingConv::C);
555    Function::arg_iterator args = main->arg_begin();
556    mInputStream = &*(args++);
557    mInputStream->setName("input");
558
559    mHeaderSize = &*(args++);
560    mHeaderSize->setName("mHeaderSize");
561
562    mFileSize = &*(args++);
563    mFileSize->setName("mFileSize");
564
565    mHasBlockChecksum = &*(args++);
566    mHasBlockChecksum->setName("mHasBlockChecksum");
567    // TODO for now, we do not handle blockCheckSum
568    mHasBlockChecksum = iBuilder->getInt1(false);
569
570    iBuilder->SetInsertPoint(BasicBlock::Create(M->getContext(), "entry", main, 0));
571}
572
573void LZ4GrepBaseGenerator::generateScanMatchMainFunc(const std::unique_ptr<kernel::KernelBuilder> & iBuilder) {
574    Module * M = iBuilder->getModule();
575    Type * const sizeTy = iBuilder->getSizeTy();
576    Type * const boolTy = iBuilder->getIntNTy(sizeof(bool) * 8);
577    Type * const voidTy = iBuilder->getVoidTy();
578    Type * const inputType = iBuilder->getInt8PtrTy();
579    Type * const intAddrTy = iBuilder->getIntAddrTy();
580
581    Function * const main = cast<Function>(M->getOrInsertFunction("Main", voidTy, inputType, sizeTy, sizeTy, boolTy, intAddrTy, nullptr));
582    main->setCallingConv(CallingConv::C);
583    Function::arg_iterator args = main->arg_begin();
584    mInputStream = &*(args++);
585    mInputStream->setName("input");
586
587    mHeaderSize = &*(args++);
588    mHeaderSize->setName("mHeaderSize");
589
590    mFileSize = &*(args++);
591    mFileSize->setName("mFileSize");
592
593    mHasBlockChecksum = &*(args++);
594    mHasBlockChecksum->setName("mHasBlockChecksum");
595
596    match_accumulator = &*(args++);
597    match_accumulator->setName("match_accumulator");
598
599    iBuilder->SetInsertPoint(BasicBlock::Create(M->getContext(), "entry", main, 0));
600}
601
602std::vector<parabix::StreamSetBuffer *>
603LZ4GrepBaseGenerator::generateFakeStreams(const std::unique_ptr<kernel::KernelBuilder> &idb,
604                                          parabix::StreamSetBuffer *refStream, std::vector<unsigned> numOfStreams) {
605
606    if (!numOfStreams.size()) {
607        return std::vector<StreamSetBuffer *>();
608    }
609    std::vector<StreamSetBuffer *> outputStreams;
610    for (unsigned i = 0; i < numOfStreams.size(); i++) {
611        outputStreams.push_back(mPxDriver.addBuffer<StaticBuffer>(idb, idb->getStreamSetTy(numOfStreams[i]),
612                                                                  this->getDefaultBufferBlocks(), 1));
613    }
614    Kernel* fakeStreamGeneratorK = mPxDriver.addKernelInstance<FakeStreamGeneratingKernel>(idb, refStream->getNumOfStreams(), numOfStreams);
615    mPxDriver.makeKernelCall(fakeStreamGeneratorK, {refStream}, outputStreams);
616    return outputStreams;
617}
618
619
620
621std::vector<parabix::StreamSetBuffer *>
622LZ4GrepBaseGenerator::decompressBitStreams(parabix::StreamSetBuffer *compressedByteStream,
623                                           std::vector<parabix::StreamSetBuffer *> compressedBitStreams) {
624    // Default implementation here will be slow
625    std::vector<parabix::StreamSetBuffer *> retVec;
626    for (unsigned i = 0; i < compressedBitStreams.size(); i++) {
627        retVec.push_back(this->decompressBitStream(compressedByteStream, compressedBitStreams[i]));
628    }
629    return retVec;
630}
Note: See TracBrowser for help on using the repository browser.