Author: halbi2
Date: 2026-02-16T11:52:13+01:00
New Revision: 1f65d4dda14cfea4323fd7139e222d26c7dc365d

URL: 
https://github.com/llvm/llvm-project/commit/1f65d4dda14cfea4323fd7139e222d26c7dc365d
DIFF: 
https://github.com/llvm/llvm-project/commit/1f65d4dda14cfea4323fd7139e222d26c7dc365d.diff

LOG: [Clang] make most enums trivially equality comparable (#169079)

std::equal(std::byte) currently has sub-optimal codegen due to enum
types not being recognized as trivially equality comparable. In order to
fix this we make them trivially comparable. In the process I factored
out into a standalone function EqualityComparisonIsDefaulted and
refactored the test cases.

Enum types cannot have operator== which is a hidden friend.

Fixes #132672

Added: 
    

Modified: 
    clang/docs/ReleaseNotes.rst
    clang/lib/Sema/SemaTypeTraits.cpp
    clang/test/SemaCXX/type-traits.cpp

Removed: 
    


################################################################################
diff  --git a/clang/docs/ReleaseNotes.rst b/clang/docs/ReleaseNotes.rst
index c9bedb87c6a79..5ac5ae4a7d37e 100644
--- a/clang/docs/ReleaseNotes.rst
+++ b/clang/docs/ReleaseNotes.rst
@@ -103,6 +103,8 @@ What's New in Clang |release|?
 C++ Language Changes
 --------------------
 
+- ``__is_trivially_equality_comparable`` no longer returns false for all enum 
types. (#GH132672)
+
 C++2c Feature Support
 ^^^^^^^^^^^^^^^^^^^^^
 

diff  --git a/clang/lib/Sema/SemaTypeTraits.cpp 
b/clang/lib/Sema/SemaTypeTraits.cpp
index be2c8853d6433..a94a59e8add7b 100644
--- a/clang/lib/Sema/SemaTypeTraits.cpp
+++ b/clang/lib/Sema/SemaTypeTraits.cpp
@@ -503,6 +503,41 @@ static bool HasNoThrowOperator(CXXRecordDecl *RD, 
OverloadedOperatorKind Op,
   return false;
 }
 
+static bool equalityComparisonIsDefaulted(Sema &S, const TagDecl *Decl,
+                                          SourceLocation KeyLoc) {
+  CanQualType T = S.Context.getCanonicalTagType(Decl);
+
+  EnterExpressionEvaluationContext UnevaluatedContext(
+      S, Sema::ExpressionEvaluationContext::Unevaluated);
+  Sema::SFINAETrap SFINAE(S, /*WithAccessChecking=*/true);
+  Sema::ContextRAII TUContext(S, S.Context.getTranslationUnitDecl());
+
+  // const ClassT& obj;
+  OpaqueValueExpr Operand(KeyLoc, T.withConst(), ExprValueKind::VK_LValue);
+  UnresolvedSet<16> Functions;
+  // obj == obj;
+  S.LookupBinOp(S.TUScope, {}, BinaryOperatorKind::BO_EQ, Functions);
+
+  ExprResult Result = S.CreateOverloadedBinOp(KeyLoc, 
BinaryOperatorKind::BO_EQ,
+                                              Functions, &Operand, &Operand);
+  if (Result.isInvalid() || SFINAE.hasErrorOccurred())
+    return false;
+
+  const auto *CallExpr = dyn_cast<CXXOperatorCallExpr>(Result.get());
+  if (!CallExpr)
+    return isa<EnumDecl>(Decl);
+  const auto *Callee = CallExpr->getDirectCallee();
+  auto ParamT = Callee->getParamDecl(0)->getType();
+  if (!Callee->isDefaulted())
+    return false;
+  if (!ParamT->isReferenceType()) {
+    const CXXRecordDecl *RD = dyn_cast<CXXRecordDecl>(Decl);
+    if (RD && !RD->isTriviallyCopyable())
+      return false;
+  }
+  return S.Context.hasSameUnqualifiedType(ParamT.getNonReferenceType(), T);
+}
+
 static bool HasNonDeletedDefaultedEqualityComparison(Sema &S,
                                                      const CXXRecordDecl *Decl,
                                                      SourceLocation KeyLoc) {
@@ -511,36 +546,8 @@ static bool HasNonDeletedDefaultedEqualityComparison(Sema 
&S,
   if (Decl->isLambda())
     return Decl->isCapturelessLambda();
 
-  CanQualType T = S.Context.getCanonicalTagType(Decl);
-  {
-    EnterExpressionEvaluationContext UnevaluatedContext(
-        S, Sema::ExpressionEvaluationContext::Unevaluated);
-    Sema::SFINAETrap SFINAE(S, /*ForValidityCheck=*/true);
-    Sema::ContextRAII TUContext(S, S.Context.getTranslationUnitDecl());
-
-    // const ClassT& obj;
-    OpaqueValueExpr Operand(KeyLoc, T.withConst(), ExprValueKind::VK_LValue);
-    UnresolvedSet<16> Functions;
-    // obj == obj;
-    S.LookupBinOp(S.TUScope, {}, BinaryOperatorKind::BO_EQ, Functions);
-
-    auto Result = S.CreateOverloadedBinOp(KeyLoc, BinaryOperatorKind::BO_EQ,
-                                          Functions, &Operand, &Operand);
-    if (Result.isInvalid() || SFINAE.hasErrorOccurred())
-      return false;
-
-    const auto *CallExpr = dyn_cast<CXXOperatorCallExpr>(Result.get());
-    if (!CallExpr)
-      return false;
-    const auto *Callee = CallExpr->getDirectCallee();
-    auto ParamT = Callee->getParamDecl(0)->getType();
-    if (!Callee->isDefaulted())
-      return false;
-    if (!ParamT->isReferenceType() && !Decl->isTriviallyCopyable())
-      return false;
-    if (!S.Context.hasSameUnqualifiedType(ParamT.getNonReferenceType(), T))
-      return false;
-  }
+  if (!equalityComparisonIsDefaulted(S, Decl, KeyLoc))
+    return false;
 
   return llvm::all_of(Decl->bases(),
                       [&](const CXXBaseSpecifier &BS) {
@@ -555,9 +562,13 @@ static bool HasNonDeletedDefaultedEqualityComparison(Sema 
&S,
              Type = Type->getBaseElementTypeUnsafe()
                         ->getCanonicalTypeUnqualified();
 
-           if (Type->isReferenceType() || Type->isEnumeralType())
+           if (Type->isReferenceType())
              return false;
-           if (const auto *RD = Type->getAsCXXRecordDecl())
+           if (Type->isEnumeralType()) {
+             EnumDecl *ED =
+                 Type->castAs<EnumType>()->getDecl()->getDefinitionOrSelf();
+             return equalityComparisonIsDefaulted(S, ED, KeyLoc);
+           } else if (const auto *RD = Type->getAsCXXRecordDecl())
              return HasNonDeletedDefaultedEqualityComparison(S, RD, KeyLoc);
            return true;
          });
@@ -567,9 +578,15 @@ static bool isTriviallyEqualityComparableType(Sema &S, 
QualType Type,
                                               SourceLocation KeyLoc) {
   QualType CanonicalType = Type.getCanonicalType();
   if (CanonicalType->isIncompleteType() || CanonicalType->isDependentType() ||
-      CanonicalType->isEnumeralType() || CanonicalType->isArrayType())
+      CanonicalType->isArrayType())
     return false;
 
+  if (CanonicalType->isEnumeralType()) {
+    EnumDecl *ED =
+        CanonicalType->castAs<EnumType>()->getDecl()->getDefinitionOrSelf();
+    return equalityComparisonIsDefaulted(S, ED, KeyLoc);
+  }
+
   if (const auto *RD = CanonicalType->getAsCXXRecordDecl()) {
     if (!HasNonDeletedDefaultedEqualityComparison(S, RD, KeyLoc))
       return false;

diff  --git a/clang/test/SemaCXX/type-traits.cpp 
b/clang/test/SemaCXX/type-traits.cpp
index 561c9ca8286b9..65c0729571f99 100644
--- a/clang/test/SemaCXX/type-traits.cpp
+++ b/clang/test/SemaCXX/type-traits.cpp
@@ -4003,6 +4003,14 @@ namespace is_trivially_equality_comparable {
 struct ForwardDeclared; // expected-note {{forward declaration of 
'is_trivially_equality_comparable::ForwardDeclared'}}
 static_assert(!__is_trivially_equality_comparable(ForwardDeclared)); // 
expected-error {{incomplete type 'ForwardDeclared' used in type trait 
expression}}
 
+enum Byte : unsigned char {};
+enum ByteWithOpEq : unsigned char {};
+bool operator==(ByteWithOpEq, ByteWithOpEq);
+
+enum Enum {};
+enum EnumWithOpEq {};
+bool operator==(EnumWithOpEq, EnumWithOpEq);
+
 static_assert(!__is_trivially_equality_comparable(void));
 static_assert(__is_trivially_equality_comparable(int));
 static_assert(!__is_trivially_equality_comparable(int[]));
@@ -4010,6 +4018,10 @@ 
static_assert(!__is_trivially_equality_comparable(int[3]));
 static_assert(!__is_trivially_equality_comparable(float));
 static_assert(!__is_trivially_equality_comparable(double));
 static_assert(!__is_trivially_equality_comparable(long double));
+static_assert(__is_trivially_equality_comparable(Byte));
+static_assert(!__is_trivially_equality_comparable(ByteWithOpEq));
+static_assert(__is_trivially_equality_comparable(Enum));
+static_assert(!__is_trivially_equality_comparable(EnumWithOpEq));
 
 struct NonTriviallyEqualityComparableNoComparator {
   int i;
@@ -4043,19 +4055,26 @@ struct TriviallyEqualityComparable {
 };
 static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparable));
 
-struct TriviallyEqualityComparableContainsArray {
-  int a[4];
-
-  bool operator==(const TriviallyEqualityComparableContainsArray&) const = 
default;
-};
-static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableContainsArray));
-
-struct TriviallyEqualityComparableContainsMultiDimensionArray {
-  int a[4][4];
-
-  bool operator==(const 
TriviallyEqualityComparableContainsMultiDimensionArray&) const = default;
-};
-static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableContainsMultiDimensionArray));
+template <class T>
+struct TriviallyEqualityComparableContains {
+  T t;
+
+  bool operator==(const TriviallyEqualityComparableContains&) const = default;
+};
+
+static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableContains<int>));
+static_assert(!__is_trivially_equality_comparable(TriviallyEqualityComparableContains<int&>));
+static_assert(!__is_trivially_equality_comparable(TriviallyEqualityComparableContains<float>));
+static_assert(!__is_trivially_equality_comparable(TriviallyEqualityComparableContains<double>));
+static_assert(!__is_trivially_equality_comparable(TriviallyEqualityComparableContains<long
 double>));
