godfreyhe commented on a change in pull request #17344:
URL: https://github.com/apache/flink/pull/17344#discussion_r732920524



##########
File path: 
flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesTableFactory.java
##########
@@ -867,27 +885,118 @@ public String asSummaryString() {
                     allPartitions.isEmpty()
                             ? Collections.singletonList(Collections.emptyMap())
                             : allPartitions;
+
             int numRetained = 0;
             for (Map<String, String> partition : keys) {
-                for (Row row : data.get(partition)) {
+                Collection<Row> rowsInPartition = data.get(partition);
+
+                // handle predicates and projection
+                List<Row> rowsRetained =
+                        rowsInPartition.stream()
+                                .filter(
+                                        row ->
+                                                
FilterUtils.isRetainedAfterApplyingFilterPredicates(
+                                                        filterPredicates, 
getValueGetter(row)))
+                                .map(
+                                        row -> {
+                                            Row projectedRow = projectRow(row);
+                                            
projectedRow.setKind(row.getKind());
+                                            return projectedRow;
+                                        })
+                                .collect(Collectors.toList());
+
+                // handle aggregates
+                if (!aggregateExpressions.isEmpty()) {
+                    rowsRetained = applyAggregatesToRows(rowsRetained);
+                }
+
+                // handle row data
+                for (Row row : rowsRetained) {
+                    final RowData rowData = (RowData) 
converter.toInternal(row);
+                    if (rowData != null) {
+                        if (numRetained >= numElementToSkip) {
+                            rowData.setRowKind(row.getKind());
+                            result.add(rowData);
+                        }
+                        numRetained++;
+                    }
+
+                    // handle limit. No aggregates will be pushed down when 
there is a limit.
                     if (result.size() >= limit) {
                         return result;
                     }
-                    boolean isRetained =
-                            
FilterUtils.isRetainedAfterApplyingFilterPredicates(
-                                    filterPredicates, getValueGetter(row));
-                    if (isRetained) {
-                        final Row projectedRow = projectRow(row);
-                        final RowData rowData = (RowData) 
converter.toInternal(projectedRow);
-                        if (rowData != null) {
-                            if (numRetained >= numElementToSkip) {
-                                rowData.setRowKind(row.getKind());
-                                result.add(rowData);
-                            }
-                            numRetained++;
-                        }
+                }
+            }
+
+            return result;
+        }
+
+        private List<Row> applyAggregatesToRows(List<Row> rows) {
+            if (groupingSet != null && groupingSet.length > 0) {
+                // has group by, group firstly
+                Map<Row, List<Row>> buffer = new HashMap<>();
+                for (Row row : rows) {
+                    Row bufferKey = new Row(groupingSet.length);
+                    for (int i = 0; i < groupingSet.length; i++) {
+                        bufferKey.setField(i, row.getField(groupingSet[i]));
+                    }
+                    if (buffer.containsKey(bufferKey)) {
+                        buffer.get(bufferKey).add(row);
+                    } else {
+                        buffer.put(bufferKey, new 
ArrayList<>(Collections.singletonList(row)));
                     }
                 }
+                List<Row> result = new ArrayList<>();
+                for (Map.Entry<Row, List<Row>> entry : buffer.entrySet()) {
+                    result.add(Row.join(entry.getKey(), 
accumulateRows(entry.getValue())));
+                }
+                return result;
+            } else {
+                return Collections.singletonList(accumulateRows(rows));
+            }
+        }
+
+        // can only apply sum/sum0/avg function for long type fields for 
testing
+        private Row accumulateRows(List<Row> rows) {
+            Row result = new Row(aggregateExpressions.size());
+            for (int i = 0; i < aggregateExpressions.size(); i++) {
+                FunctionDefinition aggFunction =
+                        aggregateExpressions.get(i).getFunctionDefinition();
+                List<FieldReferenceExpression> arguments = 
aggregateExpressions.get(i).getArgs();
+                if (aggFunction instanceof MinAggFunction) {
+                    int argIndex = arguments.get(0).getFieldIndex();
+                    Row minRow =
+                            rows.stream()
+                                    .min(Comparator.comparing(row -> 
row.getFieldAs(argIndex)))
+                                    .get();
+                    result.setField(i, minRow.getField(argIndex));
+                } else if (aggFunction instanceof MaxAggFunction) {
+                    int argIndex = arguments.get(0).getFieldIndex();
+                    Row maxRow =
+                            rows.stream()
+                                    .max(Comparator.comparing(row -> 
row.getFieldAs(argIndex)))
+                                    .get();
+                    result.setField(i, maxRow.getField(argIndex));
+                } else if (aggFunction instanceof SumAggFunction) {
+                    int argIndex = arguments.get(0).getFieldIndex();
+                    Object finalSum =
+                            rows.stream()
+                                    .filter(row -> row.getField(argIndex) != 
null)
+                                    .mapToLong(row -> row.getFieldAs(argIndex))
+                                    .sum();
+                    result.setField(i, finalSum);
+                } else if (aggFunction instanceof Sum0AggFunction) {
+                    int argIndex = arguments.get(0).getFieldIndex();
+                    Object finalSum0 =
+                            rows.stream()
+                                    .filter(row -> row.getField(argIndex) != 
null)
+                                    .mapToLong(row -> row.getFieldAs(argIndex))
+                                    .sum();
+                    result.setField(i, finalSum0);
+                } else if (aggFunction instanceof CountAggFunction

Review comment:
       if all inputs are null, CountAggFunction should return 0




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to