This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 47063a6a75f7 [SPARK-50128][SS] Add stateful processor handle APIs 
using implicit encoders in Scala
47063a6a75f7 is described below

commit 47063a6a75f729b43dec8bdeb78c46ea29e5f98f
Author: Anish Shrigondekar <[email protected]>
AuthorDate: Tue Nov 5 12:20:16 2024 +0900

    [SPARK-50128][SS] Add stateful processor handle APIs using implicit 
encoders in Scala
    
    ### What changes were proposed in this pull request?
    Add stateful processor handle APIs using implicit encoders in Scala
    
    ### Why are the changes needed?
    Without the changes, users have to pass explicit SQL encoders for state 
types while acquiring an instance of the underlying state variable
    
    ### Does this PR introduce _any_ user-facing change?
    Yes
    
    Users can now implicits available in Scala through `import 
spark.implicits._` and only provide the type while getting the state objects. 
For eg -
    
    ```
          override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = 
{
             _myValueState = getHandle.getValueState[Long]("myValueState", 
TTLConfig.NONE)
          }
    ```
    
    ### How was this patch tested?
    Existing unit tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #48728 from anishshri-db/task/SPARK-50128.
    
    Authored-by: Anish Shrigondekar <[email protected]>
    Signed-off-by: Jungtaek Lim <[email protected]>
---
 .../sql/streaming/StatefulProcessorHandle.scala    |  98 +++++----
 .../org/apache/spark/sql/streaming/TTLConfig.scala |  15 ++
 .../TransformWithStateInPandasStateServer.scala    |  26 ++-
 .../sql/execution/streaming/ListStateImpl.scala    |   5 +-
 .../execution/streaming/ListStateImplWithTTL.scala |   9 +-
 .../sql/execution/streaming/MapStateImpl.scala     |  11 +-
 .../execution/streaming/MapStateImplWithTTL.scala  |  17 +-
 .../streaming/StateTypesEncoderUtils.scala         |   2 +-
 .../streaming/StatefulProcessorHandleImpl.scala    | 226 ++++++++++++---------
 .../spark/sql/execution/streaming/TTLState.scala   |   3 +-
 .../sql/execution/streaming/ValueStateImpl.scala   |   5 +-
 .../streaming/ValueStateImplWithTTL.scala          |  10 +-
 .../apache/spark/sql/TestStatefulProcessor.java    |   6 +-
 .../sql/TestStatefulProcessorWithInitialState.java |   2 +-
 .../StateDataSourceTransformWithStateSuite.scala   |   4 +-
 ...ransformWithStateInPandasStateServerSuite.scala |   8 +-
 .../execution/streaming/state/ListStateSuite.scala |  26 ++-
 .../execution/streaming/state/MapStateSuite.scala  |  22 +-
 .../state/StatefulProcessorHandleSuite.scala       |  24 ++-
 .../streaming/state/ValueStateSuite.scala          |  26 +--
 .../streaming/TransformWithListStateSuite.scala    |   7 +-
 .../sql/streaming/TransformWithMapStateSuite.scala |   3 +-
 .../TransformWithStateInitialStateSuite.scala      |  20 +-
 .../sql/streaming/TransformWithStateSuite.scala    |  45 ++--
 .../TransformWithValueStateTTLSuite.scala          |   2 +-
 25 files changed, 367 insertions(+), 255 deletions(-)

diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala
 
