This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 5911d182ee feat: implement more expr_to_sql functionality (#9578)
5911d182ee is described below

commit 5911d182eef08afee4fbdef3da7642ee92d1314c
Author: Devin D'Angelo <[email protected]>
AuthorDate: Thu Mar 14 06:10:01 2024 -0400

    feat: implement more expr_to_sql functionality (#9578)
    
    * more impls
    
    * fix tests
    
    * cargo update dfcli
    
    * fix custom_dialect test
    
    * add tests and feature flag
    
    * fix comment
    
    * remove chrono use arrow-array conversions
    
    * fix cargo lock again
    
    * fix count distinct
    
    * retry windows ci
    
    * retry windows ci again
    
    * add roundtrip tests
    
    * cargo fmt
---
 README.md                               |   1 +
 datafusion-cli/Cargo.lock               |   1 +
 datafusion/sql/Cargo.toml               |   4 +-
 datafusion/sql/src/lib.rs               |   1 +
 datafusion/sql/src/unparser/expr.rs     | 389 ++++++++++++++++++++++++++------
 datafusion/sql/tests/sql_integration.rs |  24 +-
 6 files changed, 338 insertions(+), 82 deletions(-)

diff --git a/README.md b/README.md
index e5ac9503be..abd727672a 100644
--- a/README.md
+++ b/README.md
@@ -83,6 +83,7 @@ Default features:
 - `parquet`: support for reading the [Apache Parquet] format
 - `regex_expressions`: regular expression functions, such as `regexp_match`
 - `unicode_expressions`: Include unicode aware functions such as 
`character_length`
+- `unparser` : enables support to reverse LogicalPlans back into SQL
 
 Optional features:
 
diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock
index e0c7c4391b..1c2514811c 100644
--- a/datafusion-cli/Cargo.lock
+++ b/datafusion-cli/Cargo.lock
@@ -1366,6 +1366,7 @@ name = "datafusion-sql"
 version = "36.0.0"
 dependencies = [
  "arrow",
+ "arrow-array",
  "arrow-schema",
  "datafusion-common",
  "datafusion-expr",
diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml
index 7739058a5c..ca2c1a240c 100644
--- a/datafusion/sql/Cargo.toml
+++ b/datafusion/sql/Cargo.toml
@@ -33,11 +33,13 @@ name = "datafusion_sql"
 path = "src/lib.rs"
 
 [features]
-default = ["unicode_expressions"]
+default = ["unicode_expressions", "unparser"]
 unicode_expressions = []
+unparser = []
 
 [dependencies]
 arrow = { workspace = true }
+arrow-array = { workspace = true }
 arrow-schema = { workspace = true }
 datafusion-common = { workspace = true, default-features = true }
 datafusion-expr = { workspace = true }
diff --git a/datafusion/sql/src/lib.rs b/datafusion/sql/src/lib.rs
index da66ee197a..e8e07eebe2 100644
--- a/datafusion/sql/src/lib.rs
+++ b/datafusion/sql/src/lib.rs
@@ -36,6 +36,7 @@ mod relation;
 mod select;
 mod set_expr;
 mod statement;
+#[cfg(feature = "unparser")]
 pub mod unparser;
 pub mod utils;
 mod values;
diff --git a/datafusion/sql/src/unparser/expr.rs 
b/datafusion/sql/src/unparser/expr.rs
index 2a9fdd47ad..403a7c6193 100644
--- a/datafusion/sql/src/unparser/expr.rs
+++ b/datafusion/sql/src/unparser/expr.rs
@@ -15,12 +15,16 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use datafusion_common::{not_impl_err, Column, Result, ScalarValue};
+use arrow_array::{Date32Array, Date64Array};
+use arrow_schema::DataType;
+use datafusion_common::{
+    internal_datafusion_err, not_impl_err, Column, Result, ScalarValue,
+};
 use datafusion_expr::{
-    expr::{Alias, InList, ScalarFunction, WindowFunction},
+    expr::{AggregateFunctionDefinition, Alias, InList, ScalarFunction, 
WindowFunction},
     Between, BinaryExpr, Case, Cast, Expr, Like, Operator,
 };
-use sqlparser::ast;
+use sqlparser::ast::{self, Function, FunctionArg, Ident};
 
 use super::Unparser;
 
@@ -36,7 +40,7 @@ use super::Unparser;
 /// let expr = col("a").gt(lit(4));
 /// let sql = expr_to_sql(&expr).unwrap();
 ///
-/// assert_eq!(format!("{}", sql), "a > 4")
+/// assert_eq!(format!("{}", sql), "(a > 4)")
 /// ```
 pub fn expr_to_sql(expr: &Expr) -> Result<ast::Expr> {
     let unparser = Unparser::default();
@@ -70,7 +74,7 @@ impl Unparser<'_> {
                 let r = self.expr_to_sql(right.as_ref())?;
                 let op = self.op_to_sql(op)?;
 
-                Ok(self.binary_op_to_sql(l, r, op))
+                Ok(ast::Expr::Nested(Box::new(self.binary_op_to_sql(l, r, 
op))))
             }
             Expr::Case(Case {
                 expr,
@@ -79,10 +83,15 @@ impl Unparser<'_> {
             }) => {
                 not_impl_err!("Unsupported expression: {expr:?}")
             }
-            Expr::Cast(Cast { expr, data_type: _ }) => {
-                not_impl_err!("Unsupported expression: {expr:?}")
+            Expr::Cast(Cast { expr, data_type }) => {
+                let inner_expr = self.expr_to_sql(expr)?;
+                Ok(ast::Expr::Cast {
+                    expr: Box::new(inner_expr),
+                    data_type: self.arrow_dtype_to_ast_dtype(data_type)?,
+                    format: None,
+                })
             }
-            Expr::Literal(value) => 
Ok(ast::Expr::Value(self.scalar_to_sql(value)?)),
+            Expr::Literal(value) => Ok(self.scalar_to_sql(value)?),
             Expr::Alias(Alias { expr, name: _, .. }) => self.expr_to_sql(expr),
             Expr::WindowFunction(WindowFunction {
                 fun: _,
@@ -103,6 +112,45 @@ impl Unparser<'_> {
             }) => {
                 not_impl_err!("Unsupported expression: {expr:?}")
             }
+            Expr::AggregateFunction(agg) => {
+                let func_name = if let 
AggregateFunctionDefinition::BuiltIn(built_in) =
+                    &agg.func_def
+                {
+                    built_in.name()
+                } else {
+                    return not_impl_err!(
+                        "Only built in agg functions are supported, got 
{agg:?}"
+                    );
+                };
+
+                let args = agg
+                    .args
+                    .iter()
+                    .map(|e| {
+                        if matches!(e, Expr::Wildcard { qualifier: None }) {
+                            
Ok(FunctionArg::Unnamed(ast::FunctionArgExpr::Wildcard))
+                        } else {
+                            self.expr_to_sql(e).map(|e| {
+                                
FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e))
+                            })
+                        }
+                    })
+                    .collect::<Result<Vec<_>>>()?;
+
+                Ok(ast::Expr::Function(Function {
+                    name: ast::ObjectName(vec![Ident {
+                        value: func_name.to_string(),
+                        quote_style: None,
+                    }]),
+                    args,
+                    filter: None,
+                    null_treatment: None,
+                    over: None,
+                    distinct: agg.distinct,
+                    special: false,
+                    order_by: vec![],
+                }))
+            }
             _ => not_impl_err!("Unsupported expression: {expr:?}"),
         }
     }
