allisonwang-db commented on code in PR #39759:
URL: https://github.com/apache/spark/pull/39759#discussion_r1092794250
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala:
##########
@@ -255,26 +256,72 @@ object DecorrelateInnerQuery extends PredicateHelper {
* Rewrites a domain join cond so that it can be pushed to the right side of
a
* union/intersect/except operator.
*/
- def pushConditionsThroughUnion(
+ def pushDomainConditionsThroughSetOperation(
conditions: Seq[Expression],
- union: Union,
+ setOp: LogicalPlan, // Union or SetOperation
child: LogicalPlan): Seq[Expression] = {
// The output attributes are always equal to the left child's output
- assert(union.output.size == child.output.size)
- val map = AttributeMap(union.output.zip(child.output))
+ assert(setOp.output.size == child.output.size)
+ val map = AttributeMap(setOp.output.zip(child.output))
conditions.map {
// The left hand side is the domain attribute used in the inner query
and the right hand side
// is the attribute from the outer query. (See comment above in
buildDomainAttrMap.)
// We need to remap the attribute names used in the inner query (left
hand side) to account
// for the different names in each union child. We should not remap the
attribute names used
// in the outer query.
+ //
+ // Note: the reason we can't just use the original joinCond from when
the DomainJoin was
+ // constructed is that constructing the DomainJoins happens much earlier
than rewriting the
+ // DomainJoins into actual joins, with many optimization steps in
+ // between, which could change the attributes involved (e.g.
CollapseProject).
case EqualNullSafe(left: Attribute, right: Expression) =>
EqualNullSafe(map.getOrElse(left, left), right)
case EqualTo(left: Attribute, right: Expression) =>
Review Comment:
Since we always use EqualNullSafe to construct domain join conditions, can
we skip this EqualTo case and use `.collect` instead of `.map`?
##########
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala:
##########
@@ -333,6 +334,114 @@ class DecorrelateInnerQuerySuite extends PlanTest {
check(innerPlan, outerPlan, correctAnswer, Seq(x <=> x, y <=> y, z <=> z))
}
+ test("INTERSECT ALL in correlation path") {
+ val outerPlan = testRelation2
+ val innerPlan =
+ Intersect(
+ Filter(And(OuterReference(x) === a, c === 3),
+ testRelation),
+ Filter(And(OuterReference(y) === b, c === 6),
+ testRelation),
+ isAll = true)
+ val correctAnswer =
+ Intersect(
+ Project(Seq(a, b, c, x, y),
+ Filter(And(x === a, c === 3),
+ DomainJoin(Seq(x, y),
+ testRelation))),
+ Project(Seq(a, b, c, x, y),
+ Filter(And(y === b, c === 6),
+ DomainJoin(Seq(x.newInstance(), y.newInstance()),
+ testRelation))),
+ isAll = true
+ )
+ // Disable checkAnalysis because otherwise duplicate attributes hit
Review Comment:
Hmm why do we have this issue for set operations? Can we update the test to
bypass this issue?
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala:
##########
@@ -336,9 +383,15 @@ object DecorrelateInnerQuery extends PredicateHelper {
throw new IllegalStateException(
s"Unable to rewrite domain join with conditions: $conditions\n$d.")
}
- case u: Union =>
- u.mapChildren { child =>
- rewriteDomainJoins(outerPlan, child,
pushConditionsThroughUnion(conditions, u, child))
+ case s @ (_ : Union | _: SetOperation) =>
+ s.mapChildren { child =>
+ rewriteDomainJoins(outerPlan, child,
+ pushDomainConditionsThroughSetOperation(conditions, s, child))
Review Comment:
It would be great to provide a simple example here to illustrate the change
for `conditions` (before vs after pushDomainConditionsThroughSetOperation)
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala:
##########
@@ -336,9 +383,15 @@ object DecorrelateInnerQuery extends PredicateHelper {
throw new IllegalStateException(
s"Unable to rewrite domain join with conditions: $conditions\n$d.")
}
- case u: Union =>
- u.mapChildren { child =>
- rewriteDomainJoins(outerPlan, child,
pushConditionsThroughUnion(conditions, u, child))
+ case s @ (_ : Union | _: SetOperation) =>
+ s.mapChildren { child =>
+ rewriteDomainJoins(outerPlan, child,
+ pushDomainConditionsThroughSetOperation(conditions, s, child))
+ }
+ case j: Join if j.joinType == LeftSemi || j.joinType == LeftAnti =>
Review Comment:
Ditto
##########
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala:
##########
@@ -333,6 +334,114 @@ class DecorrelateInnerQuerySuite extends PlanTest {
check(innerPlan, outerPlan, correctAnswer, Seq(x <=> x, y <=> y, z <=> z))
}
+ test("INTERSECT ALL in correlation path") {
Review Comment:
Please also add SPARK-36124 in the test name
##########
sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-set-operations.sql:
##########
@@ -470,3 +470,112 @@ HAVING t1b NOT IN
FROM t3)
ORDER BY t1c DESC NULLS LAST, t1i;
+-- Correlation under set ops under IN - unsupported
Review Comment:
```suggestion
-- Correlation under set ops - unsupported
```
##########
sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-set-op.sql:
##########
@@ -0,0 +1,621 @@
+-- Set operations in correlation path
+
+CREATE OR REPLACE TEMP VIEW t0(t0a, t0b) AS VALUES (1, 1), (2, 0);
+CREATE OR REPLACE TEMP VIEW t1(t1a, t1b, t1c) AS VALUES (1, 1, 3);
+CREATE OR REPLACE TEMP VIEW t2(t2a, t2b, t2c) AS VALUES (1, 1, 5), (2, 2, 7);
+
+
+-- UNION ALL
+
+SELECT t0a, (SELECT sum(c) FROM
+ (SELECT t1c as c
+ FROM t1
+ WHERE t1a = t0a
+ UNION ALL
+ SELECT t2c as c
+ FROM t2
+ WHERE t2b = t0b)
+)
+FROM t0;
+
+SELECT * FROM t0 WHERE t0a <
+(SELECT sum(c) FROM
+ (SELECT t1c as c
+ FROM t1
+ WHERE t1a = t0a
+ UNION ALL
+ SELECT t2c as c
+ FROM t2
+ WHERE t2b = t0b)
+);
+
+SELECT t0a, (SELECT sum(c) FROM
+ (SELECT t1c as c
+ FROM t1
+ WHERE t1a = t0a
+ UNION ALL
+ SELECT t2c as c
+ FROM t2
+ WHERE t2a = t0a)
+)
+FROM t0;
+
+SELECT t0a, (SELECT sum(c) FROM
+ (SELECT t1c as c
+ FROM t1
+ WHERE t1a > t0a
+ UNION ALL
+ SELECT t2c as c
+ FROM t2
+ WHERE t2b <= t0b)
+)
+FROM t0;
+
+SELECT t0a, (SELECT sum(t1c) FROM
+ (SELECT t1c
+ FROM t1
+ WHERE t1a = t0a
+ UNION ALL
+ SELECT t2c
+ FROM t2
+ WHERE t2b = t0b)
+)
+FROM t0;
+
+-- Tests for column aliasing
+SELECT t0a, (SELECT sum(t1a + 3 * t1b + 5 * t1c) FROM
+ (SELECT t1c as t1a, t1a as t1b, t0a as t1c
+ FROM t1
+ WHERE t1a = t0a
+ UNION ALL
+ SELECT t0a as t2b, t2c as t1a, t0b as t2c
+ FROM t2
+ WHERE t2b = t0b)
+)
+FROM t0;
+
+-- Test handling of COUNT bug
+SELECT t0a, (SELECT count(t1c) FROM
+ (SELECT t1c
+ FROM t1
+ WHERE t1a = t0a
+ UNION ALL
+ SELECT t2c
+ FROM t2
+ WHERE t2b = t0b)
+)
+FROM t0;
+
+-- Correlated references in project
+SELECT t0a, (SELECT sum(d) FROM
+ (SELECT t1a - t0a as d
+ FROM t1
+ UNION ALL
+ SELECT t2a - t0a as d
+ FROM t2)
+)
+FROM t0;
+
+-- Correlated references in aggregate - unsupported
+SELECT t0a, (SELECT sum(d) FROM
+ (SELECT sum(t0a) as d
+ FROM t1
+ UNION ALL
+ SELECT sum(t2a) + t0a as d
+ FROM t2)
+)
+FROM t0;
+
+
+
+-- UNION DISTINCT
+
+SELECT t0a, (SELECT sum(c) FROM
+ (SELECT t1c as c
+ FROM t1
+ WHERE t1a = t0a
+ UNION DISTINCT
+ SELECT t2c as c
+ FROM t2
+ WHERE t2b = t0b)
+)
+FROM t0;
+
+SELECT * FROM t0 WHERE t0a <
+(SELECT sum(c) FROM
+ (SELECT t1c as c
+ FROM t1
+ WHERE t1a = t0a
+ UNION DISTINCT
+ SELECT t2c as c
+ FROM t2
+ WHERE t2b = t0b)
+);
+
+SELECT t0a, (SELECT sum(c) FROM
+ (SELECT t1c as c
+ FROM t1
+ WHERE t1a = t0a
+ UNION DISTINCT
+ SELECT t2c as c
+ FROM t2
+ WHERE t2a = t0a)
+)
+FROM t0;
+
+SELECT t0a, (SELECT sum(c) FROM
+ (SELECT t1c as c
+ FROM t1
+ WHERE t1a > t0a
+ UNION DISTINCT
+ SELECT t2c as c
+ FROM t2
+ WHERE t2b <= t0b)
+)
+FROM t0;
+
+SELECT t0a, (SELECT sum(t1c) FROM
+ (SELECT t1c
+ FROM t1
+ WHERE t1a = t0a
+ UNION DISTINCT
+ SELECT t2c
+ FROM t2
+ WHERE t2b = t0b)
+)
+FROM t0;
+
+-- Tests for column aliasing
+SELECT t0a, (SELECT sum(t1a + 3 * t1b + 5 * t1c) FROM
+ (SELECT t1c as t1a, t1a as t1b, t0a as t1c
+ FROM t1
+ WHERE t1a = t0a
+ UNION DISTINCT
+ SELECT t0a as t2b, t2c as t1a, t0b as t2c
+ FROM t2
+ WHERE t2b = t0b)
+)
+FROM t0;
+
+-- Test handling of COUNT bug
+SELECT t0a, (SELECT count(t1c) FROM
+ (SELECT t1c
+ FROM t1
+ WHERE t1a = t0a
+ UNION DISTINCT
+ SELECT t2c
+ FROM t2
+ WHERE t2b = t0b)
+)
+FROM t0;
+
+-- Correlated references in project
+SELECT t0a, (SELECT sum(d) FROM
+ (SELECT t1a - t0a as d
+ FROM t1
+ UNION DISTINCT
+ SELECT t2a - t0a as d
+ FROM t2)
+)
+FROM t0;
+
+-- Correlated references in aggregate - unsupported
+SELECT t0a, (SELECT sum(d) FROM
+ (SELECT sum(t0a) as d
+ FROM t1
+ UNION DISTINCT
+ SELECT sum(t2a) + t0a as d
+ FROM t2)
+)
+FROM t0;
+
+
+-- INTERSECT ALL
+
+SELECT t0a, (SELECT sum(c) FROM
+ (SELECT t1c as c
+ FROM t1
+ WHERE t1a = t0a
+ INTERSECT ALL
+ SELECT t2c as c
+ FROM t2
+ WHERE t2b = t0b)
+)
+FROM t0;
+
+SELECT * FROM t0 WHERE t0a <
+(SELECT sum(c) FROM
+ (SELECT t1c as c
+ FROM t1
+ WHERE t1a = t0a
+ INTERSECT ALL
+ SELECT t2c as c
+ FROM t2
+ WHERE t2b = t0b)
+);
+
+SELECT t0a, (SELECT sum(c) FROM
+ (SELECT t1c as c
+ FROM t1
+ WHERE t1a = t0a
+ INTERSECT ALL
+ SELECT t2c as c
+ FROM t2
+ WHERE t2a = t0a)
+)
+FROM t0;
+
+SELECT t0a, (SELECT sum(c) FROM
+ (SELECT t1c as c
+ FROM t1
+ WHERE t1a > t0a
+ INTERSECT ALL
+ SELECT t2c as c
+ FROM t2
+ WHERE t2b <= t0b)
+)
+FROM t0;
+
+SELECT t0a, (SELECT sum(t1c) FROM
+ (SELECT t1c
+ FROM t1
+ WHERE t1a = t0a
+ INTERSECT ALL
+ SELECT t2c
+ FROM t2
+ WHERE t2b = t0b)
+)
+FROM t0;
+
+-- Tests for column aliasing
+SELECT t0a, (SELECT sum(t1a + 3 * t1b + 5 * t1c) FROM
+ (SELECT t1c as t1a, t1a as t1b, t0a as t1c
+ FROM t1
+ WHERE t1a = t0a
+ INTERSECT ALL
+ SELECT t0a as t2b, t2c as t1a, t0b as t2c
+ FROM t2
+ WHERE t2b = t0b)
+)
+FROM t0;
+
+-- Test handling of COUNT bug
+SELECT t0a, (SELECT count(t1c) FROM
+ (SELECT t1c
+ FROM t1
+ WHERE t1a = t0a
+ INTERSECT ALL
+ SELECT t2c
+ FROM t2
+ WHERE t2b = t0b)
+)
+FROM t0;
+
+-- Correlated references in project
+SELECT t0a, (SELECT sum(d) FROM
+ (SELECT t1a - t0a as d
+ FROM t1
+ INTERSECT ALL
+ SELECT t2a - t0a as d
+ FROM t2)
+)
+FROM t0;
+
+-- Correlated references in aggregate - unsupported
+SELECT t0a, (SELECT sum(d) FROM
+ (SELECT sum(t0a) as d
+ FROM t1
+ INTERSECT ALL
+ SELECT sum(t2a) + t0a as d
+ FROM t2)
+)
+FROM t0;
+
+
+
+-- INTERSECT DISTINCT
+
+SELECT t0a, (SELECT sum(c) FROM
+ (SELECT t1c as c
+ FROM t1
+ WHERE t1a = t0a
+ INTERSECT DISTINCT
+ SELECT t2c as c
+ FROM t2
+ WHERE t2b = t0b)
+)
+FROM t0;
+
+SELECT * FROM t0 WHERE t0a <
+(SELECT sum(c) FROM
+ (SELECT t1c as c
+ FROM t1
+ WHERE t1a = t0a
+ INTERSECT DISTINCT
+ SELECT t2c as c
+ FROM t2
+ WHERE t2b = t0b)
+);
+
+SELECT t0a, (SELECT sum(c) FROM
+ (SELECT t1c as c
+ FROM t1
+ WHERE t1a = t0a
+ INTERSECT DISTINCT
+ SELECT t2c as c
+ FROM t2
+ WHERE t2a = t0a)
+)
+FROM t0;
+
+SELECT t0a, (SELECT sum(c) FROM
+ (SELECT t1c as c
+ FROM t1
+ WHERE t1a > t0a
+ INTERSECT DISTINCT
+ SELECT t2c as c
+ FROM t2
+ WHERE t2b <= t0b)
+)
+FROM t0;
+
+SELECT t0a, (SELECT sum(t1c) FROM
+ (SELECT t1c
+ FROM t1
+ WHERE t1a = t0a
+ INTERSECT DISTINCT
+ SELECT t2c
+ FROM t2
+ WHERE t2b = t0b)
+)
+FROM t0;
+
+-- Tests for column aliasing
+SELECT t0a, (SELECT sum(t1a + 3 * t1b + 5 * t1c) FROM
+ (SELECT t1c as t1a, t1a as t1b, t0a as t1c
+ FROM t1
+ WHERE t1a = t0a
+ INTERSECT DISTINCT
+ SELECT t0a as t2b, t2c as t1a, t0b as t2c
+ FROM t2
+ WHERE t2b = t0b)
+)
+FROM t0;
+
+-- Test handling of COUNT bug
+SELECT t0a, (SELECT count(t1c) FROM
+ (SELECT t1c
+ FROM t1
+ WHERE t1a = t0a
+ INTERSECT DISTINCT
+ SELECT t2c
+ FROM t2
+ WHERE t2b = t0b)
+)
+FROM t0;
+
+-- Correlated references in project
+SELECT t0a, (SELECT sum(d) FROM
+ (SELECT t1a - t0a as d
+ FROM t1
+ INTERSECT DISTINCT
+ SELECT t2a - t0a as d
+ FROM t2)
+)
+FROM t0;
+
+-- Correlated references in aggregate - unsupported
+SELECT t0a, (SELECT sum(d) FROM
+ (SELECT sum(t0a) as d
+ FROM t1
+ INTERSECT DISTINCT
+ SELECT sum(t2a) + t0a as d
+ FROM t2)
+)
+FROM t0;
+
+
+
+-- EXCEPT ALL
+
+SELECT t0a, (SELECT sum(c) FROM
+ (SELECT t1c as c
+ FROM t1
+ WHERE t1a = t0a
+ EXCEPT ALL
+ SELECT t2c as c
+ FROM t2
+ WHERE t2b = t0b)
+)
+FROM t0;
+
+SELECT * FROM t0 WHERE t0a <
+(SELECT sum(c) FROM
+ (SELECT t1c as c
+ FROM t1
+ WHERE t1a = t0a
+ EXCEPT ALL
+ SELECT t2c as c
+ FROM t2
+ WHERE t2b = t0b)
+);
+
+SELECT t0a, (SELECT sum(c) FROM
+ (SELECT t1c as c
+ FROM t1
+ WHERE t1a = t0a
+ EXCEPT ALL
+ SELECT t2c as c
+ FROM t2
+ WHERE t2a = t0a)
+)
+FROM t0;
+
+SELECT t0a, (SELECT sum(c) FROM
+ (SELECT t1c as c
+ FROM t1
+ WHERE t1a > t0a
+ EXCEPT ALL
+ SELECT t2c as c
+ FROM t2
+ WHERE t2b <= t0b)
+)
+FROM t0;
+
+SELECT t0a, (SELECT sum(t1c) FROM
+ (SELECT t1c
+ FROM t1
+ WHERE t1a = t0a
+ EXCEPT ALL
+ SELECT t2c
+ FROM t2
+ WHERE t2b = t0b)
+)
+FROM t0;
+
+-- Tests for column aliasing
+SELECT t0a, (SELECT sum(t1a + 3 * t1b + 5 * t1c) FROM
+ (SELECT t1c as t1a, t1a as t1b, t0a as t1c
+ FROM t1
+ WHERE t1a = t0a
+ EXCEPT ALL
+ SELECT t0a as t2b, t2c as t1a, t0b as t2c
+ FROM t2
+ WHERE t2b = t0b)
+)
+FROM t0;
+
+-- Test handling of COUNT bug
+SELECT t0a, (SELECT count(t1c) FROM
+ (SELECT t1c
+ FROM t1
+ WHERE t1a = t0a
+ EXCEPT ALL
+ SELECT t2c
+ FROM t2
+ WHERE t2b = t0b)
+)
+FROM t0;
+
+-- Correlated references in project
+SELECT t0a, (SELECT sum(d) FROM
+ (SELECT t1a - t0a as d
+ FROM t1
+ EXCEPT ALL
+ SELECT t2a - t0a as d
+ FROM t2)
+)
+FROM t0;
+
+-- Correlated references in aggregate - unsupported
+SELECT t0a, (SELECT sum(d) FROM
+ (SELECT sum(t0a) as d
+ FROM t1
+ EXCEPT ALL
+ SELECT sum(t2a) + t0a as d
+ FROM t2)
+)
+FROM t0;
+
+
+
+-- EXCEPT DISTINCT
+
+SELECT t0a, (SELECT sum(c) FROM
+ (SELECT t1c as c
+ FROM t1
+ WHERE t1a = t0a
+ EXCEPT DISTINCT
+ SELECT t2c as c
+ FROM t2
+ WHERE t2b = t0b)
+)
+FROM t0;
+
+SELECT * FROM t0 WHERE t0a <
+(SELECT sum(c) FROM
+ (SELECT t1c as c
+ FROM t1
+ WHERE t1a = t0a
+ EXCEPT DISTINCT
+ SELECT t2c as c
+ FROM t2
+ WHERE t2b = t0b)
+);
+
+SELECT t0a, (SELECT sum(c) FROM
+ (SELECT t1c as c
+ FROM t1
+ WHERE t1a = t0a
+ EXCEPT DISTINCT
+ SELECT t2c as c
+ FROM t2
+ WHERE t2a = t0a)
+)
+FROM t0;
+
+SELECT t0a, (SELECT sum(c) FROM
+ (SELECT t1c as c
+ FROM t1
+ WHERE t1a > t0a
+ EXCEPT DISTINCT
+ SELECT t2c as c
+ FROM t2
+ WHERE t2b <= t0b)
+)
+FROM t0;
+
+SELECT t0a, (SELECT sum(t1c) FROM
+ (SELECT t1c
+ FROM t1
+ WHERE t1a = t0a
+ EXCEPT DISTINCT
+ SELECT t2c
+ FROM t2
+ WHERE t2b = t0b)
+)
+FROM t0;
+
+-- Tests for column aliasing
+SELECT t0a, (SELECT sum(t1a + 3 * t1b + 5 * t1c) FROM
+ (SELECT t1c as t1a, t1a as t1b, t0a as t1c
+ FROM t1
+ WHERE t1a = t0a
+ EXCEPT DISTINCT
+ SELECT t0a as t2b, t2c as t1a, t0b as t2c
+ FROM t2
+ WHERE t2b = t0b)
+)
+FROM t0;
+
+-- Test handling of COUNT bug
+SELECT t0a, (SELECT count(t1c) FROM
+ (SELECT t1c
+ FROM t1
+ WHERE t1a = t0a
+ EXCEPT DISTINCT
+ SELECT t2c
+ FROM t2
+ WHERE t2b = t0b)
+)
+FROM t0;
+
+-- Correlated references in project
+SELECT t0a, (SELECT sum(d) FROM
+ (SELECT t1a - t0a as d
+ FROM t1
+ EXCEPT DISTINCT
+ SELECT t2a - t0a as d
+ FROM t2)
+)
+FROM t0;
+
+-- Correlated references in aggregate - unsupported
+SELECT t0a, (SELECT sum(d) FROM
+ (SELECT sum(t0a) as d
+ FROM t1
+ EXCEPT DISTINCT
+ SELECT sum(t2a) + t0a as d
+ FROM t2)
+)
+FROM t0;
Review Comment:
Can we add a few more tests that combine these set operations? Something
like `<query> UNION (<query> INTERSECT <query>`)
Also, let's try a test case where you have a semi/anti join with
correlations, but the join is not from rewriting the set operations.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]