Github user HyukjinKwon commented on a diff in the pull request:
https://github.com/apache/spark/pull/20295#discussion_r214797748
--- Diff: python/pyspark/sql/tests.py ---
@@ -4588,6 +4613,80 @@ def test_timestamp_dst(self):
result = df.groupby('time').apply(foo_udf).sort('time')
self.assertPandasEqual(df.toPandas(), result.toPandas())
+ def test_udf_with_key(self):
+ from pyspark.sql.functions import pandas_udf, col, PandasUDFType
+ df = self.data
+ pdf = df.toPandas()
+
+ def foo1(key, pdf):
+ import numpy as np
+ assert type(key) == tuple
+ assert type(key[0]) == np.int64
+
+ return pdf.assign(v1=key[0],
+ v2=pdf.v * key[0],
+ v3=pdf.v * pdf.id,
+ v4=pdf.v * pdf.id.mean())
+
+ def foo2(key, pdf):
+ import numpy as np
+ assert type(key) == tuple
+ assert type(key[0]) == np.int64
+ assert type(key[1]) == np.int32
+
+ return pdf.assign(v1=key[0],
+ v2=key[1],
+ v3=pdf.v * key[0],
+ v4=pdf.v + key[1])
+
+ def foo3(key, pdf):
+ assert type(key) == tuple
+ assert len(key) == 0
+ return pdf.assign(v1=pdf.v * pdf.id)
+
+ # v2 is int because numpy.int64 * pd.Series<int32> results in
pd.Series<int32>
+ # v3 is long because pd.Series<int64> * pd.Series<int32> results
in pd.Series<int64>
+ udf1 = pandas_udf(
+ foo1,
+ 'id long, v int, v1 long, v2 int, v3 long, v4 double',
+ PandasUDFType.GROUPED_MAP)
+
+ udf2 = pandas_udf(
+ foo2,
+ 'id long, v int, v1 long, v2 int, v3 int, v4 int',
+ PandasUDFType.GROUPED_MAP)
+
+ udf3 = pandas_udf(
+ foo3,
+ 'id long, v int, v1 long',
+ PandasUDFType.GROUPED_MAP)
+
+ # Test groupby column
+ result1 = df.groupby('id').apply(udf1).sort('id', 'v').toPandas()
+ expected1 = pdf.groupby('id')\
+ .apply(lambda x: udf1.func((x.id.iloc[0],), x))\
+ .sort_values(['id', 'v']).reset_index(drop=True)
+ self.assertPandasEqual(expected1, result1)
+
+ # Test groupby expression
+ result2 = df.groupby(df.id % 2).apply(udf1).sort('id',
'v').toPandas()
+ expected2 = pdf.groupby(pdf.id % 2)\
+ .apply(lambda x: udf1.func((x.id.iloc[0] % 2,), x))\
+ .sort_values(['id', 'v']).reset_index(drop=True)
+ self.assertPandasEqual(expected2, result2)
+
+ # Test complex groupby
+ result3 = df.groupby(df.id, df.v % 2).apply(udf2).sort('id',
'v').toPandas()
--- End diff --
In that case, any error in this case will be thrown as is from worker.py
side which is read and redirect to users end via JVM. For instance:
```python
from pyspark.sql.functions import pandas_udf, PandasUDFType
def test_func(key, pdf):
assert len(key) == 0
return pdf
udf1 = pandas_udf(test_func, "id long, v1 double",
PandasUDFType.GROUPED_MAP)
spark.range(10).groupby('id').apply(udf1).sort('id').show()
```
```
18/09/04 14:22:52 ERROR TaskSetManager: Task 1 in stage 0.0 failed 1 times;
aborting job
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/.../spark/python/pyspark/sql/dataframe.py", line 378, in show
print(self._jdf.showString(n, 20, vertical))
File "/.../spark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py",
line 1257, in __call__
File "/.../spark/python/pyspark/sql/utils.py", line 63, in deco
return f(*a, **kw)
File "/.../spark/python/lib/py4j-0.10.7-src.zip/py4j/protocol.py", line
328, in get_return_value
py4j.protocol.Py4JJavaError: An error occurred while calling o68.showString.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 1
in stage 0.0 failed 1 times, most recent failure: Lost task 1.0 in stage 0.0
(TID 1, localhost, executor driver):
org.apache.spark.api.python.PythonException: Traceback (most recent call last):
File "/.../python/lib/pyspark.zip/pyspark/worker.py", line 353, in main
process()
File "/.../python/lib/pyspark.zip/pyspark/worker.py", line 348, in process
serializer.dump_stream(func(split_index, iterator), outfile)
File "/.../spark/python/lib/pyspark.zip/pyspark/worker.py", line 242, in
<lambda>
func = lambda _, it: map(mapper, it)
File "<string>", line 1, in <lambda>
File "/.../spark/python/lib/pyspark.zip/pyspark/worker.py", line 110, in
wrapped
result = f(key, pd.concat(value_series, axis=1))
File "/.../spark/python/pyspark/util.py", line 99, in wrapper
return f(*args, **kwargs)
File "<stdin>", line 2, in test_func
AssertionError
at
org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:418)
at
org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:172)
at
org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:122)
at
org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:372)
at
org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:440)
at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
at
scala.collection.convert.Wrappers$IteratorWrapper.hasNext(Wrappers.scala:30)
at org.spark_project.guava.collect.Ordering.leastOf(Ordering.java:628)
at org.apache.spark.util.collection.Utils$.takeOrdered(Utils.scala:37)
at
org.apache.spark.rdd.RDD$$anonfun$takeOrdered$1$$anonfun$29.apply(RDD.scala:1427)
at
org.apache.spark.rdd.RDD$$anonfun$takeOrdered$1$$anonfun$29.apply(RDD.scala:1424)
at
org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:800)
at
org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:800)
at
org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:48)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
at org.apache.spark.scheduler.Task.run(Task.scala:128)
at
org.apache.spark.executor.Executor$TaskRunner$$anonfun$7.apply(Executor.scala:367)
at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1348)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:373)
at
java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
at
java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
at java.lang.Thread.run(Thread.java:748)
Driver stacktrace:
at
org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1822)
at
org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1810)
at
org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1809)
at
scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
at
org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1809)
at
org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926)
at
org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926)
at scala.Option.foreach(Option.scala:257)
at
org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:926)
at
org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2043)
at
org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1992)
at
org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1981)
at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
at
org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:737)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2061)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2158)
at org.apache.spark.rdd.RDD$$anonfun$reduce$1.apply(RDD.scala:1029)
at
org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
at
org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
at org.apache.spark.rdd.RDD.withScope(RDD.scala:363)
at org.apache.spark.rdd.RDD.reduce(RDD.scala:1011)
at org.apache.spark.rdd.RDD$$anonfun$takeOrdered$1.apply(RDD.scala:1433)
at
org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
at
org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
at org.apache.spark.rdd.RDD.withScope(RDD.scala:363)
at org.apache.spark.rdd.RDD.takeOrdered(RDD.scala:1420)
at
org.apache.spark.sql.execution.TakeOrderedAndProjectExec.executeCollect(limit.scala:207)
at
org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collectFromPlan(Dataset.scala:3384)
at
org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2545)
at
org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2545)
at org.apache.spark.sql.Dataset$$anonfun$53.apply(Dataset.scala:3365)
at
org.apache.spark.sql.execution.SQLExecution$$anonfun$withNewExecutionId$1.apply(SQLExecution.scala:78)
at
org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:125)
at
org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:73)
at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3364)
at org.apache.spark.sql.Dataset.head(Dataset.scala:2545)
at org.apache.spark.sql.Dataset.take(Dataset.scala:2759)
at org.apache.spark.sql.Dataset.getRows(Dataset.scala:255)
at org.apache.spark.sql.Dataset.showString(Dataset.scala:292)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at
sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at
sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
at py4j.Gateway.invoke(Gateway.java:282)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.GatewayConnection.run(GatewayConnection.java:238)
at java.lang.Thread.run(Thread.java:748)
Caused by: org.apache.spark.api.python.PythonException: Traceback (most
recent call last):
File "/.../spark/python/lib/pyspark.zip/pyspark/worker.py", line 353, in
main
process()
File "/.../spark/python/lib/pyspark.zip/pyspark/worker.py", line 348, in
process
serializer.dump_stream(func(split_index, iterator), outfile)
File "/.../spark/python/lib/pyspark.zip/pyspark/worker.py", line 242, in
<lambda>
func = lambda _, it: map(mapper, it)
File "<string>", line 1, in <lambda>
File "/.../spark/python/lib/pyspark.zip/pyspark/worker.py", line 110, in
wrapped
result = f(key, pd.concat(value_series, axis=1))
File "/.../spark/python/pyspark/util.py", line 99, in wrapper
return f(*args, **kwargs)
File "<stdin>", line 2, in test_func
AssertionError
at
org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:418)
at
org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:172)
at
org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:122)
at
org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:372)
at
org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:440)
at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
at
scala.collection.convert.Wrappers$IteratorWrapper.hasNext(Wrappers.scala:30)
at org.spark_project.guava.collect.Ordering.leastOf(Ordering.java:628)
at org.apache.spark.util.collection.Utils$.takeOrdered(Utils.scala:37)
at
org.apache.spark.rdd.RDD$$anonfun$takeOrdered$1$$anonfun$29.apply(RDD.scala:1427)
at
org.apache.spark.rdd.RDD$$anonfun$takeOrdered$1$$anonfun$29.apply(RDD.scala:1424)
at
org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:800)
at
org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:800)
at
org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:48)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
at org.apache.spark.scheduler.Task.run(Task.scala:128)
at
org.apache.spark.executor.Executor$TaskRunner$$anonfun$7.apply(Executor.scala:367)
at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1348)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:373)
at
java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
at
java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
... 1 more
```
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]