This is an automated email from the ASF dual-hosted git repository.

weizhong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 044f8e1  [FLINK-21115][python] Add AggregatingState and corresponding 
StateDescriptor for Python DataStream API
044f8e1 is described below

commit 044f8e17237d0fdbff96ca973817c932a3c985ba
Author: Wei Zhong <[email protected]>
AuthorDate: Wed Mar 3 20:54:52 2021 +0800

    [FLINK-21115][python] Add AggregatingState and corresponding 
StateDescriptor for Python DataStream API
    
    This closes #15028.
---
 flink-python/pyflink/datastream/functions.py       |  92 ++++++++++++++++-
 flink-python/pyflink/datastream/state.py           |  42 ++++++++
 .../pyflink/datastream/tests/test_data_stream.py   |  49 ++++++++-
 flink-python/pyflink/fn_execution/operations.py    |   8 +-
 flink-python/pyflink/fn_execution/state_impl.py    | 109 +++++++++++++--------
 5 files changed, 254 insertions(+), 46 deletions(-)

diff --git a/flink-python/pyflink/datastream/functions.py 
b/flink-python/pyflink/datastream/functions.py
index 60e718f..b6952ce 100644
--- a/flink-python/pyflink/datastream/functions.py
+++ b/flink-python/pyflink/datastream/functions.py
@@ -23,7 +23,8 @@ from typing import Union, Any, Dict
 from py4j.java_gateway import JavaObject
 
 from pyflink.datastream.state import ValueState, ValueStateDescriptor, 
ListStateDescriptor, \
-    ListState, MapStateDescriptor, MapState, ReducingStateDescriptor, 
ReducingState
+    ListState, MapStateDescriptor, MapState, ReducingStateDescriptor, 
ReducingState, \
+    AggregatingStateDescriptor, AggregatingState
 from pyflink.datastream.time_domain import TimeDomain
 from pyflink.datastream.timerservice import TimerService
 from pyflink.java_gateway import get_gateway
