This is an automated email from the ASF dual-hosted git repository.

hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new a03f6bf7c1 [GLUTEN-11509][VL] Make TreeMemoryConsumer thread-safe 
(#11553)
a03f6bf7c1 is described below

commit a03f6bf7c1fe2aa0db029b9bb498398edb534886
Author: Mohammad Linjawi <[email protected]>
AuthorDate: Wed Feb 4 16:39:44 2026 +0300

    [GLUTEN-11509][VL] Make TreeMemoryConsumer thread-safe (#11553)
---
 .../memory/memtarget/spark/TreeMemoryConsumer.java |  35 +--
 .../memtarget/spark/TreeMemoryConsumerTest.java    | 318 +++++++++++++++++++++
 2 files changed, 330 insertions(+), 23 deletions(-)

diff --git 
a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumer.java
 
b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumer.java
index f454b55324..0f00e1d669 100644
--- 
a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumer.java
+++ 
b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumer.java
@@ -29,9 +29,8 @@ import org.apache.spark.util.Utils;
 
 import java.io.IOException;
 import java.util.Collections;
-import java.util.HashMap;
 import java.util.Map;
-import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
 import java.util.stream.Collectors;
 
 /**
@@ -50,7 +49,7 @@ import java.util.stream.Collectors;
 public class TreeMemoryConsumer extends MemoryConsumer implements 
TreeMemoryTarget {
 
   private final SimpleMemoryUsageRecorder recorder = new 
SimpleMemoryUsageRecorder();
-  private final Map<String, TreeMemoryTarget> children = new HashMap<>();
+  private final Map<String, TreeMemoryTarget> children = new 
ConcurrentHashMap<>();
   private final String name = MemoryTargetUtil.toUniqueName("Gluten.Tree");
 
   TreeMemoryConsumer(TaskMemoryManager taskMemoryManager, MemoryMode mode) {
@@ -97,17 +96,10 @@ public class TreeMemoryConsumer extends MemoryConsumer 
implements TreeMemoryTarg
 
   @Override
   public MemoryUsageStats stats() {
-    Set<Map.Entry<String, TreeMemoryTarget>> entries = children.entrySet();
     Map<String, MemoryUsageStats> childrenStats =
-        entries.stream()
-            .collect(Collectors.toMap(e -> e.getValue().name(), e -> 
e.getValue().stats()));
-
-    Preconditions.checkState(childrenStats.size() == children.size());
-    MemoryUsageStats stats = recorder.toStats(childrenStats);
-    Preconditions.checkState(
-        stats.getCurrent() == getUsed(),
-        "Used bytes mismatch between gluten memory consumer and Spark task 
memory manager");
-    return stats;
+        children.values().stream()
+            .collect(Collectors.toMap(TreeMemoryTarget::name, 
TreeMemoryTarget::stats));
+    return recorder.toStats(childrenStats);
   }
 
   @Override
@@ -123,10 +115,10 @@ public class TreeMemoryConsumer extends MemoryConsumer 
implements TreeMemoryTarg
       Spiller spiller,
       Map<String, MemoryUsageStatsBuilder> virtualChildren) {
     final TreeMemoryTarget child = new Node(this, name, capacity, spiller, 
virtualChildren);
-    if (children.containsKey(child.name())) {
+    TreeMemoryTarget existing = children.putIfAbsent(child.name(), child);
+    if (existing != null) {
       throw new IllegalArgumentException("Child already registered: " + 
child.name());
     }
-    children.put(child.name(), child);
     return child;
   }
 
@@ -153,7 +145,7 @@ public class TreeMemoryConsumer extends MemoryConsumer 
implements TreeMemoryTarg
   }
 
   public static class Node implements TreeMemoryTarget, KnownNameAndStats {
-    private final Map<String, Node> children = new HashMap<>();
+    private final Map<String, Node> children = new ConcurrentHashMap<>();
     private final TreeMemoryTarget parent;
     private final String name;
     private final long capacity;
@@ -251,11 +243,8 @@ public class TreeMemoryConsumer extends MemoryConsumer 
implements TreeMemoryTarg
     @Override
     public MemoryUsageStats stats() {
       final Map<String, MemoryUsageStats> childrenStats =
-          new HashMap<>(
-              children.entrySet().stream()
-                  .collect(Collectors.toMap(e -> e.getValue().name(), e -> 
e.getValue().stats())));
-
-      Preconditions.checkState(childrenStats.size() == children.size());
+          children.values().stream()
+              .collect(Collectors.toMap(TreeMemoryTarget::name, 
TreeMemoryTarget::stats));
 
       // add virtual children
       for (Map.Entry<String, MemoryUsageStatsBuilder> entry : 
virtualChildren.entrySet()) {
@@ -275,10 +264,10 @@ public class TreeMemoryConsumer extends MemoryConsumer 
implements TreeMemoryTarg
         Map<String, MemoryUsageStatsBuilder> virtualChildren) {
       final Node child =
           new Node(this, name, Math.min(this.capacity, capacity), spiller, 
virtualChildren);
-      if (children.containsKey(child.name())) {
+      Node existing = children.putIfAbsent(child.name(), child);
+      if (existing != null) {
         throw new IllegalArgumentException("Child already registered: " + 
child.name());
       }
-      children.put(child.name(), child);
       return child;
     }
 
diff --git 
a/gluten-core/src/test/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumerTest.java
 
b/gluten-core/src/test/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumerTest.java
index e252b7d7c4..b3b6ad2bd1 100644
--- 
a/gluten-core/src/test/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumerTest.java
+++ 
b/gluten-core/src/test/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumerTest.java
@@ -30,8 +30,14 @@ import org.junit.Before;
 import org.junit.Test;
 
 import java.util.Collections;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.CyclicBarrier;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicLong;
+import java.util.concurrent.atomic.AtomicReference;
 
 import scala.Function0;
 
@@ -184,6 +190,318 @@ public class TreeMemoryConsumerTest {
         });
   }
 
+  /**
+   * Test concurrent child addition and spilling operations. This test 
reproduces the
+   * ConcurrentModificationException that occurs when one thread adds children 
while another thread
+   * is iterating during spilling.
+   */
+  @Test
+  public void testConcurrentAddAndSpill() {
+    test(
+        () -> {
+          final Spillers.AppendableSpillerList spillers = 
Spillers.appendable();
+          final TreeMemoryTarget root =
+              TreeMemoryConsumers.factory(MemoryMode.OFF_HEAP).legacyRoot();
+
+          final AtomicInteger spillCount = new AtomicInteger(0);
+          final AtomicReference<Throwable> failure = new AtomicReference<>();
+          final AtomicInteger childrenAdded = new AtomicInteger(0);
+
+          // Create initial child with spiller
+          final TreeMemoryTarget initialChild =
+              root.newChild(
+                  "INITIAL", TreeMemoryTarget.CAPACITY_UNLIMITED, spillers, 
Collections.emptyMap());
+
+          spillers.append(
+              new Spiller() {
+                @Override
+                public long spill(MemoryTarget self, Phase phase, long size) {
+                  spillCount.incrementAndGet();
+                  // Simulate spilling by repaying some memory
+                  return initialChild.repay(size / 2);
+                }
+              });
+
+          // Allocate some memory to trigger spilling
+          initialChild.borrow(200);
+
+          final int numThreads = 4;
+          final int operationsPerThread = 50;
+          final ExecutorService executor = 
Executors.newFixedThreadPool(numThreads);
+          final CyclicBarrier barrier = new CyclicBarrier(numThreads);
+          final CountDownLatch latch = new CountDownLatch(numThreads);
+
+          // Thread 1 & 2: Add children concurrently
+          for (int t = 0; t < 2; t++) {
+            executor.submit(
+                () -> {
+                  try {
+                    barrier.await(); // Synchronize start
+                    for (int i = 0; i < operationsPerThread; i++) {
+                      String childName = "CHILD_" + 
Thread.currentThread().getId() + "_" + i;
+                      root.newChild(
+                          childName,
+                          TreeMemoryTarget.CAPACITY_UNLIMITED,
+                          Spillers.NOOP,
+                          Collections.emptyMap());
+                      childrenAdded.incrementAndGet();
+                      Thread.sleep(1); // Small delay to increase contention
+                    }
+                  } catch (Exception e) {
+                    if (e instanceof InterruptedException) {
+                      Thread.currentThread().interrupt();
+                    }
+                    failure.compareAndSet(null, e);
+                  } finally {
+                    latch.countDown();
+                  }
+                });
+          }
+
+          // Thread 3 & 4: Trigger spilling and stats collection concurrently
+          for (int t = 0; t < 2; t++) {
+            executor.submit(
+                () -> {
+                  try {
+                    barrier.await(); // Synchronize start
+                    for (int i = 0; i < operationsPerThread; i++) {
+                      // Trigger spilling by borrowing memory
+                      initialChild.borrow(100);
+                      // Also collect stats which iterates over children
+                      root.stats();
+                      Thread.sleep(1); // Small delay to increase contention
+                    }
+                  } catch (Exception e) {
+                    if (e instanceof InterruptedException) {
+                      Thread.currentThread().interrupt();
+                    }
+                    failure.compareAndSet(null, e);
+                  } finally {
+                    latch.countDown();
+                  }
+                });
+          }
+
+          // Wait for all threads to complete
+          try {
+            Assert.assertTrue(
+                "Threads did not complete in time", latch.await(30, 
TimeUnit.SECONDS));
+            executor.shutdown();
+            Assert.assertTrue(
+                "Executor did not terminate", executor.awaitTermination(10, 
TimeUnit.SECONDS));
+          } catch (InterruptedException e) {
+            Thread.currentThread().interrupt();
+            Assert.fail("Test interrupted: " + e.getMessage());
+          }
+
+          // Verify no exceptions occurred
+          if (failure.get() != null) {
+            Assert.fail("Test failed due to concurrent modification: " + 
failure.get());
+          }
+
+          // Verify children were added
+          Assert.assertEquals(
+              "Expected children to be added", operationsPerThread * 2, 
childrenAdded.get());
+
+          // Verify spilling occurred
+          Assert.assertTrue("Expected spilling to occur", spillCount.get() > 
0);
+        });
+  }
+
+  /**
+   * Test concurrent stats collection while modifying the tree. This ensures 
that stats() can be
+   * called safely while children are being added/removed.
+   */
+  @Test
+  public void testConcurrentStatsCollection() {
+    test(
+        () -> {
+          final TreeMemoryTarget root =
+              TreeMemoryConsumers.factory(MemoryMode.OFF_HEAP).legacyRoot();
+          final TreeMemoryTarget warmup =
+              root.newChild(
+                  "WARMUP",
+                  TreeMemoryTarget.CAPACITY_UNLIMITED,
+                  Spillers.NOOP,
+                  Collections.emptyMap());
+          warmup.borrow(1);
+
+          final AtomicReference<Throwable> failure = new AtomicReference<>();
+          final AtomicLong totalBytesObserved = new AtomicLong(0);
+
+          final int numThreads = 6;
+          final int operationsPerThread = 100;
+          final ExecutorService executor = 
Executors.newFixedThreadPool(numThreads);
+          final CyclicBarrier barrier = new CyclicBarrier(numThreads);
+          final CountDownLatch latch = new CountDownLatch(numThreads);
+
+          // Threads 1-3: Add children and allocate memory
+          for (int t = 0; t < 3; t++) {
+            final int threadId = t;
+            executor.submit(
+                () -> {
+                  try {
+                    barrier.await();
+                    for (int i = 0; i < operationsPerThread; i++) {
+                      String childName = "CHILD_" + threadId + "_" + i;
+                      TreeMemoryTarget child =
+                          root.newChild(
+                              childName,
+                              TreeMemoryTarget.CAPACITY_UNLIMITED,
+                              Spillers.NOOP,
+                              Collections.emptyMap());
+                      child.borrow(10);
+                    }
+                  } catch (Exception e) {
+                    if (e instanceof InterruptedException) {
+                      Thread.currentThread().interrupt();
+                    }
+                    failure.compareAndSet(null, e);
+                  } finally {
+                    latch.countDown();
+                  }
+                });
+          }
+
+          // Threads 4-6: Continuously collect stats
+          for (int t = 0; t < 3; t++) {
+            executor.submit(
+                () -> {
+                  try {
+                    barrier.await();
+                    for (int i = 0; i < operationsPerThread * 2; i++) {
+                      long used = root.stats().getCurrent();
+                      totalBytesObserved.addAndGet(used);
+                    }
+                  } catch (Exception e) {
+                    if (e instanceof InterruptedException) {
+                      Thread.currentThread().interrupt();
+                    }
+                    failure.compareAndSet(null, e);
+                  } finally {
+                    latch.countDown();
+                  }
+                });
+          }
+
+          try {
+            Assert.assertTrue(
+                "Threads did not complete in time", latch.await(30, 
TimeUnit.SECONDS));
+            executor.shutdown();
+            Assert.assertTrue(
+                "Executor did not terminate", executor.awaitTermination(10, 
TimeUnit.SECONDS));
+          } catch (InterruptedException e) {
+            Thread.currentThread().interrupt();
+            Assert.fail("Test interrupted: " + e.getMessage());
+          }
+
+          if (failure.get() != null) {
+            Assert.fail("Test failed due to concurrent modification: " + 
failure.get());
+          }
+          Assert.assertTrue("Expected to observe memory usage", 
totalBytesObserved.get() > 0);
+        });
+  }
+
+  /**
+   * Stress test with high contention on multiple operations. This test 
hammers the
+   * TreeMemoryConsumer with concurrent operations to expose race conditions.
+   */
+  @Test
+  public void testHighContention() {
+    test(
+        () -> {
+          final Spillers.AppendableSpillerList spillers = 
Spillers.appendable();
+          final TreeMemoryTarget root =
+              TreeMemoryConsumers.factory(MemoryMode.OFF_HEAP).legacyRoot();
+
+          final AtomicInteger spillCount = new AtomicInteger(0);
+          final AtomicReference<Throwable> failure = new AtomicReference<>();
+
+          // Create a child with spiller
+          final TreeMemoryTarget spillableChild =
+              root.newChild(
+                  "SPILLABLE",
+                  TreeMemoryTarget.CAPACITY_UNLIMITED,
+                  spillers,
+                  Collections.emptyMap());
+
+          spillers.append(
+              new Spiller() {
+                @Override
+                public long spill(MemoryTarget self, Phase phase, long size) {
+                  spillCount.incrementAndGet();
+                  return spillableChild.repay(size / 2);
+                }
+              });
+
+          spillableChild.borrow(150);
+
+          final int numThreads = 8;
+          final int operationsPerThread = 100;
+          final ExecutorService executor = 
Executors.newFixedThreadPool(numThreads);
+          final CyclicBarrier barrier = new CyclicBarrier(numThreads);
+          final CountDownLatch latch = new CountDownLatch(numThreads);
+
+          for (int t = 0; t < numThreads; t++) {
+            final int threadId = t;
+            executor.submit(
+                () -> {
+                  try {
+                    barrier.await();
+                    for (int i = 0; i < operationsPerThread; i++) {
+                      // Mix of operations
+                      switch (i % 4) {
+                        case 0:
+                          // Add child
+                          root.newChild(
+                              "CHILD_" + threadId + "_" + i,
+                              TreeMemoryTarget.CAPACITY_UNLIMITED,
+                              Spillers.NOOP,
+                              Collections.emptyMap());
+                          break;
+                        case 1:
+                          // Trigger spilling
+                          spillableChild.borrow(50);
+                          break;
+                        case 2:
+                          // Collect stats
+                          root.stats();
+                          break;
+                        case 3:
+                          // Repay memory
+                          spillableChild.repay(10);
+                          break;
+                      }
+                    }
+                  } catch (Exception e) {
+                    if (e instanceof InterruptedException) {
+                      Thread.currentThread().interrupt();
+                    }
+                    failure.compareAndSet(null, e);
+                  } finally {
+                    latch.countDown();
+                  }
+                });
+          }
+
+          try {
+            Assert.assertTrue(
+                "Threads did not complete in time", latch.await(60, 
TimeUnit.SECONDS));
+            executor.shutdown();
+            Assert.assertTrue(
+                "Executor did not terminate", executor.awaitTermination(10, 
TimeUnit.SECONDS));
+          } catch (InterruptedException e) {
+            Thread.currentThread().interrupt();
+            Assert.fail("Test interrupted: " + e.getMessage());
+          }
+
+          if (failure.get() != null) {
+            Assert.fail("Test failed due to concurrent access issues: " + 
failure.get());
+          }
+          Assert.assertTrue("Expected spilling to occur", spillCount.get() > 
0);
+        });
+  }
+
   private void test(Runnable r) {
     TaskResources$.MODULE$.runUnsafe(
         new Function0<Object>() {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to