https://github.com/usx95 updated 
https://github.com/llvm/llvm-project/pull/159845

>From f105e5a5c094e3bfc8f3ceb6331651eaed5823e9 Mon Sep 17 00:00:00 2001
From: Utkarsh Saxena <u...@google.com>
Date: Fri, 19 Sep 2025 20:00:17 +0000
Subject: [PATCH] fix-add-immutable-set

---
 clang/lib/Analysis/LifetimeSafety.cpp   | 12 ++----
 llvm/include/llvm/ADT/ImmutableSet.h    | 49 ++++++++++++++++++-------
 llvm/unittests/ADT/ImmutableSetTest.cpp | 31 ++++++++++++++++
 3 files changed, 70 insertions(+), 22 deletions(-)

diff --git a/clang/lib/Analysis/LifetimeSafety.cpp 
b/clang/lib/Analysis/LifetimeSafety.cpp
index d016c6f12e82e..0dd5716d93fb6 100644
--- a/clang/lib/Analysis/LifetimeSafety.cpp
+++ b/clang/lib/Analysis/LifetimeSafety.cpp
@@ -910,13 +910,10 @@ template <typename T>
 static llvm::ImmutableSet<T> join(llvm::ImmutableSet<T> A,
                                   llvm::ImmutableSet<T> B,
                                   typename llvm::ImmutableSet<T>::Factory &F) {
-  if (A == B)
-    return A;
   if (A.getHeight() < B.getHeight())
     std::swap(A, B);
   for (const T &E : B)
-    if (!A.contains(E))
-      A = F.add(A, E);
+    A = F.add(A, E);
   return A;
 }
 
@@ -950,11 +947,10 @@ join(llvm::ImmutableMap<K, V> A, llvm::ImmutableMap<K, V> 
B,
   for (const auto &Entry : B) {
     const K &Key = Entry.first;
     const V &ValB = Entry.second;
-    const V *ValA = A.lookup(Key);
-    if (!ValA)
-      A = F.add(A, Key, ValB);
-    else if (*ValA != ValB)
+    if (const V *ValA = A.lookup(Key))
       A = F.add(A, Key, JoinValues(*ValA, ValB));
+    else
+      A = F.add(A, Key, ValB);
   }
   return A;
 }
