This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 0cf563062 Fix output schema generated by CommonSubExprEliminate (#3726)
0cf563062 is described below
commit 0cf5630626ba7e3d814ef00477bcae2bc3cae9c2
Author: Alexander Spies <[email protected]>
AuthorDate: Tue Oct 11 14:14:45 2022 +0200
Fix output schema generated by CommonSubExprEliminate (#3726)
* CommonSubexprEliminate: Fix additional col schema
* Use correct types in test id_array_visitor
* Re-enable fall back schema for datatype resolution
Fall back to the merged schema from the whole logical plan if the input
schema was not sufficient to resolve the datatype of a sub-expression.
This re-enables the fallback logic added in 3860cd3 (#1925).
* Add comment on fall-back logic using all schemas
Point out that it can likely be removed.
---
.../optimizer/src/common_subexpr_eliminate.rs | 183 +++++++++++++++++----
1 file changed, 149 insertions(+), 34 deletions(-)
diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs
b/datafusion/optimizer/src/common_subexpr_eliminate.rs
index 552c03d3d..cea5e8c46 100644
--- a/datafusion/optimizer/src/common_subexpr_eliminate.rs
+++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs
@@ -19,7 +19,7 @@
use crate::{OptimizerConfig, OptimizerRule};
use arrow::datatypes::DataType;
-use datafusion_common::{DFField, DFSchema, DataFusionError, Result};
+use datafusion_common::{DFField, DFSchema, DFSchemaRef, DataFusionError,
Result};
use datafusion_expr::{
col,
expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion},
@@ -94,7 +94,10 @@ fn optimize(
schema,
alias,
}) => {
- let arrays = to_arrays(expr, input, &mut expr_set)?;
+ let input_schema = Arc::clone(input.schema());
+ let all_schemas: Vec<DFSchemaRef> =
+ plan.all_schemas().into_iter().cloned().collect();
+ let arrays = to_arrays(expr, input_schema, all_schemas, &mut
expr_set)?;
let (mut new_expr, new_input) = rewrite_expr(
&[expr],
@@ -112,22 +115,18 @@ fn optimize(
)?))
}
LogicalPlan::Filter(Filter { predicate, input }) => {
- let schema = plan.schema().as_ref().clone();
- let data_type = if let Ok(data_type) = predicate.get_type(&schema)
{
- data_type
- } else {
- // predicate type could not be resolved in schema, fall back
to all schemas
- let schemas = plan.all_schemas();
- let all_schema =
- schemas.into_iter().fold(DFSchema::empty(), |mut lhs, rhs|
{
- lhs.merge(rhs);
- lhs
- });
- predicate.get_type(&all_schema)?
- };
+ let input_schema = Arc::clone(input.schema());
+ let all_schemas: Vec<DFSchemaRef> =
+ plan.all_schemas().into_iter().cloned().collect();
let mut id_array = vec![];
- expr_to_identifier(predicate, &mut expr_set, &mut id_array,
data_type)?;
+ expr_to_identifier(
+ predicate,
+ &mut expr_set,
+ &mut id_array,
+ input_schema,
+ all_schemas,
+ )?;
let (mut new_expr, new_input) = rewrite_expr(
&[&[predicate.clone()]],
@@ -153,7 +152,11 @@ fn optimize(
window_expr,
schema,
}) => {
- let arrays = to_arrays(window_expr, input, &mut expr_set)?;
+ let input_schema = Arc::clone(input.schema());
+ let all_schemas: Vec<DFSchemaRef> =
+ plan.all_schemas().into_iter().cloned().collect();
+ let arrays =
+ to_arrays(window_expr, input_schema, all_schemas, &mut
expr_set)?;
let (mut new_expr, new_input) = rewrite_expr(
&[window_expr],
@@ -175,8 +178,17 @@ fn optimize(
input,
schema,
}) => {
- let group_arrays = to_arrays(group_expr, input, &mut expr_set)?;
- let aggr_arrays = to_arrays(aggr_expr, input, &mut expr_set)?;
+ let input_schema = Arc::clone(input.schema());
+ let all_schemas: Vec<DFSchemaRef> =
+ plan.all_schemas().into_iter().cloned().collect();
+ let group_arrays = to_arrays(
+ group_expr,
+ Arc::clone(&input_schema),
+ all_schemas.clone(),
+ &mut expr_set,
+ )?;
+ let aggr_arrays =
+ to_arrays(aggr_expr, input_schema, all_schemas, &mut
expr_set)?;
let (mut new_expr, new_input) = rewrite_expr(
&[group_expr, aggr_expr],
@@ -197,7 +209,10 @@ fn optimize(
)?))
}
LogicalPlan::Sort(Sort { expr, input, fetch }) => {
- let arrays = to_arrays(expr, input, &mut expr_set)?;
+ let input_schema = Arc::clone(input.schema());
+ let all_schemas: Vec<DFSchemaRef> =
+ plan.all_schemas().into_iter().cloned().collect();
+ let arrays = to_arrays(expr, input_schema, all_schemas, &mut
expr_set)?;
let (mut new_expr, new_input) = rewrite_expr(
&[expr],
@@ -255,14 +270,20 @@ fn pop_expr(new_expr: &mut Vec<Vec<Expr>>) ->
Result<Vec<Expr>> {
fn to_arrays(
expr: &[Expr],
- input: &LogicalPlan,
+ input_schema: DFSchemaRef,
+ all_schemas: Vec<DFSchemaRef>,
expr_set: &mut ExprSet,
) -> Result<Vec<Vec<(usize, String)>>> {
expr.iter()
.map(|e| {
- let data_type = e.get_type(input.schema())?;
let mut id_array = vec![];
- expr_to_identifier(e, expr_set, &mut id_array, data_type)?;
+ expr_to_identifier(
+ e,
+ expr_set,
+ &mut id_array,
+ Arc::clone(&input_schema),
+ all_schemas.clone(),
+ )?;
Ok(id_array)
})
@@ -370,7 +391,15 @@ struct ExprIdentifierVisitor<'a> {
expr_set: &'a mut ExprSet,
/// series number (usize) and identifier.
id_array: &'a mut Vec<(usize, Identifier)>,
- data_type: DataType,
+ /// input schema for the node that we're optimizing, so we can determine
the correct datatype
+ /// for each subexpression
+ input_schema: DFSchemaRef,
+ /// all schemas in the logical plan, as a fall back if we cannot resolve
an expression type
+ /// from the input schema alone
+ // This fallback should never be necessary as the expression datatype
should always be
+ // resolvable from the input schema of the node that's being optimized.
+ // todo: This can likely be removed if we are sure it's safe to do so.
+ all_schemas: Vec<DFSchemaRef>,
// inner states
visit_stack: Vec<VisitRecord>,
@@ -448,7 +477,25 @@ impl ExpressionVisitor for ExprIdentifierVisitor<'_> {
self.id_array[idx] = (self.series_number, desc.clone());
self.visit_stack.push(VisitRecord::ExprItem(desc.clone()));
- let data_type = self.data_type.clone();
+
+ let data_type = if let Ok(data_type) =
expr.get_type(&self.input_schema) {
+ data_type
+ } else {
+ // Expression type could not be resolved in schema, fall back to
all schemas.
+ //
+ // This fallback should never be necessary as the expression
datatype should always be
+ // resolvable from the input schema of the node that's being
optimized.
+ // todo: This else-branch can likely be removed if we are sure
it's safe to do so.
+ let merged_schema =
+ self.all_schemas
+ .iter()
+ .fold(DFSchema::empty(), |mut lhs, rhs| {
+ lhs.merge(rhs);
+ lhs
+ });
+ expr.get_type(&merged_schema)?
+ };
+
self.expr_set
.entry(desc)
.or_insert_with(|| (expr.clone(), 0, data_type))
@@ -462,12 +509,14 @@ fn expr_to_identifier(
expr: &Expr,
expr_set: &mut ExprSet,
id_array: &mut Vec<(usize, Identifier)>,
- data_type: DataType,
+ input_schema: DFSchemaRef,
+ all_schemas: Vec<DFSchemaRef>,
) -> Result<()> {
expr.accept(ExprIdentifierVisitor {
expr_set,
id_array,
- data_type,
+ input_schema,
+ all_schemas,
visit_stack: vec![],
node_count: 0,
series_number: 0,
@@ -577,7 +626,8 @@ fn replace_common_expr(
mod test {
use super::*;
use crate::test::*;
- use datafusion_expr::logical_plan::JoinType;
+ use arrow::datatypes::{Field, Schema};
+ use datafusion_expr::logical_plan::{table_scan, JoinType};
use datafusion_expr::{
avg, binary_expr, col, lit, logical_plan::builder::LogicalPlanBuilder,
sum,
Operator,
@@ -597,7 +647,7 @@ mod test {
fn id_array_visitor() -> Result<()> {
let expr = binary_expr(
binary_expr(
- sum(binary_expr(col("a"), Operator::Plus, lit("1"))),
+ sum(binary_expr(col("a"), Operator::Plus, lit(1))),
Operator::Minus,
avg(col("c")),
),
@@ -605,14 +655,28 @@ mod test {
lit(2),
);
+ let schema = Arc::new(DFSchema::new_with_metadata(
+ vec![
+ DFField::new(None, "a", DataType::Int64, false),
+ DFField::new(None, "c", DataType::Int64, false),
+ ],
+ Default::default(),
+ )?);
+
let mut id_array = vec![];
- expr_to_identifier(&expr, &mut HashMap::new(), &mut id_array,
DataType::Int64)?;
+ expr_to_identifier(
+ &expr,
+ &mut HashMap::new(),
+ &mut id_array,
+ Arc::clone(&schema),
+ vec![schema],
+ )?;
let expected = vec![
- (9, "SUM(a + Utf8(\"1\")) - AVG(c) * Int32(2)Int32(2)SUM(a +
Utf8(\"1\")) - AVG(c)AVG(c)cSUM(a + Utf8(\"1\"))a + Utf8(\"1\")Utf8(\"1\")a"),
- (7, "SUM(a + Utf8(\"1\")) - AVG(c)AVG(c)cSUM(a + Utf8(\"1\"))a +
Utf8(\"1\")Utf8(\"1\")a"),
- (4, "SUM(a + Utf8(\"1\"))a + Utf8(\"1\")Utf8(\"1\")a"),
- (3, "a + Utf8(\"1\")Utf8(\"1\")a"),
+ (9, "SUM(a + Int32(1)) - AVG(c) * Int32(2)Int32(2)SUM(a +
Int32(1)) - AVG(c)AVG(c)cSUM(a + Int32(1))a + Int32(1)Int32(1)a"),
+ (7, "SUM(a + Int32(1)) - AVG(c)AVG(c)cSUM(a + Int32(1))a +
Int32(1)Int32(1)a"),
+ (4, "SUM(a + Int32(1))a + Int32(1)Int32(1)a"),
+ (3, "a + Int32(1)Int32(1)a"),
(1, ""),
(2, ""),
(6, "AVG(c)c"),
@@ -796,4 +860,55 @@ mod test {
assert!(field_set.insert(field.qualified_name()));
}
}
+
+ #[test]
+ fn eliminated_subexpr_datatype() {
+ use datafusion_expr::cast;
+
+ let schema = Schema::new(vec![
+ Field::new("a", DataType::UInt64, false),
+ Field::new("b", DataType::UInt64, false),
+ Field::new("c", DataType::UInt64, false),
+ ]);
+
+ let plan = table_scan(Some("table"), &schema, None)
+ .unwrap()
+ .filter(
+ cast(col("a"), DataType::Int64)
+ .lt(lit(1_i64))
+ .and(cast(col("a"), DataType::Int64).not_eq(lit(1_i64))),
+ )
+ .unwrap()
+ .build()
+ .unwrap();
+ let rule = CommonSubexprEliminate {};
+ let optimized_plan = rule.optimize(&plan, &mut
OptimizerConfig::new()).unwrap();
+
+ let schema = optimized_plan.schema();
+ let fields_with_datatypes: Vec<_> = schema
+ .fields()
+ .iter()
+ .map(|field| (field.name(), field.data_type()))
+ .collect();
+ let formatted_fields_with_datatype =
format!("{fields_with_datatypes:#?}");
+ let expected = r###"[
+ (
+ "CAST(table.a AS Int64)table.a",
+ Int64,
+ ),
+ (
+ "a",
+ UInt64,
+ ),
+ (
+ "b",
+ UInt64,
+ ),
+ (
+ "c",
+ UInt64,
+ ),
+]"###;
+ assert_eq!(expected, formatted_fields_with_datatype);
+ }
}