This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 0dc71c091 chore: Refactor serde for more array and struct expressions 
(#2257)
0dc71c091 is described below

commit 0dc71c091d1c10ec5480235ec3acc28e3522fe89
Author: Andy Grove <agr...@apache.org>
AuthorDate: Fri Aug 29 16:59:18 2025 -0600

    chore: Refactor serde for more array and struct expressions (#2257)
---
 .../org/apache/comet/serde/QueryPlanSerde.scala    | 170 +--------------------
 .../main/scala/org/apache/comet/serde/arrays.scala |  68 ++++++++-
 .../scala/org/apache/comet/serde/structs.scala     | 169 ++++++++++++++++++++
 3 files changed, 242 insertions(+), 165 deletions(-)

diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 22bd6fd03..ad9be300f 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -95,6 +95,8 @@ object QueryPlanSerde extends Logging with CometExprShim {
     classOf[ArraysOverlap] -> CometArraysOverlap,
     classOf[ArrayUnion] -> CometArrayUnion,
     classOf[CreateArray] -> CometCreateArray,
+    classOf[GetArrayItem] -> CometGetArrayItem,
+    classOf[ElementAt] -> CometElementAt,
     classOf[Ascii] -> CometScalarFunction("ascii"),
     classOf[ConcatWs] -> CometScalarFunction("concat_ws"),
     classOf[Chr] -> CometScalarFunction("char"),
@@ -170,6 +172,10 @@ object QueryPlanSerde extends Logging with CometExprShim {
     classOf[DateSub] -> CometDateSub,
     classOf[TruncDate] -> CometTruncDate,
     classOf[TruncTimestamp] -> CometTruncTimestamp,
+    classOf[CreateNamedStruct] -> CometCreateNamedStruct,
+    classOf[GetStructField] -> CometGetStructField,
+    classOf[GetArrayStructFields] -> CometGetArrayStructFields,
+    classOf[StructsToJson] -> CometStructsToJson,
     classOf[Flatten] -> CometFlatten,
     classOf[Atan2] -> CometAtan2,
     classOf[Ceil] -> CometCeil,
@@ -922,66 +928,6 @@ object QueryPlanSerde extends Logging with CometExprShim {
             None
         }
 
-      case StructsToJson(options, child, timezoneId) =>
-        if (options.nonEmpty) {
-          withInfo(expr, "StructsToJson with options is not supported")
-          None
-        } else {
-
-          def isSupportedType(dt: DataType): Boolean = {
-            dt match {
-              case StructType(fields) =>
-                fields.forall(f => isSupportedType(f.dataType))
-              case DataTypes.BooleanType | DataTypes.ByteType | 
DataTypes.ShortType |
-                  DataTypes.IntegerType | DataTypes.LongType | 
DataTypes.FloatType |
-                  DataTypes.DoubleType | DataTypes.StringType =>
-                true
-              case DataTypes.DateType | DataTypes.TimestampType =>
-                // TODO implement these types with tests for formatting 
options and timezone
-                false
-              case _: MapType | _: ArrayType =>
-                // Spark supports map and array in StructsToJson but this is 
not yet
-                // implemented in Comet
-                false
-              case _ => false
-            }
-          }
-
-          val isSupported = child.dataType match {
-            case s: StructType =>
-              s.fields.forall(f => isSupportedType(f.dataType))
-            case _: MapType | _: ArrayType =>
-              // Spark supports map and array in StructsToJson but this is not 
yet
-              // implemented in Comet
-              false
-            case _ =>
-              false
-          }
-
-          if (isSupported) {
-            exprToProtoInternal(child, inputs, binding) match {
-              case Some(p) =>
-                val toJson = ExprOuterClass.ToJson
-                  .newBuilder()
-                  .setChild(p)
-                  .setTimezone(timezoneId.getOrElse("UTC"))
-                  .setIgnoreNullFields(true)
-                  .build()
-                Some(
-                  ExprOuterClass.Expr
-                    .newBuilder()
-                    .setToJson(toJson)
-                    .build())
-              case _ =>
-                withInfo(expr, child)
-                None
-            }
-          } else {
-            withInfo(expr, "Unsupported data type", child)
-            None
-          }
-        }
-
       case SortOrder(child, direction, nullOrdering, _) =>
         val childExpr = exprToProtoInternal(child, inputs, binding)
 
@@ -1336,110 +1282,6 @@ object QueryPlanSerde extends Logging with 
CometExprShim {
           withInfo(expr, bloomFilter, value)
           None
         }
-
-      case struct @ CreateNamedStruct(_) =>
-        if (struct.names.length != struct.names.distinct.length) {
-          withInfo(expr, "CreateNamedStruct with duplicate field names are not 
supported")
-          return None
-        }
-
-        val valExprs = struct.valExprs.map(exprToProtoInternal(_, inputs, 
binding))
-
-        if (valExprs.forall(_.isDefined)) {
-          val structBuilder = ExprOuterClass.CreateNamedStruct.newBuilder()
-          structBuilder.addAllValues(valExprs.map(_.get).asJava)
-          structBuilder.addAllNames(struct.names.map(_.toString).asJava)
-
-          Some(
-            ExprOuterClass.Expr
-              .newBuilder()
-              .setCreateNamedStruct(structBuilder)
-              .build())
-        } else {
-          withInfo(expr, "unsupported arguments for CreateNamedStruct", 
struct.valExprs: _*)
-          None
-        }
-
-      case GetStructField(child, ordinal, _) =>
-        exprToProtoInternal(child, inputs, binding).map { childExpr =>
-          val getStructFieldBuilder = ExprOuterClass.GetStructField
-            .newBuilder()
-            .setChild(childExpr)
-            .setOrdinal(ordinal)
-
-          ExprOuterClass.Expr
-            .newBuilder()
-            .setGetStructField(getStructFieldBuilder)
-            .build()
-        }
-
-      case GetArrayItem(child, ordinal, failOnError) =>
-        val childExpr = exprToProtoInternal(child, inputs, binding)
-        val ordinalExpr = exprToProtoInternal(ordinal, inputs, binding)
-
-        if (childExpr.isDefined && ordinalExpr.isDefined) {
-          val listExtractBuilder = ExprOuterClass.ListExtract
-            .newBuilder()
-            .setChild(childExpr.get)
-            .setOrdinal(ordinalExpr.get)
-            .setOneBased(false)
-            .setFailOnError(failOnError)
-
-          Some(
-            ExprOuterClass.Expr
-              .newBuilder()
-              .setListExtract(listExtractBuilder)
-              .build())
-        } else {
-          withInfo(expr, "unsupported arguments for GetArrayItem", child, 
ordinal)
-          None
-        }
-
-      case ElementAt(child, ordinal, defaultValue, failOnError)
-          if child.dataType.isInstanceOf[ArrayType] =>
-        val childExpr = exprToProtoInternal(child, inputs, binding)
-        val ordinalExpr = exprToProtoInternal(ordinal, inputs, binding)
-        val defaultExpr = defaultValue.flatMap(exprToProtoInternal(_, inputs, 
binding))
-
-        if (childExpr.isDefined && ordinalExpr.isDefined &&
-          defaultExpr.isDefined == defaultValue.isDefined) {
-          val arrayExtractBuilder = ExprOuterClass.ListExtract
-            .newBuilder()
-            .setChild(childExpr.get)
-            .setOrdinal(ordinalExpr.get)
-            .setOneBased(true)
-            .setFailOnError(failOnError)
-
-          defaultExpr.foreach(arrayExtractBuilder.setDefaultValue(_))
-
-          Some(
-            ExprOuterClass.Expr
-              .newBuilder()
-              .setListExtract(arrayExtractBuilder)
-              .build())
-        } else {
-          withInfo(expr, "unsupported arguments for ElementAt", child, ordinal)
-          None
-        }
-
-      case GetArrayStructFields(child, _, ordinal, _, _) =>
-        val childExpr = exprToProtoInternal(child, inputs, binding)
-
-        if (childExpr.isDefined) {
-          val arrayStructFieldsBuilder = ExprOuterClass.GetArrayStructFields
-            .newBuilder()
-            .setChild(childExpr.get)
-            .setOrdinal(ordinal)
-
-          Some(
-            ExprOuterClass.Expr
-              .newBuilder()
-              .setGetArrayStructFields(arrayStructFieldsBuilder)
-              .build())
-        } else {
-          withInfo(expr, "unsupported arguments for GetArrayStructFields", 
child)
-          None
-        }
       case af @ ArrayFilter(_, func) if 
func.children.head.isInstanceOf[IsNotNull] =>
         convert(af, CometArrayCompact)
       case expr =>
diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala 
b/spark/src/main/scala/org/apache/comet/serde/arrays.scala
index 411ef00b4..5b1603aaf 100644
--- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala
@@ -21,7 +21,7 @@ package org.apache.comet.serde
 
 import scala.annotation.tailrec
 
-import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains, 
ArrayDistinct, ArrayExcept, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, 
ArrayMin, ArrayRemove, ArrayRepeat, ArraysOverlap, ArrayUnion, Attribute, 
CreateArray, Expression, Flatten, Literal}
+import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains, 
ArrayDistinct, ArrayExcept, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, 
ArrayMin, ArrayRemove, ArrayRepeat, ArraysOverlap, ArrayUnion, Attribute, 
CreateArray, ElementAt, Expression, Flatten, GetArrayItem, Literal}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 
@@ -404,6 +404,72 @@ object CometCreateArray extends 
CometExpressionSerde[CreateArray] {
   }
 }
 
+object CometGetArrayItem extends CometExpressionSerde[GetArrayItem] {
+  override def convert(
+      expr: GetArrayItem,
+      inputs: Seq[Attribute],
+      binding: Boolean): Option[ExprOuterClass.Expr] = {
+    val childExpr = exprToProtoInternal(expr.child, inputs, binding)
+    val ordinalExpr = exprToProtoInternal(expr.ordinal, inputs, binding)
+
+    if (childExpr.isDefined && ordinalExpr.isDefined) {
+      val listExtractBuilder = ExprOuterClass.ListExtract
+        .newBuilder()
+        .setChild(childExpr.get)
+        .setOrdinal(ordinalExpr.get)
+        .setOneBased(false)
+        .setFailOnError(expr.failOnError)
+
+      Some(
+        ExprOuterClass.Expr
+          .newBuilder()
+          .setListExtract(listExtractBuilder)
+          .build())
+    } else {
+      withInfo(expr, "unsupported arguments for GetArrayItem", expr.child, 
expr.ordinal)
+      None
+    }
+  }
+}
+
+object CometElementAt extends CometExpressionSerde[ElementAt] {
+
+  override def convert(
+      expr: ElementAt,
+      inputs: Seq[Attribute],
+      binding: Boolean): Option[ExprOuterClass.Expr] = {
+    val childExpr = exprToProtoInternal(expr.left, inputs, binding)
+    val ordinalExpr = exprToProtoInternal(expr.right, inputs, binding)
+    val defaultExpr = 
expr.defaultValueOutOfBound.flatMap(exprToProtoInternal(_, inputs, binding))
+
+    if (!expr.left.dataType.isInstanceOf[ArrayType]) {
+      withInfo(expr, "Input is not an array")
+      return None
+    }
+
+    if (childExpr.isDefined && ordinalExpr.isDefined &&
+      defaultExpr.isDefined == expr.defaultValueOutOfBound.isDefined) {
+      val arrayExtractBuilder = ExprOuterClass.ListExtract
+        .newBuilder()
+        .setChild(childExpr.get)
+        .setOrdinal(ordinalExpr.get)
+        .setOneBased(true)
+        .setFailOnError(expr.failOnError)
+
+      defaultExpr.foreach(arrayExtractBuilder.setDefaultValue(_))
+
+      Some(
+        ExprOuterClass.Expr
+          .newBuilder()
+          .setListExtract(arrayExtractBuilder)
+          .build())
+    } else {
+      withInfo(expr, "unsupported arguments for ElementAt", expr.left, 
expr.right)
+      None
+    }
+  }
+}
+
 object CometFlatten extends CometExpressionSerde[Flatten] with ArraysBase {
 
   override def convert(
diff --git a/spark/src/main/scala/org/apache/comet/serde/structs.scala 
b/spark/src/main/scala/org/apache/comet/serde/structs.scala
new file mode 100644
index 000000000..1c25d87bb
--- /dev/null
+++ b/spark/src/main/scala/org/apache/comet/serde/structs.scala
@@ -0,0 +1,169 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.serde
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.sql.catalyst.expressions.{Attribute, 
CreateNamedStruct, GetArrayStructFields, GetStructField, StructsToJson}
+import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, MapType, 
StructType}
+
+import org.apache.comet.CometSparkSessionExtensions.withInfo
+import org.apache.comet.serde.QueryPlanSerde.exprToProtoInternal
+
+object CometCreateNamedStruct extends CometExpressionSerde[CreateNamedStruct] {
+  override def convert(
+      expr: CreateNamedStruct,
+      inputs: Seq[Attribute],
+      binding: Boolean): Option[ExprOuterClass.Expr] = {
+    if (expr.names.length != expr.names.distinct.length) {
+      withInfo(expr, "CreateNamedStruct with duplicate field names are not 
supported")
+      return None
+    }
+
+    val valExprs = expr.valExprs.map(exprToProtoInternal(_, inputs, binding))
+
+    if (valExprs.forall(_.isDefined)) {
+      val structBuilder = ExprOuterClass.CreateNamedStruct.newBuilder()
+      structBuilder.addAllValues(valExprs.map(_.get).asJava)
+      structBuilder.addAllNames(expr.names.map(_.toString).asJava)
+
+      Some(
+        ExprOuterClass.Expr
+          .newBuilder()
+          .setCreateNamedStruct(structBuilder)
+          .build())
+    } else {
+      withInfo(expr, "unsupported arguments for CreateNamedStruct", 
expr.valExprs: _*)
+      None
+    }
+
+  }
+}
+
+object CometGetStructField extends CometExpressionSerde[GetStructField] {
+  override def convert(
+      expr: GetStructField,
+      inputs: Seq[Attribute],
+      binding: Boolean): Option[ExprOuterClass.Expr] = {
+    exprToProtoInternal(expr.child, inputs, binding).map { childExpr =>
+      val getStructFieldBuilder = ExprOuterClass.GetStructField
+        .newBuilder()
+        .setChild(childExpr)
+        .setOrdinal(expr.ordinal)
+
+      ExprOuterClass.Expr
+        .newBuilder()
+        .setGetStructField(getStructFieldBuilder)
+        .build()
+    }
+  }
+}
+
+object CometGetArrayStructFields extends 
CometExpressionSerde[GetArrayStructFields] {
+  override def convert(
+      expr: GetArrayStructFields,
+      inputs: Seq[Attribute],
+      binding: Boolean): Option[ExprOuterClass.Expr] = {
+    val childExpr = exprToProtoInternal(expr.child, inputs, binding)
+
+    if (childExpr.isDefined) {
+      val arrayStructFieldsBuilder = ExprOuterClass.GetArrayStructFields
+        .newBuilder()
+        .setChild(childExpr.get)
+        .setOrdinal(expr.ordinal)
+
+      Some(
+        ExprOuterClass.Expr
+          .newBuilder()
+          .setGetArrayStructFields(arrayStructFieldsBuilder)
+          .build())
+    } else {
+      withInfo(expr, "unsupported arguments for GetArrayStructFields", 
expr.child)
+      None
+    }
+  }
+}
+
+object CometStructsToJson extends CometExpressionSerde[StructsToJson] {
+
+  override def convert(
+      expr: StructsToJson,
+      inputs: Seq[Attribute],
+      binding: Boolean): Option[ExprOuterClass.Expr] = {
+    if (expr.options.nonEmpty) {
+      withInfo(expr, "StructsToJson with options is not supported")
+      None
+    } else {
+
+      def isSupportedType(dt: DataType): Boolean = {
+        dt match {
+          case StructType(fields) =>
+            fields.forall(f => isSupportedType(f.dataType))
+          case DataTypes.BooleanType | DataTypes.ByteType | 
DataTypes.ShortType |
+              DataTypes.IntegerType | DataTypes.LongType | DataTypes.FloatType 
|
+              DataTypes.DoubleType | DataTypes.StringType =>
+            true
+          case DataTypes.DateType | DataTypes.TimestampType =>
+            // TODO implement these types with tests for formatting options 
and timezone
+            false
+          case _: MapType | _: ArrayType =>
+            // Spark supports map and array in StructsToJson but this is not 
yet
+            // implemented in Comet
+            false
+          case _ => false
+        }
+      }
+
+      val isSupported = expr.child.dataType match {
+        case s: StructType =>
+          s.fields.forall(f => isSupportedType(f.dataType))
+        case _: MapType | _: ArrayType =>
+          // Spark supports map and array in StructsToJson but this is not yet
+          // implemented in Comet
+          false
+        case _ =>
+          false
+      }
+
+      if (isSupported) {
+        exprToProtoInternal(expr.child, inputs, binding) match {
+          case Some(p) =>
+            val toJson = ExprOuterClass.ToJson
+              .newBuilder()
+              .setChild(p)
+              .setTimezone(expr.timeZoneId.getOrElse("UTC"))
+              .setIgnoreNullFields(true)
+              .build()
+            Some(
+              ExprOuterClass.Expr
+                .newBuilder()
+                .setToJson(toJson)
+                .build())
+          case _ =>
+            withInfo(expr, expr.child)
+            None
+        }
+      } else {
+        withInfo(expr, "Unsupported data type", expr.child)
+        None
+      }
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org
For additional commands, e-mail: commits-h...@datafusion.apache.org

Reply via email to