https://gcc.gnu.org/g:8fbea0880f7a4082203dd4ac32596e3970d5f7d9

commit 8fbea0880f7a4082203dd4ac32596e3970d5f7d9
Author: Michael Meissner <[email protected]>
Date:   Mon Oct 13 12:41:28 2025 -0400

    Rework bfloat16 to v4sf optimization.
    
    2025-10-13  Michael Meissner  <[email protected]>
    
    gcc/
    
            * config/rs6000/float16.cc (bfloat16_operation_as_v4s): Rewrite
            bfloat16_binary_op_as_v4sf so it will be able to handle FMA 
operations
            in the future.
            * config/rs6000/float16.md (bfloat16_binary_op_internal1): Likewise.
            (bfloat16_binary_op_internal2): Likewise.
            (bfloat16_binary_op_internal3): Likewise.
            (bfloat16_binary_op_internal4): Likewise.
            (bfloat16_binary_op_internal5): Likewise.
            (bfloat16_binary_op_internal6): Likewise.
            * config/rs6000/rs6000-protos.h (enum bfloat16_operation): New
            enumeration.
            (bfloat16_binary_op_as_v4sf): Delete.
            (bfloat16_operation_as_v4sf): New declaration.
            * config/rs6000/vsx.md (vsx_fmav4sf4): Add generator.
            (vsx_fms<mode>4): Likewise.
            (vsx_nfma<mode>4): Likewise.
            (vsx_nfmsv4sf4): Likewise.

Diff:
---
 gcc/config/rs6000/float16.cc      | 185 ++++++++++++++++++++++----------------
 gcc/config/rs6000/float16.md      |  84 +++++++----------
 gcc/config/rs6000/rs6000-protos.h |  13 ++-
 gcc/config/rs6000/vsx.md          |   8 +-
 4 files changed, 151 insertions(+), 139 deletions(-)

diff --git a/gcc/config/rs6000/float16.cc b/gcc/config/rs6000/float16.cc
index 0d606609dab3..484d04f4ddb4 100644
--- a/gcc/config/rs6000/float16.cc
+++ b/gcc/config/rs6000/float16.cc
@@ -42,15 +42,14 @@
 #include "common/common-target.h"
 #include "rs6000-internal.h"
 
-/* Expand a bfloat16 floating point binary operation:
+/* Expand a bfloat16 floating point operation:
 
-   ICODE: Operation to perform.
-   OP0:   Result (BFmode or SFmode).
-   OP1:   First input argument (BFmode or SFmode).
-   OP2:   Second input argument (BFmode or SFmode).
-   TMP0:  Temporary for result (V4SFmode).
-   TMP1:  Temporary for first input argument (V4SFmode).
-   TMP2:  Temporary for second input argument (V4SFmode).
+   ICODE:   Operation to perform.
+   RESULT:  Result of the operation.
+   OP1:     Input operand1.
+   OP2:     Input operand2.
+   OP3:     Input operand3 or NULL_RTX.
+   SUBTYPE: Describe the operation.
 
    The operation is done as a V4SFmode vector operation.  This is because
    converting BFmode from a scalar BFmode to SFmode to do the operation and
@@ -60,108 +59,136 @@
    SFmode.  */
        
 void
