Repository: beam
Updated Branches:
  refs/heads/master 9df1da493 -> 36ed3f6c5


http://git-wip-us.apache.org/repos/asf/beam/blob/4337c3ec/sdks/python/apache_beam/runners/google_cloud_dataflow/internal/clients/dataflow/message_matchers.py
----------------------------------------------------------------------
diff --git 
a/sdks/python/apache_beam/runners/google_cloud_dataflow/internal/clients/dataflow/message_matchers.py
 
b/sdks/python/apache_beam/runners/google_cloud_dataflow/internal/clients/dataflow/message_matchers.py
new file mode 100644
index 0000000..4dda47a
--- /dev/null
+++ 
b/sdks/python/apache_beam/runners/google_cloud_dataflow/internal/clients/dataflow/message_matchers.py
@@ -0,0 +1,124 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from hamcrest.core.base_matcher import BaseMatcher
+
+
+IGNORED = object()
+
+
+class MetricStructuredNameMatcher(BaseMatcher):
+  """Matches a MetricStructuredName."""
+  def __init__(self,
+               name=IGNORED,
+               origin=IGNORED,
+               context=IGNORED):
+    """Creates a MetricsStructuredNameMatcher.
+
+    Any property not passed in to the constructor will be ignored when 
matching.
+
+    Args:
+      name: A string with the metric name.
+      origin: A string with the metric namespace.
+      context: A key:value dictionary that will be matched to the
+        structured name.
+    """
+    if context != IGNORED and not isinstance(context, dict):
+      raise ValueError('context must be a Python dictionary.')
+
+    self.name = name
+    self.origin = origin
+    self.context = context
+
+  def _matches(self, item):
+    if self.name != IGNORED and item.name != self.name:
+      return False
+    if self.origin != IGNORED and item.origin != self.origin:
+      return False
+    if self.context != IGNORED:
+      for key, name in self.context.iteritems():
+        if key not in item.context:
+          return False
+        if name != IGNORED and item.context[key] != name:
+          return False
+    return True
+
+  def describe_to(self, description):
+    descriptors = []
+    if self.name != IGNORED:
+      descriptors.append('name is {}'.format(self.name))
+    if self.origin != IGNORED:
+      descriptors.append('origin is {}'.format(self.origin))
+    if self.context != IGNORED:
+      descriptors.append('context is ({})'.format(str(self.context)))
+
+    item_description = ' and '.join(descriptors)
+    description.append(item_description)
+
+
+class MetricUpdateMatcher(BaseMatcher):
+  """Matches a metrics update protocol buffer."""
+  def __init__(self,
+               cumulative=IGNORED,
+               name=IGNORED,
+               scalar=IGNORED,
+               kind=IGNORED):
+    """Creates a MetricUpdateMatcher.
+
+    Any property not passed in to the constructor will be ignored when 
matching.
+
+    Args:
+      cumulative: A boolean.
+      name: A MetricStructuredNameMatcher object that matches the name.
+      scalar: An integer with the metric update.
+      kind: A string defining the kind of counter.
+    """
+    if name != IGNORED and not isinstance(name, MetricStructuredNameMatcher):
+      raise ValueError('name must be a MetricStructuredNameMatcher.')
+
+    self.cumulative = cumulative
+    self.name = name
+    self.scalar = scalar
+    self.kind = kind
+
+  def _matches(self, item):
+    if self.cumulative != IGNORED and item.cumulative != self.cumulative:
+      return False
+    if self.name != IGNORED and not self.name._matches(item.name):
+      return False
+    if self.kind != IGNORED and item.kind != self.kind:
+      return False
+    if self.scalar != IGNORED:
+      value_property = [p
+                        for p in item.scalar.object_value.properties
+                        if p.key == 'value']
+      int_value = value_property[0].value.integer_value
+      if self.scalar != int_value:
+        return False
+    return True
+
+  def describe_to(self, description):
+    descriptors = []
+    if self.cumulative != IGNORED:
+      descriptors.append('cumulative is {}'.format(self.cumulative))
+    if self.name != IGNORED:
+      descriptors.append('name is {}'.format(self.name))
+    if self.scalar != IGNORED:
+      descriptors.append('scalar is ({})'.format(str(self.scalar)))
+
+    item_description = ' and '.join(descriptors)
+    description.append(item_description)

