[ 
https://issues.apache.org/jira/browse/SPARK-42250?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel
 ]

Apache Spark reassigned SPARK-42250:
------------------------------------

    Assignee: Apache Spark

> predict_batch_udf with float fails when the batch size consists of single 
> value
> -------------------------------------------------------------------------------
>
>                 Key: SPARK-42250
>                 URL: https://issues.apache.org/jira/browse/SPARK-42250
>             Project: Spark
>          Issue Type: Bug
>          Components: ML, PySpark
>    Affects Versions: 3.4.0
>            Reporter: Hyukjin Kwon
>            Assignee: Apache Spark
>            Priority: Major
>
> {code}
> import numpy as np
> import pandas as pd
> from pyspark.ml.functions import predict_batch_udf
> from pyspark.sql.types import ArrayType, FloatType, StructType, StructField
> from typing import Mapping
> df = spark.createDataFrame([[[0.0, 1.0, 2.0, 3.0], [0.0, 1.0, 2.0]], [[4.0, 
> 5.0, 6.0, 7.0], [4.0, 5.0, 6.0]]], schema=["t1", "t2"])
> def make_multi_sum_fn():
>     def predict(x1: np.ndarray, x2: np.ndarray) -> np.ndarray:
>         return np.sum(x1, axis=1) + np.sum(x2, axis=1)
>     return predict
> multi_sum_udf = predict_batch_udf(
>     make_multi_sum_fn,
>     return_type=FloatType(),
>     batch_size=1,
>     input_tensor_shapes=[[4], [3]],
> )
> df.select(multi_sum_udf("t1", "t2")).collect()
> {code}
> fails as below:
> {code}
>  File "/.../spark/python/lib/pyspark.zip/pyspark/worker.py", line 829, in main
>     process()
>   File "/.../spark/python/lib/pyspark.zip/pyspark/worker.py", line 821, in 
> process
>     serializer.dump_stream(out_iter, outfile)
>   File "/.../spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", 
> line 345, in dump_stream
>     return ArrowStreamSerializer.dump_stream(self, 
> init_stream_yield_batches(), stream)
>   File "/.../spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", 
> line 86, in dump_stream
>     for batch in iterator:
>   File "/.../spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", 
> line 339, in init_stream_yield_batches
>     batch = self._create_batch(series)
>   File "/.../spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", 
> line 275, in _create_batch
>     arrs.append(create_array(s, t))
>   File "/.../spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", 
> line 245, in create_array
>     raise e
>   File "/.../spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py", 
> line 233, in create_array
>     array = pa.Array.from_pandas(s, mask=mask, type=t, safe=self._safecheck)
>   File "pyarrow/array.pxi", line 1044, in pyarrow.lib.Array.from_pandas
>   File "pyarrow/array.pxi", line 316, in pyarrow.lib.array
>   File "pyarrow/array.pxi", line 83, in pyarrow.lib._ndarray_to_array
>   File "pyarrow/error.pxi", line 100, in pyarrow.lib.check_status
> pyarrow.lib.ArrowInvalid: Could not convert array(569.) with type 
> numpy.ndarray: tried to convert to float32
>       at 
> org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:554)
>       at 
> org.apache.spark.sql.execution.python.PythonArrowOutput$$anon$1.read(PythonArrowOutput.scala:118)
>       at 
> org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:507)
>       at 
> org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
>       at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:491)
>       at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
>       at 
> org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(Unknown
>  Source)
>       at 
> org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
>       at 
> org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:760)
>       at 
> org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:391)
>       at 
> org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:888)
>       at 
> org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:888)
>       at 
> org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
>       at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364)
>       at org.apache.spark.rdd.RDD.iterator(RDD.scala:328)
>       at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:92)
>       at 
> org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:161)
>       at org.apache.spark.scheduler.Task.run(Task.scala:139)
>       at 
> org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:554)
>       at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1520)
>       at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:557)
>       at 
> java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
>       at 
> java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
>       at java.lang.Thread.run(Thread.java:748)
> {code}



--
This message was sent by Atlassian Jira
(v8.20.10#820010)

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to