This is an automated email from the ASF dual-hosted git repository.
kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 47063a6a75f7 [SPARK-50128][SS] Add stateful processor handle APIs
using implicit encoders in Scala
47063a6a75f7 is described below
commit 47063a6a75f729b43dec8bdeb78c46ea29e5f98f
Author: Anish Shrigondekar <[email protected]>
AuthorDate: Tue Nov 5 12:20:16 2024 +0900
[SPARK-50128][SS] Add stateful processor handle APIs using implicit
encoders in Scala
### What changes were proposed in this pull request?
Add stateful processor handle APIs using implicit encoders in Scala
### Why are the changes needed?
Without the changes, users have to pass explicit SQL encoders for state
types while acquiring an instance of the underlying state variable
### Does this PR introduce _any_ user-facing change?
Yes
Users can now implicits available in Scala through `import
spark.implicits._` and only provide the type while getting the state objects.
For eg -
```
override def init(outputMode: OutputMode, timeMode: TimeMode): Unit =
{
_myValueState = getHandle.getValueState[Long]("myValueState",
TTLConfig.NONE)
}
```
### How was this patch tested?
Existing unit tests
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #48728 from anishshri-db/task/SPARK-50128.
Authored-by: Anish Shrigondekar <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
---
.../sql/streaming/StatefulProcessorHandle.scala | 98 +++++----
.../org/apache/spark/sql/streaming/TTLConfig.scala | 15 ++
.../TransformWithStateInPandasStateServer.scala | 26 ++-
.../sql/execution/streaming/ListStateImpl.scala | 5 +-
.../execution/streaming/ListStateImplWithTTL.scala | 9 +-
.../sql/execution/streaming/MapStateImpl.scala | 11 +-
.../execution/streaming/MapStateImplWithTTL.scala | 17 +-
.../streaming/StateTypesEncoderUtils.scala | 2 +-
.../streaming/StatefulProcessorHandleImpl.scala | 226 ++++++++++++---------
.../spark/sql/execution/streaming/TTLState.scala | 3 +-
.../sql/execution/streaming/ValueStateImpl.scala | 5 +-
.../streaming/ValueStateImplWithTTL.scala | 10 +-
.../apache/spark/sql/TestStatefulProcessor.java | 6 +-
.../sql/TestStatefulProcessorWithInitialState.java | 2 +-
.../StateDataSourceTransformWithStateSuite.scala | 4 +-
...ransformWithStateInPandasStateServerSuite.scala | 8 +-
.../execution/streaming/state/ListStateSuite.scala | 26 ++-
.../execution/streaming/state/MapStateSuite.scala | 22 +-
.../state/StatefulProcessorHandleSuite.scala | 24 ++-
.../streaming/state/ValueStateSuite.scala | 26 +--
.../streaming/TransformWithListStateSuite.scala | 7 +-
.../sql/streaming/TransformWithMapStateSuite.scala | 3 +-
.../TransformWithStateInitialStateSuite.scala | 20 +-
.../sql/streaming/TransformWithStateSuite.scala | 45 ++--
.../TransformWithValueStateTTLSuite.scala | 2 +-
25 files changed, 367 insertions(+), 255 deletions(-)
diff --git
a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala
b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala
index d1eca0f3967d..f458f0de37cb 100644
---
a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala
+++
b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala
@@ -29,25 +29,12 @@ import org.apache.spark.sql.Encoder
@Evolving
private[sql] trait StatefulProcessorHandle extends Serializable {
- /**
- * Function to create new or return existing single value state variable of
given type. The user
- * must ensure to call this function only within the `init()` method of the
StatefulProcessor.
- *
- * @param stateName
- * \- name of the state variable
- * @param valEncoder
- * \- SQL encoder for state variable
- * @tparam T
- * \- type of state variable
- * @return
- * \- instance of ValueState of type T that can be used to store state
persistently
- */
- def getValueState[T](stateName: String, valEncoder: Encoder[T]):
ValueState[T]
-
/**
* Function to create new or return existing single value state variable of
given type with ttl.
* State values will not be returned past ttlDuration, and will be
eventually removed from the
* state store. Any state update resets the ttl to current processing time
plus ttlDuration.
+ * Users can use the helper method `TTLConfig.NONE` in Scala or
`TTLConfig.NONE()` in Java for
+ * the TTLConfig parameter to disable TTL for the state variable.
*
* The user must ensure to call this function only within the `init()`
method of the
* StatefulProcessor.
@@ -69,25 +56,34 @@ private[sql] trait StatefulProcessorHandle extends
Serializable {
ttlConfig: TTLConfig): ValueState[T]
/**
- * Creates new or returns existing list state associated with stateName. The
ListState persists
- * values of type T.
+ * (Scala-specific) Function to create new or return existing single value
state variable of
+ * given type with ttl. State values will not be returned past ttlDuration,
and will be
+ * eventually removed from the state store. Any state update resets the ttl
to current
+ * processing time plus ttlDuration. Users can use the helper method
`TTLConfig.NONE` in Scala
+ * or `TTLConfig.NONE()` in Java for the TTLConfig parameter to disable TTL
for the state
+ * variable.
+ *
+ * The user must ensure to call this function only within the `init()`
method of the
+ * StatefulProcessor. Note that this API uses the implicit SQL encoder in
Scala.
*
* @param stateName
* \- name of the state variable
- * @param valEncoder
- * \- SQL encoder for state variable
+ * @param ttlConfig
+ * \- the ttl configuration (time to live duration etc.)
* @tparam T
* \- type of state variable
* @return
- * \- instance of ListState of type T that can be used to store state
persistently
+ * \- instance of ValueState of type T that can be used to store state
persistently
*/
- def getListState[T](stateName: String, valEncoder: Encoder[T]): ListState[T]
+ def getValueState[T: Encoder](stateName: String, ttlConfig: TTLConfig):
ValueState[T]
/**
* Function to create new or return existing list state variable of given
type with ttl. State
* values will not be returned past ttlDuration, and will be eventually
removed from the state
* store. Any values in listState which have expired after ttlDuration will
not be returned on
- * get() and will be eventually removed from the state.
+ * get() and will be eventually removed from the state. Users can use the
helper method
+ * `TTLConfig.NONE` in Scala or `TTLConfig.NONE()` in Java for the TTLConfig
parameter to
+ * disable TTL for the state variable.
*
* The user must ensure to call this function only within the `init()`
method of the
* StatefulProcessor.
@@ -109,32 +105,34 @@ private[sql] trait StatefulProcessorHandle extends
Serializable {
ttlConfig: TTLConfig): ListState[T]
/**
- * Creates new or returns existing map state associated with stateName. The
MapState persists
- * Key-Value pairs of type [K, V].
+ * (Scala-specific) Function to create new or return existing list state
variable of given type
+ * with ttl. State values will not be returned past ttlDuration, and will be
eventually removed
+ * from the state store. Any values in listState which have expired after
ttlDuration will not
+ * be returned on get() and will be eventually removed from the state. Users
can use the helper
+ * method `TTLConfig.NONE` in Scala or `TTLConfig.NONE()` in Java for the
TTLConfig parameter to
+ * disable TTL for the state variable.
+ *
+ * The user must ensure to call this function only within the `init()`
method of the
+ * StatefulProcessor. Note that this API uses the implicit SQL encoder in
Scala.
*
* @param stateName
* \- name of the state variable
- * @param userKeyEnc
- * \- spark sql encoder for the map key
- * @param valEncoder
- * \- spark sql encoder for the map value
- * @tparam K
- * \- type of key for map state variable
- * @tparam V
- * \- type of value for map state variable
+ * @param ttlConfig
+ * \- the ttl configuration (time to live duration etc.)
+ * @tparam T
+ * \- type of state variable
* @return
- * \- instance of MapState of type [K,V] that can be used to store state
persistently
+ * \- instance of ListState of type T that can be used to store state
persistently
*/
- def getMapState[K, V](
- stateName: String,
- userKeyEnc: Encoder[K],
- valEncoder: Encoder[V]): MapState[K, V]
+ def getListState[T: Encoder](stateName: String, ttlConfig: TTLConfig):
ListState[T]
/**
* Function to create new or return existing map state variable of given
type with ttl. State
* values will not be returned past ttlDuration, and will be eventually
removed from the state
* store. Any values in mapState which have expired after ttlDuration will
not returned on get()
- * and will be eventually removed from the state.
+ * and will be eventually removed from the state. Users can use the helper
method
+ * `TTLConfig.NONE` in Scala or `TTLConfig.NONE()` in Java for the TTLConfig
parameter to
+ * disable TTL for the state variable.
*
* The user must ensure to call this function only within the `init()`
method of the
* StatefulProcessor.
@@ -160,6 +158,30 @@ private[sql] trait StatefulProcessorHandle extends
Serializable {
valEncoder: Encoder[V],
ttlConfig: TTLConfig): MapState[K, V]
+ /**
+ * (Scala-specific) Function to create new or return existing map state
variable of given type
+ * with ttl. State values will not be returned past ttlDuration, and will be
eventually removed
+ * from the state store. Any values in mapState which have expired after
ttlDuration will not be
+ * returned on get() and will be eventually removed from the state. Users
can use the helper
+ * method `TTLConfig.NONE` in Scala or `TTLConfig.NONE()` in Java for the
TTLConfig parameter to
+ * disable TTL for the state variable.
+ *
+ * The user must ensure to call this function only within the `init()`
method of the
+ * StatefulProcessor. Note that this API uses the implicit SQL encoder in
Scala.
+ *
+ * @param stateName
+ * \- name of the state variable
+ * @param ttlConfig
+ * \- the ttl configuration (time to live duration etc.)
+ * @tparam K
+ * \- type of key for map state variable
+ * @tparam V
+ * \- type of value for map state variable
+ * @return
+ * \- instance of MapState of type [K,V] that can be used to store state
persistently
+ */
+ def getMapState[K: Encoder, V: Encoder](stateName: String, ttlConfig:
TTLConfig): MapState[K, V]
+
/** Function to return queryInfo for currently running task */
def getQueryInfo(): QueryInfo
diff --git
a/sql/api/src/main/scala/org/apache/spark/sql/streaming/TTLConfig.scala
b/sql/api/src/main/scala/org/apache/spark/sql/streaming/TTLConfig.scala
index ce786aa943d8..7ec4fbc8c1b5 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/TTLConfig.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/TTLConfig.scala
@@ -24,7 +24,22 @@ import java.time.Duration
* will be eventually removed from the state store. Any state update resets
the ttl to current
* processing time plus ttlDuration.
*
+ * Passing a TTL duration of zero will disable the TTL for the state variable.
Users can also use
+ * the helper method `TTLConfig.NONE` in Scala or `TTLConfig.NONE()` in Java
to disable TTL for
+ * the state variable.
+ *
* @param ttlDuration
* time to live duration for state stored in the state variable.
*/
case class TTLConfig(ttlDuration: Duration)
+
+object TTLConfig {
+
+ /**
+ * Helper method to create a TTLConfig with expiry duration as Zero
+ * @return
+ * \- TTLConfig with expiry duration as Zero
+ */
+ def NONE: TTLConfig = TTLConfig(Duration.ZERO)
+
+}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala
index 8a67d5d47f05..5f3ebd87e75e 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala
@@ -591,20 +591,23 @@ class TransformWithStateInPandasStateServer(
stateType match {
case StateVariableType.ValueState => if
(!valueStates.contains(stateName)) {
val state = if (ttlDurationMs.isEmpty) {
- statefulProcessorHandle.getValueState[Row](stateName,
Encoders.row(schema))
+ statefulProcessorHandle.getValueState[Row](stateName,
Encoders.row(schema),
+ TTLConfig.NONE)
+ } else {
+ statefulProcessorHandle.getValueState(
+ stateName, Encoders.row(schema),
TTLConfig(Duration.ofMillis(ttlDurationMs.get)))
+ }
+ valueStates.put(stateName,
+ ValueStateInfo(state, schema,
expressionEncoder.createDeserializer()))
+ sendResponse(0)
} else {
- statefulProcessorHandle.getValueState(
- stateName, Encoders.row(schema),
TTLConfig(Duration.ofMillis(ttlDurationMs.get)))
+ sendResponse(1, s"Value state $stateName already exists")
}
- valueStates.put(stateName,
- ValueStateInfo(state, schema,
expressionEncoder.createDeserializer()))
- sendResponse(0)
- } else {
- sendResponse(1, s"Value state $stateName already exists")
- }
+
case StateVariableType.ListState => if (!listStates.contains(stateName))
{
val state = if (ttlDurationMs.isEmpty) {
- statefulProcessorHandle.getListState[Row](stateName,
Encoders.row(schema))
+ statefulProcessorHandle.getListState[Row](stateName,
Encoders.row(schema),
+ TTLConfig.NONE)
} else {
statefulProcessorHandle.getListState(
stateName, Encoders.row(schema),
TTLConfig(Duration.ofMillis(ttlDurationMs.get)))
@@ -616,12 +619,13 @@ class TransformWithStateInPandasStateServer(
} else {
sendResponse(1, s"List state $stateName already exists")
}
+
case StateVariableType.MapState => if (!mapStates.contains(stateName)) {
val valueSchema = StructType.fromString(mapStateValueSchemaString)
val valueExpressionEncoder =
ExpressionEncoder(valueSchema).resolveAndBind()
val state = if (ttlDurationMs.isEmpty) {
statefulProcessorHandle.getMapState[Row, Row](stateName,
- Encoders.row(schema), Encoders.row(valueSchema))
+ Encoders.row(schema), Encoders.row(valueSchema), TTLConfig.NONE)
} else {
statefulProcessorHandle.getMapState[Row, Row](stateName,
Encoders.row(schema),
Encoders.row(valueSchema),
TTLConfig(Duration.ofMillis(ttlDurationMs.get)))
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala
index 77c481a8ba0b..32683aebd8c1 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.execution.streaming
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.execution.metric.SQLMetric
import
org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec,
StateStore, StateStoreErrors}
@@ -39,7 +38,7 @@ class ListStateImpl[S](
store: StateStore,
stateName: String,
keyExprEnc: ExpressionEncoder[Any],
- valEncoder: Encoder[S],
+ valEncoder: ExpressionEncoder[Any],
metrics: Map[String, SQLMetric] = Map.empty)
extends ListStateMetricsImpl
with ListState[S]
@@ -75,7 +74,7 @@ class ListStateImpl[S](
override def next(): S = {
val valueUnsafeRow = unsafeRowValuesIterator.next()
- stateTypesEncoder.decodeValue(valueUnsafeRow)
+ stateTypesEncoder.decodeValue(valueUnsafeRow).asInstanceOf[S]
}
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala
index be47f566bc6a..4c8dd6a193c2 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala
@@ -16,7 +16,6 @@
*/
package org.apache.spark.sql.execution.streaming
-import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.metric.SQLMetric
@@ -43,7 +42,7 @@ class ListStateImplWithTTL[S](
store: StateStore,
stateName: String,
keyExprEnc: ExpressionEncoder[Any],
- valEncoder: Encoder[S],
+ valEncoder: ExpressionEncoder[Any],
ttlConfig: TTLConfig,
batchTimestampMs: Long,
metrics: Map[String, SQLMetric] = Map.empty)
@@ -91,7 +90,7 @@ class ListStateImplWithTTL[S](
if (iter.hasNext) {
val currentRow = iter.next()
- stateTypesEncoder.decodeValue(currentRow)
+ stateTypesEncoder.decodeValue(currentRow).asInstanceOf[S]
} else {
finished = true
null.asInstanceOf[S]
@@ -223,7 +222,7 @@ class ListStateImplWithTTL[S](
val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey()
val unsafeRowValuesIterator = store.valuesIterator(encodedGroupingKey,
stateName)
unsafeRowValuesIterator.map { valueUnsafeRow =>
- stateTypesEncoder.decodeValue(valueUnsafeRow)
+ stateTypesEncoder.decodeValue(valueUnsafeRow).asInstanceOf[S]
}
}
@@ -234,7 +233,7 @@ class ListStateImplWithTTL[S](
val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey()
val unsafeRowValuesIterator = store.valuesIterator(encodedGroupingKey,
stateName)
unsafeRowValuesIterator.map { valueUnsafeRow =>
- (stateTypesEncoder.decodeValue(valueUnsafeRow),
+ (stateTypesEncoder.decodeValue(valueUnsafeRow).asInstanceOf[S],
stateTypesEncoder.decodeTtlExpirationMs(valueUnsafeRow).get)
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala
index cb3db19496dd..4e608a5d5dbb 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.execution.streaming
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.execution.metric.SQLMetric
import
org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._
@@ -40,8 +39,8 @@ class MapStateImpl[K, V](
store: StateStore,
stateName: String,
keyExprEnc: ExpressionEncoder[Any],
- userKeyEnc: Encoder[K],
- valEncoder: Encoder[V],
+ userKeyEnc: ExpressionEncoder[Any],
+ valEncoder: ExpressionEncoder[Any],
metrics: Map[String, SQLMetric] = Map.empty) extends MapState[K, V] with
Logging {
// Pack grouping key and user key together as a prefixed composite key
@@ -67,7 +66,7 @@ class MapStateImpl[K, V](
val unsafeRowValue = store.get(encodedCompositeKey, stateName)
if (unsafeRowValue == null) return null.asInstanceOf[V]
- stateTypesEncoder.decodeValue(unsafeRowValue)
+ stateTypesEncoder.decodeValue(unsafeRowValue).asInstanceOf[V]
}
/** Check if the user key is contained in the map */
@@ -92,8 +91,8 @@ class MapStateImpl[K, V](
store.prefixScan(encodedGroupingKey, stateName)
.map {
case iter: UnsafeRowPair =>
- (stateTypesEncoder.decodeCompositeKey(iter.key),
- stateTypesEncoder.decodeValue(iter.value))
+ (stateTypesEncoder.decodeCompositeKey(iter.key).asInstanceOf[K],
+ stateTypesEncoder.decodeValue(iter.value).asInstanceOf[V])
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala
index 6a3685ad6c46..19704b6d1bd5 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.execution.streaming
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.metric.SQLMetric
@@ -45,8 +44,8 @@ class MapStateImplWithTTL[K, V](
store: StateStore,
stateName: String,
keyExprEnc: ExpressionEncoder[Any],
- userKeyEnc: Encoder[K],
- valEncoder: Encoder[V],
+ userKeyEnc: ExpressionEncoder[Any],
+ valEncoder: ExpressionEncoder[Any],
ttlConfig: TTLConfig,
batchTimestampMs: Long,
metrics: Map[String, SQLMetric] = Map.empty)
@@ -83,7 +82,7 @@ class MapStateImplWithTTL[K, V](
if (retRow != null) {
if (!stateTypesEncoder.isExpired(retRow, batchTimestampMs)) {
- stateTypesEncoder.decodeValue(retRow)
+ stateTypesEncoder.decodeValue(retRow).asInstanceOf[V]
} else {
null.asInstanceOf[V]
}
@@ -126,7 +125,9 @@ class MapStateImplWithTTL[K, V](
if (iter.hasNext) {
val currentRowPair = iter.next()
val key = stateTypesEncoder.decodeCompositeKey(currentRowPair.key)
+ .asInstanceOf[K]
val value = stateTypesEncoder.decodeValue(currentRowPair.value)
+ .asInstanceOf[V]
(key, value)
} else {
finished = true
@@ -213,7 +214,7 @@ class MapStateImplWithTTL[K, V](
val retRow = store.get(encodedCompositeKey, stateName)
if (retRow != null) {
- val resState = stateTypesEncoder.decodeValue(retRow)
+ val resState = stateTypesEncoder.decodeValue(retRow).asInstanceOf[V]
Some(resState)
} else {
None
@@ -231,7 +232,9 @@ class MapStateImplWithTTL[K, V](
// ttlExpiration
Option(retRow).flatMap { row =>
val ttlExpiration = stateTypesEncoder.decodeTtlExpirationMs(row)
- ttlExpiration.map(expiration => (stateTypesEncoder.decodeValue(row),
expiration))
+ ttlExpiration.map { expiration =>
+ (stateTypesEncoder.decodeValue(row).asInstanceOf[V], expiration)
+ }
}
}
@@ -253,7 +256,7 @@ class MapStateImplWithTTL[K, V](
0, keyExprEnc.schema.length)) {
val userKey = stateTypesEncoder.decodeUserKey(
nextTtlValue.userKey)
- nextValue = Some(userKey, nextTtlValue.expirationMs)
+ nextValue = Some(userKey.asInstanceOf[K],
nextTtlValue.expirationMs)
}
}
nextValue.isDefined
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala
index b70f9699195d..d87de4c69c40 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala
@@ -254,7 +254,7 @@ class SingleKeyTTLEncoder(
/** Class for TTL with composite key serialization */
class CompositeKeyTTLEncoder[K](
keyExprEnc: ExpressionEncoder[Any],
- userKeyEnc: Encoder[K]) {
+ userKeyEnc: ExpressionEncoder[Any]) {
private val ttlKeyProjection = UnsafeProjection.create(
getCompositeKeyTTLRowSchema(keyExprEnc.schema, userKeyEnc.schema))
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
index 762dfc7d0892..0f90fa8d9e49 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
@@ -24,7 +24,7 @@ import scala.collection.mutable
import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.sql.Encoder
-import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
import org.apache.spark.sql.execution.metric.SQLMetric
import
org.apache.spark.sql.execution.streaming.StatefulProcessorHandleState.PRE_INIT
import org.apache.spark.sql.execution.streaming.state._
@@ -135,31 +135,6 @@ class StatefulProcessorHandleImpl(
private lazy val currQueryInfo: QueryInfo = buildQueryInfo()
- override def getValueState[T](
- stateName: String,
- valEncoder: Encoder[T]): ValueState[T] = {
- verifyStateVarOperations("get_value_state", CREATED)
- val resultState = new ValueStateImpl[T](store, stateName, keyEncoder,
valEncoder, metrics)
- TWSMetricsUtils.incrementMetric(metrics, "numValueStateVars")
- resultState
- }
-
- override def getValueState[T](
- stateName: String,
- valEncoder: Encoder[T],
- ttlConfig: TTLConfig): ValueState[T] = {
- verifyStateVarOperations("get_value_state", CREATED)
- validateTTLConfig(ttlConfig, stateName)
-
- assert(batchTimestampMs.isDefined)
- val valueStateWithTTL = new ValueStateImplWithTTL[T](store, stateName,
- keyEncoder, valEncoder, ttlConfig, batchTimestampMs.get, metrics)
- ttlStates.add(valueStateWithTTL)
- TWSMetricsUtils.incrementMetric(metrics, "numValueStateWithTTLVars")
-
- valueStateWithTTL
- }
-
override def getQueryInfo(): QueryInfo = currQueryInfo
private lazy val timerState = new TimerStateImpl(store, timeMode, keyEncoder)
@@ -230,11 +205,39 @@ class StatefulProcessorHandleImpl(
}
}
- override def getListState[T](stateName: String, valEncoder: Encoder[T]):
ListState[T] = {
- verifyStateVarOperations("get_list_state", CREATED)
- val resultState = new ListStateImpl[T](store, stateName, keyEncoder,
valEncoder, metrics)
- TWSMetricsUtils.incrementMetric(metrics, "numListStateVars")
- resultState
+ override def getValueState[T](
+ stateName: String,
+ valEncoder: Encoder[T],
+ ttlConfig: TTLConfig): ValueState[T] = {
+ getValueState(stateName, ttlConfig)(valEncoder)
+ }
+
+ override def getValueState[T: Encoder](
+ stateName: String,
+ ttlConfig: TTLConfig): ValueState[T] = {
+ verifyStateVarOperations("get_value_state", CREATED)
+ val ttlEnabled = if (ttlConfig.ttlDuration != null &&
ttlConfig.ttlDuration.isZero) {
+ false
+ } else {
+ true
+ }
+
+ val stateEncoder = encoderFor[T].asInstanceOf[ExpressionEncoder[Any]]
+ val result = if (ttlEnabled) {
+ validateTTLConfig(ttlConfig, stateName)
+ assert(batchTimestampMs.isDefined)
+ val valueStateWithTTL = new ValueStateImplWithTTL[T](store, stateName,
+ keyEncoder, stateEncoder, ttlConfig, batchTimestampMs.get, metrics)
+ ttlStates.add(valueStateWithTTL)
+ TWSMetricsUtils.incrementMetric(metrics, "numValueStateWithTTLVars")
+ valueStateWithTTL
+ } else {
+ val valueStateWithoutTTL = new ValueStateImpl[T](store, stateName,
+ keyEncoder, stateEncoder, metrics)
+ TWSMetricsUtils.incrementMetric(metrics, "numValueStateVars")
+ valueStateWithoutTTL
+ }
+ result
}
/**
@@ -256,45 +259,72 @@ class StatefulProcessorHandleImpl(
stateName: String,
valEncoder: Encoder[T],
ttlConfig: TTLConfig): ListState[T] = {
+ getListState(stateName, ttlConfig)(valEncoder)
+ }
+ override def getListState[T: Encoder](stateName: String, ttlConfig:
TTLConfig): ListState[T] = {
verifyStateVarOperations("get_list_state", CREATED)
- validateTTLConfig(ttlConfig, stateName)
- assert(batchTimestampMs.isDefined)
- val listStateWithTTL = new ListStateImplWithTTL[T](store, stateName,
- keyEncoder, valEncoder, ttlConfig, batchTimestampMs.get, metrics)
- TWSMetricsUtils.incrementMetric(metrics, "numListStateWithTTLVars")
- ttlStates.add(listStateWithTTL)
+ val ttlEnabled = if (ttlConfig.ttlDuration != null &&
ttlConfig.ttlDuration.isZero) {
+ false
+ } else {
+ true
+ }
- listStateWithTTL
+ val stateEncoder = encoderFor[T].asInstanceOf[ExpressionEncoder[Any]]
+ val result = if (ttlEnabled) {
+ validateTTLConfig(ttlConfig, stateName)
+ assert(batchTimestampMs.isDefined)
+ val listStateWithTTL = new ListStateImplWithTTL[T](store, stateName,
+ keyEncoder, stateEncoder, ttlConfig, batchTimestampMs.get, metrics)
+ TWSMetricsUtils.incrementMetric(metrics, "numListStateWithTTLVars")
+ ttlStates.add(listStateWithTTL)
+ listStateWithTTL
+ } else {
+ val listStateWithoutTTL = new ListStateImpl[T](store, stateName,
keyEncoder,
+ stateEncoder, metrics)
+ TWSMetricsUtils.incrementMetric(metrics, "numListStateVars")
+ listStateWithoutTTL
+ }
+ result
}
override def getMapState[K, V](
stateName: String,
userKeyEnc: Encoder[K],
- valEncoder: Encoder[V]): MapState[K, V] = {
- verifyStateVarOperations("get_map_state", CREATED)
- val resultState = new MapStateImpl[K, V](store, stateName, keyEncoder,
- userKeyEnc, valEncoder, metrics)
- TWSMetricsUtils.incrementMetric(metrics, "numMapStateVars")
- resultState
+ valEncoder: Encoder[V],
+ ttlConfig: TTLConfig): MapState[K, V] = {
+ getMapState(stateName, ttlConfig)(userKeyEnc, valEncoder)
}
- override def getMapState[K, V](
+ override def getMapState[K: Encoder, V: Encoder](
stateName: String,
- userKeyEnc: Encoder[K],
- valEncoder: Encoder[V],
ttlConfig: TTLConfig): MapState[K, V] = {
verifyStateVarOperations("get_map_state", CREATED)
- validateTTLConfig(ttlConfig, stateName)
- assert(batchTimestampMs.isDefined)
- val mapStateWithTTL = new MapStateImplWithTTL[K, V](store, stateName,
keyEncoder, userKeyEnc,
- valEncoder, ttlConfig, batchTimestampMs.get, metrics)
- TWSMetricsUtils.incrementMetric(metrics, "numMapStateWithTTLVars")
- ttlStates.add(mapStateWithTTL)
+ val ttlEnabled = if (ttlConfig.ttlDuration != null &&
ttlConfig.ttlDuration.isZero) {
+ false
+ } else {
+ true
+ }
- mapStateWithTTL
+ val userKeyEnc = encoderFor[K].asInstanceOf[ExpressionEncoder[Any]]
+ val valEncoder = encoderFor[V].asInstanceOf[ExpressionEncoder[Any]]
+ val result = if (ttlEnabled) {
+ validateTTLConfig(ttlConfig, stateName)
+ assert(batchTimestampMs.isDefined)
+ val mapStateWithTTL = new MapStateImplWithTTL[K, V](store, stateName,
keyEncoder, userKeyEnc,
+ valEncoder, ttlConfig, batchTimestampMs.get, metrics)
+ TWSMetricsUtils.incrementMetric(metrics, "numMapStateWithTTLVars")
+ ttlStates.add(mapStateWithTTL)
+ mapStateWithTTL
+ } else {
+ val mapStateWithoutTTL = new MapStateImpl[K, V](store, stateName,
keyEncoder,
+ userKeyEnc, valEncoder, metrics)
+ TWSMetricsUtils.incrementMetric(metrics, "numMapStateVars")
+ mapStateWithoutTTL
+ }
+ result
}
private def validateTTLConfig(ttlConfig: TTLConfig, stateName: String): Unit
= {
@@ -350,56 +380,58 @@ class DriverStatefulProcessorHandleImpl(timeMode:
TimeMode, keyExprEnc: Expressi
stateVariableInfos.put(stateName, stateVariableInfo)
}
- override def getValueState[T](stateName: String, valEncoder: Encoder[T]):
ValueState[T] = {
- verifyStateVarOperations("get_value_state", PRE_INIT)
- val colFamilySchema = StateStoreColumnFamilySchemaUtils.
- getValueStateSchema(stateName, keyExprEnc, valEncoder, false)
- checkIfDuplicateVariableDefined(stateName)
- columnFamilySchemas.put(stateName, colFamilySchema)
- val stateVariableInfo = TransformWithStateVariableUtils.
- getValueState(stateName, ttlEnabled = false)
- stateVariableInfos.put(stateName, stateVariableInfo)
- null.asInstanceOf[ValueState[T]]
- }
-
override def getValueState[T](
stateName: String,
valEncoder: Encoder[T],
ttlConfig: TTLConfig): ValueState[T] = {
- verifyStateVarOperations("get_value_state", PRE_INIT)
- val colFamilySchema = StateStoreColumnFamilySchemaUtils.
- getValueStateSchema(stateName, keyExprEnc, valEncoder, true)
- checkIfDuplicateVariableDefined(stateName)
- columnFamilySchemas.put(stateName, colFamilySchema)
- val stateVariableInfo = TransformWithStateVariableUtils.
- getValueState(stateName, ttlEnabled = true)
- stateVariableInfos.put(stateName, stateVariableInfo)
- null.asInstanceOf[ValueState[T]]
+ getValueState(stateName, ttlConfig)(valEncoder)
}
- override def getListState[T](stateName: String, valEncoder: Encoder[T]):
ListState[T] = {
- verifyStateVarOperations("get_list_state", PRE_INIT)
+ override def getValueState[T: Encoder](
+ stateName: String,
+ ttlConfig: TTLConfig): ValueState[T] = {
+ verifyStateVarOperations("get_value_state", PRE_INIT)
+ val ttlEnabled = if (ttlConfig.ttlDuration != null &&
ttlConfig.ttlDuration.isZero) {
+ false
+ } else {
+ true
+ }
+
+ val stateEncoder = encoderFor[T]
val colFamilySchema = StateStoreColumnFamilySchemaUtils.
- getListStateSchema(stateName, keyExprEnc, valEncoder, false)
+ getValueStateSchema(stateName, keyExprEnc, stateEncoder, ttlEnabled)
checkIfDuplicateVariableDefined(stateName)
columnFamilySchemas.put(stateName, colFamilySchema)
val stateVariableInfo = TransformWithStateVariableUtils.
- getListState(stateName, ttlEnabled = false)
+ getValueState(stateName, ttlEnabled = ttlEnabled)
stateVariableInfos.put(stateName, stateVariableInfo)
- null.asInstanceOf[ListState[T]]
+ null.asInstanceOf[ValueState[T]]
}
override def getListState[T](
stateName: String,
valEncoder: Encoder[T],
ttlConfig: TTLConfig): ListState[T] = {
+ getListState(stateName, ttlConfig)(valEncoder)
+ }
+
+ override def getListState[T: Encoder](
+ stateName: String,
+ ttlConfig: TTLConfig): ListState[T] = {
verifyStateVarOperations("get_list_state", PRE_INIT)
+ val ttlEnabled = if (ttlConfig.ttlDuration != null &&
ttlConfig.ttlDuration.isZero) {
+ false
+ } else {
+ true
+ }
+
+ val stateEncoder = encoderFor[T]
val colFamilySchema = StateStoreColumnFamilySchemaUtils.
- getListStateSchema(stateName, keyExprEnc, valEncoder, true)
+ getListStateSchema(stateName, keyExprEnc, stateEncoder, ttlEnabled)
checkIfDuplicateVariableDefined(stateName)
columnFamilySchemas.put(stateName, colFamilySchema)
val stateVariableInfo = TransformWithStateVariableUtils.
- getListState(stateName, ttlEnabled = true)
+ getListState(stateName, ttlEnabled = ttlEnabled)
stateVariableInfos.put(stateName, stateVariableInfo)
null.asInstanceOf[ListState[T]]
}
@@ -407,29 +439,29 @@ class DriverStatefulProcessorHandleImpl(timeMode:
TimeMode, keyExprEnc: Expressi
override def getMapState[K, V](
stateName: String,
userKeyEnc: Encoder[K],
- valEncoder: Encoder[V]): MapState[K, V] = {
- verifyStateVarOperations("get_map_state", PRE_INIT)
- val colFamilySchema = StateStoreColumnFamilySchemaUtils.
- getMapStateSchema(stateName, keyExprEnc, userKeyEnc, valEncoder, false)
- checkIfDuplicateVariableDefined(stateName)
- columnFamilySchemas.put(stateName, colFamilySchema)
- val stateVariableInfo = TransformWithStateVariableUtils.
- getMapState(stateName, ttlEnabled = false)
- stateVariableInfos.put(stateName, stateVariableInfo)
- null.asInstanceOf[MapState[K, V]]
+ valEncoder: Encoder[V],
+ ttlConfig: TTLConfig): MapState[K, V] = {
+ getMapState(stateName, ttlConfig)(userKeyEnc, valEncoder)
}
- override def getMapState[K, V](
+ override def getMapState[K: Encoder, V: Encoder](
stateName: String,
- userKeyEnc: Encoder[K],
- valEncoder: Encoder[V],
ttlConfig: TTLConfig): MapState[K, V] = {
verifyStateVarOperations("get_map_state", PRE_INIT)
+
+ val ttlEnabled = if (ttlConfig.ttlDuration != null &&
ttlConfig.ttlDuration.isZero) {
+ false
+ } else {
+ true
+ }
+
+ val userKeyEnc = encoderFor[K]
+ val valEncoder = encoderFor[V]
val colFamilySchema = StateStoreColumnFamilySchemaUtils.
- getMapStateSchema(stateName, keyExprEnc, userKeyEnc, valEncoder, true)
+ getMapStateSchema(stateName, keyExprEnc, userKeyEnc, valEncoder,
ttlEnabled)
columnFamilySchemas.put(stateName, colFamilySchema)
val stateVariableInfo = TransformWithStateVariableUtils.
- getMapState(stateName, ttlEnabled = true)
+ getMapState(stateName, ttlEnabled = ttlEnabled)
stateVariableInfos.put(stateName, stateVariableInfo)
null.asInstanceOf[MapState[K, V]]
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala
index 8811c59a5074..87d1a15dff1a 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala
@@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.streaming
import java.time.Duration
-import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow}
@@ -199,7 +198,7 @@ abstract class CompositeKeyTTLStateImpl[K](
stateName: String,
store: StateStore,
keyExprEnc: ExpressionEncoder[Any],
- userKeyEncoder: Encoder[K],
+ userKeyEncoder: ExpressionEncoder[Any],
ttlExpirationMs: Long)
extends TTLState {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala
index b1b87feeb263..cd66bf99d4e1 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.execution.streaming
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.execution.metric.SQLMetric
import
org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec,
StateStore}
@@ -37,7 +36,7 @@ class ValueStateImpl[S](
store: StateStore,
stateName: String,
keyExprEnc: ExpressionEncoder[Any],
- valEncoder: Encoder[S],
+ valEncoder: ExpressionEncoder[Any],
metrics: Map[String, SQLMetric] = Map.empty)
extends ValueState[S] with Logging {
@@ -66,7 +65,7 @@ class ValueStateImpl[S](
val retRow = store.get(encodedGroupingKey, stateName)
if (retRow != null) {
- stateTypesEncoder.decodeValue(retRow)
+ stateTypesEncoder.decodeValue(retRow).asInstanceOf[S]
} else {
null.asInstanceOf[S]
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala
index 145cd9026491..60eea5842645 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala
@@ -16,7 +16,6 @@
*/
package org.apache.spark.sql.execution.streaming
-import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.metric.SQLMetric
@@ -41,7 +40,7 @@ class ValueStateImplWithTTL[S](
store: StateStore,
stateName: String,
keyExprEnc: ExpressionEncoder[Any],
- valEncoder: Encoder[S],
+ valEncoder: ExpressionEncoder[Any],
ttlConfig: TTLConfig,
batchTimestampMs: Long,
metrics: Map[String, SQLMetric] = Map.empty)
@@ -80,7 +79,7 @@ class ValueStateImplWithTTL[S](
val resState = stateTypesEncoder.decodeValue(retRow)
if (!stateTypesEncoder.isExpired(retRow, batchTimestampMs)) {
- resState
+ resState.asInstanceOf[S]
} else {
null.asInstanceOf[S]
}
@@ -136,7 +135,7 @@ class ValueStateImplWithTTL[S](
val retRow = store.get(encodedGroupingKey, stateName)
if (retRow != null) {
- val resState = stateTypesEncoder.decodeValue(retRow)
+ val resState = stateTypesEncoder.decodeValue(retRow).asInstanceOf[S]
Some(resState)
} else {
None
@@ -154,7 +153,8 @@ class ValueStateImplWithTTL[S](
// ttlExpiration
if (retRow != null) {
val ttlExpiration = stateTypesEncoder.decodeTtlExpirationMs(retRow)
- ttlExpiration.map(expiration => (stateTypesEncoder.decodeValue(retRow),
expiration))
+ ttlExpiration.map(expiration =>
(stateTypesEncoder.decodeValue(retRow).asInstanceOf[S],
+ expiration))
} else {
None
}
diff --git
a/sql/core/src/test/java/test/org/apache/spark/sql/TestStatefulProcessor.java
b/sql/core/src/test/java/test/org/apache/spark/sql/TestStatefulProcessor.java
index 3f0efe2e14af..9b0beb39bf13 100644
---
a/sql/core/src/test/java/test/org/apache/spark/sql/TestStatefulProcessor.java
+++
b/sql/core/src/test/java/test/org/apache/spark/sql/TestStatefulProcessor.java
@@ -42,13 +42,13 @@ public class TestStatefulProcessor extends
StatefulProcessor<Integer, String, St
OutputMode outputMode,
TimeMode timeMode) {
countState = this.getHandle().getValueState("countState",
- Encoders.LONG());
+ Encoders.LONG(), TTLConfig.NONE());
keyCountMap = this.getHandle().getMapState("keyCountMap",
- Encoders.STRING(), Encoders.LONG());
+ Encoders.STRING(), Encoders.LONG(), TTLConfig.NONE());
keysList = this.getHandle().getListState("keyList",
- Encoders.STRING());
+ Encoders.STRING(), TTLConfig.NONE());
}
@Override
diff --git
a/sql/core/src/test/java/test/org/apache/spark/sql/TestStatefulProcessorWithInitialState.java
b/sql/core/src/test/java/test/org/apache/spark/sql/TestStatefulProcessorWithInitialState.java
index 7e356abf2d05..63cac8c36915 100644
---
a/sql/core/src/test/java/test/org/apache/spark/sql/TestStatefulProcessorWithInitialState.java
+++
b/sql/core/src/test/java/test/org/apache/spark/sql/TestStatefulProcessorWithInitialState.java
@@ -41,7 +41,7 @@ public class TestStatefulProcessorWithInitialState
OutputMode outputMode,
TimeMode timeMode) {
testState = this.getHandle().getValueState("testState",
- Encoders.STRING());
+ Encoders.STRING(), TTLConfig.NONE());
}
@Override
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
index 293ec52cc871..d00827fbd3b2 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
@@ -34,7 +34,7 @@ class StatefulProcessorWithSingleValueVar extends
RunningCountStatefulProcessor
outputMode: OutputMode,
timeMode: TimeMode): Unit = {
_valueState = getHandle.getValueState[TestClass](
- "valueState", Encoders.product[TestClass])
+ "valueState", Encoders.product[TestClass], TTLConfig.NONE)
}
override def handleInputRows(
@@ -81,7 +81,7 @@ class SessionGroupsStatefulProcessor extends
override def init(
outputMode: OutputMode,
timeMode: TimeMode): Unit = {
- _groupsList = getHandle.getListState("groupsList", Encoders.STRING)
+ _groupsList = getHandle.getListState("groupsList", Encoders.STRING,
TTLConfig.NONE)
}
override def handleInputRows(
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala
index aaaa53059126..3925c3d62da3 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala
@@ -130,7 +130,8 @@ class TransformWithStateInPandasStateServerSuite extends
SparkFunSuite with Befo
verify(statefulProcessorHandle)
.getValueState[Row](any[String], any[Encoder[Row]], any[TTLConfig])
} else {
- verify(statefulProcessorHandle).getValueState[Row](any[String],
any[Encoder[Row]])
+ verify(statefulProcessorHandle).getValueState[Row](any[String],
any[Encoder[Row]],
+ any[TTLConfig])
}
verify(outputStream).writeInt(0)
}
@@ -153,7 +154,8 @@ class TransformWithStateInPandasStateServerSuite extends
SparkFunSuite with Befo
verify(statefulProcessorHandle)
.getListState[Row](any[String], any[Encoder[Row]], any[TTLConfig])
} else {
- verify(statefulProcessorHandle).getListState[Row](any[String],
any[Encoder[Row]])
+ verify(statefulProcessorHandle).getListState[Row](any[String],
any[Encoder[Row]],
+ any[TTLConfig])
}
verify(outputStream).writeInt(0)
}
@@ -178,7 +180,7 @@ class TransformWithStateInPandasStateServerSuite extends
SparkFunSuite with Befo
.getMapState[Row, Row](any[String], any[Encoder[Row]],
any[Encoder[Row]], any[TTLConfig])
} else {
verify(statefulProcessorHandle).getMapState[Row, Row](any[String],
any[Encoder[Row]],
- any[Encoder[Row]])
+ any[Encoder[Row]], any[TTLConfig])
}
verify(outputStream).writeInt(0)
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala
index e9300464af8d..22876831c00d 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala
@@ -31,6 +31,8 @@ import org.apache.spark.sql.streaming.{ListState, TimeMode,
TTLConfig, ValueStat
* operators such as transformWithState
*/
class ListStateSuite extends StateVariableSuiteBase {
+ import testImplicits._
+
// overwrite useMultipleValuesPerKey in base suite to be true for list state
override def useMultipleValuesPerKey: Boolean = true
@@ -40,7 +42,8 @@ class ListStateSuite extends StateVariableSuiteBase {
val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
stringEncoder, TimeMode.None())
- val listState: ListState[Long] = handle.getListState[Long]("listState",
Encoders.scalaLong)
+ val listState: ListState[Long] = handle.getListState[Long]("listState",
+ TTLConfig.NONE)
ImplicitGroupingKeyTracker.setImplicitKey("test_key")
val e = intercept[SparkIllegalArgumentException] {
@@ -73,7 +76,8 @@ class ListStateSuite extends StateVariableSuiteBase {
val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
stringEncoder, TimeMode.None())
- val testState: ListState[Long] = handle.getListState[Long]("testState",
Encoders.scalaLong)
+ val testState: ListState[Long] = handle.getListState[Long]("testState",
+ TTLConfig.NONE)
ImplicitGroupingKeyTracker.setImplicitKey("test_key")
// simple put and get test
@@ -101,8 +105,10 @@ class ListStateSuite extends StateVariableSuiteBase {
val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
stringEncoder, TimeMode.None())
- val testState1: ListState[Long] =
handle.getListState[Long]("testState1", Encoders.scalaLong)
- val testState2: ListState[Long] =
handle.getListState[Long]("testState2", Encoders.scalaLong)
+ val testState1: ListState[Long] = handle.getListState[Long]("testState1",
+ TTLConfig.NONE)
+ val testState2: ListState[Long] = handle.getListState[Long]("testState2",
+ TTLConfig.NONE)
ImplicitGroupingKeyTracker.setImplicitKey("test_key")
@@ -139,10 +145,12 @@ class ListStateSuite extends StateVariableSuiteBase {
val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
stringEncoder, TimeMode.None())
- val listState1: ListState[Long] =
handle.getListState[Long]("listState1", Encoders.scalaLong)
- val listState2: ListState[Long] =
handle.getListState[Long]("listState2", Encoders.scalaLong)
+ val listState1: ListState[Long] = handle.getListState[Long]("listState1",
+ TTLConfig.NONE)
+ val listState2: ListState[Long] = handle.getListState[Long]("listState2",
+ TTLConfig.NONE)
val valueState: ValueState[Long] = handle.getValueState[Long](
- "valueState", Encoders.scalaLong)
+ "valueState", TTLConfig.NONE)
ImplicitGroupingKeyTracker.setImplicitKey("test_key")
// simple put and get test
@@ -218,7 +226,7 @@ class ListStateSuite extends StateVariableSuiteBase {
}
}
- test("test negative or zero TTL duration throws error") {
+ test("test null or negative TTL duration throws error") {
tryWithProviderResource(newStoreProviderWithStateVariable(true)) {
provider =>
val store = provider.getStore(0)
val batchTimestampMs = 10
@@ -226,7 +234,7 @@ class ListStateSuite extends StateVariableSuiteBase {
stringEncoder,
TimeMode.ProcessingTime(), batchTimestampMs = Some(batchTimestampMs))
- Seq(null, Duration.ZERO, Duration.ofMinutes(-1)).foreach { ttlDuration =>
+ Seq(null, Duration.ofMinutes(-1)).foreach { ttlDuration =>
val ttlConfig = TTLConfig(ttlDuration)
val ex = intercept[SparkUnsupportedOperationException] {
handle.getListState[String]("testState", Encoders.STRING, ttlConfig)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala
index b067d589de90..9a0a891d538e 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala
@@ -36,6 +36,8 @@ class MapStateSuite extends StateVariableSuiteBase {
.add("key", BinaryType)
.add("userKey", BinaryType)
+ import testImplicits._
+
test("Map state operations for single instance") {
tryWithProviderResource(newStoreProviderWithStateVariable(true)) {
provider =>
val store = provider.getStore(0)
@@ -43,7 +45,8 @@ class MapStateSuite extends StateVariableSuiteBase {
stringEncoder, TimeMode.None())
val testState: MapState[String, Double] =
- handle.getMapState[String, Double]("testState", Encoders.STRING,
Encoders.scalaDouble)
+ handle.getMapState[String, Double]("testState", Encoders.STRING,
Encoders.scalaDouble,
+ TTLConfig.NONE)
ImplicitGroupingKeyTracker.setImplicitKey("test_key")
// put initial value
testState.updateValue("k1", 1.0)
@@ -77,9 +80,10 @@ class MapStateSuite extends StateVariableSuiteBase {
stringEncoder, TimeMode.None())
val testState1: MapState[Long, Double] =
- handle.getMapState[Long, Double]("testState1", Encoders.scalaLong,
Encoders.scalaDouble)
+ handle.getMapState[Long, Double]("testState1", TTLConfig.NONE)
val testState2: MapState[Long, Int] =
- handle.getMapState[Long, Int]("testState2", Encoders.scalaLong,
Encoders.scalaInt)
+ handle.getMapState[Long, Int]("testState2",
+ TTLConfig.NONE)
ImplicitGroupingKeyTracker.setImplicitKey("test_key")
// put initial value
testState1.updateValue(1L, 1.0)
@@ -116,13 +120,13 @@ class MapStateSuite extends StateVariableSuiteBase {
stringEncoder, TimeMode.None())
val mapTestState1: MapState[String, Int] =
- handle.getMapState[String, Int]("mapTestState1", Encoders.STRING,
Encoders.scalaInt)
+ handle.getMapState[String, Int]("mapTestState1", TTLConfig.NONE)
val mapTestState2: MapState[String, Int] =
- handle.getMapState[String, Int]("mapTestState2", Encoders.STRING,
Encoders.scalaInt)
+ handle.getMapState[String, Int]("mapTestState2", TTLConfig.NONE)
val valueTestState: ValueState[String] =
- handle.getValueState[String]("valueTestState", Encoders.STRING)
+ handle.getValueState[String]("valueTestState", TTLConfig.NONE)
val listTestState: ListState[String] =
- handle.getListState[String]("listTestState", Encoders.STRING)
+ handle.getListState[String]("listTestState", TTLConfig.NONE)
ImplicitGroupingKeyTracker.setImplicitKey("test_key")
// put initial values
@@ -227,7 +231,7 @@ class MapStateSuite extends StateVariableSuiteBase {
}
}
- test("test negative or zero TTL duration throws error") {
+ test("test null or negative TTL duration throws error") {
tryWithProviderResource(newStoreProviderWithStateVariable(true)) {
provider =>
val store = provider.getStore(0)
val batchTimestampMs = 10
@@ -235,7 +239,7 @@ class MapStateSuite extends StateVariableSuiteBase {
stringEncoder,
TimeMode.ProcessingTime(), batchTimestampMs = Some(batchTimestampMs))
- Seq(null, Duration.ZERO, Duration.ofMinutes(-1)).foreach { ttlDuration =>
+ Seq(null, Duration.ofMinutes(-1)).foreach { ttlDuration =>
val ttlConfig = TTLConfig(ttlDuration)
val ex = intercept[SparkUnsupportedOperationException] {
handle.getMapState[String, String](
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala
index 48a6fd836a46..0d74aade6719 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala
@@ -32,6 +32,8 @@ import org.apache.spark.sql.streaming.{TimeMode, TTLConfig}
*/
class StatefulProcessorHandleSuite extends StateVariableSuiteBase {
+ import testImplicits._
+
private def getTimeMode(timeMode: String): TimeMode = {
timeMode match {
case "None" => TimeMode.None()
@@ -48,7 +50,7 @@ class StatefulProcessorHandleSuite extends
StateVariableSuiteBase {
val handle = new StatefulProcessorHandleImpl(store,
UUID.randomUUID(), stringEncoder, getTimeMode(timeMode))
assert(handle.getHandleState === StatefulProcessorHandleState.CREATED)
- handle.getValueState[Long]("testState", Encoders.scalaLong)
+ handle.getValueState[Long]("testState", TTLConfig.NONE)
}
}
}
@@ -74,7 +76,7 @@ class StatefulProcessorHandleSuite extends
StateVariableSuiteBase {
}
private def createValueStateInstance(handle: StatefulProcessorHandleImpl):
Unit = {
- handle.getValueState[Long]("testState", Encoders.scalaLong)
+ handle.getValueState[Long]("testState", TTLConfig.NONE)
}
private def registerTimer(handle: StatefulProcessorHandleImpl): Unit = {
@@ -222,11 +224,11 @@ class StatefulProcessorHandleSuite extends
StateVariableSuiteBase {
UUID.randomUUID(), stringEncoder, TimeMode.ProcessingTime(),
batchTimestampMs = Some(10))
- val valueStateWithTTL = handle.getValueState("testState",
- Encoders.STRING, TTLConfig(Duration.ofHours(1)))
+ val valueStateWithTTL = handle.getValueState[String]("testState",
+ TTLConfig(Duration.ofHours(1)))
// create another state without TTL, this should not be captured in the
handle
- handle.getValueState("testState", Encoders.STRING)
+ handle.getValueState[String]("testState", TTLConfig.NONE)
assert(handle.ttlStates.size() === 1)
assert(handle.ttlStates.get(0) === valueStateWithTTL)
@@ -244,7 +246,7 @@ class StatefulProcessorHandleSuite extends
StateVariableSuiteBase {
Encoders.STRING, TTLConfig(Duration.ofHours(1)))
// create another state without TTL, this should not be captured in the
handle
- handle.getListState("testState", Encoders.STRING)
+ handle.getListState("testState", Encoders.STRING, TTLConfig.NONE)
assert(handle.ttlStates.size() === 1)
assert(handle.ttlStates.get(0) === listStateWithTTL)
@@ -262,7 +264,8 @@ class StatefulProcessorHandleSuite extends
StateVariableSuiteBase {
Encoders.STRING, Encoders.STRING, TTLConfig(Duration.ofHours(1)))
// create another state without TTL, this should not be captured in the
handle
- handle.getMapState("testState", Encoders.STRING, Encoders.STRING)
+ handle.getMapState("testState", Encoders.STRING, Encoders.STRING,
+ TTLConfig.NONE)
assert(handle.ttlStates.size() === 1)
assert(handle.ttlStates.get(0) === mapStateWithTTL)
@@ -275,9 +278,10 @@ class StatefulProcessorHandleSuite extends
StateVariableSuiteBase {
val handle = new StatefulProcessorHandleImpl(store,
UUID.randomUUID(), stringEncoder, TimeMode.None())
- handle.getValueState("testValueState", Encoders.STRING)
- handle.getListState("testListState", Encoders.STRING)
- handle.getMapState("testMapState", Encoders.STRING, Encoders.STRING)
+ handle.getValueState("testValueState", Encoders.STRING, TTLConfig.NONE)
+ handle.getListState("testListState", Encoders.STRING, TTLConfig.NONE)
+ handle.getMapState("testMapState", Encoders.STRING, Encoders.STRING,
+ TTLConfig.NONE)
assert(handle.ttlStates.isEmpty)
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
index 13d758eb1b88..55d08cd8f12a 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
@@ -44,6 +44,7 @@ case class TestClass(var id: Long, var name: String)
class ValueStateSuite extends StateVariableSuiteBase {
import StateStoreTestsHelper._
+ import testImplicits._
test("Implicit key operations") {
tryWithProviderResource(newStoreProviderWithStateVariable(true)) {
provider =>
@@ -52,7 +53,8 @@ class ValueStateSuite extends StateVariableSuiteBase {
stringEncoder, TimeMode.None())
val stateName = "testState"
- val testState: ValueState[Long] =
handle.getValueState[Long]("testState", Encoders.scalaLong)
+ val testState: ValueState[Long] = handle.getValueState[Long]("testState",
+ TTLConfig.NONE)
assert(ImplicitGroupingKeyTracker.getImplicitKeyOption.isEmpty)
val ex = intercept[Exception] {
testState.update(123)
@@ -95,7 +97,8 @@ class ValueStateSuite extends StateVariableSuiteBase {
val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
stringEncoder, TimeMode.None())
- val testState: ValueState[Long] =
handle.getValueState[Long]("testState", Encoders.scalaLong)
+ val testState: ValueState[Long] = handle.getValueState[Long]("testState",
+ TTLConfig.NONE)
ImplicitGroupingKeyTracker.setImplicitKey("test_key")
testState.update(123)
assert(testState.get() === 123)
@@ -122,9 +125,9 @@ class ValueStateSuite extends StateVariableSuiteBase {
stringEncoder, TimeMode.None())
val testState1: ValueState[Long] = handle.getValueState[Long](
- "testState1", Encoders.scalaLong)
+ "testState1", TTLConfig.NONE)
val testState2: ValueState[Long] = handle.getValueState[Long](
- "testState2", Encoders.scalaLong)
+ "testState2", TTLConfig.NONE)
ImplicitGroupingKeyTracker.setImplicitKey("test_key")
testState1.update(123)
assert(testState1.get() === 123)
@@ -168,7 +171,7 @@ class ValueStateSuite extends StateVariableSuiteBase {
val cfName = "$testState"
val ex = intercept[SparkUnsupportedOperationException] {
- handle.getValueState[Long](cfName, Encoders.scalaLong)
+ handle.getValueState[Long](cfName, TTLConfig.NONE)
}
checkError(
ex,
@@ -207,7 +210,7 @@ class ValueStateSuite extends StateVariableSuiteBase {
stringEncoder, TimeMode.None())
val testState: ValueState[Double] =
handle.getValueState[Double]("testState",
- Encoders.scalaDouble)
+ Encoders.scalaDouble, TTLConfig.NONE)
ImplicitGroupingKeyTracker.setImplicitKey("test_key")
testState.update(1.0)
assert(testState.get().equals(1.0))
@@ -233,7 +236,7 @@ class ValueStateSuite extends StateVariableSuiteBase {
stringEncoder, TimeMode.None())
val testState: ValueState[Long] = handle.getValueState[Long]("testState",
- Encoders.scalaLong)
+ TTLConfig.NONE)
ImplicitGroupingKeyTracker.setImplicitKey("test_key")
testState.update(1L)
assert(testState.get().equals(1L))
@@ -259,7 +262,7 @@ class ValueStateSuite extends StateVariableSuiteBase {
stringEncoder, TimeMode.None())
val testState: ValueState[TestClass] =
handle.getValueState[TestClass]("testState",
- Encoders.product[TestClass])
+ TTLConfig.NONE)
ImplicitGroupingKeyTracker.setImplicitKey("test_key")
testState.update(TestClass(1, "testcase1"))
assert(testState.get().equals(TestClass(1, "testcase1")))
@@ -285,7 +288,7 @@ class ValueStateSuite extends StateVariableSuiteBase {
stringEncoder, TimeMode.None())
val testState: ValueState[POJOTestClass] =
handle.getValueState[POJOTestClass]("testState",
- Encoders.bean(classOf[POJOTestClass]))
+ Encoders.bean(classOf[POJOTestClass]), TTLConfig.NONE)
ImplicitGroupingKeyTracker.setImplicitKey("test_key")
testState.update(new POJOTestClass("testcase1", 1))
assert(testState.get().equals(new POJOTestClass("testcase1", 1)))
@@ -304,7 +307,6 @@ class ValueStateSuite extends StateVariableSuiteBase {
}
}
-
test(s"test Value state TTL") {
tryWithProviderResource(newStoreProviderWithStateVariable(true)) {
provider =>
val store = provider.getStore(0)
@@ -361,7 +363,7 @@ class ValueStateSuite extends StateVariableSuiteBase {
}
}
- test("test negative or zero TTL duration throws error") {
+ test("test null or zero TTL duration throws error") {
tryWithProviderResource(newStoreProviderWithStateVariable(true)) {
provider =>
val store = provider.getStore(0)
val batchTimestampMs = 10
@@ -369,7 +371,7 @@ class ValueStateSuite extends StateVariableSuiteBase {
stringEncoder,
TimeMode.ProcessingTime(), batchTimestampMs = Some(batchTimestampMs))
- Seq(null, Duration.ZERO, Duration.ofMinutes(-1)).foreach { ttlDuration =>
+ Seq(null, Duration.ofMinutes(-1)).foreach { ttlDuration =>
val ttlConfig = TTLConfig(ttlDuration)
val ex = intercept[SparkUnsupportedOperationException] {
handle.getValueState[String]("testState", Encoders.STRING, ttlConfig)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala
index 20f04cc66c0a..88862e2ad079 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala
@@ -33,7 +33,7 @@ class TestListStateProcessor
override def init(
outputMode: OutputMode,
timeMode: TimeMode): Unit = {
- _listState = getHandle.getListState("testListState", Encoders.STRING)
+ _listState = getHandle.getListState("testListState", Encoders.STRING,
TTLConfig.NONE)
}
override def handleInputRows(
@@ -89,8 +89,9 @@ class ToggleSaveAndEmitProcessor
override def init(
outputMode: OutputMode,
timeMode: TimeMode): Unit = {
- _listState = getHandle.getListState("testListState", Encoders.STRING)
- _valueState = getHandle.getValueState("testValueState",
Encoders.scalaBoolean)
+ _listState = getHandle.getListState("testListState", Encoders.STRING,
TTLConfig.NONE)
+ _valueState = getHandle.getValueState("testValueState",
Encoders.scalaBoolean,
+ TTLConfig.NONE)
}
override def handleInputRows(
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala
index da4218949a11..76c5cbeee424 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala
@@ -33,7 +33,8 @@ class TestMapStateProcessor
override def init(
outputMode: OutputMode,
timeMode: TimeMode): Unit = {
- _mapState = getHandle.getMapState("sessionState", Encoders.STRING,
Encoders.STRING)
+ _mapState = getHandle.getMapState("sessionState", Encoders.STRING,
Encoders.STRING,
+ TTLConfig.NONE)
}
override def handleInputRows(
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala
index 300785611fd0..35ac8a4687eb 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala
@@ -39,10 +39,13 @@ abstract class StatefulProcessorWithInitialStateTestClass[V]
override def init(
outputMode: OutputMode,
timeMode: TimeMode): Unit = {
- _valState = getHandle.getValueState[Double]("testValueInit",
Encoders.scalaDouble)
- _listState = getHandle.getListState[Double]("testListInit",
Encoders.scalaDouble)
+ _valState = getHandle.getValueState[Double]("testValueInit",
Encoders.scalaDouble,
+ TTLConfig.NONE)
+ _listState = getHandle.getListState[Double]("testListInit",
Encoders.scalaDouble,
+ TTLConfig.NONE)
_mapState = getHandle.getMapState[Double, Int](
- "testMapInit", Encoders.scalaDouble, Encoders.scalaInt)
+ "testMapInit", Encoders.scalaDouble, Encoders.scalaInt,
+ TTLConfig.NONE)
}
override def handleInputRows(
@@ -162,8 +165,10 @@ class StatefulProcessorWithInitialStateProcTimerClass
override def init(
outputMode: OutputMode,
timeMode: TimeMode) : Unit = {
- _countState = getHandle.getValueState[Long]("countState",
Encoders.scalaLong)
- _timerState = getHandle.getValueState[Long]("timerState",
Encoders.scalaLong)
+ _countState = getHandle.getValueState[Long]("countState",
Encoders.scalaLong,
+ TTLConfig.NONE)
+ _timerState = getHandle.getValueState[Long]("timerState",
Encoders.scalaLong,
+ TTLConfig.NONE)
}
override def handleInitialState(
@@ -209,8 +214,9 @@ class StatefulProcessorWithInitialStateEventTimerClass
outputMode: OutputMode,
timeMode: TimeMode): Unit = {
_maxEventTimeState = getHandle.getValueState[Long]("maxEventTimeState",
- Encoders.scalaLong)
- _timerState = getHandle.getValueState[Long]("timerState",
Encoders.scalaLong)
+ Encoders.scalaLong, TTLConfig.NONE)
+ _timerState = getHandle.getValueState[Long]("timerState",
Encoders.scalaLong,
+ TTLConfig.NONE)
}
private def processUnexpiredRows(maxEventTimeSec: Long): Unit = {
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
index 7d61c6fc7084..3ef5c57ee3d0 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
@@ -50,7 +50,8 @@ class RunningCountStatefulProcessor extends
StatefulProcessor[String, String, (S
override def init(
outputMode: OutputMode,
timeMode: TimeMode): Unit = {
- _countState = getHandle.getValueState[Long]("countState",
Encoders.scalaLong)
+ _countState = getHandle.getValueState[Long]("countState",
+ Encoders.scalaLong, TTLConfig.NONE)
}
override def handleInputRows(
@@ -76,8 +77,8 @@ class RunningCountStatefulProcessorWithTTL
override def init(
outputMode: OutputMode,
timeMode: TimeMode): Unit = {
- _countState = getHandle.getValueState[Long]("countState",
- Encoders.scalaLong, TTLConfig(Duration.ofMillis(1000)))
+ _countState = getHandle.getValueState[Long]("countState",
Encoders.scalaLong,
+ TTLConfig(Duration.ofMillis(1000)))
}
override def handleInputRows(
@@ -106,7 +107,7 @@ class RunningCountListStatefulProcessor
outputMode: OutputMode,
timeMode: TimeMode): Unit = {
_countState = getHandle.getListState[Long](
- "countState", Encoders.scalaLong)
+ "countState", Encoders.scalaLong, TTLConfig.NONE)
}
override def handleInputRows(
@@ -124,7 +125,8 @@ class RunningCountStatefulProcessorInt
override def init(
outputMode: OutputMode,
timeMode: TimeMode): Unit = {
- _countState = getHandle.getValueState[Int]("countState", Encoders.scalaInt)
+ _countState = getHandle.getValueState[Int]("countState", Encoders.scalaInt,
+ TTLConfig.NONE)
}
override def handleInputRows(
@@ -183,7 +185,8 @@ class RunningCountStatefulProcessorWithProcTimeTimerUpdates
outputMode: OutputMode,
timeMode: TimeMode) : Unit = {
super.init(outputMode, timeMode)
- _timerState = getHandle.getValueState[Long]("timerState",
Encoders.scalaLong)
+ _timerState = getHandle.getValueState[Long]("timerState",
Encoders.scalaLong,
+ TTLConfig.NONE)
}
protected def processUnexpiredRows(
@@ -267,8 +270,9 @@ class MaxEventTimeStatefulProcessor
outputMode: OutputMode,
timeMode: TimeMode): Unit = {
_maxEventTimeState = getHandle.getValueState[Long]("maxEventTimeState",
- Encoders.scalaLong)
- _timerState = getHandle.getValueState[Long]("timerState",
Encoders.scalaLong)
+ Encoders.scalaLong, TTLConfig.NONE)
+ _timerState = getHandle.getValueState[Long]("timerState",
Encoders.scalaLong,
+ TTLConfig.NONE)
}
protected def processUnexpiredRows(maxEventTimeSec: Long): Unit = {
@@ -313,8 +317,10 @@ class RunningCountMostRecentStatefulProcessor
override def init(
outputMode: OutputMode,
timeMode: TimeMode): Unit = {
- _countState = getHandle.getValueState[Long]("countState",
Encoders.scalaLong)
- _mostRecent = getHandle.getValueState[String]("mostRecent",
Encoders.STRING)
+ _countState = getHandle.getValueState[Long]("countState",
Encoders.scalaLong,
+ TTLConfig.NONE)
+ _mostRecent = getHandle.getValueState[String]("mostRecent",
Encoders.STRING,
+ TTLConfig.NONE)
}
override def handleInputRows(
@@ -343,7 +349,8 @@ class MostRecentStatefulProcessorWithDeletion
outputMode: OutputMode,
timeMode: TimeMode): Unit = {
getHandle.deleteIfExists("countState")
- _mostRecent = getHandle.getValueState[String]("mostRecent",
Encoders.STRING)
+ _mostRecent = getHandle.getValueState[String]("mostRecent",
Encoders.STRING,
+ TTLConfig.NONE)
}
override def handleInputRows(
@@ -370,7 +377,8 @@ class RunningCountStatefulProcessorWithError extends
RunningCountStatefulProcess
inputRows: Iterator[String],
timerValues: TimerValues): Iterator[(String, String)] = {
// Trying to create value state here should fail
- _tempState = getHandle.getValueState[Long]("tempState", Encoders.scalaLong)
+ _tempState = getHandle.getValueState[Long]("tempState", Encoders.scalaLong,
+ TTLConfig.NONE)
Iterator.empty
}
}
@@ -383,11 +391,13 @@ class StatefulProcessorWithCompositeTypes extends
RunningCountStatefulProcessor
override def init(
outputMode: OutputMode,
timeMode: TimeMode): Unit = {
- _countState = getHandle.getValueState[Long]("countState",
Encoders.scalaLong)
+ _countState = getHandle.getValueState[Long]("countState",
Encoders.scalaLong,
+ TTLConfig.NONE)
_listState = getHandle.getListState[TestClass](
- "listState", Encoders.product[TestClass])
+ "listState", Encoders.product[TestClass], TTLConfig.NONE)
_mapState = getHandle.getMapState[POJOTestClass, String](
- "mapState", Encoders.bean(classOf[POJOTestClass]), Encoders.STRING)
+ "mapState", Encoders.bean(classOf[POJOTestClass]), Encoders.STRING,
+ TTLConfig.NONE)
}
}
@@ -421,6 +431,9 @@ class TransformWithStateSuite extends StateStoreMetricsTest
}
test("transformWithState - lazy iterators can properly get/set keyed state")
{
+ val spark = this.spark
+ import spark.implicits._
+
class ProcessorWithLazyIterators
extends StatefulProcessor[Long, Long, Long] {
@transient protected var _myValueState: ValueState[Long] = _
@@ -429,7 +442,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest
override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {
_myValueState = getHandle.getValueState[Long](
"myValueState",
- Encoders.scalaLong
+ TTLConfig.NONE
)
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala
index e7b394db0c3c..21c3beb79314 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala
@@ -130,7 +130,7 @@ class MultipleValueStatesTTLProcessor(
.getValueState("valueStateTTL", Encoders.scalaInt, ttlConfig)
.asInstanceOf[ValueStateImplWithTTL[Int]]
_valueStateWithoutTTL = getHandle
- .getValueState("valueState", Encoders.scalaInt)
+ .getValueState[Int]("valueState", Encoders.scalaInt, TTLConfig.NONE)
.asInstanceOf[ValueStateImpl[Int]]
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]