This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 22634d2 feat: Pull based native execution (#69)
22634d2 is described below
commit 22634d2945c42bee3e13557c91f6b862952ebbdb
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Wed Feb 21 08:50:09 2024 -0800
feat: Pull based native execution (#69)
---
.../scala/org/apache/comet/vector/NativeUtil.scala | 16 +-
core/src/errors.rs | 9 +
core/src/execution/datafusion/planner.rs | 83 +++++++---
core/src/execution/jni_api.rs | 183 ++++++---------------
core/src/execution/operators/scan.rs | 164 +++++++++++++++---
core/src/jvm_bridge/batch_iterator.rs | 46 ++++++
core/src/jvm_bridge/mod.rs | 15 +-
.../java/org/apache/comet/CometBatchIterator.java | 56 +++++++
.../scala/org/apache/comet/CometExecIterator.scala | 93 ++---------
spark/src/main/scala/org/apache/comet/Native.scala | 53 +-----
.../apache/comet/parquet/ParquetReadSuite.scala | 3 +-
11 files changed, 405 insertions(+), 316 deletions(-)
diff --git a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
index 1e27ed8..4bb63e5 100644
--- a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
+++ b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
@@ -39,11 +39,14 @@ class NativeUtil {
* @param batch
* the input Comet columnar batch
* @return
- * a list containing pairs of memory addresses in the format of (address
of Arrow array,
- * address of Arrow schema)
+ * a list containing number of rows + pairs of memory addresses in the
format of (address of
+ * Arrow array, address of Arrow schema)
*/
def exportBatch(batch: ColumnarBatch): Array[Long] = {
- val vectors = (0 until batch.numCols()).flatMap { index =>
+ val exportedVectors = mutable.ArrayBuffer.empty[Long]
+ exportedVectors += batch.numRows()
+
+ (0 until batch.numCols()).foreach { index =>
batch.column(index) match {
case a: CometVector =>
val valueVector = a.getValueVector
@@ -63,7 +66,8 @@ class NativeUtil {
arrowArray,
arrowSchema)
- Seq((arrowArray, arrowSchema))
+ exportedVectors += arrowArray.memoryAddress()
+ exportedVectors += arrowSchema.memoryAddress()
case c =>
throw new SparkException(
"Comet execution only takes Arrow Arrays, but got " +
@@ -71,9 +75,7 @@ class NativeUtil {
}
}
- vectors.flatMap { pair =>
- Seq(pair._1.memoryAddress(), pair._2.memoryAddress())
- }.toArray
+ exportedVectors.toArray
}
/**
diff --git a/core/src/errors.rs b/core/src/errors.rs
index 0da2c9c..16ed7c3 100644
--- a/core/src/errors.rs
+++ b/core/src/errors.rs
@@ -159,6 +159,15 @@ impl From<CometError> for DataFusionError {
}
}
+impl From<CometError> for ExecutionError {
+ fn from(value: CometError) -> Self {
+ match value {
+ CometError::Execution { source } => source,
+ _ => ExecutionError::GeneralError(value.to_string()),
+ }
+ }
+}
+
impl jni::errors::ToException for CometError {
fn to_exception(&self) -> Exception {
match self {
diff --git a/core/src/execution/datafusion/planner.rs
b/core/src/execution/datafusion/planner.rs
index c132724..2feaace 100644
--- a/core/src/execution/datafusion/planner.rs
+++ b/core/src/execution/datafusion/planner.rs
@@ -47,6 +47,7 @@ use datafusion_physical_expr::{
AggregateExpr, ScalarFunctionExpr,
};
use itertools::Itertools;
+use jni::objects::GlobalRef;
use num::{BigInt, ToPrimitive};
use crate::{
@@ -70,7 +71,7 @@ use crate::{
operators::expand::CometExpandExec,
shuffle_writer::ShuffleWriterExec,
},
- operators::{CopyExec, ExecutionError, InputBatch, ScanExec},
+ operators::{CopyExec, ExecutionError, ScanExec},
serde::to_arrow_datatype,
spark_expression,
spark_expression::{
@@ -88,6 +89,8 @@ type PhyAggResult = Result<Vec<Arc<dyn AggregateExpr>>,
ExecutionError>;
type PhyExprResult = Result<Vec<(Arc<dyn PhysicalExpr>, String)>,
ExecutionError>;
type PartitionPhyExprResult = Result<Vec<Arc<dyn PhysicalExpr>>,
ExecutionError>;
+pub const TEST_EXEC_CONTEXT_ID: i64 = -1;
+
/// The query planner for converting Spark query plans to DataFusion query
plans.
pub struct PhysicalPlanner {
// The execution context id of this planner.
@@ -105,7 +108,7 @@ impl PhysicalPlanner {
pub fn new() -> Self {
let execution_props = ExecutionProps::new();
Self {
- exec_context_id: -1,
+ exec_context_id: TEST_EXEC_CONTEXT_ID,
execution_props,
}
}
@@ -612,24 +615,28 @@ impl PhysicalPlanner {
/// Create a DataFusion physical plan from Spark physical plan.
///
- /// Note that we need `input_batches` parameter because we need to know
the exact schema (not
- /// only data type but also dictionary-encoding) at `ScanExec`s. It is
because some DataFusion
- /// operators, e.g., `ProjectionExec`, gets child operator schema during
initialization and
- /// uses it later for `RecordBatch`. We may be able to get rid of it once
`RecordBatch`
- /// relaxes schema check.
+ /// `inputs` is a vector of input source IDs. It is used to create
`ScanExec`s. Each `ScanExec`
+ /// will be assigned a unique ID from `inputs` and the ID will be used to
identify the input
+ /// source at JNI API.
+ ///
+ /// Note that `ScanExec` will pull initial input batch during
initialization. It is because we
+ /// need to know the exact schema (not only data type but also
dictionary-encoding) at
+ /// `ScanExec`s. It is because some DataFusion operators, e.g.,
`ProjectionExec`, gets child
+ /// operator schema during initialization and uses it later for
`RecordBatch`. We may be
+ /// able to get rid of it once `RecordBatch` relaxes schema check.
///
/// Note that we return created `Scan`s which will be kept at JNI API. JNI
calls will use it to
/// feed in new input batch from Spark JVM side.
pub fn create_plan<'a>(
&'a self,
spark_plan: &'a Operator,
- input_batches: &mut Vec<InputBatch>,
+ inputs: &mut Vec<Arc<GlobalRef>>,
) -> Result<(Vec<ScanExec>, Arc<dyn ExecutionPlan>), ExecutionError> {
let children = &spark_plan.children;
match spark_plan.op_struct.as_ref().unwrap() {
OpStruct::Projection(project) => {
assert!(children.len() == 1);
- let (scans, child) = self.create_plan(&children[0],
input_batches)?;
+ let (scans, child) = self.create_plan(&children[0], inputs)?;
let exprs: PhyExprResult = project
.project_list
.iter()
@@ -643,7 +650,7 @@ impl PhysicalPlanner {
}
OpStruct::Filter(filter) => {
assert!(children.len() == 1);
- let (scans, child) = self.create_plan(&children[0],
input_batches)?;
+ let (scans, child) = self.create_plan(&children[0], inputs)?;
let predicate =
self.create_expr(filter.predicate.as_ref().unwrap(),
child.schema())?;
@@ -651,7 +658,7 @@ impl PhysicalPlanner {
}
OpStruct::HashAgg(agg) => {
assert!(children.len() == 1);
- let (scans, child) = self.create_plan(&children[0],
input_batches)?;
+ let (scans, child) = self.create_plan(&children[0], inputs)?;
let group_exprs: PhyExprResult = agg
.grouping_exprs
@@ -716,13 +723,13 @@ impl PhysicalPlanner {
OpStruct::Limit(limit) => {
assert!(children.len() == 1);
let num = limit.limit;
- let (scans, child) = self.create_plan(&children[0],
input_batches)?;
+ let (scans, child) = self.create_plan(&children[0], inputs)?;
Ok((scans, Arc::new(LocalLimitExec::new(child, num as usize))))
}
OpStruct::Sort(sort) => {
assert!(children.len() == 1);
- let (scans, child) = self.create_plan(&children[0],
input_batches)?;
+ let (scans, child) = self.create_plan(&children[0], inputs)?;
let exprs: Result<Vec<PhysicalSortExpr>, ExecutionError> = sort
.sort_orders
@@ -741,21 +748,32 @@ impl PhysicalPlanner {
}
OpStruct::Scan(scan) => {
let fields =
scan.fields.iter().map(to_arrow_datatype).collect_vec();
- if input_batches.is_empty() {
+
+ // If it is not test execution context for unit test, we
should have at least one
+ // input source
+ if self.exec_context_id != TEST_EXEC_CONTEXT_ID &&
inputs.is_empty() {
return Err(ExecutionError::GeneralError(
- "No input batch for scan".to_string(),
+ "No input for scan".to_string(),
));
}
- // Consumes the first input batch source for the scan
- let input_batch = input_batches.remove(0);
+
+ // Consumes the first input source for the scan
+ let input_source = if self.exec_context_id ==
TEST_EXEC_CONTEXT_ID
+ && inputs.is_empty()
+ {
+ // For unit test, we will set input batch to scan directly
by `set_input_batch`.
+ None
+ } else {
+ Some(inputs.remove(0))
+ };
// The `ScanExec` operator will take actual arrays from Spark
during execution
- let scan = ScanExec::new(input_batch, fields);
+ let scan = ScanExec::new(self.exec_context_id, input_source,
fields)?;
Ok((vec![scan.clone()], Arc::new(scan)))
}
OpStruct::ShuffleWriter(writer) => {
assert!(children.len() == 1);
- let (scans, child) = self.create_plan(&children[0],
input_batches)?;
+ let (scans, child) = self.create_plan(&children[0], inputs)?;
let partitioning = self
.create_partitioning(writer.partitioning.as_ref().unwrap(), child.schema())?;
@@ -772,7 +790,7 @@ impl PhysicalPlanner {
}
OpStruct::Expand(expand) => {
assert!(children.len() == 1);
- let (scans, child) = self.create_plan(&children[0],
input_batches)?;
+ let (scans, child) = self.create_plan(&children[0], inputs)?;
let mut projections = vec![];
let mut projection = vec![];
@@ -805,6 +823,18 @@ impl PhysicalPlanner {
.collect();
let schema = Arc::new(Schema::new(fields));
+ // `Expand` operator keeps the input batch and expands it to
multiple output
+ // batches. However, `ScanExec` will reuse input arrays for
the next
+ // input batch. Therefore, we need to copy the input batch to
avoid
+ // the data corruption. Note that we only need to copy the
input batch
+ // if the child operator is `ScanExec`, because other
operators after `ScanExec`
+ // will create new arrays for the output batch.
+ let child = if
child.as_any().downcast_ref::<ScanExec>().is_some() {
+ Arc::new(CopyExec::new(child))
+ } else {
+ child
+ };
+
Ok((
scans,
Arc::new(CometExpandExec::new(projections, child, schema)),
@@ -997,9 +1027,9 @@ mod tests {
let values = Int32Array::from(vec![0, 1, 2, 3]);
let input_array = DictionaryArray::new(keys, Arc::new(values));
let input_batch = InputBatch::Batch(vec![Arc::new(input_array)],
row_count);
- let mut input_batches = vec![input_batch];
- let (mut scans, datafusion_plan) = planner.create_plan(&op, &mut
input_batches).unwrap();
+ let (mut scans, datafusion_plan) = planner.create_plan(&op, &mut
vec![]).unwrap();
+ scans[0].set_input_batch(input_batch);
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
@@ -1077,9 +1107,11 @@ mod tests {
let values = StringArray::from(vec!["foo", "bar", "hello", "comet"]);
let input_array = DictionaryArray::new(keys, Arc::new(values));
let input_batch = InputBatch::Batch(vec![Arc::new(input_array)],
row_count);
- let mut input_batches = vec![input_batch];
- let (mut scans, datafusion_plan) = planner.create_plan(&op, &mut
input_batches).unwrap();
+ let (mut scans, datafusion_plan) = planner.create_plan(&op, &mut
vec![]).unwrap();
+
+ // Scan's schema is determined by the input batch, so we need to set
it before execution.
+ scans[0].set_input_batch(input_batch);
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
@@ -1147,8 +1179,7 @@ mod tests {
let op = create_filter(op_scan, 0);
let planner = PhysicalPlanner::new();
- let mut input_batches = vec![InputBatch::EOF];
- let (mut scans, datafusion_plan) = planner.create_plan(&op, &mut
input_batches).unwrap();
+ let (mut scans, datafusion_plan) = planner.create_plan(&op, &mut
vec![]).unwrap();
let scan = &mut scans[0];
scan.set_input_batch(InputBatch::EOF);
diff --git a/core/src/execution/jni_api.rs b/core/src/execution/jni_api.rs
index 831f788..fae213b 100644
--- a/core/src/execution/jni_api.rs
+++ b/core/src/execution/jni_api.rs
@@ -17,9 +17,7 @@
//! Define JNI APIs which can be called from Java/Scala.
-use crate::execution::operators::{InputBatch, ScanExec};
use arrow::{
- array::{make_array, Array, ArrayData, ArrayRef},
datatypes::DataType as ArrowDataType,
ffi::{FFI_ArrowArray, FFI_ArrowSchema},
};
@@ -32,13 +30,12 @@ use datafusion::{
physical_plan::{display::DisplayableExecutionPlan, ExecutionPlan,
SendableRecordBatchStream},
prelude::{SessionConfig, SessionContext},
};
-use datafusion_common::DataFusionError;
use futures::poll;
use jni::{
errors::Result as JNIResult,
objects::{
- AutoElements, JBooleanArray, JByteArray, JClass, JIntArray,
JLongArray, JMap, JObject,
- JObjectArray, JPrimitiveArray, JString, ReleaseMode,
+ JByteArray, JClass, JIntArray, JLongArray, JMap, JObject,
JObjectArray, JPrimitiveArray,
+ JString, ReleaseMode,
},
sys::{jbyteArray, jint, jlong, jlongArray},
JNIEnv,
@@ -59,10 +56,11 @@ use crate::{
use futures::stream::StreamExt;
use jni::{
objects::GlobalRef,
- sys::{jboolean, jbooleanArray, jdouble, jintArray, jobjectArray, jstring},
+ sys::{jboolean, jdouble, jintArray, jobjectArray, jstring},
};
use tokio::runtime::Runtime;
+use crate::execution::operators::ScanExec;
use log::info;
/// Comet native execution context. Kept alive across JNI calls.
@@ -75,6 +73,8 @@ struct ExecutionContext {
pub root_op: Option<Arc<dyn ExecutionPlan>>,
/// The input sources for the DataFusion plan
pub scans: Vec<ScanExec>,
+ /// The global reference of input sources for the DataFusion plan
+ pub input_sources: Vec<Arc<GlobalRef>>,
/// The record batch stream to pull results from
pub stream: Option<SendableRecordBatchStream>,
/// The FFI arrays. We need to keep them alive here.
@@ -100,6 +100,7 @@ pub unsafe extern "system" fn
Java_org_apache_comet_Native_createPlan(
_class: JClass,
id: jlong,
config_object: JObject,
+ iterators: jobjectArray,
serialized_query: jbyteArray,
metrics_node: JObject,
) -> jlong {
@@ -137,6 +138,16 @@ pub unsafe extern "system" fn
Java_org_apache_comet_Native_createPlan(
let metrics = Arc::new(jni_new_global_ref!(env, metrics_node)?);
+ // Get the global references of input sources
+ let mut input_sources = vec![];
+ let iter_array = JObjectArray::from_raw(iterators);
+ let num_inputs = env.get_array_length(&iter_array)?;
+ for i in 0..num_inputs {
+ let input_source = env.get_object_array_element(&iter_array, i)?;
+ let input_source = Arc::new(jni_new_global_ref!(env,
input_source)?);
+ input_sources.push(input_source);
+ }
+
// We need to keep the session context alive. Some session state like
temporary
// dictionaries are stored in session context. If it is dropped, the
temporary
// dictionaries will be dropped as well.
@@ -147,6 +158,7 @@ pub unsafe extern "system" fn
Java_org_apache_comet_Native_createPlan(
spark_plan,
root_op: None,
scans: vec![],
+ input_sources,
stream: None,
ffi_arrays: vec![],
conf: configs,
@@ -164,7 +176,7 @@ pub unsafe extern "system" fn
Java_org_apache_comet_Native_createPlan(
fn prepare_datafusion_session_context(
conf: &HashMap<String, String>,
) -> CometResult<SessionContext> {
- // Get the batch size from Boson JVM side
+ // Get the batch size from Comet JVM side
let batch_size = conf
.get("batch_size")
.ok_or(CometError::Internal(
@@ -212,10 +224,9 @@ fn prepare_datafusion_session_context(
/// Prepares arrow arrays for output.
fn prepare_output(
env: &mut JNIEnv,
- output: Result<RecordBatch, DataFusionError>,
+ output_batch: RecordBatch,
exec_context: &mut ExecutionContext,
) -> CometResult<jlongArray> {
- let output_batch = output?;
let results = output_batch.columns();
let num_rows = output_batch.num_rows();
@@ -260,6 +271,20 @@ fn prepare_output(
Ok(long_array.into_raw())
}
+/// Pull the next input from JVM. Note that we cannot pull input batches in
+/// `ScanStream.poll_next` when the execution stream is polled for output.
+/// Because the input source could be another native execution stream, which
+/// will be executed in another tokio blocking thread. It causes JNI throw
+/// Java exception. So we pull input batches here and insert them into scan
+/// operators before polling the stream,
+#[inline]
+fn pull_input_batches(exec_context: &mut ExecutionContext) -> Result<(),
CometError> {
+ exec_context.scans.iter_mut().try_for_each(|scan| {
+ scan.get_next_batch()?;
+ Ok::<(), CometError>(())
+ })
+}
+
/// Accept serialized query plan and the addresses of Arrow Arrays from Spark,
/// then execute the query. Return addresses of arrow vector.
/// # Safety
@@ -269,76 +294,11 @@ pub unsafe extern "system" fn
Java_org_apache_comet_Native_executePlan(
e: JNIEnv,
_class: JClass,
exec_context: jlong,
- addresses_array: jobjectArray,
- finishes: jbooleanArray,
- batch_rows: jint,
) -> jlongArray {
- try_unwrap_or_throw(&e, |mut env| unsafe {
+ try_unwrap_or_throw(&e, |mut env| {
+ // Retrieve the query
let exec_context = get_execution_context(exec_context);
- let addresses = JObjectArray::from_raw(addresses_array);
- let num_addresses = env.get_array_length(&addresses)? as usize;
-
- let mut all_inputs: Vec<Vec<ArrayRef>> =
Vec::with_capacity(num_addresses);
-
- for i in 0..num_addresses {
- let mut inputs: Vec<ArrayRef> = vec![];
-
- let inner_addresses = env.get_object_array_element(&addresses, i
as i32)?.into();
- let inner_address_array: AutoElements<jlong> =
- env.get_array_elements(&inner_addresses,
ReleaseMode::NoCopyBack)?;
-
- let num_inner_address = inner_address_array.len();
- assert_eq!(
- num_inner_address % 2,
- 0,
- "Arrow Array addresses are invalid!"
- );
-
- let num_arrays = num_inner_address / 2;
- let array_elements = inner_address_array.as_ptr();
-
- let mut i: usize = 0;
- while i < num_arrays {
- let array_ptr = *(array_elements.add(i * 2));
- let schema_ptr = *(array_elements.add(i * 2 + 1));
- let array_data = ArrayData::from_spark((array_ptr,
schema_ptr))?;
-
- if exec_context.debug_native {
- // Validate the array data from JVM.
- array_data.validate_full().expect("Invalid array data");
- }
-
- inputs.push(make_array(array_data));
- i += 1;
- }
-
- all_inputs.push(inputs);
- }
-
- // Prepares the input batches.
- let array = JBooleanArray::from_raw(finishes);
- let eofs = env.get_array_elements(&array, ReleaseMode::NoCopyBack)?;
- let eof_flags = eofs.as_ptr();
-
- // Whether reaching the end of input batches.
- let mut finished = true;
- let mut input_batches = all_inputs
- .into_iter()
- .enumerate()
- .map(|(idx, inputs)| {
- let eof = eof_flags.add(idx);
-
- if *eof == 1 {
- InputBatch::EOF
- } else {
- finished = false;
- InputBatch::new(inputs, Some(batch_rows as usize))
- }
- })
- .collect::<Vec<InputBatch>>();
-
- // Retrieve the query
let exec_context_id = exec_context.id;
// Initialize the execution stream.
@@ -346,8 +306,10 @@ pub unsafe extern "system" fn
Java_org_apache_comet_Native_executePlan(
// query plan, we need to defer stream initialization to first time
execution.
if exec_context.root_op.is_none() {
let planner = PhysicalPlanner::new().with_exec_id(exec_context_id);
- let (scans, root_op) =
- planner.create_plan(&exec_context.spark_plan, &mut
input_batches)?;
+ let (scans, root_op) = planner.create_plan(
+ &exec_context.spark_plan,
+ &mut exec_context.input_sources.clone(),
+ )?;
exec_context.root_op = Some(root_op.clone());
exec_context.scans = scans;
@@ -366,15 +328,8 @@ pub unsafe extern "system" fn
Java_org_apache_comet_Native_executePlan(
.execute(0, task_ctx)?;
exec_context.stream = Some(stream);
} else {
- input_batches
- .into_iter()
- .enumerate()
- .for_each(|(idx, input_batch)| {
- let scan = &mut exec_context.scans[idx];
-
- // Set inputs at `Scan` node.
- scan.set_input_batch(input_batch);
- });
+ // Pull input batches
+ pull_input_batches(exec_context)?;
}
loop {
@@ -384,7 +339,7 @@ pub unsafe extern "system" fn
Java_org_apache_comet_Native_executePlan(
match poll_output {
Poll::Ready(Some(output)) => {
- return prepare_output(&mut env, output, exec_context);
+ return prepare_output(&mut env, output?, exec_context);
}
Poll::Ready(None) => {
// Reaches EOF of output.
@@ -397,23 +352,18 @@ pub unsafe extern "system" fn
Java_org_apache_comet_Native_executePlan(
return Ok(long_array.into_raw());
}
- // After reaching the end of any input, a poll pending means
there are more than
- // one blocking operators, we don't need go back-forth
- // between JVM/Native. Just keeping polling.
- Poll::Pending if finished => {
+ // A poll pending means there are more than one blocking
operators,
+ // we don't need go back-forth between JVM/Native. Just
keeping polling.
+ Poll::Pending => {
// Update metrics
update_metrics(&mut env, exec_context)?;
+ // Pull input batches
+ pull_input_batches(exec_context)?;
+
// Output not ready yet
continue;
}
- // Not reaching the end of input yet, so a poll pending means
there are blocking
- // operators. Just returning to keep reading next input.
- Poll::Pending => {
- // Update metrics
- update_metrics(&mut env, exec_context)?;
- return return_pending(env);
- }
}
}
})
@@ -425,37 +375,6 @@ fn return_pending(env: JNIEnv) -> Result<jlongArray,
CometError> {
Ok(long_array.into_raw())
}
-#[no_mangle]
-/// Peeks into next output if any.
-pub extern "system" fn Java_org_apache_comet_Native_peekNext(
- e: JNIEnv,
- _class: JClass,
- exec_context: jlong,
-) -> jlongArray {
- try_unwrap_or_throw(&e, |mut env| {
- // Retrieve the query
- let exec_context = get_execution_context(exec_context);
-
- if exec_context.stream.is_none() {
- // Plan is not initialized yet.
- return return_pending(env);
- }
-
- // Polling the stream.
- let next_item = exec_context.stream.as_mut().unwrap().next();
- let poll_output = exec_context.runtime.block_on(async {
poll!(next_item) });
-
- match poll_output {
- Poll::Ready(Some(output)) => prepare_output(&mut env, output,
exec_context),
- _ => {
- // Update metrics
- update_metrics(&mut env, exec_context)?;
- return_pending(env)
- }
- }
- })
-}
-
#[no_mangle]
/// Drop the native query plan object and context object.
pub extern "system" fn Java_org_apache_comet_Native_releasePlan(
@@ -507,7 +426,7 @@ fn get_execution_context<'a>(id: i64) -> &'a mut
ExecutionContext {
}
}
-/// Used by Boson shuffle external sorter to write sorted records to disk.
+/// Used by Comet shuffle external sorter to write sorted records to disk.
/// # Safety
/// This function is inheritly unsafe since it deals with raw pointers passed
from JNI.
#[no_mangle]
@@ -577,7 +496,7 @@ pub unsafe extern "system" fn
Java_org_apache_comet_Native_writeSortedFileNative
}
#[no_mangle]
-/// Used by Boson shuffle external sorter to sort in-memory row partition ids.
+/// Used by Comet shuffle external sorter to sort in-memory row partition ids.
pub extern "system" fn Java_org_apache_comet_Native_sortRowPartitionsNative(
e: JNIEnv,
_class: JClass,
diff --git a/core/src/execution/operators/scan.rs
b/core/src/execution/operators/scan.rs
index f80db6c..9f85de8 100644
--- a/core/src/execution/operators/scan.rs
+++ b/core/src/execution/operators/scan.rs
@@ -26,33 +26,62 @@ use futures::Stream;
use itertools::Itertools;
use arrow::compute::{cast_with_options, CastOptions};
-use arrow_array::{ArrayRef, RecordBatch, RecordBatchOptions};
+use arrow_array::{make_array, ArrayRef, RecordBatch, RecordBatchOptions};
+use arrow_data::ArrayData;
use arrow_schema::{DataType, Field, Schema, SchemaRef};
+use crate::{
+ errors::CometError,
+ execution::{
+ datafusion::planner::TEST_EXEC_CONTEXT_ID, operators::ExecutionError,
+ utils::SparkArrowConvert,
+ },
+ jvm_bridge::{jni_call, JVMClasses},
+};
use datafusion::{
execution::TaskContext,
physical_expr::*,
physical_plan::{ExecutionPlan, *},
};
use datafusion_common::{DataFusionError, Result as DataFusionResult};
+use jni::{
+ objects::{GlobalRef, JLongArray, JObject, ReleaseMode},
+ sys::jlongArray,
+};
#[derive(Debug, Clone)]
pub struct ScanExec {
- pub batch: Arc<Mutex<Option<InputBatch>>>,
+ /// The ID of the execution context that owns this subquery. We use this
ID to retrieve the JVM
+ /// environment `JNIEnv` from the execution context.
+ pub exec_context_id: i64,
+ /// The input source of scan node. It is a global reference of JVM
`CometBatchIterator` object.
+ pub input_source: Option<Arc<GlobalRef>>,
+ /// The data types of columns of the input batch. Converted from Spark
schema.
pub data_types: Vec<DataType>,
+ /// The input batch of input data. Used to determine the schema of the
input data.
+ /// It is also used in unit test to mock the input data from JVM.
+ pub batch: Arc<Mutex<Option<InputBatch>>>,
}
impl ScanExec {
- pub fn new(batch: InputBatch, data_types: Vec<DataType>) -> Self {
- Self {
- batch: Arc::new(Mutex::new(Some(batch))),
- data_types,
- }
- }
+ pub fn new(
+ exec_context_id: i64,
+ input_source: Option<Arc<GlobalRef>>,
+ data_types: Vec<DataType>,
+ ) -> Result<Self, CometError> {
+ // Scan's schema is determined by the input batch, so we need to set
it before execution.
+ let first_batch = if let Some(input_source) = input_source.as_ref() {
+ ScanExec::get_next(exec_context_id, input_source.as_obj())?
+ } else {
+ InputBatch::EOF
+ };
- /// Feeds input batch into this `Scan`.
- pub fn set_input_batch(&mut self, input: InputBatch) {
- *self.batch.try_lock().unwrap() = Some(input);
+ Ok(Self {
+ exec_context_id,
+ input_source,
+ data_types,
+ batch: Arc::new(Mutex::new(Some(first_batch))),
+ })
}
/// Checks if the input data type `dt` is a dictionary type with primitive
value type.
@@ -74,6 +103,98 @@ impl ScanExec {
dt.clone()
}
+
+ /// Feeds input batch into this `Scan`. Only used in unit test.
+ pub fn set_input_batch(&mut self, input: InputBatch) {
+ *self.batch.try_lock().unwrap() = Some(input);
+ }
+
+ /// Pull next input batch from JVM.
+ pub fn get_next_batch(&mut self) -> Result<(), CometError> {
+ let mut current_batch = self.batch.try_lock().unwrap();
+
+ if self.input_source.is_none() {
+ // This is a unit test. We don't need to call JNI.
+ return Ok(());
+ }
+
+ if current_batch.is_none() {
+ let next_batch = ScanExec::get_next(
+ self.exec_context_id,
+ self.input_source.as_ref().unwrap().as_obj(),
+ )?;
+ *current_batch = Some(next_batch);
+ }
+
+ Ok(())
+ }
+
+ /// Invokes JNI call to get next batch.
+ fn get_next(exec_context_id: i64, iter: &JObject) -> Result<InputBatch,
CometError> {
+ if exec_context_id == TEST_EXEC_CONTEXT_ID {
+ // This is a unit test. We don't need to call JNI.
+ return Ok(InputBatch::EOF);
+ }
+
+ let mut env = JVMClasses::get_env();
+
+ if iter.is_null() {
+ return Err(CometError::from(ExecutionError::GeneralError(format!(
+ "Null batch iterator object. Plan id: {}",
+ exec_context_id
+ ))));
+ }
+
+ let batch_object: JObject = unsafe {
+ jni_call!(&mut env,
+ comet_batch_iterator(iter).next() -> JObject)?
+ };
+
+ if batch_object.is_null() {
+ return Err(CometError::from(ExecutionError::GeneralError(format!(
+ "Null batch object. Plan id: {}",
+ exec_context_id
+ ))));
+ }
+
+ let batch_object = unsafe { JLongArray::from_raw(batch_object.as_raw()
as jlongArray) };
+
+ let addresses = unsafe { env.get_array_elements(&batch_object,
ReleaseMode::NoCopyBack)? };
+
+ let mut inputs: Vec<ArrayRef> = vec![];
+
+ // First element is the number of rows.
+ let num_rows = unsafe { *addresses.as_ptr() as i64 };
+
+ if num_rows < 0 {
+ return Ok(InputBatch::EOF);
+ }
+
+ let array_num = addresses.len() - 1;
+ if array_num % 2 != 0 {
+ return Err(CometError::Internal(format!(
+ "Invalid number of Arrow Array addresses: {}",
+ array_num
+ )));
+ }
+
+ let num_arrays = array_num / 2;
+ let array_elements = unsafe { addresses.as_ptr().add(1) };
+
+ let mut i: usize = 0;
+ while i < num_arrays {
+ let array_ptr = unsafe { *(array_elements.add(i * 2)) };
+ let schema_ptr = unsafe { *(array_elements.add(i * 2 + 1)) };
+ let array_data = ArrayData::from_spark((array_ptr, schema_ptr))?;
+
+ // TODO: validate array input data
+
+ inputs.push(make_array(array_data));
+ i += 1;
+ }
+
+ Ok(InputBatch::new(inputs, Some(num_rows as usize)))
+ }
}
impl ExecutionPlan for ScanExec {
@@ -214,19 +335,22 @@ impl Stream for ScanStream {
fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) ->
Poll<Option<Self::Item>> {
let mut scan_batch = self.scan.batch.try_lock().unwrap();
let input_batch = &*scan_batch;
+
+ let input_batch = if let Some(batch) = input_batch {
+ batch
+ } else {
+ return Poll::Pending;
+ };
+
let result = match input_batch {
- // Input batch is not ready.
- None => Poll::Pending,
- Some(batch) => match batch {
- InputBatch::EOF => Poll::Ready(None),
- InputBatch::Batch(columns, num_rows) => {
- Poll::Ready(Some(self.build_record_batch(columns,
*num_rows)))
- }
- },
+ InputBatch::EOF => Poll::Ready(None),
+ InputBatch::Batch(columns, num_rows) => {
+ Poll::Ready(Some(self.build_record_batch(columns, *num_rows)))
+ }
};
- // Reset the current input batch so it won't be processed again
*scan_batch = None;
+
result
}
}
diff --git a/core/src/jvm_bridge/batch_iterator.rs
b/core/src/jvm_bridge/batch_iterator.rs
new file mode 100644
index 0000000..474e4fd
--- /dev/null
+++ b/core/src/jvm_bridge/batch_iterator.rs
@@ -0,0 +1,46 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use super::get_global_jclass;
+use jni::{
+ errors::Result as JniResult,
+ objects::{JClass, JMethodID},
+ signature::ReturnType,
+ JNIEnv,
+};
+
+/// A struct that holds all the JNI methods and fields for JVM
`CometBatchIterator` class.
+pub struct CometBatchIterator<'a> {
+ pub class: JClass<'a>,
+ pub method_next: JMethodID,
+ pub method_next_ret: ReturnType,
+}
+
+impl<'a> CometBatchIterator<'a> {
+ pub const JVM_CLASS: &'static str = "org/apache/comet/CometBatchIterator";
+
+ pub fn new(env: &mut JNIEnv<'a>) -> JniResult<CometBatchIterator<'a>> {
+ // Get the global class reference
+ let class = get_global_jclass(env, Self::JVM_CLASS)?;
+
+ Ok(CometBatchIterator {
+ class,
+ method_next: env.get_method_id(Self::JVM_CLASS, "next",
"()[J").unwrap(),
+ method_next_ret: ReturnType::Array,
+ })
+ }
+}
diff --git a/core/src/jvm_bridge/mod.rs b/core/src/jvm_bridge/mod.rs
index 0870961..7a2882e 100644
--- a/core/src/jvm_bridge/mod.rs
+++ b/core/src/jvm_bridge/mod.rs
@@ -17,6 +17,8 @@
//! JNI JVM related functions
+use crate::errors::CometResult;
+
use jni::{
errors::{Error, Result as JniResult},
objects::{JClass, JMethodID, JObject, JString, JThrowable, JValueGen,
JValueOwned},
@@ -71,7 +73,7 @@ macro_rules! jni_call {
let ret = $env.call_method_unchecked($obj, method_id, ret_type, args);
// Check if JVM has thrown any exception, and handle it if so.
- let result = if let Some(exception) =
$crate::jvm_bridge::check_exception($env).unwrap() {
+ let result = if let Some(exception) =
$crate::jvm_bridge::check_exception($env)? {
Err(exception.into())
} else {
$crate::jvm_bridge::jni_map_error!($env, ret)
@@ -190,11 +192,11 @@ pub fn get_global_jclass(env: &mut JNIEnv, cls: &str) ->
JniResult<JClass<'stati
mod comet_exec;
pub use comet_exec::*;
+mod batch_iterator;
mod comet_metric_node;
-use crate::{
- errors::{CometError, CometResult},
- JAVA_VM,
-};
+
+use crate::{errors::CometError, JAVA_VM};
+use batch_iterator::CometBatchIterator;
pub use comet_metric_node::*;
/// The JVM classes that are used in the JNI calls.
@@ -212,6 +214,8 @@ pub struct JVMClasses<'a> {
pub comet_metric_node: CometMetricNode<'a>,
/// The static CometExec class. Used for getting the subquery result.
pub comet_exec: CometExec<'a>,
+ /// The CometBatchIterator class. Used for iterating over the batches.
+ pub comet_batch_iterator: CometBatchIterator<'a>,
}
unsafe impl<'a> Send for JVMClasses<'a> {}
@@ -256,6 +260,7 @@ impl JVMClasses<'_> {
throwable_get_cause_method,
comet_metric_node: CometMetricNode::new(env).unwrap(),
comet_exec: CometExec::new(env).unwrap(),
+ comet_batch_iterator: CometBatchIterator::new(env).unwrap(),
}
});
}
diff --git a/spark/src/main/java/org/apache/comet/CometBatchIterator.java
b/spark/src/main/java/org/apache/comet/CometBatchIterator.java
new file mode 100644
index 0000000..3360329
--- /dev/null
+++ b/spark/src/main/java/org/apache/comet/CometBatchIterator.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet;
+
+import scala.collection.Iterator;
+
+import org.apache.spark.sql.vectorized.ColumnarBatch;
+
+import org.apache.comet.vector.NativeUtil;
+
+/**
+ * An iterator that can be used to get batches of Arrow arrays from a Spark
iterator of
+ * ColumnarBatch. It will consume input iterator and return Arrow arrays by
addresses. This is
+ * called by native code to retrieve Arrow arrays from Spark through JNI.
+ */
+public class CometBatchIterator {
+ final Iterator<ColumnarBatch> input;
+ final NativeUtil nativeUtil;
+
+ CometBatchIterator(Iterator<ColumnarBatch> input, NativeUtil nativeUtil) {
+ this.input = input;
+ this.nativeUtil = nativeUtil;
+ }
+
+ /**
+ * Get the next batches of Arrow arrays. It will consume input iterator and
return Arrow arrays by
+ * addresses. If the input iterator is done, it will return a one negative
element array
+ * indicating the end of the iterator.
+ */
+ public long[] next() {
+ boolean hasBatch = input.hasNext();
+
+ if (!hasBatch) {
+ return new long[] {-1};
+ }
+
+ return nativeUtil.exportBatch(input.next());
+ }
+}
diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala
b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala
index 0140582..20b2d38 100644
--- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala
+++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala
@@ -48,30 +48,24 @@ class CometExecIterator(
extends Iterator[ColumnarBatch] {
private val nativeLib = new Native()
+ private val nativeUtil = new NativeUtil
+ private val cometBatchIterators = inputs.map { iterator =>
+ new CometBatchIterator(iterator, nativeUtil)
+ }.toArray
private val plan = {
val configs = createNativeConf
- nativeLib.createPlan(id, configs, protobufQueryPlan, nativeMetrics)
+ nativeLib.createPlan(id, configs, cometBatchIterators, protobufQueryPlan,
nativeMetrics)
}
- private val nativeUtil = new NativeUtil
+
private var nextBatch: Option[ColumnarBatch] = None
private var currentBatch: ColumnarBatch = null
private var closed: Boolean = false
- private def peekNext(): ExecutionState = {
- convertNativeResult(nativeLib.peekNext(plan))
- }
+ private def executeNative(): ExecutionState = {
+ val result = nativeLib.executePlan(plan)
- private def executeNative(
- input: Array[Array[Long]],
- finishes: Array[Boolean],
- numRows: Int): ExecutionState = {
- convertNativeResult(nativeLib.executePlan(plan, input, finishes, numRows))
- }
-
- private def convertNativeResult(result: Array[Long]): ExecutionState = {
val flag = result(0)
if (flag == -1) EOF
- else if (flag == 0) Pending
else if (flag == 1) {
val numRows = result(1)
val addresses = result.slice(2, result.length)
@@ -113,36 +107,12 @@ class CometExecIterator(
/** The execution is finished - no more batch */
case object EOF extends ExecutionState
- /** The execution is pending (e.g., blocking operator is still consuming
batches) */
- case object Pending extends ExecutionState
-
- private def peek(): Option[ColumnarBatch] = {
- peekNext() match {
- case Batch(numRows, addresses) =>
- val cometVectors = nativeUtil.importVector(addresses)
- Some(new ColumnarBatch(cometVectors.toArray, numRows))
- case _ =>
- None
- }
- }
-
- def getNextBatch(
- inputArrays: Array[Array[Long]],
- finishes: Array[Boolean],
- numRows: Int): Option[ColumnarBatch] = {
- executeNative(inputArrays, finishes, numRows) match {
+ def getNextBatch(): Option[ColumnarBatch] = {
+ executeNative() match {
case EOF => None
case Batch(numRows, addresses) =>
val cometVectors = nativeUtil.importVector(addresses)
Some(new ColumnarBatch(cometVectors.toArray, numRows))
- case Pending =>
- if (finishes.forall(_ == true)) {
- // Once no input, we should not get a pending flag.
- throw new SparkException(
- "Native execution should not be pending after reaching end of
input batches")
- }
- // For pending, we keep reading next input.
- None
}
}
@@ -152,48 +122,12 @@ class CometExecIterator(
if (nextBatch.isDefined) {
return true
}
- // Before we pull next input batch, check if there is next output batch
available
- // from native side. Some operators might still have output batches ready
produced
- // from last input batch. For example, `expand` operator will produce
output batches
- // based on the input batch.
- nextBatch = peek()
-
- // Next input batches are available, execute native query plan with the
inputs until
- // we get next output batch ready
- while (nextBatch.isEmpty && inputs.exists(_.hasNext)) {
- val batches = inputs.map {
- case input if input.hasNext => Some(input.next())
- case _ => None
- }
- var numRows = -1
- val (batchAddresses, finishes) = batches
- .map {
- case Some(batch) =>
- numRows = batch.numRows()
- (nativeUtil.exportBatch(batch), false)
- case None => (Array.empty[Long], true)
- }
- .toArray
- .unzip
-
- // At least one input batch should be consumed
- assert(numRows != -1, "No input batch has been consumed")
-
- nextBatch = getNextBatch(batchAddresses, finishes, numRows)
- }
+ nextBatch = getNextBatch()
- // After we consume to the end of the iterators, the native side still can
output batches
- // back because there might be blocking operators e.g. Sort. We continue
ask for batches
- // until it returns empty columns.
if (nextBatch.isEmpty) {
- val finishes = inputs.map(_ => true).toArray
- nextBatch = getNextBatch(inputs.map(_ => Array.empty[Long]).toArray,
finishes, 0)
- val hasNext = nextBatch.isDefined
- if (!hasNext) {
- close()
- }
- hasNext
+ close()
+ false
} else {
true
}
@@ -222,6 +156,7 @@ class CometExecIterator(
currentBatch = null
}
nativeLib.releasePlan(plan)
+
// The allocator thoughts the exported ArrowArray and ArrowSchema
structs are not released,
// so it will report:
// Caused by: java.lang.IllegalStateException: Memory was leaked by
query.
diff --git a/spark/src/main/scala/org/apache/comet/Native.scala
b/spark/src/main/scala/org/apache/comet/Native.scala
index 05bada5..8c1b8ac 100644
--- a/spark/src/main/scala/org/apache/comet/Native.scala
+++ b/spark/src/main/scala/org/apache/comet/Native.scala
@@ -31,6 +31,9 @@ class Native extends NativeBase {
* The id of the query plan.
* @param configMap
* The Java Map object for the configs of native engine.
+ * @param iterators
+ * the input iterators to the native query plan. It should be the same
number as the number of
+ * scan nodes in the SparkPlan.
* @param plan
* the bytes of serialized SparkPlan.
* @param metrics
@@ -41,61 +44,21 @@ class Native extends NativeBase {
@native def createPlan(
id: Long,
configMap: Map[String, String],
+ iterators: Array[CometBatchIterator],
plan: Array[Byte],
metrics: CometMetricNode): Long
- /**
- * Return the native query plan string for the given address of native query
plan. For debugging
- * purpose.
- *
- * @param plan
- * the address to native query plan.
- * @return
- * the string of native query plan.
- */
- @native def getPlanString(plan: Long): String
-
/**
* Execute a native query plan based on given input Arrow arrays.
*
* @param plan
* the address to native query plan.
- * @param addresses
- * the array of addresses of input Arrow arrays. The addresses are
exported from Arrow Arrays
- * so the number of addresses is always even number in the sequence like
[array_address1,
- * schema_address1, array_address2, schema_address2, ...]. Note that we
can pass empty
- * addresses to this API. In this case, it indicates there are no more
input arrays to the
- * native query plan, but the query plan possibly can still execute to
produce output batch
- * because it might contain blocking operators such as Sort, Aggregate.
When this API returns
- * an empty array back, it means the native query plan is finished.
- * @param finishes
- * whether the end of input arrays is reached for each input. If this is
set to true, the
- * native library will know there is no more inputs. But it doesn't mean
the execution is
- * finished immediately. For some blocking operators native execution will
continue to output.
- * @param numRows
- * the number of rows in the batch.
- * @return
- * an array containing: 1) the status flag (0 for pending, 1 for normal
returned arrays,
- * -1 for end of output), 2) (optional) the number of rows if returned flag
is 1 3) the
- * addresses of output Arrow arrays
- */
- @native def executePlan(
- plan: Long,
- addresses: Array[Array[Long]],
- finishes: Array[Boolean],
- numRows: Int): Array[Long]
-
- /**
- * Peeks the next batch of output Arrow arrays from the native query plan
without pulling any
- * input batches.
- *
- * @param plan
- * the address to native query plan.
* @return
- * an array containing: 1) the status flag (0 for pending, 1 for normal
returned arrays, 2)
- * (optional) the number of rows if returned flag is 1 3) the addresses of
output Arrow arrays
+ * an array containing: 1) the status flag (1 for normal returned arrays,
-1 for end of
+ * output) 2) (optional) the number of rows if returned flag is 1 3) the
addresses of output
+ * Arrow arrays
*/
- @native def peekNext(plan: Long): Array[Long]
+ @native def executePlan(plan: Long): Array[Long]
/**
* Release and drop the native query plan object and context object.
diff --git
a/spark/src/test/scala/org/apache/comet/parquet/ParquetReadSuite.scala
b/spark/src/test/scala/org/apache/comet/parquet/ParquetReadSuite.scala
index 1cff74d..f447522 100644
--- a/spark/src/test/scala/org/apache/comet/parquet/ParquetReadSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/parquet/ParquetReadSuite.scala
@@ -42,7 +42,6 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.comet.CometBatchScanExec
import org.apache.spark.sql.comet.CometScanExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
-import
org.apache.spark.sql.execution.datasources.SchemaColumnConvertNotSupportedException
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -1139,7 +1138,7 @@ abstract class ParquetReadSuite extends CometTestBase {
.where(s"a < ${Long.MaxValue}")
.collect()
}
-
assert(exception.getCause.getCause.isInstanceOf[SchemaColumnConvertNotSupportedException])
+ assert(exception.getMessage.contains("Column: [a], Expected: bigint,
Found: INT32"))
}
}