This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d23023202185 [SPARK-49745][SS] Add change to read registered timers 
through state data source reader
d23023202185 is described below

commit d23023202185f9fd175059caf7499251848c0758
Author: Anish Shrigondekar <[email protected]>
AuthorDate: Wed Sep 25 22:41:26 2024 +0900

    [SPARK-49745][SS] Add change to read registered timers through state data 
source reader
    
    ### What changes were proposed in this pull request?
    Add change to read registered timers through state data source reader
    
    ### Why are the changes needed?
    Without this, users cannot read registered timers per grouping key within 
the transformWithState operator
    
    ### Does this PR introduce _any_ user-facing change?
    Yes
    
    Users can now read registered timers using the following query:
    
    ```
            val stateReaderDf = spark.read
              .format("statestore")
              .option(StateSourceOptions.PATH, <checkpoint_loc>)
              .option(StateSourceOptions.READ_REGISTERED_TIMERS, true)
              .load()
    ```
    
    ### How was this patch tested?
    Added unit tests
    
    ```
    [info] Run completed in 20 seconds, 834 milliseconds.
    [info] Total number of tests run: 4
    [info] Suites: completed 1, aborted 0
    [info] Tests: succeeded 4, failed 0, canceled 0, ignored 0, pending 0
    [info] All tests passed.
    ```
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #48205 from anishshri-db/task/SPARK-49745.
    
    Lead-authored-by: Anish Shrigondekar <[email protected]>
    Co-authored-by: Jungtaek Lim <[email protected]>
    Signed-off-by: Jungtaek Lim <[email protected]>
---
 .../datasources/v2/state/StateDataSource.scala     |  50 ++++++++--
 .../v2/state/StatePartitionReader.scala            |   5 +-
 .../datasources/v2/state/utils/SchemaUtil.scala    |  33 +++++++
 .../StateStoreColumnFamilySchemaUtils.scala        |  12 +++
 .../streaming/StateTypesEncoderUtils.scala         |   3 +
 .../streaming/StatefulProcessorHandleImpl.scala    |  16 +++
 .../sql/execution/streaming/TimerStateImpl.scala   |   9 ++
 .../TransformWithStateVariableUtils.scala          |   6 +-
 .../v2/state/StateDataSourceReadSuite.scala        |  19 ++++
 .../StateDataSourceTransformWithStateSuite.scala   | 109 ++++++++++++++++++++-
 .../TransformWithValueStateTTLSuite.scala          |  21 ++--
 11 files changed, 263 insertions(+), 20 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
index 429464ea5438..39bc4dd9fb9c 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
@@ -29,15 +29,16 @@ import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.DataSourceOptions
 import org.apache.spark.sql.connector.catalog.{Table, TableProvider}
 import org.apache.spark.sql.connector.expressions.Transform
-import 
org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.{JoinSideValues,
 STATE_VAR_NAME}
+import 
org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.{JoinSideValues,
 READ_REGISTERED_TIMERS, STATE_VAR_NAME}
 import 
org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues.JoinSideValues
 import 
org.apache.spark.sql.execution.datasources.v2.state.metadata.{StateMetadataPartitionReader,
 StateMetadataTableEntry}
 import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
-import org.apache.spark.sql.execution.streaming.{CommitLog, OffsetSeqLog, 
OffsetSeqMetadata, TransformWithStateOperatorProperties, 
TransformWithStateVariableInfo}
+import org.apache.spark.sql.execution.streaming.{CommitLog, OffsetSeqLog, 
OffsetSeqMetadata, TimerStateUtils, TransformWithStateOperatorProperties, 
TransformWithStateVariableInfo}
 import 
org.apache.spark.sql.execution.streaming.StreamingCheckpointConstants.{DIR_NAME_COMMITS,
 DIR_NAME_OFFSETS, DIR_NAME_STATE}
 import 
org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{LeftSide,
 RightSide}
 import org.apache.spark.sql.execution.streaming.state.{KeyStateEncoderSpec, 
NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, 
StateSchemaCompatibilityChecker, StateStore, StateStoreColFamilySchema, 
StateStoreConf, StateStoreId, StateStoreProviderId}
 import org.apache.spark.sql.sources.DataSourceRegister
+import org.apache.spark.sql.streaming.TimeMode
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.sql.util.CaseInsensitiveStringMap
 import org.apache.spark.util.SerializableConfiguration
