From: Lino Hsing-Yu Peng <[email protected]>

Implement Zvfofp8min narrowing conversions from BF16 to FP8, including
the saturating variants exposed by the RVV builtin layer. Update builtin
shapes, operand metadata, and float8 md patterns to cover these forms.

gcc/ChangeLog:

        * config/riscv/riscv-vector-builtins-bases.cc: Add f8 narrow and sat 
conversions.
        * config/riscv/riscv-vector-builtins-bases.h: Declare vfncvt_sat_f 
bases.
        * config/riscv/riscv-vector-builtins-functions.def: Add bf16-to-f8 
narrow builtins.
        * config/riscv/riscv-vector-builtins-shapes.cc: Add f8 narrow shapes 
and naming.
        * config/riscv/riscv-vector-builtins-shapes.h: Declare f8 narrow shapes.
        * config/riscv/riscv-vector-builtins.cc: Add bf16_to_f8 operand info.
        * config/riscv/vector-float8.md: Add f8 sat variants.
---
 .../riscv/riscv-vector-builtins-bases.cc      |  58 +++++++++-
 .../riscv/riscv-vector-builtins-bases.h       |   3 +
 .../riscv/riscv-vector-builtins-functions.def |   9 ++
 .../riscv/riscv-vector-builtins-shapes.cc     | 101 +++++++++++++++++-
 .../riscv/riscv-vector-builtins-shapes.h      |   4 +
 gcc/config/riscv/riscv-vector-builtins.cc     |   9 ++
 gcc/config/riscv/vector-float8.md             |  38 ++++++-
 7 files changed, 216 insertions(+), 6 deletions(-)

diff --git a/gcc/config/riscv/riscv-vector-builtins-bases.cc 
b/gcc/config/riscv/riscv-vector-builtins-bases.cc
index 58ab57db5d4..3fa3ebb5d14 100644
--- a/gcc/config/riscv/riscv-vector-builtins-bases.cc
+++ b/gcc/config/riscv/riscv-vector-builtins-bases.cc
@@ -1526,12 +1526,15 @@ enum altfmt
 static altfmt
 get_altfmt (const function_expander &e)
 {
-  if (e.shape == shapes::alu_f8e4m3)
+  if (e.shape == shapes::alu_f8e4m3 || e.shape == shapes::narrow_alu_f8e4m3
+      || e.shape == shapes::narrow_alu_frm_f8e4m3)
     return F8E4M3;
-  if (e.shape == shapes::alu_f8e5m2)
+  if (e.shape == shapes::alu_f8e5m2 || e.shape == shapes::narrow_alu_f8e5m2
+      || e.shape == shapes::narrow_alu_frm_f8e5m2)
     return F8E5M2;
   return F8NONE;
 }
+
 class vfwcvt_f : public function_base
 {
 public:
@@ -1604,7 +1607,19 @@ public:
   rtx expand (function_expander &e) const override
   {
     if (e.op_info->op == OP_TYPE_f_w)
-      return e.use_exact_insn (code_for_pred_trunc (e.vector_mode ()));
+      {
+       switch (get_altfmt (e))
+         {
+         case F8E4M3:
+           return e.use_exact_insn (
+             code_for_pred_trunc_to (e.vector_mode (), UNSPEC_F8E4M3));
+         case F8E5M2:
+           return e.use_exact_insn (
+             code_for_pred_trunc_to (e.vector_mode (), UNSPEC_F8E5M2));
+         default:
+           return e.use_exact_insn (code_for_pred_trunc (e.vector_mode ()));
+         }
+      }
     if (e.op_info->op == OP_TYPE_x_w)
       return e.use_exact_insn (code_for_pred_narrow (FLOAT, e.arg_mode (0)));
     if (e.op_info->op == OP_TYPE_xu_w)
@@ -1614,6 +1629,37 @@ public:
   }
 };
 
