Repository: spark Updated Branches: refs/heads/master 7a82e93b3 -> 223df5d9d
[SPARK-24397][PYSPARK] Added TaskContext.getLocalProperty(key) in Python ## What changes were proposed in this pull request? This adds a new API `TaskContext.getLocalProperty(key)` to the Python TaskContext. It mirrors the Java TaskContext API of returning a string value if the key exists, or None if the key does not exist. ## How was this patch tested? New test added. Author: Tathagata Das <tathagata.das1...@gmail.com> Closes #21437 from tdas/SPARK-24397. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/223df5d9 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/223df5d9 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/223df5d9 Branch: refs/heads/master Commit: 223df5d9d4fbf48db017edb41f9b7e4033679f35 Parents: 7a82e93 Author: Tathagata Das <tathagata.das1...@gmail.com> Authored: Thu May 31 11:23:57 2018 -0700 Committer: Tathagata Das <tathagata.das1...@gmail.com> Committed: Thu May 31 11:23:57 2018 -0700 ---------------------------------------------------------------------- .../org/apache/spark/api/python/PythonRunner.scala | 7 +++++++ python/pyspark/taskcontext.py | 7 +++++++ python/pyspark/tests.py | 14 ++++++++++++++ python/pyspark/worker.py | 6 ++++++ 4 files changed, 34 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/223df5d9/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index f075a7e..41eac10 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -183,6 +183,13 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( dataOut.writeInt(context.partitionId()) dataOut.writeInt(context.attemptNumber()) dataOut.writeLong(context.taskAttemptId()) + val localProps = context.asInstanceOf[TaskContextImpl].getLocalProperties.asScala + dataOut.writeInt(localProps.size) + localProps.foreach { case (k, v) => + PythonRDD.writeUTF(k, dataOut) + PythonRDD.writeUTF(v, dataOut) + } + // sparkFilesDir PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut) // Python includes (*.zip and *.egg files) http://git-wip-us.apache.org/repos/asf/spark/blob/223df5d9/python/pyspark/taskcontext.py ---------------------------------------------------------------------- diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py index e5218d9..63ae1f3 100644 --- a/python/pyspark/taskcontext.py +++ b/python/pyspark/taskcontext.py @@ -34,6 +34,7 @@ class TaskContext(object): _partitionId = None _stageId = None _taskAttemptId = None + _localProperties = None def __new__(cls): """Even if users construct TaskContext instead of using get, give them the singleton.""" @@ -88,3 +89,9 @@ class TaskContext(object): TaskAttemptID. """ return self._taskAttemptId + + def getLocalProperty(self, key): + """ + Get a local property set upstream in the driver, or None if it is missing. + """ + return self._localProperties.get(key, None) http://git-wip-us.apache.org/repos/asf/spark/blob/223df5d9/python/pyspark/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 3b37cc0..30723b8 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -574,6 +574,20 @@ class TaskContextTests(PySparkTestCase): tc = TaskContext.get() self.assertTrue(tc is None) + def test_get_local_property(self): + """Verify that local properties set on the driver are available in TaskContext.""" + key = "testkey" + value = "testvalue" + self.sc.setLocalProperty(key, value) + try: + rdd = self.sc.parallelize(range(1), 1) + prop1 = rdd.map(lambda x: TaskContext.get().getLocalProperty(key)).collect()[0] + self.assertEqual(prop1, value) + prop2 = rdd.map(lambda x: TaskContext.get().getLocalProperty("otherkey")).collect()[0] + self.assertTrue(prop2 is None) + finally: + self.sc.setLocalProperty(key, None) + class RDDTests(ReusedPySparkTestCase): http://git-wip-us.apache.org/repos/asf/spark/blob/223df5d9/python/pyspark/worker.py ---------------------------------------------------------------------- diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 5d2e58b..fbcb8af 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -222,6 +222,12 @@ def main(infile, outfile): taskContext._partitionId = read_int(infile) taskContext._attemptNumber = read_int(infile) taskContext._taskAttemptId = read_long(infile) + taskContext._localProperties = dict() + for i in range(read_int(infile)): + k = utf8_deserializer.loads(infile) + v = utf8_deserializer.loads(infile) + taskContext._localProperties[k] = v + shuffle.MemoryBytesSpilled = 0 shuffle.DiskBytesSpilled = 0 _accumulatorRegistry.clear() --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org