This is an automated email from the ASF dual-hosted git repository.

dianfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 9c486d1  [FLINK-20528][python] Support table aggregation for group 
aggregation in streaming mode
9c486d1 is described below

commit 9c486d1c1d245263630b2e9dc0f3d6a95e130a4d
Author: HuangXingBo <[email protected]>
AuthorDate: Thu Dec 17 10:00:11 2020 +0800

    [FLINK-20528][python] Support table aggregation for group aggregation in 
streaming mode
    
    This closes #14389.
---
 flink-python/pyflink/fn_execution/aggregate.py     | 207 +++++++++++++++++----
 .../pyflink/fn_execution/beam/beam_operations.py   |  10 +
 .../pyflink/fn_execution/operation_utils.py        |   4 +-
 flink-python/pyflink/fn_execution/operations.py    |  87 +++++++--
 flink-python/pyflink/table/__init__.py             |   6 +-
 flink-python/pyflink/table/table.py                |  76 ++++++++
 flink-python/pyflink/table/table_environment.py    |  12 +-
 .../table/tests/test_row_based_operation.py        | 109 ++++++++++-
 flink-python/pyflink/table/udf.py                  | 164 +++++++++++++---
 .../AbstractPythonStreamAggregateOperator.java     |  25 +++
 .../PythonStreamGroupAggregateOperator.java        |  25 +--
 .../PythonStreamGroupTableAggregateOperator.java   |  23 +--
 ...ythonStreamGroupTableAggregateOperatorTest.java |  17 +-
 .../python/PythonTableAggregateFunction.java       | 127 +++++++++++++
 .../plan/nodes/common/CommonPythonAggregate.scala  |   9 +-
 .../StreamExecPythonGroupTableAggregate.scala      | 123 +++++++++++-
 16 files changed, 881 insertions(+), 143 deletions(-)

diff --git a/flink-python/pyflink/fn_execution/aggregate.py 
b/flink-python/pyflink/fn_execution/aggregate.py
index 1971623..b719d0e 100644
--- a/flink-python/pyflink/fn_execution/aggregate.py
+++ b/flink-python/pyflink/fn_execution/aggregate.py
@@ -16,7 +16,7 @@
 # limitations under the License.
 
################################################################################
 from abc import ABC, abstractmethod
-from typing import List, Dict
+from typing import List, Dict, Iterable
 
 from apache_beam.coders import PickleCoder, Coder
 
@@ -25,8 +25,9 @@ from pyflink.common.state import ListState, MapState
 from pyflink.fn_execution.coders import from_proto
 from pyflink.fn_execution.operation_utils import is_built_in_function, 
load_aggregate_function
 from pyflink.fn_execution.state_impl import RemoteKeyedStateBackend
-from pyflink.table import AggregateFunction, FunctionContext
+from pyflink.table import AggregateFunction, FunctionContext, 
TableAggregateFunction
 from pyflink.table.data_view import ListView, MapView
+from pyflink.table.udf import ImperativeAggregateFunction
 
 
 def join_row(left: Row, right: Row):
@@ -219,9 +220,9 @@ class StateDataViewStore(object):
             self._keyed_state_backend.get_map_state(state_name, key_coder, 
value_coder))
 
 
-class AggsHandleFunction(ABC):
+class AggsHandleFunctionBase(ABC):
     """
-    The base class for handling aggregate functions.
+    The base class for handling aggregate or table aggregate functions.
     """
 
     @abstractmethod
@@ -301,6 +302,20 @@ class AggsHandleFunction(ABC):
         pass
 
     @abstractmethod
+    def close(self):
+        """
+        Tear-down method for this function. It can be used for clean up work.
+        By default, this method does nothing.
+        """
+        pass
+
+
+class AggsHandleFunction(AggsHandleFunctionBase):
+    """
+    The base class for handling aggregate functions.
+    """
+
+    @abstractmethod
     def get_value(self) -> Row:
         """
         Gets the result of the aggregation from the current accumulators.
@@ -309,25 +324,28 @@ class AggsHandleFunction(ABC):
         """
         pass
 
+
+class TableAggsHandleFunction(AggsHandleFunctionBase):
+    """
+    The base class for handling table aggregate functions.
+    """
+
     @abstractmethod
-    def close(self):
+    def emit_value(self, current_key: Row, is_retract: bool) -> Iterable[Row]:
         """
-        Tear-down method for this function. It can be used for clean up work.
-        By default, this method does nothing.
+        Emit the result of the table aggregation.
         """
         pass
 
 
