https://gcc.gnu.org/g:4d1b61eb2bea463bbca1998cf259be5ab4cd1833
commit 4d1b61eb2bea463bbca1998cf259be5ab4cd1833 Author: Michael Meissner <[email protected]> Date: Mon Oct 13 14:44:48 2025 -0400 Add bfloat16 fma optimizations and rework bfloat16 binary optimizations. 2025-10-13 Michael Meissner <[email protected]> gcc/ * config/rs6000/float16.cc (bfloat16_operation_as_v4sf): Iterate code. * config/rs6000/float16.md (bfloat16_binary_op_internal1): Rework bfloat16 optimization code. (bfloat16_binary_op_internal2): Likewise. (bfloat16_binary_op_internal3): Likewise. (bfloat16_binary_op_internal4): Likewise. (bfloat16_binary_op_internal5): Likewise. (bfloat16_fma_internal1): Likewise. (bfloat16_fma_internal2): Likewise. (bfloat16_fms_internal1): Likewise. (bfloat16_fms_interna2): Likewise. (bfloat16_nfma_internal1): Likewise. (bfloat16_nfma_internal2): Likewise. (bfloat16_nfma_internal3): Likewise. (bfloat16_nfms_internal1): Likewise. (bfloat16_nfms_internal2): Likewise. (bfloat16_nfms_internal3): Likewise. * config/rs6000/predicates.md (bfloat16_v4sf_operand): New predicate. (bfloat16_bf_operand): Likewise. * config/rs6000/vsx.md (vsx_fmav4sf4): Remove generator, added in the last change. (vsx_fms<mode>4): Likewise. (vsx_nfma<mode>4): Likewise. (vsx_nfmsv4sf4): Likewise. Diff: --- gcc/config/rs6000/float16.cc | 66 ++++------ gcc/config/rs6000/float16.md | 267 ++++++++++++++++++++++++++++++++-------- gcc/config/rs6000/predicates.md | 47 +++++++ gcc/config/rs6000/vsx.md | 8 +- 4 files changed, 294 insertions(+), 94 deletions(-) diff --git a/gcc/config/rs6000/float16.cc b/gcc/config/rs6000/float16.cc index 484d04f4ddb4..3dc7273719c1 100644 --- a/gcc/config/rs6000/float16.cc +++ b/gcc/config/rs6000/float16.cc @@ -69,7 +69,7 @@ bfloat16_operation_as_v4sf (enum rtx_code icode, gcc_assert (can_create_pseudo_p ()); rtx result_v4sf = gen_reg_rtx (V4SFmode); - rtx ops_bf[3]; + rtx ops_orig[3] = { op1, op2, op3 }; rtx ops_v4sf[3]; size_t n_opts; @@ -77,8 +77,6 @@ bfloat16_operation_as_v4sf (enum rtx_code icode, { case BF16_BINARY: n_opts = 2; - ops_bf[0] = op1; - ops_bf[1] = op2; gcc_assert (op3 == NULL_RTX); break; @@ -88,9 +86,6 @@ bfloat16_operation_as_v4sf (enum rtx_code icode, case BF16_NFMS: gcc_assert (icode == FMA); n_opts = 3; - ops_bf[0] = op1; - ops_bf[1] = op2; - ops_bf[3] = op3; break; default: @@ -99,11 +94,15 @@ bfloat16_operation_as_v4sf (enum rtx_code icode, for (size_t i = 0; i < n_opts; i++) { - rtx op = ops_bf[i]; + rtx op = ops_orig[i]; rtx tmp = ops_v4sf[i] = gen_reg_rtx (V4SFmode); gcc_assert (op != NULL_RTX); + /* Remove truncation/extend added. */ + if (GET_CODE (op) == FLOAT_EXTEND || GET_CODE (op) == FLOAT_TRUNCATE) + op = XEXP (op, 0); + /* Convert operands to V4SFmode format. We use SPLAT for registers to get the value into the upper 32-bits. We can use XXSPLTW to splat words instead of VSPLTIH since the XVCVBF16SPN instruction ignores the @@ -113,16 +112,15 @@ bfloat16_operation_as_v4sf (enum rtx_code icode, XXSPLTIW or XXSPLTIB to load the constant the other bits are duplicated. */ - if (GET_MODE (op) == BFmode) + if (op == CONST0_RTX (SFmode) || op == CONST0_RTX (BFmode)) + emit_move_insn (tmp, CONST0_RTX (V4SFmode)); + + else if (GET_MODE (op) == BFmode) { - emit_insn (gen_xxspltw_bf (tmp, op)); + emit_insn (gen_xxspltw_bf (tmp, force_reg (BFmode, op))); emit_insn (gen_xvcvbf16spn_bf (tmp, tmp)); } - else if (op == CONST0_RTX (SFmode) - || op == CONST0_RTX (BFmode)) - emit_move_insn (tmp, CONST0_RTX (V4SFmode)); - else if (GET_MODE (op) == SFmode) { if (GET_CODE (op) == CONST_DOUBLE) @@ -146,37 +144,27 @@ bfloat16_operation_as_v4sf (enum rtx_code icode, } /* Do the operation in V4SFmode. */ - switch (subtype) - { - case BF16_BINARY: - emit_insn (gen_rtx_SET (result_v4sf, - gen_rtx_fmt_ee (icode, V4SFmode, - ops_v4sf[0], - ops_v4sf[1]))); - break; + if (subtype == BF16_BINARY) + emit_insn (gen_rtx_SET (result_v4sf, + gen_rtx_fmt_ee (icode, V4SFmode, + ops_v4sf[0], + ops_v4sf[1]))); - case BF16_FMA: - emit_insn (gen_vsx_fmav4sf4 (result_v4sf, ops_v4sf[0], ops_v4sf[1], - ops_v4sf[2])); - break; + else /* FMA/FMS/NFMA/NFMS operation. */ + { + rtx op1 = ops_v4sf[0]; + rtx op2 = ops_v4sf[1]; + rtx op3 = ops_v4sf[2]; - case BF16_FMS: - emit_insn (gen_vsx_fmsv4sf4 (result_v4sf, ops_v4sf[0], ops_v4sf[1], - ops_v4sf[2])); - break; + if (subtype == BF16_FMS || subtype == BF16_NFMS) + op3 = gen_rtx_NEG (V4SFmode, op3); - case BF16_NFMA: - emit_insn (gen_vsx_nfmav4sf4 (result_v4sf, ops_v4sf[0], ops_v4sf[1], - ops_v4sf[2])); - break; + rtx op_fma = gen_rtx_FMA (V4SFmode, op1, op2, op3); - case BF16_NFMS: - emit_insn (gen_vsx_nfmsv4sf4 (result_v4sf, ops_v4sf[0], ops_v4sf[1], - ops_v4sf[2])); - break; + if (subtype == BF16_NFMA || subtype == BF16_NFMS) + op_fma = gen_rtx_NEG (V4SFmode, op_fma); - default: - gcc_unreachable (); + emit_insn (gen_rtx_SET (result_v4sf, op_fma)); } /* Convert V4SF result back to scalar mode. */ diff --git a/gcc/config/rs6000/float16.md b/gcc/config/rs6000/float16.md index 3715bde0df03..bcfa475ac043 100644 --- a/gcc/config/rs6000/float16.md +++ b/gcc/config/rs6000/float16.md @@ -445,13 +445,13 @@ ;; SFmode. (define_insn_and_split "*bfloat16_binary_op_internal1" - [(set (match_operand:SF 0 "vsx_register_operand" "=wa") + [(set (match_operand:SF 0 "vsx_register_operand") (match_operator:SF 1 "bfloat16_binary_operator" - [(float_extend:SF - (match_operand:BF 2 "vsx_register_operand" "wa")) - (float_extend:SF - (match_operand:BF 3 "vsx_register_operand" "wa"))]))] - "TARGET_BFLOAT16_HW && can_create_pseudo_p ()" + [(match_operand:SF 2 "bfloat16_v4sf_operand") + (match_operand:SF 3 "bfloat16_v4sf_operand")]))] + "TARGET_BFLOAT16_HW && can_create_pseudo_p () + && (bfloat16_bf_operand (operands[2], SFmode) + || bfloat16_bf_operand (operands[3], SFmode))" "#" "&& 1" [(pc)] @@ -466,14 +466,14 @@ }) (define_insn_and_split "*bfloat16_binary_op_internal2" - [(set (match_operand:BF 0 "vsx_register_operand" "=wa") + [(set (match_operand:BF 0 "vsx_register_operand") (float_truncate:BF (match_operator:SF 1 "bfloat16_binary_operator" - [(float_extend:SF - (match_operand:BF 2 "vsx_register_operand" "wa")) - (float_extend:SF - (match_operand:BF 3 "vsx_register_operand" "wa"))])))] - "TARGET_BFLOAT16_HW && can_create_pseudo_p ()" + [(match_operand:SF 2 "bfloat16_v4sf_operand") + (match_operand:SF 3 "bfloat16_v4sf_operand")])))] + "TARGET_BFLOAT16_HW && can_create_pseudo_p () + && (bfloat16_bf_operand (operands[2], SFmode) + || bfloat16_bf_operand (operands[3], SFmode))" "#" "&& 1" [(pc)] @@ -487,85 +487,250 @@ DONE; }) -(define_insn_and_split "*bfloat16_binary_op_internal3" - [(set (match_operand:SF 0 "vsx_register_operand" "=wa,wa,wa") - (match_operator:SF 1 "bfloat16_binary_operator" - [(float_extend:SF - (match_operand:BF 2 "vsx_register_operand" "wa,wa,wa")) - (match_operand:SF 3 "input_operand" "wa,j,eP")]))] - "TARGET_BFLOAT16_HW && can_create_pseudo_p ()" +(define_insn_and_split "*bfloat16_fma_internal1" + [(set (match_operand:SF 0 "vsx_register_operand") + (fma:SF + (match_operand:SF 1 "bfloat16_v4sf_operand") + (match_operand:SF 2 "bfloat16_v4sf_operand") + (match_operand:SF 3 "bfloat16_v4sf_operand")))] + "TARGET_BFLOAT16_HW && can_create_pseudo_p () + && (bfloat16_bf_operand (operands[1], SFmode) + + bfloat16_bf_operand (operands[2], SFmode) + + bfloat16_bf_operand (operands[3], SFmode) >= 2)" "#" "&& 1" [(pc)] { - bfloat16_operation_as_v4sf (GET_CODE (operands[1]), + bfloat16_operation_as_v4sf (FMA, operands[0], + operands[1], operands[2], operands[3], - NULL_RTX, - BF16_BINARY); + BF16_FMA); DONE; }) -(define_insn_and_split "*bfloat16_binary_op_internal4" - [(set (match_operand:BF 0 "vsx_register_operand" "=wa,&wa,&wa") +(define_insn_and_split "*bfloat16_fma_internal2" + [(set (match_operand:BF 0 "vsx_register_operand" "=wa") (float_truncate:BF - (match_operator:SF 1 "bfloat16_binary_operator" - [(float_extend:SF - (match_operand:BF 2 "vsx_register_operand" "wa,wa,wa")) - (match_operand:SF 3 "input_operand" "wa,j,eP")])))] - "TARGET_BFLOAT16_HW && can_create_pseudo_p ()" + (fma:SF + (match_operand:SF 1 "bfloat16_v4sf_operand") + (match_operand:SF 2 "bfloat16_v4sf_operand") + (match_operand:SF 3 "bfloat16_v4sf_operand"))))] + "TARGET_BFLOAT16_HW && can_create_pseudo_p () + && (bfloat16_bf_operand (operands[1], SFmode) + + bfloat16_bf_operand (operands[2], SFmode) + + bfloat16_bf_operand (operands[3], SFmode) >= 2)" "#" "&& 1" [(pc)] { - bfloat16_operation_as_v4sf (GET_CODE (operands[1]), + bfloat16_operation_as_v4sf (FMA, operands[0], + operands[1], operands[2], operands[3], - NULL_RTX, - BF16_BINARY); + BF16_FMA); DONE; }) -(define_insn_and_split "*bfloat16_binary_op_internal5" - [(set (match_operand:SF 0 "vsx_register_operand" "=wa") - (match_operator:SF 1 "bfloat16_binary_operator" - [(match_operand:SF 2 "vsx_register_operand" "wa") - (float_extend:SF - (match_operand:BF 3 "vsx_register_operand" "wa"))]))] - "TARGET_BFLOAT16_HW && can_create_pseudo_p ()" +(define_insn_and_split "*bfloat16_fms_internal1" + [(set (match_operand:SF 0 "vsx_register_operand") + (fma:SF + (match_operand:SF 1 "bfloat16_v4sf_operand") + (match_operand:SF 2 "bfloat16_v4sf_operand") + (neg:SF + (match_operand:SF 3 "bfloat16_v4sf_operand"))))] + "TARGET_BFLOAT16_HW && can_create_pseudo_p () + && (bfloat16_bf_operand (operands[1], SFmode) + + bfloat16_bf_operand (operands[2], SFmode) + + bfloat16_bf_operand (operands[3], SFmode) >= 2)" "#" "&& 1" [(pc)] { - bfloat16_operation_as_v4sf (GET_CODE (operands[1]), + bfloat16_operation_as_v4sf (FMA, operands[0], + operands[1], operands[2], operands[3], - NULL_RTX, - BF16_BINARY); + BF16_FMS); DONE; }) -(define_insn_and_split "*bfloat16_binary_op_internal6" +(define_insn_and_split "*bfloat16_fms_interna2" + [(set (match_operand:BF 0 "vsx_register_operand") + (float_truncate:BF + (fma:SF + (match_operand:SF 1 "bfloat16_v4sf_operand") + (match_operand:SF 2 "bfloat16_v4sf_operand") + (neg:SF + (match_operand:SF 3 "bfloat16_v4sf_operand")))))] + "TARGET_BFLOAT16_HW && can_create_pseudo_p () + && (bfloat16_bf_operand (operands[1], SFmode) + + bfloat16_bf_operand (operands[2], SFmode) + + bfloat16_bf_operand (operands[3], SFmode) >= 2)" + "#" + "&& 1" + [(pc)] +{ + bfloat16_operation_as_v4sf (FMA, + operands[0], + operands[1], + operands[2], + operands[3], + BF16_FMS); + DONE; +}) + +(define_insn_and_split "*bfloat16_nfma_internal1" + [(set (match_operand:SF 0 "vsx_register_operand") + (neg:SF + (fma:SF + (match_operand:SF 1 "bfloat16_v4sf_operand") + (match_operand:SF 2 "bfloat16_v4sf_operand") + (match_operand:SF 3 "bfloat16_v4sf_operand"))))] + "TARGET_BFLOAT16_HW && can_create_pseudo_p () + && (bfloat16_bf_operand (operands[1], SFmode) + + bfloat16_bf_operand (operands[2], SFmode) + + bfloat16_bf_operand (operands[3], SFmode) >= 2)" + "#" + "&& 1" + [(pc)] +{ + bfloat16_operation_as_v4sf (FMA, + operands[0], + operands[1], + operands[2], + operands[3], + BF16_NFMA); + DONE; +}) + +(define_insn_and_split "*bfloat16_nfma_internal2" [(set (match_operand:BF 0 "vsx_register_operand" "=wa") (float_truncate:BF - (match_operator:SF 1 "bfloat16_binary_operator" - [(match_operand:SF 3 "vsx_register_operand" "wa") - (float_extend:SF - (match_operand:BF 2 "vsx_register_operand" "wa"))])))] - "TARGET_BFLOAT16_HW && can_create_pseudo_p ()" + (neg:SF + (fma:SF + (match_operand:SF 1 "bfloat16_v4sf_operand") + (match_operand:SF 2 "bfloat16_v4sf_operand") + (match_operand:SF 3 "bfloat16_v4sf_operand")))))] + "TARGET_BFLOAT16_HW && can_create_pseudo_p () + && (bfloat16_bf_operand (operands[1], SFmode) + + bfloat16_bf_operand (operands[2], SFmode) + + bfloat16_bf_operand (operands[3], SFmode) >= 2)" "#" "&& 1" [(pc)] { - bfloat16_operation_as_v4sf (GET_CODE (operands[1]), + bfloat16_operation_as_v4sf (FMA, operands[0], + operands[1], operands[2], operands[3], - NULL_RTX, - BF16_BINARY); + BF16_NFMA); + DONE; +}) + +(define_insn_and_split "*bfloat16_nfma_internal3" + [(set (match_operand:BF 0 "vsx_register_operand" "=wa") + (neg:BF + (float_truncate:BF + (fma:SF + (match_operand:SF 1 "bfloat16_v4sf_operand") + (match_operand:SF 2 "bfloat16_v4sf_operand") + (match_operand:SF 3 "bfloat16_v4sf_operand")))))] + "TARGET_BFLOAT16_HW && can_create_pseudo_p () + && (bfloat16_bf_operand (operands[1], SFmode) + + bfloat16_bf_operand (operands[2], SFmode) + + bfloat16_bf_operand (operands[3], SFmode) >= 2)" + "#" + "&& 1" + [(pc)] +{ + bfloat16_operation_as_v4sf (FMA, + operands[0], + operands[1], + operands[2], + operands[3], + BF16_NFMA); + DONE; +}) + +(define_insn_and_split "*bfloat16_nfms_internal1" + [(set (match_operand:SF 0 "vsx_register_operand") + (neg:SF + (fma:SF + (match_operand:SF 1 "bfloat16_v4sf_operand") + (match_operand:SF 2 "bfloat16_v4sf_operand") + (neg:SF + (match_operand:SF 3 "bfloat16_v4sf_operand")))))] + "TARGET_BFLOAT16_HW && can_create_pseudo_p () + && (bfloat16_bf_operand (operands[1], SFmode) + + bfloat16_bf_operand (operands[2], SFmode) + + bfloat16_bf_operand (operands[3], SFmode) >= 2)" + "#" + "&& 1" + [(pc)] +{ + bfloat16_operation_as_v4sf (FMA, + operands[0], + operands[1], + operands[2], + operands[3], + BF16_NFMS); + DONE; +}) + +(define_insn_and_split "*bfloat16_nfms_internal2" + [(set (match_operand:BF 0 "vsx_register_operand") + (float_truncate:BF + (neg:SF + (fma:SF + (match_operand:SF 1 "bfloat16_v4sf_operand") + (match_operand:SF 2 "bfloat16_v4sf_operand") + (neg:SF + (match_operand:SF 3 "bfloat16_v4sf_operand"))))))] + "TARGET_BFLOAT16_HW && can_create_pseudo_p () + && (bfloat16_bf_operand (operands[1], SFmode) + + bfloat16_bf_operand (operands[2], SFmode) + + bfloat16_bf_operand (operands[3], SFmode) >= 2)" + "#" + "&& 1" + [(pc)] +{ + bfloat16_operation_as_v4sf (FMA, + operands[0], + operands[1], + operands[2], + operands[3], + BF16_NFMS); + DONE; +}) + +(define_insn_and_split "*bfloat16_nfms_internal3" + [(set (match_operand:BF 0 "vsx_register_operand") + (neg:BF + (float_truncate:BF + (fma:SF + (match_operand:SF 1 "bfloat16_v4sf_operand") + (match_operand:SF 2 "bfloat16_v4sf_operand") + (neg:SF + (match_operand:SF 3 "bfloat16_v4sf_operand"))))))] + "TARGET_BFLOAT16_HW && can_create_pseudo_p () + && (bfloat16_bf_operand (operands[1], SFmode) + + bfloat16_bf_operand (operands[2], SFmode) + + bfloat16_bf_operand (operands[3], SFmode) >= 2)" + "#" + "&& 1" + [(pc)] +{ + bfloat16_operation_as_v4sf (FMA, + operands[0], + operands[1], + operands[2], + operands[3], + BF16_NFMS); DONE; }) diff --git a/gcc/config/rs6000/predicates.md b/gcc/config/rs6000/predicates.md index 2de33f7f32a6..55f2232ee6cf 100644 --- a/gcc/config/rs6000/predicates.md +++ b/gcc/config/rs6000/predicates.md @@ -2210,3 +2210,50 @@ ;; then converting the V4SFmode element to SFmode scalar. (define_predicate "bfloat16_binary_operator" (match_code "plus,minus,mult,smax,smin")) + +;; Match bfloat16/float operands that can be optimized to do the +;; operation in V4SFmode. +(define_predicate "bfloat16_v4sf_operand" + (match_code "reg,subreg,const_double,float_extend,float_truncate") +{ + if (mode != BFmode && mode != SFmode) + return false; + + if (REG_P (op) || SUBREG_P (op)) + return register_operand (op, mode); + + if (CONST_DOUBLE_P (op)) + return true; + + if (GET_CODE (op) == FLOAT_EXTEND + && mode == SFmode + && GET_MODE (XEXP (op, 0)) == BFmode) + return true; + + if (GET_CODE (op) == FLOAT_TRUNCATE + && mode == BFmode + && GET_MODE (XEXP (op, 0)) == SFmode) + return true; + + return false; +}) + +;; Match an operand that originally was an BFmode value to prevent +;; operations involing only SFmode values from being converted to +;; BFmode. +(define_predicate "bfloat16_bf_operand" + (match_code "reg,subreg,const_double,float_extend") +{ + if (mode == BFmode || GET_MODE (op) == BFmode) + return true; + + if (mode != SFmode) + return false; + + if (GET_MODE (op) == SFmode + && GET_CODE (op) == FLOAT_EXTEND + && GET_MODE (XEXP (op, 0)) == BFmode) + return true; + + return false; +}) diff --git a/gcc/config/rs6000/vsx.md b/gcc/config/rs6000/vsx.md index 2611660921a5..6c11d7766ed1 100644 --- a/gcc/config/rs6000/vsx.md +++ b/gcc/config/rs6000/vsx.md @@ -2098,7 +2098,7 @@ ;; vmaddfp and vnmsubfp can have different behaviors than the VSX instructions ;; in some corner cases due to VSCR[NJ] being set or if the addend is +0.0 ;; instead of -0.0. -(define_insn "vsx_fmav4sf4" +(define_insn "*vsx_fmav4sf4" [(set (match_operand:V4SF 0 "vsx_register_operand" "=wa,wa") (fma:V4SF (match_operand:V4SF 1 "vsx_register_operand" "%wa,wa") @@ -2122,7 +2122,7 @@ xvmaddmdp %x0,%x1,%x3" [(set_attr "type" "vecdouble")]) -(define_insn "vsx_fms<mode>4" +(define_insn "*vsx_fms<mode>4" [(set (match_operand:VSX_F 0 "vsx_register_operand" "=wa,wa") (fma:VSX_F (match_operand:VSX_F 1 "vsx_register_operand" "%wa,wa") @@ -2135,7 +2135,7 @@ xvmsubm<sd>p %x0,%x1,%x3" [(set_attr "type" "<VStype_mul>")]) -(define_insn "vsx_nfma<mode>4" +(define_insn "*vsx_nfma<mode>4" [(set (match_operand:VSX_F 0 "vsx_register_operand" "=wa,wa") (neg:VSX_F (fma:VSX_F @@ -2148,7 +2148,7 @@ xvnmaddm<sd>p %x0,%x1,%x3" [(set_attr "type" "<VStype_mul>")]) -(define_insn "vsx_nfmsv4sf4" +(define_insn "*vsx_nfmsv4sf4" [(set (match_operand:V4SF 0 "vsx_register_operand" "=wa,wa") (neg:V4SF (fma:V4SF
