https://github.com/philnik777 created 
https://github.com/llvm/llvm-project/pull/93113

This changes `__is_trivially_equality_comparable` to do overload resolution 
instead, which fixes a couple of false-positives (and a false-negative as a 
drive-by).

Fixes #89293



>From 53b52a07f8720db4495b93099d3e1874453f6950 Mon Sep 17 00:00:00 2001
From: Nikolas Klauser <nikolasklau...@berlin.de>
Date: Thu, 23 May 2024 01:48:06 +0200
Subject: [PATCH] [Clang] Fix __is_trivially_equality_comparable returning true
 with ineligebile defaulted overloads

---
 clang/include/clang/AST/Type.h     |  3 --
 clang/lib/AST/Type.cpp             | 60 -------------------------
 clang/lib/Sema/SemaExprCXX.cpp     | 71 +++++++++++++++++++++++++++++-
 clang/test/SemaCXX/type-traits.cpp | 28 ++++++++++++
 4 files changed, 98 insertions(+), 64 deletions(-)

diff --git a/clang/include/clang/AST/Type.h b/clang/include/clang/AST/Type.h
index 9a5c6e8d562c3..628c7a0d2df83 100644
--- a/clang/include/clang/AST/Type.h
+++ b/clang/include/clang/AST/Type.h
@@ -1126,9 +1126,6 @@ class QualType {
   /// Return true if this is a trivially relocatable type.
   bool isTriviallyRelocatableType(const ASTContext &Context) const;
 
-  /// Return true if this is a trivially equality comparable type.
-  bool isTriviallyEqualityComparableType(const ASTContext &Context) const;
-
   /// Returns true if it is a class and it might be dynamic.
   bool mayBeDynamicClass() const;
 
diff --git a/clang/lib/AST/Type.cpp b/clang/lib/AST/Type.cpp
index 3b90b8229dd18..62ca402460f94 100644
--- a/clang/lib/AST/Type.cpp
+++ b/clang/lib/AST/Type.cpp
@@ -2768,66 +2768,6 @@ bool QualType::isTriviallyRelocatableType(const 
ASTContext &Context) const {
   }
 }
 
