This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 83d951d1c Chore: Used DataFusion impl of date_add and date_sub 
functions (#2473)
83d951d1c is described below

commit 83d951d1c1a89eb34fc1dc5aad3c8bfa00cc063c
Author: Kazantsev Maksim <[email protected]>
AuthorDate: Tue Sep 30 18:17:39 2025 -0700

    Chore: Used DataFusion impl of date_add and date_sub functions (#2473)
    
    * Date_add and date_sub to DataFusion impl
    
    * Fix tests
    
    ---------
    
    Co-authored-by: Kazantsev Maksim <[email protected]>
---
 native/core/src/execution/jni_api.rs               |   4 +
 native/spark-expr/src/comet_scalar_funcs.rs        |  16 +---
 .../src/datetime_funcs/date_arithmetic.rs          | 101 ---------------------
 native/spark-expr/src/datetime_funcs/mod.rs        |   2 -
 native/spark-expr/src/lib.rs                       |   5 +-
 .../scala/org/apache/comet/serde/datetime.scala    |  28 +-----
 .../org/apache/comet/CometExpressionSuite.scala    |   5 +-
 7 files changed, 15 insertions(+), 146 deletions(-)

diff --git a/native/core/src/execution/jni_api.rs 
b/native/core/src/execution/jni_api.rs
index 1f9a4263f..52b8eb6a3 100644
--- a/native/core/src/execution/jni_api.rs
+++ b/native/core/src/execution/jni_api.rs
@@ -41,6 +41,8 @@ use datafusion::{
 };
 use datafusion_comet_proto::spark_operator::Operator;
 use datafusion_spark::function::bitwise::bit_get::SparkBitGet;
+use datafusion_spark::function::datetime::date_add::SparkDateAdd;
+use datafusion_spark::function::datetime::date_sub::SparkDateSub;
 use datafusion_spark::function::hash::sha2::SparkSha2;
 use datafusion_spark::function::math::expm1::SparkExpm1;
 use datafusion_spark::function::string::char::CharFunc;
@@ -303,6 +305,8 @@ fn prepare_datafusion_session_context(
     session_ctx.register_udf(ScalarUDF::new_from_impl(SparkSha2::default()));
     session_ctx.register_udf(ScalarUDF::new_from_impl(CharFunc::default()));
     session_ctx.register_udf(ScalarUDF::new_from_impl(SparkBitGet::default()));
+    
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkDateAdd::default()));
+    
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkDateSub::default()));
 
     // Must be the last one to override existing functions with the same name
     datafusion_comet_spark_expr::register_all_comet_functions(&mut 
session_ctx)?;
diff --git a/native/spark-expr/src/comet_scalar_funcs.rs 
b/native/spark-expr/src/comet_scalar_funcs.rs
index 4bf1cd45d..4b863927e 100644
--- a/native/spark-expr/src/comet_scalar_funcs.rs
+++ b/native/spark-expr/src/comet_scalar_funcs.rs
@@ -19,10 +19,10 @@ use crate::hash_funcs::*;
 use crate::math_funcs::checked_arithmetic::{checked_add, checked_div, 
checked_mul, checked_sub};
 use crate::math_funcs::modulo_expr::spark_modulo;
 use crate::{
-    spark_array_repeat, spark_ceil, spark_date_add, spark_date_sub, 
spark_decimal_div,
-    spark_decimal_integral_div, spark_floor, spark_hex, spark_isnan, 
spark_make_decimal,
-    spark_read_side_padding, spark_round, spark_rpad, spark_unhex, 
spark_unscaled_value, EvalMode,
-    SparkBitwiseCount, SparkBitwiseNot, SparkDateTrunc, SparkStringSpace,
+    spark_array_repeat, spark_ceil, spark_decimal_div, 
spark_decimal_integral_div, spark_floor,
+    spark_hex, spark_isnan, spark_make_decimal, spark_read_side_padding, 
spark_round, spark_rpad,
+    spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseCount, 
SparkBitwiseNot,
+    SparkDateTrunc, SparkStringSpace,
 };
 use arrow::datatypes::DataType;
 use datafusion::common::{DataFusionError, Result as DataFusionResult};
@@ -166,14 +166,6 @@ pub fn create_comet_physical_fun_with_eval_mode(
             let func = Arc::new(spark_isnan);
             make_comet_scalar_udf!("isnan", func, without data_type)
         }
