This is an automated email from the ASF dual-hosted git repository.
robertwb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new c9e5ea8 [BEAM -7741] Implement SetState for Python SDK (#9090)
c9e5ea8 is described below
commit c9e5ea843841ac4898d0104e536bd4b2fc297d33
Author: Rakesh Kumar <[email protected]>
AuthorDate: Thu Aug 8 07:17:26 2019 -0700
[BEAM -7741] Implement SetState for Python SDK (#9090)
---
.../apache_beam/runners/direct/direct_userstate.py | 9 ++
.../apache_beam/runners/worker/bundle_processor.py | 70 +++++++++
sdks/python/apache_beam/transforms/trigger.py | 16 +++
sdks/python/apache_beam/transforms/userstate.py | 51 +++++++
.../apache_beam/transforms/userstate_test.py | 156 ++++++++++++++++++++-
5 files changed, 301 insertions(+), 1 deletion(-)
diff --git a/sdks/python/apache_beam/runners/direct/direct_userstate.py
b/sdks/python/apache_beam/runners/direct/direct_userstate.py
index f0fd9b8..b764ea4 100644
--- a/sdks/python/apache_beam/runners/direct/direct_userstate.py
+++ b/sdks/python/apache_beam/runners/direct/direct_userstate.py
@@ -20,6 +20,7 @@ from __future__ import absolute_import
from apache_beam.transforms import userstate
from apache_beam.transforms.trigger import _ListStateTag
+from apache_beam.transforms.trigger import _SetStateTag
class DirectUserStateContext(userstate.UserStateContext):
@@ -43,6 +44,8 @@ class DirectUserStateContext(userstate.UserStateContext):
state_tag = _ListStateTag(state_key)
elif isinstance(state_spec, userstate.CombiningValueStateSpec):
state_tag = _ListStateTag(state_key)
+ elif isinstance(state_spec, userstate.SetStateSpec):
+ state_tag = _SetStateTag(state_key)
else:
raise ValueError('Invalid state spec: %s' % state_spec)
self.state_tags[state_spec] = state_tag
@@ -93,6 +96,12 @@ class DirectUserStateContext(userstate.UserStateContext):
state.add_state(
window, state_tag,
state_spec.coder.encode(runtime_state._current_accumulator))
+ elif isinstance(state_spec, userstate.SetStateSpec):
+ if runtime_state.is_modified():
+ state.clear_state(window, state_tag)
+ for new_value in runtime_state._current_accumulator:
+ state.add_state(
+ window, state_tag, state_spec.coder.encode(new_value))
else:
raise ValueError('Invalid state spec: %s' % state_spec)
diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py
b/sdks/python/apache_beam/runners/worker/bundle_processor.py
index a85e4da..72a41b7 100644
--- a/sdks/python/apache_beam/runners/worker/bundle_processor.py
+++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py
@@ -208,6 +208,7 @@ class _StateBackedIterable(object):
self._coder_impl = coder_or_impl
def __iter__(self):
+ # This is the continuation token this might be useful
data, continuation_token =
self._state_handler.blocking_get(self._state_key)
while True:
input_stream = coder_impl.create_InputStream(data)
@@ -379,6 +380,65 @@ class SynchronousBagRuntimeState(userstate.RuntimeState):
self._state_handler.blocking_append(self._state_key, out.get())
+# TODO(BEAM-5428): Implement cross-bundle state caching.
+class SynchronousSetRuntimeState(userstate.RuntimeState):
+
+ def __init__(self, state_handler, state_key, value_coder):
+ self._state_handler = state_handler
+ self._state_key = state_key
+ self._value_coder = value_coder
+ self._cleared = False
+ self._added_elements = set()
+
+ def _compact_data(self, rewrite=True):
+ accumulator = set(_ConcatIterable(
+ set() if self._cleared else _StateBackedIterable(
+ self._state_handler, self._state_key, self._value_coder),
+ self._added_elements))
+
+ if rewrite and accumulator:
+ self._state_handler.blocking_clear(self._state_key)
+
+ value_coder_impl = self._value_coder.get_impl()
+ out = coder_impl.create_OutputStream()
+ for element in accumulator:
+ value_coder_impl.encode_to_stream(element, out, True)
+ self._state_handler.blocking_append(self._state_key, out.get())
+
+ # Since everthing is already committed so we can safely reinitialize
+ # added_elements here.
+ self._added_elements = set()
+
+ return accumulator
+
+ def read(self):
+ return self._compact_data(rewrite=False)
+
+ def add(self, value):
+ if self._cleared:
+ # This is a good time explicitly clear.
+ self._state_handler.blocking_clear(self._state_key)
+ self._cleared = False
+
+ self._added_elements.add(value)
+ if random.random() > 0.5:
+ self._compact_data()
+
+ def clear(self):
+ self._cleared = True
+ self._added_elements = set()
+
+ def _commit(self):
+ if self._cleared:
+ self._state_handler.blocking_clear(self._state_key)
+ if self._added_elements:
+ value_coder_impl = self._value_coder.get_impl()
+ out = coder_impl.create_OutputStream()
+ for element in self._added_elements:
+ value_coder_impl.encode_to_stream(element, out, True)
+ self._state_handler.blocking_append(self._state_key, out.get())
+
+
class OutputTimer(object):
def __init__(self, key, window, receiver):
self._key = key
@@ -454,6 +514,16 @@ class FnApiUserStateContext(userstate.UserStateContext):
return bag_state
else:
return CombiningValueRuntimeState(bag_state, state_spec.combine_fn)
+ elif isinstance(state_spec, userstate.SetStateSpec):
+ return SynchronousSetRuntimeState(
+ self._state_handler,
+ state_key=beam_fn_api_pb2.StateKey(
+ bag_user_state=beam_fn_api_pb2.StateKey.BagUserState(
+ ptransform_id=self._transform_id,
+ user_state_id=state_spec.name,
+ window=self._window_coder.encode(window),
+ key=self._key_coder.encode(key))),
+ value_coder=state_spec.coder)
else:
raise NotImplementedError(state_spec)
diff --git a/sdks/python/apache_beam/transforms/trigger.py
b/sdks/python/apache_beam/transforms/trigger.py
index f0d2a1c..dbda3cc 100644
--- a/sdks/python/apache_beam/transforms/trigger.py
+++ b/sdks/python/apache_beam/transforms/trigger.py
@@ -93,6 +93,16 @@ class _ValueStateTag(_StateTag):
return _ValueStateTag(prefix + self.tag)
+class _SetStateTag(_StateTag):
+ """StateTag pointing to an element."""
+
+ def __repr__(self):
+ return 'SetStateTag({tag})'.format(tag=self.tag)
+
+ def with_prefix(self, prefix):
+ return _SetStateTag(prefix + self.tag)
+
+
class _CombiningValueStateTag(_StateTag):
"""StateTag pointing to an element, accumulated with a combiner.
@@ -865,6 +875,8 @@ class MergeableStateAdapter(SimpleState):
original_tag.combine_fn.merge_accumulators(values))
elif isinstance(tag, _ListStateTag):
return [v for vs in values for v in vs]
+ elif isinstance(tag, _SetStateTag):
+ return {v for vs in values for v in vs}
elif isinstance(tag, _WatermarkHoldStateTag):
return tag.timestamp_combiner_impl.combine_all(values)
else:
@@ -1226,6 +1238,8 @@ class InMemoryUnmergedState(UnmergedState):
self.state[window][tag.tag].append(value)
elif isinstance(tag, _ListStateTag):
self.state[window][tag.tag].append(value)
+ elif isinstance(tag, _SetStateTag):
+ self.state[window][tag.tag].append(value)
elif isinstance(tag, _WatermarkHoldStateTag):
self.state[window][tag.tag].append(value)
else:
@@ -1239,6 +1253,8 @@ class InMemoryUnmergedState(UnmergedState):
return tag.combine_fn.apply(values)
elif isinstance(tag, _ListStateTag):
return values
+ elif isinstance(tag, _SetStateTag):
+ return values
elif isinstance(tag, _WatermarkHoldStateTag):
return tag.timestamp_combiner_impl.combine_all(values)
else:
diff --git a/sdks/python/apache_beam/transforms/userstate.py
b/sdks/python/apache_beam/transforms/userstate.py
index aa4e866..4662d13 100644
--- a/sdks/python/apache_beam/transforms/userstate.py
+++ b/sdks/python/apache_beam/transforms/userstate.py
@@ -60,6 +60,23 @@ class BagStateSpec(StateSpec):
element_coder_id=context.coders.get_id(self.coder)))
+class SetStateSpec(StateSpec):
+ """Specification for a user DoFn Set State cell"""
+
+ def __init__(self, name, coder):
+ if not isinstance(name, str):
+ raise TypeError("SetState name is not a string")
+ if not isinstance(coder, Coder):
+ raise TypeError("SetState coder is not of type Coder")
+ self.name = name
+ self.coder = coder
+
+ def to_runner_api(self, context):
+ return beam_runner_api_pb2.StateSpec(
+ set_spec=beam_runner_api_pb2.SetStateSpec(
+ element_coder_id=context.coders.get_id(self.coder)))
+
+
class CombiningValueStateSpec(StateSpec):
"""Specification for a user DoFn combining value state cell."""
@@ -264,6 +281,8 @@ class RuntimeState(object):
elif isinstance(state_spec, CombiningValueStateSpec):
return CombiningValueRuntimeState(state_spec, state_tag,
current_value_accessor)
+ elif isinstance(state_spec, SetStateSpec):
+ return SetRuntimeState(state_spec, state_tag, current_value_accessor)
else:
raise ValueError('Invalid state spec: %s' % state_spec)
@@ -310,6 +329,38 @@ class BagRuntimeState(RuntimeState):
self._new_values = []
+class SetRuntimeState(RuntimeState):
+ """Set state interface object passed to user code."""
+
+ def __init__(self, state_spec, state_tag, current_value_accessor):
+ super(SetRuntimeState, self).__init__(
+ state_spec, state_tag, current_value_accessor)
+ self._current_accumulator = UNREAD_VALUE
+ self._modified = False
+
+ def _read_initial_value(self):
+ if self._current_accumulator is UNREAD_VALUE:
+ self._current_accumulator = {
+ self._decode(a) for a in self._current_value_accessor()
+ }
+
+ def read(self):
+ self._read_initial_value()
+ return self._current_accumulator
+
+ def add(self, value):
+ self._read_initial_value()
+ self._modified = True
+ self._current_accumulator.add(value)
+
+ def clear(self):
+ self._current_accumulator = set()
+ self._modified = True
+
+ def is_modified(self):
+ return self._modified
+
+
class CombiningValueRuntimeState(RuntimeState):
"""Combining value state interface object passed to user code."""
diff --git a/sdks/python/apache_beam/transforms/userstate_test.py
b/sdks/python/apache_beam/transforms/userstate_test.py
index 0d98337..8e55cee 100644
--- a/sdks/python/apache_beam/transforms/userstate_test.py
+++ b/sdks/python/apache_beam/transforms/userstate_test.py
@@ -31,6 +31,7 @@ from apache_beam.options.pipeline_options import
PipelineOptions
from apache_beam.runners.common import DoFnSignature
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.test_stream import TestStream
+from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.transforms import trigger
from apache_beam.transforms import userstate
@@ -41,6 +42,7 @@ from apache_beam.transforms.core import DoFn
from apache_beam.transforms.timeutil import TimeDomain
from apache_beam.transforms.userstate import BagStateSpec
from apache_beam.transforms.userstate import CombiningValueStateSpec
+from apache_beam.transforms.userstate import SetStateSpec
from apache_beam.transforms.userstate import TimerSpec
from apache_beam.transforms.userstate import get_dofn_specs
from apache_beam.transforms.userstate import is_stateful_dofn
@@ -114,7 +116,13 @@ class InterfaceTest(unittest.TestCase):
CombiningValueStateSpec(123, VarIntCoder(), TopCombineFn(10))
with self.assertRaises(TypeError):
CombiningValueStateSpec('statename', VarIntCoder(), object())
- # BagStateSpec('bag', )
+ SetStateSpec('setstatename', VarIntCoder())
+
+ with self.assertRaises(TypeError):
+ SetStateSpec(123, VarIntCoder())
+ with self.assertRaises(TypeError):
+ SetStateSpec('setstatename', object())
+
# TODO: add more spec tests
with self.assertRaises(ValueError):
DoFn.TimerParam(BagStateSpec('elements', BytesCoder()))
@@ -415,6 +423,152 @@ class StatefulDoFnOnDirectRunnerTest(unittest.TestCase):
['extra'],
StatefulDoFnOnDirectRunnerTest.all_records)
+ def test_simple_set_stateful_dofn(self):
+ class SimpleTestSetStatefulDoFn(DoFn):
+ BUFFER_STATE = SetStateSpec('buffer', VarIntCoder())
+ EXPIRY_TIMER = TimerSpec('expiry', TimeDomain.WATERMARK)
+
+ def process(self, element, buffer=DoFn.StateParam(BUFFER_STATE),
+ timer1=DoFn.TimerParam(EXPIRY_TIMER)):
+ unused_key, value = element
+ buffer.add(value)
+ timer1.set(20)
+
+ @on_timer(EXPIRY_TIMER)
+ def expiry_callback(self, buffer=DoFn.StateParam(BUFFER_STATE)):
+ yield sorted(buffer.read())
+
+ with TestPipeline() as p:
+ test_stream = (TestStream()
+ .advance_watermark_to(10)
+ .add_elements([1, 2, 3])
+ .add_elements([2])
+ .advance_watermark_to(24))
+ (p
+ | test_stream
+ | beam.Map(lambda x: ('mykey', x))
+ | beam.ParDo(SimpleTestSetStatefulDoFn())
+ | beam.ParDo(self.record_dofn()))
+
+ # Two firings should occur: once after element 3 since the timer should
+ # fire after the watermark passes time 20, and another time after element
+ # 4, since the timer issued at that point should fire immediately.
+ self.assertEqual(
+ [[1, 2, 3]],
+ StatefulDoFnOnDirectRunnerTest.all_records)
+
+ def test_clearing_set_state(self):
+ class SetStateClearingStatefulDoFn(beam.DoFn):
+
+ SET_STATE = SetStateSpec('buffer', StrUtf8Coder())
+ EMIT_TIMER = TimerSpec('emit_timer', TimeDomain.WATERMARK)
+ CLEAR_TIMER = TimerSpec('clear_timer', TimeDomain.WATERMARK)
+
+ def process(self,
+ element,
+ set_state=beam.DoFn.StateParam(SET_STATE),
+ emit_timer=beam.DoFn.TimerParam(EMIT_TIMER),
+ clear_timer=beam.DoFn.TimerParam(CLEAR_TIMER)):
+ value = element[1]
+ set_state.add(value)
+ clear_timer.set(100)
+ emit_timer.set(1000)
+
+ @on_timer(EMIT_TIMER)
+ def emit_values(self, set_state=beam.DoFn.StateParam(SET_STATE)):
+ for value in set_state.read():
+ yield value
+
+ @on_timer(CLEAR_TIMER)
+ def clear_values(self, set_state=beam.DoFn.StateParam(SET_STATE)):
+ set_state.clear()
+ set_state.add('different-value')
+
+ with TestPipeline() as p:
+ test_stream = (TestStream()
+ .advance_watermark_to(0)
+ .add_elements([('key1', 'value1')])
+ .advance_watermark_to(100))
+
+ _ = (p
+ | test_stream
+ | beam.ParDo(SetStateClearingStatefulDoFn())
+ | beam.ParDo(self.record_dofn()))
+
+ self.assertEqual(['different-value'],
+ StatefulDoFnOnDirectRunnerTest.all_records)
+
+ def test_stateful_set_state_portably(self):
+
+ class SetStatefulDoFn(beam.DoFn):
+
+ SET_STATE = SetStateSpec('buffer', VarIntCoder())
+
+ def process(self,
+ element,
+ set_state=beam.DoFn.StateParam(SET_STATE)):
+ _, value = element
+ aggregated_value = 0
+ set_state.add(value)
+ for saved_value in set_state.read():
+ aggregated_value += saved_value
+ yield aggregated_value
+
+ p = TestPipeline()
+ values = p | beam.Create([('key', 1),
+ ('key', 2),
+ ('key', 3),
+ ('key', 4),
+ ('key', 3)])
+ actual_values = (values
+ | beam.ParDo(SetStatefulDoFn()))
+
+ assert_that(actual_values, equal_to([1, 3, 6, 10, 10]))
+
+ result = p.run()
+ result.wait_until_finish()
+
+ def test_stateful_set_state_clean_portably(self):
+
+ class SetStateClearingStatefulDoFn(beam.DoFn):
+
+ SET_STATE = SetStateSpec('buffer', VarIntCoder())
+ EMIT_TIMER = TimerSpec('emit_timer', TimeDomain.WATERMARK)
+
+ def process(self,
+ element,
+ set_state=beam.DoFn.StateParam(SET_STATE),
+ emit_timer=beam.DoFn.TimerParam(EMIT_TIMER)):
+ _, value = element
+ set_state.add(value)
+
+ all_elements = [element for element in set_state.read()]
+
+ if len(all_elements) == 5:
+ set_state.clear()
+ set_state.add(100)
+ emit_timer.set(1)
+
+ @on_timer(EMIT_TIMER)
+ def emit_values(self, set_state=beam.DoFn.StateParam(SET_STATE)):
+ yield sorted(set_state.read())
+
+ p = TestPipeline()
+ values = p | beam.Create([('key', 1),
+ ('key', 2),
+ ('key', 3),
+ ('key', 4),
+ ('key', 5)])
+ actual_values = (values
+ | beam.Map(lambda t: window.TimestampedValue(t, 1))
+ | beam.WindowInto(window.FixedWindows(1))
+ | beam.ParDo(SetStateClearingStatefulDoFn()))
+
+ assert_that(actual_values, equal_to([[100]]))
+
+ result = p.run()
+ result.wait_until_finish()
+
def test_stateful_dofn_nonkeyed_input(self):
p = TestPipeline()
values = p | beam.Create([1, 2, 3])