http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/b14dfadd/sdks/python/apache_beam/pipeline.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/pipeline.py b/sdks/python/apache_beam/pipeline.py new file mode 100644 index 0000000..ec87f46 --- /dev/null +++ b/sdks/python/apache_beam/pipeline.py @@ -0,0 +1,435 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pipeline, the top-level Dataflow object. + +A pipeline holds a DAG of data transforms. Conceptually the nodes of the DAG +are transforms (PTransform objects) and the edges are values (mostly PCollection +objects). The transforms take as inputs one or more PValues and output one or +more PValues. + +The pipeline offers functionality to traverse the graph. The actual operation +to be executed for each node visited is specified through a runner object. + +Typical usage: + + # Create a pipeline object using a local runner for execution. + pipeline = Pipeline(runner=DirectPipelineRunner()) + + # Add to the pipeline a "Create" transform. When executed this + # transform will produce a PCollection object with the specified values. + pcoll = pipeline.create('label', [1, 2, 3]) + + # run() will execute the DAG stored in the pipeline. The execution of the + # nodes visited is done using the specified local runner. + pipeline.run() + +""" + +from __future__ import absolute_import + +import collections +import logging +import os +import shutil +import tempfile + +from google.cloud.dataflow import pvalue +from google.cloud.dataflow import typehints +from google.cloud.dataflow.internal import pickler +from google.cloud.dataflow.runners import create_runner +from google.cloud.dataflow.runners import PipelineRunner +from google.cloud.dataflow.transforms import format_full_label +from google.cloud.dataflow.transforms import ptransform +from google.cloud.dataflow.typehints import TypeCheckError +from google.cloud.dataflow.utils.options import PipelineOptions +from google.cloud.dataflow.utils.options import SetupOptions +from google.cloud.dataflow.utils.options import StandardOptions +from google.cloud.dataflow.utils.options import TypeOptions +from google.cloud.dataflow.utils.pipeline_options_validator import PipelineOptionsValidator + + +class Pipeline(object): + """A pipeline object that manages a DAG of PValues and their PTransforms. + + Conceptually the PValues are the DAG's nodes and the PTransforms computing + the PValues are the edges. + + All the transforms applied to the pipeline must have distinct full labels. + If same transform instance needs to be applied then a clone should be created + with a new label (e.g., transform.clone('new label')). + """ + + def __init__(self, runner=None, options=None, argv=None): + """Initialize a pipeline object. + + Args: + runner: An object of type 'PipelineRunner' that will be used to execute + the pipeline. For registered runners, the runner name can be specified, + otherwise a runner object must be supplied. + options: A configured 'PipelineOptions' object containing arguments + that should be used for running the Dataflow job. + argv: a list of arguments (such as sys.argv) to be used for building a + 'PipelineOptions' object. This will only be used if argument 'options' + is None. + + Raises: + ValueError: if either the runner or options argument is not of the + expected type. + """ + + if options is not None: + if isinstance(options, PipelineOptions): + self.options = options + else: + raise ValueError( + 'Parameter options, if specified, must be of type PipelineOptions. ' + 'Received : %r', options) + elif argv is not None: + if isinstance(argv, list): + self.options = PipelineOptions(argv) + else: + raise ValueError( + 'Parameter argv, if specified, must be a list. Received : %r', argv) + else: + self.options = None + + if runner is None and self.options is not None: + runner = self.options.view_as(StandardOptions).runner + if runner is None: + runner = StandardOptions.DEFAULT_RUNNER + logging.info(('Missing pipeline option (runner). Executing pipeline ' + 'using the default runner: %s.'), runner) + + if isinstance(runner, str): + runner = create_runner(runner) + elif not isinstance(runner, PipelineRunner): + raise TypeError('Runner must be a PipelineRunner object or the ' + 'name of a registered runner.') + + # Validate pipeline options + if self.options is not None: + errors = PipelineOptionsValidator(self.options, runner).validate() + if errors: + raise ValueError( + 'Pipeline has validations errors: \n' + '\n'.join(errors)) + + # Default runner to be used. + self.runner = runner + # Stack of transforms generated by nested apply() calls. The stack will + # contain a root node as an enclosing (parent) node for top transforms. + self.transforms_stack = [AppliedPTransform(None, None, '', None)] + # Set of transform labels (full labels) applied to the pipeline. + # If a transform is applied and the full label is already in the set + # then the transform will have to be cloned with a new label. + self.applied_labels = set() + # Store cache of views created from PCollections. For reference, see + # pvalue._cache_view(). + self._view_cache = {} + + def _current_transform(self): + """Returns the transform currently on the top of the stack.""" + return self.transforms_stack[-1] + + def _root_transform(self): + """Returns the root transform of the transform stack.""" + return self.transforms_stack[0] + + def run(self): + """Runs the pipeline. Returns whatever our runner returns after running.""" + if not self.options or self.options.view_as(SetupOptions).save_main_session: + # If this option is chosen, verify we can pickle the main session early. + tmpdir = tempfile.mkdtemp() + try: + pickler.dump_session(os.path.join(tmpdir, 'main_session.pickle')) + finally: + shutil.rmtree(tmpdir) + return self.runner.run(self) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if not exc_type: + self.run() + + def visit(self, visitor): + """Visits depth-first every node of a pipeline's DAG. + + Args: + visitor: PipelineVisitor object whose callbacks will be called for each + node visited. See PipelineVisitor comments. + + Raises: + TypeError: if node is specified and is not a PValue. + pipeline.PipelineError: if node is specified and does not belong to this + pipeline instance. + """ + + visited = set() + self._root_transform().visit(visitor, self, visited) + + def apply(self, transform, pvalueish=None): + """Applies a custom transform using the pvalueish specified. + + Args: + transform: the PTranform (or callable) to apply. + pvalueish: the input for the PTransform (typically a PCollection). + + Raises: + TypeError: if the transform object extracted from the argument list is + not a callable type or a descendant from PTransform. + RuntimeError: if the transform object was already applied to this pipeline + and needs to be cloned in order to apply again. + """ + if not isinstance(transform, ptransform.PTransform): + transform = _CallableWrapperPTransform(transform) + + full_label = format_full_label(self._current_transform(), transform) + if full_label in self.applied_labels: + raise RuntimeError( + 'Transform "%s" does not have a stable unique label. ' + 'This will prevent updating of pipelines. ' + 'To clone a transform with a new label use: ' + 'transform.clone("NEW LABEL").' + % full_label) + self.applied_labels.add(full_label) + + pvalueish, inputs = transform._extract_input_pvalues(pvalueish) + try: + inputs = tuple(inputs) + for leaf_input in inputs: + if not isinstance(leaf_input, pvalue.PValue): + raise TypeError + except TypeError: + raise NotImplementedError( + 'Unable to extract PValue inputs from %s; either %s does not accept ' + 'inputs of this format, or it does not properly override ' + '_extract_input_values' % (pvalueish, transform)) + + current = AppliedPTransform( + self._current_transform(), transform, full_label, inputs) + self._current_transform().add_part(current) + self.transforms_stack.append(current) + + if self.options is not None: + type_options = self.options.view_as(TypeOptions) + else: + type_options = None + + if type_options is not None and type_options.pipeline_type_check: + transform.type_check_inputs(pvalueish) + + pvalueish_result = self.runner.apply(transform, pvalueish) + + if type_options is not None and type_options.pipeline_type_check: + transform.type_check_outputs(pvalueish_result) + + for result in ptransform.GetPValues().visit(pvalueish_result): + assert isinstance(result, (pvalue.PValue, pvalue.DoOutputsTuple)) + + # Make sure we set the producer only for a leaf node in the transform DAG. + # This way we preserve the last transform of a composite transform as + # being the real producer of the result. + if result.producer is None: + result.producer = current + # TODO(robertwb): Multi-input, multi-output inference. + # TODO(robertwb): Ideally we'd do intersection here. + if (type_options is not None and type_options.pipeline_type_check and + isinstance(result, (pvalue.PCollection, pvalue.PCollectionView)) + and not result.element_type): + input_element_type = ( + inputs[0].element_type + if len(inputs) == 1 + else typehints.Any) + type_hints = transform.get_type_hints() + declared_output_type = type_hints.simple_output_type(transform.label) + if declared_output_type: + input_types = type_hints.input_types + if input_types and input_types[0]: + declared_input_type = input_types[0][0] + result.element_type = typehints.bind_type_variables( + declared_output_type, + typehints.match_type_variables(declared_input_type, + input_element_type)) + else: + result.element_type = declared_output_type + else: + result.element_type = transform.infer_output_type(input_element_type) + + assert isinstance(result.producer.inputs, tuple) + current.add_output(result) + + if (type_options is not None and + type_options.type_check_strictness == 'ALL_REQUIRED' and + transform.get_type_hints().output_types is None): + ptransform_name = '%s(%s)' % (transform.__class__.__name__, full_label) + raise TypeCheckError('Pipeline type checking is enabled, however no ' + 'output type-hint was found for the ' + 'PTransform %s' % ptransform_name) + + current.update_input_refcounts() + self.transforms_stack.pop() + return pvalueish_result + + +class _CallableWrapperPTransform(ptransform.PTransform): + + def __init__(self, callee): + assert callable(callee) + super(_CallableWrapperPTransform, self).__init__( + label=getattr(callee, '__name__', 'Callable')) + self._callee = callee + + def apply(self, *args, **kwargs): + return self._callee(*args, **kwargs) + + +class PipelineVisitor(object): + """Visitor pattern class used to traverse a DAG of transforms. + + This is an internal class used for bookkeeping by a Pipeline. + """ + + def visit_value(self, value, producer_node): + """Callback for visiting a PValue in the pipeline DAG. + + Args: + value: PValue visited (typically a PCollection instance). + producer_node: AppliedPTransform object whose transform produced the + pvalue. + """ + pass + + def visit_transform(self, transform_node): + """Callback for visiting a transform node in the pipeline DAG.""" + pass + + def enter_composite_transform(self, transform_node): + """Callback for entering traversal of a composite transform node.""" + pass + + def leave_composite_transform(self, transform_node): + """Callback for leaving traversal of a composite transform node.""" + pass + + +class AppliedPTransform(object): + """A transform node representing an instance of applying a PTransform. + + This is an internal class used for bookkeeping by a Pipeline. + """ + + def __init__(self, parent, transform, full_label, inputs): + self.parent = parent + self.transform = transform + # Note that we want the PipelineVisitor classes to use the full_label, + # inputs, side_inputs, and outputs fields from this instance instead of the + # ones of the PTransform instance associated with it. Doing this permits + # reusing PTransform instances in different contexts (apply() calls) without + # any interference. This is particularly useful for composite transforms. + self.full_label = full_label + self.inputs = inputs or () + self.side_inputs = () if transform is None else tuple(transform.side_inputs) + self.outputs = {} + self.parts = [] + + # Per tag refcount dictionary for PValues for which this node is a + # root producer. + self.refcounts = collections.defaultdict(int) + + def update_input_refcounts(self): + """Increment refcounts for all transforms providing inputs.""" + + def real_producer(pv): + real = pv.producer + while real.parts: + real = real.parts[-1] + return real + + if not self.is_composite(): + for main_input in self.inputs: + if not isinstance(main_input, pvalue.PBegin): + real_producer(main_input).refcounts[main_input.tag] += 1 + for side_input in self.side_inputs: + real_producer(side_input).refcounts[side_input.tag] += 1 + + def add_output(self, output, tag=None): + assert (isinstance(output, pvalue.PValue) or + isinstance(output, pvalue.DoOutputsTuple)) + if tag is None: + tag = len(self.outputs) + assert tag not in self.outputs + self.outputs[tag] = output + + def add_part(self, part): + assert isinstance(part, AppliedPTransform) + self.parts.append(part) + + def is_composite(self): + """Returns whether this is a composite transform. + + A composite transform has parts (inner transforms) or isn't the + producer for any of its outputs. (An example of a transform that + is not a producer is one that returns its inputs instead.) + """ + return bool(self.parts) or all( + pval.producer is not self for pval in self.outputs.values()) + + def visit(self, visitor, pipeline, visited): + """Visits all nodes reachable from the current node.""" + + for pval in self.inputs: + if pval not in visited and not isinstance(pval, pvalue.PBegin): + assert pval.producer is not None + pval.producer.visit(visitor, pipeline, visited) + # The value should be visited now since we visit outputs too. + assert pval in visited, pval + + # Visit side inputs. + for pval in self.side_inputs: + if isinstance(pval, pvalue.PCollectionView) and pval not in visited: + assert pval.producer is not None + pval.producer.visit(visitor, pipeline, visited) + # The value should be visited now since we visit outputs too. + assert pval in visited + # TODO(silviuc): Is there a way to signal that we are visiting a side + # value? The issue is that the same PValue can be reachable through + # multiple paths and therefore it is not guaranteed that the value + # will be visited as a side value. + + # Visit a composite or primitive transform. + if self.is_composite(): + visitor.enter_composite_transform(self) + for part in self.parts: + part.visit(visitor, pipeline, visited) + visitor.leave_composite_transform(self) + else: + visitor.visit_transform(self) + + # Visit the outputs (one or more). It is essential to mark as visited the + # tagged PCollections of the DoOutputsTuple object. A tagged PCollection is + # connected directly with its producer (a multi-output ParDo), but the + # output of such a transform is the containing DoOutputsTuple, not the + # PCollection inside it. Without the code below a tagged PCollection will + # not be marked as visited while visiting its producer. + for pval in self.outputs.values(): + if isinstance(pval, pvalue.DoOutputsTuple): + pvals = (v for v in pval) + else: + pvals = (pval,) + for v in pvals: + if v not in visited: + visited.add(v) + visitor.visit_value(v, self)
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/b14dfadd/sdks/python/apache_beam/pipeline_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/pipeline_test.py b/sdks/python/apache_beam/pipeline_test.py new file mode 100644 index 0000000..ce3bd6d --- /dev/null +++ b/sdks/python/apache_beam/pipeline_test.py @@ -0,0 +1,345 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the Pipeline class.""" + +import gc +import logging +import unittest + +from google.cloud.dataflow.io.iobase import NativeSource +from google.cloud.dataflow.pipeline import Pipeline +from google.cloud.dataflow.pipeline import PipelineOptions +from google.cloud.dataflow.pipeline import PipelineVisitor +from google.cloud.dataflow.pvalue import AsIter +from google.cloud.dataflow.pvalue import SideOutputValue +from google.cloud.dataflow.transforms import CombinePerKey +from google.cloud.dataflow.transforms import Create +from google.cloud.dataflow.transforms import FlatMap +from google.cloud.dataflow.transforms import Flatten +from google.cloud.dataflow.transforms import Map +from google.cloud.dataflow.transforms import PTransform +from google.cloud.dataflow.transforms import Read +from google.cloud.dataflow.transforms.util import assert_that, equal_to + + +class FakeSource(NativeSource): + """Fake source returning a fixed list of values.""" + + class _Reader(object): + + def __init__(self, vals): + self._vals = vals + + def __enter__(self): + return self + + def __exit__(self, exception_type, exception_value, traceback): + pass + + def __iter__(self): + for v in self._vals: + yield v + + def __init__(self, vals): + self._vals = vals + + def reader(self): + return FakeSource._Reader(self._vals) + + +class PipelineTest(unittest.TestCase): + + def setUp(self): + self.runner_name = 'DirectPipelineRunner' + + @staticmethod + def custom_callable(pcoll): + return pcoll | FlatMap('+1', lambda x: [x + 1]) + + # Some of these tests designate a runner by name, others supply a runner. + # This variation is just to verify that both means of runner specification + # work and is not related to other aspects of the tests. + + class CustomTransform(PTransform): + + def apply(self, pcoll): + return pcoll | FlatMap('+1', lambda x: [x + 1]) + + class Visitor(PipelineVisitor): + + def __init__(self, visited): + self.visited = visited + self.enter_composite = [] + self.leave_composite = [] + + def visit_value(self, value, _): + self.visited.append(value) + + def enter_composite_transform(self, transform_node): + self.enter_composite.append(transform_node) + + def leave_composite_transform(self, transform_node): + self.leave_composite.append(transform_node) + + def test_create(self): + pipeline = Pipeline(self.runner_name) + pcoll = pipeline | Create('label1', [1, 2, 3]) + assert_that(pcoll, equal_to([1, 2, 3])) + + # Test if initial value is an iterator object. + pcoll2 = pipeline | Create('label2', iter((4, 5, 6))) + pcoll3 = pcoll2 | FlatMap('do', lambda x: [x + 10]) + assert_that(pcoll3, equal_to([14, 15, 16]), label='pcoll3') + pipeline.run() + + def test_create_singleton_pcollection(self): + pipeline = Pipeline(self.runner_name) + pcoll = pipeline | Create('label', [[1, 2, 3]]) + assert_that(pcoll, equal_to([[1, 2, 3]])) + pipeline.run() + + def test_read(self): + pipeline = Pipeline(self.runner_name) + pcoll = pipeline | Read('read', FakeSource([1, 2, 3])) + assert_that(pcoll, equal_to([1, 2, 3])) + pipeline.run() + + def test_visit_entire_graph(self): + pipeline = Pipeline(self.runner_name) + pcoll1 = pipeline | Create('pcoll', [1, 2, 3]) + pcoll2 = pcoll1 | FlatMap('do1', lambda x: [x + 1]) + pcoll3 = pcoll2 | FlatMap('do2', lambda x: [x + 1]) + pcoll4 = pcoll2 | FlatMap('do3', lambda x: [x + 1]) + transform = PipelineTest.CustomTransform() + pcoll5 = pcoll4 | transform + + visitor = PipelineTest.Visitor(visited=[]) + pipeline.visit(visitor) + self.assertEqual(set([pcoll1, pcoll2, pcoll3, pcoll4, pcoll5]), + set(visitor.visited)) + self.assertEqual(set(visitor.enter_composite), + set(visitor.leave_composite)) + self.assertEqual(2, len(visitor.enter_composite)) + self.assertEqual(visitor.enter_composite[1].transform, transform) + self.assertEqual(visitor.leave_composite[0].transform, transform) + + def test_apply_custom_transform(self): + pipeline = Pipeline(self.runner_name) + pcoll = pipeline | Create('pcoll', [1, 2, 3]) + result = pcoll | PipelineTest.CustomTransform() + assert_that(result, equal_to([2, 3, 4])) + pipeline.run() + + def test_reuse_custom_transform_instance(self): + pipeline = Pipeline(self.runner_name) + pcoll1 = pipeline | Create('pcoll1', [1, 2, 3]) + pcoll2 = pipeline | Create('pcoll2', [4, 5, 6]) + transform = PipelineTest.CustomTransform() + pcoll1 | transform + with self.assertRaises(RuntimeError) as cm: + pipeline.apply(transform, pcoll2) + self.assertEqual( + cm.exception.message, + 'Transform "CustomTransform" does not have a stable unique label. ' + 'This will prevent updating of pipelines. ' + 'To clone a transform with a new label use: ' + 'transform.clone("NEW LABEL").') + + def test_reuse_cloned_custom_transform_instance(self): + pipeline = Pipeline(self.runner_name) + pcoll1 = pipeline | Create('pcoll1', [1, 2, 3]) + pcoll2 = pipeline | Create('pcoll2', [4, 5, 6]) + transform = PipelineTest.CustomTransform() + result1 = pcoll1 | transform + result2 = pcoll2 | transform.clone('new label') + assert_that(result1, equal_to([2, 3, 4]), label='r1') + assert_that(result2, equal_to([5, 6, 7]), label='r2') + pipeline.run() + + def test_apply_custom_callable(self): + pipeline = Pipeline(self.runner_name) + pcoll = pipeline | Create('pcoll', [1, 2, 3]) + result = pipeline.apply(PipelineTest.custom_callable, pcoll) + assert_that(result, equal_to([2, 3, 4])) + pipeline.run() + + def test_transform_no_super_init(self): + class AddSuffix(PTransform): + + def __init__(self, suffix): + # No call to super(...).__init__ + self.suffix = suffix + + def apply(self, pcoll): + return pcoll | Map(lambda x: x + self.suffix) + + self.assertEqual( + ['a-x', 'b-x', 'c-x'], + sorted(['a', 'b', 'c'] | AddSuffix('-x'))) + + def test_cached_pvalues_are_refcounted(self): + """Test that cached PValues are refcounted and deleted. + + The intermediary PValues computed by the workflow below contain + one million elements so if the refcounting does not work the number of + objects tracked by the garbage collector will increase by a few millions + by the time we execute the final Map checking the objects tracked. + Anything that is much larger than what we started with will fail the test. + """ + def check_memory(value, count_threshold): + gc.collect() + objects_count = len(gc.get_objects()) + if objects_count > count_threshold: + raise RuntimeError( + 'PValues are not refcounted: %s, %s' % ( + objects_count, count_threshold)) + return value + + def create_dupes(o, _): + yield o + yield SideOutputValue('side', o) + + pipeline = Pipeline('DirectPipelineRunner') + + gc.collect() + count_threshold = len(gc.get_objects()) + 10000 + biglist = pipeline | Create('oom:create', ['x'] * 1000000) + dupes = ( + biglist + | Map('oom:addone', lambda x: (x, 1)) + | FlatMap('oom:dupes', create_dupes, + AsIter(biglist)).with_outputs('side', main='main')) + result = ( + (dupes.side, dupes.main, dupes.side) + | Flatten('oom:flatten') + | CombinePerKey('oom:combine', sum) + | Map('oom:check', check_memory, count_threshold)) + + assert_that(result, equal_to([('x', 3000000)])) + pipeline.run() + self.assertEqual( + pipeline.runner.debug_counters['element_counts'], + { + 'oom:flatten': 3000000, + ('oom:combine/GroupByKey/reify_windows', None): 3000000, + ('oom:dupes/oom:dupes', 'side'): 1000000, + ('oom:dupes/oom:dupes', None): 1000000, + 'oom:create': 1000000, + ('oom:addone', None): 1000000, + 'oom:combine/GroupByKey/group_by_key': 1, + ('oom:check', None): 1, + 'assert_that/singleton': 1, + ('assert_that/Map(match)', None): 1, + ('oom:combine/GroupByKey/group_by_window', None): 1, + ('oom:combine/Combine/ParDo(CombineValuesDoFn)', None): 1}) + + def test_pipeline_as_context(self): + def raise_exception(exn): + raise exn + with self.assertRaises(ValueError): + with Pipeline(self.runner_name) as p: + # pylint: disable=expression-not-assigned + p | Create([ValueError]) | Map(raise_exception) + + def test_eager_pipeline(self): + p = Pipeline('EagerPipelineRunner') + self.assertEqual([1, 4, 9], p | Create([1, 2, 3]) | Map(lambda x: x*x)) + + +class DiskCachedRunnerPipelineTest(PipelineTest): + + def setUp(self): + self.runner_name = 'DiskCachedPipelineRunner' + + def test_cached_pvalues_are_refcounted(self): + # Takes long with disk spilling. + pass + + def test_eager_pipeline(self): + # Tests eager runner only + pass + + +class Bacon(PipelineOptions): + + @classmethod + def _add_argparse_args(cls, parser): + parser.add_argument('--slices', type=int) + + +class Eggs(PipelineOptions): + + @classmethod + def _add_argparse_args(cls, parser): + parser.add_argument('--style', default='scrambled') + + +class Breakfast(Bacon, Eggs): + pass + + +class PipelineOptionsTest(unittest.TestCase): + + def test_flag_parsing(self): + options = Breakfast(['--slices=3', '--style=sunny side up', '--ignored']) + self.assertEquals(3, options.slices) + self.assertEquals('sunny side up', options.style) + + def test_keyword_parsing(self): + options = Breakfast( + ['--slices=3', '--style=sunny side up', '--ignored'], + slices=10) + self.assertEquals(10, options.slices) + self.assertEquals('sunny side up', options.style) + + def test_attribute_setting(self): + options = Breakfast(slices=10) + self.assertEquals(10, options.slices) + options.slices = 20 + self.assertEquals(20, options.slices) + + def test_view_as(self): + generic_options = PipelineOptions(['--slices=3']) + self.assertEquals(3, generic_options.view_as(Bacon).slices) + self.assertEquals(3, generic_options.view_as(Breakfast).slices) + + generic_options.view_as(Breakfast).slices = 10 + self.assertEquals(10, generic_options.view_as(Bacon).slices) + + with self.assertRaises(AttributeError): + generic_options.slices # pylint: disable=pointless-statement + + with self.assertRaises(AttributeError): + generic_options.view_as(Eggs).slices # pylint: disable=expression-not-assigned + + def test_defaults(self): + options = Breakfast(['--slices=3']) + self.assertEquals(3, options.slices) + self.assertEquals('scrambled', options.style) + + def test_dir(self): + options = Breakfast() + self.assertEquals( + ['from_dictionary', 'get_all_options', 'slices', 'style', 'view_as'], + [attr for attr in dir(options) if not attr.startswith('_')]) + self.assertEquals( + ['from_dictionary', 'get_all_options', 'style', 'view_as'], + [attr for attr in dir(options.view_as(Eggs)) + if not attr.startswith('_')]) + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.DEBUG) + unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/b14dfadd/sdks/python/apache_beam/pvalue.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/pvalue.py b/sdks/python/apache_beam/pvalue.py new file mode 100644 index 0000000..5e40706 --- /dev/null +++ b/sdks/python/apache_beam/pvalue.py @@ -0,0 +1,459 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PValue, PCollection: one node of a dataflow graph. + +A node of a dataflow processing graph is a PValue. Currently, there is only +one type: PCollection (a potentially very large set of arbitrary values). +Once created, a PValue belongs to a pipeline and has an associated +transform (of type PTransform), which describes how the value will be +produced when the pipeline gets executed. +""" + +from __future__ import absolute_import + +import collections + + +class PValue(object): + """Base class for PCollection. + + Dataflow users should not construct PValue objects directly in their + pipelines. + + A PValue has the following main characteristics: + (1) Belongs to a pipeline. Added during object initialization. + (2) Has a transform that can compute the value if executed. + (3) Has a value which is meaningful if the transform was executed. + """ + + def __init__(self, pipeline, tag=None, element_type=None): + """Initializes a PValue with all arguments hidden behind keyword arguments. + + Args: + pipeline: Pipeline object for this PValue. + tag: Tag of this PValue. + element_type: The type of this PValue. + """ + self.pipeline = pipeline + self.tag = tag + self.element_type = element_type + # The AppliedPTransform instance for the application of the PTransform + # generating this PValue. The field gets initialized when a transform + # gets applied. + self.producer = None + + def __str__(self): + return '<%s>' % self._str_internal() + + def __repr__(self): + return '<%s at %s>' % (self._str_internal(), hex(id(self))) + + def _str_internal(self): + return '%s transform=%s' % ( + self.__class__.__name__, + self.producer.transform if self.producer else 'n/a') + + def apply(self, *args, **kwargs): + """Applies a transform or callable to a PValue. + + Args: + *args: positional arguments. + **kwargs: keyword arguments. + + The method will insert the pvalue as the next argument following an + optional first label and a transform/callable object. It will call the + pipeline.apply() method with this modified argument list. + """ + if isinstance(args[0], str): + # TODO(robertwb): Make sure labels are properly passed during + # ptransform construction and drop this argument. + args = args[1:] + arglist = list(args) + arglist.insert(1, self) + return self.pipeline.apply(*arglist, **kwargs) + + def __or__(self, ptransform): + return self.pipeline.apply(ptransform, self) + + +class PCollection(PValue): + """A multiple values (potentially huge) container. + + Dataflow users should not construct PCollection objects directly in their + pipelines. + """ + + def __init__(self, pipeline, **kwargs): + """Initializes a PCollection. Do not call directly.""" + super(PCollection, self).__init__(pipeline, **kwargs) + + @property + def windowing(self): + if not hasattr(self, '_windowing'): + self._windowing = self.producer.transform.get_windowing( + self.producer.inputs) + return self._windowing + + def __reduce_ex__(self, unused_version): + # Pickling a PCollection is almost always the wrong thing to do, but we + # can't prohibit it as it often gets implicitly picked up (e.g. as part + # of a closure). + return _InvalidUnpickledPCollection, () + + +class _InvalidUnpickledPCollection(object): + pass + + +class PBegin(PValue): + """A pipeline begin marker used as input to create/read transforms. + + The class is used internally to represent inputs to Create and Read + transforms. This allows us to have transforms that uniformly take PValue(s) + as inputs. + """ + pass + + +class PDone(PValue): + """PDone is the output of a transform that has a trivial result such as Write. + """ + pass + + +class DoOutputsTuple(object): + """An object grouping the multiple outputs of a ParDo or FlatMap transform.""" + + def __init__(self, pipeline, transform, tags, main_tag): + self._pipeline = pipeline + self._tags = tags + self._main_tag = main_tag + self._transform = transform + # The ApplyPTransform instance for the application of the multi FlatMap + # generating this value. The field gets initialized when a transform + # gets applied. + self.producer = None + # Dictionary of PCollections already associated with tags. + self._pcolls = {} + + def __str__(self): + return '<%s>' % self._str_internal() + + def __repr__(self): + return '<%s at %s>' % (self._str_internal(), hex(id(self))) + + def _str_internal(self): + return '%s main_tag=%s tags=%s transform=%s' % ( + self.__class__.__name__, self._main_tag, self._tags, self._transform) + + def __iter__(self): + """Iterates over tags returning for each call a (tag, pvalue) pair.""" + if self._main_tag is not None: + yield self[self._main_tag] + for tag in self._tags: + yield self[tag] + + def __getattr__(self, tag): + # Special methods which may be accessed before the object is + # fully constructed (e.g. in unpickling). + if tag[:2] == tag[-2:] == '__': + return object.__getattr__(self, tag) + return self[tag] + + def __getitem__(self, tag): + # Accept int tags so that we can look at Partition tags with the + # same ints that we used in the partition function. + # TODO(gildea): Consider requiring string-based tags everywhere. + # This will require a partition function that does not return ints. + if isinstance(tag, int): + tag = str(tag) + if tag == self._main_tag: + tag = None + elif self._tags and tag not in self._tags: + raise ValueError( + 'Tag %s is neither the main tag %s nor any of the side tags %s' % ( + tag, self._main_tag, self._tags)) + # Check if we accessed this tag before. + if tag in self._pcolls: + return self._pcolls[tag] + if tag is not None: + self._transform.side_output_tags.add(tag) + pcoll = PCollection(self._pipeline, tag=tag) + # Transfer the producer from the DoOutputsTuple to the resulting + # PCollection. + pcoll.producer = self.producer + self.producer.add_output(pcoll, tag) + self._pcolls[tag] = pcoll + return pcoll + + +class SideOutputValue(object): + """An object representing a tagged value. + + ParDo, Map, and FlatMap transforms can emit values on multiple outputs which + are distinguished by string tags. The DoFn will return plain values + if it wants to emit on the main output and SideOutputValue objects + if it wants to emit a value on a specific tagged output. + """ + + def __init__(self, tag, value): + if not isinstance(tag, basestring): + raise TypeError( + 'Attempting to create a SideOutputValue with non-string tag %s' % tag) + self.tag = tag + self.value = value + + +class PCollectionView(PValue): + """An immutable view of a PCollection that can be used as a side input.""" + + def __init__(self, pipeline): + """Initializes a PCollectionView. Do not call directly.""" + super(PCollectionView, self).__init__(pipeline) + + @property + def windowing(self): + if not hasattr(self, '_windowing'): + self._windowing = self.producer.transform.get_windowing( + self.producer.inputs) + return self._windowing + + def _view_options(self): + """Internal options corresponding to specific view. + + Intended for internal use by runner implementations. + + Returns: + Tuple of options for the given view. + """ + return () + + +class SingletonPCollectionView(PCollectionView): + """A PCollectionView that contains a single object.""" + + def __init__(self, pipeline, has_default, default_value): + super(SingletonPCollectionView, self).__init__(pipeline) + self.has_default = has_default + self.default_value = default_value + + def _view_options(self): + return (self.has_default, self.default_value) + + +class IterablePCollectionView(PCollectionView): + """A PCollectionView that can be treated as an iterable.""" + pass + + +class ListPCollectionView(PCollectionView): + """A PCollectionView that can be treated as a list.""" + pass + + +class DictPCollectionView(PCollectionView): + """A PCollectionView that can be treated as a dict.""" + pass + + +def _get_cached_view(pipeline, key): + return pipeline._view_cache.get(key, None) # pylint: disable=protected-access + + +def _cache_view(pipeline, key, view): + pipeline._view_cache[key] = view # pylint: disable=protected-access + + +def can_take_label_as_first_argument(callee): + """Decorator to allow the "label" kwarg to be passed as the first argument. + + For example, since AsSingleton is annotated with this decorator, this allows + the call "AsSingleton(pcoll, label='label1')" to be written more succinctly + as "AsSingleton('label1', pcoll)". + + Args: + callee: The callable to be called with an optional label argument. + + Returns: + Callable that allows (but does not require) a string label as its first + argument. + """ + def _inner(maybe_label, *args, **kwargs): + if isinstance(maybe_label, basestring): + return callee(*args, label=maybe_label, **kwargs) + return callee(*((maybe_label,) + args), **kwargs) + return _inner + + +def _format_view_label(pcoll): + # The monitoring UI doesn't like '/' character in transform labels. + if not pcoll.producer: + return str(pcoll.tag) + return '%s.%s' % (pcoll.producer.full_label.replace('/', '|'), + pcoll.tag) + + +_SINGLETON_NO_DEFAULT = object() + + +@can_take_label_as_first_argument +def AsSingleton(pcoll, default_value=_SINGLETON_NO_DEFAULT, label=None): # pylint: disable=invalid-name + """Create a SingletonPCollectionView from the contents of input PCollection. + + The input PCollection should contain at most one element (per window) and the + resulting PCollectionView can then be used as a side input to PTransforms. If + the PCollectionView is empty (for a given window), the side input value will + be the default_value, if specified; otherwise, it will be an EmptySideInput + object. + + Args: + pcoll: Input pcollection. + default_value: Default value for the singleton view. + label: Label to be specified if several AsSingleton's with different + defaults for the same PCollection. + + Returns: + A singleton PCollectionView containing the element as above. + """ + label = label or _format_view_label(pcoll) + has_default = default_value is not _SINGLETON_NO_DEFAULT + if not has_default: + default_value = None + + # Don't recreate the view if it was already created. + hashable_default_value = ('val', default_value) + if not isinstance(default_value, collections.Hashable): + # Massage default value to treat as hash key. + hashable_default_value = ('id', id(default_value)) + cache_key = (pcoll, AsSingleton, has_default, hashable_default_value) + cached_view = _get_cached_view(pcoll.pipeline, cache_key) + if cached_view: + return cached_view + + # Local import is required due to dependency loop; even though the + # implementation of this function requires concepts defined in modules that + # depend on pvalue, it lives in this module to reduce user workload. + from google.cloud.dataflow.transforms import sideinputs # pylint: disable=g-import-not-at-top + view = (pcoll | sideinputs.ViewAsSingleton(has_default, default_value, + label=label)) + _cache_view(pcoll.pipeline, cache_key, view) + return view + + +@can_take_label_as_first_argument +def AsIter(pcoll, label=None): # pylint: disable=invalid-name + """Create an IterablePCollectionView from the elements of input PCollection. + + The contents of the given PCollection will be available as an iterable in + PTransforms that use the returned PCollectionView as a side input. + + Args: + pcoll: Input pcollection. + label: Label to be specified if several AsIter's for the same PCollection. + + Returns: + An iterable PCollectionView containing the elements as above. + """ + label = label or _format_view_label(pcoll) + + # Don't recreate the view if it was already created. + cache_key = (pcoll, AsIter) + cached_view = _get_cached_view(pcoll.pipeline, cache_key) + if cached_view: + return cached_view + + # Local import is required due to dependency loop; even though the + # implementation of this function requires concepts defined in modules that + # depend on pvalue, it lives in this module to reduce user workload. + from google.cloud.dataflow.transforms import sideinputs # pylint: disable=g-import-not-at-top + view = (pcoll | sideinputs.ViewAsIterable(label=label)) + _cache_view(pcoll.pipeline, cache_key, view) + return view + + +@can_take_label_as_first_argument +def AsList(pcoll, label=None): # pylint: disable=invalid-name + """Create a ListPCollectionView from the elements of input PCollection. + + The contents of the given PCollection will be available as a list-like object + in PTransforms that use the returned PCollectionView as a side input. + + Args: + pcoll: Input pcollection. + label: Label to be specified if several AsList's for the same PCollection. + + Returns: + A list PCollectionView containing the elements as above. + """ + label = label or _format_view_label(pcoll) + + # Don't recreate the view if it was already created. + cache_key = (pcoll, AsList) + cached_view = _get_cached_view(pcoll.pipeline, cache_key) + if cached_view: + return cached_view + + # Local import is required due to dependency loop; even though the + # implementation of this function requires concepts defined in modules that + # depend on pvalue, it lives in this module to reduce user workload. + from google.cloud.dataflow.transforms import sideinputs # pylint: disable=g-import-not-at-top + view = (pcoll | sideinputs.ViewAsList(label=label)) + _cache_view(pcoll.pipeline, cache_key, view) + return view + + +@can_take_label_as_first_argument +def AsDict(pcoll, label=None): # pylint: disable=invalid-name + """Create a DictPCollectionView from the elements of input PCollection. + + The contents of the given PCollection whose elements are 2-tuples of key and + value will be available as a dict-like object in PTransforms that use the + returned PCollectionView as a side input. + + Args: + pcoll: Input pcollection containing 2-tuples of key and value. + label: Label to be specified if several AsDict's for the same PCollection. + + Returns: + A dict PCollectionView containing the dict as above. + """ + label = label or _format_view_label(pcoll) + + # Don't recreate the view if it was already created. + cache_key = (pcoll, AsDict) + cached_view = _get_cached_view(pcoll.pipeline, cache_key) + if cached_view: + return cached_view + + # Local import is required due to dependency loop; even though the + # implementation of this function requires concepts defined in modules that + # depend on pvalue, it lives in this module to reduce user workload. + from google.cloud.dataflow.transforms import sideinputs # pylint: disable=g-import-not-at-top + view = (pcoll | sideinputs.ViewAsDict(label=label)) + _cache_view(pcoll.pipeline, cache_key, view) + return view + + +class EmptySideInput(object): + """Value indicating when a singleton side input was empty. + + If a PCollection was furnished as a singleton side input to a PTransform, and + that PCollection was empty, then this value is supplied to the DoFn in the + place where a value from a non-empty PCollection would have gone. This alerts + the DoFn that the side input PCollection was empty. Users may want to check + whether side input values are EmptySideInput, but they will very likely never + want to create new instances of this class themselves. + """ + pass http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/b14dfadd/sdks/python/apache_beam/pvalue_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/pvalue_test.py b/sdks/python/apache_beam/pvalue_test.py new file mode 100644 index 0000000..d3c1c44 --- /dev/null +++ b/sdks/python/apache_beam/pvalue_test.py @@ -0,0 +1,63 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the PValue and PCollection classes.""" + +import unittest + +from google.cloud.dataflow.pipeline import Pipeline +from google.cloud.dataflow.pvalue import AsDict +from google.cloud.dataflow.pvalue import AsIter +from google.cloud.dataflow.pvalue import AsList +from google.cloud.dataflow.pvalue import AsSingleton +from google.cloud.dataflow.pvalue import PValue +from google.cloud.dataflow.transforms import Create + + +class FakePipeline(Pipeline): + """Fake pipeline object used to check if apply() receives correct args.""" + + def apply(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + + +class PValueTest(unittest.TestCase): + + def test_pvalue_expected_arguments(self): + pipeline = Pipeline('DirectPipelineRunner') + value = PValue(pipeline) + self.assertEqual(pipeline, value.pipeline) + + def test_pcollectionview_not_recreated(self): + pipeline = Pipeline('DirectPipelineRunner') + value = pipeline | Create('create1', [1, 2, 3]) + value2 = pipeline | Create('create2', [(1, 1), (2, 2), (3, 3)]) + self.assertEqual(AsSingleton(value), AsSingleton(value)) + self.assertEqual(AsSingleton('new', value, default_value=1), + AsSingleton('new', value, default_value=1)) + self.assertNotEqual(AsSingleton(value), + AsSingleton('new', value, default_value=1)) + self.assertEqual(AsIter(value), AsIter(value)) + self.assertEqual(AsList(value), AsList(value)) + self.assertEqual(AsDict(value2), AsDict(value2)) + + self.assertNotEqual(AsSingleton(value), AsSingleton(value2)) + self.assertNotEqual(AsIter(value), AsIter(value2)) + self.assertNotEqual(AsList(value), AsList(value2)) + self.assertNotEqual(AsDict(value), AsDict(value2)) + + +if __name__ == '__main__': + unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/b14dfadd/sdks/python/apache_beam/python_sdk_releases.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/python_sdk_releases.py b/sdks/python/apache_beam/python_sdk_releases.py new file mode 100644 index 0000000..52e07aa --- /dev/null +++ b/sdks/python/apache_beam/python_sdk_releases.py @@ -0,0 +1,53 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Descriptions of the versions of the SDK. + +This manages the features and tests supported by different versions of the +Dataflow SDK for Python. + +To add feature 'foo' to a particular release, add a 'properties' value with +'feature_foo': True. To remove feature 'foo' from a particular release, add a +'properties' value with 'feature_foo': False. Features are cumulative and can +be added and removed multiple times. + +By default, all tests are enabled. To remove test 'bar' from a particular +release, add a 'properties' value with 'test_bar': False. To add it back to a +subsequent release, add a 'properties' value with 'test_bar': True. Tests are +cumulative and can be removed and added multiple times. + +See go/dataflow-testing for more information. +""" + +OLDEST_SUPPORTED_PYTHON_SDK = 'python-0.1.4' + +RELEASES = [ + {'name': 'python-0.2.7',}, + {'name': 'python-0.2.6',}, + {'name': 'python-0.2.5',}, + {'name': 'python-0.2.4',}, + {'name': 'python-0.2.3',}, + {'name': 'python-0.2.2',}, + {'name': 'python-0.2.1',}, + {'name': 'python-0.2.0',}, + {'name': 'python-0.1.5',}, + {'name': 'python-0.1.4',}, + {'name': 'python-0.1.3',}, + {'name': 'python-0.1.2',}, + {'name': 'python-0.1.1', + 'properties': { + 'feature_python': True, + } + }, +] http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/b14dfadd/sdks/python/apache_beam/runners/__init__.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/__init__.py b/sdks/python/apache_beam/runners/__init__.py new file mode 100644 index 0000000..06d1af4 --- /dev/null +++ b/sdks/python/apache_beam/runners/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Runner objects execute a Pipeline. + +This package defines runners, which are used to execute a pipeline. +""" + +from google.cloud.dataflow.runners.dataflow_runner import DataflowPipelineRunner +from google.cloud.dataflow.runners.direct_runner import DirectPipelineRunner +from google.cloud.dataflow.runners.runner import create_runner +from google.cloud.dataflow.runners.runner import PipelineRunner +from google.cloud.dataflow.runners.runner import PipelineState http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/b14dfadd/sdks/python/apache_beam/runners/common.pxd ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/common.pxd b/sdks/python/apache_beam/runners/common.pxd new file mode 100644 index 0000000..fa1e3d6 --- /dev/null +++ b/sdks/python/apache_beam/runners/common.pxd @@ -0,0 +1,28 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cdef type SideOutputValue, TimestampedValue, WindowedValue + +cdef class DoFnRunner(object): + + cdef object dofn + cdef object window_fn + cdef object context + cdef object tagged_receivers + cdef object logger + cdef object step_name + + cdef object main_receivers + + cpdef _process_outputs(self, element, results) http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/b14dfadd/sdks/python/apache_beam/runners/common.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/common.py b/sdks/python/apache_beam/runners/common.py new file mode 100644 index 0000000..34e480b --- /dev/null +++ b/sdks/python/apache_beam/runners/common.py @@ -0,0 +1,181 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# cython: profile=True + +"""Worker operations executor.""" + +import sys + +from google.cloud.dataflow.internal import util +from google.cloud.dataflow.pvalue import SideOutputValue +from google.cloud.dataflow.transforms import core +from google.cloud.dataflow.transforms.window import TimestampedValue +from google.cloud.dataflow.transforms.window import WindowedValue +from google.cloud.dataflow.transforms.window import WindowFn + + +class FakeLogger(object): + def PerThreadLoggingContext(self, *unused_args, **unused_kwargs): + return self + def __enter__(self): + pass + def __exit__(self, *unused_args): + pass + + +class DoFnRunner(object): + """A helper class for executing ParDo operations. + """ + + def __init__(self, + fn, + args, + kwargs, + side_inputs, + windowing, + context, + tagged_receivers, + logger=None, + step_name=None): + if not args and not kwargs: + self.dofn = fn + else: + args, kwargs = util.insert_values_in_args(args, kwargs, side_inputs) + + class CurriedFn(core.DoFn): + + def start_bundle(self, context): + return fn.start_bundle(context, *args, **kwargs) + + def process(self, context): + return fn.process(context, *args, **kwargs) + + def finish_bundle(self, context): + return fn.finish_bundle(context, *args, **kwargs) + self.dofn = CurriedFn() + self.window_fn = windowing.windowfn + self.context = context + self.tagged_receivers = tagged_receivers + self.logger = logger or FakeLogger() + self.step_name = step_name + + # Optimize for the common case. + self.main_receivers = tagged_receivers[None] + + def start(self): + self.context.set_element(None) + try: + self._process_outputs(None, self.dofn.start_bundle(self.context)) + except BaseException as exn: + self.reraise_augmented(exn) + + def finish(self): + self.context.set_element(None) + try: + self._process_outputs(None, self.dofn.finish_bundle(self.context)) + except BaseException as exn: + self.reraise_augmented(exn) + + def process(self, element): + try: + with self.logger.PerThreadLoggingContext(step_name=self.step_name): + self.context.set_element(element) + self._process_outputs(element, self.dofn.process(self.context)) + except BaseException as exn: + self.reraise_augmented(exn) + + def reraise_augmented(self, exn): + if getattr(exn, '_tagged_with_step', False) or not self.step_name: + raise + args = exn.args + if args and isinstance(args[0], str): + args = (args[0] + " [while running '%s']" % self.step_name,) + args[1:] + # Poor man's exception chaining. + raise type(exn), args, sys.exc_info()[2] + else: + raise + + def _process_outputs(self, element, results): + """Dispatch the result of computation to the appropriate receivers. + + A value wrapped in a SideOutputValue object will be unwrapped and + then dispatched to the appropriate indexed output. + """ + if results is None: + return + for result in results: + tag = None + if isinstance(result, SideOutputValue): + tag = result.tag + if not isinstance(tag, basestring): + raise TypeError('In %s, tag %s is not a string' % (self, tag)) + result = result.value + if isinstance(result, WindowedValue): + windowed_value = result + elif element is None: + # Start and finish have no element from which to grab context, + # but may emit elements. + if isinstance(result, TimestampedValue): + value = result.value + timestamp = result.timestamp + assign_context = NoContext(value, timestamp) + else: + value = result + timestamp = -1 + assign_context = NoContext(value) + windowed_value = WindowedValue( + value, timestamp, self.window_fn.assign(assign_context)) + elif isinstance(result, TimestampedValue): + assign_context = WindowFn.AssignContext( + result.timestamp, result.value, element.windows) + windowed_value = WindowedValue( + result.value, result.timestamp, + self.window_fn.assign(assign_context)) + else: + windowed_value = element.with_value(result) + if tag is None: + self.main_receivers.output(windowed_value) + else: + self.tagged_receivers[tag].output(windowed_value) + +class NoContext(WindowFn.AssignContext): + """An uninspectable WindowFn.AssignContext.""" + NO_VALUE = object() + def __init__(self, value, timestamp=NO_VALUE): + self.value = value + self._timestamp = timestamp + @property + def timestamp(self): + if self._timestamp is self.NO_VALUE: + raise ValueError('No timestamp in this context.') + else: + return self._timestamp + @property + def existing_windows(self): + raise ValueError('No existing_windows in this context.') + + +class DoFnState(object): + """Keeps track of state that DoFns want, currently, user counters. + """ + + def __init__(self, counter_factory): + self.step_name = '' + self._counter_factory = counter_factory + + def counter_for(self, aggregator): + """Looks up the counter for this aggregator, creating one if necessary.""" + return self._counter_factory.get_aggregator_counter( + self.step_name, aggregator)