From: Ju-Zhe Zhong <juzhe.zh...@rivai.ai>

Hi, Richard and Richi.

Previous patch we support COND_LEN_* binary operations. However, we didn't
support COND_LEN_* ternary.

Now, this patch support COND_LEN_* ternary. Consider this following case:

#define TEST_TYPE(TYPE)                                                        \
  __attribute__ ((noipa)) void ternop_##TYPE (TYPE *__restrict dst,            \
                                              TYPE *__restrict a,              \
                                              TYPE *__restrict b,\
                TYPE *__restrict c, int n)       \
  {                                                                            \
    for (int i = 0; i < n; i++)                                                \
      dst[i] += a[i] * b[i];                                                    
 \
  }

#define TEST_ALL() TEST_TYPE (double)

TEST_ALL ()

Before this patch:
...
COND_LEN_MUL
COND_LEN_ADD

Afther this patch:
...
COND_LEN_FMA

gcc/ChangeLog:

        * genmatch.cc (commutative_op): Add COND_LEN_*
        * internal-fn.cc (first_commutative_argument): Ditto.
        (CASE): Ditto.
        (get_unconditional_internal_fn): Ditto.
        (can_interpret_as_conditional_op_p): Ditto.
        (internal_fn_len_index): Ditto.
        * internal-fn.h (can_interpret_as_conditional_op_p): Ditt.
        * tree-ssa-math-opts.cc (convert_mult_to_fma_1): Ditto.
        (convert_mult_to_fma): Ditto.
        (math_opts_dom_walker::after_dom_children): Ditto.

---
 gcc/genmatch.cc           | 13 ++++++
 gcc/internal-fn.cc        | 87 ++++++++++++++++++++++++++++++++++-----
 gcc/internal-fn.h         |  2 +-
 gcc/tree-ssa-math-opts.cc | 80 +++++++++++++++++++++++++++++------
 4 files changed, 159 insertions(+), 23 deletions(-)

diff --git a/gcc/genmatch.cc b/gcc/genmatch.cc
index 5fceeec9780..2302f2a7ff0 100644
--- a/gcc/genmatch.cc
+++ b/gcc/genmatch.cc
@@ -559,6 +559,19 @@ commutative_op (id_base *id)
       case CFN_COND_FMS:
       case CFN_COND_FNMA:
       case CFN_COND_FNMS:
+      case CFN_COND_LEN_ADD:
+      case CFN_COND_LEN_MUL:
+      case CFN_COND_LEN_MIN:
+      case CFN_COND_LEN_MAX:
+      case CFN_COND_LEN_FMIN:
+      case CFN_COND_LEN_FMAX:
+      case CFN_COND_LEN_AND:
+      case CFN_COND_LEN_IOR:
+      case CFN_COND_LEN_XOR:
+      case CFN_COND_LEN_FMA:
+      case CFN_COND_LEN_FMS:
+      case CFN_COND_LEN_FNMA:
+      case CFN_COND_LEN_FNMS:
        return 1;
 
       default:
diff --git a/gcc/internal-fn.cc b/gcc/internal-fn.cc
index c11123a1173..e698f0bffc7 100644
--- a/gcc/internal-fn.cc
+++ b/gcc/internal-fn.cc
@@ -4191,6 +4191,19 @@ first_commutative_argument (internal_fn fn)
     case IFN_COND_FMS:
     case IFN_COND_FNMA:
     case IFN_COND_FNMS:
+    case IFN_COND_LEN_ADD:
+    case IFN_COND_LEN_MUL:
+    case IFN_COND_LEN_MIN:
+    case IFN_COND_LEN_MAX:
+    case IFN_COND_LEN_FMIN:
+    case IFN_COND_LEN_FMAX:
+    case IFN_COND_LEN_AND:
+    case IFN_COND_LEN_IOR:
+    case IFN_COND_LEN_XOR:
+    case IFN_COND_LEN_FMA:
+    case IFN_COND_LEN_FMS:
+    case IFN_COND_LEN_FNMA:
+    case IFN_COND_LEN_FNMS:
       return 1;
 
     default:
@@ -4330,11 +4343,14 @@ conditional_internal_fn_code (internal_fn ifn)
 {
   switch (ifn)
     {
-#define CASE(CODE, IFN) case IFN_COND_##IFN: return CODE;
-      FOR_EACH_CODE_MAPPING(CASE)
+#define CASE(CODE, IFN)                                                        
\
+  case IFN_COND_##IFN:                                                         
\
+  case IFN_COND_LEN_##IFN:                                                     
\
+    return CODE;
+      FOR_EACH_CODE_MAPPING (CASE)
 #undef CASE
-    default:
-      return ERROR_MARK;
+      default:
+       return ERROR_MARK;
     }
 }
 
@@ -4433,6 +4449,18 @@ get_unconditional_internal_fn (internal_fn ifn)
    operating elementwise if the operands are vectors.  This includes
    the case of an all-true COND, so that the operation always happens.
 