-bfloat16_binary_op_as_v4sf (enum rtx_code icode,
-                           rtx op0,
+bfloat16_operation_as_v4sf (enum rtx_code icode,
+                           rtx result,
                            rtx op1,
                            rtx op2,
-                           rtx tmp0,
-                           rtx tmp1,
-                           rtx tmp2)
+                           rtx op3,
+                           enum bfloat16_operation subtype)
 {
-  if (GET_CODE (tmp0) == SCRATCH)
-    tmp0 = gen_reg_rtx (V4SFmode);
+  gcc_assert (can_create_pseudo_p ());
 
-  if (GET_CODE (tmp1) == SCRATCH)
-    tmp1 = gen_reg_rtx (V4SFmode);
+  rtx result_v4sf = gen_reg_rtx (V4SFmode);
+  rtx ops_bf[3];
+  rtx ops_v4sf[3];
+  size_t n_opts;
 
-  if (GET_CODE (tmp2) == SCRATCH)
-    tmp2 = gen_reg_rtx (V4SFmode);
-
-  /* Convert operand1 and operand2 to V4SFmode format.  We use SPLAT for
-     registers to get the value into the upper 32-bits.  We can use XXSPLTW
-     to splat words instead of VSPLTIH since the XVCVBF16SPN instruction
-     ignores the odd half-words, and XXSPLTW can operate on all VSX registers
-     instead of just the Altivec registers.  Using SPLAT instead of a shift
-     also insure that other bits are not a signalling NaN.  If we are using
-     XXSPLTIW or XXSPLTIB to load the constant the other bits are duplicated.  
*/
-
-  /* Operand1.  */
-  if (GET_MODE (op1) == BFmode)
+  switch (subtype)
     {
-      emit_insn (gen_xxspltw_bf (tmp1, op1));
-      emit_insn (gen_xvcvbf16spn_bf (tmp1, tmp1));
+    case BF16_BINARY:
+      n_opts = 2;
+      ops_bf[0] = op1;
+      ops_bf[1] = op2;
+      gcc_assert (op3 == NULL_RTX);
+      break;
+
+    case BF16_FMA:
+    case BF16_FMS:
+    case BF16_NFMA:
+    case BF16_NFMS:
+      gcc_assert (icode == FMA);
+      n_opts = 3;
+      ops_bf[0] = op1;
+      ops_bf[1] = op2;
+      ops_bf[3] = op3;
+      break;
+
+    default:
+      gcc_unreachable ();
     }
 
-  else if (GET_MODE (op1) == SFmode)
-    emit_insn (gen_vsx_splat_v4sf (tmp1,
-                                  force_reg (SFmode, op1)));
-
-  else
-    gcc_unreachable ();
-
-  /* Operand2.  */
-  if (GET_MODE (op2) == BFmode)
+  for (size_t i = 0; i < n_opts; i++)
     {
-      if (REG_P (op2) || SUBREG_P (op2))
-       emit_insn (gen_xxspltw_bf (tmp2, op2));
+      rtx op = ops_bf[i];
+      rtx tmp = ops_v4sf[i] = gen_reg_rtx (V4SFmode);
+
+      gcc_assert (op != NULL_RTX);
 
-      else if (op2 == CONST0_RTX (BFmode))
-       emit_move_insn (tmp2, CONST0_RTX (V4SFmode));
+      /* Convert operands to V4SFmode format.  We use SPLAT for registers to
+        get the value into the upper 32-bits.  We can use XXSPLTW to splat
+        words instead of VSPLTIH since the XVCVBF16SPN instruction ignores the
+        odd half-words, and XXSPLTW can operate on all VSX registers instead
+        of just the Altivec registers.  Using SPLAT instead of a shift also
+        insure that other bits are not a signalling NaN.  If we are using
+        XXSPLTIW or XXSPLTIB to load the constant the other bits are
+        duplicated.  */
 
-      else if (fp16_xxspltiw_constant (op2, BFmode))
+      if (GET_MODE (op) == BFmode)
        {
-         rtx op2_bf = gen_lowpart (BFmode, tmp2);
-         emit_move_insn (op2_bf, op2);
+         emit_insn (gen_xxspltw_bf (tmp, op));
+         emit_insn (gen_xvcvbf16spn_bf (tmp, tmp));
        }
 
-      else
-       gcc_unreachable ();
+      else if (op == CONST0_RTX (SFmode)
+              || op == CONST0_RTX (BFmode))
+       emit_move_insn (tmp, CONST0_RTX (V4SFmode));
 
-      emit_insn (gen_xvcvbf16spn_bf (tmp2, tmp2));
-    }
+      else if (GET_MODE (op) == SFmode)
+       {
+         if (GET_CODE (op) == CONST_DOUBLE)
+           {
+             rtvec v = rtvec_alloc (4);
 
-  else if (GET_MODE (op2) == SFmode)
-    {
-      if (REG_P (op2) || SUBREG_P (op2))
-       emit_insn (gen_vsx_splat_v4sf (tmp2, op2));
+             for (size_t i = 0; i < 4; i++)
+               RTVEC_ELT (v, i) = op;
 
-      else if (op2 == CONST0_RTX (SFmode))
-       emit_move_insn (tmp2, CONST0_RTX (V4SFmode));
+             emit_insn (gen_rtx_SET (tmp,
+                                     gen_rtx_CONST_VECTOR (V4SFmode, v)));
+           }
 
-      else if (GET_CODE (op2) == CONST_DOUBLE)
-       {
-         rtvec v = rtvec_alloc (4);
-         RTVEC_ELT (v, 0) = op2;
-         RTVEC_ELT (v, 1) = op2;
-         RTVEC_ELT (v, 2) = op2;
-         RTVEC_ELT (v, 3) = op2;
-         emit_insn (gen_rtx_SET (tmp2,
-                                 gen_rtx_CONST_VECTOR (V4SFmode, v)));
+         else
+           emit_insn (gen_vsx_splat_v4sf (tmp,
+                                          force_reg (SFmode, op)));
        }
 
       else
-       emit_insn (gen_vsx_splat_v4sf (tmp2,
-                                      force_reg (SFmode, op2)));
+       gcc_unreachable ();
     }
 
-  else
-    gcc_unreachable ();
-
   /* Do the operation in V4SFmode.  */
-  emit_insn (gen_rtx_SET (tmp0,
-                         gen_rtx_fmt_ee (icode, V4SFmode, tmp1, tmp2)));
+  switch (subtype)
+    {
+    case BF16_BINARY:
+      emit_insn (gen_rtx_SET (result_v4sf,
+                             gen_rtx_fmt_ee (icode, V4SFmode,
+                                             ops_v4sf[0],
+                                             ops_v4sf[1])));
+      break;
+
+    case BF16_FMA:
+      emit_insn (gen_vsx_fmav4sf4 (result_v4sf, ops_v4sf[0], ops_v4sf[1],
+                                  ops_v4sf[2]));
+      break;
+
+    case BF16_FMS:
+      emit_insn (gen_vsx_fmsv4sf4 (result_v4sf, ops_v4sf[0], ops_v4sf[1],
+                                  ops_v4sf[2]));
+      break;
+
+    case BF16_NFMA:
+      emit_insn (gen_vsx_nfmav4sf4 (result_v4sf, ops_v4sf[0], ops_v4sf[1],
+                                   ops_v4sf[2]));
+      break;
+
+    case BF16_NFMS:
+      emit_insn (gen_vsx_nfmsv4sf4 (result_v4sf, ops_v4sf[0], ops_v4sf[1],
+                                   ops_v4sf[2]));
+      break;
+
+    default:
+      gcc_unreachable ();
+    }
 
   /* Convert V4SF result back to scalar mode.  */
-  if (GET_MODE (op0) == BFmode)
-    emit_insn (gen_xvcvspbf16_bf (op0, tmp0));
+  if (GET_MODE (result) == BFmode)
+    emit_insn (gen_xvcvspbf16_bf (result, result_v4sf));
 
-  else if (GET_MODE (op0) == SFmode)
+  else if (GET_MODE (result) == SFmode)
     {
       rtx element = GEN_INT (WORDS_BIG_ENDIAN ? 2 : 3);
-      emit_insn (gen_vsx_extract_v4sf (op0, tmp0, element));
+      emit_insn (gen_vsx_extract_v4sf (result, result_v4sf, element));
     }
 
   else
     gcc_unreachable ();
 }
-
diff --git a/gcc/config/rs6000/float16.md b/gcc/config/rs6000/float16.md
index bab03ffddb6e..3715bde0df03 100644
--- a/gcc/config/rs6000/float16.md
+++ b/gcc/config/rs6000/float16.md
@@ -450,22 +450,18 @@
         [(float_extend:SF
           (match_operand:BF 2 "vsx_register_operand" "wa"))
          (float_extend:SF
-          (match_operand:BF 3 "vsx_register_operand" "wa"))]))
-   (clobber (match_scratch:V4SF 4 "=&wa"))
-   (clobber (match_scratch:V4SF 5 "=&wa"))
-   (clobber (match_scratch:V4SF 6 "=&wa"))]
-  "TARGET_BFLOAT16_HW"
+          (match_operand:BF 3 "vsx_register_operand" "wa"))]))]
+  "TARGET_BFLOAT16_HW && can_create_pseudo_p ()"
   "#"
   "&& 1"
   [(pc)]
 {
-  bfloat16_binary_op_as_v4sf (GET_CODE (operands[1]),
+  bfloat16_operation_as_v4sf (GET_CODE (operands[1]),
                              operands[0],
                              operands[2],
                              operands[3],
-                             operands[4],
-                             operands[5],
-                             operands[6]);
+                             NULL_RTX,
+                             BF16_BINARY);
   DONE;
 })
 
