This is an automated email from the ASF dual-hosted git repository.

hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new ef7ccad9aa [CORE] Minor code cleanups for TreeMemoryConsumer (#8254)
ef7ccad9aa is described below

commit ef7ccad9aa2a8fea0728f662dfb6db97a22c9069
Author: Hongze Zhang <[email protected]>
AuthorDate: Wed Dec 18 08:43:10 2024 +0800

    [CORE] Minor code cleanups for TreeMemoryConsumer (#8254)
---
 .../gluten/memory/memtarget/MemoryTargets.java     | 43 ++++++----
 .../memory/memtarget/RetryOnOomMemoryTarget.java   | 31 +++----
 .../gluten/memory/memtarget/TreeMemoryTargets.java |  3 +-
 .../memtarget/spark/TreeMemoryConsumers.java       | 97 +++++++++++-----------
 .../org/apache/spark/task/TaskResources.scala      |  5 +-
 .../memtarget/spark/TreeMemoryConsumerTest.java    | 90 ++++++++------------
 6 files changed, 128 insertions(+), 141 deletions(-)

diff --git 
a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/MemoryTargets.java
 
b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/MemoryTargets.java
index c0f74c7990..1997ce61d2 100644
--- 
a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/MemoryTargets.java
+++ 
b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/MemoryTargets.java
@@ -24,10 +24,13 @@ import org.apache.spark.SparkEnv;
 import org.apache.spark.annotation.Experimental;
 import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.util.SparkResourceUtil;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 import java.util.Map;
 
 public final class MemoryTargets {
+  private static final Logger LOGGER = 
LoggerFactory.getLogger(MemoryTargets.class);
 
   private MemoryTargets() {
     // enclose factory ctor
@@ -45,14 +48,6 @@ public final class MemoryTargets {
     return new OverAcquire(target, overTarget, overAcquiredRatio);
   }
 
-  public static TreeMemoryTarget retrySpillOnOom(TreeMemoryTarget target) {
-    SparkEnv env = SparkEnv.get();
-    if (env != null && env.conf() != null && 
SparkResourceUtil.getTaskSlots(env.conf()) > 1) {
-      return new RetryOnOomMemoryTarget(target);
-    }
-    return target;
-  }
-
   @Experimental
   public static MemoryTarget dynamicOffHeapSizingIfEnabled(MemoryTarget 
memoryTarget) {
     if (GlutenConfig.getConf().dynamicOffHeapSizingEnabled()) {
@@ -67,14 +62,32 @@ public final class MemoryTargets {
       String name,
       Spiller spiller,
       Map<String, MemoryUsageStatsBuilder> virtualChildren) {
-    final TreeMemoryConsumers.Factory factory;
+    final TreeMemoryConsumers.Factory factory = 
TreeMemoryConsumers.factory(tmm);
     if (GlutenConfig.getConf().memoryIsolation()) {
-      return TreeMemoryConsumers.isolated().newConsumer(tmm, name, spiller, 
virtualChildren);
-    } else {
-      // Retry of spilling is needed in shared mode because the 
maxMemoryPerTask of Vanilla Spark
-      // ExecutionMemoryPool is dynamic when with multi-slot config.
-      return MemoryTargets.retrySpillOnOom(
-          TreeMemoryConsumers.shared().newConsumer(tmm, name, spiller, 
virtualChildren));
+      return factory.newIsolatedConsumer(name, spiller, virtualChildren);
+    }
+    final TreeMemoryTarget consumer = factory.newLegacyConsumer(name, spiller, 
virtualChildren);
+    if (SparkEnv.get() == null) {
+      // We are likely in test code. Return the consumer directly.
+      LOGGER.info("SparkEnv not found. We are likely in test code.");
+      return consumer;
+    }
+    final int taskSlots = 
SparkResourceUtil.getTaskSlots(SparkEnv.get().conf());
+    if (taskSlots == 1) {
+      // We don't need to retry on OOM in the case one single task occupies 
the whole executor.
+      return consumer;
     }
+    // Since https://github.com/apache/incubator-gluten/pull/8132.
+    // Retry of spilling is needed in multi-slot and legacy mode (formerly 
named as share mode)
+    // because the maxMemoryPerTask defined by vanilla Spark's 
ExecutionMemoryPool is dynamic.
+    //
+    // See the original issue 
https://github.com/apache/incubator-gluten/issues/8128.
+    return new RetryOnOomMemoryTarget(
+        consumer,
+        () -> {
+          LOGGER.info("Request for spilling on consumer {}...", 
consumer.name());
+          long spilled = TreeMemoryTargets.spillTree(consumer, Long.MAX_VALUE);
+          LOGGER.info("Consumer {} spilled {} bytes.", consumer.name(), 
spilled);
+        });
   }
 }
diff --git 
a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/RetryOnOomMemoryTarget.java
 
b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/RetryOnOomMemoryTarget.java
index 1a5388d0d1..b564bbcaa4 100644
--- 
a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/RetryOnOomMemoryTarget.java
+++ 
b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/RetryOnOomMemoryTarget.java
@@ -27,39 +27,30 @@ import java.util.Map;
 public class RetryOnOomMemoryTarget implements TreeMemoryTarget {
   private static final Logger LOGGER = 
LoggerFactory.getLogger(RetryOnOomMemoryTarget.class);
   private final TreeMemoryTarget target;
+  private final Runnable onRetry;
 
-  RetryOnOomMemoryTarget(TreeMemoryTarget target) {
+  RetryOnOomMemoryTarget(TreeMemoryTarget target, Runnable onRetry) {
     this.target = target;
+    this.onRetry = onRetry;
   }
 
   @Override
   public long borrow(long size) {
     long granted = target.borrow(size);
     if (granted < size) {
-      LOGGER.info("Retrying spill require:{} got:{}", size, granted);
-      final long spilled = retryingSpill(Long.MAX_VALUE);
+      LOGGER.info("Granted size {} is less than requested size {}, 
retrying...", granted, size);
       final long remaining = size - granted;
-      if (spilled >= remaining) {
-        granted += target.borrow(remaining);
-      }
-      LOGGER.info("Retrying spill spilled:{} final granted:{}", spilled, 
granted);
+      // Invoke the `onRetry` callback, then retry borrowing.
+      // It's usually expected to run extra spilling logics in
+      // the `onRetry` callback so we may get enough memory space
+      // to allocate the remaining bytes.
+      onRetry.run();
+      granted += target.borrow(remaining);
+      LOGGER.info("Newest granted size after retrying: {}, requested size 
{}.", granted, size);
     }
     return granted;
   }
 
-  private long retryingSpill(long size) {
-    TreeMemoryTarget rootTarget = target;
-    while (true) {
-      try {
-        rootTarget = rootTarget.parent();
-      } catch (IllegalStateException e) {
-        // Reached the root node
-        break;
-      }
-    }
-    return TreeMemoryTargets.spillTree(rootTarget, size);
-  }
-
   @Override
   public long repay(long size) {
     return target.repay(size);
diff --git 
a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/TreeMemoryTargets.java
 
b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/TreeMemoryTargets.java
index 26c6ea4800..598317a3c4 100644
--- 
a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/TreeMemoryTargets.java
+++ 
b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/TreeMemoryTargets.java
@@ -206,7 +206,8 @@ public class TreeMemoryTargets {
         long capacity,
         Spiller spiller,
         Map<String, MemoryUsageStatsBuilder> virtualChildren) {
-      final Node child = new Node(this, name, capacity, spiller, 
virtualChildren);
+      final Node child =
+          new Node(this, name, Math.min(this.capacity, capacity), spiller, 
virtualChildren);
       if (children.containsKey(child.name())) {
         throw new IllegalArgumentException("Child already registered: " + 
child.name());
       }
diff --git 
a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumers.java
 
b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumers.java
index 7ab05bd3a2..e8bfb5cf75 100644
--- 
a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumers.java
+++ 
b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumers.java
@@ -24,71 +24,74 @@ import org.apache.gluten.memory.memtarget.TreeMemoryTarget;
 
 import org.apache.commons.collections.map.ReferenceMap;
 import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.util.Utils;
 
 import java.util.Collections;
 import java.util.Map;
 import java.util.concurrent.ConcurrentHashMap;
 
 public final class TreeMemoryConsumers {
-  private static final Map<Long, Factory> FACTORIES = new 
ConcurrentHashMap<>();
+  private static final ReferenceMap FACTORIES = new ReferenceMap();
 
   private TreeMemoryConsumers() {}
 
-  private static Factory createOrGetFactory(long perTaskCapacity) {
-    return FACTORIES.computeIfAbsent(perTaskCapacity, Factory::new);
+  @SuppressWarnings("unchecked")
+  public static Factory factory(TaskMemoryManager tmm) {
+    synchronized (FACTORIES) {
+      return (Factory) FACTORIES.computeIfAbsent(tmm, m -> new 
Factory((TaskMemoryManager) m));
+    }
   }
 
-  /**
-   * A hub to provide memory target instances whose shared size (in the same 
task) is limited to X,
-   * X = executor memory / task slots.
-   *
-   * <p>Using this to prevent OOMs if the delegated memory target could 
possibly hold large memory
-   * blocks that are not spillable.
-   *
-   * <p>See <a 
href="https://github.com/oap-project/gluten/issues/3030";>GLUTEN-3030</a>
-   */
-  public static Factory isolated() {
-    return 
createOrGetFactory(GlutenConfig.getConf().conservativeTaskOffHeapMemorySize());
-  }
+  public static class Factory {
+    private final TreeMemoryConsumer sparkConsumer;
+    private final Map<Long, TreeMemoryTarget> roots = new 
ConcurrentHashMap<>();
 
-  /**
-   * This works as a legacy Spark memory consumer which grants as much as 
possible of memory
-   * capacity to each task.
-   */
-  public static Factory shared() {
-    return createOrGetFactory(TreeMemoryTarget.CAPACITY_UNLIMITED);
-  }
+    private Factory(TaskMemoryManager tmm) {
+      this.sparkConsumer = new TreeMemoryConsumer(tmm);
+    }
 
-  public static class Factory {
-    private final ReferenceMap map = new ReferenceMap(ReferenceMap.WEAK, 
ReferenceMap.WEAK);
-    private final long perTaskCapacity;
+    private TreeMemoryTarget ofCapacity(long capacity) {
+      return roots.computeIfAbsent(
+          capacity,
+          cap ->
+              sparkConsumer.newChild(
+                  String.format("Capacity[%s]", Utils.bytesToString(cap)),
+                  cap,
+                  Spillers.NOOP,
+                  Collections.emptyMap()));
+    }
+
+    private TreeMemoryTarget legacyRoot() {
+      return ofCapacity(TreeMemoryTarget.CAPACITY_UNLIMITED);
+    }
 
-    private Factory(long perTaskCapacity) {
-      this.perTaskCapacity = perTaskCapacity;
+    private TreeMemoryTarget isolatedRoot() {
+      return 
ofCapacity(GlutenConfig.getConf().conservativeTaskOffHeapMemorySize());
     }
 
-    @SuppressWarnings("unchecked")
-    private TreeMemoryTarget getSharedAccount(TaskMemoryManager tmm) {
-      synchronized (map) {
-        return (TreeMemoryTarget)
-            map.computeIfAbsent(
-                tmm,
-                m -> {
-                  TreeMemoryTarget tmc = new 
TreeMemoryConsumer((TaskMemoryManager) m);
-                  return tmc.newChild(
-                      "root", perTaskCapacity, Spillers.NOOP, 
Collections.emptyMap());
-                });
-      }
+    /**
+     * This works as a legacy Spark memory consumer which grants as much as 
possible of memory
+     * capacity to each task.
+     */
+    public TreeMemoryTarget newLegacyConsumer(
+        String name, Spiller spiller, Map<String, MemoryUsageStatsBuilder> 
virtualChildren) {
+      final TreeMemoryTarget parent = legacyRoot();
+      return parent.newChild(name, TreeMemoryConsumer.CAPACITY_UNLIMITED, 
spiller, virtualChildren);
     }
 
-    public TreeMemoryTarget newConsumer(
-        TaskMemoryManager tmm,
-        String name,
-        Spiller spiller,
-        Map<String, MemoryUsageStatsBuilder> virtualChildren) {
-      final TreeMemoryTarget account = getSharedAccount(tmm);
-      return account.newChild(
-          name, TreeMemoryConsumer.CAPACITY_UNLIMITED, spiller, 
virtualChildren);
+    /**
+     * A hub to provide memory target instances whose shared size (in the same 
task) is limited to
+     * X, X = executor memory / task slots.
+     *
+     * <p>Using this to prevent OOMs if the delegated memory target could 
possibly hold large memory
+     * blocks that are not spill-able.
+     *
+     * <p>See <a 
href="https://github.com/oap-project/gluten/issues/3030";>GLUTEN-3030</a>
+     */
+    public TreeMemoryTarget newIsolatedConsumer(
+        String name, Spiller spiller, Map<String, MemoryUsageStatsBuilder> 
virtualChildren) {
+      final TreeMemoryTarget parent = isolatedRoot();
+      return parent.newChild(name, TreeMemoryConsumer.CAPACITY_UNLIMITED, 
spiller, virtualChildren);
     }
   }
 }
diff --git 
a/gluten-core/src/main/scala/org/apache/spark/task/TaskResources.scala 
b/gluten-core/src/main/scala/org/apache/spark/task/TaskResources.scala
index df5917125b..2f609b026d 100644
--- a/gluten-core/src/main/scala/org/apache/spark/task/TaskResources.scala
+++ b/gluten-core/src/main/scala/org/apache/spark/task/TaskResources.scala
@@ -16,6 +16,7 @@
  */
 package org.apache.spark.task
 
+import org.apache.gluten.GlutenConfig
 import org.apache.gluten.task.TaskListener
 
 import org.apache.spark.{TaskContext, TaskFailedReason, TaskKilledException, 
UnknownReason}
@@ -65,8 +66,8 @@ object TaskResources extends TaskListener with Logging {
         properties.put(key, value)
       case _ =>
     }
-    properties.setIfMissing("spark.memory.offHeap.enabled", "true")
-    properties.setIfMissing("spark.memory.offHeap.size", "1TB")
+    properties.setIfMissing(GlutenConfig.SPARK_OFFHEAP_ENABLED, "true")
+    properties.setIfMissing(GlutenConfig.SPARK_OFFHEAP_SIZE_KEY, "1TB")
     TaskContext.setTaskContext(newUnsafeTaskContext(properties))
   }
 
diff --git 
a/gluten-core/src/test/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumerTest.java
 
b/gluten-core/src/test/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumerTest.java
index befe449186..934300a1ac 100644
--- 
a/gluten-core/src/test/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumerTest.java
+++ 
b/gluten-core/src/test/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumerTest.java
@@ -49,13 +49,10 @@ public class TreeMemoryConsumerTest {
   public void testIsolated() {
     test(
         () -> {
-          final TreeMemoryConsumers.Factory factory = 
TreeMemoryConsumers.isolated();
+          final TreeMemoryConsumers.Factory factory =
+              
TreeMemoryConsumers.factory(TaskContext.get().taskMemoryManager());
           final TreeMemoryTarget consumer =
-              factory.newConsumer(
-                  TaskContext.get().taskMemoryManager(),
-                  "FOO",
-                  Spillers.NOOP,
-                  Collections.emptyMap());
+              factory.newIsolatedConsumer("FOO", Spillers.NOOP, 
Collections.emptyMap());
           Assert.assertEquals(20, consumer.borrow(20));
           Assert.assertEquals(70, consumer.borrow(70));
           Assert.assertEquals(10, consumer.borrow(20));
@@ -64,16 +61,13 @@ public class TreeMemoryConsumerTest {
   }
 
   @Test
-  public void testShared() {
+  public void testLegacy() {
     test(
         () -> {
-          final TreeMemoryConsumers.Factory factory = 
TreeMemoryConsumers.shared();
+          final TreeMemoryConsumers.Factory factory =
+              
TreeMemoryConsumers.factory(TaskContext.get().taskMemoryManager());
           final TreeMemoryTarget consumer =
-              factory.newConsumer(
-                  TaskContext.get().taskMemoryManager(),
-                  "FOO",
-                  Spillers.NOOP,
-                  Collections.emptyMap());
+              factory.newLegacyConsumer("FOO", Spillers.NOOP, 
Collections.emptyMap());
           Assert.assertEquals(20, consumer.borrow(20));
           Assert.assertEquals(70, consumer.borrow(70));
           Assert.assertEquals(20, consumer.borrow(20));
@@ -82,24 +76,16 @@ public class TreeMemoryConsumerTest {
   }
 
   @Test
-  public void testIsolatedAndShared() {
+  public void testIsolatedAndLegacy() {
     test(
         () -> {
-          final TreeMemoryTarget shared =
-              TreeMemoryConsumers.shared()
-                  .newConsumer(
-                      TaskContext.get().taskMemoryManager(),
-                      "FOO",
-                      Spillers.NOOP,
-                      Collections.emptyMap());
-          Assert.assertEquals(110, shared.borrow(110));
+          final TreeMemoryTarget legacy =
+              
TreeMemoryConsumers.factory(TaskContext.get().taskMemoryManager())
+                  .newLegacyConsumer("FOO", Spillers.NOOP, 
Collections.emptyMap());
+          Assert.assertEquals(110, legacy.borrow(110));
           final TreeMemoryTarget isolated =
-              TreeMemoryConsumers.isolated()
-                  .newConsumer(
-                      TaskContext.get().taskMemoryManager(),
-                      "FOO",
-                      Spillers.NOOP,
-                      Collections.emptyMap());
+              
TreeMemoryConsumers.factory(TaskContext.get().taskMemoryManager())
+                  .newIsolatedConsumer("FOO", Spillers.NOOP, 
Collections.emptyMap());
           Assert.assertEquals(100, isolated.borrow(110));
         });
   }
@@ -109,36 +95,32 @@ public class TreeMemoryConsumerTest {
     test(
         () -> {
           final Spillers.AppendableSpillerList spillers = 
Spillers.appendable();
-          final TreeMemoryTarget shared =
-              TreeMemoryConsumers.shared()
-                  .newConsumer(
-                      TaskContext.get().taskMemoryManager(),
-                      "FOO",
-                      spillers,
-                      Collections.emptyMap());
+          final TreeMemoryTarget legacy =
+              
TreeMemoryConsumers.factory(TaskContext.get().taskMemoryManager())
+                  .newLegacyConsumer("FOO", spillers, Collections.emptyMap());
           final AtomicInteger numSpills = new AtomicInteger(0);
           final AtomicLong numSpilledBytes = new AtomicLong(0L);
           spillers.append(
               new Spiller() {
                 @Override
                 public long spill(MemoryTarget self, Phase phase, long size) {
-                  long repaid = shared.repay(size);
+                  long repaid = legacy.repay(size);
                   numSpills.getAndIncrement();
                   numSpilledBytes.getAndAdd(repaid);
                   return repaid;
                 }
               });
-          Assert.assertEquals(300, shared.borrow(300));
-          Assert.assertEquals(300, shared.borrow(300));
+          Assert.assertEquals(300, legacy.borrow(300));
+          Assert.assertEquals(300, legacy.borrow(300));
           Assert.assertEquals(1, numSpills.get());
           Assert.assertEquals(200, numSpilledBytes.get());
-          Assert.assertEquals(400, shared.usedBytes());
+          Assert.assertEquals(400, legacy.usedBytes());
 
-          Assert.assertEquals(300, shared.borrow(300));
-          Assert.assertEquals(300, shared.borrow(300));
+          Assert.assertEquals(300, legacy.borrow(300));
+          Assert.assertEquals(300, legacy.borrow(300));
           Assert.assertEquals(3, numSpills.get());
           Assert.assertEquals(800, numSpilledBytes.get());
-          Assert.assertEquals(400, shared.usedBytes());
+          Assert.assertEquals(400, legacy.usedBytes());
         });
   }
 
@@ -147,36 +129,32 @@ public class TreeMemoryConsumerTest {
     test(
         () -> {
           final Spillers.AppendableSpillerList spillers = 
Spillers.appendable();
-          final TreeMemoryTarget shared =
-              TreeMemoryConsumers.shared()
-                  .newConsumer(
-                      TaskContext.get().taskMemoryManager(),
-                      "FOO",
-                      spillers,
-                      Collections.emptyMap());
+          final TreeMemoryTarget legacy =
+              
TreeMemoryConsumers.factory(TaskContext.get().taskMemoryManager())
+                  .newLegacyConsumer("FOO", spillers, Collections.emptyMap());
           final AtomicInteger numSpills = new AtomicInteger(0);
           final AtomicLong numSpilledBytes = new AtomicLong(0L);
           spillers.append(
               new Spiller() {
                 @Override
                 public long spill(MemoryTarget self, Phase phase, long size) {
-                  long repaid = shared.repay(Long.MAX_VALUE);
+                  long repaid = legacy.repay(Long.MAX_VALUE);
                   numSpills.getAndIncrement();
                   numSpilledBytes.getAndAdd(repaid);
                   return repaid;
                 }
               });
-          Assert.assertEquals(300, shared.borrow(300));
-          Assert.assertEquals(300, shared.borrow(300));
+          Assert.assertEquals(300, legacy.borrow(300));
+          Assert.assertEquals(300, legacy.borrow(300));
           Assert.assertEquals(1, numSpills.get());
           Assert.assertEquals(300, numSpilledBytes.get());
-          Assert.assertEquals(300, shared.usedBytes());
+          Assert.assertEquals(300, legacy.usedBytes());
 
-          Assert.assertEquals(300, shared.borrow(300));
-          Assert.assertEquals(300, shared.borrow(300));
+          Assert.assertEquals(300, legacy.borrow(300));
+          Assert.assertEquals(300, legacy.borrow(300));
           Assert.assertEquals(3, numSpills.get());
           Assert.assertEquals(900, numSpilledBytes.get());
-          Assert.assertEquals(300, shared.usedBytes());
+          Assert.assertEquals(300, legacy.usedBytes());
         });
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to