Ryan Avery created SPARK-46175:
----------------------------------

             Summary: can't return prediction that has different length than ml 
input
                 Key: SPARK-46175
                 URL: https://issues.apache.org/jira/browse/SPARK-46175
             Project: Spark
          Issue Type: Improvement
          Components: MLlib, PySpark
    Affects Versions: 3.4.1
         Environment: I'm on spark 3.4
            Reporter: Ryan Avery


I'm using

 
from pyspark.ml.functions import predict_batch_udf
to construct a pandas udf that runs a computer vision model to predict 
classification labels for images. The model takes a 4D array as input and 
returns a 4D array as output (Batch, Channels, Height, Width)
 
However I'd like to run some additional processing in the pandas_udf to convert 
the 4D output array (floats) to text labels since this is a more useful output 
and we are trying to register pandas_udfs ahead of time for spark.sql users.
 
 
When I set the return type to a StringType though I get an error
 
```
23/11/29 02:43:04 WARN TaskSetManager: Lost task 0.0 in stage 8.0 (TID 16) 
(172.18.0.2 executor 0): org.apache.spark.api.python.PythonException: Traceback 
(most recent call last): File "/opt/spark/python/pyspark/ml/functions.py", line 
809, in predict yield _validate_and_transform_prediction_result( File 
"/opt/spark/python/lib/pyspark.zip/pyspark/ml/functions.py", line 331, in 
_validate_and_transform_prediction_result raise ValueError("Prediction results 
must have same length as input data.") ValueError: Prediction results must have 
same length as input data.
```
 
This seems like an unnecessary limitation, since it is common for ML models, 
especially in computer vision, to take input shapes different from output 
shapes.
 
I think the issue is here: 
[https://spark.apache.org/docs/latest/api/python/_modules/pyspark/ml/functions.html]
 
Can this check that enforces same shape be removed?
 
 
to illustrate the problem, here are my StructTypes. The Raw one works but the 
Str one does not
 
```
from collections import namedtuple
from pyspark.sql.types import ArrayType, IntegerType, StringType,StructType, 
StructField, FloatType

# Define the schemas using namedtuple
Task = namedtuple('TaskSchema', ['inference_input', 'inference_result'])

SingleLabelClassificationRaw = Task(
inference_input=StructType([
StructField("array", ArrayType(IntegerType()), nullable=False),
StructField("shape", ArrayType(IntegerType()), nullable=False)
]),
inference_result=ArrayType(FloatType())
)
# TODO handle non ints
SingleLabelClassificationStr = Task(
inference_input=StructType([
StructField("array", ArrayType(IntegerType()), nullable=False),
StructField("shape", ArrayType(IntegerType()), nullable=False)
]),
inference_result=StringType()
)
```



--
This message was sent by Atlassian Jira
(v8.20.10#820010)

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to