This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 06a8f6b6d6f6 [SPARK-49744][SS][PYTHON] Implement TTL support for 
ListState in TransformWithStateInPandas
06a8f6b6d6f6 is described below

commit 06a8f6b6d6f674d72fdb0d17e20041c442a2a6fd
Author: bogao007 <[email protected]>
AuthorDate: Sun Oct 6 16:57:02 2024 +0900

    [SPARK-49744][SS][PYTHON] Implement TTL support for ListState in 
TransformWithStateInPandas
    
    ### What changes were proposed in this pull request?
    
    Implement TTL support for ListState in TransformWithStateInPandas.
    
    ### Why are the changes needed?
    
    Allow users to add TTL to specific list state.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes
    
    ### How was this patch tested?
    
    Added unit tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #48253 from bogao007/ttl-list-state.
    
    Authored-by: bogao007 <[email protected]>
    Signed-off-by: Jungtaek Lim <[email protected]>
---
 python/pyspark/sql/streaming/stateful_processor.py | 13 ++++--
 .../sql/streaming/stateful_processor_api_client.py |  6 ++-
 .../pandas/test_pandas_transform_with_state.py     | 48 +++++++++++++++++++++-
 .../TransformWithStateInPandasStateServer.scala    | 18 +++++---
 ...ransformWithStateInPandasStateServerSuite.scala | 23 +++++++++++
 5 files changed, 97 insertions(+), 11 deletions(-)

diff --git a/python/pyspark/sql/streaming/stateful_processor.py 
b/python/pyspark/sql/streaming/stateful_processor.py
index 0011b62132ad..6b8de0f8ac4e 100644
--- a/python/pyspark/sql/streaming/stateful_processor.py
+++ b/python/pyspark/sql/streaming/stateful_processor.py
@@ -56,7 +56,7 @@ class ValueState:
         """
         return self._value_state_client.get(self._state_name)
 
-    def update(self, new_value: Any) -> None:
+    def update(self, new_value: Tuple) -> None:
         """
         Update the value of the state.
         """
@@ -156,7 +156,9 @@ class StatefulProcessorHandle:
         self.stateful_processor_api_client.get_value_state(state_name, schema, 
ttl_duration_ms)
         return 
ValueState(ValueStateClient(self.stateful_processor_api_client), state_name, 
schema)
 
-    def getListState(self, state_name: str, schema: Union[StructType, str]) -> 
ListState:
+    def getListState(
+        self, state_name: str, schema: Union[StructType, str], 
ttl_duration_ms: Optional[int] = None
+    ) -> ListState:
         """
         Function to create new or return existing single value state variable 
of given type.
         The user must ensure to call this function only within the `init()` 
method of the
@@ -169,8 +171,13 @@ class StatefulProcessorHandle:
         schema : :class:`pyspark.sql.types.DataType` or str
             The schema of the state variable. The value can be either a
             :class:`pyspark.sql.types.DataType` object or a DDL-formatted type 
