Repository: spark
Updated Branches:
  refs/heads/branch-2.3 a55de387d -> 470cacd49


[SPARK-23754][PYTHON][FOLLOWUP][BACKPORT-2.3] Move UDF stop iteration wrapping 
from driver to executor

SPARK-23754 was fixed in #21383 by changing the UDF code to wrap the user 
function, but this required a hack to save its argspec. This PR reverts this 
change and fixes the `StopIteration` bug in the worker.

The root of the problem is that when an user-supplied function raises a 
`StopIteration`, pyspark might stop processing data, if this function is used 
in a for-loop. The solution is to catch `StopIteration`s exceptions and 
re-raise them as `RuntimeError`s, so that the execution fails and the error is 
reported to the user. This is done using the `fail_on_stopiteration` wrapper, 
in different ways depending on where the function is used:
 - In RDDs, the user function is wrapped in the driver, because this function 
is also called in the driver itself.
 - In SQL UDFs, the function is wrapped in the worker, since all processing 
happens there. Moreover, the worker needs the signature of the user function, 
which is lost when wrapping it, but passing this signature to the worker 
requires a not so nice hack.

HyukjinKwon

Author: edorigatti <emilio.doriga...@gmail.com>
Author: e-dorigatti <emilio.doriga...@gmail.com>

Closes #21538 from e-dorigatti/branch-2.3.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/470cacd4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/470cacd4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/470cacd4

Branch: refs/heads/branch-2.3
Commit: 470cacd4982ca369ffd294ee37abfa1864d39967
Parents: a55de38
Author: edorigatti <emilio.doriga...@gmail.com>
Authored: Wed Jun 13 09:06:06 2018 +0800
Committer: hyukjinkwon <gurwls...@apache.org>
Committed: Wed Jun 13 09:06:06 2018 +0800

----------------------------------------------------------------------
 python/pyspark/sql/tests.py | 54 ++++++++++++++++++++++++++++------------
 python/pyspark/sql/udf.py   |  4 +--
 python/pyspark/tests.py     | 37 ++++++++++++++++-----------
 python/pyspark/util.py      |  2 +-
 python/pyspark/worker.py    | 11 +++++---
 5 files changed, 70 insertions(+), 38 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/470cacd4/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 818ba83..aa7d8eb 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -853,22 +853,6 @@ class SQLTests(ReusedSQLTestCase):
         self.assertEqual(f, f_.func)
         self.assertEqual(return_type, f_.returnType)
 
-    def test_stopiteration_in_udf(self):
-        # test for SPARK-23754
-        from pyspark.sql.functions import udf
-        from py4j.protocol import Py4JJavaError
-
-        def foo(x):
-            raise StopIteration()
-
-        with self.assertRaises(Py4JJavaError) as cm:
-            self.spark.range(0, 1000).withColumn('v', udf(foo)('id')).show()
-
-        self.assertIn(
-            "Caught StopIteration thrown from user's code; failing the task",
-            cm.exception.java_exception.toString()
-        )
-
     def test_validate_column_types(self):
         from pyspark.sql.functions import udf, to_json
         from pyspark.sql.column import _to_java_column
@@ -3917,6 +3901,44 @@ class PandasUDFTests(ReusedSQLTestCase):
                 def foo(k, v):
                     return k
 
+    def test_stopiteration_in_udf(self):
+        from pyspark.sql.functions import udf, pandas_udf, PandasUDFType
+        from py4j.protocol import Py4JJavaError
+
+        def foo(x):
+            raise StopIteration()
+
+        def foofoo(x, y):
+            raise StopIteration()
+
+        exc_message = "Caught StopIteration thrown from user's code; failing 
the task"
+        df = self.spark.range(0, 100)
+
+        # plain udf (test for SPARK-23754)
+        self.assertRaisesRegexp(
+            Py4JJavaError,
+            exc_message,
+            df.withColumn('v', udf(foo)('id')).collect
+        )
+
+        # pandas scalar udf
+        self.assertRaisesRegexp(
+            Py4JJavaError,
+            exc_message,
+            df.withColumn(
+                'v', pandas_udf(foo, 'double', PandasUDFType.SCALAR)('id')
+            ).collect
+        )
+
+        # pandas grouped map
+        self.assertRaisesRegexp(
+            Py4JJavaError,
+            exc_message,
+            df.groupBy('id').apply(
+                pandas_udf(foo, df.schema, PandasUDFType.GROUPED_MAP)
+            ).collect
+        )
+
 
 @unittest.skipIf(
     not _have_pandas or not _have_pyarrow,

http://git-wip-us.apache.org/repos/asf/spark/blob/470cacd4/python/pyspark/sql/udf.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py
index 7d813af..671e568 100644
--- a/python/pyspark/sql/udf.py
+++ b/python/pyspark/sql/udf.py
@@ -24,7 +24,6 @@ from pyspark.rdd import _prepare_for_python_RDD, 
PythonEvalType, ignore_unicode_
 from pyspark.sql.column import Column, _to_java_column, _to_seq
 from pyspark.sql.types import StringType, DataType, StructType, 
_parse_datatype_string, \
     to_arrow_type, to_arrow_schema
-from pyspark.util import fail_on_stopiteration
 
 __all__ = ["UDFRegistration"]
 
@@ -155,8 +154,7 @@ class UserDefinedFunction(object):
         spark = SparkSession.builder.getOrCreate()
         sc = spark.sparkContext
 
-        func = fail_on_stopiteration(self.func)
-        wrapped_func = _wrap_function(sc, func, self.returnType)
+        wrapped_func = _wrap_function(sc, self.func, self.returnType)
         jdt = spark._jsparkSession.parseDataType(self.returnType.json())
         judf = 
sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
             self._name, wrapped_func, jdt, self.evalType, self.deterministic)

