source: icGREP/icgrep-devel/icgrep/IR_Gen/idisa_avx_builder.cpp @ 5832

Last change on this file since 5832 was 5832, checked in by nmedfort, 22 months ago

Bug fix for UntilN

File size: 12.4 KB
Line 
1/*
2 *  Copyright (c) 2015 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 *  icgrep is a trademark of International Characters.
5 */
6
7#include "idisa_avx_builder.h"
8
9using namespace llvm;
10
11namespace IDISA {
12   
13std::string IDISA_AVX_Builder::getBuilderUniqueName() {
14    return mBitBlockWidth != 256 ? "AVX_" + std::to_string(mBitBlockWidth) : "AVX";
15}
16
17Value * IDISA_AVX_Builder::hsimd_signmask(unsigned fw, Value * a) {
18    // AVX2 special cases
19    if (mBitBlockWidth == 256) {
20        if (fw == 64) {
21            Value * signmask_f64func = Intrinsic::getDeclaration(getModule(), Intrinsic::x86_avx_movmsk_pd_256);
22            Type * bitBlock_f64type = VectorType::get(getDoubleTy(), mBitBlockWidth/64);
23            Value * a_as_pd = CreateBitCast(a, bitBlock_f64type);
24            return CreateCall(signmask_f64func, a_as_pd);
25        } else if (fw == 32) {
26            Value * signmask_f32func = Intrinsic::getDeclaration(getModule(), Intrinsic::x86_avx_movmsk_ps_256);
27            Type * bitBlock_f32type = VectorType::get(getFloatTy(), mBitBlockWidth/32);
28            Value * a_as_ps = CreateBitCast(a, bitBlock_f32type);
29            return CreateCall(signmask_f32func, a_as_ps);
30        } else if (fw == 8) {
31            Value * signmask_f8func = Intrinsic::getDeclaration(getModule(), Intrinsic::x86_avx2_pmovmskb);
32            Type * bitBlock_i8type = VectorType::get(getInt8Ty(), mBitBlockWidth/8);
33            Value * a_as_ps = CreateBitCast(a, bitBlock_i8type);
34            return CreateCall(signmask_f8func, a_as_ps);
35        }
36    } else if (mBitBlockWidth == 512) {
37        if (fw == 64) {
38            Type * bitBlock_f32type = VectorType::get(getFloatTy(), mBitBlockWidth / 32);
39            Value * a_as_ps = CreateBitCast(a, bitBlock_f32type);
40            Constant * indicies[8];
41            for (unsigned i = 0; i < 8; i++) {
42                indicies[i] = getInt32(2 * i + 1);
43            }
44            Value * packh = CreateShuffleVector(a_as_ps, UndefValue::get(bitBlock_f32type), ConstantVector::get({indicies, 8}));
45            Type * halfBlock_f32type = VectorType::get(getFloatTy(), mBitBlockWidth/64);
46            Value * pack_as_ps = CreateBitCast(packh, halfBlock_f32type);
47            Value * signmask_f32func = Intrinsic::getDeclaration(getModule(), Intrinsic::x86_avx_movmsk_ps_256);
48            return CreateCall(signmask_f32func, pack_as_ps);
49        }
50    }
51    // Otherwise use default SSE logic.
52    return IDISA_SSE_Builder::hsimd_signmask(fw, a);
53}
54
55std::string IDISA_AVX2_Builder::getBuilderUniqueName() {
56    return mBitBlockWidth != 256 ? "AVX2_" + std::to_string(mBitBlockWidth) : "AVX2";
57}
58   
59Value * IDISA_AVX2_Builder::hsimd_packh(unsigned fw, Value * a, Value * b) {
60    if (fw <= 64) {       
61        Value * aVec = fwCast(fw / 2, a);
62        Value * bVec = fwCast(fw / 2, b);
63        const auto field_count = 2 * mBitBlockWidth / fw;
64        Constant * Idxs[field_count];
65        const auto H = (field_count / 2);
66        const auto Q = (field_count / 4);
67        for (unsigned i = 0; i < Q; i++) {
68            Idxs[i] = getInt32(2 * i);
69            Idxs[i + Q] = getInt32((2 * i) + 1);
70            Idxs[i + H] = getInt32((2 * i) + H);
71            Idxs[i + H + Q] = getInt32((2 * i) + 1 + H);
72        }
73        Value * shufa = CreateShuffleVector(aVec, aVec, ConstantVector::get({Idxs, field_count}));
74        Value * shufb = CreateShuffleVector(bVec, bVec, ConstantVector::get({Idxs, field_count}));
75        return hsimd_packh(mBitBlockWidth / 2, shufa, shufb);
76    }
77    // Otherwise use default SSE logic.
78    return IDISA_SSE_Builder::hsimd_packh(fw, a, b);
79}
80
81Value * IDISA_AVX2_Builder::hsimd_packl(unsigned fw, Value * a, Value * b) {
82    if (fw <= 64) {
83        Value * aVec = fwCast(fw / 2, a);
84        Value * bVec = fwCast(fw / 2, b);
85        const auto field_count = 2 * mBitBlockWidth / fw;
86        Constant * Idxs[field_count];
87        const auto H = (field_count / 2);
88        const auto Q = (field_count / 4);
89        for (unsigned i = 0; i < Q; i++) {
90            Idxs[i] = getInt32(2 * i);
91            Idxs[i + Q] = getInt32((2 * i) + 1);
92            Idxs[i + H] = getInt32((2 * i) + H);
93            Idxs[i + H + Q] = getInt32((2 * i) + H + 1);
94        }
95        Value * shufa = CreateShuffleVector(aVec, aVec, ConstantVector::get({Idxs, field_count}));
96        Value * shufb = CreateShuffleVector(bVec, bVec, ConstantVector::get({Idxs, field_count}));
97        return hsimd_packl(mBitBlockWidth / 2, shufa, shufb);
98    }
99    // Otherwise use default SSE logic.
100    return IDISA_SSE_Builder::hsimd_packl(fw, a, b);
101}
102   
103Value * IDISA_AVX2_Builder::esimd_mergeh(unsigned fw, Value * a, Value * b) {
104#if LLVM_VERSION_INTEGER < LLVM_6_0_0
105    if ((fw == 128) && (mBitBlockWidth == 256)) {
106        Value * vperm2i128func = Intrinsic::getDeclaration(getModule(), Intrinsic::x86_avx2_vperm2i128);
107        return CreateCall(vperm2i128func, {fwCast(64, a), fwCast(64, b), getInt8(0x31)});
108    }
109#endif
110    // Otherwise use default SSE logic.
111    return IDISA_SSE_Builder::esimd_mergeh(fw, a, b);
112}
113
114Value * IDISA_AVX2_Builder::esimd_mergel(unsigned fw, Value * a, Value * b) {
115#if LLVM_VERSION_INTEGER < LLVM_6_0_0
116    if ((fw == 128) && (mBitBlockWidth == 256)) {
117        Value * vperm2i128func = Intrinsic::getDeclaration(getModule(), Intrinsic::x86_avx2_vperm2i128);
118        return CreateCall(vperm2i128func, {fwCast(64, a), fwCast(64, b), getInt8(0x20)});
119    }
120#endif
121    // Otherwise use default SSE logic.
122    return IDISA_SSE_Builder::esimd_mergel(fw, a, b);
123}
124
125Value * IDISA_AVX2_Builder::hsimd_packl_in_lanes(unsigned lanes, unsigned fw, Value * a, Value * b) {
126    if ((fw == 16)  && (lanes == 2)) {
127        Value * vpackuswbfunc = Intrinsic::getDeclaration(getModule(), Intrinsic::x86_avx2_packuswb);
128        Value * a_low = fwCast(16, simd_and(a, simd_lomask(fw)));
129        Value * b_low = fwCast(16, simd_and(b, simd_lomask(fw)));
130        return CreateCall(vpackuswbfunc, {a_low, b_low});
131    }
132    // Otherwise use default SSE logic.
133    return IDISA_SSE_Builder::hsimd_packl_in_lanes(lanes, fw, a, b);
134}
135
136Value * IDISA_AVX2_Builder::hsimd_packh_in_lanes(unsigned lanes, unsigned fw, Value * a, Value * b) {
137    if ((fw == 16)  && (lanes == 2)) {
138        Value * vpackuswbfunc = Intrinsic::getDeclaration(getModule(), Intrinsic::x86_avx2_packuswb);
139        Value * a_low = simd_srli(fw, a, fw/2);
140        Value * b_low = simd_srli(fw, b, fw/2);
141        return CreateCall(vpackuswbfunc, {a_low, b_low});
142    }
143    // Otherwise use default SSE logic.
144    return IDISA_SSE_Builder::hsimd_packh_in_lanes(lanes, fw, a, b);
145}
146   
147std::pair<Value *, Value *> IDISA_AVX2_Builder::bitblock_add_with_carry(Value * e1, Value * e2, Value * carryin) {
148    // using LONG_ADD
149    Type * carryTy = carryin->getType();
150    if (carryTy == mBitBlockType) {
151        carryin = mvmd_extract(32, carryin, 0);
152    }
153    Value * carrygen = simd_and(e1, e2);
154    Value * carryprop = simd_or(e1, e2);
155    Value * digitsum = simd_add(64, e1, e2);
156    Value * digitcarry = simd_or(carrygen, simd_and(carryprop, CreateNot(digitsum)));
157    Value * carryMask = hsimd_signmask(64, digitcarry);
158    Value * carryMask2 = CreateOr(CreateAdd(carryMask, carryMask), carryin);
159    Value * bubble = simd_eq(64, digitsum, allOnes());
160    Value * bubbleMask = hsimd_signmask(64, bubble);
161    Value * incrementMask = CreateXor(CreateAdd(bubbleMask, carryMask2), bubbleMask);
162    Value * increments = esimd_bitspread(64,incrementMask);
163    Value * sum = simd_add(64, digitsum, increments);
164    Value * carry_out = CreateLShr(incrementMask, mBitBlockWidth / 64);
165    if (carryTy == mBitBlockType) {
166        carry_out = bitCast(CreateZExt(carry_out, getIntNTy(mBitBlockWidth)));
167    }
168    return std::pair<Value *, Value *>{carry_out, bitCast(sum)};
169}
170   
171std::pair<Value *, Value *> IDISA_AVX2_Builder::bitblock_indexed_advance(Value * strm, Value * index_strm, Value * shiftIn, unsigned shiftAmount) {
172    Value * const popcount = Intrinsic::getDeclaration(getModule(), Intrinsic::ctpop, getSizeTy());
173    Value * PEXT_f = nullptr;
174    Value * PDEP_f = nullptr;
175    const unsigned bitWidth = getSizeTy()->getBitWidth();
176    if (bitWidth == 64) {
177        PEXT_f = Intrinsic::getDeclaration(getModule(), Intrinsic::x86_bmi_pext_64);
178        PDEP_f = Intrinsic::getDeclaration(getModule(), Intrinsic::x86_bmi_pdep_64);
179    } else if ((bitWidth == 32)  && (shiftAmount < 32)) {
180        PEXT_f = Intrinsic::getDeclaration(getModule(), Intrinsic::x86_bmi_pext_32);
181        PDEP_f = Intrinsic::getDeclaration(getModule(), Intrinsic::x86_bmi_pdep_32);
182    } else {
183        llvm::report_fatal_error("indexed_advance unsupported bit width");
184    }
185    Type * iBitBlock = getIntNTy(getBitBlockWidth());
186    Value * shiftVal = getSize(shiftAmount);
187    const auto n = getBitBlockWidth() / bitWidth;
188    VectorType * const vecTy = VectorType::get(getSizeTy(), n);
189    if (LLVM_LIKELY(shiftAmount < bitWidth)) {
190        Value * carry = mvmd_extract(bitWidth, shiftIn, 0);
191        Value * result = UndefValue::get(vecTy);
192        for (unsigned i = 0; i < n; i++) {
193            Value * s = mvmd_extract(bitWidth, strm, i);
194            Value * ix = mvmd_extract(bitWidth, index_strm, i);
195            Value * ix_popcnt = CreateCall(popcount, {ix});
196            Value * bits = CreateCall(PEXT_f, {s, ix});
197            Value * adv = CreateOr(CreateShl(bits, shiftAmount), carry);
198            // We have two cases depending on whether the popcount of the index pack is < shiftAmount or not.
199            Value * popcount_small = CreateICmpULT(ix_popcnt, shiftVal);
200            Value * carry_if_popcount_small =
201                CreateOr(CreateShl(bits, CreateSub(shiftVal, ix_popcnt)),
202                            CreateLShr(carry, ix_popcnt));
203            Value * carry_if_popcount_large = CreateLShr(bits, CreateSub(ix_popcnt, shiftVal));
204            carry = CreateSelect(popcount_small, carry_if_popcount_small, carry_if_popcount_large);
205            result = mvmd_insert(bitWidth, result, CreateCall(PDEP_f, {adv, ix}), i);
206        }
207        Value * carryOut = mvmd_insert(bitWidth, allZeroes(), carry, 0);
208        return std::pair<Value *, Value *>{bitCast(carryOut), bitCast(result)};
209    }
210    else if (shiftAmount <= mBitBlockWidth) {
211        // The shift amount is always greater than the popcount of the individual
212        // elements that we deal with.   This simplifies some of the logic.
213        Value * carry = CreateBitCast(shiftIn, iBitBlock);
214        Value * result = UndefValue::get(vecTy);
215        for (unsigned i = 0; i < n; i++) {
216            Value * s = mvmd_extract(bitWidth, strm, i);
217            Value * ix = mvmd_extract(bitWidth, index_strm, i);
218            Value * ix_popcnt = CreateCall(popcount, {ix});
219            Value * bits = CreateCall(PEXT_f, {s, ix});  // All these bits are shifted out (appended to carry).
220            result = mvmd_insert(bitWidth, result, CreateCall(PDEP_f, {mvmd_extract(bitWidth, carry, 0), ix}), i);
221            carry = CreateLShr(carry, CreateZExt(ix_popcnt, iBitBlock)); // Remove the carry bits consumed, make room for new bits.
222            carry = CreateOr(carry, CreateShl(CreateZExt(bits, iBitBlock), CreateZExt(CreateSub(shiftVal, ix_popcnt), iBitBlock)));
223        }
224        return std::pair<Value *, Value *>{bitCast(carry), bitCast(result)};
225    }
226    else {
227        // The shift amount is greater than the total popcount.   We will consume popcount
228        // bits from the shiftIn value only, and produce a carry out value of the selected bits.
229        // elements that we deal with.   This simplifies some of the logic.
230        Value * carry = CreateBitCast(shiftIn, iBitBlock);
231        Value * result = UndefValue::get(vecTy);
232        Value * carryOut = ConstantInt::getNullValue(iBitBlock);
233        Value * generated = getSize(0);
234        for (unsigned i = 0; i < n; i++) {
235            Value * s = mvmd_extract(bitWidth, strm, i);
236            Value * ix = mvmd_extract(bitWidth, index_strm, i);
237            Value * ix_popcnt = CreateCall(popcount, {ix});
238            Value * bits = CreateCall(PEXT_f, {s, ix});  // All these bits are shifted out (appended to carry).
239            result = mvmd_insert(bitWidth, result, CreateCall(PDEP_f, {mvmd_extract(bitWidth, carry, 0), ix}), i);
240            carry = CreateLShr(carry, CreateZExt(ix_popcnt, iBitBlock)); // Remove the carry bits consumed.
241            carryOut = CreateOr(carryOut, CreateShl(CreateZExt(bits, iBitBlock), CreateZExt(generated, iBitBlock)));
242            generated = CreateAdd(generated, ix_popcnt);
243        }
244        return std::pair<Value *, Value *>{bitCast(carryOut), bitCast(result)};
245    }
246}
247
248}
Note: See TracBrowser for help on using the repository browser.