This is an automated email from the ASF dual-hosted git repository.

pabloem pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 95e2374  [BEAM-7746] Fix typing in runners
     new 5156634  Merge pull request #13060 from [BEAM-7746] Fix typing in 
runners
95e2374 is described below

commit 95e2374c659ba9e7fbe42ac217c0b367120b6ee1
Author: Chad Dombrova <[email protected]>
AuthorDate: Mon Oct 5 16:51:34 2020 -0700

    [BEAM-7746] Fix typing in runners
---
 .../runners/portability/artifact_service.py        |  18 +--
 .../runners/portability/fn_api_runner/execution.py |  82 ++++++++++--
 .../runners/portability/fn_api_runner/fn_runner.py | 121 ++++++++++++------
 .../portability/fn_api_runner/translations.py      |  25 ++--
 .../portability/fn_api_runner/worker_handlers.py   | 137 ++++++++++++++-------
 .../runners/portability/portable_runner.py         |  17 ++-
 .../apache_beam/runners/portability/stager.py      |   1 +
 sdks/python/apache_beam/utils/profiler.py          |  24 ++--
 sdks/python/mypy.ini                               |  14 ---
 9 files changed, 302 insertions(+), 137 deletions(-)

diff --git a/sdks/python/apache_beam/runners/portability/artifact_service.py 
b/sdks/python/apache_beam/runners/portability/artifact_service.py
index 1f3ec1c..18537f4 100644
--- a/sdks/python/apache_beam/runners/portability/artifact_service.py
+++ b/sdks/python/apache_beam/runners/portability/artifact_service.py
@@ -33,9 +33,15 @@ import queue
 import sys
 import tempfile
 import threading
-import typing
 from io import BytesIO
+from typing import Any
+from typing import BinaryIO  # pylint: disable=unused-import
 from typing import Callable
+from typing import Dict
+from typing import List
+from typing import MutableMapping
+from typing import Optional
+from typing import Tuple
 
 import grpc
 from future.moves.urllib.request import urlopen
@@ -48,11 +54,6 @@ from apache_beam.portability.api import 
beam_artifact_api_pb2_grpc
 from apache_beam.portability.api import beam_runner_api_pb2
 from apache_beam.utils import proto_utils
 
-if typing.TYPE_CHECKING:
-  from typing import BinaryIO  # pylint: disable=ungrouped-imports
-  from typing import Iterable
-  from typing import MutableMapping
-
 
 class ArtifactRetrievalService(
     beam_artifact_api_pb2_grpc.ArtifactRetrievalServiceServicer):
@@ -61,7 +62,7 @@ class ArtifactRetrievalService(
 
   def __init__(
       self,
-      file_reader,  # type: Callable[[str], BinaryIO],
+      file_reader,  # type: Callable[[str], BinaryIO]
       chunk_size=None,
   ):
     self._file_reader = file_reader
@@ -105,7 +106,8 @@ class ArtifactStagingService(
       file_writer,  # type: Callable[[str, Optional[str]], Tuple[BinaryIO, 
str]]
     ):
     self._lock = threading.Lock()