+static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableContains<int[4]>));
+static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableContains<int[4][4]>));
+static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableContains<Enum>));
+static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableContains<Enum[2]>));
+static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableContains<Enum[2][2]>));
+static_assert(!__is_trivially_equality_comparable(TriviallyEqualityComparableContains<EnumWithOpEq>));
+static_assert(!__is_trivially_equality_comparable(TriviallyEqualityComparableContains<EnumWithOpEq[2]>));
+static_assert(!__is_trivially_equality_comparable(TriviallyEqualityComparableContains<EnumWithOpEq[2][2]>));
 
 auto GetNonCapturingLambda() { return [](){ return 42; }; }
 
@@ -4184,13 +4203,6 @@ struct 
NotTriviallyEqualityComparableImplicitlyDeletedOperatorByStruct {
 };
 
static_assert(!__is_trivially_equality_comparable(NotTriviallyEqualityComparableImplicitlyDeletedOperatorByStruct));
 
-struct NotTriviallyEqualityComparableHasReferenceMember {
-  int& i;
-
-  bool operator==(const NotTriviallyEqualityComparableHasReferenceMember&) 
const = default;
-};
-static_assert(!__is_trivially_equality_comparable(NotTriviallyEqualityComparableHasReferenceMember));
-
 struct NotTriviallyEqualityComparableNonTriviallyComparableBaseBase {
   int i;
 
@@ -4206,34 +4218,6 @@ struct 
NotTriviallyEqualityComparableNonTriviallyComparableBase : NotTriviallyEq
 };
 
