source: icGREP/icgrep-devel/icgrep/lz4/grep/lz4_grep_bytestream_generator.cpp @ 6145

Last change on this file since 6145 was 6145, checked in by xwa163, 6 weeks ago
  1. LZ4 Grep: complete utf8 character classes for multiplexing pipeline
  2. Implement multiple streams version of S2P and P2S
File size: 7.3 KB
Line 
1
2#include <numeric>
3#include "lz4_grep_bytestream_generator.h"
4#include <kernels/lz4/twist_kernel.h>
5#include <kernels/lz4/decompression/lz4_twist_decompression.h>
6#include <kernels/lz4/untwist_kernel.h>
7#include <kernels/s2p_kernel.h>
8#include <kernels/p2s_kernel.h>
9#include <kernels/lz4/decompression/lz4_bytestream_decompression.h>
10#include <kernels/kernel_builder.h>
11
12
13using namespace kernel;
14using namespace parabix;
15
16StreamSetBuffer *LZ4GrepByteStreamGenerator::generateUncompressedBitStreams() {
17    StreamSetBuffer* compressedByteStream = this->loadByteStream();
18    parabix::StreamSetBuffer * uncompressedByteStream = this->byteStreamDecompression(compressedByteStream);
19    return this->s2p(uncompressedByteStream);
20}
21
22parabix::StreamSetBuffer *
23LZ4GrepByteStreamGenerator::decompressBitStream(parabix::StreamSetBuffer *compressedByteStream,
24                                                parabix::StreamSetBuffer *compressedBitStream) {
25    return this->decompressBitStreams(compressedByteStream, {compressedBitStream})[0];
26}
27
28unsigned LZ4GrepByteStreamGenerator::calculateTwistWidth(unsigned numOfStreams) {
29    if (numOfStreams <= 2) {
30        return numOfStreams;
31    } else if (numOfStreams <= 4) {
32        return 4;
33    } else if (numOfStreams <= 8) {
34        return 8;
35    } else {
36        llvm::report_fatal_error("Twist: Unsupported numOfStreams " + std::to_string(numOfStreams));;
37    }
38}
39
40std::vector<parabix::StreamSetBuffer *>
41LZ4GrepByteStreamGenerator::decompressBitStreams(parabix::StreamSetBuffer *compressedByteStream,
42                                                 std::vector<parabix::StreamSetBuffer *> compressedBitStreams) {
43    auto & b = mPxDriver.getBuilder();
44
45    std::vector<unsigned> numOfStreams(compressedBitStreams.size());
46    std::transform(compressedBitStreams.begin(), compressedBitStreams.end(), numOfStreams.begin(), [](StreamSetBuffer* b){return b->getNumOfStreams();});
47    unsigned totalStreamNum = std::accumulate(numOfStreams.begin(), numOfStreams.end(), 0u);
48
49    unsigned twistWidth = this->calculateTwistWidth(totalStreamNum);
50    StreamSetBuffer* twistedStream = this->twist(b, compressedBitStreams, twistWidth);
51
52    LZ4BlockInfo blockInfo = this->getBlockInfo(compressedByteStream);
53    StreamSetBuffer* uncompressedTwistedStream = mPxDriver.addBuffer<StaticBuffer>(b, b->getStreamSetTy(1, twistWidth), this->getDefaultBufferBlocks(), 1);
54    std::vector<StreamSetBuffer*> inputStreams = {
55            compressedByteStream,
56
57            blockInfo.isCompress,
58            blockInfo.blockStart,
59            blockInfo.blockEnd,
60
61            twistedStream
62    };
63    std::vector<StreamSetBuffer*> outputStreams = {
64            uncompressedTwistedStream
65    };
66
67    if (twistWidth <= 4) {
68        Kernel* lz4I4AioK = mPxDriver.addKernelInstance<LZ4TwistDecompressionKernel>(b, twistWidth);
69        lz4I4AioK->setInitialArguments({mFileSize});
70        mPxDriver.makeKernelCall(lz4I4AioK, inputStreams, outputStreams);
71
72    } else {
73        Kernel* lz4AioK = mPxDriver.addKernelInstance<LZ4ByteStreamDecompressionKernel>(b, true);
74        lz4AioK->setInitialArguments({mFileSize});
75        mPxDriver.makeKernelCall(lz4AioK, inputStreams, outputStreams);
76    }
77    return this->untwist(b, uncompressedTwistedStream, twistWidth, numOfStreams);
78}
79
80parabix::StreamSetBuffer* LZ4GrepByteStreamGenerator::twist(const std::unique_ptr<kernel::KernelBuilder> &b,
81                                                            std::vector<StreamSetBuffer*> inputStreams,
82                                                            unsigned twistWidth
83) {
84    std::vector<unsigned> numsOfStreams(inputStreams.size());
85    std::transform(inputStreams.begin(), inputStreams.end(), numsOfStreams.begin(), [](StreamSetBuffer* b){return b->getNumOfStreams();});
86    unsigned totalNumOfStreams = std::accumulate(numsOfStreams.begin(), numsOfStreams.end(), 0u);
87    assert(totalNumOfStreams <= twistWidth);
88
89    if (twistWidth == 1) {
90        for (unsigned i = 0; i < inputStreams.size(); i++) {
91            if (inputStreams[i]->getNumOfStreams() == 1) {
92                return inputStreams[i];
93            }
94        }
95    } else if (twistWidth == 2 || twistWidth == 4) {
96        StreamSetBuffer* twistedCharClasses = mPxDriver.addBuffer<StaticBuffer>(b, b->getStreamSetTy(1, twistWidth),
97                                                                                this->getDefaultBufferBlocks(), 1);
98        kernel::Kernel* twistK = mPxDriver.addKernelInstance<kernel::TwistMultipleByPDEPKernel>(b, numsOfStreams, twistWidth);
99        mPxDriver.makeKernelCall(twistK, inputStreams, {twistedCharClasses});
100        return twistedCharClasses;
101    } else if (twistWidth == 8) {
102        StreamSetBuffer * const mtxByteStream = mPxDriver.addBuffer<StaticBuffer>(b, b->getStreamSetTy(1, twistWidth),
103                                                                                  this->getDefaultBufferBlocks());
104        Kernel * p2sK = mPxDriver.addKernelInstance<P2SMultipleStreamsKernel>(b, cc::BitNumbering::BigEndian, numsOfStreams);
105        mPxDriver.makeKernelCall(p2sK, inputStreams, {mtxByteStream});
106        return mtxByteStream;
107    } else {
108        llvm::report_fatal_error("Twist: Unsupported twistWidth " + std::to_string(twistWidth));;
109    }
110}
111
112std::vector<StreamSetBuffer*> LZ4GrepByteStreamGenerator::untwist(const std::unique_ptr<kernel::KernelBuilder> &b,
113                                                              parabix::StreamSetBuffer *inputStream,
114                                                              unsigned twistWidth,
115                                                              std::vector<unsigned> numOfStreams
116) {
117    unsigned totalNumOfStreams = std::accumulate(numOfStreams.begin(), numOfStreams.end(), 0u);
118    assert(totalNumOfStreams <= twistWidth);
119    if (twistWidth == 1) {
120        std::vector<unsigned> fakeStreamNums;
121        for (unsigned i = 0; i < numOfStreams.size(); i++) {
122            if (numOfStreams[i] == 0) {
123                fakeStreamNums.push_back(0);
124            }
125        }
126        auto fakeStreams = this->generateFakeStreams(b, inputStream, fakeStreamNums);
127
128        std::vector<StreamSetBuffer*> retBuffers;
129        unsigned j = 0;
130        for (unsigned i = 0; i < numOfStreams.size(); i++) {
131            if (numOfStreams[i] == 0) {
132                retBuffers.push_back(fakeStreams[j]);
133                j++;
134            } else {
135                retBuffers.push_back(inputStream);
136            }
137        }
138        return retBuffers;
139    } else{
140        std::vector<StreamSetBuffer*> retBuffers;
141        for (unsigned i = 0; i < numOfStreams.size(); i++) {
142            retBuffers.push_back(mPxDriver.addBuffer<StaticBuffer>(b, b->getStreamSetTy(numOfStreams[i]), this->getDefaultBufferBlocks(), 1));
143        }
144
145
146        if (twistWidth == 2 || twistWidth == 4) {
147            kernel::Kernel* untwistK = mPxDriver.addKernelInstance<kernel::UntwistMultipleByPEXTKernel>(b, numOfStreams, twistWidth);
148            mPxDriver.makeKernelCall(untwistK, {inputStream}, retBuffers);
149            return retBuffers;
150        } else if (twistWidth == 8) {
151            Kernel * s2pk = mPxDriver.addKernelInstance<S2PMultipleStreamsKernel>(b, cc::BitNumbering::BigEndian, true, numOfStreams);
152            mPxDriver.makeKernelCall(s2pk, {inputStream}, retBuffers);
153            return retBuffers;
154        } else {
155            llvm::report_fatal_error("Twist: Unsupported twistWidth " + std::to_string(twistWidth));;
156        }
157    }
158}
159
160
Note: See TracBrowser for help on using the repository browser.