-    self._jobs_to_stage = {}
+    self._jobs_to_stage = {
+    }  # type: Dict[str, Tuple[Dict[Any, 
List[beam_runner_api_pb2.ArtifactInformation]], threading.Event]]
     self._file_writer = file_writer
 
   def register_job(
diff --git 
a/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py 
b/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py
index 5b8e91c..bc69123 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py
@@ -17,6 +17,8 @@
 
 """Set of utilities for execution of a pipeline by the FnApiRunner."""
 
+# mypy: disallow-untyped-defs
+
 from __future__ import absolute_import
 
 import collections
@@ -26,10 +28,12 @@ from typing import TYPE_CHECKING
 from typing import Any
 from typing import DefaultDict
 from typing import Dict
+from typing import Iterable
 from typing import Iterator
 from typing import List
 from typing import MutableMapping
 from typing import Optional
+from typing import Set
 from typing import Tuple
 
 from typing_extensions import Protocol
@@ -59,9 +63,13 @@ from apache_beam.utils import proto_utils
 from apache_beam.utils import windowed_value
 
 if TYPE_CHECKING:
-  from apache_beam.coders.coder_impl import CoderImpl
+  from apache_beam.coders.coder_impl import CoderImpl, WindowedValueCoderImpl
+  from apache_beam.portability.api import endpoints_pb2
   from apache_beam.runners.portability.fn_api_runner import worker_handlers
+  from apache_beam.runners.portability.fn_api_runner.fn_runner import 
DataOutput
+  from apache_beam.runners.portability.fn_api_runner.fn_runner import 
OutputTimers
   from apache_beam.runners.portability.fn_api_runner.translations import 
DataSideInput
+  from apache_beam.transforms import core
   from apache_beam.transforms.window import BoundedWindow
 
 ENCODED_IMPULSE_VALUE = WindowedValueCoder(
@@ -87,13 +95,27 @@ class PartitionableBuffer(Buffer, Protocol):
     # type: (int) -> List[List[bytes]]
     pass
 
+  @property
+  def cleared(self):
+    # type: () -> bool
+    pass
+
+  def clear(self):
+    # type: () -> None
+    pass
+
+  def reset(self):
+    # type: () -> None
+    pass
+
 
 class ListBuffer(object):
   """Used to support parititioning of a list."""
   def __init__(self, coder_impl):
+    # type: (CoderImpl) -> None
     self._coder_impl = coder_impl
     self._inputs = []  # type: List[bytes]
-    self._grouped_output = None
+    self._grouped_output = None  # type: Optional[List[List[bytes]]]
     self.cleared = False
 
   def append(self, element):
@@ -139,6 +161,8 @@ class ListBuffer(object):
     self._grouped_output = None
 
   def reset(self):
+    # type: () -> None
+
     """Resets a cleared buffer for reuse."""
     if not self.cleared:
       raise RuntimeError('Trying to reset a non-cleared ListBuffer.')
@@ -150,7 +174,7 @@ class GroupingBuffer(object):
   def __init__(self,
                pre_grouped_coder,  # type: coders.Coder
                post_grouped_coder,  # type: coders.Coder
-               windowing
+               windowing  # type: core.Windowing
               ):
     # type: (...) -> None
     self._key_coder = pre_grouped_coder.key_coder()
@@ -227,12 +251,24 @@ class GroupingBuffer(object):
     """
     return itertools.chain(*self.partition(1))
 
+  # these should never be accessed, but they allow this class to meet the
+  # PartionableBuffer protocol
+  cleared = False
+
+  def clear(self):
+    # type: () -> None
+    pass
+
+  def reset(self):
+    # type: () -> None
+    pass
+
 
 class WindowGroupingBuffer(object):
   """Used to partition windowed side inputs."""
   def __init__(
       self,
-      access_pattern,
+      access_pattern,  # type: beam_runner_api_pb2.FunctionSpec
       coder  # type: WindowedValueCoder
   ):
     # type: (...) -> None
@@ -283,17 +319,21 @@ class 
GenericNonMergingWindowFn(window.NonMergingWindowFn):
   URN = 'internal-generic-non-merging'
 
   def __init__(self, coder):
+    # type: (coders.Coder) -> None
     self._coder = coder
 
   def assign(self, assign_context):
+    # type: (window.WindowFn.AssignContext) -> Iterable[BoundedWindow]
     raise NotImplementedError()
 
   def get_window_coder(self):
+    # type: () -> coders.Coder
     return self._coder
 
   @staticmethod
   @window.urns.RunnerApiFn.register_urn(URN, bytes)
   def from_runner_api_parameter(window_coder_id, context):
+    # type: (bytes, Any) -> GenericNonMergingWindowFn
     return GenericNonMergingWindowFn(
         context.coders[window_coder_id.decode('utf-8')])
 
@@ -308,9 +348,11 @@ class FnApiRunnerExecutionContext(object):
       stages,  # type: List[translations.Stage]
       worker_handler_manager,  # type: worker_handlers.WorkerHandlerManager
       pipeline_components,  # type: beam_runner_api_pb2.Components
-      safe_coders,
-      data_channel_coders,
+      safe_coders,  # type: Dict[str, str]
+      data_channel_coders,  # type: Dict[str, str]
                ):
+    # type: (...) -> None
+
     """
     :param worker_handler_manager: This class manages the set of worker
         handlers, and the communication with state / control APIs.
@@ -365,7 +407,7 @@ class FnApiRunnerExecutionContext(object):
       return all_side_inputs
 
     all_side_inputs = frozenset(get_all_side_inputs())
-    data_side_inputs_by_producing_stage = {}
+    data_side_inputs_by_producing_stage = {}  # type: Dict[str, DataSideInput]
 
     producing_stages_by_pcoll = {}
 
@@ -397,6 +439,7 @@ class FnApiRunnerExecutionContext(object):
     return data_side_inputs_by_producing_stage
 
   def _make_safe_windowing_strategy(self, id):
+    # type: (str) -> str
     windowing_strategy_proto = 
self.pipeline_components.windowing_strategies[id]
     if windowing_strategy_proto.window_fn.urn in SAFE_WINDOW_FNS:
       return id
@@ -420,15 +463,17 @@ class FnApiRunnerExecutionContext(object):
 
   @property
   def state_servicer(self):
+    # type: () -> worker_handlers.StateServicer
     # TODO(BEAM-9625): Ensure FnApiRunnerExecutionContext owns StateServicer
     return self.worker_handler_manager.state_servicer
 
   def next_uid(self):
+    # type: () -> str
     self._last_uid += 1
     return str(self._last_uid)
 
   def _iterable_state_write(self, values, element_coder_impl):
-    # type: (...) -> bytes
+    # type: (Iterable, CoderImpl) -> bytes
     token = unique_name(None, 'iter').encode('ascii')
     out = create_OutputStream()
     for element in values:
@@ -484,21 +529,23 @@ class BundleContextManager(object):
                stage,  # type: translations.Stage
                num_workers,  # type: int
               ):
+    # type: (...) -> None
     self.execution_context = execution_context
     self.stage = stage
     self.bundle_uid = self.execution_context.next_uid()
     self.num_workers = num_workers
 
     # Properties that are lazily initialized
-    self._process_bundle_descriptor = None
-    self._worker_handlers = None
+    self._process_bundle_descriptor = None  # type: 
Optional[beam_fn_api_pb2.ProcessBundleDescriptor]
+    self._worker_handlers = None  # type: 
Optional[List[worker_handlers.WorkerHandler]]
     # a mapping of {(transform_id, timer_family_id): timer_coder_id}. The map
     # is built after self._process_bundle_descriptor is initialized.
     # This field can be used to tell whether current bundle has timers.
-    self._timer_coder_ids = None
+    self._timer_coder_ids = None  # type: Optional[Dict[Tuple[str, str], str]]
 
   @property
   def worker_handlers(self):
+    # type: () -> List[worker_handlers.WorkerHandler]
     if self._worker_handlers is None:
       self._worker_handlers = (
           self.execution_context.worker_handler_manager.get_worker_handlers(
@@ -506,23 +553,27 @@ class BundleContextManager(object):
     return self._worker_handlers
 
   def data_api_service_descriptor(self):
+    # type: () -> Optional[endpoints_pb2.ApiServiceDescriptor]
     # All worker_handlers share the same grpc server, so we can read grpc 
server
     # info from any worker_handler and read from the first worker_handler.
     return self.worker_handlers[0].data_api_service_descriptor()
 
   def state_api_service_descriptor(self):
+    # type: () -> Optional[endpoints_pb2.ApiServiceDescriptor]
     # All worker_handlers share the same grpc server, so we can read grpc 
server
     # info from any worker_handler and read from the first worker_handler.
     return self.worker_handlers[0].state_api_service_descriptor()
 
   @property
   def process_bundle_descriptor(self):
+    # type: () -> beam_fn_api_pb2.ProcessBundleDescriptor
     if self._process_bundle_descriptor is None:
       self._process_bundle_descriptor = self._build_process_bundle_descriptor()
       self._timer_coder_ids = self._build_timer_coders_id_map()
     return self._process_bundle_descriptor
 
   def _build_process_bundle_descriptor(self):
+    # type: () -> beam_fn_api_pb2.ProcessBundleDescriptor
     # Cannot be invoked until *after* _extract_endpoints is called.
     # Always populate the timer_api_service_descriptor.
     return beam_fn_api_pb2.ProcessBundleDescriptor(
@@ -543,7 +594,7 @@ class BundleContextManager(object):
         timer_api_service_descriptor=self.data_api_service_descriptor())
 
   def extract_bundle_inputs_and_outputs(self):
-    # type: (...) -> Tuple[Dict[str, PartitionableBuffer], DataOutput, 
Dict[Tuple[str, str], str]]
+    # type: () -> Tuple[Dict[str, PartitionableBuffer], DataOutput, 
Dict[Tuple[str, str], bytes]]
 
     """Returns maps of transform names to PCollection identifiers.
 
@@ -560,7 +611,7 @@ class BundleContextManager(object):
     data_input = {}  # type: Dict[str, PartitionableBuffer]
     data_output = {}  # type: DataOutput
     # A mapping of {(transform_id, timer_family_id) : buffer_id}
-    expected_timer_output = {}  # type: Dict[Tuple[str, str], str]
+    expected_timer_output = {}  # type: OutputTimers
     for transform in self.stage.transforms:
       if transform.spec.urn in (bundle_processor.DATA_INPUT_URN,
                                 bundle_processor.DATA_OUTPUT_URN):
@@ -609,6 +660,8 @@ class BundleContextManager(object):
     return self.get_coder_impl(coder_id)
 
   def _build_timer_coders_id_map(self):
+    # type: () -> Dict[Tuple[str, str], str]
+    assert self._process_bundle_descriptor is not None
     timer_coder_ids = {}
     for transform_id, transform_proto in (self._process_bundle_descriptor
         .transforms.items()):
@@ -621,6 +674,7 @@ class BundleContextManager(object):
     return timer_coder_ids
 
   def get_coder_impl(self, coder_id):
+    # type: (str) -> CoderImpl
     if coder_id in self.execution_context.safe_coders:
       return self.execution_context.pipeline_context.coders[
           self.execution_context.safe_coders[coder_id]].get_impl()
@@ -628,6 +682,8 @@ class BundleContextManager(object):
       return 
self.execution_context.pipeline_context.coders[coder_id].get_impl()
 
   def get_timer_coder_impl(self, transform_id, timer_family_id):
+    # type: (str, str) -> CoderImpl
+    assert self._timer_coder_ids is not None
     return self.get_coder_impl(
         self._timer_coder_ids[(transform_id, timer_family_id)])
 
diff --git 
a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py 
b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py
index 9c43f33..404261f 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py
@@ -18,6 +18,7 @@
 """A PipelineRunner using the SDK harness.
 """
 # pytype: skip-file
+# mypy: check-untyped-defs
 
 from __future__ import absolute_import
 from __future__ import print_function
@@ -36,12 +37,14 @@ from builtins import object
 from typing import TYPE_CHECKING
 from typing import Callable
 from typing import Dict
+from typing import Iterator
 from typing import List
 from typing import Mapping
 from typing import MutableMapping
 from typing import Optional
 from typing import Tuple
 from typing import TypeVar
+from typing import Union
 
 from apache_beam.coders.coder_impl import create_OutputStream
 from apache_beam.metrics import metric
@@ -62,13 +65,14 @@ from 
apache_beam.runners.portability.fn_api_runner.translations import create_bu
 from apache_beam.runners.portability.fn_api_runner.translations import 
only_element
 from apache_beam.runners.portability.fn_api_runner.worker_handlers import 
WorkerHandlerManager
 from apache_beam.transforms import environments
-from apache_beam.utils import profiler
 from apache_beam.utils import proto_utils
 from apache_beam.utils import thread_pool_executor
+from apache_beam.utils.profiler import Profile
 
 if TYPE_CHECKING:
   from apache_beam.pipeline import Pipeline
   from apache_beam.portability.api import metrics_pb2
+  from apache_beam.runners.portability.fn_api_runner.worker_handlers import 
WorkerHandler
 
 _LOGGER = logging.getLogger(__name__)
 
@@ -77,6 +81,7 @@ T = TypeVar('T')
 DataSideInput = Dict[Tuple[str, str],
                      Tuple[bytes, beam_runner_api_pb2.FunctionSpec]]
 DataOutput = Dict[str, bytes]
+OutputTimers = Dict[Tuple[str, str], bytes]
 BundleProcessResult = Tuple[beam_fn_api_pb2.InstructionResponse,
                             List[beam_fn_api_pb2.ProcessBundleSplitResponse]]
 
@@ -88,11 +93,12 @@ class FnApiRunner(runner.PipelineRunner):
   def __init__(
       self,
       default_environment=None,  # type: Optional[environments.Environment]
-      bundle_repeat=0,
-      use_state_iterables=False,
+      bundle_repeat=0,  # type: int
+      use_state_iterables=False,  # type: bool
       provision_info=None,  # type: Optional[ExtendedProvisionInfo]
-      progress_request_frequency=None,
-      is_drain=False):
+      progress_request_frequency=None,  # type: Optional[float]
+      is_drain=False  # type: bool
+  ):
     # type: (...) -> None
 
     """Creates a new Fn API Runner.
@@ -114,7 +120,7 @@ class FnApiRunner(runner.PipelineRunner):
     self._bundle_repeat = bundle_repeat
     self._num_workers = 1
     self._progress_frequency = progress_request_frequency
-    self._profiler_factory = None  # type: Optional[Callable[..., 
profiler.Profile]]
+    self._profiler_factory = None  # type: Optional[Callable[..., Profile]]
     self._use_state_iterables = use_state_iterables
     self._is_drain = is_drain
     self._provision_info = provision_info or ExtendedProvisionInfo(
@@ -123,6 +129,7 @@ class FnApiRunner(runner.PipelineRunner):
 
   @staticmethod
   def supported_requirements():
+    # type: () -> Tuple[str, ...]
     return (
         common_urns.requirements.REQUIRES_STATEFUL_PROCESSING.urn,
         common_urns.requirements.REQUIRES_BUNDLE_FINALIZATION.urn,
@@ -169,7 +176,7 @@ class FnApiRunner(runner.PipelineRunner):
       self._default_environment = environments.SubprocessSDKEnvironment(
           command_string=command_string)
 
-    self._profiler_factory = profiler.Profile.factory_from_options(
+    self._profiler_factory = Profile.factory_from_options(
         options.view_as(pipeline_options.ProfilingOptions))
 
     self._latest_run_result = self.run_via_runner_api(
@@ -187,6 +194,7 @@ class FnApiRunner(runner.PipelineRunner):
 
   @contextlib.contextmanager
   def maybe_profile(self):
+    # type: () -> Iterator[None]
     if self._profiler_factory:
       try:
         profile_id = 'direct-' + subprocess.check_output([
@@ -194,7 +202,8 @@ class FnApiRunner(runner.PipelineRunner):
         ]).decode(errors='ignore').strip()
       except subprocess.CalledProcessError:
         profile_id = 'direct-unknown'
-      profiler = self._profiler_factory(profile_id, time_prefix='')
+      profiler = self._profiler_factory(
+          profile_id, time_prefix='')  # type: Optional[Profile]
     else:
       profiler = None
 
@@ -231,10 +240,13 @@ class FnApiRunner(runner.PipelineRunner):
       yield
 
   def _validate_requirements(self, pipeline_proto):
+    # type: (beam_runner_api_pb2.Pipeline) -> None
+
     """As a test runner, validate requirements were set correctly."""
     expected_requirements = set()
 
     def add_requirements(transform_id):
+      # type: (str) -> None
       transform = pipeline_proto.components.transforms[transform_id]
       if transform.spec.urn in translations.PAR_DO_URNS:
         payload = proto_utils.parse_Bytes(
@@ -266,6 +278,8 @@ class FnApiRunner(runner.PipelineRunner):
           (expected_requirements - set(pipeline_proto.requirements)))
 
   def _check_requirements(self, pipeline_proto):
+    # type: (beam_runner_api_pb2.Pipeline) -> None
+
     """Check that this runner can satisfy all pipeline requirements."""
     supported_requirements = set(self.supported_requirements())
     for requirement in pipeline_proto.requirements:
@@ -355,10 +369,10 @@ class FnApiRunner(runner.PipelineRunner):
       self,
       runner_execution_context,  # type: execution.FnApiRunnerExecutionContext
       bundle_manager,  # type: BundleManager
-      data_input,
+      data_input,  # type: Dict[str, execution.PartitionableBuffer]
       data_output,  # type: DataOutput
-      fired_timers,
-      expected_output_timers,
+      fired_timers,  # type: Mapping[Tuple[str, str], 
execution.PartitionableBuffer]
+      expected_output_timers,  # type: Dict[Tuple[str, str], bytes]
   ):
     # type: (...) -> None
 
@@ -407,7 +421,12 @@ class FnApiRunner(runner.PipelineRunner):
         written_timers.clear()
 
   def _add_sdk_delayed_applications_to_deferred_inputs(
-      self, bundle_context_manager, bundle_result, deferred_inputs):
+      self,
+      bundle_context_manager,  # type: execution.BundleContextManager
+      bundle_result,  # type: beam_fn_api_pb2.InstructionResponse
+      deferred_inputs  # type: MutableMapping[str, 
execution.PartitionableBuffer]
+  ):
+    # type: (...) -> None
     for delayed_application in bundle_result.process_bundle.residual_roots:
       name = bundle_context_manager.input_for(
           delayed_application.application.transform_id,
@@ -421,8 +440,8 @@ class FnApiRunner(runner.PipelineRunner):
       self,
       splits,  # type: List[beam_fn_api_pb2.ProcessBundleSplitResponse]
       bundle_context_manager,  # type: execution.BundleContextManager
-      last_sent,
-      deferred_inputs  # type: MutableMapping[str, PartitionableBuffer]
+      last_sent,  # type: Dict[str, execution.PartitionableBuffer]
+      deferred_inputs  # type: MutableMapping[str, 
execution.PartitionableBuffer]
   ):
     # type: (...) -> None
 
