Repository: spark
Updated Branches:
  refs/heads/master 8d707b060 -> b3d88ac02


[SPARK-22187][SS] Update unsaferow format for saved state in 
flatMapGroupsWithState to allow timeouts with deleted state

## What changes were proposed in this pull request?

Currently, the group state of user-defined-type is encoded as top-level columns 
in the UnsafeRows stores in the state store. The timeout timestamp is also 
saved as (when needed) as the last top-level column. Since the group state is 
serialized to top-level columns, you cannot save "null" as a value of state 
(setting null in all the top-level columns is not equivalent). So we don't let 
the user set the timeout without initializing the state for a key. Based on 
user experience, this leads to confusion.

This PR is to change the row format such that the state is saved as nested 
columns. This would allow the state to be set to null, and avoid these 
confusing corner cases. However, queries recovering from existing checkpoint 
will use the previous format to maintain compatibility with existing production 
queries.

## How was this patch tested?
Refactored existing end-to-end tests and added new tests for explicitly testing 
obj-to-row conversion for both state formats.

Author: Tathagata Das <[email protected]>

Closes #21739 from tdas/SPARK-22187-1.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b3d88ac0
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b3d88ac0
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b3d88ac0

Branch: refs/heads/master
Commit: b3d88ac02940eff4c867d3acb79fe5ff9d724e83
Parents: 8d707b0
Author: Tathagata Das <[email protected]>
Authored: Thu Jul 19 13:17:28 2018 -0700
Committer: Tathagata Das <[email protected]>
Committed: Thu Jul 19 13:17:28 2018 -0700

