[
https://issues.apache.org/jira/browse/BEAM-3886?focusedWorklogId=93495&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-93495
]
ASF GitHub Bot logged work on BEAM-3886:
----------------------------------------
Author: ASF GitHub Bot
Created on: 20/Apr/18 23:33
Start Date: 20/Apr/18 23:33
Worklog Time Spent: 10m
Work Description: robertwb closed pull request #5192: [BEAM-3886] Have
the Python SDK use the state api service descriptor on the process bundle
descriptor.
URL: https://github.com/apache/beam/pull/5192
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py
b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
index 1bad6c2781e..d52b0383469 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
@@ -43,6 +43,7 @@
from apache_beam.portability.api import beam_fn_api_pb2
from apache_beam.portability.api import beam_fn_api_pb2_grpc
from apache_beam.portability.api import beam_runner_api_pb2
+from apache_beam.portability.api import endpoints_pb2
from apache_beam.runners import pipeline_context
from apache_beam.runners import runner
from apache_beam.runners.worker import bundle_processor
@@ -847,11 +848,11 @@ def run_stage(
self, controller, pipeline_components, stage, pcoll_buffers,
safe_coders):
context = pipeline_context.PipelineContext(pipeline_components)
- data_operation_spec = controller.data_operation_spec()
+ data_api_service_descriptor = controller.data_api_service_descriptor()
def extract_endpoints(stage):
# Returns maps of transform names to PCollection identifiers.
- # Also mutates IO stages to point to the data data_operation_spec.
+ # Also mutates IO stages to point to the data ApiServiceDescriptor.
data_input = {}
data_side_input = {}
data_output = {}
@@ -867,8 +868,11 @@ def extract_endpoints(stage):
data_output[target] = pcoll_id
else:
raise NotImplementedError
- if data_operation_spec:
- transform.spec.payload = data_operation_spec.SerializeToString()
+ if data_api_service_descriptor:
+ data_spec = beam_fn_api_pb2.RemoteGrpcPort()
+ data_spec.api_service_descriptor.url = (
+ data_api_service_descriptor.url)
+ transform.spec.payload = data_spec.SerializeToString()
else:
transform.spec.payload = ""
elif transform.spec.urn == common_urns.PARDO_TRANSFORM:
@@ -894,6 +898,10 @@ def extract_endpoints(stage):
pipeline_components.windowing_strategies.items()),
environments=dict(pipeline_components.environments.items()))
+ if controller.state_api_service_descriptor():
+ process_bundle_descriptor.state_api_service_descriptor.url = (
+ controller.state_api_service_descriptor().url)
+
# Store the required side inputs into state.
for (transform_id, tag), (pcoll_id, si) in data_side_input.items():
elements_by_window = _WindowGroupingBuffer(si)
@@ -990,15 +998,30 @@ def State(self, request_stream, context=None):
id=request.id,
clear=beam_fn_api_pb2.ClearResponse())
+ class SingletonStateHandlerFactory(sdk_worker.StateHandlerFactory):
+ """A singleton cache for a StateServicer."""
+
+ def __init__(self, state_handler):
+ self._state_handler = state_handler
+
+ def create_state_handler(self, api_service_descriptor):
+ """Returns the singleton state handler."""
+ return self._state_handler
+
+ def close(self):
+ """Does nothing."""
+ pass
+
class DirectController(object):
"""An in-memory controller for fn API control, state and data planes."""
def __init__(self):
- self.state_handler = FnApiRunner.StateServicer()
self.control_handler = self
self.data_plane_handler = data_plane.InMemoryDataChannel()
+ self.state_handler = FnApiRunner.StateServicer()
self.worker = sdk_worker.SdkWorker(
- self.state_handler, data_plane.InMemoryDataChannelFactory(
+ FnApiRunner.SingletonStateHandlerFactory(self.state_handler),
+ data_plane.InMemoryDataChannelFactory(
self.data_plane_handler.inverse()), {})
self._uid_counter = 0
@@ -1017,7 +1040,10 @@ def done(self):
def close(self):
pass
- def data_operation_spec(self):
+ def data_api_service_descriptor(self):
+ return None
+
+ def state_api_service_descriptor(self):
return None
class GrpcController(object):
@@ -1032,6 +1058,10 @@ def __init__(self, sdk_harness_factory=None):
self.data_server =
grpc.server(futures.ThreadPoolExecutor(max_workers=10))
self.data_port = self.data_server.add_insecure_port('[::]:0')
+ self.state_server = grpc.server(
+ futures.ThreadPoolExecutor(max_workers=10))
+ self.state_port = self.state_server.add_insecure_port('[::]:0')
+
self.control_handler = BeamFnControlServicer()
beam_fn_api_pb2_grpc.add_BeamFnControlServicer_to_server(
self.control_handler, self.control_server)
@@ -1040,14 +1070,13 @@ def __init__(self, sdk_harness_factory=None):
beam_fn_api_pb2_grpc.add_BeamFnDataServicer_to_server(
self.data_plane_handler, self.data_server)
- # TODO(robertwb): Is sharing the control channel fine? Alternatively,
- # how should this be plumbed?
self.state_handler = FnApiRunner.GrpcStateServicer()
beam_fn_api_pb2_grpc.add_BeamFnStateServicer_to_server(
- self.state_handler, self.control_server)
+ self.state_handler, self.state_server)
logging.info('starting control server on port %s', self.control_port)
logging.info('starting data server on port %s', self.data_port)
+ self.state_server.start()
self.data_server.start()
self.control_server.start()
@@ -1061,11 +1090,17 @@ def __init__(self, sdk_harness_factory=None):
logging.info('starting worker')
self.worker_thread.start()
- def data_operation_spec(self):
+ def data_api_service_descriptor(self):
url = 'localhost:%s' % self.data_port
- remote_grpc_port = beam_fn_api_pb2.RemoteGrpcPort()
- remote_grpc_port.api_service_descriptor.url = url
- return remote_grpc_port
+ api_service_descriptor = endpoints_pb2.ApiServiceDescriptor()
+ api_service_descriptor.url = url
+ return api_service_descriptor
+
+ def state_api_service_descriptor(self):
+ url = 'localhost:%s' % self.state_port
+ api_service_descriptor = endpoints_pb2.ApiServiceDescriptor()
+ api_service_descriptor.url = url
+ return api_service_descriptor
def close(self):
self.control_handler.done()
@@ -1073,6 +1108,7 @@ def close(self):
self.data_plane_handler.close()
self.control_server.stop(5).wait()
self.data_server.stop(5).wait()
+ self.state_server.stop(5).wait()
class BundleManager(object):
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py
b/sdks/python/apache_beam/runners/worker/sdk_worker.py
index 972e94c0025..b8fa422536b 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py
@@ -20,6 +20,7 @@
from __future__ import division
from __future__ import print_function
+import abc
import contextlib
import logging
import Queue as queue
@@ -56,6 +57,7 @@ def __init__(self, control_address, worker_count,
credentials=None):
self._control_channel, WorkerIdInterceptor())
self._data_channel_factory = data_plane.GrpcClientDataChannelFactory(
credentials)
+ self._state_handler_factory = GrpcStateHandlerFactory()
self.workers = queue.Queue()
# one thread is enough for getting the progress report.
# Assumption:
@@ -78,9 +80,6 @@ def run(self):
# Create workers
for _ in range(self._worker_count):
- state_handler = GrpcStateHandler(
- beam_fn_api_pb2_grpc.BeamFnStateStub(self._control_channel))
- state_handler.start()
# SdkHarness manage function registration and share self._fns with all
# the workers. This is needed because function registration (register)
# and exceution(process_bundle) are send over different request and we
@@ -91,7 +90,7 @@ def run(self):
# centralized function list shared among all the workers.
self.workers.put(
SdkWorker(
- state_handler=state_handler,
+ state_handler_factory=self._state_handler_factory,
data_channel_factory=self._data_channel_factory,
fns=self._fns))
@@ -118,10 +117,9 @@ def get_responses():
# get_responses may be blocked on responses.get(), but we need to return
# control to its caller.
self._responses.put(no_more_work)
- self._data_channel_factory.close()
# Stop all the workers and clean all the associated resources
- for worker in self.workers.queue:
- worker.state_handler.done()
+ self._data_channel_factory.close()
+ self._state_handler_factory.close()
logging.info('Done consuming work.')
def _execute(self, task, request):
@@ -196,9 +194,9 @@ def task():
class SdkWorker(object):
- def __init__(self, state_handler, data_channel_factory, fns):
+ def __init__(self, state_handler_factory, data_channel_factory, fns):
self.fns = fns
- self.state_handler = state_handler
+ self.state_handler_factory = state_handler_factory
self.data_channel_factory = data_channel_factory
self.bundle_processors = {}
@@ -219,12 +217,16 @@ def register(self, request, instruction_id):
register=beam_fn_api_pb2.RegisterResponse())
def process_bundle(self, request, instruction_id):
+ process_bundle_desc = self.fns[request.process_bundle_descriptor_reference]
+ state_handler = self.state_handler_factory.create_state_handler(
+ process_bundle_desc.state_api_service_descriptor)
self.bundle_processors[
instruction_id] = processor = bundle_processor.BundleProcessor(
- self.fns[request.process_bundle_descriptor_reference],
- self.state_handler, self.data_channel_factory)
+ process_bundle_desc,
+ state_handler,
+ self.data_channel_factory)
try:
- with self.state_handler.process_instruction_id(instruction_id):
+ with state_handler.process_instruction_id(instruction_id):
processor.process_bundle(instruction_id)
finally:
del self.bundle_processors[instruction_id]
@@ -243,6 +245,84 @@ def process_bundle_progress(self, request, instruction_id):
metrics=processor.metrics() if processor else None))
+class StateHandlerFactory(object):
+ """An abstract factory for creating ``DataChannel``."""
+
+ __metaclass__ = abc.ABCMeta
+
+ @abc.abstractmethod
+ def create_state_handler(self, api_service_descriptor):
+ """Returns a ``StateHandler`` from the given ApiServiceDescriptor."""
+ raise NotImplementedError(type(self))
+
+ @abc.abstractmethod
+ def close(self):
+ """Close all channels that this factory owns."""
+ raise NotImplementedError(type(self))
+
+
+class GrpcStateHandlerFactory(StateHandlerFactory):
+ """A factory for ``GrpcStateHandler``.
+
+ Caches the created channels by ``state descriptor url``.
+ """
+
+ def __init__(self):
+ self._state_handler_cache = {}
+ self._lock = threading.Lock()
+ self._throwing_state_handler = ThrowingStateHandler()
+
+ def create_state_handler(self, api_service_descriptor):
+ if not api_service_descriptor:
+ return self._throwing_state_handler
+ url = api_service_descriptor.url
+ if url not in self._state_handler_cache:
+ with self._lock:
+ if url not in self._state_handler_cache:
+ logging.info('Creating channel for %s', url)
+ grpc_channel = grpc.insecure_channel(
+ url,
+ # Options to have no limits (-1) on the size of the messages
+ # received or sent over the data plane. The actual buffer size is
+ # controlled in a layer above.
+ options=[("grpc.max_receive_message_length", -1),
+ ("grpc.max_send_message_length", -1)])
+ # Add workerId to the grpc channel
+ grpc_channel = grpc.intercept_channel(grpc_channel,
+ WorkerIdInterceptor())
+ self._state_handler_cache[url] = GrpcStateHandler(
+ beam_fn_api_pb2_grpc.BeamFnStateStub(grpc_channel))
+ return self._state_handler_cache[url]
+
+ def close(self):
+ logging.info('Closing all cached gRPC state handlers.')
+ for _, state_handler in self._state_handler_cache.items():
+ state_handler.done()
+ self._state_handler_cache.clear()
+
+
+class ThrowingStateHandler(object):
+ """A state handler that errors on any requests."""
+
+ def blocking_get(self, state_key, instruction_reference):
+ raise RuntimeError(
+ 'Unable to handle state requests for ProcessBundleDescriptor without '
+ 'out state ApiServiceDescriptor for instruction %s and state key %s.'
+ % (state_key, instruction_reference))
+
+ def blocking_append(self, state_key, data, instruction_reference):
+ raise RuntimeError(
+ 'Unable to handle state requests for ProcessBundleDescriptor without '
+ 'out state ApiServiceDescriptor for instruction %s and state key %s.'
+ % (state_key, instruction_reference))
+
+ def blocking_clear(self, state_key, instruction_reference):
+ raise RuntimeError(
+ 'Unable to handle state requests for ProcessBundleDescriptor without '
+ 'out state ApiServiceDescriptor for instruction %s and state key %s.'
+ % (state_key, instruction_reference))
+
+
class GrpcStateHandler(object):
_DONE = object()
@@ -255,6 +335,7 @@ def __init__(self, state_stub):
self._last_id = 0
self._exc_info = None
self._context = threading.local()
+ self.start()
@contextlib.contextmanager
def process_instruction_id(self, bundle_id):
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
Issue Time Tracking
-------------------
Worklog Id: (was: 93495)
Time Spent: 0.5h (was: 20m)
> Python SDK harness uses State API from ProcessBundleDescriptor
> --------------------------------------------------------------
>
> Key: BEAM-3886
> URL: https://issues.apache.org/jira/browse/BEAM-3886
> Project: Beam
> Issue Type: Sub-task
> Components: sdk-py-core
> Reporter: Ben Sidhom
> Assignee: Ahmet Altay
> Priority: Minor
> Time Spent: 0.5h
> Remaining Estimate: 0h
>
> The Python harness should pull the state api descriptor from the current
> process bundle descriptor when processing bundles.
> As a minor optimization and to make implementing new runners easier, the
> harness should not talk to the State server unless it's actually needed.
--
This message was sent by Atlassian JIRA
(v7.6.3#76005)