http://git-wip-us.apache.org/repos/asf/beam/blob/4337c3ec/sdks/python/apache_beam/runners/google_cloud_dataflow/internal/clients/dataflow/message_matchers_test.py
----------------------------------------------------------------------
diff --git 
a/sdks/python/apache_beam/runners/google_cloud_dataflow/internal/clients/dataflow/message_matchers_test.py
 
b/sdks/python/apache_beam/runners/google_cloud_dataflow/internal/clients/dataflow/message_matchers_test.py
new file mode 100644
index 0000000..2b56ae1
--- /dev/null
+++ 
b/sdks/python/apache_beam/runners/google_cloud_dataflow/internal/clients/dataflow/message_matchers_test.py
@@ -0,0 +1,69 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import unittest
+
+import hamcrest as hc
+
+import apache_beam.runners.google_cloud_dataflow.internal.clients.dataflow as 
dataflow
+from apache_beam.internal.json_value import to_json_value
+from apache_beam.runners.google_cloud_dataflow.internal.clients.dataflow 
import message_matchers
+
+
+class TestMatchers(unittest.TestCase):
+
+  def test_structured_name_matcher_basic(self):
+    metric_name = dataflow.MetricStructuredName()
+    metric_name.name = 'metric1'
+    metric_name.origin = 'origin2'
+
+    matcher = message_matchers.MetricStructuredNameMatcher(
+        name='metric1',
+        origin='origin2')
+    hc.assert_that(metric_name, hc.is_(matcher))
+    with self.assertRaises(AssertionError):
+      matcher = message_matchers.MetricStructuredNameMatcher(
+          name='metric1',
+          origin='origin1')
+      hc.assert_that(metric_name, hc.is_(matcher))
+
+  def test_metric_update_basic(self):
+    metric_update = dataflow.MetricUpdate()
+    metric_update.name = dataflow.MetricStructuredName()
+    metric_update.name.name = 'metric1'
+    metric_update.name.origin = 'origin1'
+
+    metric_update.cumulative = False
+    metric_update.kind = 'sum'
+    metric_update.scalar = to_json_value(1, with_type=True)
+
+    name_matcher = message_matchers.MetricStructuredNameMatcher(
+        name='metric1',
+        origin='origin1')
+    matcher = message_matchers.MetricUpdateMatcher(
+        name=name_matcher,
+        kind='sum',
+        scalar=1)
+
+    hc.assert_that(metric_update, hc.is_(matcher))
+
+    with self.assertRaises(AssertionError):
+      matcher.kind = 'suma'
+      hc.assert_that(metric_update, hc.is_(matcher))
+
+
+if __name__ == '__main__':
+  unittest.main()

