This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch comet-parquet-exec
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/comet-parquet-exec by this
push:
new 7488af810 [comet-parquet-exec] fix: fix various bugs in casting
between struct types (#1226)
7488af810 is described below
commit 7488af810a62e4f323227fa7c7ae65faea1a636b
Author: Andy Grove <[email protected]>
AuthorDate: Tue Jan 7 08:38:20 2025 -0700
[comet-parquet-exec] fix: fix various bugs in casting between struct types
(#1226)
* Handle top-level nulls when casting structs
* remove unused imports
* enable test
* simplify
* fix more issues
* fix more issues
---
native/spark-expr/src/cast.rs | 24 +++++++++++----
native/spark-expr/src/schema_adapter.rs | 25 ++++++++++++---
.../org/apache/comet/CometExpressionSuite.scala | 36 ++++++++--------------
3 files changed, 51 insertions(+), 34 deletions(-)
diff --git a/native/spark-expr/src/cast.rs b/native/spark-expr/src/cast.rs
index 08541e1d2..42897a00d 100644
--- a/native/spark-expr/src/cast.rs
+++ b/native/spark-expr/src/cast.rs
@@ -37,7 +37,7 @@ use arrow::{
};
use arrow_array::builder::StringBuilder;
use arrow_array::{DictionaryArray, StringArray, StructArray};
-use arrow_schema::{DataType, Field, Schema};
+use arrow_schema::{DataType, Schema};
use chrono::{NaiveDate, NaiveDateTime, TimeZone, Timelike};
use datafusion::physical_expr_common::physical_expr::down_cast_any_ref;
use datafusion_common::{
@@ -50,6 +50,7 @@ use num::{
ToPrimitive,
};
use regex::Regex;
+use std::collections::HashMap;
use std::str::FromStr;
use std::{
any::Any,
@@ -817,17 +818,28 @@ fn cast_struct_to_struct(
cast_options: &SparkCastOptions,
) -> DataFusionResult<ArrayRef> {
match (from_type, to_type) {
- (DataType::Struct(_), DataType::Struct(to_fields)) => {
- let mut cast_fields: Vec<(Arc<Field>, ArrayRef)> =
Vec::with_capacity(to_fields.len());
+ (DataType::Struct(from_fields), DataType::Struct(to_fields)) => {
+ // TODO some of this logic may be specific to converting Parquet
to Spark
+ let mut field_name_to_index_map = HashMap::new();
+ for (i, field) in from_fields.iter().enumerate() {
+ field_name_to_index_map.insert(field.name(), i);
+ }
+ assert_eq!(field_name_to_index_map.len(), from_fields.len());
+ let mut cast_fields: Vec<ArrayRef> =
Vec::with_capacity(to_fields.len());
for i in 0..to_fields.len() {
+ let from_index = field_name_to_index_map[to_fields[i].name()];
let cast_field = cast_array(
- Arc::clone(array.column(i)),
+ Arc::clone(array.column(from_index)),
to_fields[i].data_type(),
cast_options,
)?;
- cast_fields.push((Arc::clone(&to_fields[i]), cast_field));
+ cast_fields.push(cast_field);
}
- Ok(Arc::new(StructArray::from(cast_fields)))
+ Ok(Arc::new(StructArray::new(
+ to_fields.clone(),
+ cast_fields,
+ array.nulls().map(|nulls| nulls.clone()),
+ )))
}
_ => unreachable!(),
}
diff --git a/native/spark-expr/src/schema_adapter.rs
b/native/spark-expr/src/schema_adapter.rs
index 9872b7560..626b63b80 100644
--- a/native/spark-expr/src/schema_adapter.rs
+++ b/native/spark-expr/src/schema_adapter.rs
@@ -23,6 +23,7 @@ use arrow_schema::{DataType, Schema, SchemaRef};
use datafusion::datasource::schema_adapter::{SchemaAdapter,
SchemaAdapterFactory, SchemaMapper};
use datafusion_common::plan_err;
use datafusion_expr::ColumnarValue;
+use std::collections::HashMap;
use std::sync::Arc;
/// An implementation of DataFusion's `SchemaAdapterFactory` that uses a
Spark-compatible
@@ -321,10 +322,26 @@ fn cast_supported(from_type: &DataType, to_type:
&DataType, options: &SparkCastO
(Timestamp(_, Some(_)), _) => can_cast_from_timestamp(to_type,
options),
(Utf8 | LargeUtf8, _) => can_cast_from_string(to_type, options),
(_, Utf8 | LargeUtf8) => can_cast_to_string(from_type, options),
- (Struct(from_fields), Struct(to_fields)) => from_fields
- .iter()
- .zip(to_fields.iter())
- .all(|(a, b)| cast_supported(a.data_type(), b.data_type(),
options)),
+ (Struct(from_fields), Struct(to_fields)) => {
+ // TODO some of this logic may be specific to converting Parquet
to Spark
+ let mut field_types = HashMap::new();
+ for field in from_fields {
+ field_types.insert(field.name(), field.data_type());
+ }
+ if field_types.iter().len() != from_fields.len() {
+ return false;
+ }
+ for field in to_fields {
+ if let Some(from_type) = field_types.get(&field.name()) {
+ if !cast_supported(from_type, field.data_type(), options) {
+ return false;
+ }
+ } else {
+ return false;
+ }
+ }
+ true
+ }
_ => false,
}
}
diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
index 3bcca69f4..034e58dce 100644
--- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
@@ -2195,7 +2195,7 @@ class CometExpressionSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
}
}
- ignore("get_struct_field - select primitive fields") {
+ test("get_struct_field - select primitive fields") {
withTempPath { dir =>
// create input file with Comet disabled
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
@@ -2208,16 +2208,12 @@ class CometExpressionSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
df.write.parquet(dir.toString())
}
- Seq("", "parquet").foreach { v1List =>
- withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> v1List) {
- val df = spark.read.parquet(dir.toString())
- checkSparkAnswerAndOperator(df.select("nested1.id"))
- }
- }
+ val df = spark.read.parquet(dir.toString())
+ checkSparkAnswerAndOperator(df.select("nested1.id"))
}
}
- ignore("get_struct_field - select subset of struct") {
+ test("get_struct_field - select subset of struct") {
withTempPath { dir =>
// create input file with Comet disabled
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
@@ -2236,19 +2232,15 @@ class CometExpressionSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
df.write.parquet(dir.toString())
}
- Seq("", "parquet").foreach { v1List =>
- withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> v1List) {
- val df = spark.read.parquet(dir.toString())
- checkSparkAnswerAndOperator(df.select("nested1.id"))
- checkSparkAnswerAndOperator(df.select("nested1.nested2"))
- checkSparkAnswerAndOperator(df.select("nested1.nested2.id"))
- checkSparkAnswerAndOperator(df.select("nested1.id",
"nested1.nested2.id"))
- }
- }
+ val df = spark.read.parquet(dir.toString())
+ checkSparkAnswerAndOperator(df.select("nested1.id"))
+ checkSparkAnswerAndOperator(df.select("nested1.nested2"))
+ checkSparkAnswerAndOperator(df.select("nested1.nested2.id"))
+ checkSparkAnswerAndOperator(df.select("nested1.id",
"nested1.nested2.id"))
}
}
- ignore("get_struct_field - read entire struct") {
+ test("get_struct_field - read entire struct") {
withTempPath { dir =>
// create input file with Comet disabled
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
@@ -2267,12 +2259,8 @@ class CometExpressionSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
df.write.parquet(dir.toString())
}
- Seq("", "parquet").foreach { v1List =>
- withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> v1List) {
- val df = spark.read.parquet(dir.toString())
- checkSparkAnswerAndOperator(df.select("nested1"))
- }
- }
+ val df = spark.read.parquet(dir.toString())
+ checkSparkAnswerAndOperator(df.select("nested1"))
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]