jingz-db commented on code in PR #48124:
URL: https://github.com/apache/spark/pull/48124#discussion_r1859419007


##########
python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py:
##########
@@ -701,6 +722,109 @@ def check_results(batch_df, batch_id):
             SimpleStatefulProcessorWithInitialState(), check_results, 
initial_state
         )
 
+    def _test_transform_with_state_in_pandas_chaining_ops(
+        self, stateful_processor, check_results, timeMode="None", 
grouping_cols=["outputTimestamp"]
+    ):
+        import pyspark.sql.functions as f
+
+        input_path = tempfile.mkdtemp()
+        self._prepare_test_resource1(input_path)
+
+        def prepare_batch1(input_path):
+            with open(input_path + "/text-test3.txt", "w") as fw:
+                fw.write("a, 20\n")
+
+        def prepare_batch2(input_path):
+            with open(input_path + "/text-test4.txt", "w") as fw:
+                fw.write("a, 3\n")
+
+        def prepare_batch3(input_path):
+            with open(input_path + "/text-test1.txt", "w") as fw:
+                fw.write("a, 4\n")
+
+        def prepare_batch4(input_path):
+            with open(input_path + "/text-test2.txt", "w") as fw:
+                fw.write("a, 20\n")
+
+        prepare_batch1(input_path)
+        time.sleep(2)
+        prepare_batch2(input_path)
+        time.sleep(2)
+        prepare_batch3(input_path)
+        time.sleep(2)
+        prepare_batch4(input_path)
+
+        df = self._build_test_df(input_path)
+        df = df.select(
+            "id", 
f.from_unixtime(f.col("temperature")).alias("eventTime").cast("timestamp")
+        ).withWatermark("eventTime", "10 seconds")
+
+        for q in self.spark.streams.active:
+            q.stop()
+        self.assertTrue(df.isStreaming)
+
+        output_schema = StructType(
+            [
+                StructField("id", StringType(), True),
+                StructField("outputTimestamp", TimestampType(), True),
+            ]
+        )
+
+        q = (
+            df.groupBy("id")
+            .transformWithStateInPandas(
+                statefulProcessor=stateful_processor,
+                outputStructType=output_schema,
+                outputMode="Append",
+                timeMode=timeMode,
+                eventTimeColumnName="outputTimestamp",
+            )
+            .groupBy(grouping_cols)
+            .count()
+            .writeStream.queryName("chaining_ops_query")
+            .foreachBatch(check_results)
+            .outputMode("append")
+            .start()
+        )
+
+        self.assertEqual(q.name, "chaining_ops_query")
+        self.assertTrue(q.isActive)
+        q.processAllAvailable()
+        q.awaitTermination(10)
+
+    def test_transform_with_state_in_pandas_chaining_ops(self):
+        def check_results(batch_df, batch_id):
+            import datetime
+
+            if batch_id == 0:
+                assert set(
+                    batch_df.sort("outputTimestamp").select("outputTimestamp", 
"count").collect()
+                ) == {
+                    Row(outputTimestamp=datetime.datetime(1970, 1, 1, 0, 0, 
20), count=1),
+                }
+            elif batch_id == 1:
+                assert set(
+                    batch_df.sort("outputTimestamp").select("outputTimestamp", 
"count").collect()
+                ) == {
+                    Row(outputTimestamp=datetime.datetime(1970, 1, 1, 0, 0, 
3), count=1),
+                }
+            elif batch_id == 2:
+                # as the late event watermark is 10, eventTime=3 is dropped
+                assert batch_df.isEmpty()
+            else:
+                assert set(
+                    batch_df.sort("outputTimestamp").select("outputTimestamp", 
"count").collect()
+                ) == {
+                    Row(outputTimestamp=datetime.datetime(1970, 1, 1, 0, 0, 
20), count=2),
+                }
+
+        self._test_transform_with_state_in_pandas_chaining_ops(
+            StatefulProcessorChainingOps(), check_results, "eventTime"
+        )
+        self._test_transform_with_state_in_pandas_chaining_ops(

Review Comment:
   Yeah it is not our main interest as we are testing for event time watermark 
is propagated properly for TWS here. I believe we should have abundant tests 
for aggregation already. I am keeping this for safety because previously 
without the column pruning change, `groupby` on single column and multiple 
columns will match different case class in `IncrementalExecution`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to