http://git-wip-us.apache.org/repos/asf/beam/blob/4337c3ec/sdks/python/apache_beam/runners/google_cloud_dataflow/template_runner_test.py
----------------------------------------------------------------------
diff --git 
a/sdks/python/apache_beam/runners/google_cloud_dataflow/template_runner_test.py 
b/sdks/python/apache_beam/runners/google_cloud_dataflow/template_runner_test.py
new file mode 100644
index 0000000..bbcf340
--- /dev/null
+++ 
b/sdks/python/apache_beam/runners/google_cloud_dataflow/template_runner_test.py
@@ -0,0 +1,88 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Unit tests for templated pipelines."""
+
+from __future__ import absolute_import
+
+import json
+import tempfile
+import unittest
+
+import apache_beam as beam
+from apache_beam.pipeline import Pipeline
+from apache_beam.runners.google_cloud_dataflow.dataflow_runner import 
DataflowRunner
+from apache_beam.runners.google_cloud_dataflow.internal import apiclient
+from apache_beam.utils.pipeline_options import PipelineOptions
+
+
+class TemplatingDataflowRunnerTest(unittest.TestCase):
+  """TemplatingDataflow tests."""
+  def test_full_completion(self):
+    # Create dummy file and close it.  Note that we need to do this because
+    # Windows does not allow NamedTemporaryFiles to be reopened elsewhere
+    # before the temporary file is closed.
+    dummy_file = tempfile.NamedTemporaryFile(delete=False)
+    dummy_file_name = dummy_file.name
+    dummy_file.close()
+
+    dummy_dir = tempfile.mkdtemp()
+
+    remote_runner = DataflowRunner()
+    pipeline = Pipeline(remote_runner,
+                        options=PipelineOptions([
+                            '--dataflow_endpoint=ignored',
+                            '--sdk_location=' + dummy_file_name,
+                            '--job_name=test-job',
+                            '--project=test-project',
+                            '--staging_location=' + dummy_dir,
+                            '--temp_location=/dev/null',
+                            '--template_location=' + dummy_file_name,
+                            '--no_auth=True']))
+
+    pipeline | beam.Create([1, 2, 3]) | beam.Map(lambda x: x) # pylint: 
disable=expression-not-assigned
+    pipeline.run().wait_until_finish()
+    with open(dummy_file_name) as template_file:
+      saved_job_dict = json.load(template_file)
+      self.assertEqual(
+          saved_job_dict['environment']['sdkPipelineOptions']
+          ['options']['project'], 'test-project')
+      self.assertEqual(
+          saved_job_dict['environment']['sdkPipelineOptions']
+          ['options']['job_name'], 'test-job')
+
+  def test_bad_path(self):
+    dummy_sdk_file = tempfile.NamedTemporaryFile()
+    remote_runner = DataflowRunner()
+    pipeline = Pipeline(remote_runner,
+                        options=PipelineOptions([
+                            '--dataflow_endpoint=ignored',
+                            '--sdk_location=' + dummy_sdk_file.name,
+                            '--job_name=test-job',
+                            '--project=test-project',
+                            '--staging_location=ignored',
+                            '--temp_location=/dev/null',
+                            '--template_location=/bad/path',
+                            '--no_auth=True']))
+    remote_runner.job = apiclient.Job(pipeline.options)
+
+    with self.assertRaises(IOError):
+      pipeline.run().wait_until_finish()
+
+
+if __name__ == '__main__':
+  unittest.main()

http://git-wip-us.apache.org/repos/asf/beam/blob/4337c3ec/sdks/python/apache_beam/runners/runner_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/runner_test.py 
b/sdks/python/apache_beam/runners/runner_test.py
index d7a50aa..7ec5d8a 100644
--- a/sdks/python/apache_beam/runners/runner_test.py
+++ b/sdks/python/apache_beam/runners/runner_test.py
@@ -22,30 +22,28 @@ the other unit tests. In this file we choose to test only 
aspects related to
 caching and clearing values that are not tested elsewhere.
 """
 
-from datetime import datetime
 import json
 import unittest
+from datetime import datetime
 
 import hamcrest as hc
 
 import apache_beam as beam
-
-from apache_beam.internal import apiclient
+import apache_beam.transforms as ptransform
+from apache_beam.metrics.cells import DistributionData
+from apache_beam.metrics.cells import DistributionResult
+from apache_beam.metrics.execution import MetricKey
+from apache_beam.metrics.execution import MetricResult
+from apache_beam.metrics.metricbase import MetricName
 from apache_beam.pipeline import Pipeline
-from apache_beam.runners import create_runner
 from apache_beam.runners import DataflowRunner
 from apache_beam.runners import DirectRunner
 from apache_beam.runners import TestDataflowRunner