@@ -174,139 +222,265 @@ impl Unparser<'_> {
         }
     }
 
-    fn scalar_to_sql(&self, v: &ScalarValue) -> Result<ast::Value> {
+    /// DataFusion ScalarValues sometimes require a ast::Expr to construct.
+    /// For example ScalarValue::Date32(d) corresponds to the ast::Expr 
CAST('datestr' as DATE)
+    fn scalar_to_sql(&self, v: &ScalarValue) -> Result<ast::Expr> {
         match v {
-            ScalarValue::Null => Ok(ast::Value::Null),
-            ScalarValue::Boolean(Some(b)) => 
Ok(ast::Value::Boolean(b.to_owned())),
-            ScalarValue::Boolean(None) => Ok(ast::Value::Null),
-            ScalarValue::Float32(Some(f)) => 
Ok(ast::Value::Number(f.to_string(), false)),
-            ScalarValue::Float32(None) => Ok(ast::Value::Null),
-            ScalarValue::Float64(Some(f)) => 
Ok(ast::Value::Number(f.to_string(), false)),
-            ScalarValue::Float64(None) => Ok(ast::Value::Null),
+            ScalarValue::Null => Ok(ast::Expr::Value(ast::Value::Null)),
+            ScalarValue::Boolean(Some(b)) => {
+                Ok(ast::Expr::Value(ast::Value::Boolean(b.to_owned())))
+            }
+            ScalarValue::Boolean(None) => 
Ok(ast::Expr::Value(ast::Value::Null)),
+            ScalarValue::Float32(Some(f)) => {
+                Ok(ast::Expr::Value(ast::Value::Number(f.to_string(), false)))
+            }
+            ScalarValue::Float32(None) => 
Ok(ast::Expr::Value(ast::Value::Null)),
+            ScalarValue::Float64(Some(f)) => {
+                Ok(ast::Expr::Value(ast::Value::Number(f.to_string(), false)))
+            }
+            ScalarValue::Float64(None) => 
Ok(ast::Expr::Value(ast::Value::Null)),
             ScalarValue::Decimal128(Some(_), ..) => {
                 not_impl_err!("Unsupported scalar: {v:?}")
             }
-            ScalarValue::Decimal128(None, ..) => Ok(ast::Value::Null),
+            ScalarValue::Decimal128(None, ..) => 
Ok(ast::Expr::Value(ast::Value::Null)),
             ScalarValue::Decimal256(Some(_), ..) => {
                 not_impl_err!("Unsupported scalar: {v:?}")
             }
-            ScalarValue::Decimal256(None, ..) => Ok(ast::Value::Null),
-            ScalarValue::Int8(Some(i)) => Ok(ast::Value::Number(i.to_string(), 
false)),
-            ScalarValue::Int8(None) => Ok(ast::Value::Null),
-            ScalarValue::Int16(Some(i)) => 
Ok(ast::Value::Number(i.to_string(), false)),
-            ScalarValue::Int16(None) => Ok(ast::Value::Null),
-            ScalarValue::Int32(Some(i)) => 
Ok(ast::Value::Number(i.to_string(), false)),
-            ScalarValue::Int32(None) => Ok(ast::Value::Null),
-            ScalarValue::Int64(Some(i)) => 
Ok(ast::Value::Number(i.to_string(), false)),
-            ScalarValue::Int64(None) => Ok(ast::Value::Null),
-            ScalarValue::UInt8(Some(ui)) => 
Ok(ast::Value::Number(ui.to_string(), false)),
-            ScalarValue::UInt8(None) => Ok(ast::Value::Null),
-            ScalarValue::UInt16(Some(ui)) => {
-                Ok(ast::Value::Number(ui.to_string(), false))
+            ScalarValue::Decimal256(None, ..) => 
Ok(ast::Expr::Value(ast::Value::Null)),
+            ScalarValue::Int8(Some(i)) => {
+                Ok(ast::Expr::Value(ast::Value::Number(i.to_string(), false)))
             }
-            ScalarValue::UInt16(None) => Ok(ast::Value::Null),
-            ScalarValue::UInt32(Some(ui)) => {
-                Ok(ast::Value::Number(ui.to_string(), false))
+            ScalarValue::Int8(None) => Ok(ast::Expr::Value(ast::Value::Null)),
+            ScalarValue::Int16(Some(i)) => {
+                Ok(ast::Expr::Value(ast::Value::Number(i.to_string(), false)))
             }
-            ScalarValue::UInt32(None) => Ok(ast::Value::Null),
-            ScalarValue::UInt64(Some(ui)) => {
-                Ok(ast::Value::Number(ui.to_string(), false))
+            ScalarValue::Int16(None) => Ok(ast::Expr::Value(ast::Value::Null)),
+            ScalarValue::Int32(Some(i)) => {
+                Ok(ast::Expr::Value(ast::Value::Number(i.to_string(), false)))
+            }
+            ScalarValue::Int32(None) => Ok(ast::Expr::Value(ast::Value::Null)),
+            ScalarValue::Int64(Some(i)) => {
+                Ok(ast::Expr::Value(ast::Value::Number(i.to_string(), false)))
             }
-            ScalarValue::UInt64(None) => Ok(ast::Value::Null),
-            ScalarValue::Utf8(Some(str)) => {
-                Ok(ast::Value::SingleQuotedString(str.to_string()))
+            ScalarValue::Int64(None) => Ok(ast::Expr::Value(ast::Value::Null)),
+            ScalarValue::UInt8(Some(ui)) => {
+                Ok(ast::Expr::Value(ast::Value::Number(ui.to_string(), false)))
             }
-            ScalarValue::Utf8(None) => Ok(ast::Value::Null),
-            ScalarValue::LargeUtf8(Some(str)) => {
-                Ok(ast::Value::SingleQuotedString(str.to_string()))
+            ScalarValue::UInt8(None) => Ok(ast::Expr::Value(ast::Value::Null)),
+            ScalarValue::UInt16(Some(ui)) => {
+                Ok(ast::Expr::Value(ast::Value::Number(ui.to_string(), false)))
+            }
+            ScalarValue::UInt16(None) => 
Ok(ast::Expr::Value(ast::Value::Null)),
+            ScalarValue::UInt32(Some(ui)) => {
+                Ok(ast::Expr::Value(ast::Value::Number(ui.to_string(), false)))
             }
-            ScalarValue::LargeUtf8(None) => Ok(ast::Value::Null),
+            ScalarValue::UInt32(None) => 
Ok(ast::Expr::Value(ast::Value::Null)),
+            ScalarValue::UInt64(Some(ui)) => {
+                Ok(ast::Expr::Value(ast::Value::Number(ui.to_string(), false)))
+            }
+            ScalarValue::UInt64(None) => 
Ok(ast::Expr::Value(ast::Value::Null)),
+            ScalarValue::Utf8(Some(str)) => Ok(ast::Expr::Value(
+                ast::Value::SingleQuotedString(str.to_string()),
+            )),
+            ScalarValue::Utf8(None) => Ok(ast::Expr::Value(ast::Value::Null)),
+            ScalarValue::LargeUtf8(Some(str)) => Ok(ast::Expr::Value(
+                ast::Value::SingleQuotedString(str.to_string()),
+            )),
+            ScalarValue::LargeUtf8(None) => 
Ok(ast::Expr::Value(ast::Value::Null)),
             ScalarValue::Binary(Some(_)) => not_impl_err!("Unsupported scalar: 
{v:?}"),
-            ScalarValue::Binary(None) => Ok(ast::Value::Null),
+            ScalarValue::Binary(None) => 
Ok(ast::Expr::Value(ast::Value::Null)),
             ScalarValue::FixedSizeBinary(..) => {
                 not_impl_err!("Unsupported scalar: {v:?}")
             }
             ScalarValue::LargeBinary(Some(_)) => {
                 not_impl_err!("Unsupported scalar: {v:?}")
             }
-            ScalarValue::LargeBinary(None) => Ok(ast::Value::Null),
+            ScalarValue::LargeBinary(None) => 
Ok(ast::Expr::Value(ast::Value::Null)),
             ScalarValue::FixedSizeList(_a) => not_impl_err!("Unsupported 
scalar: {v:?}"),
             ScalarValue::List(_a) => not_impl_err!("Unsupported scalar: 
{v:?}"),
             ScalarValue::LargeList(_a) => not_impl_err!("Unsupported scalar: 
{v:?}"),
-            ScalarValue::Date32(Some(_d)) => not_impl_err!("Unsupported 
scalar: {v:?}"),
-            ScalarValue::Date32(None) => Ok(ast::Value::Null),
-            ScalarValue::Date64(Some(_d)) => not_impl_err!("Unsupported 
scalar: {v:?}"),
-            ScalarValue::Date64(None) => Ok(ast::Value::Null),
+            ScalarValue::Date32(Some(_)) => {
+                let date = v
+                    .to_array()?
+                    .as_any()
+                    .downcast_ref::<Date32Array>()
+                    .ok_or(internal_datafusion_err!(
+                        "Unable to downcast to Date32 from Date32 scalar"
+                    ))?
+                    .value_as_date(0)
+                    .ok_or(internal_datafusion_err!(
+                        "Unable to convert Date32 to NaiveDate"
+                    ))?;
+
+                Ok(ast::Expr::Cast {
+                    expr: 
Box::new(ast::Expr::Value(ast::Value::SingleQuotedString(
+                        date.to_string(),
+                    ))),
+                    data_type: ast::DataType::Date,
+                    format: None,
+                })
+            }
+            ScalarValue::Date32(None) => 
Ok(ast::Expr::Value(ast::Value::Null)),
+            ScalarValue::Date64(Some(_)) => {
+                let datetime = v
+                    .to_array()?
+                    .as_any()
+                    .downcast_ref::<Date64Array>()
+                    .ok_or(internal_datafusion_err!(
+                        "Unable to downcast to Date64 from Date64 scalar"
+                    ))?
+                    .value_as_datetime(0)
+                    .ok_or(internal_datafusion_err!(
+                        "Unable to convert Date64 to NaiveDateTime"
+                    ))?;
+
+                Ok(ast::Expr::Cast {
+                    expr: 
Box::new(ast::Expr::Value(ast::Value::SingleQuotedString(
+                        datetime.to_string(),
+                    ))),
+                    data_type: ast::DataType::Datetime(None),
+                    format: None,
+                })
+            }
+            ScalarValue::Date64(None) => 
Ok(ast::Expr::Value(ast::Value::Null)),
             ScalarValue::Time32Second(Some(_t)) => {
                 not_impl_err!("Unsupported scalar: {v:?}")
             }
-            ScalarValue::Time32Second(None) => Ok(ast::Value::Null),
+            ScalarValue::Time32Second(None) => 
Ok(ast::Expr::Value(ast::Value::Null)),
             ScalarValue::Time32Millisecond(Some(_t)) => {
                 not_impl_err!("Unsupported scalar: {v:?}")
             }
-            ScalarValue::Time32Millisecond(None) => Ok(ast::Value::Null),
+            ScalarValue::Time32Millisecond(None) => {
+                Ok(ast::Expr::Value(ast::Value::Null))
+            }
             ScalarValue::Time64Microsecond(Some(_t)) => {
                 not_impl_err!("Unsupported scalar: {v:?}")
             }
-            ScalarValue::Time64Microsecond(None) => Ok(ast::Value::Null),
+            ScalarValue::Time64Microsecond(None) => {
+                Ok(ast::Expr::Value(ast::Value::Null))
+            }
             ScalarValue::Time64Nanosecond(Some(_t)) => {
                 not_impl_err!("Unsupported scalar: {v:?}")
             }
-            ScalarValue::Time64Nanosecond(None) => Ok(ast::Value::Null),
+            ScalarValue::Time64Nanosecond(None) => 
Ok(ast::Expr::Value(ast::Value::Null)),
             ScalarValue::TimestampSecond(Some(_ts), _) => {
                 not_impl_err!("Unsupported scalar: {v:?}")
             }
-            ScalarValue::TimestampSecond(None, _) => Ok(ast::Value::Null),
+            ScalarValue::TimestampSecond(None, _) => {
+                Ok(ast::Expr::Value(ast::Value::Null))
+            }
             ScalarValue::TimestampMillisecond(Some(_ts), _) => {
                 not_impl_err!("Unsupported scalar: {v:?}")
             }
-            ScalarValue::TimestampMillisecond(None, _) => Ok(ast::Value::Null),
+            ScalarValue::TimestampMillisecond(None, _) => {
+                Ok(ast::Expr::Value(ast::Value::Null))
+            }
             ScalarValue::TimestampMicrosecond(Some(_ts), _) => {
                 not_impl_err!("Unsupported scalar: {v:?}")
             }
-            ScalarValue::TimestampMicrosecond(None, _) => Ok(ast::Value::Null),
+            ScalarValue::TimestampMicrosecond(None, _) => {
+                Ok(ast::Expr::Value(ast::Value::Null))
+            }
             ScalarValue::TimestampNanosecond(Some(_ts), _) => {
                 not_impl_err!("Unsupported scalar: {v:?}")
             }
-            ScalarValue::TimestampNanosecond(None, _) => Ok(ast::Value::Null),
+            ScalarValue::TimestampNanosecond(None, _) => {
+                Ok(ast::Expr::Value(ast::Value::Null))
+            }
             ScalarValue::IntervalYearMonth(Some(_i)) => {
                 not_impl_err!("Unsupported scalar: {v:?}")
             }
-            ScalarValue::IntervalYearMonth(None) => Ok(ast::Value::Null),
+            ScalarValue::IntervalYearMonth(None) => {
+                Ok(ast::Expr::Value(ast::Value::Null))
+            }
             ScalarValue::IntervalDayTime(Some(_i)) => {
                 not_impl_err!("Unsupported scalar: {v:?}")
             }
-            ScalarValue::IntervalDayTime(None) => Ok(ast::Value::Null),
+            ScalarValue::IntervalDayTime(None) => 
Ok(ast::Expr::Value(ast::Value::Null)),
             ScalarValue::IntervalMonthDayNano(Some(_i)) => {
                 not_impl_err!("Unsupported scalar: {v:?}")
             }
-            ScalarValue::IntervalMonthDayNano(None) => Ok(ast::Value::Null),
+            ScalarValue::IntervalMonthDayNano(None) => {
+                Ok(ast::Expr::Value(ast::Value::Null))
+            }
             ScalarValue::DurationSecond(Some(_d)) => {
                 not_impl_err!("Unsupported scalar: {v:?}")
             }
-            ScalarValue::DurationSecond(None) => Ok(ast::Value::Null),
+            ScalarValue::DurationSecond(None) => 
Ok(ast::Expr::Value(ast::Value::Null)),
             ScalarValue::DurationMillisecond(Some(_d)) => {
                 not_impl_err!("Unsupported scalar: {v:?}")
             }
-            ScalarValue::DurationMillisecond(None) => Ok(ast::Value::Null),
+            ScalarValue::DurationMillisecond(None) => {
+                Ok(ast::Expr::Value(ast::Value::Null))
+            }
             ScalarValue::DurationMicrosecond(Some(_d)) => {
                 not_impl_err!("Unsupported scalar: {v:?}")
             }
-            ScalarValue::DurationMicrosecond(None) => Ok(ast::Value::Null),
+            ScalarValue::DurationMicrosecond(None) => {
+                Ok(ast::Expr::Value(ast::Value::Null))
+            }
             ScalarValue::DurationNanosecond(Some(_d)) => {
                 not_impl_err!("Unsupported scalar: {v:?}")
             }
-            ScalarValue::DurationNanosecond(None) => Ok(ast::Value::Null),
+            ScalarValue::DurationNanosecond(None) => {
+                Ok(ast::Expr::Value(ast::Value::Null))
+            }
             ScalarValue::Struct(_) => not_impl_err!("Unsupported scalar: 
{v:?}"),
             ScalarValue::Dictionary(..) => not_impl_err!("Unsupported scalar: 
{v:?}"),
         }
     }
+
+    fn arrow_dtype_to_ast_dtype(&self, data_type: &DataType) -> 
Result<ast::DataType> {
+        match data_type {
+            DataType::Null => {
+                not_impl_err!("Unsupported DataType: conversion: 
{data_type:?}")
+            }
+            DataType::Boolean => Ok(ast::DataType::Bool),
+            DataType::Int8 => Ok(ast::DataType::TinyInt(None)),
+            DataType::Int16 => Ok(ast::DataType::SmallInt(None)),
+            DataType::Int32 => Ok(ast::DataType::Integer(None)),
+            DataType::Int64 => Ok(ast::DataType::BigInt(None)),
+            DataType::UInt8 => Ok(ast::DataType::UnsignedTinyInt(None)),
+            DataType::UInt16 => Ok(ast::DataType::UnsignedSmallInt(None)),
+            DataType::UInt32 => Ok(ast::DataType::UnsignedInteger(None)),
+            DataType::UInt64 => Ok(ast::DataType::UnsignedBigInt(None)),
+            DataType::Float16 => {
+                not_impl_err!("Unsupported DataType: conversion: 
{data_type:?}")
+            }
+            DataType::Float32 => Ok(ast::DataType::Float(None)),
+            DataType::Float64 => Ok(ast::DataType::Double),
+            DataType::Timestamp(_, _) => {
+                not_impl_err!("Unsupported DataType: conversion: 
{data_type:?}")
+            }
+            DataType::Date32 => Ok(ast::DataType::Date),
+            DataType::Date64 => Ok(ast::DataType::Datetime(None)),
+            DataType::Time32(_) => todo!(),
+            DataType::Time64(_) => todo!(),
+            DataType::Duration(_) => todo!(),
+            DataType::Interval(_) => todo!(),
+            DataType::Binary => todo!(),
+            DataType::FixedSizeBinary(_) => todo!(),
+            DataType::LargeBinary => todo!(),
+            DataType::Utf8 => Ok(ast::DataType::Varchar(None)),
+            DataType::LargeUtf8 => Ok(ast::DataType::Text),
+            DataType::List(_) => todo!(),
+            DataType::FixedSizeList(_, _) => todo!(),
+            DataType::LargeList(_) => todo!(),
+            DataType::Struct(_) => todo!(),
+            DataType::Union(_, _) => todo!(),
+            DataType::Dictionary(_, _) => todo!(),
+            DataType::Decimal128(_, _) => todo!(),
+            DataType::Decimal256(_, _) => todo!(),
+            DataType::Map(_, _) => todo!(),
+            DataType::RunEndEncoded(_, _) => todo!(),
+        }
+    }
 }
 
 #[cfg(test)]
 mod tests {
     use datafusion_common::TableReference;
-    use datafusion_expr::{col, lit};
+    use datafusion_expr::{col, expr::AggregateFunction, lit};
 
     use crate::unparser::dialect::CustomDialect;
 
@@ -316,14 +490,81 @@ mod tests {
 
     #[test]
     fn expr_to_sql_ok() -> Result<()> {
-        let tests: Vec<(Expr, &str)> = vec![(
-            Expr::Column(Column {
-                relation: Some(TableReference::partial("a", "b")),
-                name: "c".to_string(),
-            })
-            .gt(lit(4)),
-            r#"a.b.c > 4"#,
-        )];
+        let tests: Vec<(Expr, &str)> = vec![
+            ((col("a") + col("b")).gt(lit(4)), r#"((a + b) > 4)"#),
+            (
+                Expr::Column(Column {
+                    relation: Some(TableReference::partial("a", "b")),
+                    name: "c".to_string(),
+                })
+                .gt(lit(4)),
+                r#"(a.b.c > 4)"#,
+            ),
+            (
+                Expr::Cast(Cast {
+                    expr: Box::new(col("a")),
+                    data_type: DataType::Date64,
+                }),
+                r#"CAST(a AS DATETIME)"#,
+            ),
+            (
+                Expr::Cast(Cast {
+                    expr: Box::new(col("a")),
+                    data_type: DataType::UInt32,
+                }),
+                r#"CAST(a AS INTEGER UNSIGNED)"#,
+            ),
+            (
+                Expr::Literal(ScalarValue::Date64(Some(0))),
+                r#"CAST('1970-01-01 00:00:00' AS DATETIME)"#,
+            ),
+            (
+                Expr::Literal(ScalarValue::Date64(Some(10000))),
+                r#"CAST('1970-01-01 00:00:10' AS DATETIME)"#,
+            ),
+            (
+                Expr::Literal(ScalarValue::Date64(Some(-10000))),
+                r#"CAST('1969-12-31 23:59:50' AS DATETIME)"#,
+            ),
+            (
+                Expr::Literal(ScalarValue::Date32(Some(0))),
+                r#"CAST('1970-01-01' AS DATE)"#,
+            ),
+            (
+                Expr::Literal(ScalarValue::Date32(Some(10))),
+                r#"CAST('1970-01-11' AS DATE)"#,
+            ),
+            (
+                Expr::Literal(ScalarValue::Date32(Some(-1))),
+                r#"CAST('1969-12-31' AS DATE)"#,
+            ),
+            (
+                Expr::AggregateFunction(AggregateFunction {
+                    func_def: AggregateFunctionDefinition::BuiltIn(
+                        datafusion_expr::AggregateFunction::Sum,
+                    ),
+                    args: vec![col("a")],
+                    distinct: false,
+                    filter: None,
+                    order_by: None,
+                    null_treatment: None,
+                }),
+                "SUM(a)",
+            ),
+            (
+                Expr::AggregateFunction(AggregateFunction {
+                    func_def: AggregateFunctionDefinition::BuiltIn(
+                        datafusion_expr::AggregateFunction::Count,
+                    ),
+                    args: vec![Expr::Wildcard { qualifier: None }],
+                    distinct: true,
+                    filter: None,
+                    order_by: None,
+                    null_treatment: None,
+                }),
+                "COUNT(DISTINCT *)",
+            ),
+        ];
 
         for (expr, expected) in tests {
             let ast = expr_to_sql(&expr)?;
@@ -346,7 +587,7 @@ mod tests {
 
         let actual = format!("{}", ast);
 
-        let expected = r#"'a' > 4"#;
+        let expected = r#"('a' > 4)"#;
         assert_eq!(actual, expected);
 
         Ok(())
diff --git a/datafusion/sql/tests/sql_integration.rs 
b/datafusion/sql/tests/sql_integration.rs
index fdf7ab8c3d..a6ea22db96 100644
--- a/datafusion/sql/tests/sql_integration.rs
+++ b/datafusion/sql/tests/sql_integration.rs
@@ -4493,8 +4493,18 @@ impl TableSource for EmptyTable {
 #[test]
 fn roundtrip_expr() {
     let tests: Vec<(TableReference, &str, &str)> = vec![
-        (TableReference::bare("person"), "age > 35", "age > 35"),
-        (TableReference::bare("person"), "id = '10'", "id = '10'"),
+        (TableReference::bare("person"), "age > 35", "(age > 35)"),
+        (TableReference::bare("person"), "id = '10'", "(id = '10')"),
+        (
+            TableReference::bare("person"),
+            "CAST(id AS VARCHAR)",
+            "CAST(id AS VARCHAR)",
+        ),
+        (
+            TableReference::bare("person"),
+            "SUM((age * 2))",
+            "SUM((age * 2))",
+        ),
     ];
 
     let roundtrip = |table, sql: &str| -> Result<String> {
@@ -4540,15 +4550,15 @@ fn roundtrip_statement() {
         ),
         (
             "select ta.j1_id from j1 ta where ta.j1_id > 1;",
-            r#"SELECT ta.j1_id FROM j1 AS ta WHERE ta.j1_id > 1"#,
+            r#"SELECT ta.j1_id FROM j1 AS ta WHERE (ta.j1_id > 1)"#,
         ),
         (
-            "select ta.j1_id, tb.j2_string from j1 ta join j2 tb on ta.j1_id = 
tb.j2_id;",
-            r#"SELECT ta.j1_id, tb.j2_string FROM j1 AS ta JOIN j2 AS tb ON 
ta.j1_id = tb.j2_id"#,
+            "select ta.j1_id, tb.j2_string from j1 ta join j2 tb on (ta.j1_id 
= tb.j2_id);",
+            r#"SELECT ta.j1_id, tb.j2_string FROM j1 AS ta JOIN j2 AS tb ON 
(ta.j1_id = tb.j2_id)"#,
         ),
         (
-            "select ta.j1_id, tb.j2_string, tc.j3_string from j1 ta join j2 tb 
on ta.j1_id = tb.j2_id join j3 tc on ta.j1_id = tc.j3_id;",
-            r#"SELECT ta.j1_id, tb.j2_string, tc.j3_string FROM j1 AS ta JOIN 
j2 AS tb ON ta.j1_id = tb.j2_id JOIN j3 AS tc ON ta.j1_id = tc.j3_id"#,
+            "select ta.j1_id, tb.j2_string, tc.j3_string from j1 ta join j2 tb 
on (ta.j1_id = tb.j2_id) join j3 tc on (ta.j1_id = tc.j3_id);",
+            r#"SELECT ta.j1_id, tb.j2_string, tc.j3_string FROM j1 AS ta JOIN 
j2 AS tb ON (ta.j1_id = tb.j2_id) JOIN j3 AS tc ON (ta.j1_id = tc.j3_id)"#,
         ),
     ];
 

Reply via email to