+template <enum frm_op_type FRM_OP = NO_FRM>
+class vfncvt_sat_f : public function_base
+{
+public:
+  bool has_rounding_mode_operand_p () const override
+  {
+    return FRM_OP == HAS_FRM;
+  }
+
+  bool may_require_frm_p () const override { return true; }
+
+  rtx expand (function_expander &e) const override
+  {
+    if (e.op_info->op == OP_TYPE_f_w)
+      {
+       switch (get_altfmt (e))
+         {
+         case F8E4M3:
+           return e.use_exact_insn (
+             code_for_pred_trunc_to (e.vector_mode (), UNSPEC_F8E4M3_SAT));
+         case F8E5M2:
+           return e.use_exact_insn (
+             code_for_pred_trunc_to (e.vector_mode (), UNSPEC_F8E5M2_SAT));
+         default:
+           break;
+         }
+      }
+    gcc_unreachable ();
+  }
+};
+
 class vfncvt_rod_f : public function_base
 {
 public:
@@ -2809,6 +2855,9 @@ static CONSTEXPR const vfwcvtbf16_f vfwcvtbf16_f_obj;
 /* Zvfbfwma; */
 static CONSTEXPR const vfwmaccbf16<NO_FRM> vfwmaccbf16_obj;
 static CONSTEXPR const vfwmaccbf16<HAS_FRM> vfwmaccbf16_frm_obj;
+/* Zvfofp8min */
+static CONSTEXPR const vfncvt_sat_f<NO_FRM> vfncvt_sat_f_obj;
+static CONSTEXPR const vfncvt_sat_f<HAS_FRM> vfncvt_sat_f_frm_obj;
 
 /* Declare the function base NAME, pointing it to an instance
    of class <NAME>_obj.  */
@@ -3137,4 +3186,7 @@ BASE (vfwcvtbf16_f)
 /* Zvfbfwma */
 BASE (vfwmaccbf16)
 BASE (vfwmaccbf16_frm)
+/* Zvfofp8min */
+BASE (vfncvt_sat_f)
+BASE (vfncvt_sat_f_frm)
 } // end namespace riscv_vector
diff --git a/gcc/config/riscv/riscv-vector-builtins-bases.h 
b/gcc/config/riscv/riscv-vector-builtins-bases.h
index 9261d353e22..0df8801b7cc 100644
--- a/gcc/config/riscv/riscv-vector-builtins-bases.h
+++ b/gcc/config/riscv/riscv-vector-builtins-bases.h
@@ -352,6 +352,9 @@ extern const function_base *const vfwcvtbf16_f;
 /* Zvfbfwma */
 extern const function_base *const vfwmaccbf16;
 extern const function_base *const vfwmaccbf16_frm;
+/* Zvfofp8min */
+extern const function_base *const vfncvt_sat_f;
+extern const function_base *const vfncvt_sat_f_frm;
 }
 
 } // end namespace riscv_vector
diff --git a/gcc/config/riscv/riscv-vector-builtins-functions.def 
b/gcc/config/riscv/riscv-vector-builtins-functions.def
index 185d811e2b7..49ad8585933 100644
--- a/gcc/config/riscv/riscv-vector-builtins-functions.def
+++ b/gcc/config/riscv/riscv-vector-builtins-functions.def
@@ -773,6 +773,15 @@ DEF_RVV_FUNCTION (vfwcvtbf16_f, alu, full_preds, 
bf16_to_f32_f_v_ops)
 #define REQUIRED_EXTENSIONS ZVFOFP8MIN_EXT
 DEF_RVV_FUNCTION (vfwcvt_f, alu_f8e4m3, full_preds, f8_to_bf16_f_v_ops)
 DEF_RVV_FUNCTION (vfwcvt_f, alu_f8e5m2, full_preds, f8_to_bf16_f_v_ops)