@@ -476,22 +472,18 @@
          [(float_extend:SF
            (match_operand:BF 2 "vsx_register_operand" "wa"))
           (float_extend:SF
-           (match_operand:BF 3 "vsx_register_operand" "wa"))])))
-   (clobber (match_scratch:V4SF 4 "=&wa"))
-   (clobber (match_scratch:V4SF 5 "=&wa"))
-   (clobber (match_scratch:V4SF 6 "=&wa"))]
-  "TARGET_BFLOAT16_HW"
+           (match_operand:BF 3 "vsx_register_operand" "wa"))])))]
+  "TARGET_BFLOAT16_HW && can_create_pseudo_p ()"
   "#"
   "&& 1"
   [(pc)]
 {
-  bfloat16_binary_op_as_v4sf (GET_CODE (operands[1]),
+  bfloat16_operation_as_v4sf (GET_CODE (operands[1]),
                              operands[0],
                              operands[2],
                              operands[3],
-                             operands[4],
-                             operands[5],
-                             operands[6]);
+                             NULL_RTX,
+                             BF16_BINARY);
   DONE;
 })
 
@@ -500,22 +492,18 @@
        (match_operator:SF 1 "bfloat16_binary_operator"
         [(float_extend:SF
           (match_operand:BF 2 "vsx_register_operand" "wa,wa,wa"))
-         (match_operand:SF 3 "input_operand" "wa,j,eP")]))
-   (clobber (match_scratch:V4SF 4 "=&wa,&wa,&wa"))
-   (clobber (match_scratch:V4SF 5 "=&wa,&wa,&wa"))
-   (clobber (match_scratch:V4SF 6 "=&wa,&wa,&wa"))]
-  "TARGET_BFLOAT16_HW"
+         (match_operand:SF 3 "input_operand" "wa,j,eP")]))]
+  "TARGET_BFLOAT16_HW && can_create_pseudo_p ()"
   "#"
   "&& 1"
   [(pc)]
 {
-  bfloat16_binary_op_as_v4sf (GET_CODE (operands[1]),
+  bfloat16_operation_as_v4sf (GET_CODE (operands[1]),
                              operands[0],
                              operands[2],
                              operands[3],
-                             operands[4],
-                             operands[5],
-                             operands[6]);
+                             NULL_RTX,
+                             BF16_BINARY);
   DONE;
 })
 
