This is an automated email from the ASF dual-hosted git repository.

comphead pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 1f75eda09 chore: Implement date_trunc as ScalarUDFImpl (#1880)
1f75eda09 is described below

commit 1f75eda09890a30903bfd9a7e02c2287588b8d76
Author: Leung Ming <165622843+leung-m...@users.noreply.github.com>
AuthorDate: Tue Jun 17 01:49:15 2025 +0800

    chore: Implement date_trunc as ScalarUDFImpl (#1880)
---
 native/core/src/execution/planner.rs               | 17 ++--
 native/proto/src/proto/expr.proto                  |  6 --
 native/spark-expr/src/datetime_funcs/date_trunc.rs | 92 ++++++++--------------
 native/spark-expr/src/datetime_funcs/mod.rs        |  2 +-
 native/spark-expr/src/lib.rs                       |  2 +-
 .../org/apache/comet/serde/QueryPlanSerde.scala    | 18 +----
 6 files changed, 43 insertions(+), 94 deletions(-)

diff --git a/native/core/src/execution/planner.rs 
b/native/core/src/execution/planner.rs
index 5d5d39635..09853b6d4 100644
--- a/native/core/src/execution/planner.rs
+++ b/native/core/src/execution/planner.rs
@@ -66,6 +66,7 @@ use datafusion::{
 };
 use datafusion_comet_spark_expr::{
     create_comet_physical_fun, create_negate_expr, SparkBitwiseCount, 
SparkBitwiseNot,
+    SparkDateTrunc,
 };
 
 use crate::execution::operators::ExecutionError::GeneralError;
@@ -105,10 +106,10 @@ use datafusion_comet_proto::{
 };
 use datafusion_comet_spark_expr::{
     ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Contains, Correlation, 
Covariance,
-    CreateNamedStruct, DateTruncExpr, EndsWith, GetArrayStructFields, 
GetStructField, HourExpr,
-    IfExpr, Like, ListExtract, MinuteExpr, NormalizeNaNAndZero, RLike, 
SecondExpr,
-    SparkCastOptions, StartsWith, Stddev, StringSpaceExpr, SubstringExpr, 
SumDecimal,
-    TimestampTruncExpr, ToJson, UnboundColumn, Variance,
+    CreateNamedStruct, EndsWith, GetArrayStructFields, GetStructField, 
HourExpr, IfExpr, Like,
+    ListExtract, MinuteExpr, NormalizeNaNAndZero, RLike, SecondExpr, 
SparkCastOptions, StartsWith,
+    Stddev, StringSpaceExpr, SubstringExpr, SumDecimal, TimestampTruncExpr, 
ToJson, UnboundColumn,
+    Variance,
 };
 use datafusion_spark::function::math::expm1::SparkExpm1;
 use itertools::Itertools;
@@ -158,6 +159,7 @@ impl PhysicalPlanner {
         
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkExpm1::default()));
         
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkBitwiseNot::default()));
         
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkBitwiseCount::default()));
+        
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkDateTrunc::default()));
         Self {
             exec_context_id: TEST_EXEC_CONTEXT_ID,
             session_ctx,
@@ -475,13 +477,6 @@ impl PhysicalPlanner {
 
                 Ok(Arc::new(SecondExpr::new(child, timezone)))
             }
-            ExprStruct::TruncDate(expr) => {
-                let child =
-                    self.create_expr(expr.child.as_ref().unwrap(), 
Arc::clone(&input_schema))?;
-                let format = self.create_expr(expr.format.as_ref().unwrap(), 
input_schema)?;
-
-                Ok(Arc::new(DateTruncExpr::new(child, format)))
-            }
             ExprStruct::TruncTimestamp(expr) => {
                 let child =
                     self.create_expr(expr.child.as_ref().unwrap(), 
Arc::clone(&input_schema))?;
diff --git a/native/proto/src/proto/expr.proto 
b/native/proto/src/proto/expr.proto
index d74e675f7..4a1f6eb4f 100644
--- a/native/proto/src/proto/expr.proto
+++ b/native/proto/src/proto/expr.proto
@@ -70,7 +70,6 @@ message Expr {
     BinaryExpr bitwiseShiftLeft = 43;
     IfExpr if = 44;
     NormalizeNaNAndZero normalize_nan_and_zero = 45;
-    TruncDate truncDate = 46;
     TruncTimestamp truncTimestamp = 47;
     Abs abs = 49;
     Subquery subquery = 50;
@@ -344,11 +343,6 @@ message IfExpr {
   Expr false_expr = 3;
 }
 
-message TruncDate {
-  Expr child = 1;
-  Expr format = 2;
-}
-
 message TruncTimestamp {
   Expr format = 1;
   Expr child = 2;
diff --git a/native/spark-expr/src/datetime_funcs/date_trunc.rs 
b/native/spark-expr/src/datetime_funcs/date_trunc.rs
index 1f91ba64b..861f5a2ae 100644
--- a/native/spark-expr/src/datetime_funcs/date_trunc.rs
+++ b/native/spark-expr/src/datetime_funcs/date_trunc.rs
@@ -15,76 +15,58 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use arrow::datatypes::{DataType, Schema};
-use arrow::record_batch::RecordBatch;
-use datafusion::common::{DataFusionError, ScalarValue::Utf8};
-use datafusion::logical_expr::ColumnarValue;
-use datafusion::physical_expr::PhysicalExpr;
-use std::hash::Hash;
-use std::{
-    any::Any,
-    fmt::{Debug, Display, Formatter},
-    sync::Arc,
+use arrow::datatypes::DataType;
+use datafusion::common::{utils::take_function_args, DataFusionError, Result, 
ScalarValue::Utf8};
+use datafusion::logical_expr::{
+    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
 };
+use std::any::Any;
 
 use crate::kernels::temporal::{date_trunc_array_fmt_dyn, date_trunc_dyn};
 
-#[derive(Debug, Eq)]
-pub struct DateTruncExpr {
-    /// An array with DataType::Date32
-    child: Arc<dyn PhysicalExpr>,
-    /// Scalar UTF8 string matching the valid values in Spark SQL: 
https://spark.apache.org/docs/latest/api/sql/index.html#trunc
-    format: Arc<dyn PhysicalExpr>,
+#[derive(Debug)]
+pub struct SparkDateTrunc {
+    signature: Signature,
+    aliases: Vec<String>,
 }
 
-impl Hash for DateTruncExpr {
-    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
-        self.child.hash(state);
-        self.format.hash(state);
-    }
-}
-impl PartialEq for DateTruncExpr {
-    fn eq(&self, other: &Self) -> bool {
-        self.child.eq(&other.child) && self.format.eq(&other.format)
-    }
-}
-
-impl DateTruncExpr {
-    pub fn new(child: Arc<dyn PhysicalExpr>, format: Arc<dyn PhysicalExpr>) -> 
Self {
-        DateTruncExpr { child, format }
+impl SparkDateTrunc {
+    pub fn new() -> Self {
+        Self {
+            signature: Signature::exact(
+                vec![DataType::Date32, DataType::Utf8],
+                Volatility::Immutable,
+            ),
+            aliases: vec![],
+        }
     }
 }
 
-impl Display for DateTruncExpr {
-    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
-        write!(
-            f,
-            "DateTrunc [child:{}, format: {}]",
-            self.child, self.format
-        )
+impl Default for SparkDateTrunc {
+    fn default() -> Self {
+        Self::new()
     }
 }
 
-impl PhysicalExpr for DateTruncExpr {
+impl ScalarUDFImpl for SparkDateTrunc {
     fn as_any(&self) -> &dyn Any {
         self
     }
 
-    fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result {
-        unimplemented!()
+    fn name(&self) -> &str {
+        "date_trunc"
     }
 
-    fn data_type(&self, input_schema: &Schema) -> 
datafusion::common::Result<DataType> {
-        self.child.data_type(input_schema)
+    fn signature(&self) -> &Signature {
+        &self.signature
     }
 
-    fn nullable(&self, _: &Schema) -> datafusion::common::Result<bool> {
-        Ok(true)
+    fn return_type(&self, _: &[DataType]) -> Result<DataType> {
+        Ok(DataType::Date32)
     }
 
-    fn evaluate(&self, batch: &RecordBatch) -> 
datafusion::common::Result<ColumnarValue> {
-        let date = self.child.evaluate(batch)?;
-        let format = self.format.evaluate(batch)?;
+    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> 
Result<ColumnarValue> {
+        let [date, format] = take_function_args(self.name(), args.args)?;
         match (date, format) {
             (ColumnarValue::Array(date), 
ColumnarValue::Scalar(Utf8(Some(format)))) => {
                 let result = date_trunc_dyn(&date, format)?;
@@ -101,17 +83,7 @@ impl PhysicalExpr for DateTruncExpr {
         }
     }
 
-    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
-        vec![&self.child]
-    }
-
-    fn with_new_children(
-        self: Arc<Self>,
-        children: Vec<Arc<dyn PhysicalExpr>>,
-    ) -> Result<Arc<dyn PhysicalExpr>, DataFusionError> {
-        Ok(Arc::new(DateTruncExpr::new(
-            Arc::clone(&children[0]),
-            Arc::clone(&self.format),
-        )))
+    fn aliases(&self) -> &[String] {
+        &self.aliases
     }
 }
diff --git a/native/spark-expr/src/datetime_funcs/mod.rs 
b/native/spark-expr/src/datetime_funcs/mod.rs
index 1f4d42728..e0baa1fce 100644
--- a/native/spark-expr/src/datetime_funcs/mod.rs
+++ b/native/spark-expr/src/datetime_funcs/mod.rs
@@ -23,7 +23,7 @@ mod second;
 mod timestamp_trunc;
 
 pub use date_arithmetic::{spark_date_add, spark_date_sub};
-pub use date_trunc::DateTruncExpr;
+pub use date_trunc::SparkDateTrunc;
 pub use hour::HourExpr;
 pub use minute::MinuteExpr;
 pub use second::SecondExpr;
diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs
index ae8e639b3..c2aac93e2 100644
--- a/native/spark-expr/src/lib.rs
+++ b/native/spark-expr/src/lib.rs
@@ -60,7 +60,7 @@ pub use conversion_funcs::*;
 
 pub use comet_scalar_funcs::create_comet_physical_fun;
 pub use datetime_funcs::{
-    spark_date_add, spark_date_sub, DateTruncExpr, HourExpr, MinuteExpr, 
SecondExpr,
+    spark_date_add, spark_date_sub, HourExpr, MinuteExpr, SecondExpr, 
SparkDateTrunc,
     TimestampTruncExpr,
 };
 pub use error::{SparkError, SparkResult};
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 13bea457d..90a90e773 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -1044,21 +1044,9 @@ object QueryPlanSerde extends Logging with CometExprShim 
{
       case TruncDate(child, format) =>
         val childExpr = exprToProtoInternal(child, inputs, binding)
         val formatExpr = exprToProtoInternal(format, inputs, binding)
-
-        if (childExpr.isDefined && formatExpr.isDefined) {
-          val builder = ExprOuterClass.TruncDate.newBuilder()
-          builder.setChild(childExpr.get)
-          builder.setFormat(formatExpr.get)
-
-          Some(
-            ExprOuterClass.Expr
-              .newBuilder()
-              .setTruncDate(builder)
-              .build())
-        } else {
-          withInfo(expr, child, format)
-          None
-        }
+        val optExpr =
+          scalarFunctionExprToProtoWithReturnType("date_trunc", DateType, 
childExpr, formatExpr)
+        optExprWithInfo(optExpr, expr, child, format)
 
       case TruncTimestamp(format, child, timeZoneId) =>
         val childExpr = exprToProtoInternal(child, inputs, binding)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org
For additional commands, e-mail: commits-h...@datafusion.apache.org

Reply via email to