This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new e1d4069de Fix bug where optimizer was removing 
`Partitioning::DistributeBy` expressions (#3229)
e1d4069de is described below

commit e1d4069de9d2aaf85655edd72744f176431c03d3
Author: Andy Grove <[email protected]>
AuthorDate: Tue Aug 23 09:02:16 2022 -0600

    Fix bug where optimizer was removing `Partitioning::DistributeBy` 
expressions (#3229)
    
    * Add support for Partitioning::DistributeBy in expressions function
    
    * add integration test
---
 datafusion/expr/src/logical_plan/plan.rs       |   3 +-
 datafusion/optimizer/Cargo.toml                |   1 +
 datafusion/optimizer/tests/integration-test.rs | 149 +++++++++++++++++++++++++
 3 files changed, 152 insertions(+), 1 deletion(-)

diff --git a/datafusion/expr/src/logical_plan/plan.rs 
b/datafusion/expr/src/logical_plan/plan.rs
index 94b24e2ad..2d5eb4680 100644
--- a/datafusion/expr/src/logical_plan/plan.rs
+++ b/datafusion/expr/src/logical_plan/plan.rs
@@ -220,7 +220,8 @@ impl LogicalPlan {
                 ..
             }) => match partitioning_scheme {
                 Partitioning::Hash(expr, _) => expr.clone(),
-                _ => vec![],
+                Partitioning::DistributeBy(expr) => expr.clone(),
+                Partitioning::RoundRobinBatch(_) => vec![],
             },
             LogicalPlan::Window(Window { window_expr, .. }) => 
window_expr.clone(),
             LogicalPlan::Aggregate(Aggregate {
diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml
index b1a64384c..277f2d95f 100644
--- a/datafusion/optimizer/Cargo.toml
+++ b/datafusion/optimizer/Cargo.toml
@@ -48,5 +48,6 @@ log = "^0.4"
 
 [dev-dependencies]
 ctor = "0.1.22"
+datafusion-sql = { path = "../sql", version = "11.0.0" }
 env_logger = "0.9.0"
 
diff --git a/datafusion/optimizer/tests/integration-test.rs 
b/datafusion/optimizer/tests/integration-test.rs
new file mode 100644
index 000000000..b9d4d3b63
--- /dev/null
+++ b/datafusion/optimizer/tests/integration-test.rs
@@ -0,0 +1,149 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
+use datafusion_common::{DataFusionError, Result};
+use datafusion_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource};
+use datafusion_optimizer::common_subexpr_eliminate::CommonSubexprEliminate;
+use 
datafusion_optimizer::decorrelate_scalar_subquery::DecorrelateScalarSubquery;
+use datafusion_optimizer::decorrelate_where_exists::DecorrelateWhereExists;
+use datafusion_optimizer::decorrelate_where_in::DecorrelateWhereIn;
+use datafusion_optimizer::eliminate_filter::EliminateFilter;
+use datafusion_optimizer::eliminate_limit::EliminateLimit;
+use datafusion_optimizer::filter_null_join_keys::FilterNullJoinKeys;
+use datafusion_optimizer::filter_push_down::FilterPushDown;
+use datafusion_optimizer::limit_push_down::LimitPushDown;
+use datafusion_optimizer::optimizer::Optimizer;
+use datafusion_optimizer::projection_push_down::ProjectionPushDown;
+use datafusion_optimizer::reduce_outer_join::ReduceOuterJoin;
+use 
datafusion_optimizer::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate;
+use datafusion_optimizer::simplify_expressions::SimplifyExpressions;
+use datafusion_optimizer::single_distinct_to_groupby::SingleDistinctToGroupBy;
+use datafusion_optimizer::subquery_filter_to_join::SubqueryFilterToJoin;
+use datafusion_optimizer::{OptimizerConfig, OptimizerRule};
+use datafusion_sql::planner::{ContextProvider, SqlToRel};
+use datafusion_sql::sqlparser::ast::Statement;
+use datafusion_sql::sqlparser::dialect::GenericDialect;
+use datafusion_sql::sqlparser::parser::Parser;
+use datafusion_sql::TableReference;
+use std::any::Any;
+use std::collections::HashMap;
+use std::sync::Arc;
+
+#[test]
+fn distribute_by() -> Result<()> {
+    // regression test for 
https://github.com/apache/arrow-datafusion/issues/3234
+    let sql = "SELECT col_int32, col_utf8 FROM test DISTRIBUTE BY (col_utf8)";
+    let plan = test_sql(sql)?;
+    let expected = "Repartition: DistributeBy(#col_utf8)\
+    \n  Projection: #test.col_int32, #test.col_utf8\
+    \n    TableScan: test projection=[col_int32, col_utf8]";
+    assert_eq!(expected, format!("{:?}", plan));
+    Ok(())
+}
+
+fn test_sql(sql: &str) -> Result<LogicalPlan> {
+    let rules: Vec<Arc<dyn OptimizerRule + Sync + Send>> = vec![
+        // Simplify expressions first to maximize the chance
+        // of applying other optimizations
+        Arc::new(SimplifyExpressions::new()),
+        Arc::new(DecorrelateWhereExists::new()),
+        Arc::new(DecorrelateWhereIn::new()),
+        Arc::new(DecorrelateScalarSubquery::new()),
+        Arc::new(SubqueryFilterToJoin::new()),
+        Arc::new(EliminateFilter::new()),
+        Arc::new(CommonSubexprEliminate::new()),
+        Arc::new(EliminateLimit::new()),
+        Arc::new(ProjectionPushDown::new()),
+        Arc::new(RewriteDisjunctivePredicate::new()),
+        Arc::new(FilterNullJoinKeys::default()),
+        Arc::new(ReduceOuterJoin::new()),
+        Arc::new(FilterPushDown::new()),
+        Arc::new(LimitPushDown::new()),
+        Arc::new(SingleDistinctToGroupBy::new()),
+    ];
+
+    let optimizer = Optimizer::new(rules);
+
+    // parse the SQL
+    let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ...
+    let ast: Vec<Statement> = Parser::parse_sql(&dialect, sql).unwrap();
+    let statement = &ast[0];
+
+    // create a logical query plan
+    let schema_provider = MySchemaProvider {};
+    let sql_to_rel = SqlToRel::new(&schema_provider);
+    let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap();
+
+    // optimize the logical plan
+    let mut config = OptimizerConfig::new().with_skip_failing_rules(false);
+    optimizer.optimize(&plan, &mut config, &observe)
+}
+
+struct MySchemaProvider {}
+
+impl ContextProvider for MySchemaProvider {
+    fn get_table_provider(
+        &self,
+        name: TableReference,
+    ) -> datafusion_common::Result<Arc<dyn TableSource>> {
+        let table_name = name.table();
+        if table_name.starts_with("test") {
+            let schema = Schema::new_with_metadata(
+                vec![
+                    Field::new("col_int32", DataType::Int32, true),
+                    Field::new("col_utf8", DataType::Utf8, true),
+                ],
+                HashMap::new(),
+            );
+
+            Ok(Arc::new(MyTableSource {
+                schema: Arc::new(schema),
+            }))
+        } else {
+            Err(DataFusionError::Plan("table does not exist".to_string()))
+        }
+    }
+
+    fn get_function_meta(&self, _name: &str) -> Option<Arc<ScalarUDF>> {
+        None
+    }
+
+    fn get_aggregate_meta(&self, _name: &str) -> Option<Arc<AggregateUDF>> {
+        None
+    }
+
+    fn get_variable_type(&self, _variable_names: &[String]) -> 
Option<DataType> {
+        None
+    }
+}
+
+fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}
+
+struct MyTableSource {
+    schema: SchemaRef,
+}
+
+impl TableSource for MyTableSource {
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn schema(&self) -> SchemaRef {
+        self.schema.clone()
+    }
+}

Reply via email to