source: icGREP/icgrep-devel/icgrep/kernels/zeroextend.cpp @ 6297

Last change on this file since 6297 was 6288, checked in by cameron, 7 months ago

Repeat of prior check in

File size: 6.5 KB
Line 
1#include "zeroextend.h"
2#include <kernels/kernel_builder.h>
3#include <kernels/streamset.h>
4#include <llvm/Support/raw_ostream.h>
5
6using namespace llvm;
7
8namespace kernel {
9
10inline bool notProperFactorOf(const unsigned n, const unsigned m) {
11    return ((n % m) != n) || (n == 1) || (n >= m);
12}
13
14inline static bool is_power_2(const unsigned n) {
15    return ((n & (n - 1)) == 0) && n;
16}
17
18void ZeroExtend::generateMultiBlockLogic(const std::unique_ptr<KernelBuilder> & b, Value * const numOfStrides) {
19
20    const Binding & input = getInputStreamSetBinding(0);
21    const auto inputFieldWidth = input.getFieldWidth();
22
23    // TODO: support for 1,2,4 bit field widths will require specialized logic.
24
25    // TODO: we cannot assume aligned I/O streams when handling more than one stream in a set
26
27    if (LLVM_UNLIKELY(inputFieldWidth < 8 || !is_power_2(inputFieldWidth))) {
28        report_fatal_error("ZeroExtend: input field width "
29                           "must be a power of 2 greater than 4");
30    }
31
32    const Binding & output = getOutputStreamSetBinding(0);
33    const auto outputFieldWidth = output.getFieldWidth();
34
35    if (LLVM_UNLIKELY(notProperFactorOf(inputFieldWidth, outputFieldWidth))) {
36        report_fatal_error("ZeroExtend: input field width "
37                           "must be a proper factor of "
38                           "output field width");
39    }
40
41    const auto blockWidth = b->getBitBlockWidth();
42
43    if (LLVM_UNLIKELY(notProperFactorOf(outputFieldWidth, blockWidth))) {
44        report_fatal_error("ZeroExtend: output field width "
45                           "must be a proper factor of "
46                           "block width");
47    }
48
49    if (LLVM_UNLIKELY(input.getNumElements() != 1 || output.getNumElements() != 1)) {
50        report_fatal_error("ZeroExtend: currently only supports "
51                           "single stream I/O");
52    }
53
54    const auto inputVectorSize = (blockWidth / inputFieldWidth); assert (is_power_2(inputVectorSize));
55    const auto outputVectorSize = (blockWidth / outputFieldWidth); assert (is_power_2(outputVectorSize));
56
57    IntegerType * const sizeTy = b->getSizeTy();
58
59    Value * const ZERO = b->getSize(0);
60
61    VectorType * const inputTy = VectorType::get(b->getIntNTy(inputFieldWidth), inputVectorSize);
62    PointerType * const inputPtrTy = inputTy->getPointerTo();
63
64    VectorType * const outputTy = VectorType::get(b->getIntNTy(outputFieldWidth), outputVectorSize);
65    PointerType * const outputPtrTy = outputTy->getPointerTo();
66
67    Value * const processed = b->getProcessedItemCount(input.getName());
68    Value * const baseInputPtr = b->CreatePointerCast(b->getRawInputPointer(input.getName(), processed), inputPtrTy);
69
70    Value * const produced = b->getProducedItemCount(output.getName());
71    Value * const baseOutputPtr = b->CreatePointerCast(b->getRawOutputPointer(output.getName(), produced), outputPtrTy);
72
73    BasicBlock * const entry = b->GetInsertBlock();
74    BasicBlock * const loop = b->CreateBasicBlock("Loop");
75    b->CreateBr(loop);
76
77    b->SetInsertPoint(loop);
78    PHINode * const index = b->CreatePHI(sizeTy, 2);
79    index->addIncoming(ZERO, entry);
80
81    std::vector<Value *> inputBuffer(inputFieldWidth);
82    // read the values from the input stream
83    Value * const baseInputOffset = b->CreateMul(index, b->getSize(inputFieldWidth));
84    for (unsigned i = 0; i < inputFieldWidth; ++i) {
85        Value * const offset = b->CreateAdd(baseInputOffset, b->getSize(i));
86        Value * const ptr = b->CreateGEP(baseInputPtr, offset);
87        inputBuffer[i] = b->CreateAlignedLoad(ptr, (inputFieldWidth / CHAR_BIT));
88    }
89
90    std::vector<Value *> outputBuffer(inputFieldWidth * 2);
91
92    std::vector<Constant *> lowerHalf(inputVectorSize);
93    std::vector<Constant *> upperHalf(inputVectorSize);
94
95    // expand by doubling repeatidly until we've reached the desired output size
96    for (;;) {
97
98        VectorType * const inputTy = cast<VectorType>(inputBuffer[0]->getType());
99
100        const auto n = inputTy->getVectorElementType()->getIntegerBitWidth();
101        const auto count = blockWidth / n;
102
103        const auto halfCount = (count / 2);
104
105        for (unsigned i = 0; i < halfCount; ++i) {
106            lowerHalf[i * 2] = b->getInt32(i);
107            lowerHalf[(i * 2) + 1] = b->getInt32(count + i);
108        }
109        Constant * const LOWER_MASK = ConstantVector::get(lowerHalf);
110
111        for (unsigned i = 0; i < halfCount; ++i) {
112            upperHalf[i * 2] = b->getInt32(halfCount + i);
113            upperHalf[(i * 2) + 1] = b->getInt32(count + halfCount + i);
114        }
115        Constant * const UPPER_MASK = ConstantVector::get(upperHalf);
116
117        VectorType * const outputTy = VectorType::get(b->getIntNTy(n * 2), halfCount);
118
119        Constant * const ZEROES = ConstantVector::getNullValue(inputTy);
120        for (unsigned i = 0; i < inputBuffer.size(); ++i) {
121            Value * const lower = b->CreateShuffleVector(ZEROES, inputBuffer[i], LOWER_MASK);
122            outputBuffer[i * 2] = b->CreateBitCast(lower, outputTy);
123            Value * const upper = b->CreateShuffleVector(ZEROES, inputBuffer[i], UPPER_MASK);
124            outputBuffer[(i * 2) + 1] = b->CreateBitCast(upper, outputTy);
125        }
126
127        if (LLVM_LIKELY(outputBuffer.size() == outputFieldWidth)) {
128            break;
129        }
130
131        inputBuffer.swap(outputBuffer);
132        outputBuffer.resize(inputBuffer.size() * 2);
133        lowerHalf.resize(halfCount);
134        upperHalf.resize(halfCount);
135    }
136
137    // write the values to the output stream
138    Value * const baseOutputOffset = b->CreateMul(index, b->getSize(outputFieldWidth));
139    for (unsigned i = 0; i < outputFieldWidth; ++i) {
140        Value * const offset = b->CreateAdd(baseOutputOffset, b->getSize(i));
141        Value * const ptr = b->CreateGEP(baseOutputPtr, offset);
142        b->CreateAlignedStore(outputBuffer[i], ptr, (outputFieldWidth / CHAR_BIT));
143    }
144
145    // loop until done
146    BasicBlock * const exit = b->CreateBasicBlock("exit");
147    Value * const nextIndex = b->CreateAdd(index, b->getSize(1));
148    Value * const notDone = b->CreateICmpNE(nextIndex, numOfStrides);
149    index->addIncoming(nextIndex, b->GetInsertBlock());
150    b->CreateLikelyCondBr(notDone, loop, exit);
151    b->SetInsertPoint(exit);
152}
153
154ZeroExtend::ZeroExtend(const std::unique_ptr<kernel::KernelBuilder> & b,
155                       StreamSet * const input, StreamSet * const output)
156: MultiBlockKernel(b, "zeroextend" + std::to_string(input->getFieldWidth()) + "x" + std::to_string(output->getFieldWidth()),
157{Binding{"input", input}},
158{Binding{"output", output}},
159{}, {}, {}) {
160    assert (input->getNumElements() == output->getNumElements());
161}
162
163}
Note: See TracBrowser for help on using the repository browser.