This is an automated email from the ASF dual-hosted git repository.

yichi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 5742cb9  BEAM-12122, BEAM-12119 Add integer and string _id keys 
support to Python IO MongoDB
     new 7494d14  Merge pull request #14460 from 
MaksymSkorupskyi/BEAM-12122-Python-IO-MongoDB-integer-and-string-`_id`-keys-are-not-supported
5742cb9 is described below

commit 5742cb929af6b6cb74f591f554108451c6fc32d5
Author: Maksym Skorupskyi <[email protected]>
AuthorDate: Wed Apr 7 15:46:58 2021 +0300

    BEAM-12122, BEAM-12119 Add integer and string _id keys support to Python IO 
MongoDB
    
    * handle integer and string _id keys when reading from MongoDB collections
    * RangeTracker defines by _id key type:
    - int -> OffsetRangeTracker
    - ObjectId -> _ObjectIdRangeTracker
    - string -> LexicographicKeyRangeTracker
    * update LexicographicKeyRangeTracker to proper handling both bytes and str 
keys
    * make range_trackers.OrderedPositionRangeTracker.try_split() logic 
consistent with range_trackers.OffsetRangeTracker.try_split()
    * add unittests for MongoDB _id of int and str types
    * update range_tracker_test
---
 sdks/python/apache_beam/io/mongodbio.py           | 491 +++++++++++++++-------
 sdks/python/apache_beam/io/mongodbio_test.py      | 230 ++++++++--
 sdks/python/apache_beam/io/range_trackers.py      | 131 +++---
 sdks/python/apache_beam/io/range_trackers_test.py | 102 ++++-
 4 files changed, 706 insertions(+), 248 deletions(-)

diff --git a/sdks/python/apache_beam/io/mongodbio.py 
b/sdks/python/apache_beam/io/mongodbio.py
index a56c274..c6f7d97 100644
--- a/sdks/python/apache_beam/io/mongodbio.py
+++ b/sdks/python/apache_beam/io/mongodbio.py
@@ -71,9 +71,12 @@ import json
 import logging
 import math
 import struct
+from typing import Union
 
 import apache_beam as beam
 from apache_beam.io import iobase
+from apache_beam.io.range_trackers import LexicographicKeyRangeTracker
+from apache_beam.io.range_trackers import OffsetRangeTracker
 from apache_beam.io.range_trackers import OrderedPositionRangeTracker
 from apache_beam.transforms import DoFn
 from apache_beam.transforms import PTransform
@@ -83,11 +86,13 @@ from apache_beam.utils.annotations import experimental
 _LOGGER = logging.getLogger(__name__)
 
 try:
-  # Mongodb has its own bundled bson, which is not compatible with bson 
pakcage.
+  # Mongodb has its own bundled bson, which is not compatible with bson 
package.
   # (https://github.com/py-bson/bson/issues/82). Try to import objectid and if
   # it fails because bson package is installed, MongoDB IO will not work but at
   # least rest of the SDK will work.
+  from bson import json_util
   from bson import objectid
+  from bson.objectid import ObjectId
 
   # pymongo also internally depends on bson.
   from pymongo import ASCENDING
@@ -96,24 +101,30 @@ try:
   from pymongo import ReplaceOne
 except ImportError:
   objectid = None
+  json_util = None
+  ObjectId = None
+  ASCENDING = 1
+  DESCENDING = -1
+  MongoClient = None
+  ReplaceOne = None
   _LOGGER.warning("Could not find a compatible bson package.")
 
-__all__ = ['ReadFromMongoDB', 'WriteToMongoDB']
+__all__ = ["ReadFromMongoDB", "WriteToMongoDB"]
 
 
 @experimental()
 class ReadFromMongoDB(PTransform):
-  """A ``PTransform`` to read MongoDB documents into a ``PCollection``.
-  """
+  """A ``PTransform`` to read MongoDB documents into a ``PCollection``."""
   def __init__(
       self,
-      uri='mongodb://localhost:27017',
+      uri="mongodb://localhost:27017",
       db=None,
       coll=None,
       filter=None,
       projection=None,
       extra_client_params=None,
-      bucket_auto=False):
+      bucket_auto=False,
+  ):
     """Initialize a :class:`ReadFromMongoDB`
 
     Args:
@@ -136,16 +147,14 @@ class ReadFromMongoDB(PTransform):
 
     Returns:
       :class:`~apache_beam.transforms.ptransform.PTransform`
-
     """
     if extra_client_params is None:
       extra_client_params = {}
     if not isinstance(db, str):
-      raise ValueError('ReadFromMongDB db param must be specified as a string')
+      raise ValueError("ReadFromMongDB db param must be specified as a string")
     if not isinstance(coll, str):
       raise ValueError(
-          'ReadFromMongDB coll param must be specified as a '
-          'string')
+          "ReadFromMongDB coll param must be specified as a string")
     self._mongo_source = _BoundedMongoSource(
         uri=uri,
         db=db,
@@ -153,13 +162,88 @@ class ReadFromMongoDB(PTransform):
         filter=filter,
         projection=projection,
         extra_client_params=extra_client_params,
-        bucket_auto=bucket_auto)
+        bucket_auto=bucket_auto,
+    )
 
   def expand(self, pcoll):
     return pcoll | iobase.Read(self._mongo_source)
 
 
+class _ObjectIdRangeTracker(OrderedPositionRangeTracker):
+  """RangeTracker for tracking mongodb _id of bson ObjectId type."""
+  def position_to_fraction(
+      self,
+      pos: ObjectId,
+      start: ObjectId,
+      end: ObjectId,
+  ):
+    """Returns the fraction of keys in the range [start, end) that
+    are less than the given key.
+    """
+    pos_number = _ObjectIdHelper.id_to_int(pos)
+    start_number = _ObjectIdHelper.id_to_int(start)
+    end_number = _ObjectIdHelper.id_to_int(end)
+    return (pos_number - start_number) / (end_number - start_number)
+
+  def fraction_to_position(
+      self,
+      fraction: float,
+      start: ObjectId,
+      end: ObjectId,
+  ):
+    """Converts a fraction between 0 and 1
+    to a position between start and end.
+    """
+    start_number = _ObjectIdHelper.id_to_int(start)
+    end_number = _ObjectIdHelper.id_to_int(end)
+    total = end_number - start_number
+    pos = int(total * fraction + start_number)
+    # make sure split position is larger than start position and smaller than
+    # end position.
+    if pos <= start_number:
+      return _ObjectIdHelper.increment_id(start, 1)
+
+    if pos >= end_number:
+      return _ObjectIdHelper.increment_id(end, -1)
+
+    return _ObjectIdHelper.int_to_id(pos)
+
+
 class _BoundedMongoSource(iobase.BoundedSource):
+  """A MongoDB source that reads a finite amount of input records.
+
+  This class defines following operations which can be used to read
+  MongoDB source efficiently.
+
+  * Size estimation - method ``estimate_size()`` may return an accurate
+    estimation in bytes for the size of the source.
+  * Splitting into bundles of a given size - method ``split()`` can be used to
+    split the source into a set of sub-sources (bundles) based on a desired
+    bundle size.
+  * Getting a RangeTracker - method ``get_range_tracker()`` should return a
+    ``RangeTracker`` object for a given position range for the position type
+    of the records returned by the source.
+  * Reading the data - method ``read()`` can be used to read data from the
+    source while respecting the boundaries defined by a given
+    ``RangeTracker``.
+
+  A runner will perform reading the source in two steps.
+
+  (1) Method ``get_range_tracker()`` will be invoked with start and end
+      positions to obtain a ``RangeTracker`` for the range of positions the
+      runner intends to read. Source must define a default initial start and 
end
+      position range. These positions must be used if the start and/or end
+      positions passed to the method ``get_range_tracker()`` are ``None``
+  (2) Method read() will be invoked with the ``RangeTracker`` obtained in the
+      previous step.
+
+  **Mutability**
+
+  A ``_BoundedMongoSource`` object should not be mutated while
+  its methods (for example, ``read()``) are being invoked by a runner. Runner
+  implementations may invoke methods of ``_BoundedMongoSource`` objects through
+  multi-threaded and/or reentrant execution modes.
+  """
   def __init__(
       self,
       uri=None,
@@ -168,7 +252,8 @@ class _BoundedMongoSource(iobase.BoundedSource):
       filter=None,
       projection=None,
       extra_client_params=None,
-      bucket_auto=False):
+      bucket_auto=False,
+  ):
     if extra_client_params is None:
       extra_client_params = {}
     if filter is None:
