This is an automated email from the ASF dual-hosted git repository. kabhwan pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 5ebf793 [SPARK-38206][SS] Ignore nullability on comparing the data type of join keys on stream-stream join 5ebf793 is described below commit 5ebf7938b6882d343a6aa9e125f24bee394bb25f Author: Jungtaek Lim <kabhwan.opensou...@gmail.com> AuthorDate: Tue Feb 22 15:58:07 2022 +0900 [SPARK-38206][SS] Ignore nullability on comparing the data type of join keys on stream-stream join ### What changes were proposed in this pull request? This PR proposes to change the assertion of data type against joining keys on stream-stream join to ignore nullability. ### Why are the changes needed? The existing requirement on checking data types of joining keys is too restricted, as it also requires the same nullability. In batch query (I checked with HashJoinExec), nullability is ignored when checking data types of joining keys. ### Does this PR introduce _any_ user-facing change? Yes, end users will no longer encounter the assertion error on join keys with different nullability in both keys. ### How was this patch tested? New test added. Closes #35599 from HeartSaVioR/SPARK-38206. Authored-by: Jungtaek Lim <kabhwan.opensou...@gmail.com> Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com> --- .../streaming/StreamingSymmetricHashJoinExec.scala | 8 +- .../spark/sql/streaming/StreamingJoinSuite.scala | 158 +++++++++++++++++++++ 2 files changed, 165 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index adb84a3..81888e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -174,7 +174,13 @@ case class StreamingSymmetricHashJoinExec( joinType == Inner || joinType == LeftOuter || joinType == RightOuter || joinType == FullOuter || joinType == LeftSemi, errorMessageForJoinType) - require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType)) + + // The assertion against join keys is same as hash join for batch query. + require(leftKeys.length == rightKeys.length && + leftKeys.map(_.dataType) + .zip(rightKeys.map(_.dataType)) + .forall(types => types._1.sameType(types._2)), + "Join keys from two sides should have same length and types") private val storeConf = new StateStoreConf(conf) private val hadoopConfBcast = sparkContext.broadcast( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index e0926ef..2fbe6c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.streaming import java.io.File +import java.lang.{Integer => JInteger} import java.sql.Timestamp import java.util.{Locale, UUID} @@ -702,6 +703,53 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite { total = Seq(2), updated = Seq(1), droppedByWatermark = Seq(0), removed = Some(Seq(0))) ) } + + test("joining non-nullable left join key with nullable right join key") { + val input1 = MemoryStream[Int] + val input2 = MemoryStream[JInteger] + + val joined = testForJoinKeyNullability(input1.toDF(), input2.toDF()) + testStream(joined)( + AddData(input1, 1, 5), + AddData(input2, JInteger.valueOf(1), JInteger.valueOf(5), JInteger.valueOf(10), null), + CheckNewAnswer(Row(1, 1, 2, 3), Row(5, 5, 10, 15)) + ) + } + + test("joining nullable left join key with non-nullable right join key") { + val input1 = MemoryStream[JInteger] + val input2 = MemoryStream[Int] + + val joined = testForJoinKeyNullability(input1.toDF(), input2.toDF()) + testStream(joined)( + AddData(input1, JInteger.valueOf(1), JInteger.valueOf(5), JInteger.valueOf(10), null), + AddData(input2, 1, 5), + CheckNewAnswer(Row(1, 1, 2, 3), Row(5, 5, 10, 15)) + ) + } + + test("joining nullable left join key with nullable right join key") { + val input1 = MemoryStream[JInteger] + val input2 = MemoryStream[JInteger] + + val joined = testForJoinKeyNullability(input1.toDF(), input2.toDF()) + testStream(joined)( + AddData(input1, JInteger.valueOf(1), JInteger.valueOf(5), JInteger.valueOf(10), null), + AddData(input2, JInteger.valueOf(1), JInteger.valueOf(5), null), + CheckNewAnswer( + Row(JInteger.valueOf(1), JInteger.valueOf(1), JInteger.valueOf(2), JInteger.valueOf(3)), + Row(JInteger.valueOf(5), JInteger.valueOf(5), JInteger.valueOf(10), JInteger.valueOf(15)), + Row(null, null, null, null)) + ) + } + + private def testForJoinKeyNullability(left: DataFrame, right: DataFrame): DataFrame = { + val df1 = left.selectExpr("value as leftKey", "value * 2 as leftValue") + val df2 = right.selectExpr("value as rightKey", "value * 3 as rightValue") + + df1.join(df2, expr("leftKey <=> rightKey")) + .select("leftKey", "rightKey", "leftValue", "rightValue") + } } @@ -1168,6 +1216,116 @@ class StreamingOuterJoinSuite extends StreamingJoinSuite { CheckNewAnswer(expectedOutput.head, expectedOutput.tail: _*) ) } + + test("left-outer: joining non-nullable left join key with nullable right join key") { + val input1 = MemoryStream[(Int, Int)] + val input2 = MemoryStream[(JInteger, Int)] + + val joined = testForLeftOuterJoinKeyNullability(input1.toDF(), input2.toDF()) + + testStream(joined)( + AddData(input1, (1, 1), (1, 2), (1, 3), (1, 4), (1, 5)), + AddData(input2, + (JInteger.valueOf(1), 3), + (JInteger.valueOf(1), 4), + (JInteger.valueOf(1), 5), + (JInteger.valueOf(1), 6) + ), + CheckNewAnswer( + Row(1, 1, 3, 3, 10, 6, 9), + Row(1, 1, 4, 4, 10, 8, 12), + Row(1, 1, 5, 5, 10, 10, 15)), + AddData(input1, (1, 21)), + // right-null join + AddData(input2, (JInteger.valueOf(1), 22)), // watermark = 11, no-data-batch computes nulls + CheckNewAnswer( + Row(1, null, 1, null, 10, 2, null), + Row(1, null, 2, null, 10, 4, null) + ) + ) + } + + test("left-outer: joining nullable left join key with non-nullable right join key") { + val input1 = MemoryStream[(JInteger, Int)] + val input2 = MemoryStream[(Int, Int)] + + val joined = testForLeftOuterJoinKeyNullability(input1.toDF(), input2.toDF()) + + testStream(joined)( + AddData(input1, + (JInteger.valueOf(1), 1), + (null, 2), + (JInteger.valueOf(1), 3), + (JInteger.valueOf(1), 4), + (JInteger.valueOf(1), 5)), + AddData(input2, (1, 3), (1, 4), (1, 5), (1, 6)), + CheckNewAnswer( + Row(1, 1, 3, 3, 10, 6, 9), + Row(1, 1, 4, 4, 10, 8, 12), + Row(1, 1, 5, 5, 10, 10, 15)), + // right-null join + AddData(input1, (JInteger.valueOf(1), 21)), + AddData(input2, (1, 22)), // watermark = 11, no-data-batch computes nulls + CheckNewAnswer( + Row(1, null, 1, null, 10, 2, null), + Row(null, null, 2, null, 10, 4, null) + ) + ) + } + + test("left-outer: joining nullable left join key with nullable right join key") { + val input1 = MemoryStream[(JInteger, Int)] + val input2 = MemoryStream[(JInteger, Int)] + + val joined = testForLeftOuterJoinKeyNullability(input1.toDF(), input2.toDF()) + + testStream(joined)( + AddData(input1, + (JInteger.valueOf(1), 1), + (null, 2), + (JInteger.valueOf(1), 3), + (null, 4), + (JInteger.valueOf(1), 5)), + AddData(input2, + (JInteger.valueOf(1), 3), + (null, 4), + (JInteger.valueOf(1), 5), + (JInteger.valueOf(1), 6)), + CheckNewAnswer( + Row(1, 1, 3, 3, 10, 6, 9), + Row(null, null, 4, 4, 10, 8, 12), + Row(1, 1, 5, 5, 10, 10, 15)), + // right-null join + AddData(input1, (JInteger.valueOf(1), 21)), + AddData(input2, (JInteger.valueOf(1), 22)), // watermark = 11, no-data-batch computes nulls + CheckNewAnswer( + Row(1, null, 1, null, 10, 2, null), + Row(null, null, 2, null, 10, 4, null) + ) + ) + } + + private def testForLeftOuterJoinKeyNullability(left: DataFrame, right: DataFrame): DataFrame = { + val df1 = left + .selectExpr("_1 as leftKey1", "_2 as leftKey2", "timestamp_seconds(_2) as leftTime", + "_2 * 2 as leftValue") + .withWatermark("leftTime", "10 seconds") + val df2 = right + .selectExpr( + "_1 as rightKey1", "_2 as rightKey2", "timestamp_seconds(_2) as rightTime", + "_2 * 3 as rightValue") + .withWatermark("rightTime", "10 seconds") + + val windowed1 = df1.select('leftKey1, 'leftKey2, + window('leftTime, "10 second").as('leftWindow), 'leftValue) + val windowed2 = df2.select('rightKey1, 'rightKey2, + window('rightTime, "10 second").as('rightWindow), 'rightValue) + windowed1.join(windowed2, + expr("leftKey1 <=> rightKey1 AND leftKey2 = rightKey2 AND leftWindow = rightWindow"), + "left_outer" + ).select('leftKey1, 'rightKey1, 'leftKey2, 'rightKey2, $"leftWindow.end".cast("long"), + 'leftValue, 'rightValue) + } } class StreamingFullOuterJoinSuite extends StreamingJoinSuite { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org