This is an automated email from the ASF dual-hosted git repository. godfrey pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit 146c68df81f8e3507fdc158a523d6a9ea6fa0bd6 Author: godfreyhe <[email protected]> AuthorDate: Tue Jan 5 18:25:37 2021 +0800 [FLINK-20738][table-planner-blink] Rename BatchExecGroupAggregateBase to BatchPhysicalGroupAggregateBase and do some refactoring This closes #14562 --- .../metadata/AggCallSelectivityEstimator.scala | 132 +++---- .../plan/metadata/FlinkRelMdColumnInterval.scala | 428 ++++++++++----------- .../plan/metadata/FlinkRelMdColumnUniqueness.scala | 4 +- .../plan/metadata/FlinkRelMdDistinctRowCount.scala | 4 +- .../FlinkRelMdFilteredColumnInterval.scala | 4 +- .../metadata/FlinkRelMdModifiedMonotonicity.scala | 32 +- .../FlinkRelMdPercentageOriginalRows.scala | 6 +- .../plan/metadata/FlinkRelMdPopulationSize.scala | 2 +- .../planner/plan/metadata/FlinkRelMdRowCount.scala | 6 +- .../plan/metadata/FlinkRelMdSelectivity.scala | 6 +- .../planner/plan/metadata/FlinkRelMdSize.scala | 6 +- .../plan/metadata/FlinkRelMdUniqueGroups.scala | 4 +- .../plan/metadata/FlinkRelMdUniqueKeys.scala | 10 +- .../physical/batch/BatchExecHashAggregate.scala | 6 +- .../batch/BatchExecHashAggregateBase.scala | 22 +- .../batch/BatchExecLocalHashAggregate.scala | 12 +- .../batch/BatchExecLocalSortAggregate.scala | 12 +- .../batch/BatchExecPythonGroupAggregate.scala | 8 +- .../physical/batch/BatchExecSortAggregate.scala | 12 +- .../batch/BatchExecSortAggregateBase.scala | 28 +- ...scala => BatchPhysicalGroupAggregateBase.scala} | 41 +- .../physical/batch/BatchExecHashAggRule.scala | 5 +- .../physical/batch/BatchExecHashJoinRule.scala | 2 +- .../physical/batch/BatchExecJoinRuleBase.scala | 5 +- .../batch/BatchExecNestedLoopJoinRule.scala | 2 +- .../physical/batch/BatchExecSortAggRule.scala | 51 ++- .../batch/BatchExecWindowAggregateRule.scala | 66 ++-- ...leBase.scala => BatchPhysicalAggRuleBase.scala} | 10 +- .../physical/batch/EnforceLocalAggRuleBase.scala | 46 +-- .../physical/batch/EnforceLocalHashAggRule.scala | 56 +-- .../physical/batch/EnforceLocalSortAggRule.scala | 8 +- .../batch/RemoveRedundantLocalHashAggRule.scala | 15 +- .../batch/RemoveRedundantLocalSortAggRule.scala | 5 +- .../table/planner/plan/utils/FlinkRelMdUtil.scala | 323 ++++++++-------- .../plan/metadata/FlinkRelMdHandlerTestBase.scala | 69 +--- .../metadata/MetadataHandlerConsistencyTest.scala | 12 +- .../batch/EnforceLocalHashAggRuleTest.scala | 12 +- 37 files changed, 694 insertions(+), 778 deletions(-) diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/AggCallSelectivityEstimator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/AggCallSelectivityEstimator.scala index 39fd4d6..d7b5418 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/AggCallSelectivityEstimator.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/AggCallSelectivityEstimator.scala @@ -19,7 +19,7 @@ package org.apache.flink.table.planner.plan.metadata import org.apache.flink.table.planner.JDouble -import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchExecGroupAggregateBase, BatchExecLocalHashWindowAggregate, BatchExecLocalSortWindowAggregate, BatchExecWindowAggregateBase} +import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchExecLocalHashWindowAggregate, BatchExecLocalSortWindowAggregate, BatchExecWindowAggregateBase, BatchPhysicalGroupAggregateBase} import org.apache.flink.table.planner.plan.stats._ import org.apache.flink.table.planner.plan.utils.AggregateUtil @@ -34,16 +34,16 @@ import org.apache.calcite.sql.{SqlKind, SqlOperator} import scala.collection.JavaConversions._ /** - * Estimates selectivity of rows meeting an agg-call predicate on an Aggregate. - * - * A filter predicate on an Aggregate may contain two parts: - * one is on group by columns, another is on aggregate call's result. - * The first part is handled by [[SelectivityEstimator]], - * the second part is handled by this Estimator. - * - * @param agg aggregate node - * @param mq Metadata query - */ + * Estimates selectivity of rows meeting an agg-call predicate on an Aggregate. + * + * A filter predicate on an Aggregate may contain two parts: + * one is on group by columns, another is on aggregate call's result. + * The first part is handled by [[SelectivityEstimator]], + * the second part is handled by this Estimator. + * + * @param agg aggregate node + * @param mq Metadata query + */ class AggCallSelectivityEstimator(agg: RelNode, mq: FlinkRelMetadataQuery) extends RexVisitorImpl[Option[Double]](true) { @@ -53,15 +53,15 @@ class AggCallSelectivityEstimator(agg: RelNode, mq: FlinkRelMetadataQuery) private[flink] val defaultAggCallSelectivity = Some(0.01d) /** - * Gets AggregateCall from aggregate node - */ + * Gets AggregateCall from aggregate node + */ def getSupportedAggCall(outputIdx: Int): Option[AggregateCall] = { val (fullGrouping, aggCalls) = agg match { case rel: Aggregate => val (auxGroupSet, otherAggCalls) = AggregateUtil.checkAndSplitAggCalls(rel) (rel.getGroupSet.toArray ++ auxGroupSet, otherAggCalls) - case rel: BatchExecGroupAggregateBase => - (rel.getGrouping ++ rel.getAuxGrouping, rel.getAggCallList) + case rel: BatchPhysicalGroupAggregateBase => + (rel.grouping ++ rel.auxGrouping, rel.getAggCallList) case rel: BatchExecLocalHashWindowAggregate => val fullGrouping = rel.getGrouping ++ Array(rel.inputTimeFieldIndex) ++ rel.getAuxGrouping (fullGrouping, rel.getAggCallList) @@ -79,9 +79,9 @@ class AggCallSelectivityEstimator(agg: RelNode, mq: FlinkRelMetadataQuery) } /** - * Returns whether the given aggCall is supported now - * TODO supports more - */ + * Returns whether the given aggCall is supported now + * TODO supports more + */ def isSupportedAggCall(aggCall: AggregateCall): Boolean = { aggCall.getAggregation.getKind match { case SqlKind.SUM | SqlKind.MAX | SqlKind.MIN | SqlKind.AVG => true @@ -91,8 +91,8 @@ class AggCallSelectivityEstimator(agg: RelNode, mq: FlinkRelMetadataQuery) } /** - * Gets aggCall's interval through its argument's interval. - */ + * Gets aggCall's interval through its argument's interval. + */ def getAggCallInterval(aggCall: AggregateCall): ValueInterval = { val aggInput = agg.getInput(0) @@ -159,12 +159,12 @@ class AggCallSelectivityEstimator(agg: RelNode, mq: FlinkRelMetadataQuery) } /** - * Returns a percentage of rows meeting a filter predicate on aggregate. - * - * @param predicate predicate whose selectivity is to be estimated against aggregate calls. - * @return estimated selectivity (between 0.0 and 1.0), - * or None if no reliable estimate can be determined. - */ + * Returns a percentage of rows meeting a filter predicate on aggregate. + * + * @param predicate predicate whose selectivity is to be estimated against aggregate calls. + * @return estimated selectivity (between 0.0 and 1.0), + * or None if no reliable estimate can be determined. + */ def evaluate(predicate: RexNode): Option[Double] = { try { if (predicate == null) { @@ -213,12 +213,12 @@ class AggCallSelectivityEstimator(agg: RelNode, mq: FlinkRelMetadataQuery) } /** - * Returns a percentage of rows meeting a single condition in Filter node. - * - * @param singlePredicate predicate whose selectivity is to be estimated against aggregate calls. - * @return an optional double value to show the percentage of rows meeting a given condition. - * It returns None if the condition is not supported. - */ + * Returns a percentage of rows meeting a single condition in Filter node. + * + * @param singlePredicate predicate whose selectivity is to be estimated against aggregate calls. + * @return an optional double value to show the percentage of rows meeting a given condition. + * It returns None if the condition is not supported. + */ private def estimateSinglePredicate(singlePredicate: RexCall): Option[Double] = { val operands = singlePredicate.getOperands singlePredicate.getOperator match { @@ -250,14 +250,14 @@ class AggCallSelectivityEstimator(agg: RelNode, mq: FlinkRelMetadataQuery) } /** - * Returns a percentage of rows meeting a binary comparison expression containing two columns. - * - * @param op a binary comparison operator, including =, <=>, <, <=, >, >= - * @param left the left RexInputRef - * @param right the right RexInputRef - * @return an optional double value to show the percentage of rows meeting a given condition. - * It returns None if no statistics collected for a given column. - */ + * Returns a percentage of rows meeting a binary comparison expression containing two columns. + * + * @param op a binary comparison operator, including =, <=>, <, <=, >, >= + * @param left the left RexInputRef + * @param right the right RexInputRef + * @return an optional double value to show the percentage of rows meeting a given condition. + * It returns None if no statistics collected for a given column. + */ private def estimateComparison(op: SqlOperator, left: RexNode, right: RexNode): Option[Double] = { // if we can't handle some cases, uses SelectivityEstimator's default value // (consistent with normal case). @@ -302,14 +302,14 @@ class AggCallSelectivityEstimator(agg: RelNode, mq: FlinkRelMetadataQuery) } /** - * Returns a percentage of rows meeting an equality (=) expression. - * e.g. count(a) = 10 - * - * @param inputRef a RexInputRef - * @param literal a literal value (or constant) - * @return an optional double value to show the percentage of rows meeting a given condition. - * It returns None if no statistics collected for a given column. - */ + * Returns a percentage of rows meeting an equality (=) expression. + * e.g. count(a) = 10 + * + * @param inputRef a RexInputRef + * @param literal a literal value (or constant) + * @return an optional double value to show the percentage of rows meeting a given condition. + * It returns None if no statistics collected for a given column. + */ private def estimateEquals(inputRef: RexInputRef, literal: RexLiteral): Option[Double] = { if (literal.isNull) { return se.defaultIsNullSelectivity @@ -345,15 +345,15 @@ class AggCallSelectivityEstimator(agg: RelNode, mq: FlinkRelMetadataQuery) } /** - * Returns a percentage of rows meeting a binary comparison expression. - * e.g. sum(a) > 10 - * - * @param op a binary comparison operator, including <, <=, >, >= - * @param inputRef a RexInputRef - * @param literal a literal value (or constant) - * @return an optional double value to show the percentage of rows meeting a given condition. - * It returns None if no statistics collected for a given column. - */ + * Returns a percentage of rows meeting a binary comparison expression. + * e.g. sum(a) > 10 + * + * @param op a binary comparison operator, including <, <=, >, >= + * @param inputRef a RexInputRef + * @param literal a literal value (or constant) + * @return an optional double value to show the percentage of rows meeting a given condition. + * It returns None if no statistics collected for a given column. + */ private def estimateComparison( op: SqlOperator, inputRef: RexInputRef, @@ -372,15 +372,15 @@ class AggCallSelectivityEstimator(agg: RelNode, mq: FlinkRelMetadataQuery) } /** - * Returns a percentage of rows meeting a binary numeric comparison expression. - * This method evaluate expression for Numeric/Boolean/Date/Time/Timestamp columns. - * - * @param op a binary comparison operator, including <, <=, >, >= - * @param aggCall an AggregateCall - * @param literal a literal value (or constant) - * @return an optional double value to show the percentage of rows meeting a given condition. - * It returns None if no statistics collected for a given column. - */ + * Returns a percentage of rows meeting a binary numeric comparison expression. + * This method evaluate expression for Numeric/Boolean/Date/Time/Timestamp columns. + * + * @param op a binary comparison operator, including <, <=, >, >= + * @param aggCall an AggregateCall + * @param literal a literal value (or constant) + * @return an optional double value to show the percentage of rows meeting a given condition. + * It returns None if no statistics collected for a given column. + */ private def estimateNumericComparison( op: SqlOperator, aggCall: AggregateCall, diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala index 3bb4139..03f6f6f 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala @@ -45,21 +45,21 @@ import java.math.{BigDecimal => JBigDecimal} import scala.collection.JavaConversions._ /** - * FlinkRelMdColumnInterval supplies a default implementation of - * [[FlinkRelMetadataQuery.getColumnInterval]] for the standard logical algebra. - */ + * FlinkRelMdColumnInterval supplies a default implementation of + * [[FlinkRelMetadataQuery.getColumnInterval]] for the standard logical algebra. + */ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] { override def getDef: MetadataDef[ColumnInterval] = FlinkMetadata.ColumnInterval.DEF /** - * Gets interval of the given column on TableScan. - * - * @param ts TableScan RelNode - * @param mq RelMetadataQuery instance - * @param index the index of the given column - * @return interval of the given column on TableScan - */ + * Gets interval of the given column on TableScan. + * + * @param ts TableScan RelNode + * @param mq RelMetadataQuery instance + * @param index the index of the given column + * @return interval of the given column on TableScan + */ def getColumnInterval(ts: TableScan, mq: RelMetadataQuery, index: Int): ValueInterval = { val relOptTable = ts.getTable.asInstanceOf[FlinkPreparingTableBase] val fieldNames = relOptTable.getRowType.getFieldNames @@ -105,13 +105,13 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] { } /** - * Gets interval of the given column on Values. - * - * @param values Values RelNode - * @param mq RelMetadataQuery instance - * @param index the index of the given column - * @return interval of the given column on Values - */ + * Gets interval of the given column on Values. + * + * @param values Values RelNode + * @param mq RelMetadataQuery instance + * @param index the index of the given column + * @return interval of the given column on Values + */ def getColumnInterval(values: Values, mq: RelMetadataQuery, index: Int): ValueInterval = { val tuples = values.tuples if (tuples.isEmpty) { @@ -129,25 +129,25 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] { } /** - * Gets interval of the given column on Snapshot. - * - * @param snapshot Snapshot RelNode - * @param mq RelMetadataQuery instance - * @param index the index of the given column - * @return interval of the given column on Snapshot. - */ + * Gets interval of the given column on Snapshot. + * + * @param snapshot Snapshot RelNode + * @param mq RelMetadataQuery instance + * @param index the index of the given column + * @return interval of the given column on Snapshot. + */ def getColumnInterval(snapshot: Snapshot, mq: RelMetadataQuery, index: Int): ValueInterval = null /** - * Gets interval of the given column on Project. - * - * Note: Only support the simple RexNode, e.g RexInputRef. - * - * @param project Project RelNode - * @param mq RelMetadataQuery instance - * @param index the index of the given column - * @return interval of the given column on Project - */ + * Gets interval of the given column on Project. + * + * Note: Only support the simple RexNode, e.g RexInputRef. + * + * @param project Project RelNode + * @param mq RelMetadataQuery instance + * @param index the index of the given column + * @return interval of the given column on Project + */ def getColumnInterval(project: Project, mq: RelMetadataQuery, index: Int): ValueInterval = { val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq) val projects = project.getProjects @@ -168,13 +168,13 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] { } /** - * Gets interval of the given column on Filter. - * - * @param filter Filter RelNode - * @param mq RelMetadataQuery instance - * @param index the index of the given column - * @return interval of the given column on Filter - */ + * Gets interval of the given column on Filter. + * + * @param filter Filter RelNode + * @param mq RelMetadataQuery instance + * @param index the index of the given column + * @return interval of the given column on Filter + */ def getColumnInterval(filter: Filter, mq: RelMetadataQuery, index: Int): ValueInterval = { val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq) val inputValueInterval = fmq.getColumnInterval(filter.getInput, index) @@ -189,13 +189,13 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] { } /** - * Gets interval of the given column on Calc. - * - * @param calc Filter RelNode - * @param mq RelMetadataQuery instance - * @param index the index of the given column - * @return interval of the given column on Calc - */ + * Gets interval of the given column on Calc. + * + * @param calc Filter RelNode + * @param mq RelMetadataQuery instance + * @param index the index of the given column + * @return interval of the given column on Calc + */ def getColumnInterval(calc: Calc, mq: RelMetadataQuery, index: Int): ValueInterval = { val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq) val rexProgram = calc.getProgram @@ -204,10 +204,10 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] { } /** - * Calculate interval of column which results from the given rex node in calc. - * Note that this function is called by function above, and is reclusive in case - * of "AS" rex call, and is private, too. - */ + * Calculate interval of column which results from the given rex node in calc. + * Note that this function is called by function above, and is reclusive in case + * of "AS" rex call, and is private, too. + */ private def getColumnIntervalOfCalc( calc: Calc, mq: RelMetadataQuery, @@ -295,39 +295,39 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] { } /** - * Gets interval of the given column on Exchange. - * - * @param exchange Exchange RelNode - * @param mq RelMetadataQuery instance - * @param index the index of the given column - * @return interval of the given column on Exchange - */ + * Gets interval of the given column on Exchange. + * + * @param exchange Exchange RelNode + * @param mq RelMetadataQuery instance + * @param index the index of the given column + * @return interval of the given column on Exchange + */ def getColumnInterval(exchange: Exchange, mq: RelMetadataQuery, index: Int): ValueInterval = { val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq) fmq.getColumnInterval(exchange.getInput, index) } /** - * Gets interval of the given column on Sort. - * - * @param sort Sort RelNode - * @param mq RelMetadataQuery instance - * @param index the index of the given column - * @return interval of the given column on Sort - */ + * Gets interval of the given column on Sort. + * + * @param sort Sort RelNode + * @param mq RelMetadataQuery instance + * @param index the index of the given column + * @return interval of the given column on Sort + */ def getColumnInterval(sort: Sort, mq: RelMetadataQuery, index: Int): ValueInterval = { val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq) fmq.getColumnInterval(sort.getInput, index) } /** - * Gets interval of the given column of Expand. - * - * @param expand expand RelNode - * @param mq RelMetadataQuery instance - * @param index the index of the given column - * @return interval of the given column in batch sort - */ + * Gets interval of the given column of Expand. + * + * @param expand expand RelNode + * @param mq RelMetadataQuery instance + * @param index the index of the given column + * @return interval of the given column in batch sort + */ def getColumnInterval( expand: Expand, mq: RelMetadataQuery, @@ -355,13 +355,13 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] { } /** - * Gets interval of the given column on Rank. - * - * @param rank [[Rank]] instance to analyze - * @param mq RelMetadataQuery instance - * @param index the index of the given column - * @return interval of the given column on Rank - */ + * Gets interval of the given column on Rank. + * + * @param rank [[Rank]] instance to analyze + * @param mq RelMetadataQuery instance + * @param index the index of the given column + * @return interval of the given column on Rank + */ def getColumnInterval( rank: Rank, mq: RelMetadataQuery, @@ -387,24 +387,24 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] { } /** - * Gets interval of the given column on Aggregates. - * - * @param aggregate Aggregate RelNode - * @param mq RelMetadataQuery instance - * @param index the index of the given column - * @return interval of the given column on Aggregate - */ + * Gets interval of the given column on Aggregates. + * + * @param aggregate Aggregate RelNode + * @param mq RelMetadataQuery instance + * @param index the index of the given column + * @return interval of the given column on Aggregate + */ def getColumnInterval(aggregate: Aggregate, mq: RelMetadataQuery, index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index) /** - * Gets interval of the given column on TableAggregates. - * - * @param aggregate TableAggregate RelNode - * @param mq RelMetadataQuery instance - * @param index the index of the given column - * @return interval of the given column on TableAggregate - */ + * Gets interval of the given column on TableAggregates. + * + * @param aggregate TableAggregate RelNode + * @param mq RelMetadataQuery instance + * @param index the index of the given column + * @return interval of the given column on TableAggregate + */ def getColumnInterval( aggregate: TableAggregate, mq: RelMetadataQuery, index: Int): ValueInterval = @@ -412,121 +412,121 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] { estimateColumnIntervalOfAggregate(aggregate, mq, index) /** - * Gets interval of the given column on batch group aggregate. - * - * @param aggregate batch group aggregate RelNode - * @param mq RelMetadataQuery instance - * @param index the index of the given column - * @return interval of the given column on batch group aggregate - */ + * Gets interval of the given column on batch group aggregate. + * + * @param aggregate batch group aggregate RelNode + * @param mq RelMetadataQuery instance + * @param index the index of the given column + * @return interval of the given column on batch group aggregate + */ def getColumnInterval( - aggregate: BatchExecGroupAggregateBase, + aggregate: BatchPhysicalGroupAggregateBase, mq: RelMetadataQuery, index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index) /** - * Gets interval of the given column on stream group aggregate. - * - * @param aggregate stream group aggregate RelNode - * @param mq RelMetadataQuery instance - * @param index the index of the given column - * @return interval of the given column on stream group Aggregate - */ + * Gets interval of the given column on stream group aggregate. + * + * @param aggregate stream group aggregate RelNode + * @param mq RelMetadataQuery instance + * @param index the index of the given column + * @return interval of the given column on stream group Aggregate + */ def getColumnInterval( aggregate: StreamPhysicalGroupAggregate, mq: RelMetadataQuery, index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index) /** - * Gets interval of the given column on stream group table aggregate. - * - * @param aggregate stream group table aggregate RelNode - * @param mq RelMetadataQuery instance - * @param index the index of the given column - * @return interval of the given column on stream group TableAggregate - */ + * Gets interval of the given column on stream group table aggregate. + * + * @param aggregate stream group table aggregate RelNode + * @param mq RelMetadataQuery instance + * @param index the index of the given column + * @return interval of the given column on stream group TableAggregate + */ def getColumnInterval( - aggregate: StreamPhysicalGroupTableAggregate, - mq: RelMetadataQuery, - index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index) + aggregate: StreamPhysicalGroupTableAggregate, + mq: RelMetadataQuery, + index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index) /** - * Gets interval of the given column on stream local group aggregate. - * - * @param aggregate stream local group aggregate RelNode - * @param mq RelMetadataQuery instance - * @param index the index of the given column - * @return interval of the given column on stream local group Aggregate - */ + * Gets interval of the given column on stream local group aggregate. + * + * @param aggregate stream local group aggregate RelNode + * @param mq RelMetadataQuery instance + * @param index the index of the given column + * @return interval of the given column on stream local group Aggregate + */ def getColumnInterval( aggregate: StreamPhysicalLocalGroupAggregate, mq: RelMetadataQuery, index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index) /** - * Gets interval of the given column on stream global group aggregate. - * - * @param aggregate stream global group aggregate RelNode - * @param mq RelMetadataQuery instance - * @param index the index of the given column - * @return interval of the given column on stream global group Aggregate - */ + * Gets interval of the given column on stream global group aggregate. + * + * @param aggregate stream global group aggregate RelNode + * @param mq RelMetadataQuery instance + * @param index the index of the given column + * @return interval of the given column on stream global group Aggregate + */ def getColumnInterval( aggregate: StreamPhysicalGlobalGroupAggregate, mq: RelMetadataQuery, index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index) /** - * Gets interval of the given column on window aggregate. - * - * @param agg window aggregate RelNode - * @param mq RelMetadataQuery instance - * @param index the index of the given column - * @return interval of the given column on window Aggregate - */ + * Gets interval of the given column on window aggregate. + * + * @param agg window aggregate RelNode + * @param mq RelMetadataQuery instance + * @param index the index of the given column + * @return interval of the given column on window Aggregate + */ def getColumnInterval( agg: WindowAggregate, mq: RelMetadataQuery, index: Int): ValueInterval = estimateColumnIntervalOfAggregate(agg, mq, index) /** - * Gets interval of the given column on batch window aggregate. - * - * @param agg batch window aggregate RelNode - * @param mq RelMetadataQuery instance - * @param index the index of the given column - * @return interval of the given column on batch window Aggregate - */ + * Gets interval of the given column on batch window aggregate. + * + * @param agg batch window aggregate RelNode + * @param mq RelMetadataQuery instance + * @param index the index of the given column + * @return interval of the given column on batch window Aggregate + */ def getColumnInterval( agg: BatchExecWindowAggregateBase, mq: RelMetadataQuery, index: Int): ValueInterval = estimateColumnIntervalOfAggregate(agg, mq, index) /** - * Gets interval of the given column on stream window aggregate. - * - * @param agg stream window aggregate RelNode - * @param mq RelMetadataQuery instance - * @param index the index of the given column - * @return interval of the given column on stream window Aggregate - */ + * Gets interval of the given column on stream window aggregate. + * + * @param agg stream window aggregate RelNode + * @param mq RelMetadataQuery instance + * @param index the index of the given column + * @return interval of the given column on stream window Aggregate + */ def getColumnInterval( agg: StreamExecGroupWindowAggregate, mq: RelMetadataQuery, index: Int): ValueInterval = estimateColumnIntervalOfAggregate(agg, mq, index) /** - * Gets interval of the given column on stream window table aggregate. - * - * @param agg stream window table aggregate RelNode - * @param mq RelMetadataQuery instance - * @param index the index of the given column - * @return interval of the given column on stream window Aggregate - */ + * Gets interval of the given column on stream window table aggregate. + * + * @param agg stream window table aggregate RelNode + * @param mq RelMetadataQuery instance + * @param index the index of the given column + * @return interval of the given column on stream window Aggregate + */ def getColumnInterval( - agg: StreamExecGroupWindowTableAggregate, - mq: RelMetadataQuery, - index: Int): ValueInterval = estimateColumnIntervalOfAggregate(agg, mq, index) + agg: StreamExecGroupWindowTableAggregate, + mq: RelMetadataQuery, + index: Int): ValueInterval = estimateColumnIntervalOfAggregate(agg, mq, index) private def estimateColumnIntervalOfAggregate( aggregate: SingleRel, @@ -540,7 +540,7 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] { case agg: StreamPhysicalGlobalGroupAggregate => agg.grouping case agg: StreamPhysicalIncrementalGroupAggregate => agg.partialAggGrouping case agg: StreamExecGroupWindowAggregate => agg.getGrouping - case agg: BatchExecGroupAggregateBase => agg.getGrouping ++ agg.getAuxGrouping + case agg: BatchPhysicalGroupAggregateBase => agg.grouping ++ agg.auxGrouping case agg: Aggregate => AggregateUtil.checkAndGetFullGroupSet(agg) case agg: BatchExecLocalSortWindowAggregate => // grouping + assignTs + auxGrouping @@ -633,7 +633,7 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] { } else { null } - case agg: BatchExecGroupAggregateBase if agg.getAggCallList.length > aggCallIndex => + case agg: BatchPhysicalGroupAggregateBase if agg.getAggCallList.length > aggCallIndex => agg.getAggCallList(aggCallIndex) case agg: Aggregate => val (_, aggCalls) = AggregateUtil.checkAndSplitAggCalls(agg) @@ -683,13 +683,13 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] { } /** - * Gets interval of the given column on calcite window. - * - * @param window Window RelNode - * @param mq RelMetadataQuery instance - * @param index the index of the given column - * @return interval of the given column on window - */ + * Gets interval of the given column on calcite window. + * + * @param window Window RelNode + * @param mq RelMetadataQuery instance + * @param index the index of the given column + * @return interval of the given column on window + */ def getColumnInterval( window: Window, mq: RelMetadataQuery, @@ -698,26 +698,26 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] { } /** - * Gets interval of the given column on batch over aggregate. - * - * @param agg batch over aggregate RelNode - * @param mq RelMetadataQuery instance - * @param index he index of the given column - * @return interval of the given column on batch over aggregate. - */ + * Gets interval of the given column on batch over aggregate. + * + * @param agg batch over aggregate RelNode + * @param mq RelMetadataQuery instance + * @param index he index of the given column + * @return interval of the given column on batch over aggregate. + */ def getColumnInterval( agg: BatchExecOverAggregate, mq: RelMetadataQuery, index: Int): ValueInterval = getColumnIntervalOfOverAgg(agg, mq, index) /** - * Gets interval of the given column on stream over aggregate. - * - * @param agg stream over aggregate RelNode - * @param mq RelMetadataQuery instance - * @param index he index of the given column - * @return interval of the given column on stream over aggregate. - */ + * Gets interval of the given column on stream over aggregate. + * + * @param agg stream over aggregate RelNode + * @param mq RelMetadataQuery instance + * @param index he index of the given column + * @return interval of the given column on stream over aggregate. + */ def getColumnInterval( agg: StreamExecOverAggregate, mq: RelMetadataQuery, @@ -739,13 +739,13 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] { } /** - * Gets interval of the given column on Join. - * - * @param join Join RelNode - * @param mq RelMetadataQuery instance - * @param index the index of the given column - * @return interval of the given column on Join - */ + * Gets interval of the given column on Join. + * + * @param join Join RelNode + * @param mq RelMetadataQuery instance + * @param index the index of the given column + * @return interval of the given column on Join + */ def getColumnInterval(join: Join, mq: RelMetadataQuery, index: Int): ValueInterval = { val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq) val joinCondition = join.getCondition @@ -771,13 +771,13 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] { } /** - * Gets interval of the given column on Union. - * - * @param union Union RelNode - * @param mq RelMetadataQuery instance - * @param index the index of the given column - * @return interval of the given column on Union - */ + * Gets interval of the given column on Union. + * + * @param union Union RelNode + * @param mq RelMetadataQuery instance + * @param index the index of the given column + * @return interval of the given column on Union + */ def getColumnInterval(union: Union, mq: RelMetadataQuery, index: Int): ValueInterval = { val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq) val subIntervals = union @@ -787,13 +787,13 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] { } /** - * Gets interval of the given column on RelSubset. - * - * @param subset RelSubset to analyze - * @param mq RelMetadataQuery instance - * @param index the index of the given column - * @return If exist best relNode, then transmit to it, else transmit to the original relNode - */ + * Gets interval of the given column on RelSubset. + * + * @param subset RelSubset to analyze + * @param mq RelMetadataQuery instance + * @param index the index of the given column + * @return If exist best relNode, then transmit to it, else transmit to the original relNode + */ def getColumnInterval(subset: RelSubset, mq: RelMetadataQuery, index: Int): ValueInterval = { val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq) val rel = Util.first(subset.getBest, subset.getOriginal) @@ -801,13 +801,13 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] { } /** - * Catches-all rule when none of the others apply. - * - * @param rel RelNode to analyze - * @param mq RelMetadataQuery instance - * @param index the index of the given column - * @return Always returns null - */ + * Catches-all rule when none of the others apply. + * + * @param rel RelNode to analyze + * @param mq RelMetadataQuery instance + * @param index the index of the given column + * @return Always returns null + */ def getColumnInterval(rel: RelNode, mq: RelMetadataQuery, index: Int): ValueInterval = null } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnUniqueness.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnUniqueness.scala index a111e79..6e5dd3a 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnUniqueness.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnUniqueness.scala @@ -329,12 +329,12 @@ class FlinkRelMdColumnUniqueness private extends MetadataHandler[BuiltInMetadata } def areColumnsUnique( - rel: BatchExecGroupAggregateBase, + rel: BatchPhysicalGroupAggregateBase, mq: RelMetadataQuery, columns: ImmutableBitSet, ignoreNulls: Boolean): JBoolean = { if (rel.isFinal) { - areColumnsUniqueOnAggregate(rel.getGrouping, mq, columns, ignoreNulls) + areColumnsUniqueOnAggregate(rel.grouping, mq, columns, ignoreNulls) } else { null } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdDistinctRowCount.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdDistinctRowCount.scala index 99e290c..ee52362 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdDistinctRowCount.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdDistinctRowCount.scala @@ -309,7 +309,7 @@ class FlinkRelMdDistinctRowCount private extends MetadataHandler[BuiltInMetadata } def getDistinctRowCount( - rel: BatchExecGroupAggregateBase, + rel: BatchPhysicalGroupAggregateBase, mq: RelMetadataQuery, groupKey: ImmutableBitSet, predicate: RexNode): JDouble = { @@ -397,7 +397,7 @@ class FlinkRelMdDistinctRowCount private extends MetadataHandler[BuiltInMetadata predicate: RexNode): (Option[RexNode], Option[RexNode]) = agg match { case rel: Aggregate => FlinkRelMdUtil.splitPredicateOnAggregate(rel, predicate) - case rel: BatchExecGroupAggregateBase => + case rel: BatchPhysicalGroupAggregateBase => FlinkRelMdUtil.splitPredicateOnAggregate(rel, predicate) case rel: BatchExecWindowAggregateBase => FlinkRelMdUtil.splitPredicateOnAggregate(rel, predicate) diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdFilteredColumnInterval.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdFilteredColumnInterval.scala index b3821cb..64d7efb 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdFilteredColumnInterval.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdFilteredColumnInterval.scala @@ -19,7 +19,7 @@ package org.apache.flink.table.planner.plan.metadata import org.apache.flink.table.planner.plan.metadata.FlinkMetadata.FilteredColumnInterval import org.apache.flink.table.planner.plan.nodes.calcite.TableAggregate -import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchExecGroupAggregateBase +import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalGroupAggregateBase import org.apache.flink.table.planner.plan.nodes.physical.stream.{StreamExecGroupWindowAggregate, StreamExecGroupWindowTableAggregate, StreamPhysicalGlobalGroupAggregate, StreamPhysicalGroupAggregate, StreamPhysicalGroupTableAggregate, StreamPhysicalLocalGroupAggregate} import org.apache.flink.table.planner.plan.stats.ValueInterval import org.apache.flink.table.planner.plan.utils.ColumnIntervalUtil @@ -176,7 +176,7 @@ class FlinkRelMdFilteredColumnInterval private extends MetadataHandler[FilteredC } def getFilteredColumnInterval( - aggregate: BatchExecGroupAggregateBase, + aggregate: BatchPhysicalGroupAggregateBase, mq: RelMetadataQuery, columnIndex: Int, filterArg: Int): ValueInterval = { diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicity.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicity.scala index 7f128ec..fb1493a 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicity.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicity.scala @@ -25,7 +25,7 @@ import org.apache.flink.table.planner.plan.`trait`.RelModifiedMonotonicity import org.apache.flink.table.planner.plan.metadata.FlinkMetadata.ModifiedMonotonicity import org.apache.flink.table.planner.plan.nodes.calcite.{Expand, Rank, TableAggregate, WindowAggregate, WindowTableAggregate} import org.apache.flink.table.planner.plan.nodes.logical._ -import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchExecGroupAggregateBase, BatchPhysicalCorrelate} +import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchPhysicalCorrelate, BatchPhysicalGroupAggregateBase} import org.apache.flink.table.planner.plan.nodes.physical.stream._ import org.apache.flink.table.planner.plan.schema.{FlinkPreparingTableBase, TableSourceTable} import org.apache.flink.table.planner.plan.stats.{WithLower, WithUpper} @@ -51,9 +51,9 @@ import java.util.Collections import scala.collection.JavaConversions._ /** - * FlinkRelMdModifiedMonotonicity supplies a default implementation of - * [[FlinkRelMetadataQuery#getRelModifiedMonotonicity]] for logical algebra. - */ + * FlinkRelMdModifiedMonotonicity supplies a default implementation of + * [[FlinkRelMetadataQuery#getRelModifiedMonotonicity]] for logical algebra. + */ class FlinkRelMdModifiedMonotonicity private extends MetadataHandler[ModifiedMonotonicity] { override def getDef: MetadataDef[ModifiedMonotonicity] = FlinkMetadata.ModifiedMonotonicity.DEF @@ -239,8 +239,8 @@ class FlinkRelMdModifiedMonotonicity private extends MetadataHandler[ModifiedMon } def getRelModifiedMonotonicity( - rel: StreamPhysicalMiniBatchAssigner, - mq: RelMetadataQuery): RelModifiedMonotonicity = { + rel: StreamPhysicalMiniBatchAssigner, + mq: RelMetadataQuery): RelModifiedMonotonicity = { getMonotonicity(rel.getInput, mq, rel.getRowType.getFieldCount) } @@ -256,8 +256,8 @@ class FlinkRelMdModifiedMonotonicity private extends MetadataHandler[ModifiedMon } def getRelModifiedMonotonicity( - rel: WindowTableAggregate, - mq: RelMetadataQuery): RelModifiedMonotonicity = { + rel: WindowTableAggregate, + mq: RelMetadataQuery): RelModifiedMonotonicity = { if (allAppend(mq, rel.getInput)) { constants(rel.getRowType.getFieldCount) } else { @@ -272,7 +272,7 @@ class FlinkRelMdModifiedMonotonicity private extends MetadataHandler[ModifiedMon } def getRelModifiedMonotonicity( - rel: BatchExecGroupAggregateBase, + rel: BatchPhysicalGroupAggregateBase, mq: RelMetadataQuery): RelModifiedMonotonicity = null def getRelModifiedMonotonicity( @@ -324,8 +324,8 @@ class FlinkRelMdModifiedMonotonicity private extends MetadataHandler[ModifiedMon } def getRelModifiedMonotonicity( - rel: StreamExecGroupWindowTableAggregate, - mq: RelMetadataQuery): RelModifiedMonotonicity = { + rel: StreamExecGroupWindowTableAggregate, + mq: RelMetadataQuery): RelModifiedMonotonicity = { if (allAppend(mq, rel.getInput)) { constants(rel.getRowType.getFieldCount) } else { @@ -546,9 +546,9 @@ class FlinkRelMdModifiedMonotonicity private extends MetadataHandler[ModifiedMon def getRelModifiedMonotonicity(rel: RelNode, mq: RelMetadataQuery): RelModifiedMonotonicity = null /** - * Utility to create a RelModifiedMonotonicity which all fields is modified constant which - * means all the field's value will not be modified. - */ + * Utility to create a RelModifiedMonotonicity which all fields is modified constant which + * means all the field's value will not be modified. + */ def constants(fieldCount: Int): RelModifiedMonotonicity = { new RelModifiedMonotonicity(Array.fill(fieldCount)(CONSTANT)) } @@ -558,8 +558,8 @@ class FlinkRelMdModifiedMonotonicity private extends MetadataHandler[ModifiedMon } /** - * These operator won't generate update itself - */ + * These operator won't generate update itself + */ def getMonotonicity( input: RelNode, mq: RelMetadataQuery, diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdPercentageOriginalRows.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdPercentageOriginalRows.scala index ae5ae6c..65a97cf 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdPercentageOriginalRows.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdPercentageOriginalRows.scala @@ -20,7 +20,7 @@ package org.apache.flink.table.planner.plan.metadata import org.apache.flink.table.planner.JDouble import org.apache.flink.table.planner.plan.nodes.calcite.{Expand, Rank} -import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchExecGroupAggregateBase +import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalGroupAggregateBase import org.apache.calcite.plan.volcano.RelSubset import org.apache.calcite.rel.RelNode @@ -51,7 +51,9 @@ class FlinkRelMdPercentageOriginalRows private def getPercentageOriginalRows(rel: Aggregate, mq: RelMetadataQuery): JDouble = mq.getPercentageOriginalRows(rel.getInput) - def getPercentageOriginalRows(rel: BatchExecGroupAggregateBase, mq: RelMetadataQuery): JDouble = { + def getPercentageOriginalRows( + rel: BatchPhysicalGroupAggregateBase, + mq: RelMetadataQuery): JDouble = { mq.getPercentageOriginalRows(rel.getInput) } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdPopulationSize.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdPopulationSize.scala index e61c94b2..ac37076 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdPopulationSize.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdPopulationSize.scala @@ -218,7 +218,7 @@ class FlinkRelMdPopulationSize private extends MetadataHandler[BuiltInMetadata.P } def getPopulationSize( - rel: BatchExecGroupAggregateBase, + rel: BatchPhysicalGroupAggregateBase, mq: RelMetadataQuery, groupKey: ImmutableBitSet): JDouble = { // for global agg which has inner local agg, it passes the parameters to input directly diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdRowCount.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdRowCount.scala index 07c0668..32328c5 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdRowCount.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdRowCount.scala @@ -138,15 +138,15 @@ class FlinkRelMdRowCount private extends MetadataHandler[BuiltInMetadata.RowCoun } } - def getRowCount(rel: BatchExecGroupAggregateBase, mq: RelMetadataQuery): JDouble = { + def getRowCount(rel: BatchPhysicalGroupAggregateBase, mq: RelMetadataQuery): JDouble = { getRowCountOfBatchExecAgg(rel, mq) } private def getRowCountOfBatchExecAgg(rel: SingleRel, mq: RelMetadataQuery): JDouble = { val input = rel.getInput val (grouping, isFinal, isMerge) = rel match { - case agg: BatchExecGroupAggregateBase => - (ImmutableBitSet.of(agg.getGrouping: _*), agg.isFinal, agg.isMerge) + case agg: BatchPhysicalGroupAggregateBase => + (ImmutableBitSet.of(agg.grouping: _*), agg.isFinal, agg.isMerge) case windowAgg: BatchExecWindowAggregateBase => (ImmutableBitSet.of(windowAgg.getGrouping: _*), windowAgg.isFinal, windowAgg.isMerge) case _ => throw new IllegalArgumentException(s"Unknown aggregate type ${rel.getRelTypeName}!") diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdSelectivity.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdSelectivity.scala index ccd729f..bc30bd0 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdSelectivity.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdSelectivity.scala @@ -97,7 +97,7 @@ class FlinkRelMdSelectivity private extends MetadataHandler[BuiltInMetadata.Sele predicate: RexNode): JDouble = getSelectivityOfAgg(rel, mq, predicate) def getSelectivity( - rel: BatchExecGroupAggregateBase, + rel: BatchPhysicalGroupAggregateBase, mq: RelMetadataQuery, predicate: RexNode): JDouble = getSelectivityOfAgg(rel, mq, predicate) @@ -130,7 +130,7 @@ class FlinkRelMdSelectivity private extends MetadataHandler[BuiltInMetadata.Sele } else { val hasLocalAgg = agg match { case _: Aggregate => false - case rel: BatchExecGroupAggregateBase => rel.isFinal && rel.isMerge + case rel: BatchPhysicalGroupAggregateBase => rel.isFinal && rel.isMerge case rel: BatchExecWindowAggregateBase => rel.isFinal && rel.isMerge case _ => throw new IllegalArgumentException(s"Cannot handle ${agg.getRelTypeName}!") } @@ -147,7 +147,7 @@ class FlinkRelMdSelectivity private extends MetadataHandler[BuiltInMetadata.Sele val (childPred, restPred) = agg match { case rel: Aggregate => FlinkRelMdUtil.splitPredicateOnAggregate(rel, predicate) - case rel: BatchExecGroupAggregateBase => + case rel: BatchPhysicalGroupAggregateBase => FlinkRelMdUtil.splitPredicateOnAggregate(rel, predicate) case rel: BatchExecWindowAggregateBase => FlinkRelMdUtil.splitPredicateOnAggregate(rel, predicate) diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdSize.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdSize.scala index c4debe0..024497f 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdSize.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdSize.scala @@ -187,11 +187,13 @@ class FlinkRelMdSize private extends MetadataHandler[BuiltInMetadata.Size] { sizesBuilder.build } - def averageColumnSizes(rel: BatchExecGroupAggregateBase, mq: RelMetadataQuery): JList[JDouble] = { + def averageColumnSizes( + rel: BatchPhysicalGroupAggregateBase, + mq: RelMetadataQuery): JList[JDouble] = { // note: the logical to estimate column sizes of AggregateBatchExecBase is different from // Calcite Aggregate because AggregateBatchExecBase's rowTypes is not composed by // grouping columns + aggFunctionCall results - val mapInputToOutput = (rel.getGrouping ++ rel.getAuxGrouping).zipWithIndex.toMap + val mapInputToOutput = (rel.grouping ++ rel.auxGrouping).zipWithIndex.toMap getColumnSizesFromInputOrType(rel, mq, mapInputToOutput) } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueGroups.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueGroups.scala index 8a06bbb..8733d17 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueGroups.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueGroups.scala @@ -211,10 +211,10 @@ class FlinkRelMdUniqueGroups private extends MetadataHandler[UniqueGroups] { } def getUniqueGroups( - agg: BatchExecGroupAggregateBase, + agg: BatchPhysicalGroupAggregateBase, mq: RelMetadataQuery, columns: ImmutableBitSet): ImmutableBitSet = { - val grouping = agg.getGrouping + val grouping = agg.grouping getUniqueGroupsOfAggregate(agg.getRowType.getFieldCount, grouping, agg.getInput, mq, columns) } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueKeys.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueKeys.scala index 5962901..ffcad6d 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueKeys.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueKeys.scala @@ -84,7 +84,7 @@ class FlinkRelMdUniqueKeys private extends MetadataHandler[BuiltInMetadata.Uniqu columns.indexOf(c) } val builder = ImmutableSet.builder[ImmutableBitSet]() - builder.add(ImmutableBitSet.of(columnIndices:_*)) + builder.add(ImmutableBitSet.of(columnIndices: _*)) val uniqueSet = sourceTable.uniqueKeysSet().orElse(null) if (uniqueSet != null) { builder.addAll(uniqueSet) @@ -198,8 +198,8 @@ class FlinkRelMdUniqueKeys private extends MetadataHandler[BuiltInMetadata.Uniqu } /** - * Whether the [[RexCall]] is a cast that doesn't lose any information. - */ + * Whether the [[RexCall]] is a cast that doesn't lose any information. + */ private def isFidelityCast(call: RexCall): Boolean = { if (call.getKind != SqlKind.CAST) { return false @@ -334,11 +334,11 @@ class FlinkRelMdUniqueKeys private extends MetadataHandler[BuiltInMetadata.Uniqu } def getUniqueKeys( - rel: BatchExecGroupAggregateBase, + rel: BatchPhysicalGroupAggregateBase, mq: RelMetadataQuery, ignoreNulls: Boolean): JSet[ImmutableBitSet] = { if (rel.isFinal) { - getUniqueKeysOnAggregate(rel.getGrouping, mq, ignoreNulls) + getUniqueKeysOnAggregate(rel.grouping, mq, ignoreNulls) } else { null } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecHashAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecHashAggregate.scala index 4fc9277..29e755e 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecHashAggregate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecHashAggregate.scala @@ -29,7 +29,6 @@ import org.apache.calcite.rel.RelDistribution.Type.{HASH_DISTRIBUTED, SINGLETON} import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.core.AggregateCall import org.apache.calcite.rel.{RelNode, RelWriter} -import org.apache.calcite.tools.RelBuilder import org.apache.calcite.util.{ImmutableIntList, Util} import java.util @@ -39,11 +38,10 @@ import scala.collection.JavaConversions._ /** * Batch physical RelNode for (global) hash-based aggregate operator. * - * @see [[BatchExecGroupAggregateBase]] for more info. + * @see [[BatchPhysicalGroupAggregateBase]] for more info. */ class BatchExecHashAggregate( cluster: RelOptCluster, - relBuilder: RelBuilder, traitSet: RelTraitSet, inputRel: RelNode, outputRowType: RelDataType, @@ -55,7 +53,6 @@ class BatchExecHashAggregate( isMerge: Boolean) extends BatchExecHashAggregateBase( cluster, - relBuilder, traitSet, inputRel, outputRowType, @@ -70,7 +67,6 @@ class BatchExecHashAggregate( override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = { new BatchExecHashAggregate( cluster, - relBuilder, traitSet, inputs.get(0), outputRowType, diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecHashAggregateBase.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecHashAggregateBase.scala index 9400dbd..b5b0960 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecHashAggregateBase.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecHashAggregateBase.scala @@ -40,17 +40,15 @@ import org.apache.calcite.rel.RelNode import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.core.AggregateCall import org.apache.calcite.rel.metadata.RelMetadataQuery -import org.apache.calcite.tools.RelBuilder import org.apache.calcite.util.Util /** * Batch physical RelNode for hash-based aggregate operator. * - * @see [[BatchExecGroupAggregateBase]] for more info. + * @see [[BatchPhysicalGroupAggregateBase]] for more info. */ abstract class BatchExecHashAggregateBase( cluster: RelOptCluster, - relBuilder: RelBuilder, traitSet: RelTraitSet, inputRel: RelNode, outputRowType: RelDataType, @@ -61,13 +59,11 @@ abstract class BatchExecHashAggregateBase( aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)], isMerge: Boolean, isFinal: Boolean) - extends BatchExecGroupAggregateBase( + extends BatchPhysicalGroupAggregateBase( cluster, - relBuilder, traitSet, inputRel, outputRowType, - inputRowType, grouping, auxGrouping, aggCallToAggFunction, @@ -114,17 +110,25 @@ abstract class BatchExecHashAggregateBase( val inputType = FlinkTypeFactory.toLogicalRowType(inputRowType) val aggInfos = transformToBatchAggregateInfoList( - FlinkTypeFactory.toLogicalRowType(aggInputRowType), aggCallToAggFunction.map(_._1)) + FlinkTypeFactory.toLogicalRowType(aggInputRowType), getAggCallList) var managedMemory: Long = 0L val generatedOperator = if (grouping.isEmpty) { AggWithoutKeysCodeGenerator.genWithoutKeys( - ctx, relBuilder, aggInfos, inputType, outputType, isMerge, isFinal, "NoGrouping") + ctx, planner.getRelBuilder, aggInfos, inputType, outputType, isMerge, isFinal, "NoGrouping") } else { managedMemory = MemorySize.parse(config.getConfiguration.getString( ExecutionConfigOptions.TABLE_EXEC_RESOURCE_HASH_AGG_MEMORY)).getBytes new HashAggCodeGenerator( - ctx, relBuilder, aggInfos, inputType, outputType, grouping, auxGrouping, isMerge, isFinal + ctx, + planner.getRelBuilder, + aggInfos, + inputType, + outputType, + grouping, + auxGrouping, + isMerge, + isFinal ).genWithKeys() } val operator = new CodeGenOperatorFactory[RowData](generatedOperator) diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecLocalHashAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecLocalHashAggregate.scala index 7b55ae6..dc13264 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecLocalHashAggregate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecLocalHashAggregate.scala @@ -28,7 +28,6 @@ import org.apache.calcite.rel.RelDistribution.Type import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.core.AggregateCall import org.apache.calcite.rel.{RelNode, RelWriter} -import org.apache.calcite.tools.RelBuilder import org.apache.calcite.util.ImmutableIntList import java.util @@ -36,13 +35,12 @@ import java.util import scala.collection.JavaConversions._ /** - * Batch physical RelNode for local hash-based aggregate operator. - * - * @see [[BatchExecGroupAggregateBase]] for more info. - */ + * Batch physical RelNode for local hash-based aggregate operator. + * + * @see [[BatchPhysicalGroupAggregateBase]] for more info. + */ class BatchExecLocalHashAggregate( cluster: RelOptCluster, - relBuilder: RelBuilder, traitSet: RelTraitSet, inputRel: RelNode, outputRowType: RelDataType, @@ -52,7 +50,6 @@ class BatchExecLocalHashAggregate( aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)]) extends BatchExecHashAggregateBase( cluster, - relBuilder, traitSet, inputRel, outputRowType, @@ -67,7 +64,6 @@ class BatchExecLocalHashAggregate( override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = { new BatchExecLocalHashAggregate( cluster, - relBuilder, traitSet, inputs.get(0), outputRowType, diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecLocalSortAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecLocalSortAggregate.scala index 4c6acc2..07ff2d0 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecLocalSortAggregate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecLocalSortAggregate.scala @@ -28,7 +28,6 @@ import org.apache.calcite.rel.RelDistribution.Type import org.apache.calcite.rel._ import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.core.AggregateCall -import org.apache.calcite.tools.RelBuilder import org.apache.calcite.util.ImmutableIntList import java.util @@ -36,13 +35,12 @@ import java.util import scala.collection.JavaConversions._ /** - * Batch physical RelNode for local sort-based aggregate operator. - * - * @see [[BatchExecGroupAggregateBase]] for more info. - */ + * Batch physical RelNode for local sort-based aggregate operator. + * + * @see [[BatchPhysicalGroupAggregateBase]] for more info. + */ class BatchExecLocalSortAggregate( cluster: RelOptCluster, - relBuilder: RelBuilder, traitSet: RelTraitSet, inputRel: RelNode, outputRowType: RelDataType, @@ -52,7 +50,6 @@ class BatchExecLocalSortAggregate( aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)]) extends BatchExecSortAggregateBase( cluster, - relBuilder, traitSet, inputRel, outputRowType, @@ -67,7 +64,6 @@ class BatchExecLocalSortAggregate( override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = { new BatchExecLocalSortAggregate( cluster, - relBuilder, traitSet, inputs.get(0), outputRowType, diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecPythonGroupAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecPythonGroupAggregate.scala index 54ba618..ab90b90 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecPythonGroupAggregate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecPythonGroupAggregate.scala @@ -49,8 +49,8 @@ import java.util import scala.collection.JavaConversions._ /** - * Batch physical RelNode for aggregate (Python user defined aggregate function). - */ + * Batch physical RelNode for aggregate (Python user defined aggregate function). + */ class BatchExecPythonGroupAggregate( cluster: RelOptCluster, traitSet: RelTraitSet, @@ -62,13 +62,11 @@ class BatchExecPythonGroupAggregate( auxGrouping: Array[Int], aggCalls: Seq[AggregateCall], aggFunctions: Array[UserDefinedFunction]) - extends BatchExecGroupAggregateBase( + extends BatchPhysicalGroupAggregateBase( cluster, - null, traitSet, inputRel, outputRowType, - inputRowType, grouping, auxGrouping, aggCalls.zip(aggFunctions), diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecSortAggregate.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecSortAggregate.scala index d4ff27d..58f61c5 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecSortAggregate.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecSortAggregate.scala @@ -29,7 +29,6 @@ import org.apache.calcite.rel.RelDistribution.Type.{HASH_DISTRIBUTED, SINGLETON} import org.apache.calcite.rel._ import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.core.AggregateCall -import org.apache.calcite.tools.RelBuilder import org.apache.calcite.util.{ImmutableIntList, Util} import java.util @@ -37,13 +36,12 @@ import java.util import scala.collection.JavaConversions._ /** - * Batch physical RelNode for (global) sort-based aggregate operator. - * - * @see [[BatchExecGroupAggregateBase]] for more info. - */ + * Batch physical RelNode for (global) sort-based aggregate operator. + * + * @see [[BatchPhysicalGroupAggregateBase]] for more info. + */ class BatchExecSortAggregate( cluster: RelOptCluster, - relBuilder: RelBuilder, traitSet: RelTraitSet, inputRel: RelNode, outputRowType: RelDataType, @@ -55,7 +53,6 @@ class BatchExecSortAggregate( isMerge: Boolean) extends BatchExecSortAggregateBase( cluster, - relBuilder, traitSet, inputRel, outputRowType, @@ -70,7 +67,6 @@ class BatchExecSortAggregate( override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = { new BatchExecSortAggregate( cluster, - relBuilder, traitSet, inputs.get(0), outputRowType, diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecSortAggregateBase.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecSortAggregateBase.scala index 12a617f..3528d8f 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecSortAggregateBase.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecSortAggregateBase.scala @@ -36,16 +36,14 @@ import org.apache.calcite.rel.RelNode import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.core.AggregateCall import org.apache.calcite.rel.metadata.RelMetadataQuery -import org.apache.calcite.tools.RelBuilder /** - * Batch physical RelNode for sort-based aggregate operator. - * - * @see [[BatchExecGroupAggregateBase]] for more info. - */ + * Batch physical RelNode for sort-based aggregate operator. + * + * @see [[BatchPhysicalGroupAggregateBase]] for more info. + */ abstract class BatchExecSortAggregateBase( cluster: RelOptCluster, - relBuilder: RelBuilder, traitSet: RelTraitSet, inputRel: RelNode, outputRowType: RelDataType, @@ -56,13 +54,11 @@ abstract class BatchExecSortAggregateBase( aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)], isMerge: Boolean, isFinal: Boolean) - extends BatchExecGroupAggregateBase( + extends BatchPhysicalGroupAggregateBase( cluster, - relBuilder, traitSet, inputRel, outputRowType, - inputRowType, grouping, auxGrouping, aggCallToAggFunction, @@ -95,14 +91,22 @@ abstract class BatchExecSortAggregateBase( val inputType = FlinkTypeFactory.toLogicalRowType(inputRowType) val aggInfos = transformToBatchAggregateInfoList( - FlinkTypeFactory.toLogicalRowType(aggInputRowType), aggCallToAggFunction.map(_._1)) + FlinkTypeFactory.toLogicalRowType(aggInputRowType), getAggCallList) val generatedOperator = if (grouping.isEmpty) { AggWithoutKeysCodeGenerator.genWithoutKeys( - ctx, relBuilder, aggInfos, inputType, outputType, isMerge, isFinal, "NoGrouping") + ctx, planner.getRelBuilder, aggInfos, inputType, outputType, isMerge, isFinal, "NoGrouping") } else { SortAggCodeGenerator.genWithKeys( - ctx, relBuilder, aggInfos, inputType, outputType, grouping, auxGrouping, isMerge, isFinal) + ctx, + planner.getRelBuilder, + aggInfos, + inputType, + outputType, + grouping, + auxGrouping, + isMerge, + isFinal) } val operator = new CodeGenOperatorFactory[RowData](generatedOperator) ExecNodeUtil.createOneInputTransformation( diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecGroupAggregateBase.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalGroupAggregateBase.scala similarity index 67% rename from flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecGroupAggregateBase.scala rename to flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalGroupAggregateBase.scala index 9bed773..33cf683 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchExecGroupAggregateBase.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalGroupAggregateBase.scala @@ -28,32 +28,29 @@ import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.core.{Aggregate, AggregateCall} import org.apache.calcite.rel.{RelNode, SingleRel} -import org.apache.calcite.tools.RelBuilder /** - * Batch physical RelNode for aggregate. - * - * <P>There are two differences between this node and [[Aggregate]]: - * 1. This node supports two-stage aggregation to reduce data-shuffling: - * local-aggregation and global-aggregation. - * local-aggregation produces a partial result for each group before shuffle in stage 1, - * and then the partially aggregated results are shuffled to global-aggregation - * which produces the final result in stage 2. - * Two-stage aggregation is enabled only if all aggregate functions are mergeable. - * (e.g. SUM, AVG, MAX) - * 2. This node supports auxiliary group keys which will not be computed as key and - * does not also affect the correctness of the final result. [[Aggregate]] does not distinguish - * group keys and auxiliary group keys, and combines them as a complete `groupSet`. - */ -abstract class BatchExecGroupAggregateBase( + * Batch physical RelNode for aggregate. + * + * <P>There are two differences between this node and [[Aggregate]]: + * 1. This node supports two-stage aggregation to reduce data-shuffling: + * local-aggregation and global-aggregation. + * local-aggregation produces a partial result for each group before shuffle in stage 1, + * and then the partially aggregated results are shuffled to global-aggregation + * which produces the final result in stage 2. + * Two-stage aggregation is enabled only if all aggregate functions are mergeable. + * (e.g. SUM, AVG, MAX) + * 2. This node supports auxiliary group keys which will not be computed as key and + * does not also affect the correctness of the final result. [[Aggregate]] does not distinguish + * group keys and auxiliary group keys, and combines them as a complete `groupSet`. + */ +abstract class BatchPhysicalGroupAggregateBase( cluster: RelOptCluster, - relBuilder: RelBuilder, traitSet: RelTraitSet, inputRel: RelNode, outputRowType: RelDataType, - inputRowType: RelDataType, - grouping: Array[Int], - auxGrouping: Array[Int], + val grouping: Array[Int], + val auxGrouping: Array[Int], aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)], val isMerge: Boolean, val isFinal: Boolean) @@ -66,10 +63,6 @@ abstract class BatchExecGroupAggregateBase( override def deriveRowType(): RelDataType = outputRowType - def getGrouping: Array[Int] = grouping - - def getAuxGrouping: Array[Int] = auxGrouping - def getAggCallList: Seq[AggregateCall] = aggCallToAggFunction.map(_._1) def getAggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)] = aggCallToAggFunction diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecHashAggRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecHashAggRule.scala index 1a28774..e5c9b97 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecHashAggRule.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecHashAggRule.scala @@ -60,7 +60,7 @@ class BatchExecHashAggRule operand(classOf[FlinkLogicalAggregate], operand(classOf[RelNode], any)), "BatchExecHashAggRule") - with BatchExecAggRuleBase { + with BatchPhysicalAggRuleBase { override def matches(call: RelOptRuleCall): Boolean = { val tableConfig = call.getPlanner.getContext.unwrap(classOf[FlinkContext]).getTableConfig @@ -100,7 +100,6 @@ class BatchExecHashAggRule val providedTraitSet = localRequiredTraitSet val localHashAgg = createLocalAgg( agg.getCluster, - call.builder(), providedTraitSet, newInput, agg.getRowType, @@ -136,7 +135,6 @@ class BatchExecHashAggRule val newLocalHashAgg = RelOptRule.convert(localHashAgg, requiredTraitSet) val globalHashAgg = new BatchExecHashAggregate( agg.getCluster, - call.builder(), aggProvidedTraitSet, newLocalHashAgg, agg.getRowType, @@ -167,7 +165,6 @@ class BatchExecHashAggRule val newInput = RelOptRule.convert(input, requiredTraitSet) val hashAgg = new BatchExecHashAggregate( agg.getCluster, - call.builder(), aggProvidedTraitSet, newInput, agg.getRowType, diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecHashJoinRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecHashJoinRule.scala index a5f0a33..7de1162 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecHashJoinRule.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecHashJoinRule.scala @@ -85,7 +85,7 @@ class BatchExecHashJoinRule val distinctKeys = 0 until join.getRight.getRowType.getFieldCount val useBuildDistinct = chooseSemiBuildDistinct(join.getRight, distinctKeys) if (useBuildDistinct) { - (addLocalDistinctAgg(join.getRight, distinctKeys, call.builder()), true) + (addLocalDistinctAgg(join.getRight, distinctKeys), true) } else { (join.getRight, false) } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecJoinRuleBase.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecJoinRuleBase.scala index 636978b..9dead54 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecJoinRuleBase.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecJoinRuleBase.scala @@ -27,7 +27,6 @@ import org.apache.flink.table.planner.plan.utils.{FlinkRelMdUtil, FlinkRelOptUti import org.apache.calcite.plan.RelOptRule import org.apache.calcite.rel.RelNode -import org.apache.calcite.tools.RelBuilder import org.apache.calcite.util.ImmutableBitSet import java.lang.{Boolean => JBoolean, Double => JDouble} @@ -36,15 +35,13 @@ trait BatchExecJoinRuleBase { def addLocalDistinctAgg( node: RelNode, - distinctKeys: Seq[Int], - relBuilder: RelBuilder): RelNode = { + distinctKeys: Seq[Int]): RelNode = { val localRequiredTraitSet = node.getTraitSet.replace(FlinkConventions.BATCH_PHYSICAL) val newInput = RelOptRule.convert(node, localRequiredTraitSet) val providedTraitSet = localRequiredTraitSet new BatchExecLocalHashAggregate( node.getCluster, - relBuilder, providedTraitSet, newInput, node.getRowType, // output row type diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecNestedLoopJoinRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecNestedLoopJoinRule.scala index 8b1fb5d..3e3a37e 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecNestedLoopJoinRule.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecNestedLoopJoinRule.scala @@ -54,7 +54,7 @@ class BatchExecNestedLoopJoinRule val distinctKeys = 0 until join.getRight.getRowType.getFieldCount val useBuildDistinct = chooseSemiBuildDistinct(join.getRight, distinctKeys) if (useBuildDistinct) { - addLocalDistinctAgg(join.getRight, distinctKeys, call.builder()) + addLocalDistinctAgg(join.getRight, distinctKeys) } else { join.getRight } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecSortAggRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecSortAggRule.scala index 7bf6d16..6bb4437 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecSortAggRule.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecSortAggRule.scala @@ -34,35 +34,35 @@ import org.apache.calcite.rel._ import scala.collection.JavaConversions._ /** - * Rule that converts [[FlinkLogicalAggregate]] to - * {{{ - * BatchExecSortAggregate (global) - * +- Sort (exists if group keys are not empty) - * +- BatchPhysicalExchange (hash by group keys if group keys is not empty, else singleton) - * +- BatchExecLocalSortAggregate (local) - * +- Sort (exists if group keys are not empty) - * +- input of agg - * }}} - * when all aggregate functions are mergeable - * and [[OptimizerConfigOptions.TABLE_OPTIMIZER_AGG_PHASE_STRATEGY]] is TWO_PHASE, or - * {{{ - * BatchExecSortAggregate - * +- Sort (exists if group keys are not empty) - * +- BatchPhysicalExchange (hash by group keys if group keys is not empty, else singleton) - * +- input of agg - * }}} - * when some aggregate functions are not mergeable - * or [[OptimizerConfigOptions.TABLE_OPTIMIZER_AGG_PHASE_STRATEGY]] is ONE_PHASE. - * - * Notes: if [[OptimizerConfigOptions.TABLE_OPTIMIZER_AGG_PHASE_STRATEGY]] is NONE, - * this rule will try to create two possibilities above, and chooses the best one based on cost. - */ + * Rule that converts [[FlinkLogicalAggregate]] to + * {{{ + * BatchExecSortAggregate (global) + * +- Sort (exists if group keys are not empty) + * +- BatchPhysicalExchange (hash by group keys if group keys is not empty, else singleton) + * +- BatchExecLocalSortAggregate (local) + * +- Sort (exists if group keys are not empty) + * +- input of agg + * }}} + * when all aggregate functions are mergeable + * and [[OptimizerConfigOptions.TABLE_OPTIMIZER_AGG_PHASE_STRATEGY]] is TWO_PHASE, or + * {{{ + * BatchExecSortAggregate + * +- Sort (exists if group keys are not empty) + * +- BatchPhysicalExchange (hash by group keys if group keys is not empty, else singleton) + * +- input of agg + * }}} + * when some aggregate functions are not mergeable + * or [[OptimizerConfigOptions.TABLE_OPTIMIZER_AGG_PHASE_STRATEGY]] is ONE_PHASE. + * + * Notes: if [[OptimizerConfigOptions.TABLE_OPTIMIZER_AGG_PHASE_STRATEGY]] is NONE, + * this rule will try to create two possibilities above, and chooses the best one based on cost. + */ class BatchExecSortAggRule extends RelOptRule( operand(classOf[FlinkLogicalAggregate], operand(classOf[RelNode], any)), "BatchExecSortAggRule") - with BatchExecAggRuleBase { + with BatchPhysicalAggRuleBase { override def matches(call: RelOptRuleCall): Boolean = { val tableConfig = call.getPlanner.getContext.unwrap(classOf[FlinkContext]).getTableConfig @@ -99,7 +99,6 @@ class BatchExecSortAggRule val localSortAgg = createLocalAgg( agg.getCluster, - call.builder(), providedLocalTraitSet, newLocalInput, agg.getRowType, @@ -142,7 +141,6 @@ class BatchExecSortAggRule val newInputForFinalAgg = RelOptRule.convert(localSortAgg, requiredTraitSet) val globalSortAgg = new BatchExecSortAggregate( agg.getCluster, - call.builder(), aggProvidedTraitSet, newInputForFinalAgg, agg.getRowType, @@ -177,7 +175,6 @@ class BatchExecSortAggRule val newInput = RelOptRule.convert(input, requiredTraitSet) val sortAgg = new BatchExecSortAggregate( agg.getCluster, - call.builder(), aggProvidedTraitSet, newInput, agg.getRowType, diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecWindowAggregateRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecWindowAggregateRule.scala index 81151f2..1c7b2bb 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecWindowAggregateRule.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecWindowAggregateRule.scala @@ -46,34 +46,34 @@ import org.apache.commons.math3.util.ArithmeticUtils import scala.collection.JavaConversions._ /** - * Rule to convert a [[FlinkLogicalWindowAggregate]] into a - * {{{ - * BatchExecHash(or Sort)WindowAggregate (global) - * +- BatchPhysicalExchange (hash by group keys if group keys is not empty, else singleton) - * +- BatchExecLocalHash(or Sort)WindowAggregate (local) - * +- input of window agg - * }}} - * when all aggregate functions are mergeable - * and [[OptimizerConfigOptions.TABLE_OPTIMIZER_AGG_PHASE_STRATEGY]] is TWO_PHASE, or - * {{{ - * BatchExecHash(or Sort)WindowAggregate - * +- BatchPhysicalExchange (hash by group keys if group keys is not empty, else singleton) - * +- input of window agg - * }}} - * when some aggregate functions are not mergeable - * or [[OptimizerConfigOptions.TABLE_OPTIMIZER_AGG_PHASE_STRATEGY]] is ONE_PHASE. - * - * Notes: if [[OptimizerConfigOptions.TABLE_OPTIMIZER_AGG_PHASE_STRATEGY]] is NONE, - * this rule will try to create two possibilities above, and chooses the best one based on cost. - * if all aggregate function buffer are fix length, the rule will choose hash window agg. - */ + * Rule to convert a [[FlinkLogicalWindowAggregate]] into a + * {{{ + * BatchExecHash(or Sort)WindowAggregate (global) + * +- BatchPhysicalExchange (hash by group keys if group keys is not empty, else singleton) + * +- BatchExecLocalHash(or Sort)WindowAggregate (local) + * +- input of window agg + * }}} + * when all aggregate functions are mergeable + * and [[OptimizerConfigOptions.TABLE_OPTIMIZER_AGG_PHASE_STRATEGY]] is TWO_PHASE, or + * {{{ + * BatchExecHash(or Sort)WindowAggregate + * +- BatchPhysicalExchange (hash by group keys if group keys is not empty, else singleton) + * +- input of window agg + * }}} + * when some aggregate functions are not mergeable + * or [[OptimizerConfigOptions.TABLE_OPTIMIZER_AGG_PHASE_STRATEGY]] is ONE_PHASE. + * + * Notes: if [[OptimizerConfigOptions.TABLE_OPTIMIZER_AGG_PHASE_STRATEGY]] is NONE, + * this rule will try to create two possibilities above, and chooses the best one based on cost. + * if all aggregate function buffer are fix length, the rule will choose hash window agg. + */ class BatchExecWindowAggregateRule extends RelOptRule( operand(classOf[FlinkLogicalWindowAggregate], operand(classOf[RelNode], any)), FlinkRelFactories.LOGICAL_BUILDER_WITHOUT_AGG_INPUT_PRUNE, "BatchExecWindowAggregateRule") - with BatchExecAggRuleBase { + with BatchPhysicalAggRuleBase { override def matches(call: RelOptRuleCall): Boolean = { val agg: FlinkLogicalWindowAggregate = call.rel(0) @@ -346,11 +346,11 @@ class BatchExecWindowAggregateRule } /** - * Return true when sliding window with slideSize < windowSize && gcd(windowSize, slideSize) > 1. - * Otherwise return false, including the cases of tumbling window, - * sliding window with slideSize >= windowSize and - * sliding window with slideSize < windowSize but gcd(windowSize, slideSize) == 1. - */ + * Return true when sliding window with slideSize < windowSize && gcd(windowSize, slideSize) > 1. + * Otherwise return false, including the cases of tumbling window, + * sliding window with slideSize >= windowSize and + * sliding window with slideSize < windowSize but gcd(windowSize, slideSize) == 1. + */ private def useAssignPane( aggregateList: Array[UserDefinedFunction], windowSize: Long, @@ -360,12 +360,12 @@ class BatchExecWindowAggregateRule } /** - * In the case of sliding window without the optimization of assigning pane which means - * slideSize < windowSize && ArithmeticUtils.gcd(windowSize, slideSize) == 1, we will disable the - * local aggregate. - * Otherwise, we use the same way as the group aggregate to make the decision whether - * to use a local aggregate or not. - */ + * In the case of sliding window without the optimization of assigning pane which means + * slideSize < windowSize && ArithmeticUtils.gcd(windowSize, slideSize) == 1, we will disable the + * local aggregate. + * Otherwise, we use the same way as the group aggregate to make the decision whether + * to use a local aggregate or not. + */ private def supportLocalWindowAgg( call: RelOptRuleCall, tableConfig: TableConfig, diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecAggRuleBase.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalAggRuleBase.scala similarity index 96% rename from flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecAggRuleBase.scala rename to flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalAggRuleBase.scala index 7d3a190..d0a9aa9 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchExecAggRuleBase.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/BatchPhysicalAggRuleBase.scala @@ -24,7 +24,7 @@ import org.apache.flink.table.planner.JArrayList import org.apache.flink.table.planner.calcite.FlinkTypeFactory import org.apache.flink.table.planner.functions.aggfunctions.DeclarativeAggregateFunction import org.apache.flink.table.planner.functions.utils.UserDefinedFunctionUtils._ -import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchExecGroupAggregateBase, BatchExecLocalHashAggregate, BatchExecLocalSortAggregate} +import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchExecLocalHashAggregate, BatchExecLocalSortAggregate, BatchPhysicalGroupAggregateBase} import org.apache.flink.table.planner.plan.utils.{AggregateUtil, FlinkRelOptUtil} import org.apache.flink.table.planner.utils.AggregatePhaseStrategy import org.apache.flink.table.planner.utils.TableConfigUtils.getAggPhaseStrategy @@ -36,12 +36,11 @@ import org.apache.calcite.plan.{RelOptCluster, RelTraitSet} import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.core.{Aggregate, AggregateCall} import org.apache.calcite.rel.{RelCollation, RelCollations, RelFieldCollation, RelNode} -import org.apache.calcite.tools.RelBuilder import org.apache.calcite.util.Util import scala.collection.JavaConversions._ -trait BatchExecAggRuleBase { +trait BatchPhysicalAggRuleBase { protected def inferLocalAggType( inputRowType: RelDataType, @@ -185,7 +184,6 @@ trait BatchExecAggRuleBase { protected def createLocalAgg( cluster: RelOptCluster, - relBuilder: RelBuilder, traitSet: RelTraitSet, input: RelNode, originalAggRowType: RelDataType, @@ -193,7 +191,7 @@ trait BatchExecAggRuleBase { auxGrouping: Array[Int], aggBufferTypes: Array[Array[DataType]], aggCallToAggFunction: Seq[(AggregateCall, UserDefinedFunction)], - isLocalHashAgg: Boolean): BatchExecGroupAggregateBase = { + isLocalHashAgg: Boolean): BatchPhysicalGroupAggregateBase = { val inputRowType = input.getRowType val aggFunctions = aggCallToAggFunction.map(_._2).toArray @@ -213,7 +211,6 @@ trait BatchExecAggRuleBase { if (isLocalHashAgg) { new BatchExecLocalHashAggregate( cluster, - relBuilder, traitSet, input, localAggRowType, @@ -224,7 +221,6 @@ trait BatchExecAggRuleBase { } else { new BatchExecLocalSortAggregate( cluster, - relBuilder, traitSet, input, localAggRowType, diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/EnforceLocalAggRuleBase.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/EnforceLocalAggRuleBase.scala index 34a4de2..840374d 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/EnforceLocalAggRuleBase.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/EnforceLocalAggRuleBase.scala @@ -22,31 +22,30 @@ import org.apache.flink.table.api.TableException import org.apache.flink.table.planner.calcite.FlinkTypeFactory import org.apache.flink.table.planner.plan.`trait`.FlinkRelDistribution import org.apache.flink.table.planner.plan.nodes.FlinkConventions -import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchExecGroupAggregateBase, BatchExecHashAggregate, BatchExecSortAggregate, BatchPhysicalExchange, BatchPhysicalExpand} +import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchExecHashAggregate, BatchExecSortAggregate, BatchPhysicalExchange, BatchPhysicalExpand, BatchPhysicalGroupAggregateBase} import org.apache.flink.table.planner.plan.utils.{AggregateUtil, FlinkRelOptUtil} import org.apache.calcite.plan.{RelOptRule, RelOptRuleOperand} import org.apache.calcite.rel.RelNode import org.apache.calcite.rex.RexUtil -import org.apache.calcite.tools.RelBuilder import scala.collection.JavaConversions._ /** - * Planner rule that writes one phase aggregate to two phase aggregate, - * when the following conditions are met: - * 1. there is no local aggregate, - * 2. the aggregate has non-empty grouping and two phase aggregate strategy is enabled, - * 3. the input is [[BatchPhysicalExpand]] and there is at least one expand row - * which the columns for grouping are all constant. - */ + * Planner rule that writes one phase aggregate to two phase aggregate, + * when the following conditions are met: + * 1. there is no local aggregate, + * 2. the aggregate has non-empty grouping and two phase aggregate strategy is enabled, + * 3. the input is [[BatchPhysicalExpand]] and there is at least one expand row + * which the columns for grouping are all constant. + */ abstract class EnforceLocalAggRuleBase( operand: RelOptRuleOperand, description: String) extends RelOptRule(operand, description) - with BatchExecAggRuleBase { + with BatchPhysicalAggRuleBase { - protected def isTwoPhaseAggEnabled(agg: BatchExecGroupAggregateBase): Boolean = { + protected def isTwoPhaseAggEnabled(agg: BatchPhysicalGroupAggregateBase): Boolean = { val tableConfig = FlinkRelOptUtil.getTableConfigFromContext(agg) val aggFunctions = agg.getAggCallToAggFunction.map(_._2).toArray isTwoPhaseAggWorkable(aggFunctions, tableConfig) @@ -64,14 +63,13 @@ abstract class EnforceLocalAggRuleBase( } protected def createLocalAgg( - completeAgg: BatchExecGroupAggregateBase, - input: RelNode, - relBuilder: RelBuilder): BatchExecGroupAggregateBase = { + completeAgg: BatchPhysicalGroupAggregateBase, + input: RelNode): BatchPhysicalGroupAggregateBase = { val cluster = completeAgg.getCluster val inputRowType = input.getRowType - val grouping = completeAgg.getGrouping - val auxGrouping = completeAgg.getAuxGrouping + val grouping = completeAgg.grouping + val auxGrouping = completeAgg.auxGrouping val aggCalls = completeAgg.getAggCallList val aggCallToAggFunction = completeAgg.getAggCallToAggFunction @@ -91,7 +89,6 @@ abstract class EnforceLocalAggRuleBase( createLocalAgg( cluster, - relBuilder, traitSet, input, completeAgg.getRowType, @@ -104,10 +101,10 @@ abstract class EnforceLocalAggRuleBase( } protected def createExchange( - completeAgg: BatchExecGroupAggregateBase, + completeAgg: BatchPhysicalGroupAggregateBase, input: RelNode): BatchPhysicalExchange = { val cluster = completeAgg.getCluster - val grouping = completeAgg.getGrouping + val grouping = completeAgg.grouping // local aggregate outputs group fields first, and then agg calls val distributionFields = grouping.indices.map(Integer.valueOf) @@ -121,11 +118,10 @@ abstract class EnforceLocalAggRuleBase( } protected def createGlobalAgg( - completeAgg: BatchExecGroupAggregateBase, - input: RelNode, - relBuilder: RelBuilder): BatchExecGroupAggregateBase = { - val grouping = completeAgg.getGrouping - val auxGrouping = completeAgg.getAuxGrouping + completeAgg: BatchPhysicalGroupAggregateBase, + input: RelNode): BatchPhysicalGroupAggregateBase = { + val grouping = completeAgg.grouping + val auxGrouping = completeAgg.auxGrouping val aggCallToAggFunction = completeAgg.getAggCallToAggFunction val (newGrouping, newAuxGrouping) = getGlobalAggGroupSetPair(grouping, auxGrouping) @@ -138,7 +134,6 @@ abstract class EnforceLocalAggRuleBase( case _: BatchExecHashAggregate => new BatchExecHashAggregate( completeAgg.getCluster, - relBuilder, completeAgg.getTraitSet, input, aggRowType, @@ -151,7 +146,6 @@ abstract class EnforceLocalAggRuleBase( case _: BatchExecSortAggregate => new BatchExecSortAggregate( completeAgg.getCluster, - relBuilder, completeAgg.getTraitSet, input, aggRowType, diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/EnforceLocalHashAggRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/EnforceLocalHashAggRule.scala index d750336..7da39bb 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/EnforceLocalHashAggRule.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/EnforceLocalHashAggRule.scala @@ -18,36 +18,36 @@ package org.apache.flink.table.planner.plan.rules.physical.batch -import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchPhysicalExchange, BatchPhysicalExpand, BatchExecHashAggregate} +import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchExecHashAggregate, BatchPhysicalExchange, BatchPhysicalExpand} import org.apache.calcite.plan.RelOptRule.{any, operand} import org.apache.calcite.plan.RelOptRuleCall /** - * An [[EnforceLocalAggRuleBase]] that matches [[BatchExecHashAggregate]] - * - * for example: select count(*) from t group by rollup (a, b) - * The physical plan - * - * {{{ - * HashAggregate(isMerge=[false], groupBy=[a, b, $e], select=[a, b, $e, COUNT(*)]) - * +- Exchange(distribution=[hash[a, b, $e]]) - * +- Expand(projects=[{a=[$0], b=[$1], $e=[0]}, - * {a=[$0], b=[null], $e=[1]}, - * {a=[null], b=[null], $e=[3]}]) - * }}} - * - * will be rewritten to - * - * {{{ - * HashAggregate(isMerge=[true], groupBy=[a, b, $e], select=[a, b, $e, Final_COUNT(count1$0)]) - * +- Exchange(distribution=[hash[a, b, $e]]) - * +- LocalHashAggregate(groupBy=[a, b, $e], select=[a, b, $e, Partial_COUNT(*) AS count1$0] - * +- Expand(projects=[{a=[$0], b=[$1], $e=[0]}, - * {a=[$0], b=[null], $e=[1]}, - * {a=[null], b=[null], $e=[3]}]) - * }}} - */ + * An [[EnforceLocalAggRuleBase]] that matches [[BatchExecHashAggregate]] + * + * for example: select count(*) from t group by rollup (a, b) + * The physical plan + * + * {{{ + * HashAggregate(isMerge=[false], groupBy=[a, b, $e], select=[a, b, $e, COUNT(*)]) + * +- Exchange(distribution=[hash[a, b, $e]]) + * +- Expand(projects=[{a=[$0], b=[$1], $e=[0]}, + * {a=[$0], b=[null], $e=[1]}, + * {a=[null], b=[null], $e=[3]}]) + * }}} + * + * will be rewritten to + * + * {{{ + * HashAggregate(isMerge=[true], groupBy=[a, b, $e], select=[a, b, $e, Final_COUNT(count1$0)]) + * +- Exchange(distribution=[hash[a, b, $e]]) + * +- LocalHashAggregate(groupBy=[a, b, $e], select=[a, b, $e, Partial_COUNT(*) AS count1$0] + * +- Expand(projects=[{a=[$0], b=[$1], $e=[0]}, + * {a=[$0], b=[null], $e=[1]}, + * {a=[null], b=[null], $e=[3]}]) + * }}} + */ class EnforceLocalHashAggRule extends EnforceLocalAggRuleBase( operand(classOf[BatchExecHashAggregate], operand(classOf[BatchPhysicalExchange], @@ -60,7 +60,7 @@ class EnforceLocalHashAggRule extends EnforceLocalAggRuleBase( val enableTwoPhaseAgg = isTwoPhaseAggEnabled(agg) - val grouping = agg.getGrouping + val grouping = agg.grouping val constantShuffleKey = hasConstantShuffleKey(grouping, expand) grouping.nonEmpty && enableTwoPhaseAgg && constantShuffleKey @@ -70,9 +70,9 @@ class EnforceLocalHashAggRule extends EnforceLocalAggRuleBase( val agg: BatchExecHashAggregate = call.rel(0) val expand: BatchPhysicalExpand = call.rel(2) - val localAgg = createLocalAgg(agg, expand, call.builder) + val localAgg = createLocalAgg(agg, expand) val exchange = createExchange(agg, localAgg) - val globalAgg = createGlobalAgg(agg, exchange, call.builder) + val globalAgg = createGlobalAgg(agg, exchange) call.transformTo(globalAgg) } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/EnforceLocalSortAggRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/EnforceLocalSortAggRule.scala index 752f195..65b14b5 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/EnforceLocalSortAggRule.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/EnforceLocalSortAggRule.scala @@ -66,7 +66,7 @@ class EnforceLocalSortAggRule extends EnforceLocalAggRuleBase( val enableTwoPhaseAgg = isTwoPhaseAggEnabled(agg) - val grouping = agg.getGrouping + val grouping = agg.grouping val constantShuffleKey = hasConstantShuffleKey(grouping, expand) grouping.nonEmpty && enableTwoPhaseAgg && constantShuffleKey @@ -76,17 +76,17 @@ class EnforceLocalSortAggRule extends EnforceLocalAggRuleBase( val agg: BatchExecSortAggregate = call.rel(0) val expand: BatchPhysicalExpand = call.rel(3) - val localGrouping = agg.getGrouping + val localGrouping = agg.grouping // create local sort val localSort = createSort(expand, localGrouping) - val localAgg = createLocalAgg(agg, localSort, call.builder) + val localAgg = createLocalAgg(agg, localSort) val exchange = createExchange(agg, localAgg) // create global sort val globalGrouping = localGrouping.indices.toArray val globalSort = createSort(exchange, globalGrouping) - val globalAgg = createGlobalAgg(agg, globalSort, call.builder) + val globalAgg = createGlobalAgg(agg, globalSort) call.transformTo(globalAgg) } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/RemoveRedundantLocalHashAggRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/RemoveRedundantLocalHashAggRule.scala index 33a2bb2..9dde609 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/RemoveRedundantLocalHashAggRule.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/RemoveRedundantLocalHashAggRule.scala @@ -26,9 +26,9 @@ import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall} import org.apache.calcite.rel.RelNode /** - * There maybe exist a subTree like localHashAggregate -> globalHashAggregate which the middle - * shuffle is removed. The rule could remove redundant localHashAggregate node. - */ + * There maybe exist a subTree like localHashAggregate -> globalHashAggregate which the middle + * shuffle is removed. The rule could remove redundant localHashAggregate node. + */ class RemoveRedundantLocalHashAggRule extends RelOptRule( operand(classOf[BatchExecHashAggregate], operand(classOf[BatchExecLocalHashAggregate], @@ -36,19 +36,18 @@ class RemoveRedundantLocalHashAggRule extends RelOptRule( "RemoveRedundantLocalHashAggRule") { override def onMatch(call: RelOptRuleCall): Unit = { - val globalAgg = call.rels(0).asInstanceOf[BatchExecHashAggregate] - val localAgg = call.rels(1).asInstanceOf[BatchExecLocalHashAggregate] + val globalAgg: BatchExecHashAggregate = call.rel(0) + val localAgg: BatchExecLocalHashAggregate = call.rel(1) val inputOfLocalAgg = localAgg.getInput val newGlobalAgg = new BatchExecHashAggregate( globalAgg.getCluster, - call.builder(), globalAgg.getTraitSet, inputOfLocalAgg, globalAgg.getRowType, inputOfLocalAgg.getRowType, inputOfLocalAgg.getRowType, - localAgg.getGrouping, - localAgg.getAuxGrouping, + localAgg.grouping, + localAgg.auxGrouping, // Use the localAgg agg calls because the global agg call filters was removed, // see BatchExecHashAggRule for details. localAgg.getAggCallToAggFunction, diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/RemoveRedundantLocalSortAggRule.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/RemoveRedundantLocalSortAggRule.scala index 787860e..0d8bee0 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/RemoveRedundantLocalSortAggRule.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/physical/batch/RemoveRedundantLocalSortAggRule.scala @@ -40,14 +40,13 @@ abstract class RemoveRedundantLocalSortAggRule( val inputOfLocalAgg = getOriginalInputOfLocalAgg(call) val newGlobalAgg = new BatchExecSortAggregate( globalAgg.getCluster, - call.builder(), globalAgg.getTraitSet, inputOfLocalAgg, globalAgg.getRowType, inputOfLocalAgg.getRowType, inputOfLocalAgg.getRowType, - localAgg.getGrouping, - localAgg.getAuxGrouping, + localAgg.grouping, + localAgg.auxGrouping, // Use the localAgg agg calls because the global agg call filters was removed, // see BatchExecSortAggRule for details. localAgg.getAggCallToAggFunction, diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/FlinkRelMdUtil.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/FlinkRelMdUtil.scala index 6012ffa..321abcd 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/FlinkRelMdUtil.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/FlinkRelMdUtil.scala @@ -23,7 +23,7 @@ import org.apache.flink.table.planner.JDouble import org.apache.flink.table.planner.calcite.FlinkRelBuilder.PlannerNamedWindowProperty import org.apache.flink.table.planner.calcite.FlinkTypeFactory import org.apache.flink.table.planner.plan.nodes.calcite.{Expand, Rank, WindowAggregate} -import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchExecGroupAggregateBase, BatchExecLocalHashWindowAggregate, BatchExecLocalSortWindowAggregate, BatchExecWindowAggregateBase} +import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchExecLocalHashWindowAggregate, BatchExecLocalSortWindowAggregate, BatchExecWindowAggregateBase, BatchPhysicalGroupAggregateBase} import org.apache.flink.table.runtime.operators.rank.{ConstantRankRange, RankRange} import org.apache.flink.table.runtime.operators.sort.BinaryIndexedSortable import org.apache.flink.table.runtime.typeutils.BinaryRowDataSerializer.LENGTH_SIZE_IN_BYTES @@ -46,8 +46,8 @@ import scala.collection.JavaConversions._ import scala.collection.mutable /** - * FlinkRelMdUtil provides utility methods used by the metadata provider methods. - */ + * FlinkRelMdUtil provides utility methods used by the metadata provider methods. + */ object FlinkRelMdUtil { /** Returns an estimate of the number of rows returned by a SEMI/ANTI [[Join]]. */ @@ -65,15 +65,15 @@ object FlinkRelMdUtil { } /** - * Creates a RexNode that stores a selectivity value corresponding to the - * selectivity of a semi-join/anti-join. This can be added to a filter to simulate the - * effect of the semi-join/anti-join during costing, but should never appear in a real - * plan since it has no physical implementation. - * - * @param mq instance of metadata query - * @param rel the SEMI/ANTI join of interest - * @return constructed rexNode - */ + * Creates a RexNode that stores a selectivity value corresponding to the + * selectivity of a semi-join/anti-join. This can be added to a filter to simulate the + * effect of the semi-join/anti-join during costing, but should never appear in a real + * plan since it has no physical implementation. + * + * @param mq instance of metadata query + * @param rel the SEMI/ANTI join of interest + * @return constructed rexNode + */ def makeSemiAntiJoinSelectivityRexNode(mq: RelMetadataQuery, rel: Join): RexNode = { require(rel.getJoinType == JoinRelType.SEMI || rel.getJoinType == JoinRelType.ANTI) val joinInfo = rel.analyzeCondition() @@ -116,15 +116,15 @@ object FlinkRelMdUtil { } /** - * Estimates new distinctRowCount of currentNode after it applies a condition. - * The estimation based on one assumption: - * even distribution of all distinct data - * - * @param rowCount rowcount of node. - * @param distinctRowCount distinct rowcount of node. - * @param selectivity selectivity of condition expression. - * @return new distinctRowCount - */ + * Estimates new distinctRowCount of currentNode after it applies a condition. + * The estimation based on one assumption: + * even distribution of all distinct data + * + * @param rowCount rowcount of node. + * @param distinctRowCount distinct rowcount of node. + * @param selectivity selectivity of condition expression. + * @return new distinctRowCount + */ def adaptNdvBasedOnSelectivity( rowCount: JDouble, distinctRowCount: JDouble, @@ -134,29 +134,29 @@ object FlinkRelMdUtil { } /** - * Estimates ratio outputRowCount/ inputRowCount of agg when ndv of groupKeys is unavailable. - * - * the value of `1.0 - math.exp(-0.1 * groupCount)` increases with groupCount - * from 0.095 until close to 1.0. when groupCount is 1, the formula result is 0.095, - * when groupCount is 2, the formula result is 0.18, - * when groupCount is 3, the formula result is 0.25. - * ... - * - * @param groupingLength grouping keys length of aggregate - * @return the ratio outputRowCount/ inputRowCount of agg when ndv of groupKeys is unavailable. - */ + * Estimates ratio outputRowCount/ inputRowCount of agg when ndv of groupKeys is unavailable. + * + * the value of `1.0 - math.exp(-0.1 * groupCount)` increases with groupCount + * from 0.095 until close to 1.0. when groupCount is 1, the formula result is 0.095, + * when groupCount is 2, the formula result is 0.18, + * when groupCount is 3, the formula result is 0.25. + * ... + * + * @param groupingLength grouping keys length of aggregate + * @return the ratio outputRowCount/ inputRowCount of agg when ndv of groupKeys is unavailable. + */ def getAggregationRatioIfNdvUnavailable(groupingLength: Int): JDouble = 1.0 - math.exp(-0.1 * groupingLength) /** - * Creates a RexNode that stores a selectivity value corresponding to the - * selectivity of a NamedProperties predicate. - * - * @param winAgg window aggregate node - * @param predicate a RexNode - * @return constructed rexNode including non-NamedProperties predicates and - * a predicate that stores NamedProperties predicate's selectivity - */ + * Creates a RexNode that stores a selectivity value corresponding to the + * selectivity of a NamedProperties predicate. + * + * @param winAgg window aggregate node + * @param predicate a RexNode + * @return constructed rexNode including non-NamedProperties predicates and + * a predicate that stores NamedProperties predicate's selectivity + */ def makeNamePropertiesSelectivityRexNode( winAgg: WindowAggregate, predicate: RexNode): RexNode = { @@ -165,14 +165,14 @@ object FlinkRelMdUtil { } /** - * Creates a RexNode that stores a selectivity value corresponding to the - * selectivity of a NamedProperties predicate. - * - * @param globalWinAgg global window aggregate node - * @param predicate a RexNode - * @return constructed rexNode including non-NamedProperties predicates and - * a predicate that stores NamedProperties predicate's selectivity - */ + * Creates a RexNode that stores a selectivity value corresponding to the + * selectivity of a NamedProperties predicate. + * + * @param globalWinAgg global window aggregate node + * @param predicate a RexNode + * @return constructed rexNode including non-NamedProperties predicates and + * a predicate that stores NamedProperties predicate's selectivity + */ def makeNamePropertiesSelectivityRexNode( globalWinAgg: BatchExecWindowAggregateBase, predicate: RexNode): RexNode = { @@ -183,16 +183,16 @@ object FlinkRelMdUtil { } /** - * Creates a RexNode that stores a selectivity value corresponding to the - * selectivity of a NamedProperties predicate. - * - * @param winAgg window aggregate node - * @param fullGrouping full groupSets - * @param namedProperties NamedWindowProperty list - * @param predicate a RexNode - * @return constructed rexNode including non-NamedProperties predicates and - * a predicate that stores NamedProperties predicate's selectivity - */ + * Creates a RexNode that stores a selectivity value corresponding to the + * selectivity of a NamedProperties predicate. + * + * @param winAgg window aggregate node + * @param fullGrouping full groupSets + * @param namedProperties NamedWindowProperty list + * @param predicate a RexNode + * @return constructed rexNode including non-NamedProperties predicates and + * a predicate that stores NamedProperties predicate's selectivity + */ def makeNamePropertiesSelectivityRexNode( winAgg: SingleRel, fullGrouping: Array[Int], @@ -249,17 +249,17 @@ object FlinkRelMdUtil { } /** - * Estimates outputRowCount of local aggregate. - * - * output rowcount of local agg is (1 - pow((1 - 1/x) , n/m)) * m * x, based on two assumption: - * 1. even distribution of all distinct data - * 2. even distribution of all data in each concurrent local agg worker - * - * @param parallelism number of concurrent worker of local aggregate - * @param inputRowCount rowcount of input node of aggregate. - * @param globalAggRowCount rowcount of output of global aggregate. - * @return outputRowCount of local aggregate. - */ + * Estimates outputRowCount of local aggregate. + * + * output rowcount of local agg is (1 - pow((1 - 1/x) , n/m)) * m * x, based on two assumption: + * 1. even distribution of all distinct data + * 2. even distribution of all data in each concurrent local agg worker + * + * @param parallelism number of concurrent worker of local aggregate + * @param inputRowCount rowcount of input node of aggregate. + * @param globalAggRowCount rowcount of output of global aggregate. + * @return outputRowCount of local aggregate. + */ def getRowCountOfLocalAgg( parallelism: Int, inputRowCount: JDouble, @@ -268,13 +268,12 @@ object FlinkRelMdUtil { * globalAggRowCount * parallelism, inputRowCount) /** - * Takes a bitmap representing a set of input references and extracts the - * ones that reference the group by columns in an aggregate. - * - * - * @param groupKey the original bitmap - * @param aggRel the aggregate - */ + * Takes a bitmap representing a set of input references and extracts the + * ones that reference the group by columns in an aggregate. + * + * @param groupKey the original bitmap + * @param aggRel the aggregate + */ def setAggChildKeys( groupKey: ImmutableBitSet, aggRel: Aggregate): (ImmutableBitSet, Array[AggregateCall]) = { @@ -300,26 +299,26 @@ object FlinkRelMdUtil { } /** - * Takes a bitmap representing a set of input references and extracts the - * ones that reference the group by columns in an aggregate. - * - * @param groupKey the original bitmap - * @param aggRel the aggregate - */ + * Takes a bitmap representing a set of input references and extracts the + * ones that reference the group by columns in an aggregate. + * + * @param groupKey the original bitmap + * @param aggRel the aggregate + */ def setAggChildKeys( groupKey: ImmutableBitSet, - aggRel: BatchExecGroupAggregateBase): (ImmutableBitSet, Array[AggregateCall]) = { + aggRel: BatchPhysicalGroupAggregateBase): (ImmutableBitSet, Array[AggregateCall]) = { require(!aggRel.isFinal || !aggRel.isMerge, "Cannot handle global agg which has local agg!") setChildKeysOfAgg(groupKey, aggRel) } /** - * Takes a bitmap representing a set of input references and extracts the - * ones that reference the group by columns in an aggregate. - * - * @param groupKey the original bitmap - * @param aggRel the aggregate - */ + * Takes a bitmap representing a set of input references and extracts the + * ones that reference the group by columns in an aggregate. + * + * @param groupKey the original bitmap + * @param aggRel the aggregate + */ def setAggChildKeys( groupKey: ImmutableBitSet, aggRel: BatchExecWindowAggregateBase): (ImmutableBitSet, Array[AggregateCall]) = { @@ -341,8 +340,8 @@ object FlinkRelMdUtil { agg.getGrouping ++ Array(agg.inputTimeFieldIndex) ++ agg.getAuxGrouping) case agg: BatchExecWindowAggregateBase => (agg.getAggCallList, agg.getGrouping ++ agg.getAuxGrouping) - case agg: BatchExecGroupAggregateBase => - (agg.getAggCallList, agg.getGrouping ++ agg.getAuxGrouping) + case agg: BatchPhysicalGroupAggregateBase => + (agg.getAggCallList, agg.grouping ++ agg.auxGrouping) case _ => throw new IllegalArgumentException(s"Unknown aggregate: ${agg.getRelTypeName}") } // does not need to take keys in aggregate call into consideration if groupKey contains all @@ -362,16 +361,16 @@ object FlinkRelMdUtil { } /** - * Takes a bitmap representing a set of local window aggregate references. - * - * global win-agg output type: groupSet + auxGroupSet + aggCall + namedProperties - * local win-agg output type: groupSet + assignTs + auxGroupSet + aggCalls - * - * Skips `assignTs` when mapping `groupKey` to `childKey`. - * - * @param groupKey the original bitmap - * @param globalWinAgg the global window aggregate - */ + * Takes a bitmap representing a set of local window aggregate references. + * + * global win-agg output type: groupSet + auxGroupSet + aggCall + namedProperties + * local win-agg output type: groupSet + assignTs + auxGroupSet + aggCalls + * + * Skips `assignTs` when mapping `groupKey` to `childKey`. + * + * @param groupKey the original bitmap + * @param globalWinAgg the global window aggregate + */ def setChildKeysOfWinAgg( groupKey: ImmutableBitSet, globalWinAgg: BatchExecWindowAggregateBase): ImmutableBitSet = { @@ -389,12 +388,12 @@ object FlinkRelMdUtil { } /** - * Split groupKeys on Aggregate/ BatchExecGroupAggregateBase/ BatchExecWindowAggregateBase - * into keys on aggregate's groupKey and aggregate's aggregateCalls. - * - * @param agg the aggregate - * @param groupKey the original bitmap - */ + * Split groupKeys on Aggregate/ BatchExecGroupAggregateBase/ BatchExecWindowAggregateBase + * into keys on aggregate's groupKey and aggregate's aggregateCalls. + * + * @param agg the aggregate + * @param groupKey the original bitmap + */ def splitGroupKeysOnAggregate( agg: SingleRel, groupKey: ImmutableBitSet): (ImmutableBitSet, Array[AggregateCall]) = { @@ -418,10 +417,10 @@ object FlinkRelMdUtil { val (childKeys, aggCalls) = setAggChildKeys(groupKey, rel) val childKeyExcludeAuxKey = removeAuxKey(childKeys, rel.getGroupSet.toArray, auxGroupSet) (childKeyExcludeAuxKey, aggCalls) - case rel: BatchExecGroupAggregateBase => + case rel: BatchPhysicalGroupAggregateBase => // set the bits as they correspond to the child input val (childKeys, aggCalls) = setAggChildKeys(groupKey, rel) - val childKeyExcludeAuxKey = removeAuxKey(childKeys, rel.getGrouping, rel.getAuxGrouping) + val childKeyExcludeAuxKey = removeAuxKey(childKeys, rel.grouping, rel.auxGrouping) (childKeyExcludeAuxKey, aggCalls) case rel: BatchExecWindowAggregateBase => val (childKeys, aggCalls) = setAggChildKeys(groupKey, rel) @@ -432,14 +431,14 @@ object FlinkRelMdUtil { } /** - * Split a predicate on Aggregate into two parts, the first one is pushable part, - * the second one is rest part. - * - * @param agg Aggregate which to analyze - * @param predicate Predicate which to analyze - * @return a tuple, first element is pushable part, second element is rest part. - * Note, pushable condition will be converted based on the input field position. - */ + * Split a predicate on Aggregate into two parts, the first one is pushable part, + * the second one is rest part. + * + * @param agg Aggregate which to analyze + * @param predicate Predicate which to analyze + * @return a tuple, first element is pushable part, second element is rest part. + * Note, pushable condition will be converted based on the input field position. + */ def splitPredicateOnAggregate( agg: Aggregate, predicate: RexNode): (Option[RexNode], Option[RexNode]) = { @@ -448,29 +447,29 @@ object FlinkRelMdUtil { } /** - * Split a predicate on BatchExecGroupAggregateBase into two parts, - * the first one is pushable part, the second one is rest part. - * - * @param agg Aggregate which to analyze - * @param predicate Predicate which to analyze - * @return a tuple, first element is pushable part, second element is rest part. - * Note, pushable condition will be converted based on the input field position. - */ + * Split a predicate on BatchExecGroupAggregateBase into two parts, + * the first one is pushable part, the second one is rest part. + * + * @param agg Aggregate which to analyze + * @param predicate Predicate which to analyze + * @return a tuple, first element is pushable part, second element is rest part. + * Note, pushable condition will be converted based on the input field position. + */ def splitPredicateOnAggregate( - agg: BatchExecGroupAggregateBase, + agg: BatchPhysicalGroupAggregateBase, predicate: RexNode): (Option[RexNode], Option[RexNode]) = { - splitPredicateOnAgg(agg.getGrouping ++ agg.getAuxGrouping, agg, predicate) + splitPredicateOnAgg(agg.grouping ++ agg.auxGrouping, agg, predicate) } /** - * Split a predicate on WindowAggregateBatchExecBase into two parts, - * the first one is pushable part, the second one is rest part. - * - * @param agg Aggregate which to analyze - * @param predicate Predicate which to analyze - * @return a tuple, first element is pushable part, second element is rest part. - * Note, pushable condition will be converted based on the input field position. - */ + * Split a predicate on WindowAggregateBatchExecBase into two parts, + * the first one is pushable part, the second one is rest part. + * + * @param agg Aggregate which to analyze + * @param predicate Predicate which to analyze + * @return a tuple, first element is pushable part, second element is rest part. + * Note, pushable condition will be converted based on the input field position. + */ def splitPredicateOnAggregate( agg: BatchExecWindowAggregateBase, predicate: RexNode): (Option[RexNode], Option[RexNode]) = { @@ -478,15 +477,15 @@ object FlinkRelMdUtil { } /** - * Shifts every [[RexInputRef]] in an expression higher than length of full grouping - * (for skips `assignTs`). - * - * global win-agg output type: groupSet + auxGroupSet + aggCall + namedProperties - * local win-agg output type: groupSet + assignTs + auxGroupSet + aggCalls - * - * @param predicate a RexNode - * @param globalWinAgg the global window aggregate - */ + * Shifts every [[RexInputRef]] in an expression higher than length of full grouping + * (for skips `assignTs`). + * + * global win-agg output type: groupSet + auxGroupSet + aggCall + namedProperties + * local win-agg output type: groupSet + assignTs + auxGroupSet + aggCalls + * + * @param predicate a RexNode + * @param globalWinAgg the global window aggregate + */ def setChildPredicateOfWinAgg( predicate: RexNode, globalWinAgg: BatchExecWindowAggregateBase): RexNode = { @@ -614,9 +613,9 @@ object FlinkRelMdUtil { } /** - * Returns [[RexInputRef]] index set of projects corresponding to the given column index. - * The index will be set as -1 if the given column in project is not a [[RexInputRef]]. - */ + * Returns [[RexInputRef]] index set of projects corresponding to the given column index. + * The index will be set as -1 if the given column in project is not a [[RexInputRef]]. + */ def getInputRefIndices(index: Int, expand: Expand): util.Set[Int] = { val inputRefs = new util.HashSet[Int]() for (project <- expand.projects) { @@ -641,26 +640,26 @@ object FlinkRelMdUtil { } /** - * Computes the cardinality of a particular expression from the projection - * list. - * - * @param mq metadata query instance - * @param calc calc RelNode - * @param expr projection expression - * @return cardinality - */ + * Computes the cardinality of a particular expression from the projection + * list. + * + * @param mq metadata query instance + * @param calc calc RelNode + * @param expr projection expression + * @return cardinality + */ def cardOfCalcExpr(mq: RelMetadataQuery, calc: Calc, expr: RexNode): JDouble = { expr.accept(new CardOfCalcExpr(mq, calc)) } /** - * Visitor that walks over a scalar expression and computes the - * cardinality of its result. - * The code is borrowed from RelMdUtil - * - * @param mq metadata query instance - * @param calc calc relnode - */ + * Visitor that walks over a scalar expression and computes the + * cardinality of its result. + * The code is borrowed from RelMdUtil + * + * @param mq metadata query instance + * @param calc calc relnode + */ private class CardOfCalcExpr( mq: RelMetadataQuery, calc: Calc) diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala index bb9e873..d351595 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala @@ -88,17 +88,17 @@ class FlinkRelMdHandlerTestBase { // TODO batch RelNode and stream RelNode should have different PlannerContext // and RelOptCluster due to they have different trait definitions. val plannerContext: PlannerContext = - new PlannerContext( - tableConfig, - new FunctionCatalog(tableConfig, catalogManager, moduleManager), - catalogManager, - CalciteSchema.from(rootSchema), - util.Arrays.asList( - ConventionTraitDef.INSTANCE, - FlinkRelDistributionTraitDef.INSTANCE, - RelCollationTraitDef.INSTANCE - ) + new PlannerContext( + tableConfig, + new FunctionCatalog(tableConfig, catalogManager, moduleManager), + catalogManager, + CalciteSchema.from(rootSchema), + util.Arrays.asList( + ConventionTraitDef.INSTANCE, + FlinkRelDistributionTraitDef.INSTANCE, + RelCollationTraitDef.INSTANCE ) + ) val typeFactory: FlinkTypeFactory = plannerContext.getTypeFactory val mq: FlinkRelMetadataQuery = FlinkRelMetadataQuery.instance() @@ -981,7 +981,6 @@ class FlinkRelMdHandlerTestBase { val batchLocalAgg = new BatchExecLocalHashAggregate( cluster, - relBuilder, batchPhysicalTraits, studentBatchScan, rowTypeOfLocalAgg, @@ -994,7 +993,6 @@ class FlinkRelMdHandlerTestBase { cluster, batchLocalAgg.getTraitSet.replace(hash0), batchLocalAgg, hash0) val batchGlobalAgg = new BatchExecHashAggregate( cluster, - relBuilder, batchPhysicalTraits, batchExchange1, rowTypeOfGlobalAgg, @@ -1009,7 +1007,6 @@ class FlinkRelMdHandlerTestBase { studentBatchScan.getTraitSet.replace(hash3), studentBatchScan, hash3) val batchGlobalAggWithoutLocal = new BatchExecHashAggregate( cluster, - relBuilder, batchPhysicalTraits, batchExchange2, rowTypeOfGlobalAgg, @@ -1111,7 +1108,6 @@ class FlinkRelMdHandlerTestBase { val batchLocalAggWithAuxGroup = new BatchExecLocalHashAggregate( cluster, - relBuilder, batchPhysicalTraits, studentBatchScan, rowTypeOfLocalAgg, @@ -1133,7 +1129,6 @@ class FlinkRelMdHandlerTestBase { .add("cnt", longType).build() val batchGlobalAggWithAuxGroup = new BatchExecHashAggregate( cluster, - relBuilder, batchPhysicalTraits, batchExchange, rowTypeOfGlobalAgg, @@ -1148,7 +1143,6 @@ class FlinkRelMdHandlerTestBase { studentBatchScan.getTraitSet.replace(hash0), studentBatchScan, hash0) val batchGlobalAggWithoutLocalWithAuxGroup = new BatchExecHashAggregate( cluster, - relBuilder, batchPhysicalTraits, batchExchange2, rowTypeOfGlobalAgg, @@ -2416,49 +2410,6 @@ class FlinkRelMdHandlerTestBase { .scan("MyTable2") .minus(false).build() - private def createGlobalAgg( - table: String, groupBy: String, sum: String): BatchExecHashAggregate = { - val scan: BatchPhysicalBoundedStreamScan = - createDataStreamScan(ImmutableList.of(table), batchPhysicalTraits) - relBuilder.push(scan) - val groupByField = relBuilder.field(groupBy) - val sumField = relBuilder.field(sum) - val hash = FlinkRelDistribution.hash(Array(groupByField.getIndex), requireStrict = true) - - val exchange = new BatchPhysicalExchange(cluster, batchPhysicalTraits.replace(hash), scan, hash) - relBuilder.push(exchange) - - val logicalAgg = relBuilder.aggregate( - relBuilder.groupKey(groupBy), - relBuilder.aggregateCall(SqlStdOperatorTable.SUM, relBuilder.field(sum)) - ).build().asInstanceOf[LogicalAggregate] - val aggCalls = logicalAgg.getAggCallList - val aggFunctionFactory = new AggFunctionFactory( - FlinkTypeFactory.toLogicalRowType(studentBatchScan.getRowType), - Array.empty[Int], - Array.fill(aggCalls.size())(false)) - val aggCallToAggFunction = aggCalls.zipWithIndex.map { - case (call, index) => (call, aggFunctionFactory.createAggFunction(call, index)) - } - - val rowTypeOfGlobalAgg = typeFactory.builder - .add(groupByField.getName, groupByField.getType) - .add(sumField.getName, sumField.getType).build() - - new BatchExecHashAggregate( - cluster, - relBuilder, - batchPhysicalTraits, - exchange, - rowTypeOfGlobalAgg, - exchange.getRowType, - exchange.getRowType, - Array(groupByField.getIndex), - auxGrouping = Array(), - aggCallToAggFunction, - isMerge = false) - } - protected def createDataStreamScan[T]( tableNames: util.List[String], traitSet: RelTraitSet): T = { val table = relBuilder diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/MetadataHandlerConsistencyTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/MetadataHandlerConsistencyTest.scala index 6eaadac..327aebb 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/MetadataHandlerConsistencyTest.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/MetadataHandlerConsistencyTest.scala @@ -18,7 +18,7 @@ package org.apache.flink.table.planner.plan.metadata -import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchExecGroupAggregateBase, BatchPhysicalCorrelate} +import org.apache.flink.table.planner.plan.nodes.physical.batch.{BatchPhysicalCorrelate, BatchPhysicalGroupAggregateBase} import org.apache.calcite.rel.RelNode import org.apache.calcite.rel.core.{Aggregate, Correlate} @@ -42,11 +42,11 @@ import scala.collection.mutable * for Aggregate and Correlate. * This test ensure two points. * 1. all subclasses of [[MetadataHandler]] have explicit metadata estimation - * for [[Aggregate]] and [[BatchExecGroupAggregateBase]] or have no metadata estimation for - * [[Aggregate]] and [[BatchExecGroupAggregateBase]] either. + * for [[Aggregate]] and [[BatchPhysicalGroupAggregateBase]] or have no metadata estimation for + * [[Aggregate]] and [[BatchPhysicalGroupAggregateBase]] either. * 2. all subclasses of [[MetadataHandler]] have explicit metadata estimation - * for [[Correlate]] and [[BatchExecGroupAggregateBase]] or have no metadata estimation for - * [[Correlate]] and [[BatchExecGroupAggregateBase]] either. + * for [[Correlate]] and [[BatchPhysicalGroupAggregateBase]] or have no metadata estimation for + * [[Correlate]] and [[BatchPhysicalGroupAggregateBase]] either. * Be cautious that if logical Aggregate and physical Aggregate or logical Correlate and physical * Correlate both are present in a MetadataHandler class, their metadata estimation should be same. * This test does not check this point because every MetadataHandler could have different @@ -144,7 +144,7 @@ object MetadataHandlerConsistencyTest { @Parameterized.Parameters(name = "logicalNodeClass={0}, physicalNodeClass={1}") def parameters(): util.Collection[Array[Any]] = { Seq[Array[Any]]( - Array(classOf[Aggregate], classOf[BatchExecGroupAggregateBase]), + Array(classOf[Aggregate], classOf[BatchPhysicalGroupAggregateBase]), Array(classOf[Correlate], classOf[BatchPhysicalCorrelate])) } } diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/rules/physical/batch/EnforceLocalHashAggRuleTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/rules/physical/batch/EnforceLocalHashAggRuleTest.scala index 61666d5..69cbba9 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/rules/physical/batch/EnforceLocalHashAggRuleTest.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/rules/physical/batch/EnforceLocalHashAggRuleTest.scala @@ -31,8 +31,8 @@ import org.junit.Before /** - * Test for [[EnforceLocalHashAggRule]]. - */ + * Test for [[EnforceLocalHashAggRule]]. + */ class EnforceLocalHashAggRuleTest extends EnforceLocalAggRuleTestBase { @Before @@ -60,10 +60,10 @@ class EnforceLocalHashAggRuleTest extends EnforceLocalAggRuleTestBase { } /** - * Planner rule that ignore the [[OptimizerConfigOptions.TABLE_OPTIMIZER_AGG_PHASE_STRATEGY]] - * value, and only enable one phase aggregate. - * This rule only used for test. - */ + * Planner rule that ignore the [[OptimizerConfigOptions.TABLE_OPTIMIZER_AGG_PHASE_STRATEGY]] + * value, and only enable one phase aggregate. + * This rule only used for test. + */ class BatchExecHashAggRuleForOnePhase extends BatchExecHashAggRule { override protected def isTwoPhaseAggWorkable( aggFunctions: Array[UserDefinedFunction], tableConfig: TableConfig): Boolean = false
