This is an automated email from the ASF dual-hosted git repository.

pabloem pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 4be760e  [BEAM-9608] Increase reliance on Context Managers for 
FnApiRunner
     new a55ee53  Merge pull request #11229 from [BEAM-9608] Increasing scope 
of context managers for FnApiRunner
4be760e is described below

commit 4be760ee74be196d2edf3b6f05cefdce820c5e26
Author: Pablo Estrada <[email protected]>
AuthorDate: Wed Mar 25 14:27:56 2020 -0700

    [BEAM-9608] Increase reliance on Context Managers for FnApiRunner
---
 .../runners/portability/fn_api_runner/execution.py | 119 ++++++++++++---
 .../runners/portability/fn_api_runner/fn_runner.py | 166 +++++++--------------
 .../portability/fn_api_runner/worker_handlers.py   |   6 +-
 3 files changed, 153 insertions(+), 138 deletions(-)

diff --git 
a/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py 
b/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py
index 99c2438..02855dc 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py
@@ -21,6 +21,8 @@ from __future__ import absolute_import
 
 import collections
 import itertools
+from typing import TYPE_CHECKING
+from typing import MutableMapping
 
 from typing_extensions import Protocol
 
@@ -29,14 +31,21 @@ from apache_beam.coders.coder_impl import create_InputStream
 from apache_beam.coders.coder_impl import create_OutputStream
 from apache_beam.portability import common_urns
 from apache_beam.portability.api import beam_fn_api_pb2
+from apache_beam.runners import pipeline_context
 from apache_beam.runners.portability.fn_api_runner.translations import 
only_element
 from apache_beam.runners.portability.fn_api_runner.translations import 
split_buffer_id
+from apache_beam.runners.portability.fn_api_runner.translations import 
unique_name
 from apache_beam.runners.worker import bundle_processor
 from apache_beam.transforms import trigger
 from apache_beam.transforms.window import GlobalWindow
 from apache_beam.transforms.window import GlobalWindows
 from apache_beam.utils import windowed_value
 
+if TYPE_CHECKING:
+  from apache_beam.coders.coder_impl import CoderImpl
+  from apache_beam.runners.portability.fn_api_runner import translations
+  from apache_beam.runners.portability.fn_api_runner import worker_handlers
+
 
 class Buffer(Protocol):
   def __iter__(self):
@@ -245,37 +254,108 @@ class FnApiRunnerExecutionContext(object):
        ``beam.PCollection``.
  """
   def __init__(self,
-      worker_handler_factory,  # type: Callable[[Optional[str], int], 
List[WorkerHandler]]
+      worker_handler_manager,  # type: worker_handlers.WorkerHandlerManager
       pipeline_components,  # type: beam_runner_api_pb2.Components
       safe_coders,
       data_channel_coders,
                ):
     """