-        "date_add" => {
-            let func = Arc::new(spark_date_add);
-            make_comet_scalar_udf!("date_add", func, without data_type)
-        }
-        "date_sub" => {
-            let func = Arc::new(spark_date_sub);
-            make_comet_scalar_udf!("date_sub", func, without data_type)
-        }
         "array_repeat" => {
             let func = Arc::new(spark_array_repeat);
             make_comet_scalar_udf!("array_repeat", func, without data_type)
diff --git a/native/spark-expr/src/datetime_funcs/date_arithmetic.rs 
b/native/spark-expr/src/datetime_funcs/date_arithmetic.rs
deleted file mode 100644
index 4b4db2eb5..000000000
--- a/native/spark-expr/src/datetime_funcs/date_arithmetic.rs
+++ /dev/null
@@ -1,101 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements.  See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership.  The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License.  You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied.  See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-use arrow::array::builder::IntervalDayTimeBuilder;
-use arrow::array::types::{Int16Type, Int32Type, Int8Type};
-use arrow::array::{Array, Datum};
-use arrow::array::{ArrayRef, AsArray};
-use arrow::compute::kernels::numeric::{add, sub};
-use arrow::datatypes::DataType;
-use arrow::datatypes::IntervalDayTime;
-use arrow::error::ArrowError;
-use datafusion::common::{DataFusionError, ScalarValue};
-use datafusion::physical_expr_common::datum;
-use datafusion::physical_plan::ColumnarValue;
-use std::sync::Arc;
-
-macro_rules! scalar_date_arithmetic {
-    ($start:expr, $days:expr, $op:expr) => {{
-        let interval = IntervalDayTime::new(*$days as i32, 0);
-        let interval_cv = 
ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(interval)));
-        datum::apply($start, &interval_cv, $op)
-    }};
-}
-macro_rules! array_date_arithmetic {
-    ($days:expr, $interval_builder:expr, $intType:ty) => {{
-        for day in $days.as_primitive::<$intType>().into_iter() {
-            if let Some(non_null_day) = day {
-                
$interval_builder.append_value(IntervalDayTime::new(non_null_day as i32, 0));
-            } else {
-                $interval_builder.append_null();
-            }
-        }
-    }};
-}
-
-/// Spark-compatible `date_add` and `date_sub` expressions, which assumes days 
for the second
-/// argument, but we cannot directly add that to a Date32. We generate an 
IntervalDayTime from the
-/// second argument and use DataFusion's interface to apply Arrow's operators.
-fn spark_date_arithmetic(
-    args: &[ColumnarValue],
-    op: impl Fn(&dyn Datum, &dyn Datum) -> Result<ArrayRef, ArrowError>,
-) -> Result<ColumnarValue, DataFusionError> {
-    let start = &args[0];
-    match &args[1] {
-        ColumnarValue::Scalar(ScalarValue::Int8(Some(days))) => {
-            scalar_date_arithmetic!(start, days, op)
-        }
-        ColumnarValue::Scalar(ScalarValue::Int16(Some(days))) => {
-            scalar_date_arithmetic!(start, days, op)
-        }
-        ColumnarValue::Scalar(ScalarValue::Int32(Some(days))) => {
-            scalar_date_arithmetic!(start, days, op)
-        }
-        ColumnarValue::Array(days) => {
-            let mut interval_builder = 
IntervalDayTimeBuilder::with_capacity(days.len());
-            match days.data_type() {
-                DataType::Int8 => {
-                    array_date_arithmetic!(days, interval_builder, Int8Type)
-                }
-                DataType::Int16 => {
-                    array_date_arithmetic!(days, interval_builder, Int16Type)
-                }
-                DataType::Int32 => {
-                    array_date_arithmetic!(days, interval_builder, Int32Type)
-                }
-                _ => {
-                    return Err(DataFusionError::Internal(format!(
-                        "Unsupported data types {args:?} for date arithmetic.",
-                    )))
-                }
-            }
-            let interval_cv = 
ColumnarValue::Array(Arc::new(interval_builder.finish()));
-            datum::apply(start, &interval_cv, op)
-        }
-        _ => Err(DataFusionError::Internal(format!(
-            "Unsupported data types {args:?} for date arithmetic.",
-        ))),
-    }
-}
-
-pub fn spark_date_add(args: &[ColumnarValue]) -> Result<ColumnarValue, 
DataFusionError> {
-    spark_date_arithmetic(args, add)
-}
-
-pub fn spark_date_sub(args: &[ColumnarValue]) -> Result<ColumnarValue, 
DataFusionError> {
-    spark_date_arithmetic(args, sub)
-}
diff --git a/native/spark-expr/src/datetime_funcs/mod.rs 
b/native/spark-expr/src/datetime_funcs/mod.rs
index 0ca7bb940..ef8041e5f 100644
--- a/native/spark-expr/src/datetime_funcs/mod.rs
+++ b/native/spark-expr/src/datetime_funcs/mod.rs
@@ -15,12 +15,10 @@
 // specific language governing permissions and limitations
 // under the License.
 
