This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new b131cc39 feat: Support `GetArrayStructFields` expression (#993)
b131cc39 is described below

commit b131cc39baa617abc8fc99adf786ae88f9d33676
Author: Adam Binford <[email protected]>
AuthorDate: Mon Oct 7 11:59:04 2024 -0400

    feat: Support `GetArrayStructFields` expression (#993)
    
    * Start working on GetArrayStructFIelds
    
    * Almost have working
    
    * Working
    
    * Add another test
    
    * Remove unused
    
    * Remove unused sql conf
---
 native/core/src/execution/datafusion/planner.rs    |  13 +-
 native/proto/src/proto/expr.proto                  |   6 +
 native/spark-expr/src/lib.rs                       |   2 +-
 native/spark-expr/src/list.rs                      | 140 ++++++++++++++++++++-
 .../org/apache/comet/serde/QueryPlanSerde.scala    |  19 +++
 .../org/apache/comet/CometExpressionSuite.scala    |  22 ++++
 6 files changed, 198 insertions(+), 4 deletions(-)

diff --git a/native/core/src/execution/datafusion/planner.rs 
b/native/core/src/execution/datafusion/planner.rs
index 15de7c9a..d63fd707 100644
--- a/native/core/src/execution/datafusion/planner.rs
+++ b/native/core/src/execution/datafusion/planner.rs
@@ -96,8 +96,8 @@ use datafusion_comet_proto::{
     spark_partitioning::{partitioning::PartitioningStruct, Partitioning as 
SparkPartitioning},
 };
 use datafusion_comet_spark_expr::{
-    Cast, CreateNamedStruct, DateTruncExpr, GetStructField, HourExpr, IfExpr, 
ListExtract,
-    MinuteExpr, RLike, SecondExpr, TimestampTruncExpr, ToJson,
+    Cast, CreateNamedStruct, DateTruncExpr, GetArrayStructFields, 
GetStructField, HourExpr, IfExpr,
+    ListExtract, MinuteExpr, RLike, SecondExpr, TimestampTruncExpr, ToJson,
 };
 use datafusion_common::scalar::ScalarStructBuilder;
 use datafusion_common::{
@@ -680,6 +680,15 @@ impl PhysicalPlanner {
                     expr.fail_on_error,
                 )))
             }
