y1chi commented on a change in pull request #14460:
URL: https://github.com/apache/beam/pull/14460#discussion_r608821391



##########
File path: .gitignore
##########
@@ -108,3 +108,4 @@ website/www/site/code_samples
 website/www/site/_config_branch_repo.toml
 website/www/yarn-error.log
 !website/www/site/content
+/venv/

Review comment:
       clean up?

##########
File path: sdks/python/apache_beam/io/mongodbio.py
##########
@@ -66,542 +66,771 @@
 
 # pytype: skip-file
 
+import codecs
 import itertools
 import json
 import logging
 import math
 import struct
+from typing import Union
+
+from bson import json_util
+from bson.objectid import ObjectId
 
 import apache_beam as beam
 from apache_beam.io import iobase
 from apache_beam.io.range_trackers import OrderedPositionRangeTracker
-from apache_beam.transforms import DoFn
-from apache_beam.transforms import PTransform
-from apache_beam.transforms import Reshuffle
+from apache_beam.transforms import DoFn, PTransform, Reshuffle
 from apache_beam.utils.annotations import experimental
 
 _LOGGER = logging.getLogger(__name__)
 
 try:
-  # Mongodb has its own bundled bson, which is not compatible with bson 
pakcage.
-  # (https://github.com/py-bson/bson/issues/82). Try to import objectid and if
-  # it fails because bson package is installed, MongoDB IO will not work but at
-  # least rest of the SDK will work.
-  from bson import objectid
-
-  # pymongo also internally depends on bson.
-  from pymongo import ASCENDING
-  from pymongo import DESCENDING
-  from pymongo import MongoClient
-  from pymongo import ReplaceOne
+    # Mongodb has its own bundled bson, which is not compatible with bson 
pakcage.

Review comment:
       please reformat with yapf.

##########
File path: sdks/python/apache_beam/io/mongodbio.py
##########
@@ -66,542 +66,771 @@
 
 # pytype: skip-file
 
+import codecs
 import itertools
 import json
 import logging
 import math
 import struct
+from typing import Union
+
+from bson import json_util
+from bson.objectid import ObjectId
 
 import apache_beam as beam
 from apache_beam.io import iobase
 from apache_beam.io.range_trackers import OrderedPositionRangeTracker
-from apache_beam.transforms import DoFn
-from apache_beam.transforms import PTransform
-from apache_beam.transforms import Reshuffle
+from apache_beam.transforms import DoFn, PTransform, Reshuffle
 from apache_beam.utils.annotations import experimental
 
 _LOGGER = logging.getLogger(__name__)
 
 try:
-  # Mongodb has its own bundled bson, which is not compatible with bson 
pakcage.
-  # (https://github.com/py-bson/bson/issues/82). Try to import objectid and if
-  # it fails because bson package is installed, MongoDB IO will not work but at
-  # least rest of the SDK will work.
-  from bson import objectid
-
-  # pymongo also internally depends on bson.
-  from pymongo import ASCENDING
-  from pymongo import DESCENDING
-  from pymongo import MongoClient
-  from pymongo import ReplaceOne
+    # Mongodb has its own bundled bson, which is not compatible with bson 
pakcage.
+    # (https://github.com/py-bson/bson/issues/82). Try to import objectid and 
if
+    # it fails because bson package is installed, MongoDB IO will not work but 
at
+    # least rest of the SDK will work.
+    from bson import objectid
+
+    # pymongo also internally depends on bson.
+    from pymongo import ASCENDING
+    from pymongo import DESCENDING
+    from pymongo import MongoClient
+    from pymongo import ReplaceOne
 except ImportError:
-  objectid = None
-  _LOGGER.warning("Could not find a compatible bson package.")
+    objectid = None
+    _LOGGER.warning("Could not find a compatible bson package.")
 
-__all__ = ['ReadFromMongoDB', 'WriteToMongoDB']
+__all__ = ["ReadFromMongoDB", "WriteToMongoDB"]
 
 
 @experimental()
 class ReadFromMongoDB(PTransform):
