This is an automated email from the ASF dual-hosted git repository.
hxb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new 8debdd0 [FLINK-24317][python][tests] Optimize the implementation of
Top2 in test_flat_aggregate
8debdd0 is described below
commit 8debdd06be0e917610c50a77893f7ade45cee98f
Author: huangxingbo <[email protected]>
AuthorDate: Fri Sep 17 16:12:21 2021 +0800
[FLINK-24317][python][tests] Optimize the implementation of Top2 in
test_flat_aggregate
This closes #17309.
---
.../table/tests/test_row_based_operation.py | 27 +++++++++-------------
1 file changed, 11 insertions(+), 16 deletions(-)
diff --git a/flink-python/pyflink/table/tests/test_row_based_operation.py
b/flink-python/pyflink/table/tests/test_row_based_operation.py
index 61387f5..03b99b4 100644
--- a/flink-python/pyflink/table/tests/test_row_based_operation.py
+++ b/flink-python/pyflink/table/tests/test_row_based_operation.py
@@ -344,27 +344,22 @@ class CountAndSumAggregateFunction(AggregateFunction):
class Top2(TableAggregateFunction):
def emit_value(self, accumulator):
- yield accumulator[0]
- yield accumulator[1]
+ accumulator.sort()
+ accumulator.reverse()
+ size = len(accumulator)
+ if size > 1:
+ yield accumulator[0]
+ if size > 2:
+ yield accumulator[1]
def create_accumulator(self):
- return [None, None]
+ return []
def accumulate(self, accumulator, row: Row):
- if row.a is not None:
- if accumulator[0] is None or row.a > accumulator[0]:
- accumulator[1] = accumulator[0]
- accumulator[0] = row.a
- elif accumulator[1] is None or row.a > accumulator[1]:
- accumulator[1] = row.a
+ accumulator.append(row.a)
- def retract(self, accumulator, *args):
- accumulator[0] = accumulator[0] - 1
-
- def merge(self, accumulator, accumulators):
- for other_acc in accumulators:
- self.accumulate(accumulator, other_acc[0])
- self.accumulate(accumulator, other_acc[1])
+ def retract(self, accumulator, row: Row):
+ accumulator.remove(row.a)
def get_accumulator_type(self):
return DataTypes.ARRAY(DataTypes.BIGINT())