Github user brkyvz commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19271#discussion_r140051116
  
    --- Diff: 
sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala 
---
    @@ -0,0 +1,585 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.sql.streaming
    +
    +import java.util.UUID
    +
    +import scala.util.Random
    +
    +import org.apache.hadoop.conf.Configuration
    +import org.scalatest.BeforeAndAfter
    +
    +import org.apache.spark.scheduler.ExecutorCacheTaskLocation
    +import org.apache.spark.sql.SparkSession
    +import org.apache.spark.sql.catalyst.expressions.{Attribute, 
AttributeReference, AttributeSet, BoundReference, Expression, 
GenericInternalRow, LessThanOrEqual, Literal, UnsafeProjection, UnsafeRow}
    +import 
org.apache.spark.sql.catalyst.expressions.codegen.{GeneratePredicate}
    +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, 
Filter}
    +import org.apache.spark.sql.execution.LogicalRDD
    +import org.apache.spark.sql.execution.streaming.{MemoryStream, 
StatefulOperatorStateInfo, StreamingSymmetricHashJoinHelper}
    +import 
org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.LeftSide
    +import org.apache.spark.sql.execution.streaming.state.{StateStore, 
StateStoreConf, StateStoreProviderId, SymmetricHashJoinStateManager}
    +import org.apache.spark.sql.functions._
    +import org.apache.spark.sql.types._
    +import org.apache.spark.util.Utils
    +
    +
    +class StreamingJoinSuite extends StreamTest with StateStoreMetricsTest 