@@ -183,13 +268,33 @@ class _BoundedMongoSource(iobase.BoundedSource):
 
   def estimate_size(self):
     with MongoClient(self.uri, **self.spec) as client:
-      return client[self.db].command('collstats', self.coll).get('size')
+      return client[self.db].command("collstats", self.coll).get("size")
 
   def _estimate_average_document_size(self):
     with MongoClient(self.uri, **self.spec) as client:
-      return client[self.db].command('collstats', self.coll).get('avgObjSize')
+      return client[self.db].command("collstats", self.coll).get("avgObjSize")
+
+  def split(
+      self,
+      desired_bundle_size: int,
+      start_position: Union[int, str, bytes, ObjectId] = None,
+      stop_position: Union[int, str, bytes, ObjectId] = None,
+  ):
+    """Splits the source into a set of bundles.
+
+    Bundles should be approximately of size ``desired_bundle_size`` bytes.
+
+    Args:
+      desired_bundle_size: the desired size (in bytes) of the bundles returned.
+      start_position: if specified the given position must be used as the
+                      starting position of the first bundle.
+      stop_position: if specified the given position must be used as the ending
+                     position of the last bundle.
+    Returns:
+      an iterator of objects of type 'SourceBundle' that gives information 
about
+      the generated bundles.
+    """
 
-  def split(self, desired_bundle_size, start_position=None, 
stop_position=None):
     desired_bundle_size_in_mb = desired_bundle_size // 1024 // 1024
 
     # for desired bundle size, if desired chunk size smaller than 1mb, use
@@ -199,18 +304,21 @@ class _BoundedMongoSource(iobase.BoundedSource):
 
     is_initial_split = start_position is None and stop_position is None
     start_position, stop_position = self._replace_none_positions(
-      start_position, stop_position)
+      start_position, stop_position
+    )
 
     if self.bucket_auto:
       # Use $bucketAuto for bundling
       split_keys = []
       weights = []
-      for bucket in self._get_auto_buckets(desired_bundle_size_in_mb,
-                                           start_position,
-                                           stop_position,
-                                           is_initial_split):
-        split_keys.append({'_id': bucket['_id']['max']})
-        weights.append(bucket['count'])
+      for bucket in self._get_auto_buckets(
+          desired_bundle_size_in_mb,
+          start_position,
+          stop_position,
+          is_initial_split,
+      ):
+        split_keys.append({"_id": bucket["_id"]["max"]})
+        weights.append(bucket["count"])
     else:
       # Use splitVector for bundling
       split_keys = self._get_split_keys(
@@ -221,12 +329,13 @@ class _BoundedMongoSource(iobase.BoundedSource):
     for split_key_id, weight in zip(split_keys, weights):
       if bundle_start >= stop_position:
         break
-      bundle_end = min(stop_position, split_key_id['_id'])
+      bundle_end = min(stop_position, split_key_id["_id"])
       yield iobase.SourceBundle(
           weight=weight,
           source=self,
           start_position=bundle_start,
-          stop_position=bundle_end)
+          stop_position=bundle_end,
+      )
       bundle_start = bundle_end
     # add range of last split_key to stop_position
     if bundle_start < stop_position:
@@ -236,60 +345,146 @@ class _BoundedMongoSource(iobase.BoundedSource):
           weight=weight,
           source=self,
           start_position=bundle_start,
-          stop_position=stop_position)
+          stop_position=stop_position,
+      )
 
-  def get_range_tracker(self, start_position, stop_position):
+  def get_range_tracker(
+      self,
+      start_position: Union[int, str, ObjectId] = None,
+      stop_position: Union[int, str, ObjectId] = None,
+  ) -> Union[
+      _ObjectIdRangeTracker, OffsetRangeTracker, LexicographicKeyRangeTracker]:
+    """Returns a RangeTracker for a given position range depending on type.
+
+    Args:
+      start_position: starting position of the range. If 'None' default start
+                      position of the source must be used.
+      stop_position:  ending position of the range. If 'None' default stop
+                      position of the source must be used.
+    Returns:
+      a ``_ObjectIdRangeTracker``, ``OffsetRangeTracker``
+      or ``LexicographicKeyRangeTracker`` depending on the given position 
range.
+    """
     start_position, stop_position = self._replace_none_positions(
-        start_position, stop_position)
-    return _ObjectIdRangeTracker(start_position, stop_position)
+      start_position, stop_position
+    )
+
+    if isinstance(start_position, ObjectId):
+      return _ObjectIdRangeTracker(start_position, stop_position)
+
+    if isinstance(start_position, int):
+      return OffsetRangeTracker(start_position, stop_position)
+
+    if isinstance(start_position, str):
+      return LexicographicKeyRangeTracker(start_position, stop_position)
+
+    raise NotImplementedError(
+        f"RangeTracker for {type(start_position)} not implemented!")
 
   def read(self, range_tracker):
+    """Returns an iterator that reads data from the source.
+
+    The returned set of data must respect the boundaries defined by the given
+    ``RangeTracker`` object. For example:
+
+      * Returned set of data must be for the range
+        ``[range_tracker.start_position, range_tracker.stop_position)``. Note
+        that a source may decide to return records that start after
+        ``range_tracker.stop_position``. See documentation in class
+        ``RangeTracker`` for more details. Also, note that framework might
+        invoke ``range_tracker.try_split()`` to perform dynamic split
+        operations. range_tracker.stop_position may be updated
+        dynamically due to successful dynamic split operations.
+      * Method ``range_tracker.try_split()`` must be invoked for every record
+        that starts at a split point.
+      * Method ``range_tracker.record_current_position()`` may be invoked for
+        records that do not start at split points.
+
+    Args:
+      range_tracker: a ``RangeTracker`` whose boundaries must be respected
+                     when reading data from the source. A runner that reads 
this
+                     source muss pass a ``RangeTracker`` object that is not
+                     ``None``.
+    Returns:
+      an iterator of data read by the source.
+    """
     with MongoClient(self.uri, **self.spec) as client:
       all_filters = self._merge_id_filter(
           range_tracker.start_position(), range_tracker.stop_position())
-      docs_cursor = client[self.db][self.coll].find(
-          filter=all_filters,
-          projection=self.projection).sort([('_id', ASCENDING)])
+      docs_cursor = (
+          client[self.db][self.coll].find(
+              filter=all_filters,
+              projection=self.projection).sort([("_id", ASCENDING)]))
       for doc in docs_cursor:
-        if not range_tracker.try_claim(doc['_id']):
+        if not range_tracker.try_claim(doc["_id"]):
           return
         yield doc
 
   def display_data(self):
-    res = super(_BoundedMongoSource, self).display_data()
-    res['database'] = self.db
-    res['collection'] = self.coll
-    res['filter'] = json.dumps(
-        self.filter, default=lambda x: 'not_serializable(%s)' % str(x))
-    res['projection'] = str(self.projection)
-    res['bucket_auto'] = self.bucket_auto
+    """Returns the display data associated to a pipeline component."""
+    res = super().display_data()
+    res["database"] = self.db
+    res["collection"] = self.coll
+    res["filter"] = json.dumps(self.filter, default=json_util.default)
+    res["projection"] = str(self.projection)
+    res["bucket_auto"] = self.bucket_auto
     return res
 
