source: parabix-LLVM/llvm_git/lib/Target/X86/X86ParabixISelLowering.cpp @ 4136

Last change on this file since 4136 was 4136, checked in by linmengl, 5 years ago

long stream addition for i128

File size: 39.2 KB
Line 
1//===-- X86ParabixISelLowering.cpp - X86 Parabix DAG Lowering Implementation ----===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------------===//
9//
10// This file defines the interfaces that Parabix uses to lower LLVM code into a
11// selection DAG on X86.
12//
13// Lowering Strategy Sequence:
14// For LowerParabixOperation, we check if the op fits in CastAndOpKind first. If fits,
15// lowering process is done and terminated. Otherwise, we would check those general
16// policies (like in-place vector promotion). If fail again, custom lowering code will
17// be executed (like PXLowerADD).
18//
19//===----------------------------------------------------------------------------===//
20
21#include "X86ISelLowering.h"
22#include "Utils/X86ShuffleDecode.h"
23#include "X86CallingConv.h"
24#include "X86InstrBuilder.h"
25#include "X86MachineFunctionInfo.h"
26#include "X86TargetMachine.h"
27#include "X86TargetObjectFile.h"
28#include "X86ParabixISelLowering.h"
29#include "llvm/ADT/SmallSet.h"
30#include "llvm/ADT/Statistic.h"
31#include "llvm/ADT/StringExtras.h"
32#include "llvm/ADT/StringSwitch.h"
33#include "llvm/ADT/VariadicFunction.h"
34#include "llvm/CodeGen/IntrinsicLowering.h"
35#include "llvm/CodeGen/MachineFrameInfo.h"
36#include "llvm/CodeGen/MachineFunction.h"
37#include "llvm/CodeGen/MachineInstrBuilder.h"
38#include "llvm/CodeGen/MachineJumpTableInfo.h"
39#include "llvm/CodeGen/MachineModuleInfo.h"
40#include "llvm/CodeGen/MachineRegisterInfo.h"
41#include "llvm/IR/CallSite.h"
42#include "llvm/IR/CallingConv.h"
43#include "llvm/IR/Constants.h"
44#include "llvm/IR/DerivedTypes.h"
45#include "llvm/IR/Function.h"
46#include "llvm/IR/GlobalAlias.h"
47#include "llvm/IR/GlobalVariable.h"
48#include "llvm/IR/Instructions.h"
49#include "llvm/IR/Intrinsics.h"
50#include "llvm/MC/MCAsmInfo.h"
51#include "llvm/MC/MCContext.h"
52#include "llvm/MC/MCExpr.h"
53#include "llvm/MC/MCSymbol.h"
54#include "llvm/Support/Debug.h"
55#include "llvm/Support/ErrorHandling.h"
56#include "llvm/Support/MathExtras.h"
57#include "llvm/Target/TargetOptions.h"
58#include "ParabixGeneratedFuncs.h"
59#include <bitset>
60#include <cctype>
61#include <map>
62#include <string>
63using namespace llvm;
64
65#define DEBUG_TYPE "parabix"
66
67//v32i1 => i32, v32i2 => i64, etc
68MVT getFullRegisterType(MVT VT) {
69  MVT castType;
70  if (VT.is32BitVector())
71    castType = MVT::i32;
72  else if (VT.is64BitVector())
73    castType = MVT::i64;
74  else if (VT.is128BitVector())
75    castType = MVT::v2i64;
76  else
77    llvm_unreachable("unsupported parabix vector width");
78
79  return castType;
80}
81
82typedef std::pair<ISD::NodeType, MVT> CastAndOpKind;
83static std::map<CastAndOpKind, ISD::NodeType> CAOops;
84
85enum PXLegalizeAction {
86  InPlacePromote
87};
88
89typedef CastAndOpKind OpKind;
90static std::map<OpKind, PXLegalizeAction> OpKindActions;
91
92static void addCastAndOpKind(ISD::NodeType Op, MVT VT, ISD::NodeType ReplaceOp)
93{
94  CAOops[std::make_pair(Op, VT)] = ReplaceOp;
95}
96
97static void addOpKindAction(ISD::NodeType Op, MVT VT, PXLegalizeAction action)
98{
99  OpKindActions[std::make_pair(Op, VT)] = action;
100}
101
102static void resetOperations()
103{
104  //NEED: setOperationAction in X86ISelLowering with Custom
105  CAOops.clear();
106
107  //use XOR to simulate ADD on v32i1
108  addCastAndOpKind(ISD::ADD, MVT::v32i1, ISD::XOR);
109  addCastAndOpKind(ISD::SUB, MVT::v32i1, ISD::XOR);
110  addCastAndOpKind(ISD::MUL, MVT::v32i1, ISD::AND);
111  addCastAndOpKind(ISD::AND, MVT::v32i1, ISD::AND);
112  addCastAndOpKind(ISD::XOR, MVT::v32i1, ISD::XOR);
113  addCastAndOpKind(ISD::OR,  MVT::v32i1, ISD::OR);
114  addCastAndOpKind(ISD::MULHU,  MVT::v32i1, ISD::AND);
115
116  addCastAndOpKind(ISD::ADD, MVT::v64i1, ISD::XOR);
117  addCastAndOpKind(ISD::SUB, MVT::v64i1, ISD::XOR);
118  addCastAndOpKind(ISD::MUL, MVT::v64i1, ISD::AND);
119  addCastAndOpKind(ISD::AND, MVT::v64i1, ISD::AND);
120  addCastAndOpKind(ISD::XOR, MVT::v64i1, ISD::XOR);
121  addCastAndOpKind(ISD::OR,  MVT::v64i1, ISD::OR);
122  addCastAndOpKind(ISD::MULHU,  MVT::v64i1, ISD::AND);
123
124  addCastAndOpKind(ISD::ADD, MVT::v128i1, ISD::XOR);
125  addCastAndOpKind(ISD::SUB, MVT::v128i1, ISD::XOR);
126  addCastAndOpKind(ISD::MUL, MVT::v128i1, ISD::AND);
127  addCastAndOpKind(ISD::AND, MVT::v128i1, ISD::AND);
128  addCastAndOpKind(ISD::XOR, MVT::v128i1, ISD::XOR);
129  addCastAndOpKind(ISD::OR,  MVT::v128i1, ISD::OR);
130  addCastAndOpKind(ISD::MULHU,  MVT::v128i1, ISD::AND);
131
132  //cast v64i2 to v2i64 to lower logic ops.
133  addCastAndOpKind(ISD::AND, MVT::v64i2, ISD::AND);
134  addCastAndOpKind(ISD::XOR, MVT::v64i2, ISD::XOR);
135  addCastAndOpKind(ISD::OR,  MVT::v64i2, ISD::OR);
136
137  addCastAndOpKind(ISD::AND, MVT::v32i4, ISD::AND);
138  addCastAndOpKind(ISD::XOR, MVT::v32i4, ISD::XOR);
139  addCastAndOpKind(ISD::OR,  MVT::v32i4, ISD::OR);
140
141  //A custom lowering for v32i4 add is implmented. So ADD is not here.
142  addOpKindAction(ISD::SUB, MVT::v32i4, InPlacePromote);
143  addOpKindAction(ISD::MUL, MVT::v32i4, InPlacePromote);
144  addOpKindAction(ISD::SHL, MVT::v32i4, InPlacePromote);
145  addOpKindAction(ISD::SRL, MVT::v32i4, InPlacePromote);
146  addOpKindAction(ISD::SRA, MVT::v32i4, InPlacePromote);
147  addOpKindAction(ISD::SETCC, MVT::v32i4, InPlacePromote);
148
149  addOpKindAction(ISD::MUL, MVT::v16i8, InPlacePromote);
150}
151
152static SDValue getFullRegister(SDValue Op, SelectionDAG &DAG) {
153  MVT VT = Op.getSimpleValueType();
154  SDLoc dl(Op);
155
156  return DAG.getNode(ISD::BITCAST, dl, getFullRegisterType(VT), Op);
157}
158
159//Promote vector type in place, doubling fieldwidth within the same register
160//e.g. v32i4 => v16i8
161static MVT PromoteTypeDouble(MVT VT) {
162  unsigned RegisterWidth = VT.getSizeInBits();
163  unsigned FieldWidth = VT.getScalarSizeInBits();
164  unsigned NumElems = RegisterWidth / FieldWidth;
165
166  MVT ToVT = MVT::getVectorVT(MVT::getIntegerVT(FieldWidth * 2), NumElems / 2);
167  return ToVT;
168}
169
170//Root function for general policy lowering. Register the OpKind in resetOperations,
171//then the policy will be executed here.
172static SDValue lowerWithOpAction(SDValue Op, SelectionDAG &DAG) {
173  MVT VT = Op.getSimpleValueType();
174  OpKind kind = std::make_pair((ISD::NodeType)Op.getOpcode(), VT);
175  SDNodeTreeBuilder b(Op, &DAG);
176  unsigned RegisterWidth = VT.getSizeInBits();
177  unsigned FieldWidth = VT.getScalarSizeInBits();
178  SDValue Op0 = Op.getOperand(0);
179  SDValue Op1 = Op.getOperand(1);
180
181  switch (OpKindActions[kind]) {
182  default: llvm_unreachable("Unknown OpAction to lower parabix op");
183  case InPlacePromote:
184    MVT DoubleVT = PromoteTypeDouble(VT);
185    SDValue Himask = b.HiMask(RegisterWidth, FieldWidth * 2);
186    SDValue Lowmask = b.NOT(Himask);
187    SDValue HiBits, LowBits;
188
189    Op0 = getFullRegister(Op0, DAG);
190    Op1 = getFullRegister(Op1, DAG);
191
192    if (Op.getOpcode() == ISD::MUL) {
193      //MUL is a little different, needs to shift right high bits before calc
194      HiBits = b.SHL(FieldWidth, b.DoOp(DoubleVT,
195                                        b.SRL(FieldWidth, b.BITCAST(Op0, DoubleVT)),
196                                        b.SRL(FieldWidth, b.BITCAST(Op1, DoubleVT))));
197    } else if (Op.getOpcode() == ISD::SHL) {
198      // shift left
199      HiBits = b.DoOp(DoubleVT, b.AND(Op0, Himask),
200                      b.SRL(FieldWidth, b.BITCAST(Op1, DoubleVT)));
201    } else if (Op.getOpcode() == ISD::SRL || Op.getOpcode() == ISD::SRA) {
202      // shift right, logic or arithmetic are the same
203      HiBits = b.DoOp(DoubleVT,
204                      Op0, b.SRL(FieldWidth, b.BITCAST(Op1, DoubleVT)));
205    } else {
206      HiBits = b.DoOp(DoubleVT, b.AND(Op0, Himask), b.AND(Op1, Himask));
207    }
208
209    if (Op.getOpcode() == ISD::SETCC) {
210      //SETCC needs to shift the lowbits left, to properly set the sign bit.
211      LowBits = b.DoOp(DoubleVT,
212                       b.SHL(FieldWidth, b.BITCAST(Op0, DoubleVT)),
213                       b.SHL(FieldWidth, b.BITCAST(Op1, DoubleVT)));
214    } else if (Op.getOpcode() == ISD::SHL) {
215      //shift left
216      LowBits = b.DoOp(DoubleVT, Op0, b.AND(Op1, Lowmask));
217    } else if (Op.getOpcode() == ISD::SRL) {
218      //shift right logic
219      LowBits = b.DoOp(DoubleVT, b.AND(Op0, Lowmask), b.AND(Op1, Lowmask));
220    } else if (Op.getOpcode() == ISD::SRA) {
221      //shift right arithmetic. Need to shift left low half to high half to set sign bit
222      LowBits = b.SRL(FieldWidth,
223                      b.DoOp(DoubleVT, b.SHL(FieldWidth, b.BITCAST(Op0, DoubleVT)),
224                             b.AND(Op1, Lowmask)));
225
226    } else {
227      LowBits = b.DoOp(DoubleVT, Op0, Op1);
228    }
229
230    SDValue R = b.IFH1(Himask, HiBits, LowBits);
231    return b.BITCAST(R, VT);
232  }
233
234  llvm_unreachable("Reach the end of lowerWithOpAction");
235  return SDValue();
236}
237
238//Bitcast this vector to a full length integer and then do one op
239//Lookup op in CAOops map
240static SDValue lowerWithCastAndOp(SDValue Op, SelectionDAG &DAG) {
241  SDLoc dl(Op);
242  MVT VT = Op.getSimpleValueType();
243  SDValue A = Op.getOperand(0);
244  SDValue B = Op.getOperand(1);
245  CastAndOpKind kind = std::make_pair((ISD::NodeType)Op.getOpcode(), VT);
246
247  assert(CAOops.find(kind) != CAOops.end() && "Undefined cast and op kind");
248
249  MVT castType = getFullRegisterType(VT);
250  SDValue transA = DAG.getNode(ISD::BITCAST, dl, castType, A);
251  SDValue transB = DAG.getNode(ISD::BITCAST, dl, castType, B);
252  SDValue res = DAG.getNode(CAOops[kind], dl, castType, transA, transB);
253
254  return DAG.getNode(ISD::BITCAST, dl, VT, res);
255}
256
257//Don't do the lookup, use NewOp on casted operands.
258//If swap is set true, will swap operands
259static SDValue lowerWithCastAndOp(SDValue Op, SelectionDAG &DAG,
260                                  ISD::NodeType NewOp, bool Swap=false) {
261  SDLoc dl(Op);
262  MVT VT = Op.getSimpleValueType();
263  SDValue A = Op.getOperand(0);
264  SDValue B = Op.getOperand(1);
265
266  MVT castType = getFullRegisterType(VT);
267  SDValue transA = DAG.getNode(ISD::BITCAST, dl, castType, A);
268  SDValue transB = DAG.getNode(ISD::BITCAST, dl, castType, B);
269  SDValue res;
270  if (!Swap)
271    res = DAG.getNode(NewOp, dl, castType, transA, transB);
272  else
273    res = DAG.getNode(NewOp, dl, castType, transB, transA);
274
275  return DAG.getNode(ISD::BITCAST, dl, VT, res);
276}
277
278static SDValue PXLowerShift(SDValue Op, SelectionDAG &DAG) {
279  assert((Op.getOpcode() == ISD::SHL || Op.getOpcode() == ISD::SRA ||
280          Op.getOpcode() == ISD::SRL) && "Only lower shift ops here");
281
282  SDLoc dl(Op);
283  MVT VT = Op.getSimpleValueType();
284  SDValue A = Op.getOperand(0);
285  SDValue B = Op.getOperand(1);
286  SDValue res;
287  SDNodeTreeBuilder b(Op, &DAG);
288
289  MVT VectorEleType = VT.getVectorElementType();
290
291  if (VectorEleType == MVT::i1 && Op.getOpcode() != ISD::SRA) {
292    //SRL or SHL
293    res = b.AND(getFullRegister(A, DAG), b.NOT(getFullRegister(B, DAG)));
294  }
295  else if (VectorEleType == MVT::i1 && Op.getOpcode() == ISD::SRA) {
296    return A;
297  }
298  else if (VT == MVT::v64i2) {
299    if (Op.getOpcode() == ISD::SHL)
300      return GENLowerSHL(Op, DAG);
301    else if (Op.getOpcode() == ISD::SRL)
302      return GENLowerLSHR(Op, DAG);
303    else
304      return GENLowerASHR(Op, DAG);
305  } else
306    llvm_unreachable("lowering undefined parabix shift ops");
307
308  return DAG.getNode(ISD::BITCAST, dl, VT, res);
309}
310
311static SDValue PXLowerADD(SDValue Op, SelectionDAG &DAG) {
312  MVT VT = Op.getSimpleValueType();
313  SDNodeTreeBuilder b(Op, &DAG);
314  SDValue Op0 = Op.getOperand(0);
315  SDValue Op1 = Op.getOperand(1);
316
317  if (VT == MVT::v64i2) {
318    return GENLowerADD(Op, DAG);
319  }
320  else if (VT == MVT::v32i4) {
321    // Use mask = 0x8888... to mask out high bits and then we can do the i4 add
322    // with only one paddb.
323    std::string mask = "";
324    for (unsigned int i = 0; i < 16; i++) mask += "1000";
325    SDValue Mask = b.Constant(mask, MVT::v2i64);
326
327    MVT DoubleVT = MVT::v16i8;
328
329    SDValue Ah = b.AND(Mask, b.BITCAST(Op0, MVT::v2i64));
330    SDValue Bh = b.AND(Mask, b.BITCAST(Op1, MVT::v2i64));
331    SDValue R = b.DoOp(DoubleVT,
332                       b.AND(b.BITCAST(Op0, MVT::v2i64), b.NOT(Ah)),
333                       b.AND(b.BITCAST(Op1, MVT::v2i64), b.NOT(Bh)));
334    R = b.XOR(R, b.BITCAST(b.XOR(Ah, Bh), DoubleVT));
335
336    return b.BITCAST(R, VT);
337  }
338
339  llvm_unreachable("lowering add for unsupported type");
340  return SDValue();
341}
342
343static SDValue PXLowerSUB(SDValue Op, SelectionDAG &DAG) {
344  if (Op.getSimpleValueType() == MVT::v64i2) {
345    return GENLowerSUB(Op, DAG);
346  }
347
348  llvm_unreachable("lowering sub for unsupported type");
349  return SDValue();
350}
351
352static SDValue PXLowerMUL(SDValue Op, SelectionDAG &DAG) {
353  if (Op.getSimpleValueType() == MVT::v64i2) {
354    return GENLowerMUL(Op, DAG);
355  }
356
357  llvm_unreachable("only lowering parabix MUL");
358  return SDValue();
359}
360
361static SDValue getTruncateOrZeroExtend(SDValue V, SelectionDAG &DAG, MVT ToVT)
362{
363  SDNodeTreeBuilder b(V, &DAG);
364  MVT VT = V.getSimpleValueType();
365  if (VT.bitsLT(ToVT))
366    return b.ZERO_EXTEND(V, ToVT);
367  else if (VT.bitsGT(ToVT))
368    return b.TRUNCATE(V, ToVT);
369
370  return V;
371}
372
373static SDValue PXLowerINSERT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG, const X86Subtarget* Subtarget) {
374  MVT VT = Op.getSimpleValueType();
375  MVT FullVT = getFullRegisterType(VT);
376  SDNodeTreeBuilder b(Op, &DAG);
377
378  SDLoc dl(Op);
379  SDValue N0 = Op.getOperand(0); // vector <val>
380  SDValue N1 = Op.getOperand(1); // elt
381  SDValue N2 = Op.getOperand(2); // idx
382
383  int RegisterWidth = VT.getSizeInBits();
384  int NumElts = VT.getVectorNumElements();
385  int FieldWidth = RegisterWidth / NumElts;
386
387  if (VT == MVT::v32i1 || VT == MVT::v64i1) {
388    //Cast VT into full register and do bit manipulation.
389    SDValue TransN0 = getFullRegister(N0, DAG);
390    SDValue Res;
391
392    if (isa<ConstantSDNode>(N1)) {
393      if (cast<ConstantSDNode>(N1)->isNullValue()) {
394        //insert zero
395        SDValue Mask = DAG.getNode(ISD::SHL, dl, FullVT, DAG.getConstant(1, FullVT), N2);
396        SDValue NegMask = DAG.getNOT(dl, Mask, FullVT);
397        Res = DAG.getNode(ISD::AND, dl, FullVT, NegMask, TransN0);
398      } else {
399        //insert one
400        SDValue Mask = DAG.getNode(ISD::SHL, dl, FullVT, DAG.getConstant(1, FullVT), N2);
401        Res = DAG.getNode(ISD::OR, dl, FullVT, Mask, TransN0);
402      }
403    } else {
404      // Elt is not a constant node
405      // Mask = NOT(SHL(ZEXT(NOT(elt, i1), i32), idx))
406      // return AND(Vector, Mask)
407      // NOT is sensitive of bit width
408      SDValue NotV = DAG.getNode(ISD::AND, dl, MVT::i8, DAG.getConstant(1, MVT::i8),
409                                 DAG.getNOT(dl, N1, MVT::i8));
410      SDValue Zext = DAG.getNode(ISD::ZERO_EXTEND, dl, FullVT, NotV);
411      SDValue Mask = DAG.getNOT(dl, DAG.getNode(ISD::SHL, dl, FullVT, Zext, N2),
412                                FullVT);
413      Res = DAG.getNode(ISD::AND, dl, FullVT, Mask, TransN0);
414    }
415
416    // Cast back
417    return DAG.getNode(ISD::BITCAST, dl, VT, Res);
418  }
419  else {
420    //General strategy here
421    //extract an i16 from the vector and insert N1 into proper location.
422    //then, insert the modified i16 back
423    //SSE2 don't have extract i32, only have extract i16
424    assert(VT.getVectorElementType().bitsLE(MVT::i8) &&
425           "general INSERT_VECTOR_ELT only works with FieldWidth <= 8");
426
427    N2 = getTruncateOrZeroExtend(N2, DAG, MVT::i16);
428
429    int I16VecNumElts = RegisterWidth / 16;
430    MVT I16VecType = MVT::getVectorVT(MVT::i16, I16VecNumElts);
431    int lowbitsMask = (1 << FieldWidth) - 1;
432
433    SDValue IdxVec = b.UDIV(N2, b.Constant(NumElts / I16VecNumElts, MVT::i16));
434    SDValue IdxInside = b.UREM(N2, b.Constant(NumElts / I16VecNumElts, MVT::i16));
435
436    SDValue TransVal = b.BITCAST(N0, I16VecType);
437    SDValue ExtVal = b.EXTRACT_VECTOR_ELT(TransVal, IdxVec);
438
439    SDValue NewElt = b.SHL(b.ZERO_EXTEND(b.AND(N1, b.Constant(lowbitsMask, MVT::i8)),
440                                         MVT::i16),
441                           b.MUL(IdxInside, b.Constant(FieldWidth, MVT::i16)));
442    SDValue Mask   = b.SHL(b.Constant(lowbitsMask, MVT::i16),
443                           b.MUL(IdxInside, b.Constant(FieldWidth, MVT::i16)));
444
445    ExtVal = b.OR(b.AND(ExtVal, b.NOT(Mask)), NewElt);
446
447    // idx for insert_vector_elt should match the subtarget
448    MVT IdxValueType = MVT::i32;
449    if (Subtarget->is64Bit())
450        IdxValueType = MVT::i64;
451
452    return b.BITCAST(b.INSERT_VECTOR_ELT(TransVal, ExtVal, b.ZERO_EXTEND(IdxVec, IdxValueType)), VT);
453  }
454
455  llvm_unreachable("lowering insert_vector_elt for unsupported type");
456  return SDValue();
457}
458
459static SDValue PXLowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) {
460  SDLoc dl(Op);
461  SDValue Vec = Op.getOperand(0);
462  MVT VecVT = Vec.getSimpleValueType();
463  MVT FullVT = getFullRegisterType(VecVT);
464  SDValue Idx = Op.getOperand(1);
465  SDNodeTreeBuilder b(Op, &DAG);
466
467  int RegisterWidth = VecVT.getSizeInBits();
468  int NumElts = VecVT.getVectorNumElements();
469  int FieldWidth = RegisterWidth / NumElts;
470
471  if (VecVT == MVT::v32i1 || VecVT == MVT::v64i1) {
472    //TRUNC(AND(1, SRL(FULL_REG(VecVT), Idx)), i8)
473    SDValue TransV = DAG.getNode(ISD::BITCAST, dl, FullVT, Vec);
474    SDValue ShiftV = DAG.getNode(ISD::SRL, dl, FullVT, TransV, Idx);
475    return DAG.getNode(ISD::TRUNCATE, dl, MVT::i8,
476                       DAG.getNode(ISD::AND, dl, FullVT, ShiftV, DAG.getConstant(1, FullVT)));
477  }
478  else {
479    //General strategy here, extract i16 from the vector and then do shifting
480    //and truncate.
481    assert(VecVT.getVectorElementType().bitsLE(MVT::i8) &&
482           "general EXTRACT_VECTOR_ELT only works with FieldWidth <= 8");
483
484    Idx = getTruncateOrZeroExtend(Idx, DAG, MVT::i16);
485
486    int I16VecNumElts = RegisterWidth / 16;
487    MVT I16VecType = MVT::getVectorVT(MVT::i16, I16VecNumElts);
488
489    SDValue IdxVec = b.UDIV(Idx, b.Constant(NumElts / I16VecNumElts, MVT::i16));
490    SDValue IdxInside = b.UREM(Idx, b.Constant(NumElts / I16VecNumElts, MVT::i16));
491
492    SDValue TransVal = b.BITCAST(Vec, I16VecType);
493    SDValue ExtVal = b.EXTRACT_VECTOR_ELT(TransVal, IdxVec);
494
495    return b.TRUNCATE(b.AND(b.SRL(ExtVal, b.MUL(IdxInside, b.Constant(FieldWidth, MVT::i16))),
496                            b.Constant( (1 << FieldWidth) - 1, MVT::i16)),
497                      MVT::i8);
498  }
499
500  llvm_unreachable("lowering extract_vector_elt for unsupported type");
501  return SDValue();
502}
503
504static SDValue PXLowerSCALAR_TO_VECTOR(SDValue Op, SelectionDAG &DAG) {
505  SDLoc dl(Op);
506  MVT VecVT = Op.getSimpleValueType();
507  MVT FullVT = getFullRegisterType(VecVT);
508  SDValue Val = Op.getOperand(0);
509  EVT EltVT = VecVT.getVectorElementType();
510  SDNodeTreeBuilder b(Op, &DAG);
511
512  assert(EltVT.isInteger() && Val.getValueType().bitsGE(EltVT) &&
513         "incorrect scalar_to_vector parameters");
514
515  //Lowering assumes i8 is the smallest legal integer, which is true on X86
516  if (VecVT == MVT::v32i1 || VecVT == MVT::v64i1) {
517    SDValue Trunc = DAG.getNode(ISD::AND, dl, MVT::i8, DAG.getConstant(1, MVT::i8),
518                                DAG.getNode(ISD::TRUNCATE, dl, MVT::i8, Val));
519    SDValue Ext = DAG.getNode(ISD::ANY_EXTEND, dl, FullVT, Trunc);
520    return DAG.getNode(ISD::BITCAST, dl, VecVT, Ext);
521  }
522  else {
523    assert(Val.getSimpleValueType().bitsLE(MVT::i8) &&
524           "GetVectorFromScalarInteger only work with i2 or i4");
525
526    MVT I32VecType = MVT::getVectorVT(MVT::i32, VecVT.getSizeInBits() / 32);
527    int mask = (1 << EltVT.getSizeInBits()) - 1;
528
529    SDValue R1 = b.ANY_EXTEND(b.AND(b.TRUNCATE(Val, MVT::i8),
530                                    b.Constant(mask, MVT::i8)), MVT::i32);
531    SDValue R2 = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, I32VecType, R1);
532    return b.BITCAST(R2, VecVT);
533  }
534
535  llvm_unreachable("lowering unsupported scalar_to_vector");
536  return SDValue();
537}
538
539//get zero vector for parabix
540static SDValue getPXZeroVector(EVT VT, SDNodeTreeBuilder b) {
541  SDValue Vec;
542  if (VT.isSimple() && VT.getSimpleVT().is32BitVector()) {
543    // Careful here, don't use TargetConstant until you are sure.
544    Vec = b.Constant(0, MVT::i32);
545  } else if (VT.isSimple() && VT.getSimpleVT().is64BitVector()) {
546    Vec = b.Constant(0, MVT::i64);
547  } else if (VT.isSimple() && VT.getSimpleVT().is128BitVector()) {
548    Vec = b.ConstantVector(MVT::v4i32, 0);
549  } else if (VT.isSimple() && VT.getSimpleVT().is256BitVector()) {
550    Vec = b.ConstantVector(MVT::v8i32, 0);
551  } else
552    llvm_unreachable("Unexpected vector type");
553
554  return b.BITCAST(Vec, VT.getSimpleVT());
555}
556
557static SDValue getPXOnesVector(EVT VT, SDNodeTreeBuilder b) {
558  SDValue Vec;
559  if (VT.isSimple() && VT.getSimpleVT().is32BitVector()) {
560    // Careful here, don't use TargetConstant until you are sure.
561    Vec = b.Constant(-1, MVT::i32);
562  } else if (VT.isSimple() && VT.getSimpleVT().is64BitVector()) {
563    Vec = b.Constant(-1, MVT::i64);
564  } else if (VT.isSimple() && VT.getSimpleVT().is128BitVector()) {
565    Vec = b.ConstantVector(MVT::v4i32, -1);
566  } else if (VT.isSimple() && VT.getSimpleVT().is256BitVector()) {
567    Vec = b.ConstantVector(MVT::v8i32, -1);
568  } else
569    llvm_unreachable("Unexpected vector type");
570
571  return b.BITCAST(Vec, VT.getSimpleVT());
572}
573
574SDValue
575X86TargetLowering::PXLowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const {
576  SDLoc dl(Op);
577
578  MVT VT = Op.getSimpleValueType();
579  unsigned NumElems = Op.getNumOperands();
580  SDNodeTreeBuilder b(Op, &DAG);
581
582  // Vectors containing all zeros can be matched by pxor and xorps later
583  if (ISD::isBuildVectorAllZeros(Op.getNode())) {
584    // Canonicalize this to <4 x i32> to 1) ensure the zero vectors are CSE'd
585    // and 2) ensure that i64 scalars are eliminated on x86-32 hosts.
586    if (VT == MVT::v4i32 || VT == MVT::v8i32 || VT == MVT::v16i32)
587      return Op;
588
589    return getPXZeroVector(VT, b);
590  }
591
592  if (ISD::isBuildVectorAllOnes(Op.getNode())) {
593    return getPXOnesVector(VT, b);
594  }
595
596  if (VT == MVT::v32i1 || VT == MVT::v64i1 || VT == MVT::v128i1) {
597    //Brutely insert element
598    //TODO: improve efficiency of v128i1
599    MVT FullVT = getFullRegisterType(VT);
600    SDValue Base = DAG.getNode(ISD::BITCAST, dl, VT,
601                               DAG.getConstant(0, FullVT));
602    for (unsigned i = 0; i < NumElems; ++i) {
603      SDValue Elt = Op.getOperand(i);
604      if (Elt.getOpcode() == ISD::UNDEF)
605        continue;
606      Base = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, VT, Base, Elt,
607                         DAG.getConstant(i, MVT::i32));
608    }
609
610    return Base;
611  }
612
613  if (VT == MVT::v64i2) {
614    //Rearrange index and do 4 shifts and or
615    SmallVector<SmallVector<SDValue, 16>, 4> RearrangedVectors;
616    SmallVector<SDValue, 16> RV;
617    for (unsigned vi = 0; vi < 4; vi++) {
618      RV.clear();
619      for (unsigned i = vi; i < NumElems; i += 4) {
620        RV.push_back(Op.getOperand(i));
621      }
622
623      RearrangedVectors.push_back(RV);
624    }
625
626    //i2 is not legal on X86, so the 64 operands are all i8
627    SDValue V0 = b.BUILD_VECTOR(MVT::v16i8, RearrangedVectors[0]);
628    SDValue V1 = b.BUILD_VECTOR(MVT::v16i8, RearrangedVectors[1]);
629    SDValue V2 = b.BUILD_VECTOR(MVT::v16i8, RearrangedVectors[2]);
630    SDValue V3 = b.BUILD_VECTOR(MVT::v16i8, RearrangedVectors[3]);
631
632    return b.BITCAST(b.OR(b.OR(b.OR(V0, b.SHL<2>(V1)), b.SHL<4>(V2)), b.SHL<6>(V3)),
633                     MVT::v64i2);
634  }
635
636  if (VT == MVT::v32i4) {
637    //Rearrange index and do 2 shifts and or
638    //We have 32 x i8 as build_vector oprand, we build 2 v16i8, V0 and V1, then
639    //we can return the result as V0 | (V1 << 4), where
640    //V0 = build_vector(Op0, Op2, Op4, ... , Op30)
641    //V1 = build_vector(Op1, Op3, Op5, ... , Op31)
642    SmallVector<SDValue, 2> V;
643    SmallVector<SDValue, 16> RowV;
644    for (unsigned vi = 0; vi < 2; vi ++) {
645      RowV.clear();
646      for (unsigned i = vi; i < NumElems; i += 2) {
647        RowV.push_back(Op.getOperand(i));
648      }
649
650      V.push_back(b.BUILD_VECTOR(MVT::v16i8, RowV));
651    }
652
653    return b.BITCAST(b.OR(V[0], b.SHL<4>(V[1])), VT);
654  }
655
656  llvm_unreachable("lowering build_vector for unsupported type");
657  return SDValue();
658}
659
660static SDValue PXLowerSETCC(SDValue Op, SelectionDAG &DAG) {
661  MVT VT = Op.getSimpleValueType();
662  MVT FullVT = getFullRegisterType(VT);
663  SDLoc dl(Op);
664  SDValue Op0 = Op.getOperand(0);
665  SDValue Op1 = Op.getOperand(1);
666  ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(2))->get();
667  SDValue NEVec, TransA, TransB, Res, NotOp1, NotOp0;
668  SDNodeTreeBuilder b(Op, &DAG);
669
670  if (VT == MVT::v32i1 || VT == MVT::v64i1 || VT == MVT::v128i1) {
671    switch (CC) {
672    default: llvm_unreachable("Can't lower this parabix SETCC");
673    case ISD::SETNE:    return lowerWithCastAndOp(Op, DAG, ISD::XOR);
674    case ISD::SETEQ:
675      NEVec = lowerWithCastAndOp(Op, DAG, ISD::XOR);
676      return DAG.getNOT(dl, NEVec, VT);
677    case ISD::SETLT:
678    case ISD::SETUGT:
679      NotOp1 = DAG.getNOT(dl, Op1, VT);
680      TransA = DAG.getNode(ISD::BITCAST, dl, FullVT, Op0);
681      TransB = DAG.getNode(ISD::BITCAST, dl, FullVT, NotOp1);
682      Res = DAG.getNode(ISD::AND, dl, FullVT, TransA, TransB);
683      return DAG.getNode(ISD::BITCAST, dl, VT, Res);
684    case ISD::SETGT:
685    case ISD::SETULT:
686      NotOp0 = DAG.getNOT(dl, getFullRegister(Op0, DAG), FullVT);
687      Res = DAG.getNode(ISD::AND, dl, FullVT, NotOp0,
688                                getFullRegister(Op1, DAG));
689      return DAG.getNode(ISD::BITCAST, dl, VT, Res);
690    case ISD::SETLE:
691    case ISD::SETUGE:
692      Res = DAG.getNode(ISD::SETCC, dl, VT, Op0, Op1, DAG.getCondCode(ISD::SETGT));
693      return DAG.getNOT(dl, Res, VT);
694    case ISD::SETGE:
695    case ISD::SETULE:
696      Res = DAG.getNode(ISD::SETCC, dl, VT, Op0, Op1, DAG.getCondCode(ISD::SETLT));
697      return DAG.getNOT(dl, Res, VT);
698    }
699  }
700  else if (VT == MVT::v64i2) {
701    switch (CC) {
702    default: llvm_unreachable("Can't lower this parabix SETCC");
703    case ISD::SETEQ:  return GENLowerICMP_EQ(Op, DAG);
704    case ISD::SETLT:  return GENLowerICMP_SLT(Op, DAG);
705    case ISD::SETGT:  return GENLowerICMP_SGT(Op, DAG);
706    case ISD::SETULT: return GENLowerICMP_ULT(Op, DAG);
707    case ISD::SETUGT: return GENLowerICMP_UGT(Op, DAG);
708    case ISD::SETNE:
709      Res = GENLowerICMP_EQ(Op, DAG);
710      return b.NOT(Res);
711    case ISD::SETGE:
712      Res = GENLowerICMP_SLT(Op, DAG);
713      return b.NOT(Res);
714    case ISD::SETLE:
715      Res = GENLowerICMP_SGT(Op, DAG);
716      return b.NOT(Res);
717    case ISD::SETUGE:
718      Res = GENLowerICMP_ULT(Op, DAG);
719      return b.NOT(Res);
720    case ISD::SETULE:
721      Res = GENLowerICMP_UGT(Op, DAG);
722      return b.NOT(Res);
723    }
724  }
725
726  llvm_unreachable("only lowering parabix SETCC");
727  return SDValue();
728}
729
730///Entrance for parabix lowering.
731SDValue X86TargetLowering::LowerParabixOperation(SDValue Op, SelectionDAG &DAG) const {
732  //NEED: setOperationAction in target specific lowering (X86ISelLowering.cpp)
733  DEBUG(dbgs() << "Parabix Lowering:" << "\n"; Op.dump());
734
735  //Only resetOperations for the first time.
736  static bool FirstTimeThrough = true;
737  if (FirstTimeThrough) {
738    //dbgs() << "Parabix Lowering:" << "\n"; Op.dump();
739    resetOperations();
740    FirstTimeThrough = false;
741  }
742
743  MVT VT = Op.getSimpleValueType();
744  //Check if we have registered CastAndOp action
745  CastAndOpKind kind = std::make_pair((ISD::NodeType)Op.getOpcode(), VT);
746  if (CAOops.find(kind) != CAOops.end())
747    return lowerWithCastAndOp(Op, DAG);
748  //Check general policy
749  if (OpKindActions.find((OpKind) kind) != OpKindActions.end())
750    return lowerWithOpAction(Op, DAG);
751
752  switch (Op.getOpcode()) {
753  default: llvm_unreachable("[ROOT SWITCH] Should not custom lower this parabix op!");
754  case ISD::ADD:                return PXLowerADD(Op, DAG);
755  case ISD::SUB:                return PXLowerSUB(Op, DAG);
756  case ISD::MUL:                return PXLowerMUL(Op, DAG);
757  case ISD::BUILD_VECTOR:       return PXLowerBUILD_VECTOR(Op, DAG);
758  case ISD::SHL:
759  case ISD::SRA:
760  case ISD::SRL:                return PXLowerShift(Op, DAG);
761  case ISD::INSERT_VECTOR_ELT:  return PXLowerINSERT_VECTOR_ELT(Op, DAG, Subtarget);
762  case ISD::EXTRACT_VECTOR_ELT: return PXLowerEXTRACT_VECTOR_ELT(Op, DAG);
763  case ISD::SCALAR_TO_VECTOR:   return PXLowerSCALAR_TO_VECTOR(Op, DAG);
764  case ISD::SETCC:              return PXLowerSETCC(Op, DAG);
765  }
766}
767
768static SDValue PXPerformVSELECTCombine(SDNode *N, SelectionDAG &DAG,
769                                    TargetLowering::DAGCombinerInfo &DCI,
770                                    const X86Subtarget *Subtarget) {
771  MVT VT = N->getSimpleValueType(0);
772  SDValue Mask = N->getOperand(0);
773  MVT MaskTy = Mask.getSimpleValueType();
774  SDLoc dl(N);
775
776  SDNodeTreeBuilder b(&DAG, dl);
777
778  if (DCI.isBeforeLegalize()) {
779    //v128i1 (select v128i1, v128i1, v128i1) can be combined into logical ops
780    if (MaskTy == MVT::v128i1 && VT == MVT::v128i1) {
781      DEBUG(dbgs() << "Combining select v128i1 \n");
782      return b.IFH1(Mask, N->getOperand(0), N->getOperand(1));
783   }
784  }
785
786  //v32i8 (select v32i1, v32i8, v32i8) don't have proper lowering on AVX2, so
787  //we convert the mask to v32i8
788  if (MaskTy == MVT::v32i1 && VT == MVT::v32i8 &&
789      (Subtarget->hasAVX2() || Subtarget->hasAVX())) {
790    Mask = DAG.getNode(ISD::SIGN_EXTEND, dl, VT, Mask);
791    DCI.AddToWorklist(Mask.getNode());
792    return DAG.getNode(N->getOpcode(), dl, VT, Mask, N->getOperand(1), N->getOperand(2));
793  }
794
795  return SDValue();
796}
797
798/// isUndefOrEqual - Val is either less than zero (undef) or equal to the
799/// specified value.
800static bool isUndefOrEqual(int Val, int CmpVal) {
801  return (Val < 0 || Val == CmpVal);
802}
803
804//Check whether the shuffle node is same as IDISA simd<16>::packl
805//Convert packed 16-bit integers from a and b to packed 8-bit integers
806//Collect all the low parts of vectors a and b
807static bool isPackLowMask(ShuffleVectorSDNode *SVOp) {
808  EVT VT = SVOp->getValueType(0);
809  unsigned NumElems = VT.getVectorNumElements();
810
811  //v16i8 (shufflevector v16i8, v16i8, <0, 2, 4, 6, 8, ..., 30>)
812  for (unsigned i = 0; i < NumElems; i++) {
813    if (!isUndefOrEqual(SVOp->getMaskElt(i), i * 2))
814      return false;
815  }
816
817  return true;
818}
819
820//Check whether the shuffle node is same as IDISA simd<16>::packh
821//Collect all the high parts of vectors
822static bool isPackHighMask(ShuffleVectorSDNode *SVOp) {
823  EVT VT = SVOp->getValueType(0);
824  unsigned NumElems = VT.getVectorNumElements();
825
826  //v16i8 (shufflevector v16i8, v16i8, <1, 3, 5, 7, ..., 31>)
827  for (unsigned i = 0; i < NumElems; i++) {
828    if (!isUndefOrEqual(SVOp->getMaskElt(i), i * 2 + 1))
829      return false;
830  }
831
832  return true;
833}
834
835static SDValue PXPerformVECTOR_SHUFFLECombine(SDNode *N, SelectionDAG &DAG,
836                                    TargetLowering::DAGCombinerInfo &DCI,
837                                    const X86Subtarget *Subtarget) {
838  MVT VT = N->getSimpleValueType(0);
839  SDLoc dl(N);
840  ShuffleVectorSDNode *SVOp = cast<ShuffleVectorSDNode>(N);
841  SDValue V1 = SVOp->getOperand(0);
842  SDValue V2 = SVOp->getOperand(1);
843  SDNodeTreeBuilder b(&DAG, dl);
844
845  //v16i8 (vector_shuffle v16i8, v16i8, v16i32) can be combined into
846  //X86ISD::PACKUS
847  //simd<16>::packl
848  if (Subtarget->hasSSE2() && VT == MVT::v16i8 && isPackLowMask(SVOp) &&
849      V1.getOpcode() != ISD::UNDEF && V2.getOpcode() != ISD::UNDEF) {
850    DEBUG(dbgs() << "Parabix combine: \n"; N->dumpr());
851
852    //00000000111111110000000011111111
853    SDValue LowMaskInteger = b.Constant(16711935, MVT::i32);
854    SDValue VPool[] = {LowMaskInteger, LowMaskInteger, LowMaskInteger, LowMaskInteger};
855    SDValue LowMask16 = b.BITCAST(b.BUILD_VECTOR(MVT::v4i32, VPool), MVT::v8i16);
856
857    SDValue NewV1 = b.AND(LowMask16, b.BITCAST(V1, MVT::v8i16));
858    SDValue NewV2 = b.AND(LowMask16, b.BITCAST(V2, MVT::v8i16));
859
860    DCI.AddToWorklist(LowMask16.getNode());
861    DCI.AddToWorklist(NewV1.getNode());
862    DCI.AddToWorklist(NewV2.getNode());
863
864    return DAG.getNode(X86ISD::PACKUS, dl, MVT::v16i8, NewV1, NewV2);
865  }
866
867  //X86ISD::PACKUS cont.
868  //For simd<16>::packh
869  if (Subtarget->hasSSE2() && VT == MVT::v16i8 && isPackHighMask(SVOp) &&
870      V1.getOpcode() != ISD::UNDEF && V2.getOpcode() != ISD::UNDEF) {
871    DEBUG(dbgs() << "Parabix combine: \n"; N->dumpr());
872
873    SDValue Cst = b.Constant(8, MVT::i16);
874    SDValue VPool[] = {Cst, Cst, Cst, Cst, Cst, Cst, Cst, Cst};
875    SDValue Shift = b.BUILD_VECTOR(MVT::v8i16, VPool);
876
877    SDValue NewV1 = b.SRL(b.BITCAST(V1, MVT::v8i16), Shift);
878    SDValue NewV2 = b.SRL(b.BITCAST(V2, MVT::v8i16), Shift);
879
880    DCI.AddToWorklist(Shift.getNode());
881    DCI.AddToWorklist(NewV1.getNode());
882    DCI.AddToWorklist(NewV2.getNode());
883
884    return DAG.getNode(X86ISD::PACKUS, dl, MVT::v16i8, NewV1, NewV2);
885  }
886
887  //PEXT for simd<2, 4, 8>::packl or packh
888  //the Mask is the only thing different
889  if (Subtarget->hasBMI2() && Subtarget->is64Bit() &&
890      (isPackLowMask(SVOp) || isPackHighMask(SVOp)) &&
891      (VT == MVT::v32i4 || VT == MVT::v64i2 || VT == MVT::v128i1)) {
892    DEBUG(dbgs() << "Parabix combine: \n"; N->dumpr());
893
894    std::string Mask;
895    if (isPackLowMask(SVOp)) {
896      switch (VT.SimpleTy) {
897      default: llvm_unreachable("unsupported type");
898      case MVT::v32i4:
899        //simd<8>::packl
900        Mask = "0000111100001111000011110000111100001111000011110000111100001111";
901        break;
902      case MVT::v64i2:
903        //simd<4>::packl
904        Mask = "0011001100110011001100110011001100110011001100110011001100110011";
905        break;
906      case MVT::v128i1:
907        //simd<2>::packl
908        Mask = "0101010101010101010101010101010101010101010101010101010101010101";
909        break;
910      }
911    } else if (isPackHighMask(SVOp)) {
912      switch (VT.SimpleTy) {
913      default: llvm_unreachable("unsupported type");
914      case MVT::v32i4:
915        //simd<8>::packl
916        Mask = "1111000011110000111100001111000011110000111100001111000011110000";
917        break;
918      case MVT::v64i2:
919        //simd<4>::packl
920        Mask = "1100110011001100110011001100110011001100110011001100110011001100";
921        break;
922      case MVT::v128i1:
923        //simd<2>::packl
924        Mask = "1010101010101010101010101010101010101010101010101010101010101010";
925        break;
926      }
927    }
928
929    APInt MaskInt(64, Mask, 2);
930    SDValue MaskNode = DAG.getConstant(MaskInt, MVT::i64);
931
932    SDValue A = b.BITCAST(V1, MVT::v2i64);
933    SDValue A0 = b.EXTRACT_VECTOR_ELT(A, b.Constant(0));
934    SDValue A1 = b.EXTRACT_VECTOR_ELT(A, b.Constant(1));
935
936    SDValue B = b.BITCAST(V2, MVT::v2i64);
937    SDValue B0 = b.EXTRACT_VECTOR_ELT(B, b.Constant(0));
938    SDValue B1 = b.EXTRACT_VECTOR_ELT(B, b.Constant(1));
939
940    //There are 2 ways of implementation at this point. OR/SHL is the first one.
941    //It will generate 3 more ops for each packh/l, but have better performance
942    //for whole transposition.
943    SDValue P0 = b.OR(b.PEXT64(A0, MaskNode),
944                      b.SHL(b.PEXT64(A1, MaskNode), b.Constant(32, MVT::i64)));
945    SDValue P1 = b.OR(b.PEXT64(B0, MaskNode),
946                      b.SHL(b.PEXT64(B1, MaskNode), b.Constant(32, MVT::i64)));
947    SDValue P[] = {P0, P1};
948    return b.BITCAST(b.BUILD_VECTOR(MVT::v2i64, P), VT);
949
950    //////////////////////////////////////
951    //Below is the second implementation. Less instructions will be generated,
952    //but hurt the whole performance.
953
954    //SDValue P0 = b.TRUNCATE(b.PEXT64(A0, MaskNode), MVT::i32);
955    //SDValue P1 = b.TRUNCATE(b.PEXT64(A1, MaskNode), MVT::i32);
956    //SDValue P2 = b.TRUNCATE(b.PEXT64(B0, MaskNode), MVT::i32);
957    //SDValue P3 = b.TRUNCATE(b.PEXT64(B1, MaskNode), MVT::i32);
958
959    //SDValue P[] = {P0, P1, P2, P3};
960    //return b.BITCAST(b.BUILD_VECTOR(MVT::v4i32, P), VT);
961  }
962
963  return SDValue();
964}
965
966static bool isImmediateShiftingMask(SDValue Mask, int &imm) {
967  if (Mask.getOpcode() != ISD::BUILD_VECTOR)
968    return false;
969
970  bool FirstImmediate = true;
971  uint64_t ImmNumber;
972
973  for (unsigned i = 0, e = Mask.getNumOperands(); i != e; ++i) {
974    SDValue Op = Mask.getOperand(i);
975    if (Op.getOpcode() == ISD::UNDEF)
976      continue;
977    if (!isa<ConstantSDNode>(Op))
978      return false;
979
980    if (FirstImmediate) {
981      FirstImmediate = false;
982      ImmNumber = cast<ConstantSDNode>(Op)->getZExtValue();
983    }
984    else if (cast<ConstantSDNode>(Op)->getZExtValue() != ImmNumber) {
985      return false;
986    }
987  }
988
989  imm = (int) ImmNumber;
990  return true;
991}
992
993static SDValue PXPerformShiftCombine(SDNode *N, SelectionDAG &DAG,
994                                     TargetLowering::DAGCombinerInfo &DCI,
995                                     const X86Subtarget *Subtarget) {
996  MVT VT = N->getSimpleValueType(0);
997  SDLoc dl(N);
998  SDValue V1 = N->getOperand(0);
999  SDValue V2 = N->getOperand(1);
1000  SDNodeTreeBuilder b(&DAG, dl);
1001
1002  assert((N->getOpcode() == ISD::SHL || N->getOpcode() == ISD::SRL ||
1003          N->getOpcode() == ISD::SRA) && "Only lowering shift");
1004
1005  int imm;
1006
1007  //Optimize immediate shiftings.
1008  if (Subtarget->hasSSE2() && VT == MVT::v32i4 && isImmediateShiftingMask(V2, imm)) {
1009
1010    MVT I32VecType = MVT::v4i32;
1011
1012    if (N->getOpcode() == ISD::SHL) {
1013      DEBUG(dbgs() << "Parabix combining: "; N->dump());
1014
1015      SDValue R = b.AND(b.SHL(imm, b.BITCAST(V1, I32VecType)),
1016                        b.BITCAST(b.ConstantVector(MVT::v32i4, (15 << imm) & 15), I32VecType));
1017      return b.BITCAST(R, VT);
1018    }
1019    else if (N->getOpcode() == ISD::SRL) {
1020      DEBUG(dbgs() << "Parabix combining: "; N->dump());
1021
1022      SDValue R = b.AND(b.SRL(imm, b.BITCAST(V1, I32VecType)),
1023                        b.BITCAST(b.ConstantVector(MVT::v32i4, 15 >> imm), I32VecType));
1024      return b.BITCAST(R, VT);
1025    }
1026  }
1027
1028  return SDValue();
1029}
1030
1031static SDValue PXPerformUADDO(SDNode *N, SelectionDAG &DAG,
1032                                     TargetLowering::DAGCombinerInfo &DCI,
1033                                     const X86Subtarget *Subtarget) {
1034  MVT VT = N->getSimpleValueType(0);
1035  SDLoc dl(N);
1036  SDValue V1 = N->getOperand(0);
1037  SDValue V2 = N->getOperand(1);
1038  SDNodeTreeBuilder b(&DAG, dl);
1039
1040  if (DCI.isBeforeLegalize() && Subtarget->hasSSE2() && VT == MVT::i128) {
1041    DEBUG(dbgs() << "Parabix combining: "; N->dump());
1042
1043    //general logic for uadd.with.overflow.iXXX
1044    int RegisterWidth = VT.getSizeInBits();
1045    int f = RegisterWidth / 64;
1046    MVT VXi64Ty = MVT::getVectorVT(MVT::i64, f);
1047    MVT MaskTy = MVT::getIntegerVT(f);
1048    MVT MaskVecTy = MVT::getVectorVT(MVT::i1, f);
1049
1050    SDValue X = b.BITCAST(V1, VXi64Ty);
1051    SDValue Y = b.BITCAST(V2, VXi64Ty);
1052    SDValue R = b.ADD(X, Y);
1053
1054    SDValue Ones = getPXOnesVector(VXi64Ty, b);
1055
1056    //x = hsimd<64>::signmask(X), x, y, r are all i32 type
1057    SDValue x, y, r;
1058    if (f == 2) {
1059      //i128, v2i1 to i2 seems to be problematic
1060      x = b.SignMask2x64(X);
1061      y = b.SignMask2x64(Y);
1062      r = b.SignMask2x64(R);
1063    }
1064    else if (f == 4) {
1065      //i256
1066      x = b.SignMask4x64(X);
1067      y = b.SignMask4x64(Y);
1068      r = b.SignMask4x64(R);
1069    }
1070    else
1071    {
1072      //i512, i1024, ..., i4096
1073      SDValue Zero = getPXZeroVector(VXi64Ty, b);
1074      x = b.ZERO_EXTEND(b.BITCAST(b.SETCC(X, Zero, ISD::SETLT), MaskTy), MVT::i32);
1075      y = b.ZERO_EXTEND(b.BITCAST(b.SETCC(Y, Zero, ISD::SETLT), MaskTy), MVT::i32);
1076      r = b.ZERO_EXTEND(b.BITCAST(b.SETCC(R, Zero, ISD::SETLT), MaskTy), MVT::i32);
1077    }
1078
1079    SDValue carry = b.OR(b.AND(x, y), b.AND(b.OR(x, y), b.NOT(r)));
1080    SDValue bubble = b.ZERO_EXTEND(b.BITCAST(b.SETCC(R, Ones, ISD::SETEQ), MaskTy), MVT::i32);
1081
1082    SDValue increments = b.MatchStar(b.SHL(carry, b.Constant(1, MVT::i32)), bubble);
1083    SDValue carry_out = b.TRUNCATE(b.SRL(increments, b.Constant(f, MVT::i32)), MVT::i1);
1084
1085    SDValue spread = b.ZERO_EXTEND(b.BITCAST(b.TRUNCATE(increments, MaskTy), MaskVecTy),
1086                                   VXi64Ty);
1087    SDValue sum = b.BITCAST(b.ADD(R, spread), VT);
1088
1089    SDValue Pool[] = {sum, carry_out};
1090    SDValue Ret = DAG.getMergeValues(Pool, dl);
1091
1092    DEBUG(dbgs() << "Combined into: \n"; Ret.dumpr());
1093
1094    return Ret;
1095  }
1096
1097  return SDValue();
1098}
1099
1100SDValue X86TargetLowering::PerformParabixDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const
1101{
1102  //For now, only combine simple value type.
1103  if (!N->getValueType(0).isSimple()) return SDValue();
1104
1105  SelectionDAG &DAG = DCI.DAG;
1106  switch (N->getOpcode()) {
1107  default: break;
1108  case ISD::VSELECT:            return PXPerformVSELECTCombine(N, DAG, DCI, Subtarget);
1109  case ISD::VECTOR_SHUFFLE:     return PXPerformVECTOR_SHUFFLECombine(N, DAG, DCI, Subtarget);
1110  case ISD::SHL:
1111  case ISD::SRL:                return PXPerformShiftCombine(N, DAG, DCI, Subtarget);
1112  case ISD::UADDO:              return PXPerformUADDO(N, DAG, DCI, Subtarget);
1113  }
1114
1115  return SDValue();
1116}
1117
Note: See TracBrowser for help on using the repository browser.