leewyang opened a new pull request, #40967:
URL: https://github.com/apache/spark/pull/40967
### What changes were proposed in this pull request?
This is a followup to #39817 to handle another error condition when the
input batch is a single scalar value (where the previous fix focused on a
single scalar value output).
### Why are the changes needed?
Using `predict_batch_udf` fails when the input batch size is one.
```
import numpy as np
from pyspark.ml.functions import predict_batch_udf
from pyspark.sql.types import DoubleType
df = spark.createDataFrame([[1.0],[2.0]], schema=["a"])
def make_predict_fn():
def predict(inputs):
return inputs
return predict
identity = predict_batch_udf(make_predict_fn, return_type=DoubleType(),
batch_size=1)
preds = df.withColumn("preds", identity("a")).show()
```
fails with:
```
File "/.../spark/python/pyspark/worker.py", line 869, in main
process()
File "/.../spark/python/pyspark/worker.py", line 861, in process
serializer.dump_stream(out_iter, outfile)
File "/.../spark/python/pyspark/sql/pandas/serializers.py", line 354, in
dump_stream
return ArrowStreamSerializer.dump_stream(self,
init_stream_yield_batches(), stream)
File "/.../spark/python/pyspark/sql/pandas/serializers.py", line 86, in
dump_stream
for batch in iterator:
File "/.../spark/python/pyspark/sql/pandas/serializers.py", line 347, in
init_stream_yield_batches
for series in iterator:
File "/.../spark/python/pyspark/worker.py", line 555, in func
for result_batch, result_type in result_iter:
File "/.../spark/python/pyspark/ml/functions.py", line 818, in predict
yield _validate_and_transform_prediction_result(
File "/.../spark/python/pyspark/ml/functions.py", line 339, in
_validate_and_transform_prediction_result
if len(preds_array) != num_input_rows:
TypeError: len() of unsized object
```
After the fix:
```
+---+-----+
| a|preds|
+---+-----+
|1.0| 1.0|
|2.0| 2.0|
+---+-----+
```
### Does this PR introduce _any_ user-facing change?
This fixes a bug in the feature that was released in Spark 3.4.0.
### How was this patch tested?
Unit test was added.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]