This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 3d1d28d287 fix: Add Int32 type override for Dialects (#12916)
3d1d28d287 is described below

commit 3d1d28d287d6584668bde510908f65ebe262d22e
Author: peasee <[email protected]>
AuthorDate: Thu Oct 17 21:33:40 2024 +1000

    fix: Add Int32 type override for Dialects (#12916)
    
    * fix: Add Int32 type override for Dialects
    
    * fix: Dialect builder with_int32_cast_dtype:
    
    * test: Fix with_int32 test
---
 datafusion/sql/src/unparser/dialect.rs | 25 +++++++++++++++++++++++++
 datafusion/sql/src/unparser/expr.rs    | 30 +++++++++++++++++++++++++++++-
 2 files changed, 54 insertions(+), 1 deletion(-)

diff --git a/datafusion/sql/src/unparser/dialect.rs 
b/datafusion/sql/src/unparser/dialect.rs
index aef3b0dfab..cfc28f2c49 100644
--- a/datafusion/sql/src/unparser/dialect.rs
+++ b/datafusion/sql/src/unparser/dialect.rs
@@ -86,6 +86,12 @@ pub trait Dialect: Send + Sync {
         ast::DataType::BigInt(None)
     }
 
+    /// The SQL type to use for Arrow Int32 unparsing
+    /// Most dialects use Integer, but some, like MySQL, require SIGNED
+    fn int32_cast_dtype(&self) -> ast::DataType {
+        ast::DataType::Integer(None)
+    }
+
     /// The SQL type to use for Timestamp unparsing
     /// Most dialects use Timestamp, but some, like MySQL, require Datetime
     /// Some dialects like Dremio does not support WithTimeZone and requires 
always Timestamp
@@ -282,6 +288,10 @@ impl Dialect for MySqlDialect {
         ast::DataType::Custom(ObjectName(vec![Ident::new("SIGNED")]), vec![])
     }
 
+    fn int32_cast_dtype(&self) -> ast::DataType {
+        ast::DataType::Custom(ObjectName(vec![Ident::new("SIGNED")]), vec![])
+    }
+
     fn timestamp_cast_dtype(
         &self,
         _time_unit: &TimeUnit,
@@ -347,6 +357,7 @@ pub struct CustomDialect {
     large_utf8_cast_dtype: ast::DataType,
     date_field_extract_style: DateFieldExtractStyle,
     int64_cast_dtype: ast::DataType,
+    int32_cast_dtype: ast::DataType,
     timestamp_cast_dtype: ast::DataType,
     timestamp_tz_cast_dtype: ast::DataType,
     date32_cast_dtype: sqlparser::ast::DataType,
@@ -365,6 +376,7 @@ impl Default for CustomDialect {
             large_utf8_cast_dtype: ast::DataType::Text,
             date_field_extract_style: DateFieldExtractStyle::DatePart,
             int64_cast_dtype: ast::DataType::BigInt(None),
+            int32_cast_dtype: ast::DataType::Integer(None),
             timestamp_cast_dtype: ast::DataType::Timestamp(None, 
TimezoneInfo::None),
             timestamp_tz_cast_dtype: ast::DataType::Timestamp(
                 None,
@@ -424,6 +436,10 @@ impl Dialect for CustomDialect {
         self.int64_cast_dtype.clone()
     }
 
+    fn int32_cast_dtype(&self) -> ast::DataType {
+        self.int32_cast_dtype.clone()
+    }
+
     fn timestamp_cast_dtype(
         &self,
         _time_unit: &TimeUnit,
@@ -482,6 +498,7 @@ pub struct CustomDialectBuilder {
     large_utf8_cast_dtype: ast::DataType,
     date_field_extract_style: DateFieldExtractStyle,
     int64_cast_dtype: ast::DataType,
+    int32_cast_dtype: ast::DataType,
     timestamp_cast_dtype: ast::DataType,
     timestamp_tz_cast_dtype: ast::DataType,
     date32_cast_dtype: ast::DataType,
@@ -506,6 +523,7 @@ impl CustomDialectBuilder {
             large_utf8_cast_dtype: ast::DataType::Text,
             date_field_extract_style: DateFieldExtractStyle::DatePart,
             int64_cast_dtype: ast::DataType::BigInt(None),
+            int32_cast_dtype: ast::DataType::Integer(None),
             timestamp_cast_dtype: ast::DataType::Timestamp(None, 
TimezoneInfo::None),
             timestamp_tz_cast_dtype: ast::DataType::Timestamp(
                 None,
@@ -527,6 +545,7 @@ impl CustomDialectBuilder {
             large_utf8_cast_dtype: self.large_utf8_cast_dtype,
             date_field_extract_style: self.date_field_extract_style,
             int64_cast_dtype: self.int64_cast_dtype,
+            int32_cast_dtype: self.int32_cast_dtype,
             timestamp_cast_dtype: self.timestamp_cast_dtype,
             timestamp_tz_cast_dtype: self.timestamp_tz_cast_dtype,
             date32_cast_dtype: self.date32_cast_dtype,
@@ -604,6 +623,12 @@ impl CustomDialectBuilder {
         self
     }
 
+    /// Customize the dialect with a specific SQL type for Int32 casting: 
Integer, SIGNED, etc.
+    pub fn with_int32_cast_dtype(mut self, int32_cast_dtype: ast::DataType) -> 
Self {
+        self.int32_cast_dtype = int32_cast_dtype;
+        self
+    }
+
     /// Customize the dialect with a specific SQL type for Timestamp casting: 
Timestamp, Datetime, etc.
     pub fn with_timestamp_cast_dtype(
         mut self,
diff --git a/datafusion/sql/src/unparser/expr.rs 
b/datafusion/sql/src/unparser/expr.rs
index b7491d1f88..1be5aa68bf 100644
--- a/datafusion/sql/src/unparser/expr.rs
+++ b/datafusion/sql/src/unparser/expr.rs
@@ -1352,7 +1352,7 @@ impl Unparser<'_> {
             DataType::Boolean => Ok(ast::DataType::Bool),
             DataType::Int8 => Ok(ast::DataType::TinyInt(None)),
             DataType::Int16 => Ok(ast::DataType::SmallInt(None)),
-            DataType::Int32 => Ok(ast::DataType::Integer(None)),
+            DataType::Int32 => Ok(self.dialect.int32_cast_dtype()),
             DataType::Int64 => Ok(self.dialect.int64_cast_dtype()),
             DataType::UInt8 => Ok(ast::DataType::UnsignedTinyInt(None)),
             DataType::UInt16 => Ok(ast::DataType::UnsignedSmallInt(None)),
@@ -2253,6 +2253,34 @@ mod tests {
         Ok(())
     }
 
+    #[test]
+    fn custom_dialect_with_int32_cast_dtype() -> Result<()> {
+        let default_dialect = CustomDialectBuilder::new().build();
+        let mysql_dialect = CustomDialectBuilder::new()
+            .with_int32_cast_dtype(ast::DataType::Custom(
+                ObjectName(vec![Ident::new("SIGNED")]),
+                vec![],
+            ))
+            .build();
+
+        for (dialect, identifier) in
+            [(default_dialect, "INTEGER"), (mysql_dialect, "SIGNED")]
+        {
+            let unparser = Unparser::new(&dialect);
+            let expr = Expr::Cast(Cast {
+                expr: Box::new(col("a")),
+                data_type: DataType::Int32,
+            });
+            let ast = unparser.expr_to_sql(&expr)?;
+
+            let actual = format!("{}", ast);
+            let expected = format!(r#"CAST(a AS {identifier})"#);
+
+            assert_eq!(actual, expected);
+        }
+        Ok(())
+    }
+
     #[test]
     fn custom_dialect_with_timestamp_cast_dtype() -> Result<()> {
         let default_dialect = CustomDialectBuilder::new().build();


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to