HyukjinKwon opened a new pull request, #39817:
URL: https://github.com/apache/spark/pull/39817

   ### What changes were proposed in this pull request?
   
   This PR is sort of a followup of https://github.com/apache/spark/pull/37734 
which handles the case when the batch is single scalar value.
   
   Essentially it proposes to work around the pandas behaviour by explicitly 
casting back to the original data type of NumPy Arrow:
   
   ```python
   >>> import numpy as np
   >>> import pandas as pd
   >>> np.squeeze(np.array([1.])).dtype
   dtype('float64')
   >>> pd.Series(np.squeeze(np.array([1.]))).dtype
   dtype('O')
   >>> pd.Series(np.squeeze(np.array([1., 1.]))).dtype
   dtype('float64')
   ```
   
   
   ### Why are the changes needed?
   
   Using `predict_batch_udf` fails when the size of batch happen to have single 
value. For example, even when the batch size is set to 10, if the size of data 
is 21, it fails because the last batch consists of the single value.
   
   ```python
   import numpy as np
   import pandas as pd
   from pyspark.ml.functions import predict_batch_udf
   from pyspark.sql.types import ArrayType, FloatType, StructType, StructField
   from typing import Mapping
   
   df = spark.createDataFrame([[[0.0, 1.0, 2.0, 3.0], [0.0, 1.0, 2.0]]], 
schema=["t1", "t2"])
   
   def make_multi_sum_fn():
       def predict(x1: np.ndarray, x2: np.ndarray) -> np.ndarray:
           return np.sum(x1, axis=1) + np.sum(x2, axis=1)
       return predict
   
   multi_sum_udf = predict_batch_udf(
       make_multi_sum_fn,
       return_type=FloatType(),
       batch_size=1,
       input_tensor_shapes=[[4], [3]],
   )
   
   df.select(multi_sum_udf("t1", "t2")).collect()
   ```
   
   **Before**
   
   ```
    File "/.../spark/python/lib/pyspark.zip/pyspark/worker.py", line 829, in 
main
       process()
     File "/.../spark/python/lib/pyspark.zip/pyspark/worker.py", line 821, in 
process
       serializer.dump_stream(out_iter, outfile)
     File 
"/.../spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", line 
345, in dump_stream
       return ArrowStreamSerializer.dump_stream(self, 
init_stream_yield_batches(), stream)
     File 
"/.../spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", line 86, 
in dump_stream
       for batch in iterator:
     File 
"/.../spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", line 
339, in init_stream_yield_batches
       batch = self._create_batch(series)
     File 
"/.../spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", line 
275, in _create_batch
       arrs.append(create_array(s, t))
     File 
"/.../spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", line 
245, in create_array
       raise e
     File 
"/.../spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", line 
233, in create_array
       array = pa.Array.from_pandas(s, mask=mask, type=t, safe=self._safecheck)
     File "pyarrow/array.pxi", line 1044, in pyarrow.lib.Array.from_pandas
     File "pyarrow/array.pxi", line 316, in pyarrow.lib.array
     File "pyarrow/array.pxi", line 83, in pyarrow.lib._ndarray_to_array
     File "pyarrow/error.pxi", line 100, in pyarrow.lib.check_status
   pyarrow.lib.ArrowInvalid: Could not convert array(569.) with type 
numpy.ndarray: tried to convert to float32
   ```
   
   **After**
   
   ```
   [Row(predict(t1, t2)=9.0)]
   ```
   
   ### Does this PR introduce _any_ user-facing change?
   
   This feature has not been released yet, so no user-facing change to the end 
users.
   It fixes a bug in the unreleased feature.
   
   ### How was this patch tested?
   
   Unittest was added.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to