This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 9b492c6a5e Improve `round` scalar function unparsing for Postgres 
(#12744)
9b492c6a5e is described below

commit 9b492c6a5e168171f14d4e985fcb43b535c2e872
Author: Sergei Grebnov <[email protected]>
AuthorDate: Sun Oct 6 04:12:50 2024 -0700

    Improve `round` scalar function unparsing for Postgres (#12744)
    
    * Postgres: enforce required `NUMERIC` type for `round` scalar function 
(#34)
    
    Includes initial support for dialects to override scalar functions unparsing
    
    * Document scalar_function_to_sql_overrides fn
---
 datafusion/sql/src/unparser/dialect.rs | 119 +++++++++++++++++++-
 datafusion/sql/src/unparser/expr.rs    | 198 ++++++++++++---------------------
 datafusion/sql/src/unparser/utils.rs   |  82 +++++++++++++-
 3 files changed, 273 insertions(+), 126 deletions(-)

diff --git a/datafusion/sql/src/unparser/dialect.rs 
b/datafusion/sql/src/unparser/dialect.rs
index d8a4fb2542..609e6f2240 100644
--- a/datafusion/sql/src/unparser/dialect.rs
+++ b/datafusion/sql/src/unparser/dialect.rs
@@ -18,12 +18,17 @@
 use std::sync::Arc;
 
 use arrow_schema::TimeUnit;
+use datafusion_expr::Expr;
 use regex::Regex;
 use sqlparser::{
-    ast::{self, Ident, ObjectName, TimezoneInfo},
+    ast::{self, Function, Ident, ObjectName, TimezoneInfo},
     keywords::ALL_KEYWORDS,
 };
 
+use datafusion_common::Result;
+
+use super::{utils::date_part_to_sql, Unparser};
+
 /// `Dialect` to use for Unparsing
 ///
 /// The default dialect tries to avoid quoting identifiers unless necessary 
(e.g. `a` instead of `"a"`)
@@ -108,6 +113,18 @@ pub trait Dialect: Send + Sync {
     fn supports_column_alias_in_table_alias(&self) -> bool {
         true
     }
+
+    /// Allows the dialect to override scalar function unparsing if the 
dialect has specific rules.
+    /// Returns None if the default unparsing should be used, or 
Some(ast::Expr) if there is
+    /// a custom implementation for the function.
+    fn scalar_function_to_sql_overrides(
+        &self,
+        _unparser: &Unparser,
+        _func_name: &str,
+        _args: &[Expr],
+    ) -> Result<Option<ast::Expr>> {
+        Ok(None)
+    }
 }
 
 /// `IntervalStyle` to use for unparsing
@@ -171,6 +188,67 @@ impl Dialect for PostgreSqlDialect {
     fn float64_ast_dtype(&self) -> sqlparser::ast::DataType {
         sqlparser::ast::DataType::DoublePrecision
     }
+
+    fn scalar_function_to_sql_overrides(
+        &self,
+        unparser: &Unparser,
+        func_name: &str,
+        args: &[Expr],
+    ) -> Result<Option<ast::Expr>> {
+        if func_name == "round" {
+            return Ok(Some(
+                self.round_to_sql_enforce_numeric(unparser, func_name, args)?,
+            ));
+        }
+
+        Ok(None)
+    }
+}
+
+impl PostgreSqlDialect {
+    fn round_to_sql_enforce_numeric(
+        &self,
+        unparser: &Unparser,
+        func_name: &str,
+        args: &[Expr],
+    ) -> Result<ast::Expr> {
+        let mut args = unparser.function_args_to_sql(args)?;
+
+        // Enforce the first argument to be Numeric
+        if let 
Some(ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(expr))) =
+            args.first_mut()
+        {
+            if let ast::Expr::Cast { data_type, .. } = expr {
+                // Don't create an additional cast wrapper if we can update 
the existing one
+                *data_type = 
ast::DataType::Numeric(ast::ExactNumberInfo::None);
+            } else {
+                // Wrap the expression in a new cast
+                *expr = ast::Expr::Cast {
+                    kind: ast::CastKind::Cast,
+                    expr: Box::new(expr.clone()),
+                    data_type: 
ast::DataType::Numeric(ast::ExactNumberInfo::None),
+                    format: None,
+                };
+            }
+        }
+
+        Ok(ast::Expr::Function(Function {
+            name: ast::ObjectName(vec![Ident {
+                value: func_name.to_string(),
+                quote_style: None,
+            }]),
+            args: ast::FunctionArguments::List(ast::FunctionArgumentList {
+                duplicate_treatment: None,
+                args,
+                clauses: vec![],
+            }),
+            filter: None,
+            null_treatment: None,
+            over: None,
+            within_group: vec![],
+            parameters: ast::FunctionArguments::None,
+        }))
+    }
 }
 
 pub struct MySqlDialect {}
@@ -211,6 +289,19 @@ impl Dialect for MySqlDialect {
     ) -> ast::DataType {
         ast::DataType::Datetime(None)
     }
+
+    fn scalar_function_to_sql_overrides(
+        &self,
+        unparser: &Unparser,
+        func_name: &str,
+        args: &[Expr],
+    ) -> Result<Option<ast::Expr>> {
+        if func_name == "date_part" {
+            return date_part_to_sql(unparser, self.date_field_extract_style(), 
args);
+        }
+
+        Ok(None)
+    }
 }
 
 pub struct SqliteDialect {}
@@ -231,6 +322,19 @@ impl Dialect for SqliteDialect {
     fn supports_column_alias_in_table_alias(&self) -> bool {
         false
     }
+
+    fn scalar_function_to_sql_overrides(
+        &self,
+        unparser: &Unparser,
+        func_name: &str,
+        args: &[Expr],
+    ) -> Result<Option<ast::Expr>> {
+        if func_name == "date_part" {
+            return date_part_to_sql(unparser, self.date_field_extract_style(), 
args);
+        }
+
+        Ok(None)
+    }
 }
 
 pub struct CustomDialect {
@@ -339,6 +443,19 @@ impl Dialect for CustomDialect {
     fn supports_column_alias_in_table_alias(&self) -> bool {
         self.supports_column_alias_in_table_alias
     }
+
+    fn scalar_function_to_sql_overrides(
+        &self,
+        unparser: &Unparser,
+        func_name: &str,
+        args: &[Expr],
+    ) -> Result<Option<ast::Expr>> {
+        if func_name == "date_part" {
+            return date_part_to_sql(unparser, self.date_field_extract_style(), 
args);
+        }
+
+        Ok(None)
+    }
 }
 
 /// `CustomDialectBuilder` to build `CustomDialect` using builder pattern
diff --git a/datafusion/sql/src/unparser/expr.rs 
b/datafusion/sql/src/unparser/expr.rs
index b924268a76..537ac22744 100644
--- a/datafusion/sql/src/unparser/expr.rs
+++ b/datafusion/sql/src/unparser/expr.rs
@@ -15,16 +15,15 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use datafusion_expr::ScalarUDF;
 use sqlparser::ast::Value::SingleQuotedString;
 use sqlparser::ast::{
-    self, BinaryOperator, Expr as AstExpr, Function, FunctionArg, Ident, 
Interval,
-    ObjectName, TimezoneInfo, UnaryOperator,
+    self, BinaryOperator, Expr as AstExpr, Function, Ident, Interval, 
ObjectName,
+    TimezoneInfo, UnaryOperator,
 };
 use std::sync::Arc;
 use std::vec;
 
-use super::dialect::{DateFieldExtractStyle, IntervalStyle};
+use super::dialect::IntervalStyle;
 use super::Unparser;
 use arrow::datatypes::{Decimal128Type, Decimal256Type, DecimalType};
 use arrow::util::display::array_value_to_string;
@@ -116,47 +115,14 @@ impl Unparser<'_> {
             Expr::ScalarFunction(ScalarFunction { func, args }) => {
                 let func_name = func.name();
 
-                if let Some(expr) =
-                    self.scalar_function_to_sql_overrides(func_name, func, 
args)
+                if let Some(expr) = self
+                    .dialect
+                    .scalar_function_to_sql_overrides(self, func_name, args)?
                 {
                     return Ok(expr);
                 }
 
-                let args = args
-                    .iter()
-                    .map(|e| {
-                        if matches!(
-                            e,
-                            Expr::Wildcard {
-                                qualifier: None,
-                                ..
-                            }
-                        ) {
-                            
Ok(FunctionArg::Unnamed(ast::FunctionArgExpr::Wildcard))
-                        } else {
-                            self.expr_to_sql_inner(e).map(|e| {
-                                
FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e))
-                            })
-                        }
-                    })
-                    .collect::<Result<Vec<_>>>()?;
-
-                Ok(ast::Expr::Function(Function {
-                    name: ast::ObjectName(vec![Ident {
-                        value: func_name.to_string(),
-                        quote_style: None,
-                    }]),
-                    args: 
ast::FunctionArguments::List(ast::FunctionArgumentList {
-                        duplicate_treatment: None,
-                        args,
-                        clauses: vec![],
-                    }),
-                    filter: None,
-                    null_treatment: None,
-                    over: None,
-                    within_group: vec![],
-                    parameters: ast::FunctionArguments::None,
-                }))
+                self.scalar_function_to_sql(func_name, args)
             }
             Expr::Between(Between {
                 expr,
@@ -508,6 +474,30 @@ impl Unparser<'_> {
         }
     }
 
