Github user brkyvz commented on a diff in the pull request:
https://github.com/apache/spark/pull/19271#discussion_r140051116
--- Diff:
sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
---
@@ -0,0 +1,585 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import java.util.UUID
+
+import scala.util.Random
+
+import org.apache.hadoop.conf.Configuration
+import org.scalatest.BeforeAndAfter
+
+import org.apache.spark.scheduler.ExecutorCacheTaskLocation
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, AttributeSet, BoundReference, Expression,
GenericInternalRow, LessThanOrEqual, Literal, UnsafeProjection, UnsafeRow}
+import
org.apache.spark.sql.catalyst.expressions.codegen.{GeneratePredicate}
+import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark,
Filter}
+import org.apache.spark.sql.execution.LogicalRDD
+import org.apache.spark.sql.execution.streaming.{MemoryStream,
StatefulOperatorStateInfo, StreamingSymmetricHashJoinHelper}
+import
org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.LeftSide
+import org.apache.spark.sql.execution.streaming.state.{StateStore,
StateStoreConf, StateStoreProviderId, SymmetricHashJoinStateManager}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
+import org.apache.spark.util.Utils
+
+
+class StreamingJoinSuite extends StreamTest with StateStoreMetricsTest
with BeforeAndAfter {
+
+ before {
+ SparkSession.setActiveSession(spark) // set this before force
initializing 'joinExec'
+ spark.streams.stateStoreCoordinator // initialize the lazy
coordinator
+ }
+
+ after {
+ StateStore.stop()
+ }
+
+ import testImplicits._
+
+ test("SymmetricHashJoinStateManager - all operations") {
+ val watermarkMetadata = new
MetadataBuilder().putLong(EventTimeWatermark.delayKey, 10).build()
+ val inputValueSchema = new StructType()
+ .add(StructField("time", IntegerType, metadata = watermarkMetadata))
+ .add(StructField("value", BooleanType))
+ val inputValueAttribs = inputValueSchema.toAttributes
+ val inputValueAttribWithWatermark = inputValueAttribs(0)
+ val joinKeyExprs = Seq[Expression](Literal(false),
inputValueAttribWithWatermark, Literal(10.0))
+
+ val inputValueGen =
UnsafeProjection.create(inputValueAttribs.map(_.dataType).toArray)
+ val joinKeyGen =
UnsafeProjection.create(joinKeyExprs.map(_.dataType).toArray)
+
+ def toInputValue(i: Int): UnsafeRow = {
+ inputValueGen.apply(new GenericInternalRow(Array[Any](i, false)))
+ }
+
+ def toJoinKeyRow(i: Int): UnsafeRow = {
+ joinKeyGen.apply(new GenericInternalRow(Array[Any](false, i, 10.0)))
+ }
+
+ def toKeyInt(joinKeyRow: UnsafeRow): Int = joinKeyRow.getInt(1)
+
+ def toValueInt(inputValueRow: UnsafeRow): Int = inputValueRow.getInt(0)
+
+ withJoinStateManager(inputValueAttribs, joinKeyExprs) { manager =>
+ def append(key: Int, value: Int): Unit = {
+ manager.append(toJoinKeyRow(key), toInputValue(value))
+ }
+
+ def get(key: Int): Seq[Int] =
manager.get(toJoinKeyRow(key)).map(toValueInt).toSeq.sorted
+
+ /** Remove keys (and corresponding values) where `time <= threshold`
*/
+ def removeByKey(threshold: Long): Unit = {
+ val expr =
+ LessThanOrEqual(
+ BoundReference(
+ 1, inputValueAttribWithWatermark.dataType,
inputValueAttribWithWatermark.nullable),
+ Literal(threshold))
+ manager.removeByKeyCondition(GeneratePredicate.generate(expr).eval
_)
+ }
+
+ /** Remove values where `time <= threshold` */
+ def removeByValue(watermark: Long): Unit = {
+ val expr = LessThanOrEqual(inputValueAttribWithWatermark,
Literal(watermark))
+ manager.removeByValueCondition(
+ GeneratePredicate.generate(expr, inputValueAttribs).eval _)
+ }
+
+ def numRows: Long = {
+ manager.metrics.numKeys
+ }
+
+ assert(get(20) === Seq.empty) // initially empty
+ append(20, 2)
+ assert(get(20) === Seq(2)) // should first value correctly
+ assert(numRows === 1)
+
+ append(20, 3)
+ assert(get(20) === Seq(2, 3)) // should append new values
+ append(20, 3)
+ assert(get(20) === Seq(2, 3, 3)) // should append another copy if
same value added again
+ assert(numRows === 3)
+
+ assert(get(30) === Seq.empty)
+ append(30, 1)
+ assert(get(30) === Seq(1))
+ assert(get(20) === Seq(2, 3, 3)) // add another key-value should
not affect existing ones
+ assert(numRows === 4)
+
+ removeByKey(25)
+ assert(get(20) === Seq.empty)
+ assert(get(30) === Seq(1)) // should remove 20, not 30
+ assert(numRows === 1)
+
+ removeByKey(30)
+ assert(get(30) === Seq.empty) // should remove 30
+ assert(numRows === 0)
+
+ def appendAndTest(key: Int, values: Int*): Unit = {
+ values.foreach { value => append(key, value)}
+ require(get(key) === values)
+ }
+
+ appendAndTest(40, 100, 200, 300)
+ appendAndTest(50, 125)
+ appendAndTest(60, 275) // prepare for testing
removeByValue
+ assert(numRows === 5)
+
+ removeByValue(125)
+ assert(get(40) === Seq(200, 300))
+ assert(get(50) === Seq.empty)
+ assert(get(60) === Seq(275)) // should remove only some
values, not all
+ assert(numRows === 3)
+
+ append(40, 50)
+ assert(get(40) === Seq(50, 200, 300))
+ assert(numRows === 4)
+
+ removeByValue(200)
+ assert(get(40) === Seq(300))
+ assert(get(60) === Seq(275)) // should remove only some
values, not all
+ assert(numRows === 2)
+
+ removeByValue(300)
+ assert(get(40) === Seq.empty)
+ assert(get(60) === Seq.empty) // should remove all values now
+ assert(numRows === 0)
+ }
+ }
+
+ test("stream stream inner join on non-time column") {
+ val input1 = MemoryStream[Int]
+ val input2 = MemoryStream[Int]
+
+ val df1 = input1.toDF.select('value as "key", ('value * 2) as
"leftValue")
+ val df2 = input2.toDF.select('value as "key", ('value * 3) as
"rightValue")
+ val joined = df1.join(df2, "key")
+
+ testStream(joined)(
+ AddData(input1, 1),
+ CheckAnswer(),
+ AddData(input2, 1, 10), // 1 arrived on input1 first, then
input2, should join
+ CheckLastBatch((1, 2, 3)),
+ AddData(input1, 10), // 10 arrived on input2 first, then
input1, should join
+ CheckLastBatch((10, 20, 30)),
+ AddData(input2, 1), // another 1 in input2 should join
with 1 input1
+ CheckLastBatch((1, 2, 3)),
+ StopStream,
+ StartStream(),
+ AddData(input1, 1), // multiple 1s should be kept in state causing
multiple (1, 2, 3)
+ CheckLastBatch((1, 2, 3), (1, 2, 3)),
+ StopStream,
+ StartStream(),
+ AddData(input1, 100),
+ AddData(input2, 100),
+ CheckLastBatch((100, 200, 300))
+ )
+ }
+
+
+ test("stream stream inner join on windows - without watermark") {
+ val input1 = MemoryStream[Int]
+ val input2 = MemoryStream[Int]
+
+ val df1 = input1.toDF
+ .select('value as "key", 'value.cast("timestamp") as "timestamp",
('value * 2) as "leftValue")
+ .select('key, window('timestamp, "10 second"), 'leftValue)
+
+ val df2 = input2.toDF
+ .select('value as "key", 'value.cast("timestamp") as "timestamp",
+ ('value * 3) as "rightValue")
+ .select('key, window('timestamp, "10 second"), 'rightValue)
+
+ val joined = df1.join(df2, Seq("key", "window"))
+ .select('key, $"window.end".cast("long"), 'leftValue, 'rightValue)
+
+ testStream(joined)(
+ AddData(input1, 1),
+ CheckLastBatch(),
+ AddData(input2, 1),
+ CheckLastBatch((1, 10, 2, 3)),
+ StopStream,
+ StartStream(),
+ AddData(input1, 25),
+ CheckLastBatch(),
+ StopStream,
+ StartStream(),
+ AddData(input2, 25),
+ CheckLastBatch((25, 30, 50, 75)),
+ AddData(input1, 1),
+ CheckLastBatch((1, 10, 2, 3)), // State for 1 still around as
there is not watermark
--- End diff --
nit: `there is no watermark`
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]