This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 9320aedc8 feat: Add a `spark.comet.exec.memoryPool` configuration for 
experimenting with various datafusion memory pool setups. (#1021)
9320aedc8 is described below

commit 9320aedc8df2e8f7e5acb42ecdc44f33dff5d592
Author: Kristin Cowalcijk <[email protected]>
AuthorDate: Fri Jan 3 14:33:48 2025 +0800

    feat: Add a `spark.comet.exec.memoryPool` configuration for experimenting 
with various datafusion memory pool setups. (#1021)
---
 .../main/scala/org/apache/comet/CometConf.scala    |   9 +
 docs/source/user-guide/configs.md                  |   1 +
 native/core/src/execution/jni_api.rs               | 211 +++++++++++++++++++--
 .../scala/org/apache/comet/CometExecIterator.scala |  29 ++-
 spark/src/main/scala/org/apache/comet/Native.scala |   3 +
 5 files changed, 231 insertions(+), 22 deletions(-)

diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala 
b/common/src/main/scala/org/apache/comet/CometConf.scala
index 8815ac4eb..2fff0a04c 100644
--- a/common/src/main/scala/org/apache/comet/CometConf.scala
+++ b/common/src/main/scala/org/apache/comet/CometConf.scala
@@ -467,6 +467,15 @@ object CometConf extends ShimCometConf {
       .booleanConf
       .createWithDefault(false)
 
+  val COMET_EXEC_MEMORY_POOL_TYPE: ConfigEntry[String] = 
conf("spark.comet.exec.memoryPool")
+    .doc(
+      "The type of memory pool to be used for Comet native execution. " +
+        "Available memory pool types are 'greedy', 'fair_spill', 
'greedy_task_shared', " +
+        "'fair_spill_task_shared', 'greedy_global' and 'fair_spill_global', By 
default, " +
+        "this config is 'greedy_task_shared'.")
+    .stringConf
+    .createWithDefault("greedy_task_shared")
+
   val COMET_SCAN_PREFETCH_ENABLED: ConfigEntry[Boolean] =
     conf("spark.comet.scan.preFetch.enabled")
       .doc("Whether to enable pre-fetching feature of CometScan.")
diff --git a/docs/source/user-guide/configs.md 
b/docs/source/user-guide/configs.md
index 7881f0763..ecea70254 100644
--- a/docs/source/user-guide/configs.md
+++ b/docs/source/user-guide/configs.md
@@ -48,6 +48,7 @@ Comet provides the following configuration settings.
 | spark.comet.exec.hashJoin.enabled | Whether to enable hashJoin by default. | 
true |
 | spark.comet.exec.localLimit.enabled | Whether to enable localLimit by 
default. | true |
 | spark.comet.exec.memoryFraction | The fraction of memory from Comet memory 
overhead that the native memory manager can use for execution. The purpose of 
this config is to set aside memory for untracked data structures, as well as 
imprecise size estimation during memory acquisition. | 0.7 |
+| spark.comet.exec.memoryPool | The type of memory pool to be used for Comet 
native execution. Available memory pool types are 'greedy', 'fair_spill', 
'greedy_task_shared', 'fair_spill_task_shared', 'greedy_global' and 
'fair_spill_global', By default, this config is 'greedy_task_shared'. | 
greedy_task_shared |
 | spark.comet.exec.project.enabled | Whether to enable project by default. | 
true |
 | spark.comet.exec.replaceSortMergeJoin | Experimental feature to force Spark 
to replace SortMergeJoin with ShuffledHashJoin for improved performance. This 
feature is not stable yet. For more information, refer to the Comet Tuning 
Guide (https://datafusion.apache.org/comet/user-guide/tuning.html). | false |
 | spark.comet.exec.shuffle.compression.codec | The codec of Comet native 
shuffle used to compress shuffle data. Only zstd is supported. Compression can 
be disabled by setting spark.shuffle.compress=false. | zstd |
diff --git a/native/core/src/execution/jni_api.rs 
b/native/core/src/execution/jni_api.rs
index 09caf5e27..b1190d905 100644
--- a/native/core/src/execution/jni_api.rs
+++ b/native/core/src/execution/jni_api.rs
@@ -24,6 +24,9 @@ use datafusion::{
     physical_plan::{display::DisplayableExecutionPlan, 
SendableRecordBatchStream},
     prelude::{SessionConfig, SessionContext},
 };
+use datafusion_execution::memory_pool::{
+    FairSpillPool, GreedyMemoryPool, MemoryPool, TrackConsumersPool,
+};
 use futures::poll;
 use jni::{
     errors::Result as JNIResult,
@@ -51,20 +54,26 @@ use datafusion_comet_proto::spark_operator::Operator;
 use datafusion_common::ScalarValue;
 use datafusion_execution::runtime_env::RuntimeEnvBuilder;
 use futures::stream::StreamExt;
+use jni::sys::JNI_FALSE;
 use jni::{
     objects::GlobalRef,
     sys::{jboolean, jdouble, jintArray, jobjectArray, jstring},
 };
+use std::num::NonZeroUsize;
+use std::sync::Mutex;
 use tokio::runtime::Runtime;
 
 use crate::execution::operators::ScanExec;
 use crate::execution::spark_plan::SparkPlan;
 use log::info;
+use once_cell::sync::{Lazy, OnceCell};
 
 /// Comet native execution context. Kept alive across JNI calls.
 struct ExecutionContext {
     /// The id of the execution context.
     pub id: i64,
+    /// Task attempt id
+    pub task_attempt_id: i64,
     /// The deserialized Spark plan
     pub spark_plan: Operator,
     /// The DataFusion root operator converted from the `spark_plan`
@@ -89,6 +98,51 @@ struct ExecutionContext {
     pub explain_native: bool,
     /// Map of metrics name -> jstring object to cache jni_NewStringUTF calls.
     pub metrics_jstrings: HashMap<String, Arc<GlobalRef>>,
+    /// Memory pool config
+    pub memory_pool_config: MemoryPoolConfig,
+}
+
+#[derive(PartialEq, Eq)]
+enum MemoryPoolType {
+    Unified,
+    Greedy,
+    FairSpill,
+    GreedyTaskShared,
+    FairSpillTaskShared,
+    GreedyGlobal,
+    FairSpillGlobal,
+}
+
+struct MemoryPoolConfig {
+    pool_type: MemoryPoolType,
+    pool_size: usize,
+}
+
+impl MemoryPoolConfig {
+    fn new(pool_type: MemoryPoolType, pool_size: usize) -> Self {
+        Self {
+            pool_type,
+            pool_size,
+        }
+    }
+}
+
+/// The per-task memory pools keyed by task attempt id.
+static TASK_SHARED_MEMORY_POOLS: Lazy<Mutex<HashMap<i64, PerTaskMemoryPool>>> =
+    Lazy::new(|| Mutex::new(HashMap::new()));
+
+struct PerTaskMemoryPool {
+    memory_pool: Arc<dyn MemoryPool>,
+    num_plans: usize,
+}
+
+impl PerTaskMemoryPool {
+    fn new(memory_pool: Arc<dyn MemoryPool>) -> Self {
+        Self {
+            memory_pool,
+            num_plans: 0,
+        }
+    }
 }
 
 /// Accept serialized query plan and return the address of the native query 
plan.
@@ -105,8 +159,11 @@ pub unsafe extern "system" fn 
Java_org_apache_comet_Native_createPlan(
     comet_task_memory_manager_obj: JObject,
     batch_size: jint,
     use_unified_memory_manager: jboolean,
+    memory_pool_type: jstring,
     memory_limit: jlong,
+    memory_limit_per_task: jlong,
     memory_fraction: jdouble,
+    task_attempt_id: jlong,
     debug_native: jboolean,
     explain_native: jboolean,
     worker_threads: jint,
@@ -145,21 +202,27 @@ pub unsafe extern "system" fn 
Java_org_apache_comet_Native_createPlan(
         let task_memory_manager =
             Arc::new(jni_new_global_ref!(env, comet_task_memory_manager_obj)?);
 
+        let memory_pool_type = 
env.get_string(&JString::from_raw(memory_pool_type))?.into();
+        let memory_pool_config = parse_memory_pool_config(
+            use_unified_memory_manager != JNI_FALSE,
+            memory_pool_type,
+            memory_limit,
+            memory_limit_per_task,
+            memory_fraction,
+        )?;
+        let memory_pool =
+            create_memory_pool(&memory_pool_config, task_memory_manager, 
task_attempt_id);
+
         // We need to keep the session context alive. Some session state like 
temporary
         // dictionaries are stored in session context. If it is dropped, the 
temporary
         // dictionaries will be dropped as well.
-        let session = prepare_datafusion_session_context(
-            batch_size as usize,
-            use_unified_memory_manager == 1,
-            memory_limit as usize,
-            memory_fraction,
-            task_memory_manager,
-        )?;
+        let session = prepare_datafusion_session_context(batch_size as usize, 
memory_pool)?;
 
         let plan_creation_time = start.elapsed();
 
         let exec_context = Box::new(ExecutionContext {
             id,
+            task_attempt_id,
             spark_plan,
             root_op: None,
             scans: vec![],
@@ -172,6 +235,7 @@ pub unsafe extern "system" fn 
Java_org_apache_comet_Native_createPlan(
             debug_native: debug_native == 1,
             explain_native: explain_native == 1,
             metrics_jstrings: HashMap::new(),
+            memory_pool_config,
         });
 
         Ok(Box::into_raw(exec_context) as i64)
@@ -181,22 +245,10 @@ pub unsafe extern "system" fn 
Java_org_apache_comet_Native_createPlan(
 /// Configure DataFusion session context.
 fn prepare_datafusion_session_context(
     batch_size: usize,
-    use_unified_memory_manager: bool,
-    memory_limit: usize,
-    memory_fraction: f64,
-    comet_task_memory_manager: Arc<GlobalRef>,
+    memory_pool: Arc<dyn MemoryPool>,
 ) -> CometResult<SessionContext> {
     let mut rt_config = 
RuntimeEnvBuilder::new().with_disk_manager(DiskManagerConfig::NewOs);
-
-    // Check if we are using unified memory manager integrated with Spark.
-    if use_unified_memory_manager {
-        // Set Comet memory pool for native
-        let memory_pool = CometMemoryPool::new(comet_task_memory_manager);
-        rt_config = rt_config.with_memory_pool(Arc::new(memory_pool));
-    } else {
-        // Use the memory pool from DF
-        rt_config = rt_config.with_memory_limit(memory_limit, memory_fraction)
-    }
+    rt_config = rt_config.with_memory_pool(memory_pool);
 
     // Get Datafusion configuration from Spark Execution context
     // can be configured in Comet Spark JVM using Spark --conf parameters
@@ -224,6 +276,107 @@ fn prepare_datafusion_session_context(
     Ok(session_ctx)
 }
 
+fn parse_memory_pool_config(
+    use_unified_memory_manager: bool,
+    memory_pool_type: String,
+    memory_limit: i64,
+    memory_limit_per_task: i64,
+    memory_fraction: f64,
+) -> CometResult<MemoryPoolConfig> {
+    let memory_pool_config = if use_unified_memory_manager {
+        MemoryPoolConfig::new(MemoryPoolType::Unified, 0)
+    } else {
+        // Use the memory pool from DF
+        let pool_size = (memory_limit as f64 * memory_fraction) as usize;
+        let pool_size_per_task = (memory_limit_per_task as f64 * 
memory_fraction) as usize;
+        match memory_pool_type.as_str() {
+            "fair_spill_task_shared" => {
+                MemoryPoolConfig::new(MemoryPoolType::FairSpillTaskShared, 
pool_size_per_task)
+            }
+            "greedy_task_shared" => {
+                MemoryPoolConfig::new(MemoryPoolType::GreedyTaskShared, 
pool_size_per_task)
+            }
+            "fair_spill_global" => {
+                MemoryPoolConfig::new(MemoryPoolType::FairSpillGlobal, 
pool_size)
+            }
+            "greedy_global" => 
MemoryPoolConfig::new(MemoryPoolType::GreedyGlobal, pool_size),
+            "fair_spill" => MemoryPoolConfig::new(MemoryPoolType::FairSpill, 
pool_size_per_task),
+            "greedy" => MemoryPoolConfig::new(MemoryPoolType::Greedy, 
pool_size_per_task),
+            _ => {
+                return Err(CometError::Config(format!(
+                    "Unsupported memory pool type: {}",
+                    memory_pool_type
+                )))
+            }
+        }
+    };
+    Ok(memory_pool_config)
+}
+
+fn create_memory_pool(
+    memory_pool_config: &MemoryPoolConfig,
+    comet_task_memory_manager: Arc<GlobalRef>,
+    task_attempt_id: i64,
+) -> Arc<dyn MemoryPool> {
+    const NUM_TRACKED_CONSUMERS: usize = 10;
+    match memory_pool_config.pool_type {
+        MemoryPoolType::Unified => {
+            // Set Comet memory pool for native
+            let memory_pool = CometMemoryPool::new(comet_task_memory_manager);
+            Arc::new(memory_pool)
+        }
+        MemoryPoolType::Greedy => Arc::new(TrackConsumersPool::new(
+            GreedyMemoryPool::new(memory_pool_config.pool_size),
+            NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(),
+        )),
+        MemoryPoolType::FairSpill => Arc::new(TrackConsumersPool::new(
+            FairSpillPool::new(memory_pool_config.pool_size),
+            NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(),
+        )),
+        MemoryPoolType::GreedyGlobal => {
+            static GLOBAL_MEMORY_POOL_GREEDY: OnceCell<Arc<dyn MemoryPool>> = 
OnceCell::new();
+            let memory_pool = GLOBAL_MEMORY_POOL_GREEDY.get_or_init(|| {
+                Arc::new(TrackConsumersPool::new(
+                    GreedyMemoryPool::new(memory_pool_config.pool_size),
+                    NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(),
+                ))
+            });
+            Arc::clone(memory_pool)
+        }
+        MemoryPoolType::FairSpillGlobal => {
+            static GLOBAL_MEMORY_POOL_FAIR: OnceCell<Arc<dyn MemoryPool>> = 
OnceCell::new();
+            let memory_pool = GLOBAL_MEMORY_POOL_FAIR.get_or_init(|| {
+                Arc::new(TrackConsumersPool::new(
+                    FairSpillPool::new(memory_pool_config.pool_size),
+                    NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(),
+                ))
+            });
+            Arc::clone(memory_pool)
+        }
+        MemoryPoolType::GreedyTaskShared | MemoryPoolType::FairSpillTaskShared 
=> {
+            let mut memory_pool_map = TASK_SHARED_MEMORY_POOLS.lock().unwrap();
+            let per_task_memory_pool =
+                memory_pool_map.entry(task_attempt_id).or_insert_with(|| {
+                    let pool: Arc<dyn MemoryPool> =
+                        if memory_pool_config.pool_type == 
MemoryPoolType::GreedyTaskShared {
+                            Arc::new(TrackConsumersPool::new(
+                                
GreedyMemoryPool::new(memory_pool_config.pool_size),
+                                
NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(),
+                            ))
+                        } else {
+                            Arc::new(TrackConsumersPool::new(
+                                
FairSpillPool::new(memory_pool_config.pool_size),
+                                
NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(),
+                            ))
+                        };
+                    PerTaskMemoryPool::new(pool)
+                });
+            per_task_memory_pool.num_plans += 1;
+            Arc::clone(&per_task_memory_pool.memory_pool)
+        }
+    }
+}
+
 /// Prepares arrow arrays for output.
 fn prepare_output(
     env: &mut JNIEnv,
@@ -407,6 +560,22 @@ pub extern "system" fn 
Java_org_apache_comet_Native_releasePlan(
 ) {
     try_unwrap_or_throw(&e, |_| unsafe {
         let execution_context = get_execution_context(exec_context);
+        if execution_context.memory_pool_config.pool_type == 
MemoryPoolType::FairSpillTaskShared
+            || execution_context.memory_pool_config.pool_type == 
MemoryPoolType::GreedyTaskShared
+        {
+            // Decrement the number of native plans using the per-task shared 
memory pool, and
+            // remove the memory pool if the released native plan is the last 
native plan using it.
+            let task_attempt_id = execution_context.task_attempt_id;
+            let mut memory_pool_map = TASK_SHARED_MEMORY_POOLS.lock().unwrap();
+            if let Some(per_task_memory_pool) = 
memory_pool_map.get_mut(&task_attempt_id) {
+                per_task_memory_pool.num_plans -= 1;
+                if per_task_memory_pool.num_plans == 0 {
+                    // Drop the memory pool from the per-task memory pool map 
if there are no
+                    // more native plans using it.
+                    memory_pool_map.remove(&task_attempt_id);
+                }
+            }
+        }
         let _: Box<ExecutionContext> = Box::from_raw(execution_context);
         Ok(())
     })
diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala 
b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala
index 04d930695..0b90a91c7 100644
--- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala
+++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala
@@ -23,7 +23,7 @@ import org.apache.spark._
 import org.apache.spark.sql.comet.CometMetricNode
 import org.apache.spark.sql.vectorized._
 
-import org.apache.comet.CometConf.{COMET_BATCH_SIZE, COMET_BLOCKING_THREADS, 
COMET_DEBUG_ENABLED, COMET_EXEC_MEMORY_FRACTION, COMET_EXPLAIN_NATIVE_ENABLED, 
COMET_WORKER_THREADS}
+import org.apache.comet.CometConf.{COMET_BATCH_SIZE, COMET_BLOCKING_THREADS, 
COMET_DEBUG_ENABLED, COMET_EXEC_MEMORY_FRACTION, COMET_EXEC_MEMORY_POOL_TYPE, 
COMET_EXPLAIN_NATIVE_ENABLED, COMET_WORKER_THREADS}
 import org.apache.comet.vector.NativeUtil
 
 /**
@@ -72,8 +72,11 @@ class CometExecIterator(
       new CometTaskMemoryManager(id),
       batchSize = COMET_BATCH_SIZE.get(),
       use_unified_memory_manager = 
conf.getBoolean("spark.memory.offHeap.enabled", false),
+      memory_pool_type = COMET_EXEC_MEMORY_POOL_TYPE.get(),
       memory_limit = CometSparkSessionExtensions.getCometMemoryOverhead(conf),
+      memory_limit_per_task = getMemoryLimitPerTask(conf),
       memory_fraction = COMET_EXEC_MEMORY_FRACTION.get(),
+      task_attempt_id = TaskContext.get().taskAttemptId,
       debug = COMET_DEBUG_ENABLED.get(),
       explain = COMET_EXPLAIN_NATIVE_ENABLED.get(),
       workerThreads = COMET_WORKER_THREADS.get(),
@@ -84,6 +87,30 @@ class CometExecIterator(
   private var currentBatch: ColumnarBatch = null
   private var closed: Boolean = false
 
+  private def getMemoryLimitPerTask(conf: SparkConf): Long = {
+    val numCores = numDriverOrExecutorCores(conf).toFloat
+    val maxMemory = CometSparkSessionExtensions.getCometMemoryOverhead(conf)
+    val coresPerTask = conf.get("spark.task.cpus", "1").toFloat
+    // example 16GB maxMemory * 16 cores with 4 cores per task results
+    // in memory_limit_per_task = 16 GB * 4 / 16 = 16 GB / 4 = 4GB
+    (maxMemory.toFloat * coresPerTask / numCores).toLong
+  }
+
+  private def numDriverOrExecutorCores(conf: SparkConf): Int = {
+    def convertToInt(threads: String): Int = {
+      if (threads == "*") Runtime.getRuntime.availableProcessors() else 
threads.toInt
+    }
+    val LOCAL_N_REGEX = """local\[([0-9]+|\*)\]""".r
+    val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+|\*)\s*,\s*([0-9]+)\]""".r
+    val master = conf.get("spark.master")
+    master match {
+      case "local" => 1
+      case LOCAL_N_REGEX(threads) => convertToInt(threads)
+      case LOCAL_N_FAILURES_REGEX(threads, _) => convertToInt(threads)
+      case _ => conf.get("spark.executor.cores", "1").toInt
+    }
+  }
+
   def getNextBatch(): Option[ColumnarBatch] = {
     assert(partitionIndex >= 0 && partitionIndex < numParts)
 
diff --git a/spark/src/main/scala/org/apache/comet/Native.scala 
b/spark/src/main/scala/org/apache/comet/Native.scala
index 083c0f2b5..5fd84989b 100644
--- a/spark/src/main/scala/org/apache/comet/Native.scala
+++ b/spark/src/main/scala/org/apache/comet/Native.scala
@@ -52,8 +52,11 @@ class Native extends NativeBase {
       taskMemoryManager: CometTaskMemoryManager,
       batchSize: Int,
       use_unified_memory_manager: Boolean,
+      memory_pool_type: String,
       memory_limit: Long,
+      memory_limit_per_task: Long,
       memory_fraction: Double,
+      task_attempt_id: Long,
       debug: Boolean,
       explain: Boolean,
       workerThreads: Int,


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to