@@ -485,7 +504,8 @@ class FnApiRunner(runner.PipelineRunner):
     """
     data_input, data_output, expected_timer_output = (
         bundle_context_manager.extract_bundle_inputs_and_outputs())
-    input_timers = {}
+    input_timers = {
+    }  # type: Mapping[Tuple[str, str], execution.PartitionableBuffer]
 
     worker_handler_manager = runner_execution_context.worker_handler_manager
     _LOGGER.info('Running %s', bundle_context_manager.stage.name)
@@ -509,9 +529,11 @@ class FnApiRunner(runner.PipelineRunner):
         self._progress_frequency,
         cache_token_generator=cache_token_generator)
 
-    final_result = None
+    final_result = None  # type: Optional[beam_fn_api_pb2.InstructionResponse]
 
     def merge_results(last_result):
+      # type: (beam_fn_api_pb2.InstructionResponse) -> 
beam_fn_api_pb2.InstructionResponse
+
       """ Merge the latest result with other accumulated results. """
       return (
           last_result
@@ -539,7 +561,6 @@ class FnApiRunner(runner.PipelineRunner):
       else:
         data_input = deferred_inputs
         input_timers = fired_timers
-        bundle_manager._registered = True
 
     # Store the required downstream side inputs into state so it is accessible
     # for the worker when it runs bundles that consume this stage's output.
@@ -552,13 +573,16 @@ class FnApiRunner(runner.PipelineRunner):
 
   def _run_bundle(
       self,
-      runner_execution_context,
-      bundle_context_manager,
-      data_input,
-      data_output,
-      input_timers,
-      expected_timer_output,
-      bundle_manager):
+      runner_execution_context,  # type: execution.FnApiRunnerExecutionContext
+      bundle_context_manager,  # type: execution.BundleContextManager
+      data_input,  # type: Dict[str, execution.PartitionableBuffer]
+      data_output,  # type: DataOutput
+      input_timers,  # type: Mapping[Tuple[str, str], 
execution.PartitionableBuffer]
+      expected_timer_output,  # type: Dict[Tuple[str, str], bytes]
+      bundle_manager  # type: BundleManager
+  ):
+    # type: (...) -> Tuple[beam_fn_api_pb2.InstructionResponse, Dict[str, 
execution.PartitionableBuffer], Dict[Tuple[str, str], ListBuffer]]
+
     """Execute a bundle, and return a result object, and deferred inputs."""
     self._run_bundle_multiple_times_for_testing(
         runner_execution_context,
@@ -576,7 +600,7 @@ class FnApiRunner(runner.PipelineRunner):
     # - SDK-initiated deferred applications of root elements
     # - Runner-initiated deferred applications of root elements
     deferred_inputs = {}  # type: Dict[str, execution.PartitionableBuffer]
-    fired_timers = {}
+    fired_timers = {}  # type: Dict[Tuple[str, str], ListBuffer]
 
     self._collect_written_timers_and_add_to_fired_timers(
         bundle_context_manager, fired_timers)
@@ -601,12 +625,15 @@ class FnApiRunner(runner.PipelineRunner):
 
   @staticmethod
   def get_cache_token_generator(static=True):
+    # type: (bool) -> Iterator[beam_fn_api_pb2.ProcessBundleRequest.CacheToken]
+
     """A generator for cache tokens.
          :arg static If True, generator always returns the same cache token
                      If False, generator returns a new cache token each time
          :return A generator which returns a cache token on next(generator)
      """
     def generate_token(identifier):
+      # type: (int) -> beam_fn_api_pb2.ProcessBundleRequest.CacheToken
       return beam_fn_api_pb2.ProcessBundleRequest.CacheToken(
           user_state=beam_fn_api_pb2.ProcessBundleRequest.CacheToken.UserState(
           ),
@@ -614,44 +641,55 @@ class FnApiRunner(runner.PipelineRunner):
 
     class StaticGenerator(object):
       def __init__(self):
+        # type: () -> None
         self._token = generate_token(1)
 
       def __iter__(self):
+        # type: () -> StaticGenerator
         # pylint: disable=non-iterator-returned
         return self
 
       def __next__(self):
+        # type: () -> beam_fn_api_pb2.ProcessBundleRequest.CacheToken
         return self._token
 
     class DynamicGenerator(object):
       def __init__(self):
+        # type: () -> None
         self._counter = 0
         self._lock = threading.Lock()
 
       def __iter__(self):
+        # type: () -> DynamicGenerator
         # pylint: disable=non-iterator-returned
         return self
 
       def __next__(self):
+        # type: () -> beam_fn_api_pb2.ProcessBundleRequest.CacheToken
         with self._lock:
           self._counter += 1
           return generate_token(self._counter)
 
-    return StaticGenerator() if static else DynamicGenerator()
+    if static:
+      return StaticGenerator()
+    else:
+      return DynamicGenerator()
 
 
 class ExtendedProvisionInfo(object):
   def __init__(self,
                provision_info=None,  # type: 
Optional[beam_provision_api_pb2.ProvisionInfo]
-               artifact_staging_dir=None,
+               artifact_staging_dir=None,  # type: Optional[str]
                job_name=None,  # type: Optional[str]
               ):
+    # type: (...) -> None
     self.provision_info = (
         provision_info or beam_provision_api_pb2.ProvisionInfo())
     self.artifact_staging_dir = artifact_staging_dir
     self.job_name = job_name
 
   def for_environment(self, env):
+    # type: (...) -> ExtendedProvisionInfo
     if env.dependencies:
       provision_info_with_deps = copy.deepcopy(self.provision_info)
       provision_info_with_deps.dependencies.extend(env.dependencies)
@@ -699,9 +737,11 @@ class BundleManager(object):
 
   def __init__(self,
                bundle_context_manager,  # type: execution.BundleContextManager
-               progress_frequency=None,
+               progress_frequency=None,  # type: Optional[float]
                cache_token_generator=FnApiRunner.get_cache_token_generator()
               ):
+    # type: (...) -> None
+
     """Set up a bundle manager.
 
     Args:
@@ -709,7 +749,7 @@ class BundleManager(object):
     """
     self.bundle_context_manager = bundle_context_manager  # type: 
execution.BundleContextManager
     self._progress_frequency = progress_frequency
-    self._worker_handler = None  # type: Optional[execution.WorkerHandler]
+    self._worker_handler = None  # type: Optional[WorkerHandler]
     self._cache_token_generator = cache_token_generator
 
   def _send_input_to_worker(self,
@@ -727,6 +767,7 @@ class BundleManager(object):
 
   def _send_timers_to_worker(
       self, process_bundle_id, transform_id, timer_family_id, timers):
+    # type: (...) -> None
     assert self._worker_handler is not None
     timer_out = self._worker_handler.data_conn.output_timer_stream(
         process_bundle_id, transform_id, timer_family_id)
@@ -751,8 +792,9 @@ class BundleManager(object):
 
   def _generate_splits_for_testing(self,
                                    split_manager,
-                                   inputs,  # type: Mapping[str, 
PartitionableBuffer]
-                                   process_bundle_id):
+                                   inputs,  # type: Mapping[str, 
execution.PartitionableBuffer]
+                                   process_bundle_id
+                                  ):
     # type: (...) -> List[beam_fn_api_pb2.ProcessBundleSplitResponse]
     split_results = []  # type: 
List[beam_fn_api_pb2.ProcessBundleSplitResponse]
     read_transform_id, buffer_data = only_element(inputs.items())
@@ -819,8 +861,8 @@ class BundleManager(object):
                      inputs,  # type: Mapping[str, 
execution.PartitionableBuffer]
                      expected_outputs,  # type: DataOutput
                      fired_timers,  # type: Mapping[Tuple[str, str], 
execution.PartitionableBuffer]
-                     expected_output_timers,  # type: Dict[Tuple[str, str], 
str]
-                     dry_run=False,
+                     expected_output_timers,  # type: OutputTimers
+                     dry_run=False,  # type: bool
                     ):
     # type: (...) -> BundleProcessResult
     # Unique id for the instruction processing this bundle.
@@ -863,7 +905,8 @@ class BundleManager(object):
         split_results = self._generate_splits_for_testing(
             split_manager, inputs, process_bundle_id)
 
-      expect_reads = list(expected_outputs.keys())
+      expect_reads = list(
+          expected_outputs.keys())  # type: List[Union[str, Tuple[str, str]]]
       expect_reads.extend(list(expected_output_timers.keys()))
 
       # Gather all output data.
@@ -871,7 +914,7 @@ class BundleManager(object):
           process_bundle_id,
           expect_reads,
           abort_callback=lambda:
-          (result_future.is_done() and result_future.get().error)):
+          (result_future.is_done() and bool(result_future.get().error))):
         if isinstance(output, beam_fn_api_pb2.Elements.Timers) and not dry_run:
           with BundleManager._lock:
             timer_buffer = self.bundle_context_manager.get_buffer(
@@ -910,7 +953,7 @@ class ParallelBundleManager(BundleManager):
   def __init__(
       self,
       bundle_context_manager,  # type: execution.BundleContextManager
-      progress_frequency=None,
+      progress_frequency=None,  # type: Optional[float]
       cache_token_generator=None,
       **kwargs):
     # type: (...) -> None
@@ -924,8 +967,8 @@ class ParallelBundleManager(BundleManager):
                      inputs,  # type: Mapping[str, 
execution.PartitionableBuffer]
                      expected_outputs,  # type: DataOutput
                      fired_timers,  # type: Mapping[Tuple[str, str], 
execution.PartitionableBuffer]
-                     expected_output_timers,  # type: Dict[Tuple[str, str], 
str]
-                     dry_run=False,
+                     expected_output_timers,  # type: OutputTimers
+                     dry_run=False,  # type: bool
                     ):
     # type: (...) -> BundleProcessResult
     part_inputs = [{} for _ in range(self._num_workers)
@@ -993,7 +1036,7 @@ class ProgressRequester(threading.Thread):
     self._instruction_id = instruction_id
     self._frequency = frequency
     self._done = False
-    self._latest_progress = None
+    self._latest_progress = None  # type: 
Optional[beam_fn_api_pb2.ProcessBundleProgressResponse]
     self._callback = callback
     self.daemon = True
 
diff --git 
a/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py 
b/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py
index e4d03c8..d29ab3b 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py
@@ -18,6 +18,7 @@
 """Pipeline transformations for the FnApiRunner.
 """
 # pytype: skip-file
+# mypy: check-untyped-defs
 
 from __future__ import absolute_import
 from __future__ import print_function
@@ -307,6 +308,9 @@ class Stage(object):
           for side in side_inputs
       },
                           main_input=main_input_id)
+      # at this point we should have resolved an environment, as the key of
+      # components.environments cannot be None.
+      assert self.environment is not None
       exec_payload = beam_runner_api_pb2.ExecutableStagePayload(
           environment=components.environments[self.environment],
           input=main_input_id,
@@ -373,7 +377,7 @@ class TransformContext(object):
         None)  # type: ignore[arg-type]
     self.bytes_coder_id = self.add_or_get_coder_id(coder_proto, 'bytes_coder')
     self.safe_coders = {self.bytes_coder_id: self.bytes_coder_id}
-    self.data_channel_coders = {}
+    self.data_channel_coders = {}  # type: Dict[str, str]
 
   def add_or_get_coder_id(
       self,
@@ -812,7 +816,8 @@ def pack_combiners(stages, context):
 
   def _get_fallback_coder_id():
     return context.add_or_get_coder_id(
-        coders.registry.get_coder(object).to_runner_api(None))
+        # passing None works here because there are no component coders
+        coders.registry.get_coder(object).to_runner_api(None))  # type: 
ignore[arg-type]
 
   def _get_component_coder_id_from_kv_coder(coder, index):
     assert index < 2
@@ -922,11 +927,14 @@ def pack_combiners(stages, context):
             is_bounded=input_pcoll.is_bounded))
 
     # Set up Pack stage.
+    # TODO(BEAM-7746): classes that inherit from RunnerApiFn are expected to
+    #  accept a PipelineContext for from_runner_api/to_runner_api.  Determine
+    #  how to accomodate this.
     pack_combine_fn = combiners.SingleInputTupleCombineFn(
         *[
-            core.CombineFn.from_runner_api(combine_payload.combine_fn, context)
+            core.CombineFn.from_runner_api(combine_payload.combine_fn, 
context)  # type: ignore[arg-type]
             for combine_payload in combine_payloads
-        ]).to_runner_api(context)
+        ]).to_runner_api(context)  # type: ignore[arg-type]
     pack_transform = beam_runner_api_pb2.PTransform(
         unique_name=pack_combine_name + '/Pack',
         spec=beam_runner_api_pb2.FunctionSpec(
@@ -1480,14 +1488,17 @@ def sink_flattens(stages, pipeline_context):
 
 
 def greedily_fuse(stages, pipeline_context):
+  # type: (Iterable[Stage], TransformContext) -> FrozenSet[Stage]
+
   """Places transforms sharing an edge in the same stage, whenever possible.
   """
-  producers_by_pcoll = {}
-  consumers_by_pcoll = collections.defaultdict(list)
+  producers_by_pcoll = {}  # type: Dict[str, Stage]
+  consumers_by_pcoll = collections.defaultdict(
+      list)  # type: DefaultDict[str, List[Stage]]
 
   # Used to always reference the correct stage as the producer and
   # consumer maps are not updated when stages are fused away.
-  replacements = {}
+  replacements = {}  # type: Dict[Stage, Stage]
 
   def replacement(s):
     old_ss = []
diff --git 
a/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py 
b/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py
index 1f1d483..756e4ba 100644
--- 
a/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py
+++ 
b/sdks/python/apache_beam/runners/portability/fn_api_runner/worker_handlers.py
@@ -17,6 +17,8 @@
 
 """Code for communicating with the Workers."""
 
+# mypy: disallow-untyped-defs
+
 from __future__ import absolute_import
 
 import collections
@@ -29,6 +31,9 @@ import sys
 import threading
 import time
 from builtins import object
+from typing import TYPE_CHECKING
+from typing import Any
+from typing import BinaryIO  # pylint: disable=unused-import
 from typing import Callable
 from typing import DefaultDict
 from typing import Dict
@@ -38,6 +43,8 @@ from typing import List
 from typing import Mapping
 from typing import Optional
 from typing import Tuple
+from typing import Type
+from typing import TypeVar
 from typing import Union
 from typing import cast
 from typing import overload
@@ -64,6 +71,12 @@ from apache_beam.runners.worker.sdk_worker import _Future
 from apache_beam.runners.worker.statecache import StateCache
 from apache_beam.utils import proto_utils
 from apache_beam.utils import thread_pool_executor
+from apache_beam.utils.sentinel import Sentinel
+
+if TYPE_CHECKING:
+  from grpc import ServicerContext
+  from google.protobuf import message
+  from apache_beam.runners.portability.fn_api_runner.fn_runner import 
ExtendedProvisionInfo  # pylint: disable=ungrouped-imports
 
 # State caching is enabled in the fn_api_runner for testing, except for one
 # test which runs without state caching (FnApiRunnerTestWithDisabledCaching).
@@ -75,10 +88,11 @@ DATA_BUFFER_TIME_LIMIT_MS = 1000
 
 _LOGGER = logging.getLogger(__name__)
 
+T = TypeVar('T')
 ConstructorFn = Callable[[
     Union['message.Message', bytes],
-    'StateServicer',
-    Optional['ExtendedProvisionInfo'],
+    'sdk_worker.StateHandler',
+    'ExtendedProvisionInfo',
     'GrpcServer'
 ],
                          'WorkerHandler']
@@ -90,8 +104,9 @@ class ControlConnection(object):
   _lock = threading.Lock()
 
   def __init__(self):
+    # type: () -> None
     self._push_queue = queue.Queue(
-    )  # type: queue.Queue[beam_fn_api_pb2.InstructionRequest]
+    )  # type: queue.Queue[Union[beam_fn_api_pb2.InstructionRequest, Sentinel]]
     self._input = None  # type: 
Optional[Iterable[beam_fn_api_pb2.InstructionResponse]]
     self._futures_by_id = dict()  # type: Dict[str, ControlFuture]
     self._read_thread = threading.Thread(
@@ -99,12 +114,14 @@ class ControlConnection(object):
     self._state = BeamFnControlServicer.UNSTARTED_STATE
 
   def _read(self):
+    # type: () -> None
+    assert self._input is not None
     for data in self._input:
       self._futures_by_id.pop(data.instruction_id).set(data)
 
   @overload
   def push(self, req):
-    # type: (BeamFnControlServicer.DoneMarker) -> None
+    # type: (Sentinel) -> None
     pass
 
   @overload
@@ -113,7 +130,8 @@ class ControlConnection(object):
     pass
 
   def push(self, req):
-    if req == BeamFnControlServicer._DONE_MARKER:
+    # type: (Union[Sentinel, beam_fn_api_pb2.InstructionRequest]) -> 
Optional[ControlFuture]
+    if req is BeamFnControlServicer._DONE_MARKER:
       self._push_queue.put(req)
       return None
     if not req.instruction_id:
@@ -126,7 +144,7 @@ class ControlConnection(object):
     return future
 
   def get_req(self):
-    # type: () -> beam_fn_api_pb2.InstructionRequest
+    # type: () -> Union[Sentinel, beam_fn_api_pb2.InstructionRequest]
     return self._push_queue.get()
 
   def set_input(self, input):
@@ -147,6 +165,7 @@ class ControlConnection(object):
       self._state = BeamFnControlServicer.DONE_STATE
 
   def abort(self, exn):
+    # type: (Exception) -> None
     for future in self._futures_by_id.values():
       future.abort(exn)
 
@@ -158,23 +177,20 @@ class 
BeamFnControlServicer(beam_fn_api_pb2_grpc.BeamFnControlServicer):
   STARTED_STATE = 'started'
   DONE_STATE = 'done'
 
-  class DoneMarker(object):
-    pass
-
-  _DONE_MARKER = DoneMarker()
+  _DONE_MARKER = Sentinel.sentinel
 
   def __init__(
       self,
       worker_manager,  # type: WorkerHandlerManager
   ):
+    # type: (...) -> None
     self._worker_manager = worker_manager
     self._lock = threading.Lock()
     self._uid_counter = 0
     self._state = self.UNSTARTED_STATE
     # following self._req_* variables are used for debugging purpose, data is
     # added only when self._log_req is True.
-    self._req_sent = collections.defaultdict(int)
-    self._req_worker_mapping = {}
+    self._req_sent = collections.defaultdict(int)  # type: DefaultDict[str, 
int]
     self._log_req = logging.getLogger().getEffectiveLevel() <= logging.DEBUG
     self._connections_by_worker_id = collections.defaultdict(
         ControlConnection)  # type: DefaultDict[str, ControlConnection]
@@ -186,7 +202,7 @@ class 
BeamFnControlServicer(beam_fn_api_pb2_grpc.BeamFnControlServicer):
 
   def Control(self,
               iterator,  # type: Iterable[beam_fn_api_pb2.InstructionResponse]
-              context
+              context  # type: ServicerContext
              ):
     # type: (...) -> Iterator[beam_fn_api_pb2.InstructionRequest]
     with self._lock:
@@ -213,16 +229,14 @@ class 
BeamFnControlServicer(beam_fn_api_pb2_grpc.BeamFnControlServicer):
         self._req_sent[to_push.instruction_id] += 1
 
   def done(self):
+    # type: () -> None
     self._state = self.DONE_STATE
     _LOGGER.debug(
         'Runner: Requests sent by runner: %s',
         [(str(req), cnt) for req, cnt in self._req_sent.items()])
-    _LOGGER.debug(
-        'Runner: Requests multiplexing info: %s',
-        [(str(req), worker)
-         for req, worker in self._req_worker_mapping.items()])
 
   def GetProcessBundleDescriptor(self, id, context=None):
+    # type: (beam_fn_api_pb2.GetProcessBundleDescriptorRequest, Any) -> 
beam_fn_api_pb2.ProcessBundleDescriptor
     return self._worker_manager.get_process_bundle_descriptor(id)
 
 
@@ -242,10 +256,10 @@ class WorkerHandler(object):
   data_conn = None  # type: data_plane._GrpcDataChannel
 
   def __init__(self,
-               control_handler,
-               data_plane_handler,
-               state,  # type: StateServicer
-               provision_info  # type: Optional[ExtendedProvisionInfo]
+               control_handler,  # type: Any
+               data_plane_handler,  # type: Any
+               state,  # type: sdk_worker.StateHandler
+               provision_info  # type: ExtendedProvisionInfo
               ):
     # type: (...) -> None
 
@@ -278,6 +292,14 @@ class WorkerHandler(object):
     # type: () -> None
     raise NotImplementedError
 
+  def control_api_service_descriptor(self):
+    # type: () -> endpoints_pb2.ApiServiceDescriptor
+    raise NotImplementedError
+
+  def artifact_api_service_descriptor(self):
+    # type: () -> endpoints_pb2.ApiServiceDescriptor
+    raise NotImplementedError
+
   def data_api_service_descriptor(self):
     # type: () -> Optional[endpoints_pb2.ApiServiceDescriptor]
     raise NotImplementedError
@@ -296,9 +318,10 @@ class WorkerHandler(object):
       urn,  # type: str
       payload_type  # type: Optional[Type[T]]
   ):
-    # type: (...) -> Callable[[Callable[[T, StateServicer, 
Optional[ExtendedProvisionInfo], GrpcServer], WorkerHandler]], Callable[[T, 
StateServicer, Optional[ExtendedProvisionInfo], GrpcServer], WorkerHandler]]
+    # type: (...) -> Callable[[Callable[[T, sdk_worker.StateHandler, 
ExtendedProvisionInfo, GrpcServer], WorkerHandler]], Callable[[T, 
sdk_worker.StateHandler, ExtendedProvisionInfo, GrpcServer], WorkerHandler]]
     def wrapper(constructor):
-      cls._registered_environments[urn] = constructor, payload_type
+      # type: (Callable) -> Callable
+      cls._registered_environments[urn] = constructor, payload_type  # type: 
ignore[assignment]
       return constructor
 
     return wrapper
@@ -306,8 +329,8 @@ class WorkerHandler(object):
   @classmethod
   def create(cls,
              environment,  # type: beam_runner_api_pb2.Environment
-             state,  # type: StateServicer
-             provision_info,  # type: Optional[ExtendedProvisionInfo]
+             state,  # type: sdk_worker.StateHandler
+             provision_info,  # type: ExtendedProvisionInfo
              grpc_server  # type: GrpcServer
             ):
     # type: (...) -> WorkerHandler
@@ -319,14 +342,17 @@ class WorkerHandler(object):
         grpc_server)
 
 
[email protected]_environment(python_urns.EMBEDDED_PYTHON, None)
+# This takes a WorkerHandlerManager instead of GrpcServer, so it is not
+# compatible with WorkerHandler.register_environment.  There is a special case
+# in WorkerHandlerManager.get_worker_handlers() that allows it to work.
[email protected]_environment(python_urns.EMBEDDED_PYTHON, None)  # 
type: ignore[arg-type]
 class EmbeddedWorkerHandler(WorkerHandler):
   """An in-memory worker_handler for fn API control, state and data planes."""
 
   def __init__(self,
                unused_payload,  # type: None
                state,  # type: sdk_worker.StateHandler
-               provision_info,  # type: Optional[ExtendedProvisionInfo]
+               provision_info,  # type: ExtendedProvisionInfo
                worker_manager,  # type: WorkerHandlerManager
               ):
     # type: (...) -> None
@@ -347,6 +373,7 @@ class EmbeddedWorkerHandler(WorkerHandler):
     self._uid_counter = 0
 
   def push(self, request):
+    # type: (beam_fn_api_pb2.InstructionRequest) -> ControlFuture
     if not request.instruction_id:
       self._uid_counter += 1
       request.instruction_id = 'control_%s' % self._uid_counter
@@ -394,6 +421,7 @@ class 
BasicLoggingService(beam_fn_api_pb2_grpc.BeamFnLoggingServicer):
   }
 
   def Logging(self, log_messages, context=None):
+    # type: (Iterable[beam_fn_api_pb2.LogEntry.List], Any) -> 
Iterator[beam_fn_api_pb2.LogControl]
     yield beam_fn_api_pb2.LogControl()
     for log_message in log_messages:
       for log in log_message.log_entries:
@@ -403,12 +431,12 @@ class 
BasicLoggingService(beam_fn_api_pb2_grpc.BeamFnLoggingServicer):
 class 
BasicProvisionService(beam_provision_api_pb2_grpc.ProvisionServiceServicer
                             ):
   def __init__(self, base_info, worker_manager):
-    # type: (Optional[beam_provision_api_pb2.ProvisionInfo], 
WorkerHandlerManager) -> None
+    # type: (beam_provision_api_pb2.ProvisionInfo, WorkerHandlerManager) -> 
None
     self._base_info = base_info
     self._worker_manager = worker_manager
 
   def GetProvisionInfo(self, request, context=None):
-    # type: (...) -> beam_provision_api_pb2.GetProvisionInfoResponse
+    # type: (Any, Optional[ServicerContext]) -> 
beam_provision_api_pb2.GetProvisionInfoResponse
     if context:
       worker_id = dict(context.invocation_metadata())['worker_id']
       worker = self._worker_manager.get_worker(worker_id)
@@ -466,6 +494,7 @@ class GrpcServer(object):
             self.control_server)
 
       def open_uncompressed(f):
+        # type: (str) -> BinaryIO
         return filesystems.FileSystems.open(
             f, compression_type=CompressionTypes.UNCOMPRESSED)
 
@@ -499,6 +528,7 @@ class GrpcServer(object):
     self.control_server.start()
 
   def close(self):
+    # type: () -> None
     self.control_handler.done()
     to_wait = [
         self.control_server.stop(self._DEFAULT_SHUTDOWN_TIMEOUT_SECS),
@@ -515,7 +545,7 @@ class GrpcWorkerHandler(WorkerHandler):
 
   def __init__(self,
                state,  # type: StateServicer
-               provision_info,  # type: Optional[ExtendedProvisionInfo]
+               provision_info,  # type: ExtendedProvisionInfo
                grpc_server  # type: GrpcServer
               ):
     # type: (...) -> None
@@ -566,9 +596,11 @@ class GrpcWorkerHandler(WorkerHandler):
     super(GrpcWorkerHandler, self).close()
 
   def port_from_worker(self, port):
+    # type: (int) -> str
     return '%s:%s' % (self.host_from_worker(), port)
 
   def host_from_worker(self):
+    # type: () -> str
     return 'localhost'
 
 
@@ -578,7 +610,7 @@ class ExternalWorkerHandler(GrpcWorkerHandler):
   def __init__(self,
                external_payload,  # type: beam_runner_api_pb2.ExternalPayload
                state,  # type: StateServicer
-               provision_info,  # type: Optional[ExtendedProvisionInfo]
+               provision_info,  # type: ExtendedProvisionInfo
                grpc_server  # type: GrpcServer
               ):
     # type: (...) -> None
@@ -610,6 +642,7 @@ class ExternalWorkerHandler(GrpcWorkerHandler):
     pass
 
   def host_from_worker(self):
+    # type: () -> str
     # TODO(BEAM-8646): Reconcile across platforms.
     if sys.platform in ['win32', 'darwin']:
       return 'localhost'
@@ -622,7 +655,7 @@ class EmbeddedGrpcWorkerHandler(GrpcWorkerHandler):
   def __init__(self,
                payload,  # type: bytes
                state,  # type: StateServicer
-               provision_info,  # type: Optional[ExtendedProvisionInfo]
+               provision_info,  # type: ExtendedProvisionInfo
                grpc_server  # type: GrpcServer
               ):
     # type: (...) -> None
@@ -662,7 +695,7 @@ class SubprocessSdkWorkerHandler(GrpcWorkerHandler):
   def __init__(self,
                worker_command_line,  # type: bytes
                state,  # type: StateServicer
-               provision_info,  # type: Optional[ExtendedProvisionInfo]
+               provision_info,  # type: ExtendedProvisionInfo
                grpc_server  # type: GrpcServer
               ):
     # type: (...) -> None
@@ -690,7 +723,7 @@ class DockerSdkWorkerHandler(GrpcWorkerHandler):
   def __init__(self,
                payload,  # type: beam_runner_api_pb2.DockerPayload
                state,  # type: StateServicer
-               provision_info,  # type: Optional[ExtendedProvisionInfo]
+               provision_info,  # type: ExtendedProvisionInfo
                grpc_server  # type: GrpcServer
               ):
     # type: (...) -> None
@@ -700,6 +733,7 @@ class DockerSdkWorkerHandler(GrpcWorkerHandler):
     self._container_id = None  # type: Optional[bytes]
 
   def host_from_worker(self):
+    # type: () -> str
     if sys.platform == "darwin":
       # See https://docs.docker.com/docker-for-mac/networking/
       return 'host.docker.internal'
@@ -753,7 +787,9 @@ class DockerSdkWorkerHandler(GrpcWorkerHandler):
     t.start()
 
   def watch_container(self):
+    # type: () -> None
     while not self._done:
+      assert self._container_id is not None
       status = subprocess.check_output(
           ['docker', 'inspect', '-f', '{{.State.Status}}',
            self._container_id]).strip()
@@ -788,7 +824,7 @@ class WorkerHandlerManager(object):
   """
   def __init__(self,
                environments,  # type: Mapping[str, 
beam_runner_api_pb2.Environment]
-               job_provision_info  # type: Optional[ExtendedProvisionInfo]
+               job_provision_info  # type: ExtendedProvisionInfo
               ):
     # type: (...) -> None
     self._environments = environments
