Repository: spark
Updated Branches:
  refs/heads/master 5f6943345 -> 3099c574c


http://git-wip-us.apache.org/repos/asf/spark/blob/3099c574/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
index 533e116..a6593b7 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
@@ -24,8 +24,9 @@ import scala.util.Random
 import org.scalatest.BeforeAndAfter
 
 import org.apache.spark.scheduler.ExecutorCacheTaskLocation
-import org.apache.spark.sql.{AnalysisException, SparkSession}
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference, 
AttributeSet}
+import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SparkSession}
+import org.apache.spark.sql.catalyst.analysis.StreamingJoinHelper
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, 
AttributeSet, Literal}
 import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, Filter}
 import org.apache.spark.sql.execution.LogicalRDD
 import org.apache.spark.sql.execution.streaming.{MemoryStream, 
StatefulOperatorStateInfo, StreamingSymmetricHashJoinHelper}
@@ -35,7 +36,7 @@ import org.apache.spark.sql.types._
 import org.apache.spark.util.Utils
 
 
-class StreamingJoinSuite extends StreamTest with StateStoreMetricsTest with 
BeforeAndAfter {
+class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest 
with BeforeAndAfter {
 
   before {
     SparkSession.setActiveSession(spark)  // set this before force 
initializing 'joinExec'
@@ -322,111 +323,6 @@ class StreamingJoinSuite extends StreamTest with 
StateStoreMetricsTest with Befo
     assert(e.toString.contains("Stream stream joins without equality predicate 
is not supported"))
   }
 
-  testQuietly("extract watermark from time condition") {
-    val attributesToFindConstraintFor = Seq(
-      AttributeReference("leftTime", TimestampType)(),
-      AttributeReference("leftOther", IntegerType)())
-    val metadataWithWatermark = new MetadataBuilder()
-      .putLong(EventTimeWatermark.delayKey, 1000)
-      .build()
-    val attributesWithWatermark = Seq(
-      AttributeReference("rightTime", TimestampType, metadata = 
metadataWithWatermark)(),
-      AttributeReference("rightOther", IntegerType)())
-
-    def watermarkFrom(
-        conditionStr: String,
-        rightWatermark: Option[Long] = Some(10000)): Option[Long] = {
-      val conditionExpr = Some(conditionStr).map { str =>
-        val plan =
-          Filter(
-            spark.sessionState.sqlParser.parseExpression(str),
-            LogicalRDD(
-              attributesToFindConstraintFor ++ attributesWithWatermark,
-              spark.sparkContext.emptyRDD)(spark))
-        plan.queryExecution.optimizedPlan.asInstanceOf[Filter].condition
-      }
-      StreamingSymmetricHashJoinHelper.getStateValueWatermark(
-        AttributeSet(attributesToFindConstraintFor), 
AttributeSet(attributesWithWatermark),
-        conditionExpr, rightWatermark)
-    }
-
-    // Test comparison directionality. E.g. if leftTime < rightTime and 
rightTime > watermark,
-    // then cannot define constraint on leftTime.
-    assert(watermarkFrom("leftTime > rightTime") === Some(10000))
-    assert(watermarkFrom("leftTime >= rightTime") === Some(9999))
-    assert(watermarkFrom("leftTime < rightTime") === None)
-    assert(watermarkFrom("leftTime <= rightTime") === None)
-    assert(watermarkFrom("rightTime > leftTime") === None)
-    assert(watermarkFrom("rightTime >= leftTime") === None)
-    assert(watermarkFrom("rightTime < leftTime") === Some(10000))
-    assert(watermarkFrom("rightTime <= leftTime") === Some(9999))
-
-    // Test type conversions
-    assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG)") 
=== Some(10000))
-    assert(watermarkFrom("CAST(leftTime AS LONG) < CAST(rightTime AS LONG)") 
=== None)
-    assert(watermarkFrom("CAST(leftTime AS DOUBLE) > CAST(rightTime AS 
DOUBLE)") === Some(10000))
-    assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS DOUBLE)") 
=== Some(10000))
-    assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS FLOAT)") 
=== Some(10000))
-    assert(watermarkFrom("CAST(leftTime AS DOUBLE) > CAST(rightTime AS 
FLOAT)") === Some(10000))
-    assert(watermarkFrom("CAST(leftTime AS STRING) > CAST(rightTime AS 
STRING)") === None)
-
-    // Test with timestamp type + calendar interval on either side of equation
-    // Note: timestamptype and calendar interval don't commute, so less valid 
combinations to test.
-    assert(watermarkFrom("leftTime > rightTime + interval 1 second") === 
Some(11000))
-    assert(watermarkFrom("leftTime + interval 2 seconds > rightTime ") === 
Some(8000))
-    assert(watermarkFrom("leftTime > rightTime - interval 3 second") === 
Some(7000))
-    assert(watermarkFrom("rightTime < leftTime - interval 3 second") === 
Some(13000))
-    assert(watermarkFrom("rightTime - interval 1 second < leftTime - interval 
3 second")
-      === Some(12000))
-
-    // Test with casted long type + constants on either side of equation
-    // Note: long type and constants commute, so more combinations to test.
-    // -- Constants on the right
-    assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG) + 
1") === Some(11000))
-    assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG) - 
1") === Some(9000))
-    assert(watermarkFrom("CAST(leftTime AS LONG) > CAST((rightTime + interval 
1 second) AS LONG)")
-      === Some(11000))
-    assert(watermarkFrom("CAST(leftTime AS LONG) > 2 + CAST(rightTime AS 
LONG)") === Some(12000))
-    assert(watermarkFrom("CAST(leftTime AS LONG) > -0.5 + CAST(rightTime AS 
LONG)") === Some(9500))
-    assert(watermarkFrom("CAST(leftTime AS LONG) - CAST(rightTime AS LONG) > 
2") === Some(12000))
-    assert(watermarkFrom("-CAST(rightTime AS DOUBLE) + CAST(leftTime AS LONG) 
> 0.1")
-      === Some(10100))
-    assert(watermarkFrom("0 > CAST(rightTime AS LONG) - CAST(leftTime AS LONG) 
+ 0.2")
-      === Some(10200))
-    // -- Constants on the left
-    assert(watermarkFrom("CAST(leftTime AS LONG) + 2 > CAST(rightTime AS 
LONG)") === Some(8000))
-    assert(watermarkFrom("1 + CAST(leftTime AS LONG) > CAST(rightTime AS 
LONG)") === Some(9000))
-    assert(watermarkFrom("CAST((leftTime  + interval 3 second) AS LONG) > 
CAST(rightTime AS LONG)")
-      === Some(7000))
-    assert(watermarkFrom("CAST(leftTime AS LONG) - 2 > CAST(rightTime AS 
LONG)") === Some(12000))
-    assert(watermarkFrom("CAST(leftTime AS LONG) + 0.5 > CAST(rightTime AS 
LONG)") === Some(9500))
-    assert(watermarkFrom("CAST(leftTime AS LONG) - CAST(rightTime AS LONG) - 2 
> 0")
-      === Some(12000))
-    assert(watermarkFrom("-CAST(rightTime AS LONG) + CAST(leftTime AS LONG) - 
0.1 > 0")
-      === Some(10100))
-    // -- Constants on both sides, mixed types
-    assert(watermarkFrom("CAST(leftTime AS LONG) - 2.0 > CAST(rightTime AS 
LONG) + 1")
-      === Some(13000))
-
-    // Test multiple conditions, should return minimum watermark
-    assert(watermarkFrom(
-      "leftTime > rightTime - interval 3 second AND rightTime < leftTime + 
interval 2 seconds") ===
-      Some(7000))  // first condition wins
-    assert(watermarkFrom(
-      "leftTime > rightTime - interval 3 second AND rightTime < leftTime + 
interval 4 seconds") ===
-      Some(6000))  // second condition wins
-
-    // Test invalid comparisons
-    assert(watermarkFrom("cast(leftTime AS LONG) > leftOther") === None)      
// non-time attributes
-    assert(watermarkFrom("leftOther > rightOther") === None)                  
// non-time attributes
-    assert(watermarkFrom("leftOther > rightOther AND leftTime > rightTime") 
=== Some(10000))
-    assert(watermarkFrom("cast(rightTime AS DOUBLE) < rightOther") === None)  
// non-time attributes
-    assert(watermarkFrom("leftTime > rightTime + interval 1 month") === None) 
// month not allowed
-
-    // Test static comparisons
-    assert(watermarkFrom("cast(leftTime AS LONG) > 10") === Some(10000))
-  }
-
   test("locality preferences of StateStoreAwareZippedRDD") {
     import StreamingSymmetricHashJoinHelper._
 
@@ -470,3 +366,189 @@ class StreamingJoinSuite extends StreamTest with 
StateStoreMetricsTest with Befo
     }
   }
 }
+
+class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest 
with BeforeAndAfter {
+
+  import testImplicits._
+  import org.apache.spark.sql.functions._
+
+  before {
+    SparkSession.setActiveSession(spark) // set this before force initializing 
'joinExec'
+    spark.streams.stateStoreCoordinator // initialize the lazy coordinator
+  }
+
+  after {
+    StateStore.stop()
+  }
+
+  private def setupStream(prefix: String, multiplier: Int): 
(MemoryStream[Int], DataFrame) = {
+    val input = MemoryStream[Int]
+    val df = input.toDF
+      .select(
+        'value as "key",
+        'value.cast("timestamp") as s"${prefix}Time",
+        ('value * multiplier) as s"${prefix}Value")
+      .withWatermark(s"${prefix}Time", "10 seconds")
+
+    return (input, df)
+  }
+
+  private def setupWindowedJoin(joinType: String):
+  (MemoryStream[Int], MemoryStream[Int], DataFrame) = {
+    val (input1, df1) = setupStream("left", 2)
+    val (input2, df2) = setupStream("right", 3)
+    val windowed1 = df1.select('key, window('leftTime, "10 second"), 
'leftValue)
+    val windowed2 = df2.select('key, window('rightTime, "10 second"), 
'rightValue)
+    val joined = windowed1.join(windowed2, Seq("key", "window"), joinType)
+      .select('key, $"window.end".cast("long"), 'leftValue, 'rightValue)
+
+    (input1, input2, joined)
+  }
+
+  test("windowed left outer join") {
+    val (leftInput, rightInput, joined) = setupWindowedJoin("left_outer")
+
+    testStream(joined)(
+      // Test inner part of the join.
+      AddData(leftInput, 1, 2, 3, 4, 5),
+      AddData(rightInput, 3, 4, 5, 6, 7),
+      CheckLastBatch((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)),
+      // Old state doesn't get dropped until the batch *after* it gets 
introduced, so the
+      // nulls won't show up until the next batch after the watermark advances.
+      AddData(leftInput, 21),
+      AddData(rightInput, 22),
+      CheckLastBatch(),
+      assertNumStateRows(total = 12, updated = 2),
+      AddData(leftInput, 22),
+      CheckLastBatch(Row(22, 30, 44, 66), Row(1, 10, 2, null), Row(2, 10, 4, 
null)),
+      assertNumStateRows(total = 3, updated = 1)
+    )
+  }
+
+  test("windowed right outer join") {
+    val (leftInput, rightInput, joined) = setupWindowedJoin("right_outer")
+
+    testStream(joined)(
+      // Test inner part of the join.
+      AddData(leftInput, 1, 2, 3, 4, 5),
+      AddData(rightInput, 3, 4, 5, 6, 7),
+      CheckLastBatch((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)),
+      // Old state doesn't get dropped until the batch *after* it gets 
introduced, so the
+      // nulls won't show up until the next batch after the watermark advances.
+      AddData(leftInput, 21),
+      AddData(rightInput, 22),
+      CheckLastBatch(),
+      assertNumStateRows(total = 12, updated = 2),
+      AddData(leftInput, 22),
+      CheckLastBatch(Row(22, 30, 44, 66), Row(6, 10, null, 18), Row(7, 10, 
null, 21)),
+      assertNumStateRows(total = 3, updated = 1)
+    )
+  }
+
+  Seq(
+    ("left_outer", Row(3, null, 5, null)),
+    ("right_outer", Row(null, 2, null, 5))
+  ).foreach { case (joinType: String, outerResult) =>
+    test(s"${joinType.replaceAllLiterally("_", " ")} with watermark range 
condition") {
+      import org.apache.spark.sql.functions._
+
+      val leftInput = MemoryStream[(Int, Int)]
+      val rightInput = MemoryStream[(Int, Int)]
+
+      val df1 = leftInput.toDF.toDF("leftKey", "time")
+        .select('leftKey, 'time.cast("timestamp") as "leftTime", ('leftKey * 
2) as "leftValue")
+        .withWatermark("leftTime", "10 seconds")
+
+      val df2 = rightInput.toDF.toDF("rightKey", "time")
+        .select('rightKey, 'time.cast("timestamp") as "rightTime", ('rightKey 
* 3) as "rightValue")
+        .withWatermark("rightTime", "10 seconds")
+
+      val joined =
+        df1.join(
+          df2,
+          expr("leftKey = rightKey AND " +
+            "leftTime BETWEEN rightTime - interval 5 seconds AND rightTime + 
interval 5 seconds"),
+          joinType)
+          .select('leftKey, 'rightKey, 'leftTime.cast("int"), 
'rightTime.cast("int"))
+      testStream(joined)(
+        AddData(leftInput, (1, 5), (3, 5)),
+        CheckAnswer(),
+        AddData(rightInput, (1, 10), (2, 5)),
+        CheckLastBatch((1, 1, 5, 10)),
+        AddData(rightInput, (1, 11)),
+        CheckLastBatch(), // no match as left time is too low
+        assertNumStateRows(total = 5, updated = 1),
+
+        // Increase event time watermark to 20s by adding data with time = 30s 
on both inputs
+        AddData(leftInput, (1, 7), (1, 30)),
+        CheckLastBatch((1, 1, 7, 10), (1, 1, 7, 11)),
+        assertNumStateRows(total = 7, updated = 2),
+        AddData(rightInput, (0, 30)),
+        CheckLastBatch(),
+        assertNumStateRows(total = 8, updated = 1),
+        AddData(rightInput, (0, 30)),
+        CheckLastBatch(outerResult),
+        assertNumStateRows(total = 3, updated = 1)
+      )
+    }
+  }
+
+  // When the join condition isn't true, the outer null rows must be 
generated, even if the join
+  // keys themselves have a match.
+  test("left outer join with non-key condition violated on left") {
+    val (leftInput, simpleLeftDf) = setupStream("left", 2)
+    val (rightInput, simpleRightDf) = setupStream("right", 3)
+
+    val left = simpleLeftDf.select('key, window('leftTime, "10 second"), 
'leftValue)
+    val right = simpleRightDf.select('key, window('rightTime, "10 second"), 
'rightValue)
+
+    val joined = left.join(
+        right,
+        left("key") === right("key") && left("window") === right("window") &&
+            'leftValue > 10 && ('rightValue < 300 || 'rightValue > 1000),
+        "left_outer")
+      .select(left("key"), left("window.end").cast("long"), 'leftValue, 
'rightValue)
+
+    testStream(joined)(
+      // leftValue <= 10 should generate outer join rows even though it 
matches right keys
+      AddData(leftInput, 1, 2, 3),
+      AddData(rightInput, 1, 2, 3),
+      CheckLastBatch(),
+      AddData(leftInput, 20),
+      AddData(rightInput, 21),
+      CheckLastBatch(),
+      assertNumStateRows(total = 8, updated = 2),
+      AddData(rightInput, 20),
+      CheckLastBatch(
+        Row(20, 30, 40, 60), Row(1, 10, 2, null), Row(2, 10, 4, null), Row(3, 
10, 6, null)),
+      assertNumStateRows(total = 3, updated = 1),
+      // leftValue and rightValue both satisfying condition should not 
generate outer join rows
+      AddData(leftInput, 40, 41),
+      AddData(rightInput, 40, 41),
+      CheckLastBatch((40, 50, 80, 120), (41, 50, 82, 123)),
+      AddData(leftInput, 70),
+      AddData(rightInput, 71),
+      CheckLastBatch(),
+      assertNumStateRows(total = 6, updated = 2),
+      AddData(rightInput, 70),
+      CheckLastBatch((70, 80, 140, 210)),
+      assertNumStateRows(total = 3, updated = 1),
+      // rightValue between 300 and 1000 should generate outer join rows even 
though it matches left
+      AddData(leftInput, 101, 102, 103),
+      AddData(rightInput, 101, 102, 103),
+      CheckLastBatch(),
+      AddData(leftInput, 1000),
+      AddData(rightInput, 1001),
+      CheckLastBatch(),
+      assertNumStateRows(total = 8, updated = 2),
+      AddData(rightInput, 1000),
+      CheckLastBatch(
+        Row(1000, 1010, 2000, 3000),
+        Row(101, 110, 202, null),
+        Row(102, 110, 204, null),
+        Row(103, 110, 206, null)),
+      assertNumStateRows(total = 3, updated = 1)
+    )
+  }
+}
+


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to