This is an automated email from the ASF dual-hosted git repository.
hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 44289300f5 [VL] Add API for reserving global off-heap memory from
Spark (#9066)
44289300f5 is described below
commit 44289300f5aa01d62d1a53d11360488c3c41fbe9
Author: Hongze Zhang <[email protected]>
AuthorDate: Thu Mar 20 15:00:03 2025 +0000
[VL] Add API for reserving global off-heap memory from Spark (#9066)
---
.../apache/spark/memory/GlobalOffHeapMemory.scala | 57 ++++++++++
.../spark/memory/GlobalOffHeapMemorySuite.scala | 121 +++++++++++++++++++++
.../apache/spark/storage/BlockManagerUtils.scala | 49 +++++++++
.../scala/org/apache/spark/TaskContextUtils.scala | 2 +
.../scala/org/apache/spark/TaskContextUtils.scala | 2 +
.../scala/org/apache/spark/TaskContextUtils.scala | 2 +
.../scala/org/apache/spark/TaskContextUtils.scala | 2 +
7 files changed, 235 insertions(+)
diff --git
a/backends-velox/src/main/scala/org/apache/spark/memory/GlobalOffHeapMemory.scala
b/backends-velox/src/main/scala/org/apache/spark/memory/GlobalOffHeapMemory.scala
new file mode 100644
index 0000000000..cfaf797869
--- /dev/null
+++
b/backends-velox/src/main/scala/org/apache/spark/memory/GlobalOffHeapMemory.scala
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.memory
+
+import org.apache.gluten.exception.GlutenException
+
+import org.apache.spark.{SparkEnv, TaskContext}
+import org.apache.spark.storage.BlockId
+
+import java.lang.reflect.Field
+import java.util.UUID
+
+object GlobalOffHeapMemory {
+ private val FIELD_MEMORY_MANAGER: Field = {
+ val f = classOf[TaskMemoryManager].getDeclaredField("memoryManager")
+ f.setAccessible(true)
+ f
+ }
+
+ def acquire(numBytes: Long): Boolean = {
+ memoryManager().acquireStorageMemory(
+ BlockId(s"test_${UUID.randomUUID()}"),
+ numBytes,
+ MemoryMode.OFF_HEAP)
+ }
+
+ def free(numBytes: Long): Unit = {
+ memoryManager().releaseStorageMemory(numBytes, MemoryMode.OFF_HEAP)
+ }
+
+ private def memoryManager(): MemoryManager = {
+ val env = SparkEnv.get
+ if (env != null) {
+ return env.memoryManager
+ }
+ val tc = TaskContext.get()
+ if (tc != null) {
+ return
FIELD_MEMORY_MANAGER.get(tc.taskMemoryManager()).asInstanceOf[MemoryManager]
+ }
+ throw new GlutenException(
+ "Memory manager not found because the code is unlikely be run in a Spark
application")
+ }
+}
diff --git
a/backends-velox/src/test/scala/org/apache/spark/memory/GlobalOffHeapMemorySuite.scala
b/backends-velox/src/test/scala/org/apache/spark/memory/GlobalOffHeapMemorySuite.scala
new file mode 100644
index 0000000000..2a25d09d3b
--- /dev/null
+++
b/backends-velox/src/test/scala/org/apache/spark/memory/GlobalOffHeapMemorySuite.scala
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.memory;
+
+import org.apache.gluten.config.GlutenConfig
+import org.apache.gluten.memory.memtarget.{Spillers, TreeMemoryTarget}
+import org.apache.gluten.memory.memtarget.spark.TreeMemoryConsumers
+
+import org.apache.spark.TaskContext
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.task.TaskResources
+
+import org.junit.Assert
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.funsuite.AnyFunSuite
+
+import java.util.Collections;
+
+class GlobalOffHeapMemorySuite extends AnyFunSuite with BeforeAndAfterAll {
+
+ override protected def beforeAll(): Unit = {
+ val conf = SQLConf.get
+ conf.setConfString("spark.memory.offHeap.enabled", "true")
+ conf.setConfString("spark.memory.offHeap.size", "400")
+
conf.setConfString(GlutenConfig.COLUMNAR_CONSERVATIVE_TASK_OFFHEAP_SIZE_IN_BYTES.key,
"100")
+ }
+
+ test("Sanity") {
+ TaskResources.runUnsafe {
+ val factory =
+ TreeMemoryConsumers.factory(TaskContext.get().taskMemoryManager())
+ val consumer =
+ factory
+ .legacyRoot()
+ .newChild(
+ "FOO",
+ TreeMemoryTarget.CAPACITY_UNLIMITED,
+ Spillers.NOOP,
+ Collections.emptyMap())
+ Assert.assertEquals(300, consumer.borrow(300))
+ Assert.assertTrue(GlobalOffHeapMemory.acquire(50))
+ Assert.assertTrue(GlobalOffHeapMemory.acquire(40))
+ Assert.assertFalse(GlobalOffHeapMemory.acquire(30))
+ Assert.assertFalse(GlobalOffHeapMemory.acquire(11))
+ Assert.assertTrue(GlobalOffHeapMemory.acquire(10))
+ Assert.assertTrue(GlobalOffHeapMemory.acquire(0))
+ Assert.assertFalse(GlobalOffHeapMemory.acquire(1))
+ }
+ }
+
+ test("Task OOM by global occupation") {
+ TaskResources.runUnsafe {
+ val factory =
+ TreeMemoryConsumers.factory(TaskContext.get().taskMemoryManager())
+ val consumer =
+ factory
+ .legacyRoot()
+ .newChild(
+ "FOO",
+ TreeMemoryTarget.CAPACITY_UNLIMITED,
+ Spillers.NOOP,
+ Collections.emptyMap())
+ Assert.assertTrue(GlobalOffHeapMemory.acquire(200))
+ Assert.assertEquals(100, consumer.borrow(100))
+ Assert.assertEquals(100, consumer.borrow(200))
+ Assert.assertEquals(0, consumer.borrow(50))
+ }
+ }
+
+ test("Release global") {
+ TaskResources.runUnsafe {
+ val factory =
+ TreeMemoryConsumers.factory(TaskContext.get().taskMemoryManager())
+ val consumer =
+ factory
+ .legacyRoot()
+ .newChild(
+ "FOO",
+ TreeMemoryTarget.CAPACITY_UNLIMITED,
+ Spillers.NOOP,
+ Collections.emptyMap())
+ Assert.assertTrue(GlobalOffHeapMemory.acquire(300))
+ Assert.assertEquals(100, consumer.borrow(200))
+ GlobalOffHeapMemory.free(10)
+ Assert.assertEquals(10, consumer.borrow(50))
+ }
+ }
+
+ test("Release task") {
+ TaskResources.runUnsafe {
+ val factory =
+ TreeMemoryConsumers.factory(TaskContext.get().taskMemoryManager())
+ val consumer =
+ factory
+ .legacyRoot()
+ .newChild(
+ "FOO",
+ TreeMemoryTarget.CAPACITY_UNLIMITED,
+ Spillers.NOOP,
+ Collections.emptyMap())
+ Assert.assertEquals(300, consumer.borrow(300))
+ Assert.assertFalse(GlobalOffHeapMemory.acquire(200))
+ Assert.assertEquals(100, consumer.repay(100))
+ Assert.assertTrue(GlobalOffHeapMemory.acquire(200))
+ }
+ }
+}
diff --git
a/shims/common/src/main/scala/org/apache/spark/storage/BlockManagerUtils.scala
b/shims/common/src/main/scala/org/apache/spark/storage/BlockManagerUtils.scala
new file mode 100644
index 0000000000..5a68d15d2a
--- /dev/null
+++
b/shims/common/src/main/scala/org/apache/spark/storage/BlockManagerUtils.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.storage
+
+import org.apache.spark.SparkConf
+import org.apache.spark.internal.config.SERIALIZER
+import org.apache.spark.memory.MemoryManager
+import org.apache.spark.serializer.{Serializer, SerializerManager}
+import org.apache.spark.storage.memory.{BlockEvictionHandler, MemoryStore}
+import org.apache.spark.util.Utils
+import org.apache.spark.util.io.ChunkedByteBuffer
+
+import scala.reflect.ClassTag
+
+object BlockManagerUtils {
+ def setTestMemoryStore(conf: SparkConf, memoryManager: MemoryManager,
isDriver: Boolean): Unit = {
+ val store = new MemoryStore(
+ conf,
+ new BlockInfoManager,
+ new SerializerManager(
+ Utils.instantiateSerializerFromConf[Serializer](SERIALIZER, conf,
isDriver),
+ conf),
+ memoryManager,
+ new BlockEvictionHandler {
+ override private[storage] def dropFromMemory[T: ClassTag](
+ blockId: BlockId,
+ data: () => Either[Array[T], ChunkedByteBuffer]): StorageLevel = {
+ throw new UnsupportedOperationException(
+ s"Cannot drop block ID $blockId from test memory store")
+ }
+ }
+ )
+ memoryManager.setMemoryStore(store)
+ }
+}
diff --git
a/shims/spark32/src/main/scala/org/apache/spark/TaskContextUtils.scala
b/shims/spark32/src/main/scala/org/apache/spark/TaskContextUtils.scala
index ac7f926f1f..8ff7717fb9 100644
--- a/shims/spark32/src/main/scala/org/apache/spark/TaskContextUtils.scala
+++ b/shims/spark32/src/main/scala/org/apache/spark/TaskContextUtils.scala
@@ -19,6 +19,7 @@ package org.apache.spark
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.memory.{TaskMemoryManager, UnifiedMemoryManager}
import org.apache.spark.metrics.MetricsSystem
+import org.apache.spark.storage.BlockManagerUtils
import java.util.Properties
@@ -29,6 +30,7 @@ object TaskContextUtils {
val conf = new SparkConf()
conf.setAll(properties.asScala)
val memoryManager = UnifiedMemoryManager(conf, 1)
+ BlockManagerUtils.setTestMemoryStore(conf, memoryManager, isDriver = false)
new TaskContextImpl(
-1,
-1,
diff --git
a/shims/spark33/src/main/scala/org/apache/spark/TaskContextUtils.scala
b/shims/spark33/src/main/scala/org/apache/spark/TaskContextUtils.scala
index c4fea992d5..058467888f 100644
--- a/shims/spark33/src/main/scala/org/apache/spark/TaskContextUtils.scala
+++ b/shims/spark33/src/main/scala/org/apache/spark/TaskContextUtils.scala
@@ -19,6 +19,7 @@ package org.apache.spark
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.memory.{TaskMemoryManager, UnifiedMemoryManager}
import org.apache.spark.metrics.MetricsSystem
+import org.apache.spark.storage.BlockManagerUtils
import java.util.Properties
@@ -29,6 +30,7 @@ object TaskContextUtils {
val conf = new SparkConf()
conf.setAll(properties.asScala)
val memoryManager = UnifiedMemoryManager(conf, 1)
+ BlockManagerUtils.setTestMemoryStore(conf, memoryManager, isDriver = false)
new TaskContextImpl(
-1,
-1,
diff --git
a/shims/spark34/src/main/scala/org/apache/spark/TaskContextUtils.scala
b/shims/spark34/src/main/scala/org/apache/spark/TaskContextUtils.scala
index 7a81b61211..267b177920 100644
--- a/shims/spark34/src/main/scala/org/apache/spark/TaskContextUtils.scala
+++ b/shims/spark34/src/main/scala/org/apache/spark/TaskContextUtils.scala
@@ -19,6 +19,7 @@ package org.apache.spark
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.memory.{TaskMemoryManager, UnifiedMemoryManager}
import org.apache.spark.metrics.MetricsSystem
+import org.apache.spark.storage.BlockManagerUtils
import java.util.Properties
@@ -29,6 +30,7 @@ object TaskContextUtils {
val conf = new SparkConf()
conf.setAll(properties.asScala)
val memoryManager = UnifiedMemoryManager(conf, 1)
+ BlockManagerUtils.setTestMemoryStore(conf, memoryManager, isDriver = false)
new TaskContextImpl(
-1,
-1,
diff --git
a/shims/spark35/src/main/scala/org/apache/spark/TaskContextUtils.scala
b/shims/spark35/src/main/scala/org/apache/spark/TaskContextUtils.scala
index 7a81b61211..267b177920 100644
--- a/shims/spark35/src/main/scala/org/apache/spark/TaskContextUtils.scala
+++ b/shims/spark35/src/main/scala/org/apache/spark/TaskContextUtils.scala
@@ -19,6 +19,7 @@ package org.apache.spark
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.memory.{TaskMemoryManager, UnifiedMemoryManager}
import org.apache.spark.metrics.MetricsSystem
+import org.apache.spark.storage.BlockManagerUtils
import java.util.Properties
@@ -29,6 +30,7 @@ object TaskContextUtils {
val conf = new SparkConf()
conf.setAll(properties.asScala)
val memoryManager = UnifiedMemoryManager(conf, 1)
+ BlockManagerUtils.setTestMemoryStore(conf, memoryManager, isDriver = false)
new TaskContextImpl(
-1,
-1,
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]