Skip to content

Commit ef927ae

Browse files
authored
[llvm][RISCV] Support mulh for P extension codegen (llvm#171581)
For mulh pattern with operands that are both signed or unsigned, combination is performed automatically. However for mulh with operands which are signed and unsigned respectively we need to combine them manually same approach as what we've done for PASUB*. Note: This is first patch for mulh which only handle basic high part multiplication, there will be followup patches to handle rest of mulh related instructions.
1 parent 8975eb3 commit ef927ae

File tree

4 files changed

+297
-36
lines changed

4 files changed

+297
-36
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 47 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -15265,18 +15265,22 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
1526515265
break;
1526615266
}
1526715267
case RISCVISD::PASUB:
15268-
case RISCVISD::PASUBU: {
15268+
case RISCVISD::PASUBU:
15269+
case RISCVISD::PMULHSU: {
1526915270
MVT VT = N->getSimpleValueType(0);
1527015271
SDValue Op0 = N->getOperand(0);
1527115272
SDValue Op1 = N->getOperand(1);
15272-
assert(VT == MVT::v2i16 || VT == MVT::v4i8);
15273+
unsigned Opcode = N->getOpcode();
15274+
// PMULHSU doesn't support i8 variants
15275+
assert(VT == MVT::v2i16 ||
15276+
(Opcode != RISCVISD::PMULHSU && VT == MVT::v4i8));
1527315277
MVT NewVT = MVT::v4i16;
1527415278
if (VT == MVT::v4i8)
1527515279
NewVT = MVT::v8i8;
1527615280
SDValue Undef = DAG.getUNDEF(VT);
1527715281
Op0 = DAG.getNode(ISD::CONCAT_VECTORS, DL, NewVT, {Op0, Undef});
1527815282
Op1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, NewVT, {Op1, Undef});
15279-
Results.push_back(DAG.getNode(N->getOpcode(), DL, NewVT, {Op0, Op1}));
15283+
Results.push_back(DAG.getNode(Opcode, DL, NewVT, {Op0, Op1}));
1528015284
return;
1528115285
}
1528215286
case ISD::EXTRACT_VECTOR_ELT: {
@@ -16386,9 +16390,9 @@ static SDValue combineTruncSelectToSMaxUSat(SDNode *N, SelectionDAG &DAG) {
1638616390
return DAG.getNode(ISD::TRUNCATE, DL, VT, Min);
1638716391
}
1638816392

16389-
// Handle P extension averaging subtraction pattern:
16390-
// (vXiY (trunc (srl (sub ([s|z]ext vXiY:$a), ([s|z]ext vXiY:$b)), 1)))
16391-
// -> PASUB/PASUBU
16393+
// Handle P extension truncate patterns:
16394+
// PASUB/PASUBU: (trunc (srl (sub ([s|z]ext a), ([s|z]ext b)), 1))
16395+
// PMULHSU: (trunc (srl (mul (sext a), (zext b)), EltBits))
1639216396
static SDValue combinePExtTruncate(SDNode *N, SelectionDAG &DAG,
1639316397
const RISCVSubtarget &Subtarget) {
1639416398
SDValue N0 = N->getOperand(0);
@@ -16401,7 +16405,7 @@ static SDValue combinePExtTruncate(SDNode *N, SelectionDAG &DAG,
1640116405
VecVT != MVT::v4i8 && VecVT != MVT::v2i32)
1640216406
return SDValue();
1640316407

16404-
// Check if shift amount is 1
16408+
// Check if shift amount is a splat constant
1640516409
SDValue ShAmt = N0.getOperand(1);
1640616410
if (ShAmt.getOpcode() != ISD::BUILD_VECTOR)
1640716411
return SDValue();
@@ -16415,44 +16419,57 @@ static SDValue combinePExtTruncate(SDNode *N, SelectionDAG &DAG,
1641516419
ConstantSDNode *C = dyn_cast<ConstantSDNode>(Splat);
1641616420
if (!C)
1641716421
return SDValue();
16418-
if (C->getZExtValue() != 1)
16419-
return SDValue();
1642016422

16421-
// Check for SUB operation
16422-
SDValue Sub = N0.getOperand(0);
16423-
if (Sub.getOpcode() != ISD::SUB)
16424-
return SDValue();
16423+
SDValue Op = N0.getOperand(0);
16424+
unsigned ShAmtVal = C->getZExtValue();
1642516425

16426-
SDValue LHS = Sub.getOperand(0);
16427-
SDValue RHS = Sub.getOperand(1);
16426+
SDValue LHS = Op.getOperand(0);
16427+
SDValue RHS = Op.getOperand(1);
1642816428

16429-
// Check if both operands are sign/zero extends from the target
16430-
// type
16431-
bool IsSignExt = LHS.getOpcode() == ISD::SIGN_EXTEND &&
16432-
RHS.getOpcode() == ISD::SIGN_EXTEND;
16433-
bool IsZeroExt = LHS.getOpcode() == ISD::ZERO_EXTEND &&
16434-
RHS.getOpcode() == ISD::ZERO_EXTEND;
16429+
bool LHSIsSExt = LHS.getOpcode() == ISD::SIGN_EXTEND;
16430+
bool LHSIsZExt = LHS.getOpcode() == ISD::ZERO_EXTEND;
16431+
bool RHSIsSExt = RHS.getOpcode() == ISD::SIGN_EXTEND;
16432+
bool RHSIsZExt = RHS.getOpcode() == ISD::ZERO_EXTEND;
1643516433

16436-
if (!IsSignExt && !IsZeroExt)
16434+
if (!(LHSIsSExt || LHSIsZExt) || !(RHSIsSExt || RHSIsZExt))
1643716435
return SDValue();
1643816436

1643916437
SDValue A = LHS.getOperand(0);
1644016438
SDValue B = RHS.getOperand(0);
1644116439

16442-
// Check if the extends are from our target vector type
1644316440
if (A.getValueType() != VT || B.getValueType() != VT)
1644416441
return SDValue();
1644516442

16446-
// Determine the instruction based on type and signedness
1644716443
unsigned Opc;
16448-
if (IsSignExt)
16449-
Opc = RISCVISD::PASUB;
16450-
else if (IsZeroExt)
16451-
Opc = RISCVISD::PASUBU;
16452-
else
16444+
switch (Op.getOpcode()) {
16445+
default:
1645316446
return SDValue();
16447+
case ISD::SUB:
16448+
// PASUB/PASUBU: shift amount must be 1
16449+
if (ShAmtVal != 1)
16450+
return SDValue();
16451+
if (LHSIsSExt && RHSIsSExt)
16452+
Opc = RISCVISD::PASUB;
16453+
else if (LHSIsZExt && RHSIsZExt)
16454+
Opc = RISCVISD::PASUBU;
16455+
else
16456+
return SDValue();
16457+
break;
16458+
case ISD::MUL:
16459+
// PMULHSU: shift amount must be element size, only for i16/i32
16460+
unsigned EltBits = VecVT.getScalarSizeInBits();
16461+
if (ShAmtVal != EltBits || (EltBits != 16 && EltBits != 32))
16462+
return SDValue();
16463+
if ((LHSIsSExt && RHSIsZExt) || (LHSIsZExt && RHSIsSExt)) {
16464+
Opc = RISCVISD::PMULHSU;
16465+
// commuted case
16466+
if (LHSIsZExt && RHSIsSExt)
16467+
std::swap(A, B);
16468+
} else
16469+
return SDValue();
16470+
break;
16471+
}
1645416472

16455-
// Create the machine node directly
1645616473
return DAG.getNode(Opc, SDLoc(N), VT, {A, B});
1645716474
}
1645816475

llvm/lib/Target/RISCV/RISCVInstrInfoP.td

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1463,12 +1463,13 @@ let Predicates = [HasStdExtP, IsRV32] in {
14631463

14641464
def riscv_absw : RVSDNode<"ABSW", SDT_RISCVIntUnaryOpW>;
14651465

1466-
def SDT_RISCVPASUB : SDTypeProfile<1, 2, [SDTCisVec<0>,
1467-
SDTCisInt<0>,
1468-
SDTCisSameAs<0, 1>,
1469-
SDTCisSameAs<0, 2>]>;
1470-
def riscv_pasub : RVSDNode<"PASUB", SDT_RISCVPASUB>;
1471-
def riscv_pasubu : RVSDNode<"PASUBU", SDT_RISCVPASUB>;
1466+
def SDT_RISCVPBinOp : SDTypeProfile<1, 2, [SDTCisVec<0>,
1467+
SDTCisInt<0>,
1468+
SDTCisSameAs<0, 1>,
1469+
SDTCisSameAs<0, 2>]>;
1470+
def riscv_pasub : RVSDNode<"PASUB", SDT_RISCVPBinOp>;
1471+
def riscv_pasubu : RVSDNode<"PASUBU", SDT_RISCVPBinOp>;
1472+
def riscv_pmulhsu : RVSDNode<"PMULHSU", SDT_RISCVPBinOp>;
14721473

14731474
let Predicates = [HasStdExtP] in {
14741475
def : PatGpr<abs, ABS>;
@@ -1513,6 +1514,11 @@ let Predicates = [HasStdExtP] in {
15131514
def: Pat<(XLenVecI16VT (abds GPR:$rs1, GPR:$rs2)), (PABD_H GPR:$rs1, GPR:$rs2)>;
15141515
def: Pat<(XLenVecI16VT (abdu GPR:$rs1, GPR:$rs2)), (PABDU_H GPR:$rs1, GPR:$rs2)>;
15151516

1517+
// 16-bit multiply high patterns
1518+
def: Pat<(XLenVecI16VT (mulhs GPR:$rs1, GPR:$rs2)), (PMULH_H GPR:$rs1, GPR:$rs2)>;
1519+
def: Pat<(XLenVecI16VT (mulhu GPR:$rs1, GPR:$rs2)), (PMULHU_H GPR:$rs1, GPR:$rs2)>;
1520+
def: Pat<(XLenVecI16VT (riscv_pmulhsu GPR:$rs1, GPR:$rs2)), (PMULHSU_H GPR:$rs1, GPR:$rs2)>;
1521+
15161522
// 8-bit logical shift left/right patterns
15171523
def: Pat<(XLenVecI8VT (shl GPR:$rs1, (XLenVecI8VT (splat_vector uimm3:$shamt)))),
15181524
(PSLLI_B GPR:$rs1, uimm3:$shamt)>;
@@ -1609,6 +1615,11 @@ let Predicates = [HasStdExtP, IsRV64] in {
16091615
def: Pat<(v2i32 (riscv_pasub GPR:$rs1, GPR:$rs2)), (PASUB_W GPR:$rs1, GPR:$rs2)>;
16101616
def: Pat<(v2i32 (riscv_pasubu GPR:$rs1, GPR:$rs2)), (PASUBU_W GPR:$rs1, GPR:$rs2)>;
16111617

1618+
// 32-bit multiply high patterns
1619+
def: Pat<(v2i32 (mulhs GPR:$rs1, GPR:$rs2)), (PMULH_W GPR:$rs1, GPR:$rs2)>;
1620+
def: Pat<(v2i32 (mulhu GPR:$rs1, GPR:$rs2)), (PMULHU_W GPR:$rs1, GPR:$rs2)>;
1621+
def: Pat<(v2i32 (riscv_pmulhsu GPR:$rs1, GPR:$rs2)), (PMULHSU_W GPR:$rs1, GPR:$rs2)>;
1622+
16121623
// 32-bit logical shift left/right
16131624
def: Pat<(v2i32 (shl GPR:$rs1, (v2i32 (splat_vector (XLenVT GPR:$rs2))))),
16141625
(PSLL_WS GPR:$rs1, GPR:$rs2)>;

llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,3 +1040,81 @@ define void @test_psra_bs_vec_shamt(ptr %ret_ptr, ptr %a_ptr, ptr %shamt_ptr) {
10401040
store <4 x i8> %res, ptr %ret_ptr
10411041
ret void
10421042
}
1043+
; Test packed multiply high signed for v2i16
1044+
define void @test_pmulh_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
1045+
; CHECK-LABEL: test_pmulh_h:
1046+
; CHECK: # %bb.0:
1047+
; CHECK-NEXT: lw a1, 0(a1)
1048+
; CHECK-NEXT: lw a2, 0(a2)
1049+
; CHECK-NEXT: pmulh.h a1, a1, a2
1050+
; CHECK-NEXT: sw a1, 0(a0)
1051+
; CHECK-NEXT: ret
1052+
%a = load <2 x i16>, ptr %a_ptr
1053+
%b = load <2 x i16>, ptr %b_ptr
1054+
%a_ext = sext <2 x i16> %a to <2 x i32>
1055+
%b_ext = sext <2 x i16> %b to <2 x i32>
1056+
%mul = mul <2 x i32> %a_ext, %b_ext
1057+
%shift = lshr <2 x i32> %mul, <i32 16, i32 16>
1058+
%res = trunc <2 x i32> %shift to <2 x i16>
1059+
store <2 x i16> %res, ptr %ret_ptr
1060+
ret void
1061+
}
1062+
1063+
; Test packed multiply high unsigned for v2i16
1064+
define void @test_pmulhu_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
1065+
; CHECK-LABEL: test_pmulhu_h:
1066+
; CHECK: # %bb.0:
1067+
; CHECK-NEXT: lw a1, 0(a1)
1068+
; CHECK-NEXT: lw a2, 0(a2)
1069+
; CHECK-NEXT: pmulhu.h a1, a1, a2
1070+
; CHECK-NEXT: sw a1, 0(a0)
1071+
; CHECK-NEXT: ret
1072+
%a = load <2 x i16>, ptr %a_ptr
1073+
%b = load <2 x i16>, ptr %b_ptr
1074+
%a_ext = zext <2 x i16> %a to <2 x i32>
1075+
%b_ext = zext <2 x i16> %b to <2 x i32>
1076+
%mul = mul <2 x i32> %a_ext, %b_ext
1077+
%shift = lshr <2 x i32> %mul, <i32 16, i32 16>
1078+
%res = trunc <2 x i32> %shift to <2 x i16>
1079+
store <2 x i16> %res, ptr %ret_ptr
1080+
ret void
1081+
}
1082+
1083+
; Test packed multiply high signed-unsigned for v2i16
1084+
define void @test_pmulhsu_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
1085+
; CHECK-LABEL: test_pmulhsu_h:
1086+
; CHECK: # %bb.0:
1087+
; CHECK-NEXT: lw a1, 0(a1)
1088+
; CHECK-NEXT: lw a2, 0(a2)
1089+
; CHECK-NEXT: pmulhsu.h a1, a1, a2
1090+
; CHECK-NEXT: sw a1, 0(a0)
1091+
; CHECK-NEXT: ret
1092+
%a = load <2 x i16>, ptr %a_ptr
1093+
%b = load <2 x i16>, ptr %b_ptr
1094+
%a_ext = sext <2 x i16> %a to <2 x i32>
1095+
%b_ext = zext <2 x i16> %b to <2 x i32>
1096+
%mul = mul <2 x i32> %a_ext, %b_ext
1097+
%shift = lshr <2 x i32> %mul, <i32 16, i32 16>
1098+
%res = trunc <2 x i32> %shift to <2 x i16>
1099+
store <2 x i16> %res, ptr %ret_ptr
1100+
ret void
1101+
}
1102+
1103+
define void @test_pmulhsu_h_commuted(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
1104+
; CHECK-LABEL: test_pmulhsu_h_commuted:
1105+
; CHECK: # %bb.0:
1106+
; CHECK-NEXT: lw a1, 0(a1)
1107+
; CHECK-NEXT: lw a2, 0(a2)
1108+
; CHECK-NEXT: pmulhsu.h a1, a2, a1
1109+
; CHECK-NEXT: sw a1, 0(a0)
1110+
; CHECK-NEXT: ret
1111+
%a = load <2 x i16>, ptr %a_ptr
1112+
%b = load <2 x i16>, ptr %b_ptr
1113+
%a_ext = zext <2 x i16> %a to <2 x i32>
1114+
%b_ext = sext <2 x i16> %b to <2 x i32>
1115+
%mul = mul <2 x i32> %a_ext, %b_ext
1116+
%shift = lshr <2 x i32> %mul, <i32 16, i32 16>
1117+
%res = trunc <2 x i32> %shift to <2 x i16>
1118+
store <2 x i16> %res, ptr %ret_ptr
1119+
ret void
1120+
}

0 commit comments

Comments
 (0)