@@ -132,7 +133,7 @@ class StateDataSource extends TableProvider with 
DataSourceRegister with Logging
       sourceOptions: StateSourceOptions,
       stateStoreMetadata: Array[StateMetadataTableEntry]): Unit = {
     val twsShortName = "transformWithStateExec"
-    if (sourceOptions.stateVarName.isDefined) {
+    if (sourceOptions.stateVarName.isDefined || 
sourceOptions.readRegisteredTimers) {
       // Perform checks for transformWithState operator in case state variable 
name is provided
       require(stateStoreMetadata.size == 1)
       val opMetadata = stateStoreMetadata.head
@@ -153,10 +154,21 @@ class StateDataSource extends TableProvider with 
DataSourceRegister with Logging
           "No state variable names are defined for the transformWithState 
operator")
       }
 
+      val twsOperatorProperties = 
TransformWithStateOperatorProperties.fromJson(operatorProperties)
+      val timeMode = twsOperatorProperties.timeMode
+      if (sourceOptions.readRegisteredTimers && timeMode == 
TimeMode.None().toString) {
+        throw StateDataSourceErrors.invalidOptionValue(READ_REGISTERED_TIMERS,
+          "Registered timers are not available in TimeMode=None.")
+      }
+
       // if the state variable is not one of the defined/available state 
variables, then we
       // fail the query
-      val stateVarName = sourceOptions.stateVarName.get
-      val twsOperatorProperties = 
TransformWithStateOperatorProperties.fromJson(operatorProperties)
+      val stateVarName = if (sourceOptions.readRegisteredTimers) {
+        TimerStateUtils.getTimerStateVarName(timeMode)
+      } else {
+        sourceOptions.stateVarName.get
+      }
+
       val stateVars = twsOperatorProperties.stateVariables
       if (stateVars.filter(stateVar => stateVar.stateName == 
stateVarName).size != 1) {
         throw StateDataSourceErrors.invalidOptionValue(STATE_VAR_NAME,
@@ -196,9 +208,10 @@ class StateDataSource extends TableProvider with 
DataSourceRegister with Logging
     var keyStateEncoderSpecOpt: Option[KeyStateEncoderSpec] = None
     var stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema] = None
     var transformWithStateVariableInfoOpt: 
Option[TransformWithStateVariableInfo] = None
+    var timeMode: String = TimeMode.None.toString
 
     if (sourceOptions.joinSide == JoinSideValues.none) {
-      val stateVarName = sourceOptions.stateVarName
+      var stateVarName = sourceOptions.stateVarName
         .getOrElse(StateStore.DEFAULT_COL_FAMILY_NAME)
 
       // Read the schema file path from operator metadata version v2 onwards
@@ -208,6 +221,12 @@ class StateDataSource extends TableProvider with 
DataSourceRegister with Logging
         val storeMetadataEntry = storeMetadata.head
         val operatorProperties = TransformWithStateOperatorProperties.fromJson(
           storeMetadataEntry.operatorPropertiesJson)
+        timeMode = operatorProperties.timeMode
+
+        if (sourceOptions.readRegisteredTimers) {
+          stateVarName = TimerStateUtils.getTimerStateVarName(timeMode)
+        }
+
         val stateVarInfoList = operatorProperties.stateVariables
           .filter(stateVar => stateVar.stateName == stateVarName)
         require(stateVarInfoList.size == 1, s"Failed to find unique state 
variable info " +
@@ -304,6 +323,7 @@ case class StateSourceOptions(
     fromSnapshotOptions: Option[FromSnapshotOptions],
     readChangeFeedOptions: Option[ReadChangeFeedOptions],
     stateVarName: Option[String],
+    readRegisteredTimers: Boolean,
     flattenCollectionTypes: Boolean) {
   def stateCheckpointLocation: Path = new Path(resolvedCpLocation, 
DIR_NAME_STATE)
 
@@ -336,6 +356,7 @@ object StateSourceOptions extends DataSourceOptions {
   val CHANGE_START_BATCH_ID = newOption("changeStartBatchId")
   val CHANGE_END_BATCH_ID = newOption("changeEndBatchId")
   val STATE_VAR_NAME = newOption("stateVarName")
+  val READ_REGISTERED_TIMERS = newOption("readRegisteredTimers")
   val FLATTEN_COLLECTION_TYPES = newOption("flattenCollectionTypes")
 
   object JoinSideValues extends Enumeration {
@@ -377,6 +398,19 @@ object StateSourceOptions extends DataSourceOptions {
     val stateVarName = Option(options.get(STATE_VAR_NAME))
       .map(_.trim)
 
+    val readRegisteredTimers = try {
+      Option(options.get(READ_REGISTERED_TIMERS))
+        .map(_.toBoolean).getOrElse(false)
+    } catch {
+      case _: IllegalArgumentException =>
+        throw StateDataSourceErrors.invalidOptionValue(READ_REGISTERED_TIMERS,
+          "Boolean value is expected")
+    }
+
+    if (readRegisteredTimers && stateVarName.isDefined) {
+      throw StateDataSourceErrors.conflictOptions(Seq(READ_REGISTERED_TIMERS, 
STATE_VAR_NAME))
+    }
+
     val flattenCollectionTypes = try {
       Option(options.get(FLATTEN_COLLECTION_TYPES))
         .map(_.toBoolean).getOrElse(true)
@@ -489,8 +523,8 @@ object StateSourceOptions extends DataSourceOptions {
 
     StateSourceOptions(
       resolvedCpLocation, batchId.get, operatorId, storeName, joinSide,
-      readChangeFeed, fromSnapshotOptions, readChangeFeedOptions, stateVarName,
-      flattenCollectionTypes)
+      readChangeFeed, fromSnapshotOptions, readChangeFeedOptions,
+      stateVarName, readRegisteredTimers, flattenCollectionTypes)
   }
 
   private def resolvedCheckpointLocation(
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
index ae12b18c1f62..d77d97f0057f 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
@@ -107,6 +107,8 @@ abstract class StatePartitionReaderBase(
       useColumnFamilies = useColFamilies, storeConf, hadoopConf.value,
       useMultipleValuesPerKey = useMultipleValuesPerKey)
 
+    val isInternal = partition.sourceOptions.readRegisteredTimers
+
     if (useColFamilies) {
       val store = provider.getStore(partition.sourceOptions.batchId + 1)
       require(stateStoreColFamilySchemaOpt.isDefined)
@@ -117,7 +119,8 @@ abstract class StatePartitionReaderBase(
         stateStoreColFamilySchema.keySchema,
         stateStoreColFamilySchema.valueSchema,
         stateStoreColFamilySchema.keyStateEncoderSpec.get,
-        useMultipleValuesPerKey = useMultipleValuesPerKey)
+        useMultipleValuesPerKey = useMultipleValuesPerKey,
+        isInternal = isInternal)
     }
     provider
   }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala
index dc0d6af95114..c337d548fa42 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala
@@ -230,6 +230,7 @@ object SchemaUtil {
       "map_value" -> classOf[MapType],
       "user_map_key" -> classOf[StructType],
       "user_map_value" -> classOf[StructType],
+      "expiration_timestamp_ms" -> classOf[LongType],
       "partition_id" -> classOf[IntegerType])
 
     val expectedFieldNames = if (sourceOptions.readChangeFeed) {
@@ -256,6 +257,9 @@ object SchemaUtil {
             Seq("key", "map_value", "partition_id")
           }
 
+        case TimerState =>
+          Seq("key", "expiration_timestamp_ms", "partition_id")
+
         case _ =>
           throw StateDataSourceErrors
             .internalError(s"Unsupported state variable type $stateVarType")
@@ -322,6 +326,14 @@ object SchemaUtil {
             .add("partition_id", IntegerType)
         }
 
+      case TimerState =>
+        val groupingKeySchema = SchemaUtil.getSchemaAsDataType(
+          stateStoreColFamilySchema.keySchema, "key")
+        new StructType()
+          .add("key", groupingKeySchema)
+          .add("expiration_timestamp_ms", LongType)
+          .add("partition_id", IntegerType)
+
       case _ =>
         throw StateDataSourceErrors.internalError(s"Unsupported state variable 
type $stateVarType")
     }
@@ -407,9 +419,30 @@ object SchemaUtil {
         unifyMapStateRowPair(store.iterator(stateVarName),
           compositeKeySchema, partitionId, stateSourceOptions)
 
+      case StateVariableType.TimerState =>
+        store
+          .iterator(stateVarName)
+          .map { pair =>
+            unifyTimerRow(pair.key, compositeKeySchema, partitionId)
+          }
+
       case _ =>
         throw new IllegalStateException(
           s"Unsupported state variable type: $stateVarType")
     }
   }
+
+  private def unifyTimerRow(
+      rowKey: UnsafeRow,
+      groupingKeySchema: StructType,
+      partitionId: Int): InternalRow = {
+    val groupingKey = rowKey.get(0, groupingKeySchema).asInstanceOf[UnsafeRow]
+    val expirationTimestamp = rowKey.getLong(1)
+
+    val row = new GenericInternalRow(3)
+    row.update(0, groupingKey)
+    row.update(1, expirationTimestamp)
+    row.update(2, partitionId)
+    row
+  }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala
index 99229c6132eb..7da8408f98b0 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala
@@ -20,6 +20,7 @@ import org.apache.spark.sql.Encoder
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import 
org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._
 import 
org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, 
PrefixKeyScanStateEncoderSpec, StateStoreColFamilySchema}
+import org.apache.spark.sql.types.StructType
 
 object StateStoreColumnFamilySchemaUtils {
 
@@ -61,4 +62,15 @@ object StateStoreColumnFamilySchemaUtils {
       Some(PrefixKeyScanStateEncoderSpec(compositeKeySchema, 1)),
       Some(userKeyEnc.schema))
   }
+
+  def getTimerStateSchema(
+      stateName: String,
+      keySchema: StructType,
+      valSchema: StructType): StateStoreColFamilySchema = {
+    StateStoreColFamilySchema(
+      stateName,
+      keySchema,
+      valSchema,
+      Some(PrefixKeyScanStateEncoderSpec(keySchema, 1)))
+  }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala
index 1f5ad2fc8547..b70f9699195d 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala
@@ -288,6 +288,9 @@ class TimerKeyEncoder(keyExprEnc: ExpressionEncoder[Any]) {
     .add("key", new StructType(keyExprEnc.schema.fields))
     .add("expiryTimestampMs", LongType, nullable = false)
 
+  val schemaForValueRow: StructType =
+    StructType(Array(StructField("__dummy__", NullType)))
+
   private val keySerializer = keyExprEnc.createSerializer()
   private val keyDeserializer = 
keyExprEnc.resolveAndBind().createDeserializer()
   private val prefixKeyProjection = UnsafeProjection.create(schemaForPrefixKey)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
index 942d395dec0e..8beacbec7e6e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
@@ -308,6 +308,12 @@ class DriverStatefulProcessorHandleImpl(timeMode: 
TimeMode, keyExprEnc: Expressi
   private val stateVariableInfos: mutable.Map[String, 
TransformWithStateVariableInfo] =
     new mutable.HashMap[String, TransformWithStateVariableInfo]()
 
+  // If timeMode is not None, add a timer column family schema to the operator 
metadata so that
+  // registered timers can be read using the state data source reader.
+  if (timeMode != TimeMode.None()) {
+    addTimerColFamily()
+  }
+
   def getColumnFamilySchemas: Map[String, StateStoreColFamilySchema] = 
columnFamilySchemas.toMap
 
   def getStateVariableInfos: Map[String, TransformWithStateVariableInfo] = 
stateVariableInfos.toMap
@@ -318,6 +324,16 @@ class DriverStatefulProcessorHandleImpl(timeMode: 
TimeMode, keyExprEnc: Expressi
     }
   }
 
+  private def addTimerColFamily(): Unit = {
+    val stateName = TimerStateUtils.getTimerStateVarName(timeMode.toString)
+    val timerEncoder = new TimerKeyEncoder(keyExprEnc)
+    val colFamilySchema = StateStoreColumnFamilySchemaUtils.
+      getTimerStateSchema(stateName, timerEncoder.schemaForKeyRow, 
timerEncoder.schemaForValueRow)
+    columnFamilySchemas.put(stateName, colFamilySchema)
+    val stateVariableInfo = 
TransformWithStateVariableUtils.getTimerState(stateName)
+    stateVariableInfos.put(stateName, stateVariableInfo)
+  }
+
   override def getValueState[T](stateName: String, valEncoder: Encoder[T]): 
ValueState[T] = {
     verifyStateVarOperations("get_value_state", PRE_INIT)
     val colFamilySchema = StateStoreColumnFamilySchemaUtils.
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala
index 82a4226fcfd5..d0fbaf660060 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala
@@ -34,6 +34,15 @@ object TimerStateUtils {
   val EVENT_TIMERS_STATE_NAME = "$eventTimers"
   val KEY_TO_TIMESTAMP_CF = "_keyToTimestamp"
   val TIMESTAMP_TO_KEY_CF = "_timestampToKey"
+
+  def getTimerStateVarName(timeMode: String): String = {
+    assert(timeMode == TimeMode.EventTime.toString || timeMode == 
TimeMode.ProcessingTime.toString)
+    if (timeMode == TimeMode.EventTime.toString) {
+      TimerStateUtils.EVENT_TIMERS_STATE_NAME + 
TimerStateUtils.KEY_TO_TIMESTAMP_CF
+    } else {
+      TimerStateUtils.PROC_TIMERS_STATE_NAME + 
TimerStateUtils.KEY_TO_TIMESTAMP_CF
+    }
+  }
 }
 
 /**
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateVariableUtils.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateVariableUtils.scala
index 0a32564f973a..4a192b3e51c7 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateVariableUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateVariableUtils.scala
@@ -43,12 +43,16 @@ object TransformWithStateVariableUtils {
   def getMapState(stateName: String, ttlEnabled: Boolean): 
TransformWithStateVariableInfo = {
     TransformWithStateVariableInfo(stateName, StateVariableType.MapState, 
ttlEnabled)
   }
+
+  def getTimerState(stateName: String): TransformWithStateVariableInfo = {
+    TransformWithStateVariableInfo(stateName, StateVariableType.TimerState, 
ttlEnabled = false)
+  }
 }
 
 // Enum of possible State Variable types
 object StateVariableType extends Enumeration {
   type StateVariableType = Value
-  val ValueState, ListState, MapState = Value
+  val ValueState, ListState, MapState, TimerState = Value
 }
 
 case class TransformWithStateVariableInfo(
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala
index 8707facc4c12..5f55848d540d 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala
@@ -288,6 +288,25 @@ class StateDataSourceNegativeTestSuite extends 
StateDataSourceTestBase {
     }
   }
 
+  test("ERROR: trying to specify state variable name along with " +
+    "readRegisteredTimers should fail") {
+    withTempDir { tempDir =>
+      val exc = intercept[StateDataSourceConflictOptions] {
+        spark.read.format("statestore")
+          // trick to bypass getting the last committed batch before 
validating operator ID
+          .option(StateSourceOptions.BATCH_ID, 0)
+          .option(StateSourceOptions.STATE_VAR_NAME, "test")
+          .option(StateSourceOptions.READ_REGISTERED_TIMERS, true)
+          .load(tempDir.getAbsolutePath)
+      }
+      checkError(exc, "STDS_CONFLICT_OPTIONS", "42613",
+        Map("options" ->
+          s"['${
+            StateSourceOptions.READ_REGISTERED_TIMERS
+          }', '${StateSourceOptions.STATE_VAR_NAME}']"))
+    }
+  }
+
   test("ERROR: trying to specify non boolean value for " +
     "flattenCollectionTypes") {
     withTempDir { tempDir =>
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
index 69df86fd5f74..bd047d1132fb 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
@@ -21,9 +21,9 @@ import java.time.Duration
 import org.apache.spark.sql.{Encoders, Row}
 import org.apache.spark.sql.execution.streaming.MemoryStream
 import 
org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled,
 RocksDBStateStoreProvider, TestClass}
-import org.apache.spark.sql.functions.explode
+import org.apache.spark.sql.functions.{explode, timestamp_seconds}
 import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.streaming.{ExpiredTimerInfo, InputMapRow, 
ListState, MapInputEvent, MapOutputEvent, MapStateTTLProcessor, OutputMode, 
RunningCountStatefulProcessor, StatefulProcessor, StateStoreMetricsTest, 
TestMapStateProcessor, TimeMode, TimerValues, TransformWithStateSuiteUtils, 
Trigger, TTLConfig, ValueState}
+import org.apache.spark.sql.streaming.{ExpiredTimerInfo, InputMapRow, 
ListState, MapInputEvent, MapOutputEvent, MapStateTTLProcessor, 
MaxEventTimeStatefulProcessor, OutputMode, RunningCountStatefulProcessor, 
RunningCountStatefulProcessorWithProcTimeTimerUpdates, StatefulProcessor, 
StateStoreMetricsTest, TestMapStateProcessor, TimeMode, TimerValues, 
TransformWithStateSuiteUtils, Trigger, TTLConfig, ValueState}
 import org.apache.spark.sql.streaming.util.StreamManualClock
 
 /** Stateful processor of single value state var with non-primitive type */
@@ -176,8 +176,19 @@ class StateDataSourceTransformWithStateSuite extends 
StateStoreMetricsTest
         assert(ex.isInstanceOf[StateDataSourceInvalidOptionValue])
         assert(ex.getMessage.contains("State variable non-exist is not 
defined"))
 
-        // TODO: this should be removed when readChangeFeed is supported for 
value state
+        // Verify that trying to read timers in TimeMode as None fails
         val ex1 = intercept[Exception] {
+          spark.read
+            .format("statestore")
+            .option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
+            .option(StateSourceOptions.READ_REGISTERED_TIMERS, true)
+            .load()
+        }
+        assert(ex1.isInstanceOf[StateDataSourceInvalidOptionValue])
+        assert(ex1.getMessage.contains("Registered timers are not available"))
+
+        // TODO: this should be removed when readChangeFeed is supported for 
value state
+        val ex2 = intercept[Exception] {
           spark.read
             .format("statestore")
             .option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
@@ -186,7 +197,7 @@ class StateDataSourceTransformWithStateSuite extends 
StateStoreMetricsTest
             .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0)
             .load()
         }
-        assert(ex1.isInstanceOf[StateDataSourceConflictOptions])
+        assert(ex2.isInstanceOf[StateDataSourceConflictOptions])
       }
     }
   }
@@ -563,4 +574,94 @@ class StateDataSourceTransformWithStateSuite extends 
StateStoreMetricsTest
       }
     }
   }
+
+  test("state data source - processing-time timers integration") {
+    withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+      classOf[RocksDBStateStoreProvider].getName,
+      SQLConf.SHUFFLE_PARTITIONS.key ->
+        TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) {
+      withTempDir { tempDir =>
+        val clock = new StreamManualClock
+
+        val inputData = MemoryStream[String]
+        val result = inputData.toDS()
+          .groupByKey(x => x)
+          .transformWithState(
+            new RunningCountStatefulProcessorWithProcTimeTimerUpdates(),
+            TimeMode.ProcessingTime(),
+            OutputMode.Update())
+
+        testStream(result, OutputMode.Update())(
+          StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock,
+            checkpointLocation = tempDir.getCanonicalPath),
+          AddData(inputData, "a"),
+          AdvanceManualClock(1 * 1000),
+          CheckNewAnswer(("a", "1")), // at batch 0, ts = 1, timer = "a" -> 
[6] (= 1 + 5)
+          AddData(inputData, "a"),
+          AdvanceManualClock(2 * 1000),
+          CheckNewAnswer(("a", "2")), // at batch 1, ts = 3, timer = "a" -> 
[10.5] (3 + 7.5)
+          StopStream)
+
+        val stateReaderDf = spark.read
+          .format("statestore")
+          .option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
+          .option(StateSourceOptions.READ_REGISTERED_TIMERS, true)
+          .load()
+
+        val resultDf = stateReaderDf.selectExpr(
+       "key.value AS groupingKey",
+          "expiration_timestamp_ms AS expiryTimestamp",
+          "partition_id")
+
+        checkAnswer(resultDf,
+          Seq(Row("a", 10500L, 0)))
+      }
+    }
+  }
+
+  test("state data source - event-time timers integration") {
+    withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+      classOf[RocksDBStateStoreProvider].getName,
+      SQLConf.SHUFFLE_PARTITIONS.key ->
+        TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) {
+      withTempDir { tempDir =>
+        val inputData = MemoryStream[(String, Int)]
+        val result =
+          inputData.toDS()
+            .select($"_1".as("key"), timestamp_seconds($"_2").as("eventTime"))
+            .withWatermark("eventTime", "10 seconds")
+            .as[(String, Long)]
+            .groupByKey(_._1)
+            .transformWithState(
+              new MaxEventTimeStatefulProcessor(),
+              TimeMode.EventTime(),
+              OutputMode.Update())
+
+        testStream(result, OutputMode.Update())(
+          StartStream(checkpointLocation = tempDir.getCanonicalPath),
+
+          AddData(inputData, ("a", 11), ("a", 13), ("a", 15)),
+          // Max event time = 15. Timeout timestamp for "a" = 15 + 5 = 20. 
Watermark = 15 - 10 = 5.
+          CheckNewAnswer(("a", 15)), // Output = max event time of a
+
+          AddData(inputData, ("a", 4)), // Add data older than watermark for 
"a"
+          CheckNewAnswer(), // No output as data should get filtered by 
watermark
+          StopStream)
+
+        val stateReaderDf = spark.read
+          .format("statestore")
+          .option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
+          .option(StateSourceOptions.READ_REGISTERED_TIMERS, true)
+          .load()
+
+        val resultDf = stateReaderDf.selectExpr(
+          "key.value AS groupingKey",
+          "expiration_timestamp_ms AS expiryTimestamp",
+          "partition_id")
+
+        checkAnswer(resultDf,
+          Seq(Row("a", 20000L, 0)))
+      }
+    }
+  }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala
index 45056d104e84..1fbeaeb817bd 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala
@@ -23,7 +23,7 @@ import org.apache.hadoop.fs.Path
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.Encoders
-import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, 
ListStateImplWithTTL, MapStateImplWithTTL, MemoryStream, ValueStateImpl, 
ValueStateImplWithTTL}
+import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, 
ListStateImplWithTTL, MapStateImplWithTTL, MemoryStream, TimerStateUtils, 
ValueStateImpl, ValueStateImplWithTTL}
 import org.apache.spark.sql.execution.streaming.state._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming.util.StreamManualClock
