This is an automated email from the ASF dual-hosted git repository. godfrey pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit 049a61a578c8ba7ed65eca69eb543be096ecffac Author: godfreyhe <[email protected]> AuthorDate: Tue Jan 5 20:17:13 2021 +0800 [FLINK-20738][table-planner-blink] Introduce BatchPhysicalPythonGroupAggregate, and make BatchExecPythonGroupAggregate only extended from ExecNode This closes #14562 --- ....java => BatchPhysicalPythonAggregateRule.java} | 14 +-- .../exec/batch/BatchExecPythonGroupAggregate.scala | 130 +++++++++++++++++++++ ...ala => BatchPhysicalPythonGroupAggregate.scala} | 108 +++-------------- .../planner/plan/rules/FlinkBatchRuleSets.scala | 2 +- 4 files changed, 151 insertions(+), 103 deletions(-) diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecPythonAggregateRule.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalPythonAggregateRule.java similarity index 93% rename from flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecPythonAggregateRule.java rename to flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalPythonAggregateRule.java index f1c2ab5..a5cad3f 100644 --- a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecPythonAggregateRule.java +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalPythonAggregateRule.java @@ -24,7 +24,7 @@ import org.apache.flink.table.functions.python.PythonFunctionKind; import org.apache.flink.table.planner.calcite.FlinkTypeFactory; import org.apache.flink.table.planner.plan.nodes.FlinkConventions; import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalAggregate; -import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchExecPythonGroupAggregate; +import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalPythonGroupAggregate; import org.apache.flink.table.planner.plan.trait.FlinkRelDistribution; import org.apache.flink.table.planner.plan.utils.AggregateUtil; import org.apache.flink.table.planner.plan.utils.FlinkRelOptUtil; @@ -50,18 +50,18 @@ import scala.collection.Seq; /** * The physical rule which is responsible for converting {@link FlinkLogicalAggregate} to {@link - * BatchExecPythonGroupAggregate}. + * BatchPhysicalPythonGroupAggregate}. */ -public class BatchExecPythonAggregateRule extends ConverterRule { +public class BatchPhysicalPythonAggregateRule extends ConverterRule { - public static final RelOptRule INSTANCE = new BatchExecPythonAggregateRule(); + public static final RelOptRule INSTANCE = new BatchPhysicalPythonAggregateRule(); - private BatchExecPythonAggregateRule() { + private BatchPhysicalPythonAggregateRule() { super( FlinkLogicalAggregate.class, FlinkConventions.LOGICAL(), FlinkConventions.BATCH_PHYSICAL(), - "BatchExecPythonAggregateRule"); + "BatchPhysicalPythonAggregateRule"); } @Override @@ -124,7 +124,7 @@ public class BatchExecPythonAggregateRule extends ConverterRule { } RelNode convInput = RelOptRule.convert(input, requiredTraitSet); - return new BatchExecPythonGroupAggregate( + return new BatchPhysicalPythonGroupAggregate( relNode.getCluster(), traitSet, convInput, diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecPythonGroupAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecPythonGroupAggregate.scala new file mode 100644 index 0000000..e2c5195 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecPythonGroupAggregate.scala @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.nodes.exec.batch + +import org.apache.flink.api.dag.Transformation +import org.apache.flink.configuration.Configuration +import org.apache.flink.core.memory.ManagedMemoryUseCase +import org.apache.flink.streaming.api.operators.OneInputStreamOperator +import org.apache.flink.streaming.api.transformations.OneInputTransformation +import org.apache.flink.table.data.RowData +import org.apache.flink.table.functions.python.PythonFunctionInfo +import org.apache.flink.table.planner.delegation.PlannerBase +import org.apache.flink.table.planner.plan.nodes.exec.batch.BatchExecPythonGroupAggregate.ARROW_PYTHON_AGGREGATE_FUNCTION_OPERATOR_NAME +import org.apache.flink.table.planner.plan.nodes.exec.common.CommonExecPythonAggregate +import org.apache.flink.table.planner.plan.nodes.exec.{ExecEdge, ExecNode, ExecNodeBase} +import org.apache.flink.table.planner.utils.Logging +import org.apache.flink.table.runtime.typeutils.InternalTypeInfo +import org.apache.flink.table.types.logical.RowType + +import org.apache.calcite.rel.core.AggregateCall + +import java.util.Collections + +/** + * Batch [[ExecNode]] for aggregate (Python user defined aggregate function). + * + * <p>Note: This class can't be ported to Java, + * because java class can't extend scala interface with default implementation. + * FLINK-20751 will port this class to Java. + */ +class BatchExecPythonGroupAggregate( + grouping: Array[Int], + auxGrouping: Array[Int], + aggCalls: Seq[AggregateCall], + inputEdge: ExecEdge, + outputType: RowType, + description: String) + extends ExecNodeBase[RowData](Collections.singletonList(inputEdge), outputType, description) + with BatchExecNode[RowData] + with CommonExecPythonAggregate + with Logging { + + override protected def translateToPlanInternal( + planner: PlannerBase): Transformation[RowData] = { + val inputNode = getInputNodes.get(0).asInstanceOf[ExecNode[RowData]] + val inputTransform = inputNode.translateToPlan(planner) + + val ret = createPythonOneInputTransformation( + inputTransform, + inputNode.getOutputType.asInstanceOf[RowType], + outputType, + getConfig(planner.getExecEnv, planner.getTableConfig)) + if (isPythonWorkerUsingManagedMemory(planner.getTableConfig.getConfiguration)) { + ret.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON) + } + ret + } + + private[this] def createPythonOneInputTransformation( + inputTransform: Transformation[RowData], + inputRowType: RowType, + outputRowType: RowType, + config: Configuration): OneInputTransformation[RowData, RowData] = { + + val (pythonUdafInputOffsets, pythonFunctionInfos) = + extractPythonAggregateFunctionInfosFromAggregateCall(aggCalls) + + val pythonOperator = getPythonAggregateFunctionOperator( + config, + inputRowType, + outputRowType, + pythonUdafInputOffsets, + pythonFunctionInfos) + + new OneInputTransformation( + inputTransform, + "BatchExecPythonGroupAggregate", + pythonOperator, + InternalTypeInfo.of(outputRowType), + inputTransform.getParallelism) + } + + private[this] def getPythonAggregateFunctionOperator( + config: Configuration, + inputRowType: RowType, + outputRowType: RowType, + udafInputOffsets: Array[Int], + pythonFunctionInfos: Array[PythonFunctionInfo]): OneInputStreamOperator[RowData, RowData] = { + val clazz = loadClass(ARROW_PYTHON_AGGREGATE_FUNCTION_OPERATOR_NAME) + + val ctor = clazz.getConstructor( + classOf[Configuration], + classOf[Array[PythonFunctionInfo]], + classOf[RowType], + classOf[RowType], + classOf[Array[Int]], + classOf[Array[Int]], + classOf[Array[Int]]) + ctor.newInstance( + config, + pythonFunctionInfos, + inputRowType, + outputRowType, + grouping, + grouping ++ auxGrouping, + udafInputOffsets).asInstanceOf[OneInputStreamOperator[RowData, RowData]] + } +} + +object BatchExecPythonGroupAggregate { + val ARROW_PYTHON_AGGREGATE_FUNCTION_OPERATOR_NAME: String = + "org.apache.flink.table.runtime.operators.python.aggregate.arrow.batch." + + "BatchArrowPythonGroupAggregateFunctionOperator" +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecPythonGroupAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalPythonGroupAggregate.scala similarity index 62% rename from flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecPythonGroupAggregate.scala rename to flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalPythonGroupAggregate.scala index ab90b90..cfb9cb7 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecPythonGroupAggregate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalPythonGroupAggregate.scala @@ -18,24 +18,13 @@ package org.apache.flink.table.planner.plan.nodes.physical.batch -import org.apache.flink.api.dag.Transformation -import org.apache.flink.configuration.Configuration -import org.apache.flink.core.memory.ManagedMemoryUseCase -import org.apache.flink.streaming.api.operators.OneInputStreamOperator -import org.apache.flink.streaming.api.transformations.OneInputTransformation -import org.apache.flink.table.data.RowData import org.apache.flink.table.functions.UserDefinedFunction -import org.apache.flink.table.functions.python.PythonFunctionInfo import org.apache.flink.table.planner.calcite.FlinkTypeFactory -import org.apache.flink.table.planner.delegation.BatchPlanner import org.apache.flink.table.planner.plan.`trait`.{FlinkRelDistribution, FlinkRelDistributionTraitDef} -import org.apache.flink.table.planner.plan.nodes.exec.common.CommonExecPythonAggregate -import org.apache.flink.table.planner.plan.nodes.exec.{ExecEdge, LegacyBatchExecNode} -import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchExecPythonGroupAggregate.ARROW_PYTHON_AGGREGATE_FUNCTION_OPERATOR_NAME +import org.apache.flink.table.planner.plan.nodes.exec.{ExecEdge, ExecNode} +import org.apache.flink.table.planner.plan.nodes.exec.batch.BatchExecPythonGroupAggregate import org.apache.flink.table.planner.plan.rules.physical.batch.BatchExecJoinRuleBase import org.apache.flink.table.planner.plan.utils.{FlinkRelOptUtil, RelExplainUtil} -import org.apache.flink.table.runtime.typeutils.InternalTypeInfo -import org.apache.flink.table.types.logical.RowType import org.apache.calcite.plan.{RelOptCluster, RelOptRule, RelTraitSet} import org.apache.calcite.rel.RelDistribution.Type.{HASH_DISTRIBUTED, SINGLETON} @@ -51,7 +40,7 @@ import scala.collection.JavaConversions._ /** * Batch physical RelNode for aggregate (Python user defined aggregate function). */ -class BatchExecPythonGroupAggregate( +class BatchPhysicalPythonGroupAggregate( cluster: RelOptCluster, traitSet: RelTraitSet, inputRel: RelNode, @@ -71,9 +60,7 @@ class BatchExecPythonGroupAggregate( auxGrouping, aggCalls.zip(aggFunctions), isMerge = false, - isFinal = true) - with LegacyBatchExecNode[RowData] - with CommonExecPythonAggregate { + isFinal = true) { override def explainTerms(pw: RelWriter): RelWriter = super.explainTerms(pw) @@ -149,7 +136,7 @@ class BatchExecPythonGroupAggregate( } override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = { - new BatchExecPythonGroupAggregate( + new BatchPhysicalPythonGroupAggregate( cluster, traitSet, inputs.get(0), @@ -162,84 +149,15 @@ class BatchExecPythonGroupAggregate( aggFunctions) } - //~ ExecNode methods ----------------------------------------------------------- - - override def getInputEdges: util.List[ExecEdge] = List( - ExecEdge.builder() - .damBehavior(ExecEdge.DamBehavior.END_INPUT) - .build()) - - override protected def translateToPlanInternal( - planner: BatchPlanner): Transformation[RowData] = { - val input = getInputNodes.get(0).translateToPlan(planner) - .asInstanceOf[Transformation[RowData]] - val outputType = FlinkTypeFactory.toLogicalRowType(getRowType) - val inputType = FlinkTypeFactory.toLogicalRowType(inputRowType) - - val ret = createPythonOneInputTransformation( - input, - inputType, - outputType, - getConfig(planner.getExecEnv, planner.getTableConfig)) - if (isPythonWorkerUsingManagedMemory(planner.getTableConfig.getConfiguration)) { - ret.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON) - } - ret - } - - private[this] def createPythonOneInputTransformation( - inputTransform: Transformation[RowData], - inputRowType: RowType, - outputRowType: RowType, - config: Configuration): OneInputTransformation[RowData, RowData] = { - - val (pythonUdafInputOffsets, pythonFunctionInfos) = - extractPythonAggregateFunctionInfosFromAggregateCall(aggCalls) - - val pythonOperator = getPythonAggregateFunctionOperator( - config, - inputRowType, - outputRowType, - pythonUdafInputOffsets, - pythonFunctionInfos) - - new OneInputTransformation( - inputTransform, - "BatchExecPythonGroupAggregate", - pythonOperator, - InternalTypeInfo.of(outputRowType), - inputTransform.getParallelism) - } - - private[this] def getPythonAggregateFunctionOperator( - config: Configuration, - inputRowType: RowType, - outputRowType: RowType, - udafInputOffsets: Array[Int], - pythonFunctionInfos: Array[PythonFunctionInfo]): OneInputStreamOperator[RowData, RowData] = { - val clazz = loadClass(ARROW_PYTHON_AGGREGATE_FUNCTION_OPERATOR_NAME) - - val ctor = clazz.getConstructor( - classOf[Configuration], - classOf[Array[PythonFunctionInfo]], - classOf[RowType], - classOf[RowType], - classOf[Array[Int]], - classOf[Array[Int]], - classOf[Array[Int]]) - ctor.newInstance( - config, - pythonFunctionInfos, - inputRowType, - outputRowType, + override def translateToExecNode(): ExecNode[_] = { + new BatchExecPythonGroupAggregate( grouping, - grouping ++ auxGrouping, - udafInputOffsets).asInstanceOf[OneInputStreamOperator[RowData, RowData]] + auxGrouping, + aggCalls, + ExecEdge.builder().damBehavior(ExecEdge.DamBehavior.END_INPUT).build(), + FlinkTypeFactory.toLogicalRowType(getRowType), + getRelDetailedDescription + ) } } -object BatchExecPythonGroupAggregate { - val ARROW_PYTHON_AGGREGATE_FUNCTION_OPERATOR_NAME: String = - "org.apache.flink.table.runtime.operators.python.aggregate.arrow.batch." + - "BatchArrowPythonGroupAggregateFunctionOperator" -} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala index 5e26819..1f5809e 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkBatchRuleSets.scala @@ -419,7 +419,7 @@ object FlinkBatchRuleSets { RemoveRedundantLocalSortAggRule.WITHOUT_SORT, RemoveRedundantLocalSortAggRule.WITH_SORT, RemoveRedundantLocalHashAggRule.INSTANCE, - BatchExecPythonAggregateRule.INSTANCE, + BatchPhysicalPythonAggregateRule.INSTANCE, // over agg BatchExecOverAggregateRule.INSTANCE, // window agg
