This is an automated email from the ASF dual-hosted git repository.
weizhong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new 044f8e1 [FLINK-21115][python] Add AggregatingState and corresponding
StateDescriptor for Python DataStream API
044f8e1 is described below
commit 044f8e17237d0fdbff96ca973817c932a3c985ba
Author: Wei Zhong <[email protected]>
AuthorDate: Wed Mar 3 20:54:52 2021 +0800
[FLINK-21115][python] Add AggregatingState and corresponding
StateDescriptor for Python DataStream API
This closes #15028.
---
flink-python/pyflink/datastream/functions.py | 92 ++++++++++++++++-
flink-python/pyflink/datastream/state.py | 42 ++++++++
.../pyflink/datastream/tests/test_data_stream.py | 49 ++++++++-
flink-python/pyflink/fn_execution/operations.py | 8 +-
flink-python/pyflink/fn_execution/state_impl.py | 109 +++++++++++++--------
5 files changed, 254 insertions(+), 46 deletions(-)
diff --git a/flink-python/pyflink/datastream/functions.py
b/flink-python/pyflink/datastream/functions.py
index 60e718f..b6952ce 100644
--- a/flink-python/pyflink/datastream/functions.py
+++ b/flink-python/pyflink/datastream/functions.py
@@ -23,7 +23,8 @@ from typing import Union, Any, Dict
from py4j.java_gateway import JavaObject
from pyflink.datastream.state import ValueState, ValueStateDescriptor,
ListStateDescriptor, \
- ListState, MapStateDescriptor, MapState, ReducingStateDescriptor,
ReducingState
+ ListState, MapStateDescriptor, MapState, ReducingStateDescriptor,
ReducingState, \
+ AggregatingStateDescriptor, AggregatingState
from pyflink.datastream.time_domain import TimeDomain
from pyflink.datastream.timerservice import TimerService
from pyflink.java_gateway import get_gateway
@@ -35,6 +36,7 @@ __all__ = [
'FlatMapFunction',
'CoFlatMapFunction',
'ReduceFunction',
+ 'AggregateFunction',
'KeySelector',
'FilterFunction',
'Partitioner',
@@ -156,6 +158,17 @@ class RuntimeContext(object):
"""
pass
+ def get_aggregating_state(
+ self, state_descriptor: AggregatingStateDescriptor) ->
AggregatingState:
+ """
+ Gets a handle to the system's key/value aggregating state. This state
is similar to the
+ state accessed via get_state(ValueStateDescriptor), but is optimized
for state that
+ aggregates values with different types.
+
+ This state is only accessible if the function is executed on a
KeyedStream.
+ """
+ pass
+
class Function(abc.ABC):
"""
@@ -342,6 +355,83 @@ class ReduceFunction(Function):
pass
+class AggregateFunction(Function):
+ """
+ The AggregateFunction is a flexible aggregation function, characterized by
the following
+ features:
+
+ - The aggregates may use different types for input values,
intermediate aggregates, and
+ result type, to support a wide range of aggregation types.
+ - Support for distributive aggregations: Different intermediate
aggregates can be merged
+ together, to allow for pre-aggregation/final-aggregation
optimizations.
+
+ The AggregateFunction's intermediate aggregate (in-progress aggregation
state) is called the
+ `accumulator`. Values are added to the accumulator, and final aggregates
are obtained by
+ finalizing the accumulator state. This supports aggregation functions
where the intermediate
+ state needs to be different than the aggregated values and the final
result type, such as for
+ example average (which typically keeps a count and sum). Merging
intermediate aggregates
+ (partial aggregates) means merging the accumulators.
+
+ The AggregationFunction itself is stateless. To allow a single
AggregationFunction instance to
+ maintain multiple aggregates (such as one aggregate per key), the
AggregationFunction creates a
+ new accumulator whenever a new aggregation is started.
+ """
+
+ @abc.abstractmethod
+ def create_accumulator(self):
+ """
+ Creates a new accumulator, starting a new aggregate.
+
+ The new accumulator is typically meaningless unless a value is added
via
+ :func:`~AggregateFunction.add`.
+
+ The accumulator is the state of a running aggregation. When a program
has multiple
+ aggregates in progress (such as per key and window), the state (per
key and window) is the
+ size of the accumulator.
+
+ :return: A new accumulator, corresponding to an empty aggregate.
+ """
+ pass
+
+ @abc.abstractmethod
+ def add(self, value, accumulator):
+ """
+ Adds the given input value to the given accumulator, returning the new
accumulator value.
+
+ For efficiency, the input accumulator may be modified and returned.
+
+ :param value: The value to add.
+ :param accumulator: The accumulator to add the value to.
+ :return: The accumulator with the updated state.
+ """
+ pass
+
+ @abc.abstractmethod
+ def get_result(self, accumulator):
+ """
+ Gets the result of the aggregation from the accumulator.
+
+ :param accumulator: The accumulator of the aggregation.
+ :return: The final aggregation result.
+ """
+ pass
+
+ @abc.abstractmethod
+ def merge(self, acc_a, acc_b):
+ """
+ Merges two accumulators, returning an accumulator with the merged
state.
+
+ This function may reuse any of the given accumulators as the target
for the merge and
+ return that. The assumption is that the given accumulators will not be
used any more after
+ having been passed to this function.
+
+ :param acc_a: An accumulator to merge.
+ :param acc_b: Another accumulator to merge.
+ :return: The accumulator with the merged state.
+ """
+ pass
+
+
class KeySelector(Function):
"""
The KeySelector allows to use deterministic objects for operations such as
reduce, reduceGroup,
diff --git a/flink-python/pyflink/datastream/state.py
b/flink-python/pyflink/datastream/state.py
index 1e3cde9..4c11b29 100644
--- a/flink-python/pyflink/datastream/state.py
+++ b/flink-python/pyflink/datastream/state.py
@@ -134,6 +134,26 @@ class ReducingState(MergingState[T, T]):
pass
+class AggregatingState(MergingState[IN, OUT]):
+ """
+ :class:`State` interface for aggregating state, based on an
+ :class:`~pyflink.datastream.functions.AggregateFunction`. Elements that
are added to this type
+ of state will be eagerly pre-aggregated using a given AggregateFunction.
+
+ The state holds internally always the accumulator type of the
AggregateFunction. When
+ accessing the result of the state, the function's
+ :func:`~pyflink.datastream.functions.AggregateFunction.get_result` method.
+
+ The state is accessed and modified by user functions, and checkpointed
consistently by the
+ system as part of the distributed snapshots.
+
+ The state is only accessible by functions applied on a KeyedStream. The
key is automatically
+ supplied by the system, so the function always sees the value mapped to
the key of the current
+ element. That way, the system can handle stream and state partitioning
consistently together.
+ """
+ pass
+
+
class ListState(MergingState[T, Iterable[T]]):
"""
:class:`State` interface for partitioned list state in Operations.
@@ -358,3 +378,25 @@ class ReducingStateDescriptor(StateDescriptor):
def get_reduce_function(self):
return self._reduce_function
+
+
+class AggregatingStateDescriptor(StateDescriptor):
+ """
+ A StateDescriptor for AggregatingState.
+
+ The type internally stored in the state is the type of the Accumulator of
the
+ :func:`~pyflink.datastream.functions.AggregateFunction`.
+ """
+
+ def __init__(self,
+ name: str,
+ agg_function,
+ state_type_info):
+ super(AggregatingStateDescriptor, self).__init__(name, state_type_info)
+ from pyflink.datastream.functions import AggregateFunction
+ if not isinstance(agg_function, AggregateFunction):
+ raise TypeError("The input must be a AggregateFunction!")
+ self._agg_function = agg_function
+
+ def get_agg_function(self):
+ return self._agg_function
diff --git a/flink-python/pyflink/datastream/tests/test_data_stream.py
b/flink-python/pyflink/datastream/tests/test_data_stream.py
index 852bf8a..3d7a0da 100644
--- a/flink-python/pyflink/datastream/tests/test_data_stream.py
+++ b/flink-python/pyflink/datastream/tests/test_data_stream.py
@@ -25,12 +25,13 @@ from pyflink.common.typeinfo import Types
from pyflink.common.watermark_strategy import WatermarkStrategy,
TimestampAssigner
from pyflink.datastream import StreamExecutionEnvironment, TimeCharacteristic,
RuntimeContext
from pyflink.datastream.data_stream import DataStream
-from pyflink.datastream.functions import CoMapFunction, CoFlatMapFunction
+from pyflink.datastream.functions import CoMapFunction, CoFlatMapFunction,
AggregateFunction
from pyflink.datastream.functions import FilterFunction, ProcessFunction,
KeyedProcessFunction
from pyflink.datastream.functions import KeySelector
from pyflink.datastream.functions import MapFunction, FlatMapFunction
from pyflink.datastream.state import ValueStateDescriptor,
ListStateDescriptor, \
- MapStateDescriptor, ReducingStateDescriptor, ReducingState
+ MapStateDescriptor, ReducingStateDescriptor, ReducingState,
AggregatingState, \
+ AggregatingStateDescriptor
from pyflink.datastream.tests.test_util import DataStreamTestSinkFunction
from pyflink.java_gateway import get_gateway
from pyflink.testing.test_case_utils import PyFlinkTestCase,
invoke_java_object_method
@@ -874,6 +875,50 @@ class DataStreamTests(PyFlinkTestCase):
expected_result.sort()
self.assertEqual(expected_result, result)
+ def test_aggregating_state(self):
+ self.env.set_parallelism(2)
+ data_stream = self.env.from_collection([
+ (1, 'hi'), (2, 'hello'), (3, 'hi'), (4, 'hello'), (5, 'hi'), (6,
'hello')],
+ type_info=Types.TUPLE([Types.INT(), Types.STRING()]))
+
+ class MyAggregateFunction(AggregateFunction):
+
+ def create_accumulator(self):
+ return 0
+
+ def add(self, value, accumulator):
+ return value + accumulator
+
+ def get_result(self, accumulator):
+ return accumulator
+
+ def merge(self, acc_a, acc_b):
+ return acc_a + acc_b
+
+ class MyProcessFunction(KeyedProcessFunction):
+
+ def __init__(self):
+ self.aggregating_state = None # type: AggregatingState
+
+ def open(self, runtime_context: RuntimeContext):
+ self.aggregating_state = runtime_context.get_aggregating_state(
+ AggregatingStateDescriptor(
+ 'aggregating_state', MyAggregateFunction(),
Types.INT()))
+
+ def process_element(self, value, ctx):
+ self.aggregating_state.add(value[0])
+ yield Row(self.aggregating_state.get(), value[1])
+
+ data_stream.key_by(lambda x: x[1], key_type_info=Types.STRING()) \
+ .process(MyProcessFunction(),
output_type=Types.TUPLE([Types.INT(), Types.STRING()])) \
+ .add_sink(self.test_sink)
+ self.env.execute('test_aggregating_state')
+ result = self.test_sink.get_results()
+ expected_result = ['(1,hi)', '(2,hello)', '(4,hi)', '(6,hello)',
'(9,hi)', '(12,hello)']
+ result.sort()
+ expected_result.sort()
+ self.assertEqual(expected_result, result)
+
def tearDown(self) -> None:
self.test_sink.clear()
diff --git a/flink-python/pyflink/fn_execution/operations.py
b/flink-python/pyflink/fn_execution/operations.py
index 93c623e..f28c430 100644
--- a/flink-python/pyflink/fn_execution/operations.py
+++ b/flink-python/pyflink/fn_execution/operations.py
@@ -24,7 +24,8 @@ from typing import List, Tuple, Any, Dict
from apache_beam.coders import PickleCoder
from pyflink.datastream.state import ValueStateDescriptor, ValueState,
ListStateDescriptor, \
- ListState, MapStateDescriptor, MapState, ReducingStateDescriptor,
ReducingState
+ ListState, MapStateDescriptor, MapState, ReducingStateDescriptor,
ReducingState, \
+ AggregatingStateDescriptor, AggregatingState
from pyflink.datastream import TimeDomain, TimerService
from pyflink.datastream.functions import RuntimeContext, ProcessFunction,
KeyedProcessFunction
from pyflink.fn_execution import flink_fn_execution_pb2, operation_utils
@@ -461,6 +462,11 @@ class InternalRuntimeContext(RuntimeContext):
return self._keyed_state_backend.get_reducing_state(
state_descriptor.get_name(), PickleCoder(),
state_descriptor.get_reduce_function())
+ def get_aggregating_state(
+ self, state_descriptor: AggregatingStateDescriptor) ->
AggregatingState:
+ return self._keyed_state_backend.get_aggregating_state(
+ state_descriptor.get_name(), PickleCoder(),
state_descriptor.get_agg_function())
+
class ProcessFunctionOperation(DataStreamStatelessFunctionOperation):
diff --git a/flink-python/pyflink/fn_execution/state_impl.py
b/flink-python/pyflink/fn_execution/state_impl.py
index d596df3..41c5972 100644
--- a/flink-python/pyflink/fn_execution/state_impl.py
+++ b/flink-python/pyflink/fn_execution/state_impl.py
@@ -17,6 +17,7 @@
################################################################################
import collections
from enum import Enum
+from functools import partial
from apache_beam.coders import coder_impl
from apache_beam.portability.api import beam_fn_api_pb2
@@ -25,6 +26,7 @@ from apache_beam.transforms import userstate
from typing import List, Tuple, Any
from pyflink.datastream import ReduceFunction
+from pyflink.datastream.functions import AggregateFunction
from pyflink.datastream.state import ValueState, ListState, MapState,
ReducingState
@@ -126,7 +128,7 @@ class SynchronousListRuntimeState(ListState):
class SynchronousReducingRuntimeState(ReducingState):
"""
- The runtime ListState implementation backed by a
:class:`SynchronousBagRuntimeState`.
+ The runtime ReducingState implementation backed by a
:class:`SynchronousBagRuntimeState`.
"""
def __init__(self, internal_state: SynchronousBagRuntimeState,
reduce_function: ReduceFunction):
@@ -150,6 +152,42 @@ class SynchronousReducingRuntimeState(ReducingState):
self._internal_state.clear()
+class SynchronousAggregatingRuntimeState(ReducingState):
+ """
+ The runtime AggregatingState implementation backed by a
:class:`SynchronousBagRuntimeState`.
+ """
+
+ def __init__(self, internal_state: SynchronousBagRuntimeState,
agg_function: AggregateFunction):
+ self._internal_state = internal_state
+ self._agg_function = agg_function
+
+ def add(self, v):
+ if v is None:
+ self.clear()
+ return
+ accumulator = self._get_accumulator()
+ if accumulator is None:
+ accumulator = self._agg_function.create_accumulator()
+ accumulator = self._agg_function.add(v, accumulator)
+ self._internal_state.clear()
+ self._internal_state.add(accumulator)
+
+ def get(self):
+ accumulator = self._get_accumulator()
+ if accumulator is None:
+ return None
+ else:
+ return self._agg_function.get_result(accumulator)
+
+ def _get_accumulator(self):
+ for i in self._internal_state.read():
+ return i
+ return None
+
+ def clear(self):
+ self._internal_state.clear()
+
+
class CachedMapState(LRUCache):
def __init__(self, max_entries):
@@ -775,22 +813,12 @@ class RemoteKeyedStateBackend(object):
key=self._encoded_current_key))
def get_list_state(self, name, element_coder):
- if name in self._all_states:
- self.validate_list_state(name, element_coder)
- return self._all_states[name]
- internal_bag_state = self._get_internal_bag_state(name, element_coder)
- list_state = SynchronousListRuntimeState(internal_bag_state)
- self._all_states[name] = list_state
- return list_state
+ return self._wrap_internal_bag_state(
+ name, element_coder, SynchronousListRuntimeState,
SynchronousListRuntimeState)
def get_value_state(self, name, value_coder):
- if name in self._all_states:
- self.validate_value_state(name, value_coder)
- return self._all_states[name]
- internal_bag_state = self._get_internal_bag_state(name, value_coder)
- value_state = SynchronousValueRuntimeState(internal_bag_state)
- self._all_states[name] = value_state
- return value_state
+ return self._wrap_internal_bag_state(
+ name, value_coder, SynchronousValueRuntimeState,
SynchronousValueRuntimeState)
def get_map_state(self, name, map_key_coder, map_value_coder):
if name in self._all_states:
@@ -802,29 +830,21 @@ class RemoteKeyedStateBackend(object):
return map_state
def get_reducing_state(self, name, coder, reduce_function):
- if name in self._all_states:
- self.validate_reducing_state(name, coder)
- return self._all_states[name]
- internal_bag_state = self._get_internal_bag_state(name, coder)
- reducing_state = SynchronousReducingRuntimeState(internal_bag_state,
reduce_function)
- self._all_states[name] = reducing_state
- return reducing_state
+ return self._wrap_internal_bag_state(
+ name, coder, SynchronousReducingRuntimeState,
+ partial(SynchronousReducingRuntimeState,
reduce_function=reduce_function))
- def validate_value_state(self, name, coder):
- if name in self._all_states:
- state = self._all_states[name]
- if not isinstance(state, SynchronousValueRuntimeState):
- raise Exception("The state name '%s' is already in use and not
a value state."
- % name)
- if state._internal_state._value_coder != coder:
- raise Exception("State name corrupted: %s" % name)
+ def get_aggregating_state(self, name, coder, agg_function):
+ return self._wrap_internal_bag_state(
+ name, coder, SynchronousAggregatingRuntimeState,
+ partial(SynchronousAggregatingRuntimeState,
agg_function=agg_function))
- def validate_list_state(self, name, coder):
+ def validate_state(self, name, coder, expected_type):
if name in self._all_states:
state = self._all_states[name]
- if not isinstance(state, SynchronousListRuntimeState):
- raise Exception("The state name '%s' is already in use and not
a list state."
- % name)
+ if not isinstance(state, expected_type):
+ raise Exception("The state name '%s' is already in use and not
a %s."
+ % (name, expected_type))
if state._internal_state._value_coder != coder:
raise Exception("State name corrupted: %s" % name)
@@ -838,19 +858,23 @@ class RemoteKeyedStateBackend(object):
state._internal_state._map_value_coder != map_value_coder:
raise Exception("State name corrupted: %s" % name)
- def validate_reducing_state(self, name, coder):
+ def _wrap_internal_bag_state(self, name, element_coder, wrapper_type,
wrap_method):
if name in self._all_states:
- state = self._all_states[name]
- if not isinstance(state, SynchronousReducingRuntimeState):
- raise Exception("The state name '%s' is already in use and not
a reducing state."
- % name)
- if state._internal_state._value_coder != coder:
- raise Exception("State name corrupted: %s" % name)
+ self.validate_state(name, element_coder, wrapper_type)
+ return self._all_states[name]
+ internal_state = self._get_internal_bag_state(name, element_coder)
+ wrapped_state = wrap_method(internal_state)
+ self._all_states[name] = wrapped_state
+ return wrapped_state
def _get_internal_bag_state(self, name, element_coder):
cached_state = self._internal_state_cache.get((name,
self._encoded_current_key))
if cached_state is not None:
return cached_state
+ # The created internal state would not be put into the internal state
cache
+ # at once. The internal state cache is only updated when the current
key changes.
+ # The reason is that the state cache size may be smaller that the
count of activated
+ # state (i.e. the state with current key).
state_spec = userstate.BagStateSpec(name, element_coder)
internal_state = self._create_bag_state(state_spec)
return internal_state
@@ -906,7 +930,8 @@ class RemoteKeyedStateBackend(object):
if isinstance(state_obj,
(SynchronousValueRuntimeState,
SynchronousListRuntimeState,
- SynchronousReducingRuntimeState)):
+ SynchronousReducingRuntimeState,
+ SynchronousAggregatingRuntimeState)):
state_obj._internal_state = self._get_internal_bag_state(
state_name, state_obj._internal_state._value_coder)
elif isinstance(state_obj, SynchronousMapRuntimeState):