allisonwang-db commented on code in PR #39759:
URL: https://github.com/apache/spark/pull/39759#discussion_r1092794250


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala:
##########
@@ -255,26 +256,72 @@ object DecorrelateInnerQuery extends PredicateHelper {
    * Rewrites a domain join cond so that it can be pushed to the right side of 
a
    * union/intersect/except operator.
    */
-  def pushConditionsThroughUnion(
+  def pushDomainConditionsThroughSetOperation(
       conditions: Seq[Expression],
-      union: Union,
+      setOp: LogicalPlan, // Union or SetOperation
       child: LogicalPlan): Seq[Expression] = {
     // The output attributes are always equal to the left child's output
-    assert(union.output.size == child.output.size)
-    val map = AttributeMap(union.output.zip(child.output))
+    assert(setOp.output.size == child.output.size)
+    val map = AttributeMap(setOp.output.zip(child.output))
     conditions.map {
       // The left hand side is the domain attribute used in the inner query 
and the right hand side
       // is the attribute from the outer query. (See comment above in 
buildDomainAttrMap.)
       // We need to remap the attribute names used in the inner query (left 
hand side) to account
       // for the different names in each union child. We should not remap the 
attribute names used
       // in the outer query.
+      //
+      // Note: the reason we can't just use the original joinCond from when 
the DomainJoin was
+      // constructed is that constructing the DomainJoins happens much earlier 
than rewriting the
+      // DomainJoins into actual joins, with many optimization steps in
+      // between, which could change the attributes involved (e.g. 
CollapseProject).
       case EqualNullSafe(left: Attribute, right: Expression) =>
         EqualNullSafe(map.getOrElse(left, left), right)
       case EqualTo(left: Attribute, right: Expression) =>

Review Comment:
   Since we always use EqualNullSafe to construct domain join conditions, can 
we skip this EqualTo case and use `.collect` instead of `.map`?



##########
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala:
##########
@@ -333,6 +334,114 @@ class DecorrelateInnerQuerySuite extends PlanTest {
     check(innerPlan, outerPlan, correctAnswer, Seq(x <=> x, y <=> y, z <=> z))
   }
 
+  test("INTERSECT ALL in correlation path") {
+    val outerPlan = testRelation2
+    val innerPlan =
+      Intersect(
+        Filter(And(OuterReference(x) === a, c === 3),
+          testRelation),
+        Filter(And(OuterReference(y) === b, c === 6),
+          testRelation),
+        isAll = true)
+    val correctAnswer =
+      Intersect(
+        Project(Seq(a, b, c, x, y),
+          Filter(And(x === a, c === 3),
+            DomainJoin(Seq(x, y),
+              testRelation))),
+        Project(Seq(a, b, c, x, y),
+          Filter(And(y === b, c === 6),
+            DomainJoin(Seq(x.newInstance(), y.newInstance()),
+              testRelation))),
+        isAll = true
+      )
+    // Disable checkAnalysis because otherwise duplicate attributes hit

Review Comment:
   Hmm why do we have this issue for set operations? Can we update the test to 
bypass this issue?



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala:
##########
@@ -336,9 +383,15 @@ object DecorrelateInnerQuery extends PredicateHelper {
         throw new IllegalStateException(
           s"Unable to rewrite domain join with conditions: $conditions\n$d.")
       }
-    case u: Union =>
-      u.mapChildren { child =>
-        rewriteDomainJoins(outerPlan, child, 
pushConditionsThroughUnion(conditions, u, child))
+    case s @ (_ : Union | _: SetOperation) =>
+      s.mapChildren { child =>
+        rewriteDomainJoins(outerPlan, child,
+          pushDomainConditionsThroughSetOperation(conditions, s, child))

Review Comment:
   It would be great to provide a simple example here to illustrate the change 
for `conditions` (before vs after pushDomainConditionsThroughSetOperation)



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala:
##########
@@ -336,9 +383,15 @@ object DecorrelateInnerQuery extends PredicateHelper {
         throw new IllegalStateException(
           s"Unable to rewrite domain join with conditions: $conditions\n$d.")
       }
-    case u: Union =>
-      u.mapChildren { child =>
-        rewriteDomainJoins(outerPlan, child, 
pushConditionsThroughUnion(conditions, u, child))
+    case s @ (_ : Union | _: SetOperation) =>
+      s.mapChildren { child =>
+        rewriteDomainJoins(outerPlan, child,
+          pushDomainConditionsThroughSetOperation(conditions, s, child))
+      }
+    case j: Join if j.joinType == LeftSemi || j.joinType == LeftAnti =>

Review Comment:
   Ditto



##########
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala:
##########
@@ -333,6 +334,114 @@ class DecorrelateInnerQuerySuite extends PlanTest {
     check(innerPlan, outerPlan, correctAnswer, Seq(x <=> x, y <=> y, z <=> z))
   }
 
+  test("INTERSECT ALL in correlation path") {

Review Comment:
   Please also add SPARK-36124 in the test name



##########
sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-set-operations.sql:
##########
@@ -470,3 +470,112 @@ HAVING   t1b NOT IN
                 FROM   t3)
 ORDER BY t1c DESC NULLS LAST, t1i;
 
+-- Correlation under set ops under IN - unsupported

Review Comment:
   ```suggestion
   -- Correlation under set ops - unsupported
   ```



##########
sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-set-op.sql:
##########
@@ -0,0 +1,621 @@
+-- Set operations in correlation path
+
+CREATE OR REPLACE TEMP VIEW t0(t0a, t0b) AS VALUES (1, 1), (2, 0);
+CREATE OR REPLACE TEMP VIEW t1(t1a, t1b, t1c) AS VALUES (1, 1, 3);
+CREATE OR REPLACE TEMP VIEW t2(t2a, t2b, t2c) AS VALUES (1, 1, 5), (2, 2, 7);
+
+
+-- UNION ALL
+
+SELECT t0a, (SELECT sum(c) FROM
+  (SELECT t1c as c
+  FROM   t1
+  WHERE  t1a = t0a
+  UNION ALL
+  SELECT t2c as c
+  FROM   t2
+  WHERE  t2b = t0b)
+)
+FROM t0;
+
+SELECT * FROM t0 WHERE t0a <
+(SELECT sum(c) FROM
+  (SELECT t1c as c
+  FROM   t1
+  WHERE  t1a = t0a
+  UNION ALL
+  SELECT t2c as c
+  FROM   t2
+  WHERE  t2b = t0b)
+);
+
+SELECT t0a, (SELECT sum(c) FROM
+  (SELECT t1c as c
+  FROM   t1
+  WHERE  t1a = t0a
+  UNION ALL
+  SELECT t2c as c
+  FROM   t2
+  WHERE  t2a = t0a)
+)
+FROM t0;
+
+SELECT t0a, (SELECT sum(c) FROM
+  (SELECT t1c as c
+  FROM   t1
+  WHERE  t1a > t0a
+  UNION ALL
+  SELECT t2c as c
+  FROM   t2
+  WHERE  t2b <= t0b)
+)
+FROM t0;
+
+SELECT t0a, (SELECT sum(t1c) FROM
+  (SELECT t1c
+  FROM   t1
+  WHERE  t1a = t0a
+  UNION ALL
+  SELECT t2c
+  FROM   t2
+  WHERE  t2b = t0b)
+)
+FROM t0;
+
+-- Tests for column aliasing
+SELECT t0a, (SELECT sum(t1a + 3 * t1b + 5 * t1c) FROM
+  (SELECT t1c as t1a, t1a as t1b, t0a as t1c
+  FROM   t1
+  WHERE  t1a = t0a
+  UNION ALL
+  SELECT t0a as t2b, t2c as t1a, t0b as t2c
+  FROM   t2
+  WHERE  t2b = t0b)
+)
+FROM t0;
+
+-- Test handling of COUNT bug
+SELECT t0a, (SELECT count(t1c) FROM
+  (SELECT t1c
+  FROM   t1
+  WHERE  t1a = t0a
+  UNION ALL
+  SELECT t2c
+  FROM   t2
+  WHERE  t2b = t0b)
+)
+FROM t0;
+
+-- Correlated references in project
+SELECT t0a, (SELECT sum(d) FROM
+  (SELECT t1a - t0a as d
+  FROM   t1
+  UNION ALL
+  SELECT t2a - t0a as d
+  FROM   t2)
+)
+FROM t0;
+
+-- Correlated references in aggregate - unsupported
+SELECT t0a, (SELECT sum(d) FROM
+  (SELECT sum(t0a) as d
+  FROM   t1
+  UNION ALL
+  SELECT sum(t2a) + t0a as d
+  FROM   t2)
+)
+FROM t0;
+
+
+
+-- UNION DISTINCT
+
+SELECT t0a, (SELECT sum(c) FROM
+  (SELECT t1c as c
+  FROM   t1
+  WHERE  t1a = t0a
+  UNION DISTINCT
+  SELECT t2c as c
+  FROM   t2
+  WHERE  t2b = t0b)
+)
+FROM t0;
+
+SELECT * FROM t0 WHERE t0a <
+(SELECT sum(c) FROM
+  (SELECT t1c as c
+  FROM   t1
+  WHERE  t1a = t0a
+  UNION DISTINCT
+  SELECT t2c as c
+  FROM   t2
+  WHERE  t2b = t0b)
+);
+
+SELECT t0a, (SELECT sum(c) FROM
+  (SELECT t1c as c
+  FROM   t1
+  WHERE  t1a = t0a
+  UNION DISTINCT
+  SELECT t2c as c
+  FROM   t2
+  WHERE  t2a = t0a)
+)
+FROM t0;
+
+SELECT t0a, (SELECT sum(c) FROM
+  (SELECT t1c as c
+  FROM   t1
+  WHERE  t1a > t0a
+  UNION DISTINCT
+  SELECT t2c as c
+  FROM   t2
+  WHERE  t2b <= t0b)
+)
+FROM t0;
+
+SELECT t0a, (SELECT sum(t1c) FROM
+  (SELECT t1c
+  FROM   t1
+  WHERE  t1a = t0a
+  UNION DISTINCT
+  SELECT t2c
+  FROM   t2
+  WHERE  t2b = t0b)
+)
+FROM t0;
+
+-- Tests for column aliasing
+SELECT t0a, (SELECT sum(t1a + 3 * t1b + 5 * t1c) FROM
+  (SELECT t1c as t1a, t1a as t1b, t0a as t1c
+  FROM   t1
+  WHERE  t1a = t0a
+  UNION DISTINCT
+  SELECT t0a as t2b, t2c as t1a, t0b as t2c
+  FROM   t2
+  WHERE  t2b = t0b)
+)
+FROM t0;
+
+-- Test handling of COUNT bug
+SELECT t0a, (SELECT count(t1c) FROM
+  (SELECT t1c
+  FROM   t1
+  WHERE  t1a = t0a
+  UNION DISTINCT
+  SELECT t2c
+  FROM   t2
+  WHERE  t2b = t0b)
+)
+FROM t0;
+
+-- Correlated references in project
+SELECT t0a, (SELECT sum(d) FROM
+  (SELECT t1a - t0a as d
+  FROM   t1
+  UNION DISTINCT
+  SELECT t2a - t0a as d
+  FROM   t2)
+)
+FROM t0;
+
+-- Correlated references in aggregate - unsupported
+SELECT t0a, (SELECT sum(d) FROM
+  (SELECT sum(t0a) as d
+  FROM   t1
+  UNION DISTINCT
+  SELECT sum(t2a) + t0a as d
+  FROM   t2)
+)
+FROM t0;
+
+
+-- INTERSECT ALL
+
+SELECT t0a, (SELECT sum(c) FROM
+  (SELECT t1c as c
+  FROM   t1
+  WHERE  t1a = t0a
+  INTERSECT ALL
+  SELECT t2c as c
+  FROM   t2
+  WHERE  t2b = t0b)
+)
+FROM t0;
+
+SELECT * FROM t0 WHERE t0a <
+(SELECT sum(c) FROM
+  (SELECT t1c as c
+  FROM   t1
+  WHERE  t1a = t0a
+  INTERSECT ALL
+  SELECT t2c as c
+  FROM   t2
+  WHERE  t2b = t0b)
+);
+
+SELECT t0a, (SELECT sum(c) FROM
+  (SELECT t1c as c
+  FROM   t1
+  WHERE  t1a = t0a
+  INTERSECT ALL
+  SELECT t2c as c
+  FROM   t2
+  WHERE  t2a = t0a)
+)
+FROM t0;
+
+SELECT t0a, (SELECT sum(c) FROM
+  (SELECT t1c as c
+  FROM   t1
+  WHERE  t1a > t0a
+  INTERSECT ALL
+  SELECT t2c as c
+  FROM   t2
+  WHERE  t2b <= t0b)
+)
+FROM t0;
+
+SELECT t0a, (SELECT sum(t1c) FROM
+  (SELECT t1c
+  FROM   t1
+  WHERE  t1a = t0a
+  INTERSECT ALL
+  SELECT t2c
+  FROM   t2
+  WHERE  t2b = t0b)
+)
+FROM t0;
+
+-- Tests for column aliasing
+SELECT t0a, (SELECT sum(t1a + 3 * t1b + 5 * t1c) FROM
+  (SELECT t1c as t1a, t1a as t1b, t0a as t1c
+  FROM   t1
+  WHERE  t1a = t0a
+  INTERSECT ALL
+  SELECT t0a as t2b, t2c as t1a, t0b as t2c
+  FROM   t2
+  WHERE  t2b = t0b)
+)
+FROM t0;
+
+-- Test handling of COUNT bug
+SELECT t0a, (SELECT count(t1c) FROM
+  (SELECT t1c
+  FROM   t1
+  WHERE  t1a = t0a
+  INTERSECT ALL
+  SELECT t2c
+  FROM   t2
+  WHERE  t2b = t0b)
+)
+FROM t0;
+
+-- Correlated references in project
+SELECT t0a, (SELECT sum(d) FROM
+  (SELECT t1a - t0a as d
+  FROM   t1
+  INTERSECT ALL
+  SELECT t2a - t0a as d
+  FROM   t2)
+)
+FROM t0;
+
+-- Correlated references in aggregate - unsupported
+SELECT t0a, (SELECT sum(d) FROM
+  (SELECT sum(t0a) as d
+  FROM   t1
+  INTERSECT ALL
+  SELECT sum(t2a) + t0a as d
+  FROM   t2)
+)
+FROM t0;
+
+
+
+-- INTERSECT DISTINCT
+
+SELECT t0a, (SELECT sum(c) FROM
+  (SELECT t1c as c
+  FROM   t1
+  WHERE  t1a = t0a
+  INTERSECT DISTINCT
+  SELECT t2c as c
+  FROM   t2
+  WHERE  t2b = t0b)
+)
+FROM t0;
+
+SELECT * FROM t0 WHERE t0a <
+(SELECT sum(c) FROM
+  (SELECT t1c as c
+  FROM   t1
+  WHERE  t1a = t0a
+  INTERSECT DISTINCT
+  SELECT t2c as c
+  FROM   t2
+  WHERE  t2b = t0b)
+);
+
+SELECT t0a, (SELECT sum(c) FROM
+  (SELECT t1c as c
+  FROM   t1
+  WHERE  t1a = t0a
+  INTERSECT DISTINCT
+  SELECT t2c as c
+  FROM   t2
+  WHERE  t2a = t0a)
+)
+FROM t0;
+
+SELECT t0a, (SELECT sum(c) FROM
+  (SELECT t1c as c
+  FROM   t1
+  WHERE  t1a > t0a
+  INTERSECT DISTINCT
+  SELECT t2c as c
+  FROM   t2
+  WHERE  t2b <= t0b)
+)
+FROM t0;
+
+SELECT t0a, (SELECT sum(t1c) FROM
+  (SELECT t1c
+  FROM   t1
+  WHERE  t1a = t0a
+  INTERSECT DISTINCT
+  SELECT t2c
+  FROM   t2
+  WHERE  t2b = t0b)
+)
+FROM t0;
+
+-- Tests for column aliasing
+SELECT t0a, (SELECT sum(t1a + 3 * t1b + 5 * t1c) FROM
+  (SELECT t1c as t1a, t1a as t1b, t0a as t1c
+  FROM   t1
+  WHERE  t1a = t0a
+  INTERSECT DISTINCT
+  SELECT t0a as t2b, t2c as t1a, t0b as t2c
+  FROM   t2
+  WHERE  t2b = t0b)
+)
+FROM t0;
+
+-- Test handling of COUNT bug
+SELECT t0a, (SELECT count(t1c) FROM
+  (SELECT t1c
+  FROM   t1
+  WHERE  t1a = t0a
+  INTERSECT DISTINCT
+  SELECT t2c
+  FROM   t2
+  WHERE  t2b = t0b)
+)
+FROM t0;
+
+-- Correlated references in project
+SELECT t0a, (SELECT sum(d) FROM
+  (SELECT t1a - t0a as d
+  FROM   t1
+  INTERSECT DISTINCT
+  SELECT t2a - t0a as d
+  FROM   t2)
+)
+FROM t0;
+
+-- Correlated references in aggregate - unsupported
+SELECT t0a, (SELECT sum(d) FROM
+  (SELECT sum(t0a) as d
+  FROM   t1
+  INTERSECT DISTINCT
+  SELECT sum(t2a) + t0a as d
+  FROM   t2)
+)
+FROM t0;
+
+
+
+-- EXCEPT ALL
+
+SELECT t0a, (SELECT sum(c) FROM
+  (SELECT t1c as c
+  FROM   t1
+  WHERE  t1a = t0a
+  EXCEPT ALL
+  SELECT t2c as c
+  FROM   t2
+  WHERE  t2b = t0b)
+)
+FROM t0;
+
+SELECT * FROM t0 WHERE t0a <
+(SELECT sum(c) FROM
+  (SELECT t1c as c
+  FROM   t1
+  WHERE  t1a = t0a
+  EXCEPT ALL
+  SELECT t2c as c
+  FROM   t2
+  WHERE  t2b = t0b)
+);
+
+SELECT t0a, (SELECT sum(c) FROM
+  (SELECT t1c as c
+  FROM   t1
+  WHERE  t1a = t0a
+  EXCEPT ALL
+  SELECT t2c as c
+  FROM   t2
+  WHERE  t2a = t0a)
+)
+FROM t0;
+
+SELECT t0a, (SELECT sum(c) FROM
+  (SELECT t1c as c
+  FROM   t1
+  WHERE  t1a > t0a
+  EXCEPT ALL
+  SELECT t2c as c
+  FROM   t2
+  WHERE  t2b <= t0b)
+)
+FROM t0;
+
+SELECT t0a, (SELECT sum(t1c) FROM
+  (SELECT t1c
+  FROM   t1
+  WHERE  t1a = t0a
+  EXCEPT ALL
+  SELECT t2c
+  FROM   t2
+  WHERE  t2b = t0b)
+)
+FROM t0;
+
+-- Tests for column aliasing
+SELECT t0a, (SELECT sum(t1a + 3 * t1b + 5 * t1c) FROM
+  (SELECT t1c as t1a, t1a as t1b, t0a as t1c
+  FROM   t1
+  WHERE  t1a = t0a
+  EXCEPT ALL
+  SELECT t0a as t2b, t2c as t1a, t0b as t2c
+  FROM   t2
+  WHERE  t2b = t0b)
+)
+FROM t0;
+
+-- Test handling of COUNT bug
+SELECT t0a, (SELECT count(t1c) FROM
+  (SELECT t1c
+  FROM   t1
+  WHERE  t1a = t0a
+  EXCEPT ALL
+  SELECT t2c
+  FROM   t2
+  WHERE  t2b = t0b)
+)
+FROM t0;
+
+-- Correlated references in project
+SELECT t0a, (SELECT sum(d) FROM
+  (SELECT t1a - t0a as d
+  FROM   t1
+  EXCEPT ALL
+  SELECT t2a - t0a as d
+  FROM   t2)
+)
+FROM t0;
+
+-- Correlated references in aggregate - unsupported
+SELECT t0a, (SELECT sum(d) FROM
+  (SELECT sum(t0a) as d
+  FROM   t1
+  EXCEPT ALL
+  SELECT sum(t2a) + t0a as d
+  FROM   t2)
+)
+FROM t0;
+
+
+
+-- EXCEPT DISTINCT
+
+SELECT t0a, (SELECT sum(c) FROM
+  (SELECT t1c as c
+  FROM   t1
+  WHERE  t1a = t0a
+  EXCEPT DISTINCT
+  SELECT t2c as c
+  FROM   t2
+  WHERE  t2b = t0b)
+)
+FROM t0;
+
+SELECT * FROM t0 WHERE t0a <
+(SELECT sum(c) FROM
+  (SELECT t1c as c
+  FROM   t1
+  WHERE  t1a = t0a
+  EXCEPT DISTINCT
+  SELECT t2c as c
+  FROM   t2
+  WHERE  t2b = t0b)
+);
+
+SELECT t0a, (SELECT sum(c) FROM
+  (SELECT t1c as c
+  FROM   t1
+  WHERE  t1a = t0a
+  EXCEPT DISTINCT
+  SELECT t2c as c
+  FROM   t2
+  WHERE  t2a = t0a)
+)
+FROM t0;
+
+SELECT t0a, (SELECT sum(c) FROM
+  (SELECT t1c as c
+  FROM   t1
+  WHERE  t1a > t0a
+  EXCEPT DISTINCT
+  SELECT t2c as c
+  FROM   t2
+  WHERE  t2b <= t0b)
+)
+FROM t0;
+
+SELECT t0a, (SELECT sum(t1c) FROM
+  (SELECT t1c
+  FROM   t1
+  WHERE  t1a = t0a
+  EXCEPT DISTINCT
+  SELECT t2c
+  FROM   t2
+  WHERE  t2b = t0b)
+)
+FROM t0;
+
+-- Tests for column aliasing
+SELECT t0a, (SELECT sum(t1a + 3 * t1b + 5 * t1c) FROM
+  (SELECT t1c as t1a, t1a as t1b, t0a as t1c
+  FROM   t1
+  WHERE  t1a = t0a
+  EXCEPT DISTINCT
+  SELECT t0a as t2b, t2c as t1a, t0b as t2c
+  FROM   t2
+  WHERE  t2b = t0b)
+)
+FROM t0;
+
+-- Test handling of COUNT bug
+SELECT t0a, (SELECT count(t1c) FROM
+  (SELECT t1c
+  FROM   t1
+  WHERE  t1a = t0a
+  EXCEPT DISTINCT
+  SELECT t2c
+  FROM   t2
+  WHERE  t2b = t0b)
+)
+FROM t0;
+
+-- Correlated references in project
+SELECT t0a, (SELECT sum(d) FROM
+  (SELECT t1a - t0a as d
+  FROM   t1
+  EXCEPT DISTINCT
+  SELECT t2a - t0a as d
+  FROM   t2)
+)
+FROM t0;
+
+-- Correlated references in aggregate - unsupported
+SELECT t0a, (SELECT sum(d) FROM
+  (SELECT sum(t0a) as d
+  FROM   t1
+  EXCEPT DISTINCT
+  SELECT sum(t2a) + t0a as d
+  FROM   t2)
+)
+FROM t0;

Review Comment:
   Can we add a few more tests that combine these set operations? Something 
like `<query> UNION (<query> INTERSECT <query>`)
   
   Also, let's try a test case where you have a semi/anti join with 
correlations, but the join is not from rewriting the set operations.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to