llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

Author: Pengcheng Wang (wangpc-pp)

<details>
<summary>Changes</summary>

Note that we only support SEW=8/16 for `vwabdacc(u)`.


---
Full diff: https://github.com/llvm/llvm-project/pull/180162.diff


4 Files Affected:

- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+44) 
- (modified) llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td (+14-2) 
- (modified) llvm/lib/Target/RISCV/RISCVInstrInfoZvabd.td (+21-1) 
- (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-sad.ll (+14-10) 


``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp 
b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index d46cb575c54c5..171fc391a7aa8 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -18770,6 +18770,48 @@ static SDValue combineVWADDSUBWSelect(SDNode *N, 
SelectionDAG &DAG) {
                      N->getFlags());
 }
 
+// vwaddu C (vabd A B) -> vwabda(A B C)
+// vwaddu C (vabdu A B) -> vwabdau(A B C)
+static SDValue performVWABDACombine(SDNode *N, SelectionDAG &DAG,
+                                    const RISCVSubtarget &Subtarget) {
+  if (!Subtarget.hasStdExtZvabd())
+    return SDValue();
+
+  MVT VT = N->getSimpleValueType(0);
+  if (VT.getVectorElementType() != MVT::i8 &&
+      VT.getVectorElementType() != MVT::i16)
+    return SDValue();
+
+  SDValue Op0 = N->getOperand(0);
+  SDValue Op1 = N->getOperand(1);
+  SDValue Passthru = N->getOperand(2);
+  if (!Passthru->isUndef())
+    return SDValue();
+
+  SDValue Mask = N->getOperand(3);
+  SDValue VL = N->getOperand(4);
+  auto IsABD = [](SDValue Op) {
+    if (Op->getOpcode() != RISCVISD::ABDS_VL &&
+        Op->getOpcode() != RISCVISD::ABDU_VL)
+      return SDValue();
+    return Op;
+  };
+
+  SDValue Diff = IsABD(Op0);
+  Diff = Diff ? Diff : IsABD(Op1);
+  if (!Diff)
+    return SDValue();
+  SDValue Acc = Diff == Op0 ? Op1 : Op0;
+
+  SDLoc DL(N);
+  Acc = DAG.getNode(RISCVISD::VZEXT_VL, DL, VT, Acc, Mask, VL);
+  SDValue Result = DAG.getNode(
+      Diff.getOpcode() == RISCVISD::ABDS_VL ? RISCVISD::VWABDA_VL
+                                            : RISCVISD::VWABDAU_VL,
+      DL, VT, Diff.getOperand(0), Diff.getOperand(1), Acc, Mask, VL);
+  return Result;
+}
+
 static SDValue performVWADDSUBW_VLCombine(SDNode *N,
                                           TargetLowering::DAGCombinerInfo &DCI,
                                           const RISCVSubtarget &Subtarget) {
@@ -21681,6 +21723,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode 
*N,
     if (SDValue V = combineVqdotAccum(N, DAG, Subtarget))
       return V;
     return combineToVWMACC(N, DAG, Subtarget);
+  case RISCVISD::VWADDU_VL:
+    return performVWABDACombine(N, DAG, Subtarget);
   case RISCVISD::VWADD_W_VL:
   case RISCVISD::VWADDU_W_VL:
   case RISCVISD::VWSUB_W_VL:
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td 
b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index 46dd45876a384..d1bcaffdeac5b 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -1750,8 +1750,9 @@ multiclass VPatMultiplyAddVL_VV_VX<SDNode op, string 
instruction_name> {
   }
 }
 
