This is an automated email from the ASF dual-hosted git repository. comphead pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push: new 1f75eda09 chore: Implement date_trunc as ScalarUDFImpl (#1880) 1f75eda09 is described below commit 1f75eda09890a30903bfd9a7e02c2287588b8d76 Author: Leung Ming <165622843+leung-m...@users.noreply.github.com> AuthorDate: Tue Jun 17 01:49:15 2025 +0800 chore: Implement date_trunc as ScalarUDFImpl (#1880) --- native/core/src/execution/planner.rs | 17 ++-- native/proto/src/proto/expr.proto | 6 -- native/spark-expr/src/datetime_funcs/date_trunc.rs | 92 ++++++++-------------- native/spark-expr/src/datetime_funcs/mod.rs | 2 +- native/spark-expr/src/lib.rs | 2 +- .../org/apache/comet/serde/QueryPlanSerde.scala | 18 +---- 6 files changed, 43 insertions(+), 94 deletions(-) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 5d5d39635..09853b6d4 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -66,6 +66,7 @@ use datafusion::{ }; use datafusion_comet_spark_expr::{ create_comet_physical_fun, create_negate_expr, SparkBitwiseCount, SparkBitwiseNot, + SparkDateTrunc, }; use crate::execution::operators::ExecutionError::GeneralError; @@ -105,10 +106,10 @@ use datafusion_comet_proto::{ }; use datafusion_comet_spark_expr::{ ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Contains, Correlation, Covariance, - CreateNamedStruct, DateTruncExpr, EndsWith, GetArrayStructFields, GetStructField, HourExpr, - IfExpr, Like, ListExtract, MinuteExpr, NormalizeNaNAndZero, RLike, SecondExpr, - SparkCastOptions, StartsWith, Stddev, StringSpaceExpr, SubstringExpr, SumDecimal, - TimestampTruncExpr, ToJson, UnboundColumn, Variance, + CreateNamedStruct, EndsWith, GetArrayStructFields, GetStructField, HourExpr, IfExpr, Like, + ListExtract, MinuteExpr, NormalizeNaNAndZero, RLike, SecondExpr, SparkCastOptions, StartsWith, + Stddev, StringSpaceExpr, SubstringExpr, SumDecimal, TimestampTruncExpr, ToJson, UnboundColumn, + Variance, }; use datafusion_spark::function::math::expm1::SparkExpm1; use itertools::Itertools; @@ -158,6 +159,7 @@ impl PhysicalPlanner { session_ctx.register_udf(ScalarUDF::new_from_impl(SparkExpm1::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkBitwiseNot::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkBitwiseCount::default())); + session_ctx.register_udf(ScalarUDF::new_from_impl(SparkDateTrunc::default())); Self { exec_context_id: TEST_EXEC_CONTEXT_ID, session_ctx, @@ -475,13 +477,6 @@ impl PhysicalPlanner { Ok(Arc::new(SecondExpr::new(child, timezone))) } - ExprStruct::TruncDate(expr) => { - let child = - self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&input_schema))?; - let format = self.create_expr(expr.format.as_ref().unwrap(), input_schema)?; - - Ok(Arc::new(DateTruncExpr::new(child, format))) - } ExprStruct::TruncTimestamp(expr) => { let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&input_schema))?; diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index d74e675f7..4a1f6eb4f 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -70,7 +70,6 @@ message Expr { BinaryExpr bitwiseShiftLeft = 43; IfExpr if = 44; NormalizeNaNAndZero normalize_nan_and_zero = 45; - TruncDate truncDate = 46; TruncTimestamp truncTimestamp = 47; Abs abs = 49; Subquery subquery = 50; @@ -344,11 +343,6 @@ message IfExpr { Expr false_expr = 3; } -message TruncDate { - Expr child = 1; - Expr format = 2; -} - message TruncTimestamp { Expr format = 1; Expr child = 2; diff --git a/native/spark-expr/src/datetime_funcs/date_trunc.rs b/native/spark-expr/src/datetime_funcs/date_trunc.rs index 1f91ba64b..861f5a2ae 100644 --- a/native/spark-expr/src/datetime_funcs/date_trunc.rs +++ b/native/spark-expr/src/datetime_funcs/date_trunc.rs @@ -15,76 +15,58 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::{DataType, Schema}; -use arrow::record_batch::RecordBatch; -use datafusion::common::{DataFusionError, ScalarValue::Utf8}; -use datafusion::logical_expr::ColumnarValue; -use datafusion::physical_expr::PhysicalExpr; -use std::hash::Hash; -use std::{ - any::Any, - fmt::{Debug, Display, Formatter}, - sync::Arc, +use arrow::datatypes::DataType; +use datafusion::common::{utils::take_function_args, DataFusionError, Result, ScalarValue::Utf8}; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; +use std::any::Any; use crate::kernels::temporal::{date_trunc_array_fmt_dyn, date_trunc_dyn}; -#[derive(Debug, Eq)] -pub struct DateTruncExpr { - /// An array with DataType::Date32 - child: Arc<dyn PhysicalExpr>, - /// Scalar UTF8 string matching the valid values in Spark SQL: https://spark.apache.org/docs/latest/api/sql/index.html#trunc - format: Arc<dyn PhysicalExpr>, +#[derive(Debug)] +pub struct SparkDateTrunc { + signature: Signature, + aliases: Vec<String>, } -impl Hash for DateTruncExpr { - fn hash<H: std::hash::Hasher>(&self, state: &mut H) { - self.child.hash(state); - self.format.hash(state); - } -} -impl PartialEq for DateTruncExpr { - fn eq(&self, other: &Self) -> bool { - self.child.eq(&other.child) && self.format.eq(&other.format) - } -} - -impl DateTruncExpr { - pub fn new(child: Arc<dyn PhysicalExpr>, format: Arc<dyn PhysicalExpr>) -> Self { - DateTruncExpr { child, format } +impl SparkDateTrunc { + pub fn new() -> Self { + Self { + signature: Signature::exact( + vec![DataType::Date32, DataType::Utf8], + Volatility::Immutable, + ), + aliases: vec![], + } } } -impl Display for DateTruncExpr { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!( - f, - "DateTrunc [child:{}, format: {}]", - self.child, self.format - ) +impl Default for SparkDateTrunc { + fn default() -> Self { + Self::new() } } -impl PhysicalExpr for DateTruncExpr { +impl ScalarUDFImpl for SparkDateTrunc { fn as_any(&self) -> &dyn Any { self } - fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result { - unimplemented!() + fn name(&self) -> &str { + "date_trunc" } - fn data_type(&self, input_schema: &Schema) -> datafusion::common::Result<DataType> { - self.child.data_type(input_schema) + fn signature(&self) -> &Signature { + &self.signature } - fn nullable(&self, _: &Schema) -> datafusion::common::Result<bool> { - Ok(true) + fn return_type(&self, _: &[DataType]) -> Result<DataType> { + Ok(DataType::Date32) } - fn evaluate(&self, batch: &RecordBatch) -> datafusion::common::Result<ColumnarValue> { - let date = self.child.evaluate(batch)?; - let format = self.format.evaluate(batch)?; + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> { + let [date, format] = take_function_args(self.name(), args.args)?; match (date, format) { (ColumnarValue::Array(date), ColumnarValue::Scalar(Utf8(Some(format)))) => { let result = date_trunc_dyn(&date, format)?; @@ -101,17 +83,7 @@ impl PhysicalExpr for DateTruncExpr { } } - fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> { - vec![&self.child] - } - - fn with_new_children( - self: Arc<Self>, - children: Vec<Arc<dyn PhysicalExpr>>, - ) -> Result<Arc<dyn PhysicalExpr>, DataFusionError> { - Ok(Arc::new(DateTruncExpr::new( - Arc::clone(&children[0]), - Arc::clone(&self.format), - ))) + fn aliases(&self) -> &[String] { + &self.aliases } } diff --git a/native/spark-expr/src/datetime_funcs/mod.rs b/native/spark-expr/src/datetime_funcs/mod.rs index 1f4d42728..e0baa1fce 100644 --- a/native/spark-expr/src/datetime_funcs/mod.rs +++ b/native/spark-expr/src/datetime_funcs/mod.rs @@ -23,7 +23,7 @@ mod second; mod timestamp_trunc; pub use date_arithmetic::{spark_date_add, spark_date_sub}; -pub use date_trunc::DateTruncExpr; +pub use date_trunc::SparkDateTrunc; pub use hour::HourExpr; pub use minute::MinuteExpr; pub use second::SecondExpr; diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index ae8e639b3..c2aac93e2 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -60,7 +60,7 @@ pub use conversion_funcs::*; pub use comet_scalar_funcs::create_comet_physical_fun; pub use datetime_funcs::{ - spark_date_add, spark_date_sub, DateTruncExpr, HourExpr, MinuteExpr, SecondExpr, + spark_date_add, spark_date_sub, HourExpr, MinuteExpr, SecondExpr, SparkDateTrunc, TimestampTruncExpr, }; pub use error::{SparkError, SparkResult}; diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 13bea457d..90a90e773 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -1044,21 +1044,9 @@ object QueryPlanSerde extends Logging with CometExprShim { case TruncDate(child, format) => val childExpr = exprToProtoInternal(child, inputs, binding) val formatExpr = exprToProtoInternal(format, inputs, binding) - - if (childExpr.isDefined && formatExpr.isDefined) { - val builder = ExprOuterClass.TruncDate.newBuilder() - builder.setChild(childExpr.get) - builder.setFormat(formatExpr.get) - - Some( - ExprOuterClass.Expr - .newBuilder() - .setTruncDate(builder) - .build()) - } else { - withInfo(expr, child, format) - None - } + val optExpr = + scalarFunctionExprToProtoWithReturnType("date_trunc", DateType, childExpr, formatExpr) + optExprWithInfo(optExpr, expr, child, format) case TruncTimestamp(format, child, timeZoneId) => val childExpr = exprToProtoInternal(child, inputs, binding) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org For additional commands, e-mail: commits-h...@datafusion.apache.org