This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 555fc2e24d fix: fold cast null to substrait typed null (#15854)
555fc2e24d is described below

commit 555fc2e24dd669e44ac23a9a1d8406f4ac58a9ed
Author: discord9 <55937128+disco...@users.noreply.github.com>
AuthorDate: Tue May 6 02:47:55 2025 +0800

    fix: fold cast null to substrait typed null (#15854)
    
    * fix: fold cast null to typed null
    
    * test: unit test
    
    * chore: clippy
    
    * fix: only handle ScalarValue::Null instead of all null-ed value
---
 datafusion/substrait/src/logical_plan/producer.rs | 85 +++++++++++++++++++++++
 1 file changed, 85 insertions(+)

diff --git a/datafusion/substrait/src/logical_plan/producer.rs 
b/datafusion/substrait/src/logical_plan/producer.rs
index 07bf0cb96a..31304b73a0 100644
--- a/datafusion/substrait/src/logical_plan/producer.rs
+++ b/datafusion/substrait/src/logical_plan/producer.rs
@@ -1590,6 +1590,24 @@ pub fn from_cast(
     schema: &DFSchemaRef,
 ) -> Result<Expression> {
     let Cast { expr, data_type } = cast;
+    // since substrait Null must be typed, so if we see a cast(null, dt), we 
make it a typed null
+    if let Expr::Literal(lit) = expr.as_ref() {
+        // only the untyped(a null scalar value) null literal need this 
special handling
+        // since all other kind of nulls are already typed and can be handled 
by substrait
+        // e.g. null::<Int32Type> or null::<Utf8Type>
+        if matches!(lit, ScalarValue::Null) {
+            let lit = Literal {
+                nullable: true,
+                type_variation_reference: DEFAULT_TYPE_VARIATION_REF,
+                literal_type: Some(LiteralType::Null(to_substrait_type(
+                    data_type, true,
+                )?)),
+            };
+            return Ok(Expression {
+                rex_type: Some(RexType::Literal(lit)),
+            });
+        }
+    }
     Ok(Expression {
         rex_type: Some(RexType::Cast(Box::new(
             substrait::proto::expression::Cast {
@@ -2575,6 +2593,7 @@ mod test {
     use datafusion::common::scalar::ScalarStructBuilder;
     use datafusion::common::DFSchema;
     use datafusion::execution::{SessionState, SessionStateBuilder};
+    use datafusion::logical_expr::ExprSchemable;
     use datafusion::prelude::SessionContext;
     use std::sync::LazyLock;
 
@@ -2912,4 +2931,70 @@ mod test {
 
         assert!(matches!(err, Err(DataFusionError::SchemaError(_, _))));
     }
+
+    #[tokio::test]
+    async fn fold_cast_null() {
+        let state = SessionStateBuilder::default().build();
+        let empty_schema = DFSchemaRef::new(DFSchema::empty());
+        let field = Field::new("out", DataType::Int32, false);
+
+        let expr = Expr::Literal(ScalarValue::Null)
+            .cast_to(&DataType::Int32, &empty_schema)
+            .unwrap();
+
+        let typed_null =
+            to_substrait_extended_expr(&[(&expr, &field)], &empty_schema, 
&state)
+                .unwrap();
+
+        if let ExprType::Expression(expr) =
+            typed_null.referred_expr[0].expr_type.as_ref().unwrap()
+        {
+            let lit = Literal {
+                nullable: true,
+                type_variation_reference: DEFAULT_TYPE_VARIATION_REF,
+                literal_type: Some(LiteralType::Null(
+                    to_substrait_type(&DataType::Int32, true).unwrap(),
+                )),
+            };
+            let expected = Expression {
+                rex_type: Some(RexType::Literal(lit)),
+            };
+            assert_eq!(*expr, expected);
+        } else {
+            panic!("Expected expression type");
+        }
+
+        // a typed null should not be folded
+        let expr = Expr::Literal(ScalarValue::Int64(None))
+            .cast_to(&DataType::Int32, &empty_schema)
+            .unwrap();
+
+        let typed_null =
+            to_substrait_extended_expr(&[(&expr, &field)], &empty_schema, 
&state)
+                .unwrap();
+
+        if let ExprType::Expression(expr) =
+            typed_null.referred_expr[0].expr_type.as_ref().unwrap()
+        {
+            let cast_expr = substrait::proto::expression::Cast {
+                r#type: Some(to_substrait_type(&DataType::Int32, 
true).unwrap()),
+                input: Some(Box::new(Expression {
+                    rex_type: Some(RexType::Literal(Literal {
+                        nullable: true,
+                        type_variation_reference: DEFAULT_TYPE_VARIATION_REF,
+                        literal_type: Some(LiteralType::Null(
+                            to_substrait_type(&DataType::Int64, true).unwrap(),
+                        )),
+                    })),
+                })),
+                failure_behavior: FailureBehavior::ThrowException as i32,
+            };
+            let expected = Expression {
+                rex_type: Some(RexType::Cast(Box::new(cast_expr))),
+            };
+            assert_eq!(*expr, expected);
+        } else {
+            panic!("Expected expression type");
+        }
+    }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org
For additional commands, e-mail: commits-h...@datafusion.apache.org

Reply via email to