HeartSaVioR commented on a change in pull request #30076:
URL: https://github.com/apache/spark/pull/30076#discussion_r509245494



##########
File path: 
sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
##########
@@ -1041,3 +1077,204 @@ class StreamingOuterJoinSuite extends StreamTest with 
StateStoreMetricsTest with
     )
   }
 }
+
+class StreamingLeftSemiJoinSuite extends StreamingJoinSuite {
+
+  import testImplicits._
+
+  test("windowed left semi join") {
+    val (leftInput, rightInput, joined) = setupWindowedJoin("left_semi")
+
+    testStream(joined)(
+      MultiAddData(leftInput, 1, 2, 3, 4, 5)(rightInput, 3, 4, 5, 6, 7),
+      CheckNewAnswer(Row(3, 10, 6), Row(4, 10, 8), Row(5, 10, 10)),
+      // states
+      // left: 1, 2, 3, 4 ,5
+      // right: 3, 4, 5, 6, 7
+      assertNumStateRows(total = 10, updated = 10),
+      MultiAddData(leftInput, 21)(rightInput, 22),
+      // Watermark = 11, should remove rows having window=[0,10].
+      CheckNewAnswer(),
+      // states
+      // left: 21
+      // right: 22
+      //
+      // states evicted
+      // left: 1, 2, 3, 4 ,5 (below watermark)
+      // right: 3, 4, 5, 6, 7 (below watermark)
+      assertNumStateRows(total = 2, updated = 2),
+      AddData(leftInput, 22),
+      CheckNewAnswer(Row(22, 30, 44)),
+      // Unlike inner/outer joins, given left input row matches with right 
input row,
+      // we don't buffer the matched left input row to the state store.
+      //
+      // states
+      // left: 21
+      // right: 22
+      assertNumStateRows(total = 2, updated = 0),
+      StopStream,
+      StartStream(),
+
+      AddData(leftInput, 1),
+      // Row not add as 1 < state key watermark = 12.
+      CheckNewAnswer(),
+      // states
+      // left: 21
+      // right: 22
+      assertNumStateRows(total = 2, updated = 0, droppedByWatermark = 1),
+      AddData(rightInput, 5),
+      // Row not add as 5 < state key watermark = 12.
+      CheckNewAnswer(),
+      // states
+      // left: 21
+      // right: 22
+      assertNumStateRows(total = 2, updated = 0, droppedByWatermark = 1)
+    )
+  }
+
+  test("left semi early state exclusion on left") {
+    val (leftInput, rightInput, joined) = 
setupWindowedJoinWithLeftCondition("left_semi")
+
+    testStream(joined)(
+      MultiAddData(leftInput, 1, 2, 3)(rightInput, 3, 4, 5),
+      // The left rows with leftValue <= 4 should not generate their semi join 
rows and
+      // not get added to the state.
+      CheckNewAnswer(Row(3, 10, 6)),
+      // states
+      // left: 3
+      // right: 3, 4, 5
+      assertNumStateRows(total = 4, updated = 4),
+      // We shouldn't get more semi join rows when the watermark advances.
+      MultiAddData(leftInput, 20)(rightInput, 21),
+      CheckNewAnswer(),
+      // states
+      // left: 20
+      // right: 21
+      //
+      // states evicted
+      // left: 3 (below watermark)
+      // right: 3, 4, 5 (below watermark)
+      assertNumStateRows(total = 2, updated = 2),
+      AddData(rightInput, 20),
+      CheckNewAnswer((20, 30, 40)),
+      // states
+      // left: 20
+      // right: 21, 20
+      assertNumStateRows(total = 3, updated = 1)
+    )
+  }
+
+  test("left semi early state exclusion on right") {
+    val (leftInput, rightInput, joined) = 
setupWindowedJoinWithRightCondition("left_semi")
+
+    testStream(joined)(
+      MultiAddData(leftInput, 3, 4, 5)(rightInput, 1, 2, 3),
+      // The right rows with rightValue <= 7 should never be added to the 
state.
+      // The right row with rightValue = 9 > 7, hence joined and added to 
state.
+      CheckNewAnswer(Row(3, 10, 6)),
+      // states
+      // left: 3, 4, 5
+      // right: 3
+      assertNumStateRows(total = 4, updated = 4),
+      // We shouldn't get more semi join rows when the watermark advances.
+      MultiAddData(leftInput, 20)(rightInput, 21),
+      CheckNewAnswer(),
+      // states
+      // left: 20
+      // right: 21
+      //
+      // states evicted
+      // left: 3, 4, 5 (below watermark)
+      // right: 3 (below watermark)
+      assertNumStateRows(total = 2, updated = 2),
+      AddData(rightInput, 20),
+      CheckNewAnswer((20, 30, 40)),
+      // states
+      // left: 20
+      // right: 21, 20
+      assertNumStateRows(total = 3, updated = 1)
+    )
+  }
+
+  test("left semi join with watermark range condition") {
+    val (leftInput, rightInput, joined) = 
setupWindowedJoinWithRangeCondition("left_semi")
+
+    testStream(joined)(
+      AddData(leftInput, (1, 5), (3, 5)),
+      CheckNewAnswer(),
+      // states
+      // left: (1, 5), (3, 5)
+      // right: nothing
+      assertNumStateRows(total = 2, updated = 2),
+      AddData(rightInput, (1, 10), (2, 5)),
+      // Match left row in the state.
+      CheckNewAnswer((1, 5)),
+      // states
+      // left: (1, 5), (3, 5)
+      // right: (1, 10), (2, 5)
+      assertNumStateRows(total = 4, updated = 2),
+      AddData(rightInput, (1, 11)),
+      // No match as left time is too low and left row is already matched.

Review comment:
       That said, it'd be nice if we add or change the row for right input to 
verify the single condition `left row is already matched`. Probably like `(1, 
9)`?

##########
File path: 
sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
##########
@@ -1041,3 +1077,204 @@ class StreamingOuterJoinSuite extends StreamTest with 
StateStoreMetricsTest with
     )
   }
 }
