This is an automated email from the ASF dual-hosted git repository.
pabloem pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 4be760e [BEAM-9608] Increase reliance on Context Managers for
FnApiRunner
new a55ee53 Merge pull request #11229 from [BEAM-9608] Increasing scope
of context managers for FnApiRunner
4be760e is described below
commit 4be760ee74be196d2edf3b6f05cefdce820c5e26
Author: Pablo Estrada <[email protected]>
AuthorDate: Wed Mar 25 14:27:56 2020 -0700
[BEAM-9608] Increase reliance on Context Managers for FnApiRunner
---
.../runners/portability/fn_api_runner/execution.py | 119 ++++++++++++---
.../runners/portability/fn_api_runner/fn_runner.py | 166 +++++++--------------
.../portability/fn_api_runner/worker_handlers.py | 6 +-
3 files changed, 153 insertions(+), 138 deletions(-)
diff --git
a/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py
b/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py
index 99c2438..02855dc 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py
@@ -21,6 +21,8 @@ from __future__ import absolute_import
import collections
import itertools
+from typing import TYPE_CHECKING
+from typing import MutableMapping
from typing_extensions import Protocol
@@ -29,14 +31,21 @@ from apache_beam.coders.coder_impl import create_InputStream
from apache_beam.coders.coder_impl import create_OutputStream
from apache_beam.portability import common_urns
from apache_beam.portability.api import beam_fn_api_pb2
+from apache_beam.runners import pipeline_context
from apache_beam.runners.portability.fn_api_runner.translations import
only_element
from apache_beam.runners.portability.fn_api_runner.translations import
split_buffer_id
+from apache_beam.runners.portability.fn_api_runner.translations import
unique_name
from apache_beam.runners.worker import bundle_processor
from apache_beam.transforms import trigger
from apache_beam.transforms.window import GlobalWindow
from apache_beam.transforms.window import GlobalWindows
from apache_beam.utils import windowed_value
+if TYPE_CHECKING:
+ from apache_beam.coders.coder_impl import CoderImpl
+ from apache_beam.runners.portability.fn_api_runner import translations
+ from apache_beam.runners.portability.fn_api_runner import worker_handlers
+
class Buffer(Protocol):
def __iter__(self):
@@ -245,37 +254,108 @@ class FnApiRunnerExecutionContext(object):
``beam.PCollection``.
"""
def __init__(self,
- worker_handler_factory, # type: Callable[[Optional[str], int],
List[WorkerHandler]]
+ worker_handler_manager, # type: worker_handlers.WorkerHandlerManager
pipeline_components, # type: beam_runner_api_pb2.Components
safe_coders,
data_channel_coders,
):
"""
- :param worker_handler_factory: A ``callable`` that takes in an environment
- id and a number of workers, and returns a list of ``WorkerHandler``s.
+ :param worker_handler_manager: This class manages the set of worker
+ handlers, and the communication with state / control APIs.
:param pipeline_components: (beam_runner_api_pb2.Components): TODO
:param safe_coders:
:param data_channel_coders:
"""
self.pcoll_buffers = {} # type: MutableMapping[bytes, PartitionableBuffer]
- self.worker_handler_factory = worker_handler_factory
+ self.worker_handler_manager = worker_handler_manager
self.pipeline_components = pipeline_components
self.safe_coders = safe_coders
self.data_channel_coders = data_channel_coders
+ self.pipeline_context = pipeline_context.PipelineContext(
+ self.pipeline_components,
+ iterable_state_write=self._iterable_state_write)
+ self._last_uid = -1
+
+ @property
+ def state_servicer(self):
+ # TODO(BEAM-9625): Ensure FnApiRunnerExecutionContext owns StateServicer
+ return self.worker_handler_manager.state_servicer
+
+ def next_uid(self):
+ self._last_uid += 1
+ return str(self._last_uid)
+
+ def _iterable_state_write(self, values, element_coder_impl):
+ # type: (...) -> bytes
+ token = unique_name(None, 'iter').encode('ascii')
+ out = create_OutputStream()
+ for element in values:
+ element_coder_impl.encode_to_stream(element, out, True)
+ self.worker_handler_manager.state_servicer.append_raw(
+ beam_fn_api_pb2.StateKey(
+ runner=beam_fn_api_pb2.StateKey.Runner(key=token)),
+ out.get())
+ return token
+
class BundleContextManager(object):
def __init__(self,
- execution_context, # type: FnApiRunnerExecutionContext
- process_bundle_descriptor, # type:
beam_fn_api_pb2.ProcessBundleDescriptor
- worker_handler, # type: fn_runner.WorkerHandler
- p_context, # type: pipeline_context.PipelineContext
- ):
+ execution_context, # type: FnApiRunnerExecutionContext
+ stage, # type: translations.Stage
+ num_workers, # type: int
+ ):
self.execution_context = execution_context
- self.process_bundle_descriptor = process_bundle_descriptor
- self.worker_handler = worker_handler
- self.pipeline_context = p_context
+ self.stage = stage
+ self.bundle_uid = self.execution_context.next_uid()
+ self.num_workers = num_workers
+
+ # Properties that are lazily initialized
+ self._process_bundle_descriptor = None
+ self._worker_handlers = None
+
+ @property
+ def worker_handlers(self):
+ if self._worker_handlers is None:
+ self._worker_handlers = (
+ self.execution_context.worker_handler_manager.get_worker_handlers(
+ self.stage.environment, self.num_workers))
+ return self._worker_handlers
+
+ def data_api_service_descriptor(self):
+ # All worker_handlers share the same grpc server, so we can read grpc
server
+ # info from any worker_handler and read from the first worker_handler.
+ return self.worker_handlers[0].data_api_service_descriptor()
+
+ def state_api_service_descriptor(self):
+ # All worker_handlers share the same grpc server, so we can read grpc
server
+ # info from any worker_handler and read from the first worker_handler.
+ return self.worker_handlers[0].state_api_service_descriptor()
+
+ @property
+ def process_bundle_descriptor(self):
+ if self._process_bundle_descriptor is None:
+ self._process_bundle_descriptor = self._build_process_bundle_descriptor()
+ return self._process_bundle_descriptor
+
+ def _build_process_bundle_descriptor(self):
+ res = beam_fn_api_pb2.ProcessBundleDescriptor(
+ id=self.bundle_uid,
+ transforms={
+ transform.unique_name: transform
+ for transform in self.stage.transforms
+ },
+ pcollections=dict(
+ self.execution_context.pipeline_components.pcollections.items()),
+ coders=dict(self.execution_context.pipeline_components.coders.items()),
+ windowing_strategies=dict(
+ self.execution_context.pipeline_components.windowing_strategies.
+ items()),
+ environments=dict(
+ self.execution_context.pipeline_components.environments.items()),
+ state_api_service_descriptor=self.state_api_service_descriptor())
+ return res
def get_input_coder_impl(self, transform_id):
# type: (str) -> CoderImpl
@@ -284,10 +364,10 @@ class BundleContextManager(object):
).coder_id
assert coder_id
if coder_id in self.execution_context.safe_coders:
- return self.pipeline_context.coders[
+ return self.execution_context.pipeline_context.coders[
self.execution_context.safe_coders[coder_id]].get_impl()
else:
- return self.pipeline_context.coders[coder_id].get_impl()
+ return
self.execution_context.pipeline_context.coders[coder_id].get_impl()
def get_buffer(self, buffer_id, transform_id):
# type: (bytes, str) -> PartitionableBuffer
@@ -310,15 +390,16 @@ class BundleContextManager(object):
original_gbk_transform]
input_pcoll = only_element(list(transform_proto.inputs.values()))
output_pcoll = only_element(list(transform_proto.outputs.values()))
- pre_gbk_coder = self.pipeline_context.coders[
+ pre_gbk_coder = self.execution_context.pipeline_context.coders[
self.execution_context.safe_coders[
self.execution_context.data_channel_coders[input_pcoll]]]
- post_gbk_coder = self.pipeline_context.coders[
+ post_gbk_coder = self.execution_context.pipeline_context.coders[
self.execution_context.safe_coders[
self.execution_context.data_channel_coders[output_pcoll]]]
- windowing_strategy = self.pipeline_context.windowing_strategies[
- self.execution_context.pipeline_components.
- pcollections[output_pcoll].windowing_strategy_id]
+ windowing_strategy = (
+ self.execution_context.pipeline_context.windowing_strategies[
+ self.execution_context.pipeline_components.
+ pcollections[output_pcoll].windowing_strategy_id])
self.execution_context.pcoll_buffers[buffer_id] = GroupingBuffer(
pre_gbk_coder, post_gbk_coder, windowing_strategy)
else:
diff --git
a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py
b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py
index 13646da..fa0af3b 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py
@@ -55,7 +55,6 @@ from apache_beam.portability import common_urns
from apache_beam.portability.api import beam_fn_api_pb2
from apache_beam.portability.api import beam_provision_api_pb2
from apache_beam.portability.api import beam_runner_api_pb2
-from apache_beam.runners import pipeline_context
from apache_beam.runners import runner
from apache_beam.runners.portability import portable_metrics
from apache_beam.runners.portability.fn_api_runner import execution
@@ -65,7 +64,6 @@ from apache_beam.runners.portability.fn_api_runner.execution
import WindowGroupi
from apache_beam.runners.portability.fn_api_runner.translations import
create_buffer_id
from apache_beam.runners.portability.fn_api_runner.translations import
only_element
from apache_beam.runners.portability.fn_api_runner.translations import
split_buffer_id
-from apache_beam.runners.portability.fn_api_runner.translations import
unique_name
from apache_beam.runners.portability.fn_api_runner.worker_handlers import
WorkerHandlerManager
from apache_beam.runners.worker import bundle_processor
from apache_beam.transforms import environments
@@ -120,7 +118,6 @@ class FnApiRunner(runner.PipelineRunner):
waits before requesting progress from the SDK.
"""
super(FnApiRunner, self).__init__()
- self._last_uid = -1
self._default_environment = (
default_environment or environments.EmbeddedPythonEnvironment())
self._bundle_repeat = bundle_repeat
@@ -132,10 +129,6 @@ class FnApiRunner(runner.PipelineRunner):
beam_provision_api_pb2.ProvisionInfo(
retrieval_token='unused-retrieval-token'))
- def _next_uid(self):
- self._last_uid += 1
- return str(self._last_uid)
-
@staticmethod
def supported_requirements():
return (
@@ -329,7 +322,7 @@ class FnApiRunner(runner.PipelineRunner):
monitoring_infos_by_stage = {}
runner_execution_context = execution.FnApiRunnerExecutionContext(
- worker_handler_manager.get_worker_handlers,
+ worker_handler_manager,
stage_context.components,
stage_context.safe_coders,
stage_context.data_channel_coders)
@@ -337,7 +330,12 @@ class FnApiRunner(runner.PipelineRunner):
try:
with self.maybe_profile():
for stage in stages:
- stage_results = self._run_stage(runner_execution_context, stage)
+ bundle_context_manager = execution.BundleContextManager(
+ runner_execution_context, stage, self._num_workers)
+ stage_results = self._run_stage(
+ runner_execution_context,
+ bundle_context_manager,
+ )
metrics_by_stage[stage.name] = stage_results.process_bundle.metrics
monitoring_infos_by_stage[stage.name] = (
stage_results.process_bundle.monitoring_infos)
@@ -347,14 +345,13 @@ class FnApiRunner(runner.PipelineRunner):
runner.PipelineState.DONE, monitoring_infos_by_stage, metrics_by_stage)
def _store_side_inputs_in_state(self,
- bundle_context_manager, # type:
execution.BundleContextManager
runner_execution_context, # type:
execution.FnApiRunnerExecutionContext
data_side_input, # type: DataSideInput
):
# type: (...) -> None
for (transform_id, tag), (buffer_id, si) in data_side_input.items():
_, pcoll_id = split_buffer_id(buffer_id)
- value_coder = bundle_context_manager.pipeline_context.coders[
+ value_coder = runner_execution_context.pipeline_context.coders[
runner_execution_context.safe_coders[
runner_execution_context.data_channel_coders[pcoll_id]]]
elements_by_window = WindowGroupingBuffer(si, value_coder)
@@ -369,8 +366,9 @@ class FnApiRunner(runner.PipelineRunner):
state_key = beam_fn_api_pb2.StateKey(
iterable_side_input=beam_fn_api_pb2.StateKey.IterableSideInput(
transform_id=transform_id, side_input_id=tag, window=window))
- bundle_context_manager.worker_handler.state.append_raw(
- state_key, elements_data)
+ (
+ runner_execution_context.worker_handler_manager.state_servicer.
+ append_raw(state_key, elements_data))
elif si.urn == common_urns.side_inputs.MULTIMAP.urn:
for key, window, elements_data in elements_by_window.encoded_items():
state_key = beam_fn_api_pb2.StateKey(
@@ -379,14 +377,15 @@ class FnApiRunner(runner.PipelineRunner):
side_input_id=tag,
window=window,
key=key))
- bundle_context_manager.worker_handler.state.append_raw(
- state_key, elements_data)
+ (
+ runner_execution_context.worker_handler_manager.state_servicer.
+ append_raw(state_key, elements_data))
else:
raise ValueError("Unknown access pattern: '%s'" % si.urn)
def _run_bundle_multiple_times_for_testing(
self,
- worker_handler_list, # type: Sequence[WorkerHandler]
+ runner_execution_context, # type: execution.FnApiRunnerExecutionContext
bundle_context_manager, # type: execution.BundleContextManager
data_input,
data_output, # type: DataOutput
@@ -398,12 +397,11 @@ class FnApiRunner(runner.PipelineRunner):
If bundle_repeat > 0, replay every bundle for profiling and debugging.
"""
# all workers share state, so use any worker_handler.
- worker_handler = worker_handler_list[0]
for k in range(self._bundle_repeat):
try:
- worker_handler.state.checkpoint()
+ runner_execution_context.state_servicer.checkpoint()
testing_bundle_manager = ParallelBundleManager(
- worker_handler_list,
+ bundle_context_manager.worker_handlers,
lambda pcoll_id,
transform_id: ListBuffer(
coder_impl=bundle_context_manager.get_input_coder_impl),
@@ -415,24 +413,24 @@ class FnApiRunner(runner.PipelineRunner):
cache_token_generator=cache_token_generator)
testing_bundle_manager.process_bundle(data_input, data_output)
finally:
- worker_handler.state.restore()
+ runner_execution_context.state_servicer.restore()
def _collect_written_timers_and_add_to_deferred_inputs(
self,
- pipeline_components, # type: beam_runner_api_pb2.Components
- stage, # type: translations.Stage
+ runner_execution_context, # type: execution.FnApiRunnerExecutionContext
bundle_context_manager, # type: execution.BundleContextManager
deferred_inputs, # type: MutableMapping[str, PartitionableBuffer]
- data_channel_coders, # type: Mapping[str, str]
):
# type: (...) -> None
- for transform_id, timer_writes in stage.timer_pcollections:
+ for (transform_id,
+ timer_writes) in bundle_context_manager.stage.timer_pcollections:
# Queue any set timers as new inputs.
windowed_timer_coder_impl = (
- bundle_context_manager.pipeline_context.coders[
- data_channel_coders[timer_writes]].get_impl())
+ runner_execution_context.pipeline_context.coders[
+ runner_execution_context.data_channel_coders[timer_writes]].
+ get_impl())
written_timers = bundle_context_manager.get_buffer(
create_buffer_id(timer_writes, kind='timers'), transform_id)
if not written_timers.cleared:
@@ -511,87 +509,32 @@ class FnApiRunner(runner.PipelineRunner):
def _run_stage(self,
runner_execution_context, # type:
execution.FnApiRunnerExecutionContext
- stage, # type: translations.Stage
+ bundle_context_manager, # type:
execution.BundleContextManager
):
# type: (...) -> beam_fn_api_pb2.InstructionResponse
"""Run an individual stage.
Args:
- worker_handler_factory: A ``callable`` that takes in an environment id
- and a number of workers, and returns a list of ``WorkerHandler``s.
- stage (translations.Stage)
+ runner_execution_context (execution.FnApiRunnerExecutionContext): An
+ object containing execution information for the pipeline.
+ stage (translations.Stage): A description of the stage to execute.
"""
- def iterable_state_write(values, element_coder_impl):
- # type: (...) -> bytes
- token = unique_name(None, 'iter').encode('ascii')
- out = create_OutputStream()
- for element in values:
- element_coder_impl.encode_to_stream(element, out, True)
- worker_handler.state.append_raw(
- beam_fn_api_pb2.StateKey(
- runner=beam_fn_api_pb2.StateKey.Runner(key=token)),
- out.get())
- return token
-
- worker_handler_list = runner_execution_context.worker_handler_factory(
- stage.environment, self._num_workers)
-
- # All worker_handlers share the same grpc server, so we can read grpc
server
- # info from any worker_handler and read from the first worker_handler.
- worker_handler = next(iter(worker_handler_list))
- context = pipeline_context.PipelineContext(
- runner_execution_context.pipeline_components,
- iterable_state_write=iterable_state_write)
- data_api_service_descriptor = worker_handler.data_api_service_descriptor()
-
- _LOGGER.info('Running %s', stage.name)
- data_input, data_side_input, data_output = self._extract_endpoints(
- stage,
- runner_execution_context.pipeline_components,
- data_api_service_descriptor,
- runner_execution_context.pcoll_buffers,
- context,
- runner_execution_context.safe_coders,
- runner_execution_context.data_channel_coders)
-
- process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor(
- id=self._next_uid(),
- transforms={
- transform.unique_name: transform
- for transform in stage.transforms
- },
- pcollections=dict(
- runner_execution_context.pipeline_components.pcollections.items()),
- coders=dict(
- runner_execution_context.pipeline_components.coders.items()),
- windowing_strategies=dict(
- runner_execution_context.pipeline_components.windowing_strategies.
- items()),
- environments=dict(
- runner_execution_context.pipeline_components.environments.items()))
-
- bundle_context_manager = execution.BundleContextManager(
- runner_execution_context,
- process_bundle_descriptor,
- worker_handler,
- context)
+ worker_handler_list = bundle_context_manager.worker_handlers
- state_api_service_descriptor =
worker_handler.state_api_service_descriptor()
- if state_api_service_descriptor:
- process_bundle_descriptor.state_api_service_descriptor.url = (
- state_api_service_descriptor.url)
+ _LOGGER.info('Running %s', bundle_context_manager.stage.name)
+ data_input, data_side_input, data_output = self._extract_endpoints(
+ bundle_context_manager, runner_execution_context)
# Store the required side inputs into state so it is accessible for the
# worker when it runs this bundle.
- self._store_side_inputs_in_state(
- bundle_context_manager, runner_execution_context, data_side_input)
+ self._store_side_inputs_in_state(runner_execution_context, data_side_input)
# Change cache token across bundle repeats
cache_token_generator = FnApiRunner.get_cache_token_generator(static=False)
self._run_bundle_multiple_times_for_testing(
- worker_handler_list,
+ runner_execution_context,
bundle_context_manager,
data_input,
data_output,
@@ -601,7 +544,7 @@ class FnApiRunner(runner.PipelineRunner):
worker_handler_list,
bundle_context_manager.get_buffer,
bundle_context_manager.get_input_coder_impl,
- process_bundle_descriptor,
+ bundle_context_manager.process_bundle_descriptor,
self._progress_frequency,
num_workers=self._num_workers,
cache_token_generator=cache_token_generator)
@@ -619,11 +562,7 @@ class FnApiRunner(runner.PipelineRunner):
deferred_inputs = {} # type: Dict[str, PartitionableBuffer]
self._collect_written_timers_and_add_to_deferred_inputs(
- runner_execution_context.pipeline_components,
- stage,
- bundle_context_manager,
- deferred_inputs,
- runner_execution_context.data_channel_coders)
+ runner_execution_context, bundle_context_manager, deferred_inputs)
# Queue any process-initiated delayed bundle applications.
for delayed_application in last_result.process_bundle.residual_roots:
name = bundle_context_manager.input_for(
@@ -663,13 +602,8 @@ class FnApiRunner(runner.PipelineRunner):
return result
@staticmethod
- def _extract_endpoints(stage, # type: translations.Stage
- pipeline_components, # type:
beam_runner_api_pb2.Components
- data_api_service_descriptor, # type:
Optional[endpoints_pb2.ApiServiceDescriptor]
- pcoll_buffers, # type: MutableMapping[bytes,
PartitionableBuffer]
- context, # type: pipeline_context.PipelineContext
- safe_coders, # type: Mapping[str, str]
- data_channel_coders, # type: Mapping[str, str]
+ def _extract_endpoints(bundle_context_manager, # type:
execution.BundleContextManager
+ runner_execution_context, # type:
execution.FnApiRunnerExecutionContext
):
# type: (...) -> Tuple[Dict[str, PartitionableBuffer], DataSideInput,
DataOutput]
@@ -680,11 +614,7 @@ class FnApiRunner(runner.PipelineRunner):
Args:
stage (translations.Stage): The stage to extract endpoints
for.
- pipeline_components (beam_runner_api_pb2.Components): Components of the
- pipeline to include coders, transforms, PCollections, etc.
data_api_service_descriptor: A GRPC endpoint descriptor for data plane.
- pcoll_buffers (dict): A dictionary containing buffers for PCollection
- elements.
Returns:
A tuple of (data_input, data_side_input, data_output) dictionaries.
`data_input` is a dictionary mapping (transform_name, output_name) to a
@@ -694,30 +624,34 @@ class FnApiRunner(runner.PipelineRunner):
data_input = {} # type: Dict[str, PartitionableBuffer]
data_side_input = {} # type: DataSideInput
data_output = {} # type: DataOutput
-
- for transform in stage.transforms:
+ for transform in bundle_context_manager.stage.transforms:
if transform.spec.urn in (bundle_processor.DATA_INPUT_URN,
bundle_processor.DATA_OUTPUT_URN):
pcoll_id = transform.spec.payload
if transform.spec.urn == bundle_processor.DATA_INPUT_URN:
- coder_id = data_channel_coders[only_element(
+ coder_id = runner_execution_context.data_channel_coders[only_element(
transform.outputs.values())]
- coder = context.coders[safe_coders.get(coder_id, coder_id)]
+ coder = runner_execution_context.pipeline_context.coders[
+ runner_execution_context.safe_coders.get(coder_id, coder_id)]
if pcoll_id == translations.IMPULSE_BUFFER:
data_input[transform.unique_name] = ListBuffer(
coder_impl=coder.get_impl())
data_input[transform.unique_name].append(ENCODED_IMPULSE_VALUE)
else:
- if pcoll_id not in pcoll_buffers:
- pcoll_buffers[pcoll_id] = ListBuffer(coder_impl=coder.get_impl())
- data_input[transform.unique_name] = pcoll_buffers[pcoll_id]
+ if pcoll_id not in runner_execution_context.pcoll_buffers:
+ runner_execution_context.pcoll_buffers[pcoll_id] = ListBuffer(
+ coder_impl=coder.get_impl())
+ data_input[transform.unique_name] = (
+ runner_execution_context.pcoll_buffers[pcoll_id])
elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN:
data_output[transform.unique_name] = pcoll_id
- coder_id = data_channel_coders[only_element(
+ coder_id = runner_execution_context.data_channel_coders[only_element(
transform.inputs.values())]
else:
raise NotImplementedError
data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id)
+ data_api_service_descriptor = (
+ bundle_context_manager.data_api_service_descriptor())
if data_api_service_descriptor:
data_spec.api_service_descriptor.url = (
data_api_service_descriptor.url)
diff --git
a/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py
b/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py
index 2b690d2..d22d911 100644
---
a/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py
+++
b/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py
@@ -759,7 +759,7 @@ class WorkerHandlerManager(object):
self._cached_handlers = collections.defaultdict(
list) # type: DefaultDict[str, List[WorkerHandler]]
self._workers_by_id = {} # type: Dict[str, WorkerHandler]
- self._state = StateServicer() # rename?
+ self.state_servicer = StateServicer()
self._grpc_server = None # type: Optional[GrpcServer]
def get_worker_handlers(
@@ -781,14 +781,14 @@ class WorkerHandlerManager(object):
self._grpc_server = cast(GrpcServer, None)
elif self._grpc_server is None:
self._grpc_server = GrpcServer(
- self._state, self._job_provision_info, self)
+ self.state_servicer, self._job_provision_info, self)
worker_handler_list = self._cached_handlers[environment_id]
if len(worker_handler_list) < num_workers:
for _ in range(len(worker_handler_list), num_workers):
worker_handler = WorkerHandler.create(
environment,
- self._state,
+ self.state_servicer,
self._job_provision_info,
self._grpc_server)
_LOGGER.info(