This is an automated email from the ASF dual-hosted git repository. godfrey pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit a44285a4a1d0836e99e90244894c57637e81ab58 Author: godfreyhe <[email protected]> AuthorDate: Thu Dec 24 19:42:37 2020 +0800 [FLINK-20737][table-planner-blink] Introduce StreamPhysicalPythonGroupAggregate, and make StreamExecPythonGroupAggregate only extended from ExecNode This closes #14478 --- ...=> StreamPhysicalPythonGroupAggregateRule.java} | 22 ++-- .../common/CommonExecPythonAggregate.scala} | 8 +- .../stream/StreamExecPythonGroupAggregate.scala | 114 ++++++++------------- .../batch/BatchExecPythonGroupAggregate.scala | 6 +- .../BatchExecPythonGroupWindowAggregate.scala | 6 +- .../batch/BatchExecPythonOverAggregate.scala | 4 +- .../StreamExecPythonGroupTableAggregate.scala | 4 +- .../StreamExecPythonGroupWindowAggregate.scala | 4 +- .../stream/StreamExecPythonOverAggregate.scala | 13 ++- .../StreamPhysicalPythonGroupAggregate.scala | 92 +++++++++++++++++ .../FlinkChangelogModeInferenceProgram.scala | 4 +- .../planner/plan/rules/FlinkStreamRuleSets.scala | 2 +- .../plan/stream/table/PythonAggregateTest.scala | 5 +- 13 files changed, 171 insertions(+), 113 deletions(-) diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/rules/physical/stream/StreamExecPythonGroupAggregateRule.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/rules/physical/stream/StreamPhysicalPythonGroupAggregateRule.java similarity index 88% rename from flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/rules/physical/stream/StreamExecPythonGroupAggregateRule.java rename to flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/rules/physical/stream/StreamPhysicalPythonGroupAggregateRule.java index cc2f87c..602bdef 100644 --- a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/rules/physical/stream/StreamExecPythonGroupAggregateRule.java +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/rules/physical/stream/StreamPhysicalPythonGroupAggregateRule.java @@ -22,9 +22,10 @@ import org.apache.flink.table.api.TableException; import org.apache.flink.table.functions.python.PythonFunctionKind; import org.apache.flink.table.planner.plan.nodes.FlinkConventions; import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalAggregate; -import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamExecPythonGroupAggregate; +import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalPythonGroupAggregate; import org.apache.flink.table.planner.plan.trait.FlinkRelDistribution; import org.apache.flink.table.planner.plan.utils.PythonUtil; +import org.apache.flink.table.planner.utils.JavaScalaConversionUtil; import org.apache.calcite.plan.RelOptRule; import org.apache.calcite.plan.RelOptRuleCall; @@ -36,21 +37,20 @@ import org.apache.calcite.rel.core.AggregateCall; import java.util.List; -import scala.collection.JavaConverters; - /** - * Rule to convert a {@link FlinkLogicalAggregate} into a {@link StreamExecPythonGroupAggregate}. + * Rule to convert a {@link FlinkLogicalAggregate} into a {@link + * StreamPhysicalPythonGroupAggregate}. */ -public class StreamExecPythonGroupAggregateRule extends ConverterRule { +public class StreamPhysicalPythonGroupAggregateRule extends ConverterRule { - public static final RelOptRule INSTANCE = new StreamExecPythonGroupAggregateRule(); + public static final RelOptRule INSTANCE = new StreamPhysicalPythonGroupAggregateRule(); - public StreamExecPythonGroupAggregateRule() { + public StreamPhysicalPythonGroupAggregateRule() { super( FlinkLogicalAggregate.class, FlinkConventions.LOGICAL(), FlinkConventions.STREAM_PHYSICAL(), - "StreamExecPythonGroupAggregateRule"); + "StreamPhysicalPythonGroupAggregateRule"); } @Override @@ -111,14 +111,12 @@ public class StreamExecPythonGroupAggregateRule extends ConverterRule { rel.getTraitSet().replace(FlinkConventions.STREAM_PHYSICAL()); RelNode newInput = RelOptRule.convert(agg.getInput(), requiredTraitSet); - return new StreamExecPythonGroupAggregate( + return new StreamPhysicalPythonGroupAggregate( rel.getCluster(), providedTraitSet, newInput, rel.getRowType(), agg.getGroupSet().toArray(), - JavaConverters.asScalaIteratorConverter(agg.getAggCallList().iterator()) - .asScala() - .toSeq()); + JavaScalaConversionUtil.toScala(agg.getAggCallList())); } } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/common/CommonPythonAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecPythonAggregate.scala similarity index 98% rename from flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/common/CommonPythonAggregate.scala rename to flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecPythonAggregate.scala index b06e21d..0685296 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/common/CommonPythonAggregate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecPythonAggregate.scala @@ -16,9 +16,8 @@ * limitations under the License. */ -package org.apache.flink.table.planner.plan.nodes.common +package org.apache.flink.table.planner.plan.nodes.exec.common -import org.apache.calcite.rel.core.AggregateCall import org.apache.flink.table.api.TableException import org.apache.flink.table.api.dataview.{DataView, ListView, MapView} import org.apache.flink.table.functions.UserDefinedFunction @@ -27,15 +26,18 @@ import org.apache.flink.table.planner.functions.aggfunctions.Sum0AggFunction._ import org.apache.flink.table.planner.functions.aggfunctions._ import org.apache.flink.table.planner.functions.bridging.BridgingSqlAggFunction import org.apache.flink.table.planner.functions.utils.AggSqlFunction +import org.apache.flink.table.planner.plan.nodes.common.CommonPythonBase import org.apache.flink.table.planner.plan.utils.AggregateInfoList import org.apache.flink.table.planner.typeutils.DataViewUtils.{DataViewSpec, ListViewSpec, MapViewSpec} import org.apache.flink.table.types.logical.{RowType, StructuredType} import org.apache.flink.table.types.{DataType, FieldsDataType} +import org.apache.calcite.rel.core.AggregateCall + import scala.collection.JavaConversions._ import scala.collection.mutable -trait CommonPythonAggregate extends CommonPythonBase { +trait CommonExecPythonAggregate extends CommonPythonBase { /** * For batch execution we extract the PythonFunctionInfo from AggregateCall. diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonGroupAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonGroupAggregate.scala similarity index 65% rename from flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonGroupAggregate.scala rename to flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonGroupAggregate.scala index 3570ca1..eb36a22 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonGroupAggregate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonGroupAggregate.scala @@ -16,7 +16,7 @@ * limitations under the License. */ -package org.apache.flink.table.planner.plan.nodes.physical.stream +package org.apache.flink.table.planner.plan.nodes.exec.stream import org.apache.flink.api.dag.Transformation import org.apache.flink.configuration.Configuration @@ -25,74 +25,41 @@ import org.apache.flink.streaming.api.operators.OneInputStreamOperator import org.apache.flink.streaming.api.transformations.OneInputTransformation import org.apache.flink.table.data.RowData import org.apache.flink.table.functions.python.PythonAggregateFunctionInfo -import org.apache.flink.table.planner.calcite.FlinkTypeFactory -import org.apache.flink.table.planner.delegation.StreamPlanner -import org.apache.flink.table.planner.plan.nodes.common.CommonPythonAggregate -import org.apache.flink.table.planner.plan.nodes.exec.LegacyStreamExecNode -import org.apache.flink.table.planner.plan.utils._ +import org.apache.flink.table.planner.delegation.PlannerBase +import org.apache.flink.table.planner.plan.nodes.exec.common.CommonExecPythonAggregate +import org.apache.flink.table.planner.plan.nodes.exec.{ExecEdge, ExecNode, ExecNodeBase} +import org.apache.flink.table.planner.plan.utils.{AggregateUtil, KeySelectorUtil} import org.apache.flink.table.planner.typeutils.DataViewUtils.DataViewSpec +import org.apache.flink.table.planner.utils.Logging import org.apache.flink.table.runtime.typeutils.InternalTypeInfo import org.apache.flink.table.types.logical.RowType -import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} -import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.core.AggregateCall -import org.apache.calcite.rel.{RelNode, RelWriter} -import java.util +import java.util.Collections /** - * Stream physical RelNode for Python unbounded group aggregate. - * - * @see [[StreamPhysicalGroupAggregateBase]] for more info. - */ + * Stream [[ExecNode]] for Python unbounded group aggregate. + * + * <p>Note: This class can't be ported to Java, + * because java class can't extend scala interface with default implementation. + * FLINK-20750 will port this class to Java. + */ class StreamExecPythonGroupAggregate( - cluster: RelOptCluster, - traitSet: RelTraitSet, - inputRel: RelNode, - outputRowType: RelDataType, - val grouping: Array[Int], - val aggCalls: Seq[AggregateCall]) - extends StreamPhysicalGroupAggregateBase(cluster, traitSet, inputRel) - with LegacyStreamExecNode[RowData] - with CommonPythonAggregate { - - val aggInfoList: AggregateInfoList = AggregateUtil.deriveAggregateInfoList( - this, - grouping.length, - aggCalls) - - override def requireWatermark: Boolean = false - - override def deriveRowType(): RelDataType = outputRowType - - override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = { - new StreamExecPythonGroupAggregate( - cluster, - traitSet, - inputs.get(0), - outputRowType, - grouping, - aggCalls) - } - - override def explainTerms(pw: RelWriter): RelWriter = { - val inputRowType = getInput.getRowType - super.explainTerms(pw) - .itemIf("groupBy", - RelExplainUtil.fieldToString(grouping, inputRowType), grouping.nonEmpty) - .item("select", RelExplainUtil.streamGroupAggregationToString( - inputRowType, - getRowType, - aggInfoList, - grouping)) - } - - //~ ExecNode methods ----------------------------------------------------------- - - override protected def translateToPlanInternal( - planner: StreamPlanner): Transformation[RowData] = { - + grouping: Array[Int], + aggCalls: Seq[AggregateCall], + aggCallNeedRetractions: Array[Boolean], + generateUpdateBefore: Boolean, + needRetraction: Boolean, + inputEdge: ExecEdge, + outputType: RowType, + description: String) + extends ExecNodeBase[RowData](Collections.singletonList(inputEdge), outputType, description) + with StreamExecNode[RowData] + with CommonExecPythonAggregate + with Logging { + + override protected def translateToPlanInternal(planner: PlannerBase): Transformation[RowData] = { val tableConfig = planner.getTableConfig if (grouping.length > 0 && tableConfig.getMinIdleStateRetentionTime < 0) { @@ -101,16 +68,17 @@ class StreamExecPythonGroupAggregate( "state size. You may specify a retention time of 0 to not clean up the state.") } - val inputTransformation = getInputNodes.get(0).translateToPlan(planner) - .asInstanceOf[Transformation[RowData]] - - val outRowType = FlinkTypeFactory.toLogicalRowType(outputRowType) - val inputRowType = FlinkTypeFactory.toLogicalRowType(getInput.getRowType) - - val generateUpdateBefore = ChangelogPlanUtils.generateUpdateBefore(this) + val inputNode = getInputNodes.get(0).asInstanceOf[ExecNode[RowData]] + val inputTransformation = inputNode.translateToPlan(planner) + val inputRowType = inputNode.getOutputType.asInstanceOf[RowType] + val aggInfoList = AggregateUtil.transformToStreamAggregateInfoList( + inputRowType, + aggCalls, + aggCallNeedRetractions, + needRetraction, + isStateBackendDataViews = true) val inputCountIndex = aggInfoList.getIndexOfCountStar - val countStarInserted = aggInfoList.countStarInserted var (pythonFunctionInfos, dataViewSpecs) = @@ -123,7 +91,7 @@ class StreamExecPythonGroupAggregate( val operator = getPythonAggregateFunctionOperator( getConfig(planner.getExecEnv, tableConfig), inputRowType, - outRowType, + outputType, pythonFunctionInfos, dataViewSpecs, tableConfig.getMinIdleStateRetentionTime, @@ -133,16 +101,14 @@ class StreamExecPythonGroupAggregate( inputCountIndex, countStarInserted) - val selector = KeySelectorUtil.getRowDataSelector( - grouping, - InternalTypeInfo.of(inputRowType)) + val selector = KeySelectorUtil.getRowDataSelector(grouping, InternalTypeInfo.of(inputRowType)) // partitioned aggregation val ret = new OneInputTransformation( inputTransformation, - getRelDetailedDescription, + getDesc, operator, - InternalTypeInfo.of(outRowType), + InternalTypeInfo.of(outputType), inputTransformation.getParallelism) if (inputsContainSingleton()) { diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecPythonGroupAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecPythonGroupAggregate.scala index f0d02c6..54ba618 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecPythonGroupAggregate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecPythonGroupAggregate.scala @@ -29,8 +29,8 @@ import org.apache.flink.table.functions.python.PythonFunctionInfo import org.apache.flink.table.planner.calcite.FlinkTypeFactory import org.apache.flink.table.planner.delegation.BatchPlanner import org.apache.flink.table.planner.plan.`trait`.{FlinkRelDistribution, FlinkRelDistributionTraitDef} -import org.apache.flink.table.planner.plan.nodes.common.CommonPythonAggregate -import org.apache.flink.table.planner.plan.nodes.exec.{LegacyBatchExecNode, ExecEdge} +import org.apache.flink.table.planner.plan.nodes.exec.common.CommonExecPythonAggregate +import org.apache.flink.table.planner.plan.nodes.exec.{ExecEdge, LegacyBatchExecNode} import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchExecPythonGroupAggregate.ARROW_PYTHON_AGGREGATE_FUNCTION_OPERATOR_NAME import org.apache.flink.table.planner.plan.rules.physical.batch.BatchExecJoinRuleBase import org.apache.flink.table.planner.plan.utils.{FlinkRelOptUtil, RelExplainUtil} @@ -75,7 +75,7 @@ class BatchExecPythonGroupAggregate( isMerge = false, isFinal = true) with LegacyBatchExecNode[RowData] - with CommonPythonAggregate { + with CommonExecPythonAggregate { override def explainTerms(pw: RelWriter): RelWriter = super.explainTerms(pw) diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecPythonGroupWindowAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecPythonGroupWindowAggregate.scala index 989e053..ab3ce2c 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecPythonGroupWindowAggregate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecPythonGroupWindowAggregate.scala @@ -34,8 +34,8 @@ import org.apache.flink.table.planner.delegation.BatchPlanner import org.apache.flink.table.planner.expressions.{PlannerRowtimeAttribute, PlannerWindowEnd, PlannerWindowStart} import org.apache.flink.table.planner.plan.cost.{FlinkCost, FlinkCostFactory} import org.apache.flink.table.planner.plan.logical.LogicalWindow -import org.apache.flink.table.planner.plan.nodes.common.CommonPythonAggregate -import org.apache.flink.table.planner.plan.nodes.exec.{LegacyBatchExecNode, ExecEdge} +import org.apache.flink.table.planner.plan.nodes.exec.common.CommonExecPythonAggregate +import org.apache.flink.table.planner.plan.nodes.exec.{ExecEdge, LegacyBatchExecNode} import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchExecPythonGroupWindowAggregate.ARROW_PYTHON_GROUP_WINDOW_AGGREGATE_FUNCTION_OPERATOR_NAME import org.apache.flink.table.runtime.typeutils.InternalTypeInfo import org.apache.flink.table.types.logical.RowType @@ -84,7 +84,7 @@ class BatchExecPythonGroupWindowAggregate( false, true) with LegacyBatchExecNode[RowData] - with CommonPythonAggregate { + with CommonExecPythonAggregate { override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = { new BatchExecPythonGroupWindowAggregate( diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecPythonOverAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecPythonOverAggregate.scala index dea0d33..9c62a8b 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecPythonOverAggregate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecPythonOverAggregate.scala @@ -28,7 +28,7 @@ import org.apache.flink.table.functions.UserDefinedFunction import org.apache.flink.table.functions.python.PythonFunctionInfo import org.apache.flink.table.planner.calcite.FlinkTypeFactory import org.apache.flink.table.planner.delegation.BatchPlanner -import org.apache.flink.table.planner.plan.nodes.common.CommonPythonAggregate +import org.apache.flink.table.planner.plan.nodes.exec.common.CommonExecPythonAggregate import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchExecPythonOverAggregate.ARROW_PYTHON_OVER_WINDOW_AGGREGATE_FUNCTION_OPERATOR_NAME import org.apache.flink.table.planner.plan.utils.OverAggregateUtil.getLongBoundary import org.apache.flink.table.runtime.typeutils.InternalTypeInfo @@ -73,7 +73,7 @@ class BatchExecPythonOverAggregate( nullIsLasts, windowGroupToAggCallToAggFunction, logicWindow) - with CommonPythonAggregate { + with CommonExecPythonAggregate { override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = { new BatchExecPythonOverAggregate( diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonGroupTableAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonGroupTableAggregate.scala index c92359f..cab5a46 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonGroupTableAggregate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonGroupTableAggregate.scala @@ -26,7 +26,7 @@ import org.apache.flink.table.data.RowData import org.apache.flink.table.functions.python.PythonAggregateFunctionInfo import org.apache.flink.table.planner.calcite.FlinkTypeFactory import org.apache.flink.table.planner.delegation.StreamPlanner -import org.apache.flink.table.planner.plan.nodes.common.CommonPythonAggregate +import org.apache.flink.table.planner.plan.nodes.exec.common.CommonExecPythonAggregate import org.apache.flink.table.planner.plan.utils.{ChangelogPlanUtils, KeySelectorUtil} import org.apache.flink.table.planner.typeutils.DataViewUtils.DataViewSpec import org.apache.flink.table.runtime.typeutils.InternalTypeInfo @@ -56,7 +56,7 @@ class StreamExecPythonGroupTableAggregate( outputRowType, grouping, aggCalls) - with CommonPythonAggregate { + with CommonExecPythonAggregate { override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = { new StreamExecPythonGroupTableAggregate( diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonGroupWindowAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonGroupWindowAggregate.scala index 54479ba..ee54572 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonGroupWindowAggregate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonGroupWindowAggregate.scala @@ -31,7 +31,7 @@ import org.apache.flink.table.planner.calcite.FlinkTypeFactory import org.apache.flink.table.planner.delegation.StreamPlanner import org.apache.flink.table.planner.expressions.{PlannerProctimeAttribute, PlannerRowtimeAttribute, PlannerWindowEnd, PlannerWindowStart} import org.apache.flink.table.planner.plan.logical.{LogicalWindow, SlidingGroupWindow, TumblingGroupWindow} -import org.apache.flink.table.planner.plan.nodes.common.CommonPythonAggregate +import org.apache.flink.table.planner.plan.nodes.exec.common.CommonExecPythonAggregate import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamExecPythonGroupWindowAggregate.ARROW_STREAM_PYTHON_GROUP_WINDOW_AGGREGATE_FUNCTION_OPERATOR_NAME import org.apache.flink.table.planner.plan.utils.AggregateUtil._ import org.apache.flink.table.planner.plan.utils.{KeySelectorUtil, WindowEmitStrategy} @@ -73,7 +73,7 @@ class StreamExecPythonGroupWindowAggregate( inputTimeFieldIndex, emitStrategy, "PythonAggregate") - with CommonPythonAggregate { + with CommonExecPythonAggregate { override def copy(traitSet: RelTraitSet, inputs: java.util.List[RelNode]): RelNode = { new StreamExecPythonGroupWindowAggregate( diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonOverAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonOverAggregate.scala index 930ea87..e7d9d32 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonOverAggregate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamExecPythonOverAggregate.scala @@ -27,13 +27,12 @@ import org.apache.flink.table.data.RowData import org.apache.flink.table.functions.python.PythonFunctionInfo import org.apache.flink.table.planner.calcite.FlinkTypeFactory import org.apache.flink.table.planner.delegation.StreamPlanner -import org.apache.flink.table.planner.plan.nodes.common.CommonPythonAggregate -import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamExecPythonOverAggregate.{ +import org.apache.flink.table.planner.plan.nodes.exec.common.CommonExecPythonAggregate +import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamExecPythonOverAggregate +.{ARROW_PYTHON_OVER_WINDOW_RANGE_PROC_TIME_AGGREGATE_FUNCTION_OPERATOR_NAME, ARROW_PYTHON_OVER_WINDOW_RANGE_ROW_TIME_AGGREGATE_FUNCTION_OPERATOR_NAME, - ARROW_PYTHON_OVER_WINDOW_RANGE_PROC_TIME_AGGREGATE_FUNCTION_OPERATOR_NAME, - ARROW_PYTHON_OVER_WINDOW_ROWS_ROW_TIME_AGGREGATE_FUNCTION_OPERATOR_NAME, - ARROW_PYTHON_OVER_WINDOW_ROWS_PROC_TIME_AGGREGATE_FUNCTION_OPERATOR_NAME -} + ARROW_PYTHON_OVER_WINDOW_ROWS_PROC_TIME_AGGREGATE_FUNCTION_OPERATOR_NAME, + ARROW_PYTHON_OVER_WINDOW_ROWS_ROW_TIME_AGGREGATE_FUNCTION_OPERATOR_NAME} import org.apache.flink.table.planner.plan.utils.{KeySelectorUtil, OverAggregateUtil} import org.apache.flink.table.runtime.typeutils.InternalTypeInfo import org.apache.flink.table.types.logical.RowType @@ -66,7 +65,7 @@ class StreamExecPythonOverAggregate( outputRowType, inputRowType, logicWindow) - with CommonPythonAggregate { + with CommonExecPythonAggregate { override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = { new StreamExecPythonOverAggregate( diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalPythonGroupAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalPythonGroupAggregate.scala new file mode 100644 index 0000000..029779f --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalPythonGroupAggregate.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.nodes.physical.stream + +import org.apache.flink.table.planner.calcite.FlinkTypeFactory +import org.apache.flink.table.planner.plan.nodes.exec.stream.StreamExecPythonGroupAggregate +import org.apache.flink.table.planner.plan.nodes.exec.{ExecEdge, ExecNode} +import org.apache.flink.table.planner.plan.utils._ + +import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.rel.core.AggregateCall +import org.apache.calcite.rel.{RelNode, RelWriter} + +import java.util + +/** + * Stream physical RelNode for Python unbounded group aggregate. + * + * @see [[StreamPhysicalGroupAggregateBase]] for more info. + */ +class StreamPhysicalPythonGroupAggregate( + cluster: RelOptCluster, + traitSet: RelTraitSet, + inputRel: RelNode, + outputRowType: RelDataType, + val grouping: Array[Int], + val aggCalls: Seq[AggregateCall]) + extends StreamPhysicalGroupAggregateBase(cluster, traitSet, inputRel) { + + private lazy val aggInfoList = + AggregateUtil.deriveAggregateInfoList(this, grouping.length, aggCalls) + + override def requireWatermark: Boolean = false + + override def deriveRowType(): RelDataType = outputRowType + + override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = { + new StreamPhysicalPythonGroupAggregate( + cluster, + traitSet, + inputs.get(0), + outputRowType, + grouping, + aggCalls) + } + + override def explainTerms(pw: RelWriter): RelWriter = { + val inputRowType = getInput.getRowType + super.explainTerms(pw) + .itemIf("groupBy", + RelExplainUtil.fieldToString(grouping, inputRowType), grouping.nonEmpty) + .item("select", RelExplainUtil.streamGroupAggregationToString( + inputRowType, + getRowType, + aggInfoList, + grouping)) + } + + override def translateToExecNode(): ExecNode[_] = { + val aggCallNeedRetractions = + AggregateUtil.deriveAggCallNeedRetractions(this, grouping.length, aggCalls) + val generateUpdateBefore = ChangelogPlanUtils.generateUpdateBefore(this) + val needRetraction = !ChangelogPlanUtils.inputInsertOnly(this) + new StreamExecPythonGroupAggregate( + grouping, + aggCalls, + aggCallNeedRetractions, + generateUpdateBefore, + needRetraction, + ExecEdge.DEFAULT, + FlinkTypeFactory.toLogicalRowType(getRowType), + getRelDetailedDescription + ) + } +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/optimize/program/FlinkChangelogModeInferenceProgram.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/optimize/program/FlinkChangelogModeInferenceProgram.scala index 8d276ca..2e862e1 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/optimize/program/FlinkChangelogModeInferenceProgram.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/optimize/program/FlinkChangelogModeInferenceProgram.scala @@ -185,7 +185,7 @@ class FlinkChangelogModeInferenceProgram extends FlinkOptimizeProgram[StreamOpti createNewNode( tagg, children, ModifyKindSetTrait.ALL_CHANGES, requiredTrait, requester) - case agg: StreamExecPythonGroupAggregate => + case agg: StreamPhysicalPythonGroupAggregate => // agg support all changes in input val children = visitChildren(agg, ModifyKindSetTrait.ALL_CHANGES) val inputModifyKindSet = getModifyKindSet(children.head) @@ -462,7 +462,7 @@ class FlinkChangelogModeInferenceProgram extends FlinkOptimizeProgram[StreamOpti visitSink(sink, sinkRequiredTraits) case _: StreamPhysicalGroupAggregate | _: StreamExecGroupTableAggregate | - _: StreamPhysicalLimit | _: StreamExecPythonGroupAggregate | + _: StreamPhysicalLimit | _: StreamPhysicalPythonGroupAggregate | _: StreamExecPythonGroupTableAggregate => // Aggregate, TableAggregate and Limit requires update_before if there are updates val requiredChildTrait = beforeAfterOrNone(getModifyKindSet(rel.getInput(0))) diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala index d7c695d..4fc6e04 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala @@ -422,7 +422,7 @@ object FlinkStreamRuleSets { // group agg StreamPhysicalGroupAggregateRule.INSTANCE, StreamExecGroupTableAggregateRule.INSTANCE, - StreamExecPythonGroupAggregateRule.INSTANCE, + StreamPhysicalPythonGroupAggregateRule.INSTANCE, StreamExecPythonGroupTableAggregateRule.INSTANCE, // over agg StreamExecOverAggregateRule.INSTANCE, diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/stream/table/PythonAggregateTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/stream/table/PythonAggregateTest.scala index 4258824..4c26534 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/stream/table/PythonAggregateTest.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/stream/table/PythonAggregateTest.scala @@ -22,11 +22,12 @@ import org.apache.flink.api.java.tuple.Tuple1 import org.apache.flink.api.scala._ import org.apache.flink.table.api._ import org.apache.flink.table.api.dataview.{ListView, MapView} -import org.apache.flink.table.planner.plan.nodes.common.CommonPythonAggregate +import org.apache.flink.table.planner.plan.nodes.exec.common.CommonExecPythonAggregate import org.apache.flink.table.planner.runtime.utils.JavaUserDefinedAggFunctions.TestPythonAggregateFunction import org.apache.flink.table.planner.typeutils.DataViewUtils.{DataViewSpec, ListViewSpec, MapViewSpec} import org.apache.flink.table.planner.utils.TableTestBase import org.apache.flink.table.types.DataType + import org.junit.Assert.assertEquals import org.junit.Test @@ -115,7 +116,7 @@ class PythonAggregateTest extends TableTestBase { } } -object TestCommonPythonAggregate extends CommonPythonAggregate { +object TestCommonPythonAggregate extends CommonExecPythonAggregate { def extractDataViewSpecs(accType: DataType): Array[DataViewSpec] = { extractDataViewSpecs(0, accType) }
