This is an automated email from the ASF dual-hosted git repository.
robertwb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 75875de Make SDFBoundedSource wrapper work with dynamic splitting
(#8944)
75875de is described below
commit 75875def7098dec8fcab89941ee398a34fbf5fb1
Author: Boyuan Zhang <[email protected]>
AuthorDate: Tue Jul 23 04:47:45 2019 -0700
Make SDFBoundedSource wrapper work with dynamic splitting (#8944)
---
sdks/python/apache_beam/io/iobase.py | 57 ++++++++++++++++++++--------
sdks/python/apache_beam/io/iobase_test.py | 30 +++++++++------
sdks/python/apache_beam/io/range_trackers.py | 19 ++++------
3 files changed, 68 insertions(+), 38 deletions(-)
diff --git a/sdks/python/apache_beam/io/iobase.py
b/sdks/python/apache_beam/io/iobase.py
index bb7c03c..6763c57 100644
--- a/sdks/python/apache_beam/io/iobase.py
+++ b/sdks/python/apache_beam/io/iobase.py
@@ -1346,15 +1346,27 @@ class _SDFBoundedSourceWrapper(ptransform.PTransform):
Delegated RangeTracker guarantees synchronization safety.
"""
- def __init__(self, range_tracker):
- if not isinstance(range_tracker, RangeTracker):
+ def __init__(self, restriction):
+ if not isinstance(restriction, SourceBundle):
raise ValueError('Initializing SDFBoundedSourceRestrictionTracker'
- 'requires a RangeTracker')
- self._delegate_range_tracker = range_tracker
+ 'requires a SourceBundle')
+ self._delegate_range_tracker = restriction.source.get_range_tracker(
+ restriction.start_position, restriction.stop_position)
+ self._source = restriction.source
+ self._weight = restriction.weight
+
+ def current_progress(self):
+ return RestrictionProgress(
+ fraction=self._delegate_range_tracker.fraction_consumed())
def current_restriction(self):
- return (self._delegate_range_tracker.start_position(),
- self._delegate_range_tracker.stop_position())
+ start_pos = self._delegate_range_tracker.start_position()
+ stop_pos = self._delegate_range_tracker.stop_position()
+ return SourceBundle(
+ self._weight,
+ self._source,
+ start_pos,
+ stop_pos)
def start_pos(self):
return self._delegate_range_tracker.start_position()
@@ -1373,15 +1385,32 @@ class _SDFBoundedSourceWrapper(ptransform.PTransform):
# Need to stash current stop_pos before splitting since
# range_tracker.split will update its stop_pos if splits
# successfully.
+ start_pos = self.start_pos()
stop_pos = self.stop_pos()
- split_pos, _ = self._delegate_range_tracker.try_split(position)
- if split_pos:
- return ((self._delegate_range_tracker.start_position(), split_pos),
- (split_pos, stop_pos))
+ split_result = self._delegate_range_tracker.try_split(position)
+ if split_result:
+ split_pos, split_fraction = split_result
+ primary_weight = self._weight * split_fraction
+ residual_weight = self._weight - primary_weight
+ # Update self._weight to primary weight
+ self._weight = primary_weight
+ return (SourceBundle(primary_weight, self._source, start_pos,
+ split_pos),
+ SourceBundle(residual_weight, self._source, split_pos,
+ stop_pos))
def deferred_status(self):
return None
+ def current_watermark(self):
+ return None
+
+ def get_delegate_range_tracker(self):
+ return self._delegate_range_tracker
+
+ def get_tracking_source(self):
+ return self._source
+
class _SDFBoundedSourceRestrictionProvider(core.RestrictionProvider):
"""A `RestrictionProvider` that is used by SDF for `BoundedSource`."""
@@ -1399,8 +1428,7 @@ class _SDFBoundedSourceWrapper(ptransform.PTransform):
def create_tracker(self, restriction):
return _SDFBoundedSourceWrapper._SDFBoundedSourceRestrictionTracker(
- restriction.source.get_range_tracker(restriction.start_position,
- restriction.stop_position))
+ restriction)
def split(self, element, restriction):
# Invoke source.split to get initial splitting results.
@@ -1431,9 +1459,8 @@ class _SDFBoundedSourceWrapper(ptransform.PTransform):
restriction_tracker=core.DoFn.RestrictionParam(
_SDFBoundedSourceWrapper._SDFBoundedSourceRestrictionProvider(
source, chunk_size))):
- start_pos, end_pos = restriction_tracker.current_restriction()
- range_tracker = self.source.get_range_tracker(start_pos, end_pos)
- return self.source.read(range_tracker)
+ return restriction_tracker.get_tracking_source().read(
+ restriction_tracker.get_delegate_range_tracker())
return SDFBoundedSourceDoFn(self.source)
diff --git a/sdks/python/apache_beam/io/iobase_test.py
b/sdks/python/apache_beam/io/iobase_test.py
index 65fc89c..c7d1656 100644
--- a/sdks/python/apache_beam/io/iobase_test.py
+++ b/sdks/python/apache_beam/io/iobase_test.py
@@ -25,7 +25,6 @@ from apache_beam.io.concat_source import ConcatSource
from apache_beam.io.concat_source_test import RangeSource
from apache_beam.io import iobase
from apache_beam.io.iobase import SourceBundle
-from apache_beam.io.range_trackers import OffsetRangeTracker
class SDFBoundedSourceRestrictionProviderTest(unittest.TestCase):
@@ -115,14 +114,17 @@ class
SDFBoundedSourceRestrictionTrackerTest(unittest.TestCase):
def setUp(self):
self.initial_start_pos = 0
self.initial_stop_pos = 4
- self.range_tracker = OffsetRangeTracker(self.initial_start_pos,
- self.initial_stop_pos)
+ source_bundle = SourceBundle(
+ self.initial_stop_pos - self.initial_start_pos,
+ RangeSource(self.initial_start_pos, self.initial_stop_pos),
+ self.initial_start_pos,
+ self.initial_stop_pos)
self.sdf_restriction_tracker = (
iobase._SDFBoundedSourceWrapper._SDFBoundedSourceRestrictionTracker(
- self.range_tracker))
+ source_bundle))
def test_current_restriction_before_split(self):
- actual_start, actual_stop = (
+ _, _, actual_start, actual_stop = (
self.sdf_restriction_tracker.current_restriction())
self.assertEqual(self.initial_start_pos, actual_start)
self.assertEqual(self.initial_stop_pos, actual_stop)
@@ -136,14 +138,20 @@ class
SDFBoundedSourceRestrictionTrackerTest(unittest.TestCase):
self.sdf_restriction_tracker.current_restriction())
def test_try_split_at_remainder(self):
- fraction_of_remainder = 0.5
- expected_primary = (0, 3)
- expected_residual = (3, 4)
- self.sdf_restriction_tracker.try_claim(1)
+ fraction_of_remainder = 0.4
+ expected_primary = (0, 2, 2.0)
+ expected_residual = (2, 4, 2.0)
+ self.sdf_restriction_tracker.try_claim(0)
actual_primary, actual_residual = (
self.sdf_restriction_tracker.try_split(fraction_of_remainder))
- self.assertEqual(expected_primary, actual_primary)
- self.assertEqual(expected_residual, actual_residual)
+ self.assertEqual(expected_primary, (actual_primary.start_position,
+ actual_primary.stop_position,
+ actual_primary.weight))
+ self.assertEqual(expected_residual, (actual_residual.start_position,
+ actual_residual.stop_position,
+ actual_residual.weight))
+ self.assertEqual(actual_primary.weight,
+ self.sdf_restriction_tracker._weight)
if __name__ == '__main__':
diff --git a/sdks/python/apache_beam/io/range_trackers.py
b/sdks/python/apache_beam/io/range_trackers.py
index 5bf4898..c46f801 100644
--- a/sdks/python/apache_beam/io/range_trackers.py
+++ b/sdks/python/apache_beam/io/range_trackers.py
@@ -167,17 +167,19 @@ class OffsetRangeTracker(iobase.RangeTracker):
def fraction_consumed(self):
with self._lock:
- fraction = ((1.0 * (self._last_record_start - self.start_position()) /
- (self.stop_position() - self.start_position())) if
- self.stop_position() != self.start_position() else 0.0)
-
# self.last_record_start may become larger than self.end_offset when
# reading the records since any record that starts before the first
'split
# point' at or after the defined 'stop offset' is considered to be within
# the range of the OffsetRangeTracker. Hence fraction could be > 1.
# self.last_record_start is initialized to -1, hence fraction may be < 0.
# Bounding the to range [0, 1].
- return max(0.0, min(1.0, fraction))
+ return self.position_to_fraction(self._last_record_start,
+ self.start_position(),
+ self.stop_position())
+
+ def position_to_fraction(self, pos, start, stop):
+ fraction = 1.0 * (pos - start) / (stop - start) if start != stop else 0.0
+ return max(0.0, min(1.0, fraction))
def position_at_fraction(self, fraction):
if self.stop_position() == OffsetRangeTracker.OFFSET_INFINITY:
@@ -271,13 +273,6 @@ class OrderedPositionRangeTracker(iobase.RangeTracker):
return self.position_to_fraction(
self._last_claim, self._start_position, self._stop_position)
- def position_to_fraction(self, pos, start, end):
- """
- Converts a position `pos` betweeen `start` and `end` (inclusive) to a
- fraction between 0 and 1.
- """
- raise NotImplementedError
-
def fraction_to_position(self, fraction, start, end):
"""
Converts a fraction between 0 and 1 to a position between start and end.