vbarua commented on code in PR #13931: URL: https://github.com/apache/datafusion/pull/13931#discussion_r1899715775
########## datafusion/substrait/src/logical_plan/producer.rs: ########## @@ -998,450 +1304,418 @@ pub fn make_binary_op_scalar_func( /// Convert DataFusion Expr to Substrait Rex /// /// # Arguments -/// -/// * `expr` - DataFusion expression to be parse into a Substrait expression -/// * `schema` - DataFusion input schema for looking up field qualifiers -/// * `col_ref_offset` - Offset for calculating Substrait field reference indices. -/// This should only be set by caller with more than one input relations i.e. Join. -/// Substrait expects one set of indices when joining two relations. -/// Let's say `left` and `right` have `m` and `n` columns, respectively. The `right` -/// relation will have column indices from `0` to `n-1`, however, Substrait will expect -/// the `right` indices to be offset by the `left`. This means Substrait will expect to -/// evaluate the join condition expression on indices [0 .. n-1, n .. n+m-1]. For example: -/// ```SELECT * -/// FROM t1 -/// JOIN t2 -/// ON t1.c1 = t2.c0;``` -/// where t1 consists of columns [c0, c1, c2], and t2 = columns [c0, c1] -/// the join condition should become -/// `col_ref(1) = col_ref(3 + 0)` -/// , where `3` is the number of `left` columns (`col_ref_offset`) and `0` is the index -/// of the join key column from `right` -/// * `extensions` - Substrait extension info. Contains registered function information -#[allow(deprecated)] +/// * `producer` - SubstraitProducer implementation which the handles the actual conversion +/// * `expr` - DataFusion expression to convert into a Substrait expression +/// * `schema` - DataFusion input schema for looking up columns pub fn to_substrait_rex( - state: &dyn SubstraitPlanningState, + producer: &mut impl SubstraitProducer, expr: &Expr, schema: &DFSchemaRef, - col_ref_offset: usize, - extensions: &mut Extensions, ) -> Result<Expression> { match expr { - Expr::InList(InList { - expr, - list, - negated, - }) => { - let substrait_list = list - .iter() - .map(|x| to_substrait_rex(state, x, schema, col_ref_offset, extensions)) - .collect::<Result<Vec<Expression>>>()?; - let substrait_expr = - to_substrait_rex(state, expr, schema, col_ref_offset, extensions)?; - - let substrait_or_list = Expression { - rex_type: Some(RexType::SingularOrList(Box::new(SingularOrList { - value: Some(Box::new(substrait_expr)), - options: substrait_list, - }))), - }; - - if *negated { - let function_anchor = extensions.register_function("not".to_string()); - - Ok(Expression { - rex_type: Some(RexType::ScalarFunction(ScalarFunction { - function_reference: function_anchor, - arguments: vec![FunctionArgument { - arg_type: Some(ArgType::Value(substrait_or_list)), - }], - output_type: None, - args: vec![], - options: vec![], - })), - }) - } else { - Ok(substrait_or_list) - } + Expr::Alias(expr) => producer.consume_alias(expr, schema), + Expr::Column(expr) => producer.consume_column(expr, schema), + Expr::Literal(expr) => producer.consume_literal(expr), + Expr::BinaryExpr(expr) => producer.consume_binary_expr(expr, schema), + Expr::Like(expr) => producer.consume_like(expr, schema), + Expr::SimilarTo(_) => not_impl_err!("SimilarTo is not supported"), + Expr::Not(_) => producer.consume_unary_expr(expr, schema), + Expr::IsNotNull(_) => producer.consume_unary_expr(expr, schema), + Expr::IsNull(_) => producer.consume_unary_expr(expr, schema), + Expr::IsTrue(_) => producer.consume_unary_expr(expr, schema), + Expr::IsFalse(_) => producer.consume_unary_expr(expr, schema), + Expr::IsUnknown(_) => producer.consume_unary_expr(expr, schema), + Expr::IsNotTrue(_) => producer.consume_unary_expr(expr, schema), + Expr::IsNotFalse(_) => producer.consume_unary_expr(expr, schema), + Expr::IsNotUnknown(_) => producer.consume_unary_expr(expr, schema), + Expr::Negative(_) => producer.consume_unary_expr(expr, schema), + Expr::Between(expr) => producer.consume_between(expr, schema), + Expr::Case(expr) => producer.consume_case(expr, schema), + Expr::Cast(expr) => producer.consume_cast(expr, schema), + Expr::TryCast(expr) => producer.consume_try_cast(expr, schema), + Expr::ScalarFunction(expr) => producer.consume_scalar_function(expr, schema), + Expr::AggregateFunction(_) => { + internal_err!( + "AggregateFunction should only be encountered as part of a LogicalPlan::Aggregate" + ) } - Expr::ScalarFunction(fun) => { - let mut arguments: Vec<FunctionArgument> = vec![]; - for arg in &fun.args { - arguments.push(FunctionArgument { - arg_type: Some(ArgType::Value(to_substrait_rex( - state, - arg, - schema, - col_ref_offset, - extensions, - )?)), - }); - } + Expr::WindowFunction(expr) => producer.consume_window_function(expr, schema), + Expr::InList(expr) => producer.consume_in_list(expr, schema), + Expr::InSubquery(expr) => producer.consume_in_subquery(expr, schema), + _ => not_impl_err!("Cannot convert {expr:?} to Substrait"), + } +} - let function_anchor = extensions.register_function(fun.name().to_string()); - Ok(Expression { - rex_type: Some(RexType::ScalarFunction(ScalarFunction { - function_reference: function_anchor, - arguments, - output_type: None, - args: vec![], - options: vec![], - })), - }) - } - Expr::Between(Between { - expr, - negated, - low, - high, - }) => { - if *negated { - // `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr) - let substrait_expr = - to_substrait_rex(state, expr, schema, col_ref_offset, extensions)?; - let substrait_low = - to_substrait_rex(state, low, schema, col_ref_offset, extensions)?; - let substrait_high = - to_substrait_rex(state, high, schema, col_ref_offset, extensions)?; - - let l_expr = make_binary_op_scalar_func( - &substrait_expr, - &substrait_low, - Operator::Lt, - extensions, - ); - let r_expr = make_binary_op_scalar_func( - &substrait_high, - &substrait_expr, - Operator::Lt, - extensions, - ); +pub fn from_in_list( + producer: &mut impl SubstraitProducer, + in_list: &InList, + schema: &DFSchemaRef, +) -> Result<Expression> { + let InList { + expr, + list, + negated, + } = in_list; + let substrait_list = list + .iter() + .map(|x| producer.consume_expr(x, schema)) + .collect::<Result<Vec<Expression>>>()?; + let substrait_expr = producer.consume_expr(expr, schema)?; + + let substrait_or_list = Expression { + rex_type: Some(RexType::SingularOrList(Box::new(SingularOrList { + value: Some(Box::new(substrait_expr)), + options: substrait_list, + }))), + }; - Ok(make_binary_op_scalar_func( - &l_expr, - &r_expr, - Operator::Or, - extensions, - )) - } else { - // `expr BETWEEN low AND high` can be translated into (low <= expr AND expr <= high) - let substrait_expr = - to_substrait_rex(state, expr, schema, col_ref_offset, extensions)?; - let substrait_low = - to_substrait_rex(state, low, schema, col_ref_offset, extensions)?; - let substrait_high = - to_substrait_rex(state, high, schema, col_ref_offset, extensions)?; - - let l_expr = make_binary_op_scalar_func( - &substrait_low, - &substrait_expr, - Operator::LtEq, - extensions, - ); - let r_expr = make_binary_op_scalar_func( - &substrait_expr, - &substrait_high, - Operator::LtEq, - extensions, - ); + if *negated { + let function_anchor = producer.register_function("not".to_string()); - Ok(make_binary_op_scalar_func( - &l_expr, - &r_expr, - Operator::And, - extensions, - )) - } - } - Expr::Column(col) => { - let index = schema.index_of_column(col)?; - substrait_field_ref(index + col_ref_offset) - } - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - let l = to_substrait_rex(state, left, schema, col_ref_offset, extensions)?; - let r = to_substrait_rex(state, right, schema, col_ref_offset, extensions)?; + #[allow(deprecated)] + Ok(Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments: vec![FunctionArgument { + arg_type: Some(ArgType::Value(substrait_or_list)), + }], + output_type: None, + args: vec![], + options: vec![], + })), + }) + } else { + Ok(substrait_or_list) + } +} - Ok(make_binary_op_scalar_func(&l, &r, *op, extensions)) - } - Expr::Case(Case { - expr, - when_then_expr, - else_expr, - }) => { - let mut ifs: Vec<IfClause> = vec![]; - // Parse base - if let Some(e) = expr { - // Base expression exists - ifs.push(IfClause { - r#if: Some(to_substrait_rex( - state, - e, - schema, - col_ref_offset, - extensions, - )?), - then: None, - }); - } - // Parse `when`s - for (r#if, then) in when_then_expr { - ifs.push(IfClause { - r#if: Some(to_substrait_rex( - state, - r#if, - schema, - col_ref_offset, - extensions, - )?), - then: Some(to_substrait_rex( - state, - then, - schema, - col_ref_offset, - extensions, - )?), - }); - } +pub fn from_scalar_function( + producer: &mut impl SubstraitProducer, + fun: &expr::ScalarFunction, + schema: &DFSchemaRef, +) -> Result<Expression> { + let mut arguments: Vec<FunctionArgument> = vec![]; + for arg in &fun.args { + arguments.push(FunctionArgument { + arg_type: Some(ArgType::Value(to_substrait_rex(producer, arg, schema)?)), + }); + } - // Parse outer `else` - let r#else: Option<Box<Expression>> = match else_expr { - Some(e) => Some(Box::new(to_substrait_rex( - state, - e, - schema, - col_ref_offset, - extensions, - )?)), - None => None, - }; + let function_anchor = producer.register_function(fun.name().to_string()); + #[allow(deprecated)] + Ok(Expression { + rex_type: Some(RexType::ScalarFunction(ScalarFunction { + function_reference: function_anchor, + arguments, + output_type: None, + options: vec![], + args: vec![], + })), + }) +} - Ok(Expression { - rex_type: Some(RexType::IfThen(Box::new(IfThen { ifs, r#else }))), - }) - } - Expr::Cast(Cast { expr, data_type }) => Ok(Expression { - rex_type: Some(RexType::Cast(Box::new( - substrait::proto::expression::Cast { - r#type: Some(to_substrait_type(data_type, true)?), - input: Some(Box::new(to_substrait_rex( - state, - expr, - schema, - col_ref_offset, - extensions, - )?)), - failure_behavior: FailureBehavior::ThrowException.into(), - }, - ))), - }), - Expr::TryCast(TryCast { expr, data_type }) => Ok(Expression { - rex_type: Some(RexType::Cast(Box::new( - substrait::proto::expression::Cast { - r#type: Some(to_substrait_type(data_type, true)?), - input: Some(Box::new(to_substrait_rex( - state, - expr, - schema, - col_ref_offset, - extensions, - )?)), - failure_behavior: FailureBehavior::ReturnNull.into(), - }, - ))), - }), - Expr::Literal(value) => to_substrait_literal_expr(value, extensions), - Expr::Alias(Alias { expr, .. }) => { - to_substrait_rex(state, expr, schema, col_ref_offset, extensions) - } - Expr::WindowFunction(WindowFunction { - fun, - args, - partition_by, - order_by, - window_frame, - null_treatment: _, - }) => { - // function reference - let function_anchor = extensions.register_function(fun.to_string()); - // arguments - let mut arguments: Vec<FunctionArgument> = vec![]; - for arg in args { - arguments.push(FunctionArgument { - arg_type: Some(ArgType::Value(to_substrait_rex( - state, - arg, - schema, - col_ref_offset, - extensions, - )?)), - }); - } - // partition by expressions - let partition_by = partition_by - .iter() - .map(|e| to_substrait_rex(state, e, schema, col_ref_offset, extensions)) - .collect::<Result<Vec<_>>>()?; - // order by expressions - let order_by = order_by - .iter() - .map(|e| substrait_sort_field(state, e, schema, extensions)) - .collect::<Result<Vec<_>>>()?; - // window frame - let bounds = to_substrait_bounds(window_frame)?; - let bound_type = to_substrait_bound_type(window_frame)?; - Ok(make_substrait_window_function( - function_anchor, - arguments, - partition_by, - order_by, - bounds, - bound_type, - )) - } - Expr::Like(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive, - }) => make_substrait_like_expr( - state, - *case_insensitive, - *negated, - expr, - pattern, - *escape_char, - schema, - col_ref_offset, - extensions, - ), - Expr::InSubquery(InSubquery { - expr, - subquery, - negated, - }) => { - let substrait_expr = - to_substrait_rex(state, expr, schema, col_ref_offset, extensions)?; - - let subquery_plan = - to_substrait_rel(subquery.subquery.as_ref(), state, extensions)?; - - let substrait_subquery = Expression { - rex_type: Some(RexType::Subquery(Box::new(Subquery { - subquery_type: Some( - substrait::proto::expression::subquery::SubqueryType::InPredicate( - Box::new(InPredicate { - needles: (vec![substrait_expr]), - haystack: Some(subquery_plan), - }), - ), +pub fn from_between( + producer: &mut impl SubstraitProducer, + between: &Between, + schema: &DFSchemaRef, +) -> Result<Expression> { + let Between { + expr, + negated, + low, + high, + } = between; + if *negated { + // `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr) + let substrait_expr = producer.consume_expr(expr.as_ref(), schema)?; + let substrait_low = producer.consume_expr(low.as_ref(), schema)?; + let substrait_high = producer.consume_expr(high.as_ref(), schema)?; Review Comment: > probs better not to change now to keep diff small(er), I agree. Do small code improvements also require full issue to be linked to them? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For additional commands, e-mail: github-h...@datafusion.apache.org