-    :param worker_handler_factory: A ``callable`` that takes in an environment
-        id and a number of workers, and returns a list of ``WorkerHandler``s.
+    :param worker_handler_manager: This class manages the set of worker
+        handlers, and the communication with state / control APIs.
     :param pipeline_components:  (beam_runner_api_pb2.Components): TODO
     :param safe_coders:
     :param data_channel_coders:
     """
     self.pcoll_buffers = {}  # type: MutableMapping[bytes, PartitionableBuffer]
-    self.worker_handler_factory = worker_handler_factory
+    self.worker_handler_manager = worker_handler_manager
     self.pipeline_components = pipeline_components
     self.safe_coders = safe_coders
     self.data_channel_coders = data_channel_coders
 
+    self.pipeline_context = pipeline_context.PipelineContext(
+        self.pipeline_components,
+        iterable_state_write=self._iterable_state_write)
+    self._last_uid = -1
+
+  @property
+  def state_servicer(self):
+    # TODO(BEAM-9625): Ensure FnApiRunnerExecutionContext owns StateServicer
+    return self.worker_handler_manager.state_servicer
+
+  def next_uid(self):
+    self._last_uid += 1
+    return str(self._last_uid)
+
+  def _iterable_state_write(self, values, element_coder_impl):
+    # type: (...) -> bytes
+    token = unique_name(None, 'iter').encode('ascii')
+    out = create_OutputStream()
+    for element in values:
+      element_coder_impl.encode_to_stream(element, out, True)
+    self.worker_handler_manager.state_servicer.append_raw(
+        beam_fn_api_pb2.StateKey(
+            runner=beam_fn_api_pb2.StateKey.Runner(key=token)),
+        out.get())
+    return token
+
 
 class BundleContextManager(object):
 
   def __init__(self,
-      execution_context, # type: FnApiRunnerExecutionContext
-      process_bundle_descriptor,  # type: 
beam_fn_api_pb2.ProcessBundleDescriptor
-      worker_handler,  # type: fn_runner.WorkerHandler
-      p_context,  # type: pipeline_context.PipelineContext
-               ):
+               execution_context, # type: FnApiRunnerExecutionContext
+               stage,  # type: translations.Stage
+               num_workers,  # type: int
+              ):
     self.execution_context = execution_context
-    self.process_bundle_descriptor = process_bundle_descriptor
-    self.worker_handler = worker_handler
-    self.pipeline_context = p_context
+    self.stage = stage
+    self.bundle_uid = self.execution_context.next_uid()
+    self.num_workers = num_workers
+
+    # Properties that are lazily initialized
+    self._process_bundle_descriptor = None
+    self._worker_handlers = None
+
+  @property
+  def worker_handlers(self):
+    if self._worker_handlers is None:
+      self._worker_handlers = (
+          self.execution_context.worker_handler_manager.get_worker_handlers(
+              self.stage.environment, self.num_workers))
+    return self._worker_handlers
+
+  def data_api_service_descriptor(self):
+    # All worker_handlers share the same grpc server, so we can read grpc 
server
+    # info from any worker_handler and read from the first worker_handler.
+    return self.worker_handlers[0].data_api_service_descriptor()
+
+  def state_api_service_descriptor(self):
+    # All worker_handlers share the same grpc server, so we can read grpc 
server
+    # info from any worker_handler and read from the first worker_handler.
+    return self.worker_handlers[0].state_api_service_descriptor()
+
+  @property
+  def process_bundle_descriptor(self):
+    if self._process_bundle_descriptor is None:
+      self._process_bundle_descriptor = self._build_process_bundle_descriptor()
+    return self._process_bundle_descriptor
+
+  def _build_process_bundle_descriptor(self):
+    res = beam_fn_api_pb2.ProcessBundleDescriptor(
+        id=self.bundle_uid,
+        transforms={
+            transform.unique_name: transform
+            for transform in self.stage.transforms
+        },
+        pcollections=dict(
+            self.execution_context.pipeline_components.pcollections.items()),
+        coders=dict(self.execution_context.pipeline_components.coders.items()),
+        windowing_strategies=dict(
+            self.execution_context.pipeline_components.windowing_strategies.
+            items()),
+        environments=dict(
+            self.execution_context.pipeline_components.environments.items()),
+        state_api_service_descriptor=self.state_api_service_descriptor())
+    return res
 
   def get_input_coder_impl(self, transform_id):
     # type: (str) -> CoderImpl
@@ -284,10 +364,10 @@ class BundleContextManager(object):
     ).coder_id
     assert coder_id
     if coder_id in self.execution_context.safe_coders:
-      return self.pipeline_context.coders[
+      return self.execution_context.pipeline_context.coders[
           self.execution_context.safe_coders[coder_id]].get_impl()
     else:
-      return self.pipeline_context.coders[coder_id].get_impl()
+      return 
self.execution_context.pipeline_context.coders[coder_id].get_impl()
 
   def get_buffer(self, buffer_id, transform_id):
     # type: (bytes, str) -> PartitionableBuffer
@@ -310,15 +390,16 @@ class BundleContextManager(object):
             original_gbk_transform]
         input_pcoll = only_element(list(transform_proto.inputs.values()))
         output_pcoll = only_element(list(transform_proto.outputs.values()))
-        pre_gbk_coder = self.pipeline_context.coders[
+        pre_gbk_coder = self.execution_context.pipeline_context.coders[
             self.execution_context.safe_coders[
                 self.execution_context.data_channel_coders[input_pcoll]]]
-        post_gbk_coder = self.pipeline_context.coders[
+        post_gbk_coder = self.execution_context.pipeline_context.coders[
             self.execution_context.safe_coders[
                 self.execution_context.data_channel_coders[output_pcoll]]]
-        windowing_strategy = self.pipeline_context.windowing_strategies[
-            self.execution_context.pipeline_components.
-            pcollections[output_pcoll].windowing_strategy_id]
+        windowing_strategy = (
+            self.execution_context.pipeline_context.windowing_strategies[
+                self.execution_context.pipeline_components.
+                pcollections[output_pcoll].windowing_strategy_id])
         self.execution_context.pcoll_buffers[buffer_id] = GroupingBuffer(
             pre_gbk_coder, post_gbk_coder, windowing_strategy)
     else:
diff --git 
a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py 
b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py
index 13646da..fa0af3b 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py
@@ -55,7 +55,6 @@ from apache_beam.portability import common_urns
 from apache_beam.portability.api import beam_fn_api_pb2
 from apache_beam.portability.api import beam_provision_api_pb2
 from apache_beam.portability.api import beam_runner_api_pb2
-from apache_beam.runners import pipeline_context
 from apache_beam.runners import runner
 from apache_beam.runners.portability import portable_metrics
 from apache_beam.runners.portability.fn_api_runner import execution
@@ -65,7 +64,6 @@ from apache_beam.runners.portability.fn_api_runner.execution 
import WindowGroupi
 from apache_beam.runners.portability.fn_api_runner.translations import 
create_buffer_id
 from apache_beam.runners.portability.fn_api_runner.translations import 
only_element
 from apache_beam.runners.portability.fn_api_runner.translations import 
split_buffer_id
-from apache_beam.runners.portability.fn_api_runner.translations import 
unique_name
 from apache_beam.runners.portability.fn_api_runner.worker_handlers import 
WorkerHandlerManager
 from apache_beam.runners.worker import bundle_processor
 from apache_beam.transforms import environments
@@ -120,7 +118,6 @@ class FnApiRunner(runner.PipelineRunner):
           waits before requesting progress from the SDK.
     """
     super(FnApiRunner, self).__init__()
