Committed, thanks Juzhe.
--------------
Li Xu
>Thanks a lot. LGTM.
>
>
>
>juzhe.zh...@rivai.ai
>
>From: Li Xu
>Date: 2023-09-21 11:12
>To: gcc-patches
>CC: kito.cheng; palmer; juzhe.zhong; xuli
>Subject: [PATCH] RISC-V: Optimized for strided load/store with stride == 
>element width[PR111450]
>From: xuli <xu...@eswincomputing.com>
>
>When stride == element width, vlsse should be optimized into vle.v.
>vsse should be optimized into vse.v.
>
>PR target/111450
>
>gcc/ChangeLog:
>
>*config/riscv/constraints.md (c01): const_int 1.
>(c02): const_int 2.
>(c04): const_int 4.
>(c08): const_int 8.
>* config/riscv/predicates.md (vector_eew8_stride_operand): New predicate for 
>stride operand.
>(vector_eew16_stride_operand): Ditto.
>(vector_eew32_stride_operand): Ditto.
>(vector_eew64_stride_operand): Ditto.
>* config/riscv/vector-iterators.md: New iterator for stride operand.
>* config/riscv/vector.md: Add stride = element width constraint.
>
>gcc/testsuite/ChangeLog:
>
>* gcc.target/riscv/rvv/base/pr111450.c: New test.
>---
>gcc/config/riscv/constraints.md               |  20 ++++
>gcc/config/riscv/predicates.md                |  18 ++++
>gcc/config/riscv/vector-iterators.md          |  87 +++++++++++++++
>gcc/config/riscv/vector.md                    |  42 +++++---
>.../gcc.target/riscv/rvv/base/pr111450.c      | 100 ++++++++++++++++++
>5 files changed, 250 insertions(+), 17 deletions(-)
>create mode 100644 gcc/testsuite/gcc.target/riscv/rvv/base/pr111450.c
>
>diff --git a/gcc/config/riscv/constraints.md b/gcc/config/riscv/constraints.md
>index 3f52bc76f67..964fdd450c9 100644
>--- a/gcc/config/riscv/constraints.md
>+++ b/gcc/config/riscv/constraints.md
>@@ -45,6 +45,26 @@
>   (and (match_code "const_int")
>        (match_test "ival == 0")))
>+(define_constraint "c01"
>+  "Constant value 1."
>+  (and (match_code "const_int")
>+       (match_test "ival == 1")))
>+
>+(define_constraint "c02"
>+  "Constant value 2"
>+  (and (match_code "const_int")
>+       (match_test "ival == 2")))
>+
>+(define_constraint "c04"
>+  "Constant value 4"
>+  (and (match_code "const_int")
>+       (match_test "ival == 4")))
>+
>+(define_constraint "c08"
>+  "Constant value 8"
>+  (and (match_code "const_int")
>+       (match_test "ival == 8")))
>+
>(define_constraint "K"
>   "A 5-bit unsigned immediate for CSR access instructions."
>   (and (match_code "const_int")
>diff --git a/gcc/config/riscv/predicates.md b/gcc/config/riscv/predicates.md
>index 4bc7ff2c9d8..7845998e430 100644
>--- a/gcc/config/riscv/predicates.md
>+++ b/gcc/config/riscv/predicates.md
>@@ -514,6 +514,24 @@
>   (ior (match_operand 0 "const_0_operand")
>        (match_operand 0 "pmode_register_operand")))
>+;; [1, 2, 4, 8] means strided load/store with stride == element width
>+(define_special_predicate "vector_eew8_stride_operand"
>+  (ior (match_operand 0 "pmode_register_operand")
>+       (and (match_code "const_int")
>+            (match_test "INTVAL (op) == 1 || INTVAL (op) == 0"))))
>+(define_special_predicate "vector_eew16_stride_operand"
>+  (ior (match_operand 0 "pmode_register_operand")
>+       (and (match_code "const_int")
>+            (match_test "INTVAL (op) == 2 || INTVAL (op) == 0"))))
>+(define_special_predicate "vector_eew32_stride_operand"
>+  (ior (match_operand 0 "pmode_register_operand")
>+       (and (match_code "const_int")
>+            (match_test "INTVAL (op) == 4 || INTVAL (op) == 0"))))
>+(define_special_predicate "vector_eew64_stride_operand"
>+  (ior (match_operand 0 "pmode_register_operand")
>+       (and (match_code "const_int")
>+            (match_test "INTVAL (op) == 8 || INTVAL (op) == 0"))))
>+
>;; A special predicate that doesn't match a particular mode.
>(define_special_predicate "vector_any_register_operand"
>   (match_code "reg"))
>diff --git a/gcc/config/riscv/vector-iterators.md 
>b/gcc/config/riscv/vector-iterators.md
>index 73df55a69c8..f85d1cc80d1 100644
>--- a/gcc/config/riscv/vector-iterators.md
>+++ b/gcc/config/riscv/vector-iterators.md
>@@ -2596,6 +2596,93 @@
>   (V512DI "V512BI")
>])
>+(define_mode_attr stride_predicate [
>+  (RVVM8QI "vector_eew8_stride_operand") (RVVM4QI 
>"vector_eew8_stride_operand")
>+  (RVVM2QI "vector_eew8_stride_operand") (RVVM1QI 
>"vector_eew8_stride_operand")
>+  (RVVMF2QI "vector_eew8_stride_operand") (RVVMF4QI 
>"vector_eew8_stride_operand")
>+  (RVVMF8QI "vector_eew8_stride_operand")
>+
>+  (RVVM8HI "vector_eew16_stride_operand") (RVVM4HI 
>"vector_eew16_stride_operand")
>+  (RVVM2HI "vector_eew16_stride_operand") (RVVM1HI 
>"vector_eew16_stride_operand")
>+  (RVVMF2HI "vector_eew16_stride_operand") (RVVMF4HI 
>"vector_eew16_stride_operand")
>+
>+  (RVVM8HF "vector_eew16_stride_operand") (RVVM4HF 
>"vector_eew16_stride_operand")
>+  (RVVM2HF "vector_eew16_stride_operand") (RVVM1HF 
>"vector_eew16_stride_operand")
>+  (RVVMF2HF "vector_eew16_stride_operand") (RVVMF4HF 
>"vector_eew16_stride_operand")
>+
>+  (RVVM8SI "vector_eew32_stride_operand") (RVVM4SI 
>"vector_eew32_stride_operand")
>+  (RVVM2SI "vector_eew32_stride_operand") (RVVM1SI 
>"vector_eew32_stride_operand")
>+  (RVVMF2SI "vector_eew32_stride_operand")
>+
>+  (RVVM8SF "vector_eew32_stride_operand") (RVVM4SF 
>"vector_eew32_stride_operand")
>+  (RVVM2SF "vector_eew32_stride_operand") (RVVM1SF 
>"vector_eew32_stride_operand")
>+  (RVVMF2SF "vector_eew32_stride_operand")
>+
>+  (RVVM8DI "vector_eew64_stride_operand") (RVVM4DI 
>"vector_eew64_stride_operand")
>+  (RVVM2DI "vector_eew64_stride_operand") (RVVM1DI 
>"vector_eew64_stride_operand")
>+
>+  (RVVM8DF "vector_eew64_stride_operand") (RVVM4DF 
>"vector_eew64_stride_operand")
>+  (RVVM2DF "vector_eew64_stride_operand") (RVVM1DF 
>"vector_eew64_stride_operand")
>+])
>+
>+(define_mode_attr stride_load_constraint [
>+  (RVVM8QI "rJ,rJ,rJ,c01,c01,c01") (RVVM4QI "rJ,rJ,rJ,c01,c01,c01")
>+  (RVVM2QI "rJ,rJ,rJ,c01,c01,c01") (RVVM1QI "rJ,rJ,rJ,c01,c01,c01")
>+  (RVVMF2QI "rJ,rJ,rJ,c01,c01,c01") (RVVMF4QI "rJ,rJ,rJ,c01,c01,c01")
>+  (RVVMF8QI "rJ,rJ,rJ,c01,c01,c01")
>+
>+  (RVVM8HI "rJ,rJ,rJ,c02,c02,c02") (RVVM4HI "rJ,rJ,rJ,c02,c02,c02")
>+  (RVVM2HI "rJ,rJ,rJ,c02,c02,c02") (RVVM1HI "rJ,rJ,rJ,c02,c02,c02")
>+  (RVVMF2HI "rJ,rJ,rJ,c02,c02,c02") (RVVMF4HI "rJ,rJ,rJ,c02,c02,c02")
>+
>+  (RVVM8HF "rJ,rJ,rJ,c02,c02,c02") (RVVM4HF "rJ,rJ,rJ,c02,c02,c02")
>+  (RVVM2HF "rJ,rJ,rJ,c02,c02,c02") (RVVM1HF "rJ,rJ,rJ,c02,c02,c02")
>+  (RVVMF2HF "rJ,rJ,rJ,c02,c02,c02") (RVVMF4HF "rJ,rJ,rJ,c02,c02,c02")
>+
>+  (RVVM8SI "rJ,rJ,rJ,c04,c04,c04") (RVVM4SI "rJ,rJ,rJ,c04,c04,c04")
>+  (RVVM2SI "rJ,rJ,rJ,c04,c04,c04") (RVVM1SI "rJ,rJ,rJ,c04,c04,c04")
>+  (RVVMF2SI "rJ,rJ,rJ,c04,c04,c04")
>+
>+  (RVVM8SF "rJ,rJ,rJ,c04,c04,c04") (RVVM4SF "rJ,rJ,rJ,c04,c04,c04")
>+  (RVVM2SF "rJ,rJ,rJ,c04,c04,c04") (RVVM1SF "rJ,rJ,rJ,c04,c04,c04")
>+  (RVVMF2SF "rJ,rJ,rJ,c04,c04,c04")
>+
>+  (RVVM8DI "rJ,rJ,rJ,c08,c08,c08") (RVVM4DI "rJ,rJ,rJ,c08,c08,c08")
>+  (RVVM2DI "rJ,rJ,rJ,c08,c08,c08") (RVVM1DI "rJ,rJ,rJ,c08,c08,c08")
>+
>+  (RVVM8DF "rJ,rJ,rJ,c08,c08,c08") (RVVM4DF "rJ,rJ,rJ,c08,c08,c08")
>+  (RVVM2DF "rJ,rJ,rJ,c08,c08,c08") (RVVM1DF "rJ,rJ,rJ,c08,c08,c08")
>+])
>+
>+(define_mode_attr stride_store_constraint [
>+  (RVVM8QI "rJ,c01") (RVVM4QI "rJ,c01")
>+  (RVVM2QI "rJ,c01") (RVVM1QI "rJ,c01")
>+  (RVVMF2QI "rJ,c01") (RVVMF4QI "rJ,c01")
>+  (RVVMF8QI "rJ,c01")
>+
>+  (RVVM8HI "rJ,c02") (RVVM4HI "rJ,c02")
>+  (RVVM2HI "rJ,c02") (RVVM1HI "rJ,c02")
>+  (RVVMF2HI "rJ,c02") (RVVMF4HI "rJ,c02")
>+
>+  (RVVM8HF "rJ,c02") (RVVM4HF "rJ,c02")
>+  (RVVM2HF "rJ,c02") (RVVM1HF "rJ,c02")
>+  (RVVMF2HF "rJ,c02") (RVVMF4HF "rJ,c02")
>+
>+  (RVVM8SI "rJ,c04") (RVVM4SI "rJ,c04")
>+  (RVVM2SI "rJ,c04") (RVVM1SI "rJ,c04")
>+  (RVVMF2SI "rJ,c04")
>+
>+  (RVVM8SF "rJ,c04") (RVVM4SF "rJ,c04")
>+  (RVVM2SF "rJ,c04") (RVVM1SF "rJ,c04")
>+  (RVVMF2SF "rJ,c04")
>+
>+  (RVVM8DI "rJ,c08") (RVVM4DI "rJ,c08")
>+  (RVVM2DI "rJ,c08") (RVVM1DI "rJ,c08")
>+
>+  (RVVM8DF "rJ,c08") (RVVM4DF "rJ,c08")
>+  (RVVM2DF "rJ,c08") (RVVM1DF "rJ,c08")
>+])
>+
>(define_mode_attr gs_extension [
>   (RVVM8QI "const_1_operand") (RVVM4QI "vector_gs_extension_operand")
>   (RVVM2QI "immediate_operand") (RVVM1QI "immediate_operand") (RVVMF2QI 
>"immediate_operand")
>diff --git a/gcc/config/riscv/vector.md b/gcc/config/riscv/vector.md
>index f66ffebba24..5595789b3bb 100644
>--- a/gcc/config/riscv/vector.md
>+++ b/gcc/config/riscv/vector.md
>@@ -2083,40 +2083,48 @@
>;; 
>-------------------------------------------------------------------------------
>(define_insn "@pred_strided_load<mode>"
>-  [(set (match_operand:V 0 "register_operand"              "=vr,    vr,    
>vd")
>+  [(set (match_operand:V 0 "register_operand"              "=vr,    vr,    
>vd,    vr,    vr,    vd")
>(if_then_else:V
>  (unspec:<VM>
>-     [(match_operand:<VM> 1 "vector_mask_operand" "vmWc1,   Wc1,    vm")
>-      (match_operand 5 "vector_length_operand"    "   rK,    rK,    rK")
>-      (match_operand 6 "const_int_operand"        "    i,     i,     i")
>-      (match_operand 7 "const_int_operand"        "    i,     i,     i")
>-      (match_operand 8 "const_int_operand"        "    i,     i,     i")
>+     [(match_operand:<VM> 1 "vector_mask_operand" "vmWc1,   Wc1,    vm,    
>vmWc1,   Wc1,    vm")
>+      (match_operand 5 "vector_length_operand"    "   rK,    rK,    rK,       
>rK,    rK,    rK")
>+      (match_operand 6 "const_int_operand"        "    i,     i,     i,       
> i,     i,     i")
>+      (match_operand 7 "const_int_operand"        "    i,     i,     i,       
> i,     i,     i")
>+      (match_operand 8 "const_int_operand"        "    i,     i,     i,       
> i,     i,     i")
>     (reg:SI VL_REGNUM)
>     (reg:SI VTYPE_REGNUM)] UNSPEC_VPREDICATE)
>  (unspec:V
>-     [(match_operand:V 3 "memory_operand"         "    m,     m,     m")
>-      (match_operand 4 "pmode_reg_or_0_operand"   "   rJ,    rJ,    rJ")] 
>UNSPEC_STRIDED)
>-   (match_operand:V 2 "vector_merge_operand"      "    0,    vu,    vu")))]
>+     [(match_operand:V 3 "memory_operand"         "     m,     m,     m,    
>m,     m,     m")
>+      (match_operand 4 "<V:stride_predicate>"     
>"<V:stride_load_constraint>")] UNSPEC_STRIDED)
>+   (match_operand:V 2 "vector_merge_operand"      "     0,    vu,    vu,    
>0,    vu,    vu")))]
>   "TARGET_VECTOR"
>-  "vlse<sew>.v\t%0,%3,%z4%p1"
>+  "@
>+  vlse<sew>.v\t%0,%3,%z4%p1
>+  vlse<sew>.v\t%0,%3,%z4
>+  vlse<sew>.v\t%0,%3,%z4,%1.t
>+  vle<sew>.v\t%0,%3%p1
>+  vle<sew>.v\t%0,%3
>+  vle<sew>.v\t%0,%3,%1.t"
>   [(set_attr "type" "vlds")
>    (set_attr "mode" "<MODE>")])
>(define_insn "@pred_strided_store<mode>"
>-  [(set (match_operand:V 0 "memory_operand"                 "+m")
>+  [(set (match_operand:V 0 "memory_operand"                 "+m,    m")
>(if_then_else:V
>  (unspec:<VM>
>-     [(match_operand:<VM> 1 "vector_mask_operand" "vmWc1")
>-      (match_operand 4 "vector_length_operand"    "   rK")
>-      (match_operand 5 "const_int_operand"        "    i")
>+     [(match_operand:<VM> 1 "vector_mask_operand" "vmWc1,    vmWc1")
>+      (match_operand 4 "vector_length_operand"    "   rK,       rK")
>+      (match_operand 5 "const_int_operand"        "    i,        i")
>     (reg:SI VL_REGNUM)
>     (reg:SI VTYPE_REGNUM)] UNSPEC_VPREDICATE)
>  (unspec:V
>-     [(match_operand 2 "pmode_reg_or_0_operand"   "   rJ")
>-      (match_operand:V 3 "register_operand"       "   vr")] UNSPEC_STRIDED)
>+     [(match_operand 2 "<V:stride_predicate>"     
>"<V:stride_store_constraint>")
>+      (match_operand:V 3 "register_operand"       "   vr,       vr")] 
>UNSPEC_STRIDED)
>  (match_dup 0)))]
>   "TARGET_VECTOR"
>-  "vsse<sew>.v\t%3,%0,%z2%p1"
>+  "@
>+  vsse<sew>.v\t%3,%0,%z2%p1
>+  vse<sew>.v\t%3,%0%p1"
>   [(set_attr "type" "vsts")
>    (set_attr "mode" "<MODE>")
>    (set (attr "avl_type") (symbol_ref "INTVAL (operands[5])"))])
>diff --git a/gcc/testsuite/gcc.target/riscv/rvv/base/pr111450.c 
>b/gcc/testsuite/gcc.target/riscv/rvv/base/pr111450.c
>new file mode 100644
>index 00000000000..50aadcd2024
>--- /dev/null
>+++ b/gcc/testsuite/gcc.target/riscv/rvv/base/pr111450.c
>@@ -0,0 +1,100 @@
>+/* { dg-do compile } */
>+/* { dg-options "-march=rv32gcv_zvfh -mabi=ilp32d -O2" } */
>+/* { dg-final { check-function-bodies "**" "" } } */
>+
>+#include "riscv_vector.h"
>+
>+typedef _Float16 float16_t;
>+typedef float float32_t;
>+typedef double float64_t;
>+
>+/*
>+**foo:
>+** vsetvli\s+zero,\s*[a-z0-9]+,e8,m1,ta,ma
>+** vle8\.v\s+v[0-9]+,\s*0\([a-x0-9]+\)
>+** vse8\.v\s+v[0-9]+,\s*0\([a-x0-9]+\)
>+** ret
>+*/
>+void foo (int8_t *in, int8_t *out, int n)
>+{
>+    vint8m1_t v = __riscv_vlse8_v_i8m1 (in, 1, n);
>+    __riscv_vsse8_v_i8m1 (out, 1, v, n);
>+}
>+
>+/*
>+**foo1:
>+** vsetvli\s+zero,\s*[a-z0-9]+,e16,m1,ta,ma
>+** vle16\.v\s+v[0-9]+,\s*0\([a-x0-9]+\)
>+** vse16\.v\s+v[0-9]+,\s*0\([a-x0-9]+\)
>+** ret
>+*/
>+void foo1 (int16_t *in, int16_t *out, int n)
>+{
>+    vint16m1_t v = __riscv_vlse16_v_i16m1 (in, 2, n);
>+    __riscv_vsse16_v_i16m1 (out, 2, v, n);
>+}
>+
>+/*
>+**foo2:
>+** vsetvli\s+zero,\s*[a-z0-9]+,e32,m1,ta,ma
>+** vle32\.v\s+v[0-9]+,\s*0\([a-x0-9]+\)
>+** vse32\.v\s+v[0-9]+,\s*0\([a-x0-9]+\)
>+** ret
>+*/
>+void foo2 (int32_t *in, int32_t *out, int n)
>+{
>+    vint32m1_t v = __riscv_vlse32_v_i32m1 (in, 4, n);
>+    __riscv_vsse32_v_i32m1 (out, 4, v, n);
>+}
>+
>+/*
>+**foo3:
>+** vsetvli\s+zero,\s*[a-z0-9]+,e64,m1,ta,ma
>+** vle64\.v\s+v[0-9]+,\s*0\([a-x0-9]+\)
>+** vse64\.v\s+v[0-9]+,\s*0\([a-x0-9]+\)
>+** ret
>+*/
>+void foo3 (int64_t *in, int64_t *out, int n)
>+{
>+    vint64m1_t v = __riscv_vlse64_v_i64m1 (in, 8, n);
>+    __riscv_vsse64_v_i64m1 (out, 8, v, n);
>+}
>+
>+/*
>+**foo4:
>+** vsetvli\s+zero,\s*[a-z0-9]+,e16,mf2,ta,ma
>+** vle16\.v\s+v[0-9]+,\s*0\([a-x0-9]+\)
>+** vse16\.v\s+v[0-9]+,\s*0\([a-x0-9]+\)
>+** ret
>+*/
>+void foo4 (float16_t *in, float16_t *out, int n)
>+{
>+    vfloat16mf2_t v = __riscv_vlse16_v_f16mf2 (in, 2, n);
>+    __riscv_vsse16_v_f16mf2 (out, 2, v, n);
>+}
>+
>+/*
>+**foo5:
>+** vsetvli\s+zero,\s*[a-z0-9]+,e32,m1,ta,ma
>+** vle32\.v\s+v[0-9]+,\s*0\([a-x0-9]+\)
>+** vse32\.v\s+v[0-9]+,\s*0\([a-x0-9]+\)
>+** ret
>+*/
>+void foo5 (float32_t *in, float32_t *out, int n)
>+{
>+    vfloat32m1_t v = __riscv_vlse32_v_f32m1 (in, 4, n);
>+    __riscv_vsse32_v_f32m1 (out, 4, v, n);
>+}
>+
>+/*
>+**foo6:
>+** vsetvli\s+zero,\s*[a-z0-9]+,e64,m1,ta,ma
>+** vle64\.v\s+v[0-9]+,\s*0\([a-x0-9]+\)
>+** vse64\.v\s+v[0-9]+,\s*0\([a-x0-9]+\)
>+** ret
>+*/
>+void foo6 (float64_t *in, float64_t *out, int n)
>+{
>+    vfloat64m1_t v = __riscv_vlse64_v_f64m1 (in, 8, n);
>+    __riscv_vsse64_v_f64m1 (out, 8, v, n);
>+}
>--
>2.17.1
>
>

Reply via email to