Github user tdas commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19271#discussion_r140063861
  
    --- Diff: 
sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala 
---
    @@ -0,0 +1,585 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.sql.streaming
    +
    +import java.util.UUID
    +
    +import scala.util.Random
    +
    +import org.apache.hadoop.conf.Configuration
    +import org.scalatest.BeforeAndAfter
    +
    +import org.apache.spark.scheduler.ExecutorCacheTaskLocation
    +import org.apache.spark.sql.SparkSession
    +import org.apache.spark.sql.catalyst.expressions.{Attribute, 
AttributeReference, AttributeSet, BoundReference, Expression, 
GenericInternalRow, LessThanOrEqual, Literal, UnsafeProjection, UnsafeRow}
    +import 
org.apache.spark.sql.catalyst.expressions.codegen.{GeneratePredicate}
    +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, 
Filter}
    +import org.apache.spark.sql.execution.LogicalRDD
    +import org.apache.spark.sql.execution.streaming.{MemoryStream, 
StatefulOperatorStateInfo, StreamingSymmetricHashJoinHelper}
    +import 
org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.LeftSide
    +import org.apache.spark.sql.execution.streaming.state.{StateStore, 
StateStoreConf, StateStoreProviderId, SymmetricHashJoinStateManager}
    +import org.apache.spark.sql.functions._
    +import org.apache.spark.sql.types._
    +import org.apache.spark.util.Utils
    +
    +
    +class StreamingJoinSuite extends StreamTest with StateStoreMetricsTest 
