This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.2 by this push:
     new 7eca60d4f30 [SPARK-41162][SQL][3.3] Fix anti- and semi-join for 
self-join with aggregations
7eca60d4f30 is described below

commit 7eca60d4f304d4a1a66add9fd04166d8eed1dd4f
Author: Enrico Minack <git...@enrico.minack.dev>
AuthorDate: Fri Jan 6 11:32:45 2023 +0800

    [SPARK-41162][SQL][3.3] Fix anti- and semi-join for self-join with 
aggregations
    
    ### What changes were proposed in this pull request?
    Backport #39131 to branch-3.3.
    
    Rule `PushDownLeftSemiAntiJoin` should not push an anti-join below an 
`Aggregate` when the join condition references an attribute that exists in its 
right plan and its left plan's child. This usually happens when the anti-join / 
semi-join is a self-join while `DeduplicateRelations` cannot deduplicate those 
attributes (in this example due to the projection of `value` to `id`).
    
    This behaviour already exists for `Project` and `Union`, but `Aggregate` 
lacks this safety guard.
    
    ### Why are the changes needed?
    Without this change, the optimizer creates an incorrect plan.
    
    This example fails with `distinct()` (an aggregation), and succeeds without 
`distinct()`, but both queries are identical:
    ```scala
    val ids = Seq(1, 2, 3).toDF("id").distinct()
    val result = ids.withColumn("id", $"id" + 1).join(ids, Seq("id"), 
"left_anti").collect()
    assert(result.length == 1)
    ```
    With `distinct()`, rule `PushDownLeftSemiAntiJoin` creates a join condition 
`(value#907 + 1) = value#907`, which can never be true. This effectively 
removes the anti-join.
    
    **Before this PR:**
    The anti-join is fully removed from the plan.
    ```
    == Physical Plan ==
    AdaptiveSparkPlan (16)
    +- == Final Plan ==
       LocalTableScan (1)
    
    (16) AdaptiveSparkPlan
    Output [1]: [id#900]
    Arguments: isFinalPlan=true
    ```
    
    This is caused by `PushDownLeftSemiAntiJoin` adding join condition 
`(value#907 + 1) = value#907`, which is wrong as because `id#910` in `(id#910 + 
1) AS id#912` exists in the right child of the join as well as in the left 
grandchild:
    ```
    === Applying Rule 
org.apache.spark.sql.catalyst.optimizer.PushDownLeftSemiAntiJoin ===
    !Join LeftAnti, (id#912 = id#910)                  Aggregate [id#910], 
[(id#910 + 1) AS id#912]
    !:- Aggregate [id#910], [(id#910 + 1) AS id#912]   +- Project [value#907 AS 
id#910]
    !:  +- Project [value#907 AS id#910]                  +- Join LeftAnti, 
((value#907 + 1) = value#907)
    !:     +- LocalRelation [value#907]                      :- LocalRelation 
[value#907]
    !+- Aggregate [id#910], [id#910]                         +- Aggregate 
[id#910], [id#910]
    !   +- Project [value#914 AS id#910]                        +- Project 
[value#914 AS id#910]
    !      +- LocalRelation [value#914]                            +- 
LocalRelation [value#914]
    ```
    
    The right child of the join and in the left grandchild would become the 
children of the pushed-down join, which creates an invalid join condition.
    
    **After this PR:**
    Join condition `(id#910 + 1) AS id#912` is understood to become ambiguous 