-  def _get_split_keys(self, desired_chunk_size_in_mb, start_pos, end_pos):
-    # calls mongodb splitVector command to get document ids at split position
-    if start_pos >= _ObjectIdHelper.increment_id(end_pos, -1):
-      # single document not splittable
+  @staticmethod
+  def _range_is_not_splittable(
+      start_pos: Union[int, str, ObjectId],
+      end_pos: Union[int, str, ObjectId],
+  ):
+    """Return `True` if splitting range doesn't make sense
+    (single document is not splittable),
+    Return `False` otherwise.
+    """
+    return ((
+        isinstance(start_pos, ObjectId) and
+        start_pos >= _ObjectIdHelper.increment_id(end_pos, -1)) or
+            (isinstance(start_pos, int) and start_pos >= end_pos - 1) or
+            (isinstance(start_pos, str) and start_pos >= end_pos))
+
+  def _get_split_keys(
+      self,
+      desired_chunk_size_in_mb: int,
+      start_pos: Union[int, str, ObjectId],
+      end_pos: Union[int, str, ObjectId],
+  ):
+    """Calls MongoDB `splitVector` command
+    to get document ids at split position.
+    """
+    # single document not splittable
+    if self._range_is_not_splittable(start_pos, end_pos):
       return []
+
     with MongoClient(self.uri, **self.spec) as client:
-      name_space = '%s.%s' % (self.db, self.coll)
-      return (
-          client[self.db].command(
-              'splitVector',
-              name_space,
-              keyPattern={'_id': 1},  # Ascending index
-              min={'_id': start_pos},
-              max={'_id': end_pos},
-              maxChunkSize=desired_chunk_size_in_mb)['splitKeys'])
+      name_space = "%s.%s" % (self.db, self.coll)
+      return client[self.db].command(
+        "splitVector",
+        name_space,
+        keyPattern={"_id": 1},  # Ascending index
+        min={"_id": start_pos},
+        max={"_id": end_pos},
+        maxChunkSize=desired_chunk_size_in_mb,
+      )["splitKeys"]
 
   def _get_auto_buckets(
-      self, desired_chunk_size_in_mb, start_pos, end_pos, is_initial_split):
-
-    if start_pos >= _ObjectIdHelper.increment_id(end_pos, -1):
-      # single document not splittable
+      self,
+      desired_chunk_size_in_mb: int,
+      start_pos: Union[int, str, ObjectId],
+      end_pos: Union[int, str, ObjectId],
+      is_initial_split: bool,
+  ) -> list:
+    """Use MongoDB `$bucketAuto` aggregation to split collection into bundles
+    instead of `splitVector` command, which does not work with MongoDB Atlas.
+    """
+    # single document not splittable
+    if self._range_is_not_splittable(start_pos, end_pos):
       return []
 
     if is_initial_split and not self.filter:
-      # total collection size
+      # total collection size in MB
       size_in_mb = self.estimate_size() / float(1 << 20)
     else:
       # size of documents within start/end id range and possibly filtered
@@ -306,34 +501,46 @@ class _BoundedMongoSource(iobase.BoundedSource):
       pipeline = [
           {
               # filter by positions and by the custom filter if any
-              '$match': self._merge_id_filter(start_pos, end_pos)
+              "$match": self._merge_id_filter(start_pos, end_pos)
           },
           {
-              '$bucketAuto': {
-                  'groupBy': '$_id', 'buckets': bucket_count
+              "$bucketAuto": {
+                  "groupBy": "$_id", "buckets": bucket_count
               }
-          }
+          },
       ]
-      buckets = list(client[self.db][self.coll].aggregate(pipeline))
+      buckets = list(
+          # Use `allowDiskUse` option to avoid aggregation limit of 100 Mb RAM
+          client[self.db][self.coll].aggregate(pipeline, allowDiskUse=True))
       if buckets:
-        buckets[-1]['_id']['max'] = end_pos
+        buckets[-1]["_id"]["max"] = end_pos
 
       return buckets
 
-  def _merge_id_filter(self, start_position, stop_position):
-    # Merge the default filter (if any) with refined _id field range
-    # of range_tracker.
-    # $gte specifies start position (inclusive)
-    # and $lt specifies the end position (exclusive),
-    # see more at
-    # https://docs.mongodb.com/manual/reference/operator/query/gte/ and
-    # https://docs.mongodb.com/manual/reference/operator/query/lt/
-    id_filter = {'_id': {'$gte': start_position, '$lt': stop_position}}
+  def _merge_id_filter(
+      self,
+      start_position: Union[int, str, bytes, ObjectId],
+      stop_position: Union[int, str, bytes, ObjectId] = None,
+  ) -> dict:
+    """Merge the default filter (if any) with refined _id field range
+    of range_tracker.
+    $gte specifies start position (inclusive)
+    and $lt specifies the end position (exclusive),
+    see more at
+    https://docs.mongodb.com/manual/reference/operator/query/gte/ and
+    https://docs.mongodb.com/manual/reference/operator/query/lt/
+    """
+
+    if stop_position is None:
+      id_filter = {"_id": {"$gte": start_position}}
+    else:
+      id_filter = {"_id": {"$gte": start_position, "$lt": stop_position}}
+
     if self.filter:
       all_filters = {
           # see more at
           # https://docs.mongodb.com/manual/reference/operator/query/and/
-          '$and': [self.filter.copy(), id_filter]
+          "$and": [self.filter.copy(), id_filter]
       }
     else:
       all_filters = id_filter
@@ -342,45 +549,58 @@ class _BoundedMongoSource(iobase.BoundedSource):
 
   def _get_head_document_id(self, sort_order):
     with MongoClient(self.uri, **self.spec) as client:
-      cursor = client[self.db][self.coll].find(
-          filter={}, projection=[]).sort([('_id', sort_order)]).limit(1)
+      cursor = (
+          client[self.db][self.coll].find(filter={}, projection=[]).sort([
+              ("_id", sort_order)
+          ]).limit(1))
       try:
-        return cursor[0]['_id']
+        return cursor[0]["_id"]
+
       except IndexError:
-        raise ValueError('Empty Mongodb collection')
+        raise ValueError("Empty Mongodb collection")
 
   def _replace_none_positions(self, start_position, stop_position):
+
     if start_position is None:
       start_position = self._get_head_document_id(ASCENDING)
     if stop_position is None:
       last_doc_id = self._get_head_document_id(DESCENDING)
       # increment last doc id binary value by 1 to make sure the last document
       # is not excluded
-      stop_position = _ObjectIdHelper.increment_id(last_doc_id, 1)
+      if isinstance(last_doc_id, ObjectId):
+        stop_position = _ObjectIdHelper.increment_id(last_doc_id, 1)
+      elif isinstance(last_doc_id, int):
+        stop_position = last_doc_id + 1
+      elif isinstance(last_doc_id, str):
+        stop_position = last_doc_id + '\x00'
+
     return start_position, stop_position
 
   def _count_id_range(self, start_position, stop_position):
-    # Number of documents between start_position (inclusive)
-    # and stop_position (exclusive), respecting the custom filter if any.
+    """Number of documents between start_position (inclusive)
+    and stop_position (exclusive), respecting the custom filter if any.
+    """
     with MongoClient(self.uri, **self.spec) as client:
       return client[self.db][self.coll].count_documents(
           filter=self._merge_id_filter(start_position, stop_position))
 
 
-class _ObjectIdHelper(object):
+class _ObjectIdHelper:
   """A Utility class to manipulate bson object ids."""
   @classmethod
