This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch branch-3.4 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push: new 4ab6e61edca [SPARK-42250][PYTHON][ML] predict_batch_udf` with float fails when the batch size consists of single value 4ab6e61edca is described below commit 4ab6e61edca357b1c8583fa0818c00df83389d4e Author: Hyukjin Kwon <gurwls...@apache.org> AuthorDate: Tue Jan 31 19:41:54 2023 +0900 [SPARK-42250][PYTHON][ML] predict_batch_udf` with float fails when the batch size consists of single value ### What changes were proposed in this pull request? This PR is sort of a followup of https://github.com/apache/spark/pull/37734 which handles the case when the batch is single scalar value. Essentially it proposes to work around the pandas behaviour by explicitly casting back to the original data type of NumPy Arrow: ```python >>> import numpy as np >>> import pandas as pd >>> np.squeeze(np.array([1.])).dtype dtype('float64') >>> pd.Series(np.squeeze(np.array([1.]))).dtype dtype('O') >>> pd.Series(np.squeeze(np.array([1., 1.]))).dtype dtype('float64') ``` ### Why are the changes needed? Using `predict_batch_udf` fails when the size of batch happen to have single value. For example, even when the batch size is set to 10, if the size of data is 21, it fails because the last batch consists of the single value. ```python import numpy as np import pandas as pd from pyspark.ml.functions import predict_batch_udf from pyspark.sql.types import ArrayType, FloatType, StructType, StructField from typing import Mapping df = spark.createDataFrame([[[0.0, 1.0, 2.0, 3.0], [0.0, 1.0, 2.0]]], schema=["t1", "t2"]) def make_multi_sum_fn(): def predict(x1: np.ndarray, x2: np.ndarray) -> np.ndarray: return np.sum(x1, axis=1) + np.sum(x2, axis=1) return predict multi_sum_udf = predict_batch_udf( make_multi_sum_fn, return_type=FloatType(), batch_size=1, input_tensor_shapes=[[4], [3]], ) df.select(multi_sum_udf("t1", "t2")).collect() ``` **Before** ``` File "/.../spark/python/lib/pyspark.zip/pyspark/worker.py", line 829, in main process() File "/.../spark/python/lib/pyspark.zip/pyspark/worker.py", line 821, in process serializer.dump_stream(out_iter, outfile) File "/.../spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", line 345, in dump_stream return ArrowStreamSerializer.dump_stream(self, init_stream_yield_batches(), stream) File "/.../spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", line 86, in dump_stream for batch in iterator: File "/.../spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", line 339, in init_stream_yield_batches batch = self._create_batch(series) File "/.../spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", line 275, in _create_batch arrs.append(create_array(s, t)) File "/.../spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", line 245, in create_array raise e File "/.../spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", line 233, in create_array array = pa.Array.from_pandas(s, mask=mask, type=t, safe=self._safecheck) File "pyarrow/array.pxi", line 1044, in pyarrow.lib.Array.from_pandas File "pyarrow/array.pxi", line 316, in pyarrow.lib.array File "pyarrow/array.pxi", line 83, in pyarrow.lib._ndarray_to_array File "pyarrow/error.pxi", line 100, in pyarrow.lib.check_status pyarrow.lib.ArrowInvalid: Could not convert array(569.) with type numpy.ndarray: tried to convert to float32 ``` **After** ``` [Row(predict(t1, t2)=9.0)] ``` ### Does this PR introduce _any_ user-facing change? This feature has not been released yet, so no user-facing change to the end users. It fixes a bug in the unreleased feature. ### How was this patch tested? Unittest was added. Closes #39817 from HyukjinKwon/SPARK-42250. Authored-by: Hyukjin Kwon <gurwls...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> (cherry picked from commit ccbc075f39f9322423541ed34a9c6b6bbf60a280) Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/ml/functions.py | 3 ++- python/pyspark/ml/tests/test_functions.py | 24 +++++++++++++++++++++++- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/python/pyspark/ml/functions.py b/python/pyspark/ml/functions.py index 1ed2c294356..977f9a7b5be 100644 --- a/python/pyspark/ml/functions.py +++ b/python/pyspark/ml/functions.py @@ -340,7 +340,8 @@ def _validate_and_transform_prediction_result( ): raise ValueError("Invalid shape for scalar prediction result.") - return pd.Series(np.squeeze(preds)) # type: ignore + output = np.squeeze(preds) # type: ignore[arg-type] + return pd.Series(output).astype(output.dtype) else: raise ValueError("Unsupported return type") diff --git a/python/pyspark/ml/tests/test_functions.py b/python/pyspark/ml/tests/test_functions.py index 04bb3ee7035..6c2268b0968 100644 --- a/python/pyspark/ml/tests/test_functions.py +++ b/python/pyspark/ml/tests/test_functions.py @@ -20,7 +20,7 @@ import unittest from pyspark.ml.functions import predict_batch_udf from pyspark.sql.functions import array, struct, col -from pyspark.sql.types import ArrayType, DoubleType, IntegerType, StructType, StructField +from pyspark.sql.types import ArrayType, DoubleType, IntegerType, StructType, StructField, FloatType from pyspark.testing.mlutils import SparkSessionTestCase @@ -488,6 +488,28 @@ class PredictBatchUDFTests(SparkSessionTestCase): .toPandas() ) + def test_single_value_in_batch(self): + # SPARK-42250: batches consisting of single float value should work + df = self.spark.createDataFrame( + [[[0.0, 1.0, 2.0, 3.0], [0.0, 1.0, 2.0]]], schema=["t1", "t2"] + ) + + def make_multi_sum_fn(): + def predict(x1: np.ndarray, x2: np.ndarray) -> np.ndarray: + return np.sum(x1, axis=1) + np.sum(x2, axis=1) + + return predict + + multi_sum_udf = predict_batch_udf( + make_multi_sum_fn, + return_type=FloatType(), + batch_size=1, + input_tensor_shapes=[[4], [3]], + ) + + [value] = df.select(multi_sum_udf("t1", "t2")).first() + self.assertEqual(value, 9.0) + if __name__ == "__main__": from pyspark.ml.tests.test_functions import * # noqa: F401 --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org