llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-adt

Author: Utkarsh Saxena (usx95)

<details>
<summary>Changes</summary>

Optimize ImmutableSet operations to avoid unnecessary tree modifications when 
adding existing elements or removing non-existent elements.

- Modified `ImutAVLFactory::add_internal()` to return the original tree when 
both key and value are the same, avoiding unnecessary node creation
- Updated `ImutAVLFactory::remove_internal()` and `add_internal()` to return 
the original tree when no changes are made.

Note that `balanceTree` always end up creating at least one node even when no 
rebalancing is done. So we also need to avoid unnecessary calls to it.

---
Full diff: https://github.com/llvm/llvm-project/pull/159845.diff


3 Files Affected:

- (modified) clang/lib/Analysis/LifetimeSafety.cpp (+4-8) 
- (modified) llvm/include/llvm/ADT/ImmutableSet.h (+34-14) 
- (modified) llvm/unittests/ADT/ImmutableSetTest.cpp (+33) 


``````````diff
diff --git a/clang/lib/Analysis/LifetimeSafety.cpp 
b/clang/lib/Analysis/LifetimeSafety.cpp
index d016c6f12e82e..0dd5716d93fb6 100644
--- a/clang/lib/Analysis/LifetimeSafety.cpp
+++ b/clang/lib/Analysis/LifetimeSafety.cpp
@@ -910,13 +910,10 @@ template <typename T>
 static llvm::ImmutableSet<T> join(llvm::ImmutableSet<T> A,
                                   llvm::ImmutableSet<T> B,
                                   typename llvm::ImmutableSet<T>::Factory &F) {
-  if (A == B)
-    return A;
   if (A.getHeight() < B.getHeight())
     std::swap(A, B);
   for (const T &E : B)
-    if (!A.contains(E))
-      A = F.add(A, E);
+    A = F.add(A, E);
   return A;
 }
 
@@ -950,11 +947,10 @@ join(llvm::ImmutableMap<K, V> A, llvm::ImmutableMap<K, V> 
B,
   for (const auto &Entry : B) {
     const K &Key = Entry.first;
     const V &ValB = Entry.second;
-    const V *ValA = A.lookup(Key);
-    if (!ValA)
-      A = F.add(A, Key, ValB);
-    else if (*ValA != ValB)
+    if (const V *ValA = A.lookup(Key))
       A = F.add(A, Key, JoinValues(*ValA, ValB));
+    else
+      A = F.add(A, Key, ValB);
   }
   return A;
 }