+DEF_RVV_FUNCTION (vfncvt_f, narrow_alu_f8e4m3, full_preds, bf16_to_f8_f_w_ops)
+DEF_RVV_FUNCTION (vfncvt_f, narrow_alu_f8e5m2, full_preds, bf16_to_f8_f_w_ops)
+DEF_RVV_FUNCTION (vfncvt_sat_f, narrow_alu_f8e4m3, full_preds, 
bf16_to_f8_f_w_ops)
+DEF_RVV_FUNCTION (vfncvt_sat_f, narrow_alu_f8e5m2, full_preds, 
bf16_to_f8_f_w_ops)
+
+DEF_RVV_FUNCTION (vfncvt_f_frm, narrow_alu_frm_f8e4m3, full_preds, 
bf16_to_f8_f_w_ops)
+DEF_RVV_FUNCTION (vfncvt_f_frm, narrow_alu_frm_f8e5m2, full_preds, 
bf16_to_f8_f_w_ops)
+DEF_RVV_FUNCTION (vfncvt_sat_f_frm, narrow_alu_frm_f8e4m3, full_preds, 
bf16_to_f8_f_w_ops)
+DEF_RVV_FUNCTION (vfncvt_sat_f_frm, narrow_alu_frm_f8e5m2, full_preds, 
bf16_to_f8_f_w_ops)
 #undef REQUIRED_EXTENSIONS
 
 /* Zvfbfwma */
diff --git a/gcc/config/riscv/riscv-vector-builtins-shapes.cc 
b/gcc/config/riscv/riscv-vector-builtins-shapes.cc
index 3533ef01714..8385a49f517 100644
--- a/gcc/config/riscv/riscv-vector-builtins-shapes.cc
+++ b/gcc/config/riscv/riscv-vector-builtins-shapes.cc
@@ -96,7 +96,11 @@ supports_vectype_p (const function_group_info &group, 
unsigned int vec_type_idx)
       || *group.shape == shapes::seg_indexed_loadstore
       || *group.shape == shapes::seg_fault_load
       || *group.shape == shapes::alu_f8e4m3
-      || *group.shape == shapes::alu_f8e5m2)
+      || *group.shape == shapes::alu_f8e5m2
+      || *group.shape == shapes::narrow_alu_f8e4m3
+      || *group.shape == shapes::narrow_alu_f8e5m2
+      || *group.shape == shapes::narrow_alu_frm_f8e4m3
+      || *group.shape == shapes::narrow_alu_frm_f8e5m2)
     return true;
   return false;
 }
@@ -797,6 +801,97 @@ struct narrow_alu_def : public build_base
   }
 };
 
