This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new cd815ae [SPARK-26078][SQL] Dedup self-join attributes on IN subqueries
cd815ae is described below
commit cd815ae6c5ce3edb8aec3add942549f76a20e586
Author: Marco Gaido <[email protected]>
AuthorDate: Sun Dec 16 10:57:11 2018 +0800
[SPARK-26078][SQL] Dedup self-join attributes on IN subqueries
## What changes were proposed in this pull request?
When there is a self-join as result of a IN subquery, the join condition
may be invalid, resulting in trivially true predicates and return wrong results.
The PR deduplicates the subquery output in order to avoid the issue.
## How was this patch tested?
added UT
Closes #23057 from mgaido91/SPARK-26078.
Authored-by: Marco Gaido <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../spark/sql/catalyst/optimizer/subquery.scala | 99 +++++++++++++---------
.../scala/org/apache/spark/sql/SubquerySuite.scala | 37 ++++++++
2 files changed, 98 insertions(+), 38 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
index e9b7a8b..34840c6 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer
import scala.collection.mutable.ArrayBuffer
-import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
import org.apache.spark.sql.catalyst.expressions.aggregate._
@@ -43,31 +43,53 @@ import org.apache.spark.sql.types._
* condition.
*/
object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper
{
- private def dedupJoin(joinPlan: LogicalPlan): LogicalPlan = joinPlan match {
+
+ private def buildJoin(
+ outerPlan: LogicalPlan,
+ subplan: LogicalPlan,
+ joinType: JoinType,
+ condition: Option[Expression]): Join = {
+ // Deduplicate conflicting attributes if any.
+ val dedupSubplan = dedupSubqueryOnSelfJoin(outerPlan, subplan, None,
condition)
+ Join(outerPlan, dedupSubplan, joinType, condition)
+ }
+
+ private def dedupSubqueryOnSelfJoin(
+ outerPlan: LogicalPlan,
+ subplan: LogicalPlan,
+ valuesOpt: Option[Seq[Expression]],
+ condition: Option[Expression] = None): LogicalPlan = {
// SPARK-21835: It is possibly that the two sides of the join have
conflicting attributes,
// the produced join then becomes unresolved and break structural
integrity. We should
- // de-duplicate conflicting attributes. We don't use transformation here
because we only
- // care about the most top join converted from correlated predicate
subquery.
- case j @ Join(left, right, joinType @ (LeftSemi | LeftAnti |
ExistenceJoin(_)), joinCond) =>
- val duplicates = right.outputSet.intersect(left.outputSet)
- if (duplicates.nonEmpty) {
- val aliasMap = AttributeMap(duplicates.map { dup =>
- dup -> Alias(dup, dup.toString)()
- }.toSeq)
- val aliasedExpressions = right.output.map { ref =>
- aliasMap.getOrElse(ref, ref)
- }
- val newRight = Project(aliasedExpressions, right)
- val newJoinCond = joinCond.map { condExpr =>
- condExpr transform {
- case a: Attribute => aliasMap.getOrElse(a, a).toAttribute
+ // de-duplicate conflicting attributes.
+ // SPARK-26078: it may also happen that the subquery has conflicting
attributes with the outer
+ // values. In this case, the resulting join would contain trivially true
conditions (eg.
+ // id#3 = id#3) which cannot be de-duplicated after. In this method, if
there are conflicting
+ // attributes in the join condition, the subquery's conflicting attributes
are changed using
+ // a projection which aliases them and resolves the problem.
+ val outerReferences = valuesOpt.map(values =>
+
AttributeSet.fromAttributeSets(values.map(_.references))).getOrElse(AttributeSet.empty)
+ val outerRefs = outerPlan.outputSet ++ outerReferences
+ val duplicates = outerRefs.intersect(subplan.outputSet)
+ if (duplicates.nonEmpty) {
+ condition.foreach { e =>
+ val conflictingAttrs = e.references.intersect(duplicates)
+ if (conflictingAttrs.nonEmpty) {
+ throw new AnalysisException("Found conflicting attributes " +
+ s"${conflictingAttrs.mkString(",")} in the condition joining
outer plan:\n " +
+ s"$outerPlan\nand subplan:\n $subplan")
}
- }
- Join(left, newRight, joinType, newJoinCond)
- } else {
- j
}
- case _ => joinPlan
+ val rewrites = AttributeMap(duplicates.map { dup =>
+ dup -> Alias(dup, dup.toString)()
+ }.toSeq)
+ val aliasedExpressions = subplan.output.map { ref =>
+ rewrites.getOrElse(ref, ref)
+ }
+ Project(aliasedExpressions, subplan)
+ } else {
+ subplan
+ }
}
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
@@ -85,17 +107,16 @@ object RewritePredicateSubquery extends Rule[LogicalPlan]
with PredicateHelper {
withSubquery.foldLeft(newFilter) {
case (p, Exists(sub, conditions, _)) =>
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
- // Deduplicate conflicting attributes if any.
- dedupJoin(Join(outerPlan, sub, LeftSemi, joinCond))
+ buildJoin(outerPlan, sub, LeftSemi, joinCond)
case (p, Not(Exists(sub, conditions, _))) =>
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
- // Deduplicate conflicting attributes if any.
- dedupJoin(Join(outerPlan, sub, LeftAnti, joinCond))
+ buildJoin(outerPlan, sub, LeftAnti, joinCond)
case (p, InSubquery(values, ListQuery(sub, conditions, _, _))) =>
- val inConditions = values.zip(sub.output).map(EqualTo.tupled)
- val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++
conditions, p)
// Deduplicate conflicting attributes if any.
- dedupJoin(Join(outerPlan, sub, LeftSemi, joinCond))
+ val newSub = dedupSubqueryOnSelfJoin(p, sub, Some(values))
+ val inConditions = values.zip(newSub.output).map(EqualTo.tupled)
+ val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++
conditions, p)
+ Join(outerPlan, newSub, LeftSemi, joinCond)
case (p, Not(InSubquery(values, ListQuery(sub, conditions, _, _)))) =>
// This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr
// Construct the condition. A NULL in one of the conditions is
regarded as a positive
@@ -103,7 +124,10 @@ object RewritePredicateSubquery extends Rule[LogicalPlan]
with PredicateHelper {
// Note that will almost certainly be planned as a Broadcast Nested
Loop join.
// Use EXISTS if performance matters to you.
- val inConditions = values.zip(sub.output).map(EqualTo.tupled)
+
+ // Deduplicate conflicting attributes if any.
+ val newSub = dedupSubqueryOnSelfJoin(p, sub, Some(values))
+ val inConditions = values.zip(newSub.output).map(EqualTo.tupled)
val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions, p)
// Expand the NOT IN expression with the NULL-aware semantic
// to its full form. That is from:
@@ -118,8 +142,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan]
with PredicateHelper {
// will have the final conditions in the LEFT ANTI as
// (A.A1 = B.B1 OR ISNULL(A.A1 = B.B1)) AND (B.B2 = A.A2) AND B.B3 >
1
val finalJoinCond = (nullAwareJoinConds ++
conditions).reduceLeft(And)
- // Deduplicate conflicting attributes if any.
- dedupJoin(Join(outerPlan, sub, LeftAnti, Option(finalJoinCond)))
+ Join(outerPlan, newSub, LeftAnti, Option(finalJoinCond))
case (p, predicate) =>
val (newCond, inputPlan) = rewriteExistentialExpr(Seq(predicate), p)
Project(p.output, Filter(newCond.get, inputPlan))
@@ -140,16 +163,16 @@ object RewritePredicateSubquery extends Rule[LogicalPlan]
with PredicateHelper {
e transformUp {
case Exists(sub, conditions, _) =>
val exists = AttributeReference("exists", BooleanType, nullable =
false)()
- // Deduplicate conflicting attributes if any.
- newPlan = dedupJoin(
- Join(newPlan, sub, ExistenceJoin(exists),
conditions.reduceLeftOption(And)))
+ newPlan =
+ buildJoin(newPlan, sub, ExistenceJoin(exists),
conditions.reduceLeftOption(And))
exists
case InSubquery(values, ListQuery(sub, conditions, _, _)) =>
val exists = AttributeReference("exists", BooleanType, nullable =
false)()
- val inConditions = values.zip(sub.output).map(EqualTo.tupled)
- val newConditions = (inConditions ++
conditions).reduceLeftOption(And)
// Deduplicate conflicting attributes if any.
- newPlan = dedupJoin(Join(newPlan, sub, ExistenceJoin(exists),
newConditions))
+ val newSub = dedupSubqueryOnSelfJoin(newPlan, sub, Some(values))
+ val inConditions = values.zip(newSub.output).map(EqualTo.tupled)
+ val newConditions = (inConditions ++
conditions).reduceLeftOption(And)
+ newPlan = Join(newPlan, newSub, ExistenceJoin(exists), newConditions)
exists
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
index 5088821..c95c52f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
@@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan, Sort}
import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types._
class SubquerySuite extends QueryTest with SharedSQLContext {
import testImplicits._
@@ -1280,4 +1281,40 @@ class SubquerySuite extends QueryTest with
SharedSQLContext {
assert(subqueries.length == 1)
}
}
+
+ test("SPARK-26078: deduplicate fake self joins for IN subqueries") {
+ withTempView("a", "b") {
+ Seq("a" -> 2, "b" -> 1).toDF("id", "num").createTempView("a")
+ Seq("a" -> 2, "b" -> 1).toDF("id", "num").createTempView("b")
+
+ val df1 = spark.sql(
+ """
+ |SELECT id,num,source FROM (
+ | SELECT id, num, 'a' as source FROM a
+ | UNION ALL
+ | SELECT id, num, 'b' as source FROM b
+ |) AS c WHERE c.id IN (SELECT id FROM b WHERE num = 2)
+ """.stripMargin)
+ checkAnswer(df1, Seq(Row("a", 2, "a"), Row("a", 2, "b")))
+ val df2 = spark.sql(
+ """
+ |SELECT id,num,source FROM (
+ | SELECT id, num, 'a' as source FROM a
+ | UNION ALL
+ | SELECT id, num, 'b' as source FROM b
+ |) AS c WHERE c.id NOT IN (SELECT id FROM b WHERE num = 2)
+ """.stripMargin)
+ checkAnswer(df2, Seq(Row("b", 1, "a"), Row("b", 1, "b")))
+ val df3 = spark.sql(
+ """
+ |SELECT id,num,source FROM (
+ | SELECT id, num, 'a' as source FROM a
+ | UNION ALL
+ | SELECT id, num, 'b' as source FROM b
+ |) AS c WHERE c.id IN (SELECT id FROM b WHERE num = 2) OR
+ |c.id IN (SELECT id FROM b WHERE num = 3)
+ """.stripMargin)
+ checkAnswer(df3, Seq(Row("a", 2, "a"), Row("a", 2, "b")))
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]