This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new c7a9c1e0776e [SPARK-49021][SS] Add support for reading 
transformWithState value state variables with state data source reader
c7a9c1e0776e is described below

commit c7a9c1e0776e9a8df3af5141626e494aab8734d6
Author: Anish Shrigondekar <[email protected]>
AuthorDate: Fri Aug 30 07:38:39 2024 +0900

    [SPARK-49021][SS] Add support for reading transformWithState value state 
variables with state data source reader
    
    ### What changes were proposed in this pull request?
    Add support for reading transformWithState value state variables with state 
data source reader
    
    Co-authored with jingz-db
    
    ### Why are the changes needed?
    Changes are needed to integrate reading state reading with new operator 
metadata and state schema format for the value state types used in state 
variables within transformWithState
    
    ### Does this PR introduce _any_ user-facing change?
    Yes
    
    Users can now read valueState variables used in the `transformWithState` 
operator using the state data source reader.
    
    ```
    spark
       .read
       .format("statestore")
       .option("operatorId", <operatorId>)
       .option("stateVarName", <varName>)
       .load(<state path>)
    
    ```
    
    ### How was this patch tested?
    Added unit tests
    ```
    ===== POSSIBLE THREAD LEAK IN SUITE 
o.a.s.sql.streaming.TransformWithStateSuite, threads: 
ForkJoinPool.commonPool-worker-4 (daemon=true), Idle Worker Monitor for python3 
(daemon=true), rpc-boss-3-1 (daemon=true), ForkJoinPool.commonPool-worker-5 
(daemon=true), ForkJoinPool.commonPool-worker-3 (daemon=true), 
ForkJoinPool.commonPool-worker-2 (daemon=true), shuffle-boss-6-1 (daemon=true), 
ForkJoinPool.commonPool-worker-1 (daemon=true) =====
    [info] Run completed in 2 minutes, 28 seconds.
    [info] Total number of tests run: 42
    [info] Suites: completed 1, aborted 0
    [info] Tests: succeeded 42, failed 0, canceled 0, ignored 1, pending 0
    [info] All tests passed.
    ```
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #47574 from anishshri-db/task/SPARK-49021.
    
    Authored-by: Anish Shrigondekar <[email protected]>
    Signed-off-by: Jungtaek Lim <[email protected]>
---
 .../datasources/v2/state/StateDataSource.scala     | 245 +++++++++++++++++----
 .../v2/state/StatePartitionReader.scala            | 114 ++++++----
 .../datasources/v2/state/StateScanBuilder.scala    |  17 +-
 .../datasources/v2/state/StateTable.scala          |  40 +---
 .../v2/state/metadata/StateMetadataSource.scala    |  55 +++--
 .../datasources/v2/state/utils/SchemaUtil.scala    | 126 ++++++++++-
 .../v2/state/StateDataSourceReadSuite.scala        |  55 ++++-
 .../StateDataSourceTransformWithStateSuite.scala   | 220 ++++++++++++++++++
 8 files changed, 731 insertions(+), 141 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
index acd5303350de..83399e2cac01 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
@@ -24,26 +24,28 @@ import scala.util.control.NonFatal
 import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.fs.Path
 
+import org.apache.spark.internal.Logging
 import org.apache.spark.sql.{RuntimeConfig, SparkSession}
 import org.apache.spark.sql.catalyst.DataSourceOptions
 import org.apache.spark.sql.connector.catalog.{Table, TableProvider}
 import org.apache.spark.sql.connector.expressions.Transform
-import 
org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues
+import 
org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.{JoinSideValues,
 STATE_VAR_NAME}
 import 
org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues.JoinSideValues
-import 
org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataPartitionReader
-import org.apache.spark.sql.execution.streaming.{CommitLog, OffsetSeqLog, 
OffsetSeqMetadata}
+import 
org.apache.spark.sql.execution.datasources.v2.state.metadata.{StateMetadataPartitionReader,
 StateMetadataTableEntry}
+import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
+import org.apache.spark.sql.execution.streaming.{CommitLog, OffsetSeqLog, 
OffsetSeqMetadata, TransformWithStateOperatorProperties, 
TransformWithStateVariableInfo}
 import 
org.apache.spark.sql.execution.streaming.StreamingCheckpointConstants.{DIR_NAME_COMMITS,
 DIR_NAME_OFFSETS, DIR_NAME_STATE}
 import 
org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{LeftSide,
 RightSide}
-import 
org.apache.spark.sql.execution.streaming.state.{StateSchemaCompatibilityChecker,
 StateStore, StateStoreConf, StateStoreId, StateStoreProviderId}
+import org.apache.spark.sql.execution.streaming.state.{KeyStateEncoderSpec, 
NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, 
StateSchemaCompatibilityChecker, StateStore, StateStoreColFamilySchema, 
StateStoreConf, StateStoreId, StateStoreProviderId}
 import org.apache.spark.sql.sources.DataSourceRegister
-import org.apache.spark.sql.types.{IntegerType, LongType, StringType, 
StructType}
+import org.apache.spark.sql.types.StructType
 import org.apache.spark.sql.util.CaseInsensitiveStringMap
 import org.apache.spark.util.SerializableConfiguration
 
 /**
  * An implementation of [[TableProvider]] with [[DataSourceRegister]] for 
State Store data source.
  */
