This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 7e0ff1a46 Chore: Refactor serde for math expressions (#2259)
7e0ff1a46 is described below

commit 7e0ff1a468162fd768471d6d822c1fbd7c98daef
Author: Kazantsev Maksim <kazantsev....@yandex.ru>
AuthorDate: Fri Aug 29 12:56:23 2025 -0700

    Chore: Refactor serde for math expressions (#2259)
    
    * Maths expr refactor
    
    * Fix
    
    * Format
    
    ---------
    
    Co-authored-by: Kazantsev Maksim <mn.kazant...@gmail.com>
---
 .../org/apache/comet/serde/QueryPlanSerde.scala    |  79 ++------------
 .../main/scala/org/apache/comet/serde/math.scala   | 120 +++++++++++++++++++++
 2 files changed, 128 insertions(+), 71 deletions(-)

diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index bef7e15e1..22bd6fd03 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -170,7 +170,14 @@ object QueryPlanSerde extends Logging with CometExprShim {
     classOf[DateSub] -> CometDateSub,
     classOf[TruncDate] -> CometTruncDate,
     classOf[TruncTimestamp] -> CometTruncTimestamp,
-    classOf[Flatten] -> CometFlatten)
+    classOf[Flatten] -> CometFlatten,
+    classOf[Atan2] -> CometAtan2,
+    classOf[Ceil] -> CometCeil,
+    classOf[Floor] -> CometFloor,
+    classOf[Log] -> CometLog,
+    classOf[Log10] -> CometLog10,
+    classOf[Log2] -> CometLog2,
+    classOf[Pow] -> CometScalarFunction[Pow]("pow"))
 
   /**
    * Mapping of Spark aggregate expression class to Comet expression handler.
@@ -1108,12 +1115,6 @@ object QueryPlanSerde extends Logging with CometExprShim 
{
 //            None
 //          }
 
-      case Atan2(left, right) =>
-        val leftExpr = exprToProtoInternal(left, inputs, binding)
-        val rightExpr = exprToProtoInternal(right, inputs, binding)
-        val optExpr = scalarFunctionExprToProto("atan2", leftExpr, rightExpr)
-        optExprWithInfo(optExpr, expr, left, right)
-
       case Hex(child) =>
         val childExpr = exprToProtoInternal(child, inputs, binding)
         val optExpr =
@@ -1131,56 +1132,6 @@ object QueryPlanSerde extends Logging with CometExprShim 
{
           scalarFunctionExprToProtoWithReturnType("unhex", e.dataType, 
childExpr, failOnErrorExpr)
         optExprWithInfo(optExpr, expr, unHex._1)
 
-      case e @ Ceil(child) =>
-        val childExpr = exprToProtoInternal(child, inputs, binding)
-        child.dataType match {
-          case t: DecimalType if t.scale == 0 => // zero scale is no-op
-            childExpr
-          case t: DecimalType if t.scale < 0 => // Spark disallows negative 
scale SPARK-30252
-            withInfo(e, s"Decimal type $t has negative scale")
-            None
-          case _ =>
-            val optExpr = scalarFunctionExprToProtoWithReturnType("ceil", 
e.dataType, childExpr)
-            optExprWithInfo(optExpr, expr, child)
-        }
-
-      case e @ Floor(child) =>
-        val childExpr = exprToProtoInternal(child, inputs, binding)
-        child.dataType match {
-          case t: DecimalType if t.scale == 0 => // zero scale is no-op
-            childExpr
-          case t: DecimalType if t.scale < 0 => // Spark disallows negative 
scale SPARK-30252
-            withInfo(e, s"Decimal type $t has negative scale")
-            None
-          case _ =>
-            val optExpr = scalarFunctionExprToProtoWithReturnType("floor", 
e.dataType, childExpr)
-            optExprWithInfo(optExpr, expr, child)
-        }
-
-      // The expression for `log` functions is defined as null on numbers less 
than or equal
-      // to 0. This matches Spark and Hive behavior, where non positive values 
eval to null
-      // instead of NaN or -Infinity.
-      case Log(child) =>
-        val childExpr = exprToProtoInternal(nullIfNegative(child), inputs, 
binding)
-        val optExpr = scalarFunctionExprToProto("ln", childExpr)
-        optExprWithInfo(optExpr, expr, child)
-
-      case Log10(child) =>
-        val childExpr = exprToProtoInternal(nullIfNegative(child), inputs, 
binding)
-        val optExpr = scalarFunctionExprToProto("log10", childExpr)
-        optExprWithInfo(optExpr, expr, child)
-
-      case Log2(child) =>
-        val childExpr = exprToProtoInternal(nullIfNegative(child), inputs, 
binding)
-        val optExpr = scalarFunctionExprToProto("log2", childExpr)
-        optExprWithInfo(optExpr, expr, child)
-
-      case Pow(left, right) =>
-        val leftExpr = exprToProtoInternal(left, inputs, binding)
-        val rightExpr = exprToProtoInternal(right, inputs, binding)
-        val optExpr = scalarFunctionExprToProto("pow", leftExpr, rightExpr)
-        optExprWithInfo(optExpr, expr, left, right)
-
       case RegExpReplace(subject, pattern, replacement, startPosition) =>
         if (!RegExp.isSupportedPattern(pattern.toString) &&
           !CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.get()) {
@@ -1265,15 +1216,6 @@ object QueryPlanSerde extends Logging with CometExprShim 
{
           None
         }
 
-      case BitwiseAnd(left, right) =>
-        createBinaryExpr(
-          expr,
-          left,
-          right,
-          inputs,
-          binding,
-          (builder, binaryExpr) => builder.setBitwiseAnd(binaryExpr))
-
       case n @ Not(In(_, _)) =>
         CometNotIn.convert(n, inputs, binding)
 
@@ -1611,11 +1553,6 @@ object QueryPlanSerde extends Logging with CometExprShim 
{
     Some(ExprOuterClass.Expr.newBuilder().setScalarFunc(builder).build())
   }
 
-  private def nullIfNegative(expression: Expression): Expression = {
-    val zero = Literal.default(expression.dataType)
-    If(LessThanOrEqual(expression, zero), Literal.create(null, 
expression.dataType), expression)
-  }
-
   /**
    * Returns true if given datatype is supported as a key in DataFusion sort 
merge join.
    */
diff --git a/spark/src/main/scala/org/apache/comet/serde/math.scala 
b/spark/src/main/scala/org/apache/comet/serde/math.scala
new file mode 100644
index 000000000..700b9bd44
--- /dev/null
+++ b/spark/src/main/scala/org/apache/comet/serde/math.scala
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet.serde
+
+import org.apache.spark.sql.catalyst.expressions.{Atan2, Attribute, Ceil, 
Expression, Floor, If, LessThanOrEqual, Literal, Log, Log10, Log2}
+import org.apache.spark.sql.types.DecimalType
+
+import org.apache.comet.CometSparkSessionExtensions.withInfo
+import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, 
optExprWithInfo, scalarFunctionExprToProto, 
scalarFunctionExprToProtoWithReturnType}
+
+object CometAtan2 extends CometExpressionSerde[Atan2] {
+  override def convert(
+      expr: Atan2,
+      inputs: Seq[Attribute],
+      binding: Boolean): Option[ExprOuterClass.Expr] = {
+    val leftExpr = exprToProtoInternal(expr.left, inputs, binding)
+    val rightExpr = exprToProtoInternal(expr.right, inputs, binding)
+    val optExpr = scalarFunctionExprToProto("atan2", leftExpr, rightExpr)
+    optExprWithInfo(optExpr, expr, expr.left, expr.right)
+  }
+}
+
+object CometCeil extends CometExpressionSerde[Ceil] {
+  override def convert(
+      expr: Ceil,
+      inputs: Seq[Attribute],
+      binding: Boolean): Option[ExprOuterClass.Expr] = {
+    val childExpr = exprToProtoInternal(expr.child, inputs, binding)
+    expr.child.dataType match {
+      case t: DecimalType if t.scale == 0 => // zero scale is no-op
+        childExpr
+      case t: DecimalType if t.scale < 0 => // Spark disallows negative scale 
SPARK-30252
+        withInfo(expr, s"Decimal type $t has negative scale")
+        None
+      case _ =>
+        val optExpr = scalarFunctionExprToProtoWithReturnType("ceil", 
expr.dataType, childExpr)
+        optExprWithInfo(optExpr, expr, expr.child)
+    }
+  }
+}
+
+object CometFloor extends CometExpressionSerde[Floor] {
+  override def convert(
+      expr: Floor,
+      inputs: Seq[Attribute],
+      binding: Boolean): Option[ExprOuterClass.Expr] = {
+    val childExpr = exprToProtoInternal(expr.child, inputs, binding)
+    expr.child.dataType match {
+      case t: DecimalType if t.scale == 0 => // zero scale is no-op
+        childExpr
+      case t: DecimalType if t.scale < 0 => // Spark disallows negative scale 
SPARK-30252
+        withInfo(expr, s"Decimal type $t has negative scale")
+        None
+      case _ =>
+        val optExpr = scalarFunctionExprToProtoWithReturnType("floor", 
expr.dataType, childExpr)
+        optExprWithInfo(optExpr, expr, expr.child)
+    }
+  }
+}
+
+// The expression for `log` functions is defined as null on numbers less than 
or equal
+// to 0. This matches Spark and Hive behavior, where non positive values eval 
to null
+// instead of NaN or -Infinity.
+object CometLog extends CometExpressionSerde[Log] with MathExprBase {
+  override def convert(
+      expr: Log,
+      inputs: Seq[Attribute],
+      binding: Boolean): Option[ExprOuterClass.Expr] = {
+    val childExpr = exprToProtoInternal(nullIfNegative(expr.child), inputs, 
binding)
+    val optExpr = scalarFunctionExprToProto("ln", childExpr)
+    optExprWithInfo(optExpr, expr, expr.child)
+  }
+}
+
+object CometLog10 extends CometExpressionSerde[Log10] with MathExprBase {
+  override def convert(
+      expr: Log10,
+      inputs: Seq[Attribute],
+      binding: Boolean): Option[ExprOuterClass.Expr] = {
+    val childExpr = exprToProtoInternal(nullIfNegative(expr.child), inputs, 
binding)
+    val optExpr = scalarFunctionExprToProto("log10", childExpr)
+    optExprWithInfo(optExpr, expr, expr.child)
+  }
+}
+
+object CometLog2 extends CometExpressionSerde[Log2] with MathExprBase {
+  override def convert(
+      expr: Log2,
+      inputs: Seq[Attribute],
+      binding: Boolean): Option[ExprOuterClass.Expr] = {
+    val childExpr = exprToProtoInternal(nullIfNegative(expr.child), inputs, 
binding)
+    val optExpr = scalarFunctionExprToProto("log2", childExpr)
+    optExprWithInfo(optExpr, expr, expr.child)
+
+  }
+}
+
+sealed trait MathExprBase {
+  protected def nullIfNegative(expression: Expression): Expression = {
+    val zero = Literal.default(expression.dataType)
+    If(LessThanOrEqual(expression, zero), Literal.create(null, 
expression.dataType), expression)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org
For additional commands, e-mail: commits-h...@datafusion.apache.org

Reply via email to