@@ -798,13 +834,16 @@ class WorkerHandlerManager(object):
     self._workers_by_id = {}  # type: Dict[str, WorkerHandler]
     self.state_servicer = StateServicer()
     self._grpc_server = None  # type: Optional[GrpcServer]
-    self._process_bundle_descriptors = {}
+    self._process_bundle_descriptors = {
+    }  # type: Dict[str, beam_fn_api_pb2.ProcessBundleDescriptor]
 
   def register_process_bundle_descriptor(self, process_bundle_descriptor):
+    # type: (beam_fn_api_pb2.ProcessBundleDescriptor) -> None
     self._process_bundle_descriptors[
         process_bundle_descriptor.id] = process_bundle_descriptor
 
   def get_process_bundle_descriptor(self, request):
+    # type: (beam_fn_api_pb2.GetProcessBundleDescriptorRequest) -> 
beam_fn_api_pb2.ProcessBundleDescriptor
     return self._process_bundle_descriptors[
         request.process_bundle_descriptor_id]
 
@@ -852,6 +891,7 @@ class WorkerHandlerManager(object):
     return self._cached_handlers[environment_id][:num_workers]
 
   def close_all(self):
+    # type: () -> None
     for worker_handler_list in self._cached_handlers.values():
       for worker_handler in set(worker_handler_list):
         try:
