http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/b14dfadd/sdks/python/apache_beam/io/iobase.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/iobase.py b/sdks/python/apache_beam/io/iobase.py new file mode 100644 index 0000000..26ebeb5 --- /dev/null +++ b/sdks/python/apache_beam/io/iobase.py @@ -0,0 +1,1073 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sources and sinks. + +A Source manages record-oriented data input from a particular kind of source +(e.g. a set of files, a database table, etc.). The reader() method of a source +returns a reader object supporting the iterator protocol; iteration yields +raw records of unprocessed, serialized data. + + +A Sink manages record-oriented data output to a particular kind of sink +(e.g. a set of files, a database table, etc.). The writer() method of a sink +returns a writer object supporting writing records of serialized data to +the sink. +""" + +from collections import namedtuple + +import logging +import random +import uuid + +from google.cloud.dataflow import pvalue +from google.cloud.dataflow.coders import PickleCoder +from google.cloud.dataflow.pvalue import AsIter +from google.cloud.dataflow.pvalue import AsSingleton +from google.cloud.dataflow.transforms import core +from google.cloud.dataflow.transforms import ptransform +from google.cloud.dataflow.transforms import window + + +def _dict_printable_fields(dict_object, skip_fields): + """Returns a list of strings for the interesting fields of a dict.""" + return ['%s=%r' % (name, value) + for name, value in dict_object.iteritems() + # want to output value 0 but not None nor [] + if (value or value == 0) + and name not in skip_fields] + +_minor_fields = ['coder', 'key_coder', 'value_coder', + 'config_bytes', 'elements', + 'append_trailing_newlines', 'strip_trailing_newlines', + 'compression_type'] + + +class NativeSource(object): + """A source implemented by Dataflow service. + + This class is to be only inherited by sources natively implemented by Cloud + Dataflow service, hence should not be sub-classed by users. + + This class is deprecated and should not be used to define new sources. + """ + + def reader(self): + """Returns a NativeSourceReader instance associated with this source.""" + raise NotImplementedError + + def __repr__(self): + return '<{name} {vals}>'.format( + name=self.__class__.__name__, + vals=', '.join(_dict_printable_fields(self.__dict__, + _minor_fields))) + + +class NativeSourceReader(object): + """A reader for a source implemented by Dataflow service.""" + + def __enter__(self): + """Opens everything necessary for a reader to function properly.""" + raise NotImplementedError + + def __exit__(self, exception_type, exception_value, traceback): + """Cleans up after a reader executed.""" + raise NotImplementedError + + def __iter__(self): + """Returns an iterator over all the records of the source.""" + raise NotImplementedError + + @property + def returns_windowed_values(self): + """Returns whether this reader returns windowed values.""" + return False + + def get_progress(self): + """Returns a representation of how far the reader has read. + + Returns: + A SourceReaderProgress object that gives the current progress of the + reader. + """ + return + + def request_dynamic_split(self, dynamic_split_request): + """Attempts to split the input in two parts. + + The two parts are named the "primary" part and the "residual" part. The + current 'NativeSourceReader' keeps processing the primary part, while the + residual part will be processed elsewhere (e.g. perhaps on a different + worker). + + The primary and residual parts, if concatenated, must represent the + same input as the current input of this 'NativeSourceReader' before this + call. + + The boundary between the primary part and the residual part is + specified in a framework-specific way using 'DynamicSplitRequest' e.g., + if the framework supports the notion of positions, it might be a + position at which the input is asked to split itself (which is not + necessarily the same position at which it *will* split itself); it + might be an approximate fraction of input, or something else. + + This function returns a 'DynamicSplitResult', which encodes, in a + framework-specific way, the information sufficient to construct a + description of the resulting primary and residual inputs. For example, it + might, again, be a position demarcating these parts, or it might be a pair + of fully-specified input descriptions, or something else. + + After a successful call to 'request_dynamic_split()', subsequent calls + should be interpreted relative to the new primary. + + Args: + dynamic_split_request: A 'DynamicSplitRequest' describing the split + request. + + Returns: + 'None' if the 'DynamicSplitRequest' cannot be honored (in that + case the input represented by this 'NativeSourceReader' stays the same), + or a 'DynamicSplitResult' describing how the input was split into a + primary and residual part. + """ + logging.debug( + 'SourceReader %r does not support dynamic splitting. Ignoring dynamic ' + 'split request: %r', + self, dynamic_split_request) + return + + +class ReaderProgress(object): + """A representation of how far a NativeSourceReader has read.""" + + def __init__(self, position=None, percent_complete=None, remaining_time=None): + + self._position = position + + if percent_complete is not None: + percent_complete = float(percent_complete) + if percent_complete < 0 or percent_complete > 1: + raise ValueError( + 'The percent_complete argument was %f. Must be in range [0, 1].' + % percent_complete) + self._percent_complete = percent_complete + + self._remaining_time = remaining_time + + @property + def position(self): + """Returns progress, represented as a ReaderPosition object.""" + return self._position + + @property + def percent_complete(self): + """Returns progress, represented as a percentage of total work. + + Progress range from 0.0 (beginning, nothing complete) to 1.0 (end of the + work range, entire WorkItem complete). + + Returns: + Progress represented as a percentage of total work. + """ + return self._percent_complete + + @property + def remaining_time(self): + """Returns progress, represented as an estimated time remaining.""" + return self._remaining_time + + +class ReaderPosition(object): + """A representation of position in an iteration of a 'NativeSourceReader'.""" + + def __init__(self, end=None, key=None, byte_offset=None, record_index=None, + shuffle_position=None, concat_position=None): + """Initializes ReaderPosition. + + A ReaderPosition may get instantiated for one of these position types. Only + one of these should be specified. + + Args: + end: position is past all other positions. For example, this may be used + to represent the end position of an unbounded range. + key: position is a string key. + byte_offset: position is a byte offset. + record_index: position is a record index + shuffle_position: position is a base64 encoded shuffle position. + concat_position: position is a 'ConcatPosition'. + """ + + self.end = end + self.key = key + self.byte_offset = byte_offset + self.record_index = record_index + self.shuffle_position = shuffle_position + + if concat_position is not None: + assert isinstance(concat_position, ConcatPosition) + self.concat_position = concat_position + + +class ConcatPosition(object): + """A position that encapsulate an inner position and an index. + + This is used to represent the position of a source that encapsulate several + other sources. + """ + + def __init__(self, index, position): + """Initializes ConcatPosition. + + Args: + index: index of the source currently being read. + position: inner position within the source currently being read. + """ + + if position is not None: + assert isinstance(position, ReaderPosition) + self.index = index + self.position = position + + +class DynamicSplitRequest(object): + """Specifies how 'NativeSourceReader.request_dynamic_split' should split. + """ + + def __init__(self, progress): + assert isinstance(progress, ReaderProgress) + self.progress = progress + + +class DynamicSplitResult(object): + pass + + +class DynamicSplitResultWithPosition(DynamicSplitResult): + + def __init__(self, stop_position): + assert isinstance(stop_position, ReaderPosition) + self.stop_position = stop_position + + +class NativeSink(object): + """A sink implemented by Dataflow service. + + This class is to be only inherited by sinks natively implemented by Cloud + Dataflow service, hence should not be sub-classed by users. + """ + + def writer(self): + """Returns a SinkWriter for this source.""" + raise NotImplementedError + + def __repr__(self): + return '<{name} {vals}>'.format( + name=self.__class__.__name__, + vals=_dict_printable_fields(self.__dict__, _minor_fields)) + + +class NativeSinkWriter(object): + """A writer for a sink implemented by Dataflow service.""" + + def __enter__(self): + """Opens everything necessary for a writer to function properly.""" + raise NotImplementedError + + def __exit__(self, exception_type, exception_value, traceback): + """Cleans up after a writer executed.""" + raise NotImplementedError + + @property + def takes_windowed_values(self): + """Returns whether this writer takes windowed values.""" + return False + + def Write(self, o): # pylint: disable=invalid-name + """Writes a record to the sink associated with this writer.""" + raise NotImplementedError + + +# Encapsulates information about a bundle of a source generated when method +# BoundedSource.split() is invoked. +# This is a named 4-tuple that has following fields. +# * weight - a number that represents the size of the bundle. This value will +# be used to compare the relative sizes of bundles generated by the +# current source. +# The weight returned here could be specified using a unit of your +# choice (for example, bundles of sizes 100MB, 200MB, and 700MB may +# specify weights 100, 200, 700 or 1, 2, 7) but all bundles of a +# source should specify the weight using the same unit. +# * source - a BoundedSource object for the bundle. +# * start_position - starting position of the bundle +# * stop_position - ending position of the bundle. +# +# Type for start and stop positions are specific to the bounded source and must +# be consistent throughout. +SourceBundle = namedtuple( + 'SourceBundle', + 'weight source start_position stop_position') + + +class BoundedSource(object): + """A Dataflow source that reads a finite amount of input records. + + This class defines following operations which can be used to read the source + efficiently. + + * Size estimation - method ``estimate_size()`` may return an accurate + estimation in bytes for the size of the source. + * Splitting into bundles of a given size - method ``split()`` can be used to + split the source into a set of sub-sources (bundles) based on a desired + bundle size. + * Getting a RangeTracker - method ``get_range_tracker() should return a + ``RangeTracker`` object for a given position range for the position type + of the records returned by the source. + * Reading the data - method ``read()`` can be used to read data from the + source while respecting the boundaries defined by a given + ``RangeTracker``. + """ + + def estimate_size(self): + """Estimates the size of source in bytes. + + An estimate of the total size (in bytes) of the data that would be read + from this source. This estimate is in terms of external storage size, + before performing decompression or other processing. + + Returns: + estimated size of the source if the size can be determined, ``None`` + otherwise. + """ + raise NotImplementedError + + def split(self, desired_bundle_size, start_position=None, stop_position=None): + """Splits the source into a set of bundles. + + Bundles should be approximately of size ``desired_bundle_size`` bytes. + + Args: + desired_bundle_size: the desired size (in bytes) of the bundles returned. + start_position: if specified the given position must be used as the + starting position of the first bundle. + stop_position: if specified the given position must be used as the ending + position of the last bundle. + Returns: + an iterator of objects of type 'SourceBundle' that gives information about + the generated bundles. + """ + raise NotImplementedError + + def get_range_tracker(self, start_position, stop_position): + """Returns a RangeTracker for a given position range. + + Framework may invoke ``read()`` method with the RangeTracker object returned + here to read data from the source. + Args: + start_position: starting position of the range. + stop_position: ending position of the range. + Returns: + a ``RangeTracker`` for the given position range. + """ + raise NotImplementedError + + def read(self, range_tracker): + """Returns an iterator that reads data from the source. + + The returned set of data must respect the boundaries defined by the given + ``RangeTracker`` object. For example: + + * Returned set of data must be for the range + ``[range_tracker.start_position, range_tracker.stop_position)``. Note + that a source may decide to return records that start after + ``range_tracker.stop_position``. See documentation in class + ``RangeTracker`` for more details. Also, note that framework might + invoke ``range_tracker.try_split()`` to perform dynamic split + operations. range_tracker.stop_position may be updated + dynamically due to successful dynamic split operations. + * Method ``range_tracker.try_split()`` must be invoked for every record + that starts at a split point. + * Method ``range_tracker.record_current_position()`` may be invoked for + records that do not start at split points. + + Args: + range_tracker: a ``RangeTracker`` whose boundaries must be respected + when reading data from the source. If 'None' all records + represented by the current source should be read. + Returns: + an iterator of data read by the source. + """ + raise NotImplementedError + + def default_output_coder(self): + """Coder that should be used for the records returned by the source.""" + return PickleCoder() + + +class RangeTracker(object): + """A thread safe object used by Dataflow source framework. + + A Dataflow source is defined using a ''BoundedSource'' and a ''RangeTracker'' + pair. A ''RangeTracker'' is used by Dataflow source framework to perform + dynamic work rebalancing of position-based sources. + + **Position-based sources** + + A position-based source is one where the source can be described by a range + of positions of an ordered type and the records returned by the reader can be + described by positions of the same type. + + In case a record occupies a range of positions in the source, the most + important thing about the record is the position where it starts. + + Defining the semantics of positions for a source is entirely up to the source + class, however the chosen definitions have to obey certain properties in order + to make it possible to correctly split the source into parts, including + dynamic splitting. Two main aspects need to be defined: + + 1. How to assign starting positions to records. + 2. Which records should be read by a source with a range '[A, B)'. + + Moreover, reading a range must be *efficient*, i.e., the performance of + reading a range should not significantly depend on the location of the range. + For example, reading the range [A, B) should not require reading all data + before 'A'. + + The sections below explain exactly what properties these definitions must + satisfy, and how to use a ``RangeTracker`` with a properly defined source. + + **Properties of position-based sources** + + The main requirement for position-based sources is *associativity*: reading + records from '[A, B)' and records from '[B, C)' should give the same + records as reading from '[A, C)', where 'A <= B <= C'. This property + ensures that no matter how a range of positions is split into arbitrarily many + sub-ranges, the total set of records described by them stays the same. + + The other important property is how the source's range relates to positions of + records in the source. In many sources each record can be identified by a + unique starting position. In this case: + + * All records returned by a source '[A, B)' must have starting positions in + this range. + * All but the last record should end within this range. The last record may or + may not extend past the end of the range. + * Records should not overlap. + + Such sources should define "read '[A, B)'" as "read from the first record + starting at or after 'A', up to but not including the first record starting + at or after 'B'". + + Some examples of such sources include reading lines or CSV from a text file, + reading keys and values from a BigTable, etc. + + The concept of *split points* allows to extend the definitions for dealing + with sources where some records cannot be identified by a unique starting + position. + + In all cases, all records returned by a source '[A, B)' must *start* at or + after 'A'. + + **Split points** + + Some sources may have records that are not directly addressable. For example, + imagine a file format consisting of a sequence of compressed blocks. Each + block can be assigned an offset, but records within the block cannot be + directly addressed without decompressing the block. Let us refer to this + hypothetical format as <i>CBF (Compressed Blocks Format)</i>. + + Many such formats can still satisfy the associativity property. For example, + in CBF, reading '[A, B)' can mean "read all the records in all blocks whose + starting offset is in '[A, B)'". + + To support such complex formats, we introduce the notion of *split points*. We + say that a record is a split point if there exists a position 'A' such that + the record is the first one to be returned when reading the range + '[A, infinity)'. In CBF, the only split points would be the first records + in each block. + + Split points allow us to define the meaning of a record's position and a + source's range in all cases: + + * For a record that is at a split point, its position is defined to be the + largest 'A' such that reading a source with the range '[A, infinity)' + returns this record. + * Positions of other records are only required to be non-decreasing. + * Reading the source '[A, B)' must return records starting from the first + split point at or after 'A', up to but not including the first split point + at or after 'B'. In particular, this means that the first record returned + by a source MUST always be a split point. + * Positions of split points must be unique. + + As a result, for any decomposition of the full range of the source into + position ranges, the total set of records will be the full set of records in + the source, and each record will be read exactly once. + + **Consumed positions** + + As the source is being read, and records read from it are being passed to the + downstream transforms in the pipeline, we say that positions in the source are + being *consumed*. When a reader has read a record (or promised to a caller + that a record will be returned), positions up to and including the record's + start position are considered *consumed*. + + Dynamic splitting can happen only at *unconsumed* positions. If the reader + just returned a record at offset 42 in a file, dynamic splitting can happen + only at offset 43 or beyond, as otherwise that record could be read twice (by + the current reader and by a reader of the task starting at 43). + """ + + def start_position(self): + """Returns the starting position of the current range, inclusive.""" + raise NotImplementedError + + def stop_position(self): + """Returns the ending position of the current range, exclusive.""" + raise NotImplementedError + + def try_claim(self, position): # pylint: disable=unused-argument + """Atomically determines if a record at a split point is within the range. + + This method should be called **if and only if** the record is at a split + point. This method may modify the internal state of the ``RangeTracker`` by + updating the last-consumed position to ``position``. + + ** Thread safety ** + + This method along with several other methods of this class may be invoked by + multiple threads, hence must be made thread-safe, e.g. by using a single + lock object. + + Args: + position: starting position of a record being read by a source. + + Returns: + ``True``, if the given position falls within the current range, returns + ``False`` otherwise. + """ + raise NotImplementedError + + def set_current_position(self, position): + """Updates the last-consumed position to the given position. + + A source may invoke this method for records that do not start at split + points. This may modify the internal state of the ``RangeTracker``. If the + record starts at a split point, method ``try_claim()`` **must** be invoked + instead of this method. + + Args: + position: starting position of a record being read by a source. + """ + raise NotImplementedError + + def position_at_fraction(self, fraction): + """Returns the position at the given fraction. + + Given a fraction within the range [0.0, 1.0) this method will return the + position at the given fraction compared the the position range + [self.start_position, self.stop_position). + + ** Thread safety ** + + This method along with several other methods of this class may be invoked by + multiple threads, hence must be made thread-safe, e.g. by using a single + lock object. + + Args: + fraction: a float value within the range [0.0, 1.0). + Returns: + a position within the range [self.start_position, self.stop_position). + """ + raise NotImplementedError + + def try_split(self, position): + """Atomically splits the current range. + + Determines a position to split the current range, split_position, based on + the given position. In most cases split_position and position will be the + same. + + Splits the current range '[self.start_position, self.stop_position)' + into a "primary" part '[self.start_position, split_position)' and a + "residual" part '[split_position, self.stop_position)', assuming the + current last-consumed position is within + '[self.start_position, split_position)' (i.e., split_position has not been + consumed yet). + + If successful, updates the current range to be the primary and returns a + tuple (split_position, split_fraction). split_fraction should be the + fraction of size of range '[self.start_position, split_position)' compared + to the original (before split) range + '[self.start_position, self.stop_position)'. + + If the split_position has already been consumed, returns ``None``. + + ** Thread safety ** + + This method along with several other methods of this class may be invoked by + multiple threads, hence must be made thread-safe, e.g. by using a single + lock object. + + Args: + position: suggested position where the current range should try to + be split at. + Returns: + a tuple containing the split position and split fraction. + """ + raise NotImplementedError + + def fraction_consumed(self): + """Returns the approximate fraction of consumed positions in the source. + + ** Thread safety ** + + This method along with several other methods of this class may be invoked by + multiple threads, hence must be made thread-safe, e.g. by using a single + lock object. + + Returns: + the approximate fraction of positions that have been consumed by + successful 'try_split()' and 'report_current_position()' calls, or + 0.0 if no such calls have happened. + """ + raise NotImplementedError + + +class Sink(object): + """A resource that can be written to using the ``df.io.Write`` transform. + + Here ``df`` stands for Dataflow Python code imported in following manner. + ``import google.cloud.dataflow as df``. + + A parallel write to an ``iobase.Sink`` consists of three phases: + + 1. A sequential *initialization* phase (e.g., creating a temporary output + directory, etc.) + 2. A parallel write phase where workers write *bundles* of records + 3. A sequential *finalization* phase (e.g., committing the writes, merging + output files, etc.) + + For exact definition of a Dataflow bundle please see + https://cloud.google.com/dataflow/faq. + + Implementing a new sink requires extending two classes. + + 1. iobase.Sink + + ``iobase.Sink`` is an immutable logical description of the location/resource + to write to. Depending on the type of sink, it may contain fields such as the + path to an output directory on a filesystem, a database table name, + etc. ``iobase.Sink`` provides methods for performing a write operation to the + sink described by it. To this end, implementors of an extension of + ``iobase.Sink`` must implement three methods: + ``initialize_write()``, ``open_writer()``, and ``finalize_write()``. + + 2. iobase.Writer + + ``iobase.Writer`` is used to write a single bundle of records. An + ``iobase.Writer`` defines two methods: ``write()`` which writes a + single record from the bundle and ``close()`` which is called once + at the end of writing a bundle. + + See also ``df.io.fileio.FileSink`` which provides a simpler API for writing + sinks that produce files. + + **Execution of the Write transform** + + ``initialize_write()`` and ``finalize_write()`` are conceptually called once: + at the beginning and end of a ``Write`` transform. However, implementors must + ensure that these methods are *idempotent*, as they may be called multiple + times on different machines in the case of failure/retry or for redundancy. + + ``initialize_write()`` should perform any initialization that needs to be done + prior to writing to the sink. ``initialize_write()`` may return a result + (let's call this ``init_result``) that contains any parameters it wants to + pass on to its writers about the sink. For example, a sink that writes to a + file system may return an ``init_result`` that contains a dynamically + generated unique directory to which data should be written. + + To perform writing of a bundle of elements, Dataflow execution engine will + create an ``iobase.Writer`` using the implementation of + ``iobase.Sink.open_writer()``. When invoking ``open_writer()`` execution + engine will provide the ``init_result`` returned by ``initialize_write()`` + invocation as well as a *bundle id* (let's call this ``bundle_id``) that is + unique for each invocation of ``open_writer()``. + + Execution engine will then invoke ``iobase.Writer.write()`` implementation for + each element that has to be written. Once all elements of a bundle are + written, execution engine will invoke ``iobase.Writer.close()`` implementation + which should return a result (let's call this ``write_result``) that contains + information that encodes the result of the write and, in most cases, some + encoding of the unique bundle id. For example, if each bundle is written to a + unique temporary file, ``close()`` method may return an object that contains + the temporary file name. After writing of all bundles is complete, execution + engine will invoke ``finalize_write()`` implementation. As parameters to this + invocation execution engine will provide ``init_result`` as well as an + iterable of ``write_result``. + + The execution of a write transform can be illustrated using following pseudo + code (assume that the outer for loop happens in parallel across many + machines):: + + init_result = sink.initialize_write() + write_results = [] + for bundle in partition(pcoll): + writer = sink.open_writer(init_result, generate_bundle_id()) + for elem in bundle: + writer.write(elem) + write_results.append(writer.close()) + sink.finalize_write(init_result, write_results) + + + **init_result** + + Methods of 'iobase.Sink' should agree on the 'init_result' type that will be + returned when initializing the sink. This type can be a client-defined object + or an existing type. The returned type must be picklable using Dataflow coder + ``coders.PickleCoder``. Returning an init_result is optional. + + **bundle_id** + + In order to ensure fault-tolerance, a bundle may be executed multiple times + (e.g., in the event of failure/retry or for redundancy). However, exactly one + of these executions will have its result passed to the + ``iobase.Sink.finalize_write()`` method. Each call to + ``iobase.Sink.open_writer()`` is passed a unique bundle id when it is called + by the ``WriteImpl`` transform, so even redundant or retried bundles will have + a unique way of identifying their output. + + The bundle id should be used to guarantee that a bundle's output is unique. + This uniqueness guarantee is important; if a bundle is to be output to a file, + for example, the name of the file must be unique to avoid conflicts with other + writers. The bundle id should be encoded in the writer result returned by the + writer and subsequently used by the ``finalize_write()`` method to identify + the results of successful writes. + + For example, consider the scenario where a Writer writes files containing + serialized records and the ``finalize_write()`` is to merge or rename these + output files. In this case, a writer may use its unique id to name its output + file (to avoid conflicts) and return the name of the file it wrote as its + writer result. The ``finalize_write()`` will then receive an ``Iterable`` of + output file names that it can then merge or rename using some bundle naming + scheme. + + **write_result** + + ``iobase.Writer.close()`` and ``finalize_write()`` implementations must agree + on type of the ``write_result`` object returned when invoking + ``iobase.Writer.close()``. This type can be a client-defined object or + an existing type. The returned type must be picklable using Dataflow coder + ``coders.PickleCoder``. Returning a ``write_result`` when + ``iobase.Writer.close()`` is invoked is optional but if unique + ``write_result`` objects are not returned, sink should, guarantee idempotency + when same bundle is written multiple times due to failure/retry or redundancy. + + + **More information** + + For more information on creating new sinks please refer to the official + documentation at + ``https://cloud.google.com/dataflow/model/custom-io#creating-sinks``. + """ + + def initialize_write(self): + """Initializes the sink before writing begins. + + Invoked before any data is written to the sink. + + + Please see documentation in ``iobase.Sink`` for an example. + + Returns: + An object that contains any sink specific state generated by + initialization. This object will be passed to open_writer() and + finalize_write() methods. + """ + raise NotImplementedError + + def open_writer(self, init_result, uid): + """Opens a writer for writing a bundle of elements to the sink. + + Args: + init_result: the result of initialize_write() invocation. + uid: a unique identifier generated by the system. + Returns: + an ``iobase.Writer`` that can be used to write a bundle of records to the + current sink. + """ + raise NotImplementedError + + def finalize_write(self, init_result, writer_results): + """Finalizes the sink after all data is written to it. + + Given the result of initialization and an iterable of results from bundle + writes, performs finalization after writing and closes the sink. Called + after all bundle writes are complete. + + The bundle write results that are passed to finalize are those returned by + bundles that completed successfully. Although bundles may have been run + multiple times (for fault-tolerance), only one writer result will be passed + to finalize for each bundle. An implementation of finalize should perform + clean up of any failed and successfully retried bundles. Note that these + failed bundles will not have their writer result passed to finalize, so + finalize should be capable of locating any temporary/partial output written + by failed bundles. + + If all retries of a bundle fails, the whole pipeline will fail *without* + finalize_write() being invoked. + + A best practice is to make finalize atomic. If this is impossible given the + semantics of the sink, finalize should be idempotent, as it may be called + multiple times in the case of failure/retry or for redundancy. + + Note that the iteration order of the writer results is not guaranteed to be + consistent if finalize is called multiple times. + + Args: + init_result: the result of ``initialize_write()`` invocation. + writer_results: an iterable containing results of ``Writer.close()`` + invocations. This will only contain results of successful writes, and + will only contain the result of a single successful write for a given + bundle. + """ + raise NotImplementedError + + +class Writer(object): + """Writes a bundle of elements from a ``PCollection`` to a sink. + + A Writer ``iobase.Writer.write()`` writes and elements to the sink while + ``iobase.Writer.close()`` is called after all elements in the bundle have been + written. + + See ``iobase.Sink`` for more detailed documentation about the process of + writing to a sink. + """ + + def write(self, value): + """Writes a value to the sink using the current writer.""" + raise NotImplementedError + + def close(self): + """Closes the current writer. + + Please see documentation in ``iobase.Sink`` for an example. + + Returns: + An object representing the writes that were performed by the current + writer. + """ + raise NotImplementedError + + +class _NativeWrite(ptransform.PTransform): + """A PTransform for writing to a Dataflow native sink. + + These are sinks that are implemented natively by the Dataflow service + and hence should not be updated by users. These sinks are processed + using a Dataflow native write transform. + + Applying this transform results in a ``pvalue.PDone``. + """ + + def __init__(self, *args, **kwargs): + """Initializes a Write transform. + + Args: + *args: A tuple of position arguments. + **kwargs: A dictionary of keyword arguments. + + The *args, **kwargs are expected to be (label, sink) or (sink). + """ + label, sink = self.parse_label_and_arg(args, kwargs, 'sink') + super(_NativeWrite, self).__init__(label) + self.sink = sink + + def apply(self, pcoll): + self._check_pcollection(pcoll) + return pvalue.PDone(pcoll.pipeline) + + +class Read(ptransform.PTransform): + """A transform that reads a PCollection.""" + + def __init__(self, *args, **kwargs): + """Initializes a Read transform. + + Args: + *args: A tuple of position arguments. + **kwargs: A dictionary of keyword arguments. + + The *args, **kwargs are expected to be (label, source) or (source). + """ + label, source = self.parse_label_and_arg(args, kwargs, 'source') + super(Read, self).__init__(label) + self.source = source + + def apply(self, pbegin): + assert isinstance(pbegin, pvalue.PBegin) + self.pipeline = pbegin.pipeline + return pvalue.PCollection(self.pipeline) + + def get_windowing(self, unused_inputs): + return core.Windowing(window.GlobalWindows()) + + +class Write(ptransform.PTransform): + """A ``PTransform`` that writes to a sink. + + A sink should inherit ``iobase.Sink``. Such implementations are + handled using a composite transform that consists of three ``ParDo``s - + (1) a ``ParDo`` performing a global initialization (2) a ``ParDo`` performing + a parallel write and (3) a ``ParDo`` performing a global finalization. In the + case of an empty ``PCollection``, only the global initialization and + finalization will be performed. Currently only batch workflows support custom + sinks. + + Example usage:: + + pcollection | df.io.Write(MySink()) + + This returns a ``pvalue.PValue`` object that represents the end of the + Pipeline. + + The sink argument may also be a full PTransform, in which case it will be + applied directly. This allows composite sink-like transforms (e.g. a sink + with some pre-processing DoFns) to be used the same as all other sinks. + + This transform also supports sinks that inherit ``iobase.NativeSink``. These + are sinks that are implemented natively by the Dataflow service and hence + should not be updated by users. These sinks are processed using a Dataflow + native write transform. + """ + + def __init__(self, *args, **kwargs): + """Initializes a Write transform. + + Args: + *args: A tuple of position arguments. + **kwargs: A dictionary of keyword arguments. + + The *args, **kwargs are expected to be (label, sink) or (sink). + """ + label, sink = self.parse_label_and_arg(args, kwargs, 'sink') + super(Write, self).__init__(label) + self.sink = sink + + def apply(self, pcoll): + from google.cloud.dataflow.io import iobase + if isinstance(self.sink, iobase.NativeSink): + # A native sink + return pcoll | _NativeWrite('native_write', self.sink) + elif isinstance(self.sink, iobase.Sink): + # A custom sink + return pcoll | WriteImpl(self.sink) + elif isinstance(self.sink, ptransform.PTransform): + # This allows "composite" sinks to be used like non-composite ones. + return pcoll | self.sink + else: + raise ValueError('A sink must inherit iobase.Sink, iobase.NativeSink, ' + 'or be a PTransform. Received : %r', self.sink) + + +class WriteImpl(ptransform.PTransform): + """Implements the writing of custom sinks.""" + + def __init__(self, sink): + super(WriteImpl, self).__init__() + self.sink = sink + + def apply(self, pcoll): + do_once = pcoll.pipeline | core.Create('DoOnce', [None]) + init_result_coll = do_once | core.Map( + 'initialize_write', lambda _, sink: sink.initialize_write(), self.sink) + if getattr(self.sink, 'num_shards', 0): + min_shards = self.sink.num_shards + if min_shards == 1: + keyed_pcoll = pcoll | core.Map(lambda x: (None, x)) + else: + keyed_pcoll = pcoll | core.ParDo(_RoundRobinKeyFn(min_shards)) + write_result_coll = (keyed_pcoll + | core.WindowInto(window.GlobalWindows()) + | core.GroupByKey() + | core.Map('write_bundles', + _write_keyed_bundle, self.sink, + AsSingleton(init_result_coll))) + else: + min_shards = 1 + write_result_coll = pcoll | core.ParDo('write_bundles', + _WriteBundleDoFn(), self.sink, + AsSingleton(init_result_coll)) + return do_once | core.FlatMap( + 'finalize_write', + _finalize_write, + self.sink, + AsSingleton(init_result_coll), + AsIter(write_result_coll), + min_shards) + + +class _WriteBundleDoFn(core.DoFn): + """A DoFn for writing elements to an iobase.Writer. + + Opens a writer at the first element and closes the writer at finish_bundle(). + """ + + def __init__(self): + self.writer = None + + def process(self, context, sink, init_result): + if self.writer is None: + self.writer = sink.open_writer(init_result, str(uuid.uuid4())) + self.writer.write(context.element) + + def finish_bundle(self, context, *args, **kwargs): + if self.writer is not None: + yield window.TimestampedValue(self.writer.close(), window.MAX_TIMESTAMP) + + +def _write_keyed_bundle(bundle, sink, init_result): + writer = sink.open_writer(init_result, str(uuid.uuid4())) + for element in bundle[1]: # values + writer.write(element) + return window.TimestampedValue(writer.close(), window.MAX_TIMESTAMP) + + +def _finalize_write(_, sink, init_result, write_results, min_shards): + write_results = list(write_results) + extra_shards = [] + if len(write_results) < min_shards: + logging.debug( + 'Creating %s empty shard(s).', min_shards - len(write_results)) + for _ in range(min_shards - len(write_results)): + writer = sink.open_writer(init_result, str(uuid.uuid4())) + extra_shards.append(writer.close()) + outputs = sink.finalize_write(init_result, write_results + extra_shards) + if outputs: + return (window.TimestampedValue(v, window.MAX_TIMESTAMP) for v in outputs) + + +class _RoundRobinKeyFn(core.DoFn): + + def __init__(self, count): + self.count = count + + def start_bundle(self, context): + self.counter = random.randint(0, self.count - 1) + + def process(self, context): + self.counter += 1 + if self.counter >= self.count: + self.counter -= self.count + yield self.counter, context.element
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/b14dfadd/sdks/python/apache_beam/io/pubsub.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/pubsub.py b/sdks/python/apache_beam/io/pubsub.py new file mode 100644 index 0000000..88aa7f5 --- /dev/null +++ b/sdks/python/apache_beam/io/pubsub.py @@ -0,0 +1,73 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Google Cloud PubSub sources and sinks. + +Cloud Pub/Sub sources and sinks are currently supported only in streaming +pipelines, during remote execution. +""" + +from __future__ import absolute_import + +from google.cloud.dataflow import coders +from google.cloud.dataflow.io import iobase + + +class PubSubSource(iobase.NativeSource): + """Source for reading from a given Cloud Pub/Sub topic. + + Attributes: + topic: Cloud Pub/Sub topic in the form "/topics/<project>/<topic>". + subscription: Optional existing Cloud Pub/Sub subscription to use in the + form "projects/<project>/subscriptions/<subscription>". + id_label: The attribute on incoming Pub/Sub messages to use as a unique + record identifier. When specified, the value of this attribute (which can + be any string that uniquely identifies the record) will be used for + deduplication of messages. If not provided, Dataflow cannot guarantee + that no duplicate data will be delivered on the Pub/Sub stream. In this + case, deduplication of the stream will be strictly best effort. + coder: The Coder to use for decoding incoming Pub/Sub messages. + """ + + def __init__(self, topic, subscription=None, id_label=None, + coder=coders.StrUtf8Coder()): + self.topic = topic + self.subscription = subscription + self.id_label = id_label + self.coder = coder + + @property + def format(self): + """Source format name required for remote execution.""" + return 'pubsub' + + def reader(self): + raise NotImplementedError( + 'PubSubSource is not supported in local execution.') + + +class PubSubSink(iobase.NativeSink): + """Sink for writing to a given Cloud Pub/Sub topic.""" + + def __init__(self, topic, coder=coders.StrUtf8Coder()): + self.topic = topic + self.coder = coder + + @property + def format(self): + """Sink format name required for remote execution.""" + return 'pubsub' + + def writer(self): + raise NotImplementedError( + 'PubSubSink is not supported in local execution.') http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/b14dfadd/sdks/python/apache_beam/io/range_trackers.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/range_trackers.py b/sdks/python/apache_beam/io/range_trackers.py new file mode 100644 index 0000000..2cdcd5b --- /dev/null +++ b/sdks/python/apache_beam/io/range_trackers.py @@ -0,0 +1,270 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""iobase.RangeTracker implementations provided with Dataflow SDK. +""" + +import logging +import math +import threading + +from google.cloud.dataflow.io import iobase + + +class OffsetRangeTracker(iobase.RangeTracker): + """A 'RangeTracker' for non-negative positions of type 'long'.""" + + # Offset corresponding to infinity. This can only be used as the upper-bound + # of a range, and indicates reading all of the records until the end without + # specifying exactly what the end is. + # Infinite ranges cannot be split because it is impossible to estimate + # progress within them. + OFFSET_INFINITY = float('inf') + + def __init__(self, start, end): + super(OffsetRangeTracker, self).__init__() + self._start_offset = start + self._stop_offset = end + self._last_record_start = -1 + self._offset_of_last_split_point = -1 + self._lock = threading.Lock() + + def start_position(self): + return self._start_offset + + def stop_position(self): + return self._stop_offset + + @property + def last_record_start(self): + return self._last_record_start + + def _validate_record_start(self, record_start, split_point): + # This function must only be called under the lock self.lock. + if not self._lock.locked(): + raise ValueError( + 'This function must only be called under the lock self.lock.') + + if record_start < self._last_record_start: + raise ValueError( + 'Trying to return a record [starting at %d] which is before the ' + 'last-returned record [starting at %d]' % + (record_start, self._last_record_start)) + + if split_point: + if (self._offset_of_last_split_point != -1 and + record_start == self._offset_of_last_split_point): + raise ValueError( + 'Record at a split point has same offset as the previous split ' + 'point: %d' % record_start) + elif self._last_record_start == -1: + raise ValueError( + 'The first record [starting at %d] must be at a split point' % + record_start) + + if (split_point and self._offset_of_last_split_point is not -1 and + record_start is self._offset_of_last_split_point): + raise ValueError( + 'Record at a split point has same offset as the previous split ' + 'point: %d' % record_start) + + if not split_point and self._last_record_start == -1: + raise ValueError( + 'The first record [starting at %d] must be at a split point' % + record_start) + + def try_claim(self, record_start): + with self._lock: + self._validate_record_start(record_start, True) + if record_start >= self.stop_position(): + return False + self._offset_of_last_split_point = record_start + self._last_record_start = record_start + return True + + def set_current_position(self, record_start): + with self._lock: + self._validate_record_start(record_start, False) + self._last_record_start = record_start + + def try_split(self, split_offset): + with self._lock: + if self._stop_offset == OffsetRangeTracker.OFFSET_INFINITY: + logging.debug('refusing to split %r at %d: stop position unspecified', + self, split_offset) + return + if self._last_record_start == -1: + logging.debug('Refusing to split %r at %d: unstarted', self, + split_offset) + return + + if split_offset <= self._last_record_start: + logging.debug( + 'Refusing to split %r at %d: already past proposed stop offset', + self, split_offset) + return + if (split_offset < self.start_position() + or split_offset >= self.stop_position()): + logging.debug( + 'Refusing to split %r at %d: proposed split position out of range', + self, split_offset) + return + + logging.debug('Agreeing to split %r at %d', self, split_offset) + self._stop_offset = split_offset + + split_fraction = (float(split_offset - self._start_offset) / ( + self._stop_offset - self._start_offset)) + + return self._stop_offset, split_fraction + + def fraction_consumed(self): + with self._lock: + fraction = ((1.0 * (self._last_record_start - self.start_position()) / + (self.stop_position() - self.start_position())) if + self.stop_position() != self.start_position() else 0.0) + + # self.last_record_start may become larger than self.end_offset when + # reading the records since any record that starts before the first 'split + # point' at or after the defined 'stop offset' is considered to be within + # the range of the OffsetRangeTracker. Hence fraction could be > 1. + # self.last_record_start is initialized to -1, hence fraction may be < 0. + # Bounding the to range [0, 1]. + return max(0.0, min(1.0, fraction)) + + def position_at_fraction(self, fraction): + if self.stop_position() == OffsetRangeTracker.OFFSET_INFINITY: + raise Exception( + 'get_position_for_fraction_consumed is not applicable for an ' + 'unbounded range') + return (math.ceil(self.start_position() + fraction * ( + self.stop_position() - self.start_position()))) + + +class GroupedShuffleRangeTracker(iobase.RangeTracker): + """A 'RangeTracker' for positions used by'GroupedShuffleReader'. + + These positions roughly correspond to hashes of keys. In case of hash + collisions, multiple groups can have the same position. In that case, the + first group at a particular position is considered a split point (because + it is the first to be returned when reading a position range starting at this + position), others are not. + """ + + def __init__(self, decoded_start_pos, decoded_stop_pos): + super(GroupedShuffleRangeTracker, self).__init__() + self._decoded_start_pos = decoded_start_pos + self._decoded_stop_pos = decoded_stop_pos + self._decoded_last_group_start = None + self._last_group_was_at_a_split_point = False + self._lock = threading.Lock() + + def start_position(self): + return self._decoded_start_pos + + def stop_position(self): + return self._decoded_stop_pos + + def last_group_start(self): + return self._decoded_last_group_start + + def _validate_decoded_group_start(self, decoded_group_start, split_point): + if self.start_position() and decoded_group_start < self.start_position(): + raise ValueError('Trying to return record at %r which is before the' + ' starting position at %r' % + (decoded_group_start, self.start_position())) + + if (self.last_group_start() and + decoded_group_start < self.last_group_start()): + raise ValueError('Trying to return group at %r which is before the' + ' last-returned group at %r' % + (decoded_group_start, self.last_group_start())) + if (split_point and self.last_group_start() and + self.last_group_start() == decoded_group_start): + raise ValueError('Trying to return a group at a split point with ' + 'same position as the previous group: both at %r, ' + 'last group was %sat a split point.' % + (decoded_group_start, + ('' if self._last_group_was_at_a_split_point + else 'not '))) + if not split_point: + if self.last_group_start() is None: + raise ValueError('The first group [at %r] must be at a split point' % + decoded_group_start) + if self.last_group_start() != decoded_group_start: + # This case is not a violation of general RangeTracker semantics, but it + # is contrary to how GroupingShuffleReader in particular works. Hitting + # it would mean it's behaving unexpectedly. + raise ValueError('Trying to return a group not at a split point, but ' + 'with a different position than the previous group: ' + 'last group was %r at %r, current at a %s split' + ' point.' % + (self.last_group_start() + , decoded_group_start + , ('' if self._last_group_was_at_a_split_point + else 'non-'))) + + def try_claim(self, decoded_group_start): + with self._lock: + self._validate_decoded_group_start(decoded_group_start, True) + if (self.stop_position() + and decoded_group_start >= self.stop_position()): + return False + + self._decoded_last_group_start = decoded_group_start + self._last_group_was_at_a_split_point = True + return True + + def set_current_position(self, decoded_group_start): + with self._lock: + self._validate_decoded_group_start(decoded_group_start, False) + self._decoded_last_group_start = decoded_group_start + self._last_group_was_at_a_split_point = False + + def try_split(self, decoded_split_position): + with self._lock: + if self.last_group_start() is None: + logging.info('Refusing to split %r at %r: unstarted' + , self, decoded_split_position) + return + + if decoded_split_position <= self.last_group_start(): + logging.info('Refusing to split %r at %r: already past proposed split ' + 'position' + , self, decoded_split_position) + return + + if ((self.stop_position() + and decoded_split_position >= self.stop_position()) + or (self.start_position() + and decoded_split_position <= self.start_position())): + logging.error('Refusing to split %r at %r: proposed split position out ' + 'of range', self, decoded_split_position) + return + + logging.debug('Agreeing to split %r at %r' + , self, decoded_split_position) + self._decoded_stop_pos = decoded_split_position + + # Since GroupedShuffleRangeTracker cannot determine relative sizes of the + # two splits, returning 0.5 as the fraction below so that the framework + # assumes the splits to be of the same size. + return self._decoded_stop_pos, 0.5 + + def fraction_consumed(self): + # GroupingShuffle sources have special support on the service and the + # service will estimate progress from positions for us. + raise RuntimeError('GroupedShuffleRangeTracker does not measure fraction' + ' consumed due to positions being opaque strings' + ' that are interpretted by the service') http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/b14dfadd/sdks/python/apache_beam/io/range_trackers_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/range_trackers_test.py b/sdks/python/apache_beam/io/range_trackers_test.py new file mode 100644 index 0000000..709d594 --- /dev/null +++ b/sdks/python/apache_beam/io/range_trackers_test.py @@ -0,0 +1,318 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the range_trackers module.""" + +import array +import copy +import logging +import unittest + + +from google.cloud.dataflow.io import range_trackers + + +class OffsetRangeTrackerTest(unittest.TestCase): + + def test_try_return_record_simple_sparse(self): + tracker = range_trackers.OffsetRangeTracker(100, 200) + self.assertTrue(tracker.try_claim(110)) + self.assertTrue(tracker.try_claim(140)) + self.assertTrue(tracker.try_claim(183)) + self.assertFalse(tracker.try_claim(210)) + + def test_try_return_record_simple_dense(self): + tracker = range_trackers.OffsetRangeTracker(3, 6) + self.assertTrue(tracker.try_claim(3)) + self.assertTrue(tracker.try_claim(4)) + self.assertTrue(tracker.try_claim(5)) + self.assertFalse(tracker.try_claim(6)) + + def test_try_return_record_continuous_until_split_point(self): + tracker = range_trackers.OffsetRangeTracker(9, 18) + # Return records with gaps of 2; every 3rd record is a split point. + self.assertTrue(tracker.try_claim(10)) + tracker.set_current_position(12) + tracker.set_current_position(14) + self.assertTrue(tracker.try_claim(16)) + # Out of range, but not a split point... + tracker.set_current_position(18) + tracker.set_current_position(20) + # Out of range AND a split point. + self.assertFalse(tracker.try_claim(22)) + + def test_split_at_offset_fails_if_unstarted(self): + tracker = range_trackers.OffsetRangeTracker(100, 200) + self.assertFalse(tracker.try_split(150)) + + def test_split_at_offset(self): + tracker = range_trackers.OffsetRangeTracker(100, 200) + self.assertTrue(tracker.try_claim(110)) + # Example positions we shouldn't split at, when last record starts at 110: + self.assertFalse(tracker.try_split(109)) + self.assertFalse(tracker.try_split(110)) + self.assertFalse(tracker.try_split(200)) + self.assertFalse(tracker.try_split(210)) + # Example positions we *should* split at: + self.assertTrue(copy.copy(tracker).try_split(111)) + self.assertTrue(copy.copy(tracker).try_split(129)) + self.assertTrue(copy.copy(tracker).try_split(130)) + self.assertTrue(copy.copy(tracker).try_split(131)) + self.assertTrue(copy.copy(tracker).try_split(150)) + self.assertTrue(copy.copy(tracker).try_split(199)) + + # If we split at 170 and then at 150: + self.assertTrue(tracker.try_split(170)) + self.assertTrue(tracker.try_split(150)) + # Should be able to return a record starting before the new stop offset. + # Returning records starting at the same offset is ok. + self.assertTrue(copy.copy(tracker).try_claim(135)) + self.assertTrue(copy.copy(tracker).try_claim(135)) + # Should be able to return a record starting right before the new stop + # offset. + self.assertTrue(copy.copy(tracker).try_claim(149)) + # Should not be able to return a record starting at or after the new stop + # offset. + self.assertFalse(tracker.try_claim(150)) + self.assertFalse(tracker.try_claim(151)) + # Should accept non-splitpoint records starting after stop offset. + tracker.set_current_position(135) + tracker.set_current_position(152) + tracker.set_current_position(160) + tracker.set_current_position(171) + + def test_get_position_for_fraction_dense(self): + # Represents positions 3, 4, 5. + tracker = range_trackers.OffsetRangeTracker(3, 6) + # [3, 3) represents 0.0 of [3, 6) + self.assertEqual(3, tracker.position_at_fraction(0.0)) + # [3, 4) represents up to 1/3 of [3, 6) + self.assertEqual(4, tracker.position_at_fraction(1.0 / 6)) + self.assertEqual(4, tracker.position_at_fraction(0.333)) + # [3, 5) represents up to 2/3 of [3, 6) + self.assertEqual(5, tracker.position_at_fraction(0.334)) + self.assertEqual(5, tracker.position_at_fraction(0.666)) + # Any fraction consumed over 2/3 means the whole [3, 6) has been consumed. + self.assertEqual(6, tracker.position_at_fraction(0.667)) + + def test_get_fraction_consumed_dense(self): + tracker = range_trackers.OffsetRangeTracker(3, 6) + self.assertEqual(0, tracker.fraction_consumed()) + self.assertTrue(tracker.try_claim(3)) + self.assertEqual(0.0, tracker.fraction_consumed()) + self.assertTrue(tracker.try_claim(4)) + self.assertEqual(1.0 / 3, tracker.fraction_consumed()) + self.assertTrue(tracker.try_claim(5)) + self.assertEqual(2.0 / 3, tracker.fraction_consumed()) + tracker.set_current_position(6) + self.assertEqual(1.0, tracker.fraction_consumed()) + tracker.set_current_position(7) + self.assertFalse(tracker.try_claim(7)) + + def test_get_fraction_consumed_sparse(self): + tracker = range_trackers.OffsetRangeTracker(100, 200) + self.assertEqual(0, tracker.fraction_consumed()) + self.assertTrue(tracker.try_claim(110)) + # Consumed positions through 110 = total 10 positions of 100 done. + self.assertEqual(0.10, tracker.fraction_consumed()) + self.assertTrue(tracker.try_claim(150)) + self.assertEqual(0.50, tracker.fraction_consumed()) + self.assertTrue(tracker.try_claim(195)) + self.assertEqual(0.95, tracker.fraction_consumed()) + + def test_everything_with_unbounded_range(self): + tracker = range_trackers.OffsetRangeTracker( + 100, + range_trackers.OffsetRangeTracker.OFFSET_INFINITY) + self.assertTrue(tracker.try_claim(150)) + self.assertTrue(tracker.try_claim(250)) + # get_position_for_fraction_consumed should fail for an unbounded range + with self.assertRaises(Exception): + tracker.position_at_fraction(0.5) + + def test_try_return_first_record_not_split_point(self): + with self.assertRaises(Exception): + range_trackers.OffsetRangeTracker(100, 200).set_current_position(120) + + def test_try_return_record_non_monotonic(self): + tracker = range_trackers.OffsetRangeTracker(100, 200) + self.assertTrue(tracker.try_claim(120)) + with self.assertRaises(Exception): + tracker.try_claim(110) + + +class GroupedShuffleRangeTrackerTest(unittest.TestCase): + + def bytes_to_position(self, bytes_array): + return array.array('B', bytes_array).tostring() + + def test_try_return_record_in_infinite_range(self): + tracker = range_trackers.GroupedShuffleRangeTracker('', '') + self.assertTrue(tracker.try_claim( + self.bytes_to_position([1, 2, 3]))) + self.assertTrue(tracker.try_claim( + self.bytes_to_position([1, 2, 5]))) + self.assertTrue(tracker.try_claim( + self.bytes_to_position([3, 6, 8, 10]))) + + def test_try_return_record_finite_range(self): + tracker = range_trackers.GroupedShuffleRangeTracker( + self.bytes_to_position([1, 0, 0]), self.bytes_to_position([5, 0, 0])) + self.assertTrue(tracker.try_claim( + self.bytes_to_position([1, 2, 3]))) + self.assertTrue(tracker.try_claim( + self.bytes_to_position([1, 2, 5]))) + self.assertTrue(tracker.try_claim( + self.bytes_to_position([3, 6, 8, 10]))) + self.assertTrue(tracker.try_claim( + self.bytes_to_position([4, 255, 255, 255]))) + # Should fail for positions that are lexicographically equal to or larger + # than the defined stop position. + self.assertFalse(copy.copy(tracker).try_claim( + self.bytes_to_position([5, 0, 0]))) + self.assertFalse(copy.copy(tracker).try_claim( + self.bytes_to_position([5, 0, 1]))) + self.assertFalse(copy.copy(tracker).try_claim( + self.bytes_to_position([6, 0, 0]))) + + def test_try_return_record_with_non_split_point(self): + tracker = range_trackers.GroupedShuffleRangeTracker( + self.bytes_to_position([1, 0, 0]), self.bytes_to_position([5, 0, 0])) + self.assertTrue(tracker.try_claim( + self.bytes_to_position([1, 2, 3]))) + tracker.set_current_position(self.bytes_to_position([1, 2, 3])) + tracker.set_current_position(self.bytes_to_position([1, 2, 3])) + self.assertTrue(tracker.try_claim( + self.bytes_to_position([1, 2, 5]))) + tracker.set_current_position(self.bytes_to_position([1, 2, 5])) + self.assertTrue(tracker.try_claim( + self.bytes_to_position([3, 6, 8, 10]))) + self.assertTrue(tracker.try_claim( + self.bytes_to_position([4, 255, 255, 255]))) + + def test_first_record_non_split_point(self): + tracker = range_trackers.GroupedShuffleRangeTracker( + self.bytes_to_position([3, 0, 0]), self.bytes_to_position([5, 0, 0])) + with self.assertRaises(ValueError): + tracker.set_current_position(self.bytes_to_position([3, 4, 5])) + + def test_non_split_point_record_with_different_position(self): + tracker = range_trackers.GroupedShuffleRangeTracker( + self.bytes_to_position([3, 0, 0]), self.bytes_to_position([5, 0, 0])) + self.assertTrue(tracker.try_claim(self.bytes_to_position([3, 4, 5]))) + with self.assertRaises(ValueError): + tracker.set_current_position(self.bytes_to_position([3, 4, 6])) + + def test_try_return_record_before_start(self): + tracker = range_trackers.GroupedShuffleRangeTracker( + self.bytes_to_position([3, 0, 0]), self.bytes_to_position([5, 0, 0])) + with self.assertRaises(ValueError): + tracker.try_claim(self.bytes_to_position([1, 2, 3])) + + def test_try_return_non_monotonic(self): + tracker = range_trackers.GroupedShuffleRangeTracker( + self.bytes_to_position([3, 0, 0]), self.bytes_to_position([5, 0, 0])) + self.assertTrue(tracker.try_claim(self.bytes_to_position([3, 4, 5]))) + self.assertTrue(tracker.try_claim(self.bytes_to_position([3, 4, 6]))) + with self.assertRaises(ValueError): + tracker.try_claim(self.bytes_to_position([3, 2, 1])) + + def test_try_return_identical_positions(self): + tracker = range_trackers.GroupedShuffleRangeTracker( + self.bytes_to_position([3, 0, 0]), self.bytes_to_position([5, 0, 0])) + self.assertTrue(tracker.try_claim( + self.bytes_to_position([3, 4, 5]))) + with self.assertRaises(ValueError): + tracker.try_claim(self.bytes_to_position([3, 4, 5])) + + def test_try_split_at_position_infinite_range(self): + tracker = range_trackers.GroupedShuffleRangeTracker('', '') + # Should fail before first record is returned. + self.assertFalse(tracker.try_split( + self.bytes_to_position([3, 4, 5, 6]))) + + self.assertTrue(tracker.try_claim( + self.bytes_to_position([1, 2, 3]))) + + # Should now succeed. + self.assertIsNotNone(tracker.try_split( + self.bytes_to_position([3, 4, 5, 6]))) + # Should not split at same or larger position. + self.assertIsNone(tracker.try_split( + self.bytes_to_position([3, 4, 5, 6]))) + self.assertIsNone(tracker.try_split( + self.bytes_to_position([3, 4, 5, 6, 7]))) + self.assertIsNone(tracker.try_split( + self.bytes_to_position([4, 5, 6, 7]))) + + # Should split at smaller position. + self.assertIsNotNone(tracker.try_split( + self.bytes_to_position([3, 2, 1]))) + + self.assertTrue(tracker.try_claim( + self.bytes_to_position([2, 3, 4]))) + + # Should not split at a position we're already past. + self.assertIsNone(tracker.try_split( + self.bytes_to_position([2, 3, 4]))) + self.assertIsNone(tracker.try_split( + self.bytes_to_position([2, 3, 3]))) + + self.assertTrue(tracker.try_claim( + self.bytes_to_position([3, 2, 0]))) + self.assertFalse(tracker.try_claim( + self.bytes_to_position([3, 2, 1]))) + + def test_try_test_split_at_position_finite_range(self): + tracker = range_trackers.GroupedShuffleRangeTracker( + self.bytes_to_position([0, 0, 0]), + self.bytes_to_position([10, 20, 30])) + # Should fail before first record is returned. + self.assertFalse(tracker.try_split( + self.bytes_to_position([0, 0, 0]))) + self.assertFalse(tracker.try_split( + self.bytes_to_position([3, 4, 5, 6]))) + + self.assertTrue(tracker.try_claim( + self.bytes_to_position([1, 2, 3]))) + + # Should now succeed. + self.assertTrue(tracker.try_split( + self.bytes_to_position([3, 4, 5, 6]))) + # Should not split at same or larger position. + self.assertFalse(tracker.try_split( + self.bytes_to_position([3, 4, 5, 6]))) + self.assertFalse(tracker.try_split( + self.bytes_to_position([3, 4, 5, 6, 7]))) + self.assertFalse(tracker.try_split( + self.bytes_to_position([4, 5, 6, 7]))) + + # Should split at smaller position. + self.assertTrue(tracker.try_split( + self.bytes_to_position([3, 2, 1]))) + # But not at a position at or before last returned record. + self.assertFalse(tracker.try_split( + self.bytes_to_position([1, 2, 3]))) + + self.assertTrue(tracker.try_claim( + self.bytes_to_position([2, 3, 4]))) + self.assertTrue(tracker.try_claim( + self.bytes_to_position([3, 2, 0]))) + self.assertFalse(tracker.try_claim( + self.bytes_to_position([3, 2, 1]))) + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/b14dfadd/sdks/python/apache_beam/io/sources_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/sources_test.py b/sdks/python/apache_beam/io/sources_test.py new file mode 100644 index 0000000..512dc1a --- /dev/null +++ b/sdks/python/apache_beam/io/sources_test.py @@ -0,0 +1,65 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the sources framework.""" + +import logging +import tempfile +import unittest + +import google.cloud.dataflow as df + +from google.cloud.dataflow.io import iobase +from google.cloud.dataflow.transforms.util import assert_that +from google.cloud.dataflow.transforms.util import equal_to + + +class LineSource(iobase.BoundedSource): + """A simple source that reads lines from a given file.""" + + def __init__(self, file_name): + self._file_name = file_name + + def read(self, _): + with open(self._file_name) as f: + for line in f: + yield line.rstrip('\n') + + +class SourcesTest(unittest.TestCase): + + def _create_temp_file(self, contents): + with tempfile.NamedTemporaryFile(delete=False) as f: + f.write(contents) + return f.name + + def test_read_from_source(self): + file_name = self._create_temp_file('aaaa\nbbbb\ncccc\ndddd') + + source = LineSource(file_name) + result = [line for line in source.read(None)] + + self.assertItemsEqual(['aaaa', 'bbbb', 'cccc', 'dddd'], result) + + def test_run_direct(self): + file_name = self._create_temp_file('aaaa\nbbbb\ncccc\ndddd') + pipeline = df.Pipeline('DirectPipelineRunner') + pcoll = pipeline | df.Read(LineSource(file_name)) + assert_that(pcoll, equal_to(['aaaa', 'bbbb', 'cccc', 'dddd'])) + + pipeline.run() + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + unittest.main()