[ 
https://issues.apache.org/jira/browse/SPARK-26078?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16722363#comment-16722363
 ] 

ASF GitHub Bot commented on SPARK-26078:
----------------------------------------

asfgit closed pull request #23057: [SPARK-26078][SQL] Dedup self-join 
attributes on IN subqueries
URL: https://github.com/apache/spark/pull/23057
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
index e9b7a8b76e683..34840c6c977a6 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer
 
 import scala.collection.mutable.ArrayBuffer
 
-import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
@@ -43,31 +43,53 @@ import org.apache.spark.sql.types._
  *    condition.
  */
 object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper 
{
-  private def dedupJoin(joinPlan: LogicalPlan): LogicalPlan = joinPlan match {
+
+  private def buildJoin(
+      outerPlan: LogicalPlan,
+      subplan: LogicalPlan,
+      joinType: JoinType,
+      condition: Option[Expression]): Join = {
+    // Deduplicate conflicting attributes if any.
+    val dedupSubplan = dedupSubqueryOnSelfJoin(outerPlan, subplan, None, 
condition)
+    Join(outerPlan, dedupSubplan, joinType, condition)
+  }
+
+  private def dedupSubqueryOnSelfJoin(
+      outerPlan: LogicalPlan,
+      subplan: LogicalPlan,
+      valuesOpt: Option[Seq[Expression]],
+      condition: Option[Expression] = None): LogicalPlan = {
     // SPARK-21835: It is possibly that the two sides of the join have 
conflicting attributes,
     // the produced join then becomes unresolved and break structural 
integrity. We should
-    // de-duplicate conflicting attributes. We don't use transformation here 
because we only
-    // care about the most top join converted from correlated predicate 
subquery.
-    case j @ Join(left, right, joinType @ (LeftSemi | LeftAnti | 
ExistenceJoin(_)), joinCond) =>
-      val duplicates = right.outputSet.intersect(left.outputSet)
-      if (duplicates.nonEmpty) {
-        val aliasMap = AttributeMap(duplicates.map { dup =>
-          dup -> Alias(dup, dup.toString)()
-        }.toSeq)
-        val aliasedExpressions = right.output.map { ref =>
-          aliasMap.getOrElse(ref, ref)
-        }
-        val newRight = Project(aliasedExpressions, right)
-        val newJoinCond = joinCond.map { condExpr =>
-          condExpr transform {
-            case a: Attribute => aliasMap.getOrElse(a, a).toAttribute
+    // de-duplicate conflicting attributes.
+    // SPARK-26078: it may also happen that the subquery has conflicting 
attributes with the outer
+    // values. In this case, the resulting join would contain trivially true 
conditions (eg.
+    // id#3 = id#3) which cannot be de-duplicated after. In this method, if 
there are conflicting
+    // attributes in the join condition, the subquery's conflicting attributes 
are changed using
+    // a projection which aliases them and resolves the problem.
+    val outerReferences = valuesOpt.map(values =>
+      
AttributeSet.fromAttributeSets(values.map(_.references))).getOrElse(AttributeSet.empty)
+    val outerRefs = outerPlan.outputSet ++ outerReferences
+    val duplicates = outerRefs.intersect(subplan.outputSet)
+    if (duplicates.nonEmpty) {
+      condition.foreach { e =>
+          val conflictingAttrs = e.references.intersect(duplicates)
+          if (conflictingAttrs.nonEmpty) {
+            throw new AnalysisException("Found conflicting attributes " +
+              s"${conflictingAttrs.mkString(",")} in the condition joining 
outer plan:\n  " +
+              s"$outerPlan\nand subplan:\n  $subplan")
           }
-        }
-        Join(left, newRight, joinType, newJoinCond)
-      } else {
-        j
       }
-    case _ => joinPlan
+      val rewrites = AttributeMap(duplicates.map { dup =>
+        dup -> Alias(dup, dup.toString)()
+      }.toSeq)
+      val aliasedExpressions = subplan.output.map { ref =>
+        rewrites.getOrElse(ref, ref)
+      }
+      Project(aliasedExpressions, subplan)
+    } else {
+      subplan
+    }
   }
 
   def apply(plan: LogicalPlan): LogicalPlan = plan transform {
@@ -85,17 +107,16 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] 
with PredicateHelper {
       withSubquery.foldLeft(newFilter) {
         case (p, Exists(sub, conditions, _)) =>
           val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
-          // Deduplicate conflicting attributes if any.
-          dedupJoin(Join(outerPlan, sub, LeftSemi, joinCond))
+          buildJoin(outerPlan, sub, LeftSemi, joinCond)
         case (p, Not(Exists(sub, conditions, _))) =>
           val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
-          // Deduplicate conflicting attributes if any.
-          dedupJoin(Join(outerPlan, sub, LeftAnti, joinCond))
+          buildJoin(outerPlan, sub, LeftAnti, joinCond)
         case (p, InSubquery(values, ListQuery(sub, conditions, _, _))) =>
-          val inConditions = values.zip(sub.output).map(EqualTo.tupled)
-          val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ 
conditions, p)
           // Deduplicate conflicting attributes if any.
-          dedupJoin(Join(outerPlan, sub, LeftSemi, joinCond))
+          val newSub = dedupSubqueryOnSelfJoin(p, sub, Some(values))
+          val inConditions = values.zip(newSub.output).map(EqualTo.tupled)
+          val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ 
conditions, p)
+          Join(outerPlan, newSub, LeftSemi, joinCond)
         case (p, Not(InSubquery(values, ListQuery(sub, conditions, _, _)))) =>
           // This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr
           // Construct the condition. A NULL in one of the conditions is 
regarded as a positive
@@ -103,7 +124,10 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] 
with PredicateHelper {
 
           // Note that will almost certainly be planned as a Broadcast Nested 
Loop join.
           // Use EXISTS if performance matters to you.
-          val inConditions = values.zip(sub.output).map(EqualTo.tupled)
+
+          // Deduplicate conflicting attributes if any.
+          val newSub = dedupSubqueryOnSelfJoin(p, sub, Some(values))
+          val inConditions = values.zip(newSub.output).map(EqualTo.tupled)
           val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions, p)
           // Expand the NOT IN expression with the NULL-aware semantic
           // to its full form. That is from:
@@ -118,8 +142,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] 
with PredicateHelper {
           // will have the final conditions in the LEFT ANTI as
           // (A.A1 = B.B1 OR ISNULL(A.A1 = B.B1)) AND (B.B2 = A.A2) AND B.B3 > 
1
           val finalJoinCond = (nullAwareJoinConds ++ 
conditions).reduceLeft(And)
-          // Deduplicate conflicting attributes if any.
-          dedupJoin(Join(outerPlan, sub, LeftAnti, Option(finalJoinCond)))
+          Join(outerPlan, newSub, LeftAnti, Option(finalJoinCond))
         case (p, predicate) =>
           val (newCond, inputPlan) = rewriteExistentialExpr(Seq(predicate), p)
           Project(p.output, Filter(newCond.get, inputPlan))
@@ -140,16 +163,16 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] 
with PredicateHelper {
       e transformUp {
         case Exists(sub, conditions, _) =>
           val exists = AttributeReference("exists", BooleanType, nullable = 
false)()
-          // Deduplicate conflicting attributes if any.
-          newPlan = dedupJoin(
-            Join(newPlan, sub, ExistenceJoin(exists), 
conditions.reduceLeftOption(And)))
+          newPlan =
+            buildJoin(newPlan, sub, ExistenceJoin(exists), 
conditions.reduceLeftOption(And))
           exists
         case InSubquery(values, ListQuery(sub, conditions, _, _)) =>
           val exists = AttributeReference("exists", BooleanType, nullable = 
false)()
-          val inConditions = values.zip(sub.output).map(EqualTo.tupled)
-          val newConditions = (inConditions ++ 
conditions).reduceLeftOption(And)
           // Deduplicate conflicting attributes if any.
-          newPlan = dedupJoin(Join(newPlan, sub, ExistenceJoin(exists), 
newConditions))
+          val newSub = dedupSubqueryOnSelfJoin(newPlan, sub, Some(values))
+          val inConditions = values.zip(newSub.output).map(EqualTo.tupled)
+          val newConditions = (inConditions ++ 
conditions).reduceLeftOption(And)
+          newPlan = Join(newPlan, newSub, ExistenceJoin(exists), newConditions)
           exists
       }
     }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
index 5088821ad7361..c95c52f1d3a9c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
@@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer
 import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
 import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan, Sort}
 import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types._
 
 class SubquerySuite extends QueryTest with SharedSQLContext {
   import testImplicits._
@@ -1280,4 +1281,40 @@ class SubquerySuite extends QueryTest with 
SharedSQLContext {
       assert(subqueries.length == 1)
     }
   }
+
+  test("SPARK-26078: deduplicate fake self joins for IN subqueries") {
+    withTempView("a", "b") {
+      Seq("a" -> 2, "b" -> 1).toDF("id", "num").createTempView("a")
+      Seq("a" -> 2, "b" -> 1).toDF("id", "num").createTempView("b")
+
+      val df1 = spark.sql(
+        """
+          |SELECT id,num,source FROM (
+          |  SELECT id, num, 'a' as source FROM a
+          |  UNION ALL
+          |  SELECT id, num, 'b' as source FROM b
+          |) AS c WHERE c.id IN (SELECT id FROM b WHERE num = 2)
+        """.stripMargin)
+      checkAnswer(df1, Seq(Row("a", 2, "a"), Row("a", 2, "b")))
+      val df2 = spark.sql(
+        """
+          |SELECT id,num,source FROM (
+          |  SELECT id, num, 'a' as source FROM a
+          |  UNION ALL
+          |  SELECT id, num, 'b' as source FROM b
+          |) AS c WHERE c.id NOT IN (SELECT id FROM b WHERE num = 2)
+        """.stripMargin)
+      checkAnswer(df2, Seq(Row("b", 1, "a"), Row("b", 1, "b")))
+      val df3 = spark.sql(
+        """
+          |SELECT id,num,source FROM (
+          |  SELECT id, num, 'a' as source FROM a
+          |  UNION ALL
+          |  SELECT id, num, 'b' as source FROM b
+          |) AS c WHERE c.id IN (SELECT id FROM b WHERE num = 2) OR
+          |c.id IN (SELECT id FROM b WHERE num = 3)
+        """.stripMargin)
+      checkAnswer(df3, Seq(Row("a", 2, "a"), Row("a", 2, "b")))
+    }
+  }
 }


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


> WHERE .. IN fails to filter rows when used in combination with UNION
> --------------------------------------------------------------------
>
>                 Key: SPARK-26078
>                 URL: https://issues.apache.org/jira/browse/SPARK-26078
>             Project: Spark
>          Issue Type: Bug
>          Components: SQL
>    Affects Versions: 2.3.1, 2.4.0
>            Reporter: Arttu Voutilainen
>            Assignee: Marco Gaido
>            Priority: Blocker
>              Labels: correctness
>             Fix For: 3.0.0
>
>
> Hey,
> We encountered a case where Spark SQL does not seem to handle WHERE .. IN 
> correctly, when used in combination with UNION, but instead returns also rows 
> that do not fulfill the condition. Swapping the order of the datasets in the 
> UNION makes the problem go away. Repro below:
>  
> {code}
> sql = SQLContext(sc)
> a = spark.createDataFrame([{'id': 'a', 'num': 2}, {'id':'b', 'num':1}])
> b = spark.createDataFrame([{'id': 'a', 'num': 2}, {'id':'b', 'num':1}])
> a.registerTempTable('a')
> b.registerTempTable('b')
> bug = sql.sql("""
>     SELECT id,num,source FROM
>     (
>         SELECT id, num, 'a' as source FROM a
>         UNION ALL
>         SELECT id, num, 'b' as source FROM b
>     ) AS c
>     WHERE c.id IN (SELECT id FROM b WHERE num = 2)
> """)
> no_bug = sql.sql("""
>     SELECT id,num,source FROM
>     (
>         SELECT id, num, 'b' as source FROM b
>         UNION ALL
>         SELECT id, num, 'a' as source FROM a
>     ) AS c
>     WHERE c.id IN (SELECT id FROM b WHERE num = 2)
> """)
> bug.show()
> no_bug.show()
> bug.explain(True)
> no_bug.explain(True)
> {code}
> This results in one extra row in the "bug" DF coming from DF "b", that should 
> not be there as it  
> {code:java}
> >>> bug.show()
> +---+---+------+
> | id|num|source|
> +---+---+------+
> |  a|  2|     a|
> |  a|  2|     b|
> |  b|  1|     b|
> +---+---+------+
> >>> no_bug.show()
> +---+---+------+
> | id|num|source|
> +---+---+------+
> |  a|  2|     b|
> |  a|  2|     a|
> +---+---+------+
> {code}
>  The reason can be seen in the query plans:
> {code:java}
> >>> bug.explain(True)
> ...
> == Optimized Logical Plan ==
> Union
> :- Project [id#0, num#1L, a AS source#136]
> :  +- Join LeftSemi, (id#0 = id#4)
> :     :- LogicalRDD [id#0, num#1L], false
> :     +- Project [id#4]
> :        +- Filter (isnotnull(num#5L) && (num#5L = 2))
> :           +- LogicalRDD [id#4, num#5L], false
> +- Join LeftSemi, (id#4#172 = id#4#172)
>    :- Project [id#4, num#5L, b AS source#137]
>    :  +- LogicalRDD [id#4, num#5L], false
>    +- Project [id#4 AS id#4#172]
>       +- Filter (isnotnull(num#5L) && (num#5L = 2))
>          +- LogicalRDD [id#4, num#5L], false
> {code}
> Note the line *+- Join LeftSemi, (id#4#172 = id#4#172)* - this condition 
> seems wrong, and I believe it causes the LeftSemi to return true for all rows 
> in the left-hand-side table, thus failing to filter as the WHERE .. IN 
> should. Compare with the non-buggy version, where both LeftSemi joins have 
> distinct #-things on both sides:
> {code:java}
> >>> no_bug.explain()
> ...
> == Optimized Logical Plan ==
> Union
> :- Project [id#4, num#5L, b AS source#142]
> :  +- Join LeftSemi, (id#4 = id#4#173)
> :     :- LogicalRDD [id#4, num#5L], false
> :     +- Project [id#4 AS id#4#173]
> :        +- Filter (isnotnull(num#5L) && (num#5L = 2))
> :           +- LogicalRDD [id#4, num#5L], false
> +- Project [id#0, num#1L, a AS source#143]
>    +- Join LeftSemi, (id#0 = id#4#173)
>       :- LogicalRDD [id#0, num#1L], false
>       +- Project [id#4 AS id#4#173]
>          +- Filter (isnotnull(num#5L) && (num#5L = 2))
>             +- LogicalRDD [id#4, num#5L], false
> {code}
>  
> Best,
> -Arttu 
>  



--
This message was sent by Atlassian JIRA
(v7.6.3#76005)

---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org
For additional commands, e-mail: issues-h...@spark.apache.org

Reply via email to