This is an automated email from the ASF dual-hosted git repository.
kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 06a8f6b6d6f6 [SPARK-49744][SS][PYTHON] Implement TTL support for
ListState in TransformWithStateInPandas
06a8f6b6d6f6 is described below
commit 06a8f6b6d6f674d72fdb0d17e20041c442a2a6fd
Author: bogao007 <[email protected]>
AuthorDate: Sun Oct 6 16:57:02 2024 +0900
[SPARK-49744][SS][PYTHON] Implement TTL support for ListState in
TransformWithStateInPandas
### What changes were proposed in this pull request?
Implement TTL support for ListState in TransformWithStateInPandas.
### Why are the changes needed?
Allow users to add TTL to specific list state.
### Does this PR introduce _any_ user-facing change?
Yes
### How was this patch tested?
Added unit tests.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #48253 from bogao007/ttl-list-state.
Authored-by: bogao007 <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
---
python/pyspark/sql/streaming/stateful_processor.py | 13 ++++--
.../sql/streaming/stateful_processor_api_client.py | 6 ++-
.../pandas/test_pandas_transform_with_state.py | 48 +++++++++++++++++++++-
.../TransformWithStateInPandasStateServer.scala | 18 +++++---
...ransformWithStateInPandasStateServerSuite.scala | 23 +++++++++++
5 files changed, 97 insertions(+), 11 deletions(-)
diff --git a/python/pyspark/sql/streaming/stateful_processor.py
b/python/pyspark/sql/streaming/stateful_processor.py
index 0011b62132ad..6b8de0f8ac4e 100644
--- a/python/pyspark/sql/streaming/stateful_processor.py
+++ b/python/pyspark/sql/streaming/stateful_processor.py
@@ -56,7 +56,7 @@ class ValueState:
"""
return self._value_state_client.get(self._state_name)
- def update(self, new_value: Any) -> None:
+ def update(self, new_value: Tuple) -> None:
"""
Update the value of the state.
"""
@@ -156,7 +156,9 @@ class StatefulProcessorHandle:
self.stateful_processor_api_client.get_value_state(state_name, schema,
ttl_duration_ms)
return
ValueState(ValueStateClient(self.stateful_processor_api_client), state_name,
schema)
- def getListState(self, state_name: str, schema: Union[StructType, str]) ->
ListState:
+ def getListState(
+ self, state_name: str, schema: Union[StructType, str],
ttl_duration_ms: Optional[int] = None
+ ) -> ListState:
"""
Function to create new or return existing single value state variable
of given type.
The user must ensure to call this function only within the `init()`
method of the
@@ -169,8 +171,13 @@ class StatefulProcessorHandle:
schema : :class:`pyspark.sql.types.DataType` or str
The schema of the state variable. The value can be either a
:class:`pyspark.sql.types.DataType` object or a DDL-formatted type
string.
+ ttlDurationMs: int
+ Time to live duration of the state in milliseconds. State values
will not be returned
+ past ttlDuration and will be eventually removed from the state
store. Any state update
+ resets the expiration time to current processing time plus
ttlDuration.
+ If ttl is not specified the state will never expire.
"""
- self.stateful_processor_api_client.get_list_state(state_name, schema)
+ self.stateful_processor_api_client.get_list_state(state_name, schema,
ttl_duration_ms)
return ListState(ListStateClient(self.stateful_processor_api_client),
state_name, schema)
diff --git a/python/pyspark/sql/streaming/stateful_processor_api_client.py
b/python/pyspark/sql/streaming/stateful_processor_api_client.py
index 2a5e55159e76..449d5a2ad55d 100644
--- a/python/pyspark/sql/streaming/stateful_processor_api_client.py
+++ b/python/pyspark/sql/streaming/stateful_processor_api_client.py
@@ -131,7 +131,9 @@ class StatefulProcessorApiClient:
# TODO(SPARK-49233): Classify user facing errors.
raise PySparkRuntimeError(f"Error initializing value state: "
f"{response_message[1]}")
- def get_list_state(self, state_name: str, schema: Union[StructType, str])
-> None:
+ def get_list_state(
+ self, state_name: str, schema: Union[StructType, str],
ttl_duration_ms: Optional[int]
+ ) -> None:
import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
if isinstance(schema, str):
@@ -140,6 +142,8 @@ class StatefulProcessorApiClient:
state_call_command = stateMessage.StateCallCommand()
state_call_command.stateName = state_name
state_call_command.schema = schema.json()
+ if ttl_duration_ms is not None:
+ state_call_command.ttl.durationMs = ttl_duration_ms
call =
stateMessage.StatefulProcessorCall(getListState=state_call_command)
message = stateMessage.StateRequest(statefulProcessorCall=call)
diff --git
a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
index 99333ae6f5c2..01cd441941d9 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py
@@ -221,6 +221,18 @@ class TransformWithStateInPandasTestsMixin:
self._test_transform_with_state_in_pandas_basic(ListStateProcessor(),
check_results, True)
+ # test list state with ttl has the same behavior as list state when state
doesn't expire.
+ def test_transform_with_state_in_pandas_list_state_large_ttl(self):
+ def check_results(batch_df, _):
+ assert set(batch_df.sort("id").collect()) == {
+ Row(id="0", countAsString="2"),
+ Row(id="1", countAsString="2"),
+ }
+
+ self._test_transform_with_state_in_pandas_basic(
+ ListStateLargeTTLProcessor(), check_results, True, "processingTime"
+ )
+
# test value state with ttl has the same behavior as value state when
# state doesn't expire.
def test_value_state_ttl_basic(self):
@@ -248,8 +260,10 @@ class TransformWithStateInPandasTestsMixin:
[
Row(id="ttl-count-0", count=1),
Row(id="count-0", count=1),
+ Row(id="ttl-list-state-count-0", count=1),
Row(id="ttl-count-1", count=1),
Row(id="count-1", count=1),
+ Row(id="ttl-list-state-count-1", count=1),
],
)
elif batch_id == 1:
@@ -258,21 +272,29 @@ class TransformWithStateInPandasTestsMixin:
[
Row(id="ttl-count-0", count=2),
Row(id="count-0", count=2),
+ Row(id="ttl-list-state-count-0", count=3),
Row(id="ttl-count-1", count=2),
Row(id="count-1", count=2),
+ Row(id="ttl-list-state-count-1", count=3),
],
)
elif batch_id == 2:
# ttl-count-0 expire and restart from count 0.
- # ttl-count-1 get reset in batch 1 and keep the state
+ # The TTL for value state ttl_count_state gets reset in batch
1 because of the
+ # update operation and ttl-count-1 keeps the state.
+ # ttl-list-state-count-0 expire and restart from count 0.
+ # The TTL for list state ttl_list_state gets reset in batch 1
because of the
+ # put operation and ttl-list-state-count-1 keeps the state.
# non-ttl state never expires
assertDataFrameEqual(
batch_df,
[
Row(id="ttl-count-0", count=1),
Row(id="count-0", count=3),
+ Row(id="ttl-list-state-count-0", count=1),
Row(id="ttl-count-1", count=3),
Row(id="count-1", count=3),
+ Row(id="ttl-list-state-count-1", count=7),
],
)
if batch_id == 0 or batch_id == 1:
@@ -362,25 +384,38 @@ class TTLStatefulProcessor(StatefulProcessor):
state_schema = StructType([StructField("value", IntegerType(), True)])
self.ttl_count_state = handle.getValueState("ttl-state", state_schema,
10000)
self.count_state = handle.getValueState("state", state_schema)
+ self.ttl_list_state = handle.getListState("ttl-list-state",
state_schema, 10000)
def handleInputRows(self, key, rows) -> Iterator[pd.DataFrame]:
count = 0
ttl_count = 0
+ ttl_list_state_count = 0
id = key[0]
if self.count_state.exists():
count = self.count_state.get()[0]
if self.ttl_count_state.exists():
ttl_count = self.ttl_count_state.get()[0]
+ if self.ttl_list_state.exists():
+ iter = self.ttl_list_state.get()
+ for s in iter:
+ ttl_list_state_count += s[0]
for pdf in rows:
pdf_count = pdf.count().get("temperature")
count += pdf_count
ttl_count += pdf_count
+ ttl_list_state_count += pdf_count
self.count_state.update((count,))
# skip updating state for the 2nd batch so that ttl state expire
if not (ttl_count == 2 and id == "0"):
self.ttl_count_state.update((ttl_count,))
- yield pd.DataFrame({"id": [f"ttl-count-{id}", f"count-{id}"], "count":
[ttl_count, count]})
+ self.ttl_list_state.put([(ttl_list_state_count,),
(ttl_list_state_count,)])
+ yield pd.DataFrame(
+ {
+ "id": [f"ttl-count-{id}", f"count-{id}",
f"ttl-list-state-count-{id}"],
+ "count": [ttl_count, count, ttl_list_state_count],
+ }
+ )
def close(self) -> None:
pass
@@ -457,6 +492,15 @@ class ListStateProcessor(StatefulProcessor):
pass
+# A stateful processor that inherit all behavior of ListStateProcessor except
that it use
+# ttl state with a large timeout.
+class ListStateLargeTTLProcessor(ListStateProcessor):
+ def init(self, handle: StatefulProcessorHandle) -> None:
+ state_schema = StructType([StructField("temperature", IntegerType(),
True)])
+ self.list_state1 = handle.getListState("listState1", state_schema,
30000)
+ self.list_state2 = handle.getListState("listState2", state_schema,
30000)
+
+
class TransformWithStateInPandasTests(TransformWithStateInPandasTestsMixin,
ReusedSQLTestCase):
pass
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala
index d293e7a4a5bb..fed1843acfa5 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala
@@ -189,8 +189,12 @@ class TransformWithStateInPandasStateServer(
case StatefulProcessorCall.MethodCase.GETLISTSTATE =>
val stateName = message.getGetListState.getStateName
val schema = message.getGetListState.getSchema
- // TODO(SPARK-49744): Add ttl support for list state.
- initializeStateVariable(stateName, schema,
StateVariableType.ListState, None)
+ val ttlDurationMs = if (message.getGetListState.hasTtl) {
+ Some(message.getGetListState.getTtl.getDurationMs)
+ } else {
+ None
+ }
+ initializeStateVariable(stateName, schema,
StateVariableType.ListState, ttlDurationMs)
case _ =>
throw new IllegalArgumentException("Invalid method call")
}
@@ -372,10 +376,14 @@ class TransformWithStateInPandasStateServer(
sendResponse(1, s"Value state $stateName already exists")
}
case StateVariableType.ListState => if
(!listStates.contains(stateName)) {
- // TODO(SPARK-49744): Add ttl support for list state.
+ val state = if (ttlDurationMs.isEmpty) {
+ statefulProcessorHandle.getListState[Row](stateName,
Encoders.row(schema))
+ } else {
+ statefulProcessorHandle.getListState(
+ stateName, Encoders.row(schema),
TTLConfig(Duration.ofMillis(ttlDurationMs.get)))
+ }
listStates.put(stateName,
- ListStateInfo(statefulProcessorHandle.getListState[Row](stateName,
- Encoders.row(schema)), schema,
expressionEncoder.createDeserializer(),
+ ListStateInfo(state, schema,
expressionEncoder.createDeserializer(),
expressionEncoder.createSerializer()))
sendResponse(0)
} else {
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala
index 137e2531f4f4..776772bb51ca 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala
@@ -118,6 +118,29 @@ class TransformWithStateInPandasStateServerSuite extends
SparkFunSuite with Befo
}
}
+ Seq(true, false).foreach { useTTL =>
+ test(s"get list state, useTTL=$useTTL") {
+ val stateCallCommandBuilder = StateCallCommand.newBuilder()
+ .setStateName("newName")
+ .setSchema("StructType(List(StructField(value,IntegerType,true)))")
+ if (useTTL) {
+
stateCallCommandBuilder.setTtl(StateMessage.TTLConfig.newBuilder().setDurationMs(1000))
+ }
+ val message = StatefulProcessorCall
+ .newBuilder()
+ .setGetListState(stateCallCommandBuilder.build())
+ .build()
+ stateServer.handleStatefulProcessorCall(message)
+ if (useTTL) {
+ verify(statefulProcessorHandle)
+ .getListState[Row](any[String], any[Encoder[Row]], any[TTLConfig])
+ } else {
+ verify(statefulProcessorHandle).getListState[Row](any[String],
any[Encoder[Row]])
+ }
+ verify(outputStream).writeInt(0)
+ }
+ }
+
test("value state exists") {
val message = ValueStateCall.newBuilder().setStateName(stateName)
.setExists(Exists.newBuilder().build()).build()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]