string.
+        ttlDurationMs: int
+            Time to live duration of the state in milliseconds. State values 
will not be returned
+            past ttlDuration and will be eventually removed from the state 
store. Any state update
+            resets the expiration time to current processing time plus 
ttlDuration.
+            If ttl is not specified the state will never expire.
         """
-        self.stateful_processor_api_client.get_list_state(state_name, schema)
+        self.stateful_processor_api_client.get_list_state(state_name, schema, 
ttl_duration_ms)
         return ListState(ListStateClient(self.stateful_processor_api_client), 
state_name, schema)
 
 
diff --git a/python/pyspark/sql/streaming/stateful_processor_api_client.py 
b/python/pyspark/sql/streaming/stateful_processor_api_client.py
index 2a5e55159e76..449d5a2ad55d 100644
--- a/python/pyspark/sql/streaming/stateful_processor_api_client.py
+++ b/python/pyspark/sql/streaming/stateful_processor_api_client.py
@@ -131,7 +131,9 @@ class StatefulProcessorApiClient:
             # TODO(SPARK-49233): Classify user facing errors.
             raise PySparkRuntimeError(f"Error initializing value state: " 
f"{response_message[1]}")
 
-    def get_list_state(self, state_name: str, schema: Union[StructType, str]) 
-> None:
+    def get_list_state(
+        self, state_name: str, schema: Union[StructType, str], 
ttl_duration_ms: Optional[int]
+    ) -> None:
         import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
 
         if isinstance(schema, str):
@@ -140,6 +142,8 @@ class StatefulProcessorApiClient:
         state_call_command = stateMessage.StateCallCommand()
         state_call_command.stateName = state_name
         state_call_command.schema = schema.json()
+        if ttl_duration_ms is not None:
+            state_call_command.ttl.durationMs = ttl_duration_ms
         call = 
stateMessage.StatefulProcessorCall(getListState=state_call_command)
         message = stateMessage.StateRequest(statefulProcessorCall=call)
 
diff --git 
a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py 
b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
index 99333ae6f5c2..01cd441941d9 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
@@ -221,6 +221,18 @@ class TransformWithStateInPandasTestsMixin:
 
         self._test_transform_with_state_in_pandas_basic(ListStateProcessor(), 
check_results, True)
 
+    # test list state with ttl has the same behavior as list state when state 
doesn't expire.
+    def test_transform_with_state_in_pandas_list_state_large_ttl(self):
+        def check_results(batch_df, _):
+            assert set(batch_df.sort("id").collect()) == {
+                Row(id="0", countAsString="2"),
+                Row(id="1", countAsString="2"),
+            }
+
+        self._test_transform_with_state_in_pandas_basic(
+            ListStateLargeTTLProcessor(), check_results, True, "processingTime"
+        )
+
     # test value state with ttl has the same behavior as value state when
     # state doesn't expire.
     def test_value_state_ttl_basic(self):
@@ -248,8 +260,10 @@ class TransformWithStateInPandasTestsMixin:
                     [
                         Row(id="ttl-count-0", count=1),
                         Row(id="count-0", count=1),
+                        Row(id="ttl-list-state-count-0", count=1),
                         Row(id="ttl-count-1", count=1),
                         Row(id="count-1", count=1),
+                        Row(id="ttl-list-state-count-1", count=1),
                     ],
                 )
             elif batch_id == 1:
@@ -258,21 +272,29 @@ class TransformWithStateInPandasTestsMixin:
                     [
                         Row(id="ttl-count-0", count=2),
                         Row(id="count-0", count=2),
+                        Row(id="ttl-list-state-count-0", count=3),
                         Row(id="ttl-count-1", count=2),
                         Row(id="count-1", count=2),
+                        Row(id="ttl-list-state-count-1", count=3),
                     ],
                 )
             elif batch_id == 2:
                 # ttl-count-0 expire and restart from count 0.
-                # ttl-count-1 get reset in batch 1 and keep the state
+                # The TTL for value state ttl_count_state gets reset in batch 
1 because of the
+                # update operation and ttl-count-1 keeps the state.
+                # ttl-list-state-count-0 expire and restart from count 0.
+                # The TTL for list state ttl_list_state gets reset in batch 1 
because of the
+                # put operation and ttl-list-state-count-1 keeps the state.
                 # non-ttl state never expires
                 assertDataFrameEqual(
                     batch_df,
                     [
                         Row(id="ttl-count-0", count=1),
                         Row(id="count-0", count=3),
+                        Row(id="ttl-list-state-count-0", count=1),
                         Row(id="ttl-count-1", count=3),
                         Row(id="count-1", count=3),
+                        Row(id="ttl-list-state-count-1", count=7),
                     ],
                 )
             if batch_id == 0 or batch_id == 1:
@@ -362,25 +384,38 @@ class TTLStatefulProcessor(StatefulProcessor):
         state_schema = StructType([StructField("value", IntegerType(), True)])
         self.ttl_count_state = handle.getValueState("ttl-state", state_schema, 
10000)
         self.count_state = handle.getValueState("state", state_schema)
+        self.ttl_list_state = handle.getListState("ttl-list-state", 
state_schema, 10000)
 
     def handleInputRows(self, key, rows) -> Iterator[pd.DataFrame]:
         count = 0
         ttl_count = 0
+        ttl_list_state_count = 0
         id = key[0]
         if self.count_state.exists():
             count = self.count_state.get()[0]
         if self.ttl_count_state.exists():
             ttl_count = self.ttl_count_state.get()[0]
+        if self.ttl_list_state.exists():
+            iter = self.ttl_list_state.get()
+            for s in iter:
+                ttl_list_state_count += s[0]
         for pdf in rows:
             pdf_count = pdf.count().get("temperature")
             count += pdf_count
             ttl_count += pdf_count
+            ttl_list_state_count += pdf_count
 
         self.count_state.update((count,))
         # skip updating state for the 2nd batch so that ttl state expire
         if not (ttl_count == 2 and id == "0"):
             self.ttl_count_state.update((ttl_count,))
-        yield pd.DataFrame({"id": [f"ttl-count-{id}", f"count-{id}"], "count": 
[ttl_count, count]})
+            self.ttl_list_state.put([(ttl_list_state_count,), 
(ttl_list_state_count,)])
+        yield pd.DataFrame(
+            {
+                "id": [f"ttl-count-{id}", f"count-{id}", 
f"ttl-list-state-count-{id}"],
+                "count": [ttl_count, count, ttl_list_state_count],
+            }
+        )
 
     def close(self) -> None:
         pass
@@ -457,6 +492,15 @@ class ListStateProcessor(StatefulProcessor):
         pass
 
 
+# A stateful processor that inherit all behavior of ListStateProcessor except 
that it use
+# ttl state with a large timeout.
+class ListStateLargeTTLProcessor(ListStateProcessor):
+    def init(self, handle: StatefulProcessorHandle) -> None:
+        state_schema = StructType([StructField("temperature", IntegerType(), 
True)])
+        self.list_state1 = handle.getListState("listState1", state_schema, 
30000)
+        self.list_state2 = handle.getListState("listState2", state_schema, 
30000)
+
+
 class TransformWithStateInPandasTests(TransformWithStateInPandasTestsMixin, 
ReusedSQLTestCase):
     pass
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala
index d293e7a4a5bb..fed1843acfa5 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala
@@ -189,8 +189,12 @@ class TransformWithStateInPandasStateServer(
       case StatefulProcessorCall.MethodCase.GETLISTSTATE =>
         val stateName = message.getGetListState.getStateName
         val schema = message.getGetListState.getSchema
-        // TODO(SPARK-49744): Add ttl support for list state.
-        initializeStateVariable(stateName, schema, 
StateVariableType.ListState, None)
+        val ttlDurationMs = if (message.getGetListState.hasTtl) {
+          Some(message.getGetListState.getTtl.getDurationMs)
+        } else {
+            None
+        }
+        initializeStateVariable(stateName, schema, 
StateVariableType.ListState, ttlDurationMs)
       case _ =>
         throw new IllegalArgumentException("Invalid method call")
     }
@@ -372,10 +376,14 @@ class TransformWithStateInPandasStateServer(
           sendResponse(1, s"Value state $stateName already exists")
         }
         case StateVariableType.ListState => if 
(!listStates.contains(stateName)) {
-          // TODO(SPARK-49744): Add ttl support for list state.
+          val state = if (ttlDurationMs.isEmpty) {
+            statefulProcessorHandle.getListState[Row](stateName, 
Encoders.row(schema))
+          } else {
+            statefulProcessorHandle.getListState(
+              stateName, Encoders.row(schema), 
TTLConfig(Duration.ofMillis(ttlDurationMs.get)))
+          }
           listStates.put(stateName,
-            ListStateInfo(statefulProcessorHandle.getListState[Row](stateName,
-              Encoders.row(schema)), schema, 
expressionEncoder.createDeserializer(),
+            ListStateInfo(state, schema, 
expressionEncoder.createDeserializer(),
               expressionEncoder.createSerializer()))
           sendResponse(0)
         } else {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala
index 137e2531f4f4..776772bb51ca 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala
@@ -118,6 +118,29 @@ class TransformWithStateInPandasStateServerSuite extends 
SparkFunSuite with Befo
     }
   }
 
+  Seq(true, false).foreach { useTTL =>
+    test(s"get list state, useTTL=$useTTL") {
+      val stateCallCommandBuilder = StateCallCommand.newBuilder()
+        .setStateName("newName")
+        .setSchema("StructType(List(StructField(value,IntegerType,true)))")
+      if (useTTL) {
+        
stateCallCommandBuilder.setTtl(StateMessage.TTLConfig.newBuilder().setDurationMs(1000))
+      }
+      val message = StatefulProcessorCall
+        .newBuilder()
+        .setGetListState(stateCallCommandBuilder.build())
+        .build()
+      stateServer.handleStatefulProcessorCall(message)
+      if (useTTL) {
+        verify(statefulProcessorHandle)
+          .getListState[Row](any[String], any[Encoder[Row]], any[TTLConfig])
+      } else {
+        verify(statefulProcessorHandle).getListState[Row](any[String], 
any[Encoder[Row]])
+      }
+      verify(outputStream).writeInt(0)
+    }
+  }
+
   test("value state exists") {
     val message = ValueStateCall.newBuilder().setStateName(stateName)
       .setExists(Exists.newBuilder().build()).build()


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to