source: icGREP/icgrep-devel/icgrep/kernels/swizzle.cpp @ 6161

Last change on this file since 6161 was 6133, checked in by xwa163, 15 months ago
  1. Add sourceCC in multiplexed CC
  2. Remove workaround FakeBasisBits? from ICGrep
  3. Implement Swizzled version of LZParabix
  4. Init checkin for SwizzleByGather? Kernel
File size: 6.2 KB
Line 
1/*
2 *  Copyright (c) 2017 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 */
5
6#include "swizzle.h"
7#include <kernels/kernel_builder.h>
8#include <string>
9#include <vector>
10
11using namespace llvm;
12
13namespace kernel {
14SwizzleGenerator::SwizzleGenerator(const std::unique_ptr<kernel::KernelBuilder> & iBuilder, unsigned bitStreamCount, unsigned outputSets, unsigned inputSets, unsigned fieldWidth, std::string prefix)
15: BlockOrientedKernel(prefix + "swizzle" + std::to_string(fieldWidth) + ":" + std::to_string(bitStreamCount) + "_" + std::to_string(outputSets) + "_" + std::to_string(inputSets) , {}, {}, {}, {}, {})
16, mBitStreamCount(bitStreamCount)
17, mFieldWidth(fieldWidth)
18, mSwizzleFactor(iBuilder->getBitBlockWidth() / fieldWidth)
19, mInputSets(inputSets)
20, mOutputSets(outputSets) {
21    assert((fieldWidth > 0) && ((fieldWidth & (fieldWidth - 1)) == 0) && "fieldWidth must be a power of 2");
22    assert(fieldWidth < iBuilder->getBitBlockWidth() && "fieldWidth must be less than the block width");
23    assert(mSwizzleFactor > 1 && "fieldWidth must be less than the block width");
24    unsigned inputStreamsPerSet = (bitStreamCount + inputSets - 1)/inputSets;
25    unsigned outputStreamsPerSet = (bitStreamCount + outputSets - 1)/outputSets;
26    // Maybe the following is unnecessary.
27    //assert(inputStreamsPerSet % swizzleFactor == 0 && "input sets must be an exact multiple of the swizzle factor");
28    assert(outputStreamsPerSet % mSwizzleFactor == 0 && "output sets must be an exact multiple of the swizzle factor");
29    for (unsigned i = 0; i < mInputSets; i++) {
30        mStreamSetInputs.push_back(Binding{iBuilder->getStreamSetTy(inputStreamsPerSet, 1), "inputGroup" + std::to_string(i)});
31    }
32    for (unsigned i = 0; i < mOutputSets; i++) {
33        mStreamSetOutputs.push_back(Binding{iBuilder->getStreamSetTy(outputStreamsPerSet, 1), "outputGroup" + std::to_string(i), FixedRate(1), BlockSize(fieldWidth)});
34    }
35}
36
37void SwizzleGenerator::generateDoBlockMethod(const std::unique_ptr<kernel::KernelBuilder> & iBuilder) {
38       
39    // We may need a few passes depending on the swizzle factor
40    const unsigned swizzleFactor = mSwizzleFactor;
41    const unsigned passes = std::log2(mSwizzleFactor);
42    const unsigned swizzleGroups = (mBitStreamCount + mSwizzleFactor - 1)/mSwizzleFactor;
43    const unsigned inputStreamsPerSet = (mBitStreamCount + mInputSets - 1)/mInputSets;
44    const unsigned outputStreamsPerSet = (mBitStreamCount + mOutputSets - 1)/mOutputSets;
45
46    Value * sourceBlocks[swizzleFactor];
47    Value * targetBlocks[swizzleFactor];
48
49    for (unsigned grp = 0; grp < swizzleGroups; grp++) {
50        // First load all the data.       
51        for (unsigned i = 0; i < swizzleFactor; i++) {
52            unsigned streamNo = grp * swizzleFactor + i;
53            if (streamNo < mBitStreamCount) {
54                unsigned inputSetNo = streamNo / inputStreamsPerSet;
55                unsigned j = streamNo % inputStreamsPerSet;
56                sourceBlocks[i] = iBuilder->loadInputStreamBlock("inputGroup" + std::to_string(inputSetNo), iBuilder->getInt32(j));
57            } else {
58                // Fill in the remaining logically required streams of the last swizzle group with null values.
59                sourceBlocks[i] = Constant::getNullValue(iBuilder->getBitBlockType());
60            }
61        }
62        // Now perform the swizzle passes.
63        for (unsigned p = 0; p < passes; p++) {
64            for (unsigned i = 0; i < swizzleFactor / 2; i++) {
65                targetBlocks[i * 2] = iBuilder->esimd_mergel(mFieldWidth, sourceBlocks[i], sourceBlocks[i + (swizzleFactor / 2)]);
66                targetBlocks[(i * 2) + 1] = iBuilder->esimd_mergeh(mFieldWidth, sourceBlocks[i], sourceBlocks[i + (swizzleFactor / 2)]);
67            }
68            for (unsigned i = 0; i < swizzleFactor; i++) {
69                sourceBlocks[i] = targetBlocks[i];
70            }
71        }
72        for (unsigned i = 0; i < swizzleFactor; i++) {
73            unsigned streamNo = grp * swizzleFactor + i;
74            unsigned outputSetNo = streamNo / outputStreamsPerSet;
75            unsigned j = streamNo % outputStreamsPerSet;
76            iBuilder->storeOutputStreamBlock("outputGroup" + std::to_string(outputSetNo), iBuilder->getInt32(j), iBuilder->bitCast(sourceBlocks[i]));
77        }
78    }
79}
80
81
82    SwizzleByGather::SwizzleByGather(const std::unique_ptr<KernelBuilder> &iBuilder)
83    : BlockOrientedKernel("swizzleByGather", {}, {}, {}, {}, {}){
84        for (unsigned i = 0; i < 2; i++) {
85            mStreamSetInputs.push_back(Binding{iBuilder->getStreamSetTy(4, 1), "inputGroup" + std::to_string(i)});
86        }
87        for (unsigned i = 0; i < 1; i++) {
88            mStreamSetOutputs.push_back(Binding{iBuilder->getStreamSetTy(8, 1), "outputGroup" + std::to_string(i), FixedRate(1)});
89        }
90    }
91
92    void SwizzleByGather::generateDoBlockMethod(const std::unique_ptr<kernel::KernelBuilder> &b) {
93        Value* outputStreamPtr = b->getOutputStreamBlockPtr("outputGroup0", b->getSize(0));
94
95        for (unsigned i = 0; i < 2; i++) {
96            std::vector<llvm::Value*> inputStream;
97            Value* inputPtr = b->getInputStreamBlockPtr("inputGroup" + std::to_string(i), b->getSize(0));
98
99            Value* inputBytePtr = b->CreatePointerCast(inputPtr, b->getInt8PtrTy());
100            Function *gatherFunc = Intrinsic::getDeclaration(b->getModule(), Intrinsic::x86_avx2_gather_d_q_256);
101            Value *addresses = ConstantVector::get(
102                    {b->getInt32(0), b->getInt32(32), b->getInt32(64), b->getInt32(96)});
103
104            for (unsigned j = 0; j < 4; j++) {
105                Value *gather_result = b->CreateCall(
106                        gatherFunc,
107                        {
108                                UndefValue::get(b->getBitBlockType()),
109                                inputBytePtr,
110                                addresses,
111                                Constant::getAllOnesValue(b->getBitBlockType()),
112                                b->getInt8(1)
113                        }
114                );
115
116                inputBytePtr = b->CreateGEP(inputBytePtr, b->getInt32(8));
117
118                b->CreateStore(gather_result, outputStreamPtr);
119                outputStreamPtr = b->CreateGEP(outputStreamPtr, b->getSize(1));
120            }
121        }
122    }
123}
Note: See TracBrowser for help on using the repository browser.