diff --git a/llvm/include/llvm/ADT/ImmutableSet.h 
b/llvm/include/llvm/ADT/ImmutableSet.h
index ac86f43b2048e..017585a47ddd6 100644
--- a/llvm/include/llvm/ADT/ImmutableSet.h
+++ b/llvm/include/llvm/ADT/ImmutableSet.h
@@ -531,7 +531,7 @@ class ImutAVLFactory {
   /// add_internal - Creates a new tree that includes the specified
   ///  data and the data from the original tree.  If the original tree
   ///  already contained the data item, the original tree is returned.
-  TreeTy* add_internal(value_type_ref V, TreeTy* T) {
+  TreeTy *add_internal(value_type_ref V, TreeTy *T) {
     if (isEmpty(T))
       return createNode(T, V, T);
     assert(!T->isMutable());
@@ -539,19 +539,34 @@ class ImutAVLFactory {
     key_type_ref K = ImutInfo::KeyOfValue(V);
     key_type_ref KCurrent = ImutInfo::KeyOfValue(getValue(T));
 
-    if (ImutInfo::isEqual(K,KCurrent))
+    if (ImutInfo::isEqual(K, KCurrent)) {
+      // If both key and value are same, return the original tree.
+      if (ImutInfo::isDataEqual(ImutInfo::DataOfValue(V),
+                                ImutInfo::DataOfValue(getValue(T))))
+        return T;
+      // Otherwise create a new node with the new value.
       return createNode(getLeft(T), V, getRight(T));
-    else if (ImutInfo::isLess(K,KCurrent))
-      return balanceTree(add_internal(V, getLeft(T)), getValue(T), 
getRight(T));
+    }
+
+    TreeTy *NewL = getLeft(T);
+    TreeTy *NewR = getRight(T);
+    if (ImutInfo::isLess(K, KCurrent))
+      NewL = add_internal(V, NewL);
     else
-      return balanceTree(getLeft(T), getValue(T), add_internal(V, 
getRight(T)));
+      NewR = add_internal(V, NewR);
+
+    // If no changes were made, return the original tree. Otherwise, balance 
the
+    // tree and return the new root.
+    return NewL == getLeft(T) && NewR == getRight(T)
+               ? T
+               : balanceTree(NewL, getValue(T), NewR);
   }
 
   /// remove_internal - Creates a new tree that includes all the data
   ///  from the original tree except the specified data.  If the
   ///  specified data did not exist in the original tree, the original
   ///  tree is returned.
-  TreeTy* remove_internal(key_type_ref K, TreeTy* T) {
+  TreeTy *remove_internal(key_type_ref K, TreeTy *T) {
     if (isEmpty(T))
       return T;
 
@@ -559,15 +574,21 @@ class ImutAVLFactory {
 
     key_type_ref KCurrent = ImutInfo::KeyOfValue(getValue(T));
 
-    if (ImutInfo::isEqual(K,KCurrent)) {
+    if (ImutInfo::isEqual(K, KCurrent))
       return combineTrees(getLeft(T), getRight(T));
-    } else if (ImutInfo::isLess(K,KCurrent)) {
-      return balanceTree(remove_internal(K, getLeft(T)),
-                                            getValue(T), getRight(T));
-    } else {
-      return balanceTree(getLeft(T), getValue(T),
-                         remove_internal(K, getRight(T)));
-    }
+
+    TreeTy *NewL = getLeft(T);
+    TreeTy *NewR = getRight(T);
+    if (ImutInfo::isLess(K, KCurrent))
+      NewL = remove_internal(K, NewL);
+    else
+      NewR = remove_internal(K, NewR);
+
+    // If no changes were made, return the original tree. Otherwise, balance 
the
+    // tree and return the new root.
+    return NewL == getLeft(T) && NewR == getRight(T)
+               ? T
+               : balanceTree(NewL, getValue(T), NewR);
   }
 
   TreeTy* combineTrees(TreeTy* L, TreeTy* R) {
diff --git a/llvm/unittests/ADT/ImmutableSetTest.cpp 
b/llvm/unittests/ADT/ImmutableSetTest.cpp
index c0bde4c4d680b..87bc2a8da4bad 100644
--- a/llvm/unittests/ADT/ImmutableSetTest.cpp
+++ b/llvm/unittests/ADT/ImmutableSetTest.cpp
@@ -164,4 +164,35 @@ TEST_F(ImmutableSetTest, IterLongSetTest) {
   ASSERT_EQ(6, i);
 }
 
+TEST_F(ImmutableSetTest, AddIfNotFoundTest) {
+  ImmutableSet<long>::Factory f(/*canonicalize=*/false);
+  ImmutableSet<long> S = f.getEmptySet();
+  S = f.add(S, 1);
+  S = f.add(S, 2);
+  S = f.add(S, 3);
+
+  ImmutableSet<long> T1 = f.add(S, 1);
+  ImmutableSet<long> T2 = f.add(S, 2);
+  ImmutableSet<long> T3 = f.add(S, 3);
+  EXPECT_EQ(S.getRoot(), T1.getRoot());
+  EXPECT_EQ(S.getRoot(), T2.getRoot());
+  EXPECT_EQ(S.getRoot(), T3.getRoot());
+
+  ImmutableSet<long> U = f.add(S, 4);
+  EXPECT_NE(S.getRoot(), U.getRoot());
+}
+
+TEST_F(ImmutableSetTest, RemoveIfNotFoundTest) {
+  ImmutableSet<long>::Factory f(/*canonicalize=*/false);
+  ImmutableSet<long> S = f.getEmptySet();
+  S = f.add(S, 1);
+  S = f.add(S, 2);
+  S = f.add(S, 3);
+
+  ImmutableSet<long> T = f.remove(S, 4);
+  EXPECT_EQ(S.getRoot(), T.getRoot());
+
+  ImmutableSet<long> U = f.remove(S, 3);
+  EXPECT_NE(S.getRoot(), U.getRoot());
 }
+} // namespace

_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to