This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 5911d182ee feat: implement more expr_to_sql functionality (#9578)
5911d182ee is described below
commit 5911d182eef08afee4fbdef3da7642ee92d1314c
Author: Devin D'Angelo <[email protected]>
AuthorDate: Thu Mar 14 06:10:01 2024 -0400
feat: implement more expr_to_sql functionality (#9578)
* more impls
* fix tests
* cargo update dfcli
* fix custom_dialect test
* add tests and feature flag
* fix comment
* remove chrono use arrow-array conversions
* fix cargo lock again
* fix count distinct
* retry windows ci
* retry windows ci again
* add roundtrip tests
* cargo fmt
---
README.md | 1 +
datafusion-cli/Cargo.lock | 1 +
datafusion/sql/Cargo.toml | 4 +-
datafusion/sql/src/lib.rs | 1 +
datafusion/sql/src/unparser/expr.rs | 389 ++++++++++++++++++++++++++------
datafusion/sql/tests/sql_integration.rs | 24 +-
6 files changed, 338 insertions(+), 82 deletions(-)
diff --git a/README.md b/README.md
index e5ac9503be..abd727672a 100644
--- a/README.md
+++ b/README.md
@@ -83,6 +83,7 @@ Default features:
- `parquet`: support for reading the [Apache Parquet] format
- `regex_expressions`: regular expression functions, such as `regexp_match`
- `unicode_expressions`: Include unicode aware functions such as
`character_length`
+- `unparser` : enables support to reverse LogicalPlans back into SQL
Optional features:
diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock
index e0c7c4391b..1c2514811c 100644
--- a/datafusion-cli/Cargo.lock
+++ b/datafusion-cli/Cargo.lock
@@ -1366,6 +1366,7 @@ name = "datafusion-sql"
version = "36.0.0"
dependencies = [
"arrow",
+ "arrow-array",
"arrow-schema",
"datafusion-common",
"datafusion-expr",
diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml
index 7739058a5c..ca2c1a240c 100644
--- a/datafusion/sql/Cargo.toml
+++ b/datafusion/sql/Cargo.toml
@@ -33,11 +33,13 @@ name = "datafusion_sql"
path = "src/lib.rs"
[features]
-default = ["unicode_expressions"]
+default = ["unicode_expressions", "unparser"]
unicode_expressions = []
+unparser = []
[dependencies]
arrow = { workspace = true }
+arrow-array = { workspace = true }
arrow-schema = { workspace = true }
datafusion-common = { workspace = true, default-features = true }
datafusion-expr = { workspace = true }
diff --git a/datafusion/sql/src/lib.rs b/datafusion/sql/src/lib.rs
index da66ee197a..e8e07eebe2 100644
--- a/datafusion/sql/src/lib.rs
+++ b/datafusion/sql/src/lib.rs
@@ -36,6 +36,7 @@ mod relation;
mod select;
mod set_expr;
mod statement;
+#[cfg(feature = "unparser")]
pub mod unparser;
pub mod utils;
mod values;
diff --git a/datafusion/sql/src/unparser/expr.rs
b/datafusion/sql/src/unparser/expr.rs
index 2a9fdd47ad..403a7c6193 100644
--- a/datafusion/sql/src/unparser/expr.rs
+++ b/datafusion/sql/src/unparser/expr.rs
@@ -15,12 +15,16 @@
// specific language governing permissions and limitations
// under the License.
-use datafusion_common::{not_impl_err, Column, Result, ScalarValue};
+use arrow_array::{Date32Array, Date64Array};
+use arrow_schema::DataType;
+use datafusion_common::{
+ internal_datafusion_err, not_impl_err, Column, Result, ScalarValue,
+};
use datafusion_expr::{
- expr::{Alias, InList, ScalarFunction, WindowFunction},
+ expr::{AggregateFunctionDefinition, Alias, InList, ScalarFunction,
WindowFunction},
Between, BinaryExpr, Case, Cast, Expr, Like, Operator,
};
-use sqlparser::ast;
+use sqlparser::ast::{self, Function, FunctionArg, Ident};
use super::Unparser;
@@ -36,7 +40,7 @@ use super::Unparser;
/// let expr = col("a").gt(lit(4));
/// let sql = expr_to_sql(&expr).unwrap();
///
-/// assert_eq!(format!("{}", sql), "a > 4")
+/// assert_eq!(format!("{}", sql), "(a > 4)")
/// ```
pub fn expr_to_sql(expr: &Expr) -> Result<ast::Expr> {
let unparser = Unparser::default();
@@ -70,7 +74,7 @@ impl Unparser<'_> {
let r = self.expr_to_sql(right.as_ref())?;
let op = self.op_to_sql(op)?;
- Ok(self.binary_op_to_sql(l, r, op))
+ Ok(ast::Expr::Nested(Box::new(self.binary_op_to_sql(l, r,
op))))
}
Expr::Case(Case {
expr,
@@ -79,10 +83,15 @@ impl Unparser<'_> {
}) => {
not_impl_err!("Unsupported expression: {expr:?}")
}
- Expr::Cast(Cast { expr, data_type: _ }) => {
- not_impl_err!("Unsupported expression: {expr:?}")
+ Expr::Cast(Cast { expr, data_type }) => {
+ let inner_expr = self.expr_to_sql(expr)?;
+ Ok(ast::Expr::Cast {
+ expr: Box::new(inner_expr),
+ data_type: self.arrow_dtype_to_ast_dtype(data_type)?,
+ format: None,
+ })
}
- Expr::Literal(value) =>
Ok(ast::Expr::Value(self.scalar_to_sql(value)?)),
+ Expr::Literal(value) => Ok(self.scalar_to_sql(value)?),
Expr::Alias(Alias { expr, name: _, .. }) => self.expr_to_sql(expr),
Expr::WindowFunction(WindowFunction {
fun: _,
@@ -103,6 +112,45 @@ impl Unparser<'_> {
}) => {
not_impl_err!("Unsupported expression: {expr:?}")
}
+ Expr::AggregateFunction(agg) => {
+ let func_name = if let
AggregateFunctionDefinition::BuiltIn(built_in) =
+ &agg.func_def
+ {
+ built_in.name()
+ } else {
+ return not_impl_err!(
+ "Only built in agg functions are supported, got
{agg:?}"
+ );
+ };
+
+ let args = agg
+ .args
+ .iter()
+ .map(|e| {
+ if matches!(e, Expr::Wildcard { qualifier: None }) {
+
Ok(FunctionArg::Unnamed(ast::FunctionArgExpr::Wildcard))
+ } else {
+ self.expr_to_sql(e).map(|e| {
+
FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e))
+ })
+ }
+ })
+ .collect::<Result<Vec<_>>>()?;
+
+ Ok(ast::Expr::Function(Function {
+ name: ast::ObjectName(vec![Ident {
+ value: func_name.to_string(),
+ quote_style: None,
+ }]),
+ args,
+ filter: None,
+ null_treatment: None,
+ over: None,
+ distinct: agg.distinct,
+ special: false,
+ order_by: vec![],
+ }))
+ }
_ => not_impl_err!("Unsupported expression: {expr:?}"),
}
}
@@ -174,139 +222,265 @@ impl Unparser<'_> {
}
}
- fn scalar_to_sql(&self, v: &ScalarValue) -> Result<ast::Value> {
+ /// DataFusion ScalarValues sometimes require a ast::Expr to construct.
+ /// For example ScalarValue::Date32(d) corresponds to the ast::Expr
CAST('datestr' as DATE)
+ fn scalar_to_sql(&self, v: &ScalarValue) -> Result<ast::Expr> {
match v {
- ScalarValue::Null => Ok(ast::Value::Null),
- ScalarValue::Boolean(Some(b)) =>
Ok(ast::Value::Boolean(b.to_owned())),
- ScalarValue::Boolean(None) => Ok(ast::Value::Null),
- ScalarValue::Float32(Some(f)) =>
Ok(ast::Value::Number(f.to_string(), false)),
- ScalarValue::Float32(None) => Ok(ast::Value::Null),
- ScalarValue::Float64(Some(f)) =>
Ok(ast::Value::Number(f.to_string(), false)),
- ScalarValue::Float64(None) => Ok(ast::Value::Null),
+ ScalarValue::Null => Ok(ast::Expr::Value(ast::Value::Null)),
+ ScalarValue::Boolean(Some(b)) => {
+ Ok(ast::Expr::Value(ast::Value::Boolean(b.to_owned())))
+ }
+ ScalarValue::Boolean(None) =>
Ok(ast::Expr::Value(ast::Value::Null)),
+ ScalarValue::Float32(Some(f)) => {
+ Ok(ast::Expr::Value(ast::Value::Number(f.to_string(), false)))
+ }
+ ScalarValue::Float32(None) =>
Ok(ast::Expr::Value(ast::Value::Null)),
+ ScalarValue::Float64(Some(f)) => {
+ Ok(ast::Expr::Value(ast::Value::Number(f.to_string(), false)))
+ }
+ ScalarValue::Float64(None) =>
Ok(ast::Expr::Value(ast::Value::Null)),
ScalarValue::Decimal128(Some(_), ..) => {
not_impl_err!("Unsupported scalar: {v:?}")
}
- ScalarValue::Decimal128(None, ..) => Ok(ast::Value::Null),
+ ScalarValue::Decimal128(None, ..) =>
Ok(ast::Expr::Value(ast::Value::Null)),
ScalarValue::Decimal256(Some(_), ..) => {
not_impl_err!("Unsupported scalar: {v:?}")
}
- ScalarValue::Decimal256(None, ..) => Ok(ast::Value::Null),
- ScalarValue::Int8(Some(i)) => Ok(ast::Value::Number(i.to_string(),
false)),
- ScalarValue::Int8(None) => Ok(ast::Value::Null),
- ScalarValue::Int16(Some(i)) =>
Ok(ast::Value::Number(i.to_string(), false)),
- ScalarValue::Int16(None) => Ok(ast::Value::Null),
- ScalarValue::Int32(Some(i)) =>
Ok(ast::Value::Number(i.to_string(), false)),
- ScalarValue::Int32(None) => Ok(ast::Value::Null),
- ScalarValue::Int64(Some(i)) =>
Ok(ast::Value::Number(i.to_string(), false)),
- ScalarValue::Int64(None) => Ok(ast::Value::Null),
- ScalarValue::UInt8(Some(ui)) =>
Ok(ast::Value::Number(ui.to_string(), false)),
- ScalarValue::UInt8(None) => Ok(ast::Value::Null),
- ScalarValue::UInt16(Some(ui)) => {
- Ok(ast::Value::Number(ui.to_string(), false))
+ ScalarValue::Decimal256(None, ..) =>
Ok(ast::Expr::Value(ast::Value::Null)),
+ ScalarValue::Int8(Some(i)) => {
+ Ok(ast::Expr::Value(ast::Value::Number(i.to_string(), false)))
}
- ScalarValue::UInt16(None) => Ok(ast::Value::Null),
- ScalarValue::UInt32(Some(ui)) => {
- Ok(ast::Value::Number(ui.to_string(), false))
+ ScalarValue::Int8(None) => Ok(ast::Expr::Value(ast::Value::Null)),
+ ScalarValue::Int16(Some(i)) => {
+ Ok(ast::Expr::Value(ast::Value::Number(i.to_string(), false)))
}
- ScalarValue::UInt32(None) => Ok(ast::Value::Null),
- ScalarValue::UInt64(Some(ui)) => {
- Ok(ast::Value::Number(ui.to_string(), false))
+ ScalarValue::Int16(None) => Ok(ast::Expr::Value(ast::Value::Null)),
+ ScalarValue::Int32(Some(i)) => {
+ Ok(ast::Expr::Value(ast::Value::Number(i.to_string(), false)))
+ }
+ ScalarValue::Int32(None) => Ok(ast::Expr::Value(ast::Value::Null)),
+ ScalarValue::Int64(Some(i)) => {
+ Ok(ast::Expr::Value(ast::Value::Number(i.to_string(), false)))
}
- ScalarValue::UInt64(None) => Ok(ast::Value::Null),
- ScalarValue::Utf8(Some(str)) => {
- Ok(ast::Value::SingleQuotedString(str.to_string()))
+ ScalarValue::Int64(None) => Ok(ast::Expr::Value(ast::Value::Null)),
+ ScalarValue::UInt8(Some(ui)) => {
+ Ok(ast::Expr::Value(ast::Value::Number(ui.to_string(), false)))
}
- ScalarValue::Utf8(None) => Ok(ast::Value::Null),
- ScalarValue::LargeUtf8(Some(str)) => {
- Ok(ast::Value::SingleQuotedString(str.to_string()))
+ ScalarValue::UInt8(None) => Ok(ast::Expr::Value(ast::Value::Null)),
+ ScalarValue::UInt16(Some(ui)) => {
+ Ok(ast::Expr::Value(ast::Value::Number(ui.to_string(), false)))
+ }
+ ScalarValue::UInt16(None) =>
Ok(ast::Expr::Value(ast::Value::Null)),
+ ScalarValue::UInt32(Some(ui)) => {
+ Ok(ast::Expr::Value(ast::Value::Number(ui.to_string(), false)))
}
- ScalarValue::LargeUtf8(None) => Ok(ast::Value::Null),
+ ScalarValue::UInt32(None) =>
Ok(ast::Expr::Value(ast::Value::Null)),
+ ScalarValue::UInt64(Some(ui)) => {
+ Ok(ast::Expr::Value(ast::Value::Number(ui.to_string(), false)))
+ }
+ ScalarValue::UInt64(None) =>
Ok(ast::Expr::Value(ast::Value::Null)),
+ ScalarValue::Utf8(Some(str)) => Ok(ast::Expr::Value(
+ ast::Value::SingleQuotedString(str.to_string()),
+ )),
+ ScalarValue::Utf8(None) => Ok(ast::Expr::Value(ast::Value::Null)),
+ ScalarValue::LargeUtf8(Some(str)) => Ok(ast::Expr::Value(
+ ast::Value::SingleQuotedString(str.to_string()),
+ )),
+ ScalarValue::LargeUtf8(None) =>
Ok(ast::Expr::Value(ast::Value::Null)),
ScalarValue::Binary(Some(_)) => not_impl_err!("Unsupported scalar:
{v:?}"),
- ScalarValue::Binary(None) => Ok(ast::Value::Null),
+ ScalarValue::Binary(None) =>
Ok(ast::Expr::Value(ast::Value::Null)),
ScalarValue::FixedSizeBinary(..) => {
not_impl_err!("Unsupported scalar: {v:?}")
}
ScalarValue::LargeBinary(Some(_)) => {
not_impl_err!("Unsupported scalar: {v:?}")
}
- ScalarValue::LargeBinary(None) => Ok(ast::Value::Null),
+ ScalarValue::LargeBinary(None) =>
Ok(ast::Expr::Value(ast::Value::Null)),
ScalarValue::FixedSizeList(_a) => not_impl_err!("Unsupported
scalar: {v:?}"),
ScalarValue::List(_a) => not_impl_err!("Unsupported scalar:
{v:?}"),
ScalarValue::LargeList(_a) => not_impl_err!("Unsupported scalar:
{v:?}"),
- ScalarValue::Date32(Some(_d)) => not_impl_err!("Unsupported
scalar: {v:?}"),
- ScalarValue::Date32(None) => Ok(ast::Value::Null),
- ScalarValue::Date64(Some(_d)) => not_impl_err!("Unsupported
scalar: {v:?}"),
- ScalarValue::Date64(None) => Ok(ast::Value::Null),
+ ScalarValue::Date32(Some(_)) => {
+ let date = v
+ .to_array()?
+ .as_any()
+ .downcast_ref::<Date32Array>()
+ .ok_or(internal_datafusion_err!(
+ "Unable to downcast to Date32 from Date32 scalar"
+ ))?
+ .value_as_date(0)
+ .ok_or(internal_datafusion_err!(
+ "Unable to convert Date32 to NaiveDate"
+ ))?;
+
+ Ok(ast::Expr::Cast {
+ expr:
Box::new(ast::Expr::Value(ast::Value::SingleQuotedString(
+ date.to_string(),
+ ))),
+ data_type: ast::DataType::Date,
+ format: None,
+ })
+ }
+ ScalarValue::Date32(None) =>
Ok(ast::Expr::Value(ast::Value::Null)),
+ ScalarValue::Date64(Some(_)) => {
+ let datetime = v
+ .to_array()?
+ .as_any()
+ .downcast_ref::<Date64Array>()
+ .ok_or(internal_datafusion_err!(
+ "Unable to downcast to Date64 from Date64 scalar"
+ ))?
+ .value_as_datetime(0)
+ .ok_or(internal_datafusion_err!(
+ "Unable to convert Date64 to NaiveDateTime"
+ ))?;
+
+ Ok(ast::Expr::Cast {
+ expr:
Box::new(ast::Expr::Value(ast::Value::SingleQuotedString(
+ datetime.to_string(),
+ ))),
+ data_type: ast::DataType::Datetime(None),
+ format: None,
+ })
+ }
+ ScalarValue::Date64(None) =>
Ok(ast::Expr::Value(ast::Value::Null)),
ScalarValue::Time32Second(Some(_t)) => {
not_impl_err!("Unsupported scalar: {v:?}")
}
- ScalarValue::Time32Second(None) => Ok(ast::Value::Null),
+ ScalarValue::Time32Second(None) =>
Ok(ast::Expr::Value(ast::Value::Null)),
ScalarValue::Time32Millisecond(Some(_t)) => {
not_impl_err!("Unsupported scalar: {v:?}")
}
- ScalarValue::Time32Millisecond(None) => Ok(ast::Value::Null),
+ ScalarValue::Time32Millisecond(None) => {
+ Ok(ast::Expr::Value(ast::Value::Null))
+ }
ScalarValue::Time64Microsecond(Some(_t)) => {
not_impl_err!("Unsupported scalar: {v:?}")
}
- ScalarValue::Time64Microsecond(None) => Ok(ast::Value::Null),
+ ScalarValue::Time64Microsecond(None) => {
+ Ok(ast::Expr::Value(ast::Value::Null))
+ }
ScalarValue::Time64Nanosecond(Some(_t)) => {
not_impl_err!("Unsupported scalar: {v:?}")
}
- ScalarValue::Time64Nanosecond(None) => Ok(ast::Value::Null),
+ ScalarValue::Time64Nanosecond(None) =>
Ok(ast::Expr::Value(ast::Value::Null)),
ScalarValue::TimestampSecond(Some(_ts), _) => {
not_impl_err!("Unsupported scalar: {v:?}")
}
- ScalarValue::TimestampSecond(None, _) => Ok(ast::Value::Null),
+ ScalarValue::TimestampSecond(None, _) => {
+ Ok(ast::Expr::Value(ast::Value::Null))
+ }
ScalarValue::TimestampMillisecond(Some(_ts), _) => {
not_impl_err!("Unsupported scalar: {v:?}")
}
- ScalarValue::TimestampMillisecond(None, _) => Ok(ast::Value::Null),
+ ScalarValue::TimestampMillisecond(None, _) => {
+ Ok(ast::Expr::Value(ast::Value::Null))
+ }
ScalarValue::TimestampMicrosecond(Some(_ts), _) => {
not_impl_err!("Unsupported scalar: {v:?}")
}
- ScalarValue::TimestampMicrosecond(None, _) => Ok(ast::Value::Null),
+ ScalarValue::TimestampMicrosecond(None, _) => {
+ Ok(ast::Expr::Value(ast::Value::Null))
+ }
ScalarValue::TimestampNanosecond(Some(_ts), _) => {
not_impl_err!("Unsupported scalar: {v:?}")
}
- ScalarValue::TimestampNanosecond(None, _) => Ok(ast::Value::Null),
+ ScalarValue::TimestampNanosecond(None, _) => {
+ Ok(ast::Expr::Value(ast::Value::Null))
+ }
ScalarValue::IntervalYearMonth(Some(_i)) => {
not_impl_err!("Unsupported scalar: {v:?}")
}
- ScalarValue::IntervalYearMonth(None) => Ok(ast::Value::Null),
+ ScalarValue::IntervalYearMonth(None) => {
+ Ok(ast::Expr::Value(ast::Value::Null))
+ }
ScalarValue::IntervalDayTime(Some(_i)) => {
not_impl_err!("Unsupported scalar: {v:?}")
}
- ScalarValue::IntervalDayTime(None) => Ok(ast::Value::Null),
+ ScalarValue::IntervalDayTime(None) =>
Ok(ast::Expr::Value(ast::Value::Null)),
ScalarValue::IntervalMonthDayNano(Some(_i)) => {
not_impl_err!("Unsupported scalar: {v:?}")
}
- ScalarValue::IntervalMonthDayNano(None) => Ok(ast::Value::Null),
+ ScalarValue::IntervalMonthDayNano(None) => {
+ Ok(ast::Expr::Value(ast::Value::Null))
+ }
ScalarValue::DurationSecond(Some(_d)) => {
not_impl_err!("Unsupported scalar: {v:?}")
}
- ScalarValue::DurationSecond(None) => Ok(ast::Value::Null),
+ ScalarValue::DurationSecond(None) =>
Ok(ast::Expr::Value(ast::Value::Null)),
ScalarValue::DurationMillisecond(Some(_d)) => {
not_impl_err!("Unsupported scalar: {v:?}")
}
- ScalarValue::DurationMillisecond(None) => Ok(ast::Value::Null),
+ ScalarValue::DurationMillisecond(None) => {
+ Ok(ast::Expr::Value(ast::Value::Null))
+ }
ScalarValue::DurationMicrosecond(Some(_d)) => {
not_impl_err!("Unsupported scalar: {v:?}")
}
- ScalarValue::DurationMicrosecond(None) => Ok(ast::Value::Null),
+ ScalarValue::DurationMicrosecond(None) => {
+ Ok(ast::Expr::Value(ast::Value::Null))
+ }
ScalarValue::DurationNanosecond(Some(_d)) => {
not_impl_err!("Unsupported scalar: {v:?}")
}
- ScalarValue::DurationNanosecond(None) => Ok(ast::Value::Null),
+ ScalarValue::DurationNanosecond(None) => {
+ Ok(ast::Expr::Value(ast::Value::Null))
+ }
ScalarValue::Struct(_) => not_impl_err!("Unsupported scalar:
{v:?}"),
ScalarValue::Dictionary(..) => not_impl_err!("Unsupported scalar:
{v:?}"),
}
}
+
+ fn arrow_dtype_to_ast_dtype(&self, data_type: &DataType) ->
Result<ast::DataType> {
+ match data_type {
+ DataType::Null => {
+ not_impl_err!("Unsupported DataType: conversion:
{data_type:?}")
+ }
+ DataType::Boolean => Ok(ast::DataType::Bool),
+ DataType::Int8 => Ok(ast::DataType::TinyInt(None)),
+ DataType::Int16 => Ok(ast::DataType::SmallInt(None)),
+ DataType::Int32 => Ok(ast::DataType::Integer(None)),
+ DataType::Int64 => Ok(ast::DataType::BigInt(None)),
+ DataType::UInt8 => Ok(ast::DataType::UnsignedTinyInt(None)),
+ DataType::UInt16 => Ok(ast::DataType::UnsignedSmallInt(None)),
+ DataType::UInt32 => Ok(ast::DataType::UnsignedInteger(None)),
+ DataType::UInt64 => Ok(ast::DataType::UnsignedBigInt(None)),
+ DataType::Float16 => {
+ not_impl_err!("Unsupported DataType: conversion:
{data_type:?}")
+ }
+ DataType::Float32 => Ok(ast::DataType::Float(None)),
+ DataType::Float64 => Ok(ast::DataType::Double),
+ DataType::Timestamp(_, _) => {
+ not_impl_err!("Unsupported DataType: conversion:
{data_type:?}")
+ }
+ DataType::Date32 => Ok(ast::DataType::Date),
+ DataType::Date64 => Ok(ast::DataType::Datetime(None)),
+ DataType::Time32(_) => todo!(),
+ DataType::Time64(_) => todo!(),
+ DataType::Duration(_) => todo!(),
+ DataType::Interval(_) => todo!(),
+ DataType::Binary => todo!(),
+ DataType::FixedSizeBinary(_) => todo!(),
+ DataType::LargeBinary => todo!(),
+ DataType::Utf8 => Ok(ast::DataType::Varchar(None)),
+ DataType::LargeUtf8 => Ok(ast::DataType::Text),
+ DataType::List(_) => todo!(),
+ DataType::FixedSizeList(_, _) => todo!(),
+ DataType::LargeList(_) => todo!(),
+ DataType::Struct(_) => todo!(),
+ DataType::Union(_, _) => todo!(),
+ DataType::Dictionary(_, _) => todo!(),
+ DataType::Decimal128(_, _) => todo!(),
+ DataType::Decimal256(_, _) => todo!(),
+ DataType::Map(_, _) => todo!(),
+ DataType::RunEndEncoded(_, _) => todo!(),
+ }
+ }
}
#[cfg(test)]
mod tests {
use datafusion_common::TableReference;
- use datafusion_expr::{col, lit};
+ use datafusion_expr::{col, expr::AggregateFunction, lit};
use crate::unparser::dialect::CustomDialect;
@@ -316,14 +490,81 @@ mod tests {
#[test]
fn expr_to_sql_ok() -> Result<()> {
- let tests: Vec<(Expr, &str)> = vec![(
- Expr::Column(Column {
- relation: Some(TableReference::partial("a", "b")),
- name: "c".to_string(),
- })
- .gt(lit(4)),
- r#"a.b.c > 4"#,
- )];
+ let tests: Vec<(Expr, &str)> = vec![
+ ((col("a") + col("b")).gt(lit(4)), r#"((a + b) > 4)"#),
+ (
+ Expr::Column(Column {
+ relation: Some(TableReference::partial("a", "b")),
+ name: "c".to_string(),
+ })
+ .gt(lit(4)),
+ r#"(a.b.c > 4)"#,
+ ),
+ (
+ Expr::Cast(Cast {
+ expr: Box::new(col("a")),
+ data_type: DataType::Date64,
+ }),
+ r#"CAST(a AS DATETIME)"#,
+ ),
+ (
+ Expr::Cast(Cast {
+ expr: Box::new(col("a")),
+ data_type: DataType::UInt32,
+ }),
+ r#"CAST(a AS INTEGER UNSIGNED)"#,
+ ),
+ (
+ Expr::Literal(ScalarValue::Date64(Some(0))),
+ r#"CAST('1970-01-01 00:00:00' AS DATETIME)"#,
+ ),
+ (
+ Expr::Literal(ScalarValue::Date64(Some(10000))),
+ r#"CAST('1970-01-01 00:00:10' AS DATETIME)"#,
+ ),
+ (
+ Expr::Literal(ScalarValue::Date64(Some(-10000))),
+ r#"CAST('1969-12-31 23:59:50' AS DATETIME)"#,
+ ),
+ (
+ Expr::Literal(ScalarValue::Date32(Some(0))),
+ r#"CAST('1970-01-01' AS DATE)"#,
+ ),
+ (
+ Expr::Literal(ScalarValue::Date32(Some(10))),
+ r#"CAST('1970-01-11' AS DATE)"#,
+ ),
+ (
+ Expr::Literal(ScalarValue::Date32(Some(-1))),
+ r#"CAST('1969-12-31' AS DATE)"#,
+ ),
+ (
+ Expr::AggregateFunction(AggregateFunction {
+ func_def: AggregateFunctionDefinition::BuiltIn(
+ datafusion_expr::AggregateFunction::Sum,
+ ),
+ args: vec![col("a")],
+ distinct: false,
+ filter: None,
+ order_by: None,
+ null_treatment: None,
+ }),
+ "SUM(a)",
+ ),
+ (
+ Expr::AggregateFunction(AggregateFunction {
+ func_def: AggregateFunctionDefinition::BuiltIn(
+ datafusion_expr::AggregateFunction::Count,
+ ),
+ args: vec![Expr::Wildcard { qualifier: None }],
+ distinct: true,
+ filter: None,
+ order_by: None,
+ null_treatment: None,
+ }),
+ "COUNT(DISTINCT *)",
+ ),
+ ];
for (expr, expected) in tests {
let ast = expr_to_sql(&expr)?;
@@ -346,7 +587,7 @@ mod tests {
let actual = format!("{}", ast);
- let expected = r#"'a' > 4"#;
+ let expected = r#"('a' > 4)"#;
assert_eq!(actual, expected);
Ok(())
diff --git a/datafusion/sql/tests/sql_integration.rs
b/datafusion/sql/tests/sql_integration.rs
index fdf7ab8c3d..a6ea22db96 100644
--- a/datafusion/sql/tests/sql_integration.rs
+++ b/datafusion/sql/tests/sql_integration.rs
@@ -4493,8 +4493,18 @@ impl TableSource for EmptyTable {
#[test]
fn roundtrip_expr() {
let tests: Vec<(TableReference, &str, &str)> = vec![
- (TableReference::bare("person"), "age > 35", "age > 35"),
- (TableReference::bare("person"), "id = '10'", "id = '10'"),
+ (TableReference::bare("person"), "age > 35", "(age > 35)"),
+ (TableReference::bare("person"), "id = '10'", "(id = '10')"),
+ (
+ TableReference::bare("person"),
+ "CAST(id AS VARCHAR)",
+ "CAST(id AS VARCHAR)",
+ ),
+ (
+ TableReference::bare("person"),
+ "SUM((age * 2))",
+ "SUM((age * 2))",
+ ),
];
let roundtrip = |table, sql: &str| -> Result<String> {
@@ -4540,15 +4550,15 @@ fn roundtrip_statement() {
),
(
"select ta.j1_id from j1 ta where ta.j1_id > 1;",
- r#"SELECT ta.j1_id FROM j1 AS ta WHERE ta.j1_id > 1"#,
+ r#"SELECT ta.j1_id FROM j1 AS ta WHERE (ta.j1_id > 1)"#,
),
(
- "select ta.j1_id, tb.j2_string from j1 ta join j2 tb on ta.j1_id =
tb.j2_id;",
- r#"SELECT ta.j1_id, tb.j2_string FROM j1 AS ta JOIN j2 AS tb ON
ta.j1_id = tb.j2_id"#,
+ "select ta.j1_id, tb.j2_string from j1 ta join j2 tb on (ta.j1_id
= tb.j2_id);",
+ r#"SELECT ta.j1_id, tb.j2_string FROM j1 AS ta JOIN j2 AS tb ON
(ta.j1_id = tb.j2_id)"#,
),
(
- "select ta.j1_id, tb.j2_string, tc.j3_string from j1 ta join j2 tb
on ta.j1_id = tb.j2_id join j3 tc on ta.j1_id = tc.j3_id;",
- r#"SELECT ta.j1_id, tb.j2_string, tc.j3_string FROM j1 AS ta JOIN
j2 AS tb ON ta.j1_id = tb.j2_id JOIN j3 AS tc ON ta.j1_id = tc.j3_id"#,
+ "select ta.j1_id, tb.j2_string, tc.j3_string from j1 ta join j2 tb
on (ta.j1_id = tb.j2_id) join j3 tc on (ta.j1_id = tc.j3_id);",
+ r#"SELECT ta.j1_id, tb.j2_string, tc.j3_string FROM j1 AS ta JOIN
j2 AS tb ON (ta.j1_id = tb.j2_id) JOIN j3 AS tc ON (ta.j1_id = tc.j3_id)"#,
),
];