From: Lino Hsing-Yu Peng <[email protected]>

The pass now checks altfmt compatibility when merging states and
selecting vsetvl patterns.

gcc/ChangeLog:

        * config/riscv/riscv-vsetvl.cc (altfmt_to_str, get_altfmt): New helpers.
        (demand_flags, altfmt_demand_type): Add altfmt demand tracking.
        (vsetvl_info): Track altfmt and its demand.
        (normalize_demand, parse_insn, get_vsetvl_pat, dump, operator==):
        Handle altfmt.
        (demand_system): Add altfmt compatibility, availability, and merge.
        (preds_all_same_avl_and_ratio_p): Check altfmt availability.
        * config/riscv/riscv-vsetvl.def (DEF_ALTFMT_RULE): New macro.
        Add altfmt fusion rules.
---
 gcc/config/riscv/riscv-vsetvl.cc  | 183 ++++++++++++++++++++++++++----
 gcc/config/riscv/riscv-vsetvl.def |  15 +++
 2 files changed, 175 insertions(+), 23 deletions(-)

diff --git a/gcc/config/riscv/riscv-vsetvl.cc b/gcc/config/riscv/riscv-vsetvl.cc
index c62295ee89b..22a8f6c6ef1 100644
--- a/gcc/config/riscv/riscv-vsetvl.cc
+++ b/gcc/config/riscv/riscv-vsetvl.cc
@@ -256,6 +256,22 @@ policy_to_str (bool agnostic_p)
   return agnostic_p ? "agnostic" : "undisturbed";
 }
 
+static const char *
+altfmt_to_str (uint8_t altfmt)
+{
+  switch (altfmt)
+    {
+    case ALTFMT_NONE:
+      return "none";
+    case ALTFMT_ALT:
+      return "alt";
+    case ALTFMT_ANY:
+      return "any";
+    default:
+      return "unknown";
+    }
+}
+
 /* Return true if it is an RVV instruction depends on VTYPE global
    status register.  */
 bool
@@ -501,6 +517,14 @@ mask_agnostic_p (rtx_insn *rinsn)
   return ma == INVALID_ATTRIBUTE ? get_default_ma () : IS_AGNOSTIC (ma);
 }
 
+static uint8_t
+get_altfmt (rtx_insn *rinsn)
+{
+  extract_insn_cached (rinsn);
+  int altfmt = get_attr_altfmt (rinsn);
+  return altfmt == INVALID_ATTRIBUTE ? ALTFMT_NONE : altfmt;
+}
+
 /* Return true if FN has a vector instruction that use VL/VTYPE.  */
 static bool
 has_vector_insn (function *fn)
@@ -817,19 +841,21 @@ enum demand_flags : unsigned
   DEMAND_MASK_POLICY_P = 1 << 5,
   DEMAND_AVL_P = 1 << 6,
   DEMAND_NON_ZERO_AVL_P = 1 << 7,
+  DEMAND_ALTFMT_P = 1 << 8,
 };
 
-/* We split the demand information into three parts. They are sew and lmul
+/* We split the demand information into four parts. They are sew and lmul
    related (sew_lmul_demand_type), tail and mask policy related
-   (policy_demand_type) and avl related (avl_demand_type). Then we define three
-   interfaces available_p, compatible_p and merge. available_p is
-   used to determine whether the two vsetvl infos prev_info and next_info are
-   available or not. If prev_info is available for next_info, it means that the
-   RVV insn corresponding to next_info on the path from prev_info to next_info
-   can be used without inserting a separate vsetvl instruction. compatible_p
-   is used to determine whether prev_info is compatible with next_info, and if
-   so, merge can be used to merge the stricter demand information from
-   next_info into prev_info so that prev_info becomes available to next_info.
+   (policy_demand_type), alternate FP8 format related (altfmt_demand_type) and
+   avl related (avl_demand_type). Then we define three interfaces available_p,
+   compatible_p and merge. available_p is used to determine whether the two
+   vsetvl infos prev_info and next_info are available or not. If prev_info is
+   available for next_info, it means that the RVV insn corresponding to
+   next_info on the path from prev_info to next_info can be used without
+   inserting a separate vsetvl instruction. compatible_p is used to determine
+   whether prev_info is compatible with next_info, and if so, merge can be used
+   to merge the stricter demand information from next_info into prev_info so
+   that prev_info becomes available to next_info.
  */
 
 enum class sew_lmul_demand_type : unsigned
