jingz-db commented on code in PR #49156:
URL: https://github.com/apache/spark/pull/49156#discussion_r1889350550
##########
python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py:
##########
@@ -907,6 +923,328 @@ def
test_transform_with_state_in_pandas_batch_query_initial_state(self):
Row(id="1", value=str(146 + 346)),
}
+ # This test covers mapState with TTL, an empty state variable
+ # and additional test against initial state python runner
+ def test_transform_with_map_state_metadata(self):
+ checkpoint_path = tempfile.mktemp()
+
+ def check_results(batch_df, batch_id):
+ if batch_id == 0:
+ assert set(batch_df.sort("id").collect()) == {
+ Row(id="0", countAsString="2"),
+ Row(id="1", countAsString="2"),
+ }
+ else:
+ # check for state metadata source
+ metadata_df =
self.spark.read.format("state-metadata").load(checkpoint_path)
+ assert set(
+ metadata_df.select(
+ "operatorId",
+ "operatorName",
+ "stateStoreName",
+ "numPartitions",
+ "minBatchId",
+ "maxBatchId",
+ ).collect()
+ ) == {
+ Row(
+ operatorId=0,
+ operatorName="transformWithStateInPandasExec",
+ stateStoreName="default",
+ numPartitions=5,
+ minBatchId=0,
+ maxBatchId=0,
+ )
+ }
+ operator_properties_json_obj = json.loads(
+ metadata_df.select("operatorProperties").collect()[0][0]
+ )
+ assert operator_properties_json_obj["timeMode"] ==
"ProcessingTime"
+ assert operator_properties_json_obj["outputMode"] == "Update"
+
+ state_var_list = operator_properties_json_obj["stateVariables"]
+ assert len(state_var_list) == 3
+ for state_var in state_var_list:
+ if state_var["stateName"] == "mapState":
+ assert state_var["stateVariableType"] == "MapState"
+ assert state_var["ttlEnabled"]
+ elif state_var["stateName"] == "listState":
+ assert state_var["stateVariableType"] == "ListState"
+ assert not state_var["ttlEnabled"]
+ else:
+ assert state_var["stateName"] ==
"$procTimers_keyToTimestamp"
+ assert state_var["stateVariableType"] == "TimerState"
+
+ # check for state data source
+ map_state_df = (
+ self.spark.read.format("statestore")
+ .option("path", checkpoint_path)
+ .option("stateVarName", "mapState")
+ .load()
+ )
+ assert set(
+ map_state_df.selectExpr(
+ "key.id AS groupingKey",
+ "user_map_key.name AS mapKey",
+ "user_map_value.value.count AS mapValue",
+ )
+ .sort("groupingKey")
+ .collect()
+ ) == {
+ Row(groupingKey="0", mapKey="key2", mapValue=2),
+ Row(groupingKey="1", mapKey="key2", mapValue=2),
+ }
+
+ ttl_df = map_state_df.selectExpr(
+ "user_map_value.ttlExpirationMs AS TTLVal"
+ ).collect()
+ # check if there are two rows containing TTL value in map
state dataframe
+ assert len(ttl_df) == 2
+ # check if two rows are of the same TTL value
+ assert len(set(ttl_df)) == 1
+
+ list_state_df = (
+ self.spark.read.format("statestore")
+ .option("path", checkpoint_path)
+ .option("stateVarName", "listState")
+ .load()
+ )
+ assert list_state_df.isEmpty()
+
+ for q in self.spark.streams.active:
+ q.stop()
+
+ self._test_transform_with_state_in_pandas_basic(
+ MapStateLargeTTLProcessor(),
+ check_results,
+ True,
+ "processingTime",
+ checkpoint_path=checkpoint_path,
+ initial_state=None,
+ )
+
+ # run the same test suite again but with no-op initial state
+ # TWS with initial state is using a different python runner
+ init_data = [("0", 789), ("3", 987)]
+ initial_state = self.spark.createDataFrame(init_data, "id string,
temperature int").groupBy(
+ "id"
+ )
+ self._test_transform_with_state_in_pandas_basic(
+ MapStateLargeTTLProcessor(),
+ check_results,
+ True,
+ "processingTime",
+ checkpoint_path=checkpoint_path,
+ initial_state=initial_state,
+ )
+
+ # This test covers multiple list state variables and flatten option
+ def test_transform_with_list_state_metadata(self):
+ checkpoint_path = tempfile.mktemp()
+
+ def check_results(batch_df, batch_id):
+ if batch_id == 0:
+ assert set(batch_df.sort("id").collect()) == {
+ Row(id="0", countAsString="2"),
+ Row(id="1", countAsString="2"),
+ }
+ else:
+ # check for state metadata source
+ metadata_df =
self.spark.read.format("state-metadata").load(checkpoint_path)
+ operator_properties_json_obj = json.loads(
+ metadata_df.select("operatorProperties").collect()[0][0]
+ )
+ state_var_list = operator_properties_json_obj["stateVariables"]
+ assert len(state_var_list) == 3
+ for state_var in state_var_list:
+ if state_var["stateName"] in ["listState1", "listState2"]:
+ state_var["stateVariableType"] == "ListState"
+ else:
+ assert state_var["stateName"] ==
"$procTimers_keyToTimestamp"
+ assert state_var["stateVariableType"] == "TimerState"
+
+ # check for state data source and flatten option
+ list_state_1_df = (
+ self.spark.read.format("statestore")
+ .option("path", checkpoint_path)
+ .option("stateVarName", "listState1")
+ .option("flattenCollectionTypes", True)
+ .load()
+ )
+ assert list_state_1_df.selectExpr(
+ "key.id AS groupingKey",
+ "list_element.temperature AS listElement",
+ ).sort("groupingKey", "listElement").collect() == [
+ Row(groupingKey="0", listElement=20),
+ Row(groupingKey="0", listElement=20),
+ Row(groupingKey="0", listElement=111),
+ Row(groupingKey="0", listElement=120),
+ Row(groupingKey="0", listElement=120),
+ Row(groupingKey="1", listElement=20),
+ Row(groupingKey="1", listElement=20),
+ Row(groupingKey="1", listElement=111),
+ Row(groupingKey="1", listElement=120),
+ Row(groupingKey="1", listElement=120),
+ ]
+
+ list_state_2_df = (
+ self.spark.read.format("statestore")
+ .option("path", checkpoint_path)
+ .option("stateVarName", "listState2")
+ .option("flattenCollectionTypes", False)
+ .load()
+ )
+ assert list_state_2_df.selectExpr(
+ "key.id AS groupingKey", "list_value.temperature AS
valueList"
+ ).sort("groupingKey").withColumn(
+ "valueSortedList", array_sort(col("valueList"))
+ ).select(
+ "groupingKey", "valueSortedList"
+ ).collect() == [
+ Row(groupingKey="0", valueSortedList=[20, 20, 120, 120,
222]),
+ Row(groupingKey="1", valueSortedList=[20, 20, 120, 120,
222]),
+ ]
+
+ for q in self.spark.streams.active:
+ q.stop()
+
+ self._test_transform_with_state_in_pandas_basic(
+ ListStateProcessor(),
+ check_results,
+ True,
+ "processingTime",
+ checkpoint_path=checkpoint_path,
+ initial_state=None,
+ )
+
+ # This test covers value state variable and read change feed
+ def test_transform_with_value_state_metadata(self):
+ checkpoint_path = tempfile.mktemp()
+
+ def check_results(batch_df, batch_id):
+ if batch_id == 0:
+ assert set(batch_df.sort("id").collect()) == {
+ Row(id="0", countAsString="2"),
+ Row(id="1", countAsString="2"),
+ }
+ else:
+ assert set(batch_df.sort("id").collect()) == {
+ Row(id="0", countAsString="3"),
+ Row(id="1", countAsString="2"),
+ }
+
+ # check for state metadata source
+ metadata_df =
self.spark.read.format("state-metadata").load(checkpoint_path)
+ operator_properties_json_obj = json.loads(
+ metadata_df.select("operatorProperties").collect()[0][0]
+ )
+ state_var_list = operator_properties_json_obj["stateVariables"]
+
+ # TODO "tempState" should not appear in the metadata as it is
already deleted
Review Comment:
I personally feel this should not appear in the metadata if we call
`deleteIfExists` on "tempState". Currently DriverStatefulProcessorHandleImpl
did nothing when calling `deleteIfExists`:
https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala#L493.
Shall we make the change to remove the deleted variable from tracking map for
metadata?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]