@@ -35,6 +36,7 @@ __all__ = [
     'FlatMapFunction',
     'CoFlatMapFunction',
     'ReduceFunction',
+    'AggregateFunction',
     'KeySelector',
     'FilterFunction',
     'Partitioner',
@@ -156,6 +158,17 @@ class RuntimeContext(object):
         """
         pass
 
+    def get_aggregating_state(
+            self, state_descriptor: AggregatingStateDescriptor) -> 
AggregatingState:
+        """
+        Gets a handle to the system's key/value aggregating state. This state 
is similar to the
+        state accessed via get_state(ValueStateDescriptor), but is optimized 
for state that
+        aggregates values with different types.
+
+        This state is only accessible if the function is executed on a 
KeyedStream.
+        """
+        pass
+
 
 class Function(abc.ABC):
     """
@@ -342,6 +355,83 @@ class ReduceFunction(Function):
         pass
 
 
+class AggregateFunction(Function):
+    """
+    The AggregateFunction is a flexible aggregation function, characterized by 
the following
+    features:
+
+        - The aggregates may use different types for input values, 
intermediate aggregates, and
+          result type, to support a wide range of aggregation types.
+        - Support for distributive aggregations: Different intermediate 
aggregates can be merged
+          together, to allow for pre-aggregation/final-aggregation 
optimizations.
+
+    The AggregateFunction's intermediate aggregate (in-progress aggregation 
state) is called the
+    `accumulator`. Values are added to the accumulator, and final aggregates 
are obtained by
+    finalizing the accumulator state. This supports aggregation functions 
where the intermediate
+    state needs to be different than the aggregated values and the final 
result type, such as for
+    example average (which typically keeps a count and sum). Merging 
intermediate aggregates
+    (partial aggregates) means merging the accumulators.
+
+    The AggregationFunction itself is stateless. To allow a single 
AggregationFunction instance to
+    maintain multiple aggregates (such as one aggregate per key), the 
AggregationFunction creates a
+    new accumulator whenever a new aggregation is started.
+    """
+
+    @abc.abstractmethod
+    def create_accumulator(self):
+        """
+        Creates a new accumulator, starting a new aggregate.
+
+        The new accumulator is typically meaningless unless a value is added 
via
+        :func:`~AggregateFunction.add`.
+
+        The accumulator is the state of a running aggregation. When a program 
has multiple
+        aggregates in progress (such as per key and window), the state (per 
key and window) is the
+        size of the accumulator.
+
+        :return: A new accumulator, corresponding to an empty aggregate.
+        """
+        pass
+
+    @abc.abstractmethod
+    def add(self, value, accumulator):
+        """
+        Adds the given input value to the given accumulator, returning the new 
accumulator value.
+
+        For efficiency, the input accumulator may be modified and returned.
+
+        :param value: The value to add.
+        :param accumulator: The accumulator to add the value to.
+        :return: The accumulator with the updated state.
+        """
+        pass
+
+    @abc.abstractmethod
+    def get_result(self, accumulator):
+        """
+        Gets the result of the aggregation from the accumulator.
+
+        :param accumulator: The accumulator of the aggregation.
+        :return: The final aggregation result.
+        """
+        pass
+
+    @abc.abstractmethod
+    def merge(self, acc_a, acc_b):
+        """
+        Merges two accumulators, returning an accumulator with the merged 
state.
+
+        This function may reuse any of the given accumulators as the target 
for the merge and
+        return that. The assumption is that the given accumulators will not be 
used any more after
+        having been passed to this function.
+
+        :param acc_a: An accumulator to merge.
+        :param acc_b: Another accumulator to merge.
+        :return: The accumulator with the merged state.
+        """
+        pass
+
+
 class KeySelector(Function):
     """
     The KeySelector allows to use deterministic objects for operations such as 
reduce, reduceGroup,
diff --git a/flink-python/pyflink/datastream/state.py 
b/flink-python/pyflink/datastream/state.py
index 1e3cde9..4c11b29 100644
--- a/flink-python/pyflink/datastream/state.py
+++ b/flink-python/pyflink/datastream/state.py
@@ -134,6 +134,26 @@ class ReducingState(MergingState[T, T]):
     pass
 
 
+class AggregatingState(MergingState[IN, OUT]):
+    """
+    :class:`State` interface for aggregating state, based on an
+    :class:`~pyflink.datastream.functions.AggregateFunction`. Elements that 
are added to this type
+    of state will be eagerly pre-aggregated using a given AggregateFunction.
+
+    The state holds internally always the accumulator type of the 
AggregateFunction. When
+    accessing the result of the state, the function's
+    :func:`~pyflink.datastream.functions.AggregateFunction.get_result` method.
+
+    The state is accessed and modified by user functions, and checkpointed 
consistently by the
+    system as part of the distributed snapshots.
+
+    The state is only accessible by functions applied on a KeyedStream. The 
key is automatically
+    supplied by the system, so the function always sees the value mapped to 
the key of the current
+    element. That way, the system can handle stream and state partitioning 
consistently together.
+    """
+    pass
+
+
 class ListState(MergingState[T, Iterable[T]]):
     """
     :class:`State` interface for partitioned list state in Operations.
@@ -358,3 +378,25 @@ class ReducingStateDescriptor(StateDescriptor):
 
     def get_reduce_function(self):
         return self._reduce_function
+
+
+class AggregatingStateDescriptor(StateDescriptor):
+    """
+    A StateDescriptor for AggregatingState.
+
+    The type internally stored in the state is the type of the Accumulator of 
the
+    :func:`~pyflink.datastream.functions.AggregateFunction`.
+    """
+
+    def __init__(self,
+                 name: str,
+                 agg_function,
+                 state_type_info):
+        super(AggregatingStateDescriptor, self).__init__(name, state_type_info)
+        from pyflink.datastream.functions import AggregateFunction
+        if not isinstance(agg_function, AggregateFunction):
+            raise TypeError("The input must be a AggregateFunction!")
+        self._agg_function = agg_function
+
+    def get_agg_function(self):
+        return self._agg_function
diff --git a/flink-python/pyflink/datastream/tests/test_data_stream.py 
b/flink-python/pyflink/datastream/tests/test_data_stream.py
index 852bf8a..3d7a0da 100644
--- a/flink-python/pyflink/datastream/tests/test_data_stream.py
+++ b/flink-python/pyflink/datastream/tests/test_data_stream.py
@@ -25,12 +25,13 @@ from pyflink.common.typeinfo import Types
 from pyflink.common.watermark_strategy import WatermarkStrategy, 
TimestampAssigner
 from pyflink.datastream import StreamExecutionEnvironment, TimeCharacteristic, 
RuntimeContext
 from pyflink.datastream.data_stream import DataStream
-from pyflink.datastream.functions import CoMapFunction, CoFlatMapFunction
+from pyflink.datastream.functions import CoMapFunction, CoFlatMapFunction, 
AggregateFunction
 from pyflink.datastream.functions import FilterFunction, ProcessFunction, 
KeyedProcessFunction
 from pyflink.datastream.functions import KeySelector
 from pyflink.datastream.functions import MapFunction, FlatMapFunction
 from pyflink.datastream.state import ValueStateDescriptor, 
ListStateDescriptor, \
-    MapStateDescriptor, ReducingStateDescriptor, ReducingState
+    MapStateDescriptor, ReducingStateDescriptor, ReducingState, 
AggregatingState, \
+    AggregatingStateDescriptor
 from pyflink.datastream.tests.test_util import DataStreamTestSinkFunction
 from pyflink.java_gateway import get_gateway
 from pyflink.testing.test_case_utils import PyFlinkTestCase, 
invoke_java_object_method
@@ -874,6 +875,50 @@ class DataStreamTests(PyFlinkTestCase):
         expected_result.sort()
         self.assertEqual(expected_result, result)
 
+    def test_aggregating_state(self):
+        self.env.set_parallelism(2)
+        data_stream = self.env.from_collection([
+            (1, 'hi'), (2, 'hello'), (3, 'hi'), (4, 'hello'), (5, 'hi'), (6, 
'hello')],
+            type_info=Types.TUPLE([Types.INT(), Types.STRING()]))
+
+        class MyAggregateFunction(AggregateFunction):
+
+            def create_accumulator(self):
+                return 0
+
+            def add(self, value, accumulator):
+                return value + accumulator
+
+            def get_result(self, accumulator):
+                return accumulator
+
+            def merge(self, acc_a, acc_b):
+                return acc_a + acc_b
+
+        class MyProcessFunction(KeyedProcessFunction):
+
+            def __init__(self):
+                self.aggregating_state = None  # type: AggregatingState
+
+            def open(self, runtime_context: RuntimeContext):
+                self.aggregating_state = runtime_context.get_aggregating_state(
+                    AggregatingStateDescriptor(
+                        'aggregating_state', MyAggregateFunction(), 
Types.INT()))
+
+            def process_element(self, value, ctx):
+                self.aggregating_state.add(value[0])
+                yield Row(self.aggregating_state.get(), value[1])
+
+        data_stream.key_by(lambda x: x[1], key_type_info=Types.STRING()) \
+            .process(MyProcessFunction(), 
output_type=Types.TUPLE([Types.INT(), Types.STRING()])) \
+            .add_sink(self.test_sink)
+        self.env.execute('test_aggregating_state')
+        result = self.test_sink.get_results()
+        expected_result = ['(1,hi)', '(2,hello)', '(4,hi)', '(6,hello)', 
'(9,hi)', '(12,hello)']
+        result.sort()
+        expected_result.sort()
+        self.assertEqual(expected_result, result)
+
     def tearDown(self) -> None:
         self.test_sink.clear()
 
diff --git a/flink-python/pyflink/fn_execution/operations.py 
b/flink-python/pyflink/fn_execution/operations.py
index 93c623e..f28c430 100644
--- a/flink-python/pyflink/fn_execution/operations.py
+++ b/flink-python/pyflink/fn_execution/operations.py
@@ -24,7 +24,8 @@ from typing import List, Tuple, Any, Dict
 from apache_beam.coders import PickleCoder
 
 from pyflink.datastream.state import ValueStateDescriptor, ValueState, 
ListStateDescriptor, \
-    ListState, MapStateDescriptor, MapState, ReducingStateDescriptor, 
ReducingState
+    ListState, MapStateDescriptor, MapState, ReducingStateDescriptor, 
ReducingState, \
+    AggregatingStateDescriptor, AggregatingState
 from pyflink.datastream import TimeDomain, TimerService
 from pyflink.datastream.functions import RuntimeContext, ProcessFunction, 
KeyedProcessFunction
 from pyflink.fn_execution import flink_fn_execution_pb2, operation_utils
@@ -461,6 +462,11 @@ class InternalRuntimeContext(RuntimeContext):
         return self._keyed_state_backend.get_reducing_state(
             state_descriptor.get_name(), PickleCoder(), 
state_descriptor.get_reduce_function())
 
+    def get_aggregating_state(
+            self, state_descriptor: AggregatingStateDescriptor) -> 
AggregatingState:
+        return self._keyed_state_backend.get_aggregating_state(
+            state_descriptor.get_name(), PickleCoder(), 
state_descriptor.get_agg_function())
+
 
 class ProcessFunctionOperation(DataStreamStatelessFunctionOperation):
 
diff --git a/flink-python/pyflink/fn_execution/state_impl.py 
b/flink-python/pyflink/fn_execution/state_impl.py
index d596df3..41c5972 100644
--- a/flink-python/pyflink/fn_execution/state_impl.py
+++ b/flink-python/pyflink/fn_execution/state_impl.py
@@ -17,6 +17,7 @@
 
################################################################################
 import collections
 from enum import Enum
+from functools import partial
 
 from apache_beam.coders import coder_impl
 from apache_beam.portability.api import beam_fn_api_pb2
@@ -25,6 +26,7 @@ from apache_beam.transforms import userstate
 from typing import List, Tuple, Any
 
 from pyflink.datastream import ReduceFunction
+from pyflink.datastream.functions import AggregateFunction
 from pyflink.datastream.state import ValueState, ListState, MapState, 
ReducingState
 
 
@@ -126,7 +128,7 @@ class SynchronousListRuntimeState(ListState):
 
 class SynchronousReducingRuntimeState(ReducingState):
     """
-    The runtime ListState implementation backed by a 
:class:`SynchronousBagRuntimeState`.
+    The runtime ReducingState implementation backed by a 
:class:`SynchronousBagRuntimeState`.
     """
 
     def __init__(self, internal_state: SynchronousBagRuntimeState, 
reduce_function: ReduceFunction):
@@ -150,6 +152,42 @@ class SynchronousReducingRuntimeState(ReducingState):
         self._internal_state.clear()
 
 
+class SynchronousAggregatingRuntimeState(ReducingState):
+    """
+    The runtime AggregatingState implementation backed by a 
:class:`SynchronousBagRuntimeState`.
+    """
+
+    def __init__(self, internal_state: SynchronousBagRuntimeState, 
agg_function: AggregateFunction):
+        self._internal_state = internal_state
+        self._agg_function = agg_function
+
+    def add(self, v):
+        if v is None:
+            self.clear()
+            return
+        accumulator = self._get_accumulator()
+        if accumulator is None:
+            accumulator = self._agg_function.create_accumulator()
+        accumulator = self._agg_function.add(v, accumulator)
+        self._internal_state.clear()
+        self._internal_state.add(accumulator)
+
+    def get(self):
+        accumulator = self._get_accumulator()
+        if accumulator is None:
+            return None
+        else:
+            return self._agg_function.get_result(accumulator)
+
+    def _get_accumulator(self):
+        for i in self._internal_state.read():
+            return i
+        return None
+
+    def clear(self):
+        self._internal_state.clear()
+
+
 class CachedMapState(LRUCache):
 
     def __init__(self, max_entries):
@@ -775,22 +813,12 @@ class RemoteKeyedStateBackend(object):
                 key=self._encoded_current_key))
 
     def get_list_state(self, name, element_coder):
