This is an automated email from the ASF dual-hosted git repository. tdas pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new dfd7b02 [SPARK-35800][SS] Improving GroupState testability by introducing TestGroupState dfd7b02 is described below commit dfd7b026dc7c3c38bef9afab82852aff902a25d2 Author: Li Zhang <li.zh...@databricks.com> AuthorDate: Tue Jun 22 15:04:01 2021 -0400 [SPARK-35800][SS] Improving GroupState testability by introducing TestGroupState ### What changes were proposed in this pull request? Proposed changes in this pull request: 1. Introducing the `TestGroupState` interface which is inherited from `GroupState` so that testing related getters can be exposed in a controlled manner 2. Changing `GroupStateImpl` to inherit from `TestGroupState` interface, instead of directly from `GroupState` 3. Implementing `TestGroupState` object with `create()` method to forward inputs to the private `GroupStateImpl` constructor 4. User input validations have been added into `GroupStateImpl`'s `createForStreaming()` method to prevent users from creating invalid GroupState objects. 5. Replacing existing `GroupStateImpl` usages in sql pkg internal unit tests with the newly added `TestGroupState` to give user best practice about `TestGroupState` usage. With the changes in this PR, the class hierarchy is changed from `GroupStateImpl` -> `GroupState` to `GroupStateImpl` -> `TestGroupState` -> `GroupState` (-> means inherits from) ### Why are the changes needed? The internal `GroupStateImpl` implementation for the `GroupState` interface has no public constructors accessible outside of the sql pkg. However, the user-provided state transition function for `[map|flatMap]GroupsWithState` requires a `GroupState` object as the prevState input. Currently, users are calling the Structured Streaming engine in their unit tests in order to instantiate such `GroupState` instances, which makes UTs cumbersome. The proposed `TestGroupState` interface is to give users controlled access to the `GroupStateImpl` internal implementation to largely improve testability of Structured Streaming state transition functions. **Usage Example** ``` import org.apache.spark.sql.streaming.TestGroupState test(“Structured Streaming state update function”) { var prevState = TestGroupState.create[UserStatus]( optionalState = Optional.empty[UserStatus], timeoutConf = EventTimeTimeout, batchProcessingTimeMs = 1L, eventTimeWatermarkMs = Optional.of(1L), hasTimedOut = false) val userId: String = ... val actions: Iterator[UserAction] = ... assert(!prevState.hasUpdated) updateState(userId, actions, prevState) assert(prevState.hasUpdated) } ``` ### Does this PR introduce _any_ user-facing change? Yes, the `TestGroupState` interface and its corresponding `create()` factory function in its companion object are introduced in this pull request for users to use in unit tests. ### How was this patch tested? - New unit tests are added - Existing GroupState unit tests are updated Closes #32938 from lizhangdatabricks/improve-group-state-testability. Authored-by: Li Zhang <li.zh...@databricks.com> Signed-off-by: Tathagata Das <tathagata.das1...@gmail.com> --- .../streaming/FlatMapGroupsWithStateExec.scala | 8 +- .../sql/execution/streaming/GroupStateImpl.scala | 36 ++- .../state/FlatMapGroupsWithStateExecHelper.scala | 5 +- .../spark/sql/streaming/TestGroupState.scala | 173 ++++++++++++++ .../org/apache/spark/sql/JavaDatasetSuite.java | 92 ++++++++ .../streaming/FlatMapGroupsWithStateSuite.scala | 255 +++++++++++++++------ 6 files changed, 475 insertions(+), 94 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index e626fc1..981586e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -60,8 +60,8 @@ case class FlatMapGroupsWithStateExec( child: SparkPlan ) extends UnaryExecNode with ObjectProducerExec with StateStoreWriter with WatermarkSupport { - import GroupStateImpl._ import FlatMapGroupsWithStateExecHelper._ + import GroupStateImpl._ private val isTimeoutEnabled = timeoutConf != NoTimeout private val watermarkPresent = child.output.exists { @@ -229,13 +229,13 @@ case class FlatMapGroupsWithStateExec( // When the iterator is consumed, then write changes to state def onIteratorCompletion: Unit = { - if (groupState.hasRemoved && groupState.getTimeoutTimestamp == NO_TIMESTAMP) { + if (groupState.isRemoved && !groupState.getTimeoutTimestampMs.isPresent()) { stateManager.removeState(store, stateData.keyRow) numUpdatedStateRows += 1 } else { - val currentTimeoutTimestamp = groupState.getTimeoutTimestamp + val currentTimeoutTimestamp = groupState.getTimeoutTimestampMs.orElse(NO_TIMESTAMP) val hasTimeoutChanged = currentTimeoutTimestamp != stateData.timeoutTimestamp - val shouldWriteState = groupState.hasUpdated || groupState.hasRemoved || hasTimeoutChanged + val shouldWriteState = groupState.isUpdated || groupState.isRemoved || hasTimeoutChanged if (shouldWriteState) { val updatedStateObj = if (groupState.exists) groupState.get else null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala index 25756c2..b4f3712 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala @@ -20,16 +20,16 @@ package org.apache.spark.sql.execution.streaming import java.sql.Date import java.util.concurrent.TimeUnit -import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTimeout, ProcessingTimeTimeout} +import org.apache.spark.api.java.Optional +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTimeout, NoTimeout, ProcessingTimeTimeout} import org.apache.spark.sql.catalyst.util.IntervalUtils import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.streaming.GroupStateImpl._ -import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout} +import org.apache.spark.sql.streaming.{GroupStateTimeout, TestGroupState} import org.apache.spark.unsafe.types.UTF8String - /** - * Internal implementation of the [[GroupState]] interface. Methods are not thread-safe. + * Internal implementation of the [[TestGroupState]] interface. Methods are not thread-safe. * * @param optionalValue Optional value of the state * @param batchProcessingTimeMs Processing time of current batch, used to calculate timestamp @@ -45,7 +45,7 @@ private[sql] class GroupStateImpl[S] private( eventTimeWatermarkMs: Long, timeoutConf: GroupStateTimeout, override val hasTimedOut: Boolean, - watermarkPresent: Boolean) extends GroupState[S] { + watermarkPresent: Boolean) extends TestGroupState[S] { private var value: S = optionalValue.getOrElse(null.asInstanceOf[S]) private var defined: Boolean = optionalValue.isDefined @@ -147,14 +147,17 @@ private[sql] class GroupStateImpl[S] private( // ========= Internal API ========= - /** Whether the state has been marked for removing */ - def hasRemoved: Boolean = removed + override def isRemoved: Boolean = removed - /** Whether the state has been updated */ - def hasUpdated: Boolean = updated + override def isUpdated: Boolean = updated - /** Return timeout timestamp or `TIMEOUT_TIMESTAMP_NOT_SET` if not set */ - def getTimeoutTimestamp: Long = timeoutTimestamp + override def getTimeoutTimestampMs: Optional[Long] = { + if (timeoutTimestamp != NO_TIMESTAMP) { + Optional.of(timeoutTimestamp) + } else { + Optional.empty[Long] + } + } private def parseDuration(duration: String): Long = { val cal = IntervalUtils.stringToInterval(UTF8String.fromString(duration)) @@ -184,6 +187,17 @@ private[sql] object GroupStateImpl { timeoutConf: GroupStateTimeout, hasTimedOut: Boolean, watermarkPresent: Boolean): GroupStateImpl[S] = { + if (batchProcessingTimeMs < 0) { + throw new IllegalArgumentException("batchProcessingTimeMs must be 0 or positive") + } + if (watermarkPresent && eventTimeWatermarkMs < 0) { + throw new IllegalArgumentException("eventTimeWatermarkMs must be 0 or positive if present") + } + if (hasTimedOut && timeoutConf == NoTimeout) { + throw new UnsupportedOperationException( + "hasTimedOut is true however there's no timeout configured") + } + new GroupStateImpl[S]( optionalValue, batchProcessingTimeMs, eventTimeWatermarkMs, timeoutConf, hasTimedOut, watermarkPresent) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala index cc785ee..2d9824e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.streaming.state import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.ObjectOperator -import org.apache.spark.sql.execution.streaming.GroupStateImpl import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP import org.apache.spark.sql.types._ @@ -168,7 +167,7 @@ object FlatMapGroupsWithStateExecHelper { override val stateSerializerExprs: Seq[Expression] = { val encoderSerializer = stateEncoder.namedExpressions if (shouldStoreTimestamp) { - encoderSerializer :+ Literal(GroupStateImpl.NO_TIMESTAMP) + encoderSerializer :+ Literal(NO_TIMESTAMP) } else { encoderSerializer } @@ -226,7 +225,7 @@ object FlatMapGroupsWithStateExecHelper { } if (shouldStoreTimestamp) { - Seq(nullSafeNestedStateSerExpr, Literal(GroupStateImpl.NO_TIMESTAMP)) + Seq(nullSafeNestedStateSerExpr, Literal(NO_TIMESTAMP)) } else { Seq(nullSafeNestedStateSerExpr) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/TestGroupState.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/TestGroupState.scala new file mode 100644 index 0000000..d53d608 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/TestGroupState.scala @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.annotation.{Evolving, Experimental} +import org.apache.spark.api.java.Optional +import org.apache.spark.sql.execution.streaming.GroupStateImpl +import org.apache.spark.sql.execution.streaming.GroupStateImpl._ + +/** + * :: Experimental :: + * + * The extended version of [[GroupState]] interface with extra getters of state machine fields + * to improve testability of the [[GroupState]] implementations + * which inherit from the extended interface. + * + * Scala example of using `TestGroupState`: + * {{{ + * // Please refer to ScalaDoc of `GroupState` for the Scala definition of `mappingFunction()` + * + * import org.apache.spark.api.java.Optional + * import org.apache.spark.sql.streaming.GroupStateTimeout + * import org.apache.spark.sql.streaming.TestGroupState + * // other imports + * + * // test class setups + * + * test("MapGroupsWithState state transition function") { + * // Creates the prevState input for the state transition function + * // with desired configs. The `create()` API would guarantee that + * // the generated instance has the same behavior as the one built by + * // engine with the same configs. + * val prevState = TestGroupState.create[Int]( + * optionalState = Optional.empty[Int], + * timeoutConf = NoTimeout, + * batchProcessingTimeMs = 1L, + * eventTimeWatermarkMs = Optional.of(1L), + * hasTimedOut = false) + * + * val key: String = ... + * val values: Iterator[Int] = ... + * + * // Asserts the prevState is in init state without updates. + * assert(!prevState.isUpdated) + * + * // Calls the state transition function with the test previous state + * // with desired configs. + * mappingFunction(key, values, prevState) + * + * // Asserts the test GroupState object has been updated but not removed + * // after calling the state transition function + * assert(prevState.isUpdated) + * assert(!prevState.isRemoved) + * } + * }}} + * + * Java example of using `TestGroupSate`: + * {{{ + * // Please refer to ScalaDoc of `GroupState` for the Java definition of `mappingFunction()` + * + * import org.apache.spark.api.java.Optional; + * import org.apache.spark.sql.streaming.GroupStateTimeout; + * import org.apache.spark.sql.streaming.TestGroupState; + * // other imports + * + * // test class setups + * + * // test `MapGroupsWithState` state transition function `mappingFunction()` + * public void testMappingFunctionWithTestGroupState() { + * // Creates the prevState input for the state transition function + * // with desired configs. The `create()` API would guarantee that + * // the generated instance has the same behavior as the one built by + * // engine with the same configs. + * TestGroupState<Int> prevState = TestGroupState.create( + * Optional.empty(), + * GroupStateTimeout.NoTimeout(), + * 1L, + * Optional.of(1L), + * false); + * + * String key = ...; + * Integer[] values = ...; + * + * // Asserts the prevState is in init state without updates. + * Assert.assertFalse(prevState.isUpdated()); + * + * // Calls the state transition function with the test previous state + * // with desired configs. + * mappingFunction.call(key, Arrays.asList(values).iterator(), prevState); + * + * // Asserts the test GroupState object has been updated but not removed + * // after calling the state transition function + * Assert.assertTrue(prevState.isUpdated()); + * Assert.assertFalse(prevState.isRemoved()); + * } + * }}} + * + * @tparam S User-defined type of the state to be stored for each group. Must be encodable into + * Spark SQL types (see `Encoder` for more details). + * @since 3.2.0 + */ +@Experimental +@Evolving +trait TestGroupState[S] extends GroupState[S] { + /** Whether the state has been marked for removing */ + def isRemoved: Boolean + + /** Whether the state has been updated but not removed */ + def isUpdated: Boolean + + /** + * Returns the timestamp if `setTimeoutTimestamp()` is called. + * Or, returns batch processing time + the duration when + * `setTimeoutDuration()` is called. + * + * Otherwise, returns `Optional.empty` if not set. + */ + def getTimeoutTimestampMs: Optional[Long] +} + +object TestGroupState { + + /** + * Creates TestGroupState instances for general testing purposes. + * + * @param optionalState Optional value of the state. + * @param timeoutConf Type of timeout configured. Based on this, different operations + * will be supported. + * @param batchProcessingTimeMs Processing time of current batch, used to calculate timestamp + * for processing time timeouts. + * @param eventTimeWatermarkMs Optional value of event time watermark in ms. Set as + * `Optional.empty` if watermark is not present. + * Otherwise, event time watermark should be a positive long + * and the timestampMs set through `setTimeoutTimestamp()` + * cannot be less than `eventTimeWatermarkMs`. + * @param hasTimedOut Whether the key for which this state wrapped is being created is + * getting timed out or not. + * @return a [[TestGroupState]] instance built with the user specified configs. + */ + @throws[IllegalArgumentException]("if 'batchProcessingTimeMs' is less than 0") + @throws[IllegalArgumentException]("if 'eventTimeWatermarkMs' is present but less than 0") + @throws[UnsupportedOperationException]( + "if 'hasTimedOut' is true however there's no timeout configured") + def create[S]( + optionalState: Optional[S], + timeoutConf: GroupStateTimeout, + batchProcessingTimeMs: Long, + eventTimeWatermarkMs: Optional[Long], + hasTimedOut: Boolean): TestGroupState[S] = { + GroupStateImpl.createForStreaming[S]( + Option(optionalState.orNull), + batchProcessingTimeMs, + eventTimeWatermarkMs.orElse(NO_TIMESTAMP), + timeoutConf, + hasTimedOut, + eventTimeWatermarkMs.isPresent()) + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 645c9e9..5e48dc6 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -25,6 +25,7 @@ import java.time.*; import java.util.*; import javax.annotation.Nonnull; +import org.apache.spark.api.java.Optional; import org.apache.spark.sql.streaming.GroupStateTimeout; import org.apache.spark.sql.streaming.OutputMode; import scala.Tuple2; @@ -33,6 +34,7 @@ import scala.Tuple4; import scala.Tuple5; import com.google.common.base.Objects; +import org.apache.spark.sql.streaming.TestGroupState; import org.junit.*; import org.apache.spark.api.java.JavaPairRDD; @@ -159,6 +161,96 @@ public class JavaDatasetSuite implements Serializable { } @Test + public void testIllegalTestGroupStateCreations() { + // SPARK-35800: test code throws upon illegal TestGroupState create() calls + Assert.assertThrows( + "eventTimeWatermarkMs must be 0 or positive if present", + IllegalArgumentException.class, + () -> { + TestGroupState.create( + Optional.of(5), GroupStateTimeout.EventTimeTimeout(), 0L, Optional.of(-1000L), false); + }); + + Assert.assertThrows( + "batchProcessingTimeMs must be 0 or positive", + IllegalArgumentException.class, + () -> { + TestGroupState.create( + Optional.of(5), GroupStateTimeout.EventTimeTimeout(), -100L, Optional.of(1000L), false); + }); + + Assert.assertThrows( + "hasTimedOut is true however there's no timeout configured", + UnsupportedOperationException.class, + () -> { + TestGroupState.create( + Optional.of(5), GroupStateTimeout.NoTimeout(), 100L, Optional.empty(), true); + }); + } + + @Test + public void testMappingFunctionWithTestGroupState() throws Exception { + // SPARK-35800: test the mapping function with injected TestGroupState instance + MapGroupsWithStateFunction<Integer, Integer, Integer, Integer> mappingFunction = + (MapGroupsWithStateFunction<Integer, Integer, Integer, Integer>) (key, values, state) -> { + if (state.hasTimedOut()) { + state.remove(); + return 0; + } + + int existingState = 0; + if (state.exists()) { + existingState = state.get(); + } else { + // Set state timeout timestamp upon initialization + state.setTimeoutTimestamp(1500L); + } + + while (values.hasNext()) { + existingState += values.next(); + } + state.update(existingState); + + return state.get(); + }; + + TestGroupState<Integer> prevState = TestGroupState.create( + Optional.empty(), GroupStateTimeout.EventTimeTimeout(), 0L, Optional.of(1000L), false); + + Assert.assertFalse(prevState.isUpdated()); + Assert.assertFalse(prevState.isRemoved()); + Assert.assertFalse(prevState.exists()); + Assert.assertEquals(Optional.empty(), prevState.getTimeoutTimestampMs()); + + Integer[] values = {1, 3, 5}; + mappingFunction.call(1, Arrays.asList(values).iterator(), prevState); + + Assert.assertTrue(prevState.isUpdated()); + Assert.assertFalse(prevState.isRemoved()); + Assert.assertTrue(prevState.exists()); + Assert.assertEquals(new Integer(9), prevState.get()); + Assert.assertEquals(0L, prevState.getCurrentProcessingTimeMs()); + Assert.assertEquals(1000L, prevState.getCurrentWatermarkMs()); + Assert.assertEquals(Optional.of(1500L), prevState.getTimeoutTimestampMs()); + + mappingFunction.call(1, Arrays.asList(values).iterator(), prevState); + + Assert.assertTrue(prevState.isUpdated()); + Assert.assertFalse(prevState.isRemoved()); + Assert.assertTrue(prevState.exists()); + Assert.assertEquals(new Integer(18), prevState.get()); + + prevState = TestGroupState.create( + Optional.of(9), GroupStateTimeout.EventTimeTimeout(), 0L, Optional.of(1000L), true); + + mappingFunction.call(1, Arrays.asList(values).iterator(), prevState); + + Assert.assertFalse(prevState.isUpdated()); + Assert.assertTrue(prevState.isRemoved()); + Assert.assertFalse(prevState.exists()); + } + + @Test public void testGroupBy() { List<String> data = Arrays.asList("a", "foo", "bar"); Dataset<String> ds = spark.createDataset(data, Encoders.STRING()); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 788be53..ad12d0d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -24,6 +24,7 @@ import org.apache.commons.io.FileUtils import org.scalatest.exceptions.TestFailedException import org.apache.spark.SparkException +import org.apache.spark.api.java.Optional import org.apache.spark.api.java.function.FlatMapGroupsWithStateFunction import org.apache.spark.sql.{DataFrame, Encoder} import org.apache.spark.sql.catalyst.InternalRow @@ -48,12 +49,43 @@ case class Result(key: Long, count: Int) class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest { import testImplicits._ + + import FlatMapGroupsWithStateSuite._ import GroupStateImpl._ import GroupStateTimeout._ - import FlatMapGroupsWithStateSuite._ + + test("SPARK-35800: ensure TestGroupState creates instances the same as prod") { + val testState = TestGroupState.create[Int]( + Optional.of(5), EventTimeTimeout, 1L, Optional.of(1L), hasTimedOut = false) + + val prodState = GroupStateImpl.createForStreaming[Int]( + Some(5), 1L, 1L, EventTimeTimeout, false, true) + + assert(testState.isInstanceOf[GroupStateImpl[Int]]) + + assert(testState.isRemoved === prodState.isRemoved) + assert(testState.isUpdated === prodState.isUpdated) + assert(testState.exists === prodState.exists) + assert(testState.get === prodState.get) + assert(testState.getTimeoutTimestampMs === prodState.getTimeoutTimestampMs) + assert(testState.hasTimedOut === prodState.hasTimedOut) + assert(testState.getCurrentProcessingTimeMs === prodState.getCurrentProcessingTimeMs) + assert(testState.getCurrentWatermarkMs === prodState.getCurrentWatermarkMs) + + testState.update(6) + prodState.update(6) + assert(testState.isUpdated === prodState.isUpdated) + assert(testState.exists === prodState.exists) + assert(testState.get === prodState.get) + + testState.remove() + prodState.remove() + assert(testState.exists === prodState.exists) + assert(testState.isRemoved === prodState.isRemoved) + } test("GroupState - get, exists, update, remove") { - var state: GroupStateImpl[String] = null + var state: TestGroupState[String] = null def testState( expectedData: Option[String], @@ -69,21 +101,21 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest { } } assert(state.getOption === expectedData) - assert(state.hasUpdated === shouldBeUpdated) - assert(state.hasRemoved === shouldBeRemoved) + assert(state.isUpdated === shouldBeUpdated) + assert(state.isRemoved === shouldBeRemoved) } // === Tests for state in streaming queries === // Updating empty state - state = GroupStateImpl.createForStreaming( - None, 1, 1, NoTimeout, hasTimedOut = false, watermarkPresent = false) + state = TestGroupState.create[String]( + Optional.empty[String], NoTimeout, 1, Optional.empty[Long], hasTimedOut = false) testState(None) state.update("") testState(Some(""), shouldBeUpdated = true) // Updating exiting state - state = GroupStateImpl.createForStreaming( - Some("2"), 1, 1, NoTimeout, hasTimedOut = false, watermarkPresent = false) + state = TestGroupState.create[String]( + Optional.of("2"), NoTimeout, 1, Optional.empty[Long], hasTimedOut = false) testState(Some("2")) state.update("3") testState(Some("3"), shouldBeUpdated = true) @@ -102,10 +134,10 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest { } test("GroupState - setTimeout - with NoTimeout") { - for (initValue <- Seq(None, Some(5))) { + for (initValue <- Seq(Optional.empty[Int], Optional.of((5)))) { val states = Seq( - GroupStateImpl.createForStreaming( - initValue, 1000, 1000, NoTimeout, hasTimedOut = false, watermarkPresent = false), + TestGroupState.create[Int]( + initValue, NoTimeout, 1000, Optional.empty[Long], hasTimedOut = false), GroupStateImpl.createForBatch(NoTimeout, watermarkPresent = false) ) for (state <- states) { @@ -122,33 +154,36 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest { test("GroupState - setTimeout - with ProcessingTimeTimeout") { // for streaming queries - var state: GroupStateImpl[Int] = GroupStateImpl.createForStreaming( - None, 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false, watermarkPresent = false) - assert(state.getTimeoutTimestamp === NO_TIMESTAMP) + var state = TestGroupState.create[Int]( + Optional.empty[Int], ProcessingTimeTimeout, 1000, Optional.empty[Long], hasTimedOut = false) + assert(!state.getTimeoutTimestampMs.isPresent()) state.setTimeoutDuration("-1 month 31 days 1 second") - assert(state.getTimeoutTimestamp === 2000) + assert(state.getTimeoutTimestampMs.isPresent()) + assert(state.getTimeoutTimestampMs.get() === 2000) state.setTimeoutDuration(500) - assert(state.getTimeoutTimestamp === 1500) // can be set without initializing state + assert(state.getTimeoutTimestampMs.get() === 1500) // can be set without initializing state testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) state.update(5) - assert(state.getTimeoutTimestamp === 1500) // does not change + assert(state.getTimeoutTimestampMs.isPresent()) + assert(state.getTimeoutTimestampMs.get() === 1500) // does not change state.setTimeoutDuration(1000) - assert(state.getTimeoutTimestamp === 2000) + assert(state.getTimeoutTimestampMs.get() === 2000) state.setTimeoutDuration("2 second") - assert(state.getTimeoutTimestamp === 3000) + assert(state.getTimeoutTimestampMs.get() === 3000) testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) state.remove() - assert(state.getTimeoutTimestamp === 3000) // does not change + assert(state.getTimeoutTimestampMs.isPresent()) + assert(state.getTimeoutTimestampMs.get() === 3000) // does not change state.setTimeoutDuration(500) // can still be set - assert(state.getTimeoutTimestamp === 1500) + assert(state.getTimeoutTimestampMs.get() === 1500) testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) // for batch queries state = GroupStateImpl.createForBatch( ProcessingTimeTimeout, watermarkPresent = false).asInstanceOf[GroupStateImpl[Int]] - assert(state.getTimeoutTimestamp === NO_TIMESTAMP) + assert(!state.getTimeoutTimestampMs.isPresent()) state.setTimeoutDuration(500) testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) @@ -163,32 +198,31 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest { } test("GroupState - setTimeout - with EventTimeTimeout") { - var state: GroupStateImpl[Int] = GroupStateImpl.createForStreaming( - None, 1000, 1000, EventTimeTimeout, false, watermarkPresent = true) - - assert(state.getTimeoutTimestamp === NO_TIMESTAMP) + var state = TestGroupState.create[Int]( + Optional.empty[Int], EventTimeTimeout, 1000, Optional.of(1000), hasTimedOut = false) + assert(!state.getTimeoutTimestampMs.isPresent()) testTimeoutDurationNotAllowed[UnsupportedOperationException](state) state.setTimeoutTimestamp(5000) - assert(state.getTimeoutTimestamp === 5000) // can be set without initializing state + assert(state.getTimeoutTimestampMs.get() === 5000) // can be set without initializing state state.update(5) - assert(state.getTimeoutTimestamp === 5000) // does not change + assert(state.getTimeoutTimestampMs.get() === 5000) // does not change state.setTimeoutTimestamp(10000) - assert(state.getTimeoutTimestamp === 10000) + assert(state.getTimeoutTimestampMs.get() === 10000) state.setTimeoutTimestamp(new Date(20000)) - assert(state.getTimeoutTimestamp === 20000) + assert(state.getTimeoutTimestampMs.get() === 20000) testTimeoutDurationNotAllowed[UnsupportedOperationException](state) state.remove() - assert(state.getTimeoutTimestamp === 20000) + assert(state.getTimeoutTimestampMs.get() === 20000) state.setTimeoutTimestamp(5000) - assert(state.getTimeoutTimestamp === 5000) // can be set after removing state + assert(state.getTimeoutTimestampMs.get() === 5000) // can be set after removing state testTimeoutDurationNotAllowed[UnsupportedOperationException](state) // for batch queries - state = GroupStateImpl.createForBatch(EventTimeTimeout, watermarkPresent = false) - .asInstanceOf[GroupStateImpl[Int]] - assert(state.getTimeoutTimestamp === NO_TIMESTAMP) + state = GroupStateImpl.createForBatch( + EventTimeTimeout, watermarkPresent = false).asInstanceOf[GroupStateImpl[Int]] + assert(!state.getTimeoutTimestampMs.isPresent()) testTimeoutDurationNotAllowed[UnsupportedOperationException](state) state.setTimeoutTimestamp(5000) @@ -203,18 +237,20 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest { } test("GroupState - illegal params to setTimeout") { - var state: GroupStateImpl[Int] = null + var state: TestGroupState[Int] = null - // Test setTimeout****() with illegal values + // Test setTimeout() with illegal values def testIllegalTimeout(body: => Unit): Unit = { intercept[IllegalArgumentException] { body } - assert(state.getTimeoutTimestamp === NO_TIMESTAMP) + assert(!state.getTimeoutTimestampMs.isPresent()) } - state = GroupStateImpl.createForStreaming( - Some(5), 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false, watermarkPresent = false) + // Test setTimeout() with illegal values + state = TestGroupState.create[Int]( + Optional.of(5), ProcessingTimeTimeout, 1000, Optional.empty[Long], hasTimedOut = false) + testIllegalTimeout { state.setTimeoutDuration(-1000) } @@ -232,8 +268,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest { state.setTimeoutDuration("1 month -31 day") } - state = GroupStateImpl.createForStreaming( - Some(5), 1000, 1000, EventTimeTimeout, hasTimedOut = false, watermarkPresent = false) + state = TestGroupState.create[Int]( + Optional.of(5), EventTimeTimeout, 1000, Optional.of(1000), hasTimedOut = false) testIllegalTimeout { state.setTimeoutTimestamp(-10000) } @@ -260,17 +296,64 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest { } } + test("SPARK-35800: illegal params to create") { + // eventTimeWatermarkMs >= 0 if present + var illegalArgument = intercept[IllegalArgumentException] { + TestGroupState.create[Int]( + Optional.of(5), EventTimeTimeout, 100L, Optional.of(-1000), hasTimedOut = false) + } + assert( + illegalArgument.getMessage.contains("eventTimeWatermarkMs must be 0 or positive if present")) + illegalArgument = intercept[IllegalArgumentException] { + GroupStateImpl.createForStreaming[Int]( + Some(5), 100L, -1000L, EventTimeTimeout, false, true) + } + assert( + illegalArgument.getMessage.contains("eventTimeWatermarkMs must be 0 or positive if present")) + + // batchProcessingTimeMs must be positive + illegalArgument = intercept[IllegalArgumentException] { + TestGroupState.create[Int]( + Optional.of(5), EventTimeTimeout, -100L, Optional.of(1000), hasTimedOut = false) + } + assert(illegalArgument.getMessage.contains("batchProcessingTimeMs must be 0 or positive")) + illegalArgument = intercept[IllegalArgumentException] { + GroupStateImpl.createForStreaming[Int]( + Some(5), -100L, 1000L, EventTimeTimeout, false, true) + } + assert(illegalArgument.getMessage.contains("batchProcessingTimeMs must be 0 or positive")) + + // hasTimedOut cannot be true if there's no timeout configured + var unsupportedOperation = intercept[UnsupportedOperationException] { + TestGroupState.create[Int]( + Optional.of(5), NoTimeout, 100L, Optional.empty[Long], hasTimedOut = true) + } + assert( + unsupportedOperation + .getMessage.contains("hasTimedOut is true however there's no timeout configured")) + unsupportedOperation = intercept[UnsupportedOperationException] { + GroupStateImpl.createForStreaming[Int]( + Some(5), 100L, NO_TIMESTAMP, NoTimeout, true, false) + } + assert( + unsupportedOperation + .getMessage.contains("hasTimedOut is true however there's no timeout configured")) + } + test("GroupState - hasTimedOut") { for (timeoutConf <- Seq(NoTimeout, ProcessingTimeTimeout, EventTimeTimeout)) { // for streaming queries - for (initState <- Seq(None, Some(5))) { - val state1 = GroupStateImpl.createForStreaming( - initState, 1000, 1000, timeoutConf, hasTimedOut = false, watermarkPresent = false) + for (initState <- Seq(Optional.empty[Int], Optional.of(5))) { + val state1 = TestGroupState.create[Int]( + initState, timeoutConf, 1000, Optional.empty[Long], hasTimedOut = false) assert(state1.hasTimedOut === false) - val state2 = GroupStateImpl.createForStreaming( - initState, 1000, 1000, timeoutConf, hasTimedOut = true, watermarkPresent = false) - assert(state2.hasTimedOut) + // hasTimedOut can only be set as true when timeoutConf isn't NoTimeout + if (timeoutConf != NoTimeout) { + val state2 = TestGroupState.create[Int]( + initState, timeoutConf, 1000, Optional.empty[Long], hasTimedOut = true) + assert(state2.hasTimedOut) + } } // for batch queries @@ -280,10 +363,11 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest { } test("GroupState - getCurrentWatermarkMs") { - def streamingState(timeoutConf: GroupStateTimeout, watermark: Option[Long]): GroupState[Int] = { - GroupStateImpl.createForStreaming( - None, 1000, watermark.getOrElse(-1), timeoutConf, - hasTimedOut = false, watermark.nonEmpty) + def streamingState( + timeoutConf: GroupStateTimeout, + watermark: Optional[Long]): GroupState[Int] = { + TestGroupState.create[Int]( + Optional.empty[Int], timeoutConf, 1000, watermark, hasTimedOut = false) } def batchState(timeoutConf: GroupStateTimeout, watermarkPresent: Boolean): GroupState[Any] = { @@ -298,9 +382,13 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest { for (timeoutConf <- Seq(NoTimeout, EventTimeTimeout, ProcessingTimeTimeout)) { // Tests for getCurrentWatermarkMs in streaming queries - assertWrongTimeoutError { streamingState(timeoutConf, None).getCurrentWatermarkMs() } - assert(streamingState(timeoutConf, Some(1000)).getCurrentWatermarkMs() === 1000) - assert(streamingState(timeoutConf, Some(2000)).getCurrentWatermarkMs() === 2000) + assertWrongTimeoutError { + streamingState(timeoutConf, Optional.empty[Long]).getCurrentWatermarkMs() + } + assert(streamingState(timeoutConf, Optional.of(0)).getCurrentWatermarkMs() === 0) + assert(streamingState(timeoutConf, Optional.of(1000)).getCurrentWatermarkMs() === 1000) + assert(streamingState(timeoutConf, Optional.of(2000)).getCurrentWatermarkMs() === 2000) + assert(batchState(EventTimeTimeout, watermarkPresent = true).getCurrentWatermarkMs() === -1) // Tests for getCurrentWatermarkMs in batch queries assertWrongTimeoutError { @@ -312,11 +400,15 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest { test("GroupState - getCurrentProcessingTimeMs") { def streamingState( - timeoutConf: GroupStateTimeout, - procTime: Long, - watermarkPresent: Boolean): GroupState[Int] = { - GroupStateImpl.createForStreaming( - None, procTime, -1, timeoutConf, hasTimedOut = false, watermarkPresent = false) + timeoutConf: GroupStateTimeout, + procTime: Long, + watermarkPresent: Boolean): GroupState[Int] = { + val eventTimeWatermarkMs = watermarkPresent match { + case true => Optional.of(1000L) + case false => Optional.empty[Long] + } + TestGroupState.create[Int]( + Optional.of(1000), timeoutConf, procTime, eventTimeWatermarkMs, hasTimedOut = false) } def batchState(timeoutConf: GroupStateTimeout, watermarkPresent: Boolean): GroupState[Any] = { @@ -326,8 +418,10 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest { for (timeoutConf <- Seq(NoTimeout, EventTimeTimeout, ProcessingTimeTimeout)) { for (watermarkPresent <- Seq(false, true)) { // Tests for getCurrentProcessingTimeMs in streaming queries - assert(streamingState(timeoutConf, NO_TIMESTAMP, watermarkPresent) - .getCurrentProcessingTimeMs() === -1) + // No negative processing time is allowed, and + // illegal input validation has been added in the separate test + assert(streamingState(timeoutConf, 0, watermarkPresent) + .getCurrentProcessingTimeMs() === 0) assert(streamingState(timeoutConf, 1000, watermarkPresent) .getCurrentProcessingTimeMs() === 1000) assert(streamingState(timeoutConf, 2000, watermarkPresent) @@ -342,15 +436,24 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest { test("GroupState - primitive type") { - var intState = GroupStateImpl.createForStreaming[Int]( - None, 1000, 1000, NoTimeout, hasTimedOut = false, watermarkPresent = false) + var intState = TestGroupState.create[Int]( + Optional.empty[Int], + NoTimeout, + 1000, + Optional.empty[Long], + hasTimedOut = false) intercept[NoSuchElementException] { intState.get } assert(intState.getOption === None) - intState = GroupStateImpl.createForStreaming[Int]( - Some(10), 1000, 1000, NoTimeout, hasTimedOut = false, watermarkPresent = false) + intState = TestGroupState.create[Int]( + Optional.of(10), + NoTimeout, + 1000, + Optional.empty[Long], + hasTimedOut = false) + assert(intState.get == 10) intState.update(0) assert(intState.get == 0) @@ -1291,24 +1394,24 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest { }.get } - def testTimeoutDurationNotAllowed[T <: Exception: Manifest](state: GroupStateImpl[_]): Unit = { - val prevTimestamp = state.getTimeoutTimestamp + def testTimeoutDurationNotAllowed[T <: Exception: Manifest](state: TestGroupState[_]): Unit = { + val prevTimestamp = state.getTimeoutTimestampMs intercept[T] { state.setTimeoutDuration(1000) } - assert(state.getTimeoutTimestamp === prevTimestamp) + assert(state.getTimeoutTimestampMs === prevTimestamp) intercept[T] { state.setTimeoutDuration("2 second") } - assert(state.getTimeoutTimestamp === prevTimestamp) + assert(state.getTimeoutTimestampMs === prevTimestamp) } - def testTimeoutTimestampNotAllowed[T <: Exception: Manifest](state: GroupStateImpl[_]): Unit = { - val prevTimestamp = state.getTimeoutTimestamp + def testTimeoutTimestampNotAllowed[T <: Exception: Manifest](state: TestGroupState[_]): Unit = { + val prevTimestamp = state.getTimeoutTimestampMs intercept[T] { state.setTimeoutTimestamp(2000) } - assert(state.getTimeoutTimestamp === prevTimestamp) + assert(state.getTimeoutTimestampMs === prevTimestamp) intercept[T] { state.setTimeoutTimestamp(2000, "1 second") } - assert(state.getTimeoutTimestamp === prevTimestamp) + assert(state.getTimeoutTimestampMs === prevTimestamp) intercept[T] { state.setTimeoutTimestamp(new Date(2000)) } - assert(state.getTimeoutTimestamp === prevTimestamp) + assert(state.getTimeoutTimestampMs === prevTimestamp) intercept[T] { state.setTimeoutTimestamp(new Date(2000), "1 second") } - assert(state.getTimeoutTimestamp === prevTimestamp) + assert(state.getTimeoutTimestampMs === prevTimestamp) } def newStateStore(): StateStore = new MemoryStateStore() --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org