HeartSaVioR commented on code in PR #48005:
URL: https://github.com/apache/spark/pull/48005#discussion_r1828618048
##########
python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py:
##########
@@ -536,6 +551,167 @@ def check_results(batch_df, batch_id):
EventTimeStatefulProcessor(), check_results
)
+ def _test_transform_with_state_init_state_in_pandas(self,
stateful_processor, check_results):
+ input_path = tempfile.mkdtemp()
+ self._prepare_test_resource1(input_path)
+ self._prepare_input_data(input_path + "/text-test2.txt", [0, 3], [67,
12])
+
+ df = self._build_test_df(input_path)
+
+ for q in self.spark.streams.active:
+ q.stop()
+ self.assertTrue(df.isStreaming)
+
+ output_schema = StructType(
+ [
+ StructField("id", StringType(), True),
+ StructField("value", StringType(), True),
+ ]
+ )
+
+ data = [("0", 789), ("3", 987)]
+ initial_state = self.spark.createDataFrame(data, "id string, initVal
int").groupBy("id")
+
+ q = (
+ df.groupBy("id")
+ .transformWithStateInPandas(
+ statefulProcessor=stateful_processor,
+ outputStructType=output_schema,
+ outputMode="Update",
+ timeMode="None",
+ initialState=initial_state,
+ )
+ .writeStream.queryName("this_query")
+ .foreachBatch(check_results)
+ .outputMode("update")
+ .start()
+ )
+
+ self.assertEqual(q.name, "this_query")
+ self.assertTrue(q.isActive)
+ q.processAllAvailable()
+ q.awaitTermination(10)
+ self.assertTrue(q.exception() is None)
+
+ def test_transform_with_state_init_state_in_pandas(self):
+ def check_results(batch_df, batch_id):
+ if batch_id == 0:
+ # for key 0, initial state was processed and it was only
processed once;
+ # for key 1, it did not appear in the initial state df;
+ # for key 3, it did not appear in the first batch of input keys
+ # so it won't be emitted
+ assert set(batch_df.sort("id").collect()) == {
+ Row(id="0", value=str(789 + 123 + 46)),
+ Row(id="1", value=str(146 + 346)),
+ }
+ else:
+ # for key 0, verify initial state was only processed once in
the first batch;
+ # for key 3, verify init state was processed and reflected in
the accumulated value
+ assert set(batch_df.sort("id").collect()) == {
+ Row(id="0", value=str(789 + 123 + 46 + 67)),
+ Row(id="3", value=str(987 + 12)),
+ }
+
+ self._test_transform_with_state_init_state_in_pandas(
+ SimpleStatefulProcessorWithInitialState(), check_results
+ )
+
+ def _test_transform_with_state_non_contiguous_grouping_cols(
Review Comment:
shall we have the same test (non-contiguous grouping keys) for the path of
initial state for completeness sake?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]