This is an automated email from the ASF dual-hosted git repository.

mbutrovich pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 46057a7b1 chore: Refactor serde for `CheckOverflow` (#2537)
46057a7b1 is described below

commit 46057a7b1ace736e8be2dec39c27d3b9eadb6551
Author: Andy Grove <[email protected]>
AuthorDate: Thu Oct 9 10:32:32 2025 -0600

    chore: Refactor serde for `CheckOverflow` (#2537)
---
 .github/workflows/pr_build_linux.yml               |  1 +
 .github/workflows/pr_build_macos.yml               |  1 +
 .../org/apache/comet/serde/QueryPlanSerde.scala    | 24 +--------
 .../main/scala/org/apache/comet/serde/math.scala   | 42 ++++++++++++++-
 .../org/apache/comet/CometFuzzMathSuite.scala      | 59 ++++++++++++++++++++++
 5 files changed, 102 insertions(+), 25 deletions(-)

diff --git a/.github/workflows/pr_build_linux.yml 
b/.github/workflows/pr_build_linux.yml
index 9b918ad8b..208bf87f6 100644
--- a/.github/workflows/pr_build_linux.yml
+++ b/.github/workflows/pr_build_linux.yml
@@ -103,6 +103,7 @@ jobs:
             value: |
               org.apache.comet.CometFuzzTestSuite
               org.apache.comet.CometFuzzAggregateSuite
+              org.apache.comet.CometFuzzMathSuite
               org.apache.comet.DataGeneratorSuite
           - name: "shuffle"
             value: |
diff --git a/.github/workflows/pr_build_macos.yml 
b/.github/workflows/pr_build_macos.yml
index fb6a8295b..465533834 100644
--- a/.github/workflows/pr_build_macos.yml
+++ b/.github/workflows/pr_build_macos.yml
@@ -68,6 +68,7 @@ jobs:
             value: |
               org.apache.comet.CometFuzzTestSuite
               org.apache.comet.CometFuzzAggregateSuite
+              org.apache.comet.CometFuzzMathSuite
               org.apache.comet.DataGeneratorSuite
           - name: "shuffle"
             value: |
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 43f8b7293..0a4b61fce 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -221,7 +221,6 @@ object QueryPlanSerde extends Logging with CometExprShim {
   private val miscExpressions: Map[Class[_ <: Expression], 
CometExpressionSerde[_]] = Map(
     // TODO SortOrder (?)
     // TODO PromotePrecision
-    // TODO CheckOverflow
     // TODO KnownFloatingPointNormalized
     // TODO ScalarSubquery
     // TODO UnscaledValue
@@ -230,6 +229,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
     // TODO RegExpReplace
     classOf[Alias] -> CometAlias,
     classOf[AttributeReference] -> CometAttributeReference,
+    classOf[CheckOverflow] -> CometCheckOverflow,
     classOf[Coalesce] -> CometCoalesce,
     classOf[Literal] -> CometLiteral,
     classOf[MonotonicallyIncreasingID] -> CometMonotonicallyIncreasingId,
@@ -772,28 +772,6 @@ object QueryPlanSerde extends Logging with CometExprShim {
         // `PromotePrecision` is just a wrapper, don't need to serialize it.
         exprToProtoInternal(child, inputs, binding)
 
-      case CheckOverflow(child, dt, nullOnOverflow) =>
-        val childExpr = exprToProtoInternal(child, inputs, binding)
-
-        if (childExpr.isDefined) {
-          val builder = ExprOuterClass.CheckOverflow.newBuilder()
-          builder.setChild(childExpr.get)
-          builder.setFailOnError(!nullOnOverflow)
-
-          // `dataType` must be decimal type
-          val dataType = serializeDataType(dt)
-          builder.setDatatype(dataType.get)
-
-          Some(
-            ExprOuterClass.Expr
-              .newBuilder()
-              .setCheckOverflow(builder)
-              .build())
-        } else {
-          withInfo(expr, child)
-          None
-        }
-
       case RegExpReplace(subject, pattern, replacement, startPosition) =>
         if (!RegExp.isSupportedPattern(pattern.toString) &&
           !CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.get()) {
diff --git a/spark/src/main/scala/org/apache/comet/serde/math.scala 
b/spark/src/main/scala/org/apache/comet/serde/math.scala
index 8bb27153f..f2b0010d8 100644
--- a/spark/src/main/scala/org/apache/comet/serde/math.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/math.scala
@@ -19,11 +19,11 @@
 
 package org.apache.comet.serde
 
-import org.apache.spark.sql.catalyst.expressions.{Atan2, Attribute, Ceil, 
Expression, Floor, Hex, If, LessThanOrEqual, Literal, Log, Log10, Log2, Unhex}
+import org.apache.spark.sql.catalyst.expressions.{Atan2, Attribute, Ceil, 
CheckOverflow, Expression, Floor, Hex, If, LessThanOrEqual, Literal, Log, 
Log10, Log2, Unhex}
 import org.apache.spark.sql.types.DecimalType
 
 import org.apache.comet.CometSparkSessionExtensions.withInfo
-import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, 
optExprWithInfo, scalarFunctionExprToProto, 
scalarFunctionExprToProtoWithReturnType}
+import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, 
optExprWithInfo, scalarFunctionExprToProto, 
scalarFunctionExprToProtoWithReturnType, serializeDataType}
 
 object CometAtan2 extends CometExpressionSerde[Atan2] {
   override def convert(
@@ -143,3 +143,41 @@ sealed trait MathExprBase {
     If(LessThanOrEqual(expression, zero), Literal.create(null, 
expression.dataType), expression)
   }
 }
+
+object CometCheckOverflow extends CometExpressionSerde[CheckOverflow] {
+
+  override def getSupportLevel(expr: CheckOverflow): SupportLevel = {
+    if (expr.dataType.isInstanceOf[DecimalType]) {
+      Compatible()
+    } else {
+      Unsupported(Some("dataType must be DecimalType"))
+    }
+  }
+
+  override def convert(
+      expr: CheckOverflow,
+      inputs: Seq[Attribute],
+      binding: Boolean): Option[ExprOuterClass.Expr] = {
+    val childExpr = exprToProtoInternal(expr.child, inputs, binding)
+
+    if (childExpr.isDefined) {
+      val builder = ExprOuterClass.CheckOverflow.newBuilder()
+      builder.setChild(childExpr.get)
+      builder.setFailOnError(!expr.nullOnOverflow)
+
+      // `dataType` must be decimal type
+      assert(expr.dataType.isInstanceOf[DecimalType])
+      val dataType = serializeDataType(expr.dataType)
+      builder.setDatatype(dataType.get)
+
+      Some(
+        ExprOuterClass.Expr
+          .newBuilder()
+          .setCheckOverflow(builder)
+          .build())
+    } else {
+      withInfo(expr, expr.child)
+      None
+    }
+  }
+}
diff --git a/spark/src/test/scala/org/apache/comet/CometFuzzMathSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometFuzzMathSuite.scala
new file mode 100644
index 000000000..d48712ed5
--- /dev/null
+++ b/spark/src/test/scala/org/apache/comet/CometFuzzMathSuite.scala
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet
+
+import org.apache.spark.sql.types.{DecimalType, IntegerType, LongType}
+
+class CometFuzzMathSuite extends CometFuzzTestBase {
+
+  for (op <- Seq("+", "-", "*", "/", "div")) {
+    test(s"integer math: $op") {
+      val df = spark.read.parquet(filename)
+      val cols = df.schema.fields
+        .filter(_.dataType match {
+          case _: IntegerType => true
+          case _: LongType => true
+          case _ => false
+        })
+        .map(_.name)
+      df.createOrReplaceTempView("t1")
+      val sql =
+        s"SELECT ${cols(0)} $op ${cols(1)} FROM t1 ORDER BY ${cols(0)}, 
${cols(1)} LIMIT 500"
+      if (op == "div") {
+        // cast(cast(c3#1975 as bigint) as decimal(19,0)) is not fully 
compatible with Spark (No overflow check)
+        checkSparkAnswer(sql)
+      } else {
+        checkSparkAnswerAndOperator(sql)
+      }
+    }
+  }
+
+  for (op <- Seq("+", "-", "*", "/", "div")) {
+    test(s"decimal math: $op") {
+      val df = spark.read.parquet(filename)
+      val cols = 
df.schema.fields.filter(_.dataType.isInstanceOf[DecimalType]).map(_.name)
+      df.createOrReplaceTempView("t1")
+      val sql =
+        s"SELECT ${cols(0)} $op ${cols(1)} FROM t1 ORDER BY ${cols(0)}, 
${cols(1)} LIMIT 500"
+      checkSparkAnswerAndOperator(sql)
+    }
+  }
+
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to