This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 3d1d28d287 fix: Add Int32 type override for Dialects (#12916)
3d1d28d287 is described below
commit 3d1d28d287d6584668bde510908f65ebe262d22e
Author: peasee <[email protected]>
AuthorDate: Thu Oct 17 21:33:40 2024 +1000
fix: Add Int32 type override for Dialects (#12916)
* fix: Add Int32 type override for Dialects
* fix: Dialect builder with_int32_cast_dtype:
* test: Fix with_int32 test
---
datafusion/sql/src/unparser/dialect.rs | 25 +++++++++++++++++++++++++
datafusion/sql/src/unparser/expr.rs | 30 +++++++++++++++++++++++++++++-
2 files changed, 54 insertions(+), 1 deletion(-)
diff --git a/datafusion/sql/src/unparser/dialect.rs
b/datafusion/sql/src/unparser/dialect.rs
index aef3b0dfab..cfc28f2c49 100644
--- a/datafusion/sql/src/unparser/dialect.rs
+++ b/datafusion/sql/src/unparser/dialect.rs
@@ -86,6 +86,12 @@ pub trait Dialect: Send + Sync {
ast::DataType::BigInt(None)
}
+ /// The SQL type to use for Arrow Int32 unparsing
+ /// Most dialects use Integer, but some, like MySQL, require SIGNED
+ fn int32_cast_dtype(&self) -> ast::DataType {
+ ast::DataType::Integer(None)
+ }
+
/// The SQL type to use for Timestamp unparsing
/// Most dialects use Timestamp, but some, like MySQL, require Datetime
/// Some dialects like Dremio does not support WithTimeZone and requires
always Timestamp
@@ -282,6 +288,10 @@ impl Dialect for MySqlDialect {
ast::DataType::Custom(ObjectName(vec![Ident::new("SIGNED")]), vec![])
}
+ fn int32_cast_dtype(&self) -> ast::DataType {
+ ast::DataType::Custom(ObjectName(vec![Ident::new("SIGNED")]), vec![])
+ }
+
fn timestamp_cast_dtype(
&self,
_time_unit: &TimeUnit,
@@ -347,6 +357,7 @@ pub struct CustomDialect {
large_utf8_cast_dtype: ast::DataType,
date_field_extract_style: DateFieldExtractStyle,
int64_cast_dtype: ast::DataType,
+ int32_cast_dtype: ast::DataType,
timestamp_cast_dtype: ast::DataType,
timestamp_tz_cast_dtype: ast::DataType,
date32_cast_dtype: sqlparser::ast::DataType,
@@ -365,6 +376,7 @@ impl Default for CustomDialect {
large_utf8_cast_dtype: ast::DataType::Text,
date_field_extract_style: DateFieldExtractStyle::DatePart,
int64_cast_dtype: ast::DataType::BigInt(None),
+ int32_cast_dtype: ast::DataType::Integer(None),
timestamp_cast_dtype: ast::DataType::Timestamp(None,
TimezoneInfo::None),
timestamp_tz_cast_dtype: ast::DataType::Timestamp(
None,
@@ -424,6 +436,10 @@ impl Dialect for CustomDialect {
self.int64_cast_dtype.clone()
}
+ fn int32_cast_dtype(&self) -> ast::DataType {
+ self.int32_cast_dtype.clone()
+ }
+
fn timestamp_cast_dtype(
&self,
_time_unit: &TimeUnit,
@@ -482,6 +498,7 @@ pub struct CustomDialectBuilder {
large_utf8_cast_dtype: ast::DataType,
date_field_extract_style: DateFieldExtractStyle,
int64_cast_dtype: ast::DataType,
+ int32_cast_dtype: ast::DataType,
timestamp_cast_dtype: ast::DataType,
timestamp_tz_cast_dtype: ast::DataType,
date32_cast_dtype: ast::DataType,
@@ -506,6 +523,7 @@ impl CustomDialectBuilder {
large_utf8_cast_dtype: ast::DataType::Text,
date_field_extract_style: DateFieldExtractStyle::DatePart,
int64_cast_dtype: ast::DataType::BigInt(None),
+ int32_cast_dtype: ast::DataType::Integer(None),
timestamp_cast_dtype: ast::DataType::Timestamp(None,
TimezoneInfo::None),
timestamp_tz_cast_dtype: ast::DataType::Timestamp(
None,
@@ -527,6 +545,7 @@ impl CustomDialectBuilder {
large_utf8_cast_dtype: self.large_utf8_cast_dtype,
date_field_extract_style: self.date_field_extract_style,
int64_cast_dtype: self.int64_cast_dtype,
+ int32_cast_dtype: self.int32_cast_dtype,
timestamp_cast_dtype: self.timestamp_cast_dtype,
timestamp_tz_cast_dtype: self.timestamp_tz_cast_dtype,
date32_cast_dtype: self.date32_cast_dtype,
@@ -604,6 +623,12 @@ impl CustomDialectBuilder {
self
}
+ /// Customize the dialect with a specific SQL type for Int32 casting:
Integer, SIGNED, etc.
+ pub fn with_int32_cast_dtype(mut self, int32_cast_dtype: ast::DataType) ->
Self {
+ self.int32_cast_dtype = int32_cast_dtype;
+ self
+ }
+
/// Customize the dialect with a specific SQL type for Timestamp casting:
Timestamp, Datetime, etc.
pub fn with_timestamp_cast_dtype(
mut self,
diff --git a/datafusion/sql/src/unparser/expr.rs
b/datafusion/sql/src/unparser/expr.rs
index b7491d1f88..1be5aa68bf 100644
--- a/datafusion/sql/src/unparser/expr.rs
+++ b/datafusion/sql/src/unparser/expr.rs
@@ -1352,7 +1352,7 @@ impl Unparser<'_> {
DataType::Boolean => Ok(ast::DataType::Bool),
DataType::Int8 => Ok(ast::DataType::TinyInt(None)),
DataType::Int16 => Ok(ast::DataType::SmallInt(None)),
- DataType::Int32 => Ok(ast::DataType::Integer(None)),
+ DataType::Int32 => Ok(self.dialect.int32_cast_dtype()),
DataType::Int64 => Ok(self.dialect.int64_cast_dtype()),
DataType::UInt8 => Ok(ast::DataType::UnsignedTinyInt(None)),
DataType::UInt16 => Ok(ast::DataType::UnsignedSmallInt(None)),
@@ -2253,6 +2253,34 @@ mod tests {
Ok(())
}
+ #[test]
+ fn custom_dialect_with_int32_cast_dtype() -> Result<()> {
+ let default_dialect = CustomDialectBuilder::new().build();
+ let mysql_dialect = CustomDialectBuilder::new()
+ .with_int32_cast_dtype(ast::DataType::Custom(
+ ObjectName(vec![Ident::new("SIGNED")]),
+ vec![],
+ ))
+ .build();
+
+ for (dialect, identifier) in
+ [(default_dialect, "INTEGER"), (mysql_dialect, "SIGNED")]
+ {
+ let unparser = Unparser::new(&dialect);
+ let expr = Expr::Cast(Cast {
+ expr: Box::new(col("a")),
+ data_type: DataType::Int32,
+ });
+ let ast = unparser.expr_to_sql(&expr)?;
+
+ let actual = format!("{}", ast);
+ let expected = format!(r#"CAST(a AS {identifier})"#);
+
+ assert_eq!(actual, expected);
+ }
+ Ok(())
+ }
+
#[test]
fn custom_dialect_with_timestamp_cast_dtype() -> Result<()> {
let default_dialect = CustomDialectBuilder::new().build();
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]