-    self._last_uid = -1
     self._default_environment = (
         default_environment or environments.EmbeddedPythonEnvironment())
     self._bundle_repeat = bundle_repeat
@@ -132,10 +129,6 @@ class FnApiRunner(runner.PipelineRunner):
         beam_provision_api_pb2.ProvisionInfo(
             retrieval_token='unused-retrieval-token'))
 
-  def _next_uid(self):
-    self._last_uid += 1
-    return str(self._last_uid)
-
   @staticmethod
   def supported_requirements():
     return (
@@ -329,7 +322,7 @@ class FnApiRunner(runner.PipelineRunner):
     monitoring_infos_by_stage = {}
 
     runner_execution_context = execution.FnApiRunnerExecutionContext(
-        worker_handler_manager.get_worker_handlers,
+        worker_handler_manager,
         stage_context.components,
         stage_context.safe_coders,
         stage_context.data_channel_coders)
@@ -337,7 +330,12 @@ class FnApiRunner(runner.PipelineRunner):
     try:
       with self.maybe_profile():
         for stage in stages:
-          stage_results = self._run_stage(runner_execution_context, stage)
+          bundle_context_manager = execution.BundleContextManager(
+              runner_execution_context, stage, self._num_workers)
+          stage_results = self._run_stage(
+              runner_execution_context,
+              bundle_context_manager,
+          )
           metrics_by_stage[stage.name] = stage_results.process_bundle.metrics
           monitoring_infos_by_stage[stage.name] = (
               stage_results.process_bundle.monitoring_infos)
@@ -347,14 +345,13 @@ class FnApiRunner(runner.PipelineRunner):
         runner.PipelineState.DONE, monitoring_infos_by_stage, metrics_by_stage)
 
   def _store_side_inputs_in_state(self,
-                                  bundle_context_manager,  # type: 
execution.BundleContextManager
                                   runner_execution_context,  # type: 
execution.FnApiRunnerExecutionContext
                                   data_side_input,  # type: DataSideInput
                                  ):
     # type: (...) -> None
     for (transform_id, tag), (buffer_id, si) in data_side_input.items():
       _, pcoll_id = split_buffer_id(buffer_id)
-      value_coder = bundle_context_manager.pipeline_context.coders[
+      value_coder = runner_execution_context.pipeline_context.coders[
           runner_execution_context.safe_coders[
               runner_execution_context.data_channel_coders[pcoll_id]]]
       elements_by_window = WindowGroupingBuffer(si, value_coder)
@@ -369,8 +366,9 @@ class FnApiRunner(runner.PipelineRunner):
           state_key = beam_fn_api_pb2.StateKey(
               iterable_side_input=beam_fn_api_pb2.StateKey.IterableSideInput(
                   transform_id=transform_id, side_input_id=tag, window=window))
-          bundle_context_manager.worker_handler.state.append_raw(
-              state_key, elements_data)
+          (
+              runner_execution_context.worker_handler_manager.state_servicer.
+              append_raw(state_key, elements_data))
       elif si.urn == common_urns.side_inputs.MULTIMAP.urn:
         for key, window, elements_data in elements_by_window.encoded_items():
           state_key = beam_fn_api_pb2.StateKey(
@@ -379,14 +377,15 @@ class FnApiRunner(runner.PipelineRunner):
                   side_input_id=tag,
                   window=window,
                   key=key))
