Github user tdas commented on a diff in the pull request:
https://github.com/apache/spark/pull/19327#discussion_r142530623
--- Diff:
sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
---
@@ -470,3 +475,222 @@ class StreamingJoinSuite extends StreamTest with
StateStoreMetricsTest with Befo
}
}
}
+
+class StreamingOuterJoinSuite extends StreamTest with
StateStoreMetricsTest with BeforeAndAfter {
+
+ import testImplicits._
+ import org.apache.spark.sql.functions._
+
+ before {
+ SparkSession.setActiveSession(spark) // set this before force
initializing 'joinExec'
+ spark.streams.stateStoreCoordinator // initialize the lazy coordinator
+ }
+
+ after {
+ StateStore.stop()
+ }
+
+ private def setupStream(prefix: String, multiplier: Int):
(MemoryStream[Int], DataFrame) = {
+ val input = MemoryStream[Int]
+ val df = input.toDF
+ .select(
+ 'value as "key",
+ 'value.cast("timestamp") as s"${prefix}Time",
+ ('value * multiplier) as s"${prefix}Value")
+ .withWatermark(s"${prefix}Time", "10 seconds")
+
+ return (input, df)
+ }
+
+ private def setupWindowedJoin(joinType: String):
+ (MemoryStream[Int], MemoryStream[Int], DataFrame) = {
+ val (input1, df1) = setupStream("left", 2)
+ val (input2, df2) = setupStream("right", 3)
+ val windowed1 = df1.select('key, window('leftTime, "10 second"),
'leftValue)
+ val windowed2 = df2.select('key, window('rightTime, "10 second"),
'rightValue)
+ val joined = windowed1.join(windowed2, Seq("key", "window"), joinType)
+ .select('key, $"window.end".cast("long"), 'leftValue, 'rightValue)
+
+ (input1, input2, joined)
+ }
+
+ test("windowed left outer join") {
+ val (leftInput, rightInput, joined) = setupWindowedJoin("left_outer")
+
+ testStream(joined)(
+ // Test inner part of the join.
+ AddData(leftInput, 1, 2, 3, 4, 5),
+ AddData(rightInput, 3, 4, 5, 6, 7),
+ CheckLastBatch((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)),
+ // Old state doesn't get dropped until the batch *after* it gets
introduced, so the
+ // nulls won't show up until the next batch after the watermark
advances.
+ AddData(leftInput, 21),
+ AddData(rightInput, 22),
+ CheckLastBatch(),
+ assertNumStateRows(total = 12, updated = 2),
+ AddData(leftInput, 22),
+ CheckLastBatch(Row(22, 30, 44, 66), Row(1, 10, 2, null), Row(2, 10,
4, null)),
+ assertNumStateRows(total = 3, updated = 1)
+ )
+ }
+
+ test("windowed right outer join") {
+ val (leftInput, rightInput, joined) = setupWindowedJoin("right_outer")
+
+ testStream(joined)(
+ // Test inner part of the join.
+ AddData(leftInput, 1, 2, 3, 4, 5),
+ AddData(rightInput, 3, 4, 5, 6, 7),
+ CheckLastBatch((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)),
+ // Old state doesn't get dropped until the batch *after* it gets
introduced, so the
+ // nulls won't show up until the next batch after the watermark
advances.
+ AddData(leftInput, 21),
+ AddData(rightInput, 22),
+ CheckLastBatch(),
+ assertNumStateRows(total = 12, updated = 2),
+ AddData(leftInput, 22),
+ CheckLastBatch(Row(22, 30, 44, 66), Row(6, 10, null, 18), Row(7, 10,
null, 21)),
+ assertNumStateRows(total = 3, updated = 1)
+ )
+ }
+
+ Seq(
+ ("left_outer", Row(3, null, 5, null)),
+ ("right_outer", Row(null, 2, null, 5))
+ ).foreach { case (joinType: String, outerResult) =>
+ test(s"${joinType.replaceAllLiterally("_", " ")} with watermark range
condition") {
+ import org.apache.spark.sql.functions._
+
+ val leftInput = MemoryStream[(Int, Int)]
+ val rightInput = MemoryStream[(Int, Int)]
+
+ val df1 = leftInput.toDF.toDF("leftKey", "time")
+ .select('leftKey, 'time.cast("timestamp") as "leftTime", ('leftKey
* 2) as "leftValue")
+ .withWatermark("leftTime", "10 seconds")
+
+ val df2 = rightInput.toDF.toDF("rightKey", "time")
+ .select('rightKey, 'time.cast("timestamp") as "rightTime",
('rightKey * 3) as "rightValue")
+ .withWatermark("rightTime", "10 seconds")
+
+ val joined =
+ df1.join(
+ df2,
+ expr("leftKey = rightKey AND " +
+ "leftTime BETWEEN rightTime - interval 5 seconds AND rightTime
+ interval 5 seconds"),
+ joinType)
+ .select('leftKey, 'rightKey, 'leftTime.cast("int"),
'rightTime.cast("int"))
+ testStream(joined)(
+ AddData(leftInput, (1, 5), (3, 5)),
+ CheckAnswer(),
+ AddData(rightInput, (1, 10), (2, 5)),
+ CheckLastBatch((1, 1, 5, 10)),
+ AddData(rightInput, (1, 11)),
+ CheckLastBatch(), // no match as left time is too low
+ assertNumStateRows(total = 5, updated = 1),
+
+ // Increase event time watermark to 20s by adding data with time =
30s on both inputs
+ AddData(leftInput, (1, 7), (1, 30)),
+ CheckLastBatch((1, 1, 7, 10), (1, 1, 7, 11)),
+ assertNumStateRows(total = 7, updated = 2),
+ AddData(rightInput, (0, 30)),
+ CheckLastBatch(),
+ assertNumStateRows(total = 8, updated = 1),
+ AddData(rightInput, (0, 30)),
+ CheckLastBatch(outerResult),
+ assertNumStateRows(total = 3, updated = 1)
+ )
+ }
+ }
+
+ // When the join condition isn't true, the outer null rows must be
generated, even if the join
+ // keys themselves have a match.
+ test("left outer join with non-key condition violated on left") {
+ val (leftInput, simpleLeftDf) = setupStream("left", 2)
+ val (rightInput, simpleRightDf) = setupStream("right", 3)
+
+ val left = simpleLeftDf.select('key, window('leftTime, "10 second"),
'leftValue)
+ val right = simpleRightDf.select('key, window('rightTime, "10
second"), 'rightValue)
+
+ val joined = left.join(
+ right,
+ left("key") === right("key") && left("window") === right("window")
&&
+ 'leftValue > 20 && 'rightValue < 200,
+ "left_outer")
+ .select(left("key"), left("window.end").cast("long"), 'leftValue,
'rightValue)
+
+ testStream(joined)(
+ // leftValue <= 20 should generate outer join rows even though it
matches right keys
+ AddData(leftInput, 1, 2, 3),
+ AddData(rightInput, 1, 2, 3),
+ CheckLastBatch(),
+ AddData(leftInput, 30),
+ AddData(rightInput, 31),
+ CheckLastBatch(),
+ assertNumStateRows(total = 8, updated = 2),
+ AddData(rightInput, 32),
--- End diff --
In fact, the next one can also be collapsed. avoids a lot of duplicate code.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]