This is an automated email from the ASF dual-hosted git repository.
hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new a03f6bf7c1 [GLUTEN-11509][VL] Make TreeMemoryConsumer thread-safe
(#11553)
a03f6bf7c1 is described below
commit a03f6bf7c1fe2aa0db029b9bb498398edb534886
Author: Mohammad Linjawi <[email protected]>
AuthorDate: Wed Feb 4 16:39:44 2026 +0300
[GLUTEN-11509][VL] Make TreeMemoryConsumer thread-safe (#11553)
---
.../memory/memtarget/spark/TreeMemoryConsumer.java | 35 +--
.../memtarget/spark/TreeMemoryConsumerTest.java | 318 +++++++++++++++++++++
2 files changed, 330 insertions(+), 23 deletions(-)
diff --git
a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumer.java
b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumer.java
index f454b55324..0f00e1d669 100644
---
a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumer.java
+++
b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumer.java
@@ -29,9 +29,8 @@ import org.apache.spark.util.Utils;
import java.io.IOException;
import java.util.Collections;
-import java.util.HashMap;
import java.util.Map;
-import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
/**
@@ -50,7 +49,7 @@ import java.util.stream.Collectors;
public class TreeMemoryConsumer extends MemoryConsumer implements
TreeMemoryTarget {
private final SimpleMemoryUsageRecorder recorder = new
SimpleMemoryUsageRecorder();
- private final Map<String, TreeMemoryTarget> children = new HashMap<>();
+ private final Map<String, TreeMemoryTarget> children = new
ConcurrentHashMap<>();
private final String name = MemoryTargetUtil.toUniqueName("Gluten.Tree");
TreeMemoryConsumer(TaskMemoryManager taskMemoryManager, MemoryMode mode) {
@@ -97,17 +96,10 @@ public class TreeMemoryConsumer extends MemoryConsumer
implements TreeMemoryTarg
@Override
public MemoryUsageStats stats() {
- Set<Map.Entry<String, TreeMemoryTarget>> entries = children.entrySet();
Map<String, MemoryUsageStats> childrenStats =
- entries.stream()
- .collect(Collectors.toMap(e -> e.getValue().name(), e ->
e.getValue().stats()));
-
- Preconditions.checkState(childrenStats.size() == children.size());
- MemoryUsageStats stats = recorder.toStats(childrenStats);
- Preconditions.checkState(
- stats.getCurrent() == getUsed(),
- "Used bytes mismatch between gluten memory consumer and Spark task
memory manager");
- return stats;
+ children.values().stream()
+ .collect(Collectors.toMap(TreeMemoryTarget::name,
TreeMemoryTarget::stats));
+ return recorder.toStats(childrenStats);
}
@Override
@@ -123,10 +115,10 @@ public class TreeMemoryConsumer extends MemoryConsumer
implements TreeMemoryTarg
Spiller spiller,
Map<String, MemoryUsageStatsBuilder> virtualChildren) {
final TreeMemoryTarget child = new Node(this, name, capacity, spiller,
virtualChildren);
- if (children.containsKey(child.name())) {
+ TreeMemoryTarget existing = children.putIfAbsent(child.name(), child);
+ if (existing != null) {
throw new IllegalArgumentException("Child already registered: " +
child.name());
}
- children.put(child.name(), child);
return child;
}
@@ -153,7 +145,7 @@ public class TreeMemoryConsumer extends MemoryConsumer
implements TreeMemoryTarg
}
public static class Node implements TreeMemoryTarget, KnownNameAndStats {
- private final Map<String, Node> children = new HashMap<>();
+ private final Map<String, Node> children = new ConcurrentHashMap<>();
private final TreeMemoryTarget parent;
private final String name;
private final long capacity;
@@ -251,11 +243,8 @@ public class TreeMemoryConsumer extends MemoryConsumer
implements TreeMemoryTarg
@Override
public MemoryUsageStats stats() {
final Map<String, MemoryUsageStats> childrenStats =
- new HashMap<>(
- children.entrySet().stream()
- .collect(Collectors.toMap(e -> e.getValue().name(), e ->
e.getValue().stats())));
-
- Preconditions.checkState(childrenStats.size() == children.size());
+ children.values().stream()
+ .collect(Collectors.toMap(TreeMemoryTarget::name,
TreeMemoryTarget::stats));
// add virtual children
for (Map.Entry<String, MemoryUsageStatsBuilder> entry :
virtualChildren.entrySet()) {
@@ -275,10 +264,10 @@ public class TreeMemoryConsumer extends MemoryConsumer
implements TreeMemoryTarg
Map<String, MemoryUsageStatsBuilder> virtualChildren) {
final Node child =
new Node(this, name, Math.min(this.capacity, capacity), spiller,
virtualChildren);
- if (children.containsKey(child.name())) {
+ Node existing = children.putIfAbsent(child.name(), child);
+ if (existing != null) {
throw new IllegalArgumentException("Child already registered: " +
child.name());
}
- children.put(child.name(), child);
return child;
}
diff --git
a/gluten-core/src/test/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumerTest.java
b/gluten-core/src/test/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumerTest.java
index e252b7d7c4..b3b6ad2bd1 100644
---
a/gluten-core/src/test/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumerTest.java
+++
b/gluten-core/src/test/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumerTest.java
@@ -30,8 +30,14 @@ import org.junit.Before;
import org.junit.Test;
import java.util.Collections;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.CyclicBarrier;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
+import java.util.concurrent.atomic.AtomicReference;
import scala.Function0;
@@ -184,6 +190,318 @@ public class TreeMemoryConsumerTest {
});
}
+ /**
+ * Test concurrent child addition and spilling operations. This test
reproduces the
+ * ConcurrentModificationException that occurs when one thread adds children
while another thread
+ * is iterating during spilling.
+ */
+ @Test
+ public void testConcurrentAddAndSpill() {
+ test(
+ () -> {
+ final Spillers.AppendableSpillerList spillers =
Spillers.appendable();
+ final TreeMemoryTarget root =
+ TreeMemoryConsumers.factory(MemoryMode.OFF_HEAP).legacyRoot();
+
+ final AtomicInteger spillCount = new AtomicInteger(0);
+ final AtomicReference<Throwable> failure = new AtomicReference<>();
+ final AtomicInteger childrenAdded = new AtomicInteger(0);
+
+ // Create initial child with spiller
+ final TreeMemoryTarget initialChild =
+ root.newChild(
+ "INITIAL", TreeMemoryTarget.CAPACITY_UNLIMITED, spillers,
Collections.emptyMap());
+
+ spillers.append(
+ new Spiller() {
+ @Override
+ public long spill(MemoryTarget self, Phase phase, long size) {
+ spillCount.incrementAndGet();
+ // Simulate spilling by repaying some memory
+ return initialChild.repay(size / 2);
+ }
+ });
+
+ // Allocate some memory to trigger spilling
+ initialChild.borrow(200);
+
+ final int numThreads = 4;
+ final int operationsPerThread = 50;
+ final ExecutorService executor =
Executors.newFixedThreadPool(numThreads);
+ final CyclicBarrier barrier = new CyclicBarrier(numThreads);
+ final CountDownLatch latch = new CountDownLatch(numThreads);
+
+ // Thread 1 & 2: Add children concurrently
+ for (int t = 0; t < 2; t++) {
+ executor.submit(
+ () -> {
+ try {
+ barrier.await(); // Synchronize start
+ for (int i = 0; i < operationsPerThread; i++) {
+ String childName = "CHILD_" +
Thread.currentThread().getId() + "_" + i;
+ root.newChild(
+ childName,
+ TreeMemoryTarget.CAPACITY_UNLIMITED,
+ Spillers.NOOP,
+ Collections.emptyMap());
+ childrenAdded.incrementAndGet();
+ Thread.sleep(1); // Small delay to increase contention
+ }
+ } catch (Exception e) {
+ if (e instanceof InterruptedException) {
+ Thread.currentThread().interrupt();
+ }
+ failure.compareAndSet(null, e);
+ } finally {
+ latch.countDown();
+ }
+ });
+ }
+
+ // Thread 3 & 4: Trigger spilling and stats collection concurrently
+ for (int t = 0; t < 2; t++) {
+ executor.submit(
+ () -> {
+ try {
+ barrier.await(); // Synchronize start
+ for (int i = 0; i < operationsPerThread; i++) {
+ // Trigger spilling by borrowing memory
+ initialChild.borrow(100);
+ // Also collect stats which iterates over children
+ root.stats();
+ Thread.sleep(1); // Small delay to increase contention
+ }
+ } catch (Exception e) {
+ if (e instanceof InterruptedException) {
+ Thread.currentThread().interrupt();
+ }
+ failure.compareAndSet(null, e);
+ } finally {
+ latch.countDown();
+ }
+ });
+ }
+
+ // Wait for all threads to complete
+ try {
+ Assert.assertTrue(
+ "Threads did not complete in time", latch.await(30,
TimeUnit.SECONDS));
+ executor.shutdown();
+ Assert.assertTrue(
+ "Executor did not terminate", executor.awaitTermination(10,
TimeUnit.SECONDS));
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ Assert.fail("Test interrupted: " + e.getMessage());
+ }
+
+ // Verify no exceptions occurred
+ if (failure.get() != null) {
+ Assert.fail("Test failed due to concurrent modification: " +
failure.get());
+ }
+
+ // Verify children were added
+ Assert.assertEquals(
+ "Expected children to be added", operationsPerThread * 2,
childrenAdded.get());
+
+ // Verify spilling occurred
+ Assert.assertTrue("Expected spilling to occur", spillCount.get() >
0);
+ });
+ }
+
+ /**
+ * Test concurrent stats collection while modifying the tree. This ensures
that stats() can be
+ * called safely while children are being added/removed.
+ */
+ @Test
+ public void testConcurrentStatsCollection() {
+ test(
+ () -> {
+ final TreeMemoryTarget root =
+ TreeMemoryConsumers.factory(MemoryMode.OFF_HEAP).legacyRoot();
+ final TreeMemoryTarget warmup =
+ root.newChild(
+ "WARMUP",
+ TreeMemoryTarget.CAPACITY_UNLIMITED,
+ Spillers.NOOP,
+ Collections.emptyMap());
+ warmup.borrow(1);
+
+ final AtomicReference<Throwable> failure = new AtomicReference<>();
+ final AtomicLong totalBytesObserved = new AtomicLong(0);
+
+ final int numThreads = 6;
+ final int operationsPerThread = 100;
+ final ExecutorService executor =
Executors.newFixedThreadPool(numThreads);
+ final CyclicBarrier barrier = new CyclicBarrier(numThreads);
+ final CountDownLatch latch = new CountDownLatch(numThreads);
+
+ // Threads 1-3: Add children and allocate memory
+ for (int t = 0; t < 3; t++) {
+ final int threadId = t;
+ executor.submit(
+ () -> {
+ try {
+ barrier.await();
+ for (int i = 0; i < operationsPerThread; i++) {
+ String childName = "CHILD_" + threadId + "_" + i;
+ TreeMemoryTarget child =
+ root.newChild(
+ childName,
+ TreeMemoryTarget.CAPACITY_UNLIMITED,
+ Spillers.NOOP,
+ Collections.emptyMap());
+ child.borrow(10);
+ }
+ } catch (Exception e) {
+ if (e instanceof InterruptedException) {
+ Thread.currentThread().interrupt();
+ }
+ failure.compareAndSet(null, e);
+ } finally {
+ latch.countDown();
+ }
+ });
+ }
+
+ // Threads 4-6: Continuously collect stats
+ for (int t = 0; t < 3; t++) {
+ executor.submit(
+ () -> {
+ try {
+ barrier.await();
+ for (int i = 0; i < operationsPerThread * 2; i++) {
+ long used = root.stats().getCurrent();
+ totalBytesObserved.addAndGet(used);
+ }
+ } catch (Exception e) {
+ if (e instanceof InterruptedException) {
+ Thread.currentThread().interrupt();
+ }
+ failure.compareAndSet(null, e);
+ } finally {
+ latch.countDown();
+ }
+ });
+ }
+
+ try {
+ Assert.assertTrue(
+ "Threads did not complete in time", latch.await(30,
TimeUnit.SECONDS));
+ executor.shutdown();
+ Assert.assertTrue(
+ "Executor did not terminate", executor.awaitTermination(10,
TimeUnit.SECONDS));
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ Assert.fail("Test interrupted: " + e.getMessage());
+ }
+
+ if (failure.get() != null) {
+ Assert.fail("Test failed due to concurrent modification: " +
failure.get());
+ }
+ Assert.assertTrue("Expected to observe memory usage",
totalBytesObserved.get() > 0);
+ });
+ }
+
+ /**
+ * Stress test with high contention on multiple operations. This test
hammers the
+ * TreeMemoryConsumer with concurrent operations to expose race conditions.
+ */
+ @Test
+ public void testHighContention() {
+ test(
+ () -> {
+ final Spillers.AppendableSpillerList spillers =
Spillers.appendable();
+ final TreeMemoryTarget root =
+ TreeMemoryConsumers.factory(MemoryMode.OFF_HEAP).legacyRoot();
+
+ final AtomicInteger spillCount = new AtomicInteger(0);
+ final AtomicReference<Throwable> failure = new AtomicReference<>();
+
+ // Create a child with spiller
+ final TreeMemoryTarget spillableChild =
+ root.newChild(
+ "SPILLABLE",
+ TreeMemoryTarget.CAPACITY_UNLIMITED,
+ spillers,
+ Collections.emptyMap());
+
+ spillers.append(
+ new Spiller() {
+ @Override
+ public long spill(MemoryTarget self, Phase phase, long size) {
+ spillCount.incrementAndGet();
+ return spillableChild.repay(size / 2);
+ }
+ });
+
+ spillableChild.borrow(150);
+
+ final int numThreads = 8;
+ final int operationsPerThread = 100;
+ final ExecutorService executor =
Executors.newFixedThreadPool(numThreads);
+ final CyclicBarrier barrier = new CyclicBarrier(numThreads);
+ final CountDownLatch latch = new CountDownLatch(numThreads);
+
+ for (int t = 0; t < numThreads; t++) {
+ final int threadId = t;
+ executor.submit(
+ () -> {
+ try {
+ barrier.await();
+ for (int i = 0; i < operationsPerThread; i++) {
+ // Mix of operations
+ switch (i % 4) {
+ case 0:
+ // Add child
+ root.newChild(
+ "CHILD_" + threadId + "_" + i,
+ TreeMemoryTarget.CAPACITY_UNLIMITED,
+ Spillers.NOOP,
+ Collections.emptyMap());
+ break;
+ case 1:
+ // Trigger spilling
+ spillableChild.borrow(50);
+ break;
+ case 2:
+ // Collect stats
+ root.stats();
+ break;
+ case 3:
+ // Repay memory
+ spillableChild.repay(10);
+ break;
+ }
+ }
+ } catch (Exception e) {
+ if (e instanceof InterruptedException) {
+ Thread.currentThread().interrupt();
+ }
+ failure.compareAndSet(null, e);
+ } finally {
+ latch.countDown();
+ }
+ });
+ }
+
+ try {
+ Assert.assertTrue(
+ "Threads did not complete in time", latch.await(60,
TimeUnit.SECONDS));
+ executor.shutdown();
+ Assert.assertTrue(
+ "Executor did not terminate", executor.awaitTermination(10,
TimeUnit.SECONDS));
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ Assert.fail("Test interrupted: " + e.getMessage());
+ }
+
+ if (failure.get() != null) {
+ Assert.fail("Test failed due to concurrent access issues: " +
failure.get());
+ }
+ Assert.assertTrue("Expected spilling to occur", spillCount.get() >
0);
+ });
+ }
+
private void test(Runnable r) {
TaskResources$.MODULE$.runUnsafe(
new Function0<Object>() {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]