-class SimpleAggsHandleFunction(AggsHandleFunction):
+class SimpleAggsHandleFunctionBase(AggsHandleFunctionBase):
     """
-    A simple AggsHandleFunction implementation which provides the basic 
functionality.
+    A simple AggsHandleFunctionBase implementation which provides the basic 
functionality.
     """
 
     def __init__(self,
-                 udfs: List[AggregateFunction],
+                 udfs: List[ImperativeAggregateFunction],
                  input_extractors: List,
-                 index_of_count_star: int,
-                 count_star_inserted: bool,
                  udf_data_view_specs: List[List[DataViewSpec]],
                  filter_args: List[int],
                  distinct_indexes: List[int],
@@ -335,10 +353,6 @@ class SimpleAggsHandleFunction(AggsHandleFunction):
         self._udfs = udfs
         self._input_extractors = input_extractors
         self._accumulators = None  # type: Row
-        self._get_value_indexes = [i for i in range(len(udfs))]
-        if index_of_count_star >= 0 and count_star_inserted:
-            # The record count is used internally, should be ignored by the 
get_value method.
-            self._get_value_indexes.remove(index_of_count_star)
         self._udf_data_view_specs = udf_data_view_specs
         self._udf_data_views = []
         self._filter_args = filter_args
@@ -451,13 +465,64 @@ class SimpleAggsHandleFunction(AggsHandleFunction):
             for data_view in self._udf_data_views[i].values():
                 data_view.clear()
 
+    def close(self):
+        for udf in self._udfs:
+            udf.close()
+
+
+class SimpleAggsHandleFunction(SimpleAggsHandleFunctionBase, 
AggsHandleFunction):
+    """
+    A simple AggsHandleFunction implementation which provides the basic 
functionality.
+    """
+
+    def __init__(self,
+                 udfs: List[AggregateFunction],
+                 input_extractors: List,
+                 index_of_count_star: int,
+                 count_star_inserted: bool,
+                 udf_data_view_specs: List[List[DataViewSpec]],
+                 filter_args: List[int],
+                 distinct_indexes: List[int],
+                 distinct_view_descriptors: Dict[int, DistinctViewDescriptor]):
+        super(SimpleAggsHandleFunction, self).__init__(
+            udfs, input_extractors, udf_data_view_specs, filter_args, 
distinct_indexes,
+            distinct_view_descriptors)
+        self._get_value_indexes = [i for i in range(len(udfs))]
+        if index_of_count_star >= 0 and count_star_inserted:
+            # The record count is used internally, should be ignored by the 
get_value method.
+            self._get_value_indexes.remove(index_of_count_star)
+
     def get_value(self):
         return Row(*[self._udfs[i].get_value(self._accumulators[i])
                      for i in self._get_value_indexes])
 
-    def close(self):
-        for udf in self._udfs:
-            udf.close()
+
+class SimpleTableAggsHandleFunction(SimpleAggsHandleFunctionBase, 
TableAggsHandleFunction):
+    """
+    A simple TableAggsHandleFunction implementation which provides the basic 
functionality.
+    """
+
+    def __init__(self,
+                 udfs: List[TableAggregateFunction],
+                 input_extractors: List,
+                 udf_data_view_specs: List[List[DataViewSpec]],
+                 filter_args: List[int],
+                 distinct_indexes: List[int],
+                 distinct_view_descriptors: Dict[int, DistinctViewDescriptor]):
+        super(SimpleTableAggsHandleFunction, self).__init__(
+            udfs, input_extractors, udf_data_view_specs, filter_args, 
distinct_indexes,
+            distinct_view_descriptors)
+
+    def emit_value(self, current_key: Row, is_retract: bool):
+        udf = self._udfs[0]  # type: TableAggregateFunction
+        results = udf.emit_value(self._accumulators[0])
+        for x in results:
+            result = join_row(current_key, x)
+            if is_retract:
+                result.set_row_kind(RowKind.DELETE)
+            else:
+                result.set_row_kind(RowKind.INSERT)
+            yield result
 
 
 class RecordCounter(ABC):
@@ -494,10 +559,10 @@ class RetractionRecordCounter(RecordCounter):
         return acc is None or acc[self._index_of_count_star][0] == 0
 
 
-class GroupAggFunction(object):
+class GroupAggFunctionBase(object):
 
     def __init__(self,
-                 aggs_handle: AggsHandleFunction,
+                 aggs_handle: AggsHandleFunctionBase,
                  key_selector: RowKeySelector,
                  state_backend: RemoteKeyedStateBackend,
                  state_value_coder: Coder,
@@ -518,6 +583,41 @@ class GroupAggFunction(object):
     def close(self):
         self.aggs_handle.close()
 
+    def on_timer(self, key):
+        if self.state_cleaning_enabled:
+            self.state_backend.set_current_key(key)
+            accumulator_state = self.state_backend.get_value_state(
+                "accumulators", self.state_value_coder)
+            accumulator_state.clear()
+            self.aggs_handle.cleanup()
+
+    @staticmethod
+    def is_retract_msg(data: Row):
+        return data.get_row_kind() == RowKind.UPDATE_BEFORE or 
data.get_row_kind() == RowKind.DELETE
+
+    @staticmethod
+    def is_accumulate_msg(data: Row):
+        return data.get_row_kind() == RowKind.UPDATE_AFTER or 
data.get_row_kind() == RowKind.INSERT
+
+    @abstractmethod
+    def process_element(self, input_data: Row):
+        pass
+
+
+class GroupAggFunction(GroupAggFunctionBase):
+
+    def __init__(self,
+                 aggs_handle: AggsHandleFunction,
+                 key_selector: RowKeySelector,
+                 state_backend: RemoteKeyedStateBackend,
+                 state_value_coder: Coder,
+                 generate_update_before: bool,
+                 state_cleaning_enabled: bool,
+                 index_of_count_star: int):
+        super(GroupAggFunction, self).__init__(
+            aggs_handle, key_selector, state_backend, state_value_coder, 
generate_update_before,
+            state_cleaning_enabled, index_of_count_star)
+
     def process_element(self, input_data: Row):
         key = self.key_selector.get_key(input_data)
         self.state_backend.set_current_key(key)
@@ -597,20 +697,55 @@ class GroupAggFunction(object):
             # cleanup dataview under current key
             self.aggs_handle.cleanup()
 
-    def on_timer(self, key):
-        if self.state_cleaning_enabled:
-            self.state_backend.set_current_key(key)
-            accumulator_state = self.state_backend.get_value_state(
-                "accumulators", self.state_value_coder)
-            accumulator_state.clear()
-            self.aggs_handle.cleanup()
 
-    @staticmethod
-    def is_retract_msg(data: Row):
-        return data.get_row_kind() == RowKind.UPDATE_BEFORE \
-            or data.get_row_kind() == RowKind.DELETE
+class GroupTableAggFunction(GroupAggFunctionBase):
+    def __init__(self,
+                 aggs_handle: TableAggsHandleFunction,
+                 key_selector: RowKeySelector,
+                 state_backend: RemoteKeyedStateBackend,
+                 state_value_coder: Coder,
+                 generate_update_before: bool,
+                 state_cleaning_enabled: bool,
+                 index_of_count_star: int):
+        super(GroupTableAggFunction, self).__init__(
+            aggs_handle, key_selector, state_backend, state_value_coder, 
generate_update_before,
+            state_cleaning_enabled, index_of_count_star)
 
-    @staticmethod
-    def is_accumulate_msg(data: Row):
-        return data.get_row_kind() == RowKind.UPDATE_AFTER \
-            or data.get_row_kind() == RowKind.INSERT
+    def process_element(self, input_data: Row):
+        key = self.key_selector.get_key(input_data)
+        self.state_backend.set_current_key(key)
+        self.state_backend.clear_cached_iterators()
+        accumulator_state = self.state_backend.get_value_state(
+            "accumulators", self.state_value_coder)
+        accumulators = accumulator_state.value()
+        if accumulators is None:
+            first_row = True
+            accumulators = self.aggs_handle.create_accumulators()
+        else:
+            first_row = False
+
+        # set accumulators to handler first
+        self.aggs_handle.set_accumulators(accumulators)
+
+        if not first_row and self.generate_update_before:
+            yield from self.aggs_handle.emit_value(key, True)
+
+        # update aggregate result and set to the newRow
+        if self.is_accumulate_msg(input_data):
+            # accumulate input
+            self.aggs_handle.accumulate(input_data)
+        else:
+            # retract input
+            self.aggs_handle.retract(input_data)
+
+        # get accumulator
+        accumulators = self.aggs_handle.get_accumulators()
+
+        if not self.record_counter.record_count_is_zero(accumulators):
+            yield from self.aggs_handle.emit_value(key, False)
+            accumulator_state.update(accumulators)
+        else:
+            # and clear all state
+            accumulator_state.clear()
+            # cleanup dataview under current key
+            self.aggs_handle.cleanup()
diff --git a/flink-python/pyflink/fn_execution/beam/beam_operations.py 
b/flink-python/pyflink/fn_execution/beam/beam_operations.py
index 3b95201..c36515d 100644
--- a/flink-python/pyflink/fn_execution/beam/beam_operations.py
+++ b/flink-python/pyflink/fn_execution/beam/beam_operations.py
@@ -110,6 +110,16 @@ def create_data_stream_keyed_process_function(factory, 
transform_id, transform_p
         operations.KeyedProcessFunctionOperation)
 
 
+@bundle_processor.BeamTransformFactory.register_urn(
+    operations.STREAM_GROUP_TABLE_AGGREGATE_URN,
+    flink_fn_execution_pb2.UserDefinedAggregateFunctions)
+def create_table_aggregate_function(factory, transform_id, transform_proto, 
parameter, consumers):
+    return _create_user_defined_function_operation(
+        factory, transform_proto, consumers, parameter,
+        beam_operations.StatefulFunctionOperation,
+        operations.StreamGroupTableAggregateOperation)
+
+
 def _create_user_defined_function_operation(factory, transform_proto, 
consumers, udfs_proto,
                                             beam_operation_cls, 
internal_operation_cls):
     output_tags = list(transform_proto.outputs.keys())
diff --git a/flink-python/pyflink/fn_execution/operation_utils.py 
b/flink-python/pyflink/fn_execution/operation_utils.py
index 5e8e768..2dbd8e8 100644
--- a/flink-python/pyflink/fn_execution/operation_utils.py
+++ b/flink-python/pyflink/fn_execution/operation_utils.py
@@ -26,7 +26,7 @@ from pyflink.fn_execution import flink_fn_execution_pb2, 
pickle
 from pyflink.serializers import PickleSerializer
 from pyflink.table import functions
 from pyflink.table.udf import DelegationTableFunction, 
DelegatingScalarFunction, \
-    AggregateFunction, PandasAggregateFunctionWrapper
+    ImperativeAggregateFunction, PandasAggregateFunctionWrapper
 
 _func_num = 0
 _constant_num = 0
@@ -147,7 +147,7 @@ def extract_user_defined_aggregate_function(
         user_defined_function_proto,
         distinct_info_dict: Dict[Tuple[List[str]], Tuple[List[int], 
List[int]]]):
     user_defined_agg = 
load_aggregate_function(user_defined_function_proto.payload)
-    assert isinstance(user_defined_agg, AggregateFunction)
+    assert isinstance(user_defined_agg, ImperativeAggregateFunction)
     args_str = []
     local_variable_dict = {}
     for arg in user_defined_function_proto.inputs:
diff --git a/flink-python/pyflink/fn_execution/operations.py 
b/flink-python/pyflink/fn_execution/operations.py
index dd4c7cf..9cfa1da 100644
--- a/flink-python/pyflink/fn_execution/operations.py
+++ b/flink-python/pyflink/fn_execution/operations.py
@@ -30,7 +30,8 @@ from pyflink.fn_execution import flink_fn_execution_pb2, 
operation_utils
 from pyflink.fn_execution.beam.beam_coders import DataViewFilterCoder
 from pyflink.fn_execution.operation_utils import 
extract_user_defined_aggregate_function
 from pyflink.fn_execution.aggregate import RowKeySelector, 
SimpleAggsHandleFunction, \
-    GroupAggFunction, extract_data_view_specs, DistinctViewDescriptor
+    GroupAggFunction, extract_data_view_specs, DistinctViewDescriptor, \
+    SimpleTableAggsHandleFunction, GroupTableAggFunction
 from pyflink.metrics.metricbase import GenericMetricGroup
 from pyflink.table import FunctionContext, Row
 
@@ -39,6 +40,7 @@ from pyflink.table import FunctionContext, Row
 SCALAR_FUNCTION_URN = "flink:transform:scalar_function:v1"
 TABLE_FUNCTION_URN = "flink:transform:table_function:v1"
 STREAM_GROUP_AGGREGATE_URN = "flink:transform:stream_group_aggregate:v1"
+STREAM_GROUP_TABLE_AGGREGATE_URN = 
"flink:transform:stream_group_table_aggregate:v1"
 PANDAS_AGGREGATE_FUNCTION_URN = "flink:transform:aggregate_function:arrow:v1"
 PANDAS_BATCH_OVER_WINDOW_AGGREGATE_FUNCTION_URN = \
     "flink:transform:batch_over_window_aggregate_function:arrow:v1"
@@ -262,7 +264,7 @@ class StatefulFunctionOperation(Operation):
 TRIGGER_TIMER = 1
 
 
-class StreamGroupAggregateOperation(StatefulFunctionOperation):
+class AbstractStreamGroupAggregateOperation(StatefulFunctionOperation):
 
     def __init__(self, spec, keyed_state_backend):
         self.generate_update_before = spec.serialized_fn.generate_update_before
@@ -276,7 +278,7 @@ class 
StreamGroupAggregateOperation(StatefulFunctionOperation):
         self.state_cache_size = spec.serialized_fn.state_cache_size
         self.state_cleaning_enabled = spec.serialized_fn.state_cleaning_enabled
         self.data_view_specs = extract_data_view_specs(spec.serialized_fn.udfs)
-        super(StreamGroupAggregateOperation, self).__init__(spec, 
keyed_state_backend)
+        super(AbstractStreamGroupAggregateOperation, self).__init__(spec, 
keyed_state_backend)
 
     def open(self):
         self.group_agg_function.open(FunctionContext(self.base_metric_group))
@@ -310,6 +312,44 @@ class 
StreamGroupAggregateOperation(StatefulFunctionOperation):
             # use the agg index of the first function as the key of shared 
distinct view
             distinct_view_descriptors[agg_index_list[0]] = 
DistinctViewDescriptor(
                 input_extractors[agg_index_list[0]], filter_arg_list)
+
+        key_selector = RowKeySelector(self.grouping)
+        if len(self.data_view_specs) > 0:
+            state_value_coder = DataViewFilterCoder(self.data_view_specs)
+        else:
+            state_value_coder = PickleCoder()
+
+        self.group_agg_function = self.create_process_function(
+            user_defined_aggs, input_extractors, filter_args, distinct_indexes,
+            distinct_view_descriptors, key_selector, state_value_coder)
+
+        return self.process_element_or_timer, []
+
+    def process_element_or_timer(self, input_data: Tuple[int, Row, int, Row]):
+        # the structure of the input data:
+        # [element_type, element(for process_element), timestamp(for timer), 
key(for timer)]
+        # all the fields are nullable except the "element_type"
+        if input_data[0] != TRIGGER_TIMER:
+            return self.group_agg_function.process_element(input_data[1])
+        else:
+            self.group_agg_function.on_timer(input_data[3])
+            return []
+
+    @abc.abstractmethod
+    def create_process_function(self, user_defined_aggs, input_extractors, 
filter_args,
+                                distinct_indexes, distinct_view_descriptors, 
key_selector,
+                                state_value_coder):
+        pass
+
+
+class StreamGroupAggregateOperation(AbstractStreamGroupAggregateOperation):
+
+    def __init__(self, spec, keyed_state_backend):
+        super(StreamGroupAggregateOperation, self).__init__(spec, 
keyed_state_backend)
+
+    def create_process_function(self, user_defined_aggs, input_extractors, 
filter_args,
+                                distinct_indexes, distinct_view_descriptors, 
key_selector,
+                                state_value_coder):
         aggs_handler_function = SimpleAggsHandleFunction(
             user_defined_aggs,
             input_extractors,
@@ -319,12 +359,8 @@ class 
StreamGroupAggregateOperation(StatefulFunctionOperation):
             filter_args,
             distinct_indexes,
             distinct_view_descriptors)
-        key_selector = RowKeySelector(self.grouping)
-        if len(self.data_view_specs) > 0:
-            state_value_coder = DataViewFilterCoder(self.data_view_specs)
-        else:
-            state_value_coder = PickleCoder()
-        self.group_agg_function = GroupAggFunction(
+
+        return GroupAggFunction(
             aggs_handler_function,
             key_selector,
             self.keyed_state_backend,
@@ -332,17 +368,30 @@ class 
StreamGroupAggregateOperation(StatefulFunctionOperation):
             self.generate_update_before,
             self.state_cleaning_enabled,
             self.index_of_count_star)
-        return self.process_element_or_timer, []
 
-    def process_element_or_timer(self, input_data: Tuple[int, Row, int, Row]):
-        # the structure of the input data:
-        # [element_type, element(for process_element), timestamp(for timer), 
key(for timer)]
-        # all the fields are nullable except the "element_type"
-        if input_data[0] != TRIGGER_TIMER:
-            return self.group_agg_function.process_element(input_data[1])
-        else:
-            self.group_agg_function.on_timer(input_data[3])
-            return []
+
+class 
StreamGroupTableAggregateOperation(AbstractStreamGroupAggregateOperation):
+    def __init__(self, spec, keyed_state_backend):
+        super(StreamGroupTableAggregateOperation, self).__init__(spec, 
keyed_state_backend)
+
+    def create_process_function(self, user_defined_aggs, input_extractors, 
filter_args,
+                                distinct_indexes, distinct_view_descriptors, 
key_selector,
+                                state_value_coder):
+        aggs_handler_function = SimpleTableAggsHandleFunction(
+            user_defined_aggs,
+            input_extractors,
+            self.data_view_specs,
+            filter_args,
+            distinct_indexes,
+            distinct_view_descriptors)
+        return GroupTableAggFunction(
+            aggs_handler_function,
+            key_selector,
+            self.keyed_state_backend,
+            state_value_coder,
+            self.generate_update_before,
+            self.state_cleaning_enabled,
+            self.index_of_count_star)
 
 
 class DataStreamStatelessFunctionOperation(Operation):
diff --git a/flink-python/pyflink/table/__init__.py 
b/flink-python/pyflink/table/__init__.py
index 7f0f3b2..b082331 100644
--- a/flink-python/pyflink/table/__init__.py
+++ b/flink-python/pyflink/table/__init__.py
@@ -61,6 +61,8 @@ Important classes of Flink Table API:
       Base interface for user-defined table function.
     - :class:`pyflink.table.AggregateFunction`
       Base interface for user-defined aggregate function.
+    - :class:`pyflink.table.TableAggregateFunction`
+      Base interface for user-defined table aggregate function.
     - :class:`pyflink.table.StatementSet`
       Base interface accepts DML statements or Tables.
 """
@@ -84,7 +86,8 @@ from pyflink.table.table_environment import 
(TableEnvironment, StreamTableEnviro
 from pyflink.table.table_result import TableResult
 from pyflink.table.table_schema import TableSchema
 from pyflink.table.types import DataTypes, UserDefinedType, Row, RowKind
-from pyflink.table.udf import FunctionContext, ScalarFunction, TableFunction, 
AggregateFunction
+from pyflink.table.udf import FunctionContext, ScalarFunction, TableFunction, 
AggregateFunction, \
+    TableAggregateFunction
 
 __all__ = [
     'AggregateFunction',
@@ -118,6 +121,7 @@ __all__ = [
     'TableSchema',
     'TableSink',
     'TableSource',
+    'TableAggregateFunction',
     'UserDefinedType',
     'WindowGroupedTable',
     'WriteMode'
diff --git a/flink-python/pyflink/table/table.py 
b/flink-python/pyflink/table/table.py
index 962a087..7d07e28 100644
--- a/flink-python/pyflink/table/table.py
+++ b/flink-python/pyflink/table/table.py
@@ -828,6 +828,28 @@ class Table(object):
         else:
             return AggregatedTable(self._j_table.aggregate(func._j_expr), 
self._t_env)
 
+    def flat_aggregate(self, func: Union[str, Expression]) -> 
'FlatAggregateTable':
+        """
+        Perform a global flat_aggregate without group_by. flat_aggregate takes 
a
+        :class:`~pyflink.table.TableAggregateFunction` which returns multiple 
rows. Use a selection
+        after the flat_aggregate.
+
+        Example:
+        ::
+
+            >>> table_agg = udtaf(MyTableAggregateFunction())
+            >>> tab.flat_aggregate(table_agg(tab.a).alias("a", 
"b")).select("a, b")
+
+        :param func: user-defined table aggregate function.
+        :return: The result table.
+
+        .. versionadded:: 1.13.0
+        """
+        if isinstance(func, str):
+            return FlatAggregateTable(self._j_table.flatAggregate(func), 
self._t_env)
+        else:
+            return 
FlatAggregateTable(self._j_table.flatAggregate(func._j_expr), self._t_env)
+
     def insert_into(self, table_path: str):
         """
         Writes the :class:`~pyflink.table.Table` to a 
:class:`~pyflink.table.TableSink` that was
@@ -1018,6 +1040,28 @@ class GroupedTable(object):
         else:
             return AggregatedTable(self._j_table.aggregate(func._j_expr), 
self._t_env)
 
+    def flat_aggregate(self, func: Union[str, Expression]) -> 
'FlatAggregateTable':
+        """
+        Performs a flat_aggregate operation on a grouped table. flat_aggregate 
takes a
+        :class:`~pyflink.table.TableAggregateFunction` which returns multiple 
rows. Use a selection
+        after flatAggregate.
+
+        Example:
+        ::
+
+            >>> table_agg = udtaf(MyTableAggregateFunction())
+            >>> 
tab.group_by(tab.c).flat_aggregate(table_agg(tab.a).alias("a")).select("c, a")
+
+        :param func: user-defined table aggregate function.
+        :return: The result table.
+
+        .. versionadded:: 1.13.0
+        """
+        if isinstance(func, str):
+            return FlatAggregateTable(self._j_table.flatAggregate(func), 
self._t_env)
+        else:
+            return 
FlatAggregateTable(self._j_table.flatAggregate(func._j_expr), self._t_env)
+
 
 class GroupWindowedTable(object):
     """
@@ -1196,3 +1240,35 @@ class AggregatedTable(object):
             assert len(fields) == 1
             assert isinstance(fields[0], str)
             return Table(self._j_table.select(fields[0]), self._t_env)
+
+
+class FlatAggregateTable(object):
+    """
+    A table that performs flatAggregate on a :class:`~pyflink.table.Table`, a
+    :class:`~pyflink.table.GroupedTable` or a 
:class:`~pyflink.table.WindowGroupedTable`
+    """
+
+    def __init__(self, java_table, t_env):
+        self._j_table = java_table
+        self._t_env = t_env
+
+    def select(self, *fields: Union[str, Expression]) -> 'Table':
+        """
+        Performs a selection operation on a FlatAggregateTable. Similar to a 
SQL SELECT statement.
+        The field expressions can contain complex expressions.
+
+        Example:
+        ::
+
+            >>> table_agg = udtaf(MyTableAggregateFunction())
+            >>> tab.flat_aggregate(table_agg(tab.a).alias("a", 
"b")).select("a, b")
+
+        :param fields: Expression string.
+        :return: The result table.
+        """
+        if all(isinstance(f, Expression) for f in fields):
+            return Table(self._j_table.select(to_expression_jarray(fields)), 
self._t_env)
+        else:
+            assert len(fields) == 1
+            assert isinstance(fields[0], str)
+            return Table(self._j_table.select(fields[0]), self._t_env)
diff --git a/flink-python/pyflink/table/table_environment.py 
b/flink-python/pyflink/table/table_environment.py
index 3348ed2..d47712d 100644
--- a/flink-python/pyflink/table/table_environment.py
+++ b/flink-python/pyflink/table/table_environment.py
@@ -46,7 +46,7 @@ from pyflink.table.types import _to_java_type, 
_create_type_verifier, RowType, D
     _infer_schema_from_data, _create_converter, from_arrow_type, RowField, 
create_arrow_schema, \
     _to_java_data_type
 from pyflink.table.udf import UserDefinedFunctionWrapper, AggregateFunction, 
udaf, \
-    UserDefinedAggregateFunctionWrapper
+    UserDefinedAggregateFunctionWrapper, udtaf, TableAggregateFunction
 from pyflink.table.utils import to_expression_jarray
 from pyflink.util import utils
 from pyflink.util.utils import get_j_env_configuration, is_local_deployment, 
load_java_class, \
@@ -1574,14 +1574,20 @@ class TableEnvironment(object, metaclass=ABCMeta):
         self._add_jars_to_j_env_config(classpaths_key)
 
     def _wrap_aggregate_function_if_needed(self, function) -> 
UserDefinedFunctionWrapper:
-        if isinstance(function, (AggregateFunction, 
UserDefinedAggregateFunctionWrapper)):
+        if isinstance(function, (AggregateFunction, TableAggregateFunction,
+                                 UserDefinedAggregateFunctionWrapper)):
             if not self._is_blink_planner:
-                raise Exception("Python UDAF is only supported in blink 
planner")
+                raise Exception("Python UDAF and UDTAF are only supported in 
blink planner")
         if isinstance(function, AggregateFunction):
             function = udaf(function,
                             result_type=function.get_result_type(),
                             accumulator_type=function.get_accumulator_type(),
                             name=str(function.__class__.__name__))
+        elif isinstance(function, TableAggregateFunction):
+            function = udtaf(function,
+                             result_type=function.get_result_type(),
+                             accumulator_type=function.get_accumulator_type(),
+                             name=str(function.__class__.__name__))
         return function
 
 
diff --git a/flink-python/pyflink/table/tests/test_row_based_operation.py 
b/flink-python/pyflink/table/tests/test_row_based_operation.py
index 26c8e4c..076b7d0 100644
--- a/flink-python/pyflink/table/tests/test_row_based_operation.py
+++ b/flink-python/pyflink/table/tests/test_row_based_operation.py
@@ -18,9 +18,9 @@
 from pandas.util.testing import assert_frame_equal
 
 from pyflink.common import Row
-from pyflink.table import expressions as expr
+from pyflink.table import expressions as expr, ListView
 from pyflink.table.types import DataTypes
-from pyflink.table.udf import udf, udtf, udaf, AggregateFunction
+from pyflink.table.udf import udf, udtf, udaf, AggregateFunction, 
TableAggregateFunction, udtaf
 from pyflink.testing import source_sink_utils
 from pyflink.testing.test_case_utils import PyFlinkBlinkBatchTableTestCase, \
     PyFlinkBlinkStreamTableTestCase
@@ -223,6 +223,52 @@ class 
StreamRowBasedOperationITTests(RowBasedOperationTests, PyFlinkBlinkStreamT
             .to_pandas()
         assert_frame_equal(result, pd.DataFrame([[1, 3, 15], [2, 2, 4]], 
columns=['a', 'c', 'd']))
 
+    def test_flat_aggregate(self):
+        import pandas as pd
+        self.t_env.register_function("mytop", Top2())
+        t = self.t_env.from_elements([(1, 'Hi', 'Hello'),
+                                      (3, 'Hi', 'hi'),
+                                      (5, 'Hi2', 'hi'),
+                                      (7, 'Hi', 'Hello'),
+                                      (2, 'Hi', 'Hello')], ['a', 'b', 'c'])
+        result = t.group_by("c") \
+            .flat_aggregate("mytop(a)") \
+            .select("c, a") \
+            .flat_aggregate("mytop(a)") \
+            .select("a") \
+            .to_pandas()
+
+        assert_frame_equal(result, pd.DataFrame([[7], [5]], columns=['a']))
+
+    def test_flat_aggregate_list_view(self):
+        import pandas as pd
+        my_concat = udtaf(ListViewConcatTableAggregateFunction())
+        self.t_env.get_config().get_configuration().set_string(
+            "python.fn-execution.bundle.size", "2")
+        # trigger the cache eviction in a bundle.
+        self.t_env.get_config().get_configuration().set_string(
+            "python.state.cache-size", "2")
+        t = self.t_env.from_elements([(1, 'Hi', 'Hello'),
+                                      (3, 'Hi', 'hi'),
+                                      (3, 'Hi2', 'hi'),
+                                      (3, 'Hi', 'hi'),
+                                      (2, 'Hi', 'Hello'),
+                                      (1, 'Hi2', 'Hello'),
+                                      (3, 'Hi3', 'hi'),
+                                      (3, 'Hi2', 'Hello'),
+                                      (3, 'Hi3', 'hi'),
+                                      (2, 'Hi3', 'Hello')], ['a', 'b', 'c'])
+        result = t.group_by(t.c) \
+            .flat_aggregate(my_concat(t.b, ',').alias("b")) \
+            .select(t.b, t.c) \
+            .alias("a, c")
+        assert_frame_equal(result.to_pandas(),
+                           pd.DataFrame([["Hi,Hi2,Hi,Hi3,Hi3", "hi"],
+                                         ["Hi,Hi2,Hi,Hi3,Hi3", "hi"],
+                                         ["Hi,Hi,Hi2,Hi2,Hi3", "Hello"],
+                                         ["Hi,Hi,Hi2,Hi2,Hi3", "Hello"]],
+                                        columns=['a', 'c']))
+
 
 class CountAndSumAggregateFunction(AggregateFunction):
 
@@ -258,6 +304,65 @@ class CountAndSumAggregateFunction(AggregateFunction):
              DataTypes.FIELD("b", DataTypes.BIGINT())])
 
 
+class Top2(TableAggregateFunction):
+
+    def emit_value(self, accumulator):
+        yield Row(accumulator[0])
+        yield Row(accumulator[1])
+
+    def create_accumulator(self):
+        return [None, None]
+
+    def accumulate(self, accumulator, *args):
+        if args[0] is not None:
+            if accumulator[0] is None or args[0] > accumulator[0]:
+                accumulator[1] = accumulator[0]
+                accumulator[0] = args[0]
+            elif accumulator[1] is None or args[0] > accumulator[1]:
+                accumulator[1] = args[0]
+
+    def retract(self, accumulator, *args):
+        accumulator[0] = accumulator[0] - 1
+
+    def merge(self, accumulator, accumulators):
+        for other_acc in accumulators:
+            self.accumulate(accumulator, other_acc[0])
+            self.accumulate(accumulator, other_acc[1])
+
+    def get_accumulator_type(self):
+        return DataTypes.ARRAY(DataTypes.BIGINT())
+
+    def get_result_type(self):
+        return DataTypes.ROW(
+            [DataTypes.FIELD("a", DataTypes.BIGINT())])
+
+
+class ListViewConcatTableAggregateFunction(TableAggregateFunction):
+
+    def emit_value(self, accumulator):
+        result = accumulator[1].join(accumulator[0])
+        yield Row(result)
+        yield Row(result)
+
+    def create_accumulator(self):
+        return Row(ListView(), '')
+
+    def accumulate(self, accumulator, *args):
+        accumulator[1] = args[1]
+        accumulator[0].add(args[0])
+
+    def retract(self, accumulator, *args):
+        raise NotImplementedError
+
+    def get_accumulator_type(self):
+        return DataTypes.ROW([
+            DataTypes.FIELD("f0", DataTypes.LIST_VIEW(DataTypes.STRING())),
+            DataTypes.FIELD("f1", DataTypes.BIGINT())])
+
+    def get_result_type(self):
+        return DataTypes.ROW([DataTypes.FIELD("a", DataTypes.STRING())])
+
+
 if __name__ == '__main__':
     import unittest
 
diff --git a/flink-python/pyflink/table/udf.py 
b/flink-python/pyflink/table/udf.py
index ddba3d9..5a625d0 100644
--- a/flink-python/pyflink/table/udf.py
+++ b/flink-python/pyflink/table/udf.py
@@ -19,7 +19,7 @@ import abc
 import collections
 import functools
 import inspect
-from typing import Union, List, Type, Callable, TypeVar, Generic
+from typing import Union, List, Type, Callable, TypeVar, Generic, Iterable
 
 from pyflink.java_gateway import get_gateway
 from pyflink.metrics import MetricGroup
@@ -28,7 +28,7 @@ from pyflink.table.types import DataType, _to_java_type, 
_to_java_data_type
 from pyflink.util import utils
 
 __all__ = ['FunctionContext', 'AggregateFunction', 'ScalarFunction', 
'TableFunction',
-           'udf', 'udtf', 'udaf']
+           'TableAggregateFunction', 'udf', 'udtf', 'udaf', 'udtaf']
 
 
 class FunctionContext(object):
@@ -126,25 +126,16 @@ T = TypeVar('T')
 ACC = TypeVar('ACC')
 
 
-class AggregateFunction(UserDefinedFunction, Generic[T, ACC]):
-    """
-    Base interface for user-defined aggregate function. A user-defined 
aggregate function maps
-    scalar values of multiple rows to a new scalar value.
-
-    .. versionadded:: 1.12.0
+class ImperativeAggregateFunction(UserDefinedFunction, Generic[T, ACC]):
     """
+    Base interface for user-defined aggregate function and table aggregate 
function.
 
-    @abc.abstractmethod
-    def get_value(self, accumulator: ACC) -> T:
-        """
-        Called every time when an aggregation result should be materialized. 
The returned value
-        could be either an early and incomplete result (periodically emitted 
as data arrives) or
-        the final result of the aggregation.
+    This class is used for unified handling of imperative aggregating 
functions. Concrete
+    implementations should extend from 
:class:`~pyflink.table.AggregateFunction` or
+    :class:`~pyflink.table.TableAggregateFunction`.
 
-        :param accumulator: the accumulator which contains the current 
intermediate results
-        :return: the aggregation result
-        """
-        pass
+    .. versionadded:: 1.13.0
+    """
 
     @abc.abstractmethod
     def create_accumulator(self) -> ACC:
@@ -208,6 +199,50 @@ class AggregateFunction(UserDefinedFunction, Generic[T, 
ACC]):
         pass
 
 
+class AggregateFunction(ImperativeAggregateFunction):
+    """
+    Base interface for user-defined aggregate function. A user-defined 
aggregate function maps
+    scalar values of multiple rows to a new scalar value.
+
+    .. versionadded:: 1.12.0
+    """
+
+    @abc.abstractmethod
+    def get_value(self, accumulator: ACC) -> T:
+        """
+        Called every time when an aggregation result should be materialized. 
The returned value
+        could be either an early and incomplete result (periodically emitted 
as data arrives) or
+        the final result of the aggregation.
+
+        :param accumulator: the accumulator which contains the current 
intermediate results
+        :return: the aggregation result
+        """
+        pass
+
+
+class TableAggregateFunction(ImperativeAggregateFunction):
+    """
+    Base class for a user-defined table aggregate function. A user-defined 
table aggregate function
+    maps scalar values of multiple rows to zero, one, or multiple rows (or 
structured types). If an
+    output record consists of only one field, the structured record can be 
omitted, and a scalar
+    value can be emitted that will be implicitly wrapped into a row by the 
runtime.
+
+    .. versionadded:: 1.13.0
+    """
+
+    @abc.abstractmethod
+    def emit_value(self, accumulator: ACC) -> Iterable[T]:
+        """
+        Called every time when an aggregation result should be materialized. 
The returned value
+        could be either an early and incomplete result (periodically emitted 
as data arrives) or the
+        final result of the aggregation.
+
+        :param accumulator: the accumulator which contains the current 
aggregated results.
+        :return: multiple aggregated result
+        """
+        pass
+
+
 class DelegatingScalarFunction(ScalarFunction):
     """
     Helper scalar function implementation for lambda expression and python 
function. It's for
@@ -434,10 +469,10 @@ class 
UserDefinedTableFunctionWrapper(UserDefinedFunctionWrapper):
 
 class UserDefinedAggregateFunctionWrapper(UserDefinedFunctionWrapper):
     """
-    Wrapper for Python user-defined aggregate function.
+    Wrapper for Python user-defined aggregate function or user-defined table 
aggregate function.
     """
     def __init__(self, func, input_types, result_type, accumulator_type, 
func_type,
-                 deterministic, name):
+                 deterministic, name, is_table_aggregate=False):
         super(UserDefinedAggregateFunctionWrapper, self).__init__(
             func, input_types, func_type, deterministic, name)
 
@@ -459,6 +494,7 @@ class 
UserDefinedAggregateFunctionWrapper(UserDefinedFunctionWrapper):
                     accumulator_type))
         self._result_type = result_type
         self._accumulator_type = accumulator_type
+        self._is_table_aggregate = is_table_aggregate
 
     def _create_judf(self, serialized_func, j_input_types, j_function_kind):
         if self._func_type == "pandas":
@@ -473,8 +509,12 @@ class 
UserDefinedAggregateFunctionWrapper(UserDefinedFunctionWrapper):
         j_accumulator_type = _to_java_data_type(self._accumulator_type)
 
         gateway = get_gateway()
-        PythonAggregateFunction = gateway.jvm \
-            .org.apache.flink.table.functions.python.PythonAggregateFunction
+        if self._is_table_aggregate:
+            PythonAggregateFunction = gateway.jvm \
+                
.org.apache.flink.table.functions.python.PythonTableAggregateFunction
+        else:
+            PythonAggregateFunction = gateway.jvm \
+                
.org.apache.flink.table.functions.python.PythonAggregateFunction
         j_aggregate_function = PythonAggregateFunction(
             self._name,
             bytearray(serialized_func),
@@ -512,7 +552,12 @@ def _create_udaf(f, input_types, result_type, 
accumulator_type, func_type, deter
         f, input_types, result_type, accumulator_type, func_type, 
deterministic, name)
 
 
-def udf(f: Union[Callable, UserDefinedFunction, Type] = None,
+def _create_udtaf(f, input_types, result_type, accumulator_type, func_type, 
deterministic, name):
+    return UserDefinedAggregateFunctionWrapper(
+        f, input_types, result_type, accumulator_type, func_type, 
deterministic, name, True)
+
+
+def udf(f: Union[Callable, ScalarFunction, Type] = None,
         input_types: Union[List[DataType], DataType] = None, result_type: 
DataType = None,
         deterministic: bool = None, name: str = None, func_type: str = 
"general",
         udf_type: str = None) -> Union[UserDefinedScalarFunctionWrapper, 
Callable]:
@@ -567,7 +612,7 @@ def udf(f: Union[Callable, UserDefinedFunction, Type] = 
None,
         return _create_udf(f, input_types, result_type, func_type, 
deterministic, name)
 
 
-def udtf(f: Union[Callable, UserDefinedFunction, Type] = None,
+def udtf(f: Union[Callable, TableFunction, Type] = None,
          input_types: Union[List[DataType], DataType] = None,
          result_types: Union[List[DataType], DataType] = None, deterministic: 
bool = None,
          name: str = None) -> Union[UserDefinedTableFunctionWrapper, Callable]:
@@ -607,7 +652,7 @@ def udtf(f: Union[Callable, UserDefinedFunction, Type] = 
None,
         return _create_udtf(f, input_types, result_types, deterministic, name)
 
 
-def udaf(f: Union[Callable, UserDefinedFunction, Type] = None,
+def udaf(f: Union[Callable, AggregateFunction, Type] = None,
          input_types: Union[List[DataType], DataType] = None, result_type: 
DataType = None,
          accumulator_type: DataType = None, deterministic: bool = None, name: 
str = None,
          func_type: str = "general") -> 
Union[UserDefinedAggregateFunctionWrapper, Callable]:
@@ -647,3 +692,72 @@ def udaf(f: Union[Callable, UserDefinedFunction, Type] = 
None,
     else:
         return _create_udaf(f, input_types, result_type, accumulator_type, 
func_type,
                             deterministic, name)
+
+
+def udtaf(f: Union[Callable, TableAggregateFunction, Type] = None,
+          input_types: Union[List[DataType], DataType] = None, result_type: 
DataType = None,
+          accumulator_type: DataType = None, deterministic: bool = None, name: 
str = None,
+          func_type: str = 'general') -> 
Union[UserDefinedAggregateFunctionWrapper, Callable]:
+    """
+    Helper method for creating a user-defined table aggregate function.
+
+    Example:
+    ::
+
+        >>> # The input_types is optional.
+        >>> class Top2(TableAggregateFunction):
+        ...     def emit_value(self, accumulator):
+        ...         yield Row(accumulator[0])
+        ...         yield Row(accumulator[1])
+        ...
+        ...     def create_accumulator(self):
+        ...         return [None, None]
+        ...
+        ...     def accumulate(self, accumulator, *args):
+        ...         if args[0] is not None:
+        ...             if accumulator[0] is None or args[0] > accumulator[0]:
+        ...                 accumulator[1] = accumulator[0]
+        ...                 accumulator[0] = args[0]
+        ...             elif accumulator[1] is None or args[0] > 
accumulator[1]:
+        ...                 accumulator[1] = args[0]
+        ...
+        ...     def retract(self, accumulator, *args):
+        ...         accumulator[0] = accumulator[0] - 1
+        ...
+        ...     def merge(self, accumulator, accumulators):
+        ...         for other_acc in accumulators:
+        ...             self.accumulate(accumulator, other_acc[0])
+        ...             self.accumulate(accumulator, other_acc[1])
+        ...
+        ...     def get_accumulator_type(self):
+        ...         return DataTypes.ARRAY(DataTypes.BIGINT())
+        ...
+        ...     def get_result_type(self):
+        ...         return DataTypes.ROW(
+        ...             [DataTypes.FIELD("a", DataTypes.BIGINT())])
+        >>> top2 = udtaf(Top2())
+
+    :param f: user-defined table aggregate function.
+    :param input_types: optional, the input data types.
+    :param result_type: the result data type.
+    :param accumulator_type: optional, the accumulator data type.
+    :param deterministic: the determinism of the function's results. True if 
and only if a call to
+                          this function is guaranteed to always return the 
same result given the
+                          same parameters. (default True)
+    :param name: the function name.
+    :param func_type: the type of the python function, available value: general
+                     (default: general)
+    :return: UserDefinedAggregateFunctionWrapper or function.
+
+    .. versionadded:: 1.13.0
+    """
+    if func_type != 'general':
+        raise ValueError("The func_type must be 'general', got %s."
+                         % func_type)
+    if f is None:
+        return functools.partial(_create_udtaf, input_types=input_types, 
result_type=result_type,
+                                 accumulator_type=accumulator_type, 
func_type=func_type,
+                                 deterministic=deterministic, name=name)
+    else:
+        return _create_udtaf(f, input_types, result_type, accumulator_type, 
func_type,
+                             deterministic, name)
diff --git 
a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/AbstractPythonStreamAggregateOperator.java
 
b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/AbstractPythonStreamAggregateOperator.java
index 25168e4..145447f 100644
--- 
a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/AbstractPythonStreamAggregateOperator.java
+++ 
b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/AbstractPythonStreamAggregateOperator.java
@@ -41,11 +41,15 @@ import org.apache.flink.streaming.api.TimerService;
 import org.apache.flink.streaming.api.operators.InternalTimer;
 import org.apache.flink.streaming.api.operators.Triggerable;
 import 
org.apache.flink.streaming.api.operators.python.AbstractOneInputPythonFunctionOperator;
+import org.apache.flink.streaming.api.utils.PythonOperatorUtils;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.table.data.GenericRowData;
 import org.apache.flink.table.data.RowData;
 import org.apache.flink.table.data.UpdatableRowData;
+import org.apache.flink.table.functions.python.PythonAggregateFunctionInfo;
+import org.apache.flink.table.functions.python.PythonEnv;
 import org.apache.flink.table.planner.plan.utils.KeySelectorUtil;
+import org.apache.flink.table.planner.typeutils.DataViewUtils;
 import org.apache.flink.table.runtime.functions.CleanupState;
 import org.apache.flink.table.runtime.keyselector.RowDataKeySelector;
 import 
org.apache.flink.table.runtime.operators.python.utils.StreamRecordRowDataWrappingCollector;
@@ -85,6 +89,10 @@ public abstract class AbstractPythonStreamAggregateOperator
 
        private static final byte TRIGGER_TIMER = 1;
 
+       private final PythonAggregateFunctionInfo[] aggregateFunctions;
+
+       private final DataViewUtils.DataViewSpec[][] dataViewSpecs;
+
        /**
         * The input logical type.
         */
@@ -196,6 +204,8 @@ public abstract class AbstractPythonStreamAggregateOperator
                Configuration config,
                RowType inputType,
                RowType outputType,
+               PythonAggregateFunctionInfo[] aggregateFunctions,
+               DataViewUtils.DataViewSpec[][] dataViewSpecs,
                int[] grouping,
                int indexOfCountStar,
                boolean generateUpdateBefore,
@@ -204,6 +214,8 @@ public abstract class AbstractPythonStreamAggregateOperator
                super(config);
                this.inputType = Preconditions.checkNotNull(inputType);
                this.outputType = Preconditions.checkNotNull(outputType);
+               this.aggregateFunctions = aggregateFunctions;
+               this.dataViewSpecs = dataViewSpecs;
                this.jobOptions = buildJobOptions(config);
                this.grouping = grouping;
                this.indexOfCountStar = indexOfCountStar;
@@ -321,6 +333,11 @@ public abstract class AbstractPythonStreamAggregateOperator
                return keyForTimerService;
        }
 
+       @Override
+       public PythonEnv getPythonEnv() {
+               return aggregateFunctions[0].getPythonFunction().getPythonEnv();
+       }
+
        @VisibleForTesting
        TypeSerializer getKeySerializer() {
                return PythonTypeUtils.toBlinkTypeSerializer(getKeyType());
@@ -348,6 +365,14 @@ public abstract class AbstractPythonStreamAggregateOperator
                builder.setStateCacheSize(stateCacheSize);
                builder.setMapStateReadCacheSize(mapStateReadCacheSize);
                builder.setMapStateWriteCacheSize(mapStateWriteCacheSize);
+               for (int i = 0; i < aggregateFunctions.length; i++) {
+                       DataViewUtils.DataViewSpec[] specs = null;
+                       if (i < dataViewSpecs.length) {
+                               specs = dataViewSpecs[i];
+                       }
+                       builder.addUdfs(
+                               
PythonOperatorUtils.getUserDefinedAggregateFunctionProto(aggregateFunctions[i], 
specs));
+               }
                return builder.build();
        }
 
diff --git 
a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/PythonStreamGroupAggregateOperator.java
 
b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/PythonStreamGroupAggregateOperator.java
index 670e8da..4f1c748 100644
--- 
a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/PythonStreamGroupAggregateOperator.java
+++ 
b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/PythonStreamGroupAggregateOperator.java
@@ -22,9 +22,7 @@ import org.apache.flink.annotation.Internal;
 import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.fnexecution.v1.FlinkFnApi;
-import org.apache.flink.streaming.api.utils.PythonOperatorUtils;
 import org.apache.flink.table.functions.python.PythonAggregateFunctionInfo;
-import org.apache.flink.table.functions.python.PythonEnv;
 import org.apache.flink.table.planner.typeutils.DataViewUtils;
 import org.apache.flink.table.types.logical.RowType;
 
@@ -39,10 +37,6 @@ public class PythonStreamGroupAggregateOperator extends 
AbstractPythonStreamAggr
        @VisibleForTesting
        protected static final String STREAM_GROUP_AGGREGATE_URN = 
"flink:transform:stream_group_aggregate:v1";
 
-       private final PythonAggregateFunctionInfo[] aggregateFunctions;
-
-       private final DataViewUtils.DataViewSpec[][] dataViewSpecs;
-
        /**
         * True if the count(*) agg is inserted by the planner.
         */
@@ -60,18 +54,11 @@ public class PythonStreamGroupAggregateOperator extends 
AbstractPythonStreamAggr
                boolean generateUpdateBefore,
                long minRetentionTime,
                long maxRetentionTime) {
-               super(config, inputType, outputType, grouping, 
indexOfCountStar, generateUpdateBefore,
-                       minRetentionTime, maxRetentionTime);
-               this.aggregateFunctions = aggregateFunctions;
-               this.dataViewSpecs = dataViewSpecs;
+               super(config, inputType, outputType, aggregateFunctions, 
dataViewSpecs, grouping,
+                       indexOfCountStar, generateUpdateBefore, 
minRetentionTime, maxRetentionTime);
                this.countStarInserted = countStarInserted;
        }
 
-       @Override
-       public PythonEnv getPythonEnv() {
-               return aggregateFunctions[0].getPythonFunction().getPythonEnv();
-       }
-
        /**
         * Gets the proto representation of the Python user-defined aggregate 
functions to be executed.
         */
@@ -79,14 +66,6 @@ public class PythonStreamGroupAggregateOperator extends 
AbstractPythonStreamAggr
        public FlinkFnApi.UserDefinedAggregateFunctions 
getUserDefinedFunctionsProto() {
                FlinkFnApi.UserDefinedAggregateFunctions.Builder builder =
                        super.getUserDefinedFunctionsProto().toBuilder();
-               for (int i = 0; i < aggregateFunctions.length; i++) {
-                       DataViewUtils.DataViewSpec[] specs = null;
-                       if (i < dataViewSpecs.length) {
-                               specs = dataViewSpecs[i];
-                       }
-                       builder.addUdfs(
-                               
PythonOperatorUtils.getUserDefinedAggregateFunctionProto(aggregateFunctions[i], 
specs));
-               }
                builder.setCountStarInserted(countStarInserted);
                return builder.build();
        }
diff --git 
a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/PythonStreamGroupTableAggregateOperator.java
 
b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/PythonStreamGroupTableAggregateOperator.java
index a61ad91..fd67de1 100644
--- 
a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/PythonStreamGroupTableAggregateOperator.java
+++ 
b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/PythonStreamGroupTableAggregateOperator.java
@@ -22,9 +22,7 @@ import org.apache.flink.annotation.Internal;
 import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.fnexecution.v1.FlinkFnApi;
-import org.apache.flink.streaming.api.utils.PythonOperatorUtils;
 import org.apache.flink.table.functions.python.PythonAggregateFunctionInfo;
-import org.apache.flink.table.functions.python.PythonEnv;
 import org.apache.flink.table.planner.typeutils.DataViewUtils;
 import org.apache.flink.table.types.logical.RowType;
 
@@ -39,30 +37,19 @@ public class PythonStreamGroupTableAggregateOperator 
extends AbstractPythonStrea
        @VisibleForTesting
        protected static final String STREAM_GROUP_TABLE_AGGREGATE_URN = 
"flink:transform:stream_group_table_aggregate:v1";
 
-       private final PythonAggregateFunctionInfo aggregateFunction;
-
-       private final DataViewUtils.DataViewSpec[] dataViewSpecs;
-
        public PythonStreamGroupTableAggregateOperator(
                Configuration config,
                RowType inputType,
                RowType outputType,
-               PythonAggregateFunctionInfo aggregateFunction,
-               DataViewUtils.DataViewSpec[] dataViewSpec,
+               PythonAggregateFunctionInfo[] aggregateFunctions,
+               DataViewUtils.DataViewSpec[][] dataViewSpecs,
                int[] grouping,
                int indexOfCountStar,
                boolean generateUpdateBefore,
                long minRetentionTime,
                long maxRetentionTime) {
-               super(config, inputType, outputType, grouping, 
indexOfCountStar, generateUpdateBefore,
-                       minRetentionTime, maxRetentionTime);
-               this.aggregateFunction = aggregateFunction;
-               this.dataViewSpecs = dataViewSpec;
-       }
-
-       @Override
-       public PythonEnv getPythonEnv() {
-               return aggregateFunction.getPythonFunction().getPythonEnv();
+               super(config, inputType, outputType, aggregateFunctions, 
dataViewSpecs, grouping,
+                       indexOfCountStar, generateUpdateBefore, 
minRetentionTime, maxRetentionTime);
        }
 
        /**
@@ -72,8 +59,6 @@ public class PythonStreamGroupTableAggregateOperator extends 
AbstractPythonStrea
        public FlinkFnApi.UserDefinedAggregateFunctions 
getUserDefinedFunctionsProto() {
                FlinkFnApi.UserDefinedAggregateFunctions.Builder builder =
                        super.getUserDefinedFunctionsProto().toBuilder();
-               
builder.addUdfs(PythonOperatorUtils.getUserDefinedAggregateFunctionProto(
-                       aggregateFunction, dataViewSpecs));
                return builder.build();
        }
 
diff --git 
a/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/aggregate/PythonStreamGroupTableAggregateOperatorTest.java
 
b/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/aggregate/PythonStreamGroupTableAggregateOperatorTest.java
index eee3b5a..d839fcb 100644
--- 
a/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/aggregate/PythonStreamGroupTableAggregateOperatorTest.java
+++ 
b/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/aggregate/PythonStreamGroupTableAggregateOperatorTest.java
@@ -224,11 +224,12 @@ public class PythonStreamGroupTableAggregateOperatorTest 
extends AbstractPythonS
                        config,
                        getInputType(),
                        getOutputType(),
-                       new PythonAggregateFunctionInfo(
-                               
PythonScalarFunctionOperatorTestBase.DummyPythonFunction.INSTANCE,
-                               new Integer[]{0},
-                               -1,
-                               false),
+                       new PythonAggregateFunctionInfo[]{
+                               new PythonAggregateFunctionInfo(
+                                       
PythonScalarFunctionOperatorTestBase.DummyPythonFunction.INSTANCE,
+                                       new Integer[]{0},
+                                       -1,
+                                       false)},
                        getGrouping(),
                        -1,
                        false,
@@ -243,7 +244,7 @@ public class PythonStreamGroupTableAggregateOperatorTest 
extends AbstractPythonS
                        Configuration config,
                        RowType inputType,
                        RowType outputType,
-                       PythonAggregateFunctionInfo aggregateFunction,
+                       PythonAggregateFunctionInfo[] aggregateFunctions,
                        int[] grouping,
                        int indexOfCountStar,
                        boolean generateUpdateBefore,
@@ -253,8 +254,8 @@ public class PythonStreamGroupTableAggregateOperatorTest 
extends AbstractPythonS
                                config,
                                inputType,
                                outputType,
-                               aggregateFunction,
-                               new DataViewUtils.DataViewSpec[0],
+                               aggregateFunctions,
+                               new DataViewUtils.DataViewSpec[0][0],
                                grouping,
                                indexOfCountStar,
                                generateUpdateBefore,
diff --git 
a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/PythonTableAggregateFunction.java
 
b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/PythonTableAggregateFunction.java
new file mode 100644
index 0000000..4486f69
--- /dev/null
+++ 
b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/PythonTableAggregateFunction.java
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.functions.python;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.table.catalog.DataTypeFactory;
+import org.apache.flink.table.functions.TableAggregateFunction;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.inference.TypeInference;
+import org.apache.flink.table.types.inference.TypeStrategies;
+import org.apache.flink.table.types.utils.TypeConversions;
+
+/**
+ * The wrapper of user defined python table aggregate function.
+ */
+@Internal
+public class PythonTableAggregateFunction extends TableAggregateFunction 
implements PythonFunction {
+
+       private static final long serialVersionUID = 1L;
+
+       private final String name;
+       private final byte[] serializedTableAggregateFunction;
+       private final DataType[] inputTypes;
+       private final DataType resultType;
+       private final DataType accumulatorType;
+       private final PythonFunctionKind pythonFunctionKind;
+       private final boolean deterministic;
+       private final PythonEnv pythonEnv;
+
+       public PythonTableAggregateFunction(
+               String name,
+               byte[] serializedTableAggregateFunction,
+               DataType[] inputTypes,
+               DataType resultType,
+               DataType accumulatorType,
+               PythonFunctionKind pythonFunctionKind,
+               boolean deterministic,
+               PythonEnv pythonEnv) {
+               this.name = name;
+               this.serializedTableAggregateFunction = 
serializedTableAggregateFunction;
+               this.inputTypes = inputTypes;
+               this.resultType = resultType;
+               this.accumulatorType = accumulatorType;
+               this.pythonFunctionKind = pythonFunctionKind;
+               this.deterministic = deterministic;
+               this.pythonEnv = pythonEnv;
+       }
+
+       public void accumulate(Object accumulator, Object... args) {
+               throw new UnsupportedOperationException(
+                       "This method is a placeholder and should not be 
called.");
+       }
+
+       public void emitValue(Object accumulator, Object out) {
+               throw new UnsupportedOperationException(
+                       "This method is a placeholder and should not be 
called.");
+       }
+
+       @Override
+       public Object createAccumulator() {
+               return null;
+       }
+
+       @Override
+       public byte[] getSerializedPythonFunction() {
+               return serializedTableAggregateFunction;
+       }
+
+       @Override
+       public PythonEnv getPythonEnv() {
+               return pythonEnv;
+       }
+
+       @Override
+       public PythonFunctionKind getPythonFunctionKind() {
+               return pythonFunctionKind;
+       }
+
+       @Override
+       public boolean isDeterministic() {
+               return deterministic;
+       }
+
+       @Override
+       public TypeInformation getResultType() {
+               return TypeConversions.fromDataTypeToLegacyInfo(resultType);
+       }
+
+       @Override
+       public TypeInformation getAccumulatorType() {
+               return 
TypeConversions.fromDataTypeToLegacyInfo(accumulatorType);
+       }
+
+       @Override
+       public TypeInference getTypeInference(DataTypeFactory typeFactory) {
+               TypeInference.Builder builder = TypeInference.newBuilder();
+               if (inputTypes != null) {
+                       builder.typedArguments(inputTypes);
+               }
+               return builder
+                       .outputTypeStrategy(TypeStrategies.explicit(resultType))
+                       
.accumulatorTypeStrategy(TypeStrategies.explicit(accumulatorType))
+                       .build();
+       }
+
+       @Override
+       public String toString() {
+               return name;
+       }
+}
diff --git 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/common/CommonPythonAggregate.scala
 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/common/CommonPythonAggregate.scala
index 652786c..b06e21d 100644
--- 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/common/CommonPythonAggregate.scala
+++ 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/common/CommonPythonAggregate.scala
@@ -85,11 +85,16 @@ trait CommonPythonAggregate extends CommonPythonBase {
             
pythonAggregateInfoList.aggInfos(i).argIndexes.map(_.asInstanceOf[AnyRef]),
             aggCalls(i).filterArg,
             aggCalls(i).isDistinct))
+          val typeInference = function match {
+            case aggregateFunction: PythonAggregateFunction =>
+              aggregateFunction.getTypeInference(null)
+            case tableAggregateFunction: PythonTableAggregateFunction =>
+              tableAggregateFunction.getTypeInference(null)
+          }
           dataViewSpecList.add(
             extractDataViewSpecs(
               i,
-              
function.asInstanceOf[PythonAggregateFunction].getTypeInference(null)
-              .getAccumulatorTypeStrategy.get().inferType(null).get()))
+              
typeInference.getAccumulatorTypeStrategy.get().inferType(null).get()))
         case function: UserDefinedFunction =>
           var filterArg = -1
           var distinct = false
diff --git 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonGroupTableAggregate.scala
 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonGroupTableAggregate.scala
index dcf9b7b..c92359f 100644
--- 
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonGroupTableAggregate.scala
+++ 
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonGroupTableAggregate.scala
@@ -18,9 +18,19 @@
 package org.apache.flink.table.planner.plan.nodes.physical.stream
 
 import org.apache.flink.api.dag.Transformation
-import org.apache.flink.table.api.TableException
+import org.apache.flink.configuration.Configuration
+import org.apache.flink.core.memory.ManagedMemoryUseCase
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator
+import org.apache.flink.streaming.api.transformations.OneInputTransformation
 import org.apache.flink.table.data.RowData
+import org.apache.flink.table.functions.python.PythonAggregateFunctionInfo
+import org.apache.flink.table.planner.calcite.FlinkTypeFactory
 import org.apache.flink.table.planner.delegation.StreamPlanner
+import org.apache.flink.table.planner.plan.nodes.common.CommonPythonAggregate
+import org.apache.flink.table.planner.plan.utils.{ChangelogPlanUtils, 
KeySelectorUtil}
+import org.apache.flink.table.planner.typeutils.DataViewUtils.DataViewSpec
+import org.apache.flink.table.runtime.typeutils.InternalTypeInfo
+import org.apache.flink.table.types.logical.RowType
 
 import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
 import org.apache.calcite.rel.RelNode
@@ -45,7 +55,8 @@ class StreamExecPythonGroupTableAggregate(
     inputRel,
     outputRowType,
     grouping,
-    aggCalls) {
+    aggCalls)
+  with CommonPythonAggregate {
 
   override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): 
RelNode = {
     new StreamExecPythonGroupTableAggregate(
@@ -59,6 +70,112 @@ class StreamExecPythonGroupTableAggregate(
 
   override protected def translateToPlanInternal(
       planner: StreamPlanner): Transformation[RowData] = {
-    throw new TableException("The implementation will be in FLINK-20528.")
+    val tableConfig = planner.getTableConfig
+
+    if (grouping.length > 0 && tableConfig.getMinIdleStateRetentionTime < 0) {
+      LOG.warn("No state retention interval configured for a query which 
accumulates state. " +
+        "Please provide a query configuration with valid retention interval to 
prevent excessive " +
+        "state size. You may specify a retention time of 0 to not clean up the 
state.")
+    }
+
+    val inputTransformation = getInputNodes.get(0).translateToPlan(planner)
+      .asInstanceOf[Transformation[RowData]]
+
+    val outRowType = FlinkTypeFactory.toLogicalRowType(outputRowType)
+    val inputRowType = FlinkTypeFactory.toLogicalRowType(getInput.getRowType)
+
+    val generateUpdateBefore = ChangelogPlanUtils.generateUpdateBefore(this)
+
+    val inputCountIndex = aggInfoList.getIndexOfCountStar
+
+    var (pythonFunctionInfos, dataViewSpecs) =
+      extractPythonAggregateFunctionInfos(aggInfoList, aggCalls)
+
+    if (dataViewSpecs.forall(_.isEmpty)) {
+      dataViewSpecs = Array(Array())
+    }
+
+    val operator = getPythonTableAggregateFunctionOperator(
+      getConfig(planner.getExecEnv, tableConfig),
+      inputRowType,
+      outRowType,
+      pythonFunctionInfos,
+      dataViewSpecs,
+      tableConfig.getMinIdleStateRetentionTime,
+      tableConfig.getMaxIdleStateRetentionTime,
+      grouping,
+      generateUpdateBefore,
+      inputCountIndex)
+
+    val selector = KeySelectorUtil.getRowDataSelector(
+      grouping,
+      InternalTypeInfo.of(inputRowType))
+
+    // partitioned aggregation
+    val ret = new OneInputTransformation(
+      inputTransformation,
+      getRelDetailedDescription,
+      operator,
+      InternalTypeInfo.of(outRowType),
+      inputTransformation.getParallelism)
+
+    if (inputsContainSingleton()) {
+      ret.setParallelism(1)
+      ret.setMaxParallelism(1)
+    }
+
+    if 
(isPythonWorkerUsingManagedMemory(planner.getTableConfig.getConfiguration)) {
+      ret.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON)
+    }
+
+    // set KeyType and Selector for state
+    ret.setStateKeySelector(selector)
+    ret.setStateKeyType(selector.getProducedType)
+    ret
   }
+
+  private[this] def getPythonTableAggregateFunctionOperator(
+      config: Configuration,
+      inputType: RowType,
+      outputType: RowType,
+      aggregateFunctions: Array[PythonAggregateFunctionInfo],
+      dataViewSpecs: Array[Array[DataViewSpec]],
+      minIdleStateRetentionTime: Long,
+      maxIdleStateRetentionTime: Long,
+      grouping: Array[Int],
+      generateUpdateBefore: Boolean,
+      indexOfCountStar: Int): OneInputStreamOperator[RowData, RowData] = {
+
+    val clazz = loadClass(
+      
StreamExecPythonGroupTableAggregate.PYTHON_STREAM_TABLE_AGGREGATE_OPERATOR_NAME)
+    val ctor = clazz.getConstructor(
+      classOf[Configuration],
+      classOf[RowType],
+      classOf[RowType],
+      classOf[Array[PythonAggregateFunctionInfo]],
+      classOf[Array[Array[DataViewSpec]]],
+      classOf[Array[Int]],
+      classOf[Int],
+      classOf[Boolean],
+      classOf[Long],
+      classOf[Long])
+    ctor.newInstance(
+      config.asInstanceOf[AnyRef],
+      inputType.asInstanceOf[AnyRef],
+      outputType.asInstanceOf[AnyRef],
+      aggregateFunctions.asInstanceOf[AnyRef],
+      dataViewSpecs.asInstanceOf[AnyRef],
+      grouping.asInstanceOf[AnyRef],
+      indexOfCountStar.asInstanceOf[AnyRef],
+      generateUpdateBefore.asInstanceOf[AnyRef],
+      minIdleStateRetentionTime.asInstanceOf[AnyRef],
+      maxIdleStateRetentionTime.asInstanceOf[AnyRef])
+      .asInstanceOf[OneInputStreamOperator[RowData, RowData]]
+  }
+}
+
+object StreamExecPythonGroupTableAggregate {
+  val PYTHON_STREAM_TABLE_AGGREGATE_OPERATOR_NAME: String =
+    "org.apache.flink.table.runtime.operators.python.aggregate." +
+      "PythonStreamGroupTableAggregateOperator"
 }

Reply via email to