This is an automated email from the ASF dual-hosted git repository.
godfrey pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new 34d581b [FLINK-20895][table-planner] Support local aggregate push
down in table planner
34d581b is described below
commit 34d581b230712d3fd01ce9cfd822acd0df413a92
Author: iyupeng <[email protected]>
AuthorDate: Sun Oct 31 21:56:36 2021 +0800
[FLINK-20895][table-planner] Support local aggregate push down in table
planner
Co-authored-by: Sebastian Liu <[email protected]>
This closes #17344
---
.../generated/optimizer_config_configuration.html | 6 +
.../table/api/config/OptimizerConfigOptions.java | 9 +
.../abilities/SupportsAggregatePushDown.java | 2 -
.../table/expressions/AggregateExpression.java | 5 -
.../abilities/source/AggregatePushDownSpec.java | 207 +++++++++
.../abilities/source/SourceAbilityContext.java | 1 +
.../plan/abilities/source/SourceAbilitySpec.java | 3 +-
.../batch/PushLocalAggIntoScanRuleBase.java | 240 ++++++++++
.../batch/PushLocalHashAggIntoScanRule.java | 79 ++++
.../PushLocalHashAggWithCalcIntoScanRule.java | 92 ++++
.../batch/PushLocalSortAggIntoScanRule.java | 79 ++++
.../PushLocalSortAggWithCalcIntoScanRule.java | 92 ++++
...ushLocalSortAggWithSortAndCalcIntoScanRule.java | 101 +++++
.../PushLocalSortAggWithSortIntoScanRule.java | 87 ++++
.../batch/BatchPhysicalTableSourceScan.scala | 6 +
.../planner/plan/rules/FlinkBatchRuleSets.scala | 8 +-
.../planner/plan/schema/TableSourceTable.scala | 19 +
.../table/planner/plan/utils/AggregateUtil.scala | 21 +-
.../planner/factories/TestValuesTableFactory.java | 313 +++++++++++--
.../PushLocalAggIntoTableSourceScanRuleTest.java | 367 +++++++++++++++
.../sql/agg/LocalAggregatePushDownITCase.java | 318 +++++++++++++
.../table/planner/plan/batch/sql/RankTest.xml | 3 +-
.../planner/plan/batch/sql/TableSourceTest.xml | 3 +-
.../PushLocalAggIntoTableSourceScanRuleTest.xml | 499 +++++++++++++++++++++
.../table/planner/runtime/utils/TestData.scala | 22 +-
25 files changed, 2524 insertions(+), 58 deletions(-)
diff --git
a/docs/layouts/shortcodes/generated/optimizer_config_configuration.html
b/docs/layouts/shortcodes/generated/optimizer_config_configuration.html
index f82b46a..61fc8ce 100644
--- a/docs/layouts/shortcodes/generated/optimizer_config_configuration.html
+++ b/docs/layouts/shortcodes/generated/optimizer_config_configuration.html
@@ -60,6 +60,12 @@ ONE_PHASE: Enforce to use one stage aggregate which only has
CompleteGlobalAggre
<td>When it is true, the optimizer will try to find out duplicated
sub-plans and reuse them.</td>
</tr>
<tr>
+ <td><h5>table.optimizer.source.aggregate-pushdown-enabled</h5><br>
<span class="label label-primary">Batch</span></td>
+ <td style="word-wrap: break-word;">true</td>
+ <td>Boolean</td>
+ <td>When it is true, the optimizer will push down the local
aggregates into the TableSource which implements SupportsAggregatePushDown.</td>
+ </tr>
+ <tr>
<td><h5>table.optimizer.source.predicate-pushdown-enabled</h5><br>
<span class="label label-primary">Batch</span> <span class="label
label-primary">Streaming</span></td>
<td style="word-wrap: break-word;">true</td>
<td>Boolean</td>
diff --git
a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/config/OptimizerConfigOptions.java
b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/config/OptimizerConfigOptions.java
index 685dd58..1e01256 100644
---
a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/config/OptimizerConfigOptions.java
+++
b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/api/config/OptimizerConfigOptions.java
@@ -93,6 +93,15 @@ public class OptimizerConfigOptions {
+
TABLE_OPTIMIZER_REUSE_SUB_PLAN_ENABLED.key()
+ " is true.");
+ @Documentation.TableOption(execMode = Documentation.ExecMode.BATCH)
+ public static final ConfigOption<Boolean>
TABLE_OPTIMIZER_SOURCE_AGGREGATE_PUSHDOWN_ENABLED =
+ key("table.optimizer.source.aggregate-pushdown-enabled")
+ .booleanType()
+ .defaultValue(true)
+ .withDescription(
+ "When it is true, the optimizer will push down the
local aggregates into "
+ + "the TableSource which implements
SupportsAggregatePushDown.");
+
@Documentation.TableOption(execMode =
Documentation.ExecMode.BATCH_STREAMING)
public static final ConfigOption<Boolean>
TABLE_OPTIMIZER_SOURCE_PREDICATE_PUSHDOWN_ENABLED =
key("table.optimizer.source.predicate-pushdown-enabled")
diff --git
a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/connector/source/abilities/SupportsAggregatePushDown.java
b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/connector/source/abilities/SupportsAggregatePushDown.java
index 218c2a4..67f645a 100644
---
a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/connector/source/abilities/SupportsAggregatePushDown.java
+++
b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/connector/source/abilities/SupportsAggregatePushDown.java
@@ -124,8 +124,6 @@ import java.util.List;
*
* <p>Regardless if this interface is implemented or not, a final aggregation
is always applied in a
* subsequent operation after the source.
- *
- * <p>Note: currently, the {@link SupportsAggregatePushDown} is not supported
by planner.
*/
@PublicEvolving
public interface SupportsAggregatePushDown {
diff --git
a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/AggregateExpression.java
b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/AggregateExpression.java
index ce11179..897ab8a 100644
---
a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/AggregateExpression.java
+++
b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/AggregateExpression.java
@@ -19,7 +19,6 @@
package org.apache.flink.table.expressions;
import org.apache.flink.annotation.PublicEvolving;
-import
org.apache.flink.table.connector.source.abilities.SupportsAggregatePushDown;
import org.apache.flink.table.functions.FunctionDefinition;
import org.apache.flink.table.types.DataType;
import org.apache.flink.util.Preconditions;
@@ -47,9 +46,6 @@ import java.util.stream.Collectors;
* <li>{@code approximate} indicates whether this is a approximate aggregate
function.
* <li>{@code ignoreNulls} indicates whether this aggregate function ignore
null value.
* </ul>
- *
- * <p>Note: currently, the {@link AggregateExpression} is only used in {@link
- * SupportsAggregatePushDown}.
*/
@PublicEvolving
public class AggregateExpression implements ResolvedExpression {
@@ -107,7 +103,6 @@ public class AggregateExpression implements
ResolvedExpression {
return args;
}
- @Nullable
public Optional<CallExpression> getFilterExpression() {
return Optional.ofNullable(filterExpression);
}
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/abilities/source/AggregatePushDownSpec.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/abilities/source/AggregatePushDownSpec.java
new file mode 100644
index 0000000..f8fe887
--- /dev/null
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/abilities/source/AggregatePushDownSpec.java
@@ -0,0 +1,207 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.abilities.source;
+
+import org.apache.flink.table.api.TableException;
+import org.apache.flink.table.connector.source.DynamicTableSource;
+import
org.apache.flink.table.connector.source.abilities.SupportsAggregatePushDown;
+import org.apache.flink.table.expressions.AggregateExpression;
+import org.apache.flink.table.expressions.FieldReferenceExpression;
+import org.apache.flink.table.planner.functions.aggfunctions.AvgAggFunction;
+import org.apache.flink.table.planner.functions.aggfunctions.CountAggFunction;
+import org.apache.flink.table.planner.functions.aggfunctions.Sum0AggFunction;
+import org.apache.flink.table.planner.plan.utils.AggregateInfo;
+import org.apache.flink.table.planner.plan.utils.AggregateInfoList;
+import org.apache.flink.table.planner.plan.utils.AggregateUtil;
+import org.apache.flink.table.planner.utils.JavaScalaConversionUtil;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.logical.RowType;
+import org.apache.flink.table.types.utils.TypeConversions;
+
+import
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonCreator;
+import
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty;
+import
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonTypeName;
+
+import org.apache.calcite.rel.core.AggregateCall;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.stream.Collectors;
+
+import scala.Tuple2;
+
+import static org.apache.flink.util.Preconditions.checkArgument;
+import static org.apache.flink.util.Preconditions.checkNotNull;
+
+/**
+ * A sub-class of {@link SourceAbilitySpec} that can not only
serialize/deserialize the aggregation
+ * to/from JSON, but also can push the local aggregate into a {@link
SupportsAggregatePushDown}.
+ */
+@JsonTypeName("AggregatePushDown")
+public class AggregatePushDownSpec extends SourceAbilitySpecBase {
+
+ public static final String FIELD_NAME_INPUT_TYPE = "inputType";
+
+ public static final String FIELD_NAME_GROUPING_SETS = "groupingSets";
+
+ public static final String FIELD_NAME_AGGREGATE_CALLS = "aggregateCalls";
+
+ @JsonProperty(FIELD_NAME_INPUT_TYPE)
+ private final RowType inputType;
+
+ @JsonProperty(FIELD_NAME_GROUPING_SETS)
+ private final List<int[]> groupingSets;
+
+ @JsonProperty(FIELD_NAME_AGGREGATE_CALLS)
+ private final List<AggregateCall> aggregateCalls;
+
+ @JsonCreator
+ public AggregatePushDownSpec(
+ @JsonProperty(FIELD_NAME_INPUT_TYPE) RowType inputType,
+ @JsonProperty(FIELD_NAME_GROUPING_SETS) List<int[]> groupingSets,
+ @JsonProperty(FIELD_NAME_AGGREGATE_CALLS) List<AggregateCall>
aggregateCalls,
+ @JsonProperty(FIELD_NAME_PRODUCED_TYPE) RowType producedType) {
+ super(producedType);
+
+ this.inputType = inputType;
+ this.groupingSets = new ArrayList<>(checkNotNull(groupingSets));
+ this.aggregateCalls = aggregateCalls;
+ }
+
+ @Override
+ public void apply(DynamicTableSource tableSource, SourceAbilityContext
context) {
+ checkArgument(getProducedType().isPresent());
+ apply(
+ inputType,
+ groupingSets,
+ aggregateCalls,
+ getProducedType().get(),
+ tableSource,
+ context);
+ }
+
+ @Override
+ public String getDigests(SourceAbilityContext context) {
+ int[] grouping = groupingSets.get(0);
+ String groupingStr =
+ Arrays.stream(grouping)
+ .mapToObj(index ->
inputType.getFieldNames().get(index))
+ .collect(Collectors.joining(","));
+
+ List<AggregateExpression> aggregateExpressions =
+ buildAggregateExpressions(inputType, aggregateCalls);
+ String aggFunctionsStr =
+ aggregateExpressions.stream()
+ .map(AggregateExpression::asSummaryString)
+ .collect(Collectors.joining(","));
+
+ return "aggregates=[grouping=["
+ + groupingStr
+ + "], aggFunctions=["
+ + aggFunctionsStr
+ + "]]";
+ }
+
+ public static boolean apply(
+ RowType inputType,
+ List<int[]> groupingSets,
+ List<AggregateCall> aggregateCalls,
+ RowType producedType,
+ DynamicTableSource tableSource,
+ SourceAbilityContext context) {
+ assert context.isBatchMode() && groupingSets.size() == 1;
+
+ List<AggregateExpression> aggregateExpressions =
+ buildAggregateExpressions(inputType, aggregateCalls);
+
+ if (tableSource instanceof SupportsAggregatePushDown) {
+ DataType producedDataType =
TypeConversions.fromLogicalToDataType(producedType);
+ return ((SupportsAggregatePushDown) tableSource)
+ .applyAggregates(groupingSets, aggregateExpressions,
producedDataType);
+ } else {
+ throw new TableException(
+ String.format(
+ "%s does not support SupportsAggregatePushDown.",
+ tableSource.getClass().getName()));
+ }
+ }
+
+ private static List<AggregateExpression> buildAggregateExpressions(
+ RowType inputType, List<AggregateCall> aggregateCalls) {
+ AggregateInfoList aggInfoList =
+ AggregateUtil.transformToBatchAggregateInfoList(
+ inputType,
JavaScalaConversionUtil.toScala(aggregateCalls), null, null);
+ if (aggInfoList.aggInfos().length == 0) {
+ // no agg function need to be pushed down
+ return Collections.emptyList();
+ }
+
+ List<AggregateExpression> aggExpressions = new ArrayList<>();
+ for (AggregateInfo aggInfo : aggInfoList.aggInfos()) {
+ List<FieldReferenceExpression> arguments = new ArrayList<>(1);
+ for (int argIndex : aggInfo.argIndexes()) {
+ DataType argType =
+ TypeConversions.fromLogicalToDataType(
+ inputType.getFields().get(argIndex).getType());
+ FieldReferenceExpression field =
+ new FieldReferenceExpression(
+ inputType.getFieldNames().get(argIndex),
argType, 0, argIndex);
+ arguments.add(field);
+ }
+ if (aggInfo.function() instanceof AvgAggFunction) {
+ Tuple2<Sum0AggFunction, CountAggFunction> sum0AndCountFunction
=
+
AggregateUtil.deriveSumAndCountFromAvg((AvgAggFunction) aggInfo.function());
+ AggregateExpression sum0Expression =
+ new AggregateExpression(
+ sum0AndCountFunction._1(),
+ arguments,
+ null,
+ aggInfo.externalResultType(),
+ aggInfo.agg().isDistinct(),
+ aggInfo.agg().isApproximate(),
+ aggInfo.agg().ignoreNulls());
+ aggExpressions.add(sum0Expression);
+ AggregateExpression countExpression =
+ new AggregateExpression(
+ sum0AndCountFunction._2(),
+ arguments,
+ null,
+ aggInfo.externalResultType(),
+ aggInfo.agg().isDistinct(),
+ aggInfo.agg().isApproximate(),
+ aggInfo.agg().ignoreNulls());
+ aggExpressions.add(countExpression);
+ } else {
+ AggregateExpression aggregateExpression =
+ new AggregateExpression(
+ aggInfo.function(),
+ arguments,
+ null,
+ aggInfo.externalResultType(),
+ aggInfo.agg().isDistinct(),
+ aggInfo.agg().isApproximate(),
+ aggInfo.agg().ignoreNulls());
+ aggExpressions.add(aggregateExpression);
+ }
+ }
+ return aggExpressions;
+ }
+}
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/abilities/source/SourceAbilityContext.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/abilities/source/SourceAbilityContext.java
index 1fbb61a..e3431c7 100644
---
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/abilities/source/SourceAbilityContext.java
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/abilities/source/SourceAbilityContext.java
@@ -40,6 +40,7 @@ import org.apache.calcite.rel.core.TableScan;
* <li>project push down (SupportsProjectionPushDown)
* <li>partition push down (SupportsPartitionPushDown)
* <li>watermark push down (SupportsWatermarkPushDown)
+ * <li>aggregate push down (SupportsAggregatePushDown)
* <li>reading metadata (SupportsReadingMetadata)
* </ul>
*/
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/abilities/source/SourceAbilitySpec.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/abilities/source/SourceAbilitySpec.java
index 92326f0..453ee4c 100644
---
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/abilities/source/SourceAbilitySpec.java
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/abilities/source/SourceAbilitySpec.java
@@ -40,7 +40,8 @@ import java.util.Optional;
@JsonSubTypes.Type(value = ProjectPushDownSpec.class),
@JsonSubTypes.Type(value = ReadingMetadataSpec.class),
@JsonSubTypes.Type(value = WatermarkPushDownSpec.class),
- @JsonSubTypes.Type(value = SourceWatermarkSpec.class)
+ @JsonSubTypes.Type(value = SourceWatermarkSpec.class),
+ @JsonSubTypes.Type(value = AggregatePushDownSpec.class)
})
@Internal
public interface SourceAbilitySpec {
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoScanRuleBase.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoScanRuleBase.java
new file mode 100644
index 0000000..fe9a858
--- /dev/null
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoScanRuleBase.java
@@ -0,0 +1,240 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.rules.physical.batch;
+
+import org.apache.flink.table.api.TableConfig;
+import org.apache.flink.table.api.config.OptimizerConfigOptions;
+import org.apache.flink.table.connector.source.DynamicTableSource;
+import
org.apache.flink.table.connector.source.abilities.SupportsAggregatePushDown;
+import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
+import
org.apache.flink.table.planner.plan.abilities.source.AggregatePushDownSpec;
+import
org.apache.flink.table.planner.plan.abilities.source.ProjectPushDownSpec;
+import
org.apache.flink.table.planner.plan.abilities.source.SourceAbilityContext;
+import org.apache.flink.table.planner.plan.abilities.source.SourceAbilitySpec;
+import
org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalCalc;
+import
org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalExchange;
+import
org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalGroupAggregateBase;
+import
org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalTableSourceScan;
+import org.apache.flink.table.planner.plan.schema.TableSourceTable;
+import org.apache.flink.table.planner.plan.stats.FlinkStatistic;
+import org.apache.flink.table.planner.plan.utils.RexNodeExtractor;
+import org.apache.flink.table.planner.utils.JavaScalaConversionUtil;
+import org.apache.flink.table.planner.utils.ShortcutUtils;
+import org.apache.flink.table.types.logical.RowType;
+
+import org.apache.calcite.plan.RelOptRule;
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.plan.RelOptRuleOperand;
+import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.rex.RexInputRef;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.rex.RexProgram;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.stream.Collectors;
+
+/**
+ * Planner rule that tries to push a local aggregator into an {@link
BatchPhysicalTableSourceScan}
+ * whose table is a {@link TableSourceTable} with a source supporting {@link
+ * SupportsAggregatePushDown}.
+ *
+ * <p>The aggregate push down does not support a number of more complex
statements at present:
+ *
+ * <ul>
+ * <li>complex grouping operations such as ROLLUP, CUBE, or GROUPING SETS.
+ * <li>expressions inside the aggregation function call: such as sum(a * b).
+ * <li>aggregations with ordering.
+ * <li>aggregations with filter.
+ * </ul>
+ */
+public abstract class PushLocalAggIntoScanRuleBase extends RelOptRule {
+
+ public PushLocalAggIntoScanRuleBase(RelOptRuleOperand operand, String
description) {
+ super(operand, description);
+ }
+
+ protected boolean canPushDown(
+ RelOptRuleCall call,
+ BatchPhysicalGroupAggregateBase aggregate,
+ BatchPhysicalTableSourceScan tableSourceScan) {
+ TableConfig tableConfig =
ShortcutUtils.unwrapContext(call.getPlanner()).getTableConfig();
+ if (!tableConfig
+ .getConfiguration()
+ .getBoolean(
+
OptimizerConfigOptions.TABLE_OPTIMIZER_SOURCE_AGGREGATE_PUSHDOWN_ENABLED)) {
+ return false;
+ }
+
+ if (aggregate.isFinal() || aggregate.getAggCallList().isEmpty()) {
+ return false;
+ }
+ List<AggregateCall> aggCallList =
+ JavaScalaConversionUtil.toJava(aggregate.getAggCallList());
+ for (AggregateCall aggCall : aggCallList) {
+ if (aggCall.isDistinct()
+ || aggCall.isApproximate()
+ || aggCall.getArgList().size() > 1
+ || aggCall.hasFilter()
+ || !aggCall.getCollation().getFieldCollations().isEmpty())
{
+ return false;
+ }
+ }
+ TableSourceTable tableSourceTable = tableSourceScan.tableSourceTable();
+ // we can not push aggregates twice
+ return tableSourceTable != null
+ && tableSourceTable.tableSource() instanceof
SupportsAggregatePushDown
+ && Arrays.stream(tableSourceTable.abilitySpecs())
+ .noneMatch(spec -> spec instanceof
AggregatePushDownSpec);
+ }
+
+ protected void pushLocalAggregateIntoScan(
+ RelOptRuleCall call,
+ BatchPhysicalGroupAggregateBase localAgg,
+ BatchPhysicalTableSourceScan oldScan) {
+ pushLocalAggregateIntoScan(call, localAgg, oldScan, null);
+ }
+
+ protected void pushLocalAggregateIntoScan(
+ RelOptRuleCall call,
+ BatchPhysicalGroupAggregateBase localAgg,
+ BatchPhysicalTableSourceScan oldScan,
+ int[] calcRefFields) {
+ RowType inputType =
FlinkTypeFactory.toLogicalRowType(oldScan.getRowType());
+ List<int[]> groupingSets =
+ Collections.singletonList(
+ ArrayUtils.addAll(localAgg.grouping(),
localAgg.auxGrouping()));
+ List<AggregateCall> aggCallList =
JavaScalaConversionUtil.toJava(localAgg.getAggCallList());
+
+ // map arg index in aggregate to field index in scan through referred
fields by calc.
+ if (calcRefFields != null) {
+ groupingSets = translateGroupingArgIndex(groupingSets,
calcRefFields);
+ aggCallList = translateAggCallArgIndex(aggCallList, calcRefFields);
+ }
+
+ RowType producedType =
FlinkTypeFactory.toLogicalRowType(localAgg.getRowType());
+
+ TableSourceTable oldTableSourceTable = oldScan.tableSourceTable();
+ DynamicTableSource newTableSource = oldScan.tableSource().copy();
+
+ boolean isPushDownSuccess =
+ AggregatePushDownSpec.apply(
+ inputType,
+ groupingSets,
+ aggCallList,
+ producedType,
+ newTableSource,
+ SourceAbilityContext.from(oldScan));
+
+ if (!isPushDownSuccess) {
+ // aggregate push down failed, just return without changing any
nodes.
+ return;
+ }
+
+ // create new source table with new spec and statistic.
+ AggregatePushDownSpec aggregatePushDownSpec =
+ new AggregatePushDownSpec(inputType, groupingSets,
aggCallList, producedType);
+
+ TableSourceTable newTableSourceTable =
+ oldTableSourceTable
+ .copy(
+ newTableSource,
+ localAgg.getRowType(),
+ new SourceAbilitySpec[]
{aggregatePushDownSpec})
+ .copy(FlinkStatistic.UNKNOWN());
+
+ // transform to new nodes.
+ BatchPhysicalTableSourceScan newScan =
+ oldScan.copy(oldScan.getTraitSet(), newTableSourceTable);
+ BatchPhysicalExchange oldExchange = call.rel(0);
+ BatchPhysicalExchange newExchange =
+ oldExchange.copy(oldExchange.getTraitSet(), newScan,
oldExchange.getDistribution());
+ call.transformTo(newExchange);
+ }
+
+ protected boolean isProjectionNotPushedDown(BatchPhysicalTableSourceScan
tableSourceScan) {
+ TableSourceTable tableSourceTable = tableSourceScan.tableSourceTable();
+ return tableSourceTable != null
+ && Arrays.stream(tableSourceTable.abilitySpecs())
+ .noneMatch(spec -> spec instanceof
ProjectPushDownSpec);
+ }
+
+ /**
+ * Currently, we only supports to push down aggregate above calc which has
input ref only.
+ *
+ * @param calc BatchPhysicalCalc
+ * @return true if OK to be pushed down
+ */
+ protected boolean isInputRefOnly(BatchPhysicalCalc calc) {
+ RexProgram program = calc.getProgram();
+
+ // check if condition exists. All filters should have been pushed down.
+ if (program.getCondition() != null) {
+ return false;
+ }
+
+ return !program.getProjectList().isEmpty()
+ && program.getProjectList().stream()
+ .map(calc.getProgram()::expandLocalRef)
+ .allMatch(RexInputRef.class::isInstance);
+ }
+
+ protected int[] getRefFiledIndex(BatchPhysicalCalc calc) {
+ List<RexNode> projects =
+ calc.getProgram().getProjectList().stream()
+ .map(calc.getProgram()::expandLocalRef)
+ .collect(Collectors.toList());
+
+ return RexNodeExtractor.extractRefInputFields(projects);
+ }
+
+ protected List<int[]> translateGroupingArgIndex(List<int[]> groupingSets,
int[] refFields) {
+ List<int[]> newGroupingSets = new ArrayList<>();
+ groupingSets.forEach(
+ grouping -> {
+ int[] newGrouping = new int[grouping.length];
+ for (int i = 0; i < grouping.length; i++) {
+ int argIndex = grouping[i];
+ newGrouping[i] = refFields[argIndex];
+ }
+ newGroupingSets.add(newGrouping);
+ });
+
+ return newGroupingSets;
+ }
+
+ protected List<AggregateCall> translateAggCallArgIndex(
+ List<AggregateCall> aggCallList, int[] refFields) {
+ List<AggregateCall> newAggCallList = new ArrayList<>();
+ aggCallList.forEach(
+ aggCall -> {
+ List<Integer> argList = new ArrayList<>();
+ for (int i = 0; i < aggCall.getArgList().size(); i++) {
+ int argIndex = aggCall.getArgList().get(i);
+ argList.add(refFields[argIndex]);
+ }
+ newAggCallList.add(aggCall.copy(argList,
aggCall.filterArg, aggCall.collation));
+ });
+
+ return newAggCallList;
+ }
+}
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalHashAggIntoScanRule.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalHashAggIntoScanRule.java
new file mode 100644
index 0000000..0678162
--- /dev/null
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalHashAggIntoScanRule.java
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.rules.physical.batch;
+
+import org.apache.flink.table.api.config.OptimizerConfigOptions;
+import
org.apache.flink.table.connector.source.abilities.SupportsAggregatePushDown;
+import
org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalExchange;
+import
org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalLocalHashAggregate;
+import
org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalTableSourceScan;
+import org.apache.flink.table.planner.plan.schema.TableSourceTable;
+
+import org.apache.calcite.plan.RelOptRuleCall;
+
+/**
+ * Planner rule that tries to push a local hash aggregate which without sort
into a {@link
+ * BatchPhysicalTableSourceScan} whose table is a {@link TableSourceTable}
with a source supporting
+ * {@link SupportsAggregatePushDown}. The {@link
+ * OptimizerConfigOptions#TABLE_OPTIMIZER_SOURCE_AGGREGATE_PUSHDOWN_ENABLED}
need to be true.
+ *
+ * <p>Suppose we have the original physical plan:
+ *
+ * <pre>{@code
+ * BatchPhysicalHashAggregate (global)
+ * +- BatchPhysicalExchange (hash by group keys if group keys is not empty,
else singleton)
+ * +- BatchPhysicalLocalHashAggregate (local)
+ * +- BatchPhysicalTableSourceScan
+ * }</pre>
+ *
+ * <p>This physical plan will be rewritten to:
+ *
+ * <pre>{@code
+ * BatchPhysicalHashAggregate (global)
+ * +- BatchPhysicalExchange (hash by group keys if group keys is not empty,
else singleton)
+ * +- BatchPhysicalTableSourceScan (with local aggregate pushed down)
+ * }</pre>
+ */
+public class PushLocalHashAggIntoScanRule extends PushLocalAggIntoScanRuleBase
{
+ public static final PushLocalHashAggIntoScanRule INSTANCE = new
PushLocalHashAggIntoScanRule();
+
+ public PushLocalHashAggIntoScanRule() {
+ super(
+ operand(
+ BatchPhysicalExchange.class,
+ operand(
+ BatchPhysicalLocalHashAggregate.class,
+ operand(BatchPhysicalTableSourceScan.class,
none()))),
+ "PushLocalHashAggIntoScanRule");
+ }
+
+ @Override
+ public boolean matches(RelOptRuleCall call) {
+ BatchPhysicalLocalHashAggregate localAggregate = call.rel(1);
+ BatchPhysicalTableSourceScan tableSourceScan = call.rel(2);
+ return canPushDown(call, localAggregate, tableSourceScan);
+ }
+
+ @Override
+ public void onMatch(RelOptRuleCall call) {
+ BatchPhysicalLocalHashAggregate localHashAgg = call.rel(1);
+ BatchPhysicalTableSourceScan oldScan = call.rel(2);
+ pushLocalAggregateIntoScan(call, localHashAgg, oldScan);
+ }
+}
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalHashAggWithCalcIntoScanRule.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalHashAggWithCalcIntoScanRule.java
new file mode 100755
index 0000000..87f47c5
--- /dev/null
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalHashAggWithCalcIntoScanRule.java
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.rules.physical.batch;
+
+import org.apache.flink.table.api.config.OptimizerConfigOptions;
+import
org.apache.flink.table.connector.source.abilities.SupportsAggregatePushDown;
+import
org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalCalc;
+import
org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalExchange;
+import
org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalLocalHashAggregate;
+import
org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalTableSourceScan;
+import org.apache.flink.table.planner.plan.schema.TableSourceTable;
+
+import org.apache.calcite.plan.RelOptRuleCall;
+
+/**
+ * Planner rule that tries to push a local hash aggregate which with calc into
a {@link
+ * BatchPhysicalTableSourceScan} whose table is a {@link TableSourceTable}
with a source supporting
+ * {@link SupportsAggregatePushDown}. The {@link
+ * OptimizerConfigOptions#TABLE_OPTIMIZER_SOURCE_AGGREGATE_PUSHDOWN_ENABLED}
need to be true.
+ *
+ * <p>Suppose we have the original physical plan:
+ *
+ * <pre>{@code
+ * BatchPhysicalHashAggregate (global)
+ * +- BatchPhysicalExchange (hash by group keys if group keys is not empty,
else singleton)
+ * +- BatchPhysicalLocalHashAggregate (local)
+ * +- BatchPhysicalCalc (filed projection only)
+ * +- BatchPhysicalTableSourceScan
+ * }</pre>
+ *
+ * <p>This physical plan will be rewritten to:
+ *
+ * <pre>{@code
+ * BatchPhysicalHashAggregate (global)
+ * +- BatchPhysicalExchange (hash by group keys if group keys is not empty,
else singleton)
+ * +- BatchPhysicalTableSourceScan (with local aggregate pushed down)
+ * }</pre>
+ */
+public class PushLocalHashAggWithCalcIntoScanRule extends
PushLocalAggIntoScanRuleBase {
+ public static final PushLocalHashAggWithCalcIntoScanRule INSTANCE =
+ new PushLocalHashAggWithCalcIntoScanRule();
+
+ public PushLocalHashAggWithCalcIntoScanRule() {
+ super(
+ operand(
+ BatchPhysicalExchange.class,
+ operand(
+ BatchPhysicalLocalHashAggregate.class,
+ operand(
+ BatchPhysicalCalc.class,
+
operand(BatchPhysicalTableSourceScan.class, none())))),
+ "PushLocalHashAggWithCalcIntoScanRule");
+ }
+
+ @Override
+ public boolean matches(RelOptRuleCall call) {
+ BatchPhysicalLocalHashAggregate localHashAgg = call.rel(1);
+ BatchPhysicalCalc calc = call.rel(2);
+ BatchPhysicalTableSourceScan tableSourceScan = call.rel(3);
+
+ return isInputRefOnly(calc)
+ && isProjectionNotPushedDown(tableSourceScan)
+ && canPushDown(call, localHashAgg, tableSourceScan);
+ }
+
+ @Override
+ public void onMatch(RelOptRuleCall call) {
+ BatchPhysicalLocalHashAggregate localHashAgg = call.rel(1);
+ BatchPhysicalCalc calc = call.rel(2);
+ BatchPhysicalTableSourceScan oldScan = call.rel(3);
+
+ int[] calcRefFields = getRefFiledIndex(calc);
+
+ pushLocalAggregateIntoScan(call, localHashAgg, oldScan, calcRefFields);
+ }
+}
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggIntoScanRule.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggIntoScanRule.java
new file mode 100755
index 0000000..ca101ca
--- /dev/null
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggIntoScanRule.java
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.rules.physical.batch;
+
+import org.apache.flink.table.api.config.OptimizerConfigOptions;
+import
org.apache.flink.table.connector.source.abilities.SupportsAggregatePushDown;
+import
org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalExchange;
+import
org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalLocalSortAggregate;
+import
org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalTableSourceScan;
+import org.apache.flink.table.planner.plan.schema.TableSourceTable;
+
+import org.apache.calcite.plan.RelOptRuleCall;
+
+/**
+ * Planner rule that tries to push a local sort aggregate which without sort
into a {@link
+ * BatchPhysicalTableSourceScan} whose table is a {@link TableSourceTable}
with a source supporting
+ * {@link SupportsAggregatePushDown}. The {@link
+ * OptimizerConfigOptions#TABLE_OPTIMIZER_SOURCE_AGGREGATE_PUSHDOWN_ENABLED}
need to be true.
+ *
+ * <p>Suppose we have the original physical plan:
+ *
+ * <pre>{@code
+ * BatchPhysicalSortAggregate (global)
+ * +- BatchPhysicalExchange (hash by group keys if group keys is not empty,
else singleton)
+ * +- BatchPhysicalLocalSortAggregate (local)
+ * +- BatchPhysicalTableSourceScan
+ * }</pre>
+ *
+ * <p>This physical plan will be rewritten to:
+ *
+ * <pre>{@code
+ * BatchPhysicalSortAggregate (global)
+ * +- BatchPhysicalExchange (hash by group keys if group keys is not empty,
else singleton)
+ * +- BatchPhysicalTableSourceScan (with local aggregate pushed down)
+ * }</pre>
+ */
+public class PushLocalSortAggIntoScanRule extends PushLocalAggIntoScanRuleBase
{
+ public static final PushLocalSortAggIntoScanRule INSTANCE = new
PushLocalSortAggIntoScanRule();
+
+ public PushLocalSortAggIntoScanRule() {
+ super(
+ operand(
+ BatchPhysicalExchange.class,
+ operand(
+ BatchPhysicalLocalSortAggregate.class,
+ operand(BatchPhysicalTableSourceScan.class,
none()))),
+ "PushLocalSortAggIntoScanRule");
+ }
+
+ @Override
+ public boolean matches(RelOptRuleCall call) {
+ BatchPhysicalLocalSortAggregate localAggregate = call.rel(1);
+ BatchPhysicalTableSourceScan tableSourceScan = call.rel(2);
+ return canPushDown(call, localAggregate, tableSourceScan);
+ }
+
+ @Override
+ public void onMatch(RelOptRuleCall call) {
+ BatchPhysicalLocalSortAggregate localHashAgg = call.rel(1);
+ BatchPhysicalTableSourceScan oldScan = call.rel(2);
+ pushLocalAggregateIntoScan(call, localHashAgg, oldScan);
+ }
+}
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithCalcIntoScanRule.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithCalcIntoScanRule.java
new file mode 100755
index 0000000..e56e3aa
--- /dev/null
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithCalcIntoScanRule.java
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.rules.physical.batch;
+
+import org.apache.flink.table.api.config.OptimizerConfigOptions;
+import
org.apache.flink.table.connector.source.abilities.SupportsAggregatePushDown;
+import
org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalCalc;
+import
org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalExchange;
+import
org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalLocalSortAggregate;
+import
org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalTableSourceScan;
+import org.apache.flink.table.planner.plan.schema.TableSourceTable;
+
+import org.apache.calcite.plan.RelOptRuleCall;
+
+/**
+ * Planner rule that tries to push a local sort aggregate which without sort
into a {@link
+ * BatchPhysicalTableSourceScan} whose table is a {@link TableSourceTable}
with a source supporting
+ * {@link SupportsAggregatePushDown}. The {@link
+ * OptimizerConfigOptions#TABLE_OPTIMIZER_SOURCE_AGGREGATE_PUSHDOWN_ENABLED}
need to be true.
+ *
+ * <p>Suppose we have the original physical plan:
+ *
+ * <pre>{@code
+ * BatchPhysicalSortAggregate (global)
+ * +- BatchPhysicalExchange (hash by group keys if group keys is not empty,
else singleton)
+ * +- BatchPhysicalLocalSortAggregate (local)
+ * +- BatchPhysicalCalc (filed projection only)
+ * +- BatchPhysicalTableSourceScan
+ * }</pre>
+ *
+ * <p>This physical plan will be rewritten to:
+ *
+ * <pre>{@code
+ * BatchPhysicalSortAggregate (global)
+ * +- BatchPhysicalExchange (hash by group keys if group keys is not empty,
else singleton)
+ * +- BatchPhysicalTableSourceScan (with local aggregate pushed down)
+ * }</pre>
+ */
+public class PushLocalSortAggWithCalcIntoScanRule extends
PushLocalAggIntoScanRuleBase {
+ public static final PushLocalSortAggWithCalcIntoScanRule INSTANCE =
+ new PushLocalSortAggWithCalcIntoScanRule();
+
+ public PushLocalSortAggWithCalcIntoScanRule() {
+ super(
+ operand(
+ BatchPhysicalExchange.class,
+ operand(
+ BatchPhysicalLocalSortAggregate.class,
+ operand(
+ BatchPhysicalCalc.class,
+
operand(BatchPhysicalTableSourceScan.class, none())))),
+ "PushLocalSortAggWithCalcIntoScanRule");
+ }
+
+ @Override
+ public boolean matches(RelOptRuleCall call) {
+ BatchPhysicalLocalSortAggregate localAggregate = call.rel(1);
+ BatchPhysicalCalc calc = call.rel(2);
+ BatchPhysicalTableSourceScan tableSourceScan = call.rel(3);
+
+ return isInputRefOnly(calc)
+ && isProjectionNotPushedDown(tableSourceScan)
+ && canPushDown(call, localAggregate, tableSourceScan);
+ }
+
+ @Override
+ public void onMatch(RelOptRuleCall call) {
+ BatchPhysicalLocalSortAggregate localHashAgg = call.rel(1);
+ BatchPhysicalCalc calc = call.rel(2);
+ BatchPhysicalTableSourceScan oldScan = call.rel(3);
+
+ int[] calcRefFields = getRefFiledIndex(calc);
+
+ pushLocalAggregateIntoScan(call, localHashAgg, oldScan, calcRefFields);
+ }
+}
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithSortAndCalcIntoScanRule.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithSortAndCalcIntoScanRule.java
new file mode 100755
index 0000000..d9c340a
--- /dev/null
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithSortAndCalcIntoScanRule.java
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.rules.physical.batch;
+
+import org.apache.flink.table.api.config.OptimizerConfigOptions;
+import
org.apache.flink.table.connector.source.abilities.SupportsAggregatePushDown;
+import
org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalCalc;
+import
org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalExchange;
+import
org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalGroupAggregateBase;
+import
org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalLocalSortAggregate;
+import
org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalSort;
+import
org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalTableSourceScan;
+import org.apache.flink.table.planner.plan.schema.TableSourceTable;
+
+import org.apache.calcite.plan.RelOptRuleCall;
+
+/**
+ * Planner rule that tries to push a local sort aggregate which with sort and
calc into a {@link
+ * BatchPhysicalTableSourceScan} whose table is a {@link TableSourceTable}
with a source supporting
+ * {@link SupportsAggregatePushDown}. The {@link
+ * OptimizerConfigOptions#TABLE_OPTIMIZER_SOURCE_AGGREGATE_PUSHDOWN_ENABLED}
need to be true.
+ *
+ * <p>Suppose we have the original physical plan:
+ *
+ * <pre>{@code
+ * BatchPhysicalSortAggregate (global)
+ * +- BatchPhysicalSort (exists if group keys are not empty)
+ * +- BatchPhysicalExchange (hash by group keys if group keys is not empty,
else singleton)
+ * +- BatchPhysicalLocalSortAggregate (local)
+ * +- BatchPhysicalSort (exists if group keys are not empty)
+ * +- BatchPhysicalCalc (filed projection only)
+ * +- BatchPhysicalTableSourceScan
+ * }</pre>
+ *
+ * <p>This physical plan will be rewritten to:
+ *
+ * <pre>{@code
+ * BatchPhysicalSortAggregate (global)
+ * +- BatchPhysicalSort (exists if group keys are not empty)
+ * +- BatchPhysicalExchange (hash by group keys if group keys is not empty,
else singleton)
+ * +- BatchPhysicalTableSourceScan (with local aggregate pushed down)
+ * }</pre>
+ */
+public class PushLocalSortAggWithSortAndCalcIntoScanRule extends
PushLocalAggIntoScanRuleBase {
+ public static final PushLocalSortAggWithSortAndCalcIntoScanRule INSTANCE =
+ new PushLocalSortAggWithSortAndCalcIntoScanRule();
+
+ public PushLocalSortAggWithSortAndCalcIntoScanRule() {
+ super(
+ operand(
+ BatchPhysicalExchange.class,
+ operand(
+ BatchPhysicalLocalSortAggregate.class,
+ operand(
+ BatchPhysicalSort.class,
+ operand(
+ BatchPhysicalCalc.class,
+ operand(
+
BatchPhysicalTableSourceScan.class,
+ none()))))),
+ "PushLocalSortAggWithSortAndCalcIntoScanRule");
+ }
+
+ @Override
+ public boolean matches(RelOptRuleCall call) {
+ BatchPhysicalGroupAggregateBase localAggregate = call.rel(1);
+ BatchPhysicalCalc calc = call.rel(3);
+ BatchPhysicalTableSourceScan tableSourceScan = call.rel(4);
+
+ return isInputRefOnly(calc)
+ && isProjectionNotPushedDown(tableSourceScan)
+ && canPushDown(call, localAggregate, tableSourceScan);
+ }
+
+ @Override
+ public void onMatch(RelOptRuleCall call) {
+ BatchPhysicalGroupAggregateBase localSortAgg = call.rel(1);
+ BatchPhysicalCalc calc = call.rel(3);
+ BatchPhysicalTableSourceScan oldScan = call.rel(4);
+
+ int[] calcRefFields = getRefFiledIndex(calc);
+
+ pushLocalAggregateIntoScan(call, localSortAgg, oldScan, calcRefFields);
+ }
+}
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithSortIntoScanRule.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithSortIntoScanRule.java
new file mode 100644
index 0000000..9d952b2
--- /dev/null
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalSortAggWithSortIntoScanRule.java
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.rules.physical.batch;
+
+import org.apache.flink.table.api.config.OptimizerConfigOptions;
+import
org.apache.flink.table.connector.source.abilities.SupportsAggregatePushDown;
+import
org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalExchange;
+import
org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalGroupAggregateBase;
+import
org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalLocalSortAggregate;
+import
org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalSort;
+import
org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalTableSourceScan;
+import org.apache.flink.table.planner.plan.schema.TableSourceTable;
+
+import org.apache.calcite.plan.RelOptRuleCall;
+
+/**
+ * Planner rule that tries to push a local sort aggregate which with sort into
a {@link
+ * BatchPhysicalTableSourceScan} whose table is a {@link TableSourceTable}
with a source supporting
+ * {@link SupportsAggregatePushDown}. The {@link
+ * OptimizerConfigOptions#TABLE_OPTIMIZER_SOURCE_AGGREGATE_PUSHDOWN_ENABLED}
need to be true.
+ *
+ * <p>Suppose we have the original physical plan:
+ *
+ * <pre>{@code
+ * BatchPhysicalSortAggregate (global)
+ * +- BatchPhysicalSort (exists if group keys are not empty)
+ * +- BatchPhysicalExchange (hash by group keys if group keys is not empty,
else singleton)
+ * +- BatchPhysicalLocalSortAggregate (local)
+ * +- BatchPhysicalSort (exists if group keys are not empty)
+ * +- BatchPhysicalTableSourceScan
+ * }</pre>
+ *
+ * <p>This physical plan will be rewritten to:
+ *
+ * <pre>{@code
+ * BatchPhysicalSortAggregate (global)
+ * +- BatchPhysicalSort (exists if group keys are not empty)
+ * +- BatchPhysicalExchange (hash by group keys if group keys is not empty,
else singleton)
+ * +- BatchPhysicalTableSourceScan (with local aggregate pushed down)
+ * }</pre>
+ */
+public class PushLocalSortAggWithSortIntoScanRule extends
PushLocalAggIntoScanRuleBase {
+ public static final PushLocalSortAggWithSortIntoScanRule INSTANCE =
+ new PushLocalSortAggWithSortIntoScanRule();
+
+ public PushLocalSortAggWithSortIntoScanRule() {
+ super(
+ operand(
+ BatchPhysicalExchange.class,
+ operand(
+ BatchPhysicalLocalSortAggregate.class,
+ operand(
+ BatchPhysicalSort.class,
+
operand(BatchPhysicalTableSourceScan.class, none())))),
+ "PushLocalSortAggWithSortIntoScanRule");
+ }
+
+ @Override
+ public boolean matches(RelOptRuleCall call) {
+ BatchPhysicalGroupAggregateBase localAggregate = call.rel(1);
+ BatchPhysicalTableSourceScan tableSourceScan = call.rel(3);
+ return canPushDown(call, localAggregate, tableSourceScan);
+ }
+
+ @Override
+ public void onMatch(RelOptRuleCall call) {
+ BatchPhysicalGroupAggregateBase localSortAgg = call.rel(1);
+ BatchPhysicalTableSourceScan oldScan = call.rel(3);
+ pushLocalAggregateIntoScan(call, localSortAgg, oldScan);
+ }
+}
diff --git
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalTableSourceScan.scala
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalTableSourceScan.scala
index 3021002..0a3563d 100644
---
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalTableSourceScan.scala
+++
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalTableSourceScan.scala
@@ -49,6 +49,12 @@ class BatchPhysicalTableSourceScan(
new BatchPhysicalTableSourceScan(cluster, traitSet, getHints,
tableSourceTable)
}
+ def copy(
+ traitSet: RelTraitSet,
+ tableSourceTable: TableSourceTable): BatchPhysicalTableSourceScan
= {
+ new BatchPhysicalTableSourceScan(cluster, traitSet, getHints,
tableSourceTable)
+ }
+
override def computeSelfCost(planner: RelOptPlanner, mq: RelMetadataQuery):
RelOptCost = {
val rowCnt = mq.getRowCount(this)
if (rowCnt == null) {
diff --git
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala
index 83fa93b..28db5c8 100644
---
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala
+++
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala
@@ -448,6 +448,12 @@ object FlinkBatchRuleSets {
*/
val PHYSICAL_REWRITE: RuleSet = RuleSets.ofList(
EnforceLocalHashAggRule.INSTANCE,
- EnforceLocalSortAggRule.INSTANCE
+ EnforceLocalSortAggRule.INSTANCE,
+ PushLocalHashAggIntoScanRule.INSTANCE,
+ PushLocalHashAggWithCalcIntoScanRule.INSTANCE,
+ PushLocalSortAggIntoScanRule.INSTANCE,
+ PushLocalSortAggWithSortIntoScanRule.INSTANCE,
+ PushLocalSortAggWithCalcIntoScanRule.INSTANCE,
+ PushLocalSortAggWithSortAndCalcIntoScanRule.INSTANCE
)
}
diff --git
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/schema/TableSourceTable.scala
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/schema/TableSourceTable.scala
index cc96c10..2ef50b4 100644
---
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/schema/TableSourceTable.scala
+++
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/schema/TableSourceTable.scala
@@ -131,4 +131,23 @@ class TableSourceTable(
flinkContext,
abilitySpecs ++ newAbilitySpecs)
}
+
+ /**
+ * Creates a copy of this table, changing the statistic
+ *
+ * @param newStatistic new table statistic
+ * @return New TableSourceTable instance with new statistic
+ */
+ def copy(newStatistic: FlinkStatistic): TableSourceTable = {
+ new TableSourceTable(
+ relOptSchema,
+ tableIdentifier,
+ rowType,
+ newStatistic,
+ tableSource,
+ isStreamingMode,
+ catalogTable,
+ flinkContext,
+ abilitySpecs)
+ }
}
diff --git
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala
index 7294d24..37d2157 100644
---
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala
+++
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala
@@ -27,7 +27,9 @@ import org.apache.flink.table.planner.JLong
import org.apache.flink.table.planner.calcite.{FlinkTypeFactory,
FlinkTypeSystem}
import org.apache.flink.table.planner.delegation.PlannerBase
import org.apache.flink.table.planner.expressions._
-import
org.apache.flink.table.planner.functions.aggfunctions.DeclarativeAggregateFunction
+import
org.apache.flink.table.planner.functions.aggfunctions.AvgAggFunction.{ByteAvgAggFunction,
DoubleAvgAggFunction, FloatAvgAggFunction, IntAvgAggFunction,
LongAvgAggFunction, ShortAvgAggFunction}
+import
org.apache.flink.table.planner.functions.aggfunctions.Sum0AggFunction.{ByteSum0AggFunction,
DoubleSum0AggFunction, FloatSum0AggFunction, IntSum0AggFunction,
LongSum0AggFunction, ShortSum0AggFunction}
+import org.apache.flink.table.planner.functions.aggfunctions.{AvgAggFunction,
CountAggFunction, DeclarativeAggregateFunction, Sum0AggFunction}
import org.apache.flink.table.planner.functions.bridging.BridgingSqlAggFunction
import
org.apache.flink.table.planner.functions.inference.OperatorBindingCallContext
import org.apache.flink.table.planner.functions.sql.{FlinkSqlOperatorTable,
SqlFirstLastValueAggFunction, SqlListAggFunction}
@@ -48,7 +50,7 @@ import org.apache.flink.table.types.DataType
import org.apache.flink.table.types.inference.TypeInferenceUtil
import org.apache.flink.table.types.logical.LogicalTypeRoot._
import org.apache.flink.table.types.logical.utils.LogicalTypeChecks
-import org.apache.flink.table.types.logical.{LogicalTypeRoot, _}
+import org.apache.flink.table.types.logical._
import org.apache.flink.table.types.utils.DataTypeUtils
import org.apache.calcite.rel.`type`._
@@ -278,6 +280,21 @@ object AggregateUtil extends Enumeration {
isBounded = false)
}
+ def deriveSumAndCountFromAvg(
+ avgAggFunction: AvgAggFunction): (Sum0AggFunction, CountAggFunction) = {
+ avgAggFunction match {
+ case _: ByteAvgAggFunction => (new ByteSum0AggFunction, new
CountAggFunction)
+ case _: ShortAvgAggFunction => (new ShortSum0AggFunction, new
CountAggFunction)
+ case _: IntAvgAggFunction => (new IntSum0AggFunction, new
CountAggFunction)
+ case _: LongAvgAggFunction => (new LongSum0AggFunction, new
CountAggFunction)
+ case _: FloatAvgAggFunction => (new FloatSum0AggFunction, new
CountAggFunction)
+ case _: DoubleAvgAggFunction => (new DoubleSum0AggFunction, new
CountAggFunction)
+ case _ =>
+ throw new TableException(s"Avg aggregate function does not support:
''$avgAggFunction''" +
+ s"Please re-check the function or data type.")
+ }
+ }
+
def transformToBatchAggregateFunctions(
inputRowType: RowType,
aggregateCalls: Seq[AggregateCall],
diff --git
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesTableFactory.java
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesTableFactory.java
index d717876..404d05a 100644
---
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesTableFactory.java
+++
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesTableFactory.java
@@ -54,6 +54,7 @@ import
org.apache.flink.table.connector.source.LookupTableSource;
import org.apache.flink.table.connector.source.ScanTableSource;
import org.apache.flink.table.connector.source.SourceFunctionProvider;
import org.apache.flink.table.connector.source.TableFunctionProvider;
+import
org.apache.flink.table.connector.source.abilities.SupportsAggregatePushDown;
import
org.apache.flink.table.connector.source.abilities.SupportsFilterPushDown;
import org.apache.flink.table.connector.source.abilities.SupportsLimitPushDown;
import
org.apache.flink.table.connector.source.abilities.SupportsPartitionPushDown;
@@ -62,11 +63,14 @@ import
org.apache.flink.table.connector.source.abilities.SupportsReadingMetadata
import
org.apache.flink.table.connector.source.abilities.SupportsSourceWatermark;
import
org.apache.flink.table.connector.source.abilities.SupportsWatermarkPushDown;
import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.expressions.AggregateExpression;
+import org.apache.flink.table.expressions.FieldReferenceExpression;
import org.apache.flink.table.expressions.ResolvedExpression;
import org.apache.flink.table.factories.DynamicTableSinkFactory;
import org.apache.flink.table.factories.DynamicTableSourceFactory;
import org.apache.flink.table.factories.FactoryUtil;
import org.apache.flink.table.functions.AsyncTableFunction;
+import org.apache.flink.table.functions.FunctionDefinition;
import org.apache.flink.table.functions.TableFunction;
import
org.apache.flink.table.planner.factories.TestValuesRuntimeFunctions.AppendingOutputFormat;
import
org.apache.flink.table.planner.factories.TestValuesRuntimeFunctions.AppendingSinkFunction;
@@ -74,10 +78,17 @@ import
org.apache.flink.table.planner.factories.TestValuesRuntimeFunctions.Async
import
org.apache.flink.table.planner.factories.TestValuesRuntimeFunctions.KeyedUpsertingSinkFunction;
import
org.apache.flink.table.planner.factories.TestValuesRuntimeFunctions.RetractingSinkFunction;
import
org.apache.flink.table.planner.factories.TestValuesRuntimeFunctions.TestValuesLookupFunction;
+import org.apache.flink.table.planner.functions.aggfunctions.Count1AggFunction;
+import org.apache.flink.table.planner.functions.aggfunctions.CountAggFunction;
+import org.apache.flink.table.planner.functions.aggfunctions.MaxAggFunction;
+import org.apache.flink.table.planner.functions.aggfunctions.MinAggFunction;
+import org.apache.flink.table.planner.functions.aggfunctions.Sum0AggFunction;
+import org.apache.flink.table.planner.functions.aggfunctions.SumAggFunction;
import org.apache.flink.table.planner.runtime.utils.FailingCollectionSource;
import org.apache.flink.table.planner.utils.FilterUtils;
import org.apache.flink.table.planner.utils.JavaScalaConversionUtil;
import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.logical.BigIntType;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.utils.LogicalTypeParser;
import org.apache.flink.table.types.utils.DataTypeUtils;
@@ -95,6 +106,7 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
+import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
@@ -286,6 +298,9 @@ public final class TestValuesTableFactory
private static final ConfigOption<Integer> SINK_EXPECTED_MESSAGES_NUM =
ConfigOptions.key("sink-expected-messages-num").intType().defaultValue(-1);
+ private static final ConfigOption<Boolean> ENABLE_PROJECTION_PUSH_DOWN =
+
ConfigOptions.key("enable-projection-push-down").booleanType().defaultValue(true);
+
private static final ConfigOption<Boolean> NESTED_PROJECTION_SUPPORTED =
ConfigOptions.key("nested-projection-supported").booleanType().defaultValue(false);
@@ -361,6 +376,7 @@ public final class TestValuesTableFactory
boolean isAsync = helper.getOptions().get(ASYNC_ENABLED);
String lookupFunctionClass =
helper.getOptions().get(LOOKUP_FUNCTION_CLASS);
boolean disableLookup = helper.getOptions().get(DISABLE_LOOKUP);
+ boolean enableProjectionPushDown =
helper.getOptions().get(ENABLE_PROJECTION_PUSH_DOWN);
boolean nestedProjectionSupported =
helper.getOptions().get(NESTED_PROJECTION_SUPPORTED);
boolean enableWatermarkPushDown =
helper.getOptions().get(ENABLE_WATERMARK_PUSH_DOWN);
boolean failingSource = helper.getOptions().get(FAILING_SOURCE);
@@ -398,6 +414,25 @@ public final class TestValuesTableFactory
partition2Rows.put(Collections.emptyMap(), data);
}
+ if (!enableProjectionPushDown) {
+ return new TestValuesScanTableSourceWithoutProjectionPushDown(
+ producedDataType,
+ changelogMode,
+ isBounded,
+ runtimeSource,
+ failingSource,
+ partition2Rows,
+ nestedProjectionSupported,
+ null,
+ Collections.emptyList(),
+ filterableFieldsSet,
+ numElementToSkip,
+ Long.MAX_VALUE,
+ partitions,
+ readableMetadata,
+ null);
+ }
+
if (disableLookup) {
if (enableWatermarkPushDown) {
return new TestValuesScanTableSourceWithWatermarkPushDown(
@@ -541,6 +576,7 @@ public final class TestValuesTableFactory
SINK_INSERT_ONLY,
RUNTIME_SINK,
SINK_EXPECTED_MESSAGES_NUM,
+ ENABLE_PROJECTION_PUSH_DOWN,
NESTED_PROJECTION_SUPPORTED,
FILTERABLE_FIELDS,
PARTITION_LIST,
@@ -679,14 +715,14 @@ public final class TestValuesTableFactory
// Table sources
//
--------------------------------------------------------------------------------------------
- /** Values {@link ScanTableSource} for testing. */
- private static class TestValuesScanTableSource
+ /** Values {@link ScanTableSource} for testing that disables projection
push down. */
+ private static class TestValuesScanTableSourceWithoutProjectionPushDown
implements ScanTableSource,
- SupportsProjectionPushDown,
SupportsFilterPushDown,
SupportsLimitPushDown,
SupportsPartitionPushDown,
- SupportsReadingMetadata {
+ SupportsReadingMetadata,
+ SupportsAggregatePushDown {
protected DataType producedDataType;
protected final ChangelogMode changelogMode;
@@ -705,7 +741,10 @@ public final class TestValuesTableFactory
protected final Map<String, DataType> readableMetadata;
protected @Nullable int[] projectedMetadataFields;
- private TestValuesScanTableSource(
+ private @Nullable int[] groupingSet;
+ private List<AggregateExpression> aggregateExpressions;
+
+ private TestValuesScanTableSourceWithoutProjectionPushDown(
DataType producedDataType,
ChangelogMode changelogMode,
boolean bounded,
@@ -736,6 +775,8 @@ public final class TestValuesTableFactory
this.allPartitions = allPartitions;
this.readableMetadata = readableMetadata;
this.projectedMetadataFields = projectedMetadataFields;
+ this.groupingSet = null;
+ this.aggregateExpressions = Collections.emptyList();
}
@Override
@@ -803,17 +844,6 @@ public final class TestValuesTableFactory
}
@Override
- public boolean supportsNestedProjection() {
- return nestedProjectionSupported;
- }
-
- @Override
- public void applyProjection(int[][] projectedFields) {
- this.producedDataType = DataTypeUtils.projectRow(producedDataType,
projectedFields);
- this.projectedPhysicalFields = projectedFields;
- }
-
- @Override
public Result applyFilters(List<ResolvedExpression> filters) {
List<ResolvedExpression> acceptedFilters = new ArrayList<>();
List<ResolvedExpression> remainingFilters = new ArrayList<>();
@@ -838,7 +868,7 @@ public final class TestValuesTableFactory
@Override
public DynamicTableSource copy() {
- return new TestValuesScanTableSource(
+ return new TestValuesScanTableSourceWithoutProjectionPushDown(
producedDataType,
changelogMode,
bounded,
@@ -867,27 +897,129 @@ public final class TestValuesTableFactory
allPartitions.isEmpty()
? Collections.singletonList(Collections.emptyMap())
: allPartitions;
- int numRetained = 0;
+
+ int numSkipped = 0;
for (Map<String, String> partition : keys) {
- for (Row row : data.get(partition)) {
+ Collection<Row> rowsInPartition = data.get(partition);
+
+ // handle element skipping
+ int numToSkipInPartition = 0;
+ if (numSkipped < numElementToSkip) {
+ numToSkipInPartition =
+ Math.min(rowsInPartition.size(), numElementToSkip
- numSkipped);
+ }
+ numSkipped += numToSkipInPartition;
+
+ // handle predicates and projection
+ List<Row> rowsRetained =
+ rowsInPartition.stream()
+ .skip(numToSkipInPartition)
+ .filter(
+ row ->
+
FilterUtils.isRetainedAfterApplyingFilterPredicates(
+ filterPredicates,
getValueGetter(row)))
+ .map(
+ row -> {
+ Row projectedRow = projectRow(row);
+
projectedRow.setKind(row.getKind());
+ return projectedRow;
+ })
+ .collect(Collectors.toList());
+
+ // handle aggregates
+ if (!aggregateExpressions.isEmpty()) {
+ rowsRetained = applyAggregatesToRows(rowsRetained);
+ }
+
+ // handle row data
+ for (Row row : rowsRetained) {
+ final RowData rowData = (RowData)
converter.toInternal(row);
+ if (rowData != null) {
+ rowData.setRowKind(row.getKind());
+ result.add(rowData);
+ }
+
+ // handle limit. No aggregates will be pushed down when
there is a limit.
if (result.size() >= limit) {
return result;
}
- boolean isRetained =
-
FilterUtils.isRetainedAfterApplyingFilterPredicates(
- filterPredicates, getValueGetter(row));
- if (isRetained) {
- final Row projectedRow = projectRow(row);
- final RowData rowData = (RowData)
converter.toInternal(projectedRow);
- if (rowData != null) {
- if (numRetained >= numElementToSkip) {
- rowData.setRowKind(row.getKind());
- result.add(rowData);
- }
- numRetained++;
- }
+ }
+ }
+
+ return result;
+ }
+
+ private List<Row> applyAggregatesToRows(List<Row> rows) {
+ if (groupingSet != null && groupingSet.length > 0) {
+ // has group by, group firstly
+ Map<Row, List<Row>> buffer = new HashMap<>();
+ for (Row row : rows) {
+ Row bufferKey = new Row(groupingSet.length);
+ for (int i = 0; i < groupingSet.length; i++) {
+ bufferKey.setField(i, row.getField(groupingSet[i]));
+ }
+ if (buffer.containsKey(bufferKey)) {
+ buffer.get(bufferKey).add(row);
+ } else {
+ buffer.put(bufferKey, new
ArrayList<>(Collections.singletonList(row)));
}
}
+ List<Row> result = new ArrayList<>();
+ for (Map.Entry<Row, List<Row>> entry : buffer.entrySet()) {
+ result.add(Row.join(entry.getKey(),
accumulateRows(entry.getValue())));
+ }
+ return result;
+ } else {
+ return Collections.singletonList(accumulateRows(rows));
+ }
+ }
+
+ // can only apply sum/sum0/avg function for long type fields for
testing
+ private Row accumulateRows(List<Row> rows) {
+ Row result = new Row(aggregateExpressions.size());
+ for (int i = 0; i < aggregateExpressions.size(); i++) {
+ FunctionDefinition aggFunction =
+ aggregateExpressions.get(i).getFunctionDefinition();
+ List<FieldReferenceExpression> arguments =
aggregateExpressions.get(i).getArgs();
+ if (aggFunction instanceof MinAggFunction) {
+ int argIndex = arguments.get(0).getFieldIndex();
+ Row minRow =
+ rows.stream()
+ .min(Comparator.comparing(row ->
row.getFieldAs(argIndex)))
+ .orElse(null);
+ result.setField(i, minRow != null ?
minRow.getField(argIndex) : null);
+ } else if (aggFunction instanceof MaxAggFunction) {
+ int argIndex = arguments.get(0).getFieldIndex();
+ Row maxRow =
+ rows.stream()
+ .max(Comparator.comparing(row ->
row.getFieldAs(argIndex)))
+ .orElse(null);
+ result.setField(i, maxRow != null ?
maxRow.getField(argIndex) : null);
+ } else if (aggFunction instanceof SumAggFunction) {
+ int argIndex = arguments.get(0).getFieldIndex();
+ Object finalSum =
+ rows.stream()
+ .filter(row -> row.getField(argIndex) !=
null)
+ .mapToLong(row -> row.getFieldAs(argIndex))
+ .sum();
+
+ boolean allNull = rows.stream().noneMatch(r ->
r.getField(argIndex) != null);
+ result.setField(i, allNull ? null : finalSum);
+ } else if (aggFunction instanceof Sum0AggFunction) {
+ int argIndex = arguments.get(0).getFieldIndex();
+ Object finalSum0 =
+ rows.stream()
+ .filter(row -> row.getField(argIndex) !=
null)
+ .mapToLong(row -> row.getFieldAs(argIndex))
+ .sum();
+ result.setField(i, finalSum0);
+ } else if (aggFunction instanceof CountAggFunction) {
+ int argIndex = arguments.get(0).getFieldIndex();
+ long count = rows.stream().filter(r ->
r.getField(argIndex) != null).count();
+ result.setField(i, count);
+ } else if (aggFunction instanceof Count1AggFunction) {
+ result.setField(i, (long) rows.size());
+ }
}
return result;
}
@@ -954,6 +1086,52 @@ public final class TestValuesTableFactory
}
@Override
+ public boolean applyAggregates(
+ List<int[]> groupingSets,
+ List<AggregateExpression> aggregateExpressions,
+ DataType producedDataType) {
+ // This TestValuesScanTableSource only supports single group
aggregate ar present.
+ if (groupingSets.size() > 1) {
+ return false;
+ }
+ List<AggregateExpression> aggExpressions = new ArrayList<>();
+ for (AggregateExpression aggExpression : aggregateExpressions) {
+ FunctionDefinition functionDefinition =
aggExpression.getFunctionDefinition();
+ if (!(functionDefinition instanceof MinAggFunction
+ || functionDefinition instanceof MaxAggFunction
+ || functionDefinition instanceof SumAggFunction
+ || functionDefinition instanceof Sum0AggFunction
+ || functionDefinition instanceof CountAggFunction
+ || functionDefinition instanceof Count1AggFunction)) {
+ return false;
+ }
+ if (aggExpression.getFilterExpression().isPresent()
+ || aggExpression.isApproximate()
+ || aggExpression.isDistinct()) {
+ return false;
+ }
+
+ // only Long data type is supported in this unit test expect
count()
+ if (aggExpression.getArgs().stream()
+ .anyMatch(
+ field ->
+
!(field.getOutputDataType().getLogicalType()
+ instanceof BigIntType)
+ && !(functionDefinition
instanceof CountAggFunction
+ || functionDefinition
+ instanceof
Count1AggFunction))) {
+ return false;
+ }
+
+ aggExpressions.add(aggExpression);
+ }
+ this.groupingSet = groupingSets.get(0);
+ this.aggregateExpressions = aggExpressions;
+ this.producedDataType = producedDataType;
+ return true;
+ }
+
+ @Override
public void applyLimit(long limit) {
this.limit = limit;
}
@@ -973,6 +1151,77 @@ public final class TestValuesTableFactory
}
}
+ /** Values {@link ScanTableSource} for testing that supports projection
push down. */
+ private static class TestValuesScanTableSource
+ extends TestValuesScanTableSourceWithoutProjectionPushDown
+ implements SupportsProjectionPushDown {
+
+ private TestValuesScanTableSource(
+ DataType producedDataType,
+ ChangelogMode changelogMode,
+ boolean bounded,
+ String runtimeSource,
+ boolean failingSource,
+ Map<Map<String, String>, Collection<Row>> data,
+ boolean nestedProjectionSupported,
+ @Nullable int[][] projectedPhysicalFields,
+ List<ResolvedExpression> filterPredicates,
+ Set<String> filterableFields,
+ int numElementToSkip,
+ long limit,
+ List<Map<String, String>> allPartitions,
+ Map<String, DataType> readableMetadata,
+ @Nullable int[] projectedMetadataFields) {
+ super(
+ producedDataType,
+ changelogMode,
+ bounded,
+ runtimeSource,
+ failingSource,
+ data,
+ nestedProjectionSupported,
+ projectedPhysicalFields,
+ filterPredicates,
+ filterableFields,
+ numElementToSkip,
+ limit,
+ allPartitions,
+ readableMetadata,
+ projectedMetadataFields);
+ }
+
+ @Override
+ public DynamicTableSource copy() {
+ return new TestValuesScanTableSource(
+ producedDataType,
+ changelogMode,
+ bounded,
+ runtimeSource,
+ failingSource,
+ data,
+ nestedProjectionSupported,
+ projectedPhysicalFields,
+ filterPredicates,
+ filterableFields,
+ numElementToSkip,
+ limit,
+ allPartitions,
+ readableMetadata,
+ projectedMetadataFields);
+ }
+
+ @Override
+ public boolean supportsNestedProjection() {
+ return nestedProjectionSupported;
+ }
+
+ @Override
+ public void applyProjection(int[][] projectedFields) {
+ this.producedDataType = DataTypeUtils.projectRow(producedDataType,
projectedFields);
+ this.projectedPhysicalFields = projectedFields;
+ }
+ }
+
/** Values {@link ScanTableSource} for testing that supports watermark
push down. */
private static class TestValuesScanTableSourceWithWatermarkPushDown
extends TestValuesScanTableSource
diff --git
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoTableSourceScanRuleTest.java
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoTableSourceScanRuleTest.java
new file mode 100644
index 0000000..1312f8b
--- /dev/null
+++
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoTableSourceScanRuleTest.java
@@ -0,0 +1,367 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.rules.physical.batch;
+
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.TableConfig;
+import org.apache.flink.table.api.config.ExecutionConfigOptions;
+import org.apache.flink.table.api.config.OptimizerConfigOptions;
+import
org.apache.flink.table.planner.functions.aggfunctions.CollectAggFunction;
+import org.apache.flink.table.planner.utils.BatchTableTestUtil;
+import org.apache.flink.table.planner.utils.TableTestBase;
+
+import org.junit.Before;
+import org.junit.Test;
+
+/**
+ * Test for rules that extend {@link PushLocalAggIntoScanRuleBase} to push
down local aggregates
+ * into table source.
+ */
+public class PushLocalAggIntoTableSourceScanRuleTest extends TableTestBase {
+ protected BatchTableTestUtil util = batchTestUtil(new TableConfig());
+
+ @Before
+ public void setup() {
+ TableConfig tableConfig = util.tableEnv().getConfig();
+ tableConfig
+ .getConfiguration()
+ .setBoolean(
+
OptimizerConfigOptions.TABLE_OPTIMIZER_SOURCE_AGGREGATE_PUSHDOWN_ENABLED,
+ true);
+ String ddl =
+ "CREATE TABLE inventory (\n"
+ + " id BIGINT,\n"
+ + " name STRING,\n"
+ + " amount BIGINT,\n"
+ + " price BIGINT,\n"
+ + " type STRING\n"
+ + ") WITH (\n"
+ + " 'connector' = 'values',\n"
+ + " 'filterable-fields' = 'id;type',\n"
+ + " 'bounded' = 'true'\n"
+ + ")";
+ util.tableEnv().executeSql(ddl);
+
+ String ddl2 =
+ "CREATE TABLE inventory_meta (\n"
+ + " id BIGINT,\n"
+ + " name STRING,\n"
+ + " amount BIGINT,\n"
+ + " price BIGINT,\n"
+ + " type STRING,\n"
+ + " metadata_1 BIGINT METADATA,\n"
+ + " metadata_2 STRING METADATA,\n"
+ + " PRIMARY KEY (`id`) NOT ENFORCED\n"
+ + ") WITH (\n"
+ + " 'connector' = 'values',\n"
+ + " 'filterable-fields' = 'id;type',\n"
+ + " 'readable-metadata' = 'metadata_1:BIGINT,
metadata_2:STRING',\n"
+ + " 'bounded' = 'true'\n"
+ + ")";
+ util.tableEnv().executeSql(ddl2);
+
+ // partitioned table
+ String ddl3 =
+ "CREATE TABLE inventory_part (\n"
+ + " id BIGINT,\n"
+ + " name STRING,\n"
+ + " amount BIGINT,\n"
+ + " price BIGINT,\n"
+ + " type STRING\n"
+ + ") PARTITIONED BY (type)\n"
+ + "WITH (\n"
+ + " 'connector' = 'values',\n"
+ + " 'filterable-fields' = 'id;type',\n"
+ + " 'partition-list' = 'type:a;type:b',\n"
+ + " 'bounded' = 'true'\n"
+ + ")";
+ util.tableEnv().executeSql(ddl3);
+
+ // disable projection push down
+ String ddl4 =
+ "CREATE TABLE inventory_no_proj (\n"
+ + " id BIGINT,\n"
+ + " name STRING,\n"
+ + " amount BIGINT,\n"
+ + " price BIGINT,\n"
+ + " type STRING\n"
+ + ")\n"
+ + "WITH (\n"
+ + " 'connector' = 'values',\n"
+ + " 'filterable-fields' = 'id;type',\n"
+ + " 'enable-projection-push-down' = 'false',\n"
+ + " 'bounded' = 'true'\n"
+ + ")";
+ util.tableEnv().executeSql(ddl4);
+ }
+
+ @Test
+ public void testCanPushDownLocalHashAggWithGroup() {
+ util.verifyRelPlan(
+ "SELECT\n"
+ + " sum(amount),\n"
+ + " name,\n"
+ + " type\n"
+ + "FROM inventory\n"
+ + " group by name, type");
+ }
+
+ @Test
+ public void testDisablePushDownLocalAgg() {
+ // disable push down local agg
+ util.getTableEnv()
+ .getConfig()
+ .getConfiguration()
+ .setBoolean(
+
OptimizerConfigOptions.TABLE_OPTIMIZER_SOURCE_AGGREGATE_PUSHDOWN_ENABLED,
+ false);
+
+ util.verifyRelPlan(
+ "SELECT\n"
+ + " sum(amount),\n"
+ + " name,\n"
+ + " type\n"
+ + "FROM inventory\n"
+ + " group by name, type");
+
+ // reset config
+ util.getTableEnv()
+ .getConfig()
+ .getConfiguration()
+ .setBoolean(
+
OptimizerConfigOptions.TABLE_OPTIMIZER_SOURCE_AGGREGATE_PUSHDOWN_ENABLED,
+ true);
+ }
+
+ @Test
+ public void testCanPushDownLocalHashAggWithoutGroup() {
+ util.verifyRelPlan(
+ "SELECT\n"
+ + " min(id),\n"
+ + " max(amount),\n"
+ + " sum(price),\n"
+ + " avg(price),\n"
+ + " count(id)\n"
+ + "FROM inventory");
+ }
+
+ @Test
+ public void testCanPushDownLocalSortAggWithoutSort() {
+ // enable sort agg
+ util.getTableEnv()
+ .getConfig()
+ .getConfiguration()
+
.setString(ExecutionConfigOptions.TABLE_EXEC_DISABLED_OPERATORS, "HashAgg");
+
+ util.verifyRelPlan(
+ "SELECT\n"
+ + " min(id),\n"
+ + " max(amount),\n"
+ + " sum(price),\n"
+ + " avg(price),\n"
+ + " count(id)\n"
+ + "FROM inventory");
+
+ // reset config
+ util.getTableEnv()
+ .getConfig()
+ .getConfiguration()
+
.setString(ExecutionConfigOptions.TABLE_EXEC_DISABLED_OPERATORS, "");
+ }
+
+ @Test
+ public void testCanPushDownLocalSortAggWithSort() {
+ // enable sort agg
+ util.getTableEnv()
+ .getConfig()
+ .getConfiguration()
+
.setString(ExecutionConfigOptions.TABLE_EXEC_DISABLED_OPERATORS, "HashAgg");
+
+ util.verifyRelPlan(
+ "SELECT\n"
+ + " sum(amount),\n"
+ + " name,\n"
+ + " type\n"
+ + "FROM inventory\n"
+ + " group by name, type");
+
+ // reset config
+ util.getTableEnv()
+ .getConfig()
+ .getConfiguration()
+
.setString(ExecutionConfigOptions.TABLE_EXEC_DISABLED_OPERATORS, "");
+ }
+
+ @Test
+ public void testCanPushDownLocalAggAfterFilterPushDown() {
+
+ util.verifyRelPlan(
+ "SELECT\n"
+ + " sum(amount),\n"
+ + " name,\n"
+ + " type\n"
+ + "FROM inventory\n"
+ + " where id = 123\n"
+ + " group by name, type");
+ }
+
+ @Test
+ public void testCanPushDownLocalAggWithMetadata() {
+ util.verifyRelPlan(
+ "SELECT\n"
+ + " sum(amount),\n"
+ + " max(metadata_1),\n"
+ + " name,\n"
+ + " type\n"
+ + "FROM inventory_meta\n"
+ + " where id = 123\n"
+ + " group by name, type");
+ }
+
+ @Test
+ public void testCanPushDownLocalAggWithPartition() {
+ util.verifyRelPlan(
+ "SELECT\n"
+ + " sum(amount),\n"
+ + " type,\n"
+ + " name\n"
+ + "FROM inventory_part\n"
+ + " where type in ('a', 'b') and id = 123\n"
+ + " group by type, name");
+ }
+
+ @Test
+ public void testCanPushDownLocalAggWithoutProjectionPushDown() {
+ util.verifyRelPlan(
+ "SELECT\n"
+ + " sum(amount),\n"
+ + " name,\n"
+ + " type\n"
+ + "FROM inventory_no_proj\n"
+ + " where id = 123\n"
+ + " group by name, type");
+ }
+
+ @Test
+ public void testCanPushDownLocalAggWithAuxGrouping() {
+ // enable two-phase aggregate, otherwise there is no local aggregate
+ util.getTableEnv()
+ .getConfig()
+ .getConfiguration()
+
.setString(OptimizerConfigOptions.TABLE_OPTIMIZER_AGG_PHASE_STRATEGY,
"TWO_PHASE");
+
+ util.verifyRelPlan(
+ "SELECT\n"
+ + " id, name, count(*)\n"
+ + "FROM inventory_meta\n"
+ + " group by id, name");
+ }
+
+ @Test
+ public void testCannotPushDownLocalAggAfterLimitPushDown() {
+
+ util.verifyRelPlan(
+ "SELECT\n"
+ + " sum(amount),\n"
+ + " name,\n"
+ + " type\n"
+ + "FROM (\n"
+ + " SELECT\n"
+ + " *\n"
+ + " FROM inventory\n"
+ + " LIMIT 100\n"
+ + ") t\n"
+ + " group by name, type");
+ }
+
+ @Test
+ public void testCannotPushDownLocalAggWithUDAF() {
+ // add udf
+ util.addTemporarySystemFunction(
+ "udaf_collect", new
CollectAggFunction<>(DataTypes.BIGINT().getLogicalType()));
+
+ util.verifyRelPlan(
+ "SELECT\n"
+ + " udaf_collect(amount),\n"
+ + " name,\n"
+ + " type\n"
+ + "FROM inventory\n"
+ + " group by name, type");
+ }
+
+ @Test
+ public void testCannotPushDownLocalAggWithUnsupportedDataTypes() {
+ util.verifyRelPlan(
+ "SELECT\n"
+ + " max(name),\n"
+ + " type\n"
+ + "FROM inventory\n"
+ + " group by type");
+ }
+
+ @Test
+ public void testCannotPushDownWithColumnExpression() {
+ util.verifyRelPlan(
+ "SELECT\n"
+ + " min(amount + price),\n"
+ + " max(amount),\n"
+ + " sum(price),\n"
+ + " count(id),\n"
+ + " name\n"
+ + "FROM inventory\n"
+ + " group by name");
+ }
+
+ @Test
+ public void testCannotPushDownWithUnsupportedAggFunction() {
+ util.verifyRelPlan(
+ "SELECT\n"
+ + " min(id),\n"
+ + " max(amount),\n"
+ + " sum(price),\n"
+ + " count(distinct id),\n"
+ + " name\n"
+ + "FROM inventory\n"
+ + " group by name");
+ }
+
+ @Test
+ public void testCannotPushDownWithWindowAggFunction() {
+ util.verifyRelPlan(
+ "SELECT\n"
+ + " id,\n"
+ + " amount,\n"
+ + " sum(price) over (partition by name),\n"
+ + " name\n"
+ + "FROM inventory");
+ }
+
+ @Test
+ public void testCannotPushDownWithArgFilter() {
+ util.verifyRelPlan(
+ "SELECT\n"
+ + " min(id),\n"
+ + " max(amount),\n"
+ + " sum(price),\n"
+ + " count(id) FILTER(WHERE id > 100),\n"
+ + " name\n"
+ + "FROM inventory\n"
+ + " group by name");
+ }
+}
diff --git
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/batch/sql/agg/LocalAggregatePushDownITCase.java
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/batch/sql/agg/LocalAggregatePushDownITCase.java
new file mode 100755
index 0000000..b6522c0
--- /dev/null
+++
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/batch/sql/agg/LocalAggregatePushDownITCase.java
@@ -0,0 +1,318 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.runtime.batch.sql.agg;
+
+import org.apache.flink.table.api.config.ExecutionConfigOptions;
+import org.apache.flink.table.api.config.OptimizerConfigOptions;
+import org.apache.flink.table.planner.factories.TestValuesTableFactory;
+import org.apache.flink.table.planner.runtime.utils.BatchTestBase;
+import org.apache.flink.table.planner.runtime.utils.TestData;
+import org.apache.flink.table.planner.utils.JavaScalaConversionUtil;
+import org.apache.flink.types.Row;
+
+import org.junit.Before;
+import org.junit.Test;
+
+import java.util.Arrays;
+import java.util.Collections;
+
+/** Test for local aggregate push down. */
+public class LocalAggregatePushDownITCase extends BatchTestBase {
+
+ @Before
+ public void before() {
+ super.before();
+ env().setParallelism(1); // set sink parallelism to 1
+
+ String testDataId =
TestValuesTableFactory.registerData(TestData.personData());
+ String ddl =
+ "CREATE TABLE AggregatableTable (\n"
+ + " id int,\n"
+ + " age int,\n"
+ + " name string,\n"
+ + " height int,\n"
+ + " gender string,\n"
+ + " deposit bigint,\n"
+ + " points bigint,\n"
+ + " metadata_1 BIGINT METADATA,\n"
+ + " metadata_2 STRING METADATA,\n"
+ + " PRIMARY KEY (`id`) NOT ENFORCED\n"
+ + ") WITH (\n"
+ + " 'connector' = 'values',\n"
+ + " 'data-id' = '"
+ + testDataId
+ + "',\n"
+ + " 'filterable-fields' = 'id;age',\n"
+ + " 'readable-metadata' = 'metadata_1:BIGINT,
metadata_2:STRING',\n"
+ + " 'bounded' = 'true'\n"
+ + ")";
+ tEnv().executeSql(ddl);
+
+ // partitioned table
+ String ddl2 =
+ "CREATE TABLE AggregatableTable_Part (\n"
+ + " id int,\n"
+ + " age int,\n"
+ + " name string,\n"
+ + " height int,\n"
+ + " gender string,\n"
+ + " deposit bigint,\n"
+ + " points bigint,\n"
+ + " distance BIGINT,\n"
+ + " type STRING\n"
+ + ") PARTITIONED BY (type)\n"
+ + "WITH (\n"
+ + " 'connector' = 'values',\n"
+ + " 'data-id' = '"
+ + testDataId
+ + "',\n"
+ + " 'filterable-fields' = 'id;age',\n"
+ + " 'partition-list' =
'type:A;type:B;type:C;type:D',\n"
+ + " 'bounded' = 'true'\n"
+ + ")";
+ tEnv().executeSql(ddl2);
+
+ // partitioned table
+ String ddl3 =
+ "CREATE TABLE AggregatableTable_No_Proj (\n"
+ + " id int,\n"
+ + " age int,\n"
+ + " name string,\n"
+ + " height int,\n"
+ + " gender string,\n"
+ + " deposit bigint,\n"
+ + " points bigint,\n"
+ + " distance BIGINT,\n"
+ + " type STRING\n"
+ + ")\n"
+ + "WITH (\n"
+ + " 'connector' = 'values',\n"
+ + " 'data-id' = '"
+ + testDataId
+ + "',\n"
+ + " 'filterable-fields' = 'id;age',\n"
+ + " 'enable-projection-push-down' = 'false',\n"
+ + " 'bounded' = 'true'\n"
+ + ")";
+ tEnv().executeSql(ddl3);
+ }
+
+ @Test
+ public void testPushDownLocalHashAggWithGroup() {
+ checkResult(
+ "SELECT\n"
+ + " avg(deposit) as avg_dep,\n"
+ + " sum(deposit),\n"
+ + " count(1),\n"
+ + " gender\n"
+ + "FROM\n"
+ + " AggregatableTable\n"
+ + "GROUP BY gender\n"
+ + "ORDER BY avg_dep",
+ JavaScalaConversionUtil.toScala(
+ Arrays.asList(Row.of(126, 630, 5, "f"), Row.of(220,
1320, 6, "m"))),
+ false);
+ }
+
+ @Test
+ public void testDisablePushDownLocalAgg() {
+ // disable push down local agg
+ tEnv().getConfig()
+ .getConfiguration()
+ .setBoolean(
+
OptimizerConfigOptions.TABLE_OPTIMIZER_SOURCE_AGGREGATE_PUSHDOWN_ENABLED,
+ false);
+
+ checkResult(
+ "SELECT\n"
+ + " avg(deposit) as avg_dep,\n"
+ + " sum(deposit),\n"
+ + " count(1),\n"
+ + " gender\n"
+ + "FROM\n"
+ + " AggregatableTable\n"
+ + "GROUP BY gender\n"
+ + "ORDER BY avg_dep",
+ JavaScalaConversionUtil.toScala(
+ Arrays.asList(Row.of(126, 630, 5, "f"), Row.of(220,
1320, 6, "m"))),
+ false);
+ }
+
+ @Test
+ public void testPushDownLocalHashAggWithoutGroup() {
+ checkResult(
+ "SELECT\n"
+ + " avg(deposit),\n"
+ + " sum(deposit),\n"
+ + " count(*)\n"
+ + "FROM\n"
+ + " AggregatableTable",
+
JavaScalaConversionUtil.toScala(Collections.singletonList(Row.of(177, 1950,
11))),
+ false);
+ }
+
+ @Test
+ public void testPushDownLocalSortAggWithoutSort() {
+ // enable sort agg
+ tEnv().getConfig()
+ .getConfiguration()
+
.setString(ExecutionConfigOptions.TABLE_EXEC_DISABLED_OPERATORS, "HashAgg");
+
+ checkResult(
+ "SELECT\n"
+ + " avg(deposit),\n"
+ + " sum(deposit),\n"
+ + " count(*)\n"
+ + "FROM\n"
+ + " AggregatableTable",
+
JavaScalaConversionUtil.toScala(Collections.singletonList(Row.of(177, 1950,
11))),
+ false);
+ }
+
+ @Test
+ public void testPushDownLocalSortAggWithSort() {
+ // enable sort agg
+ tEnv().getConfig()
+ .getConfiguration()
+
.setString(ExecutionConfigOptions.TABLE_EXEC_DISABLED_OPERATORS, "HashAgg");
+
+ checkResult(
+ "SELECT\n"
+ + " avg(deposit),\n"
+ + " sum(deposit),\n"
+ + " count(1),\n"
+ + " gender,\n"
+ + " age\n"
+ + "FROM\n"
+ + " AggregatableTable\n"
+ + "GROUP BY gender, age",
+ JavaScalaConversionUtil.toScala(
+ Arrays.asList(
+ Row.of(50, 50, 1, "f", 19),
+ Row.of(200, 200, 1, "f", 20),
+ Row.of(250, 750, 3, "m", 23),
+ Row.of(126, 380, 3, "f", 25),
+ Row.of(300, 300, 1, "m", 27),
+ Row.of(170, 170, 1, "m", 28),
+ Row.of(100, 100, 1, "m", 34))),
+ false);
+ }
+
+ @Test
+ public void testPushDownLocalAggAfterFilterPushDown() {
+ checkResult(
+ "SELECT\n"
+ + " avg(deposit),\n"
+ + " sum(deposit),\n"
+ + " count(1),\n"
+ + " gender,\n"
+ + " age\n"
+ + "FROM\n"
+ + " AggregatableTable\n"
+ + "WHERE age <= 20\n"
+ + "GROUP BY gender, age",
+ JavaScalaConversionUtil.toScala(
+ Arrays.asList(Row.of(50, 50, 1, "f", 19), Row.of(200,
200, 1, "f", 20))),
+ false);
+ }
+
+ @Test
+ public void testPushDownLocalAggWithMetadata() {
+ checkResult(
+ "SELECT\n"
+ + " sum(metadata_1),\n"
+ + " metadata_2\n"
+ + "FROM\n"
+ + " AggregatableTable\n"
+ + "GROUP BY metadata_2",
+ JavaScalaConversionUtil.toScala(
+ Arrays.asList(
+ Row.of(156, 'C'),
+ Row.of(183, 'A'),
+ Row.of(51, 'D'),
+ Row.of(70, 'B'))),
+ false);
+ }
+
+ @Test
+ public void testPushDownLocalAggWithPartition() {
+ checkResult(
+ "SELECT\n"
+ + " sum(deposit),\n"
+ + " count(1),\n"
+ + " type,\n"
+ + " name\n"
+ + "FROM\n"
+ + " AggregatableTable_Part\n"
+ + "WHERE type in ('A', 'C')"
+ + "GROUP BY type, name",
+ JavaScalaConversionUtil.toScala(
+ Arrays.asList(
+ Row.of(150, 1, "C", "jack"),
+ Row.of(180, 1, "A", "emma"),
+ Row.of(200, 1, "A", "tom"),
+ Row.of(200, 1, "C", "eva"),
+ Row.of(300, 1, "C", "danny"),
+ Row.of(400, 1, "A", "tommas"),
+ Row.of(50, 1, "C", "olivia"))),
+ false);
+ }
+
+ @Test
+ public void testPushDownLocalAggWithoutProjectionPushDown() {
+ checkResult(
+ "SELECT\n"
+ + " avg(deposit),\n"
+ + " sum(deposit),\n"
+ + " count(1),\n"
+ + " gender,\n"
+ + " age\n"
+ + "FROM\n"
+ + " AggregatableTable_No_Proj\n"
+ + "WHERE age <= 20\n"
+ + "GROUP BY gender, age",
+ JavaScalaConversionUtil.toScala(
+ Arrays.asList(Row.of(50, 50, 1, "f", 19), Row.of(200,
200, 1, "f", 20))),
+ false);
+ }
+
+ @Test
+ public void testPushDownLocalAggWithoutAuxGrouping() {
+ // enable two-phase aggregate, otherwise there is no local aggregate
+ tEnv().getConfig()
+ .getConfiguration()
+
.setString(OptimizerConfigOptions.TABLE_OPTIMIZER_AGG_PHASE_STRATEGY,
"TWO_PHASE");
+
+ checkResult(
+ "SELECT\n"
+ + " id,\n"
+ + " name,\n"
+ + " count(*)\n"
+ + "FROM\n"
+ + " AggregatableTable\n"
+ + "WHERE id > 8\n"
+ + "GROUP BY id, name",
+ JavaScalaConversionUtil.toScala(
+ Arrays.asList(
+ Row.of(9, "emma", 1),
+ Row.of(10, "benji", 1),
+ Row.of(11, "eva", 1))),
+ false);
+ }
+}
diff --git
a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/RankTest.xml
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/RankTest.xml
index c25de39..b587212 100644
---
a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/RankTest.xml
+++
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/RankTest.xml
@@ -40,8 +40,7 @@ Sink(table=[default_catalog.default_database.sink],
fields=[name, eat, cnt])
+- Exchange(distribution=[hash[name]])
+- HashAggregate(isMerge=[true], groupBy=[name, eat],
select=[name, eat, Final_SUM(sum$0) AS cnt])
+- Exchange(distribution=[hash[name, eat]])
- +- LocalHashAggregate(groupBy=[name, eat], select=[name,
eat, Partial_SUM(age) AS sum$0])
- +- TableSourceScan(table=[[default_catalog,
default_database, test_source]], fields=[name, eat, age])
+ +- TableSourceScan(table=[[default_catalog,
default_database, test_source, aggregates=[grouping=[name,eat],
aggFunctions=[LongSumAggFunction(age)]]]], fields=[name, eat, sum$0])
]]>
</Resource>
</TestCase>
diff --git
a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/TableSourceTest.xml
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/TableSourceTest.xml
index f28e546..794c354 100644
---
a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/TableSourceTest.xml
+++
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/TableSourceTest.xml
@@ -133,8 +133,7 @@ LogicalAggregate(group=[{}], EXPR$0=[COUNT()])
<![CDATA[
HashAggregate(isMerge=[true], select=[Final_COUNT(count1$0) AS EXPR$0])
+- Exchange(distribution=[single])
- +- LocalHashAggregate(select=[Partial_COUNT(*) AS count1$0])
- +- TableSourceScan(table=[[default_catalog, default_database,
ProjectableTable, project=[], metadata=[]]], fields=[])
+ +- TableSourceScan(table=[[default_catalog, default_database,
ProjectableTable, project=[], metadata=[], aggregates=[grouping=[],
aggFunctions=[Count1AggFunction()]]]], fields=[count1$0])
]]>
</Resource>
</TestCase>
diff --git
a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoTableSourceScanRuleTest.xml
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoTableSourceScanRuleTest.xml
new file mode 100644
index 0000000..fc4e0d9
--- /dev/null
+++
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoTableSourceScanRuleTest.xml
@@ -0,0 +1,499 @@
+<?xml version="1.0" ?>
+<!--
+Licensed to the Apache Software Foundation (ASF) under one or more
+contributor license agreements. See the NOTICE file distributed with
+this work for additional information regarding copyright ownership.
+The ASF licenses this file to you under the Apache License, Version 2.0
+(the "License"); you may not use this file except in compliance with
+the License. You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+-->
+<Root>
+ <TestCase name="testCanPushDownLocalHashAggWithGroup">
+ <Resource name="sql">
+ <![CDATA[SELECT
+ sum(amount),
+ name,
+ type
+FROM inventory
+ group by name, type]]>
+ </Resource>
+ <Resource name="ast">
+ <![CDATA[
+LogicalProject(EXPR$0=[$2], name=[$0], type=[$1])
++- LogicalAggregate(group=[{0, 1}], EXPR$0=[SUM($2)])
+ +- LogicalProject(name=[$1], type=[$4], amount=[$2])
+ +- LogicalTableScan(table=[[default_catalog, default_database,
inventory]])
+]]>
+ </Resource>
+ <Resource name="optimized rel plan">
+ <![CDATA[
+Calc(select=[EXPR$0, name, type])
++- HashAggregate(isMerge=[true], groupBy=[name, type], select=[name, type,
Final_SUM(sum$0) AS EXPR$0])
+ +- Exchange(distribution=[hash[name, type]])
+ +- TableSourceScan(table=[[default_catalog, default_database, inventory,
project=[name, type, amount], metadata=[], aggregates=[grouping=[name,type],
aggFunctions=[LongSumAggFunction(amount)]]]], fields=[name, type, sum$0])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testDisablePushDownLocalAgg">
+ <Resource name="sql">
+ <![CDATA[SELECT
+ sum(amount),
+ name,
+ type
+FROM inventory
+ group by name, type]]>
+ </Resource>
+ <Resource name="ast">
+ <![CDATA[
+LogicalProject(EXPR$0=[$2], name=[$0], type=[$1])
++- LogicalAggregate(group=[{0, 1}], EXPR$0=[SUM($2)])
+ +- LogicalProject(name=[$1], type=[$4], amount=[$2])
+ +- LogicalTableScan(table=[[default_catalog, default_database,
inventory]])
+]]>
+ </Resource>
+ <Resource name="optimized rel plan">
+ <![CDATA[
+Calc(select=[EXPR$0, name, type])
++- HashAggregate(isMerge=[true], groupBy=[name, type], select=[name, type,
Final_SUM(sum$0) AS EXPR$0])
+ +- Exchange(distribution=[hash[name, type]])
+ +- LocalHashAggregate(groupBy=[name, type], select=[name, type,
Partial_SUM(amount) AS sum$0])
+ +- TableSourceScan(table=[[default_catalog, default_database,
inventory, project=[name, type, amount], metadata=[]]], fields=[name, type,
amount])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testCanPushDownLocalHashAggWithoutGroup">
+ <Resource name="sql">
+ <![CDATA[SELECT
+ min(id),
+ max(amount),
+ sum(price),
+ avg(price),
+ count(id)
+FROM inventory]]>
+ </Resource>
+ <Resource name="ast">
+ <![CDATA[
+LogicalAggregate(group=[{}], EXPR$0=[MIN($0)], EXPR$1=[MAX($1)],
EXPR$2=[SUM($2)], EXPR$3=[AVG($2)], EXPR$4=[COUNT($0)])
++- LogicalProject(id=[$0], amount=[$2], price=[$3])
+ +- LogicalTableScan(table=[[default_catalog, default_database, inventory]])
+]]>
+ </Resource>
+ <Resource name="optimized rel plan">
+ <![CDATA[
+HashAggregate(isMerge=[true], select=[Final_MIN(min$0) AS EXPR$0,
Final_MAX(max$1) AS EXPR$1, Final_SUM(sum$2) AS EXPR$2, Final_AVG(sum$3,
count$4) AS EXPR$3, Final_COUNT(count$5) AS EXPR$4])
++- Exchange(distribution=[single])
+ +- TableSourceScan(table=[[default_catalog, default_database, inventory,
project=[id, amount, price], metadata=[], aggregates=[grouping=[],
aggFunctions=[LongMinAggFunction(id),LongMaxAggFunction(amount),LongSumAggFunction(price),LongSum0AggFunction(price),CountAggFunction(price),CountAggFunction(id)]]]],
fields=[min$0, max$1, sum$2, sum$3, count$4, count$5])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testCanPushDownLocalSortAggWithoutSort">
+ <Resource name="sql">
+ <![CDATA[SELECT
+ min(id),
+ max(amount),
+ sum(price),
+ avg(price),
+ count(id)
+FROM inventory]]>
+ </Resource>
+ <Resource name="ast">
+ <![CDATA[
+LogicalAggregate(group=[{}], EXPR$0=[MIN($0)], EXPR$1=[MAX($1)],
EXPR$2=[SUM($2)], EXPR$3=[AVG($2)], EXPR$4=[COUNT($0)])
++- LogicalProject(id=[$0], amount=[$2], price=[$3])
+ +- LogicalTableScan(table=[[default_catalog, default_database, inventory]])
+]]>
+ </Resource>
+ <Resource name="optimized rel plan">
+ <![CDATA[
+SortAggregate(isMerge=[true], select=[Final_MIN(min$0) AS EXPR$0,
Final_MAX(max$1) AS EXPR$1, Final_SUM(sum$2) AS EXPR$2, Final_AVG(sum$3,
count$4) AS EXPR$3, Final_COUNT(count$5) AS EXPR$4])
++- Exchange(distribution=[single])
+ +- TableSourceScan(table=[[default_catalog, default_database, inventory,
project=[id, amount, price], metadata=[], aggregates=[grouping=[],
aggFunctions=[LongMinAggFunction(id),LongMaxAggFunction(amount),LongSumAggFunction(price),LongSum0AggFunction(price),CountAggFunction(price),CountAggFunction(id)]]]],
fields=[min$0, max$1, sum$2, sum$3, count$4, count$5])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testCanPushDownLocalSortAggWithSort">
+ <Resource name="sql">
+ <![CDATA[SELECT
+ sum(amount),
+ name,
+ type
+FROM inventory
+ group by name, type]]>
+ </Resource>
+ <Resource name="ast">
+ <![CDATA[
+LogicalProject(EXPR$0=[$2], name=[$0], type=[$1])
++- LogicalAggregate(group=[{0, 1}], EXPR$0=[SUM($2)])
+ +- LogicalProject(name=[$1], type=[$4], amount=[$2])
+ +- LogicalTableScan(table=[[default_catalog, default_database,
inventory]])
+]]>
+ </Resource>
+ <Resource name="optimized rel plan">
+ <![CDATA[
+Calc(select=[EXPR$0, name, type])
++- SortAggregate(isMerge=[true], groupBy=[name, type], select=[name, type,
Final_SUM(sum$0) AS EXPR$0])
+ +- Sort(orderBy=[name ASC, type ASC])
+ +- Exchange(distribution=[hash[name, type]])
+ +- TableSourceScan(table=[[default_catalog, default_database,
inventory, project=[name, type, amount], metadata=[],
aggregates=[grouping=[name,type], aggFunctions=[LongSumAggFunction(amount)]]]],
fields=[name, type, sum$0])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testCanPushDownLocalAggWithAuxGrouping">
+ <Resource name="sql">
+ <![CDATA[SELECT
+ id, name, count(*)
+FROM inventory_meta
+ group by id, name]]>
+ </Resource>
+ <Resource name="ast">
+ <![CDATA[
+LogicalAggregate(group=[{0, 1}], EXPR$2=[COUNT()])
++- LogicalProject(id=[$0], name=[$1])
+ +- LogicalTableScan(table=[[default_catalog, default_database,
inventory_meta]])
+]]>
+ </Resource>
+ <Resource name="optimized rel plan">
+ <![CDATA[
+HashAggregate(isMerge=[true], groupBy=[id], auxGrouping=[name], select=[id,
name, Final_COUNT(count1$0) AS EXPR$2])
++- Exchange(distribution=[hash[id]])
+ +- TableSourceScan(table=[[default_catalog, default_database,
inventory_meta, project=[id, name], metadata=[],
aggregates=[grouping=[id,name], aggFunctions=[Count1AggFunction()]]]],
fields=[id, name, count1$0])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testCanPushDownLocalAggAfterFilterPushDown">
+ <Resource name="sql">
+ <![CDATA[SELECT
+ sum(amount),
+ name,
+ type
+FROM inventory
+ where id = 123
+ group by name, type]]>
+ </Resource>
+ <Resource name="ast">
+ <![CDATA[
+LogicalProject(EXPR$0=[$2], name=[$0], type=[$1])
++- LogicalAggregate(group=[{0, 1}], EXPR$0=[SUM($2)])
+ +- LogicalProject(name=[$1], type=[$4], amount=[$2])
+ +- LogicalFilter(condition=[=($0, 123)])
+ +- LogicalTableScan(table=[[default_catalog, default_database,
inventory]])
+]]>
+ </Resource>
+ <Resource name="optimized rel plan">
+ <![CDATA[
+Calc(select=[EXPR$0, name, type])
++- HashAggregate(isMerge=[true], groupBy=[name, type], select=[name, type,
Final_SUM(sum$0) AS EXPR$0])
+ +- Exchange(distribution=[hash[name, type]])
+ +- TableSourceScan(table=[[default_catalog, default_database, inventory,
filter=[=(id, 123:BIGINT)], project=[name, type, amount], metadata=[],
aggregates=[grouping=[name,type], aggFunctions=[LongSumAggFunction(amount)]]]],
fields=[name, type, sum$0])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testCanPushDownLocalAggWithMetadata">
+ <Resource name="sql">
+ <![CDATA[SELECT
+ sum(amount),
+ max(metadata_1),
+ name,
+ type
+FROM inventory_meta
+ where id = 123
+ group by name, type]]>
+ </Resource>
+ <Resource name="ast">
+ <![CDATA[
+LogicalProject(EXPR$0=[$2], EXPR$1=[$3], name=[$0], type=[$1])
++- LogicalAggregate(group=[{0, 1}], EXPR$0=[SUM($2)], EXPR$1=[MAX($3)])
+ +- LogicalProject(name=[$1], type=[$4], amount=[$2], metadata_1=[$5])
+ +- LogicalFilter(condition=[=($0, 123)])
+ +- LogicalProject(id=[$0], name=[$1], amount=[$2], price=[$3],
type=[$4], metadata_1=[$5], metadata_2=[$6])
+ +- LogicalTableScan(table=[[default_catalog, default_database,
inventory_meta]])
+]]>
+ </Resource>
+ <Resource name="optimized rel plan">
+ <![CDATA[
+Calc(select=[EXPR$0, EXPR$1, name, type])
++- HashAggregate(isMerge=[true], groupBy=[name, type], select=[name, type,
Final_SUM(sum$0) AS EXPR$0, Final_MAX(max$1) AS EXPR$1])
+ +- Exchange(distribution=[hash[name, type]])
+ +- TableSourceScan(table=[[default_catalog, default_database,
inventory_meta, filter=[=(id, 123:BIGINT)], project=[name, type, amount,
metadata_1], metadata=[metadata_1], aggregates=[grouping=[name,type],
aggFunctions=[LongSumAggFunction(amount),LongMaxAggFunction(metadata_1)]]]],
fields=[name, type, sum$0, max$1])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testCanPushDownLocalAggWithPartition">
+ <Resource name="sql">
+ <![CDATA[SELECT
+ sum(amount),
+ type,
+ name
+FROM inventory_part
+ where type in ('a', 'b') and id = 123
+ group by type, name]]>
+ </Resource>
+ <Resource name="ast">
+ <![CDATA[
+LogicalProject(EXPR$0=[$2], type=[$0], name=[$1])
++- LogicalAggregate(group=[{0, 1}], EXPR$0=[SUM($2)])
+ +- LogicalProject(type=[$4], name=[$1], amount=[$2])
+ +- LogicalFilter(condition=[AND(OR(=($4, _UTF-16LE'a'), =($4,
_UTF-16LE'b')), =($0, 123))])
+ +- LogicalTableScan(table=[[default_catalog, default_database,
inventory_part]])
+]]>
+ </Resource>
+ <Resource name="optimized rel plan">
+ <![CDATA[
+Calc(select=[EXPR$0, type, name])
++- HashAggregate(isMerge=[true], groupBy=[type, name], select=[type, name,
Final_SUM(sum$0) AS EXPR$0])
+ +- Exchange(distribution=[hash[type, name]])
+ +- TableSourceScan(table=[[default_catalog, default_database,
inventory_part, filter=[=(id, 123:BIGINT)], partitions=[{type=a}, {type=b}],
project=[type, name, amount], metadata=[], aggregates=[grouping=[type,name],
aggFunctions=[LongSumAggFunction(amount)]]]], fields=[type, name, sum$0])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testCanPushDownLocalAggWithoutProjectionPushDown">
+ <Resource name="sql">
+ <![CDATA[SELECT
+ sum(amount),
+ name,
+ type
+FROM inventory_no_proj
+ where id = 123
+ group by name, type]]>
+ </Resource>
+ <Resource name="ast">
+ <![CDATA[
+LogicalProject(EXPR$0=[$2], name=[$0], type=[$1])
++- LogicalAggregate(group=[{0, 1}], EXPR$0=[SUM($2)])
+ +- LogicalProject(name=[$1], type=[$4], amount=[$2])
+ +- LogicalFilter(condition=[=($0, 123)])
+ +- LogicalTableScan(table=[[default_catalog, default_database,
inventory_no_proj]])
+]]>
+ </Resource>
+ <Resource name="optimized rel plan">
+ <![CDATA[
+Calc(select=[EXPR$0, name, type])
++- HashAggregate(isMerge=[true], groupBy=[name, type], select=[name, type,
Final_SUM(sum$0) AS EXPR$0])
+ +- Exchange(distribution=[hash[name, type]])
+ +- TableSourceScan(table=[[default_catalog, default_database,
inventory_no_proj, filter=[=(id, 123:BIGINT)],
aggregates=[grouping=[name,type], aggFunctions=[LongSumAggFunction(amount)]]]],
fields=[name, type, sum$0])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testCannotPushDownLocalAggAfterLimitPushDown">
+ <Resource name="sql">
+ <![CDATA[SELECT
+ sum(amount),
+ name,
+ type
+FROM (
+ SELECT
+ *
+ FROM inventory
+ LIMIT 100
+) t
+ group by name, type]]>
+ </Resource>
+ <Resource name="ast">
+ <![CDATA[
+LogicalProject(EXPR$0=[$2], name=[$0], type=[$1])
++- LogicalAggregate(group=[{0, 1}], EXPR$0=[SUM($2)])
+ +- LogicalProject(name=[$1], type=[$4], amount=[$2])
+ +- LogicalSort(fetch=[100])
+ +- LogicalProject(id=[$0], name=[$1], amount=[$2], price=[$3],
type=[$4])
+ +- LogicalTableScan(table=[[default_catalog, default_database,
inventory]])
+]]>
+ </Resource>
+ <Resource name="optimized rel plan">
+ <![CDATA[
+Calc(select=[EXPR$0, name, type])
++- HashAggregate(isMerge=[true], groupBy=[name, type], select=[name, type,
Final_SUM(sum$0) AS EXPR$0])
+ +- Exchange(distribution=[hash[name, type]])
+ +- LocalHashAggregate(groupBy=[name, type], select=[name, type,
Partial_SUM(amount) AS sum$0])
+ +- Calc(select=[name, type, amount])
+ +- Limit(offset=[0], fetch=[100], global=[true])
+ +- Exchange(distribution=[single])
+ +- Limit(offset=[0], fetch=[100], global=[false])
+ +- TableSourceScan(table=[[default_catalog,
default_database, inventory, limit=[100]]], fields=[id, name, amount, price,
type])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testCannotPushDownLocalAggWithUDAF">
+ <Resource name="sql">
+ <![CDATA[SELECT
+ udaf_collect(amount),
+ name,
+ type
+FROM inventory
+ group by name, type]]>
+ </Resource>
+ <Resource name="ast">
+ <![CDATA[
+LogicalProject(EXPR$0=[$2], name=[$0], type=[$1])
++- LogicalAggregate(group=[{0, 1}], EXPR$0=[udaf_collect($2)])
+ +- LogicalProject(name=[$1], type=[$4], amount=[$2])
+ +- LogicalTableScan(table=[[default_catalog, default_database,
inventory]])
+]]>
+ </Resource>
+ <Resource name="optimized rel plan">
+ <![CDATA[
+Calc(select=[EXPR$0, name, type])
++- SortAggregate(isMerge=[true], groupBy=[name, type], select=[name, type,
Final_udaf_collect(EXPR$0) AS EXPR$0])
+ +- Sort(orderBy=[name ASC, type ASC])
+ +- Exchange(distribution=[hash[name, type]])
+ +- LocalSortAggregate(groupBy=[name, type], select=[name, type,
Partial_udaf_collect(amount) AS EXPR$0])
+ +- Sort(orderBy=[name ASC, type ASC])
+ +- TableSourceScan(table=[[default_catalog, default_database,
inventory, project=[name, type, amount], metadata=[]]], fields=[name, type,
amount])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testCannotPushDownLocalAggWithUnsupportedDataTypes">
+ <Resource name="sql">
+ <![CDATA[SELECT
+ max(name),
+ type
+FROM inventory
+ group by type]]>
+ </Resource>
+ <Resource name="ast">
+ <![CDATA[
+LogicalProject(EXPR$0=[$1], type=[$0])
++- LogicalAggregate(group=[{0}], EXPR$0=[MAX($1)])
+ +- LogicalProject(type=[$4], name=[$1])
+ +- LogicalTableScan(table=[[default_catalog, default_database,
inventory]])
+]]>
+ </Resource>
+ <Resource name="optimized rel plan">
+ <![CDATA[
+Calc(select=[EXPR$0, type])
++- SortAggregate(isMerge=[true], groupBy=[type], select=[type,
Final_MAX(max$0) AS EXPR$0])
+ +- Sort(orderBy=[type ASC])
+ +- Exchange(distribution=[hash[type]])
+ +- LocalSortAggregate(groupBy=[type], select=[type, Partial_MAX(name)
AS max$0])
+ +- Sort(orderBy=[type ASC])
+ +- TableSourceScan(table=[[default_catalog, default_database,
inventory, project=[type, name], metadata=[]]], fields=[type, name])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testCannotPushDownWithColumnExpression">
+ <Resource name="sql">
+ <![CDATA[SELECT
+ min(amount + price),
+ max(amount),
+ sum(price),
+ count(id),
+ name
+FROM inventory
+ group by name]]>
+ </Resource>
+ <Resource name="ast">
+ <![CDATA[
+LogicalProject(EXPR$0=[$1], EXPR$1=[$2], EXPR$2=[$3], EXPR$3=[$4], name=[$0])
++- LogicalAggregate(group=[{0}], EXPR$0=[MIN($1)], EXPR$1=[MAX($2)],
EXPR$2=[SUM($3)], EXPR$3=[COUNT($4)])
+ +- LogicalProject(name=[$1], $f1=[+($2, $3)], amount=[$2], price=[$3],
id=[$0])
+ +- LogicalTableScan(table=[[default_catalog, default_database,
inventory]])
+]]>
+ </Resource>
+ <Resource name="optimized rel plan">
+ <![CDATA[
+Calc(select=[EXPR$0, EXPR$1, EXPR$2, EXPR$3, name])
++- HashAggregate(isMerge=[true], groupBy=[name], select=[name,
Final_MIN(min$0) AS EXPR$0, Final_MAX(max$1) AS EXPR$1, Final_SUM(sum$2) AS
EXPR$2, Final_COUNT(count$3) AS EXPR$3])
+ +- Exchange(distribution=[hash[name]])
+ +- LocalHashAggregate(groupBy=[name], select=[name, Partial_MIN($f1) AS
min$0, Partial_MAX(amount) AS max$1, Partial_SUM(price) AS sum$2,
Partial_COUNT(id) AS count$3])
+ +- Calc(select=[name, +(amount, price) AS $f1, amount, price, id])
+ +- TableSourceScan(table=[[default_catalog, default_database,
inventory, project=[name, amount, price, id], metadata=[]]], fields=[name,
amount, price, id])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testCannotPushDownWithUnsupportedAggFunction">
+ <Resource name="sql">
+ <![CDATA[SELECT
+ min(id),
+ max(amount),
+ sum(price),
+ count(distinct id),
+ name
+FROM inventory
+ group by name]]>
+ </Resource>
+ <Resource name="ast">
+ <![CDATA[
+LogicalProject(EXPR$0=[$1], EXPR$1=[$2], EXPR$2=[$3], EXPR$3=[$4], name=[$0])
++- LogicalAggregate(group=[{0}], EXPR$0=[MIN($1)], EXPR$1=[MAX($2)],
EXPR$2=[SUM($3)], EXPR$3=[COUNT(DISTINCT $1)])
+ +- LogicalProject(name=[$1], id=[$0], amount=[$2], price=[$3])
+ +- LogicalTableScan(table=[[default_catalog, default_database,
inventory]])
+]]>
+ </Resource>
+ <Resource name="optimized rel plan">
+ <![CDATA[
+Calc(select=[EXPR$0, EXPR$1, EXPR$2, EXPR$3, name])
++- HashAggregate(isMerge=[true], groupBy=[name], select=[name,
Final_MIN(min$0) AS EXPR$0, Final_MIN(min$1) AS EXPR$1, Final_MIN(min$2) AS
EXPR$2, Final_COUNT(count$3) AS EXPR$3])
+ +- Exchange(distribution=[hash[name]])
+ +- LocalHashAggregate(groupBy=[name], select=[name, Partial_MIN(EXPR$0)
FILTER $g_1 AS min$0, Partial_MIN(EXPR$1) FILTER $g_1 AS min$1,
Partial_MIN(EXPR$2) FILTER $g_1 AS min$2, Partial_COUNT(id) FILTER $g_0 AS
count$3])
+ +- Calc(select=[name, id, EXPR$0, EXPR$1, EXPR$2, =(CASE(=($e,
0:BIGINT), 0:BIGINT, 1:BIGINT), 0) AS $g_0, =(CASE(=($e, 0:BIGINT), 0:BIGINT,
1:BIGINT), 1) AS $g_1])
+ +- HashAggregate(isMerge=[true], groupBy=[name, id, $e],
select=[name, id, $e, Final_MIN(min$0) AS EXPR$0, Final_MAX(max$1) AS EXPR$1,
Final_SUM(sum$2) AS EXPR$2])
+ +- Exchange(distribution=[hash[name, id, $e]])
+ +- LocalHashAggregate(groupBy=[name, id, $e], select=[name,
id, $e, Partial_MIN(id_0) AS min$0, Partial_MAX(amount) AS max$1,
Partial_SUM(price) AS sum$2])
+ +- Expand(projects=[{name, id, amount, price, 0 AS $e, id
AS id_0}, {name, null AS id, amount, price, 1 AS $e, id AS id_0}])
+ +- TableSourceScan(table=[[default_catalog,
default_database, inventory, project=[name, id, amount, price], metadata=[]]],
fields=[name, id, amount, price])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testCannotPushDownWithWindowAggFunction">
+ <Resource name="sql">
+ <![CDATA[SELECT
+ id,
+ amount,
+ sum(price) over (partition by name),
+ name
+FROM inventory]]>
+ </Resource>
+ <Resource name="ast">
+ <![CDATA[
+LogicalProject(id=[$0], amount=[$2], EXPR$2=[CASE(>(COUNT($3) OVER (PARTITION
BY $1), 0), $SUM0($3) OVER (PARTITION BY $1), null:BIGINT)], name=[$1])
++- LogicalTableScan(table=[[default_catalog, default_database, inventory]])
+]]>
+ </Resource>
+ <Resource name="optimized rel plan">
+ <![CDATA[
+Calc(select=[id, amount, CASE(>(w0$o0, 0:BIGINT), w0$o1, null:BIGINT) AS
EXPR$2, name])
++- OverAggregate(partitionBy=[name], window#0=[COUNT(price) AS w0$o0,
$SUM0(price) AS w0$o1 RANG BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED
FOLLOWING], select=[id, name, amount, price, w0$o0, w0$o1])
+ +- Sort(orderBy=[name ASC])
+ +- Exchange(distribution=[hash[name]])
+ +- TableSourceScan(table=[[default_catalog, default_database,
inventory, project=[id, name, amount, price], metadata=[]]], fields=[id, name,
amount, price])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testCannotPushDownWithArgFilter">
+ <Resource name="sql">
+ <![CDATA[SELECT
+ min(id),
+ max(amount),
+ sum(price),
+ count(id) FILTER(WHERE id > 100),
+ name
+FROM inventory
+ group by name]]>
+ </Resource>
+ <Resource name="ast">
+ <![CDATA[
+LogicalProject(EXPR$0=[$1], EXPR$1=[$2], EXPR$2=[$3], EXPR$3=[$4], name=[$0])
++- LogicalAggregate(group=[{0}], EXPR$0=[MIN($1)], EXPR$1=[MAX($2)],
EXPR$2=[SUM($3)], EXPR$3=[COUNT($1) FILTER $4])
+ +- LogicalProject(name=[$1], id=[$0], amount=[$2], price=[$3], $f4=[IS
TRUE(>($0, 100))])
+ +- LogicalTableScan(table=[[default_catalog, default_database,
inventory]])
+]]>
+ </Resource>
+ <Resource name="optimized rel plan">
+ <![CDATA[
+Calc(select=[EXPR$0, EXPR$1, EXPR$2, EXPR$3, name])
++- HashAggregate(isMerge=[true], groupBy=[name], select=[name,
Final_MIN(min$0) AS EXPR$0, Final_MAX(max$1) AS EXPR$1, Final_SUM(sum$2) AS
EXPR$2, Final_COUNT(count$3) AS EXPR$3])
+ +- Exchange(distribution=[hash[name]])
+ +- LocalHashAggregate(groupBy=[name], select=[name, Partial_MIN(id) AS
min$0, Partial_MAX(amount) AS max$1, Partial_SUM(price) AS sum$2,
Partial_COUNT(id) FILTER $f4 AS count$3])
+ +- Calc(select=[name, id, amount, price, IS TRUE(>(id, 100)) AS $f4])
+ +- TableSourceScan(table=[[default_catalog, default_database,
inventory, project=[name, id, amount, price], metadata=[]]], fields=[name, id,
amount, price])
+]]>
+ </Resource>
+ </TestCase>
+</Root>
diff --git
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/utils/TestData.scala
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/utils/TestData.scala
index 538c5a8..9fe60ba 100644
---
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/utils/TestData.scala
+++
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/utils/TestData.scala
@@ -414,17 +414,17 @@ object TestData {
// person test data
lazy val personData: Seq[Row] = Seq(
- row(1, 23, "tom", 172, "m"),
- row(2, 21, "mary", 161, "f"),
- row(3, 18, "jack", 182, "m"),
- row(4, 25, "rose", 165, "f"),
- row(5, 27, "danny", 175, "m"),
- row(6, 31, "tommas", 172, "m"),
- row(7, 19, "olivia", 172, "f"),
- row(8, 34, "stef", 170, "m"),
- row(9, 32, "emma", 171, "f"),
- row(10, 28, "benji", 165, "m"),
- row(11, 20, "eva", 180, "f")
+ row(1, 23, "tom", 172, "m", 200L, 1000L, 15L, "A"),
+ row(2, 25, "mary", 161, "f", 100L, 1000L, 25L, "B"),
+ row(3, 23, "jack", 182, "m", 150L, 1300L, 35L, "C"),
+ row(4, 25, "rose", 165, "f", 100L, 500L, 45L, "B"),
+ row(5, 27, "danny", 175, "m", 300L, 300L, 54L, "C"),
+ row(6, 23, "tommas", 172, "m", 400L, 4000L, 53L, "A"),
+ row(7, 19, "olivia", 172, "f", 50L, 9000L, 52L, "C"),
+ row(8, 34, "stef", 170, "m", 100L, 1900L, 51L, "D"),
+ row(9, 25, "emma", 171, "f", 180L, 800L, 115L, "A"),
+ row(10, 28, "benji", 165, "m", 170L, 11000L, 0L, "B"),
+ row(11, 20, "eva", 180, "f", 200L, 1000L, 15L, "C")
)
val nullablesOfPersonData = Array(true, true, true, true, true)