-  def id_to_int(cls, id):
+  def id_to_int(cls, _id: Union[int, ObjectId]) -> int:
     """
     Args:
-      id: ObjectId required for each MongoDB document _id field.
+      _id: ObjectId required for each MongoDB document _id field.
 
     Returns: Converted integer value of ObjectId's 12 bytes binary value.
-
     """
+    if isinstance(_id, int):
+      return _id
+
     # converts object id binary to integer
     # id object is bytes type with size of 12
-    ints = struct.unpack('>III', id.binary)
+    ints = struct.unpack(">III", _id.binary)
     return (ints[0] << 64) + (ints[1] << 32) + ints[2]
 
   @classmethod
@@ -391,61 +611,45 @@ class _ObjectIdHelper(object):
 
     Returns: The ObjectId that has the 12 bytes binary converted from the
       integer value.
-
     """
     # converts integer value to object id. Int value should be less than
     # (2 ^ 96) so it can be convert to 12 bytes required by object id.
     if number < 0 or number >= (1 << 96):
-      raise ValueError('number value must be within [0, %s)' % (1 << 96))
-    ints = [(number & 0xffffffff0000000000000000) >> 64,
-            (number & 0x00000000ffffffff00000000) >> 32,
-            number & 0x0000000000000000ffffffff]
+      raise ValueError("number value must be within [0, %s)" % (1 << 96))
+    ints = [
+        (number & 0xFFFFFFFF0000000000000000) >> 64,
+        (number & 0x00000000FFFFFFFF00000000) >> 32,
+        number & 0x0000000000000000FFFFFFFF,
+    ]
 
-    bytes = struct.pack('>III', *ints)
-    return objectid.ObjectId(bytes)
+    number_bytes = struct.pack(">III", *ints)
+    return ObjectId(number_bytes)
 
   @classmethod
-  def increment_id(cls, object_id, inc):
+  def increment_id(
+      cls,
+      _id: ObjectId,
+      inc: int,
+  ) -> ObjectId:
     """
+    Increment object_id binary value by inc value and return new object id.
+
     Args:
-      object_id: The ObjectId to change.
-      inc(int): The incremental int value to be added to ObjectId.
+      _id: The `_id` to change.
+      inc(int): The incremental int value to be added to `_id`.
 
     Returns:
-
+        `_id` incremented by `inc` value
     """
-    # increment object_id binary value by inc value and return new object id.
-    id_number = _ObjectIdHelper.id_to_int(object_id)
+    id_number = _ObjectIdHelper.id_to_int(_id)
     new_number = id_number + inc
     if new_number < 0 or new_number >= (1 << 96):
       raise ValueError(
-          'invalid incremental, inc value must be within ['
-          '%s, %s)' % (0 - id_number, 1 << 96 - id_number))
+          "invalid incremental, inc value must be within ["
+          "%s, %s)" % (0 - id_number, 1 << 96 - id_number))
     return _ObjectIdHelper.int_to_id(new_number)
 
 
-class _ObjectIdRangeTracker(OrderedPositionRangeTracker):
-  """RangeTracker for tracking mongodb _id of bson ObjectId type."""
-  def position_to_fraction(self, pos, start, end):
-    pos_number = _ObjectIdHelper.id_to_int(pos)
-    start_number = _ObjectIdHelper.id_to_int(start)
-    end_number = _ObjectIdHelper.id_to_int(end)
-    return (pos_number - start_number) / (end_number - start_number)
-
-  def fraction_to_position(self, fraction, start, end):
-    start_number = _ObjectIdHelper.id_to_int(start)
-    end_number = _ObjectIdHelper.id_to_int(end)
-    total = end_number - start_number
-    pos = int(total * fraction + start_number)
-    # make sure split position is larger than start position and smaller than
-    # end position.
-    if pos <= start_number:
-      return _ObjectIdHelper.increment_id(start, 1)
-    if pos >= end_number:
-      return _ObjectIdHelper.increment_id(end, -1)
-    return _ObjectIdHelper.int_to_id(pos)
-
-
 @experimental()
 class WriteToMongoDB(PTransform):
   """WriteToMongoDB is a ``PTransform`` that writes a ``PCollection`` of
@@ -472,11 +676,12 @@ class WriteToMongoDB(PTransform):
   """
   def __init__(
       self,
-      uri='mongodb://localhost:27017',
+      uri="mongodb://localhost:27017",
       db=None,
       coll=None,
       batch_size=100,
-      extra_client_params=None):
+      extra_client_params=None,
+  ):
     """
 
     Args:
@@ -496,11 +701,10 @@ class WriteToMongoDB(PTransform):
     if extra_client_params is None:
       extra_client_params = {}
     if not isinstance(db, str):
-      raise ValueError('WriteToMongoDB db param must be specified as a string')
+      raise ValueError("WriteToMongoDB db param must be specified as a string")
     if not isinstance(coll, str):
       raise ValueError(
-          'WriteToMongoDB coll param must be specified as a '
-          'string')
+          "WriteToMongoDB coll param must be specified as a string")
     self._uri = uri
     self._db = db
     self._coll = coll
@@ -508,25 +712,27 @@ class WriteToMongoDB(PTransform):
     self._spec = extra_client_params
 
   def expand(self, pcoll):
-    return pcoll \
-           | beam.ParDo(_GenerateObjectIdFn()) \
-           | Reshuffle() \
-           | beam.ParDo(_WriteMongoFn(self._uri, self._db, self._coll,
-                                      self._batch_size, self._spec))
+    return (
+        pcoll
+        | beam.ParDo(_GenerateObjectIdFn())
+        | Reshuffle()
+        | beam.ParDo(
+            _WriteMongoFn(
+                self._uri, self._db, self._coll, self._batch_size, 
self._spec)))
 
 
 class _GenerateObjectIdFn(DoFn):
   def process(self, element, *args, **kwargs):
     # if _id field already exist we keep it as it is, otherwise the ptransform
     # generates a new _id field to achieve idempotent write to mongodb.
-    if '_id' not in element:
+    if "_id" not in element:
       # object.ObjectId() generates a unique identifier that follows mongodb
       # default format, if _id is not present in document, mongodb server
       # generates it with this same function upon write. However the
       # uniqueness of generated id may not be guaranteed if the work load are
       # distributed across too many processes. See more on the ObjectId format
       # https://docs.mongodb.com/manual/reference/bson-types/#objectid.
-      element['_id'] = objectid.ObjectId()
+      element["_id"] = objectid.ObjectId()
 
     yield element
 
@@ -560,13 +766,13 @@ class _WriteMongoFn(DoFn):
 
   def display_data(self):
     res = super(_WriteMongoFn, self).display_data()
-    res['database'] = self.db
-    res['collection'] = self.coll
-    res['batch_size'] = self.batch_size
+    res["database"] = self.db
+    res["collection"] = self.coll
+    res["batch_size"] = self.batch_size
     return res
 
 
-class _MongoSink(object):
+class _MongoSink:
   def __init__(self, uri=None, db=None, coll=None, extra_params=None):
     if extra_params is None:
       extra_params = {}
@@ -585,17 +791,18 @@ class _MongoSink(object):
       # insert new one, otherwise overwrite it.
       requests.append(
           ReplaceOne(
-              filter={'_id': doc.get('_id', None)},
+              filter={"_id": doc.get("_id", None)},
               replacement=doc,
               upsert=True))
     resp = self.client[self.db][self.coll].bulk_write(requests)
     _LOGGER.debug(
-        'BulkWrite to MongoDB result in nModified:%d, nUpserted:%d, '
-        'nMatched:%d, Errors:%s' % (
+        "BulkWrite to MongoDB result in nModified:%d, nUpserted:%d, "
+        "nMatched:%d, Errors:%s" % (
             resp.modified_count,
             resp.upserted_count,
             resp.matched_count,
-            resp.bulk_api_result.get('writeErrors')))
+            resp.bulk_api_result.get("writeErrors"),
+        ))
 
   def __enter__(self):
     if self.client is None:
diff --git a/sdks/python/apache_beam/io/mongodbio_test.py 
b/sdks/python/apache_beam/io/mongodbio_test.py
index f6f467f..150eac2 100644
--- a/sdks/python/apache_beam/io/mongodbio_test.py
+++ b/sdks/python/apache_beam/io/mongodbio_test.py
@@ -20,9 +20,11 @@ import datetime
 import logging
 import random
 import unittest
+from typing import Union
 from unittest import TestCase
 
 import mock
+from bson import ObjectId
 from bson import objectid
 from parameterized import parameterized_class
 from pymongo import ASCENDING
@@ -38,6 +40,8 @@ from apache_beam.io.mongodbio import _MongoSink
 from apache_beam.io.mongodbio import _ObjectIdHelper
 from apache_beam.io.mongodbio import _ObjectIdRangeTracker
 from apache_beam.io.mongodbio import _WriteMongoFn
+from apache_beam.io.range_trackers import LexicographicKeyRangeTracker
+from apache_beam.io.range_trackers import OffsetRangeTracker
 from apache_beam.testing.test_pipeline import TestPipeline
 from apache_beam.testing.util import assert_that
 from apache_beam.testing.util import equal_to
@@ -120,7 +124,7 @@ class _MockMongoColl(object):
   def count_documents(self, filter):
     return len(self._filter(filter))
 
-  def aggregate(self, pipeline):
+  def aggregate(self, pipeline, **kwargs):
     # Simulate $bucketAuto aggregate pipeline.
     # Example splits doc count for the total of 5 docs:
     #   - 1 bucket:  [5]
@@ -175,7 +179,7 @@ class _MockMongoDb(object):
   def command(self, command, *args, **kwargs):
     if command == 'collstats':
       return {'size': 5 * 1024 * 1024, 'avgObjSize': 1 * 1024 * 1024}
-    elif command == 'splitVector':
+    if command == 'splitVector':
       return self.get_split_keys(command, *args, **kwargs)
 
   def get_split_keys(self, command, ns, min, max, maxChunkSize, **kwargs):
@@ -205,7 +209,7 @@ class _MockMongoDb(object):
     }
 
 
-class _MockMongoClient(object):
+class _MockMongoClient:
   def __init__(self, docs):
     self.docs = docs
 
@@ -219,15 +223,87 @@ class _MockMongoClient(object):
     pass
 
 
-@parameterized_class(('bucket_auto', ), [(None, ), (True, )])
+# Generate test data for MongoDB collections of different types
+OBJECT_IDS = [
+    objectid.ObjectId.from_datetime(
+        datetime.datetime(year=2020, month=i + 1, day=i + 1)) for i in range(5)
+]
+
+INT_IDS = [n for n in range(5)]  # [0, 1, 2, 3, 4]
+
+STR_IDS_1 = [str(n) for n in range(5)]  # ['0', '1', '2', '3', '4']
+
+# ['aaaaa', 'bbbbb', 'ccccc', 'ddddd', 'eeeee']
+STR_IDS_2 = [chr(97 + n) * 5 for n in range(5)]
+
+# ['AAAAAAAAAAAAAAAAAAAA', 'BBBBBBBBBBBBBBBBBBBB', ..., 'EEEEEEEEEEEEEEEEEEEE']
+STR_IDS_3 = [chr(65 + n) * 20 for n in range(5)]
+
+
+@parameterized_class(('bucket_auto', '_ids', 'min_id', 'max_id'),
+                     [
+                         (
+                             None,
+                             OBJECT_IDS,
+                             _ObjectIdHelper.int_to_id(0),
+                             _ObjectIdHelper.int_to_id(2**96 - 1)),
+                         (
+                             True,
+                             OBJECT_IDS,
+                             _ObjectIdHelper.int_to_id(0),
+                             _ObjectIdHelper.int_to_id(2**96 - 1)),
+                         (
+                             None,
+                             INT_IDS,
+                             0,
+                             2**96 - 1,
+                         ),
+                         (
+                             True,
+                             INT_IDS,
+                             0,
+                             2**96 - 1,
+                         ),
+                         (
+                             None,
+                             STR_IDS_1,
+                             chr(0),
+                             chr(0x10ffff),
+                         ),
+                         (
+                             True,
+                             STR_IDS_1,
+                             chr(0),
+                             chr(0x10ffff),
+                         ),
+                         (
+                             None,
+                             STR_IDS_2,
+                             chr(0),
+                             chr(0x10ffff),
+                         ),
+                         (
+                             True,
+                             STR_IDS_2,
+                             chr(0),
+                             chr(0x10ffff),
+                         ),
+                         (
+                             None,
+                             STR_IDS_3,
+                             chr(0),
+                             chr(0x10ffff),
+                         ),
+                         (
+                             True,
+                             STR_IDS_3,
+                             chr(0),
+                             chr(0x10ffff),
+                         ),
+                     ])
 class MongoSourceTest(unittest.TestCase):
   @mock.patch('apache_beam.io.mongodbio.MongoClient')
   def setUp(self, mock_client):
-    self._ids = [
-        objectid.ObjectId.from_datetime(
-            datetime.datetime(year=2020, month=i + 1, day=i + 1))
-        for i in range(5)
-    ]
     self._docs = [{'_id': self._ids[i], 'x': i} for i in range(len(self._ids))]
     mock_client.return_value = _MockMongoClient(self._docs)
 
@@ -242,6 +318,27 @@ class MongoSourceTest(unittest.TestCase):
       kwargs['bucket_auto'] = bucket_auto
     return _BoundedMongoSource('mongodb://test', 'testdb', 'testcoll', 
**kwargs)
 
+  def _increment_id(
+      self,
+      _id: Union[ObjectId, int, str],
+      inc: int,
+  ) -> Union[ObjectId, int, str]:
+    """Helper method to increment `_id` of different types."""
+
+    if isinstance(_id, ObjectId):
+      return _ObjectIdHelper.increment_id(_id, inc)
+
+    if isinstance(_id, int):
+      return _id + inc
+
+    if isinstance(_id, str):
+      index = self._ids.index(_id) + inc
+      if index <= 0:
+        return self._ids[0]
+      if index >= len(self._ids):
+        return self._ids[-1]
+      return self._ids[index]
+
   @mock.patch('apache_beam.io.mongodbio.MongoClient')
   def test_estimate_size(self, mock_client):
     mock_client.return_value = _MockMongoClient(self._docs)
@@ -282,10 +379,12 @@ class MongoSourceTest(unittest.TestCase):
               start_position=None, stop_position=None,
               desired_bundle_size=size))
       self.assertEqual(len(splits), 1)
-      self.assertEqual(splits[0].start_position, self._docs[0]['_id'])
-      self.assertEqual(
-          splits[0].stop_position,
-          _ObjectIdHelper.increment_id(self._docs[0]['_id'], 1))
+      _id = self._docs[0]['_id']
+      assert _id == splits[0].start_position
+      assert _id <= splits[0].stop_position
+      if isinstance(_id, (ObjectId, int)):
+        # We can unambiguously determine next `_id`
+        assert self._increment_id(_id, 1) == splits[0].stop_position
 
   @mock.patch('apache_beam.io.mongodbio.MongoClient')
   def test_split_no_documents(self, mock_client):