@@ -525,22 +513,18 @@
         (match_operator:SF 1 "bfloat16_binary_operator"
          [(float_extend:SF
            (match_operand:BF 2 "vsx_register_operand" "wa,wa,wa"))
-          (match_operand:SF 3 "input_operand" "wa,j,eP")])))
-   (clobber (match_scratch:V4SF 4 "=&wa,&wa,&wa"))
-   (clobber (match_scratch:V4SF 5 "=&wa,&wa,&wa"))
-   (clobber (match_scratch:V4SF 6 "=&wa,&wa,&wa"))]
-  "TARGET_BFLOAT16_HW"
+          (match_operand:SF 3 "input_operand" "wa,j,eP")])))]
+  "TARGET_BFLOAT16_HW && can_create_pseudo_p ()"
   "#"
   "&& 1"
   [(pc)]
 {
-  bfloat16_binary_op_as_v4sf (GET_CODE (operands[1]),
+  bfloat16_operation_as_v4sf (GET_CODE (operands[1]),
                              operands[0],
                              operands[2],
                              operands[3],
-                             operands[4],
-                             operands[5],
-                             operands[6]);
+                             NULL_RTX,
+                             BF16_BINARY);
   DONE;
 })
 
@@ -549,22 +533,18 @@
        (match_operator:SF 1 "bfloat16_binary_operator"
         [(match_operand:SF 2 "vsx_register_operand" "wa")
          (float_extend:SF
-          (match_operand:BF 3 "vsx_register_operand" "wa"))]))
-   (clobber (match_scratch:V4SF 4 "=&wa"))
-   (clobber (match_scratch:V4SF 5 "=&wa"))
-   (clobber (match_scratch:V4SF 6 "=&wa"))]
-  "TARGET_BFLOAT16_HW"
+          (match_operand:BF 3 "vsx_register_operand" "wa"))]))]
+  "TARGET_BFLOAT16_HW && can_create_pseudo_p ()"
   "#"
   "&& 1"
   [(pc)]
 {
-  bfloat16_binary_op_as_v4sf (GET_CODE (operands[1]),
+  bfloat16_operation_as_v4sf (GET_CODE (operands[1]),
                              operands[0],
                              operands[2],
                              operands[3],
-                             operands[4],
-                             operands[5],
-                             operands[6]);
+                             NULL_RTX,
+                             BF16_BINARY);
   DONE;
 })
 
