This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 0e9c189a2e Substrait insubquery (#8363)
0e9c189a2e is described below

commit 0e9c189a2e4f8f6304239d6cbe14f5114a6d0406
Author: Tanmay Gujar <[email protected]>
AuthorDate: Wed Dec 20 15:48:11 2023 -0500

    Substrait insubquery (#8363)
    
    * testing in subquery support for substrait producer
    
    * consumer fails with table not found
    
    * testing roundtrip check
    
    * pass in ctx to expr
    
    * basic test for Insubquery
    
    * fix: outer refs in consumer
    
    * fix: merge issues
    
    * minor fixes
    
    * fix: fmt and clippy CI errors
    
    * improve error msg in consumer
    
    * minor fixes
---
 datafusion/substrait/src/logical_plan/consumer.rs  | 151 +++++++++++++++-----
 datafusion/substrait/src/logical_plan/producer.rs  | 155 ++++++++++++++++-----
 .../tests/cases/roundtrip_logical_plan.rs          |  18 +++
 3 files changed, 256 insertions(+), 68 deletions(-)

diff --git a/datafusion/substrait/src/logical_plan/consumer.rs 
b/datafusion/substrait/src/logical_plan/consumer.rs
index b7fee96bba..9931dd15ae 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -28,7 +28,7 @@ use datafusion::logical_expr::{
 };
 use datafusion::logical_expr::{
     expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning,
-    Repartition, WindowFrameBound, WindowFrameUnits,
+    Repartition, Subquery, WindowFrameBound, WindowFrameUnits,
 };
 use datafusion::prelude::JoinType;
 use datafusion::sql::TableReference;
@@ -39,6 +39,7 @@ use datafusion::{
     scalar::ScalarValue,
 };
 use substrait::proto::exchange_rel::ExchangeKind;
+use substrait::proto::expression::subquery::SubqueryType;
 use substrait::proto::expression::{FieldReference, Literal, ScalarFunction};
 use substrait::proto::{
     aggregate_function::AggregationInvocation,
@@ -61,7 +62,7 @@ use substrait::proto::{
 use substrait::proto::{FunctionArgument, SortField};
 
 use datafusion::common::plan_err;
-use datafusion::logical_expr::expr::{InList, Sort};
+use datafusion::logical_expr::expr::{InList, InSubquery, Sort};
 use std::collections::HashMap;
 use std::str::FromStr;
 use std::sync::Arc;
@@ -230,7 +231,8 @@ pub async fn from_substrait_rel(
                 let mut exprs: Vec<Expr> = vec![];
                 for e in &p.expressions {
                     let x =
-                        from_substrait_rex(e, input.clone().schema(), 
extensions).await?;
+                        from_substrait_rex(ctx, e, input.clone().schema(), 
extensions)
+                            .await?;
                     // if the expression is WindowFunction, wrap in a Window 
relation
                     //   before returning and do not add to list of this 
Projection's expression list
                     // otherwise, add expression to the Projection's 
expression list
@@ -256,7 +258,8 @@ pub async fn from_substrait_rel(
                 );
                 if let Some(condition) = filter.condition.as_ref() {
                     let expr =
-                        from_substrait_rex(condition, input.schema(), 
extensions).await?;
+                        from_substrait_rex(ctx, condition, input.schema(), 
extensions)
+                            .await?;
                     input.filter(expr.as_ref().clone())?.build()
                 } else {
                     not_impl_err!("Filter without an condition is not valid")
@@ -288,7 +291,8 @@ pub async fn from_substrait_rel(
                     from_substrait_rel(ctx, input, extensions).await?,
                 );
                 let sorts =
-                    from_substrait_sorts(&sort.sorts, input.schema(), 
extensions).await?;
+                    from_substrait_sorts(ctx, &sort.sorts, input.schema(), 
extensions)
+                        .await?;
                 input.sort(sorts)?.build()
             } else {
                 not_impl_err!("Sort without an input is not valid")
@@ -306,7 +310,8 @@ pub async fn from_substrait_rel(
                     1 => {
                         for e in &agg.groupings[0].grouping_expressions {
                             let x =
-                                from_substrait_rex(e, input.schema(), 
extensions).await?;
+                                from_substrait_rex(ctx, e, input.schema(), 
extensions)
+                                    .await?;
                             group_expr.push(x.as_ref().clone());
                         }
                     }
@@ -315,8 +320,13 @@ pub async fn from_substrait_rel(
                         for grouping in &agg.groupings {
                             let mut grouping_set = vec![];
                             for e in &grouping.grouping_expressions {
-                                let x = from_substrait_rex(e, input.schema(), 
extensions)
-                                    .await?;
+                                let x = from_substrait_rex(
+                                    ctx,
+                                    e,
+                                    input.schema(),
+                                    extensions,
+                                )
+                                .await?;
                                 grouping_set.push(x.as_ref().clone());
                             }
                             grouping_sets.push(grouping_set);
@@ -334,7 +344,7 @@ pub async fn from_substrait_rel(
                 for m in &agg.measures {
                     let filter = match &m.filter {
                         Some(fil) => Some(Box::new(
-                            from_substrait_rex(fil, input.schema(), extensions)
+                            from_substrait_rex(ctx, fil, input.schema(), 
extensions)
                                 .await?
                                 .as_ref()
                                 .clone(),
@@ -402,8 +412,8 @@ pub async fn from_substrait_rel(
             // Otherwise, build join with only the filter, without join keys
             match &join.expression.as_ref() {
                 Some(expr) => {
-                    let on =
-                        from_substrait_rex(expr, &in_join_schema, 
extensions).await?;
+                    let on = from_substrait_rex(ctx, expr, &in_join_schema, 
extensions)
+                        .await?;
                     // The join expression can contain both equal and 
non-equal ops.
                     // As of datafusion 31.0.0, the equal and non equal join 
conditions are in separate fields.
                     // So we extract each part as follows:
@@ -612,14 +622,16 @@ fn from_substrait_jointype(join_type: i32) -> 
Result<JoinType> {
 
 /// Convert Substrait Sorts to DataFusion Exprs
 pub async fn from_substrait_sorts(
+    ctx: &SessionContext,
     substrait_sorts: &Vec<SortField>,
     input_schema: &DFSchema,
     extensions: &HashMap<u32, &String>,
 ) -> Result<Vec<Expr>> {
     let mut sorts: Vec<Expr> = vec![];
     for s in substrait_sorts {
-        let expr = from_substrait_rex(s.expr.as_ref().unwrap(), input_schema, 
extensions)
-            .await?;
+        let expr =
+            from_substrait_rex(ctx, s.expr.as_ref().unwrap(), input_schema, 
extensions)
+                .await?;
         let asc_nullfirst = match &s.sort_kind {
             Some(k) => match k {
                 Direction(d) => {
@@ -660,13 +672,14 @@ pub async fn from_substrait_sorts(
 
 /// Convert Substrait Expressions to DataFusion Exprs
 pub async fn from_substrait_rex_vec(
+    ctx: &SessionContext,
     exprs: &Vec<Expression>,
     input_schema: &DFSchema,
     extensions: &HashMap<u32, &String>,
 ) -> Result<Vec<Expr>> {
     let mut expressions: Vec<Expr> = vec![];
     for expr in exprs {
-        let expression = from_substrait_rex(expr, input_schema, 
extensions).await?;
+        let expression = from_substrait_rex(ctx, expr, input_schema, 
extensions).await?;
         expressions.push(expression.as_ref().clone());
     }
     Ok(expressions)
@@ -674,6 +687,7 @@ pub async fn from_substrait_rex_vec(
 
 /// Convert Substrait FunctionArguments to DataFusion Exprs
 pub async fn from_substriat_func_args(
+    ctx: &SessionContext,
     arguments: &Vec<FunctionArgument>,
     input_schema: &DFSchema,
     extensions: &HashMap<u32, &String>,
@@ -682,7 +696,7 @@ pub async fn from_substriat_func_args(
     for arg in arguments {
         let arg_expr = match &arg.arg_type {
             Some(ArgType::Value(e)) => {
-                from_substrait_rex(e, input_schema, extensions).await
+                from_substrait_rex(ctx, e, input_schema, extensions).await
             }
             _ => {
                 not_impl_err!("Aggregated function argument non-Value type not 
supported")
@@ -707,7 +721,7 @@ pub async fn from_substrait_agg_func(
     for arg in &f.arguments {
         let arg_expr = match &arg.arg_type {
             Some(ArgType::Value(e)) => {
-                from_substrait_rex(e, input_schema, extensions).await
+                from_substrait_rex(ctx, e, input_schema, extensions).await
             }
             _ => {
                 not_impl_err!("Aggregated function argument non-Value type not 
supported")
@@ -745,6 +759,7 @@ pub async fn from_substrait_agg_func(
 /// Convert Substrait Rex to DataFusion Expr
 #[async_recursion]
 pub async fn from_substrait_rex(
+    ctx: &SessionContext,
     e: &Expression,
     input_schema: &DFSchema,
     extensions: &HashMap<u32, &String>,
@@ -755,13 +770,18 @@ pub async fn from_substrait_rex(
             let substrait_list = s.options.as_ref();
             Ok(Arc::new(Expr::InList(InList {
                 expr: Box::new(
-                    from_substrait_rex(substrait_expr, input_schema, 
extensions)
+                    from_substrait_rex(ctx, substrait_expr, input_schema, 
extensions)
                         .await?
                         .as_ref()
                         .clone(),
                 ),
-                list: from_substrait_rex_vec(substrait_list, input_schema, 
extensions)
-                    .await?,
+                list: from_substrait_rex_vec(
+                    ctx,
+                    substrait_list,
+                    input_schema,
+                    extensions,
+                )
+                .await?,
                 negated: false,
             })))
         }
@@ -779,6 +799,7 @@ pub async fn from_substrait_rex(
                     if if_expr.then.is_none() {
                         expr = Some(Box::new(
                             from_substrait_rex(
+                                ctx,
                                 if_expr.r#if.as_ref().unwrap(),
                                 input_schema,
                                 extensions,
@@ -793,6 +814,7 @@ pub async fn from_substrait_rex(
                 when_then_expr.push((
                     Box::new(
                         from_substrait_rex(
+                            ctx,
                             if_expr.r#if.as_ref().unwrap(),
                             input_schema,
                             extensions,
@@ -803,6 +825,7 @@ pub async fn from_substrait_rex(
                     ),
                     Box::new(
                         from_substrait_rex(
+                            ctx,
                             if_expr.then.as_ref().unwrap(),
                             input_schema,
                             extensions,
@@ -816,7 +839,7 @@ pub async fn from_substrait_rex(
             // Parse `else`
             let else_expr = match &if_then.r#else {
                 Some(e) => Some(Box::new(
-                    from_substrait_rex(e, input_schema, extensions)
+                    from_substrait_rex(ctx, e, input_schema, extensions)
                         .await?
                         .as_ref()
                         .clone(),
@@ -843,7 +866,7 @@ pub async fn from_substrait_rex(
                     for arg in &f.arguments {
                         let arg_expr = match &arg.arg_type {
                             Some(ArgType::Value(e)) => {
-                                from_substrait_rex(e, input_schema, 
extensions).await
+                                from_substrait_rex(ctx, e, input_schema, 
extensions).await
                             }
                             _ => not_impl_err!(
                                 "Aggregated function argument non-Value type 
not supported"
@@ -868,14 +891,14 @@ pub async fn from_substrait_rex(
                         (Some(ArgType::Value(l)), Some(ArgType::Value(r))) => {
                             Ok(Arc::new(Expr::BinaryExpr(BinaryExpr {
                                 left: Box::new(
-                                    from_substrait_rex(l, input_schema, 
extensions)
+                                    from_substrait_rex(ctx, l, input_schema, 
extensions)
                                         .await?
                                         .as_ref()
                                         .clone(),
                                 ),
                                 op,
                                 right: Box::new(
-                                    from_substrait_rex(r, input_schema, 
extensions)
+                                    from_substrait_rex(ctx, r, input_schema, 
extensions)
                                         .await?
                                         .as_ref()
                                         .clone(),
@@ -888,7 +911,7 @@ pub async fn from_substrait_rex(
                     }
                 }
                 ScalarFunctionType::Expr(builder) => {
-                    builder.build(f, input_schema, extensions).await
+                    builder.build(ctx, f, input_schema, extensions).await
                 }
             }
         }
@@ -900,6 +923,7 @@ pub async fn from_substrait_rex(
             Some(output_type) => Ok(Arc::new(Expr::Cast(Cast::new(
                 Box::new(
                     from_substrait_rex(
+                        ctx,
                         cast.as_ref().input.as_ref().unwrap().as_ref(),
                         input_schema,
                         extensions,
@@ -921,7 +945,8 @@ pub async fn from_substrait_rex(
                 ),
             };
             let order_by =
-                from_substrait_sorts(&window.sorts, input_schema, 
extensions).await?;
+                from_substrait_sorts(ctx, &window.sorts, input_schema, 
extensions)
+                    .await?;
             // Substrait does not encode WindowFrameUnits so we're using a 
simple logic to determine the units
             // If there is no `ORDER BY`, then by default, the frame counts 
each row from the lower up to upper boundary
             // If there is `ORDER BY`, then by default, each frame is a range 
starting from unbounded preceding to current row
@@ -934,12 +959,14 @@ pub async fn from_substrait_rex(
             Ok(Arc::new(Expr::WindowFunction(expr::WindowFunction {
                 fun: fun?.unwrap(),
                 args: from_substriat_func_args(
+                    ctx,
                     &window.arguments,
                     input_schema,
                     extensions,
                 )
                 .await?,
                 partition_by: from_substrait_rex_vec(
+                    ctx,
                     &window.partitions,
                     input_schema,
                     extensions,
@@ -953,6 +980,51 @@ pub async fn from_substrait_rex(
                 },
             })))
         }
+        Some(RexType::Subquery(subquery)) => match 
&subquery.as_ref().subquery_type {
+            Some(subquery_type) => match subquery_type {
+                SubqueryType::InPredicate(in_predicate) => {
+                    if in_predicate.needles.len() != 1 {
+                        Err(DataFusionError::Substrait(
+                            "InPredicate Subquery type must have exactly one 
Needle expression"
+                                .to_string(),
+                        ))
+                    } else {
+                        let needle_expr = &in_predicate.needles[0];
+                        let haystack_expr = &in_predicate.haystack;
+                        if let Some(haystack_expr) = haystack_expr {
+                            let haystack_expr =
+                                from_substrait_rel(ctx, haystack_expr, 
extensions)
+                                    .await?;
+                            let outer_refs = haystack_expr.all_out_ref_exprs();
+                            Ok(Arc::new(Expr::InSubquery(InSubquery {
+                                expr: Box::new(
+                                    from_substrait_rex(
+                                        ctx,
+                                        needle_expr,
+                                        input_schema,
+                                        extensions,
+                                    )
+                                    .await?
+                                    .as_ref()
+                                    .clone(),
+                                ),
+                                subquery: Subquery {
+                                    subquery: Arc::new(haystack_expr),
+                                    outer_ref_columns: outer_refs,
+                                },
+                                negated: false,
+                            })))
+                        } else {
+                            substrait_err!("InPredicate Subquery type must 
have a Haystack expression")
+                        }
+                    }
+                }
+                _ => substrait_err!("Subquery type not implemented"),
+            },
+            None => {
+                substrait_err!("Subquery experssion without SubqueryType is 
not allowed")
+            }
+        },
         _ => not_impl_err!("unsupported rex_type"),
     }
 }
@@ -1312,16 +1384,22 @@ impl BuiltinExprBuilder {
 
     pub async fn build(
         self,
+        ctx: &SessionContext,
         f: &ScalarFunction,
         input_schema: &DFSchema,
         extensions: &HashMap<u32, &String>,
     ) -> Result<Arc<Expr>> {
         match self.expr_name.as_str() {
-            "like" => Self::build_like_expr(false, f, input_schema, 
extensions).await,
-            "ilike" => Self::build_like_expr(true, f, input_schema, 
extensions).await,
+            "like" => {
+                Self::build_like_expr(ctx, false, f, input_schema, 
extensions).await
+            }
+            "ilike" => {
+                Self::build_like_expr(ctx, true, f, input_schema, 
extensions).await
+            }
             "not" | "negative" | "is_null" | "is_not_null" | "is_true" | 
"is_false"
             | "is_not_true" | "is_not_false" | "is_unknown" | "is_not_unknown" 
=> {
-                Self::build_unary_expr(&self.expr_name, f, input_schema, 
extensions).await
+                Self::build_unary_expr(ctx, &self.expr_name, f, input_schema, 
extensions)
+                    .await
             }
             _ => {
                 not_impl_err!("Unsupported builtin expression: {}", 
self.expr_name)
@@ -1330,6 +1408,7 @@ impl BuiltinExprBuilder {
     }
 
     async fn build_unary_expr(
+        ctx: &SessionContext,
         fn_name: &str,
         f: &ScalarFunction,
         input_schema: &DFSchema,
@@ -1341,7 +1420,7 @@ impl BuiltinExprBuilder {
         let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type 
else {
             return substrait_err!("Invalid arguments type for {fn_name} expr");
         };
-        let arg = from_substrait_rex(expr_substrait, input_schema, extensions)
+        let arg = from_substrait_rex(ctx, expr_substrait, input_schema, 
extensions)
             .await?
             .as_ref()
             .clone();
@@ -1365,6 +1444,7 @@ impl BuiltinExprBuilder {
     }
 
     async fn build_like_expr(
+        ctx: &SessionContext,
         case_insensitive: bool,
         f: &ScalarFunction,
         input_schema: &DFSchema,
@@ -1378,22 +1458,23 @@ impl BuiltinExprBuilder {
         let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type 
else {
             return substrait_err!("Invalid arguments type for `{fn_name}` 
expr");
         };
-        let expr = from_substrait_rex(expr_substrait, input_schema, extensions)
+        let expr = from_substrait_rex(ctx, expr_substrait, input_schema, 
extensions)
             .await?
             .as_ref()
             .clone();
         let Some(ArgType::Value(pattern_substrait)) = &f.arguments[1].arg_type 
else {
             return substrait_err!("Invalid arguments type for `{fn_name}` 
expr");
         };
-        let pattern = from_substrait_rex(pattern_substrait, input_schema, 
extensions)
-            .await?
-            .as_ref()
-            .clone();
+        let pattern =
+            from_substrait_rex(ctx, pattern_substrait, input_schema, 
extensions)
+                .await?
+                .as_ref()
+                .clone();
         let Some(ArgType::Value(escape_char_substrait)) = 
&f.arguments[2].arg_type else {
             return substrait_err!("Invalid arguments type for `{fn_name}` 
expr");
         };
         let escape_char_expr =
-            from_substrait_rex(escape_char_substrait, input_schema, extensions)
+            from_substrait_rex(ctx, escape_char_substrait, input_schema, 
extensions)
                 .await?
                 .as_ref()
                 .clone();
diff --git a/datafusion/substrait/src/logical_plan/producer.rs 
b/datafusion/substrait/src/logical_plan/producer.rs
index 50f8725442..926883251a 100644
--- a/datafusion/substrait/src/logical_plan/producer.rs
+++ b/datafusion/substrait/src/logical_plan/producer.rs
@@ -36,12 +36,13 @@ use datafusion::common::{substrait_err, DFSchemaRef};
 use datafusion::logical_expr::aggregate_function;
 use datafusion::logical_expr::expr::{
     AggregateFunctionDefinition, Alias, BinaryExpr, Case, Cast, GroupingSet, 
InList,
-    ScalarFunctionDefinition, Sort, WindowFunction,
+    InSubquery, ScalarFunctionDefinition, Sort, WindowFunction,
 };
 use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, 
Operator};
 use datafusion::prelude::Expr;
 use prost_types::Any as ProtoAny;
 use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields};
+use substrait::proto::expression::subquery::InPredicate;
 use substrait::proto::expression::window_function::BoundsType;
 use substrait::proto::{CrossRel, ExchangeRel};
 use substrait::{
@@ -58,7 +59,8 @@ use substrait::{
             window_function::bound::Kind as BoundKind,
             window_function::Bound,
             FieldReference, IfThen, Literal, MaskExpression, ReferenceSegment, 
RexType,
-            ScalarFunction, SingularOrList, WindowFunction as 
SubstraitWindowFunction,
+            ScalarFunction, SingularOrList, Subquery,
+            WindowFunction as SubstraitWindowFunction,
         },
         extensions::{
             self,
@@ -167,7 +169,7 @@ pub fn to_substrait_rel(
             let expressions = p
                 .expr
                 .iter()
-                .map(|e| to_substrait_rex(e, p.input.schema(), 0, 
extension_info))
+                .map(|e| to_substrait_rex(ctx, e, p.input.schema(), 0, 
extension_info))
                 .collect::<Result<Vec<_>>>()?;
             Ok(Box::new(Rel {
                 rel_type: Some(RelType::Project(Box::new(ProjectRel {
@@ -181,6 +183,7 @@ pub fn to_substrait_rel(
         LogicalPlan::Filter(filter) => {
             let input = to_substrait_rel(filter.input.as_ref(), ctx, 
extension_info)?;
             let filter_expr = to_substrait_rex(
+                ctx,
                 &filter.predicate,
                 filter.input.schema(),
                 0,
@@ -214,7 +217,9 @@ pub fn to_substrait_rel(
             let sort_fields = sort
                 .expr
                 .iter()
-                .map(|e| substrait_sort_field(e, sort.input.schema(), 
extension_info))
+                .map(|e| {
+                    substrait_sort_field(ctx, e, sort.input.schema(), 
extension_info)
+                })
                 .collect::<Result<Vec<_>>>()?;
             Ok(Box::new(Rel {
                 rel_type: Some(RelType::Sort(Box::new(SortRel {
@@ -228,6 +233,7 @@ pub fn to_substrait_rel(
         LogicalPlan::Aggregate(agg) => {
             let input = to_substrait_rel(agg.input.as_ref(), ctx, 
extension_info)?;
             let groupings = to_substrait_groupings(
+                ctx,
                 &agg.group_expr,
                 agg.input.schema(),
                 extension_info,
@@ -235,7 +241,9 @@ pub fn to_substrait_rel(
             let measures = agg
                 .aggr_expr
                 .iter()
-                .map(|e| to_substrait_agg_measure(e, agg.input.schema(), 
extension_info))
+                .map(|e| {
+                    to_substrait_agg_measure(ctx, e, agg.input.schema(), 
extension_info)
+                })
                 .collect::<Result<Vec<_>>>()?;
 
             Ok(Box::new(Rel {
@@ -283,6 +291,7 @@ pub fn to_substrait_rel(
             let in_join_schema = join.left.schema().join(join.right.schema())?;
             let join_filter = match &join.filter {
                 Some(filter) => Some(to_substrait_rex(
+                    ctx,
                     filter,
                     &Arc::new(in_join_schema),
                     0,
@@ -299,6 +308,7 @@ pub fn to_substrait_rel(
                 Operator::Eq
             };
             let join_on = to_substrait_join_expr(
+                ctx,
                 &join.on,
                 eq_op,
                 join.left.schema(),
@@ -401,6 +411,7 @@ pub fn to_substrait_rel(
             let mut window_exprs = vec![];
             for expr in &window.window_expr {
                 window_exprs.push(to_substrait_rex(
+                    ctx,
                     expr,
                     window.input.schema(),
                     0,
@@ -500,6 +511,7 @@ pub fn to_substrait_rel(
 }
 
 fn to_substrait_join_expr(
+    ctx: &SessionContext,
     join_conditions: &Vec<(Expr, Expr)>,
     eq_op: Operator,
     left_schema: &DFSchemaRef,
@@ -513,9 +525,10 @@ fn to_substrait_join_expr(
     let mut exprs: Vec<Expression> = vec![];
     for (left, right) in join_conditions {
         // Parse left
-        let l = to_substrait_rex(left, left_schema, 0, extension_info)?;
+        let l = to_substrait_rex(ctx, left, left_schema, 0, extension_info)?;
         // Parse right
         let r = to_substrait_rex(
+            ctx,
             right,
             right_schema,
             left_schema.fields().len(), // offset to return the correct index
@@ -576,6 +589,7 @@ pub fn operator_to_name(op: Operator) -> &'static str {
 }
 
 pub fn parse_flat_grouping_exprs(
+    ctx: &SessionContext,
     exprs: &[Expr],
     schema: &DFSchemaRef,
     extension_info: &mut (
@@ -585,7 +599,7 @@ pub fn parse_flat_grouping_exprs(
 ) -> Result<Grouping> {
     let grouping_expressions = exprs
         .iter()
-        .map(|e| to_substrait_rex(e, schema, 0, extension_info))
+        .map(|e| to_substrait_rex(ctx, e, schema, 0, extension_info))
         .collect::<Result<Vec<_>>>()?;
     Ok(Grouping {
         grouping_expressions,
@@ -593,6 +607,7 @@ pub fn parse_flat_grouping_exprs(
 }
 
 pub fn to_substrait_groupings(
+    ctx: &SessionContext,
     exprs: &Vec<Expr>,
     schema: &DFSchemaRef,
     extension_info: &mut (
@@ -608,7 +623,9 @@ pub fn to_substrait_groupings(
                 )),
                 GroupingSet::GroupingSets(sets) => Ok(sets
                     .iter()
-                    .map(|set| parse_flat_grouping_exprs(set, schema, 
extension_info))
+                    .map(|set| {
+                        parse_flat_grouping_exprs(ctx, set, schema, 
extension_info)
+                    })
                     .collect::<Result<Vec<_>>>()?),
                 GroupingSet::Rollup(set) => {
                     let mut sets: Vec<Vec<Expr>> = vec![vec![]];
@@ -618,17 +635,21 @@ pub fn to_substrait_groupings(
                     Ok(sets
                         .iter()
                         .rev()
-                        .map(|set| parse_flat_grouping_exprs(set, schema, 
extension_info))
+                        .map(|set| {
+                            parse_flat_grouping_exprs(ctx, set, schema, 
extension_info)
+                        })
                         .collect::<Result<Vec<_>>>()?)
                 }
             },
             _ => Ok(vec![parse_flat_grouping_exprs(
+                ctx,
                 exprs,
                 schema,
                 extension_info,
             )?]),
         },
         _ => Ok(vec![parse_flat_grouping_exprs(
+            ctx,
             exprs,
             schema,
             extension_info,
@@ -638,6 +659,7 @@ pub fn to_substrait_groupings(
 
 #[allow(deprecated)]
 pub fn to_substrait_agg_measure(
+    ctx: &SessionContext,
     expr: &Expr,
     schema: &DFSchemaRef,
     extension_info: &mut (
@@ -650,13 +672,13 @@ pub fn to_substrait_agg_measure(
             match func_def {
                 AggregateFunctionDefinition::BuiltIn (fun) => {
                     let sorts = if let Some(order_by) = order_by {
-                        order_by.iter().map(|expr| 
to_substrait_sort_field(expr, schema, 
extension_info)).collect::<Result<Vec<_>>>()?
+                        order_by.iter().map(|expr| 
to_substrait_sort_field(ctx, expr, schema, 
extension_info)).collect::<Result<Vec<_>>>()?
                     } else {
                         vec![]
                     };
                     let mut arguments: Vec<FunctionArgument> = vec![];
                     for arg in args {
-                        arguments.push(FunctionArgument { arg_type: 
Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) });
+                        arguments.push(FunctionArgument { arg_type: 
Some(ArgType::Value(to_substrait_rex(ctx, arg, schema, 0, extension_info)?)) });
                     }
                     let function_anchor = _register_function(fun.to_string(), 
extension_info);
                     Ok(Measure {
@@ -674,20 +696,20 @@ pub fn to_substrait_agg_measure(
                             options: vec![],
                         }),
                         filter: match filter {
-                            Some(f) => Some(to_substrait_rex(f, schema, 0, 
extension_info)?),
+                            Some(f) => Some(to_substrait_rex(ctx, f, schema, 
0, extension_info)?),
                             None => None
                         }
                     })
                 }
                 AggregateFunctionDefinition::UDF(fun) => {
                     let sorts = if let Some(order_by) = order_by {
-                        order_by.iter().map(|expr| 
to_substrait_sort_field(expr, schema, 
extension_info)).collect::<Result<Vec<_>>>()?
+                        order_by.iter().map(|expr| 
to_substrait_sort_field(ctx, expr, schema, 
extension_info)).collect::<Result<Vec<_>>>()?
                     } else {
                         vec![]
                     };
                     let mut arguments: Vec<FunctionArgument> = vec![];
                     for arg in args {
-                        arguments.push(FunctionArgument { arg_type: 
Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) });
+                        arguments.push(FunctionArgument { arg_type: 
Some(ArgType::Value(to_substrait_rex(ctx, arg, schema, 0, extension_info)?)) });
                     }
                     let function_anchor = 
_register_function(fun.name().to_string(), extension_info);
                     Ok(Measure {
@@ -702,7 +724,7 @@ pub fn to_substrait_agg_measure(
                             options: vec![],
                         }),
                         filter: match filter {
-                            Some(f) => Some(to_substrait_rex(f, schema, 0, 
extension_info)?),
+                            Some(f) => Some(to_substrait_rex(ctx, f, schema, 
0, extension_info)?),
                             None => None
                         }
                     })
@@ -714,7 +736,7 @@ pub fn to_substrait_agg_measure(
 
         }
         Expr::Alias(Alias{expr,..})=> {
-            to_substrait_agg_measure(expr, schema, extension_info)
+            to_substrait_agg_measure(ctx, expr, schema, extension_info)
         }
         _ => internal_err!(
             "Expression must be compatible with aggregation. Unsupported 
expression: {:?}. ExpressionType: {:?}",
@@ -726,6 +748,7 @@ pub fn to_substrait_agg_measure(
 
 /// Converts sort expression to corresponding substrait `SortField`
 fn to_substrait_sort_field(
+    ctx: &SessionContext,
     expr: &Expr,
     schema: &DFSchemaRef,
     extension_info: &mut (
@@ -743,6 +766,7 @@ fn to_substrait_sort_field(
             };
             Ok(SortField {
                 expr: Some(to_substrait_rex(
+                    ctx,
                     sort.expr.deref(),
                     schema,
                     0,
@@ -851,6 +875,7 @@ pub fn make_binary_op_scalar_func(
 /// * `extension_info` - Substrait extension info. Contains registered 
function information
 #[allow(deprecated)]
 pub fn to_substrait_rex(
+    ctx: &SessionContext,
     expr: &Expr,
     schema: &DFSchemaRef,
     col_ref_offset: usize,
@@ -867,10 +892,10 @@ pub fn to_substrait_rex(
         }) => {
             let substrait_list = list
                 .iter()
-                .map(|x| to_substrait_rex(x, schema, col_ref_offset, 
extension_info))
+                .map(|x| to_substrait_rex(ctx, x, schema, col_ref_offset, 
extension_info))
                 .collect::<Result<Vec<Expression>>>()?;
             let substrait_expr =
-                to_substrait_rex(expr, schema, col_ref_offset, 
extension_info)?;
+                to_substrait_rex(ctx, expr, schema, col_ref_offset, 
extension_info)?;
 
             let substrait_or_list = Expression {
                 rex_type: Some(RexType::SingularOrList(Box::new(SingularOrList 
{
@@ -903,6 +928,7 @@ pub fn to_substrait_rex(
             for arg in &fun.args {
                 arguments.push(FunctionArgument {
                     arg_type: Some(ArgType::Value(to_substrait_rex(
+                        ctx,
                         arg,
                         schema,
                         col_ref_offset,
@@ -937,11 +963,11 @@ pub fn to_substrait_rex(
             if *negated {
                 // `expr NOT BETWEEN low AND high` can be translated into 
(expr < low OR high < expr)
                 let substrait_expr =
-                    to_substrait_rex(expr, schema, col_ref_offset, 
extension_info)?;
+                    to_substrait_rex(ctx, expr, schema, col_ref_offset, 
extension_info)?;
                 let substrait_low =
-                    to_substrait_rex(low, schema, col_ref_offset, 
extension_info)?;
+                    to_substrait_rex(ctx, low, schema, col_ref_offset, 
extension_info)?;
                 let substrait_high =
-                    to_substrait_rex(high, schema, col_ref_offset, 
extension_info)?;
+                    to_substrait_rex(ctx, high, schema, col_ref_offset, 
extension_info)?;
 
                 let l_expr = make_binary_op_scalar_func(
                     &substrait_expr,
@@ -965,11 +991,11 @@ pub fn to_substrait_rex(
             } else {
                 // `expr BETWEEN low AND high` can be translated into (low <= 
expr AND expr <= high)
                 let substrait_expr =
-                    to_substrait_rex(expr, schema, col_ref_offset, 
extension_info)?;
+                    to_substrait_rex(ctx, expr, schema, col_ref_offset, 
extension_info)?;
                 let substrait_low =
-                    to_substrait_rex(low, schema, col_ref_offset, 
extension_info)?;
+                    to_substrait_rex(ctx, low, schema, col_ref_offset, 
extension_info)?;
                 let substrait_high =
-                    to_substrait_rex(high, schema, col_ref_offset, 
extension_info)?;
+                    to_substrait_rex(ctx, high, schema, col_ref_offset, 
extension_info)?;
 
                 let l_expr = make_binary_op_scalar_func(
                     &substrait_low,
@@ -997,8 +1023,8 @@ pub fn to_substrait_rex(
             substrait_field_ref(index + col_ref_offset)
         }
         Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
-            let l = to_substrait_rex(left, schema, col_ref_offset, 
extension_info)?;
-            let r = to_substrait_rex(right, schema, col_ref_offset, 
extension_info)?;
+            let l = to_substrait_rex(ctx, left, schema, col_ref_offset, 
extension_info)?;
+            let r = to_substrait_rex(ctx, right, schema, col_ref_offset, 
extension_info)?;
 
             Ok(make_binary_op_scalar_func(&l, &r, *op, extension_info))
         }
@@ -1013,6 +1039,7 @@ pub fn to_substrait_rex(
                 // Base expression exists
                 ifs.push(IfClause {
                     r#if: Some(to_substrait_rex(
+                        ctx,
                         e,
                         schema,
                         col_ref_offset,
@@ -1025,12 +1052,14 @@ pub fn to_substrait_rex(
             for (r#if, then) in when_then_expr {
                 ifs.push(IfClause {
                     r#if: Some(to_substrait_rex(
+                        ctx,
                         r#if,
                         schema,
                         col_ref_offset,
                         extension_info,
                     )?),
                     then: Some(to_substrait_rex(
+                        ctx,
                         then,
                         schema,
                         col_ref_offset,
@@ -1042,6 +1071,7 @@ pub fn to_substrait_rex(
             // Parse outer `else`
             let r#else: Option<Box<Expression>> = match else_expr {
                 Some(e) => Some(Box::new(to_substrait_rex(
+                    ctx,
                     e,
                     schema,
                     col_ref_offset,
@@ -1060,6 +1090,7 @@ pub fn to_substrait_rex(
                     substrait::proto::expression::Cast {
                         r#type: Some(to_substrait_type(data_type)?),
                         input: Some(Box::new(to_substrait_rex(
+                            ctx,
                             expr,
                             schema,
                             col_ref_offset,
@@ -1072,7 +1103,7 @@ pub fn to_substrait_rex(
         }
         Expr::Literal(value) => to_substrait_literal(value),
         Expr::Alias(Alias { expr, .. }) => {
-            to_substrait_rex(expr, schema, col_ref_offset, extension_info)
+            to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)
         }
         Expr::WindowFunction(WindowFunction {
             fun,
@@ -1088,6 +1119,7 @@ pub fn to_substrait_rex(
             for arg in args {
                 arguments.push(FunctionArgument {
                     arg_type: Some(ArgType::Value(to_substrait_rex(
+                        ctx,
                         arg,
                         schema,
                         col_ref_offset,
@@ -1098,12 +1130,12 @@ pub fn to_substrait_rex(
             // partition by expressions
             let partition_by = partition_by
                 .iter()
-                .map(|e| to_substrait_rex(e, schema, col_ref_offset, 
extension_info))
+                .map(|e| to_substrait_rex(ctx, e, schema, col_ref_offset, 
extension_info))
                 .collect::<Result<Vec<_>>>()?;
             // order by expressions
             let order_by = order_by
                 .iter()
-                .map(|e| substrait_sort_field(e, schema, extension_info))
+                .map(|e| substrait_sort_field(ctx, e, schema, extension_info))
                 .collect::<Result<Vec<_>>>()?;
             // window frame
             let bounds = to_substrait_bounds(window_frame)?;
@@ -1124,6 +1156,7 @@ pub fn to_substrait_rex(
             escape_char,
             case_insensitive,
         }) => make_substrait_like_expr(
+            ctx,
             *case_insensitive,
             *negated,
             expr,
@@ -1133,7 +1166,50 @@ pub fn to_substrait_rex(
             col_ref_offset,
             extension_info,
         ),
+        Expr::InSubquery(InSubquery {
+            expr,
+            subquery,
+            negated,
+        }) => {
+            let substrait_expr =
+                to_substrait_rex(ctx, expr, schema, col_ref_offset, 
extension_info)?;
+
+            let subquery_plan =
+                to_substrait_rel(subquery.subquery.as_ref(), ctx, 
extension_info)?;
+
+            let substrait_subquery = Expression {
+                rex_type: Some(RexType::Subquery(Box::new(Subquery {
+                    subquery_type: Some(
+                        
substrait::proto::expression::subquery::SubqueryType::InPredicate(
+                            Box::new(InPredicate {
+                                needles: (vec![substrait_expr]),
+                                haystack: Some(subquery_plan),
+                            }),
+                        ),
+                    ),
+                }))),
+            };
+            if *negated {
+                let function_anchor =
+                    _register_function("not".to_string(), extension_info);
+
+                Ok(Expression {
+                    rex_type: Some(RexType::ScalarFunction(ScalarFunction {
+                        function_reference: function_anchor,
+                        arguments: vec![FunctionArgument {
+                            arg_type: Some(ArgType::Value(substrait_subquery)),
+                        }],
+                        output_type: None,
+                        args: vec![],
+                        options: vec![],
+                    })),
+                })
+            } else {
+                Ok(substrait_subquery)
+            }
+        }
         Expr::Not(arg) => to_substrait_unary_scalar_fn(
+            ctx,
             "not",
             arg,
             schema,
@@ -1141,6 +1217,7 @@ pub fn to_substrait_rex(
             extension_info,
         ),
         Expr::IsNull(arg) => to_substrait_unary_scalar_fn(
+            ctx,
             "is_null",
             arg,
             schema,
@@ -1148,6 +1225,7 @@ pub fn to_substrait_rex(
             extension_info,
         ),
         Expr::IsNotNull(arg) => to_substrait_unary_scalar_fn(
+            ctx,
             "is_not_null",
             arg,
             schema,
@@ -1155,6 +1233,7 @@ pub fn to_substrait_rex(
             extension_info,
         ),
         Expr::IsTrue(arg) => to_substrait_unary_scalar_fn(
+            ctx,
             "is_true",
             arg,
             schema,
@@ -1162,6 +1241,7 @@ pub fn to_substrait_rex(
             extension_info,
         ),
         Expr::IsFalse(arg) => to_substrait_unary_scalar_fn(
+            ctx,
             "is_false",
             arg,
             schema,
@@ -1169,6 +1249,7 @@ pub fn to_substrait_rex(
             extension_info,
         ),
         Expr::IsUnknown(arg) => to_substrait_unary_scalar_fn(
+            ctx,
             "is_unknown",
             arg,
             schema,
@@ -1176,6 +1257,7 @@ pub fn to_substrait_rex(
             extension_info,
         ),
         Expr::IsNotTrue(arg) => to_substrait_unary_scalar_fn(
+            ctx,
             "is_not_true",
             arg,
             schema,
@@ -1183,6 +1265,7 @@ pub fn to_substrait_rex(
             extension_info,
         ),
         Expr::IsNotFalse(arg) => to_substrait_unary_scalar_fn(
+            ctx,
             "is_not_false",
             arg,
             schema,
@@ -1190,6 +1273,7 @@ pub fn to_substrait_rex(
             extension_info,
         ),
         Expr::IsNotUnknown(arg) => to_substrait_unary_scalar_fn(
+            ctx,
             "is_not_unknown",
             arg,
             schema,
@@ -1197,6 +1281,7 @@ pub fn to_substrait_rex(
             extension_info,
         ),
         Expr::Negative(arg) => to_substrait_unary_scalar_fn(
+            ctx,
             "negative",
             arg,
             schema,
@@ -1421,6 +1506,7 @@ fn make_substrait_window_function(
 #[allow(deprecated)]
 #[allow(clippy::too_many_arguments)]
 fn make_substrait_like_expr(
+    ctx: &SessionContext,
     ignore_case: bool,
     negated: bool,
     expr: &Expr,
@@ -1438,8 +1524,8 @@ fn make_substrait_like_expr(
     } else {
         _register_function("like".to_string(), extension_info)
     };
-    let expr = to_substrait_rex(expr, schema, col_ref_offset, extension_info)?;
-    let pattern = to_substrait_rex(pattern, schema, col_ref_offset, 
extension_info)?;
+    let expr = to_substrait_rex(ctx, expr, schema, col_ref_offset, 
extension_info)?;
+    let pattern = to_substrait_rex(ctx, pattern, schema, col_ref_offset, 
extension_info)?;
     let escape_char =
         to_substrait_literal(&ScalarValue::Utf8(escape_char.map(|c| 
c.to_string())))?;
     let arguments = vec![
@@ -1669,6 +1755,7 @@ fn to_substrait_literal(value: &ScalarValue) -> 
Result<Expression> {
 
 /// Util to generate substrait [RexType::ScalarFunction] with one argument
 fn to_substrait_unary_scalar_fn(
+    ctx: &SessionContext,
     fn_name: &str,
     arg: &Expr,
     schema: &DFSchemaRef,
@@ -1679,7 +1766,8 @@ fn to_substrait_unary_scalar_fn(
     ),
 ) -> Result<Expression> {
     let function_anchor = _register_function(fn_name.to_string(), 
extension_info);
-    let substrait_expr = to_substrait_rex(arg, schema, col_ref_offset, 
extension_info)?;
+    let substrait_expr =
+        to_substrait_rex(ctx, arg, schema, col_ref_offset, extension_info)?;
 
     Ok(Expression {
         rex_type: Some(RexType::ScalarFunction(ScalarFunction {
@@ -1880,6 +1968,7 @@ fn try_to_substrait_field_reference(
 }
 
 fn substrait_sort_field(
+    ctx: &SessionContext,
     expr: &Expr,
     schema: &DFSchemaRef,
     extension_info: &mut (
@@ -1893,7 +1982,7 @@ fn substrait_sort_field(
             asc,
             nulls_first,
         }) => {
-            let e = to_substrait_rex(expr, schema, 0, extension_info)?;
+            let e = to_substrait_rex(ctx, expr, schema, 0, extension_info)?;
             let d = match (asc, nulls_first) {
                 (true, true) => SortDirection::AscNullsFirst,
                 (true, false) => SortDirection::AscNullsLast,
diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs 
b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
index 47eb5a8f73..d7327caee4 100644
--- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
@@ -394,6 +394,24 @@ async fn roundtrip_inlist_4() -> Result<()> {
     roundtrip("SELECT * FROM data WHERE f NOT IN ('a', 'b', 'c', 'd')").await
 }
 
+#[tokio::test]
+async fn roundtrip_inlist_5() -> Result<()> {
+    // on roundtrip there is an additional projection during TableScan which 
includes all column of the table,
+    // using assert_expected_plan here as a workaround
+    assert_expected_plan(
+    "SELECT a, f FROM data WHERE (f IN ('a', 'b', 'c') OR a in (SELECT data2.a 
FROM data2 WHERE f IN ('b', 'c', 'd')))",
+    "Filter: data.f = Utf8(\"a\") OR data.f = Utf8(\"b\") OR data.f = 
Utf8(\"c\") OR data.a IN (<subquery>)\
+    \n  Subquery:\
+    \n    Projection: data2.a\
+    \n      Filter: data2.f IN ([Utf8(\"b\"), Utf8(\"c\"), Utf8(\"d\")])\
+    \n        TableScan: data2 projection=[a, b, c, d, e, f]\
+    \n  TableScan: data projection=[a, f], partial_filters=[data.f = 
Utf8(\"a\") OR data.f = Utf8(\"b\") OR data.f = Utf8(\"c\") OR data.a IN 
(<subquery>)]\
+    \n    Subquery:\
+    \n      Projection: data2.a\
+    \n        Filter: data2.f IN ([Utf8(\"b\"), Utf8(\"c\"), Utf8(\"d\")])\
+    \n          TableScan: data2 projection=[a, b, c, d, e, f]").await
+}
+
 #[tokio::test]
 async fn roundtrip_cross_join() -> Result<()> {
     roundtrip("SELECT * FROM data CROSS JOIN data2").await


Reply via email to