as both sides of the prospect join contain `id#910`. Hence, the join is not 
pushed down. The rule is then not applied any more.
    
    The final plan contains the anti-join:
    ```
    == Physical Plan ==
    AdaptiveSparkPlan (24)
    +- == Final Plan ==
       * BroadcastHashJoin LeftSemi BuildRight (14)
       :- * HashAggregate (7)
       :  +- AQEShuffleRead (6)
       :     +- ShuffleQueryStage (5), Statistics(sizeInBytes=48.0 B, 
rowCount=3)
       :        +- Exchange (4)
       :           +- * HashAggregate (3)
       :              +- * Project (2)
       :                 +- * LocalTableScan (1)
       +- BroadcastQueryStage (13), Statistics(sizeInBytes=1024.0 KiB, 
rowCount=3)
          +- BroadcastExchange (12)
             +- * HashAggregate (11)
                +- AQEShuffleRead (10)
                   +- ShuffleQueryStage (9), Statistics(sizeInBytes=48.0 B, 
rowCount=3)
                      +- ReusedExchange (8)
    
    (8) ReusedExchange [Reuses operator id: 4]
    Output [1]: [id#898]
    
    (24) AdaptiveSparkPlan
    Output [1]: [id#900]
    Arguments: isFinalPlan=true
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    It fixes correctness.
    
    ### How was this patch tested?
    Unit tests in `DataFrameJoinSuite` and `LeftSemiAntiJoinPushDownSuite`.
    
    Closes #39409 from EnricoMi/branch-antijoin-selfjoin-fix-3.3.
    
    Authored-by: Enrico Minack <git...@enrico.minack.dev>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
    (cherry picked from commit b97f79da04acc9bde1cb4def7dc33c22cfc11372)
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../optimizer/PushDownLeftSemiAntiJoin.scala       | 13 ++---
 .../optimizer/LeftSemiAntiJoinPushDownSuite.scala  | 57 ++++++++++++++--------
 .../org/apache/spark/sql/DataFrameJoinSuite.scala  | 18 +++++++
 3 files changed, 63 insertions(+), 25 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala
index 31b9d604060..8a146c4d688 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala
@@ -56,9 +56,10 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan]
       }
 
     // LeftSemi/LeftAnti over Aggregate, only push down if join can be planned 
as broadcast join.
-    case join @ Join(agg: Aggregate, rightOp, LeftSemiOrAnti(_), _, _)
+    case join @ Join(agg: Aggregate, rightOp, LeftSemiOrAnti(_), joinCond, _)
         if agg.aggregateExpressions.forall(_.deterministic) && 
agg.groupingExpressions.nonEmpty &&
           
!agg.aggregateExpressions.exists(ScalarSubquery.hasCorrelatedScalarSubquery) &&
+          canPushThroughCondition(agg.children, joinCond, rightOp) &&
           canPlanAsBroadcastHashJoin(join, conf) =>
       val aliasMap = getAliasMap(agg)
       val canPushDownPredicate = (predicate: Expression) => {
@@ -105,11 +106,11 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan]
   }
 
   /**
-   * Check if we can safely push a join through a project or union by making 
sure that attributes
-   * referred in join condition do not contain the same attributes as the plan 
they are moved
-   * into. This can happen when both sides of join refers to the same source 
(self join). This
-   * function makes sure that the join condition refers to attributes that are 
not ambiguous (i.e
-   * present in both the legs of the join) or else the resultant plan will be 
invalid.
+   * Check if we can safely push a join through a project, aggregate, or union 
by making sure that
+   * attributes referred in join condition do not contain the same attributes 
as the plan they are
+   * moved into. This can happen when both sides of join refers to the same 
source (self join).
+   * This function makes sure that the join condition refers to attributes 
that are not ambiguous
+   * (i.e present in both the legs of the join) or else the resultant plan 
will be invalid.
    */
   private def canPushThroughCondition(
       plans: Seq[LogicalPlan],
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala
index 88c29c9274a..0b5a7f76607 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.rules._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.IntegerType
 
-class LeftSemiPushdownSuite extends PlanTest {
+class LeftSemiAntiJoinPushDownSuite extends PlanTest {
 
   object Optimize extends RuleExecutor[LogicalPlan] {
     val batches =
@@ -46,7 +46,7 @@ class LeftSemiPushdownSuite extends PlanTest {
   val testRelation1 = LocalRelation('d.int)
   val testRelation2 = LocalRelation('e.int)
 
-  test("Project: LeftSemiAnti join pushdown") {
+  test("Project: LeftSemi join pushdown") {
     val originalQuery = testRelation
       .select(star())
       .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd))
@@ -59,7 +59,7 @@ class LeftSemiPushdownSuite extends PlanTest {
     comparePlans(optimized, correctAnswer)
   }
 
-  test("Project: LeftSemiAnti join no pushdown because of non-deterministic 
proj exprs") {
+  test("Project: LeftSemi join no pushdown - non-deterministic proj exprs") {
     val originalQuery = testRelation
       .select(Rand(1), 'b, 'c)
       .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd))
@@ -68,7 +68,7 @@ class LeftSemiPushdownSuite extends PlanTest {
     comparePlans(optimized, originalQuery.analyze)
   }
 
-  test("Project: LeftSemiAnti join non correlated scalar subq") {
+  test("Project: LeftSemi join pushdown - non-correlated scalar subq") {
     val subq = 
ScalarSubquery(testRelation.groupBy('b)(sum('c).as("sum")).analyze)
     val originalQuery = testRelation
       .select(subq.as("sum"))
@@ -83,7 +83,7 @@ class LeftSemiPushdownSuite extends PlanTest {
     comparePlans(optimized, correctAnswer)
   }
 
-  test("Project: LeftSemiAnti join no pushdown - correlated scalar subq in 
projection list") {
+  test("Project: LeftSemi join no pushdown - correlated scalar subq in 
projection list") {
     val testRelation2 = LocalRelation('e.int, 'f.int)
     val subqPlan = testRelation2.groupBy('e)(sum('f).as("sum")).where('e === 
'a)
     val subqExpr = ScalarSubquery(subqPlan)
@@ -95,7 +95,7 @@ class LeftSemiPushdownSuite extends PlanTest {
     comparePlans(optimized, originalQuery.analyze)
   }
 
-  test("Aggregate: LeftSemiAnti join pushdown") {
+  test("Aggregate: LeftSemi join pushdown") {
     val originalQuery = testRelation
       .groupBy('b)('b, sum('c))
       .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd))
@@ -109,7 +109,7 @@ class LeftSemiPushdownSuite extends PlanTest {
     comparePlans(optimized, correctAnswer)
   }
 
-  test("Aggregate: LeftSemiAnti join no pushdown due to non-deterministic aggr 
expressions") {
+  test("Aggregate: LeftSemi join no pushdown - non-deterministic aggr 
expressions") {
     val originalQuery = testRelation
       .groupBy('b)('b, Rand(10).as('c))
       .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd))
@@ -142,7 +142,7 @@ class LeftSemiPushdownSuite extends PlanTest {
     comparePlans(optimized, originalQuery.analyze)
   }
 
-  test("LeftSemiAnti join over aggregate - no pushdown") {
+  test("Aggregate: LeftSemi join no pushdown") {
     val originalQuery = testRelation
       .groupBy('b)('b, sum('c).as('sum))
       .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd && 
'sum === 'd))
@@ -151,7 +151,7 @@ class LeftSemiPushdownSuite extends PlanTest {
     comparePlans(optimized, originalQuery.analyze)
   }
 
-  test("Aggregate: LeftSemiAnti join non-correlated scalar subq aggr exprs") {
+  test("Aggregate: LeftSemi join pushdown - non-correlated scalar subq aggr 
exprs") {
     val subq = 
ScalarSubquery(testRelation.groupBy('b)(sum('c).as("sum")).analyze)
     val originalQuery = testRelation
       .groupBy('a) ('a, subq.as("sum"))
@@ -166,7 +166,7 @@ class LeftSemiPushdownSuite extends PlanTest {
     comparePlans(optimized, correctAnswer)
   }
 
-  test("LeftSemiAnti join over Window") {
+  test("Window: LeftSemi join pushdown") {
     val winExpr = windowExpr(count('b), windowSpec('a :: Nil, 'b.asc :: Nil, 
UnspecifiedFrame))
 
     val originalQuery = testRelation
@@ -184,7 +184,7 @@ class LeftSemiPushdownSuite extends PlanTest {
     comparePlans(optimized, correctAnswer)
   }
 
-  test("Window: LeftSemi partial pushdown") {
+  test("Window: LeftSemi join partial pushdown") {
     // Attributes from join condition which does not refer to the window 
partition spec
     // are kept up in the plan as a Filter operator above Window.
     val winExpr = windowExpr(count('b), windowSpec('a :: Nil, 'b.asc :: Nil, 
UnspecifiedFrame))
@@ -224,7 +224,7 @@ class LeftSemiPushdownSuite extends PlanTest {
     comparePlans(optimized, correctAnswer)
   }
 
-  test("Union: LeftSemiAnti join pushdown") {
+  test("Union: LeftSemi join pushdown") {
     val testRelation2 = LocalRelation('x.int, 'y.int, 'z.int)
 
     val originalQuery = Union(Seq(testRelation, testRelation2))
@@ -240,7 +240,7 @@ class LeftSemiPushdownSuite extends PlanTest {
     comparePlans(optimized, correctAnswer)
   }
 
-  test("Union: LeftSemiAnti join pushdown in self join scenario") {
+  test("Union: LeftSemi join pushdown in self join scenario") {
     val testRelation2 = LocalRelation('x.int, 'y.int, 'z.int)
     val attrX = testRelation2.output.head
 
@@ -259,7 +259,7 @@ class LeftSemiPushdownSuite extends PlanTest {
     comparePlans(optimized, correctAnswer)
   }
 
-  test("Unary: LeftSemiAnti join pushdown") {
+  test("Unary: LeftSemi join pushdown") {
     val originalQuery = testRelation
       .select(star())
       .repartition(1)
@@ -274,7 +274,7 @@ class LeftSemiPushdownSuite extends PlanTest {
     comparePlans(optimized, correctAnswer)
   }
 
-  test("Unary: LeftSemiAnti join pushdown - empty join condition") {
+  test("Unary: LeftSemi join pushdown - empty join condition") {
     val originalQuery = testRelation
       .select(star())
       .repartition(1)
@@ -289,7 +289,7 @@ class LeftSemiPushdownSuite extends PlanTest {
     comparePlans(optimized, correctAnswer)
   }
 
-  test("Unary: LeftSemi join pushdown - partial pushdown") {
+  test("Unary: LeftSemi join partial pushdown") {
     val testRelationWithArrayType = LocalRelation('a.int, 'b.int, 
'c_arr.array(IntegerType))
     val originalQuery = testRelationWithArrayType
       .generate(Explode('c_arr), alias = Some("arr"), outputNames = 
Seq("out_col"))
@@ -305,7 +305,7 @@ class LeftSemiPushdownSuite extends PlanTest {
     comparePlans(optimized, correctAnswer)
   }
 
-  test("Unary: LeftAnti join pushdown - no pushdown") {
+  test("Unary: LeftAnti join no pushdown") {
     val testRelationWithArrayType = LocalRelation('a.int, 'b.int, 
'c_arr.array(IntegerType))
     val originalQuery = testRelationWithArrayType
       .generate(Explode('c_arr), alias = Some("arr"), outputNames = 
Seq("out_col"))
@@ -315,7 +315,7 @@ class LeftSemiPushdownSuite extends PlanTest {
     comparePlans(optimized, originalQuery.analyze)
   }
 
-  test("Unary: LeftSemiAnti join pushdown - no pushdown") {
+  test("Unary: LeftSemi join - no pushdown") {
     val testRelationWithArrayType = LocalRelation('a.int, 'b.int, 
'c_arr.array(IntegerType))
     val originalQuery = testRelationWithArrayType
       .generate(Explode('c_arr), alias = Some("arr"), outputNames = 
Seq("out_col"))
@@ -325,7 +325,7 @@ class LeftSemiPushdownSuite extends PlanTest {
     comparePlans(optimized, originalQuery.analyze)
   }
 
-  test("Unary: LeftSemi join push down through Expand") {
+  test("Unary: LeftSemi join pushdown through Expand") {
     val expand = Expand(Seq(Seq('a, 'b, "null"), Seq('a, "null", 'c)),
       Seq('a, 'b, 'c), testRelation)
     val originalQuery = expand
@@ -431,6 +431,25 @@ class LeftSemiPushdownSuite extends PlanTest {
     }
   }
 
+  Seq(LeftSemi, LeftAnti).foreach { case jt =>
+    test(s"Aggregate: $jt join no pushdown - join condition refers left leg 
and right leg child") {
+      val aggregation = testRelation
+        .select('b.as("id"), 'c)
+        .groupBy('id)('id, sum('c).as("sum"))
+
+      // reference "b" exists in left leg, and the children of the right leg 
of the join
+      val originalQuery = aggregation.select(('id + 1).as("id_plus_1"), 'sum)
+        .join(aggregation, joinType = jt, condition = Some('id === 'id_plus_1))
+      val optimized = Optimize.execute(originalQuery.analyze)
+      val correctAnswer = testRelation
+        .select('b.as("id"), 'c)
+        .groupBy('id)(('id + 1).as("id_plus_1"), sum('c).as("sum"))
+        .join(aggregation, joinType = jt, condition = Some('id === 'id_plus_1))
+        .analyze
+      comparePlans(optimized, correctAnswer)
+    }
+  }
+
   Seq(LeftSemi, LeftAnti).foreach { case outerJT =>
     Seq(Inner, LeftOuter, RightOuter, Cross).foreach { case innerJT =>
       test(s"$outerJT no pushdown - join condition refers none of the leg - 
join type $innerJT") {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
index 1fda13f996a..4298d503b10 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
@@ -288,6 +288,24 @@ class DataFrameJoinSuite extends QueryTest
     }
   }
 
+  Seq("left_semi", "left_anti").foreach { joinType =>
+    test(s"SPARK-41162: $joinType self-joined aggregated dataframe") {
+      // aggregated dataframe
+      val ids = Seq(1, 2, 3).toDF("id").distinct()
+
+      // self-joined via joinType
+      val result = ids.withColumn("id", $"id" + 1)
+        .join(ids, usingColumns = Seq("id"), joinType = joinType).collect()
+
+      val expected = joinType match {
+        case "left_semi" => 2
+        case "left_anti" => 1
+        case _ => -1  // unsupported test type, test will always fail
+      }
+      assert(result.length == expected)
+    }
+  }
+
   def extractLeftDeepInnerJoins(plan: LogicalPlan): Seq[LogicalPlan] = plan 
match {
     case j @ Join(left, right, _: InnerLike, _, _) => right +: 
extractLeftDeepInnerJoins(left)
     case Filter(_, child) => extractLeftDeepInnerJoins(child)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to