with BeforeAndAfter {
    +
    +  before {
    +    SparkSession.setActiveSession(spark)  // set this before force 
initializing 'joinExec'
    +    spark.streams.stateStoreCoordinator   // initialize the lazy 
coordinator
    +  }
    +
    +  after {
    +    StateStore.stop()
    +  }
    +
    +  import testImplicits._
    +
    +  test("SymmetricHashJoinStateManager - all operations") {
    +    val watermarkMetadata = new 
MetadataBuilder().putLong(EventTimeWatermark.delayKey, 10).build()
    +    val inputValueSchema = new StructType()
    +      .add(StructField("time", IntegerType, metadata = watermarkMetadata))
    +      .add(StructField("value", BooleanType))
    +    val inputValueAttribs = inputValueSchema.toAttributes
    +    val inputValueAttribWithWatermark = inputValueAttribs(0)
    +    val joinKeyExprs = Seq[Expression](Literal(false), 
inputValueAttribWithWatermark, Literal(10.0))
    +
    +    val inputValueGen = 
UnsafeProjection.create(inputValueAttribs.map(_.dataType).toArray)
    +    val joinKeyGen = 
UnsafeProjection.create(joinKeyExprs.map(_.dataType).toArray)
    +
    +    def toInputValue(i: Int): UnsafeRow = {
    +      inputValueGen.apply(new GenericInternalRow(Array[Any](i, false)))
    +    }
    +
    +    def toJoinKeyRow(i: Int): UnsafeRow = {
    +      joinKeyGen.apply(new GenericInternalRow(Array[Any](false, i, 10.0)))
    +    }
    +
    +    def toKeyInt(joinKeyRow: UnsafeRow): Int = joinKeyRow.getInt(1)
    +
    +    def toValueInt(inputValueRow: UnsafeRow): Int = inputValueRow.getInt(0)
    +
    +    withJoinStateManager(inputValueAttribs, joinKeyExprs) { manager =>
    +      def append(key: Int, value: Int): Unit = {
    +        manager.append(toJoinKeyRow(key), toInputValue(value))
    +      }
    +
    +      def get(key: Int): Seq[Int] = 
manager.get(toJoinKeyRow(key)).map(toValueInt).toSeq.sorted
    +
    +      /** Remove keys (and corresponding values) where `time <= threshold` 
*/
    +      def removeByKey(threshold: Long): Unit = {
    +        val expr =
    +          LessThanOrEqual(
    +            BoundReference(
    +              1, inputValueAttribWithWatermark.dataType, 
inputValueAttribWithWatermark.nullable),
    +            Literal(threshold))
    +        manager.removeByKeyCondition(GeneratePredicate.generate(expr).eval 
_)
    +      }
    +
    +      /** Remove values where `time <= threshold` */
    +      def removeByValue(watermark: Long): Unit = {
    +        val expr = LessThanOrEqual(inputValueAttribWithWatermark, 
Literal(watermark))
    +        manager.removeByValueCondition(
    +          GeneratePredicate.generate(expr, inputValueAttribs).eval _)
    +      }
    +
    +      def numRows: Long = {
    +        manager.metrics.numKeys
    +      }
    +
    +      assert(get(20) === Seq.empty)     // initially empty
    +      append(20, 2)
    +      assert(get(20) === Seq(2))        // should first value correctly
    +      assert(numRows === 1)
    +
    +      append(20, 3)
    +      assert(get(20) === Seq(2, 3))     // should append new values
    +      append(20, 3)
    +      assert(get(20) === Seq(2, 3, 3))  // should append another copy if 
same value added again
    +      assert(numRows === 3)
    +
    +      assert(get(30) === Seq.empty)
    +      append(30, 1)
    +      assert(get(30) === Seq(1))
    +      assert(get(20) === Seq(2, 3, 3))  // add another key-value should 
not affect existing ones
    +      assert(numRows === 4)
    +
    +      removeByKey(25)
    +      assert(get(20) === Seq.empty)
    +      assert(get(30) === Seq(1))        // should remove 20, not 30
    +      assert(numRows === 1)
    +
    +      removeByKey(30)
    +      assert(get(30) === Seq.empty)     // should remove 30
    +      assert(numRows === 0)
    +
    +      def appendAndTest(key: Int, values: Int*): Unit = {
    +        values.foreach { value => append(key, value)}
    +        require(get(key) === values)
    +      }
    +
    +      appendAndTest(40, 100, 200, 300)
    +      appendAndTest(50, 125)
    +      appendAndTest(60, 275)              // prepare for testing 
removeByValue
    +      assert(numRows === 5)
    +
    +      removeByValue(125)
    +      assert(get(40) === Seq(200, 300))
    +      assert(get(50) === Seq.empty)
    +      assert(get(60) === Seq(275))        // should remove only some 
values, not all
    +      assert(numRows === 3)
    +
    +      append(40, 50)
    +      assert(get(40) === Seq(50, 200, 300))
    +      assert(numRows === 4)
    +
    +      removeByValue(200)
    +      assert(get(40) === Seq(300))
    +      assert(get(60) === Seq(275))        // should remove only some 
values, not all
    +      assert(numRows === 2)
    +
    +      removeByValue(300)
    +      assert(get(40) === Seq.empty)
    +      assert(get(60) === Seq.empty)       // should remove all values now
    +      assert(numRows === 0)
    +    }
    +  }
    +
    +  test("stream stream inner join on non-time column") {
    +    val input1 = MemoryStream[Int]
    +    val input2 = MemoryStream[Int]
    +
    +    val df1 = input1.toDF.select('value as "key", ('value * 2) as 
"leftValue")
    +    val df2 = input2.toDF.select('value as "key", ('value * 3) as 
"rightValue")
    +    val joined = df1.join(df2, "key")
    +
    +    testStream(joined)(
    +      AddData(input1, 1),
    +      CheckAnswer(),
    +      AddData(input2, 1, 10),       // 1 arrived on input1 first, then 
input2, should join
    +      CheckLastBatch((1, 2, 3)),
    +      AddData(input1, 10),          // 10 arrived on input2 first, then 
input1, should join
    +      CheckLastBatch((10, 20, 30)),
    +      AddData(input2, 1),           // another 1 in input2 should join 
with 1 input1
    +      CheckLastBatch((1, 2, 3)),
    +      StopStream,
    +      StartStream(),
    +      AddData(input1, 1), // multiple 1s should be kept in state causing 
multiple (1, 2, 3)
    +      CheckLastBatch((1, 2, 3), (1, 2, 3)),
    +      StopStream,
    +      StartStream(),
    +      AddData(input1, 100),
    +      AddData(input2, 100),
    +      CheckLastBatch((100, 200, 300))
    +    )
    +  }
    +
    +
    +  test("stream stream inner join on windows - without watermark") {
    +    val input1 = MemoryStream[Int]
    +    val input2 = MemoryStream[Int]
    +
    +    val df1 = input1.toDF
    +      .select('value as "key", 'value.cast("timestamp") as "timestamp", 
('value * 2) as "leftValue")
    +      .select('key, window('timestamp, "10 second"), 'leftValue)
    +
    +    val df2 = input2.toDF
    +      .select('value as "key", 'value.cast("timestamp") as "timestamp",
    +        ('value * 3) as "rightValue")
    +      .select('key, window('timestamp, "10 second"), 'rightValue)
    +
    +    val joined = df1.join(df2, Seq("key", "window"))
    +      .select('key, $"window.end".cast("long"), 'leftValue, 'rightValue)
    +
    +    testStream(joined)(
    +      AddData(input1, 1),
    +      CheckLastBatch(),
    +      AddData(input2, 1),
    +      CheckLastBatch((1, 10, 2, 3)),
    +      StopStream,
    +      StartStream(),
    +      AddData(input1, 25),
    +      CheckLastBatch(),
    +      StopStream,
    +      StartStream(),
    +      AddData(input2, 25),
    +      CheckLastBatch((25, 30, 50, 75)),
    +      AddData(input1, 1),
    +      CheckLastBatch((1, 10, 2, 3)),      // State for 1 still around as 
there is not watermark
    +      StopStream,
    +      StartStream(),
    +      AddData(input1, 5),
    +      CheckLastBatch(),
    +      AddData(input2, 5),
    +      CheckLastBatch((5, 10, 10, 15))     // No filter by any watermark
    +    )
    +  }
    +
    +  test("stream stream inner join on windows - with watermark") {
    +    val input1 = MemoryStream[Int]
    +    val input2 = MemoryStream[Int]
    +
    +    val df1 = input1.toDF
    +      .select('value as "key", 'value.cast("timestamp") as "timestamp", 
('value * 2) as "leftValue")
    +      .withWatermark("timestamp", "10 seconds")
    +      .select('key, window('timestamp, "10 second"), 'leftValue)
    +
    +    val df2 = input2.toDF
    +      .select('value as "key", 'value.cast("timestamp") as "timestamp",
    +        ('value * 3) as "rightValue")
    +      .select('key, window('timestamp, "10 second"), 'rightValue)
    +
    +    val joined = df1.join(df2, Seq("key", "window"))
    +      .select('key, $"window.end".cast("long"), 'leftValue, 'rightValue)
    +
    +    testStream(joined)(
    +      AddData(input1, 1),
    +      CheckAnswer(),
    +      assertNumStateRows(total = 1, updated = 1),
    --- End diff --
    
    We are not counting the number of rows for the counts as it is confusing to 
do so. The whole point of this metric is to convey the number of rows 
effectively stored as the state, and should not depend on implementation 
details such as keeping the counts separately.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to