-          bundle_context_manager.worker_handler.state.append_raw(
-              state_key, elements_data)
+          (
+              runner_execution_context.worker_handler_manager.state_servicer.
+              append_raw(state_key, elements_data))
       else:
         raise ValueError("Unknown access pattern: '%s'" % si.urn)
 
   def _run_bundle_multiple_times_for_testing(
       self,
-      worker_handler_list,  # type: Sequence[WorkerHandler]
+      runner_execution_context,  # type: execution.FnApiRunnerExecutionContext
       bundle_context_manager,  # type: execution.BundleContextManager
       data_input,
       data_output,  # type: DataOutput
@@ -398,12 +397,11 @@ class FnApiRunner(runner.PipelineRunner):
     If bundle_repeat > 0, replay every bundle for profiling and debugging.
     """
     # all workers share state, so use any worker_handler.
-    worker_handler = worker_handler_list[0]
     for k in range(self._bundle_repeat):
       try:
-        worker_handler.state.checkpoint()
+        runner_execution_context.state_servicer.checkpoint()
         testing_bundle_manager = ParallelBundleManager(
-            worker_handler_list,
+            bundle_context_manager.worker_handlers,
             lambda pcoll_id,
             transform_id: ListBuffer(
                 coder_impl=bundle_context_manager.get_input_coder_impl),
@@ -415,24 +413,24 @@ class FnApiRunner(runner.PipelineRunner):
             cache_token_generator=cache_token_generator)
         testing_bundle_manager.process_bundle(data_input, data_output)
       finally:
-        worker_handler.state.restore()
+        runner_execution_context.state_servicer.restore()
 
   def _collect_written_timers_and_add_to_deferred_inputs(
       self,
-      pipeline_components,  # type: beam_runner_api_pb2.Components
-      stage,  # type: translations.Stage
+      runner_execution_context,  # type: execution.FnApiRunnerExecutionContext
       bundle_context_manager,  # type: execution.BundleContextManager
       deferred_inputs,  # type: MutableMapping[str, PartitionableBuffer]
-      data_channel_coders,  # type: Mapping[str, str]
   ):
     # type: (...) -> None
 
-    for transform_id, timer_writes in stage.timer_pcollections:
+    for (transform_id,
+         timer_writes) in bundle_context_manager.stage.timer_pcollections:
 
       # Queue any set timers as new inputs.
       windowed_timer_coder_impl = (
-          bundle_context_manager.pipeline_context.coders[
-              data_channel_coders[timer_writes]].get_impl())
+          runner_execution_context.pipeline_context.coders[
+              runner_execution_context.data_channel_coders[timer_writes]].
+          get_impl())
       written_timers = bundle_context_manager.get_buffer(
           create_buffer_id(timer_writes, kind='timers'), transform_id)
       if not written_timers.cleared:
@@ -511,87 +509,32 @@ class FnApiRunner(runner.PipelineRunner):
 
   def _run_stage(self,
                  runner_execution_context,  # type: 
execution.FnApiRunnerExecutionContext
-                 stage,  # type: translations.Stage
+                 bundle_context_manager,  # type: 
execution.BundleContextManager
                 ):
     # type: (...) -> beam_fn_api_pb2.InstructionResponse
 
     """Run an individual stage.
 
     Args:
-      worker_handler_factory: A ``callable`` that takes in an environment id
-        and a number of workers, and returns a list of ``WorkerHandler``s.
-      stage (translations.Stage)
+      runner_execution_context (execution.FnApiRunnerExecutionContext): An
+        object containing execution information for the pipeline.
+      stage (translations.Stage): A description of the stage to execute.
     """
-    def iterable_state_write(values, element_coder_impl):
-      # type: (...) -> bytes
-      token = unique_name(None, 'iter').encode('ascii')
-      out = create_OutputStream()
-      for element in values:
-        element_coder_impl.encode_to_stream(element, out, True)
-      worker_handler.state.append_raw(
-          beam_fn_api_pb2.StateKey(
-              runner=beam_fn_api_pb2.StateKey.Runner(key=token)),
-          out.get())
-      return token
-
-    worker_handler_list = runner_execution_context.worker_handler_factory(
-        stage.environment, self._num_workers)
-
-    # All worker_handlers share the same grpc server, so we can read grpc 
server
-    # info from any worker_handler and read from the first worker_handler.
-    worker_handler = next(iter(worker_handler_list))
-    context = pipeline_context.PipelineContext(
-        runner_execution_context.pipeline_components,
-        iterable_state_write=iterable_state_write)
-    data_api_service_descriptor = worker_handler.data_api_service_descriptor()
-
-    _LOGGER.info('Running %s', stage.name)
-    data_input, data_side_input, data_output = self._extract_endpoints(
-        stage,
-        runner_execution_context.pipeline_components,
-        data_api_service_descriptor,
-        runner_execution_context.pcoll_buffers,
-        context,
-        runner_execution_context.safe_coders,
-        runner_execution_context.data_channel_coders)
-
-    process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor(
-        id=self._next_uid(),
-        transforms={
-            transform.unique_name: transform
-            for transform in stage.transforms
-        },
-        pcollections=dict(
-            runner_execution_context.pipeline_components.pcollections.items()),
-        coders=dict(
-            runner_execution_context.pipeline_components.coders.items()),
-        windowing_strategies=dict(
-            runner_execution_context.pipeline_components.windowing_strategies.
-            items()),
-        environments=dict(
-            runner_execution_context.pipeline_components.environments.items()))
-
-    bundle_context_manager = execution.BundleContextManager(
-        runner_execution_context,
-        process_bundle_descriptor,
-        worker_handler,
-        context)
+    worker_handler_list = bundle_context_manager.worker_handlers
 
-    state_api_service_descriptor = 
worker_handler.state_api_service_descriptor()
-    if state_api_service_descriptor:
-      process_bundle_descriptor.state_api_service_descriptor.url = (
-          state_api_service_descriptor.url)
+    _LOGGER.info('Running %s', bundle_context_manager.stage.name)
+    data_input, data_side_input, data_output = self._extract_endpoints(
+        bundle_context_manager, runner_execution_context)
 
     # Store the required side inputs into state so it is accessible for the
     # worker when it runs this bundle.
-    self._store_side_inputs_in_state(
-        bundle_context_manager, runner_execution_context, data_side_input)
+    self._store_side_inputs_in_state(runner_execution_context, data_side_input)
 
     # Change cache token across bundle repeats
     cache_token_generator = FnApiRunner.get_cache_token_generator(static=False)
 
     self._run_bundle_multiple_times_for_testing(
-        worker_handler_list,
+        runner_execution_context,
         bundle_context_manager,
         data_input,
         data_output,
@@ -601,7 +544,7 @@ class FnApiRunner(runner.PipelineRunner):
         worker_handler_list,
         bundle_context_manager.get_buffer,
         bundle_context_manager.get_input_coder_impl,
-        process_bundle_descriptor,
+        bundle_context_manager.process_bundle_descriptor,
         self._progress_frequency,
         num_workers=self._num_workers,
         cache_token_generator=cache_token_generator)
@@ -619,11 +562,7 @@ class FnApiRunner(runner.PipelineRunner):
       deferred_inputs = {}  # type: Dict[str, PartitionableBuffer]
 
       self._collect_written_timers_and_add_to_deferred_inputs(
-          runner_execution_context.pipeline_components,
-          stage,
-          bundle_context_manager,
-          deferred_inputs,
-          runner_execution_context.data_channel_coders)
+          runner_execution_context, bundle_context_manager, deferred_inputs)
       # Queue any process-initiated delayed bundle applications.
       for delayed_application in last_result.process_bundle.residual_roots:
         name = bundle_context_manager.input_for(
@@ -663,13 +602,8 @@ class FnApiRunner(runner.PipelineRunner):
     return result
 
   @staticmethod
-  def _extract_endpoints(stage,  # type: translations.Stage
-                         pipeline_components,  # type: 
beam_runner_api_pb2.Components
-                         data_api_service_descriptor, # type: 
Optional[endpoints_pb2.ApiServiceDescriptor]
-                         pcoll_buffers,  # type: MutableMapping[bytes, 
PartitionableBuffer]
-                         context,  # type: pipeline_context.PipelineContext
-                         safe_coders,  # type: Mapping[str, str]
-                         data_channel_coders,  # type: Mapping[str, str]
+  def _extract_endpoints(bundle_context_manager,  # type: 
execution.BundleContextManager
+                         runner_execution_context,  # type: 
execution.FnApiRunnerExecutionContext
                          ):
     # type: (...) -> Tuple[Dict[str, PartitionableBuffer], DataSideInput, 
DataOutput]
 
@@ -680,11 +614,7 @@ class FnApiRunner(runner.PipelineRunner):
     Args:
       stage (translations.Stage): The stage to extract endpoints
         for.
-      pipeline_components (beam_runner_api_pb2.Components): Components of the
-        pipeline to include coders, transforms, PCollections, etc.
       data_api_service_descriptor: A GRPC endpoint descriptor for data plane.
-      pcoll_buffers (dict): A dictionary containing buffers for PCollection
-        elements.
     Returns:
       A tuple of (data_input, data_side_input, data_output) dictionaries.
         `data_input` is a dictionary mapping (transform_name, output_name) to a
@@ -694,30 +624,34 @@ class FnApiRunner(runner.PipelineRunner):
     data_input = {}  # type: Dict[str, PartitionableBuffer]
     data_side_input = {}  # type: DataSideInput
     data_output = {}  # type: DataOutput
-
-    for transform in stage.transforms:
+    for transform in bundle_context_manager.stage.transforms:
       if transform.spec.urn in (bundle_processor.DATA_INPUT_URN,
                                 bundle_processor.DATA_OUTPUT_URN):
         pcoll_id = transform.spec.payload
         if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
-          coder_id = data_channel_coders[only_element(
+          coder_id = runner_execution_context.data_channel_coders[only_element(
               transform.outputs.values())]
-          coder = context.coders[safe_coders.get(coder_id, coder_id)]
+          coder = runner_execution_context.pipeline_context.coders[
+              runner_execution_context.safe_coders.get(coder_id, coder_id)]
           if pcoll_id == translations.IMPULSE_BUFFER:
             data_input[transform.unique_name] = ListBuffer(
                 coder_impl=coder.get_impl())
             data_input[transform.unique_name].append(ENCODED_IMPULSE_VALUE)
           else:
-            if pcoll_id not in pcoll_buffers:
-              pcoll_buffers[pcoll_id] = ListBuffer(coder_impl=coder.get_impl())
-            data_input[transform.unique_name] = pcoll_buffers[pcoll_id]
+            if pcoll_id not in runner_execution_context.pcoll_buffers:
+              runner_execution_context.pcoll_buffers[pcoll_id] = ListBuffer(
+                  coder_impl=coder.get_impl())
+            data_input[transform.unique_name] = (
+                runner_execution_context.pcoll_buffers[pcoll_id])
         elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN:
           data_output[transform.unique_name] = pcoll_id
-          coder_id = data_channel_coders[only_element(
+          coder_id = runner_execution_context.data_channel_coders[only_element(
               transform.inputs.values())]
         else:
           raise NotImplementedError
         data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id)
+        data_api_service_descriptor = (
+            bundle_context_manager.data_api_service_descriptor())
         if data_api_service_descriptor:
           data_spec.api_service_descriptor.url = (
               data_api_service_descriptor.url)
diff --git 
a/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py 
b/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py
index 2b690d2..d22d911 100644
--- 
a/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py
+++ 
b/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py
@@ -759,7 +759,7 @@ class WorkerHandlerManager(object):
     self._cached_handlers = collections.defaultdict(
         list)  # type: DefaultDict[str, List[WorkerHandler]]
     self._workers_by_id = {}  # type: Dict[str, WorkerHandler]
-    self._state = StateServicer()  # rename?
+    self.state_servicer = StateServicer()
     self._grpc_server = None  # type: Optional[GrpcServer]
 
   def get_worker_handlers(
@@ -781,14 +781,14 @@ class WorkerHandlerManager(object):
       self._grpc_server = cast(GrpcServer, None)
     elif self._grpc_server is None:
       self._grpc_server = GrpcServer(
-          self._state, self._job_provision_info, self)
+          self.state_servicer, self._job_provision_info, self)
 
     worker_handler_list = self._cached_handlers[environment_id]
     if len(worker_handler_list) < num_workers:
       for _ in range(len(worker_handler_list), num_workers):
         worker_handler = WorkerHandler.create(
             environment,
-            self._state,
+            self.state_servicer,
             self._job_provision_info,
             self._grpc_server)
         _LOGGER.info(

Reply via email to