This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 1e9f0e1d65 Implement prettier SQL unparsing (more human readable) 
(#11186)
1e9f0e1d65 is described below

commit 1e9f0e1d650f0549e6a8f7d6971b7373fae5199c
Author: Mohamed Abdeen <[email protected]>
AuthorDate: Thu Jul 11 19:20:10 2024 +0300

    Implement prettier SQL unparsing (more human readable) (#11186)
    
    * initial prettier unparse
    
    * bug fix
    
    * handling minus and divide
    
    * cleaning references and comments
    
    * moved tests
    
    * Update precedence of BETWEEN
    
    * rerun CI
    
    * Change precedence to match PGSQLs
    
    * more pretty unparser tests
    
    * Update operator precedence to match latest PGSQL
    
    * directly prettify expr_to_sql
    
    * handle IS operator
    
    * correct IS precedence
    
    * update unparser tests
    
    * update unparser example
    
    * update more unparser examples
    
    * add with_pretty builder to unparser
---
 datafusion-examples/examples/parse_sql_expr.rs |   9 +
 datafusion-examples/examples/plan_to_sql.rs    |  18 +-
 datafusion/expr/src/operator.rs                |  24 ++-
 datafusion/sql/src/unparser/expr.rs            | 230 ++++++++++++++++++++-----
 datafusion/sql/src/unparser/mod.rs             |  15 +-
 datafusion/sql/tests/cases/plan_to_sql.rs      |  99 +++++++++--
 6 files changed, 319 insertions(+), 76 deletions(-)

diff --git a/datafusion-examples/examples/parse_sql_expr.rs 
b/datafusion-examples/examples/parse_sql_expr.rs
index a1fc5d269a..e23e5accae 100644
--- a/datafusion-examples/examples/parse_sql_expr.rs
+++ b/datafusion-examples/examples/parse_sql_expr.rs
@@ -153,5 +153,14 @@ async fn round_trip_parse_sql_expr_demo() -> Result<()> {
 
     assert_eq!(sql, round_trip_sql);
 
+    // enable pretty-unparsing. This make the output more human-readable
+    // but can be problematic when passed to other SQL engines due to
+    // difference in precedence rules between DataFusion and target engines.
+    let unparser = Unparser::default().with_pretty(true);
+
+    let pretty = "int_col < 5 OR double_col = 8";
+    let pretty_round_trip_sql = 
unparser.expr_to_sql(&parsed_expr)?.to_string();
+    assert_eq!(pretty, pretty_round_trip_sql);
+
     Ok(())
 }
diff --git a/datafusion-examples/examples/plan_to_sql.rs 
b/datafusion-examples/examples/plan_to_sql.rs
index bd708fe52b..f719a33fb6 100644
--- a/datafusion-examples/examples/plan_to_sql.rs
+++ b/datafusion-examples/examples/plan_to_sql.rs
@@ -31,9 +31,9 @@ use datafusion_sql::unparser::{plan_to_sql, Unparser};
 /// 1. [`simple_expr_to_sql_demo`]: Create a simple expression [`Exprs`] with
 /// fluent API and convert to sql suitable for passing to another database
 ///
-/// 2. [`simple_expr_to_sql_demo_no_escape`]  Create a simple expression
-/// [`Exprs`] with fluent API and convert to sql without escaping column names
-/// more suitable for displaying to humans.
+/// 2. [`simple_expr_to_pretty_sql_demo`] Create a simple expression
+/// [`Exprs`] with fluent API and convert to sql without extra parentheses,
+/// suitable for displaying to humans
 ///
 /// 3. [`simple_expr_to_sql_demo_escape_mysql_style`]" Create a simple
 /// expression [`Exprs`] with fluent API and convert to sql escaping column
@@ -49,6 +49,7 @@ use datafusion_sql::unparser::{plan_to_sql, Unparser};
 async fn main() -> Result<()> {
     // See how to evaluate expressions
     simple_expr_to_sql_demo()?;
+    simple_expr_to_pretty_sql_demo()?;
     simple_expr_to_sql_demo_escape_mysql_style()?;
     simple_plan_to_sql_demo().await?;
     round_trip_plan_to_sql_demo().await?;
@@ -64,6 +65,17 @@ fn simple_expr_to_sql_demo() -> Result<()> {
     Ok(())
 }
 
+/// DataFusioon can remove parentheses when converting an expression to SQL.
+/// Note that output is intended for humans, not for other SQL engines,
+/// as difference in precedence rules can cause expressions to be parsed 
differently.
+fn simple_expr_to_pretty_sql_demo() -> Result<()> {
+    let expr = col("a").lt(lit(5)).or(col("a").eq(lit(8)));
+    let unparser = Unparser::default().with_pretty(true);
+    let sql = unparser.expr_to_sql(&expr)?.to_string();
+    assert_eq!(sql, r#"a < 5 OR a = 8"#);
+    Ok(())
+}
+
 /// DataFusion can convert expressions to SQL without escaping column names 
using
 /// using a custom dialect and an explicit unparser
 fn simple_expr_to_sql_demo_escape_mysql_style() -> Result<()> {
diff --git a/datafusion/expr/src/operator.rs b/datafusion/expr/src/operator.rs
index a10312e234..9bb8c48d6c 100644
--- a/datafusion/expr/src/operator.rs
+++ b/datafusion/expr/src/operator.rs
@@ -218,29 +218,23 @@ impl Operator {
     }
 
     /// Get the operator precedence
-    /// use <https://www.postgresql.org/docs/7.0/operators.htm#AEN2026> as a 
reference
+    /// use <https://www.postgresql.org/docs/7.2/sql-precedence.html> as a 
reference
     pub fn precedence(&self) -> u8 {
         match self {
             Operator::Or => 5,
             Operator::And => 10,
-            Operator::NotEq
-            | Operator::Eq
-            | Operator::Lt
-            | Operator::LtEq
-            | Operator::Gt
-            | Operator::GtEq => 20,
-            Operator::Plus | Operator::Minus => 30,
-            Operator::Multiply | Operator::Divide | Operator::Modulo => 40,
+            Operator::Eq | Operator::NotEq | Operator::LtEq | Operator::GtEq 
=> 15,
+            Operator::Lt | Operator::Gt => 20,
+            Operator::LikeMatch
+            | Operator::NotLikeMatch
+            | Operator::ILikeMatch
+            | Operator::NotILikeMatch => 25,
             Operator::IsDistinctFrom
             | Operator::IsNotDistinctFrom
             | Operator::RegexMatch
             | Operator::RegexNotMatch
             | Operator::RegexIMatch
             | Operator::RegexNotIMatch
-            | Operator::LikeMatch
-            | Operator::ILikeMatch
-            | Operator::NotLikeMatch
-            | Operator::NotILikeMatch
             | Operator::BitwiseAnd
             | Operator::BitwiseOr
             | Operator::BitwiseShiftLeft
@@ -248,7 +242,9 @@ impl Operator {
             | Operator::BitwiseXor
             | Operator::StringConcat
             | Operator::AtArrow
-            | Operator::ArrowAt => 0,
+            | Operator::ArrowAt => 30,
+            Operator::Plus | Operator::Minus => 40,
+            Operator::Multiply | Operator::Divide | Operator::Modulo => 45,
         }
     }
 }
diff --git a/datafusion/sql/src/unparser/expr.rs 
b/datafusion/sql/src/unparser/expr.rs
index 198186934c..e0d05c400c 100644
--- a/datafusion/sql/src/unparser/expr.rs
+++ b/datafusion/sql/src/unparser/expr.rs
@@ -30,8 +30,8 @@ use arrow_array::{Date32Array, Date64Array, PrimitiveArray};
 use arrow_schema::DataType;
 use sqlparser::ast::Value::SingleQuotedString;
 use sqlparser::ast::{
-    self, Expr as AstExpr, Function, FunctionArg, Ident, Interval, 
TimezoneInfo,
-    UnaryOperator,
+    self, BinaryOperator, Expr as AstExpr, Function, FunctionArg, Ident, 
Interval,
+    TimezoneInfo, UnaryOperator,
 };
 
 use datafusion_common::{
@@ -101,8 +101,21 @@ pub fn expr_to_unparsed(expr: &Expr) -> Result<Unparsed> {
     unparser.expr_to_unparsed(expr)
 }
 
+const LOWEST: &BinaryOperator = &BinaryOperator::Or;
+// closest precedence we have to IS operator is BitwiseAnd (any other) in PG 
docs
+// (https://www.postgresql.org/docs/7.2/sql-precedence.html)
+const IS: &BinaryOperator = &BinaryOperator::BitwiseAnd;
+
 impl Unparser<'_> {
     pub fn expr_to_sql(&self, expr: &Expr) -> Result<ast::Expr> {
+        let mut root_expr = self.expr_to_sql_inner(expr)?;
+        if self.pretty {
+            root_expr = self.remove_unnecessary_nesting(root_expr, LOWEST, 
LOWEST);
+        }
+        Ok(root_expr)
+    }
+
+    fn expr_to_sql_inner(&self, expr: &Expr) -> Result<ast::Expr> {
         match expr {
             Expr::InList(InList {
                 expr,
@@ -111,10 +124,10 @@ impl Unparser<'_> {
             }) => {
                 let list_expr = list
                     .iter()
-                    .map(|e| self.expr_to_sql(e))
+                    .map(|e| self.expr_to_sql_inner(e))
                     .collect::<Result<Vec<_>>>()?;
                 Ok(ast::Expr::InList {
-                    expr: Box::new(self.expr_to_sql(expr)?),
+                    expr: Box::new(self.expr_to_sql_inner(expr)?),
                     list: list_expr,
                     negated: *negated,
                 })
@@ -128,7 +141,7 @@ impl Unparser<'_> {
                         if matches!(e, Expr::Wildcard { qualifier: None }) {
                             
Ok(FunctionArg::Unnamed(ast::FunctionArgExpr::Wildcard))
                         } else {
-                            self.expr_to_sql(e).map(|e| {
+                            self.expr_to_sql_inner(e).map(|e| {
                                 
FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e))
                             })
                         }
@@ -157,9 +170,9 @@ impl Unparser<'_> {
                 low,
                 high,
             }) => {
-                let sql_parser_expr = self.expr_to_sql(expr)?;
-                let sql_low = self.expr_to_sql(low)?;
-                let sql_high = self.expr_to_sql(high)?;
+                let sql_parser_expr = self.expr_to_sql_inner(expr)?;
+                let sql_low = self.expr_to_sql_inner(low)?;
+                let sql_high = self.expr_to_sql_inner(high)?;
                 Ok(ast::Expr::Nested(Box::new(self.between_op_to_sql(
                     sql_parser_expr,
                     *negated,
@@ -169,8 +182,8 @@ impl Unparser<'_> {
             }
             Expr::Column(col) => self.col_to_sql(col),
             Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
-                let l = self.expr_to_sql(left.as_ref())?;
-                let r = self.expr_to_sql(right.as_ref())?;
+                let l = self.expr_to_sql_inner(left.as_ref())?;
+                let r = self.expr_to_sql_inner(right.as_ref())?;
                 let op = self.op_to_sql(op)?;
 
                 Ok(ast::Expr::Nested(Box::new(self.binary_op_to_sql(l, r, 
op))))
@@ -182,21 +195,21 @@ impl Unparser<'_> {
             }) => {
                 let conditions = when_then_expr
                     .iter()
-                    .map(|(w, _)| self.expr_to_sql(w))
+                    .map(|(w, _)| self.expr_to_sql_inner(w))
                     .collect::<Result<Vec<_>>>()?;
                 let results = when_then_expr
                     .iter()
-                    .map(|(_, t)| self.expr_to_sql(t))
+                    .map(|(_, t)| self.expr_to_sql_inner(t))
                     .collect::<Result<Vec<_>>>()?;
                 let operand = match expr.as_ref() {
-                    Some(e) => match self.expr_to_sql(e) {
+                    Some(e) => match self.expr_to_sql_inner(e) {
                         Ok(sql_expr) => Some(Box::new(sql_expr)),
                         Err(_) => None,
                     },
                     None => None,
                 };
                 let else_result = match else_expr.as_ref() {
-                    Some(e) => match self.expr_to_sql(e) {
+                    Some(e) => match self.expr_to_sql_inner(e) {
                         Ok(sql_expr) => Some(Box::new(sql_expr)),
                         Err(_) => None,
                     },
@@ -211,7 +224,7 @@ impl Unparser<'_> {
                 })
             }
             Expr::Cast(Cast { expr, data_type }) => {
-                let inner_expr = self.expr_to_sql(expr)?;
+                let inner_expr = self.expr_to_sql_inner(expr)?;
                 Ok(ast::Expr::Cast {
                     kind: ast::CastKind::Cast,
                     expr: Box::new(inner_expr),
@@ -220,7 +233,7 @@ impl Unparser<'_> {
                 })
             }
             Expr::Literal(value) => Ok(self.scalar_to_sql(value)?),
-            Expr::Alias(Alias { expr, name: _, .. }) => self.expr_to_sql(expr),
+            Expr::Alias(Alias { expr, name: _, .. }) => 
self.expr_to_sql_inner(expr),
             Expr::WindowFunction(WindowFunction {
                 fun,
                 args,
@@ -255,7 +268,7 @@ impl Unparser<'_> {
                     window_name: None,
                     partition_by: partition_by
                         .iter()
-                        .map(|e| self.expr_to_sql(e))
+                        .map(|e| self.expr_to_sql_inner(e))
                         .collect::<Result<Vec<_>>>()?,
                     order_by,
                     window_frame: Some(ast::WindowFrame {
@@ -296,8 +309,8 @@ impl Unparser<'_> {
                 case_insensitive: _,
             }) => Ok(ast::Expr::Like {
                 negated: *negated,
-                expr: Box::new(self.expr_to_sql(expr)?),
-                pattern: Box::new(self.expr_to_sql(pattern)?),
+                expr: Box::new(self.expr_to_sql_inner(expr)?),
+                pattern: Box::new(self.expr_to_sql_inner(pattern)?),
                 escape_char: escape_char.map(|c| c.to_string()),
             }),
             Expr::AggregateFunction(agg) => {
@@ -305,7 +318,7 @@ impl Unparser<'_> {
 
                 let args = self.function_args_to_sql(&agg.args)?;
                 let filter = match &agg.filter {
-                    Some(filter) => Some(Box::new(self.expr_to_sql(filter)?)),
+                    Some(filter) => 
Some(Box::new(self.expr_to_sql_inner(filter)?)),
                     None => None,
                 };
                 Ok(ast::Expr::Function(Function {
@@ -339,7 +352,7 @@ impl Unparser<'_> {
                 Ok(ast::Expr::Subquery(sub_query))
             }
             Expr::InSubquery(insubq) => {
-                let inexpr = Box::new(self.expr_to_sql(insubq.expr.as_ref())?);
+                let inexpr = 
Box::new(self.expr_to_sql_inner(insubq.expr.as_ref())?);
                 let sub_statement =
                     self.plan_to_sql(insubq.subquery.subquery.as_ref())?;
                 let sub_query = if let ast::Statement::Query(inner_query) = 
sub_statement
@@ -377,38 +390,38 @@ impl Unparser<'_> {
                 nulls_first: _,
             }) => plan_err!("Sort expression should be handled by 
expr_to_unparsed"),
             Expr::IsNull(expr) => {
-                Ok(ast::Expr::IsNull(Box::new(self.expr_to_sql(expr)?)))
-            }
-            Expr::IsNotNull(expr) => {
-                Ok(ast::Expr::IsNotNull(Box::new(self.expr_to_sql(expr)?)))
+                Ok(ast::Expr::IsNull(Box::new(self.expr_to_sql_inner(expr)?)))
             }
+            Expr::IsNotNull(expr) => Ok(ast::Expr::IsNotNull(Box::new(
+                self.expr_to_sql_inner(expr)?,
+            ))),
             Expr::IsTrue(expr) => {
-                Ok(ast::Expr::IsTrue(Box::new(self.expr_to_sql(expr)?)))
-            }
-            Expr::IsNotTrue(expr) => {
-                Ok(ast::Expr::IsNotTrue(Box::new(self.expr_to_sql(expr)?)))
+                Ok(ast::Expr::IsTrue(Box::new(self.expr_to_sql_inner(expr)?)))
             }
+            Expr::IsNotTrue(expr) => Ok(ast::Expr::IsNotTrue(Box::new(
+                self.expr_to_sql_inner(expr)?,
+            ))),
             Expr::IsFalse(expr) => {
-                Ok(ast::Expr::IsFalse(Box::new(self.expr_to_sql(expr)?)))
-            }
-            Expr::IsNotFalse(expr) => {
-                Ok(ast::Expr::IsNotFalse(Box::new(self.expr_to_sql(expr)?)))
-            }
-            Expr::IsUnknown(expr) => {
-                Ok(ast::Expr::IsUnknown(Box::new(self.expr_to_sql(expr)?)))
-            }
-            Expr::IsNotUnknown(expr) => {
-                Ok(ast::Expr::IsNotUnknown(Box::new(self.expr_to_sql(expr)?)))
-            }
+                Ok(ast::Expr::IsFalse(Box::new(self.expr_to_sql_inner(expr)?)))
+            }
+            Expr::IsNotFalse(expr) => Ok(ast::Expr::IsNotFalse(Box::new(
+                self.expr_to_sql_inner(expr)?,
+            ))),
+            Expr::IsUnknown(expr) => Ok(ast::Expr::IsUnknown(Box::new(
+                self.expr_to_sql_inner(expr)?,
+            ))),
+            Expr::IsNotUnknown(expr) => Ok(ast::Expr::IsNotUnknown(Box::new(
+                self.expr_to_sql_inner(expr)?,
+            ))),
             Expr::Not(expr) => {
-                let sql_parser_expr = self.expr_to_sql(expr)?;
+                let sql_parser_expr = self.expr_to_sql_inner(expr)?;
                 Ok(AstExpr::UnaryOp {
                     op: UnaryOperator::Not,
                     expr: Box::new(sql_parser_expr),
                 })
             }
             Expr::Negative(expr) => {
-                let sql_parser_expr = self.expr_to_sql(expr)?;
+                let sql_parser_expr = self.expr_to_sql_inner(expr)?;
                 Ok(AstExpr::UnaryOp {
                     op: UnaryOperator::Minus,
                     expr: Box::new(sql_parser_expr),
@@ -432,7 +445,7 @@ impl Unparser<'_> {
                 })
             }
             Expr::TryCast(TryCast { expr, data_type }) => {
-                let inner_expr = self.expr_to_sql(expr)?;
+                let inner_expr = self.expr_to_sql_inner(expr)?;
                 Ok(ast::Expr::Cast {
                     kind: ast::CastKind::TryCast,
                     expr: Box::new(inner_expr),
@@ -449,7 +462,7 @@ impl Unparser<'_> {
                         .iter()
                         .map(|set| {
                             set.iter()
-                                .map(|e| self.expr_to_sql(e))
+                                .map(|e| self.expr_to_sql_inner(e))
                                 .collect::<Result<Vec<_>>>()
                         })
                         .collect::<Result<Vec<_>>>()?;
@@ -460,7 +473,7 @@ impl Unparser<'_> {
                     let expr_ast_sets = cube
                         .iter()
                         .map(|e| {
-                            let sql = self.expr_to_sql(e)?;
+                            let sql = self.expr_to_sql_inner(e)?;
                             Ok(vec![sql])
                         })
                         .collect::<Result<Vec<_>>>()?;
@@ -470,7 +483,7 @@ impl Unparser<'_> {
                     let expr_ast_sets: Vec<Vec<AstExpr>> = rollup
                         .iter()
                         .map(|e| {
-                            let sql = self.expr_to_sql(e)?;
+                            let sql = self.expr_to_sql_inner(e)?;
                             Ok(vec![sql])
                         })
                         .collect::<Result<Vec<_>>>()?;
@@ -603,6 +616,88 @@ impl Unparser<'_> {
         }
     }
 
+    /// Given an expression of the form `((a + b) * (c * d))`,
+    /// the parenthesing is redundant if the precedence of the nested 
expression is already higher
+    /// than the surrounding operators' precedence. The above expression would 
become
+    /// `(a + b) * c * d`.
+    ///
+    /// Also note that when fetching the precedence of a nested expression, we 
ignore other nested
+    /// expressions, so precedence of expr `(a * (b + c))` equals `*` and not 
`+`.
+    fn remove_unnecessary_nesting(
+        &self,
+        expr: ast::Expr,
+        left_op: &BinaryOperator,
+        right_op: &BinaryOperator,
+    ) -> ast::Expr {
+        match expr {
+            ast::Expr::Nested(nested) => {
+                let surrounding_precedence = self
+                    .sql_op_precedence(left_op)
+                    .max(self.sql_op_precedence(right_op));
+
+                let inner_precedence = self.inner_precedence(&nested);
+
+                let not_associative =
+                    matches!(left_op, BinaryOperator::Minus | 
BinaryOperator::Divide);
+
+                if inner_precedence == surrounding_precedence && 
not_associative {
+                    ast::Expr::Nested(Box::new(
+                        self.remove_unnecessary_nesting(*nested, LOWEST, 
LOWEST),
+                    ))
+                } else if inner_precedence >= surrounding_precedence {
+                    self.remove_unnecessary_nesting(*nested, left_op, right_op)
+                } else {
+                    ast::Expr::Nested(Box::new(
+                        self.remove_unnecessary_nesting(*nested, LOWEST, 
LOWEST),
+                    ))
+                }
+            }
+            ast::Expr::BinaryOp { left, op, right } => ast::Expr::BinaryOp {
+                left: Box::new(self.remove_unnecessary_nesting(*left, left_op, 
&op)),
+                right: Box::new(self.remove_unnecessary_nesting(*right, &op, 
right_op)),
+                op,
+            },
+            ast::Expr::IsTrue(expr) => ast::Expr::IsTrue(Box::new(
+                self.remove_unnecessary_nesting(*expr, left_op, IS),
+            )),
+            ast::Expr::IsNotTrue(expr) => ast::Expr::IsNotTrue(Box::new(
+                self.remove_unnecessary_nesting(*expr, left_op, IS),
+            )),
+            ast::Expr::IsFalse(expr) => ast::Expr::IsFalse(Box::new(
+                self.remove_unnecessary_nesting(*expr, left_op, IS),
+            )),
+            ast::Expr::IsNotFalse(expr) => ast::Expr::IsNotFalse(Box::new(
+                self.remove_unnecessary_nesting(*expr, left_op, IS),
+            )),
+            ast::Expr::IsNull(expr) => ast::Expr::IsNull(Box::new(
+                self.remove_unnecessary_nesting(*expr, left_op, IS),
+            )),
+            ast::Expr::IsNotNull(expr) => ast::Expr::IsNotNull(Box::new(
+                self.remove_unnecessary_nesting(*expr, left_op, IS),
+            )),
+            ast::Expr::IsUnknown(expr) => ast::Expr::IsUnknown(Box::new(
+                self.remove_unnecessary_nesting(*expr, left_op, IS),
+            )),
+            ast::Expr::IsNotUnknown(expr) => ast::Expr::IsNotUnknown(Box::new(
+                self.remove_unnecessary_nesting(*expr, left_op, IS),
+            )),
+            _ => expr,
+        }
+    }
+
+    fn inner_precedence(&self, expr: &ast::Expr) -> u8 {
+        match expr {
+            ast::Expr::Nested(_) | ast::Expr::Identifier(_) | 
ast::Expr::Value(_) => 100,
+            ast::Expr::BinaryOp { op, .. } => self.sql_op_precedence(op),
+            // closest precedence we currently have to Between is PGLikeMatch
+            // (https://www.postgresql.org/docs/7.2/sql-precedence.html)
+            ast::Expr::Between { .. } => {
+                self.sql_op_precedence(&ast::BinaryOperator::PGLikeMatch)
+            }
+            _ => 0,
+        }
+    }
+
     pub(super) fn between_op_to_sql(
         &self,
         expr: ast::Expr,
@@ -618,6 +713,48 @@ impl Unparser<'_> {
         }
     }
 
+    fn sql_op_precedence(&self, op: &BinaryOperator) -> u8 {
+        match self.sql_to_op(op) {
+            Ok(op) => op.precedence(),
+            Err(_) => 0,
+        }
+    }
+
+    fn sql_to_op(&self, op: &BinaryOperator) -> Result<Operator> {
+        match op {
+            ast::BinaryOperator::Eq => Ok(Operator::Eq),
+            ast::BinaryOperator::NotEq => Ok(Operator::NotEq),
+            ast::BinaryOperator::Lt => Ok(Operator::Lt),
+            ast::BinaryOperator::LtEq => Ok(Operator::LtEq),
+            ast::BinaryOperator::Gt => Ok(Operator::Gt),
+            ast::BinaryOperator::GtEq => Ok(Operator::GtEq),
+            ast::BinaryOperator::Plus => Ok(Operator::Plus),
+            ast::BinaryOperator::Minus => Ok(Operator::Minus),
+            ast::BinaryOperator::Multiply => Ok(Operator::Multiply),
+            ast::BinaryOperator::Divide => Ok(Operator::Divide),
+            ast::BinaryOperator::Modulo => Ok(Operator::Modulo),
+            ast::BinaryOperator::And => Ok(Operator::And),
+            ast::BinaryOperator::Or => Ok(Operator::Or),
+            ast::BinaryOperator::PGRegexMatch => Ok(Operator::RegexMatch),
+            ast::BinaryOperator::PGRegexIMatch => Ok(Operator::RegexIMatch),
+            ast::BinaryOperator::PGRegexNotMatch => 
Ok(Operator::RegexNotMatch),
+            ast::BinaryOperator::PGRegexNotIMatch => 
Ok(Operator::RegexNotIMatch),
+            ast::BinaryOperator::PGILikeMatch => Ok(Operator::ILikeMatch),
+            ast::BinaryOperator::PGNotLikeMatch => Ok(Operator::NotLikeMatch),
+            ast::BinaryOperator::PGLikeMatch => Ok(Operator::LikeMatch),
+            ast::BinaryOperator::PGNotILikeMatch => 
Ok(Operator::NotILikeMatch),
+            ast::BinaryOperator::BitwiseAnd => Ok(Operator::BitwiseAnd),
+            ast::BinaryOperator::BitwiseOr => Ok(Operator::BitwiseOr),
+            ast::BinaryOperator::BitwiseXor => Ok(Operator::BitwiseXor),
+            ast::BinaryOperator::PGBitwiseShiftRight => 
Ok(Operator::BitwiseShiftRight),
+            ast::BinaryOperator::PGBitwiseShiftLeft => 
Ok(Operator::BitwiseShiftLeft),
+            ast::BinaryOperator::StringConcat => Ok(Operator::StringConcat),
+            ast::BinaryOperator::AtArrow => Ok(Operator::AtArrow),
+            ast::BinaryOperator::ArrowAt => Ok(Operator::ArrowAt),
+            _ => not_impl_err!("unsupported operation: {op:?}"),
+        }
+    }
+
     fn op_to_sql(&self, op: &Operator) -> Result<ast::BinaryOperator> {
         match op {
             Operator::Eq => Ok(ast::BinaryOperator::Eq),
@@ -1538,6 +1675,7 @@ mod tests {
 
         Ok(())
     }
+
     #[test]
     fn custom_dialect() -> Result<()> {
         let dialect = CustomDialect::new(Some('\''));
diff --git a/datafusion/sql/src/unparser/mod.rs 
b/datafusion/sql/src/unparser/mod.rs
index fbbed4972b..e5ffbc8a21 100644
--- a/datafusion/sql/src/unparser/mod.rs
+++ b/datafusion/sql/src/unparser/mod.rs
@@ -29,11 +29,23 @@ pub mod dialect;
 
 pub struct Unparser<'a> {
     dialect: &'a dyn Dialect,
+    pretty: bool,
 }
 
 impl<'a> Unparser<'a> {
     pub fn new(dialect: &'a dyn Dialect) -> Self {
-        Self { dialect }
+        Self {
+            dialect,
+            pretty: false,
+        }
+    }
+
+    /// Allow unparser to remove parenthesis according to the precedence rules 
of DataFusion.
+    /// This might make it invalid SQL for other SQL query engines with 
different precedence
+    /// rules, even if its valid for DataFusion.
+    pub fn with_pretty(mut self, pretty: bool) -> Self {
+        self.pretty = pretty;
+        self
     }
 }
 
@@ -41,6 +53,7 @@ impl<'a> Default for Unparser<'a> {
     fn default() -> Self {
         Self {
             dialect: &DefaultDialect {},
+            pretty: false,
         }
     }
 }
diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs 
b/datafusion/sql/tests/cases/plan_to_sql.rs
index 374403d853..91295b2e8a 100644
--- a/datafusion/sql/tests/cases/plan_to_sql.rs
+++ b/datafusion/sql/tests/cases/plan_to_sql.rs
@@ -104,26 +104,26 @@ fn roundtrip_statement() -> Result<()> {
             "select id, count(*) as cnt from (select p1.id as id from person 
p1 inner join person p2 on p1.id=p2.id) group by id",
             "select id, count(*), first_name from person group by first_name, 
id",
             "select id, sum(age), first_name from person group by first_name, 
id",
-            "select id, count(*), first_name 
-            from person 
+            "select id, count(*), first_name
+            from person
             where id!=3 and first_name=='test'
-            group by first_name, id 
+            group by first_name, id
             having count(*)>5 and count(*)<10
             order by count(*)",
-            r#"select id, count("First Name") as count_first_name, "Last Name" 
+            r#"select id, count("First Name") as count_first_name, "Last Name"
             from person_quoted_cols
             where id!=3 and "First Name"=='test'
-            group by "Last Name", id 
+            group by "Last Name", id
             having count_first_name>5 and count_first_name<10
             order by count_first_name, "Last Name""#,
             r#"select p.id, count("First Name") as count_first_name,
-            "Last Name", sum(qp.id/p.id - (select sum(id) from 
person_quoted_cols) ) / (select count(*) from person) 
+            "Last Name", sum(qp.id/p.id - (select sum(id) from 
person_quoted_cols) ) / (select count(*) from person)
             from (select id, "First Name", "Last Name" from 
person_quoted_cols) qp
             inner join (select * from person) p
             on p.id = qp.id
-            where p.id!=3 and "First Name"=='test' and qp.id in 
+            where p.id!=3 and "First Name"=='test' and qp.id in
             (select id from (select id, count(*) from person group by id 
having count(*) > 0))
-            group by "Last Name", p.id 
+            group by "Last Name", p.id
             having count_first_name>5 and count_first_name<10
             order by count_first_name, "Last Name""#,
             r#"SELECT j1_string as string FROM j1
@@ -134,12 +134,12 @@ fn roundtrip_statement() -> Result<()> {
             SELECT j2_string as string FROM j2
             ORDER BY string DESC
             LIMIT 10"#,
-            "SELECT id, count(*) over (PARTITION BY first_name ROWS BETWEEN 
UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), 
-            last_name, sum(id) over (PARTITION BY first_name ROWS BETWEEN 
UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), 
+            "SELECT id, count(*) over (PARTITION BY first_name ROWS BETWEEN 
UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING),
+            last_name, sum(id) over (PARTITION BY first_name ROWS BETWEEN 
UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING),
             first_name from person",
-            r#"SELECT id, count(distinct id) over (ROWS BETWEEN UNBOUNDED 
PRECEDING AND UNBOUNDED FOLLOWING), 
+            r#"SELECT id, count(distinct id) over (ROWS BETWEEN UNBOUNDED 
PRECEDING AND UNBOUNDED FOLLOWING),
             sum(id) OVER (PARTITION BY first_name ROWS BETWEEN UNBOUNDED 
PRECEDING AND UNBOUNDED FOLLOWING) from person"#,
-            "SELECT id, sum(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 
PRECEDING AND 2 FOLLOWING) from person",            
+            "SELECT id, sum(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 
PRECEDING AND 2 FOLLOWING) from person",
         ];
 
     // For each test sql string, we transform as follows:
@@ -314,3 +314,78 @@ fn test_table_references_in_plan_to_sql() {
         "SELECT \"table\".id, \"table\".\"value\" FROM \"table\"",
     );
 }
+
+#[test]
+fn test_pretty_roundtrip() -> Result<()> {
+    let schema = Schema::new(vec![
+        Field::new("id", DataType::Utf8, false),
+        Field::new("age", DataType::Utf8, false),
+    ]);
+
+    let df_schema = DFSchema::try_from(schema)?;
+
+    let context = MockContextProvider::default();
+    let sql_to_rel = SqlToRel::new(&context);
+
+    let unparser = Unparser::default().with_pretty(true);
+
+    let sql_to_pretty_unparse = vec![
+        ("((id < 5) OR (age = 8))", "id < 5 OR age = 8"),
+        ("((id + 5) * (age * 8))", "(id + 5) * age * 8"),
+        ("(3 + (5 * 6) * 3)", "3 + 5 * 6 * 3"),
+        ("((3 * (5 + 6)) * 3)", "3 * (5 + 6) * 3"),
+        ("((3 AND (5 OR 6)) * 3)", "(3 AND (5 OR 6)) * 3"),
+        ("((3 + (5 + 6)) * 3)", "(3 + 5 + 6) * 3"),
+        ("((3 + (5 + 6)) + 3)", "3 + 5 + 6 + 3"),
+        ("3 + 5 + 6 + 3", "3 + 5 + 6 + 3"),
+        ("3 + (5 + (6 + 3))", "3 + 5 + 6 + 3"),
+        ("3 + ((5 + 6) + 3)", "3 + 5 + 6 + 3"),
+        ("(3 + 5) + (6 + 3)", "3 + 5 + 6 + 3"),
+        ("((3 + 5) + (6 + 3))", "3 + 5 + 6 + 3"),
+        (
+            "((id > 10) OR (age BETWEEN 10 AND 20))",
+            "id > 10 OR age BETWEEN 10 AND 20",
+        ),
+        (
+            "((id > 10) * (age BETWEEN 10 AND 20))",
+            "(id > 10) * (age BETWEEN 10 AND 20)",
+        ),
+        ("id - (age - 8)", "id - (age - 8)"),
+        ("((id - age) - 8)", "id - age - 8"),
+        ("(id OR (age - 8))", "id OR age - 8"),
+        ("(id / (age - 8))", "id / (age - 8)"),
+        ("((id / age) * 8)", "id / age * 8"),
+        ("((age + 10) < 20) IS TRUE", "(age + 10 < 20) IS TRUE"),
+        (
+            "(20 > (age + 5)) IS NOT FALSE",
+            "(20 > age + 5) IS NOT FALSE",
+        ),
+        ("(true AND false) IS FALSE", "(true AND false) IS FALSE"),
+        ("true AND (false IS FALSE)", "true AND false IS FALSE"),
+    ];
+
+    for (sql, pretty) in sql_to_pretty_unparse.iter() {
+        let sql_expr = Parser::new(&GenericDialect {})
+            .try_with_sql(sql)?
+            .parse_expr()?;
+        let expr =
+            sql_to_rel.sql_to_expr(sql_expr, &df_schema, &mut 
PlannerContext::new())?;
+        let round_trip_sql = unparser.expr_to_sql(&expr)?.to_string();
+        assert_eq!(pretty.to_string(), round_trip_sql);
+
+        // verify that the pretty string parses to the same underlying Expr
+        let pretty_sql_expr = Parser::new(&GenericDialect {})
+            .try_with_sql(pretty)?
+            .parse_expr()?;
+
+        let pretty_expr = sql_to_rel.sql_to_expr(
+            pretty_sql_expr,
+            &df_schema,
+            &mut PlannerContext::new(),
+        )?;
+
+        assert_eq!(expr.to_string(), pretty_expr.to_string());
+    }
+
+    Ok(())
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to