asfgit closed pull request #23470: [SPARK-26549][PySpark] Fix for python worker
reuse take no effect for parallelize lazy iterable range
URL: https://github.com/apache/spark/pull/23470
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 6137ed25a0dd9..180a3e882dab6 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -493,6 +493,14 @@ def getStart(split):
return start0 + int((split * size / numSlices)) * step
def f(split, iterator):
+ # it's an empty iterator here but we need this line for
triggering the
+ # logic of signal handling in FramedSerializer.load_stream,
for instance,
+ # SpecialLengths.END_OF_DATA_SECTION in _read_with_length.
Since
+ # FramedSerializer.load_stream produces a generator, the
control should
+ # at least be in that function once. Here we do it by
explicitly converting
+ # the empty iterator to a list, thus make sure worker reuse
takes effect.
+ # See more details in SPARK-26549.
+ assert len(list(iterator)) == 0
return xrange(getStart(split), getStart(split + 1), step)
return self.parallelize([], numSlices).mapPartitionsWithIndex(f)
diff --git a/python/pyspark/tests/test_worker.py
b/python/pyspark/tests/test_worker.py
index a33b77d983419..a4f108f18e17d 100644
--- a/python/pyspark/tests/test_worker.py
+++ b/python/pyspark/tests/test_worker.py
@@ -22,7 +22,7 @@
from py4j.protocol import Py4JJavaError
-from pyspark.testing.utils import ReusedPySparkTestCase, QuietTest
+from pyspark.testing.utils import ReusedPySparkTestCase, PySparkTestCase,
QuietTest
if sys.version_info[0] >= 3:
xrange = range
@@ -145,6 +145,16 @@ def test_with_different_versions_of_python(self):
self.sc.pythonVer = version
+class WorkerReuseTest(PySparkTestCase):
+
+ def test_reuse_worker_of_parallelize_xrange(self):
+ rdd = self.sc.parallelize(xrange(20), 8)
+ previous_pids = rdd.map(lambda x: os.getpid()).collect()
+ current_pids = rdd.map(lambda x: os.getpid()).collect()
+ for pid in current_pids:
+ self.assertTrue(pid in previous_pids)
+
+
if __name__ == "__main__":
import unittest
from pyspark.tests.test_worker import *
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]