+    pub fn scalar_function_to_sql(
+        &self,
+        func_name: &str,
+        args: &[Expr],
+    ) -> Result<ast::Expr> {
+        let args = self.function_args_to_sql(args)?;
+        Ok(ast::Expr::Function(Function {
+            name: ast::ObjectName(vec![Ident {
+                value: func_name.to_string(),
+                quote_style: None,
+            }]),
+            args: ast::FunctionArguments::List(ast::FunctionArgumentList {
+                duplicate_treatment: None,
+                args,
+                clauses: vec![],
+            }),
+            filter: None,
+            null_treatment: None,
+            over: None,
+            within_group: vec![],
+            parameters: ast::FunctionArguments::None,
+        }))
+    }
+
     pub fn sort_to_sql(&self, sort: &Sort) -> Result<ast::OrderByExpr> {
         let Sort {
             expr,
@@ -530,87 +520,6 @@ impl Unparser<'_> {
         })
     }
 
-    fn scalar_function_to_sql_overrides(
-        &self,
-        func_name: &str,
-        _func: &Arc<ScalarUDF>,
-        args: &[Expr],
-    ) -> Option<ast::Expr> {
-        if func_name.to_lowercase() == "date_part" {
-            match (self.dialect.date_field_extract_style(), args.len()) {
-                (DateFieldExtractStyle::Extract, 2) => {
-                    let date_expr = self.expr_to_sql(&args[1]).ok()?;
-
-                    if let Expr::Literal(ScalarValue::Utf8(Some(field))) = 
&args[0] {
-                        let field = match field.to_lowercase().as_str() {
-                            "year" => ast::DateTimeField::Year,
-                            "month" => ast::DateTimeField::Month,
-                            "day" => ast::DateTimeField::Day,
-                            "hour" => ast::DateTimeField::Hour,
-                            "minute" => ast::DateTimeField::Minute,
-                            "second" => ast::DateTimeField::Second,
-                            _ => return None,
-                        };
-
-                        return Some(ast::Expr::Extract {
-                            field,
-                            expr: Box::new(date_expr),
-                            syntax: ast::ExtractSyntax::From,
-                        });
-                    }
-                }
-                (DateFieldExtractStyle::Strftime, 2) => {
-                    let column = self.expr_to_sql(&args[1]).ok()?;
-
-                    if let Expr::Literal(ScalarValue::Utf8(Some(field))) = 
&args[0] {
-                        let field = match field.to_lowercase().as_str() {
-                            "year" => "%Y",
-                            "month" => "%m",
-                            "day" => "%d",
-                            "hour" => "%H",
-                            "minute" => "%M",
-                            "second" => "%S",
-                            _ => return None,
-                        };
-
-                        return Some(ast::Expr::Function(ast::Function {
-                            name: ast::ObjectName(vec![ast::Ident {
-                                value: "strftime".to_string(),
-                                quote_style: None,
-                            }]),
-                            args: ast::FunctionArguments::List(
-                                ast::FunctionArgumentList {
-                                    duplicate_treatment: None,
-                                    args: vec![
-                                        ast::FunctionArg::Unnamed(
-                                            
ast::FunctionArgExpr::Expr(ast::Expr::Value(
-                                                ast::Value::SingleQuotedString(
-                                                    field.to_string(),
-                                                ),
-                                            )),
-                                        ),
-                                        ast::FunctionArg::Unnamed(
-                                            ast::FunctionArgExpr::Expr(column),
-                                        ),
-                                    ],
-                                    clauses: vec![],
-                                },
-                            ),
-                            filter: None,
-                            null_treatment: None,
-                            over: None,
-                            within_group: vec![],
-                            parameters: ast::FunctionArguments::None,
-                        }));
-                    }
-                }
-                _ => {} // no overrides for DateFieldExtractStyle::DatePart, 
because it's already a date_part
-            }
-        }
-
-        None
-    }
-
     fn ast_type_for_date64_in_cast(&self) -> ast::DataType {
         if self.dialect.use_timestamp_for_date64() {
             ast::DataType::Timestamp(None, ast::TimezoneInfo::None)
@@ -665,7 +574,10 @@ impl Unparser<'_> {
         }
     }
 
