This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch branch-3.0 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.0 by this push: new a281b9c [SPARK-31549][PYSPARK] Add a develop API invoking collect on Python RDD with user-specified job group a281b9c is described below commit a281b9cc64910f7f708341e8379cd2878461186a Author: Weichen Xu <weichen...@databricks.com> AuthorDate: Fri May 1 10:08:16 2020 +0900 [SPARK-31549][PYSPARK] Add a develop API invoking collect on Python RDD with user-specified job group ### What changes were proposed in this pull request? I add a new API in pyspark RDD class: def collectWithJobGroup(self, groupId, description, interruptOnCancel=False) This API do the same thing with `rdd.collect`, but it can specify the job group when do collect. The purpose of adding this API is, if we use: ``` sc.setJobGroup("group-id...") rdd.collect() ``` The `setJobGroup` API in pyspark won't work correctly. This related to a bug discussed in https://issues.apache.org/jira/browse/SPARK-31549 Note: This PR is a rather temporary workaround for `PYSPARK_PIN_THREAD`, and as a step to migrate to `PYSPARK_PIN_THREAD` smoothly. It targets Spark 3.0. - `PYSPARK_PIN_THREAD` is unstable at this moment that affects whole PySpark applications. - It is impossible to make it runtime configuration as it has to be set before JVM is launched. - There is a thread leak issue between Python and JVM. We should address but it's not a release blocker for Spark 3.0 since the feature is experimental. I plan to handle this after Spark 3.0 due to stability. Once `PYSPARK_PIN_THREAD` is enabled by default, we should remove this API out ideally. I will target to deprecate this API in Spark 3.1. ### Why are the changes needed? Fix bug. ### Does this PR introduce any user-facing change? A develop API in pyspark: `pyspark.RDD. collectWithJobGroup` ### How was this patch tested? Unit test. Closes #28395 from WeichenXu123/collect_with_job_group. Authored-by: Weichen Xu <weichen...@databricks.com> Signed-off-by: HyukjinKwon <gurwls...@apache.org> (cherry picked from commit ee1de66fe4e05754ea3f33b75b83c54772b00112) Signed-off-by: HyukjinKwon <gurwls...@apache.org> --- .../org/apache/spark/api/python/PythonRDD.scala | 15 ++++++ python/pyspark/rdd.py | 13 +++++ python/pyspark/tests/test_rdd.py | 62 ++++++++++++++++++++++ 3 files changed, 90 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 6dc1721f..a577194 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -169,6 +169,21 @@ private[spark] object PythonRDD extends Logging { } /** + * A helper function to collect an RDD as an iterator, then serve it via socket. + * This method is similar with `PythonRDD.collectAndServe`, but user can specify job group id, + * job description, and interruptOnCancel option. + */ + def collectAndServeWithJobGroup[T]( + rdd: RDD[T], + groupId: String, + description: String, + interruptOnCancel: Boolean): Array[Any] = { + val sc = rdd.sparkContext + sc.setJobGroup(groupId, description, interruptOnCancel) + serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}") + } + + /** * A helper function to create a local RDD iterator and serve it via socket. Partitions are * are collected as separate jobs, by order of index. Partition data is first requested by a * non-zero integer to start a collection job. The response is prefaced by an integer with 1 diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 52ab86c..b5b72da 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -889,6 +889,19 @@ class RDD(object): sock_info = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd()) return list(_load_from_socket(sock_info, self._jrdd_deserializer)) + def collectWithJobGroup(self, groupId, description, interruptOnCancel=False): + """ + .. note:: Experimental + + When collect rdd, use this method to specify job group. + + .. versionadded:: 3.0.0 + """ + with SCCallSiteSync(self.context) as css: + sock_info = self.ctx._jvm.PythonRDD.collectAndServeWithJobGroup( + self._jrdd.rdd(), groupId, description, interruptOnCancel) + return list(_load_from_socket(sock_info, self._jrdd_deserializer)) + def reduce(self, f): """ Reduces the elements of this RDD using the specified commutative and diff --git a/python/pyspark/tests/test_rdd.py b/python/pyspark/tests/test_rdd.py index 15cc48ae2..6b11d68 100644 --- a/python/pyspark/tests/test_rdd.py +++ b/python/pyspark/tests/test_rdd.py @@ -784,6 +784,68 @@ class RDDTests(ReusedPySparkTestCase): self.assertEqual(i, next(it)) + def test_multiple_group_jobs(self): + import threading + group_a = "job_ids_to_cancel" + group_b = "job_ids_to_run" + + threads = [] + thread_ids = range(4) + thread_ids_to_cancel = [i for i in thread_ids if i % 2 == 0] + thread_ids_to_run = [i for i in thread_ids if i % 2 != 0] + + # A list which records whether job is cancelled. + # The index of the array is the thread index which job run in. + is_job_cancelled = [False for _ in thread_ids] + + def run_job(job_group, index): + """ + Executes a job with the group ``job_group``. Each job waits for 3 seconds + and then exits. + """ + try: + self.sc.parallelize([15]).map(lambda x: time.sleep(x)) \ + .collectWithJobGroup(job_group, "test rdd collect with setting job group") + is_job_cancelled[index] = False + except Exception: + # Assume that exception means job cancellation. + is_job_cancelled[index] = True + + # Test if job succeeded when not cancelled. + run_job(group_a, 0) + self.assertFalse(is_job_cancelled[0]) + + # Run jobs + for i in thread_ids_to_cancel: + t = threading.Thread(target=run_job, args=(group_a, i)) + t.start() + threads.append(t) + + for i in thread_ids_to_run: + t = threading.Thread(target=run_job, args=(group_b, i)) + t.start() + threads.append(t) + + # Wait to make sure all jobs are executed. + time.sleep(3) + # And then, cancel one job group. + self.sc.cancelJobGroup(group_a) + + # Wait until all threads launching jobs are finished. + for t in threads: + t.join() + + for i in thread_ids_to_cancel: + self.assertTrue( + is_job_cancelled[i], + "Thread {i}: Job in group A was not cancelled.".format(i=i)) + + for i in thread_ids_to_run: + self.assertFalse( + is_job_cancelled[i], + "Thread {i}: Job in group B did not succeeded.".format(i=i)) + + if __name__ == "__main__": import unittest from pyspark.tests.test_rdd import * --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org