-        if name in self._all_states:
-            self.validate_list_state(name, element_coder)
-            return self._all_states[name]
-        internal_bag_state = self._get_internal_bag_state(name, element_coder)
-        list_state = SynchronousListRuntimeState(internal_bag_state)
-        self._all_states[name] = list_state
-        return list_state
+        return self._wrap_internal_bag_state(
+            name, element_coder, SynchronousListRuntimeState, 
SynchronousListRuntimeState)
 
     def get_value_state(self, name, value_coder):
-        if name in self._all_states:
-            self.validate_value_state(name, value_coder)
-            return self._all_states[name]
-        internal_bag_state = self._get_internal_bag_state(name, value_coder)
-        value_state = SynchronousValueRuntimeState(internal_bag_state)
-        self._all_states[name] = value_state
-        return value_state
+        return self._wrap_internal_bag_state(
+            name, value_coder, SynchronousValueRuntimeState, 
SynchronousValueRuntimeState)
 
     def get_map_state(self, name, map_key_coder, map_value_coder):
         if name in self._all_states:
@@ -802,29 +830,21 @@ class RemoteKeyedStateBackend(object):
         return map_state
 
     def get_reducing_state(self, name, coder, reduce_function):
-        if name in self._all_states:
-            self.validate_reducing_state(name, coder)
-            return self._all_states[name]
-        internal_bag_state = self._get_internal_bag_state(name, coder)
-        reducing_state = SynchronousReducingRuntimeState(internal_bag_state, 
reduce_function)
-        self._all_states[name] = reducing_state
-        return reducing_state
+        return self._wrap_internal_bag_state(
+            name, coder, SynchronousReducingRuntimeState,
+            partial(SynchronousReducingRuntimeState, 
reduce_function=reduce_function))
 