-mod date_arithmetic;
 mod date_trunc;
 mod extract_date_part;
 mod timestamp_trunc;
 
-pub use date_arithmetic::{spark_date_add, spark_date_sub};
 pub use date_trunc::SparkDateTrunc;
 pub use extract_date_part::SparkHour;
 pub use extract_date_part::SparkMinute;
diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs
index 7bdc7ff51..932fcbe53 100644
--- a/native/spark-expr/src/lib.rs
+++ b/native/spark-expr/src/lib.rs
@@ -68,10 +68,7 @@ pub use comet_scalar_funcs::{
     create_comet_physical_fun, create_comet_physical_fun_with_eval_mode,
     register_all_comet_functions,
 };
-pub use datetime_funcs::{
-    spark_date_add, spark_date_sub, SparkDateTrunc, SparkHour, SparkMinute, 
SparkSecond,
-    TimestampTruncExpr,
-};
+pub use datetime_funcs::{SparkDateTrunc, SparkHour, SparkMinute, SparkSecond, 
TimestampTruncExpr};
 pub use error::{SparkError, SparkResult};
 pub use hash_funcs::*;
 pub use json_funcs::ToJson;
diff --git a/spark/src/main/scala/org/apache/comet/serde/datetime.scala 
b/spark/src/main/scala/org/apache/comet/serde/datetime.scala
index 8e4c92d70..9473ee30e 100644
--- a/spark/src/main/scala/org/apache/comet/serde/datetime.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/datetime.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.types.{DateType, IntegerType}
 import org.apache.comet.CometSparkSessionExtensions.withInfo
 import org.apache.comet.serde.CometGetDateField.CometGetDateField
 import org.apache.comet.serde.ExprOuterClass.Expr
-import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, 
optExprWithInfo, scalarFunctionExprToProto, 
scalarFunctionExprToProtoWithReturnType, serializeDataType}
+import org.apache.comet.serde.QueryPlanSerde._
 
 private object CometGetDateField extends Enumeration {
   type CometGetDateField = Value
@@ -251,31 +251,9 @@ object CometSecond extends CometExpressionSerde[Second] {
   }
 }
 
-object CometDateAdd extends CometExpressionSerde[DateAdd] {
-  override def convert(
-      expr: DateAdd,
-      inputs: Seq[Attribute],
-      binding: Boolean): Option[ExprOuterClass.Expr] = {
-    val leftExpr = exprToProtoInternal(expr.left, inputs, binding)
-    val rightExpr = exprToProtoInternal(expr.right, inputs, binding)
-    val optExpr =
-      scalarFunctionExprToProtoWithReturnType("date_add", DateType, leftExpr, 
rightExpr)
-    optExprWithInfo(optExpr, expr, expr.left, expr.right)
-  }
-}
+object CometDateAdd extends CometScalarFunction[DateAdd]("date_add")
 
-object CometDateSub extends CometExpressionSerde[DateSub] {
-  override def convert(
-      expr: DateSub,
-      inputs: Seq[Attribute],
-      binding: Boolean): Option[ExprOuterClass.Expr] = {
-    val leftExpr = exprToProtoInternal(expr.left, inputs, binding)
-    val rightExpr = exprToProtoInternal(expr.right, inputs, binding)
-    val optExpr =
-      scalarFunctionExprToProtoWithReturnType("date_sub", DateType, leftExpr, 
rightExpr)
-    optExprWithInfo(optExpr, expr, expr.left, expr.right)
-  }
-}
+object CometDateSub extends CometScalarFunction[DateSub]("date_sub")
 
 object CometTruncDate extends CometExpressionSerde[TruncDate] {
   override def convert(
diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
index daf0e45cc..07663ea91 100644
--- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
@@ -252,7 +252,7 @@ class CometExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
           } else {
             assert(sparkErr.get.getMessage.contains("integer overflow"))
           }
-          assert(cometErr.get.getMessage.contains("`NaiveDate + TimeDelta` 
overflowed"))
+          assert(cometErr.get.getMessage.contains("attempt to add with 
overflow"))
         }
       }
     }
@@ -296,10 +296,11 @@ class CometExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
             checkSparkMaybeThrows(sql(s"SELECT _20 - ${Int.MaxValue} FROM 
tbl"))
           if (isSpark40Plus) {
             
assert(sparkErr.get.getMessage.contains("EXPRESSION_DECODING_FAILED"))
+            
assert(cometErr.get.getMessage.contains("EXPRESSION_DECODING_FAILED"))
           } else {
             assert(sparkErr.get.getMessage.contains("integer overflow"))
+            assert(cometErr.get.getMessage.contains("integer overflow"))
           }
-          assert(cometErr.get.getMessage.contains("`NaiveDate - TimeDelta` 
overflowed"))
         }
       }
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to