This is an automated email from the ASF dual-hosted git repository. godfrey pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit 9f8f5cd370312887bbabc2dad31d2a2d3f17330f Author: godfreyhe <[email protected]> AuthorDate: Tue Jan 5 18:47:41 2021 +0800 [FLINK-20738][table-planner-blink] Introduce BatchPhysicalLocalHashAggregate, and make BatchPhysicalLocalHashAggregate only extended from FlinkPhysicalRel This closes #14562 --- .../plan/metadata/FlinkRelMdColumnInterval.scala | 2 +- ...scala => BatchPhysicalLocalHashAggregate.scala} | 83 ++++++---------------- .../physical/batch/BatchExecJoinRuleBase.scala | 4 +- .../physical/batch/BatchPhysicalAggRuleBase.scala | 4 +- .../physical/batch/BatchPhysicalHashAggRule.scala | 4 +- .../batch/RemoveRedundantLocalHashAggRule.scala | 6 +- .../plan/metadata/FlinkRelMdHandlerTestBase.scala | 4 +- 7 files changed, 32 insertions(+), 75 deletions(-) diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala index e6f3f9c..82de446 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala @@ -613,7 +613,7 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] { agg.partialAggCalls(aggCallIndex) case agg: StreamExecGroupWindowAggregate if agg.aggCalls.length > aggCallIndex => agg.aggCalls(aggCallIndex) - case agg: BatchExecLocalHashAggregate => + case agg: BatchPhysicalLocalHashAggregate => getAggCallFromLocalAgg(aggCallIndex, agg.getAggCallList, agg.getInput.getRowType) case agg: BatchPhysicalHashAggregate if agg.isMerge => val aggCallIndexInLocalAgg = getAggCallIndexInLocalAgg( diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecLocalHashAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalLocalHashAggregate.scala similarity index 62% rename from flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecLocalHashAggregate.scala rename to flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalLocalHashAggregate.scala index 5dad79f..3ea9492 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecLocalHashAggregate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalLocalHashAggregate.scala @@ -18,22 +18,12 @@ package org.apache.flink.table.planner.plan.nodes.physical.batch -import org.apache.flink.api.dag.Transformation -import org.apache.flink.configuration.MemorySize -import org.apache.flink.table.api.config.ExecutionConfigOptions -import org.apache.flink.table.data.RowData import org.apache.flink.table.functions.UserDefinedFunction import org.apache.flink.table.planner.calcite.FlinkTypeFactory -import org.apache.flink.table.planner.codegen.CodeGeneratorContext -import org.apache.flink.table.planner.codegen.agg.batch.{AggWithoutKeysCodeGenerator, HashAggCodeGenerator} -import org.apache.flink.table.planner.delegation.BatchPlanner import org.apache.flink.table.planner.plan.`trait`.{FlinkRelDistribution, FlinkRelDistributionTraitDef} -import org.apache.flink.table.planner.plan.nodes.exec.utils.ExecNodeUtil -import org.apache.flink.table.planner.plan.nodes.exec.{ExecEdge, LegacyBatchExecNode} -import org.apache.flink.table.planner.plan.utils.AggregateUtil.transformToBatchAggregateInfoList +import org.apache.flink.table.planner.plan.nodes.exec.batch.BatchExecHashAggregate +import org.apache.flink.table.planner.plan.nodes.exec.{ExecEdge, ExecNode} import org.apache.flink.table.planner.plan.utils.RelExplainUtil -import org.apache.flink.table.runtime.operators.CodeGenOperatorFactory -import org.apache.flink.table.runtime.typeutils.InternalTypeInfo import org.apache.calcite.plan.{RelOptCluster, RelOptRule, RelTraitSet} import org.apache.calcite.rel.RelDistribution.Type @@ -51,7 +41,7 @@ import scala.collection.JavaConversions._ * * @see [[BatchPhysicalGroupAggregateBase]] for more info. */ -class BatchExecLocalHashAggregate( +class BatchPhysicalLocalHashAggregate( cluster: RelOptCluster, traitSet: RelTraitSet, inputRel: RelNode, @@ -69,11 +59,10 @@ class BatchExecLocalHashAggregate( auxGrouping, aggCallToAggFunction, isMerge = false, - isFinal = false) - with LegacyBatchExecNode[RowData] { + isFinal = false) { override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = { - new BatchExecLocalHashAggregate( + new BatchPhysicalLocalHashAggregate( cluster, traitSet, inputs.get(0), @@ -130,57 +119,25 @@ class BatchExecLocalHashAggregate( Some(copy(providedTraits, Seq(newInput))) } - //~ ExecNode methods ----------------------------------------------------------- - - override protected def translateToPlanInternal( - planner: BatchPlanner): Transformation[RowData] = { - val config = planner.getTableConfig - val input = getInputNodes.get(0).translateToPlan(planner) - .asInstanceOf[Transformation[RowData]] - val ctx = CodeGeneratorContext(config) - val outputType = FlinkTypeFactory.toLogicalRowType(getRowType) - val inputType = FlinkTypeFactory.toLogicalRowType(inputRowType) - - val aggInfos = transformToBatchAggregateInfoList( - FlinkTypeFactory.toLogicalRowType(inputRowType), getAggCallList) - - var managedMemory: Long = 0L - val generatedOperator = if (grouping.isEmpty) { - AggWithoutKeysCodeGenerator.genWithoutKeys( - ctx, planner.getRelBuilder, aggInfos, inputType, outputType, isMerge, isFinal, "NoGrouping") - } else { - managedMemory = MemorySize.parse(config.getConfiguration.getString( - ExecutionConfigOptions.TABLE_EXEC_RESOURCE_HASH_AGG_MEMORY)).getBytes - new HashAggCodeGenerator( - ctx, - planner.getRelBuilder, - aggInfos, - inputType, - outputType, - grouping, - auxGrouping, - isMerge, - isFinal - ).genWithKeys() - } - val operator = new CodeGenOperatorFactory[RowData](generatedOperator) - ExecNodeUtil.createOneInputTransformation( - input, - getRelDetailedDescription, - operator, - InternalTypeInfo.of(outputType), - input.getParallelism, - managedMemory) + override def translateToExecNode(): ExecNode[_] = { + new BatchExecHashAggregate( + grouping, + auxGrouping, + getAggCallList.toArray, + FlinkTypeFactory.toLogicalRowType(inputRowType), + false, // isMerge is always false + false, // isFinal is always false + getInputEdge, + FlinkTypeFactory.toLogicalRowType(getRowType), + getRelDetailedDescription + ) } - override def getInputEdges: util.List[ExecEdge] = { + private def getInputEdge: ExecEdge = { if (grouping.length == 0) { - List( - ExecEdge.builder() - .damBehavior(ExecEdge.DamBehavior.END_INPUT) - .build()) + ExecEdge.builder().damBehavior(ExecEdge.DamBehavior.END_INPUT).build() } else { - List(ExecEdge.DEFAULT) + ExecEdge.DEFAULT } } } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecJoinRuleBase.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecJoinRuleBase.scala index 9dead54..c2a272a 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecJoinRuleBase.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecJoinRuleBase.scala @@ -22,7 +22,7 @@ import org.apache.flink.annotation.Experimental import org.apache.flink.configuration.ConfigOption import org.apache.flink.configuration.ConfigOptions.key import org.apache.flink.table.planner.plan.nodes.FlinkConventions -import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchExecLocalHashAggregate +import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalLocalHashAggregate import org.apache.flink.table.planner.plan.utils.{FlinkRelMdUtil, FlinkRelOptUtil} import org.apache.calcite.plan.RelOptRule @@ -40,7 +40,7 @@ trait BatchExecJoinRuleBase { val newInput = RelOptRule.convert(node, localRequiredTraitSet) val providedTraitSet = localRequiredTraitSet - new BatchExecLocalHashAggregate( + new BatchPhysicalLocalHashAggregate( node.getCluster, providedTraitSet, newInput, diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalAggRuleBase.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalAggRuleBase.scala index d0a9aa9..bc0aabb 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalAggRuleBase.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalAggRuleBase.scala @@ -24,7 +24,7 @@ import org.apache.flink.table.planner.JArrayList import org.apache.flink.table.planner.calcite.FlinkTypeFactory import org.apache.flink.table.planner.functions.aggfunctions.DeclarativeAggregateFunction import org.apache.flink.table.planner.functions.utils.UserDefinedFunctionUtils._ -import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchExecLocalHashAggregate, BatchExecLocalSortAggregate, BatchPhysicalGroupAggregateBase} +import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchExecLocalSortAggregate, BatchPhysicalGroupAggregateBase, BatchPhysicalLocalHashAggregate} import org.apache.flink.table.planner.plan.utils.{AggregateUtil, FlinkRelOptUtil} import org.apache.flink.table.planner.utils.AggregatePhaseStrategy import org.apache.flink.table.planner.utils.TableConfigUtils.getAggPhaseStrategy @@ -209,7 +209,7 @@ trait BatchPhysicalAggRuleBase { aggBufferTypes.map(_.map(fromDataTypeToLogicalType))) if (isLocalHashAgg) { - new BatchExecLocalHashAggregate( + new BatchPhysicalLocalHashAggregate( cluster, traitSet, input, diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalHashAggRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalHashAggRule.scala index 536516d..767e734 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalHashAggRule.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalHashAggRule.scala @@ -39,7 +39,7 @@ import scala.collection.JavaConversions._ * {{{ * BatchPhysicalHashAggregate (global) * +- BatchPhysicalExchange (hash by group keys if group keys is not empty, else singleton) - * +- BatchExecLocalHashAggregate (local) + * +- BatchPhysicalLocalHashAggregate (local) * +- input of agg * }}} * when all aggregate functions are mergeable @@ -94,7 +94,7 @@ class BatchPhysicalHashAggRule // create two-phase agg if possible if (isTwoPhaseAggWorkable(aggFunctions, tableConfig)) { - // create BatchExecLocalHashAggregate + // create BatchPhysicalLocalHashAggregate val localRequiredTraitSet = input.getTraitSet.replace(FlinkConventions.BATCH_PHYSICAL) val newInput = RelOptRule.convert(input, localRequiredTraitSet) val providedTraitSet = localRequiredTraitSet diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/RemoveRedundantLocalHashAggRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/RemoveRedundantLocalHashAggRule.scala index 540d1b2..4484f2e 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/RemoveRedundantLocalHashAggRule.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/RemoveRedundantLocalHashAggRule.scala @@ -19,7 +19,7 @@ package org.apache.flink.table.planner.plan.rules.physical.batch import org.apache.flink.table.planner.plan.nodes.FlinkConventions -import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchPhysicalHashAggregate, BatchExecLocalHashAggregate} +import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchPhysicalHashAggregate, BatchPhysicalLocalHashAggregate} import org.apache.calcite.plan.RelOptRule._ import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall} @@ -31,13 +31,13 @@ import org.apache.calcite.rel.RelNode */ class RemoveRedundantLocalHashAggRule extends RelOptRule( operand(classOf[BatchPhysicalHashAggregate], - operand(classOf[BatchExecLocalHashAggregate], + operand(classOf[BatchPhysicalLocalHashAggregate], operand(classOf[RelNode], FlinkConventions.BATCH_PHYSICAL, any))), "RemoveRedundantLocalHashAggRule") { override def onMatch(call: RelOptRuleCall): Unit = { val globalAgg: BatchPhysicalHashAggregate = call.rel(0) - val localAgg: BatchExecLocalHashAggregate = call.rel(1) + val localAgg: BatchPhysicalLocalHashAggregate = call.rel(1) val inputOfLocalAgg = localAgg.getInput val newGlobalAgg = new BatchPhysicalHashAggregate( globalAgg.getCluster, diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala index bbb123b..353d5de 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala @@ -979,7 +979,7 @@ class FlinkRelMdHandlerTestBase { val hash0 = FlinkRelDistribution.hash(Array(0), requireStrict = true) val hash3 = FlinkRelDistribution.hash(Array(3), requireStrict = true) - val batchLocalAgg = new BatchExecLocalHashAggregate( + val batchLocalAgg = new BatchPhysicalLocalHashAggregate( cluster, batchPhysicalTraits, studentBatchScan, @@ -1106,7 +1106,7 @@ class FlinkRelMdHandlerTestBase { .add("sum_score", doubleType) .add("cnt", longType).build() - val batchLocalAggWithAuxGroup = new BatchExecLocalHashAggregate( + val batchLocalAggWithAuxGroup = new BatchPhysicalLocalHashAggregate( cluster, batchPhysicalTraits, studentBatchScan,
