This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 90e03e487 [fix][plan] relax the check for distinct, order by for 
dataframe (#5258)
90e03e487 is described below

commit 90e03e4870dacd0a526e9e856c77a471f82b2242
Author: xyz <[email protected]>
AuthorDate: Thu Feb 16 18:57:36 2023 +0800

    [fix][plan] relax the check for distinct, order by for dataframe (#5258)
    
    In pr #5132, we added a check that order by expr list must exist in 
distinct expr list,
    this restriction may be over eager for dataframe user, sometimes they may 
want to do this
    for some particular reasons.
    
    In this pr, we relax the check for distinct, order by for dataframe, this 
check will only be
    triggered for sql planner.
    
    Signed-off-by: xyz <[email protected]>
---
 datafusion/core/src/dataframe.rs            | 42 +++++++++++++++++++++++
 datafusion/core/tests/dataframe.rs          | 34 +++++++++++-------
 datafusion/expr/src/logical_plan/builder.rs | 28 +++++----------
 datafusion/sql/src/query.rs                 | 53 ++++++++++++++++++++++++++---
 4 files changed, 121 insertions(+), 36 deletions(-)

diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs
index 26fe5c051..36135bd1e 100644
--- a/datafusion/core/src/dataframe.rs
+++ b/datafusion/core/src/dataframe.rs
@@ -1072,6 +1072,48 @@ mod tests {
         Ok(())
     }
 
+    #[tokio::test]
+    async fn test_distinct() -> Result<()> {
+        let t = test_table().await?;
+        let plan = t
+            .select(vec![col("c1")])
+            .unwrap()
+            .distinct()
+            .unwrap()
+            .plan
+            .clone();
+
+        let sql_plan = create_plan("select distinct c1 from 
aggregate_test_100").await?;
+
+        assert_same_plan(&plan, &sql_plan);
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn test_distinct_sort_by() -> Result<()> {
+        let t = test_table().await?;
+        let plan = t
+            .select(vec![col("c1")])
+            .unwrap()
+            .distinct()
+            .unwrap()
+            .sort(vec![col("c2").sort(true, true)])
+            .unwrap();
+        let df_results = plan.clone().collect().await?;
+        assert_batches_sorted_eq!(
+            vec![
+                "+----+", "| c1 |", "+----+", "| a  |", "| a  |", "| a  |", "| 
a  |",
+                "| a  |", "| b  |", "| b  |", "| b  |", "| b  |", "| b  |", "| 
c  |",
+                "| c  |", "| c  |", "| c  |", "| c  |", "| d  |", "| d  |", "| 
d  |",
+                "| d  |", "| d  |", "| e  |", "| e  |", "| e  |", "| e  |", "| 
e  |",
+                "+----+",
+            ],
+            &df_results
+        );
+
+        Ok(())
+    }
+
     #[tokio::test]
     async fn join() -> Result<()> {
         let left = test_table().await?.select_columns(&["c1", "c2"])?;
diff --git a/datafusion/core/tests/dataframe.rs 
b/datafusion/core/tests/dataframe.rs
index 2259fc3ac..6ca42f55f 100644
--- a/datafusion/core/tests/dataframe.rs
+++ b/datafusion/core/tests/dataframe.rs
@@ -24,7 +24,6 @@ use arrow::{
     record_batch::RecordBatch,
 };
 use datafusion::from_slice::FromSlice;
-use datafusion_common::DataFusionError;
 use std::sync::Arc;
 
 use datafusion::dataframe::DataFrame;
@@ -146,18 +145,29 @@ async fn sort_on_distinct_unprojected_columns() -> 
Result<()> {
 
     let ctx = SessionContext::new();
     ctx.register_batch("t", batch).unwrap();
+    let df = ctx
+        .table("t")
+        .await
+        .unwrap()
+        .select(vec![col("a")])
+        .unwrap()
+        .distinct()
+        .unwrap()
+        .sort(vec![Expr::Sort(Sort::new(Box::new(col("b")), false, true))])
+        .unwrap();
+    let results = df.collect().await.unwrap();
 
-    assert!(matches!(
-        ctx.table("t")
-            .await
-            .unwrap()
-            .select(vec![col("a")])
-            .unwrap()
-            .distinct()
-            .unwrap()
-            .sort(vec![Expr::Sort(Sort::new(Box::new(col("b")), false, 
true))]),
-        Err(DataFusionError::Plan(_))
-    ));
+    #[rustfmt::skip]
+    let expected = vec![
+        "+-----+",
+        "| a   |",
+        "+-----+",
+        "| 100 |",
+        "| 10  |",
+        "| 1   |",
+        "+-----+",
+    ];
+    assert_batches_eq!(expected, &results);
     Ok(())
 }
 
diff --git a/datafusion/expr/src/logical_plan/builder.rs 
b/datafusion/expr/src/logical_plan/builder.rs
index e1092b96b..f979e1f76 100644
--- a/datafusion/expr/src/logical_plan/builder.rs
+++ b/datafusion/expr/src/logical_plan/builder.rs
@@ -406,27 +406,15 @@ impl LogicalPlanBuilder {
                 Ok(())
             })?;
 
-        // if current plan is distinct or current plan is repartition and its 
child plan is distinct,
-        // then this plan is a select distinct plan
-        let is_select_distinct = match self.plan {
-            LogicalPlan::Distinct(_) => true,
-            LogicalPlan::Repartition(Repartition { ref input, .. }) => {
-                matches!(input.as_ref(), &LogicalPlan::Distinct(_))
-            }
-            _ => false,
-        };
+        self.create_sort_plan(exprs, missing_cols)
+    }
 
-        // for select distinct, order by expressions must exist in select list
-        if is_select_distinct && !missing_cols.is_empty() {
-            let missing_col_names = missing_cols
-                .iter()
-                .map(|col| col.flat_name())
-                .collect::<String>();
-            let error_msg = format!(
-                "For SELECT DISTINCT, ORDER BY expressions {missing_col_names} 
must appear in select list",
-            );
-            return Err(DataFusionError::Plan(error_msg));
-        }
+    pub fn create_sort_plan(
+        self,
+        exprs: impl IntoIterator<Item = impl Into<Expr>> + Clone,
+        missing_cols: Vec<Column>,
+    ) -> Result<Self> {
+        let schema = self.plan.schema();
 
         if missing_cols.is_empty() {
             return Ok(Self::from(LogicalPlan::Sort(Sort {
diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs
index eb7ece87d..c59c42e93 100644
--- a/datafusion/sql/src/query.rs
+++ b/datafusion/sql/src/query.rs
@@ -17,8 +17,9 @@
 
 use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
 use crate::utils::normalize_ident;
-use datafusion_common::{DFSchema, DataFusionError, Result, ScalarValue};
-use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder};
+use datafusion_common::{Column, DFSchema, DataFusionError, Result, 
ScalarValue};
+use datafusion_expr::expr_rewriter::rewrite_sort_cols_by_aggs;
+use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder, Repartition};
 use sqlparser::ast::{Expr as SQLExpr, Offset as SQLOffset, OrderByExpr, Query};
 
 use sqlparser::parser::ParserError::ParserError;
@@ -150,11 +151,55 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
             return Ok(plan);
         }
 
-        let order_by_rex = order_by
+        let mut order_by_rex = order_by
             .into_iter()
             .map(|e| self.order_by_to_sort_expr(e, plan.schema()))
             .collect::<Result<Vec<_>>>()?;
 
-        LogicalPlanBuilder::from(plan).sort(order_by_rex)?.build()
+        order_by_rex = rewrite_sort_cols_by_aggs(order_by_rex, &plan)?;
+        let schema = plan.schema();
+
+        // if current plan is distinct or current plan is repartition and its 
child plan is distinct,
+        // then this plan is a select distinct plan
+        let is_select_distinct = match plan {
+            LogicalPlan::Distinct(_) => true,
+            LogicalPlan::Repartition(Repartition { ref input, .. }) => {
+                matches!(input.as_ref(), &LogicalPlan::Distinct(_))
+            }
+            _ => false,
+        };
+
+        let mut missing_cols: Vec<Column> = vec![];
+        // Collect sort columns that are missing in the input plan's schema
+        order_by_rex
+            .clone()
+            .into_iter()
+            .try_for_each::<_, Result<()>>(|expr| {
+                let columns = expr.to_columns()?;
+
+                columns.into_iter().for_each(|c| {
+                    if schema.field_from_column(&c).is_err() {
+                        missing_cols.push(c);
+                    }
+                });
+
+                Ok(())
+            })?;
+
+        // for select distinct, order by expressions must exist in select list
+        if is_select_distinct && !missing_cols.is_empty() {
+            let missing_col_names = missing_cols
+                .iter()
+                .map(|col| col.flat_name())
+                .collect::<String>();
+            let error_msg = format!(
+                "For SELECT DISTINCT, ORDER BY expressions {missing_col_names} 
must appear in select list",
+            );
+            return Err(DataFusionError::Plan(error_msg));
+        }
+
+        LogicalPlanBuilder::from(plan)
+            .create_sort_plan(order_by_rex, missing_cols)?
+            .build()
     }
 }

Reply via email to