+static char *
+build_f8_narrow_name (function_builder &b, const function_instance &instance,
+                     bool overloaded_p, const char *altfmt, bool frm_p)
+{
+  if (overloaded_p && !instance.base->can_be_overloaded_p (instance.pred))
+    return nullptr;
+
+  const char *base_name = instance.base_name;
+  char base_name_buf[BASE_NAME_MAX_LEN] = {};
+  if (frm_p)
+    {
+      build_frm_base::normalize_base_name (base_name_buf, instance.base_name,
+                                          sizeof (base_name_buf));
+      base_name = base_name_buf;
+    }
+
+  b.append_base_name (base_name);
+
+  if (overloaded_p)
+    {
+      vector_type_index vti
+       = instance.op_info->args[0].get_function_type_index (
+         instance.type.index);
+      const char *src_scalar
+       = vti == VECTOR_TYPE_INVALID ? nullptr : type_suffixes[vti].scalar;
+      b.append_name (src_scalar ? src_scalar : "_bf16");
+      b.append_name ("_");
+      b.append_name (altfmt);
+    }
+  else
+    {
+      b.append_name (operand_suffixes[instance.op_info->op]);
+      vector_type_index vti
+       = instance.op_info->args[0].get_function_type_index (
+         instance.type.index);
+      if (vti != VECTOR_TYPE_INVALID)
+       b.append_name (type_suffixes[vti].vector);
+      append_f8_suffix (b,
+                       instance.op_info->ret.get_function_type_index (
+                         instance.type.index),
+                       altfmt);
+      if (frm_p)
+       b.append_name ("_rm");
+    }
+
+  if (overloaded_p && instance.pred == PRED_TYPE_m)
+    return b.finish_name ();
+  b.append_name (predication_suffixes[instance.pred]);
+  return b.finish_name ();
+}
+
+/* narrow_alu_f8e4m3_def class.  */
+struct narrow_alu_f8e4m3_def : public narrow_alu_def
+{
+  char *get_name (function_builder &b, const function_instance &instance,
+                 bool overloaded_p) const override
+  {
+    return build_f8_narrow_name (b, instance, overloaded_p, "f8e4m3", false);
+  }
+};
+
+/* narrow_alu_f8e5m2_def class.  */
+struct narrow_alu_f8e5m2_def : public narrow_alu_def
+{
+  char *get_name (function_builder &b, const function_instance &instance,
+                 bool overloaded_p) const override
+  {
+    return build_f8_narrow_name (b, instance, overloaded_p, "f8e5m2", false);
+  }
+};
+
+/* narrow_alu_frm_f8e4m3_def class.  */
+struct narrow_alu_frm_f8e4m3_def : public narrow_alu_frm_def
+{
+  char *get_name (function_builder &b, const function_instance &instance,
+                 bool overloaded_p) const override
+  {
+    return build_f8_narrow_name (b, instance, overloaded_p, "f8e4m3", true);
+  }
+};
+
+/* narrow_alu_frm_f8e5m2_def class.  */
+struct narrow_alu_frm_f8e5m2_def : public narrow_alu_frm_def
+{
+  char *get_name (function_builder &b, const function_instance &instance,
+                 bool overloaded_p) const override
+  {
+    return build_f8_narrow_name (b, instance, overloaded_p, "f8e5m2", true);
+  }
+};
+
 /* move_def class. Handle vmv.v.v/vmv.v.x.  */
 struct move_def : public build_base
 {
@@ -1505,4 +1600,8 @@ SHAPE(sf_vcix, sf_vcix)
 /* Zvfofp8min */
 SHAPE (alu_f8e4m3, alu_f8e4m3)
 SHAPE (alu_f8e5m2, alu_f8e5m2)
+SHAPE (narrow_alu_f8e4m3, narrow_alu_f8e4m3)
+SHAPE (narrow_alu_f8e5m2, narrow_alu_f8e5m2)
+SHAPE (narrow_alu_frm_f8e4m3, narrow_alu_frm_f8e4m3)
+SHAPE (narrow_alu_frm_f8e5m2, narrow_alu_frm_f8e5m2)
 } // end namespace riscv_vector
diff --git a/gcc/config/riscv/riscv-vector-builtins-shapes.h 
b/gcc/config/riscv/riscv-vector-builtins-shapes.h
index d700d76da30..0158c82bd3b 100644
--- a/gcc/config/riscv/riscv-vector-builtins-shapes.h
+++ b/gcc/config/riscv/riscv-vector-builtins-shapes.h
@@ -67,6 +67,10 @@ extern const function_shape *const sf_vcix;
 /* Zvfofp8min extension.  */
 extern const function_shape *const alu_f8e4m3;
 extern const function_shape *const alu_f8e5m2;
+extern const function_shape *const narrow_alu_f8e4m3;
+extern const function_shape *const narrow_alu_f8e5m2;
+extern const function_shape *const narrow_alu_frm_f8e4m3;
+extern const function_shape *const narrow_alu_frm_f8e5m2;
 }
 
 } // end namespace riscv_vector
diff --git a/gcc/config/riscv/riscv-vector-builtins.cc 
b/gcc/config/riscv/riscv-vector-builtins.cc
index a7ad068b00a..c671b133c85 100644
--- a/gcc/config/riscv/riscv-vector-builtins.cc
+++ b/gcc/config/riscv/riscv-vector-builtins.cc
@@ -2058,6 +2058,15 @@ static CONSTEXPR const rvv_op_info f_to_nf_f_w_ops
      rvv_arg_type_info (RVV_BASE_double_trunc_float_vector), /* Return type */
      v_args /* Args */};
 
+/* A static operand information for vector_type func (vector_type)
+ * function registration. */
+static CONSTEXPR const rvv_op_info bf16_to_f8_f_w_ops
+  = {bf_ops,     /* Types */
+     OP_TYPE_f_w, /* Suffix */
+     rvv_arg_type_info (
+       RVV_BASE_double_trunc_unsigned_vector), /* Return type */
+     v_args /* Args */};
+
 /* A static operand information for vector_type func (vector_type)
  * function registration. */
 static CONSTEXPR const rvv_op_info f8_to_bf16_f_v_ops
