source: icGREP/icgrep-devel/icgrep/kernels/lz4/aio/lz4_twist_aio.cpp @ 6135

Last change on this file since 6135 was 6135, checked in by xwa163, 6 months ago
  1. Implement twist_kernel and untwist_kernel by PEXT and PDEP
  2. Use twist form for multiplexing lz4 grep
File size: 16.2 KB
Line 
1
2#include "lz4_twist_aio.h"
3
4#include <kernels/kernel_builder.h>
5#include <iostream>
6#include <string>
7#include <llvm/Support/raw_ostream.h>
8#include <kernels/streamset.h>
9
10using namespace llvm;
11using namespace kernel;
12using namespace std;
13
14
15namespace kernel {
16    size_t LZ4TwistAioKernel::getNormalCopyLength() {
17        return (COPY_FW - BYTE_WIDTH) / mTwistWidth;
18    }
19    llvm::Value* LZ4TwistAioKernel::getNormalCopyLengthValue(const std::unique_ptr<KernelBuilder> &b) {
20        return b->getSize(this->getNormalCopyLength());
21    }
22
23
24    LZ4TwistAioKernel::LZ4TwistAioKernel(const std::unique_ptr<kernel::KernelBuilder> &b, unsigned twistWidth, unsigned blockSize)
25            : LZ4SequentialAioBaseKernel(b, "LZ4TwistAioKernel", blockSize),
26              mTwistWidth(twistWidth),
27              mItemsPerByte(BYTE_WIDTH / twistWidth)
28    {
29        mStreamSetInputs.push_back(Binding{b->getStreamSetTy(1, twistWidth), "inputTwistStream", RateEqualTo("byteStream")});
30        mStreamSetOutputs.push_back(Binding{b->getStreamSetTy(1, twistWidth), "outputTwistStream", BoundedRate(0, 1)});
31
32        this->addScalar(b->getInt8PtrTy(), "temporaryInputPtr");
33    }
34
35
36
37    void LZ4TwistAioKernel::doLiteralCopy(const std::unique_ptr<KernelBuilder> &b, llvm::Value *literalStart,
38                               llvm::Value *literalLength, llvm::Value* blockStart) {
39        // Constant and Type
40        Constant* SIZE_0 = b->getSize(0);
41        Constant* SIZE_ITEMS_PER_BYTE = b->getSize(mItemsPerByte);
42        Constant* INT_FW_TWIST_WIDTH = b->getIntN(COPY_FW, mTwistWidth);
43        Type* INT8_PTR_TY = b->getInt8PtrTy();
44        Type* INT_FW_TY = b->getIntNTy(COPY_FW);
45        Type* INT_FW_PTR_TY = INT_FW_TY->getPointerTo();
46
47
48
49        Value* temporayInputPtr = b->getScalarField("temporaryInputPtr");
50        Value* initInputOffset = b->CreateSub(
51                b->CreateUDiv(literalStart, SIZE_ITEMS_PER_BYTE),
52                b->CreateUDiv(blockStart, SIZE_ITEMS_PER_BYTE)
53        );
54
55        Value* initInputPtr = b->CreateGEP(temporayInputPtr, initInputOffset);
56
57        Value* outputByteBasePtr = b->CreatePointerCast(b->getRawOutputPointer("outputTwistStream", SIZE_0), INT8_PTR_TY);
58        Value* outputPos = b->getScalarField("outputPos");
59        Value* outputCapacity = b->getCapacity("outputTwistStream");
60
61        Value* outputPosRem = b->CreateURem(outputPos, outputCapacity);
62        Value* outputPosByteRem = b->CreateUDiv(outputPosRem, SIZE_ITEMS_PER_BYTE);
63
64        Value* initOutputPtr = b->CreateGEP(outputByteBasePtr, outputPosByteRem);
65        Value* initOutputLastByte = b->CreateZExt(b->CreateLoad(initOutputPtr), INT_FW_TY);
66
67
68        Value* literalStartRemByteItem = b->CreateURem(literalStart, SIZE_ITEMS_PER_BYTE);
69        Value* outputPosRemByteItem = b->CreateURem(outputPos, SIZE_ITEMS_PER_BYTE);
70        Value* outputMask = this->getOutputMask(b, outputPos);
71
72
73        // ---- EntryBlock
74        BasicBlock* entryBlock = b->GetInsertBlock();
75
76        BasicBlock* literalCopyCon = b->CreateBasicBlock("literalCopyCon");
77
78        b->CreateBr(literalCopyCon);
79
80        // ---- literalCopyCon
81        b->SetInsertPoint(literalCopyCon);
82        PHINode* phiInputPtr = b->CreatePHI(b->getInt8PtrTy(), 2);
83        phiInputPtr->addIncoming(initInputPtr, entryBlock);
84
85        PHINode* phiOutputPtr = b->CreatePHI(b->getInt8PtrTy(), 2);
86        phiOutputPtr->addIncoming(initOutputPtr, entryBlock);
87
88        PHINode* phiCopiedLength = b->CreatePHI(b->getSizeTy(), 2);
89        phiCopiedLength->addIncoming(SIZE_0, entryBlock);
90
91        PHINode* phiOutputLastByte = b->CreatePHI(INT_FW_TY, 2);
92        phiOutputLastByte->addIncoming(initOutputLastByte, entryBlock);
93
94        BasicBlock* literalCopyBody = b->CreateBasicBlock("literalCopyBody");
95        BasicBlock* literalCopyExit = b->CreateBasicBlock("literalCopyExit");
96
97        b->CreateCondBr(b->CreateICmpULT(phiCopiedLength, literalLength), literalCopyBody, literalCopyExit);
98
99        // ---- literalCopyBody
100        b->SetInsertPoint(literalCopyBody);
101
102        Value* inputFwPtr = b->CreatePointerCast(phiInputPtr, INT_FW_PTR_TY);
103        Value* outputFwPtr = b->CreatePointerCast(phiOutputPtr, INT_FW_PTR_TY);
104
105        Value* inputTargetValue = b->CreateLoad(inputFwPtr);
106        inputTargetValue = b->CreateLShr(inputTargetValue, b->CreateMul(literalStartRemByteItem, INT_FW_TWIST_WIDTH));
107        inputTargetValue = b->CreateShl(inputTargetValue, b->CreateMul(outputPosRemByteItem, INT_FW_TWIST_WIDTH));
108
109
110        Value* newCopyLength = this->getNormalCopyLengthValue(b);
111
112        Value* outputValue = b->CreateAnd(phiOutputLastByte, outputMask);
113        outputValue = b->CreateOr(outputValue, inputTargetValue);
114        b->CreateStore(outputValue, outputFwPtr);
115
116
117        phiOutputLastByte->addIncoming(b->CreateLShr(outputValue, b->getSize(this->getNormalCopyLength() * mTwistWidth)), b->GetInsertBlock());
118        phiCopiedLength->addIncoming(b->CreateAdd(phiCopiedLength, newCopyLength), b->GetInsertBlock());
119        phiInputPtr->addIncoming(b->CreateGEP(phiInputPtr, b->CreateUDiv(newCopyLength, SIZE_ITEMS_PER_BYTE)), b->GetInsertBlock());
120        phiOutputPtr->addIncoming(b->CreateGEP(phiOutputPtr, b->CreateUDiv(newCopyLength, SIZE_ITEMS_PER_BYTE)), b->GetInsertBlock());
121
122        b->CreateBr(literalCopyCon);
123
124        // ---- literalCopyExit
125        b->SetInsertPoint(literalCopyExit);
126        b->setScalarField("outputPos", b->CreateAdd(outputPos, literalLength));
127    }
128
129
130
131    void LZ4TwistAioKernel::doShortMatchCopy(const std::unique_ptr<KernelBuilder> &b, llvm::Value *matchOffset,
132                                  llvm::Value *matchLength) {
133        // Constant and Type
134        Constant* SIZE_0 = b->getSize(0);
135        Constant* SIZE_ITEMS_PER_BYTE = b->getSize(mItemsPerByte);
136        Constant* INT_FW_TWIST_WIDTH = b->getIntN(COPY_FW, mTwistWidth);
137        Type* INT8_PTR_TY = b->getInt8PtrTy();
138        Type* INT_FW_TY = b->getIntNTy(COPY_FW);
139        Type* INT_FW_PTR_TY = INT_FW_TY->getPointerTo();
140
141
142
143        Value* outputCapacity = b->getCapacity("outputTwistStream");
144
145        Value* outputByteBasePtr = b->CreatePointerCast(b->getRawOutputPointer("outputTwistStream", SIZE_0), INT8_PTR_TY);
146
147        Value* outputPos = b->getScalarField("outputPos");
148        Value* outputPosRem = b->CreateURem(outputPos, outputCapacity);
149
150        // ---- EntryBlock
151        BasicBlock* entryBlock = b->GetInsertBlock();
152        BasicBlock* literalCopyCon = b->CreateBasicBlock("literalCopyCon");
153        b->CreateBr(literalCopyCon);
154
155        // ---- literalCopyCon
156        b->SetInsertPoint(literalCopyCon);
157        PHINode* phiCopiedLength = b->CreatePHI(b->getSizeTy(), 2);
158        phiCopiedLength->addIncoming(SIZE_0, entryBlock);
159
160        BasicBlock* literalCopyBody = b->CreateBasicBlock("literalCopyBody");
161        BasicBlock* literalCopyExit = b->CreateBasicBlock("literalCopyExit");
162
163        b->CreateCondBr(b->CreateICmpULT(phiCopiedLength, matchLength), literalCopyBody, literalCopyExit);
164
165        // ---- literalCopyBody
166        b->SetInsertPoint(literalCopyBody);
167        Value* outputStartRem = b->CreateAdd(outputPosRem, phiCopiedLength);
168
169        Value* outputStartRemByteItem = b->CreateURem(outputStartRem, SIZE_ITEMS_PER_BYTE);
170        Value* outputStartByteRem = b->CreateUDiv(outputStartRem, SIZE_ITEMS_PER_BYTE);
171
172        Value* outputTargetPtr = b->CreateGEP(outputByteBasePtr, outputStartByteRem);
173        outputTargetPtr = b->CreatePointerCast(outputTargetPtr, INT_FW_PTR_TY);
174
175
176        Value* copyStartRem = b->CreateSub(outputStartRem, matchOffset);
177        Value* copyStartRemByteItem = b->CreateURem(copyStartRem, SIZE_ITEMS_PER_BYTE);
178        Value* copyStartByteRem = b->CreateUDiv(copyStartRem, SIZE_ITEMS_PER_BYTE);
179
180        Value* inputTargetPtr = b->CreateGEP(outputByteBasePtr, copyStartByteRem);
181        inputTargetPtr = b->CreatePointerCast(inputTargetPtr, INT_FW_PTR_TY);
182
183        Value* inputTargetValue = b->CreateLoad(inputTargetPtr);
184        inputTargetValue = b->CreateLShr(inputTargetValue, b->CreateMul(copyStartRemByteItem, INT_FW_TWIST_WIDTH));
185
186        Value* outputValue = b->CreateLoad(outputTargetPtr);
187        Value* outputMask = this->getOutputMask(b, outputStartRemByteItem);
188        outputValue = b->CreateAnd(outputValue, outputMask);
189
190        inputTargetValue = b->CreateShl(inputTargetValue, b->CreateMul(outputStartRemByteItem, INT_FW_TWIST_WIDTH));
191        outputValue = b->CreateOr(outputValue, inputTargetValue);
192        b->CreateStore(outputValue, outputTargetPtr);
193
194        Value* newCopyLength = matchOffset;
195
196        phiCopiedLength->addIncoming(b->CreateAdd(phiCopiedLength, newCopyLength), b->GetInsertBlock());
197
198        b->CreateBr(literalCopyCon);
199
200        // ---- literalCopyExit
201        b->SetInsertPoint(literalCopyExit);
202        b->setScalarField("outputPos", b->CreateAdd(outputPos, matchLength));
203    }
204
205    void LZ4TwistAioKernel::doLongMatchCopy(const std::unique_ptr<KernelBuilder> &b, llvm::Value *matchOffset,
206                                 llvm::Value *matchLength) {
207        // Constant and Type
208        Constant* SIZE_0 = b->getSize(0);
209        Constant* SIZE_ITEMS_PER_BYTE = b->getSize(mItemsPerByte);
210        Constant* INT_FW_TWIST_WIDTH = b->getIntN(COPY_FW, mTwistWidth);
211        Type* INT8_PTR_TY = b->getInt8PtrTy();
212        Type* INT_FW_TY = b->getIntNTy(COPY_FW);
213        Type* INT_FW_PTR_TY = INT_FW_TY->getPointerTo();
214
215
216
217        Value* outputByteBasePtr = b->CreatePointerCast(b->getRawOutputPointer("outputTwistStream", b->getSize(0)), INT8_PTR_TY);
218        Value* outputCapacity = b->getCapacity("outputTwistStream");
219        Value* outputPos = b->getScalarField("outputPos");
220        Value* outputPosRem = b->CreateURem(outputPos, outputCapacity);
221
222        Value* outputPosRemByteItem = b->CreateURem(outputPosRem, SIZE_ITEMS_PER_BYTE);
223        Value* outputMask = this->getOutputMask(b, outputPosRem);
224
225
226        Value* outputBytePos = b->CreateUDiv(outputPosRem, SIZE_ITEMS_PER_BYTE);
227        Value* initCopyToPtr = b->CreateGEP(outputByteBasePtr, outputBytePos);
228
229        Value* initOutputLastByte = b->CreateZExt(b->CreateLoad(initCopyToPtr), INT_FW_TY);
230
231        Value* copyFromPosRem = b->CreateSub(outputPosRem, matchOffset);
232
233        Value* copyFromPosRemByteItem = b->CreateURem(copyFromPosRem, SIZE_ITEMS_PER_BYTE);
234        Value* copyFromBytePos = b->CreateUDiv(copyFromPosRem, SIZE_ITEMS_PER_BYTE);
235        Value* initCopyFromPtr = b->CreateGEP(outputByteBasePtr, copyFromBytePos);
236
237
238        Value* copyLength = this->getNormalCopyLengthValue(b);
239        Value* copyLengthByte = b->CreateUDiv(copyLength, SIZE_ITEMS_PER_BYTE);
240
241        // ---- EntryBlock
242        BasicBlock* entryBlock = b->GetInsertBlock();
243        BasicBlock* literalCopyCon = b->CreateBasicBlock("literalCopyCon");
244        b->CreateBr(literalCopyCon);
245
246        // ---- literalCopyCon
247        b->SetInsertPoint(literalCopyCon);
248        PHINode* phiCopiedLength = b->CreatePHI(b->getSizeTy(), 2);
249        phiCopiedLength->addIncoming(SIZE_0, entryBlock);
250
251        PHINode* phiCopyFromPtr = b->CreatePHI(b->getInt8PtrTy(), 2);
252        phiCopyFromPtr->addIncoming(initCopyFromPtr, entryBlock);
253
254        PHINode* phiCopyToPtr = b->CreatePHI(b->getInt8PtrTy(), 2);
255        phiCopyToPtr->addIncoming(initCopyToPtr, entryBlock);
256
257        PHINode* phiOutputLastByte = b->CreatePHI(b->getIntNTy(COPY_FW), 2);
258        phiOutputLastByte->addIncoming(initOutputLastByte, entryBlock);
259
260
261        BasicBlock* literalCopyBody = b->CreateBasicBlock("literalCopyBody");
262        BasicBlock* literalCopyExit = b->CreateBasicBlock("literalCopyExit");
263
264        b->CreateCondBr(b->CreateICmpULT(phiCopiedLength, matchLength), literalCopyBody, literalCopyExit);
265
266        // ---- literalCopyBody
267        b->SetInsertPoint(literalCopyBody);
268        Value* outputTargetPtr = b->CreatePointerCast(phiCopyToPtr, INT_FW_PTR_TY);
269        Value* inputTargetPtr = b->CreatePointerCast(phiCopyFromPtr, INT_FW_PTR_TY);
270
271        Value* inputTargetValue = b->CreateLoad(inputTargetPtr);
272        inputTargetValue = b->CreateLShr(inputTargetValue, b->CreateMul(copyFromPosRemByteItem, INT_FW_TWIST_WIDTH));
273        inputTargetValue = b->CreateShl(inputTargetValue, b->CreateMul(outputPosRemByteItem, INT_FW_TWIST_WIDTH));
274
275        Value* outputValue = b->CreateAnd(phiOutputLastByte, outputMask);
276
277        outputValue = b->CreateOr(outputValue, inputTargetValue);
278        b->CreateStore(outputValue, outputTargetPtr);
279
280        phiCopiedLength->addIncoming(b->CreateAdd(phiCopiedLength, copyLength), b->GetInsertBlock());
281        phiCopyFromPtr->addIncoming(b->CreateGEP(phiCopyFromPtr, copyLengthByte), b->GetInsertBlock());
282        phiCopyToPtr->addIncoming(b->CreateGEP(phiCopyToPtr, copyLengthByte), b->GetInsertBlock());
283        phiOutputLastByte->addIncoming(b->CreateLShr(outputValue, b->getSize(this->getNormalCopyLength() * mTwistWidth)), b->GetInsertBlock());
284
285        b->CreateBr(literalCopyCon);
286
287        // ---- literalCopyExit
288        b->SetInsertPoint(literalCopyExit);
289
290        b->setScalarField("outputPos", b->CreateAdd(outputPos, matchLength));
291    }
292
293
294    void LZ4TwistAioKernel::doMatchCopy(const std::unique_ptr<KernelBuilder> &b, llvm::Value *matchOffset,
295                             llvm::Value *matchLength) {
296
297        BasicBlock* shortMatchCopyBlock = b->CreateBasicBlock("shortMatchCopyBlock");
298        BasicBlock* longMatchCopyBlock = b->CreateBasicBlock("longMatchCopyBlock");
299        BasicBlock* matchCopyFinishBlock = b->CreateBasicBlock("matchCopyFinishBlock");
300
301        b->CreateUnlikelyCondBr(
302                b->CreateICmpULT(matchOffset, this->getNormalCopyLengthValue(b)),
303                shortMatchCopyBlock,
304                longMatchCopyBlock
305        );
306
307        // ---- shortMatchCopyBlock
308        b->SetInsertPoint(shortMatchCopyBlock);
309        this->doShortMatchCopy(b, matchOffset, matchLength);
310        b->CreateBr(matchCopyFinishBlock);
311
312        // ---- longMatchCopyBlock
313        b->SetInsertPoint(longMatchCopyBlock);
314        this->doLongMatchCopy(b, matchOffset, matchLength);
315        b->CreateBr(matchCopyFinishBlock);
316
317        b->SetInsertPoint(matchCopyFinishBlock);
318    }
319
320    void LZ4TwistAioKernel::setProducedOutputItemCount(const std::unique_ptr<KernelBuilder> &b, llvm::Value* produced) {
321        b->setProducedItemCount("outputTwistStream", produced);
322    }
323
324
325    void LZ4TwistAioKernel::initializationMethod(const std::unique_ptr<KernelBuilder> &b) {
326        b->setScalarField("temporaryInputPtr", b->CreateMalloc(b->getSize(mBlockSize / mItemsPerByte)));
327    }
328
329    void LZ4TwistAioKernel::prepareProcessBlock(const std::unique_ptr<KernelBuilder> &b, llvm::Value* blockStart, llvm::Value* blockEnd) {
330        Constant* SIZE_0 = b->getSize(0);
331        Constant* SIZE_ITEMS_PER_BYTE = b->getSize(mItemsPerByte);
332        Type* INT8_PTR_TY = b->getInt8PtrTy();
333
334
335        Value* rawInputPtr = b->CreatePointerCast(b->getRawInputPointer("inputTwistStream", SIZE_0), INT8_PTR_TY);
336        Value* inputCapacity = b->getCapacity("inputTwistStream");
337
338        Value* inputByteCapacity = b->CreateUDiv(inputCapacity, SIZE_ITEMS_PER_BYTE);
339
340        Value* blockStartRem = b->CreateURem(blockStart, inputCapacity);
341        Value* blockStartByteRem = b->CreateUDiv(blockStartRem, SIZE_ITEMS_PER_BYTE);
342        Value* remByte = b->CreateSub(inputByteCapacity, blockStartByteRem);
343
344        Value* blockSize = b->CreateSub(blockEnd, blockStart);
345        Value* copyTotalByte = b->CreateAdd(b->CreateUDiv(blockSize, SIZE_ITEMS_PER_BYTE), SIZE_ITEMS_PER_BYTE); // It will be safe if we copy a few bytes more
346
347        Value* copyBytes1 = b->CreateUMin(remByte, copyTotalByte);
348        Value* copyBytes2 = b->CreateSub(copyTotalByte, copyBytes1);
349
350        Value* temporayInputPtr = b->getScalarField("temporaryInputPtr");
351
352        b->CreateMemCpy(temporayInputPtr, b->CreateGEP(rawInputPtr, blockStartByteRem), copyBytes1, 1);
353        b->CreateMemCpy(b->CreateGEP(temporayInputPtr, copyBytes1), rawInputPtr, copyBytes2, 1);
354    }
355
356    void LZ4TwistAioKernel::beforeTermination(const std::unique_ptr<KernelBuilder> &b) {
357        b->CreateFree(b->getScalarField("temporaryInputPtr"));
358    }
359
360    llvm::Value *LZ4TwistAioKernel::getOutputMask(const std::unique_ptr<KernelBuilder> &b, llvm::Value *outputPos) {
361        Value* remByteItems = b->CreateURem(outputPos, b->getSize(mItemsPerByte));
362        Value* INT_FW_1 = b->getIntN(COPY_FW, 1);
363        Value* shiftAmount = b->CreateMul(remByteItems, b->getIntN(COPY_FW, mTwistWidth));
364        return b->CreateSub(
365                b->CreateShl(INT_FW_1, shiftAmount),
366                INT_FW_1
367        );
368    }
369}
Note: See TracBrowser for help on using the repository browser.