source: icGREP/icgrep-devel/icgrep/IR_Gen/idisa_avx_builder.cpp @ 5797

Last change on this file since 5797 was 5759, checked in by cameron, 19 months ago

Small fixes

File size: 12.0 KB
Line 
1/*
2 *  Copyright (c) 2015 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 *  icgrep is a trademark of International Characters.
5 */
6
7#include "idisa_avx_builder.h"
8
9using namespace llvm;
10
11namespace IDISA {
12   
13std::string IDISA_AVX_Builder::getBuilderUniqueName() {
14    return mBitBlockWidth != 256 ? "AVX_" + std::to_string(mBitBlockWidth) : "AVX";
15}
16
17Value * IDISA_AVX_Builder::hsimd_signmask(unsigned fw, Value * a) {
18    // AVX2 special cases
19    if (mBitBlockWidth == 256) {
20        if (fw == 64) {
21            Value * signmask_f64func = Intrinsic::getDeclaration(getModule(), Intrinsic::x86_avx_movmsk_pd_256);
22            Type * bitBlock_f64type = VectorType::get(getDoubleTy(), mBitBlockWidth/64);
23            Value * a_as_pd = CreateBitCast(a, bitBlock_f64type);
24            return CreateCall(signmask_f64func, a_as_pd);
25        } else if (fw == 32) {
26            Value * signmask_f32func = Intrinsic::getDeclaration(getModule(), Intrinsic::x86_avx_movmsk_ps_256);
27            Type * bitBlock_f32type = VectorType::get(getFloatTy(), mBitBlockWidth/32);
28            Value * a_as_ps = CreateBitCast(a, bitBlock_f32type);
29            return CreateCall(signmask_f32func, a_as_ps);
30        }
31    } else if (mBitBlockWidth == 512) {
32        if (fw == 64) {
33            Type * bitBlock_f32type = VectorType::get(getFloatTy(), mBitBlockWidth / 32);
34            Value * a_as_ps = CreateBitCast(a, bitBlock_f32type);
35            Constant * indicies[8];
36            for (unsigned i = 0; i < 8; i++) {
37                indicies[i] = getInt32(2 * i + 1);
38            }
39            Value * packh = CreateShuffleVector(a_as_ps, UndefValue::get(bitBlock_f32type), ConstantVector::get({indicies, 8}));
40            Type * halfBlock_f32type = VectorType::get(getFloatTy(), mBitBlockWidth/64);
41            Value * pack_as_ps = CreateBitCast(packh, halfBlock_f32type);
42            Value * signmask_f32func = Intrinsic::getDeclaration(getModule(), Intrinsic::x86_avx_movmsk_ps_256);
43            return CreateCall(signmask_f32func, pack_as_ps);
44        }
45    }
46    // Otherwise use default SSE logic.
47    return IDISA_SSE_Builder::hsimd_signmask(fw, a);
48}
49
50std::string IDISA_AVX2_Builder::getBuilderUniqueName() {
51    return mBitBlockWidth != 256 ? "AVX2_" + std::to_string(mBitBlockWidth) : "AVX2";
52}
53   
54Value * IDISA_AVX2_Builder::hsimd_packh(unsigned fw, Value * a, Value * b) {
55    if (fw <= 64) {       
56        Value * aVec = fwCast(fw / 2, a);
57        Value * bVec = fwCast(fw / 2, b);
58        const auto field_count = 2 * mBitBlockWidth / fw;
59        Constant * Idxs[field_count];
60        const auto H = (field_count / 2);
61        const auto Q = (field_count / 4);
62        for (unsigned i = 0; i < Q; i++) {
63            Idxs[i] = getInt32(2 * i);
64            Idxs[i + Q] = getInt32((2 * i) + 1);
65            Idxs[i + H] = getInt32((2 * i) + H);
66            Idxs[i + H + Q] = getInt32((2 * i) + 1 + H);
67        }
68        Value * shufa = CreateShuffleVector(aVec, aVec, ConstantVector::get({Idxs, field_count}));
69        Value * shufb = CreateShuffleVector(bVec, bVec, ConstantVector::get({Idxs, field_count}));
70        return hsimd_packh(mBitBlockWidth / 2, shufa, shufb);
71    }
72    // Otherwise use default SSE logic.
73    return IDISA_SSE_Builder::hsimd_packh(fw, a, b);
74}
75
76Value * IDISA_AVX2_Builder::hsimd_packl(unsigned fw, Value * a, Value * b) {
77    if (fw <= 64) {
78        Value * aVec = fwCast(fw / 2, a);
79        Value * bVec = fwCast(fw / 2, b);
80        const auto field_count = 2 * mBitBlockWidth / fw;
81        Constant * Idxs[field_count];
82        const auto H = (field_count / 2);
83        const auto Q = (field_count / 4);
84        for (unsigned i = 0; i < Q; i++) {
85            Idxs[i] = getInt32(2 * i);
86            Idxs[i + Q] = getInt32((2 * i) + 1);
87            Idxs[i + H] = getInt32((2 * i) + H);
88            Idxs[i + H + Q] = getInt32((2 * i) + H + 1);
89        }
90        Value * shufa = CreateShuffleVector(aVec, aVec, ConstantVector::get({Idxs, field_count}));
91        Value * shufb = CreateShuffleVector(bVec, bVec, ConstantVector::get({Idxs, field_count}));
92        return hsimd_packl(mBitBlockWidth / 2, shufa, shufb);
93    }
94    // Otherwise use default SSE logic.
95    return IDISA_SSE_Builder::hsimd_packl(fw, a, b);
96}
97   
98Value * IDISA_AVX2_Builder::esimd_mergeh(unsigned fw, Value * a, Value * b) {
99#if LLVM_VERSION_INTEGER < LLVM_6_0_0
100    if ((fw == 128) && (mBitBlockWidth == 256)) {
101        Value * vperm2i128func = Intrinsic::getDeclaration(getModule(), Intrinsic::x86_avx2_vperm2i128);
102        return CreateCall(vperm2i128func, {fwCast(64, a), fwCast(64, b), getInt8(0x31)});
103    }
104#endif
105    // Otherwise use default SSE logic.
106    return IDISA_SSE_Builder::esimd_mergeh(fw, a, b);
107}
108
109Value * IDISA_AVX2_Builder::esimd_mergel(unsigned fw, Value * a, Value * b) {
110#if LLVM_VERSION_INTEGER < LLVM_6_0_0
111    if ((fw == 128) && (mBitBlockWidth == 256)) {
112        Value * vperm2i128func = Intrinsic::getDeclaration(getModule(), Intrinsic::x86_avx2_vperm2i128);
113        return CreateCall(vperm2i128func, {fwCast(64, a), fwCast(64, b), getInt8(0x20)});
114    }
115#endif
116    // Otherwise use default SSE logic.
117    return IDISA_SSE_Builder::esimd_mergel(fw, a, b);
118}
119
120Value * IDISA_AVX2_Builder::hsimd_packl_in_lanes(unsigned lanes, unsigned fw, Value * a, Value * b) {
121    if ((fw == 16)  && (lanes == 2)) {
122        Value * vpackuswbfunc = Intrinsic::getDeclaration(getModule(), Intrinsic::x86_avx2_packuswb);
123        Value * a_low = fwCast(16, simd_and(a, simd_lomask(fw)));
124        Value * b_low = fwCast(16, simd_and(b, simd_lomask(fw)));
125        return CreateCall(vpackuswbfunc, {a_low, b_low});
126    }
127    // Otherwise use default SSE logic.
128    return IDISA_SSE_Builder::hsimd_packl_in_lanes(lanes, fw, a, b);
129}
130
131Value * IDISA_AVX2_Builder::hsimd_packh_in_lanes(unsigned lanes, unsigned fw, Value * a, Value * b) {
132    if ((fw == 16)  && (lanes == 2)) {
133        Value * vpackuswbfunc = Intrinsic::getDeclaration(getModule(), Intrinsic::x86_avx2_packuswb);
134        Value * a_low = simd_srli(fw, a, fw/2);
135        Value * b_low = simd_srli(fw, b, fw/2);
136        return CreateCall(vpackuswbfunc, {a_low, b_low});
137    }
138    // Otherwise use default SSE logic.
139    return IDISA_SSE_Builder::hsimd_packh_in_lanes(lanes, fw, a, b);
140}
141   
142std::pair<Value *, Value *> IDISA_AVX2_Builder::bitblock_add_with_carry(Value * e1, Value * e2, Value * carryin) {
143    // using LONG_ADD
144    Type * carryTy = carryin->getType();
145    if (carryTy == mBitBlockType) {
146        carryin = mvmd_extract(32, carryin, 0);
147    }
148    Value * carrygen = simd_and(e1, e2);
149    Value * carryprop = simd_or(e1, e2);
150    Value * digitsum = simd_add(64, e1, e2);
151    Value * digitcarry = simd_or(carrygen, simd_and(carryprop, CreateNot(digitsum)));
152    Value * carryMask = hsimd_signmask(64, digitcarry);
153    Value * carryMask2 = CreateOr(CreateAdd(carryMask, carryMask), carryin);
154    Value * bubble = simd_eq(64, digitsum, allOnes());
155    Value * bubbleMask = hsimd_signmask(64, bubble);
156    Value * incrementMask = CreateXor(CreateAdd(bubbleMask, carryMask2), bubbleMask);
157    Value * increments = esimd_bitspread(64,incrementMask);
158    Value * sum = simd_add(64, digitsum, increments);
159    Value * carry_out = CreateLShr(incrementMask, mBitBlockWidth / 64);
160    if (carryTy == mBitBlockType) {
161        carry_out = bitCast(CreateZExt(carry_out, getIntNTy(mBitBlockWidth)));
162    }
163    return std::pair<Value *, Value *>{carry_out, bitCast(sum)};
164}
165   
166std::pair<Value *, Value *> IDISA_AVX2_Builder::bitblock_indexed_advance(Value * strm, Value * index_strm, Value * shiftIn, unsigned shiftAmount) {
167    Value * popcount_f = Intrinsic::getDeclaration(getModule(), Intrinsic::ctpop, getSizeTy());
168    Value * PEXT_f = nullptr;
169    Value * PDEP_f = nullptr;
170    unsigned bitWidth = sizeof(size_t) * 8;
171    if (bitWidth == 64) {
172        PEXT_f = Intrinsic::getDeclaration(getModule(), Intrinsic::x86_bmi_pext_64);
173        PDEP_f = Intrinsic::getDeclaration(getModule(), Intrinsic::x86_bmi_pdep_64);
174    }
175    else if ((bitWidth == 32)  && (shiftAmount < 32)) {
176        PEXT_f = Intrinsic::getDeclaration(getModule(), Intrinsic::x86_bmi_pext_32);
177        PDEP_f = Intrinsic::getDeclaration(getModule(), Intrinsic::x86_bmi_pdep_32);
178    }
179    else {
180        llvm::report_fatal_error("indexed_advance unsupported bit width");
181    }
182    Type * iBitBlock = getIntNTy(getBitBlockWidth());
183    Value * shiftVal = getSize(shiftAmount);
184    if (LLVM_LIKELY(shiftAmount < bitWidth)) {
185        Value * carry = mvmd_extract(bitWidth, shiftIn, 0);
186        Value * result = allZeroes();
187        for (unsigned i = 0; i < getBitBlockWidth()/bitWidth; i++) {
188            Value * s = mvmd_extract(bitWidth, strm, i);
189            Value * ix = mvmd_extract(bitWidth, index_strm, i);
190            Value * ix_popcnt = CreateCall(popcount_f, {ix});
191            Value * bits = CreateCall(PEXT_f, {s, ix});
192            Value * adv = CreateOr(CreateShl(bits, shiftAmount), carry);
193            // We have two cases depending on whether the popcount of the index pack is < shiftAmount or not.
194            Value * popcount_small = CreateICmpULT(ix_popcnt, shiftVal);
195            Value * carry_if_popcount_small = 
196                CreateOr(CreateShl(bits, CreateSub(shiftVal, ix_popcnt)),
197                            CreateLShr(carry, ix_popcnt));
198            Value * carry_if_popcount_large = CreateLShr(bits, CreateSub(ix_popcnt, shiftVal));
199            carry = CreateSelect(popcount_small, carry_if_popcount_small, carry_if_popcount_large);
200            result = mvmd_insert(bitWidth, result, CreateCall(PDEP_f, {adv, ix}), i);
201        }
202        Value * carryOut = mvmd_insert(bitWidth, allZeroes(), carry, 0);
203        return std::pair<Value *, Value *>{bitCast(carryOut), bitCast(result)};
204    }
205    else if (shiftAmount <= mBitBlockWidth) {
206        // The shift amount is always greater than the popcount of the individual
207        // elements that we deal with.   This simplifies some of the logic.
208        Value * carry = CreateBitCast(shiftIn, iBitBlock);
209        Value * result = allZeroes();
210        for (unsigned i = 0; i < getBitBlockWidth()/bitWidth; i++) {
211            Value * s = mvmd_extract(bitWidth, strm, i);
212            Value * ix = mvmd_extract(bitWidth, index_strm, i);
213            Value * ix_popcnt = CreateCall(popcount_f, {ix});
214            Value * bits = CreateCall(PEXT_f, {s, ix});  // All these bits are shifted out (appended to carry).
215            result = mvmd_insert(bitWidth, result, CreateCall(PDEP_f, {mvmd_extract(bitWidth, carry, 0), ix}), i);
216            carry = CreateLShr(carry, CreateZExt(ix_popcnt, iBitBlock)); // Remove the carry bits consumed, make room for new bits.
217            carry = CreateOr(carry, CreateShl(CreateZExt(bits, iBitBlock), CreateZExt(CreateSub(shiftVal, ix_popcnt), iBitBlock)));
218        }
219        return std::pair<Value *, Value *>{bitCast(carry), bitCast(result)};
220    }
221    else {
222        // The shift amount is greater than the total popcount.   We will consume popcount
223        // bits from the shiftIn value only, and produce a carry out value of the selected bits.
224        // elements that we deal with.   This simplifies some of the logic.
225        Value * carry = CreateBitCast(shiftIn, iBitBlock);
226        Value * result = allZeroes();
227        Value * carryOut = CreateBitCast(allZeroes(), iBitBlock);
228        Value * generated = getSize(0);
229        for (unsigned i = 0; i < getBitBlockWidth()/bitWidth; i++) {
230            Value * s = mvmd_extract(bitWidth, strm, i);
231            Value * ix = mvmd_extract(bitWidth, index_strm, i);
232            Value * ix_popcnt = CreateCall(popcount_f, {ix});
233            Value * bits = CreateCall(PEXT_f, {s, ix});  // All these bits are shifted out (appended to carry).
234            result = mvmd_insert(bitWidth, result, CreateCall(PDEP_f, {mvmd_extract(bitWidth, carry, 0), ix}), i);
235            carry = CreateLShr(carry, CreateZExt(ix_popcnt, iBitBlock)); // Remove the carry bits consumed.
236            carryOut = CreateOr(carryOut, CreateShl(CreateZExt(bits, iBitBlock), CreateZExt(generated, iBitBlock)));
237            generated = CreateAdd(generated, ix_popcnt);
238        }
239        return std::pair<Value *, Value *>{bitCast(carryOut), bitCast(result)};
240    }
241}
242
243}
Note: See TracBrowser for help on using the repository browser.