Github user HyukjinKwon commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20295#discussion_r214797748
  
    --- Diff: python/pyspark/sql/tests.py ---
    @@ -4588,6 +4613,80 @@ def test_timestamp_dst(self):
             result = df.groupby('time').apply(foo_udf).sort('time')
             self.assertPandasEqual(df.toPandas(), result.toPandas())
     
    +    def test_udf_with_key(self):
    +        from pyspark.sql.functions import pandas_udf, col, PandasUDFType
    +        df = self.data
    +        pdf = df.toPandas()
    +
    +        def foo1(key, pdf):
    +            import numpy as np
    +            assert type(key) == tuple
    +            assert type(key[0]) == np.int64
    +
    +            return pdf.assign(v1=key[0],
    +                              v2=pdf.v * key[0],
    +                              v3=pdf.v * pdf.id,
    +                              v4=pdf.v * pdf.id.mean())
    +
    +        def foo2(key, pdf):
    +            import numpy as np
    +            assert type(key) == tuple
    +            assert type(key[0]) == np.int64
    +            assert type(key[1]) == np.int32
    +
    +            return pdf.assign(v1=key[0],
    +                              v2=key[1],
    +                              v3=pdf.v * key[0],
    +                              v4=pdf.v + key[1])
    +
    +        def foo3(key, pdf):
    +            assert type(key) == tuple
    +            assert len(key) == 0
    +            return pdf.assign(v1=pdf.v * pdf.id)
    +
    +        # v2 is int because numpy.int64 * pd.Series<int32> results in 
pd.Series<int32>
    +        # v3 is long because pd.Series<int64> * pd.Series<int32> results 
in pd.Series<int64>
    +        udf1 = pandas_udf(
    +            foo1,
    +            'id long, v int, v1 long, v2 int, v3 long, v4 double',
    +            PandasUDFType.GROUPED_MAP)
    +
    +        udf2 = pandas_udf(
    +            foo2,
    +            'id long, v int, v1 long, v2 int, v3 int, v4 int',
    +            PandasUDFType.GROUPED_MAP)
    +
    +        udf3 = pandas_udf(
    +            foo3,
    +            'id long, v int, v1 long',
    +            PandasUDFType.GROUPED_MAP)
    +
    +        # Test groupby column
    +        result1 = df.groupby('id').apply(udf1).sort('id', 'v').toPandas()
    +        expected1 = pdf.groupby('id')\
    +            .apply(lambda x: udf1.func((x.id.iloc[0],), x))\
    +            .sort_values(['id', 'v']).reset_index(drop=True)
    +        self.assertPandasEqual(expected1, result1)
    +
    +        # Test groupby expression
    +        result2 = df.groupby(df.id % 2).apply(udf1).sort('id', 
'v').toPandas()
    +        expected2 = pdf.groupby(pdf.id % 2)\
    +            .apply(lambda x: udf1.func((x.id.iloc[0] % 2,), x))\
    +            .sort_values(['id', 'v']).reset_index(drop=True)
    +        self.assertPandasEqual(expected2, result2)
    +
    +        # Test complex groupby
    +        result3 = df.groupby(df.id, df.v % 2).apply(udf2).sort('id', 
'v').toPandas()
    --- End diff --
    
    In that case, any error in this case will be thrown as is from worker.py 
