This is an automated email from the ASF dual-hosted git repository.
hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new ef7ccad9aa [CORE] Minor code cleanups for TreeMemoryConsumer (#8254)
ef7ccad9aa is described below
commit ef7ccad9aa2a8fea0728f662dfb6db97a22c9069
Author: Hongze Zhang <[email protected]>
AuthorDate: Wed Dec 18 08:43:10 2024 +0800
[CORE] Minor code cleanups for TreeMemoryConsumer (#8254)
---
.../gluten/memory/memtarget/MemoryTargets.java | 43 ++++++----
.../memory/memtarget/RetryOnOomMemoryTarget.java | 31 +++----
.../gluten/memory/memtarget/TreeMemoryTargets.java | 3 +-
.../memtarget/spark/TreeMemoryConsumers.java | 97 +++++++++++-----------
.../org/apache/spark/task/TaskResources.scala | 5 +-
.../memtarget/spark/TreeMemoryConsumerTest.java | 90 ++++++++------------
6 files changed, 128 insertions(+), 141 deletions(-)
diff --git
a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/MemoryTargets.java
b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/MemoryTargets.java
index c0f74c7990..1997ce61d2 100644
---
a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/MemoryTargets.java
+++
b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/MemoryTargets.java
@@ -24,10 +24,13 @@ import org.apache.spark.SparkEnv;
import org.apache.spark.annotation.Experimental;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.util.SparkResourceUtil;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
import java.util.Map;
public final class MemoryTargets {
+ private static final Logger LOGGER =
LoggerFactory.getLogger(MemoryTargets.class);
private MemoryTargets() {
// enclose factory ctor
@@ -45,14 +48,6 @@ public final class MemoryTargets {
return new OverAcquire(target, overTarget, overAcquiredRatio);
}
- public static TreeMemoryTarget retrySpillOnOom(TreeMemoryTarget target) {
- SparkEnv env = SparkEnv.get();
- if (env != null && env.conf() != null &&
SparkResourceUtil.getTaskSlots(env.conf()) > 1) {
- return new RetryOnOomMemoryTarget(target);
- }
- return target;
- }
-
@Experimental
public static MemoryTarget dynamicOffHeapSizingIfEnabled(MemoryTarget
memoryTarget) {
if (GlutenConfig.getConf().dynamicOffHeapSizingEnabled()) {
@@ -67,14 +62,32 @@ public final class MemoryTargets {
String name,
Spiller spiller,
Map<String, MemoryUsageStatsBuilder> virtualChildren) {
- final TreeMemoryConsumers.Factory factory;
+ final TreeMemoryConsumers.Factory factory =
TreeMemoryConsumers.factory(tmm);
if (GlutenConfig.getConf().memoryIsolation()) {
- return TreeMemoryConsumers.isolated().newConsumer(tmm, name, spiller,
virtualChildren);
- } else {
- // Retry of spilling is needed in shared mode because the
maxMemoryPerTask of Vanilla Spark
- // ExecutionMemoryPool is dynamic when with multi-slot config.
- return MemoryTargets.retrySpillOnOom(
- TreeMemoryConsumers.shared().newConsumer(tmm, name, spiller,
virtualChildren));
+ return factory.newIsolatedConsumer(name, spiller, virtualChildren);
+ }
+ final TreeMemoryTarget consumer = factory.newLegacyConsumer(name, spiller,
virtualChildren);
+ if (SparkEnv.get() == null) {
+ // We are likely in test code. Return the consumer directly.
+ LOGGER.info("SparkEnv not found. We are likely in test code.");
+ return consumer;
+ }
+ final int taskSlots =
SparkResourceUtil.getTaskSlots(SparkEnv.get().conf());
+ if (taskSlots == 1) {
+ // We don't need to retry on OOM in the case one single task occupies
the whole executor.
+ return consumer;
}
+ // Since https://github.com/apache/incubator-gluten/pull/8132.
+ // Retry of spilling is needed in multi-slot and legacy mode (formerly
named as share mode)
+ // because the maxMemoryPerTask defined by vanilla Spark's
ExecutionMemoryPool is dynamic.
+ //
+ // See the original issue
https://github.com/apache/incubator-gluten/issues/8128.
+ return new RetryOnOomMemoryTarget(
+ consumer,
+ () -> {
+ LOGGER.info("Request for spilling on consumer {}...",
consumer.name());
+ long spilled = TreeMemoryTargets.spillTree(consumer, Long.MAX_VALUE);
+ LOGGER.info("Consumer {} spilled {} bytes.", consumer.name(),
spilled);
+ });
}
}
diff --git
a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/RetryOnOomMemoryTarget.java
b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/RetryOnOomMemoryTarget.java
index 1a5388d0d1..b564bbcaa4 100644
---
a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/RetryOnOomMemoryTarget.java
+++
b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/RetryOnOomMemoryTarget.java
@@ -27,39 +27,30 @@ import java.util.Map;
public class RetryOnOomMemoryTarget implements TreeMemoryTarget {
private static final Logger LOGGER =
LoggerFactory.getLogger(RetryOnOomMemoryTarget.class);
private final TreeMemoryTarget target;
+ private final Runnable onRetry;
- RetryOnOomMemoryTarget(TreeMemoryTarget target) {
+ RetryOnOomMemoryTarget(TreeMemoryTarget target, Runnable onRetry) {
this.target = target;
+ this.onRetry = onRetry;
}
@Override
public long borrow(long size) {
long granted = target.borrow(size);
if (granted < size) {
- LOGGER.info("Retrying spill require:{} got:{}", size, granted);
- final long spilled = retryingSpill(Long.MAX_VALUE);
+ LOGGER.info("Granted size {} is less than requested size {},
retrying...", granted, size);
final long remaining = size - granted;
- if (spilled >= remaining) {
- granted += target.borrow(remaining);
- }
- LOGGER.info("Retrying spill spilled:{} final granted:{}", spilled,
granted);
+ // Invoke the `onRetry` callback, then retry borrowing.
+ // It's usually expected to run extra spilling logics in
+ // the `onRetry` callback so we may get enough memory space
+ // to allocate the remaining bytes.
+ onRetry.run();
+ granted += target.borrow(remaining);
+ LOGGER.info("Newest granted size after retrying: {}, requested size
{}.", granted, size);
}
return granted;
}
- private long retryingSpill(long size) {
- TreeMemoryTarget rootTarget = target;
- while (true) {
- try {
- rootTarget = rootTarget.parent();
- } catch (IllegalStateException e) {
- // Reached the root node
- break;
- }
- }
- return TreeMemoryTargets.spillTree(rootTarget, size);
- }
-
@Override
public long repay(long size) {
return target.repay(size);
diff --git
a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/TreeMemoryTargets.java
b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/TreeMemoryTargets.java
index 26c6ea4800..598317a3c4 100644
---
a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/TreeMemoryTargets.java
+++
b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/TreeMemoryTargets.java
@@ -206,7 +206,8 @@ public class TreeMemoryTargets {
long capacity,
Spiller spiller,
Map<String, MemoryUsageStatsBuilder> virtualChildren) {
- final Node child = new Node(this, name, capacity, spiller,
virtualChildren);
+ final Node child =
+ new Node(this, name, Math.min(this.capacity, capacity), spiller,
virtualChildren);
if (children.containsKey(child.name())) {
throw new IllegalArgumentException("Child already registered: " +
child.name());
}
diff --git
a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumers.java
b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumers.java
index 7ab05bd3a2..e8bfb5cf75 100644
---
a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumers.java
+++
b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumers.java
@@ -24,71 +24,74 @@ import org.apache.gluten.memory.memtarget.TreeMemoryTarget;
import org.apache.commons.collections.map.ReferenceMap;
import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.util.Utils;
import java.util.Collections;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
public final class TreeMemoryConsumers {
- private static final Map<Long, Factory> FACTORIES = new
ConcurrentHashMap<>();
+ private static final ReferenceMap FACTORIES = new ReferenceMap();
private TreeMemoryConsumers() {}
- private static Factory createOrGetFactory(long perTaskCapacity) {
- return FACTORIES.computeIfAbsent(perTaskCapacity, Factory::new);
+ @SuppressWarnings("unchecked")
+ public static Factory factory(TaskMemoryManager tmm) {
+ synchronized (FACTORIES) {
+ return (Factory) FACTORIES.computeIfAbsent(tmm, m -> new
Factory((TaskMemoryManager) m));
+ }
}
- /**
- * A hub to provide memory target instances whose shared size (in the same
task) is limited to X,
- * X = executor memory / task slots.
- *
- * <p>Using this to prevent OOMs if the delegated memory target could
possibly hold large memory
- * blocks that are not spillable.
- *
- * <p>See <a
href="https://github.com/oap-project/gluten/issues/3030">GLUTEN-3030</a>
- */
- public static Factory isolated() {
- return
createOrGetFactory(GlutenConfig.getConf().conservativeTaskOffHeapMemorySize());
- }
+ public static class Factory {
+ private final TreeMemoryConsumer sparkConsumer;
+ private final Map<Long, TreeMemoryTarget> roots = new
ConcurrentHashMap<>();
- /**
- * This works as a legacy Spark memory consumer which grants as much as
possible of memory
- * capacity to each task.
- */
- public static Factory shared() {
- return createOrGetFactory(TreeMemoryTarget.CAPACITY_UNLIMITED);
- }
+ private Factory(TaskMemoryManager tmm) {
+ this.sparkConsumer = new TreeMemoryConsumer(tmm);
+ }
- public static class Factory {
- private final ReferenceMap map = new ReferenceMap(ReferenceMap.WEAK,
ReferenceMap.WEAK);
- private final long perTaskCapacity;
+ private TreeMemoryTarget ofCapacity(long capacity) {
+ return roots.computeIfAbsent(
+ capacity,
+ cap ->
+ sparkConsumer.newChild(
+ String.format("Capacity[%s]", Utils.bytesToString(cap)),
+ cap,
+ Spillers.NOOP,
+ Collections.emptyMap()));
+ }
+
+ private TreeMemoryTarget legacyRoot() {
+ return ofCapacity(TreeMemoryTarget.CAPACITY_UNLIMITED);
+ }
- private Factory(long perTaskCapacity) {
- this.perTaskCapacity = perTaskCapacity;
+ private TreeMemoryTarget isolatedRoot() {
+ return
ofCapacity(GlutenConfig.getConf().conservativeTaskOffHeapMemorySize());
}
- @SuppressWarnings("unchecked")
- private TreeMemoryTarget getSharedAccount(TaskMemoryManager tmm) {
- synchronized (map) {
- return (TreeMemoryTarget)
- map.computeIfAbsent(
- tmm,
- m -> {
- TreeMemoryTarget tmc = new
TreeMemoryConsumer((TaskMemoryManager) m);
- return tmc.newChild(
- "root", perTaskCapacity, Spillers.NOOP,
Collections.emptyMap());
- });
- }
+ /**
+ * This works as a legacy Spark memory consumer which grants as much as
possible of memory
+ * capacity to each task.
+ */
+ public TreeMemoryTarget newLegacyConsumer(
+ String name, Spiller spiller, Map<String, MemoryUsageStatsBuilder>
virtualChildren) {
+ final TreeMemoryTarget parent = legacyRoot();
+ return parent.newChild(name, TreeMemoryConsumer.CAPACITY_UNLIMITED,
spiller, virtualChildren);
}
- public TreeMemoryTarget newConsumer(
- TaskMemoryManager tmm,
- String name,
- Spiller spiller,
- Map<String, MemoryUsageStatsBuilder> virtualChildren) {
- final TreeMemoryTarget account = getSharedAccount(tmm);
- return account.newChild(
- name, TreeMemoryConsumer.CAPACITY_UNLIMITED, spiller,
virtualChildren);
+ /**
+ * A hub to provide memory target instances whose shared size (in the same
task) is limited to
+ * X, X = executor memory / task slots.
+ *
+ * <p>Using this to prevent OOMs if the delegated memory target could
possibly hold large memory
+ * blocks that are not spill-able.
+ *
+ * <p>See <a
href="https://github.com/oap-project/gluten/issues/3030">GLUTEN-3030</a>
+ */
+ public TreeMemoryTarget newIsolatedConsumer(
+ String name, Spiller spiller, Map<String, MemoryUsageStatsBuilder>
virtualChildren) {
+ final TreeMemoryTarget parent = isolatedRoot();
+ return parent.newChild(name, TreeMemoryConsumer.CAPACITY_UNLIMITED,
spiller, virtualChildren);
}
}
}
diff --git
a/gluten-core/src/main/scala/org/apache/spark/task/TaskResources.scala
b/gluten-core/src/main/scala/org/apache/spark/task/TaskResources.scala
index df5917125b..2f609b026d 100644
--- a/gluten-core/src/main/scala/org/apache/spark/task/TaskResources.scala
+++ b/gluten-core/src/main/scala/org/apache/spark/task/TaskResources.scala
@@ -16,6 +16,7 @@
*/
package org.apache.spark.task
+import org.apache.gluten.GlutenConfig
import org.apache.gluten.task.TaskListener
import org.apache.spark.{TaskContext, TaskFailedReason, TaskKilledException,
UnknownReason}
@@ -65,8 +66,8 @@ object TaskResources extends TaskListener with Logging {
properties.put(key, value)
case _ =>
}
- properties.setIfMissing("spark.memory.offHeap.enabled", "true")
- properties.setIfMissing("spark.memory.offHeap.size", "1TB")
+ properties.setIfMissing(GlutenConfig.SPARK_OFFHEAP_ENABLED, "true")
+ properties.setIfMissing(GlutenConfig.SPARK_OFFHEAP_SIZE_KEY, "1TB")
TaskContext.setTaskContext(newUnsafeTaskContext(properties))
}
diff --git
a/gluten-core/src/test/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumerTest.java
b/gluten-core/src/test/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumerTest.java
index befe449186..934300a1ac 100644
---
a/gluten-core/src/test/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumerTest.java
+++
b/gluten-core/src/test/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumerTest.java
@@ -49,13 +49,10 @@ public class TreeMemoryConsumerTest {
public void testIsolated() {
test(
() -> {
- final TreeMemoryConsumers.Factory factory =
TreeMemoryConsumers.isolated();
+ final TreeMemoryConsumers.Factory factory =
+
TreeMemoryConsumers.factory(TaskContext.get().taskMemoryManager());
final TreeMemoryTarget consumer =
- factory.newConsumer(
- TaskContext.get().taskMemoryManager(),
- "FOO",
- Spillers.NOOP,
- Collections.emptyMap());
+ factory.newIsolatedConsumer("FOO", Spillers.NOOP,
Collections.emptyMap());
Assert.assertEquals(20, consumer.borrow(20));
Assert.assertEquals(70, consumer.borrow(70));
Assert.assertEquals(10, consumer.borrow(20));
@@ -64,16 +61,13 @@ public class TreeMemoryConsumerTest {
}
@Test
- public void testShared() {
+ public void testLegacy() {
test(
() -> {
- final TreeMemoryConsumers.Factory factory =
TreeMemoryConsumers.shared();
+ final TreeMemoryConsumers.Factory factory =
+
TreeMemoryConsumers.factory(TaskContext.get().taskMemoryManager());
final TreeMemoryTarget consumer =
- factory.newConsumer(
- TaskContext.get().taskMemoryManager(),
- "FOO",
- Spillers.NOOP,
- Collections.emptyMap());
+ factory.newLegacyConsumer("FOO", Spillers.NOOP,
Collections.emptyMap());
Assert.assertEquals(20, consumer.borrow(20));
Assert.assertEquals(70, consumer.borrow(70));
Assert.assertEquals(20, consumer.borrow(20));
@@ -82,24 +76,16 @@ public class TreeMemoryConsumerTest {
}
@Test
- public void testIsolatedAndShared() {
+ public void testIsolatedAndLegacy() {
test(
() -> {
- final TreeMemoryTarget shared =
- TreeMemoryConsumers.shared()
- .newConsumer(
- TaskContext.get().taskMemoryManager(),
- "FOO",
- Spillers.NOOP,
- Collections.emptyMap());
- Assert.assertEquals(110, shared.borrow(110));
+ final TreeMemoryTarget legacy =
+
TreeMemoryConsumers.factory(TaskContext.get().taskMemoryManager())
+ .newLegacyConsumer("FOO", Spillers.NOOP,
Collections.emptyMap());
+ Assert.assertEquals(110, legacy.borrow(110));
final TreeMemoryTarget isolated =
- TreeMemoryConsumers.isolated()
- .newConsumer(
- TaskContext.get().taskMemoryManager(),
- "FOO",
- Spillers.NOOP,
- Collections.emptyMap());
+
TreeMemoryConsumers.factory(TaskContext.get().taskMemoryManager())
+ .newIsolatedConsumer("FOO", Spillers.NOOP,
Collections.emptyMap());
Assert.assertEquals(100, isolated.borrow(110));
});
}
@@ -109,36 +95,32 @@ public class TreeMemoryConsumerTest {
test(
() -> {
final Spillers.AppendableSpillerList spillers =
Spillers.appendable();
- final TreeMemoryTarget shared =
- TreeMemoryConsumers.shared()
- .newConsumer(
- TaskContext.get().taskMemoryManager(),
- "FOO",
- spillers,
- Collections.emptyMap());
+ final TreeMemoryTarget legacy =
+
TreeMemoryConsumers.factory(TaskContext.get().taskMemoryManager())
+ .newLegacyConsumer("FOO", spillers, Collections.emptyMap());
final AtomicInteger numSpills = new AtomicInteger(0);
final AtomicLong numSpilledBytes = new AtomicLong(0L);
spillers.append(
new Spiller() {
@Override
public long spill(MemoryTarget self, Phase phase, long size) {
- long repaid = shared.repay(size);
+ long repaid = legacy.repay(size);
numSpills.getAndIncrement();
numSpilledBytes.getAndAdd(repaid);
return repaid;
}
});
- Assert.assertEquals(300, shared.borrow(300));
- Assert.assertEquals(300, shared.borrow(300));
+ Assert.assertEquals(300, legacy.borrow(300));
+ Assert.assertEquals(300, legacy.borrow(300));
Assert.assertEquals(1, numSpills.get());
Assert.assertEquals(200, numSpilledBytes.get());
- Assert.assertEquals(400, shared.usedBytes());
+ Assert.assertEquals(400, legacy.usedBytes());
- Assert.assertEquals(300, shared.borrow(300));
- Assert.assertEquals(300, shared.borrow(300));
+ Assert.assertEquals(300, legacy.borrow(300));
+ Assert.assertEquals(300, legacy.borrow(300));
Assert.assertEquals(3, numSpills.get());
Assert.assertEquals(800, numSpilledBytes.get());
- Assert.assertEquals(400, shared.usedBytes());
+ Assert.assertEquals(400, legacy.usedBytes());
});
}
@@ -147,36 +129,32 @@ public class TreeMemoryConsumerTest {
test(
() -> {
final Spillers.AppendableSpillerList spillers =
Spillers.appendable();
- final TreeMemoryTarget shared =
- TreeMemoryConsumers.shared()
- .newConsumer(
- TaskContext.get().taskMemoryManager(),
- "FOO",
- spillers,
- Collections.emptyMap());
+ final TreeMemoryTarget legacy =
+
TreeMemoryConsumers.factory(TaskContext.get().taskMemoryManager())
+ .newLegacyConsumer("FOO", spillers, Collections.emptyMap());
final AtomicInteger numSpills = new AtomicInteger(0);
final AtomicLong numSpilledBytes = new AtomicLong(0L);
spillers.append(
new Spiller() {
@Override
public long spill(MemoryTarget self, Phase phase, long size) {
- long repaid = shared.repay(Long.MAX_VALUE);
+ long repaid = legacy.repay(Long.MAX_VALUE);
numSpills.getAndIncrement();
numSpilledBytes.getAndAdd(repaid);
return repaid;
}
});
- Assert.assertEquals(300, shared.borrow(300));
- Assert.assertEquals(300, shared.borrow(300));
+ Assert.assertEquals(300, legacy.borrow(300));
+ Assert.assertEquals(300, legacy.borrow(300));
Assert.assertEquals(1, numSpills.get());
Assert.assertEquals(300, numSpilledBytes.get());
- Assert.assertEquals(300, shared.usedBytes());
+ Assert.assertEquals(300, legacy.usedBytes());
- Assert.assertEquals(300, shared.borrow(300));
- Assert.assertEquals(300, shared.borrow(300));
+ Assert.assertEquals(300, legacy.borrow(300));
+ Assert.assertEquals(300, legacy.borrow(300));
Assert.assertEquals(3, numSpills.get());
Assert.assertEquals(900, numSpilledBytes.get());
- Assert.assertEquals(300, shared.usedBytes());
+ Assert.assertEquals(300, legacy.usedBytes());
});
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]