bogao007 commented on code in PR #52536:
URL: https://github.com/apache/spark/pull/52536#discussion_r2411935951


##########
python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py:
##########
@@ -1595,6 +1599,98 @@ def check_results(batch_df, batch_id):
             StatefulProcessorCompositeTypeFactory(), check_results, 
output_schema=output_schema
         )
 
+    # run a test with composite types where the output of TWS (not just the 
states) are complex.
+    def test_composite_output_schema(self):
+        def check_results(batch_df, batch_id):
+            batch_df.collect()
+            # Cannot use set() wrapper because Row objects contain unhashable 
types (lists)
+            if batch_id == 0:
+                assert batch_df.sort("primitiveValue").collect() == [
+                    Row(
+                        primitiveValue="key_0_count_2",
+                        listOfPrimitive=["item_0", "item_1"],
+                        mapOfPrimitive={"key0": "value0", "key1": "value1"},
+                        listOfComposite=[
+                            Row(intValue=0, doubleValue=0.0),
+                            Row(intValue=1, doubleValue=1.5)
+                        ],
+                        mapOfComposite={
+                            "nested_key0": Row(intValue=0, doubleValue=0.0),
+                            "nested_key1": Row(intValue=10, doubleValue=2.5)
+                        }
+                    ),
+                    Row(
+                        primitiveValue="key_1_count_2",
+                        listOfPrimitive=["item_0", "item_1"],
+                        mapOfPrimitive={"key0": "value0", "key1": "value1"},
+                        listOfComposite=[
+                            Row(intValue=0, doubleValue=0.0),
+                            Row(intValue=1, doubleValue=1.5)
+                        ],
+                        mapOfComposite={
+                            "nested_key0": Row(intValue=0, doubleValue=0.0),
+                            "nested_key1": Row(intValue=10, doubleValue=2.5)
+                        }
+                    )
+                ]
+            else:
+                assert batch_df.sort("primitiveValue").collect() == [
+                    Row(
+                        primitiveValue="key_0_count_5",
+                        listOfPrimitive=["item_0", "item_1", "item_2", 
"item_3", "item_4"],
+                        mapOfPrimitive={"key0": "value0", "key1": "value1", 
"key2": "value2", "key3": "value3", "key4": "value4"},
+                        listOfComposite=[
+                            Row(intValue=0, doubleValue=0.0),
+                            Row(intValue=1, doubleValue=1.5),
+                            Row(intValue=2, doubleValue=3.0),
+                            Row(intValue=3, doubleValue=4.5),
+                            Row(intValue=4, doubleValue=6.0)
+                        ],
+                        mapOfComposite={
+                            "nested_key0": Row(intValue=0, doubleValue=0.0),
+                            "nested_key1": Row(intValue=10, doubleValue=2.5),
+                            "nested_key2": Row(intValue=20, doubleValue=5.0),
+                            "nested_key3": Row(intValue=30, doubleValue=7.5),
+                            "nested_key4": Row(intValue=40, doubleValue=10.0)
+                        }
+                    ),
+                    Row(
+                        primitiveValue="key_1_count_4",
+                        listOfPrimitive=["item_0", "item_1", "item_2", 
"item_3"],
+                        mapOfPrimitive={"key0": "value0", "key1": "value1", 
"key2": "value2", "key3": "value3"},
+                        listOfComposite=[
+                            Row(intValue=0, doubleValue=0.0),
+                            Row(intValue=1, doubleValue=1.5),
+                            Row(intValue=2, doubleValue=3.0),
+                            Row(intValue=3, doubleValue=4.5)
+                        ],
+                        mapOfComposite={
+                            "nested_key0": Row(intValue=0, doubleValue=0.0),
+                            "nested_key1": Row(intValue=10, doubleValue=2.5),
+                            "nested_key2": Row(intValue=20, doubleValue=5.0),
+                            "nested_key3": Row(intValue=30, doubleValue=7.5)
+                        }
+                    )
+                ]
+
+        # Define the output schema matching Scala case class
+        inner_nested_class_schema = StructType([

Review Comment:
   Can we add an Array type to the inner schema of the map state?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to