llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-clang

Author: None (halbi2)

<details>
<summary>Changes</summary>

std::equal(std::byte) currently has sub-optimal codegen due to enum types not 
being recognized as trivially equality comparable. In order to fix this we make 
them trivially comparable. In the process I factored out into a standalone 
function EqualityComparisonIsDefaulted and refactored the test cases.

Enum types cannot have operator== which is a hidden friend.

Fixes #<!-- -->132672

---
Full diff: https://github.com/llvm/llvm-project/pull/169079.diff


2 Files Affected:

- (modified) clang/lib/Sema/SemaTypeTraits.cpp (+50-33) 
- (modified) clang/test/SemaCXX/type-traits.cpp (+20-26) 


``````````diff
diff --git a/clang/lib/Sema/SemaTypeTraits.cpp 
b/clang/lib/Sema/SemaTypeTraits.cpp
index 38877967af05e..581989e6d0069 100644
--- a/clang/lib/Sema/SemaTypeTraits.cpp
+++ b/clang/lib/Sema/SemaTypeTraits.cpp
@@ -591,6 +591,43 @@ static bool HasNoThrowOperator(CXXRecordDecl *RD, 
OverloadedOperatorKind Op,
   return false;
 }
 
+static bool EqualityComparisonIsDefaulted(Sema &S, const TypeDecl *Decl,
+                                          SourceLocation KeyLoc) {
+  CanQualType T = S.Context.getCanonicalTagType(Decl);
+
+  EnterExpressionEvaluationContext UnevaluatedContext(
+      S, Sema::ExpressionEvaluationContext::Unevaluated);
+  Sema::SFINAETrap SFINAE(S, /*ForValidityCheck=*/true);
+  Sema::ContextRAII TUContext(S, S.Context.getTranslationUnitDecl());
+
+  // const ClassT& obj;
+  OpaqueValueExpr Operand(
+      KeyLoc, T.withConst(),
+      ExprValueKind::VK_LValue);
+  UnresolvedSet<16> Functions;
+  // obj == obj;
+  S.LookupBinOp(S.TUScope, {}, BinaryOperatorKind::BO_EQ, Functions);
+
+  auto Result = S.CreateOverloadedBinOp(KeyLoc, BinaryOperatorKind::BO_EQ,
+                                        Functions, &Operand, &Operand);
+  if (Result.isInvalid() || SFINAE.hasErrorOccurred())
+    return false;
+
+  const auto *CallExpr = dyn_cast<CXXOperatorCallExpr>(Result.get());
+  if (!CallExpr)
+    return isa<EnumDecl>(Decl);
+  const auto *Callee = CallExpr->getDirectCallee();
+  auto ParamT = Callee->getParamDecl(0)->getType();
+  if (!Callee->isDefaulted())
+    return false;
+  if (!ParamT->isReferenceType()) {
+    const CXXRecordDecl *RD = dyn_cast<CXXRecordDecl>(Decl);
+    if (!RD->isTriviallyCopyable())
+      return false;
+  }
+  return S.Context.hasSameUnqualifiedType(ParamT.getNonReferenceType(), T);
+}
+
 static bool HasNonDeletedDefaultedEqualityComparison(Sema &S,
                                                      const CXXRecordDecl *Decl,
                                                      SourceLocation KeyLoc) {
@@ -599,36 +636,8 @@ static bool HasNonDeletedDefaultedEqualityComparison(Sema 
&S,
   if (Decl->isLambda())
     return Decl->isCapturelessLambda();
 
-  CanQualType T = S.Context.getCanonicalTagType(Decl);
-  {
-    EnterExpressionEvaluationContext UnevaluatedContext(
-        S, Sema::ExpressionEvaluationContext::Unevaluated);
-    Sema::SFINAETrap SFINAE(S, /*ForValidityCheck=*/true);
-    Sema::ContextRAII TUContext(S, S.Context.getTranslationUnitDecl());
-
-    // const ClassT& obj;
-    OpaqueValueExpr Operand(KeyLoc, T.withConst(), ExprValueKind::VK_LValue);
-    UnresolvedSet<16> Functions;
-    // obj == obj;
-    S.LookupBinOp(S.TUScope, {}, BinaryOperatorKind::BO_EQ, Functions);
-
-    auto Result = S.CreateOverloadedBinOp(KeyLoc, BinaryOperatorKind::BO_EQ,
-                                          Functions, &Operand, &Operand);
-    if (Result.isInvalid() || SFINAE.hasErrorOccurred())
-      return false;
-
-    const auto *CallExpr = dyn_cast<CXXOperatorCallExpr>(Result.get());
-    if (!CallExpr)
-      return false;
-    const auto *Callee = CallExpr->getDirectCallee();
-    auto ParamT = Callee->getParamDecl(0)->getType();
-    if (!Callee->isDefaulted())
-      return false;
-    if (!ParamT->isReferenceType() && !Decl->isTriviallyCopyable())
-      return false;
-    if (!S.Context.hasSameUnqualifiedType(ParamT.getNonReferenceType(), T))
-      return false;
-  }
+  if (!EqualityComparisonIsDefaulted(S, Decl, KeyLoc))
+    return false;
 
   return llvm::all_of(Decl->bases(),
                       [&](const CXXBaseSpecifier &BS) {
@@ -643,9 +652,12 @@ static bool HasNonDeletedDefaultedEqualityComparison(Sema 
&S,
              Type = Type->getBaseElementTypeUnsafe()
                         ->getCanonicalTypeUnqualified();
 
-           if (Type->isReferenceType() || Type->isEnumeralType())
+           if (Type->isReferenceType())
              return false;
-           if (const auto *RD = Type->getAsCXXRecordDecl())
+           if (Type->isEnumeralType()) {
+             EnumDecl *ED = 
Type->castAs<EnumType>()->getOriginalDecl()->getDefinitionOrSelf();
+             return EqualityComparisonIsDefaulted(S, ED, KeyLoc);
+           } else if (const auto *RD = Type->getAsCXXRecordDecl())
              return HasNonDeletedDefaultedEqualityComparison(S, RD, KeyLoc);
            return true;
          });
@@ -655,9 +667,14 @@ static bool isTriviallyEqualityComparableType(Sema &S, 
QualType Type,
                                               SourceLocation KeyLoc) {
   QualType CanonicalType = Type.getCanonicalType();
   if (CanonicalType->isIncompleteType() || CanonicalType->isDependentType() ||
-      CanonicalType->isEnumeralType() || CanonicalType->isArrayType())
+      CanonicalType->isArrayType())
     return false;
 
+  if (CanonicalType->isEnumeralType()) {
+    EnumDecl *ED = 
CanonicalType->castAs<EnumType>()->getOriginalDecl()->getDefinitionOrSelf();
+    return EqualityComparisonIsDefaulted(S, ED, KeyLoc);
+  }
+
   if (const auto *RD = CanonicalType->getAsCXXRecordDecl()) {
     if (!HasNonDeletedDefaultedEqualityComparison(S, RD, KeyLoc))
       return false;
diff --git a/clang/test/SemaCXX/type-traits.cpp 
b/clang/test/SemaCXX/type-traits.cpp
index 9ef44d0346b48..76fa4a9c2b936 100644
--- a/clang/test/SemaCXX/type-traits.cpp
+++ b/clang/test/SemaCXX/type-traits.cpp
@@ -3993,6 +3993,10 @@ namespace is_trivially_equality_comparable {
 struct ForwardDeclared; // expected-note {{forward declaration of 
'is_trivially_equality_comparable::ForwardDeclared'}}
 static_assert(!__is_trivially_equality_comparable(ForwardDeclared)); // 
expected-error {{incomplete type 'ForwardDeclared' used in type trait 
expression}}
 
+enum Enum {};
+enum EnumWithOpEq {};
+bool operator==(EnumWithOpEq, EnumWithOpEq);
+
 static_assert(!__is_trivially_equality_comparable(void));
 static_assert(__is_trivially_equality_comparable(int));
 static_assert(!__is_trivially_equality_comparable(int[]));
@@ -4000,6 +4004,8 @@ 
static_assert(!__is_trivially_equality_comparable(int[3]));
 static_assert(!__is_trivially_equality_comparable(float));
 static_assert(!__is_trivially_equality_comparable(double));
 static_assert(!__is_trivially_equality_comparable(long double));
+static_assert(__is_trivially_equality_comparable(Enum));
+static_assert(!__is_trivially_equality_comparable(EnumWithOpEq));
 
 struct NonTriviallyEqualityComparableNoComparator {
   int i;
@@ -4033,19 +4039,21 @@ struct TriviallyEqualityComparable {
 };
 static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparable));
 
-struct TriviallyEqualityComparableContainsArray {
-  int a[4];
+template <class T>
+struct TriviallyEqualityComparableContains {
+  T t;
 
-  bool operator==(const TriviallyEqualityComparableContainsArray&) const = 
default;
+  bool operator==(const TriviallyEqualityComparableContains&) const = default;
 };
-static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableContainsArray));
 
-struct TriviallyEqualityComparableContainsMultiDimensionArray {
-  int a[4][4];
-
-  bool operator==(const 
TriviallyEqualityComparableContainsMultiDimensionArray&) const = default;
-};
-static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableContainsMultiDimensionArray));
+static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableContains<int>));
+static_assert(!__is_trivially_equality_comparable(TriviallyEqualityComparableContains<float>));
+static_assert(!__is_trivially_equality_comparable(TriviallyEqualityComparableContains<double>));
+static_assert(!__is_trivially_equality_comparable(TriviallyEqualityComparableContains<long
 double>));
