From: Matthias Kretz <[email protected]>

Before this change _Ap::_S_is_bitmask would pick up false from
_ScalarAbi<N>. After __scalar_abi_tag now identifies any _Abi<N, N, V>,
where V can also identify bit-masks, the short-cut of setting
_S_use_bitmask to _Ap::_S_is_bitmask is wrong. It would be correct to
have it say _Ap::_S_is_bitmask && !__scalar_abi_tag<_Ap>. I decided to
implement the latter only in the _S_nreg == 1 specialization and have
the higher ups inherit the value from their vec/mask member. The
_S_is_bitmask bit is not erased for __scalar_abi_tag since it makes a
difference for __abi_rebind.

Signed-off-by: Matthias Kretz <[email protected]>

libstdc++-v3/ChangeLog:

        * include/bits/simd_details.h (_ScalarAbi): Remove.
        (__scalar_abi_tag): Identify _Abi<N, N> as scalar now.
        (__native_abi): Replace _ScalarAbi<1> with _Abi_t<1, 1, ...>.
        (__abi_rebind): Refactor rebinding from/to __scalar_abi_tag.
        * include/bits/simd_mask.h (_S_use_bitmask): Only true if
        !_S_is_scalar.
        (_M_and_neighbors, _M_or_neighbors): Add case for _S_is_scalar
        where the and/or must be executed one step earlier.
        (_M_reduce_min_index, _M_reduce_max_index): Delete dead code.
        * include/bits/simd_vec.h (_S_use_bitmask): Inherit the value
        from the first data member.
        * testsuite/std/simd/traits_impl.cc: Adjust for the removal of
        _ScalarAbi.
---
 libstdc++-v3/include/bits/simd_details.h      | 86 +++++++------------
 libstdc++-v3/include/bits/simd_mask.h         | 38 +++++---
 libstdc++-v3/include/bits/simd_vec.h          |  2 +-
 .../testsuite/std/simd/traits_impl.cc         | 11 +--
 4 files changed, 58 insertions(+), 79 deletions(-)

diff --git a/libstdc++-v3/include/bits/simd_details.h 
b/libstdc++-v3/include/bits/simd_details.h
index a1acc5bd9464..e6185fac64a3 100644
--- a/libstdc++-v3/include/bits/simd_details.h
+++ b/libstdc++-v3/include/bits/simd_details.h
@@ -241,46 +241,21 @@ namespace simd
 #endif
 
   /** @internal
-   * This ABI tag describes basic_vec objects that store one element per data 
member and basic_mask
-   * objects that store one bool data members.
+   * @brief This ABI tag determines the data member(s) of basic_vec and 
basic_mask.
    *
-   * @tparam _Np   The number of elements, which also matches the number of 
data members in
-   *               basic_vec and basic_mask.
-   */
-  template <int _Np = 1>
-    struct _ScalarAbi
-    {
-      static constexpr int _S_size = _Np;
-
-      static constexpr int _S_nreg = _Np;
-
-      static constexpr _AbiVariant _S_variant = {};
-
-      template <typename _Tp>
-       using _DataType = __canonical_vec_type_t<_Tp>;
-
-      static constexpr bool _S_is_vecmask = false;
-
-      // in principle a bool is a 1-bit bitmask, but this is asking for an 
AVX512 bitmask
-      static constexpr bool _S_is_bitmask = false;
-
-      template <size_t>
-       using _MaskDataType = bool;
-
-      template <int _N2, int _Nreg2 = _N2>
-       static consteval _ScalarAbi<_N2>
-       _S_resize()
-       {
-         static_assert(_N2 == _Nreg2);
-         return {};
-       }
-    };
-
-  /** @internal
-   * This ABI tag describes basic_vec objects that store one or more objects 
declared with the
-   * [[gnu::vector_size(N)]] attribute.
-   * Applied to basic_mask objects, this ABI tag either describes 
corresponding vector-mask objects
-   * or bit-mask objects. Which one is used is determined via @p _Var.
+   * @p _Nreg determines the number of recursive basic_vec/basic_mask data 
members where @p _Nreg is
+   * equal to 1. With @p _Nreg equal to 1, the basic_vec/basic_mask holds one 
vector builtin (@p
+   * _Np greater than 1) or a scalar (@ _Np equal to 1).
+   * @f$\lceil\frac{\mathtt{Np}}{\mathtt{Nreg}}\rceil@f$ therefore determines 
the number of elements
+   * in a register (except for a remainder where it can be smaller). If @p _Np 
equals @p _Nreg, (the
+   * aforementioned quotient is 1), then basic_vec (recursively) holds 
non-vector data members and
+   * basic_mask holds bools.
+   *
+   * The @p _Var parameter determines details about the data member in the one 
register case. Masks
+   * can be represented as vector masks (the default comparison result of GNU 
vector builtins),
+   * bit-masks as used by AVX-512, bit-masks as used by ARM SVE (not yet 
implemented), or a single
+   * bool (for the @p _Np equals 1 case). For basic_mask it determines the 
actual data layout and
+   * for basic_mask it determines the result of compares.
    *
    * @tparam _Np    The number of elements.
    * @tparam _Nreg  The number of registers needed to store @p _Np elements.
@@ -391,9 +366,13 @@ namespace simd
            { __x.template _S_resize<_Tp::_S_size, _Tp::_S_nreg>() } -> 
same_as<_Tp>;
          };
 
+  /** @internal
+   * Satisfied if @p _Tp is a valid simd ABI tag and one element is stored per 
register (number of
+   * registers equals size).
+   */
   template <typename _Tp>
     concept __scalar_abi_tag
