Github user mengxr commented on a diff in the pull request:
https://github.com/apache/spark/pull/22085#discussion_r209846054
--- Diff: python/pyspark/taskcontext.py ---
@@ -95,3 +95,92 @@ def getLocalProperty(self, key):
Get a local property set upstream in the driver, or None if it is
missing.
"""
return self._localProperties.get(key, None)
+
+
+class BarrierTaskContext(TaskContext):
+
+ """
+ .. note:: Experimental
+
+ A TaskContext with extra info and tooling for a barrier stage. To
access the BarrierTaskContext
+ for a running task, use:
+ L{BarrierTaskContext.get()}.
+
+ .. versionadded:: 2.4.0
+ """
+
+ _barrierContext = None
+
+ def __init__(self):
+ """Construct a BarrierTaskContext, use get instead"""
+ pass
+
+ @classmethod
+ def _getOrCreate(cls):
+ """Internal function to get or create global BarrierTaskContext."""
+ if cls._taskContext is None:
+ cls._taskContext = BarrierTaskContext()
+ return cls._taskContext
+
+ @classmethod
+ def get(cls):
+ """
+ Return the currently active BarrierTaskContext. This can be called
inside of user functions
+ to access contextual information about running tasks.
+
+ .. note:: Must be called on the worker, not the driver. Returns
None if not initialized.
+ """
+ return cls._taskContext
+
+ @classmethod
+ def _initialize(cls, ctx):
+ """
+ Initialize BarrierTaskContext, other methods within
BarrierTaskContext can only be called
+ after BarrierTaskContext is initialized.
+ """
+ cls._barrierContext = ctx
+
+ def barrier(self):
+ """
+ .. note:: Experimental
+
+ Sets a global barrier and waits until all tasks in this stage hit
this barrier.
+ Note this method is only allowed for a BarrierTaskContext.
+
+ .. versionadded:: 2.4.0
+ """
+ if self._barrierContext is None:
+ raise Exception("Not supported to call barrier() before
initialize " +
+ "BarrierTaskContext.")
+ else:
+ self._barrierContext.barrier()
+
+ def getTaskInfos(self):
+ """
+ .. note:: Experimental
+
+ Returns the all task infos in this barrier stage, the task infos
are ordered by
+ partitionId.
+ Note this method is only allowed for a BarrierTaskContext.
+
+ .. versionadded:: 2.4.0
+ """
+ if self._barrierContext is None:
+ raise Exception("Not supported to call getTaskInfos() before
initialize " +
+ "BarrierTaskContext.")
+ else:
+ java_list = self._barrierContext.getTaskInfos()
+ return [BarrierTaskInfo(h) for h in java_list]
+
+
+class BarrierTaskInfo(object):
+ """
+ .. note:: Experimental
+
+ Carries all task infos of a barrier task.
+
+ .. versionadded:: 2.4.0
+ """
+
+ def __init__(self, info):
+ self.address = info.address
--- End diff --
* should be `info.address()`
* better to rename `info` to `jobj` to make it clear this is from Java
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]