[
https://issues.apache.org/jira/browse/BEAM-6120?focusedWorklogId=173521&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-173521
]
ASF GitHub Bot logged work on BEAM-6120:
----------------------------------------
Author: ASF GitHub Bot
Created on: 10/Dec/18 11:32
Start Date: 10/Dec/18 11:32
Worklog Time Spent: 10m
Work Description: robertwb closed pull request #7127: [BEAM-6120] Support
retrieval of large gbk iterables over the state API.
URL: https://github.com/apache/beam/pull/7127
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/model/pipeline/src/main/proto/beam_runner_api.proto
b/model/pipeline/src/main/proto/beam_runner_api.proto
index bbe2cfa91305..4a254ebdb324 100644
--- a/model/pipeline/src/main/proto/beam_runner_api.proto
+++ b/model/pipeline/src/main/proto/beam_runner_api.proto
@@ -549,6 +549,22 @@ message StandardCoders {
VARINT = 2 [(beam_urn) = "beam:coder:varint:v1"];
// Encodes an iterable of elements.
+ //
+ // The encoding for an iterable [e1...eN] of known length N is
+ //
+ // fixed32(N)
+ // encode(e1) encode(e2) encode(e3) ... encode(eN)
+ //
+ // If the length is unknown, it is batched up into groups of size b1..bM
+ // and encoded as
+ //
+ // fixed32(0)
+ // varInt64(b1) encode(e1) encode(e2) ... encode(e_b1)
+ // varInt64(b2) encode(e_(b1+1)) encode(e_(b1+2)) ... encode(e_(b1+b2))
+ // ...
+ // varInt64(bM) encode(e_(N-bM+1)) encode(e_(N-bM+2)) ... encode(eN)
+ // varInt64(0)
+ //
// Components: Coder for a single element.
ITERABLE = 3 [(beam_urn) = "beam:coder:iterable:v1"];
@@ -588,6 +604,23 @@ message StandardCoders {
// of the element
// Components: The element coder and the window coder, in that order
WINDOWED_VALUE = 8 [(beam_urn) = "beam:coder:windowed_value:v1"];
+
+ // Encodes an iterable of elements, some of which may be stored elsewhere.
+ //
+ // The encoding for a state-backed iterable is the same as that for
+ // an iterable, but the final varInt64(0) terminating the set of batches
+ // may instead be replaced by
+ //
+ // varInt64(-1)
+ // varInt64(len(token))
+ // token
+ //
+ // where token is an opaque byte string that can be used to fetch the
+ // remainder of the iterable (e.g. over the state API).
+ //
+ // Components: Coder for a single element.
+ // Experimental.
+ STATE_BACKED_ITERABLE = 9 [(beam_urn) =
"beam:coder:state_backed_iterable:v1"];
}
}
diff --git a/sdks/python/apache_beam/coders/coder_impl.pxd
b/sdks/python/apache_beam/coders/coder_impl.pxd
index db724e652f83..9d5ac808c3c4 100644
--- a/sdks/python/apache_beam/coders/coder_impl.pxd
+++ b/sdks/python/apache_beam/coders/coder_impl.pxd
@@ -128,9 +128,14 @@ cdef class TupleCoderImpl(AbstractComponentCoderImpl):
cdef class SequenceCoderImpl(StreamCoderImpl):
cdef CoderImpl _elem_coder
+ cdef object _read_state
+ cdef object _write_state
+ cdef int _write_state_threshold
+
cpdef _construct_from_sequence(self, values)
+
@cython.locals(buffer=OutputStream, target_buffer_size=libc.stdint.int64_t,
- index=libc.stdint.int64_t)
+ index=libc.stdint.int64_t, prev_index=libc.stdint.int64_t)
cpdef encode_to_stream(self, value, OutputStream stream, bint nested)
diff --git a/sdks/python/apache_beam/coders/coder_impl.py
b/sdks/python/apache_beam/coders/coder_impl.py
index c3768dd10f4e..244b82cd9194 100644
--- a/sdks/python/apache_beam/coders/coder_impl.py
+++ b/sdks/python/apache_beam/coders/coder_impl.py
@@ -637,20 +637,33 @@ class SequenceCoderImpl(StreamCoderImpl):
countX element(0) element(1) ... element(countX - 1)
0
+ If writing to state is enabled, the final terminating 0 will instead be
+ repaced with::
+
+ varInt64(-1)
+ len(state_token)
+ state_token
+
+ where state_token is a bytes object used to retrieve the remainder of the
+ iterable via the state API.
"""
# Default buffer size of 64kB of handling iterables of unknown length.
_DEFAULT_BUFFER_SIZE = 64 * 1024
- def __init__(self, elem_coder):
+ def __init__(self, elem_coder,
+ read_state=None, write_state=None, write_state_threshold=0):
self._elem_coder = elem_coder
+ self._read_state = read_state
+ self._write_state = write_state
+ self._write_state_threshold = write_state_threshold
def _construct_from_sequence(self, values):
raise NotImplementedError
def encode_to_stream(self, value, out, nested):
# Compatible with Java's IterableLikeCoder.
- if hasattr(value, '__len__'):
+ if hasattr(value, '__len__') and self._write_state is None:
out.write_bigendian_int32(len(value))
for elem in value:
self._elem_coder.encode_to_stream(elem, out, True)
@@ -662,19 +675,36 @@ def encode_to_stream(self, value, out, nested):
# -1 to indicate that the length is not known.
out.write_bigendian_int32(-1)
buffer = create_OutputStream()
- target_buffer_size = self._DEFAULT_BUFFER_SIZE
+ if self._write_state is None:
+ target_buffer_size = self._DEFAULT_BUFFER_SIZE
+ else:
+ target_buffer_size = min(
+ self._DEFAULT_BUFFER_SIZE, self._write_state_threshold)
prev_index = index = -1
- for index, elem in enumerate(value):
+ # Don't want to miss out on fast list iteration optimization.
+ value_iter = value if isinstance(value, (list, tuple)) else iter(value)
+ start_size = out.size()
+ for elem in value_iter:
+ index += 1
self._elem_coder.encode_to_stream(elem, buffer, True)
if buffer.size() > target_buffer_size:
out.write_var_int64(index - prev_index)
out.write(buffer.get())
prev_index = index
buffer = create_OutputStream()
- if index > prev_index:
- out.write_var_int64(index - prev_index)
- out.write(buffer.get())
- out.write_var_int64(0)
+ if (self._write_state is not None
+ and out.size() - start_size > self._write_state_threshold):
+ tail = (value_iter[index + 1:] if isinstance(value, (list, tuple))
+ else value_iter)
+ state_token = self._write_state(tail, self._elem_coder)
+ out.write_var_int64(-1)
+ out.write(state_token, True)
+ break
+ else:
+ if index > prev_index:
+ out.write_var_int64(index - prev_index)
+ out.write(buffer.get())
+ out.write_var_int64(0)
def decode_from_stream(self, in_stream, nested):
size = in_stream.read_bigendian_int32()
@@ -690,6 +720,35 @@ def decode_from_stream(self, in_stream, nested):
elements.append(self._elem_coder.decode_from_stream(in_stream, True))
count = in_stream.read_var_int64()
+ if count == -1:
+ if self._read_state is None:
+ raise ValueError(
+ 'Cannot read state-written iterable without state reader.')
+
+ class FullIterable(object):
+ def __init__(self, head, tail):
+ self._head = head
+ self._tail = tail
+
+ def __iter__(self):
+ for elem in self._head:
+ yield elem
+ for elem in self._tail:
+ yield elem
+
+ def __eq__(self, other):
+ return list(self) == list(other)
+
+ def __hash__(self):
+ raise NotImplementedError
+
+ def __reduce__(self):
+ return list, (list(self),)
+
+ state_token = in_stream.read_all(True)
+ elements = FullIterable(
+ elements, self._read_state(state_token, self._elem_coder))
+
return self._construct_from_sequence(elements)
def estimate_size(self, value, nested=False):
@@ -719,6 +778,8 @@ def get_estimated_size_and_observables(self, value,
nested=False):
# per block of data since we are not including the count prefix which
# occurs at most once per 64k of data and is upto 10 bytes long. The upper
# bound of the underestimate is 10 / 65536 ~= 0.0153% of the actual size.
+ # TODO: More efficient size estimation in the case of state-backed
+ # iterables.
return estimated_size, observables
diff --git a/sdks/python/apache_beam/coders/coders.py
b/sdks/python/apache_beam/coders/coders.py
index f2a4b2ee724a..e84e08dc5133 100644
--- a/sdks/python/apache_beam/coders/coders.py
+++ b/sdks/python/apache_beam/coders/coders.py
@@ -263,8 +263,9 @@ def to_runner_api(self, context):
context.default_environment_id() if context else None),
spec=beam_runner_api_pb2.FunctionSpec(
urn=urn,
- payload=typed_param.SerializeToString()
- if typed_param is not None else None)),
+ payload=typed_param
+ if isinstance(typed_param, (bytes, type(None)))
+ else typed_param.SerializeToString())),
component_coder_ids=[context.coders.get_id(c) for c in components])
@classmethod
@@ -1046,3 +1047,54 @@ def __hash__(self):
Coder.register_structured_urn(
common_urns.coders.LENGTH_PREFIX.urn, LengthPrefixCoder)
+
+
+class StateBackedIterableCoder(FastCoder):
+ def __init__(
+ self,
+ element_coder,
+ read_state=None,
+ write_state=None,
+ write_state_threshold=1):
+ self._element_coder = element_coder
+ self._read_state = read_state
+ self._write_state = write_state
+ self._write_state_threshold = write_state_threshold
+
+ def _create_impl(self):
+ return coder_impl.IterableCoderImpl(
+ self._element_coder.get_impl(),
+ self._read_state,
+ self._write_state,
+ self._write_state_threshold)
+
+ def is_deterministic(self):
+ return False
+
+ def _get_component_coders(self):
+ return (self._element_coder,)
+
+ def __repr__(self):
+ return 'StateBackedIterableCoder[%r]' % self._element_coder
+
+ def __eq__(self, other):
+ return (type(self) == type(other)
+ and self._element_coder == other._element_coder
+ and self._write_state_threshold == other._write_state_threshold)
+
+ def __hash__(self):
+ return hash((type(self), self._element_coder, self._write_state_threshold))
+
+ def to_runner_api_parameter(self, context):
+ return (
+ common_urns.coders.STATE_BACKED_ITERABLE.urn,
+ str(self._write_state_threshold).encode('ascii'),
+ self._get_component_coders())
+
+ @Coder.register_urn(common_urns.coders.STATE_BACKED_ITERABLE.urn, bytes)
+ def from_runner_api_parameter(payload, components, context):
+ return StateBackedIterableCoder(
+ components[0],
+ read_state=context.iterable_state_read,
+ write_state=context.iterable_state_write,
+ write_state_threshold=int(payload))
diff --git a/sdks/python/apache_beam/coders/coders_test_common.py
b/sdks/python/apache_beam/coders/coders_test_common.py
index 4707c32f8469..0ee38c27b140 100644
--- a/sdks/python/apache_beam/coders/coders_test_common.py
+++ b/sdks/python/apache_beam/coders/coders_test_common.py
@@ -86,18 +86,21 @@ def _observe_nested(cls, coder):
cls.seen_nested.add(type(c))
cls._observe_nested(c)
- def check_coder(self, coder, *values):
+ def check_coder(self, coder, *values, **kwargs):
+ context = kwargs.pop('context', pipeline_context.PipelineContext())
+ test_size_estimation = kwargs.pop('test_size_estimation', True)
+ assert not kwargs
self._observe(coder)
for v in values:
self.assertEqual(v, coder.decode(coder.encode(v)))
- self.assertEqual(coder.estimate_size(v),
- len(coder.encode(v)))
- self.assertEqual(coder.estimate_size(v),
- coder.get_impl().estimate_size(v))
- self.assertEqual(coder.get_impl().get_estimated_size_and_observables(v),
- (coder.get_impl().estimate_size(v), []))
- copy1 = dill.loads(dill.dumps(coder))
- context = pipeline_context.PipelineContext()
+ if test_size_estimation:
+ self.assertEqual(coder.estimate_size(v),
+ len(coder.encode(v)))
+ self.assertEqual(coder.estimate_size(v),
+ coder.get_impl().estimate_size(v))
+
self.assertEqual(coder.get_impl().get_estimated_size_and_observables(v),
+ (coder.get_impl().estimate_size(v), []))
+ copy1 = dill.loads(dill.dumps(coder))
copy2 = coders.Coder.from_runner_api(coder.to_runner_api(context), context)
for v in values:
self.assertEqual(v, copy1.decode(copy2.encode(v)))
@@ -445,6 +448,37 @@ def __iter__(self):
coder.get_impl().get_estimated_size_and_observables(value)[1],
[(observ, elem_coder.get_impl())])
+ def test_state_backed_iterable_coder(self):
+ # pylint: disable=global-variable-undefined
+ # required for pickling by reference
+ global state
+ state = {}
+
+ def iterable_state_write(values, element_coder_impl):
+ token = b'state_token_%d' % len(state)
+ state[token] = [element_coder_impl.encode(e) for e in values]
+ return token
+
+ def iterable_state_read(token, element_coder_impl):
+ return [element_coder_impl.decode(s) for s in state[token]]
+
+ coder = coders.StateBackedIterableCoder(
+ coders.VarIntCoder(),
+ read_state=iterable_state_read,
+ write_state=iterable_state_write,
+ write_state_threshold=1)
+ context = pipeline_context.PipelineContext(
+ iterable_state_read=iterable_state_read,
+ iterable_state_write=iterable_state_write)
+ self.check_coder(
+ coder, [1, 2, 3], context=context, test_size_estimation=False)
+ # Ensure that state was actually used.
+ self.assertNotEqual(state, {})
+ self.check_coder(coders.TupleCoder((coder, coder)),
+ ([1], [2, 3]),
+ context=context,
+ test_size_estimation=False)
+
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
diff --git a/sdks/python/apache_beam/runners/pipeline_context.py
b/sdks/python/apache_beam/runners/pipeline_context.py
index 5f774f8db605..437fe0194994 100644
--- a/sdks/python/apache_beam/runners/pipeline_context.py
+++ b/sdks/python/apache_beam/runners/pipeline_context.py
@@ -111,7 +111,8 @@ class PipelineContext(object):
}
def __init__(
- self, proto=None, default_environment=None, use_fake_coders=False):
+ self, proto=None, default_environment=None, use_fake_coders=False,
+ iterable_state_read=None, iterable_state_write=None):
if isinstance(proto, beam_fn_api_pb2.ProcessBundleDescriptor):
proto = beam_runner_api_pb2.Components(
coders=dict(proto.coders.items()),
@@ -127,6 +128,8 @@ def __init__(
else:
self._default_environment_id = None
self.use_fake_coders = use_fake_coders
+ self.iterable_state_read = iterable_state_read
+ self.iterable_state_write = iterable_state_write
# If fake coders are requested, return a pickled version of the element type
# rather than an actual coder. The element type is required for some runners,
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py
b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
index 62a2d299983d..e06fe0b8c2b0 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
@@ -212,7 +212,8 @@ def encoded_items(self):
class FnApiRunner(runner.PipelineRunner):
- def __init__(self, use_grpc=False, sdk_harness_factory=None,
bundle_repeat=0):
+ def __init__(self, use_grpc=False, sdk_harness_factory=None, bundle_repeat=0,
+ use_state_iterables=False):
"""Creates a new Fn API Runner.
Args:
@@ -222,6 +223,8 @@ def __init__(self, use_grpc=False,
sdk_harness_factory=None, bundle_repeat=0):
typcially not set by users
bundle_repeat: replay every bundle this many extra times, for profiling
and debugging
+ use_state_iterables: Intentionally split gbk iterables over state API
+ (for testing)
"""
super(FnApiRunner, self).__init__()
self._last_uid = -1
@@ -232,6 +235,7 @@ def __init__(self, use_grpc=False,
sdk_harness_factory=None, bundle_repeat=0):
self._bundle_repeat = bundle_repeat
self._progress_frequency = None
self._profiler_factory = None
+ self._use_state_iterables = use_state_iterables
def _next_uid(self):
self._last_uid += 1
@@ -313,6 +317,35 @@ def windowed_coder_id(coder_id, window_coder_id):
component_coder_ids=[coder_id, window_coder_id])
return add_or_get_coder_id(proto)
+ _with_state_iterables_cache = {}
+
+ def with_state_iterables(coder_id):
+ if coder_id not in _with_state_iterables_cache:
+ _with_state_iterables_cache[
+ coder_id] = create_with_state_iterables(coder_id)
+ return _with_state_iterables_cache[coder_id]
+
+ def create_with_state_iterables(coder_id):
+ coder = pipeline_components.coders[coder_id]
+ if coder.spec.spec.urn == common_urns.coders.ITERABLE.urn:
+ new_coder_id = unique_name(pipeline_components.coders, 'coder')
+ new_coder = pipeline_components.coders[new_coder_id]
+ new_coder.CopyFrom(coder)
+ new_coder.spec.spec.urn = common_urns.coders.STATE_BACKED_ITERABLE.urn
+ new_coder.spec.spec.payload = b'1'
+ return new_coder_id
+ else:
+ new_component_ids = [
+ with_state_iterables(c) for c in coder.component_coder_ids]
+ if new_component_ids == coder.component_coder_ids:
+ return coder_id
+ else:
+ new_coder_id = unique_name(pipeline_components.coders, 'coder')
+ new_coder = pipeline_components.coders[new_coder_id]
+ new_coder.CopyFrom(coder)
+ new_coder.component_coder_ids[:] = new_component_ids
+ return new_coder_id
+
safe_coders = {}
def length_prefix_unknown_coders(pcoll, pipeline_components):
@@ -572,6 +605,10 @@ def expand_gbk(stages):
length_prefix_unknown_coders(
pipeline_components.pcollections[pcoll_id],
pipeline_components)
for pcoll_id in transform.outputs.values():
+ if self._use_state_iterables:
+ pipeline_components.pcollections[
+ pcoll_id].coder_id = with_state_iterables(
+ pipeline_components.pcollections[pcoll_id].coder_id)
length_prefix_unknown_coders(
pipeline_components.pcollections[pcoll_id],
pipeline_components)
@@ -967,7 +1004,19 @@ def run_stages(self, pipeline_components, stages,
safe_coders):
def run_stage(
self, controller, pipeline_components, stage, pcoll_buffers,
safe_coders):
- context = pipeline_context.PipelineContext(pipeline_components)
+ def iterable_state_write(values, element_coder_impl):
+ token = unique_name(None, 'iter').encode('ascii')
+ out = create_OutputStream()
+ for element in values:
+ element_coder_impl.encode_to_stream(element, out, True)
+ controller.state_handler.blocking_append(
+ beam_fn_api_pb2.StateKey(
+ runner=beam_fn_api_pb2.StateKey.Runner(key=token)),
+ out.get())
+ return token
+
+ context = pipeline_context.PipelineContext(
+ pipeline_components, iterable_state_write=iterable_state_write)
data_api_service_descriptor = controller.data_api_service_descriptor()
def extract_endpoints(stage):
diff --git
a/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py
b/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py
index 79418d99862c..3889beef3eea 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_transforms.py
@@ -152,8 +152,15 @@ def union(a, b):
return frozenset.union(a, b)
+_global_counter = 0
+
+
def unique_name(existing, prefix):
- if prefix in existing:
+ if existing is None:
+ global _global_counter
+ _global_counter += 1
+ return '%s_%d' % (prefix, _global_counter)
+ elif prefix in existing:
counter = 0
while True:
counter += 1
diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py
b/sdks/python/apache_beam/runners/worker/bundle_processor.py
index 92a0c0ddc7f7..cc32c0bcd964 100644
--- a/sdks/python/apache_beam/runners/worker/bundle_processor.py
+++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py
@@ -33,6 +33,7 @@
from future.utils import itervalues
import apache_beam as beam
+from apache_beam import coders
from apache_beam.coders import WindowedValueCoder
from apache_beam.coders import coder_impl
from apache_beam.internal import pickler
@@ -126,10 +127,13 @@ def process_encoded(self, encoded_windowed_values):
class _StateBackedIterable(object):
- def __init__(self, state_handler, state_key, coder):
+ def __init__(self, state_handler, state_key, coder_or_impl):
self._state_handler = state_handler
self._state_key = state_key
- self._coder_impl = coder.get_impl()
+ if isinstance(coder_or_impl, coders.Coder):
+ self._coder_impl = coder_or_impl.get_impl()
+ else:
+ self._coder_impl = coder_or_impl
def __iter__(self):
# TODO(robertwb): Support pagination.
@@ -546,7 +550,14 @@ def __init__(self, descriptor, data_channel_factory,
counter_factory,
self.counter_factory = counter_factory
self.state_sampler = state_sampler
self.state_handler = state_handler
- self.context = pipeline_context.PipelineContext(descriptor)
+ self.context = pipeline_context.PipelineContext(
+ descriptor,
+ iterable_state_read=lambda token, element_coder_impl:
+ _StateBackedIterable(
+ state_handler,
+ beam_fn_api_pb2.StateKey(
+ runner=beam_fn_api_pb2.StateKey.Runner(key=token)),
+ element_coder_impl))
_known_urns = {}
diff --git a/sdks/python/apache_beam/utils/proto_utils.py
b/sdks/python/apache_beam/utils/proto_utils.py
index 5dceb174e5ab..9a76448b1068 100644
--- a/sdks/python/apache_beam/utils/proto_utils.py
+++ b/sdks/python/apache_beam/utils/proto_utils.py
@@ -48,14 +48,14 @@ def unpack_Any(any_msg, msg_class):
return msg
-def parse_Bytes(bytes, msg_class):
+def parse_Bytes(serialized_bytes, msg_class):
"""Parses the String of bytes into msg_class.
Returns the input bytes if msg_class is None."""
- if msg_class is None:
- return bytes
+ if msg_class is None or msg_class is bytes:
+ return serialized_bytes
msg = msg_class()
- msg.ParseFromString(bytes)
+ msg.ParseFromString(serialized_bytes)
return msg
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
Issue Time Tracking
-------------------
Worklog Id: (was: 173521)
Time Spent: 4h (was: 3h 50m)
> Support retrieval of large gbk iterables over the state API.
> ------------------------------------------------------------
>
> Key: BEAM-6120
> URL: https://issues.apache.org/jira/browse/BEAM-6120
> Project: Beam
> Issue Type: Improvement
> Components: sdk-py-harness
> Reporter: Robert Bradshaw
> Assignee: Robert Bradshaw
> Priority: Major
> Time Spent: 4h
> Remaining Estimate: 0h
>
--
This message was sent by Atlassian JIRA
(v7.6.3#76005)