This is an automated email from the ASF dual-hosted git repository.

github-bot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 88fa0dfce6 Add `Field` to `Expr::Cast` -- allow logical expressions to 
express a cast to an extension type (#18136)
88fa0dfce6 is described below

commit 88fa0dfce64d146508319fd3dd8a3eba137ed410
Author: Dewey Dunnington <[email protected]>
AuthorDate: Tue Mar 3 11:01:16 2026 -0600

    Add `Field` to `Expr::Cast` -- allow logical expressions to express a cast 
to an extension type (#18136)
    
    ## Which issue does this PR close?
    
    - Closes #18060.
    
    I am sorry that I missed the previous PR implementing this (
    https://github.com/apache/datafusion/pull/18120 ) and I'm also happy to
    review that one instead of updating this!
    
    ## Rationale for this change
    
    Other systems that interact with the logical plan (e.g., SQL, Substrait)
    can express types that are not strictly within the arrow DataType enum.
    
    ## What changes are included in this PR?
    
    For the Cast and TryCast structs, the destination data type was changed
    from a DataType to a FieldRef.
    
    ## Are these changes tested?
    
    Yes.
    
    ## Are there any user-facing changes?
    
    Yes, any code using `Cast { .. }` to create an expression would need to
    use `Cast::new()` instead (or pass on field metadata if it has it).
    Existing matches will need to be upated for the `data_type` -> `field`
    member rename.
    
    ---------
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 .../provider_filter_pushdown.rs                    |   2 +-
 .../user_defined/user_defined_scalar_functions.rs  |   5 +-
 datafusion/expr/src/expr.rs                        |  60 +++++---
 datafusion/expr/src/expr_rewriter/order_by.rs      |   8 +-
 datafusion/expr/src/expr_schema.rs                 |  23 ++-
 datafusion/expr/src/tree_node.rs                   |   8 +-
 datafusion/functions/src/core/arrow_cast.rs        |   8 +-
 datafusion/optimizer/src/eliminate_outer_join.rs   |   4 +-
 .../src/simplify_expressions/expr_simplifier.rs    |   5 +-
 datafusion/physical-expr/src/planner.rs            |  89 ++++++++++--
 datafusion/proto/proto/datafusion.proto            |   4 +
 datafusion/proto/src/generated/pbjson.rs           |  72 ++++++++++
 datafusion/proto/src/generated/prost.rs            |  14 ++
 datafusion/proto/src/logical_plan/from_proto.rs    |  20 ++-
 datafusion/proto/src/logical_plan/to_proto.rs      |  12 +-
 datafusion/sql/src/unparser/expr.rs                | 160 +++++++++------------
 .../src/logical_plan/consumer/expr/cast.rs         |  11 +-
 .../substrait/src/logical_plan/consumer/types.rs   |  20 ++-
 .../src/logical_plan/producer/expr/cast.rs         |  16 +--
 .../substrait/src/logical_plan/producer/types.rs   |  44 +++---
 20 files changed, 387 insertions(+), 198 deletions(-)

diff --git 
a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs 
b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs
index 6bb119fcb7..8078b0a7ec 100644
--- a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs
+++ b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs
@@ -207,7 +207,7 @@ impl TableProvider for CustomProvider {
                     Expr::Literal(ScalarValue::Int16(Some(i)), _) => *i as i64,
                     Expr::Literal(ScalarValue::Int32(Some(i)), _) => *i as i64,
                     Expr::Literal(ScalarValue::Int64(Some(i)), _) => *i,
-                    Expr::Cast(Cast { expr, data_type: _ }) => match 
expr.deref() {
+                    Expr::Cast(Cast { expr, field: _ }) => match expr.deref() {
                         Expr::Literal(lit_value, _) => match lit_value {
                             ScalarValue::Int8(Some(v)) => *v as i64,
                             ScalarValue::Int16(Some(v)) => *v as i64,
diff --git 
a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs 
b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
index b4ce3a03db..025ee9767c 100644
--- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
+++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
@@ -715,10 +715,7 @@ impl ScalarUDFImpl for CastToI64UDF {
             arg
         } else {
             // need to use an actual cast to get the correct type
-            Expr::Cast(datafusion_expr::Cast {
-                expr: Box::new(arg),
-                data_type: DataType::Int64,
-            })
+            Expr::Cast(datafusion_expr::Cast::new(Box::new(arg), 
DataType::Int64))
         };
         // return the newly written argument to DataFusion
         Ok(ExprSimplifyResult::Simplified(new_expr))
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index 87e8e029a6..ef0187a878 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -32,6 +32,8 @@ use crate::{ExprSchemable, Operator, Signature, WindowFrame, 
WindowUDF};
 
 use arrow::datatypes::{DataType, Field, FieldRef};
 use datafusion_common::cse::{HashNode, NormalizeEq, Normalizeable};
+use datafusion_common::datatype::DataTypeExt;
+use datafusion_common::metadata::format_type_and_metadata;
 use datafusion_common::tree_node::{
     Transformed, TransformedResult, TreeNode, TreeNodeContainer, 
TreeNodeRecursion,
 };
@@ -800,13 +802,20 @@ pub struct Cast {
     /// The expression being cast
     pub expr: Box<Expr>,
     /// The `DataType` the expression will yield
-    pub data_type: DataType,
+    pub field: FieldRef,
 }
 
 impl Cast {
     /// Create a new Cast expression
     pub fn new(expr: Box<Expr>, data_type: DataType) -> Self {
-        Self { expr, data_type }
+        Self {
+            expr,
+            field: data_type.into_nullable_field_ref(),
+        }
+    }
+
+    pub fn new_from_field(expr: Box<Expr>, field: FieldRef) -> Self {
+        Self { expr, field }
     }
 }
 
@@ -816,13 +825,20 @@ pub struct TryCast {
     /// The expression being cast
     pub expr: Box<Expr>,
     /// The `DataType` the expression will yield
-    pub data_type: DataType,
+    pub field: FieldRef,
 }
 
 impl TryCast {
     /// Create a new TryCast expression
     pub fn new(expr: Box<Expr>, data_type: DataType) -> Self {
-        Self { expr, data_type }
+        Self {
+            expr,
+            field: data_type.into_nullable_field_ref(),
+        }
+    }
+
+    pub fn new_from_field(expr: Box<Expr>, field: FieldRef) -> Self {
+        Self { expr, field }
     }
 }
 
@@ -2323,23 +2339,23 @@ impl NormalizeEq for Expr {
             (
                 Expr::Cast(Cast {
                     expr: self_expr,
-                    data_type: self_data_type,
+                    field: self_field,
                 }),
                 Expr::Cast(Cast {
                     expr: other_expr,
-                    data_type: other_data_type,
+                    field: other_field,
                 }),
             )
             | (
                 Expr::TryCast(TryCast {
                     expr: self_expr,
-                    data_type: self_data_type,
+                    field: self_field,
                 }),
                 Expr::TryCast(TryCast {
                     expr: other_expr,
-                    data_type: other_data_type,
+                    field: other_field,
                 }),
-            ) => self_data_type == other_data_type && 
self_expr.normalize_eq(other_expr),
+            ) => self_field == other_field && 
self_expr.normalize_eq(other_expr),
             (
                 Expr::ScalarFunction(ScalarFunction {
                     func: self_func,
@@ -2655,15 +2671,9 @@ impl HashNode for Expr {
                 when_then_expr: _when_then_expr,
                 else_expr: _else_expr,
             }) => {}
-            Expr::Cast(Cast {
-                expr: _expr,
-                data_type,
-            })
-            | Expr::TryCast(TryCast {
-                expr: _expr,
-                data_type,
-            }) => {
-                data_type.hash(state);
+            Expr::Cast(Cast { expr: _expr, field })
+            | Expr::TryCast(TryCast { expr: _expr, field }) => {
+                field.hash(state);
             }
             Expr::ScalarFunction(ScalarFunction { func, args: _args }) => {
                 func.hash(state);
@@ -3369,11 +3379,15 @@ impl Display for Expr {
                 }
                 write!(f, "END")
             }
-            Expr::Cast(Cast { expr, data_type }) => {
-                write!(f, "CAST({expr} AS {data_type})")
+            Expr::Cast(Cast { expr, field }) => {
+                let formatted =
+                    format_type_and_metadata(field.data_type(), 
Some(field.metadata()));
+                write!(f, "CAST({expr} AS {formatted})")
             }
-            Expr::TryCast(TryCast { expr, data_type }) => {
-                write!(f, "TRY_CAST({expr} AS {data_type})")
+            Expr::TryCast(TryCast { expr, field }) => {
+                let formatted =
+                    format_type_and_metadata(field.data_type(), 
Some(field.metadata()));
+                write!(f, "TRY_CAST({expr} AS {formatted})")
             }
             Expr::Not(expr) => write!(f, "NOT {expr}"),
             Expr::Negative(expr) => write!(f, "(- {expr})"),
@@ -3765,7 +3779,7 @@ mod test {
     fn format_cast() -> Result<()> {
         let expr = Expr::Cast(Cast {
             expr: Box::new(Expr::Literal(ScalarValue::Float32(Some(1.23)), 
None)),
-            data_type: DataType::Utf8,
+            field: DataType::Utf8.into_nullable_field_ref(),
         });
         let expected_canonical = "CAST(Float32(1.23) AS Utf8)";
         assert_eq!(expected_canonical, format!("{expr}"));
diff --git a/datafusion/expr/src/expr_rewriter/order_by.rs 
b/datafusion/expr/src/expr_rewriter/order_by.rs
index ec22be5254..7c6af56c8d 100644
--- a/datafusion/expr/src/expr_rewriter/order_by.rs
+++ b/datafusion/expr/src/expr_rewriter/order_by.rs
@@ -116,13 +116,13 @@ fn rewrite_in_terms_of_projection(
 
         if let Some(found) = found {
             return Ok(Transformed::yes(match normalized_expr {
-                Expr::Cast(Cast { expr: _, data_type }) => Expr::Cast(Cast {
+                Expr::Cast(Cast { expr: _, field }) => Expr::Cast(Cast {
                     expr: Box::new(found),
-                    data_type,
+                    field,
                 }),
-                Expr::TryCast(TryCast { expr: _, data_type }) => 
Expr::TryCast(TryCast {
+                Expr::TryCast(TryCast { expr: _, field }) => 
Expr::TryCast(TryCast {
                     expr: Box::new(found),
-                    data_type,
+                    field,
                 }),
                 _ => found,
             }));
diff --git a/datafusion/expr/src/expr_schema.rs 
b/datafusion/expr/src/expr_schema.rs
index f4e4f014f5..f2ac777af3 100644
--- a/datafusion/expr/src/expr_schema.rs
+++ b/datafusion/expr/src/expr_schema.rs
@@ -132,8 +132,9 @@ impl ExprSchemable for Expr {
                     .as_ref()
                     .map_or(Ok(DataType::Null), |e| e.get_type(schema))
             }
-            Expr::Cast(Cast { data_type, .. })
-            | Expr::TryCast(TryCast { data_type, .. }) => 
Ok(data_type.clone()),
+            Expr::Cast(Cast { field, .. }) | Expr::TryCast(TryCast { field, .. 
}) => {
+                Ok(field.data_type().clone())
+            }
             Expr::Unnest(Unnest { expr }) => {
                 let arg_data_type = expr.get_type(schema)?;
                 // Unnest's output type is the inner type of the list
@@ -550,9 +551,23 @@ impl ExprSchemable for Expr {
                 func.return_field_from_args(args)
             }
             // _ => Ok((self.get_type(schema)?, self.nullable(schema)?)),
-            Expr::Cast(Cast { expr, data_type }) => expr
+            Expr::Cast(Cast { expr, field }) => expr
                 .to_field(schema)
-                .map(|(_, f)| f.retyped(data_type.clone())),
+                .map(|(_table_ref, destination_field)| {
+                    // This propagates the nullability of the input rather than
+                    // force the nullability of the destination field. This is
+                    // usually the desired behaviour (i.e., specifying a cast
+                    // destination type usually does not force a user to pick
+                    // nullability, and assuming `true` would prevent the 
non-nullability
+                    // of the parent expression to make the result eligible for
+                    // optimizations that only apply to non-nullable values).
+                    destination_field
+                        .as_ref()
+                        .clone()
+                        .with_data_type(field.data_type().clone())
+                        .with_metadata(destination_field.metadata().clone())
+                })
+                .map(Arc::new),
             Expr::Placeholder(Placeholder {
                 id: _,
                 field: Some(field),
diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs
index 226c512a97..f3bec6bbf9 100644
--- a/datafusion/expr/src/tree_node.rs
+++ b/datafusion/expr/src/tree_node.rs
@@ -234,12 +234,12 @@ impl TreeNode for Expr {
                 .update_data(|(new_expr, new_when_then_expr, new_else_expr)| {
                     Expr::Case(Case::new(new_expr, new_when_then_expr, 
new_else_expr))
                 }),
-            Expr::Cast(Cast { expr, data_type }) => expr
+            Expr::Cast(Cast { expr, field }) => expr
                 .map_elements(f)?
-                .update_data(|be| Expr::Cast(Cast::new(be, data_type))),
-            Expr::TryCast(TryCast { expr, data_type }) => expr
+                .update_data(|be| Expr::Cast(Cast::new_from_field(be, field))),
+            Expr::TryCast(TryCast { expr, field }) => expr
                 .map_elements(f)?
-                .update_data(|be| Expr::TryCast(TryCast::new(be, data_type))),
+                .update_data(|be| Expr::TryCast(TryCast::new_from_field(be, 
field))),
             Expr::ScalarFunction(ScalarFunction { func, args }) => {
                 args.map_elements(f)?.map_data(|new_args| {
                     Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
diff --git a/datafusion/functions/src/core/arrow_cast.rs 
b/datafusion/functions/src/core/arrow_cast.rs
index 7c24450adf..e555081e41 100644
--- a/datafusion/functions/src/core/arrow_cast.rs
+++ b/datafusion/functions/src/core/arrow_cast.rs
@@ -19,11 +19,11 @@
 
 use arrow::datatypes::{DataType, Field, FieldRef};
 use arrow::error::ArrowError;
-use datafusion_common::types::logical_string;
 use datafusion_common::{
-    Result, ScalarValue, arrow_datafusion_err, exec_err, internal_err,
+    Result, ScalarValue, arrow_datafusion_err, datatype::DataTypeExt,
+    exec_datafusion_err, exec_err, internal_err, types::logical_string,
+    utils::take_function_args,
 };
-use datafusion_common::{exec_datafusion_err, utils::take_function_args};
 use std::any::Any;
 
 use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext};
@@ -176,7 +176,7 @@ impl ScalarUDFImpl for ArrowCastFunc {
             // Use an actual cast to get the correct type
             Expr::Cast(datafusion_expr::Cast {
                 expr: Box::new(arg),
-                data_type: target_type,
+                field: target_type.into_nullable_field_ref(),
             })
         };
         // return the newly written argument to DataFusion
diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs 
b/datafusion/optimizer/src/eliminate_outer_join.rs
index 58abe38d04..5c47d6b7c5 100644
--- a/datafusion/optimizer/src/eliminate_outer_join.rs
+++ b/datafusion/optimizer/src/eliminate_outer_join.rs
@@ -290,8 +290,8 @@ fn extract_non_nullable_columns(
                 false,
             )
         }
-        Expr::Cast(Cast { expr, data_type: _ })
-        | Expr::TryCast(TryCast { expr, data_type: _ }) => 
extract_non_nullable_columns(
+        Expr::Cast(Cast { expr, field: _ })
+        | Expr::TryCast(TryCast { expr, field: _ }) => 
extract_non_nullable_columns(
             expr,
             non_nullable_cols,
             left_schema,
diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs 
b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
index fe2e1a3b04..28fcdf1ded 100644
--- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
+++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
@@ -646,12 +646,11 @@ impl ConstEvaluator {
             Expr::ScalarFunction(ScalarFunction { func, .. }) => {
                 Self::volatility_ok(func.signature().volatility)
             }
-            Expr::Cast(Cast { expr, data_type })
-            | Expr::TryCast(TryCast { expr, data_type }) => {
+            Expr::Cast(Cast { expr, field }) | Expr::TryCast(TryCast { expr, 
field }) => {
                 if let (
                     Ok(DataType::Struct(source_fields)),
                     DataType::Struct(target_fields),
-                ) = (expr.get_type(&DFSchema::empty()), data_type)
+                ) = (expr.get_type(&DFSchema::empty()), field.data_type())
                 {
                     // Don't const-fold struct casts with different field 
counts
                     if source_fields.len() != target_fields.len() {
diff --git a/datafusion/physical-expr/src/planner.rs 
b/datafusion/physical-expr/src/planner.rs
index 84a6aa4309..5c170700d9 100644
--- a/datafusion/physical-expr/src/planner.rs
+++ b/datafusion/physical-expr/src/planner.rs
@@ -25,7 +25,7 @@ use crate::{
 
 use arrow::datatypes::Schema;
 use datafusion_common::config::ConfigOptions;
-use datafusion_common::metadata::FieldMetadata;
+use datafusion_common::metadata::{FieldMetadata, format_type_and_metadata};
 use datafusion_common::{
     DFSchema, Result, ScalarValue, ToDFSchema, exec_err, not_impl_err, 
plan_err,
 };
@@ -34,7 +34,7 @@ use datafusion_expr::expr::{Alias, Cast, InList, Placeholder, 
ScalarFunction};
 use datafusion_expr::var_provider::VarType;
 use datafusion_expr::var_provider::is_system_variables;
 use datafusion_expr::{
-    Between, BinaryExpr, Expr, Like, Operator, TryCast, binary_expr, lit,
+    Between, BinaryExpr, Expr, ExprSchemable, Like, Operator, TryCast, 
binary_expr, lit,
 };
 
 /// [PhysicalExpr] evaluate DataFusion expressions such as `A + 1`, or `CAST(c1
@@ -288,16 +288,44 @@ pub fn create_physical_expr(
                 };
             Ok(expressions::case(expr, when_then_expr, else_expr)?)
         }
-        Expr::Cast(Cast { expr, data_type }) => expressions::cast(
-            create_physical_expr(expr, input_dfschema, execution_props)?,
-            input_schema,
-            data_type.clone(),
-        ),
-        Expr::TryCast(TryCast { expr, data_type }) => expressions::try_cast(
-            create_physical_expr(expr, input_dfschema, execution_props)?,
-            input_schema,
-            data_type.clone(),
-        ),
+        Expr::Cast(Cast { expr, field }) => {
+            if !field.metadata().is_empty() {
+                let (_, src_field) = expr.to_field(input_dfschema)?;
+                return plan_err!(
+                    "Cast from {} to {} is not supported",
+                    format_type_and_metadata(
+                        src_field.data_type(),
+                        Some(src_field.metadata()),
+                    ),
+                    format_type_and_metadata(field.data_type(), 
Some(field.metadata()))
+                );
+            }
+
+            expressions::cast(
+                create_physical_expr(expr, input_dfschema, execution_props)?,
+                input_schema,
+                field.data_type().clone(),
+            )
+        }
+        Expr::TryCast(TryCast { expr, field }) => {
+            if !field.metadata().is_empty() {
+                let (_, src_field) = expr.to_field(input_dfschema)?;
+                return plan_err!(
+                    "TryCast from {} to {} is not supported",
+                    format_type_and_metadata(
+                        src_field.data_type(),
+                        Some(src_field.metadata()),
+                    ),
+                    format_type_and_metadata(field.data_type(), 
Some(field.metadata()))
+                );
+            }
+
+            expressions::try_cast(
+                create_physical_expr(expr, input_dfschema, execution_props)?,
+                input_schema,
+                field.data_type().clone(),
+            )
+        }
         Expr::Not(expr) => {
             expressions::not(create_physical_expr(expr, input_dfschema, 
execution_props)?)
         }
@@ -417,7 +445,7 @@ pub fn logical2physical(expr: &Expr, schema: &Schema) -> 
Arc<dyn PhysicalExpr> {
 mod tests {
     use arrow::array::{ArrayRef, BooleanArray, RecordBatch, StringArray};
     use arrow::datatypes::{DataType, Field};
-
+    use datafusion_common::datatype::DataTypeExt;
     use datafusion_expr::{Operator, col, lit};
 
     use super::*;
@@ -447,6 +475,41 @@ mod tests {
         Ok(())
     }
 
+    #[test]
+    fn test_cast_to_extension_type() -> Result<()> {
+        let extension_field_type = Arc::new(
+            DataType::FixedSizeBinary(16)
+                .into_nullable_field()
+                .with_metadata(
+                    [("ARROW:extension:name".to_string(), 
"arrow.uuid".to_string())]
+                        .into(),
+                ),
+        );
+        let expr = lit("3230e5d4-888e-408b-b09b-831f44aa0c58");
+        let cast_expr = Expr::Cast(Cast::new_from_field(
+            Box::new(expr.clone()),
+            Arc::clone(&extension_field_type),
+        ));
+        let err =
+            create_physical_expr(&cast_expr, &DFSchema::empty(), 
&ExecutionProps::new())
+                .unwrap_err();
+        assert!(err.message().contains("arrow.uuid"));
+
+        let try_cast_expr = Expr::TryCast(TryCast::new_from_field(
+            Box::new(expr.clone()),
+            Arc::clone(&extension_field_type),
+        ));
+        let err = create_physical_expr(
+            &try_cast_expr,
+            &DFSchema::empty(),
+            &ExecutionProps::new(),
+        )
+        .unwrap_err();
+        assert!(err.message().contains("arrow.uuid"));
+
+        Ok(())
+    }
+
     /// Test that deeply nested expressions do not cause a stack overflow.
     ///
     /// This test only runs when the `recursive_protection` feature is enabled,
diff --git a/datafusion/proto/proto/datafusion.proto 
b/datafusion/proto/proto/datafusion.proto
index 7c02688676..e64bcdb41e 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -594,11 +594,15 @@ message WhenThen {
 message CastNode {
   LogicalExprNode expr = 1;
   datafusion_common.ArrowType arrow_type = 2;
+  map<string, string> metadata = 3;
+  optional bool nullable = 4;
 }
 
 message TryCastNode {
   LogicalExprNode expr = 1;
   datafusion_common.ArrowType arrow_type = 2;
+  map<string, string> metadata = 3;
+  optional bool nullable = 4;
 }
 
 message SortExprNode {
diff --git a/datafusion/proto/src/generated/pbjson.rs 
b/datafusion/proto/src/generated/pbjson.rs
index 5b2b9133ce..15b9ba88f4 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -2203,6 +2203,12 @@ impl serde::Serialize for CastNode {
         if self.arrow_type.is_some() {
             len += 1;
         }
+        if !self.metadata.is_empty() {
+            len += 1;
+        }
+        if self.nullable.is_some() {
+            len += 1;
+        }
         let mut struct_ser = 
serializer.serialize_struct("datafusion.CastNode", len)?;
         if let Some(v) = self.expr.as_ref() {
             struct_ser.serialize_field("expr", v)?;
@@ -2210,6 +2216,12 @@ impl serde::Serialize for CastNode {
         if let Some(v) = self.arrow_type.as_ref() {
             struct_ser.serialize_field("arrowType", v)?;
         }
+        if !self.metadata.is_empty() {
+            struct_ser.serialize_field("metadata", &self.metadata)?;
+        }
+        if let Some(v) = self.nullable.as_ref() {
+            struct_ser.serialize_field("nullable", v)?;
+        }
         struct_ser.end()
     }
 }
@@ -2223,12 +2235,16 @@ impl<'de> serde::Deserialize<'de> for CastNode {
             "expr",
             "arrow_type",
             "arrowType",
+            "metadata",
+            "nullable",
         ];
 
         #[allow(clippy::enum_variant_names)]
         enum GeneratedField {
             Expr,
             ArrowType,
+            Metadata,
+            Nullable,
         }
         impl<'de> serde::Deserialize<'de> for GeneratedField {
             fn deserialize<D>(deserializer: D) -> 
std::result::Result<GeneratedField, D::Error>
@@ -2252,6 +2268,8 @@ impl<'de> serde::Deserialize<'de> for CastNode {
                         match value {
                             "expr" => Ok(GeneratedField::Expr),
                             "arrowType" | "arrow_type" => 
Ok(GeneratedField::ArrowType),
+                            "metadata" => Ok(GeneratedField::Metadata),
+                            "nullable" => Ok(GeneratedField::Nullable),
                             _ => Err(serde::de::Error::unknown_field(value, 
FIELDS)),
                         }
                     }
@@ -2273,6 +2291,8 @@ impl<'de> serde::Deserialize<'de> for CastNode {
             {
                 let mut expr__ = None;
                 let mut arrow_type__ = None;
+                let mut metadata__ = None;
+                let mut nullable__ = None;
                 while let Some(k) = map_.next_key()? {
                     match k {
                         GeneratedField::Expr => {
@@ -2287,11 +2307,27 @@ impl<'de> serde::Deserialize<'de> for CastNode {
                             }
                             arrow_type__ = map_.next_value()?;
                         }
+                        GeneratedField::Metadata => {
+                            if metadata__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("metadata"));
+                            }
+                            metadata__ = Some(
+                                map_.next_value::<std::collections::HashMap<_, 
_>>()?
+                            );
+                        }
+                        GeneratedField::Nullable => {
+                            if nullable__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("nullable"));
+                            }
+                            nullable__ = map_.next_value()?;
+                        }
                     }
                 }
                 Ok(CastNode {
                     expr: expr__,
                     arrow_type: arrow_type__,
+                    metadata: metadata__.unwrap_or_default(),
+                    nullable: nullable__,
                 })
             }
         }
@@ -23040,6 +23076,12 @@ impl serde::Serialize for TryCastNode {
         if self.arrow_type.is_some() {
             len += 1;
         }
+        if !self.metadata.is_empty() {
+            len += 1;
+        }
+        if self.nullable.is_some() {
+            len += 1;
+        }
         let mut struct_ser = 
serializer.serialize_struct("datafusion.TryCastNode", len)?;
         if let Some(v) = self.expr.as_ref() {
             struct_ser.serialize_field("expr", v)?;
@@ -23047,6 +23089,12 @@ impl serde::Serialize for TryCastNode {
         if let Some(v) = self.arrow_type.as_ref() {
             struct_ser.serialize_field("arrowType", v)?;
         }
+        if !self.metadata.is_empty() {
+            struct_ser.serialize_field("metadata", &self.metadata)?;
+        }
+        if let Some(v) = self.nullable.as_ref() {
+            struct_ser.serialize_field("nullable", v)?;
+        }
         struct_ser.end()
     }
 }
@@ -23060,12 +23108,16 @@ impl<'de> serde::Deserialize<'de> for TryCastNode {
             "expr",
             "arrow_type",
             "arrowType",
+            "metadata",
+            "nullable",
         ];
 
         #[allow(clippy::enum_variant_names)]
         enum GeneratedField {
             Expr,
             ArrowType,
+            Metadata,
+            Nullable,
         }
         impl<'de> serde::Deserialize<'de> for GeneratedField {
             fn deserialize<D>(deserializer: D) -> 
std::result::Result<GeneratedField, D::Error>
@@ -23089,6 +23141,8 @@ impl<'de> serde::Deserialize<'de> for TryCastNode {
                         match value {
                             "expr" => Ok(GeneratedField::Expr),
                             "arrowType" | "arrow_type" => 
Ok(GeneratedField::ArrowType),
+                            "metadata" => Ok(GeneratedField::Metadata),
+                            "nullable" => Ok(GeneratedField::Nullable),
                             _ => Err(serde::de::Error::unknown_field(value, 
FIELDS)),
                         }
                     }
@@ -23110,6 +23164,8 @@ impl<'de> serde::Deserialize<'de> for TryCastNode {
             {
                 let mut expr__ = None;
                 let mut arrow_type__ = None;
+                let mut metadata__ = None;
+                let mut nullable__ = None;
                 while let Some(k) = map_.next_key()? {
                     match k {
                         GeneratedField::Expr => {
@@ -23124,11 +23180,27 @@ impl<'de> serde::Deserialize<'de> for TryCastNode {
                             }
                             arrow_type__ = map_.next_value()?;
                         }
+                        GeneratedField::Metadata => {
+                            if metadata__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("metadata"));
+                            }
+                            metadata__ = Some(
+                                map_.next_value::<std::collections::HashMap<_, 
_>>()?
+                            );
+                        }
+                        GeneratedField::Nullable => {
+                            if nullable__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("nullable"));
+                            }
+                            nullable__ = map_.next_value()?;
+                        }
                     }
                 }
                 Ok(TryCastNode {
                     expr: expr__,
                     arrow_type: arrow_type__,
+                    metadata: metadata__.unwrap_or_default(),
+                    nullable: nullable__,
                 })
             }
         }
diff --git a/datafusion/proto/src/generated/prost.rs 
b/datafusion/proto/src/generated/prost.rs
index d9602665c2..48ad8dcfea 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -922,6 +922,13 @@ pub struct CastNode {
     pub expr: 
::core::option::Option<::prost::alloc::boxed::Box<LogicalExprNode>>,
     #[prost(message, optional, tag = "2")]
     pub arrow_type: 
::core::option::Option<super::datafusion_common::ArrowType>,
+    #[prost(map = "string, string", tag = "3")]
+    pub metadata: ::std::collections::HashMap<
+        ::prost::alloc::string::String,
+        ::prost::alloc::string::String,
+    >,
+    #[prost(bool, optional, tag = "4")]
+    pub nullable: ::core::option::Option<bool>,
 }
 #[derive(Clone, PartialEq, ::prost::Message)]
 pub struct TryCastNode {
@@ -929,6 +936,13 @@ pub struct TryCastNode {
     pub expr: 
::core::option::Option<::prost::alloc::boxed::Box<LogicalExprNode>>,
     #[prost(message, optional, tag = "2")]
     pub arrow_type: 
::core::option::Option<super::datafusion_common::ArrowType>,
+    #[prost(map = "string, string", tag = "3")]
+    pub metadata: ::std::collections::HashMap<
+        ::prost::alloc::string::String,
+        ::prost::alloc::string::String,
+    >,
+    #[prost(bool, optional, tag = "4")]
+    pub nullable: ::core::option::Option<bool>,
 }
 #[derive(Clone, PartialEq, ::prost::Message)]
 pub struct SortExprNode {
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs 
b/datafusion/proto/src/logical_plan/from_proto.rs
index a653f517b7..ed33d9fab1 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -17,7 +17,8 @@
 
 use std::sync::Arc;
 
-use arrow::datatypes::Field;
+use arrow::datatypes::{DataType, Field};
+use datafusion_common::datatype::DataTypeExt;
 use datafusion_common::{
     NullEquality, RecursionUnnestOption, Result, ScalarValue, TableReference,
     UnnestOptions, exec_datafusion_err, internal_err, plan_datafusion_err,
@@ -527,8 +528,11 @@ pub fn parse_expr(
                 "expr",
                 codec,
             )?);
-            let data_type = cast.arrow_type.as_ref().required("arrow_type")?;
-            Ok(Expr::Cast(Cast::new(expr, data_type)))
+            let data_type: DataType = 
cast.arrow_type.as_ref().required("arrow_type")?;
+            let field = data_type
+                .into_nullable_field()
+                .with_nullable(cast.nullable.unwrap_or(true));
+            Ok(Expr::Cast(Cast::new_from_field(expr, Arc::new(field))))
         }
         ExprType::TryCast(cast) => {
             let expr = Box::new(parse_required_expr(
@@ -537,8 +541,14 @@ pub fn parse_expr(
                 "expr",
                 codec,
             )?);
-            let data_type = cast.arrow_type.as_ref().required("arrow_type")?;
-            Ok(Expr::TryCast(TryCast::new(expr, data_type)))
+            let data_type: DataType = 
cast.arrow_type.as_ref().required("arrow_type")?;
+            let field = data_type
+                .into_nullable_field()
+                .with_nullable(cast.nullable.unwrap_or(true));
+            Ok(Expr::TryCast(TryCast::new_from_field(
+                expr,
+                Arc::new(field),
+            )))
         }
         ExprType::Negative(negative) => Ok(Expr::Negative(Box::new(
             parse_required_expr(negative.expr.as_deref(), registry, "expr", 
codec)?,
diff --git a/datafusion/proto/src/logical_plan/to_proto.rs 
b/datafusion/proto/src/logical_plan/to_proto.rs
index fe63fce6ee..6fcb738992 100644
--- a/datafusion/proto/src/logical_plan/to_proto.rs
+++ b/datafusion/proto/src/logical_plan/to_proto.rs
@@ -521,19 +521,23 @@ pub fn serialize_expr(
                 expr_type: Some(ExprType::Case(expr)),
             }
         }
-        Expr::Cast(Cast { expr, data_type }) => {
+        Expr::Cast(Cast { expr, field }) => {
             let expr = Box::new(protobuf::CastNode {
                 expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)),
-                arrow_type: Some(data_type.try_into()?),
+                arrow_type: Some(field.data_type().try_into()?),
+                metadata: field.metadata().clone(),
+                nullable: Some(field.is_nullable()),
             });
             protobuf::LogicalExprNode {
                 expr_type: Some(ExprType::Cast(expr)),
             }
         }
-        Expr::TryCast(TryCast { expr, data_type }) => {
+        Expr::TryCast(TryCast { expr, field }) => {
             let expr = Box::new(protobuf::TryCastNode {
                 expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)),
-                arrow_type: Some(data_type.try_into()?),
+                arrow_type: Some(field.data_type().try_into()?),
+                metadata: field.metadata().clone(),
+                nullable: Some(field.is_nullable()),
             });
             protobuf::LogicalExprNode {
                 expr_type: Some(ExprType::TryCast(expr)),
diff --git a/datafusion/sql/src/unparser/expr.rs 
b/datafusion/sql/src/unparser/expr.rs
index 59a9207b51..bbe1f3dd9d 100644
--- a/datafusion/sql/src/unparser/expr.rs
+++ b/datafusion/sql/src/unparser/expr.rs
@@ -15,6 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use datafusion_common::datatype::DataTypeExt;
 use datafusion_expr::expr::{AggregateFunctionParams, Unnest, 
WindowFunctionParams};
 use sqlparser::ast::Value::SingleQuotedString;
 use sqlparser::ast::{
@@ -37,6 +38,7 @@ use arrow::array::{
 };
 use arrow::datatypes::{
     DataType, Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type, 
DecimalType,
+    FieldRef,
 };
 use arrow::util::display::array_value_to_string;
 use datafusion_common::{
@@ -188,9 +190,7 @@ impl Unparser<'_> {
                     end_token: AttachedToken::empty(),
                 })
             }
-            Expr::Cast(Cast { expr, data_type }) => {
-                Ok(self.cast_to_sql(expr, data_type)?)
-            }
+            Expr::Cast(Cast { expr, field }) => Ok(self.cast_to_sql(expr, 
field)?),
             Expr::Literal(value, _) => Ok(self.scalar_to_sql(value)?),
             Expr::Alias(Alias { expr, name: _, .. }) => 
self.expr_to_sql_inner(expr),
             Expr::WindowFunction(window_fun) => {
@@ -488,12 +488,12 @@ impl Unparser<'_> {
                     )
                 })
             }
-            Expr::TryCast(TryCast { expr, data_type }) => {
+            Expr::TryCast(TryCast { expr, field }) => {
                 let inner_expr = self.expr_to_sql_inner(expr)?;
                 Ok(ast::Expr::Cast {
                     kind: ast::CastKind::TryCast,
                     expr: Box::new(inner_expr),
-                    data_type: self.arrow_dtype_to_ast_dtype(data_type)?,
+                    data_type: self.arrow_dtype_to_ast_dtype(field)?,
                     array: false,
                     format: None,
                 })
@@ -1176,17 +1176,20 @@ impl Unparser<'_> {
 
     // Explicit type cast on ast::Expr::Value is not needed by underlying 
engine for certain types
     // For example: CAST(Utf8("binary_value") AS Binary) and  
CAST(Utf8("dictionary_value") AS Dictionary)
-    fn cast_to_sql(&self, expr: &Expr, data_type: &DataType) -> 
Result<ast::Expr> {
+    fn cast_to_sql(&self, expr: &Expr, field: &FieldRef) -> Result<ast::Expr> {
         let inner_expr = self.expr_to_sql_inner(expr)?;
+        let data_type = field.data_type();
         match inner_expr {
             ast::Expr::Value(_) => match data_type {
-                DataType::Dictionary(_, _) | DataType::Binary | 
DataType::BinaryView => {
+                DataType::Dictionary(_, _) | DataType::Binary | 
DataType::BinaryView
+                    if field.metadata().is_empty() =>
+                {
                     Ok(inner_expr)
                 }
                 _ => Ok(ast::Expr::Cast {
                     kind: ast::CastKind::Cast,
                     expr: Box::new(inner_expr),
-                    data_type: self.arrow_dtype_to_ast_dtype(data_type)?,
+                    data_type: self.arrow_dtype_to_ast_dtype(field)?,
                     array: false,
                     format: None,
                 }),
@@ -1194,7 +1197,7 @@ impl Unparser<'_> {
             _ => Ok(ast::Expr::Cast {
                 kind: ast::CastKind::Cast,
                 expr: Box::new(inner_expr),
-                data_type: self.arrow_dtype_to_ast_dtype(data_type)?,
+                data_type: self.arrow_dtype_to_ast_dtype(field)?,
                 array: false,
                 format: None,
             }),
@@ -1724,7 +1727,8 @@ impl Unparser<'_> {
         }))
     }
 
-    fn arrow_dtype_to_ast_dtype(&self, data_type: &DataType) -> 
Result<ast::DataType> {
+    fn arrow_dtype_to_ast_dtype(&self, field: &FieldRef) -> 
Result<ast::DataType> {
+        let data_type = field.data_type();
         match data_type {
             DataType::Null => {
                 not_impl_err!("Unsupported DataType: conversion: {data_type}")
@@ -1797,10 +1801,10 @@ impl Unparser<'_> {
             DataType::Union(_, _) => {
                 not_impl_err!("Unsupported DataType: conversion: {data_type}")
             }
-            DataType::Dictionary(_, val) => self.arrow_dtype_to_ast_dtype(val),
-            DataType::RunEndEncoded(_, val) => {
-                self.arrow_dtype_to_ast_dtype(val.data_type())
+            DataType::Dictionary(_, val) => {
+                
self.arrow_dtype_to_ast_dtype(&val.clone().into_nullable_field_ref())
             }
+            DataType::RunEndEncoded(_, val) => 
self.arrow_dtype_to_ast_dtype(val),
             DataType::Decimal32(precision, scale)
             | DataType::Decimal64(precision, scale)
             | DataType::Decimal128(precision, scale)
@@ -1938,34 +1942,25 @@ mod tests {
                 r#"CASE WHEN a IS NOT NULL THEN true ELSE false END"#,
             ),
             (
-                Expr::Cast(Cast {
-                    expr: Box::new(col("a")),
-                    data_type: DataType::Date64,
-                }),
+                Expr::Cast(Cast::new(Box::new(col("a")), DataType::Date64)),
                 r#"CAST(a AS DATETIME)"#,
             ),
             (
-                Expr::Cast(Cast {
-                    expr: Box::new(col("a")),
-                    data_type: DataType::Timestamp(
-                        TimeUnit::Nanosecond,
-                        Some("+08:00".into()),
-                    ),
-                }),
+                Expr::Cast(Cast::new(
+                    Box::new(col("a")),
+                    DataType::Timestamp(TimeUnit::Nanosecond, 
Some("+08:00".into())),
+                )),
                 r#"CAST(a AS TIMESTAMP WITH TIME ZONE)"#,
             ),
             (
-                Expr::Cast(Cast {
-                    expr: Box::new(col("a")),
-                    data_type: DataType::Timestamp(TimeUnit::Millisecond, 
None),
-                }),
+                Expr::Cast(Cast::new(
+                    Box::new(col("a")),
+                    DataType::Timestamp(TimeUnit::Millisecond, None),
+                )),
                 r#"CAST(a AS TIMESTAMP)"#,
             ),
             (
-                Expr::Cast(Cast {
-                    expr: Box::new(col("a")),
-                    data_type: DataType::UInt32,
-                }),
+                Expr::Cast(Cast::new(Box::new(col("a")), DataType::UInt32)),
                 r#"CAST(a AS INTEGER UNSIGNED)"#,
             ),
             (
@@ -2283,10 +2278,7 @@ mod tests {
                 r#"((a + b) > 100.123)"#,
             ),
             (
-                Expr::Cast(Cast {
-                    expr: Box::new(col("a")),
-                    data_type: DataType::Decimal128(10, -2),
-                }),
+                Expr::Cast(Cast::new(Box::new(col("a")), 
DataType::Decimal128(10, -2))),
                 r#"CAST(a AS DECIMAL(12,0))"#,
             ),
             (
@@ -2434,10 +2426,7 @@ mod tests {
                 .build();
             let unparser = Unparser::new(&dialect);
 
-            let expr = Expr::Cast(Cast {
-                expr: Box::new(col("a")),
-                data_type: DataType::Date64,
-            });
+            let expr = Expr::Cast(Cast::new(Box::new(col("a")), 
DataType::Date64));
             let ast = unparser.expr_to_sql(&expr)?;
 
             let actual = format!("{ast}");
@@ -2459,10 +2448,7 @@ mod tests {
                 .build();
             let unparser = Unparser::new(&dialect);
 
-            let expr = Expr::Cast(Cast {
-                expr: Box::new(col("a")),
-                data_type: DataType::Float64,
-            });
+            let expr = Expr::Cast(Cast::new(Box::new(col("a")), 
DataType::Float64));
             let ast = unparser.expr_to_sql(&expr)?;
 
             let actual = format!("{ast}");
@@ -2692,23 +2678,23 @@ mod tests {
     fn test_cast_value_to_binary_expr() {
         let tests = [
             (
-                Expr::Cast(Cast {
-                    expr: Box::new(Expr::Literal(
+                Expr::Cast(Cast::new(
+                    Box::new(Expr::Literal(
                         ScalarValue::Utf8(Some("blah".to_string())),
                         None,
                     )),
-                    data_type: DataType::Binary,
-                }),
+                    DataType::Binary,
+                )),
                 "'blah'",
             ),
             (
-                Expr::Cast(Cast {
-                    expr: Box::new(Expr::Literal(
+                Expr::Cast(Cast::new(
+                    Box::new(Expr::Literal(
                         ScalarValue::Utf8(Some("blah".to_string())),
                         None,
                     )),
-                    data_type: DataType::BinaryView,
-                }),
+                    DataType::BinaryView,
+                )),
                 "'blah'",
             ),
         ];
@@ -2739,10 +2725,7 @@ mod tests {
         ] {
             let unparser = Unparser::new(dialect);
 
-            let expr = Expr::Cast(Cast {
-                expr: Box::new(col("a")),
-                data_type,
-            });
+            let expr = Expr::Cast(Cast::new(Box::new(col("a")), data_type));
             let ast = unparser.expr_to_sql(&expr)?;
 
             let actual = format!("{ast}");
@@ -2825,10 +2808,7 @@ mod tests {
             [(default_dialect, "BIGINT"), (mysql_dialect, "SIGNED")]
         {
             let unparser = Unparser::new(&dialect);
-            let expr = Expr::Cast(Cast {
-                expr: Box::new(col("a")),
-                data_type: DataType::Int64,
-            });
+            let expr = Expr::Cast(Cast::new(Box::new(col("a")), 
DataType::Int64));
             let ast = unparser.expr_to_sql(&expr)?;
 
             let actual = format!("{ast}");
@@ -2853,10 +2833,7 @@ mod tests {
             [(default_dialect, "INTEGER"), (mysql_dialect, "SIGNED")]
         {
             let unparser = Unparser::new(&dialect);
-            let expr = Expr::Cast(Cast {
-                expr: Box::new(col("a")),
-                data_type: DataType::Int32,
-            });
+            let expr = Expr::Cast(Cast::new(Box::new(col("a")), 
DataType::Int32));
             let ast = unparser.expr_to_sql(&expr)?;
 
             let actual = format!("{ast}");
@@ -2892,10 +2869,7 @@ mod tests {
             (&mysql_dialect, &timestamp_with_tz, "DATETIME"),
         ] {
             let unparser = Unparser::new(dialect);
-            let expr = Expr::Cast(Cast {
-                expr: Box::new(col("a")),
-                data_type: data_type.clone(),
-            });
+            let expr = Expr::Cast(Cast::new(Box::new(col("a")), 
data_type.clone()));
             let ast = unparser.expr_to_sql(&expr)?;
 
             let actual = format!("{ast}");
@@ -2948,10 +2922,7 @@ mod tests {
         ] {
             let unparser = Unparser::new(dialect);
 
-            let expr = Expr::Cast(Cast {
-                expr: Box::new(col("a")),
-                data_type,
-            });
+            let expr = Expr::Cast(Cast::new(Box::new(col("a")), data_type));
             let ast = unparser.expr_to_sql(&expr)?;
 
             let actual = format!("{ast}");
@@ -2991,13 +2962,13 @@ mod tests {
     #[test]
     fn test_cast_value_to_dict_expr() {
         let tests = [(
-            Expr::Cast(Cast {
-                expr: Box::new(Expr::Literal(
+            Expr::Cast(Cast::new(
+                Box::new(Expr::Literal(
                     ScalarValue::Utf8(Some("variation".to_string())),
                     None,
                 )),
-                data_type: DataType::Dictionary(Box::new(Int8), 
Box::new(DataType::Utf8)),
-            }),
+                DataType::Dictionary(Box::new(Int8), Box::new(DataType::Utf8)),
+            )),
             "'variation'",
         )];
         for (value, expected) in tests {
@@ -3029,10 +3000,7 @@ mod tests {
                     datafusion_functions::math::round::RoundFunc::new(),
                 )),
                 args: vec![
-                    Expr::Cast(Cast {
-                        expr: Box::new(col("a")),
-                        data_type: DataType::Float64,
-                    }),
+                    Expr::Cast(Cast::new(Box::new(col("a")), 
DataType::Float64)),
                     Expr::Literal(ScalarValue::Int64(Some(2)), None),
                 ],
             });
@@ -3194,10 +3162,12 @@ mod tests {
 
         let unparser = Unparser::new(&dialect);
 
-        let ast_dtype = 
unparser.arrow_dtype_to_ast_dtype(&DataType::Dictionary(
-            Box::new(DataType::Int32),
-            Box::new(DataType::Utf8),
-        ))?;
+        let arrow_field = Arc::new(Field::new(
+            "",
+            DataType::Dictionary(Box::new(DataType::Int32), 
Box::new(DataType::Utf8)),
+            true,
+        ));
+        let ast_dtype = unparser.arrow_dtype_to_ast_dtype(&arrow_field)?;
 
         assert_eq!(ast_dtype, ast::DataType::Varchar(None));
 
@@ -3210,10 +3180,13 @@ mod tests {
 
         let unparser = Unparser::new(&dialect);
 
-        let ast_dtype = 
unparser.arrow_dtype_to_ast_dtype(&DataType::RunEndEncoded(
-            Field::new("run_ends", DataType::Int32, false).into(),
-            Field::new("values", DataType::Utf8, true).into(),
-        ))?;
+        let ast_dtype = unparser.arrow_dtype_to_ast_dtype(
+            &DataType::RunEndEncoded(
+                Field::new("run_ends", DataType::Int32, false).into(),
+                Field::new("values", DataType::Utf8, true).into(),
+            )
+            .into_nullable_field_ref(),
+        )?;
 
         assert_eq!(ast_dtype, ast::DataType::Varchar(None));
 
@@ -3227,7 +3200,8 @@ mod tests {
             .build();
         let unparser = Unparser::new(&dialect);
 
-        let ast_dtype = 
unparser.arrow_dtype_to_ast_dtype(&DataType::Utf8View)?;
+        let arrow_field = Arc::new(Field::new("", DataType::Utf8View, true));
+        let ast_dtype = unparser.arrow_dtype_to_ast_dtype(&arrow_field)?;
 
         assert_eq!(ast_dtype, ast::DataType::Char(None));
 
@@ -3295,10 +3269,10 @@ mod tests {
         let dialect: Arc<dyn Dialect> = Arc::new(SqliteDialect {});
 
         let unparser = Unparser::new(dialect.as_ref());
-        let expr = Expr::Cast(Cast {
-            expr: Box::new(col("a")),
-            data_type: DataType::Timestamp(TimeUnit::Nanosecond, None),
-        });
+        let expr = Expr::Cast(Cast::new(
+            Box::new(col("a")),
+            DataType::Timestamp(TimeUnit::Nanosecond, None),
+        ));
 
         let ast = unparser.expr_to_sql(&expr)?;
 
diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/cast.rs 
b/datafusion/substrait/src/logical_plan/consumer/expr/cast.rs
index ec70ac3fec..3dd62afe8f 100644
--- a/datafusion/substrait/src/logical_plan/consumer/expr/cast.rs
+++ b/datafusion/substrait/src/logical_plan/consumer/expr/cast.rs
@@ -15,8 +15,9 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use crate::logical_plan::consumer::SubstraitConsumer;
-use crate::logical_plan::consumer::types::from_substrait_type_without_names;
+use crate::logical_plan::consumer::{
+    SubstraitConsumer, field_from_substrait_type_without_names,
+};
 use datafusion::common::{DFSchema, substrait_err};
 use datafusion::logical_expr::{Cast, Expr, TryCast};
 use substrait::proto::expression as substrait_expression;
@@ -37,11 +38,11 @@ pub async fn from_cast(
                     )
                     .await?,
             );
-            let data_type = from_substrait_type_without_names(consumer, 
output_type)?;
+            let field = field_from_substrait_type_without_names(consumer, 
output_type)?;
             if cast.failure_behavior() == ReturnNull {
-                Ok(Expr::TryCast(TryCast::new(input_expr, data_type)))
+                Ok(Expr::TryCast(TryCast::new_from_field(input_expr, field)))
             } else {
-                Ok(Expr::Cast(Cast::new(input_expr, data_type)))
+                Ok(Expr::Cast(Cast::new_from_field(input_expr, field)))
             }
         }
         None => substrait_err!("Cast expression without output type is not 
allowed"),
diff --git a/datafusion/substrait/src/logical_plan/consumer/types.rs 
b/datafusion/substrait/src/logical_plan/consumer/types.rs
index 9ef7a0dd46..2493ac1e5a 100644
--- a/datafusion/substrait/src/logical_plan/consumer/types.rs
+++ b/datafusion/substrait/src/logical_plan/consumer/types.rs
@@ -34,14 +34,22 @@ use crate::variation_const::{
 };
 use crate::variation_const::{FLOAT_16_TYPE_NAME, NULL_TYPE_NAME};
 use datafusion::arrow::datatypes::{
-    DataType, Field, Fields, IntervalUnit, Schema, TimeUnit,
+    DataType, Field, FieldRef, Fields, IntervalUnit, Schema, TimeUnit,
 };
+use datafusion::common::datatype::DataTypeExt;
 use datafusion::common::{
     DFSchema, not_impl_err, substrait_datafusion_err, substrait_err,
 };
 use std::sync::Arc;
 use substrait::proto::{NamedStruct, Type, r#type};
 
+pub(crate) fn field_from_substrait_type_without_names(
+    consumer: &impl SubstraitConsumer,
+    dt: &Type,
+) -> datafusion::common::Result<FieldRef> {
+    Ok(from_substrait_type_without_names(consumer, 
dt)?.into_nullable_field_ref())
+}
+
 pub(crate) fn from_substrait_type_without_names(
     consumer: &impl SubstraitConsumer,
     dt: &Type,
@@ -49,6 +57,16 @@ pub(crate) fn from_substrait_type_without_names(
     from_substrait_type(consumer, dt, &[], &mut 0)
 }
 
+pub fn field_from_substrait_type(
+    consumer: &impl SubstraitConsumer,
+    dt: &Type,
+    dfs_names: &[String],
+    name_idx: &mut usize,
+) -> datafusion::common::Result<FieldRef> {
+    // We could add nullability here now that we are returning a Field
+    Ok(from_substrait_type(consumer, dt, dfs_names, 
name_idx)?.into_nullable_field_ref())
+}
+
 pub fn from_substrait_type(
     consumer: &impl SubstraitConsumer,
     dt: &Type,
diff --git a/datafusion/substrait/src/logical_plan/producer/expr/cast.rs 
b/datafusion/substrait/src/logical_plan/producer/expr/cast.rs
index 6eb27fc39d..2a5a6fe5c3 100644
--- a/datafusion/substrait/src/logical_plan/producer/expr/cast.rs
+++ b/datafusion/substrait/src/logical_plan/producer/expr/cast.rs
@@ -15,7 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use crate::logical_plan::producer::{SubstraitProducer, to_substrait_type};
+use crate::logical_plan::producer::{SubstraitProducer, 
to_substrait_type_from_field};
 use crate::variation_const::DEFAULT_TYPE_VARIATION_REF;
 use datafusion::common::{DFSchemaRef, ScalarValue};
 use datafusion::logical_expr::{Cast, Expr, TryCast};
@@ -29,7 +29,7 @@ pub fn from_cast(
     cast: &Cast,
     schema: &DFSchemaRef,
 ) -> datafusion::common::Result<Expression> {
-    let Cast { expr, data_type } = cast;
+    let Cast { expr, field } = cast;
     // since substrait Null must be typed, so if we see a cast(null, dt), we 
make it a typed null
     if let Expr::Literal(lit, _) = expr.as_ref() {
         // only the untyped(a null scalar value) null literal need this 
special handling
@@ -39,8 +39,8 @@ pub fn from_cast(
             let lit = Literal {
                 nullable: true,
                 type_variation_reference: DEFAULT_TYPE_VARIATION_REF,
-                literal_type: Some(LiteralType::Null(to_substrait_type(
-                    producer, data_type, true,
+                literal_type: 
Some(LiteralType::Null(to_substrait_type_from_field(
+                    producer, field,
                 )?)),
             };
             return Ok(Expression {
@@ -51,7 +51,7 @@ pub fn from_cast(
     Ok(Expression {
         rex_type: Some(RexType::Cast(Box::new(
             substrait::proto::expression::Cast {
-                r#type: Some(to_substrait_type(producer, data_type, true)?),
+                r#type: Some(to_substrait_type_from_field(producer, field)?),
                 input: Some(Box::new(producer.handle_expr(expr, schema)?)),
                 failure_behavior: FailureBehavior::ThrowException.into(),
             },
@@ -64,11 +64,11 @@ pub fn from_try_cast(
     cast: &TryCast,
     schema: &DFSchemaRef,
 ) -> datafusion::common::Result<Expression> {
-    let TryCast { expr, data_type } = cast;
+    let TryCast { expr, field } = cast;
     Ok(Expression {
         rex_type: Some(RexType::Cast(Box::new(
             substrait::proto::expression::Cast {
-                r#type: Some(to_substrait_type(producer, data_type, true)?),
+                r#type: Some(to_substrait_type_from_field(producer, field)?),
                 input: Some(Box::new(producer.handle_expr(expr, schema)?)),
                 failure_behavior: FailureBehavior::ReturnNull.into(),
             },
@@ -80,7 +80,7 @@ pub fn from_try_cast(
 mod tests {
     use super::*;
     use crate::logical_plan::producer::{
-        DefaultSubstraitProducer, to_substrait_extended_expr,
+        DefaultSubstraitProducer, to_substrait_extended_expr, 
to_substrait_type,
     };
     use datafusion::arrow::datatypes::{DataType, Field};
     use datafusion::common::DFSchema;
diff --git a/datafusion/substrait/src/logical_plan/producer/types.rs 
b/datafusion/substrait/src/logical_plan/producer/types.rs
index 3727596119..fa58949e6e 100644
--- a/datafusion/substrait/src/logical_plan/producer/types.rs
+++ b/datafusion/substrait/src/logical_plan/producer/types.rs
@@ -27,7 +27,7 @@ use crate::variation_const::{
     TIME_32_TYPE_VARIATION_REF, TIME_64_TYPE_VARIATION_REF,
     UNSIGNED_INTEGER_TYPE_VARIATION_REF, VIEW_CONTAINER_TYPE_VARIATION_REF,
 };
-use datafusion::arrow::datatypes::{DataType, IntervalUnit};
+use datafusion::arrow::datatypes::{DataType, Field, FieldRef, IntervalUnit};
 use datafusion::common::{DFSchemaRef, not_impl_err, plan_err};
 use substrait::proto::{NamedStruct, r#type};
 
@@ -36,12 +36,19 @@ pub(crate) fn to_substrait_type(
     dt: &DataType,
     nullable: bool,
 ) -> datafusion::common::Result<substrait::proto::Type> {
-    let nullability = if nullable {
+    to_substrait_type_from_field(producer, &Field::new("", dt.clone(), 
nullable).into())
+}
+
+pub(crate) fn to_substrait_type_from_field(
+    producer: &mut impl SubstraitProducer,
+    field: &FieldRef,
+) -> datafusion::common::Result<substrait::proto::Type> {
+    let nullability = if field.is_nullable() {
         r#type::Nullability::Nullable as i32
     } else {
         r#type::Nullability::Required as i32
     };
-    match dt {
+    match field.data_type() {
         DataType::Null => {
             let type_anchor = 
producer.register_type(NULL_TYPE_NAME.to_string());
             Ok(substrait::proto::Type {
@@ -288,16 +295,9 @@ pub(crate) fn to_substrait_type(
         }
         DataType::Map(inner, _) => match inner.data_type() {
             DataType::Struct(key_and_value) if key_and_value.len() == 2 => {
-                let key_type = to_substrait_type(
-                    producer,
-                    key_and_value[0].data_type(),
-                    key_and_value[0].is_nullable(),
-                )?;
-                let value_type = to_substrait_type(
-                    producer,
-                    key_and_value[1].data_type(),
-                    key_and_value[1].is_nullable(),
-                )?;
+                let key_type = to_substrait_type_from_field(producer, 
&key_and_value[0])?;
+                let value_type =
+                    to_substrait_type_from_field(producer, &key_and_value[1])?;
                 Ok(substrait::proto::Type {
                     kind: Some(r#type::Kind::Map(Box::new(r#type::Map {
                         key: Some(Box::new(key_type)),
@@ -310,8 +310,14 @@ pub(crate) fn to_substrait_type(
             _ => plan_err!("Map fields must contain a Struct with exactly 2 
fields"),
         },
         DataType::Dictionary(key_type, value_type) => {
-            let key_type = to_substrait_type(producer, key_type, nullable)?;
-            let value_type = to_substrait_type(producer, value_type, 
nullable)?;
+            let key_type = to_substrait_type_from_field(
+                producer,
+                &Field::new("", key_type.as_ref().clone(), 
field.is_nullable()).into(),
+            )?;
+            let value_type = to_substrait_type_from_field(
+                producer,
+                &Field::new("", value_type.as_ref().clone(), 
field.is_nullable()).into(),
+            )?;
             Ok(substrait::proto::Type {
                 kind: Some(r#type::Kind::Map(Box::new(r#type::Map {
                     key: Some(Box::new(key_type)),
@@ -324,9 +330,7 @@ pub(crate) fn to_substrait_type(
         DataType::Struct(fields) => {
             let field_types = fields
                 .iter()
-                .map(|field| {
-                    to_substrait_type(producer, field.data_type(), 
field.is_nullable())
-                })
+                .map(|field| to_substrait_type_from_field(producer, field))
                 .collect::<datafusion::common::Result<Vec<_>>>()?;
             Ok(substrait::proto::Type {
                 kind: Some(r#type::Kind::Struct(r#type::Struct {
@@ -352,7 +356,7 @@ pub(crate) fn to_substrait_type(
                 precision: *p as i32,
             })),
         }),
-        _ => not_impl_err!("Unsupported cast type: {dt}"),
+        _ => not_impl_err!("Unsupported cast type: {field}"),
     }
 }
 
@@ -369,7 +373,7 @@ pub(crate) fn to_substrait_named_struct(
         types: schema
             .fields()
             .iter()
-            .map(|f| to_substrait_type(producer, f.data_type(), 
f.is_nullable()))
+            .map(|f| to_substrait_type_from_field(producer, f))
             .collect::<datafusion::common::Result<_>>()?,
         type_variation_reference: DEFAULT_TYPE_VARIATION_REF,
         nullability: r#type::Nullability::Required as i32,


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to