-      = same_as<_Tp, _ScalarAbi<_Tp::_S_size>> && __abi_tag<_Tp>;
+      = same_as<_Tp, _Abi_t<_Tp::_S_size, _Tp::_S_size, _Tp::_S_variant>> && 
__abi_tag<_Tp>;
 
   // Determine if math functions must *raise* floating-point exceptions.
   // math_errhandling may expand to an extern symbol, in which case we must 
assume fp exceptions
@@ -760,7 +739,7 @@ namespace simd
       else if constexpr (_Traits._M_have_avx512f())
        return _Abi_t<64 / __adj_sizeof, 1, _AbiVariant::_BitMask>();
       else if constexpr (is_same_v<_Tp, _Float16> && !_Traits._M_have_f16c())
-       return _ScalarAbi<1>();
+       return _Abi_t<1, 1>();
       else if constexpr (_Traits._M_have_avx2())
        return _Abi_t<32 / __adj_sizeof, 1>();
       else if constexpr (_Traits._M_have_avx() && is_floating_point_v<_Tp>)
@@ -772,7 +751,7 @@ namespace simd
        return _Abi_t<16 / __adj_sizeof, 1>();
       // no MMX: we can't emit EMMS where it would be necessary
       else
-       return _ScalarAbi<1>();
+       return _Abi_t<1, 1>();
     }
 
 #else
@@ -794,7 +773,7 @@ namespace simd
       if constexpr (!__vectorizable<_Tp>)
        return _InvalidAbi();
       else
-       return _ScalarAbi<1>();
+       return _Abi_t<1, 1>();
     }
 
 #endif
@@ -850,17 +829,19 @@ namespace simd
       if constexpr (_Np <= 0 || !__vectorizable<_Tp>)
        return _InvalidAbi();
 