@@ -859,13 +899,14 @@ class WorkerHandlerManager(object):
         except Exception:
           _LOGGER.error(
               "Error closing worker_handler %s" % worker_handler, 
exc_info=True)
-    self._cached_handlers = {}
+    self._cached_handlers = {}  # type: ignore[assignment]
     self._workers_by_id = {}
     if self._grpc_server is not None:
       self._grpc_server.close()
       self._grpc_server = None
 
   def get_worker(self, worker_id):
+    # type: (str) -> WorkerHandler
     return self._workers_by_id[worker_id]
 
 
@@ -953,6 +994,7 @@ class 
StateServicer(beam_fn_api_pb2_grpc.BeamFnStateServicer,
 
   @contextlib.contextmanager
   def process_instruction_id(self, unused_instruction_id):
+    # type: (Any) -> Iterator
     yield
 
   def get_raw(self,
@@ -1016,7 +1058,7 @@ class 
GrpcStateServicer(beam_fn_api_pb2_grpc.BeamFnStateServicer):
 
   def State(self,
       request_stream,  # type: Iterable[beam_fn_api_pb2.StateRequest]
-      context=None
+      context=None  # type: Any
             ):
     # type: (...) -> Iterator[beam_fn_api_pb2.StateResponse]
     # Note that this eagerly mutates state, assuming any failures are fatal.
@@ -1062,24 +1104,29 @@ class 
SingletonStateHandlerFactory(sdk_worker.StateHandlerFactory):
 
 
 class ControlFuture(object):
-  def __init__(self, instruction_id, response=None):
+  def __init__(self,
+               instruction_id,  # type: str
+               response=None  # type: 
Optional[beam_fn_api_pb2.InstructionResponse]
+              ):
+    # type: (...) -> None
     self.instruction_id = instruction_id
-    if response:
-      self._response = response
-    else:
-      self._response = None
+    self._response = response
+    if response is None:
       self._condition = threading.Condition()
-    self._exception = None
+    self._exception = None  # type: Optional[Exception]
 
   def is_done(self):
+    # type: () -> bool
     return self._response is not None
 
   def set(self, response):
+    # type: (beam_fn_api_pb2.InstructionResponse) -> None
     with self._condition:
       self._response = response
       self._condition.notify_all()
 
   def get(self, timeout=None):
+    # type: (Optional[float]) -> beam_fn_api_pb2.InstructionResponse
     if not self._response and not self._exception:
       with self._condition:
         if not self._response and not self._exception:
@@ -1087,9 +1134,11 @@ class ControlFuture(object):
     if self._exception:
       raise self._exception
     else:
+      assert self._response is not None
       return self._response
 
   def abort(self, exception):
+    # type: (Exception) -> None
     with self._condition:
       self._exception = exception
       self._condition.notify_all()
diff --git a/sdks/python/apache_beam/runners/portability/portable_runner.py 
b/sdks/python/apache_beam/runners/portability/portable_runner.py
index 713c762..15c8606 100644
--- a/sdks/python/apache_beam/runners/portability/portable_runner.py
+++ b/sdks/python/apache_beam/runners/portability/portable_runner.py
@@ -16,6 +16,7 @@
 #
 
 # pytype: skip-file
+# mypy: check-untyped-defs
 
 from __future__ import absolute_import
 from __future__ import division
@@ -27,7 +28,6 @@ import logging
 import threading
 import time
 from typing import TYPE_CHECKING
-from typing import Any
 from typing import Iterator
 from typing import Optional
 from typing import Tuple
@@ -197,8 +197,12 @@ class JobServiceHandle(object):
             pipeline_options=self.get_pipeline_options()),
         timeout=self.timeout)
 
