source: icGREP/icgrep-devel/icgrep/IR_Gen/idisa_avx_builder.cpp @ 5977

Last change on this file since 5977 was 5977, checked in by cameron, 13 months ago

AVX512 subfeature detection and popcount from Cole with further modification

File size: 19.3 KB
Line 
1/*
2 *  Copyright (c) 2018 International Characters.
3 *  This software is licensed to the public under the Open Software License 3.0.
4 *  icgrep is a trademark of International Characters.
5 */
6
7#include "idisa_avx_builder.h"
8#include <toolchain/toolchain.h>
9#include <llvm/Support/raw_ostream.h>
10
11using namespace llvm;
12
13namespace IDISA {
14
15std::string IDISA_AVX_Builder::getBuilderUniqueName() {
16    return mBitBlockWidth != 256 ? "AVX_" + std::to_string(mBitBlockWidth) : "AVX";
17}
18
19Value * IDISA_AVX_Builder::hsimd_signmask(unsigned fw, Value * a) {
20    // AVX2 special cases
21    if (mBitBlockWidth == 256) {
22        if (fw == 64) {
23            Value * signmask_f64func = Intrinsic::getDeclaration(getModule(), Intrinsic::x86_avx_movmsk_pd_256);
24            Type * bitBlock_f64type = VectorType::get(getDoubleTy(), mBitBlockWidth/64);
25            Value * a_as_pd = CreateBitCast(a, bitBlock_f64type);
26            return CreateCall(signmask_f64func, a_as_pd);
27        } else if (fw == 32) {
28            Value * signmask_f32func = Intrinsic::getDeclaration(getModule(), Intrinsic::x86_avx_movmsk_ps_256);
29            Type * bitBlock_f32type = VectorType::get(getFloatTy(), mBitBlockWidth/32);
30            Value * a_as_ps = CreateBitCast(a, bitBlock_f32type);
31            return CreateCall(signmask_f32func, a_as_ps);
32        }
33    } else if (mBitBlockWidth == 512) {
34        if (fw == 64) {
35            Type * bitBlock_f32type = VectorType::get(getFloatTy(), mBitBlockWidth / 32);
36            Value * a_as_ps = CreateBitCast(a, bitBlock_f32type);
37            Constant * indicies[8];
38            for (unsigned i = 0; i < 8; i++) {
39                indicies[i] = getInt32(2 * i + 1);
40            }
41            Value * packh = CreateShuffleVector(a_as_ps, UndefValue::get(bitBlock_f32type), ConstantVector::get({indicies, 8}));
42            Type * halfBlock_f32type = VectorType::get(getFloatTy(), mBitBlockWidth/64);
43            Value * pack_as_ps = CreateBitCast(packh, halfBlock_f32type);
44            Value * signmask_f32func = Intrinsic::getDeclaration(getModule(), Intrinsic::x86_avx_movmsk_ps_256);
45            return CreateCall(signmask_f32func, pack_as_ps);
46        }
47    }
48    // Otherwise use default SSE logic.
49    return IDISA_SSE_Builder::hsimd_signmask(fw, a);
50}
51
52std::string IDISA_AVX2_Builder::getBuilderUniqueName() {
53    return mBitBlockWidth != 256 ? "AVX2_" + std::to_string(mBitBlockWidth) : "AVX2";
54}
55
56Value * IDISA_AVX2_Builder::hsimd_packh(unsigned fw, Value * a, Value * b) {
57    if ((fw > 8) && (fw <= 64)) {
58        Value * aVec = fwCast(fw / 2, a);
59        Value * bVec = fwCast(fw / 2, b);
60        const auto field_count = 2 * mBitBlockWidth / fw;
61        Constant * Idxs[field_count];
62        const auto H = (field_count / 2);
63        const auto Q = (field_count / 4);
64        for (unsigned i = 0; i < Q; i++) {
65            Idxs[i] = getInt32(2 * i);
66            Idxs[i + Q] = getInt32((2 * i) + 1);
67            Idxs[i + H] = getInt32((2 * i) + H);
68            Idxs[i + H + Q] = getInt32((2 * i) + 1 + H);
69        }
70        Value * shufa = CreateShuffleVector(aVec, aVec, ConstantVector::get({Idxs, field_count}));
71        Value * shufb = CreateShuffleVector(bVec, bVec, ConstantVector::get({Idxs, field_count}));
72        return hsimd_packh(mBitBlockWidth / 2, shufa, shufb);
73    }
74    // Otherwise use default SSE logic.
75    return IDISA_SSE_Builder::hsimd_packh(fw, a, b);
76}
77
78Value * IDISA_AVX2_Builder::hsimd_packl(unsigned fw, Value * a, Value * b) {
79    if ((fw > 8) && (fw <= 64)) {
80        Value * aVec = fwCast(fw / 2, a);
81        Value * bVec = fwCast(fw / 2, b);
82        const auto field_count = 2 * mBitBlockWidth / fw;
83        Constant * Idxs[field_count];
84        const auto H = (field_count / 2);
85        const auto Q = (field_count / 4);
86        for (unsigned i = 0; i < Q; i++) {
87            Idxs[i] = getInt32(2 * i);
88            Idxs[i + Q] = getInt32((2 * i) + 1);
89            Idxs[i + H] = getInt32((2 * i) + H);
90            Idxs[i + H + Q] = getInt32((2 * i) + H + 1);
91        }
92        Value * shufa = CreateShuffleVector(aVec, aVec, ConstantVector::get({Idxs, field_count}));
93        Value * shufb = CreateShuffleVector(bVec, bVec, ConstantVector::get({Idxs, field_count}));
94        return hsimd_packl(mBitBlockWidth / 2, shufa, shufb);
95    }
96    // Otherwise use default SSE logic.
97    return IDISA_SSE_Builder::hsimd_packl(fw, a, b);
98}
99
100Value * IDISA_AVX2_Builder::esimd_mergeh(unsigned fw, Value * a, Value * b) {
101#if LLVM_VERSION_INTEGER < LLVM_VERSION_CODE(6, 0, 0)
102    if ((fw == 128) && (mBitBlockWidth == 256)) {
103        Value * vperm2i128func = Intrinsic::getDeclaration(getModule(), Intrinsic::x86_avx2_vperm2i128);
104        return CreateCall(vperm2i128func, {fwCast(64, a), fwCast(64, b), getInt8(0x31)});
105    }
106#endif
107    // Otherwise use default SSE logic.
108    return IDISA_SSE_Builder::esimd_mergeh(fw, a, b);
109}
110
111Value * IDISA_AVX2_Builder::esimd_mergel(unsigned fw, Value * a, Value * b) {
112#if LLVM_VERSION_INTEGER < LLVM_VERSION_CODE(6, 0, 0)
113    if ((fw == 128) && (mBitBlockWidth == 256)) {
114        Value * vperm2i128func = Intrinsic::getDeclaration(getModule(), Intrinsic::x86_avx2_vperm2i128);
115        return CreateCall(vperm2i128func, {fwCast(64, a), fwCast(64, b), getInt8(0x20)});
116    }
117#endif
118    // Otherwise use default SSE logic.
119    return IDISA_SSE_Builder::esimd_mergel(fw, a, b);
120}
121
122Value * IDISA_AVX2_Builder::hsimd_packl_in_lanes(unsigned lanes, unsigned fw, Value * a, Value * b) {
123    if ((fw == 16)  && (lanes == 2)) {
124        Value * vpackuswbfunc = Intrinsic::getDeclaration(getModule(), Intrinsic::x86_avx2_packuswb);
125        Value * a_low = fwCast(16, simd_and(a, simd_lomask(fw)));
126        Value * b_low = fwCast(16, simd_and(b, simd_lomask(fw)));
127        return CreateCall(vpackuswbfunc, {a_low, b_low});
128    }
129    // Otherwise use default SSE logic.
130    return IDISA_SSE_Builder::hsimd_packl_in_lanes(lanes, fw, a, b);
131}
132
133Value * IDISA_AVX2_Builder::hsimd_packh_in_lanes(unsigned lanes, unsigned fw, Value * a, Value * b) {
134    if ((fw == 16)  && (lanes == 2)) {
135        Value * vpackuswbfunc = Intrinsic::getDeclaration(getModule(), Intrinsic::x86_avx2_packuswb);
136        Value * a_low = simd_srli(fw, a, fw/2);
137        Value * b_low = simd_srli(fw, b, fw/2);
138        return CreateCall(vpackuswbfunc, {a_low, b_low});
139    }
140    // Otherwise use default SSE logic.
141    return IDISA_SSE_Builder::hsimd_packh_in_lanes(lanes, fw, a, b);
142}
143
144std::pair<Value *, Value *> IDISA_AVX2_Builder::bitblock_add_with_carry(Value * e1, Value * e2, Value * carryin) {
145    // using LONG_ADD
146    Type * carryTy = carryin->getType();
147    if (carryTy == mBitBlockType) {
148        carryin = mvmd_extract(32, carryin, 0);
149    }
150    Value * carrygen = simd_and(e1, e2);
151    Value * carryprop = simd_or(e1, e2);
152    Value * digitsum = simd_add(64, e1, e2);
153    Value * digitcarry = simd_or(carrygen, simd_and(carryprop, CreateNot(digitsum)));
154    Value * carryMask = hsimd_signmask(64, digitcarry);
155    Value * carryMask2 = CreateOr(CreateAdd(carryMask, carryMask), carryin);
156    Value * bubble = simd_eq(64, digitsum, allOnes());
157    Value * bubbleMask = hsimd_signmask(64, bubble);
158    Value * incrementMask = CreateXor(CreateAdd(bubbleMask, carryMask2), bubbleMask);
159    Value * increments = esimd_bitspread(64,incrementMask);
160    Value * sum = simd_add(64, digitsum, increments);
161    Value * carry_out = CreateLShr(incrementMask, mBitBlockWidth / 64);
162    if (carryTy == mBitBlockType) {
163        carry_out = bitCast(CreateZExt(carry_out, getIntNTy(mBitBlockWidth)));
164    }
165    return std::pair<Value *, Value *>{carry_out, bitCast(sum)};
166}
167
168std::pair<Value *, Value *> IDISA_AVX2_Builder::bitblock_indexed_advance(Value * strm, Value * index_strm, Value * shiftIn, unsigned shiftAmount) {
169    Value * const popcount = Intrinsic::getDeclaration(getModule(), Intrinsic::ctpop, getSizeTy());
170    Value * PEXT_f = nullptr;
171    Value * PDEP_f = nullptr;
172    const unsigned bitWidth = getSizeTy()->getBitWidth();
173    if (bitWidth == 64) {
174        PEXT_f = Intrinsic::getDeclaration(getModule(), Intrinsic::x86_bmi_pext_64);
175        PDEP_f = Intrinsic::getDeclaration(getModule(), Intrinsic::x86_bmi_pdep_64);
176    } else if ((bitWidth == 32)  && (shiftAmount < 32)) {
177        PEXT_f = Intrinsic::getDeclaration(getModule(), Intrinsic::x86_bmi_pext_32);
178        PDEP_f = Intrinsic::getDeclaration(getModule(), Intrinsic::x86_bmi_pdep_32);
179    } else {
180        llvm::report_fatal_error("indexed_advance unsupported bit width");
181    }
182    Type * iBitBlock = getIntNTy(getBitBlockWidth());
183    Value * shiftVal = getSize(shiftAmount);
184    const auto n = getBitBlockWidth() / bitWidth;
185    VectorType * const vecTy = VectorType::get(getSizeTy(), n);
186    if (LLVM_LIKELY(shiftAmount < bitWidth)) {
187        Value * carry = mvmd_extract(bitWidth, shiftIn, 0);
188        Value * result = UndefValue::get(vecTy);
189        for (unsigned i = 0; i < n; i++) {
190            Value * s = mvmd_extract(bitWidth, strm, i);
191            Value * ix = mvmd_extract(bitWidth, index_strm, i);
192            Value * ix_popcnt = CreateCall(popcount, {ix});
193            Value * bits = CreateCall(PEXT_f, {s, ix});
194            Value * adv = CreateOr(CreateShl(bits, shiftAmount), carry);
195            // We have two cases depending on whether the popcount of the index pack is < shiftAmount or not.
196            Value * popcount_small = CreateICmpULT(ix_popcnt, shiftVal);
197            Value * carry_if_popcount_small =
198                CreateOr(CreateShl(bits, CreateSub(shiftVal, ix_popcnt)),
199                            CreateLShr(carry, ix_popcnt));
200            Value * carry_if_popcount_large = CreateLShr(bits, CreateSub(ix_popcnt, shiftVal));
201            carry = CreateSelect(popcount_small, carry_if_popcount_small, carry_if_popcount_large);
202            result = mvmd_insert(bitWidth, result, CreateCall(PDEP_f, {adv, ix}), i);
203        }
204        Value * carryOut = mvmd_insert(bitWidth, allZeroes(), carry, 0);
205        return std::pair<Value *, Value *>{bitCast(carryOut), bitCast(result)};
206    }
207    else if (shiftAmount <= mBitBlockWidth) {
208        // The shift amount is always greater than the popcount of the individual
209        // elements that we deal with.   This simplifies some of the logic.
210        Value * carry = CreateBitCast(shiftIn, iBitBlock);
211        Value * result = UndefValue::get(vecTy);
212        for (unsigned i = 0; i < n; i++) {
213            Value * s = mvmd_extract(bitWidth, strm, i);
214            Value * ix = mvmd_extract(bitWidth, index_strm, i);
215            Value * ix_popcnt = CreateCall(popcount, {ix});
216            Value * bits = CreateCall(PEXT_f, {s, ix});  // All these bits are shifted out (appended to carry).
217            result = mvmd_insert(bitWidth, result, CreateCall(PDEP_f, {mvmd_extract(bitWidth, carry, 0), ix}), i);
218            carry = CreateLShr(carry, CreateZExt(ix_popcnt, iBitBlock)); // Remove the carry bits consumed, make room for new bits.
219            carry = CreateOr(carry, CreateShl(CreateZExt(bits, iBitBlock), CreateZExt(CreateSub(shiftVal, ix_popcnt), iBitBlock)));
220        }
221        return std::pair<Value *, Value *>{bitCast(carry), bitCast(result)};
222    }
223    else {
224        // The shift amount is greater than the total popcount.   We will consume popcount
225        // bits from the shiftIn value only, and produce a carry out value of the selected bits.
226        // elements that we deal with.   This simplifies some of the logic.
227        Value * carry = CreateBitCast(shiftIn, iBitBlock);
228        Value * result = UndefValue::get(vecTy);
229        Value * carryOut = ConstantInt::getNullValue(iBitBlock);
230        Value * generated = getSize(0);
231        for (unsigned i = 0; i < n; i++) {
232            Value * s = mvmd_extract(bitWidth, strm, i);
233            Value * ix = mvmd_extract(bitWidth, index_strm, i);
234            Value * ix_popcnt = CreateCall(popcount, {ix});
235            Value * bits = CreateCall(PEXT_f, {s, ix});  // All these bits are shifted out (appended to carry).
236            result = mvmd_insert(bitWidth, result, CreateCall(PDEP_f, {mvmd_extract(bitWidth, carry, 0), ix}), i);
237            carry = CreateLShr(carry, CreateZExt(ix_popcnt, iBitBlock)); // Remove the carry bits consumed.
238            carryOut = CreateOr(carryOut, CreateShl(CreateZExt(bits, iBitBlock), CreateZExt(generated, iBitBlock)));
239            generated = CreateAdd(generated, ix_popcnt);
240        }
241        return std::pair<Value *, Value *>{bitCast(carryOut), bitCast(result)};
242    }
243}
244
245Value * IDISA_AVX2_Builder::hsimd_signmask(unsigned fw, Value * a) {
246    // AVX2 special cases
247    if (mBitBlockWidth == 256) {
248        if (fw == 8) {
249            Value * signmask_f8func = Intrinsic::getDeclaration(getModule(), Intrinsic::x86_avx2_pmovmskb);
250            Type * bitBlock_i8type = VectorType::get(getInt8Ty(), mBitBlockWidth/8);
251            Value * a_as_ps = CreateBitCast(a, bitBlock_i8type);
252            return CreateCall(signmask_f8func, a_as_ps);
253        }
254    }
255    // Otherwise use default SSE logic.
256    return IDISA_AVX_Builder::hsimd_signmask(fw, a);
257}
258
259std::string IDISA_AVX512F_Builder::getBuilderUniqueName() {
260    return mBitBlockWidth != 512 ? "AVX512F_" + std::to_string(mBitBlockWidth) : "AVX512BW";
261}
262
263llvm::Value * IDISA_AVX512F_Builder::hsimd_packh(unsigned fw, llvm::Value * a, llvm::Value * b) {
264    if ((mBitBlockWidth == 512) && (fw == 16)) {
265
266        const unsigned int field_count = 64;
267        Constant * Idxs[field_count];
268
269        for (unsigned int i = 0; i < field_count; i++) {
270            Idxs[i] = getInt32(i);
271        }
272
273        llvm::Value * pmovfunc = Intrinsic::getDeclaration(getModule(), Intrinsic::x86_avx512_mask_pmov_wb_512);
274        llvm::Value * mask = getInt32(-1);
275        llvm::Constant * shuffleMask = ConstantVector::get({Idxs, 64});
276
277        a = fwCast(fw, a);
278        a = IDISA_Builder::simd_srai(fw, a, fw/2);
279        a = CreateCall(pmovfunc, {a, a, mask});
280        b = fwCast(fw, b);
281        b = IDISA_Builder::simd_srai(fw, b, fw/2);
282        b = CreateCall(pmovfunc, {b, b, mask});
283
284        llvm::Value * c = CreateShuffleVector(a, b, shuffleMask);
285        c = bitCast(c);
286        return c;
287    }
288return IDISA_Builder::hsimd_packh(fw, a, b);
289}
290
291llvm::Value * IDISA_AVX512F_Builder::hsimd_packl(unsigned fw, llvm::Value * a, llvm::Value * b) {
292    if ((mBitBlockWidth == 512) && (fw == 16)) {
293
294        const unsigned int field_count = 64;
295        Constant * Idxs[field_count];
296        for (unsigned int i = 0; i < field_count; i++) {
297            Idxs[i] = getInt32(i);
298        }
299
300        llvm::Value * pmovfunc = Intrinsic::getDeclaration(getModule(), Intrinsic::x86_avx512_mask_pmov_wb_512);
301        llvm::Value * mask = getInt32(-1);
302        llvm::Constant * shuffleMask = ConstantVector::get({Idxs, 64});
303
304        a = fwCast(fw, a);
305        a = CreateCall(pmovfunc, {a, a, mask});
306        b = fwCast(fw, b);
307        b = CreateCall(pmovfunc, {b, b, mask});
308
309        llvm::Value * c = CreateShuffleVector(a, b, shuffleMask);
310        c = bitCast(c);
311        return c;
312    }
313return IDISA_Builder::hsimd_packl(fw, a, b);
314}
315
316llvm::Value * IDISA_AVX512F_Builder::esimd_bitspread(unsigned fw, llvm::Value * bitmask) {
317   
318    if (mBitBlockWidth == 512 && fw == 64) {
319        Value * broadcastFunc = Intrinsic::getDeclaration(getModule(), Intrinsic::x86_avx512_mask_broadcasti64x4_512);
320        Value * broadcastMask = CreateZExtOrTrunc(bitmask, getInt8Ty());
321       
322        const unsigned int srcFieldCount = 8;
323        Constant * srcArr[srcFieldCount];
324        for (unsigned int i = 0; i < srcFieldCount; i++) {
325            srcArr[i] = getInt64(0);
326        }
327        Constant * src = ConstantVector::get({srcArr, srcFieldCount});
328       
329        const unsigned int aFieldCount = 4;
330        Constant * aArr[aFieldCount];
331        for (unsigned int i = 0; i < aFieldCount; i++) {
332            aArr[i] = getInt64(1);
333        }
334        Constant * a = ConstantVector::get({aArr, aFieldCount});
335       
336        return CreateCall(broadcastFunc, {a, src, broadcastMask});
337    }
338   
339    return IDISA_Builder::esimd_bitspread(fw, bitmask);
340}
341
342llvm::Value * IDISA_AVX512F_Builder::simd_popcount(unsigned fw, llvm::Value * a) {
343     if (fw == 512) {
344         Constant * zero16xi8 = Constant::getNullValue(VectorType::get(getInt8Ty(), 16));
345         Constant * zeroInt32 = Constant::getNullValue(getInt32Ty());
346         Value * c = simd_popcount(64, a);
347         //  Should probably use _mm512_reduce_add_epi64, but not found in LLVM 3.8
348         Value * pack64_8_func = Intrinsic::getDeclaration(getModule(), Intrinsic::x86_avx512_mask_pmov_qb_512);
349         // popcounts of 64 bit fields will always fit in 8 bit fields.
350         // We don't need the masked version of this, but the unmasked intrinsic was not found.
351         c = CreateCall(pack64_8_func, {c, zero16xi8, Constant::getAllOnesValue(getInt8Ty())});
352         Value * horizSADfunc = Intrinsic::getDeclaration(getModule(), Intrinsic::x86_sse2_psad_bw);
353         c = CreateCall(horizSADfunc, {c, zero16xi8});
354         return CreateInsertElement(allZeroes(), CreateExtractElement(c, zeroInt32), zeroInt32);
355    }
356    if (hostCPUFeatures.hasAVX512VPOPCNTDQ && (fw == 32 || fw == 64)){
357        //llvm should use vpopcntd or vpopcntq instructions
358        return CreatePopcount(fwCast(fw, a));
359    }
360    if (hostCPUFeatures.hasAVX512BW && (fw == 64)) {
361        Value * horizSADfunc = Intrinsic::getDeclaration(getModule(), Intrinsic::x86_avx512_psad_bw_512);
362        return bitCast(CreateCall(horizSADfunc, {fwCast(8, simd_popcount(8, a)), fwCast(8, allZeroes())}));
363    }
364    //https://en.wikipedia.org/wiki/Hamming_weight#Efficient_implementation
365    if((fw == 64) && (mBitBlockWidth == 512)){
366        Constant * m1Arr[8];
367        llvm::Constant * m1;
368        for (unsigned int i = 0; i < 8; i++) {
369            m1Arr[i] = getInt64(0x5555555555555555);
370        }
371        m1 = ConstantVector::get({m1Arr, 8});
372       
373        Constant * m2Arr[8];
374        llvm::Constant * m2;
375        for (unsigned int i = 0; i < 8; i++) {
376            m2Arr[i] = getInt64(0x3333333333333333);
377        }
378        m2 = ConstantVector::get({m2Arr, 8});
379       
380        Constant * m4Arr[8];
381        llvm::Constant * m4;
382        for (unsigned int i = 0; i < 8; i++) {
383            m4Arr[i] = getInt64(0x0f0f0f0f0f0f0f0f);
384        }
385        m4 = ConstantVector::get({m4Arr, 8});
386       
387        Constant * h01Arr[8];
388        llvm::Constant * h01;
389        for (unsigned int i = 0; i < 8; i++) {
390            h01Arr[i] = getInt64(0x0101010101010101);
391        }
392        h01 = ConstantVector::get({h01Arr, 8});
393       
394        a = simd_sub(fw, a, simd_and(simd_srli(fw, a, 1), m1));
395        a = simd_add(fw, simd_and(a, m2), simd_and(simd_srli(fw, a, 2), m2));
396        a = simd_and(simd_add(fw, a, simd_srli(fw, a, 4)), m4);
397        return simd_srli(fw, simd_mult(fw, a, h01), 56);
398       
399    }
400    return IDISA_Builder::simd_popcount(fw, a);
401}
402
403
404void IDISA_AVX512F_Builder::getAVX512Features() {
405    llvm::StringMap<bool> features;
406    if (llvm::sys::getHostCPUFeatures(features)) {
407        hostCPUFeatures.hasAVX512CD = features.lookup("avx512cd");
408        hostCPUFeatures.hasAVX512BW = features.lookup("avx512bw");
409        hostCPUFeatures.hasAVX512DQ = features.lookup("avx512dq");
410        hostCPUFeatures.hasAVX512VL = features.lookup("avx512vl");
411       
412        //hostCPUFeatures.hasAVX512VBMI, hostCPUFeatures.hasAVX512VBMI2,
413        //hostCPUFeatures.hasAVX512VPOPCNTDQ have not been tested as we
414        //did not have hardware support. It should work in theory (tm)
415       
416        hostCPUFeatures.hasAVX512VBMI = features.lookup("avx512_vbmi");
417        hostCPUFeatures.hasAVX512VBMI2 = features.lookup("avx512_vbmi2");
418        hostCPUFeatures.hasAVX512VPOPCNTDQ = features.lookup("avx512_vpopcntdq");
419    }
420}
421
422
423}
Note: See TracBrowser for help on using the repository browser.