Ryan Avery created SPARK-46175:
----------------------------------
Summary: can't return prediction that has different length than ml
input
Key: SPARK-46175
URL: https://issues.apache.org/jira/browse/SPARK-46175
Project: Spark
Issue Type: Improvement
Components: MLlib, PySpark
Affects Versions: 3.4.1
Environment: I'm on spark 3.4
Reporter: Ryan Avery
I'm using
from pyspark.ml.functions import predict_batch_udf
to construct a pandas udf that runs a computer vision model to predict
classification labels for images. The model takes a 4D array as input and
returns a 4D array as output (Batch, Channels, Height, Width)
However I'd like to run some additional processing in the pandas_udf to convert
the 4D output array (floats) to text labels since this is a more useful output
and we are trying to register pandas_udfs ahead of time for spark.sql users.
When I set the return type to a StringType though I get an error
```
23/11/29 02:43:04 WARN TaskSetManager: Lost task 0.0 in stage 8.0 (TID 16)
(172.18.0.2 executor 0): org.apache.spark.api.python.PythonException: Traceback
(most recent call last): File "/opt/spark/python/pyspark/ml/functions.py", line
809, in predict yield _validate_and_transform_prediction_result( File
"/opt/spark/python/lib/pyspark.zip/pyspark/ml/functions.py", line 331, in
_validate_and_transform_prediction_result raise ValueError("Prediction results
must have same length as input data.") ValueError: Prediction results must have
same length as input data.
```
This seems like an unnecessary limitation, since it is common for ML models,
especially in computer vision, to take input shapes different from output
shapes.
I think the issue is here:
[https://spark.apache.org/docs/latest/api/python/_modules/pyspark/ml/functions.html]
Can this check that enforces same shape be removed?
to illustrate the problem, here are my StructTypes. The Raw one works but the
Str one does not
```
from collections import namedtuple
from pyspark.sql.types import ArrayType, IntegerType, StringType,StructType,
StructField, FloatType
# Define the schemas using namedtuple
Task = namedtuple('TaskSchema', ['inference_input', 'inference_result'])
SingleLabelClassificationRaw = Task(
inference_input=StructType([
StructField("array", ArrayType(IntegerType()), nullable=False),
StructField("shape", ArrayType(IntegerType()), nullable=False)
]),
inference_result=ArrayType(FloatType())
)
# TODO handle non ints
SingleLabelClassificationStr = Task(
inference_input=StructType([
StructField("array", ArrayType(IntegerType()), nullable=False),
StructField("shape", ArrayType(IntegerType()), nullable=False)
]),
inference_result=StringType()
)
```
--
This message was sent by Atlassian Jira
(v8.20.10#820010)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]