This is an automated email from the ASF dual-hosted git repository. agrove pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push: new 7e0ff1a46 Chore: Refactor serde for math expressions (#2259) 7e0ff1a46 is described below commit 7e0ff1a468162fd768471d6d822c1fbd7c98daef Author: Kazantsev Maksim <kazantsev....@yandex.ru> AuthorDate: Fri Aug 29 12:56:23 2025 -0700 Chore: Refactor serde for math expressions (#2259) * Maths expr refactor * Fix * Format --------- Co-authored-by: Kazantsev Maksim <mn.kazant...@gmail.com> --- .../org/apache/comet/serde/QueryPlanSerde.scala | 79 ++------------ .../main/scala/org/apache/comet/serde/math.scala | 120 +++++++++++++++++++++ 2 files changed, 128 insertions(+), 71 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index bef7e15e1..22bd6fd03 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -170,7 +170,14 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[DateSub] -> CometDateSub, classOf[TruncDate] -> CometTruncDate, classOf[TruncTimestamp] -> CometTruncTimestamp, - classOf[Flatten] -> CometFlatten) + classOf[Flatten] -> CometFlatten, + classOf[Atan2] -> CometAtan2, + classOf[Ceil] -> CometCeil, + classOf[Floor] -> CometFloor, + classOf[Log] -> CometLog, + classOf[Log10] -> CometLog10, + classOf[Log2] -> CometLog2, + classOf[Pow] -> CometScalarFunction[Pow]("pow")) /** * Mapping of Spark aggregate expression class to Comet expression handler. @@ -1108,12 +1115,6 @@ object QueryPlanSerde extends Logging with CometExprShim { // None // } - case Atan2(left, right) => - val leftExpr = exprToProtoInternal(left, inputs, binding) - val rightExpr = exprToProtoInternal(right, inputs, binding) - val optExpr = scalarFunctionExprToProto("atan2", leftExpr, rightExpr) - optExprWithInfo(optExpr, expr, left, right) - case Hex(child) => val childExpr = exprToProtoInternal(child, inputs, binding) val optExpr = @@ -1131,56 +1132,6 @@ object QueryPlanSerde extends Logging with CometExprShim { scalarFunctionExprToProtoWithReturnType("unhex", e.dataType, childExpr, failOnErrorExpr) optExprWithInfo(optExpr, expr, unHex._1) - case e @ Ceil(child) => - val childExpr = exprToProtoInternal(child, inputs, binding) - child.dataType match { - case t: DecimalType if t.scale == 0 => // zero scale is no-op - childExpr - case t: DecimalType if t.scale < 0 => // Spark disallows negative scale SPARK-30252 - withInfo(e, s"Decimal type $t has negative scale") - None - case _ => - val optExpr = scalarFunctionExprToProtoWithReturnType("ceil", e.dataType, childExpr) - optExprWithInfo(optExpr, expr, child) - } - - case e @ Floor(child) => - val childExpr = exprToProtoInternal(child, inputs, binding) - child.dataType match { - case t: DecimalType if t.scale == 0 => // zero scale is no-op - childExpr - case t: DecimalType if t.scale < 0 => // Spark disallows negative scale SPARK-30252 - withInfo(e, s"Decimal type $t has negative scale") - None - case _ => - val optExpr = scalarFunctionExprToProtoWithReturnType("floor", e.dataType, childExpr) - optExprWithInfo(optExpr, expr, child) - } - - // The expression for `log` functions is defined as null on numbers less than or equal - // to 0. This matches Spark and Hive behavior, where non positive values eval to null - // instead of NaN or -Infinity. - case Log(child) => - val childExpr = exprToProtoInternal(nullIfNegative(child), inputs, binding) - val optExpr = scalarFunctionExprToProto("ln", childExpr) - optExprWithInfo(optExpr, expr, child) - - case Log10(child) => - val childExpr = exprToProtoInternal(nullIfNegative(child), inputs, binding) - val optExpr = scalarFunctionExprToProto("log10", childExpr) - optExprWithInfo(optExpr, expr, child) - - case Log2(child) => - val childExpr = exprToProtoInternal(nullIfNegative(child), inputs, binding) - val optExpr = scalarFunctionExprToProto("log2", childExpr) - optExprWithInfo(optExpr, expr, child) - - case Pow(left, right) => - val leftExpr = exprToProtoInternal(left, inputs, binding) - val rightExpr = exprToProtoInternal(right, inputs, binding) - val optExpr = scalarFunctionExprToProto("pow", leftExpr, rightExpr) - optExprWithInfo(optExpr, expr, left, right) - case RegExpReplace(subject, pattern, replacement, startPosition) => if (!RegExp.isSupportedPattern(pattern.toString) && !CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.get()) { @@ -1265,15 +1216,6 @@ object QueryPlanSerde extends Logging with CometExprShim { None } - case BitwiseAnd(left, right) => - createBinaryExpr( - expr, - left, - right, - inputs, - binding, - (builder, binaryExpr) => builder.setBitwiseAnd(binaryExpr)) - case n @ Not(In(_, _)) => CometNotIn.convert(n, inputs, binding) @@ -1611,11 +1553,6 @@ object QueryPlanSerde extends Logging with CometExprShim { Some(ExprOuterClass.Expr.newBuilder().setScalarFunc(builder).build()) } - private def nullIfNegative(expression: Expression): Expression = { - val zero = Literal.default(expression.dataType) - If(LessThanOrEqual(expression, zero), Literal.create(null, expression.dataType), expression) - } - /** * Returns true if given datatype is supported as a key in DataFusion sort merge join. */ diff --git a/spark/src/main/scala/org/apache/comet/serde/math.scala b/spark/src/main/scala/org/apache/comet/serde/math.scala new file mode 100644 index 000000000..700b9bd44 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/serde/math.scala @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import org.apache.spark.sql.catalyst.expressions.{Atan2, Attribute, Ceil, Expression, Floor, If, LessThanOrEqual, Literal, Log, Log10, Log2} +import org.apache.spark.sql.types.DecimalType + +import org.apache.comet.CometSparkSessionExtensions.withInfo +import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType} + +object CometAtan2 extends CometExpressionSerde[Atan2] { + override def convert( + expr: Atan2, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val leftExpr = exprToProtoInternal(expr.left, inputs, binding) + val rightExpr = exprToProtoInternal(expr.right, inputs, binding) + val optExpr = scalarFunctionExprToProto("atan2", leftExpr, rightExpr) + optExprWithInfo(optExpr, expr, expr.left, expr.right) + } +} + +object CometCeil extends CometExpressionSerde[Ceil] { + override def convert( + expr: Ceil, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val childExpr = exprToProtoInternal(expr.child, inputs, binding) + expr.child.dataType match { + case t: DecimalType if t.scale == 0 => // zero scale is no-op + childExpr + case t: DecimalType if t.scale < 0 => // Spark disallows negative scale SPARK-30252 + withInfo(expr, s"Decimal type $t has negative scale") + None + case _ => + val optExpr = scalarFunctionExprToProtoWithReturnType("ceil", expr.dataType, childExpr) + optExprWithInfo(optExpr, expr, expr.child) + } + } +} + +object CometFloor extends CometExpressionSerde[Floor] { + override def convert( + expr: Floor, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val childExpr = exprToProtoInternal(expr.child, inputs, binding) + expr.child.dataType match { + case t: DecimalType if t.scale == 0 => // zero scale is no-op + childExpr + case t: DecimalType if t.scale < 0 => // Spark disallows negative scale SPARK-30252 + withInfo(expr, s"Decimal type $t has negative scale") + None + case _ => + val optExpr = scalarFunctionExprToProtoWithReturnType("floor", expr.dataType, childExpr) + optExprWithInfo(optExpr, expr, expr.child) + } + } +} + +// The expression for `log` functions is defined as null on numbers less than or equal +// to 0. This matches Spark and Hive behavior, where non positive values eval to null +// instead of NaN or -Infinity. +object CometLog extends CometExpressionSerde[Log] with MathExprBase { + override def convert( + expr: Log, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val childExpr = exprToProtoInternal(nullIfNegative(expr.child), inputs, binding) + val optExpr = scalarFunctionExprToProto("ln", childExpr) + optExprWithInfo(optExpr, expr, expr.child) + } +} + +object CometLog10 extends CometExpressionSerde[Log10] with MathExprBase { + override def convert( + expr: Log10, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val childExpr = exprToProtoInternal(nullIfNegative(expr.child), inputs, binding) + val optExpr = scalarFunctionExprToProto("log10", childExpr) + optExprWithInfo(optExpr, expr, expr.child) + } +} + +object CometLog2 extends CometExpressionSerde[Log2] with MathExprBase { + override def convert( + expr: Log2, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val childExpr = exprToProtoInternal(nullIfNegative(expr.child), inputs, binding) + val optExpr = scalarFunctionExprToProto("log2", childExpr) + optExprWithInfo(optExpr, expr, expr.child) + + } +} + +sealed trait MathExprBase { + protected def nullIfNegative(expression: Expression): Expression = { + val zero = Literal.default(expression.dataType) + If(LessThanOrEqual(expression, zero), Literal.create(null, expression.dataType), expression) + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org For additional commands, e-mail: commits-h...@datafusion.apache.org