with BeforeAndAfter {
    +
    +  before {
    +    SparkSession.setActiveSession(spark)  // set this before force 
initializing 'joinExec'
    +    spark.streams.stateStoreCoordinator   // initialize the lazy 
coordinator
    +  }
    +
    +  after {
    +    StateStore.stop()
    +  }
    +
    +  import testImplicits._
    +
    +  test("SymmetricHashJoinStateManager - all operations") {
    +    val watermarkMetadata = new 
MetadataBuilder().putLong(EventTimeWatermark.delayKey, 10).build()
    +    val inputValueSchema = new StructType()
    +      .add(StructField("time", IntegerType, metadata = watermarkMetadata))
    +      .add(StructField("value", BooleanType))
    +    val inputValueAttribs = inputValueSchema.toAttributes
    +    val inputValueAttribWithWatermark = inputValueAttribs(0)
    +    val joinKeyExprs = Seq[Expression](Literal(false), 
inputValueAttribWithWatermark, Literal(10.0))
    +
    +    val inputValueGen = 
UnsafeProjection.create(inputValueAttribs.map(_.dataType).toArray)
    +    val joinKeyGen = 
UnsafeProjection.create(joinKeyExprs.map(_.dataType).toArray)
    +
    +    def toInputValue(i: Int): UnsafeRow = {
    +      inputValueGen.apply(new GenericInternalRow(Array[Any](i, false)))
    +    }
    +
    +    def toJoinKeyRow(i: Int): UnsafeRow = {
    +      joinKeyGen.apply(new GenericInternalRow(Array[Any](false, i, 10.0)))
    +    }
    +
    +    def toKeyInt(joinKeyRow: UnsafeRow): Int = joinKeyRow.getInt(1)
    +
    +    def toValueInt(inputValueRow: UnsafeRow): Int = inputValueRow.getInt(0)
    +
    +    withJoinStateManager(inputValueAttribs, joinKeyExprs) { manager =>
    +      def append(key: Int, value: Int): Unit = {
    +        manager.append(toJoinKeyRow(key), toInputValue(value))
    +      }
    +
    +      def get(key: Int): Seq[Int] = 
manager.get(toJoinKeyRow(key)).map(toValueInt).toSeq.sorted
    +
    +      /** Remove keys (and corresponding values) where `time <= threshold` 
*/
    +      def removeByKey(threshold: Long): Unit = {
    +        val expr =
    +          LessThanOrEqual(
    +            BoundReference(
    +              1, inputValueAttribWithWatermark.dataType, 
inputValueAttribWithWatermark.nullable),
    +            Literal(threshold))
    +        manager.removeByKeyCondition(GeneratePredicate.generate(expr).eval 
_)
    +      }
    +
    +      /** Remove values where `time <= threshold` */
    +      def removeByValue(watermark: Long): Unit = {
    +        val expr = LessThanOrEqual(inputValueAttribWithWatermark, 
Literal(watermark))
    +        manager.removeByValueCondition(
    +          GeneratePredicate.generate(expr, inputValueAttribs).eval _)
    +      }
    +
    +      def numRows: Long = {
    +        manager.metrics.numKeys
    +      }
    +
    +      assert(get(20) === Seq.empty)     // initially empty
    +      append(20, 2)
    +      assert(get(20) === Seq(2))        // should first value correctly
    +      assert(numRows === 1)
    +
    +      append(20, 3)
    +      assert(get(20) === Seq(2, 3))     // should append new values
    +      append(20, 3)
    +      assert(get(20) === Seq(2, 3, 3))  // should append another copy if 
same value added again
    +      assert(numRows === 3)
    +
    +      assert(get(30) === Seq.empty)
    +      append(30, 1)
    +      assert(get(30) === Seq(1))
    +      assert(get(20) === Seq(2, 3, 3))  // add another key-value should 
not affect existing ones
    +      assert(numRows === 4)
    +
    +      removeByKey(25)
    +      assert(get(20) === Seq.empty)
    +      assert(get(30) === Seq(1))        // should remove 20, not 30
    +      assert(numRows === 1)
    +
    +      removeByKey(30)
    +      assert(get(30) === Seq.empty)     // should remove 30
    +      assert(numRows === 0)
    +
    +      def appendAndTest(key: Int, values: Int*): Unit = {
    +        values.foreach { value => append(key, value)}
    +        require(get(key) === values)
    +      }
    +
    +      appendAndTest(40, 100, 200, 300)
    +      appendAndTest(50, 125)
    +      appendAndTest(60, 275)              // prepare for testing 
removeByValue
    +      assert(numRows === 5)
    +
    +      removeByValue(125)
    +      assert(get(40) === Seq(200, 300))
    +      assert(get(50) === Seq.empty)
    +      assert(get(60) === Seq(275))        // should remove only some 
values, not all
    +      assert(numRows === 3)
    +
    +      append(40, 50)
    +      assert(get(40) === Seq(50, 200, 300))
    +      assert(numRows === 4)
    +
    +      removeByValue(200)
    +      assert(get(40) === Seq(300))
    +      assert(get(60) === Seq(275))        // should remove only some 
values, not all
    +      assert(numRows === 2)
    +
    +      removeByValue(300)
    +      assert(get(40) === Seq.empty)
    +      assert(get(60) === Seq.empty)       // should remove all values now
    +      assert(numRows === 0)
    +    }
    +  }
    +
    +  test("stream stream inner join on non-time column") {
    +    val input1 = MemoryStream[Int]
    +    val input2 = MemoryStream[Int]
    +
    +    val df1 = input1.toDF.select('value as "key", ('value * 2) as 
"leftValue")
    +    val df2 = input2.toDF.select('value as "key", ('value * 3) as 
"rightValue")
    +    val joined = df1.join(df2, "key")
    +
    +    testStream(joined)(
    +      AddData(input1, 1),
    +      CheckAnswer(),
    +      AddData(input2, 1, 10),       // 1 arrived on input1 first, then 
input2, should join
    +      CheckLastBatch((1, 2, 3)),
    +      AddData(input1, 10),          // 10 arrived on input2 first, then 
input1, should join
    +      CheckLastBatch((10, 20, 30)),
    +      AddData(input2, 1),           // another 1 in input2 should join 
with 1 input1
    +      CheckLastBatch((1, 2, 3)),
    +      StopStream,
    +      StartStream(),
    +      AddData(input1, 1), // multiple 1s should be kept in state causing 
multiple (1, 2, 3)
    +      CheckLastBatch((1, 2, 3), (1, 2, 3)),
    +      StopStream,
    +      StartStream(),
    +      AddData(input1, 100),
    +      AddData(input2, 100),
    +      CheckLastBatch((100, 200, 300))
    +    )
    +  }
    +
    +
    +  test("stream stream inner join on windows - without watermark") {
    +    val input1 = MemoryStream[Int]
    +    val input2 = MemoryStream[Int]
    +
    +    val df1 = input1.toDF
    +      .select('value as "key", 'value.cast("timestamp") as "timestamp", 
('value * 2) as "leftValue")
    +      .select('key, window('timestamp, "10 second"), 'leftValue)
    +
    +    val df2 = input2.toDF
    +      .select('value as "key", 'value.cast("timestamp") as "timestamp",
    +        ('value * 3) as "rightValue")
    +      .select('key, window('timestamp, "10 second"), 'rightValue)
    +
    +    val joined = df1.join(df2, Seq("key", "window"))
    +      .select('key, $"window.end".cast("long"), 'leftValue, 'rightValue)
    +
    +    testStream(joined)(
    +      AddData(input1, 1),
    +      CheckLastBatch(),
    +      AddData(input2, 1),
    +      CheckLastBatch((1, 10, 2, 3)),
    +      StopStream,
    +      StartStream(),
    +      AddData(input1, 25),
    +      CheckLastBatch(),
    +      StopStream,
    +      StartStream(),
    +      AddData(input2, 25),
    +      CheckLastBatch((25, 30, 50, 75)),
    +      AddData(input1, 1),
    +      CheckLastBatch((1, 10, 2, 3)),      // State for 1 still around as 
there is not watermark
    --- End diff --
    
    nit: `there is no watermark`


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to