-class StateDataSource extends TableProvider with DataSourceRegister {
+class StateDataSource extends TableProvider with DataSourceRegister with 
Logging {
   private lazy val session: SparkSession = SparkSession.active
 
   private lazy val hadoopConf: Configuration = 
session.sessionState.newHadoopConf()
@@ -58,24 +60,26 @@ class StateDataSource extends TableProvider with 
DataSourceRegister {
       properties: util.Map[String, String]): Table = {
     val sourceOptions = StateSourceOptions.apply(session, hadoopConf, 
properties)
     val stateConf = buildStateStoreConf(sourceOptions.resolvedCpLocation, 
sourceOptions.batchId)
-    // Read the operator metadata once to see if we can find the information 
for prefix scan
-    // encoder used in session window aggregation queries.
-    val allStateStoreMetadata = new StateMetadataPartitionReader(
-      sourceOptions.stateCheckpointLocation.getParent.toString, 
serializedHadoopConf,
-      sourceOptions.batchId)
-      .stateMetadata.toArray
-    val stateStoreMetadata = allStateStoreMetadata.filter { entry =>
-      entry.operatorId == sourceOptions.operatorId &&
-        entry.stateStoreName == sourceOptions.storeName
+    val stateStoreReaderInfo: StateStoreReaderInfo = 
getStoreMetadataAndRunChecks(sourceOptions)
+
+    // The key state encoder spec should be available for all operators except 
stream-stream joins
+    val keyStateEncoderSpec = if 
(stateStoreReaderInfo.keyStateEncoderSpecOpt.isDefined) {
+      stateStoreReaderInfo.keyStateEncoderSpecOpt.get
+    } else {
+      val keySchema = SchemaUtil.getSchemaAsDataType(schema, 
"key").asInstanceOf[StructType]
+      NoPrefixKeyStateEncoderSpec(keySchema)
     }
 
-    new StateTable(session, schema, sourceOptions, stateConf, 
stateStoreMetadata)
+    new StateTable(session, schema, sourceOptions, stateConf, 
keyStateEncoderSpec,
+      stateStoreReaderInfo.transformWithStateVariableInfoOpt,
+      stateStoreReaderInfo.stateStoreColFamilySchemaOpt)
   }
 
   override def inferSchema(options: CaseInsensitiveStringMap): StructType = {
-    val partitionId = StateStore.PARTITION_ID_TO_CHECK_SCHEMA
     val sourceOptions = StateSourceOptions.apply(session, hadoopConf, options)
 
+    val stateStoreReaderInfo: StateStoreReaderInfo = 
getStoreMetadataAndRunChecks(sourceOptions)
+
     val stateCheckpointLocation = sourceOptions.stateCheckpointLocation
     try {
       val (keySchema, valueSchema) = sourceOptions.joinSide match {
@@ -88,34 +92,24 @@ class StateDataSource extends TableProvider with 
DataSourceRegister {
             sourceOptions.operatorId, RightSide)
 
         case JoinSideValues.none =>
-          val storeId = new StateStoreId(stateCheckpointLocation.toString, 
sourceOptions.operatorId,
-            partitionId, sourceOptions.storeName)
-          val providerId = new StateStoreProviderId(storeId, UUID.randomUUID())
-          val manager = new StateSchemaCompatibilityChecker(providerId, 
hadoopConf)
-          val stateSchema = manager.readSchemaFile().head
-          (stateSchema.keySchema, stateSchema.valueSchema)
-      }
-
-      if (sourceOptions.readChangeFeed) {
-        new StructType()
-          .add("batch_id", LongType)
-          .add("change_type", StringType)
-          .add("key", keySchema)
-          .add("value", valueSchema)
-          .add("partition_id", IntegerType)
-      } else {
-        new StructType()
-          .add("key", keySchema)
-          .add("value", valueSchema)
-          .add("partition_id", IntegerType)
+          // we should have the schema for the state store if joinSide is none
+          require(stateStoreReaderInfo.stateStoreColFamilySchemaOpt.isDefined)
+          val resultSchema = 
stateStoreReaderInfo.stateStoreColFamilySchemaOpt.get
+          (resultSchema.keySchema, resultSchema.valueSchema)
       }
 
+      SchemaUtil.getSourceSchema(sourceOptions, keySchema,
+        valueSchema,
+        stateStoreReaderInfo.transformWithStateVariableInfoOpt,
+        stateStoreReaderInfo.stateStoreColFamilySchemaOpt)
     } catch {
       case NonFatal(e) =>
         throw StateDataSourceErrors.failedToReadStateSchema(sourceOptions, e)
     }
   }
 
+  override def supportsExternalMetadata(): Boolean = false
+
   private def buildStateStoreConf(checkpointLocation: String, batchId: Long): 
StateStoreConf = {
     val offsetLog = new OffsetSeqLog(session,
       new Path(checkpointLocation, DIR_NAME_OFFSETS).toString)
@@ -134,7 +128,161 @@ class StateDataSource extends TableProvider with 
DataSourceRegister {
     }
   }
 
-  override def supportsExternalMetadata(): Boolean = false
+  private def runStateVarChecks(
+      sourceOptions: StateSourceOptions,
+      stateStoreMetadata: Array[StateMetadataTableEntry]): Unit = {
+    val twsShortName = "transformWithStateExec"
+    if (sourceOptions.stateVarName.isDefined) {
+      // Perform checks for transformWithState operator in case state variable 
name is provided
+      require(stateStoreMetadata.size == 1)
+      val opMetadata = stateStoreMetadata.head
+      if (opMetadata.operatorName != twsShortName) {
+        // if we are trying to query state source with state variable name, 
then the operator
+        // should be transformWithState
+        val errorMsg = "Providing state variable names is only supported with 
the " +
+          s"transformWithState operator. Found 
operator=${opMetadata.operatorName}. " +
+          s"Please remove this option and re-run the query."
+        throw StateDataSourceErrors.invalidOptionValue(STATE_VAR_NAME, 
errorMsg)
+      }
+
+      // if the operator is transformWithState, but the operator properties 
are empty, then
+      // the user has not defined any state variables for the operator
+      val operatorProperties = opMetadata.operatorPropertiesJson
+      if (operatorProperties.isEmpty) {
+        throw StateDataSourceErrors.invalidOptionValue(STATE_VAR_NAME,
+          "No state variable names are defined for the transformWithState 
operator")
+      }
+
+      // if the state variable is not one of the defined/available state 
variables, then we
+      // fail the query
+      val stateVarName = sourceOptions.stateVarName.get
+      val twsOperatorProperties = 
TransformWithStateOperatorProperties.fromJson(operatorProperties)
+      val stateVars = twsOperatorProperties.stateVariables
+      if (stateVars.filter(stateVar => stateVar.stateName == 
stateVarName).size != 1) {
+        throw StateDataSourceErrors.invalidOptionValue(STATE_VAR_NAME,
+          s"State variable $stateVarName is not defined for the 
transformWithState operator.")
+      }
+
+      // TODO: Support change feed and transformWithState together
+      if (sourceOptions.readChangeFeed) {
+        throw 
StateDataSourceErrors.conflictOptions(Seq(StateSourceOptions.READ_CHANGE_FEED,
+          StateSourceOptions.STATE_VAR_NAME))
+      }
+    } else {
+      // if the operator is transformWithState, then a state variable argument 
is mandatory
+      if (stateStoreMetadata.size == 1 &&
+        stateStoreMetadata.head.operatorName == twsShortName) {
+        throw StateDataSourceErrors.requiredOptionUnspecified("stateVarName")
+      }
+    }
+  }
+
+  private def getStateStoreMetadata(stateSourceOptions: StateSourceOptions):
+    Array[StateMetadataTableEntry] = {
+    val allStateStoreMetadata = new StateMetadataPartitionReader(
+      stateSourceOptions.stateCheckpointLocation.getParent.toString,
+      serializedHadoopConf, stateSourceOptions.batchId).stateMetadata.toArray
+    val stateStoreMetadata = allStateStoreMetadata.filter { entry =>
+      entry.operatorId == stateSourceOptions.operatorId &&
+        entry.stateStoreName == stateSourceOptions.storeName
+    }
+    stateStoreMetadata
+  }
+
+  private def getStoreMetadataAndRunChecks(sourceOptions: StateSourceOptions):
+    StateStoreReaderInfo = {
+    val storeMetadata = getStateStoreMetadata(sourceOptions)
+    runStateVarChecks(sourceOptions, storeMetadata)
+    var keyStateEncoderSpecOpt: Option[KeyStateEncoderSpec] = None
+    var stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema] = None
+    var transformWithStateVariableInfoOpt: 
Option[TransformWithStateVariableInfo] = None
+
+    if (sourceOptions.joinSide == JoinSideValues.none) {
+      val stateVarName = sourceOptions.stateVarName
+        .getOrElse(StateStore.DEFAULT_COL_FAMILY_NAME)
+
+      // Read the schema file path from operator metadata version v2 onwards
+      // for the transformWithState operator
+      val oldSchemaFilePath = if (storeMetadata.length > 0 && 
storeMetadata.head.version == 2
+        && storeMetadata.head.operatorName.contains("transformWithStateExec")) 
{
+        val storeMetadataEntry = storeMetadata.head
+        val operatorProperties = TransformWithStateOperatorProperties.fromJson(
+          storeMetadataEntry.operatorPropertiesJson)
+        val stateVarInfoList = operatorProperties.stateVariables
+          .filter(stateVar => stateVar.stateName == stateVarName)
+        require(stateVarInfoList.size == 1, s"Failed to find unique state 
variable info " +
+          s"for state variable $stateVarName in operator 
${sourceOptions.operatorId}")
+        val stateVarInfo = stateVarInfoList.head
+        transformWithStateVariableInfoOpt = Some(stateVarInfo)
+        val schemaFilePath = new 
Path(storeMetadataEntry.stateSchemaFilePath.get)
+        Some(schemaFilePath)
+      } else {
+        None
+      }
+
+      try {
+        // Read the actual state schema from the provided path for v2 or from 
the dedicated path
+        // for v1
+        val partitionId = StateStore.PARTITION_ID_TO_CHECK_SCHEMA
+        val stateCheckpointLocation = sourceOptions.stateCheckpointLocation
+        val storeId = new StateStoreId(stateCheckpointLocation.toString, 
sourceOptions.operatorId,
+          partitionId, sourceOptions.storeName)
+        val providerId = new StateStoreProviderId(storeId, UUID.randomUUID())
+        val manager = new StateSchemaCompatibilityChecker(providerId, 
hadoopConf,
+          oldSchemaFilePath = oldSchemaFilePath)
+        val stateSchema = manager.readSchemaFile()
+
+        // Based on the version and read schema, populate the 
keyStateEncoderSpec used for
+        // reading the column families
+        val resultSchema = stateSchema.filter(_.colFamilyName == 
stateVarName).head
+        keyStateEncoderSpecOpt = Some(getKeyStateEncoderSpec(resultSchema, 
storeMetadata))
+        stateStoreColFamilySchemaOpt = Some(resultSchema)
+      } catch {
+        case NonFatal(ex) =>
+          throw StateDataSourceErrors.failedToReadStateSchema(sourceOptions, 
ex)
+      }
+    }
+
+    StateStoreReaderInfo(
+      keyStateEncoderSpecOpt,
+      stateStoreColFamilySchemaOpt,
+      transformWithStateVariableInfoOpt
+    )
+  }
+
+  private def getKeyStateEncoderSpec(
+      colFamilySchema: StateStoreColFamilySchema,
+      storeMetadata: Array[StateMetadataTableEntry]): KeyStateEncoderSpec = {
+    // If operator metadata is not found, then log a warning and continue with 
using the no-prefix
+    // key state encoder
+    val keyStateEncoderSpec = if (storeMetadata.length == 0) {
+      logWarning("Metadata for state store not found, possible cause is this 
checkpoint " +
+        "is created by older version of spark. If the query has session window 
aggregation, " +
+        "the state can't be read correctly and runtime exception will be 
thrown. " +
+        "Run the streaming query in newer spark version to generate state 
metadata " +
+        "can fix the issue.")
+      NoPrefixKeyStateEncoderSpec(colFamilySchema.keySchema)
+    } else {
+      require(storeMetadata.length == 1)
+      val storeMetadataEntry = storeMetadata.head
+      // if version has metadata info, then use numColsPrefixKey as specified
+      if (storeMetadataEntry.version == 1 && 
storeMetadataEntry.numColsPrefixKey == 0) {
+        NoPrefixKeyStateEncoderSpec(colFamilySchema.keySchema)
+      } else if (storeMetadataEntry.version == 1 && 
storeMetadataEntry.numColsPrefixKey > 0) {
+        PrefixKeyScanStateEncoderSpec(colFamilySchema.keySchema,
+          storeMetadataEntry.numColsPrefixKey)
+      } else if (storeMetadataEntry.version == 2) {
+        // for version 2, we have the encoder spec recorded to the state 
schema file. so we just
+        // use that directly
+        require(colFamilySchema.keyStateEncoderSpec.isDefined)
+        colFamilySchema.keyStateEncoderSpec.get
+      } else {
+        throw StateDataSourceErrors.internalError(s"Failed to read " +
+          s"key state encoder spec for 
operator=${storeMetadataEntry.operatorId}")
+      }
+    }
+    keyStateEncoderSpec
+  }
 }
 
 case class FromSnapshotOptions(
@@ -154,12 +302,14 @@ case class StateSourceOptions(
     joinSide: JoinSideValues,
     readChangeFeed: Boolean,
     fromSnapshotOptions: Option[FromSnapshotOptions],
-    readChangeFeedOptions: Option[ReadChangeFeedOptions]) {
+    readChangeFeedOptions: Option[ReadChangeFeedOptions],
+    stateVarName: Option[String]) {
   def stateCheckpointLocation: Path = new Path(resolvedCpLocation, 
DIR_NAME_STATE)
 
   override def toString: String = {
     var desc = s"StateSourceOptions(checkpointLocation=$resolvedCpLocation, 
batchId=$batchId, " +
-      s"operatorId=$operatorId, storeName=$storeName, joinSide=$joinSide"
+      s"operatorId=$operatorId, storeName=$storeName, joinSide=$joinSide, " +
+      s"stateVarName=${stateVarName.getOrElse("None")}"
     if (fromSnapshotOptions.isDefined) {
       desc += s", 
snapshotStartBatchId=${fromSnapshotOptions.get.snapshotStartBatchId}"
       desc += s", 
snapshotPartitionId=${fromSnapshotOptions.get.snapshotPartitionId}"
@@ -183,6 +333,7 @@ object StateSourceOptions extends DataSourceOptions {
   val READ_CHANGE_FEED = newOption("readChangeFeed")
   val CHANGE_START_BATCH_ID = newOption("changeStartBatchId")
   val CHANGE_END_BATCH_ID = newOption("changeEndBatchId")
+  val STATE_VAR_NAME = newOption("stateVarName")
 
   object JoinSideValues extends Enumeration {
     type JoinSideValues = Value
@@ -219,6 +370,10 @@ object StateSourceOptions extends DataSourceOptions {
       throw StateDataSourceErrors.invalidOptionValueIsEmpty(STORE_NAME)
     }
 
+    // Check if the state variable name is provided. Used with the 
transformWithState operator.
+    val stateVarName = Option(options.get(STATE_VAR_NAME))
+      .map(_.trim)
+
     val joinSide = try {
       Option(options.get(JOIN_SIDE))
         .map(JoinSideValues.withName).getOrElse(JoinSideValues.none)
@@ -322,7 +477,7 @@ object StateSourceOptions extends DataSourceOptions {
 
     StateSourceOptions(
       resolvedCpLocation, batchId.get, operatorId, storeName, joinSide,
-      readChangeFeed, fromSnapshotOptions, readChangeFeedOptions)
+      readChangeFeed, fromSnapshotOptions, readChangeFeedOptions, stateVarName)
   }
 
   private def resolvedCheckpointLocation(
@@ -342,3 +497,11 @@ object StateSourceOptions extends DataSourceOptions {
     }
   }
 }
+
+// Case class to store information around the key state encoder, col family 
schema and
+// operator specific state used primarily for the transformWithState operator.
+case class StateStoreReaderInfo(
+    keyStateEncoderSpecOpt: Option[KeyStateEncoderSpec],
+    stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
+    transformWithStateVariableInfoOpt: Option[TransformWithStateVariableInfo]
+)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
index 6201cf1157ab..53576c335cb0 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
@@ -20,8 +20,8 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, 
UnsafeRow}
 import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, 
PartitionReaderFactory}
-import 
org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataTableEntry
 import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
+import org.apache.spark.sql.execution.streaming.{StateVariableType, 
TransformWithStateVariableInfo}
 import org.apache.spark.sql.execution.streaming.state._
 import 
org.apache.spark.sql.execution.streaming.state.RecordType.{getRecordTypeAsString,
 RecordType}
 import org.apache.spark.sql.types.StructType
@@ -36,16 +36,21 @@ class StatePartitionReaderFactory(
     storeConf: StateStoreConf,
     hadoopConf: SerializableConfiguration,
     schema: StructType,
-    stateStoreMetadata: Array[StateMetadataTableEntry]) extends 
PartitionReaderFactory {
+    keyStateEncoderSpec: KeyStateEncoderSpec,
+    stateVariableInfoOpt: Option[TransformWithStateVariableInfo],
+    stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema])
+  extends PartitionReaderFactory {
 
   override def createReader(partition: InputPartition): 
PartitionReader[InternalRow] = {
     val stateStoreInputPartition = 
partition.asInstanceOf[StateStoreInputPartition]
     if (stateStoreInputPartition.sourceOptions.readChangeFeed) {
       new StateStoreChangeDataPartitionReader(storeConf, hadoopConf,
-        stateStoreInputPartition, schema, stateStoreMetadata)
+        stateStoreInputPartition, schema, keyStateEncoderSpec, 
stateVariableInfoOpt,
+        stateStoreColFamilySchemaOpt)
     } else {
       new StatePartitionReader(storeConf, hadoopConf,
-        stateStoreInputPartition, schema, stateStoreMetadata)
+        stateStoreInputPartition, schema, keyStateEncoderSpec, 
stateVariableInfoOpt,
+        stateStoreColFamilySchemaOpt)
     }
   }
 }
@@ -59,40 +64,44 @@ abstract class StatePartitionReaderBase(
     hadoopConf: SerializableConfiguration,
     partition: StateStoreInputPartition,
     schema: StructType,
-    stateStoreMetadata: Array[StateMetadataTableEntry])
+    keyStateEncoderSpec: KeyStateEncoderSpec,
+    stateVariableInfoOpt: Option[TransformWithStateVariableInfo],
+    stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema])
   extends PartitionReader[InternalRow] with Logging {
-  private val keySchema = SchemaUtil.getSchemaAsDataType(schema, 
"key").asInstanceOf[StructType]
-  private val valueSchema = SchemaUtil.getSchemaAsDataType(schema, 
"value").asInstanceOf[StructType]
+  protected val keySchema = SchemaUtil.getSchemaAsDataType(
+    schema, "key").asInstanceOf[StructType]
+  protected val valueSchema = SchemaUtil.getSchemaAsDataType(
+    schema, "value").asInstanceOf[StructType]
 
   protected lazy val provider: StateStoreProvider = {
     val stateStoreId = 
StateStoreId(partition.sourceOptions.stateCheckpointLocation.toString,
       partition.sourceOptions.operatorId, partition.partition, 
partition.sourceOptions.storeName)
     val stateStoreProviderId = StateStoreProviderId(stateStoreId, 
partition.queryId)
-    val numColsPrefixKey = if (stateStoreMetadata.isEmpty) {
-      logWarning("Metadata for state store not found, possible cause is this 
checkpoint " +
-        "is created by older version of spark. If the query has session window 
aggregation, " +
-        "the state can't be read correctly and runtime exception will be 
thrown. " +
-        "Run the streaming query in newer spark version to generate state 
metadata " +
-        "can fix the issue.")
-      0
-    } else {
-      require(stateStoreMetadata.length == 1)
-      stateStoreMetadata.head.numColsPrefixKey
-    }
 
-    // TODO: currently we don't support RangeKeyScanStateEncoderSpec. Support 
for this will be
-    // added in the future along with state metadata changes.
-    // Filed JIRA here: https://issues.apache.org/jira/browse/SPARK-47524
-    val keyStateEncoderType = if (numColsPrefixKey > 0) {
-      PrefixKeyScanStateEncoderSpec(keySchema, numColsPrefixKey)
+    val useColFamilies = if (stateVariableInfoOpt.isDefined) {
+      true
     } else {
-      NoPrefixKeyStateEncoderSpec(keySchema)
+      false
     }
 
-    StateStoreProvider.createAndInit(
-      stateStoreProviderId, keySchema, valueSchema, keyStateEncoderType,
-      useColumnFamilies = false, storeConf, hadoopConf.value,
+    val provider = StateStoreProvider.createAndInit(
+      stateStoreProviderId, keySchema, valueSchema, keyStateEncoderSpec,
+      useColumnFamilies = useColFamilies, storeConf, hadoopConf.value,
       useMultipleValuesPerKey = false)
+
+    if (useColFamilies) {
+      val store = provider.getStore(partition.sourceOptions.batchId + 1)
+      require(stateStoreColFamilySchemaOpt.isDefined)
+      val stateStoreColFamilySchema = stateStoreColFamilySchemaOpt.get
+      require(stateStoreColFamilySchema.keyStateEncoderSpec.isDefined)
+      store.createColFamilyIfAbsent(
+        stateStoreColFamilySchema.colFamilyName,
+        stateStoreColFamilySchema.keySchema,
+        stateStoreColFamilySchema.valueSchema,
+        stateStoreColFamilySchema.keyStateEncoderSpec.get,
+        useMultipleValuesPerKey = false)
+    }
+    provider
   }
 
   protected val iter: Iterator[InternalRow]
@@ -126,8 +135,11 @@ class StatePartitionReader(
     hadoopConf: SerializableConfiguration,
     partition: StateStoreInputPartition,
     schema: StructType,
-    stateStoreMetadata: Array[StateMetadataTableEntry])
-  extends StatePartitionReaderBase(storeConf, hadoopConf, partition, schema, 
stateStoreMetadata) {
+    keyStateEncoderSpec: KeyStateEncoderSpec,
+    stateVariableInfoOpt: Option[TransformWithStateVariableInfo],
+    stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema])
+  extends StatePartitionReaderBase(storeConf, hadoopConf, partition, schema,
+    keyStateEncoderSpec, stateVariableInfoOpt, stateStoreColFamilySchemaOpt) {
 
   private lazy val store: ReadStateStore = {
     partition.sourceOptions.fromSnapshotOptions match {
@@ -146,21 +158,40 @@ class StatePartitionReader(
   }
 
   override lazy val iter: Iterator[InternalRow] = {
-    store.iterator().map(pair => unifyStateRowPair((pair.key, pair.value)))
+    val stateVarName = stateVariableInfoOpt
+      .map(_.stateName).getOrElse(StateStore.DEFAULT_COL_FAMILY_NAME)
+    store
+      .iterator(stateVarName)
+      .map { pair =>
+        stateVariableInfoOpt match {
+          case Some(stateVarInfo) =>
+            val stateVarType = stateVarInfo.stateVariableType
+            val hasTTLEnabled = stateVarInfo.ttlEnabled
+
+            stateVarType match {
+              case StateVariableType.ValueState =>
+                if (hasTTLEnabled) {
+                  SchemaUtil.unifyStateRowPairWithTTL((pair.key, pair.value), 
valueSchema,
+                    partition.partition)
+                } else {
+                  SchemaUtil.unifyStateRowPair((pair.key, pair.value), 
partition.partition)
+                }
+
+              case _ =>
+                throw new IllegalStateException(
+                  s"Unsupported state variable type: $stateVarType")
+            }
+
+          case None =>
+            SchemaUtil.unifyStateRowPair((pair.key, pair.value), 
partition.partition)
+        }
+      }
   }
 
   override def close(): Unit = {
     store.abort()
     super.close()
   }
-
-  private def unifyStateRowPair(pair: (UnsafeRow, UnsafeRow)): InternalRow = {
-    val row = new GenericInternalRow(3)
-    row.update(0, pair._1)
-    row.update(1, pair._2)
-    row.update(2, partition.partition)
-    row
-  }
 }
 
 /**
@@ -172,8 +203,11 @@ class StateStoreChangeDataPartitionReader(
     hadoopConf: SerializableConfiguration,
     partition: StateStoreInputPartition,
     schema: StructType,
-    stateStoreMetadata: Array[StateMetadataTableEntry])
-  extends StatePartitionReaderBase(storeConf, hadoopConf, partition, schema, 
stateStoreMetadata) {
+    keyStateEncoderSpec: KeyStateEncoderSpec,
+    stateVariableInfoOpt: Option[TransformWithStateVariableInfo],
+    stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema])
+  extends StatePartitionReaderBase(storeConf, hadoopConf, partition, schema,
+    keyStateEncoderSpec, stateVariableInfoOpt, stateStoreColFamilySchemaOpt) {
 
   private lazy val changeDataReader:
     NextIterator[(RecordType.Value, UnsafeRow, UnsafeRow, Long)] = {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala
index 01f966ae948a..1bb992eb9add 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala
@@ -25,9 +25,9 @@ import org.apache.hadoop.fs.{Path, PathFilter}
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.connector.read.{Batch, InputPartition, 
PartitionReaderFactory, Scan, ScanBuilder}
 import 
org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues
-import 
org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataTableEntry
 import 
org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{LeftSide,
 RightSide}
-import org.apache.spark.sql.execution.streaming.state.{StateStoreConf, 
StateStoreErrors}
+import org.apache.spark.sql.execution.streaming.TransformWithStateVariableInfo
+import org.apache.spark.sql.execution.streaming.state.{KeyStateEncoderSpec, 
StateStoreColFamilySchema, StateStoreConf, StateStoreErrors}
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.util.SerializableConfiguration
 
@@ -37,9 +37,11 @@ class StateScanBuilder(
     schema: StructType,
     sourceOptions: StateSourceOptions,
     stateStoreConf: StateStoreConf,
-    stateStoreMetadata: Array[StateMetadataTableEntry]) extends ScanBuilder {
+    keyStateEncoderSpec: KeyStateEncoderSpec,
+    stateVariableInfoOpt: Option[TransformWithStateVariableInfo],
+    stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema]) extends 
ScanBuilder {
   override def build(): Scan = new StateScan(session, schema, sourceOptions, 
stateStoreConf,
-    stateStoreMetadata)
+    keyStateEncoderSpec, stateVariableInfoOpt, stateStoreColFamilySchemaOpt)
 }
 
 /** An implementation of [[InputPartition]] for State Store data source. */
@@ -54,7 +56,10 @@ class StateScan(
     schema: StructType,
     sourceOptions: StateSourceOptions,
     stateStoreConf: StateStoreConf,
-    stateStoreMetadata: Array[StateMetadataTableEntry]) extends Scan with 
Batch {
+    keyStateEncoderSpec: KeyStateEncoderSpec,
+    stateVariableInfoOpt: Option[TransformWithStateVariableInfo],
+    stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema]) extends 
Scan
+  with Batch {
 
   // A Hadoop Configuration can be about 10 KB, which is pretty big, so 
broadcast it
   private val hadoopConfBroadcast = session.sparkContext.broadcast(
@@ -123,7 +128,7 @@ class StateScan(
 
     case JoinSideValues.none =>
       new StatePartitionReaderFactory(stateStoreConf, 
hadoopConfBroadcast.value, schema,
-        stateStoreMetadata)
+        keyStateEncoderSpec, stateVariableInfoOpt, 
stateStoreColFamilySchemaOpt)
   }
 
   override def toBatch: Batch = this
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala
index 2fc85cd8aa96..4069a52f38b1 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala
@@ -24,12 +24,11 @@ import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.connector.catalog.{MetadataColumn, 
SupportsMetadataColumns, SupportsRead, Table, TableCapability}
 import org.apache.spark.sql.connector.read.ScanBuilder
 import 
org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues
-import 
org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataTableEntry
 import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
-import org.apache.spark.sql.execution.streaming.state.StateStoreConf
-import org.apache.spark.sql.types.{IntegerType, LongType, StringType, 
StructType}
+import org.apache.spark.sql.execution.streaming.TransformWithStateVariableInfo
+import org.apache.spark.sql.execution.streaming.state.{KeyStateEncoderSpec, 
StateStoreColFamilySchema, StateStoreConf}
+import org.apache.spark.sql.types.StructType
 import org.apache.spark.sql.util.CaseInsensitiveStringMap
-import org.apache.spark.util.ArrayImplicits._
 
 /** An implementation of [[Table]] with [[SupportsRead]] for State Store data 
source. */
 class StateTable(
@@ -37,12 +36,14 @@ class StateTable(
     override val schema: StructType,
     sourceOptions: StateSourceOptions,
     stateConf: StateStoreConf,
-    stateStoreMetadata: Array[StateMetadataTableEntry])
+    keyStateEncoderSpec: KeyStateEncoderSpec,
+    stateVariableInfoOpt: Option[TransformWithStateVariableInfo],
+    stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema])
   extends Table with SupportsRead with SupportsMetadataColumns {
 
   import StateTable._
 
-  if (!isValidSchema(schema)) {
+  if (!SchemaUtil.isValidSchema(sourceOptions, schema, stateVariableInfoOpt)) {
     throw StateDataSourceErrors.internalError(
       s"Invalid schema is provided. Provided schema: $schema for " +
         s"checkpoint location: ${sourceOptions.stateCheckpointLocation} , 
operatorId: " +
@@ -77,34 +78,11 @@ class StateTable(
   override def capabilities(): util.Set[TableCapability] = CAPABILITY
 
   override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder =
-    new StateScanBuilder(session, schema, sourceOptions, stateConf, 
stateStoreMetadata)
+    new StateScanBuilder(session, schema, sourceOptions, stateConf, 
keyStateEncoderSpec,
+      stateVariableInfoOpt, stateStoreColFamilySchemaOpt)
 
   override def properties(): util.Map[String, String] = Map.empty[String, 
String].asJava
 
-  private def isValidSchema(schema: StructType): Boolean = {
-    val expectedFieldNames =
-      if (sourceOptions.readChangeFeed) {
-        Seq("batch_id", "change_type", "key", "value", "partition_id")
-      } else {
-        Seq("key", "value", "partition_id")
-      }
-    val expectedTypes = Map(
-      "batch_id" -> classOf[LongType],
-      "change_type" -> classOf[StringType],
-      "key" -> classOf[StructType],
-      "value" -> classOf[StructType],
-      "partition_id" -> classOf[IntegerType])
-
-    if (schema.fieldNames.toImmutableArraySeq != expectedFieldNames) {
-      false
-    } else {
-      schema.fieldNames.forall { fieldName =>
-        expectedTypes(fieldName).isAssignableFrom(
-          SchemaUtil.getSchemaAsDataType(schema, fieldName).getClass)
-      }
-    }
-  }
-
   override def metadataColumns(): Array[MetadataColumn] = Array.empty
 }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala
index afd6a190b0ca..64fdfb799762 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/metadata/StateMetadataSource.scala
@@ -23,6 +23,7 @@ import scala.jdk.CollectionConverters._
 import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.fs.{Path, PathFilter}
 
+import org.apache.spark.internal.{Logging, LogKeys, MDC}
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
@@ -46,6 +47,7 @@ case class StateMetadataTableEntry(
     numPartitions: Int,
     minBatchId: Long,
     maxBatchId: Long,
+    version: Int,
     operatorPropertiesJson: String,
     numColsPrefixKey: Int,
     stateSchemaFilePath: Option[String]) {
@@ -87,7 +89,7 @@ class StateMetadataSource extends TableProvider with 
DataSourceRegister {
 
   override def inferSchema(options: CaseInsensitiveStringMap): StructType = {
     // The schema of state metadata table is static.
-   StateMetadataTableEntry.schema
+    StateMetadataTableEntry.schema
   }
 }
 
@@ -159,7 +161,7 @@ case class StateMetadataPartitionReaderFactory(
 class StateMetadataPartitionReader(
     checkpointLocation: String,
     serializedHadoopConf: SerializableConfiguration,
-    batchId: Long) extends PartitionReader[InternalRow] {
+    batchId: Long) extends PartitionReader[InternalRow] with Logging {
 
   override def next(): Boolean = {
     stateMetadata.hasNext
@@ -205,26 +207,35 @@ class StateMetadataPartitionReader(
 
   // Need this to be accessible from IncrementalExecution for the planning 
rule.
   private[sql] def allOperatorStateMetadata: Array[OperatorStateMetadata] = {
-    val stateDir = new Path(checkpointLocation, "state")
-    val opIds = fileManager
-      .list(stateDir, pathNameCanBeParsedAsLongFilter).map(f => 
pathToLong(f.getPath)).sorted
-    opIds.map { opId =>
-      val operatorIdPath = new Path(stateDir, opId.toString)
-      // check if OperatorStateMetadataV2 path exists, if it does, read it
-      // otherwise, fall back to OperatorStateMetadataV1
-      val operatorStateMetadataV2Path = 
OperatorStateMetadataV2.metadataDirPath(operatorIdPath)
-      val operatorStateMetadataVersion = if 
(fileManager.exists(operatorStateMetadataV2Path)) {
-        2
-      } else {
-        1
-      }
-
-      OperatorStateMetadataReader.createReader(
-        operatorIdPath, hadoopConf, operatorStateMetadataVersion, 
batchId).read() match {
-        case Some(metadata) => metadata
-        case None => throw 
StateDataSourceErrors.failedToReadOperatorMetadata(checkpointLocation,
-          batchId)
+    try {
+      val stateDir = new Path(checkpointLocation, "state")
+      val opIds = fileManager
+        .list(stateDir, pathNameCanBeParsedAsLongFilter).map(f => 
pathToLong(f.getPath)).sorted
+      opIds.map { opId =>
+        val operatorIdPath = new Path(stateDir, opId.toString)
+        // check if OperatorStateMetadataV2 path exists, if it does, read it
+        // otherwise, fall back to OperatorStateMetadataV1
+        val operatorStateMetadataV2Path = 
OperatorStateMetadataV2.metadataDirPath(operatorIdPath)
+        val operatorStateMetadataVersion = if 
(fileManager.exists(operatorStateMetadataV2Path)) {
+          2
+        } else {
+          1
+        }
+        OperatorStateMetadataReader.createReader(
+          operatorIdPath, hadoopConf, operatorStateMetadataVersion, 
batchId).read() match {
+          case Some(metadata) => metadata
+          case None => throw 
StateDataSourceErrors.failedToReadOperatorMetadata(checkpointLocation,
+            batchId)
+        }
       }
+    } catch {
+      // if the operator metadata is not present, catch the exception
+      // and return an empty array
+      case ex: Exception =>
+        logWarning(log"Failed to find operator metadata for " +
+          log"path=${MDC(LogKeys.CHECKPOINT_LOCATION, checkpointLocation)} " +
+          log"with exception=${MDC(LogKeys.EXCEPTION, ex)}")
+        Array.empty
     }
   }
 
@@ -242,6 +253,7 @@ class StateMetadataPartitionReader(
               stateStoreMetadata.numPartitions,
               if (batchIds.nonEmpty) batchIds.head else -1,
               if (batchIds.nonEmpty) batchIds.last else -1,
+              operatorStateMetadata.version,
               null,
               stateStoreMetadata.numColsPrefixKey,
               None
@@ -255,6 +267,7 @@ class StateMetadataPartitionReader(
               stateStoreMetadata.numPartitions,
               if (batchIds.nonEmpty) batchIds.head else -1,
               if (batchIds.nonEmpty) batchIds.last else -1,
+              operatorStateMetadata.version,
               v2.operatorPropertiesJson,
               -1, // numColsPrefixKey is not available in 
OperatorStateMetadataV2
               Some(stateStoreMetadata.stateSchemaFilePath)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala
index 54c6b34db972..9dd357530ec4 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala
@@ -17,7 +17,13 @@
 package org.apache.spark.sql.execution.datasources.v2.state.utils
 
 import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.types.{DataType, StructType}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, 
UnsafeRow}
+import 
org.apache.spark.sql.execution.datasources.v2.state.{StateDataSourceErrors, 
StateSourceOptions}
+import org.apache.spark.sql.execution.streaming.{StateVariableType, 
TransformWithStateVariableInfo}
+import org.apache.spark.sql.execution.streaming.state.StateStoreColFamilySchema
+import org.apache.spark.sql.types.{DataType, IntegerType, LongType, 
StringType, StructType}
+import org.apache.spark.util.ArrayImplicits._
 
 object SchemaUtil {
   def getSchemaAsDataType(schema: StructType, fieldName: String): DataType = {
@@ -30,4 +36,122 @@ object SchemaUtil {
           "schema" -> schema.toString()))
     }
   }
+
+  def getSourceSchema(
+      sourceOptions: StateSourceOptions,
+      keySchema: StructType,
+      valueSchema: StructType,
+      transformWithStateVariableInfoOpt: 
Option[TransformWithStateVariableInfo],
+      stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema]): 
StructType = {
+    if (sourceOptions.readChangeFeed) {
+      new StructType()
+        .add("batch_id", LongType)
+        .add("change_type", StringType)
+        .add("key", keySchema)
+        .add("value", valueSchema)
+        .add("partition_id", IntegerType)
+    } else if (transformWithStateVariableInfoOpt.isDefined) {
+      require(stateStoreColFamilySchemaOpt.isDefined)
+      generateSchemaForStateVar(transformWithStateVariableInfoOpt.get,
+        stateStoreColFamilySchemaOpt.get)
+    } else {
+      new StructType()
+        .add("key", keySchema)
+        .add("value", valueSchema)
+        .add("partition_id", IntegerType)
+    }
+  }
+
+  def unifyStateRowPair(pair: (UnsafeRow, UnsafeRow), partition: Int): 
InternalRow = {
+    val row = new GenericInternalRow(3)
+    row.update(0, pair._1)
+    row.update(1, pair._2)
+    row.update(2, partition)
+    row
+  }
+
+  def unifyStateRowPairWithTTL(
+      pair: (UnsafeRow, UnsafeRow),
+      valueSchema: StructType,
+      partition: Int): InternalRow = {
+    val row = new GenericInternalRow(4)
+    row.update(0, pair._1)
+    row.update(1, pair._2.get(0, valueSchema))
+    row.update(2, pair._2.get(1, LongType))
+    row.update(3, partition)
+    row
+  }
+
+  def isValidSchema(
+      sourceOptions: StateSourceOptions,
+      schema: StructType,
+      transformWithStateVariableInfoOpt: 
Option[TransformWithStateVariableInfo]): Boolean = {
+  val expectedTypes = Map(
+      "batch_id" -> classOf[LongType],
+      "change_type" -> classOf[StringType],
+      "key" -> classOf[StructType],
+      "value" -> classOf[StructType],
+      "partition_id" -> classOf[IntegerType],
+      "expiration_timestamp" -> classOf[LongType])
+
+    val expectedFieldNames = if (sourceOptions.readChangeFeed) {
+      Seq("batch_id", "change_type", "key", "value", "partition_id")
+    } else if (transformWithStateVariableInfoOpt.isDefined) {
+      val stateVarInfo = transformWithStateVariableInfoOpt.get
+      val hasTTLEnabled = stateVarInfo.ttlEnabled
+      val stateVarType = stateVarInfo.stateVariableType
+
+      stateVarType match {
+        case StateVariableType.ValueState =>
+          if (hasTTLEnabled) {
+            Seq("key", "value", "expiration_timestamp", "partition_id")
+          } else {
+            Seq("key", "value", "partition_id")
+          }
+
+        case _ =>
+          throw StateDataSourceErrors
+            .internalError(s"Unsupported state variable type $stateVarType")
+      }
+    } else {
+      Seq("key", "value", "partition_id")
+    }
+
+    if (schema.fieldNames.toImmutableArraySeq != expectedFieldNames) {
+      false
+    } else {
+      schema.fieldNames.forall { fieldName =>
+        expectedTypes(fieldName).isAssignableFrom(
+          SchemaUtil.getSchemaAsDataType(schema, fieldName).getClass)
+      }
+    }
+  }
+
+  private def generateSchemaForStateVar(
+      stateVarInfo: TransformWithStateVariableInfo,
+      stateStoreColFamilySchema: StateStoreColFamilySchema): StructType = {
+    val stateVarType = stateVarInfo.stateVariableType
+    val hasTTLEnabled = stateVarInfo.ttlEnabled
+
+    stateVarType match {
+      case StateVariableType.ValueState =>
+        if (hasTTLEnabled) {
+          val ttlValueSchema = SchemaUtil.getSchemaAsDataType(
+            stateStoreColFamilySchema.valueSchema, 
"value").asInstanceOf[StructType]
+          new StructType()
+            .add("key", stateStoreColFamilySchema.keySchema)
+            .add("value", ttlValueSchema)
+            .add("expiration_timestamp", LongType)
+            .add("partition_id", IntegerType)
+        } else {
+          new StructType()
+            .add("key", stateStoreColFamilySchema.keySchema)
+            .add("value", stateStoreColFamilySchema.valueSchema)
+            .add("partition_id", IntegerType)
+        }
+
+      case _ =>
+        throw StateDataSourceErrors.internalError(s"Unsupported state variable 
type $stateVarType")
+    }
+  }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala
index e6cdd0dce9ef..97c88037a717 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.execution.streaming.{CommitLog, 
MemoryStream, Offset
 import org.apache.spark.sql.execution.streaming.state._
 import org.apache.spark.sql.functions.col
 import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.streaming.OutputMode
+import org.apache.spark.sql.streaming.{OutputMode, TimeMode, 
TransformWithStateSuiteUtils}
 import org.apache.spark.sql.types.{IntegerType, StructType}
 
 class StateDataSourceNegativeTestSuite extends StateDataSourceTestBase {
@@ -268,6 +268,25 @@ class StateDataSourceNegativeTestSuite extends 
StateDataSourceTestBase {
           "message" -> s"value should be less than or equal to $endBatchId"))
     }
   }
+
+  test("ERROR: trying to specify state variable name with " +
+    "non-transformWithState operator") {
+    withTempDir { tempDir =>
+      runDropDuplicatesQuery(tempDir.getAbsolutePath)
+
+      val exc = intercept[StateDataSourceInvalidOptionValue] {
+        spark.read.format("statestore")
+          // trick to bypass getting the last committed batch before 
validating operator ID
+          .option(StateSourceOptions.BATCH_ID, 0)
+          .option(StateSourceOptions.STATE_VAR_NAME, "test")
+          .load(tempDir.getAbsolutePath)
+      }
+      checkError(exc, "STDS_INVALID_OPTION_VALUE.WITH_MESSAGE", Some("42616"),
+        Map("optionName" -> StateSourceOptions.STATE_VAR_NAME,
+          "message" -> ".*"),
+        matchPVals = true)
+    }
+  }
 }
 
 /**
@@ -429,6 +448,40 @@ class RocksDBStateDataSourceReadSuite extends 
StateDataSourceReadSuite {
     
spark.conf.set("spark.sql.streaming.stateStore.rocksdb.changelogCheckpointing.enabled",
       "false")
   }
+
+  test("ERROR: Do not provide state variable name with " +
+    "transformWithState operator") {
+    import testImplicits._
+    withTempDir { tempDir =>
+      withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+        classOf[RocksDBStateStoreProvider].getName,
+        SQLConf.SHUFFLE_PARTITIONS.key ->
+          TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) {
+        val inputData = MemoryStream[String]
+        val result = inputData.toDS()
+          .groupByKey(x => x)
+          .transformWithState(new StatefulProcessorWithSingleValueVar(),
+            TimeMode.None(),
+            OutputMode.Update())
+
+        testStream(result, OutputMode.Update())(
+          StartStream(checkpointLocation = tempDir.getAbsolutePath),
+          AddData(inputData, "a"),
+          CheckNewAnswer(("a", "1")),
+          StopStream
+        )
+
+        val e = intercept[StateDataSourceUnspecifiedRequiredOption] {
+          spark.read
+            .format("statestore")
+            .option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
+            .load()
+        }
+        checkError(e, "STDS_REQUIRED_OPTION_UNSPECIFIED", Some("42601"),
+          Map("optionName" -> "stateVarName"))
+      }
+    }
+  }
 }
 
 class RocksDBWithChangelogCheckpointStateDataSourceReaderSuite extends
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
new file mode 100644
index 000000000000..ccd4e005756a
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
@@ -0,0 +1,220 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2.state
+
+import java.time.Duration
+
+import org.apache.spark.sql.{Encoders, Row}
+import org.apache.spark.sql.execution.streaming.MemoryStream
+import 
org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled,
 RocksDBStateStoreProvider, TestClass}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{ExpiredTimerInfo, OutputMode, 
RunningCountStatefulProcessor, StatefulProcessor, StateStoreMetricsTest, 
TimeMode, TimerValues, TransformWithStateSuiteUtils, TTLConfig, ValueState}
+
+/** Stateful processor of single value state var with non-primitive type */
+class StatefulProcessorWithSingleValueVar extends 
RunningCountStatefulProcessor {
+  @transient private var _valueState: ValueState[TestClass] = _
+
+  override def init(
+      outputMode: OutputMode,
+      timeMode: TimeMode): Unit = {
+    _valueState = getHandle.getValueState[TestClass](
+      "valueState", Encoders.product[TestClass])
+  }
+
+  override def handleInputRows(
+      key: String,
+      inputRows: Iterator[String],
+      timerValues: TimerValues,
+      expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String)] = {
+    val count = _valueState.getOption().getOrElse(TestClass(0L, 
"dummyKey")).id + 1
+    _valueState.update(TestClass(count, "dummyKey"))
+    Iterator((key, count.toString))
+  }
+}
+
+class StatefulProcessorWithTTL
+  extends StatefulProcessor[String, String, (String, String)] {
+  @transient protected var _countState: ValueState[Long] = _
+
+  override def init(
+      outputMode: OutputMode,
+      timeMode: TimeMode): Unit = {
+    _countState = getHandle.getValueState[Long]("countState",
+      Encoders.scalaLong, TTLConfig(Duration.ofMillis(30000)))
+  }
+
+  override def handleInputRows(
+      key: String,
+      inputRows: Iterator[String],
+      timerValues: TimerValues,
+      expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String)] = {
+    val count = _countState.getOption().getOrElse(0L) + 1
+    if (count == 3) {
+      _countState.clear()
+      Iterator.empty
+    } else {
+      _countState.update(count)
+      Iterator((key, count.toString))
+    }
+  }
+}
+
+/**
+ * Test suite to verify integration of state data source reader with the 
transformWithState operator
+ */
+class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest
+  with AlsoTestWithChangelogCheckpointingEnabled {
+
+  import testImplicits._
+
+  test("state data source integration - value state with single variable") {
+    withTempDir { tempDir =>
+      withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+        classOf[RocksDBStateStoreProvider].getName,
+        SQLConf.SHUFFLE_PARTITIONS.key ->
+          TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) {
+        val inputData = MemoryStream[String]
+        val result = inputData.toDS()
+          .groupByKey(x => x)
+          .transformWithState(new StatefulProcessorWithSingleValueVar(),
+            TimeMode.None(),
+            OutputMode.Update())
+
+        testStream(result, OutputMode.Update())(
+          StartStream(checkpointLocation = tempDir.getAbsolutePath),
+          AddData(inputData, "a"),
+          CheckNewAnswer(("a", "1")),
+          AddData(inputData, "b"),
+          CheckNewAnswer(("b", "1")),
+          StopStream
+        )
+
+        val stateReaderDf = spark.read
+          .format("statestore")
+          .option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
+          .option(StateSourceOptions.STATE_VAR_NAME, "valueState")
+          .load()
+
+        val resultDf = stateReaderDf.selectExpr(
+          "key.value AS groupingKey",
+          "value.id AS valueId", "value.name AS valueName",
+          "partition_id")
+
+        checkAnswer(resultDf,
+          Seq(Row("a", 1L, "dummyKey", 0), Row("b", 1L, "dummyKey", 1)))
+
+        // non existent state variable should fail
+        val ex = intercept[Exception] {
+          spark.read
+            .format("statestore")
+            .option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
+            .option(StateSourceOptions.STATE_VAR_NAME, "non-exist")
+            .load()
+        }
+        assert(ex.isInstanceOf[StateDataSourceInvalidOptionValue])
+        assert(ex.getMessage.contains("State variable non-exist is not 
defined"))
+
+        // TODO: this should be removed when readChangeFeed is supported for 
value state
+        val ex1 = intercept[Exception] {
+          spark.read
+            .format("statestore")
+            .option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
+            .option(StateSourceOptions.STATE_VAR_NAME, "valueState")
+            .option(StateSourceOptions.READ_CHANGE_FEED, "true")
+            .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0)
+            .load()
+        }
+        assert(ex1.isInstanceOf[StateDataSourceConflictOptions])
+      }
+    }
+  }
+
+  test("state data source integration - value state with single variable and 
TTL") {
+    withTempDir { tempDir =>
+      withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+        classOf[RocksDBStateStoreProvider].getName,
+        SQLConf.SHUFFLE_PARTITIONS.key ->
+          TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) {
+        val inputData = MemoryStream[String]
+        val result = inputData.toDS()
+          .groupByKey(x => x)
+          .transformWithState(new StatefulProcessorWithTTL(),
+            TimeMode.ProcessingTime(),
+            OutputMode.Update())
+
+        testStream(result, OutputMode.Update())(
+          StartStream(checkpointLocation = tempDir.getAbsolutePath),
+          AddData(inputData, "a"),
+          AddData(inputData, "b"),
+          Execute { _ =>
+            // wait for the batch to run since we are using processing time
+            Thread.sleep(5000)
+          },
+          StopStream
+        )
+
+        val stateReaderDf = spark.read
+          .format("statestore")
+          .option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
+          .option(StateSourceOptions.STATE_VAR_NAME, "countState")
+          .load()
+
+        val resultDf = stateReaderDf.selectExpr(
+          "key.value", "value.value", "expiration_timestamp", "partition_id")
+
+        var count = 0L
+        resultDf.collect().foreach { row =>
+          count = count + 1
+          assert(row.getLong(2) > 0)
+        }
+
+        // verify that 2 state rows are present
+        assert(count === 2)
+
+        val answerDf = stateReaderDf.selectExpr(
+          "key.value AS groupingKey",
+          "value.value AS valueId", "partition_id")
+        checkAnswer(answerDf,
+          Seq(Row("a", 1L, 0), Row("b", 1L, 1)))
+
+        // non existent state variable should fail
+        val ex = intercept[Exception] {
+          spark.read
+            .format("statestore")
+            .option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
+            .option(StateSourceOptions.STATE_VAR_NAME, "non-exist")
+            .load()
+        }
+        assert(ex.isInstanceOf[StateDataSourceInvalidOptionValue])
+        assert(ex.getMessage.contains("State variable non-exist is not 
defined"))
+
+        // TODO: this should be removed when readChangeFeed is supported for 
TTL based state
+        // variables
+        val ex1 = intercept[Exception] {
+          spark.read
+            .format("statestore")
+            .option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
+            .option(StateSourceOptions.STATE_VAR_NAME, "countState")
+            .option(StateSourceOptions.READ_CHANGE_FEED, "true")
+            .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0)
+            .load()
+        }
+        assert(ex1.isInstanceOf[StateDataSourceConflictOptions])
+      }
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to