This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.1
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.1 by this push:
     new 6b398a4  [SPARK-35652][SQL] joinWith on two table generated from same 
one
6b398a4 is described below

commit 6b398a4d832d78e2d0caee5acf22eac961024ee3
Author: dgd-contributor <dgd_contribu...@viettel.com.vn>
AuthorDate: Fri Jun 11 20:36:50 2021 +0800

    [SPARK-35652][SQL] joinWith on two table generated from same one
    
    It seems like spark inner join is performing a cartesian join in self 
joining using `joinWith`
    
    To produce this issues:
    ```
    val df = spark.range(0,3)
    df.joinWith(df, df("id") === df("id")).show()
    ```
    
    Before this pull request, the result is
    +---+---+
     | _1 |  _2 |
    +---+---+
    |    0 |   0 |
    |    0 |   1 |
    |    0 |   2 |
    |    1 |   0 |
    |    1 |   1 |
    |    1 |   2 |
    |    2 |   0 |
    |    2 |   1 |
    |    2 |   2 |
    +---+---+
    
    The expected result is
    +---+---+
     | _1 |  _2 |
    +---+---+
    |    0 |   0 |
    |    1 |   1 |
    |    2 |   2 |
    +---+---+
    correctness
    
    no
    
    add test
    
    Closes #32863 from 
dgd-contributor/SPARK-35652_join_and_joinWith_in_seft_joining.
    
    Authored-by: dgd-contributor <dgd_contribu...@viettel.com.vn>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
    (cherry picked from commit 6e1aa15679b5fed249c62b2340151a0299401b18)
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../main/scala/org/apache/spark/sql/Dataset.scala  | 50 ++++++++++++++--------
 .../scala/org/apache/spark/sql/DatasetSuite.scala  | 18 ++++++++
 2 files changed, 49 insertions(+), 19 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 1c76f4c..5d83d1e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1042,6 +1042,30 @@ class Dataset[T] private[sql](
   def join(right: Dataset[_], joinExprs: Column): DataFrame = join(right, 
joinExprs, "inner")
 
   /**
+   * find the trivially true predicates and automatically resolves them to 
both sides.
+   */
+  private def resolveSelfJoinCondition(plan: Join): Join = {
+    val resolver = sparkSession.sessionState.analyzer.resolver
+    val cond = plan.condition.map { _.transform {
+      case catalyst.expressions.EqualTo(a: AttributeReference, b: 
AttributeReference)
+        if a.sameRef(b) =>
+        catalyst.expressions.EqualTo(
+          plan.left.resolveQuoted(a.name, resolver)
+            .getOrElse(throw resolveException(a.name, 
plan.left.schema.fieldNames)),
+          plan.right.resolveQuoted(b.name, resolver)
+            .getOrElse(throw resolveException(b.name, 
plan.right.schema.fieldNames)))
+      case catalyst.expressions.EqualNullSafe(a: AttributeReference, b: 
AttributeReference)
+        if a.sameRef(b) =>
+        catalyst.expressions.EqualNullSafe(
+          plan.left.resolveQuoted(a.name, resolver)
+            .getOrElse(throw resolveException(a.name, 
plan.left.schema.fieldNames)),
+          plan.right.resolveQuoted(b.name, resolver)
+            .getOrElse(throw resolveException(b.name, 
plan.right.schema.fieldNames)))
+    }}
+    plan.copy(condition = cond)
+  }
+
+  /**
    * Join with another `DataFrame`, using the given join expression. The 
following performs
    * a full outer join between `df1` and `df2`.
    *
@@ -1095,26 +1119,9 @@ class Dataset[T] private[sql](
     // Otherwise, find the trivially true predicates and automatically 
resolves them to both sides.
     // By the time we get here, since we have already run analysis, all 
attributes should've been
     // resolved and become AttributeReference.
-    val resolver = sparkSession.sessionState.analyzer.resolver
-    val cond = plan.condition.map { _.transform {
-      case catalyst.expressions.EqualTo(a: AttributeReference, b: 
AttributeReference)
-          if a.sameRef(b) =>
-        catalyst.expressions.EqualTo(
-          plan.left.resolveQuoted(a.name, resolver)
-            .getOrElse(throw resolveException(a.name, 
plan.left.schema.fieldNames)),
-          plan.right.resolveQuoted(b.name, resolver)
-            .getOrElse(throw resolveException(b.name, 
plan.right.schema.fieldNames)))
-      case catalyst.expressions.EqualNullSafe(a: AttributeReference, b: 
AttributeReference)
-        if a.sameRef(b) =>
-        catalyst.expressions.EqualNullSafe(
-          plan.left.resolveQuoted(a.name, resolver)
-            .getOrElse(throw resolveException(a.name, 
plan.left.schema.fieldNames)),
-          plan.right.resolveQuoted(b.name, resolver)
-            .getOrElse(throw resolveException(b.name, 
plan.right.schema.fieldNames)))
-    }}
 
     withPlan {
-      plan.copy(condition = cond)
+      resolveSelfJoinCondition(plan)
     }
   }
 
@@ -1156,7 +1163,7 @@ class Dataset[T] private[sql](
   def joinWith[U](other: Dataset[U], condition: Column, joinType: String): 
Dataset[(T, U)] = {
     // Creates a Join node and resolve it first, to get join condition 
resolved, self-join resolved,
     // etc.
-    val joined = sparkSession.sessionState.executePlan(
+    var joined = sparkSession.sessionState.executePlan(
       Join(
         this.logicalPlan,
         other.logicalPlan,
@@ -1168,6 +1175,11 @@ class Dataset[T] private[sql](
       throw new AnalysisException("Invalid join type in joinWith: " + 
joined.joinType.sql)
     }
 
+    // If auto self join alias is enable
+    if (sqlContext.conf.dataFrameSelfJoinAutoResolveAmbiguity) {
+      joined = resolveSelfJoinCondition(joined)
+    }
+
     implicit val tuple2Encoder: Encoder[(T, U)] =
       ExpressionEncoder.tuple(this.exprEnc, other.exprEnc)
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 69fbb9b..1b8bb3f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -1980,6 +1980,24 @@ class DatasetSuite extends QueryTest
 
     checkAnswer(withUDF, Row(Row(1), null, null) :: Row(Row(1), null, null) :: 
Nil)
   }
+
+  test("SPARK-35652: joinWith on two table generated from same one performing 
a cartesian join," +
+    " which should be inner join") {
+    val df = Seq(1, 2, 3).toDS()
+
+    val joined = df.joinWith(df, df("value") === df("value"), "inner")
+
+    val expectedSchema = StructType(Seq(
+      StructField("_1", IntegerType, nullable = false),
+      StructField("_2", IntegerType, nullable = false)
+    ))
+
+    assert(joined.schema === expectedSchema)
+
+    checkDataset(
+      joined,
+      (1, 1), (2, 2), (3, 3))
+  }
 }
 
 case class Bar(a: Int)

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to