+   There is an alternative approach to interpret the STMT when the operands
+   are vectors which is the operation predicated by both conditional mask
+   and loop control length, the equivalent C code:
+
+     for (int i = 0; i < NUNTIS; i++)
+      {
+       if (i < LEN + BIAS && COND[i])
+         LHS[i] = A[i] CODE B[i];
+       else
+         LHS[i] = ELSE[i];
+      }
+
    When returning true, set:
 
    - *COND_OUT to the condition COND, or to NULL_TREE if the condition
@@ -4440,13 +4468,18 @@ get_unconditional_internal_fn (internal_fn ifn)
    - *CODE_OUT to the tree code
    - OPS[I] to operand I of *CODE_OUT
    - *ELSE_OUT to the fallback value ELSE, or to NULL_TREE if the
-     condition is known to be all true.  */
+     condition is known to be all true.
+   - *LEN to the len argument if it COND_LEN_* operations or to NULL_TREE.
+   - *BIAS to the bias argument if it COND_LEN_* operations or to NULL_TREE.  
*/
 
 bool
 can_interpret_as_conditional_op_p (gimple *stmt, tree *cond_out,
                                   tree_code *code_out,
-                                  tree (&ops)[3], tree *else_out)
+                                  tree (&ops)[3], tree *else_out,
+                                  tree *len, tree *bias)
 {
+  *len = NULL_TREE;
+  *bias = NULL_TREE;
   if (gassign *assign = dyn_cast <gassign *> (stmt))
     {
       *cond_out = NULL_TREE;
@@ -4462,18 +4495,28 @@ can_interpret_as_conditional_op_p (gimple *stmt, tree 
*cond_out,
       {
        internal_fn ifn = gimple_call_internal_fn (call);
        tree_code code = conditional_internal_fn_code (ifn);
+       int len_index = internal_fn_len_index (ifn);
+       int cond_nargs = len_index >= 0 ? 4 : 2;
        if (code != ERROR_MARK)
          {
            *cond_out = gimple_call_arg (call, 0);
            *code_out = code;
-           unsigned int nops = gimple_call_num_args (call) - 2;
+           unsigned int nops = gimple_call_num_args (call) - cond_nargs;
            for (unsigned int i = 0; i < 3; ++i)
              ops[i] = i < nops ? gimple_call_arg (call, i + 1) : NULL_TREE;
            *else_out = gimple_call_arg (call, nops + 1);
-           if (integer_truep (*cond_out))
+           if (len_index < 0)
+             {
+               if (integer_truep (*cond_out))
+                 {
+                   *cond_out = NULL_TREE;
+                   *else_out = NULL_TREE;
+                 }
+             }
+           else
              {
-               *cond_out = NULL_TREE;
-               *else_out = NULL_TREE;
+               *len = gimple_call_arg (call, len_index);
+               *bias = gimple_call_arg (call, len_index + 1);
              }
            return true;
          }
@@ -4561,8 +4604,32 @@ internal_fn_len_index (internal_fn fn)
 
     case IFN_LEN_MASK_GATHER_LOAD:
     case IFN_LEN_MASK_SCATTER_STORE:
+    case IFN_COND_LEN_FMA:
+    case IFN_COND_LEN_FMS:
+    case IFN_COND_LEN_FNMA:
+    case IFN_COND_LEN_FNMS:
       return 5;
 
+    case IFN_COND_LEN_ADD:
+    case IFN_COND_LEN_SUB:
+    case IFN_COND_LEN_MUL:
+    case IFN_COND_LEN_DIV:
+    case IFN_COND_LEN_MOD:
+    case IFN_COND_LEN_RDIV:
+    case IFN_COND_LEN_MIN:
+    case IFN_COND_LEN_MAX:
+    case IFN_COND_LEN_FMIN:
+    case IFN_COND_LEN_FMAX:
+    case IFN_COND_LEN_AND:
+    case IFN_COND_LEN_IOR:
+    case IFN_COND_LEN_XOR:
+    case IFN_COND_LEN_SHL:
+    case IFN_COND_LEN_SHR:
+      return 4;
+
+    case IFN_COND_LEN_NEG:
+      return 3;
+
     default:
       return -1;
     }
diff --git a/gcc/internal-fn.h b/gcc/internal-fn.h
index dd1bab0bddf..a5c3f4765ff 100644
--- a/gcc/internal-fn.h
+++ b/gcc/internal-fn.h
@@ -229,7 +229,7 @@ extern tree_code conditional_internal_fn_code (internal_fn);
 extern internal_fn get_unconditional_internal_fn (internal_fn);
 extern bool can_interpret_as_conditional_op_p (gimple *, tree *,
                                               tree_code *, tree (&)[3],
-                                              tree *);
+                                              tree *, tree *, tree *);
 
 extern bool internal_load_fn_p (internal_fn);
 extern bool internal_store_fn_p (internal_fn);
diff --git a/gcc/tree-ssa-math-opts.cc b/gcc/tree-ssa-math-opts.cc
index 68fc518b1ab..712097ac5be 100644
--- a/gcc/tree-ssa-math-opts.cc
+++ b/gcc/tree-ssa-math-opts.cc
@@ -3099,10 +3099,11 @@ convert_mult_to_fma_1 (tree mul_result, tree op1, tree 
op2)
          negate_p = true;
        }
 
-      tree cond, else_value, ops[3];
+      tree cond, else_value, ops[3], len, bias;
       tree_code code;
       if (!can_interpret_as_conditional_op_p (use_stmt, &cond, &code,
-                                             ops, &else_value))
+                                             ops, &else_value,
+                                             &len, &bias))
        gcc_unreachable ();
       addop = ops[0] == result ? ops[1] : ops[0];
 
