This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 0e9c189a2e Substrait insubquery (#8363)
0e9c189a2e is described below
commit 0e9c189a2e4f8f6304239d6cbe14f5114a6d0406
Author: Tanmay Gujar <[email protected]>
AuthorDate: Wed Dec 20 15:48:11 2023 -0500
Substrait insubquery (#8363)
* testing in subquery support for substrait producer
* consumer fails with table not found
* testing roundtrip check
* pass in ctx to expr
* basic test for Insubquery
* fix: outer refs in consumer
* fix: merge issues
* minor fixes
* fix: fmt and clippy CI errors
* improve error msg in consumer
* minor fixes
---
datafusion/substrait/src/logical_plan/consumer.rs | 151 +++++++++++++++-----
datafusion/substrait/src/logical_plan/producer.rs | 155 ++++++++++++++++-----
.../tests/cases/roundtrip_logical_plan.rs | 18 +++
3 files changed, 256 insertions(+), 68 deletions(-)
diff --git a/datafusion/substrait/src/logical_plan/consumer.rs
b/datafusion/substrait/src/logical_plan/consumer.rs
index b7fee96bba..9931dd15ae 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -28,7 +28,7 @@ use datafusion::logical_expr::{
};
use datafusion::logical_expr::{
expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning,
- Repartition, WindowFrameBound, WindowFrameUnits,
+ Repartition, Subquery, WindowFrameBound, WindowFrameUnits,
};
use datafusion::prelude::JoinType;
use datafusion::sql::TableReference;
@@ -39,6 +39,7 @@ use datafusion::{
scalar::ScalarValue,
};
use substrait::proto::exchange_rel::ExchangeKind;
+use substrait::proto::expression::subquery::SubqueryType;
use substrait::proto::expression::{FieldReference, Literal, ScalarFunction};
use substrait::proto::{
aggregate_function::AggregationInvocation,
@@ -61,7 +62,7 @@ use substrait::proto::{
use substrait::proto::{FunctionArgument, SortField};
use datafusion::common::plan_err;
-use datafusion::logical_expr::expr::{InList, Sort};
+use datafusion::logical_expr::expr::{InList, InSubquery, Sort};
use std::collections::HashMap;
use std::str::FromStr;
use std::sync::Arc;
@@ -230,7 +231,8 @@ pub async fn from_substrait_rel(
let mut exprs: Vec<Expr> = vec![];
for e in &p.expressions {
let x =
- from_substrait_rex(e, input.clone().schema(),
extensions).await?;
+ from_substrait_rex(ctx, e, input.clone().schema(),
extensions)
+ .await?;
// if the expression is WindowFunction, wrap in a Window
relation
// before returning and do not add to list of this
Projection's expression list
// otherwise, add expression to the Projection's
expression list
@@ -256,7 +258,8 @@ pub async fn from_substrait_rel(
);
if let Some(condition) = filter.condition.as_ref() {
let expr =
- from_substrait_rex(condition, input.schema(),
extensions).await?;
+ from_substrait_rex(ctx, condition, input.schema(),
extensions)
+ .await?;
input.filter(expr.as_ref().clone())?.build()
} else {
not_impl_err!("Filter without an condition is not valid")
@@ -288,7 +291,8 @@ pub async fn from_substrait_rel(
from_substrait_rel(ctx, input, extensions).await?,
);
let sorts =
- from_substrait_sorts(&sort.sorts, input.schema(),
extensions).await?;
+ from_substrait_sorts(ctx, &sort.sorts, input.schema(),
extensions)
+ .await?;
input.sort(sorts)?.build()
} else {
not_impl_err!("Sort without an input is not valid")
@@ -306,7 +310,8 @@ pub async fn from_substrait_rel(
1 => {
for e in &agg.groupings[0].grouping_expressions {
let x =
- from_substrait_rex(e, input.schema(),
extensions).await?;
+ from_substrait_rex(ctx, e, input.schema(),
extensions)
+ .await?;
group_expr.push(x.as_ref().clone());
}
}
@@ -315,8 +320,13 @@ pub async fn from_substrait_rel(
for grouping in &agg.groupings {
let mut grouping_set = vec![];
for e in &grouping.grouping_expressions {
- let x = from_substrait_rex(e, input.schema(),
extensions)
- .await?;
+ let x = from_substrait_rex(
+ ctx,
+ e,
+ input.schema(),
+ extensions,
+ )
+ .await?;
grouping_set.push(x.as_ref().clone());
}
grouping_sets.push(grouping_set);
@@ -334,7 +344,7 @@ pub async fn from_substrait_rel(
for m in &agg.measures {
let filter = match &m.filter {
Some(fil) => Some(Box::new(
- from_substrait_rex(fil, input.schema(), extensions)
+ from_substrait_rex(ctx, fil, input.schema(),
extensions)
.await?
.as_ref()
.clone(),
@@ -402,8 +412,8 @@ pub async fn from_substrait_rel(
// Otherwise, build join with only the filter, without join keys
match &join.expression.as_ref() {
Some(expr) => {
- let on =
- from_substrait_rex(expr, &in_join_schema,
extensions).await?;
+ let on = from_substrait_rex(ctx, expr, &in_join_schema,
extensions)
+ .await?;
// The join expression can contain both equal and
non-equal ops.
// As of datafusion 31.0.0, the equal and non equal join
conditions are in separate fields.
// So we extract each part as follows:
@@ -612,14 +622,16 @@ fn from_substrait_jointype(join_type: i32) ->
Result<JoinType> {
/// Convert Substrait Sorts to DataFusion Exprs
pub async fn from_substrait_sorts(
+ ctx: &SessionContext,
substrait_sorts: &Vec<SortField>,
input_schema: &DFSchema,
extensions: &HashMap<u32, &String>,
) -> Result<Vec<Expr>> {
let mut sorts: Vec<Expr> = vec![];
for s in substrait_sorts {
- let expr = from_substrait_rex(s.expr.as_ref().unwrap(), input_schema,
extensions)
- .await?;
+ let expr =
+ from_substrait_rex(ctx, s.expr.as_ref().unwrap(), input_schema,
extensions)
+ .await?;
let asc_nullfirst = match &s.sort_kind {
Some(k) => match k {
Direction(d) => {
@@ -660,13 +672,14 @@ pub async fn from_substrait_sorts(
/// Convert Substrait Expressions to DataFusion Exprs
pub async fn from_substrait_rex_vec(
+ ctx: &SessionContext,
exprs: &Vec<Expression>,
input_schema: &DFSchema,
extensions: &HashMap<u32, &String>,
) -> Result<Vec<Expr>> {
let mut expressions: Vec<Expr> = vec![];
for expr in exprs {
- let expression = from_substrait_rex(expr, input_schema,
extensions).await?;
+ let expression = from_substrait_rex(ctx, expr, input_schema,
extensions).await?;
expressions.push(expression.as_ref().clone());
}
Ok(expressions)
@@ -674,6 +687,7 @@ pub async fn from_substrait_rex_vec(
/// Convert Substrait FunctionArguments to DataFusion Exprs
pub async fn from_substriat_func_args(
+ ctx: &SessionContext,
arguments: &Vec<FunctionArgument>,
input_schema: &DFSchema,
extensions: &HashMap<u32, &String>,
@@ -682,7 +696,7 @@ pub async fn from_substriat_func_args(
for arg in arguments {
let arg_expr = match &arg.arg_type {
Some(ArgType::Value(e)) => {
- from_substrait_rex(e, input_schema, extensions).await
+ from_substrait_rex(ctx, e, input_schema, extensions).await
}
_ => {
not_impl_err!("Aggregated function argument non-Value type not
supported")
@@ -707,7 +721,7 @@ pub async fn from_substrait_agg_func(
for arg in &f.arguments {
let arg_expr = match &arg.arg_type {
Some(ArgType::Value(e)) => {
- from_substrait_rex(e, input_schema, extensions).await
+ from_substrait_rex(ctx, e, input_schema, extensions).await
}
_ => {
not_impl_err!("Aggregated function argument non-Value type not
supported")
@@ -745,6 +759,7 @@ pub async fn from_substrait_agg_func(
/// Convert Substrait Rex to DataFusion Expr
#[async_recursion]
pub async fn from_substrait_rex(
+ ctx: &SessionContext,
e: &Expression,
input_schema: &DFSchema,
extensions: &HashMap<u32, &String>,
@@ -755,13 +770,18 @@ pub async fn from_substrait_rex(
let substrait_list = s.options.as_ref();
Ok(Arc::new(Expr::InList(InList {
expr: Box::new(
- from_substrait_rex(substrait_expr, input_schema,
extensions)
+ from_substrait_rex(ctx, substrait_expr, input_schema,
extensions)
.await?
.as_ref()
.clone(),
),
- list: from_substrait_rex_vec(substrait_list, input_schema,
extensions)
- .await?,
+ list: from_substrait_rex_vec(
+ ctx,
+ substrait_list,
+ input_schema,
+ extensions,
+ )
+ .await?,
negated: false,
})))
}
@@ -779,6 +799,7 @@ pub async fn from_substrait_rex(
if if_expr.then.is_none() {
expr = Some(Box::new(
from_substrait_rex(
+ ctx,
if_expr.r#if.as_ref().unwrap(),
input_schema,
extensions,
@@ -793,6 +814,7 @@ pub async fn from_substrait_rex(
when_then_expr.push((
Box::new(
from_substrait_rex(
+ ctx,
if_expr.r#if.as_ref().unwrap(),
input_schema,
extensions,
@@ -803,6 +825,7 @@ pub async fn from_substrait_rex(
),
Box::new(
from_substrait_rex(
+ ctx,
if_expr.then.as_ref().unwrap(),
input_schema,
extensions,
@@ -816,7 +839,7 @@ pub async fn from_substrait_rex(
// Parse `else`
let else_expr = match &if_then.r#else {
Some(e) => Some(Box::new(
- from_substrait_rex(e, input_schema, extensions)
+ from_substrait_rex(ctx, e, input_schema, extensions)
.await?
.as_ref()
.clone(),
@@ -843,7 +866,7 @@ pub async fn from_substrait_rex(
for arg in &f.arguments {
let arg_expr = match &arg.arg_type {
Some(ArgType::Value(e)) => {
- from_substrait_rex(e, input_schema,
extensions).await
+ from_substrait_rex(ctx, e, input_schema,
extensions).await
}
_ => not_impl_err!(
"Aggregated function argument non-Value type
not supported"
@@ -868,14 +891,14 @@ pub async fn from_substrait_rex(
(Some(ArgType::Value(l)), Some(ArgType::Value(r))) => {
Ok(Arc::new(Expr::BinaryExpr(BinaryExpr {
left: Box::new(
- from_substrait_rex(l, input_schema,
extensions)
+ from_substrait_rex(ctx, l, input_schema,
extensions)
.await?
.as_ref()
.clone(),
),
op,
right: Box::new(
- from_substrait_rex(r, input_schema,
extensions)
+ from_substrait_rex(ctx, r, input_schema,
extensions)
.await?
.as_ref()
.clone(),
@@ -888,7 +911,7 @@ pub async fn from_substrait_rex(
}
}
ScalarFunctionType::Expr(builder) => {
- builder.build(f, input_schema, extensions).await
+ builder.build(ctx, f, input_schema, extensions).await
}
}
}
@@ -900,6 +923,7 @@ pub async fn from_substrait_rex(
Some(output_type) => Ok(Arc::new(Expr::Cast(Cast::new(
Box::new(
from_substrait_rex(
+ ctx,
cast.as_ref().input.as_ref().unwrap().as_ref(),
input_schema,
extensions,
@@ -921,7 +945,8 @@ pub async fn from_substrait_rex(
),
};
let order_by =
- from_substrait_sorts(&window.sorts, input_schema,
extensions).await?;
+ from_substrait_sorts(ctx, &window.sorts, input_schema,
extensions)
+ .await?;
// Substrait does not encode WindowFrameUnits so we're using a
simple logic to determine the units
// If there is no `ORDER BY`, then by default, the frame counts
each row from the lower up to upper boundary
// If there is `ORDER BY`, then by default, each frame is a range
starting from unbounded preceding to current row
@@ -934,12 +959,14 @@ pub async fn from_substrait_rex(
Ok(Arc::new(Expr::WindowFunction(expr::WindowFunction {
fun: fun?.unwrap(),
args: from_substriat_func_args(
+ ctx,
&window.arguments,
input_schema,
extensions,
)
.await?,
partition_by: from_substrait_rex_vec(
+ ctx,
&window.partitions,
input_schema,
extensions,
@@ -953,6 +980,51 @@ pub async fn from_substrait_rex(
},
})))
}
+ Some(RexType::Subquery(subquery)) => match
&subquery.as_ref().subquery_type {
+ Some(subquery_type) => match subquery_type {
+ SubqueryType::InPredicate(in_predicate) => {
+ if in_predicate.needles.len() != 1 {
+ Err(DataFusionError::Substrait(
+ "InPredicate Subquery type must have exactly one
Needle expression"
+ .to_string(),
+ ))
+ } else {
+ let needle_expr = &in_predicate.needles[0];
+ let haystack_expr = &in_predicate.haystack;
+ if let Some(haystack_expr) = haystack_expr {
+ let haystack_expr =
+ from_substrait_rel(ctx, haystack_expr,
extensions)
+ .await?;
+ let outer_refs = haystack_expr.all_out_ref_exprs();
+ Ok(Arc::new(Expr::InSubquery(InSubquery {
+ expr: Box::new(
+ from_substrait_rex(
+ ctx,
+ needle_expr,
+ input_schema,
+ extensions,
+ )
+ .await?
+ .as_ref()
+ .clone(),
+ ),
+ subquery: Subquery {
+ subquery: Arc::new(haystack_expr),
+ outer_ref_columns: outer_refs,
+ },
+ negated: false,
+ })))
+ } else {
+ substrait_err!("InPredicate Subquery type must
have a Haystack expression")
+ }
+ }
+ }
+ _ => substrait_err!("Subquery type not implemented"),
+ },
+ None => {
+ substrait_err!("Subquery experssion without SubqueryType is
not allowed")
+ }
+ },
_ => not_impl_err!("unsupported rex_type"),
}
}
@@ -1312,16 +1384,22 @@ impl BuiltinExprBuilder {
pub async fn build(
self,
+ ctx: &SessionContext,
f: &ScalarFunction,
input_schema: &DFSchema,
extensions: &HashMap<u32, &String>,
) -> Result<Arc<Expr>> {
match self.expr_name.as_str() {
- "like" => Self::build_like_expr(false, f, input_schema,
extensions).await,
- "ilike" => Self::build_like_expr(true, f, input_schema,
extensions).await,
+ "like" => {
+ Self::build_like_expr(ctx, false, f, input_schema,
extensions).await
+ }
+ "ilike" => {
+ Self::build_like_expr(ctx, true, f, input_schema,
extensions).await
+ }
"not" | "negative" | "is_null" | "is_not_null" | "is_true" |
"is_false"
| "is_not_true" | "is_not_false" | "is_unknown" | "is_not_unknown"
=> {
- Self::build_unary_expr(&self.expr_name, f, input_schema,
extensions).await
+ Self::build_unary_expr(ctx, &self.expr_name, f, input_schema,
extensions)
+ .await
}
_ => {
not_impl_err!("Unsupported builtin expression: {}",
self.expr_name)
@@ -1330,6 +1408,7 @@ impl BuiltinExprBuilder {
}
async fn build_unary_expr(
+ ctx: &SessionContext,
fn_name: &str,
f: &ScalarFunction,
input_schema: &DFSchema,
@@ -1341,7 +1420,7 @@ impl BuiltinExprBuilder {
let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type
else {
return substrait_err!("Invalid arguments type for {fn_name} expr");
};
- let arg = from_substrait_rex(expr_substrait, input_schema, extensions)
+ let arg = from_substrait_rex(ctx, expr_substrait, input_schema,
extensions)
.await?
.as_ref()
.clone();
@@ -1365,6 +1444,7 @@ impl BuiltinExprBuilder {
}
async fn build_like_expr(
+ ctx: &SessionContext,
case_insensitive: bool,
f: &ScalarFunction,
input_schema: &DFSchema,
@@ -1378,22 +1458,23 @@ impl BuiltinExprBuilder {
let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type
else {
return substrait_err!("Invalid arguments type for `{fn_name}`
expr");
};
- let expr = from_substrait_rex(expr_substrait, input_schema, extensions)
+ let expr = from_substrait_rex(ctx, expr_substrait, input_schema,
extensions)
.await?
.as_ref()
.clone();
let Some(ArgType::Value(pattern_substrait)) = &f.arguments[1].arg_type
else {
return substrait_err!("Invalid arguments type for `{fn_name}`
expr");
};
- let pattern = from_substrait_rex(pattern_substrait, input_schema,
extensions)
- .await?
- .as_ref()
- .clone();
+ let pattern =
+ from_substrait_rex(ctx, pattern_substrait, input_schema,
extensions)
+ .await?
+ .as_ref()
+ .clone();
let Some(ArgType::Value(escape_char_substrait)) =
&f.arguments[2].arg_type else {
return substrait_err!("Invalid arguments type for `{fn_name}`
expr");
};
let escape_char_expr =
- from_substrait_rex(escape_char_substrait, input_schema, extensions)
+ from_substrait_rex(ctx, escape_char_substrait, input_schema,
extensions)
.await?
.as_ref()
.clone();
diff --git a/datafusion/substrait/src/logical_plan/producer.rs
b/datafusion/substrait/src/logical_plan/producer.rs
index 50f8725442..926883251a 100644
--- a/datafusion/substrait/src/logical_plan/producer.rs
+++ b/datafusion/substrait/src/logical_plan/producer.rs
@@ -36,12 +36,13 @@ use datafusion::common::{substrait_err, DFSchemaRef};
use datafusion::logical_expr::aggregate_function;
use datafusion::logical_expr::expr::{
AggregateFunctionDefinition, Alias, BinaryExpr, Case, Cast, GroupingSet,
InList,
- ScalarFunctionDefinition, Sort, WindowFunction,
+ InSubquery, ScalarFunctionDefinition, Sort, WindowFunction,
};
use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan,
Operator};
use datafusion::prelude::Expr;
use prost_types::Any as ProtoAny;
use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields};
+use substrait::proto::expression::subquery::InPredicate;
use substrait::proto::expression::window_function::BoundsType;
use substrait::proto::{CrossRel, ExchangeRel};
use substrait::{
@@ -58,7 +59,8 @@ use substrait::{
window_function::bound::Kind as BoundKind,
window_function::Bound,
FieldReference, IfThen, Literal, MaskExpression, ReferenceSegment,
RexType,
- ScalarFunction, SingularOrList, WindowFunction as
SubstraitWindowFunction,
+ ScalarFunction, SingularOrList, Subquery,
+ WindowFunction as SubstraitWindowFunction,
},
extensions::{
self,
@@ -167,7 +169,7 @@ pub fn to_substrait_rel(
let expressions = p
.expr
.iter()
- .map(|e| to_substrait_rex(e, p.input.schema(), 0,
extension_info))
+ .map(|e| to_substrait_rex(ctx, e, p.input.schema(), 0,
extension_info))
.collect::<Result<Vec<_>>>()?;
Ok(Box::new(Rel {
rel_type: Some(RelType::Project(Box::new(ProjectRel {
@@ -181,6 +183,7 @@ pub fn to_substrait_rel(
LogicalPlan::Filter(filter) => {
let input = to_substrait_rel(filter.input.as_ref(), ctx,
extension_info)?;
let filter_expr = to_substrait_rex(
+ ctx,
&filter.predicate,
filter.input.schema(),
0,
@@ -214,7 +217,9 @@ pub fn to_substrait_rel(
let sort_fields = sort
.expr
.iter()
- .map(|e| substrait_sort_field(e, sort.input.schema(),
extension_info))
+ .map(|e| {
+ substrait_sort_field(ctx, e, sort.input.schema(),
extension_info)
+ })
.collect::<Result<Vec<_>>>()?;
Ok(Box::new(Rel {
rel_type: Some(RelType::Sort(Box::new(SortRel {
@@ -228,6 +233,7 @@ pub fn to_substrait_rel(
LogicalPlan::Aggregate(agg) => {
let input = to_substrait_rel(agg.input.as_ref(), ctx,
extension_info)?;
let groupings = to_substrait_groupings(
+ ctx,
&agg.group_expr,
agg.input.schema(),
extension_info,
@@ -235,7 +241,9 @@ pub fn to_substrait_rel(
let measures = agg
.aggr_expr
.iter()
- .map(|e| to_substrait_agg_measure(e, agg.input.schema(),
extension_info))
+ .map(|e| {
+ to_substrait_agg_measure(ctx, e, agg.input.schema(),
extension_info)
+ })
.collect::<Result<Vec<_>>>()?;
Ok(Box::new(Rel {
@@ -283,6 +291,7 @@ pub fn to_substrait_rel(
let in_join_schema = join.left.schema().join(join.right.schema())?;
let join_filter = match &join.filter {
Some(filter) => Some(to_substrait_rex(
+ ctx,
filter,
&Arc::new(in_join_schema),
0,
@@ -299,6 +308,7 @@ pub fn to_substrait_rel(
Operator::Eq
};
let join_on = to_substrait_join_expr(
+ ctx,
&join.on,
eq_op,
join.left.schema(),
@@ -401,6 +411,7 @@ pub fn to_substrait_rel(
let mut window_exprs = vec![];
for expr in &window.window_expr {
window_exprs.push(to_substrait_rex(
+ ctx,
expr,
window.input.schema(),
0,
@@ -500,6 +511,7 @@ pub fn to_substrait_rel(
}
fn to_substrait_join_expr(
+ ctx: &SessionContext,
join_conditions: &Vec<(Expr, Expr)>,
eq_op: Operator,
left_schema: &DFSchemaRef,
@@ -513,9 +525,10 @@ fn to_substrait_join_expr(
let mut exprs: Vec<Expression> = vec![];
for (left, right) in join_conditions {
// Parse left
- let l = to_substrait_rex(left, left_schema, 0, extension_info)?;
+ let l = to_substrait_rex(ctx, left, left_schema, 0, extension_info)?;
// Parse right
let r = to_substrait_rex(
+ ctx,
right,
right_schema,
left_schema.fields().len(), // offset to return the correct index
@@ -576,6 +589,7 @@ pub fn operator_to_name(op: Operator) -> &'static str {
}
pub fn parse_flat_grouping_exprs(
+ ctx: &SessionContext,
exprs: &[Expr],
schema: &DFSchemaRef,
extension_info: &mut (
@@ -585,7 +599,7 @@ pub fn parse_flat_grouping_exprs(
) -> Result<Grouping> {
let grouping_expressions = exprs
.iter()
- .map(|e| to_substrait_rex(e, schema, 0, extension_info))
+ .map(|e| to_substrait_rex(ctx, e, schema, 0, extension_info))
.collect::<Result<Vec<_>>>()?;
Ok(Grouping {
grouping_expressions,
@@ -593,6 +607,7 @@ pub fn parse_flat_grouping_exprs(
}
pub fn to_substrait_groupings(
+ ctx: &SessionContext,
exprs: &Vec<Expr>,
schema: &DFSchemaRef,
extension_info: &mut (
@@ -608,7 +623,9 @@ pub fn to_substrait_groupings(
)),
GroupingSet::GroupingSets(sets) => Ok(sets
.iter()
- .map(|set| parse_flat_grouping_exprs(set, schema,
extension_info))
+ .map(|set| {
+ parse_flat_grouping_exprs(ctx, set, schema,
extension_info)
+ })
.collect::<Result<Vec<_>>>()?),
GroupingSet::Rollup(set) => {
let mut sets: Vec<Vec<Expr>> = vec![vec![]];
@@ -618,17 +635,21 @@ pub fn to_substrait_groupings(
Ok(sets
.iter()
.rev()
- .map(|set| parse_flat_grouping_exprs(set, schema,
extension_info))
+ .map(|set| {
+ parse_flat_grouping_exprs(ctx, set, schema,
extension_info)
+ })
.collect::<Result<Vec<_>>>()?)
}
},
_ => Ok(vec![parse_flat_grouping_exprs(
+ ctx,
exprs,
schema,
extension_info,
)?]),
},
_ => Ok(vec![parse_flat_grouping_exprs(
+ ctx,
exprs,
schema,
extension_info,
@@ -638,6 +659,7 @@ pub fn to_substrait_groupings(
#[allow(deprecated)]
pub fn to_substrait_agg_measure(
+ ctx: &SessionContext,
expr: &Expr,
schema: &DFSchemaRef,
extension_info: &mut (
@@ -650,13 +672,13 @@ pub fn to_substrait_agg_measure(
match func_def {
AggregateFunctionDefinition::BuiltIn (fun) => {
let sorts = if let Some(order_by) = order_by {
- order_by.iter().map(|expr|
to_substrait_sort_field(expr, schema,
extension_info)).collect::<Result<Vec<_>>>()?
+ order_by.iter().map(|expr|
to_substrait_sort_field(ctx, expr, schema,
extension_info)).collect::<Result<Vec<_>>>()?
} else {
vec![]
};
let mut arguments: Vec<FunctionArgument> = vec![];
for arg in args {
- arguments.push(FunctionArgument { arg_type:
Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) });
+ arguments.push(FunctionArgument { arg_type:
Some(ArgType::Value(to_substrait_rex(ctx, arg, schema, 0, extension_info)?)) });
}
let function_anchor = _register_function(fun.to_string(),
extension_info);
Ok(Measure {
@@ -674,20 +696,20 @@ pub fn to_substrait_agg_measure(
options: vec![],
}),
filter: match filter {
- Some(f) => Some(to_substrait_rex(f, schema, 0,
extension_info)?),
+ Some(f) => Some(to_substrait_rex(ctx, f, schema,
0, extension_info)?),
None => None
}
})
}
AggregateFunctionDefinition::UDF(fun) => {
let sorts = if let Some(order_by) = order_by {
- order_by.iter().map(|expr|
to_substrait_sort_field(expr, schema,
extension_info)).collect::<Result<Vec<_>>>()?
+ order_by.iter().map(|expr|
to_substrait_sort_field(ctx, expr, schema,
extension_info)).collect::<Result<Vec<_>>>()?
} else {
vec![]
};
let mut arguments: Vec<FunctionArgument> = vec![];
for arg in args {
- arguments.push(FunctionArgument { arg_type:
Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) });
+ arguments.push(FunctionArgument { arg_type:
Some(ArgType::Value(to_substrait_rex(ctx, arg, schema, 0, extension_info)?)) });
}
let function_anchor =
_register_function(fun.name().to_string(), extension_info);
Ok(Measure {
@@ -702,7 +724,7 @@ pub fn to_substrait_agg_measure(
options: vec![],
}),
filter: match filter {
- Some(f) => Some(to_substrait_rex(f, schema, 0,
extension_info)?),
+ Some(f) => Some(to_substrait_rex(ctx, f, schema,
0, extension_info)?),
None => None
}
})
@@ -714,7 +736,7 @@ pub fn to_substrait_agg_measure(
}
Expr::Alias(Alias{expr,..})=> {
- to_substrait_agg_measure(expr, schema, extension_info)
+ to_substrait_agg_measure(ctx, expr, schema, extension_info)
}
_ => internal_err!(
"Expression must be compatible with aggregation. Unsupported
expression: {:?}. ExpressionType: {:?}",
@@ -726,6 +748,7 @@ pub fn to_substrait_agg_measure(
/// Converts sort expression to corresponding substrait `SortField`
fn to_substrait_sort_field(
+ ctx: &SessionContext,
expr: &Expr,
schema: &DFSchemaRef,
extension_info: &mut (
@@ -743,6 +766,7 @@ fn to_substrait_sort_field(
};
Ok(SortField {
expr: Some(to_substrait_rex(
+ ctx,
sort.expr.deref(),
schema,
0,
@@ -851,6 +875,7 @@ pub fn make_binary_op_scalar_func(
/// * `extension_info` - Substrait extension info. Contains registered
function information
#[allow(deprecated)]
pub fn to_substrait_rex(
+ ctx: &SessionContext,
expr: &Expr,
schema: &DFSchemaRef,
col_ref_offset: usize,
@@ -867,10 +892,10 @@ pub fn to_substrait_rex(
}) => {
let substrait_list = list
.iter()
- .map(|x| to_substrait_rex(x, schema, col_ref_offset,
extension_info))
+ .map(|x| to_substrait_rex(ctx, x, schema, col_ref_offset,
extension_info))
.collect::<Result<Vec<Expression>>>()?;
let substrait_expr =
- to_substrait_rex(expr, schema, col_ref_offset,
extension_info)?;
+ to_substrait_rex(ctx, expr, schema, col_ref_offset,
extension_info)?;
let substrait_or_list = Expression {
rex_type: Some(RexType::SingularOrList(Box::new(SingularOrList
{
@@ -903,6 +928,7 @@ pub fn to_substrait_rex(
for arg in &fun.args {
arguments.push(FunctionArgument {
arg_type: Some(ArgType::Value(to_substrait_rex(
+ ctx,
arg,
schema,
col_ref_offset,
@@ -937,11 +963,11 @@ pub fn to_substrait_rex(
if *negated {
// `expr NOT BETWEEN low AND high` can be translated into
(expr < low OR high < expr)
let substrait_expr =
- to_substrait_rex(expr, schema, col_ref_offset,
extension_info)?;
+ to_substrait_rex(ctx, expr, schema, col_ref_offset,
extension_info)?;
let substrait_low =
- to_substrait_rex(low, schema, col_ref_offset,
extension_info)?;
+ to_substrait_rex(ctx, low, schema, col_ref_offset,
extension_info)?;
let substrait_high =
- to_substrait_rex(high, schema, col_ref_offset,
extension_info)?;
+ to_substrait_rex(ctx, high, schema, col_ref_offset,
extension_info)?;
let l_expr = make_binary_op_scalar_func(
&substrait_expr,
@@ -965,11 +991,11 @@ pub fn to_substrait_rex(
} else {
// `expr BETWEEN low AND high` can be translated into (low <=
expr AND expr <= high)
let substrait_expr =
- to_substrait_rex(expr, schema, col_ref_offset,
extension_info)?;
+ to_substrait_rex(ctx, expr, schema, col_ref_offset,
extension_info)?;
let substrait_low =
- to_substrait_rex(low, schema, col_ref_offset,
extension_info)?;
+ to_substrait_rex(ctx, low, schema, col_ref_offset,
extension_info)?;
let substrait_high =
- to_substrait_rex(high, schema, col_ref_offset,
extension_info)?;
+ to_substrait_rex(ctx, high, schema, col_ref_offset,
extension_info)?;
let l_expr = make_binary_op_scalar_func(
&substrait_low,
@@ -997,8 +1023,8 @@ pub fn to_substrait_rex(
substrait_field_ref(index + col_ref_offset)
}
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
- let l = to_substrait_rex(left, schema, col_ref_offset,
extension_info)?;
- let r = to_substrait_rex(right, schema, col_ref_offset,
extension_info)?;
+ let l = to_substrait_rex(ctx, left, schema, col_ref_offset,
extension_info)?;
+ let r = to_substrait_rex(ctx, right, schema, col_ref_offset,
extension_info)?;
Ok(make_binary_op_scalar_func(&l, &r, *op, extension_info))
}
@@ -1013,6 +1039,7 @@ pub fn to_substrait_rex(
// Base expression exists
ifs.push(IfClause {
r#if: Some(to_substrait_rex(
+ ctx,
e,
schema,
col_ref_offset,
@@ -1025,12 +1052,14 @@ pub fn to_substrait_rex(
for (r#if, then) in when_then_expr {
ifs.push(IfClause {
r#if: Some(to_substrait_rex(
+ ctx,
r#if,
schema,
col_ref_offset,
extension_info,
)?),
then: Some(to_substrait_rex(
+ ctx,
then,
schema,
col_ref_offset,
@@ -1042,6 +1071,7 @@ pub fn to_substrait_rex(
// Parse outer `else`
let r#else: Option<Box<Expression>> = match else_expr {
Some(e) => Some(Box::new(to_substrait_rex(
+ ctx,
e,
schema,
col_ref_offset,
@@ -1060,6 +1090,7 @@ pub fn to_substrait_rex(
substrait::proto::expression::Cast {
r#type: Some(to_substrait_type(data_type)?),
input: Some(Box::new(to_substrait_rex(
+ ctx,
expr,
schema,
col_ref_offset,
@@ -1072,7 +1103,7 @@ pub fn to_substrait_rex(
}
Expr::Literal(value) => to_substrait_literal(value),
Expr::Alias(Alias { expr, .. }) => {
- to_substrait_rex(expr, schema, col_ref_offset, extension_info)
+ to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)
}
Expr::WindowFunction(WindowFunction {
fun,
@@ -1088,6 +1119,7 @@ pub fn to_substrait_rex(
for arg in args {
arguments.push(FunctionArgument {
arg_type: Some(ArgType::Value(to_substrait_rex(
+ ctx,
arg,
schema,
col_ref_offset,
@@ -1098,12 +1130,12 @@ pub fn to_substrait_rex(
// partition by expressions
let partition_by = partition_by
.iter()
- .map(|e| to_substrait_rex(e, schema, col_ref_offset,
extension_info))
+ .map(|e| to_substrait_rex(ctx, e, schema, col_ref_offset,
extension_info))
.collect::<Result<Vec<_>>>()?;
// order by expressions
let order_by = order_by
.iter()
- .map(|e| substrait_sort_field(e, schema, extension_info))
+ .map(|e| substrait_sort_field(ctx, e, schema, extension_info))
.collect::<Result<Vec<_>>>()?;
// window frame
let bounds = to_substrait_bounds(window_frame)?;
@@ -1124,6 +1156,7 @@ pub fn to_substrait_rex(
escape_char,
case_insensitive,
}) => make_substrait_like_expr(
+ ctx,
*case_insensitive,
*negated,
expr,
@@ -1133,7 +1166,50 @@ pub fn to_substrait_rex(
col_ref_offset,
extension_info,
),
+ Expr::InSubquery(InSubquery {
+ expr,
+ subquery,
+ negated,
+ }) => {
+ let substrait_expr =
+ to_substrait_rex(ctx, expr, schema, col_ref_offset,
extension_info)?;
+
+ let subquery_plan =
+ to_substrait_rel(subquery.subquery.as_ref(), ctx,
extension_info)?;
+
+ let substrait_subquery = Expression {
+ rex_type: Some(RexType::Subquery(Box::new(Subquery {
+ subquery_type: Some(
+
substrait::proto::expression::subquery::SubqueryType::InPredicate(
+ Box::new(InPredicate {
+ needles: (vec![substrait_expr]),
+ haystack: Some(subquery_plan),
+ }),
+ ),
+ ),
+ }))),
+ };
+ if *negated {
+ let function_anchor =
+ _register_function("not".to_string(), extension_info);
+
+ Ok(Expression {
+ rex_type: Some(RexType::ScalarFunction(ScalarFunction {
+ function_reference: function_anchor,
+ arguments: vec![FunctionArgument {
+ arg_type: Some(ArgType::Value(substrait_subquery)),
+ }],
+ output_type: None,
+ args: vec![],
+ options: vec![],
+ })),
+ })
+ } else {
+ Ok(substrait_subquery)
+ }
+ }
Expr::Not(arg) => to_substrait_unary_scalar_fn(
+ ctx,
"not",
arg,
schema,
@@ -1141,6 +1217,7 @@ pub fn to_substrait_rex(
extension_info,
),
Expr::IsNull(arg) => to_substrait_unary_scalar_fn(
+ ctx,
"is_null",
arg,
schema,
@@ -1148,6 +1225,7 @@ pub fn to_substrait_rex(
extension_info,
),
Expr::IsNotNull(arg) => to_substrait_unary_scalar_fn(
+ ctx,
"is_not_null",
arg,
schema,
@@ -1155,6 +1233,7 @@ pub fn to_substrait_rex(
extension_info,
),
Expr::IsTrue(arg) => to_substrait_unary_scalar_fn(
+ ctx,
"is_true",
arg,
schema,
@@ -1162,6 +1241,7 @@ pub fn to_substrait_rex(
extension_info,
),
Expr::IsFalse(arg) => to_substrait_unary_scalar_fn(
+ ctx,
"is_false",
arg,
schema,
@@ -1169,6 +1249,7 @@ pub fn to_substrait_rex(
extension_info,
),
Expr::IsUnknown(arg) => to_substrait_unary_scalar_fn(
+ ctx,
"is_unknown",
arg,
schema,
@@ -1176,6 +1257,7 @@ pub fn to_substrait_rex(
extension_info,
),
Expr::IsNotTrue(arg) => to_substrait_unary_scalar_fn(
+ ctx,
"is_not_true",
arg,
schema,
@@ -1183,6 +1265,7 @@ pub fn to_substrait_rex(
extension_info,
),
Expr::IsNotFalse(arg) => to_substrait_unary_scalar_fn(
+ ctx,
"is_not_false",
arg,
schema,
@@ -1190,6 +1273,7 @@ pub fn to_substrait_rex(
extension_info,
),
Expr::IsNotUnknown(arg) => to_substrait_unary_scalar_fn(
+ ctx,
"is_not_unknown",
arg,
schema,
@@ -1197,6 +1281,7 @@ pub fn to_substrait_rex(
extension_info,
),
Expr::Negative(arg) => to_substrait_unary_scalar_fn(
+ ctx,
"negative",
arg,
schema,
@@ -1421,6 +1506,7 @@ fn make_substrait_window_function(
#[allow(deprecated)]
#[allow(clippy::too_many_arguments)]
fn make_substrait_like_expr(
+ ctx: &SessionContext,
ignore_case: bool,
negated: bool,
expr: &Expr,
@@ -1438,8 +1524,8 @@ fn make_substrait_like_expr(
} else {
_register_function("like".to_string(), extension_info)
};
- let expr = to_substrait_rex(expr, schema, col_ref_offset, extension_info)?;
- let pattern = to_substrait_rex(pattern, schema, col_ref_offset,
extension_info)?;
+ let expr = to_substrait_rex(ctx, expr, schema, col_ref_offset,
extension_info)?;
+ let pattern = to_substrait_rex(ctx, pattern, schema, col_ref_offset,
extension_info)?;
let escape_char =
to_substrait_literal(&ScalarValue::Utf8(escape_char.map(|c|
c.to_string())))?;
let arguments = vec![
@@ -1669,6 +1755,7 @@ fn to_substrait_literal(value: &ScalarValue) ->
Result<Expression> {
/// Util to generate substrait [RexType::ScalarFunction] with one argument
fn to_substrait_unary_scalar_fn(
+ ctx: &SessionContext,
fn_name: &str,
arg: &Expr,
schema: &DFSchemaRef,
@@ -1679,7 +1766,8 @@ fn to_substrait_unary_scalar_fn(
),
) -> Result<Expression> {
let function_anchor = _register_function(fn_name.to_string(),
extension_info);
- let substrait_expr = to_substrait_rex(arg, schema, col_ref_offset,
extension_info)?;
+ let substrait_expr =
+ to_substrait_rex(ctx, arg, schema, col_ref_offset, extension_info)?;
Ok(Expression {
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
@@ -1880,6 +1968,7 @@ fn try_to_substrait_field_reference(
}
fn substrait_sort_field(
+ ctx: &SessionContext,
expr: &Expr,
schema: &DFSchemaRef,
extension_info: &mut (
@@ -1893,7 +1982,7 @@ fn substrait_sort_field(
asc,
nulls_first,
}) => {
- let e = to_substrait_rex(expr, schema, 0, extension_info)?;
+ let e = to_substrait_rex(ctx, expr, schema, 0, extension_info)?;
let d = match (asc, nulls_first) {
(true, true) => SortDirection::AscNullsFirst,
(true, false) => SortDirection::AscNullsLast,
diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
index 47eb5a8f73..d7327caee4 100644
--- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
@@ -394,6 +394,24 @@ async fn roundtrip_inlist_4() -> Result<()> {
roundtrip("SELECT * FROM data WHERE f NOT IN ('a', 'b', 'c', 'd')").await
}
+#[tokio::test]
+async fn roundtrip_inlist_5() -> Result<()> {
+ // on roundtrip there is an additional projection during TableScan which
includes all column of the table,
+ // using assert_expected_plan here as a workaround
+ assert_expected_plan(
+ "SELECT a, f FROM data WHERE (f IN ('a', 'b', 'c') OR a in (SELECT data2.a
FROM data2 WHERE f IN ('b', 'c', 'd')))",
+ "Filter: data.f = Utf8(\"a\") OR data.f = Utf8(\"b\") OR data.f =
Utf8(\"c\") OR data.a IN (<subquery>)\
+ \n Subquery:\
+ \n Projection: data2.a\
+ \n Filter: data2.f IN ([Utf8(\"b\"), Utf8(\"c\"), Utf8(\"d\")])\
+ \n TableScan: data2 projection=[a, b, c, d, e, f]\
+ \n TableScan: data projection=[a, f], partial_filters=[data.f =
Utf8(\"a\") OR data.f = Utf8(\"b\") OR data.f = Utf8(\"c\") OR data.a IN
(<subquery>)]\
+ \n Subquery:\
+ \n Projection: data2.a\
+ \n Filter: data2.f IN ([Utf8(\"b\"), Utf8(\"c\"), Utf8(\"d\")])\
+ \n TableScan: data2 projection=[a, b, c, d, e, f]").await
+}
+
#[tokio::test]
async fn roundtrip_cross_join() -> Result<()> {
roundtrip("SELECT * FROM data CROSS JOIN data2").await