This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch comet-parquet-exec
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/comet-parquet-exec by this
push:
new e0d80773 [comet-parquet-exec] Simplify schema logic for
CometNativeScan (#1142)
e0d80773 is described below
commit e0d807736b97ac1351dc34336c249bc7ac540558
Author: Matt Butrovich <[email protected]>
AuthorDate: Thu Dec 5 12:10:56 2024 -0500
[comet-parquet-exec] Simplify schema logic for CometNativeScan (#1142)
* Serialize original data schema and required schema, generate projection
vector on the Java side.
* Sending over more schema info like column names and nullability.
* Using the new stuff in the proto. About to take the old out.
* Remove old logic.
* remove errant print.
* Serialize original data schema and required schema, generate projection
vector on the Java side.
* Sending over more schema info like column names and nullability.
* Using the new stuff in the proto. About to take the old out.
* Remove old logic.
* remove errant print.
* Remove commented print. format.
* Remove commented print. format.
* Fix projection_vector to include partition_schema cols correctly.
* Rename variable.
---
native/core/src/execution/datafusion/planner.rs | 96 +++++++++-------------
.../src/execution/datafusion/schema_adapter.rs | 3 +-
native/proto/src/proto/operator.proto | 13 ++-
.../org/apache/comet/serde/QueryPlanSerde.scala | 40 ++++++---
4 files changed, 82 insertions(+), 70 deletions(-)
diff --git a/native/core/src/execution/datafusion/planner.rs
b/native/core/src/execution/datafusion/planner.rs
index c5147d77..40fe4515 100644
--- a/native/core/src/execution/datafusion/planner.rs
+++ b/native/core/src/execution/datafusion/planner.rs
@@ -121,7 +121,6 @@ use datafusion_physical_expr::LexOrdering;
use itertools::Itertools;
use jni::objects::GlobalRef;
use num::{BigInt, ToPrimitive};
-use parquet::schema::parser::parse_message_type;
use std::cmp::max;
use std::{collections::HashMap, sync::Arc};
use url::Url;
@@ -950,50 +949,28 @@ impl PhysicalPlanner {
))
}
OpStruct::NativeScan(scan) => {
- let data_schema =
parse_message_type(&scan.data_schema).unwrap();
- let required_schema =
parse_message_type(&scan.required_schema).unwrap();
-
- let data_schema_descriptor =
-
parquet::schema::types::SchemaDescriptor::new(Arc::new(data_schema));
- let data_schema_arrow = Arc::new(
-
parquet::arrow::schema::parquet_to_arrow_schema(&data_schema_descriptor, None)
- .unwrap(),
- );
-
- let required_schema_descriptor =
-
parquet::schema::types::SchemaDescriptor::new(Arc::new(required_schema));
- let required_schema_arrow = Arc::new(
- parquet::arrow::schema::parquet_to_arrow_schema(
- &required_schema_descriptor,
- None,
- )
- .unwrap(),
- );
-
- let partition_schema_arrow = scan
- .partition_schema
+ let data_schema =
convert_spark_types_to_arrow_schema(scan.data_schema.as_slice());
+ let required_schema: SchemaRef =
+
convert_spark_types_to_arrow_schema(scan.required_schema.as_slice());
+ let partition_schema: SchemaRef =
+
convert_spark_types_to_arrow_schema(scan.partition_schema.as_slice());
+ let projection_vector: Vec<usize> = scan
+ .projection_vector
.iter()
- .map(to_arrow_datatype)
- .collect_vec();
- let partition_fields: Vec<_> = partition_schema_arrow
- .iter()
- .enumerate()
- .map(|(idx, data_type)| {
- Field::new(format!("part_{}", idx), data_type.clone(),
true)
- })
+ .map(|offset| *offset as usize)
.collect();
// Convert the Spark expressions to Physical expressions
let data_filters: Result<Vec<Arc<dyn PhysicalExpr>>,
ExecutionError> = scan
.data_filters
.iter()
- .map(|expr| self.create_expr(expr,
Arc::clone(&required_schema_arrow)))
+ .map(|expr| self.create_expr(expr,
Arc::clone(&required_schema)))
.collect();
// Create a conjunctive form of the vector because
ParquetExecBuilder takes
// a single expression
let data_filters = data_filters?;
- let test_data_filters =
data_filters.clone().into_iter().reduce(|left, right| {
+ let cnf_data_filters =
data_filters.clone().into_iter().reduce(|left, right| {
Arc::new(BinaryExpr::new(
left,
datafusion::logical_expr::Operator::And,
@@ -1064,29 +1041,21 @@ impl PhysicalPlanner {
assert_eq!(file_groups.len(), partition_count);
let object_store_url = ObjectStoreUrl::local_filesystem();
+ let partition_fields: Vec<Field> = partition_schema
+ .fields()
+ .iter()
+ .map(|field| {
+ Field::new(field.name(), field.data_type().clone(),
field.is_nullable())
+ })
+ .collect_vec();
let mut file_scan_config =
- FileScanConfig::new(object_store_url,
Arc::clone(&data_schema_arrow))
+ FileScanConfig::new(object_store_url,
Arc::clone(&data_schema))
.with_file_groups(file_groups)
.with_table_partition_cols(partition_fields);
- // Check for projection, if so generate the vector and add to
FileScanConfig.
- let mut projection_vector: Vec<usize> =
- Vec::with_capacity(required_schema_arrow.fields.len());
- // TODO: could be faster with a hashmap rather than iterating
over data_schema_arrow with index_of.
- required_schema_arrow.fields.iter().for_each(|field| {
-
projection_vector.push(data_schema_arrow.index_of(field.name()).unwrap());
- });
-
- partition_schema_arrow
- .iter()
- .enumerate()
- .for_each(|(idx, _)| {
- projection_vector.push(idx +
data_schema_arrow.fields.len());
- });
-
assert_eq!(
projection_vector.len(),
- required_schema_arrow.fields.len() +
partition_schema_arrow.len()
+ required_schema.fields.len() +
partition_schema.fields.len()
);
file_scan_config =
file_scan_config.with_projection(Some(projection_vector));
@@ -1095,13 +1064,11 @@ impl PhysicalPlanner {
table_parquet_options.global.pushdown_filters = true;
table_parquet_options.global.reorder_filters = true;
- let mut builder = ParquetExecBuilder::new(file_scan_config)
- .with_table_parquet_options(table_parquet_options)
- .with_schema_adapter_factory(
- Arc::new(CometSchemaAdapterFactory::default()),
- );
+ let mut builder = ParquetExecBuilder::new(file_scan_config)
+ .with_table_parquet_options(table_parquet_options)
+
.with_schema_adapter_factory(Arc::new(CometSchemaAdapterFactory::default()));
- if let Some(filter) = test_data_filters {
+ if let Some(filter) = cnf_data_filters {
builder = builder.with_predicate(filter);
}
@@ -2309,6 +2276,23 @@ fn from_protobuf_eval_mode(value: i32) ->
Result<EvalMode, prost::DecodeError> {
}
}
+fn convert_spark_types_to_arrow_schema(
+ spark_types: &[spark_operator::SparkStructField],
+) -> SchemaRef {
+ let arrow_fields = spark_types
+ .iter()
+ .map(|spark_type| {
+ Field::new(
+ String::clone(&spark_type.name),
+ to_arrow_datatype(spark_type.data_type.as_ref().unwrap()),
+ spark_type.nullable,
+ )
+ })
+ .collect_vec();
+ let arrow_schema: SchemaRef = Arc::new(Schema::new(arrow_fields));
+ arrow_schema
+}
+
#[cfg(test)]
mod tests {
use std::{sync::Arc, task::Poll};
diff --git a/native/core/src/execution/datafusion/schema_adapter.rs
b/native/core/src/execution/datafusion/schema_adapter.rs
index 16d4b9d6..79dcd5c1 100644
--- a/native/core/src/execution/datafusion/schema_adapter.rs
+++ b/native/core/src/execution/datafusion/schema_adapter.rs
@@ -259,7 +259,8 @@ impl SchemaMapper for SchemaMapping {
EvalMode::Legacy,
"UTC",
false,
- )?.into_array(batch_col.len())
+ )?
+ .into_array(batch_col.len())
// and if that works, return the field and column.
.map(|new_col| (new_col, table_field.clone()))
})
diff --git a/native/proto/src/proto/operator.proto
b/native/proto/src/proto/operator.proto
index 5e8a80f9..b4e12d12 100644
--- a/native/proto/src/proto/operator.proto
+++ b/native/proto/src/proto/operator.proto
@@ -61,6 +61,12 @@ message SparkFilePartition {
repeated SparkPartitionedFile partitioned_file = 1;
}
+message SparkStructField {
+ string name = 1;
+ spark.spark_expression.DataType data_type = 2;
+ bool nullable = 3;
+}
+
message Scan {
repeated spark.spark_expression.DataType fields = 1;
// The source of the scan (e.g. file scan, broadcast exchange, shuffle,
etc). This
@@ -75,11 +81,12 @@ message NativeScan {
// is purely for informational purposes when viewing native query plans in
// debug mode.
string source = 2;
- string required_schema = 3;
- string data_schema = 4;
- repeated spark.spark_expression.DataType partition_schema = 5;
+ repeated SparkStructField required_schema = 3;
+ repeated SparkStructField data_schema = 4;
+ repeated SparkStructField partition_schema = 5;
repeated spark.spark_expression.Expr data_filters = 6;
repeated SparkFilePartition file_partitions = 7;
+ repeated int64 projection_vector = 8;
}
message Projection {
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 7473e932..29e50d73 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -36,7 +36,6 @@ import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec,
ShuffleQueryStageExec}
import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec,
HashAggregateExec, ObjectHashAggregateExec}
import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD}
-import
org.apache.spark.sql.execution.datasources.parquet.SparkToParquetSchemaConverter
import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDD,
DataSourceRDDPartition}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec,
ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin,
ShuffledHashJoinExec, SortMergeJoinExec}
@@ -2520,18 +2519,28 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
case _ =>
}
- val requiredSchemaParquet =
- new
SparkToParquetSchemaConverter(conf).convert(scan.requiredSchema)
- val dataSchemaParquet =
- new
SparkToParquetSchemaConverter(conf).convert(scan.relation.dataSchema)
- val partitionSchema = scan.relation.partitionSchema.fields.flatMap {
field =>
- serializeDataType(field.dataType)
- }
+ val partitionSchema =
schema2Proto(scan.relation.partitionSchema.fields)
+ val requiredSchema = schema2Proto(scan.requiredSchema.fields)
+ val dataSchema = schema2Proto(scan.relation.dataSchema.fields)
+
+ val data_schema_idxs = scan.requiredSchema.fields.map(field => {
+ scan.relation.dataSchema.fieldIndex(field.name)
+ })
+ val partition_schema_idxs = Array
+ .range(
+ scan.relation.dataSchema.fields.length,
+ scan.relation.dataSchema.length +
scan.relation.partitionSchema.fields.length)
+
+ val projection_vector = (data_schema_idxs ++
partition_schema_idxs).map(idx =>
+ idx.toLong.asInstanceOf[java.lang.Long])
+
+
nativeScanBuilder.addAllProjectionVector(projection_vector.toIterable.asJava)
+
// In `CometScanRule`, we ensure partitionSchema is supported.
assert(partitionSchema.length ==
scan.relation.partitionSchema.fields.length)
- nativeScanBuilder.setRequiredSchema(requiredSchemaParquet.toString)
- nativeScanBuilder.setDataSchema(dataSchemaParquet.toString)
+ nativeScanBuilder.addAllDataSchema(dataSchema.toIterable.asJava)
+
nativeScanBuilder.addAllRequiredSchema(requiredSchema.toIterable.asJava)
nativeScanBuilder.addAllPartitionSchema(partitionSchema.toIterable.asJava)
Some(result.setNativeScan(nativeScanBuilder).build())
@@ -3198,6 +3207,17 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
true
}
+ private def schema2Proto(
+ fields: Array[StructField]): Array[OperatorOuterClass.SparkStructField]
= {
+ val fieldBuilder = OperatorOuterClass.SparkStructField.newBuilder()
+ fields.map(field => {
+ fieldBuilder.setName(field.name)
+ fieldBuilder.setDataType(serializeDataType(field.dataType).get)
+ fieldBuilder.setNullable(field.nullable)
+ fieldBuilder.build()
+ })
+ }
+
private def partition2Proto(
partition: FilePartition,
nativeScanBuilder: OperatorOuterClass.NativeScan.Builder,
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]