[ 
https://issues.apache.org/jira/browse/BEAM-3886?focusedWorklogId=93495&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-93495
 ]

ASF GitHub Bot logged work on BEAM-3886:
----------------------------------------

                Author: ASF GitHub Bot
            Created on: 20/Apr/18 23:33
            Start Date: 20/Apr/18 23:33
    Worklog Time Spent: 10m 
      Work Description: robertwb closed pull request #5192: [BEAM-3886] Have 
the Python SDK use the state api service descriptor on the process bundle 
descriptor.
URL: https://github.com/apache/beam/pull/5192
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py 
b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
index 1bad6c2781e..d52b0383469 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
@@ -43,6 +43,7 @@
 from apache_beam.portability.api import beam_fn_api_pb2
 from apache_beam.portability.api import beam_fn_api_pb2_grpc
 from apache_beam.portability.api import beam_runner_api_pb2
+from apache_beam.portability.api import endpoints_pb2
 from apache_beam.runners import pipeline_context
 from apache_beam.runners import runner
 from apache_beam.runners.worker import bundle_processor
@@ -847,11 +848,11 @@ def run_stage(
       self, controller, pipeline_components, stage, pcoll_buffers, 
safe_coders):
 
     context = pipeline_context.PipelineContext(pipeline_components)
-    data_operation_spec = controller.data_operation_spec()
+    data_api_service_descriptor = controller.data_api_service_descriptor()
 
     def extract_endpoints(stage):
       # Returns maps of transform names to PCollection identifiers.
-      # Also mutates IO stages to point to the data data_operation_spec.
+      # Also mutates IO stages to point to the data ApiServiceDescriptor.
       data_input = {}
       data_side_input = {}
       data_output = {}
@@ -867,8 +868,11 @@ def extract_endpoints(stage):
             data_output[target] = pcoll_id
           else:
             raise NotImplementedError
-          if data_operation_spec:
-            transform.spec.payload = data_operation_spec.SerializeToString()
+          if data_api_service_descriptor:
+            data_spec = beam_fn_api_pb2.RemoteGrpcPort()
+            data_spec.api_service_descriptor.url = (
+              data_api_service_descriptor.url)
+            transform.spec.payload = data_spec.SerializeToString()
           else:
             transform.spec.payload = ""
         elif transform.spec.urn == common_urns.PARDO_TRANSFORM:
@@ -894,6 +898,10 @@ def extract_endpoints(stage):
             pipeline_components.windowing_strategies.items()),
         environments=dict(pipeline_components.environments.items()))
 
+    if controller.state_api_service_descriptor():
+      process_bundle_descriptor.state_api_service_descriptor.url = (
+        controller.state_api_service_descriptor().url)
+
     # Store the required side inputs into state.
     for (transform_id, tag), (pcoll_id, si) in data_side_input.items():
       elements_by_window = _WindowGroupingBuffer(si)
@@ -990,15 +998,30 @@ def State(self, request_stream, context=None):
               id=request.id,
               clear=beam_fn_api_pb2.ClearResponse())
 
+  class SingletonStateHandlerFactory(sdk_worker.StateHandlerFactory):
+    """A singleton cache for a StateServicer."""
+
+    def __init__(self, state_handler):
+      self._state_handler = state_handler
+
+    def create_state_handler(self, api_service_descriptor):
+      """Returns the singleton state handler."""
+      return self._state_handler
+
+    def close(self):
+      """Does nothing."""
+      pass
+
   class DirectController(object):
     """An in-memory controller for fn API control, state and data planes."""
 
     def __init__(self):
-      self.state_handler = FnApiRunner.StateServicer()
       self.control_handler = self
       self.data_plane_handler = data_plane.InMemoryDataChannel()
+      self.state_handler = FnApiRunner.StateServicer()
       self.worker = sdk_worker.SdkWorker(
-          self.state_handler, data_plane.InMemoryDataChannelFactory(
+          FnApiRunner.SingletonStateHandlerFactory(self.state_handler),
+          data_plane.InMemoryDataChannelFactory(
               self.data_plane_handler.inverse()), {})
       self._uid_counter = 0
 
@@ -1017,7 +1040,10 @@ def done(self):
     def close(self):
       pass
 
-    def data_operation_spec(self):
+    def data_api_service_descriptor(self):
+      return None
+
+    def state_api_service_descriptor(self):
       return None
 
   class GrpcController(object):
@@ -1032,6 +1058,10 @@ def __init__(self, sdk_harness_factory=None):
       self.data_server = 
grpc.server(futures.ThreadPoolExecutor(max_workers=10))
       self.data_port = self.data_server.add_insecure_port('[::]:0')
 
+      self.state_server = grpc.server(
+          futures.ThreadPoolExecutor(max_workers=10))
+      self.state_port = self.state_server.add_insecure_port('[::]:0')
+
       self.control_handler = BeamFnControlServicer()
       beam_fn_api_pb2_grpc.add_BeamFnControlServicer_to_server(
           self.control_handler, self.control_server)
@@ -1040,14 +1070,13 @@ def __init__(self, sdk_harness_factory=None):
       beam_fn_api_pb2_grpc.add_BeamFnDataServicer_to_server(
           self.data_plane_handler, self.data_server)
 
-      # TODO(robertwb): Is sharing the control channel fine?  Alternatively,
-      # how should this be plumbed?
       self.state_handler = FnApiRunner.GrpcStateServicer()
       beam_fn_api_pb2_grpc.add_BeamFnStateServicer_to_server(
-          self.state_handler, self.control_server)
+          self.state_handler, self.state_server)
 
       logging.info('starting control server on port %s', self.control_port)
       logging.info('starting data server on port %s', self.data_port)