@@ -851,6 +877,12 @@ enum class policy_demand_type : unsigned
   ignore_policy = demand_flags::DEMAND_EMPTY_P,
 };
 
+enum class altfmt_demand_type : unsigned
+{
+  altfmt = demand_flags::DEMAND_ALTFMT_P,
+  ignore_altfmt = demand_flags::DEMAND_EMPTY_P,
+};
+
 enum class avl_demand_type : unsigned
 {
   avl = demand_flags::DEMAND_AVL_P,
@@ -900,11 +932,13 @@ private:
   uint8_t m_max_sew;
   vlmul_type m_vlmul;
   uint8_t m_ratio;
+  uint8_t m_altfmt;
   bool m_ta;
   bool m_ma;
 
   sew_lmul_demand_type m_sew_lmul_demand;
   policy_demand_type m_policy_demand;
+  altfmt_demand_type m_altfmt_demand;
   avl_demand_type m_avl_demand;
 
   enum class state_type
@@ -925,9 +959,10 @@ public:
   vsetvl_info ()
     : m_insn (nullptr), m_bb (nullptr), m_avl (NULL_RTX), m_vl (NULL_RTX),
       m_avl_def (nullptr), m_sew (0), m_max_sew (0), m_vlmul (LMUL_RESERVED),
-      m_ratio (0), m_ta (false), m_ma (false),
+      m_ratio (0), m_altfmt (ALTFMT_NONE), m_ta (false), m_ma (false),
       m_sew_lmul_demand (sew_lmul_demand_type::sew_lmul),
       m_policy_demand (policy_demand_type::tail_mask_policy),
+      m_altfmt_demand (altfmt_demand_type::altfmt),
       m_avl_demand (avl_demand_type::avl), m_state (state_type::UNINITIALIZED),
       m_delete (false), m_change_vtype_only (false), m_read_vl_insn (nullptr),
       m_vl_used_by_non_rvv_insn (false)
@@ -943,6 +978,7 @@ public:
   void set_sew (uint8_t sew) { m_sew = sew; }
   void set_vlmul (vlmul_type vlmul) { m_vlmul = vlmul; }
   void set_ratio (uint8_t ratio) { m_ratio = ratio; }
+  void set_altfmt (uint8_t altfmt) { m_altfmt = altfmt; }
   void set_ta (bool ta) { m_ta = ta; }
   void set_ma (bool ma) { m_ma = ma; }
   void set_delete () { m_delete = true; }
@@ -957,6 +993,7 @@ public:
   uint8_t get_sew () const { return m_sew; }
   vlmul_type get_vlmul () const { return m_vlmul; }
   uint8_t get_ratio () const { return m_ratio; }
+  uint8_t get_altfmt () const { return m_altfmt; }
   bool get_ta () const { return m_ta; }
   bool get_ma () const { return m_ma; }
   insn_info *get_insn () const { return m_insn; }
@@ -1026,6 +1063,10 @@ public:
   {
     m_policy_demand = demand;
   }
+  void set_altfmt_demand (altfmt_demand_type demand)
+  {
+    m_altfmt_demand = demand;
+  }
   void set_avl_demand (avl_demand_type demand) { m_avl_demand = demand; }
 
   sew_lmul_demand_type get_sew_lmul_demand () const
@@ -1033,6 +1074,7 @@ public:
     return m_sew_lmul_demand;
   }
   policy_demand_type get_policy_demand () const { return m_policy_demand; }
+  altfmt_demand_type get_altfmt_demand () const { return m_altfmt_demand; }
   avl_demand_type get_avl_demand () const { return m_avl_demand; }
 
   void normalize_demand (unsigned demand_flags)
@@ -1077,6 +1119,18 @@ public:
        gcc_unreachable ();
       }
 