@@ -265,7 +265,16 @@ class TransformWithValueStateTTLSuite extends 
TransformWithStateTTLTest {
         val fm = CheckpointFileManager.create(stateSchemaPath, hadoopConf)
 
         val keySchema = new StructType().add("value", StringType)
+        val schemaForKeyRow: StructType = new StructType()
+          .add("key", new StructType(keySchema.fields))
+          .add("expiryTimestampMs", LongType, nullable = false)
+        val schemaForValueRow: StructType = 
StructType(Array(StructField("__dummy__", NullType)))
         val schema0 = StateStoreColFamilySchema(
+          
TimerStateUtils.getTimerStateVarName(TimeMode.ProcessingTime().toString),
+          schemaForKeyRow,
+          schemaForValueRow,
+          Some(PrefixKeyScanStateEncoderSpec(schemaForKeyRow, 1)))
+        val schema1 = StateStoreColFamilySchema(
           "valueStateTTL",
           keySchema,
           new StructType().add("value",
@@ -275,14 +284,14 @@ class TransformWithValueStateTTLSuite extends 
TransformWithStateTTLTest {
           Some(NoPrefixKeyStateEncoderSpec(keySchema)),
           None
         )
-        val schema1 = StateStoreColFamilySchema(
+        val schema2 = StateStoreColFamilySchema(
           "valueState",
           keySchema,
           new StructType().add("value", IntegerType, false),
           Some(NoPrefixKeyStateEncoderSpec(keySchema)),
           None
         )
-        val schema2 = StateStoreColFamilySchema(
+        val schema3 = StateStoreColFamilySchema(
           "listState",
           keySchema,
           new StructType().add("value",
@@ -300,7 +309,7 @@ class TransformWithValueStateTTLSuite extends 
TransformWithStateTTLTest {
         val compositeKeySchema = new StructType()
           .add("key", new StructType().add("value", StringType))
           .add("userKey", userKeySchema)
-        val schema3 = StateStoreColFamilySchema(
+        val schema4 = StateStoreColFamilySchema(
           "mapState",
           compositeKeySchema,
           new StructType().add("value",
@@ -351,9 +360,9 @@ class TransformWithValueStateTTLSuite extends 
TransformWithStateTTLTest {
               q.lastProgress.stateOperators.head.customMetrics
                 .get("numMapStateWithTTLVars").toInt)
 
-            assert(colFamilySeq.length == 4)
+            assert(colFamilySeq.length == 5)
             assert(colFamilySeq.map(_.toString).toSet == Set(
-              schema0, schema1, schema2, schema3
+              schema0, schema1, schema2, schema3, schema4
             ).map(_.toString))
           },
           StopStream


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to