-    def validate_value_state(self, name, coder):
-        if name in self._all_states:
-            state = self._all_states[name]
-            if not isinstance(state, SynchronousValueRuntimeState):
-                raise Exception("The state name '%s' is already in use and not 
a value state."
-                                % name)
-            if state._internal_state._value_coder != coder:
-                raise Exception("State name corrupted: %s" % name)
+    def get_aggregating_state(self, name, coder, agg_function):
+        return self._wrap_internal_bag_state(
+            name, coder, SynchronousAggregatingRuntimeState,
+            partial(SynchronousAggregatingRuntimeState, 
agg_function=agg_function))
 
-    def validate_list_state(self, name, coder):
+    def validate_state(self, name, coder, expected_type):
         if name in self._all_states:
             state = self._all_states[name]
-            if not isinstance(state, SynchronousListRuntimeState):
-                raise Exception("The state name '%s' is already in use and not 
a list state."
-                                % name)
+            if not isinstance(state, expected_type):
+                raise Exception("The state name '%s' is already in use and not 
a %s."
+                                % (name, expected_type))
             if state._internal_state._value_coder != coder:
                 raise Exception("State name corrupted: %s" % name)
 
@@ -838,19 +858,23 @@ class RemoteKeyedStateBackend(object):
                     state._internal_state._map_value_coder != map_value_coder:
                 raise Exception("State name corrupted: %s" % name)
 
