This is an automated email from the ASF dual-hosted git repository.
dianfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new 9c486d1 [FLINK-20528][python] Support table aggregation for group
aggregation in streaming mode
9c486d1 is described below
commit 9c486d1c1d245263630b2e9dc0f3d6a95e130a4d
Author: HuangXingBo <[email protected]>
AuthorDate: Thu Dec 17 10:00:11 2020 +0800
[FLINK-20528][python] Support table aggregation for group aggregation in
streaming mode
This closes #14389.
---
flink-python/pyflink/fn_execution/aggregate.py | 207 +++++++++++++++++----
.../pyflink/fn_execution/beam/beam_operations.py | 10 +
.../pyflink/fn_execution/operation_utils.py | 4 +-
flink-python/pyflink/fn_execution/operations.py | 87 +++++++--
flink-python/pyflink/table/__init__.py | 6 +-
flink-python/pyflink/table/table.py | 76 ++++++++
flink-python/pyflink/table/table_environment.py | 12 +-
.../table/tests/test_row_based_operation.py | 109 ++++++++++-
flink-python/pyflink/table/udf.py | 164 +++++++++++++---
.../AbstractPythonStreamAggregateOperator.java | 25 +++
.../PythonStreamGroupAggregateOperator.java | 25 +--
.../PythonStreamGroupTableAggregateOperator.java | 23 +--
...ythonStreamGroupTableAggregateOperatorTest.java | 17 +-
.../python/PythonTableAggregateFunction.java | 127 +++++++++++++
.../plan/nodes/common/CommonPythonAggregate.scala | 9 +-
.../StreamExecPythonGroupTableAggregate.scala | 123 +++++++++++-
16 files changed, 881 insertions(+), 143 deletions(-)
diff --git a/flink-python/pyflink/fn_execution/aggregate.py
b/flink-python/pyflink/fn_execution/aggregate.py
index 1971623..b719d0e 100644
--- a/flink-python/pyflink/fn_execution/aggregate.py
+++ b/flink-python/pyflink/fn_execution/aggregate.py
@@ -16,7 +16,7 @@
# limitations under the License.
################################################################################
from abc import ABC, abstractmethod
-from typing import List, Dict
+from typing import List, Dict, Iterable
from apache_beam.coders import PickleCoder, Coder
@@ -25,8 +25,9 @@ from pyflink.common.state import ListState, MapState
from pyflink.fn_execution.coders import from_proto
from pyflink.fn_execution.operation_utils import is_built_in_function,
load_aggregate_function
from pyflink.fn_execution.state_impl import RemoteKeyedStateBackend
-from pyflink.table import AggregateFunction, FunctionContext
+from pyflink.table import AggregateFunction, FunctionContext,
TableAggregateFunction
from pyflink.table.data_view import ListView, MapView
+from pyflink.table.udf import ImperativeAggregateFunction
def join_row(left: Row, right: Row):
@@ -219,9 +220,9 @@ class StateDataViewStore(object):
self._keyed_state_backend.get_map_state(state_name, key_coder,
value_coder))
-class AggsHandleFunction(ABC):
+class AggsHandleFunctionBase(ABC):
"""
- The base class for handling aggregate functions.
+ The base class for handling aggregate or table aggregate functions.
"""
@abstractmethod
@@ -301,6 +302,20 @@ class AggsHandleFunction(ABC):
pass
@abstractmethod
+ def close(self):
+ """
+ Tear-down method for this function. It can be used for clean up work.
+ By default, this method does nothing.
+ """
+ pass
+
+
+class AggsHandleFunction(AggsHandleFunctionBase):
+ """
+ The base class for handling aggregate functions.
+ """
+
+ @abstractmethod
def get_value(self) -> Row:
"""
Gets the result of the aggregation from the current accumulators.
@@ -309,25 +324,28 @@ class AggsHandleFunction(ABC):
"""
pass
+
+class TableAggsHandleFunction(AggsHandleFunctionBase):
+ """
+ The base class for handling table aggregate functions.
+ """
+
@abstractmethod
- def close(self):
+ def emit_value(self, current_key: Row, is_retract: bool) -> Iterable[Row]:
"""
- Tear-down method for this function. It can be used for clean up work.
- By default, this method does nothing.
+ Emit the result of the table aggregation.
"""
pass
-class SimpleAggsHandleFunction(AggsHandleFunction):
+class SimpleAggsHandleFunctionBase(AggsHandleFunctionBase):
"""
- A simple AggsHandleFunction implementation which provides the basic
functionality.
+ A simple AggsHandleFunctionBase implementation which provides the basic
functionality.
"""
def __init__(self,
- udfs: List[AggregateFunction],
+ udfs: List[ImperativeAggregateFunction],
input_extractors: List,
- index_of_count_star: int,
- count_star_inserted: bool,
udf_data_view_specs: List[List[DataViewSpec]],
filter_args: List[int],
distinct_indexes: List[int],
@@ -335,10 +353,6 @@ class SimpleAggsHandleFunction(AggsHandleFunction):
self._udfs = udfs
self._input_extractors = input_extractors
self._accumulators = None # type: Row
- self._get_value_indexes = [i for i in range(len(udfs))]
- if index_of_count_star >= 0 and count_star_inserted:
- # The record count is used internally, should be ignored by the
get_value method.
- self._get_value_indexes.remove(index_of_count_star)
self._udf_data_view_specs = udf_data_view_specs
self._udf_data_views = []
self._filter_args = filter_args
@@ -451,13 +465,64 @@ class SimpleAggsHandleFunction(AggsHandleFunction):
for data_view in self._udf_data_views[i].values():
data_view.clear()
+ def close(self):
+ for udf in self._udfs:
+ udf.close()
+
+
+class SimpleAggsHandleFunction(SimpleAggsHandleFunctionBase,
AggsHandleFunction):
+ """
+ A simple AggsHandleFunction implementation which provides the basic
functionality.
+ """
+
+ def __init__(self,
+ udfs: List[AggregateFunction],
+ input_extractors: List,
+ index_of_count_star: int,
+ count_star_inserted: bool,
+ udf_data_view_specs: List[List[DataViewSpec]],
+ filter_args: List[int],
+ distinct_indexes: List[int],
+ distinct_view_descriptors: Dict[int, DistinctViewDescriptor]):
+ super(SimpleAggsHandleFunction, self).__init__(
+ udfs, input_extractors, udf_data_view_specs, filter_args,
distinct_indexes,
+ distinct_view_descriptors)
+ self._get_value_indexes = [i for i in range(len(udfs))]
+ if index_of_count_star >= 0 and count_star_inserted:
+ # The record count is used internally, should be ignored by the
get_value method.
+ self._get_value_indexes.remove(index_of_count_star)
+
def get_value(self):
return Row(*[self._udfs[i].get_value(self._accumulators[i])
for i in self._get_value_indexes])
- def close(self):
- for udf in self._udfs:
- udf.close()
+
+class SimpleTableAggsHandleFunction(SimpleAggsHandleFunctionBase,
TableAggsHandleFunction):
+ """
+ A simple TableAggsHandleFunction implementation which provides the basic
functionality.
+ """
+
+ def __init__(self,
+ udfs: List[TableAggregateFunction],
+ input_extractors: List,
+ udf_data_view_specs: List[List[DataViewSpec]],
+ filter_args: List[int],
+ distinct_indexes: List[int],
+ distinct_view_descriptors: Dict[int, DistinctViewDescriptor]):
+ super(SimpleTableAggsHandleFunction, self).__init__(
+ udfs, input_extractors, udf_data_view_specs, filter_args,
distinct_indexes,
+ distinct_view_descriptors)
+
+ def emit_value(self, current_key: Row, is_retract: bool):
+ udf = self._udfs[0] # type: TableAggregateFunction
+ results = udf.emit_value(self._accumulators[0])
+ for x in results:
+ result = join_row(current_key, x)
+ if is_retract:
+ result.set_row_kind(RowKind.DELETE)
+ else:
+ result.set_row_kind(RowKind.INSERT)
+ yield result
class RecordCounter(ABC):
@@ -494,10 +559,10 @@ class RetractionRecordCounter(RecordCounter):
return acc is None or acc[self._index_of_count_star][0] == 0
-class GroupAggFunction(object):
+class GroupAggFunctionBase(object):
def __init__(self,
- aggs_handle: AggsHandleFunction,
+ aggs_handle: AggsHandleFunctionBase,
key_selector: RowKeySelector,
state_backend: RemoteKeyedStateBackend,
state_value_coder: Coder,
@@ -518,6 +583,41 @@ class GroupAggFunction(object):
def close(self):
self.aggs_handle.close()
+ def on_timer(self, key):
+ if self.state_cleaning_enabled:
+ self.state_backend.set_current_key(key)
+ accumulator_state = self.state_backend.get_value_state(
+ "accumulators", self.state_value_coder)
+ accumulator_state.clear()
+ self.aggs_handle.cleanup()
+
+ @staticmethod
+ def is_retract_msg(data: Row):
+ return data.get_row_kind() == RowKind.UPDATE_BEFORE or
data.get_row_kind() == RowKind.DELETE
+
+ @staticmethod
+ def is_accumulate_msg(data: Row):
+ return data.get_row_kind() == RowKind.UPDATE_AFTER or
data.get_row_kind() == RowKind.INSERT
+
+ @abstractmethod
+ def process_element(self, input_data: Row):
+ pass
+
+
+class GroupAggFunction(GroupAggFunctionBase):
+
+ def __init__(self,
+ aggs_handle: AggsHandleFunction,
+ key_selector: RowKeySelector,
+ state_backend: RemoteKeyedStateBackend,
+ state_value_coder: Coder,
+ generate_update_before: bool,
+ state_cleaning_enabled: bool,
+ index_of_count_star: int):
+ super(GroupAggFunction, self).__init__(
+ aggs_handle, key_selector, state_backend, state_value_coder,
generate_update_before,
+ state_cleaning_enabled, index_of_count_star)
+
def process_element(self, input_data: Row):
key = self.key_selector.get_key(input_data)
self.state_backend.set_current_key(key)
@@ -597,20 +697,55 @@ class GroupAggFunction(object):
# cleanup dataview under current key
self.aggs_handle.cleanup()
- def on_timer(self, key):
- if self.state_cleaning_enabled:
- self.state_backend.set_current_key(key)
- accumulator_state = self.state_backend.get_value_state(
- "accumulators", self.state_value_coder)
- accumulator_state.clear()
- self.aggs_handle.cleanup()
- @staticmethod
- def is_retract_msg(data: Row):
- return data.get_row_kind() == RowKind.UPDATE_BEFORE \
- or data.get_row_kind() == RowKind.DELETE
+class GroupTableAggFunction(GroupAggFunctionBase):
+ def __init__(self,
+ aggs_handle: TableAggsHandleFunction,
+ key_selector: RowKeySelector,
+ state_backend: RemoteKeyedStateBackend,
+ state_value_coder: Coder,
+ generate_update_before: bool,
+ state_cleaning_enabled: bool,
+ index_of_count_star: int):
+ super(GroupTableAggFunction, self).__init__(
+ aggs_handle, key_selector, state_backend, state_value_coder,
generate_update_before,
+ state_cleaning_enabled, index_of_count_star)
- @staticmethod
- def is_accumulate_msg(data: Row):
- return data.get_row_kind() == RowKind.UPDATE_AFTER \
- or data.get_row_kind() == RowKind.INSERT
+ def process_element(self, input_data: Row):
+ key = self.key_selector.get_key(input_data)
+ self.state_backend.set_current_key(key)
+ self.state_backend.clear_cached_iterators()
+ accumulator_state = self.state_backend.get_value_state(
+ "accumulators", self.state_value_coder)
+ accumulators = accumulator_state.value()
+ if accumulators is None:
+ first_row = True
+ accumulators = self.aggs_handle.create_accumulators()
+ else:
+ first_row = False
+
+ # set accumulators to handler first
+ self.aggs_handle.set_accumulators(accumulators)
+
+ if not first_row and self.generate_update_before:
+ yield from self.aggs_handle.emit_value(key, True)
+
+ # update aggregate result and set to the newRow
+ if self.is_accumulate_msg(input_data):
+ # accumulate input
+ self.aggs_handle.accumulate(input_data)
+ else:
+ # retract input
+ self.aggs_handle.retract(input_data)
+
+ # get accumulator
+ accumulators = self.aggs_handle.get_accumulators()
+
+ if not self.record_counter.record_count_is_zero(accumulators):
+ yield from self.aggs_handle.emit_value(key, False)
+ accumulator_state.update(accumulators)
+ else:
+ # and clear all state
+ accumulator_state.clear()
+ # cleanup dataview under current key
+ self.aggs_handle.cleanup()
diff --git a/flink-python/pyflink/fn_execution/beam/beam_operations.py
b/flink-python/pyflink/fn_execution/beam/beam_operations.py
index 3b95201..c36515d 100644
--- a/flink-python/pyflink/fn_execution/beam/beam_operations.py
+++ b/flink-python/pyflink/fn_execution/beam/beam_operations.py
@@ -110,6 +110,16 @@ def create_data_stream_keyed_process_function(factory,
transform_id, transform_p
operations.KeyedProcessFunctionOperation)
+@bundle_processor.BeamTransformFactory.register_urn(
+ operations.STREAM_GROUP_TABLE_AGGREGATE_URN,
+ flink_fn_execution_pb2.UserDefinedAggregateFunctions)
+def create_table_aggregate_function(factory, transform_id, transform_proto,
parameter, consumers):
+ return _create_user_defined_function_operation(
+ factory, transform_proto, consumers, parameter,
+ beam_operations.StatefulFunctionOperation,
+ operations.StreamGroupTableAggregateOperation)
+
+
def _create_user_defined_function_operation(factory, transform_proto,
consumers, udfs_proto,
beam_operation_cls,
internal_operation_cls):
output_tags = list(transform_proto.outputs.keys())
diff --git a/flink-python/pyflink/fn_execution/operation_utils.py
b/flink-python/pyflink/fn_execution/operation_utils.py
index 5e8e768..2dbd8e8 100644
--- a/flink-python/pyflink/fn_execution/operation_utils.py
+++ b/flink-python/pyflink/fn_execution/operation_utils.py
@@ -26,7 +26,7 @@ from pyflink.fn_execution import flink_fn_execution_pb2,
pickle
from pyflink.serializers import PickleSerializer
from pyflink.table import functions
from pyflink.table.udf import DelegationTableFunction,
DelegatingScalarFunction, \
- AggregateFunction, PandasAggregateFunctionWrapper
+ ImperativeAggregateFunction, PandasAggregateFunctionWrapper
_func_num = 0
_constant_num = 0
@@ -147,7 +147,7 @@ def extract_user_defined_aggregate_function(
user_defined_function_proto,
distinct_info_dict: Dict[Tuple[List[str]], Tuple[List[int],
List[int]]]):
user_defined_agg =
load_aggregate_function(user_defined_function_proto.payload)
- assert isinstance(user_defined_agg, AggregateFunction)
+ assert isinstance(user_defined_agg, ImperativeAggregateFunction)
args_str = []
local_variable_dict = {}
for arg in user_defined_function_proto.inputs:
diff --git a/flink-python/pyflink/fn_execution/operations.py
b/flink-python/pyflink/fn_execution/operations.py
index dd4c7cf..9cfa1da 100644
--- a/flink-python/pyflink/fn_execution/operations.py
+++ b/flink-python/pyflink/fn_execution/operations.py
@@ -30,7 +30,8 @@ from pyflink.fn_execution import flink_fn_execution_pb2,
operation_utils
from pyflink.fn_execution.beam.beam_coders import DataViewFilterCoder
from pyflink.fn_execution.operation_utils import
extract_user_defined_aggregate_function
from pyflink.fn_execution.aggregate import RowKeySelector,
SimpleAggsHandleFunction, \
- GroupAggFunction, extract_data_view_specs, DistinctViewDescriptor
+ GroupAggFunction, extract_data_view_specs, DistinctViewDescriptor, \
+ SimpleTableAggsHandleFunction, GroupTableAggFunction
from pyflink.metrics.metricbase import GenericMetricGroup
from pyflink.table import FunctionContext, Row
@@ -39,6 +40,7 @@ from pyflink.table import FunctionContext, Row
SCALAR_FUNCTION_URN = "flink:transform:scalar_function:v1"
TABLE_FUNCTION_URN = "flink:transform:table_function:v1"
STREAM_GROUP_AGGREGATE_URN = "flink:transform:stream_group_aggregate:v1"
+STREAM_GROUP_TABLE_AGGREGATE_URN =
"flink:transform:stream_group_table_aggregate:v1"
PANDAS_AGGREGATE_FUNCTION_URN = "flink:transform:aggregate_function:arrow:v1"
PANDAS_BATCH_OVER_WINDOW_AGGREGATE_FUNCTION_URN = \
"flink:transform:batch_over_window_aggregate_function:arrow:v1"
@@ -262,7 +264,7 @@ class StatefulFunctionOperation(Operation):
TRIGGER_TIMER = 1
-class StreamGroupAggregateOperation(StatefulFunctionOperation):
+class AbstractStreamGroupAggregateOperation(StatefulFunctionOperation):
def __init__(self, spec, keyed_state_backend):
self.generate_update_before = spec.serialized_fn.generate_update_before
@@ -276,7 +278,7 @@ class
StreamGroupAggregateOperation(StatefulFunctionOperation):
self.state_cache_size = spec.serialized_fn.state_cache_size
self.state_cleaning_enabled = spec.serialized_fn.state_cleaning_enabled
self.data_view_specs = extract_data_view_specs(spec.serialized_fn.udfs)
- super(StreamGroupAggregateOperation, self).__init__(spec,
keyed_state_backend)
+ super(AbstractStreamGroupAggregateOperation, self).__init__(spec,
keyed_state_backend)
def open(self):
self.group_agg_function.open(FunctionContext(self.base_metric_group))
@@ -310,6 +312,44 @@ class
StreamGroupAggregateOperation(StatefulFunctionOperation):
# use the agg index of the first function as the key of shared
distinct view
distinct_view_descriptors[agg_index_list[0]] =
DistinctViewDescriptor(
input_extractors[agg_index_list[0]], filter_arg_list)
+
+ key_selector = RowKeySelector(self.grouping)
+ if len(self.data_view_specs) > 0:
+ state_value_coder = DataViewFilterCoder(self.data_view_specs)
+ else:
+ state_value_coder = PickleCoder()
+
+ self.group_agg_function = self.create_process_function(
+ user_defined_aggs, input_extractors, filter_args, distinct_indexes,
+ distinct_view_descriptors, key_selector, state_value_coder)
+
+ return self.process_element_or_timer, []
+
+ def process_element_or_timer(self, input_data: Tuple[int, Row, int, Row]):
+ # the structure of the input data:
+ # [element_type, element(for process_element), timestamp(for timer),
key(for timer)]
+ # all the fields are nullable except the "element_type"
+ if input_data[0] != TRIGGER_TIMER:
+ return self.group_agg_function.process_element(input_data[1])
+ else:
+ self.group_agg_function.on_timer(input_data[3])
+ return []
+
+ @abc.abstractmethod
+ def create_process_function(self, user_defined_aggs, input_extractors,
filter_args,
+ distinct_indexes, distinct_view_descriptors,
key_selector,
+ state_value_coder):
+ pass
+
+
+class StreamGroupAggregateOperation(AbstractStreamGroupAggregateOperation):
+
+ def __init__(self, spec, keyed_state_backend):
+ super(StreamGroupAggregateOperation, self).__init__(spec,
keyed_state_backend)
+
+ def create_process_function(self, user_defined_aggs, input_extractors,
filter_args,
+ distinct_indexes, distinct_view_descriptors,
key_selector,
+ state_value_coder):
aggs_handler_function = SimpleAggsHandleFunction(
user_defined_aggs,
input_extractors,
@@ -319,12 +359,8 @@ class
StreamGroupAggregateOperation(StatefulFunctionOperation):
filter_args,
distinct_indexes,
distinct_view_descriptors)
- key_selector = RowKeySelector(self.grouping)
- if len(self.data_view_specs) > 0:
- state_value_coder = DataViewFilterCoder(self.data_view_specs)
- else:
- state_value_coder = PickleCoder()
- self.group_agg_function = GroupAggFunction(
+
+ return GroupAggFunction(
aggs_handler_function,
key_selector,
self.keyed_state_backend,
@@ -332,17 +368,30 @@ class
StreamGroupAggregateOperation(StatefulFunctionOperation):
self.generate_update_before,
self.state_cleaning_enabled,
self.index_of_count_star)
- return self.process_element_or_timer, []
- def process_element_or_timer(self, input_data: Tuple[int, Row, int, Row]):
- # the structure of the input data:
- # [element_type, element(for process_element), timestamp(for timer),
key(for timer)]
- # all the fields are nullable except the "element_type"
- if input_data[0] != TRIGGER_TIMER:
- return self.group_agg_function.process_element(input_data[1])
- else:
- self.group_agg_function.on_timer(input_data[3])
- return []
+
+class
StreamGroupTableAggregateOperation(AbstractStreamGroupAggregateOperation):
+ def __init__(self, spec, keyed_state_backend):
+ super(StreamGroupTableAggregateOperation, self).__init__(spec,
keyed_state_backend)
+
+ def create_process_function(self, user_defined_aggs, input_extractors,
filter_args,
+ distinct_indexes, distinct_view_descriptors,
key_selector,
+ state_value_coder):
+ aggs_handler_function = SimpleTableAggsHandleFunction(
+ user_defined_aggs,
+ input_extractors,
+ self.data_view_specs,
+ filter_args,
+ distinct_indexes,
+ distinct_view_descriptors)
+ return GroupTableAggFunction(
+ aggs_handler_function,
+ key_selector,
+ self.keyed_state_backend,
+ state_value_coder,
+ self.generate_update_before,
+ self.state_cleaning_enabled,
+ self.index_of_count_star)
class DataStreamStatelessFunctionOperation(Operation):
diff --git a/flink-python/pyflink/table/__init__.py
b/flink-python/pyflink/table/__init__.py
index 7f0f3b2..b082331 100644
--- a/flink-python/pyflink/table/__init__.py
+++ b/flink-python/pyflink/table/__init__.py
@@ -61,6 +61,8 @@ Important classes of Flink Table API:
Base interface for user-defined table function.
- :class:`pyflink.table.AggregateFunction`
Base interface for user-defined aggregate function.
+ - :class:`pyflink.table.TableAggregateFunction`
+ Base interface for user-defined table aggregate function.
- :class:`pyflink.table.StatementSet`
Base interface accepts DML statements or Tables.
"""
@@ -84,7 +86,8 @@ from pyflink.table.table_environment import
(TableEnvironment, StreamTableEnviro
from pyflink.table.table_result import TableResult
from pyflink.table.table_schema import TableSchema
from pyflink.table.types import DataTypes, UserDefinedType, Row, RowKind
-from pyflink.table.udf import FunctionContext, ScalarFunction, TableFunction,
AggregateFunction
+from pyflink.table.udf import FunctionContext, ScalarFunction, TableFunction,
AggregateFunction, \
+ TableAggregateFunction
__all__ = [
'AggregateFunction',
@@ -118,6 +121,7 @@ __all__ = [
'TableSchema',
'TableSink',
'TableSource',
+ 'TableAggregateFunction',
'UserDefinedType',
'WindowGroupedTable',
'WriteMode'
diff --git a/flink-python/pyflink/table/table.py
b/flink-python/pyflink/table/table.py
index 962a087..7d07e28 100644
--- a/flink-python/pyflink/table/table.py
+++ b/flink-python/pyflink/table/table.py
@@ -828,6 +828,28 @@ class Table(object):
else:
return AggregatedTable(self._j_table.aggregate(func._j_expr),
self._t_env)
+ def flat_aggregate(self, func: Union[str, Expression]) ->
'FlatAggregateTable':
+ """
+ Perform a global flat_aggregate without group_by. flat_aggregate takes
a
+ :class:`~pyflink.table.TableAggregateFunction` which returns multiple
rows. Use a selection
+ after the flat_aggregate.
+
+ Example:
+ ::
+
+ >>> table_agg = udtaf(MyTableAggregateFunction())
+ >>> tab.flat_aggregate(table_agg(tab.a).alias("a",
"b")).select("a, b")
+
+ :param func: user-defined table aggregate function.
+ :return: The result table.
+
+ .. versionadded:: 1.13.0
+ """
+ if isinstance(func, str):
+ return FlatAggregateTable(self._j_table.flatAggregate(func),
self._t_env)
+ else:
+ return
FlatAggregateTable(self._j_table.flatAggregate(func._j_expr), self._t_env)
+
def insert_into(self, table_path: str):
"""
Writes the :class:`~pyflink.table.Table` to a
:class:`~pyflink.table.TableSink` that was
@@ -1018,6 +1040,28 @@ class GroupedTable(object):
else:
return AggregatedTable(self._j_table.aggregate(func._j_expr),
self._t_env)
+ def flat_aggregate(self, func: Union[str, Expression]) ->
'FlatAggregateTable':
+ """
+ Performs a flat_aggregate operation on a grouped table. flat_aggregate
takes a
+ :class:`~pyflink.table.TableAggregateFunction` which returns multiple
rows. Use a selection
+ after flatAggregate.
+
+ Example:
+ ::
+
+ >>> table_agg = udtaf(MyTableAggregateFunction())
+ >>>
tab.group_by(tab.c).flat_aggregate(table_agg(tab.a).alias("a")).select("c, a")
+
+ :param func: user-defined table aggregate function.
+ :return: The result table.
+
+ .. versionadded:: 1.13.0
+ """
+ if isinstance(func, str):
+ return FlatAggregateTable(self._j_table.flatAggregate(func),
self._t_env)
+ else:
+ return
FlatAggregateTable(self._j_table.flatAggregate(func._j_expr), self._t_env)
+
class GroupWindowedTable(object):
"""
@@ -1196,3 +1240,35 @@ class AggregatedTable(object):
assert len(fields) == 1
assert isinstance(fields[0], str)
return Table(self._j_table.select(fields[0]), self._t_env)
+
+
+class FlatAggregateTable(object):
+ """
+ A table that performs flatAggregate on a :class:`~pyflink.table.Table`, a
+ :class:`~pyflink.table.GroupedTable` or a
:class:`~pyflink.table.WindowGroupedTable`
+ """
+
+ def __init__(self, java_table, t_env):
+ self._j_table = java_table
+ self._t_env = t_env
+
+ def select(self, *fields: Union[str, Expression]) -> 'Table':
+ """
+ Performs a selection operation on a FlatAggregateTable. Similar to a
SQL SELECT statement.
+ The field expressions can contain complex expressions.
+
+ Example:
+ ::
+
+ >>> table_agg = udtaf(MyTableAggregateFunction())
+ >>> tab.flat_aggregate(table_agg(tab.a).alias("a",
"b")).select("a, b")
+
+ :param fields: Expression string.
+ :return: The result table.
+ """
+ if all(isinstance(f, Expression) for f in fields):
+ return Table(self._j_table.select(to_expression_jarray(fields)),
self._t_env)
+ else:
+ assert len(fields) == 1
+ assert isinstance(fields[0], str)
+ return Table(self._j_table.select(fields[0]), self._t_env)
diff --git a/flink-python/pyflink/table/table_environment.py
b/flink-python/pyflink/table/table_environment.py
index 3348ed2..d47712d 100644
--- a/flink-python/pyflink/table/table_environment.py
+++ b/flink-python/pyflink/table/table_environment.py
@@ -46,7 +46,7 @@ from pyflink.table.types import _to_java_type,
_create_type_verifier, RowType, D
_infer_schema_from_data, _create_converter, from_arrow_type, RowField,
create_arrow_schema, \
_to_java_data_type
from pyflink.table.udf import UserDefinedFunctionWrapper, AggregateFunction,
udaf, \
- UserDefinedAggregateFunctionWrapper
+ UserDefinedAggregateFunctionWrapper, udtaf, TableAggregateFunction
from pyflink.table.utils import to_expression_jarray
from pyflink.util import utils
from pyflink.util.utils import get_j_env_configuration, is_local_deployment,
load_java_class, \
@@ -1574,14 +1574,20 @@ class TableEnvironment(object, metaclass=ABCMeta):
self._add_jars_to_j_env_config(classpaths_key)
def _wrap_aggregate_function_if_needed(self, function) ->
UserDefinedFunctionWrapper:
- if isinstance(function, (AggregateFunction,
UserDefinedAggregateFunctionWrapper)):
+ if isinstance(function, (AggregateFunction, TableAggregateFunction,
+ UserDefinedAggregateFunctionWrapper)):
if not self._is_blink_planner:
- raise Exception("Python UDAF is only supported in blink
planner")
+ raise Exception("Python UDAF and UDTAF are only supported in
blink planner")
if isinstance(function, AggregateFunction):
function = udaf(function,
result_type=function.get_result_type(),
accumulator_type=function.get_accumulator_type(),
name=str(function.__class__.__name__))
+ elif isinstance(function, TableAggregateFunction):
+ function = udtaf(function,
+ result_type=function.get_result_type(),
+ accumulator_type=function.get_accumulator_type(),
+ name=str(function.__class__.__name__))
return function
diff --git a/flink-python/pyflink/table/tests/test_row_based_operation.py
b/flink-python/pyflink/table/tests/test_row_based_operation.py
index 26c8e4c..076b7d0 100644
--- a/flink-python/pyflink/table/tests/test_row_based_operation.py
+++ b/flink-python/pyflink/table/tests/test_row_based_operation.py
@@ -18,9 +18,9 @@
from pandas.util.testing import assert_frame_equal
from pyflink.common import Row
-from pyflink.table import expressions as expr
+from pyflink.table import expressions as expr, ListView
from pyflink.table.types import DataTypes
-from pyflink.table.udf import udf, udtf, udaf, AggregateFunction
+from pyflink.table.udf import udf, udtf, udaf, AggregateFunction,
TableAggregateFunction, udtaf
from pyflink.testing import source_sink_utils
from pyflink.testing.test_case_utils import PyFlinkBlinkBatchTableTestCase, \
PyFlinkBlinkStreamTableTestCase
@@ -223,6 +223,52 @@ class
StreamRowBasedOperationITTests(RowBasedOperationTests, PyFlinkBlinkStreamT
.to_pandas()
assert_frame_equal(result, pd.DataFrame([[1, 3, 15], [2, 2, 4]],
columns=['a', 'c', 'd']))
+ def test_flat_aggregate(self):
+ import pandas as pd
+ self.t_env.register_function("mytop", Top2())
+ t = self.t_env.from_elements([(1, 'Hi', 'Hello'),
+ (3, 'Hi', 'hi'),
+ (5, 'Hi2', 'hi'),
+ (7, 'Hi', 'Hello'),
+ (2, 'Hi', 'Hello')], ['a', 'b', 'c'])
+ result = t.group_by("c") \
+ .flat_aggregate("mytop(a)") \
+ .select("c, a") \
+ .flat_aggregate("mytop(a)") \
+ .select("a") \
+ .to_pandas()
+
+ assert_frame_equal(result, pd.DataFrame([[7], [5]], columns=['a']))
+
+ def test_flat_aggregate_list_view(self):
+ import pandas as pd
+ my_concat = udtaf(ListViewConcatTableAggregateFunction())
+ self.t_env.get_config().get_configuration().set_string(
+ "python.fn-execution.bundle.size", "2")
+ # trigger the cache eviction in a bundle.
+ self.t_env.get_config().get_configuration().set_string(
+ "python.state.cache-size", "2")
+ t = self.t_env.from_elements([(1, 'Hi', 'Hello'),
+ (3, 'Hi', 'hi'),
+ (3, 'Hi2', 'hi'),
+ (3, 'Hi', 'hi'),
+ (2, 'Hi', 'Hello'),
+ (1, 'Hi2', 'Hello'),
+ (3, 'Hi3', 'hi'),
+ (3, 'Hi2', 'Hello'),
+ (3, 'Hi3', 'hi'),
+ (2, 'Hi3', 'Hello')], ['a', 'b', 'c'])
+ result = t.group_by(t.c) \
+ .flat_aggregate(my_concat(t.b, ',').alias("b")) \
+ .select(t.b, t.c) \
+ .alias("a, c")
+ assert_frame_equal(result.to_pandas(),
+ pd.DataFrame([["Hi,Hi2,Hi,Hi3,Hi3", "hi"],
+ ["Hi,Hi2,Hi,Hi3,Hi3", "hi"],
+ ["Hi,Hi,Hi2,Hi2,Hi3", "Hello"],
+ ["Hi,Hi,Hi2,Hi2,Hi3", "Hello"]],
+ columns=['a', 'c']))
+
class CountAndSumAggregateFunction(AggregateFunction):
@@ -258,6 +304,65 @@ class CountAndSumAggregateFunction(AggregateFunction):
DataTypes.FIELD("b", DataTypes.BIGINT())])
+class Top2(TableAggregateFunction):
+
+ def emit_value(self, accumulator):
+ yield Row(accumulator[0])
+ yield Row(accumulator[1])
+
+ def create_accumulator(self):
+ return [None, None]
+
+ def accumulate(self, accumulator, *args):
+ if args[0] is not None:
+ if accumulator[0] is None or args[0] > accumulator[0]:
+ accumulator[1] = accumulator[0]
+ accumulator[0] = args[0]
+ elif accumulator[1] is None or args[0] > accumulator[1]:
+ accumulator[1] = args[0]
+
+ def retract(self, accumulator, *args):
+ accumulator[0] = accumulator[0] - 1
+
+ def merge(self, accumulator, accumulators):
+ for other_acc in accumulators:
+ self.accumulate(accumulator, other_acc[0])
+ self.accumulate(accumulator, other_acc[1])
+
+ def get_accumulator_type(self):
+ return DataTypes.ARRAY(DataTypes.BIGINT())
+
+ def get_result_type(self):
+ return DataTypes.ROW(
+ [DataTypes.FIELD("a", DataTypes.BIGINT())])
+
+
+class ListViewConcatTableAggregateFunction(TableAggregateFunction):
+
+ def emit_value(self, accumulator):
+ result = accumulator[1].join(accumulator[0])
+ yield Row(result)
+ yield Row(result)
+
+ def create_accumulator(self):
+ return Row(ListView(), '')
+
+ def accumulate(self, accumulator, *args):
+ accumulator[1] = args[1]
+ accumulator[0].add(args[0])
+
+ def retract(self, accumulator, *args):
+ raise NotImplementedError
+
+ def get_accumulator_type(self):
+ return DataTypes.ROW([
+ DataTypes.FIELD("f0", DataTypes.LIST_VIEW(DataTypes.STRING())),
+ DataTypes.FIELD("f1", DataTypes.BIGINT())])
+
+ def get_result_type(self):
+ return DataTypes.ROW([DataTypes.FIELD("a", DataTypes.STRING())])
+
+
if __name__ == '__main__':
import unittest
diff --git a/flink-python/pyflink/table/udf.py
b/flink-python/pyflink/table/udf.py
index ddba3d9..5a625d0 100644
--- a/flink-python/pyflink/table/udf.py
+++ b/flink-python/pyflink/table/udf.py
@@ -19,7 +19,7 @@ import abc
import collections
import functools
import inspect
-from typing import Union, List, Type, Callable, TypeVar, Generic
+from typing import Union, List, Type, Callable, TypeVar, Generic, Iterable
from pyflink.java_gateway import get_gateway
from pyflink.metrics import MetricGroup
@@ -28,7 +28,7 @@ from pyflink.table.types import DataType, _to_java_type,
_to_java_data_type
from pyflink.util import utils
__all__ = ['FunctionContext', 'AggregateFunction', 'ScalarFunction',
'TableFunction',
- 'udf', 'udtf', 'udaf']
+ 'TableAggregateFunction', 'udf', 'udtf', 'udaf', 'udtaf']
class FunctionContext(object):
@@ -126,25 +126,16 @@ T = TypeVar('T')
ACC = TypeVar('ACC')
-class AggregateFunction(UserDefinedFunction, Generic[T, ACC]):
- """
- Base interface for user-defined aggregate function. A user-defined
aggregate function maps
- scalar values of multiple rows to a new scalar value.
-
- .. versionadded:: 1.12.0
+class ImperativeAggregateFunction(UserDefinedFunction, Generic[T, ACC]):
"""
+ Base interface for user-defined aggregate function and table aggregate
function.
- @abc.abstractmethod
- def get_value(self, accumulator: ACC) -> T:
- """
- Called every time when an aggregation result should be materialized.
The returned value
- could be either an early and incomplete result (periodically emitted
as data arrives) or
- the final result of the aggregation.
+ This class is used for unified handling of imperative aggregating
functions. Concrete
+ implementations should extend from
:class:`~pyflink.table.AggregateFunction` or
+ :class:`~pyflink.table.TableAggregateFunction`.
- :param accumulator: the accumulator which contains the current
intermediate results
- :return: the aggregation result
- """
- pass
+ .. versionadded:: 1.13.0
+ """
@abc.abstractmethod
def create_accumulator(self) -> ACC:
@@ -208,6 +199,50 @@ class AggregateFunction(UserDefinedFunction, Generic[T,
ACC]):
pass
+class AggregateFunction(ImperativeAggregateFunction):
+ """
+ Base interface for user-defined aggregate function. A user-defined
aggregate function maps
+ scalar values of multiple rows to a new scalar value.
+
+ .. versionadded:: 1.12.0
+ """
+
+ @abc.abstractmethod
+ def get_value(self, accumulator: ACC) -> T:
+ """
+ Called every time when an aggregation result should be materialized.
The returned value
+ could be either an early and incomplete result (periodically emitted
as data arrives) or
+ the final result of the aggregation.
+
+ :param accumulator: the accumulator which contains the current
intermediate results
+ :return: the aggregation result
+ """
+ pass
+
+
+class TableAggregateFunction(ImperativeAggregateFunction):
+ """
+ Base class for a user-defined table aggregate function. A user-defined
table aggregate function
+ maps scalar values of multiple rows to zero, one, or multiple rows (or
structured types). If an
+ output record consists of only one field, the structured record can be
omitted, and a scalar
+ value can be emitted that will be implicitly wrapped into a row by the
runtime.
+
+ .. versionadded:: 1.13.0
+ """
+
+ @abc.abstractmethod
+ def emit_value(self, accumulator: ACC) -> Iterable[T]:
+ """
+ Called every time when an aggregation result should be materialized.
The returned value
+ could be either an early and incomplete result (periodically emitted
as data arrives) or the
+ final result of the aggregation.
+
+ :param accumulator: the accumulator which contains the current
aggregated results.
+ :return: multiple aggregated result
+ """
+ pass
+
+
class DelegatingScalarFunction(ScalarFunction):
"""
Helper scalar function implementation for lambda expression and python
function. It's for
@@ -434,10 +469,10 @@ class
UserDefinedTableFunctionWrapper(UserDefinedFunctionWrapper):
class UserDefinedAggregateFunctionWrapper(UserDefinedFunctionWrapper):
"""
- Wrapper for Python user-defined aggregate function.
+ Wrapper for Python user-defined aggregate function or user-defined table
aggregate function.
"""
def __init__(self, func, input_types, result_type, accumulator_type,
func_type,
- deterministic, name):
+ deterministic, name, is_table_aggregate=False):
super(UserDefinedAggregateFunctionWrapper, self).__init__(
func, input_types, func_type, deterministic, name)
@@ -459,6 +494,7 @@ class
UserDefinedAggregateFunctionWrapper(UserDefinedFunctionWrapper):
accumulator_type))
self._result_type = result_type
self._accumulator_type = accumulator_type
+ self._is_table_aggregate = is_table_aggregate
def _create_judf(self, serialized_func, j_input_types, j_function_kind):
if self._func_type == "pandas":
@@ -473,8 +509,12 @@ class
UserDefinedAggregateFunctionWrapper(UserDefinedFunctionWrapper):
j_accumulator_type = _to_java_data_type(self._accumulator_type)
gateway = get_gateway()
- PythonAggregateFunction = gateway.jvm \
- .org.apache.flink.table.functions.python.PythonAggregateFunction
+ if self._is_table_aggregate:
+ PythonAggregateFunction = gateway.jvm \
+
.org.apache.flink.table.functions.python.PythonTableAggregateFunction
+ else:
+ PythonAggregateFunction = gateway.jvm \
+
.org.apache.flink.table.functions.python.PythonAggregateFunction
j_aggregate_function = PythonAggregateFunction(
self._name,
bytearray(serialized_func),
@@ -512,7 +552,12 @@ def _create_udaf(f, input_types, result_type,
accumulator_type, func_type, deter
f, input_types, result_type, accumulator_type, func_type,
deterministic, name)
-def udf(f: Union[Callable, UserDefinedFunction, Type] = None,
+def _create_udtaf(f, input_types, result_type, accumulator_type, func_type,
deterministic, name):
+ return UserDefinedAggregateFunctionWrapper(
+ f, input_types, result_type, accumulator_type, func_type,
deterministic, name, True)
+
+
+def udf(f: Union[Callable, ScalarFunction, Type] = None,
input_types: Union[List[DataType], DataType] = None, result_type:
DataType = None,
deterministic: bool = None, name: str = None, func_type: str =
"general",
udf_type: str = None) -> Union[UserDefinedScalarFunctionWrapper,
Callable]:
@@ -567,7 +612,7 @@ def udf(f: Union[Callable, UserDefinedFunction, Type] =
None,
return _create_udf(f, input_types, result_type, func_type,
deterministic, name)
-def udtf(f: Union[Callable, UserDefinedFunction, Type] = None,
+def udtf(f: Union[Callable, TableFunction, Type] = None,
input_types: Union[List[DataType], DataType] = None,
result_types: Union[List[DataType], DataType] = None, deterministic:
bool = None,
name: str = None) -> Union[UserDefinedTableFunctionWrapper, Callable]:
@@ -607,7 +652,7 @@ def udtf(f: Union[Callable, UserDefinedFunction, Type] =
None,
return _create_udtf(f, input_types, result_types, deterministic, name)
-def udaf(f: Union[Callable, UserDefinedFunction, Type] = None,
+def udaf(f: Union[Callable, AggregateFunction, Type] = None,
input_types: Union[List[DataType], DataType] = None, result_type:
DataType = None,
accumulator_type: DataType = None, deterministic: bool = None, name:
str = None,
func_type: str = "general") ->
Union[UserDefinedAggregateFunctionWrapper, Callable]:
@@ -647,3 +692,72 @@ def udaf(f: Union[Callable, UserDefinedFunction, Type] =
None,
else:
return _create_udaf(f, input_types, result_type, accumulator_type,
func_type,
deterministic, name)
+
+
+def udtaf(f: Union[Callable, TableAggregateFunction, Type] = None,
+ input_types: Union[List[DataType], DataType] = None, result_type:
DataType = None,
+ accumulator_type: DataType = None, deterministic: bool = None, name:
str = None,
+ func_type: str = 'general') ->
Union[UserDefinedAggregateFunctionWrapper, Callable]:
+ """
+ Helper method for creating a user-defined table aggregate function.
+
+ Example:
+ ::
+
+ >>> # The input_types is optional.
+ >>> class Top2(TableAggregateFunction):
+ ... def emit_value(self, accumulator):
+ ... yield Row(accumulator[0])
+ ... yield Row(accumulator[1])
+ ...
+ ... def create_accumulator(self):
+ ... return [None, None]
+ ...
+ ... def accumulate(self, accumulator, *args):
+ ... if args[0] is not None:
+ ... if accumulator[0] is None or args[0] > accumulator[0]:
+ ... accumulator[1] = accumulator[0]
+ ... accumulator[0] = args[0]
+ ... elif accumulator[1] is None or args[0] >
accumulator[1]:
+ ... accumulator[1] = args[0]
+ ...
+ ... def retract(self, accumulator, *args):
+ ... accumulator[0] = accumulator[0] - 1
+ ...
+ ... def merge(self, accumulator, accumulators):
+ ... for other_acc in accumulators:
+ ... self.accumulate(accumulator, other_acc[0])
+ ... self.accumulate(accumulator, other_acc[1])
+ ...
+ ... def get_accumulator_type(self):
+ ... return DataTypes.ARRAY(DataTypes.BIGINT())
+ ...
+ ... def get_result_type(self):
+ ... return DataTypes.ROW(
+ ... [DataTypes.FIELD("a", DataTypes.BIGINT())])
+ >>> top2 = udtaf(Top2())
+
+ :param f: user-defined table aggregate function.
+ :param input_types: optional, the input data types.
+ :param result_type: the result data type.
+ :param accumulator_type: optional, the accumulator data type.
+ :param deterministic: the determinism of the function's results. True if
and only if a call to
+ this function is guaranteed to always return the
same result given the
+ same parameters. (default True)
+ :param name: the function name.
+ :param func_type: the type of the python function, available value: general
+ (default: general)
+ :return: UserDefinedAggregateFunctionWrapper or function.
+
+ .. versionadded:: 1.13.0
+ """
+ if func_type != 'general':
+ raise ValueError("The func_type must be 'general', got %s."
+ % func_type)
+ if f is None:
+ return functools.partial(_create_udtaf, input_types=input_types,
result_type=result_type,
+ accumulator_type=accumulator_type,
func_type=func_type,
+ deterministic=deterministic, name=name)
+ else:
+ return _create_udtaf(f, input_types, result_type, accumulator_type,
func_type,
+ deterministic, name)
diff --git
a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/AbstractPythonStreamAggregateOperator.java
b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/AbstractPythonStreamAggregateOperator.java
index 25168e4..145447f 100644
---
a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/AbstractPythonStreamAggregateOperator.java
+++
b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/AbstractPythonStreamAggregateOperator.java
@@ -41,11 +41,15 @@ import org.apache.flink.streaming.api.TimerService;
import org.apache.flink.streaming.api.operators.InternalTimer;
import org.apache.flink.streaming.api.operators.Triggerable;
import
org.apache.flink.streaming.api.operators.python.AbstractOneInputPythonFunctionOperator;
+import org.apache.flink.streaming.api.utils.PythonOperatorUtils;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.table.data.GenericRowData;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.data.UpdatableRowData;
+import org.apache.flink.table.functions.python.PythonAggregateFunctionInfo;
+import org.apache.flink.table.functions.python.PythonEnv;
import org.apache.flink.table.planner.plan.utils.KeySelectorUtil;
+import org.apache.flink.table.planner.typeutils.DataViewUtils;
import org.apache.flink.table.runtime.functions.CleanupState;
import org.apache.flink.table.runtime.keyselector.RowDataKeySelector;
import
org.apache.flink.table.runtime.operators.python.utils.StreamRecordRowDataWrappingCollector;
@@ -85,6 +89,10 @@ public abstract class AbstractPythonStreamAggregateOperator
private static final byte TRIGGER_TIMER = 1;
+ private final PythonAggregateFunctionInfo[] aggregateFunctions;
+
+ private final DataViewUtils.DataViewSpec[][] dataViewSpecs;
+
/**
* The input logical type.
*/
@@ -196,6 +204,8 @@ public abstract class AbstractPythonStreamAggregateOperator
Configuration config,
RowType inputType,
RowType outputType,
+ PythonAggregateFunctionInfo[] aggregateFunctions,
+ DataViewUtils.DataViewSpec[][] dataViewSpecs,
int[] grouping,
int indexOfCountStar,
boolean generateUpdateBefore,
@@ -204,6 +214,8 @@ public abstract class AbstractPythonStreamAggregateOperator
super(config);
this.inputType = Preconditions.checkNotNull(inputType);
this.outputType = Preconditions.checkNotNull(outputType);
+ this.aggregateFunctions = aggregateFunctions;
+ this.dataViewSpecs = dataViewSpecs;
this.jobOptions = buildJobOptions(config);
this.grouping = grouping;
this.indexOfCountStar = indexOfCountStar;
@@ -321,6 +333,11 @@ public abstract class AbstractPythonStreamAggregateOperator
return keyForTimerService;
}
+ @Override
+ public PythonEnv getPythonEnv() {
+ return aggregateFunctions[0].getPythonFunction().getPythonEnv();
+ }
+
@VisibleForTesting
TypeSerializer getKeySerializer() {
return PythonTypeUtils.toBlinkTypeSerializer(getKeyType());
@@ -348,6 +365,14 @@ public abstract class AbstractPythonStreamAggregateOperator
builder.setStateCacheSize(stateCacheSize);
builder.setMapStateReadCacheSize(mapStateReadCacheSize);
builder.setMapStateWriteCacheSize(mapStateWriteCacheSize);
+ for (int i = 0; i < aggregateFunctions.length; i++) {
+ DataViewUtils.DataViewSpec[] specs = null;
+ if (i < dataViewSpecs.length) {
+ specs = dataViewSpecs[i];
+ }
+ builder.addUdfs(
+
PythonOperatorUtils.getUserDefinedAggregateFunctionProto(aggregateFunctions[i],
specs));
+ }
return builder.build();
}
diff --git
a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/PythonStreamGroupAggregateOperator.java
b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/PythonStreamGroupAggregateOperator.java
index 670e8da..4f1c748 100644
---
a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/PythonStreamGroupAggregateOperator.java
+++
b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/PythonStreamGroupAggregateOperator.java
@@ -22,9 +22,7 @@ import org.apache.flink.annotation.Internal;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.fnexecution.v1.FlinkFnApi;
-import org.apache.flink.streaming.api.utils.PythonOperatorUtils;
import org.apache.flink.table.functions.python.PythonAggregateFunctionInfo;
-import org.apache.flink.table.functions.python.PythonEnv;
import org.apache.flink.table.planner.typeutils.DataViewUtils;
import org.apache.flink.table.types.logical.RowType;
@@ -39,10 +37,6 @@ public class PythonStreamGroupAggregateOperator extends
AbstractPythonStreamAggr
@VisibleForTesting
protected static final String STREAM_GROUP_AGGREGATE_URN =
"flink:transform:stream_group_aggregate:v1";
- private final PythonAggregateFunctionInfo[] aggregateFunctions;
-
- private final DataViewUtils.DataViewSpec[][] dataViewSpecs;
-
/**
* True if the count(*) agg is inserted by the planner.
*/
@@ -60,18 +54,11 @@ public class PythonStreamGroupAggregateOperator extends
AbstractPythonStreamAggr
boolean generateUpdateBefore,
long minRetentionTime,
long maxRetentionTime) {
- super(config, inputType, outputType, grouping,
indexOfCountStar, generateUpdateBefore,
- minRetentionTime, maxRetentionTime);
- this.aggregateFunctions = aggregateFunctions;
- this.dataViewSpecs = dataViewSpecs;
+ super(config, inputType, outputType, aggregateFunctions,
dataViewSpecs, grouping,
+ indexOfCountStar, generateUpdateBefore,
minRetentionTime, maxRetentionTime);
this.countStarInserted = countStarInserted;
}
- @Override
- public PythonEnv getPythonEnv() {
- return aggregateFunctions[0].getPythonFunction().getPythonEnv();
- }
-
/**
* Gets the proto representation of the Python user-defined aggregate
functions to be executed.
*/
@@ -79,14 +66,6 @@ public class PythonStreamGroupAggregateOperator extends
AbstractPythonStreamAggr
public FlinkFnApi.UserDefinedAggregateFunctions
getUserDefinedFunctionsProto() {
FlinkFnApi.UserDefinedAggregateFunctions.Builder builder =
super.getUserDefinedFunctionsProto().toBuilder();
- for (int i = 0; i < aggregateFunctions.length; i++) {
- DataViewUtils.DataViewSpec[] specs = null;
- if (i < dataViewSpecs.length) {
- specs = dataViewSpecs[i];
- }
- builder.addUdfs(
-
PythonOperatorUtils.getUserDefinedAggregateFunctionProto(aggregateFunctions[i],
specs));
- }
builder.setCountStarInserted(countStarInserted);
return builder.build();
}
diff --git
a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/PythonStreamGroupTableAggregateOperator.java
b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/PythonStreamGroupTableAggregateOperator.java
index a61ad91..fd67de1 100644
---
a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/PythonStreamGroupTableAggregateOperator.java
+++
b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/aggregate/PythonStreamGroupTableAggregateOperator.java
@@ -22,9 +22,7 @@ import org.apache.flink.annotation.Internal;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.fnexecution.v1.FlinkFnApi;
-import org.apache.flink.streaming.api.utils.PythonOperatorUtils;
import org.apache.flink.table.functions.python.PythonAggregateFunctionInfo;
-import org.apache.flink.table.functions.python.PythonEnv;
import org.apache.flink.table.planner.typeutils.DataViewUtils;
import org.apache.flink.table.types.logical.RowType;
@@ -39,30 +37,19 @@ public class PythonStreamGroupTableAggregateOperator
extends AbstractPythonStrea
@VisibleForTesting
protected static final String STREAM_GROUP_TABLE_AGGREGATE_URN =
"flink:transform:stream_group_table_aggregate:v1";
- private final PythonAggregateFunctionInfo aggregateFunction;
-
- private final DataViewUtils.DataViewSpec[] dataViewSpecs;
-
public PythonStreamGroupTableAggregateOperator(
Configuration config,
RowType inputType,
RowType outputType,
- PythonAggregateFunctionInfo aggregateFunction,
- DataViewUtils.DataViewSpec[] dataViewSpec,
+ PythonAggregateFunctionInfo[] aggregateFunctions,
+ DataViewUtils.DataViewSpec[][] dataViewSpecs,
int[] grouping,
int indexOfCountStar,
boolean generateUpdateBefore,
long minRetentionTime,
long maxRetentionTime) {
- super(config, inputType, outputType, grouping,
indexOfCountStar, generateUpdateBefore,
- minRetentionTime, maxRetentionTime);
- this.aggregateFunction = aggregateFunction;
- this.dataViewSpecs = dataViewSpec;
- }
-
- @Override
- public PythonEnv getPythonEnv() {
- return aggregateFunction.getPythonFunction().getPythonEnv();
+ super(config, inputType, outputType, aggregateFunctions,
dataViewSpecs, grouping,
+ indexOfCountStar, generateUpdateBefore,
minRetentionTime, maxRetentionTime);
}
/**
@@ -72,8 +59,6 @@ public class PythonStreamGroupTableAggregateOperator extends
AbstractPythonStrea
public FlinkFnApi.UserDefinedAggregateFunctions
getUserDefinedFunctionsProto() {
FlinkFnApi.UserDefinedAggregateFunctions.Builder builder =
super.getUserDefinedFunctionsProto().toBuilder();
-
builder.addUdfs(PythonOperatorUtils.getUserDefinedAggregateFunctionProto(
- aggregateFunction, dataViewSpecs));
return builder.build();
}
diff --git
a/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/aggregate/PythonStreamGroupTableAggregateOperatorTest.java
b/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/aggregate/PythonStreamGroupTableAggregateOperatorTest.java
index eee3b5a..d839fcb 100644
---
a/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/aggregate/PythonStreamGroupTableAggregateOperatorTest.java
+++
b/flink-python/src/test/java/org/apache/flink/table/runtime/operators/python/aggregate/PythonStreamGroupTableAggregateOperatorTest.java
@@ -224,11 +224,12 @@ public class PythonStreamGroupTableAggregateOperatorTest
extends AbstractPythonS
config,
getInputType(),
getOutputType(),
- new PythonAggregateFunctionInfo(
-
PythonScalarFunctionOperatorTestBase.DummyPythonFunction.INSTANCE,
- new Integer[]{0},
- -1,
- false),
+ new PythonAggregateFunctionInfo[]{
+ new PythonAggregateFunctionInfo(
+
PythonScalarFunctionOperatorTestBase.DummyPythonFunction.INSTANCE,
+ new Integer[]{0},
+ -1,
+ false)},
getGrouping(),
-1,
false,
@@ -243,7 +244,7 @@ public class PythonStreamGroupTableAggregateOperatorTest
extends AbstractPythonS
Configuration config,
RowType inputType,
RowType outputType,
- PythonAggregateFunctionInfo aggregateFunction,
+ PythonAggregateFunctionInfo[] aggregateFunctions,
int[] grouping,
int indexOfCountStar,
boolean generateUpdateBefore,
@@ -253,8 +254,8 @@ public class PythonStreamGroupTableAggregateOperatorTest
extends AbstractPythonS
config,
inputType,
outputType,
- aggregateFunction,
- new DataViewUtils.DataViewSpec[0],
+ aggregateFunctions,
+ new DataViewUtils.DataViewSpec[0][0],
grouping,
indexOfCountStar,
generateUpdateBefore,
diff --git
a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/PythonTableAggregateFunction.java
b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/PythonTableAggregateFunction.java
new file mode 100644
index 0000000..4486f69
--- /dev/null
+++
b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/PythonTableAggregateFunction.java
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.functions.python;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.table.catalog.DataTypeFactory;
+import org.apache.flink.table.functions.TableAggregateFunction;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.inference.TypeInference;
+import org.apache.flink.table.types.inference.TypeStrategies;
+import org.apache.flink.table.types.utils.TypeConversions;
+
+/**
+ * The wrapper of user defined python table aggregate function.
+ */
+@Internal
+public class PythonTableAggregateFunction extends TableAggregateFunction
implements PythonFunction {
+
+ private static final long serialVersionUID = 1L;
+
+ private final String name;
+ private final byte[] serializedTableAggregateFunction;
+ private final DataType[] inputTypes;
+ private final DataType resultType;
+ private final DataType accumulatorType;
+ private final PythonFunctionKind pythonFunctionKind;
+ private final boolean deterministic;
+ private final PythonEnv pythonEnv;
+
+ public PythonTableAggregateFunction(
+ String name,
+ byte[] serializedTableAggregateFunction,
+ DataType[] inputTypes,
+ DataType resultType,
+ DataType accumulatorType,
+ PythonFunctionKind pythonFunctionKind,
+ boolean deterministic,
+ PythonEnv pythonEnv) {
+ this.name = name;
+ this.serializedTableAggregateFunction =
serializedTableAggregateFunction;
+ this.inputTypes = inputTypes;
+ this.resultType = resultType;
+ this.accumulatorType = accumulatorType;
+ this.pythonFunctionKind = pythonFunctionKind;
+ this.deterministic = deterministic;
+ this.pythonEnv = pythonEnv;
+ }
+
+ public void accumulate(Object accumulator, Object... args) {
+ throw new UnsupportedOperationException(
+ "This method is a placeholder and should not be
called.");
+ }
+
+ public void emitValue(Object accumulator, Object out) {
+ throw new UnsupportedOperationException(
+ "This method is a placeholder and should not be
called.");
+ }
+
+ @Override
+ public Object createAccumulator() {
+ return null;
+ }
+
+ @Override
+ public byte[] getSerializedPythonFunction() {
+ return serializedTableAggregateFunction;
+ }
+
+ @Override
+ public PythonEnv getPythonEnv() {
+ return pythonEnv;
+ }
+
+ @Override
+ public PythonFunctionKind getPythonFunctionKind() {
+ return pythonFunctionKind;
+ }
+
+ @Override
+ public boolean isDeterministic() {
+ return deterministic;
+ }
+
+ @Override
+ public TypeInformation getResultType() {
+ return TypeConversions.fromDataTypeToLegacyInfo(resultType);
+ }
+
+ @Override
+ public TypeInformation getAccumulatorType() {
+ return
TypeConversions.fromDataTypeToLegacyInfo(accumulatorType);
+ }
+
+ @Override
+ public TypeInference getTypeInference(DataTypeFactory typeFactory) {
+ TypeInference.Builder builder = TypeInference.newBuilder();
+ if (inputTypes != null) {
+ builder.typedArguments(inputTypes);
+ }
+ return builder
+ .outputTypeStrategy(TypeStrategies.explicit(resultType))
+
.accumulatorTypeStrategy(TypeStrategies.explicit(accumulatorType))
+ .build();
+ }
+
+ @Override
+ public String toString() {
+ return name;
+ }
+}
diff --git
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/common/CommonPythonAggregate.scala
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/common/CommonPythonAggregate.scala
index 652786c..b06e21d 100644
---
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/common/CommonPythonAggregate.scala
+++
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/common/CommonPythonAggregate.scala
@@ -85,11 +85,16 @@ trait CommonPythonAggregate extends CommonPythonBase {
pythonAggregateInfoList.aggInfos(i).argIndexes.map(_.asInstanceOf[AnyRef]),
aggCalls(i).filterArg,
aggCalls(i).isDistinct))
+ val typeInference = function match {
+ case aggregateFunction: PythonAggregateFunction =>
+ aggregateFunction.getTypeInference(null)
+ case tableAggregateFunction: PythonTableAggregateFunction =>
+ tableAggregateFunction.getTypeInference(null)
+ }
dataViewSpecList.add(
extractDataViewSpecs(
i,
-
function.asInstanceOf[PythonAggregateFunction].getTypeInference(null)
- .getAccumulatorTypeStrategy.get().inferType(null).get()))
+
typeInference.getAccumulatorTypeStrategy.get().inferType(null).get()))
case function: UserDefinedFunction =>
var filterArg = -1
var distinct = false
diff --git
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonGroupTableAggregate.scala
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonGroupTableAggregate.scala
index dcf9b7b..c92359f 100644
---
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonGroupTableAggregate.scala
+++
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonGroupTableAggregate.scala
@@ -18,9 +18,19 @@
package org.apache.flink.table.planner.plan.nodes.physical.stream
import org.apache.flink.api.dag.Transformation
-import org.apache.flink.table.api.TableException
+import org.apache.flink.configuration.Configuration
+import org.apache.flink.core.memory.ManagedMemoryUseCase
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator
+import org.apache.flink.streaming.api.transformations.OneInputTransformation
import org.apache.flink.table.data.RowData
+import org.apache.flink.table.functions.python.PythonAggregateFunctionInfo
+import org.apache.flink.table.planner.calcite.FlinkTypeFactory
import org.apache.flink.table.planner.delegation.StreamPlanner
+import org.apache.flink.table.planner.plan.nodes.common.CommonPythonAggregate
+import org.apache.flink.table.planner.plan.utils.{ChangelogPlanUtils,
KeySelectorUtil}
+import org.apache.flink.table.planner.typeutils.DataViewUtils.DataViewSpec
+import org.apache.flink.table.runtime.typeutils.InternalTypeInfo
+import org.apache.flink.table.types.logical.RowType
import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
import org.apache.calcite.rel.RelNode
@@ -45,7 +55,8 @@ class StreamExecPythonGroupTableAggregate(
inputRel,
outputRowType,
grouping,
- aggCalls) {
+ aggCalls)
+ with CommonPythonAggregate {
override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]):
RelNode = {
new StreamExecPythonGroupTableAggregate(
@@ -59,6 +70,112 @@ class StreamExecPythonGroupTableAggregate(
override protected def translateToPlanInternal(
planner: StreamPlanner): Transformation[RowData] = {
- throw new TableException("The implementation will be in FLINK-20528.")
+ val tableConfig = planner.getTableConfig
+
+ if (grouping.length > 0 && tableConfig.getMinIdleStateRetentionTime < 0) {
+ LOG.warn("No state retention interval configured for a query which
accumulates state. " +
+ "Please provide a query configuration with valid retention interval to
prevent excessive " +
+ "state size. You may specify a retention time of 0 to not clean up the
state.")
+ }
+
+ val inputTransformation = getInputNodes.get(0).translateToPlan(planner)
+ .asInstanceOf[Transformation[RowData]]
+
+ val outRowType = FlinkTypeFactory.toLogicalRowType(outputRowType)
+ val inputRowType = FlinkTypeFactory.toLogicalRowType(getInput.getRowType)
+
+ val generateUpdateBefore = ChangelogPlanUtils.generateUpdateBefore(this)
+
+ val inputCountIndex = aggInfoList.getIndexOfCountStar
+
+ var (pythonFunctionInfos, dataViewSpecs) =
+ extractPythonAggregateFunctionInfos(aggInfoList, aggCalls)
+
+ if (dataViewSpecs.forall(_.isEmpty)) {
+ dataViewSpecs = Array(Array())
+ }
+
+ val operator = getPythonTableAggregateFunctionOperator(
+ getConfig(planner.getExecEnv, tableConfig),
+ inputRowType,
+ outRowType,
+ pythonFunctionInfos,
+ dataViewSpecs,
+ tableConfig.getMinIdleStateRetentionTime,
+ tableConfig.getMaxIdleStateRetentionTime,
+ grouping,
+ generateUpdateBefore,
+ inputCountIndex)
+
+ val selector = KeySelectorUtil.getRowDataSelector(
+ grouping,
+ InternalTypeInfo.of(inputRowType))
+
+ // partitioned aggregation
+ val ret = new OneInputTransformation(
+ inputTransformation,
+ getRelDetailedDescription,
+ operator,
+ InternalTypeInfo.of(outRowType),
+ inputTransformation.getParallelism)
+
+ if (inputsContainSingleton()) {
+ ret.setParallelism(1)
+ ret.setMaxParallelism(1)
+ }
+
+ if
(isPythonWorkerUsingManagedMemory(planner.getTableConfig.getConfiguration)) {
+ ret.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON)
+ }
+
+ // set KeyType and Selector for state
+ ret.setStateKeySelector(selector)
+ ret.setStateKeyType(selector.getProducedType)
+ ret
}
+
+ private[this] def getPythonTableAggregateFunctionOperator(
+ config: Configuration,
+ inputType: RowType,
+ outputType: RowType,
+ aggregateFunctions: Array[PythonAggregateFunctionInfo],
+ dataViewSpecs: Array[Array[DataViewSpec]],
+ minIdleStateRetentionTime: Long,
+ maxIdleStateRetentionTime: Long,
+ grouping: Array[Int],
+ generateUpdateBefore: Boolean,
+ indexOfCountStar: Int): OneInputStreamOperator[RowData, RowData] = {
+
+ val clazz = loadClass(
+
StreamExecPythonGroupTableAggregate.PYTHON_STREAM_TABLE_AGGREGATE_OPERATOR_NAME)
+ val ctor = clazz.getConstructor(
+ classOf[Configuration],
+ classOf[RowType],
+ classOf[RowType],
+ classOf[Array[PythonAggregateFunctionInfo]],
+ classOf[Array[Array[DataViewSpec]]],
+ classOf[Array[Int]],
+ classOf[Int],
+ classOf[Boolean],
+ classOf[Long],
+ classOf[Long])
+ ctor.newInstance(
+ config.asInstanceOf[AnyRef],
+ inputType.asInstanceOf[AnyRef],
+ outputType.asInstanceOf[AnyRef],
+ aggregateFunctions.asInstanceOf[AnyRef],
+ dataViewSpecs.asInstanceOf[AnyRef],
+ grouping.asInstanceOf[AnyRef],
+ indexOfCountStar.asInstanceOf[AnyRef],
+ generateUpdateBefore.asInstanceOf[AnyRef],
+ minIdleStateRetentionTime.asInstanceOf[AnyRef],
+ maxIdleStateRetentionTime.asInstanceOf[AnyRef])
+ .asInstanceOf[OneInputStreamOperator[RowData, RowData]]
+ }
+}
+
+object StreamExecPythonGroupTableAggregate {
+ val PYTHON_STREAM_TABLE_AGGREGATE_OPERATOR_NAME: String =
+ "org.apache.flink.table.runtime.operators.python.aggregate." +
+ "PythonStreamGroupTableAggregateOperator"
}