@@ -357,8 +456,8 @@ class MongoSourceTest(unittest.TestCase):
       reference_info = (
           filtered_mongo_source,
           # range to match no documents:
-          _ObjectIdHelper.increment_id(self._docs[-1]['_id'], 1),
-          _ObjectIdHelper.increment_id(self._docs[-1]['_id'], 2),
+          self._increment_id(self._docs[-1]['_id'], 1),
+          self._increment_id(self._docs[-1]['_id'], 2),
       )
       sources_info = ([
           (split.source, split.start_position, split.stop_position)
@@ -379,8 +478,21 @@ class MongoSourceTest(unittest.TestCase):
   @mock.patch('apache_beam.io.mongodbio.MongoClient')
   def test_get_range_tracker(self, mock_client):
     mock_client.return_value = _MockMongoClient(self._docs)
-    self.assertIsInstance(
-        self.mongo_source.get_range_tracker(None, None), _ObjectIdRangeTracker)
+    if self._ids == OBJECT_IDS:
+      self.assertIsInstance(
+          self.mongo_source.get_range_tracker(None, None),
+          _ObjectIdRangeTracker,
+      )
+    elif self._ids == INT_IDS:
+      self.assertIsInstance(
+          self.mongo_source.get_range_tracker(None, None),
+          OffsetRangeTracker,
+      )
+    elif self._ids == STR_IDS_1:
+      self.assertIsInstance(
+          self.mongo_source.get_range_tracker(None, None),
+          LexicographicKeyRangeTracker,
+      )
 
   @mock.patch('apache_beam.io.mongodbio.MongoClient')
   def test_read(self, mock_client):
@@ -394,26 +506,26 @@ class MongoSourceTest(unittest.TestCase):
         },
         {
             # range covers from the first to the third documents
-            'start': _ObjectIdHelper.int_to_id(0),  # smallest possible id
+            'start': self.min_id,  # smallest possible id
             'stop': self._ids[2],
             'expected': self._docs[0:2]
         },
         {
             # range covers from the third to last documents
             'start': self._ids[2],
-            'stop': _ObjectIdHelper.int_to_id(2**96 - 1),  # largest possible 
id
+            'stop': self.max_id,  # largest possible id
             'expected': self._docs[2:]
         },
         {
             # range covers all documents
-            'start': _ObjectIdHelper.int_to_id(0),
-            'stop': _ObjectIdHelper.int_to_id(2**96 - 1),
+            'start': self.min_id,
+            'stop': self.max_id,
             'expected': self._docs
         },
         {
             # range doesn't include any document
-            'start': _ObjectIdHelper.increment_id(self._ids[2], 1),
-            'stop': _ObjectIdHelper.increment_id(self._ids[3], -1),
+            'start': self._increment_id(self._ids[2], 1),
+            'stop': self._increment_id(self._ids[3], -1),
             'expected': []
         },
     ]
@@ -429,6 +541,32 @@ class MongoSourceTest(unittest.TestCase):
     self.assertTrue('database' in data)
     self.assertTrue('collection' in data)
 
+  def test_range_is_not_splittable(self):
+    self.assertTrue(
+        self.mongo_source._range_is_not_splittable(
+            _ObjectIdHelper.int_to_id(1),
+            _ObjectIdHelper.int_to_id(1),
+        ))
+    self.assertTrue(
+        self.mongo_source._range_is_not_splittable(
+            _ObjectIdHelper.int_to_id(1),
+            _ObjectIdHelper.int_to_id(2),
+        ))
+    self.assertFalse(
+        self.mongo_source._range_is_not_splittable(
+            _ObjectIdHelper.int_to_id(1),
+            _ObjectIdHelper.int_to_id(3),
+        ))
+
+    self.assertTrue(self.mongo_source._range_is_not_splittable(0, 0))
+    self.assertTrue(self.mongo_source._range_is_not_splittable(0, 1))
+    self.assertFalse(self.mongo_source._range_is_not_splittable(0, 2))
+
+    self.assertTrue(self.mongo_source._range_is_not_splittable("AAA", "AAA"))
+    self.assertFalse(
+        self.mongo_source._range_is_not_splittable("AAA", "AAA\x00"))
+    self.assertFalse(self.mongo_source._range_is_not_splittable("AAA", "AAB"))
+
 
 @parameterized_class(('bucket_auto', ), [(False, ), (True, )])
 class ReadFromMongoDBTest(unittest.TestCase):
@@ -500,9 +638,13 @@ class MongoSinkTest(unittest.TestCase):
 class WriteToMongoDBTest(unittest.TestCase):
   @mock.patch('apache_beam.io.mongodbio.MongoClient')
   def test_write_to_mongodb_with_existing_id(self, mock_client):
-    id = objectid.ObjectId()
-    docs = [{'x': 1, '_id': id}]
-    expected_update = [ReplaceOne({'_id': id}, {'x': 1, '_id': id}, True, 
None)]
+    _id = objectid.ObjectId()
+    docs = [{'x': 1, '_id': _id}]
+    expected_update = [
+        ReplaceOne({'_id': _id}, {
+            'x': 1, '_id': _id
+        }, True, None)
+    ]
     with TestPipeline() as p:
       _ = (
           p | "Create" >> beam.Create(docs)
@@ -538,34 +680,36 @@ class ObjectIdHelperTest(TestCase):
         (objectid.ObjectId('00000000ffffffffffffffff'), 2**64 - 1),
         (objectid.ObjectId('ffffffffffffffffffffffff'), 2**96 - 1),
     ]
-    for (id, number) in test_cases:
-      self.assertEqual(id, _ObjectIdHelper.int_to_id(number))
-      self.assertEqual(number, _ObjectIdHelper.id_to_int(id))
+    for (_id, number) in test_cases:
+      self.assertEqual(_id, _ObjectIdHelper.int_to_id(number))
+      self.assertEqual(number, _ObjectIdHelper.id_to_int(_id))
 
     # random tests
     for _ in range(100):
-      id = objectid.ObjectId()
-      number = int(id.binary.hex(), 16)
-      self.assertEqual(id, _ObjectIdHelper.int_to_id(number))
-      self.assertEqual(number, _ObjectIdHelper.id_to_int(id))
+      _id = objectid.ObjectId()
+      number = int(_id.binary.hex(), 16)
+      self.assertEqual(_id, _ObjectIdHelper.int_to_id(number))
+      self.assertEqual(number, _ObjectIdHelper.id_to_int(_id))
 
   def test_increment_id(self):
     test_cases = [
         (
-            objectid.ObjectId('000000000000000100000000'),
-            objectid.ObjectId('0000000000000000ffffffff')),
+            objectid.ObjectId("000000000000000100000000"),
+            objectid.ObjectId("0000000000000000ffffffff"),
+        ),
         (
-            objectid.ObjectId('000000010000000000000000'),
-            objectid.ObjectId('00000000ffffffffffffffff')),
+            objectid.ObjectId("000000010000000000000000"),
+            objectid.ObjectId("00000000ffffffffffffffff"),
+        ),
     ]
-    for (first, second) in test_cases:
+    for first, second in test_cases:
       self.assertEqual(second, _ObjectIdHelper.increment_id(first, -1))
       self.assertEqual(first, _ObjectIdHelper.increment_id(second, 1))
 
     for _ in range(100):
-      id = objectid.ObjectId()
-      self.assertLess(id, _ObjectIdHelper.increment_id(id, 1))
-      self.assertGreater(id, _ObjectIdHelper.increment_id(id, -1))
+      _id = objectid.ObjectId()
+      self.assertLess(_id, _ObjectIdHelper.increment_id(_id, 1))
+      self.assertGreater(_id, _ObjectIdHelper.increment_id(_id, -1))
 
 
 class ObjectRangeTrackerTest(TestCase):
@@ -578,10 +722,10 @@ class ObjectRangeTrackerTest(TestCase):
                   [random.randint(start_int, stop_int) for _ in range(100)])
     tracker = _ObjectIdRangeTracker()
     for pos in test_cases:
-      id = _ObjectIdHelper.int_to_id(pos - start_int)
+      _id = _ObjectIdHelper.int_to_id(pos - start_int)
       desired_fraction = (pos - start_int) / (stop_int - start_int)
       self.assertAlmostEqual(
-          tracker.position_to_fraction(id, start, stop),
+          tracker.position_to_fraction(_id, start, stop),
           desired_fraction,
           places=20)
 
diff --git a/sdks/python/apache_beam/io/range_trackers.py 
b/sdks/python/apache_beam/io/range_trackers.py
index a53ecb6..33b15d5 100644
--- a/sdks/python/apache_beam/io/range_trackers.py
+++ b/sdks/python/apache_beam/io/range_trackers.py
@@ -23,6 +23,7 @@ import codecs
 import logging
 import math
 import threading
+from typing import Union
 
 from apache_beam.io import iobase
 
@@ -265,16 +266,17 @@ class OrderedPositionRangeTracker(iobase.RangeTracker):
       if ((self._stop_position is not None and position >= self._stop_position)
           or (self._start_position is not None and
               position <= self._start_position)):
-        raise ValueError(
-            "Split at '%s' not in range %s" %
-            (position, [self._start_position, self._stop_position]))
+        _LOGGER.debug(
+            'Refusing to split %r at %d: proposed split position out of range',
+            self,
+            position)
+        return
+
       if self._last_claim is self.UNSTARTED or self._last_claim < position:
         fraction = self.position_to_fraction(
             position, start=self._start_position, end=self._stop_position)
         self._stop_position = position
         return position, fraction
-      else:
-        return None
 
   def fraction_consumed(self):
     if self._last_claim is self.UNSTARTED:
@@ -289,6 +291,12 @@ class OrderedPositionRangeTracker(iobase.RangeTracker):
     """
     raise NotImplementedError
 
+  def position_to_fraction(self, position, start, end):
+    """Returns the fraction of keys in the range [start, end) that
+    are less than the given key.
+    """
+    raise NotImplementedError
+
 
 class UnsplittableRangeTracker(iobase.RangeTracker):
   """A RangeTracker that always ignores split requests.
@@ -339,87 +347,112 @@ class UnsplittableRangeTracker(iobase.RangeTracker):
 
 
 class LexicographicKeyRangeTracker(OrderedPositionRangeTracker):
-  """
-  A range tracker that tracks progress through a lexicographically
+  """A range tracker that tracks progress through a lexicographically
   ordered keyspace of strings.
   """
   @classmethod
-  def fraction_to_position(cls, fraction, start=None, end=None):
-    """
-    Linearly interpolates a key that is lexicographically
+  def fraction_to_position(
+      cls,
+      fraction: float,
+      start: Union[bytes, str] = None,
+      end: Union[bytes, str] = None,
+  ) -> Union[bytes, str]:
+    """Linearly interpolates a key that is lexicographically
     fraction of the way between start and end.
     """
     assert 0 <= fraction <= 1, fraction
+
     if start is None:
       start = b''
+
+    if fraction == 0:
+      return start
+
     if fraction == 1:
       return end
-    elif fraction == 0:
-      return start
+
+    if not end:
+      common_prefix_len = len(start) - len(start.lstrip(b'\xFF'))
     else:
-      if not end:
-        common_prefix_len = len(start) - len(start.lstrip(b'\xFF'))
+      for ix, (s, e) in enumerate(zip(start, end)):
+        if s != e:
+          common_prefix_len = ix
+          break
       else:
-        for ix, (s, e) in enumerate(zip(start, end)):
-          if s != e:
-            common_prefix_len = ix
-            break
-        else:
-          common_prefix_len = min(len(start), len(end))
-      # Convert the relative precision of fraction (~53 bits) to an absolute
-      # precision needed to represent values between start and end distinctly.
-      prec = common_prefix_len + int(-math.log(fraction, 256)) + 7
-      istart = cls._bytestring_to_int(start, prec)
-      iend = cls._bytestring_to_int(end, prec) if end else 1 << (prec * 8)
-      ikey = istart + int((iend - istart) * fraction)
-      # Could be equal due to rounding.
-      # Adjust to ensure we never return the actual start and end
-      # unless fraction is exatly 0 or 1.
-      if ikey == istart:
-        ikey += 1
-      elif ikey == iend:
-        ikey -= 1
-      return cls._bytestring_from_int(ikey, prec).rstrip(b'\0')
+        common_prefix_len = min(len(start), len(end))
+
+    # Convert the relative precision of fraction (~53 bits) to an absolute
+    # precision needed to represent values between start and end distinctly.
+    prec = common_prefix_len + int(-math.log(fraction, 256)) + 7
+    istart = cls._bytestring_to_int(start, prec)
+    iend = cls._bytestring_to_int(end, prec) if end else 1 << (prec * 8)
+    ikey = istart + int((iend - istart) * fraction)
+
+    # Could be equal due to rounding.
+    # Adjust to ensure we never return the actual start and end
+    # unless fraction is exatly 0 or 1.
+    if ikey == istart:
+      ikey += 1
+    elif ikey == iend:
+      ikey -= 1
+
+    position: bytes = cls._bytestring_from_int(ikey, prec).rstrip(b'\0')
+
+    if isinstance(start, bytes):
+      return position
+
+    return position.decode(encoding='unicode_escape', errors='replace')
 
   @classmethod
-  def position_to_fraction(cls, key, start=None, end=None):
-    """
-    Returns the fraction of keys in the range [start, end) that
+  def position_to_fraction(
+      cls,
+      key: Union[bytes, str] = None,
+      start: Union[bytes, str] = None,
+      end: Union[bytes, str] = None,
+  ) -> float:
+    """Returns the fraction of keys in the range [start, end) that
     are less than the given key.
     """
     if not key:
       return 0
+
     if start is None:
-      start = b''
+      start = '' if isinstance(key, str) else b''
+
     prec = len(start) + 7
     if key.startswith(start):
       # Higher absolute precision needed for very small values of fixed
       # relative position.
-      prec = max(prec, len(key) - len(key[len(start):].strip(b'\0')) + 7)
+      trailing_symbol = '\0' if isinstance(key, str) else b'\0'
+      prec = max(
+          prec, len(key) - len(key[len(start):].strip(trailing_symbol)) + 7)
     istart = cls._bytestring_to_int(start, prec)
     ikey = cls._bytestring_to_int(key, prec)
     iend = cls._bytestring_to_int(end, prec) if end else 1 << (prec * 8)
     return float(ikey - istart) / (iend - istart)
 
   @staticmethod
-  def _bytestring_to_int(s, prec):
-    """
-    Returns int(256**prec * f) where f is the fraction
+  def _bytestring_to_int(s: Union[bytes, str], prec: int) -> int:
+    """Returns int(256**prec * f) where f is the fraction
     represented by interpreting '.' + s as a base-256
     floating point number.
     """
     if not s:
       return 0
-    elif len(s) < prec:
+
+    if isinstance(s, str):
+      s = s.encode()  # str -> bytes
+
+    if len(s) < prec:
       s += b'\0' * (prec - len(s))
     else:
       s = s[:prec]
-    return int(codecs.encode(s, 'hex'), 16)
+
+    h = codecs.encode(s, encoding='hex')
+    return int(h, base=16)
 
   @staticmethod
-  def _bytestring_from_int(i, prec):
-    """
-    Inverse of _bytestring_to_int.
-    """
+  def _bytestring_from_int(i: int, prec: int) -> bytes:
+    """Inverse of _bytestring_to_int."""
     h = '%x' % i
-    return codecs.decode('0' * (2 * prec - len(h)) + h, 'hex')
+    return codecs.decode('0' * (2 * prec - len(h)) + h, encoding='hex')
diff --git a/sdks/python/apache_beam/io/range_trackers_test.py 
b/sdks/python/apache_beam/io/range_trackers_test.py
index 02a1e4e..0bf3799 100644
--- a/sdks/python/apache_beam/io/range_trackers_test.py
+++ b/sdks/python/apache_beam/io/range_trackers_test.py
@@ -22,6 +22,8 @@ import copy
 import logging
 import math
 import unittest
