jingz-db commented on code in PR #47878:
URL: https://github.com/apache/spark/pull/47878#discussion_r1815811970
##########
python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py:
##########
@@ -336,6 +341,263 @@ def check_results(batch_df, batch_id):
finally:
input_dir.cleanup()
+ def _test_transform_with_state_in_pandas_proc_timer(
+ self, stateful_processor, check_results):
+ input_path = tempfile.mkdtemp()
+ self._prepare_test_resource3(input_path)
+ self._prepare_test_resource1(input_path)
+ self._prepare_test_resource2(input_path)
+
+ df = self._build_test_df(input_path)
+
+ for q in self.spark.streams.active:
+ q.stop()
+ self.assertTrue(df.isStreaming)
+
+ output_schema = StructType(
+ [
+ StructField("id", StringType(), True),
+ StructField("countAsString", StringType(), True),
+ StructField("timeValues", StringType(), True)
+ ]
+ )
+
+ query_name = "processing_time_test_query"
+ q = (
+ df.groupBy("id")
+ .transformWithStateInPandas(
+ statefulProcessor=stateful_processor,
+ outputStructType=output_schema,
+ outputMode="Update",
+ timeMode="processingtime",
+ )
+ .writeStream.queryName(query_name)
+ .foreachBatch(check_results)
+ .outputMode("update")
+ .start()
+ )
+
+ self.assertEqual(q.name, query_name)
+ self.assertTrue(q.isActive)
+ q.processAllAvailable()
+ q.awaitTermination(10)
+ self.assertTrue(q.exception() is None)
+
+ def test_transform_with_state_in_pandas_proc_timer(self):
+
+ # helper function to check expired timestamp is smaller than current
processing time
+ def check_timestamp(batch_df):
+ expired_df = batch_df.filter(batch_df["countAsString"] == "-1") \
+ .select("id", "timeValues").withColumnRenamed("timeValues",
"expiredTimestamp")
+ count_df = batch_df.filter(batch_df["countAsString"] != "-1") \
+ .select("id", "timeValues").withColumnRenamed("timeValues",
"countStateTimestamp")
+ joined_df = expired_df.join(count_df, on="id")
+ for row in joined_df.collect():
+ assert row["expiredTimestamp"] < row["countStateTimestamp"]
+
+ def check_results(batch_df, batch_id):
+ if batch_id == 0:
+ assert set(batch_df.sort("id").select("id",
"countAsString").collect()) == {
+ Row(id="0", countAsString="1"),
+ Row(id="1", countAsString="1"),
+ }
+ elif batch_id == 1:
+ # for key 0, the accumulated count is emitted before the count
state is cleared
+ # during the timer process
+ assert set(batch_df.sort("id").select("id",
"countAsString").collect()) == {
+ Row(id="0", countAsString="3"),
+ Row(id="0", countAsString="-1"),
+ Row(id="1", countAsString="3"),
+ }
+ self.first_expired_timestamp = \
+ batch_df.filter(batch_df["countAsString"] ==
-1).first()['timeValues']
+ check_timestamp(batch_df)
+
+ else:
+ assert set(batch_df.sort("id").select("id",
"countAsString").collect()) == {
+ Row(id="0", countAsString="3"),
+ Row(id="0", countAsString="-1"),
+ Row(id="1", countAsString="5")
+ }
+ # The expired timestamp in current batch is larger than expiry
timestamp in batch 1
+ # because this is a new timer registered in batch1 and
+ # different from the one registered in batch 0
+ current_batch_expired_timestamp = \
+ batch_df.filter(batch_df["countAsString"] ==
-1).first()['timeValues']
+ assert(current_batch_expired_timestamp >
self.first_expired_timestamp)
+
+
self._test_transform_with_state_in_pandas_proc_timer(ProcTimeStatefulProcessor(),
check_results)
+
+ def _test_transform_with_state_in_pandas_event_time(self,
stateful_processor, check_results):
+ import pyspark.sql.functions as f
+
+ input_path = tempfile.mkdtemp()
+
+ def prepare_batch1(input_path):
+ with open(input_path + "/text-test3.txt", "w") as fw:
+ fw.write("a, 20\n")
+
+ def prepare_batch2(input_path):
+ with open(input_path + "/text-test1.txt", "w") as fw:
+ fw.write("a, 4\n")
+
+ def prepare_batch3(input_path):
+ with open(input_path + "/text-test2.txt", "w") as fw:
+ fw.write("a, 11\n")
+ fw.write("a, 13\n")
+ fw.write("a, 15\n")
+
+ prepare_batch1(input_path)
+ prepare_batch2(input_path)
+ prepare_batch3(input_path)
+
+ df = self._build_test_df(input_path)
+ df = df.select("id",
+
f.from_unixtime(f.col("temperature")).alias("eventTime").cast("timestamp")) \
+ .withWatermark("eventTime", "10 seconds")
+
+ for q in self.spark.streams.active:
+ q.stop()
+ self.assertTrue(df.isStreaming)
+
+ output_schema = StructType(
+ [
+ StructField("id", StringType(), True),
+ StructField("timestamp", StringType(), True)
+ ]
+ )
+
+ query_name = "event_time_test_query"
+ q = (
+ df.groupBy("id")
+ .transformWithStateInPandas(
+ statefulProcessor=stateful_processor,
+ outputStructType=output_schema,
+ outputMode="Update",
+ timeMode="eventtime",
+ )
+ .writeStream.queryName(query_name)
+ .foreachBatch(check_results)
+ .outputMode("update")
+ .start()
+ )
+
+ self.assertEqual(q.name, query_name)
+ self.assertTrue(q.isActive)
+ q.processAllAvailable()
+ q.awaitTermination(10)
+ self.assertTrue(q.exception() is None)
+
+ def test_transform_with_state_in_pandas_event_time(self):
+ def check_results(batch_df, batch_id):
+ if batch_id == 0:
+ assert set(batch_df.sort("id").collect()) == {
+ Row(id="a", timestamp="20")
+ }
+ elif batch_id == 1:
+ # event time = 4 will be discarded because the watermark = 15
- 10 = 5
Review Comment:
Thanks for leaving the comments! By reading your comments I realized I did
not quite understand the difference between watermark for eviction and
watermark for late record before.
The test case should be still fine, I just deleted the comments. Dropping
late record will be tested more throughly in the chaining of operator PR.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]