+    switch (demand_flags & DEMAND_ALTFMT_P)
+      {
+      case (unsigned) altfmt_demand_type::altfmt:
+       m_altfmt_demand = altfmt_demand_type::altfmt;
+       break;
+      case (unsigned) altfmt_demand_type::ignore_altfmt:
+       m_altfmt_demand = altfmt_demand_type::ignore_altfmt;
+       break;
+      default:
+       gcc_unreachable ();
+      }
+
     switch (demand_flags & (DEMAND_AVL_P | DEMAND_NON_ZERO_AVL_P))
       {
       case (unsigned) avl_demand_type::avl:
@@ -1107,6 +1161,9 @@ public:
       m_vl = ::get_vl (rinsn);
     m_sew = ::get_sew (rinsn);
     m_vlmul = ::get_vlmul (rinsn);
+    m_altfmt = ::get_altfmt (rinsn);
+    m_altfmt_demand = m_altfmt == ALTFMT_ANY ? 
altfmt_demand_type::ignore_altfmt
+                                            : altfmt_demand_type::altfmt;
     m_ta = tail_agnostic_p (rinsn);
     m_ma = mask_agnostic_p (rinsn);
   }
@@ -1166,6 +1223,7 @@ public:
        in demand info backward analysis.  */
     if (m_ratio == INVALID_ATTRIBUTE)
       m_ratio = calculate_ratio (m_sew, m_vlmul);
+    m_altfmt = ::get_altfmt (insn->rtl ());
     m_ta = tail_agnostic_p (insn->rtl ());
     m_ma = mask_agnostic_p (insn->rtl ());
 
@@ -1232,6 +1290,9 @@ public:
          dflags |= demand_flags::DEMAND_MASK_POLICY_P;
       }
 
+    if (m_altfmt != ALTFMT_ANY)
+      dflags |= demand_flags::DEMAND_ALTFMT_P;
+
     normalize_demand (dflags);
 
     /* Optimize AVL from the vsetvl instruction.  */
@@ -1288,7 +1349,8 @@ public:
       avl = GEN_INT (0);
     rtx sew = gen_int_mode (get_sew (), Pmode);
     rtx vlmul = gen_int_mode (get_vlmul (), Pmode);
-    rtx altfmt = const0_rtx;
+    uint8_t altfmt_val = get_altfmt () == ALTFMT_ALT ? ALTFMT_ALT : 
ALTFMT_NONE;
+    rtx altfmt = gen_int_mode (altfmt_val, Pmode);
     rtx ta = gen_int_mode (get_ta (), Pmode);
     rtx ma = gen_int_mode (get_ma (), Pmode);
 
@@ -1336,11 +1398,13 @@ public:
           && get_avl () == other.get_avl () && get_vl () == other.get_vl ()
           && get_avl_def () == other.get_avl_def ()
           && get_sew () == other.get_sew ()
-          && get_vlmul () == other.get_vlmul () && get_ta () == other.get_ta ()
-          && get_ma () == other.get_ma ()
+          && get_vlmul () == other.get_vlmul ()
+          && get_altfmt () == other.get_altfmt ()
+          && get_ta () == other.get_ta () && get_ma () == other.get_ma ()
           && get_avl_demand () == other.get_avl_demand ()
           && get_sew_lmul_demand () == other.get_sew_lmul_demand ()
-          && get_policy_demand () == other.get_policy_demand ();
+          && get_policy_demand () == other.get_policy_demand ()
+          && get_altfmt_demand () == other.get_altfmt_demand ();
   }
 
   void dump (FILE *file, const char *indent = "") const
@@ -1385,6 +1449,9 @@ public:
     else if (m_policy_demand == policy_demand_type::mask_policy_only)
       fprintf (file, " demand_mask_policy_only");
 
+    if (m_altfmt_demand == altfmt_demand_type::altfmt)
+      fprintf (file, " demand_altfmt");
+
     if (m_avl_demand == avl_demand_type::avl)
       fprintf (file, " demand_avl");
     else if (m_avl_demand == avl_demand_type::non_zero_avl)
@@ -1393,6 +1460,7 @@ public:
 
     fprintf (file, "%sSEW=%d, ", indent, get_sew ());
     fprintf (file, "VLMUL=%s, ", vlmul_to_str (get_vlmul ()));
