HyukjinKwon opened a new pull request, #39817: URL: https://github.com/apache/spark/pull/39817
### What changes were proposed in this pull request? This PR is sort of a followup of https://github.com/apache/spark/pull/37734 which handles the case when the batch is single scalar value. Essentially it proposes to work around the pandas behaviour by explicitly casting back to the original data type of NumPy Arrow: ```python >>> import numpy as np >>> import pandas as pd >>> np.squeeze(np.array([1.])).dtype dtype('float64') >>> pd.Series(np.squeeze(np.array([1.]))).dtype dtype('O') >>> pd.Series(np.squeeze(np.array([1., 1.]))).dtype dtype('float64') ``` ### Why are the changes needed? Using `predict_batch_udf` fails when the size of batch happen to have single value. For example, even when the batch size is set to 10, if the size of data is 21, it fails because the last batch consists of the single value. ```python import numpy as np import pandas as pd from pyspark.ml.functions import predict_batch_udf from pyspark.sql.types import ArrayType, FloatType, StructType, StructField from typing import Mapping df = spark.createDataFrame([[[0.0, 1.0, 2.0, 3.0], [0.0, 1.0, 2.0]]], schema=["t1", "t2"]) def make_multi_sum_fn(): def predict(x1: np.ndarray, x2: np.ndarray) -> np.ndarray: return np.sum(x1, axis=1) + np.sum(x2, axis=1) return predict multi_sum_udf = predict_batch_udf( make_multi_sum_fn, return_type=FloatType(), batch_size=1, input_tensor_shapes=[[4], [3]], ) df.select(multi_sum_udf("t1", "t2")).collect() ``` **Before** ``` File "/.../spark/python/lib/pyspark.zip/pyspark/worker.py", line 829, in main process() File "/.../spark/python/lib/pyspark.zip/pyspark/worker.py", line 821, in process serializer.dump_stream(out_iter, outfile) File "/.../spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", line 345, in dump_stream return ArrowStreamSerializer.dump_stream(self, init_stream_yield_batches(), stream) File "/.../spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", line 86, in dump_stream for batch in iterator: File "/.../spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", line 339, in init_stream_yield_batches batch = self._create_batch(series) File "/.../spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", line 275, in _create_batch arrs.append(create_array(s, t)) File "/.../spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", line 245, in create_array raise e File "/.../spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", line 233, in create_array array = pa.Array.from_pandas(s, mask=mask, type=t, safe=self._safecheck) File "pyarrow/array.pxi", line 1044, in pyarrow.lib.Array.from_pandas File "pyarrow/array.pxi", line 316, in pyarrow.lib.array File "pyarrow/array.pxi", line 83, in pyarrow.lib._ndarray_to_array File "pyarrow/error.pxi", line 100, in pyarrow.lib.check_status pyarrow.lib.ArrowInvalid: Could not convert array(569.) with type numpy.ndarray: tried to convert to float32 ``` **After** ``` [Row(predict(t1, t2)=9.0)] ``` ### Does this PR introduce _any_ user-facing change? This feature has not been released yet, so no user-facing change to the end users. It fixes a bug in the unreleased feature. ### How was this patch tested? Unittest was added. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
