Github user tdas commented on a diff in the pull request:
https://github.com/apache/spark/pull/19327#discussion_r142344009
--- Diff:
sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
---
@@ -470,3 +475,283 @@ class StreamingJoinSuite extends StreamTest with
StateStoreMetricsTest with Befo
}
}
}
+
+class StreamingOuterJoinSuite extends StreamTest with
StateStoreMetricsTest with BeforeAndAfter {
+
+ import testImplicits._
+ import org.apache.spark.sql.functions._
+
+ before {
+ SparkSession.setActiveSession(spark) // set this before force
initializing 'joinExec'
+ spark.streams.stateStoreCoordinator // initialize the lazy coordinator
+ }
+
+ after {
+ StateStore.stop()
+ }
+
+ private def setupStream(prefix: String, multiplier: Int):
(MemoryStream[Int], DataFrame) = {
+ val input = MemoryStream[Int]
+
+ val df = input.toDF
+ .select(
+ 'value as "key",
+ 'value.cast("timestamp") as s"${prefix}Time",
+ ('value * multiplier) as s"${prefix}Value")
+ .withWatermark(s"${prefix}Time", "10 seconds")
+
+ return (input, df)
+ }
+
+ private def setupWindowedJoin(joinType: String) = {
+ val (input1, df1) = setupStream("left", 2)
+ val (input2, df2) = setupStream("right", 3)
+
+ val windowed1 = df1.select('key, window('leftTime, "10 second"),
'leftValue)
+
+ val windowed2 = df2.select('key, window('rightTime, "10 second"),
'rightValue)
+
+ val joined = windowed1.join(windowed2, Seq("key", "window"), joinType)
+ .select('key, $"window.end".cast("long"), 'leftValue, 'rightValue)
+
+ (input1, input2, joined)
+ }
+
+ test("left stream batch outer join") {
+ val stream = MemoryStream[Int]
+ .toDF()
+ .withColumn("timestamp", 'value.cast("timestamp"))
+ .withWatermark("timestamp", "1 second")
+ val joined =
+ stream.join(Seq(1).toDF(), Seq("value"), "left_outer")
+
+ // This test is in the suite just to confirm the validations below
don't block this valid join.
+ // We don't need to check results, just that the join can happen.
+ testStream(joined)()
+ }
+
+ test("left batch stream outer join") {
+ val stream = MemoryStream[Int]
+ .toDF()
+ .withColumn("timestamp", 'value.cast("timestamp"))
+ .withWatermark("timestamp", "1 second")
+ val joined =
+ Seq(1).toDF().join(stream, Seq("value"), "left_outer")
+
+ val thrown = intercept[AnalysisException] {
+ testStream(joined)()
+ }
+
+ assert(thrown.getMessage.contains(
+ "Left outer join with a streaming DataFrame/Dataset on the right and
a static"))
+ }
+
+ test("right stream batch outer join") {
+ val stream = MemoryStream[Int]
+ .toDF()
+ .withColumn("timestamp", 'value.cast("timestamp"))
+ .withWatermark("timestamp", "1 second")
+ val joined =
+ stream.join(Seq(1).toDF(), Seq("value"), "right_outer")
+
+ val thrown = intercept[AnalysisException] {
+ testStream(joined)()
+ }
+
+ assert(thrown.getMessage.contains(
+ "Right outer join with a streaming DataFrame/Dataset on the left and
a static"))
+ }
+
+ test("left outer join with no watermark") {
+ val joined =
+ MemoryStream[Int].toDF().join(MemoryStream[Int].toDF(),
Seq("value"), "left_outer")
+
+ val thrown = intercept[AnalysisException] {
+ testStream(joined)()
+ }
+
+ assert(thrown.getMessage.contains(
+ "Stream-stream outer join between two streaming DataFrame/Datasets
is not supported " +
+ "without a watermark"))
+ }
+
+ test("right outer join with no watermark") {
+ val joined =
+ MemoryStream[Int].toDF().join(MemoryStream[Int].toDF(),
Seq("value"), "right_outer")
+
+ val thrown = intercept[AnalysisException] {
+ testStream(joined)()
+ }
+
+ assert(thrown.getMessage.contains(
+ "Stream-stream outer join between two streaming DataFrame/Datasets
is not supported " +
+ "without a watermark"))
+ }
+
+ test("windowed left outer join") {
+ val (leftInput, rightInput, joined) = setupWindowedJoin("left_outer")
+
+ testStream(joined)(
+ // Test inner part of the join.
+ AddData(leftInput, 1, 2, 3, 4, 5),
+ AddData(rightInput, 3, 4, 5, 6, 7),
+ CheckLastBatch((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)),
+ // Old state doesn't get dropped until the batch *after* it gets
introduced, so the
+ // nulls won't show up until the next batch after the watermark
advances.
+ AddData(leftInput, 21),
+ AddData(rightInput, 22),
+ CheckLastBatch(),
+ AddData(leftInput, 22),
+ CheckLastBatch(Row(22, 30, 44, 66), Row(1, 10, 2, null), Row(2, 10,
4, null))
+ )
+ }
+
+ test("windowed right outer join") {
+ val (leftInput, rightInput, joined) = setupWindowedJoin("right_outer")
+
+ testStream(joined)(
+ // Test inner part of the join.
+ AddData(leftInput, 1, 2, 3, 4, 5),
+ AddData(rightInput, 3, 4, 5, 6, 7),
+ CheckLastBatch((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)),
+ // Old state doesn't get dropped until the batch *after* it gets
introduced, so the
+ // nulls won't show up until the next batch after the watermark
advances.
+ AddData(leftInput, 21),
+ AddData(rightInput, 22),
+ CheckLastBatch(),
+ AddData(leftInput, 22),
+ CheckLastBatch(Row(22, 30, 44, 66), Row(6, 10, null, 18), Row(7, 10,
null, 21))
+ )
+ }
+
+ Seq(
+ ("left_outer", Row(3, null, 5, null)),
+ ("right_outer", Row(null, 2, null, 5))
+ ).foreach { case (joinType: String, outerResult) =>
+ test(s"${joinType.replaceAllLiterally("_", " ")} with watermark range
condition") {
+ import org.apache.spark.sql.functions._
+
+ val leftInput = MemoryStream[(Int, Int)]
+ val rightInput = MemoryStream[(Int, Int)]
+
+ val df1 = leftInput.toDF.toDF("leftKey", "time")
+ .select('leftKey, 'time.cast("timestamp") as "leftTime", ('leftKey
* 2) as "leftValue")
+ .withWatermark("leftTime", "10 seconds")
+
+ val df2 = rightInput.toDF.toDF("rightKey", "time")
+ .select('rightKey, 'time.cast("timestamp") as "rightTime",
('rightKey * 3) as "rightValue")
+ .withWatermark("rightTime", "10 seconds")
+
+ val joined =
+ df1.join(
+ df2,
+ expr("leftKey = rightKey AND " +
+ "leftTime BETWEEN rightTime - interval 5 seconds AND rightTime
+ interval 5 seconds"),
+ joinType)
+ .select('leftKey, 'rightKey, 'leftTime.cast("int"),
'rightTime.cast("int"))
+ testStream(joined)(
+ AddData(leftInput, (1, 5), (3, 5)),
+ CheckAnswer(),
+ AddData(rightInput, (1, 10), (2, 5)),
+ CheckLastBatch((1, 1, 5, 10)),
+ AddData(rightInput, (1, 11)),
+ CheckLastBatch(), // no match as left time is too low
+ assertNumStateRows(total = 5, updated = 1),
+
+ // Increase event time watermark to 20s by adding data with time =
30s on both inputs
+ AddData(leftInput, (1, 7), (1, 30)),
+ CheckLastBatch((1, 1, 7, 10), (1, 1, 7, 11)),
+ assertNumStateRows(total = 7, updated = 2),
+ AddData(rightInput, (0, 30)),
+ CheckLastBatch(),
+ assertNumStateRows(total = 8, updated = 1),
+ AddData(rightInput, (0, 30)),
+ CheckLastBatch(outerResult))
+ }
+ }
+
+ // When the join condition isn't true, the outer null rows must be
generated, even if the join
+ // keys themselves have a match.
+ test("outer join with non-key condition violated on left") {
--- End diff --
outer join -> left outer join
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]