This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch branch-3.1 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.1 by this push: new 6b398a4 [SPARK-35652][SQL] joinWith on two table generated from same one 6b398a4 is described below commit 6b398a4d832d78e2d0caee5acf22eac961024ee3 Author: dgd-contributor <dgd_contribu...@viettel.com.vn> AuthorDate: Fri Jun 11 20:36:50 2021 +0800 [SPARK-35652][SQL] joinWith on two table generated from same one It seems like spark inner join is performing a cartesian join in self joining using `joinWith` To produce this issues: ``` val df = spark.range(0,3) df.joinWith(df, df("id") === df("id")).show() ``` Before this pull request, the result is +---+---+ | _1 | _2 | +---+---+ | 0 | 0 | | 0 | 1 | | 0 | 2 | | 1 | 0 | | 1 | 1 | | 1 | 2 | | 2 | 0 | | 2 | 1 | | 2 | 2 | +---+---+ The expected result is +---+---+ | _1 | _2 | +---+---+ | 0 | 0 | | 1 | 1 | | 2 | 2 | +---+---+ correctness no add test Closes #32863 from dgd-contributor/SPARK-35652_join_and_joinWith_in_seft_joining. Authored-by: dgd-contributor <dgd_contribu...@viettel.com.vn> Signed-off-by: Wenchen Fan <wenc...@databricks.com> (cherry picked from commit 6e1aa15679b5fed249c62b2340151a0299401b18) Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../main/scala/org/apache/spark/sql/Dataset.scala | 50 ++++++++++++++-------- .../scala/org/apache/spark/sql/DatasetSuite.scala | 18 ++++++++ 2 files changed, 49 insertions(+), 19 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 1c76f4c..5d83d1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1042,6 +1042,30 @@ class Dataset[T] private[sql]( def join(right: Dataset[_], joinExprs: Column): DataFrame = join(right, joinExprs, "inner") /** + * find the trivially true predicates and automatically resolves them to both sides. + */ + private def resolveSelfJoinCondition(plan: Join): Join = { + val resolver = sparkSession.sessionState.analyzer.resolver + val cond = plan.condition.map { _.transform { + case catalyst.expressions.EqualTo(a: AttributeReference, b: AttributeReference) + if a.sameRef(b) => + catalyst.expressions.EqualTo( + plan.left.resolveQuoted(a.name, resolver) + .getOrElse(throw resolveException(a.name, plan.left.schema.fieldNames)), + plan.right.resolveQuoted(b.name, resolver) + .getOrElse(throw resolveException(b.name, plan.right.schema.fieldNames))) + case catalyst.expressions.EqualNullSafe(a: AttributeReference, b: AttributeReference) + if a.sameRef(b) => + catalyst.expressions.EqualNullSafe( + plan.left.resolveQuoted(a.name, resolver) + .getOrElse(throw resolveException(a.name, plan.left.schema.fieldNames)), + plan.right.resolveQuoted(b.name, resolver) + .getOrElse(throw resolveException(b.name, plan.right.schema.fieldNames))) + }} + plan.copy(condition = cond) + } + + /** * Join with another `DataFrame`, using the given join expression. The following performs * a full outer join between `df1` and `df2`. * @@ -1095,26 +1119,9 @@ class Dataset[T] private[sql]( // Otherwise, find the trivially true predicates and automatically resolves them to both sides. // By the time we get here, since we have already run analysis, all attributes should've been // resolved and become AttributeReference. - val resolver = sparkSession.sessionState.analyzer.resolver - val cond = plan.condition.map { _.transform { - case catalyst.expressions.EqualTo(a: AttributeReference, b: AttributeReference) - if a.sameRef(b) => - catalyst.expressions.EqualTo( - plan.left.resolveQuoted(a.name, resolver) - .getOrElse(throw resolveException(a.name, plan.left.schema.fieldNames)), - plan.right.resolveQuoted(b.name, resolver) - .getOrElse(throw resolveException(b.name, plan.right.schema.fieldNames))) - case catalyst.expressions.EqualNullSafe(a: AttributeReference, b: AttributeReference) - if a.sameRef(b) => - catalyst.expressions.EqualNullSafe( - plan.left.resolveQuoted(a.name, resolver) - .getOrElse(throw resolveException(a.name, plan.left.schema.fieldNames)), - plan.right.resolveQuoted(b.name, resolver) - .getOrElse(throw resolveException(b.name, plan.right.schema.fieldNames))) - }} withPlan { - plan.copy(condition = cond) + resolveSelfJoinCondition(plan) } } @@ -1156,7 +1163,7 @@ class Dataset[T] private[sql]( def joinWith[U](other: Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] = { // Creates a Join node and resolve it first, to get join condition resolved, self-join resolved, // etc. - val joined = sparkSession.sessionState.executePlan( + var joined = sparkSession.sessionState.executePlan( Join( this.logicalPlan, other.logicalPlan, @@ -1168,6 +1175,11 @@ class Dataset[T] private[sql]( throw new AnalysisException("Invalid join type in joinWith: " + joined.joinType.sql) } + // If auto self join alias is enable + if (sqlContext.conf.dataFrameSelfJoinAutoResolveAmbiguity) { + joined = resolveSelfJoinCondition(joined) + } + implicit val tuple2Encoder: Encoder[(T, U)] = ExpressionEncoder.tuple(this.exprEnc, other.exprEnc) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 69fbb9b..1b8bb3f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1980,6 +1980,24 @@ class DatasetSuite extends QueryTest checkAnswer(withUDF, Row(Row(1), null, null) :: Row(Row(1), null, null) :: Nil) } + + test("SPARK-35652: joinWith on two table generated from same one performing a cartesian join," + + " which should be inner join") { + val df = Seq(1, 2, 3).toDS() + + val joined = df.joinWith(df, df("value") === df("value"), "inner") + + val expectedSchema = StructType(Seq( + StructField("_1", IntegerType, nullable = false), + StructField("_2", IntegerType, nullable = false) + )) + + assert(joined.schema === expectedSchema) + + checkDataset( + joined, + (1, 1), (2, 2), (3, 3)) + } } case class Bar(a: Int) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org