@@ -574,22 +554,18 @@
         (match_operator:SF 1 "bfloat16_binary_operator"
          [(match_operand:SF 3 "vsx_register_operand" "wa")
           (float_extend:SF
-           (match_operand:BF 2 "vsx_register_operand" "wa"))])))
-   (clobber (match_scratch:V4SF 4 "=&wa"))
-   (clobber (match_scratch:V4SF 5 "=&wa"))
-   (clobber (match_scratch:V4SF 6 "=&wa"))]
-  "TARGET_BFLOAT16_HW"
+           (match_operand:BF 2 "vsx_register_operand" "wa"))])))]
+  "TARGET_BFLOAT16_HW && can_create_pseudo_p ()"
   "#"
   "&& 1"
   [(pc)]
 {
-  bfloat16_binary_op_as_v4sf (GET_CODE (operands[1]),
+  bfloat16_operation_as_v4sf (GET_CODE (operands[1]),
                              operands[0],
                              operands[2],
                              operands[3],
-                             operands[4],
-                             operands[5],
-                             operands[6]);
+                             NULL_RTX,
+                             BF16_BINARY);
   DONE;
 })
 
diff --git a/gcc/config/rs6000/rs6000-protos.h 
b/gcc/config/rs6000/rs6000-protos.h
index 063f74f6e3f6..db38468df816 100644
--- a/gcc/config/rs6000/rs6000-protos.h
+++ b/gcc/config/rs6000/rs6000-protos.h
@@ -260,8 +260,17 @@ extern unsigned constant_generates_xxspltiw 
(vec_const_128bit_type *);
 extern unsigned constant_generates_xxspltidp (vec_const_128bit_type *);
 
 /* From float16.cc.  */
-extern void bfloat16_binary_op_as_v4sf (enum rtx_code, rtx, rtx, rtx,
-                                       rtx, rtx, rtx);
+/* Optimize bfloat16 operations.  */
+enum bfloat16_operation {
+  BF16_BINARY,                         /* Bfloat16 binary op.  */
+  BF16_FMA,                            /* (a * b) + c.  */
+  BF16_FMS,                            /* (a * b) - c.  */
+  BF16_NFMA,                           /* - ((a * b) + c).  */
+  BF16_NFMS                            /* - ((a * b) - c).  */
+};
+
+extern void bfloat16_operation_as_v4sf (enum rtx_code, rtx, rtx, rtx, rtx,
+                                       enum bfloat16_operation);
 #endif /* RTX_CODE */
 
 #ifdef TREE_CODE
diff --git a/gcc/config/rs6000/vsx.md b/gcc/config/rs6000/vsx.md
index 6c11d7766ed1..2611660921a5 100644
--- a/gcc/config/rs6000/vsx.md
+++ b/gcc/config/rs6000/vsx.md
@@ -2098,7 +2098,7 @@
 ;; vmaddfp and vnmsubfp can have different behaviors than the VSX instructions
 ;; in some corner cases due to VSCR[NJ] being set or if the addend is +0.0
 ;; instead of -0.0.
-(define_insn "*vsx_fmav4sf4"
+(define_insn "vsx_fmav4sf4"
   [(set (match_operand:V4SF 0 "vsx_register_operand" "=wa,wa")
        (fma:V4SF
          (match_operand:V4SF 1 "vsx_register_operand" "%wa,wa")
@@ -2122,7 +2122,7 @@
    xvmaddmdp %x0,%x1,%x3"
   [(set_attr "type" "vecdouble")])
 
-(define_insn "*vsx_fms<mode>4"
+(define_insn "vsx_fms<mode>4"
   [(set (match_operand:VSX_F 0 "vsx_register_operand" "=wa,wa")
        (fma:VSX_F
          (match_operand:VSX_F 1 "vsx_register_operand" "%wa,wa")
@@ -2135,7 +2135,7 @@
    xvmsubm<sd>p %x0,%x1,%x3"
   [(set_attr "type" "<VStype_mul>")])
 
-(define_insn "*vsx_nfma<mode>4"
+(define_insn "vsx_nfma<mode>4"
   [(set (match_operand:VSX_F 0 "vsx_register_operand" "=wa,wa")
        (neg:VSX_F
         (fma:VSX_F
@@ -2148,7 +2148,7 @@
    xvnmaddm<sd>p %x0,%x1,%x3"
   [(set_attr "type" "<VStype_mul>")])
 
-(define_insn "*vsx_nfmsv4sf4"
+(define_insn "vsx_nfmsv4sf4"
   [(set (match_operand:V4SF 0 "vsx_register_operand" "=wa,wa")
        (neg:V4SF
         (fma:V4SF

Reply via email to