Use the cythonized DoFnContext everywhere. Of particular note we want this to be strongly typed in the DoFnRunner as one interacts with it for every element processed.
Also moves the creation of this context, which is an implementation detail, from the runner to the API for future flexibility. Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/0529a127 Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/0529a127 Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/0529a127 Branch: refs/heads/python-sdk Commit: 0529a12771f55dc35a4f4e1a7b63af6ba4fe4d84 Parents: d72ffb0 Author: Robert Bradshaw <[email protected]> Authored: Thu Aug 4 10:55:18 2016 -0700 Committer: Dan Halperin <[email protected]> Committed: Wed Aug 10 09:42:00 2016 -0700 ---------------------------------------------------------------------- sdks/python/apache_beam/runners/common.pxd | 2 +- sdks/python/apache_beam/runners/common.py | 18 ++++++++++++++---- sdks/python/apache_beam/runners/direct_runner.py | 9 +++------ .../runners/inprocess/transform_evaluator.py | 8 +++----- 4 files changed, 21 insertions(+), 16 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/0529a127/sdks/python/apache_beam/runners/common.pxd ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/common.pxd b/sdks/python/apache_beam/runners/common.pxd index 7acd049..5cd4cf8 100644 --- a/sdks/python/apache_beam/runners/common.pxd +++ b/sdks/python/apache_beam/runners/common.pxd @@ -32,7 +32,7 @@ cdef class DoFnRunner(Receiver): cdef object dofn cdef object dofn_process cdef object window_fn - cdef object context # TODO(robertwb): Make this a DoFnContext + cdef DoFnContext context cdef object tagged_receivers cdef LoggingContext logging_context cdef object step_name http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/0529a127/sdks/python/apache_beam/runners/common.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/common.py b/sdks/python/apache_beam/runners/common.py index c017704..67277c3 100644 --- a/sdks/python/apache_beam/runners/common.py +++ b/sdks/python/apache_beam/runners/common.py @@ -59,12 +59,16 @@ class DoFnRunner(Receiver): kwargs, side_inputs, windowing, - context, - tagged_receivers, + context=None, + tagged_receivers=None, logger=None, step_name=None, # Preferred alternative to logger - logging_context=None): + # TODO(robertwb): Remove once all runners are updated. + logging_context=None, + # Preferred alternative to context + # TODO(robertwb): Remove once all runners are updated. + state=None): if not args and not kwargs: self.dofn = fn self.dofn_process = fn.process @@ -85,10 +89,16 @@ class DoFnRunner(Receiver): self.dofn_process = lambda context: fn.process(context, *args, **kwargs) self.window_fn = windowing.windowfn - self.context = context self.tagged_receivers = tagged_receivers self.step_name = step_name + if state: + assert context is None + self.context = DoFnContext(self.step_name, state=state) + else: + assert context is not None + self.context = context + if logging_context: self.logging_context = logging_context else: http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/0529a127/sdks/python/apache_beam/runners/direct_runner.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/direct_runner.py b/sdks/python/apache_beam/runners/direct_runner.py index e0df439..a62ddf7 100644 --- a/sdks/python/apache_beam/runners/direct_runner.py +++ b/sdks/python/apache_beam/runners/direct_runner.py @@ -43,7 +43,6 @@ from apache_beam.runners.runner import PipelineResult from apache_beam.runners.runner import PipelineRunner from apache_beam.runners.runner import PipelineState from apache_beam.runners.runner import PValueCache -from apache_beam.transforms import DoFnProcessContext from apache_beam.transforms.window import GlobalWindows from apache_beam.transforms.window import WindowedValue from apache_beam.typehints.typecheck import OutputCheckWrapperDoFn @@ -138,9 +137,6 @@ class DirectPipelineRunner(PipelineRunner): @skip_if_cached def run_ParDo(self, transform_node): transform = transform_node.transform - # TODO(gildea): what is the appropriate object to attach the state to? - context = DoFnProcessContext(label=transform.label, - state=DoFnState(self._counter_factory)) side_inputs = [self._cache.get_pvalue(view) for view in transform_node.side_inputs] @@ -176,8 +172,9 @@ class DirectPipelineRunner(PipelineRunner): runner = DoFnRunner(transform.dofn, transform.args, transform.kwargs, side_inputs, transform_node.inputs[0].windowing, - context, TaggedReceivers(), - step_name=transform_node.full_label) + tagged_receivers=TaggedReceivers(), + step_name=transform_node.full_label, + state=DoFnState(self._counter_factory)) runner.start() for v in self._cache.get_pvalue(transform_node.inputs[0]): runner.process(v) http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/0529a127/sdks/python/apache_beam/runners/inprocess/transform_evaluator.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/inprocess/transform_evaluator.py b/sdks/python/apache_beam/runners/inprocess/transform_evaluator.py index 138ea87..9aeda46 100644 --- a/sdks/python/apache_beam/runners/inprocess/transform_evaluator.py +++ b/sdks/python/apache_beam/runners/inprocess/transform_evaluator.py @@ -30,7 +30,6 @@ from apache_beam.runners.common import DoFnState from apache_beam.runners.inprocess.inprocess_watermark_manager import InProcessWatermarkManager from apache_beam.runners.inprocess.inprocess_transform_result import InProcessTransformResult from apache_beam.transforms import core -from apache_beam.transforms import DoFnProcessContext from apache_beam.transforms import sideinputs from apache_beam.transforms.window import GlobalWindows from apache_beam.transforms.window import WindowedValue @@ -337,8 +336,6 @@ class _ParDoEvaluator(_TransformEvaluator): self._tagged_receivers[None].tag = None # main_tag is None. self._counter_factory = counters.CounterFactory() - context = DoFnProcessContext(label=transform.label, - state=DoFnState(self._counter_factory)) dofn = copy.deepcopy(transform.dofn) @@ -351,8 +348,9 @@ class _ParDoEvaluator(_TransformEvaluator): self.runner = DoFnRunner(dofn, transform.args, transform.kwargs, self._side_inputs, self._applied_ptransform.inputs[0].windowing, - context, self._tagged_receivers, - step_name=self._applied_ptransform.full_label) + tagged_receivers=self._tagged_receivers, + step_name=self._applied_ptransform.full_label, + state=DoFnState(self._counter_factory)) self.runner.start() def process_element(self, element):
