This is an automated email from the ASF dual-hosted git repository.

tdas pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new dfd7b02  [SPARK-35800][SS] Improving GroupState testability by 
introducing TestGroupState
dfd7b02 is described below

commit dfd7b026dc7c3c38bef9afab82852aff902a25d2
Author: Li Zhang <li.zh...@databricks.com>
AuthorDate: Tue Jun 22 15:04:01 2021 -0400

    [SPARK-35800][SS] Improving GroupState testability by introducing 
TestGroupState
    
    ### What changes were proposed in this pull request?
    Proposed changes in this pull request:
    
    1. Introducing the `TestGroupState` interface which is inherited from 
`GroupState` so that testing related getters can be exposed in a controlled 
manner
    2. Changing `GroupStateImpl` to inherit from `TestGroupState` interface, 
instead of directly from `GroupState`
    3. Implementing `TestGroupState` object with `create()` method to forward 
inputs to the private `GroupStateImpl` constructor
    4. User input validations have been added into `GroupStateImpl`'s 
`createForStreaming()` method to prevent users from creating invalid GroupState 
objects.
    5. Replacing existing `GroupStateImpl` usages in sql pkg internal unit 
tests with the newly added `TestGroupState` to give user best practice about 
`TestGroupState` usage.
    
    With the changes in this PR, the class hierarchy is changed from 
`GroupStateImpl` -> `GroupState` to `GroupStateImpl` -> `TestGroupState` -> 
`GroupState` (-> means inherits from)
    
    ### Why are the changes needed?
    The internal `GroupStateImpl` implementation for the `GroupState` interface 
has no public constructors accessible outside of the sql pkg. However, the 
user-provided state transition function for `[map|flatMap]GroupsWithState` 
requires a `GroupState` object as the prevState input.
    
    Currently, users are calling the Structured Streaming engine in their unit 
tests in order to instantiate such `GroupState` instances, which makes UTs 
cumbersome.
    
    The proposed `TestGroupState` interface is to give users controlled access 