+
+class StreamingLeftSemiJoinSuite extends StreamingJoinSuite {
+
+  import testImplicits._
+
+  test("windowed left semi join") {
+    val (leftInput, rightInput, joined) = setupWindowedJoin("left_semi")
+
+    testStream(joined)(
+      MultiAddData(leftInput, 1, 2, 3, 4, 5)(rightInput, 3, 4, 5, 6, 7),
+      CheckNewAnswer(Row(3, 10, 6), Row(4, 10, 8), Row(5, 10, 10)),
+      // states
+      // left: 1, 2, 3, 4 ,5
+      // right: 3, 4, 5, 6, 7
+      assertNumStateRows(total = 10, updated = 10),
+      MultiAddData(leftInput, 21)(rightInput, 22),
+      // Watermark = 11, should remove rows having window=[0,10].
+      CheckNewAnswer(),
+      // states
+      // left: 21
+      // right: 22
+      //
+      // states evicted
+      // left: 1, 2, 3, 4 ,5 (below watermark)
+      // right: 3, 4, 5, 6, 7 (below watermark)
+      assertNumStateRows(total = 2, updated = 2),
+      AddData(leftInput, 22),
+      CheckNewAnswer(Row(22, 30, 44)),
+      // Unlike inner/outer joins, given left input row matches with right 
input row,
+      // we don't buffer the matched left input row to the state store.
+      //
+      // states
+      // left: 21
+      // right: 22
+      assertNumStateRows(total = 2, updated = 0),
+      StopStream,
+      StartStream(),
+
+      AddData(leftInput, 1),
+      // Row not add as 1 < state key watermark = 12.
+      CheckNewAnswer(),
+      // states
+      // left: 21
+      // right: 22
+      assertNumStateRows(total = 2, updated = 0, droppedByWatermark = 1),
+      AddData(rightInput, 5),
+      // Row not add as 5 < state key watermark = 12.
+      CheckNewAnswer(),
+      // states
+      // left: 21
+      // right: 22
+      assertNumStateRows(total = 2, updated = 0, droppedByWatermark = 1)
+    )
+  }
+
+  test("left semi early state exclusion on left") {
+    val (leftInput, rightInput, joined) = 
setupWindowedJoinWithLeftCondition("left_semi")
+
+    testStream(joined)(
+      MultiAddData(leftInput, 1, 2, 3)(rightInput, 3, 4, 5),
+      // The left rows with leftValue <= 4 should not generate their semi join 
rows and
+      // not get added to the state.
+      CheckNewAnswer(Row(3, 10, 6)),
+      // states
+      // left: 3
+      // right: 3, 4, 5
+      assertNumStateRows(total = 4, updated = 4),
+      // We shouldn't get more semi join rows when the watermark advances.
+      MultiAddData(leftInput, 20)(rightInput, 21),
+      CheckNewAnswer(),
+      // states
+      // left: 20
+      // right: 21
+      //
+      // states evicted
+      // left: 3 (below watermark)
+      // right: 3, 4, 5 (below watermark)
+      assertNumStateRows(total = 2, updated = 2),
+      AddData(rightInput, 20),
+      CheckNewAnswer((20, 30, 40)),
+      // states
+      // left: 20
+      // right: 21, 20
+      assertNumStateRows(total = 3, updated = 1)
+    )
+  }
+
+  test("left semi early state exclusion on right") {
+    val (leftInput, rightInput, joined) = 
setupWindowedJoinWithRightCondition("left_semi")
+
+    testStream(joined)(
+      MultiAddData(leftInput, 3, 4, 5)(rightInput, 1, 2, 3),
+      // The right rows with rightValue <= 7 should never be added to the 
state.
+      // The right row with rightValue = 9 > 7, hence joined and added to 
state.
+      CheckNewAnswer(Row(3, 10, 6)),
+      // states
+      // left: 3, 4, 5
+      // right: 3
+      assertNumStateRows(total = 4, updated = 4),
+      // We shouldn't get more semi join rows when the watermark advances.
+      MultiAddData(leftInput, 20)(rightInput, 21),
+      CheckNewAnswer(),
+      // states
+      // left: 20
+      // right: 21
+      //
+      // states evicted
+      // left: 3, 4, 5 (below watermark)
+      // right: 3 (below watermark)
+      assertNumStateRows(total = 2, updated = 2),
+      AddData(rightInput, 20),
+      CheckNewAnswer((20, 30, 40)),
+      // states
+      // left: 20
+      // right: 21, 20
+      assertNumStateRows(total = 3, updated = 1)
+    )
+  }
+
+  test("left semi join with watermark range condition") {
+    val (leftInput, rightInput, joined) = 
setupWindowedJoinWithRangeCondition("left_semi")
+
+    testStream(joined)(
+      AddData(leftInput, (1, 5), (3, 5)),
+      CheckNewAnswer(),
+      // states
+      // left: (1, 5), (3, 5)
+      // right: nothing
+      assertNumStateRows(total = 2, updated = 2),
+      AddData(rightInput, (1, 10), (2, 5)),
+      // Match left row in the state.
+      CheckNewAnswer((1, 5)),
+      // states
+      // left: (1, 5), (3, 5)
+      // right: (1, 10), (2, 5)
+      assertNumStateRows(total = 4, updated = 2),
+      AddData(rightInput, (1, 11)),
+      // No match as left time is too low and left row is already matched.
+      CheckNewAnswer(),
+      // states
+      // left: (1, 5), (3, 5)
+      // right: (1, 10), (2, 5), (1, 11)
+      assertNumStateRows(total = 5, updated = 1),
+      // Increase event time watermark to 20s by adding data with time = 30s 
on both inputs.
+      AddData(leftInput, (1, 7), (1, 30)),
+      CheckNewAnswer((1, 7)),
+      // states
+      // left: (1, 5), (3, 5), (1, 30)
+      // right: (1, 10), (2, 5), (1, 11)
+      assertNumStateRows(total = 6, updated = 1),
+      // Watermark = 30 - 10 = 20, no matched row.
+      AddData(rightInput, (0, 30)),
+      CheckNewAnswer(),
+      // states
+      // left: (1, 30)
+      // right: (0, 30)
+      //
+      // states evicted
+      // left: (1, 5), (3, 5) (below watermark = 20)
+      // right: (1, 10), (2, 5), (1, 11) (below watermark = 20)
+      assertNumStateRows(total = 2, updated = 1)
+    )
+  }
+
+  test("self left semi join") {
+    val (inputStream, query) = setupWindowedSelfJoin("left_semi")
+
+    testStream(query)(
+      AddData(inputStream, (1, 1L), (2, 2L), (3, 3L), (4, 4L), (5, 5L)),
+      CheckNewAnswer((2, 2), (4, 4)),
+      // batch 1 - global watermark = 0
+      // states
+      // left: (2, 2L), (4, 4L)
+      //       (left rows with value % 2 != 0 is filtered per 
[[PushDownLeftSemiAntiJoin]])
+      // right: (2, 2L), (4, 4L)

Review comment:
       Probably better to leave the similar explanation (like you do for left 
state) for right state.

##########
File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
##########
@@ -99,13 +99,20 @@ class SymmetricHashJoinStateManager(
   /**
    * Get all the matched values for given join condition, with marking matched.
    * This method is designed to mark joined rows properly without exposing 
internal index of row.
+   *
+   * @param excludeRowsAlreadyMatched Do not join with rows already matched 
previously.
+   *                                  This is used for right side of left semi 
join in
+   *                                  [[StreamingSymmetricHashJoinExec]] only.
    */
   def getJoinedRows(
       key: UnsafeRow,
       generateJoinedRow: InternalRow => JoinedRow,
-      predicate: JoinedRow => Boolean): Iterator[JoinedRow] = {
+      predicate: JoinedRow => Boolean,
+      excludeRowsAlreadyMatched: Boolean = false): Iterator[JoinedRow] = {
     val numValues = keyToNumValues.get(key)
-    keyWithIndexToValue.getAll(key, numValues).map { keyIdxToValue =>
+    keyWithIndexToValue.getAll(key, numValues).filterNot { keyIdxToValue =>

Review comment:
       I thought this first and not proposed because current predicate cannot 
check the condition. We can still do this via adjusting the type of predicate a 
bit, but I guess the followup PR would try to separate left semi case of 
performance which lets us to can revert the change here. For the reason I 
prefer the small change for now.

##########
File path: 
sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
##########
@@ -1088,39 +1088,79 @@ class StreamingLeftSemiJoinSuite extends 
StreamingJoinSuite {
     testStream(joined)(
       MultiAddData(leftInput, 1, 2, 3, 4, 5)(rightInput, 3, 4, 5, 6, 7),
       CheckNewAnswer(Row(3, 10, 6), Row(4, 10, 8), Row(5, 10, 10)),
+      // states
+      // left: 1, 2, 3, 4 ,5
+      // right: 3, 4, 5, 6, 7
+      assertNumStateRows(total = 10, updated = 10),
       MultiAddData(leftInput, 21)(rightInput, 22),
-      // Watermark = 11, should remove rows having window=[0,10]
+      // Watermark = 11, should remove rows having window=[0,10].
       CheckNewAnswer(),
-      assertNumStateRows(total = 2, updated = 12),
+      // states
+      // left: 21
+      // right: 22
+      //
+      // states evicted
+      // left: 1, 2, 3, 4 ,5 (below watermark)
+      // right: 3, 4, 5, 6, 7 (below watermark)
+      assertNumStateRows(total = 2, updated = 2),
       AddData(leftInput, 22),
       CheckNewAnswer(Row(22, 30, 44)),
+      // Unlike inner/outer joins, given left input row matches with right 
input row,
+      // we don't buffer the matched left input row to the state store.
+      //
+      // states
+      // left: 21
+      // right: 22
       assertNumStateRows(total = 2, updated = 0),
       StopStream,
       StartStream(),
 
       AddData(leftInput, 1),
-      // Row not add as 1 < state key watermark = 12
+      // Row not add as 1 < state key watermark = 12.
       CheckNewAnswer(),
-      AddData(rightInput, 11),
-      // Row not add as 11 < state key watermark = 12
-      CheckNewAnswer()
+      // states
+      // left: 21
+      // right: 22
+      assertNumStateRows(total = 2, updated = 0, droppedByWatermark = 1),
+      AddData(rightInput, 5),
+      // Row not add as 5 < state key watermark = 12.
+      CheckNewAnswer(),
+      // states
+      // left: 21
+      // right: 22
+      assertNumStateRows(total = 2, updated = 0, droppedByWatermark = 1)
     )
   }
 
   test("left semi early state exclusion on left") {
     val (leftInput, rightInput, joined) = 
setupWindowedJoinWithLeftCondition("left_semi")
 
     testStream(joined)(
-      MultiAddData(leftInput, 1, 2, 3)(rightInput, 1, 2, 3, 4, 5),
-      // The left rows with leftValue <= 4 should not generate their semi join 
row and
+      MultiAddData(leftInput, 1, 2, 3)(rightInput, 3, 4, 5),

Review comment:
       Ah OK thanks for explanation. I agree it's probably less confused to 
just remove them from right side instead of being affected by the optimizer.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to