+      self.state_server.start()
       self.data_server.start()
       self.control_server.start()
 
@@ -1061,11 +1090,17 @@ def __init__(self, sdk_harness_factory=None):
       logging.info('starting worker')
       self.worker_thread.start()
 
-    def data_operation_spec(self):
+    def data_api_service_descriptor(self):
       url = 'localhost:%s' % self.data_port
-      remote_grpc_port = beam_fn_api_pb2.RemoteGrpcPort()
-      remote_grpc_port.api_service_descriptor.url = url
-      return remote_grpc_port
+      api_service_descriptor = endpoints_pb2.ApiServiceDescriptor()
+      api_service_descriptor.url = url
+      return api_service_descriptor
+
+    def state_api_service_descriptor(self):
+      url = 'localhost:%s' % self.state_port
+      api_service_descriptor = endpoints_pb2.ApiServiceDescriptor()
+      api_service_descriptor.url = url
+      return api_service_descriptor
 
     def close(self):
       self.control_handler.done()
@@ -1073,6 +1108,7 @@ def close(self):
       self.data_plane_handler.close()
       self.control_server.stop(5).wait()
       self.data_server.stop(5).wait()
+      self.state_server.stop(5).wait()
 
 
 class BundleManager(object):
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py 
b/sdks/python/apache_beam/runners/worker/sdk_worker.py
index 972e94c0025..b8fa422536b 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py
@@ -20,6 +20,7 @@
 from __future__ import division
 from __future__ import print_function
 
+import abc
 import contextlib
 import logging
 import Queue as queue
@@ -56,6 +57,7 @@ def __init__(self, control_address, worker_count, 
credentials=None):
         self._control_channel, WorkerIdInterceptor())
     self._data_channel_factory = data_plane.GrpcClientDataChannelFactory(
         credentials)
+    self._state_handler_factory = GrpcStateHandlerFactory()
     self.workers = queue.Queue()
     # one thread is enough for getting the progress report.
     # Assumption:
@@ -78,9 +80,6 @@ def run(self):
 
     # Create workers
     for _ in range(self._worker_count):
-      state_handler = GrpcStateHandler(
-          beam_fn_api_pb2_grpc.BeamFnStateStub(self._control_channel))
-      state_handler.start()
       # SdkHarness manage function registration and share self._fns with all
       # the workers. This is needed because function registration (register)
       # and exceution(process_bundle) are send over different request and we
@@ -91,7 +90,7 @@ def run(self):
       # centralized function list shared among all the workers.
       self.workers.put(
           SdkWorker(
-              state_handler=state_handler,
+              state_handler_factory=self._state_handler_factory,
               data_channel_factory=self._data_channel_factory,
               fns=self._fns))
 
@@ -118,10 +117,9 @@ def get_responses():
     # get_responses may be blocked on responses.get(), but we need to return
     # control to its caller.
     self._responses.put(no_more_work)
-    self._data_channel_factory.close()
     # Stop all the workers and clean all the associated resources
-    for worker in self.workers.queue:
-      worker.state_handler.done()
+    self._data_channel_factory.close()
+    self._state_handler_factory.close()
     logging.info('Done consuming work.')
 
   def _execute(self, task, request):
@@ -196,9 +194,9 @@ def task():
 
 class SdkWorker(object):
 
-  def __init__(self, state_handler, data_channel_factory, fns):
+  def __init__(self, state_handler_factory, data_channel_factory, fns):
     self.fns = fns
-    self.state_handler = state_handler
+    self.state_handler_factory = state_handler_factory
     self.data_channel_factory = data_channel_factory
     self.bundle_processors = {}
 
@@ -219,12 +217,16 @@ def register(self, request, instruction_id):
         register=beam_fn_api_pb2.RegisterResponse())
 
   def process_bundle(self, request, instruction_id):
+    process_bundle_desc = self.fns[request.process_bundle_descriptor_reference]
+    state_handler = self.state_handler_factory.create_state_handler(
+        process_bundle_desc.state_api_service_descriptor)
     self.bundle_processors[
         instruction_id] = processor = bundle_processor.BundleProcessor(
-            self.fns[request.process_bundle_descriptor_reference],
-            self.state_handler, self.data_channel_factory)
+            process_bundle_desc,
+            state_handler,
+            self.data_channel_factory)
     try:
-      with self.state_handler.process_instruction_id(instruction_id):
+      with state_handler.process_instruction_id(instruction_id):
         processor.process_bundle(instruction_id)
     finally:
       del self.bundle_processors[instruction_id]
@@ -243,6 +245,84 @@ def process_bundle_progress(self, request, instruction_id):
             metrics=processor.metrics() if processor else None))
 
 
+class StateHandlerFactory(object):
+  """An abstract factory for creating ``DataChannel``."""
+
+  __metaclass__ = abc.ABCMeta
+
+  @abc.abstractmethod
+  def create_state_handler(self, api_service_descriptor):
+    """Returns a ``StateHandler`` from the given ApiServiceDescriptor."""
+    raise NotImplementedError(type(self))
+
+  @abc.abstractmethod
+  def close(self):
+    """Close all channels that this factory owns."""
+    raise NotImplementedError(type(self))
+
+
+class GrpcStateHandlerFactory(StateHandlerFactory):
+  """A factory for ``GrpcStateHandler``.
+
+  Caches the created channels by ``state descriptor url``.
+  """
+
+  def __init__(self):
+    self._state_handler_cache = {}
+    self._lock = threading.Lock()
+    self._throwing_state_handler = ThrowingStateHandler()
+
+  def create_state_handler(self, api_service_descriptor):
+    if not api_service_descriptor:
+      return self._throwing_state_handler
+    url = api_service_descriptor.url
+    if url not in self._state_handler_cache:
+      with self._lock:
+        if url not in self._state_handler_cache:
+          logging.info('Creating channel for %s', url)
+          grpc_channel = grpc.insecure_channel(
+              url,
+              # Options to have no limits (-1) on the size of the messages
+              # received or sent over the data plane. The actual buffer size is
+              # controlled in a layer above.
+              options=[("grpc.max_receive_message_length", -1),
+                       ("grpc.max_send_message_length", -1)])
+          # Add workerId to the grpc channel
+          grpc_channel = grpc.intercept_channel(grpc_channel,
+                                                WorkerIdInterceptor())
+          self._state_handler_cache[url] = GrpcStateHandler(
+              beam_fn_api_pb2_grpc.BeamFnStateStub(grpc_channel))
+    return self._state_handler_cache[url]
+
+  def close(self):
+    logging.info('Closing all cached gRPC state handlers.')
+    for _, state_handler in self._state_handler_cache.items():
+      state_handler.done()
+    self._state_handler_cache.clear()
+
+
+class ThrowingStateHandler(object):
+  """A state handler that errors on any requests."""
+
+  def blocking_get(self, state_key, instruction_reference):
+    raise RuntimeError(
+        'Unable to handle state requests for ProcessBundleDescriptor without '
+        'out state ApiServiceDescriptor for instruction %s and state key %s.'
+        % (state_key, instruction_reference))
+
+  def blocking_append(self, state_key, data, instruction_reference):
+    raise RuntimeError(
+        'Unable to handle state requests for ProcessBundleDescriptor without '
+        'out state ApiServiceDescriptor for instruction %s and state key %s.'
+        % (state_key, instruction_reference))
+
+  def blocking_clear(self, state_key, instruction_reference):
+    raise RuntimeError(
+        'Unable to handle state requests for ProcessBundleDescriptor without '
+        'out state ApiServiceDescriptor for instruction %s and state key %s.'
+        % (state_key, instruction_reference))
+
+
 class GrpcStateHandler(object):
 
   _DONE = object()
@@ -255,6 +335,7 @@ def __init__(self, state_stub):
     self._last_id = 0
     self._exc_info = None
     self._context = threading.local()
+    self.start()
 
   @contextlib.contextmanager
   def process_instruction_id(self, bundle_id):


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


Issue Time Tracking
-------------------

    Worklog Id:     (was: 93495)
    Time Spent: 0.5h  (was: 20m)

> Python SDK harness uses State API from ProcessBundleDescriptor
> --------------------------------------------------------------
>
>                 Key: BEAM-3886
>                 URL: https://issues.apache.org/jira/browse/BEAM-3886
>             Project: Beam
>          Issue Type: Sub-task
>          Components: sdk-py-core
>            Reporter: Ben Sidhom
>            Assignee: Ahmet Altay
>            Priority: Minor
>          Time Spent: 0.5h
>  Remaining Estimate: 0h
>
> The Python harness should pull the state api descriptor from the current 
> process bundle descriptor when processing bundles.
> As a minor optimization and to make implementing new runners easier, the 
> harness should not talk to the State server unless it's actually needed.



--
This message was sent by Atlassian JIRA
(v7.6.3#76005)

Reply via email to