This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 83d951d1c Chore: Used DataFusion impl of date_add and date_sub
functions (#2473)
83d951d1c is described below
commit 83d951d1c1a89eb34fc1dc5aad3c8bfa00cc063c
Author: Kazantsev Maksim <[email protected]>
AuthorDate: Tue Sep 30 18:17:39 2025 -0700
Chore: Used DataFusion impl of date_add and date_sub functions (#2473)
* Date_add and date_sub to DataFusion impl
* Fix tests
---------
Co-authored-by: Kazantsev Maksim <[email protected]>
---
native/core/src/execution/jni_api.rs | 4 +
native/spark-expr/src/comet_scalar_funcs.rs | 16 +---
.../src/datetime_funcs/date_arithmetic.rs | 101 ---------------------
native/spark-expr/src/datetime_funcs/mod.rs | 2 -
native/spark-expr/src/lib.rs | 5 +-
.../scala/org/apache/comet/serde/datetime.scala | 28 +-----
.../org/apache/comet/CometExpressionSuite.scala | 5 +-
7 files changed, 15 insertions(+), 146 deletions(-)
diff --git a/native/core/src/execution/jni_api.rs
b/native/core/src/execution/jni_api.rs
index 1f9a4263f..52b8eb6a3 100644
--- a/native/core/src/execution/jni_api.rs
+++ b/native/core/src/execution/jni_api.rs
@@ -41,6 +41,8 @@ use datafusion::{
};
use datafusion_comet_proto::spark_operator::Operator;
use datafusion_spark::function::bitwise::bit_get::SparkBitGet;
+use datafusion_spark::function::datetime::date_add::SparkDateAdd;
+use datafusion_spark::function::datetime::date_sub::SparkDateSub;
use datafusion_spark::function::hash::sha2::SparkSha2;
use datafusion_spark::function::math::expm1::SparkExpm1;
use datafusion_spark::function::string::char::CharFunc;
@@ -303,6 +305,8 @@ fn prepare_datafusion_session_context(
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkSha2::default()));
session_ctx.register_udf(ScalarUDF::new_from_impl(CharFunc::default()));
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkBitGet::default()));
+
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkDateAdd::default()));
+
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkDateSub::default()));
// Must be the last one to override existing functions with the same name
datafusion_comet_spark_expr::register_all_comet_functions(&mut
session_ctx)?;
diff --git a/native/spark-expr/src/comet_scalar_funcs.rs
b/native/spark-expr/src/comet_scalar_funcs.rs
index 4bf1cd45d..4b863927e 100644
--- a/native/spark-expr/src/comet_scalar_funcs.rs
+++ b/native/spark-expr/src/comet_scalar_funcs.rs
@@ -19,10 +19,10 @@ use crate::hash_funcs::*;
use crate::math_funcs::checked_arithmetic::{checked_add, checked_div,
checked_mul, checked_sub};
use crate::math_funcs::modulo_expr::spark_modulo;
use crate::{
- spark_array_repeat, spark_ceil, spark_date_add, spark_date_sub,
spark_decimal_div,
- spark_decimal_integral_div, spark_floor, spark_hex, spark_isnan,
spark_make_decimal,
- spark_read_side_padding, spark_round, spark_rpad, spark_unhex,
spark_unscaled_value, EvalMode,
- SparkBitwiseCount, SparkBitwiseNot, SparkDateTrunc, SparkStringSpace,
+ spark_array_repeat, spark_ceil, spark_decimal_div,
spark_decimal_integral_div, spark_floor,
+ spark_hex, spark_isnan, spark_make_decimal, spark_read_side_padding,
spark_round, spark_rpad,
+ spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseCount,
SparkBitwiseNot,
+ SparkDateTrunc, SparkStringSpace,
};
use arrow::datatypes::DataType;
use datafusion::common::{DataFusionError, Result as DataFusionResult};
@@ -166,14 +166,6 @@ pub fn create_comet_physical_fun_with_eval_mode(
let func = Arc::new(spark_isnan);
make_comet_scalar_udf!("isnan", func, without data_type)
}
- "date_add" => {
- let func = Arc::new(spark_date_add);
- make_comet_scalar_udf!("date_add", func, without data_type)
- }
- "date_sub" => {
- let func = Arc::new(spark_date_sub);
- make_comet_scalar_udf!("date_sub", func, without data_type)
- }
"array_repeat" => {
let func = Arc::new(spark_array_repeat);
make_comet_scalar_udf!("array_repeat", func, without data_type)
diff --git a/native/spark-expr/src/datetime_funcs/date_arithmetic.rs
b/native/spark-expr/src/datetime_funcs/date_arithmetic.rs
deleted file mode 100644
index 4b4db2eb5..000000000
--- a/native/spark-expr/src/datetime_funcs/date_arithmetic.rs
+++ /dev/null
@@ -1,101 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-use arrow::array::builder::IntervalDayTimeBuilder;
-use arrow::array::types::{Int16Type, Int32Type, Int8Type};
-use arrow::array::{Array, Datum};
-use arrow::array::{ArrayRef, AsArray};
-use arrow::compute::kernels::numeric::{add, sub};
-use arrow::datatypes::DataType;
-use arrow::datatypes::IntervalDayTime;
-use arrow::error::ArrowError;
-use datafusion::common::{DataFusionError, ScalarValue};
-use datafusion::physical_expr_common::datum;
-use datafusion::physical_plan::ColumnarValue;
-use std::sync::Arc;
-
-macro_rules! scalar_date_arithmetic {
- ($start:expr, $days:expr, $op:expr) => {{
- let interval = IntervalDayTime::new(*$days as i32, 0);
- let interval_cv =
ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(interval)));
- datum::apply($start, &interval_cv, $op)
- }};
-}
-macro_rules! array_date_arithmetic {
- ($days:expr, $interval_builder:expr, $intType:ty) => {{
- for day in $days.as_primitive::<$intType>().into_iter() {
- if let Some(non_null_day) = day {
-
$interval_builder.append_value(IntervalDayTime::new(non_null_day as i32, 0));
- } else {
- $interval_builder.append_null();
- }
- }
- }};
-}
-
-/// Spark-compatible `date_add` and `date_sub` expressions, which assumes days
for the second
-/// argument, but we cannot directly add that to a Date32. We generate an
IntervalDayTime from the
-/// second argument and use DataFusion's interface to apply Arrow's operators.
-fn spark_date_arithmetic(
- args: &[ColumnarValue],
- op: impl Fn(&dyn Datum, &dyn Datum) -> Result<ArrayRef, ArrowError>,
-) -> Result<ColumnarValue, DataFusionError> {
- let start = &args[0];
- match &args[1] {
- ColumnarValue::Scalar(ScalarValue::Int8(Some(days))) => {
- scalar_date_arithmetic!(start, days, op)
- }
- ColumnarValue::Scalar(ScalarValue::Int16(Some(days))) => {
- scalar_date_arithmetic!(start, days, op)
- }
- ColumnarValue::Scalar(ScalarValue::Int32(Some(days))) => {
- scalar_date_arithmetic!(start, days, op)
- }
- ColumnarValue::Array(days) => {
- let mut interval_builder =
IntervalDayTimeBuilder::with_capacity(days.len());
- match days.data_type() {
- DataType::Int8 => {
- array_date_arithmetic!(days, interval_builder, Int8Type)
- }
- DataType::Int16 => {
- array_date_arithmetic!(days, interval_builder, Int16Type)
- }
- DataType::Int32 => {
- array_date_arithmetic!(days, interval_builder, Int32Type)
- }
- _ => {
- return Err(DataFusionError::Internal(format!(
- "Unsupported data types {args:?} for date arithmetic.",
- )))
- }
- }
- let interval_cv =
ColumnarValue::Array(Arc::new(interval_builder.finish()));
- datum::apply(start, &interval_cv, op)
- }
- _ => Err(DataFusionError::Internal(format!(
- "Unsupported data types {args:?} for date arithmetic.",
- ))),
- }
-}
-
-pub fn spark_date_add(args: &[ColumnarValue]) -> Result<ColumnarValue,
DataFusionError> {
- spark_date_arithmetic(args, add)
-}
-
-pub fn spark_date_sub(args: &[ColumnarValue]) -> Result<ColumnarValue,
DataFusionError> {
- spark_date_arithmetic(args, sub)
-}
diff --git a/native/spark-expr/src/datetime_funcs/mod.rs
b/native/spark-expr/src/datetime_funcs/mod.rs
index 0ca7bb940..ef8041e5f 100644
--- a/native/spark-expr/src/datetime_funcs/mod.rs
+++ b/native/spark-expr/src/datetime_funcs/mod.rs
@@ -15,12 +15,10 @@
// specific language governing permissions and limitations
// under the License.
-mod date_arithmetic;
mod date_trunc;
mod extract_date_part;
mod timestamp_trunc;
-pub use date_arithmetic::{spark_date_add, spark_date_sub};
pub use date_trunc::SparkDateTrunc;
pub use extract_date_part::SparkHour;
pub use extract_date_part::SparkMinute;
diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs
index 7bdc7ff51..932fcbe53 100644
--- a/native/spark-expr/src/lib.rs
+++ b/native/spark-expr/src/lib.rs
@@ -68,10 +68,7 @@ pub use comet_scalar_funcs::{
create_comet_physical_fun, create_comet_physical_fun_with_eval_mode,
register_all_comet_functions,
};
-pub use datetime_funcs::{
- spark_date_add, spark_date_sub, SparkDateTrunc, SparkHour, SparkMinute,
SparkSecond,
- TimestampTruncExpr,
-};
+pub use datetime_funcs::{SparkDateTrunc, SparkHour, SparkMinute, SparkSecond,
TimestampTruncExpr};
pub use error::{SparkError, SparkResult};
pub use hash_funcs::*;
pub use json_funcs::ToJson;
diff --git a/spark/src/main/scala/org/apache/comet/serde/datetime.scala
b/spark/src/main/scala/org/apache/comet/serde/datetime.scala
index 8e4c92d70..9473ee30e 100644
--- a/spark/src/main/scala/org/apache/comet/serde/datetime.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/datetime.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.types.{DateType, IntegerType}
import org.apache.comet.CometSparkSessionExtensions.withInfo
import org.apache.comet.serde.CometGetDateField.CometGetDateField
import org.apache.comet.serde.ExprOuterClass.Expr
-import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal,
optExprWithInfo, scalarFunctionExprToProto,
scalarFunctionExprToProtoWithReturnType, serializeDataType}
+import org.apache.comet.serde.QueryPlanSerde._
private object CometGetDateField extends Enumeration {
type CometGetDateField = Value
@@ -251,31 +251,9 @@ object CometSecond extends CometExpressionSerde[Second] {
}
}
-object CometDateAdd extends CometExpressionSerde[DateAdd] {
- override def convert(
- expr: DateAdd,
- inputs: Seq[Attribute],
- binding: Boolean): Option[ExprOuterClass.Expr] = {
- val leftExpr = exprToProtoInternal(expr.left, inputs, binding)
- val rightExpr = exprToProtoInternal(expr.right, inputs, binding)
- val optExpr =
- scalarFunctionExprToProtoWithReturnType("date_add", DateType, leftExpr,
rightExpr)
- optExprWithInfo(optExpr, expr, expr.left, expr.right)
- }
-}
+object CometDateAdd extends CometScalarFunction[DateAdd]("date_add")
-object CometDateSub extends CometExpressionSerde[DateSub] {
- override def convert(
- expr: DateSub,
- inputs: Seq[Attribute],
- binding: Boolean): Option[ExprOuterClass.Expr] = {
- val leftExpr = exprToProtoInternal(expr.left, inputs, binding)
- val rightExpr = exprToProtoInternal(expr.right, inputs, binding)
- val optExpr =
- scalarFunctionExprToProtoWithReturnType("date_sub", DateType, leftExpr,
rightExpr)
- optExprWithInfo(optExpr, expr, expr.left, expr.right)
- }
-}
+object CometDateSub extends CometScalarFunction[DateSub]("date_sub")
object CometTruncDate extends CometExpressionSerde[TruncDate] {
override def convert(
diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
index daf0e45cc..07663ea91 100644
--- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
@@ -252,7 +252,7 @@ class CometExpressionSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
} else {
assert(sparkErr.get.getMessage.contains("integer overflow"))
}
- assert(cometErr.get.getMessage.contains("`NaiveDate + TimeDelta`
overflowed"))
+ assert(cometErr.get.getMessage.contains("attempt to add with
overflow"))
}
}
}
@@ -296,10 +296,11 @@ class CometExpressionSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
checkSparkMaybeThrows(sql(s"SELECT _20 - ${Int.MaxValue} FROM
tbl"))
if (isSpark40Plus) {
assert(sparkErr.get.getMessage.contains("EXPRESSION_DECODING_FAILED"))
+
assert(cometErr.get.getMessage.contains("EXPRESSION_DECODING_FAILED"))
} else {
assert(sparkErr.get.getMessage.contains("integer overflow"))
+ assert(cometErr.get.getMessage.contains("integer overflow"))
}
- assert(cometErr.get.getMessage.contains("`NaiveDate - TimeDelta`
overflowed"))
}
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]