-  """A ``PTransform`` to read MongoDB documents into a ``PCollection``.
-  """
-  def __init__(
-      self,
-      uri='mongodb://localhost:27017',
-      db=None,
-      coll=None,
-      filter=None,
-      projection=None,
-      extra_client_params=None,
-      bucket_auto=False):
-    """Initialize a :class:`ReadFromMongoDB`
-
-    Args:
-      uri (str): The MongoDB connection string following the URI format.
-      db (str): The MongoDB database name.
-      coll (str): The MongoDB collection name.
-      filter: A `bson.SON
-        <https://api.mongodb.com/python/current/api/bson/son.html>`_ object
-        specifying elements which must be present for a document to be included
-        in the result set.
-      projection: A list of field names that should be returned in the result
-        set or a dict specifying the fields to include or exclude.
-      extra_client_params(dict): Optional `MongoClient
-        
<https://api.mongodb.com/python/current/api/pymongo/mongo_client.html>`_
-        parameters.
-      bucket_auto (bool): If :data:`True`, use MongoDB `$bucketAuto` 
aggregation
-        to split collection into bundles instead of `splitVector` command,
-        which does not work with MongoDB Atlas.
-        If :data:`False` (the default), use `splitVector` command for bundling.
-
-    Returns:
-      :class:`~apache_beam.transforms.ptransform.PTransform`
-
-    """
-    if extra_client_params is None:
-      extra_client_params = {}
-    if not isinstance(db, str):
-      raise ValueError('ReadFromMongDB db param must be specified as a string')
-    if not isinstance(coll, str):
-      raise ValueError(
-          'ReadFromMongDB coll param must be specified as a '
-          'string')
-    self._mongo_source = _BoundedMongoSource(
-        uri=uri,
-        db=db,
-        coll=coll,
-        filter=filter,
-        projection=projection,
-        extra_client_params=extra_client_params,
-        bucket_auto=bucket_auto)
-
-  def expand(self, pcoll):
-    return pcoll | iobase.Read(self._mongo_source)
+    """A ``PTransform`` to read MongoDB documents into a ``PCollection``."""
+
+    def __init__(
+        self,
+        uri="mongodb://localhost:27017",
+        db=None,
+        coll=None,
+        filter=None,
+        projection=None,
+        extra_client_params=None,
+        bucket_auto=False,
+    ):
+        """Initialize a :class:`ReadFromMongoDB`
+
+        Args:
+          uri (str): The MongoDB connection string following the URI format.
+          db (str): The MongoDB database name.
+          coll (str): The MongoDB collection name.
+          filter: A `bson.SON
+            <https://api.mongodb.com/python/current/api/bson/son.html>`_ object
+            specifying elements which must be present for a document to be 
included
+            in the result set.
+          projection: A list of field names that should be returned in the 
result
+            set or a dict specifying the fields to include or exclude.
+          extra_client_params(dict): Optional `MongoClient
+            
<https://api.mongodb.com/python/current/api/pymongo/mongo_client.html>`_
+            parameters.
+          bucket_auto (bool): If :data:`True`, use MongoDB `$bucketAuto` 
aggregation
+            to split collection into bundles instead of `splitVector` command,
+            which does not work with MongoDB Atlas.
+            If :data:`False` (the default), use `splitVector` command for 
bundling.
+
+        Returns:
+          :class:`~apache_beam.transforms.ptransform.PTransform`
+
+        """
+        if extra_client_params is None:
+            extra_client_params = {}
+        if not isinstance(db, str):
+            raise ValueError("ReadFromMongDB db param must be specified as a 
string")
+        if not isinstance(coll, str):
+            raise ValueError("ReadFromMongDB coll param must be specified as a 
string")
+        self._mongo_source = _BoundedMongoSource(
+            uri=uri,
+            db=db,
+            coll=coll,
+            filter=filter,
+            projection=projection,
+            extra_client_params=extra_client_params,
+            bucket_auto=bucket_auto,
+        )
+
+    def expand(self, pcoll):
+        return pcoll | iobase.Read(self._mongo_source)
 
 
 class _BoundedMongoSource(iobase.BoundedSource):
-  def __init__(
-      self,
-      uri=None,
-      db=None,
-      coll=None,
-      filter=None,
-      projection=None,
-      extra_client_params=None,
-      bucket_auto=False):
-    if extra_client_params is None:
-      extra_client_params = {}
-    if filter is None:
-      filter = {}
-    self.uri = uri
-    self.db = db
-    self.coll = coll
-    self.filter = filter
-    self.projection = projection
-    self.spec = extra_client_params
-    self.bucket_auto = bucket_auto
-
-  def estimate_size(self):
-    with MongoClient(self.uri, **self.spec) as client:
-      return client[self.db].command('collstats', self.coll).get('size')
-
-  def _estimate_average_document_size(self):
-    with MongoClient(self.uri, **self.spec) as client:
-      return client[self.db].command('collstats', self.coll).get('avgObjSize')
-
-  def split(self, desired_bundle_size, start_position=None, 
stop_position=None):
-    desired_bundle_size_in_mb = desired_bundle_size // 1024 // 1024
-
-    # for desired bundle size, if desired chunk size smaller than 1mb, use
-    # MongoDB default split size of 1mb.
-    if desired_bundle_size_in_mb < 1:
-      desired_bundle_size_in_mb = 1
-
-    is_initial_split = start_position is None and stop_position is None
-    start_position, stop_position = self._replace_none_positions(
-      start_position, stop_position)
-
-    if self.bucket_auto:
-      # Use $bucketAuto for bundling
-      split_keys = []
-      weights = []
-      for bucket in self._get_auto_buckets(desired_bundle_size_in_mb,
-                                           start_position,
-                                           stop_position,
-                                           is_initial_split):
-        split_keys.append({'_id': bucket['_id']['max']})
-        weights.append(bucket['count'])
-    else:
-      # Use splitVector for bundling
-      split_keys = self._get_split_keys(
-          desired_bundle_size_in_mb, start_position, stop_position)
-      weights = itertools.cycle((desired_bundle_size_in_mb, ))
-
-    bundle_start = start_position
-    for split_key_id, weight in zip(split_keys, weights):
-      if bundle_start >= stop_position:
-        break
-      bundle_end = min(stop_position, split_key_id['_id'])
-      yield iobase.SourceBundle(
-          weight=weight,
-          source=self,
-          start_position=bundle_start,
-          stop_position=bundle_end)
-      bundle_start = bundle_end
-    # add range of last split_key to stop_position
-    if bundle_start < stop_position:
-      # bucket_auto mode can come here if not split due to single document
-      weight = 1 if self.bucket_auto else desired_bundle_size_in_mb
-      yield iobase.SourceBundle(
-          weight=weight,
-          source=self,
-          start_position=bundle_start,
-          stop_position=stop_position)
-
-  def get_range_tracker(self, start_position, stop_position):
-    start_position, stop_position = self._replace_none_positions(
-        start_position, stop_position)
-    return _ObjectIdRangeTracker(start_position, stop_position)
-
-  def read(self, range_tracker):
-    with MongoClient(self.uri, **self.spec) as client:
-      all_filters = self._merge_id_filter(
-          range_tracker.start_position(), range_tracker.stop_position())
-      docs_cursor = client[self.db][self.coll].find(
-          filter=all_filters,
-          projection=self.projection).sort([('_id', ASCENDING)])
-      for doc in docs_cursor:
-        if not range_tracker.try_claim(doc['_id']):
-          return
-        yield doc
-
-  def display_data(self):
-    res = super(_BoundedMongoSource, self).display_data()
-    res['database'] = self.db
-    res['collection'] = self.coll
-    res['filter'] = json.dumps(
-        self.filter, default=lambda x: 'not_serializable(%s)' % str(x))
-    res['projection'] = str(self.projection)
-    res['bucket_auto'] = self.bucket_auto
-    return res
-
-  def _get_split_keys(self, desired_chunk_size_in_mb, start_pos, end_pos):
-    # calls mongodb splitVector command to get document ids at split position
-    if start_pos >= _ObjectIdHelper.increment_id(end_pos, -1):
-      # single document not splittable
-      return []
-    with MongoClient(self.uri, **self.spec) as client:
-      name_space = '%s.%s' % (self.db, self.coll)
-      return (
-          client[self.db].command(
-              'splitVector',
-              name_space,
-              keyPattern={'_id': 1},  # Ascending index
-              min={'_id': start_pos},
-              max={'_id': end_pos},
-              maxChunkSize=desired_chunk_size_in_mb)['splitKeys'])
-
-  def _get_auto_buckets(
-      self, desired_chunk_size_in_mb, start_pos, end_pos, is_initial_split):
-
-    if start_pos >= _ObjectIdHelper.increment_id(end_pos, -1):
-      # single document not splittable
-      return []
-
-    if is_initial_split and not self.filter:
-      # total collection size
-      size_in_mb = self.estimate_size() / float(1 << 20)
-    else:
-      # size of documents within start/end id range and possibly filtered
-      documents_count = self._count_id_range(start_pos, end_pos)
-      avg_document_size = self._estimate_average_document_size()
-      size_in_mb = documents_count * avg_document_size / float(1 << 20)
-
-    if size_in_mb == 0:
-      # no documents not splittable (maybe a result of filtering)
-      return []
-
-    bucket_count = math.ceil(size_in_mb / desired_chunk_size_in_mb)
-    with beam.io.mongodbio.MongoClient(self.uri, **self.spec) as client:
-      pipeline = [
-          {
-              # filter by positions and by the custom filter if any
-              '$match': self._merge_id_filter(start_pos, end_pos)
-          },
-          {
-              '$bucketAuto': {
-                  'groupBy': '$_id', 'buckets': bucket_count
-              }
-          }
-      ]
-      buckets = list(client[self.db][self.coll].aggregate(pipeline))
-      if buckets:
-        buckets[-1]['_id']['max'] = end_pos
-
-      return buckets
-
-  def _merge_id_filter(self, start_position, stop_position):
-    # Merge the default filter (if any) with refined _id field range
-    # of range_tracker.
-    # $gte specifies start position (inclusive)
-    # and $lt specifies the end position (exclusive),
-    # see more at
-    # https://docs.mongodb.com/manual/reference/operator/query/gte/ and
-    # https://docs.mongodb.com/manual/reference/operator/query/lt/
-    id_filter = {'_id': {'$gte': start_position, '$lt': stop_position}}
-    if self.filter:
-      all_filters = {
-          # see more at
-          # https://docs.mongodb.com/manual/reference/operator/query/and/
-          '$and': [self.filter.copy(), id_filter]
-      }
-    else:
-      all_filters = id_filter
-
-    return all_filters
-
-  def _get_head_document_id(self, sort_order):
-    with MongoClient(self.uri, **self.spec) as client:
-      cursor = client[self.db][self.coll].find(
-          filter={}, projection=[]).sort([('_id', sort_order)]).limit(1)
-      try:
-        return cursor[0]['_id']
-      except IndexError:
-        raise ValueError('Empty Mongodb collection')
-
-  def _replace_none_positions(self, start_position, stop_position):
-    if start_position is None:
-      start_position = self._get_head_document_id(ASCENDING)
-    if stop_position is None:
-      last_doc_id = self._get_head_document_id(DESCENDING)
-      # increment last doc id binary value by 1 to make sure the last document
-      # is not excluded
-      stop_position = _ObjectIdHelper.increment_id(last_doc_id, 1)
-    return start_position, stop_position
-
-  def _count_id_range(self, start_position, stop_position):
-    # Number of documents between start_position (inclusive)
-    # and stop_position (exclusive), respecting the custom filter if any.
-    with MongoClient(self.uri, **self.spec) as client:
-      return client[self.db][self.coll].count_documents(
-          filter=self._merge_id_filter(start_position, stop_position))
-
-
-class _ObjectIdHelper(object):
-  """A Utility class to manipulate bson object ids."""
-  @classmethod
-  def id_to_int(cls, id):
+    """A MongoDB source that reads a finite amount of input records.
+
+    This class defines following operations which can be used to read the 
source
+    efficiently.
+
+    * Size estimation - method ``estimate_size()`` may return an accurate
+      estimation in bytes for the size of the source.
+    * Splitting into bundles of a given size - method ``split()`` can be used 
to
+      split the source into a set of sub-sources (bundles) based on a desired
+      bundle size.
+    * Getting a MongoDBRangeTracker - method ``get_range_tracker()`` should 
return a
+      ``MongoDBRangeTracker`` object for a given position range for the 
position type
+      of the records returned by the source.
+    * Reading the data - method ``read()`` can be used to read data from the
+      source while respecting the boundaries defined by a given
+      ``MongoDBRangeTracker``.
+
+    A runner will perform reading the source in two steps.
+
+    (1) Method ``get_range_tracker()`` will be invoked with start and end
+        positions to obtain a ``MongoDBRangeTracker`` for the range of 
positions the
+        runner intends to read. Source must define a default initial start and 
end
+        position range. These positions must be used if the start and/or end
+        positions passed to the method ``get_range_tracker()`` are ``None``
+    (2) Method read() will be invoked with the ``MongoDBRangeTracker`` 
obtained in the
+        previous step.
+
+    **Mutability**
+
+    A ``_BoundedMongoSource`` object should not be mutated while
+    its methods (for example, ``read()``) are being invoked by a runner. Runner
+    implementations may invoke methods of ``_BoundedMongoSource`` objects 
through
+    multi-threaded and/or reentrant execution modes.
     """
-    Args:
-      id: ObjectId required for each MongoDB document _id field.
-
-    Returns: Converted integer value of ObjectId's 12 bytes binary value.
-
+    def __init__(
+        self,
+        uri=None,
+        db=None,
+        coll=None,
+        filter=None,
+        projection=None,
+        extra_client_params=None,
+        bucket_auto=False,
+    ):
+        if extra_client_params is None:
+            extra_client_params = {}
+        if filter is None:
+            filter = {}
+        self.uri = uri
+        self.db = db
+        self.coll = coll
+        self.filter = filter
+        self.projection = projection
+        self.spec = extra_client_params
+        self.bucket_auto = bucket_auto
+
+    def estimate_size(self):
+        with MongoClient(self.uri, **self.spec) as client:
+            return client[self.db].command("collstats", self.coll).get("size")
+
+    def _estimate_average_document_size(self):
+        with MongoClient(self.uri, **self.spec) as client:
+            return client[self.db].command("collstats", 
self.coll).get("avgObjSize")
+
+    def split(self, desired_bundle_size, start_position=None, 
stop_position=None):
+        desired_bundle_size_in_mb = desired_bundle_size // 1024 // 1024
+
+        # for desired bundle size, if desired chunk size smaller than 1mb, use
+        # MongoDB default split size of 1mb.
+        if desired_bundle_size_in_mb < 1:
+            desired_bundle_size_in_mb = 1
+
+        is_initial_split = start_position is None and stop_position is None
+        start_position, stop_position = self._replace_none_positions(
+            start_position, stop_position
+        )
+
+        if self.bucket_auto:
+            # Use $bucketAuto for bundling
+            split_keys = []
+            weights = []
+            for bucket in self._get_auto_buckets(
+                desired_bundle_size_in_mb,
+                start_position,
+                stop_position,
+                is_initial_split,
+            ):
+                split_keys.append({"_id": bucket["_id"]["max"]})
+                weights.append(bucket["count"])
+        else:
+            # Use splitVector for bundling
+            split_keys = self._get_split_keys(
+                desired_bundle_size_in_mb, start_position, stop_position
+            )
+            weights = itertools.cycle((desired_bundle_size_in_mb,))
+
+        bundle_start = start_position
+        for split_key_id, weight in zip(split_keys, weights):
+            if bundle_start >= stop_position:
+                break
+            bundle_end = min(stop_position, split_key_id["_id"])
+            yield iobase.SourceBundle(
+                weight=weight,
+                source=self,
+                start_position=bundle_start,
+                stop_position=bundle_end,
+            )
+            bundle_start = bundle_end
+        # add range of last split_key to stop_position
+        if bundle_start < stop_position:
+            # bucket_auto mode can come here if not split due to single 
document
+            weight = 1 if self.bucket_auto else desired_bundle_size_in_mb
+            yield iobase.SourceBundle(
+                weight=weight,
+                source=self,
+                start_position=bundle_start,
+                stop_position=stop_position,
+            )
+
+    def get_range_tracker(
+        self,
+        start_position: Union[int, str, ObjectId] = None,
+        stop_position: Union[int, str, ObjectId] = None,
+    ):
+        """Returns a MongoDBRangeTracker for a given position range.
+
+        Framework may invoke ``read()`` method with the MongoDBRangeTracker 
object
+        returned here to read data from the source.
+
+        Args:
+          start_position: starting position of the range. If 'None' default 
start
+                          position of the source must be used.
+          stop_position:  ending position of the range. If 'None' default stop
+                          position of the source must be used.
+        Returns:
+          a ``MongoDBRangeTracker`` for the given position range.
+        """
+        start_position, stop_position = self._replace_none_positions(
+            start_position, stop_position
+        )
+        return MongoDBRangeTracker(start_position, stop_position)

Review comment:
       Can you try to return different range trackers for different types? I 
think you may be able to reuse OffsetRangeTracker and 
LexicographicKeyRangeTracker, you'd also need to make sure that the comparison 
algorithm is the same used by MongoDB field sorting.

##########
File path: sdks/python/apache_beam/io/mongodbio.py
##########
@@ -66,542 +66,771 @@
 
 # pytype: skip-file
 
+import codecs
 import itertools
 import json
 import logging
 import math
 import struct
+from typing import Union
+
+from bson import json_util
+from bson.objectid import ObjectId
 
 import apache_beam as beam
 from apache_beam.io import iobase
 from apache_beam.io.range_trackers import OrderedPositionRangeTracker
-from apache_beam.transforms import DoFn
-from apache_beam.transforms import PTransform
-from apache_beam.transforms import Reshuffle
+from apache_beam.transforms import DoFn, PTransform, Reshuffle
 from apache_beam.utils.annotations import experimental
 
 _LOGGER = logging.getLogger(__name__)
 
 try:
-  # Mongodb has its own bundled bson, which is not compatible with bson 
pakcage.
-  # (https://github.com/py-bson/bson/issues/82). Try to import objectid and if
-  # it fails because bson package is installed, MongoDB IO will not work but at
-  # least rest of the SDK will work.
-  from bson import objectid
-
-  # pymongo also internally depends on bson.
-  from pymongo import ASCENDING
-  from pymongo import DESCENDING
-  from pymongo import MongoClient
-  from pymongo import ReplaceOne
+    # Mongodb has its own bundled bson, which is not compatible with bson 
pakcage.
+    # (https://github.com/py-bson/bson/issues/82). Try to import objectid and 
if
+    # it fails because bson package is installed, MongoDB IO will not work but 
at
+    # least rest of the SDK will work.
+    from bson import objectid
+
+    # pymongo also internally depends on bson.
+    from pymongo import ASCENDING
+    from pymongo import DESCENDING
+    from pymongo import MongoClient
+    from pymongo import ReplaceOne
 except ImportError:
-  objectid = None
-  _LOGGER.warning("Could not find a compatible bson package.")
+    objectid = None
+    _LOGGER.warning("Could not find a compatible bson package.")
 
-__all__ = ['ReadFromMongoDB', 'WriteToMongoDB']
+__all__ = ["ReadFromMongoDB", "WriteToMongoDB"]
 
 
 @experimental()
 class ReadFromMongoDB(PTransform):
-  """A ``PTransform`` to read MongoDB documents into a ``PCollection``.
-  """
-  def __init__(
-      self,
-      uri='mongodb://localhost:27017',
-      db=None,
-      coll=None,
-      filter=None,
-      projection=None,
-      extra_client_params=None,
-      bucket_auto=False):
-    """Initialize a :class:`ReadFromMongoDB`
-
-    Args:
-      uri (str): The MongoDB connection string following the URI format.
-      db (str): The MongoDB database name.
-      coll (str): The MongoDB collection name.
-      filter: A `bson.SON
-        <https://api.mongodb.com/python/current/api/bson/son.html>`_ object
-        specifying elements which must be present for a document to be included
-        in the result set.
-      projection: A list of field names that should be returned in the result
-        set or a dict specifying the fields to include or exclude.
-      extra_client_params(dict): Optional `MongoClient
-        
<https://api.mongodb.com/python/current/api/pymongo/mongo_client.html>`_
-        parameters.
-      bucket_auto (bool): If :data:`True`, use MongoDB `$bucketAuto` 
aggregation
-        to split collection into bundles instead of `splitVector` command,
-        which does not work with MongoDB Atlas.
-        If :data:`False` (the default), use `splitVector` command for bundling.
-
-    Returns:
-      :class:`~apache_beam.transforms.ptransform.PTransform`
-
-    """
-    if extra_client_params is None:
-      extra_client_params = {}
-    if not isinstance(db, str):
-      raise ValueError('ReadFromMongDB db param must be specified as a string')
-    if not isinstance(coll, str):
-      raise ValueError(
-          'ReadFromMongDB coll param must be specified as a '
-          'string')
-    self._mongo_source = _BoundedMongoSource(
-        uri=uri,
-        db=db,
-        coll=coll,
-        filter=filter,
-        projection=projection,
-        extra_client_params=extra_client_params,
-        bucket_auto=bucket_auto)
-
-  def expand(self, pcoll):
-    return pcoll | iobase.Read(self._mongo_source)
+    """A ``PTransform`` to read MongoDB documents into a ``PCollection``."""
+
+    def __init__(
+        self,
+        uri="mongodb://localhost:27017",
+        db=None,
+        coll=None,
+        filter=None,
+        projection=None,
+        extra_client_params=None,
+        bucket_auto=False,
+    ):
+        """Initialize a :class:`ReadFromMongoDB`
+
+        Args:
+          uri (str): The MongoDB connection string following the URI format.
+          db (str): The MongoDB database name.
+          coll (str): The MongoDB collection name.
+          filter: A `bson.SON
+            <https://api.mongodb.com/python/current/api/bson/son.html>`_ object
+            specifying elements which must be present for a document to be 
included
+            in the result set.
+          projection: A list of field names that should be returned in the 
result
+            set or a dict specifying the fields to include or exclude.
+          extra_client_params(dict): Optional `MongoClient
+            
<https://api.mongodb.com/python/current/api/pymongo/mongo_client.html>`_
+            parameters.
+          bucket_auto (bool): If :data:`True`, use MongoDB `$bucketAuto` 
aggregation
+            to split collection into bundles instead of `splitVector` command,
+            which does not work with MongoDB Atlas.
+            If :data:`False` (the default), use `splitVector` command for 
bundling.
+
+        Returns:
+          :class:`~apache_beam.transforms.ptransform.PTransform`
+
+        """
+        if extra_client_params is None:
+            extra_client_params = {}
+        if not isinstance(db, str):
+            raise ValueError("ReadFromMongDB db param must be specified as a 
string")
+        if not isinstance(coll, str):
+            raise ValueError("ReadFromMongDB coll param must be specified as a 
string")
+        self._mongo_source = _BoundedMongoSource(
+            uri=uri,
+            db=db,
+            coll=coll,
+            filter=filter,
+            projection=projection,
+            extra_client_params=extra_client_params,
+            bucket_auto=bucket_auto,
+        )
+
+    def expand(self, pcoll):
+        return pcoll | iobase.Read(self._mongo_source)
 
 
 class _BoundedMongoSource(iobase.BoundedSource):
-  def __init__(
-      self,
-      uri=None,
-      db=None,
-      coll=None,
-      filter=None,
-      projection=None,
-      extra_client_params=None,
-      bucket_auto=False):
-    if extra_client_params is None:
-      extra_client_params = {}
-    if filter is None:
-      filter = {}
-    self.uri = uri
-    self.db = db
-    self.coll = coll
-    self.filter = filter
-    self.projection = projection
-    self.spec = extra_client_params
-    self.bucket_auto = bucket_auto
-
-  def estimate_size(self):
-    with MongoClient(self.uri, **self.spec) as client:
-      return client[self.db].command('collstats', self.coll).get('size')
-
-  def _estimate_average_document_size(self):
-    with MongoClient(self.uri, **self.spec) as client:
-      return client[self.db].command('collstats', self.coll).get('avgObjSize')
-
-  def split(self, desired_bundle_size, start_position=None, 
stop_position=None):
-    desired_bundle_size_in_mb = desired_bundle_size // 1024 // 1024
-
-    # for desired bundle size, if desired chunk size smaller than 1mb, use
-    # MongoDB default split size of 1mb.
-    if desired_bundle_size_in_mb < 1:
-      desired_bundle_size_in_mb = 1
-
-    is_initial_split = start_position is None and stop_position is None
-    start_position, stop_position = self._replace_none_positions(
-      start_position, stop_position)
-
-    if self.bucket_auto:
-      # Use $bucketAuto for bundling
-      split_keys = []
-      weights = []
-      for bucket in self._get_auto_buckets(desired_bundle_size_in_mb,
-                                           start_position,
-                                           stop_position,
-                                           is_initial_split):
-        split_keys.append({'_id': bucket['_id']['max']})
-        weights.append(bucket['count'])
-    else:
-      # Use splitVector for bundling
-      split_keys = self._get_split_keys(
-          desired_bundle_size_in_mb, start_position, stop_position)
-      weights = itertools.cycle((desired_bundle_size_in_mb, ))
-
-    bundle_start = start_position
-    for split_key_id, weight in zip(split_keys, weights):
-      if bundle_start >= stop_position:
-        break
-      bundle_end = min(stop_position, split_key_id['_id'])
-      yield iobase.SourceBundle(
-          weight=weight,
-          source=self,
-          start_position=bundle_start,
-          stop_position=bundle_end)
-      bundle_start = bundle_end
-    # add range of last split_key to stop_position
-    if bundle_start < stop_position:
-      # bucket_auto mode can come here if not split due to single document
-      weight = 1 if self.bucket_auto else desired_bundle_size_in_mb
-      yield iobase.SourceBundle(
-          weight=weight,
-          source=self,
-          start_position=bundle_start,
-          stop_position=stop_position)
-
-  def get_range_tracker(self, start_position, stop_position):
-    start_position, stop_position = self._replace_none_positions(
-        start_position, stop_position)
-    return _ObjectIdRangeTracker(start_position, stop_position)
-
-  def read(self, range_tracker):
-    with MongoClient(self.uri, **self.spec) as client:
-      all_filters = self._merge_id_filter(
-          range_tracker.start_position(), range_tracker.stop_position())
-      docs_cursor = client[self.db][self.coll].find(
-          filter=all_filters,
-          projection=self.projection).sort([('_id', ASCENDING)])
-      for doc in docs_cursor:
-        if not range_tracker.try_claim(doc['_id']):
-          return
-        yield doc
-
-  def display_data(self):
-    res = super(_BoundedMongoSource, self).display_data()
-    res['database'] = self.db
-    res['collection'] = self.coll
-    res['filter'] = json.dumps(
-        self.filter, default=lambda x: 'not_serializable(%s)' % str(x))
-    res['projection'] = str(self.projection)
-    res['bucket_auto'] = self.bucket_auto
-    return res
-
-  def _get_split_keys(self, desired_chunk_size_in_mb, start_pos, end_pos):
-    # calls mongodb splitVector command to get document ids at split position
-    if start_pos >= _ObjectIdHelper.increment_id(end_pos, -1):
-      # single document not splittable
-      return []
-    with MongoClient(self.uri, **self.spec) as client:
-      name_space = '%s.%s' % (self.db, self.coll)
-      return (
-          client[self.db].command(
-              'splitVector',
-              name_space,
-              keyPattern={'_id': 1},  # Ascending index
-              min={'_id': start_pos},
-              max={'_id': end_pos},
-              maxChunkSize=desired_chunk_size_in_mb)['splitKeys'])
-
-  def _get_auto_buckets(
-      self, desired_chunk_size_in_mb, start_pos, end_pos, is_initial_split):
-
-    if start_pos >= _ObjectIdHelper.increment_id(end_pos, -1):
-      # single document not splittable
-      return []
-
-    if is_initial_split and not self.filter:
-      # total collection size
-      size_in_mb = self.estimate_size() / float(1 << 20)
-    else:
-      # size of documents within start/end id range and possibly filtered
-      documents_count = self._count_id_range(start_pos, end_pos)
-      avg_document_size = self._estimate_average_document_size()
-      size_in_mb = documents_count * avg_document_size / float(1 << 20)
-
-    if size_in_mb == 0:
-      # no documents not splittable (maybe a result of filtering)
-      return []
-
-    bucket_count = math.ceil(size_in_mb / desired_chunk_size_in_mb)
-    with beam.io.mongodbio.MongoClient(self.uri, **self.spec) as client:
-      pipeline = [
-          {
-              # filter by positions and by the custom filter if any
-              '$match': self._merge_id_filter(start_pos, end_pos)
-          },
-          {
-              '$bucketAuto': {
-                  'groupBy': '$_id', 'buckets': bucket_count
-              }
-          }
-      ]
-      buckets = list(client[self.db][self.coll].aggregate(pipeline))
-      if buckets:
-        buckets[-1]['_id']['max'] = end_pos
-
-      return buckets
-
-  def _merge_id_filter(self, start_position, stop_position):
-    # Merge the default filter (if any) with refined _id field range
-    # of range_tracker.
-    # $gte specifies start position (inclusive)
-    # and $lt specifies the end position (exclusive),
-    # see more at
-    # https://docs.mongodb.com/manual/reference/operator/query/gte/ and
-    # https://docs.mongodb.com/manual/reference/operator/query/lt/
-    id_filter = {'_id': {'$gte': start_position, '$lt': stop_position}}
-    if self.filter:
-      all_filters = {
-          # see more at
-          # https://docs.mongodb.com/manual/reference/operator/query/and/
-          '$and': [self.filter.copy(), id_filter]
-      }
-    else:
-      all_filters = id_filter
-
-    return all_filters
-
-  def _get_head_document_id(self, sort_order):
-    with MongoClient(self.uri, **self.spec) as client:
-      cursor = client[self.db][self.coll].find(
-          filter={}, projection=[]).sort([('_id', sort_order)]).limit(1)
-      try:
-        return cursor[0]['_id']
-      except IndexError:
-        raise ValueError('Empty Mongodb collection')
-
-  def _replace_none_positions(self, start_position, stop_position):
-    if start_position is None:
-      start_position = self._get_head_document_id(ASCENDING)
-    if stop_position is None:
-      last_doc_id = self._get_head_document_id(DESCENDING)
-      # increment last doc id binary value by 1 to make sure the last document
-      # is not excluded
-      stop_position = _ObjectIdHelper.increment_id(last_doc_id, 1)
-    return start_position, stop_position
-
-  def _count_id_range(self, start_position, stop_position):
-    # Number of documents between start_position (inclusive)
-    # and stop_position (exclusive), respecting the custom filter if any.
-    with MongoClient(self.uri, **self.spec) as client:
-      return client[self.db][self.coll].count_documents(
-          filter=self._merge_id_filter(start_position, stop_position))
-
-
-class _ObjectIdHelper(object):
-  """A Utility class to manipulate bson object ids."""
-  @classmethod
-  def id_to_int(cls, id):
+    """A MongoDB source that reads a finite amount of input records.
+
+    This class defines following operations which can be used to read the 
source
+    efficiently.
+
+    * Size estimation - method ``estimate_size()`` may return an accurate
+      estimation in bytes for the size of the source.
+    * Splitting into bundles of a given size - method ``split()`` can be used 
to
+      split the source into a set of sub-sources (bundles) based on a desired
+      bundle size.
+    * Getting a MongoDBRangeTracker - method ``get_range_tracker()`` should 
return a
+      ``MongoDBRangeTracker`` object for a given position range for the 
position type
+      of the records returned by the source.
+    * Reading the data - method ``read()`` can be used to read data from the
+      source while respecting the boundaries defined by a given
+      ``MongoDBRangeTracker``.
+
+    A runner will perform reading the source in two steps.
+
+    (1) Method ``get_range_tracker()`` will be invoked with start and end
+        positions to obtain a ``MongoDBRangeTracker`` for the range of 
positions the
+        runner intends to read. Source must define a default initial start and 
end
+        position range. These positions must be used if the start and/or end
+        positions passed to the method ``get_range_tracker()`` are ``None``
+    (2) Method read() will be invoked with the ``MongoDBRangeTracker`` 
obtained in the
+        previous step.
+
+    **Mutability**
+
+    A ``_BoundedMongoSource`` object should not be mutated while
+    its methods (for example, ``read()``) are being invoked by a runner. Runner
+    implementations may invoke methods of ``_BoundedMongoSource`` objects 
through
+    multi-threaded and/or reentrant execution modes.
     """
-    Args:
-      id: ObjectId required for each MongoDB document _id field.
-
-    Returns: Converted integer value of ObjectId's 12 bytes binary value.
-
+    def __init__(
+        self,
+        uri=None,
+        db=None,
+        coll=None,
+        filter=None,
+        projection=None,
+        extra_client_params=None,
+        bucket_auto=False,
+    ):
+        if extra_client_params is None:
+            extra_client_params = {}
+        if filter is None:
+            filter = {}
+        self.uri = uri
+        self.db = db
+        self.coll = coll
+        self.filter = filter
+        self.projection = projection
+        self.spec = extra_client_params
+        self.bucket_auto = bucket_auto
+
+    def estimate_size(self):
+        with MongoClient(self.uri, **self.spec) as client:
+            return client[self.db].command("collstats", self.coll).get("size")
+
+    def _estimate_average_document_size(self):
+        with MongoClient(self.uri, **self.spec) as client:
+            return client[self.db].command("collstats", 
self.coll).get("avgObjSize")
+
+    def split(self, desired_bundle_size, start_position=None, 
stop_position=None):
+        desired_bundle_size_in_mb = desired_bundle_size // 1024 // 1024
+
+        # for desired bundle size, if desired chunk size smaller than 1mb, use
+        # MongoDB default split size of 1mb.
+        if desired_bundle_size_in_mb < 1:
+            desired_bundle_size_in_mb = 1
+
+        is_initial_split = start_position is None and stop_position is None
+        start_position, stop_position = self._replace_none_positions(
+            start_position, stop_position
+        )
+
+        if self.bucket_auto:
+            # Use $bucketAuto for bundling
+            split_keys = []
+            weights = []
+            for bucket in self._get_auto_buckets(
+                desired_bundle_size_in_mb,
+                start_position,
+                stop_position,
+                is_initial_split,
+            ):
+                split_keys.append({"_id": bucket["_id"]["max"]})
+                weights.append(bucket["count"])
+        else:
+            # Use splitVector for bundling
+            split_keys = self._get_split_keys(
+                desired_bundle_size_in_mb, start_position, stop_position
+            )
+            weights = itertools.cycle((desired_bundle_size_in_mb,))
+
+        bundle_start = start_position
+        for split_key_id, weight in zip(split_keys, weights):
+            if bundle_start >= stop_position:
+                break
+            bundle_end = min(stop_position, split_key_id["_id"])
+            yield iobase.SourceBundle(
+                weight=weight,
+                source=self,
+                start_position=bundle_start,
+                stop_position=bundle_end,
+            )
+            bundle_start = bundle_end
+        # add range of last split_key to stop_position
+        if bundle_start < stop_position:
+            # bucket_auto mode can come here if not split due to single 
document
+            weight = 1 if self.bucket_auto else desired_bundle_size_in_mb
+            yield iobase.SourceBundle(
+                weight=weight,
+                source=self,
+                start_position=bundle_start,
+                stop_position=stop_position,
+            )
+
+    def get_range_tracker(
+        self,
+        start_position: Union[int, str, ObjectId] = None,
+        stop_position: Union[int, str, ObjectId] = None,
+    ):
+        """Returns a MongoDBRangeTracker for a given position range.
+
+        Framework may invoke ``read()`` method with the MongoDBRangeTracker 
object
+        returned here to read data from the source.
+
+        Args:
+          start_position: starting position of the range. If 'None' default 
start
+                          position of the source must be used.
+          stop_position:  ending position of the range. If 'None' default stop
+                          position of the source must be used.
+        Returns:
+          a ``MongoDBRangeTracker`` for the given position range.
+        """
+        start_position, stop_position = self._replace_none_positions(
+            start_position, stop_position
+        )
+        return MongoDBRangeTracker(start_position, stop_position)
+
+    def read(self, range_tracker):
+        with MongoClient(self.uri, **self.spec) as client:
+            all_filters = self._merge_id_filter(
+                range_tracker.start_position(), range_tracker.stop_position()
+            )
+            docs_cursor = (
+                client[self.db][self.coll]
+                .find(filter=all_filters, projection=self.projection)
+                .sort([("_id", ASCENDING)])
+            )
+            for doc in docs_cursor:
+                if not range_tracker.try_claim(doc["_id"]):
+                    return
+                yield doc
+
+    def display_data(self):
+        res = super().display_data()
+        res["database"] = self.db
+        res["collection"] = self.coll
+        res["filter"] = json.dumps(self.filter, default=json_util.default)
+        res["projection"] = str(self.projection)
+        res["bucket_auto"] = self.bucket_auto
+        return res
+
+    def _get_split_keys(self, desired_chunk_size_in_mb, start_pos, end_pos):
+        # calls mongodb splitVector command to get document ids at split 
position
+        if start_pos >= _ObjectIdHelper.increment_id(end_pos, -1):
+            # single document not splittable
+            return []
+        with MongoClient(self.uri, **self.spec) as client:
+            name_space = "%s.%s" % (self.db, self.coll)
+            return client[self.db].command(
+                "splitVector",
+                name_space,
+                keyPattern={"_id": 1},  # Ascending index
+                min={"_id": start_pos},
+                max={"_id": end_pos},
+                maxChunkSize=desired_chunk_size_in_mb,
+            )["splitKeys"]
+
+    def _get_auto_buckets(
+        self,
+        desired_chunk_size_in_mb,
+        start_pos: ObjectId,
+        end_pos: ObjectId,
+        is_initial_split: bool,
+    ) -> list:
+        if start_pos >= _ObjectIdHelper.increment_id(end_pos, -1):
+            # single document not splittable
+            return []
+
+        if is_initial_split and not self.filter:
+            # total collection size in MB
+            size_in_mb = self.estimate_size() / float(1 << 20)
+        else:
+            # size of documents within start/end id range and possibly filtered
+            documents_count = self._count_id_range(start_pos, end_pos)
+            avg_document_size = self._estimate_average_document_size()
+            size_in_mb = documents_count * avg_document_size / float(1 << 20)
+
+        if size_in_mb == 0:
+            # no documents not splittable (maybe a result of filtering)
+            return []
+
+        bucket_count = math.ceil(size_in_mb / desired_chunk_size_in_mb)
+        with beam.io.mongodbio.MongoClient(self.uri, **self.spec) as client:
+            pipeline = [
+                {
+                    # filter by positions and by the custom filter if any
+                    "$match": self._merge_id_filter(start_pos, end_pos)
+                },
+                {"$bucketAuto": {"groupBy": "$_id", "buckets": bucket_count}},
+            ]
+            buckets = list(
+                # Use `allowDiskUse` option to avoid aggregation limit of 100 
Mb RAM
+                client[self.db][self.coll].aggregate(pipeline, 
allowDiskUse=True)
+            )
+            if buckets:
+                buckets[-1]["_id"]["max"] = end_pos
+
+            return buckets
+
+    def _merge_id_filter(
+        self,
+        start_position: Union[int, str, ObjectId],
+        stop_position: Union[int, str, ObjectId] = None,
+    ) -> dict:
+        """Merge the default filter (if any) with refined _id field range
+        of range_tracker.
+        $gte specifies start position (inclusive)
+        and $lt specifies the end position (exclusive),
+        see more at
+        https://docs.mongodb.com/manual/reference/operator/query/gte/ and
+        https://docs.mongodb.com/manual/reference/operator/query/lt/
+        """
+        if stop_position is None:
+            id_filter = {"_id": {"$gte": start_position}}
+        else:
+            id_filter = {"_id": {"$gte": start_position, "$lt": stop_position}}
+
+        if self.filter:
+            all_filters = {
+                # see more at
+                # https://docs.mongodb.com/manual/reference/operator/query/and/
+                "$and": [self.filter.copy(), id_filter]
+            }
+        else:
+            all_filters = id_filter
+
+        return all_filters
+
+    def _get_head_document_id(self, sort_order):
+        with MongoClient(self.uri, **self.spec) as client:
+            cursor = (
+                client[self.db][self.coll]
+                .find(filter={}, projection=[])
+                .sort([("_id", sort_order)])
+                .limit(1)
+            )
+            try:
+                return cursor[0]["_id"]
+
+            except IndexError:
+                raise ValueError("Empty Mongodb collection")
+
+    def _replace_none_positions(self, start_position, stop_position):
+
+        if start_position is None:
+            start_position = self._get_head_document_id(ASCENDING)
+        if stop_position is None:
+            last_doc_id = self._get_head_document_id(DESCENDING)
+            # increment last doc id binary value by 1 to make sure the last 
document
+            # is not excluded
+            stop_position = _ObjectIdHelper.increment_id(last_doc_id, 1)
+
+        return start_position, stop_position
+
+    def _count_id_range(self, start_position, stop_position):
+        """Number of documents between start_position (inclusive)
+        and stop_position (exclusive), respecting the custom filter if any.
+        """
+        with MongoClient(self.uri, **self.spec) as client:
+            return client[self.db][self.coll].count_documents(
+                filter=self._merge_id_filter(start_position, stop_position)
+            )
+
+
+class MongoDBRangeTracker(OrderedPositionRangeTracker):
+    """RangeTracker for tracking MongoDB `_id` of following types:
+    - int
+    - bytes

Review comment:
       can you add unit test and integration test accordingly to make sure they 
work as intended?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to