----------------------------------------------------------------------
 .../sql/catalyst/expressions/Expression.scala   |   3 +-
 .../org/apache/spark/sql/internal/SQLConf.scala |   8 +
 .../spark/sql/execution/SparkStrategies.scala   |   5 +-
 .../streaming/FlatMapGroupsWithStateExec.scala  | 136 +++-------
 .../sql/execution/streaming/OffsetSeq.scala     |  10 +-
 .../FlatMapGroupsWithStateExecHelper.scala      | 247 ++++++++++++++++++
 .../commits/0                                   |   2 +
 .../commits/1                                   |   2 +
 .../metadata                                    |   1 +
 .../offsets/0                                   |   3 +
 .../offsets/1                                   |   3 +
 .../state/0/0/1.delta                           | Bin 0 -> 84 bytes
 .../state/0/0/2.delta                           | Bin 0 -> 46 bytes
 .../state/0/1/1.delta                           | Bin 0 -> 46 bytes
 .../state/0/1/2.delta                           | Bin 0 -> 46 bytes
 .../state/0/2/1.delta                           | Bin 0 -> 46 bytes
 .../state/0/2/2.delta                           | Bin 0 -> 46 bytes
 .../state/0/3/1.delta                           | Bin 0 -> 46 bytes
 .../state/0/3/2.delta                           | Bin 0 -> 46 bytes
 .../state/0/4/1.delta                           | Bin 0 -> 46 bytes
 .../state/0/4/2.delta                           | Bin 0 -> 46 bytes
 .../FlatMapGroupsWithStateExecHelperSuite.scala | 218 ++++++++++++++++
 .../streaming/FlatMapGroupsWithStateSuite.scala | 250 ++++++++++++++-----
 23 files changed, 708 insertions(+), 180 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b3d88ac0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index f7d1b10..a69b804 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -715,7 +715,8 @@ trait ComplexTypeMergingExpression extends Expression {
       "The collection of input data types must not be empty.")
     require(
       TypeCoercion.haveSameType(inputTypesForMerging),
-      "All input types must be the same except nullable, containsNull, 
valueContainsNull flags.")
+      "All input types must be the same except nullable, containsNull, 
valueContainsNull flags." +
+        s" The input types found 
are\n\t${inputTypesForMerging.mkString("\n\t")}")
     
inputTypesForMerging.reduceLeft(TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(_,
 _).get)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/b3d88ac0/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 9239d4e..fbb9a8c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -843,6 +843,14 @@ object SQLConf {
       .intConf
       .createWithDefault(10)
 
+  val FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION =
+    buildConf("spark.sql.streaming.flatMapGroupsWithState.stateFormatVersion")
+      .internal()
+      .doc("State format version used by flatMapGroupsWithState operation in a 
streaming query")
+      .intConf
+      .checkValue(v => Set(1, 2).contains(v), "Valid versions are 1 and 2")
+      .createWithDefault(2)
+
   val CHECKPOINT_LOCATION = buildConf("spark.sql.streaming.checkpointLocation")
     .doc("The default location for storing checkpoint data for streaming 
queries.")
     .stringConf

http://git-wip-us.apache.org/repos/asf/spark/blob/b3d88ac0/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 02e095b..0c4ea85 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -504,9 +504,10 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
       case FlatMapGroupsWithState(
         func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, stateEnc, 
outputMode, _,
         timeout, child) =>
+        val stateVersion = 
conf.getConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION)
         val execPlan = FlatMapGroupsWithStateExec(
-          func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, None, 
stateEnc, outputMode,
-          timeout, batchTimestampMs = None, eventTimeWatermark = None, 
planLater(child))
+          func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, None, 
stateEnc, stateVersion,
+          outputMode, timeout, batchTimestampMs = None, eventTimeWatermark = 
None, planLater(child))
         execPlan :: Nil
       case _ =>
         Nil

http://git-wip-us.apache.org/repos/asf/spark/blob/b3d88ac0/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
index 8e82ccc..bfe7d00 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
@@ -23,10 +23,8 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, 
Attribute, Attribut
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, 
Distribution}
 import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP
 import org.apache.spark.sql.execution.streaming.state._
 import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
-import org.apache.spark.sql.types.IntegerType
 import org.apache.spark.util.CompletionIterator
 
 /**
@@ -52,6 +50,7 @@ case class FlatMapGroupsWithStateExec(
     outputObjAttr: Attribute,
     stateInfo: Option[StatefulOperatorStateInfo],
     stateEncoder: ExpressionEncoder[Any],
+    stateFormatVersion: Int,
     outputMode: OutputMode,
     timeoutConf: GroupStateTimeout,
     batchTimestampMs: Option[Long],
@@ -60,32 +59,15 @@ case class FlatMapGroupsWithStateExec(
   ) extends UnaryExecNode with ObjectProducerExec with StateStoreWriter with 
WatermarkSupport {
 
   import GroupStateImpl._
+  import FlatMapGroupsWithStateExecHelper._
 
   private val isTimeoutEnabled = timeoutConf != NoTimeout
-  private val timestampTimeoutAttribute =
-    AttributeReference("timeoutTimestamp", dataType = IntegerType, nullable = 
false)()
-  private val stateAttributes: Seq[Attribute] = {
-    val encSchemaAttribs = stateEncoder.schema.toAttributes
-    if (isTimeoutEnabled) encSchemaAttribs :+ timestampTimeoutAttribute else 
encSchemaAttribs
-  }
-  // Get the serializer for the state, taking into account whether we need to 
save timestamps
-  private val stateSerializer = {
-    val encoderSerializer = stateEncoder.namedExpressions
-    if (isTimeoutEnabled) {
-      encoderSerializer :+ Literal(GroupStateImpl.NO_TIMESTAMP)
-    } else {
-      encoderSerializer
-    }
-  }
-  // Get the deserializer for the state. Note that this must be done in the 
driver, as
-  // resolving and binding of deserializer expressions to the encoded type can 
be safely done
-  // only in the driver.
-  private val stateDeserializer = stateEncoder.resolveAndBind().deserializer
-
   private val watermarkPresent = child.output.exists {
     case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => 
true
     case _ => false
   }
+  private[sql] val stateManager =
+    createStateManager(stateEncoder, isTimeoutEnabled, stateFormatVersion)
 
   /** Distribute by grouping attributes */
   override def requiredChildDistribution: Seq[Distribution] =
@@ -125,11 +107,11 @@ case class FlatMapGroupsWithStateExec(
     child.execute().mapPartitionsWithStateStore[InternalRow](
       getStateInfo,
       groupingAttributes.toStructType,
-      stateAttributes.toStructType,
+      stateManager.stateSchema,
       indexOrdinal = None,
       sqlContext.sessionState,
       Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) =>
-        val updater = new StateStoreUpdater(store)
+        val processor = new InputProcessor(store)
 
         // If timeout is based on event time, then filter late data based on 
watermark
         val filteredIter = watermarkPredicateForData match {
@@ -143,7 +125,7 @@ case class FlatMapGroupsWithStateExec(
         // all the data has been processed. This is to ensure that the timeout 
information of all
         // the keys with data is updated before they are processed for 
timeouts.
         val outputIterator =
-          updater.updateStateForKeysWithData(filteredIter) ++ 
updater.updateStateForTimedOutKeys()
+          processor.processNewData(filteredIter) ++ 
processor.processTimedOutState()
 
         // Return an iterator of all the rows generated by all the keys, such 
that when fully
         // consumed, all the state updates will be committed by the state store
@@ -158,7 +140,7 @@ case class FlatMapGroupsWithStateExec(
   }
 
   /** Helper class to update the state store */
-  class StateStoreUpdater(store: StateStore) {
+  class InputProcessor(store: StateStore) {
 
     // Converters for translating input keys, values, output data between rows 
and Java objects
     private val getKeyObj =
@@ -167,14 +149,6 @@ case class FlatMapGroupsWithStateExec(
       ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes)
     private val getOutputRow = 
ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
 
-    // Converters for translating state between rows and Java objects
-    private val getStateObjFromRow = ObjectOperator.deserializeRowToObject(
-      stateDeserializer, stateAttributes)
-    private val getStateRowFromObj = 
ObjectOperator.serializeObjectToRow(stateSerializer)
-
-    // Index of the additional metadata fields in the state row
-    private val timeoutTimestampIndex = 
stateAttributes.indexOf(timestampTimeoutAttribute)
-
     // Metrics
     private val numUpdatedStateRows = longMetric("numUpdatedStateRows")
     private val numOutputRows = longMetric("numOutputRows")
@@ -183,20 +157,19 @@ case class FlatMapGroupsWithStateExec(
      * For every group, get the key, values and corresponding state and call 
the function,
      * and return an iterator of rows
      */
-    def updateStateForKeysWithData(dataIter: Iterator[InternalRow]): 
Iterator[InternalRow] = {
+    def processNewData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] 
= {
       val groupedIter = GroupedIterator(dataIter, groupingAttributes, 
child.output)
       groupedIter.flatMap { case (keyRow, valueRowIter) =>
         val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow]
         callFunctionAndUpdateState(
-          keyUnsafeRow,
+          stateManager.getState(store, keyUnsafeRow),
           valueRowIter,
-          store.get(keyUnsafeRow),
           hasTimedOut = false)
       }
     }
 
     /** Find the groups that have timeout set and are timing out right now, 
and call the function */
-    def updateStateForTimedOutKeys(): Iterator[InternalRow] = {
+    def processTimedOutState(): Iterator[InternalRow] = {
       if (isTimeoutEnabled) {
         val timeoutThreshold = timeoutConf match {
           case ProcessingTimeTimeout => batchTimestampMs.get
@@ -205,12 +178,11 @@ case class FlatMapGroupsWithStateExec(
             throw new IllegalStateException(
               s"Cannot filter timed out keys for $timeoutConf")
         }
-        val timingOutPairs = store.getRange(None, None).filter { rowPair =>
-          val timeoutTimestamp = getTimeoutTimestamp(rowPair.value)
-          timeoutTimestamp != NO_TIMESTAMP && timeoutTimestamp < 
timeoutThreshold
+        val timingOutPairs = stateManager.getAllState(store).filter { state =>
+          state.timeoutTimestamp != NO_TIMESTAMP && state.timeoutTimestamp < 
timeoutThreshold
         }
-        timingOutPairs.flatMap { rowPair =>
-          callFunctionAndUpdateState(rowPair.key, Iterator.empty, 
rowPair.value, hasTimedOut = true)
+        timingOutPairs.flatMap { stateData =>
+          callFunctionAndUpdateState(stateData, Iterator.empty, hasTimedOut = 
true)
         }
       } else Iterator.empty
     }
@@ -220,22 +192,19 @@ case class FlatMapGroupsWithStateExec(
      * iterator. Note that the store updating is lazy, that is, the store will 
be updated only
      * after the returned iterator is fully consumed.
      *
-     * @param keyRow Row representing the key, cannot be null
+     * @param stateData All the data related to the state to be updated
      * @param valueRowIter Iterator of values as rows, cannot be null, but can 
be empty
-     * @param prevStateRow Row representing the previous state, can be null
      * @param hasTimedOut Whether this function is being called for a key 
timeout
      */
     private def callFunctionAndUpdateState(
-        keyRow: UnsafeRow,
+        stateData: StateData,
         valueRowIter: Iterator[InternalRow],
-        prevStateRow: UnsafeRow,
         hasTimedOut: Boolean): Iterator[InternalRow] = {
 
-      val keyObj = getKeyObj(keyRow)  // convert key to objects
+      val keyObj = getKeyObj(stateData.keyRow)  // convert key to objects
       val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value 
rows to objects
-      val stateObj = getStateObj(prevStateRow)
-      val keyedState = GroupStateImpl.createForStreaming(
-        Option(stateObj),
+      val groupState = GroupStateImpl.createForStreaming(
+        Option(stateData.stateObj),
         batchTimestampMs.getOrElse(NO_TIMESTAMP),
         eventTimeWatermark.getOrElse(NO_TIMESTAMP),
         timeoutConf,
@@ -243,50 +212,24 @@ case class FlatMapGroupsWithStateExec(
         watermarkPresent)
 
       // Call function, get the returned objects and convert them to rows
-      val mappedIterator = func(keyObj, valueObjIter, keyedState).map { obj =>
+      val mappedIterator = func(keyObj, valueObjIter, groupState).map { obj =>
         numOutputRows += 1
         getOutputRow(obj)
       }
 
       // When the iterator is consumed, then write changes to state
       def onIteratorCompletion: Unit = {
-
-        val currentTimeoutTimestamp = keyedState.getTimeoutTimestamp
-        // If the state has not yet been set but timeout has been set, then
-        // we have to generate a row to save the timeout. However, attempting 
serialize
-        // null using case class encoder throws -
-        //    java.lang.NullPointerException: Null value appeared in 
non-nullable field:
-        //    If the schema is inferred from a Scala tuple / case class, or a 
Java bean, please
-        //    try to use scala.Option[_] or other nullable types.
-        if (!keyedState.exists && currentTimeoutTimestamp != NO_TIMESTAMP) {
-          throw new IllegalStateException(
-            "Cannot set timeout when state is not defined, that is, state has 
not been" +
-              "initialized or has been removed")
-        }
-
-        if (keyedState.hasRemoved) {
-          store.remove(keyRow)
+        if (groupState.hasRemoved && groupState.getTimeoutTimestamp == 
NO_TIMESTAMP) {
+          stateManager.removeState(store, stateData.keyRow)
           numUpdatedStateRows += 1
-
         } else {
-          val previousTimeoutTimestamp = getTimeoutTimestamp(prevStateRow)
-          val stateRowToWrite = if (keyedState.hasUpdated) {
-            getStateRow(keyedState.get)
-          } else {
-            prevStateRow
-          }
-
-          val hasTimeoutChanged = currentTimeoutTimestamp != 
previousTimeoutTimestamp
-          val shouldWriteState = keyedState.hasUpdated || hasTimeoutChanged
+          val currentTimeoutTimestamp = groupState.getTimeoutTimestamp
+          val hasTimeoutChanged = currentTimeoutTimestamp != 
stateData.timeoutTimestamp
+          val shouldWriteState = groupState.hasUpdated || 
groupState.hasRemoved || hasTimeoutChanged
 
           if (shouldWriteState) {
-            if (stateRowToWrite == null) {
-              // This should never happen because checks in GroupStateImpl 
should avoid cases
-              // where empty state would need to be written
-              throw new IllegalStateException("Attempting to write empty 
state")
-            }
-            setTimeoutTimestamp(stateRowToWrite, currentTimeoutTimestamp)
-            store.put(keyRow, stateRowToWrite)
+            val updatedStateObj = if (groupState.exists) groupState.get else 
null
+            stateManager.putState(store, stateData.keyRow, updatedStateObj, 
currentTimeoutTimestamp)
             numUpdatedStateRows += 1
           }
         }
@@ -295,28 +238,5 @@ case class FlatMapGroupsWithStateExec(
       // Return an iterator of rows such that fully consumed, the updated 
state value will be saved
       CompletionIterator[InternalRow, Iterator[InternalRow]](mappedIterator, 
onIteratorCompletion)
     }
-
-    /** Returns the state as Java object if defined */
-    def getStateObj(stateRow: UnsafeRow): Any = {
-      if (stateRow != null) getStateObjFromRow(stateRow) else null
-    }
-
-    /** Returns the row for an updated state */
-    def getStateRow(obj: Any): UnsafeRow = {
-      assert(obj != null)
-      getStateRowFromObj(obj)
-    }
-
-    /** Returns the timeout timestamp of a state row is set */
-    def getTimeoutTimestamp(stateRow: UnsafeRow): Long = {
-      if (isTimeoutEnabled && stateRow != null) {
-        stateRow.getLong(timeoutTimestampIndex)
-      } else NO_TIMESTAMP
-    }
-
-    /** Set the timestamp in a state row */
-    def setTimeoutTimestamp(stateRow: UnsafeRow, timeoutTimestamps: Long): 
Unit = {
-      if (isTimeoutEnabled) stateRow.setLong(timeoutTimestampIndex, 
timeoutTimestamps)
-    }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/b3d88ac0/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala
index 1ae3f36..9847756 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala
@@ -22,7 +22,8 @@ import org.json4s.jackson.Serialization
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.RuntimeConfig
-import org.apache.spark.sql.internal.SQLConf._
+import 
org.apache.spark.sql.execution.streaming.state.FlatMapGroupsWithStateExecHelper
+import 
org.apache.spark.sql.internal.SQLConf.{FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION,
 _}
 
 /**
  * An ordered collection of offsets, used to track the progress of processing 
data from one or more
@@ -87,7 +88,8 @@ case class OffsetSeqMetadata(
 object OffsetSeqMetadata extends Logging {
   private implicit val format = Serialization.formats(NoTypeHints)
   private val relevantSQLConfs = Seq(
-    SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS, 
STREAMING_MULTIPLE_WATERMARK_POLICY)
+    SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS, 
STREAMING_MULTIPLE_WATERMARK_POLICY,
+    FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION)
 
   /**
    * Default values of relevant configurations that are used for backward 
compatibility.
@@ -100,7 +102,9 @@ object OffsetSeqMetadata extends Logging {
    * with a specific default value for ensuring same behavior of the query as 
before.
    */
   private val relevantSQLConfDefaultValues = Map[String, String](
-    STREAMING_MULTIPLE_WATERMARK_POLICY.key -> 
MultipleWatermarkPolicy.DEFAULT_POLICY_NAME
+    STREAMING_MULTIPLE_WATERMARK_POLICY.key -> 
MultipleWatermarkPolicy.DEFAULT_POLICY_NAME,
+    FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key ->
+      FlatMapGroupsWithStateExecHelper.legacyVersion.toString
   )
 
   def apply(json: String): OffsetSeqMetadata = 
Serialization.read[OffsetSeqMetadata](json)

http://git-wip-us.apache.org/repos/asf/spark/blob/b3d88ac0/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala
new file mode 100644
index 0000000..0a16a38
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala
@@ -0,0 +1,247 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.execution.ObjectOperator
+import org.apache.spark.sql.execution.streaming.GroupStateImpl
+import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP
+import org.apache.spark.sql.types._
+
+
+object FlatMapGroupsWithStateExecHelper {
+
+  val supportedVersions = Seq(1, 2)
+  val legacyVersion = 1
+
+  /**
+   * Class to capture deserialized state and timestamp return by the state 
manager.
+   * This is intended for reuse.
+   */
+  case class StateData(
+      var keyRow: UnsafeRow = null,
+      var stateRow: UnsafeRow = null,
+      var stateObj: Any = null,
+      var timeoutTimestamp: Long = -1) {
+
+    private[FlatMapGroupsWithStateExecHelper] def withNew(
+        newKeyRow: UnsafeRow,
+        newStateRow: UnsafeRow,
+        newStateObj: Any,
+        newTimeout: Long): this.type = {
+      keyRow = newKeyRow
+      stateRow = newStateRow
+      stateObj = newStateObj
+      timeoutTimestamp = newTimeout
+      this
+    }
+  }
+
+  /** Interface for interacting with state data of FlatMapGroupsWithState */
+  sealed trait StateManager extends Serializable {
+    def stateSchema: StructType
+    def getState(store: StateStore, keyRow: UnsafeRow): StateData
+    def putState(store: StateStore, keyRow: UnsafeRow, state: Any, 
timeoutTimestamp: Long): Unit
+    def removeState(store: StateStore, keyRow: UnsafeRow): Unit
+    def getAllState(store: StateStore): Iterator[StateData]
+  }
+
+  def createStateManager(
+      stateEncoder: ExpressionEncoder[Any],
+      shouldStoreTimestamp: Boolean,
+      stateFormatVersion: Int): StateManager = {
+    stateFormatVersion match {
+      case 1 => new StateManagerImplV1(stateEncoder, shouldStoreTimestamp)
+      case 2 => new StateManagerImplV2(stateEncoder, shouldStoreTimestamp)
+      case _ => throw new IllegalArgumentException(s"Version 
$stateFormatVersion is invalid")
+    }
+  }
+
+  // 
===============================================================================================
+  // =========================== Private implementations of StateManager 
===========================
+  // 
===============================================================================================
+
+  /** Commmon methods for StateManager implementations */
+  private abstract class StateManagerImplBase(shouldStoreTimestamp: Boolean)
+    extends StateManager {
+
+    protected def stateSerializerExprs: Seq[Expression]
+    protected def stateDeserializerExpr: Expression
+    protected def timeoutTimestampOrdinalInRow: Int
+
+    /** Get deserialized state and corresponding timeout timestamp for a key */
+    override def getState(store: StateStore, keyRow: UnsafeRow): StateData = {
+      val stateRow = store.get(keyRow)
+      stateDataForGets.withNew(keyRow, stateRow, getStateObject(stateRow), 
getTimestamp(stateRow))
+    }
+
+    /** Put state and timeout timestamp for a key */
+    override def putState(store: StateStore, key: UnsafeRow, state: Any, 
timestamp: Long): Unit = {
+      val stateRow = getStateRow(state)
+      setTimestamp(stateRow, timestamp)
+      store.put(key, stateRow)
+    }
+
+    override def removeState(store: StateStore, keyRow: UnsafeRow): Unit = {
+      store.remove(keyRow)
+    }
+
+    override def getAllState(store: StateStore): Iterator[StateData] = {
+      val stateData = StateData()
+      store.getRange(None, None).map { p =>
+        stateData.withNew(p.key, p.value, getStateObject(p.value), 
getTimestamp(p.value))
+      }
+    }
+
+    private lazy val stateSerializerFunc = 
ObjectOperator.serializeObjectToRow(stateSerializerExprs)
+    private lazy val stateDeserializerFunc = {
+      ObjectOperator.deserializeRowToObject(stateDeserializerExpr, 
stateSchema.toAttributes)
+    }
+    private lazy val stateDataForGets = StateData()
+
+    protected def getStateObject(row: UnsafeRow): Any = {
+      if (row != null) stateDeserializerFunc(row) else null
+    }
+
+    protected def getStateRow(obj: Any): UnsafeRow = {
+      stateSerializerFunc(obj)
+    }
+
+    /** Returns the timeout timestamp of a state row is set */
+    private def getTimestamp(stateRow: UnsafeRow): Long = {
+      if (shouldStoreTimestamp && stateRow != null) {
+        stateRow.getLong(timeoutTimestampOrdinalInRow)
+      } else NO_TIMESTAMP
+    }
+
+    /** Set the timestamp in a state row */
+    private def setTimestamp(stateRow: UnsafeRow, timeoutTimestamps: Long): 
Unit = {
+      if (shouldStoreTimestamp) stateRow.setLong(timeoutTimestampOrdinalInRow, 
timeoutTimestamps)
+    }
+  }
+
+  /**
+   * Version 1 of the StateManager which stores the user-defined state as 
flattened columns in
+   * the UnsafeRow. Say the user-defined state has 3 fields - col1, col2, 
col3. The
+   * unsafe rows will look like this.
+   *
+   *    UnsafeRow[ col1 | col2 | col3 | timestamp ]
+   *
+   * The limitation of this format is that timestamp cannot be set when the 
user-defined
+   * state has been removed. This is because the columns cannot be 
collectively marked to be
+   * empty/null.
+   */
+  private class StateManagerImplV1(
+      stateEncoder: ExpressionEncoder[Any],
+      shouldStoreTimestamp: Boolean) extends 
StateManagerImplBase(shouldStoreTimestamp) {
+
+    private val timestampTimeoutAttribute =
+      AttributeReference("timeoutTimestamp", dataType = IntegerType, nullable 
= false)()
+
+    private val stateAttributes: Seq[Attribute] = {
+      val encSchemaAttribs = stateEncoder.schema.toAttributes
+      if (shouldStoreTimestamp) encSchemaAttribs :+ timestampTimeoutAttribute 
else encSchemaAttribs
+    }
+
+    override val stateSchema: StructType = stateAttributes.toStructType
+
+    override val timeoutTimestampOrdinalInRow: Int = {
+      stateAttributes.indexOf(timestampTimeoutAttribute)
+    }
+
+    override val stateSerializerExprs: Seq[Expression] = {
+      val encoderSerializer = stateEncoder.namedExpressions
+      if (shouldStoreTimestamp) {
+        encoderSerializer :+ Literal(GroupStateImpl.NO_TIMESTAMP)
+      } else {
+        encoderSerializer
+      }
+    }
+
+    override val stateDeserializerExpr: Expression = {
+      // Note that this must be done in the driver, as resolving and binding 
of deserializer
+      // expressions to the encoded type can be safely done only in the driver.
+      stateEncoder.resolveAndBind().deserializer
+    }
+
+    override protected def getStateRow(obj: Any): UnsafeRow = {
+      require(obj != null, "State object cannot be null")
+      super.getStateRow(obj)
+    }
+  }
+
+  /**
+   * Version 2 of the StateManager which stores the user-defined state as a 
nested struct
+   * in the UnsafeRow. Say the user-defined state has 3 fields - col1, col2, 
col3. The
+   * unsafe rows will look like this.
+   *                    ___________________________
+   *                   |                           |
+   *                   |                           V
+   *    UnsafeRow[ nested-struct | timestamp |  UnsafeRow[ col1 | col2 | col3 
] ]
+   *
+   * This allows the entire user-defined state to be collectively marked as 
empty/null,
+   * thus allowing timestamp to be set without requiring the state to be 
present.
+   */
+  private class StateManagerImplV2(
+      stateEncoder: ExpressionEncoder[Any],
+      shouldStoreTimestamp: Boolean) extends 
StateManagerImplBase(shouldStoreTimestamp) {
+
+    /** Schema of the state rows saved in the state store */
+    override val stateSchema: StructType = {
+      var schema = new StructType().add("groupState", stateEncoder.schema, 
nullable = true)
+      if (shouldStoreTimestamp) schema = schema.add("timeoutTimestamp", 
LongType, nullable = false)
+      schema
+    }
+
+    // Ordinals of the information stored in the state row
+    private val nestedStateOrdinal = 0
+    override val timeoutTimestampOrdinalInRow = 1
+
+    override val stateSerializerExprs: Seq[Expression] = {
+      val boundRefToSpecificInternalRow = BoundReference(
+        0, stateEncoder.serializer.head.collect { case b: BoundReference => 
b.dataType }.head, true)
+
+      val nestedStateSerExpr =
+        CreateNamedStruct(stateEncoder.namedExpressions.flatMap(e => 
Seq(Literal(e.name), e)))
+
+      val nullSafeNestedStateSerExpr = {
+        val nullLiteral = Literal(null, nestedStateSerExpr.dataType)
+        CaseWhen(Seq(IsNull(boundRefToSpecificInternalRow) -> nullLiteral), 
nestedStateSerExpr)
+      }
+
+      if (shouldStoreTimestamp) {
+        Seq(nullSafeNestedStateSerExpr, Literal(GroupStateImpl.NO_TIMESTAMP))
+      } else {
+        Seq(nullSafeNestedStateSerExpr)
+      }
+    }
+
+    override val stateDeserializerExpr: Expression = {
+      // Note that this must be done in the driver, as resolving and binding 
of deserializer
+      // expressions to the encoded type can be safely done only in the driver.
+      val boundRefToNestedState =
+        BoundReference(nestedStateOrdinal, stateEncoder.schema, nullable = 
true)
+      val deserExpr = stateEncoder.resolveAndBind().deserializer.transformUp {
+        case BoundReference(ordinal, _, _) => 
GetStructField(boundRefToNestedState, ordinal)
+      }
+      val nullLiteral = Literal(null, deserExpr.dataType)
+      CaseWhen(Seq(IsNull(boundRefToNestedState) -> nullLiteral), elseValue = 
deserExpr)
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/b3d88ac0/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/commits/0
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/commits/0
 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/commits/0
new file mode 100644
index 0000000..83321cd
--- /dev/null
+++ 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/commits/0
@@ -0,0 +1,2 @@
+v1
+{}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/spark/blob/b3d88ac0/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/commits/1
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/commits/1
 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/commits/1
new file mode 100644
index 0000000..83321cd
--- /dev/null
+++ 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/commits/1
@@ -0,0 +1,2 @@
+v1
+{}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/spark/blob/b3d88ac0/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/metadata
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/metadata
 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/metadata
new file mode 100644
index 0000000..372180b
--- /dev/null
+++ 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/metadata
@@ -0,0 +1 @@
+{"id":"04d960cd-d38f-4ce6-b8d0-ebcf84c9dccc"}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/spark/blob/b3d88ac0/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/offsets/0
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/offsets/0
 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/offsets/0
new file mode 100644
index 0000000..807d7b0
--- /dev/null
+++ 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/offsets/0
@@ -0,0 +1,3 @@
+v1
+{"batchWatermarkMs":0,"batchTimestampMs":1531292029003,"conf":{"spark.sql.shuffle.partitions":"5","spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider"}}
+0
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/spark/blob/b3d88ac0/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/offsets/1
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/offsets/1
 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/offsets/1
new file mode 100644
index 0000000..cce5410
--- /dev/null
+++ 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/offsets/1
@@ -0,0 +1,3 @@
+v1
+{"batchWatermarkMs":5000,"batchTimestampMs":1531292030005,"conf":{"spark.sql.shuffle.partitions":"5","spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider"}}
+1
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/spark/blob/b3d88ac0/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/0/1.delta
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/0/1.delta
 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/0/1.delta
new file mode 100644
index 0000000..193524f
Binary files /dev/null and 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/0/1.delta
 differ

http://git-wip-us.apache.org/repos/asf/spark/blob/b3d88ac0/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/0/2.delta
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/0/2.delta
 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/0/2.delta
new file mode 100644
index 0000000..6352978
Binary files /dev/null and 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/0/2.delta
 differ

http://git-wip-us.apache.org/repos/asf/spark/blob/b3d88ac0/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/1/1.delta
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/1/1.delta
 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/1/1.delta
new file mode 100644
index 0000000..6352978
Binary files /dev/null and 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/1/1.delta
 differ

http://git-wip-us.apache.org/repos/asf/spark/blob/b3d88ac0/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/1/2.delta
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/1/2.delta
 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/1/2.delta
new file mode 100644
index 0000000..6352978
Binary files /dev/null and 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/1/2.delta
 differ

http://git-wip-us.apache.org/repos/asf/spark/blob/b3d88ac0/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/2/1.delta
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/2/1.delta
 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/2/1.delta
new file mode 100644
index 0000000..6352978
Binary files /dev/null and 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/2/1.delta
 differ

http://git-wip-us.apache.org/repos/asf/spark/blob/b3d88ac0/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/2/2.delta
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/2/2.delta
 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/2/2.delta
new file mode 100644
index 0000000..6352978
Binary files /dev/null and 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/2/2.delta
 differ

http://git-wip-us.apache.org/repos/asf/spark/blob/b3d88ac0/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/3/1.delta
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/3/1.delta
 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/3/1.delta
new file mode 100644
index 0000000..6352978
Binary files /dev/null and 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/3/1.delta
 differ

http://git-wip-us.apache.org/repos/asf/spark/blob/b3d88ac0/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/3/2.delta
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/3/2.delta
 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/3/2.delta
new file mode 100644
index 0000000..6352978
Binary files /dev/null and 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/3/2.delta
 differ

http://git-wip-us.apache.org/repos/asf/spark/blob/b3d88ac0/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/4/1.delta
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/4/1.delta
 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/4/1.delta
new file mode 100644
index 0000000..6352978
Binary files /dev/null and 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/4/1.delta
 differ

http://git-wip-us.apache.org/repos/asf/spark/blob/b3d88ac0/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/4/2.delta
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/4/2.delta
 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/4/2.delta
new file mode 100644
index 0000000..6352978
Binary files /dev/null and 
b/sql/core/src/test/resources/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/state/0/4/2.delta
 differ

http://git-wip-us.apache.org/repos/asf/spark/blob/b3d88ac0/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelperSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelperSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelperSuite.scala
new file mode 100644
index 0000000..dec30fd
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelperSuite.scala
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import java.util.concurrent.atomic.AtomicInteger
+
+import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, 
UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.execution.streaming.GroupStateImpl._
+import org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite._
+import org.apache.spark.sql.streaming.StreamTest
+import org.apache.spark.sql.types._
+
+
+class FlatMapGroupsWithStateExecHelperSuite extends StreamTest {
+
+  import testImplicits._
+  import FlatMapGroupsWithStateExecHelper._
+
+  // ============================ StateManagerImplV1 
============================
+
+  test(s"StateManager v1 - primitive type - without timestamp") {
+    val schema = new StructType().add("value", IntegerType, nullable = false)
+    testStateManagerWithoutTimestamp[Int](version = 1, schema, Seq(0, 10))
+  }
+
+  test(s"StateManager v1 - primitive type - with timestamp") {
+    val schema = new StructType()
+      .add("value", IntegerType, nullable = false)
+      .add("timeoutTimestamp", IntegerType, nullable = false)
+    testStateManagerWithTimestamp[Int](version = 1, schema, Seq(0, 10))
+  }
+
+  test(s"StateManager v1 - nested type - without timestamp") {
+    val schema = StructType(Seq(
+      StructField("i", IntegerType, nullable = false),
+      StructField("nested", StructType(Seq(
+        StructField("d", DoubleType, nullable = false),
+        StructField("str", StringType))
+      ))
+    ))
+
+    val testValues = Seq(
+      NestedStruct(1, Struct(1.0, "someString")),
+      NestedStruct(0, Struct(0.0, "")),
+      NestedStruct(0, null))
+
+    testStateManagerWithoutTimestamp[NestedStruct](version = 1, schema, 
testValues)
+
+    // Verify the limitation of v1 with null state
+    intercept[Exception] {
+      testStateManagerWithoutTimestamp[NestedStruct](version = 1, schema, 
testValues = Seq(null))
+    }
+  }
+
+  test(s"StateManager v1 - nested type - with timestamp") {
+    val schema = StructType(Seq(
+      StructField("i", IntegerType, nullable = false),
+      StructField("nested", StructType(Seq(
+        StructField("d", DoubleType, nullable = false),
+        StructField("str", StringType))
+      )),
+      StructField("timeoutTimestamp", IntegerType, nullable = false)
+    ))
+
+    val testValues = Seq(
+      NestedStruct(1, Struct(1.0, "someString")),
+      NestedStruct(0, Struct(0.0, "")),
+      NestedStruct(0, null))
+
+    testStateManagerWithTimestamp[NestedStruct](version = 1, schema, 
testValues)
+
+    // Verify the limitation of v1 with null state
+    intercept[Exception] {
+      testStateManagerWithTimestamp[NestedStruct](version = 1, schema, 
testValues = Seq(null))
+    }
+  }
+
+  // ============================ StateManagerImplV2 
============================
+
+  test(s"StateManager v2 - primitive type - without timestamp") {
+    val schema = new StructType()
+      .add("groupState", new StructType().add("value", IntegerType, nullable = 
false))
+    testStateManagerWithoutTimestamp[Int](version = 2, schema, Seq(0, 10))
+  }
+
+  test(s"StateManager v2 - primitive type - with timestamp") {
+    val schema = new StructType()
+      .add("groupState", new StructType().add("value", IntegerType, nullable = 
false))
+      .add("timeoutTimestamp", LongType, nullable = false)
+    testStateManagerWithTimestamp[Int](version = 2, schema, Seq(0, 10))
+  }
+
+  test(s"StateManager v2 - nested type - without timestamp") {
+    val schema = StructType(Seq(
+      StructField("groupState", StructType(Seq(
+        StructField("i", IntegerType, nullable = false),
+        StructField("nested", StructType(Seq(
+          StructField("d", DoubleType, nullable = false),
+          StructField("str", StringType)
+        )))
+      )))
+    ))
+
+    val testValues = Seq(
+      NestedStruct(1, Struct(1.0, "someString")),
+      NestedStruct(0, Struct(0.0, "")),
+      NestedStruct(0, null),
+      null)
+
+    testStateManagerWithoutTimestamp[NestedStruct](version = 2, schema, 
testValues)
+  }
+
+  test(s"StateManager v2 - nested type - with timestamp") {
+    val schema = StructType(Seq(
+      StructField("groupState", StructType(Seq(
+        StructField("i", IntegerType, nullable = false),
+        StructField("nested", StructType(Seq(
+          StructField("d", DoubleType, nullable = false),
+          StructField("str", StringType)
+        )))
+      ))),
+      StructField("timeoutTimestamp", LongType, nullable = false)
+    ))
+
+    val testValues = Seq(
+      NestedStruct(1, Struct(1.0, "someString")),
+      NestedStruct(0, Struct(0.0, "")),
+      NestedStruct(0, null),
+      null)
+
+    testStateManagerWithTimestamp[NestedStruct](version = 2, schema, 
testValues)
+  }
+
+
+  def testStateManagerWithoutTimestamp[T: Encoder](
+      version: Int,
+      expectedStateSchema: StructType,
+      testValues: Seq[T]): Unit = {
+    val stateManager = newStateManager[T](version, withTimestamp = false)
+    assert(stateManager.stateSchema === expectedStateSchema)
+    testStateManager(stateManager, testValues, NO_TIMESTAMP)
+  }
+
+  def testStateManagerWithTimestamp[T: Encoder](
+      version: Int,
+      expectedStateSchema: StructType,
+      testValues: Seq[T]): Unit = {
+    val stateManager = newStateManager[T](version, withTimestamp = true)
+    assert(stateManager.stateSchema === expectedStateSchema)
+    for (timestamp <- Seq(NO_TIMESTAMP, 1000)) {
+      testStateManager(stateManager, testValues, timestamp)
+    }
+  }
+
+  private def testStateManager[T: Encoder](
+      stateManager: StateManager,
+      values: Seq[T],
+      timestamp: Long): Unit = {
+    val keys = (1 to values.size).map(_ => newKey())
+    val store = new MemoryStateStore()
+
+    // Test stateManager.getState(), putState(), removeState()
+    keys.zip(values).foreach { case (key, value) =>
+      try {
+        stateManager.putState(store, key, value, timestamp)
+        val data = stateManager.getState(store, key)
+        assert(data.stateObj == value)
+        assert(data.timeoutTimestamp === timestamp)
+        stateManager.removeState(store, key)
+        assert(stateManager.getState(store, key).stateObj == null)
+      } catch {
+        case e: Throwable =>
+         fail(s"put/get/remove test with '$value' failed", e)
+      }
+    }
+
+    // Test stateManager.getAllState()
+    for (i <- keys.indices) {
+      stateManager.putState(store, keys(i), values(i), timestamp)
+    }
+    val allData = stateManager.getAllState(store).map(_.copy()).toArray
+    assert(allData.map(_.timeoutTimestamp).toSet == Set(timestamp))
+    assert(allData.map(_.stateObj).toSet == values.toSet)
+  }
+
+  private def newStateManager[T: Encoder](version: Int, withTimestamp: 
Boolean): StateManager = {
+    FlatMapGroupsWithStateExecHelper.createStateManager(
+      implicitly[Encoder[T]].asInstanceOf[ExpressionEncoder[Any]],
+      withTimestamp,
+      version)
+  }
+
+  private val proj = UnsafeProjection.create(Array[DataType](IntegerType))
+  private val keyCounter = new AtomicInteger(0)
+  private def newKey(): UnsafeRow = {
+    proj.apply(new 
GenericInternalRow(Array[Any](keyCounter.getAndDecrement()))).copy()
+  }
+}
+
+case class Struct(d: Double, str: String)
+case class NestedStruct(i: Int, nested: Struct)

http://git-wip-us.apache.org/repos/asf/spark/blob/b3d88ac0/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
index 988c8e6..82d7755 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
@@ -17,9 +17,11 @@
 
 package org.apache.spark.sql.streaming
 
+import java.io.File
 import java.sql.Date
 import java.util.concurrent.ConcurrentHashMap
 
+import org.apache.commons.io.FileUtils
 import org.scalatest.BeforeAndAfterAll
 import org.scalatest.exceptions.TestFailedException
 
@@ -31,10 +33,12 @@ import 
org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsWithState
 import org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning
 import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
 import org.apache.spark.sql.execution.RDDScanExec
-import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, 
GroupStateImpl, MemoryStream}
-import org.apache.spark.sql.execution.streaming.state.{StateStore, 
StateStoreId, StateStoreMetrics, UnsafeRowPair}
+import org.apache.spark.sql.execution.streaming._
+import 
org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper,
 StateStore, StateStoreId, StateStoreMetrics, UnsafeRowPair}
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming.util.StreamManualClock
 import org.apache.spark.sql.types.{DataType, IntegerType}
+import org.apache.spark.util.Utils
 
 /** Class to check custom state types */
 case class RunningCount(count: Long)
@@ -359,13 +363,13 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest
     }
   }
 
-  // Values used for testing StateStoreUpdater
+  // Values used for testing InputProcessor
   val currentBatchTimestamp = 1000
   val currentBatchWatermark = 1000
   val beforeTimeoutThreshold = 999
   val afterTimeoutThreshold = 1001
 
-  // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout = 
NoTimeout
+  // Tests for InputProcessor.processNewData() when timeout = NoTimeout
   for (priorState <- Seq(None, Some(0))) {
     val priorStateStr = if (priorState.nonEmpty) "prior state set" else "no 
prior state"
     val testName = s"NoTimeout - $priorStateStr - "
@@ -396,7 +400,7 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest
       expectedState = None)        // should be removed
   }
 
-  // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout != 
NoTimeout
+  // Tests for InputProcessor.processTimedOutState() when timeout != NoTimeout
   for (priorState <- Seq(None, Some(0))) {
     for (priorTimeoutTimestamp <- Seq(NO_TIMESTAMP, 1000)) {
       var testName = ""
@@ -443,6 +447,18 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest
           expectedState = None)                                 // state 
should be removed
       }
 
+      // Tests with ProcessingTimeTimeout
+      if (priorState == None) {
+        testStateUpdateWithData(
+          s"ProcessingTimeTimeout - $testName - timeout updated without 
initializing state",
+          stateUpdates = state => { state.setTimeoutDuration(5000) },
+          timeoutConf = ProcessingTimeTimeout,
+          priorState = None,
+          priorTimeoutTimestamp = priorTimeoutTimestamp,
+          expectedState = None,
+          expectedTimeoutTimestamp = currentBatchTimestamp + 5000)
+      }
+
       testStateUpdateWithData(
         s"ProcessingTimeTimeout - $testName - state and timeout duration 
updated",
         stateUpdates =
@@ -454,6 +470,30 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest
         expectedTimeoutTimestamp = currentBatchTimestamp + 5000) // timestamp 
should change
 
       testStateUpdateWithData(
+        s"ProcessingTimeTimeout - $testName - timeout updated after state 
removed",
+        stateUpdates = state => { state.remove(); 
state.setTimeoutDuration(5000) },
+        timeoutConf = ProcessingTimeTimeout,
+        priorState = priorState,
+        priorTimeoutTimestamp = priorTimeoutTimestamp,
+        expectedState = None,
+        expectedTimeoutTimestamp = currentBatchTimestamp + 5000)
+
+      // Tests with EventTimeTimeout
+
+      if (priorState == None) {
+        testStateUpdateWithData(
+          s"EventTimeTimeout - $testName - setting timeout without init state 
not allowed",
+          stateUpdates = state => {
+            state.setTimeoutTimestamp(10000)
+          },
+          timeoutConf = EventTimeTimeout,
+          priorState = None,
+          priorTimeoutTimestamp = priorTimeoutTimestamp,
+          expectedState = None,
+          expectedTimeoutTimestamp = 10000)
+      }
+
+      testStateUpdateWithData(
         s"EventTimeTimeout - $testName - state and timeout timestamp updated",
         stateUpdates =
           (state: GroupState[Int]) => { state.update(5); 
state.setTimeoutTimestamp(5000) },
@@ -477,48 +517,21 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest
         priorTimeoutTimestamp = priorTimeoutTimestamp,
         expectedState = Some(5),                                 // state 
should change
         expectedTimeoutTimestamp = NO_TIMESTAMP)                 // timestamp 
should not update
-    }
-  }
 
-  // Currently disallowed cases for 
StateStoreUpdater.updateStateForKeysWithData(),
-  // Try to remove these cases in the future
-  for (priorTimeoutTimestamp <- Seq(NO_TIMESTAMP, 1000)) {
-    val testName =
-      if (priorTimeoutTimestamp != NO_TIMESTAMP) "prior timeout set" else "no 
prior timeout"
-    testStateUpdateWithData(
-      s"ProcessingTimeTimeout - $testName - setting timeout without init state 
not allowed",
-      stateUpdates = state => { state.setTimeoutDuration(5000) },
-      timeoutConf = ProcessingTimeTimeout,
-      priorState = None,
-      priorTimeoutTimestamp = priorTimeoutTimestamp,
-      expectedException = classOf[IllegalStateException])
-
-    testStateUpdateWithData(
-      s"ProcessingTimeTimeout - $testName - setting timeout with state removal 
not allowed",
-      stateUpdates = state => { state.remove(); state.setTimeoutDuration(5000) 
},
-      timeoutConf = ProcessingTimeTimeout,
-      priorState = Some(5),
-      priorTimeoutTimestamp = priorTimeoutTimestamp,
-      expectedException = classOf[IllegalStateException])
-
-    testStateUpdateWithData(
-      s"EventTimeTimeout - $testName - setting timeout without init state not 
allowed",
-      stateUpdates = state => { state.setTimeoutTimestamp(10000) },
-      timeoutConf = EventTimeTimeout,
-      priorState = None,
-      priorTimeoutTimestamp = priorTimeoutTimestamp,
-      expectedException = classOf[IllegalStateException])
-
-    testStateUpdateWithData(
-      s"EventTimeTimeout - $testName - setting timeout with state removal not 
allowed",
-      stateUpdates = state => { state.remove(); 
state.setTimeoutTimestamp(10000) },
-      timeoutConf = EventTimeTimeout,
-      priorState = Some(5),
-      priorTimeoutTimestamp = priorTimeoutTimestamp,
-      expectedException = classOf[IllegalStateException])
+      testStateUpdateWithData(
+        s"EventTimeTimeout - $testName - setting timeout with state removal 
not allowed",
+        stateUpdates = state => {
+          state.remove(); state.setTimeoutTimestamp(10000)
+        },
+        timeoutConf = EventTimeTimeout,
+        priorState = priorState,
+        priorTimeoutTimestamp = priorTimeoutTimestamp,
+        expectedState = None,
+        expectedTimeoutTimestamp = 10000)
+    }
   }
 
-  // Tests for StateStoreUpdater.updateStateForTimedOutKeys()
+  // Tests for InputProcessor.processTimedOutState()
   val preTimeoutState = Some(5)
   for (timeoutConf <- Seq(ProcessingTimeTimeout, EventTimeTimeout)) {
     testStateUpdateWithTimeout(
@@ -590,7 +603,7 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest
     expectedState = Some(5),                                  // state should 
change
     expectedTimeoutTimestamp = 5000)                          // timestamp 
should change
 
-  test("flatMapGroupsWithState - streaming") {
+  testWithAllStateVersions("flatMapGroupsWithState - streaming") {
     // Function to maintain running count up to 2, and then remove the count
     // Returns the data and the count if state is defined, otherwise does not 
return anything
     val stateFunc = (key: String, values: Iterator[String], state: 
GroupState[RunningCount]) => {
@@ -669,7 +682,7 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest
     )
   }
 
-  test("flatMapGroupsWithState - streaming + aggregation") {
+  testWithAllStateVersions("flatMapGroupsWithState - streaming + aggregation") 
{
     // Function to maintain running count up to 2, and then remove the count
     // Returns the data and the count (-1 if count reached beyond 2 and state 
was just removed)
     val stateFunc = (key: String, values: Iterator[String], state: 
GroupState[RunningCount]) => {
@@ -728,7 +741,7 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest
     checkAnswer(df, Seq(("a", 2), ("b", 1)).toDF)
   }
 
-  test("flatMapGroupsWithState - streaming with processing time timeout") {
+  testWithAllStateVersions("flatMapGroupsWithState - streaming with processing 
time timeout") {
     // Function to maintain the count as state and set the proc. time timeout 
delay of 10 seconds.
     // It returns the count if changed, or -1 if the state was removed by 
timeout.
     val stateFunc = (key: String, values: Iterator[String], state: 
GroupState[RunningCount]) => {
@@ -792,7 +805,7 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest
     )
   }
 
-  test("flatMapGroupsWithState - streaming with event time timeout + 
watermark") {
+  testWithAllStateVersions("flatMapGroupsWithState - streaming w/ event time 
timeout + watermark") {
     // Function to maintain the max event time as state and set the timeout 
timestamp based on the
     // current max event time seen. It returns the max event time in the 
state, or -1 if the state
     // was removed by timeout.
@@ -843,6 +856,105 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest
     )
   }
 
+  test("flatMapGroupsWithState - uses state format version 2 by default") {
+    val stateFunc = (key: String, values: Iterator[String], state: 
GroupState[RunningCount]) => {
+      val count = state.getOption.map(_.count).getOrElse(0L) + values.size
+      state.update(RunningCount(count))
+      Iterator((key, count.toString))
+    }
+
+    val inputData = MemoryStream[String]
+    val result = inputData.toDS()
+        .groupByKey(x => x)
+        .flatMapGroupsWithState(Update, GroupStateTimeout.NoTimeout)(stateFunc)
+
+    testStream(result, Update)(
+      AddData(inputData, "a"),
+      CheckNewAnswer(("a", "1")),
+      Execute { query =>
+        // Verify state format = 2
+        val f = query.lastExecution.executedPlan.collect { case f: 
FlatMapGroupsWithStateExec => f }
+        assert(f.size == 1)
+        assert(f.head.stateFormatVersion == 2)
+      }
+    )
+  }
+
+  test("flatMapGroupsWithState - recovery from checkpoint uses state format 
version 1") {
+    // Function to maintain the max event time as state and set the timeout 
timestamp based on the
+    // current max event time seen. It returns the max event time in the 
state, or -1 if the state
+    // was removed by timeout.
+    val stateFunc = (key: String, values: Iterator[(String, Long)], state: 
GroupState[Long]) => {
+      assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 }
+      assertCanGetWatermark { state.getCurrentWatermarkMs() >= -1 }
+
+      val timeoutDelaySec = 5
+      if (state.hasTimedOut) {
+        state.remove()
+        Iterator((key, -1))
+      } else {
+        val valuesSeq = values.toSeq
+        val maxEventTimeSec = math.max(valuesSeq.map(_._2).max, 
state.getOption.getOrElse(0L))
+        val timeoutTimestampSec = maxEventTimeSec + timeoutDelaySec
+        state.update(maxEventTimeSec)
+        state.setTimeoutTimestamp(timeoutTimestampSec * 1000)
+        Iterator((key, maxEventTimeSec.toInt))
+      }
+    }
+    val inputData = MemoryStream[(String, Int)]
+    val result =
+      inputData.toDS
+        .select($"_1".as("key"), $"_2".cast("timestamp").as("eventTime"))
+        .withWatermark("eventTime", "10 seconds")
+        .as[(String, Long)]
+        .groupByKey(_._1)
+        .flatMapGroupsWithState(Update, EventTimeTimeout)(stateFunc)
+
+    val resourceUri = this.getClass.getResource(
+      
"/structured-streaming/checkpoint-version-2.3.1-flatMapGroupsWithState-state-format-1/").toURI
+
+    val checkpointDir = Utils.createTempDir().getCanonicalFile
+    // Copy the checkpoint to a temp dir to prevent changes to the original.
+    // Not doing this will lead to the test passing on the first run, but fail 
subsequent runs.
+    FileUtils.copyDirectory(new File(resourceUri), checkpointDir)
+
+    inputData.addData(("a", 11), ("a", 13), ("a", 15))
+    inputData.addData(("a", 4))
+
+    testStream(result, Update)(
+      StartStream(
+        checkpointLocation = checkpointDir.getAbsolutePath,
+        additionalConfs = 
Map(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key -> "2")),
+      /*
+      Note: The checkpoint was generated using the following input in Spark 
version 2.3.1
+
+      AddData(inputData, ("a", 11), ("a", 13), ("a", 15)),
+      // Max event time = 15. Timeout timestamp for "a" = 15 + 5 = 20. 
Watermark = 15 - 10 = 5.
+      CheckNewAnswer(("a", 15)),  // Output = max event time of a
+
+      AddData(inputData, ("a", 4)),       // Add data older than watermark for 
"a"
+      CheckNewAnswer(),                   // No output as data should get 
filtered by watermark
+      */
+
+      AddData(inputData, ("a", 10)),      // Add data newer than watermark for 
"a"
+      CheckNewAnswer(("a", 15)),          // Max event time is still the same
+      // Timeout timestamp for "a" is still 20 as max event time for "a" is 
still 15.
+      // Watermark is still 5 as max event time for all data is still 15.
+
+      Execute { query =>
+        // Verify state format = 1
+        val f = query.lastExecution.executedPlan.collect { case f: 
FlatMapGroupsWithStateExec => f }
+        assert(f.size == 1)
+        assert(f.head.stateFormatVersion == 1)
+      },
+
+      AddData(inputData, ("b", 31)),      // Add data newer than watermark for 
"b", not "a"
+      // Watermark = 31 - 10 = 21, so "a" should be timed out as timeout 
timestamp for "a" is 20.
+      CheckNewAnswer(("a", -1), ("b", 31))           // State for "a" should 
timeout and emit -1
+    )
+  }
+
+
   test("mapGroupsWithState - streaming") {
     // Function to maintain running count up to 2, and then remove the count
     // Returns the data and the count (-1 if count reached beyond 2 and state 
was just removed)
@@ -1032,7 +1144,7 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest
     if (priorState.isEmpty && priorTimeoutTimestamp != NO_TIMESTAMP) {
       return // there can be no prior timestamp, when there is no prior state
     }
-    test(s"StateStoreUpdater - updates with data - $testName") {
+    test(s"InputProcessor - process new data - $testName") {
       val mapGroupsFunc = (key: Int, values: Iterator[Int], state: 
GroupState[Int]) => {
         assert(state.hasTimedOut === false, "hasTimedOut not false")
         assert(values.nonEmpty, "Some value is expected")
@@ -1054,7 +1166,7 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest
       expectedState: Option[Int],
       expectedTimeoutTimestamp: Long = NO_TIMESTAMP): Unit = {
 
-    test(s"StateStoreUpdater - updates for timeout - $testName") {
+    test(s"InputProcessor - process timed out state - $testName") {
       val mapGroupsFunc = (key: Int, values: Iterator[Int], state: 
GroupState[Int]) => {
         assert(state.hasTimedOut === true, "hasTimedOut not true")
         assert(values.isEmpty, "values not empty")
@@ -1081,21 +1193,20 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest
     val store = newStateStore()
     val mapGroupsSparkPlan = newFlatMapGroupsWithStateExec(
       mapGroupsFunc, timeoutConf, currentBatchTimestamp)
-    val updater = new mapGroupsSparkPlan.StateStoreUpdater(store)
+    val inputProcessor = new mapGroupsSparkPlan.InputProcessor(store)
+    val stateManager = mapGroupsSparkPlan.stateManager
     val key = intToRow(0)
     // Prepare store with prior state configs
-    if (priorState.nonEmpty) {
-      val row = updater.getStateRow(priorState.get)
-      updater.setTimeoutTimestamp(row, priorTimeoutTimestamp)
-      store.put(key.copy(), row.copy())
+    if (priorState.nonEmpty || priorTimeoutTimestamp != NO_TIMESTAMP) {
+      stateManager.putState(store, key, priorState.orNull, 
priorTimeoutTimestamp)
     }
 
     // Call updating function to update state store
     def callFunction() = {
       val returnedIter = if (testTimeoutUpdates) {
-        updater.updateStateForTimedOutKeys()
+        inputProcessor.processTimedOutState()
       } else {
-        updater.updateStateForKeysWithData(Iterator(key))
+        inputProcessor.processNewData(Iterator(key))
       }
       returnedIter.size // consume the iterator to force state updates
     }
@@ -1106,15 +1217,11 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest
     } else {
       // Call function to update and verify updated state in store
       callFunction()
-      val updatedStateRow = store.get(key)
-      assert(
-        Option(updater.getStateObj(updatedStateRow)).map(_.toString.toInt) === 
expectedState,
+      val updatedState = stateManager.getState(store, key)
+      assert(Option(updatedState.stateObj).map(_.toString.toInt) === 
expectedState,
         "final state not as expected")
-      if (updatedStateRow != null) {
-        assert(
-          updater.getTimeoutTimestamp(updatedStateRow) === 
expectedTimeoutTimestamp,
-          "final timeout timestamp not as expected")
-      }
+      assert(updatedState.timeoutTimestamp === expectedTimeoutTimestamp,
+        "final timeout timestamp not as expected")
     }
   }
 
@@ -1122,6 +1229,7 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest
       func: (Int, Iterator[Int], GroupState[Int]) => Iterator[Int],
       timeoutType: GroupStateTimeout = GroupStateTimeout.NoTimeout,
       batchTimestampMs: Long = NO_TIMESTAMP): FlatMapGroupsWithStateExec = {
+    val stateFormatVersion = 
spark.conf.get(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION)
     MemoryStream[Int]
       .toDS
       .groupByKey(x => x)
@@ -1129,7 +1237,7 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest
       .logicalPlan.collectFirst {
         case FlatMapGroupsWithState(f, k, v, g, d, o, s, m, _, t, _) =>
           FlatMapGroupsWithStateExec(
-            f, k, v, g, d, o, None, s, m, t,
+            f, k, v, g, d, o, None, s, stateFormatVersion, m, t,
             Some(currentBatchTimestamp), Some(currentBatchWatermark), 
RDDScanExec(g, null, "rdd"))
       }.get
   }
@@ -1162,6 +1270,16 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest
   }
 
   def rowToInt(row: UnsafeRow): Int = row.getInt(0)
+
+  def testWithAllStateVersions(name: String)(func: => Unit): Unit = {
+    for (version <- FlatMapGroupsWithStateExecHelper.supportedVersions) {
+      test(s"$name - state format version $version") {
+        withSQLConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key -> 
version.toString) {
+          func
+        }
+      }
+    }
+  }
 }
 
 object FlatMapGroupsWithStateSuite {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to