+static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableContains<int[4]>));
+static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableContains<int[4][4]>));
+static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableContains<Enum>));
+static_assert(!__is_trivially_equality_comparable(TriviallyEqualityComparableContains<EnumWithOpEq>));
 
 auto GetNonCapturingLambda() { return [](){ return 42; }; }
 
@@ -4196,13 +4204,6 @@ struct 
NotTriviallyEqualityComparableNonTriviallyComparableBase : NotTriviallyEq
 };
 
static_assert(!__is_trivially_equality_comparable(NotTriviallyEqualityComparableNonTriviallyComparableBase));
 
-enum E {
-  a,
-  b
-};
-bool operator==(E, E) { return false; }
-static_assert(!__is_trivially_equality_comparable(E));
-
 struct NotTriviallyEqualityComparableHasEnum {
   E e;
   bool operator==(const NotTriviallyEqualityComparableHasEnum&) const = 
default;
@@ -4434,15 +4435,8 @@ struct NotTriviallyEqualityComparableHasReferenceMember {
 };
 
static_assert(!__is_trivially_equality_comparable(NotTriviallyEqualityComparableHasReferenceMember));
 
-enum E {
-  a,
-  b
-};
-bool operator==(E, E) { return false; }
-static_assert(!__is_trivially_equality_comparable(E));
-
 struct NotTriviallyEqualityComparableHasEnum {
-  E e;
+  Enum e;
   friend bool operator==(const NotTriviallyEqualityComparableHasEnum&, const 
NotTriviallyEqualityComparableHasEnum&) = default;
 };
 
static_assert(!__is_trivially_equality_comparable(NotTriviallyEqualityComparableHasEnum));
@@ -4465,7 +4459,7 @@ 
static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableRefC
 }
 
 #endif // __cplusplus >= 202002L
-};
+}
 
 namespace can_pass_in_regs {
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/169079
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to