side which is read and redirect to users end via JVM. For instance:
    
    ```python
    from pyspark.sql.functions import pandas_udf, PandasUDFType
    def test_func(key, pdf):
        assert len(key) == 0
        return pdf
    
    udf1 = pandas_udf(test_func, "id long, v1 double", 
PandasUDFType.GROUPED_MAP)
    spark.range(10).groupby('id').apply(udf1).sort('id').show()
    ```
    
    ```
    18/09/04 14:22:52 ERROR TaskSetManager: Task 1 in stage 0.0 failed 1 times; 
aborting job
    Traceback (most recent call last):
      File "<stdin>", line 1, in <module>
      File "/.../spark/python/pyspark/sql/dataframe.py", line 378, in show
        print(self._jdf.showString(n, 20, vertical))
      File "/.../spark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py", 
line 1257, in __call__
      File "/.../spark/python/pyspark/sql/utils.py", line 63, in deco
        return f(*a, **kw)
      File "/.../spark/python/lib/py4j-0.10.7-src.zip/py4j/protocol.py", line 
328, in get_return_value
    py4j.protocol.Py4JJavaError: An error occurred while calling o68.showString.
    : org.apache.spark.SparkException: Job aborted due to stage failure: Task 1 
in stage 0.0 failed 1 times, most recent failure: Lost task 1.0 in stage 0.0 
(TID 1, localhost, executor driver): 
org.apache.spark.api.python.PythonException: Traceback (most recent call last):
      File "/.../python/lib/pyspark.zip/pyspark/worker.py", line 353, in main
        process()
      File "/.../python/lib/pyspark.zip/pyspark/worker.py", line 348, in process
        serializer.dump_stream(func(split_index, iterator), outfile)
      File "/.../spark/python/lib/pyspark.zip/pyspark/worker.py", line 242, in 
<lambda>
        func = lambda _, it: map(mapper, it)
      File "<string>", line 1, in <lambda>
      File "/.../spark/python/lib/pyspark.zip/pyspark/worker.py", line 110, in 
wrapped
        result = f(key, pd.concat(value_series, axis=1))
      File "/.../spark/python/pyspark/util.py", line 99, in wrapper
        return f(*args, **kwargs)
      File "<stdin>", line 2, in test_func
    AssertionError
    
        at 
org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:418)
        at 
org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:172)
        at 
org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:122)
        at 
org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:372)
        at 
org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
        at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:440)
        at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
        at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
        at 
scala.collection.convert.Wrappers$IteratorWrapper.hasNext(Wrappers.scala:30)
        at org.spark_project.guava.collect.Ordering.leastOf(Ordering.java:628)
        at org.apache.spark.util.collection.Utils$.takeOrdered(Utils.scala:37)
        at 
org.apache.spark.rdd.RDD$$anonfun$takeOrdered$1$$anonfun$29.apply(RDD.scala:1427)
        at 
org.apache.spark.rdd.RDD$$anonfun$takeOrdered$1$$anonfun$29.apply(RDD.scala:1424)
        at 
org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:800)
        at 
org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:800)
        at 
org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:48)
        at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
        at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
        at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
        at org.apache.spark.scheduler.Task.run(Task.scala:128)
        at 
org.apache.spark.executor.Executor$TaskRunner$$anonfun$7.apply(Executor.scala:367)
        at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1348)
        at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:373)
        at 
java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
        at 
java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
        at java.lang.Thread.run(Thread.java:748)
    
    Driver stacktrace:
        at 
org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1822)
        at 
org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1810)
        at 
org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1809)
        at 
scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
        at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
        at 
org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1809)
        at 
org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926)
        at 
org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926)
        at scala.Option.foreach(Option.scala:257)
        at 
org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:926)
        at 
org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2043)
        at 
org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1992)
        at 
org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1981)
        at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
        at 
org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:737)
        at org.apache.spark.SparkContext.runJob(SparkContext.scala:2061)
        at org.apache.spark.SparkContext.runJob(SparkContext.scala:2158)
        at org.apache.spark.rdd.RDD$$anonfun$reduce$1.apply(RDD.scala:1029)
        at 
org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
        at 
org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
        at org.apache.spark.rdd.RDD.withScope(RDD.scala:363)
        at org.apache.spark.rdd.RDD.reduce(RDD.scala:1011)
        at org.apache.spark.rdd.RDD$$anonfun$takeOrdered$1.apply(RDD.scala:1433)
        at 
org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
        at 
org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
        at org.apache.spark.rdd.RDD.withScope(RDD.scala:363)
        at org.apache.spark.rdd.RDD.takeOrdered(RDD.scala:1420)
        at 
org.apache.spark.sql.execution.TakeOrderedAndProjectExec.executeCollect(limit.scala:207)
        at 
org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collectFromPlan(Dataset.scala:3384)
        at 
org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2545)
        at 
org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2545)
        at org.apache.spark.sql.Dataset$$anonfun$53.apply(Dataset.scala:3365)
        at 
org.apache.spark.sql.execution.SQLExecution$$anonfun$withNewExecutionId$1.apply(SQLExecution.scala:78)
        at 
org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:125)
        at 
org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:73)
        at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3364)
        at org.apache.spark.sql.Dataset.head(Dataset.scala:2545)
        at org.apache.spark.sql.Dataset.take(Dataset.scala:2759)
        at org.apache.spark.sql.Dataset.getRows(Dataset.scala:255)
        at org.apache.spark.sql.Dataset.showString(Dataset.scala:292)
        at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
        at 
sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
        at 
sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
        at java.lang.reflect.Method.invoke(Method.java:498)
        at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
        at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
        at py4j.Gateway.invoke(Gateway.java:282)
        at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
        at py4j.commands.CallCommand.execute(CallCommand.java:79)
        at py4j.GatewayConnection.run(GatewayConnection.java:238)
        at java.lang.Thread.run(Thread.java:748)
    Caused by: org.apache.spark.api.python.PythonException: Traceback (most 
recent call last):
      File "/.../spark/python/lib/pyspark.zip/pyspark/worker.py", line 353, in 
main
        process()
      File "/.../spark/python/lib/pyspark.zip/pyspark/worker.py", line 348, in 
process
        serializer.dump_stream(func(split_index, iterator), outfile)
      File "/.../spark/python/lib/pyspark.zip/pyspark/worker.py", line 242, in 
<lambda>
        func = lambda _, it: map(mapper, it)
      File "<string>", line 1, in <lambda>
      File "/.../spark/python/lib/pyspark.zip/pyspark/worker.py", line 110, in 
wrapped
        result = f(key, pd.concat(value_series, axis=1))
      File "/.../spark/python/pyspark/util.py", line 99, in wrapper
        return f(*args, **kwargs)
      File "<stdin>", line 2, in test_func
    AssertionError
    
        at 
org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:418)
        at 
org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:172)
        at 
org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:122)
        at 
org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:372)
        at 
org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
        at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:440)
        at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
        at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
        at 
scala.collection.convert.Wrappers$IteratorWrapper.hasNext(Wrappers.scala:30)
        at org.spark_project.guava.collect.Ordering.leastOf(Ordering.java:628)
        at org.apache.spark.util.collection.Utils$.takeOrdered(Utils.scala:37)
        at 
org.apache.spark.rdd.RDD$$anonfun$takeOrdered$1$$anonfun$29.apply(RDD.scala:1427)
        at 
org.apache.spark.rdd.RDD$$anonfun$takeOrdered$1$$anonfun$29.apply(RDD.scala:1424)
        at 
org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:800)
        at 
org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:800)
        at 
org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:48)
        at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
        at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
        at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
        at org.apache.spark.scheduler.Task.run(Task.scala:128)
        at 
org.apache.spark.executor.Executor$TaskRunner$$anonfun$7.apply(Executor.scala:367)
        at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1348)
        at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:373)
        at 
java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
        at 
java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
        ... 1 more
    ```


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to