[
https://issues.apache.org/jira/browse/BEAM-2687?focusedWorklogId=149867&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-149867
]
ASF GitHub Bot logged work on BEAM-2687:
----------------------------------------
Author: ASF GitHub Bot
Created on: 30/Sep/18 23:27
Start Date: 30/Sep/18 23:27
Worklog Time Spent: 10m
Work Description: charlesccychen closed pull request #6433: [BEAM-2687]
Implement Timers over the Fn API.
URL: https://github.com/apache/beam/pull/6433
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/sdks/python/apache_beam/coders/coder_impl.py
b/sdks/python/apache_beam/coders/coder_impl.py
index 6fd9b169ed6..f82f86756a3 100644
--- a/sdks/python/apache_beam/coders/coder_impl.py
+++ b/sdks/python/apache_beam/coders/coder_impl.py
@@ -435,6 +435,23 @@ def estimate_size(self, unused_value, nested=False):
return 8
+class TimerCoderImpl(StreamCoderImpl):
+ """For internal use only; no backwards-compatibility guarantees."""
+ def __init__(self, payload_coder_impl):
+ self._timestamp_coder_impl = TimestampCoderImpl()
+ self._payload_coder_impl = payload_coder_impl
+
+ def encode_to_stream(self, value, out, nested):
+ self._timestamp_coder_impl.encode_to_stream(value['timestamp'], out, True)
+ self._payload_coder_impl.encode_to_stream(value.get('payload'), out, True)
+
+ def decode_from_stream(self, in_stream, nested):
+ return dict(
+ timestamp=self._timestamp_coder_impl.decode_from_stream(
+ in_stream, True),
+ payload=self._payload_coder_impl.decode_from_stream(in_stream, True))
+
+
small_ints = [chr(_).encode('latin-1') for _ in range(128)]
diff --git a/sdks/python/apache_beam/coders/coders.py
b/sdks/python/apache_beam/coders/coders.py
index ad4edbbb374..57a0d54c7f6 100644
--- a/sdks/python/apache_beam/coders/coders.py
+++ b/sdks/python/apache_beam/coders/coders.py
@@ -432,6 +432,34 @@ def __hash__(self):
return hash(type(self))
+class _TimerCoder(FastCoder):
+ """A coder used for timer values.
+
+ For internal use."""
+ def __init__(self, payload_coder):
+ self._payload_coder = payload_coder
+
+ def _get_component_coders(self):
+ return [self._payload_coder]
+
+ def _create_impl(self):
+ return coder_impl.TimerCoderImpl(self._payload_coder.get_impl())
+
+ def is_deterministic(self):
+ return self._payload_coder.is_deterministic()
+
+ def __eq__(self, other):
+ return (type(self) == type(other)
+ and self._payload_coder == other._payload_coder)
+
+ def __hash__(self):
+ return hash(type(self)) + hash(self._payload_coder)
+
+
+Coder.register_structured_urn(
+ common_urns.coders.TIMER.urn, _TimerCoder)
+
+
class SingletonCoder(FastCoder):
"""A coder that always encodes exactly one value."""
diff --git a/sdks/python/apache_beam/coders/coders_test_common.py
b/sdks/python/apache_beam/coders/coders_test_common.py
index 969c1de10c4..607d3e25741 100644
--- a/sdks/python/apache_beam/coders/coders_test_common.py
+++ b/sdks/python/apache_beam/coders/coders_test_common.py
@@ -197,6 +197,15 @@ def test_timestamp_coder(self):
coders.TupleCoder((coders.TimestampCoder(), coders.BytesCoder())),
(timestamp.Timestamp.of(27), b'abc'))
+ def test_timer_coder(self):
+ self.check_coder(coders._TimerCoder(coders.BytesCoder()),
+ *[{'timestamp': timestamp.Timestamp(micros=x),
+ 'payload': b'xyz'}
+ for x in range(-3, 3)])
+ self.check_coder(
+ coders.TupleCoder((coders._TimerCoder(coders.VarIntCoder()),)),
+ ({'timestamp': timestamp.Timestamp.of(37), 'payload': 389},))
+
def test_tuple_coder(self):
kv_coder = coders.TupleCoder((coders.VarIntCoder(), coders.BytesCoder()))
# Verify cloud object representation
diff --git a/sdks/python/apache_beam/coders/standard_coders_test.py
b/sdks/python/apache_beam/coders/standard_coders_test.py
index 031406f40ac..cb9d43b8775 100644
--- a/sdks/python/apache_beam/coders/standard_coders_test.py
+++ b/sdks/python/apache_beam/coders/standard_coders_test.py
@@ -63,7 +63,8 @@ class StandardCodersTest(unittest.TestCase):
'beam:coder:iterable:v1': lambda t: coders.IterableCoder(t),
'beam:coder:global_window:v1': coders.GlobalWindowCoder,
'beam:coder:windowed_value:v1':
- lambda v, w: coders.WindowedValueCoder(v, w)
+ lambda v, w: coders.WindowedValueCoder(v, w),
+ 'beam:coder:timer:v1': coders._TimerCoder,
}
_urn_to_json_value_parser = {
@@ -81,7 +82,11 @@ class StandardCodersTest(unittest.TestCase):
'beam:coder:windowed_value:v1':
lambda x, value_parser, window_parser: windowed_value.create(
value_parser(x['value']), x['timestamp'] * 1000,
- tuple([window_parser(w) for w in x['windows']]))
+ tuple([window_parser(w) for w in x['windows']])),
+ 'beam:coder:timer:v1':
+ lambda x, payload_parser: dict(
+ payload=payload_parser(x['payload']),
+ timestamp=Timestamp(micros=x['timestamp'])),
}
def test_standard_coders(self):
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py
b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
index 2fc92e887bf..1b77a6e0944 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
@@ -252,6 +252,7 @@ def __init__(self, name, transforms,
self.transforms = transforms
self.downstream_side_inputs = downstream_side_inputs
self.must_follow = must_follow
+ self.timer_pcollections = []
def __repr__(self):
must_follow = ', '.join(prev.name for prev in self.must_follow)
@@ -848,6 +849,86 @@ def fuse(producer, consumer):
stage.deduplicate_read()
return final_stages
+ def inject_timer_pcollections(stages):
+ for stage in stages:
+ for transform in list(stage.transforms):
+ if transform.spec.urn == common_urns.primitives.PAR_DO.urn:
+ payload = proto_utils.parse_Bytes(
+ transform.spec.payload, beam_runner_api_pb2.ParDoPayload)
+ for tag in payload.timer_specs.keys():
+ input_pcoll = pipeline_components.pcollections[
+ next(iter(transform.inputs.values()))]
+ # Create the appropriate coder for the timer PCollection.
+ void_coder_id = add_or_get_coder_id(
+ beam.coders.SingletonCoder(None).to_runner_api(None))
+ timer_coder_id = add_or_get_coder_id(
+ beam_runner_api_pb2.Coder(
+ spec=beam_runner_api_pb2.SdkFunctionSpec(
+ spec=beam_runner_api_pb2.FunctionSpec(
+ urn=common_urns.coders.TIMER.urn)),
+ component_coder_ids=[void_coder_id]))
+ key_coder_id = input_pcoll.coder_id
+ if (pipeline_components.coders[key_coder_id].spec.spec.urn
+ == common_urns.coders.WINDOWED_VALUE.urn):
+ key_coder_id = pipeline_components.coders[
+ key_coder_id].component_coder_ids[0]
+ if (pipeline_components.coders[key_coder_id].spec.spec.urn
+ == common_urns.coders.KV.urn):
+ key_coder_id = pipeline_components.coders[
+ key_coder_id].component_coder_ids[0]
+ key_timer_coder_id = add_or_get_coder_id(
+ beam_runner_api_pb2.Coder(
+ spec=beam_runner_api_pb2.SdkFunctionSpec(
+ spec=beam_runner_api_pb2.FunctionSpec(
+ urn=common_urns.coders.KV.urn)),
+ component_coder_ids=[key_coder_id, timer_coder_id]))
+ timer_pcoll_coder_id = windowed_coder_id(
+ key_timer_coder_id,
+ pipeline_components.windowing_strategies[
+ input_pcoll.windowing_strategy_id].window_coder_id)
+ # Inject the read and write pcollections.
+ timer_read_pcoll = unique_name(
+ pipeline_components.pcollections,
+ '%s_timers_to_read_%s' % (transform.unique_name, tag))
+ timer_write_pcoll = unique_name(
+ pipeline_components.pcollections,
+ '%s_timers_to_write_%s' % (transform.unique_name, tag))
+ pipeline_components.pcollections[timer_read_pcoll].CopyFrom(
+ beam_runner_api_pb2.PCollection(
+ unique_name=timer_read_pcoll,
+ coder_id=timer_pcoll_coder_id,
+ windowing_strategy_id=input_pcoll.windowing_strategy_id,
+ is_bounded=input_pcoll.is_bounded))
+ pipeline_components.pcollections[timer_write_pcoll].CopyFrom(
+ beam_runner_api_pb2.PCollection(
+ unique_name=timer_write_pcoll,
+ coder_id=timer_pcoll_coder_id,
+ windowing_strategy_id=input_pcoll.windowing_strategy_id,
+ is_bounded=input_pcoll.is_bounded))
+ stage.transforms.append(
+ beam_runner_api_pb2.PTransform(
+ unique_name=timer_read_pcoll + '/Read',
+ outputs={'out': timer_read_pcoll},
+ spec=beam_runner_api_pb2.FunctionSpec(
+ urn=bundle_processor.DATA_INPUT_URN,
+ payload=('timers:%s' % timer_read_pcoll).encode(
+ 'utf-8'))))
+ stage.transforms.append(
+ beam_runner_api_pb2.PTransform(
+ unique_name=timer_write_pcoll + '/Write',
+ inputs={'in': timer_write_pcoll},
+ spec=beam_runner_api_pb2.FunctionSpec(
+ urn=bundle_processor.DATA_OUTPUT_URN,
+ payload=('timers:%s' % timer_write_pcoll).encode(
+ 'utf-8'))))
+ assert tag not in transform.inputs
+ transform.inputs[tag] = timer_read_pcoll
+ assert tag not in transform.outputs
+ transform.outputs[tag] = timer_write_pcoll
+ stage.timer_pcollections.append(
+ (timer_read_pcoll + '/Read', timer_write_pcoll))
+ yield stage
+
def sort_stages(stages):
"""Order stages suitable for sequential execution.
"""
@@ -908,7 +989,7 @@ def leaf_transforms(root_ids):
for phase in [
annotate_downstream_side_inputs, fix_side_input_pcoll_coders,
lift_combiners, expand_gbk, sink_flattens, greedily_fuse,
- impulse_to_input, sort_stages]:
+ impulse_to_input, inject_timer_pcollections, sort_stages]:
logging.info('%s %s %s', '=' * 20, phase, '=' * 20)
stages = list(phase(stages))
logging.debug('Stages: %s', [str(s) for s in stages])
@@ -1016,7 +1097,8 @@ def extract_endpoints(stage):
controller.state_handler.blocking_append(state_key, elements_data)
def get_buffer(pcoll_id):
- if pcoll_id.startswith(b'materialize:'):
+ if (pcoll_id.startswith(b'materialize:')
+ or pcoll_id.startswith(b'timers:')):
if pcoll_id not in pcoll_buffers:
# Just store the data chunks for replay.
pcoll_buffers[pcoll_id] = list()
@@ -1043,10 +1125,51 @@ def get_buffer(pcoll_id):
raise NotImplementedError(pcoll_id)
return pcoll_buffers[pcoll_id]
- return BundleManager(
+ result = BundleManager(
controller, get_buffer, process_bundle_descriptor,
self._progress_frequency).process_bundle(data_input, data_output)
+ while True:
+ timer_inputs = {}
+ for transform_id, timer_writes in stage.timer_pcollections:
+ windowed_timer_coder_impl = context.coders[
+ pipeline_components.pcollections[timer_writes].coder_id].get_impl()
+ written_timers = get_buffer(b'timers:' + timer_writes.encode('utf-8'))
+ if written_timers:
+ # Keep only the "last" timer set per key and window.
+ timers_by_key_and_window = {}
+ for elements_data in written_timers:
+ input_stream = create_InputStream(elements_data)
+ while input_stream.size() > 0:
+ windowed_key_timer =
windowed_timer_coder_impl.decode_from_stream(
+ input_stream, True)
+ key, _ = windowed_key_timer.value
+ # TODO: Explode and merge windows.
+ assert len(windowed_key_timer.windows) == 1
+ timers_by_key_and_window[
+ key, windowed_key_timer.windows[0]] = windowed_key_timer
+ out = create_OutputStream()
+ for windowed_key_timer in timers_by_key_and_window.values():
+ windowed_timer_coder_impl.encode_to_stream(
+ windowed_key_timer, out, True)
+ timer_inputs[transform_id, 'out'] = [out.get()]
+ written_timers[:] = []
+ if timer_inputs:
+ # The worker will be waiting on these inputs as well.
+ for other_input in data_input:
+ if other_input not in timer_inputs:
+ timer_inputs[other_input] = []
+ # TODO(robertwb): merge results
+ BundleManager(
+ controller,
+ get_buffer,
+ process_bundle_descriptor,
+ self._progress_frequency).process_bundle(timer_inputs, data_output)
+ else:
+ break
+
+ return result
+
# These classes are used to interact with the worker.
class StateServicer(beam_fn_api_pb2_grpc.BeamFnStateServicer):
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
index cf08230a12b..3bb0feb74a9 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
@@ -256,6 +256,29 @@ def process(self, kv,
index=beam.DoFn.StateParam(index_state_spec)):
assert_that(p | beam.Create(inputs) | beam.ParDo(AddIndex()),
equal_to(expected))
+ def test_pardo_timers(self):
+ timer_spec = userstate.TimerSpec('timer', userstate.TimeDomain.WATERMARK)
+
+ class TimerDoFn(beam.DoFn):
+ def process(self, element, timer=beam.DoFn.TimerParam(timer_spec)):
+ unused_key, ts = element
+ timer.set(ts)
+ timer.set(2 * ts)
+
+ @userstate.on_timer(timer_spec)
+ def process_timer(self):
+ yield 'fired'
+
+ with self.create_pipeline() as p:
+ actual = (
+ p
+ | beam.Create([('k1', 10), ('k2', 100)])
+ | beam.ParDo(TimerDoFn())
+ | beam.Map(lambda x, ts=beam.DoFn.TimestampParam: (x, ts)))
+
+ expected = [('fired', ts) for ts in (20, 200)]
+ assert_that(actual, equal_to(expected))
+
def test_group_by_key(self):
with self.create_pipeline() as p:
res = (p
diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py
b/sdks/python/apache_beam/runners/worker/bundle_processor.py
index 6b633576020..70968206123 100644
--- a/sdks/python/apache_beam/runners/worker/bundle_processor.py
+++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py
@@ -50,6 +50,7 @@
from apache_beam.transforms import userstate
from apache_beam.utils import counters
from apache_beam.utils import proto_utils
+from apache_beam.utils import timestamp
# This module is experimental. No backwards-compatibility guarantees.
@@ -255,15 +256,39 @@ def clear(self):
self._state_handler.blocking_clear(self._state_key)
+class OutputTimer(object):
+ def __init__(self, key, receiver):
+ self._key = key
+ self._receiver = receiver
+
+ def set(self, ts):
+ from apache_beam.transforms.window import GlobalWindows
+ self._receiver.receive(
+ GlobalWindows.windowed_value(
+ (self._key,
+ dict(timestamp=timestamp.Timestamp.of(ts)))))
+
+ def clear(self, timestamp):
+ self._receiver.receive((self._key, dict(clear=True)))
+
+
class FnApiUserStateContext(userstate.UserStateContext):
- def __init__(self, state_handler, transform_id, key_coder, window_coder):
+ def __init__(
+ self, state_handler, transform_id, key_coder, window_coder, timer_specs):
self._state_handler = state_handler
self._transform_id = transform_id
self._key_coder = key_coder
self._window_coder = window_coder
+ self._timer_specs = timer_specs
+ self._timer_receivers = None
+
+ def update_timer_receivers(self, receivers):
+ self._timer_receivers = {}
+ for tag in self._timer_specs:
+ self._timer_receivers[tag] = receivers.pop(tag)
def get_timer(self, timer_spec, key, window):
- raise NotImplementedError
+ return OutputTimer(key, self._timer_receivers[timer_spec.name])
def get_state(self, state_spec, key, window):
if isinstance(state_spec,
@@ -360,7 +385,6 @@ def topological_height(transform_id):
descriptor.transforms, key=topological_height, reverse=True)])
def process_bundle(self, instruction_id):
-
expected_inputs = []
for op in self.ops.values():
if isinstance(op, DataOutputOperation):
@@ -380,11 +404,19 @@ def process_bundle(self, instruction_id):
op.start()
# Inject inputs from data plane.
+ data_channels = collections.defaultdict(list)
+ input_op_by_target = {}
for input_op in expected_inputs:
- for data in input_op.data_channel.input_elements(
- instruction_id, [input_op.target]):
- # ignores input name
- input_op.process_encoded(data.data)
+ data_channels[input_op.data_channel].append(input_op.target)
+ # ignores input name
+ input_op_by_target[
+ input_op.target.primitive_transform_reference] = input_op
+ for data_channel, expected_targets in data_channels.items():
+ for data in data_channel.input_elements(
+ instruction_id, expected_targets):
+ input_op_by_target[
+ data.target.primitive_transform_reference
+ ].process_encoded(data.data)
# Finish all operations.
for op in self.ops.values():
@@ -499,9 +531,31 @@ def augment_oldstyle_op(op, step_name, consumers,
tag_list=None):
return op
+class TimerConsumer(operations.Operation):
+ def __init__(self, timer_tag, do_op):
+ self._timer_tag = timer_tag
+ self._do_op = do_op
+
+ def process(self, windowed_value):
+ self._do_op.process_timer(self._timer_tag, windowed_value)
+
+
@BeamTransformFactory.register_urn(
DATA_INPUT_URN, beam_fn_api_pb2.RemoteGrpcPort)
def create(factory, transform_id, transform_proto, grpc_port, consumers):
+ # Timers are the one special case where we don't want to call the
+ # (unlabeled) operation.process() method, which we detect here.
+ # TODO(robertwb): Consider generalizing if there are any more cases.
+ output_pcoll = only_element(transform_proto.outputs.values())
+ output_consumers = only_element(consumers.values())
+ if (len(output_consumers) == 1
+ and isinstance(only_element(output_consumers), operations.DoOperation)):
+ do_op = only_element(output_consumers)
+ for tag, pcoll_id in do_op.timer_inputs.items():
+ if pcoll_id == output_pcoll:
+ output_consumers[:] = [TimerConsumer(tag, do_op)]
+ break
+
target = beam_fn_api_pb2.Target(
primitive_transform_reference=transform_id,
name=only_element(list(transform_proto.outputs.keys())))
@@ -597,18 +651,18 @@ def create(factory, transform_id, transform_proto,
parameter, consumers):
serialized_fn = parameter.do_fn.spec.payload
return _create_pardo_operation(
factory, transform_id, transform_proto, consumers,
- serialized_fn, parameter.side_inputs)
+ serialized_fn, parameter)
def _create_pardo_operation(
factory, transform_id, transform_proto, consumers,
- serialized_fn, side_inputs_proto=None):
+ serialized_fn, pardo_proto=None):
- if side_inputs_proto:
+ if pardo_proto and pardo_proto.side_inputs:
input_tags_to_coders = factory.get_input_coders(transform_proto)
tagged_side_inputs = [
(tag, beam.pvalue.SideInputData.from_runner_api(si, factory.context))
- for tag, si in side_inputs_proto.items()]
+ for tag, si in pardo_proto.side_inputs.items()]
tagged_side_inputs.sort(
key=lambda tag_si: int(re.match('side([0-9]+)(-.*)?$',
tag_si[0]).group(1)))
@@ -638,22 +692,40 @@ def mutate_tag(tag):
dofn_data = pickler.loads(serialized_fn)
if not dofn_data[-1]:
# Windowing not set.
- side_input_tags = side_inputs_proto or ()
+ if pardo_proto:
+ other_input_tags = set.union(
+ set(pardo_proto.side_inputs), set(pardo_proto.timer_specs))
+ else:
+ other_input_tags = ()
pcoll_id, = [pcoll for tag, pcoll in transform_proto.inputs.items()
- if tag not in side_input_tags]
+ if tag not in other_input_tags]
windowing = factory.context.windowing_strategies.get_by_id(
factory.descriptor.pcollections[pcoll_id].windowing_strategy_id)
serialized_fn = pickler.dumps(dofn_data[:-1] + (windowing,))
- if userstate.is_stateful_dofn(dofn_data[0]):
- input_coder = factory.get_only_input_coder(transform_proto)
+ if pardo_proto and (pardo_proto.timer_specs or pardo_proto.state_specs):
+ main_input_coder = None
+ timer_inputs = {}
+ for tag, pcoll_id in transform_proto.inputs.items():
+ if tag in pardo_proto.timer_specs:
+ timer_inputs[tag] = pcoll_id
+ elif tag in pardo_proto.side_inputs:
+ pass
+ else:
+ # Must be the main input
+ assert main_input_coder is None
+ main_input_coder = factory.get_windowed_coder(pcoll_id)
+ assert main_input_coder is not None
+
user_state_context = FnApiUserStateContext(
factory.state_handler,
transform_id,
- input_coder.key_coder(),
- input_coder.window_coder)
+ main_input_coder.key_coder(),
+ main_input_coder.window_coder,
+ timer_specs=pardo_proto.timer_specs)
else:
user_state_context = None
+ timer_inputs = None
output_coders = factory.get_output_coders(transform_proto)
spec = operation_specs.WorkerDoFn(
@@ -670,7 +742,8 @@ def mutate_tag(tag):
factory.counter_factory,
factory.state_sampler,
side_input_maps,
- user_state_context),
+ user_state_context,
+ timer_inputs=timer_inputs),
transform_proto.unique_name,
consumers,
output_tags)
diff --git a/sdks/python/apache_beam/runners/worker/data_plane.py
b/sdks/python/apache_beam/runners/worker/data_plane.py
index f9c008171df..b280a0b28f0 100644
--- a/sdks/python/apache_beam/runners/worker/data_plane.py
+++ b/sdks/python/apache_beam/runners/worker/data_plane.py
@@ -139,7 +139,8 @@ def inverse(self):
def input_elements(self, instruction_id, unused_expected_targets=None):
for data in self._inputs:
if data.instruction_reference == instruction_id:
- yield data
+ if data.data:
+ yield data
def output_stream(self, instruction_id, target):
def add_to_inverse_output(data):
diff --git a/sdks/python/apache_beam/runners/worker/operations.pxd
b/sdks/python/apache_beam/runners/worker/operations.pxd
index ad82aae24ad..9cde9da75b2 100644
--- a/sdks/python/apache_beam/runners/worker/operations.pxd
+++ b/sdks/python/apache_beam/runners/worker/operations.pxd
@@ -82,6 +82,8 @@ cdef class DoOperation(Operation):
cdef object tagged_receivers
cdef object side_input_maps
cdef object user_state_context
+ cdef public dict timer_inputs
+ cdef dict timer_specs
cdef class CombineOperation(Operation):
diff --git a/sdks/python/apache_beam/runners/worker/operations.py
b/sdks/python/apache_beam/runners/worker/operations.py
index 41792cbfb38..b16a2c9486d 100644
--- a/sdks/python/apache_beam/runners/worker/operations.py
+++ b/sdks/python/apache_beam/runners/worker/operations.py
@@ -42,6 +42,7 @@
from apache_beam.transforms import sideinputs as apache_sideinputs
from apache_beam.transforms import combiners
from apache_beam.transforms import core
+from apache_beam.transforms import userstate
from apache_beam.transforms.combiners import PhasedCombineFnExecutor
from apache_beam.transforms.combiners import curry_combine_fn
from apache_beam.transforms.window import GlobalWindows
@@ -296,11 +297,13 @@ class DoOperation(Operation):
def __init__(
self, name, spec, counter_factory, sampler, side_input_maps=None,
- user_state_context=None):
+ user_state_context=None, timer_inputs=None):
super(DoOperation, self).__init__(name, spec, counter_factory, sampler)
self.side_input_maps = side_input_maps
self.user_state_context = user_state_context
self.tagged_receivers = None
+ # A mapping of timer tags to the input "PCollections" they come in on.
+ self.timer_inputs = timer_inputs or {}
def _read_side_inputs(self, tags_and_types):
"""Generator reading side inputs in the order prescribed by tags_and_types.
@@ -389,6 +392,13 @@ def start(self):
raise ValueError('Unexpected output name for operation: %s' % tag)
self.tagged_receivers[original_tag] = self.receivers[index]
+ if self.user_state_context:
+ self.user_state_context.update_timer_receivers(self.tagged_receivers)
+ self.timer_specs = {
+ spec.name: spec
+ for spec in userstate.get_dofn_specs(fn)[1]
+ }
+
if self.side_input_maps is None:
if tags_and_types:
self.side_input_maps = list(self._read_side_inputs(tags_and_types))
@@ -413,6 +423,12 @@ def process(self, o):
with self.scoped_process_state:
self.dofn_receiver.receive(o)
+ def process_timer(self, tag, windowed_timer):
+ key, timer_data = windowed_timer.value
+ timer_spec = self.timer_specs[tag]
+ self.dofn_receiver.process_user_timer(
+ timer_spec, key, windowed_timer.windows[0], timer_data['timestamp'])
+
def finish(self):
with self.scoped_finish_state:
self.dofn_runner.finish()
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
Issue Time Tracking
-------------------
Worklog Id: (was: 149867)
Time Spent: 5h 40m (was: 5.5h)
> Python SDK support for Stateful Processing
> ------------------------------------------
>
> Key: BEAM-2687
> URL: https://issues.apache.org/jira/browse/BEAM-2687
> Project: Beam
> Issue Type: New Feature
> Components: sdk-py-core
> Reporter: Ahmet Altay
> Assignee: Charles Chen
> Priority: Major
> Time Spent: 5h 40m
> Remaining Estimate: 0h
>
> Python SDK should support stateful processing
> (https://beam.apache.org/blog/2017/02/13/stateful-processing.html)
> In the meantime, runner capability matrix should be updated to show the lack
> of this feature
> (https://beam.apache.org/documentation/runners/capability-matrix/)
> Use this as an umbrella issue for all related issues.
--
This message was sent by Atlassian JIRA
(v7.6.3#76005)