Repository: spark
Updated Branches:
  refs/heads/master 7a82e93b3 -> 223df5d9d


[SPARK-24397][PYSPARK] Added TaskContext.getLocalProperty(key) in Python

## What changes were proposed in this pull request?

This adds a new API `TaskContext.getLocalProperty(key)` to the Python 
TaskContext. It mirrors the Java TaskContext API of returning a string value if 
the key exists, or None if the key does not exist.

## How was this patch tested?
New test added.

Author: Tathagata Das <tathagata.das1...@gmail.com>

Closes #21437 from tdas/SPARK-24397.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/223df5d9
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/223df5d9
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/223df5d9

Branch: refs/heads/master
Commit: 223df5d9d4fbf48db017edb41f9b7e4033679f35
Parents: 7a82e93
Author: Tathagata Das <tathagata.das1...@gmail.com>
Authored: Thu May 31 11:23:57 2018 -0700
Committer: Tathagata Das <tathagata.das1...@gmail.com>
Committed: Thu May 31 11:23:57 2018 -0700

----------------------------------------------------------------------
 .../org/apache/spark/api/python/PythonRunner.scala    |  7 +++++++
 python/pyspark/taskcontext.py                         |  7 +++++++
 python/pyspark/tests.py                               | 14 ++++++++++++++
 python/pyspark/worker.py                              |  6 ++++++
 4 files changed, 34 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/223df5d9/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index f075a7e..41eac10 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -183,6 +183,13 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
         dataOut.writeInt(context.partitionId())
         dataOut.writeInt(context.attemptNumber())
         dataOut.writeLong(context.taskAttemptId())
+        val localProps = 
context.asInstanceOf[TaskContextImpl].getLocalProperties.asScala
+        dataOut.writeInt(localProps.size)
+        localProps.foreach { case (k, v) =>
+          PythonRDD.writeUTF(k, dataOut)
+          PythonRDD.writeUTF(v, dataOut)
+        }
+
         // sparkFilesDir
         PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut)
         // Python includes (*.zip and *.egg files)

http://git-wip-us.apache.org/repos/asf/spark/blob/223df5d9/python/pyspark/taskcontext.py
----------------------------------------------------------------------
diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py
index e5218d9..63ae1f3 100644
--- a/python/pyspark/taskcontext.py
+++ b/python/pyspark/taskcontext.py
@@ -34,6 +34,7 @@ class TaskContext(object):
     _partitionId = None
     _stageId = None
     _taskAttemptId = None
+    _localProperties = None
 
     def __new__(cls):
         """Even if users construct TaskContext instead of using get, give them 
the singleton."""
@@ -88,3 +89,9 @@ class TaskContext(object):
         TaskAttemptID.
         """
         return self._taskAttemptId
+
+    def getLocalProperty(self, key):
+        """
+        Get a local property set upstream in the driver, or None if it is 
missing.
+        """
+        return self._localProperties.get(key, None)

http://git-wip-us.apache.org/repos/asf/spark/blob/223df5d9/python/pyspark/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 3b37cc0..30723b8 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -574,6 +574,20 @@ class TaskContextTests(PySparkTestCase):
         tc = TaskContext.get()
         self.assertTrue(tc is None)
 
+    def test_get_local_property(self):
+        """Verify that local properties set on the driver are available in 
TaskContext."""
+        key = "testkey"
+        value = "testvalue"
+        self.sc.setLocalProperty(key, value)
+        try:
+            rdd = self.sc.parallelize(range(1), 1)
+            prop1 = rdd.map(lambda x: 
TaskContext.get().getLocalProperty(key)).collect()[0]
+            self.assertEqual(prop1, value)
+            prop2 = rdd.map(lambda x: 
TaskContext.get().getLocalProperty("otherkey")).collect()[0]
+            self.assertTrue(prop2 is None)
+        finally:
+            self.sc.setLocalProperty(key, None)
+
 
 class RDDTests(ReusedPySparkTestCase):
 

http://git-wip-us.apache.org/repos/asf/spark/blob/223df5d9/python/pyspark/worker.py
----------------------------------------------------------------------
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 5d2e58b..fbcb8af 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -222,6 +222,12 @@ def main(infile, outfile):
         taskContext._partitionId = read_int(infile)
         taskContext._attemptNumber = read_int(infile)
         taskContext._taskAttemptId = read_long(infile)
+        taskContext._localProperties = dict()
+        for i in range(read_int(infile)):
+            k = utf8_deserializer.loads(infile)
+            v = utf8_deserializer.loads(infile)
+            taskContext._localProperties[k] = v
+
         shuffle.MemoryBytesSpilled = 0
         shuffle.DiskBytesSpilled = 0
         _accumulatorRegistry.clear()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to