source: icGREP/icgrep-devel/icgrep/kernels/lz4/decompression/lz4_twist_decompression.cpp @ 6143

Last change on this file since 6143 was 6143, checked in by xwa163, 10 months ago

lz4_grep: some bug fixing and refactor

File size: 17.9 KB
Line 
1
2#include "lz4_twist_decompression.h"
3
4#include <kernels/kernel_builder.h>
5#include <iostream>
6#include <string>
7#include <llvm/Support/raw_ostream.h>
8#include <kernels/streamset.h>
9
10using namespace llvm;
11using namespace kernel;
12using namespace std;
13
14
15namespace kernel {
16    size_t LZ4TwistDecompressionKernel::getNormalCopyLength() {
17        return (COPY_FW - BYTE_WIDTH) / mTwistWidth;
18    }
19    llvm::Value* LZ4TwistDecompressionKernel::getNormalCopyLengthValue(const std::unique_ptr<KernelBuilder> &b) {
20        return b->getSize(this->getNormalCopyLength());
21    }
22
23
24    LZ4TwistDecompressionKernel::LZ4TwistDecompressionKernel(const std::unique_ptr<kernel::KernelBuilder> &b, unsigned twistWidth, unsigned blockSize)
25            : LZ4SequentialDecompressionKernel(b, "LZ4TwistDecompressionKernel", blockSize),
26              mTwistWidth(twistWidth),
27              mItemsPerByte(BYTE_WIDTH / twistWidth)
28    {
29        mStreamSetInputs.push_back(Binding{b->getStreamSetTy(1, twistWidth), "inputTwistStream", RateEqualTo("byteStream")});
30//        mStreamSetInputs.push_back(Binding{b->getStreamSetTy(1, twistWidth), "refTwistStream"});
31        mStreamSetOutputs.push_back(Binding{b->getStreamSetTy(1, twistWidth), "outputTwistStream", BoundedRate(0, 1)});
32
33        this->addScalar(b->getInt8PtrTy(), "temporaryInputPtr");
34        this->addScalar(b->getInt8PtrTy(), "temporaryOutputPtr");
35    }
36
37
38
39    void LZ4TwistDecompressionKernel::doLiteralCopy(const std::unique_ptr<KernelBuilder> &b, llvm::Value *literalStart,
40                               llvm::Value *literalLength, llvm::Value* blockStart) {
41        // Constant and Type
42        Constant* SIZE_0 = b->getSize(0);
43        Constant* SIZE_ITEMS_PER_BYTE = b->getSize(mItemsPerByte);
44        Constant* INT_FW_TWIST_WIDTH = b->getIntN(COPY_FW, mTwistWidth);
45        Type* INT_FW_TY = b->getIntNTy(COPY_FW);
46        Type* INT_FW_PTR_TY = INT_FW_TY->getPointerTo();
47
48
49
50        Value* temporayInputPtr = b->getScalarField("temporaryInputPtr");
51        Value* initInputOffset = b->CreateSub(
52                b->CreateUDiv(literalStart, SIZE_ITEMS_PER_BYTE),
53                b->CreateUDiv(blockStart, SIZE_ITEMS_PER_BYTE)
54        );
55
56        Value* initInputPtr = b->CreateGEP(temporayInputPtr, initInputOffset);
57
58
59        Value* outputByteBasePtr = b->getScalarField("temporaryOutputPtr");
60        Value* outputPos = b->getScalarField("outputPos");
61
62        Value* outputPosRem = b->CreateSub(outputPos, b->getProducedItemCount("outputTwistStream"));
63        Value* outputPosByteRem = b->CreateUDiv(outputPosRem, SIZE_ITEMS_PER_BYTE);
64
65        Value* initOutputPtr = b->CreateGEP(outputByteBasePtr, outputPosByteRem);
66        Value* initOutputLastByte = b->CreateZExt(b->CreateLoad(initOutputPtr), INT_FW_TY);
67
68
69        Value* literalStartRemByteItem = b->CreateURem(literalStart, SIZE_ITEMS_PER_BYTE);
70        Value* outputPosRemByteItem = b->CreateURem(outputPos, SIZE_ITEMS_PER_BYTE);
71        Value* outputMask = this->getOutputMask(b, outputPos);
72
73        // ---- EntryBlock
74        BasicBlock* entryBlock = b->GetInsertBlock();
75
76        BasicBlock* literalCopyCon = b->CreateBasicBlock("literalCopyCon");
77
78        b->CreateBr(literalCopyCon);
79
80        // ---- literalCopyCon
81        b->SetInsertPoint(literalCopyCon);
82        PHINode* phiInputPtr = b->CreatePHI(b->getInt8PtrTy(), 2);
83        phiInputPtr->addIncoming(initInputPtr, entryBlock);
84
85        PHINode* phiOutputPtr = b->CreatePHI(b->getInt8PtrTy(), 2);
86        phiOutputPtr->addIncoming(initOutputPtr, entryBlock);
87
88        PHINode* phiCopiedLength = b->CreatePHI(b->getSizeTy(), 2);
89        phiCopiedLength->addIncoming(SIZE_0, entryBlock);
90
91        PHINode* phiOutputLastByte = b->CreatePHI(INT_FW_TY, 2);
92        phiOutputLastByte->addIncoming(initOutputLastByte, entryBlock);
93
94        BasicBlock* literalCopyBody = b->CreateBasicBlock("literalCopyBody");
95        BasicBlock* literalCopyExit = b->CreateBasicBlock("literalCopyExit");
96
97
98        b->CreateCondBr(b->CreateICmpULT(phiCopiedLength, literalLength), literalCopyBody, literalCopyExit);
99
100        // ---- literalCopyBody
101        b->SetInsertPoint(literalCopyBody);
102
103        Value* inputFwPtr = b->CreatePointerCast(phiInputPtr, INT_FW_PTR_TY);
104        Value* outputFwPtr = b->CreatePointerCast(phiOutputPtr, INT_FW_PTR_TY);
105
106
107        Value* inputTargetValue = b->CreateLoad(inputFwPtr);
108
109
110        inputTargetValue = b->CreateLShr(inputTargetValue, b->CreateMul(literalStartRemByteItem, INT_FW_TWIST_WIDTH));
111        inputTargetValue = b->CreateShl(inputTargetValue, b->CreateMul(outputPosRemByteItem, INT_FW_TWIST_WIDTH));
112
113
114        Value* newCopyLength = this->getNormalCopyLengthValue(b);
115
116        Value* outputValue = b->CreateAnd(phiOutputLastByte, outputMask);
117
118        outputValue = b->CreateOr(outputValue, inputTargetValue);
119
120        b->CreateStore(outputValue, outputFwPtr);
121
122        phiOutputLastByte->addIncoming(b->CreateLShr(outputValue, b->getSize(this->getNormalCopyLength() * mTwistWidth)), b->GetInsertBlock());
123        phiCopiedLength->addIncoming(b->CreateAdd(phiCopiedLength, newCopyLength), b->GetInsertBlock());
124        phiInputPtr->addIncoming(b->CreateGEP(phiInputPtr, b->CreateUDiv(newCopyLength, SIZE_ITEMS_PER_BYTE)), b->GetInsertBlock());
125        phiOutputPtr->addIncoming(b->CreateGEP(phiOutputPtr, b->CreateUDiv(newCopyLength, SIZE_ITEMS_PER_BYTE)), b->GetInsertBlock());
126
127        b->CreateBr(literalCopyCon);
128
129        // ---- literalCopyExit
130        b->SetInsertPoint(literalCopyExit);
131
132
133        b->setScalarField("outputPos", b->CreateAdd(outputPos, literalLength));
134    }
135
136
137
138    void LZ4TwistDecompressionKernel::doShortMatchCopy(const std::unique_ptr<KernelBuilder> &b, llvm::Value *matchOffset,
139                                  llvm::Value *matchLength) {
140        // Constant and Type
141        Constant* SIZE_0 = b->getSize(0);
142        Constant* SIZE_ITEMS_PER_BYTE = b->getSize(mItemsPerByte);
143        Constant* INT_FW_TWIST_WIDTH = b->getIntN(COPY_FW, mTwistWidth);
144        Type* INT8_PTR_TY = b->getInt8PtrTy();
145        Type* INT_FW_TY = b->getIntNTy(COPY_FW);
146        Type* INT_FW_PTR_TY = INT_FW_TY->getPointerTo();
147
148
149
150//        Value* outputCapacity = b->getCapacity("outputTwistStream");
151
152//        Value* outputByteBasePtr = b->CreatePointerCast(b->getRawOutputPointer("outputTwistStream", SIZE_0), INT8_PTR_TY);
153        Value* outputByteBasePtr = b->getScalarField("temporaryOutputPtr");
154
155        Value* outputPos = b->getScalarField("outputPos");
156//        Value* outputPosRem = b->CreateURem(outputPos, outputCapacity);
157        Value* outputPosRem = b->CreateSub(outputPos, b->getProducedItemCount("outputTwistStream"));
158
159        // ---- EntryBlock
160        BasicBlock* entryBlock = b->GetInsertBlock();
161        BasicBlock* literalCopyCon = b->CreateBasicBlock("literalCopyCon");
162        b->CreateBr(literalCopyCon);
163
164        // ---- literalCopyCon
165        b->SetInsertPoint(literalCopyCon);
166        PHINode* phiCopiedLength = b->CreatePHI(b->getSizeTy(), 2);
167        phiCopiedLength->addIncoming(SIZE_0, entryBlock);
168
169        BasicBlock* literalCopyBody = b->CreateBasicBlock("literalCopyBody");
170        BasicBlock* literalCopyExit = b->CreateBasicBlock("literalCopyExit");
171
172        b->CreateCondBr(b->CreateICmpULT(phiCopiedLength, matchLength), literalCopyBody, literalCopyExit);
173
174        // ---- literalCopyBody
175        b->SetInsertPoint(literalCopyBody);
176        Value* outputStartRem = b->CreateAdd(outputPosRem, phiCopiedLength);
177
178        Value* outputStartRemByteItem = b->CreateURem(outputStartRem, SIZE_ITEMS_PER_BYTE);
179        Value* outputStartByteRem = b->CreateUDiv(outputStartRem, SIZE_ITEMS_PER_BYTE);
180
181        Value* outputTargetPtr = b->CreateGEP(outputByteBasePtr, outputStartByteRem);
182        outputTargetPtr = b->CreatePointerCast(outputTargetPtr, INT_FW_PTR_TY);
183
184
185        Value* copyStartRem = b->CreateSub(outputStartRem, matchOffset);
186        Value* copyStartRemByteItem = b->CreateURem(copyStartRem, SIZE_ITEMS_PER_BYTE);
187        Value* copyStartByteRem = b->CreateUDiv(copyStartRem, SIZE_ITEMS_PER_BYTE);
188
189        Value* inputTargetPtr = b->CreateGEP(outputByteBasePtr, copyStartByteRem);
190        inputTargetPtr = b->CreatePointerCast(inputTargetPtr, INT_FW_PTR_TY);
191
192        Value* inputTargetValue = b->CreateLoad(inputTargetPtr);
193        inputTargetValue = b->CreateLShr(inputTargetValue, b->CreateMul(copyStartRemByteItem, INT_FW_TWIST_WIDTH));
194
195        Value* outputValue = b->CreateLoad(outputTargetPtr);
196        Value* outputMask = this->getOutputMask(b, outputStartRemByteItem);
197        outputValue = b->CreateAnd(outputValue, outputMask);
198
199        inputTargetValue = b->CreateShl(inputTargetValue, b->CreateMul(outputStartRemByteItem, INT_FW_TWIST_WIDTH));
200        outputValue = b->CreateOr(outputValue, inputTargetValue);
201        b->CreateStore(outputValue, outputTargetPtr);
202
203        Value* newCopyLength = matchOffset;
204
205        phiCopiedLength->addIncoming(b->CreateAdd(phiCopiedLength, newCopyLength), b->GetInsertBlock());
206
207        b->CreateBr(literalCopyCon);
208
209        // ---- literalCopyExit
210        b->SetInsertPoint(literalCopyExit);
211        b->setScalarField("outputPos", b->CreateAdd(outputPos, matchLength));
212    }
213
214    void LZ4TwistDecompressionKernel::doLongMatchCopy(const std::unique_ptr<KernelBuilder> &b, llvm::Value *matchOffset,
215                                 llvm::Value *matchLength) {
216        // Constant and Type
217        Constant* SIZE_0 = b->getSize(0);
218        Constant* SIZE_ITEMS_PER_BYTE = b->getSize(mItemsPerByte);
219        Constant* INT_FW_TWIST_WIDTH = b->getIntN(COPY_FW, mTwistWidth);
220        Type* INT8_PTR_TY = b->getInt8PtrTy();
221        Type* INT_FW_TY = b->getIntNTy(COPY_FW);
222        Type* INT_FW_PTR_TY = INT_FW_TY->getPointerTo();
223
224
225//        Value* outputByteBasePtr = b->CreatePointerCast(b->getRawOutputPointer("outputTwistStream", b->getSize(0)), INT8_PTR_TY);
226//        Value* outputCapacity = b->getCapacity("outputTwistStream");
227        Value* outputByteBasePtr = b->getScalarField("temporaryOutputPtr");
228        Value* outputPos = b->getScalarField("outputPos");
229
230
231//        Value* outputPosRem = b->CreateURem(outputPos, outputCapacity);
232        Value* outputPosRem = b->CreateSub(outputPos, b->getProducedItemCount("outputTwistStream"));
233
234        Value* outputPosRemByteItem = b->CreateURem(outputPosRem, SIZE_ITEMS_PER_BYTE);
235        Value* outputMask = this->getOutputMask(b, outputPosRem);
236
237
238        Value* outputBytePos = b->CreateUDiv(outputPosRem, SIZE_ITEMS_PER_BYTE);
239
240
241        Value* initCopyToPtr = b->CreateGEP(outputByteBasePtr, outputBytePos);
242        Value* initOutputLastByte = b->CreateZExt(b->CreateLoad(initCopyToPtr), INT_FW_TY);
243
244        Value* copyFromPosRem = b->CreateSub(outputPosRem, matchOffset);
245
246        Value* copyFromPosRemByteItem = b->CreateURem(copyFromPosRem, SIZE_ITEMS_PER_BYTE);
247        Value* copyFromBytePos = b->CreateUDiv(copyFromPosRem, SIZE_ITEMS_PER_BYTE);
248        Value* initCopyFromPtr = b->CreateGEP(outputByteBasePtr, copyFromBytePos);
249
250
251        Value* copyLength = this->getNormalCopyLengthValue(b);
252        Value* copyLengthByte = b->CreateUDiv(copyLength, SIZE_ITEMS_PER_BYTE);
253
254        // ---- EntryBlock
255        BasicBlock* entryBlock = b->GetInsertBlock();
256        BasicBlock* literalCopyCon = b->CreateBasicBlock("literalCopyCon");
257        b->CreateBr(literalCopyCon);
258
259        // ---- literalCopyCon
260        b->SetInsertPoint(literalCopyCon);
261        PHINode* phiCopiedLength = b->CreatePHI(b->getSizeTy(), 2);
262        phiCopiedLength->addIncoming(SIZE_0, entryBlock);
263
264        PHINode* phiCopyFromPtr = b->CreatePHI(b->getInt8PtrTy(), 2);
265        phiCopyFromPtr->addIncoming(initCopyFromPtr, entryBlock);
266
267        PHINode* phiCopyToPtr = b->CreatePHI(b->getInt8PtrTy(), 2);
268        phiCopyToPtr->addIncoming(initCopyToPtr, entryBlock);
269
270        PHINode* phiOutputLastByte = b->CreatePHI(b->getIntNTy(COPY_FW), 2);
271        phiOutputLastByte->addIncoming(initOutputLastByte, entryBlock);
272
273
274        BasicBlock* literalCopyBody = b->CreateBasicBlock("literalCopyBody");
275        BasicBlock* literalCopyExit = b->CreateBasicBlock("literalCopyExit");
276
277        b->CreateCondBr(b->CreateICmpULT(phiCopiedLength, matchLength), literalCopyBody, literalCopyExit);
278
279        // ---- literalCopyBody
280        b->SetInsertPoint(literalCopyBody);
281        Value* outputTargetPtr = b->CreatePointerCast(phiCopyToPtr, INT_FW_PTR_TY);
282        Value* inputTargetPtr = b->CreatePointerCast(phiCopyFromPtr, INT_FW_PTR_TY);
283
284        Value* inputTargetValue = b->CreateLoad(inputTargetPtr);
285        inputTargetValue = b->CreateLShr(inputTargetValue, b->CreateMul(copyFromPosRemByteItem, INT_FW_TWIST_WIDTH));
286        inputTargetValue = b->CreateShl(inputTargetValue, b->CreateMul(outputPosRemByteItem, INT_FW_TWIST_WIDTH));
287
288        Value* outputValue = b->CreateAnd(phiOutputLastByte, outputMask);
289
290        outputValue = b->CreateOr(outputValue, inputTargetValue);
291        b->CreateStore(outputValue, outputTargetPtr);
292
293        phiCopiedLength->addIncoming(b->CreateAdd(phiCopiedLength, copyLength), b->GetInsertBlock());
294        phiCopyFromPtr->addIncoming(b->CreateGEP(phiCopyFromPtr, copyLengthByte), b->GetInsertBlock());
295        phiCopyToPtr->addIncoming(b->CreateGEP(phiCopyToPtr, copyLengthByte), b->GetInsertBlock());
296        phiOutputLastByte->addIncoming(b->CreateLShr(outputValue, b->getSize(this->getNormalCopyLength() * mTwistWidth)), b->GetInsertBlock());
297
298        b->CreateBr(literalCopyCon);
299
300        // ---- literalCopyExit
301        b->SetInsertPoint(literalCopyExit);
302
303        b->setScalarField("outputPos", b->CreateAdd(outputPos, matchLength));
304    }
305
306
307    void LZ4TwistDecompressionKernel::doMatchCopy(const std::unique_ptr<KernelBuilder> &b, llvm::Value *matchOffset,
308                             llvm::Value *matchLength) {
309
310        BasicBlock* shortMatchCopyBlock = b->CreateBasicBlock("shortMatchCopyBlock");
311        BasicBlock* longMatchCopyBlock = b->CreateBasicBlock("longMatchCopyBlock");
312        BasicBlock* matchCopyFinishBlock = b->CreateBasicBlock("matchCopyFinishBlock");
313
314        b->CreateUnlikelyCondBr(
315                b->CreateICmpULT(matchOffset, this->getNormalCopyLengthValue(b)),
316                shortMatchCopyBlock,
317                longMatchCopyBlock
318        );
319
320        // ---- shortMatchCopyBlock
321        b->SetInsertPoint(shortMatchCopyBlock);
322        this->doShortMatchCopy(b, matchOffset, matchLength);
323        b->CreateBr(matchCopyFinishBlock);
324
325        // ---- longMatchCopyBlock
326        b->SetInsertPoint(longMatchCopyBlock);
327        this->doLongMatchCopy(b, matchOffset, matchLength);
328        b->CreateBr(matchCopyFinishBlock);
329
330        b->SetInsertPoint(matchCopyFinishBlock);
331    }
332
333    void LZ4TwistDecompressionKernel::setProducedOutputItemCount(const std::unique_ptr<KernelBuilder> &b, llvm::Value* produced) {
334        Constant* SIZE_ITEMS_PER_BYTE = b->getSize(mItemsPerByte);
335        Constant* SIZE_0 = b->getSize(0);
336        Type* INT8_PTR_TY = b->getInt8PtrTy();
337
338        Value* oldProduced = b->getProducedItemCount("outputTwistStream");
339
340        Value* outputByteBasePtr = b->CreatePointerCast(b->getRawOutputPointer("outputTwistStream", SIZE_0), INT8_PTR_TY);
341        Value* outputCapacity = b->getCapacity("outputTwistStream");
342        Value* outputPosRem = b->CreateURem(oldProduced, outputCapacity);
343
344        Value* outputPosByteRem = b->CreateUDiv(outputPosRem, SIZE_ITEMS_PER_BYTE);
345        Value* actualOutputPtr = b->CreateGEP(outputByteBasePtr, outputPosByteRem);
346        b->CreateMemCpy(actualOutputPtr, b->getScalarField("temporaryOutputPtr"), b->getSize(mBlockSize / mItemsPerByte), 1);
347
348        Value* ptr = b->CreateGEP(b->CreatePointerCast(b->getScalarField("temporaryOutputPtr"), b->getBitBlockType()->getPointerTo()), b->getSize(0x16f));
349
350        b->setProducedItemCount("outputTwistStream", produced);
351    }
352
353
354    void LZ4TwistDecompressionKernel::initializationMethod(const std::unique_ptr<KernelBuilder> &b) {
355        b->setScalarField("temporaryInputPtr", b->CreateMalloc(b->getSize(mBlockSize / mItemsPerByte)));
356        b->setScalarField("temporaryOutputPtr", b->CreateMalloc(b->getSize(mBlockSize / mItemsPerByte + COPY_FW / BYTE_WIDTH)));
357
358    }
359
360    void LZ4TwistDecompressionKernel::prepareProcessBlock(const std::unique_ptr<KernelBuilder> &b, llvm::Value* blockStart, llvm::Value* blockEnd) {
361        Constant* SIZE_0 = b->getSize(0);
362        Constant* SIZE_ITEMS_PER_BYTE = b->getSize(mItemsPerByte);
363        Type* INT8_PTR_TY = b->getInt8PtrTy();
364
365
366        Value* rawInputPtr = b->CreatePointerCast(b->getRawInputPointer("inputTwistStream", SIZE_0), INT8_PTR_TY);
367        Value* inputCapacity = b->getCapacity("inputTwistStream");
368
369        Value* inputByteCapacity = b->CreateUDiv(inputCapacity, SIZE_ITEMS_PER_BYTE);
370
371        Value* blockStartRem = b->CreateURem(blockStart, inputCapacity);
372        Value* blockStartByteRem = b->CreateUDiv(blockStartRem, SIZE_ITEMS_PER_BYTE);
373        Value* remByte = b->CreateSub(inputByteCapacity, blockStartByteRem);
374
375        Value* blockSize = b->CreateSub(blockEnd, blockStart);
376        Value* copyTotalByte = b->CreateAdd(b->CreateUDiv(blockSize, SIZE_ITEMS_PER_BYTE), SIZE_ITEMS_PER_BYTE); // It will be safe if we copy a few bytes more
377
378        Value* copyBytes1 = b->CreateUMin(remByte, copyTotalByte);
379        Value* copyBytes2 = b->CreateSub(copyTotalByte, copyBytes1);
380
381        Value* temporayInputPtr = b->getScalarField("temporaryInputPtr");
382
383        b->CreateMemCpy(temporayInputPtr, b->CreateGEP(rawInputPtr, blockStartByteRem), copyBytes1, 1);
384        b->CreateMemCpy(b->CreateGEP(temporayInputPtr, copyBytes1), rawInputPtr, copyBytes2, 1);
385    }
386
387    void LZ4TwistDecompressionKernel::beforeTermination(const std::unique_ptr<KernelBuilder> &b) {
388        b->CreateFree(b->getScalarField("temporaryInputPtr"));
389        b->CreateFree(b->getScalarField("temporaryOutputPtr"));
390    }
391
392    llvm::Value *LZ4TwistDecompressionKernel::getOutputMask(const std::unique_ptr<KernelBuilder> &b, llvm::Value *outputPos) {
393        Value* remByteItems = b->CreateURem(outputPos, b->getSize(mItemsPerByte));
394        Value* INT_FW_1 = b->getIntN(COPY_FW, 1);
395        Value* shiftAmount = b->CreateMul(remByteItems, b->getIntN(COPY_FW, mTwistWidth));
396        return b->CreateSub(
397                b->CreateShl(INT_FW_1, shiftAmount),
398                INT_FW_1
399        );
400    }
401}
Note: See TracBrowser for help on using the repository browser.