On Thu, 13 Jul 2023, juzhe.zh...@rivai.ai wrote:

> From: Ju-Zhe Zhong <juzhe.zh...@rivai.ai>
> 
> Hi, Richard and Richi.
> 
> Previous patch we support COND_LEN_* binary operations. However, we didn't
> support COND_LEN_* ternary.
> 
> Now, this patch support COND_LEN_* ternary. Consider this following case:
> 
> #define TEST_TYPE(TYPE)                                                       
>  \
>   __attribute__ ((noipa)) void ternop_##TYPE (TYPE *__restrict dst,           
>  \
>                                             TYPE *__restrict a,              \
>                                             TYPE *__restrict b,\
>                 TYPE *__restrict c, int n)       \
>   {                                                                           
>  \
>     for (int i = 0; i < n; i++)                                               
>  \
>       dst[i] += a[i] * b[i];                                                  
>    \
>   }
> 
> #define TEST_ALL() TEST_TYPE (double)
> 
> TEST_ALL ()
> 
> Before this patch:
> ...
> COND_LEN_MUL
> COND_LEN_ADD
> 
> Afther this patch:
> ...
> COND_LEN_FMA

OK.

Thanks,
Richard.

> gcc/ChangeLog:
> 
>         * genmatch.cc (commutative_op): Add COND_LEN_*
>         * internal-fn.cc (first_commutative_argument): Ditto.
>         (CASE): Ditto.
>         (get_unconditional_internal_fn): Ditto.
>         (can_interpret_as_conditional_op_p): Ditto.
>         (internal_fn_len_index): Ditto.
>         * internal-fn.h (can_interpret_as_conditional_op_p): Ditt.
>         * tree-ssa-math-opts.cc (convert_mult_to_fma_1): Ditto.
>         (convert_mult_to_fma): Ditto.
>         (math_opts_dom_walker::after_dom_children): Ditto.
> 
> ---
>  gcc/genmatch.cc           | 13 ++++++
>  gcc/internal-fn.cc        | 87 ++++++++++++++++++++++++++++++++++-----
>  gcc/internal-fn.h         |  2 +-
>  gcc/tree-ssa-math-opts.cc | 80 +++++++++++++++++++++++++++++------
>  4 files changed, 159 insertions(+), 23 deletions(-)
> 
> diff --git a/gcc/genmatch.cc b/gcc/genmatch.cc
> index 5fceeec9780..2302f2a7ff0 100644
> --- a/gcc/genmatch.cc
> +++ b/gcc/genmatch.cc
> @@ -559,6 +559,19 @@ commutative_op (id_base *id)
>        case CFN_COND_FMS:
>        case CFN_COND_FNMA:
>        case CFN_COND_FNMS:
> +      case CFN_COND_LEN_ADD:
> +      case CFN_COND_LEN_MUL:
> +      case CFN_COND_LEN_MIN:
> +      case CFN_COND_LEN_MAX:
> +      case CFN_COND_LEN_FMIN:
> +      case CFN_COND_LEN_FMAX:
> +      case CFN_COND_LEN_AND:
> +      case CFN_COND_LEN_IOR:
> +      case CFN_COND_LEN_XOR:
> +      case CFN_COND_LEN_FMA:
> +      case CFN_COND_LEN_FMS:
> +      case CFN_COND_LEN_FNMA:
> +      case CFN_COND_LEN_FNMS:
>       return 1;
>  
>        default:
> diff --git a/gcc/internal-fn.cc b/gcc/internal-fn.cc
> index c11123a1173..e698f0bffc7 100644
> --- a/gcc/internal-fn.cc
> +++ b/gcc/internal-fn.cc
> @@ -4191,6 +4191,19 @@ first_commutative_argument (internal_fn fn)
>      case IFN_COND_FMS:
>      case IFN_COND_FNMA:
>      case IFN_COND_FNMS:
> +    case IFN_COND_LEN_ADD:
> +    case IFN_COND_LEN_MUL:
> +    case IFN_COND_LEN_MIN:
> +    case IFN_COND_LEN_MAX:
> +    case IFN_COND_LEN_FMIN:
> +    case IFN_COND_LEN_FMAX:
> +    case IFN_COND_LEN_AND:
> +    case IFN_COND_LEN_IOR:
> +    case IFN_COND_LEN_XOR:
> +    case IFN_COND_LEN_FMA:
> +    case IFN_COND_LEN_FMS:
> +    case IFN_COND_LEN_FNMA:
> +    case IFN_COND_LEN_FNMS:
>        return 1;
>  
>      default:
> @@ -4330,11 +4343,14 @@ conditional_internal_fn_code (internal_fn ifn)
>  {
>    switch (ifn)
>      {
> -#define CASE(CODE, IFN) case IFN_COND_##IFN: return CODE;
> -      FOR_EACH_CODE_MAPPING(CASE)
> +#define CASE(CODE, IFN)                                                      
>   \
> +  case IFN_COND_##IFN:                                                       
>   \
> +  case IFN_COND_LEN_##IFN:                                                   
>   \
> +    return CODE;
> +      FOR_EACH_CODE_MAPPING (CASE)
>  #undef CASE
> -    default:
> -      return ERROR_MARK;
> +      default:
> +     return ERROR_MARK;
>      }
>  }
>  
> @@ -4433,6 +4449,18 @@ get_unconditional_internal_fn (internal_fn ifn)
>     operating elementwise if the operands are vectors.  This includes
>     the case of an all-true COND, so that the operation always happens.
>  
> +   There is an alternative approach to interpret the STMT when the operands
> +   are vectors which is the operation predicated by both conditional mask
> +   and loop control length, the equivalent C code:
> +
> +     for (int i = 0; i < NUNTIS; i++)
> +      {
> +     if (i < LEN + BIAS && COND[i])
> +       LHS[i] = A[i] CODE B[i];
> +     else
> +       LHS[i] = ELSE[i];
> +      }
> +
>     When returning true, set:
>  
>     - *COND_OUT to the condition COND, or to NULL_TREE if the condition
> @@ -4440,13 +4468,18 @@ get_unconditional_internal_fn (internal_fn ifn)
>     - *CODE_OUT to the tree code
>     - OPS[I] to operand I of *CODE_OUT
>     - *ELSE_OUT to the fallback value ELSE, or to NULL_TREE if the
> -     condition is known to be all true.  */
> +     condition is known to be all true.
> +   - *LEN to the len argument if it COND_LEN_* operations or to NULL_TREE.
> +   - *BIAS to the bias argument if it COND_LEN_* operations or to NULL_TREE. 
>  */
>  
>  bool
>  can_interpret_as_conditional_op_p (gimple *stmt, tree *cond_out,
>                                  tree_code *code_out,
> -                                tree (&ops)[3], tree *else_out)
> +                                tree (&ops)[3], tree *else_out,
> +                                tree *len, tree *bias)
>  {
> +  *len = NULL_TREE;
> +  *bias = NULL_TREE;
>    if (gassign *assign = dyn_cast <gassign *> (stmt))
>      {
>        *cond_out = NULL_TREE;
> @@ -4462,18 +4495,28 @@ can_interpret_as_conditional_op_p (gimple *stmt, tree 
> *cond_out,
>        {
>       internal_fn ifn = gimple_call_internal_fn (call);
>       tree_code code = conditional_internal_fn_code (ifn);
> +     int len_index = internal_fn_len_index (ifn);
> +     int cond_nargs = len_index >= 0 ? 4 : 2;
>       if (code != ERROR_MARK)
>         {
>           *cond_out = gimple_call_arg (call, 0);
>           *code_out = code;
> -         unsigned int nops = gimple_call_num_args (call) - 2;
> +         unsigned int nops = gimple_call_num_args (call) - cond_nargs;
>           for (unsigned int i = 0; i < 3; ++i)
>             ops[i] = i < nops ? gimple_call_arg (call, i + 1) : NULL_TREE;
>           *else_out = gimple_call_arg (call, nops + 1);
> -         if (integer_truep (*cond_out))
> +         if (len_index < 0)
> +           {
> +             if (integer_truep (*cond_out))
> +               {
> +                 *cond_out = NULL_TREE;
> +                 *else_out = NULL_TREE;
> +               }
> +           }
> +         else
>             {
> -             *cond_out = NULL_TREE;
> -             *else_out = NULL_TREE;
> +             *len = gimple_call_arg (call, len_index);
> +             *bias = gimple_call_arg (call, len_index + 1);
>             }
>           return true;
>         }
> @@ -4561,8 +4604,32 @@ internal_fn_len_index (internal_fn fn)
>  
>      case IFN_LEN_MASK_GATHER_LOAD:
>      case IFN_LEN_MASK_SCATTER_STORE:
> +    case IFN_COND_LEN_FMA:
> +    case IFN_COND_LEN_FMS:
> +    case IFN_COND_LEN_FNMA:
> +    case IFN_COND_LEN_FNMS:
>        return 5;
>  
> +    case IFN_COND_LEN_ADD:
> +    case IFN_COND_LEN_SUB:
> +    case IFN_COND_LEN_MUL:
> +    case IFN_COND_LEN_DIV:
> +    case IFN_COND_LEN_MOD:
> +    case IFN_COND_LEN_RDIV:
> +    case IFN_COND_LEN_MIN:
> +    case IFN_COND_LEN_MAX:
> +    case IFN_COND_LEN_FMIN:
> +    case IFN_COND_LEN_FMAX:
> +    case IFN_COND_LEN_AND:
> +    case IFN_COND_LEN_IOR:
> +    case IFN_COND_LEN_XOR:
> +    case IFN_COND_LEN_SHL:
> +    case IFN_COND_LEN_SHR:
> +      return 4;
> +
> +    case IFN_COND_LEN_NEG:
> +      return 3;
> +
>      default:
>        return -1;
>      }
> diff --git a/gcc/internal-fn.h b/gcc/internal-fn.h
> index dd1bab0bddf..a5c3f4765ff 100644
> --- a/gcc/internal-fn.h
> +++ b/gcc/internal-fn.h
> @@ -229,7 +229,7 @@ extern tree_code conditional_internal_fn_code 
> (internal_fn);
>  extern internal_fn get_unconditional_internal_fn (internal_fn);
>  extern bool can_interpret_as_conditional_op_p (gimple *, tree *,
>                                              tree_code *, tree (&)[3],
> -                                            tree *);
> +                                            tree *, tree *, tree *);
>  
>  extern bool internal_load_fn_p (internal_fn);
>  extern bool internal_store_fn_p (internal_fn);
> diff --git a/gcc/tree-ssa-math-opts.cc b/gcc/tree-ssa-math-opts.cc
> index 68fc518b1ab..712097ac5be 100644
> --- a/gcc/tree-ssa-math-opts.cc
> +++ b/gcc/tree-ssa-math-opts.cc
> @@ -3099,10 +3099,11 @@ convert_mult_to_fma_1 (tree mul_result, tree op1, 
> tree op2)
>         negate_p = true;
>       }
>  
> -      tree cond, else_value, ops[3];
> +      tree cond, else_value, ops[3], len, bias;
>        tree_code code;
>        if (!can_interpret_as_conditional_op_p (use_stmt, &cond, &code,
> -                                           ops, &else_value))
> +                                           ops, &else_value,
> +                                           &len, &bias))
>       gcc_unreachable ();
>        addop = ops[0] == result ? ops[1] : ops[0];
>  
> @@ -3122,7 +3123,11 @@ convert_mult_to_fma_1 (tree mul_result, tree op1, tree 
> op2)
>        if (seq)
>       gsi_insert_seq_before (&gsi, seq, GSI_SAME_STMT);
>  
> -      if (cond)
> +      if (len)
> +     fma_stmt
> +       = gimple_build_call_internal (IFN_COND_LEN_FMA, 7, cond, mulop1, op2,
> +                                     addop, else_value, len, bias);
> +      else if (cond)
>       fma_stmt = gimple_build_call_internal (IFN_COND_FMA, 5, cond, mulop1,
>                                              op2, addop, else_value);
>        else
> @@ -3307,7 +3312,8 @@ last_fma_candidate_feeds_initial_phi 
> (fma_deferring_state *state,
>  
>  static bool
>  convert_mult_to_fma (gimple *mul_stmt, tree op1, tree op2,
> -                  fma_deferring_state *state, tree mul_cond = NULL_TREE)
> +                  fma_deferring_state *state, tree mul_cond = NULL_TREE,
> +                  tree mul_len = NULL_TREE, tree mul_bias = NULL_TREE)
>  {
>    tree mul_result = gimple_get_lhs (mul_stmt);
>    /* If there isn't a LHS then this can't be an FMA.  There can be no LHS
> @@ -3420,10 +3426,10 @@ convert_mult_to_fma (gimple *mul_stmt, tree op1, tree 
> op2,
>         negate_p = seen_negate_p = true;
>       }
>  
> -      tree cond, else_value, ops[3];
> +      tree cond, else_value, ops[3], len, bias;
>        tree_code code;
>        if (!can_interpret_as_conditional_op_p (use_stmt, &cond, &code, ops,
> -                                           &else_value))
> +                                           &else_value, &len, &bias))
>       return false;
>  
>        switch (code)
> @@ -3439,15 +3445,49 @@ convert_mult_to_fma (gimple *mul_stmt, tree op1, tree 
> op2,
>         return false;
>       }
>  
> -      if (mul_cond && cond != mul_cond)
> -     return false;
> -
> -      if (cond)
> +      if (len)
>       {
> -       if (cond == result || else_value == result)
> +       /* For COND_LEN_* operations, we may have dummpy mask which is
> +          the all true mask.  Such TREE type may be mul_cond != cond
> +          but we still consider they are equal.  */
> +       if (mul_cond && cond != mul_cond
> +           && !(integer_truep (mul_cond) && integer_truep (cond)))
>           return false;
> -       if (!direct_internal_fn_supported_p (IFN_COND_FMA, type, opt_type))
> +
> +       if (else_value == result)
> +         return false;
> +
> +       if (!direct_internal_fn_supported_p (IFN_COND_LEN_FMA, type,
> +                                            opt_type))
>           return false;
> +
> +       if (mul_len)
> +         {
> +           poly_int64 mul_value, value;
> +           if (poly_int_tree_p (mul_len, &mul_value)
> +               && poly_int_tree_p (len, &value)
> +               && maybe_ne (mul_value, value))
> +             return false;
> +           else if (mul_len != len)
> +             return false;
> +
> +           if (wi::to_widest (mul_bias) != wi::to_widest (bias))
> +             return false;
> +         }
> +     }
> +      else
> +     {
> +       if (mul_cond && cond != mul_cond)
> +         return false;
> +
> +       if (cond)
> +         {
> +           if (cond == result || else_value == result)
> +             return false;
> +           if (!direct_internal_fn_supported_p (IFN_COND_FMA, type,
> +                                                opt_type))
> +             return false;
> +         }
>       }
>  
>        /* If the subtrahend (OPS[1]) is computed by a MULT_EXPR that
> @@ -5632,6 +5672,22 @@ math_opts_dom_walker::after_dom_children (basic_block 
> bb)
>               }
>             break;
>  
> +         case CFN_COND_LEN_MUL:
> +           if (convert_mult_to_fma (stmt,
> +                                    gimple_call_arg (stmt, 1),
> +                                    gimple_call_arg (stmt, 2),
> +                                    &fma_state,
> +                                    gimple_call_arg (stmt, 0),
> +                                    gimple_call_arg (stmt, 4),
> +                                    gimple_call_arg (stmt, 5)))
> +
> +             {
> +               gsi_remove (&gsi, true);
> +               release_defs (stmt);
> +               continue;
> +             }
> +           break;
> +
>           case CFN_LAST:
>             cancel_fma_deferring (&fma_state);
>             break;
> 

-- 
Richard Biener <rguent...@suse.de>
SUSE Software Solutions Germany GmbH, Frankenstrasse 146, 90461 Nuernberg,
Germany; GF: Ivo Totev, Andrew Myers, Andrew McDonald, Boudien Moerman;
HRB 36809 (AG Nuernberg)

Reply via email to