This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 19d9174150 Add support for Substrait Struct literals and type (#10622)
19d9174150 is described below

commit 19d91741509ea975f32f99b17d3e0595a85c0e09
Author: Arttu <[email protected]>
AuthorDate: Thu May 23 19:33:08 2024 +0200

    Add support for Substrait Struct literals and type (#10622)
    
    * Add support for (un-named) Substrait Struct literal
    
    Adds support for converting from DataFusion Struct ScalarValues into 
Substrait Struct Literals and back.
    All structs are assumed to be unnamed - ie fields are renamed
    into "c0", "c1", etc
    
    * add converting from Substrait Struct type
    
    * cargo fmt --all
    
    * Unit test for NULL inside Struct
    
    * retry ci
---
 datafusion/substrait/src/logical_plan/consumer.rs  | 22 +++++++++++++
 datafusion/substrait/src/logical_plan/producer.rs  | 36 +++++++++++++++++++++-
 .../tests/cases/roundtrip_logical_plan.rs          | 10 ++++++
 3 files changed, 67 insertions(+), 1 deletion(-)

diff --git a/datafusion/substrait/src/logical_plan/consumer.rs 
b/datafusion/substrait/src/logical_plan/consumer.rs
index 5a71ab91db..a08485fd35 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -63,6 +63,7 @@ use substrait::proto::{FunctionArgument, SortField};
 
 use datafusion::arrow::array::GenericListArray;
 use datafusion::common::plan_err;
+use datafusion::common::scalar::ScalarStructBuilder;
 use datafusion::logical_expr::expr::{InList, InSubquery, Sort};
 use std::collections::HashMap;
 use std::str::FromStr;
@@ -1159,6 +1160,15 @@ pub(crate) fn from_substrait_type(dt: 
&substrait::proto::Type) -> Result<DataTyp
                     "Unsupported Substrait type variation {v} of type 
{s_kind:?}"
                 ),
             },
+            r#type::Kind::Struct(s) => {
+                let mut fields = vec![];
+                for (i, f) in s.types.iter().enumerate() {
+                    let field =
+                        Field::new(&format!("c{i}"), from_substrait_type(f)?, 
true);
+                    fields.push(field);
+                }
+                Ok(DataType::Struct(fields.into()))
+            }
             _ => not_impl_err!("Unsupported Substrait type: {s_kind:?}"),
         },
         _ => not_impl_err!("`None` Substrait kind is not supported"),
@@ -1318,6 +1328,18 @@ pub(crate) fn from_substrait_literal(lit: &Literal) -> 
Result<ScalarValue> {
                 }
             }
         }
+        Some(LiteralType::Struct(s)) => {
+            let mut builder = ScalarStructBuilder::new();
+            for (i, field) in s.fields.iter().enumerate() {
+                let sv = from_substrait_literal(field)?;
+                // c0, c1, ... align with e.g. SqlToRel::create_named_struct
+                builder = builder.with_scalar(
+                    Field::new(&format!("c{i}"), sv.data_type(), 
field.nullable),
+                    sv,
+                );
+            }
+            builder.build()?
+        }
         Some(LiteralType::Null(ntype)) => from_substrait_null(ntype)?,
         _ => return not_impl_err!("Unsupported literal_type: {:?}", 
lit.literal_type),
     };
diff --git a/datafusion/substrait/src/logical_plan/producer.rs 
b/datafusion/substrait/src/logical_plan/producer.rs
index bfdffdc3a2..e216008c73 100644
--- a/datafusion/substrait/src/logical_plan/producer.rs
+++ b/datafusion/substrait/src/logical_plan/producer.rs
@@ -43,7 +43,7 @@ use datafusion::logical_expr::{expr, Between, JoinConstraint, 
LogicalPlan, Opera
 use datafusion::prelude::Expr;
 use prost_types::Any as ProtoAny;
 use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields};
-use substrait::proto::expression::literal::List;
+use substrait::proto::expression::literal::{List, Struct};
 use substrait::proto::expression::subquery::InPredicate;
 use substrait::proto::expression::window_function::BoundsType;
 use substrait::proto::{CrossRel, ExchangeRel};
@@ -1751,6 +1751,18 @@ fn to_substrait_literal(value: &ScalarValue) -> 
Result<Literal> {
         ScalarValue::LargeList(l) if !value.is_null() => {
             (convert_array_to_literal_list(l)?, LARGE_CONTAINER_TYPE_REF)
         }
+        ScalarValue::Struct(s) if !value.is_null() => (
+            LiteralType::Struct(Struct {
+                fields: s
+                    .columns()
+                    .iter()
+                    .map(|col| {
+                        to_substrait_literal(&ScalarValue::try_from_array(col, 
0)?)
+                    })
+                    .collect::<Result<Vec<_>>>()?,
+            }),
+            DEFAULT_TYPE_REF,
+        ),
         _ => (try_to_substrait_null(value)?, DEFAULT_TYPE_REF),
     };
 
@@ -1979,6 +1991,9 @@ fn try_to_substrait_null(v: &ScalarValue) -> 
Result<LiteralType> {
         ScalarValue::LargeList(l) => {
             Ok(LiteralType::Null(to_substrait_type(l.data_type())?))
         }
+        ScalarValue::Struct(s) => {
+            Ok(LiteralType::Null(to_substrait_type(s.data_type())?))
+        }
         // TODO: Extend support for remaining data types
         _ => not_impl_err!("Unsupported literal: {v:?}"),
     }
@@ -2061,6 +2076,7 @@ mod test {
     use crate::logical_plan::consumer::{from_substrait_literal, 
from_substrait_type};
     use datafusion::arrow::array::GenericListArray;
     use datafusion::arrow::datatypes::Field;
+    use datafusion::common::scalar::ScalarStructBuilder;
 
     use super::*;
 
@@ -2125,6 +2141,17 @@ mod test {
             ),
         )))?;
 
+        let c0 = Field::new("c0", DataType::Boolean, true);
+        let c1 = Field::new("c1", DataType::Int32, true);
+        let c2 = Field::new("c2", DataType::Utf8, true);
+        round_trip_literal(
+            ScalarStructBuilder::new()
+                .with_scalar(c0, ScalarValue::Boolean(Some(true)))
+                .with_scalar(c1, ScalarValue::Int32(Some(1)))
+                .with_scalar(c2, ScalarValue::Utf8(None))
+                .build()?,
+        )?;
+
         Ok(())
     }
 
@@ -2169,6 +2196,13 @@ mod test {
         round_trip_type(DataType::LargeList(
             Field::new_list_field(DataType::Int32, true).into(),
         ))?;
+        round_trip_type(DataType::Struct(
+            vec![
+                Field::new("c0", DataType::Int32, true),
+                Field::new("c1", DataType::Utf8, true),
+            ]
+            .into(),
+        ))?;
 
         Ok(())
     }
diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs 
b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
index 02371063ef..8d0e96cedd 100644
--- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
@@ -675,6 +675,16 @@ async fn roundtrip_literal_list() -> Result<()> {
     .await
 }
 
+#[tokio::test]
+async fn roundtrip_literal_struct() -> Result<()> {
+    assert_expected_plan(
+        "SELECT STRUCT(1, true, CAST(NULL AS STRING)) FROM data",
+        "Projection: Struct({c0:1,c1:true,c2:})\
+        \n  TableScan: data projection=[]",
+    )
+    .await
+}
+
 /// Construct a plan that cast columns. Only those SQL types are supported for 
now.
 #[tokio::test]
 async fn new_test_grammar() -> Result<()> {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to