-      else if constexpr (__scalar_abi_tag<_A0>)
-       return _A0::template _S_resize<_Np>();
-
       else
        {
          using _Native = 
remove_const_t<decltype(std::simd::__native_abi<_Tp>())>;
          static_assert(0 != _Native::_S_size);
          constexpr int __nreg = __div_ceil(_Np, _Native::_S_size);
 
-         if constexpr (__scalar_abi_tag<_Native>)
-           return _Native::template _S_resize<_Np>();
+         // __scalar_abi_tag is sticky (unless we reach size 1, where we can't 
know whether it was
+         // an explicit __scalar_abi_tag before some resize_t)
+         if constexpr (__scalar_abi_tag<_Native> || (__scalar_abi_tag<_A0> && 
_A0::_S_size >= 2))
+           {
+               return _A0::template _S_resize<_Np, _Np>();
+           }
+
          else
            return _Abi_t<_Native::_S_size, 1, 
__filter_abi_variant(_A0::_S_variant,
                                                                    
_AbiVariant::_MaskVariants)
@@ -885,9 +866,6 @@ namespace simd
       if constexpr (_Bytes == 0 || _Np <= 0)
        return _InvalidAbi();
 
-      else if constexpr (__scalar_abi_tag<_A0>)
-       return _A0::template _S_resize<_Np>();
-
 #if _GLIBCXX_X86
       // AVX w/o AVX2:
       // e.g. resize_t<8, mask<float, Whatever>> needs to be _Abi<8, 1> not 
_Abi<8, 2>
@@ -939,12 +917,6 @@ namespace simd
       if (__b0 != __b1)
        return true;
 
-      // everything is better than _ScalarAbi, except when converting to a 
single bool
-      if constexpr (__scalar_abi_tag<_To>)
-       return __n > 1;
-      else if constexpr (__scalar_abi_tag<_From>)
-       return true;
-
       // converting to a bit-mask is better
       else if constexpr (_To::_S_is_vecmask != _From::_S_is_vecmask)
        return _To::_S_is_vecmask; // to vector-mask is explicit
diff --git a/libstdc++-v3/include/bits/simd_mask.h 
b/libstdc++-v3/include/bits/simd_mask.h
index 0a7cfa03cedd..81a0825ec6ce 100644
--- a/libstdc++-v3/include/bits/simd_mask.h
+++ b/libstdc++-v3/include/bits/simd_mask.h
@@ -543,7 +543,7 @@ namespace simd
 
       static constexpr bool _S_is_scalar = _S_has_bool_member;
 
-      static constexpr bool _S_use_bitmask = _Ap::_S_is_bitmask;
+      static constexpr bool _S_use_bitmask = _Ap::_S_is_bitmask && 
!_S_is_scalar;
 
       static constexpr int _S_full_size = [] {
        if constexpr (_S_is_scalar)
@@ -1519,8 +1519,16 @@ namespace simd
       constexpr basic_mask&
       _M_and_neighbors()
       {
-       _M_data0._M_and_neighbors();
-       _M_data1._M_and_neighbors();
+       if constexpr (_S_size == 2)
+         {
+           static_assert(_S_is_scalar);
+           _M_data0 = _M_data1 = _M_data0 && _M_data1;
+         }
+       else
+         {
+           _M_data0._M_and_neighbors();
+           _M_data1._M_and_neighbors();
+         }
        return *this;
       }
 
@@ -1528,8 +1536,16 @@ namespace simd
       constexpr basic_mask&
       _M_or_neighbors()
       {
-       _M_data0._M_or_neighbors();
-       _M_data1._M_or_neighbors();
+       if constexpr (_S_size == 2)
+         {
+           static_assert(_S_is_scalar);
+           _M_data0 = _M_data1 = _M_data0 || _M_data1;
+         }
+       else
+         {
+           _M_data0._M_or_neighbors();
+           _M_data1._M_or_neighbors();
+         }
        return *this;
       }
 
@@ -1650,7 +1666,7 @@ namespace simd
        else if constexpr (_M_data1._S_has_bool_member)
          // in some cases the last element can be 'bool' instead of 
bit-/vector-mask;
          // e.g. mask<short, 17> is {mask<short, 16>, mask<short, 1>}, where 
the latter uses
-         // _ScalarAbi<1>, which is stored as 'bool'
+         // _Abi<1, 1>, which is stored as 'bool'
          return __i < _N0 ? _M_data0[__i] : _M_data1[__i - _N0];
        else if constexpr (abi_type::_S_is_bitmask)
          {
@@ -1929,10 +1945,7 @@ namespace simd
          {
            const auto __bits = _M_to_uint();
            __glibcxx_simd_precondition(__bits, "An empty mask does not have a 
min_index.");
-           if constexpr (_S_size == 1)
-             return 0;
-           else
-             return __countr_zero(_M_to_uint());
+           return __countr_zero(_M_to_uint());
          }
        else if (_M_data0._M_none_of())
          return _M_data1._M_reduce_min_index() + _N0;
@@ -1948,10 +1961,7 @@ namespace simd
          {
            const auto __bits = _M_to_uint();
            __glibcxx_simd_precondition(__bits, "An empty mask does not have a 
max_index.");
-           if constexpr (_S_size == 1)
-             return 0;
-           else
-             return __highest_bit(_M_to_uint());
+           return __highest_bit(_M_to_uint());
          }
        else if (_M_data1._M_none_of())
          return _M_data0._M_reduce_max_index();
diff --git a/libstdc++-v3/include/bits/simd_vec.h 
b/libstdc++-v3/include/bits/simd_vec.h
index 5f3bd7fd2f61..5624ec781426 100644
--- a/libstdc++-v3/include/bits/simd_vec.h
+++ b/libstdc++-v3/include/bits/simd_vec.h
@@ -1776,7 +1776,7 @@ namespace simd
 
       _DataType1 _M_data1;
 
-      static constexpr bool _S_use_bitmask = _Ap::_S_is_bitmask;
+      static constexpr bool _S_use_bitmask = _DataType0::_S_use_bitmask;
 
       static constexpr bool _S_is_partial = _DataType1::_S_is_partial;
 
diff --git a/libstdc++-v3/testsuite/std/simd/traits_impl.cc 
b/libstdc++-v3/testsuite/std/simd/traits_impl.cc
index 2f705c7df2f7..94c6843b6228 100644
--- a/libstdc++-v3/testsuite/std/simd/traits_impl.cc
+++ b/libstdc++-v3/testsuite/std/simd/traits_impl.cc
@@ -49,24 +49,21 @@ void test()
   static_assert(sizeof(_Bitmask<3>) == 1);
   static_assert(sizeof(_Bitmask<30>) == 4);
 
-  static_assert(__scalar_abi_tag<_ScalarAbi<1>>);
-  static_assert(__scalar_abi_tag<_ScalarAbi<2>>);
-  static_assert(!__scalar_abi_tag<_Abi_t<1, 1>>);
-
-  static_assert(__abi_tag<_ScalarAbi<1>>);
-  static_assert(__abi_tag<_ScalarAbi<2>>);
+  static_assert(__scalar_abi_tag<_Abi_t<1, 1>>);
+  static_assert(__scalar_abi_tag<_Abi_t<2, 2>>);
+  static_assert(!__scalar_abi_tag<_Abi_t<2, 1>>);
 
   using AN = decltype(__native_abi<float>());
   using A1 = decltype(__native_abi<float>()._S_resize<1>());
   static_assert(A1::_S_size == 1);
   static_assert(A1::_S_nreg == 1);
   static_assert(A1::_S_variant == AN::_S_variant);
-  static_assert(__scalar_abi_tag<A1> == __scalar_abi_tag<AN>);
   static_assert(std::is_same_v<decltype(__abi_rebind<float, AN::_S_size, 
A1>()), AN>);
   if constexpr (AN::_S_size >= 2) // the target has SIMD support for float
     {
       {
        using A2 = decltype(__abi_rebind<float, 2, AN>());
+       static_assert(__scalar_abi_tag<A2> == __scalar_abi_tag<AN>);
        static_assert(A2::_S_size == 2);
        static_assert(A2::_S_nreg == 1);
        static_assert(A2::_S_variant == AN::_S_variant);
-- 
2.54.0

Reply via email to