This is an automated email from the ASF dual-hosted git repository.
kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new d23023202185 [SPARK-49745][SS] Add change to read registered timers
through state data source reader
d23023202185 is described below
commit d23023202185f9fd175059caf7499251848c0758
Author: Anish Shrigondekar <[email protected]>
AuthorDate: Wed Sep 25 22:41:26 2024 +0900
[SPARK-49745][SS] Add change to read registered timers through state data
source reader
### What changes were proposed in this pull request?
Add change to read registered timers through state data source reader
### Why are the changes needed?
Without this, users cannot read registered timers per grouping key within
the transformWithState operator
### Does this PR introduce _any_ user-facing change?
Yes
Users can now read registered timers using the following query:
```
val stateReaderDf = spark.read
.format("statestore")
.option(StateSourceOptions.PATH, <checkpoint_loc>)
.option(StateSourceOptions.READ_REGISTERED_TIMERS, true)
.load()
```
### How was this patch tested?
Added unit tests
```
[info] Run completed in 20 seconds, 834 milliseconds.
[info] Total number of tests run: 4
[info] Suites: completed 1, aborted 0
[info] Tests: succeeded 4, failed 0, canceled 0, ignored 0, pending 0
[info] All tests passed.
```
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #48205 from anishshri-db/task/SPARK-49745.
Lead-authored-by: Anish Shrigondekar <[email protected]>
Co-authored-by: Jungtaek Lim <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
---
.../datasources/v2/state/StateDataSource.scala | 50 ++++++++--
.../v2/state/StatePartitionReader.scala | 5 +-
.../datasources/v2/state/utils/SchemaUtil.scala | 33 +++++++
.../StateStoreColumnFamilySchemaUtils.scala | 12 +++
.../streaming/StateTypesEncoderUtils.scala | 3 +
.../streaming/StatefulProcessorHandleImpl.scala | 16 +++
.../sql/execution/streaming/TimerStateImpl.scala | 9 ++
.../TransformWithStateVariableUtils.scala | 6 +-
.../v2/state/StateDataSourceReadSuite.scala | 19 ++++
.../StateDataSourceTransformWithStateSuite.scala | 109 ++++++++++++++++++++-
.../TransformWithValueStateTTLSuite.scala | 21 ++--
11 files changed, 263 insertions(+), 20 deletions(-)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
index 429464ea5438..39bc4dd9fb9c 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
@@ -29,15 +29,16 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.DataSourceOptions
import org.apache.spark.sql.connector.catalog.{Table, TableProvider}
import org.apache.spark.sql.connector.expressions.Transform
-import
org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.{JoinSideValues,
STATE_VAR_NAME}
+import
org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.{JoinSideValues,
READ_REGISTERED_TIMERS, STATE_VAR_NAME}
import
org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues.JoinSideValues
import
org.apache.spark.sql.execution.datasources.v2.state.metadata.{StateMetadataPartitionReader,
StateMetadataTableEntry}
import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil
-import org.apache.spark.sql.execution.streaming.{CommitLog, OffsetSeqLog,
OffsetSeqMetadata, TransformWithStateOperatorProperties,
TransformWithStateVariableInfo}
+import org.apache.spark.sql.execution.streaming.{CommitLog, OffsetSeqLog,
OffsetSeqMetadata, TimerStateUtils, TransformWithStateOperatorProperties,
TransformWithStateVariableInfo}
import
org.apache.spark.sql.execution.streaming.StreamingCheckpointConstants.{DIR_NAME_COMMITS,
DIR_NAME_OFFSETS, DIR_NAME_STATE}
import
org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{LeftSide,
RightSide}
import org.apache.spark.sql.execution.streaming.state.{KeyStateEncoderSpec,
NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec,
StateSchemaCompatibilityChecker, StateStore, StateStoreColFamilySchema,
StateStoreConf, StateStoreId, StateStoreProviderId}
import org.apache.spark.sql.sources.DataSourceRegister
+import org.apache.spark.sql.streaming.TimeMode
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.SerializableConfiguration
@@ -132,7 +133,7 @@ class StateDataSource extends TableProvider with
DataSourceRegister with Logging
sourceOptions: StateSourceOptions,
stateStoreMetadata: Array[StateMetadataTableEntry]): Unit = {
val twsShortName = "transformWithStateExec"
- if (sourceOptions.stateVarName.isDefined) {
+ if (sourceOptions.stateVarName.isDefined ||
sourceOptions.readRegisteredTimers) {
// Perform checks for transformWithState operator in case state variable
name is provided
require(stateStoreMetadata.size == 1)
val opMetadata = stateStoreMetadata.head
@@ -153,10 +154,21 @@ class StateDataSource extends TableProvider with
DataSourceRegister with Logging
"No state variable names are defined for the transformWithState
operator")
}
+ val twsOperatorProperties =
TransformWithStateOperatorProperties.fromJson(operatorProperties)
+ val timeMode = twsOperatorProperties.timeMode
+ if (sourceOptions.readRegisteredTimers && timeMode ==
TimeMode.None().toString) {
+ throw StateDataSourceErrors.invalidOptionValue(READ_REGISTERED_TIMERS,
+ "Registered timers are not available in TimeMode=None.")
+ }
+
// if the state variable is not one of the defined/available state
variables, then we
// fail the query
- val stateVarName = sourceOptions.stateVarName.get
- val twsOperatorProperties =
TransformWithStateOperatorProperties.fromJson(operatorProperties)
+ val stateVarName = if (sourceOptions.readRegisteredTimers) {
+ TimerStateUtils.getTimerStateVarName(timeMode)
+ } else {
+ sourceOptions.stateVarName.get
+ }
+
val stateVars = twsOperatorProperties.stateVariables
if (stateVars.filter(stateVar => stateVar.stateName ==
stateVarName).size != 1) {
throw StateDataSourceErrors.invalidOptionValue(STATE_VAR_NAME,
@@ -196,9 +208,10 @@ class StateDataSource extends TableProvider with
DataSourceRegister with Logging
var keyStateEncoderSpecOpt: Option[KeyStateEncoderSpec] = None
var stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema] = None
var transformWithStateVariableInfoOpt:
Option[TransformWithStateVariableInfo] = None
+ var timeMode: String = TimeMode.None.toString
if (sourceOptions.joinSide == JoinSideValues.none) {
- val stateVarName = sourceOptions.stateVarName
+ var stateVarName = sourceOptions.stateVarName
.getOrElse(StateStore.DEFAULT_COL_FAMILY_NAME)
// Read the schema file path from operator metadata version v2 onwards
@@ -208,6 +221,12 @@ class StateDataSource extends TableProvider with
DataSourceRegister with Logging
val storeMetadataEntry = storeMetadata.head
val operatorProperties = TransformWithStateOperatorProperties.fromJson(
storeMetadataEntry.operatorPropertiesJson)
+ timeMode = operatorProperties.timeMode
+
+ if (sourceOptions.readRegisteredTimers) {
+ stateVarName = TimerStateUtils.getTimerStateVarName(timeMode)
+ }
+
val stateVarInfoList = operatorProperties.stateVariables
.filter(stateVar => stateVar.stateName == stateVarName)
require(stateVarInfoList.size == 1, s"Failed to find unique state
variable info " +
@@ -304,6 +323,7 @@ case class StateSourceOptions(
fromSnapshotOptions: Option[FromSnapshotOptions],
readChangeFeedOptions: Option[ReadChangeFeedOptions],
stateVarName: Option[String],
+ readRegisteredTimers: Boolean,
flattenCollectionTypes: Boolean) {
def stateCheckpointLocation: Path = new Path(resolvedCpLocation,
DIR_NAME_STATE)
@@ -336,6 +356,7 @@ object StateSourceOptions extends DataSourceOptions {
val CHANGE_START_BATCH_ID = newOption("changeStartBatchId")
val CHANGE_END_BATCH_ID = newOption("changeEndBatchId")
val STATE_VAR_NAME = newOption("stateVarName")
+ val READ_REGISTERED_TIMERS = newOption("readRegisteredTimers")
val FLATTEN_COLLECTION_TYPES = newOption("flattenCollectionTypes")
object JoinSideValues extends Enumeration {
@@ -377,6 +398,19 @@ object StateSourceOptions extends DataSourceOptions {
val stateVarName = Option(options.get(STATE_VAR_NAME))
.map(_.trim)
+ val readRegisteredTimers = try {
+ Option(options.get(READ_REGISTERED_TIMERS))
+ .map(_.toBoolean).getOrElse(false)
+ } catch {
+ case _: IllegalArgumentException =>
+ throw StateDataSourceErrors.invalidOptionValue(READ_REGISTERED_TIMERS,
+ "Boolean value is expected")
+ }
+
+ if (readRegisteredTimers && stateVarName.isDefined) {
+ throw StateDataSourceErrors.conflictOptions(Seq(READ_REGISTERED_TIMERS,
STATE_VAR_NAME))
+ }
+
val flattenCollectionTypes = try {
Option(options.get(FLATTEN_COLLECTION_TYPES))
.map(_.toBoolean).getOrElse(true)
@@ -489,8 +523,8 @@ object StateSourceOptions extends DataSourceOptions {
StateSourceOptions(
resolvedCpLocation, batchId.get, operatorId, storeName, joinSide,
- readChangeFeed, fromSnapshotOptions, readChangeFeedOptions, stateVarName,
- flattenCollectionTypes)
+ readChangeFeed, fromSnapshotOptions, readChangeFeedOptions,
+ stateVarName, readRegisteredTimers, flattenCollectionTypes)
}
private def resolvedCheckpointLocation(
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
index ae12b18c1f62..d77d97f0057f 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
@@ -107,6 +107,8 @@ abstract class StatePartitionReaderBase(
useColumnFamilies = useColFamilies, storeConf, hadoopConf.value,
useMultipleValuesPerKey = useMultipleValuesPerKey)
+ val isInternal = partition.sourceOptions.readRegisteredTimers
+
if (useColFamilies) {
val store = provider.getStore(partition.sourceOptions.batchId + 1)
require(stateStoreColFamilySchemaOpt.isDefined)
@@ -117,7 +119,8 @@ abstract class StatePartitionReaderBase(
stateStoreColFamilySchema.keySchema,
stateStoreColFamilySchema.valueSchema,
stateStoreColFamilySchema.keyStateEncoderSpec.get,
- useMultipleValuesPerKey = useMultipleValuesPerKey)
+ useMultipleValuesPerKey = useMultipleValuesPerKey,
+ isInternal = isInternal)
}
provider
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala
index dc0d6af95114..c337d548fa42 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala
@@ -230,6 +230,7 @@ object SchemaUtil {
"map_value" -> classOf[MapType],
"user_map_key" -> classOf[StructType],
"user_map_value" -> classOf[StructType],
+ "expiration_timestamp_ms" -> classOf[LongType],
"partition_id" -> classOf[IntegerType])
val expectedFieldNames = if (sourceOptions.readChangeFeed) {
@@ -256,6 +257,9 @@ object SchemaUtil {
Seq("key", "map_value", "partition_id")
}
+ case TimerState =>
+ Seq("key", "expiration_timestamp_ms", "partition_id")
+
case _ =>
throw StateDataSourceErrors
.internalError(s"Unsupported state variable type $stateVarType")
@@ -322,6 +326,14 @@ object SchemaUtil {
.add("partition_id", IntegerType)
}
+ case TimerState =>
+ val groupingKeySchema = SchemaUtil.getSchemaAsDataType(
+ stateStoreColFamilySchema.keySchema, "key")
+ new StructType()
+ .add("key", groupingKeySchema)
+ .add("expiration_timestamp_ms", LongType)
+ .add("partition_id", IntegerType)
+
case _ =>
throw StateDataSourceErrors.internalError(s"Unsupported state variable
type $stateVarType")
}
@@ -407,9 +419,30 @@ object SchemaUtil {
unifyMapStateRowPair(store.iterator(stateVarName),
compositeKeySchema, partitionId, stateSourceOptions)
+ case StateVariableType.TimerState =>
+ store
+ .iterator(stateVarName)
+ .map { pair =>
+ unifyTimerRow(pair.key, compositeKeySchema, partitionId)
+ }
+
case _ =>
throw new IllegalStateException(
s"Unsupported state variable type: $stateVarType")
}
}
+
+ private def unifyTimerRow(
+ rowKey: UnsafeRow,
+ groupingKeySchema: StructType,
+ partitionId: Int): InternalRow = {
+ val groupingKey = rowKey.get(0, groupingKeySchema).asInstanceOf[UnsafeRow]
+ val expirationTimestamp = rowKey.getLong(1)
+
+ val row = new GenericInternalRow(3)
+ row.update(0, groupingKey)
+ row.update(1, expirationTimestamp)
+ row.update(2, partitionId)
+ row
+ }
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala
index 99229c6132eb..7da8408f98b0 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala
@@ -20,6 +20,7 @@ import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import
org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._
import
org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec,
PrefixKeyScanStateEncoderSpec, StateStoreColFamilySchema}
+import org.apache.spark.sql.types.StructType
object StateStoreColumnFamilySchemaUtils {
@@ -61,4 +62,15 @@ object StateStoreColumnFamilySchemaUtils {
Some(PrefixKeyScanStateEncoderSpec(compositeKeySchema, 1)),
Some(userKeyEnc.schema))
}
+
+ def getTimerStateSchema(
+ stateName: String,
+ keySchema: StructType,
+ valSchema: StructType): StateStoreColFamilySchema = {
+ StateStoreColFamilySchema(
+ stateName,
+ keySchema,
+ valSchema,
+ Some(PrefixKeyScanStateEncoderSpec(keySchema, 1)))
+ }
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala
index 1f5ad2fc8547..b70f9699195d 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala
@@ -288,6 +288,9 @@ class TimerKeyEncoder(keyExprEnc: ExpressionEncoder[Any]) {
.add("key", new StructType(keyExprEnc.schema.fields))
.add("expiryTimestampMs", LongType, nullable = false)
+ val schemaForValueRow: StructType =
+ StructType(Array(StructField("__dummy__", NullType)))
+
private val keySerializer = keyExprEnc.createSerializer()
private val keyDeserializer =
keyExprEnc.resolveAndBind().createDeserializer()
private val prefixKeyProjection = UnsafeProjection.create(schemaForPrefixKey)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
index 942d395dec0e..8beacbec7e6e 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
@@ -308,6 +308,12 @@ class DriverStatefulProcessorHandleImpl(timeMode:
TimeMode, keyExprEnc: Expressi
private val stateVariableInfos: mutable.Map[String,
TransformWithStateVariableInfo] =
new mutable.HashMap[String, TransformWithStateVariableInfo]()
+ // If timeMode is not None, add a timer column family schema to the operator
metadata so that
+ // registered timers can be read using the state data source reader.
+ if (timeMode != TimeMode.None()) {
+ addTimerColFamily()
+ }
+
def getColumnFamilySchemas: Map[String, StateStoreColFamilySchema] =
columnFamilySchemas.toMap
def getStateVariableInfos: Map[String, TransformWithStateVariableInfo] =
stateVariableInfos.toMap
@@ -318,6 +324,16 @@ class DriverStatefulProcessorHandleImpl(timeMode:
TimeMode, keyExprEnc: Expressi
}
}
+ private def addTimerColFamily(): Unit = {
+ val stateName = TimerStateUtils.getTimerStateVarName(timeMode.toString)
+ val timerEncoder = new TimerKeyEncoder(keyExprEnc)
+ val colFamilySchema = StateStoreColumnFamilySchemaUtils.
+ getTimerStateSchema(stateName, timerEncoder.schemaForKeyRow,
timerEncoder.schemaForValueRow)
+ columnFamilySchemas.put(stateName, colFamilySchema)
+ val stateVariableInfo =
TransformWithStateVariableUtils.getTimerState(stateName)
+ stateVariableInfos.put(stateName, stateVariableInfo)
+ }
+
override def getValueState[T](stateName: String, valEncoder: Encoder[T]):
ValueState[T] = {
verifyStateVarOperations("get_value_state", PRE_INIT)
val colFamilySchema = StateStoreColumnFamilySchemaUtils.
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala
index 82a4226fcfd5..d0fbaf660060 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala
@@ -34,6 +34,15 @@ object TimerStateUtils {
val EVENT_TIMERS_STATE_NAME = "$eventTimers"
val KEY_TO_TIMESTAMP_CF = "_keyToTimestamp"
val TIMESTAMP_TO_KEY_CF = "_timestampToKey"
+
+ def getTimerStateVarName(timeMode: String): String = {
+ assert(timeMode == TimeMode.EventTime.toString || timeMode ==
TimeMode.ProcessingTime.toString)
+ if (timeMode == TimeMode.EventTime.toString) {
+ TimerStateUtils.EVENT_TIMERS_STATE_NAME +
TimerStateUtils.KEY_TO_TIMESTAMP_CF
+ } else {
+ TimerStateUtils.PROC_TIMERS_STATE_NAME +
TimerStateUtils.KEY_TO_TIMESTAMP_CF
+ }
+ }
}
/**
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateVariableUtils.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateVariableUtils.scala
index 0a32564f973a..4a192b3e51c7 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateVariableUtils.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateVariableUtils.scala
@@ -43,12 +43,16 @@ object TransformWithStateVariableUtils {
def getMapState(stateName: String, ttlEnabled: Boolean):
TransformWithStateVariableInfo = {
TransformWithStateVariableInfo(stateName, StateVariableType.MapState,
ttlEnabled)
}
+
+ def getTimerState(stateName: String): TransformWithStateVariableInfo = {
+ TransformWithStateVariableInfo(stateName, StateVariableType.TimerState,
ttlEnabled = false)
+ }
}
// Enum of possible State Variable types
object StateVariableType extends Enumeration {
type StateVariableType = Value
- val ValueState, ListState, MapState = Value
+ val ValueState, ListState, MapState, TimerState = Value
}
case class TransformWithStateVariableInfo(
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala
index 8707facc4c12..5f55848d540d 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala
@@ -288,6 +288,25 @@ class StateDataSourceNegativeTestSuite extends
StateDataSourceTestBase {
}
}
+ test("ERROR: trying to specify state variable name along with " +
+ "readRegisteredTimers should fail") {
+ withTempDir { tempDir =>
+ val exc = intercept[StateDataSourceConflictOptions] {
+ spark.read.format("statestore")
+ // trick to bypass getting the last committed batch before
validating operator ID
+ .option(StateSourceOptions.BATCH_ID, 0)
+ .option(StateSourceOptions.STATE_VAR_NAME, "test")
+ .option(StateSourceOptions.READ_REGISTERED_TIMERS, true)
+ .load(tempDir.getAbsolutePath)
+ }
+ checkError(exc, "STDS_CONFLICT_OPTIONS", "42613",
+ Map("options" ->
+ s"['${
+ StateSourceOptions.READ_REGISTERED_TIMERS
+ }', '${StateSourceOptions.STATE_VAR_NAME}']"))
+ }
+ }
+
test("ERROR: trying to specify non boolean value for " +
"flattenCollectionTypes") {
withTempDir { tempDir =>
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
index 69df86fd5f74..bd047d1132fb 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
@@ -21,9 +21,9 @@ import java.time.Duration
import org.apache.spark.sql.{Encoders, Row}
import org.apache.spark.sql.execution.streaming.MemoryStream
import
org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled,
RocksDBStateStoreProvider, TestClass}
-import org.apache.spark.sql.functions.explode
+import org.apache.spark.sql.functions.{explode, timestamp_seconds}
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.streaming.{ExpiredTimerInfo, InputMapRow,
ListState, MapInputEvent, MapOutputEvent, MapStateTTLProcessor, OutputMode,
RunningCountStatefulProcessor, StatefulProcessor, StateStoreMetricsTest,
TestMapStateProcessor, TimeMode, TimerValues, TransformWithStateSuiteUtils,
Trigger, TTLConfig, ValueState}
+import org.apache.spark.sql.streaming.{ExpiredTimerInfo, InputMapRow,
ListState, MapInputEvent, MapOutputEvent, MapStateTTLProcessor,
MaxEventTimeStatefulProcessor, OutputMode, RunningCountStatefulProcessor,
RunningCountStatefulProcessorWithProcTimeTimerUpdates, StatefulProcessor,
StateStoreMetricsTest, TestMapStateProcessor, TimeMode, TimerValues,
TransformWithStateSuiteUtils, Trigger, TTLConfig, ValueState}
import org.apache.spark.sql.streaming.util.StreamManualClock
/** Stateful processor of single value state var with non-primitive type */
@@ -176,8 +176,19 @@ class StateDataSourceTransformWithStateSuite extends
StateStoreMetricsTest
assert(ex.isInstanceOf[StateDataSourceInvalidOptionValue])
assert(ex.getMessage.contains("State variable non-exist is not
defined"))
- // TODO: this should be removed when readChangeFeed is supported for
value state
+ // Verify that trying to read timers in TimeMode as None fails
val ex1 = intercept[Exception] {
+ spark.read
+ .format("statestore")
+ .option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
+ .option(StateSourceOptions.READ_REGISTERED_TIMERS, true)
+ .load()
+ }
+ assert(ex1.isInstanceOf[StateDataSourceInvalidOptionValue])
+ assert(ex1.getMessage.contains("Registered timers are not available"))
+
+ // TODO: this should be removed when readChangeFeed is supported for
value state
+ val ex2 = intercept[Exception] {
spark.read
.format("statestore")
.option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
@@ -186,7 +197,7 @@ class StateDataSourceTransformWithStateSuite extends
StateStoreMetricsTest
.option(StateSourceOptions.CHANGE_START_BATCH_ID, 0)
.load()
}
- assert(ex1.isInstanceOf[StateDataSourceConflictOptions])
+ assert(ex2.isInstanceOf[StateDataSourceConflictOptions])
}
}
}
@@ -563,4 +574,94 @@ class StateDataSourceTransformWithStateSuite extends
StateStoreMetricsTest
}
}
}
+
+ test("state data source - processing-time timers integration") {
+ withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+ classOf[RocksDBStateStoreProvider].getName,
+ SQLConf.SHUFFLE_PARTITIONS.key ->
+ TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) {
+ withTempDir { tempDir =>
+ val clock = new StreamManualClock
+
+ val inputData = MemoryStream[String]
+ val result = inputData.toDS()
+ .groupByKey(x => x)
+ .transformWithState(
+ new RunningCountStatefulProcessorWithProcTimeTimerUpdates(),
+ TimeMode.ProcessingTime(),
+ OutputMode.Update())
+
+ testStream(result, OutputMode.Update())(
+ StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock,
+ checkpointLocation = tempDir.getCanonicalPath),
+ AddData(inputData, "a"),
+ AdvanceManualClock(1 * 1000),
+ CheckNewAnswer(("a", "1")), // at batch 0, ts = 1, timer = "a" ->
[6] (= 1 + 5)
+ AddData(inputData, "a"),
+ AdvanceManualClock(2 * 1000),
+ CheckNewAnswer(("a", "2")), // at batch 1, ts = 3, timer = "a" ->
[10.5] (3 + 7.5)
+ StopStream)
+
+ val stateReaderDf = spark.read
+ .format("statestore")
+ .option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
+ .option(StateSourceOptions.READ_REGISTERED_TIMERS, true)
+ .load()
+
+ val resultDf = stateReaderDf.selectExpr(
+ "key.value AS groupingKey",
+ "expiration_timestamp_ms AS expiryTimestamp",
+ "partition_id")
+
+ checkAnswer(resultDf,
+ Seq(Row("a", 10500L, 0)))
+ }
+ }
+ }
+
+ test("state data source - event-time timers integration") {
+ withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+ classOf[RocksDBStateStoreProvider].getName,
+ SQLConf.SHUFFLE_PARTITIONS.key ->
+ TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) {
+ withTempDir { tempDir =>
+ val inputData = MemoryStream[(String, Int)]
+ val result =
+ inputData.toDS()
+ .select($"_1".as("key"), timestamp_seconds($"_2").as("eventTime"))
+ .withWatermark("eventTime", "10 seconds")
+ .as[(String, Long)]
+ .groupByKey(_._1)
+ .transformWithState(
+ new MaxEventTimeStatefulProcessor(),
+ TimeMode.EventTime(),
+ OutputMode.Update())
+
+ testStream(result, OutputMode.Update())(
+ StartStream(checkpointLocation = tempDir.getCanonicalPath),
+
+ AddData(inputData, ("a", 11), ("a", 13), ("a", 15)),
+ // Max event time = 15. Timeout timestamp for "a" = 15 + 5 = 20.
Watermark = 15 - 10 = 5.
+ CheckNewAnswer(("a", 15)), // Output = max event time of a
+
+ AddData(inputData, ("a", 4)), // Add data older than watermark for
"a"
+ CheckNewAnswer(), // No output as data should get filtered by
watermark
+ StopStream)
+
+ val stateReaderDf = spark.read
+ .format("statestore")
+ .option(StateSourceOptions.PATH, tempDir.getAbsolutePath)
+ .option(StateSourceOptions.READ_REGISTERED_TIMERS, true)
+ .load()
+
+ val resultDf = stateReaderDf.selectExpr(
+ "key.value AS groupingKey",
+ "expiration_timestamp_ms AS expiryTimestamp",
+ "partition_id")
+
+ checkAnswer(resultDf,
+ Seq(Row("a", 20000L, 0)))
+ }
+ }
+ }
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala
index 45056d104e84..1fbeaeb817bd 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala
@@ -23,7 +23,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.internal.Logging
import org.apache.spark.sql.Encoders
-import org.apache.spark.sql.execution.streaming.{CheckpointFileManager,
ListStateImplWithTTL, MapStateImplWithTTL, MemoryStream, ValueStateImpl,
ValueStateImplWithTTL}
+import org.apache.spark.sql.execution.streaming.{CheckpointFileManager,
ListStateImplWithTTL, MapStateImplWithTTL, MemoryStream, TimerStateUtils,
ValueStateImpl, ValueStateImplWithTTL}
import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.util.StreamManualClock
@@ -265,7 +265,16 @@ class TransformWithValueStateTTLSuite extends
TransformWithStateTTLTest {
val fm = CheckpointFileManager.create(stateSchemaPath, hadoopConf)
val keySchema = new StructType().add("value", StringType)
+ val schemaForKeyRow: StructType = new StructType()
+ .add("key", new StructType(keySchema.fields))
+ .add("expiryTimestampMs", LongType, nullable = false)
+ val schemaForValueRow: StructType =
StructType(Array(StructField("__dummy__", NullType)))
val schema0 = StateStoreColFamilySchema(
+
TimerStateUtils.getTimerStateVarName(TimeMode.ProcessingTime().toString),
+ schemaForKeyRow,
+ schemaForValueRow,
+ Some(PrefixKeyScanStateEncoderSpec(schemaForKeyRow, 1)))
+ val schema1 = StateStoreColFamilySchema(
"valueStateTTL",
keySchema,
new StructType().add("value",
@@ -275,14 +284,14 @@ class TransformWithValueStateTTLSuite extends
TransformWithStateTTLTest {
Some(NoPrefixKeyStateEncoderSpec(keySchema)),
None
)
- val schema1 = StateStoreColFamilySchema(
+ val schema2 = StateStoreColFamilySchema(
"valueState",
keySchema,
new StructType().add("value", IntegerType, false),
Some(NoPrefixKeyStateEncoderSpec(keySchema)),
None
)
- val schema2 = StateStoreColFamilySchema(
+ val schema3 = StateStoreColFamilySchema(
"listState",
keySchema,
new StructType().add("value",
@@ -300,7 +309,7 @@ class TransformWithValueStateTTLSuite extends
TransformWithStateTTLTest {
val compositeKeySchema = new StructType()
.add("key", new StructType().add("value", StringType))
.add("userKey", userKeySchema)
- val schema3 = StateStoreColFamilySchema(
+ val schema4 = StateStoreColFamilySchema(
"mapState",
compositeKeySchema,
new StructType().add("value",
@@ -351,9 +360,9 @@ class TransformWithValueStateTTLSuite extends
TransformWithStateTTLTest {
q.lastProgress.stateOperators.head.customMetrics
.get("numMapStateWithTTLVars").toInt)
- assert(colFamilySeq.length == 4)
+ assert(colFamilySeq.length == 5)
assert(colFamilySeq.map(_.toString).toSet == Set(
- schema0, schema1, schema2, schema3
+ schema0, schema1, schema2, schema3, schema4
).map(_.toString))
},
StopStream
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]