committed, thanks!

On Sat, Jan 28, 2023 at 7:26 AM <juzhe.zh...@rivai.ai> wrote:

> From: Ju-Zhe Zhong <juzhe.zh...@rivai.ai>
>
> gcc/ChangeLog:
>
>         * config/riscv/predicates.md (pmode_reg_or_0_operand): New
> predicate.
>         * config/riscv/riscv-vector-builtins-bases.cc (class loadstore):
> Support vlse/vsse.
>         (BASE): Ditto.
>         * config/riscv/riscv-vector-builtins-bases.h: Ditto.
>         * config/riscv/riscv-vector-builtins-functions.def (vlse): New
> class.
>         (vsse): New class.
>         * config/riscv/riscv-vector-builtins.cc
> (function_expander::use_contiguous_load_insn): Support vlse/vsse.
>         * config/riscv/vector.md (@pred_strided_load<mode>): New md
> pattern.
>         (@pred_strided_store<mode>): Ditto.
>
> ---
>  gcc/config/riscv/predicates.md                |  4 +
>  .../riscv/riscv-vector-builtins-bases.cc      | 26 +++++-
>  .../riscv/riscv-vector-builtins-bases.h       |  2 +
>  .../riscv/riscv-vector-builtins-functions.def |  2 +
>  gcc/config/riscv/riscv-vector-builtins.cc     | 33 ++++++-
>  gcc/config/riscv/vector.md                    | 90 +++++++++++++++++--
>  6 files changed, 143 insertions(+), 14 deletions(-)
>
> diff --git a/gcc/config/riscv/predicates.md
> b/gcc/config/riscv/predicates.md
> index 766a427570c..f9013bbf8bb 100644
> --- a/gcc/config/riscv/predicates.md
> +++ b/gcc/config/riscv/predicates.md
> @@ -286,6 +286,10 @@
>             (match_test "GET_CODE (op) == UNSPEC
>                          && (XINT (op, 1) == UNSPEC_VUNDEF)"))))
>
> +(define_special_predicate "pmode_reg_or_0_operand"
> +  (ior (match_operand 0 "const_0_operand")
> +       (match_operand 0 "pmode_register_operand")))
> +
>  ;; The scalar operand can be directly broadcast by RVV instructions.
>  (define_predicate "direct_broadcast_operand"
>    (ior (match_operand 0 "register_operand")
> diff --git a/gcc/config/riscv/riscv-vector-builtins-bases.cc
> b/gcc/config/riscv/riscv-vector-builtins-bases.cc
> index cf6a060ddfb..f9a16c68e07 100644
> --- a/gcc/config/riscv/riscv-vector-builtins-bases.cc
> +++ b/gcc/config/riscv/riscv-vector-builtins-bases.cc
> @@ -84,8 +84,8 @@ public:
>    }
>  };
>
> -/* Implements vle.v/vse.v/vlm.v/vsm.v codegen.  */
> -template <bool STORE_P>
> +/* Implements vle.v/vse.v/vlm.v/vsm.v/vlse.v/vsse.v codegen.  */
> +template <bool STORE_P, bool STRIDED_P = false>
>  class loadstore : public function_base
>  {
>    unsigned int call_properties (const function_instance &) const override
> @@ -106,9 +106,23 @@ class loadstore : public function_base
>    rtx expand (function_expander &e) const override
>    {
>      if (STORE_P)
> -      return e.use_contiguous_store_insn (code_for_pred_store
> (e.vector_mode ()));
> +      {
> +       if (STRIDED_P)
> +         return e.use_contiguous_store_insn (
> +           code_for_pred_strided_store (e.vector_mode ()));
> +       else
> +         return e.use_contiguous_store_insn (
> +           code_for_pred_store (e.vector_mode ()));
> +      }
>      else
> -      return e.use_contiguous_load_insn (code_for_pred_mov (e.vector_mode
> ()));
> +      {
> +       if (STRIDED_P)
> +         return e.use_contiguous_load_insn (
> +           code_for_pred_strided_load (e.vector_mode ()));
> +       else
> +         return e.use_contiguous_load_insn (
> +           code_for_pred_mov (e.vector_mode ()));
> +      }
>    }
>  };
>
> @@ -118,6 +132,8 @@ static CONSTEXPR const loadstore<false> vle_obj;
>  static CONSTEXPR const loadstore<true> vse_obj;
>  static CONSTEXPR const loadstore<false> vlm_obj;
>  static CONSTEXPR const loadstore<true> vsm_obj;
> +static CONSTEXPR const loadstore<false, true> vlse_obj;
> +static CONSTEXPR const loadstore<true, true> vsse_obj;
>
>  /* Declare the function base NAME, pointing it to an instance
>     of class <NAME>_obj.  */
> @@ -130,5 +146,7 @@ BASE (vle)
>  BASE (vse)
>  BASE (vlm)
>  BASE (vsm)
> +BASE (vlse)
> +BASE (vsse)
>
>  } // end namespace riscv_vector
> diff --git a/gcc/config/riscv/riscv-vector-builtins-bases.h
> b/gcc/config/riscv/riscv-vector-builtins-bases.h
> index 7af462b9530..93999e2cbee 100644
> --- a/gcc/config/riscv/riscv-vector-builtins-bases.h
> +++ b/gcc/config/riscv/riscv-vector-builtins-bases.h
> @@ -30,6 +30,8 @@ extern const function_base *const vle;
>  extern const function_base *const vse;
>  extern const function_base *const vlm;
>  extern const function_base *const vsm;
> +extern const function_base *const vlse;
> +extern const function_base *const vsse;
>  }
>
>  } // end namespace riscv_vector
> diff --git a/gcc/config/riscv/riscv-vector-builtins-functions.def
> b/gcc/config/riscv/riscv-vector-builtins-functions.def
> index 8bcaf2e3267..1ddde7b9d76 100644
> --- a/gcc/config/riscv/riscv-vector-builtins-functions.def
> +++ b/gcc/config/riscv/riscv-vector-builtins-functions.def
> @@ -44,5 +44,7 @@ DEF_RVV_FUNCTION (vle, loadstore, full_preds,
> all_v_scalar_const_ptr_ops)
>  DEF_RVV_FUNCTION (vse, loadstore, none_m_preds, all_v_scalar_ptr_ops)
>  DEF_RVV_FUNCTION (vlm, loadstore, none_preds, b_v_scalar_const_ptr_ops)
>  DEF_RVV_FUNCTION (vsm, loadstore, none_preds, b_v_scalar_ptr_ops)
> +DEF_RVV_FUNCTION (vlse, loadstore, full_preds,
> all_v_scalar_const_ptr_ptrdiff_ops)
> +DEF_RVV_FUNCTION (vsse, loadstore, none_m_preds,
> all_v_scalar_ptr_ptrdiff_ops)
>
>  #undef DEF_RVV_FUNCTION
> diff --git a/gcc/config/riscv/riscv-vector-builtins.cc
> b/gcc/config/riscv/riscv-vector-builtins.cc
> index 9023930560c..593a5f08e69 100644
> --- a/gcc/config/riscv/riscv-vector-builtins.cc
> +++ b/gcc/config/riscv/riscv-vector-builtins.cc
> @@ -167,6 +167,19 @@ static CONSTEXPR const rvv_arg_type_info
> scalar_ptr_args[]
>    = {rvv_arg_type_info (RVV_BASE_scalar_ptr),
>       rvv_arg_type_info (RVV_BASE_vector), rvv_arg_type_info_end};
>
> +/* A list of args for vector_type func (const scalar_type *, ptrdiff_t)
> + * function.  */
> +static CONSTEXPR const rvv_arg_type_info scalar_const_ptr_ptrdiff_args[]
> +  = {rvv_arg_type_info (RVV_BASE_scalar_const_ptr),
> +     rvv_arg_type_info (RVV_BASE_ptrdiff), rvv_arg_type_info_end};
> +
> +/* A list of args for void func (scalar_type *, ptrdiff_t, vector_type)
> + * function.  */
> +static CONSTEXPR const rvv_arg_type_info scalar_ptr_ptrdiff_args[]
> +  = {rvv_arg_type_info (RVV_BASE_scalar_ptr),
> +     rvv_arg_type_info (RVV_BASE_ptrdiff), rvv_arg_type_info
> (RVV_BASE_vector),
> +     rvv_arg_type_info_end};
> +
>  /* A list of none preds that will be registered for intrinsic functions.
> */
>  static CONSTEXPR const predication_type_index none_preds[]
>    = {PRED_TYPE_none, NUM_PRED_TYPES};
> @@ -227,6 +240,22 @@ static CONSTEXPR const rvv_op_info b_v_scalar_ptr_ops
>       rvv_arg_type_info (RVV_BASE_void), /* Return type */
>       scalar_ptr_args /* Args */};
>
> +/* A static operand information for vector_type func (const scalar_type *,
> + * ptrdiff_t) function registration. */
> +static CONSTEXPR const rvv_op_info all_v_scalar_const_ptr_ptrdiff_ops
> +  = {all_ops,                            /* Types */
> +     OP_TYPE_v,                                  /* Suffix */
> +     rvv_arg_type_info (RVV_BASE_vector), /* Return type */
> +     scalar_const_ptr_ptrdiff_args /* Args */};
> +
> +/* A static operand information for void func (scalar_type *, ptrdiff_t,
> + * vector_type) function registration. */
> +static CONSTEXPR const rvv_op_info all_v_scalar_ptr_ptrdiff_ops
> +  = {all_ops,                          /* Types */
> +     OP_TYPE_v,                                /* Suffix */
> +     rvv_arg_type_info (RVV_BASE_void), /* Return type */
> +     scalar_ptr_ptrdiff_args /* Args */};
> +
>  /* A list of all RVV intrinsic functions.  */
>  static function_group_info function_groups[] = {
>  #define DEF_RVV_FUNCTION(NAME, SHAPE, PREDS, OPS_INFO)
>      \
> @@ -921,7 +950,9 @@ function_expander::use_contiguous_load_insn (insn_code
> icode)
>        add_input_operand (Pmode, get_tail_policy_for_pred (pred));
>        add_input_operand (Pmode, get_mask_policy_for_pred (pred));
>      }
> -  add_input_operand (Pmode, get_avl_type_rtx (avl_type::NONVLMAX));
> +
> +  if (opno != insn_data[icode].n_generator_args)
> +    add_input_operand (Pmode, get_avl_type_rtx (avl_type::NONVLMAX));
>
>    return generate_insn (icode);
>  }
> diff --git a/gcc/config/riscv/vector.md b/gcc/config/riscv/vector.md
> index 4319266974d..1453be116a9 100644
> --- a/gcc/config/riscv/vector.md
> +++ b/gcc/config/riscv/vector.md
> @@ -33,6 +33,7 @@
>    UNSPEC_VUNDEF
>    UNSPEC_VPREDICATE
>    UNSPEC_VLMAX
> +  UNSPEC_STRIDED
>  ])
>
>  (define_constants [
> @@ -204,28 +205,56 @@
>
>  ;; The index of operand[] to get the avl op.
>  (define_attr "vl_op_idx" ""
> -       (cond [(eq_attr "type"
> "vlde,vste,vimov,vfmov,vldm,vstm,vlds,vmalu")
> -        (const_int 4)]
> -       (const_int INVALID_ATTRIBUTE)))
> +  (cond [(eq_attr "type" "vlde,vste,vimov,vfmov,vldm,vstm,vmalu,vsts")
> +          (const_int 4)
> +
> +        ;; If operands[3] of "vlds" is not vector mode, it is
> pred_broadcast.
> +        ;; wheras it is pred_strided_load if operands[3] is vector mode.
> +         (eq_attr "type" "vlds")
> +          (if_then_else (match_test "VECTOR_MODE_P (GET_MODE
> (operands[3]))")
> +             (const_int 5)
> +             (const_int 4))]
> +  (const_int INVALID_ATTRIBUTE)))
>
>  ;; The tail policy op value.
>  (define_attr "ta" ""
> -  (cond [(eq_attr "type" "vlde,vimov,vfmov,vlds")
> -          (symbol_ref "riscv_vector::get_ta(operands[5])")]
> +  (cond [(eq_attr "type" "vlde,vimov,vfmov")
> +          (symbol_ref "riscv_vector::get_ta(operands[5])")
> +
> +        ;; If operands[3] of "vlds" is not vector mode, it is
> pred_broadcast.
> +        ;; wheras it is pred_strided_load if operands[3] is vector mode.
> +        (eq_attr "type" "vlds")
> +          (if_then_else (match_test "VECTOR_MODE_P (GET_MODE
> (operands[3]))")
> +            (symbol_ref "riscv_vector::get_ta(operands[6])")
> +            (symbol_ref "riscv_vector::get_ta(operands[5])"))]
>         (const_int INVALID_ATTRIBUTE)))
>
>  ;; The mask policy op value.
>  (define_attr "ma" ""
> -  (cond [(eq_attr "type" "vlde,vlds")
> -          (symbol_ref "riscv_vector::get_ma(operands[6])")]
> +  (cond [(eq_attr "type" "vlde")
> +          (symbol_ref "riscv_vector::get_ma(operands[6])")
> +
> +        ;; If operands[3] of "vlds" is not vector mode, it is
> pred_broadcast.
> +        ;; wheras it is pred_strided_load if operands[3] is vector mode.
> +        (eq_attr "type" "vlds")
> +          (if_then_else (match_test "VECTOR_MODE_P (GET_MODE
> (operands[3]))")
> +            (symbol_ref "riscv_vector::get_ma(operands[7])")
> +            (symbol_ref "riscv_vector::get_ma(operands[6])"))]
>         (const_int INVALID_ATTRIBUTE)))
>
>  ;; The avl type value.
>  (define_attr "avl_type" ""
> -  (cond [(eq_attr "type"
> "vlde,vlde,vste,vimov,vimov,vimov,vfmov,vlds,vlds")
> +  (cond [(eq_attr "type" "vlde,vlde,vste,vimov,vimov,vimov,vfmov")
>            (symbol_ref "INTVAL (operands[7])")
>          (eq_attr "type" "vldm,vstm,vimov,vmalu,vmalu")
> -          (symbol_ref "INTVAL (operands[5])")]
> +          (symbol_ref "INTVAL (operands[5])")
> +
> +        ;; If operands[3] of "vlds" is not vector mode, it is
> pred_broadcast.
> +        ;; wheras it is pred_strided_load if operands[3] is vector mode.
> +        (eq_attr "type" "vlds")
> +          (if_then_else (match_test "VECTOR_MODE_P (GET_MODE
> (operands[3]))")
> +            (const_int INVALID_ATTRIBUTE)
> +            (symbol_ref "INTVAL (operands[7])"))]
>         (const_int INVALID_ATTRIBUTE)))
>
>  ;; -----------------------------------------------------------------
> @@ -760,3 +789,46 @@
>     vlse<sew>.v\t%0,%3,zero"
>    [(set_attr "type" "vimov,vfmov,vlds,vlds")
>     (set_attr "mode" "<MODE>")])
> +
> +;;
> -------------------------------------------------------------------------------
> +;; ---- Predicated Strided loads/stores
> +;;
> -------------------------------------------------------------------------------
> +;; Includes:
> +;; - 7.5. Vector Strided Instructions
> +;;
> -------------------------------------------------------------------------------
> +
> +(define_insn "@pred_strided_load<mode>"
> +  [(set (match_operand:V 0 "register_operand"              "=vr,    vr,
>   vd")
> +       (if_then_else:V
> +         (unspec:<VM>
> +           [(match_operand:<VM> 1 "vector_mask_operand" "vmWc1,   Wc1,
> vm")
> +            (match_operand 5 "vector_length_operand"    "   rK,    rK,
> rK")
> +            (match_operand 6 "const_int_operand"        "    i,     i,
>  i")
> +            (match_operand 7 "const_int_operand"        "    i,     i,
>  i")
> +            (reg:SI VL_REGNUM)
> +            (reg:SI VTYPE_REGNUM)] UNSPEC_VPREDICATE)
> +         (unspec:V
> +           [(match_operand:V 3 "memory_operand"         "    m,     m,
>  m")
> +            (match_operand 4 "pmode_reg_or_0_operand"   "   rJ,    rJ,
> rJ")] UNSPEC_STRIDED)
> +         (match_operand:V 2 "vector_merge_operand"      "    0,    vu,
> vu")))]
> +  "TARGET_VECTOR"
> +  "vlse<sew>.v\t%0,%3,%z4%p1"
> +  [(set_attr "type" "vlds")
> +   (set_attr "mode" "<MODE>")])
> +
> +(define_insn "@pred_strided_store<mode>"
> +  [(set (match_operand:V 0 "memory_operand"                 "+m")
> +       (if_then_else:V
> +         (unspec:<VM>
> +           [(match_operand:<VM> 1 "vector_mask_operand" "vmWc1")
> +            (match_operand 4 "vector_length_operand"    "   rK")
> +            (reg:SI VL_REGNUM)
> +            (reg:SI VTYPE_REGNUM)] UNSPEC_VPREDICATE)
> +         (unspec:V
> +           [(match_operand 2 "pmode_reg_or_0_operand"   "   rJ")
> +            (match_operand:V 3 "register_operand"       "   vr")]
> UNSPEC_STRIDED)
> +         (match_dup 0)))]
> +  "TARGET_VECTOR"
> +  "vsse<sew>.v\t%3,%0,%z2%p1"
> +  [(set_attr "type" "vsts")
> +   (set_attr "mode" "<MODE>")])
> --
> 2.36.3
>
>

Reply via email to