Ignore:
Timestamp:
Nov 2, 2018, 7:18:31 PM (6 months ago)
Author:
nmedfort
Message:

Initial version of PipelineKernel? + revised StreamSet? model.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • icGREP/icgrep-devel/icgrep/kernels/swizzle.cpp

    r6133 r6184  
    66#include "swizzle.h"
    77#include <kernels/kernel_builder.h>
     8#include <llvm/Support/raw_ostream.h>
    89#include <string>
    910#include <vector>
     
    1213
    1314namespace kernel {
    14 SwizzleGenerator::SwizzleGenerator(const std::unique_ptr<kernel::KernelBuilder> & iBuilder, unsigned bitStreamCount, unsigned outputSets, unsigned inputSets, unsigned fieldWidth, std::string prefix)
    15 : BlockOrientedKernel(prefix + "swizzle" + std::to_string(fieldWidth) + ":" + std::to_string(bitStreamCount) + "_" + std::to_string(outputSets) + "_" + std::to_string(inputSets) , {}, {}, {}, {}, {})
    16 , mBitStreamCount(bitStreamCount)
    17 , mFieldWidth(fieldWidth)
    18 , mSwizzleFactor(iBuilder->getBitBlockWidth() / fieldWidth)
    19 , mInputSets(inputSets)
    20 , mOutputSets(outputSets) {
    21     assert((fieldWidth > 0) && ((fieldWidth & (fieldWidth - 1)) == 0) && "fieldWidth must be a power of 2");
    22     assert(fieldWidth < iBuilder->getBitBlockWidth() && "fieldWidth must be less than the block width");
    23     assert(mSwizzleFactor > 1 && "fieldWidth must be less than the block width");
    24     unsigned inputStreamsPerSet = (bitStreamCount + inputSets - 1)/inputSets;
    25     unsigned outputStreamsPerSet = (bitStreamCount + outputSets - 1)/outputSets;
    26     // Maybe the following is unnecessary.
    27     //assert(inputStreamsPerSet % swizzleFactor == 0 && "input sets must be an exact multiple of the swizzle factor");
    28     assert(outputStreamsPerSet % mSwizzleFactor == 0 && "output sets must be an exact multiple of the swizzle factor");
    29     for (unsigned i = 0; i < mInputSets; i++) {
    30         mStreamSetInputs.push_back(Binding{iBuilder->getStreamSetTy(inputStreamsPerSet, 1), "inputGroup" + std::to_string(i)});
    31     }
    32     for (unsigned i = 0; i < mOutputSets; i++) {
    33         mStreamSetOutputs.push_back(Binding{iBuilder->getStreamSetTy(outputStreamsPerSet, 1), "outputGroup" + std::to_string(i), FixedRate(1), BlockSize(fieldWidth)});
    34     }
     15
     16inline static bool is_power_2(const uint64_t n) {
     17    return ((n & (n - 1)) == 0) && n;
    3518}
    3619
    37 void SwizzleGenerator::generateDoBlockMethod(const std::unique_ptr<kernel::KernelBuilder> & iBuilder) {
     20LLVM_READNONE inline unsigned getBitStreamCount(const std::vector<StreamSet *> & inputs) {
     21    unsigned count = 0;
     22
     23
     24    for (StreamSet * input : inputs) {
     25        count += input->getNumElements();
     26    }
     27    return count;
     28}
     29
     30
     31inline std::string makeSwizzleName(const std::vector<StreamSet *> & inputs, const std::vector<StreamSet *> & outputs, const unsigned fieldWidth) {
     32    const auto inputStreamCount = getBitStreamCount(inputs);
     33    const auto outputStreamCount = getBitStreamCount(outputs);
     34    if (LLVM_UNLIKELY(inputStreamCount != outputStreamCount)) {
     35        report_fatal_error("total number of input elements does not match the output elements");
     36    }
     37    std::string tmp;
     38    raw_string_ostream out(tmp);
     39    out << "swizzle" << fieldWidth << ':' << inputStreamCount << '_' << inputs.size() << '_' << outputs.size();
     40    out.flush();
     41    return tmp;
     42}
     43
     44inline size_t ceil_udiv(const size_t n, const size_t m) {
     45    return (n + m - 1) / m;
     46}
     47
     48inline Bindings makeSwizzledInputs(const std::vector<StreamSet *> & inputs) {
     49    Bindings bindings;
     50    const auto n = inputs.size();
     51    bindings.reserve(n);
     52    const auto numElements = inputs[0]->getNumElements();
     53    for (unsigned i = 0; i < n; ++i) {
     54        if (LLVM_UNLIKELY(inputs[i]->getNumElements() != numElements)) {
     55            report_fatal_error("not all inputs have the same number of elements");
     56        }
     57        bindings.emplace_back("inputGroup" + std::to_string(i), inputs[i]);
     58    }
     59    return bindings;
     60}
     61
     62inline Bindings makeSwizzledOutputs(const std::vector<StreamSet *> & outputs, const unsigned fieldWidth) {
     63    Bindings bindings;
     64    const auto n = outputs.size();
     65    bindings.reserve(n);
     66    const auto numElements = outputs[0]->getNumElements();
     67    for (unsigned i = 0; i < n; ++i) {
     68        if (LLVM_UNLIKELY(outputs[i]->getNumElements() != numElements)) {
     69            report_fatal_error("not all outputs have the same number of elements");
     70        }
     71        bindings.emplace_back("outputGroup" + std::to_string(i), outputs[i], FixedRate(1), BlockSize(fieldWidth));
     72    }
     73    return bindings;
     74}
     75
     76SwizzleGenerator::SwizzleGenerator(const std::unique_ptr<kernel::KernelBuilder> &,
     77                                   const std::vector<StreamSet *> & inputs,
     78                                   const std::vector<StreamSet *> & outputs,
     79                                   const unsigned fieldWidth)
     80: BlockOrientedKernel(makeSwizzleName(inputs, outputs, fieldWidth),
     81std::move(makeSwizzledInputs(inputs)),
     82std::move(makeSwizzledOutputs(outputs, fieldWidth)),
     83{}, {}, {})
     84, mBitStreamCount(getBitStreamCount(inputs))
     85, mFieldWidth(fieldWidth) {
     86
     87}
     88
     89void SwizzleGenerator::generateDoBlockMethod(const std::unique_ptr<kernel::KernelBuilder> & b) {
    3890       
    3991    // We may need a few passes depending on the swizzle factor
    40     const unsigned swizzleFactor = mSwizzleFactor;
    41     const unsigned passes = std::log2(mSwizzleFactor);
    42     const unsigned swizzleGroups = (mBitStreamCount + mSwizzleFactor - 1)/mSwizzleFactor;
    43     const unsigned inputStreamsPerSet = (mBitStreamCount + mInputSets - 1)/mInputSets;
    44     const unsigned outputStreamsPerSet = (mBitStreamCount + mOutputSets - 1)/mOutputSets;
     92
     93    if (LLVM_UNLIKELY(!is_power_2(mFieldWidth))) {
     94        report_fatal_error("fieldWidth must be a power of 2");
     95    }
     96    if (LLVM_UNLIKELY(mFieldWidth > b->getBitBlockWidth())) {
     97        report_fatal_error("fieldWidth must be a power of 2");
     98    }
     99
     100    const auto swizzleFactor = b->getBitBlockWidth() / mFieldWidth;
     101    const auto passes = std::log2(swizzleFactor);
     102    const auto swizzleGroups = ceil_udiv(mBitStreamCount, swizzleFactor);
     103    const auto inputStreamsPerSet = ceil_udiv(mBitStreamCount, getNumOfStreamInputs());
     104    const auto outputStreamsPerSet = ceil_udiv(mBitStreamCount, getNumOfStreamOutputs());
    45105
    46106    Value * sourceBlocks[swizzleFactor];
    47107    Value * targetBlocks[swizzleFactor];
    48 
    49108    for (unsigned grp = 0; grp < swizzleGroups; grp++) {
    50109        // First load all the data.       
    51110        for (unsigned i = 0; i < swizzleFactor; i++) {
    52             unsigned streamNo = grp * swizzleFactor + i;
     111            const auto streamNo = grp * swizzleFactor + i;
    53112            if (streamNo < mBitStreamCount) {
    54                 unsigned inputSetNo = streamNo / inputStreamsPerSet;
    55                 unsigned j = streamNo % inputStreamsPerSet;
    56                 sourceBlocks[i] = iBuilder->loadInputStreamBlock("inputGroup" + std::to_string(inputSetNo), iBuilder->getInt32(j));
     113                const auto inputSetNo = streamNo / inputStreamsPerSet;
     114                const auto j = streamNo % inputStreamsPerSet;
     115                sourceBlocks[i] = b->loadInputStreamBlock("inputGroup" + std::to_string(inputSetNo), b->getInt32(j));
    57116            } else {
    58117                // Fill in the remaining logically required streams of the last swizzle group with null values.
    59                 sourceBlocks[i] = Constant::getNullValue(iBuilder->getBitBlockType());
     118                sourceBlocks[i] = Constant::getNullValue(b->getBitBlockType());
    60119            }
    61120        }
     
    63122        for (unsigned p = 0; p < passes; p++) {
    64123            for (unsigned i = 0; i < swizzleFactor / 2; i++) {
    65                 targetBlocks[i * 2] = iBuilder->esimd_mergel(mFieldWidth, sourceBlocks[i], sourceBlocks[i + (swizzleFactor / 2)]);
    66                 targetBlocks[(i * 2) + 1] = iBuilder->esimd_mergeh(mFieldWidth, sourceBlocks[i], sourceBlocks[i + (swizzleFactor / 2)]);
     124                targetBlocks[i * 2] = b->esimd_mergel(mFieldWidth, sourceBlocks[i], sourceBlocks[i + (swizzleFactor / 2)]);
     125                targetBlocks[(i * 2) + 1] = b->esimd_mergeh(mFieldWidth, sourceBlocks[i], sourceBlocks[i + (swizzleFactor / 2)]);
    67126            }
    68127            for (unsigned i = 0; i < swizzleFactor; i++) {
     
    74133            unsigned outputSetNo = streamNo / outputStreamsPerSet;
    75134            unsigned j = streamNo % outputStreamsPerSet;
    76             iBuilder->storeOutputStreamBlock("outputGroup" + std::to_string(outputSetNo), iBuilder->getInt32(j), iBuilder->bitCast(sourceBlocks[i]));
     135            b->storeOutputStreamBlock("outputGroup" + std::to_string(outputSetNo), b->getInt32(j), b->bitCast(sourceBlocks[i]));
    77136        }
    78137    }
     
    80139
    81140
    82     SwizzleByGather::SwizzleByGather(const std::unique_ptr<KernelBuilder> &iBuilder)
    83     : BlockOrientedKernel("swizzleByGather", {}, {}, {}, {}, {}){
    84         for (unsigned i = 0; i < 2; i++) {
    85             mStreamSetInputs.push_back(Binding{iBuilder->getStreamSetTy(4, 1), "inputGroup" + std::to_string(i)});
    86         }
    87         for (unsigned i = 0; i < 1; i++) {
    88             mStreamSetOutputs.push_back(Binding{iBuilder->getStreamSetTy(8, 1), "outputGroup" + std::to_string(i), FixedRate(1)});
    89         }
     141SwizzleByGather::SwizzleByGather(const std::unique_ptr<KernelBuilder> &iBuilder)
     142: BlockOrientedKernel("swizzleByGather", {}, {}, {}, {}, {}){
     143    for (unsigned i = 0; i < 2; i++) {
     144        mInputStreamSets.push_back(Binding{iBuilder->getStreamSetTy(4, 1), "inputGroup" + std::to_string(i)});
    90145    }
     146    for (unsigned i = 0; i < 1; i++) {
     147        mOutputStreamSets.push_back(Binding{iBuilder->getStreamSetTy(8, 1), "outputGroup" + std::to_string(i), FixedRate(1)});
     148    }
     149}
    91150
    92     void SwizzleByGather::generateDoBlockMethod(const std::unique_ptr<kernel::KernelBuilder> &b) {
    93         Value* outputStreamPtr = b->getOutputStreamBlockPtr("outputGroup0", b->getSize(0));
     151void SwizzleByGather::generateDoBlockMethod(const std::unique_ptr<kernel::KernelBuilder> &b) {
     152    Value* outputStreamPtr = b->getOutputStreamBlockPtr("outputGroup0", b->getSize(0));
    94153
    95         for (unsigned i = 0; i < 2; i++) {
    96             std::vector<llvm::Value*> inputStream;
    97             Value* inputPtr = b->getInputStreamBlockPtr("inputGroup" + std::to_string(i), b->getSize(0));
     154    for (unsigned i = 0; i < 2; i++) {
     155        std::vector<llvm::Value*> inputStream;
     156        Value* inputPtr = b->getInputStreamBlockPtr("inputGroup" + std::to_string(i), b->getSize(0));
    98157
    99             Value* inputBytePtr = b->CreatePointerCast(inputPtr, b->getInt8PtrTy());
    100             Function *gatherFunc = Intrinsic::getDeclaration(b->getModule(), Intrinsic::x86_avx2_gather_d_q_256);
    101             Value *addresses = ConstantVector::get(
    102                     {b->getInt32(0), b->getInt32(32), b->getInt32(64), b->getInt32(96)});
     158        Value* inputBytePtr = b->CreatePointerCast(inputPtr, b->getInt8PtrTy());
     159        Function *gatherFunc = Intrinsic::getDeclaration(b->getModule(), Intrinsic::x86_avx2_gather_d_q_256);
     160        Value *addresses = ConstantVector::get(
     161                {b->getInt32(0), b->getInt32(32), b->getInt32(64), b->getInt32(96)});
    103162
    104             for (unsigned j = 0; j < 4; j++) {
    105                 Value *gather_result = b->CreateCall(
    106                         gatherFunc,
    107                         {
    108                                 UndefValue::get(b->getBitBlockType()),
    109                                 inputBytePtr,
    110                                 addresses,
    111                                 Constant::getAllOnesValue(b->getBitBlockType()),
    112                                 b->getInt8(1)
    113                         }
    114                 );
     163        for (unsigned j = 0; j < 4; j++) {
     164            Value *gather_result = b->CreateCall(
     165                    gatherFunc,
     166                    {
     167                            UndefValue::get(b->getBitBlockType()),
     168                            inputBytePtr,
     169                            addresses,
     170                            Constant::getAllOnesValue(b->getBitBlockType()),
     171                            b->getInt8(1)
     172                    }
     173            );
    115174
    116                 inputBytePtr = b->CreateGEP(inputBytePtr, b->getInt32(8));
     175            inputBytePtr = b->CreateGEP(inputBytePtr, b->getInt32(8));
    117176
    118                 b->CreateStore(gather_result, outputStreamPtr);
    119                 outputStreamPtr = b->CreateGEP(outputStreamPtr, b->getSize(1));
    120             }
     177            b->CreateStore(gather_result, outputStreamPtr);
     178            outputStreamPtr = b->CreateGEP(outputStreamPtr, b->getSize(1));
    121179        }
    122180    }
    123181}
     182
     183}
Note: See TracChangeset for help on using the changeset viewer.