+            ExprStruct::GetArrayStructFields(expr) => {
+                let child =
+                    self.create_expr(expr.child.as_ref().unwrap(), 
Arc::clone(&input_schema))?;
+
+                Ok(Arc::new(GetArrayStructFields::new(
+                    child,
+                    expr.ordinal as usize,
+                )))
+            }
             expr => Err(ExecutionError::GeneralError(format!(
                 "Not implemented: {:?}",
                 expr
diff --git a/native/proto/src/proto/expr.proto 
b/native/proto/src/proto/expr.proto
index 88940f38..1a3e3c9f 100644
--- a/native/proto/src/proto/expr.proto
+++ b/native/proto/src/proto/expr.proto
@@ -81,6 +81,7 @@ message Expr {
     GetStructField get_struct_field = 54;
     ToJson to_json = 55;
     ListExtract list_extract = 56;
+    GetArrayStructFields get_array_struct_fields = 57;
   }
 }
 
@@ -517,6 +518,11 @@ message ListExtract {
   bool fail_on_error = 5;
 }
 
+message GetArrayStructFields {
+  Expr child = 1;
+  int32 ordinal = 2;
+}
+
 enum SortDirection {
   Ascending = 0;
   Descending = 1;
diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs
index c4b1c99b..cc22dfcb 100644
--- a/native/spark-expr/src/lib.rs
+++ b/native/spark-expr/src/lib.rs
@@ -38,7 +38,7 @@ mod xxhash64;
 pub use cast::{spark_cast, Cast};
 pub use error::{SparkError, SparkResult};
 pub use if_expr::IfExpr;
-pub use list::ListExtract;
+pub use list::{GetArrayStructFields, ListExtract};
 pub use regexp::RLike;
 pub use structs::{CreateNamedStruct, GetStructField};
 pub use temporal::{DateTruncExpr, HourExpr, MinuteExpr, SecondExpr, 
TimestampTruncExpr};
diff --git a/native/spark-expr/src/list.rs b/native/spark-expr/src/list.rs
index 0b85a842..a376198d 100644
--- a/native/spark-expr/src/list.rs
+++ b/native/spark-expr/src/list.rs
@@ -16,7 +16,7 @@
 // under the License.
 
 use arrow::{array::MutableArrayData, datatypes::ArrowNativeType, 
record_batch::RecordBatch};
-use arrow_array::{Array, GenericListArray, Int32Array, OffsetSizeTrait};
+use arrow_array::{Array, GenericListArray, Int32Array, OffsetSizeTrait, 
StructArray};
 use arrow_schema::{DataType, FieldRef, Schema};
 use datafusion::logical_expr::ColumnarValue;
 use datafusion::physical_expr_common::physical_expr::down_cast_any_ref;
@@ -275,6 +275,144 @@ impl PartialEq<dyn Any> for ListExtract {
     }
 }
 
+#[derive(Debug, Hash)]
+pub struct GetArrayStructFields {
+    child: Arc<dyn PhysicalExpr>,
+    ordinal: usize,
+}
+
+impl GetArrayStructFields {
+    pub fn new(child: Arc<dyn PhysicalExpr>, ordinal: usize) -> Self {
+        Self { child, ordinal }
+    }
+
+    fn list_field(&self, input_schema: &Schema) -> DataFusionResult<FieldRef> {
+        match self.child.data_type(input_schema)? {
+            DataType::List(field) | DataType::LargeList(field) => Ok(field),
+            data_type => Err(DataFusionError::Internal(format!(
+                "Unexpected data type in GetArrayStructFields: {:?}",
+                data_type
+            ))),
+        }
+    }
+
+    fn child_field(&self, input_schema: &Schema) -> DataFusionResult<FieldRef> 
{
+        match self.list_field(input_schema)?.data_type() {
+            DataType::Struct(fields) => Ok(Arc::clone(&fields[self.ordinal])),
+            data_type => Err(DataFusionError::Internal(format!(
+                "Unexpected data type in GetArrayStructFields: {:?}",
+                data_type
+            ))),
+        }
+    }
+}
+
+impl PhysicalExpr for GetArrayStructFields {
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn data_type(&self, input_schema: &Schema) -> DataFusionResult<DataType> {
+        let struct_field = self.child_field(input_schema)?;
+        match self.child.data_type(input_schema)? {
+            DataType::List(_) => Ok(DataType::List(struct_field)),
+            DataType::LargeList(_) => Ok(DataType::LargeList(struct_field)),
+            data_type => Err(DataFusionError::Internal(format!(
+                "Unexpected data type in GetArrayStructFields: {:?}",
+                data_type
+            ))),
+        }
+    }
+
+    fn nullable(&self, input_schema: &Schema) -> DataFusionResult<bool> {
+        Ok(self.list_field(input_schema)?.is_nullable()
+            || self.child_field(input_schema)?.is_nullable())
+    }
+
+    fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult<ColumnarValue> 
{
+        let child_value = 
self.child.evaluate(batch)?.into_array(batch.num_rows())?;
+
+        match child_value.data_type() {
+            DataType::List(_) => {
+                let list_array = as_list_array(&child_value)?;
+
+                get_array_struct_fields(list_array, self.ordinal)
+            }
+            DataType::LargeList(_) => {
+                let list_array = as_large_list_array(&child_value)?;
+
+                get_array_struct_fields(list_array, self.ordinal)
+            }
+            data_type => Err(DataFusionError::Internal(format!(
+                "Unexpected child type for ListExtract: {:?}",
+                data_type
+            ))),
+        }
+    }
+
+    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
+        vec![&self.child]
+    }
+
+    fn with_new_children(
+        self: Arc<Self>,
+        children: Vec<Arc<dyn PhysicalExpr>>,
+    ) -> datafusion_common::Result<Arc<dyn PhysicalExpr>> {
+        match children.len() {
+            1 => Ok(Arc::new(GetArrayStructFields::new(
+                Arc::clone(&children[0]),
+                self.ordinal,
+            ))),
+            _ => internal_err!("GetArrayStructFields should have exactly one 
child"),
+        }
+    }
+
+    fn dyn_hash(&self, state: &mut dyn Hasher) {
+        let mut s = state;
+        self.child.hash(&mut s);
+        self.ordinal.hash(&mut s);
+        self.hash(&mut s);
+    }
+}
+
+fn get_array_struct_fields<O: OffsetSizeTrait>(
+    list_array: &GenericListArray<O>,
+    ordinal: usize,
+) -> DataFusionResult<ColumnarValue> {
+    let values = list_array
+        .values()
+        .as_any()
+        .downcast_ref::<StructArray>()
+        .expect("A struct is expected");
+
+    let column = Arc::clone(values.column(ordinal));
+    let field = Arc::clone(&values.fields()[ordinal]);
+
+    let offsets = list_array.offsets();
+    let array = GenericListArray::new(field, offsets.clone(), column, 
list_array.nulls().cloned());
+
+    Ok(ColumnarValue::Array(Arc::new(array)))
+}
+
+impl Display for GetArrayStructFields {
+    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+        write!(
+            f,
+            "GetArrayStructFields [child: {:?}, ordinal: {:?}]",
+            self.child, self.ordinal
+        )
+    }
+}
+
+impl PartialEq<dyn Any> for GetArrayStructFields {
+    fn eq(&self, other: &dyn Any) -> bool {
+        down_cast_any_ref(other)
+            .downcast_ref::<Self>()
+            .map(|x| self.child.eq(&x.child) && self.ordinal.eq(&x.ordinal))
+            .unwrap_or(false)
+    }
+}
+
 #[cfg(test)]
 mod test {
     use crate::list::{list_extract, zero_based_index};
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 51b32b7d..02b845e7 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -2542,6 +2542,25 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
             None
           }
 
+        case GetArrayStructFields(child, _, ordinal, _, _) =>
+          val childExpr = exprToProto(child, inputs, binding)
+
+          if (childExpr.isDefined) {
+            val arrayStructFieldsBuilder = ExprOuterClass.GetArrayStructFields
+              .newBuilder()
+              .setChild(childExpr.get)
+              .setOrdinal(ordinal)
+
+            Some(
+              ExprOuterClass.Expr
+                .newBuilder()
+                .setGetArrayStructFields(arrayStructFieldsBuilder)
+                .build())
+          } else {
+            withInfo(expr, "unsupported arguments for GetArrayStructFields", 
child)
+            None
+          }
+
         case _ =>
           withInfo(expr, s"${expr.prettyName} is not supported", 
expr.children: _*)
           None
diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
index 16bc15b8..da22df40 100644
--- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
@@ -2271,4 +2271,26 @@ class CometExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
       }
     }
   }
+
+  test("GetArrayStructFields") {
+    Seq(true, false).foreach { dictionaryEnabled =>
+      withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> 
SimplifyExtractValueOps.ruleName) {
+        withTempDir { dir =>
+          val path = new Path(dir.toURI.toString, "test.parquet")
+          makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 
10000)
+          val df = spark.read
+            .parquet(path.toString)
+            .select(
+              array(struct(col("_2"), col("_3"), col("_4"), col("_8")), 
lit(null)).alias("arr"))
+          checkSparkAnswerAndOperator(df.select("arr._2", "arr._3", "arr._4"))
+
+          val complex = spark.read
+            .parquet(path.toString)
+            .select(array(struct(struct(col("_4"), 
col("_8")).alias("nested"))).alias("arr"))
+
+          checkSparkAnswerAndOperator(complex.select(col("arr.nested._4")))
+        }
+      }
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to