http://git-wip-us.apache.org/repos/asf/spark/blob/470cacd4/python/pyspark/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index af39450..81bff4b 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -1270,27 +1270,34 @@ class RDDTests(ReusedPySparkTestCase):
         self.assertRaises(Py4JJavaError, rdd.pipe('grep 4', 
checkCode=True).collect)
         self.assertEqual([], rdd.pipe('grep 4').collect())
 
-    def test_stopiteration_in_client_code(self):
+    def test_stopiteration_in_user_code(self):
 
         def stopit(*x):
             raise StopIteration()
 
         seq_rdd = self.sc.parallelize(range(10))
         keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10))
-
-        self.assertRaises(Py4JJavaError, seq_rdd.map(stopit).collect)
-        self.assertRaises(Py4JJavaError, seq_rdd.filter(stopit).collect)
-        self.assertRaises(Py4JJavaError, 
seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect)
-        self.assertRaises(Py4JJavaError, seq_rdd.foreach, stopit)
-        self.assertRaises(Py4JJavaError, keyed_rdd.reduceByKeyLocally, stopit)
-        self.assertRaises(Py4JJavaError, seq_rdd.reduce, stopit)
-        self.assertRaises(Py4JJavaError, seq_rdd.fold, 0, stopit)
-
-        # the exception raised is non-deterministic
-        self.assertRaises((Py4JJavaError, RuntimeError),
-                          seq_rdd.aggregate, 0, stopit, lambda *x: 1)
-        self.assertRaises((Py4JJavaError, RuntimeError),
-                          seq_rdd.aggregate, 0, lambda *x: 1, stopit)
+        msg = "Caught StopIteration thrown from user's code; failing the task"
+
+        self.assertRaisesRegexp(Py4JJavaError, msg, 
seq_rdd.map(stopit).collect)
+        self.assertRaisesRegexp(Py4JJavaError, msg, 
seq_rdd.filter(stopit).collect)
+        self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit)
+        self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.reduce, stopit)
+        self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.fold, 0, stopit)
+        self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit)
+        self.assertRaisesRegexp(Py4JJavaError, msg,
+                                
seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect)
+
+        # these methods call the user function both in the driver and in the 
executor
+        # the exception raised is different according to where the 
StopIteration happens
+        # RuntimeError is raised if in the driver
+        # Py4JJavaError is raised if in the executor (wraps the RuntimeError 
raised in the worker)
+        self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg,
+                                keyed_rdd.reduceByKeyLocally, stopit)
+        self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg,
+                                seq_rdd.aggregate, 0, stopit, lambda *x: 1)
+        self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg,
+                                seq_rdd.aggregate, 0, lambda *x: 1, stopit)
 
 
 class ProfilerTests(PySparkTestCase):

http://git-wip-us.apache.org/repos/asf/spark/blob/470cacd4/python/pyspark/util.py
----------------------------------------------------------------------
diff --git a/python/pyspark/util.py b/python/pyspark/util.py
index 83d528f..94f51ee 100644
--- a/python/pyspark/util.py
+++ b/python/pyspark/util.py
@@ -48,7 +48,7 @@ def _exception_message(excp):
 def fail_on_stopiteration(f):
     """
     Wraps the input function to fail on 'StopIteration' by raising a 
'RuntimeError'
-    prevents silent loss of data when 'f' is used in a for loop
+    prevents silent loss of data when 'f' is used in a for loop in Spark code
     """
     def wrapper(*args, **kwargs):
         try:

http://git-wip-us.apache.org/repos/asf/spark/blob/470cacd4/python/pyspark/worker.py
----------------------------------------------------------------------
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 44e9106..788b323 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -35,6 +35,7 @@ from pyspark.serializers import write_with_length, write_int, 
read_long, \
     write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
     BatchedSerializer, ArrowStreamPandasSerializer
 from pyspark.sql.types import to_arrow_type
+from pyspark.util import fail_on_stopiteration
 from pyspark import shuffle
 
 pickleSer = PickleSerializer()
@@ -122,13 +123,17 @@ def read_single_udf(pickleSer, infile, eval_type):
         else:
             row_func = chain(row_func, f)
 
+    # make sure StopIteration's raised in the user code are not ignored
+    # when they are processed in a for loop, raise them as RuntimeError's 
instead
+    func = fail_on_stopiteration(row_func)
+
     # the last returnType will be the return type of UDF
     if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF:
-        return arg_offsets, wrap_scalar_pandas_udf(row_func, return_type)
+        return arg_offsets, wrap_scalar_pandas_udf(func, return_type)
     elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
-        return arg_offsets, wrap_grouped_map_pandas_udf(row_func, return_type)
+        return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type)
     else:
-        return arg_offsets, wrap_udf(row_func, return_type)
+        return arg_offsets, wrap_udf(func, return_type)
 
 
 def read_udfs(pickleSer, infile, eval_type):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to