+    fprintf (file, "ALTFMT=%s, ", altfmt_to_str (get_altfmt ()));
     fprintf (file, "RATIO=%d, ", get_ratio ());
     fprintf (file, "MAX_SEW=%d\n", get_max_sew ());
 
@@ -1670,6 +1738,13 @@ private:
     return tail_policy_eq_p (prev, next) && mask_policy_eq_p (prev, next);
   }
 
+  /* predictors for altfmt */
+
+  inline bool altfmt_eq_p (const vsetvl_info &prev, const vsetvl_info &next)
+  {
+    return prev.get_altfmt () == next.get_altfmt ();
+  }
+
   /* predictors for avl */
 
   inline bool modify_or_use_vl_p (insn_info *i, const vsetvl_info &info)
@@ -1917,6 +1992,13 @@ private:
     use_mask_policy (prev, next);
   }
 
+  /* modifiers for altfmt */
+
+  inline void use_next_altfmt (vsetvl_info &prev, const vsetvl_info &next)
+  {
+    prev.set_altfmt (next.get_altfmt ());
+  }
+
   /* modifiers for avl */
 
   inline void use_next_avl (vsetvl_info &prev, const vsetvl_info &next)
@@ -2152,6 +2234,59 @@ public:
       return;                                                                  
\
     }
 
+#include "riscv-vsetvl.def"
+
+    gcc_unreachable ();
+  }
+
+  bool altfmt_compatible_p (const vsetvl_info &prev, const vsetvl_info &next)
+  {
+    gcc_assert (prev.valid_p () && next.valid_p ());
+    altfmt_demand_type prev_flags = prev.get_altfmt_demand ();
+    altfmt_demand_type next_flags = next.get_altfmt_demand ();
+#define DEF_ALTFMT_RULE(PREV_FLAGS, NEXT_FLAGS, NEW_FLAGS, COMPATIBLE_P,       
\
+                       AVAILABLE_P, FUSE)                                     \
+  if (prev_flags == altfmt_demand_type::PREV_FLAGS                             
\
+      && next_flags == altfmt_demand_type::NEXT_FLAGS)                         
\
+    return COMPATIBLE_P (prev, next);
+
+#include "riscv-vsetvl.def"
+
+    gcc_unreachable ();
+  }
+
+  bool altfmt_available_p (const vsetvl_info &prev, const vsetvl_info &next)
+  {
+    gcc_assert (prev.valid_p () && next.valid_p ());
+    altfmt_demand_type prev_flags = prev.get_altfmt_demand ();
+    altfmt_demand_type next_flags = next.get_altfmt_demand ();
+#define DEF_ALTFMT_RULE(PREV_FLAGS, NEXT_FLAGS, NEW_FLAGS, COMPATIBLE_P,       
\
+                       AVAILABLE_P, FUSE)                                     \
+  if (prev_flags == altfmt_demand_type::PREV_FLAGS                             
\
+      && next_flags == altfmt_demand_type::NEXT_FLAGS)                         
\
+    return AVAILABLE_P (prev, next);
+
+#include "riscv-vsetvl.def"
+
+    gcc_unreachable ();
+  }
+
+  void merge_altfmt (vsetvl_info &prev, const vsetvl_info &next)
+  {
+    gcc_assert (prev.valid_p () && next.valid_p ());
+    altfmt_demand_type prev_flags = prev.get_altfmt_demand ();
+    altfmt_demand_type next_flags = next.get_altfmt_demand ();
+#define DEF_ALTFMT_RULE(PREV_FLAGS, NEXT_FLAGS, NEW_FLAGS, COMPATIBLE_P,       
\
+                       AVAILABLE_P, FUSE)                                     \
+  if (prev_flags == altfmt_demand_type::PREV_FLAGS                             
\
+      && next_flags == altfmt_demand_type::NEXT_FLAGS)                         
\
+    {                                                                          
\
+      gcc_assert (COMPATIBLE_P (prev, next));                                  
\
+      FUSE (prev, next);                                                       
\
+      prev.set_altfmt_demand (altfmt_demand_type::NEW_FLAGS);                  
\
+      return;                                                                  
\
+    }
+
 #include "riscv-vsetvl.def"
 
     gcc_unreachable ();
