Github user tdas commented on a diff in the pull request:
https://github.com/apache/spark/pull/19271#discussion_r140063861
--- Diff:
sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
---
@@ -0,0 +1,585 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import java.util.UUID
+
+import scala.util.Random
+
+import org.apache.hadoop.conf.Configuration
+import org.scalatest.BeforeAndAfter
+
+import org.apache.spark.scheduler.ExecutorCacheTaskLocation
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, AttributeSet, BoundReference, Expression,
GenericInternalRow, LessThanOrEqual, Literal, UnsafeProjection, UnsafeRow}
+import
org.apache.spark.sql.catalyst.expressions.codegen.{GeneratePredicate}
+import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark,
Filter}
+import org.apache.spark.sql.execution.LogicalRDD
+import org.apache.spark.sql.execution.streaming.{MemoryStream,
StatefulOperatorStateInfo, StreamingSymmetricHashJoinHelper}
+import
org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.LeftSide
+import org.apache.spark.sql.execution.streaming.state.{StateStore,
StateStoreConf, StateStoreProviderId, SymmetricHashJoinStateManager}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
+import org.apache.spark.util.Utils
+
+
+class StreamingJoinSuite extends StreamTest with StateStoreMetricsTest
with BeforeAndAfter {
+
+ before {
+ SparkSession.setActiveSession(spark) // set this before force
initializing 'joinExec'
+ spark.streams.stateStoreCoordinator // initialize the lazy
coordinator
+ }
+
+ after {
+ StateStore.stop()
+ }
+
+ import testImplicits._
+
+ test("SymmetricHashJoinStateManager - all operations") {
+ val watermarkMetadata = new
MetadataBuilder().putLong(EventTimeWatermark.delayKey, 10).build()
+ val inputValueSchema = new StructType()
+ .add(StructField("time", IntegerType, metadata = watermarkMetadata))
+ .add(StructField("value", BooleanType))
+ val inputValueAttribs = inputValueSchema.toAttributes
+ val inputValueAttribWithWatermark = inputValueAttribs(0)
+ val joinKeyExprs = Seq[Expression](Literal(false),
inputValueAttribWithWatermark, Literal(10.0))
+
+ val inputValueGen =
UnsafeProjection.create(inputValueAttribs.map(_.dataType).toArray)
+ val joinKeyGen =
UnsafeProjection.create(joinKeyExprs.map(_.dataType).toArray)
+
+ def toInputValue(i: Int): UnsafeRow = {
+ inputValueGen.apply(new GenericInternalRow(Array[Any](i, false)))
+ }
+
+ def toJoinKeyRow(i: Int): UnsafeRow = {
+ joinKeyGen.apply(new GenericInternalRow(Array[Any](false, i, 10.0)))
+ }
+
+ def toKeyInt(joinKeyRow: UnsafeRow): Int = joinKeyRow.getInt(1)
+
+ def toValueInt(inputValueRow: UnsafeRow): Int = inputValueRow.getInt(0)
+
+ withJoinStateManager(inputValueAttribs, joinKeyExprs) { manager =>
+ def append(key: Int, value: Int): Unit = {
+ manager.append(toJoinKeyRow(key), toInputValue(value))
+ }
+
+ def get(key: Int): Seq[Int] =
manager.get(toJoinKeyRow(key)).map(toValueInt).toSeq.sorted
+
+ /** Remove keys (and corresponding values) where `time <= threshold`
*/
+ def removeByKey(threshold: Long): Unit = {
+ val expr =
+ LessThanOrEqual(
+ BoundReference(
+ 1, inputValueAttribWithWatermark.dataType,
inputValueAttribWithWatermark.nullable),
+ Literal(threshold))
+ manager.removeByKeyCondition(GeneratePredicate.generate(expr).eval
_)
+ }
+
+ /** Remove values where `time <= threshold` */
+ def removeByValue(watermark: Long): Unit = {
+ val expr = LessThanOrEqual(inputValueAttribWithWatermark,
Literal(watermark))
+ manager.removeByValueCondition(
+ GeneratePredicate.generate(expr, inputValueAttribs).eval _)
+ }
+
+ def numRows: Long = {
+ manager.metrics.numKeys
+ }
+
+ assert(get(20) === Seq.empty) // initially empty
+ append(20, 2)
+ assert(get(20) === Seq(2)) // should first value correctly
+ assert(numRows === 1)
+
+ append(20, 3)
+ assert(get(20) === Seq(2, 3)) // should append new values
+ append(20, 3)
+ assert(get(20) === Seq(2, 3, 3)) // should append another copy if
same value added again
+ assert(numRows === 3)
+
+ assert(get(30) === Seq.empty)
+ append(30, 1)
+ assert(get(30) === Seq(1))
+ assert(get(20) === Seq(2, 3, 3)) // add another key-value should
not affect existing ones
+ assert(numRows === 4)
+
+ removeByKey(25)
+ assert(get(20) === Seq.empty)
+ assert(get(30) === Seq(1)) // should remove 20, not 30
+ assert(numRows === 1)
+
+ removeByKey(30)
+ assert(get(30) === Seq.empty) // should remove 30
+ assert(numRows === 0)
+
+ def appendAndTest(key: Int, values: Int*): Unit = {
+ values.foreach { value => append(key, value)}
+ require(get(key) === values)
+ }
+
+ appendAndTest(40, 100, 200, 300)
+ appendAndTest(50, 125)
+ appendAndTest(60, 275) // prepare for testing
removeByValue
+ assert(numRows === 5)
+
+ removeByValue(125)
+ assert(get(40) === Seq(200, 300))
+ assert(get(50) === Seq.empty)
+ assert(get(60) === Seq(275)) // should remove only some
values, not all
+ assert(numRows === 3)
+
+ append(40, 50)
+ assert(get(40) === Seq(50, 200, 300))
+ assert(numRows === 4)
+
+ removeByValue(200)
+ assert(get(40) === Seq(300))
+ assert(get(60) === Seq(275)) // should remove only some
values, not all
+ assert(numRows === 2)
+
+ removeByValue(300)
+ assert(get(40) === Seq.empty)
+ assert(get(60) === Seq.empty) // should remove all values now
+ assert(numRows === 0)
+ }
+ }
+
+ test("stream stream inner join on non-time column") {
+ val input1 = MemoryStream[Int]
+ val input2 = MemoryStream[Int]
+
+ val df1 = input1.toDF.select('value as "key", ('value * 2) as
"leftValue")
+ val df2 = input2.toDF.select('value as "key", ('value * 3) as
"rightValue")
+ val joined = df1.join(df2, "key")
+
+ testStream(joined)(
+ AddData(input1, 1),
+ CheckAnswer(),
+ AddData(input2, 1, 10), // 1 arrived on input1 first, then
input2, should join
+ CheckLastBatch((1, 2, 3)),
+ AddData(input1, 10), // 10 arrived on input2 first, then
input1, should join
+ CheckLastBatch((10, 20, 30)),
+ AddData(input2, 1), // another 1 in input2 should join
with 1 input1
+ CheckLastBatch((1, 2, 3)),
+ StopStream,
+ StartStream(),
+ AddData(input1, 1), // multiple 1s should be kept in state causing
multiple (1, 2, 3)
+ CheckLastBatch((1, 2, 3), (1, 2, 3)),
+ StopStream,
+ StartStream(),
+ AddData(input1, 100),
+ AddData(input2, 100),
+ CheckLastBatch((100, 200, 300))
+ )
+ }
+
+
+ test("stream stream inner join on windows - without watermark") {
+ val input1 = MemoryStream[Int]
+ val input2 = MemoryStream[Int]
+
+ val df1 = input1.toDF
+ .select('value as "key", 'value.cast("timestamp") as "timestamp",
('value * 2) as "leftValue")
+ .select('key, window('timestamp, "10 second"), 'leftValue)
+
+ val df2 = input2.toDF
+ .select('value as "key", 'value.cast("timestamp") as "timestamp",
+ ('value * 3) as "rightValue")
+ .select('key, window('timestamp, "10 second"), 'rightValue)
+
+ val joined = df1.join(df2, Seq("key", "window"))
+ .select('key, $"window.end".cast("long"), 'leftValue, 'rightValue)
+
+ testStream(joined)(
+ AddData(input1, 1),
+ CheckLastBatch(),
+ AddData(input2, 1),
+ CheckLastBatch((1, 10, 2, 3)),
+ StopStream,
+ StartStream(),
+ AddData(input1, 25),
+ CheckLastBatch(),
+ StopStream,
+ StartStream(),
+ AddData(input2, 25),
+ CheckLastBatch((25, 30, 50, 75)),
+ AddData(input1, 1),
+ CheckLastBatch((1, 10, 2, 3)), // State for 1 still around as
there is not watermark
+ StopStream,
+ StartStream(),
+ AddData(input1, 5),
+ CheckLastBatch(),
+ AddData(input2, 5),
+ CheckLastBatch((5, 10, 10, 15)) // No filter by any watermark
+ )
+ }
+
+ test("stream stream inner join on windows - with watermark") {
+ val input1 = MemoryStream[Int]
+ val input2 = MemoryStream[Int]
+
+ val df1 = input1.toDF
+ .select('value as "key", 'value.cast("timestamp") as "timestamp",
('value * 2) as "leftValue")
+ .withWatermark("timestamp", "10 seconds")
+ .select('key, window('timestamp, "10 second"), 'leftValue)
+
+ val df2 = input2.toDF
+ .select('value as "key", 'value.cast("timestamp") as "timestamp",
+ ('value * 3) as "rightValue")
+ .select('key, window('timestamp, "10 second"), 'rightValue)
+
+ val joined = df1.join(df2, Seq("key", "window"))
+ .select('key, $"window.end".cast("long"), 'leftValue, 'rightValue)
+
+ testStream(joined)(
+ AddData(input1, 1),
+ CheckAnswer(),
+ assertNumStateRows(total = 1, updated = 1),
--- End diff --
We are not counting the number of rows for the counts as it is confusing to
do so. The whole point of this metric is to convey the number of rows
effectively stored as the state, and should not depend on implementation
details such as keeping the counts separately.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]