static_assert(!__is_trivially_equality_comparable(NotTriviallyEqualityComparableNonTriviallyComparableBase));
 
-enum E {
-  a,
-  b
-};
-bool operator==(E, E) { return false; }
-static_assert(!__is_trivially_equality_comparable(E));
-
-struct NotTriviallyEqualityComparableHasEnum {
-  E e;
-  bool operator==(const NotTriviallyEqualityComparableHasEnum&) const = 
default;
-};
-static_assert(!__is_trivially_equality_comparable(NotTriviallyEqualityComparableHasEnum));
-
-struct NotTriviallyEqualityComparableNonTriviallyEqualityComparableArrs {
-  E e[1];
-
-  bool operator==(const 
NotTriviallyEqualityComparableNonTriviallyEqualityComparableArrs&) const = 
default;
-};
-static_assert(!__is_trivially_equality_comparable(NotTriviallyEqualityComparableNonTriviallyEqualityComparableArrs));
-
-struct NotTriviallyEqualityComparableNonTriviallyEqualityComparableArrs2 {
-  E e[1][1];
-
-  bool operator==(const 
NotTriviallyEqualityComparableNonTriviallyEqualityComparableArrs2&) const = 
default;
-};
-
-static_assert(!__is_trivially_equality_comparable(NotTriviallyEqualityComparableNonTriviallyEqualityComparableArrs2));
-
 struct NotTriviallyEqualityComparablePrivateComparison {
   int i;
 
@@ -4321,6 +4305,27 @@ struct TriviallyEqualityComparable {
 };
 static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparable));
 