-  def stage(self, pipeline, artifact_staging_endpoint, staging_session_token):
-    # type: (...) -> Optional[Any]
+  def stage(self,
+            proto_pipeline,  # type: beam_runner_api_pb2.Pipeline
+            artifact_staging_endpoint,
+            staging_session_token
+           ):
+    # type: (...) -> None
 
     """Stage artifacts"""
     if artifact_staging_endpoint:
@@ -288,6 +292,7 @@ class PortableRunner(runner.PipelineRunner):
         'use, such as --runner=FlinkRunner or --runner=SparkRunner.')
 
   def create_job_service_handle(self, job_service, options):
+    # type: (...) -> JobServiceHandle
     return JobServiceHandle(job_service, options)
 
   def create_job_service(self, options):
@@ -299,7 +304,7 @@ class PortableRunner(runner.PipelineRunner):
     job_endpoint = options.view_as(PortableOptions).job_endpoint
     if job_endpoint:
       if job_endpoint == 'embed':
-        server = job_server.EmbeddedJobServer()
+        server = job_server.EmbeddedJobServer()  # type: job_server.JobServer
       else:
         job_server_timeout = 
options.view_as(PortableOptions).job_server_timeout
         server = job_server.ExternalJobServer(job_endpoint, job_server_timeout)
