This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new b8b0c5584f feat(substrait): introduce consume_rel and 
consume_expression (#13963)
b8b0c5584f is described below

commit b8b0c5584f9f3a3aeca730ef1ac23dafc3e76dde
Author: Victor Barua <[email protected]>
AuthorDate: Sat Jan 4 07:05:04 2025 -0800

    feat(substrait): introduce consume_rel and consume_expression (#13963)
    
    * feat(substrait): introduce consume_rel and consume_expression
    
    Route calls to from_substrait_rel and from_substrait_rex through the
    SubstraitConsumer in order to allow users to provide their own behaviour
    
    * feat(substrait): consume nulls of user-defined types
    
    * docs(substrait): consume_rel and consume_expression docstrings
---
 datafusion/substrait/src/logical_plan/consumer.rs | 373 ++++++----------------
 1 file changed, 106 insertions(+), 267 deletions(-)

diff --git a/datafusion/substrait/src/logical_plan/consumer.rs 
b/datafusion/substrait/src/logical_plan/consumer.rs
index 5155531526..0ee87afe32 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -17,7 +17,7 @@
 
 use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano, OffsetBuffer};
 use async_recursion::async_recursion;
-use datafusion::arrow::array::{GenericListArray, MapArray};
+use datafusion::arrow::array::MapArray;
 use datafusion::arrow::datatypes::{
     DataType, Field, FieldRef, Fields, IntervalUnit, Schema, TimeUnit,
 };
@@ -173,9 +173,9 @@ use substrait::proto::{
 ///
 ///     // You can implement a fully custom consumer method if you need 
special handling
 ///     async fn consume_filter(&self, rel: &FilterRel) -> Result<LogicalPlan> 
{
-///         let input = from_substrait_rel(self, 
rel.input.as_ref().unwrap()).await?;
+///         let input = self.consume_rel(rel.input.as_ref().unwrap()).await?;
 ///         let expression =
-///             from_substrait_rex(self, rel.condition.as_ref().unwrap(), 
input.schema())
+///             self.consume_expression(rel.condition.as_ref().unwrap(), 
input.schema())
 ///                 .await?;
 ///         // though this one is quite boring
 ///         LogicalPlanBuilder::from(input).filter(expression)?.build()
@@ -233,6 +233,12 @@ pub trait SubstraitConsumer: Send + Sync + Sized {
     // These methods have default implementations calling the common handler 
code, to allow for users
     // to re-use common handling logic.
 
+    /// All [Rel]s to be converted pass through this method.
+    /// You can provide your own implementation if you wish to customize the 
conversion behaviour.
+    async fn consume_rel(&self, rel: &Rel) -> Result<LogicalPlan> {
+        from_substrait_rel(self, rel).await
+    }
+
     async fn consume_read(&self, rel: &ReadRel) -> Result<LogicalPlan> {
         from_read_rel(self, rel).await
     }
@@ -285,6 +291,16 @@ pub trait SubstraitConsumer: Send + Sync + Sized {
     // These methods have default implementations calling the common handler 
code, to allow for users
     // to re-use common handling logic.
 
+    /// All [Expression]s to be converted pass through this method.
+    /// You can provide your own implementation if you wish to customize the 
conversion behaviour.
+    async fn consume_expression(
+        &self,
+        expr: &Expression,
+        input_schema: &DFSchema,
+    ) -> Result<Expr> {
+        from_substrait_rex(self, expr, input_schema).await
+    }
+
     async fn consume_literal(&self, expr: &Literal) -> Result<Expr> {
         from_literal(self, expr).await
     }
@@ -535,7 +551,7 @@ impl SubstraitConsumer for DefaultSubstraitConsumer<'_> {
                     "ExtensionSingleRel missing input rel, try using 
ExtensionLeafRel instead"
                 );
         };
-        let input_plan = from_substrait_rel(self, input_rel).await?;
+        let input_plan = self.consume_rel(input_rel).await?;
         let plan = plan.with_exprs_and_inputs(plan.expressions(), 
vec![input_plan])?;
         Ok(LogicalPlan::Extension(Extension { node: plan }))
     }
