This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch comet-parquet-exec
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/comet-parquet-exec by this 
push:
     new e0d80773 [comet-parquet-exec] Simplify schema logic for 
CometNativeScan (#1142)
e0d80773 is described below

commit e0d807736b97ac1351dc34336c249bc7ac540558
Author: Matt Butrovich <[email protected]>
AuthorDate: Thu Dec 5 12:10:56 2024 -0500

    [comet-parquet-exec] Simplify schema logic for CometNativeScan (#1142)
    
    * Serialize original data schema and required schema, generate projection 
vector on the Java side.
    
    * Sending over more schema info like column names and nullability.
    
    * Using the new stuff in the proto. About to take the old out.
    
    * Remove old logic.
    
    * remove errant print.
    
    * Serialize original data schema and required schema, generate projection 
vector on the Java side.
    
    * Sending over more schema info like column names and nullability.
    
    * Using the new stuff in the proto. About to take the old out.
    
    * Remove old logic.
    
    * remove errant print.
    
    * Remove commented print. format.
    
    * Remove commented print. format.
    
    * Fix projection_vector to include partition_schema cols correctly.
    
    * Rename variable.
---
 native/core/src/execution/datafusion/planner.rs    | 96 +++++++++-------------
 .../src/execution/datafusion/schema_adapter.rs     |  3 +-
 native/proto/src/proto/operator.proto              | 13 ++-
 .../org/apache/comet/serde/QueryPlanSerde.scala    | 40 ++++++---
 4 files changed, 82 insertions(+), 70 deletions(-)

diff --git a/native/core/src/execution/datafusion/planner.rs 
b/native/core/src/execution/datafusion/planner.rs
index c5147d77..40fe4515 100644
--- a/native/core/src/execution/datafusion/planner.rs
+++ b/native/core/src/execution/datafusion/planner.rs
@@ -121,7 +121,6 @@ use datafusion_physical_expr::LexOrdering;
 use itertools::Itertools;
 use jni::objects::GlobalRef;
 use num::{BigInt, ToPrimitive};
-use parquet::schema::parser::parse_message_type;
 use std::cmp::max;
 use std::{collections::HashMap, sync::Arc};
 use url::Url;
@@ -950,50 +949,28 @@ impl PhysicalPlanner {
                 ))
             }
             OpStruct::NativeScan(scan) => {
-                let data_schema = 
parse_message_type(&scan.data_schema).unwrap();
-                let required_schema = 
parse_message_type(&scan.required_schema).unwrap();
-
-                let data_schema_descriptor =
-                    
parquet::schema::types::SchemaDescriptor::new(Arc::new(data_schema));
-                let data_schema_arrow = Arc::new(
-                    
parquet::arrow::schema::parquet_to_arrow_schema(&data_schema_descriptor, None)
-                        .unwrap(),
-                );
-
-                let required_schema_descriptor =
-                    
parquet::schema::types::SchemaDescriptor::new(Arc::new(required_schema));
-                let required_schema_arrow = Arc::new(
-                    parquet::arrow::schema::parquet_to_arrow_schema(
-                        &required_schema_descriptor,
-                        None,
-                    )
-                    .unwrap(),
-                );
-
-                let partition_schema_arrow = scan
-                    .partition_schema
+                let data_schema = 
convert_spark_types_to_arrow_schema(scan.data_schema.as_slice());
+                let required_schema: SchemaRef =
+                    
convert_spark_types_to_arrow_schema(scan.required_schema.as_slice());
+                let partition_schema: SchemaRef =
+                    
convert_spark_types_to_arrow_schema(scan.partition_schema.as_slice());
+                let projection_vector: Vec<usize> = scan
+                    .projection_vector
                     .iter()
-                    .map(to_arrow_datatype)
-                    .collect_vec();
-                let partition_fields: Vec<_> = partition_schema_arrow
-                    .iter()
-                    .enumerate()
-                    .map(|(idx, data_type)| {
-                        Field::new(format!("part_{}", idx), data_type.clone(), 
true)
-                    })
+                    .map(|offset| *offset as usize)
                     .collect();
 
                 // Convert the Spark expressions to Physical expressions
                 let data_filters: Result<Vec<Arc<dyn PhysicalExpr>>, 
ExecutionError> = scan
                     .data_filters
                     .iter()
-                    .map(|expr| self.create_expr(expr, 
Arc::clone(&required_schema_arrow)))
+                    .map(|expr| self.create_expr(expr, 
Arc::clone(&required_schema)))
                     .collect();
 
                 // Create a conjunctive form of the vector because 
ParquetExecBuilder takes
                 // a single expression
                 let data_filters = data_filters?;
-                let test_data_filters = 
data_filters.clone().into_iter().reduce(|left, right| {
+                let cnf_data_filters = 
data_filters.clone().into_iter().reduce(|left, right| {
                     Arc::new(BinaryExpr::new(
                         left,
                         datafusion::logical_expr::Operator::And,
@@ -1064,29 +1041,21 @@ impl PhysicalPlanner {
                 assert_eq!(file_groups.len(), partition_count);
 
                 let object_store_url = ObjectStoreUrl::local_filesystem();
+                let partition_fields: Vec<Field> = partition_schema
+                    .fields()
+                    .iter()
+                    .map(|field| {
+                        Field::new(field.name(), field.data_type().clone(), 
field.is_nullable())
+                    })
+                    .collect_vec();
                 let mut file_scan_config =
-                    FileScanConfig::new(object_store_url, 
Arc::clone(&data_schema_arrow))
+                    FileScanConfig::new(object_store_url, 
Arc::clone(&data_schema))
                         .with_file_groups(file_groups)
                         .with_table_partition_cols(partition_fields);
 
-                // Check for projection, if so generate the vector and add to 
FileScanConfig.
-                let mut projection_vector: Vec<usize> =
-                    Vec::with_capacity(required_schema_arrow.fields.len());
-                // TODO: could be faster with a hashmap rather than iterating 
over data_schema_arrow with index_of.
-                required_schema_arrow.fields.iter().for_each(|field| {
-                    
projection_vector.push(data_schema_arrow.index_of(field.name()).unwrap());
-                });
-
-                partition_schema_arrow
-                    .iter()
-                    .enumerate()
-                    .for_each(|(idx, _)| {
-                        projection_vector.push(idx + 
data_schema_arrow.fields.len());
-                    });
-
                 assert_eq!(
                     projection_vector.len(),
-                    required_schema_arrow.fields.len() + 
partition_schema_arrow.len()
+                    required_schema.fields.len() + 
partition_schema.fields.len()
                 );
                 file_scan_config = 
file_scan_config.with_projection(Some(projection_vector));
 
@@ -1095,13 +1064,11 @@ impl PhysicalPlanner {
                 table_parquet_options.global.pushdown_filters = true;
                 table_parquet_options.global.reorder_filters = true;
 
-                    let mut builder = ParquetExecBuilder::new(file_scan_config)
-                        .with_table_parquet_options(table_parquet_options)
-                        .with_schema_adapter_factory(
-                            Arc::new(CometSchemaAdapterFactory::default()),
-                        );
+                let mut builder = ParquetExecBuilder::new(file_scan_config)
+                    .with_table_parquet_options(table_parquet_options)
+                    
.with_schema_adapter_factory(Arc::new(CometSchemaAdapterFactory::default()));
 
-                if let Some(filter) = test_data_filters {
+                if let Some(filter) = cnf_data_filters {
                     builder = builder.with_predicate(filter);
                 }
 
@@ -2309,6 +2276,23 @@ fn from_protobuf_eval_mode(value: i32) -> 
Result<EvalMode, prost::DecodeError> {
     }
 }
 
+fn convert_spark_types_to_arrow_schema(
+    spark_types: &[spark_operator::SparkStructField],
+) -> SchemaRef {
+    let arrow_fields = spark_types
+        .iter()
+        .map(|spark_type| {
+            Field::new(
+                String::clone(&spark_type.name),
+                to_arrow_datatype(spark_type.data_type.as_ref().unwrap()),
+                spark_type.nullable,
+            )
+        })
+        .collect_vec();
+    let arrow_schema: SchemaRef = Arc::new(Schema::new(arrow_fields));
+    arrow_schema
+}
+
 #[cfg(test)]
 mod tests {
     use std::{sync::Arc, task::Poll};
diff --git a/native/core/src/execution/datafusion/schema_adapter.rs 
b/native/core/src/execution/datafusion/schema_adapter.rs
index 16d4b9d6..79dcd5c1 100644
--- a/native/core/src/execution/datafusion/schema_adapter.rs
+++ b/native/core/src/execution/datafusion/schema_adapter.rs
@@ -259,7 +259,8 @@ impl SchemaMapper for SchemaMapping {
                             EvalMode::Legacy,
                             "UTC",
                             false,
-                        )?.into_array(batch_col.len())
+                        )?
+                        .into_array(batch_col.len())
                         // and if that works, return the field and column.
                         .map(|new_col| (new_col, table_field.clone()))
                     })
diff --git a/native/proto/src/proto/operator.proto 
b/native/proto/src/proto/operator.proto
index 5e8a80f9..b4e12d12 100644
--- a/native/proto/src/proto/operator.proto
+++ b/native/proto/src/proto/operator.proto
@@ -61,6 +61,12 @@ message SparkFilePartition {
   repeated SparkPartitionedFile partitioned_file = 1;
 }
 
+message SparkStructField {
+  string name = 1;
+  spark.spark_expression.DataType data_type = 2;
+  bool nullable = 3;
+}
+
 message Scan {
   repeated spark.spark_expression.DataType fields = 1;
   // The source of the scan (e.g. file scan, broadcast exchange, shuffle, 
etc). This
@@ -75,11 +81,12 @@ message NativeScan {
   // is purely for informational purposes when viewing native query plans in
   // debug mode.
   string source = 2;
-  string required_schema = 3;
-  string data_schema = 4;
-  repeated spark.spark_expression.DataType partition_schema = 5;
+  repeated SparkStructField required_schema = 3;
+  repeated SparkStructField data_schema = 4;
+  repeated SparkStructField partition_schema = 5;
   repeated spark.spark_expression.Expr data_filters = 6;
   repeated SparkFilePartition file_partitions = 7;
+  repeated int64 projection_vector = 8;
 }
 
 message Projection {
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 7473e932..29e50d73 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -36,7 +36,6 @@ import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, 
ShuffleQueryStageExec}
 import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, 
HashAggregateExec, ObjectHashAggregateExec}
 import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD}
-import 
org.apache.spark.sql.execution.datasources.parquet.SparkToParquetSchemaConverter
 import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDD, 
DataSourceRDDPartition}
 import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, 
ReusedExchangeExec, ShuffleExchangeExec}
 import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin, 
ShuffledHashJoinExec, SortMergeJoinExec}
@@ -2520,18 +2519,28 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
             case _ =>
           }
 
-          val requiredSchemaParquet =
-            new 
SparkToParquetSchemaConverter(conf).convert(scan.requiredSchema)
-          val dataSchemaParquet =
-            new 
SparkToParquetSchemaConverter(conf).convert(scan.relation.dataSchema)
-          val partitionSchema = scan.relation.partitionSchema.fields.flatMap { 
field =>
-            serializeDataType(field.dataType)
-          }
+          val partitionSchema = 
schema2Proto(scan.relation.partitionSchema.fields)
+          val requiredSchema = schema2Proto(scan.requiredSchema.fields)
+          val dataSchema = schema2Proto(scan.relation.dataSchema.fields)
+
+          val data_schema_idxs = scan.requiredSchema.fields.map(field => {
+            scan.relation.dataSchema.fieldIndex(field.name)
+          })
+          val partition_schema_idxs = Array
+            .range(
+              scan.relation.dataSchema.fields.length,
+              scan.relation.dataSchema.length + 
scan.relation.partitionSchema.fields.length)
+
+          val projection_vector = (data_schema_idxs ++ 
partition_schema_idxs).map(idx =>
+            idx.toLong.asInstanceOf[java.lang.Long])
+
+          
nativeScanBuilder.addAllProjectionVector(projection_vector.toIterable.asJava)
+
           // In `CometScanRule`, we ensure partitionSchema is supported.
           assert(partitionSchema.length == 
scan.relation.partitionSchema.fields.length)
 
-          nativeScanBuilder.setRequiredSchema(requiredSchemaParquet.toString)
-          nativeScanBuilder.setDataSchema(dataSchemaParquet.toString)
+          nativeScanBuilder.addAllDataSchema(dataSchema.toIterable.asJava)
+          
nativeScanBuilder.addAllRequiredSchema(requiredSchema.toIterable.asJava)
           
nativeScanBuilder.addAllPartitionSchema(partitionSchema.toIterable.asJava)
 
           Some(result.setNativeScan(nativeScanBuilder).build())
@@ -3198,6 +3207,17 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
     true
   }
 
+  private def schema2Proto(
+      fields: Array[StructField]): Array[OperatorOuterClass.SparkStructField] 
= {
+    val fieldBuilder = OperatorOuterClass.SparkStructField.newBuilder()
+    fields.map(field => {
+      fieldBuilder.setName(field.name)
+      fieldBuilder.setDataType(serializeDataType(field.dataType).get)
+      fieldBuilder.setNullable(field.nullable)
+      fieldBuilder.build()
+    })
+  }
+
   private def partition2Proto(
       partition: FilePartition,
       nativeScanBuilder: OperatorOuterClass.NativeScan.Builder,


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to