[
https://issues.apache.org/jira/browse/SPARK-49793?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel
]
Xinrong Meng updated SPARK-49793:
---------------------------------
Description:
{code:java}
import numpy as np
import pandas as pd
from pyspark.ml.functions import predict_batch_udf
from pyspark.sql.types import DoubleType
from pyspark.sql.functions import struct
data = np.arange(0, 36, dtype=np.float64).reshape(-1, 4)
pdf = pd.DataFrame(data, columns=["a", "b", "c", "d"])
df = spark.createDataFrame(pdf)
def make_predict_fn():
fake_output = np.random.random()
def predict(inputs):
return np.array([fake_output for i in inputs])
return predict
identity = predict_batch_udf(make_predict_fn, return_type=DoubleType(),
batch_size=5)
df1 = df.withColumn("preds", identity(struct("a"))).toPandas()
df2 = df.withColumn("preds", identity(struct("a"))).toPandas()
{code}
NumPy 2.1.0
{code:java}
>>> df1
a b c d preds
0 0.0 1.0 2.0 3.0 0.431752
1 4.0 5.0 6.0 7.0 0.912097
2 8.0 9.0 10.0 11.0 0.679628
3 12.0 13.0 14.0 15.0 0.853850
4 16.0 17.0 18.0 19.0 0.389971
5 20.0 21.0 22.0 23.0 0.654521
6 24.0 25.0 26.0 27.0 0.430569
7 28.0 29.0 30.0 31.0 0.331055
8 32.0 33.0 34.0 35.0 0.306073
>>> df2
a b c d preds
0 0.0 1.0 2.0 3.0 0.679628
1 4.0 5.0 6.0 7.0 0.430569
2 8.0 9.0 10.0 11.0 0.853850
3 12.0 13.0 14.0 15.0 0.306073
4 16.0 17.0 18.0 19.0 0.654521
5 20.0 21.0 22.0 23.0 0.389971
6 24.0 25.0 26.0 27.0 0.507598
7 28.0 29.0 30.0 31.0 0.912097
8 32.0 33.0 34.0 35.0 0.431752 {code}
which should be
{code:java}
>>> df1
a b c d preds
0 0.0 1.0 2.0 3.0 0.685941
1 4.0 5.0 6.0 7.0 0.685941
2 8.0 9.0 10.0 11.0 0.685941
3 12.0 13.0 14.0 15.0 0.685941
4 16.0 17.0 18.0 19.0 0.685941
5 20.0 21.0 22.0 23.0 0.685941
6 24.0 25.0 26.0 27.0 0.685941
7 28.0 29.0 30.0 31.0 0.685941
8 32.0 33.0 34.0 35.0 0.685941
>>> df2
a b c d preds
0 0.0 1.0 2.0 3.0 0.685941
1 4.0 5.0 6.0 7.0 0.685941
2 8.0 9.0 10.0 11.0 0.685941
3 12.0 13.0 14.0 15.0 0.685941
4 16.0 17.0 18.0 19.0 0.685941
5 20.0 21.0 22.0 23.0 0.685941
6 24.0 25.0 26.0 27.0 0.685941
7 28.0 29.0 30.0 31.0 0.685941
8 32.0 33.0 34.0 35.0 0.685941 {code}
> Enable PredictBatchUDFTests.test_caching for NumPy 2
> ----------------------------------------------------
>
> Key: SPARK-49793
> URL: https://issues.apache.org/jira/browse/SPARK-49793
> Project: Spark
> Issue Type: Story
> Components: ML, Tests
> Affects Versions: 4.0.0
> Reporter: Xinrong Meng
> Priority: Major
>
>
> {code:java}
> import numpy as np
> import pandas as pd
> from pyspark.ml.functions import predict_batch_udf
> from pyspark.sql.types import DoubleType
> from pyspark.sql.functions import struct
> data = np.arange(0, 36, dtype=np.float64).reshape(-1, 4)
> pdf = pd.DataFrame(data, columns=["a", "b", "c", "d"])
> df = spark.createDataFrame(pdf)
> def make_predict_fn():
> fake_output = np.random.random()
> def predict(inputs):
> return np.array([fake_output for i in inputs])
> return predict
>
> identity = predict_batch_udf(make_predict_fn, return_type=DoubleType(),
> batch_size=5)
> df1 = df.withColumn("preds", identity(struct("a"))).toPandas()
> df2 = df.withColumn("preds", identity(struct("a"))).toPandas()
> {code}
> NumPy 2.1.0
> {code:java}
> >>> df1
> a b c d preds
> 0 0.0 1.0 2.0 3.0 0.431752
> 1 4.0 5.0 6.0 7.0 0.912097
> 2 8.0 9.0 10.0 11.0 0.679628
> 3 12.0 13.0 14.0 15.0 0.853850
> 4 16.0 17.0 18.0 19.0 0.389971
> 5 20.0 21.0 22.0 23.0 0.654521
> 6 24.0 25.0 26.0 27.0 0.430569
> 7 28.0 29.0 30.0 31.0 0.331055
> 8 32.0 33.0 34.0 35.0 0.306073
> >>> df2
> a b c d preds
> 0 0.0 1.0 2.0 3.0 0.679628
> 1 4.0 5.0 6.0 7.0 0.430569
> 2 8.0 9.0 10.0 11.0 0.853850
> 3 12.0 13.0 14.0 15.0 0.306073
> 4 16.0 17.0 18.0 19.0 0.654521
> 5 20.0 21.0 22.0 23.0 0.389971
> 6 24.0 25.0 26.0 27.0 0.507598
> 7 28.0 29.0 30.0 31.0 0.912097
> 8 32.0 33.0 34.0 35.0 0.431752 {code}
> which should be
> {code:java}
> >>> df1
> a b c d preds
> 0 0.0 1.0 2.0 3.0 0.685941
> 1 4.0 5.0 6.0 7.0 0.685941
> 2 8.0 9.0 10.0 11.0 0.685941
> 3 12.0 13.0 14.0 15.0 0.685941
> 4 16.0 17.0 18.0 19.0 0.685941
> 5 20.0 21.0 22.0 23.0 0.685941
> 6 24.0 25.0 26.0 27.0 0.685941
> 7 28.0 29.0 30.0 31.0 0.685941
> 8 32.0 33.0 34.0 35.0 0.685941
> >>> df2
> a b c d preds
> 0 0.0 1.0 2.0 3.0 0.685941
> 1 4.0 5.0 6.0 7.0 0.685941
> 2 8.0 9.0 10.0 11.0 0.685941
> 3 12.0 13.0 14.0 15.0 0.685941
> 4 16.0 17.0 18.0 19.0 0.685941
> 5 20.0 21.0 22.0 23.0 0.685941
> 6 24.0 25.0 26.0 27.0 0.685941
> 7 28.0 29.0 30.0 31.0 0.685941
> 8 32.0 33.0 34.0 35.0 0.685941 {code}
>
--
This message was sent by Atlassian Jira
(v8.20.10#820010)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]