@@ -553,7 +569,7 @@ impl SubstraitConsumer for DefaultSubstraitConsumer<'_> {
             .deserialize_logical_plan(&ext_detail.type_url, 
&ext_detail.value)?;
         let mut inputs = Vec::with_capacity(rel.inputs.len());
         for input in &rel.inputs {
-            let input_plan = from_substrait_rel(self, input).await?;
+            let input_plan = self.consume_rel(input).await?;
             inputs.push(input_plan);
         }
         let plan = plan.with_exprs_and_inputs(plan.expressions(), inputs)?;
@@ -666,10 +682,10 @@ async fn union_rels(
     is_all: bool,
 ) -> Result<LogicalPlan> {
     let mut union_builder = Ok(LogicalPlanBuilder::from(
-        from_substrait_rel(consumer, &rels[0]).await?,
+        consumer.consume_rel(&rels[0]).await?,
     ));
     for input in &rels[1..] {
-        let rel_plan = from_substrait_rel(consumer, input).await?;
+        let rel_plan = consumer.consume_rel(input).await?;
 
         union_builder = if is_all {
             union_builder?.union(rel_plan)
@@ -685,12 +701,12 @@ async fn intersect_rels(
     rels: &[Rel],
     is_all: bool,
 ) -> Result<LogicalPlan> {
-    let mut rel = from_substrait_rel(consumer, &rels[0]).await?;
+    let mut rel = consumer.consume_rel(&rels[0]).await?;
 
     for input in &rels[1..] {
         rel = LogicalPlanBuilder::intersect(
             rel,
-            from_substrait_rel(consumer, input).await?,
+            consumer.consume_rel(input).await?,
             is_all,
         )?
     }
@@ -703,14 +719,10 @@ async fn except_rels(
     rels: &[Rel],
     is_all: bool,
 ) -> Result<LogicalPlan> {
-    let mut rel = from_substrait_rel(consumer, &rels[0]).await?;
+    let mut rel = consumer.consume_rel(&rels[0]).await?;
 
     for input in &rels[1..] {
-        rel = LogicalPlanBuilder::except(
-            rel,
-            from_substrait_rel(consumer, input).await?,
-            is_all,
-        )?
+        rel = LogicalPlanBuilder::except(rel, 
consumer.consume_rel(input).await?, is_all)?
     }
 
     Ok(rel)
@@ -743,11 +755,9 @@ pub async fn from_substrait_plan_with_consumer(
         1 => {
             match plan.relations[0].rel_type.as_ref() {
                 Some(rt) => match rt {
-                    plan_rel::RelType::Rel(rel) => {
-                        Ok(from_substrait_rel(consumer, rel).await?)
-                    },
+                    plan_rel::RelType::Rel(rel) => 
Ok(consumer.consume_rel(rel).await?),
                     plan_rel::RelType::Root(root) => {
-                        let plan = from_substrait_rel(consumer, 
root.input.as_ref().unwrap()).await?;
+                        let plan = 
consumer.consume_rel(root.input.as_ref().unwrap()).await?;
                         if root.names.is_empty() {
                             // Backwards compatibility for plans missing names
                             return Ok(plan);
@@ -841,7 +851,9 @@ pub async fn from_substrait_extended_expr(
                 plan_err!("required property `expr_type` missing from 
Substrait ExpressionReference message")
             }
         }?;
-        let expr = from_substrait_rex(&consumer, scalar_expr, 
&input_schema).await?;
+        let expr = consumer
+            .consume_expression(scalar_expr, &input_schema)
+            .await?;
         let (output_type, expected_nullability) =
             expr.data_type_and_nullable(&input_schema)?;
         let output_field = Field::new("", output_type, expected_nullability);
@@ -1034,8 +1046,7 @@ pub async fn from_project_rel(
     p: &ProjectRel,
 ) -> Result<LogicalPlan> {
     if let Some(input) = p.input.as_ref() {
-        let mut input =
-            LogicalPlanBuilder::from(from_substrait_rel(consumer, 
input).await?);
+        let mut input = 
LogicalPlanBuilder::from(consumer.consume_rel(input).await?);
         let original_schema = input.schema().clone();
 
         // Ensure that all expressions have a unique display name, so that
@@ -1052,7 +1063,9 @@ pub async fn from_project_rel(
 
         let mut explicit_exprs: Vec<Expr> = vec![];
         for expr in &p.expressions {
-            let e = from_substrait_rex(consumer, expr, 
input.clone().schema()).await?;
+            let e = consumer
+                .consume_expression(expr, input.clone().schema())
+                .await?;
             // if the expression is WindowFunction, wrap in a Window relation
             if let Expr::WindowFunction(_) = &e {
                 // Adding the same expression here and in the project below
@@ -1081,9 +1094,11 @@ pub async fn from_filter_rel(
     filter: &FilterRel,
 ) -> Result<LogicalPlan> {
     if let Some(input) = filter.input.as_ref() {
-        let input = LogicalPlanBuilder::from(from_substrait_rel(consumer, 
input).await?);
+        let input = 
LogicalPlanBuilder::from(consumer.consume_rel(input).await?);
         if let Some(condition) = filter.condition.as_ref() {
-            let expr = from_substrait_rex(consumer, condition, 
input.schema()).await?;
+            let expr = consumer
+                .consume_expression(condition, input.schema())
+                .await?;
             input.filter(expr)?.build()
         } else {
             not_impl_err!("Filter without an condition is not valid")
@@ -1099,12 +1114,12 @@ pub async fn from_fetch_rel(
     fetch: &FetchRel,
 ) -> Result<LogicalPlan> {
     if let Some(input) = fetch.input.as_ref() {
-        let input = LogicalPlanBuilder::from(from_substrait_rel(consumer, 
input).await?);
+        let input = 
LogicalPlanBuilder::from(consumer.consume_rel(input).await?);
         let empty_schema = DFSchemaRef::new(DFSchema::empty());
         let offset = match &fetch.offset_mode {
             Some(fetch_rel::OffsetMode::Offset(offset)) => Some(lit(*offset)),
             Some(fetch_rel::OffsetMode::OffsetExpr(expr)) => {
-                Some(from_substrait_rex(consumer, expr, &empty_schema).await?)
+                Some(consumer.consume_expression(expr, &empty_schema).await?)
             }
             None => None,
         };
@@ -1114,7 +1129,7 @@ pub async fn from_fetch_rel(
                 (*count != -1).then(|| lit(*count))
             }
             Some(fetch_rel::CountMode::CountExpr(expr)) => {
-                Some(from_substrait_rex(consumer, expr, &empty_schema).await?)
+                Some(consumer.consume_expression(expr, &empty_schema).await?)
             }
             None => None,
         };
@@ -1129,7 +1144,7 @@ pub async fn from_sort_rel(
     sort: &SortRel,
 ) -> Result<LogicalPlan> {
     if let Some(input) = sort.input.as_ref() {
-        let input = LogicalPlanBuilder::from(from_substrait_rel(consumer, 
input).await?);
+        let input = 
LogicalPlanBuilder::from(consumer.consume_rel(input).await?);
         let sorts = from_substrait_sorts(consumer, &sort.sorts, 
input.schema()).await?;
         input.sort(sorts)?.build()
     } else {
@@ -1142,11 +1157,11 @@ pub async fn from_aggregate_rel(
     agg: &AggregateRel,
 ) -> Result<LogicalPlan> {
     if let Some(input) = agg.input.as_ref() {
-        let input = LogicalPlanBuilder::from(from_substrait_rel(consumer, 
input).await?);
+        let input = 
LogicalPlanBuilder::from(consumer.consume_rel(input).await?);
         let mut ref_group_exprs = vec![];
 
         for e in &agg.grouping_expressions {
-            let x = from_substrait_rex(consumer, e, input.schema()).await?;
+            let x = consumer.consume_expression(e, input.schema()).await?;
             ref_group_exprs.push(x);
         }
 
@@ -1189,7 +1204,7 @@ pub async fn from_aggregate_rel(
         for m in &agg.measures {
             let filter = match &m.filter {
                 Some(fil) => Some(Box::new(
-                    from_substrait_rex(consumer, fil, input.schema()).await?,
+                    consumer.consume_expression(fil, input.schema()).await?,
                 )),
                 None => None,
             };
@@ -1242,10 +1257,10 @@ pub async fn from_join_rel(
     }
 
     let left: LogicalPlanBuilder = LogicalPlanBuilder::from(
-        from_substrait_rel(consumer, join.left.as_ref().unwrap()).await?,
+        consumer.consume_rel(join.left.as_ref().unwrap()).await?,
     );
     let right = LogicalPlanBuilder::from(
-        from_substrait_rel(consumer, join.right.as_ref().unwrap()).await?,
+        consumer.consume_rel(join.right.as_ref().unwrap()).await?,
     );
     let (left, right) = requalify_sides_if_needed(left, right)?;
 
@@ -1258,7 +1273,7 @@ pub async fn from_join_rel(
     // Otherwise, build join with only the filter, without join keys
     match &join.expression.as_ref() {
         Some(expr) => {
-            let on = from_substrait_rex(consumer, expr, 
&in_join_schema).await?;
+            let on = consumer.consume_expression(expr, &in_join_schema).await?;
             // The join expression can contain both equal and non-equal ops.
             // As of datafusion 31.0.0, the equal and non equal join 
conditions are in separate fields.
             // So we extract each part as follows:
@@ -1290,10 +1305,10 @@ pub async fn from_cross_rel(
     cross: &CrossRel,
 ) -> Result<LogicalPlan> {
     let left = LogicalPlanBuilder::from(
-        from_substrait_rel(consumer, cross.left.as_ref().unwrap()).await?,
+        consumer.consume_rel(cross.left.as_ref().unwrap()).await?,
     );
     let right = LogicalPlanBuilder::from(
-        from_substrait_rel(consumer, cross.right.as_ref().unwrap()).await?,
+        consumer.consume_rel(cross.right.as_ref().unwrap()).await?,
     );
     let (left, right) = requalify_sides_if_needed(left, right)?;
     left.cross_join(right.build()?)?.build()
@@ -1466,7 +1481,7 @@ pub async fn from_set_rel(
             SetOp::UnionAll => union_rels(consumer, &set.inputs, true).await,
             SetOp::UnionDistinct => union_rels(consumer, &set.inputs, 
false).await,
             SetOp::IntersectionPrimary => LogicalPlanBuilder::intersect(
-                from_substrait_rel(consumer, &set.inputs[0]).await?,
+                consumer.consume_rel(&set.inputs[0]).await?,
                 union_rels(consumer, &set.inputs[1..], true).await?,
                 false,
             ),
@@ -1490,7 +1505,7 @@ pub async fn from_exchange_rel(
     let Some(input) = exchange.input.as_ref() else {
         return substrait_err!("Unexpected empty input in ExchangeRel");
     };
-    let input = Arc::new(from_substrait_rel(consumer, input).await?);
+    let input = Arc::new(consumer.consume_rel(input).await?);
 
     let Some(exchange_kind) = &exchange.exchange_kind else {
         return substrait_err!("Unexpected empty input in ExchangeRel");
@@ -1822,8 +1837,9 @@ pub async fn from_substrait_sorts(
 ) -> Result<Vec<Sort>> {
     let mut sorts: Vec<Sort> = vec![];
     for s in substrait_sorts {
-        let expr =
-            from_substrait_rex(consumer, s.expr.as_ref().unwrap(), 
input_schema).await?;
+        let expr = consumer
+            .consume_expression(s.expr.as_ref().unwrap(), input_schema)
+            .await?;
         let asc_nullfirst = match &s.sort_kind {
             Some(k) => match k {
                 Direction(d) => {
@@ -1870,7 +1886,7 @@ pub async fn from_substrait_rex_vec(
 ) -> Result<Vec<Expr>> {
     let mut expressions: Vec<Expr> = vec![];
     for expr in exprs {
-        let expression = from_substrait_rex(consumer, expr, 
input_schema).await?;
+        let expression = consumer.consume_expression(expr, 
input_schema).await?;
         expressions.push(expression);
     }
     Ok(expressions)
@@ -1885,9 +1901,7 @@ pub async fn from_substrait_func_args(
     let mut args: Vec<Expr> = vec![];
     for arg in arguments {
         let arg_expr = match &arg.arg_type {
-            Some(ArgType::Value(e)) => {
-                from_substrait_rex(consumer, e, input_schema).await
-            }
+            Some(ArgType::Value(e)) => consumer.consume_expression(e, 
input_schema).await,
             _ => not_impl_err!("Function argument non-Value type not 
supported"),
         };
         args.push(arg_expr?);
@@ -1991,7 +2005,11 @@ pub async fn from_singular_or_list(
     let substrait_expr = expr.value.as_ref().unwrap();
     let substrait_list = expr.options.as_ref();
     Ok(Expr::InList(InList {
-        expr: Box::new(from_substrait_rex(consumer, substrait_expr, 
input_schema).await?),
+        expr: Box::new(
+            consumer
+                .consume_expression(substrait_expr, input_schema)
+                .await?,
+        ),
         list: from_substrait_rex_vec(consumer, substrait_list, 
input_schema).await?,
         negated: false,
     }))
@@ -2019,39 +2037,30 @@ pub async fn from_if_then(
             // Check if the first element is type base expression
             if if_expr.then.is_none() {
                 expr = Some(Box::new(
-                    from_substrait_rex(
-                        consumer,
-                        if_expr.r#if.as_ref().unwrap(),
-                        input_schema,
-                    )
-                    .await?,
+                    consumer
+                        .consume_expression(if_expr.r#if.as_ref().unwrap(), 
input_schema)
+                        .await?,
                 ));
                 continue;
             }
         }
         when_then_expr.push((
             Box::new(
-                from_substrait_rex(
-                    consumer,
-                    if_expr.r#if.as_ref().unwrap(),
-                    input_schema,
-                )
-                .await?,
+                consumer
+                    .consume_expression(if_expr.r#if.as_ref().unwrap(), 
input_schema)
+                    .await?,
             ),
             Box::new(
-                from_substrait_rex(
-                    consumer,
-                    if_expr.then.as_ref().unwrap(),
-                    input_schema,
-                )
-                .await?,
+                consumer
+                    .consume_expression(if_expr.then.as_ref().unwrap(), 
input_schema)
+                    .await?,
             ),
         ));
     }
     // Parse `else`
     let else_expr = match &if_then.r#else {
         Some(e) => Some(Box::new(
-            from_substrait_rex(consumer, e, input_schema).await?,
+            consumer.consume_expression(e, input_schema).await?,
         )),
         None => None,
     };
@@ -2134,12 +2143,12 @@ pub async fn from_cast(
     match cast.r#type.as_ref() {
         Some(output_type) => {
             let input_expr = Box::new(
-                from_substrait_rex(
-                    consumer,
-                    cast.input.as_ref().unwrap().as_ref(),
-                    input_schema,
-                )
-                .await?,
+                consumer
+                    .consume_expression(
+                        cast.input.as_ref().unwrap().as_ref(),
+                        input_schema,
+                    )
+                    .await?,
             );
             let data_type = from_substrait_type_without_names(consumer, 
output_type)?;
             if cast.failure_behavior() == ReturnNull {
@@ -2229,12 +2238,12 @@ pub async fn from_subquery(
                     let needle_expr = &in_predicate.needles[0];
                     let haystack_expr = &in_predicate.haystack;
                     if let Some(haystack_expr) = haystack_expr {
-                        let haystack_expr =
-                            from_substrait_rel(consumer, haystack_expr).await?;
+                        let haystack_expr = 
consumer.consume_rel(haystack_expr).await?;
                         let outer_refs = haystack_expr.all_out_ref_exprs();
                         Ok(Expr::InSubquery(InSubquery {
                             expr: Box::new(
-                                from_substrait_rex(consumer, needle_expr, 
input_schema)
+                                consumer
+                                    .consume_expression(needle_expr, 
input_schema)
                                     .await?,
                             ),
                             subquery: Subquery {
@@ -2251,11 +2260,9 @@ pub async fn from_subquery(
                 }
             }
             SubqueryType::Scalar(query) => {
-                let plan = from_substrait_rel(
-                    consumer,
-                    &(query.input.clone()).unwrap_or_default(),
-                )
-                .await?;
+                let plan = consumer
+                    .consume_rel(&(query.input.clone()).unwrap_or_default())
+                    .await?;
                 let outer_ref_columns = plan.all_out_ref_exprs();
                 Ok(Expr::ScalarSubquery(Subquery {
                     subquery: Arc::new(plan),
@@ -2267,11 +2274,9 @@ pub async fn from_subquery(
                     // exist
                     PredicateOp::Exists => {
                         let relation = &predicate.tuples;
-                        let plan = from_substrait_rel(
-                            consumer,
-                            &relation.clone().unwrap_or_default(),
-                        )
-                        .await?;
+                        let plan = consumer
+                            .consume_rel(&relation.clone().unwrap_or_default())
+                            .await?;
                         let outer_ref_columns = plan.all_out_ref_exprs();
                         Ok(Expr::Exists(Exists::new(
                             Subquery {
@@ -2909,8 +2914,10 @@ fn from_substrait_literal(
             }
             builder.build()?
         }
-        Some(LiteralType::Null(ntype)) => {
-            from_substrait_null(consumer, ntype, dfs_names, name_idx)?
+        Some(LiteralType::Null(null_type)) => {
+            let data_type =
+                from_substrait_type(consumer, null_type, dfs_names, name_idx)?;
+            ScalarValue::try_from(&data_type)?
         }
         Some(LiteralType::IntervalDayToSecond(IntervalDayToSecond {
             days,
@@ -3082,180 +3089,6 @@ fn from_substrait_literal(
     Ok(scalar_value)
 }
 
-fn from_substrait_null(
-    consumer: &impl SubstraitConsumer,
-    null_type: &Type,
-    dfs_names: &[String],
-    name_idx: &mut usize,
-) -> Result<ScalarValue> {
-    if let Some(kind) = &null_type.kind {
-        match kind {
-            r#type::Kind::Bool(_) => Ok(ScalarValue::Boolean(None)),
-            r#type::Kind::I8(integer) => match 
integer.type_variation_reference {
-                DEFAULT_TYPE_VARIATION_REF => Ok(ScalarValue::Int8(None)),
-                UNSIGNED_INTEGER_TYPE_VARIATION_REF => 
Ok(ScalarValue::UInt8(None)),
-                v => not_impl_err!(
-                    "Unsupported Substrait type variation {v} of type {kind:?}"
-                ),
-            },
-            r#type::Kind::I16(integer) => match 
integer.type_variation_reference {
-                DEFAULT_TYPE_VARIATION_REF => Ok(ScalarValue::Int16(None)),
-                UNSIGNED_INTEGER_TYPE_VARIATION_REF => 
Ok(ScalarValue::UInt16(None)),
-                v => not_impl_err!(
-                    "Unsupported Substrait type variation {v} of type {kind:?}"
-                ),
-            },
-            r#type::Kind::I32(integer) => match 
integer.type_variation_reference {
-                DEFAULT_TYPE_VARIATION_REF => Ok(ScalarValue::Int32(None)),
-                UNSIGNED_INTEGER_TYPE_VARIATION_REF => 
Ok(ScalarValue::UInt32(None)),
-                v => not_impl_err!(
-                    "Unsupported Substrait type variation {v} of type {kind:?}"
-                ),
-            },
-            r#type::Kind::I64(integer) => match 
integer.type_variation_reference {
-                DEFAULT_TYPE_VARIATION_REF => Ok(ScalarValue::Int64(None)),
-                UNSIGNED_INTEGER_TYPE_VARIATION_REF => 
Ok(ScalarValue::UInt64(None)),
-                v => not_impl_err!(
-                    "Unsupported Substrait type variation {v} of type {kind:?}"
-                ),
-            },
-            r#type::Kind::Fp32(_) => Ok(ScalarValue::Float32(None)),
-            r#type::Kind::Fp64(_) => Ok(ScalarValue::Float64(None)),
-            r#type::Kind::Timestamp(ts) => {
-                // Kept for backwards compatibility, new plans should use 
PrecisionTimestamp(Tz) instead
-                #[allow(deprecated)]
-                match ts.type_variation_reference {
-                    TIMESTAMP_SECOND_TYPE_VARIATION_REF => {
-                        Ok(ScalarValue::TimestampSecond(None, None))
-                    }
-                    TIMESTAMP_MILLI_TYPE_VARIATION_REF => {
-                        Ok(ScalarValue::TimestampMillisecond(None, None))
-                    }
-                    TIMESTAMP_MICRO_TYPE_VARIATION_REF => {
-                        Ok(ScalarValue::TimestampMicrosecond(None, None))
-                    }
-                    TIMESTAMP_NANO_TYPE_VARIATION_REF => {
-                        Ok(ScalarValue::TimestampNanosecond(None, None))
-                    }
-                    v => not_impl_err!(
-                        "Unsupported Substrait type variation {v} of type 
{kind:?}"
-                    ),
-                }
-            }
-            r#type::Kind::PrecisionTimestamp(pts) => match pts.precision {
-                0 => Ok(ScalarValue::TimestampSecond(None, None)),
-                3 => Ok(ScalarValue::TimestampMillisecond(None, None)),
-                6 => Ok(ScalarValue::TimestampMicrosecond(None, None)),
-                9 => Ok(ScalarValue::TimestampNanosecond(None, None)),
-                p => not_impl_err!(
-                    "Unsupported Substrait precision {p} for 
PrecisionTimestamp"
-                ),
-            },
-            r#type::Kind::PrecisionTimestampTz(pts) => match pts.precision {
-                0 => Ok(ScalarValue::TimestampSecond(
-                    None,
-                    Some(DEFAULT_TIMEZONE.into()),
-                )),
-                3 => Ok(ScalarValue::TimestampMillisecond(
-                    None,
-                    Some(DEFAULT_TIMEZONE.into()),
-                )),
-                6 => Ok(ScalarValue::TimestampMicrosecond(
-                    None,
-                    Some(DEFAULT_TIMEZONE.into()),
-                )),
-                9 => Ok(ScalarValue::TimestampNanosecond(
-                    None,
-                    Some(DEFAULT_TIMEZONE.into()),
-                )),
-                p => not_impl_err!(
-                    "Unsupported Substrait precision {p} for 
PrecisionTimestamp"
-                ),
-            },
-            r#type::Kind::Date(date) => match date.type_variation_reference {
-                DATE_32_TYPE_VARIATION_REF => Ok(ScalarValue::Date32(None)),
-                DATE_64_TYPE_VARIATION_REF => Ok(ScalarValue::Date64(None)),
-                v => not_impl_err!(
-                    "Unsupported Substrait type variation {v} of type {kind:?}"
-                ),
-            },
-            r#type::Kind::Binary(binary) => match 
binary.type_variation_reference {
-                DEFAULT_CONTAINER_TYPE_VARIATION_REF => 
Ok(ScalarValue::Binary(None)),
-                LARGE_CONTAINER_TYPE_VARIATION_REF => 
Ok(ScalarValue::LargeBinary(None)),
-                v => not_impl_err!(
-                    "Unsupported Substrait type variation {v} of type {kind:?}"
-                ),
-            },
-            // FixedBinary is not supported because `None` doesn't have length
-            r#type::Kind::String(string) => match 
string.type_variation_reference {
-                DEFAULT_CONTAINER_TYPE_VARIATION_REF => 
Ok(ScalarValue::Utf8(None)),
-                LARGE_CONTAINER_TYPE_VARIATION_REF => 
Ok(ScalarValue::LargeUtf8(None)),
-                v => not_impl_err!(
-                    "Unsupported Substrait type variation {v} of type {kind:?}"
-                ),
-            },
-            r#type::Kind::Decimal(d) => Ok(ScalarValue::Decimal128(
-                None,
-                d.precision as u8,
-                d.scale as i8,
-            )),
-            r#type::Kind::List(l) => {
-                let field = Field::new_list_field(
-                    from_substrait_type(
-                        consumer,
-                        l.r#type.clone().unwrap().as_ref(),
-                        dfs_names,
-                        name_idx,
-                    )?,
-                    true,
-                );
-                match l.type_variation_reference {
-                    DEFAULT_CONTAINER_TYPE_VARIATION_REF => 
Ok(ScalarValue::List(
-                        Arc::new(GenericListArray::new_null(field.into(), 1)),
-                    )),
-                    LARGE_CONTAINER_TYPE_VARIATION_REF => 
Ok(ScalarValue::LargeList(
-                        Arc::new(GenericListArray::new_null(field.into(), 1)),
-                    )),
-                    v => not_impl_err!(
-                        "Unsupported Substrait type variation {v} of type 
{kind:?}"
-                    ),
-                }
-            }
-            r#type::Kind::Map(map) => {
-                let key_type = map.key.as_ref().ok_or_else(|| {
-                    substrait_datafusion_err!("Map type must have key type")
-                })?;
-                let value_type = map.value.as_ref().ok_or_else(|| {
-                    substrait_datafusion_err!("Map type must have value type")
-                })?;
-
-                let key_type =
-                    from_substrait_type(consumer, key_type, dfs_names, 
name_idx)?;
-                let value_type =
-                    from_substrait_type(consumer, value_type, dfs_names, 
name_idx)?;
-                let entries_field = Arc::new(Field::new_struct(
-                    "entries",
-                    vec![
-                        Field::new("key", key_type, false),
-                        Field::new("value", value_type, true),
-                    ],
-                    false,
-                ));
-
-                DataType::Map(entries_field, false /* keys sorted 
*/).try_into()
-            }
-            r#type::Kind::Struct(s) => {
-                let fields =
-                    from_substrait_struct_type(consumer, s, dfs_names, 
name_idx)?;
-                Ok(ScalarStructBuilder::new_null(fields))
-            }
-            _ => not_impl_err!("Unsupported Substrait type for null: 
{kind:?}"),
-        }
-    } else {
-        not_impl_err!("Null type without kind is not supported")
-    }
-}
-
 #[allow(deprecated)]
 async fn from_substrait_grouping(
     consumer: &impl SubstraitConsumer,
@@ -3266,7 +3099,7 @@ async fn from_substrait_grouping(
     let mut group_exprs = vec![];
     if !grouping.grouping_expressions.is_empty() {
         for e in &grouping.grouping_expressions {
-            let expr = from_substrait_rex(consumer, e, input_schema).await?;
+            let expr = consumer.consume_expression(e, input_schema).await?;
             group_exprs.push(expr);
         }
         return Ok(group_exprs);
@@ -3349,7 +3182,9 @@ impl BuiltinExprBuilder {
         let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type 
else {
             return substrait_err!("Invalid arguments type for {fn_name} expr");
         };
-        let arg = from_substrait_rex(consumer, expr_substrait, 
input_schema).await?;
+        let arg = consumer
+            .consume_expression(expr_substrait, input_schema)
+            .await?;
         let arg = Box::new(arg);
 
         let expr = match fn_name {
@@ -3383,12 +3218,15 @@ impl BuiltinExprBuilder {
         let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type 
else {
             return substrait_err!("Invalid arguments type for `{fn_name}` 
expr");
         };
-        let expr = from_substrait_rex(consumer, expr_substrait, 
input_schema).await?;
+        let expr = consumer
+            .consume_expression(expr_substrait, input_schema)
+            .await?;
         let Some(ArgType::Value(pattern_substrait)) = &f.arguments[1].arg_type 
else {
             return substrait_err!("Invalid arguments type for `{fn_name}` 
expr");
         };
-        let pattern =
-            from_substrait_rex(consumer, pattern_substrait, 
input_schema).await?;
+        let pattern = consumer
+            .consume_expression(pattern_substrait, input_schema)
+            .await?;
 
         // Default case: escape character is Literal(Utf8(None))
         let escape_char = if f.arguments.len() == 3 {
@@ -3397,8 +3235,9 @@ impl BuiltinExprBuilder {
                 return substrait_err!("Invalid arguments type for `{fn_name}` 
expr");
             };
 
-            let escape_char_expr =
-                from_substrait_rex(consumer, escape_char_substrait, 
input_schema).await?;
+            let escape_char_expr = consumer
+                .consume_expression(escape_char_substrait, input_schema)
+                .await?;
 
             match escape_char_expr {
                 Expr::Literal(ScalarValue::Utf8(escape_char_string)) => {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to