-multiclass VPatWidenMultiplyAddVL_VV_VX<SDNode vwmacc_op, string instr_name> {
-  foreach vtiTowti = AllWidenableIntVectors in {
+multiclass VPatWidenMultiplyAddVL_VV<SDNode vwmacc_op, string instr_name,
+                                     list<VTypeInfoToWide> vtilist = 
AllWidenableIntVectors> {
+  foreach vtiTowti = vtilist in {
     defvar vti = vtiTowti.Vti;
     defvar wti = vtiTowti.Wti;
     let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
@@ -1763,6 +1764,17 @@ multiclass VPatWidenMultiplyAddVL_VV_VX<SDNode 
vwmacc_op, string instr_name> {
                 (!cast<Instruction>(instr_name#"_VV_"#vti.LMul.MX#"_MASK")
                     wti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
                     (vti.Mask VMV0:$vm), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
+    }
+  }
+}
+
+multiclass VPatWidenMultiplyAddVL_VV_VX<SDNode vwmacc_op, string instr_name>
+    : VPatWidenMultiplyAddVL_VV<vwmacc_op, instr_name> {
+  foreach vtiTowti = AllWidenableIntVectors in {
+    defvar vti = vtiTowti.Vti;
+    defvar wti = vtiTowti.Wti;
+    let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
+                                 GetVTypePredicates<wti>.Predicates) in {
       def : Pat<(vwmacc_op (SplatPat XLenVT:$rs1),
                            (vti.Vector vti.RegClass:$rs2),
                            (wti.Vector wti.RegClass:$rd),
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZvabd.td 
b/llvm/lib/Target/RISCV/RISCVInstrInfoZvabd.td
index 139372b70e590..46261d83711cc 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoZvabd.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZvabd.td
@@ -29,7 +29,6 @@ let Predicates = [HasStdExtZvabd] in {
 
//===----------------------------------------------------------------------===//
 // Pseudos
 
//===----------------------------------------------------------------------===//
-
 multiclass PseudoVABS {
   foreach m = MxList in {
     defvar mx = m.MX;
@@ -44,10 +43,23 @@ multiclass PseudoVABS {
   }
 }
 
+multiclass VPseudoVWABD_VV {
+  foreach m = MxListW in {
+    defvar mx = m.MX;
+    defm "" : VPseudoTernaryW_VV<m, Commutable = 1>,
+              SchedTernary<"WriteVIWMulAddV", "ReadVIWMulAddV",
+                           "ReadVIWMulAddV", "ReadVIWMulAddV", mx>;
+  }
+}
+
 let Predicates = [HasStdExtZvabd] in {
   defm PseudoVABS : PseudoVABS;
   defm PseudoVABD : VPseudoVALU_VV<Commutable = 1>;
   defm PseudoVABDU : VPseudoVALU_VV<Commutable = 1>;
+  let IsRVVWideningReduction = 1 in {
+    defm PseudoVWABDA : VPseudoVWABD_VV;
+    defm PseudoVWABDAU : VPseudoVWABD_VV;
+  } // IsRVVWideningReduction = 1
 } // Predicates = [HasStdExtZvabd]
 
 
//===----------------------------------------------------------------------===//
@@ -57,12 +69,17 @@ let HasPassthruOp = true, HasMaskOp = true in {
 def riscv_abs_vl  : RVSDNode<"ABS_VL", SDT_RISCVIntUnOp_VL>;
 def riscv_abds_vl : RVSDNode<"ABDS_VL", SDT_RISCVIntBinOp_VL, 
[SDNPCommutative]>;
 def riscv_abdu_vl : RVSDNode<"ABDU_VL", SDT_RISCVIntBinOp_VL, 
[SDNPCommutative]>;
+def rvv_vwabda_vl  : RVSDNode<"VWABDA_VL", SDT_RISCVVWIntTernOp_VL, 
[SDNPCommutative]>;
+def rvv_vwabdau_vl : RVSDNode<"VWABDAU_VL", SDT_RISCVVWIntTernOp_VL, 
[SDNPCommutative]>;
 } // let HasPassthruOp = true, HasMaskOp = true
 
 // These instructions are defined for SEW=8 and SEW=16, otherwise the 
instruction
 // encoding is reserved.
 defvar ABDIntVectors = !filter(vti, AllIntegerVectors, !or(!eq(vti.SEW, 8),
                                                            !eq(vti.SEW, 16)));
+defvar ABDAIntVectors = !filter(vtiTowti, AllWidenableIntVectors,
+                                          !or(!eq(vtiTowti.Vti.SEW, 8),
+                                              !eq(vtiTowti.Vti.SEW, 16)));
 
 let Predicates = [HasStdExtZvabd] in {
 defm : VPatBinarySDNode_VV<abds, "PseudoVABD", ABDIntVectors>;
@@ -79,4 +96,7 @@ foreach vti = AllIntegerVectors in {
 }
 
 defm : VPatUnaryVL_V<riscv_abs_vl, "PseudoVABS", HasStdExtZvabd>;
+
+defm : VPatWidenMultiplyAddVL_VV<rvv_vwabda_vl, "PseudoVWABDA", 
ABDAIntVectors>;
+defm : VPatWidenMultiplyAddVL_VV<rvv_vwabdau_vl, "PseudoVWABDAU", 
ABDAIntVectors>;
 } // Predicates = [HasStdExtZvabd]
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-sad.ll 
b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-sad.ll
index 9f6c34cb052ff..dcb8b31c682b3 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-sad.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-sad.ll
@@ -199,16 +199,18 @@ define signext i32 @sad_2block_16xi8_as_i32(ptr %a, ptr 
%b, i32 signext %stridea
 ; ZVABD-NEXT:    vle8.v v15, (a1)
 ; ZVABD-NEXT:    add a0, a0, a2
 ; ZVABD-NEXT:    add a1, a1, a3
+; ZVABD-NEXT:    vle8.v v16, (a0)
+; ZVABD-NEXT:    vle8.v v17, (a1)
 ; ZVABD-NEXT:    vabdu.vv v8, v8, v9
-; ZVABD-NEXT:    vle8.v v9, (a0)
-; ZVABD-NEXT:    vabdu.vv v10, v10, v11
-; ZVABD-NEXT:    vle8.v v11, (a1)
-; ZVABD-NEXT:    vwaddu.vv v12, v10, v8
+; ZVABD-NEXT:    vsetvli zero, zero, e16, m2, ta, ma
+; ZVABD-NEXT:    vzext.vf2 v12, v8
+; ZVABD-NEXT:    vsetvli zero, zero, e8, m1, ta, ma
+; ZVABD-NEXT:    vwabdau.vv v12, v10, v11
 ; ZVABD-NEXT:    vabdu.vv v8, v14, v15
 ; ZVABD-NEXT:    vsetvli zero, zero, e16, m2, ta, ma
 ; ZVABD-NEXT:    vzext.vf2 v14, v8
 ; ZVABD-NEXT:    vsetvli zero, zero, e8, m1, ta, ma
-; ZVABD-NEXT:    vabdu.vv v16, v9, v11
+; ZVABD-NEXT:    vabdu.vv v16, v16, v17
 ; ZVABD-NEXT:    vsetvli zero, zero, e16, m2, ta, ma
 ; ZVABD-NEXT:    vwaddu.vv v8, v14, v12
 ; ZVABD-NEXT:    vzext.vf2 v12, v16
@@ -320,16 +322,18 @@ define signext i32 @sadu_2block_16xi8_as_i32(ptr %a, ptr 
%b, i32 signext %stride
 ; ZVABD-NEXT:    vle8.v v15, (a1)
 ; ZVABD-NEXT:    add a0, a0, a2
 ; ZVABD-NEXT:    add a1, a1, a3
+; ZVABD-NEXT:    vle8.v v16, (a0)
+; ZVABD-NEXT:    vle8.v v17, (a1)
 ; ZVABD-NEXT:    vabd.vv v8, v8, v9
-; ZVABD-NEXT:    vle8.v v9, (a0)
-; ZVABD-NEXT:    vabd.vv v10, v10, v11
-; ZVABD-NEXT:    vle8.v v11, (a1)
-; ZVABD-NEXT:    vwaddu.vv v12, v10, v8
+; ZVABD-NEXT:    vsetvli zero, zero, e16, m2, ta, ma
+; ZVABD-NEXT:    vzext.vf2 v12, v8
+; ZVABD-NEXT:    vsetvli zero, zero, e8, m1, ta, ma
+; ZVABD-NEXT:    vwabda.vv v12, v10, v11
 ; ZVABD-NEXT:    vabd.vv v8, v14, v15
 ; ZVABD-NEXT:    vsetvli zero, zero, e16, m2, ta, ma
 ; ZVABD-NEXT:    vzext.vf2 v14, v8
 ; ZVABD-NEXT:    vsetvli zero, zero, e8, m1, ta, ma
-; ZVABD-NEXT:    vabd.vv v16, v9, v11
+; ZVABD-NEXT:    vabd.vv v16, v16, v17
 ; ZVABD-NEXT:    vsetvli zero, zero, e16, m2, ta, ma
 ; ZVABD-NEXT:    vwaddu.vv v8, v14, v12
 ; ZVABD-NEXT:    vzext.vf2 v12, v16

``````````

</details>


https://github.com/llvm/llvm-project/pull/180162
_______________________________________________
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to