jingz-db commented on code in PR #47878:
URL: https://github.com/apache/spark/pull/47878#discussion_r1815811970


##########
python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py:
##########
@@ -336,6 +341,263 @@ def check_results(batch_df, batch_id):
         finally:
             input_dir.cleanup()
 
+    def _test_transform_with_state_in_pandas_proc_timer(
+            self, stateful_processor, check_results):
+        input_path = tempfile.mkdtemp()
+        self._prepare_test_resource3(input_path)
+        self._prepare_test_resource1(input_path)
+        self._prepare_test_resource2(input_path)
+
+        df = self._build_test_df(input_path)
+
+        for q in self.spark.streams.active:
+            q.stop()
+        self.assertTrue(df.isStreaming)
+
+        output_schema = StructType(
+            [
+                StructField("id", StringType(), True),
+                StructField("countAsString", StringType(), True),
+                StructField("timeValues", StringType(), True)
+            ]
+        )
+
+        query_name = "processing_time_test_query"
+        q = (
+            df.groupBy("id")
+            .transformWithStateInPandas(
+                statefulProcessor=stateful_processor,
+                outputStructType=output_schema,
+                outputMode="Update",
+                timeMode="processingtime",
+            )
+            .writeStream.queryName(query_name)
+            .foreachBatch(check_results)
+            .outputMode("update")
+            .start()
+        )
+
+        self.assertEqual(q.name, query_name)
+        self.assertTrue(q.isActive)
+        q.processAllAvailable()
+        q.awaitTermination(10)
+        self.assertTrue(q.exception() is None)
+
+    def test_transform_with_state_in_pandas_proc_timer(self):
+
+        # helper function to check expired timestamp is smaller than current 
processing time
+        def check_timestamp(batch_df):
+            expired_df = batch_df.filter(batch_df["countAsString"] == "-1") \
+                .select("id", "timeValues").withColumnRenamed("timeValues", 
"expiredTimestamp")
+            count_df = batch_df.filter(batch_df["countAsString"] != "-1") \
+                .select("id", "timeValues").withColumnRenamed("timeValues", 
"countStateTimestamp")
+            joined_df = expired_df.join(count_df, on="id")
+            for row in joined_df.collect():
+                assert row["expiredTimestamp"] < row["countStateTimestamp"]
+
+        def check_results(batch_df, batch_id):
+            if batch_id == 0:
+                assert set(batch_df.sort("id").select("id", 
"countAsString").collect()) == {
+                    Row(id="0", countAsString="1"),
+                    Row(id="1", countAsString="1"),
+                }
+            elif batch_id == 1:
+                # for key 0, the accumulated count is emitted before the count 
state is cleared
+                # during the timer process
+                assert set(batch_df.sort("id").select("id", 
"countAsString").collect()) == {
+                    Row(id="0", countAsString="3"),
+                    Row(id="0", countAsString="-1"),
+                    Row(id="1", countAsString="3"),
+                }
+                self.first_expired_timestamp = \
+                    batch_df.filter(batch_df["countAsString"] == 
-1).first()['timeValues']
+                check_timestamp(batch_df)
+
+            else:
+                assert set(batch_df.sort("id").select("id", 
"countAsString").collect()) == {
+                    Row(id="0", countAsString="3"),
+                    Row(id="0", countAsString="-1"),
+                    Row(id="1", countAsString="5")
+                }
+                # The expired timestamp in current batch is larger than expiry 
timestamp in batch 1
+                # because this is a new timer registered in batch1 and
+                # different from the one registered in batch 0
+                current_batch_expired_timestamp = \
+                    batch_df.filter(batch_df["countAsString"] == 
-1).first()['timeValues']
+                assert(current_batch_expired_timestamp > 
self.first_expired_timestamp)
+
+        
self._test_transform_with_state_in_pandas_proc_timer(ProcTimeStatefulProcessor(),
 check_results)
+
+    def _test_transform_with_state_in_pandas_event_time(self, 
stateful_processor, check_results):
+        import pyspark.sql.functions as f
+
+        input_path = tempfile.mkdtemp()
+
+        def prepare_batch1(input_path):
+            with open(input_path + "/text-test3.txt", "w") as fw:
+                fw.write("a, 20\n")
+
+        def prepare_batch2(input_path):
+            with open(input_path + "/text-test1.txt", "w") as fw:
+                fw.write("a, 4\n")
+
+        def prepare_batch3(input_path):
+            with open(input_path + "/text-test2.txt", "w") as fw:
+                fw.write("a, 11\n")
+                fw.write("a, 13\n")
+                fw.write("a, 15\n")
+
+        prepare_batch1(input_path)
+        prepare_batch2(input_path)
+        prepare_batch3(input_path)
+
+        df = self._build_test_df(input_path)
+        df = df.select("id",
+                       
f.from_unixtime(f.col("temperature")).alias("eventTime").cast("timestamp")) \
+            .withWatermark("eventTime", "10 seconds")
+
+        for q in self.spark.streams.active:
+            q.stop()
+        self.assertTrue(df.isStreaming)
+
+        output_schema = StructType(
+            [
+                StructField("id", StringType(), True),
+                StructField("timestamp", StringType(), True)
+            ]
+        )
+
+        query_name = "event_time_test_query"
+        q = (
+            df.groupBy("id")
+            .transformWithStateInPandas(
+                statefulProcessor=stateful_processor,
+                outputStructType=output_schema,
+                outputMode="Update",
+                timeMode="eventtime",
+            )
+            .writeStream.queryName(query_name)
+            .foreachBatch(check_results)
+            .outputMode("update")
+            .start()
+        )
+
+        self.assertEqual(q.name, query_name)
+        self.assertTrue(q.isActive)
+        q.processAllAvailable()
+        q.awaitTermination(10)
+        self.assertTrue(q.exception() is None)
+
+    def test_transform_with_state_in_pandas_event_time(self):
+        def check_results(batch_df, batch_id):
+            if batch_id == 0:
+                assert set(batch_df.sort("id").collect()) == {
+                    Row(id="a", timestamp="20")
+                }
+            elif batch_id == 1:
+                # event time = 4 will be discarded because the watermark = 15 
- 10 = 5

Review Comment:
   Thanks for leaving the comments! By reading your comments I realized I did 
not quite understand the difference between watermark for eviction and 
watermark for late record before. 
   The test case should be still fine, I just deleted the comments. Dropping 
late record will be tested more throughly in the chaining of operator PR.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to