source: icGREP/icgrep-devel/icgrep/kernels/swizzle.cpp @ 6133

Last change on this file since 6133 was 6133, checked in by xwa163, 9 months ago
  1. Add sourceCC in multiplexed CC
  2. Remove workaround FakeBasisBits? from ICGrep
  3. Implement Swizzled version of LZParabix
  4. Init checkin for SwizzleByGather? Kernel
File size: 6.2 KB
Line 
1/*
2 *  Copyright (c) 2017 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 */
5
6#include "swizzle.h"
7#include <kernels/kernel_builder.h>
8#include <string>
9#include <vector>
10
11using namespace llvm;
12
13namespace kernel {
14SwizzleGenerator::SwizzleGenerator(const std::unique_ptr<kernel::KernelBuilder> & iBuilder, unsigned bitStreamCount, unsigned outputSets, unsigned inputSets, unsigned fieldWidth, std::string prefix)
15: BlockOrientedKernel(prefix + "swizzle" + std::to_string(fieldWidth) + ":" + std::to_string(bitStreamCount) + "_" + std::to_string(outputSets) + "_" + std::to_string(inputSets) , {}, {}, {}, {}, {})
16, mBitStreamCount(bitStreamCount)
17, mFieldWidth(fieldWidth)
18, mSwizzleFactor(iBuilder->getBitBlockWidth() / fieldWidth)
19, mInputSets(inputSets)
20, mOutputSets(outputSets) {
21    assert((fieldWidth > 0) && ((fieldWidth & (fieldWidth - 1)) == 0) && "fieldWidth must be a power of 2");
22    assert(fieldWidth < iBuilder->getBitBlockWidth() && "fieldWidth must be less than the block width");
23    assert(mSwizzleFactor > 1 && "fieldWidth must be less than the block width");
24    unsigned inputStreamsPerSet = (bitStreamCount + inputSets - 1)/inputSets;
25    unsigned outputStreamsPerSet = (bitStreamCount + outputSets - 1)/outputSets;
26    // Maybe the following is unnecessary.
27    //assert(inputStreamsPerSet % swizzleFactor == 0 && "input sets must be an exact multiple of the swizzle factor");
28    assert(outputStreamsPerSet % mSwizzleFactor == 0 && "output sets must be an exact multiple of the swizzle factor");
29    for (unsigned i = 0; i < mInputSets; i++) {
30        mStreamSetInputs.push_back(Binding{iBuilder->getStreamSetTy(inputStreamsPerSet, 1), "inputGroup" + std::to_string(i)});
31    }
32    for (unsigned i = 0; i < mOutputSets; i++) {
33        mStreamSetOutputs.push_back(Binding{iBuilder->getStreamSetTy(outputStreamsPerSet, 1), "outputGroup" + std::to_string(i), FixedRate(1), BlockSize(fieldWidth)});
34    }
35}
36
37void SwizzleGenerator::generateDoBlockMethod(const std::unique_ptr<kernel::KernelBuilder> & iBuilder) {
38       
39    // We may need a few passes depending on the swizzle factor
40    const unsigned swizzleFactor = mSwizzleFactor;
41    const unsigned passes = std::log2(mSwizzleFactor);
42    const unsigned swizzleGroups = (mBitStreamCount + mSwizzleFactor - 1)/mSwizzleFactor;
43    const unsigned inputStreamsPerSet = (mBitStreamCount + mInputSets - 1)/mInputSets;
44    const unsigned outputStreamsPerSet = (mBitStreamCount + mOutputSets - 1)/mOutputSets;
45
46    Value * sourceBlocks[swizzleFactor];
47    Value * targetBlocks[swizzleFactor];
48
49    for (unsigned grp = 0; grp < swizzleGroups; grp++) {
50        // First load all the data.       
51        for (unsigned i = 0; i < swizzleFactor; i++) {
52            unsigned streamNo = grp * swizzleFactor + i;
53            if (streamNo < mBitStreamCount) {
54                unsigned inputSetNo = streamNo / inputStreamsPerSet;
55                unsigned j = streamNo % inputStreamsPerSet;
56                sourceBlocks[i] = iBuilder->loadInputStreamBlock("inputGroup" + std::to_string(inputSetNo), iBuilder->getInt32(j));
57            } else {
58                // Fill in the remaining logically required streams of the last swizzle group with null values.
59                sourceBlocks[i] = Constant::getNullValue(iBuilder->getBitBlockType());
60            }
61        }
62        // Now perform the swizzle passes.
63        for (unsigned p = 0; p < passes; p++) {
64            for (unsigned i = 0; i < swizzleFactor / 2; i++) {
65                targetBlocks[i * 2] = iBuilder->esimd_mergel(mFieldWidth, sourceBlocks[i], sourceBlocks[i + (swizzleFactor / 2)]);
66                targetBlocks[(i * 2) + 1] = iBuilder->esimd_mergeh(mFieldWidth, sourceBlocks[i], sourceBlocks[i + (swizzleFactor / 2)]);
67            }
68            for (unsigned i = 0; i < swizzleFactor; i++) {
69                sourceBlocks[i] = targetBlocks[i];
70            }
71        }
72        for (unsigned i = 0; i < swizzleFactor; i++) {
73            unsigned streamNo = grp * swizzleFactor + i;
74            unsigned outputSetNo = streamNo / outputStreamsPerSet;
75            unsigned j = streamNo % outputStreamsPerSet;
76            iBuilder->storeOutputStreamBlock("outputGroup" + std::to_string(outputSetNo), iBuilder->getInt32(j), iBuilder->bitCast(sourceBlocks[i]));
77        }
78    }
79}
80
81
82    SwizzleByGather::SwizzleByGather(const std::unique_ptr<KernelBuilder> &iBuilder)
83    : BlockOrientedKernel("swizzleByGather", {}, {}, {}, {}, {}){
84        for (unsigned i = 0; i < 2; i++) {
85            mStreamSetInputs.push_back(Binding{iBuilder->getStreamSetTy(4, 1), "inputGroup" + std::to_string(i)});
86        }
87        for (unsigned i = 0; i < 1; i++) {
88            mStreamSetOutputs.push_back(Binding{iBuilder->getStreamSetTy(8, 1), "outputGroup" + std::to_string(i), FixedRate(1)});
89        }
90    }
91
92    void SwizzleByGather::generateDoBlockMethod(const std::unique_ptr<kernel::KernelBuilder> &b) {
93        Value* outputStreamPtr = b->getOutputStreamBlockPtr("outputGroup0", b->getSize(0));
94
95        for (unsigned i = 0; i < 2; i++) {
96            std::vector<llvm::Value*> inputStream;
97            Value* inputPtr = b->getInputStreamBlockPtr("inputGroup" + std::to_string(i), b->getSize(0));
98
99            Value* inputBytePtr = b->CreatePointerCast(inputPtr, b->getInt8PtrTy());
100            Function *gatherFunc = Intrinsic::getDeclaration(b->getModule(), Intrinsic::x86_avx2_gather_d_q_256);
101            Value *addresses = ConstantVector::get(
102                    {b->getInt32(0), b->getInt32(32), b->getInt32(64), b->getInt32(96)});
103
104            for (unsigned j = 0; j < 4; j++) {
105                Value *gather_result = b->CreateCall(
106                        gatherFunc,
107                        {
108                                UndefValue::get(b->getBitBlockType()),
109                                inputBytePtr,
110                                addresses,
111                                Constant::getAllOnesValue(b->getBitBlockType()),
112                                b->getInt8(1)
113                        }
114                );
115
116                inputBytePtr = b->CreateGEP(inputBytePtr, b->getInt32(8));
117
118                b->CreateStore(gather_result, outputStreamPtr);
119                outputStreamPtr = b->CreateGEP(outputStreamPtr, b->getSize(1));
120            }
121        }
122    }
123}
Note: See TracBrowser for help on using the repository browser.