asl3 commented on code in PR #51225: URL: https://github.com/apache/spark/pull/51225#discussion_r2217028842
########## python/pyspark/sql/pandas/serializers.py: ########## @@ -194,6 +203,190 @@ class ArrowStreamUDTFSerializer(ArrowStreamUDFSerializer): def load_stream(self, stream): return ArrowStreamSerializer.load_stream(self, stream) +class ArrowBatchUDFSerializer(ArrowStreamUDFSerializer): + """ + Serializer used by Python worker to evaluate Arrow UDFs + """ + + def __init__( + self, + assign_cols_by_name, + input_types, + struct_in_pandas="row", + ndarray_as_list=True, + return_type=None, + prefers_large_var_types=False, + ): + super(ArrowBatchUDFSerializer, self).__init__() + self._assign_cols_by_name = assign_cols_by_name + self._input_types = input_types + self._struct_in_pandas = struct_in_pandas + self._ndarray_as_list = ndarray_as_list + self._return_type = return_type + self._prefers_large_var_types = prefers_large_var_types + + def convert_arrow_to_rows(self, *args): + """ + Convert Arrow arrays to rows + """ + import pyarrow as pa + + arrays = [ + arg.combine_chunks() if isinstance(arg, pa.ChunkedArray) else arg + for arg in args + ] + + names = [f"_{n}" for n in range(len(arrays))] + table = pa.Table.from_arrays(arrays, names=names) + schema = StructType([ + StructField(f"_{i}", data_type, True) + for i, data_type in enumerate(self._input_types) + ]) + + rows = ArrowTableToRowsConversion.convert(table, schema=schema) + + return [tuple(row) for row in rows] + + def load_stream(self, stream): + """ + Load Arrow RecordBatches and convert to Python objects. + """ + import pyarrow as pa + import pyarrow.types as types + + def is_packed_udf_arguments(struct_type): + return (hasattr(struct_type, 'metadata') and + struct_type.metadata is not None and + b'__packed_udf_args__' in struct_type.metadata) + + batches = ArrowStreamSerializer.load_stream(self, stream) + for batch in batches: + if (batch.num_columns == 1 and + types.is_struct(batch.column(0).type) and + is_packed_udf_arguments(batch.column(0).type)): + # Packed UDF arguments from ArrowStreamUDFSerializer + # Flatten them back to individual arrays + first_column = batch.column(0) + flattened_batch = pa.RecordBatch.from_arrays(first_column.flatten(), schema=pa.schema(first_column.type)) + arrays = [flattened_batch.column(i) for i in range(flattened_batch.num_columns)] + else: + # Else data from JVM, preserve structure + arrays = [batch.column(i) for i in range(batch.num_columns)] + + if len(arrays) == 0: + # Zero-arg case: create empty tuples for each row in the batch + converted_rows = [() for _ in range(batch.num_rows)] + else: + converted_rows = self.convert_arrow_to_rows(*arrays) + + yield converted_rows + + def dump_stream(self, iterator, stream): + """ + Convert Python UDF results to Arrow RecordBatches and serialize. + """ + import pyarrow as pa + + def wrap_and_init_stream(): + should_write_start_length = True + + for i, batch_data in enumerate(iterator): + if isinstance(batch_data, tuple) and len(batch_data) == 3: + # Single UDF case + udf_results, arrow_return_type, return_type = batch_data + arr = self._convert_udf_results_to_arrow(udf_results, return_type) + batch = pa.RecordBatch.from_arrays([arr], ["_0"]) + else: + # Multiple UDFs case + arrs = [] + for j, (udf_results, arrow_return_type, return_type) in enumerate(batch_data): + arr = self._convert_udf_results_to_arrow(udf_results, return_type) + arrs.append(arr) + + batch = pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in range(len(arrs))]) + + if should_write_start_length: + write_int(SpecialLengths.START_ARROW_STREAM, stream) + should_write_start_length = False + yield batch + + return ArrowStreamSerializer.dump_stream(self, wrap_and_init_stream(), stream) + + def _convert_udf_results_to_arrow(self, udf_results, return_type): + """ + Convert UDF data to Arrow with LocalDataToArrowConversion. + """ + + coerced_results = self._apply_type_coercion_for_type(udf_results, return_type) + + temp_struct_type = StructType([StructField("result", return_type)]) + temp_data = [{"result": result} for result in coerced_results] + arrow_table = LocalDataToArrowConversion.convert( + temp_data, temp_struct_type, self._prefers_large_var_types + ) + return arrow_table.column(0).combine_chunks() + + def _apply_type_coercion_for_type(self, udf_results, return_type): Review Comment: I added this to match the legacy path behavior, without it this test `test_type_coercion_string_to_numeric` fails -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org