Skip to content

Commit b467405

Browse files
author
git apple-llvm automerger
committed
Merge commit 'ef927ae26318' from llvm.org/main into next
2 parents fd2b29a + ef927ae commit b467405

File tree

4 files changed

+297
-36
lines changed

4 files changed

+297
-36
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 47 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -15265,18 +15265,22 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
1526515265
break;
1526615266
}
1526715267
case RISCVISD::PASUB:
15268-
case RISCVISD::PASUBU: {
15268+
case RISCVISD::PASUBU:
15269+
case RISCVISD::PMULHSU: {
1526915270
MVT VT = N->getSimpleValueType(0);
1527015271
SDValue Op0 = N->getOperand(0);
1527115272
SDValue Op1 = N->getOperand(1);
15272-
assert(VT == MVT::v2i16 || VT == MVT::v4i8);
15273+
unsigned Opcode = N->getOpcode();
15274+
// PMULHSU doesn't support i8 variants
15275+
assert(VT == MVT::v2i16 ||
15276+
(Opcode != RISCVISD::PMULHSU && VT == MVT::v4i8));
1527315277
MVT NewVT = MVT::v4i16;
1527415278
if (VT == MVT::v4i8)
1527515279
NewVT = MVT::v8i8;
1527615280
SDValue Undef = DAG.getUNDEF(VT);
1527715281
Op0 = DAG.getNode(ISD::CONCAT_VECTORS, DL, NewVT, {Op0, Undef});
1527815282
Op1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, NewVT, {Op1, Undef});
15279-
Results.push_back(DAG.getNode(N->getOpcode(), DL, NewVT, {Op0, Op1}));
15283+
Results.push_back(DAG.getNode(Opcode, DL, NewVT, {Op0, Op1}));
1528015284
return;
1528115285
}
1528215286
case ISD::EXTRACT_VECTOR_ELT: {
@@ -16386,9 +16390,9 @@ static SDValue combineTruncSelectToSMaxUSat(SDNode *N, SelectionDAG &DAG) {
1638616390
return DAG.getNode(ISD::TRUNCATE, DL, VT, Min);
1638716391
}
1638816392

16389-
// Handle P extension averaging subtraction pattern:
16390-
// (vXiY (trunc (srl (sub ([s|z]ext vXiY:$a), ([s|z]ext vXiY:$b)), 1)))
16391-
// -> PASUB/PASUBU
16393+
// Handle P extension truncate patterns:
16394+
// PASUB/PASUBU: (trunc (srl (sub ([s|z]ext a), ([s|z]ext b)), 1))
16395+
// PMULHSU: (trunc (srl (mul (sext a), (zext b)), EltBits))
1639216396
static SDValue combinePExtTruncate(SDNode *N, SelectionDAG &DAG,
1639316397
const RISCVSubtarget &Subtarget) {
1639416398
SDValue N0 = N->getOperand(0);
@@ -16401,7 +16405,7 @@ static SDValue combinePExtTruncate(SDNode *N, SelectionDAG &DAG,
1640116405
VecVT != MVT::v4i8 && VecVT != MVT::v2i32)
1640216406
return SDValue();
1640316407

16404-
// Check if shift amount is 1
16408+
// Check if shift amount is a splat constant
1640516409
SDValue ShAmt = N0.getOperand(1);
1640616410
if (ShAmt.getOpcode() != ISD::BUILD_VECTOR)
1640716411
return SDValue();
@@ -16415,44 +16419,57 @@ static SDValue combinePExtTruncate(SDNode *N, SelectionDAG &DAG,
1641516419
ConstantSDNode *C = dyn_cast<ConstantSDNode>(Splat);
1641616420
if (!C)
1641716421
return SDValue();
16418-
if (C->getZExtValue() != 1)
16419-
return SDValue();
1642016422

16421-
// Check for SUB operation
16422-
SDValue Sub = N0.getOperand(0);
16423-
if (Sub.getOpcode() != ISD::SUB)
16424-
return SDValue();
16423+
SDValue Op = N0.getOperand(0);
16424+
unsigned ShAmtVal = C->getZExtValue();
1642516425

16426-
SDValue LHS = Sub.getOperand(0);
16427-
SDValue RHS = Sub.getOperand(1);
16426+
SDValue LHS = Op.getOperand(0);
16427+
SDValue RHS = Op.getOperand(1);
1642816428

16429-
// Check if both operands are sign/zero extends from the target
16430-
// type
16431-
bool IsSignExt = LHS.getOpcode() == ISD::SIGN_EXTEND &&
16432-
RHS.getOpcode() == ISD::SIGN_EXTEND;
16433-
bool IsZeroExt = LHS.getOpcode() == ISD::ZERO_EXTEND &&
16434-
RHS.getOpcode() == ISD::ZERO_EXTEND;
16429+
bool LHSIsSExt = LHS.getOpcode() == ISD::SIGN_EXTEND;
16430+
bool LHSIsZExt = LHS.getOpcode() == ISD::ZERO_EXTEND;
16431+
bool RHSIsSExt = RHS.getOpcode() == ISD::SIGN_EXTEND;
16432+
bool RHSIsZExt = RHS.getOpcode() == ISD::ZERO_EXTEND;
1643516433

16436-
if (!IsSignExt && !IsZeroExt)
16434+
if (!(LHSIsSExt || LHSIsZExt) || !(RHSIsSExt || RHSIsZExt))
1643716435
return SDValue();
1643816436

1643916437
SDValue A = LHS.getOperand(0);
1644016438
SDValue B = RHS.getOperand(0);
1644116439

16442-
// Check if the extends are from our target vector type
1644316440
if (A.getValueType() != VT || B.getValueType() != VT)
1644416441
return SDValue();
1644516442

16446-
// Determine the instruction based on type and signedness
1644716443
unsigned Opc;
16448-
if (IsSignExt)
16449-
Opc = RISCVISD::PASUB;
16450-
else if (IsZeroExt)
16451-
Opc = RISCVISD::PASUBU;
16452-
else
16444+
switch (Op.getOpcode()) {
16445+
default:
1645316446
return SDValue();
16447+
case ISD::SUB:
16448+
// PASUB/PASUBU: shift amount must be 1
16449+
if (ShAmtVal != 1)
16450+
return SDValue();
16451+
if (LHSIsSExt && RHSIsSExt)
16452+
Opc = RISCVISD::PASUB;
16453+
else if (LHSIsZExt && RHSIsZExt)
16454+
Opc = RISCVISD::PASUBU;
16455+
else
16456+
return SDValue();
16457+
break;
16458+
case ISD::MUL:
16459+
// PMULHSU: shift amount must be element size, only for i16/i32
16460+
unsigned EltBits = VecVT.getScalarSizeInBits();
16461+
if (ShAmtVal != EltBits || (EltBits != 16 && EltBits != 32))
16462+
return SDValue();
16463+
if ((LHSIsSExt && RHSIsZExt) || (LHSIsZExt && RHSIsSExt)) {
16464+
Opc = RISCVISD::PMULHSU;
16465+
// commuted case
16466+
if (LHSIsZExt && RHSIsSExt)
16467+
std::swap(A, B);
16468+
} else
16469+
return SDValue();
16470+
break;
16471+
}
1645416472

16455-
// Create the machine node directly
1645616473
return DAG.getNode(Opc, SDLoc(N), VT, {A, B});
1645716474
}
1645816475

llvm/lib/Target/RISCV/RISCVInstrInfoP.td

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1463,12 +1463,13 @@ let Predicates = [HasStdExtP, IsRV32] in {
14631463

14641464
def riscv_absw : RVSDNode<"ABSW", SDT_RISCVIntUnaryOpW>;
14651465

1466-
def SDT_RISCVPASUB : SDTypeProfile<1, 2, [SDTCisVec<0>,
1467-
SDTCisInt<0>,
1468-
SDTCisSameAs<0, 1>,
1469-
SDTCisSameAs<0, 2>]>;
1470-
def riscv_pasub : RVSDNode<"PASUB", SDT_RISCVPASUB>;
1471-
def riscv_pasubu : RVSDNode<"PASUBU", SDT_RISCVPASUB>;
1466+
def SDT_RISCVPBinOp : SDTypeProfile<1, 2, [SDTCisVec<0>,
1467+
SDTCisInt<0>,
1468+
SDTCisSameAs<0, 1>,
1469+
SDTCisSameAs<0, 2>]>;
1470+
def riscv_pasub : RVSDNode<"PASUB", SDT_RISCVPBinOp>;
1471+
def riscv_pasubu : RVSDNode<"PASUBU", SDT_RISCVPBinOp>;
1472+
def riscv_pmulhsu : RVSDNode<"PMULHSU", SDT_RISCVPBinOp>;
14721473

14731474
let Predicates = [HasStdExtP] in {
14741475
def : PatGpr<abs, ABS>;
@@ -1513,6 +1514,11 @@ let Predicates = [HasStdExtP] in {
15131514
def: Pat<(XLenVecI16VT (abds GPR:$rs1, GPR:$rs2)), (PABD_H GPR:$rs1, GPR:$rs2)>;
15141515
def: Pat<(XLenVecI16VT (abdu GPR:$rs1, GPR:$rs2)), (PABDU_H GPR:$rs1, GPR:$rs2)>;
15151516

1517+
// 16-bit multiply high patterns
1518+
def: Pat<(XLenVecI16VT (mulhs GPR:$rs1, GPR:$rs2)), (PMULH_H GPR:$rs1, GPR:$rs2)>;
1519+
def: Pat<(XLenVecI16VT (mulhu GPR:$rs1, GPR:$rs2)), (PMULHU_H GPR:$rs1, GPR:$rs2)>;
1520+
def: Pat<(XLenVecI16VT (riscv_pmulhsu GPR:$rs1, GPR:$rs2)), (PMULHSU_H GPR:$rs1, GPR:$rs2)>;
1521+
15161522
// 8-bit logical shift left/right patterns
15171523
def: Pat<(XLenVecI8VT (shl GPR:$rs1, (XLenVecI8VT (splat_vector uimm3:$shamt)))),
15181524
(PSLLI_B GPR:$rs1, uimm3:$shamt)>;
@@ -1609,6 +1615,11 @@ let Predicates = [HasStdExtP, IsRV64] in {
16091615
def: Pat<(v2i32 (riscv_pasub GPR:$rs1, GPR:$rs2)), (PASUB_W GPR:$rs1, GPR:$rs2)>;
16101616
def: Pat<(v2i32 (riscv_pasubu GPR:$rs1, GPR:$rs2)), (PASUBU_W GPR:$rs1, GPR:$rs2)>;
16111617

1618+
// 32-bit multiply high patterns
1619+
def: Pat<(v2i32 (mulhs GPR:$rs1, GPR:$rs2)), (PMULH_W GPR:$rs1, GPR:$rs2)>;
1620+
def: Pat<(v2i32 (mulhu GPR:$rs1, GPR:$rs2)), (PMULHU_W GPR:$rs1, GPR:$rs2)>;
1621+
def: Pat<(v2i32 (riscv_pmulhsu GPR:$rs1, GPR:$rs2)), (PMULHSU_W GPR:$rs1, GPR:$rs2)>;
1622+
16121623
// 32-bit logical shift left/right
16131624
def: Pat<(v2i32 (shl GPR:$rs1, (v2i32 (splat_vector (XLenVT GPR:$rs2))))),
16141625
(PSLL_WS GPR:$rs1, GPR:$rs2)>;

llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,3 +1040,81 @@ define void @test_psra_bs_vec_shamt(ptr %ret_ptr, ptr %a_ptr, ptr %shamt_ptr) {
10401040
store <4 x i8> %res, ptr %ret_ptr
10411041
ret void
10421042
}
1043+
; Test packed multiply high signed for v2i16
1044+
define void @test_pmulh_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
1045+
; CHECK-LABEL: test_pmulh_h:
1046+
; CHECK: # %bb.0:
1047+
; CHECK-NEXT: lw a1, 0(a1)
1048+
; CHECK-NEXT: lw a2, 0(a2)
1049+
; CHECK-NEXT: pmulh.h a1, a1, a2
1050+
; CHECK-NEXT: sw a1, 0(a0)
1051+
; CHECK-NEXT: ret
1052+
%a = load <2 x i16>, ptr %a_ptr
1053+
%b = load <2 x i16>, ptr %b_ptr
1054+
%a_ext = sext <2 x i16> %a to <2 x i32>
1055+
%b_ext = sext <2 x i16> %b to <2 x i32>
1056+
%mul = mul <2 x i32> %a_ext, %b_ext
1057+
%shift = lshr <2 x i32> %mul, <i32 16, i32 16>
1058+
%res = trunc <2 x i32> %shift to <2 x i16>
1059+
store <2 x i16> %res, ptr %ret_ptr
1060+
ret void
1061+
}
1062+
1063+
; Test packed multiply high unsigned for v2i16
1064+
define void @test_pmulhu_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
1065+
; CHECK-LABEL: test_pmulhu_h:
1066+
; CHECK: # %bb.0:
1067+
; CHECK-NEXT: lw a1, 0(a1)
1068+
; CHECK-NEXT: lw a2, 0(a2)
1069+
; CHECK-NEXT: pmulhu.h a1, a1, a2
1070+
; CHECK-NEXT: sw a1, 0(a0)
1071+
; CHECK-NEXT: ret
1072+
%a = load <2 x i16>, ptr %a_ptr
1073+
%b = load <2 x i16>, ptr %b_ptr
1074+
%a_ext = zext <2 x i16> %a to <2 x i32>
1075+
%b_ext = zext <2 x i16> %b to <2 x i32>
1076+
%mul = mul <2 x i32> %a_ext, %b_ext
1077+
%shift = lshr <2 x i32> %mul, <i32 16, i32 16>
1078+
%res = trunc <2 x i32> %shift to <2 x i16>
1079+
store <2 x i16> %res, ptr %ret_ptr
1080+
ret void
1081+
}
1082+
1083+
; Test packed multiply high signed-unsigned for v2i16
1084+
define void @test_pmulhsu_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
1085+
; CHECK-LABEL: test_pmulhsu_h:
1086+
; CHECK: # %bb.0:
1087+
; CHECK-NEXT: lw a1, 0(a1)
1088+
; CHECK-NEXT: lw a2, 0(a2)
1089+
; CHECK-NEXT: pmulhsu.h a1, a1, a2
1090+
; CHECK-NEXT: sw a1, 0(a0)
1091+
; CHECK-NEXT: ret
1092+
%a = load <2 x i16>, ptr %a_ptr
1093+
%b = load <2 x i16>, ptr %b_ptr
1094+
%a_ext = sext <2 x i16> %a to <2 x i32>
1095+
%b_ext = zext <2 x i16> %b to <2 x i32>
1096+
%mul = mul <2 x i32> %a_ext, %b_ext
1097+
%shift = lshr <2 x i32> %mul, <i32 16, i32 16>
1098+
%res = trunc <2 x i32> %shift to <2 x i16>
1099+
store <2 x i16> %res, ptr %ret_ptr
1100+
ret void
1101+
}
1102+
1103+
define void @test_pmulhsu_h_commuted(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
1104+
; CHECK-LABEL: test_pmulhsu_h_commuted:
1105+
; CHECK: # %bb.0:
1106+
; CHECK-NEXT: lw a1, 0(a1)
1107+
; CHECK-NEXT: lw a2, 0(a2)
1108+
; CHECK-NEXT: pmulhsu.h a1, a2, a1
1109+
; CHECK-NEXT: sw a1, 0(a0)
1110+
; CHECK-NEXT: ret
1111+
%a = load <2 x i16>, ptr %a_ptr
1112+
%b = load <2 x i16>, ptr %b_ptr
1113+
%a_ext = zext <2 x i16> %a to <2 x i32>
1114+
%b_ext = sext <2 x i16> %b to <2 x i32>
1115+
%mul = mul <2 x i32> %a_ext, %b_ext
1116+
%shift = lshr <2 x i32> %mul, <i32 16, i32 16>
1117+
%res = trunc <2 x i32> %shift to <2 x i16>
1118+
store <2 x i16> %res, ptr %ret_ptr
1119+
ret void
1120+
}

0 commit comments

Comments
 (0)