b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala
index d1eca0f3967d..f458f0de37cb 100644
--- 
a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala
+++ 
b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala
@@ -29,25 +29,12 @@ import org.apache.spark.sql.Encoder
 @Evolving
 private[sql] trait StatefulProcessorHandle extends Serializable {
 
-  /**
-   * Function to create new or return existing single value state variable of 
given type. The user
-   * must ensure to call this function only within the `init()` method of the 
StatefulProcessor.
-   *
-   * @param stateName
-   *   \- name of the state variable
-   * @param valEncoder
-   *   \- SQL encoder for state variable
-   * @tparam T
-   *   \- type of state variable
-   * @return
-   *   \- instance of ValueState of type T that can be used to store state 
persistently
-   */
-  def getValueState[T](stateName: String, valEncoder: Encoder[T]): 
ValueState[T]
-
   /**
    * Function to create new or return existing single value state variable of 
given type with ttl.
    * State values will not be returned past ttlDuration, and will be 
eventually removed from the
    * state store. Any state update resets the ttl to current processing time 
plus ttlDuration.
+   * Users can use the helper method `TTLConfig.NONE` in Scala or 
`TTLConfig.NONE()` in Java for
+   * the TTLConfig parameter to disable TTL for the state variable.
    *
    * The user must ensure to call this function only within the `init()` 
method of the
    * StatefulProcessor.
@@ -69,25 +56,34 @@ private[sql] trait StatefulProcessorHandle extends 
Serializable {
       ttlConfig: TTLConfig): ValueState[T]
 
   /**
-   * Creates new or returns existing list state associated with stateName. The 
ListState persists
-   * values of type T.
+   * (Scala-specific) Function to create new or return existing single value 
state variable of
+   * given type with ttl. State values will not be returned past ttlDuration, 
and will be
+   * eventually removed from the state store. Any state update resets the ttl 
to current
+   * processing time plus ttlDuration. Users can use the helper method 
`TTLConfig.NONE` in Scala
+   * or `TTLConfig.NONE()` in Java for the TTLConfig parameter to disable TTL 
for the state
+   * variable.
+   *
+   * The user must ensure to call this function only within the `init()` 
method of the
+   * StatefulProcessor. Note that this API uses the implicit SQL encoder in 
Scala.
    *
    * @param stateName
    *   \- name of the state variable
-   * @param valEncoder
-   *   \- SQL encoder for state variable
+   * @param ttlConfig
+   *   \- the ttl configuration (time to live duration etc.)
    * @tparam T
    *   \- type of state variable
    * @return
-   *   \- instance of ListState of type T that can be used to store state 
persistently
+   *   \- instance of ValueState of type T that can be used to store state 
persistently
    */
-  def getListState[T](stateName: String, valEncoder: Encoder[T]): ListState[T]
+  def getValueState[T: Encoder](stateName: String, ttlConfig: TTLConfig): 
ValueState[T]
 
   /**
    * Function to create new or return existing list state variable of given 
type with ttl. State
    * values will not be returned past ttlDuration, and will be eventually 
removed from the state
    * store. Any values in listState which have expired after ttlDuration will 
not be returned on
-   * get() and will be eventually removed from the state.
+   * get() and will be eventually removed from the state. Users can use the 
helper method
+   * `TTLConfig.NONE` in Scala or `TTLConfig.NONE()` in Java for the TTLConfig 
parameter to
+   * disable TTL for the state variable.
    *
    * The user must ensure to call this function only within the `init()` 
method of the
    * StatefulProcessor.
@@ -109,32 +105,34 @@ private[sql] trait StatefulProcessorHandle extends 
Serializable {
       ttlConfig: TTLConfig): ListState[T]
 
   /**
-   * Creates new or returns existing map state associated with stateName. The 
MapState persists
-   * Key-Value pairs of type [K, V].
+   * (Scala-specific) Function to create new or return existing list state 
variable of given type
+   * with ttl. State values will not be returned past ttlDuration, and will be 
eventually removed
+   * from the state store. Any values in listState which have expired after 
ttlDuration will not
+   * be returned on get() and will be eventually removed from the state. Users 
can use the helper
+   * method `TTLConfig.NONE` in Scala or `TTLConfig.NONE()` in Java for the 
TTLConfig parameter to
+   * disable TTL for the state variable.
+   *
+   * The user must ensure to call this function only within the `init()` 
method of the
+   * StatefulProcessor. Note that this API uses the implicit SQL encoder in 
Scala.
    *
    * @param stateName
    *   \- name of the state variable
-   * @param userKeyEnc
-   *   \- spark sql encoder for the map key
-   * @param valEncoder
-   *   \- spark sql encoder for the map value
-   * @tparam K
-   *   \- type of key for map state variable
-   * @tparam V
-   *   \- type of value for map state variable
+   * @param ttlConfig
+   *   \- the ttl configuration (time to live duration etc.)
+   * @tparam T
+   *   \- type of state variable
    * @return
-   *   \- instance of MapState of type [K,V] that can be used to store state 
persistently
+   *   \- instance of ListState of type T that can be used to store state 
persistently
    */
-  def getMapState[K, V](
-      stateName: String,
-      userKeyEnc: Encoder[K],
-      valEncoder: Encoder[V]): MapState[K, V]
+  def getListState[T: Encoder](stateName: String, ttlConfig: TTLConfig): 
ListState[T]
 
   /**
    * Function to create new or return existing map state variable of given 
type with ttl. State
    * values will not be returned past ttlDuration, and will be eventually 
removed from the state
    * store. Any values in mapState which have expired after ttlDuration will 
not returned on get()
-   * and will be eventually removed from the state.
+   * and will be eventually removed from the state. Users can use the helper 
method
+   * `TTLConfig.NONE` in Scala or `TTLConfig.NONE()` in Java for the TTLConfig 
parameter to
+   * disable TTL for the state variable.
    *
    * The user must ensure to call this function only within the `init()` 
method of the
    * StatefulProcessor.
@@ -160,6 +158,30 @@ private[sql] trait StatefulProcessorHandle extends 
Serializable {
       valEncoder: Encoder[V],
       ttlConfig: TTLConfig): MapState[K, V]
 
+  /**
+   * (Scala-specific) Function to create new or return existing map state 
variable of given type
+   * with ttl. State values will not be returned past ttlDuration, and will be 
eventually removed
+   * from the state store. Any values in mapState which have expired after 
ttlDuration will not be
+   * returned on get() and will be eventually removed from the state. Users 
can use the helper
+   * method `TTLConfig.NONE` in Scala or `TTLConfig.NONE()` in Java for the 
TTLConfig parameter to
+   * disable TTL for the state variable.
+   *
+   * The user must ensure to call this function only within the `init()` 
method of the
+   * StatefulProcessor. Note that this API uses the implicit SQL encoder in 
Scala.
+   *
+   * @param stateName
+   *   \- name of the state variable
+   * @param ttlConfig
+   *   \- the ttl configuration (time to live duration etc.)
+   * @tparam K
+   *   \- type of key for map state variable
+   * @tparam V
+   *   \- type of value for map state variable
+   * @return
+   *   \- instance of MapState of type [K,V] that can be used to store state 
persistently
+   */
+  def getMapState[K: Encoder, V: Encoder](stateName: String, ttlConfig: 
TTLConfig): MapState[K, V]
+
   /** Function to return queryInfo for currently running task */
   def getQueryInfo(): QueryInfo
 
diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/streaming/TTLConfig.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/streaming/TTLConfig.scala
index ce786aa943d8..7ec4fbc8c1b5 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/TTLConfig.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/TTLConfig.scala
@@ -24,7 +24,22 @@ import java.time.Duration
  * will be eventually removed from the state store. Any state update resets 
the ttl to current
  * processing time plus ttlDuration.
  *
+ * Passing a TTL duration of zero will disable the TTL for the state variable. 
Users can also use
+ * the helper method `TTLConfig.NONE` in Scala or `TTLConfig.NONE()` in Java 
to disable TTL for
+ * the state variable.
+ *
  * @param ttlDuration
  *   time to live duration for state stored in the state variable.
  */
 case class TTLConfig(ttlDuration: Duration)
+
+object TTLConfig {
+
+  /**
+   * Helper method to create a TTLConfig with expiry duration as Zero
+   * @return
+   *   \- TTLConfig with expiry duration as Zero
+   */
+  def NONE: TTLConfig = TTLConfig(Duration.ZERO)
+
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala
index 8a67d5d47f05..5f3ebd87e75e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala
@@ -591,20 +591,23 @@ class TransformWithStateInPandasStateServer(
     stateType match {
       case StateVariableType.ValueState => if 
(!valueStates.contains(stateName)) {
         val state = if (ttlDurationMs.isEmpty) {
-          statefulProcessorHandle.getValueState[Row](stateName, 
Encoders.row(schema))
+          statefulProcessorHandle.getValueState[Row](stateName, 
Encoders.row(schema),
+            TTLConfig.NONE)
+          } else {
+            statefulProcessorHandle.getValueState(
+              stateName, Encoders.row(schema), 
TTLConfig(Duration.ofMillis(ttlDurationMs.get)))
+          }
+          valueStates.put(stateName,
+            ValueStateInfo(state, schema, 
expressionEncoder.createDeserializer()))
+          sendResponse(0)
         } else {
-          statefulProcessorHandle.getValueState(
-            stateName, Encoders.row(schema), 
TTLConfig(Duration.ofMillis(ttlDurationMs.get)))
+          sendResponse(1, s"Value state $stateName already exists")
         }
-        valueStates.put(stateName,
-          ValueStateInfo(state, schema, 
expressionEncoder.createDeserializer()))
-        sendResponse(0)
-      } else {
-        sendResponse(1, s"Value state $stateName already exists")
-      }
+
       case StateVariableType.ListState => if (!listStates.contains(stateName)) 
{
         val state = if (ttlDurationMs.isEmpty) {
-          statefulProcessorHandle.getListState[Row](stateName, 
Encoders.row(schema))
+          statefulProcessorHandle.getListState[Row](stateName, 
Encoders.row(schema),
+            TTLConfig.NONE)
         } else {
           statefulProcessorHandle.getListState(
             stateName, Encoders.row(schema), 
TTLConfig(Duration.ofMillis(ttlDurationMs.get)))
@@ -616,12 +619,13 @@ class TransformWithStateInPandasStateServer(
       } else {
         sendResponse(1, s"List state $stateName already exists")
       }
+
       case StateVariableType.MapState => if (!mapStates.contains(stateName)) {
         val valueSchema = StructType.fromString(mapStateValueSchemaString)
         val valueExpressionEncoder = 
ExpressionEncoder(valueSchema).resolveAndBind()
         val state = if (ttlDurationMs.isEmpty) {
           statefulProcessorHandle.getMapState[Row, Row](stateName,
-            Encoders.row(schema), Encoders.row(valueSchema))
+            Encoders.row(schema), Encoders.row(valueSchema), TTLConfig.NONE)
         } else {
           statefulProcessorHandle.getMapState[Row, Row](stateName, 
Encoders.row(schema),
             Encoders.row(valueSchema), 
TTLConfig(Duration.ofMillis(ttlDurationMs.get)))
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala
index 77c481a8ba0b..32683aebd8c1 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala
@@ -17,7 +17,6 @@
 package org.apache.spark.sql.execution.streaming
 
 import org.apache.spark.internal.Logging
-import org.apache.spark.sql.Encoder
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.execution.metric.SQLMetric
 import 
org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, 
StateStore, StateStoreErrors}
@@ -39,7 +38,7 @@ class ListStateImpl[S](
      store: StateStore,
      stateName: String,
      keyExprEnc: ExpressionEncoder[Any],
-     valEncoder: Encoder[S],
+     valEncoder: ExpressionEncoder[Any],
      metrics: Map[String, SQLMetric] = Map.empty)
   extends ListStateMetricsImpl
   with ListState[S]
@@ -75,7 +74,7 @@ class ListStateImpl[S](
 
        override def next(): S = {
          val valueUnsafeRow = unsafeRowValuesIterator.next()
-         stateTypesEncoder.decodeValue(valueUnsafeRow)
+         stateTypesEncoder.decodeValue(valueUnsafeRow).asInstanceOf[S]
        }
      }
    }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala
index be47f566bc6a..4c8dd6a193c2 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala
@@ -16,7 +16,6 @@
  */
 package org.apache.spark.sql.execution.streaming
 
-import org.apache.spark.sql.Encoder
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
 import org.apache.spark.sql.execution.metric.SQLMetric
@@ -43,7 +42,7 @@ class ListStateImplWithTTL[S](
     store: StateStore,
     stateName: String,
     keyExprEnc: ExpressionEncoder[Any],
-    valEncoder: Encoder[S],
+    valEncoder: ExpressionEncoder[Any],
     ttlConfig: TTLConfig,
     batchTimestampMs: Long,
     metrics: Map[String, SQLMetric] = Map.empty)
@@ -91,7 +90,7 @@ class ListStateImplWithTTL[S](
 
         if (iter.hasNext) {
           val currentRow = iter.next()
-          stateTypesEncoder.decodeValue(currentRow)
+          stateTypesEncoder.decodeValue(currentRow).asInstanceOf[S]
         } else {
           finished = true
           null.asInstanceOf[S]
@@ -223,7 +222,7 @@ class ListStateImplWithTTL[S](
     val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey()
     val unsafeRowValuesIterator = store.valuesIterator(encodedGroupingKey, 
stateName)
     unsafeRowValuesIterator.map { valueUnsafeRow =>
-      stateTypesEncoder.decodeValue(valueUnsafeRow)
+      stateTypesEncoder.decodeValue(valueUnsafeRow).asInstanceOf[S]
     }
   }
 
@@ -234,7 +233,7 @@ class ListStateImplWithTTL[S](
     val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey()
     val unsafeRowValuesIterator = store.valuesIterator(encodedGroupingKey, 
stateName)
     unsafeRowValuesIterator.map { valueUnsafeRow =>
-      (stateTypesEncoder.decodeValue(valueUnsafeRow),
+      (stateTypesEncoder.decodeValue(valueUnsafeRow).asInstanceOf[S],
         stateTypesEncoder.decodeTtlExpirationMs(valueUnsafeRow).get)
     }
   }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala
index cb3db19496dd..4e608a5d5dbb 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala
@@ -17,7 +17,6 @@
 package org.apache.spark.sql.execution.streaming
 
 import org.apache.spark.internal.Logging
-import org.apache.spark.sql.Encoder
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.execution.metric.SQLMetric
 import 
org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._
@@ -40,8 +39,8 @@ class MapStateImpl[K, V](
     store: StateStore,
     stateName: String,
     keyExprEnc: ExpressionEncoder[Any],
-    userKeyEnc: Encoder[K],
-    valEncoder: Encoder[V],
+    userKeyEnc: ExpressionEncoder[Any],
+    valEncoder: ExpressionEncoder[Any],
     metrics: Map[String, SQLMetric] = Map.empty) extends MapState[K, V] with 
Logging {
 
   // Pack grouping key and user key together as a prefixed composite key
@@ -67,7 +66,7 @@ class MapStateImpl[K, V](
     val unsafeRowValue = store.get(encodedCompositeKey, stateName)
 
     if (unsafeRowValue == null) return null.asInstanceOf[V]
-    stateTypesEncoder.decodeValue(unsafeRowValue)
+    stateTypesEncoder.decodeValue(unsafeRowValue).asInstanceOf[V]
   }
 
   /** Check if the user key is contained in the map */
@@ -92,8 +91,8 @@ class MapStateImpl[K, V](
     store.prefixScan(encodedGroupingKey, stateName)
       .map {
         case iter: UnsafeRowPair =>
-          (stateTypesEncoder.decodeCompositeKey(iter.key),
-            stateTypesEncoder.decodeValue(iter.value))
+          (stateTypesEncoder.decodeCompositeKey(iter.key).asInstanceOf[K],
+            stateTypesEncoder.decodeValue(iter.value).asInstanceOf[V])
       }
   }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala
index 6a3685ad6c46..19704b6d1bd5 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala
@@ -17,7 +17,6 @@
 package org.apache.spark.sql.execution.streaming
 
 import org.apache.spark.internal.Logging
-import org.apache.spark.sql.Encoder
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
 import org.apache.spark.sql.execution.metric.SQLMetric
@@ -45,8 +44,8 @@ class MapStateImplWithTTL[K, V](
     store: StateStore,
     stateName: String,
     keyExprEnc: ExpressionEncoder[Any],
-    userKeyEnc: Encoder[K],
-    valEncoder: Encoder[V],
+    userKeyEnc: ExpressionEncoder[Any],
+    valEncoder: ExpressionEncoder[Any],
     ttlConfig: TTLConfig,
     batchTimestampMs: Long,
     metrics: Map[String, SQLMetric] = Map.empty)
@@ -83,7 +82,7 @@ class MapStateImplWithTTL[K, V](
 
     if (retRow != null) {
       if (!stateTypesEncoder.isExpired(retRow, batchTimestampMs)) {
-        stateTypesEncoder.decodeValue(retRow)
+        stateTypesEncoder.decodeValue(retRow).asInstanceOf[V]
       } else {
         null.asInstanceOf[V]
       }
@@ -126,7 +125,9 @@ class MapStateImplWithTTL[K, V](
         if (iter.hasNext) {
           val currentRowPair = iter.next()
           val key = stateTypesEncoder.decodeCompositeKey(currentRowPair.key)
+            .asInstanceOf[K]
           val value = stateTypesEncoder.decodeValue(currentRowPair.value)
+            .asInstanceOf[V]
           (key, value)
         } else {
           finished = true
@@ -213,7 +214,7 @@ class MapStateImplWithTTL[K, V](
     val retRow = store.get(encodedCompositeKey, stateName)
 
     if (retRow != null) {
-      val resState = stateTypesEncoder.decodeValue(retRow)
+      val resState = stateTypesEncoder.decodeValue(retRow).asInstanceOf[V]
       Some(resState)
     } else {
       None
@@ -231,7 +232,9 @@ class MapStateImplWithTTL[K, V](
     // ttlExpiration
     Option(retRow).flatMap { row =>
       val ttlExpiration = stateTypesEncoder.decodeTtlExpirationMs(row)
-      ttlExpiration.map(expiration => (stateTypesEncoder.decodeValue(row), 
expiration))
+      ttlExpiration.map { expiration =>
+        (stateTypesEncoder.decodeValue(row).asInstanceOf[V], expiration)
+      }
     }
   }
 
@@ -253,7 +256,7 @@ class MapStateImplWithTTL[K, V](
             0, keyExprEnc.schema.length)) {
             val userKey = stateTypesEncoder.decodeUserKey(
               nextTtlValue.userKey)
-            nextValue = Some(userKey, nextTtlValue.expirationMs)
+            nextValue = Some(userKey.asInstanceOf[K], 
nextTtlValue.expirationMs)
           }
         }
         nextValue.isDefined
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala
index b70f9699195d..d87de4c69c40 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala
@@ -254,7 +254,7 @@ class SingleKeyTTLEncoder(
 /** Class for TTL with composite key serialization */
 class CompositeKeyTTLEncoder[K](
     keyExprEnc: ExpressionEncoder[Any],
-    userKeyEnc: Encoder[K]) {
+    userKeyEnc: ExpressionEncoder[Any]) {
 
   private val ttlKeyProjection = UnsafeProjection.create(
     getCompositeKeyTTLRowSchema(keyExprEnc.schema, userKeyEnc.schema))
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
index 762dfc7d0892..0f90fa8d9e49 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
@@ -24,7 +24,7 @@ import scala.collection.mutable
 import org.apache.spark.TaskContext
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.Encoder
-import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
 import org.apache.spark.sql.execution.metric.SQLMetric
 import 
org.apache.spark.sql.execution.streaming.StatefulProcessorHandleState.PRE_INIT
 import org.apache.spark.sql.execution.streaming.state._
@@ -135,31 +135,6 @@ class StatefulProcessorHandleImpl(
 
   private lazy val currQueryInfo: QueryInfo = buildQueryInfo()
 
-  override def getValueState[T](
-      stateName: String,
-      valEncoder: Encoder[T]): ValueState[T] = {
-    verifyStateVarOperations("get_value_state", CREATED)
-    val resultState = new ValueStateImpl[T](store, stateName, keyEncoder, 
valEncoder, metrics)
-    TWSMetricsUtils.incrementMetric(metrics, "numValueStateVars")
-    resultState
-  }
-
-  override def getValueState[T](
-      stateName: String,
-      valEncoder: Encoder[T],
-      ttlConfig: TTLConfig): ValueState[T] = {
-    verifyStateVarOperations("get_value_state", CREATED)
-    validateTTLConfig(ttlConfig, stateName)
-
-    assert(batchTimestampMs.isDefined)
-    val valueStateWithTTL = new ValueStateImplWithTTL[T](store, stateName,
-      keyEncoder, valEncoder, ttlConfig, batchTimestampMs.get, metrics)
-    ttlStates.add(valueStateWithTTL)
-    TWSMetricsUtils.incrementMetric(metrics, "numValueStateWithTTLVars")
-
-    valueStateWithTTL
-  }
-
   override def getQueryInfo(): QueryInfo = currQueryInfo
 
   private lazy val timerState = new TimerStateImpl(store, timeMode, keyEncoder)
@@ -230,11 +205,39 @@ class StatefulProcessorHandleImpl(
     }
   }
 
-  override def getListState[T](stateName: String, valEncoder: Encoder[T]): 
ListState[T] = {
-    verifyStateVarOperations("get_list_state", CREATED)
-    val resultState = new ListStateImpl[T](store, stateName, keyEncoder, 
valEncoder, metrics)
-    TWSMetricsUtils.incrementMetric(metrics, "numListStateVars")
-    resultState
+  override def getValueState[T](
+      stateName: String,
+      valEncoder: Encoder[T],
+      ttlConfig: TTLConfig): ValueState[T] = {
+    getValueState(stateName, ttlConfig)(valEncoder)
+  }
+
+  override def getValueState[T: Encoder](
+      stateName: String,
+      ttlConfig: TTLConfig): ValueState[T] = {
+    verifyStateVarOperations("get_value_state", CREATED)
+    val ttlEnabled = if (ttlConfig.ttlDuration != null && 
ttlConfig.ttlDuration.isZero) {
+      false
+    } else {
+      true
+    }
+
+    val stateEncoder = encoderFor[T].asInstanceOf[ExpressionEncoder[Any]]
+    val result = if (ttlEnabled) {
+      validateTTLConfig(ttlConfig, stateName)
+      assert(batchTimestampMs.isDefined)
+      val valueStateWithTTL = new ValueStateImplWithTTL[T](store, stateName,
+        keyEncoder, stateEncoder, ttlConfig, batchTimestampMs.get, metrics)
+      ttlStates.add(valueStateWithTTL)
+      TWSMetricsUtils.incrementMetric(metrics, "numValueStateWithTTLVars")
+      valueStateWithTTL
+    } else {
+      val valueStateWithoutTTL = new ValueStateImpl[T](store, stateName,
+        keyEncoder, stateEncoder, metrics)
+      TWSMetricsUtils.incrementMetric(metrics, "numValueStateVars")
+      valueStateWithoutTTL
+    }
+    result
   }
 
   /**
@@ -256,45 +259,72 @@ class StatefulProcessorHandleImpl(
       stateName: String,
       valEncoder: Encoder[T],
       ttlConfig: TTLConfig): ListState[T] = {
+    getListState(stateName, ttlConfig)(valEncoder)
+  }
 
+  override def getListState[T: Encoder](stateName: String, ttlConfig: 
TTLConfig): ListState[T] = {
     verifyStateVarOperations("get_list_state", CREATED)
-    validateTTLConfig(ttlConfig, stateName)
 
-    assert(batchTimestampMs.isDefined)
-    val listStateWithTTL = new ListStateImplWithTTL[T](store, stateName,
-      keyEncoder, valEncoder, ttlConfig, batchTimestampMs.get, metrics)
-    TWSMetricsUtils.incrementMetric(metrics, "numListStateWithTTLVars")
-    ttlStates.add(listStateWithTTL)
+    val ttlEnabled = if (ttlConfig.ttlDuration != null && 
ttlConfig.ttlDuration.isZero) {
+      false
+    } else {
+      true
+    }
 
-    listStateWithTTL
+    val stateEncoder = encoderFor[T].asInstanceOf[ExpressionEncoder[Any]]
+    val result = if (ttlEnabled) {
+      validateTTLConfig(ttlConfig, stateName)
+      assert(batchTimestampMs.isDefined)
+      val listStateWithTTL = new ListStateImplWithTTL[T](store, stateName,
+        keyEncoder, stateEncoder, ttlConfig, batchTimestampMs.get, metrics)
+      TWSMetricsUtils.incrementMetric(metrics, "numListStateWithTTLVars")
+      ttlStates.add(listStateWithTTL)
+      listStateWithTTL
+    } else {
+      val listStateWithoutTTL = new ListStateImpl[T](store, stateName, 
keyEncoder,
+        stateEncoder, metrics)
+      TWSMetricsUtils.incrementMetric(metrics, "numListStateVars")
+      listStateWithoutTTL
+    }
+    result
   }
 
   override def getMapState[K, V](
       stateName: String,
       userKeyEnc: Encoder[K],
-      valEncoder: Encoder[V]): MapState[K, V] = {
-    verifyStateVarOperations("get_map_state", CREATED)
-    val resultState = new MapStateImpl[K, V](store, stateName, keyEncoder,
-      userKeyEnc, valEncoder, metrics)
-    TWSMetricsUtils.incrementMetric(metrics, "numMapStateVars")
-    resultState
+      valEncoder: Encoder[V],
+      ttlConfig: TTLConfig): MapState[K, V] = {
+    getMapState(stateName, ttlConfig)(userKeyEnc, valEncoder)
   }
 
-  override def getMapState[K, V](
+  override def getMapState[K: Encoder, V: Encoder](
       stateName: String,
-      userKeyEnc: Encoder[K],
-      valEncoder: Encoder[V],
       ttlConfig: TTLConfig): MapState[K, V] = {
     verifyStateVarOperations("get_map_state", CREATED)
-    validateTTLConfig(ttlConfig, stateName)
 
-    assert(batchTimestampMs.isDefined)
-    val mapStateWithTTL = new MapStateImplWithTTL[K, V](store, stateName, 
keyEncoder, userKeyEnc,
-      valEncoder, ttlConfig, batchTimestampMs.get, metrics)
-    TWSMetricsUtils.incrementMetric(metrics, "numMapStateWithTTLVars")
-    ttlStates.add(mapStateWithTTL)
+    val ttlEnabled = if (ttlConfig.ttlDuration != null && 
ttlConfig.ttlDuration.isZero) {
+      false
+    } else {
+      true
+    }
 
-    mapStateWithTTL
+    val userKeyEnc = encoderFor[K].asInstanceOf[ExpressionEncoder[Any]]
+    val valEncoder = encoderFor[V].asInstanceOf[ExpressionEncoder[Any]]
+    val result = if (ttlEnabled) {
+      validateTTLConfig(ttlConfig, stateName)
+      assert(batchTimestampMs.isDefined)
+      val mapStateWithTTL = new MapStateImplWithTTL[K, V](store, stateName, 
keyEncoder, userKeyEnc,
+        valEncoder, ttlConfig, batchTimestampMs.get, metrics)
+      TWSMetricsUtils.incrementMetric(metrics, "numMapStateWithTTLVars")
+      ttlStates.add(mapStateWithTTL)
+      mapStateWithTTL
+    } else {
+      val mapStateWithoutTTL = new MapStateImpl[K, V](store, stateName, 
keyEncoder,
+        userKeyEnc, valEncoder, metrics)
+      TWSMetricsUtils.incrementMetric(metrics, "numMapStateVars")
+      mapStateWithoutTTL
+    }
+    result
   }
 
   private def validateTTLConfig(ttlConfig: TTLConfig, stateName: String): Unit 
= {
@@ -350,56 +380,58 @@ class DriverStatefulProcessorHandleImpl(timeMode: 
TimeMode, keyExprEnc: Expressi
     stateVariableInfos.put(stateName, stateVariableInfo)
   }
 
-  override def getValueState[T](stateName: String, valEncoder: Encoder[T]): 
ValueState[T] = {
-    verifyStateVarOperations("get_value_state", PRE_INIT)
-    val colFamilySchema = StateStoreColumnFamilySchemaUtils.
-      getValueStateSchema(stateName, keyExprEnc, valEncoder, false)
-    checkIfDuplicateVariableDefined(stateName)
-    columnFamilySchemas.put(stateName, colFamilySchema)
-    val stateVariableInfo = TransformWithStateVariableUtils.
-      getValueState(stateName, ttlEnabled = false)
-    stateVariableInfos.put(stateName, stateVariableInfo)
-    null.asInstanceOf[ValueState[T]]
-  }
-
   override def getValueState[T](
       stateName: String,
       valEncoder: Encoder[T],
       ttlConfig: TTLConfig): ValueState[T] = {
-    verifyStateVarOperations("get_value_state", PRE_INIT)
-    val colFamilySchema = StateStoreColumnFamilySchemaUtils.
-      getValueStateSchema(stateName, keyExprEnc, valEncoder, true)
-    checkIfDuplicateVariableDefined(stateName)
-    columnFamilySchemas.put(stateName, colFamilySchema)
-    val stateVariableInfo = TransformWithStateVariableUtils.
-      getValueState(stateName, ttlEnabled = true)
-    stateVariableInfos.put(stateName, stateVariableInfo)
-    null.asInstanceOf[ValueState[T]]
+    getValueState(stateName, ttlConfig)(valEncoder)
   }
 
-  override def getListState[T](stateName: String, valEncoder: Encoder[T]): 
ListState[T] = {
-    verifyStateVarOperations("get_list_state", PRE_INIT)
+  override def getValueState[T: Encoder](
+      stateName: String,
+      ttlConfig: TTLConfig): ValueState[T] = {
+    verifyStateVarOperations("get_value_state", PRE_INIT)
+    val ttlEnabled = if (ttlConfig.ttlDuration != null && 
ttlConfig.ttlDuration.isZero) {
+      false
+    } else {
+      true
+    }
+
+    val stateEncoder = encoderFor[T]
     val colFamilySchema = StateStoreColumnFamilySchemaUtils.
-      getListStateSchema(stateName, keyExprEnc, valEncoder, false)
+      getValueStateSchema(stateName, keyExprEnc, stateEncoder, ttlEnabled)
     checkIfDuplicateVariableDefined(stateName)
     columnFamilySchemas.put(stateName, colFamilySchema)
     val stateVariableInfo = TransformWithStateVariableUtils.
-      getListState(stateName, ttlEnabled = false)
+      getValueState(stateName, ttlEnabled = ttlEnabled)
     stateVariableInfos.put(stateName, stateVariableInfo)
-    null.asInstanceOf[ListState[T]]
+    null.asInstanceOf[ValueState[T]]
   }
 
   override def getListState[T](
       stateName: String,
       valEncoder: Encoder[T],
       ttlConfig: TTLConfig): ListState[T] = {
+    getListState(stateName, ttlConfig)(valEncoder)
+  }
+
+  override def getListState[T: Encoder](
+      stateName: String,
+      ttlConfig: TTLConfig): ListState[T] = {
     verifyStateVarOperations("get_list_state", PRE_INIT)
+    val ttlEnabled = if (ttlConfig.ttlDuration != null && 
ttlConfig.ttlDuration.isZero) {
+      false
+    } else {
+      true
+    }
+
+    val stateEncoder = encoderFor[T]
     val colFamilySchema = StateStoreColumnFamilySchemaUtils.
-      getListStateSchema(stateName, keyExprEnc, valEncoder, true)
+      getListStateSchema(stateName, keyExprEnc, stateEncoder, ttlEnabled)
     checkIfDuplicateVariableDefined(stateName)
     columnFamilySchemas.put(stateName, colFamilySchema)
     val stateVariableInfo = TransformWithStateVariableUtils.
-      getListState(stateName, ttlEnabled = true)
+      getListState(stateName, ttlEnabled = ttlEnabled)
     stateVariableInfos.put(stateName, stateVariableInfo)
     null.asInstanceOf[ListState[T]]
   }
@@ -407,29 +439,29 @@ class DriverStatefulProcessorHandleImpl(timeMode: 
TimeMode, keyExprEnc: Expressi
   override def getMapState[K, V](
       stateName: String,
       userKeyEnc: Encoder[K],
-      valEncoder: Encoder[V]): MapState[K, V] = {
-    verifyStateVarOperations("get_map_state", PRE_INIT)
-    val colFamilySchema = StateStoreColumnFamilySchemaUtils.
-      getMapStateSchema(stateName, keyExprEnc, userKeyEnc, valEncoder, false)
-    checkIfDuplicateVariableDefined(stateName)
-    columnFamilySchemas.put(stateName, colFamilySchema)
-    val stateVariableInfo = TransformWithStateVariableUtils.
-      getMapState(stateName, ttlEnabled = false)
-    stateVariableInfos.put(stateName, stateVariableInfo)
-    null.asInstanceOf[MapState[K, V]]
+      valEncoder: Encoder[V],
+      ttlConfig: TTLConfig): MapState[K, V] = {
+    getMapState(stateName, ttlConfig)(userKeyEnc, valEncoder)
   }
 
-  override def getMapState[K, V](
+  override def getMapState[K: Encoder, V: Encoder](
       stateName: String,
-      userKeyEnc: Encoder[K],
-      valEncoder: Encoder[V],
       ttlConfig: TTLConfig): MapState[K, V] = {
     verifyStateVarOperations("get_map_state", PRE_INIT)
+
+    val ttlEnabled = if (ttlConfig.ttlDuration != null && 
ttlConfig.ttlDuration.isZero) {
+      false
+    } else {
+      true
+    }
+
+    val userKeyEnc = encoderFor[K]
+    val valEncoder = encoderFor[V]
     val colFamilySchema = StateStoreColumnFamilySchemaUtils.
-      getMapStateSchema(stateName, keyExprEnc, userKeyEnc, valEncoder, true)
+      getMapStateSchema(stateName, keyExprEnc, userKeyEnc, valEncoder, 
ttlEnabled)
     columnFamilySchemas.put(stateName, colFamilySchema)
     val stateVariableInfo = TransformWithStateVariableUtils.
-      getMapState(stateName, ttlEnabled = true)
+      getMapState(stateName, ttlEnabled = ttlEnabled)
     stateVariableInfos.put(stateName, stateVariableInfo)
     null.asInstanceOf[MapState[K, V]]
   }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala
index 8811c59a5074..87d1a15dff1a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TTLState.scala
@@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.streaming
 
 import java.time.Duration
 
-import org.apache.spark.sql.Encoder
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow}
@@ -199,7 +198,7 @@ abstract class CompositeKeyTTLStateImpl[K](
     stateName: String,
     store: StateStore,
     keyExprEnc: ExpressionEncoder[Any],
-    userKeyEncoder: Encoder[K],
+    userKeyEncoder: ExpressionEncoder[Any],
     ttlExpirationMs: Long)
   extends TTLState {
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala
index b1b87feeb263..cd66bf99d4e1 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala
@@ -17,7 +17,6 @@
 package org.apache.spark.sql.execution.streaming
 
 import org.apache.spark.internal.Logging
-import org.apache.spark.sql.Encoder
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.execution.metric.SQLMetric
 import 
org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, 
StateStore}
@@ -37,7 +36,7 @@ class ValueStateImpl[S](
     store: StateStore,
     stateName: String,
     keyExprEnc: ExpressionEncoder[Any],
-    valEncoder: Encoder[S],
+    valEncoder: ExpressionEncoder[Any],
     metrics: Map[String, SQLMetric] = Map.empty)
   extends ValueState[S] with Logging {
 
@@ -66,7 +65,7 @@ class ValueStateImpl[S](
     val retRow = store.get(encodedGroupingKey, stateName)
 
     if (retRow != null) {
-      stateTypesEncoder.decodeValue(retRow)
+      stateTypesEncoder.decodeValue(retRow).asInstanceOf[S]
     } else {
       null.asInstanceOf[S]
     }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala
index 145cd9026491..60eea5842645 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala
@@ -16,7 +16,6 @@
  */
 package org.apache.spark.sql.execution.streaming
 
-import org.apache.spark.sql.Encoder
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
 import org.apache.spark.sql.execution.metric.SQLMetric
@@ -41,7 +40,7 @@ class ValueStateImplWithTTL[S](
     store: StateStore,
     stateName: String,
     keyExprEnc: ExpressionEncoder[Any],
-    valEncoder: Encoder[S],
+    valEncoder: ExpressionEncoder[Any],
     ttlConfig: TTLConfig,
     batchTimestampMs: Long,
     metrics: Map[String, SQLMetric] = Map.empty)
@@ -80,7 +79,7 @@ class ValueStateImplWithTTL[S](
       val resState = stateTypesEncoder.decodeValue(retRow)
 
       if (!stateTypesEncoder.isExpired(retRow, batchTimestampMs)) {
-        resState
+        resState.asInstanceOf[S]
       } else {
         null.asInstanceOf[S]
       }
@@ -136,7 +135,7 @@ class ValueStateImplWithTTL[S](
     val retRow = store.get(encodedGroupingKey, stateName)
 
     if (retRow != null) {
-      val resState = stateTypesEncoder.decodeValue(retRow)
+      val resState = stateTypesEncoder.decodeValue(retRow).asInstanceOf[S]
       Some(resState)
     } else {
       None
@@ -154,7 +153,8 @@ class ValueStateImplWithTTL[S](
     // ttlExpiration
     if (retRow != null) {
       val ttlExpiration = stateTypesEncoder.decodeTtlExpirationMs(retRow)
-      ttlExpiration.map(expiration => (stateTypesEncoder.decodeValue(retRow), 
expiration))
+      ttlExpiration.map(expiration => 
(stateTypesEncoder.decodeValue(retRow).asInstanceOf[S],
+        expiration))
     } else {
       None
     }
diff --git 
a/sql/core/src/test/java/test/org/apache/spark/sql/TestStatefulProcessor.java 
b/sql/core/src/test/java/test/org/apache/spark/sql/TestStatefulProcessor.java
index 3f0efe2e14af..9b0beb39bf13 100644
--- 
a/sql/core/src/test/java/test/org/apache/spark/sql/TestStatefulProcessor.java
+++ 
b/sql/core/src/test/java/test/org/apache/spark/sql/TestStatefulProcessor.java
@@ -42,13 +42,13 @@ public class TestStatefulProcessor extends 
StatefulProcessor<Integer, String, St
       OutputMode outputMode,
       TimeMode timeMode) {
     countState = this.getHandle().getValueState("countState",
-      Encoders.LONG());
+      Encoders.LONG(), TTLConfig.NONE());
 
     keyCountMap = this.getHandle().getMapState("keyCountMap",
-      Encoders.STRING(), Encoders.LONG());
+      Encoders.STRING(), Encoders.LONG(), TTLConfig.NONE());
 
     keysList = this.getHandle().getListState("keyList",
-      Encoders.STRING());
+      Encoders.STRING(), TTLConfig.NONE());
   }
 
   @Override
diff --git 
a/sql/core/src/test/java/test/org/apache/spark/sql/TestStatefulProcessorWithInitialState.java
 
b/sql/core/src/test/java/test/org/apache/spark/sql/TestStatefulProcessorWithInitialState.java
index 7e356abf2d05..63cac8c36915 100644
--- 
a/sql/core/src/test/java/test/org/apache/spark/sql/TestStatefulProcessorWithInitialState.java
+++ 
b/sql/core/src/test/java/test/org/apache/spark/sql/TestStatefulProcessorWithInitialState.java
@@ -41,7 +41,7 @@ public class TestStatefulProcessorWithInitialState
       OutputMode outputMode,
       TimeMode timeMode) {
     testState = this.getHandle().getValueState("testState",
-      Encoders.STRING());
+      Encoders.STRING(), TTLConfig.NONE());
   }
 
   @Override
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
index 293ec52cc871..d00827fbd3b2 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala
@@ -34,7 +34,7 @@ class StatefulProcessorWithSingleValueVar extends 
RunningCountStatefulProcessor
       outputMode: OutputMode,
       timeMode: TimeMode): Unit = {
     _valueState = getHandle.getValueState[TestClass](
-      "valueState", Encoders.product[TestClass])
+      "valueState", Encoders.product[TestClass], TTLConfig.NONE)
   }
 
   override def handleInputRows(
@@ -81,7 +81,7 @@ class SessionGroupsStatefulProcessor extends
   override def init(
       outputMode: OutputMode,
       timeMode: TimeMode): Unit = {
-    _groupsList = getHandle.getListState("groupsList", Encoders.STRING)
+    _groupsList = getHandle.getListState("groupsList", Encoders.STRING, 
TTLConfig.NONE)
   }
 
   override def handleInputRows(
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala
index aaaa53059126..3925c3d62da3 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala
@@ -130,7 +130,8 @@ class TransformWithStateInPandasStateServerSuite extends 
SparkFunSuite with Befo
         verify(statefulProcessorHandle)
           .getValueState[Row](any[String], any[Encoder[Row]], any[TTLConfig])
       } else {
-        verify(statefulProcessorHandle).getValueState[Row](any[String], 
any[Encoder[Row]])
+        verify(statefulProcessorHandle).getValueState[Row](any[String], 
any[Encoder[Row]],
+          any[TTLConfig])
       }
       verify(outputStream).writeInt(0)
     }
@@ -153,7 +154,8 @@ class TransformWithStateInPandasStateServerSuite extends 
SparkFunSuite with Befo
         verify(statefulProcessorHandle)
           .getListState[Row](any[String], any[Encoder[Row]], any[TTLConfig])
       } else {
-        verify(statefulProcessorHandle).getListState[Row](any[String], 
any[Encoder[Row]])
+        verify(statefulProcessorHandle).getListState[Row](any[String], 
any[Encoder[Row]],
+          any[TTLConfig])
       }
       verify(outputStream).writeInt(0)
     }
@@ -178,7 +180,7 @@ class TransformWithStateInPandasStateServerSuite extends 
SparkFunSuite with Befo
           .getMapState[Row, Row](any[String], any[Encoder[Row]], 
any[Encoder[Row]], any[TTLConfig])
       } else {
         verify(statefulProcessorHandle).getMapState[Row, Row](any[String], 
any[Encoder[Row]],
-          any[Encoder[Row]])
+          any[Encoder[Row]], any[TTLConfig])
       }
       verify(outputStream).writeInt(0)
     }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala
index e9300464af8d..22876831c00d 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ListStateSuite.scala
@@ -31,6 +31,8 @@ import org.apache.spark.sql.streaming.{ListState, TimeMode, 
TTLConfig, ValueStat
  * operators such as transformWithState
  */
 class ListStateSuite extends StateVariableSuiteBase {
+  import testImplicits._
+
   // overwrite useMultipleValuesPerKey in base suite to be true for list state
   override def useMultipleValuesPerKey: Boolean = true
 
@@ -40,7 +42,8 @@ class ListStateSuite extends StateVariableSuiteBase {
       val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
         stringEncoder, TimeMode.None())
 
-      val listState: ListState[Long] = handle.getListState[Long]("listState", 
Encoders.scalaLong)
+      val listState: ListState[Long] = handle.getListState[Long]("listState",
+        TTLConfig.NONE)
 
       ImplicitGroupingKeyTracker.setImplicitKey("test_key")
       val e = intercept[SparkIllegalArgumentException] {
@@ -73,7 +76,8 @@ class ListStateSuite extends StateVariableSuiteBase {
       val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
         stringEncoder, TimeMode.None())
 
-      val testState: ListState[Long] = handle.getListState[Long]("testState", 
Encoders.scalaLong)
+      val testState: ListState[Long] = handle.getListState[Long]("testState",
+        TTLConfig.NONE)
       ImplicitGroupingKeyTracker.setImplicitKey("test_key")
 
       // simple put and get test
@@ -101,8 +105,10 @@ class ListStateSuite extends StateVariableSuiteBase {
       val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
         stringEncoder, TimeMode.None())
 
-      val testState1: ListState[Long] = 
handle.getListState[Long]("testState1", Encoders.scalaLong)
-      val testState2: ListState[Long] = 
handle.getListState[Long]("testState2", Encoders.scalaLong)
+      val testState1: ListState[Long] = handle.getListState[Long]("testState1",
+        TTLConfig.NONE)
+      val testState2: ListState[Long] = handle.getListState[Long]("testState2",
+        TTLConfig.NONE)
 
       ImplicitGroupingKeyTracker.setImplicitKey("test_key")
 
@@ -139,10 +145,12 @@ class ListStateSuite extends StateVariableSuiteBase {
       val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
         stringEncoder, TimeMode.None())
 
-      val listState1: ListState[Long] = 
handle.getListState[Long]("listState1", Encoders.scalaLong)
-      val listState2: ListState[Long] = 
handle.getListState[Long]("listState2", Encoders.scalaLong)
+      val listState1: ListState[Long] = handle.getListState[Long]("listState1",
+        TTLConfig.NONE)
+      val listState2: ListState[Long] = handle.getListState[Long]("listState2",
+        TTLConfig.NONE)
       val valueState: ValueState[Long] = handle.getValueState[Long](
-        "valueState", Encoders.scalaLong)
+        "valueState", TTLConfig.NONE)
 
       ImplicitGroupingKeyTracker.setImplicitKey("test_key")
       // simple put and get test
@@ -218,7 +226,7 @@ class ListStateSuite extends StateVariableSuiteBase {
     }
   }
 
-  test("test negative or zero TTL duration throws error") {
+  test("test null or negative TTL duration throws error") {
     tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
       val store = provider.getStore(0)
       val batchTimestampMs = 10
@@ -226,7 +234,7 @@ class ListStateSuite extends StateVariableSuiteBase {
         stringEncoder,
         TimeMode.ProcessingTime(), batchTimestampMs = Some(batchTimestampMs))
 
-      Seq(null, Duration.ZERO, Duration.ofMinutes(-1)).foreach { ttlDuration =>
+      Seq(null, Duration.ofMinutes(-1)).foreach { ttlDuration =>
         val ttlConfig = TTLConfig(ttlDuration)
         val ex = intercept[SparkUnsupportedOperationException] {
           handle.getListState[String]("testState", Encoders.STRING, ttlConfig)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala
index b067d589de90..9a0a891d538e 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MapStateSuite.scala
@@ -36,6 +36,8 @@ class MapStateSuite extends StateVariableSuiteBase {
     .add("key", BinaryType)
     .add("userKey", BinaryType)
 
+  import testImplicits._
+
   test("Map state operations for single instance") {
     tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
       val store = provider.getStore(0)
@@ -43,7 +45,8 @@ class MapStateSuite extends StateVariableSuiteBase {
         stringEncoder, TimeMode.None())
 
       val testState: MapState[String, Double] =
-        handle.getMapState[String, Double]("testState", Encoders.STRING, 
Encoders.scalaDouble)
+        handle.getMapState[String, Double]("testState", Encoders.STRING, 
Encoders.scalaDouble,
+          TTLConfig.NONE)
       ImplicitGroupingKeyTracker.setImplicitKey("test_key")
       // put initial value
       testState.updateValue("k1", 1.0)
@@ -77,9 +80,10 @@ class MapStateSuite extends StateVariableSuiteBase {
         stringEncoder, TimeMode.None())
 
       val testState1: MapState[Long, Double] =
-        handle.getMapState[Long, Double]("testState1", Encoders.scalaLong, 
Encoders.scalaDouble)
+        handle.getMapState[Long, Double]("testState1", TTLConfig.NONE)
       val testState2: MapState[Long, Int] =
-        handle.getMapState[Long, Int]("testState2", Encoders.scalaLong, 
Encoders.scalaInt)
+        handle.getMapState[Long, Int]("testState2",
+          TTLConfig.NONE)
       ImplicitGroupingKeyTracker.setImplicitKey("test_key")
       // put initial value
       testState1.updateValue(1L, 1.0)
@@ -116,13 +120,13 @@ class MapStateSuite extends StateVariableSuiteBase {
         stringEncoder, TimeMode.None())
 
       val mapTestState1: MapState[String, Int] =
-        handle.getMapState[String, Int]("mapTestState1", Encoders.STRING, 
Encoders.scalaInt)
+        handle.getMapState[String, Int]("mapTestState1", TTLConfig.NONE)
       val mapTestState2: MapState[String, Int] =
-        handle.getMapState[String, Int]("mapTestState2", Encoders.STRING, 
Encoders.scalaInt)
+        handle.getMapState[String, Int]("mapTestState2", TTLConfig.NONE)
       val valueTestState: ValueState[String] =
-        handle.getValueState[String]("valueTestState", Encoders.STRING)
+        handle.getValueState[String]("valueTestState", TTLConfig.NONE)
       val listTestState: ListState[String] =
-        handle.getListState[String]("listTestState", Encoders.STRING)
+        handle.getListState[String]("listTestState", TTLConfig.NONE)
 
       ImplicitGroupingKeyTracker.setImplicitKey("test_key")
       // put initial values
@@ -227,7 +231,7 @@ class MapStateSuite extends StateVariableSuiteBase {
     }
   }
 
-  test("test negative or zero TTL duration throws error") {
+  test("test null or negative TTL duration throws error") {
     tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
       val store = provider.getStore(0)
       val batchTimestampMs = 10
@@ -235,7 +239,7 @@ class MapStateSuite extends StateVariableSuiteBase {
         stringEncoder,
         TimeMode.ProcessingTime(), batchTimestampMs = Some(batchTimestampMs))
 
-      Seq(null, Duration.ZERO, Duration.ofMinutes(-1)).foreach { ttlDuration =>
+      Seq(null, Duration.ofMinutes(-1)).foreach { ttlDuration =>
         val ttlConfig = TTLConfig(ttlDuration)
         val ex = intercept[SparkUnsupportedOperationException] {
           handle.getMapState[String, String](
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala
index 48a6fd836a46..0d74aade6719 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StatefulProcessorHandleSuite.scala
@@ -32,6 +32,8 @@ import org.apache.spark.sql.streaming.{TimeMode, TTLConfig}
  */
 class StatefulProcessorHandleSuite extends StateVariableSuiteBase {
 
+  import testImplicits._
+
   private def getTimeMode(timeMode: String): TimeMode = {
     timeMode match {
       case "None" => TimeMode.None()
@@ -48,7 +50,7 @@ class StatefulProcessorHandleSuite extends 
StateVariableSuiteBase {
         val handle = new StatefulProcessorHandleImpl(store,
           UUID.randomUUID(), stringEncoder, getTimeMode(timeMode))
         assert(handle.getHandleState === StatefulProcessorHandleState.CREATED)
-        handle.getValueState[Long]("testState", Encoders.scalaLong)
+        handle.getValueState[Long]("testState", TTLConfig.NONE)
       }
     }
   }
@@ -74,7 +76,7 @@ class StatefulProcessorHandleSuite extends 
StateVariableSuiteBase {
   }
 
   private def createValueStateInstance(handle: StatefulProcessorHandleImpl): 
Unit = {
-    handle.getValueState[Long]("testState", Encoders.scalaLong)
+    handle.getValueState[Long]("testState", TTLConfig.NONE)
   }
 
   private def registerTimer(handle: StatefulProcessorHandleImpl): Unit = {
@@ -222,11 +224,11 @@ class StatefulProcessorHandleSuite extends 
StateVariableSuiteBase {
         UUID.randomUUID(), stringEncoder, TimeMode.ProcessingTime(),
         batchTimestampMs = Some(10))
 
-      val valueStateWithTTL = handle.getValueState("testState",
-        Encoders.STRING, TTLConfig(Duration.ofHours(1)))
+      val valueStateWithTTL = handle.getValueState[String]("testState",
+        TTLConfig(Duration.ofHours(1)))
 
       // create another state without TTL, this should not be captured in the 
handle
-      handle.getValueState("testState", Encoders.STRING)
+      handle.getValueState[String]("testState", TTLConfig.NONE)
 
       assert(handle.ttlStates.size() === 1)
       assert(handle.ttlStates.get(0) === valueStateWithTTL)
@@ -244,7 +246,7 @@ class StatefulProcessorHandleSuite extends 
StateVariableSuiteBase {
         Encoders.STRING, TTLConfig(Duration.ofHours(1)))
 
       // create another state without TTL, this should not be captured in the 
handle
-      handle.getListState("testState", Encoders.STRING)
+      handle.getListState("testState", Encoders.STRING, TTLConfig.NONE)
 
       assert(handle.ttlStates.size() === 1)
       assert(handle.ttlStates.get(0) === listStateWithTTL)
@@ -262,7 +264,8 @@ class StatefulProcessorHandleSuite extends 
StateVariableSuiteBase {
         Encoders.STRING, Encoders.STRING, TTLConfig(Duration.ofHours(1)))
 
       // create another state without TTL, this should not be captured in the 
handle
-      handle.getMapState("testState", Encoders.STRING, Encoders.STRING)
+      handle.getMapState("testState", Encoders.STRING, Encoders.STRING,
+        TTLConfig.NONE)
 
       assert(handle.ttlStates.size() === 1)
       assert(handle.ttlStates.get(0) === mapStateWithTTL)
@@ -275,9 +278,10 @@ class StatefulProcessorHandleSuite extends 
StateVariableSuiteBase {
       val handle = new StatefulProcessorHandleImpl(store,
         UUID.randomUUID(), stringEncoder, TimeMode.None())
 
-      handle.getValueState("testValueState", Encoders.STRING)
-      handle.getListState("testListState", Encoders.STRING)
-      handle.getMapState("testMapState", Encoders.STRING, Encoders.STRING)
+      handle.getValueState("testValueState", Encoders.STRING, TTLConfig.NONE)
+      handle.getListState("testListState", Encoders.STRING, TTLConfig.NONE)
+      handle.getMapState("testMapState", Encoders.STRING, Encoders.STRING,
+        TTLConfig.NONE)
 
       assert(handle.ttlStates.isEmpty)
     }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
index 13d758eb1b88..55d08cd8f12a 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
@@ -44,6 +44,7 @@ case class TestClass(var id: Long, var name: String)
 class ValueStateSuite extends StateVariableSuiteBase {
 
   import StateStoreTestsHelper._
+  import testImplicits._
 
   test("Implicit key operations") {
     tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
@@ -52,7 +53,8 @@ class ValueStateSuite extends StateVariableSuiteBase {
         stringEncoder, TimeMode.None())
 
       val stateName = "testState"
-      val testState: ValueState[Long] = 
handle.getValueState[Long]("testState", Encoders.scalaLong)
+      val testState: ValueState[Long] = handle.getValueState[Long]("testState",
+        TTLConfig.NONE)
       assert(ImplicitGroupingKeyTracker.getImplicitKeyOption.isEmpty)
       val ex = intercept[Exception] {
         testState.update(123)
@@ -95,7 +97,8 @@ class ValueStateSuite extends StateVariableSuiteBase {
       val handle = new StatefulProcessorHandleImpl(store, UUID.randomUUID(),
         stringEncoder, TimeMode.None())
 
-      val testState: ValueState[Long] = 
handle.getValueState[Long]("testState", Encoders.scalaLong)
+      val testState: ValueState[Long] = handle.getValueState[Long]("testState",
+        TTLConfig.NONE)
       ImplicitGroupingKeyTracker.setImplicitKey("test_key")
       testState.update(123)
       assert(testState.get() === 123)
@@ -122,9 +125,9 @@ class ValueStateSuite extends StateVariableSuiteBase {
         stringEncoder, TimeMode.None())
 
       val testState1: ValueState[Long] = handle.getValueState[Long](
-        "testState1", Encoders.scalaLong)
+        "testState1", TTLConfig.NONE)
       val testState2: ValueState[Long] = handle.getValueState[Long](
-        "testState2", Encoders.scalaLong)
+        "testState2", TTLConfig.NONE)
       ImplicitGroupingKeyTracker.setImplicitKey("test_key")
       testState1.update(123)
       assert(testState1.get() === 123)
@@ -168,7 +171,7 @@ class ValueStateSuite extends StateVariableSuiteBase {
 
       val cfName = "$testState"
       val ex = intercept[SparkUnsupportedOperationException] {
-        handle.getValueState[Long](cfName, Encoders.scalaLong)
+        handle.getValueState[Long](cfName, TTLConfig.NONE)
       }
       checkError(
         ex,
@@ -207,7 +210,7 @@ class ValueStateSuite extends StateVariableSuiteBase {
         stringEncoder, TimeMode.None())
 
       val testState: ValueState[Double] = 
handle.getValueState[Double]("testState",
-        Encoders.scalaDouble)
+        Encoders.scalaDouble, TTLConfig.NONE)
       ImplicitGroupingKeyTracker.setImplicitKey("test_key")
       testState.update(1.0)
       assert(testState.get().equals(1.0))
@@ -233,7 +236,7 @@ class ValueStateSuite extends StateVariableSuiteBase {
         stringEncoder, TimeMode.None())
 
       val testState: ValueState[Long] = handle.getValueState[Long]("testState",
-        Encoders.scalaLong)
+        TTLConfig.NONE)
       ImplicitGroupingKeyTracker.setImplicitKey("test_key")
       testState.update(1L)
       assert(testState.get().equals(1L))
@@ -259,7 +262,7 @@ class ValueStateSuite extends StateVariableSuiteBase {
         stringEncoder, TimeMode.None())
 
       val testState: ValueState[TestClass] = 
handle.getValueState[TestClass]("testState",
-        Encoders.product[TestClass])
+        TTLConfig.NONE)
       ImplicitGroupingKeyTracker.setImplicitKey("test_key")
       testState.update(TestClass(1, "testcase1"))
       assert(testState.get().equals(TestClass(1, "testcase1")))
@@ -285,7 +288,7 @@ class ValueStateSuite extends StateVariableSuiteBase {
         stringEncoder, TimeMode.None())
 
       val testState: ValueState[POJOTestClass] = 
handle.getValueState[POJOTestClass]("testState",
-        Encoders.bean(classOf[POJOTestClass]))
+         Encoders.bean(classOf[POJOTestClass]), TTLConfig.NONE)
       ImplicitGroupingKeyTracker.setImplicitKey("test_key")
       testState.update(new POJOTestClass("testcase1", 1))
       assert(testState.get().equals(new POJOTestClass("testcase1", 1)))
@@ -304,7 +307,6 @@ class ValueStateSuite extends StateVariableSuiteBase {
     }
   }
 
-
   test(s"test Value state TTL") {
     tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
       val store = provider.getStore(0)
@@ -361,7 +363,7 @@ class ValueStateSuite extends StateVariableSuiteBase {
     }
   }
 
-  test("test negative or zero TTL duration throws error") {
+  test("test null or zero TTL duration throws error") {
     tryWithProviderResource(newStoreProviderWithStateVariable(true)) { 
provider =>
       val store = provider.getStore(0)
       val batchTimestampMs = 10
@@ -369,7 +371,7 @@ class ValueStateSuite extends StateVariableSuiteBase {
         stringEncoder,
         TimeMode.ProcessingTime(), batchTimestampMs = Some(batchTimestampMs))
 
-      Seq(null, Duration.ZERO, Duration.ofMinutes(-1)).foreach { ttlDuration =>
+      Seq(null, Duration.ofMinutes(-1)).foreach { ttlDuration =>
         val ttlConfig = TTLConfig(ttlDuration)
         val ex = intercept[SparkUnsupportedOperationException] {
           handle.getValueState[String]("testState", Encoders.STRING, ttlConfig)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala
index 20f04cc66c0a..88862e2ad079 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala
@@ -33,7 +33,7 @@ class TestListStateProcessor
   override def init(
       outputMode: OutputMode,
       timeMode: TimeMode): Unit = {
-    _listState = getHandle.getListState("testListState", Encoders.STRING)
+    _listState = getHandle.getListState("testListState", Encoders.STRING, 
TTLConfig.NONE)
   }
 
   override def handleInputRows(
@@ -89,8 +89,9 @@ class ToggleSaveAndEmitProcessor
   override def init(
       outputMode: OutputMode,
       timeMode: TimeMode): Unit = {
-    _listState = getHandle.getListState("testListState", Encoders.STRING)
-    _valueState = getHandle.getValueState("testValueState", 
Encoders.scalaBoolean)
+    _listState = getHandle.getListState("testListState", Encoders.STRING, 
TTLConfig.NONE)
+    _valueState = getHandle.getValueState("testValueState", 
Encoders.scalaBoolean,
+      TTLConfig.NONE)
   }
 
   override def handleInputRows(
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala
index da4218949a11..76c5cbeee424 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithMapStateSuite.scala
@@ -33,7 +33,8 @@ class TestMapStateProcessor
   override def init(
       outputMode: OutputMode,
       timeMode: TimeMode): Unit = {
-    _mapState = getHandle.getMapState("sessionState", Encoders.STRING, 
Encoders.STRING)
+    _mapState = getHandle.getMapState("sessionState", Encoders.STRING, 
Encoders.STRING,
+      TTLConfig.NONE)
   }
 
   override def handleInputRows(
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala
index 300785611fd0..35ac8a4687eb 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala
@@ -39,10 +39,13 @@ abstract class StatefulProcessorWithInitialStateTestClass[V]
   override def init(
       outputMode: OutputMode,
       timeMode: TimeMode): Unit = {
-    _valState = getHandle.getValueState[Double]("testValueInit", 
Encoders.scalaDouble)
-    _listState = getHandle.getListState[Double]("testListInit", 
Encoders.scalaDouble)
+    _valState = getHandle.getValueState[Double]("testValueInit", 
Encoders.scalaDouble,
+      TTLConfig.NONE)
+    _listState = getHandle.getListState[Double]("testListInit", 
Encoders.scalaDouble,
+      TTLConfig.NONE)
     _mapState = getHandle.getMapState[Double, Int](
-      "testMapInit", Encoders.scalaDouble, Encoders.scalaInt)
+      "testMapInit", Encoders.scalaDouble, Encoders.scalaInt,
+        TTLConfig.NONE)
   }
 
   override def handleInputRows(
@@ -162,8 +165,10 @@ class StatefulProcessorWithInitialStateProcTimerClass
   override def init(
       outputMode: OutputMode,
       timeMode: TimeMode) : Unit = {
-    _countState = getHandle.getValueState[Long]("countState", 
Encoders.scalaLong)
-    _timerState = getHandle.getValueState[Long]("timerState", 
Encoders.scalaLong)
+    _countState = getHandle.getValueState[Long]("countState", 
Encoders.scalaLong,
+      TTLConfig.NONE)
+    _timerState = getHandle.getValueState[Long]("timerState", 
Encoders.scalaLong,
+      TTLConfig.NONE)
   }
 
   override def handleInitialState(
@@ -209,8 +214,9 @@ class StatefulProcessorWithInitialStateEventTimerClass
       outputMode: OutputMode,
       timeMode: TimeMode): Unit = {
     _maxEventTimeState = getHandle.getValueState[Long]("maxEventTimeState",
-      Encoders.scalaLong)
-    _timerState = getHandle.getValueState[Long]("timerState", 
Encoders.scalaLong)
+      Encoders.scalaLong, TTLConfig.NONE)
+    _timerState = getHandle.getValueState[Long]("timerState", 
Encoders.scalaLong,
+      TTLConfig.NONE)
   }
 
   private def processUnexpiredRows(maxEventTimeSec: Long): Unit = {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
index 7d61c6fc7084..3ef5c57ee3d0 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
@@ -50,7 +50,8 @@ class RunningCountStatefulProcessor extends 
StatefulProcessor[String, String, (S
   override def init(
       outputMode: OutputMode,
       timeMode: TimeMode): Unit = {
-    _countState = getHandle.getValueState[Long]("countState", 
Encoders.scalaLong)
+    _countState = getHandle.getValueState[Long]("countState",
+      Encoders.scalaLong, TTLConfig.NONE)
   }
 
   override def handleInputRows(
@@ -76,8 +77,8 @@ class RunningCountStatefulProcessorWithTTL
   override def init(
       outputMode: OutputMode,
       timeMode: TimeMode): Unit = {
-    _countState = getHandle.getValueState[Long]("countState",
-      Encoders.scalaLong, TTLConfig(Duration.ofMillis(1000)))
+    _countState = getHandle.getValueState[Long]("countState", 
Encoders.scalaLong,
+      TTLConfig(Duration.ofMillis(1000)))
   }
 
   override def handleInputRows(
@@ -106,7 +107,7 @@ class RunningCountListStatefulProcessor
       outputMode: OutputMode,
       timeMode: TimeMode): Unit = {
     _countState = getHandle.getListState[Long](
-      "countState", Encoders.scalaLong)
+      "countState", Encoders.scalaLong, TTLConfig.NONE)
   }
 
   override def handleInputRows(
@@ -124,7 +125,8 @@ class RunningCountStatefulProcessorInt
   override def init(
       outputMode: OutputMode,
       timeMode: TimeMode): Unit = {
-    _countState = getHandle.getValueState[Int]("countState", Encoders.scalaInt)
+    _countState = getHandle.getValueState[Int]("countState", Encoders.scalaInt,
+      TTLConfig.NONE)
   }
 
   override def handleInputRows(
@@ -183,7 +185,8 @@ class RunningCountStatefulProcessorWithProcTimeTimerUpdates
       outputMode: OutputMode,
       timeMode: TimeMode) : Unit = {
     super.init(outputMode, timeMode)
-    _timerState = getHandle.getValueState[Long]("timerState", 
Encoders.scalaLong)
+    _timerState = getHandle.getValueState[Long]("timerState", 
Encoders.scalaLong,
+      TTLConfig.NONE)
   }
 
   protected def processUnexpiredRows(
@@ -267,8 +270,9 @@ class MaxEventTimeStatefulProcessor
       outputMode: OutputMode,
       timeMode: TimeMode): Unit = {
     _maxEventTimeState = getHandle.getValueState[Long]("maxEventTimeState",
-      Encoders.scalaLong)
-    _timerState = getHandle.getValueState[Long]("timerState", 
Encoders.scalaLong)
+      Encoders.scalaLong, TTLConfig.NONE)
+    _timerState = getHandle.getValueState[Long]("timerState", 
Encoders.scalaLong,
+      TTLConfig.NONE)
   }
 
   protected def processUnexpiredRows(maxEventTimeSec: Long): Unit = {
@@ -313,8 +317,10 @@ class RunningCountMostRecentStatefulProcessor
   override def init(
       outputMode: OutputMode,
       timeMode: TimeMode): Unit = {
-    _countState = getHandle.getValueState[Long]("countState", 
Encoders.scalaLong)
-    _mostRecent = getHandle.getValueState[String]("mostRecent", 
Encoders.STRING)
+    _countState = getHandle.getValueState[Long]("countState", 
Encoders.scalaLong,
+      TTLConfig.NONE)
+    _mostRecent = getHandle.getValueState[String]("mostRecent", 
Encoders.STRING,
+      TTLConfig.NONE)
   }
 
   override def handleInputRows(
@@ -343,7 +349,8 @@ class MostRecentStatefulProcessorWithDeletion
       outputMode: OutputMode,
       timeMode: TimeMode): Unit = {
     getHandle.deleteIfExists("countState")
-    _mostRecent = getHandle.getValueState[String]("mostRecent", 
Encoders.STRING)
+    _mostRecent = getHandle.getValueState[String]("mostRecent", 
Encoders.STRING,
+      TTLConfig.NONE)
   }
 
   override def handleInputRows(
@@ -370,7 +377,8 @@ class RunningCountStatefulProcessorWithError extends 
RunningCountStatefulProcess
       inputRows: Iterator[String],
       timerValues: TimerValues): Iterator[(String, String)] = {
     // Trying to create value state here should fail
-    _tempState = getHandle.getValueState[Long]("tempState", Encoders.scalaLong)
+    _tempState = getHandle.getValueState[Long]("tempState", Encoders.scalaLong,
+      TTLConfig.NONE)
     Iterator.empty
   }
 }
@@ -383,11 +391,13 @@ class StatefulProcessorWithCompositeTypes extends 
RunningCountStatefulProcessor
   override def init(
       outputMode: OutputMode,
       timeMode: TimeMode): Unit = {
-    _countState = getHandle.getValueState[Long]("countState", 
Encoders.scalaLong)
+    _countState = getHandle.getValueState[Long]("countState", 
Encoders.scalaLong,
+      TTLConfig.NONE)
     _listState = getHandle.getListState[TestClass](
-      "listState", Encoders.product[TestClass])
+      "listState", Encoders.product[TestClass], TTLConfig.NONE)
     _mapState = getHandle.getMapState[POJOTestClass, String](
-      "mapState", Encoders.bean(classOf[POJOTestClass]), Encoders.STRING)
+      "mapState", Encoders.bean(classOf[POJOTestClass]), Encoders.STRING,
+      TTLConfig.NONE)
   }
 }
 
@@ -421,6 +431,9 @@ class TransformWithStateSuite extends StateStoreMetricsTest
   }
 
   test("transformWithState - lazy iterators can properly get/set keyed state") 
{
+    val spark = this.spark
+    import spark.implicits._
+
     class ProcessorWithLazyIterators
       extends StatefulProcessor[Long, Long, Long] {
       @transient protected var _myValueState: ValueState[Long] = _
@@ -429,7 +442,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest
       override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {
         _myValueState = getHandle.getValueState[Long](
           "myValueState",
-          Encoders.scalaLong
+          TTLConfig.NONE
         )
       }
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala
index e7b394db0c3c..21c3beb79314 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala
@@ -130,7 +130,7 @@ class MultipleValueStatesTTLProcessor(
       .getValueState("valueStateTTL", Encoders.scalaInt, ttlConfig)
       .asInstanceOf[ValueStateImplWithTTL[Int]]
     _valueStateWithoutTTL = getHandle
-      .getValueState("valueState", Encoders.scalaInt)
+      .getValueState[Int]("valueState", Encoders.scalaInt, TTLConfig.NONE)
       .asInstanceOf[ValueStateImpl[Int]]
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to