allisonwang-db commented on code in PR #39375:
URL: https://github.com/apache/spark/pull/39375#discussion_r1067334361


##########
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala:
##########
@@ -850,12 +850,6 @@ class AnalysisErrorSuite extends AnalysisTest {
       LocalRelation(a))
     assertAnalysisError(plan2, "Accessing outer query column is not allowed 
in" :: Nil)
 
-    val plan3 = Filter(
-      Exists(Union(LocalRelation(b),

Review Comment:
   We can change this to a set operation other than Union instead of deleting 
it.



##########
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala:
##########
@@ -283,6 +283,51 @@ class DecorrelateInnerQuerySuite extends PlanTest {
     check(innerPlan, outerPlan, correctAnswer, Seq(y <=> y, x === a, y === z))
   }
 
+  test("union in correlation path") {
+    val outerPlan = testRelation2
+    val innerPlan =
+      Union(
+        Filter(And(OuterReference(x) === a, c === 3),
+          testRelation),
+        Filter(And(OuterReference(y) === b, c === 6),
+          testRelation))
+    val correctAnswer =
+      Union(
+        Filter(And(x === a, c === 3),
+          DomainJoin(Seq(x, y),
+            testRelation)),
+        Filter(And(y === b, c === 6),
+          DomainJoin(Seq(x, y),

Review Comment:
   So regardless of whether the condition is equality or non-equality, we need 
to insert domain joins here, right?



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala:
##########
@@ -667,6 +720,24 @@ object DecorrelateInnerQuery extends PredicateHelper {
             val newJoin = j.copy(left = newLeft, right = newRight, condition = 
newCondition)
             (newJoin, newJoinCond, newOuterReferenceMap)
 
+          case u: Union =>
+            // First collect outer references from all children
+            val collectedChildOuterReferences = 
collectOuterReferencesInPlanTree(u)

Review Comment:
   Why do we need to collect all outer references from all children first? Can 
you add some high-level ideas on how decorrelate through Union works? 
Preferable with examples.



##########
sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala:
##########
@@ -938,6 +938,46 @@ class SubquerySuite extends QueryTest
     }
   }
 
+  test("SPARK-36124: Correlated subqueries with set operations") {

Review Comment:
   ```suggestion
     test("SPARK-36124: Correlated subqueries with union") {
   ```



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala:
##########
@@ -385,7 +431,9 @@ object DecorrelateInnerQuery extends PredicateHelper {
             val (correlated, uncorrelated) = 
conditions.partition(containsOuter)
             // Find outer references that can be substituted by attributes 
from the inner
             // query using the equality predicates.
-            val equivalences = collectEquivalentOuterReferences(correlated)
+            val equivalences =
+              if (underSetOp) AttributeMap.empty[Attribute]

Review Comment:
   Can we add a comment here why the attribute map is empty when it's 
underSetOp is true?



##########
sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala:
##########
@@ -938,6 +938,46 @@ class SubquerySuite extends QueryTest
     }
   }
 
+  test("SPARK-36124: Correlated subqueries with set operations") {
+    withTempView("t0", "t1", "t2") {
+      Seq((1, 1), (2, 0)).toDF("t0a", "t0b").createOrReplaceTempView("t0")
+      Seq((1, 1, 3)).toDF("t1a", "t1b", "t1c").createOrReplaceTempView("t1")
+      Seq((1, 1, 5), (2, 2, 7)).toDF("t2a", "t2b", 
"t2c").createOrReplaceTempView("t2")
+
+      // Union with different outer refs
+      checkAnswer(

Review Comment:
   I think we already checked the answer in the SQL query test. Here we can 
check the plan structure if necessary instead of the answer.



##########
sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-select.sql:
##########
@@ -235,3 +235,76 @@ SELECT c, (
     FROM (VALUES (0, 6), (1, 5), (2, 4), (3, 3)) t1(a, b)
     WHERE a + b = c
 ) FROM (VALUES (6)) t2(c);
+
+-- Set operations in correlation path
+
+CREATE OR REPLACE TEMP VIEW t0(t0a, t0b) as values (1, 1), (2, 0);
+CREATE OR REPLACE TEMP VIEW t1(t1a, t1b, t1c) as values (1, 1, 3);
+CREATE OR REPLACE TEMP VIEW t2(t2a, t2b, t2c) as values (1, 1, 5), (2, 2, 7);
+
+SELECT t0a, (SELECT sum(c) FROM
+  (SELECT t1c as c
+  FROM   t1
+  WHERE  t1a = t0a
+  UNION ALL
+  SELECT t2c as c
+  FROM   t2
+  WHERE  t2b = t0b)
+)
+FROM t0;
+
+SELECT t0a, (SELECT sum(c) FROM
+  (SELECT t1c as c
+  FROM   t1
+  WHERE  t1a = t0a
+  UNION ALL
+  SELECT t2c as c
+  FROM   t2
+  WHERE  t2a = t0a)
+)
+FROM t0;
+
+SELECT t0a, (SELECT sum(c) FROM
+  (SELECT t1c as c
+  FROM   t1
+  WHERE  t1a > t0a
+  UNION ALL
+  SELECT t2c as c
+  FROM   t2
+  WHERE  t2b <= t0b)
+)
+FROM t0;
+
+SELECT t0a, (SELECT sum(t1c) FROM
+  (SELECT t1c
+  FROM   t1
+  WHERE  t1a = t0a
+  UNION ALL
+  SELECT t2c
+  FROM   t2
+  WHERE  t2b = t0b)
+)
+FROM t0;
+
+SELECT t0a, (SELECT sum(t1c) FROM

Review Comment:
   Can we add more tests cases?
   1. aggregate functions that can cause the COUNT bug (e.g. count(*))
   2. Correlated references in project and aggregate.



##########
sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-select.sql:
##########
@@ -235,3 +235,76 @@ SELECT c, (
     FROM (VALUES (0, 6), (1, 5), (2, 4), (3, 3)) t1(a, b)
     WHERE a + b = c
 ) FROM (VALUES (6)) t2(c);
+
+-- Set operations in correlation path
+
+CREATE OR REPLACE TEMP VIEW t0(t0a, t0b) as values (1, 1), (2, 0);

Review Comment:
   ```suggestion
   CREATE OR REPLACE TEMP VIEW t0(t0a, t0b) AS VALUES (1, 1), (2, 0);
   ```



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala:
##########
@@ -242,14 +251,38 @@ object DecorrelateInnerQuery extends PredicateHelper {
     }.toMap
   }
 
+  /**
+   * Rewrites a domain join cond so that it can be pushed to the right side of 
a
+   * union/intersect/except operator.
+   */
+  def pushConditionsThroughUnion(
+    conditions: Seq[Expression],

Review Comment:
   nit: 4 space



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala:
##########
@@ -1209,6 +1209,12 @@ trait CheckAnalysis extends PredicateHelper with 
LookupCatalog with QueryErrorsB
         case p @ (_: ResolvedHint | _: LeafNode | _: Repartition | _: 
SubqueryAlias) =>
           p.children.foreach(child => checkPlan(child, aggregated, 
canContainOuter))
 
+        case p @ (_ : Union) =>
+          // Set operations (e.g. UNION) containing correlated values are only 
supported
+          // with decorrelateInnerQueryEnabled.
+          val childCanContainOuter = canContainOuter && 
SQLConf.get.decorrelateInnerQueryEnabled

Review Comment:
   I think we also need to check if it's inside a scalar or lateral subquery. I 
don't think we support this in IN/EXISTS subquery.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to