jiateoh commented on code in PR #52536:
URL: https://github.com/apache/spark/pull/52536#discussion_r2412044093
##########
python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py:
##########
@@ -1595,6 +1599,98 @@ def check_results(batch_df, batch_id):
StatefulProcessorCompositeTypeFactory(), check_results,
output_schema=output_schema
)
+ # run a test with composite types where the output of TWS (not just the
states) are complex.
+ def test_composite_output_schema(self):
+ def check_results(batch_df, batch_id):
+ batch_df.collect()
+ # Cannot use set() wrapper because Row objects contain unhashable
types (lists)
+ if batch_id == 0:
+ assert batch_df.sort("primitiveValue").collect() == [
+ Row(
+ primitiveValue="key_0_count_2",
+ listOfPrimitive=["item_0", "item_1"],
+ mapOfPrimitive={"key0": "value0", "key1": "value1"},
+ listOfComposite=[
+ Row(intValue=0, doubleValue=0.0),
+ Row(intValue=1, doubleValue=1.5)
+ ],
+ mapOfComposite={
+ "nested_key0": Row(intValue=0, doubleValue=0.0),
+ "nested_key1": Row(intValue=10, doubleValue=2.5)
+ }
+ ),
+ Row(
+ primitiveValue="key_1_count_2",
+ listOfPrimitive=["item_0", "item_1"],
+ mapOfPrimitive={"key0": "value0", "key1": "value1"},
+ listOfComposite=[
+ Row(intValue=0, doubleValue=0.0),
+ Row(intValue=1, doubleValue=1.5)
+ ],
+ mapOfComposite={
+ "nested_key0": Row(intValue=0, doubleValue=0.0),
+ "nested_key1": Row(intValue=10, doubleValue=2.5)
+ }
+ )
+ ]
+ else:
+ assert batch_df.sort("primitiveValue").collect() == [
+ Row(
+ primitiveValue="key_0_count_5",
+ listOfPrimitive=["item_0", "item_1", "item_2",
"item_3", "item_4"],
+ mapOfPrimitive={"key0": "value0", "key1": "value1",
"key2": "value2", "key3": "value3", "key4": "value4"},
+ listOfComposite=[
+ Row(intValue=0, doubleValue=0.0),
+ Row(intValue=1, doubleValue=1.5),
+ Row(intValue=2, doubleValue=3.0),
+ Row(intValue=3, doubleValue=4.5),
+ Row(intValue=4, doubleValue=6.0)
+ ],
+ mapOfComposite={
+ "nested_key0": Row(intValue=0, doubleValue=0.0),
+ "nested_key1": Row(intValue=10, doubleValue=2.5),
+ "nested_key2": Row(intValue=20, doubleValue=5.0),
+ "nested_key3": Row(intValue=30, doubleValue=7.5),
+ "nested_key4": Row(intValue=40, doubleValue=10.0)
+ }
+ ),
+ Row(
+ primitiveValue="key_1_count_4",
+ listOfPrimitive=["item_0", "item_1", "item_2",
"item_3"],
+ mapOfPrimitive={"key0": "value0", "key1": "value1",
"key2": "value2", "key3": "value3"},
+ listOfComposite=[
+ Row(intValue=0, doubleValue=0.0),
+ Row(intValue=1, doubleValue=1.5),
+ Row(intValue=2, doubleValue=3.0),
+ Row(intValue=3, doubleValue=4.5)
+ ],
+ mapOfComposite={
+ "nested_key0": Row(intValue=0, doubleValue=0.0),
+ "nested_key1": Row(intValue=10, doubleValue=2.5),
+ "nested_key2": Row(intValue=20, doubleValue=5.0),
+ "nested_key3": Row(intValue=30, doubleValue=7.5)
+ }
+ )
+ ]
+
+ # Define the output schema matching Scala case class
+ inner_nested_class_schema = StructType([
Review Comment:
Lost this commit due to a rebase/force push, but I've added both an array
and map type into the inner nested class (which is in turn used in the array
and map outputs)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]