@@ -463,6 +468,7 @@ class PipelineResult(runner.PipelineResult):
     self._runtime_exception = None
 
   def cancel(self):
+    # type: () -> None
     try:
       self._job_service.Cancel(
           beam_job_api_pb2.CancelJobRequest(job_id=self._job_id))
@@ -496,6 +502,7 @@ class PipelineResult(runner.PipelineResult):
     return self._metrics
 
   def _last_error_message(self):
+    # type: () -> str
     # Filter only messages with the "message_response" and error messages.
     messages = [
         m.message_response for m in self._messages
@@ -517,6 +524,7 @@ class PipelineResult(runner.PipelineResult):
     :return: The result of the pipeline, i.e. PipelineResult.
     """
     def read_messages():
+      # type: () -> None
       previous_state = -1
       for message in self._message_stream:
         if message.HasField('message_response'):
@@ -576,6 +584,7 @@ class PipelineResult(runner.PipelineResult):
       self._cleanup()
 
   def _cleanup(self, on_exit=False):
+    # type: (bool) -> None
     if on_exit and self._cleanup_callbacks:
       _LOGGER.info(
           'Running cleanup on exit. If your pipeline should continue running, '
diff --git a/sdks/python/apache_beam/runners/portability/stager.py 
b/sdks/python/apache_beam/runners/portability/stager.py
index e6288fc..f1a820f 100644
--- a/sdks/python/apache_beam/runners/portability/stager.py
+++ b/sdks/python/apache_beam/runners/portability/stager.py
@@ -57,6 +57,7 @@ import sys
 import tempfile
 from typing import List
 from typing import Optional
+from typing import Tuple
 
 import pkg_resources
 from future.moves.urllib.parse import urlparse
diff --git a/sdks/python/apache_beam/utils/profiler.py 
b/sdks/python/apache_beam/utils/profiler.py
index dae625a..81bd578 100644
--- a/sdks/python/apache_beam/utils/profiler.py
+++ b/sdks/python/apache_beam/utils/profiler.py
@@ -21,6 +21,7 @@ For internal use only; no backwards-compatibility guarantees.
 """
 
 # pytype: skip-file
+# mypy: check-untyped-defs
 
 from __future__ import absolute_import
 
@@ -47,6 +48,9 @@ class Profile(object):
 
   SORTBY = 'cumulative'
 
+  profile_output = None  # type: str
+  stats = None  # type: pstats.Stats
+
   def __init__(
       self,
       profile_id, # type: str
@@ -72,13 +76,11 @@ class Profile(object):
         profiling session, the profiler only records the newly allocated 
objects
         in this session.
     """
-    self.stats = None
     self.profile_id = str(profile_id)
     self.profile_location = profile_location
     self.log_results = log_results
     self.file_copy_fn = file_copy_fn or self.default_file_copy_fn
     self.time_prefix = time_prefix
-    self.profile_output = None
     self.enable_cpu_profiling = enable_cpu_profiling
     self.enable_memory_profiling = enable_memory_profiling
 
@@ -104,7 +106,8 @@ class Profile(object):
       if self.enable_cpu_profiling:
         self.profile.create_stats()
         self.profile_output = self._upload_profile_data(
-            'cpu_profile', self.profile.stats)
+            # typing: seems stats attr is missing from typeshed
+            self.profile_location, 'cpu_profile', self.profile.stats)  # type: 
ignore[attr-defined]
 
       if self.enable_memory_profiling:
         if not self.hpy:
@@ -113,7 +116,10 @@ class Profile(object):
           h = self.hpy.heap()
           heap_dump_data = '%s\n%s' % (h, h.more)
           self._upload_profile_data(
-              'memory_profile', heap_dump_data, write_binary=False)
+              self.profile_location,
+              'memory_profile',
+              heap_dump_data,
+              write_binary=False)
 
     if self.log_results:
       if self.enable_cpu_profiling:
@@ -156,18 +162,20 @@ class Profile(object):
       return create_profiler
     return None
 
-  def _upload_profile_data(self, dir, data, write_binary=True):
+  def _upload_profile_data(
+      self, profile_location, dir, data, write_binary=True):
+    # type: (...) -> str
     dump_location = os.path.join(
-        self.profile_location,
+        profile_location,
         dir,
         time.strftime(self.time_prefix + self.profile_id))
     fd, filename = tempfile.mkstemp()
     try:
       os.close(fd)
       if write_binary:
-        with open(filename, 'wb') as f:
+        with open(filename, 'wb') as fb:
           import marshal
-          marshal.dump(data, f)
+          marshal.dump(data, fb)
       else:
         with open(filename, 'w') as f:
           f.write(data)
diff --git a/sdks/python/mypy.ini b/sdks/python/mypy.ini
index 121d1b4..6095c02 100644
--- a/sdks/python/mypy.ini
+++ b/sdks/python/mypy.ini
@@ -61,8 +61,6 @@ ignore_errors = true
 
 
 # TODO(BEAM-7746): Remove the lines below.
-[mypy-apache_beam.coders.coders]
-ignore_errors = true
 
 [mypy-apache_beam.io.*]
 ignore_errors = true
@@ -88,18 +86,6 @@ ignore_errors = true
 [mypy-apache_beam.runners.interactive.*]
 ignore_errors = true
 
-[mypy-apache_beam.runners.portability.artifact_service]
-ignore_errors = true
-
-[mypy-apache_beam.runners.portability.fn_api_runner.*]
-ignore_errors = true
-
-[mypy-apache_beam.runners.portability.portable_runner]
-ignore_errors = true
-
-[mypy-apache_beam.runners.portability.stager]
-ignore_errors = true
-
 [mypy-apache_beam.testing.synthetic_pipeline]
 ignore_errors = true
 

Reply via email to