diff --git a/llvm/include/llvm/ADT/ImmutableSet.h 
b/llvm/include/llvm/ADT/ImmutableSet.h
index ac86f43b2048e..c2d84d86b5e27 100644
--- a/llvm/include/llvm/ADT/ImmutableSet.h
+++ b/llvm/include/llvm/ADT/ImmutableSet.h
@@ -531,7 +531,7 @@ class ImutAVLFactory {
   /// add_internal - Creates a new tree that includes the specified
   ///  data and the data from the original tree.  If the original tree
   ///  already contained the data item, the original tree is returned.
-  TreeTy* add_internal(value_type_ref V, TreeTy* T) {
+  TreeTy *add_internal(value_type_ref V, TreeTy *T) {
     if (isEmpty(T))
       return createNode(T, V, T);
     assert(!T->isMutable());
@@ -539,19 +539,34 @@ class ImutAVLFactory {
     key_type_ref K = ImutInfo::KeyOfValue(V);
     key_type_ref KCurrent = ImutInfo::KeyOfValue(getValue(T));
 
-    if (ImutInfo::isEqual(K,KCurrent))
+    if (ImutInfo::isEqual(K, KCurrent)) {
+      // If both key and value are same, return the original tree.
+      if (ImutInfo::isDataEqual(ImutInfo::DataOfValue(V),
+                                ImutInfo::DataOfValue(getValue(T))))
+        return T;
+      // Otherwise create a new node with the new value.
       return createNode(getLeft(T), V, getRight(T));
-    else if (ImutInfo::isLess(K,KCurrent))
-      return balanceTree(add_internal(V, getLeft(T)), getValue(T), 
getRight(T));
+    }
+
+    TreeTy *NewL = getLeft(T);
+    TreeTy *NewR = getRight(T);
+    if (ImutInfo::isLess(K, KCurrent))
+      NewL = add_internal(V, NewL);
     else
-      return balanceTree(getLeft(T), getValue(T), add_internal(V, 
getRight(T)));
+      NewR = add_internal(V, NewR);
+
+    // If no changes were made, return the original tree. Otherwise, balance 
the
+    // tree and return the new root.
+    return NewL == getLeft(T) && NewR == getRight(T)
+               ? T
+               : balanceTree(NewL, getValue(T), NewR);
   }
 
   /// remove_internal - Creates a new tree that includes all the data
   ///  from the original tree except the specified data.  If the
   ///  specified data did not exist in the original tree, the original
   ///  tree is returned.
-  TreeTy* remove_internal(key_type_ref K, TreeTy* T) {
+    TreeTy* remove_internal(key_type_ref K, TreeTy* T) {
     if (isEmpty(T))
       return T;
 
@@ -559,15 +574,20 @@ class ImutAVLFactory {
 
     key_type_ref KCurrent = ImutInfo::KeyOfValue(getValue(T));
 
-    if (ImutInfo::isEqual(K,KCurrent)) {
+    if (ImutInfo::isEqual(K, KCurrent))
       return combineTrees(getLeft(T), getRight(T));
-    } else if (ImutInfo::isLess(K,KCurrent)) {
-      return balanceTree(remove_internal(K, getLeft(T)),
-                                            getValue(T), getRight(T));
-    } else {
-      return balanceTree(getLeft(T), getValue(T),
-                         remove_internal(K, getRight(T)));
-    }
+
+    TreeTy *NewL = getLeft(T);
+    TreeTy *NewR = getRight(T);
+    if (ImutInfo::isLess(K, KCurrent))
+      NewL = remove_internal(K, NewL);
+    else
+      NewR = remove_internal(K, NewR);
+
+    // If no changes were made, return the original tree. Otherwise, balance 
the
+    // tree and return the new root.
+    return NewL == getLeft(T) && NewR == getRight(T) ? T
+                                               : balanceTree(NewL, 
getValue(T), NewR);
   }
 
   TreeTy* combineTrees(TreeTy* L, TreeTy* R) {
diff --git a/llvm/unittests/ADT/ImmutableSetTest.cpp 
b/llvm/unittests/ADT/ImmutableSetTest.cpp
index c0bde4c4d680b..c85a642d06eb2 100644
--- a/llvm/unittests/ADT/ImmutableSetTest.cpp
+++ b/llvm/unittests/ADT/ImmutableSetTest.cpp
@@ -164,4 +164,37 @@ TEST_F(ImmutableSetTest, IterLongSetTest) {
   ASSERT_EQ(6, i);
 }
 
+TEST_F(ImmutableSetTest, AddIfNotFoundTest) {
+  ImmutableSet<long>::Factory f(/*canonicalize=*/false);
+  ImmutableSet<long> S = f.getEmptySet();
+  S = f.add(S, 1);
+  S = f.add(S, 2);
+  S = f.add(S, 3);
+
+  ImmutableSet<long> T1 = f.add(S, 1);
+  ImmutableSet<long> T2 = f.add(S, 2);
+  ImmutableSet<long> T3 = f.add(S, 3);
+  EXPECT_EQ(S.getRoot(), T1.getRoot());
+  EXPECT_EQ(S.getRoot(), T2.getRoot());
+  EXPECT_EQ(S.getRoot(), T3.getRoot());
+
+  ImmutableSet<long> U = f.add(S, 4);
+  EXPECT_NE(S.getRoot(), U.getRoot());
+}
+
+
+TEST_F(ImmutableSetTest, RemoveIfNotFoundTest) {
+  ImmutableSet<long>::Factory f(/*canonicalize=*/false);
+  ImmutableSet<long> S = f.getEmptySet();
+  S = f.add(S, 1);
+  S = f.add(S, 2);
+  S = f.add(S, 3);
+
+  ImmutableSet<long> T = f.remove(S, 4);
+  EXPECT_EQ(S.getRoot(), T.getRoot());
+
+  ImmutableSet<long> U = f.remove(S, 3);
+  EXPECT_NE(S.getRoot(), U.getRoot());
+}
+
 }

``````````

</details>


https://github.com/llvm/llvm-project/pull/159845
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to