https://gcc.gnu.org/g:5b7639b0de2f1d610e16eec628033fe3650d4763

commit r15-10468-g5b7639b0de2f1d610e16eec628033fe3650d4763
Author: Tamar Christina <[email protected]>
Date:   Mon Oct 27 17:55:38 2025 +0000

    vect: Fix operand swapping on complex multiplication detection [PR122408]
    
    For
    
    SUBROUTINE a( j, b, c, d )
      !GCC$ ATTRIBUTES noinline :: a
      COMPLEX*16         b
      COMPLEX*16         c( * ), d( * )
      DO k = 1, j
         c( k ) = - b * CONJG( d( k ) )
      END DO
    END
    
    we incorrectly generate .IFN_COMPLEX_MUL instead of .IFN_COMPLEX_MUL_CONJ.
    
    The issue happens because in the call to vect_validate_multiplication the
    operand vectors are passed by reference and so the stripping of the 
NEGATE_EXPR
    after matching modifies the input vector.  If validation fail we flip the
    operands and try again.  But we've already stipped the negates and so if we
    match we would match a normal multiply.
    
    This fixes the API by marking the operands as const and instead pass an 
explicit
    output vec that's to be used.  This also reduces the number of copies we 
were
    doing.
    
    With this we now correctly detect .IFN_COMPLEX_MUL_CONJ.  Weirdly enough I
    couldn't reproduce this with any C example because they get reassociated
    differently and always succeed on the first attempt.  Fortran is easy to
    trigger though so new fortran tests added.
    
    gcc/ChangeLog:
    
            PR tree-optimization/122408
            * tree-vect-slp-patterns.cc (vect_validate_multiplication): Cleanup 
and
            document interface.
            (complex_mul_pattern::matches, complex_fms_pattern::matches): 
Update to
            new interface.
    
    gcc/testsuite/ChangeLog:
    
            PR tree-optimization/122408
            * gfortran.target/aarch64/pr122408_1.f90: New test.
            * gfortran.target/aarch64/pr122408_2.f90: New test.
    
    (cherry picked from commit c5fa3d4c88fc4f8799318e463c47941eb52b7546)

Diff:
---
 .../gfortran.target/aarch64/pr122408_1.f90         |  61 +++++++++
 .../gfortran.target/aarch64/pr122408_2.f90         | 140 +++++++++++++++++++++
 gcc/tree-vect-slp-patterns.cc                      |  75 ++++++-----
 3 files changed, 244 insertions(+), 32 deletions(-)

