This is an automated email from the ASF dual-hosted git repository.
yichi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 5742cb9 BEAM-12122, BEAM-12119 Add integer and string _id keys
support to Python IO MongoDB
new 7494d14 Merge pull request #14460 from
MaksymSkorupskyi/BEAM-12122-Python-IO-MongoDB-integer-and-string-`_id`-keys-are-not-supported
5742cb9 is described below
commit 5742cb929af6b6cb74f591f554108451c6fc32d5
Author: Maksym Skorupskyi <[email protected]>
AuthorDate: Wed Apr 7 15:46:58 2021 +0300
BEAM-12122, BEAM-12119 Add integer and string _id keys support to Python IO
MongoDB
* handle integer and string _id keys when reading from MongoDB collections
* RangeTracker defines by _id key type:
- int -> OffsetRangeTracker
- ObjectId -> _ObjectIdRangeTracker
- string -> LexicographicKeyRangeTracker
* update LexicographicKeyRangeTracker to proper handling both bytes and str
keys
* make range_trackers.OrderedPositionRangeTracker.try_split() logic
consistent with range_trackers.OffsetRangeTracker.try_split()
* add unittests for MongoDB _id of int and str types
* update range_tracker_test
---
sdks/python/apache_beam/io/mongodbio.py | 491 +++++++++++++++-------
sdks/python/apache_beam/io/mongodbio_test.py | 230 ++++++++--
sdks/python/apache_beam/io/range_trackers.py | 131 +++---
sdks/python/apache_beam/io/range_trackers_test.py | 102 ++++-
4 files changed, 706 insertions(+), 248 deletions(-)
diff --git a/sdks/python/apache_beam/io/mongodbio.py
b/sdks/python/apache_beam/io/mongodbio.py
index a56c274..c6f7d97 100644
--- a/sdks/python/apache_beam/io/mongodbio.py
+++ b/sdks/python/apache_beam/io/mongodbio.py
@@ -71,9 +71,12 @@ import json
import logging
import math
import struct
+from typing import Union
import apache_beam as beam
from apache_beam.io import iobase
+from apache_beam.io.range_trackers import LexicographicKeyRangeTracker
+from apache_beam.io.range_trackers import OffsetRangeTracker
from apache_beam.io.range_trackers import OrderedPositionRangeTracker
from apache_beam.transforms import DoFn
from apache_beam.transforms import PTransform
@@ -83,11 +86,13 @@ from apache_beam.utils.annotations import experimental
_LOGGER = logging.getLogger(__name__)
try:
- # Mongodb has its own bundled bson, which is not compatible with bson
pakcage.
+ # Mongodb has its own bundled bson, which is not compatible with bson
package.
# (https://github.com/py-bson/bson/issues/82). Try to import objectid and if
# it fails because bson package is installed, MongoDB IO will not work but at
# least rest of the SDK will work.
+ from bson import json_util
from bson import objectid
+ from bson.objectid import ObjectId
# pymongo also internally depends on bson.
from pymongo import ASCENDING
@@ -96,24 +101,30 @@ try:
from pymongo import ReplaceOne
except ImportError:
objectid = None
+ json_util = None
+ ObjectId = None
+ ASCENDING = 1
+ DESCENDING = -1
+ MongoClient = None
+ ReplaceOne = None
_LOGGER.warning("Could not find a compatible bson package.")
-__all__ = ['ReadFromMongoDB', 'WriteToMongoDB']
+__all__ = ["ReadFromMongoDB", "WriteToMongoDB"]
@experimental()
class ReadFromMongoDB(PTransform):
- """A ``PTransform`` to read MongoDB documents into a ``PCollection``.
- """
+ """A ``PTransform`` to read MongoDB documents into a ``PCollection``."""
def __init__(
self,
- uri='mongodb://localhost:27017',
+ uri="mongodb://localhost:27017",
db=None,
coll=None,
filter=None,
projection=None,
extra_client_params=None,
- bucket_auto=False):
+ bucket_auto=False,
+ ):
"""Initialize a :class:`ReadFromMongoDB`
Args:
@@ -136,16 +147,14 @@ class ReadFromMongoDB(PTransform):
Returns:
:class:`~apache_beam.transforms.ptransform.PTransform`
-
"""
if extra_client_params is None:
extra_client_params = {}
if not isinstance(db, str):
- raise ValueError('ReadFromMongDB db param must be specified as a string')
+ raise ValueError("ReadFromMongDB db param must be specified as a string")
if not isinstance(coll, str):
raise ValueError(
- 'ReadFromMongDB coll param must be specified as a '
- 'string')
+ "ReadFromMongDB coll param must be specified as a string")
self._mongo_source = _BoundedMongoSource(
uri=uri,
db=db,
@@ -153,13 +162,88 @@ class ReadFromMongoDB(PTransform):
filter=filter,
projection=projection,
extra_client_params=extra_client_params,
- bucket_auto=bucket_auto)
+ bucket_auto=bucket_auto,
+ )
def expand(self, pcoll):
return pcoll | iobase.Read(self._mongo_source)
+class _ObjectIdRangeTracker(OrderedPositionRangeTracker):
+ """RangeTracker for tracking mongodb _id of bson ObjectId type."""
+ def position_to_fraction(
+ self,
+ pos: ObjectId,
+ start: ObjectId,
+ end: ObjectId,
+ ):
+ """Returns the fraction of keys in the range [start, end) that
+ are less than the given key.
+ """
+ pos_number = _ObjectIdHelper.id_to_int(pos)
+ start_number = _ObjectIdHelper.id_to_int(start)
+ end_number = _ObjectIdHelper.id_to_int(end)
+ return (pos_number - start_number) / (end_number - start_number)
+
+ def fraction_to_position(
+ self,
+ fraction: float,
+ start: ObjectId,
+ end: ObjectId,
+ ):
+ """Converts a fraction between 0 and 1
+ to a position between start and end.
+ """
+ start_number = _ObjectIdHelper.id_to_int(start)
+ end_number = _ObjectIdHelper.id_to_int(end)
+ total = end_number - start_number
+ pos = int(total * fraction + start_number)
+ # make sure split position is larger than start position and smaller than
+ # end position.
+ if pos <= start_number:
+ return _ObjectIdHelper.increment_id(start, 1)
+
+ if pos >= end_number:
+ return _ObjectIdHelper.increment_id(end, -1)
+
+ return _ObjectIdHelper.int_to_id(pos)
+
+
class _BoundedMongoSource(iobase.BoundedSource):
+ """A MongoDB source that reads a finite amount of input records.
+
+ This class defines following operations which can be used to read
+ MongoDB source efficiently.
+
+ * Size estimation - method ``estimate_size()`` may return an accurate
+ estimation in bytes for the size of the source.
+ * Splitting into bundles of a given size - method ``split()`` can be used to
+ split the source into a set of sub-sources (bundles) based on a desired
+ bundle size.
+ * Getting a RangeTracker - method ``get_range_tracker()`` should return a
+ ``RangeTracker`` object for a given position range for the position type
+ of the records returned by the source.
+ * Reading the data - method ``read()`` can be used to read data from the
+ source while respecting the boundaries defined by a given
+ ``RangeTracker``.
+
+ A runner will perform reading the source in two steps.
+
+ (1) Method ``get_range_tracker()`` will be invoked with start and end
+ positions to obtain a ``RangeTracker`` for the range of positions the
+ runner intends to read. Source must define a default initial start and
end
+ position range. These positions must be used if the start and/or end
+ positions passed to the method ``get_range_tracker()`` are ``None``
+ (2) Method read() will be invoked with the ``RangeTracker`` obtained in the
+ previous step.
+
+ **Mutability**
+
+ A ``_BoundedMongoSource`` object should not be mutated while
+ its methods (for example, ``read()``) are being invoked by a runner. Runner
+ implementations may invoke methods of ``_BoundedMongoSource`` objects through
+ multi-threaded and/or reentrant execution modes.
+ """
def __init__(
self,
uri=None,
@@ -168,7 +252,8 @@ class _BoundedMongoSource(iobase.BoundedSource):
filter=None,
projection=None,
extra_client_params=None,
- bucket_auto=False):
+ bucket_auto=False,
+ ):
if extra_client_params is None:
extra_client_params = {}
if filter is None:
@@ -183,13 +268,33 @@ class _BoundedMongoSource(iobase.BoundedSource):
def estimate_size(self):
with MongoClient(self.uri, **self.spec) as client:
- return client[self.db].command('collstats', self.coll).get('size')
+ return client[self.db].command("collstats", self.coll).get("size")
def _estimate_average_document_size(self):
with MongoClient(self.uri, **self.spec) as client:
- return client[self.db].command('collstats', self.coll).get('avgObjSize')
+ return client[self.db].command("collstats", self.coll).get("avgObjSize")
+
+ def split(
+ self,
+ desired_bundle_size: int,
+ start_position: Union[int, str, bytes, ObjectId] = None,
+ stop_position: Union[int, str, bytes, ObjectId] = None,
+ ):
+ """Splits the source into a set of bundles.
+
+ Bundles should be approximately of size ``desired_bundle_size`` bytes.
+
+ Args:
+ desired_bundle_size: the desired size (in bytes) of the bundles returned.
+ start_position: if specified the given position must be used as the
+ starting position of the first bundle.
+ stop_position: if specified the given position must be used as the ending
+ position of the last bundle.
+ Returns:
+ an iterator of objects of type 'SourceBundle' that gives information
about
+ the generated bundles.
+ """
- def split(self, desired_bundle_size, start_position=None,
stop_position=None):
desired_bundle_size_in_mb = desired_bundle_size // 1024 // 1024
# for desired bundle size, if desired chunk size smaller than 1mb, use
@@ -199,18 +304,21 @@ class _BoundedMongoSource(iobase.BoundedSource):
is_initial_split = start_position is None and stop_position is None
start_position, stop_position = self._replace_none_positions(
- start_position, stop_position)
+ start_position, stop_position
+ )
if self.bucket_auto:
# Use $bucketAuto for bundling
split_keys = []
weights = []
- for bucket in self._get_auto_buckets(desired_bundle_size_in_mb,
- start_position,
- stop_position,
- is_initial_split):
- split_keys.append({'_id': bucket['_id']['max']})
- weights.append(bucket['count'])
+ for bucket in self._get_auto_buckets(
+ desired_bundle_size_in_mb,
+ start_position,
+ stop_position,
+ is_initial_split,
+ ):
+ split_keys.append({"_id": bucket["_id"]["max"]})
+ weights.append(bucket["count"])
else:
# Use splitVector for bundling
split_keys = self._get_split_keys(
@@ -221,12 +329,13 @@ class _BoundedMongoSource(iobase.BoundedSource):
for split_key_id, weight in zip(split_keys, weights):
if bundle_start >= stop_position:
break
- bundle_end = min(stop_position, split_key_id['_id'])
+ bundle_end = min(stop_position, split_key_id["_id"])
yield iobase.SourceBundle(
weight=weight,
source=self,
start_position=bundle_start,
- stop_position=bundle_end)
+ stop_position=bundle_end,
+ )
bundle_start = bundle_end
# add range of last split_key to stop_position
if bundle_start < stop_position:
@@ -236,60 +345,146 @@ class _BoundedMongoSource(iobase.BoundedSource):
weight=weight,
source=self,
start_position=bundle_start,
- stop_position=stop_position)
+ stop_position=stop_position,
+ )
- def get_range_tracker(self, start_position, stop_position):
+ def get_range_tracker(
+ self,
+ start_position: Union[int, str, ObjectId] = None,
+ stop_position: Union[int, str, ObjectId] = None,
+ ) -> Union[
+ _ObjectIdRangeTracker, OffsetRangeTracker, LexicographicKeyRangeTracker]:
+ """Returns a RangeTracker for a given position range depending on type.
+
+ Args:
+ start_position: starting position of the range. If 'None' default start
+ position of the source must be used.
+ stop_position: ending position of the range. If 'None' default stop
+ position of the source must be used.
+ Returns:
+ a ``_ObjectIdRangeTracker``, ``OffsetRangeTracker``
+ or ``LexicographicKeyRangeTracker`` depending on the given position
range.
+ """
start_position, stop_position = self._replace_none_positions(
- start_position, stop_position)
- return _ObjectIdRangeTracker(start_position, stop_position)
+ start_position, stop_position
+ )
+
+ if isinstance(start_position, ObjectId):
+ return _ObjectIdRangeTracker(start_position, stop_position)
+
+ if isinstance(start_position, int):
+ return OffsetRangeTracker(start_position, stop_position)
+
+ if isinstance(start_position, str):
+ return LexicographicKeyRangeTracker(start_position, stop_position)
+
+ raise NotImplementedError(
+ f"RangeTracker for {type(start_position)} not implemented!")
def read(self, range_tracker):
+ """Returns an iterator that reads data from the source.
+
+ The returned set of data must respect the boundaries defined by the given
+ ``RangeTracker`` object. For example:
+
+ * Returned set of data must be for the range
+ ``[range_tracker.start_position, range_tracker.stop_position)``. Note
+ that a source may decide to return records that start after
+ ``range_tracker.stop_position``. See documentation in class
+ ``RangeTracker`` for more details. Also, note that framework might
+ invoke ``range_tracker.try_split()`` to perform dynamic split
+ operations. range_tracker.stop_position may be updated
+ dynamically due to successful dynamic split operations.
+ * Method ``range_tracker.try_split()`` must be invoked for every record
+ that starts at a split point.
+ * Method ``range_tracker.record_current_position()`` may be invoked for
+ records that do not start at split points.
+
+ Args:
+ range_tracker: a ``RangeTracker`` whose boundaries must be respected
+ when reading data from the source. A runner that reads
this
+ source muss pass a ``RangeTracker`` object that is not
+ ``None``.
+ Returns:
+ an iterator of data read by the source.
+ """
with MongoClient(self.uri, **self.spec) as client:
all_filters = self._merge_id_filter(
range_tracker.start_position(), range_tracker.stop_position())
- docs_cursor = client[self.db][self.coll].find(
- filter=all_filters,
- projection=self.projection).sort([('_id', ASCENDING)])
+ docs_cursor = (
+ client[self.db][self.coll].find(
+ filter=all_filters,
+ projection=self.projection).sort([("_id", ASCENDING)]))
for doc in docs_cursor:
- if not range_tracker.try_claim(doc['_id']):
+ if not range_tracker.try_claim(doc["_id"]):
return
yield doc
def display_data(self):
- res = super(_BoundedMongoSource, self).display_data()
- res['database'] = self.db
- res['collection'] = self.coll
- res['filter'] = json.dumps(
- self.filter, default=lambda x: 'not_serializable(%s)' % str(x))
- res['projection'] = str(self.projection)
- res['bucket_auto'] = self.bucket_auto
+ """Returns the display data associated to a pipeline component."""
+ res = super().display_data()
+ res["database"] = self.db
+ res["collection"] = self.coll
+ res["filter"] = json.dumps(self.filter, default=json_util.default)
+ res["projection"] = str(self.projection)
+ res["bucket_auto"] = self.bucket_auto
return res
- def _get_split_keys(self, desired_chunk_size_in_mb, start_pos, end_pos):
- # calls mongodb splitVector command to get document ids at split position
- if start_pos >= _ObjectIdHelper.increment_id(end_pos, -1):
- # single document not splittable
+ @staticmethod
+ def _range_is_not_splittable(
+ start_pos: Union[int, str, ObjectId],
+ end_pos: Union[int, str, ObjectId],
+ ):
+ """Return `True` if splitting range doesn't make sense
+ (single document is not splittable),
+ Return `False` otherwise.
+ """
+ return ((
+ isinstance(start_pos, ObjectId) and
+ start_pos >= _ObjectIdHelper.increment_id(end_pos, -1)) or
+ (isinstance(start_pos, int) and start_pos >= end_pos - 1) or
+ (isinstance(start_pos, str) and start_pos >= end_pos))
+
+ def _get_split_keys(
+ self,
+ desired_chunk_size_in_mb: int,
+ start_pos: Union[int, str, ObjectId],
+ end_pos: Union[int, str, ObjectId],
+ ):
+ """Calls MongoDB `splitVector` command
+ to get document ids at split position.
+ """
+ # single document not splittable
+ if self._range_is_not_splittable(start_pos, end_pos):
return []
+
with MongoClient(self.uri, **self.spec) as client:
- name_space = '%s.%s' % (self.db, self.coll)
- return (
- client[self.db].command(
- 'splitVector',
- name_space,
- keyPattern={'_id': 1}, # Ascending index
- min={'_id': start_pos},
- max={'_id': end_pos},
- maxChunkSize=desired_chunk_size_in_mb)['splitKeys'])
+ name_space = "%s.%s" % (self.db, self.coll)
+ return client[self.db].command(
+ "splitVector",
+ name_space,
+ keyPattern={"_id": 1}, # Ascending index
+ min={"_id": start_pos},
+ max={"_id": end_pos},
+ maxChunkSize=desired_chunk_size_in_mb,
+ )["splitKeys"]
def _get_auto_buckets(
- self, desired_chunk_size_in_mb, start_pos, end_pos, is_initial_split):
-
- if start_pos >= _ObjectIdHelper.increment_id(end_pos, -1):
- # single document not splittable
+ self,
+ desired_chunk_size_in_mb: int,
+ start_pos: Union[int, str, ObjectId],
+ end_pos: Union[int, str, ObjectId],
+ is_initial_split: bool,
+ ) -> list:
+ """Use MongoDB `$bucketAuto` aggregation to split collection into bundles
+ instead of `splitVector` command, which does not work with MongoDB Atlas.
+ """
+ # single document not splittable
+ if self._range_is_not_splittable(start_pos, end_pos):
return []
if is_initial_split and not self.filter:
- # total collection size
+ # total collection size in MB
size_in_mb = self.estimate_size() / float(1 << 20)
else:
# size of documents within start/end id range and possibly filtered
@@ -306,34 +501,46 @@ class _BoundedMongoSource(iobase.BoundedSource):
pipeline = [
{
# filter by positions and by the custom filter if any
- '$match': self._merge_id_filter(start_pos, end_pos)
+ "$match": self._merge_id_filter(start_pos, end_pos)
},
{
- '$bucketAuto': {
- 'groupBy': '$_id', 'buckets': bucket_count
+ "$bucketAuto": {
+ "groupBy": "$_id", "buckets": bucket_count
}
- }
+ },
]
- buckets = list(client[self.db][self.coll].aggregate(pipeline))
+ buckets = list(
+ # Use `allowDiskUse` option to avoid aggregation limit of 100 Mb RAM
+ client[self.db][self.coll].aggregate(pipeline, allowDiskUse=True))
if buckets:
- buckets[-1]['_id']['max'] = end_pos
+ buckets[-1]["_id"]["max"] = end_pos
return buckets
- def _merge_id_filter(self, start_position, stop_position):
- # Merge the default filter (if any) with refined _id field range
- # of range_tracker.
- # $gte specifies start position (inclusive)
- # and $lt specifies the end position (exclusive),
- # see more at
- # https://docs.mongodb.com/manual/reference/operator/query/gte/ and
- # https://docs.mongodb.com/manual/reference/operator/query/lt/
- id_filter = {'_id': {'$gte': start_position, '$lt': stop_position}}
+ def _merge_id_filter(
+ self,
+ start_position: Union[int, str, bytes, ObjectId],
+ stop_position: Union[int, str, bytes, ObjectId] = None,
+ ) -> dict:
+ """Merge the default filter (if any) with refined _id field range
+ of range_tracker.
+ $gte specifies start position (inclusive)
+ and $lt specifies the end position (exclusive),
+ see more at
+ https://docs.mongodb.com/manual/reference/operator/query/gte/ and
+ https://docs.mongodb.com/manual/reference/operator/query/lt/
+ """
+
+ if stop_position is None:
+ id_filter = {"_id": {"$gte": start_position}}
+ else:
+ id_filter = {"_id": {"$gte": start_position, "$lt": stop_position}}
+
if self.filter:
all_filters = {
# see more at
# https://docs.mongodb.com/manual/reference/operator/query/and/
- '$and': [self.filter.copy(), id_filter]
+ "$and": [self.filter.copy(), id_filter]
}
else:
all_filters = id_filter
@@ -342,45 +549,58 @@ class _BoundedMongoSource(iobase.BoundedSource):
def _get_head_document_id(self, sort_order):
with MongoClient(self.uri, **self.spec) as client:
- cursor = client[self.db][self.coll].find(
- filter={}, projection=[]).sort([('_id', sort_order)]).limit(1)
+ cursor = (
+ client[self.db][self.coll].find(filter={}, projection=[]).sort([
+ ("_id", sort_order)
+ ]).limit(1))
try:
- return cursor[0]['_id']
+ return cursor[0]["_id"]
+
except IndexError:
- raise ValueError('Empty Mongodb collection')
+ raise ValueError("Empty Mongodb collection")
def _replace_none_positions(self, start_position, stop_position):
+
if start_position is None:
start_position = self._get_head_document_id(ASCENDING)
if stop_position is None:
last_doc_id = self._get_head_document_id(DESCENDING)
# increment last doc id binary value by 1 to make sure the last document
# is not excluded
- stop_position = _ObjectIdHelper.increment_id(last_doc_id, 1)
+ if isinstance(last_doc_id, ObjectId):
+ stop_position = _ObjectIdHelper.increment_id(last_doc_id, 1)
+ elif isinstance(last_doc_id, int):
+ stop_position = last_doc_id + 1
+ elif isinstance(last_doc_id, str):
+ stop_position = last_doc_id + '\x00'
+
return start_position, stop_position
def _count_id_range(self, start_position, stop_position):
- # Number of documents between start_position (inclusive)
- # and stop_position (exclusive), respecting the custom filter if any.
+ """Number of documents between start_position (inclusive)
+ and stop_position (exclusive), respecting the custom filter if any.
+ """
with MongoClient(self.uri, **self.spec) as client:
return client[self.db][self.coll].count_documents(
filter=self._merge_id_filter(start_position, stop_position))
-class _ObjectIdHelper(object):
+class _ObjectIdHelper:
"""A Utility class to manipulate bson object ids."""
@classmethod
- def id_to_int(cls, id):
+ def id_to_int(cls, _id: Union[int, ObjectId]) -> int:
"""
Args:
- id: ObjectId required for each MongoDB document _id field.
+ _id: ObjectId required for each MongoDB document _id field.
Returns: Converted integer value of ObjectId's 12 bytes binary value.
-
"""
+ if isinstance(_id, int):
+ return _id
+
# converts object id binary to integer
# id object is bytes type with size of 12
- ints = struct.unpack('>III', id.binary)
+ ints = struct.unpack(">III", _id.binary)
return (ints[0] << 64) + (ints[1] << 32) + ints[2]
@classmethod
@@ -391,61 +611,45 @@ class _ObjectIdHelper(object):
Returns: The ObjectId that has the 12 bytes binary converted from the
integer value.
-
"""
# converts integer value to object id. Int value should be less than
# (2 ^ 96) so it can be convert to 12 bytes required by object id.
if number < 0 or number >= (1 << 96):
- raise ValueError('number value must be within [0, %s)' % (1 << 96))
- ints = [(number & 0xffffffff0000000000000000) >> 64,
- (number & 0x00000000ffffffff00000000) >> 32,
- number & 0x0000000000000000ffffffff]
+ raise ValueError("number value must be within [0, %s)" % (1 << 96))
+ ints = [
+ (number & 0xFFFFFFFF0000000000000000) >> 64,
+ (number & 0x00000000FFFFFFFF00000000) >> 32,
+ number & 0x0000000000000000FFFFFFFF,
+ ]
- bytes = struct.pack('>III', *ints)
- return objectid.ObjectId(bytes)
+ number_bytes = struct.pack(">III", *ints)
+ return ObjectId(number_bytes)
@classmethod
- def increment_id(cls, object_id, inc):
+ def increment_id(
+ cls,
+ _id: ObjectId,
+ inc: int,
+ ) -> ObjectId:
"""
+ Increment object_id binary value by inc value and return new object id.
+
Args:
- object_id: The ObjectId to change.
- inc(int): The incremental int value to be added to ObjectId.
+ _id: The `_id` to change.
+ inc(int): The incremental int value to be added to `_id`.
Returns:
-
+ `_id` incremented by `inc` value
"""
- # increment object_id binary value by inc value and return new object id.
- id_number = _ObjectIdHelper.id_to_int(object_id)
+ id_number = _ObjectIdHelper.id_to_int(_id)
new_number = id_number + inc
if new_number < 0 or new_number >= (1 << 96):
raise ValueError(
- 'invalid incremental, inc value must be within ['
- '%s, %s)' % (0 - id_number, 1 << 96 - id_number))
+ "invalid incremental, inc value must be within ["
+ "%s, %s)" % (0 - id_number, 1 << 96 - id_number))
return _ObjectIdHelper.int_to_id(new_number)
-class _ObjectIdRangeTracker(OrderedPositionRangeTracker):
- """RangeTracker for tracking mongodb _id of bson ObjectId type."""
- def position_to_fraction(self, pos, start, end):
- pos_number = _ObjectIdHelper.id_to_int(pos)
- start_number = _ObjectIdHelper.id_to_int(start)
- end_number = _ObjectIdHelper.id_to_int(end)
- return (pos_number - start_number) / (end_number - start_number)
-
- def fraction_to_position(self, fraction, start, end):
- start_number = _ObjectIdHelper.id_to_int(start)
- end_number = _ObjectIdHelper.id_to_int(end)
- total = end_number - start_number
- pos = int(total * fraction + start_number)
- # make sure split position is larger than start position and smaller than
- # end position.
- if pos <= start_number:
- return _ObjectIdHelper.increment_id(start, 1)
- if pos >= end_number:
- return _ObjectIdHelper.increment_id(end, -1)
- return _ObjectIdHelper.int_to_id(pos)
-
-
@experimental()
class WriteToMongoDB(PTransform):
"""WriteToMongoDB is a ``PTransform`` that writes a ``PCollection`` of
@@ -472,11 +676,12 @@ class WriteToMongoDB(PTransform):
"""
def __init__(
self,
- uri='mongodb://localhost:27017',
+ uri="mongodb://localhost:27017",
db=None,
coll=None,
batch_size=100,
- extra_client_params=None):
+ extra_client_params=None,
+ ):
"""
Args:
@@ -496,11 +701,10 @@ class WriteToMongoDB(PTransform):
if extra_client_params is None:
extra_client_params = {}
if not isinstance(db, str):
- raise ValueError('WriteToMongoDB db param must be specified as a string')
+ raise ValueError("WriteToMongoDB db param must be specified as a string")
if not isinstance(coll, str):
raise ValueError(
- 'WriteToMongoDB coll param must be specified as a '
- 'string')
+ "WriteToMongoDB coll param must be specified as a string")
self._uri = uri
self._db = db
self._coll = coll
@@ -508,25 +712,27 @@ class WriteToMongoDB(PTransform):
self._spec = extra_client_params
def expand(self, pcoll):
- return pcoll \
- | beam.ParDo(_GenerateObjectIdFn()) \
- | Reshuffle() \
- | beam.ParDo(_WriteMongoFn(self._uri, self._db, self._coll,
- self._batch_size, self._spec))
+ return (
+ pcoll
+ | beam.ParDo(_GenerateObjectIdFn())
+ | Reshuffle()
+ | beam.ParDo(
+ _WriteMongoFn(
+ self._uri, self._db, self._coll, self._batch_size,
self._spec)))
class _GenerateObjectIdFn(DoFn):
def process(self, element, *args, **kwargs):
# if _id field already exist we keep it as it is, otherwise the ptransform
# generates a new _id field to achieve idempotent write to mongodb.
- if '_id' not in element:
+ if "_id" not in element:
# object.ObjectId() generates a unique identifier that follows mongodb
# default format, if _id is not present in document, mongodb server
# generates it with this same function upon write. However the
# uniqueness of generated id may not be guaranteed if the work load are
# distributed across too many processes. See more on the ObjectId format
# https://docs.mongodb.com/manual/reference/bson-types/#objectid.
- element['_id'] = objectid.ObjectId()
+ element["_id"] = objectid.ObjectId()
yield element
@@ -560,13 +766,13 @@ class _WriteMongoFn(DoFn):
def display_data(self):
res = super(_WriteMongoFn, self).display_data()
- res['database'] = self.db
- res['collection'] = self.coll
- res['batch_size'] = self.batch_size
+ res["database"] = self.db
+ res["collection"] = self.coll
+ res["batch_size"] = self.batch_size
return res
-class _MongoSink(object):
+class _MongoSink:
def __init__(self, uri=None, db=None, coll=None, extra_params=None):
if extra_params is None:
extra_params = {}
@@ -585,17 +791,18 @@ class _MongoSink(object):
# insert new one, otherwise overwrite it.
requests.append(
ReplaceOne(
- filter={'_id': doc.get('_id', None)},
+ filter={"_id": doc.get("_id", None)},
replacement=doc,
upsert=True))
resp = self.client[self.db][self.coll].bulk_write(requests)
_LOGGER.debug(
- 'BulkWrite to MongoDB result in nModified:%d, nUpserted:%d, '
- 'nMatched:%d, Errors:%s' % (
+ "BulkWrite to MongoDB result in nModified:%d, nUpserted:%d, "
+ "nMatched:%d, Errors:%s" % (
resp.modified_count,
resp.upserted_count,
resp.matched_count,
- resp.bulk_api_result.get('writeErrors')))
+ resp.bulk_api_result.get("writeErrors"),
+ ))
def __enter__(self):
if self.client is None:
diff --git a/sdks/python/apache_beam/io/mongodbio_test.py
b/sdks/python/apache_beam/io/mongodbio_test.py
index f6f467f..150eac2 100644
--- a/sdks/python/apache_beam/io/mongodbio_test.py
+++ b/sdks/python/apache_beam/io/mongodbio_test.py
@@ -20,9 +20,11 @@ import datetime
import logging
import random
import unittest
+from typing import Union
from unittest import TestCase
import mock
+from bson import ObjectId
from bson import objectid
from parameterized import parameterized_class
from pymongo import ASCENDING
@@ -38,6 +40,8 @@ from apache_beam.io.mongodbio import _MongoSink
from apache_beam.io.mongodbio import _ObjectIdHelper
from apache_beam.io.mongodbio import _ObjectIdRangeTracker
from apache_beam.io.mongodbio import _WriteMongoFn
+from apache_beam.io.range_trackers import LexicographicKeyRangeTracker
+from apache_beam.io.range_trackers import OffsetRangeTracker
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
@@ -120,7 +124,7 @@ class _MockMongoColl(object):
def count_documents(self, filter):
return len(self._filter(filter))
- def aggregate(self, pipeline):
+ def aggregate(self, pipeline, **kwargs):
# Simulate $bucketAuto aggregate pipeline.
# Example splits doc count for the total of 5 docs:
# - 1 bucket: [5]
@@ -175,7 +179,7 @@ class _MockMongoDb(object):
def command(self, command, *args, **kwargs):
if command == 'collstats':
return {'size': 5 * 1024 * 1024, 'avgObjSize': 1 * 1024 * 1024}
- elif command == 'splitVector':
+ if command == 'splitVector':
return self.get_split_keys(command, *args, **kwargs)
def get_split_keys(self, command, ns, min, max, maxChunkSize, **kwargs):
@@ -205,7 +209,7 @@ class _MockMongoDb(object):
}
-class _MockMongoClient(object):
+class _MockMongoClient:
def __init__(self, docs):
self.docs = docs
@@ -219,15 +223,87 @@ class _MockMongoClient(object):
pass
-@parameterized_class(('bucket_auto', ), [(None, ), (True, )])
+# Generate test data for MongoDB collections of different types
+OBJECT_IDS = [
+ objectid.ObjectId.from_datetime(
+ datetime.datetime(year=2020, month=i + 1, day=i + 1)) for i in range(5)
+]
+
+INT_IDS = [n for n in range(5)] # [0, 1, 2, 3, 4]
+
+STR_IDS_1 = [str(n) for n in range(5)] # ['0', '1', '2', '3', '4']
+
+# ['aaaaa', 'bbbbb', 'ccccc', 'ddddd', 'eeeee']
+STR_IDS_2 = [chr(97 + n) * 5 for n in range(5)]
+
+# ['AAAAAAAAAAAAAAAAAAAA', 'BBBBBBBBBBBBBBBBBBBB', ..., 'EEEEEEEEEEEEEEEEEEEE']
+STR_IDS_3 = [chr(65 + n) * 20 for n in range(5)]
+
+
+@parameterized_class(('bucket_auto', '_ids', 'min_id', 'max_id'),
+ [
+ (
+ None,
+ OBJECT_IDS,
+ _ObjectIdHelper.int_to_id(0),
+ _ObjectIdHelper.int_to_id(2**96 - 1)),
+ (
+ True,
+ OBJECT_IDS,
+ _ObjectIdHelper.int_to_id(0),
+ _ObjectIdHelper.int_to_id(2**96 - 1)),
+ (
+ None,
+ INT_IDS,
+ 0,
+ 2**96 - 1,
+ ),
+ (
+ True,
+ INT_IDS,
+ 0,
+ 2**96 - 1,
+ ),
+ (
+ None,
+ STR_IDS_1,
+ chr(0),
+ chr(0x10ffff),
+ ),
+ (
+ True,
+ STR_IDS_1,
+ chr(0),
+ chr(0x10ffff),
+ ),
+ (
+ None,
+ STR_IDS_2,
+ chr(0),
+ chr(0x10ffff),
+ ),
+ (
+ True,
+ STR_IDS_2,
+ chr(0),
+ chr(0x10ffff),
+ ),
+ (
+ None,
+ STR_IDS_3,
+ chr(0),
+ chr(0x10ffff),
+ ),
+ (
+ True,
+ STR_IDS_3,
+ chr(0),
+ chr(0x10ffff),
+ ),
+ ])
class MongoSourceTest(unittest.TestCase):
@mock.patch('apache_beam.io.mongodbio.MongoClient')
def setUp(self, mock_client):
- self._ids = [
- objectid.ObjectId.from_datetime(
- datetime.datetime(year=2020, month=i + 1, day=i + 1))
- for i in range(5)
- ]
self._docs = [{'_id': self._ids[i], 'x': i} for i in range(len(self._ids))]
mock_client.return_value = _MockMongoClient(self._docs)
@@ -242,6 +318,27 @@ class MongoSourceTest(unittest.TestCase):
kwargs['bucket_auto'] = bucket_auto
return _BoundedMongoSource('mongodb://test', 'testdb', 'testcoll',
**kwargs)
+ def _increment_id(
+ self,
+ _id: Union[ObjectId, int, str],
+ inc: int,
+ ) -> Union[ObjectId, int, str]:
+ """Helper method to increment `_id` of different types."""
+
+ if isinstance(_id, ObjectId):
+ return _ObjectIdHelper.increment_id(_id, inc)
+
+ if isinstance(_id, int):
+ return _id + inc
+
+ if isinstance(_id, str):
+ index = self._ids.index(_id) + inc
+ if index <= 0:
+ return self._ids[0]
+ if index >= len(self._ids):
+ return self._ids[-1]
+ return self._ids[index]
+
@mock.patch('apache_beam.io.mongodbio.MongoClient')
def test_estimate_size(self, mock_client):
mock_client.return_value = _MockMongoClient(self._docs)
@@ -282,10 +379,12 @@ class MongoSourceTest(unittest.TestCase):
start_position=None, stop_position=None,
desired_bundle_size=size))
self.assertEqual(len(splits), 1)
- self.assertEqual(splits[0].start_position, self._docs[0]['_id'])
- self.assertEqual(
- splits[0].stop_position,
- _ObjectIdHelper.increment_id(self._docs[0]['_id'], 1))
+ _id = self._docs[0]['_id']
+ assert _id == splits[0].start_position
+ assert _id <= splits[0].stop_position
+ if isinstance(_id, (ObjectId, int)):
+ # We can unambiguously determine next `_id`
+ assert self._increment_id(_id, 1) == splits[0].stop_position
@mock.patch('apache_beam.io.mongodbio.MongoClient')
def test_split_no_documents(self, mock_client):
@@ -357,8 +456,8 @@ class MongoSourceTest(unittest.TestCase):
reference_info = (
filtered_mongo_source,
# range to match no documents:
- _ObjectIdHelper.increment_id(self._docs[-1]['_id'], 1),
- _ObjectIdHelper.increment_id(self._docs[-1]['_id'], 2),
+ self._increment_id(self._docs[-1]['_id'], 1),
+ self._increment_id(self._docs[-1]['_id'], 2),
)
sources_info = ([
(split.source, split.start_position, split.stop_position)
@@ -379,8 +478,21 @@ class MongoSourceTest(unittest.TestCase):
@mock.patch('apache_beam.io.mongodbio.MongoClient')
def test_get_range_tracker(self, mock_client):
mock_client.return_value = _MockMongoClient(self._docs)
- self.assertIsInstance(
- self.mongo_source.get_range_tracker(None, None), _ObjectIdRangeTracker)
+ if self._ids == OBJECT_IDS:
+ self.assertIsInstance(
+ self.mongo_source.get_range_tracker(None, None),
+ _ObjectIdRangeTracker,
+ )
+ elif self._ids == INT_IDS:
+ self.assertIsInstance(
+ self.mongo_source.get_range_tracker(None, None),
+ OffsetRangeTracker,
+ )
+ elif self._ids == STR_IDS_1:
+ self.assertIsInstance(
+ self.mongo_source.get_range_tracker(None, None),
+ LexicographicKeyRangeTracker,
+ )
@mock.patch('apache_beam.io.mongodbio.MongoClient')
def test_read(self, mock_client):
@@ -394,26 +506,26 @@ class MongoSourceTest(unittest.TestCase):
},
{
# range covers from the first to the third documents
- 'start': _ObjectIdHelper.int_to_id(0), # smallest possible id
+ 'start': self.min_id, # smallest possible id
'stop': self._ids[2],
'expected': self._docs[0:2]
},
{
# range covers from the third to last documents
'start': self._ids[2],
- 'stop': _ObjectIdHelper.int_to_id(2**96 - 1), # largest possible
id
+ 'stop': self.max_id, # largest possible id
'expected': self._docs[2:]
},
{
# range covers all documents
- 'start': _ObjectIdHelper.int_to_id(0),
- 'stop': _ObjectIdHelper.int_to_id(2**96 - 1),
+ 'start': self.min_id,
+ 'stop': self.max_id,
'expected': self._docs
},
{
# range doesn't include any document
- 'start': _ObjectIdHelper.increment_id(self._ids[2], 1),
- 'stop': _ObjectIdHelper.increment_id(self._ids[3], -1),
+ 'start': self._increment_id(self._ids[2], 1),
+ 'stop': self._increment_id(self._ids[3], -1),
'expected': []
},
]
@@ -429,6 +541,32 @@ class MongoSourceTest(unittest.TestCase):
self.assertTrue('database' in data)
self.assertTrue('collection' in data)
+ def test_range_is_not_splittable(self):
+ self.assertTrue(
+ self.mongo_source._range_is_not_splittable(
+ _ObjectIdHelper.int_to_id(1),
+ _ObjectIdHelper.int_to_id(1),
+ ))
+ self.assertTrue(
+ self.mongo_source._range_is_not_splittable(
+ _ObjectIdHelper.int_to_id(1),
+ _ObjectIdHelper.int_to_id(2),
+ ))
+ self.assertFalse(
+ self.mongo_source._range_is_not_splittable(
+ _ObjectIdHelper.int_to_id(1),
+ _ObjectIdHelper.int_to_id(3),
+ ))
+
+ self.assertTrue(self.mongo_source._range_is_not_splittable(0, 0))
+ self.assertTrue(self.mongo_source._range_is_not_splittable(0, 1))
+ self.assertFalse(self.mongo_source._range_is_not_splittable(0, 2))
+
+ self.assertTrue(self.mongo_source._range_is_not_splittable("AAA", "AAA"))
+ self.assertFalse(
+ self.mongo_source._range_is_not_splittable("AAA", "AAA\x00"))
+ self.assertFalse(self.mongo_source._range_is_not_splittable("AAA", "AAB"))
+
@parameterized_class(('bucket_auto', ), [(False, ), (True, )])
class ReadFromMongoDBTest(unittest.TestCase):
@@ -500,9 +638,13 @@ class MongoSinkTest(unittest.TestCase):
class WriteToMongoDBTest(unittest.TestCase):
@mock.patch('apache_beam.io.mongodbio.MongoClient')
def test_write_to_mongodb_with_existing_id(self, mock_client):
- id = objectid.ObjectId()
- docs = [{'x': 1, '_id': id}]
- expected_update = [ReplaceOne({'_id': id}, {'x': 1, '_id': id}, True,
None)]
+ _id = objectid.ObjectId()
+ docs = [{'x': 1, '_id': _id}]
+ expected_update = [
+ ReplaceOne({'_id': _id}, {
+ 'x': 1, '_id': _id
+ }, True, None)
+ ]
with TestPipeline() as p:
_ = (
p | "Create" >> beam.Create(docs)
@@ -538,34 +680,36 @@ class ObjectIdHelperTest(TestCase):
(objectid.ObjectId('00000000ffffffffffffffff'), 2**64 - 1),
(objectid.ObjectId('ffffffffffffffffffffffff'), 2**96 - 1),
]
- for (id, number) in test_cases:
- self.assertEqual(id, _ObjectIdHelper.int_to_id(number))
- self.assertEqual(number, _ObjectIdHelper.id_to_int(id))
+ for (_id, number) in test_cases:
+ self.assertEqual(_id, _ObjectIdHelper.int_to_id(number))
+ self.assertEqual(number, _ObjectIdHelper.id_to_int(_id))
# random tests
for _ in range(100):
- id = objectid.ObjectId()
- number = int(id.binary.hex(), 16)
- self.assertEqual(id, _ObjectIdHelper.int_to_id(number))
- self.assertEqual(number, _ObjectIdHelper.id_to_int(id))
+ _id = objectid.ObjectId()
+ number = int(_id.binary.hex(), 16)
+ self.assertEqual(_id, _ObjectIdHelper.int_to_id(number))
+ self.assertEqual(number, _ObjectIdHelper.id_to_int(_id))
def test_increment_id(self):
test_cases = [
(
- objectid.ObjectId('000000000000000100000000'),
- objectid.ObjectId('0000000000000000ffffffff')),
+ objectid.ObjectId("000000000000000100000000"),
+ objectid.ObjectId("0000000000000000ffffffff"),
+ ),
(
- objectid.ObjectId('000000010000000000000000'),
- objectid.ObjectId('00000000ffffffffffffffff')),
+ objectid.ObjectId("000000010000000000000000"),
+ objectid.ObjectId("00000000ffffffffffffffff"),
+ ),
]
- for (first, second) in test_cases:
+ for first, second in test_cases:
self.assertEqual(second, _ObjectIdHelper.increment_id(first, -1))
self.assertEqual(first, _ObjectIdHelper.increment_id(second, 1))
for _ in range(100):
- id = objectid.ObjectId()
- self.assertLess(id, _ObjectIdHelper.increment_id(id, 1))
- self.assertGreater(id, _ObjectIdHelper.increment_id(id, -1))
+ _id = objectid.ObjectId()
+ self.assertLess(_id, _ObjectIdHelper.increment_id(_id, 1))
+ self.assertGreater(_id, _ObjectIdHelper.increment_id(_id, -1))
class ObjectRangeTrackerTest(TestCase):
@@ -578,10 +722,10 @@ class ObjectRangeTrackerTest(TestCase):
[random.randint(start_int, stop_int) for _ in range(100)])
tracker = _ObjectIdRangeTracker()
for pos in test_cases:
- id = _ObjectIdHelper.int_to_id(pos - start_int)
+ _id = _ObjectIdHelper.int_to_id(pos - start_int)
desired_fraction = (pos - start_int) / (stop_int - start_int)
self.assertAlmostEqual(
- tracker.position_to_fraction(id, start, stop),
+ tracker.position_to_fraction(_id, start, stop),
desired_fraction,
places=20)
diff --git a/sdks/python/apache_beam/io/range_trackers.py
b/sdks/python/apache_beam/io/range_trackers.py
index a53ecb6..33b15d5 100644
--- a/sdks/python/apache_beam/io/range_trackers.py
+++ b/sdks/python/apache_beam/io/range_trackers.py
@@ -23,6 +23,7 @@ import codecs
import logging
import math
import threading
+from typing import Union
from apache_beam.io import iobase
@@ -265,16 +266,17 @@ class OrderedPositionRangeTracker(iobase.RangeTracker):
if ((self._stop_position is not None and position >= self._stop_position)
or (self._start_position is not None and
position <= self._start_position)):
- raise ValueError(
- "Split at '%s' not in range %s" %
- (position, [self._start_position, self._stop_position]))
+ _LOGGER.debug(
+ 'Refusing to split %r at %d: proposed split position out of range',
+ self,
+ position)
+ return
+
if self._last_claim is self.UNSTARTED or self._last_claim < position:
fraction = self.position_to_fraction(
position, start=self._start_position, end=self._stop_position)
self._stop_position = position
return position, fraction
- else:
- return None
def fraction_consumed(self):
if self._last_claim is self.UNSTARTED:
@@ -289,6 +291,12 @@ class OrderedPositionRangeTracker(iobase.RangeTracker):
"""
raise NotImplementedError
+ def position_to_fraction(self, position, start, end):
+ """Returns the fraction of keys in the range [start, end) that
+ are less than the given key.
+ """
+ raise NotImplementedError
+
class UnsplittableRangeTracker(iobase.RangeTracker):
"""A RangeTracker that always ignores split requests.
@@ -339,87 +347,112 @@ class UnsplittableRangeTracker(iobase.RangeTracker):
class LexicographicKeyRangeTracker(OrderedPositionRangeTracker):
- """
- A range tracker that tracks progress through a lexicographically
+ """A range tracker that tracks progress through a lexicographically
ordered keyspace of strings.
"""
@classmethod
- def fraction_to_position(cls, fraction, start=None, end=None):
- """
- Linearly interpolates a key that is lexicographically
+ def fraction_to_position(
+ cls,
+ fraction: float,
+ start: Union[bytes, str] = None,
+ end: Union[bytes, str] = None,
+ ) -> Union[bytes, str]:
+ """Linearly interpolates a key that is lexicographically
fraction of the way between start and end.
"""
assert 0 <= fraction <= 1, fraction
+
if start is None:
start = b''
+
+ if fraction == 0:
+ return start
+
if fraction == 1:
return end
- elif fraction == 0:
- return start
+
+ if not end:
+ common_prefix_len = len(start) - len(start.lstrip(b'\xFF'))
else:
- if not end:
- common_prefix_len = len(start) - len(start.lstrip(b'\xFF'))
+ for ix, (s, e) in enumerate(zip(start, end)):
+ if s != e:
+ common_prefix_len = ix
+ break
else:
- for ix, (s, e) in enumerate(zip(start, end)):
- if s != e:
- common_prefix_len = ix
- break
- else:
- common_prefix_len = min(len(start), len(end))
- # Convert the relative precision of fraction (~53 bits) to an absolute
- # precision needed to represent values between start and end distinctly.
- prec = common_prefix_len + int(-math.log(fraction, 256)) + 7
- istart = cls._bytestring_to_int(start, prec)
- iend = cls._bytestring_to_int(end, prec) if end else 1 << (prec * 8)
- ikey = istart + int((iend - istart) * fraction)
- # Could be equal due to rounding.
- # Adjust to ensure we never return the actual start and end
- # unless fraction is exatly 0 or 1.
- if ikey == istart:
- ikey += 1
- elif ikey == iend:
- ikey -= 1
- return cls._bytestring_from_int(ikey, prec).rstrip(b'\0')
+ common_prefix_len = min(len(start), len(end))
+
+ # Convert the relative precision of fraction (~53 bits) to an absolute
+ # precision needed to represent values between start and end distinctly.
+ prec = common_prefix_len + int(-math.log(fraction, 256)) + 7
+ istart = cls._bytestring_to_int(start, prec)
+ iend = cls._bytestring_to_int(end, prec) if end else 1 << (prec * 8)
+ ikey = istart + int((iend - istart) * fraction)
+
+ # Could be equal due to rounding.
+ # Adjust to ensure we never return the actual start and end
+ # unless fraction is exatly 0 or 1.
+ if ikey == istart:
+ ikey += 1
+ elif ikey == iend:
+ ikey -= 1
+
+ position: bytes = cls._bytestring_from_int(ikey, prec).rstrip(b'\0')
+
+ if isinstance(start, bytes):
+ return position
+
+ return position.decode(encoding='unicode_escape', errors='replace')
@classmethod
- def position_to_fraction(cls, key, start=None, end=None):
- """
- Returns the fraction of keys in the range [start, end) that
+ def position_to_fraction(
+ cls,
+ key: Union[bytes, str] = None,
+ start: Union[bytes, str] = None,
+ end: Union[bytes, str] = None,
+ ) -> float:
+ """Returns the fraction of keys in the range [start, end) that
are less than the given key.
"""
if not key:
return 0
+
if start is None:
- start = b''
+ start = '' if isinstance(key, str) else b''
+
prec = len(start) + 7
if key.startswith(start):
# Higher absolute precision needed for very small values of fixed
# relative position.
- prec = max(prec, len(key) - len(key[len(start):].strip(b'\0')) + 7)
+ trailing_symbol = '\0' if isinstance(key, str) else b'\0'
+ prec = max(
+ prec, len(key) - len(key[len(start):].strip(trailing_symbol)) + 7)
istart = cls._bytestring_to_int(start, prec)
ikey = cls._bytestring_to_int(key, prec)
iend = cls._bytestring_to_int(end, prec) if end else 1 << (prec * 8)
return float(ikey - istart) / (iend - istart)
@staticmethod
- def _bytestring_to_int(s, prec):
- """
- Returns int(256**prec * f) where f is the fraction
+ def _bytestring_to_int(s: Union[bytes, str], prec: int) -> int:
+ """Returns int(256**prec * f) where f is the fraction
represented by interpreting '.' + s as a base-256
floating point number.
"""
if not s:
return 0
- elif len(s) < prec:
+
+ if isinstance(s, str):
+ s = s.encode() # str -> bytes
+
+ if len(s) < prec:
s += b'\0' * (prec - len(s))
else:
s = s[:prec]
- return int(codecs.encode(s, 'hex'), 16)
+
+ h = codecs.encode(s, encoding='hex')
+ return int(h, base=16)
@staticmethod
- def _bytestring_from_int(i, prec):
- """
- Inverse of _bytestring_to_int.
- """
+ def _bytestring_from_int(i: int, prec: int) -> bytes:
+ """Inverse of _bytestring_to_int."""
h = '%x' % i
- return codecs.decode('0' * (2 * prec - len(h)) + h, 'hex')
+ return codecs.decode('0' * (2 * prec - len(h)) + h, encoding='hex')
diff --git a/sdks/python/apache_beam/io/range_trackers_test.py
b/sdks/python/apache_beam/io/range_trackers_test.py
index 02a1e4e..0bf3799 100644
--- a/sdks/python/apache_beam/io/range_trackers_test.py
+++ b/sdks/python/apache_beam/io/range_trackers_test.py
@@ -22,6 +22,8 @@ import copy
import logging
import math
import unittest
+from typing import Optional
+from typing import Union
from apache_beam.io import range_trackers
@@ -257,25 +259,26 @@ class OrderedPositionRangeTrackerTest(unittest.TestCase):
def test_out_of_range(self):
tracker = self.DoubleRangeTracker(10, 20)
+
# Can't claim before range.
with self.assertRaises(ValueError):
tracker.try_claim(-5)
+
# Can't split before range.
- with self.assertRaises(ValueError):
- tracker.try_split(-5)
+ self.assertFalse(tracker.try_split(-5))
+
# Reject useless split at start position.
- with self.assertRaises(ValueError):
- tracker.try_split(10)
+ self.assertFalse(tracker.try_split(10))
+
# Can't split after range.
- with self.assertRaises(ValueError):
- tracker.try_split(25)
+ self.assertFalse(tracker.try_split(25))
tracker.try_split(15)
+
# Can't split after modified range.
- with self.assertRaises(ValueError):
- tracker.try_split(17)
+ self.assertFalse(tracker.try_split(17))
+
# Reject useless split at end position.
- with self.assertRaises(ValueError):
- tracker.try_split(15)
+ self.assertFalse(tracker.try_split(15))
self.assertTrue(tracker.try_split(14))
@@ -303,16 +306,20 @@ class UnsplittableRangeTrackerTest(unittest.TestCase):
class LexicographicKeyRangeTrackerTest(unittest.TestCase):
- """
- Tests of LexicographicKeyRangeTracker.
- """
+ """Tests of LexicographicKeyRangeTracker."""
key_to_fraction = (
range_trackers.LexicographicKeyRangeTracker.position_to_fraction)
fraction_to_key = (
range_trackers.LexicographicKeyRangeTracker.fraction_to_position)
- def _check(self, fraction=None, key=None, start=None, end=None, delta=0):
+ def _check(
+ self,
+ fraction: Optional[float] = None,
+ key: Union[bytes, str] = None,
+ start: Union[bytes, str] = None,
+ end: Union[bytes, str] = None,
+ delta: float = 0.0):
assert key is not None or fraction is not None
if fraction is None:
fraction = self.key_to_fraction(key, start, end)
@@ -341,14 +348,42 @@ class LexicographicKeyRangeTrackerTest(unittest.TestCase):
self._check(key=b'\x07', fraction=7 / 256.)
self._check(key=b'\xFF', fraction=255 / 256.)
self._check(key=b'\x01\x02\x03', fraction=(2**16 + 2**9 + 3) / (2.0**24))
+ self._check(key=b'UUUUUUT', fraction=1 / 3)
+ self._check(key=b'3333334', fraction=1 / 5)
+ self._check(key=b'$\x92I$\x92I$', fraction=1 / 7, delta=1e-3)
+ self._check(key=b'\x01\x02\x03', fraction=(2**16 + 2**9 + 3) / (2.0**24))
def test_key_to_fraction(self):
+ # test no key, no start
+ self._check(end=b'eeeeee', fraction=0.0)
+ self._check(end='eeeeee', fraction=0.0)
+
+ # test no fraction
+ self._check(key=b'bbbbbb', start=b'aaaaaa', end=b'eeeeee')
+ self._check(key='bbbbbb', start='aaaaaa', end='eeeeee')
+
+ # test no start
+ self._check(key=b'eeeeee', end=b'eeeeee', fraction=1.0)
+ self._check(key='eeeeee', end='eeeeee', fraction=1.0)
+ self._check(key=b'\x19YYYYY@', end=b'eeeeee', fraction=0.25)
+ self._check(key=b'2\xb2\xb2\xb2\xb2\xb2\x80', end='eeeeee', fraction=0.5)
+ self._check(key=b'L\x0c\x0c\x0c\x0c\x0b\xc0', end=b'eeeeee', fraction=0.75)
+
+ # test bytes keys
self._check(key=b'\x87', start=b'\x80', fraction=7 / 128.)
self._check(key=b'\x07', end=b'\x10', fraction=7 / 16.)
self._check(key=b'\x47', start=b'\x40', end=b'\x80', fraction=7 / 64.)
self._check(key=b'\x47\x80', start=b'\x40', end=b'\x80', fraction=15 /
128.)
+ # test string keys
+ self._check(key='aaaaaa', start='aaaaaa', end='eeeeee', fraction=0.0)
+ self._check(key='bbbbbb', start='aaaaaa', end='eeeeee', fraction=0.25)
+ self._check(key='cccccc', start='aaaaaa', end='eeeeee', fraction=0.5)
+ self._check(key='dddddd', start='aaaaaa', end='eeeeee', fraction=0.75)
+ self._check(key='eeeeee', start='aaaaaa', end='eeeeee', fraction=1.0)
+
def test_key_to_fraction_common_prefix(self):
+ # test bytes keys
self._check(
key=b'a' * 100 + b'b',
start=b'a' * 100 + b'a',
@@ -370,7 +405,35 @@ class LexicographicKeyRangeTrackerTest(unittest.TestCase):
end=b'foob\x00\x00\x00\x00\x00\x00\x00\x00\x02',
fraction=0.5)
+ # test string keys
+ self._check(
+ key='a' * 100 + 'a',
+ start='a' * 100 + 'a',
+ end='a' * 100 + 'e',
+ fraction=0.0)
+ self._check(
+ key='a' * 100 + 'b',
+ start='a' * 100 + 'a',
+ end='a' * 100 + 'e',
+ fraction=0.25)
+ self._check(
+ key='a' * 100 + 'c',
+ start='a' * 100 + 'a',
+ end='a' * 100 + 'e',
+ fraction=0.5)
+ self._check(
+ key='a' * 100 + 'd',
+ start='a' * 100 + 'a',
+ end='a' * 100 + 'e',
+ fraction=0.75)
+ self._check(
+ key='a' * 100 + 'e',
+ start='a' * 100 + 'a',
+ end='a' * 100 + 'e',
+ fraction=1.0)
+
def test_tiny(self):
+ # test bytes keys
self._check(fraction=.5**20, key=b'\0\0\x10')
self._check(fraction=.5**20, start=b'a', end=b'b', key=b'a\0\0\x10')
self._check(fraction=.5**20, start=b'a', end=b'c', key=b'a\0\0\x20')
@@ -386,6 +449,11 @@ class LexicographicKeyRangeTrackerTest(unittest.TestCase):
delta=1e-15)
self._check(fraction=.5**100, key=b'\0' * 12 + b'\x10')
+ # test string keys
+ self._check(fraction=.5**20, start='a', end='b', key='a\0\0\x10')
+ self._check(fraction=.5**20, start='a', end='c', key='a\0\0\x20')
+ self._check(fraction=.5**20, start='xy_a', end='xy_c', key='xy_a\0\0\x20')
+
def test_lots(self):
for fraction in (0, 1, .5, .75, 7. / 512, 1 - 7. / 4096):
self._check(fraction)
@@ -418,6 +486,12 @@ class LexicographicKeyRangeTrackerTest(unittest.TestCase):
# (beyond the common prefix of start and end).
self._check(
1 / math.e,
+ start='AAAAAAA',
+ end='zzzzzzz',
+ key='VNg/ot\x82',
+ delta=1e-14)
+ self._check(
+ 1 / math.e,
start=b'abc_abc',
end=b'abc_xyz',
key=b'abc_i\xe0\xf4\x84\x86\x99\x96',