boyuanzz commented on a change in pull request #13154:
URL: https://github.com/apache/beam/pull/13154#discussion_r514457389
##########
File path: sdks/python/apache_beam/io/iobase.py
##########
@@ -1427,194 +1432,184 @@ def with_completed(self, completed):
fraction=self._fraction, remaining=self._remaining,
completed=completed)
-class _SDFBoundedSourceWrapper(ptransform.PTransform):
- """A ``PTransform`` that uses SDF to read from a ``BoundedSource``.
+class _SDFBoundedSourceRestriction(object):
+ """ A restriction wraps SourceBundle and RangeTracker. """
+ def __init__(self, source_bundle, range_tracker=None):
+ self._source_bundle = source_bundle
+ self._range_tracker = range_tracker
- NOTE: This transform can only be used with beam_fn_api enabled.
+ def __reduce__(self):
+ # The instance of RangeTracker shouldn't be serialized.
+ return (self.__class__, (self._source_bundle, ))
+
+ def range_tracker(self):
+ if not self._range_tracker:
+ self._range_tracker = self._source_bundle.source.get_range_tracker(
+ self._source_bundle.start_position,
self._source_bundle.stop_position)
+ return self._range_tracker
+
+ def weight(self):
+ return self._source_bundle.weight
+
+ def source(self):
+ return self._source_bundle.source
+
+ def try_split(self, fraction_of_remainder):
+ consumed_fraction = self.range_tracker().fraction_consumed()
+ fraction = (
+ consumed_fraction + (1 - consumed_fraction) * fraction_of_remainder)
+ position = self.range_tracker().position_at_fraction(fraction)
+ # Need to stash current stop_pos before splitting since
+ # range_tracker.split will update its stop_pos if splits
+ # successfully.
+ stop_pos = self._source_bundle.stop_position
+ split_result = self.range_tracker().try_split(position)
+ if split_result:
+ split_pos, split_fraction = split_result
+ primary_weight = self._source_bundle.weight * split_fraction
+ residual_weight = self._source_bundle.weight - primary_weight
+ # Update self to primary weight and end position.
+ self._source_bundle = SourceBundle(
+ primary_weight,
+ self._source_bundle.source,
+ self._source_bundle.start_position,
+ split_pos)
+ return (
+ self,
+ _SDFBoundedSourceRestriction(
+ SourceBundle(
+ residual_weight,
+ self._source_bundle.source,
+ split_pos,
+ stop_pos)))
+
+
+class _SDFBoundedSourceRestrictionTracker(RestrictionTracker):
+ """An `iobase.RestrictionTracker` implementations for wrapping BoundedSource
+ with SDF. The tracked restriction is a _SDFBoundedSourceRestriction, which
+ wraps SourceBundle and RangeTracker.
+
+ Delegated RangeTracker guarantees synchronization safety.
"""
- class _SDFBoundedSourceRestriction(object):
- """ A restriction wraps SourceBundle and RangeTracker. """
- def __init__(self, source_bundle, range_tracker=None):
- self._source_bundle = source_bundle
- self._range_tracker = range_tracker
-
- def __reduce__(self):
- # The instance of RangeTracker shouldn't be serialized.
- return (self.__class__, (self._source_bundle, ))
-
- def range_tracker(self):
- if not self._range_tracker:
- self._range_tracker = self._source_bundle.source.get_range_tracker(
- self._source_bundle.start_position,
- self._source_bundle.stop_position)
- return self._range_tracker
-
- def weight(self):
- return self._source_bundle.weight
-
- def source(self):
- return self._source_bundle.source
-
- def try_split(self, fraction_of_remainder):
- consumed_fraction = self.range_tracker().fraction_consumed()
- fraction = (
- consumed_fraction + (1 - consumed_fraction) * fraction_of_remainder)
- position = self.range_tracker().position_at_fraction(fraction)
- # Need to stash current stop_pos before splitting since
- # range_tracker.split will update its stop_pos if splits
- # successfully.
- stop_pos = self._source_bundle.stop_position
- split_result = self.range_tracker().try_split(position)
- if split_result:
- split_pos, split_fraction = split_result
- primary_weight = self._source_bundle.weight * split_fraction
- residual_weight = self._source_bundle.weight - primary_weight
- # Update self to primary weight and end position.
- self._source_bundle = SourceBundle(
- primary_weight,
- self._source_bundle.source,
- self._source_bundle.start_position,
- split_pos)
- return (
- self,
- _SDFBoundedSourceWrapper._SDFBoundedSourceRestriction(
- SourceBundle(
- residual_weight,
- self._source_bundle.source,
- split_pos,
- stop_pos)))
-
- class _SDFBoundedSourceRestrictionTracker(RestrictionTracker):
- """An `iobase.RestrictionTracker` implementations for wrapping
BoundedSource
- with SDF. The tracked restriction is a _SDFBoundedSourceRestriction, which
- wraps SourceBundle and RangeTracker.
-
- Delegated RangeTracker guarantees synchronization safety.
- """
- def __init__(self, restriction):
- if not isinstance(restriction,
- _SDFBoundedSourceWrapper._SDFBoundedSourceRestriction):
- raise ValueError(
- 'Initializing SDFBoundedSourceRestrictionTracker'
- ' requires a _SDFBoundedSourceRestriction')
- self.restriction = restriction
-
- def current_progress(self):
- # type: () -> RestrictionProgress
- return RestrictionProgress(
- fraction=self.restriction.range_tracker().fraction_consumed())
-
- def current_restriction(self):
- self.restriction.range_tracker()
- return self.restriction
-
- def start_pos(self):
- return self.restriction.range_tracker().start_position()
-
- def stop_pos(self):
- return self.restriction.range_tracker().stop_position()
-
- def try_claim(self, position):
- return self.restriction.range_tracker().try_claim(position)
-
- def try_split(self, fraction_of_remainder):
- return self.restriction.try_split(fraction_of_remainder)
-
- def check_done(self):
- return self.restriction.range_tracker().fraction_consumed() >= 1.0
-
- def is_bounded(self):
- return True
-
- class _SDFBoundedSourceRestrictionProvider(core.RestrictionProvider):
- """A `RestrictionProvider` that is used by SDF for `BoundedSource`."""
- def __init__(self, source, desired_chunk_size=None):
- self._source = source
- self._desired_chunk_size = desired_chunk_size
-
- def initial_restriction(self, element):
- # Get initial range_tracker from source
- range_tracker = self._source.get_range_tracker(None, None)
- return _SDFBoundedSourceWrapper._SDFBoundedSourceRestriction(
- SourceBundle(
- None,
- self._source,
- range_tracker.start_position(),
- range_tracker.stop_position()))
-
- def create_tracker(self, restriction):
- return _SDFBoundedSourceWrapper._SDFBoundedSourceRestrictionTracker(
- restriction)
-
- def split(self, element, restriction):
- if self._desired_chunk_size is None:
- try:
- estimated_size = self._source.estimate_size()
- except NotImplementedError:
- estimated_size = None
- self._desired_chunk_size = Read.get_desired_chunk_size(estimated_size)
-
- # Invoke source.split to get initial splitting results.
- source_bundles = self._source.split(self._desired_chunk_size)
- for source_bundle in source_bundles:
- yield _SDFBoundedSourceWrapper._SDFBoundedSourceRestriction(
- source_bundle)
-
- def restriction_size(self, element, restriction):
- return restriction.weight()
-
- def restriction_coder(self):
- return coders.DillCoder()
+ def __init__(self, restriction):
+ if not isinstance(restriction, _SDFBoundedSourceRestriction):
+ raise ValueError(
+ 'Initializing SDFBoundedSourceRestrictionTracker'
+ ' requires a _SDFBoundedSourceRestriction')
+ self.restriction = restriction
- def __init__(self, source):
- if not isinstance(source, BoundedSource):
- raise RuntimeError('SDFBoundedSourceWrapper can only wrap BoundedSource')
- super(_SDFBoundedSourceWrapper, self).__init__()
- self.source = source
+ def current_progress(self):
+ # type: () -> RestrictionProgress
+ return RestrictionProgress(
+ fraction=self.restriction.range_tracker().fraction_consumed())
- def _create_sdf_bounded_source_dofn(self):
- source = self.source
+ def current_restriction(self):
+ self.restriction.range_tracker()
+ return self.restriction
- class SDFBoundedSourceDoFn(core.DoFn):
- def __init__(self, read_source):
- self.source = read_source
+ def start_pos(self):
+ return self.restriction.range_tracker().start_position()
+
+ def stop_pos(self):
+ return self.restriction.range_tracker().stop_position()
+
+ def try_claim(self, position):
+ return self.restriction.range_tracker().try_claim(position)
- def display_data(self):
- return {
- 'source': DisplayDataItem(
- self.source.__class__, label='Read Source'),
- 'source_dd': self.source
- }
+ def try_split(self, fraction_of_remainder):
+ return self.restriction.try_split(fraction_of_remainder)
+
+ def check_done(self):
+ return self.restriction.range_tracker().fraction_consumed() >= 1.0
+
+ def is_bounded(self):
+ return True
+
+
+class _SDFBoundedSourceRestrictionProvider(core.RestrictionProvider):
+ """
+ A `RestrictionProvider` that is used by SDF for `BoundedSource`.
+
+ If source is provided, uses it for initializing restriction. Otherwise
+ initializes restriction based on input element that is expected to be of
+ BoundedSource type.
+ """
+ def __init__(self, desired_chunk_size=None):
+ self._desired_chunk_size = desired_chunk_size
+
+ def _check_source(self, src):
+ if src is not None and not isinstance(src, BoundedSource):
Review comment:
The `src` cannot be `None`, right?
##########
File path: sdks/python/apache_beam/io/iobase.py
##########
@@ -1427,194 +1432,184 @@ def with_completed(self, completed):
fraction=self._fraction, remaining=self._remaining,
completed=completed)
-class _SDFBoundedSourceWrapper(ptransform.PTransform):
- """A ``PTransform`` that uses SDF to read from a ``BoundedSource``.
+class _SDFBoundedSourceRestriction(object):
+ """ A restriction wraps SourceBundle and RangeTracker. """
+ def __init__(self, source_bundle, range_tracker=None):
+ self._source_bundle = source_bundle
+ self._range_tracker = range_tracker
- NOTE: This transform can only be used with beam_fn_api enabled.
+ def __reduce__(self):
+ # The instance of RangeTracker shouldn't be serialized.
+ return (self.__class__, (self._source_bundle, ))
+
+ def range_tracker(self):
+ if not self._range_tracker:
+ self._range_tracker = self._source_bundle.source.get_range_tracker(
+ self._source_bundle.start_position,
self._source_bundle.stop_position)
+ return self._range_tracker
+
+ def weight(self):
+ return self._source_bundle.weight
+
+ def source(self):
+ return self._source_bundle.source
+
+ def try_split(self, fraction_of_remainder):
+ consumed_fraction = self.range_tracker().fraction_consumed()
+ fraction = (
+ consumed_fraction + (1 - consumed_fraction) * fraction_of_remainder)
+ position = self.range_tracker().position_at_fraction(fraction)
+ # Need to stash current stop_pos before splitting since
+ # range_tracker.split will update its stop_pos if splits
+ # successfully.
+ stop_pos = self._source_bundle.stop_position
+ split_result = self.range_tracker().try_split(position)
+ if split_result:
+ split_pos, split_fraction = split_result
+ primary_weight = self._source_bundle.weight * split_fraction
+ residual_weight = self._source_bundle.weight - primary_weight
+ # Update self to primary weight and end position.
+ self._source_bundle = SourceBundle(
+ primary_weight,
+ self._source_bundle.source,
+ self._source_bundle.start_position,
+ split_pos)
+ return (
+ self,
+ _SDFBoundedSourceRestriction(
+ SourceBundle(
+ residual_weight,
+ self._source_bundle.source,
+ split_pos,
+ stop_pos)))
+
+
+class _SDFBoundedSourceRestrictionTracker(RestrictionTracker):
+ """An `iobase.RestrictionTracker` implementations for wrapping BoundedSource
+ with SDF. The tracked restriction is a _SDFBoundedSourceRestriction, which
+ wraps SourceBundle and RangeTracker.
+
+ Delegated RangeTracker guarantees synchronization safety.
"""
- class _SDFBoundedSourceRestriction(object):
- """ A restriction wraps SourceBundle and RangeTracker. """
- def __init__(self, source_bundle, range_tracker=None):
- self._source_bundle = source_bundle
- self._range_tracker = range_tracker
-
- def __reduce__(self):
- # The instance of RangeTracker shouldn't be serialized.
- return (self.__class__, (self._source_bundle, ))
-
- def range_tracker(self):
- if not self._range_tracker:
- self._range_tracker = self._source_bundle.source.get_range_tracker(
- self._source_bundle.start_position,
- self._source_bundle.stop_position)
- return self._range_tracker
-
- def weight(self):
- return self._source_bundle.weight
-
- def source(self):
- return self._source_bundle.source
-
- def try_split(self, fraction_of_remainder):
- consumed_fraction = self.range_tracker().fraction_consumed()
- fraction = (
- consumed_fraction + (1 - consumed_fraction) * fraction_of_remainder)
- position = self.range_tracker().position_at_fraction(fraction)
- # Need to stash current stop_pos before splitting since
- # range_tracker.split will update its stop_pos if splits
- # successfully.
- stop_pos = self._source_bundle.stop_position
- split_result = self.range_tracker().try_split(position)
- if split_result:
- split_pos, split_fraction = split_result
- primary_weight = self._source_bundle.weight * split_fraction
- residual_weight = self._source_bundle.weight - primary_weight
- # Update self to primary weight and end position.
- self._source_bundle = SourceBundle(
- primary_weight,
- self._source_bundle.source,
- self._source_bundle.start_position,
- split_pos)
- return (
- self,
- _SDFBoundedSourceWrapper._SDFBoundedSourceRestriction(
- SourceBundle(
- residual_weight,
- self._source_bundle.source,
- split_pos,
- stop_pos)))
-
- class _SDFBoundedSourceRestrictionTracker(RestrictionTracker):
- """An `iobase.RestrictionTracker` implementations for wrapping
BoundedSource
- with SDF. The tracked restriction is a _SDFBoundedSourceRestriction, which
- wraps SourceBundle and RangeTracker.
-
- Delegated RangeTracker guarantees synchronization safety.
- """
- def __init__(self, restriction):
- if not isinstance(restriction,
- _SDFBoundedSourceWrapper._SDFBoundedSourceRestriction):
- raise ValueError(
- 'Initializing SDFBoundedSourceRestrictionTracker'
- ' requires a _SDFBoundedSourceRestriction')
- self.restriction = restriction
-
- def current_progress(self):
- # type: () -> RestrictionProgress
- return RestrictionProgress(
- fraction=self.restriction.range_tracker().fraction_consumed())
-
- def current_restriction(self):
- self.restriction.range_tracker()
- return self.restriction
-
- def start_pos(self):
- return self.restriction.range_tracker().start_position()
-
- def stop_pos(self):
- return self.restriction.range_tracker().stop_position()
-
- def try_claim(self, position):
- return self.restriction.range_tracker().try_claim(position)
-
- def try_split(self, fraction_of_remainder):
- return self.restriction.try_split(fraction_of_remainder)
-
- def check_done(self):
- return self.restriction.range_tracker().fraction_consumed() >= 1.0
-
- def is_bounded(self):
- return True
-
- class _SDFBoundedSourceRestrictionProvider(core.RestrictionProvider):
- """A `RestrictionProvider` that is used by SDF for `BoundedSource`."""
- def __init__(self, source, desired_chunk_size=None):
- self._source = source
- self._desired_chunk_size = desired_chunk_size
-
- def initial_restriction(self, element):
- # Get initial range_tracker from source
- range_tracker = self._source.get_range_tracker(None, None)
- return _SDFBoundedSourceWrapper._SDFBoundedSourceRestriction(
- SourceBundle(
- None,
- self._source,
- range_tracker.start_position(),
- range_tracker.stop_position()))
-
- def create_tracker(self, restriction):
- return _SDFBoundedSourceWrapper._SDFBoundedSourceRestrictionTracker(
- restriction)
-
- def split(self, element, restriction):
- if self._desired_chunk_size is None:
- try:
- estimated_size = self._source.estimate_size()
- except NotImplementedError:
- estimated_size = None
- self._desired_chunk_size = Read.get_desired_chunk_size(estimated_size)
-
- # Invoke source.split to get initial splitting results.
- source_bundles = self._source.split(self._desired_chunk_size)
- for source_bundle in source_bundles:
- yield _SDFBoundedSourceWrapper._SDFBoundedSourceRestriction(
- source_bundle)
-
- def restriction_size(self, element, restriction):
- return restriction.weight()
-
- def restriction_coder(self):
- return coders.DillCoder()
+ def __init__(self, restriction):
+ if not isinstance(restriction, _SDFBoundedSourceRestriction):
+ raise ValueError(
+ 'Initializing SDFBoundedSourceRestrictionTracker'
+ ' requires a _SDFBoundedSourceRestriction')
+ self.restriction = restriction
- def __init__(self, source):
- if not isinstance(source, BoundedSource):
- raise RuntimeError('SDFBoundedSourceWrapper can only wrap BoundedSource')
- super(_SDFBoundedSourceWrapper, self).__init__()
- self.source = source
+ def current_progress(self):
+ # type: () -> RestrictionProgress
+ return RestrictionProgress(
+ fraction=self.restriction.range_tracker().fraction_consumed())
- def _create_sdf_bounded_source_dofn(self):
- source = self.source
+ def current_restriction(self):
+ self.restriction.range_tracker()
+ return self.restriction
- class SDFBoundedSourceDoFn(core.DoFn):
- def __init__(self, read_source):
- self.source = read_source
+ def start_pos(self):
+ return self.restriction.range_tracker().start_position()
+
+ def stop_pos(self):
+ return self.restriction.range_tracker().stop_position()
+
+ def try_claim(self, position):
+ return self.restriction.range_tracker().try_claim(position)
- def display_data(self):
- return {
- 'source': DisplayDataItem(
- self.source.__class__, label='Read Source'),
- 'source_dd': self.source
- }
+ def try_split(self, fraction_of_remainder):
+ return self.restriction.try_split(fraction_of_remainder)
+
+ def check_done(self):
+ return self.restriction.range_tracker().fraction_consumed() >= 1.0
+
+ def is_bounded(self):
+ return True
+
+
+class _SDFBoundedSourceRestrictionProvider(core.RestrictionProvider):
+ """
+ A `RestrictionProvider` that is used by SDF for `BoundedSource`.
+
+ If source is provided, uses it for initializing restriction. Otherwise
Review comment:
It seems like we also need to update pydoc here as well.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]