This is an automated email from the ASF dual-hosted git repository.

hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 44289300f5 [VL] Add API for reserving global off-heap memory from 
Spark (#9066)
44289300f5 is described below

commit 44289300f5aa01d62d1a53d11360488c3c41fbe9
Author: Hongze Zhang <[email protected]>
AuthorDate: Thu Mar 20 15:00:03 2025 +0000

    [VL] Add API for reserving global off-heap memory from Spark (#9066)
---
 .../apache/spark/memory/GlobalOffHeapMemory.scala  |  57 ++++++++++
 .../spark/memory/GlobalOffHeapMemorySuite.scala    | 121 +++++++++++++++++++++
 .../apache/spark/storage/BlockManagerUtils.scala   |  49 +++++++++
 .../scala/org/apache/spark/TaskContextUtils.scala  |   2 +
 .../scala/org/apache/spark/TaskContextUtils.scala  |   2 +
 .../scala/org/apache/spark/TaskContextUtils.scala  |   2 +
 .../scala/org/apache/spark/TaskContextUtils.scala  |   2 +
 7 files changed, 235 insertions(+)

diff --git 
a/backends-velox/src/main/scala/org/apache/spark/memory/GlobalOffHeapMemory.scala
 
b/backends-velox/src/main/scala/org/apache/spark/memory/GlobalOffHeapMemory.scala
new file mode 100644
index 0000000000..cfaf797869
--- /dev/null
+++ 
b/backends-velox/src/main/scala/org/apache/spark/memory/GlobalOffHeapMemory.scala
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.memory
+
+import org.apache.gluten.exception.GlutenException
+
+import org.apache.spark.{SparkEnv, TaskContext}
+import org.apache.spark.storage.BlockId
+
+import java.lang.reflect.Field
+import java.util.UUID
+
+object GlobalOffHeapMemory {
+  private val FIELD_MEMORY_MANAGER: Field = {
+    val f = classOf[TaskMemoryManager].getDeclaredField("memoryManager")
+    f.setAccessible(true)
+    f
+  }
+
+  def acquire(numBytes: Long): Boolean = {
+    memoryManager().acquireStorageMemory(
+      BlockId(s"test_${UUID.randomUUID()}"),
+      numBytes,
+      MemoryMode.OFF_HEAP)
+  }
+
+  def free(numBytes: Long): Unit = {
+    memoryManager().releaseStorageMemory(numBytes, MemoryMode.OFF_HEAP)
+  }
+
+  private def memoryManager(): MemoryManager = {
+    val env = SparkEnv.get
+    if (env != null) {
+      return env.memoryManager
+    }
+    val tc = TaskContext.get()
+    if (tc != null) {
+      return 
FIELD_MEMORY_MANAGER.get(tc.taskMemoryManager()).asInstanceOf[MemoryManager]
+    }
+    throw new GlutenException(
+      "Memory manager not found because the code is unlikely be run in a Spark 
application")
+  }
+}
diff --git 
a/backends-velox/src/test/scala/org/apache/spark/memory/GlobalOffHeapMemorySuite.scala
 
b/backends-velox/src/test/scala/org/apache/spark/memory/GlobalOffHeapMemorySuite.scala
new file mode 100644
index 0000000000..2a25d09d3b
--- /dev/null
+++ 
b/backends-velox/src/test/scala/org/apache/spark/memory/GlobalOffHeapMemorySuite.scala
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.memory;
+
+import org.apache.gluten.config.GlutenConfig
+import org.apache.gluten.memory.memtarget.{Spillers, TreeMemoryTarget}
+import org.apache.gluten.memory.memtarget.spark.TreeMemoryConsumers
+
+import org.apache.spark.TaskContext
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.task.TaskResources
+
+import org.junit.Assert
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.funsuite.AnyFunSuite
+
+import java.util.Collections;
+
+class GlobalOffHeapMemorySuite extends AnyFunSuite with BeforeAndAfterAll {
+
+  override protected def beforeAll(): Unit = {
+    val conf = SQLConf.get
+    conf.setConfString("spark.memory.offHeap.enabled", "true")
+    conf.setConfString("spark.memory.offHeap.size", "400")
+    
conf.setConfString(GlutenConfig.COLUMNAR_CONSERVATIVE_TASK_OFFHEAP_SIZE_IN_BYTES.key,
 "100")
+  }
+
+  test("Sanity") {
+    TaskResources.runUnsafe {
+      val factory =
+        TreeMemoryConsumers.factory(TaskContext.get().taskMemoryManager())
+      val consumer =
+        factory
+          .legacyRoot()
+          .newChild(
+            "FOO",
+            TreeMemoryTarget.CAPACITY_UNLIMITED,
+            Spillers.NOOP,
+            Collections.emptyMap())
+      Assert.assertEquals(300, consumer.borrow(300))
+      Assert.assertTrue(GlobalOffHeapMemory.acquire(50))
+      Assert.assertTrue(GlobalOffHeapMemory.acquire(40))
+      Assert.assertFalse(GlobalOffHeapMemory.acquire(30))
+      Assert.assertFalse(GlobalOffHeapMemory.acquire(11))
+      Assert.assertTrue(GlobalOffHeapMemory.acquire(10))
+      Assert.assertTrue(GlobalOffHeapMemory.acquire(0))
+      Assert.assertFalse(GlobalOffHeapMemory.acquire(1))
+    }
+  }
+
+  test("Task OOM by global occupation") {
+    TaskResources.runUnsafe {
+      val factory =
+        TreeMemoryConsumers.factory(TaskContext.get().taskMemoryManager())
+      val consumer =
+        factory
+          .legacyRoot()
+          .newChild(
+            "FOO",
+            TreeMemoryTarget.CAPACITY_UNLIMITED,
+            Spillers.NOOP,
+            Collections.emptyMap())
+      Assert.assertTrue(GlobalOffHeapMemory.acquire(200))
+      Assert.assertEquals(100, consumer.borrow(100))
+      Assert.assertEquals(100, consumer.borrow(200))
+      Assert.assertEquals(0, consumer.borrow(50))
+    }
+  }
+
+  test("Release global") {
+    TaskResources.runUnsafe {
+      val factory =
+        TreeMemoryConsumers.factory(TaskContext.get().taskMemoryManager())
+      val consumer =
+        factory
+          .legacyRoot()
+          .newChild(
+            "FOO",
+            TreeMemoryTarget.CAPACITY_UNLIMITED,
+            Spillers.NOOP,
+            Collections.emptyMap())
+      Assert.assertTrue(GlobalOffHeapMemory.acquire(300))
+      Assert.assertEquals(100, consumer.borrow(200))
+      GlobalOffHeapMemory.free(10)
+      Assert.assertEquals(10, consumer.borrow(50))
+    }
+  }
+
+  test("Release task") {
+    TaskResources.runUnsafe {
+      val factory =
+        TreeMemoryConsumers.factory(TaskContext.get().taskMemoryManager())
+      val consumer =
+        factory
+          .legacyRoot()
+          .newChild(
+            "FOO",
+            TreeMemoryTarget.CAPACITY_UNLIMITED,
+            Spillers.NOOP,
+            Collections.emptyMap())
+      Assert.assertEquals(300, consumer.borrow(300))
+      Assert.assertFalse(GlobalOffHeapMemory.acquire(200))
+      Assert.assertEquals(100, consumer.repay(100))
+      Assert.assertTrue(GlobalOffHeapMemory.acquire(200))
+    }
+  }
+}
diff --git 
a/shims/common/src/main/scala/org/apache/spark/storage/BlockManagerUtils.scala 
b/shims/common/src/main/scala/org/apache/spark/storage/BlockManagerUtils.scala
new file mode 100644
index 0000000000..5a68d15d2a
--- /dev/null
+++ 
b/shims/common/src/main/scala/org/apache/spark/storage/BlockManagerUtils.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.storage
+
+import org.apache.spark.SparkConf
+import org.apache.spark.internal.config.SERIALIZER
+import org.apache.spark.memory.MemoryManager
+import org.apache.spark.serializer.{Serializer, SerializerManager}
+import org.apache.spark.storage.memory.{BlockEvictionHandler, MemoryStore}
+import org.apache.spark.util.Utils
+import org.apache.spark.util.io.ChunkedByteBuffer
+
+import scala.reflect.ClassTag
+
+object BlockManagerUtils {
+  def setTestMemoryStore(conf: SparkConf, memoryManager: MemoryManager, 
isDriver: Boolean): Unit = {
+    val store = new MemoryStore(
+      conf,
+      new BlockInfoManager,
+      new SerializerManager(
+        Utils.instantiateSerializerFromConf[Serializer](SERIALIZER, conf, 
isDriver),
+        conf),
+      memoryManager,
+      new BlockEvictionHandler {
+        override private[storage] def dropFromMemory[T: ClassTag](
+            blockId: BlockId,
+            data: () => Either[Array[T], ChunkedByteBuffer]): StorageLevel = {
+          throw new UnsupportedOperationException(
+            s"Cannot drop block ID $blockId from test memory store")
+        }
+      }
+    )
+    memoryManager.setMemoryStore(store)
+  }
+}
diff --git 
a/shims/spark32/src/main/scala/org/apache/spark/TaskContextUtils.scala 
b/shims/spark32/src/main/scala/org/apache/spark/TaskContextUtils.scala
index ac7f926f1f..8ff7717fb9 100644
--- a/shims/spark32/src/main/scala/org/apache/spark/TaskContextUtils.scala
+++ b/shims/spark32/src/main/scala/org/apache/spark/TaskContextUtils.scala
@@ -19,6 +19,7 @@ package org.apache.spark
 import org.apache.spark.executor.TaskMetrics
 import org.apache.spark.memory.{TaskMemoryManager, UnifiedMemoryManager}
 import org.apache.spark.metrics.MetricsSystem
+import org.apache.spark.storage.BlockManagerUtils
 
 import java.util.Properties
 
@@ -29,6 +30,7 @@ object TaskContextUtils {
     val conf = new SparkConf()
     conf.setAll(properties.asScala)
     val memoryManager = UnifiedMemoryManager(conf, 1)
+    BlockManagerUtils.setTestMemoryStore(conf, memoryManager, isDriver = false)
     new TaskContextImpl(
       -1,
       -1,
diff --git 
a/shims/spark33/src/main/scala/org/apache/spark/TaskContextUtils.scala 
b/shims/spark33/src/main/scala/org/apache/spark/TaskContextUtils.scala
index c4fea992d5..058467888f 100644
--- a/shims/spark33/src/main/scala/org/apache/spark/TaskContextUtils.scala
+++ b/shims/spark33/src/main/scala/org/apache/spark/TaskContextUtils.scala
@@ -19,6 +19,7 @@ package org.apache.spark
 import org.apache.spark.executor.TaskMetrics
 import org.apache.spark.memory.{TaskMemoryManager, UnifiedMemoryManager}
 import org.apache.spark.metrics.MetricsSystem
+import org.apache.spark.storage.BlockManagerUtils
 
 import java.util.Properties
 
@@ -29,6 +30,7 @@ object TaskContextUtils {
     val conf = new SparkConf()
     conf.setAll(properties.asScala)
     val memoryManager = UnifiedMemoryManager(conf, 1)
+    BlockManagerUtils.setTestMemoryStore(conf, memoryManager, isDriver = false)
     new TaskContextImpl(
       -1,
       -1,
diff --git 
a/shims/spark34/src/main/scala/org/apache/spark/TaskContextUtils.scala 
b/shims/spark34/src/main/scala/org/apache/spark/TaskContextUtils.scala
index 7a81b61211..267b177920 100644
--- a/shims/spark34/src/main/scala/org/apache/spark/TaskContextUtils.scala
+++ b/shims/spark34/src/main/scala/org/apache/spark/TaskContextUtils.scala
@@ -19,6 +19,7 @@ package org.apache.spark
 import org.apache.spark.executor.TaskMetrics
 import org.apache.spark.memory.{TaskMemoryManager, UnifiedMemoryManager}
 import org.apache.spark.metrics.MetricsSystem
+import org.apache.spark.storage.BlockManagerUtils
 
 import java.util.Properties
 
@@ -29,6 +30,7 @@ object TaskContextUtils {
     val conf = new SparkConf()
     conf.setAll(properties.asScala)
     val memoryManager = UnifiedMemoryManager(conf, 1)
+    BlockManagerUtils.setTestMemoryStore(conf, memoryManager, isDriver = false)
     new TaskContextImpl(
       -1,
       -1,
diff --git 
a/shims/spark35/src/main/scala/org/apache/spark/TaskContextUtils.scala 
b/shims/spark35/src/main/scala/org/apache/spark/TaskContextUtils.scala
index 7a81b61211..267b177920 100644
--- a/shims/spark35/src/main/scala/org/apache/spark/TaskContextUtils.scala
+++ b/shims/spark35/src/main/scala/org/apache/spark/TaskContextUtils.scala
@@ -19,6 +19,7 @@ package org.apache.spark
 import org.apache.spark.executor.TaskMetrics
 import org.apache.spark.memory.{TaskMemoryManager, UnifiedMemoryManager}
 import org.apache.spark.metrics.MetricsSystem
+import org.apache.spark.storage.BlockManagerUtils
 
 import java.util.Properties
 
@@ -29,6 +30,7 @@ object TaskContextUtils {
     val conf = new SparkConf()
     conf.setAll(properties.asScala)
     val memoryManager = UnifiedMemoryManager(conf, 1)
+    BlockManagerUtils.setTestMemoryStore(conf, memoryManager, isDriver = false)
     new TaskContextImpl(
       -1,
       -1,


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to