http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/io/gcp/pubsub.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/gcp/pubsub.py b/sdks/python/apache_beam/io/gcp/pubsub.py index 32d388a..1ba8ac0 100644 --- a/sdks/python/apache_beam/io/gcp/pubsub.py +++ b/sdks/python/apache_beam/io/gcp/pubsub.py @@ -24,35 +24,29 @@ This API is currently under development and is subject to change. from __future__ import absolute_import -import re - from apache_beam import coders from apache_beam.io.iobase import Read from apache_beam.io.iobase import Write from apache_beam.runners.dataflow.native_io import iobase as dataflow_io -from apache_beam.transforms import core from apache_beam.transforms import PTransform -from apache_beam.transforms import Map -from apache_beam.transforms import window +from apache_beam.transforms import ParDo from apache_beam.transforms.display import DisplayDataItem -__all__ = ['ReadStringsFromPubSub', 'WriteStringsToPubSub'] +__all__ = ['ReadStringsFromPubSub', 'WriteStringsToPubSub', + 'PubSubSource', 'PubSubSink'] class ReadStringsFromPubSub(PTransform): """A ``PTransform`` for reading utf-8 string payloads from Cloud Pub/Sub.""" - def __init__(self, topic=None, subscription=None, id_label=None): + def __init__(self, topic, subscription=None, id_label=None): """Initializes ``ReadStringsFromPubSub``. Attributes: - topic: Cloud Pub/Sub topic in the form "projects/<project>/topics/ - <topic>". If provided, subscription must be None. - subscription: Existing Cloud Pub/Sub subscription to use in the - form "projects/<project>/subscriptions/<subscription>". If not - specified, a temporary subscription will be created from the specified - topic. If provided, topic must be None. + topic: Cloud Pub/Sub topic in the form "/topics/<project>/<topic>". + subscription: Optional existing Cloud Pub/Sub subscription to use in the + form "projects/<project>/subscriptions/<subscription>". id_label: The attribute on incoming Pub/Sub messages to use as a unique record identifier. When specified, the value of this attribute (which can be any string that uniquely identifies the record) will be used for @@ -66,13 +60,10 @@ class ReadStringsFromPubSub(PTransform): subscription=subscription, id_label=id_label) - def get_windowing(self, unused_inputs): - return core.Windowing(window.GlobalWindows()) - def expand(self, pvalue): pcoll = pvalue.pipeline | Read(self._source) pcoll.element_type = bytes - pcoll = pcoll | 'DecodeString' >> Map(lambda b: b.decode('utf-8')) + pcoll = pcoll | 'decode string' >> ParDo(_decodeUtf8String) pcoll.element_type = unicode return pcoll @@ -90,50 +81,18 @@ class WriteStringsToPubSub(PTransform): self._sink = _PubSubPayloadSink(topic) def expand(self, pcoll): - pcoll = pcoll | 'EncodeString' >> Map(lambda s: s.encode('utf-8')) + pcoll = pcoll | 'encode string' >> ParDo(_encodeUtf8String) pcoll.element_type = bytes return pcoll | Write(self._sink) -PROJECT_ID_REGEXP = '[a-z][-a-z0-9:.]{4,61}[a-z0-9]' -SUBSCRIPTION_REGEXP = 'projects/([^/]+)/subscriptions/(.+)' -TOPIC_REGEXP = 'projects/([^/]+)/topics/(.+)' - - -def parse_topic(full_topic): - match = re.match(TOPIC_REGEXP, full_topic) - if not match: - raise ValueError( - 'PubSub topic must be in the form "projects/<project>/topics' - '/<topic>" (got %r).' % full_topic) - project, topic_name = match.group(1), match.group(2) - if not re.match(PROJECT_ID_REGEXP, project): - raise ValueError('Invalid PubSub project name: %r.' % project) - return project, topic_name - - -def parse_subscription(full_subscription): - match = re.match(SUBSCRIPTION_REGEXP, full_subscription) - if not match: - raise ValueError( - 'PubSub subscription must be in the form "projects/<project>' - '/subscriptions/<subscription>" (got %r).' % full_subscription) - project, subscription_name = match.group(1), match.group(2) - if not re.match(PROJECT_ID_REGEXP, project): - raise ValueError('Invalid PubSub project name: %r.' % project) - return project, subscription_name - - class _PubSubPayloadSource(dataflow_io.NativeSource): """Source for the payload of a message as bytes from a Cloud Pub/Sub topic. Attributes: - topic: Cloud Pub/Sub topic in the form "projects/<project>/topics/<topic>". - If provided, subscription must be None. - subscription: Existing Cloud Pub/Sub subscription to use in the - form "projects/<project>/subscriptions/<subscription>". If not specified, - a temporary subscription will be created from the specified topic. If - provided, topic must be None. + topic: Cloud Pub/Sub topic in the form "/topics/<project>/<topic>". + subscription: Optional existing Cloud Pub/Sub subscription to use in the + form "projects/<project>/subscriptions/<subscription>". id_label: The attribute on incoming Pub/Sub messages to use as a unique record identifier. When specified, the value of this attribute (which can be any string that uniquely identifies the record) will be used for @@ -142,27 +101,11 @@ class _PubSubPayloadSource(dataflow_io.NativeSource): case, deduplication of the stream will be strictly best effort. """ - def __init__(self, topic=None, subscription=None, id_label=None): - # We are using this coder explicitly for portability reasons of PubsubIO - # across implementations in languages. - self.coder = coders.BytesCoder() - self.full_topic = topic - self.full_subscription = subscription - self.topic_name = None - self.subscription_name = None + def __init__(self, topic, subscription=None, id_label=None): + self.topic = topic + self.subscription = subscription self.id_label = id_label - # Perform some validation on the topic and subscription. - if not (topic or subscription): - raise ValueError('Either a topic or subscription must be provided.') - if topic and subscription: - raise ValueError('Only one of topic or subscription should be provided.') - - if topic: - self.project, self.topic_name = parse_topic(topic) - if subscription: - self.project, self.subscription_name = parse_subscription(subscription) - @property def format(self): """Source format name required for remote execution.""" @@ -173,10 +116,10 @@ class _PubSubPayloadSource(dataflow_io.NativeSource): DisplayDataItem(self.id_label, label='ID Label Attribute').drop_if_none(), 'topic': - DisplayDataItem(self.full_topic, - label='Pubsub Topic').drop_if_none(), + DisplayDataItem(self.topic, + label='Pubsub Topic'), 'subscription': - DisplayDataItem(self.full_subscription, + DisplayDataItem(self.subscription, label='Pubsub Subscription').drop_if_none()} def reader(self): @@ -188,12 +131,7 @@ class _PubSubPayloadSink(dataflow_io.NativeSink): """Sink for the payload of a message as bytes to a Cloud Pub/Sub topic.""" def __init__(self, topic): - # we are using this coder explicitly for portability reasons of PubsubIO - # across implementations in languages. - self.coder = coders.BytesCoder() - self.full_topic = topic - - self.project, self.topic_name = parse_topic(topic) + self.topic = topic @property def format(self): @@ -201,8 +139,86 @@ class _PubSubPayloadSink(dataflow_io.NativeSink): return 'pubsub' def display_data(self): - return {'topic': DisplayDataItem(self.full_topic, label='Pubsub Topic')} + return {'topic': DisplayDataItem(self.topic, label='Pubsub Topic')} def writer(self): raise NotImplementedError( 'PubSubPayloadSink is not supported in local execution.') + + +def _decodeUtf8String(encoded_value): + """Decodes a string in utf-8 format from bytes""" + return encoded_value.decode('utf-8') + + +def _encodeUtf8String(value): + """Encodes a string in utf-8 format to bytes""" + return value.encode('utf-8') + + +class PubSubSource(dataflow_io.NativeSource): + """Deprecated: do not use. + + Source for reading from a given Cloud Pub/Sub topic. + + Attributes: + topic: Cloud Pub/Sub topic in the form "/topics/<project>/<topic>". + subscription: Optional existing Cloud Pub/Sub subscription to use in the + form "projects/<project>/subscriptions/<subscription>". + id_label: The attribute on incoming Pub/Sub messages to use as a unique + record identifier. When specified, the value of this attribute (which can + be any string that uniquely identifies the record) will be used for + deduplication of messages. If not provided, Dataflow cannot guarantee + that no duplicate data will be delivered on the Pub/Sub stream. In this + case, deduplication of the stream will be strictly best effort. + coder: The Coder to use for decoding incoming Pub/Sub messages. + """ + + def __init__(self, topic, subscription=None, id_label=None, + coder=coders.StrUtf8Coder()): + self.topic = topic + self.subscription = subscription + self.id_label = id_label + self.coder = coder + + @property + def format(self): + """Source format name required for remote execution.""" + return 'pubsub' + + def display_data(self): + return {'id_label': + DisplayDataItem(self.id_label, + label='ID Label Attribute').drop_if_none(), + 'topic': + DisplayDataItem(self.topic, + label='Pubsub Topic'), + 'subscription': + DisplayDataItem(self.subscription, + label='Pubsub Subscription').drop_if_none()} + + def reader(self): + raise NotImplementedError( + 'PubSubSource is not supported in local execution.') + + +class PubSubSink(dataflow_io.NativeSink): + """Deprecated: do not use. + + Sink for writing to a given Cloud Pub/Sub topic.""" + + def __init__(self, topic, coder=coders.StrUtf8Coder()): + self.topic = topic + self.coder = coder + + @property + def format(self): + """Sink format name required for remote execution.""" + return 'pubsub' + + def display_data(self): + return {'topic': DisplayDataItem(self.topic, label='Pubsub Topic')} + + def writer(self): + raise NotImplementedError( + 'PubSubSink is not supported in local execution.')
http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/io/gcp/pubsub_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/gcp/pubsub_test.py b/sdks/python/apache_beam/io/gcp/pubsub_test.py index 0dcc3c3..322d08a 100644 --- a/sdks/python/apache_beam/io/gcp/pubsub_test.py +++ b/sdks/python/apache_beam/io/gcp/pubsub_test.py @@ -22,6 +22,8 @@ import unittest import hamcrest as hc +from apache_beam.io.gcp.pubsub import _decodeUtf8String +from apache_beam.io.gcp.pubsub import _encodeUtf8String from apache_beam.io.gcp.pubsub import _PubSubPayloadSink from apache_beam.io.gcp.pubsub import _PubSubPayloadSource from apache_beam.io.gcp.pubsub import ReadStringsFromPubSub @@ -31,112 +33,77 @@ from apache_beam.transforms.display import DisplayData from apache_beam.transforms.display_test import DisplayDataItemMatcher -# Protect against environments where the PubSub library is not available. -# pylint: disable=wrong-import-order, wrong-import-position -try: - from google.cloud import pubsub -except ImportError: - pubsub = None -# pylint: enable=wrong-import-order, wrong-import-position - - [email protected](pubsub is None, 'GCP dependencies are not installed') class TestReadStringsFromPubSub(unittest.TestCase): - def test_expand_with_topic(self): + def test_expand(self): p = TestPipeline() - pcoll = p | ReadStringsFromPubSub('projects/fakeprj/topics/a_topic', - None, 'a_label') + pcoll = p | ReadStringsFromPubSub('a_topic', 'a_subscription', 'a_label') # Ensure that the output type is str self.assertEqual(unicode, pcoll.element_type) - # Ensure that the properties passed through correctly - source = pcoll.producer.transform._source - self.assertEqual('a_topic', source.topic_name) - self.assertEqual('a_label', source.id_label) - - def test_expand_with_subscription(self): - p = TestPipeline() - pcoll = p | ReadStringsFromPubSub( - None, 'projects/fakeprj/subscriptions/a_subscription', 'a_label') - # Ensure that the output type is str - self.assertEqual(unicode, pcoll.element_type) + # Ensure that the type on the intermediate read output PCollection is bytes + read_pcoll = pcoll.producer.inputs[0] + self.assertEqual(bytes, read_pcoll.element_type) # Ensure that the properties passed through correctly - source = pcoll.producer.transform._source - self.assertEqual('a_subscription', source.subscription_name) + source = read_pcoll.producer.transform.source + self.assertEqual('a_topic', source.topic) + self.assertEqual('a_subscription', source.subscription) self.assertEqual('a_label', source.id_label) - def test_expand_with_no_topic_or_subscription(self): - with self.assertRaisesRegexp( - ValueError, "Either a topic or subscription must be provided."): - ReadStringsFromPubSub(None, None, 'a_label') - - def test_expand_with_both_topic_and_subscription(self): - with self.assertRaisesRegexp( - ValueError, "Only one of topic or subscription should be provided."): - ReadStringsFromPubSub('a_topic', 'a_subscription', 'a_label') - [email protected](pubsub is None, 'GCP dependencies are not installed') class TestWriteStringsToPubSub(unittest.TestCase): def test_expand(self): p = TestPipeline() - pdone = (p - | ReadStringsFromPubSub('projects/fakeprj/topics/baz') - | WriteStringsToPubSub('projects/fakeprj/topics/a_topic')) + pdone = p | ReadStringsFromPubSub('baz') | WriteStringsToPubSub('a_topic') # Ensure that the properties passed through correctly - self.assertEqual('a_topic', pdone.producer.transform.dofn.topic_name) + sink = pdone.producer.transform.sink + self.assertEqual('a_topic', sink.topic) + # Ensure that the type on the intermediate payload transformer output + # PCollection is bytes + write_pcoll = pdone.producer.inputs[0] + self.assertEqual(bytes, write_pcoll.element_type) [email protected](pubsub is None, 'GCP dependencies are not installed') -class TestPubSubSource(unittest.TestCase): - def test_display_data_topic(self): - source = _PubSubPayloadSource( - 'projects/fakeprj/topics/a_topic', - None, - 'a_label') - dd = DisplayData.create_from(source) - expected_items = [ - DisplayDataItemMatcher( - 'topic', 'projects/fakeprj/topics/a_topic'), - DisplayDataItemMatcher('id_label', 'a_label')] - - hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items)) - def test_display_data_subscription(self): - source = _PubSubPayloadSource( - None, - 'projects/fakeprj/subscriptions/a_subscription', - 'a_label') +class TestPubSubSource(unittest.TestCase): + def test_display_data(self): + source = _PubSubPayloadSource('a_topic', 'a_subscription', 'a_label') dd = DisplayData.create_from(source) expected_items = [ - DisplayDataItemMatcher( - 'subscription', 'projects/fakeprj/subscriptions/a_subscription'), + DisplayDataItemMatcher('topic', 'a_topic'), + DisplayDataItemMatcher('subscription', 'a_subscription'), DisplayDataItemMatcher('id_label', 'a_label')] hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items)) def test_display_data_no_subscription(self): - source = _PubSubPayloadSource('projects/fakeprj/topics/a_topic') + source = _PubSubPayloadSource('a_topic') dd = DisplayData.create_from(source) expected_items = [ - DisplayDataItemMatcher('topic', 'projects/fakeprj/topics/a_topic')] + DisplayDataItemMatcher('topic', 'a_topic')] hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items)) [email protected](pubsub is None, 'GCP dependencies are not installed') class TestPubSubSink(unittest.TestCase): def test_display_data(self): - sink = _PubSubPayloadSink('projects/fakeprj/topics/a_topic') + sink = _PubSubPayloadSink('a_topic') dd = DisplayData.create_from(sink) expected_items = [ - DisplayDataItemMatcher('topic', 'projects/fakeprj/topics/a_topic')] + DisplayDataItemMatcher('topic', 'a_topic')] hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items)) +class TestEncodeDecodeUtf8String(unittest.TestCase): + def test_encode(self): + self.assertEqual(b'test_data', _encodeUtf8String('test_data')) + + def test_decode(self): + self.assertEqual('test_data', _decodeUtf8String(b'test_data')) + + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main() http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/io/gcp/tests/bigquery_matcher.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/gcp/tests/bigquery_matcher.py b/sdks/python/apache_beam/io/gcp/tests/bigquery_matcher.py index d6f0e97..844cbc5 100644 --- a/sdks/python/apache_beam/io/gcp/tests/bigquery_matcher.py +++ b/sdks/python/apache_beam/io/gcp/tests/bigquery_matcher.py @@ -92,9 +92,9 @@ class BigqueryMatcher(BaseMatcher): page_token = None results = [] while True: - for row in query.fetch_data(page_token=page_token): - results.append(row) - if results: + rows, _, page_token = query.fetch_data(page_token=page_token) + results.extend(rows) + if not page_token: break return results http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/io/gcp/tests/bigquery_matcher_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/gcp/tests/bigquery_matcher_test.py b/sdks/python/apache_beam/io/gcp/tests/bigquery_matcher_test.py index 5b72285..f12293e 100644 --- a/sdks/python/apache_beam/io/gcp/tests/bigquery_matcher_test.py +++ b/sdks/python/apache_beam/io/gcp/tests/bigquery_matcher_test.py @@ -53,7 +53,7 @@ class BigqueryMatcherTest(unittest.TestCase): matcher = bq_verifier.BigqueryMatcher( 'mock_project', 'mock_query', - '59f9d6bdee30d67ea73b8aded121c3a0280f9cd8') + 'da39a3ee5e6b4b0d3255bfef95601890afd80709') hc_assert_that(self._mock_result, matcher) @patch.object(bigquery, 'Client') http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/io/range_trackers.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/range_trackers.py b/sdks/python/apache_beam/io/range_trackers.py index bef77d4..9cb36e7 100644 --- a/sdks/python/apache_beam/io/range_trackers.py +++ b/sdks/python/apache_beam/io/range_trackers.py @@ -193,6 +193,136 @@ class OffsetRangeTracker(iobase.RangeTracker): self._split_points_unclaimed_callback = callback +class GroupedShuffleRangeTracker(iobase.RangeTracker): + """For internal use only; no backwards-compatibility guarantees. + + A 'RangeTracker' for positions used by'GroupedShuffleReader'. + + These positions roughly correspond to hashes of keys. In case of hash + collisions, multiple groups can have the same position. In that case, the + first group at a particular position is considered a split point (because + it is the first to be returned when reading a position range starting at this + position), others are not. + """ + + def __init__(self, decoded_start_pos, decoded_stop_pos): + super(GroupedShuffleRangeTracker, self).__init__() + self._decoded_start_pos = decoded_start_pos + self._decoded_stop_pos = decoded_stop_pos + self._decoded_last_group_start = None + self._last_group_was_at_a_split_point = False + self._split_points_seen = 0 + self._lock = threading.Lock() + + def start_position(self): + return self._decoded_start_pos + + def stop_position(self): + return self._decoded_stop_pos + + def last_group_start(self): + return self._decoded_last_group_start + + def _validate_decoded_group_start(self, decoded_group_start, split_point): + if self.start_position() and decoded_group_start < self.start_position(): + raise ValueError('Trying to return record at %r which is before the' + ' starting position at %r' % + (decoded_group_start, self.start_position())) + + if (self.last_group_start() and + decoded_group_start < self.last_group_start()): + raise ValueError('Trying to return group at %r which is before the' + ' last-returned group at %r' % + (decoded_group_start, self.last_group_start())) + if (split_point and self.last_group_start() and + self.last_group_start() == decoded_group_start): + raise ValueError('Trying to return a group at a split point with ' + 'same position as the previous group: both at %r, ' + 'last group was %sat a split point.' % + (decoded_group_start, + ('' if self._last_group_was_at_a_split_point + else 'not '))) + if not split_point: + if self.last_group_start() is None: + raise ValueError('The first group [at %r] must be at a split point' % + decoded_group_start) + if self.last_group_start() != decoded_group_start: + # This case is not a violation of general RangeTracker semantics, but it + # is contrary to how GroupingShuffleReader in particular works. Hitting + # it would mean it's behaving unexpectedly. + raise ValueError('Trying to return a group not at a split point, but ' + 'with a different position than the previous group: ' + 'last group was %r at %r, current at a %s split' + ' point.' % + (self.last_group_start() + , decoded_group_start + , ('' if self._last_group_was_at_a_split_point + else 'non-'))) + + def try_claim(self, decoded_group_start): + with self._lock: + self._validate_decoded_group_start(decoded_group_start, True) + if (self.stop_position() + and decoded_group_start >= self.stop_position()): + return False + + self._decoded_last_group_start = decoded_group_start + self._last_group_was_at_a_split_point = True + self._split_points_seen += 1 + return True + + def set_current_position(self, decoded_group_start): + with self._lock: + self._validate_decoded_group_start(decoded_group_start, False) + self._decoded_last_group_start = decoded_group_start + self._last_group_was_at_a_split_point = False + + def try_split(self, decoded_split_position): + with self._lock: + if self.last_group_start() is None: + logging.info('Refusing to split %r at %r: unstarted' + , self, decoded_split_position) + return + + if decoded_split_position <= self.last_group_start(): + logging.info('Refusing to split %r at %r: already past proposed split ' + 'position' + , self, decoded_split_position) + return + + if ((self.stop_position() + and decoded_split_position >= self.stop_position()) + or (self.start_position() + and decoded_split_position <= self.start_position())): + logging.error('Refusing to split %r at %r: proposed split position out ' + 'of range', self, decoded_split_position) + return + + logging.debug('Agreeing to split %r at %r' + , self, decoded_split_position) + self._decoded_stop_pos = decoded_split_position + + # Since GroupedShuffleRangeTracker cannot determine relative sizes of the + # two splits, returning 0.5 as the fraction below so that the framework + # assumes the splits to be of the same size. + return self._decoded_stop_pos, 0.5 + + def fraction_consumed(self): + # GroupingShuffle sources have special support on the service and the + # service will estimate progress from positions for us. + raise RuntimeError('GroupedShuffleRangeTracker does not measure fraction' + ' consumed due to positions being opaque strings' + ' that are interpreted by the service') + + def split_points(self): + with self._lock: + splits_points_consumed = ( + 0 if self._split_points_seen <= 1 else (self._split_points_seen - 1)) + + return (splits_points_consumed, + iobase.RangeTracker.SPLIT_POINTS_UNKNOWN) + + class OrderedPositionRangeTracker(iobase.RangeTracker): """ An abstract base class for range trackers whose positions are comparable. http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/io/range_trackers_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/range_trackers_test.py b/sdks/python/apache_beam/io/range_trackers_test.py index 3e92663..edb6386 100644 --- a/sdks/python/apache_beam/io/range_trackers_test.py +++ b/sdks/python/apache_beam/io/range_trackers_test.py @@ -17,11 +17,14 @@ """Unit tests for the range_trackers module.""" +import array import copy import logging import math import unittest + +from apache_beam.io import iobase from apache_beam.io import range_trackers @@ -186,6 +189,189 @@ class OffsetRangeTrackerTest(unittest.TestCase): (3, 41)) +class GroupedShuffleRangeTrackerTest(unittest.TestCase): + + def bytes_to_position(self, bytes_array): + return array.array('B', bytes_array).tostring() + + def test_try_return_record_in_infinite_range(self): + tracker = range_trackers.GroupedShuffleRangeTracker('', '') + self.assertTrue(tracker.try_claim( + self.bytes_to_position([1, 2, 3]))) + self.assertTrue(tracker.try_claim( + self.bytes_to_position([1, 2, 5]))) + self.assertTrue(tracker.try_claim( + self.bytes_to_position([3, 6, 8, 10]))) + + def test_try_return_record_finite_range(self): + tracker = range_trackers.GroupedShuffleRangeTracker( + self.bytes_to_position([1, 0, 0]), self.bytes_to_position([5, 0, 0])) + self.assertTrue(tracker.try_claim( + self.bytes_to_position([1, 2, 3]))) + self.assertTrue(tracker.try_claim( + self.bytes_to_position([1, 2, 5]))) + self.assertTrue(tracker.try_claim( + self.bytes_to_position([3, 6, 8, 10]))) + self.assertTrue(tracker.try_claim( + self.bytes_to_position([4, 255, 255, 255]))) + # Should fail for positions that are lexicographically equal to or larger + # than the defined stop position. + self.assertFalse(copy.copy(tracker).try_claim( + self.bytes_to_position([5, 0, 0]))) + self.assertFalse(copy.copy(tracker).try_claim( + self.bytes_to_position([5, 0, 1]))) + self.assertFalse(copy.copy(tracker).try_claim( + self.bytes_to_position([6, 0, 0]))) + + def test_try_return_record_with_non_split_point(self): + tracker = range_trackers.GroupedShuffleRangeTracker( + self.bytes_to_position([1, 0, 0]), self.bytes_to_position([5, 0, 0])) + self.assertTrue(tracker.try_claim( + self.bytes_to_position([1, 2, 3]))) + tracker.set_current_position(self.bytes_to_position([1, 2, 3])) + tracker.set_current_position(self.bytes_to_position([1, 2, 3])) + self.assertTrue(tracker.try_claim( + self.bytes_to_position([1, 2, 5]))) + tracker.set_current_position(self.bytes_to_position([1, 2, 5])) + self.assertTrue(tracker.try_claim( + self.bytes_to_position([3, 6, 8, 10]))) + self.assertTrue(tracker.try_claim( + self.bytes_to_position([4, 255, 255, 255]))) + + def test_first_record_non_split_point(self): + tracker = range_trackers.GroupedShuffleRangeTracker( + self.bytes_to_position([3, 0, 0]), self.bytes_to_position([5, 0, 0])) + with self.assertRaises(ValueError): + tracker.set_current_position(self.bytes_to_position([3, 4, 5])) + + def test_non_split_point_record_with_different_position(self): + tracker = range_trackers.GroupedShuffleRangeTracker( + self.bytes_to_position([3, 0, 0]), self.bytes_to_position([5, 0, 0])) + self.assertTrue(tracker.try_claim(self.bytes_to_position([3, 4, 5]))) + with self.assertRaises(ValueError): + tracker.set_current_position(self.bytes_to_position([3, 4, 6])) + + def test_try_return_record_before_start(self): + tracker = range_trackers.GroupedShuffleRangeTracker( + self.bytes_to_position([3, 0, 0]), self.bytes_to_position([5, 0, 0])) + with self.assertRaises(ValueError): + tracker.try_claim(self.bytes_to_position([1, 2, 3])) + + def test_try_return_non_monotonic(self): + tracker = range_trackers.GroupedShuffleRangeTracker( + self.bytes_to_position([3, 0, 0]), self.bytes_to_position([5, 0, 0])) + self.assertTrue(tracker.try_claim(self.bytes_to_position([3, 4, 5]))) + self.assertTrue(tracker.try_claim(self.bytes_to_position([3, 4, 6]))) + with self.assertRaises(ValueError): + tracker.try_claim(self.bytes_to_position([3, 2, 1])) + + def test_try_return_identical_positions(self): + tracker = range_trackers.GroupedShuffleRangeTracker( + self.bytes_to_position([3, 0, 0]), self.bytes_to_position([5, 0, 0])) + self.assertTrue(tracker.try_claim( + self.bytes_to_position([3, 4, 5]))) + with self.assertRaises(ValueError): + tracker.try_claim(self.bytes_to_position([3, 4, 5])) + + def test_try_split_at_position_infinite_range(self): + tracker = range_trackers.GroupedShuffleRangeTracker('', '') + # Should fail before first record is returned. + self.assertFalse(tracker.try_split( + self.bytes_to_position([3, 4, 5, 6]))) + + self.assertTrue(tracker.try_claim( + self.bytes_to_position([1, 2, 3]))) + + # Should now succeed. + self.assertIsNotNone(tracker.try_split( + self.bytes_to_position([3, 4, 5, 6]))) + # Should not split at same or larger position. + self.assertIsNone(tracker.try_split( + self.bytes_to_position([3, 4, 5, 6]))) + self.assertIsNone(tracker.try_split( + self.bytes_to_position([3, 4, 5, 6, 7]))) + self.assertIsNone(tracker.try_split( + self.bytes_to_position([4, 5, 6, 7]))) + + # Should split at smaller position. + self.assertIsNotNone(tracker.try_split( + self.bytes_to_position([3, 2, 1]))) + + self.assertTrue(tracker.try_claim( + self.bytes_to_position([2, 3, 4]))) + + # Should not split at a position we're already past. + self.assertIsNone(tracker.try_split( + self.bytes_to_position([2, 3, 4]))) + self.assertIsNone(tracker.try_split( + self.bytes_to_position([2, 3, 3]))) + + self.assertTrue(tracker.try_claim( + self.bytes_to_position([3, 2, 0]))) + self.assertFalse(tracker.try_claim( + self.bytes_to_position([3, 2, 1]))) + + def test_try_test_split_at_position_finite_range(self): + tracker = range_trackers.GroupedShuffleRangeTracker( + self.bytes_to_position([0, 0, 0]), + self.bytes_to_position([10, 20, 30])) + # Should fail before first record is returned. + self.assertFalse(tracker.try_split( + self.bytes_to_position([0, 0, 0]))) + self.assertFalse(tracker.try_split( + self.bytes_to_position([3, 4, 5, 6]))) + + self.assertTrue(tracker.try_claim( + self.bytes_to_position([1, 2, 3]))) + + # Should now succeed. + self.assertTrue(tracker.try_split( + self.bytes_to_position([3, 4, 5, 6]))) + # Should not split at same or larger position. + self.assertFalse(tracker.try_split( + self.bytes_to_position([3, 4, 5, 6]))) + self.assertFalse(tracker.try_split( + self.bytes_to_position([3, 4, 5, 6, 7]))) + self.assertFalse(tracker.try_split( + self.bytes_to_position([4, 5, 6, 7]))) + + # Should split at smaller position. + self.assertTrue(tracker.try_split( + self.bytes_to_position([3, 2, 1]))) + # But not at a position at or before last returned record. + self.assertFalse(tracker.try_split( + self.bytes_to_position([1, 2, 3]))) + + self.assertTrue(tracker.try_claim( + self.bytes_to_position([2, 3, 4]))) + self.assertTrue(tracker.try_claim( + self.bytes_to_position([3, 2, 0]))) + self.assertFalse(tracker.try_claim( + self.bytes_to_position([3, 2, 1]))) + + def test_split_points(self): + tracker = range_trackers.GroupedShuffleRangeTracker( + self.bytes_to_position([1, 0, 0]), + self.bytes_to_position([5, 0, 0])) + self.assertEqual(tracker.split_points(), + (0, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)) + self.assertTrue(tracker.try_claim(self.bytes_to_position([1, 2, 3]))) + self.assertEqual(tracker.split_points(), + (0, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)) + self.assertTrue(tracker.try_claim(self.bytes_to_position([1, 2, 5]))) + self.assertEqual(tracker.split_points(), + (1, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)) + self.assertTrue(tracker.try_claim(self.bytes_to_position([3, 6, 8]))) + self.assertEqual(tracker.split_points(), + (2, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)) + self.assertTrue(tracker.try_claim(self.bytes_to_position([4, 255, 255]))) + self.assertEqual(tracker.split_points(), + (3, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)) + self.assertFalse(tracker.try_claim(self.bytes_to_position([5, 1, 0]))) + self.assertEqual(tracker.split_points(), + (3, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)) + + class OrderedPositionRangeTrackerTest(unittest.TestCase): class DoubleRangeTracker(range_trackers.OrderedPositionRangeTracker): http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/options/pipeline_options.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/options/pipeline_options.py b/sdks/python/apache_beam/options/pipeline_options.py index ea996a3..daef3a7 100644 --- a/sdks/python/apache_beam/options/pipeline_options.py +++ b/sdks/python/apache_beam/options/pipeline_options.py @@ -18,6 +18,7 @@ """Pipeline options obtained from command line parsing.""" import argparse +import warnings from apache_beam.transforms.display import HasDisplayData from apache_beam.options.value_provider import StaticValueProvider @@ -278,6 +279,14 @@ class StandardOptions(PipelineOptions): action='store_true', help='Whether to enable streaming mode.') + # TODO(BEAM-1265): Remove this warning, once at least one runner supports + # streaming pipelines. + def validate(self, validator): + errors = [] + if self.view_as(StandardOptions).streaming: + warnings.warn('Streaming pipelines are not supported.') + return errors + class TypeOptions(PipelineOptions): @@ -465,14 +474,7 @@ class WorkerOptions(PipelineOptions): parser.add_argument( '--use_public_ips', default=None, - action='store_true', - help='Whether to assign public IP addresses to the worker VMs.') - parser.add_argument( - '--no_use_public_ips', - dest='use_public_ips', - default=None, - action='store_false', - help='Whether to assign only private IP addresses to the worker VMs.') + help='Whether to assign public IP addresses to the worker machines.') def validate(self, validator): errors = [] @@ -552,18 +554,6 @@ class SetupOptions(PipelineOptions): 'worker will install the resulting package before running any custom ' 'code.')) parser.add_argument( - '--beam_plugin', '--beam_plugin', - dest='beam_plugins', - action='append', - default=None, - help= - ('Bootstrap the python process before executing any code by importing ' - 'all the plugins used in the pipeline. Please pass a comma separated' - 'list of import paths to be included. This is currently an ' - 'experimental flag and provides no stability. Multiple ' - '--beam_plugin options can be specified if more than one plugin ' - 'is needed.')) - parser.add_argument( '--save_main_session', default=False, action='store_true', @@ -609,11 +599,6 @@ class TestOptions(PipelineOptions): help=('Verify state/output of e2e test pipeline. This is pickled ' 'version of the matcher which should extends ' 'hamcrest.core.base_matcher.BaseMatcher.')) - parser.add_argument( - '--dry_run', - default=False, - help=('Used in unit testing runners without submitting the ' - 'actual job.')) def validate(self, validator): errors = [] http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/options/pipeline_options_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/options/pipeline_options_test.py b/sdks/python/apache_beam/options/pipeline_options_test.py index f4dd4d9..1a644b4 100644 --- a/sdks/python/apache_beam/options/pipeline_options_test.py +++ b/sdks/python/apache_beam/options/pipeline_options_test.py @@ -192,52 +192,47 @@ class PipelineOptionsTest(unittest.TestCase): options = PipelineOptions(['--redefined_flag']) self.assertTrue(options.get_all_options()['redefined_flag']) - # TODO(BEAM-1319): Require unique names only within a test. - # For now, <file name acronym>_vp_arg<number> will be the convention - # to name value-provider arguments in tests, as opposed to - # <file name acronym>_non_vp_arg<number> for non-value-provider arguments. - # The number will grow per file as tests are added. def test_value_provider_options(self): class UserOptions(PipelineOptions): @classmethod def _add_argparse_args(cls, parser): parser.add_value_provider_argument( - '--pot_vp_arg1', + '--vp_arg', help='This flag is a value provider') parser.add_value_provider_argument( - '--pot_vp_arg2', + '--vp_arg2', default=1, type=int) parser.add_argument( - '--pot_non_vp_arg1', + '--non_vp_arg', default=1, type=int ) # Provide values: if not provided, the option becomes of the type runtime vp - options = UserOptions(['--pot_vp_arg1', 'hello']) - self.assertIsInstance(options.pot_vp_arg1, StaticValueProvider) - self.assertIsInstance(options.pot_vp_arg2, RuntimeValueProvider) - self.assertIsInstance(options.pot_non_vp_arg1, int) + options = UserOptions(['--vp_arg', 'hello']) + self.assertIsInstance(options.vp_arg, StaticValueProvider) + self.assertIsInstance(options.vp_arg2, RuntimeValueProvider) + self.assertIsInstance(options.non_vp_arg, int) # Values can be overwritten - options = UserOptions(pot_vp_arg1=5, - pot_vp_arg2=StaticValueProvider(value_type=str, - value='bye'), - pot_non_vp_arg1=RuntimeValueProvider( + options = UserOptions(vp_arg=5, + vp_arg2=StaticValueProvider(value_type=str, + value='bye'), + non_vp_arg=RuntimeValueProvider( option_name='foo', value_type=int, default_value=10)) - self.assertEqual(options.pot_vp_arg1, 5) - self.assertTrue(options.pot_vp_arg2.is_accessible(), - '%s is not accessible' % options.pot_vp_arg2) - self.assertEqual(options.pot_vp_arg2.get(), 'bye') - self.assertFalse(options.pot_non_vp_arg1.is_accessible()) + self.assertEqual(options.vp_arg, 5) + self.assertTrue(options.vp_arg2.is_accessible(), + '%s is not accessible' % options.vp_arg2) + self.assertEqual(options.vp_arg2.get(), 'bye') + self.assertFalse(options.non_vp_arg.is_accessible()) with self.assertRaises(RuntimeError): - options.pot_non_vp_arg1.get() + options.non_vp_arg.get() if __name__ == '__main__': http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/options/value_provider_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/options/value_provider_test.py b/sdks/python/apache_beam/options/value_provider_test.py index 17e9590..3a45e8b 100644 --- a/sdks/python/apache_beam/options/value_provider_test.py +++ b/sdks/python/apache_beam/options/value_provider_test.py @@ -24,77 +24,72 @@ from apache_beam.options.value_provider import RuntimeValueProvider from apache_beam.options.value_provider import StaticValueProvider -# TODO(BEAM-1319): Require unique names only within a test. -# For now, <file name acronym>_vp_arg<number> will be the convention -# to name value-provider arguments in tests, as opposed to -# <file name acronym>_non_vp_arg<number> for non-value-provider arguments. -# The number will grow per file as tests are added. class ValueProviderTests(unittest.TestCase): def test_static_value_provider_keyword_argument(self): class UserDefinedOptions(PipelineOptions): @classmethod def _add_argparse_args(cls, parser): parser.add_value_provider_argument( - '--vpt_vp_arg1', + '--vp_arg', help='This keyword argument is a value provider', default='some value') - options = UserDefinedOptions(['--vpt_vp_arg1', 'abc']) - self.assertTrue(isinstance(options.vpt_vp_arg1, StaticValueProvider)) - self.assertTrue(options.vpt_vp_arg1.is_accessible()) - self.assertEqual(options.vpt_vp_arg1.get(), 'abc') + options = UserDefinedOptions(['--vp_arg', 'abc']) + self.assertTrue(isinstance(options.vp_arg, StaticValueProvider)) + self.assertTrue(options.vp_arg.is_accessible()) + self.assertEqual(options.vp_arg.get(), 'abc') def test_runtime_value_provider_keyword_argument(self): class UserDefinedOptions(PipelineOptions): @classmethod def _add_argparse_args(cls, parser): parser.add_value_provider_argument( - '--vpt_vp_arg2', + '--vp_arg', help='This keyword argument is a value provider') options = UserDefinedOptions() - self.assertTrue(isinstance(options.vpt_vp_arg2, RuntimeValueProvider)) - self.assertFalse(options.vpt_vp_arg2.is_accessible()) + self.assertTrue(isinstance(options.vp_arg, RuntimeValueProvider)) + self.assertFalse(options.vp_arg.is_accessible()) with self.assertRaises(RuntimeError): - options.vpt_vp_arg2.get() + options.vp_arg.get() def test_static_value_provider_positional_argument(self): class UserDefinedOptions(PipelineOptions): @classmethod def _add_argparse_args(cls, parser): parser.add_value_provider_argument( - 'vpt_vp_arg3', + 'vp_pos_arg', help='This positional argument is a value provider', default='some value') options = UserDefinedOptions(['abc']) - self.assertTrue(isinstance(options.vpt_vp_arg3, StaticValueProvider)) - self.assertTrue(options.vpt_vp_arg3.is_accessible()) - self.assertEqual(options.vpt_vp_arg3.get(), 'abc') + self.assertTrue(isinstance(options.vp_pos_arg, StaticValueProvider)) + self.assertTrue(options.vp_pos_arg.is_accessible()) + self.assertEqual(options.vp_pos_arg.get(), 'abc') def test_runtime_value_provider_positional_argument(self): class UserDefinedOptions(PipelineOptions): @classmethod def _add_argparse_args(cls, parser): parser.add_value_provider_argument( - 'vpt_vp_arg4', + 'vp_pos_arg', help='This positional argument is a value provider') options = UserDefinedOptions([]) - self.assertTrue(isinstance(options.vpt_vp_arg4, RuntimeValueProvider)) - self.assertFalse(options.vpt_vp_arg4.is_accessible()) + self.assertTrue(isinstance(options.vp_pos_arg, RuntimeValueProvider)) + self.assertFalse(options.vp_pos_arg.is_accessible()) with self.assertRaises(RuntimeError): - options.vpt_vp_arg4.get() + options.vp_pos_arg.get() def test_static_value_provider_type_cast(self): class UserDefinedOptions(PipelineOptions): @classmethod def _add_argparse_args(cls, parser): parser.add_value_provider_argument( - '--vpt_vp_arg5', + '--vp_arg', type=int, help='This flag is a value provider') - options = UserDefinedOptions(['--vpt_vp_arg5', '123']) - self.assertTrue(isinstance(options.vpt_vp_arg5, StaticValueProvider)) - self.assertTrue(options.vpt_vp_arg5.is_accessible()) - self.assertEqual(options.vpt_vp_arg5.get(), 123) + options = UserDefinedOptions(['--vp_arg', '123']) + self.assertTrue(isinstance(options.vp_arg, StaticValueProvider)) + self.assertTrue(options.vp_arg.is_accessible()) + self.assertEqual(options.vp_arg.get(), 123) def test_set_runtime_option(self): # define ValueProvider ptions, with and without default values @@ -102,25 +97,25 @@ class ValueProviderTests(unittest.TestCase): @classmethod def _add_argparse_args(cls, parser): parser.add_value_provider_argument( - '--vpt_vp_arg6', + '--vp_arg', help='This keyword argument is a value provider') # set at runtime parser.add_value_provider_argument( # not set, had default int - '-v', '--vpt_vp_arg7', # with short form + '-v', '--vp_arg2', # with short form default=123, type=int) parser.add_value_provider_argument( # not set, had default str - '--vpt_vp-arg8', # with dash in name + '--vp-arg3', # with dash in name default='123', type=str) parser.add_value_provider_argument( # not set and no default - '--vpt_vp_arg9', + '--vp_arg4', type=float) parser.add_value_provider_argument( # positional argument set - 'vpt_vp_arg10', # default & runtime ignored + 'vp_pos_arg', # default & runtime ignored help='This positional argument is a value provider', type=float, default=5.4) @@ -128,23 +123,23 @@ class ValueProviderTests(unittest.TestCase): # provide values at graph-construction time # (options not provided here become of the type RuntimeValueProvider) options = UserDefinedOptions1(['1.2']) - self.assertFalse(options.vpt_vp_arg6.is_accessible()) - self.assertFalse(options.vpt_vp_arg7.is_accessible()) - self.assertFalse(options.vpt_vp_arg8.is_accessible()) - self.assertFalse(options.vpt_vp_arg9.is_accessible()) - self.assertTrue(options.vpt_vp_arg10.is_accessible()) + self.assertFalse(options.vp_arg.is_accessible()) + self.assertFalse(options.vp_arg2.is_accessible()) + self.assertFalse(options.vp_arg3.is_accessible()) + self.assertFalse(options.vp_arg4.is_accessible()) + self.assertTrue(options.vp_pos_arg.is_accessible()) # provide values at job-execution time # (options not provided here will use their default, if they have one) - RuntimeValueProvider.set_runtime_options({'vpt_vp_arg6': 'abc', - 'vpt_vp_arg10':'3.2'}) - self.assertTrue(options.vpt_vp_arg6.is_accessible()) - self.assertEqual(options.vpt_vp_arg6.get(), 'abc') - self.assertTrue(options.vpt_vp_arg7.is_accessible()) - self.assertEqual(options.vpt_vp_arg7.get(), 123) - self.assertTrue(options.vpt_vp_arg8.is_accessible()) - self.assertEqual(options.vpt_vp_arg8.get(), '123') - self.assertTrue(options.vpt_vp_arg9.is_accessible()) - self.assertIsNone(options.vpt_vp_arg9.get()) - self.assertTrue(options.vpt_vp_arg10.is_accessible()) - self.assertEqual(options.vpt_vp_arg10.get(), 1.2) + RuntimeValueProvider.set_runtime_options({'vp_arg': 'abc', + 'vp_pos_arg':'3.2'}) + self.assertTrue(options.vp_arg.is_accessible()) + self.assertEqual(options.vp_arg.get(), 'abc') + self.assertTrue(options.vp_arg2.is_accessible()) + self.assertEqual(options.vp_arg2.get(), 123) + self.assertTrue(options.vp_arg3.is_accessible()) + self.assertEqual(options.vp_arg3.get(), '123') + self.assertTrue(options.vp_arg4.is_accessible()) + self.assertIsNone(options.vp_arg4.get()) + self.assertTrue(options.vp_pos_arg.is_accessible()) + self.assertEqual(options.vp_pos_arg.get(), 1.2) http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/pipeline.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/pipeline.py b/sdks/python/apache_beam/pipeline.py index fe36d85..9093abf 100644 --- a/sdks/python/apache_beam/pipeline.py +++ b/sdks/python/apache_beam/pipeline.py @@ -45,7 +45,6 @@ Typical usage: from __future__ import absolute_import -import abc import collections import logging import os @@ -54,7 +53,6 @@ import tempfile from apache_beam import pvalue from apache_beam.internal import pickler -from apache_beam.pvalue import PCollection from apache_beam.runners import create_runner from apache_beam.runners import PipelineRunner from apache_beam.transforms import ptransform @@ -159,157 +157,6 @@ class Pipeline(object): """Returns the root transform of the transform stack.""" return self.transforms_stack[0] - def _remove_labels_recursively(self, applied_transform): - for part in applied_transform.parts: - if part.full_label in self.applied_labels: - self.applied_labels.remove(part.full_label) - if part.parts: - for part2 in part.parts: - self._remove_labels_recursively(part2) - - def _replace(self, override): - - assert isinstance(override, PTransformOverride) - matcher = override.get_matcher() - - output_map = {} - output_replacements = {} - input_replacements = {} - - class TransformUpdater(PipelineVisitor): # pylint: disable=used-before-assignment - """"A visitor that replaces the matching PTransforms.""" - - def __init__(self, pipeline): - self.pipeline = pipeline - - def _replace_if_needed(self, transform_node): - if matcher(transform_node): - replacement_transform = override.get_replacement_transform( - transform_node.transform) - inputs = transform_node.inputs - # TODO: Support replacing PTransforms with multiple inputs. - if len(inputs) > 1: - raise NotImplementedError( - 'PTransform overriding is only supported for PTransforms that ' - 'have a single input. Tried to replace input of ' - 'AppliedPTransform %r that has %d inputs', - transform_node, len(inputs)) - transform_node.transform = replacement_transform - self.pipeline.transforms_stack.append(transform_node) - - # Keeping the same label for the replaced node but recursively - # removing labels of child transforms since they will be replaced - # during the expand below. - self.pipeline._remove_labels_recursively(transform_node) - - new_output = replacement_transform.expand(inputs[0]) - if new_output.producer is None: - # When current transform is a primitive, we set the producer here. - new_output.producer = transform_node - - # We only support replacing transforms with a single output with - # another transform that produces a single output. - # TODO: Support replacing PTransforms with multiple outputs. - if (len(transform_node.outputs) > 1 or - not isinstance(transform_node.outputs[None], PCollection) or - not isinstance(new_output, PCollection)): - raise NotImplementedError( - 'PTransform overriding is only supported for PTransforms that ' - 'have a single output. Tried to replace output of ' - 'AppliedPTransform %r with %r.' - , transform_node, new_output) - - # Recording updated outputs. This cannot be done in the same visitor - # since if we dynamically update output type here, we'll run into - # errors when visiting child nodes. - output_map[transform_node.outputs[None]] = new_output - - self.pipeline.transforms_stack.pop() - - def enter_composite_transform(self, transform_node): - self._replace_if_needed(transform_node) - - def visit_transform(self, transform_node): - self._replace_if_needed(transform_node) - - self.visit(TransformUpdater(self)) - - # Adjusting inputs and outputs - class InputOutputUpdater(PipelineVisitor): # pylint: disable=used-before-assignment - """"A visitor that records input and output values to be replaced. - - Input and output values that should be updated are recorded in maps - input_replacements and output_replacements respectively. - - We cannot update input and output values while visiting since that results - in validation errors. - """ - - def __init__(self, pipeline): - self.pipeline = pipeline - - def enter_composite_transform(self, transform_node): - self.visit_transform(transform_node) - - def visit_transform(self, transform_node): - if (None in transform_node.outputs and - transform_node.outputs[None] in output_map): - output_replacements[transform_node] = ( - output_map[transform_node.outputs[None]]) - - replace_input = False - for input in transform_node.inputs: - if input in output_map: - replace_input = True - break - - if replace_input: - new_input = [ - input if not input in output_map else output_map[input] - for input in transform_node.inputs] - input_replacements[transform_node] = new_input - - self.visit(InputOutputUpdater(self)) - - for transform in output_replacements: - transform.replace_output(output_replacements[transform]) - - for transform in input_replacements: - transform.inputs = input_replacements[transform] - - def _check_replacement(self, override): - matcher = override.get_matcher() - - class ReplacementValidator(PipelineVisitor): - def visit_transform(self, transform_node): - if matcher(transform_node): - raise RuntimeError('Transform node %r was not replaced as expected.', - transform_node) - - self.visit(ReplacementValidator()) - - def replace_all(self, replacements): - """ Dynamically replaces PTransforms in the currently populated hierarchy. - - Currently this only works for replacements where input and output types - are exactly the same. - TODO: Update this to also work for transform overrides where input and - output types are different. - - Args: - replacements a list of PTransformOverride objects. - """ - for override in replacements: - assert isinstance(override, PTransformOverride) - self._replace(override) - - # Checking if the PTransforms have been successfully replaced. This will - # result in a failure if a PTransform that was replaced in a given override - # gets re-added in a subsequent override. This is not allowed and ordering - # of PTransformOverride objects in 'replacements' is important. - for override in replacements: - self._check_replacement(override) - def run(self, test_runner_api=True): """Runs the pipeline. Returns whatever our runner returns after running.""" @@ -466,20 +313,10 @@ class Pipeline(object): self.transforms_stack.pop() return pvalueish_result - def __reduce__(self): - # Some transforms contain a reference to their enclosing pipeline, - # which in turn reference all other transforms (resulting in quadratic - # time/space to pickle each transform individually). As we don't - # require pickled pipelines to be executable, break the chain here. - return str, ('Pickled pipeline stub.',) - def _verify_runner_api_compatible(self): class Visitor(PipelineVisitor): # pylint: disable=used-before-assignment ok = True # Really a nonlocal. - def enter_composite_transform(self, transform_node): - self.visit_transform(transform_node) - def visit_transform(self, transform_node): if transform_node.side_inputs: # No side inputs (yet). @@ -502,7 +339,7 @@ class Pipeline(object): def to_runner_api(self): """For internal use only; no backwards-compatibility guarantees.""" from apache_beam.runners import pipeline_context - from apache_beam.portability.api import beam_runner_api_pb2 + from apache_beam.runners.api import beam_runner_api_pb2 context = pipeline_context.PipelineContext() # Mutates context; placing inline would force dependence on # argument evaluation order. @@ -525,18 +362,7 @@ class Pipeline(object): p.applied_labels = set([ t.unique_name for t in proto.components.transforms.values()]) for id in proto.components.pcollections: - pcollection = context.pcollections.get_by_id(id) - pcollection.pipeline = p - - # Inject PBegin input where necessary. - from apache_beam.io.iobase import Read - from apache_beam.transforms.core import Create - has_pbegin = [Read, Create] - for id in proto.components.transforms: - transform = context.transforms.get_by_id(id) - if not transform.inputs and transform.transform.__class__ in has_pbegin: - transform.inputs = (pvalue.PBegin(p),) - + context.pcollections.get_by_id(id).pipeline = p return p @@ -558,7 +384,7 @@ class PipelineVisitor(object): pass def visit_transform(self, transform_node): - """Callback for visiting a transform leaf node in the pipeline DAG.""" + """Callback for visiting a transform node in the pipeline DAG.""" pass def enter_composite_transform(self, transform_node): @@ -615,20 +441,6 @@ class AppliedPTransform(object): for side_input in self.side_inputs: real_producer(side_input.pvalue).refcounts[side_input.pvalue.tag] += 1 - def replace_output(self, output, tag=None): - """Replaces the output defined by the given tag with the given output. - - Args: - output: replacement output - tag: tag of the output to be replaced. - """ - if isinstance(output, pvalue.DoOutputsTuple): - self.replace_output(output[output._main_tag]) - elif isinstance(output, pvalue.PValue): - self.outputs[tag] = output - else: - raise TypeError("Unexpected output type: %s" % output) - def add_output(self, output, tag=None): if isinstance(output, pvalue.DoOutputsTuple): self.add_output(output[output._main_tag]) @@ -713,7 +525,7 @@ class AppliedPTransform(object): if isinstance(output, pvalue.PCollection)} def to_runner_api(self, context): - from apache_beam.portability.api import beam_runner_api_pb2 + from apache_beam.runners.api import beam_runner_api_pb2 def transform_to_runner_api(transform, context): if transform is None: @@ -752,37 +564,3 @@ class AppliedPTransform(object): pc.tag = tag result.update_input_refcounts() return result - - -class PTransformOverride(object): - """For internal use only; no backwards-compatibility guarantees. - - Gives a matcher and replacements for matching PTransforms. - - TODO: Update this to support cases where input and/our output types are - different. - """ - __metaclass__ = abc.ABCMeta - - @abc.abstractmethod - def get_matcher(self): - """Gives a matcher that will be used to to perform this override. - - Returns: - a callable that takes an AppliedPTransform as a parameter and returns a - boolean as a result. - """ - raise NotImplementedError - - @abc.abstractmethod - def get_replacement_transform(self, ptransform): - """Provides a runner specific override for a given PTransform. - - Args: - ptransform: PTransform to be replaced. - Returns: - A PTransform that will be the replacement for the PTransform given as an - argument. - """ - # Returns a PTransformReplacement - raise NotImplementedError http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/pipeline_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/pipeline_test.py b/sdks/python/apache_beam/pipeline_test.py index aad0143..e0775d1 100644 --- a/sdks/python/apache_beam/pipeline_test.py +++ b/sdks/python/apache_beam/pipeline_test.py @@ -28,11 +28,9 @@ import apache_beam as beam from apache_beam.io import Read from apache_beam.metrics import Metrics from apache_beam.pipeline import Pipeline -from apache_beam.pipeline import PTransformOverride from apache_beam.pipeline import PipelineOptions from apache_beam.pipeline import PipelineVisitor from apache_beam.pvalue import AsSingleton -from apache_beam.runners import DirectRunner from apache_beam.runners.dataflow.native_io.iobase import NativeSource from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.testing.util import assert_that @@ -77,18 +75,6 @@ class FakeSource(NativeSource): return FakeSource._Reader(self._vals) -class DoubleParDo(beam.PTransform): - def expand(self, input): - return input | 'Inner' >> beam.Map(lambda a: a * 2) - - -class TripleParDo(beam.PTransform): - def expand(self, input): - # Keeping labels the same intentionally to make sure that there is no label - # conflict due to replacement. - return input | 'Inner' >> beam.Map(lambda a: a * 3) - - class PipelineTest(unittest.TestCase): @staticmethod @@ -299,27 +285,6 @@ class PipelineTest(unittest.TestCase): # p = Pipeline('EagerRunner') # self.assertEqual([1, 4, 9], p | Create([1, 2, 3]) | Map(lambda x: x*x)) - def test_ptransform_overrides(self): - - def my_par_do_matcher(applied_ptransform): - return isinstance(applied_ptransform.transform, DoubleParDo) - - class MyParDoOverride(PTransformOverride): - - def get_matcher(self): - return my_par_do_matcher - - def get_replacement_transform(self, ptransform): - if isinstance(ptransform, DoubleParDo): - return TripleParDo() - raise ValueError('Unsupported type of transform: %r', ptransform) - - # Using following private variable for testing. - DirectRunner._PTRANSFORM_OVERRIDES.append(MyParDoOverride()) - with Pipeline() as p: - pcoll = p | beam.Create([1, 2, 3]) | 'Multiply' >> DoubleParDo() - assert_that(pcoll, equal_to([3, 6, 9])) - class DoFnTest(unittest.TestCase): @@ -480,24 +445,6 @@ class RunnerApiTest(unittest.TestCase): p2 = Pipeline.from_runner_api(proto, p.runner, p._options) p2.run() - def test_pickling(self): - class MyPTransform(beam.PTransform): - pickle_count = [0] - - def expand(self, p): - self.p = p - return p | beam.Create([None]) - - def __reduce__(self): - self.pickle_count[0] += 1 - return str, () - - p = beam.Pipeline() - for k in range(20): - p | 'Iter%s' % k >> MyPTransform() # pylint: disable=expression-not-assigned - p.to_runner_api() - self.assertEqual(MyPTransform.pickle_count[0], 20) - if __name__ == '__main__': logging.getLogger().setLevel(logging.DEBUG) http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/portability/__init__.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/portability/__init__.py b/sdks/python/apache_beam/portability/__init__.py deleted file mode 100644 index 0bce5d6..0000000 --- a/sdks/python/apache_beam/portability/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -"""For internal use only; no backwards-compatibility guarantees.""" http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/portability/api/__init__.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/portability/api/__init__.py b/sdks/python/apache_beam/portability/api/__init__.py deleted file mode 100644 index 2750859..0000000 --- a/sdks/python/apache_beam/portability/api/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -"""For internal use only; no backwards-compatibility guarantees. - -Automatically generated when running setup.py sdist or build[_py]. -""" http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/pvalue.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/pvalue.py b/sdks/python/apache_beam/pvalue.py index 34a483e..7385e82 100644 --- a/sdks/python/apache_beam/pvalue.py +++ b/sdks/python/apache_beam/pvalue.py @@ -128,7 +128,7 @@ class PCollection(PValue): return _InvalidUnpickledPCollection, () def to_runner_api(self, context): - from apache_beam.portability.api import beam_runner_api_pb2 + from apache_beam.runners.api import beam_runner_api_pb2 from apache_beam.internal import pickler return beam_runner_api_pb2.PCollection( unique_name='%d%s.%s' % ( http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/runners/api/__init__.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/api/__init__.py b/sdks/python/apache_beam/runners/api/__init__.py new file mode 100644 index 0000000..2750859 --- /dev/null +++ b/sdks/python/apache_beam/runners/api/__init__.py @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""For internal use only; no backwards-compatibility guarantees. + +Automatically generated when running setup.py sdist or build[_py]. +""" http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py index 059e139..3fc8983 100644 --- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py +++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py @@ -39,17 +39,13 @@ from apache_beam.runners.dataflow.internal import names from apache_beam.runners.dataflow.internal.clients import dataflow as dataflow_api from apache_beam.runners.dataflow.internal.names import PropertyNames from apache_beam.runners.dataflow.internal.names import TransformNames -from apache_beam.runners.dataflow.ptransform_overrides import CreatePTransformOverride from apache_beam.runners.runner import PValueCache from apache_beam.runners.runner import PipelineResult from apache_beam.runners.runner import PipelineRunner from apache_beam.runners.runner import PipelineState from apache_beam.transforms.display import DisplayData from apache_beam.typehints import typehints -from apache_beam.options.pipeline_options import SetupOptions from apache_beam.options.pipeline_options import StandardOptions -from apache_beam.options.pipeline_options import TestOptions -from apache_beam.utils.plugin import BeamPlugin __all__ = ['DataflowRunner'] @@ -65,15 +61,11 @@ class DataflowRunner(PipelineRunner): if blocking is set to False. """ - # A list of PTransformOverride objects to be applied before running a pipeline - # using DataflowRunner. - # Currently this only works for overrides where the input and output types do - # not change. - # For internal SDK use only. This should not be updated by Beam pipeline - # authors. - _PTRANSFORM_OVERRIDES = [ - CreatePTransformOverride(), - ] + # Environment version information. It is passed to the service during a + # a job submission and is used by the service to establish what features + # are expected by the workers. + BATCH_ENVIRONMENT_MAJOR_VERSION = '6' + STREAMING_ENVIRONMENT_MAJOR_VERSION = '1' def __init__(self, cache=None): # Cache of CloudWorkflowStep protos generated while the runner @@ -223,6 +215,7 @@ class DataflowRunner(PipelineRunner): return FlattenInputVisitor() + # TODO(mariagh): Make this method take pipepline_options def run(self, pipeline): """Remotely executes entire pipeline or parts reachable from node.""" # Import here to avoid adding the dependency for local running scenarios. @@ -233,17 +226,6 @@ class DataflowRunner(PipelineRunner): raise ImportError( 'Google Cloud Dataflow runner not available, ' 'please install apache_beam[gcp]') - - # Performing configured PTransform overrides. - pipeline.replace_all(DataflowRunner._PTRANSFORM_OVERRIDES) - - # Add setup_options for all the BeamPlugin imports - setup_options = pipeline._options.view_as(SetupOptions) - plugins = BeamPlugin.get_all_plugin_paths() - if setup_options.beam_plugins is not None: - plugins = list(set(plugins + setup_options.beam_plugins)) - setup_options.beam_plugins = plugins - self.job = apiclient.Job(pipeline._options) # Dataflow runner requires a KV type for GBK inputs, hence we enforce that @@ -257,14 +239,15 @@ class DataflowRunner(PipelineRunner): # The superclass's run will trigger a traversal of all reachable nodes. super(DataflowRunner, self).run(pipeline) - test_options = pipeline._options.view_as(TestOptions) - # If it is a dry run, return without submitting the job. - if test_options.dry_run: - return None + standard_options = pipeline._options.view_as(StandardOptions) + if standard_options.streaming: + job_version = DataflowRunner.STREAMING_ENVIRONMENT_MAJOR_VERSION + else: + job_version = DataflowRunner.BATCH_ENVIRONMENT_MAJOR_VERSION # Get a Dataflow API client and set its options self.dataflow_client = apiclient.DataflowApplicationClient( - pipeline._options) + pipeline._options, job_version) # Create the job result = DataflowPipelineResult( @@ -377,26 +360,6 @@ class DataflowRunner(PipelineRunner): PropertyNames.OUTPUT_NAME: PropertyNames.OUT}]) return step - def run_Impulse(self, transform_node): - standard_options = ( - transform_node.outputs[None].pipeline._options.view_as(StandardOptions)) - if standard_options.streaming: - step = self._add_step( - TransformNames.READ, transform_node.full_label, transform_node) - step.add_property(PropertyNames.FORMAT, 'pubsub') - step.add_property(PropertyNames.PUBSUB_SUBSCRIPTION, '_starting_signal/') - - step.encoding = self._get_encoded_output_coder(transform_node) - step.add_property( - PropertyNames.OUTPUT_INFO, - [{PropertyNames.USER_NAME: ( - '%s.%s' % ( - transform_node.full_label, PropertyNames.OUT)), - PropertyNames.ENCODING: step.encoding, - PropertyNames.OUTPUT_NAME: PropertyNames.OUT}]) - else: - ValueError('Impulse source for batch pipelines has not been defined.') - def run_Flatten(self, transform_node): step = self._add_step(TransformNames.FLATTEN, transform_node.full_label, transform_node) @@ -655,13 +618,10 @@ class DataflowRunner(PipelineRunner): if not standard_options.streaming: raise ValueError('PubSubPayloadSource is currently available for use ' 'only in streaming pipelines.') - # Only one of topic or subscription should be set. - if transform.source.full_subscription: + step.add_property(PropertyNames.PUBSUB_TOPIC, transform.source.topic) + if transform.source.subscription: step.add_property(PropertyNames.PUBSUB_SUBSCRIPTION, - transform.source.full_subscription) - elif transform.source.full_topic: - step.add_property(PropertyNames.PUBSUB_TOPIC, - transform.source.full_topic) + transform.source.topic) if transform.source.id_label: step.add_property(PropertyNames.PUBSUB_ID_LABEL, transform.source.id_label) @@ -679,12 +639,7 @@ class DataflowRunner(PipelineRunner): # step should be the type of value outputted by each step. Read steps # automatically wrap output values in a WindowedValue wrapper, if necessary. # This is also necessary for proper encoding for size estimation. - # Using a GlobalWindowCoder as a place holder instead of the default - # PickleCoder because GlobalWindowCoder is known coder. - # TODO(robertwb): Query the collection for the windowfn to extract the - # correct coder. - coder = coders.WindowedValueCoder(transform._infer_output_coder(), - coders.coders.GlobalWindowCoder()) # pylint: disable=protected-access + coder = coders.WindowedValueCoder(transform._infer_output_coder()) # pylint: disable=protected-access step.encoding = self._get_cloud_encoding(coder) step.add_property( @@ -745,7 +700,7 @@ class DataflowRunner(PipelineRunner): if not standard_options.streaming: raise ValueError('PubSubPayloadSink is currently available for use ' 'only in streaming pipelines.') - step.add_property(PropertyNames.PUBSUB_TOPIC, transform.sink.full_topic) + step.add_property(PropertyNames.PUBSUB_TOPIC, transform.sink.topic) else: raise ValueError( 'Sink %r has unexpected format %s.' % ( @@ -753,12 +708,8 @@ class DataflowRunner(PipelineRunner): step.add_property(PropertyNames.FORMAT, transform.sink.format) # Wrap coder in WindowedValueCoder: this is necessary for proper encoding - # for size estimation. Using a GlobalWindowCoder as a place holder instead - # of the default PickleCoder because GlobalWindowCoder is known coder. - # TODO(robertwb): Query the collection for the windowfn to extract the - # correct coder. - coder = coders.WindowedValueCoder(transform.sink.coder, - coders.coders.GlobalWindowCoder()) + # for size estimation. + coder = coders.WindowedValueCoder(transform.sink.coder) step.encoding = self._get_cloud_encoding(coder) step.add_property(PropertyNames.ENCODING, step.encoding) step.add_property( @@ -770,7 +721,7 @@ class DataflowRunner(PipelineRunner): @classmethod def serialize_windowing_strategy(cls, windowing): from apache_beam.runners import pipeline_context - from apache_beam.portability.api import beam_runner_api_pb2 + from apache_beam.runners.api import beam_runner_api_pb2 context = pipeline_context.PipelineContext() windowing_proto = windowing.to_runner_api(context) return cls.byte_array_to_json_string( @@ -783,7 +734,7 @@ class DataflowRunner(PipelineRunner): # Imported here to avoid circular dependencies. # pylint: disable=wrong-import-order, wrong-import-position from apache_beam.runners import pipeline_context - from apache_beam.portability.api import beam_runner_api_pb2 + from apache_beam.runners.api import beam_runner_api_pb2 from apache_beam.transforms.core import Windowing proto = beam_runner_api_pb2.MessageWithComponents() proto.ParseFromString(cls.json_string_to_byte_array(serialized_data)) http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py b/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py index a9b8fdb..74fd01d 100644 --- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py +++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py @@ -59,8 +59,7 @@ class DataflowRunnerTest(unittest.TestCase): '--project=test-project', '--staging_location=ignored', '--temp_location=/dev/null', - '--no_auth=True', - '--dry_run=True'] + '--no_auth=True'] @mock.patch('time.sleep', return_value=None) def test_wait_until_finish(self, patched_time_sleep): @@ -109,22 +108,8 @@ class DataflowRunnerTest(unittest.TestCase): (p | ptransform.Create([1, 2, 3]) # pylint: disable=expression-not-assigned | 'Do' >> ptransform.FlatMap(lambda x: [(x, x)]) | ptransform.GroupByKey()) - p.run() - - def test_streaming_create_translation(self): - remote_runner = DataflowRunner() - self.default_properties.append("--streaming") - p = Pipeline(remote_runner, PipelineOptions(self.default_properties)) - p | ptransform.Create([1]) # pylint: disable=expression-not-assigned - p.run() - job_dict = json.loads(str(remote_runner.job)) - self.assertEqual(len(job_dict[u'steps']), 2) - - self.assertEqual(job_dict[u'steps'][0][u'kind'], u'ParallelRead') - self.assertEqual( - job_dict[u'steps'][0][u'properties'][u'pubsub_subscription'], - '_starting_signal/') - self.assertEqual(job_dict[u'steps'][1][u'kind'], u'ParallelDo') + remote_runner.job = apiclient.Job(p._options) + super(DataflowRunner, remote_runner).run(p) def test_remote_runner_display_data(self): remote_runner = DataflowRunner() @@ -157,7 +142,8 @@ class DataflowRunnerTest(unittest.TestCase): (p | ptransform.Create([1, 2, 3, 4, 5]) | 'Do' >> SpecialParDo(SpecialDoFn(), now)) - p.run() + remote_runner.job = apiclient.Job(p._options) + super(DataflowRunner, remote_runner).run(p) job_dict = json.loads(str(remote_runner.job)) steps = [step for step in job_dict['steps'] http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py b/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py index 33dfe19..df1a3f2 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py @@ -38,6 +38,7 @@ from apache_beam.io.filesystems import FileSystems from apache_beam.io.gcp.internal.clients import storage from apache_beam.runners.dataflow.internal import dependency from apache_beam.runners.dataflow.internal.clients import dataflow +from apache_beam.runners.dataflow.internal.dependency import get_required_container_version from apache_beam.runners.dataflow.internal.dependency import get_sdk_name_and_version from apache_beam.runners.dataflow.internal.names import PropertyNames from apache_beam.transforms import cy_combiners @@ -49,13 +50,6 @@ from apache_beam.options.pipeline_options import StandardOptions from apache_beam.options.pipeline_options import WorkerOptions -# Environment version information. It is passed to the service during a -# a job submission and is used by the service to establish what features -# are expected by the workers. -_LEGACY_ENVIRONMENT_MAJOR_VERSION = '6' -_FNAPI_ENVIRONMENT_MAJOR_VERSION = '1' - - class Step(object): """Wrapper for a dataflow Step protobuf.""" @@ -153,10 +147,7 @@ class Environment(object): if self.standard_options.streaming: job_type = 'FNAPI_STREAMING' else: - if _use_fnapi(options): - job_type = 'FNAPI_BATCH' - else: - job_type = 'PYTHON_BATCH' + job_type = 'PYTHON_BATCH' self.proto.version.additionalProperties.extend([ dataflow.Environment.VersionValue.AdditionalProperty( key='job_type', @@ -214,8 +205,11 @@ class Environment(object): pool.workerHarnessContainerImage = ( self.worker_options.worker_harness_container_image) else: + # Default to using the worker harness container image for the current SDK + # version. pool.workerHarnessContainerImage = ( - dependency.get_default_container_image_for_current_sdk(job_type)) + 'dataflow.gcr.io/v1beta3/python:%s' % + get_required_container_version()) if self.worker_options.use_public_ips is not None: if self.worker_options.use_public_ips: pool.ipConfiguration = ( @@ -370,16 +364,11 @@ class Job(object): class DataflowApplicationClient(object): """A Dataflow API client used by application code to create and query jobs.""" - def __init__(self, options): + def __init__(self, options, environment_version): """Initializes a Dataflow API client object.""" self.standard_options = options.view_as(StandardOptions) self.google_cloud_options = options.view_as(GoogleCloudOptions) - - if _use_fnapi(options): - self.environment_version = _FNAPI_ENVIRONMENT_MAJOR_VERSION - else: - self.environment_version = _LEGACY_ENVIRONMENT_MAJOR_VERSION - + self.environment_version = environment_version if self.google_cloud_options.no_auth: credentials = None else: @@ -721,14 +710,6 @@ def translate_mean(accumulator, metric_update): metric_update.kind = None -def _use_fnapi(pipeline_options): - standard_options = pipeline_options.view_as(StandardOptions) - debug_options = pipeline_options.view_as(DebugOptions) - - return standard_options.streaming or ( - debug_options.experiments and 'beam_fn_api' in debug_options.experiments) - - # To enable a counter on the service, add it to this dictionary. metric_translations = { cy_combiners.CountCombineFn: ('sum', translate_scalar), http://git-wip-us.apache.org/repos/asf/beam/blob/c1b2b96a/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py b/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py index 407ffcf..67cf77f 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py @@ -22,6 +22,7 @@ from mock import Mock from apache_beam.metrics.cells import DistributionData from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.runners.dataflow.dataflow_runner import DataflowRunner from apache_beam.runners.dataflow.internal.clients import dataflow # Protect against environments where apitools library is not available. @@ -39,7 +40,9 @@ class UtilTest(unittest.TestCase): @unittest.skip("Enable once BEAM-1080 is fixed.") def test_create_application_client(self): pipeline_options = PipelineOptions() - apiclient.DataflowApplicationClient(pipeline_options) + apiclient.DataflowApplicationClient( + pipeline_options, + DataflowRunner.BATCH_ENVIRONMENT_MAJOR_VERSION) def test_set_network(self): pipeline_options = PipelineOptions( @@ -119,30 +122,6 @@ class UtilTest(unittest.TestCase): self.assertEqual( metric_update.floatingPointMean.count.lowBits, accumulator.count) - def test_default_ip_configuration(self): - pipeline_options = PipelineOptions( - ['--temp_location', 'gs://any-location/temp']) - env = apiclient.Environment([], pipeline_options, '2.0.0') - self.assertEqual(env.proto.workerPools[0].ipConfiguration, None) - - def test_public_ip_configuration(self): - pipeline_options = PipelineOptions( - ['--temp_location', 'gs://any-location/temp', - '--use_public_ips']) - env = apiclient.Environment([], pipeline_options, '2.0.0') - self.assertEqual( - env.proto.workerPools[0].ipConfiguration, - dataflow.WorkerPool.IpConfigurationValueValuesEnum.WORKER_IP_PUBLIC) - - def test_private_ip_configuration(self): - pipeline_options = PipelineOptions( - ['--temp_location', 'gs://any-location/temp', - '--no_use_public_ips']) - env = apiclient.Environment([], pipeline_options, '2.0.0') - self.assertEqual( - env.proto.workerPools[0].ipConfiguration, - dataflow.WorkerPool.IpConfigurationValueValuesEnum.WORKER_IP_PRIVATE) - if __name__ == '__main__': unittest.main()