+from typing import Optional
+from typing import Union
 
 from apache_beam.io import range_trackers
 
@@ -257,25 +259,26 @@ class OrderedPositionRangeTrackerTest(unittest.TestCase):
 
   def test_out_of_range(self):
     tracker = self.DoubleRangeTracker(10, 20)
+
     # Can't claim before range.
     with self.assertRaises(ValueError):
       tracker.try_claim(-5)
+
     # Can't split before range.
-    with self.assertRaises(ValueError):
-      tracker.try_split(-5)
+    self.assertFalse(tracker.try_split(-5))
+
     # Reject useless split at start position.
-    with self.assertRaises(ValueError):
-      tracker.try_split(10)
+    self.assertFalse(tracker.try_split(10))
+
     # Can't split after range.
-    with self.assertRaises(ValueError):
-      tracker.try_split(25)
+    self.assertFalse(tracker.try_split(25))
     tracker.try_split(15)
+
     # Can't split after modified range.
-    with self.assertRaises(ValueError):
-      tracker.try_split(17)
+    self.assertFalse(tracker.try_split(17))
+
     # Reject useless split at end position.
-    with self.assertRaises(ValueError):
-      tracker.try_split(15)
+    self.assertFalse(tracker.try_split(15))
     self.assertTrue(tracker.try_split(14))
 
 
@@ -303,16 +306,20 @@ class UnsplittableRangeTrackerTest(unittest.TestCase):
 
 
 class LexicographicKeyRangeTrackerTest(unittest.TestCase):
-  """
-  Tests of LexicographicKeyRangeTracker.
-  """
+  """Tests of LexicographicKeyRangeTracker."""
 
   key_to_fraction = (
       range_trackers.LexicographicKeyRangeTracker.position_to_fraction)
   fraction_to_key = (
       range_trackers.LexicographicKeyRangeTracker.fraction_to_position)
 
-  def _check(self, fraction=None, key=None, start=None, end=None, delta=0):
+  def _check(
+      self,
+      fraction: Optional[float] = None,
+      key: Union[bytes, str] = None,
+      start: Union[bytes, str] = None,
+      end: Union[bytes, str] = None,
+      delta: float = 0.0):
     assert key is not None or fraction is not None
     if fraction is None:
       fraction = self.key_to_fraction(key, start, end)
@@ -341,14 +348,42 @@ class LexicographicKeyRangeTrackerTest(unittest.TestCase):
     self._check(key=b'\x07', fraction=7 / 256.)
     self._check(key=b'\xFF', fraction=255 / 256.)
     self._check(key=b'\x01\x02\x03', fraction=(2**16 + 2**9 + 3) / (2.0**24))
+    self._check(key=b'UUUUUUT', fraction=1 / 3)
+    self._check(key=b'3333334', fraction=1 / 5)
+    self._check(key=b'$\x92I$\x92I$', fraction=1 / 7, delta=1e-3)
+    self._check(key=b'\x01\x02\x03', fraction=(2**16 + 2**9 + 3) / (2.0**24))
 
   def test_key_to_fraction(self):
+    # test no key, no start
+    self._check(end=b'eeeeee', fraction=0.0)
+    self._check(end='eeeeee', fraction=0.0)
+
+    # test no fraction
+    self._check(key=b'bbbbbb', start=b'aaaaaa', end=b'eeeeee')
+    self._check(key='bbbbbb', start='aaaaaa', end='eeeeee')
+
+    # test no start
+    self._check(key=b'eeeeee', end=b'eeeeee', fraction=1.0)
+    self._check(key='eeeeee', end='eeeeee', fraction=1.0)
+    self._check(key=b'\x19YYYYY@', end=b'eeeeee', fraction=0.25)
+    self._check(key=b'2\xb2\xb2\xb2\xb2\xb2\x80', end='eeeeee', fraction=0.5)
+    self._check(key=b'L\x0c\x0c\x0c\x0c\x0b\xc0', end=b'eeeeee', fraction=0.75)
+
+    # test bytes keys
     self._check(key=b'\x87', start=b'\x80', fraction=7 / 128.)
     self._check(key=b'\x07', end=b'\x10', fraction=7 / 16.)
     self._check(key=b'\x47', start=b'\x40', end=b'\x80', fraction=7 / 64.)
     self._check(key=b'\x47\x80', start=b'\x40', end=b'\x80', fraction=15 / 
128.)
 
+    # test string keys
+    self._check(key='aaaaaa', start='aaaaaa', end='eeeeee', fraction=0.0)
+    self._check(key='bbbbbb', start='aaaaaa', end='eeeeee', fraction=0.25)
+    self._check(key='cccccc', start='aaaaaa', end='eeeeee', fraction=0.5)
+    self._check(key='dddddd', start='aaaaaa', end='eeeeee', fraction=0.75)
+    self._check(key='eeeeee', start='aaaaaa', end='eeeeee', fraction=1.0)
+
   def test_key_to_fraction_common_prefix(self):
+    # test bytes keys
     self._check(
         key=b'a' * 100 + b'b',
         start=b'a' * 100 + b'a',
@@ -370,7 +405,35 @@ class LexicographicKeyRangeTrackerTest(unittest.TestCase):
         end=b'foob\x00\x00\x00\x00\x00\x00\x00\x00\x02',
         fraction=0.5)
 
+    # test string keys
+    self._check(
+        key='a' * 100 + 'a',
+        start='a' * 100 + 'a',
+        end='a' * 100 + 'e',
+        fraction=0.0)
+    self._check(
+        key='a' * 100 + 'b',
+        start='a' * 100 + 'a',
+        end='a' * 100 + 'e',
+        fraction=0.25)
+    self._check(
+        key='a' * 100 + 'c',
+        start='a' * 100 + 'a',
+        end='a' * 100 + 'e',
+        fraction=0.5)
+    self._check(
+        key='a' * 100 + 'd',
+        start='a' * 100 + 'a',
+        end='a' * 100 + 'e',
+        fraction=0.75)
+    self._check(
+        key='a' * 100 + 'e',
+        start='a' * 100 + 'a',
+        end='a' * 100 + 'e',
+        fraction=1.0)
+
   def test_tiny(self):
+    # test bytes keys
     self._check(fraction=.5**20, key=b'\0\0\x10')
     self._check(fraction=.5**20, start=b'a', end=b'b', key=b'a\0\0\x10')
     self._check(fraction=.5**20, start=b'a', end=b'c', key=b'a\0\0\x20')
@@ -386,6 +449,11 @@ class LexicographicKeyRangeTrackerTest(unittest.TestCase):
         delta=1e-15)
     self._check(fraction=.5**100, key=b'\0' * 12 + b'\x10')
 
+    # test string keys
+    self._check(fraction=.5**20, start='a', end='b', key='a\0\0\x10')
+    self._check(fraction=.5**20, start='a', end='c', key='a\0\0\x20')
+    self._check(fraction=.5**20, start='xy_a', end='xy_c', key='xy_a\0\0\x20')
+
   def test_lots(self):
     for fraction in (0, 1, .5, .75, 7. / 512, 1 - 7. / 4096):
       self._check(fraction)
@@ -418,6 +486,12 @@ class LexicographicKeyRangeTrackerTest(unittest.TestCase):
     # (beyond the common prefix of start and end).
     self._check(
         1 / math.e,
+        start='AAAAAAA',
+        end='zzzzzzz',
+        key='VNg/ot\x82',
+        delta=1e-14)
+    self._check(
+        1 / math.e,
         start=b'abc_abc',
         end=b'abc_xyz',
         key=b'abc_i\xe0\xf4\x84\x86\x99\x96',

Reply via email to