This is an automated email from the ASF dual-hosted git repository. godfrey pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit 29c81fe1c0803e6649d128aa08759b2c1384cfed Author: godfreyhe <[email protected]> AuthorDate: Tue Jan 5 19:52:36 2021 +0800 [FLINK-20738][table-planner-blink] Introduce BatchPhysicalSortAggregate, and make BatchExecSortAggregate only extended from ExecNode This closes #14562 --- .../nodes/exec/batch/BatchExecSortAggregate.java | 125 +++++++++++++++++++++ .../codegen/agg/batch/SortAggCodeGenerator.scala | 2 +- .../plan/metadata/FlinkRelMdColumnInterval.scala | 2 +- .../batch/BatchExecLocalSortAggregate.scala | 55 ++++++++- ...gate.scala => BatchPhysicalSortAggregate.scala} | 35 +++--- ....scala => BatchPhysicalSortAggregateBase.scala} | 56 +-------- .../planner/plan/rules/FlinkBatchRuleSets.scala | 2 +- ...ggRule.scala => BatchPhysicalSortAggRule.scala} | 20 ++-- .../physical/batch/EnforceLocalAggRuleBase.scala | 8 +- .../physical/batch/EnforceLocalSortAggRule.scala | 10 +- .../batch/RemoveRedundantLocalSortAggRule.scala | 20 ++-- .../batch/EnforceLocalSortAggRuleTest.scala | 4 +- 12 files changed, 234 insertions(+), 105 deletions(-) diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecSortAggregate.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecSortAggregate.java new file mode 100644 index 0000000..17c15a0 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecSortAggregate.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.nodes.exec.batch; + +import org.apache.flink.api.dag.Transformation; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.transformations.OneInputTransformation; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.planner.codegen.CodeGeneratorContext; +import org.apache.flink.table.planner.codegen.agg.batch.AggWithoutKeysCodeGenerator; +import org.apache.flink.table.planner.codegen.agg.batch.SortAggCodeGenerator; +import org.apache.flink.table.planner.delegation.PlannerBase; +import org.apache.flink.table.planner.plan.nodes.exec.ExecEdge; +import org.apache.flink.table.planner.plan.nodes.exec.ExecNode; +import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeBase; +import org.apache.flink.table.planner.plan.utils.AggregateInfoList; +import org.apache.flink.table.planner.plan.utils.AggregateUtil; +import org.apache.flink.table.planner.utils.JavaScalaConversionUtil; +import org.apache.flink.table.runtime.generated.GeneratedOperator; +import org.apache.flink.table.runtime.operators.CodeGenOperatorFactory; +import org.apache.flink.table.runtime.typeutils.InternalTypeInfo; +import org.apache.flink.table.types.logical.RowType; + +import org.apache.calcite.rel.core.AggregateCall; + +import java.util.Arrays; +import java.util.Collections; + +/** Batch {@link ExecNode} for (global) sort-based aggregate operator. */ +public class BatchExecSortAggregate extends ExecNodeBase<RowData> + implements BatchExecNode<RowData> { + + private final int[] grouping; + private final int[] auxGrouping; + private final AggregateCall[] aggCalls; + private final RowType aggInputRowType; + private final boolean isMerge; + private final boolean isFinal; + + public BatchExecSortAggregate( + int[] grouping, + int[] auxGrouping, + AggregateCall[] aggCalls, + RowType aggInputRowType, + boolean isMerge, + boolean isFinal, + ExecEdge inputEdge, + RowType outputType, + String description) { + super(Collections.singletonList(inputEdge), outputType, description); + this.grouping = grouping; + this.auxGrouping = auxGrouping; + this.aggCalls = aggCalls; + this.aggInputRowType = aggInputRowType; + this.isMerge = isMerge; + this.isFinal = isFinal; + } + + @SuppressWarnings("unchecked") + @Override + protected Transformation<RowData> translateToPlanInternal(PlannerBase planner) { + final ExecNode<RowData> inputNode = (ExecNode<RowData>) getInputNodes().get(0); + final Transformation<RowData> inputTransform = inputNode.translateToPlan(planner); + + final RowType inputRowType = (RowType) inputNode.getOutputType(); + final RowType outputRowType = (RowType) getOutputType(); + + final CodeGeneratorContext ctx = new CodeGeneratorContext(planner.getTableConfig()); + final AggregateInfoList aggInfos = + AggregateUtil.transformToBatchAggregateInfoList( + aggInputRowType, + JavaScalaConversionUtil.toScala(Arrays.asList(aggCalls)), + null, + null); + + final GeneratedOperator<OneInputStreamOperator<RowData, RowData>> generatedOperator; + if (grouping.length == 0) { + generatedOperator = + AggWithoutKeysCodeGenerator.genWithoutKeys( + ctx, + planner.getRelBuilder(), + aggInfos, + inputRowType, + outputRowType, + isMerge, + isFinal, + "NoGrouping"); + } else { + generatedOperator = + SortAggCodeGenerator.genWithKeys( + ctx, + planner.getRelBuilder(), + aggInfos, + inputRowType, + outputRowType, + grouping, + auxGrouping, + isMerge, + isFinal); + } + + return new OneInputTransformation<>( + inputTransform, + getDesc(), + new CodeGenOperatorFactory<>(generatedOperator), + InternalTypeInfo.of(outputRowType), + inputTransform.getParallelism()); + } +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/SortAggCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/SortAggCodeGenerator.scala index c62ac4a..901ebf9 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/SortAggCodeGenerator.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/SortAggCodeGenerator.scala @@ -37,7 +37,7 @@ import org.apache.flink.table.types.logical.RowType */ object SortAggCodeGenerator { - private[flink] def genWithKeys( + def genWithKeys( ctx: CodeGeneratorContext, builder: RelBuilder, aggInfoList: AggregateInfoList, diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala index 82de446..72f6891 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala @@ -625,7 +625,7 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] { } case agg: BatchExecLocalSortAggregate => getAggCallFromLocalAgg(aggCallIndex, agg.getAggCallList, agg.getInput.getRowType) - case agg: BatchExecSortAggregate if agg.isMerge => + case agg: BatchPhysicalSortAggregate if agg.isMerge => val aggCallIndexInLocalAgg = getAggCallIndexInLocalAgg( aggCallIndex, agg.getAggCallList, agg.aggInputRowType) if (aggCallIndexInLocalAgg != null) { diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecLocalSortAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecLocalSortAggregate.scala index 07ff2d0..ed4d362 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecLocalSortAggregate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecLocalSortAggregate.scala @@ -18,10 +18,20 @@ package org.apache.flink.table.planner.plan.nodes.physical.batch +import org.apache.flink.api.dag.Transformation +import org.apache.flink.table.data.RowData import org.apache.flink.table.functions.UserDefinedFunction +import org.apache.flink.table.planner.calcite.FlinkTypeFactory +import org.apache.flink.table.planner.codegen.CodeGeneratorContext +import org.apache.flink.table.planner.codegen.agg.batch.{AggWithoutKeysCodeGenerator, SortAggCodeGenerator} +import org.apache.flink.table.planner.delegation.BatchPlanner import org.apache.flink.table.planner.plan.`trait`.{FlinkRelDistribution, FlinkRelDistributionTraitDef} -import org.apache.flink.table.planner.plan.nodes.exec.ExecEdge +import org.apache.flink.table.planner.plan.nodes.exec.utils.ExecNodeUtil +import org.apache.flink.table.planner.plan.nodes.exec.{ExecEdge, LegacyBatchExecNode} +import org.apache.flink.table.planner.plan.utils.AggregateUtil.transformToBatchAggregateInfoList import org.apache.flink.table.planner.plan.utils.{FlinkRelOptUtil, RelExplainUtil} +import org.apache.flink.table.runtime.operators.CodeGenOperatorFactory +import org.apache.flink.table.runtime.typeutils.InternalTypeInfo import org.apache.calcite.plan.{RelOptCluster, RelOptRule, RelTraitSet} import org.apache.calcite.rel.RelDistribution.Type @@ -48,18 +58,17 @@ class BatchExecLocalSortAggregate( grouping: Array[Int], auxGrouping: Array[Int], aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)]) - extends BatchExecSortAggregateBase( + extends BatchPhysicalSortAggregateBase( cluster, traitSet, inputRel, outputRowType, - inputRowType, - inputRowType, grouping, auxGrouping, aggCallToAggFunction, isMerge = false, - isFinal = false) { + isFinal = false) + with LegacyBatchExecNode[RowData] { override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = { new BatchExecLocalSortAggregate( @@ -129,6 +138,42 @@ class BatchExecLocalSortAggregate( //~ ExecNode methods ----------------------------------------------------------- + override protected def translateToPlanInternal( + planner: BatchPlanner): Transformation[RowData] = { + val input = getInputNodes.get(0).translateToPlan(planner) + .asInstanceOf[Transformation[RowData]] + val ctx = CodeGeneratorContext(planner.getTableConfig) + val outputType = FlinkTypeFactory.toLogicalRowType(getRowType) + val inputType = FlinkTypeFactory.toLogicalRowType(inputRowType) + + val aggInfos = transformToBatchAggregateInfoList( + FlinkTypeFactory.toLogicalRowType(inputRowType), getAggCallList) + + val generatedOperator = if (grouping.isEmpty) { + AggWithoutKeysCodeGenerator.genWithoutKeys( + ctx, planner.getRelBuilder, aggInfos, inputType, outputType, isMerge, isFinal, "NoGrouping") + } else { + SortAggCodeGenerator.genWithKeys( + ctx, + planner.getRelBuilder, + aggInfos, + inputType, + outputType, + grouping, + auxGrouping, + isMerge, + isFinal) + } + val operator = new CodeGenOperatorFactory[RowData](generatedOperator) + ExecNodeUtil.createOneInputTransformation( + input, + getRelDetailedDescription, + operator, + InternalTypeInfo.of(outputType), + input.getParallelism, + 0) + } + override def getInputEdges: util.List[ExecEdge] = { if (grouping.length == 0) { List( diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecSortAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalSortAggregate.scala similarity index 87% rename from flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecSortAggregate.scala rename to flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalSortAggregate.scala index 58f61c5..cb0fd1f 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecSortAggregate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalSortAggregate.scala @@ -19,8 +19,10 @@ package org.apache.flink.table.planner.plan.nodes.physical.batch import org.apache.flink.table.functions.UserDefinedFunction +import org.apache.flink.table.planner.calcite.FlinkTypeFactory import org.apache.flink.table.planner.plan.`trait`.{FlinkRelDistribution, FlinkRelDistributionTraitDef} -import org.apache.flink.table.planner.plan.nodes.exec.ExecEdge +import org.apache.flink.table.planner.plan.nodes.exec.batch.BatchExecSortAggregate +import org.apache.flink.table.planner.plan.nodes.exec.{ExecEdge, ExecNode} import org.apache.flink.table.planner.plan.rules.physical.batch.BatchExecJoinRuleBase import org.apache.flink.table.planner.plan.utils.{FlinkRelOptUtil, RelExplainUtil} @@ -40,7 +42,7 @@ import scala.collection.JavaConversions._ * * @see [[BatchPhysicalGroupAggregateBase]] for more info. */ -class BatchExecSortAggregate( +class BatchPhysicalSortAggregate( cluster: RelOptCluster, traitSet: RelTraitSet, inputRel: RelNode, @@ -51,13 +53,11 @@ class BatchExecSortAggregate( auxGrouping: Array[Int], aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)], isMerge: Boolean) - extends BatchExecSortAggregateBase( + extends BatchPhysicalSortAggregateBase( cluster, traitSet, inputRel, outputRowType, - inputRowType, - aggInputRowType, grouping, auxGrouping, aggCallToAggFunction, @@ -65,7 +65,7 @@ class BatchExecSortAggregate( isFinal = true) { override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = { - new BatchExecSortAggregate( + new BatchPhysicalSortAggregate( cluster, traitSet, inputs.get(0), @@ -153,16 +153,25 @@ class BatchExecSortAggregate( Some(copy(newProvidedTraitSet, Seq(newInput))) } - //~ ExecNode methods ----------------------------------------------------------- + override def translateToExecNode(): ExecNode[_] = { + new BatchExecSortAggregate( + grouping, + auxGrouping, + getAggCallList.toArray, + FlinkTypeFactory.toLogicalRowType(aggInputRowType), + isMerge, + true, // isFinal is always true + getInputEdge, + FlinkTypeFactory.toLogicalRowType(getRowType), + getRelDetailedDescription + ) + } - override def getInputEdges: util.List[ExecEdge] = { + private def getInputEdge: ExecEdge = { if (grouping.length == 0) { - List( - ExecEdge.builder() - .damBehavior(ExecEdge.DamBehavior.END_INPUT) - .build()) + ExecEdge.builder().damBehavior(ExecEdge.DamBehavior.END_INPUT).build() } else { - List(ExecEdge.DEFAULT) + ExecEdge.DEFAULT } } } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecSortAggregateBase.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalSortAggregateBase.scala similarity index 52% rename from flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecSortAggregateBase.scala rename to flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalSortAggregateBase.scala index 3528d8f..a3166bf 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecSortAggregateBase.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalSortAggregateBase.scala @@ -17,19 +17,8 @@ */ package org.apache.flink.table.planner.plan.nodes.physical.batch -import org.apache.flink.api.dag.Transformation -import org.apache.flink.table.data.RowData import org.apache.flink.table.functions.UserDefinedFunction -import org.apache.flink.table.planner.calcite.FlinkTypeFactory -import org.apache.flink.table.planner.codegen.CodeGeneratorContext -import org.apache.flink.table.planner.codegen.agg.batch.{AggWithoutKeysCodeGenerator, SortAggCodeGenerator} -import org.apache.flink.table.planner.delegation.BatchPlanner import org.apache.flink.table.planner.plan.cost.{FlinkCost, FlinkCostFactory} -import org.apache.flink.table.planner.plan.nodes.exec.LegacyBatchExecNode -import org.apache.flink.table.planner.plan.nodes.exec.utils.ExecNodeUtil -import org.apache.flink.table.planner.plan.utils.AggregateUtil.transformToBatchAggregateInfoList -import org.apache.flink.table.runtime.operators.CodeGenOperatorFactory -import org.apache.flink.table.runtime.typeutils.InternalTypeInfo import org.apache.calcite.plan.{RelOptCluster, RelOptCost, RelOptPlanner, RelTraitSet} import org.apache.calcite.rel.RelNode @@ -42,13 +31,11 @@ import org.apache.calcite.rel.metadata.RelMetadataQuery * * @see [[BatchPhysicalGroupAggregateBase]] for more info. */ -abstract class BatchExecSortAggregateBase( +abstract class BatchPhysicalSortAggregateBase( cluster: RelOptCluster, traitSet: RelTraitSet, inputRel: RelNode, outputRowType: RelDataType, - inputRowType: RelDataType, - aggInputRowType: RelDataType, grouping: Array[Int], auxGrouping: Array[Int], aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)], @@ -63,8 +50,7 @@ abstract class BatchExecSortAggregateBase( auxGrouping, aggCallToAggFunction, isMerge, - isFinal) - with LegacyBatchExecNode[RowData]{ + isFinal) { override def computeSelfCost(planner: RelOptPlanner, mq: RelMetadataQuery): RelOptCost = { val inputRows = mq.getRowCount(getInput()) @@ -79,42 +65,4 @@ abstract class BatchExecSortAggregateBase( val costFactory = planner.getCostFactory.asInstanceOf[FlinkCostFactory] costFactory.makeCost(rowCount, cpuCost, 0, 0, memCost) } - - //~ ExecNode methods ----------------------------------------------------------- - - override protected def translateToPlanInternal( - planner: BatchPlanner): Transformation[RowData] = { - val input = getInputNodes.get(0).translateToPlan(planner) - .asInstanceOf[Transformation[RowData]] - val ctx = CodeGeneratorContext(planner.getTableConfig) - val outputType = FlinkTypeFactory.toLogicalRowType(getRowType) - val inputType = FlinkTypeFactory.toLogicalRowType(inputRowType) - - val aggInfos = transformToBatchAggregateInfoList( - FlinkTypeFactory.toLogicalRowType(aggInputRowType), getAggCallList) - - val generatedOperator = if (grouping.isEmpty) { - AggWithoutKeysCodeGenerator.genWithoutKeys( - ctx, planner.getRelBuilder, aggInfos, inputType, outputType, isMerge, isFinal, "NoGrouping") - } else { - SortAggCodeGenerator.genWithKeys( - ctx, - planner.getRelBuilder, - aggInfos, - inputType, - outputType, - grouping, - auxGrouping, - isMerge, - isFinal) - } - val operator = new CodeGenOperatorFactory[RowData](generatedOperator) - ExecNodeUtil.createOneInputTransformation( - input, - getRelDetailedDescription, - operator, - InternalTypeInfo.of(outputType), - input.getParallelism, - 0) - } } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala index 587ed0b..5e26819 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala @@ -415,7 +415,7 @@ object FlinkBatchRuleSets { BatchPhysicalExpandRule.INSTANCE, // group agg BatchPhysicalHashAggRule.INSTANCE, - BatchExecSortAggRule.INSTANCE, + BatchPhysicalSortAggRule.INSTANCE, RemoveRedundantLocalSortAggRule.WITHOUT_SORT, RemoveRedundantLocalSortAggRule.WITH_SORT, RemoveRedundantLocalHashAggRule.INSTANCE, diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecSortAggRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalSortAggRule.scala similarity index 95% rename from flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecSortAggRule.scala rename to flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalSortAggRule.scala index 6bb4437..f3a931e 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecSortAggRule.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalSortAggRule.scala @@ -22,7 +22,7 @@ import org.apache.flink.table.planner.calcite.{FlinkContext, FlinkTypeFactory} import org.apache.flink.table.planner.plan.`trait`.FlinkRelDistribution import org.apache.flink.table.planner.plan.nodes.FlinkConventions import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalAggregate -import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchExecSortAggregate +import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalSortAggregate import org.apache.flink.table.planner.plan.utils.PythonUtil.isPythonAggregate import org.apache.flink.table.planner.plan.utils.{AggregateUtil, OperatorType} import org.apache.flink.table.planner.utils.TableConfigUtils.isOperatorDisabled @@ -36,7 +36,7 @@ import scala.collection.JavaConversions._ /** * Rule that converts [[FlinkLogicalAggregate]] to * {{{ - * BatchExecSortAggregate (global) + * BatchPhysicalSortAggregate (global) * +- Sort (exists if group keys are not empty) * +- BatchPhysicalExchange (hash by group keys if group keys is not empty, else singleton) * +- BatchExecLocalSortAggregate (local) @@ -46,7 +46,7 @@ import scala.collection.JavaConversions._ * when all aggregate functions are mergeable * and [[OptimizerConfigOptions.TABLE_OPTIMIZER_AGG_PHASE_STRATEGY]] is TWO_PHASE, or * {{{ - * BatchExecSortAggregate + * BatchPhysicalSortAggregate * +- Sort (exists if group keys are not empty) * +- BatchPhysicalExchange (hash by group keys if group keys is not empty, else singleton) * +- input of agg @@ -57,11 +57,11 @@ import scala.collection.JavaConversions._ * Notes: if [[OptimizerConfigOptions.TABLE_OPTIMIZER_AGG_PHASE_STRATEGY]] is NONE, * this rule will try to create two possibilities above, and chooses the best one based on cost. */ -class BatchExecSortAggRule +class BatchPhysicalSortAggRule extends RelOptRule( operand(classOf[FlinkLogicalAggregate], operand(classOf[RelNode], any)), - "BatchExecSortAggRule") + "BatchPhysicalSortAggRule") with BatchPhysicalAggRuleBase { override def matches(call: RelOptRuleCall): Boolean = { @@ -108,7 +108,7 @@ class BatchExecSortAggRule aggCallToAggFunction, isLocalHashAgg = false) - // create global BatchExecSortAggregate + // create global BatchPhysicalSortAggregate val (globalGroupSet, globalAuxGroupSet) = getGlobalAggGroupSetPair(groupSet, auxGroupSet) val (globalDistributions, globalCollation) = if (agg.getGroupCount != 0) { // global agg should use groupSet's indices as distribution fields @@ -139,7 +139,7 @@ class BatchExecSortAggRule .replace(globalCollation) val newInputForFinalAgg = RelOptRule.convert(localSortAgg, requiredTraitSet) - val globalSortAgg = new BatchExecSortAggregate( + val globalSortAgg = new BatchPhysicalSortAggregate( agg.getCluster, aggProvidedTraitSet, newInputForFinalAgg, @@ -173,7 +173,7 @@ class BatchExecSortAggRule requiredTraitSet = requiredTraitSet.replace(sortCollation) } val newInput = RelOptRule.convert(input, requiredTraitSet) - val sortAgg = new BatchExecSortAggregate( + val sortAgg = new BatchPhysicalSortAggregate( agg.getCluster, aggProvidedTraitSet, newInput, @@ -191,6 +191,6 @@ class BatchExecSortAggRule } } -object BatchExecSortAggRule { - val INSTANCE = new BatchExecSortAggRule +object BatchPhysicalSortAggRule { + val INSTANCE = new BatchPhysicalSortAggRule } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/EnforceLocalAggRuleBase.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/EnforceLocalAggRuleBase.scala index d4b5a26..a1c380f 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/EnforceLocalAggRuleBase.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/EnforceLocalAggRuleBase.scala @@ -22,7 +22,7 @@ import org.apache.flink.table.api.TableException import org.apache.flink.table.planner.calcite.FlinkTypeFactory import org.apache.flink.table.planner.plan.`trait`.FlinkRelDistribution import org.apache.flink.table.planner.plan.nodes.FlinkConventions -import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchExecSortAggregate, BatchPhysicalExchange, BatchPhysicalExpand, BatchPhysicalGroupAggregateBase, BatchPhysicalHashAggregate} +import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchPhysicalSortAggregate, BatchPhysicalExchange, BatchPhysicalExpand, BatchPhysicalGroupAggregateBase, BatchPhysicalHashAggregate} import org.apache.flink.table.planner.plan.utils.{AggregateUtil, FlinkRelOptUtil} import org.apache.calcite.plan.{RelOptRule, RelOptRuleOperand} @@ -82,7 +82,7 @@ abstract class EnforceLocalAggRuleBase( val isLocalHashAgg = completeAgg match { case _: BatchPhysicalHashAggregate => true - case _: BatchExecSortAggregate => false + case _: BatchPhysicalSortAggregate => false case _ => throw new TableException(s"Unsupported aggregate: ${completeAgg.getClass.getSimpleName}") } @@ -143,8 +143,8 @@ abstract class EnforceLocalAggRuleBase( newAuxGrouping, aggCallToAggFunction, isMerge = true) - case _: BatchExecSortAggregate => - new BatchExecSortAggregate( + case _: BatchPhysicalSortAggregate => + new BatchPhysicalSortAggregate( completeAgg.getCluster, completeAgg.getTraitSet, input, diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/EnforceLocalSortAggRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/EnforceLocalSortAggRule.scala index 65b14b5..e1036cf 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/EnforceLocalSortAggRule.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/EnforceLocalSortAggRule.scala @@ -19,14 +19,14 @@ package org.apache.flink.table.planner.plan.rules.physical.batch import org.apache.flink.table.planner.plan.nodes.FlinkConventions -import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchPhysicalSort, BatchExecSortAggregate, BatchPhysicalExchange, BatchPhysicalExpand} +import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchPhysicalExchange, BatchPhysicalExpand, BatchPhysicalSort, BatchPhysicalSortAggregate} import org.apache.calcite.plan.RelOptRule.{any, operand} import org.apache.calcite.plan.RelOptRuleCall import org.apache.calcite.rel.{RelCollationTraitDef, RelNode} /** - * An [[EnforceLocalAggRuleBase]] that matches [[BatchExecSortAggregate]] + * An [[EnforceLocalAggRuleBase]] that matches [[BatchPhysicalSortAggregate]] * * for example: select count(*) from t group by rollup (a, b) * The physical plan @@ -54,14 +54,14 @@ import org.apache.calcite.rel.{RelCollationTraitDef, RelNode} * }}} */ class EnforceLocalSortAggRule extends EnforceLocalAggRuleBase( - operand(classOf[BatchExecSortAggregate], + operand(classOf[BatchPhysicalSortAggregate], operand(classOf[BatchPhysicalSort], operand(classOf[BatchPhysicalExchange], operand(classOf[BatchPhysicalExpand], any)))), "EnforceLocalSortAggRule") { override def matches(call: RelOptRuleCall): Boolean = { - val agg: BatchExecSortAggregate = call.rel(0) + val agg: BatchPhysicalSortAggregate = call.rel(0) val expand: BatchPhysicalExpand = call.rel(3) val enableTwoPhaseAgg = isTwoPhaseAggEnabled(agg) @@ -73,7 +73,7 @@ class EnforceLocalSortAggRule extends EnforceLocalAggRuleBase( } override def onMatch(call: RelOptRuleCall): Unit = { - val agg: BatchExecSortAggregate = call.rel(0) + val agg: BatchPhysicalSortAggregate = call.rel(0) val expand: BatchPhysicalExpand = call.rel(3) val localGrouping = agg.grouping diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/RemoveRedundantLocalSortAggRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/RemoveRedundantLocalSortAggRule.scala index 0d8bee0..ae8d666 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/RemoveRedundantLocalSortAggRule.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/RemoveRedundantLocalSortAggRule.scala @@ -19,7 +19,7 @@ package org.apache.flink.table.planner.plan.rules.physical.batch import org.apache.flink.table.planner.plan.nodes.FlinkConventions -import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchExecLocalSortAggregate, BatchPhysicalSort, BatchExecSortAggregate} +import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchExecLocalSortAggregate, BatchPhysicalSort, BatchPhysicalSortAggregate} import org.apache.calcite.plan.RelOptRule._ import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelOptRuleOperand} @@ -38,7 +38,7 @@ abstract class RemoveRedundantLocalSortAggRule( val globalAgg = getOriginalGlobalAgg(call) val localAgg = getOriginalLocalAgg(call) val inputOfLocalAgg = getOriginalInputOfLocalAgg(call) - val newGlobalAgg = new BatchExecSortAggregate( + val newGlobalAgg = new BatchPhysicalSortAggregate( globalAgg.getCluster, globalAgg.getTraitSet, inputOfLocalAgg, @@ -54,7 +54,7 @@ abstract class RemoveRedundantLocalSortAggRule( call.transformTo(newGlobalAgg) } - private[table] def getOriginalGlobalAgg(call: RelOptRuleCall): BatchExecSortAggregate + private[table] def getOriginalGlobalAgg(call: RelOptRuleCall): BatchPhysicalSortAggregate private[table] def getOriginalLocalAgg(call: RelOptRuleCall): BatchExecLocalSortAggregate @@ -63,13 +63,14 @@ abstract class RemoveRedundantLocalSortAggRule( } class RemoveRedundantLocalSortAggWithoutSortRule extends RemoveRedundantLocalSortAggRule( - operand(classOf[BatchExecSortAggregate], + operand(classOf[BatchPhysicalSortAggregate], operand(classOf[BatchExecLocalSortAggregate], operand(classOf[RelNode], FlinkConventions.BATCH_PHYSICAL, any))), "RemoveRedundantLocalSortAggWithoutSortRule") { - override private[table] def getOriginalGlobalAgg(call: RelOptRuleCall): BatchExecSortAggregate = { - call.rels(0).asInstanceOf[BatchExecSortAggregate] + override private[table] def getOriginalGlobalAgg( + call: RelOptRuleCall): BatchPhysicalSortAggregate = { + call.rels(0).asInstanceOf[BatchPhysicalSortAggregate] } override private[table] def getOriginalLocalAgg( @@ -84,14 +85,15 @@ class RemoveRedundantLocalSortAggWithoutSortRule extends RemoveRedundantLocalSor } class RemoveRedundantLocalSortAggWithSortRule extends RemoveRedundantLocalSortAggRule( - operand(classOf[BatchExecSortAggregate], + operand(classOf[BatchPhysicalSortAggregate], operand(classOf[BatchPhysicalSort], operand(classOf[BatchExecLocalSortAggregate], operand(classOf[RelNode], FlinkConventions.BATCH_PHYSICAL, any)))), "RemoveRedundantLocalSortAggWithSortRule") { - override private[table] def getOriginalGlobalAgg(call: RelOptRuleCall): BatchExecSortAggregate = { - call.rels(0).asInstanceOf[BatchExecSortAggregate] + override private[table] def getOriginalGlobalAgg( + call: RelOptRuleCall): BatchPhysicalSortAggregate = { + call.rels(0).asInstanceOf[BatchPhysicalSortAggregate] } override private[table] def getOriginalLocalAgg( diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/rules/physical/batch/EnforceLocalSortAggRuleTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/rules/physical/batch/EnforceLocalSortAggRuleTest.scala index 346f31a..e0eeca2 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/rules/physical/batch/EnforceLocalSortAggRuleTest.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/rules/physical/batch/EnforceLocalSortAggRuleTest.scala @@ -46,7 +46,7 @@ class EnforceLocalSortAggRuleTest extends EnforceLocalAggRuleTestBase { // remove the original BatchExecSortAggRule and add BatchExecSortAggRuleForOnePhase // to let the physical phase generate one phase aggregate program.getFlinkRuleSetProgram(FlinkBatchProgram.PHYSICAL) - .get.remove(RuleSets.ofList(BatchExecSortAggRule.INSTANCE)) + .get.remove(RuleSets.ofList(BatchPhysicalSortAggRule.INSTANCE)) program.getFlinkRuleSetProgram(FlinkBatchProgram.PHYSICAL) .get.add(RuleSets.ofList(BatchExecSortAggRuleForOnePhase.INSTANCE)) @@ -82,7 +82,7 @@ class EnforceLocalSortAggRuleTest extends EnforceLocalAggRuleTestBase { * value, and only enable one phase aggregate. * This rule only used for test. */ -class BatchExecSortAggRuleForOnePhase extends BatchExecSortAggRule { +class BatchExecSortAggRuleForOnePhase extends BatchPhysicalSortAggRule { override protected def isTwoPhaseAggWorkable( aggFunctions: Array[UserDefinedFunction], tableConfig: TableConfig): Boolean = false