+template <class T>
+struct TriviallyEqualityComparableContains {
+  T t;
+
+  friend bool operator==(const TriviallyEqualityComparableContains&, const 
TriviallyEqualityComparableContains&) = default;
+};
+
+static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableContains<int>));
+static_assert(!__is_trivially_equality_comparable(TriviallyEqualityComparableContains<int&>));
+static_assert(!__is_trivially_equality_comparable(TriviallyEqualityComparableContains<float>));
+static_assert(!__is_trivially_equality_comparable(TriviallyEqualityComparableContains<double>));
+static_assert(!__is_trivially_equality_comparable(TriviallyEqualityComparableContains<long
 double>));
+static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableContains<int[4]>));
+static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableContains<int[4][4]>));
+static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableContains<Enum>));
+static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableContains<Enum[2]>));
+static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableContains<Enum[2][2]>));
+static_assert(!__is_trivially_equality_comparable(TriviallyEqualityComparableContains<EnumWithOpEq>));
+static_assert(!__is_trivially_equality_comparable(TriviallyEqualityComparableContains<EnumWithOpEq[2]>));
+static_assert(!__is_trivially_equality_comparable(TriviallyEqualityComparableContains<EnumWithOpEq[2][2]>));
+
 struct TriviallyEqualityComparableNonTriviallyCopyable {
   TriviallyEqualityComparableNonTriviallyCopyable(const 
TriviallyEqualityComparableNonTriviallyCopyable&);
   ~TriviallyEqualityComparableNonTriviallyCopyable();
@@ -4437,26 +4442,6 @@ struct 
NotTriviallyEqualityComparableImplicitlyDeletedOperatorByStruct {
 };
 
static_assert(!__is_trivially_equality_comparable(NotTriviallyEqualityComparableImplicitlyDeletedOperatorByStruct));
 
-struct NotTriviallyEqualityComparableHasReferenceMember {
-  int& i;
-
-  friend bool operator==(const 
NotTriviallyEqualityComparableHasReferenceMember&, const 
NotTriviallyEqualityComparableHasReferenceMember&) = default;
-};
-static_assert(!__is_trivially_equality_comparable(NotTriviallyEqualityComparableHasReferenceMember));
-
-enum E {
-  a,
-  b
-};
-bool operator==(E, E) { return false; }
-static_assert(!__is_trivially_equality_comparable(E));
-
-struct NotTriviallyEqualityComparableHasEnum {
-  E e;
-  friend bool operator==(const NotTriviallyEqualityComparableHasEnum&, const 
NotTriviallyEqualityComparableHasEnum&) = default;
-};
-static_assert(!__is_trivially_equality_comparable(NotTriviallyEqualityComparableHasEnum));
-
 struct NonTriviallyEqualityComparableValueComparisonNonTriviallyCopyable {
   int i;
   NonTriviallyEqualityComparableValueComparisonNonTriviallyCopyable(const 
NonTriviallyEqualityComparableValueComparisonNonTriviallyCopyable&);
@@ -4475,7 +4460,7 @@ 
static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableRefC
 }
 
 #endif // __cplusplus >= 202002L
-};
+}
 
 namespace can_pass_in_regs {
 


        
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to