This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new b131cc39 feat: Support `GetArrayStructFields` expression (#993)
b131cc39 is described below
commit b131cc39baa617abc8fc99adf786ae88f9d33676
Author: Adam Binford <[email protected]>
AuthorDate: Mon Oct 7 11:59:04 2024 -0400
feat: Support `GetArrayStructFields` expression (#993)
* Start working on GetArrayStructFIelds
* Almost have working
* Working
* Add another test
* Remove unused
* Remove unused sql conf
---
native/core/src/execution/datafusion/planner.rs | 13 +-
native/proto/src/proto/expr.proto | 6 +
native/spark-expr/src/lib.rs | 2 +-
native/spark-expr/src/list.rs | 140 ++++++++++++++++++++-
.../org/apache/comet/serde/QueryPlanSerde.scala | 19 +++
.../org/apache/comet/CometExpressionSuite.scala | 22 ++++
6 files changed, 198 insertions(+), 4 deletions(-)
diff --git a/native/core/src/execution/datafusion/planner.rs
b/native/core/src/execution/datafusion/planner.rs
index 15de7c9a..d63fd707 100644
--- a/native/core/src/execution/datafusion/planner.rs
+++ b/native/core/src/execution/datafusion/planner.rs
@@ -96,8 +96,8 @@ use datafusion_comet_proto::{
spark_partitioning::{partitioning::PartitioningStruct, Partitioning as
SparkPartitioning},
};
use datafusion_comet_spark_expr::{
- Cast, CreateNamedStruct, DateTruncExpr, GetStructField, HourExpr, IfExpr,
ListExtract,
- MinuteExpr, RLike, SecondExpr, TimestampTruncExpr, ToJson,
+ Cast, CreateNamedStruct, DateTruncExpr, GetArrayStructFields,
GetStructField, HourExpr, IfExpr,
+ ListExtract, MinuteExpr, RLike, SecondExpr, TimestampTruncExpr, ToJson,
};
use datafusion_common::scalar::ScalarStructBuilder;
use datafusion_common::{
@@ -680,6 +680,15 @@ impl PhysicalPlanner {
expr.fail_on_error,
)))
}
+ ExprStruct::GetArrayStructFields(expr) => {
+ let child =
+ self.create_expr(expr.child.as_ref().unwrap(),
Arc::clone(&input_schema))?;
+
+ Ok(Arc::new(GetArrayStructFields::new(
+ child,
+ expr.ordinal as usize,
+ )))
+ }
expr => Err(ExecutionError::GeneralError(format!(
"Not implemented: {:?}",
expr
diff --git a/native/proto/src/proto/expr.proto
b/native/proto/src/proto/expr.proto
index 88940f38..1a3e3c9f 100644
--- a/native/proto/src/proto/expr.proto
+++ b/native/proto/src/proto/expr.proto
@@ -81,6 +81,7 @@ message Expr {
GetStructField get_struct_field = 54;
ToJson to_json = 55;
ListExtract list_extract = 56;
+ GetArrayStructFields get_array_struct_fields = 57;
}
}
@@ -517,6 +518,11 @@ message ListExtract {
bool fail_on_error = 5;
}
+message GetArrayStructFields {
+ Expr child = 1;
+ int32 ordinal = 2;
+}
+
enum SortDirection {
Ascending = 0;
Descending = 1;
diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs
index c4b1c99b..cc22dfcb 100644
--- a/native/spark-expr/src/lib.rs
+++ b/native/spark-expr/src/lib.rs
@@ -38,7 +38,7 @@ mod xxhash64;
pub use cast::{spark_cast, Cast};
pub use error::{SparkError, SparkResult};
pub use if_expr::IfExpr;
-pub use list::ListExtract;
+pub use list::{GetArrayStructFields, ListExtract};
pub use regexp::RLike;
pub use structs::{CreateNamedStruct, GetStructField};
pub use temporal::{DateTruncExpr, HourExpr, MinuteExpr, SecondExpr,
TimestampTruncExpr};
diff --git a/native/spark-expr/src/list.rs b/native/spark-expr/src/list.rs
index 0b85a842..a376198d 100644
--- a/native/spark-expr/src/list.rs
+++ b/native/spark-expr/src/list.rs
@@ -16,7 +16,7 @@
// under the License.
use arrow::{array::MutableArrayData, datatypes::ArrowNativeType,
record_batch::RecordBatch};
-use arrow_array::{Array, GenericListArray, Int32Array, OffsetSizeTrait};
+use arrow_array::{Array, GenericListArray, Int32Array, OffsetSizeTrait,
StructArray};
use arrow_schema::{DataType, FieldRef, Schema};
use datafusion::logical_expr::ColumnarValue;
use datafusion::physical_expr_common::physical_expr::down_cast_any_ref;
@@ -275,6 +275,144 @@ impl PartialEq<dyn Any> for ListExtract {
}
}
+#[derive(Debug, Hash)]
+pub struct GetArrayStructFields {
+ child: Arc<dyn PhysicalExpr>,
+ ordinal: usize,
+}
+
+impl GetArrayStructFields {
+ pub fn new(child: Arc<dyn PhysicalExpr>, ordinal: usize) -> Self {
+ Self { child, ordinal }
+ }
+
+ fn list_field(&self, input_schema: &Schema) -> DataFusionResult<FieldRef> {
+ match self.child.data_type(input_schema)? {
+ DataType::List(field) | DataType::LargeList(field) => Ok(field),
+ data_type => Err(DataFusionError::Internal(format!(
+ "Unexpected data type in GetArrayStructFields: {:?}",
+ data_type
+ ))),
+ }
+ }
+
+ fn child_field(&self, input_schema: &Schema) -> DataFusionResult<FieldRef>
{
+ match self.list_field(input_schema)?.data_type() {
+ DataType::Struct(fields) => Ok(Arc::clone(&fields[self.ordinal])),
+ data_type => Err(DataFusionError::Internal(format!(
+ "Unexpected data type in GetArrayStructFields: {:?}",
+ data_type
+ ))),
+ }
+ }
+}
+
+impl PhysicalExpr for GetArrayStructFields {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn data_type(&self, input_schema: &Schema) -> DataFusionResult<DataType> {
+ let struct_field = self.child_field(input_schema)?;
+ match self.child.data_type(input_schema)? {
+ DataType::List(_) => Ok(DataType::List(struct_field)),
+ DataType::LargeList(_) => Ok(DataType::LargeList(struct_field)),
+ data_type => Err(DataFusionError::Internal(format!(
+ "Unexpected data type in GetArrayStructFields: {:?}",
+ data_type
+ ))),
+ }
+ }
+
+ fn nullable(&self, input_schema: &Schema) -> DataFusionResult<bool> {
+ Ok(self.list_field(input_schema)?.is_nullable()
+ || self.child_field(input_schema)?.is_nullable())
+ }
+
+ fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult<ColumnarValue>
{
+ let child_value =
self.child.evaluate(batch)?.into_array(batch.num_rows())?;
+
+ match child_value.data_type() {
+ DataType::List(_) => {
+ let list_array = as_list_array(&child_value)?;
+
+ get_array_struct_fields(list_array, self.ordinal)
+ }
+ DataType::LargeList(_) => {
+ let list_array = as_large_list_array(&child_value)?;
+
+ get_array_struct_fields(list_array, self.ordinal)
+ }
+ data_type => Err(DataFusionError::Internal(format!(
+ "Unexpected child type for ListExtract: {:?}",
+ data_type
+ ))),
+ }
+ }
+
+ fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
+ vec![&self.child]
+ }
+
+ fn with_new_children(
+ self: Arc<Self>,
+ children: Vec<Arc<dyn PhysicalExpr>>,
+ ) -> datafusion_common::Result<Arc<dyn PhysicalExpr>> {
+ match children.len() {
+ 1 => Ok(Arc::new(GetArrayStructFields::new(
+ Arc::clone(&children[0]),
+ self.ordinal,
+ ))),
+ _ => internal_err!("GetArrayStructFields should have exactly one
child"),
+ }
+ }
+
+ fn dyn_hash(&self, state: &mut dyn Hasher) {
+ let mut s = state;
+ self.child.hash(&mut s);
+ self.ordinal.hash(&mut s);
+ self.hash(&mut s);
+ }
+}
+
+fn get_array_struct_fields<O: OffsetSizeTrait>(
+ list_array: &GenericListArray<O>,
+ ordinal: usize,
+) -> DataFusionResult<ColumnarValue> {
+ let values = list_array
+ .values()
+ .as_any()
+ .downcast_ref::<StructArray>()
+ .expect("A struct is expected");
+
+ let column = Arc::clone(values.column(ordinal));
+ let field = Arc::clone(&values.fields()[ordinal]);
+
+ let offsets = list_array.offsets();
+ let array = GenericListArray::new(field, offsets.clone(), column,
list_array.nulls().cloned());
+
+ Ok(ColumnarValue::Array(Arc::new(array)))
+}
+
+impl Display for GetArrayStructFields {
+ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ write!(
+ f,
+ "GetArrayStructFields [child: {:?}, ordinal: {:?}]",
+ self.child, self.ordinal
+ )
+ }
+}
+
+impl PartialEq<dyn Any> for GetArrayStructFields {
+ fn eq(&self, other: &dyn Any) -> bool {
+ down_cast_any_ref(other)
+ .downcast_ref::<Self>()
+ .map(|x| self.child.eq(&x.child) && self.ordinal.eq(&x.ordinal))
+ .unwrap_or(false)
+ }
+}
+
#[cfg(test)]
mod test {
use crate::list::{list_extract, zero_based_index};
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 51b32b7d..02b845e7 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -2542,6 +2542,25 @@ object QueryPlanSerde extends Logging with
ShimQueryPlanSerde with CometExprShim
None
}
+ case GetArrayStructFields(child, _, ordinal, _, _) =>
+ val childExpr = exprToProto(child, inputs, binding)
+
+ if (childExpr.isDefined) {
+ val arrayStructFieldsBuilder = ExprOuterClass.GetArrayStructFields
+ .newBuilder()
+ .setChild(childExpr.get)
+ .setOrdinal(ordinal)
+
+ Some(
+ ExprOuterClass.Expr
+ .newBuilder()
+ .setGetArrayStructFields(arrayStructFieldsBuilder)
+ .build())
+ } else {
+ withInfo(expr, "unsupported arguments for GetArrayStructFields",
child)
+ None
+ }
+
case _ =>
withInfo(expr, s"${expr.prettyName} is not supported",
expr.children: _*)
None
diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
index 16bc15b8..da22df40 100644
--- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
@@ -2271,4 +2271,26 @@ class CometExpressionSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
}
}
}
+
+ test("GetArrayStructFields") {
+ Seq(true, false).foreach { dictionaryEnabled =>
+ withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key ->
SimplifyExtractValueOps.ruleName) {
+ withTempDir { dir =>
+ val path = new Path(dir.toURI.toString, "test.parquet")
+ makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled,
10000)
+ val df = spark.read
+ .parquet(path.toString)
+ .select(
+ array(struct(col("_2"), col("_3"), col("_4"), col("_8")),
lit(null)).alias("arr"))
+ checkSparkAnswerAndOperator(df.select("arr._2", "arr._3", "arr._4"))
+
+ val complex = spark.read
+ .parquet(path.toString)
+ .select(array(struct(struct(col("_4"),
col("_8")).alias("nested"))).alias("arr"))
+
+ checkSparkAnswerAndOperator(complex.select(col("arr.nested._4")))
+ }
+ }
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]