diff --git a/gcc/config/riscv/vector-float8.md 
b/gcc/config/riscv/vector-float8.md
index 2a3eeeaa1db..423b490b811 100644
--- a/gcc/config/riscv/vector-float8.md
+++ b/gcc/config/riscv/vector-float8.md
@@ -35,10 +35,18 @@
   (RVVMF4BF "RVVMF8QI")
 ])
 
-(define_int_iterator ALTFMT [UNSPEC_F8E4M3 UNSPEC_F8E5M2])
+(define_int_iterator ALTFMT [UNSPEC_F8E4M3 UNSPEC_F8E5M2 UNSPEC_F8E4M3_SAT 
UNSPEC_F8E5M2_SAT])
 (define_int_attr altfmt
   [(UNSPEC_F8E4M3     "f8e4m3")
-   (UNSPEC_F8E5M2     "f8e5m2")])
+   (UNSPEC_F8E5M2     "f8e5m2")
+   (UNSPEC_F8E4M3_SAT "f8e4m3_sat")
+   (UNSPEC_F8E5M2_SAT "f8e5m2_sat")])
+
+(define_int_attr sat
+  [(UNSPEC_F8E4M3     "")
+   (UNSPEC_F8E5M2     "")
+   (UNSPEC_F8E4M3_SAT ".sat")
+   (UNSPEC_F8E5M2_SAT ".sat")])
 
 ;; Zvfofp8min extension: FP8 to BF16 widening conversions.
 
@@ -61,3 +69,29 @@
   "vfwcvtbf16.f.f.v\t%0,%3%p1"
   [(set_attr "type" "vfwcvtbf16")
    (set_attr "mode" "<VBF_DOUBLE_TRUNC>")])
+
+;; Zvfofp8min extension: BF16 to FP8 narrowing conversions.
+
+(define_insn "@pred_trunc_<mode>_to_<altfmt>"
+  [(set (match_operand:<VBF_DOUBLE_TRUNC> 0 "register_operand"   "=vd, vd, vr, 
vr,  &vr,  &vr")
+    (if_then_else:<VBF_DOUBLE_TRUNC>
+        (unspec:<VM>
+        [(match_operand:<VM> 1 "vector_mask_operand"            " vm, 
vm,Wc1,Wc1,vmWc1,vmWc1")
+        (match_operand 4 "vector_length_operand"               " rK, rK, rK, 
rK,   rK,   rK")
+        (match_operand 5 "const_int_operand"                   "  i,  i,  i,  
i,    i,    i")
+        (match_operand 6 "const_int_operand"                   "  i,  i,  i,  
i,    i,    i")
+        (match_operand 7 "const_int_operand"                   "  i,  i,  i,  
i,    i,    i")
+        (match_operand 8 "const_int_operand"                   "  i,  i,  i,  
i,    i,    i")
+        (reg:SI VL_REGNUM)
+        (reg:SI VTYPE_REGNUM)
+        (reg:SI FRM_REGNUM)] UNSPEC_VPREDICATE)
+         (unspec:<VBF_DOUBLE_TRUNC>
+        [(float_truncate:<VBF_DOUBLE_TRUNC>
+           (match_operand:VWEXTF_ZVFOFP8MIN 3 "register_operand"      "  0,  
0,  0,  0,   vr,   vr"))] ALTFMT)
+        (match_operand:<VBF_DOUBLE_TRUNC> 2 "vector_merge_operand"  " vu,  0, 
vu,  0,   vu,    0")))]
+  "TARGET_VECTOR && TARGET_ZVFOFP8MIN && TARGET_ZVFBFMIN"
+  "vfncvtbf16<sat>.f.f.w\t%0,%3%p1"
+  [(set_attr "type" "vfncvtbf16")
+   (set_attr "mode" "<VBF_DOUBLE_TRUNC>")
+   (set (attr "frm_mode")
+   (symbol_ref "riscv_vector::get_frm_mode (operands[8])"))])
-- 
2.34.1

Reply via email to