This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 5ebf793  [SPARK-38206][SS] Ignore nullability on comparing the data 
type of join keys on stream-stream join
5ebf793 is described below

commit 5ebf7938b6882d343a6aa9e125f24bee394bb25f
Author: Jungtaek Lim <kabhwan.opensou...@gmail.com>
AuthorDate: Tue Feb 22 15:58:07 2022 +0900

    [SPARK-38206][SS] Ignore nullability on comparing the data type of join 
keys on stream-stream join
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to change the assertion of data type against joining keys 
on stream-stream join to ignore nullability.
    
    ### Why are the changes needed?
    
    The existing requirement on checking data types of joining keys is too 
restricted, as it also requires the same nullability. In batch query (I checked 
with HashJoinExec), nullability is ignored when checking data types of joining 
keys.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, end users will no longer encounter the assertion error on join keys 
with different nullability in both keys.
    
    ### How was this patch tested?
    
    New test added.
    
    Closes #35599 from HeartSaVioR/SPARK-38206.
    
    Authored-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
    Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
---
 .../streaming/StreamingSymmetricHashJoinExec.scala |   8 +-
 .../spark/sql/streaming/StreamingJoinSuite.scala   | 158 +++++++++++++++++++++
 2 files changed, 165 insertions(+), 1 deletion(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
index adb84a3..81888e0 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
@@ -174,7 +174,13 @@ case class StreamingSymmetricHashJoinExec(
     joinType == Inner || joinType == LeftOuter || joinType == RightOuter || 
joinType == FullOuter ||
     joinType == LeftSemi,
     errorMessageForJoinType)
-  require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType))
+
+  // The assertion against join keys is same as hash join for batch query.
+  require(leftKeys.length == rightKeys.length &&
+    leftKeys.map(_.dataType)
+      .zip(rightKeys.map(_.dataType))
+      .forall(types => types._1.sameType(types._2)),
+    "Join keys from two sides should have same length and types")
 
   private val storeConf = new StateStoreConf(conf)
   private val hadoopConfBcast = sparkContext.broadcast(
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
index e0926ef..2fbe6c4 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.streaming
 
 import java.io.File
+import java.lang.{Integer => JInteger}
 import java.sql.Timestamp
 import java.util.{Locale, UUID}
 
@@ -702,6 +703,53 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite {
         total = Seq(2), updated = Seq(1), droppedByWatermark = Seq(0), removed 
= Some(Seq(0)))
     )
   }
+
+  test("joining non-nullable left join key with nullable right join key") {
+    val input1 = MemoryStream[Int]
+    val input2 = MemoryStream[JInteger]
+
+    val joined = testForJoinKeyNullability(input1.toDF(), input2.toDF())
+    testStream(joined)(
+      AddData(input1, 1, 5),
+      AddData(input2, JInteger.valueOf(1), JInteger.valueOf(5), 
JInteger.valueOf(10), null),
+      CheckNewAnswer(Row(1, 1, 2, 3), Row(5, 5, 10, 15))
+    )
+  }
+
+  test("joining nullable left join key with non-nullable right join key") {
+    val input1 = MemoryStream[JInteger]
+    val input2 = MemoryStream[Int]
+
+    val joined = testForJoinKeyNullability(input1.toDF(), input2.toDF())
+    testStream(joined)(
+      AddData(input1, JInteger.valueOf(1), JInteger.valueOf(5), 
JInteger.valueOf(10), null),
+      AddData(input2, 1, 5),
+      CheckNewAnswer(Row(1, 1, 2, 3), Row(5, 5, 10, 15))
+    )
+  }
+
+  test("joining nullable left join key with nullable right join key") {
+    val input1 = MemoryStream[JInteger]
+    val input2 = MemoryStream[JInteger]
+
+    val joined = testForJoinKeyNullability(input1.toDF(), input2.toDF())
+    testStream(joined)(
+      AddData(input1, JInteger.valueOf(1), JInteger.valueOf(5), 
JInteger.valueOf(10), null),
+      AddData(input2, JInteger.valueOf(1), JInteger.valueOf(5), null),
+      CheckNewAnswer(
+        Row(JInteger.valueOf(1), JInteger.valueOf(1), JInteger.valueOf(2), 
JInteger.valueOf(3)),
+        Row(JInteger.valueOf(5), JInteger.valueOf(5), JInteger.valueOf(10), 
JInteger.valueOf(15)),
+        Row(null, null, null, null))
+    )
+  }
+
+  private def testForJoinKeyNullability(left: DataFrame, right: DataFrame): 
DataFrame = {
+    val df1 = left.selectExpr("value as leftKey", "value * 2 as leftValue")
+    val df2 = right.selectExpr("value as rightKey", "value * 3 as rightValue")
+
+    df1.join(df2, expr("leftKey <=> rightKey"))
+      .select("leftKey", "rightKey", "leftValue", "rightValue")
+  }
 }
 
 
@@ -1168,6 +1216,116 @@ class StreamingOuterJoinSuite extends 
StreamingJoinSuite {
       CheckNewAnswer(expectedOutput.head, expectedOutput.tail: _*)
     )
   }