-import apache_beam.transforms as ptransform
+from apache_beam.runners import create_runner
+from apache_beam.runners.google_cloud_dataflow.internal import apiclient
 from apache_beam.transforms.display import DisplayDataItem
 from apache_beam.utils.pipeline_options import PipelineOptions
 
-from apache_beam.metrics.cells import DistributionData
-from apache_beam.metrics.cells import DistributionResult
-from apache_beam.metrics.execution import MetricResult
-from apache_beam.metrics.execution import MetricKey
-from apache_beam.metrics.metricbase import MetricName
-
 
 class RunnerTest(unittest.TestCase):
   default_properties = [

http://git-wip-us.apache.org/repos/asf/beam/blob/4337c3ec/sdks/python/apache_beam/runners/template_runner_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/template_runner_test.py 
b/sdks/python/apache_beam/runners/template_runner_test.py
deleted file mode 100644
index b73645c..0000000
--- a/sdks/python/apache_beam/runners/template_runner_test.py
+++ /dev/null
@@ -1,88 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements.  See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License.  You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-"""Unit tests for templated pipelines."""
-
-from __future__ import absolute_import
-
-import json
-import tempfile
-import unittest
-
-import apache_beam as beam
-from apache_beam.internal import apiclient
-from apache_beam.pipeline import Pipeline
-from apache_beam.runners.google_cloud_dataflow.dataflow_runner import 
DataflowRunner
-from apache_beam.utils.pipeline_options import PipelineOptions
-
-
-class TemplatingDataflowRunnerTest(unittest.TestCase):
-  """TemplatingDataflow tests."""
-  def test_full_completion(self):
-    # Create dummy file and close it.  Note that we need to do this because
-    # Windows does not allow NamedTemporaryFiles to be reopened elsewhere
-    # before the temporary file is closed.
-    dummy_file = tempfile.NamedTemporaryFile(delete=False)
-    dummy_file_name = dummy_file.name
-    dummy_file.close()
-
-    dummy_dir = tempfile.mkdtemp()
-
-    remote_runner = DataflowRunner()
-    pipeline = Pipeline(remote_runner,
-                        options=PipelineOptions([
-                            '--dataflow_endpoint=ignored',
-                            '--sdk_location=' + dummy_file_name,
-                            '--job_name=test-job',
-                            '--project=test-project',
-                            '--staging_location=' + dummy_dir,
-                            '--temp_location=/dev/null',
-                            '--template_location=' + dummy_file_name,
-                            '--no_auth=True']))
-
-    pipeline | beam.Create([1, 2, 3]) | beam.Map(lambda x: x) # pylint: 
disable=expression-not-assigned
-    pipeline.run().wait_until_finish()
-    with open(dummy_file_name) as template_file:
-      saved_job_dict = json.load(template_file)
-      self.assertEqual(
-          saved_job_dict['environment']['sdkPipelineOptions']
-          ['options']['project'], 'test-project')
-      self.assertEqual(
-          saved_job_dict['environment']['sdkPipelineOptions']
-          ['options']['job_name'], 'test-job')
-
-  def test_bad_path(self):
-    dummy_sdk_file = tempfile.NamedTemporaryFile()
-    remote_runner = DataflowRunner()
-    pipeline = Pipeline(remote_runner,
-                        options=PipelineOptions([
-                            '--dataflow_endpoint=ignored',
-                            '--sdk_location=' + dummy_sdk_file.name,
-                            '--job_name=test-job',
-                            '--project=test-project',
-                            '--staging_location=ignored',
-                            '--temp_location=/dev/null',
-                            '--template_location=/bad/path',
-                            '--no_auth=True']))
-    remote_runner.job = apiclient.Job(pipeline.options)
-
-    with self.assertRaises(IOError):
-      pipeline.run().wait_until_finish()
-
-
-if __name__ == '__main__':
-  unittest.main()

Reply via email to