This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 2671e0ca Stop passing Java config map into native createPlan (#1101)
2671e0ca is described below
commit 2671e0cafbd6fe96c9af852ad7f34592e299a62a
Author: Andy Grove <[email protected]>
AuthorDate: Tue Dec 3 17:40:54 2024 -0700
Stop passing Java config map into native createPlan (#1101)
---
native/core/src/execution/jni_api.rs | 69 +++++-----------------
.../scala/org/apache/comet/CometExecIterator.scala | 32 ++--------
spark/src/main/scala/org/apache/comet/Native.scala | 10 ++--
3 files changed, 28 insertions(+), 83 deletions(-)
diff --git a/native/core/src/execution/jni_api.rs
b/native/core/src/execution/jni_api.rs
index 083744f0..8afe134c 100644
--- a/native/core/src/execution/jni_api.rs
+++ b/native/core/src/execution/jni_api.rs
@@ -31,8 +31,8 @@ use futures::poll;
use jni::{
errors::Result as JNIResult,
objects::{
- JByteArray, JClass, JIntArray, JLongArray, JMap, JObject,
JObjectArray, JPrimitiveArray,
- JString, ReleaseMode,
+ JByteArray, JClass, JIntArray, JLongArray, JObject, JObjectArray,
JPrimitiveArray, JString,
+ ReleaseMode,
},
sys::{jbyteArray, jint, jlong, jlongArray},
JNIEnv,
@@ -77,8 +77,6 @@ struct ExecutionContext {
pub input_sources: Vec<Arc<GlobalRef>>,
/// The record batch stream to pull results from
pub stream: Option<SendableRecordBatchStream>,
- /// Configurations for DF execution
- pub conf: HashMap<String, String>,
/// The Tokio runtime used for async.
pub runtime: Runtime,
/// Native metrics
@@ -103,11 +101,15 @@ pub unsafe extern "system" fn
Java_org_apache_comet_Native_createPlan(
e: JNIEnv,
_class: JClass,
id: jlong,
- config_object: JObject,
iterators: jobjectArray,
serialized_query: jbyteArray,
metrics_node: JObject,
comet_task_memory_manager_obj: JObject,
+ batch_size: jint,
+ debug_native: jboolean,
+ explain_native: jboolean,
+ worker_threads: jint,
+ blocking_threads: jint,
) -> jlong {
try_unwrap_or_throw(&e, |mut env| {
// Init JVM classes
@@ -121,36 +123,10 @@ pub unsafe extern "system" fn
Java_org_apache_comet_Native_createPlan(
// Deserialize query plan
let spark_plan = serde::deserialize_op(bytes.as_slice())?;
- // Sets up context
- let mut configs = HashMap::new();
-
- let config_map = JMap::from_env(&mut env, &config_object)?;
- let mut map_iter = config_map.iter(&mut env)?;
- while let Some((key, value)) = map_iter.next(&mut env)? {
- let key: String =
env.get_string(&JString::from(key)).unwrap().into();
- let value: String =
env.get_string(&JString::from(value)).unwrap().into();
- configs.insert(key, value);
- }
-
- // Whether we've enabled additional debugging on the native side
- let debug_native = parse_bool(&configs, "debug_native")?;
- let explain_native = parse_bool(&configs, "explain_native")?;
-
- let worker_threads = configs
- .get("worker_threads")
- .map(String::as_str)
- .unwrap_or("4")
- .parse::<usize>()?;
- let blocking_threads = configs
- .get("blocking_threads")
- .map(String::as_str)
- .unwrap_or("10")
- .parse::<usize>()?;
-
// Use multi-threaded tokio runtime to prevent blocking spawned tasks
if any
let runtime = tokio::runtime::Builder::new_multi_thread()
- .worker_threads(worker_threads)
- .max_blocking_threads(blocking_threads)
+ .worker_threads(worker_threads as usize)
+ .max_blocking_threads(blocking_threads as usize)
.enable_all()
.build()?;
@@ -171,7 +147,7 @@ pub unsafe extern "system" fn
Java_org_apache_comet_Native_createPlan(
// We need to keep the session context alive. Some session state like
temporary
// dictionaries are stored in session context. If it is dropped, the
temporary
// dictionaries will be dropped as well.
- let session = prepare_datafusion_session_context(&configs,
task_memory_manager)?;
+ let session = prepare_datafusion_session_context(batch_size as usize,
task_memory_manager)?;
let plan_creation_time = start.elapsed();
@@ -182,13 +158,12 @@ pub unsafe extern "system" fn
Java_org_apache_comet_Native_createPlan(
scans: vec![],
input_sources,
stream: None,
- conf: configs,
runtime,
metrics,
plan_creation_time,
session_ctx: Arc::new(session),
- debug_native,
- explain_native,
+ debug_native: debug_native == 1,
+ explain_native: explain_native == 1,
metrics_jstrings: HashMap::new(),
});
@@ -196,19 +171,11 @@ pub unsafe extern "system" fn
Java_org_apache_comet_Native_createPlan(
})
}
-/// Parse Comet configs and configure DataFusion session context.
+/// Configure DataFusion session context.
fn prepare_datafusion_session_context(
- conf: &HashMap<String, String>,
+ batch_size: usize,
comet_task_memory_manager: Arc<GlobalRef>,
) -> CometResult<SessionContext> {
- // Get the batch size from Comet JVM side
- let batch_size = conf
- .get("batch_size")
- .ok_or(CometError::Internal(
- "Config 'batch_size' is not specified from Comet JVM
side".to_string(),
- ))?
- .parse::<usize>()?;
-
let mut rt_config =
RuntimeConfig::new().with_disk_manager(DiskManagerConfig::NewOs);
// Set Comet memory pool for native
@@ -218,7 +185,7 @@ fn prepare_datafusion_session_context(
// Get Datafusion configuration from Spark Execution context
// can be configured in Comet Spark JVM using Spark --conf parameters
// e.g: spark-shell --conf
spark.datafusion.sql_parser.parse_float_as_decimal=true
- let mut session_config = SessionConfig::new()
+ let session_config = SessionConfig::new()
.with_batch_size(batch_size)
// DataFusion partial aggregates can emit duplicate rows so we disable
the
// skip partial aggregation feature because this is not compatible
with Spark's
@@ -231,11 +198,7 @@ fn prepare_datafusion_session_context(
&ScalarValue::Float64(Some(1.1)),
);
- for (key, value) in conf.iter().filter(|(k, _)|
k.starts_with("datafusion.")) {
- session_config = session_config.set_str(key, value);
- }
-
- let runtime = RuntimeEnv::try_new(rt_config).unwrap();
+ let runtime = RuntimeEnv::try_new(rt_config)?;
let mut session_ctx = SessionContext::new_with_config_rt(session_config,
Arc::new(runtime));
diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala
b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala
index bff3e792..d57e9e2b 100644
--- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala
+++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala
@@ -60,43 +60,23 @@ class CometExecIterator(
new CometBatchIterator(iterator, nativeUtil)
}.toArray
private val plan = {
- val configs = createNativeConf
nativeLib.createPlan(
id,
- configs,
cometBatchIterators,
protobufQueryPlan,
nativeMetrics,
- new CometTaskMemoryManager(id))
+ new CometTaskMemoryManager(id),
+ batchSize = COMET_BATCH_SIZE.get(),
+ debug = COMET_DEBUG_ENABLED.get(),
+ explain = COMET_EXPLAIN_NATIVE_ENABLED.get(),
+ workerThreads = COMET_WORKER_THREADS.get(),
+ blockingThreads = COMET_BLOCKING_THREADS.get())
}
private var nextBatch: Option[ColumnarBatch] = None
private var currentBatch: ColumnarBatch = null
private var closed: Boolean = false
- /**
- * Creates a new configuration map to be passed to the native side.
- */
- private def createNativeConf: java.util.HashMap[String, String] = {
- val result = new java.util.HashMap[String, String]()
- val conf = SparkEnv.get.conf
-
- result.put("batch_size", String.valueOf(COMET_BATCH_SIZE.get()))
- result.put("debug_native", String.valueOf(COMET_DEBUG_ENABLED.get()))
- result.put("explain_native",
String.valueOf(COMET_EXPLAIN_NATIVE_ENABLED.get()))
- result.put("worker_threads", String.valueOf(COMET_WORKER_THREADS.get()))
- result.put("blocking_threads",
String.valueOf(COMET_BLOCKING_THREADS.get()))
-
- // Strip mandatory prefix spark. which is not required for DataFusion
session params
- conf.getAll.foreach {
- case (k, v) if k.startsWith("spark.datafusion") =>
- result.put(k.replaceFirst("spark\\.", ""), v)
- case _ =>
- }
-
- result
- }
-
def getNextBatch(): Option[ColumnarBatch] = {
assert(partitionIndex >= 0 && partitionIndex < numParts)
diff --git a/spark/src/main/scala/org/apache/comet/Native.scala
b/spark/src/main/scala/org/apache/comet/Native.scala
index 52063419..64ada91a 100644
--- a/spark/src/main/scala/org/apache/comet/Native.scala
+++ b/spark/src/main/scala/org/apache/comet/Native.scala
@@ -19,8 +19,6 @@
package org.apache.comet
-import java.util.Map
-
import org.apache.spark.CometTaskMemoryManager
import org.apache.spark.sql.comet.CometMetricNode
@@ -47,11 +45,15 @@ class Native extends NativeBase {
*/
@native def createPlan(
id: Long,
- configMap: Map[String, String],
iterators: Array[CometBatchIterator],
plan: Array[Byte],
metrics: CometMetricNode,
- taskMemoryManager: CometTaskMemoryManager): Long
+ taskMemoryManager: CometTaskMemoryManager,
+ batchSize: Int,
+ debug: Boolean,
+ explain: Boolean,
+ workerThreads: Int,
+ blockingThreads: Int): Long
/**
* Execute a native query plan based on given input Arrow arrays.
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]