+
+  test("left-outer: joining non-nullable left join key with nullable right 
join key") {
+    val input1 = MemoryStream[(Int, Int)]
+    val input2 = MemoryStream[(JInteger, Int)]
+
+    val joined = testForLeftOuterJoinKeyNullability(input1.toDF(), 
input2.toDF())
+
+    testStream(joined)(
+      AddData(input1, (1, 1), (1, 2), (1, 3), (1, 4), (1, 5)),
+      AddData(input2,
+        (JInteger.valueOf(1), 3),
+        (JInteger.valueOf(1), 4),
+        (JInteger.valueOf(1), 5),
+        (JInteger.valueOf(1), 6)
+      ),
+      CheckNewAnswer(
+        Row(1, 1, 3, 3, 10, 6, 9),
+        Row(1, 1, 4, 4, 10, 8, 12),
+        Row(1, 1, 5, 5, 10, 10, 15)),
+      AddData(input1, (1, 21)),
+      // right-null join
+      AddData(input2, (JInteger.valueOf(1), 22)), // watermark = 11, 
no-data-batch computes nulls
+      CheckNewAnswer(
+        Row(1, null, 1, null, 10, 2, null),
+        Row(1, null, 2, null, 10, 4, null)
+      )
+    )
+  }
+
+  test("left-outer: joining nullable left join key with non-nullable right 
join key") {
+    val input1 = MemoryStream[(JInteger, Int)]
+    val input2 = MemoryStream[(Int, Int)]
+
+    val joined = testForLeftOuterJoinKeyNullability(input1.toDF(), 
input2.toDF())
+
+    testStream(joined)(
+      AddData(input1,
+        (JInteger.valueOf(1), 1),
+        (null, 2),
+        (JInteger.valueOf(1), 3),
+        (JInteger.valueOf(1), 4),
+        (JInteger.valueOf(1), 5)),
+      AddData(input2, (1, 3), (1, 4), (1, 5), (1, 6)),
+      CheckNewAnswer(
+        Row(1, 1, 3, 3, 10, 6, 9),
+        Row(1, 1, 4, 4, 10, 8, 12),
+        Row(1, 1, 5, 5, 10, 10, 15)),
+      // right-null join
+      AddData(input1, (JInteger.valueOf(1), 21)),
+      AddData(input2, (1, 22)), // watermark = 11, no-data-batch computes nulls
+      CheckNewAnswer(
+        Row(1, null, 1, null, 10, 2, null),
+        Row(null, null, 2, null, 10, 4, null)
+      )
+    )
+  }
+
+  test("left-outer: joining nullable left join key with nullable right join 
key") {
+    val input1 = MemoryStream[(JInteger, Int)]
+    val input2 = MemoryStream[(JInteger, Int)]
+
+    val joined = testForLeftOuterJoinKeyNullability(input1.toDF(), 
input2.toDF())
+
+    testStream(joined)(
+      AddData(input1,
+        (JInteger.valueOf(1), 1),
+        (null, 2),
+        (JInteger.valueOf(1), 3),
+        (null, 4),
+        (JInteger.valueOf(1), 5)),
+      AddData(input2,
+        (JInteger.valueOf(1), 3),
+        (null, 4),
+        (JInteger.valueOf(1), 5),
+        (JInteger.valueOf(1), 6)),
+      CheckNewAnswer(
+        Row(1, 1, 3, 3, 10, 6, 9),
+        Row(null, null, 4, 4, 10, 8, 12),
+        Row(1, 1, 5, 5, 10, 10, 15)),
+      // right-null join
+      AddData(input1, (JInteger.valueOf(1), 21)),
+      AddData(input2, (JInteger.valueOf(1), 22)), // watermark = 11, 
no-data-batch computes nulls
+      CheckNewAnswer(
+        Row(1, null, 1, null, 10, 2, null),
+        Row(null, null, 2, null, 10, 4, null)
+      )
+    )
+  }
+
+  private def testForLeftOuterJoinKeyNullability(left: DataFrame, right: 
DataFrame): DataFrame = {
+    val df1 = left
+      .selectExpr("_1 as leftKey1", "_2 as leftKey2", "timestamp_seconds(_2) 
as leftTime",
+        "_2 * 2 as leftValue")
+      .withWatermark("leftTime", "10 seconds")
+    val df2 = right
+      .selectExpr(
+        "_1 as rightKey1", "_2 as rightKey2", "timestamp_seconds(_2) as 
rightTime",
+        "_2 * 3 as rightValue")
+      .withWatermark("rightTime", "10 seconds")
+
+    val windowed1 = df1.select('leftKey1, 'leftKey2,
+      window('leftTime, "10 second").as('leftWindow), 'leftValue)
+    val windowed2 = df2.select('rightKey1, 'rightKey2,
+      window('rightTime, "10 second").as('rightWindow), 'rightValue)
+    windowed1.join(windowed2,
+      expr("leftKey1 <=> rightKey1 AND leftKey2 = rightKey2 AND leftWindow = 
rightWindow"),
+      "left_outer"
+    ).select('leftKey1, 'rightKey1, 'leftKey2, 'rightKey2, 
$"leftWindow.end".cast("long"),
+      'leftValue, 'rightValue)
+  }
 }
 
 class StreamingFullOuterJoinSuite extends StreamingJoinSuite {

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to