funrollloops commented on code in PR #56464:
URL: https://github.com/apache/spark/pull/56464#discussion_r3461827665


##########
python/pyspark/worker.py:
##########
@@ -3373,37 +3349,147 @@ def map_batch(batch):
         return func, None, ser, ser
 
     if eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF:
-        # We assume there is only one UDF here because grouped map doesn't
-        # support combining multiple UDFs.
-        assert num_udfs == 1
+        import pyarrow as pa
+        import pandas as pd
+
+        assert num_udfs == 1, "One TRANSFORM_WITH_STATE_PANDAS UDF expected 
here."
+        udf, arg_offsets, return_type = udfs[0]
 
         # See TransformWithStateInPandasExec for how arg_offsets are used to
         # distinguish between grouping attributes and data attributes
-        arg_offsets, f = udfs[0]
         parsed_offsets = extract_key_value_indexes(arg_offsets)
-        ser.key_offsets = parsed_offsets[0][0]
+        assert len(parsed_offsets) == 1, (
+            "Expected one pair of offsets for TRANSFORM_WITH_STATE_PANDAS UDF."
+        )
+
+        key_offsets = parsed_offsets[0][0]
+        value_offsets = parsed_offsets[0][1]
+        output_schema = StructType([StructField("_0", return_type)])
+
         stateful_processor_api_client = StatefulProcessorApiClient(
             eval_conf.state_server_socket_port, eval_conf.grouping_key_schema
         )
 
-        def mapper(a):
-            mode = a[0]
+        arrow_max_records_per_batch = runner_conf.arrow_max_records_per_batch
+        arrow_max_records_per_batch = (
+            arrow_max_records_per_batch if arrow_max_records_per_batch > 0 
else 2**31 - 1
+        )
+        arrow_max_bytes_per_batch = runner_conf.arrow_max_bytes_per_batch
 
-            if mode == TransformWithStateInPandasFuncMode.PROCESS_DATA:
-                key = a[1]
+        def transform_with_state_func(
+            split_index: int,
+            batches: Iterator[pa.RecordBatch],
+        ) -> Iterator[pa.RecordBatch]:
+            """Apply transformWithStateInPandas UDF.
+
+            Data chunks for the same grouping key appear sequentially in the
+            input batches but may span batch boundaries, so rows are regrouped
+            by key and re-chunked into pandas DataFrames bounded by
+            arrow_max_records_per_batch and arrow_max_bytes_per_batch. The UDF
+            is invoked once per grouping key with a lazy iterator of chunks,
+            then once for PROCESS_TIMER and once for COMPLETE.
+            """
+            total_bytes = 0
+            total_rows = 0
+            average_arrow_row_size = 0.0
+
+            def row_stream():
+                nonlocal total_bytes, total_rows, average_arrow_row_size
+                for batch in batches:
+                    # Short circuit batch size stats if the batch size is
+                    # unlimited as computing batch size is computationally
+                    # expensive.
+                    if arrow_max_bytes_per_batch != 2**31 - 1 and 
batch.num_rows > 0:
+                        total_bytes += sum(
+                            buf.size
+                            for col in batch.columns
+                            for buf in col.buffers()
+                            if buf is not None
+                        )
+                        total_rows += batch.num_rows
+                        average_arrow_row_size = total_bytes / total_rows
+                    data_pandas = ArrowBatchTransformer.to_pandas(
+                        batch,
+                        timezone=runner_conf.timezone,
+                        prefer_int_ext_dtype=runner_conf.prefer_int_ext_dtype,

Review Comment:
   we're passing far fewer arguments than in 
python/pyspark/sql/pandas/serializers.py:1037; why?
   
   in particular we're missing schema, struct_in_pandas, ndarray_as_list, and 
df_for_struct



##########
python/pyspark/worker.py:
##########
@@ -3373,37 +3349,147 @@ def map_batch(batch):
         return func, None, ser, ser
 
     if eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF:
-        # We assume there is only one UDF here because grouped map doesn't
-        # support combining multiple UDFs.
-        assert num_udfs == 1
+        import pyarrow as pa
+        import pandas as pd
+
+        assert num_udfs == 1, "One TRANSFORM_WITH_STATE_PANDAS UDF expected 
here."
+        udf, arg_offsets, return_type = udfs[0]
 
         # See TransformWithStateInPandasExec for how arg_offsets are used to
         # distinguish between grouping attributes and data attributes
-        arg_offsets, f = udfs[0]
         parsed_offsets = extract_key_value_indexes(arg_offsets)
-        ser.key_offsets = parsed_offsets[0][0]
+        assert len(parsed_offsets) == 1, (
+            "Expected one pair of offsets for TRANSFORM_WITH_STATE_PANDAS UDF."
+        )
+
+        key_offsets = parsed_offsets[0][0]
+        value_offsets = parsed_offsets[0][1]
+        output_schema = StructType([StructField("_0", return_type)])

Review Comment:
   Did we previously wrap the returned value with a struct of one field? I 
can't find that logic anywhere in the touched code.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to