http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/runners/portability/fn_api_runner.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner.py index f88fe53..db34ef9 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py @@ -17,32 +17,27 @@ """A PipelineRunner using the SDK harness. """ -import base64 import collections +import json import logging import Queue as queue import threading from concurrent import futures -from google.protobuf import wrappers_pb2 import grpc -import apache_beam as beam # pylint: disable=ungrouped-imports +import apache_beam as beam from apache_beam.coders import WindowedValueCoder from apache_beam.coders.coder_impl import create_InputStream from apache_beam.coders.coder_impl import create_OutputStream from apache_beam.internal import pickler from apache_beam.io import iobase from apache_beam.transforms.window import GlobalWindows -from apache_beam.portability.api import beam_fn_api_pb2 -from apache_beam.portability.api import beam_runner_api_pb2 -from apache_beam.runners import pipeline_context +from apache_beam.runners.api import beam_fn_api_pb2 from apache_beam.runners.portability import maptask_executor_runner -from apache_beam.runners.worker import bundle_processor from apache_beam.runners.worker import data_plane from apache_beam.runners.worker import operation_specs from apache_beam.runners.worker import sdk_worker -from apache_beam.utils import proto_utils # This module is experimental. No backwards-compatibility guarantees. @@ -128,145 +123,191 @@ class FnApiRunner(maptask_executor_runner.MapTaskExecutorRunner): def _map_task_registration(self, map_task, state_handler, data_operation_spec): - input_data, side_input_data, runner_sinks, process_bundle_descriptor = ( - self._map_task_to_protos(map_task, data_operation_spec)) - # Side inputs will be accessed over the state API. - for key, elements_data in side_input_data.items(): - state_key = beam_fn_api_pb2.StateKey.MultimapSideInput(key=key) - state_handler.Clear(state_key) - state_handler.Append(state_key, [elements_data]) - return beam_fn_api_pb2.InstructionRequest( - instruction_id=self._next_uid(), - register=beam_fn_api_pb2.RegisterRequest( - process_bundle_descriptor=[process_bundle_descriptor]) - ), runner_sinks, input_data - - def _map_task_to_protos(self, map_task, data_operation_spec): input_data = {} - side_input_data = {} runner_sinks = {} - - context = pipeline_context.PipelineContext() - transform_protos = {} - used_pcollections = {} - - def uniquify(*names): - # An injective mapping from string* to string. - return ':'.join("%s:%d" % (name, len(name)) for name in names) - - def pcollection_id(op_ix, out_ix): - if (op_ix, out_ix) not in used_pcollections: - used_pcollections[op_ix, out_ix] = uniquify( - map_task[op_ix][0], 'out', str(out_ix)) - return used_pcollections[op_ix, out_ix] - - def get_inputs(op): - if hasattr(op, 'inputs'): - inputs = op.inputs - elif hasattr(op, 'input'): - inputs = [op.input] - else: - inputs = [] - return {'in%s' % ix: pcollection_id(*input) - for ix, input in enumerate(inputs)} - - def get_outputs(op_ix): - op = map_task[op_ix][1] - return {tag: pcollection_id(op_ix, out_ix) - for out_ix, tag in enumerate(getattr(op, 'output_tags', ['out']))} - - def only_element(iterable): - element, = iterable - return element + transforms = [] + transform_index_to_id = {} + + # Maps coders to new coder objects and references. + coders = {} + + def coder_id(coder): + if coder not in coders: + coders[coder] = beam_fn_api_pb2.Coder( + function_spec=sdk_worker.pack_function_spec_data( + json.dumps(coder.as_cloud_object()), + sdk_worker.PYTHON_CODER_URN, id=self._next_uid())) + + return coders[coder].function_spec.id + + def output_tags(op): + return getattr(op, 'output_tags', ['out']) + + def as_target(op_input): + input_op_index, input_output_index = op_input + input_op = map_task[input_op_index][1] + return { + 'ignored_input_tag': + beam_fn_api_pb2.Target.List(target=[ + beam_fn_api_pb2.Target( + primitive_transform_reference=transform_index_to_id[ + input_op_index], + name=output_tags(input_op)[input_output_index]) + ]) + } + + def outputs(op): + return { + tag: beam_fn_api_pb2.PCollection(coder_reference=coder_id(coder)) + for tag, coder in zip(output_tags(op), op.output_coders) + } for op_ix, (stage_name, operation) in enumerate(map_task): - transform_id = uniquify(stage_name) - + transform_id = transform_index_to_id[op_ix] = self._next_uid() if isinstance(operation, operation_specs.WorkerInMemoryWrite): # Write this data back to the runner. - target_name = only_element(get_inputs(operation).keys()) - runner_sinks[(transform_id, target_name)] = operation - transform_spec = beam_runner_api_pb2.FunctionSpec( - urn=bundle_processor.DATA_OUTPUT_URN, - parameter=proto_utils.pack_Any(data_operation_spec)) + fn = beam_fn_api_pb2.FunctionSpec(urn=sdk_worker.DATA_OUTPUT_URN, + id=self._next_uid()) + if data_operation_spec: + fn.data.Pack(data_operation_spec) + inputs = as_target(operation.input) + side_inputs = {} + runner_sinks[(transform_id, 'out')] = operation elif isinstance(operation, operation_specs.WorkerRead): - # A Read from an in-memory source is done over the data plane. + # A Read is either translated to a direct injection of windowed values + # into the sdk worker, or an injection of the source object into the + # sdk worker as data followed by an SDF that reads that source. if (isinstance(operation.source.source, - maptask_executor_runner.InMemorySource) + worker_runner_base.InMemorySource) and isinstance(operation.source.source.default_output_coder(), WindowedValueCoder)): - target_name = only_element(get_outputs(op_ix).keys()) - input_data[(transform_id, target_name)] = self._reencode_elements( - operation.source.source.read(None), - operation.source.source.default_output_coder()) - transform_spec = beam_runner_api_pb2.FunctionSpec( - urn=bundle_processor.DATA_INPUT_URN, - parameter=proto_utils.pack_Any(data_operation_spec)) - + output_stream = create_OutputStream() + element_coder = ( + operation.source.source.default_output_coder().get_impl()) + # Re-encode the elements in the nested context and + # concatenate them together + for element in operation.source.source.read(None): + element_coder.encode_to_stream(element, output_stream, True) + target_name = self._next_uid() + input_data[(transform_id, target_name)] = output_stream.get() + fn = beam_fn_api_pb2.FunctionSpec(urn=sdk_worker.DATA_INPUT_URN, + id=self._next_uid()) + if data_operation_spec: + fn.data.Pack(data_operation_spec) + inputs = {target_name: beam_fn_api_pb2.Target.List()} + side_inputs = {} else: - # Otherwise serialize the source and execute it there. - # TODO: Use SDFs with an initial impulse. - # The Dataflow runner harness strips the base64 encoding. do the same - # here until we get the same thing back that we sent in. - transform_spec = beam_runner_api_pb2.FunctionSpec( - urn=bundle_processor.PYTHON_SOURCE_URN, - parameter=proto_utils.pack_Any( - wrappers_pb2.BytesValue( - value=base64.b64decode( - pickler.dumps(operation.source.source))))) + # Read the source object from the runner. + source_coder = beam.coders.DillCoder() + input_transform_id = self._next_uid() + output_stream = create_OutputStream() + source_coder.get_impl().encode_to_stream( + GlobalWindows.windowed_value(operation.source), + output_stream, + True) + target_name = self._next_uid() + input_data[(input_transform_id, target_name)] = output_stream.get() + input_ptransform = beam_fn_api_pb2.PrimitiveTransform( + id=input_transform_id, + function_spec=beam_fn_api_pb2.FunctionSpec( + urn=sdk_worker.DATA_INPUT_URN, + id=self._next_uid()), + # TODO(robertwb): Possible name collision. + step_name=stage_name + '/inject_source', + inputs={target_name: beam_fn_api_pb2.Target.List()}, + outputs={ + 'out': + beam_fn_api_pb2.PCollection( + coder_reference=coder_id(source_coder)) + }) + if data_operation_spec: + input_ptransform.function_spec.data.Pack(data_operation_spec) + transforms.append(input_ptransform) + + # Read the elements out of the source. + fn = sdk_worker.pack_function_spec_data( + OLDE_SOURCE_SPLITTABLE_DOFN_DATA, + sdk_worker.PYTHON_DOFN_URN, + id=self._next_uid()) + inputs = { + 'ignored_input_tag': + beam_fn_api_pb2.Target.List(target=[ + beam_fn_api_pb2.Target( + primitive_transform_reference=input_transform_id, + name='out') + ]) + } + side_inputs = {} elif isinstance(operation, operation_specs.WorkerDoFn): - # Record the contents of each side input for access via the state api. - side_input_extras = [] + fn = sdk_worker.pack_function_spec_data( + operation.serialized_fn, + sdk_worker.PYTHON_DOFN_URN, + id=self._next_uid()) + inputs = as_target(operation.input) + # Store the contents of each side input for state access. for si in operation.side_inputs: assert isinstance(si.source, iobase.BoundedSource) element_coder = si.source.default_output_coder() + view_id = self._next_uid() # TODO(robertwb): Actually flesh out the ViewFn API. - side_input_extras.append((si.tag, element_coder)) - side_input_data[ - bundle_processor.side_input_tag(transform_id, si.tag)] = ( - self._reencode_elements( - si.source.read(si.source.get_range_tracker(None, None)), - element_coder)) - augmented_serialized_fn = pickler.dumps( - (operation.serialized_fn, side_input_extras)) - transform_spec = beam_runner_api_pb2.FunctionSpec( - urn=bundle_processor.PYTHON_DOFN_URN, - parameter=proto_utils.pack_Any( - wrappers_pb2.BytesValue(value=augmented_serialized_fn))) + side_inputs[si.tag] = beam_fn_api_pb2.SideInput( + view_fn=sdk_worker.serialize_and_pack_py_fn( + element_coder, urn=sdk_worker.PYTHON_ITERABLE_VIEWFN_URN, + id=view_id)) + # Re-encode the elements in the nested context and + # concatenate them together + output_stream = create_OutputStream() + for element in si.source.read( + si.source.get_range_tracker(None, None)): + element_coder.get_impl().encode_to_stream( + element, output_stream, True) + elements_data = output_stream.get() + state_key = beam_fn_api_pb2.StateKey(function_spec_reference=view_id) + state_handler.Clear(state_key) + state_handler.Append( + beam_fn_api_pb2.SimpleStateAppendRequest( + state_key=state_key, data=[elements_data])) elif isinstance(operation, operation_specs.WorkerFlatten): - # Flatten is nice and simple. - transform_spec = beam_runner_api_pb2.FunctionSpec( - urn=bundle_processor.IDENTITY_DOFN_URN) + fn = sdk_worker.pack_function_spec_data( + operation.serialized_fn, + sdk_worker.IDENTITY_DOFN_URN, + id=self._next_uid()) + inputs = { + 'ignored_input_tag': + beam_fn_api_pb2.Target.List(target=[ + beam_fn_api_pb2.Target( + primitive_transform_reference=transform_index_to_id[ + input_op_index], + name=output_tags(map_task[input_op_index][1])[ + input_output_index]) + for input_op_index, input_output_index in operation.inputs + ]) + } + side_inputs = {} else: - raise NotImplementedError(operation) - - transform_protos[transform_id] = beam_runner_api_pb2.PTransform( - unique_name=stage_name, - spec=transform_spec, - inputs=get_inputs(operation), - outputs=get_outputs(op_ix)) - - pcollection_protos = { - name: beam_runner_api_pb2.PCollection( - unique_name=name, - coder_id=context.coders.get_id( - map_task[op_id][1].output_coders[out_id])) - for (op_id, out_id), name in used_pcollections.items() - } - # Must follow creation of pcollection_protos to capture used coders. - context_proto = context.to_runner_api() + raise TypeError(operation) + + ptransform = beam_fn_api_pb2.PrimitiveTransform( + id=transform_id, + function_spec=fn, + step_name=stage_name, + inputs=inputs, + side_inputs=side_inputs, + outputs=outputs(operation)) + transforms.append(ptransform) + process_bundle_descriptor = beam_fn_api_pb2.ProcessBundleDescriptor( - id=self._next_uid(), - transforms=transform_protos, - pcollections=pcollection_protos, - coders=dict(context_proto.coders.items()), - windowing_strategies=dict(context_proto.windowing_strategies.items()), - environments=dict(context_proto.environments.items())) - return input_data, side_input_data, runner_sinks, process_bundle_descriptor + id=self._next_uid(), coders=coders.values(), + primitive_transform=transforms) + return beam_fn_api_pb2.InstructionRequest( + instruction_id=self._next_uid(), + register=beam_fn_api_pb2.RegisterRequest( + process_bundle_descriptor=[process_bundle_descriptor + ])), runner_sinks, input_data def _run_map_task( self, map_task, control_handler, state_handler, data_plane_handler, @@ -317,7 +358,7 @@ class FnApiRunner(maptask_executor_runner.MapTaskExecutorRunner): sink_op.output_buffer.append(e) return - def execute_map_tasks(self, ordered_map_tasks, direct=False): + def execute_map_tasks(self, ordered_map_tasks, direct=True): if direct: controller = FnApiRunner.DirectController() else: @@ -341,8 +382,9 @@ class FnApiRunner(maptask_executor_runner.MapTaskExecutorRunner): return beam_fn_api_pb2.Elements.Data( data=''.join(self._all[self._to_key(state_key)])) - def Append(self, state_key, data): - self._all[self._to_key(state_key)].extend(data) + def Append(self, append_request): + self._all[self._to_key(append_request.state_key)].extend( + append_request.data) def Clear(self, state_key): try: @@ -352,7 +394,8 @@ class FnApiRunner(maptask_executor_runner.MapTaskExecutorRunner): @staticmethod def _to_key(state_key): - return state_key.window, state_key.key + return (state_key.function_spec_reference, state_key.window, + state_key.key) class DirectController(object): """An in-memory controller for fn API control, state and data planes.""" @@ -428,10 +471,3 @@ class FnApiRunner(maptask_executor_runner.MapTaskExecutorRunner): self.data_plane_handler.close() self.control_server.stop(5).wait() self.data_server.stop(5).wait() - - @staticmethod - def _reencode_elements(elements, element_coder): - output_stream = create_OutputStream() - for element in elements: - element_coder.get_impl().encode_to_stream(element, output_stream, True) - return output_stream.get()
http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py index 163e980..633602f 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py @@ -20,41 +20,18 @@ import unittest import apache_beam as beam from apache_beam.runners.portability import fn_api_runner -from apache_beam.runners.portability import maptask_executor_runner_test -from apache_beam.testing.util import assert_that -from apache_beam.testing.util import equal_to +from apache_beam.runners.portability import maptask_executor_runner -class FnApiRunnerTest( - maptask_executor_runner_test.MapTaskExecutorRunnerTest): +class FnApiRunnerTest(maptask_executor_runner.MapTaskExecutorRunner): def create_pipeline(self): - return beam.Pipeline( - runner=fn_api_runner.FnApiRunner()) + return beam.Pipeline(runner=fn_api_runner.FnApiRunner()) def test_combine_per_key(self): - # TODO(BEAM-1348): Enable once Partial GBK is supported in fn API. + # TODO(robertwb): Implement PGBKCV operation. pass - def test_combine_per_key(self): - # TODO(BEAM-1348): Enable once Partial GBK is supported in fn API. - pass - - def test_pardo_side_inputs(self): - # TODO(BEAM-1348): Enable once side inputs are supported in fn API. - pass - - def test_pardo_unfusable_side_inputs(self): - # TODO(BEAM-1348): Enable once side inputs are supported in fn API. - pass - - def test_assert_that(self): - # TODO: figure out a way for fn_api_runner to parse and raise the - # underlying exception. - with self.assertRaisesRegexp(RuntimeError, 'BeamAssertException'): - with self.create_pipeline() as p: - assert_that(p | beam.Create(['a', 'b']), equal_to(['a'])) - # Inherits all tests from maptask_executor_runner.MapTaskExecutorRunner http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/runners/worker/bundle_processor.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py deleted file mode 100644 index 2669bfc..0000000 --- a/sdks/python/apache_beam/runners/worker/bundle_processor.py +++ /dev/null @@ -1,426 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -"""SDK harness for executing Python Fns via the Fn API.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import base64 -import collections -import json -import logging - -from google.protobuf import wrappers_pb2 - -from apache_beam.coders import coder_impl -from apache_beam.coders import WindowedValueCoder -from apache_beam.internal import pickler -from apache_beam.io import iobase -from apache_beam.portability.api import beam_fn_api_pb2 -from apache_beam.runners.dataflow.native_io import iobase as native_iobase -from apache_beam.runners import pipeline_context -from apache_beam.runners.worker import operation_specs -from apache_beam.runners.worker import operations -from apache_beam.utils import counters -from apache_beam.utils import proto_utils - -# This module is experimental. No backwards-compatibility guarantees. - - -try: - from apache_beam.runners.worker import statesampler -except ImportError: - from apache_beam.runners.worker import statesampler_fake as statesampler - - -DATA_INPUT_URN = 'urn:org.apache.beam:source:runner:0.1' -DATA_OUTPUT_URN = 'urn:org.apache.beam:sink:runner:0.1' -IDENTITY_DOFN_URN = 'urn:org.apache.beam:dofn:identity:0.1' -PYTHON_ITERABLE_VIEWFN_URN = 'urn:org.apache.beam:viewfn:iterable:python:0.1' -PYTHON_CODER_URN = 'urn:org.apache.beam:coder:python:0.1' -# TODO(vikasrk): Fix this once runner sends appropriate python urns. -PYTHON_DOFN_URN = 'urn:org.apache.beam:dofn:java:0.1' -PYTHON_SOURCE_URN = 'urn:org.apache.beam:source:java:0.1' - - -def side_input_tag(transform_id, tag): - return str("%d[%s][%s]" % (len(transform_id), transform_id, tag)) - - -class RunnerIOOperation(operations.Operation): - """Common baseclass for runner harness IO operations.""" - - def __init__(self, operation_name, step_name, consumers, counter_factory, - state_sampler, windowed_coder, target, data_channel): - super(RunnerIOOperation, self).__init__( - operation_name, None, counter_factory, state_sampler) - self.windowed_coder = windowed_coder - self.step_name = step_name - # target represents the consumer for the bytes in the data plane for a - # DataInputOperation or a producer of these bytes for a DataOutputOperation. - self.target = target - self.data_channel = data_channel - for _, consumer_ops in consumers.items(): - for consumer in consumer_ops: - self.add_receiver(consumer, 0) - - -class DataOutputOperation(RunnerIOOperation): - """A sink-like operation that gathers outputs to be sent back to the runner. - """ - - def set_output_stream(self, output_stream): - self.output_stream = output_stream - - def process(self, windowed_value): - self.windowed_coder.get_impl().encode_to_stream( - windowed_value, self.output_stream, True) - - def finish(self): - self.output_stream.close() - super(DataOutputOperation, self).finish() - - -class DataInputOperation(RunnerIOOperation): - """A source-like operation that gathers input from the runner. - """ - - def __init__(self, operation_name, step_name, consumers, counter_factory, - state_sampler, windowed_coder, input_target, data_channel): - super(DataInputOperation, self).__init__( - operation_name, step_name, consumers, counter_factory, state_sampler, - windowed_coder, target=input_target, data_channel=data_channel) - # We must do this manually as we don't have a spec or spec.output_coders. - self.receivers = [ - operations.ConsumerSet(self.counter_factory, self.step_name, 0, - consumers.itervalues().next(), - self.windowed_coder)] - - def process(self, windowed_value): - self.output(windowed_value) - - def process_encoded(self, encoded_windowed_values): - input_stream = coder_impl.create_InputStream(encoded_windowed_values) - while input_stream.size() > 0: - decoded_value = self.windowed_coder.get_impl().decode_from_stream( - input_stream, True) - self.output(decoded_value) - - -# TODO(robertwb): Revise side input API to not be in terms of native sources. -# This will enable lookups, but there's an open question as to how to handle -# custom sources without forcing intermediate materialization. This seems very -# related to the desire to inject key and window preserving [Splittable]DoFns -# into the view computation. -class SideInputSource(native_iobase.NativeSource, - native_iobase.NativeSourceReader): - """A 'source' for reading side inputs via state API calls. - """ - - def __init__(self, state_handler, state_key, coder): - self._state_handler = state_handler - self._state_key = state_key - self._coder = coder - - def reader(self): - return self - - @property - def returns_windowed_values(self): - return True - - def __enter__(self): - return self - - def __exit__(self, *exn_info): - pass - - def __iter__(self): - # TODO(robertwb): Support pagination. - input_stream = coder_impl.create_InputStream( - self._state_handler.Get(self._state_key).data) - while input_stream.size() > 0: - yield self._coder.get_impl().decode_from_stream(input_stream, True) - - -def memoize(func): - cache = {} - missing = object() - - def wrapper(*args): - result = cache.get(args, missing) - if result is missing: - result = cache[args] = func(*args) - return result - return wrapper - - -def only_element(iterable): - element, = iterable - return element - - -class BundleProcessor(object): - """A class for processing bundles of elements. - """ - def __init__( - self, process_bundle_descriptor, state_handler, data_channel_factory): - self.process_bundle_descriptor = process_bundle_descriptor - self.state_handler = state_handler - self.data_channel_factory = data_channel_factory - - def create_execution_tree(self, descriptor): - # TODO(robertwb): Figure out the correct prefix to use for output counters - # from StateSampler. - counter_factory = counters.CounterFactory() - state_sampler = statesampler.StateSampler( - 'fnapi-step%s-' % descriptor.id, counter_factory) - - transform_factory = BeamTransformFactory( - descriptor, self.data_channel_factory, counter_factory, state_sampler, - self.state_handler) - - pcoll_consumers = collections.defaultdict(list) - for transform_id, transform_proto in descriptor.transforms.items(): - for pcoll_id in transform_proto.inputs.values(): - pcoll_consumers[pcoll_id].append(transform_id) - - @memoize - def get_operation(transform_id): - transform_consumers = { - tag: [get_operation(op) for op in pcoll_consumers[pcoll_id]] - for tag, pcoll_id - in descriptor.transforms[transform_id].outputs.items() - } - return transform_factory.create_operation( - transform_id, transform_consumers) - - # Operations must be started (hence returned) in order. - @memoize - def topological_height(transform_id): - return 1 + max( - [0] + - [topological_height(consumer) - for pcoll in descriptor.transforms[transform_id].outputs.values() - for consumer in pcoll_consumers[pcoll]]) - - return [get_operation(transform_id) - for transform_id in sorted( - descriptor.transforms, key=topological_height, reverse=True)] - - def process_bundle(self, instruction_id): - ops = self.create_execution_tree(self.process_bundle_descriptor) - - expected_inputs = [] - for op in ops: - if isinstance(op, DataOutputOperation): - # TODO(robertwb): Is there a better way to pass the instruction id to - # the operation? - op.set_output_stream(op.data_channel.output_stream( - instruction_id, op.target)) - elif isinstance(op, DataInputOperation): - # We must wait until we receive "end of stream" for each of these ops. - expected_inputs.append(op) - - # Start all operations. - for op in reversed(ops): - logging.info('start %s', op) - op.start() - - # Inject inputs from data plane. - for input_op in expected_inputs: - for data in input_op.data_channel.input_elements( - instruction_id, [input_op.target]): - # ignores input name - input_op.process_encoded(data.data) - - # Finish all operations. - for op in ops: - logging.info('finish %s', op) - op.finish() - - -class BeamTransformFactory(object): - """Factory for turning transform_protos into executable operations.""" - def __init__(self, descriptor, data_channel_factory, counter_factory, - state_sampler, state_handler): - self.descriptor = descriptor - self.data_channel_factory = data_channel_factory - self.counter_factory = counter_factory - self.state_sampler = state_sampler - self.state_handler = state_handler - self.context = pipeline_context.PipelineContext(descriptor) - - _known_urns = {} - - @classmethod - def register_urn(cls, urn, parameter_type): - def wrapper(func): - cls._known_urns[urn] = func, parameter_type - return func - return wrapper - - def create_operation(self, transform_id, consumers): - transform_proto = self.descriptor.transforms[transform_id] - creator, parameter_type = self._known_urns[transform_proto.spec.urn] - parameter = proto_utils.unpack_Any( - transform_proto.spec.parameter, parameter_type) - return creator(self, transform_id, transform_proto, parameter, consumers) - - def get_coder(self, coder_id): - coder_proto = self.descriptor.coders[coder_id] - if coder_proto.spec.spec.urn: - return self.context.coders.get_by_id(coder_id) - else: - # No URN, assume cloud object encoding json bytes. - return operation_specs.get_coder_from_spec( - json.loads( - proto_utils.unpack_Any(coder_proto.spec.spec.parameter, - wrappers_pb2.BytesValue).value)) - - def get_output_coders(self, transform_proto): - return { - tag: self.get_coder(self.descriptor.pcollections[pcoll_id].coder_id) - for tag, pcoll_id in transform_proto.outputs.items() - } - - def get_only_output_coder(self, transform_proto): - return only_element(self.get_output_coders(transform_proto).values()) - - def get_input_coders(self, transform_proto): - return { - tag: self.get_coder(self.descriptor.pcollections[pcoll_id].coder_id) - for tag, pcoll_id in transform_proto.inputs.items() - } - - def get_only_input_coder(self, transform_proto): - return only_element(self.get_input_coders(transform_proto).values()) - - # TODO(robertwb): Update all operations to take these in the constructor. - @staticmethod - def augment_oldstyle_op(op, step_name, consumers, tag_list=None): - op.step_name = step_name - for tag, op_consumers in consumers.items(): - for consumer in op_consumers: - op.add_receiver(consumer, tag_list.index(tag) if tag_list else 0) - return op - - [email protected]_urn( - DATA_INPUT_URN, beam_fn_api_pb2.RemoteGrpcPort) -def create(factory, transform_id, transform_proto, grpc_port, consumers): - target = beam_fn_api_pb2.Target( - primitive_transform_reference=transform_id, - name=only_element(transform_proto.outputs.keys())) - return DataInputOperation( - transform_proto.unique_name, - transform_proto.unique_name, - consumers, - factory.counter_factory, - factory.state_sampler, - factory.get_only_output_coder(transform_proto), - input_target=target, - data_channel=factory.data_channel_factory.create_data_channel(grpc_port)) - - [email protected]_urn( - DATA_OUTPUT_URN, beam_fn_api_pb2.RemoteGrpcPort) -def create(factory, transform_id, transform_proto, grpc_port, consumers): - target = beam_fn_api_pb2.Target( - primitive_transform_reference=transform_id, - name=only_element(transform_proto.inputs.keys())) - return DataOutputOperation( - transform_proto.unique_name, - transform_proto.unique_name, - consumers, - factory.counter_factory, - factory.state_sampler, - # TODO(robertwb): Perhaps this could be distinct from the input coder? - factory.get_only_input_coder(transform_proto), - target=target, - data_channel=factory.data_channel_factory.create_data_channel(grpc_port)) - - [email protected]_urn(PYTHON_SOURCE_URN, wrappers_pb2.BytesValue) -def create(factory, transform_id, transform_proto, parameter, consumers): - # The Dataflow runner harness strips the base64 encoding. - source = pickler.loads(base64.b64encode(parameter.value)) - spec = operation_specs.WorkerRead( - iobase.SourceBundle(1.0, source, None, None), - [WindowedValueCoder(source.default_output_coder())]) - return factory.augment_oldstyle_op( - operations.ReadOperation( - transform_proto.unique_name, - spec, - factory.counter_factory, - factory.state_sampler), - transform_proto.unique_name, - consumers) - - [email protected]_urn(PYTHON_DOFN_URN, wrappers_pb2.BytesValue) -def create(factory, transform_id, transform_proto, parameter, consumers): - dofn_data = pickler.loads(parameter.value) - if len(dofn_data) == 2: - # Has side input data. - serialized_fn, side_input_data = dofn_data - else: - # No side input data. - serialized_fn, side_input_data = parameter.value, [] - - def create_side_input(tag, coder): - # TODO(robertwb): Extract windows (and keys) out of element data. - # TODO(robertwb): Extract state key from ParDoPayload. - return operation_specs.WorkerSideInputSource( - tag=tag, - source=SideInputSource( - factory.state_handler, - beam_fn_api_pb2.StateKey.MultimapSideInput( - key=side_input_tag(transform_id, tag)), - coder=coder)) - output_tags = list(transform_proto.outputs.keys()) - output_coders = factory.get_output_coders(transform_proto) - spec = operation_specs.WorkerDoFn( - serialized_fn=serialized_fn, - output_tags=output_tags, - input=None, - side_inputs=[ - create_side_input(tag, coder) for tag, coder in side_input_data], - output_coders=[output_coders[tag] for tag in output_tags]) - return factory.augment_oldstyle_op( - operations.DoOperation( - transform_proto.unique_name, - spec, - factory.counter_factory, - factory.state_sampler), - transform_proto.unique_name, - consumers, - output_tags) - - [email protected]_urn(IDENTITY_DOFN_URN, None) -def create(factory, transform_id, transform_proto, unused_parameter, consumers): - return factory.augment_oldstyle_op( - operations.FlattenOperation( - transform_proto.unique_name, - None, - factory.counter_factory, - factory.state_sampler), - transform_proto.unique_name, - consumers) http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/runners/worker/data_plane.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/worker/data_plane.py b/sdks/python/apache_beam/runners/worker/data_plane.py index 26f65ee..5edd0b4 100644 --- a/sdks/python/apache_beam/runners/worker/data_plane.py +++ b/sdks/python/apache_beam/runners/worker/data_plane.py @@ -28,7 +28,7 @@ import Queue as queue import threading from apache_beam.coders import coder_impl -from apache_beam.portability.api import beam_fn_api_pb2 +from apache_beam.runners.api import beam_fn_api_pb2 import grpc # This module is experimental. No backwards-compatibility guarantees. @@ -167,18 +167,12 @@ class _GrpcDataChannel(DataChannel): yield data def output_stream(self, instruction_id, target): - # TODO: Return an output stream that sends data - # to the Runner once a fixed size buffer is full. - # Currently we buffer all the data before sending - # any messages. def add_to_send_queue(data): - if data: - self._to_send.put( - beam_fn_api_pb2.Elements.Data( - instruction_reference=instruction_id, - target=target, - data=data)) - # End of stream marker. + self._to_send.put( + beam_fn_api_pb2.Elements.Data( + instruction_reference=instruction_id, + target=target, + data=data)) self._to_send.put( beam_fn_api_pb2.Elements.Data( instruction_reference=instruction_id, @@ -246,8 +240,8 @@ class DataChannelFactory(object): __metaclass__ = abc.ABCMeta @abc.abstractmethod - def create_data_channel(self, remote_grpc_port): - """Returns a ``DataChannel`` from the given RemoteGrpcPort.""" + def create_data_channel(self, function_spec): + """Returns a ``DataChannel`` from the given function_spec.""" raise NotImplementedError(type(self)) @abc.abstractmethod @@ -265,7 +259,9 @@ class GrpcClientDataChannelFactory(DataChannelFactory): def __init__(self): self._data_channel_cache = {} - def create_data_channel(self, remote_grpc_port): + def create_data_channel(self, function_spec): + remote_grpc_port = beam_fn_api_pb2.RemoteGrpcPort() + function_spec.data.Unpack(remote_grpc_port) url = remote_grpc_port.api_service_descriptor.url if url not in self._data_channel_cache: logging.info('Creating channel for %s', url) @@ -287,7 +283,7 @@ class InMemoryDataChannelFactory(DataChannelFactory): def __init__(self, in_memory_data_channel): self._in_memory_data_channel = in_memory_data_channel - def create_data_channel(self, unused_remote_grpc_port): + def create_data_channel(self, unused_function_spec): return self._in_memory_data_channel def close(self): http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/runners/worker/data_plane_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/worker/data_plane_test.py b/sdks/python/apache_beam/runners/worker/data_plane_test.py index 360468a..e3e01ac 100644 --- a/sdks/python/apache_beam/runners/worker/data_plane_test.py +++ b/sdks/python/apache_beam/runners/worker/data_plane_test.py @@ -29,7 +29,7 @@ import unittest from concurrent import futures import grpc -from apache_beam.portability.api import beam_fn_api_pb2 +from apache_beam.runners.api import beam_fn_api_pb2 from apache_beam.runners.worker import data_plane http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/runners/worker/log_handler.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/worker/log_handler.py b/sdks/python/apache_beam/runners/worker/log_handler.py index b8f6352..59ffbf4 100644 --- a/sdks/python/apache_beam/runners/worker/log_handler.py +++ b/sdks/python/apache_beam/runners/worker/log_handler.py @@ -21,7 +21,7 @@ import math import Queue as queue import threading -from apache_beam.portability.api import beam_fn_api_pb2 +from apache_beam.runners.api import beam_fn_api_pb2 import grpc # This module is experimental. No backwards-compatibility guarantees. http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/runners/worker/log_handler_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/worker/log_handler_test.py b/sdks/python/apache_beam/runners/worker/log_handler_test.py index 2256bb5..8720ca8 100644 --- a/sdks/python/apache_beam/runners/worker/log_handler_test.py +++ b/sdks/python/apache_beam/runners/worker/log_handler_test.py @@ -22,7 +22,7 @@ import unittest from concurrent import futures import grpc -from apache_beam.portability.api import beam_fn_api_pb2 +from apache_beam.runners.api import beam_fn_api_pb2 from apache_beam.runners.worker import log_handler http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/runners/worker/operation_specs.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/worker/operation_specs.py b/sdks/python/apache_beam/runners/worker/operation_specs.py index bdafbea..db5eb76 100644 --- a/sdks/python/apache_beam/runners/worker/operation_specs.py +++ b/sdks/python/apache_beam/runners/worker/operation_specs.py @@ -326,12 +326,11 @@ def get_coder_from_spec(coder_spec): assert len(coder_spec['component_encodings']) == 2 value_coder, window_coder = [ get_coder_from_spec(c) for c in coder_spec['component_encodings']] - return coders.coders.WindowedValueCoder( - value_coder, window_coder=window_coder) + return coders.WindowedValueCoder(value_coder, window_coder=window_coder) elif coder_spec['@type'] == 'kind:interval_window': assert ('component_encodings' not in coder_spec or not coder_spec['component_encodings']) - return coders.coders.IntervalWindowCoder() + return coders.IntervalWindowCoder() elif coder_spec['@type'] == 'kind:global_window': assert ('component_encodings' not in coder_spec or not coder_spec['component_encodings']) @@ -340,10 +339,6 @@ def get_coder_from_spec(coder_spec): assert len(coder_spec['component_encodings']) == 1 return coders.coders.LengthPrefixCoder( get_coder_from_spec(coder_spec['component_encodings'][0])) - elif coder_spec['@type'] == 'kind:bytes': - assert ('component_encodings' not in coder_spec - or len(coder_spec['component_encodings'] == 0)) - return coders.BytesCoder() # We pass coders in the form "<coder_name>$<pickled_data>" to make the job # description JSON more readable. http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/runners/worker/operations.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/worker/operations.py b/sdks/python/apache_beam/runners/worker/operations.py index c4f945b..a44561d 100644 --- a/sdks/python/apache_beam/runners/worker/operations.py +++ b/sdks/python/apache_beam/runners/worker/operations.py @@ -129,7 +129,6 @@ class Operation(object): self.operation_name + '-finish') # TODO(ccy): the '-abort' state can be added when the abort is supported in # Operations. - self.scoped_metrics_container = None def start(self): """Start operation.""" http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/runners/worker/sdk_worker.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py index 6a23680..33c50ad 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py @@ -21,21 +21,198 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections +import json import logging import Queue as queue import threading import traceback +import zlib -from apache_beam.portability.api import beam_fn_api_pb2 -from apache_beam.runners.worker import bundle_processor -from apache_beam.runners.worker import data_plane +import dill +from google.protobuf import wrappers_pb2 + +from apache_beam.coders import coder_impl +from apache_beam.coders import WindowedValueCoder +from apache_beam.internal import pickler +from apache_beam.io import iobase +from apache_beam.runners.dataflow.native_io import iobase as native_iobase +from apache_beam.utils import counters +from apache_beam.runners.api import beam_fn_api_pb2 +from apache_beam.runners.worker import operation_specs +from apache_beam.runners.worker import operations + +# This module is experimental. No backwards-compatibility guarantees. + + +try: + from apache_beam.runners.worker import statesampler +except ImportError: + from apache_beam.runners.worker import statesampler_fake as statesampler +from apache_beam.runners.worker.data_plane import GrpcClientDataChannelFactory + + +DATA_INPUT_URN = 'urn:org.apache.beam:source:runner:0.1' +DATA_OUTPUT_URN = 'urn:org.apache.beam:sink:runner:0.1' +IDENTITY_DOFN_URN = 'urn:org.apache.beam:dofn:identity:0.1' +PYTHON_ITERABLE_VIEWFN_URN = 'urn:org.apache.beam:viewfn:iterable:python:0.1' +PYTHON_CODER_URN = 'urn:org.apache.beam:coder:python:0.1' +# TODO(vikasrk): Fix this once runner sends appropriate python urns. +PYTHON_DOFN_URN = 'urn:org.apache.beam:dofn:java:0.1' +PYTHON_SOURCE_URN = 'urn:org.apache.beam:source:java:0.1' + + +class RunnerIOOperation(operations.Operation): + """Common baseclass for runner harness IO operations.""" + + def __init__(self, operation_name, step_name, consumers, counter_factory, + state_sampler, windowed_coder, target, data_channel): + super(RunnerIOOperation, self).__init__( + operation_name, None, counter_factory, state_sampler) + self.windowed_coder = windowed_coder + self.step_name = step_name + # target represents the consumer for the bytes in the data plane for a + # DataInputOperation or a producer of these bytes for a DataOutputOperation. + self.target = target + self.data_channel = data_channel + for _, consumer_ops in consumers.items(): + for consumer in consumer_ops: + self.add_receiver(consumer, 0) + + +class DataOutputOperation(RunnerIOOperation): + """A sink-like operation that gathers outputs to be sent back to the runner. + """ + + def set_output_stream(self, output_stream): + self.output_stream = output_stream + + def process(self, windowed_value): + self.windowed_coder.get_impl().encode_to_stream( + windowed_value, self.output_stream, True) + + def finish(self): + self.output_stream.close() + super(DataOutputOperation, self).finish() + + +class DataInputOperation(RunnerIOOperation): + """A source-like operation that gathers input from the runner. + """ + + def __init__(self, operation_name, step_name, consumers, counter_factory, + state_sampler, windowed_coder, input_target, data_channel): + super(DataInputOperation, self).__init__( + operation_name, step_name, consumers, counter_factory, state_sampler, + windowed_coder, target=input_target, data_channel=data_channel) + # We must do this manually as we don't have a spec or spec.output_coders. + self.receivers = [ + operations.ConsumerSet(self.counter_factory, self.step_name, 0, + consumers.itervalues().next(), + self.windowed_coder)] + + def process(self, windowed_value): + self.output(windowed_value) + + def process_encoded(self, encoded_windowed_values): + input_stream = coder_impl.create_InputStream(encoded_windowed_values) + while input_stream.size() > 0: + decoded_value = self.windowed_coder.get_impl().decode_from_stream( + input_stream, True) + self.output(decoded_value) + + +# TODO(robertwb): Revise side input API to not be in terms of native sources. +# This will enable lookups, but there's an open question as to how to handle +# custom sources without forcing intermediate materialization. This seems very +# related to the desire to inject key and window preserving [Splittable]DoFns +# into the view computation. +class SideInputSource(native_iobase.NativeSource, + native_iobase.NativeSourceReader): + """A 'source' for reading side inputs via state API calls. + """ + + def __init__(self, state_handler, state_key, coder): + self._state_handler = state_handler + self._state_key = state_key + self._coder = coder + + def reader(self): + return self + + @property + def returns_windowed_values(self): + return True + + def __enter__(self): + return self + + def __exit__(self, *exn_info): + pass + + def __iter__(self): + # TODO(robertwb): Support pagination. + input_stream = coder_impl.create_InputStream( + self._state_handler.Get(self._state_key).data) + while input_stream.size() > 0: + yield self._coder.get_impl().decode_from_stream(input_stream, True) + + +def unpack_and_deserialize_py_fn(function_spec): + """Returns unpacked and deserialized object from function spec proto.""" + return pickler.loads(unpack_function_spec_data(function_spec)) + + +def unpack_function_spec_data(function_spec): + """Returns unpacked data from function spec proto.""" + data = wrappers_pb2.BytesValue() + function_spec.data.Unpack(data) + return data.value + + +# pylint: disable=redefined-builtin +def serialize_and_pack_py_fn(fn, urn, id=None): + """Returns serialized and packed function in a function spec proto.""" + return pack_function_spec_data(pickler.dumps(fn), urn, id) +# pylint: enable=redefined-builtin + + +# pylint: disable=redefined-builtin +def pack_function_spec_data(value, urn, id=None): + """Returns packed data in a function spec proto.""" + data = wrappers_pb2.BytesValue(value=value) + fn_proto = beam_fn_api_pb2.FunctionSpec(urn=urn) + fn_proto.data.Pack(data) + if id: + fn_proto.id = id + return fn_proto +# pylint: enable=redefined-builtin + + +# TODO(vikasrk): move this method to ``coders.py`` in the SDK. +def load_compressed(compressed_data): + """Returns a decompressed and deserialized python object.""" + # Note: SDK uses ``pickler.dumps`` to serialize certain python objects + # (like sources), which involves serialization, compression and base64 + # encoding. We cannot directly use ``pickler.loads`` for + # deserialization, as the runner would have already base64 decoded the + # data. So we only need to decompress and deserialize. + + data = zlib.decompress(compressed_data) + try: + return dill.loads(data) + except Exception: # pylint: disable=broad-except + dill.dill._trace(True) # pylint: disable=protected-access + return dill.loads(data) + finally: + dill.dill._trace(False) # pylint: disable=protected-access class SdkHarness(object): def __init__(self, control_channel): self._control_channel = control_channel - self._data_channel_factory = data_plane.GrpcClientDataChannelFactory() + self._data_channel_factory = GrpcClientDataChannelFactory() def run(self): contol_stub = beam_fn_api_pb2.BeamFnControlStub(self._control_channel) @@ -59,10 +236,6 @@ class SdkHarness(object): try: response = self.worker.do_instruction(work_request) except Exception: # pylint: disable=broad-except - logging.error( - 'Error processing instruction %s', - work_request.instruction_id, - exc_info=True) response = beam_fn_api_pb2.InstructionResponse( instruction_id=work_request.instruction_id, error=traceback.format_exc()) @@ -100,12 +273,185 @@ class SdkWorker(object): def register(self, request, unused_instruction_id=None): for process_bundle_descriptor in request.process_bundle_descriptor: self.fns[process_bundle_descriptor.id] = process_bundle_descriptor + for p_transform in list(process_bundle_descriptor.primitive_transform): + self.fns[p_transform.function_spec.id] = p_transform.function_spec return beam_fn_api_pb2.RegisterResponse() + def initial_source_split(self, request, unused_instruction_id=None): + source_spec = self.fns[request.source_reference] + assert source_spec.urn == PYTHON_SOURCE_URN + source_bundle = unpack_and_deserialize_py_fn( + self.fns[request.source_reference]) + splits = source_bundle.source.split(request.desired_bundle_size_bytes, + source_bundle.start_position, + source_bundle.stop_position) + response = beam_fn_api_pb2.InitialSourceSplitResponse() + response.splits.extend([ + beam_fn_api_pb2.SourceSplit( + source=serialize_and_pack_py_fn(split, PYTHON_SOURCE_URN), + relative_size=split.weight, + ) + for split in splits + ]) + return response + + def create_execution_tree(self, descriptor): + # TODO(vikasrk): Add an id field to Coder proto and use that instead. + coders = {coder.function_spec.id: operation_specs.get_coder_from_spec( + json.loads(unpack_function_spec_data(coder.function_spec))) + for coder in descriptor.coders} + + counter_factory = counters.CounterFactory() + # TODO(robertwb): Figure out the correct prefix to use for output counters + # from StateSampler. + state_sampler = statesampler.StateSampler( + 'fnapi-step%s-' % descriptor.id, counter_factory) + consumers = collections.defaultdict(lambda: collections.defaultdict(list)) + ops_by_id = {} + reversed_ops = [] + + for transform in reversed(descriptor.primitive_transform): + # TODO(robertwb): Figure out how to plumb through the operation name (e.g. + # "s3") from the service through the FnAPI so that msec counters can be + # reported and correctly plumbed through the service and the UI. + operation_name = 'fnapis%s' % transform.id + + def only_element(iterable): + element, = iterable + return element + + if transform.function_spec.urn == DATA_OUTPUT_URN: + target = beam_fn_api_pb2.Target( + primitive_transform_reference=transform.id, + name=only_element(transform.outputs.keys())) + + op = DataOutputOperation( + operation_name, + transform.step_name, + consumers[transform.id], + counter_factory, + state_sampler, + coders[only_element(transform.outputs.values()).coder_reference], + target, + self.data_channel_factory.create_data_channel( + transform.function_spec)) + + elif transform.function_spec.urn == DATA_INPUT_URN: + target = beam_fn_api_pb2.Target( + primitive_transform_reference=transform.id, + name=only_element(transform.inputs.keys())) + op = DataInputOperation( + operation_name, + transform.step_name, + consumers[transform.id], + counter_factory, + state_sampler, + coders[only_element(transform.outputs.values()).coder_reference], + target, + self.data_channel_factory.create_data_channel( + transform.function_spec)) + + elif transform.function_spec.urn == PYTHON_DOFN_URN: + def create_side_input(tag, si): + # TODO(robertwb): Extract windows (and keys) out of element data. + return operation_specs.WorkerSideInputSource( + tag=tag, + source=SideInputSource( + self.state_handler, + beam_fn_api_pb2.StateKey( + function_spec_reference=si.view_fn.id), + coder=unpack_and_deserialize_py_fn(si.view_fn))) + output_tags = list(transform.outputs.keys()) + spec = operation_specs.WorkerDoFn( + serialized_fn=unpack_function_spec_data(transform.function_spec), + output_tags=output_tags, + input=None, + side_inputs=[create_side_input(tag, si) + for tag, si in transform.side_inputs.items()], + output_coders=[coders[transform.outputs[out].coder_reference] + for out in output_tags]) + + op = operations.DoOperation(operation_name, spec, counter_factory, + state_sampler) + # TODO(robertwb): Move these to the constructor. + op.step_name = transform.step_name + for tag, op_consumers in consumers[transform.id].items(): + for consumer in op_consumers: + op.add_receiver( + consumer, output_tags.index(tag)) + + elif transform.function_spec.urn == IDENTITY_DOFN_URN: + op = operations.FlattenOperation(operation_name, None, counter_factory, + state_sampler) + # TODO(robertwb): Move these to the constructor. + op.step_name = transform.step_name + for tag, op_consumers in consumers[transform.id].items(): + for consumer in op_consumers: + op.add_receiver(consumer, 0) + + elif transform.function_spec.urn == PYTHON_SOURCE_URN: + source = load_compressed(unpack_function_spec_data( + transform.function_spec)) + # TODO(vikasrk): Remove this once custom source is implemented with + # splittable dofn via the data plane. + spec = operation_specs.WorkerRead( + iobase.SourceBundle(1.0, source, None, None), + [WindowedValueCoder(source.default_output_coder())]) + op = operations.ReadOperation(operation_name, spec, counter_factory, + state_sampler) + op.step_name = transform.step_name + output_tags = list(transform.outputs.keys()) + for tag, op_consumers in consumers[transform.id].items(): + for consumer in op_consumers: + op.add_receiver( + consumer, output_tags.index(tag)) + + else: + raise NotImplementedError + + # Record consumers. + for _, inputs in transform.inputs.items(): + for target in inputs.target: + consumers[target.primitive_transform_reference][target.name].append( + op) + + reversed_ops.append(op) + ops_by_id[transform.id] = op + + return list(reversed(reversed_ops)), ops_by_id + def process_bundle(self, request, instruction_id): - bundle_processor.BundleProcessor( - self.fns[request.process_bundle_descriptor_reference], - self.state_handler, - self.data_channel_factory).process_bundle(instruction_id) + ops, ops_by_id = self.create_execution_tree( + self.fns[request.process_bundle_descriptor_reference]) + + expected_inputs = [] + for _, op in ops_by_id.items(): + if isinstance(op, DataOutputOperation): + # TODO(robertwb): Is there a better way to pass the instruction id to + # the operation? + op.set_output_stream(op.data_channel.output_stream( + instruction_id, op.target)) + elif isinstance(op, DataInputOperation): + # We must wait until we receive "end of stream" for each of these ops. + expected_inputs.append(op) + + # Start all operations. + for op in reversed(ops): + logging.info('start %s', op) + op.start() + + # Inject inputs from data plane. + for input_op in expected_inputs: + for data in input_op.data_channel.input_elements( + instruction_id, [input_op.target]): + # ignores input name + target_op = ops_by_id[data.target.primitive_transform_reference] + # lacks coder for non-input ops + target_op.process_encoded(data.data) + + # Finish all operations. + for op in ops: + logging.info('finish %s', op) + op.finish() return beam_fn_api_pb2.ProcessBundleResponse() http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py index f3f1e02..b891779 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker_main.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker_main.py @@ -24,7 +24,7 @@ import sys import grpc from google.protobuf import text_format -from apache_beam.portability.api import beam_fn_api_pb2 +from apache_beam.runners.api import beam_fn_api_pb2 from apache_beam.runners.worker.log_handler import FnApiLogRecordHandler from apache_beam.runners.worker.sdk_worker import SdkHarness http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/runners/worker/sdk_worker_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker_test.py b/sdks/python/apache_beam/runners/worker/sdk_worker_test.py index dc72a5f..0d0811b 100644 --- a/sdks/python/apache_beam/runners/worker/sdk_worker_test.py +++ b/sdks/python/apache_beam/runners/worker/sdk_worker_test.py @@ -27,8 +27,10 @@ import unittest from concurrent import futures import grpc -from apache_beam.portability.api import beam_fn_api_pb2 -from apache_beam.portability.api import beam_runner_api_pb2 +from apache_beam.io.concat_source_test import RangeSource +from apache_beam.io.iobase import SourceBundle +from apache_beam.runners.api import beam_fn_api_pb2 +from apache_beam.runners.worker import data_plane from apache_beam.runners.worker import sdk_worker @@ -62,12 +64,13 @@ class BeamFnControlServicer(beam_fn_api_pb2.BeamFnControlServicer): class SdkWorkerTest(unittest.TestCase): def test_fn_registration(self): - process_bundle_descriptors = [ - beam_fn_api_pb2.ProcessBundleDescriptor( - id=str(100+ix), - transforms={ - str(ix): beam_runner_api_pb2.PTransform(unique_name=str(ix))}) - for ix in range(4)] + fns = [beam_fn_api_pb2.FunctionSpec(id=str(ix)) for ix in range(4)] + + process_bundle_descriptors = [beam_fn_api_pb2.ProcessBundleDescriptor( + id=str(100+ix), + primitive_transform=[ + beam_fn_api_pb2.PrimitiveTransform(function_spec=fn)]) + for ix, fn in enumerate(fns)] test_controller = BeamFnControlServicer([beam_fn_api_pb2.InstructionRequest( register=beam_fn_api_pb2.RegisterRequest( @@ -83,7 +86,81 @@ class SdkWorkerTest(unittest.TestCase): harness.run() self.assertEqual( harness.worker.fns, - {item.id: item for item in process_bundle_descriptors}) + {item.id: item for item in fns + process_bundle_descriptors}) + + @unittest.skip("initial splitting not in proto") + def test_source_split(self): + source = RangeSource(0, 100) + expected_splits = list(source.split(30)) + + worker = sdk_harness.SdkWorker( + None, data_plane.GrpcClientDataChannelFactory()) + worker.register( + beam_fn_api_pb2.RegisterRequest( + process_bundle_descriptor=[beam_fn_api_pb2.ProcessBundleDescriptor( + primitive_transform=[beam_fn_api_pb2.PrimitiveTransform( + function_spec=sdk_harness.serialize_and_pack_py_fn( + SourceBundle(1.0, source, None, None), + sdk_harness.PYTHON_SOURCE_URN, + id="src"))])])) + split_response = worker.initial_source_split( + beam_fn_api_pb2.InitialSourceSplitRequest( + desired_bundle_size_bytes=30, + source_reference="src")) + + self.assertEqual( + expected_splits, + [sdk_harness.unpack_and_deserialize_py_fn(s.source) + for s in split_response.splits]) + + self.assertEqual( + [s.weight for s in expected_splits], + [s.relative_size for s in split_response.splits]) + + @unittest.skip("initial splitting not in proto") + def test_source_split_via_instruction(self): + + source = RangeSource(0, 100) + expected_splits = list(source.split(30)) + + test_controller = BeamFnControlServicer([ + beam_fn_api_pb2.InstructionRequest( + instruction_id="register_request", + register=beam_fn_api_pb2.RegisterRequest( + process_bundle_descriptor=[ + beam_fn_api_pb2.ProcessBundleDescriptor( + primitive_transform=[beam_fn_api_pb2.PrimitiveTransform( + function_spec=sdk_harness.serialize_and_pack_py_fn( + SourceBundle(1.0, source, None, None), + sdk_harness.PYTHON_SOURCE_URN, + id="src"))])])), + beam_fn_api_pb2.InstructionRequest( + instruction_id="split_request", + initial_source_split=beam_fn_api_pb2.InitialSourceSplitRequest( + desired_bundle_size_bytes=30, + source_reference="src")) + ]) + + server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + beam_fn_api_pb2.add_BeamFnControlServicer_to_server(test_controller, server) + test_port = server.add_insecure_port("[::]:0") + server.start() + + channel = grpc.insecure_channel("localhost:%s" % test_port) + harness = sdk_harness.SdkHarness(channel) + harness.run() + + split_response = test_controller.responses[ + "split_request"].initial_source_split + + self.assertEqual( + expected_splits, + [sdk_harness.unpack_and_deserialize_py_fn(s.source) + for s in split_response.splits]) + + self.assertEqual( + [s.weight for s in expected_splits], + [s.relative_size for s in split_response.splits]) if __name__ == "__main__": http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/testing/test_stream.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/testing/test_stream.py b/sdks/python/apache_beam/testing/test_stream.py index 7989fb2..a06bcd0 100644 --- a/sdks/python/apache_beam/testing/test_stream.py +++ b/sdks/python/apache_beam/testing/test_stream.py @@ -24,10 +24,8 @@ from abc import ABCMeta from abc import abstractmethod from apache_beam import coders -from apache_beam import core from apache_beam import pvalue from apache_beam.transforms import PTransform -from apache_beam.transforms import window from apache_beam.transforms.window import TimestampedValue from apache_beam.utils import timestamp from apache_beam.utils.windowed_value import WindowedValue @@ -101,9 +99,6 @@ class TestStream(PTransform): self.current_watermark = timestamp.MIN_TIMESTAMP self.events = [] - def get_windowing(self, unused_inputs): - return core.Windowing(window.GlobalWindows()) - def expand(self, pbegin): assert isinstance(pbegin, pvalue.PBegin) self.pipeline = pbegin.pipeline http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/testing/test_stream_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/testing/test_stream_test.py b/sdks/python/apache_beam/testing/test_stream_test.py index b7ca141..e32dda2 100644 --- a/sdks/python/apache_beam/testing/test_stream_test.py +++ b/sdks/python/apache_beam/testing/test_stream_test.py @@ -19,16 +19,10 @@ import unittest -import apache_beam as beam -from apache_beam.options.pipeline_options import PipelineOptions -from apache_beam.options.pipeline_options import StandardOptions -from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.testing.test_stream import ElementEvent from apache_beam.testing.test_stream import ProcessingTimeEvent from apache_beam.testing.test_stream import TestStream from apache_beam.testing.test_stream import WatermarkEvent -from apache_beam.testing.util import assert_that, equal_to -from apache_beam.transforms.window import FixedWindows from apache_beam.transforms.window import TimestampedValue from apache_beam.utils import timestamp from apache_beam.utils.windowed_value import WindowedValue @@ -84,68 +78,6 @@ class TestStreamTest(unittest.TestCase): TimestampedValue('a', timestamp.MAX_TIMESTAMP) ])) - def test_basic_execution(self): - test_stream = (TestStream() - .advance_watermark_to(10) - .add_elements(['a', 'b', 'c']) - .advance_watermark_to(20) - .add_elements(['d']) - .add_elements(['e']) - .advance_processing_time(10) - .advance_watermark_to(300) - .add_elements([TimestampedValue('late', 12)]) - .add_elements([TimestampedValue('last', 310)])) - - class RecordFn(beam.DoFn): - def process(self, element=beam.DoFn.ElementParam, - timestamp=beam.DoFn.TimestampParam): - yield (element, timestamp) - - options = PipelineOptions() - options.view_as(StandardOptions).streaming = True - p = TestPipeline(options=options) - my_record_fn = RecordFn() - records = p | test_stream | beam.ParDo(my_record_fn) - assert_that(records, equal_to([ - ('a', timestamp.Timestamp(10)), - ('b', timestamp.Timestamp(10)), - ('c', timestamp.Timestamp(10)), - ('d', timestamp.Timestamp(20)), - ('e', timestamp.Timestamp(20)), - ('late', timestamp.Timestamp(12)), - ('last', timestamp.Timestamp(310)),])) - p.run() - - def test_gbk_execution(self): - test_stream = (TestStream() - .advance_watermark_to(10) - .add_elements(['a', 'b', 'c']) - .advance_watermark_to(20) - .add_elements(['d']) - .add_elements(['e']) - .advance_processing_time(10) - .advance_watermark_to(300) - .add_elements([TimestampedValue('late', 12)]) - .add_elements([TimestampedValue('last', 310)])) - - options = PipelineOptions() - options.view_as(StandardOptions).streaming = True - p = TestPipeline(options=options) - records = (p - | test_stream - | beam.WindowInto(FixedWindows(15)) - | beam.Map(lambda x: ('k', x)) - | beam.GroupByKey()) - # TODO(BEAM-2519): timestamp assignment for elements from a GBK should - # respect the TimestampCombiner. The test below should also verify the - # timestamps of the outputted elements once this is implemented. - assert_that(records, equal_to([ - ('k', ['a', 'b', 'c']), - ('k', ['d', 'e']), - ('k', ['late']), - ('k', ['last'])])) - p.run() - if __name__ == '__main__': unittest.main() http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/transforms/combiners.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/transforms/combiners.py b/sdks/python/apache_beam/transforms/combiners.py index 875306f..fa0742d 100644 --- a/sdks/python/apache_beam/transforms/combiners.py +++ b/sdks/python/apache_beam/transforms/combiners.py @@ -149,7 +149,6 @@ class Top(object): """Combiners for obtaining extremal elements.""" # pylint: disable=no-self-argument - @staticmethod @ptransform.ptransform_fn def Of(pcoll, n, compare=None, *args, **kwargs): """Obtain a list of the compare-most N elements in a PCollection. @@ -178,7 +177,6 @@ class Top(object): return pcoll | core.CombineGlobally( TopCombineFn(n, compare, key, reverse), *args, **kwargs) - @staticmethod @ptransform.ptransform_fn def PerKey(pcoll, n, compare=None, *args, **kwargs): """Identifies the compare-most N elements associated with each key. @@ -212,25 +210,21 @@ class Top(object): return pcoll | core.CombinePerKey( TopCombineFn(n, compare, key, reverse), *args, **kwargs) - @staticmethod @ptransform.ptransform_fn def Largest(pcoll, n): """Obtain a list of the greatest N elements in a PCollection.""" return pcoll | Top.Of(n) - @staticmethod @ptransform.ptransform_fn def Smallest(pcoll, n): """Obtain a list of the least N elements in a PCollection.""" return pcoll | Top.Of(n, reverse=True) - @staticmethod @ptransform.ptransform_fn def LargestPerKey(pcoll, n): """Identifies the N greatest elements associated with each key.""" return pcoll | Top.PerKey(n) - @staticmethod @ptransform.ptransform_fn def SmallestPerKey(pcoll, n, reverse=True): """Identifies the N least elements associated with each key.""" @@ -375,12 +369,10 @@ class Sample(object): """Combiners for sampling n elements without replacement.""" # pylint: disable=no-self-argument - @staticmethod @ptransform.ptransform_fn def FixedSizeGlobally(pcoll, n): return pcoll | core.CombineGlobally(SampleCombineFn(n)) - @staticmethod @ptransform.ptransform_fn def FixedSizePerKey(pcoll, n): return pcoll | core.CombinePerKey(SampleCombineFn(n)) http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/transforms/combiners_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/transforms/combiners_test.py b/sdks/python/apache_beam/transforms/combiners_test.py index cd2b595..c79fec8 100644 --- a/sdks/python/apache_beam/transforms/combiners_test.py +++ b/sdks/python/apache_beam/transforms/combiners_test.py @@ -156,11 +156,14 @@ class CombineTest(unittest.TestCase): def test_combine_sample_display_data(self): def individual_test_per_key_dd(sampleFn, args, kwargs): - trs = [sampleFn(*args, **kwargs)] + trs = [beam.CombinePerKey(sampleFn(*args, **kwargs)), + beam.CombineGlobally(sampleFn(*args, **kwargs))] for transform in trs: dd = DisplayData.create_from(transform) expected_items = [ - DisplayDataItemMatcher('fn', transform._fn.__name__)] + DisplayDataItemMatcher('fn', sampleFn.fn.__name__), + DisplayDataItemMatcher('combine_fn', + transform.fn.__class__)] if args: expected_items.append( DisplayDataItemMatcher('args', str(args))) http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/transforms/core.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index 8018219..0e497f9 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -27,7 +27,7 @@ from apache_beam import pvalue from apache_beam import typehints from apache_beam.coders import typecoders from apache_beam.internal import util -from apache_beam.portability.api import beam_runner_api_pb2 +from apache_beam.runners.api import beam_runner_api_pb2 from apache_beam.transforms import ptransform from apache_beam.transforms.display import DisplayDataItem from apache_beam.transforms.display import HasDisplayData @@ -1078,6 +1078,40 @@ class GroupByKey(PTransform): key_type, value_type = trivial_inference.key_value_types(input_type) return Iterable[KV[key_type, typehints.WindowedValue[value_type]]] + class GroupAlsoByWindow(DoFn): + # TODO(robertwb): Support combiner lifting. + + def __init__(self, windowing): + super(GroupByKey.GroupAlsoByWindow, self).__init__() + self.windowing = windowing + + def infer_output_type(self, input_type): + key_type, windowed_value_iter_type = trivial_inference.key_value_types( + input_type) + value_type = windowed_value_iter_type.inner_type.inner_type + return Iterable[KV[key_type, Iterable[value_type]]] + + def start_bundle(self): + # pylint: disable=wrong-import-order, wrong-import-position + from apache_beam.transforms.trigger import InMemoryUnmergedState + from apache_beam.transforms.trigger import create_trigger_driver + # pylint: enable=wrong-import-order, wrong-import-position + self.driver = create_trigger_driver(self.windowing, True) + self.state_type = InMemoryUnmergedState + + def process(self, element): + k, vs = element + state = self.state_type() + # TODO(robertwb): Conditionally process in smaller chunks. + for wvalue in self.driver.process_elements(state, vs, MIN_TIMESTAMP): + yield wvalue.with_value((k, wvalue.value)) + while state.timers: + fired = state.get_and_clear_timers() + for timer_window, (name, time_domain, fire_time) in fired: + for wvalue in self.driver.process_timer( + timer_window, name, time_domain, fire_time, state): + yield wvalue.with_value((k, wvalue.value)) + def expand(self, pcoll): # This code path is only used in the local direct runner. For Dataflow # runner execution, the GroupByKey transform is expanded on the service. @@ -1102,7 +1136,8 @@ class GroupByKey(PTransform): | 'GroupByKey' >> (_GroupByKeyOnly() .with_input_types(reify_output_type) .with_output_types(gbk_input_type)) - | ('GroupByWindow' >> _GroupAlsoByWindow(pcoll.windowing) + | ('GroupByWindow' >> ParDo( + self.GroupAlsoByWindow(pcoll.windowing)) .with_input_types(gbk_input_type) .with_output_types(gbk_output_type))) else: @@ -1110,7 +1145,8 @@ class GroupByKey(PTransform): return (pcoll | 'ReifyWindows' >> ParDo(self.ReifyWindows()) | 'GroupByKey' >> _GroupByKeyOnly() - | 'GroupByWindow' >> _GroupAlsoByWindow(pcoll.windowing)) + | 'GroupByWindow' >> ParDo( + self.GroupAlsoByWindow(pcoll.windowing))) @typehints.with_input_types(typehints.KV[K, V]) @@ -1126,55 +1162,6 @@ class _GroupByKeyOnly(PTransform): return pvalue.PCollection(pcoll.pipeline) [email protected]_input_types(typehints.KV[K, typehints.Iterable[V]]) [email protected]_output_types(typehints.KV[K, typehints.Iterable[V]]) -class _GroupAlsoByWindow(ParDo): - """The GroupAlsoByWindow transform.""" - def __init__(self, windowing): - super(_GroupAlsoByWindow, self).__init__( - _GroupAlsoByWindowDoFn(windowing)) - self.windowing = windowing - - def expand(self, pcoll): - self._check_pcollection(pcoll) - return pvalue.PCollection(pcoll.pipeline) - - -class _GroupAlsoByWindowDoFn(DoFn): - # TODO(robertwb): Support combiner lifting. - - def __init__(self, windowing): - super(_GroupAlsoByWindowDoFn, self).__init__() - self.windowing = windowing - - def infer_output_type(self, input_type): - key_type, windowed_value_iter_type = trivial_inference.key_value_types( - input_type) - value_type = windowed_value_iter_type.inner_type.inner_type - return Iterable[KV[key_type, Iterable[value_type]]] - - def start_bundle(self): - # pylint: disable=wrong-import-order, wrong-import-position - from apache_beam.transforms.trigger import InMemoryUnmergedState - from apache_beam.transforms.trigger import create_trigger_driver - # pylint: enable=wrong-import-order, wrong-import-position - self.driver = create_trigger_driver(self.windowing, True) - self.state_type = InMemoryUnmergedState - - def process(self, element): - k, vs = element - state = self.state_type() - # TODO(robertwb): Conditionally process in smaller chunks. - for wvalue in self.driver.process_elements(state, vs, MIN_TIMESTAMP): - yield wvalue.with_value((k, wvalue.value)) - while state.timers: - fired = state.get_and_clear_timers() - for timer_window, (name, time_domain, fire_time) in fired: - for wvalue in self.driver.process_timer( - timer_window, name, time_domain, fire_time, state): - yield wvalue.with_value((k, wvalue.value)) - - class Partition(PTransformWithSideInputs): """Split a PCollection into several partitions. @@ -1444,18 +1431,15 @@ class Create(PTransform): return Any return Union[[trivial_inference.instance_to_type(v) for v in self.value]] - def get_output_type(self): - return (self.get_type_hints().simple_output_type(self.label) or - self.infer_output_type(None)) - def expand(self, pbegin): from apache_beam.io import iobase assert isinstance(pbegin, pvalue.PBegin) self.pipeline = pbegin.pipeline - coder = typecoders.registry.get_coder(self.get_output_type()) + ouput_type = (self.get_type_hints().simple_output_type(self.label) or + self.infer_output_type(None)) + coder = typecoders.registry.get_coder(ouput_type) source = self._create_source_from_iterable(self.value, coder) - return (pbegin.pipeline - | iobase.Read(source).with_output_types(self.get_output_type())) + return pbegin.pipeline | iobase.Read(source).with_output_types(ouput_type) def get_windowing(self, unused_inputs): return Windowing(GlobalWindows()) http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/transforms/ptransform.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/transforms/ptransform.py b/sdks/python/apache_beam/transforms/ptransform.py index cd84122..bd2a120 100644 --- a/sdks/python/apache_beam/transforms/ptransform.py +++ b/sdks/python/apache_beam/transforms/ptransform.py @@ -430,7 +430,7 @@ class PTransform(WithTypeHints, HasDisplayData): cls._known_urns[urn] = parameter_type, constructor def to_runner_api(self, context): - from apache_beam.portability.api import beam_runner_api_pb2 + from apache_beam.runners.api import beam_runner_api_pb2 urn, typed_param = self.to_runner_api_parameter(context) return beam_runner_api_pb2.FunctionSpec( urn=urn, @@ -595,23 +595,32 @@ class PTransformWithSideInputs(PTransform): return '%s(%s)' % (self.__class__.__name__, self.fn.default_label()) -class _PTransformFnPTransform(PTransform): +class CallablePTransform(PTransform): """A class wrapper for a function-based transform.""" - def __init__(self, fn, *args, **kwargs): - super(_PTransformFnPTransform, self).__init__() - self._fn = fn - self._args = args - self._kwargs = kwargs + def __init__(self, fn): + # pylint: disable=super-init-not-called + # This is a helper class for a function decorator. Only when the class + # is called (and __call__ invoked) we will have all the information + # needed to initialize the super class. + self.fn = fn + self._args = () + self._kwargs = {} def display_data(self): - res = {'fn': (self._fn.__name__ - if hasattr(self._fn, '__name__') - else self._fn.__class__), + res = {'fn': (self.fn.__name__ + if hasattr(self.fn, '__name__') + else self.fn.__class__), 'args': DisplayDataItem(str(self._args)).drop_if_default('()'), 'kwargs': DisplayDataItem(str(self._kwargs)).drop_if_default('{}')} return res + def __call__(self, *args, **kwargs): + super(CallablePTransform, self).__init__() + self._args = args + self._kwargs = kwargs + return self + def expand(self, pcoll): # Since the PTransform will be implemented entirely as a function # (once called), we need to pass through any type-hinting information that @@ -620,18 +629,18 @@ class _PTransformFnPTransform(PTransform): kwargs = dict(self._kwargs) args = tuple(self._args) try: - if 'type_hints' in inspect.getargspec(self._fn).args: + if 'type_hints' in inspect.getargspec(self.fn).args: args = (self.get_type_hints(),) + args except TypeError: # Might not be a function. pass - return self._fn(pcoll, *args, **kwargs) + return self.fn(pcoll, *args, **kwargs) def default_label(self): if self._args: return '%s(%s)' % ( - label_from_callable(self._fn), label_from_callable(self._args[0])) - return label_from_callable(self._fn) + label_from_callable(self.fn), label_from_callable(self._args[0])) + return label_from_callable(self.fn) def ptransform_fn(fn): @@ -675,11 +684,7 @@ def ptransform_fn(fn): operator (i.e., `|`) will inject the pcoll argument in its proper place (first argument if no label was specified and second argument otherwise). """ - # TODO(robertwb): Consider removing staticmethod to allow for self parameter. - - def callable_ptransform_factory(*args, **kwargs): - return _PTransformFnPTransform(fn, *args, **kwargs) - return callable_ptransform_factory + return CallablePTransform(fn) def label_from_callable(fn): http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/transforms/trigger.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/transforms/trigger.py b/sdks/python/apache_beam/transforms/trigger.py index f77fa1a..4200995 100644 --- a/sdks/python/apache_beam/transforms/trigger.py +++ b/sdks/python/apache_beam/transforms/trigger.py @@ -33,10 +33,9 @@ from apache_beam.transforms.window import GlobalWindow from apache_beam.transforms.window import TimestampCombiner from apache_beam.transforms.window import WindowedValue from apache_beam.transforms.window import WindowFn -from apache_beam.portability.api import beam_runner_api_pb2 +from apache_beam.runners.api import beam_runner_api_pb2 from apache_beam.utils.timestamp import MAX_TIMESTAMP from apache_beam.utils.timestamp import MIN_TIMESTAMP -from apache_beam.utils.timestamp import TIME_GRANULARITY # AfterCount is experimental. No backwards compatibility guarantees. @@ -1067,8 +1066,6 @@ class InMemoryUnmergedState(UnmergedState): def clear_timer(self, window, name, time_domain): self.timers[window].pop((name, time_domain), None) - if not self.timers[window]: - del self.timers[window] def get_window(self, window_id): return window_id @@ -1105,34 +1102,17 @@ class InMemoryUnmergedState(UnmergedState): if not self.state[window]: self.state.pop(window, None) - def get_timers(self, clear=False, watermark=MAX_TIMESTAMP): + def get_and_clear_timers(self, watermark=MAX_TIMESTAMP): expired = [] for window, timers in list(self.timers.items()): for (name, time_domain), timestamp in list(timers.items()): if timestamp <= watermark: expired.append((window, (name, time_domain, timestamp))) - if clear: - del timers[(name, time_domain)] - if not timers and clear: + del timers[(name, time_domain)] + if not timers: del self.timers[window] return expired - def get_and_clear_timers(self, watermark=MAX_TIMESTAMP): - return self.get_timers(clear=True, watermark=watermark) - - def get_earliest_hold(self): - earliest_hold = MAX_TIMESTAMP - for unused_window, tagged_states in self.state.iteritems(): - # TODO(BEAM-2519): currently, this assumes that the watermark hold tag is - # named "watermark". This is currently only true because the only place - # watermark holds are set is in the GeneralTriggerDriver, where we use - # this name. We should fix this by allowing enumeration of the tag types - # used in adding state. - if 'watermark' in tagged_states and tagged_states['watermark']: - hold = min(tagged_states['watermark']) - TIME_GRANULARITY - earliest_hold = min(earliest_hold, hold) - return earliest_hold - def __repr__(self): state_str = '\n'.join('%s: %s' % (key, dict(state)) for key, state in self.state.items()) http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/transforms/window.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/transforms/window.py b/sdks/python/apache_beam/transforms/window.py index 458fb74..e87a007 100644 --- a/sdks/python/apache_beam/transforms/window.py +++ b/sdks/python/apache_beam/transforms/window.py @@ -55,8 +55,8 @@ from google.protobuf import duration_pb2 from google.protobuf import timestamp_pb2 from apache_beam.coders import coders -from apache_beam.portability.api import beam_runner_api_pb2 -from apache_beam.portability.api import standard_window_fns_pb2 +from apache_beam.runners.api import beam_runner_api_pb2 +from apache_beam.runners.api import standard_window_fns_pb2 from apache_beam.transforms import timeutil from apache_beam.utils import proto_utils from apache_beam.utils import urns
