jchen5 commented on code in PR #41301:
URL: https://github.com/apache/spark/pull/41301#discussion_r1214962726
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala:
##########
@@ -804,18 +804,67 @@ object DecorrelateInnerQuery extends PredicateHelper {
(d.copy(child = newChild), joinCond, outerReferenceMap)
case j @ Join(left, right, joinType, condition, _) =>
- val outerReferences = collectOuterReferences(j.expressions)
- // Join condition containing outer references is not supported.
- assert(outerReferences.isEmpty, s"Correlated column is not allowed
in join: $j")
- val newOuterReferences = parentOuterReferences ++ outerReferences
- val shouldPushToLeft = joinType match {
+ def splitCorrelatedPredicate(condition: Option[Expression],
+ isInnerJoin: Boolean,
+ shouldDecorrelatePredicates: Boolean):
+ (Seq[Expression], Seq[Expression], Seq[Expression],
+ Seq[Expression], AttributeMap[Attribute]) = {
+ // Similar to Filters above, we split the join condition (if
present) into correlated
+ // and uncorrelated predicates, and separately handle joins
under set and aggregation
+ // operations.
+ if (shouldDecorrelatePredicates) {
+ val conditions =
+ if (condition.isDefined)
splitConjunctivePredicates(condition.get)
+ else Seq.empty[Expression]
+ val (correlated, uncorrelated) =
conditions.partition(containsOuter)
+ val equivalences =
+ if (underSetOp) AttributeMap.empty[Attribute]
+ else collectEquivalentOuterReferences(correlated)
+ var (equalityCond, predicates) =
+ if (underSetOp) (Seq.empty[Expression], correlated)
+ else correlated.partition(canPullUpOverAgg)
Review Comment:
I think we can do
```
else if (aggregated) collectEquivalentOuterReferences(correlated)
else correlated
```
so that if it's not aggregated, we can directly use all the correlated
predicates without adding a DomainJoin, equivalent to the Filter logic around
```
// Results of this sub-tree is not aggregated, so all
correlated predicates
// can be directly used as outer query join conditions.
```
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala:
##########
@@ -804,18 +804,67 @@ object DecorrelateInnerQuery extends PredicateHelper {
(d.copy(child = newChild), joinCond, outerReferenceMap)
case j @ Join(left, right, joinType, condition, _) =>
- val outerReferences = collectOuterReferences(j.expressions)
- // Join condition containing outer references is not supported.
- assert(outerReferences.isEmpty, s"Correlated column is not allowed
in join: $j")
- val newOuterReferences = parentOuterReferences ++ outerReferences
- val shouldPushToLeft = joinType match {
+ def splitCorrelatedPredicate(condition: Option[Expression],
+ isInnerJoin: Boolean,
+ shouldDecorrelatePredicates: Boolean):
+ (Seq[Expression], Seq[Expression], Seq[Expression],
+ Seq[Expression], AttributeMap[Attribute]) = {
+ // Similar to Filters above, we split the join condition (if
present) into correlated
+ // and uncorrelated predicates, and separately handle joins
under set and aggregation
+ // operations.
+ if (shouldDecorrelatePredicates) {
+ val conditions =
+ if (condition.isDefined)
splitConjunctivePredicates(condition.get)
+ else Seq.empty[Expression]
+ val (correlated, uncorrelated) =
conditions.partition(containsOuter)
+ val equivalences =
+ if (underSetOp) AttributeMap.empty[Attribute]
+ else collectEquivalentOuterReferences(correlated)
+ var (equalityCond, predicates) =
+ if (underSetOp) (Seq.empty[Expression], correlated)
+ else correlated.partition(canPullUpOverAgg)
+ // Fully preserve the join predicate for non-inner joins.
+ if (!isInnerJoin) {
+ predicates = predicates ++ equalityCond
Review Comment:
Maybe slightly better to set `predicates = correlated`, to preserve the
ordering of the original predicates?
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala:
##########
@@ -804,18 +804,67 @@ object DecorrelateInnerQuery extends PredicateHelper {
(d.copy(child = newChild), joinCond, outerReferenceMap)
case j @ Join(left, right, joinType, condition, _) =>
- val outerReferences = collectOuterReferences(j.expressions)
- // Join condition containing outer references is not supported.
- assert(outerReferences.isEmpty, s"Correlated column is not allowed
in join: $j")
- val newOuterReferences = parentOuterReferences ++ outerReferences
- val shouldPushToLeft = joinType match {
+ def splitCorrelatedPredicate(condition: Option[Expression],
+ isInnerJoin: Boolean,
+ shouldDecorrelatePredicates: Boolean):
+ (Seq[Expression], Seq[Expression], Seq[Expression],
+ Seq[Expression], AttributeMap[Attribute]) = {
Review Comment:
Can you add a comment that the results are (correlated, uncorrelated,
equalityCond, predicates, equivalences) and maybe a brief description of them
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala:
##########
@@ -804,18 +804,67 @@ object DecorrelateInnerQuery extends PredicateHelper {
(d.copy(child = newChild), joinCond, outerReferenceMap)
case j @ Join(left, right, joinType, condition, _) =>
- val outerReferences = collectOuterReferences(j.expressions)
- // Join condition containing outer references is not supported.
- assert(outerReferences.isEmpty, s"Correlated column is not allowed
in join: $j")
- val newOuterReferences = parentOuterReferences ++ outerReferences
- val shouldPushToLeft = joinType match {
+ def splitCorrelatedPredicate(condition: Option[Expression],
+ isInnerJoin: Boolean,
+ shouldDecorrelatePredicates: Boolean):
+ (Seq[Expression], Seq[Expression], Seq[Expression],
+ Seq[Expression], AttributeMap[Attribute]) = {
+ // Similar to Filters above, we split the join condition (if
present) into correlated
+ // and uncorrelated predicates, and separately handle joins
under set and aggregation
+ // operations.
+ if (shouldDecorrelatePredicates) {
+ val conditions =
+ if (condition.isDefined)
splitConjunctivePredicates(condition.get)
+ else Seq.empty[Expression]
+ val (correlated, uncorrelated) =
conditions.partition(containsOuter)
+ val equivalences =
+ if (underSetOp) AttributeMap.empty[Attribute]
+ else collectEquivalentOuterReferences(correlated)
+ var (equalityCond, predicates) =
+ if (underSetOp) (Seq.empty[Expression], correlated)
+ else correlated.partition(canPullUpOverAgg)
+ // Fully preserve the join predicate for non-inner joins.
+ if (!isInnerJoin) {
+ predicates = predicates ++ equalityCond
+ }
+ (correlated, uncorrelated, equalityCond, predicates,
equivalences)
+ } else {
+ (Seq.empty[Expression],
+ if (condition.isEmpty) Seq.empty[Expression] else
Seq(condition.get),
+ Seq.empty[Expression],
+ Seq.empty[Expression],
+ AttributeMap.empty[Attribute])
+ }
+ }
+
+ val shouldDecorrelatePredicates =
+ SQLConf.get.getConf(SQLConf.DECORRELATE_JOIN_PREDICATE_ENABLED)
+ if (!shouldDecorrelatePredicates) {
+ val outerReferences = collectOuterReferences(j.expressions)
+ // Join condition containing outer references is not supported.
+ assert(outerReferences.isEmpty, s"Correlated column is not
allowed in join: $j")
+ }
+ val (correlated, uncorrelated, equalityCond, predicates,
equivalences) =
+ splitCorrelatedPredicate(condition, joinType == Inner,
shouldDecorrelatePredicates)
+ val outerReferences = collectOuterReferences(j.expressions) ++
+ collectOuterReferences(predicates)
+ val newOuterReferences =
+ parentOuterReferences ++ outerReferences -- equivalences.keySet
+ var shouldPushToLeft = joinType match {
case LeftOuter | LeftSemiOrAnti(_) | FullOuter => true
case _ => hasOuterReferences(left)
}
val shouldPushToRight = joinType match {
case RightOuter | FullOuter => true
case _ => hasOuterReferences(right)
}
+ if (shouldDecorrelatePredicates && !shouldPushToLeft &&
!shouldPushToRight
+ && !correlated.isEmpty) {
Review Comment:
Can we change this check from correlated.isEmpty to predicates.isEmpty? I.e.
if all the correlated predicates are in equalityCond and we can directly use
them as join conditions, then I think we shouldn't need to add a DomainJoin.
##########
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala:
##########
@@ -454,4 +460,85 @@ class DecorrelateInnerQuerySuite extends PlanTest {
DomainJoin(Seq(x), testRelation))))
check(innerPlan, outerPlan, correctAnswer, Seq(x <=> x))
}
+
+ test("SPARK-43780: aggregation in subquery with correlated equi-join") {
+ // Join in the subquery is on equi-predicates, so all the correlated
references can be
+ // substituted by equivalent ones from the outer query, and domain join is
not needed.
+ val outerPlan = testRelation
+ val innerPlan =
+ Aggregate(
+ Seq.empty[Expression], Seq(Alias(count(Literal(1)), "a")()),
+ Project(Seq(x, y, a3, b3),
+ Join(testRelation2, testRelation3, Inner,
+ Some(And(x === a3, y === OuterReference(a))), JoinHint.NONE)))
+
+ val correctAnswer =
+ Aggregate(
+ Seq(y), Seq(Alias(count(Literal(1)), "a")(), y),
+ Project(Seq(x, y, a3, b3),
+ Join(testRelation2, testRelation3, Inner, Some(x === a3),
JoinHint.NONE)))
+ check(innerPlan, outerPlan, correctAnswer, Seq(y === a))
+ }
+
+ test("SPARK-43780: aggregation in subquery with correlated non-equi-join") {
+ // Join in the subquery is on non-equi-predicate, so we introduce a
DomainJoin.
+ val outerPlan = testRelation
+ val innerPlan =
+ Aggregate(
+ Seq.empty[Expression], Seq(Alias(count(Literal(1)), "a")()),
+ Project(Seq(x, y, a3, b3),
+ Join(testRelation2, testRelation3, Inner,
+ Some(And(x === a3, y > OuterReference(a))), JoinHint.NONE)))
+ val correctAnswer =
+ Aggregate(
+ Seq(a), Seq(Alias(count(Literal(1)), "a")(), a),
+ Project(Seq(x, y, a3, b3, a),
+ Join(
+ DomainJoin(Seq(a), testRelation2),
+ testRelation3, Inner, Some(And(x === a3, y > a)), JoinHint.NONE)))
+ check(innerPlan, outerPlan, correctAnswer, Seq(a <=> a))
+ }
+
+ test("SPARK-43780: aggregation in subquery with correlated left join") {
+ // Join in the subquery is on equi-predicates, so all the correlated
references can be
+ // substituted by equivalent ones from the outer query, and domain join is
not needed.
+ val outerPlan = testRelation
+ val innerPlan =
+ Aggregate(
+ Seq.empty[Expression], Seq(Alias(count(Literal(1)), "a")()),
+ Project(Seq(x, y, a3, b3),
+ Join(testRelation2, testRelation3, LeftOuter,
+ Some(And(x === a3, y === OuterReference(a))), JoinHint.NONE)))
Review Comment:
Can we also add a test with left join where the correlated predicate
involves the right (outer) side? Then we would need to have a DomainJoin, right?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]