Hyukjin Kwon created SPARK-42250:
------------------------------------
Summary: batch_infer_udf with float fails when the batch size
consists of single value
Key: SPARK-42250
URL: https://issues.apache.org/jira/browse/SPARK-42250
Project: Spark
Issue Type: Bug
Components: ML, PySpark
Affects Versions: 3.4.0
Reporter: Hyukjin Kwon
{code}
import numpy as np
import pandas as pd
from pyspark.ml.functions import predict_batch_udf
from pyspark.sql.types import ArrayType, FloatType, StructType, StructField
from typing import Mapping
df = spark.createDataFrame([[[0.0, 1.0, 2.0, 3.0], [0.0, 1.0, 2.0]], [[4.0,
5.0, 6.0, 7.0], [4.0, 5.0, 6.0]]], schema=["t1", "t2"])
def make_multi_sum_fn():
def predict(x1: np.ndarray, x2: np.ndarray) -> np.ndarray:
return np.sum(x1, axis=1) + np.sum(x2, axis=1)
return predict
multi_sum_udf = predict_batch_udf(
make_multi_sum_fn,
return_type=FloatType(),
batch_size=1,
input_tensor_shapes=[[4], [3]],
)
df.select(multi_sum_udf("t1", "t2")).collect()
{code}
fails as below:
{code}
File "/.../spark/python/lib/pyspark.zip/pyspark/worker.py", line 829, in main
process()
File "/.../spark/python/lib/pyspark.zip/pyspark/worker.py", line 821, in
process
serializer.dump_stream(out_iter, outfile)
File "/.../spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py",
line 345, in dump_stream
return ArrowStreamSerializer.dump_stream(self, init_stream_yield_batches(),
stream)
File "/.../spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py",
line 86, in dump_stream
for batch in iterator:
File "/.../spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py",
line 339, in init_stream_yield_batches
batch = self._create_batch(series)
File "/.../spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py",
line 275, in _create_batch
arrs.append(create_array(s, t))
File "/.../spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py",
line 245, in create_array
raise e
File "/.../spark/python/lib/pyspark.zip/pyspark/sql/pandas/serializers.py",
line 233, in create_array
array = pa.Array.from_pandas(s, mask=mask, type=t, safe=self._safecheck)
File "pyarrow/array.pxi", line 1044, in pyarrow.lib.Array.from_pandas
File "pyarrow/array.pxi", line 316, in pyarrow.lib.array
File "pyarrow/array.pxi", line 83, in pyarrow.lib._ndarray_to_array
File "pyarrow/error.pxi", line 100, in pyarrow.lib.check_status
pyarrow.lib.ArrowInvalid: Could not convert array(569.) with type
numpy.ndarray: tried to convert to float32
at
org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:554)
at
org.apache.spark.sql.execution.python.PythonArrowOutput$$anon$1.read(PythonArrowOutput.scala:118)
at
org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:507)
at
org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:491)
at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
at
org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(Unknown
Source)
at
org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
at
org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:760)
at
org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:391)
at
org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:888)
at
org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:888)
at
org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:328)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:92)
at
org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:161)
at org.apache.spark.scheduler.Task.run(Task.scala:139)
at
org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:554)
at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1520)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:557)
at
java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
at
java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
at java.lang.Thread.run(Thread.java:748)
{code}
--
This message was sent by Atlassian Jira
(v8.20.10#820010)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]