diff --git a/gcc/testsuite/gfortran.target/aarch64/pr122408_1.f90 
b/gcc/testsuite/gfortran.target/aarch64/pr122408_1.f90
new file mode 100644
index 000000000000..8a3416231ff1
--- /dev/null
+++ b/gcc/testsuite/gfortran.target/aarch64/pr122408_1.f90
@@ -0,0 +1,61 @@
+! { dg-do compile }
+! { dg-additional-options "-O2 -march=armv8.3-a" }
+
+subroutine c_add_ab(n, a, c, b)         ! C += A * B
+  use iso_fortran_env, only: real64
+  implicit none
+  !GCC$ ATTRIBUTES noinline :: c_add_ab
+  integer, intent(in) :: n
+  complex(real64), intent(in)    :: a
+  complex(real64), intent(inout) :: c(*)
+  complex(real64), intent(in)    :: b(*)
+  integer :: k
+  do k = 1, n
+    c(k) = c(k) + a * b(k)
+  end do
+end subroutine c_add_ab
+
+subroutine c_sub_ab(n, a, c, b)         ! C -= A * B
+  use iso_fortran_env, only: real64
+  implicit none
+  !GCC$ ATTRIBUTES noinline :: c_sub_ab
+  integer, intent(in) :: n
+  complex(real64), intent(in)    :: a
+  complex(real64), intent(inout) :: c(*)
+  complex(real64), intent(in)    :: b(*)
+  integer :: k
+  do k = 1, n
+    c(k) = c(k) - a * b(k)
+  end do
+end subroutine c_sub_ab
+
+subroutine c_add_a_conjb(n, a, c, b)    ! C += A * conj(B)
+  use iso_fortran_env, only: real64
+  implicit none
+  !GCC$ ATTRIBUTES noinline :: c_add_a_conjb
+  integer, intent(in) :: n
+  complex(real64), intent(in)    :: a
+  complex(real64), intent(inout) :: c(*)
+  complex(real64), intent(in)    :: b(*)
+  integer :: k
+  do k = 1, n
+    c(k) = c(k) + a * conjg(b(k))
+  end do
+end subroutine c_add_a_conjb
+
+subroutine c_sub_a_conjb(n, a, c, b)    ! C -= A * conj(B)
+  use iso_fortran_env, only: real64
+  implicit none
+  !GCC$ ATTRIBUTES noinline :: c_sub_a_conjb
+  integer, intent(in) :: n
+  complex(real64), intent(in)    :: a
+  complex(real64), intent(inout) :: c(*)
+  complex(real64), intent(in)    :: b(*)
+  integer :: k
+  do k = 1, n
+    c(k) = c(k) - a * conjg(b(k))
+  end do
+end subroutine c_sub_a_conjb
+
+! { dg-final { scan-assembler-times {fcmla\s+v[0-9]+.2d, v[0-9]+.2d, 
v[0-9]+.2d, #0} 2 } }
+! { dg-final { scan-assembler-times {fcmla\s+v[0-9]+.2d, v[0-9]+.2d, 
v[0-9]+.2d, #270} 2 } }
diff --git a/gcc/testsuite/gfortran.target/aarch64/pr122408_2.f90 
b/gcc/testsuite/gfortran.target/aarch64/pr122408_2.f90
new file mode 100644
index 000000000000..feb6dc14af8a
--- /dev/null
+++ b/gcc/testsuite/gfortran.target/aarch64/pr122408_2.f90
@@ -0,0 +1,140 @@
+! { dg-do run }
+! { dg-additional-options "-O2" }
+! { dg-additional-options "-O2 -march=armv8.3-a" { target 
arm_v8_3a_complex_neon_hw } }
+
+module util
+  use iso_fortran_env, only: real64, int64
+  implicit none
+contains
+  pure logical function bitwise_eq(x, y)
+    complex(real64), intent(in) :: x, y
+    integer(int64) :: xr, xi, yr, yi
+    xr = transfer(real(x,kind=real64), 0_int64)
+    xi = transfer(aimag(x),             0_int64)
+    yr = transfer(real(y,kind=real64),  0_int64)
+    yi = transfer(aimag(y),              0_int64)
+    bitwise_eq = (xr == yr) .and. (xi == yi)
+  end function bitwise_eq
+
+  subroutine check_equal(tag, got, ref, nfail)
+    character(*), intent(in) :: tag
+    complex(real64), intent(in) :: got(:), ref(:)
+    integer, intent(inout) :: nfail
+    integer :: i
+    do i = 1, size(got)
+      if (.not. bitwise_eq(got(i), ref(i))) then
+        nfail = nfail + 1
+        write(*,'(A,": mismatch at i=",I0, "  got=",2ES16.8,"  
ref=",2ES16.8)') &
+             trim(tag), i, real(got(i)), aimag(got(i)), real(ref(i)), 
aimag(ref(i))
+      end if
+    end do
+  end subroutine check_equal
+end module util
+
+module fcmla_ops
+  use iso_fortran_env, only: real64
+  implicit none
+contains
+  subroutine c_add_ab(n, a, c, b)         ! C += A * B
+    !GCC$ ATTRIBUTES noinline :: c_add_ab
+    integer, intent(in) :: n
+    complex(real64), intent(in)    :: a
+    complex(real64), intent(inout) :: c(*)
+    complex(real64), intent(in)    :: b(*)
+    integer :: k
+    do k = 1, n
+      c(k) = c(k) + a * b(k)
+    end do
+  end subroutine c_add_ab
+
+  subroutine c_sub_ab(n, a, c, b)         ! C -= A * B
+    !GCC$ ATTRIBUTES noinline :: c_sub_ab
+    integer, intent(in) :: n
+    complex(real64), intent(in)    :: a
+    complex(real64), intent(inout) :: c(*)
+    complex(real64), intent(in)    :: b(*)
+    integer :: k
+    do k = 1, n
+      c(k) = c(k) - a * b(k)
+    end do
+  end subroutine c_sub_ab
+
+  subroutine c_add_a_conjb(n, a, c, b)    ! C += A * conj(B)
+    !GCC$ ATTRIBUTES noinline :: c_add_a_conjb
+    integer, intent(in) :: n
+    complex(real64), intent(in)    :: a
+    complex(real64), intent(inout) :: c(*)
+    complex(real64), intent(in)    :: b(*)
+    integer :: k
+    do k = 1, n
+      c(k) = c(k) + a * conjg(b(k))
+    end do
+  end subroutine c_add_a_conjb
+
+  subroutine c_sub_a_conjb(n, a, c, b)    ! C -= A * conj(B)
+    !GCC$ ATTRIBUTES noinline :: c_sub_a_conjb
+    integer, intent(in) :: n
+    complex(real64), intent(in)    :: a
+    complex(real64), intent(inout) :: c(*)
+    complex(real64), intent(in)    :: b(*)
+    integer :: k
+    do k = 1, n
+      c(k) = c(k) - a * conjg(b(k))
+    end do
+  end subroutine c_sub_a_conjb
+end module fcmla_ops
+
+program fcmla_accum_pairs
+  use iso_fortran_env, only: real64
+  use util
+  use fcmla_ops
+  implicit none
+
+  integer, parameter :: n = 4
+  complex(real64) :: a, b(n), c0(n)
+  complex(real64) :: c_add_ab_got(n),      c_add_ab_ref(n)
+  complex(real64) :: c_sub_ab_got(n),      c_sub_ab_ref(n)
+  complex(real64) :: c_add_conjb_got(n),   c_add_conjb_ref(n)
+  complex(real64) :: c_sub_conjb_got(n),   c_sub_conjb_ref(n)
+  integer :: i, fails
+
+  ! Constants (include a signed-zero lane)
+  a    = cmplx( 2.0_real64, -3.0_real64, kind=real64)
+  b(1) = cmplx( 1.5_real64, -2.0_real64, kind=real64)
+  b(2) = cmplx(-4.0_real64,  5.0_real64, kind=real64)
+  b(3) = cmplx(-0.0_real64,  0.0_real64, kind=real64)
+  b(4) = cmplx( 0.25_real64, 3.0_real64, kind=real64)
+
+  c0(1) = cmplx( 1.0_real64, -2.0_real64, kind=real64)
+  c0(2) = cmplx( 3.0_real64, -4.0_real64, kind=real64)
+  c0(3) = cmplx(-5.0_real64,  6.0_real64, kind=real64)
+  c0(4) = cmplx( 0.0_real64,  0.0_real64, kind=real64)
+
+  ! Run each form
+  c_add_ab_got    = c0; call c_add_ab     (n, a, c_add_ab_got,    b)
+  c_sub_ab_got    = c0; call c_sub_ab     (n, a, c_sub_ab_got,    b)
+  c_add_conjb_got = c0; call c_add_a_conjb(n, a, c_add_conjb_got, b)
+  c_sub_conjb_got = c0; call c_sub_a_conjb(n, a, c_sub_conjb_got, b)
+
+  ! Scalar references
+  do i = 1, n
+    c_add_ab_ref(i)    = c0(i) + a * b(i)
+    c_sub_ab_ref(i)    = c0(i) - a * b(i)
+    c_add_conjb_ref(i) = c0(i) + a * conjg(b(i))
+    c_sub_conjb_ref(i) = c0(i) - a * conjg(b(i))
+  end do
+
+  ! Bitwise checks
+  fails = 0
+  call check_equal("C +=  A*B       ", c_add_ab_got,    c_add_ab_ref,    fails)
+  call check_equal("C -=  A*B       ", c_sub_ab_got,    c_sub_ab_ref,    fails)
+  call check_equal("C +=  A*conj(B) ", c_add_conjb_got, c_add_conjb_ref, fails)
+  call check_equal("C -=  A*conj(B) ", c_sub_conjb_got, c_sub_conjb_ref, fails)
+
+  if (fails == 0) then
+    stop 0
+  else
+    stop 1
+  end if
+end program fcmla_accum_pairs
+
diff --git a/gcc/tree-vect-slp-patterns.cc b/gcc/tree-vect-slp-patterns.cc
index c0dff90d9baf..cebd9aa1c13c 100644
--- a/gcc/tree-vect-slp-patterns.cc
+++ b/gcc/tree-vect-slp-patterns.cc
@@ -847,15 +847,23 @@ compatible_complex_nodes_p (slp_compat_nodes_map_t 
*compat_cache,
   return true;
 }
 
+
+/* Check to see if the oprands to two multiplies, 2 each in LEFT_OP and
+   RIGHT_OP match a complex multiplication  or complex multiply-and-accumulate
+   or complex multiply-and-subtract pattern.  Do this using the permute cache
+   PERM_CACHE and the combination compatibility list COMPAT_CACHE.  If
+   the operation is successful the macthing operands are returned in OPS and
+   _STATUS indicates if the operation matched includes a conjugate of one of 
the
+   operands.  If the operation succeeds True is returned, otherwise False and
+   the values in ops are meaningless.  */
 static inline bool
 vect_validate_multiplication (slp_tree_to_load_perm_map_t *perm_cache,
                              slp_compat_nodes_map_t *compat_cache,
-                             vec<slp_tree> &left_op,
-                             vec<slp_tree> &right_op,
-                             bool subtract,
+                             const vec<slp_tree> &left_op,
+                             const vec<slp_tree> &right_op,
+                             bool subtract, vec<slp_tree> &ops,
                              enum _conj_status *_status)
 {
-  auto_vec<slp_tree> ops;
   enum _conj_status stats = CONJ_NONE;
 
   /* The complex operations can occur in two layouts and two permute sequences
@@ -886,31 +894,31 @@ vect_validate_multiplication (slp_tree_to_load_perm_map_t 
*perm_cache,
   bool neg0 = vect_match_expression_p (right_op[0], NEGATE_EXPR);
   bool neg1 = vect_match_expression_p (right_op[1], NEGATE_EXPR);
 
+  /* Create the combined inputs after remapping and flattening.  */
+  ops.create (4);
+  ops.safe_splice (left_op);
+  ops.safe_splice (right_op);
+
   /* Determine which style we're looking at.  We only have different ones
      whenever a conjugate is involved.  */
   if (neg0 && neg1)
     ;
   else if (neg0)
     {
-      right_op[0] = SLP_TREE_CHILDREN (right_op[0])[0];
+      ops[2] = SLP_TREE_CHILDREN (right_op[0])[0];
       stats = CONJ_FST;
       if (subtract)
        perm = 0;
     }
   else if (neg1)
     {
-      right_op[1] = SLP_TREE_CHILDREN (right_op[1])[0];
+      ops[3] = SLP_TREE_CHILDREN (right_op[1])[0];
       stats = CONJ_SND;
       perm = 1;
     }
 
   *_status = stats;
 
-  /* Flatten the inputs after we've remapped them.  */
-  ops.create (4);
-  ops.safe_splice (left_op);
-  ops.safe_splice (right_op);
-
   /* Extract out the elements to check.  */
   slp_tree op0 = ops[styles[style][0]];
   slp_tree op1 = ops[styles[style][1]];
@@ -1073,15 +1081,16 @@ complex_mul_pattern::matches (complex_operation_t op,
     return IFN_LAST;
 
   enum _conj_status status;
+  auto_vec<slp_tree> res_ops;
   if (!vect_validate_multiplication (perm_cache, compat_cache, left_op,
-                                    right_op, false, &status))
+                                    right_op, false, res_ops, &status))
     {
       /* Try swapping the order and re-trying since multiplication is
         commutative.  */
       std::swap (left_op[0], left_op[1]);
       std::swap (right_op[0], right_op[1]);
       if (!vect_validate_multiplication (perm_cache, compat_cache, left_op,
-                                        right_op, false, &status))
+                                        right_op, false, res_ops, &status))
        return IFN_LAST;
     }
 
@@ -1109,24 +1118,24 @@ complex_mul_pattern::matches (complex_operation_t op,
   if (add0)
     ops->quick_push (add0);
 
-  complex_perm_kinds_t kind = linear_loads_p (perm_cache, left_op[0]);
+  complex_perm_kinds_t kind = linear_loads_p (perm_cache, res_ops[0]);
   if (kind == PERM_EVENODD || kind == PERM_TOP)
     {
-      ops->quick_push (left_op[1]);
-      ops->quick_push (right_op[1]);
-      ops->quick_push (left_op[0]);
+      ops->quick_push (res_ops[1]);
+      ops->quick_push (res_ops[3]);
+      ops->quick_push (res_ops[0]);
     }
   else if (kind == PERM_EVENEVEN && status != CONJ_SND)
     {
-      ops->quick_push (left_op[0]);
-      ops->quick_push (right_op[0]);
-      ops->quick_push (left_op[1]);
+      ops->quick_push (res_ops[0]);
+      ops->quick_push (res_ops[2]);
+      ops->quick_push (res_ops[1]);
     }
   else
     {
-      ops->quick_push (left_op[0]);
-      ops->quick_push (right_op[1]);
-      ops->quick_push (left_op[1]);
+      ops->quick_push (res_ops[0]);
+      ops->quick_push (res_ops[3]);
+      ops->quick_push (res_ops[1]);
     }
 
   return ifn;
@@ -1298,15 +1307,17 @@ complex_fms_pattern::matches (complex_operation_t op,
     return IFN_LAST;
 
   enum _conj_status status;
+  auto_vec<slp_tree> res_ops;
   if (!vect_validate_multiplication (perm_cache, compat_cache, right_op,
-                                    left_op, true, &status))
+                                    left_op, true, res_ops, &status))
     {
       /* Try swapping the order and re-trying since multiplication is
         commutative.  */
       std::swap (left_op[0], left_op[1]);
       std::swap (right_op[0], right_op[1]);
+      auto_vec<slp_tree> res_ops;
       if (!vect_validate_multiplication (perm_cache, compat_cache, right_op,
-                                        left_op, true, &status))
+                                        left_op, true, res_ops, &status))
        return IFN_LAST;
     }
 
@@ -1321,20 +1332,20 @@ complex_fms_pattern::matches (complex_operation_t op,
   ops->truncate (0);
   ops->create (4);
 
-  complex_perm_kinds_t kind = linear_loads_p (perm_cache, right_op[0]);
+  complex_perm_kinds_t kind = linear_loads_p (perm_cache, res_ops[2]);
   if (kind == PERM_EVENODD)
     {
       ops->quick_push (l0node[0]);
-      ops->quick_push (right_op[0]);
-      ops->quick_push (right_op[1]);
-      ops->quick_push (left_op[1]);
+      ops->quick_push (res_ops[2]);
+      ops->quick_push (res_ops[3]);
+      ops->quick_push (res_ops[1]);
     }
   else
     {
       ops->quick_push (l0node[0]);
-      ops->quick_push (right_op[1]);
-      ops->quick_push (right_op[0]);
-      ops->quick_push (left_op[0]);
+      ops->quick_push (res_ops[3]);
+      ops->quick_push (res_ops[2]);
+      ops->quick_push (res_ops[0]);
     }
 
   return ifn;

Reply via email to