hjtran commented on code in PR #27631: URL: https://github.com/apache/beam/pull/27631#discussion_r1307955392
########## sdks/python/apache_beam/runners/trivial_runner.py: ########## @@ -0,0 +1,415 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import collections +import logging +from typing import Any +from typing import Iterable +from typing import Iterator +from typing import List +from typing import TypeVar + +from apache_beam import coders +from apache_beam.coders.coder_impl import create_InputStream +from apache_beam.coders.coder_impl import create_OutputStream +from apache_beam.portability import common_urns +from apache_beam.portability.api import beam_fn_api_pb2 +from apache_beam.portability.api import beam_runner_api_pb2 +from apache_beam.runners import common +from apache_beam.runners import pipeline_context +from apache_beam.runners import runner +from apache_beam.runners.worker import bundle_processor +from apache_beam.runners.portability.fn_api_runner import translations +from apache_beam.runners.portability.fn_api_runner import worker_handlers +from apache_beam.transforms import core +from apache_beam.transforms import trigger +from apache_beam.utils import windowed_value + +T = TypeVar("T") + +_LOGGER = logging.getLogger(__name__) + + +class TrivialRunner(runner.PipelineRunner): + """A bare-bones batch Python pipeline runner illistrating how to use the + RunnerAPI and FnAPI to execute pipelines. + + Note that this runner is primarily for pedagogical purposes and is missing + several features in order to keep it as simple as possible. Where possible + pointers are provided which this should serve as a useful starting point. + """ + def run_portable_pipeline(self, pipeline, options): + # First ensure we are able to run this pipeline. + # Specifically, that it does not depend on requirements that were + # added since this runner was developed. + self.check_requirements(pipeline, self.supported_requirements()) + + # Now we optimize the pipeline, notably performing pipeline fusion, + # to turn it into a DAG where all the operations are one of + # Impulse, Flatten, GroupByKey, or beam:runner:executable_stage. + optimized_pipeline = translations.optimize_pipeline( + pipeline, + phases=translations.standard_optimize_phases(), + known_runner_urns=frozenset([ + common_urns.primitives.IMPULSE.urn, + common_urns.primitives.FLATTEN.urn, + common_urns.primitives.GROUP_BY_KEY.urn + ]), + # This boolean indicates we want fused executable_stages. + partial=False) + + # standard_optimize_phases() has a final step giving the stages in + # topological order, so now we can just walk over them and execute + # them. (This are not quite so simple if we were attempting to execute + # a streaming pipeline, but this is a trivial runner that only supports + # batch...) + execution_state = ExecutionState(optimized_pipeline) + for transform_id in optimized_pipeline.root_transform_ids: + self.execute_transform(transform_id, execution_state) + + # A more sophisticated runner may perform the execution in the background + # and return a PipelineResult that can be used to monitor/cancel the + # concurrent execution. + return runner.PipelineResult(runner.PipelineState.DONE) + + def execute_transform(self, transform_id, execution_state): + """Execute a single transform.""" + transform_proto = execution_state.optimized_pipeline.components.transforms[ + transform_id] + _LOGGER.info( + "Executing stage %s %s", transform_id, transform_proto.unique_name) + if not is_primitive_transform(transform_proto): + # A composite is simply executed by executing its parts. + for sub_transform in transform_proto.subtransforms: + self.execute_transform(sub_transform, execution_state) + + elif transform_proto.spec.urn == common_urns.primitives.IMPULSE.urn: + # An impulse has no inputs and produces a single output (which happens + # to be an empty byte string in the global window). + execution_state.set_pcollection_contents( + only_element(transform_proto.outputs.values()), + [common.ENCODED_IMPULSE_VALUE]) + + elif transform_proto.spec.urn == common_urns.primitives.FLATTEN.urn: + # The output of a flatten is simply the union of its inputs. + output_pcoll_id = only_element(transform_proto.outputs.values()) + execution_state.set_pcollection_contents( + output_pcoll_id, + sum([ + execution_state.get_pcollection_contents(pc) + for pc in transform_proto.inputs.values() + ], [])) + + elif transform_proto.spec.urn == 'beam:runner:executable_stage:v1': + # This is a collection of user DoFns. + self.execute_executable_stage(transform_proto, execution_state) + + elif transform_proto.spec.urn == common_urns.primitives.GROUP_BY_KEY.urn: + # Execute the grouping operation. + self.group_by_key_and_window( + only_element(transform_proto.inputs.values()), + only_element(transform_proto.outputs.values()), + execution_state) + + else: + raise RuntimeError( + f"Unsupported transform {transform_id}" + " of type {transform_proto.spec.urn}") + + def execute_executable_stage(self, transform_proto, execution_state): + # Stage here is like a mini pipeline, with PTransforms, PCollections, etc. + # inside of it. + stage = beam_runner_api_pb2.ExecutableStagePayload.FromString( + transform_proto.spec.payload) + if stage.side_inputs: + # To support these we would need to make the side input PCollections + # available over the state API before processing this bundle. + raise NotImplementedError() Review Comment: ```suggestion raise NotImplementedError("Side inputs not supported") ``` ########## sdks/python/apache_beam/runners/trivial_runner.py: ########## @@ -0,0 +1,415 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import collections +import logging +from typing import Any +from typing import Iterable +from typing import Iterator +from typing import List +from typing import TypeVar + +from apache_beam import coders +from apache_beam.coders.coder_impl import create_InputStream +from apache_beam.coders.coder_impl import create_OutputStream +from apache_beam.portability import common_urns +from apache_beam.portability.api import beam_fn_api_pb2 +from apache_beam.portability.api import beam_runner_api_pb2 +from apache_beam.runners import common +from apache_beam.runners import pipeline_context +from apache_beam.runners import runner +from apache_beam.runners.worker import bundle_processor +from apache_beam.runners.portability.fn_api_runner import translations +from apache_beam.runners.portability.fn_api_runner import worker_handlers +from apache_beam.transforms import core +from apache_beam.transforms import trigger +from apache_beam.utils import windowed_value + +T = TypeVar("T") + +_LOGGER = logging.getLogger(__name__) + + +class TrivialRunner(runner.PipelineRunner): + """A bare-bones batch Python pipeline runner illistrating how to use the + RunnerAPI and FnAPI to execute pipelines. + + Note that this runner is primarily for pedagogical purposes and is missing + several features in order to keep it as simple as possible. Where possible + pointers are provided which this should serve as a useful starting point. + """ + def run_portable_pipeline(self, pipeline, options): + # First ensure we are able to run this pipeline. + # Specifically, that it does not depend on requirements that were + # added since this runner was developed. + self.check_requirements(pipeline, self.supported_requirements()) + + # Now we optimize the pipeline, notably performing pipeline fusion, + # to turn it into a DAG where all the operations are one of + # Impulse, Flatten, GroupByKey, or beam:runner:executable_stage. + optimized_pipeline = translations.optimize_pipeline( + pipeline, + phases=translations.standard_optimize_phases(), + known_runner_urns=frozenset([ + common_urns.primitives.IMPULSE.urn, + common_urns.primitives.FLATTEN.urn, + common_urns.primitives.GROUP_BY_KEY.urn + ]), + # This boolean indicates we want fused executable_stages. + partial=False) + + # standard_optimize_phases() has a final step giving the stages in + # topological order, so now we can just walk over them and execute + # them. (This are not quite so simple if we were attempting to execute + # a streaming pipeline, but this is a trivial runner that only supports + # batch...) + execution_state = ExecutionState(optimized_pipeline) + for transform_id in optimized_pipeline.root_transform_ids: + self.execute_transform(transform_id, execution_state) + + # A more sophisticated runner may perform the execution in the background + # and return a PipelineResult that can be used to monitor/cancel the + # concurrent execution. + return runner.PipelineResult(runner.PipelineState.DONE) + + def execute_transform(self, transform_id, execution_state): + """Execute a single transform.""" + transform_proto = execution_state.optimized_pipeline.components.transforms[ + transform_id] + _LOGGER.info( + "Executing stage %s %s", transform_id, transform_proto.unique_name) + if not is_primitive_transform(transform_proto): + # A composite is simply executed by executing its parts. + for sub_transform in transform_proto.subtransforms: + self.execute_transform(sub_transform, execution_state) + + elif transform_proto.spec.urn == common_urns.primitives.IMPULSE.urn: + # An impulse has no inputs and produces a single output (which happens + # to be an empty byte string in the global window). + execution_state.set_pcollection_contents( + only_element(transform_proto.outputs.values()), + [common.ENCODED_IMPULSE_VALUE]) + + elif transform_proto.spec.urn == common_urns.primitives.FLATTEN.urn: + # The output of a flatten is simply the union of its inputs. + output_pcoll_id = only_element(transform_proto.outputs.values()) + execution_state.set_pcollection_contents( + output_pcoll_id, + sum([ + execution_state.get_pcollection_contents(pc) + for pc in transform_proto.inputs.values() + ], [])) + + elif transform_proto.spec.urn == 'beam:runner:executable_stage:v1': + # This is a collection of user DoFns. + self.execute_executable_stage(transform_proto, execution_state) + + elif transform_proto.spec.urn == common_urns.primitives.GROUP_BY_KEY.urn: + # Execute the grouping operation. + self.group_by_key_and_window( + only_element(transform_proto.inputs.values()), + only_element(transform_proto.outputs.values()), + execution_state) + + else: + raise RuntimeError( + f"Unsupported transform {transform_id}" + " of type {transform_proto.spec.urn}") + + def execute_executable_stage(self, transform_proto, execution_state): + # Stage here is like a mini pipeline, with PTransforms, PCollections, etc. + # inside of it. + stage = beam_runner_api_pb2.ExecutableStagePayload.FromString( + transform_proto.spec.payload) + if stage.side_inputs: + # To support these we would need to make the side input PCollections + # available over the state API before processing this bundle. + raise NotImplementedError() + + # This is the set of transforms that were fused together. + stage_transforms = { + id: stage.components.transforms[id] + for id in stage.transforms + } + # The executable stage has bare PCollections as its inputs and outputs. + # + # We need an operation to feed data into the bundle (runner -> SDK). + # This is done by attaching special transform that reads from the data + # channel. + input_transform = execution_state.new_id('stage_input') + input_pcoll = stage.input + stage_transforms[input_transform] = beam_runner_api_pb2.PTransform( + # Read data, encoded with the given coder, from the data channel. + spec=beam_runner_api_pb2.FunctionSpec( + urn=bundle_processor.DATA_INPUT_URN, + payload=beam_fn_api_pb2.RemoteGrpcPort( + # If we were using a cross-process data channel, we would also + # need to set the address of the data channel itself here. + coder_id=execution_state.windowed_coder_id( + stage.input)).SerializeToString()), + # Wire its "output" to the required PCollection. + outputs={'out': input_pcoll}) + # Also add operations to consume data produced by the bundle and writes + # them to the data channel (SDK -> runner). + output_ops_to_pcoll = {} + for output_pcoll in stage.outputs: + output_transform = execution_state.new_id('stage_output') + stage_transforms[output_transform] = beam_runner_api_pb2.PTransform( + # Data will be written here, with the given coder. + spec=beam_runner_api_pb2.FunctionSpec( + urn=bundle_processor.DATA_OUTPUT_URN, + payload=beam_fn_api_pb2.RemoteGrpcPort( + # Again, the grpc address itself is implicit. + coder_id=execution_state.windowed_coder_id( + output_pcoll)).SerializeToString()), + # This operation takes as input the bundle's output pcollection. + inputs={'input': output_pcoll}) + output_ops_to_pcoll[output_transform] = output_pcoll + + # Now we can create a description of what it means to process a bundle + # of this type. When processing the bundle, we simply refer to this + # descriptor by id. (In our case we only do so once, but in a more general + # runner one may want to process the same "type" of bundle of many distinct + # partitions of the input, especially in streaming where one may have + # hundreds of concurrently processed bundles per key.) + process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor( + id=execution_state.new_id('descriptor'), + transforms=stage_transforms, + pcollections=stage.components.pcollections, + coders=execution_state.optimized_pipeline.components.coders, + windowing_strategies=stage.components.windowing_strategies, + environments=stage.components.environments, + # Were timers and state supported, their endpoints would be listed here. + ) + execution_state.register_process_bundle_descriptor( + process_bundle_descriptor) + + # Now we are ready to actually execute the bundle. + process_bundle_id = execution_state.new_id('bundle') + + # First, push the all input data onto the data channel. + # Timers would be sent over this channel as well. + # In a real runner we could do this after (or concurrently with) starting + # the bundle to avoid having to hold all the data to process in memory, + # but for simplicity our bundle invocation (below) is synchronous so the + # data must be available for processing right away. + to_worker = execution_state.worker_handler.data_conn.output_stream( + process_bundle_id, input_transform) + for encoded_data in execution_state.get_pcollection_contents(input_pcoll): + to_worker.write(encoded_data) + to_worker.close() + + # Now we send a process bundle request over the control plane. + process_bundle_request = beam_fn_api_pb2.InstructionRequest( + instruction_id=process_bundle_id, + process_bundle=beam_fn_api_pb2.ProcessBundleRequest( + process_bundle_descriptor_id=process_bundle_descriptor.id)) + result_future = execution_state.worker_handler.control_conn.push( + process_bundle_request) + + # Read the results off the data channel. + # Note that if there are multiple outputs, we may get them in any order, + # possibly interleaved. + for output in execution_state.worker_handler.data_conn.input_elements( + process_bundle_id, list(output_ops_to_pcoll.keys())): + if isinstance(output, beam_fn_api_pb2.Elements.Data): + # Adds the output to the appropriate PCollection. + execution_state.set_pcollection_contents( + output_ops_to_pcoll[output.transform_id], [output.data]) + else: + # E.g. timers to set. + raise RuntimeError("Unexpected data type: %s" % output) + + # Ensure the operation completed successfully. + # This result contains things like metrics and continuation tokens as well. + result = result_future.get() + if result.error: + raise RuntimeError(result.error) + if result.process_bundle.residual_roots: + # We would need to re-schedule execution of this bundle with this data. + raise NotImplementedError('SDF continuation') + if result.process_bundle.requires_finalization: + # We would need to invoke the finalization callback, on a best effort + # basis, *after* the outputs are durably committed. + raise NotImplementedError('finalization') + if result.process_bundle.elements.data: + # These should be processed just like outputs from the data channel. + raise NotImplementedError('control-channel data') + if result.process_bundle.elements.timers: + # These should be processed just like outputs from the data channel. + raise NotImplementedError('timers') + + def group_by_key_and_window(self, input_pcoll, output_pcoll, execution_state): + """Groups the elements of input_pcoll, placing their output in output_pcoll. + """ + # Note that we are using the designated coders to decode and deeply inspect + # the elements. This could be useful for other operations as well (e.g. + # splitting the returned bundles of encoded elements into smaller chunks). + # One issue that we're sweeping under the rug here is that for a truly + # portable multi-language runner we may not be able to instanciate + # every coder in Python. This can be overcome by wrapping unknown coders as + # length-prefixing (or otherwise delimiting) variants. + + # Decode the input elements to get at their individual keys and values. + input_coder = execution_state.windowed_coder(input_pcoll) + key_coder = input_coder.key_coder() + input_elements = [] + for encoded_elements in execution_state.get_pcollection_contents( + input_pcoll): + for element in decode_all(encoded_elements, input_coder): + input_elements.append(element) + + # Now perform the actual grouping. + components = execution_state.optimized_pipeline.components + windowing = components.windowing_strategies[ + components.pcollections[input_pcoll].windowing_strategy_id] + + if (windowing.merge_status == + beam_runner_api_pb2.MergeStatus.Enum.NON_MERGING and + windowing.output_time == + beam_runner_api_pb2.OutputTime.Enum.END_OF_WINDOW): + # This is the "easy" case, show how to do it by hand. + # Note that we're grouping by encoded key, and also by the window. + grouped = collections.defaultdict(list) + for element in input_elements: + for window in element.windows: + key, value = element.value + grouped[window, key_coder.encode(key)].append(value) + output_elements = [ + windowed_value.WindowedValue( + (key_coder.decode(encoded_key), values), + window.end, [window], + trigger.BatchGlobalTriggerDriver.ONLY_FIRING) + for ((window, encoded_key), values) in grouped.items() + ] + else: + # This handles generic merging and triggering. + trigger_driver = trigger.create_trigger_driver( + execution_state.windowing_strategy(input_pcoll), True) + grouped_by_key = collections.defaultdict(list) + for element in input_elements: + key, value = element.value + grouped_by_key[key_coder.encode(key)].append(element.with_value(value)) + output_elements = [] + for encoded_key, windowed_values in grouped_by_key.items(): + for grouping in trigger_driver.process_entire_key( + key_coder.decode(encoded_key), windowed_values): + output_elements.append(grouping) + + # Store the grouped values in the output PCollection. + output_coder = execution_state.windowed_coder(output_pcoll) + execution_state.set_pcollection_contents( + output_pcoll, [encode_all(output_elements, output_coder)]) + + def supported_requirements(self) -> Iterable[str]: + # Nothing non-trivial is supported. + return [] + + +class ExecutionState: + """A helper class holding various values and context during execution.""" + def __init__(self, optimized_pipeline): + self.optimized_pipeline = optimized_pipeline + self._pcollections_to_encoded_elements = {} + self._counter = 0 + self._process_bundle_descriptors = {} + # This emulates a connection to an SDK worker (e.g. its data, control, + # etc. channels). + # There are other variants available as well (e.g. GRPC, Docker, ...) + self.worker_handler = worker_handlers.EmbeddedWorkerHandler( + None, + state=worker_handlers.StateServicer(), + provision_info=None, + worker_manager=self) + self._windowed_coders = {} + # Populate all windowed coders before creating _pipeline_context. + for pcoll_id in self.optimized_pipeline.components.pcollections.keys(): + self.windowed_coder_id(pcoll_id) + self._pipeline_context = pipeline_context.PipelineContext( + optimized_pipeline.components) + + def register_process_bundle_descriptor( + self, process_bundle_descriptor: beam_fn_api_pb2.ProcessBundleDescriptor): + self._process_bundle_descriptors[ + process_bundle_descriptor.id] = process_bundle_descriptor + + def get_pcollection_contents(self, pcoll_id: str) -> List[bytes]: + return self._pcollections_to_encoded_elements[pcoll_id] + + def set_pcollection_contents(self, pcoll_id: str, chunks: List[bytes]): + self._pcollections_to_encoded_elements[pcoll_id] = chunks Review Comment: `_pcollections_to_encoded_chunks` might be slightly more correct? I got tripped up because I thought each value in these lists encoded a single element. Not sure the suggested name is really that much better though. Feel free to ignore ########## sdks/python/apache_beam/runners/trivial_runner.py: ########## @@ -0,0 +1,415 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import collections +import logging +from typing import Any +from typing import Iterable +from typing import Iterator +from typing import List +from typing import TypeVar + +from apache_beam import coders +from apache_beam.coders.coder_impl import create_InputStream +from apache_beam.coders.coder_impl import create_OutputStream +from apache_beam.portability import common_urns +from apache_beam.portability.api import beam_fn_api_pb2 +from apache_beam.portability.api import beam_runner_api_pb2 +from apache_beam.runners import common +from apache_beam.runners import pipeline_context +from apache_beam.runners import runner +from apache_beam.runners.worker import bundle_processor +from apache_beam.runners.portability.fn_api_runner import translations +from apache_beam.runners.portability.fn_api_runner import worker_handlers +from apache_beam.transforms import core +from apache_beam.transforms import trigger +from apache_beam.utils import windowed_value + +T = TypeVar("T") + +_LOGGER = logging.getLogger(__name__) + + +class TrivialRunner(runner.PipelineRunner): + """A bare-bones batch Python pipeline runner illistrating how to use the + RunnerAPI and FnAPI to execute pipelines. + + Note that this runner is primarily for pedagogical purposes and is missing + several features in order to keep it as simple as possible. Where possible + pointers are provided which this should serve as a useful starting point. + """ + def run_portable_pipeline(self, pipeline, options): + # First ensure we are able to run this pipeline. + # Specifically, that it does not depend on requirements that were + # added since this runner was developed. + self.check_requirements(pipeline, self.supported_requirements()) + + # Now we optimize the pipeline, notably performing pipeline fusion, + # to turn it into a DAG where all the operations are one of + # Impulse, Flatten, GroupByKey, or beam:runner:executable_stage. + optimized_pipeline = translations.optimize_pipeline( + pipeline, + phases=translations.standard_optimize_phases(), + known_runner_urns=frozenset([ + common_urns.primitives.IMPULSE.urn, + common_urns.primitives.FLATTEN.urn, + common_urns.primitives.GROUP_BY_KEY.urn + ]), + # This boolean indicates we want fused executable_stages. + partial=False) + + # standard_optimize_phases() has a final step giving the stages in + # topological order, so now we can just walk over them and execute + # them. (This are not quite so simple if we were attempting to execute + # a streaming pipeline, but this is a trivial runner that only supports + # batch...) + execution_state = ExecutionState(optimized_pipeline) + for transform_id in optimized_pipeline.root_transform_ids: + self.execute_transform(transform_id, execution_state) + + # A more sophisticated runner may perform the execution in the background + # and return a PipelineResult that can be used to monitor/cancel the + # concurrent execution. + return runner.PipelineResult(runner.PipelineState.DONE) + + def execute_transform(self, transform_id, execution_state): + """Execute a single transform.""" + transform_proto = execution_state.optimized_pipeline.components.transforms[ + transform_id] + _LOGGER.info( + "Executing stage %s %s", transform_id, transform_proto.unique_name) + if not is_primitive_transform(transform_proto): + # A composite is simply executed by executing its parts. + for sub_transform in transform_proto.subtransforms: + self.execute_transform(sub_transform, execution_state) + + elif transform_proto.spec.urn == common_urns.primitives.IMPULSE.urn: + # An impulse has no inputs and produces a single output (which happens + # to be an empty byte string in the global window). + execution_state.set_pcollection_contents( + only_element(transform_proto.outputs.values()), + [common.ENCODED_IMPULSE_VALUE]) + + elif transform_proto.spec.urn == common_urns.primitives.FLATTEN.urn: + # The output of a flatten is simply the union of its inputs. + output_pcoll_id = only_element(transform_proto.outputs.values()) + execution_state.set_pcollection_contents( + output_pcoll_id, + sum([ + execution_state.get_pcollection_contents(pc) + for pc in transform_proto.inputs.values() + ], [])) + + elif transform_proto.spec.urn == 'beam:runner:executable_stage:v1': + # This is a collection of user DoFns. + self.execute_executable_stage(transform_proto, execution_state) + + elif transform_proto.spec.urn == common_urns.primitives.GROUP_BY_KEY.urn: + # Execute the grouping operation. + self.group_by_key_and_window( + only_element(transform_proto.inputs.values()), + only_element(transform_proto.outputs.values()), + execution_state) + + else: + raise RuntimeError( + f"Unsupported transform {transform_id}" + " of type {transform_proto.spec.urn}") + + def execute_executable_stage(self, transform_proto, execution_state): + # Stage here is like a mini pipeline, with PTransforms, PCollections, etc. + # inside of it. + stage = beam_runner_api_pb2.ExecutableStagePayload.FromString( + transform_proto.spec.payload) + if stage.side_inputs: + # To support these we would need to make the side input PCollections + # available over the state API before processing this bundle. + raise NotImplementedError() + + # This is the set of transforms that were fused together. + stage_transforms = { + id: stage.components.transforms[id] + for id in stage.transforms + } + # The executable stage has bare PCollections as its inputs and outputs. + # + # We need an operation to feed data into the bundle (runner -> SDK). + # This is done by attaching special transform that reads from the data + # channel. + input_transform = execution_state.new_id('stage_input') + input_pcoll = stage.input + stage_transforms[input_transform] = beam_runner_api_pb2.PTransform( + # Read data, encoded with the given coder, from the data channel. + spec=beam_runner_api_pb2.FunctionSpec( + urn=bundle_processor.DATA_INPUT_URN, + payload=beam_fn_api_pb2.RemoteGrpcPort( + # If we were using a cross-process data channel, we would also + # need to set the address of the data channel itself here. + coder_id=execution_state.windowed_coder_id( + stage.input)).SerializeToString()), + # Wire its "output" to the required PCollection. + outputs={'out': input_pcoll}) + # Also add operations to consume data produced by the bundle and writes + # them to the data channel (SDK -> runner). + output_ops_to_pcoll = {} + for output_pcoll in stage.outputs: + output_transform = execution_state.new_id('stage_output') + stage_transforms[output_transform] = beam_runner_api_pb2.PTransform( + # Data will be written here, with the given coder. + spec=beam_runner_api_pb2.FunctionSpec( + urn=bundle_processor.DATA_OUTPUT_URN, + payload=beam_fn_api_pb2.RemoteGrpcPort( + # Again, the grpc address itself is implicit. + coder_id=execution_state.windowed_coder_id( + output_pcoll)).SerializeToString()), + # This operation takes as input the bundle's output pcollection. + inputs={'input': output_pcoll}) + output_ops_to_pcoll[output_transform] = output_pcoll + + # Now we can create a description of what it means to process a bundle + # of this type. When processing the bundle, we simply refer to this + # descriptor by id. (In our case we only do so once, but in a more general + # runner one may want to process the same "type" of bundle of many distinct + # partitions of the input, especially in streaming where one may have + # hundreds of concurrently processed bundles per key.) + process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor( + id=execution_state.new_id('descriptor'), + transforms=stage_transforms, + pcollections=stage.components.pcollections, + coders=execution_state.optimized_pipeline.components.coders, + windowing_strategies=stage.components.windowing_strategies, + environments=stage.components.environments, + # Were timers and state supported, their endpoints would be listed here. + ) + execution_state.register_process_bundle_descriptor( + process_bundle_descriptor) + + # Now we are ready to actually execute the bundle. + process_bundle_id = execution_state.new_id('bundle') + + # First, push the all input data onto the data channel. + # Timers would be sent over this channel as well. + # In a real runner we could do this after (or concurrently with) starting + # the bundle to avoid having to hold all the data to process in memory, + # but for simplicity our bundle invocation (below) is synchronous so the + # data must be available for processing right away. + to_worker = execution_state.worker_handler.data_conn.output_stream( + process_bundle_id, input_transform) + for encoded_data in execution_state.get_pcollection_contents(input_pcoll): + to_worker.write(encoded_data) + to_worker.close() + + # Now we send a process bundle request over the control plane. + process_bundle_request = beam_fn_api_pb2.InstructionRequest( + instruction_id=process_bundle_id, + process_bundle=beam_fn_api_pb2.ProcessBundleRequest( + process_bundle_descriptor_id=process_bundle_descriptor.id)) + result_future = execution_state.worker_handler.control_conn.push( + process_bundle_request) + + # Read the results off the data channel. + # Note that if there are multiple outputs, we may get them in any order, + # possibly interleaved. + for output in execution_state.worker_handler.data_conn.input_elements( + process_bundle_id, list(output_ops_to_pcoll.keys())): + if isinstance(output, beam_fn_api_pb2.Elements.Data): + # Adds the output to the appropriate PCollection. + execution_state.set_pcollection_contents( + output_ops_to_pcoll[output.transform_id], [output.data]) + else: + # E.g. timers to set. + raise RuntimeError("Unexpected data type: %s" % output) + + # Ensure the operation completed successfully. + # This result contains things like metrics and continuation tokens as well. + result = result_future.get() + if result.error: + raise RuntimeError(result.error) + if result.process_bundle.residual_roots: + # We would need to re-schedule execution of this bundle with this data. + raise NotImplementedError('SDF continuation') + if result.process_bundle.requires_finalization: + # We would need to invoke the finalization callback, on a best effort + # basis, *after* the outputs are durably committed. + raise NotImplementedError('finalization') + if result.process_bundle.elements.data: + # These should be processed just like outputs from the data channel. + raise NotImplementedError('control-channel data') + if result.process_bundle.elements.timers: + # These should be processed just like outputs from the data channel. + raise NotImplementedError('timers') + + def group_by_key_and_window(self, input_pcoll, output_pcoll, execution_state): + """Groups the elements of input_pcoll, placing their output in output_pcoll. + """ + # Note that we are using the designated coders to decode and deeply inspect + # the elements. This could be useful for other operations as well (e.g. + # splitting the returned bundles of encoded elements into smaller chunks). + # One issue that we're sweeping under the rug here is that for a truly + # portable multi-language runner we may not be able to instanciate + # every coder in Python. This can be overcome by wrapping unknown coders as + # length-prefixing (or otherwise delimiting) variants. + + # Decode the input elements to get at their individual keys and values. + input_coder = execution_state.windowed_coder(input_pcoll) + key_coder = input_coder.key_coder() + input_elements = [] + for encoded_elements in execution_state.get_pcollection_contents( + input_pcoll): + for element in decode_all(encoded_elements, input_coder): + input_elements.append(element) + + # Now perform the actual grouping. + components = execution_state.optimized_pipeline.components + windowing = components.windowing_strategies[ + components.pcollections[input_pcoll].windowing_strategy_id] + + if (windowing.merge_status == + beam_runner_api_pb2.MergeStatus.Enum.NON_MERGING and + windowing.output_time == + beam_runner_api_pb2.OutputTime.Enum.END_OF_WINDOW): + # This is the "easy" case, show how to do it by hand. + # Note that we're grouping by encoded key, and also by the window. + grouped = collections.defaultdict(list) + for element in input_elements: + for window in element.windows: + key, value = element.value + grouped[window, key_coder.encode(key)].append(value) + output_elements = [ + windowed_value.WindowedValue( + (key_coder.decode(encoded_key), values), + window.end, [window], + trigger.BatchGlobalTriggerDriver.ONLY_FIRING) + for ((window, encoded_key), values) in grouped.items() + ] + else: + # This handles generic merging and triggering. + trigger_driver = trigger.create_trigger_driver( + execution_state.windowing_strategy(input_pcoll), True) + grouped_by_key = collections.defaultdict(list) + for element in input_elements: + key, value = element.value + grouped_by_key[key_coder.encode(key)].append(element.with_value(value)) + output_elements = [] + for encoded_key, windowed_values in grouped_by_key.items(): + for grouping in trigger_driver.process_entire_key( + key_coder.decode(encoded_key), windowed_values): + output_elements.append(grouping) + + # Store the grouped values in the output PCollection. + output_coder = execution_state.windowed_coder(output_pcoll) + execution_state.set_pcollection_contents( + output_pcoll, [encode_all(output_elements, output_coder)]) + + def supported_requirements(self) -> Iterable[str]: + # Nothing non-trivial is supported. + return [] + + +class ExecutionState: + """A helper class holding various values and context during execution.""" + def __init__(self, optimized_pipeline): + self.optimized_pipeline = optimized_pipeline + self._pcollections_to_encoded_elements = {} + self._counter = 0 + self._process_bundle_descriptors = {} + # This emulates a connection to an SDK worker (e.g. its data, control, + # etc. channels). + # There are other variants available as well (e.g. GRPC, Docker, ...) + self.worker_handler = worker_handlers.EmbeddedWorkerHandler( + None, + state=worker_handlers.StateServicer(), + provision_info=None, + worker_manager=self) + self._windowed_coders = {} + # Populate all windowed coders before creating _pipeline_context. + for pcoll_id in self.optimized_pipeline.components.pcollections.keys(): + self.windowed_coder_id(pcoll_id) + self._pipeline_context = pipeline_context.PipelineContext( + optimized_pipeline.components) Review Comment: `PipelineContext` is marked as an internal only class ########## sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py: ########## @@ -603,7 +603,7 @@ def pipeline_from_stages( components.transforms.clear() components.pcollections.clear() - roots = set() + roots = {} # order preserving but still has fast contains checking Review Comment: Might it be worth adding a test for this? ########## sdks/python/apache_beam/runners/trivial_runner_test.py: ########## @@ -0,0 +1,67 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +import apache_beam as beam + +from apache_beam.runners.trivial_runner import TrivialRunner +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to + + +class TrivialRunnerTest(unittest.TestCase): + def test_trivial(self): + # The most trivial pipeline, to ensure at least something is working. + # (Notably avoids the non-trivial complexity within assert_that.) + with beam.Pipeline(runner=TrivialRunner()) as p: + _ = p | beam.Impulse() + + def test_assert_that(self): + # If this fails, the other tests may be vacuous. + with self.assertRaisesRegex(Exception, 'Failed assert'): + with beam.Pipeline(runner=TrivialRunner()) as p: + assert_that(p | beam.Impulse(), equal_to(['a'])) + + def test_impulse(self): + with beam.Pipeline(runner=TrivialRunner()) as p: + assert_that(p | beam.Impulse(), equal_to([b''])) + + def test_create(self): + with beam.Pipeline(runner=TrivialRunner()) as p: + assert_that(p | beam.Create(['a', 'b']), equal_to(['a', 'b'])) + + def test_flatten(self): + with beam.Pipeline(runner=TrivialRunner()) as p: + ab = p | 'AB' >> beam.Create(['a', 'b']) + c = p | 'C' >> beam.Create(['c']) + assert_that((ab, c, c) | beam.Flatten(), equal_to(['a', 'b', 'c', 'c'])) + + def test_map(self): + with beam.Pipeline(runner=TrivialRunner()) as p: + assert_that( + p | beam.Create(['a', 'b']) | beam.Map(str.upper), + equal_to(['A', 'B'])) + Review Comment: `Create` adds quite a number of transforms by default which make these test cases a tad more complicated. I found setting `shuffle=False` made things easier to debug and develop -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
