This is an automated email from the ASF dual-hosted git repository. pabloem pushed a commit to branch revert-11270-fn-ref-more in repository https://gitbox.apache.org/repos/asf/beam.git
commit e5e52694bcaf9a4cdc9fd4a130f8cca4dcc6fe6a Author: Pablo <[email protected]> AuthorDate: Mon Apr 27 17:27:34 2020 -0700 Revert "[BEAM-9639][BEAM-9608] Improvements for FnApiRunner" --- .../runners/portability/fn_api_runner/execution.py | 179 +-------- .../runners/portability/fn_api_runner/fn_runner.py | 406 +++++++++++++-------- .../portability/fn_api_runner/fn_runner_test.py | 24 -- .../portability/fn_api_runner/translations.py | 11 - 4 files changed, 255 insertions(+), 365 deletions(-) diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py index 2f29515..e62d8a8 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py @@ -22,29 +22,17 @@ from __future__ import absolute_import import collections import itertools from typing import TYPE_CHECKING -from typing import Any -from typing import DefaultDict -from typing import Dict -from typing import Iterator -from typing import List from typing import MutableMapping -from typing import Optional -from typing import Tuple from typing_extensions import Protocol from apache_beam import coders -from apache_beam.coders import BytesCoder from apache_beam.coders.coder_impl import create_InputStream from apache_beam.coders.coder_impl import create_OutputStream -from apache_beam.coders.coders import GlobalWindowCoder -from apache_beam.coders.coders import WindowedValueCoder from apache_beam.portability import common_urns from apache_beam.portability.api import beam_fn_api_pb2 from apache_beam.portability.api import beam_runner_api_pb2 from apache_beam.runners import pipeline_context -from apache_beam.runners.portability.fn_api_runner import translations -from apache_beam.runners.portability.fn_api_runner.translations import create_buffer_id from apache_beam.runners.portability.fn_api_runner.translations import only_element from apache_beam.runners.portability.fn_api_runner.translations import split_buffer_id from apache_beam.runners.portability.fn_api_runner.translations import unique_name @@ -57,13 +45,8 @@ from apache_beam.utils import windowed_value if TYPE_CHECKING: from apache_beam.coders.coder_impl import CoderImpl + from apache_beam.runners.portability.fn_api_runner import translations from apache_beam.runners.portability.fn_api_runner import worker_handlers - from apache_beam.runners.portability.fn_api_runner.translations import DataSideInput - from apache_beam.transforms.window import BoundedWindow - -ENCODED_IMPULSE_VALUE = WindowedValueCoder( - BytesCoder(), GlobalWindowCoder()).get_impl().encode_nested( - GlobalWindows.windowed_value(b'')) class Buffer(Protocol): @@ -221,7 +204,7 @@ class WindowGroupingBuffer(object): def __init__( self, access_pattern, - coder # type: WindowedValueCoder + coder # type: coders.WindowedValueCoder ): # type: (...) -> None # Here's where we would use a different type of partitioning @@ -268,12 +251,11 @@ class WindowGroupingBuffer(object): class FnApiRunnerExecutionContext(object): """ - :var pcoll_buffers: (dict): Mapping of + :var pcoll_buffers: (collections.defaultdict of str: list): Mapping of PCollection IDs to list that functions as buffer for the ``beam.PCollection``. """ def __init__(self, - stages, # type: List[translations.Stage] worker_handler_manager, # type: worker_handlers.WorkerHandlerManager pipeline_components, # type: beam_runner_api_pb2.Components safe_coders, @@ -286,9 +268,6 @@ class FnApiRunnerExecutionContext(object): :param safe_coders: :param data_channel_coders: """ - self.stages = stages - self.side_input_descriptors_by_stage = ( - self._build_data_side_inputs_map(stages)) self.pcoll_buffers = {} # type: MutableMapping[bytes, PartitionableBuffer] self.timer_buffers = {} # type: MutableMapping[bytes, ListBuffer] self.worker_handler_manager = worker_handler_manager @@ -301,63 +280,6 @@ class FnApiRunnerExecutionContext(object): iterable_state_write=self._iterable_state_write) self._last_uid = -1 - @staticmethod - def _build_data_side_inputs_map(stages): - # type: (Iterable[translations.Stage]) -> MutableMapping[str, DataSideInput] - - """Builds an index mapping stages to side input descriptors. - - A side input descriptor is a map of side input IDs to side input access - patterns for all of the outputs of a stage that will be consumed as a - side input. - """ - transform_consumers = collections.defaultdict( - list) # type: DefaultDict[str, List[beam_runner_api_pb2.PTransform]] - stage_consumers = collections.defaultdict( - list) # type: DefaultDict[str, List[translations.Stage]] - - def get_all_side_inputs(): - # type: () -> Set[str] - all_side_inputs = set() # type: Set[str] - for stage in stages: - for transform in stage.transforms: - for input in transform.inputs.values(): - transform_consumers[input].append(transform) - stage_consumers[input].append(stage) - for si in stage.side_inputs(): - all_side_inputs.add(si) - return all_side_inputs - - all_side_inputs = frozenset(get_all_side_inputs()) - data_side_inputs_by_producing_stage = {} - - producing_stages_by_pcoll = {} - - for s in stages: - data_side_inputs_by_producing_stage[s.name] = {} - for transform in s.transforms: - for o in transform.outputs.values(): - if o in s.side_inputs(): - continue - producing_stages_by_pcoll[o] = s - - for side_pc in all_side_inputs: - for consuming_transform in transform_consumers[side_pc]: - if consuming_transform.spec.urn not in translations.PAR_DO_URNS: - continue - producing_stage = producing_stages_by_pcoll[side_pc] - payload = proto_utils.parse_Bytes( - consuming_transform.spec.payload, beam_runner_api_pb2.ParDoPayload) - for si_tag in payload.side_inputs: - if consuming_transform.inputs[si_tag] == side_pc: - side_input_id = (consuming_transform.unique_name, si_tag) - data_side_inputs_by_producing_stage[ - producing_stage.name][side_input_id] = ( - translations.create_buffer_id(side_pc), - payload.side_inputs[si_tag].access_pattern) - - return data_side_inputs_by_producing_stage - @property def state_servicer(self): # TODO(BEAM-9625): Ensure FnApiRunnerExecutionContext owns StateServicer @@ -379,43 +301,6 @@ class FnApiRunnerExecutionContext(object): out.get()) return token - def commit_side_inputs_to_state( - self, - data_side_input, # type: DataSideInput - ): - # type: (...) -> None - for (consuming_transform_id, tag), (buffer_id, - func_spec) in data_side_input.items(): - _, pcoll_id = split_buffer_id(buffer_id) - value_coder = self.pipeline_context.coders[self.safe_coders[ - self.data_channel_coders[pcoll_id]]] - elements_by_window = WindowGroupingBuffer(func_spec, value_coder) - if buffer_id not in self.pcoll_buffers: - self.pcoll_buffers[buffer_id] = ListBuffer( - coder_impl=value_coder.get_impl()) - for element_data in self.pcoll_buffers[buffer_id]: - elements_by_window.append(element_data) - - if func_spec.urn == common_urns.side_inputs.ITERABLE.urn: - for _, window, elements_data in elements_by_window.encoded_items(): - state_key = beam_fn_api_pb2.StateKey( - iterable_side_input=beam_fn_api_pb2.StateKey.IterableSideInput( - transform_id=consuming_transform_id, - side_input_id=tag, - window=window)) - self.state_servicer.append_raw(state_key, elements_data) - elif func_spec.urn == common_urns.side_inputs.MULTIMAP.urn: - for key, window, elements_data in elements_by_window.encoded_items(): - state_key = beam_fn_api_pb2.StateKey( - multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput( - transform_id=consuming_transform_id, - side_input_id=tag, - window=window, - key=key)) - self.state_servicer.append_raw(state_key, elements_data) - else: - raise ValueError("Unknown access pattern: '%s'" % func_spec.urn) - class BundleContextManager(object): @@ -482,64 +367,6 @@ class BundleContextManager(object): state_api_service_descriptor=self.state_api_service_descriptor(), timer_api_service_descriptor=self.data_api_service_descriptor()) - def extract_bundle_inputs_and_outputs(self): - # type: (...) -> Tuple[Dict[str, PartitionableBuffer], DataOutput, Dict[Tuple[str, str], str]] - - """Returns maps of transform names to PCollection identifiers. - - Also mutates IO stages to point to the data ApiServiceDescriptor. - - Returns: - A tuple of (data_input, data_output, expected_timer_output) dictionaries. - `data_input` is a dictionary mapping (transform_name, output_name) to a - PCollection buffer; `data_output` is a dictionary mapping - (transform_name, output_name) to a PCollection ID. - `expected_timer_output` is a dictionary mapping transform_id and - timer family ID to a buffer id for timers. - """ - data_input = {} # type: Dict[str, PartitionableBuffer] - data_output = {} # type: DataOutput - # A mapping of {(transform_id, timer_family_id) : buffer_id} - expected_timer_output = {} # type: Dict[Tuple[str, str], str] - for transform in self.stage.transforms: - if transform.spec.urn in (bundle_processor.DATA_INPUT_URN, - bundle_processor.DATA_OUTPUT_URN): - pcoll_id = transform.spec.payload - if transform.spec.urn == bundle_processor.DATA_INPUT_URN: - coder_id = self.execution_context.data_channel_coders[only_element( - transform.outputs.values())] - coder = self.execution_context.pipeline_context.coders[ - self.execution_context.safe_coders.get(coder_id, coder_id)] - if pcoll_id == translations.IMPULSE_BUFFER: - data_input[transform.unique_name] = ListBuffer( - coder_impl=coder.get_impl()) - data_input[transform.unique_name].append(ENCODED_IMPULSE_VALUE) - else: - if pcoll_id not in self.execution_context.pcoll_buffers: - self.execution_context.pcoll_buffers[pcoll_id] = ListBuffer( - coder_impl=coder.get_impl()) - data_input[transform.unique_name] = ( - self.execution_context.pcoll_buffers[pcoll_id]) - elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN: - data_output[transform.unique_name] = pcoll_id - coder_id = self.execution_context.data_channel_coders[only_element( - transform.inputs.values())] - else: - raise NotImplementedError - data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id) - data_api_service_descriptor = self.data_api_service_descriptor() - if data_api_service_descriptor: - data_spec.api_service_descriptor.url = ( - data_api_service_descriptor.url) - transform.spec.payload = data_spec.SerializeToString() - elif transform.spec.urn in translations.PAR_DO_URNS: - payload = proto_utils.parse_Bytes( - transform.spec.payload, beam_runner_api_pb2.ParDoPayload) - for timer_family_id in payload.timer_family_specs.keys(): - expected_timer_output[(transform.unique_name, timer_family_id)] = ( - create_buffer_id(timer_family_id, 'timers')) - return data_input, data_output, expected_timer_output - def get_input_coder_impl(self, transform_id): # type: (str) -> CoderImpl coder_id = beam_fn_api_pb2.RemoteGrpcPort.FromString( diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py index 1f162ac..8e7ba2d 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner.py @@ -40,9 +40,11 @@ from typing import List from typing import Mapping from typing import MutableMapping from typing import Optional +from typing import Sequence from typing import Tuple from typing import TypeVar +import apache_beam as beam # pylint: disable=ungrouped-imports from apache_beam.coders.coder_impl import create_OutputStream from apache_beam.metrics import metric from apache_beam.metrics import monitoring_infos @@ -58,9 +60,12 @@ from apache_beam.runners.portability import portable_metrics from apache_beam.runners.portability.fn_api_runner import execution from apache_beam.runners.portability.fn_api_runner import translations from apache_beam.runners.portability.fn_api_runner.execution import ListBuffer +from apache_beam.runners.portability.fn_api_runner.execution import WindowGroupingBuffer from apache_beam.runners.portability.fn_api_runner.translations import create_buffer_id from apache_beam.runners.portability.fn_api_runner.translations import only_element +from apache_beam.runners.portability.fn_api_runner.translations import split_buffer_id from apache_beam.runners.portability.fn_api_runner.worker_handlers import WorkerHandlerManager +from apache_beam.runners.worker import bundle_processor from apache_beam.transforms import environments from apache_beam.utils import profiler from apache_beam.utils import proto_utils @@ -68,6 +73,7 @@ from apache_beam.utils.thread_pool_executor import UnboundedThreadPoolExecutor if TYPE_CHECKING: from apache_beam.pipeline import Pipeline + from apache_beam.coders.coder_impl import CoderImpl from apache_beam.portability.api import metrics_pb2 _LOGGER = logging.getLogger(__name__) @@ -82,6 +88,11 @@ BundleProcessResult = Tuple[beam_fn_api_pb2.InstructionResponse, # This module is experimental. No backwards-compatibility guarantees. +ENCODED_IMPULSE_VALUE = beam.coders.WindowedValueCoder( + beam.coders.BytesCoder(), + beam.coders.coders.GlobalWindowCoder()).get_impl().encode_nested( + beam.transforms.window.GlobalWindows.windowed_value(b'')) + class FnApiRunner(runner.PipelineRunner): @@ -314,7 +325,6 @@ class FnApiRunner(runner.PipelineRunner): monitoring_infos_by_stage = {} runner_execution_context = execution.FnApiRunnerExecutionContext( - stages, worker_handler_manager, stage_context.components, stage_context.safe_coders, @@ -325,7 +335,6 @@ class FnApiRunner(runner.PipelineRunner): for stage in stages: bundle_context_manager = execution.BundleContextManager( runner_execution_context, stage, self._num_workers) - stage_results = self._run_stage( runner_execution_context, bundle_context_manager, @@ -336,14 +345,54 @@ class FnApiRunner(runner.PipelineRunner): worker_handler_manager.close_all() return RunnerResult(runner.PipelineState.DONE, monitoring_infos_by_stage) + def _store_side_inputs_in_state(self, + runner_execution_context, # type: execution.FnApiRunnerExecutionContext + data_side_input, # type: DataSideInput + ): + # type: (...) -> None + for (transform_id, tag), (buffer_id, si) in data_side_input.items(): + _, pcoll_id = split_buffer_id(buffer_id) + value_coder = runner_execution_context.pipeline_context.coders[ + runner_execution_context.safe_coders[ + runner_execution_context.data_channel_coders[pcoll_id]]] + elements_by_window = WindowGroupingBuffer(si, value_coder) + if buffer_id not in runner_execution_context.pcoll_buffers: + runner_execution_context.pcoll_buffers[buffer_id] = ListBuffer( + coder_impl=value_coder.get_impl()) + for element_data in runner_execution_context.pcoll_buffers[buffer_id]: + elements_by_window.append(element_data) + + if si.urn == common_urns.side_inputs.ITERABLE.urn: + for _, window, elements_data in elements_by_window.encoded_items(): + state_key = beam_fn_api_pb2.StateKey( + iterable_side_input=beam_fn_api_pb2.StateKey.IterableSideInput( + transform_id=transform_id, side_input_id=tag, window=window)) + ( + runner_execution_context.worker_handler_manager.state_servicer. + append_raw(state_key, elements_data)) + elif si.urn == common_urns.side_inputs.MULTIMAP.urn: + for key, window, elements_data in elements_by_window.encoded_items(): + state_key = beam_fn_api_pb2.StateKey( + multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput( + transform_id=transform_id, + side_input_id=tag, + window=window, + key=key)) + ( + runner_execution_context.worker_handler_manager.state_servicer. + append_raw(state_key, elements_data)) + else: + raise ValueError("Unknown access pattern: '%s'" % si.urn) + def _run_bundle_multiple_times_for_testing( self, runner_execution_context, # type: execution.FnApiRunnerExecutionContext - bundle_manager, # type: BundleManager + bundle_context_manager, # type: execution.BundleContextManager data_input, data_output, # type: DataOutput fired_timers, expected_output_timers, + cache_token_generator ): # type: (...) -> None @@ -354,12 +403,18 @@ class FnApiRunner(runner.PipelineRunner): for _ in range(self._bundle_repeat): try: runner_execution_context.state_servicer.checkpoint() - bundle_manager.process_bundle( - data_input, - data_output, - fired_timers, - expected_output_timers, - dry_run=True) + testing_bundle_manager = ParallelBundleManager( + bundle_context_manager.worker_handlers, + lambda pcoll_id, + transform_id: ListBuffer( + coder_impl=bundle_context_manager.get_input_coder_impl), + bundle_context_manager.get_input_coder_impl, + bundle_context_manager.process_bundle_descriptor, + self._progress_frequency, + num_workers=self._num_workers, + cache_token_generator=cache_token_generator) + testing_bundle_manager.process_bundle( + data_input, data_output, fired_timers, expected_output_timers) finally: runner_execution_context.state_servicer.restore() @@ -389,17 +444,6 @@ class FnApiRunner(runner.PipelineRunner): fired_timers[(transform_id, timer_family_id)].append(out.get()) written_timers.clear() - def _add_sdk_delayed_applications_to_deferred_inputs( - self, bundle_context_manager, bundle_result, deferred_inputs): - for delayed_application in bundle_result.process_bundle.residual_roots: - name = bundle_context_manager.input_for( - delayed_application.application.transform_id, - delayed_application.application.input_id) - if name not in deferred_inputs: - deferred_inputs[name] = ListBuffer( - coder_impl=bundle_context_manager.get_input_coder_impl(name)) - deferred_inputs[name].append(delayed_application.application.element) - def _add_residuals_and_channel_splits_to_deferred_inputs( self, splits, # type: List[beam_fn_api_pb2.ProcessBundleSplitResponse] @@ -463,115 +507,164 @@ class FnApiRunner(runner.PipelineRunner): Args: runner_execution_context (execution.FnApiRunnerExecutionContext): An object containing execution information for the pipeline. - bundle_context_manager (execution.BundleContextManager): A description of - the stage to execute, and its context. + stage (translations.Stage): A description of the stage to execute. """ - data_input, data_output, expected_timer_output = ( - bundle_context_manager.extract_bundle_inputs_and_outputs()) - input_timers = {} - + worker_handler_list = bundle_context_manager.worker_handlers worker_handler_manager = runner_execution_context.worker_handler_manager _LOGGER.info('Running %s', bundle_context_manager.stage.name) + (data_input, data_side_input, data_output, + expected_timer_output) = self._extract_endpoints( + bundle_context_manager, runner_execution_context) worker_handler_manager.register_process_bundle_descriptor( bundle_context_manager.process_bundle_descriptor) - # We create the bundle manager here, as it can be reused for bundles of the - # same stage, but it may have to be created by-bundle later on. + # Store the required side inputs into state so it is accessible for the + # worker when it runs this bundle. + self._store_side_inputs_in_state(runner_execution_context, data_side_input) + + # Change cache token across bundle repeats cache_token_generator = FnApiRunner.get_cache_token_generator(static=False) - bundle_manager = ParallelBundleManager( + + self._run_bundle_multiple_times_for_testing( + runner_execution_context, bundle_context_manager, + data_input, + data_output, {}, + expected_timer_output, + cache_token_generator=cache_token_generator) + + bundle_manager = ParallelBundleManager( + worker_handler_list, + bundle_context_manager.get_buffer, + bundle_context_manager.get_input_coder_impl, + bundle_context_manager.process_bundle_descriptor, self._progress_frequency, + num_workers=self._num_workers, cache_token_generator=cache_token_generator) - final_result = None + # For the first time of processing, we don't have fired timers as inputs. + result, splits = bundle_manager.process_bundle(data_input, + data_output, + {}, + expected_timer_output) - def merge_results(last_result): - """ Merge the latest result with other accumulated results. """ - return ( - last_result - if final_result is None else beam_fn_api_pb2.InstructionResponse( - process_bundle=beam_fn_api_pb2.ProcessBundleResponse( - monitoring_infos=monitoring_infos.consolidate( - itertools.chain( - final_result.process_bundle.monitoring_infos, - last_result.process_bundle.monitoring_infos))), - error=final_result.error or last_result.error)) + last_result = result + last_sent = data_input + # We cannot split deferred_input until we include residual_roots to + # merged results. Without residual_roots, pipeline stops earlier and we + # may miss some data. + # We also don't partition fired timer inputs for the same reason. + bundle_manager._num_workers = 1 while True: - last_result, deferred_inputs, fired_timers = self._run_bundle( - runner_execution_context, - bundle_context_manager, - data_input, - data_output, - input_timers, - expected_timer_output, - bundle_manager) - - final_result = merge_results(last_result) - if not deferred_inputs and not fired_timers: - break + deferred_inputs = {} # type: Dict[str, PartitionableBuffer] + fired_timers = {} + + self._collect_written_timers_and_add_to_fired_timers( + bundle_context_manager, fired_timers) + # Queue any process-initiated delayed bundle applications. + for delayed_application in last_result.process_bundle.residual_roots: + name = bundle_context_manager.input_for( + delayed_application.application.transform_id, + delayed_application.application.input_id) + if name not in deferred_inputs: + deferred_inputs[name] = ListBuffer( + coder_impl=bundle_context_manager.get_input_coder_impl(name)) + deferred_inputs[name].append(delayed_application.application.element) + # Queue any runner-initiated delayed bundle applications. + self._add_residuals_and_channel_splits_to_deferred_inputs( + splits, bundle_context_manager, last_sent, deferred_inputs) + + if deferred_inputs or fired_timers: + # The worker will be waiting on these inputs as well. + for other_input in data_input: + if other_input not in deferred_inputs: + deferred_inputs[other_input] = ListBuffer( + coder_impl=bundle_context_manager.get_input_coder_impl( + other_input)) + # TODO(robertwb): merge results + last_result, splits = bundle_manager.process_bundle( + deferred_inputs, data_output, fired_timers, expected_timer_output) + last_sent = deferred_inputs + result = beam_fn_api_pb2.InstructionResponse( + process_bundle=beam_fn_api_pb2.ProcessBundleResponse( + monitoring_infos=monitoring_infos.consolidate( + itertools.chain( + result.process_bundle.monitoring_infos, + last_result.process_bundle.monitoring_infos))), + error=result.error or last_result.error) else: - data_input = deferred_inputs - input_timers = fired_timers - bundle_manager._registered = True + break - # Store the required downstream side inputs into state so it is accessible - # for the worker when it runs bundles that consume this stage's output. - data_side_input = ( - runner_execution_context.side_input_descriptors_by_stage.get( - bundle_context_manager.stage.name, {})) - runner_execution_context.commit_side_inputs_to_state(data_side_input) + return result - return final_result + @staticmethod + def _extract_endpoints(bundle_context_manager, # type: execution.BundleContextManager + runner_execution_context, # type: execution.FnApiRunnerExecutionContext + ): + # type: (...) -> Tuple[Dict[str, PartitionableBuffer], DataSideInput, DataOutput] - def _run_bundle( - self, - runner_execution_context, - bundle_context_manager, - data_input, - data_output, - input_timers, - expected_timer_output, - bundle_manager): - """Execute a bundle, and return a result object, and deferred inputs.""" - self._run_bundle_multiple_times_for_testing( - runner_execution_context, - bundle_manager, - data_input, - data_output, - input_timers, - expected_timer_output) - - result, splits = bundle_manager.process_bundle( - data_input, data_output, input_timers, expected_timer_output) - # Now we collect all the deferred inputs remaining from bundle execution. - # Deferred inputs can be: - # - timers - # - SDK-initiated deferred applications of root elements - # - Runner-initiated deferred applications of root elements - deferred_inputs = {} # type: Dict[str, execution.PartitionableBuffer] - fired_timers = {} - - self._collect_written_timers_and_add_to_fired_timers( - bundle_context_manager, fired_timers) - - self._add_sdk_delayed_applications_to_deferred_inputs( - bundle_context_manager, result, deferred_inputs) - - self._add_residuals_and_channel_splits_to_deferred_inputs( - splits, bundle_context_manager, data_input, deferred_inputs) - - # After collecting deferred inputs, we 'pad' the structure with empty - # buffers for other expected inputs. - if deferred_inputs or fired_timers: - # The worker will be waiting on these inputs as well. - for other_input in data_input: - if other_input not in deferred_inputs: - deferred_inputs[other_input] = ListBuffer( - coder_impl=bundle_context_manager.get_input_coder_impl( - other_input)) - - return result, deferred_inputs, fired_timers + """Returns maps of transform names to PCollection identifiers. + + Also mutates IO stages to point to the data ApiServiceDescriptor. + + Args: + stage (translations.Stage): The stage to extract endpoints + for. + data_api_service_descriptor: A GRPC endpoint descriptor for data plane. + Returns: + A tuple of (data_input, data_side_input, data_output) dictionaries. + `data_input` is a dictionary mapping (transform_name, output_name) to a + PCollection buffer; `data_output` is a dictionary mapping + (transform_name, output_name) to a PCollection ID. + """ + data_input = {} # type: Dict[str, PartitionableBuffer] + data_side_input = {} # type: DataSideInput + data_output = {} # type: DataOutput + # A mapping of {(transform_id, timer_family_id) : buffer_id} + expected_timer_output = {} # type: Dict[Tuple(str, str), str] + for transform in bundle_context_manager.stage.transforms: + if transform.spec.urn in (bundle_processor.DATA_INPUT_URN, + bundle_processor.DATA_OUTPUT_URN): + pcoll_id = transform.spec.payload + if transform.spec.urn == bundle_processor.DATA_INPUT_URN: + coder_id = runner_execution_context.data_channel_coders[only_element( + transform.outputs.values())] + coder = runner_execution_context.pipeline_context.coders[ + runner_execution_context.safe_coders.get(coder_id, coder_id)] + if pcoll_id == translations.IMPULSE_BUFFER: + data_input[transform.unique_name] = ListBuffer( + coder_impl=coder.get_impl()) + data_input[transform.unique_name].append(ENCODED_IMPULSE_VALUE) + else: + if pcoll_id not in runner_execution_context.pcoll_buffers: + runner_execution_context.pcoll_buffers[pcoll_id] = ListBuffer( + coder_impl=coder.get_impl()) + data_input[transform.unique_name] = ( + runner_execution_context.pcoll_buffers[pcoll_id]) + elif transform.spec.urn == bundle_processor.DATA_OUTPUT_URN: + data_output[transform.unique_name] = pcoll_id + coder_id = runner_execution_context.data_channel_coders[only_element( + transform.inputs.values())] + else: + raise NotImplementedError + data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id) + data_api_service_descriptor = ( + bundle_context_manager.data_api_service_descriptor()) + if data_api_service_descriptor: + data_spec.api_service_descriptor.url = ( + data_api_service_descriptor.url) + transform.spec.payload = data_spec.SerializeToString() + elif transform.spec.urn in translations.PAR_DO_URNS: + payload = proto_utils.parse_Bytes( + transform.spec.payload, beam_runner_api_pb2.ParDoPayload) + for tag, si in payload.side_inputs.items(): + data_side_input[transform.unique_name, tag] = ( + create_buffer_id(transform.inputs[tag]), si.access_pattern) + for timer_family_id in payload.timer_family_specs.keys(): + expected_timer_output[(transform.unique_name, timer_family_id)] = ( + create_buffer_id(timer_family_id, 'timers')) + return data_input, data_side_input, data_output, expected_timer_output @staticmethod def get_cache_token_generator(static=True): @@ -672,18 +765,28 @@ class BundleManager(object): _lock = threading.Lock() def __init__(self, - bundle_context_manager, # type: execution.BundleContextManager + worker_handler_list, # type: Sequence[WorkerHandler] + get_buffer, # type: Callable[[bytes, str], PartitionableBuffer] + get_input_coder_impl, # type: Callable[[str], CoderImpl] + bundle_descriptor, # type: beam_fn_api_pb2.ProcessBundleDescriptor progress_frequency=None, cache_token_generator=FnApiRunner.get_cache_token_generator() ): """Set up a bundle manager. Args: + worker_handler_list + get_buffer (Callable[[str], list]) + get_input_coder_impl (Callable[[str], Coder]) + bundle_descriptor (beam_fn_api_pb2.ProcessBundleDescriptor) progress_frequency """ - self.bundle_context_manager = bundle_context_manager # type: execution.BundleContextManager + self._worker_handler_list = worker_handler_list + self._get_buffer = get_buffer + self._get_input_coder_impl = get_input_coder_impl + self._bundle_descriptor = bundle_descriptor self._progress_frequency = progress_frequency - self._worker_handler = None # type: Optional[execution.WorkerHandler] + self._worker_handler = None # type: Optional[WorkerHandler] self._cache_token_generator = cache_token_generator def _send_input_to_worker(self, @@ -711,8 +814,7 @@ class BundleManager(object): def _select_split_manager(self): """TODO(pabloem) WHAT DOES THIS DO""" unique_names = set( - t.unique_name for t in self.bundle_context_manager. - process_bundle_descriptor.transforms.values()) + t.unique_name for t in self._bundle_descriptor.transforms.values()) for stage_name, candidate in reversed(_split_managers): if (stage_name in unique_names or (stage_name + '/Process') in unique_names): @@ -733,8 +835,8 @@ class BundleManager(object): byte_stream = b''.join(buffer_data) num_elements = len( list( - self.bundle_context_manager.get_input_coder_impl( - read_transform_id).decode_all(byte_stream))) + self._get_input_coder_impl(read_transform_id).decode_all( + byte_stream))) # Start the split manager in case it wants to set any breakpoints. split_manager_generator = split_manager(num_elements) @@ -787,20 +889,18 @@ class BundleManager(object): return split_results def process_bundle(self, - inputs, # type: Mapping[str, execution.PartitionableBuffer] + inputs, # type: Mapping[str, PartitionableBuffer] expected_outputs, # type: DataOutput - fired_timers, # type: Mapping[Tuple[str, str], execution.PartitionableBuffer] - expected_output_timers, # type: Dict[Tuple[str, str], str] - dry_run=False, + fired_timers, # type: Mapping[Tuple[str, str], PartitionableBuffer] + expected_output_timers # type: Dict[str, Dict[str, str]] ): # type: (...) -> BundleProcessResult # Unique id for the instruction processing this bundle. with BundleManager._lock: BundleManager._uid_counter += 1 process_bundle_id = 'bundle_%s' % BundleManager._uid_counter - self._worker_handler = self.bundle_context_manager.worker_handlers[ - BundleManager._uid_counter % - len(self.bundle_context_manager.worker_handlers)] + self._worker_handler = self._worker_handler_list[ + BundleManager._uid_counter % len(self._worker_handler_list)] split_manager = self._select_split_manager() if not split_manager: @@ -820,8 +920,7 @@ class BundleManager(object): process_bundle_req = beam_fn_api_pb2.InstructionRequest( instruction_id=process_bundle_id, process_bundle=beam_fn_api_pb2.ProcessBundleRequest( - process_bundle_descriptor_id=self.bundle_context_manager. - process_bundle_descriptor.id, + process_bundle_descriptor_id=self._bundle_descriptor.id, cache_tokens=[next(self._cache_token_generator)])) result_future = self._worker_handler.control_conn.push(process_bundle_req) @@ -843,15 +942,15 @@ class BundleManager(object): expect_reads, abort_callback=lambda: (result_future.is_done() and result_future.get().error)): - if isinstance(output, beam_fn_api_pb2.Elements.Timers) and not dry_run: + if isinstance(output, beam_fn_api_pb2.Elements.Timers): with BundleManager._lock: - self.bundle_context_manager.get_buffer( + self._get_buffer( expected_output_timers[( output.transform_id, output.timer_family_id)], output.transform_id).append(output.timers) - if isinstance(output, beam_fn_api_pb2.Elements.Data) and not dry_run: + if isinstance(output, beam_fn_api_pb2.Elements.Data): with BundleManager._lock: - self.bundle_context_manager.get_buffer( + self._get_buffer( expected_outputs[output.transform_id], output.transform_id).append(output.data) @@ -874,32 +973,32 @@ class ParallelBundleManager(BundleManager): def __init__( self, - bundle_context_manager, # type: execution.BundleContextManager + worker_handler_list, # type: Sequence[WorkerHandler] + get_buffer, # type: Callable[[bytes, str], PartitionableBuffer] + get_input_coder_impl, # type: Callable[[str], CoderImpl] + bundle_descriptor, # type: beam_fn_api_pb2.ProcessBundleDescriptor progress_frequency=None, cache_token_generator=None, **kwargs): # type: (...) -> None super(ParallelBundleManager, self).__init__( - bundle_context_manager, + worker_handler_list, + get_buffer, + get_input_coder_impl, + bundle_descriptor, progress_frequency, cache_token_generator=cache_token_generator) - self._num_workers = bundle_context_manager.num_workers + self._num_workers = kwargs.pop('num_workers', 1) def process_bundle(self, - inputs, # type: Mapping[str, execution.PartitionableBuffer] + inputs, # type: Mapping[str, PartitionableBuffer] expected_outputs, # type: DataOutput - fired_timers, # type: Mapping[Tuple[str, str], execution.PartitionableBuffer] - expected_output_timers, # type: Dict[Tuple[str, str], str] - dry_run=False, - ): + fired_timers, # type: Mapping[Tuple[str, str], PartitionableBuffer] + expected_output_timers # type: Dict[Tuple[str, str], str] + ): # type: (...) -> BundleProcessResult part_inputs = [{} for _ in range(self._num_workers) ] # type: List[Dict[str, List[bytes]]] - # Timers are only executed on the first worker - # TODO(BEAM-9741): Split timers to multiple workers - timer_inputs = [ - fired_timers if i == 0 else {} for i in range(self._num_workers) - ] for name, input in inputs.items(): for ix, part in enumerate(input.partition(self._num_workers)): part_inputs[ix][name] = part @@ -908,23 +1007,21 @@ class ParallelBundleManager(BundleManager): split_result_list = [ ] # type: List[beam_fn_api_pb2.ProcessBundleSplitResponse] - def execute(part_map_input_timers): + def execute(part_map): # type: (...) -> BundleProcessResult - part_map, input_timers = part_map_input_timers bundle_manager = BundleManager( - self.bundle_context_manager, + self._worker_handler_list, + self._get_buffer, + self._get_input_coder_impl, + self._bundle_descriptor, self._progress_frequency, cache_token_generator=self._cache_token_generator) return bundle_manager.process_bundle( - part_map, - expected_outputs, - input_timers, - expected_output_timers, - dry_run) + part_map, expected_outputs, fired_timers, expected_output_timers) with UnboundedThreadPoolExecutor() as executor: - for result, split_result in executor.map(execute, zip(part_inputs, # pylint: disable=zip-builtin-not-iterating - timer_inputs)): + for result, split_result in executor.map(execute, part_inputs): + split_result_list += split_result if merged_result is None: merged_result = result @@ -937,6 +1034,7 @@ class ParallelBundleManager(BundleManager): merged_result.process_bundle.monitoring_infos))), error=result.error or merged_result.error) assert merged_result is not None + return merged_result, split_result_list diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py index ce99d87..0ff0853 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py @@ -240,30 +240,6 @@ class FnApiRunnerTest(unittest.TestCase): lambda k, d: (k, sorted(d[k])), beam.pvalue.AsMultiMap(side)), equal_to([('a', [1, 3]), ('b', [2])])) - def test_multimap_multiside_input(self): - # A test where two transforms in the same stage consume the same PCollection - # twice as side input. - with self.create_pipeline() as p: - main = p | 'main' >> beam.Create(['a', 'b']) - side = ( - p | 'side' >> beam.Create([('a', 1), ('b', 2), ('a', 3)]) - # TODO(BEAM-4782): Obviate the need for this map. - | beam.Map(lambda kv: (kv[0], kv[1]))) - assert_that( - main | 'first map' >> beam.Map( - lambda k, - d, - l: (k, sorted(d[k]), sorted([e[1] for e in l])), - beam.pvalue.AsMultiMap(side), - beam.pvalue.AsList(side)) - | 'second map' >> beam.Map( - lambda k, - d, - l: (k[0], sorted(d[k[0]]), sorted([e[1] for e in l])), - beam.pvalue.AsMultiMap(side), - beam.pvalue.AsList(side)), - equal_to([('a', [1, 3], [1, 2, 3]), ('b', [2], [1, 2, 3])])) - def test_multimap_side_input_type_coercion(self): with self.create_pipeline() as p: main = p | 'main' >> beam.Create(['a', 'b']) diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py index 235aec8..5d18d29 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py @@ -75,17 +75,6 @@ PAR_DO_URNS = frozenset([ IMPULSE_BUFFER = b'impulse' -# SideInputId is identified by a consumer ParDo + tag. -SideInputId = Tuple[str, str] -SideInputAccessPattern = beam_runner_api_pb2.FunctionSpec - -DataOutput = Dict[str, bytes] - -# DataSideInput maps SideInputIds to a tuple of the encoded bytes of the side -# input content, and a payload specification regarding the type of side input -# (MultiMap / Iterable). -DataSideInput = Dict[SideInputId, Tuple[bytes, SideInputAccessPattern]] - class Stage(object): """A set of Transforms that can be sent to the worker for processing."""
