This is an automated email from the ASF dual-hosted git repository. agrove pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push: new 0dc71c091 chore: Refactor serde for more array and struct expressions (#2257) 0dc71c091 is described below commit 0dc71c091d1c10ec5480235ec3acc28e3522fe89 Author: Andy Grove <agr...@apache.org> AuthorDate: Fri Aug 29 16:59:18 2025 -0600 chore: Refactor serde for more array and struct expressions (#2257) --- .../org/apache/comet/serde/QueryPlanSerde.scala | 170 +-------------------- .../main/scala/org/apache/comet/serde/arrays.scala | 68 ++++++++- .../scala/org/apache/comet/serde/structs.scala | 169 ++++++++++++++++++++ 3 files changed, 242 insertions(+), 165 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 22bd6fd03..ad9be300f 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -95,6 +95,8 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[ArraysOverlap] -> CometArraysOverlap, classOf[ArrayUnion] -> CometArrayUnion, classOf[CreateArray] -> CometCreateArray, + classOf[GetArrayItem] -> CometGetArrayItem, + classOf[ElementAt] -> CometElementAt, classOf[Ascii] -> CometScalarFunction("ascii"), classOf[ConcatWs] -> CometScalarFunction("concat_ws"), classOf[Chr] -> CometScalarFunction("char"), @@ -170,6 +172,10 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[DateSub] -> CometDateSub, classOf[TruncDate] -> CometTruncDate, classOf[TruncTimestamp] -> CometTruncTimestamp, + classOf[CreateNamedStruct] -> CometCreateNamedStruct, + classOf[GetStructField] -> CometGetStructField, + classOf[GetArrayStructFields] -> CometGetArrayStructFields, + classOf[StructsToJson] -> CometStructsToJson, classOf[Flatten] -> CometFlatten, classOf[Atan2] -> CometAtan2, classOf[Ceil] -> CometCeil, @@ -922,66 +928,6 @@ object QueryPlanSerde extends Logging with CometExprShim { None } - case StructsToJson(options, child, timezoneId) => - if (options.nonEmpty) { - withInfo(expr, "StructsToJson with options is not supported") - None - } else { - - def isSupportedType(dt: DataType): Boolean = { - dt match { - case StructType(fields) => - fields.forall(f => isSupportedType(f.dataType)) - case DataTypes.BooleanType | DataTypes.ByteType | DataTypes.ShortType | - DataTypes.IntegerType | DataTypes.LongType | DataTypes.FloatType | - DataTypes.DoubleType | DataTypes.StringType => - true - case DataTypes.DateType | DataTypes.TimestampType => - // TODO implement these types with tests for formatting options and timezone - false - case _: MapType | _: ArrayType => - // Spark supports map and array in StructsToJson but this is not yet - // implemented in Comet - false - case _ => false - } - } - - val isSupported = child.dataType match { - case s: StructType => - s.fields.forall(f => isSupportedType(f.dataType)) - case _: MapType | _: ArrayType => - // Spark supports map and array in StructsToJson but this is not yet - // implemented in Comet - false - case _ => - false - } - - if (isSupported) { - exprToProtoInternal(child, inputs, binding) match { - case Some(p) => - val toJson = ExprOuterClass.ToJson - .newBuilder() - .setChild(p) - .setTimezone(timezoneId.getOrElse("UTC")) - .setIgnoreNullFields(true) - .build() - Some( - ExprOuterClass.Expr - .newBuilder() - .setToJson(toJson) - .build()) - case _ => - withInfo(expr, child) - None - } - } else { - withInfo(expr, "Unsupported data type", child) - None - } - } - case SortOrder(child, direction, nullOrdering, _) => val childExpr = exprToProtoInternal(child, inputs, binding) @@ -1336,110 +1282,6 @@ object QueryPlanSerde extends Logging with CometExprShim { withInfo(expr, bloomFilter, value) None } - - case struct @ CreateNamedStruct(_) => - if (struct.names.length != struct.names.distinct.length) { - withInfo(expr, "CreateNamedStruct with duplicate field names are not supported") - return None - } - - val valExprs = struct.valExprs.map(exprToProtoInternal(_, inputs, binding)) - - if (valExprs.forall(_.isDefined)) { - val structBuilder = ExprOuterClass.CreateNamedStruct.newBuilder() - structBuilder.addAllValues(valExprs.map(_.get).asJava) - structBuilder.addAllNames(struct.names.map(_.toString).asJava) - - Some( - ExprOuterClass.Expr - .newBuilder() - .setCreateNamedStruct(structBuilder) - .build()) - } else { - withInfo(expr, "unsupported arguments for CreateNamedStruct", struct.valExprs: _*) - None - } - - case GetStructField(child, ordinal, _) => - exprToProtoInternal(child, inputs, binding).map { childExpr => - val getStructFieldBuilder = ExprOuterClass.GetStructField - .newBuilder() - .setChild(childExpr) - .setOrdinal(ordinal) - - ExprOuterClass.Expr - .newBuilder() - .setGetStructField(getStructFieldBuilder) - .build() - } - - case GetArrayItem(child, ordinal, failOnError) => - val childExpr = exprToProtoInternal(child, inputs, binding) - val ordinalExpr = exprToProtoInternal(ordinal, inputs, binding) - - if (childExpr.isDefined && ordinalExpr.isDefined) { - val listExtractBuilder = ExprOuterClass.ListExtract - .newBuilder() - .setChild(childExpr.get) - .setOrdinal(ordinalExpr.get) - .setOneBased(false) - .setFailOnError(failOnError) - - Some( - ExprOuterClass.Expr - .newBuilder() - .setListExtract(listExtractBuilder) - .build()) - } else { - withInfo(expr, "unsupported arguments for GetArrayItem", child, ordinal) - None - } - - case ElementAt(child, ordinal, defaultValue, failOnError) - if child.dataType.isInstanceOf[ArrayType] => - val childExpr = exprToProtoInternal(child, inputs, binding) - val ordinalExpr = exprToProtoInternal(ordinal, inputs, binding) - val defaultExpr = defaultValue.flatMap(exprToProtoInternal(_, inputs, binding)) - - if (childExpr.isDefined && ordinalExpr.isDefined && - defaultExpr.isDefined == defaultValue.isDefined) { - val arrayExtractBuilder = ExprOuterClass.ListExtract - .newBuilder() - .setChild(childExpr.get) - .setOrdinal(ordinalExpr.get) - .setOneBased(true) - .setFailOnError(failOnError) - - defaultExpr.foreach(arrayExtractBuilder.setDefaultValue(_)) - - Some( - ExprOuterClass.Expr - .newBuilder() - .setListExtract(arrayExtractBuilder) - .build()) - } else { - withInfo(expr, "unsupported arguments for ElementAt", child, ordinal) - None - } - - case GetArrayStructFields(child, _, ordinal, _, _) => - val childExpr = exprToProtoInternal(child, inputs, binding) - - if (childExpr.isDefined) { - val arrayStructFieldsBuilder = ExprOuterClass.GetArrayStructFields - .newBuilder() - .setChild(childExpr.get) - .setOrdinal(ordinal) - - Some( - ExprOuterClass.Expr - .newBuilder() - .setGetArrayStructFields(arrayStructFieldsBuilder) - .build()) - } else { - withInfo(expr, "unsupported arguments for GetArrayStructFields", child) - None - } case af @ ArrayFilter(_, func) if func.children.head.isInstanceOf[IsNotNull] => convert(af, CometArrayCompact) case expr => diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index 411ef00b4..5b1603aaf 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -21,7 +21,7 @@ package org.apache.comet.serde import scala.annotation.tailrec -import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains, ArrayDistinct, ArrayExcept, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayRemove, ArrayRepeat, ArraysOverlap, ArrayUnion, Attribute, CreateArray, Expression, Flatten, Literal} +import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains, ArrayDistinct, ArrayExcept, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayRemove, ArrayRepeat, ArraysOverlap, ArrayUnion, Attribute, CreateArray, ElementAt, Expression, Flatten, GetArrayItem, Literal} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -404,6 +404,72 @@ object CometCreateArray extends CometExpressionSerde[CreateArray] { } } +object CometGetArrayItem extends CometExpressionSerde[GetArrayItem] { + override def convert( + expr: GetArrayItem, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val childExpr = exprToProtoInternal(expr.child, inputs, binding) + val ordinalExpr = exprToProtoInternal(expr.ordinal, inputs, binding) + + if (childExpr.isDefined && ordinalExpr.isDefined) { + val listExtractBuilder = ExprOuterClass.ListExtract + .newBuilder() + .setChild(childExpr.get) + .setOrdinal(ordinalExpr.get) + .setOneBased(false) + .setFailOnError(expr.failOnError) + + Some( + ExprOuterClass.Expr + .newBuilder() + .setListExtract(listExtractBuilder) + .build()) + } else { + withInfo(expr, "unsupported arguments for GetArrayItem", expr.child, expr.ordinal) + None + } + } +} + +object CometElementAt extends CometExpressionSerde[ElementAt] { + + override def convert( + expr: ElementAt, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val childExpr = exprToProtoInternal(expr.left, inputs, binding) + val ordinalExpr = exprToProtoInternal(expr.right, inputs, binding) + val defaultExpr = expr.defaultValueOutOfBound.flatMap(exprToProtoInternal(_, inputs, binding)) + + if (!expr.left.dataType.isInstanceOf[ArrayType]) { + withInfo(expr, "Input is not an array") + return None + } + + if (childExpr.isDefined && ordinalExpr.isDefined && + defaultExpr.isDefined == expr.defaultValueOutOfBound.isDefined) { + val arrayExtractBuilder = ExprOuterClass.ListExtract + .newBuilder() + .setChild(childExpr.get) + .setOrdinal(ordinalExpr.get) + .setOneBased(true) + .setFailOnError(expr.failOnError) + + defaultExpr.foreach(arrayExtractBuilder.setDefaultValue(_)) + + Some( + ExprOuterClass.Expr + .newBuilder() + .setListExtract(arrayExtractBuilder) + .build()) + } else { + withInfo(expr, "unsupported arguments for ElementAt", expr.left, expr.right) + None + } + } +} + object CometFlatten extends CometExpressionSerde[Flatten] with ArraysBase { override def convert( diff --git a/spark/src/main/scala/org/apache/comet/serde/structs.scala b/spark/src/main/scala/org/apache/comet/serde/structs.scala new file mode 100644 index 000000000..1c25d87bb --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/serde/structs.scala @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.catalyst.expressions.{Attribute, CreateNamedStruct, GetArrayStructFields, GetStructField, StructsToJson} +import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, MapType, StructType} + +import org.apache.comet.CometSparkSessionExtensions.withInfo +import org.apache.comet.serde.QueryPlanSerde.exprToProtoInternal + +object CometCreateNamedStruct extends CometExpressionSerde[CreateNamedStruct] { + override def convert( + expr: CreateNamedStruct, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + if (expr.names.length != expr.names.distinct.length) { + withInfo(expr, "CreateNamedStruct with duplicate field names are not supported") + return None + } + + val valExprs = expr.valExprs.map(exprToProtoInternal(_, inputs, binding)) + + if (valExprs.forall(_.isDefined)) { + val structBuilder = ExprOuterClass.CreateNamedStruct.newBuilder() + structBuilder.addAllValues(valExprs.map(_.get).asJava) + structBuilder.addAllNames(expr.names.map(_.toString).asJava) + + Some( + ExprOuterClass.Expr + .newBuilder() + .setCreateNamedStruct(structBuilder) + .build()) + } else { + withInfo(expr, "unsupported arguments for CreateNamedStruct", expr.valExprs: _*) + None + } + + } +} + +object CometGetStructField extends CometExpressionSerde[GetStructField] { + override def convert( + expr: GetStructField, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + exprToProtoInternal(expr.child, inputs, binding).map { childExpr => + val getStructFieldBuilder = ExprOuterClass.GetStructField + .newBuilder() + .setChild(childExpr) + .setOrdinal(expr.ordinal) + + ExprOuterClass.Expr + .newBuilder() + .setGetStructField(getStructFieldBuilder) + .build() + } + } +} + +object CometGetArrayStructFields extends CometExpressionSerde[GetArrayStructFields] { + override def convert( + expr: GetArrayStructFields, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val childExpr = exprToProtoInternal(expr.child, inputs, binding) + + if (childExpr.isDefined) { + val arrayStructFieldsBuilder = ExprOuterClass.GetArrayStructFields + .newBuilder() + .setChild(childExpr.get) + .setOrdinal(expr.ordinal) + + Some( + ExprOuterClass.Expr + .newBuilder() + .setGetArrayStructFields(arrayStructFieldsBuilder) + .build()) + } else { + withInfo(expr, "unsupported arguments for GetArrayStructFields", expr.child) + None + } + } +} + +object CometStructsToJson extends CometExpressionSerde[StructsToJson] { + + override def convert( + expr: StructsToJson, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + if (expr.options.nonEmpty) { + withInfo(expr, "StructsToJson with options is not supported") + None + } else { + + def isSupportedType(dt: DataType): Boolean = { + dt match { + case StructType(fields) => + fields.forall(f => isSupportedType(f.dataType)) + case DataTypes.BooleanType | DataTypes.ByteType | DataTypes.ShortType | + DataTypes.IntegerType | DataTypes.LongType | DataTypes.FloatType | + DataTypes.DoubleType | DataTypes.StringType => + true + case DataTypes.DateType | DataTypes.TimestampType => + // TODO implement these types with tests for formatting options and timezone + false + case _: MapType | _: ArrayType => + // Spark supports map and array in StructsToJson but this is not yet + // implemented in Comet + false + case _ => false + } + } + + val isSupported = expr.child.dataType match { + case s: StructType => + s.fields.forall(f => isSupportedType(f.dataType)) + case _: MapType | _: ArrayType => + // Spark supports map and array in StructsToJson but this is not yet + // implemented in Comet + false + case _ => + false + } + + if (isSupported) { + exprToProtoInternal(expr.child, inputs, binding) match { + case Some(p) => + val toJson = ExprOuterClass.ToJson + .newBuilder() + .setChild(p) + .setTimezone(expr.timeZoneId.getOrElse("UTC")) + .setIgnoreNullFields(true) + .build() + Some( + ExprOuterClass.Expr + .newBuilder() + .setToJson(toJson) + .build()) + case _ => + withInfo(expr, expr.child) + None + } + } else { + withInfo(expr, "Unsupported data type", expr.child) + None + } + } + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org For additional commands, e-mail: commits-h...@datafusion.apache.org