-    fn function_args_to_sql(&self, args: &[Expr]) -> 
Result<Vec<ast::FunctionArg>> {
+    pub(crate) fn function_args_to_sql(
+        &self,
+        args: &[Expr],
+    ) -> Result<Vec<ast::FunctionArg>> {
         args.iter()
             .map(|e| {
                 if matches!(
@@ -1554,7 +1466,10 @@ mod tests {
     use datafusion_functions_aggregate::expr_fn::sum;
     use datafusion_functions_window::row_number::row_number_udwf;
 
-    use crate::unparser::dialect::{CustomDialect, CustomDialectBuilder};
+    use crate::unparser::dialect::{
+        CustomDialect, CustomDialectBuilder, DateFieldExtractStyle, Dialect,
+        PostgreSqlDialect,
+    };
 
     use super::*;
 
@@ -2428,4 +2343,39 @@ mod tests {
             assert_eq!(actual, expected);
         }
     }
+
+    #[test]
+    fn test_round_scalar_fn_to_expr() -> Result<()> {
+        let default_dialect: Arc<dyn Dialect> = Arc::new(
+            CustomDialectBuilder::new()
+                .with_identifier_quote_style('"')
+                .build(),
+        );
+        let postgres_dialect: Arc<dyn Dialect> = Arc::new(PostgreSqlDialect 
{});
+
+        for (dialect, identifier) in
+            [(default_dialect, "DOUBLE"), (postgres_dialect, "NUMERIC")]
+        {
+            let unparser = Unparser::new(dialect.as_ref());
+            let expr = Expr::ScalarFunction(ScalarFunction {
+                func: Arc::new(ScalarUDF::from(
+                    datafusion_functions::math::round::RoundFunc::new(),
+                )),
+                args: vec![
+                    Expr::Cast(Cast {
+                        expr: Box::new(col("a")),
+                        data_type: DataType::Float64,
+                    }),
+                    Expr::Literal(ScalarValue::Int64(Some(2))),
+                ],
+            });
+            let ast = unparser.expr_to_sql(&expr)?;
+
+            let actual = format!("{}", ast);
+            let expected = format!(r#"round(CAST("a" AS {identifier}), 2)"#);
+
+            assert_eq!(actual, expected);
+        }
+        Ok(())
+    }
 }
diff --git a/datafusion/sql/src/unparser/utils.rs 
b/datafusion/sql/src/unparser/utils.rs
index 0059aba257..8b2530a749 100644
--- a/datafusion/sql/src/unparser/utils.rs
+++ b/datafusion/sql/src/unparser/utils.rs
@@ -18,11 +18,14 @@
 use datafusion_common::{
     internal_err,
     tree_node::{Transformed, TreeNode},
-    Column, DataFusionError, Result,
+    Column, DataFusionError, Result, ScalarValue,
 };
 use datafusion_expr::{
     utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, Window,
 };
+use sqlparser::ast;
+
+use super::{dialect::DateFieldExtractStyle, Unparser};
 
 /// Recursively searches children of [LogicalPlan] to find an Aggregate node 
if exists
 /// prior to encountering a Join, TableScan, or a nested subquery (derived 
table factor).
@@ -187,3 +190,80 @@ fn find_window_expr<'a>(
         .flat_map(|w| w.window_expr.iter())
         .find(|expr| expr.schema_name().to_string() == column_name)
 }
+
+/// Converts a date_part function to SQL, tailoring it to the supported date 
field extraction style.
+pub(crate) fn date_part_to_sql(
+    unparser: &Unparser,
+    style: DateFieldExtractStyle,
+    date_part_args: &[Expr],
+) -> Result<Option<ast::Expr>> {
+    match (style, date_part_args.len()) {
+        (DateFieldExtractStyle::Extract, 2) => {
+            let date_expr = unparser.expr_to_sql(&date_part_args[1])?;
+            if let Expr::Literal(ScalarValue::Utf8(Some(field))) = 
&date_part_args[0] {
+                let field = match field.to_lowercase().as_str() {
+                    "year" => ast::DateTimeField::Year,
+                    "month" => ast::DateTimeField::Month,
+                    "day" => ast::DateTimeField::Day,
+                    "hour" => ast::DateTimeField::Hour,
+                    "minute" => ast::DateTimeField::Minute,
+                    "second" => ast::DateTimeField::Second,
+                    _ => return Ok(None),
+                };
+
+                return Ok(Some(ast::Expr::Extract {
+                    field,
+                    expr: Box::new(date_expr),
+                    syntax: ast::ExtractSyntax::From,
+                }));
+            }
+        }
+        (DateFieldExtractStyle::Strftime, 2) => {
+            let column = unparser.expr_to_sql(&date_part_args[1])?;
+
+            if let Expr::Literal(ScalarValue::Utf8(Some(field))) = 
&date_part_args[0] {
+                let field = match field.to_lowercase().as_str() {
+                    "year" => "%Y",
+                    "month" => "%m",
+                    "day" => "%d",
+                    "hour" => "%H",
+                    "minute" => "%M",
+                    "second" => "%S",
+                    _ => return Ok(None),
+                };
+
+                return Ok(Some(ast::Expr::Function(ast::Function {
+                    name: ast::ObjectName(vec![ast::Ident {
+                        value: "strftime".to_string(),
+                        quote_style: None,
+                    }]),
+                    args: 
ast::FunctionArguments::List(ast::FunctionArgumentList {
+                        duplicate_treatment: None,
+                        args: vec![
+                            
ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(
+                                
ast::Expr::Value(ast::Value::SingleQuotedString(
+                                    field.to_string(),
+                                )),
+                            )),
+                            
ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(column)),
+                        ],
+                        clauses: vec![],
+                    }),
+                    filter: None,
+                    null_treatment: None,
+                    over: None,
+                    within_group: vec![],
+                    parameters: ast::FunctionArguments::None,
+                })));
+            }
+        }
+        (DateFieldExtractStyle::DatePart, _) => {
+            return Ok(Some(
+                unparser.scalar_function_to_sql("date_part", date_part_args)?,
+            ));
+        }
+        _ => {}
+    };
+
+    Ok(None)
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to