@@ -3122,7 +3123,11 @@ convert_mult_to_fma_1 (tree mul_result, tree op1, tree 
op2)
       if (seq)
        gsi_insert_seq_before (&gsi, seq, GSI_SAME_STMT);
 
-      if (cond)
+      if (len)
+       fma_stmt
+         = gimple_build_call_internal (IFN_COND_LEN_FMA, 7, cond, mulop1, op2,
+                                       addop, else_value, len, bias);
+      else if (cond)
        fma_stmt = gimple_build_call_internal (IFN_COND_FMA, 5, cond, mulop1,
                                               op2, addop, else_value);
       else
@@ -3307,7 +3312,8 @@ last_fma_candidate_feeds_initial_phi (fma_deferring_state 
*state,
 
 static bool
 convert_mult_to_fma (gimple *mul_stmt, tree op1, tree op2,
-                    fma_deferring_state *state, tree mul_cond = NULL_TREE)
+                    fma_deferring_state *state, tree mul_cond = NULL_TREE,
+                    tree mul_len = NULL_TREE, tree mul_bias = NULL_TREE)
 {
   tree mul_result = gimple_get_lhs (mul_stmt);
   /* If there isn't a LHS then this can't be an FMA.  There can be no LHS
@@ -3420,10 +3426,10 @@ convert_mult_to_fma (gimple *mul_stmt, tree op1, tree 
op2,
          negate_p = seen_negate_p = true;
        }
 
-      tree cond, else_value, ops[3];
+      tree cond, else_value, ops[3], len, bias;
       tree_code code;
       if (!can_interpret_as_conditional_op_p (use_stmt, &cond, &code, ops,
-                                             &else_value))
+                                             &else_value, &len, &bias))
        return false;
 
       switch (code)
@@ -3439,15 +3445,49 @@ convert_mult_to_fma (gimple *mul_stmt, tree op1, tree 
op2,
          return false;
        }
 
-      if (mul_cond && cond != mul_cond)
-       return false;
-
-      if (cond)
+      if (len)
        {
-         if (cond == result || else_value == result)
+         /* For COND_LEN_* operations, we may have dummpy mask which is
+            the all true mask.  Such TREE type may be mul_cond != cond
+            but we still consider they are equal.  */
+         if (mul_cond && cond != mul_cond
+             && !(integer_truep (mul_cond) && integer_truep (cond)))
            return false;
-         if (!direct_internal_fn_supported_p (IFN_COND_FMA, type, opt_type))
+
+         if (else_value == result)
+           return false;
+
+         if (!direct_internal_fn_supported_p (IFN_COND_LEN_FMA, type,
+                                              opt_type))
            return false;
+
+         if (mul_len)
+           {
+             poly_int64 mul_value, value;
+             if (poly_int_tree_p (mul_len, &mul_value)
+                 && poly_int_tree_p (len, &value)
+                 && maybe_ne (mul_value, value))
+               return false;
+             else if (mul_len != len)
+               return false;
+
+             if (wi::to_widest (mul_bias) != wi::to_widest (bias))
+               return false;
+           }
+       }
+      else
+       {
+         if (mul_cond && cond != mul_cond)
+           return false;
+
+         if (cond)
+           {
+             if (cond == result || else_value == result)
+               return false;
+             if (!direct_internal_fn_supported_p (IFN_COND_FMA, type,
+                                                  opt_type))
+               return false;
+           }
        }
 
       /* If the subtrahend (OPS[1]) is computed by a MULT_EXPR that
@@ -5632,6 +5672,22 @@ math_opts_dom_walker::after_dom_children (basic_block bb)
                }
              break;
 
+           case CFN_COND_LEN_MUL:
+             if (convert_mult_to_fma (stmt,
+                                      gimple_call_arg (stmt, 1),
+                                      gimple_call_arg (stmt, 2),
+                                      &fma_state,
+                                      gimple_call_arg (stmt, 0),
+                                      gimple_call_arg (stmt, 4),
+                                      gimple_call_arg (stmt, 5)))
+
+               {
+                 gsi_remove (&gsi, true);
+                 release_defs (stmt);
+                 continue;
+               }
+             break;
+
            case CFN_LAST:
              cancel_fma_deferring (&fma_state);
              break;
-- 
2.36.3

Reply via email to