to the `GroupStateImpl` internal implementation to largely improve testability 
of Structured Streaming state transition functions.
    
    **Usage Example**
    ```
    import org.apache.spark.sql.streaming.TestGroupState
    
    test(“Structured Streaming state update function”) {
      var prevState = TestGroupState.create[UserStatus](
        optionalState = Optional.empty[UserStatus],
        timeoutConf = EventTimeTimeout,
        batchProcessingTimeMs = 1L,
        eventTimeWatermarkMs = Optional.of(1L),
        hasTimedOut = false)
    
      val userId: String = ...
      val actions: Iterator[UserAction] = ...
    
      assert(!prevState.hasUpdated)
    
      updateState(userId, actions, prevState)
    
      assert(prevState.hasUpdated)
    }
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, the `TestGroupState` interface and its corresponding `create()` 
factory function in its companion object are introduced in this pull request 
for users to use in unit tests.
    
    ### How was this patch tested?
    - New unit tests are added
    - Existing GroupState unit tests are updated
    
    Closes #32938 from lizhangdatabricks/improve-group-state-testability.
    
    Authored-by: Li Zhang <li.zh...@databricks.com>
    Signed-off-by: Tathagata Das <tathagata.das1...@gmail.com>
---
 .../streaming/FlatMapGroupsWithStateExec.scala     |   8 +-
 .../sql/execution/streaming/GroupStateImpl.scala   |  36 ++-
 .../state/FlatMapGroupsWithStateExecHelper.scala   |   5 +-
 .../spark/sql/streaming/TestGroupState.scala       | 173 ++++++++++++++
 .../org/apache/spark/sql/JavaDatasetSuite.java     |  92 ++++++++
 .../streaming/FlatMapGroupsWithStateSuite.scala    | 255 +++++++++++++++------
 6 files changed, 475 insertions(+), 94 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
index e626fc1..981586e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
@@ -60,8 +60,8 @@ case class FlatMapGroupsWithStateExec(
     child: SparkPlan
   ) extends UnaryExecNode with ObjectProducerExec with StateStoreWriter with 
WatermarkSupport {
 
-  import GroupStateImpl._
   import FlatMapGroupsWithStateExecHelper._
+  import GroupStateImpl._
 
   private val isTimeoutEnabled = timeoutConf != NoTimeout
   private val watermarkPresent = child.output.exists {
@@ -229,13 +229,13 @@ case class FlatMapGroupsWithStateExec(
 
       // When the iterator is consumed, then write changes to state
       def onIteratorCompletion: Unit = {
-        if (groupState.hasRemoved && groupState.getTimeoutTimestamp == 
NO_TIMESTAMP) {
+        if (groupState.isRemoved && 
!groupState.getTimeoutTimestampMs.isPresent()) {
           stateManager.removeState(store, stateData.keyRow)
           numUpdatedStateRows += 1
         } else {
-          val currentTimeoutTimestamp = groupState.getTimeoutTimestamp
+          val currentTimeoutTimestamp = 
groupState.getTimeoutTimestampMs.orElse(NO_TIMESTAMP)
           val hasTimeoutChanged = currentTimeoutTimestamp != 
stateData.timeoutTimestamp
-          val shouldWriteState = groupState.hasUpdated || 
groupState.hasRemoved || hasTimeoutChanged
+          val shouldWriteState = groupState.isUpdated || groupState.isRemoved 
|| hasTimeoutChanged
 
           if (shouldWriteState) {
             val updatedStateObj = if (groupState.exists) groupState.get else 
null
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala
index 25756c2..b4f3712 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala
@@ -20,16 +20,16 @@ package org.apache.spark.sql.execution.streaming
 import java.sql.Date
 import java.util.concurrent.TimeUnit
 
-import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTimeout, 
ProcessingTimeTimeout}
+import org.apache.spark.api.java.Optional
+import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTimeout, 
NoTimeout, ProcessingTimeTimeout}
 import org.apache.spark.sql.catalyst.util.IntervalUtils
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.execution.streaming.GroupStateImpl._
-import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout}
+import org.apache.spark.sql.streaming.{GroupStateTimeout, TestGroupState}
 import org.apache.spark.unsafe.types.UTF8String
 
-
 /**
- * Internal implementation of the [[GroupState]] interface. Methods are not 
thread-safe.
+ * Internal implementation of the [[TestGroupState]] interface. Methods are 
not thread-safe.
  *
  * @param optionalValue Optional value of the state
  * @param batchProcessingTimeMs Processing time of current batch, used to 
calculate timestamp
@@ -45,7 +45,7 @@ private[sql] class GroupStateImpl[S] private(
     eventTimeWatermarkMs: Long,
     timeoutConf: GroupStateTimeout,
     override val hasTimedOut: Boolean,
-    watermarkPresent: Boolean) extends GroupState[S] {
+    watermarkPresent: Boolean) extends TestGroupState[S] {
 
   private var value: S = optionalValue.getOrElse(null.asInstanceOf[S])
   private var defined: Boolean = optionalValue.isDefined
@@ -147,14 +147,17 @@ private[sql] class GroupStateImpl[S] private(
 
   // ========= Internal API =========
 
-  /** Whether the state has been marked for removing */
-  def hasRemoved: Boolean = removed
+  override def isRemoved: Boolean = removed
 
-  /** Whether the state has been updated */
-  def hasUpdated: Boolean = updated
+  override def isUpdated: Boolean = updated
 
-  /** Return timeout timestamp or `TIMEOUT_TIMESTAMP_NOT_SET` if not set */
-  def getTimeoutTimestamp: Long = timeoutTimestamp
+  override def getTimeoutTimestampMs: Optional[Long] = {
+    if (timeoutTimestamp != NO_TIMESTAMP) {
+      Optional.of(timeoutTimestamp)
+    } else {
+      Optional.empty[Long]
+    }
+  }
 
   private def parseDuration(duration: String): Long = {
     val cal = IntervalUtils.stringToInterval(UTF8String.fromString(duration))
@@ -184,6 +187,17 @@ private[sql] object GroupStateImpl {
       timeoutConf: GroupStateTimeout,
       hasTimedOut: Boolean,
       watermarkPresent: Boolean): GroupStateImpl[S] = {
+    if (batchProcessingTimeMs < 0) {
+      throw new IllegalArgumentException("batchProcessingTimeMs must be 0 or 
positive")
+    }
+    if (watermarkPresent && eventTimeWatermarkMs < 0) {
+      throw new IllegalArgumentException("eventTimeWatermarkMs must be 0 or 
positive if present")
+    }
+    if (hasTimedOut && timeoutConf == NoTimeout) {
+      throw new UnsupportedOperationException(
+        "hasTimedOut is true however there's no timeout configured")
+    }
+
     new GroupStateImpl[S](
       optionalValue, batchProcessingTimeMs, eventTimeWatermarkMs,
       timeoutConf, hasTimedOut, watermarkPresent)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala
index cc785ee..2d9824e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala
@@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.streaming.state
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.execution.ObjectOperator
-import org.apache.spark.sql.execution.streaming.GroupStateImpl
 import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP
 import org.apache.spark.sql.types._
 
@@ -168,7 +167,7 @@ object FlatMapGroupsWithStateExecHelper {
     override val stateSerializerExprs: Seq[Expression] = {
       val encoderSerializer = stateEncoder.namedExpressions
       if (shouldStoreTimestamp) {
-        encoderSerializer :+ Literal(GroupStateImpl.NO_TIMESTAMP)
+        encoderSerializer :+ Literal(NO_TIMESTAMP)
       } else {
         encoderSerializer
       }
@@ -226,7 +225,7 @@ object FlatMapGroupsWithStateExecHelper {
       }
 
       if (shouldStoreTimestamp) {
-        Seq(nullSafeNestedStateSerExpr, Literal(GroupStateImpl.NO_TIMESTAMP))
+        Seq(nullSafeNestedStateSerExpr, Literal(NO_TIMESTAMP))
       } else {
         Seq(nullSafeNestedStateSerExpr)
       }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/streaming/TestGroupState.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/streaming/TestGroupState.scala
new file mode 100644
index 0000000..d53d608
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/streaming/TestGroupState.scala
@@ -0,0 +1,173 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import org.apache.spark.annotation.{Evolving, Experimental}
+import org.apache.spark.api.java.Optional
+import org.apache.spark.sql.execution.streaming.GroupStateImpl
+import org.apache.spark.sql.execution.streaming.GroupStateImpl._
+
+/**
+ * :: Experimental ::
+ *
+ * The extended version of [[GroupState]] interface with extra getters of 
state machine fields
+ * to improve testability of the [[GroupState]] implementations
+ * which inherit from the extended interface.
+ *
+ * Scala example of using `TestGroupState`:
+ * {{{
+ * // Please refer to ScalaDoc of `GroupState` for the Scala definition of 
`mappingFunction()`
+ *
+ * import org.apache.spark.api.java.Optional
+ * import org.apache.spark.sql.streaming.GroupStateTimeout
+ * import org.apache.spark.sql.streaming.TestGroupState
+ * // other imports
+ *
+ * // test class setups
+ *
+ * test("MapGroupsWithState state transition function") {
+ *   // Creates the prevState input for the state transition function
+ *   // with desired configs. The `create()` API would guarantee that
+ *   // the generated instance has the same behavior as the one built by
+ *   // engine with the same configs.
+ *   val prevState = TestGroupState.create[Int](
+ *     optionalState = Optional.empty[Int],
+ *     timeoutConf = NoTimeout,
+ *     batchProcessingTimeMs = 1L,
+ *     eventTimeWatermarkMs = Optional.of(1L),
+ *     hasTimedOut = false)
+ *
+ *   val key: String = ...
+ *   val values: Iterator[Int] = ...
+ *
+ *   // Asserts the prevState is in init state without updates.
+ *   assert(!prevState.isUpdated)
+ *
+ *   // Calls the state transition function with the test previous state
+ *   // with desired configs.
+ *   mappingFunction(key, values, prevState)
+ *
+ *   // Asserts the test GroupState object has been updated but not removed
+ *   // after calling the state transition function
+ *   assert(prevState.isUpdated)
+ *   assert(!prevState.isRemoved)
+ * }
+ * }}}
+ *
+ * Java example of using `TestGroupSate`:
+ * {{{
+ * // Please refer to ScalaDoc of `GroupState` for the Java definition of 
`mappingFunction()`
+ *
+ * import org.apache.spark.api.java.Optional;
+ * import org.apache.spark.sql.streaming.GroupStateTimeout;
+ * import org.apache.spark.sql.streaming.TestGroupState;
+ * // other imports
+ *
+ * // test class setups
+ *
+ * // test `MapGroupsWithState` state transition function `mappingFunction()`
+ * public void testMappingFunctionWithTestGroupState() {
+ *   // Creates the prevState input for the state transition function
+ *   // with desired configs. The `create()` API would guarantee that
+ *   // the generated instance has the same behavior as the one built by
+ *   // engine with the same configs.
+ *   TestGroupState<Int> prevState = TestGroupState.create(
+ *     Optional.empty(),
+ *     GroupStateTimeout.NoTimeout(),
+ *     1L,
+ *     Optional.of(1L),
+ *     false);
+ *
+ *   String key = ...;
+ *   Integer[] values = ...;
+ *
+ *   // Asserts the prevState is in init state without updates.
+ *   Assert.assertFalse(prevState.isUpdated());
+ *
+ *   // Calls the state transition function with the test previous state
+ *   // with desired configs.
+ *   mappingFunction.call(key, Arrays.asList(values).iterator(), prevState);
+ *
+ *   // Asserts the test GroupState object has been updated but not removed
+ *   // after calling the state transition function
+ *   Assert.assertTrue(prevState.isUpdated());
+ *   Assert.assertFalse(prevState.isRemoved());
+ * }
+ * }}}
+ *
+ * @tparam S User-defined type of the state to be stored for each group. Must 
be encodable into
+ *           Spark SQL types (see `Encoder` for more details).
+ * @since 3.2.0
+ */
+@Experimental
+@Evolving
+trait TestGroupState[S] extends GroupState[S] {
+  /** Whether the state has been marked for removing */
+  def isRemoved: Boolean
+
+  /** Whether the state has been updated but not removed */
+  def isUpdated: Boolean
+
+  /**
+   * Returns the timestamp if `setTimeoutTimestamp()` is called.
+   * Or, returns batch processing time + the duration when
+   * `setTimeoutDuration()` is called.
+   *
+   * Otherwise, returns `Optional.empty` if not set.
+   */
+  def getTimeoutTimestampMs: Optional[Long]
+}
+
+object TestGroupState {
+
+  /**
+   * Creates TestGroupState instances for general testing purposes.
+   *
+   * @param optionalState         Optional value of the state.
+   * @param timeoutConf           Type of timeout configured. Based on this, 
different operations
+   *                              will be supported.
+   * @param batchProcessingTimeMs Processing time of current batch, used to 
calculate timestamp
+   *                              for processing time timeouts.
+   * @param eventTimeWatermarkMs  Optional value of event time watermark in 
ms. Set as
+   *                              `Optional.empty` if watermark is not present.
+   *                              Otherwise, event time watermark should be a 
positive long
+   *                              and the timestampMs set through 
`setTimeoutTimestamp()`
+   *                              cannot be less than `eventTimeWatermarkMs`.
+   * @param hasTimedOut           Whether the key for which this state wrapped 
is being created is
+   *                              getting timed out or not.
+   * @return a [[TestGroupState]] instance built with the user specified 
configs.
+   */
+  @throws[IllegalArgumentException]("if 'batchProcessingTimeMs' is less than 
0")
+  @throws[IllegalArgumentException]("if 'eventTimeWatermarkMs' is present but 
less than 0")
+  @throws[UnsupportedOperationException](
+    "if 'hasTimedOut' is true however there's no timeout configured")
+  def create[S](
+      optionalState: Optional[S],
+      timeoutConf: GroupStateTimeout,
+      batchProcessingTimeMs: Long,
+      eventTimeWatermarkMs: Optional[Long],
+      hasTimedOut: Boolean): TestGroupState[S] = {
+    GroupStateImpl.createForStreaming[S](
+      Option(optionalState.orNull),
+      batchProcessingTimeMs,
+      eventTimeWatermarkMs.orElse(NO_TIMESTAMP),
+      timeoutConf,
+      hasTimedOut,
+      eventTimeWatermarkMs.isPresent())
+  }
+}
diff --git 
a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java 
b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 645c9e9..5e48dc6 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -25,6 +25,7 @@ import java.time.*;
 import java.util.*;
 import javax.annotation.Nonnull;
 
+import org.apache.spark.api.java.Optional;
 import org.apache.spark.sql.streaming.GroupStateTimeout;
 import org.apache.spark.sql.streaming.OutputMode;
 import scala.Tuple2;
@@ -33,6 +34,7 @@ import scala.Tuple4;
 import scala.Tuple5;
 
 import com.google.common.base.Objects;
+import org.apache.spark.sql.streaming.TestGroupState;
 import org.junit.*;
 
 import org.apache.spark.api.java.JavaPairRDD;
@@ -159,6 +161,96 @@ public class JavaDatasetSuite implements Serializable {
   }
 
   @Test
+  public void testIllegalTestGroupStateCreations() {
+    // SPARK-35800: test code throws upon illegal TestGroupState create() calls
+    Assert.assertThrows(
+      "eventTimeWatermarkMs must be 0 or positive if present",
+      IllegalArgumentException.class,
+      () -> {
+        TestGroupState.create(
+          Optional.of(5), GroupStateTimeout.EventTimeTimeout(), 0L, 
Optional.of(-1000L), false);
+      });
+
+    Assert.assertThrows(
+      "batchProcessingTimeMs must be 0 or positive",
+      IllegalArgumentException.class,
+      () -> {
+        TestGroupState.create(
+          Optional.of(5), GroupStateTimeout.EventTimeTimeout(), -100L, 
Optional.of(1000L), false);
+      });
+
+    Assert.assertThrows(
+      "hasTimedOut is true however there's no timeout configured",
+      UnsupportedOperationException.class,
+      () -> {
+        TestGroupState.create(
+          Optional.of(5), GroupStateTimeout.NoTimeout(), 100L, 
Optional.empty(), true);
+      });
+  }
+
+  @Test
+  public void testMappingFunctionWithTestGroupState() throws Exception {
+    // SPARK-35800: test the mapping function with injected TestGroupState 
instance
+    MapGroupsWithStateFunction<Integer, Integer, Integer, Integer> 
mappingFunction =
+      (MapGroupsWithStateFunction<Integer, Integer, Integer, Integer>) (key, 
values, state) -> {
+        if (state.hasTimedOut()) {
+          state.remove();
+          return 0;
+        }
+
+        int existingState = 0;
+        if (state.exists()) {
+          existingState = state.get();
+        } else {
+          // Set state timeout timestamp upon initialization
+          state.setTimeoutTimestamp(1500L);
+        }
+
+        while (values.hasNext()) {
+          existingState += values.next();
+        }
+        state.update(existingState);
+
+        return state.get();
+    };
+
+    TestGroupState<Integer> prevState = TestGroupState.create(
+      Optional.empty(), GroupStateTimeout.EventTimeTimeout(), 0L, 
Optional.of(1000L), false);
+
+    Assert.assertFalse(prevState.isUpdated());
+    Assert.assertFalse(prevState.isRemoved());
+    Assert.assertFalse(prevState.exists());
+    Assert.assertEquals(Optional.empty(), prevState.getTimeoutTimestampMs());
+
+    Integer[] values = {1, 3, 5};
+    mappingFunction.call(1, Arrays.asList(values).iterator(), prevState);
+
+    Assert.assertTrue(prevState.isUpdated());
+    Assert.assertFalse(prevState.isRemoved());
+    Assert.assertTrue(prevState.exists());
+    Assert.assertEquals(new Integer(9), prevState.get());
+    Assert.assertEquals(0L, prevState.getCurrentProcessingTimeMs());
+    Assert.assertEquals(1000L, prevState.getCurrentWatermarkMs());
+    Assert.assertEquals(Optional.of(1500L), prevState.getTimeoutTimestampMs());
+
+    mappingFunction.call(1, Arrays.asList(values).iterator(), prevState);
+
+    Assert.assertTrue(prevState.isUpdated());
+    Assert.assertFalse(prevState.isRemoved());
+    Assert.assertTrue(prevState.exists());
+    Assert.assertEquals(new Integer(18), prevState.get());
+
+    prevState = TestGroupState.create(
+      Optional.of(9), GroupStateTimeout.EventTimeTimeout(), 0L, 
Optional.of(1000L), true);
+
+    mappingFunction.call(1, Arrays.asList(values).iterator(), prevState);
+
+    Assert.assertFalse(prevState.isUpdated());
+    Assert.assertTrue(prevState.isRemoved());
+    Assert.assertFalse(prevState.exists());
+  }
+
+  @Test
   public void testGroupBy() {
     List<String> data = Arrays.asList("a", "foo", "bar");
     Dataset<String> ds = spark.createDataset(data, Encoders.STRING());
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
index 788be53..ad12d0d 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
@@ -24,6 +24,7 @@ import org.apache.commons.io.FileUtils
 import org.scalatest.exceptions.TestFailedException
 
 import org.apache.spark.SparkException
+import org.apache.spark.api.java.Optional
 import org.apache.spark.api.java.function.FlatMapGroupsWithStateFunction
 import org.apache.spark.sql.{DataFrame, Encoder}
 import org.apache.spark.sql.catalyst.InternalRow
@@ -48,12 +49,43 @@ case class Result(key: Long, count: Int)
 class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest {
 
   import testImplicits._
+
+  import FlatMapGroupsWithStateSuite._
   import GroupStateImpl._
   import GroupStateTimeout._
-  import FlatMapGroupsWithStateSuite._
+
+  test("SPARK-35800: ensure TestGroupState creates instances the same as 
prod") {
+    val testState = TestGroupState.create[Int](
+      Optional.of(5), EventTimeTimeout, 1L, Optional.of(1L), hasTimedOut = 
false)
+
+    val prodState = GroupStateImpl.createForStreaming[Int](
+      Some(5), 1L, 1L, EventTimeTimeout, false, true)
+
+    assert(testState.isInstanceOf[GroupStateImpl[Int]])
+
+    assert(testState.isRemoved === prodState.isRemoved)
+    assert(testState.isUpdated === prodState.isUpdated)
+    assert(testState.exists === prodState.exists)
+    assert(testState.get === prodState.get)
+    assert(testState.getTimeoutTimestampMs === prodState.getTimeoutTimestampMs)
+    assert(testState.hasTimedOut === prodState.hasTimedOut)
+    assert(testState.getCurrentProcessingTimeMs === 
prodState.getCurrentProcessingTimeMs)
+    assert(testState.getCurrentWatermarkMs === prodState.getCurrentWatermarkMs)
+
+    testState.update(6)
+    prodState.update(6)
+    assert(testState.isUpdated === prodState.isUpdated)
+    assert(testState.exists === prodState.exists)
+    assert(testState.get === prodState.get)
+
+    testState.remove()
+    prodState.remove()
+    assert(testState.exists === prodState.exists)
+    assert(testState.isRemoved === prodState.isRemoved)
+  }
 
   test("GroupState - get, exists, update, remove") {
-    var state: GroupStateImpl[String] = null
+    var state: TestGroupState[String] = null
 
     def testState(
         expectedData: Option[String],
@@ -69,21 +101,21 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest {
         }
       }
       assert(state.getOption === expectedData)
-      assert(state.hasUpdated === shouldBeUpdated)
-      assert(state.hasRemoved === shouldBeRemoved)
+      assert(state.isUpdated === shouldBeUpdated)
+      assert(state.isRemoved === shouldBeRemoved)
     }
 
     // === Tests for state in streaming queries ===
     // Updating empty state
-    state = GroupStateImpl.createForStreaming(
-      None, 1, 1, NoTimeout, hasTimedOut = false, watermarkPresent = false)
+    state = TestGroupState.create[String](
+      Optional.empty[String], NoTimeout, 1, Optional.empty[Long], hasTimedOut 
= false)
     testState(None)
     state.update("")
     testState(Some(""), shouldBeUpdated = true)
 
     // Updating exiting state
-    state = GroupStateImpl.createForStreaming(
-      Some("2"), 1, 1, NoTimeout, hasTimedOut = false, watermarkPresent = 
false)
+    state = TestGroupState.create[String](
+      Optional.of("2"), NoTimeout, 1, Optional.empty[Long], hasTimedOut = 
false)
     testState(Some("2"))
     state.update("3")
     testState(Some("3"), shouldBeUpdated = true)
@@ -102,10 +134,10 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest {
   }
 
   test("GroupState - setTimeout - with NoTimeout") {
-    for (initValue <- Seq(None, Some(5))) {
+    for (initValue <- Seq(Optional.empty[Int], Optional.of((5)))) {
       val states = Seq(
-        GroupStateImpl.createForStreaming(
-          initValue, 1000, 1000, NoTimeout, hasTimedOut = false, 
watermarkPresent = false),
+        TestGroupState.create[Int](
+          initValue, NoTimeout, 1000, Optional.empty[Long], hasTimedOut = 
false),
         GroupStateImpl.createForBatch(NoTimeout, watermarkPresent = false)
       )
       for (state <- states) {
@@ -122,33 +154,36 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest {
 
   test("GroupState - setTimeout - with ProcessingTimeTimeout") {
     // for streaming queries
-    var state: GroupStateImpl[Int] = GroupStateImpl.createForStreaming(
-      None, 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false, 
watermarkPresent = false)
-    assert(state.getTimeoutTimestamp === NO_TIMESTAMP)
+    var state = TestGroupState.create[Int](
+      Optional.empty[Int], ProcessingTimeTimeout, 1000, Optional.empty[Long], 
hasTimedOut = false)
+    assert(!state.getTimeoutTimestampMs.isPresent())
     state.setTimeoutDuration("-1 month 31 days 1 second")
-    assert(state.getTimeoutTimestamp === 2000)
+    assert(state.getTimeoutTimestampMs.isPresent())
+    assert(state.getTimeoutTimestampMs.get() === 2000)
     state.setTimeoutDuration(500)
-    assert(state.getTimeoutTimestamp === 1500) // can be set without 
initializing state
+    assert(state.getTimeoutTimestampMs.get() === 1500) // can be set without 
initializing state
     testTimeoutTimestampNotAllowed[UnsupportedOperationException](state)
 
     state.update(5)
-    assert(state.getTimeoutTimestamp === 1500) // does not change
+    assert(state.getTimeoutTimestampMs.isPresent())
+    assert(state.getTimeoutTimestampMs.get() === 1500) // does not change
     state.setTimeoutDuration(1000)
-    assert(state.getTimeoutTimestamp === 2000)
+    assert(state.getTimeoutTimestampMs.get() === 2000)
     state.setTimeoutDuration("2 second")
-    assert(state.getTimeoutTimestamp === 3000)
+    assert(state.getTimeoutTimestampMs.get() === 3000)
     testTimeoutTimestampNotAllowed[UnsupportedOperationException](state)
 
     state.remove()
-    assert(state.getTimeoutTimestamp === 3000) // does not change
+    assert(state.getTimeoutTimestampMs.isPresent())
+    assert(state.getTimeoutTimestampMs.get() === 3000) // does not change
     state.setTimeoutDuration(500) // can still be set
-    assert(state.getTimeoutTimestamp === 1500)
+    assert(state.getTimeoutTimestampMs.get() === 1500)
     testTimeoutTimestampNotAllowed[UnsupportedOperationException](state)
 
     // for batch queries
     state = GroupStateImpl.createForBatch(
       ProcessingTimeTimeout, watermarkPresent = 
false).asInstanceOf[GroupStateImpl[Int]]
-    assert(state.getTimeoutTimestamp === NO_TIMESTAMP)
+    assert(!state.getTimeoutTimestampMs.isPresent())
     state.setTimeoutDuration(500)
     testTimeoutTimestampNotAllowed[UnsupportedOperationException](state)
 
@@ -163,32 +198,31 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest {
   }
 
   test("GroupState - setTimeout - with EventTimeTimeout") {
-    var state: GroupStateImpl[Int] = GroupStateImpl.createForStreaming(
-      None, 1000, 1000, EventTimeTimeout, false, watermarkPresent = true)
-
-    assert(state.getTimeoutTimestamp === NO_TIMESTAMP)
+    var state = TestGroupState.create[Int](
+        Optional.empty[Int], EventTimeTimeout, 1000, Optional.of(1000), 
hasTimedOut = false)
+    assert(!state.getTimeoutTimestampMs.isPresent())
     testTimeoutDurationNotAllowed[UnsupportedOperationException](state)
     state.setTimeoutTimestamp(5000)
-    assert(state.getTimeoutTimestamp === 5000) // can be set without 
initializing state
+    assert(state.getTimeoutTimestampMs.get() === 5000) // can be set without 
initializing state
 
     state.update(5)
-    assert(state.getTimeoutTimestamp === 5000) // does not change
+    assert(state.getTimeoutTimestampMs.get() === 5000) // does not change
     state.setTimeoutTimestamp(10000)
-    assert(state.getTimeoutTimestamp === 10000)
+    assert(state.getTimeoutTimestampMs.get() === 10000)
     state.setTimeoutTimestamp(new Date(20000))
-    assert(state.getTimeoutTimestamp === 20000)
+    assert(state.getTimeoutTimestampMs.get() === 20000)
     testTimeoutDurationNotAllowed[UnsupportedOperationException](state)
 
     state.remove()
-    assert(state.getTimeoutTimestamp === 20000)
+    assert(state.getTimeoutTimestampMs.get() === 20000)
     state.setTimeoutTimestamp(5000)
-    assert(state.getTimeoutTimestamp === 5000) // can be set after removing 
state
+    assert(state.getTimeoutTimestampMs.get() === 5000) // can be set after 
removing state
     testTimeoutDurationNotAllowed[UnsupportedOperationException](state)
 
     // for batch queries
-    state = GroupStateImpl.createForBatch(EventTimeTimeout, watermarkPresent = 
false)
-      .asInstanceOf[GroupStateImpl[Int]]
-    assert(state.getTimeoutTimestamp === NO_TIMESTAMP)
+    state = GroupStateImpl.createForBatch(
+      EventTimeTimeout, watermarkPresent = 
false).asInstanceOf[GroupStateImpl[Int]]
+    assert(!state.getTimeoutTimestampMs.isPresent())
     testTimeoutDurationNotAllowed[UnsupportedOperationException](state)
     state.setTimeoutTimestamp(5000)
 
@@ -203,18 +237,20 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest {
   }
 
   test("GroupState - illegal params to setTimeout") {
-    var state: GroupStateImpl[Int] = null
+    var state: TestGroupState[Int] = null
 
-    // Test setTimeout****() with illegal values
+    // Test setTimeout() with illegal values
     def testIllegalTimeout(body: => Unit): Unit = {
       intercept[IllegalArgumentException] {
         body
       }
-      assert(state.getTimeoutTimestamp === NO_TIMESTAMP)
+      assert(!state.getTimeoutTimestampMs.isPresent())
     }
 
-    state = GroupStateImpl.createForStreaming(
-      Some(5), 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false, 
watermarkPresent = false)
+    // Test setTimeout() with illegal values
+    state = TestGroupState.create[Int](
+      Optional.of(5), ProcessingTimeTimeout, 1000, Optional.empty[Long], 
hasTimedOut = false)
+
     testIllegalTimeout {
       state.setTimeoutDuration(-1000)
     }
@@ -232,8 +268,8 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest {
       state.setTimeoutDuration("1 month -31 day")
     }
 
-    state = GroupStateImpl.createForStreaming(
-      Some(5), 1000, 1000, EventTimeTimeout, hasTimedOut = false, 
watermarkPresent = false)
+    state = TestGroupState.create[Int](
+        Optional.of(5), EventTimeTimeout, 1000, Optional.of(1000), hasTimedOut 
= false)
     testIllegalTimeout {
       state.setTimeoutTimestamp(-10000)
     }
@@ -260,17 +296,64 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest {
     }
   }
 
+  test("SPARK-35800: illegal params to create") {
+    // eventTimeWatermarkMs >= 0 if present
+    var illegalArgument = intercept[IllegalArgumentException] {
+      TestGroupState.create[Int](
+        Optional.of(5), EventTimeTimeout, 100L, Optional.of(-1000), 
hasTimedOut = false)
+    }
+    assert(
+      illegalArgument.getMessage.contains("eventTimeWatermarkMs must be 0 or 
positive if present"))
+    illegalArgument = intercept[IllegalArgumentException] {
+      GroupStateImpl.createForStreaming[Int](
+        Some(5), 100L, -1000L, EventTimeTimeout, false, true)
+    }
+    assert(
+      illegalArgument.getMessage.contains("eventTimeWatermarkMs must be 0 or 
positive if present"))
+
+    // batchProcessingTimeMs must be positive
+    illegalArgument = intercept[IllegalArgumentException] {
+      TestGroupState.create[Int](
+        Optional.of(5), EventTimeTimeout, -100L, Optional.of(1000), 
hasTimedOut = false)
+    }
+    assert(illegalArgument.getMessage.contains("batchProcessingTimeMs must be 
0 or positive"))
+    illegalArgument = intercept[IllegalArgumentException] {
+      GroupStateImpl.createForStreaming[Int](
+        Some(5), -100L, 1000L, EventTimeTimeout, false, true)
+    }
+    assert(illegalArgument.getMessage.contains("batchProcessingTimeMs must be 
0 or positive"))
+
+    // hasTimedOut cannot be true if there's no timeout configured
+    var unsupportedOperation = intercept[UnsupportedOperationException] {
+      TestGroupState.create[Int](
+        Optional.of(5), NoTimeout, 100L, Optional.empty[Long], hasTimedOut = 
true)
+    }
+    assert(
+      unsupportedOperation
+        .getMessage.contains("hasTimedOut is true however there's no timeout 
configured"))
+    unsupportedOperation = intercept[UnsupportedOperationException] {
+      GroupStateImpl.createForStreaming[Int](
+        Some(5), 100L, NO_TIMESTAMP, NoTimeout, true, false)
+    }
+    assert(
+      unsupportedOperation
+        .getMessage.contains("hasTimedOut is true however there's no timeout 
configured"))
+  }
+
   test("GroupState - hasTimedOut") {
     for (timeoutConf <- Seq(NoTimeout, ProcessingTimeTimeout, 
EventTimeTimeout)) {
       // for streaming queries
-      for (initState <- Seq(None, Some(5))) {
-        val state1 = GroupStateImpl.createForStreaming(
-          initState, 1000, 1000, timeoutConf, hasTimedOut = false, 
watermarkPresent = false)
+      for (initState <- Seq(Optional.empty[Int], Optional.of(5))) {
+        val state1 = TestGroupState.create[Int](
+          initState, timeoutConf, 1000, Optional.empty[Long], hasTimedOut = 
false)
         assert(state1.hasTimedOut === false)
 
-        val state2 = GroupStateImpl.createForStreaming(
-          initState, 1000, 1000, timeoutConf, hasTimedOut = true, 
watermarkPresent = false)
-        assert(state2.hasTimedOut)
+        // hasTimedOut can only be set as true when timeoutConf isn't NoTimeout
+        if (timeoutConf != NoTimeout) {
+          val state2 = TestGroupState.create[Int](
+            initState, timeoutConf, 1000, Optional.empty[Long], hasTimedOut = 
true)
+          assert(state2.hasTimedOut)
+        }
       }
 
       // for batch queries
@@ -280,10 +363,11 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest {
   }
 
   test("GroupState - getCurrentWatermarkMs") {
-    def streamingState(timeoutConf: GroupStateTimeout, watermark: 
Option[Long]): GroupState[Int] = {
-      GroupStateImpl.createForStreaming(
-        None, 1000, watermark.getOrElse(-1), timeoutConf,
-        hasTimedOut = false, watermark.nonEmpty)
+    def streamingState(
+        timeoutConf: GroupStateTimeout,
+        watermark: Optional[Long]): GroupState[Int] = {
+      TestGroupState.create[Int](
+        Optional.empty[Int], timeoutConf, 1000, watermark, hasTimedOut = false)
     }
 
     def batchState(timeoutConf: GroupStateTimeout, watermarkPresent: Boolean): 
GroupState[Any] = {
@@ -298,9 +382,13 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest {
 
     for (timeoutConf <- Seq(NoTimeout, EventTimeTimeout, 
ProcessingTimeTimeout)) {
       // Tests for getCurrentWatermarkMs in streaming queries
-      assertWrongTimeoutError { streamingState(timeoutConf, 
None).getCurrentWatermarkMs() }
-      assert(streamingState(timeoutConf, Some(1000)).getCurrentWatermarkMs() 
=== 1000)
-      assert(streamingState(timeoutConf, Some(2000)).getCurrentWatermarkMs() 
=== 2000)
+      assertWrongTimeoutError {
+        streamingState(timeoutConf, 
Optional.empty[Long]).getCurrentWatermarkMs()
+      }
+      assert(streamingState(timeoutConf, 
Optional.of(0)).getCurrentWatermarkMs() === 0)
+      assert(streamingState(timeoutConf, 
Optional.of(1000)).getCurrentWatermarkMs() === 1000)
+      assert(streamingState(timeoutConf, 
Optional.of(2000)).getCurrentWatermarkMs() === 2000)
+      assert(batchState(EventTimeTimeout, watermarkPresent = 
true).getCurrentWatermarkMs() === -1)
 
       // Tests for getCurrentWatermarkMs in batch queries
       assertWrongTimeoutError {
@@ -312,11 +400,15 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest {
 
   test("GroupState - getCurrentProcessingTimeMs") {
     def streamingState(
-        timeoutConf: GroupStateTimeout,
-        procTime: Long,
-        watermarkPresent: Boolean): GroupState[Int] = {
-      GroupStateImpl.createForStreaming(
-        None, procTime, -1, timeoutConf, hasTimedOut = false, watermarkPresent 
= false)
+      timeoutConf: GroupStateTimeout,
+      procTime: Long,
+      watermarkPresent: Boolean): GroupState[Int] = {
+      val eventTimeWatermarkMs = watermarkPresent match {
+        case true => Optional.of(1000L)
+        case false => Optional.empty[Long]
+      }
+      TestGroupState.create[Int](
+        Optional.of(1000), timeoutConf, procTime, eventTimeWatermarkMs, 
hasTimedOut = false)
     }
 
     def batchState(timeoutConf: GroupStateTimeout, watermarkPresent: Boolean): 
GroupState[Any] = {
@@ -326,8 +418,10 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest {
     for (timeoutConf <- Seq(NoTimeout, EventTimeTimeout, 
ProcessingTimeTimeout)) {
       for (watermarkPresent <- Seq(false, true)) {
         // Tests for getCurrentProcessingTimeMs in streaming queries
-        assert(streamingState(timeoutConf, NO_TIMESTAMP, watermarkPresent)
-            .getCurrentProcessingTimeMs() === -1)
+        // No negative processing time is allowed, and
+        // illegal input validation has been added in the separate test
+        assert(streamingState(timeoutConf, 0, watermarkPresent)
+          .getCurrentProcessingTimeMs() === 0)
         assert(streamingState(timeoutConf, 1000, watermarkPresent)
           .getCurrentProcessingTimeMs() === 1000)
         assert(streamingState(timeoutConf, 2000, watermarkPresent)
@@ -342,15 +436,24 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest {
 
 
   test("GroupState - primitive type") {
-    var intState = GroupStateImpl.createForStreaming[Int](
-      None, 1000, 1000, NoTimeout, hasTimedOut = false, watermarkPresent = 
false)
+    var intState = TestGroupState.create[Int](
+      Optional.empty[Int],
+      NoTimeout,
+      1000,
+      Optional.empty[Long],
+      hasTimedOut = false)
     intercept[NoSuchElementException] {
       intState.get
     }
     assert(intState.getOption === None)
 
-    intState = GroupStateImpl.createForStreaming[Int](
-      Some(10), 1000, 1000, NoTimeout, hasTimedOut = false, watermarkPresent = 
false)
+    intState = TestGroupState.create[Int](
+      Optional.of(10),
+      NoTimeout,
+      1000,
+      Optional.empty[Long],
+      hasTimedOut = false)
+
     assert(intState.get == 10)
     intState.update(0)
     assert(intState.get == 0)
@@ -1291,24 +1394,24 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest {
       }.get
   }
 
-  def testTimeoutDurationNotAllowed[T <: Exception: Manifest](state: 
GroupStateImpl[_]): Unit = {
-    val prevTimestamp = state.getTimeoutTimestamp
+  def testTimeoutDurationNotAllowed[T <: Exception: Manifest](state: 
TestGroupState[_]): Unit = {
+    val prevTimestamp = state.getTimeoutTimestampMs
     intercept[T] { state.setTimeoutDuration(1000) }
-    assert(state.getTimeoutTimestamp === prevTimestamp)
+    assert(state.getTimeoutTimestampMs === prevTimestamp)
     intercept[T] { state.setTimeoutDuration("2 second") }
-    assert(state.getTimeoutTimestamp === prevTimestamp)
+    assert(state.getTimeoutTimestampMs === prevTimestamp)
   }
 
-  def testTimeoutTimestampNotAllowed[T <: Exception: Manifest](state: 
GroupStateImpl[_]): Unit = {
-    val prevTimestamp = state.getTimeoutTimestamp
+  def testTimeoutTimestampNotAllowed[T <: Exception: Manifest](state: 
TestGroupState[_]): Unit = {
+    val prevTimestamp = state.getTimeoutTimestampMs
     intercept[T] { state.setTimeoutTimestamp(2000) }
-    assert(state.getTimeoutTimestamp === prevTimestamp)
+    assert(state.getTimeoutTimestampMs === prevTimestamp)
     intercept[T] { state.setTimeoutTimestamp(2000, "1 second") }
-    assert(state.getTimeoutTimestamp === prevTimestamp)
+    assert(state.getTimeoutTimestampMs === prevTimestamp)
     intercept[T] { state.setTimeoutTimestamp(new Date(2000)) }
-    assert(state.getTimeoutTimestamp === prevTimestamp)
+    assert(state.getTimeoutTimestampMs === prevTimestamp)
     intercept[T] { state.setTimeoutTimestamp(new Date(2000), "1 second") }
-    assert(state.getTimeoutTimestamp === prevTimestamp)
+    assert(state.getTimeoutTimestampMs === prevTimestamp)
   }
 
   def newStateStore(): StateStore = new MemoryStateStore()

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to