-static bool
-HasNonDeletedDefaultedEqualityComparison(const CXXRecordDecl *Decl) {
-  if (Decl->isUnion())
-    return false;
-  if (Decl->isLambda())
-    return Decl->isCapturelessLambda();
-
-  auto IsDefaultedOperatorEqualEqual = [&](const FunctionDecl *Function) {
-    return Function->getOverloadedOperator() ==
-               OverloadedOperatorKind::OO_EqualEqual &&
-           Function->isDefaulted() && Function->getNumParams() > 0 &&
-           (Function->getParamDecl(0)->getType()->isReferenceType() ||
-            Decl->isTriviallyCopyable());
-  };
-
-  if (llvm::none_of(Decl->methods(), IsDefaultedOperatorEqualEqual) &&
-      llvm::none_of(Decl->friends(), [&](const FriendDecl *Friend) {
-        if (NamedDecl *ND = Friend->getFriendDecl()) {
-          return ND->isFunctionOrFunctionTemplate() &&
-                 IsDefaultedOperatorEqualEqual(ND->getAsFunction());
-        }
-        return false;
-      }))
-    return false;
-
-  return llvm::all_of(Decl->bases(),
-                      [](const CXXBaseSpecifier &BS) {
-                        if (const auto *RD = 
BS.getType()->getAsCXXRecordDecl())
-                          return HasNonDeletedDefaultedEqualityComparison(RD);
-                        return true;
-                      }) &&
-         llvm::all_of(Decl->fields(), [](const FieldDecl *FD) {
-           auto Type = FD->getType();
-           if (Type->isArrayType())
-             Type = 
Type->getBaseElementTypeUnsafe()->getCanonicalTypeUnqualified();
-
-           if (Type->isReferenceType() || Type->isEnumeralType())
-             return false;
-           if (const auto *RD = Type->getAsCXXRecordDecl())
-             return HasNonDeletedDefaultedEqualityComparison(RD);
-           return true;
-         });
-}
-
-bool QualType::isTriviallyEqualityComparableType(
-    const ASTContext &Context) const {
-  QualType CanonicalType = getCanonicalType();
-  if (CanonicalType->isIncompleteType() || CanonicalType->isDependentType() ||
-      CanonicalType->isEnumeralType() || CanonicalType->isArrayType())
-    return false;
-
-  if (const auto *RD = CanonicalType->getAsCXXRecordDecl()) {
-    if (!HasNonDeletedDefaultedEqualityComparison(RD))
-      return false;
-  }
-
-  return Context.hasUniqueObjectRepresentations(
-      CanonicalType, /*CheckIfTriviallyCopyable=*/false);
-}
-
 bool QualType::isNonWeakInMRRWithObjCWeak(const ASTContext &Context) const {
   return !Context.getLangOpts().ObjCAutoRefCount &&
          Context.getLangOpts().ObjCWeak &&
diff --git a/clang/lib/Sema/SemaExprCXX.cpp b/clang/lib/Sema/SemaExprCXX.cpp
index f543e006060d6..ccf678e666ecb 100644
--- a/clang/lib/Sema/SemaExprCXX.cpp
+++ b/clang/lib/Sema/SemaExprCXX.cpp
@@ -5199,6 +5199,75 @@ static bool HasNoThrowOperator(const RecordType *RT, 
OverloadedOperatorKind Op,
   return false;
 }
 
+static bool
+HasNonDeletedDefaultedEqualityComparison(Sema& S, const CXXRecordDecl *Decl) {
+  if (Decl->isUnion())
+    return false;
+  if (Decl->isLambda())
+    return Decl->isCapturelessLambda();
+
+  {
+    EnterExpressionEvaluationContext UnevaluatedContext(
+        S, Sema::ExpressionEvaluationContext::Unevaluated);
+    Sema::SFINAETrap SFINAE(S, /*AccessCheckingSFINAE=*/true);
+    Sema::ContextRAII TUContext(S, S.Context.getTranslationUnitDecl());
+
+    // const ClassT& obj;
+    OpaqueValueExpr Operand(
+        {}, Decl->getTypeForDecl()->getCanonicalTypeUnqualified().withConst(),
+        ExprValueKind::VK_LValue);
+    UnresolvedSet<16> Functions;
+    // obj == obj;
+    S.LookupBinOp(S.TUScope, {}, BinaryOperatorKind::BO_EQ, Functions);
+
+    auto Result = S.CreateOverloadedBinOp({}, BinaryOperatorKind::BO_EQ,
+                                          Functions, &Operand, &Operand);
+    if (Result.isInvalid() || SFINAE.hasErrorOccurred())
+      return false;
+    auto Callee = Result.getAs<CXXOperatorCallExpr>()->getDirectCallee();
+    auto ParamT = Callee->getParamDecl(0)->getType();
+    if (!Callee->isDefaulted() ||
+        (!ParamT->isReferenceType() && !Decl->isTriviallyCopyable()))
+      return false;
+  }
+
+  return llvm::all_of(Decl->bases(),
+                      [&](const CXXBaseSpecifier &BS) {
+                        if (const auto *RD = 
BS.getType()->getAsCXXRecordDecl())
+                          return HasNonDeletedDefaultedEqualityComparison(S,
+                                                                          RD);
+                        return true;
+                      }) &&
+         llvm::all_of(Decl->fields(), [&](const FieldDecl *FD) {
+           auto Type = FD->getType();
+           if (Type->isArrayType())
+             Type = Type->getBaseElementTypeUnsafe()
+                        ->getCanonicalTypeUnqualified();
+
+           if (Type->isReferenceType() || Type->isEnumeralType())
+             return false;
+           if (const auto *RD = Type->getAsCXXRecordDecl())
+             return HasNonDeletedDefaultedEqualityComparison(S, RD);
+           return true;
+         });
+}
+
+static bool isTriviallyEqualityComparableType(
+    Sema& S, QualType Type) {
+  QualType CanonicalType = Type.getCanonicalType();
+  if (CanonicalType->isIncompleteType() || CanonicalType->isDependentType() ||
+      CanonicalType->isEnumeralType() || CanonicalType->isArrayType())
+    return false;
+
+  if (const auto *RD = CanonicalType->getAsCXXRecordDecl()) {
+    if (!HasNonDeletedDefaultedEqualityComparison(S, RD))
+      return false;
+  }
+
+  return S.getASTContext().hasUniqueObjectRepresentations(
+      CanonicalType, /*CheckIfTriviallyCopyable=*/false);
+}
+
 static bool EvaluateUnaryTypeTrait(Sema &Self, TypeTrait UTT,
                                    SourceLocation KeyLoc,
                                    TypeSourceInfo *TInfo) {
@@ -5629,7 +5698,7 @@ static bool EvaluateUnaryTypeTrait(Sema &Self, TypeTrait 
UTT,
     Self.Diag(KeyLoc, diag::err_builtin_pass_in_regs_non_class) << T;
     return false;
   case UTT_IsTriviallyEqualityComparable:
-    return T.isTriviallyEqualityComparableType(C);
+    return isTriviallyEqualityComparableType(Self, T);
   }
 }
 
diff --git a/clang/test/SemaCXX/type-traits.cpp 
b/clang/test/SemaCXX/type-traits.cpp
index d40605f56f1ed..d06495e9cfb66 100644
--- a/clang/test/SemaCXX/type-traits.cpp
+++ b/clang/test/SemaCXX/type-traits.cpp
@@ -3885,8 +3885,36 @@ struct 
NotTriviallyEqualityComparableNonTriviallyEqualityComparableArrs2 {
 
   bool operator==(const 
NotTriviallyEqualityComparableNonTriviallyEqualityComparableArrs2&) const = 
default;
 };
+
 
static_assert(!__is_trivially_equality_comparable(NotTriviallyEqualityComparableNonTriviallyEqualityComparableArrs2));
 
+template<bool B>
+struct MaybeTriviallyEqualityComparable {
+    int i;
+    bool operator==(const MaybeTriviallyEqualityComparable&) const requires B 
= default;
+    bool operator==(const MaybeTriviallyEqualityComparable& rhs) const { 
return (i % 3) == (rhs.i % 3); }
+};
+static_assert(__is_trivially_equality_comparable(MaybeTriviallyEqualityComparable<true>));
+static_assert(!__is_trivially_equality_comparable(MaybeTriviallyEqualityComparable<false>));
+
+struct NotTriviallyEqualityComparableMoreConstrainedExternalOp {
+  int i;
+  bool operator==(const 
NotTriviallyEqualityComparableMoreConstrainedExternalOp&) const = default;
+};
+
+bool operator==(const NotTriviallyEqualityComparableMoreConstrainedExternalOp&,
+                const 
NotTriviallyEqualityComparableMoreConstrainedExternalOp&) 
__attribute__((enable_if(true, ""))) {}
+
+static_assert(!__is_trivially_equality_comparable(NotTriviallyEqualityComparableMoreConstrainedExternalOp));
+
+struct TriviallyEqualityComparableExternalDefaultedOp {
+  int i;
+  friend bool operator==(TriviallyEqualityComparableExternalDefaultedOp, 
TriviallyEqualityComparableExternalDefaultedOp);
+};
+bool operator==(TriviallyEqualityComparableExternalDefaultedOp, 
TriviallyEqualityComparableExternalDefaultedOp) = default;
+
+static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableExternalDefaultedOp));
+
 namespace hidden_friend {
 
 struct TriviallyEqualityComparable {

_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to