-    def validate_reducing_state(self, name, coder):
+    def _wrap_internal_bag_state(self, name, element_coder, wrapper_type, 
wrap_method):
         if name in self._all_states:
-            state = self._all_states[name]
-            if not isinstance(state, SynchronousReducingRuntimeState):
-                raise Exception("The state name '%s' is already in use and not 
a reducing state."
-                                % name)
-            if state._internal_state._value_coder != coder:
-                raise Exception("State name corrupted: %s" % name)
+            self.validate_state(name, element_coder, wrapper_type)
+            return self._all_states[name]
+        internal_state = self._get_internal_bag_state(name, element_coder)
+        wrapped_state = wrap_method(internal_state)
+        self._all_states[name] = wrapped_state
+        return wrapped_state
 
     def _get_internal_bag_state(self, name, element_coder):
         cached_state = self._internal_state_cache.get((name, 
self._encoded_current_key))
         if cached_state is not None:
             return cached_state
+        # The created internal state would not be put into the internal state 
cache
+        # at once. The internal state cache is only updated when the current 
key changes.
+        # The reason is that the state cache size may be smaller that the 
count of activated
+        # state (i.e. the state with current key).
         state_spec = userstate.BagStateSpec(name, element_coder)
         internal_state = self._create_bag_state(state_spec)
         return internal_state
@@ -906,7 +930,8 @@ class RemoteKeyedStateBackend(object):
             if isinstance(state_obj,
                           (SynchronousValueRuntimeState,
                            SynchronousListRuntimeState,
-                           SynchronousReducingRuntimeState)):
+                           SynchronousReducingRuntimeState,
+                           SynchronousAggregatingRuntimeState)):
                 state_obj._internal_state = self._get_internal_bag_state(
                     state_name, state_obj._internal_state._value_coder)
             elif isinstance(state_obj, SynchronousMapRuntimeState):

Reply via email to