@@ -15265,18 +15265,22 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
1526515265 break;
1526615266 }
1526715267 case RISCVISD::PASUB:
15268- case RISCVISD::PASUBU: {
15268+ case RISCVISD::PASUBU:
15269+ case RISCVISD::PMULHSU: {
1526915270 MVT VT = N->getSimpleValueType(0);
1527015271 SDValue Op0 = N->getOperand(0);
1527115272 SDValue Op1 = N->getOperand(1);
15272- assert(VT == MVT::v2i16 || VT == MVT::v4i8);
15273+ unsigned Opcode = N->getOpcode();
15274+ // PMULHSU doesn't support i8 variants
15275+ assert(VT == MVT::v2i16 ||
15276+ (Opcode != RISCVISD::PMULHSU && VT == MVT::v4i8));
1527315277 MVT NewVT = MVT::v4i16;
1527415278 if (VT == MVT::v4i8)
1527515279 NewVT = MVT::v8i8;
1527615280 SDValue Undef = DAG.getUNDEF(VT);
1527715281 Op0 = DAG.getNode(ISD::CONCAT_VECTORS, DL, NewVT, {Op0, Undef});
1527815282 Op1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, NewVT, {Op1, Undef});
15279- Results.push_back(DAG.getNode(N->getOpcode() , DL, NewVT, {Op0, Op1}));
15283+ Results.push_back(DAG.getNode(Opcode , DL, NewVT, {Op0, Op1}));
1528015284 return;
1528115285 }
1528215286 case ISD::EXTRACT_VECTOR_ELT: {
@@ -16386,9 +16390,9 @@ static SDValue combineTruncSelectToSMaxUSat(SDNode *N, SelectionDAG &DAG) {
1638616390 return DAG.getNode(ISD::TRUNCATE, DL, VT, Min);
1638716391}
1638816392
16389- // Handle P extension averaging subtraction pattern :
16390- // (vXiY (trunc (srl (sub ([s|z]ext vXiY:$ a), ([s|z]ext vXiY:$ b)), 1) ))
16391- // -> PASUB/PASUBU
16393+ // Handle P extension truncate patterns :
16394+ // PASUB/PASUBU: (trunc (srl (sub ([s|z]ext a), ([s|z]ext b)), 1))
16395+ // PMULHSU: (trunc (srl (mul (sext a), (zext b)), EltBits))
1639216396static SDValue combinePExtTruncate(SDNode *N, SelectionDAG &DAG,
1639316397 const RISCVSubtarget &Subtarget) {
1639416398 SDValue N0 = N->getOperand(0);
@@ -16401,7 +16405,7 @@ static SDValue combinePExtTruncate(SDNode *N, SelectionDAG &DAG,
1640116405 VecVT != MVT::v4i8 && VecVT != MVT::v2i32)
1640216406 return SDValue();
1640316407
16404- // Check if shift amount is 1
16408+ // Check if shift amount is a splat constant
1640516409 SDValue ShAmt = N0.getOperand(1);
1640616410 if (ShAmt.getOpcode() != ISD::BUILD_VECTOR)
1640716411 return SDValue();
@@ -16415,44 +16419,57 @@ static SDValue combinePExtTruncate(SDNode *N, SelectionDAG &DAG,
1641516419 ConstantSDNode *C = dyn_cast<ConstantSDNode>(Splat);
1641616420 if (!C)
1641716421 return SDValue();
16418- if (C->getZExtValue() != 1)
16419- return SDValue();
1642016422
16421- // Check for SUB operation
16422- SDValue Sub = N0.getOperand(0);
16423- if (Sub.getOpcode() != ISD::SUB)
16424- return SDValue();
16423+ SDValue Op = N0.getOperand(0);
16424+ unsigned ShAmtVal = C->getZExtValue();
1642516425
16426- SDValue LHS = Sub .getOperand(0);
16427- SDValue RHS = Sub .getOperand(1);
16426+ SDValue LHS = Op .getOperand(0);
16427+ SDValue RHS = Op .getOperand(1);
1642816428
16429- // Check if both operands are sign/zero extends from the target
16430- // type
16431- bool IsSignExt = LHS.getOpcode() == ISD::SIGN_EXTEND &&
16432- RHS.getOpcode() == ISD::SIGN_EXTEND;
16433- bool IsZeroExt = LHS.getOpcode() == ISD::ZERO_EXTEND &&
16434- RHS.getOpcode() == ISD::ZERO_EXTEND;
16429+ bool LHSIsSExt = LHS.getOpcode() == ISD::SIGN_EXTEND;
16430+ bool LHSIsZExt = LHS.getOpcode() == ISD::ZERO_EXTEND;
16431+ bool RHSIsSExt = RHS.getOpcode() == ISD::SIGN_EXTEND;
16432+ bool RHSIsZExt = RHS.getOpcode() == ISD::ZERO_EXTEND;
1643516433
16436- if (!IsSignExt && !IsZeroExt )
16434+ if (!(LHSIsSExt || LHSIsZExt) || !(RHSIsSExt || RHSIsZExt) )
1643716435 return SDValue();
1643816436
1643916437 SDValue A = LHS.getOperand(0);
1644016438 SDValue B = RHS.getOperand(0);
1644116439
16442- // Check if the extends are from our target vector type
1644316440 if (A.getValueType() != VT || B.getValueType() != VT)
1644416441 return SDValue();
1644516442
16446- // Determine the instruction based on type and signedness
1644716443 unsigned Opc;
16448- if (IsSignExt)
16449- Opc = RISCVISD::PASUB;
16450- else if (IsZeroExt)
16451- Opc = RISCVISD::PASUBU;
16452- else
16444+ switch (Op.getOpcode()) {
16445+ default:
1645316446 return SDValue();
16447+ case ISD::SUB:
16448+ // PASUB/PASUBU: shift amount must be 1
16449+ if (ShAmtVal != 1)
16450+ return SDValue();
16451+ if (LHSIsSExt && RHSIsSExt)
16452+ Opc = RISCVISD::PASUB;
16453+ else if (LHSIsZExt && RHSIsZExt)
16454+ Opc = RISCVISD::PASUBU;
16455+ else
16456+ return SDValue();
16457+ break;
16458+ case ISD::MUL:
16459+ // PMULHSU: shift amount must be element size, only for i16/i32
16460+ unsigned EltBits = VecVT.getScalarSizeInBits();
16461+ if (ShAmtVal != EltBits || (EltBits != 16 && EltBits != 32))
16462+ return SDValue();
16463+ if ((LHSIsSExt && RHSIsZExt) || (LHSIsZExt && RHSIsSExt)) {
16464+ Opc = RISCVISD::PMULHSU;
16465+ // commuted case
16466+ if (LHSIsZExt && RHSIsSExt)
16467+ std::swap(A, B);
16468+ } else
16469+ return SDValue();
16470+ break;
16471+ }
1645416472
16455- // Create the machine node directly
1645616473 return DAG.getNode(Opc, SDLoc(N), VT, {A, B});
1645716474}
1645816475
0 commit comments