MaksymSkorupskyi commented on a change in pull request #14460: URL: https://github.com/apache/beam/pull/14460#discussion_r615879905
########## File path: sdks/python/apache_beam/io/mongodbio.py ########## @@ -66,542 +66,771 @@ # pytype: skip-file +import codecs import itertools import json import logging import math import struct +from typing import Union + +from bson import json_util +from bson.objectid import ObjectId import apache_beam as beam from apache_beam.io import iobase from apache_beam.io.range_trackers import OrderedPositionRangeTracker -from apache_beam.transforms import DoFn -from apache_beam.transforms import PTransform -from apache_beam.transforms import Reshuffle +from apache_beam.transforms import DoFn, PTransform, Reshuffle from apache_beam.utils.annotations import experimental _LOGGER = logging.getLogger(__name__) try: - # Mongodb has its own bundled bson, which is not compatible with bson pakcage. - # (https://github.com/py-bson/bson/issues/82). Try to import objectid and if - # it fails because bson package is installed, MongoDB IO will not work but at - # least rest of the SDK will work. - from bson import objectid - - # pymongo also internally depends on bson. - from pymongo import ASCENDING - from pymongo import DESCENDING - from pymongo import MongoClient - from pymongo import ReplaceOne + # Mongodb has its own bundled bson, which is not compatible with bson pakcage. + # (https://github.com/py-bson/bson/issues/82). Try to import objectid and if + # it fails because bson package is installed, MongoDB IO will not work but at + # least rest of the SDK will work. + from bson import objectid + + # pymongo also internally depends on bson. + from pymongo import ASCENDING + from pymongo import DESCENDING + from pymongo import MongoClient + from pymongo import ReplaceOne except ImportError: - objectid = None - _LOGGER.warning("Could not find a compatible bson package.") + objectid = None + _LOGGER.warning("Could not find a compatible bson package.") -__all__ = ['ReadFromMongoDB', 'WriteToMongoDB'] +__all__ = ["ReadFromMongoDB", "WriteToMongoDB"] @experimental() class ReadFromMongoDB(PTransform): - """A ``PTransform`` to read MongoDB documents into a ``PCollection``. - """ - def __init__( - self, - uri='mongodb://localhost:27017', - db=None, - coll=None, - filter=None, - projection=None, - extra_client_params=None, - bucket_auto=False): - """Initialize a :class:`ReadFromMongoDB` - - Args: - uri (str): The MongoDB connection string following the URI format. - db (str): The MongoDB database name. - coll (str): The MongoDB collection name. - filter: A `bson.SON - <https://api.mongodb.com/python/current/api/bson/son.html>`_ object - specifying elements which must be present for a document to be included - in the result set. - projection: A list of field names that should be returned in the result - set or a dict specifying the fields to include or exclude. - extra_client_params(dict): Optional `MongoClient - <https://api.mongodb.com/python/current/api/pymongo/mongo_client.html>`_ - parameters. - bucket_auto (bool): If :data:`True`, use MongoDB `$bucketAuto` aggregation - to split collection into bundles instead of `splitVector` command, - which does not work with MongoDB Atlas. - If :data:`False` (the default), use `splitVector` command for bundling. - - Returns: - :class:`~apache_beam.transforms.ptransform.PTransform` - - """ - if extra_client_params is None: - extra_client_params = {} - if not isinstance(db, str): - raise ValueError('ReadFromMongDB db param must be specified as a string') - if not isinstance(coll, str): - raise ValueError( - 'ReadFromMongDB coll param must be specified as a ' - 'string') - self._mongo_source = _BoundedMongoSource( - uri=uri, - db=db, - coll=coll, - filter=filter, - projection=projection, - extra_client_params=extra_client_params, - bucket_auto=bucket_auto) - - def expand(self, pcoll): - return pcoll | iobase.Read(self._mongo_source) + """A ``PTransform`` to read MongoDB documents into a ``PCollection``.""" + + def __init__( + self, + uri="mongodb://localhost:27017", + db=None, + coll=None, + filter=None, + projection=None, + extra_client_params=None, + bucket_auto=False, + ): + """Initialize a :class:`ReadFromMongoDB` + + Args: + uri (str): The MongoDB connection string following the URI format. + db (str): The MongoDB database name. + coll (str): The MongoDB collection name. + filter: A `bson.SON + <https://api.mongodb.com/python/current/api/bson/son.html>`_ object + specifying elements which must be present for a document to be included + in the result set. + projection: A list of field names that should be returned in the result + set or a dict specifying the fields to include or exclude. + extra_client_params(dict): Optional `MongoClient + <https://api.mongodb.com/python/current/api/pymongo/mongo_client.html>`_ + parameters. + bucket_auto (bool): If :data:`True`, use MongoDB `$bucketAuto` aggregation + to split collection into bundles instead of `splitVector` command, + which does not work with MongoDB Atlas. + If :data:`False` (the default), use `splitVector` command for bundling. + + Returns: + :class:`~apache_beam.transforms.ptransform.PTransform` + + """ + if extra_client_params is None: + extra_client_params = {} + if not isinstance(db, str): + raise ValueError("ReadFromMongDB db param must be specified as a string") + if not isinstance(coll, str): + raise ValueError("ReadFromMongDB coll param must be specified as a string") + self._mongo_source = _BoundedMongoSource( + uri=uri, + db=db, + coll=coll, + filter=filter, + projection=projection, + extra_client_params=extra_client_params, + bucket_auto=bucket_auto, + ) + + def expand(self, pcoll): + return pcoll | iobase.Read(self._mongo_source) class _BoundedMongoSource(iobase.BoundedSource): - def __init__( - self, - uri=None, - db=None, - coll=None, - filter=None, - projection=None, - extra_client_params=None, - bucket_auto=False): - if extra_client_params is None: - extra_client_params = {} - if filter is None: - filter = {} - self.uri = uri - self.db = db - self.coll = coll - self.filter = filter - self.projection = projection - self.spec = extra_client_params - self.bucket_auto = bucket_auto - - def estimate_size(self): - with MongoClient(self.uri, **self.spec) as client: - return client[self.db].command('collstats', self.coll).get('size') - - def _estimate_average_document_size(self): - with MongoClient(self.uri, **self.spec) as client: - return client[self.db].command('collstats', self.coll).get('avgObjSize') - - def split(self, desired_bundle_size, start_position=None, stop_position=None): - desired_bundle_size_in_mb = desired_bundle_size // 1024 // 1024 - - # for desired bundle size, if desired chunk size smaller than 1mb, use - # MongoDB default split size of 1mb. - if desired_bundle_size_in_mb < 1: - desired_bundle_size_in_mb = 1 - - is_initial_split = start_position is None and stop_position is None - start_position, stop_position = self._replace_none_positions( - start_position, stop_position) - - if self.bucket_auto: - # Use $bucketAuto for bundling - split_keys = [] - weights = [] - for bucket in self._get_auto_buckets(desired_bundle_size_in_mb, - start_position, - stop_position, - is_initial_split): - split_keys.append({'_id': bucket['_id']['max']}) - weights.append(bucket['count']) - else: - # Use splitVector for bundling - split_keys = self._get_split_keys( - desired_bundle_size_in_mb, start_position, stop_position) - weights = itertools.cycle((desired_bundle_size_in_mb, )) - - bundle_start = start_position - for split_key_id, weight in zip(split_keys, weights): - if bundle_start >= stop_position: - break - bundle_end = min(stop_position, split_key_id['_id']) - yield iobase.SourceBundle( - weight=weight, - source=self, - start_position=bundle_start, - stop_position=bundle_end) - bundle_start = bundle_end - # add range of last split_key to stop_position - if bundle_start < stop_position: - # bucket_auto mode can come here if not split due to single document - weight = 1 if self.bucket_auto else desired_bundle_size_in_mb - yield iobase.SourceBundle( - weight=weight, - source=self, - start_position=bundle_start, - stop_position=stop_position) - - def get_range_tracker(self, start_position, stop_position): - start_position, stop_position = self._replace_none_positions( - start_position, stop_position) - return _ObjectIdRangeTracker(start_position, stop_position) - - def read(self, range_tracker): - with MongoClient(self.uri, **self.spec) as client: - all_filters = self._merge_id_filter( - range_tracker.start_position(), range_tracker.stop_position()) - docs_cursor = client[self.db][self.coll].find( - filter=all_filters, - projection=self.projection).sort([('_id', ASCENDING)]) - for doc in docs_cursor: - if not range_tracker.try_claim(doc['_id']): - return - yield doc - - def display_data(self): - res = super(_BoundedMongoSource, self).display_data() - res['database'] = self.db - res['collection'] = self.coll - res['filter'] = json.dumps( - self.filter, default=lambda x: 'not_serializable(%s)' % str(x)) - res['projection'] = str(self.projection) - res['bucket_auto'] = self.bucket_auto - return res - - def _get_split_keys(self, desired_chunk_size_in_mb, start_pos, end_pos): - # calls mongodb splitVector command to get document ids at split position - if start_pos >= _ObjectIdHelper.increment_id(end_pos, -1): - # single document not splittable - return [] - with MongoClient(self.uri, **self.spec) as client: - name_space = '%s.%s' % (self.db, self.coll) - return ( - client[self.db].command( - 'splitVector', - name_space, - keyPattern={'_id': 1}, # Ascending index - min={'_id': start_pos}, - max={'_id': end_pos}, - maxChunkSize=desired_chunk_size_in_mb)['splitKeys']) - - def _get_auto_buckets( - self, desired_chunk_size_in_mb, start_pos, end_pos, is_initial_split): - - if start_pos >= _ObjectIdHelper.increment_id(end_pos, -1): - # single document not splittable - return [] - - if is_initial_split and not self.filter: - # total collection size - size_in_mb = self.estimate_size() / float(1 << 20) - else: - # size of documents within start/end id range and possibly filtered - documents_count = self._count_id_range(start_pos, end_pos) - avg_document_size = self._estimate_average_document_size() - size_in_mb = documents_count * avg_document_size / float(1 << 20) - - if size_in_mb == 0: - # no documents not splittable (maybe a result of filtering) - return [] - - bucket_count = math.ceil(size_in_mb / desired_chunk_size_in_mb) - with beam.io.mongodbio.MongoClient(self.uri, **self.spec) as client: - pipeline = [ - { - # filter by positions and by the custom filter if any - '$match': self._merge_id_filter(start_pos, end_pos) - }, - { - '$bucketAuto': { - 'groupBy': '$_id', 'buckets': bucket_count - } - } - ] - buckets = list(client[self.db][self.coll].aggregate(pipeline)) - if buckets: - buckets[-1]['_id']['max'] = end_pos - - return buckets - - def _merge_id_filter(self, start_position, stop_position): - # Merge the default filter (if any) with refined _id field range - # of range_tracker. - # $gte specifies start position (inclusive) - # and $lt specifies the end position (exclusive), - # see more at - # https://docs.mongodb.com/manual/reference/operator/query/gte/ and - # https://docs.mongodb.com/manual/reference/operator/query/lt/ - id_filter = {'_id': {'$gte': start_position, '$lt': stop_position}} - if self.filter: - all_filters = { - # see more at - # https://docs.mongodb.com/manual/reference/operator/query/and/ - '$and': [self.filter.copy(), id_filter] - } - else: - all_filters = id_filter - - return all_filters - - def _get_head_document_id(self, sort_order): - with MongoClient(self.uri, **self.spec) as client: - cursor = client[self.db][self.coll].find( - filter={}, projection=[]).sort([('_id', sort_order)]).limit(1) - try: - return cursor[0]['_id'] - except IndexError: - raise ValueError('Empty Mongodb collection') - - def _replace_none_positions(self, start_position, stop_position): - if start_position is None: - start_position = self._get_head_document_id(ASCENDING) - if stop_position is None: - last_doc_id = self._get_head_document_id(DESCENDING) - # increment last doc id binary value by 1 to make sure the last document - # is not excluded - stop_position = _ObjectIdHelper.increment_id(last_doc_id, 1) - return start_position, stop_position - - def _count_id_range(self, start_position, stop_position): - # Number of documents between start_position (inclusive) - # and stop_position (exclusive), respecting the custom filter if any. - with MongoClient(self.uri, **self.spec) as client: - return client[self.db][self.coll].count_documents( - filter=self._merge_id_filter(start_position, stop_position)) - - -class _ObjectIdHelper(object): - """A Utility class to manipulate bson object ids.""" - @classmethod - def id_to_int(cls, id): + """A MongoDB source that reads a finite amount of input records. + + This class defines following operations which can be used to read the source + efficiently. + + * Size estimation - method ``estimate_size()`` may return an accurate + estimation in bytes for the size of the source. + * Splitting into bundles of a given size - method ``split()`` can be used to + split the source into a set of sub-sources (bundles) based on a desired + bundle size. + * Getting a MongoDBRangeTracker - method ``get_range_tracker()`` should return a + ``MongoDBRangeTracker`` object for a given position range for the position type + of the records returned by the source. + * Reading the data - method ``read()`` can be used to read data from the + source while respecting the boundaries defined by a given + ``MongoDBRangeTracker``. + + A runner will perform reading the source in two steps. + + (1) Method ``get_range_tracker()`` will be invoked with start and end + positions to obtain a ``MongoDBRangeTracker`` for the range of positions the + runner intends to read. Source must define a default initial start and end + position range. These positions must be used if the start and/or end + positions passed to the method ``get_range_tracker()`` are ``None`` + (2) Method read() will be invoked with the ``MongoDBRangeTracker`` obtained in the + previous step. + + **Mutability** + + A ``_BoundedMongoSource`` object should not be mutated while + its methods (for example, ``read()``) are being invoked by a runner. Runner + implementations may invoke methods of ``_BoundedMongoSource`` objects through + multi-threaded and/or reentrant execution modes. """ - Args: - id: ObjectId required for each MongoDB document _id field. - - Returns: Converted integer value of ObjectId's 12 bytes binary value. - + def __init__( + self, + uri=None, + db=None, + coll=None, + filter=None, + projection=None, + extra_client_params=None, + bucket_auto=False, + ): + if extra_client_params is None: + extra_client_params = {} + if filter is None: + filter = {} + self.uri = uri + self.db = db + self.coll = coll + self.filter = filter + self.projection = projection + self.spec = extra_client_params + self.bucket_auto = bucket_auto + + def estimate_size(self): + with MongoClient(self.uri, **self.spec) as client: + return client[self.db].command("collstats", self.coll).get("size") + + def _estimate_average_document_size(self): + with MongoClient(self.uri, **self.spec) as client: + return client[self.db].command("collstats", self.coll).get("avgObjSize") + + def split(self, desired_bundle_size, start_position=None, stop_position=None): + desired_bundle_size_in_mb = desired_bundle_size // 1024 // 1024 + + # for desired bundle size, if desired chunk size smaller than 1mb, use + # MongoDB default split size of 1mb. + if desired_bundle_size_in_mb < 1: + desired_bundle_size_in_mb = 1 + + is_initial_split = start_position is None and stop_position is None + start_position, stop_position = self._replace_none_positions( + start_position, stop_position + ) + + if self.bucket_auto: + # Use $bucketAuto for bundling + split_keys = [] + weights = [] + for bucket in self._get_auto_buckets( + desired_bundle_size_in_mb, + start_position, + stop_position, + is_initial_split, + ): + split_keys.append({"_id": bucket["_id"]["max"]}) + weights.append(bucket["count"]) + else: + # Use splitVector for bundling + split_keys = self._get_split_keys( + desired_bundle_size_in_mb, start_position, stop_position + ) + weights = itertools.cycle((desired_bundle_size_in_mb,)) + + bundle_start = start_position + for split_key_id, weight in zip(split_keys, weights): + if bundle_start >= stop_position: + break + bundle_end = min(stop_position, split_key_id["_id"]) + yield iobase.SourceBundle( + weight=weight, + source=self, + start_position=bundle_start, + stop_position=bundle_end, + ) + bundle_start = bundle_end + # add range of last split_key to stop_position + if bundle_start < stop_position: + # bucket_auto mode can come here if not split due to single document + weight = 1 if self.bucket_auto else desired_bundle_size_in_mb + yield iobase.SourceBundle( + weight=weight, + source=self, + start_position=bundle_start, + stop_position=stop_position, + ) + + def get_range_tracker( + self, + start_position: Union[int, str, ObjectId] = None, + stop_position: Union[int, str, ObjectId] = None, + ): + """Returns a MongoDBRangeTracker for a given position range. + + Framework may invoke ``read()`` method with the MongoDBRangeTracker object + returned here to read data from the source. + + Args: + start_position: starting position of the range. If 'None' default start + position of the source must be used. + stop_position: ending position of the range. If 'None' default stop + position of the source must be used. + Returns: + a ``MongoDBRangeTracker`` for the given position range. + """ + start_position, stop_position = self._replace_none_positions( + start_position, stop_position + ) + return MongoDBRangeTracker(start_position, stop_position) + + def read(self, range_tracker): + with MongoClient(self.uri, **self.spec) as client: + all_filters = self._merge_id_filter( + range_tracker.start_position(), range_tracker.stop_position() + ) + docs_cursor = ( + client[self.db][self.coll] + .find(filter=all_filters, projection=self.projection) + .sort([("_id", ASCENDING)]) + ) + for doc in docs_cursor: + if not range_tracker.try_claim(doc["_id"]): + return + yield doc + + def display_data(self): + res = super().display_data() + res["database"] = self.db + res["collection"] = self.coll + res["filter"] = json.dumps(self.filter, default=json_util.default) + res["projection"] = str(self.projection) + res["bucket_auto"] = self.bucket_auto + return res + + def _get_split_keys(self, desired_chunk_size_in_mb, start_pos, end_pos): + # calls mongodb splitVector command to get document ids at split position + if start_pos >= _ObjectIdHelper.increment_id(end_pos, -1): + # single document not splittable + return [] + with MongoClient(self.uri, **self.spec) as client: + name_space = "%s.%s" % (self.db, self.coll) + return client[self.db].command( + "splitVector", + name_space, + keyPattern={"_id": 1}, # Ascending index + min={"_id": start_pos}, + max={"_id": end_pos}, + maxChunkSize=desired_chunk_size_in_mb, + )["splitKeys"] + + def _get_auto_buckets( + self, + desired_chunk_size_in_mb, + start_pos: ObjectId, + end_pos: ObjectId, + is_initial_split: bool, + ) -> list: + if start_pos >= _ObjectIdHelper.increment_id(end_pos, -1): + # single document not splittable + return [] + + if is_initial_split and not self.filter: + # total collection size in MB + size_in_mb = self.estimate_size() / float(1 << 20) + else: + # size of documents within start/end id range and possibly filtered + documents_count = self._count_id_range(start_pos, end_pos) + avg_document_size = self._estimate_average_document_size() + size_in_mb = documents_count * avg_document_size / float(1 << 20) + + if size_in_mb == 0: + # no documents not splittable (maybe a result of filtering) + return [] + + bucket_count = math.ceil(size_in_mb / desired_chunk_size_in_mb) + with beam.io.mongodbio.MongoClient(self.uri, **self.spec) as client: + pipeline = [ + { + # filter by positions and by the custom filter if any + "$match": self._merge_id_filter(start_pos, end_pos) + }, + {"$bucketAuto": {"groupBy": "$_id", "buckets": bucket_count}}, + ] + buckets = list( + # Use `allowDiskUse` option to avoid aggregation limit of 100 Mb RAM + client[self.db][self.coll].aggregate(pipeline, allowDiskUse=True) + ) + if buckets: + buckets[-1]["_id"]["max"] = end_pos + + return buckets + + def _merge_id_filter( + self, + start_position: Union[int, str, ObjectId], + stop_position: Union[int, str, ObjectId] = None, + ) -> dict: + """Merge the default filter (if any) with refined _id field range + of range_tracker. + $gte specifies start position (inclusive) + and $lt specifies the end position (exclusive), + see more at + https://docs.mongodb.com/manual/reference/operator/query/gte/ and + https://docs.mongodb.com/manual/reference/operator/query/lt/ + """ + if stop_position is None: + id_filter = {"_id": {"$gte": start_position}} + else: + id_filter = {"_id": {"$gte": start_position, "$lt": stop_position}} + + if self.filter: + all_filters = { + # see more at + # https://docs.mongodb.com/manual/reference/operator/query/and/ + "$and": [self.filter.copy(), id_filter] + } + else: + all_filters = id_filter + + return all_filters + + def _get_head_document_id(self, sort_order): + with MongoClient(self.uri, **self.spec) as client: + cursor = ( + client[self.db][self.coll] + .find(filter={}, projection=[]) + .sort([("_id", sort_order)]) + .limit(1) + ) + try: + return cursor[0]["_id"] + + except IndexError: + raise ValueError("Empty Mongodb collection") + + def _replace_none_positions(self, start_position, stop_position): + + if start_position is None: + start_position = self._get_head_document_id(ASCENDING) + if stop_position is None: + last_doc_id = self._get_head_document_id(DESCENDING) + # increment last doc id binary value by 1 to make sure the last document + # is not excluded + stop_position = _ObjectIdHelper.increment_id(last_doc_id, 1) + + return start_position, stop_position + + def _count_id_range(self, start_position, stop_position): + """Number of documents between start_position (inclusive) + and stop_position (exclusive), respecting the custom filter if any. + """ + with MongoClient(self.uri, **self.spec) as client: + return client[self.db][self.coll].count_documents( + filter=self._merge_id_filter(start_position, stop_position) + ) + + +class MongoDBRangeTracker(OrderedPositionRangeTracker): + """RangeTracker for tracking MongoDB `_id` of following types: + - int + - bytes + - str + - bson ObjectId + + For bytes/string keys tracks progress through a lexicographically + ordered keyspace of strings. """ - # converts object id binary to integer - # id object is bytes type with size of 12 - ints = struct.unpack('>III', id.binary) - return (ints[0] << 64) + (ints[1] << 32) + ints[2] - @classmethod - def int_to_id(cls, number): - """ - Args: - number(int): The integer value to be used to convert to ObjectId. - - Returns: The ObjectId that has the 12 bytes binary converted from the - integer value. - - """ - # converts integer value to object id. Int value should be less than - # (2 ^ 96) so it can be convert to 12 bytes required by object id. - if number < 0 or number >= (1 << 96): - raise ValueError('number value must be within [0, %s)' % (1 << 96)) - ints = [(number & 0xffffffff0000000000000000) >> 64, - (number & 0x00000000ffffffff00000000) >> 32, - number & 0x0000000000000000ffffffff] - - bytes = struct.pack('>III', *ints) - return objectid.ObjectId(bytes) - - @classmethod - def increment_id(cls, object_id, inc): - """ - Args: - object_id: The ObjectId to change. - inc(int): The incremental int value to be added to ObjectId. - - Returns: - - """ - # increment object_id binary value by inc value and return new object id. - id_number = _ObjectIdHelper.id_to_int(object_id) - new_number = id_number + inc - if new_number < 0 or new_number >= (1 << 96): - raise ValueError( - 'invalid incremental, inc value must be within [' - '%s, %s)' % (0 - id_number, 1 << 96 - id_number)) - return _ObjectIdHelper.int_to_id(new_number) - - -class _ObjectIdRangeTracker(OrderedPositionRangeTracker): - """RangeTracker for tracking mongodb _id of bson ObjectId type.""" - def position_to_fraction(self, pos, start, end): - pos_number = _ObjectIdHelper.id_to_int(pos) - start_number = _ObjectIdHelper.id_to_int(start) - end_number = _ObjectIdHelper.id_to_int(end) - return (pos_number - start_number) / (end_number - start_number) - - def fraction_to_position(self, fraction, start, end): - start_number = _ObjectIdHelper.id_to_int(start) - end_number = _ObjectIdHelper.id_to_int(end) - total = end_number - start_number - pos = int(total * fraction + start_number) - # make sure split position is larger than start position and smaller than - # end position. - if pos <= start_number: - return _ObjectIdHelper.increment_id(start, 1) - if pos >= end_number: - return _ObjectIdHelper.increment_id(end, -1) - return _ObjectIdHelper.int_to_id(pos) + def position_to_fraction( + self, + pos: Union[int, bytes, str, ObjectId] = None, + start: Union[int, bytes, str, ObjectId] = None, + end: Union[int, bytes, str, ObjectId] = None, + ) -> float: + """Returns the fraction of keys in the range [start, end) that + are less than the given key. + """ + # Handle integer `_id` + if isinstance(pos, int) and isinstance(start, int) and isinstance(end, int): + return (pos - start) / (end - start) + + # Handle ObjectId `_id` + if ( + isinstance(pos, ObjectId) + and isinstance(start, ObjectId) + and isinstance(end, ObjectId) + ): + pos = _ObjectIdHelper.id_to_int(pos) + start = _ObjectIdHelper.id_to_int(start) + end = _ObjectIdHelper.id_to_int(end) + return (pos - start) / (end - start) + + if not pos: + return 0 + + if start is None: + start = b"" + + prec = len(start) + 7 + if pos.startswith(start): + # Higher absolute precision needed for very small values of fixed + # relative position. + prec = max(prec, len(pos) - len(pos[len(start) :].strip(b"\0")) + 7) + pos = self._bytestring_to_int(pos, prec) + start = self._bytestring_to_int(start, prec) + end = self._bytestring_to_int(end, prec) if end else 1 << (prec * 8) + + return (pos - start) / (end - start) + + def fraction_to_position( + self, + fraction: float, + start: Union[int, bytes, str, ObjectId] = None, + end: Union[int, bytes, str, ObjectId] = None, + ) -> Union[int, bytes, str, ObjectId]: + """Converts a fraction between 0 and 1 to a position between start and end. + For string keys linearly interpolates a key that is lexicographically + fraction of the way between start and end. + """ + if not 0 <= fraction <= 1: + raise ValueError(f"Invalid fraction: {fraction}! Must be in range [0, 1]") + + if isinstance(start, (int, ObjectId)) and isinstance(end, (int, ObjectId)): + start = _ObjectIdHelper.id_to_int(start) + end = _ObjectIdHelper.id_to_int(end) + total = end - start + pos = int(total * fraction + start) + # make sure split position is larger than start position and smaller than + # end position. + if pos <= start: + return _ObjectIdHelper.increment_id(start, 1) + if pos >= end: + return _ObjectIdHelper.increment_id(end, -1) + return _ObjectIdHelper.int_to_id(pos) + + if start is None: + start = b"" + + if fraction == 1: + return end + + if fraction == 0: + return start + + if not end: + common_prefix_len = len(start) - len(start.lstrip(b"\xFF")) + else: + for ix, (s, e) in enumerate(zip(start, end)): + if s != e: + common_prefix_len = ix + break + else: + common_prefix_len = min(len(start), len(end)) + # Convert the relative precision of fraction (~53 bits) to an absolute + # precision needed to represent values between start and end distinctly. + prec = common_prefix_len + int(-math.log(fraction, 256)) + 7 + start = self._bytestring_to_int(start, prec) + end = self._bytestring_to_int(end, prec) if end else 1 << (prec * 8) + pos = start + int((end - start) * fraction) + # Could be equal due to rounding. + # Adjust to ensure we never return the actual start and end + # unless fraction is exactly 0 or 1. + if pos == start: + pos += 1 + elif pos == end: + pos -= 1 + return self._bytestring_from_int(pos, prec).rstrip(b"\0") + + @staticmethod + def _bytestring_to_int(s: Union[bytes, str], prec: int) -> int: + """Returns int(256**prec * f) where f is the fraction + represented by interpreting '.' + s as a base-256 + floating point number. + """ + if not s: + return 0 + + if isinstance(s, str): + s = s.encode() + + if len(s) < prec: + s += b"\0" * (prec - len(s)) + else: + s = s[:prec] + + return int(codecs.encode(s, "hex"), 16) + + @staticmethod + def _bytestring_from_int(i: int, prec: int) -> bytes: + """Inverse of _bytestring_to_int.""" + h: str = "%x" % i + return codecs.decode("0" * (2 * prec - len(h)) + h, "hex") + + +class _ObjectIdHelper: + """A Utility class to manipulate bson object ids.""" + + @classmethod + def id_to_int(cls, _id: Union[int, ObjectId]) -> int: + """ + Args: + _id: ObjectId required for each MongoDB document _id field. + + Returns: Converted integer value of ObjectId's 12 bytes binary value. + + """ + if isinstance(_id, int): + return _id + + # converts object id binary to integer + # id object is bytes type with size of 12 + ints = struct.unpack(">III", _id.binary) + return (ints[0] << 64) + (ints[1] << 32) + ints[2] + + @classmethod + def int_to_id(cls, number): + """ + Args: + number(int): The integer value to be used to convert to ObjectId. + + Returns: The ObjectId that has the 12 bytes binary converted from the + integer value. + + """ + # converts integer value to object id. Int value should be less than + # (2 ^ 96) so it can be convert to 12 bytes required by object id. + if number < 0 or number >= (1 << 96): + raise ValueError("number value must be within [0, %s)" % (1 << 96)) + ints = [ + (number & 0xFFFFFFFF0000000000000000) >> 64, + (number & 0x00000000FFFFFFFF00000000) >> 32, + number & 0x0000000000000000FFFFFFFF, + ] + + number_bytes = struct.pack(">III", *ints) + return objectid.ObjectId(number_bytes) + + @classmethod + def increment_id( + cls, + object_id: Union[int, str, ObjectId], + inc: int, + ) -> Union[int, str, ObjectId]: + """ + Args: + object_id: The `_id` to change. + inc(int): The incremental int value to be added to `_id`. + + Returns: + `_id` incremented by `inc` value + """ + # Handle integer id: + if isinstance(object_id, int): + return object_id + inc + + # Handle string id: + if isinstance(object_id, str): + object_id = object_id or chr(31) # handle empty string ('') + # incrementing the latest symbol of the string Review comment: unittests added -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected]
