This is an automated email from the ASF dual-hosted git repository. snuyanzin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit f01de3c4303c4f9ec2b5781af41347c7ecb1630b Author: Sergey Nuyanzin <[email protected]> AuthorDate: Fri Mar 15 09:33:14 2024 +0100 [FLINK-34896][table] Migrate CorrelateSortToRankRule to java This closes #24545 --- .../rules/logical/CorrelateSortToRankRule.java | 256 +++++++++++++++++++++ .../rules/logical/CorrelateSortToRankRule.scala | 192 ---------------- 2 files changed, 256 insertions(+), 192 deletions(-) diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/CorrelateSortToRankRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/CorrelateSortToRankRule.java new file mode 100644 index 00000000000..82c38712a39 --- /dev/null +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/CorrelateSortToRankRule.java @@ -0,0 +1,256 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.rules.logical; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.table.planner.calcite.FlinkRelBuilder; +import org.apache.flink.table.planner.calcite.FlinkRelFactories; +import org.apache.flink.table.runtime.operators.rank.ConstantRankRange; +import org.apache.flink.table.runtime.operators.rank.RankType; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.plan.hep.HepPlanner; +import org.apache.calcite.rel.RelCollation; +import org.apache.calcite.rel.RelCollations; +import org.apache.calcite.rel.RelFieldCollation; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Aggregate; +import org.apache.calcite.rel.core.Correlate; +import org.apache.calcite.rel.core.Filter; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rel.core.Sort; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexCorrelVariable; +import org.apache.calcite.rex.RexFieldAccess; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexLiteral; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.tools.RelBuilder; +import org.apache.calcite.util.ImmutableBitSet; +import org.immutables.value.Value; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +/** + * Planner rule that rewrites sort correlation to a Rank. Typically, the following plan + * + * <pre>{@code + * LogicalProject(state=[$0], name=[$1]) + * +- LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0}]) + * :- LogicalAggregate(group=[{0}]) + * : +- LogicalProject(state=[$1]) + * : +- LogicalTableScan(table=[[default_catalog, default_database, cities]]) + * +- LogicalSort(sort0=[$1], dir0=[DESC-nulls-last], fetch=[3]) + * +- LogicalProject(name=[$0], pop=[$2]) + * +- LogicalFilter(condition=[=($1, $cor0.state)]) + * +- LogicalTableScan(table=[[default_catalog, default_database, cities]]) + * }</pre> + * + * <p>would be transformed to + * + * <pre>{@code + * LogicalProject(state=[$0], name=[$1]) + * +- LogicalProject(state=[$1], name=[$0], pop=[$2]) + * +- LogicalRank(rankType=[ROW_NUMBER], rankRange=[rankStart=1, rankEnd=3], + * partitionBy=[$1], orderBy=[$2 DESC], select=[name=$0, state=$1, pop=$2]) + * +- LogicalTableScan(table=[[default_catalog, default_database, cities]]) + * }</pre> + * + * <p>To match the Correlate, the LHS needs to be a global Aggregate on a scan, the RHS should be a + * Sort with an equal Filter predicate whose keys are same with the LHS grouping keys. + * + * <p>This rule can only be used in {@link HepPlanner}. + */ [email protected] +public class CorrelateSortToRankRule + extends RelRule<CorrelateSortToRankRule.CorrelateSortToRankRuleConfig> { + + public static final CorrelateSortToRankRule INSTANCE = + CorrelateSortToRankRule.CorrelateSortToRankRuleConfig.DEFAULT.toRule(); + + protected CorrelateSortToRankRule(CorrelateSortToRankRuleConfig config) { + super(config); + } + + @Override + public boolean matches(RelOptRuleCall call) { + Correlate correlate = call.rel(0); + if (correlate.getJoinType() != JoinRelType.INNER) { + return false; + } + Aggregate agg = call.rel(1); + if (!agg.getAggCallList().isEmpty() || agg.getGroupSets().size() > 1) { + return false; + } + Project aggInput = call.rel(2); + if (!aggInput.isMapping()) { + return false; + } + Sort sort = call.rel(3); + if (sort.offset != null || sort.fetch == null) { + // 1. we can not describe the offset using rank + // 2. there is no need to transform to rank if no fetch limit + return false; + } + Project sortInput = call.rel(4); + if (!sortInput.isMapping()) { + return false; + } + Filter filter = call.rel(5); + + List<RexNode> cnfCond = RelOptUtil.conjunctions(filter.getCondition()); + if (cnfCond.stream().anyMatch(c -> !isValidCondition(c, correlate))) { + return false; + } + + return aggInput.getInput().getDigest().equals(filter.getInput().getDigest()); + } + + private boolean isValidCondition(RexNode condition, Correlate correlate) { + // must be equiv condition + if (condition.getKind() != SqlKind.EQUALS) { + return false; + } + Tuple2<RexInputRef, RexFieldAccess> tuple = resolveFilterCondition(condition); + if (tuple.f0 == null) { + return false; + } + RexCorrelVariable variable = (RexCorrelVariable) tuple.f1.getReferenceExpr(); + return variable.id.equals(correlate.getCorrelationId()); + } + + /** + * Resolves the filter condition with specific pattern: input ref and field access. + * + * @param condition The join condition + * @return tuple of operands (RexInputRef, RexFieldAccess), or null if the pattern does not + * match + */ + private Tuple2<RexInputRef, RexFieldAccess> resolveFilterCondition(RexNode condition) { + RexCall condCall = (RexCall) condition; + RexNode operand0 = condCall.getOperands().get(0); + RexNode operand1 = condCall.getOperands().get(1); + if (operand0.isA(SqlKind.INPUT_REF) && operand1.isA(SqlKind.FIELD_ACCESS)) { + return Tuple2.of((RexInputRef) operand0, (RexFieldAccess) operand1); + } else if (operand0.isA(SqlKind.FIELD_ACCESS) && operand1.isA(SqlKind.INPUT_REF)) { + return Tuple2.of((RexInputRef) operand1, (RexFieldAccess) operand0); + } else { + return Tuple2.of(null, null); + } + } + + @Override + public void onMatch(RelOptRuleCall call) { + RelBuilder builder = call.builder(); + + Sort sort = call.rel(3); + Project sortInput = call.rel(4); + Filter filter = call.rel(5); + + List<RexNode> cnfCond = RelOptUtil.conjunctions(filter.getCondition()); + ImmutableBitSet partitionKey = + ImmutableBitSet.of( + cnfCond.stream() + .map(c -> resolveFilterCondition(c).f0.getIndex()) + .collect(Collectors.toList())); + + RelDataType baseType = sortInput.getInput().getRowType(); + List<RexNode> projects = new ArrayList<>(); + partitionKey.asList().forEach(k -> projects.add(RexInputRef.of(k, baseType))); + projects.addAll(sortInput.getProjects()); + + RelCollation oriCollation = sort.getCollation(); + List<RelFieldCollation> newFieldCollations = + oriCollation.getFieldCollations().stream() + .map( + fc -> { + int newFieldIdx = + ((RexInputRef) + sortInput + .getProjects() + .get(fc.getFieldIndex())) + .getIndex(); + return fc.withFieldIndex(newFieldIdx); + }) + .collect(Collectors.toList()); + RelCollation newCollation = RelCollations.of(newFieldCollations); + + RelNode newRel = + ((FlinkRelBuilder) (builder.push(filter.getInput()))) + .rank( + partitionKey, + newCollation, + RankType.ROW_NUMBER, + new ConstantRankRange( + 1, ((RexLiteral) sort.fetch).getValueAs(Long.class)), + null, + false) + .project(projects) + .build(); + + call.transformTo(newRel); + } + + /** Rule configuration. */ + @Value.Immutable(singleton = false) + public interface CorrelateSortToRankRuleConfig extends RelRule.Config { + CorrelateSortToRankRule.CorrelateSortToRankRuleConfig DEFAULT = + ImmutableCorrelateSortToRankRule.CorrelateSortToRankRuleConfig.builder() + .operandSupplier( + b0 -> + b0.operand(Correlate.class) + .inputs( + b1 -> + b1.operand(Aggregate.class) + .oneInput( + b2 -> + b2.operand( + Project + .class) + .anyInputs()), + b2 -> + b2.operand(Sort.class) + .inputs( + b3 -> + b3.operand( + Project + .class) + .inputs( + b4 -> + b4.operand( + Filter + .class) + .anyInputs())))) + .relBuilderFactory(FlinkRelFactories.FLINK_REL_BUILDER()) + .description("CorrelateSortToRankRule") + .build(); + + @Override + default CorrelateSortToRankRule toRule() { + return new CorrelateSortToRankRule(this); + } + } +} diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/CorrelateSortToRankRule.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/CorrelateSortToRankRule.scala deleted file mode 100644 index 32efd087bbe..00000000000 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/CorrelateSortToRankRule.scala +++ /dev/null @@ -1,192 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.flink.table.planner.plan.rules.logical - -import org.apache.flink.table.planner.calcite.{FlinkRelBuilder, FlinkRelFactories} -import org.apache.flink.table.runtime.operators.rank.{ConstantRankRange, RankType} - -import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelOptUtil} -import org.apache.calcite.plan.RelOptRule.{any, operand} -import org.apache.calcite.rel.`type`.RelDataType -import org.apache.calcite.rel.RelCollations -import org.apache.calcite.rel.core.{Aggregate, Correlate, Filter, JoinRelType, Project, Sort} -import org.apache.calcite.rex.{RexCall, RexCorrelVariable, RexFieldAccess, RexInputRef, RexLiteral, RexNode} -import org.apache.calcite.sql.SqlKind -import org.apache.calcite.util.ImmutableBitSet - -import java.util - -import scala.collection.JavaConversions._ - -/** - * Planner rule that rewrites sort correlation to a Rank. Typically, the following plan - * - * {{{ - * LogicalProject(state=[$0], name=[$1]) - * +- LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0}]) - * :- LogicalAggregate(group=[{0}]) - * : +- LogicalProject(state=[$1]) - * : +- LogicalTableScan(table=[[default_catalog, default_database, cities]]) - * +- LogicalSort(sort0=[$1], dir0=[DESC-nulls-last], fetch=[3]) - * +- LogicalProject(name=[$0], pop=[$2]) - * +- LogicalFilter(condition=[=($1, $cor0.state)]) - * +- LogicalTableScan(table=[[default_catalog, default_database, cities]]) - * }}} - * - * <p>would be transformed to - * - * {{{ - * LogicalProject(state=[$0], name=[$1]) - * +- LogicalProject(state=[$1], name=[$0], pop=[$2]) - * +- LogicalRank(rankType=[ROW_NUMBER], rankRange=[rankStart=1, rankEnd=3], - * partitionBy=[$1], orderBy=[$2 DESC], select=[name=$0, state=$1, pop=$2]) - * +- LogicalTableScan(table=[[default_catalog, default_database, cities]]) - * }}} - * - * <p>To match the Correlate, the LHS needs to be a global Aggregate on a scan, the RHS should be a - * Sort with an equal Filter predicate whose keys are same with the LHS grouping keys. - * - * <p>This rule can only be used in HepPlanner. - */ -class CorrelateSortToRankRule - extends RelOptRule( - operand( - classOf[Correlate], - operand(classOf[Aggregate], operand(classOf[Project], any())), - operand(classOf[Sort], operand(classOf[Project], operand(classOf[Filter], any()))) - ), - FlinkRelFactories.FLINK_REL_BUILDER, - "CorrelateSortToRankRule") { - - override def matches(call: RelOptRuleCall): Boolean = { - val correlate: Correlate = call.rel(0) - if (correlate.getJoinType != JoinRelType.INNER) { - return false - } - val agg: Aggregate = call.rel(1) - if (agg.getAggCallList.size() > 0 || agg.getGroupSets.size() > 1) { - return false - } - val aggInput: Project = call.rel(2) - if (!aggInput.isMapping) { - return false - } - val sort: Sort = call.rel(3) - if (sort.offset != null || sort.fetch == null) { - // 1. we can not describe the offset using rank - // 2. there is no need to transform to rank if no fetch limit - return false - } - val sortInput: Project = call.rel(4) - if (!sortInput.isMapping) { - return false - } - val filter: Filter = call.rel(5) - - val cnfCond = RelOptUtil.conjunctions(filter.getCondition) - if (cnfCond.exists(c => !isValidCondition(c, correlate))) { - return false - } - - aggInput.getInput.getDigest.equals(filter.getInput.getDigest) - } - - private def isValidCondition(condition: RexNode, correlate: Correlate): Boolean = { - // must be equiv condition - if (condition.getKind != SqlKind.EQUALS) { - return false - } - val (inputRef, fieldAccess) = resolveFilterCondition(condition) - if (inputRef == null) { - return false - } - val variable = fieldAccess.getReferenceExpr.asInstanceOf[RexCorrelVariable] - variable.id.equals(correlate.getCorrelationId) - } - - /** - * Resolves the filter condition with specific pattern: input ref and field access. - * - * @param condition - * The join condition - * @return - * tuple of operands (RexInputRef, RexFieldAccess), or null if the pattern does not match - */ - private def resolveFilterCondition(condition: RexNode): (RexInputRef, RexFieldAccess) = { - val condCall = condition.asInstanceOf[RexCall] - val operand0 = condCall.getOperands.get(0) - val operand1 = condCall.getOperands.get(1) - if (operand0.isA(SqlKind.INPUT_REF) && operand1.isA(SqlKind.FIELD_ACCESS)) { - (operand0.asInstanceOf[RexInputRef], operand1.asInstanceOf[RexFieldAccess]) - } else if (operand0.isA(SqlKind.FIELD_ACCESS) && operand1.isA(SqlKind.INPUT_REF)) { - (operand1.asInstanceOf[RexInputRef], operand0.asInstanceOf[RexFieldAccess]) - } else { - (null, null) - } - } - - override def onMatch(call: RelOptRuleCall): Unit = { - val builder = call.builder() - - val sort: Sort = call.rel(3) - val sortInput: Project = call.rel(4) - val filter: Filter = call.rel(5) - - val cnfCond = RelOptUtil.conjunctions(filter.getCondition) - val partitionKey: ImmutableBitSet = - ImmutableBitSet.of(cnfCond.map(c => resolveFilterCondition(c)._1.getIndex): _*) - - val baseType: RelDataType = sortInput.getInput().getRowType - val projects = new util.ArrayList[RexNode]() - partitionKey.asList().foreach(k => projects.add(RexInputRef.of(k, baseType))) - projects.addAll(sortInput.getProjects) - - val oriCollation = sort.getCollation - val newFieldCollations = oriCollation.getFieldCollations.map { - fc => - val newFieldIdx = sortInput.getProjects - .get(fc.getFieldIndex) - .asInstanceOf[RexInputRef] - .getIndex - fc.withFieldIndex(newFieldIdx) - } - val newCollation = RelCollations.of(newFieldCollations) - - val newRel = builder - .push(filter.getInput()) - .asInstanceOf[FlinkRelBuilder] - .rank( - partitionKey, - newCollation, - RankType.ROW_NUMBER, - new ConstantRankRange( - 1, - sort.fetch.asInstanceOf[RexLiteral].getValueAs(classOf[java.lang.Long])), - null, - false - ) - .project(projects) - .build() - - call.transformTo(newRel) - } -} - -object CorrelateSortToRankRule { - val INSTANCE = new CorrelateSortToRankRule -}