@@ -2226,19 +2361,19 @@ public:
 
   bool compatible_p (const vsetvl_info &prev, const vsetvl_info &next)
   {
-    bool compatible_p = sew_lmul_compatible_p (prev, next)
-                       && policy_compatible_p (prev, next)
-                       && avl_compatible_p (prev, next)
-                       && vl_not_in_conflict_p (prev, next);
+    bool compatible_p
+      = sew_lmul_compatible_p (prev, next) && policy_compatible_p (prev, next)
+       && altfmt_compatible_p (prev, next) && avl_compatible_p (prev, next)
+       && vl_not_in_conflict_p (prev, next);
     return compatible_p;
   }
 
   bool available_p (const vsetvl_info &prev, const vsetvl_info &next)
   {
-    bool available_p = sew_lmul_available_p (prev, next)
-                      && policy_available_p (prev, next)
-                      && avl_available_p (prev, next)
-                      && vl_not_in_conflict_p (prev, next);
+    bool available_p
+      = sew_lmul_available_p (prev, next) && policy_available_p (prev, next)
+       && altfmt_available_p (prev, next) && avl_available_p (prev, next)
+       && vl_not_in_conflict_p (prev, next);
     gcc_assert (!available_p || compatible_p (prev, next));
     return available_p;
   }
@@ -2248,6 +2383,7 @@ public:
     gcc_assert (compatible_p (prev, next));
     merge_sew_lmul (prev, next);
     merge_policy (prev, next);
+    merge_altfmt (prev, next);
     merge_avl (prev, next);
     gcc_assert (available_p (prev, next));
   }
@@ -2497,6 +2633,7 @@ private:
        const vsetvl_info &prev_info = *m_vsetvl_def_exprs[expr_index];
        if (!prev_info.valid_p ()
            || !m_dem.avl_available_p (prev_info, curr_info)
+           || !m_dem.altfmt_available_p (prev_info, curr_info)
            || prev_info.get_ratio () != curr_info.get_ratio ())
          return false;
       }
diff --git a/gcc/config/riscv/riscv-vsetvl.def 
b/gcc/config/riscv/riscv-vsetvl.def
index 2cef36bc4e9..948e812c45e 100644
--- a/gcc/config/riscv/riscv-vsetvl.def
+++ b/gcc/config/riscv/riscv-vsetvl.def
@@ -38,6 +38,11 @@ along with GCC; see the file COPYING3.  If not see
                        available_p, fuse)
 #endif
 
+#ifndef DEF_ALTFMT_RULE
+#define DEF_ALTFMT_RULE(prev_demand, next_demand, fused_demand, compatible_p,  
\
+                       available_p, fuse)
+#endif
+
 #ifndef DEF_AVL_RULE
 #define DEF_AVL_RULE(prev_demand, next_demand, fused_demand, compatible_p,     
\
                     available_p, fuse)
@@ -153,6 +158,15 @@ DEF_POLICY_RULE (ignore_policy, mask_policy_only, 
mask_policy_only, always_true,
 DEF_POLICY_RULE (ignore_policy, ignore_policy, ignore_policy, always_true,
                 always_true, nop)
 
+/* Define ALTFMT compatible and merge rules.  */
+
+DEF_ALTFMT_RULE (altfmt, altfmt, altfmt, altfmt_eq_p, altfmt_eq_p, nop)
+DEF_ALTFMT_RULE (altfmt, ignore_altfmt, altfmt, always_true, always_true, nop)
+DEF_ALTFMT_RULE (ignore_altfmt, altfmt, altfmt, always_true, always_false,
+                use_next_altfmt)
+DEF_ALTFMT_RULE (ignore_altfmt, ignore_altfmt, ignore_altfmt, always_true,
+                always_true, nop)
+
 /* Define AVL compatible and merge rules.  */
 
 DEF_AVL_RULE (avl, avl, avl, avl_equal_p, avl_equal_p, nop)
@@ -177,4 +191,5 @@ DEF_AVL_RULE (ignore_avl, ignore_avl, ignore_avl, 
always_true, always_true, nop)
 
 #undef DEF_SEW_LMUL_RULE
 #undef DEF_POLICY_RULE
+#undef DEF_ALTFMT_RULE
 #undef DEF_AVL_RULE
-- 
2.34.1

Reply via email to