This is an automated email from the ASF dual-hosted git repository.
pabloem pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 95e2374 [BEAM-7746] Fix typing in runners
new 5156634 Merge pull request #13060 from [BEAM-7746] Fix typing in
runners
95e2374 is described below
commit 95e2374c659ba9e7fbe42ac217c0b367120b6ee1
Author: Chad Dombrova <[email protected]>
AuthorDate: Mon Oct 5 16:51:34 2020 -0700
[BEAM-7746] Fix typing in runners
---
.../runners/portability/artifact_service.py | 18 +--
.../runners/portability/fn_api_runner/execution.py | 82 ++++++++++--
.../runners/portability/fn_api_runner/fn_runner.py | 121 ++++++++++++------
.../portability/fn_api_runner/translations.py | 25 ++--
.../portability/fn_api_runner/worker_handlers.py | 137 ++++++++++++++-------
.../runners/portability/portable_runner.py | 17 ++-
.../apache_beam/runners/portability/stager.py | 1 +
sdks/python/apache_beam/utils/profiler.py | 24 ++--
sdks/python/mypy.ini | 14 ---
9 files changed, 302 insertions(+), 137 deletions(-)
diff --git a/sdks/python/apache_beam/runners/portability/artifact_service.py
b/sdks/python/apache_beam/runners/portability/artifact_service.py
index 1f3ec1c..18537f4 100644
--- a/sdks/python/apache_beam/runners/portability/artifact_service.py
+++ b/sdks/python/apache_beam/runners/portability/artifact_service.py
@@ -33,9 +33,15 @@ import queue
import sys
import tempfile
import threading
-import typing
from io import BytesIO
+from typing import Any
+from typing import BinaryIO # pylint: disable=unused-import
from typing import Callable
+from typing import Dict
+from typing import List
+from typing import MutableMapping
+from typing import Optional
+from typing import Tuple
import grpc
from future.moves.urllib.request import urlopen
@@ -48,11 +54,6 @@ from apache_beam.portability.api import
beam_artifact_api_pb2_grpc
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.utils import proto_utils
-if typing.TYPE_CHECKING:
- from typing import BinaryIO # pylint: disable=ungrouped-imports
- from typing import Iterable
- from typing import MutableMapping
-
class ArtifactRetrievalService(
beam_artifact_api_pb2_grpc.ArtifactRetrievalServiceServicer):
@@ -61,7 +62,7 @@ class ArtifactRetrievalService(
def __init__(
self,
- file_reader, # type: Callable[[str], BinaryIO],
+ file_reader, # type: Callable[[str], BinaryIO]
chunk_size=None,
):
self._file_reader = file_reader
@@ -105,7 +106,8 @@ class ArtifactStagingService(
file_writer, # type: Callable[[str, Optional[str]], Tuple[BinaryIO,
str]]
):
self._lock = threading.Lock()
- self._jobs_to_stage = {}
+ self._jobs_to_stage = {
+ } # type: Dict[str, Tuple[Dict[Any,
List[beam_runner_api_pb2.ArtifactInformation]], threading.Event]]
self._file_writer = file_writer
def register_job(
diff --git
a/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py
b/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py
index 5b8e91c..bc69123 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py
@@ -17,6 +17,8 @@
"""Set of utilities for execution of a pipeline by the FnApiRunner."""
+# mypy: disallow-untyped-defs
+
from __future__ import absolute_import
import collections
@@ -26,10 +28,12 @@ from typing import TYPE_CHECKING
from typing import Any
from typing import DefaultDict
from typing import Dict
+from typing import Iterable
from typing import Iterator
from typing import List
from typing import MutableMapping
from typing import Optional
+from typing import Set
from typing import Tuple
from typing_extensions import Protocol
@@ -59,9 +63,13 @@ from apache_beam.utils import proto_utils
from apache_beam.utils import windowed_value
if TYPE_CHECKING:
- from apache_beam.coders.coder_impl import CoderImpl
+ from apache_beam.coders.coder_impl import CoderImpl, WindowedValueCoderImpl
+ from apache_beam.portability.api import endpoints_pb2
from apache_beam.runners.portability.fn_api_runner import worker_handlers
+ from apache_beam.runners.portability.fn_api_runner.fn_runner import
DataOutput
+ from apache_beam.runners.portability.fn_api_runner.fn_runner import
OutputTimers
from apache_beam.runners.portability.fn_api_runner.translations import
DataSideInput
+ from apache_beam.transforms import core
from apache_beam.transforms.window import BoundedWindow
ENCODED_IMPULSE_VALUE = WindowedValueCoder(
@@ -87,13 +95,27 @@ class PartitionableBuffer(Buffer, Protocol):
# type: (int) -> List[List[bytes]]
pass
+ @property
+ def cleared(self):
+ # type: () -> bool
+ pass
+
+ def clear(self):
+ # type: () -> None
+ pass
+
+ def reset(self):
+ # type: () -> None
+ pass
+
class ListBuffer(object):
"""Used to support parititioning of a list."""
def __init__(self, coder_impl):
+ # type: (CoderImpl) -> None
self._coder_impl = coder_impl
self._inputs = [] # type: List[bytes]
- self._grouped_output = None
+ self._grouped_output = None # type: Optional[List[List[bytes]]]
self.cleared = False
def append(self, element):
@@ -139,6 +161,8 @@ class ListBuffer(object):
self._grouped_output = None
def reset(self):
+ # type: () -> None
+
"""Resets a cleared buffer for reuse."""
if not self.cleared:
raise RuntimeError('Trying to reset a non-cleared ListBuffer.')
@@ -150,7 +174,7 @@ class GroupingBuffer(object):
def __init__(self,
pre_grouped_coder, # type: coders.Coder
post_grouped_coder, # type: coders.Coder
- windowing
+ windowing # type: core.Windowing
):
# type: (...) -> None
self._key_coder = pre_grouped_coder.key_coder()
@@ -227,12 +251,24 @@ class GroupingBuffer(object):
"""
return itertools.chain(*self.partition(1))
+ # these should never be accessed, but they allow this class to meet the
+ # PartionableBuffer protocol
+ cleared = False
+
+ def clear(self):
+ # type: () -> None
+ pass
+
+ def reset(self):
+ # type: () -> None
+ pass
+
class WindowGroupingBuffer(object):
"""Used to partition windowed side inputs."""
def __init__(
self,
- access_pattern,
+ access_pattern, # type: beam_runner_api_pb2.FunctionSpec
coder # type: WindowedValueCoder
):
# type: (...) -> None
@@ -283,17 +319,21 @@ class
GenericNonMergingWindowFn(window.NonMergingWindowFn):
URN = 'internal-generic-non-merging'
def __init__(self, coder):
+ # type: (coders.Coder) -> None
self._coder = coder
def assign(self, assign_context):
+ # type: (window.WindowFn.AssignContext) -> Iterable[BoundedWindow]
raise NotImplementedError()
def get_window_coder(self):
+ # type: () -> coders.Coder
return self._coder
@staticmethod
@window.urns.RunnerApiFn.register_urn(URN, bytes)
def from_runner_api_parameter(window_coder_id, context):
+ # type: (bytes, Any) -> GenericNonMergingWindowFn
return GenericNonMergingWindowFn(
context.coders[window_coder_id.decode('utf-8')])
@@ -308,9 +348,11 @@ class FnApiRunnerExecutionContext(object):
stages, # type: List[translations.Stage]
worker_handler_manager, # type: worker_handlers.WorkerHandlerManager
pipeline_components, # type: beam_runner_api_pb2.Components
- safe_coders,
- data_channel_coders,
+ safe_coders, # type: Dict[str, str]
+ data_channel_coders, # type: Dict[str, str]
):
+ # type: (...) -> None
+
"""
:param worker_handler_manager: This class manages the set of worker
handlers, and the communication with state / control APIs.
@@ -365,7 +407,7 @@ class FnApiRunnerExecutionContext(object):
return all_side_inputs
all_side_inputs = frozenset(get_all_side_inputs())
- data_side_inputs_by_producing_stage = {}
+ data_side_inputs_by_producing_stage = {} # type: Dict[str, DataSideInput]
producing_stages_by_pcoll = {}
@@ -397,6 +439,7 @@ class FnApiRunnerExecutionContext(object):
return data_side_inputs_by_producing_stage
def _make_safe_windowing_strategy(self, id):
+ # type: (str) -> str
windowing_strategy_proto =
self.pipeline_components.windowing_strategies[id]
if windowing_strategy_proto.window_fn.urn in SAFE_WINDOW_FNS:
return id
@@ -420,15 +463,17 @@ class FnApiRunnerExecutionContext(object):
@property
def state_servicer(self):
+ # type: () -> worker_handlers.StateServicer
# TODO(BEAM-9625): Ensure FnApiRunnerExecutionContext owns StateServicer
return self.worker_handler_manager.state_servicer
def next_uid(self):
+ # type: () -> str
self._last_uid += 1
return str(self._last_uid)
def _iterable_state_write(self, values, element_coder_impl):
- # type: (...) -> bytes
+ # type: (Iterable, CoderImpl) -> bytes
token = unique_name(None, 'iter').encode('ascii')
out = create_OutputStream()
for element in values:
@@ -484,21 +529,23 @@ class BundleContextManager(object):
stage, # type: translations.Stage
num_workers, # type: int
):
+ # type: (...) -> None
self.execution_context = execution_context
self.stage = stage
self.bundle_uid = self.execution_context.next_uid()
self.num_workers = num_workers
# Properties that are lazily initialized
- self._process_bundle_descriptor = None
- self._worker_handlers = None
+ self._process_bundle_descriptor = None # type:
Optional[beam_fn_api_pb2.ProcessBundleDescriptor]
+ self._worker_handlers = None # type:
Optional[List[worker_handlers.WorkerHandler]]
# a mapping of {(transform_id, timer_family_id): timer_coder_id}. The map
# is built after self._process_bundle_descriptor is initialized.
# This field can be used to tell whether current bundle has timers.
- self._timer_coder_ids = None
+ self._timer_coder_ids = None # type: Optional[Dict[Tuple[str, str], str]]
@property
def worker_handlers(self):
+ # type: () -> List[worker_handlers.WorkerHandler]
if self._worker_handlers is None:
self._worker_handlers = (
self.execution_context.worker_handler_manager.get_worker_handlers(
@@ -506,23 +553,27 @@ class BundleContextManager(object):
return self._worker_handlers
def data_api_service_descriptor(self):
+ # type: () -> Optional[endpoints_pb2.ApiServiceDescriptor]
# All worker_handlers share the same grpc server, so we can read grpc
server
# info from any worker_handler and read from the first worker_handler.
return self.worker_handlers[0].data_api_service_descriptor()
def state_api_service_descriptor(self):
+ # type: () -> Optional[endpoints_pb2.ApiServiceDescriptor]
# All worker_handlers share the same grpc server, so we can read grpc
server
# info from any worker_handler and read from the first worker_handler.
return self.worker_handlers[0].state_api_service_descriptor()
@property
def process_bundle_descriptor(self):
+ # type: () -> beam_fn_api_pb2.ProcessBundleDescriptor
if self._process_bundle_descriptor is None:
self._process_bundle_descriptor = self._build_process_bundle_descriptor()
self._timer_coder_ids = self._build_timer_coders_id_map()
return self._process_bundle_descriptor
def _build_process_bundle_descriptor(self):
+ # type: () -> beam_fn_api_pb2.ProcessBundleDescriptor
# Cannot be invoked until *after* _extract_endpoints is called.
# Always populate the timer_api_service_descriptor.
return beam_fn_api_pb2.ProcessBundleDescriptor(
@@ -543,7 +594,7 @@ class BundleContextManager(object):
timer_api_service_descriptor=self.data_api_service_descriptor())
def extract_bundle_inputs_and_outputs(self):
- # type: (...) -> Tuple[Dict[str, PartitionableBuffer], DataOutput,
Dict[Tuple[str, str], str]]
+ # type: () -> Tuple[Dict[str, PartitionableBuffer], DataOutput,
Dict[Tuple[str, str], bytes]]
"""Returns maps of transform names to PCollection identifiers.
@@ -560,7 +611,7 @@ class BundleContextManager(object):
data_input = {} # type: Dict[str, PartitionableBuffer]
data_output = {} # type: DataOutput
# A mapping of {(transform_id, timer_family_id) : buffer_id}
- expected_timer_output = {} # type: Dict[Tuple[str, str], str]
+ expected_timer_output = {} # type: OutputTimers
for transform in self.stage.transforms:
if transform.spec.urn in (bundle_processor.DATA_INPUT_URN,
bundle_processor.DATA_OUTPUT_URN):
@@ -609,6 +660,8 @@ class BundleContextManager(object):
return self.get_coder_impl(coder_id)
def _build_timer_coders_id_map(self):
+ # type: () -> Dict[Tuple[str, str], str]
+ assert self._process_bundle_descriptor is not None
timer_coder_ids = {}
for transform_id, transform_proto in (self._process_bundle_descriptor
.transforms.items()):
@@ -621,6 +674,7 @@ class BundleContextManager(object):
return timer_coder_ids
def get_coder_impl(self, coder_id):
+ # type: (str) -> CoderImpl
if coder_id in self.execution_context.safe_coders:
return self.execution_context.pipeline_context.coders[
self.execution_context.safe_coders[coder_id]].get_impl()
@@ -628,6 +682,8 @@ class BundleContextManager(object):
return
self.execution_context.pipeline_context.coders[coder_id].get_impl()
def get_timer_coder_impl(self, transform_id, timer_family_id):
+ # type: (str, str) -> CoderImpl
+ assert self._timer_coder_ids is not None
return self.get_coder_impl(
self._timer_coder_ids[(transform_id, timer_family_id)])
diff --git
a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py
b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py
index 9c43f33..404261f 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py
@@ -18,6 +18,7 @@
"""A PipelineRunner using the SDK harness.
"""
# pytype: skip-file
+# mypy: check-untyped-defs
from __future__ import absolute_import
from __future__ import print_function
@@ -36,12 +37,14 @@ from builtins import object
from typing import TYPE_CHECKING
from typing import Callable
from typing import Dict
+from typing import Iterator
from typing import List
from typing import Mapping
from typing import MutableMapping
from typing import Optional
from typing import Tuple
from typing import TypeVar
+from typing import Union
from apache_beam.coders.coder_impl import create_OutputStream
from apache_beam.metrics import metric
@@ -62,13 +65,14 @@ from
apache_beam.runners.portability.fn_api_runner.translations import create_bu
from apache_beam.runners.portability.fn_api_runner.translations import
only_element
from apache_beam.runners.portability.fn_api_runner.worker_handlers import
WorkerHandlerManager
from apache_beam.transforms import environments
-from apache_beam.utils import profiler
from apache_beam.utils import proto_utils
from apache_beam.utils import thread_pool_executor
+from apache_beam.utils.profiler import Profile
if TYPE_CHECKING:
from apache_beam.pipeline import Pipeline
from apache_beam.portability.api import metrics_pb2
+ from apache_beam.runners.portability.fn_api_runner.worker_handlers import
WorkerHandler
_LOGGER = logging.getLogger(__name__)
@@ -77,6 +81,7 @@ T = TypeVar('T')
DataSideInput = Dict[Tuple[str, str],
Tuple[bytes, beam_runner_api_pb2.FunctionSpec]]
DataOutput = Dict[str, bytes]
+OutputTimers = Dict[Tuple[str, str], bytes]
BundleProcessResult = Tuple[beam_fn_api_pb2.InstructionResponse,
List[beam_fn_api_pb2.ProcessBundleSplitResponse]]
@@ -88,11 +93,12 @@ class FnApiRunner(runner.PipelineRunner):
def __init__(
self,
default_environment=None, # type: Optional[environments.Environment]
- bundle_repeat=0,
- use_state_iterables=False,
+ bundle_repeat=0, # type: int
+ use_state_iterables=False, # type: bool
provision_info=None, # type: Optional[ExtendedProvisionInfo]
- progress_request_frequency=None,
- is_drain=False):
+ progress_request_frequency=None, # type: Optional[float]
+ is_drain=False # type: bool
+ ):
# type: (...) -> None
"""Creates a new Fn API Runner.
@@ -114,7 +120,7 @@ class FnApiRunner(runner.PipelineRunner):
self._bundle_repeat = bundle_repeat
self._num_workers = 1
self._progress_frequency = progress_request_frequency
- self._profiler_factory = None # type: Optional[Callable[...,
profiler.Profile]]
+ self._profiler_factory = None # type: Optional[Callable[..., Profile]]
self._use_state_iterables = use_state_iterables
self._is_drain = is_drain
self._provision_info = provision_info or ExtendedProvisionInfo(
@@ -123,6 +129,7 @@ class FnApiRunner(runner.PipelineRunner):
@staticmethod
def supported_requirements():
+ # type: () -> Tuple[str, ...]
return (
common_urns.requirements.REQUIRES_STATEFUL_PROCESSING.urn,
common_urns.requirements.REQUIRES_BUNDLE_FINALIZATION.urn,
@@ -169,7 +176,7 @@ class FnApiRunner(runner.PipelineRunner):
self._default_environment = environments.SubprocessSDKEnvironment(
command_string=command_string)
- self._profiler_factory = profiler.Profile.factory_from_options(
+ self._profiler_factory = Profile.factory_from_options(
options.view_as(pipeline_options.ProfilingOptions))
self._latest_run_result = self.run_via_runner_api(
@@ -187,6 +194,7 @@ class FnApiRunner(runner.PipelineRunner):
@contextlib.contextmanager
def maybe_profile(self):
+ # type: () -> Iterator[None]
if self._profiler_factory:
try:
profile_id = 'direct-' + subprocess.check_output([
@@ -194,7 +202,8 @@ class FnApiRunner(runner.PipelineRunner):
]).decode(errors='ignore').strip()
except subprocess.CalledProcessError:
profile_id = 'direct-unknown'
- profiler = self._profiler_factory(profile_id, time_prefix='')
+ profiler = self._profiler_factory(
+ profile_id, time_prefix='') # type: Optional[Profile]
else:
profiler = None
@@ -231,10 +240,13 @@ class FnApiRunner(runner.PipelineRunner):
yield
def _validate_requirements(self, pipeline_proto):
+ # type: (beam_runner_api_pb2.Pipeline) -> None
+
"""As a test runner, validate requirements were set correctly."""
expected_requirements = set()
def add_requirements(transform_id):
+ # type: (str) -> None
transform = pipeline_proto.components.transforms[transform_id]
if transform.spec.urn in translations.PAR_DO_URNS:
payload = proto_utils.parse_Bytes(
@@ -266,6 +278,8 @@ class FnApiRunner(runner.PipelineRunner):
(expected_requirements - set(pipeline_proto.requirements)))
def _check_requirements(self, pipeline_proto):
+ # type: (beam_runner_api_pb2.Pipeline) -> None
+
"""Check that this runner can satisfy all pipeline requirements."""
supported_requirements = set(self.supported_requirements())
for requirement in pipeline_proto.requirements:
@@ -355,10 +369,10 @@ class FnApiRunner(runner.PipelineRunner):
self,
runner_execution_context, # type: execution.FnApiRunnerExecutionContext
bundle_manager, # type: BundleManager
- data_input,
+ data_input, # type: Dict[str, execution.PartitionableBuffer]
data_output, # type: DataOutput
- fired_timers,
- expected_output_timers,
+ fired_timers, # type: Mapping[Tuple[str, str],
execution.PartitionableBuffer]
+ expected_output_timers, # type: Dict[Tuple[str, str], bytes]
):
# type: (...) -> None
@@ -407,7 +421,12 @@ class FnApiRunner(runner.PipelineRunner):
written_timers.clear()
def _add_sdk_delayed_applications_to_deferred_inputs(
- self, bundle_context_manager, bundle_result, deferred_inputs):
+ self,
+ bundle_context_manager, # type: execution.BundleContextManager
+ bundle_result, # type: beam_fn_api_pb2.InstructionResponse
+ deferred_inputs # type: MutableMapping[str,
execution.PartitionableBuffer]
+ ):
+ # type: (...) -> None
for delayed_application in bundle_result.process_bundle.residual_roots:
name = bundle_context_manager.input_for(
delayed_application.application.transform_id,
@@ -421,8 +440,8 @@ class FnApiRunner(runner.PipelineRunner):
self,
splits, # type: List[beam_fn_api_pb2.ProcessBundleSplitResponse]
bundle_context_manager, # type: execution.BundleContextManager
- last_sent,
- deferred_inputs # type: MutableMapping[str, PartitionableBuffer]
+ last_sent, # type: Dict[str, execution.PartitionableBuffer]
+ deferred_inputs # type: MutableMapping[str,
execution.PartitionableBuffer]
):
# type: (...) -> None
@@ -485,7 +504,8 @@ class FnApiRunner(runner.PipelineRunner):
"""
data_input, data_output, expected_timer_output = (
bundle_context_manager.extract_bundle_inputs_and_outputs())
- input_timers = {}
+ input_timers = {
+ } # type: Mapping[Tuple[str, str], execution.PartitionableBuffer]
worker_handler_manager = runner_execution_context.worker_handler_manager
_LOGGER.info('Running %s', bundle_context_manager.stage.name)
@@ -509,9 +529,11 @@ class FnApiRunner(runner.PipelineRunner):
self._progress_frequency,
cache_token_generator=cache_token_generator)
- final_result = None
+ final_result = None # type: Optional[beam_fn_api_pb2.InstructionResponse]
def merge_results(last_result):
+ # type: (beam_fn_api_pb2.InstructionResponse) ->
beam_fn_api_pb2.InstructionResponse
+
""" Merge the latest result with other accumulated results. """
return (
last_result
@@ -539,7 +561,6 @@ class FnApiRunner(runner.PipelineRunner):
else:
data_input = deferred_inputs
input_timers = fired_timers
- bundle_manager._registered = True
# Store the required downstream side inputs into state so it is accessible
# for the worker when it runs bundles that consume this stage's output.
@@ -552,13 +573,16 @@ class FnApiRunner(runner.PipelineRunner):
def _run_bundle(
self,
- runner_execution_context,
- bundle_context_manager,
- data_input,
- data_output,
- input_timers,
- expected_timer_output,
- bundle_manager):
+ runner_execution_context, # type: execution.FnApiRunnerExecutionContext
+ bundle_context_manager, # type: execution.BundleContextManager
+ data_input, # type: Dict[str, execution.PartitionableBuffer]
+ data_output, # type: DataOutput
+ input_timers, # type: Mapping[Tuple[str, str],
execution.PartitionableBuffer]
+ expected_timer_output, # type: Dict[Tuple[str, str], bytes]
+ bundle_manager # type: BundleManager
+ ):
+ # type: (...) -> Tuple[beam_fn_api_pb2.InstructionResponse, Dict[str,
execution.PartitionableBuffer], Dict[Tuple[str, str], ListBuffer]]
+
"""Execute a bundle, and return a result object, and deferred inputs."""
self._run_bundle_multiple_times_for_testing(
runner_execution_context,
@@ -576,7 +600,7 @@ class FnApiRunner(runner.PipelineRunner):
# - SDK-initiated deferred applications of root elements
# - Runner-initiated deferred applications of root elements
deferred_inputs = {} # type: Dict[str, execution.PartitionableBuffer]
- fired_timers = {}
+ fired_timers = {} # type: Dict[Tuple[str, str], ListBuffer]
self._collect_written_timers_and_add_to_fired_timers(
bundle_context_manager, fired_timers)
@@ -601,12 +625,15 @@ class FnApiRunner(runner.PipelineRunner):
@staticmethod
def get_cache_token_generator(static=True):
+ # type: (bool) -> Iterator[beam_fn_api_pb2.ProcessBundleRequest.CacheToken]
+
"""A generator for cache tokens.
:arg static If True, generator always returns the same cache token
If False, generator returns a new cache token each time
:return A generator which returns a cache token on next(generator)
"""
def generate_token(identifier):
+ # type: (int) -> beam_fn_api_pb2.ProcessBundleRequest.CacheToken
return beam_fn_api_pb2.ProcessBundleRequest.CacheToken(
user_state=beam_fn_api_pb2.ProcessBundleRequest.CacheToken.UserState(
),
@@ -614,44 +641,55 @@ class FnApiRunner(runner.PipelineRunner):
class StaticGenerator(object):
def __init__(self):
+ # type: () -> None
self._token = generate_token(1)
def __iter__(self):
+ # type: () -> StaticGenerator
# pylint: disable=non-iterator-returned
return self
def __next__(self):
+ # type: () -> beam_fn_api_pb2.ProcessBundleRequest.CacheToken
return self._token
class DynamicGenerator(object):
def __init__(self):
+ # type: () -> None
self._counter = 0
self._lock = threading.Lock()
def __iter__(self):
+ # type: () -> DynamicGenerator
# pylint: disable=non-iterator-returned
return self
def __next__(self):
+ # type: () -> beam_fn_api_pb2.ProcessBundleRequest.CacheToken
with self._lock:
self._counter += 1
return generate_token(self._counter)
- return StaticGenerator() if static else DynamicGenerator()
+ if static:
+ return StaticGenerator()
+ else:
+ return DynamicGenerator()
class ExtendedProvisionInfo(object):
def __init__(self,
provision_info=None, # type:
Optional[beam_provision_api_pb2.ProvisionInfo]
- artifact_staging_dir=None,
+ artifact_staging_dir=None, # type: Optional[str]
job_name=None, # type: Optional[str]
):
+ # type: (...) -> None
self.provision_info = (
provision_info or beam_provision_api_pb2.ProvisionInfo())
self.artifact_staging_dir = artifact_staging_dir
self.job_name = job_name
def for_environment(self, env):
+ # type: (...) -> ExtendedProvisionInfo
if env.dependencies:
provision_info_with_deps = copy.deepcopy(self.provision_info)
provision_info_with_deps.dependencies.extend(env.dependencies)
@@ -699,9 +737,11 @@ class BundleManager(object):
def __init__(self,
bundle_context_manager, # type: execution.BundleContextManager
- progress_frequency=None,
+ progress_frequency=None, # type: Optional[float]
cache_token_generator=FnApiRunner.get_cache_token_generator()
):
+ # type: (...) -> None
+
"""Set up a bundle manager.
Args:
@@ -709,7 +749,7 @@ class BundleManager(object):
"""
self.bundle_context_manager = bundle_context_manager # type:
execution.BundleContextManager
self._progress_frequency = progress_frequency
- self._worker_handler = None # type: Optional[execution.WorkerHandler]
+ self._worker_handler = None # type: Optional[WorkerHandler]
self._cache_token_generator = cache_token_generator
def _send_input_to_worker(self,
@@ -727,6 +767,7 @@ class BundleManager(object):
def _send_timers_to_worker(
self, process_bundle_id, transform_id, timer_family_id, timers):
+ # type: (...) -> None
assert self._worker_handler is not None
timer_out = self._worker_handler.data_conn.output_timer_stream(
process_bundle_id, transform_id, timer_family_id)
@@ -751,8 +792,9 @@ class BundleManager(object):
def _generate_splits_for_testing(self,
split_manager,
- inputs, # type: Mapping[str,
PartitionableBuffer]
- process_bundle_id):
+ inputs, # type: Mapping[str,
execution.PartitionableBuffer]
+ process_bundle_id
+ ):
# type: (...) -> List[beam_fn_api_pb2.ProcessBundleSplitResponse]
split_results = [] # type:
List[beam_fn_api_pb2.ProcessBundleSplitResponse]
read_transform_id, buffer_data = only_element(inputs.items())
@@ -819,8 +861,8 @@ class BundleManager(object):
inputs, # type: Mapping[str,
execution.PartitionableBuffer]
expected_outputs, # type: DataOutput
fired_timers, # type: Mapping[Tuple[str, str],
execution.PartitionableBuffer]
- expected_output_timers, # type: Dict[Tuple[str, str],
str]
- dry_run=False,
+ expected_output_timers, # type: OutputTimers
+ dry_run=False, # type: bool
):
# type: (...) -> BundleProcessResult
# Unique id for the instruction processing this bundle.
@@ -863,7 +905,8 @@ class BundleManager(object):
split_results = self._generate_splits_for_testing(
split_manager, inputs, process_bundle_id)
- expect_reads = list(expected_outputs.keys())
+ expect_reads = list(
+ expected_outputs.keys()) # type: List[Union[str, Tuple[str, str]]]
expect_reads.extend(list(expected_output_timers.keys()))
# Gather all output data.
@@ -871,7 +914,7 @@ class BundleManager(object):
process_bundle_id,
expect_reads,
abort_callback=lambda:
- (result_future.is_done() and result_future.get().error)):
+ (result_future.is_done() and bool(result_future.get().error))):
if isinstance(output, beam_fn_api_pb2.Elements.Timers) and not dry_run:
with BundleManager._lock:
timer_buffer = self.bundle_context_manager.get_buffer(
@@ -910,7 +953,7 @@ class ParallelBundleManager(BundleManager):
def __init__(
self,
bundle_context_manager, # type: execution.BundleContextManager
- progress_frequency=None,
+ progress_frequency=None, # type: Optional[float]
cache_token_generator=None,
**kwargs):
# type: (...) -> None
@@ -924,8 +967,8 @@ class ParallelBundleManager(BundleManager):
inputs, # type: Mapping[str,
execution.PartitionableBuffer]
expected_outputs, # type: DataOutput
fired_timers, # type: Mapping[Tuple[str, str],
execution.PartitionableBuffer]
- expected_output_timers, # type: Dict[Tuple[str, str],
str]
- dry_run=False,
+ expected_output_timers, # type: OutputTimers
+ dry_run=False, # type: bool
):
# type: (...) -> BundleProcessResult
part_inputs = [{} for _ in range(self._num_workers)
@@ -993,7 +1036,7 @@ class ProgressRequester(threading.Thread):
self._instruction_id = instruction_id
self._frequency = frequency
self._done = False
- self._latest_progress = None
+ self._latest_progress = None # type:
Optional[beam_fn_api_pb2.ProcessBundleProgressResponse]
self._callback = callback
self.daemon = True
diff --git
a/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py
b/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py
index e4d03c8..d29ab3b 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py
@@ -18,6 +18,7 @@
"""Pipeline transformations for the FnApiRunner.
"""
# pytype: skip-file
+# mypy: check-untyped-defs
from __future__ import absolute_import
from __future__ import print_function
@@ -307,6 +308,9 @@ class Stage(object):
for side in side_inputs
},
main_input=main_input_id)
+ # at this point we should have resolved an environment, as the key of
+ # components.environments cannot be None.
+ assert self.environment is not None
exec_payload = beam_runner_api_pb2.ExecutableStagePayload(
environment=components.environments[self.environment],
input=main_input_id,
@@ -373,7 +377,7 @@ class TransformContext(object):
None) # type: ignore[arg-type]
self.bytes_coder_id = self.add_or_get_coder_id(coder_proto, 'bytes_coder')
self.safe_coders = {self.bytes_coder_id: self.bytes_coder_id}
- self.data_channel_coders = {}
+ self.data_channel_coders = {} # type: Dict[str, str]
def add_or_get_coder_id(
self,
@@ -812,7 +816,8 @@ def pack_combiners(stages, context):
def _get_fallback_coder_id():
return context.add_or_get_coder_id(
- coders.registry.get_coder(object).to_runner_api(None))
+ # passing None works here because there are no component coders
+ coders.registry.get_coder(object).to_runner_api(None)) # type:
ignore[arg-type]
def _get_component_coder_id_from_kv_coder(coder, index):
assert index < 2
@@ -922,11 +927,14 @@ def pack_combiners(stages, context):
is_bounded=input_pcoll.is_bounded))
# Set up Pack stage.
+ # TODO(BEAM-7746): classes that inherit from RunnerApiFn are expected to
+ # accept a PipelineContext for from_runner_api/to_runner_api. Determine
+ # how to accomodate this.
pack_combine_fn = combiners.SingleInputTupleCombineFn(
*[
- core.CombineFn.from_runner_api(combine_payload.combine_fn, context)
+ core.CombineFn.from_runner_api(combine_payload.combine_fn,
context) # type: ignore[arg-type]
for combine_payload in combine_payloads
- ]).to_runner_api(context)
+ ]).to_runner_api(context) # type: ignore[arg-type]
pack_transform = beam_runner_api_pb2.PTransform(
unique_name=pack_combine_name + '/Pack',
spec=beam_runner_api_pb2.FunctionSpec(
@@ -1480,14 +1488,17 @@ def sink_flattens(stages, pipeline_context):
def greedily_fuse(stages, pipeline_context):
+ # type: (Iterable[Stage], TransformContext) -> FrozenSet[Stage]
+
"""Places transforms sharing an edge in the same stage, whenever possible.
"""
- producers_by_pcoll = {}
- consumers_by_pcoll = collections.defaultdict(list)
+ producers_by_pcoll = {} # type: Dict[str, Stage]
+ consumers_by_pcoll = collections.defaultdict(
+ list) # type: DefaultDict[str, List[Stage]]
# Used to always reference the correct stage as the producer and
# consumer maps are not updated when stages are fused away.
- replacements = {}
+ replacements = {} # type: Dict[Stage, Stage]
def replacement(s):
old_ss = []
diff --git
a/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py
b/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py
index 1f1d483..756e4ba 100644
---
a/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py
+++
b/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py
@@ -17,6 +17,8 @@
"""Code for communicating with the Workers."""
+# mypy: disallow-untyped-defs
+
from __future__ import absolute_import
import collections
@@ -29,6 +31,9 @@ import sys
import threading
import time
from builtins import object
+from typing import TYPE_CHECKING
+from typing import Any
+from typing import BinaryIO # pylint: disable=unused-import
from typing import Callable
from typing import DefaultDict
from typing import Dict
@@ -38,6 +43,8 @@ from typing import List
from typing import Mapping
from typing import Optional
from typing import Tuple
+from typing import Type
+from typing import TypeVar
from typing import Union
from typing import cast
from typing import overload
@@ -64,6 +71,12 @@ from apache_beam.runners.worker.sdk_worker import _Future
from apache_beam.runners.worker.statecache import StateCache
from apache_beam.utils import proto_utils
from apache_beam.utils import thread_pool_executor
+from apache_beam.utils.sentinel import Sentinel
+
+if TYPE_CHECKING:
+ from grpc import ServicerContext
+ from google.protobuf import message
+ from apache_beam.runners.portability.fn_api_runner.fn_runner import
ExtendedProvisionInfo # pylint: disable=ungrouped-imports
# State caching is enabled in the fn_api_runner for testing, except for one
# test which runs without state caching (FnApiRunnerTestWithDisabledCaching).
@@ -75,10 +88,11 @@ DATA_BUFFER_TIME_LIMIT_MS = 1000
_LOGGER = logging.getLogger(__name__)
+T = TypeVar('T')
ConstructorFn = Callable[[
Union['message.Message', bytes],
- 'StateServicer',
- Optional['ExtendedProvisionInfo'],
+ 'sdk_worker.StateHandler',
+ 'ExtendedProvisionInfo',
'GrpcServer'
],
'WorkerHandler']
@@ -90,8 +104,9 @@ class ControlConnection(object):
_lock = threading.Lock()
def __init__(self):
+ # type: () -> None
self._push_queue = queue.Queue(
- ) # type: queue.Queue[beam_fn_api_pb2.InstructionRequest]
+ ) # type: queue.Queue[Union[beam_fn_api_pb2.InstructionRequest, Sentinel]]
self._input = None # type:
Optional[Iterable[beam_fn_api_pb2.InstructionResponse]]
self._futures_by_id = dict() # type: Dict[str, ControlFuture]
self._read_thread = threading.Thread(
@@ -99,12 +114,14 @@ class ControlConnection(object):
self._state = BeamFnControlServicer.UNSTARTED_STATE
def _read(self):
+ # type: () -> None
+ assert self._input is not None
for data in self._input:
self._futures_by_id.pop(data.instruction_id).set(data)
@overload
def push(self, req):
- # type: (BeamFnControlServicer.DoneMarker) -> None
+ # type: (Sentinel) -> None
pass
@overload
@@ -113,7 +130,8 @@ class ControlConnection(object):
pass
def push(self, req):
- if req == BeamFnControlServicer._DONE_MARKER:
+ # type: (Union[Sentinel, beam_fn_api_pb2.InstructionRequest]) ->
Optional[ControlFuture]
+ if req is BeamFnControlServicer._DONE_MARKER:
self._push_queue.put(req)
return None
if not req.instruction_id:
@@ -126,7 +144,7 @@ class ControlConnection(object):
return future
def get_req(self):
- # type: () -> beam_fn_api_pb2.InstructionRequest
+ # type: () -> Union[Sentinel, beam_fn_api_pb2.InstructionRequest]
return self._push_queue.get()
def set_input(self, input):
@@ -147,6 +165,7 @@ class ControlConnection(object):
self._state = BeamFnControlServicer.DONE_STATE
def abort(self, exn):
+ # type: (Exception) -> None
for future in self._futures_by_id.values():
future.abort(exn)
@@ -158,23 +177,20 @@ class
BeamFnControlServicer(beam_fn_api_pb2_grpc.BeamFnControlServicer):
STARTED_STATE = 'started'
DONE_STATE = 'done'
- class DoneMarker(object):
- pass
-
- _DONE_MARKER = DoneMarker()
+ _DONE_MARKER = Sentinel.sentinel
def __init__(
self,
worker_manager, # type: WorkerHandlerManager
):
+ # type: (...) -> None
self._worker_manager = worker_manager
self._lock = threading.Lock()
self._uid_counter = 0
self._state = self.UNSTARTED_STATE
# following self._req_* variables are used for debugging purpose, data is
# added only when self._log_req is True.
- self._req_sent = collections.defaultdict(int)
- self._req_worker_mapping = {}
+ self._req_sent = collections.defaultdict(int) # type: DefaultDict[str,
int]
self._log_req = logging.getLogger().getEffectiveLevel() <= logging.DEBUG
self._connections_by_worker_id = collections.defaultdict(
ControlConnection) # type: DefaultDict[str, ControlConnection]
@@ -186,7 +202,7 @@ class
BeamFnControlServicer(beam_fn_api_pb2_grpc.BeamFnControlServicer):
def Control(self,
iterator, # type: Iterable[beam_fn_api_pb2.InstructionResponse]
- context
+ context # type: ServicerContext
):
# type: (...) -> Iterator[beam_fn_api_pb2.InstructionRequest]
with self._lock:
@@ -213,16 +229,14 @@ class
BeamFnControlServicer(beam_fn_api_pb2_grpc.BeamFnControlServicer):
self._req_sent[to_push.instruction_id] += 1
def done(self):
+ # type: () -> None
self._state = self.DONE_STATE
_LOGGER.debug(
'Runner: Requests sent by runner: %s',
[(str(req), cnt) for req, cnt in self._req_sent.items()])
- _LOGGER.debug(
- 'Runner: Requests multiplexing info: %s',
- [(str(req), worker)
- for req, worker in self._req_worker_mapping.items()])
def GetProcessBundleDescriptor(self, id, context=None):
+ # type: (beam_fn_api_pb2.GetProcessBundleDescriptorRequest, Any) ->
beam_fn_api_pb2.ProcessBundleDescriptor
return self._worker_manager.get_process_bundle_descriptor(id)
@@ -242,10 +256,10 @@ class WorkerHandler(object):
data_conn = None # type: data_plane._GrpcDataChannel
def __init__(self,
- control_handler,
- data_plane_handler,
- state, # type: StateServicer
- provision_info # type: Optional[ExtendedProvisionInfo]
+ control_handler, # type: Any
+ data_plane_handler, # type: Any
+ state, # type: sdk_worker.StateHandler
+ provision_info # type: ExtendedProvisionInfo
):
# type: (...) -> None
@@ -278,6 +292,14 @@ class WorkerHandler(object):
# type: () -> None
raise NotImplementedError
+ def control_api_service_descriptor(self):
+ # type: () -> endpoints_pb2.ApiServiceDescriptor
+ raise NotImplementedError
+
+ def artifact_api_service_descriptor(self):
+ # type: () -> endpoints_pb2.ApiServiceDescriptor
+ raise NotImplementedError
+
def data_api_service_descriptor(self):
# type: () -> Optional[endpoints_pb2.ApiServiceDescriptor]
raise NotImplementedError
@@ -296,9 +318,10 @@ class WorkerHandler(object):
urn, # type: str
payload_type # type: Optional[Type[T]]
):
- # type: (...) -> Callable[[Callable[[T, StateServicer,
Optional[ExtendedProvisionInfo], GrpcServer], WorkerHandler]], Callable[[T,
StateServicer, Optional[ExtendedProvisionInfo], GrpcServer], WorkerHandler]]
+ # type: (...) -> Callable[[Callable[[T, sdk_worker.StateHandler,
ExtendedProvisionInfo, GrpcServer], WorkerHandler]], Callable[[T,
sdk_worker.StateHandler, ExtendedProvisionInfo, GrpcServer], WorkerHandler]]
def wrapper(constructor):
- cls._registered_environments[urn] = constructor, payload_type
+ # type: (Callable) -> Callable
+ cls._registered_environments[urn] = constructor, payload_type # type:
ignore[assignment]
return constructor
return wrapper
@@ -306,8 +329,8 @@ class WorkerHandler(object):
@classmethod
def create(cls,
environment, # type: beam_runner_api_pb2.Environment
- state, # type: StateServicer
- provision_info, # type: Optional[ExtendedProvisionInfo]
+ state, # type: sdk_worker.StateHandler
+ provision_info, # type: ExtendedProvisionInfo
grpc_server # type: GrpcServer
):
# type: (...) -> WorkerHandler
@@ -319,14 +342,17 @@ class WorkerHandler(object):
grpc_server)
[email protected]_environment(python_urns.EMBEDDED_PYTHON, None)
+# This takes a WorkerHandlerManager instead of GrpcServer, so it is not
+# compatible with WorkerHandler.register_environment. There is a special case
+# in WorkerHandlerManager.get_worker_handlers() that allows it to work.
[email protected]_environment(python_urns.EMBEDDED_PYTHON, None) #
type: ignore[arg-type]
class EmbeddedWorkerHandler(WorkerHandler):
"""An in-memory worker_handler for fn API control, state and data planes."""
def __init__(self,
unused_payload, # type: None
state, # type: sdk_worker.StateHandler
- provision_info, # type: Optional[ExtendedProvisionInfo]
+ provision_info, # type: ExtendedProvisionInfo
worker_manager, # type: WorkerHandlerManager
):
# type: (...) -> None
@@ -347,6 +373,7 @@ class EmbeddedWorkerHandler(WorkerHandler):
self._uid_counter = 0
def push(self, request):
+ # type: (beam_fn_api_pb2.InstructionRequest) -> ControlFuture
if not request.instruction_id:
self._uid_counter += 1
request.instruction_id = 'control_%s' % self._uid_counter
@@ -394,6 +421,7 @@ class
BasicLoggingService(beam_fn_api_pb2_grpc.BeamFnLoggingServicer):
}
def Logging(self, log_messages, context=None):
+ # type: (Iterable[beam_fn_api_pb2.LogEntry.List], Any) ->
Iterator[beam_fn_api_pb2.LogControl]
yield beam_fn_api_pb2.LogControl()
for log_message in log_messages:
for log in log_message.log_entries:
@@ -403,12 +431,12 @@ class
BasicLoggingService(beam_fn_api_pb2_grpc.BeamFnLoggingServicer):
class
BasicProvisionService(beam_provision_api_pb2_grpc.ProvisionServiceServicer
):
def __init__(self, base_info, worker_manager):
- # type: (Optional[beam_provision_api_pb2.ProvisionInfo],
WorkerHandlerManager) -> None
+ # type: (beam_provision_api_pb2.ProvisionInfo, WorkerHandlerManager) ->
None
self._base_info = base_info
self._worker_manager = worker_manager
def GetProvisionInfo(self, request, context=None):
- # type: (...) -> beam_provision_api_pb2.GetProvisionInfoResponse
+ # type: (Any, Optional[ServicerContext]) ->
beam_provision_api_pb2.GetProvisionInfoResponse
if context:
worker_id = dict(context.invocation_metadata())['worker_id']
worker = self._worker_manager.get_worker(worker_id)
@@ -466,6 +494,7 @@ class GrpcServer(object):
self.control_server)
def open_uncompressed(f):
+ # type: (str) -> BinaryIO
return filesystems.FileSystems.open(
f, compression_type=CompressionTypes.UNCOMPRESSED)
@@ -499,6 +528,7 @@ class GrpcServer(object):
self.control_server.start()
def close(self):
+ # type: () -> None
self.control_handler.done()
to_wait = [
self.control_server.stop(self._DEFAULT_SHUTDOWN_TIMEOUT_SECS),
@@ -515,7 +545,7 @@ class GrpcWorkerHandler(WorkerHandler):
def __init__(self,
state, # type: StateServicer
- provision_info, # type: Optional[ExtendedProvisionInfo]
+ provision_info, # type: ExtendedProvisionInfo
grpc_server # type: GrpcServer
):
# type: (...) -> None
@@ -566,9 +596,11 @@ class GrpcWorkerHandler(WorkerHandler):
super(GrpcWorkerHandler, self).close()
def port_from_worker(self, port):
+ # type: (int) -> str
return '%s:%s' % (self.host_from_worker(), port)
def host_from_worker(self):
+ # type: () -> str
return 'localhost'
@@ -578,7 +610,7 @@ class ExternalWorkerHandler(GrpcWorkerHandler):
def __init__(self,
external_payload, # type: beam_runner_api_pb2.ExternalPayload
state, # type: StateServicer
- provision_info, # type: Optional[ExtendedProvisionInfo]
+ provision_info, # type: ExtendedProvisionInfo
grpc_server # type: GrpcServer
):
# type: (...) -> None
@@ -610,6 +642,7 @@ class ExternalWorkerHandler(GrpcWorkerHandler):
pass
def host_from_worker(self):
+ # type: () -> str
# TODO(BEAM-8646): Reconcile across platforms.
if sys.platform in ['win32', 'darwin']:
return 'localhost'
@@ -622,7 +655,7 @@ class EmbeddedGrpcWorkerHandler(GrpcWorkerHandler):
def __init__(self,
payload, # type: bytes
state, # type: StateServicer
- provision_info, # type: Optional[ExtendedProvisionInfo]
+ provision_info, # type: ExtendedProvisionInfo
grpc_server # type: GrpcServer
):
# type: (...) -> None
@@ -662,7 +695,7 @@ class SubprocessSdkWorkerHandler(GrpcWorkerHandler):
def __init__(self,
worker_command_line, # type: bytes
state, # type: StateServicer
- provision_info, # type: Optional[ExtendedProvisionInfo]
+ provision_info, # type: ExtendedProvisionInfo
grpc_server # type: GrpcServer
):
# type: (...) -> None
@@ -690,7 +723,7 @@ class DockerSdkWorkerHandler(GrpcWorkerHandler):
def __init__(self,
payload, # type: beam_runner_api_pb2.DockerPayload
state, # type: StateServicer
- provision_info, # type: Optional[ExtendedProvisionInfo]
+ provision_info, # type: ExtendedProvisionInfo
grpc_server # type: GrpcServer
):
# type: (...) -> None
@@ -700,6 +733,7 @@ class DockerSdkWorkerHandler(GrpcWorkerHandler):
self._container_id = None # type: Optional[bytes]
def host_from_worker(self):
+ # type: () -> str
if sys.platform == "darwin":
# See https://docs.docker.com/docker-for-mac/networking/
return 'host.docker.internal'
@@ -753,7 +787,9 @@ class DockerSdkWorkerHandler(GrpcWorkerHandler):
t.start()
def watch_container(self):
+ # type: () -> None
while not self._done:
+ assert self._container_id is not None
status = subprocess.check_output(
['docker', 'inspect', '-f', '{{.State.Status}}',
self._container_id]).strip()
@@ -788,7 +824,7 @@ class WorkerHandlerManager(object):
"""
def __init__(self,
environments, # type: Mapping[str,
beam_runner_api_pb2.Environment]
- job_provision_info # type: Optional[ExtendedProvisionInfo]
+ job_provision_info # type: ExtendedProvisionInfo
):
# type: (...) -> None
self._environments = environments
@@ -798,13 +834,16 @@ class WorkerHandlerManager(object):
self._workers_by_id = {} # type: Dict[str, WorkerHandler]
self.state_servicer = StateServicer()
self._grpc_server = None # type: Optional[GrpcServer]
- self._process_bundle_descriptors = {}
+ self._process_bundle_descriptors = {
+ } # type: Dict[str, beam_fn_api_pb2.ProcessBundleDescriptor]
def register_process_bundle_descriptor(self, process_bundle_descriptor):
+ # type: (beam_fn_api_pb2.ProcessBundleDescriptor) -> None
self._process_bundle_descriptors[
process_bundle_descriptor.id] = process_bundle_descriptor
def get_process_bundle_descriptor(self, request):
+ # type: (beam_fn_api_pb2.GetProcessBundleDescriptorRequest) ->
beam_fn_api_pb2.ProcessBundleDescriptor
return self._process_bundle_descriptors[
request.process_bundle_descriptor_id]
@@ -852,6 +891,7 @@ class WorkerHandlerManager(object):
return self._cached_handlers[environment_id][:num_workers]
def close_all(self):
+ # type: () -> None
for worker_handler_list in self._cached_handlers.values():
for worker_handler in set(worker_handler_list):
try:
@@ -859,13 +899,14 @@ class WorkerHandlerManager(object):
except Exception:
_LOGGER.error(
"Error closing worker_handler %s" % worker_handler,
exc_info=True)
- self._cached_handlers = {}
+ self._cached_handlers = {} # type: ignore[assignment]
self._workers_by_id = {}
if self._grpc_server is not None:
self._grpc_server.close()
self._grpc_server = None
def get_worker(self, worker_id):
+ # type: (str) -> WorkerHandler
return self._workers_by_id[worker_id]
@@ -953,6 +994,7 @@ class
StateServicer(beam_fn_api_pb2_grpc.BeamFnStateServicer,
@contextlib.contextmanager
def process_instruction_id(self, unused_instruction_id):
+ # type: (Any) -> Iterator
yield
def get_raw(self,
@@ -1016,7 +1058,7 @@ class
GrpcStateServicer(beam_fn_api_pb2_grpc.BeamFnStateServicer):
def State(self,
request_stream, # type: Iterable[beam_fn_api_pb2.StateRequest]
- context=None
+ context=None # type: Any
):
# type: (...) -> Iterator[beam_fn_api_pb2.StateResponse]
# Note that this eagerly mutates state, assuming any failures are fatal.
@@ -1062,24 +1104,29 @@ class
SingletonStateHandlerFactory(sdk_worker.StateHandlerFactory):
class ControlFuture(object):
- def __init__(self, instruction_id, response=None):
+ def __init__(self,
+ instruction_id, # type: str
+ response=None # type:
Optional[beam_fn_api_pb2.InstructionResponse]
+ ):
+ # type: (...) -> None
self.instruction_id = instruction_id
- if response:
- self._response = response
- else:
- self._response = None
+ self._response = response
+ if response is None:
self._condition = threading.Condition()
- self._exception = None
+ self._exception = None # type: Optional[Exception]
def is_done(self):
+ # type: () -> bool
return self._response is not None
def set(self, response):
+ # type: (beam_fn_api_pb2.InstructionResponse) -> None
with self._condition:
self._response = response
self._condition.notify_all()
def get(self, timeout=None):
+ # type: (Optional[float]) -> beam_fn_api_pb2.InstructionResponse
if not self._response and not self._exception:
with self._condition:
if not self._response and not self._exception:
@@ -1087,9 +1134,11 @@ class ControlFuture(object):
if self._exception:
raise self._exception
else:
+ assert self._response is not None
return self._response
def abort(self, exception):
+ # type: (Exception) -> None
with self._condition:
self._exception = exception
self._condition.notify_all()
diff --git a/sdks/python/apache_beam/runners/portability/portable_runner.py
b/sdks/python/apache_beam/runners/portability/portable_runner.py
index 713c762..15c8606 100644
--- a/sdks/python/apache_beam/runners/portability/portable_runner.py
+++ b/sdks/python/apache_beam/runners/portability/portable_runner.py
@@ -16,6 +16,7 @@
#
# pytype: skip-file
+# mypy: check-untyped-defs
from __future__ import absolute_import
from __future__ import division
@@ -27,7 +28,6 @@ import logging
import threading
import time
from typing import TYPE_CHECKING
-from typing import Any
from typing import Iterator
from typing import Optional
from typing import Tuple
@@ -197,8 +197,12 @@ class JobServiceHandle(object):
pipeline_options=self.get_pipeline_options()),
timeout=self.timeout)
- def stage(self, pipeline, artifact_staging_endpoint, staging_session_token):
- # type: (...) -> Optional[Any]
+ def stage(self,
+ proto_pipeline, # type: beam_runner_api_pb2.Pipeline
+ artifact_staging_endpoint,
+ staging_session_token
+ ):
+ # type: (...) -> None
"""Stage artifacts"""
if artifact_staging_endpoint:
@@ -288,6 +292,7 @@ class PortableRunner(runner.PipelineRunner):
'use, such as --runner=FlinkRunner or --runner=SparkRunner.')
def create_job_service_handle(self, job_service, options):
+ # type: (...) -> JobServiceHandle
return JobServiceHandle(job_service, options)
def create_job_service(self, options):
@@ -299,7 +304,7 @@ class PortableRunner(runner.PipelineRunner):
job_endpoint = options.view_as(PortableOptions).job_endpoint
if job_endpoint:
if job_endpoint == 'embed':
- server = job_server.EmbeddedJobServer()
+ server = job_server.EmbeddedJobServer() # type: job_server.JobServer
else:
job_server_timeout =
options.view_as(PortableOptions).job_server_timeout
server = job_server.ExternalJobServer(job_endpoint, job_server_timeout)
@@ -463,6 +468,7 @@ class PipelineResult(runner.PipelineResult):
self._runtime_exception = None
def cancel(self):
+ # type: () -> None
try:
self._job_service.Cancel(
beam_job_api_pb2.CancelJobRequest(job_id=self._job_id))
@@ -496,6 +502,7 @@ class PipelineResult(runner.PipelineResult):
return self._metrics
def _last_error_message(self):
+ # type: () -> str
# Filter only messages with the "message_response" and error messages.
messages = [
m.message_response for m in self._messages
@@ -517,6 +524,7 @@ class PipelineResult(runner.PipelineResult):
:return: The result of the pipeline, i.e. PipelineResult.
"""
def read_messages():
+ # type: () -> None
previous_state = -1
for message in self._message_stream:
if message.HasField('message_response'):
@@ -576,6 +584,7 @@ class PipelineResult(runner.PipelineResult):
self._cleanup()
def _cleanup(self, on_exit=False):
+ # type: (bool) -> None
if on_exit and self._cleanup_callbacks:
_LOGGER.info(
'Running cleanup on exit. If your pipeline should continue running, '
diff --git a/sdks/python/apache_beam/runners/portability/stager.py
b/sdks/python/apache_beam/runners/portability/stager.py
index e6288fc..f1a820f 100644
--- a/sdks/python/apache_beam/runners/portability/stager.py
+++ b/sdks/python/apache_beam/runners/portability/stager.py
@@ -57,6 +57,7 @@ import sys
import tempfile
from typing import List
from typing import Optional
+from typing import Tuple
import pkg_resources
from future.moves.urllib.parse import urlparse
diff --git a/sdks/python/apache_beam/utils/profiler.py
b/sdks/python/apache_beam/utils/profiler.py
index dae625a..81bd578 100644
--- a/sdks/python/apache_beam/utils/profiler.py
+++ b/sdks/python/apache_beam/utils/profiler.py
@@ -21,6 +21,7 @@ For internal use only; no backwards-compatibility guarantees.
"""
# pytype: skip-file
+# mypy: check-untyped-defs
from __future__ import absolute_import
@@ -47,6 +48,9 @@ class Profile(object):
SORTBY = 'cumulative'
+ profile_output = None # type: str
+ stats = None # type: pstats.Stats
+
def __init__(
self,
profile_id, # type: str
@@ -72,13 +76,11 @@ class Profile(object):
profiling session, the profiler only records the newly allocated
objects
in this session.
"""
- self.stats = None
self.profile_id = str(profile_id)
self.profile_location = profile_location
self.log_results = log_results
self.file_copy_fn = file_copy_fn or self.default_file_copy_fn
self.time_prefix = time_prefix
- self.profile_output = None
self.enable_cpu_profiling = enable_cpu_profiling
self.enable_memory_profiling = enable_memory_profiling
@@ -104,7 +106,8 @@ class Profile(object):
if self.enable_cpu_profiling:
self.profile.create_stats()
self.profile_output = self._upload_profile_data(
- 'cpu_profile', self.profile.stats)
+ # typing: seems stats attr is missing from typeshed
+ self.profile_location, 'cpu_profile', self.profile.stats) # type:
ignore[attr-defined]
if self.enable_memory_profiling:
if not self.hpy:
@@ -113,7 +116,10 @@ class Profile(object):
h = self.hpy.heap()
heap_dump_data = '%s\n%s' % (h, h.more)
self._upload_profile_data(
- 'memory_profile', heap_dump_data, write_binary=False)
+ self.profile_location,
+ 'memory_profile',
+ heap_dump_data,
+ write_binary=False)
if self.log_results:
if self.enable_cpu_profiling:
@@ -156,18 +162,20 @@ class Profile(object):
return create_profiler
return None
- def _upload_profile_data(self, dir, data, write_binary=True):
+ def _upload_profile_data(
+ self, profile_location, dir, data, write_binary=True):
+ # type: (...) -> str
dump_location = os.path.join(
- self.profile_location,
+ profile_location,
dir,
time.strftime(self.time_prefix + self.profile_id))
fd, filename = tempfile.mkstemp()
try:
os.close(fd)
if write_binary:
- with open(filename, 'wb') as f:
+ with open(filename, 'wb') as fb:
import marshal
- marshal.dump(data, f)
+ marshal.dump(data, fb)
else:
with open(filename, 'w') as f:
f.write(data)
diff --git a/sdks/python/mypy.ini b/sdks/python/mypy.ini
index 121d1b4..6095c02 100644
--- a/sdks/python/mypy.ini
+++ b/sdks/python/mypy.ini
@@ -61,8 +61,6 @@ ignore_errors = true
# TODO(BEAM-7746): Remove the lines below.
-[mypy-apache_beam.coders.coders]
-ignore_errors = true
[mypy-apache_beam.io.*]
ignore_errors = true
@@ -88,18 +86,6 @@ ignore_errors = true
[mypy-apache_beam.runners.interactive.*]
ignore_errors = true
-[mypy-apache_beam.runners.portability.artifact_service]
-ignore_errors = true
-
-[mypy-apache_beam.runners.portability.fn_api_runner.*]
-ignore_errors = true
-
-[mypy-apache_beam.runners.portability.portable_runner]
-ignore_errors = true
-
-[mypy-apache_beam.runners.portability.stager]
-ignore_errors = true
-
[mypy-apache_beam.testing.synthetic_pipeline]
ignore_errors = true