godfreyhe commented on a change in pull request #15997: URL: https://github.com/apache/flink/pull/15997#discussion_r642741007
########## File path: flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/physical/stream/WatermarkAssignerChangelogNormalizeTransposeRuleTest.xml ########## @@ -0,0 +1,191 @@ +<?xml version="1.0" ?> +<!-- +Licensed to the Apache Software Foundation (ASF) under one or more +contributor license agreements. See the NOTICE file distributed with +this work for additional information regarding copyright ownership. +The ASF licenses this file to you under the Apache License, Version 2.0 +(the "License"); you may not use this file except in compliance with +the License. You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--> +<Root> + <TestCase name="testGroupKeyIsComputedColumn"> + <Resource name="sql"> + <![CDATA[ +SELECT + currency2, + COUNT(1) AS cnt, + TUMBLE_START(currency_time, INTERVAL '5' SECOND) as w_start, + TUMBLE_END(currency_time, INTERVAL '5' SECOND) as w_end +FROM src_with_computed_column2 +GROUP BY currency2, TUMBLE(currency_time, INTERVAL '5' SECOND) +]]> + </Resource> + <Resource name="ast"> + <![CDATA[ +LogicalProject(currency2=[$0], cnt=[$2], w_start=[TUMBLE_START($1)], w_end=[TUMBLE_END($1)]) ++- LogicalAggregate(group=[{0, 1}], cnt=[COUNT()]) + +- LogicalProject(currency2=[$1], $f1=[$TUMBLE($5, 5000:INTERVAL SECOND)]) + +- LogicalWatermarkAssigner(rowtime=[currency_time], watermark=[-($5, 5000:INTERVAL SECOND)]) + +- LogicalProject(currency=[$0], currency2=[+($0, 2)], currency_no=[$1], rate=[$2], c=[$3], currency_time=[TO_TIMESTAMP($3)]) + +- LogicalTableScan(table=[[default_catalog, default_database, src_with_computed_column2]]) +]]> + </Resource> + <Resource name="optimized rel plan"> + <![CDATA[ +Calc(select=[currency2, cnt, w$start AS w_start, w$end AS w_end], changelogMode=[I]) ++- GroupWindowAggregate(groupBy=[currency2], window=[TumblingGroupWindow('w$, currency_time, 5000)], properties=[w$start, w$end, w$rowtime, w$proctime], select=[currency2, COUNT(*) AS cnt, start('w$) AS w$start, end('w$) AS w$end, rowtime('w$) AS w$rowtime, proctime('w$) AS w$proctime], changelogMode=[I]) + +- Exchange(distribution=[hash[currency2]], changelogMode=[I,UB,UA,D]) + +- Calc(select=[currency2, currency_time], changelogMode=[I,UB,UA,D]) + +- ChangelogNormalize(key=[currency2], changelogMode=[I,UB,UA,D]) + +- Exchange(distribution=[hash[$0]], changelogMode=[UA,D]) + +- WatermarkAssigner(rowtime=[currency_time], watermark=[-(currency_time, 5000:INTERVAL SECOND)], changelogMode=[UA,D]) + +- Calc(select=[+(currency, 2) AS currency2, TO_TIMESTAMP(c) AS currency_time, currency AS $0], changelogMode=[UA,D]) + +- TableSourceScan(table=[[default_catalog, default_database, src_with_computed_column2, project=[currency, c]]], fields=[currency, c], changelogMode=[UA,D]) +]]> + </Resource> + </TestCase> + <TestCase name="testPushdownCalcAndWatermarkAssignerWithCalc"> + <Resource name="sql"> + <![CDATA[ +SELECT + currency, + COUNT(1) AS cnt, + TUMBLE_START(currency_time, INTERVAL '5' SECOND) as w_start, + TUMBLE_END(currency_time, INTERVAL '5' SECOND) as w_end +FROM src_with_computed_column +GROUP BY currency, TUMBLE(currency_time, INTERVAL '5' SECOND) +]]> + </Resource> + <Resource name="ast"> + <![CDATA[ +LogicalProject(currency=[$0], cnt=[$2], w_start=[TUMBLE_START($1)], w_end=[TUMBLE_END($1)]) ++- LogicalAggregate(group=[{0, 1}], cnt=[COUNT()]) + +- LogicalProject(currency=[$0], $f1=[$TUMBLE($4, 5000:INTERVAL SECOND)]) + +- LogicalWatermarkAssigner(rowtime=[currency_time], watermark=[-($4, 5000:INTERVAL SECOND)]) + +- LogicalProject(currency=[$0], currency_no=[$1], rate=[$2], c=[$3], currency_time=[TO_TIMESTAMP($3)]) + +- LogicalTableScan(table=[[default_catalog, default_database, src_with_computed_column]]) +]]> + </Resource> + <Resource name="optimized rel plan"> + <![CDATA[ +Calc(select=[currency, cnt, w$start AS w_start, w$end AS w_end], changelogMode=[I]) ++- GroupWindowAggregate(groupBy=[currency], window=[TumblingGroupWindow('w$, currency_time, 5000)], properties=[w$start, w$end, w$rowtime, w$proctime], select=[currency, COUNT(*) AS cnt, start('w$) AS w$start, end('w$) AS w$end, rowtime('w$) AS w$rowtime, proctime('w$) AS w$proctime], changelogMode=[I]) + +- Exchange(distribution=[hash[currency]], changelogMode=[I,UB,UA,D]) + +- ChangelogNormalize(key=[currency], changelogMode=[I,UB,UA,D]) + +- Exchange(distribution=[hash[currency]], changelogMode=[UA,D]) + +- WatermarkAssigner(rowtime=[currency_time], watermark=[-(currency_time, 5000:INTERVAL SECOND)], changelogMode=[UA,D]) + +- Calc(select=[currency, TO_TIMESTAMP(c) AS currency_time], changelogMode=[UA,D]) + +- TableSourceScan(table=[[default_catalog, default_database, src_with_computed_column, project=[currency, c]]], fields=[currency, c], changelogMode=[UA,D]) +]]> + </Resource> + </TestCase> + <TestCase name="testPushdownNewCalcAndWatermarkAssignerWithCalc"> + <Resource name="sql"> + <![CDATA[ +SELECT + TUMBLE_START(currency_time, INTERVAL '5' SECOND) as w_start, + TUMBLE_END(currency_time, INTERVAL '5' SECOND) as w_end, + MAX(rate) AS max_rate +FROM src_with_computed_column +GROUP BY TUMBLE(currency_time, INTERVAL '5' SECOND) +]]> + </Resource> + <Resource name="ast"> + <![CDATA[ +LogicalProject(w_start=[TUMBLE_START($0)], w_end=[TUMBLE_END($0)], max_rate=[$1]) ++- LogicalAggregate(group=[{0}], max_rate=[MAX($1)]) + +- LogicalProject($f0=[$TUMBLE($4, 5000:INTERVAL SECOND)], rate=[$2]) + +- LogicalWatermarkAssigner(rowtime=[currency_time], watermark=[-($4, 5000:INTERVAL SECOND)]) + +- LogicalProject(currency=[$0], currency_no=[$1], rate=[$2], c=[$3], currency_time=[TO_TIMESTAMP($3)]) + +- LogicalTableScan(table=[[default_catalog, default_database, src_with_computed_column]]) +]]> + </Resource> + <Resource name="optimized rel plan"> + <![CDATA[ +Calc(select=[w$start AS w_start, w$end AS w_end, max_rate], changelogMode=[I]) ++- GroupWindowAggregate(window=[TumblingGroupWindow('w$, currency_time, 5000)], properties=[w$start, w$end, w$rowtime, w$proctime], select=[MAX(rate) AS max_rate, start('w$) AS w$start, end('w$) AS w$end, rowtime('w$) AS w$rowtime, proctime('w$) AS w$proctime], changelogMode=[I]) + +- Exchange(distribution=[single], changelogMode=[I,UB,UA,D]) + +- Calc(select=[currency_time, rate], changelogMode=[I,UB,UA,D]) + +- ChangelogNormalize(key=[$0], changelogMode=[I,UB,UA,D]) + +- Exchange(distribution=[hash[$0]], changelogMode=[UA,D]) Review comment: the field name is dropped ? see `$0` ########## File path: flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/rules/physical/stream/WatermarkAssignerChangelogNormalizeTransposeRuleTest.xml ########## @@ -0,0 +1,191 @@ +<?xml version="1.0" ?> +<!-- +Licensed to the Apache Software Foundation (ASF) under one or more +contributor license agreements. See the NOTICE file distributed with +this work for additional information regarding copyright ownership. +The ASF licenses this file to you under the Apache License, Version 2.0 +(the "License"); you may not use this file except in compliance with +the License. You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--> +<Root> + <TestCase name="testGroupKeyIsComputedColumn"> + <Resource name="sql"> + <![CDATA[ +SELECT + currency2, + COUNT(1) AS cnt, + TUMBLE_START(currency_time, INTERVAL '5' SECOND) as w_start, + TUMBLE_END(currency_time, INTERVAL '5' SECOND) as w_end +FROM src_with_computed_column2 +GROUP BY currency2, TUMBLE(currency_time, INTERVAL '5' SECOND) +]]> + </Resource> + <Resource name="ast"> + <![CDATA[ +LogicalProject(currency2=[$0], cnt=[$2], w_start=[TUMBLE_START($1)], w_end=[TUMBLE_END($1)]) ++- LogicalAggregate(group=[{0, 1}], cnt=[COUNT()]) + +- LogicalProject(currency2=[$1], $f1=[$TUMBLE($5, 5000:INTERVAL SECOND)]) + +- LogicalWatermarkAssigner(rowtime=[currency_time], watermark=[-($5, 5000:INTERVAL SECOND)]) + +- LogicalProject(currency=[$0], currency2=[+($0, 2)], currency_no=[$1], rate=[$2], c=[$3], currency_time=[TO_TIMESTAMP($3)]) + +- LogicalTableScan(table=[[default_catalog, default_database, src_with_computed_column2]]) +]]> + </Resource> + <Resource name="optimized rel plan"> + <![CDATA[ +Calc(select=[currency2, cnt, w$start AS w_start, w$end AS w_end], changelogMode=[I]) ++- GroupWindowAggregate(groupBy=[currency2], window=[TumblingGroupWindow('w$, currency_time, 5000)], properties=[w$start, w$end, w$rowtime, w$proctime], select=[currency2, COUNT(*) AS cnt, start('w$) AS w$start, end('w$) AS w$end, rowtime('w$) AS w$rowtime, proctime('w$) AS w$proctime], changelogMode=[I]) + +- Exchange(distribution=[hash[currency2]], changelogMode=[I,UB,UA,D]) + +- Calc(select=[currency2, currency_time], changelogMode=[I,UB,UA,D]) + +- ChangelogNormalize(key=[currency2], changelogMode=[I,UB,UA,D]) + +- Exchange(distribution=[hash[$0]], changelogMode=[UA,D]) + +- WatermarkAssigner(rowtime=[currency_time], watermark=[-(currency_time, 5000:INTERVAL SECOND)], changelogMode=[UA,D]) + +- Calc(select=[+(currency, 2) AS currency2, TO_TIMESTAMP(c) AS currency_time, currency AS $0], changelogMode=[UA,D]) + +- TableSourceScan(table=[[default_catalog, default_database, src_with_computed_column2, project=[currency, c]]], fields=[currency, c], changelogMode=[UA,D]) +]]> + </Resource> + </TestCase> + <TestCase name="testPushdownCalcAndWatermarkAssignerWithCalc"> + <Resource name="sql"> + <![CDATA[ +SELECT + currency, + COUNT(1) AS cnt, + TUMBLE_START(currency_time, INTERVAL '5' SECOND) as w_start, + TUMBLE_END(currency_time, INTERVAL '5' SECOND) as w_end +FROM src_with_computed_column +GROUP BY currency, TUMBLE(currency_time, INTERVAL '5' SECOND) +]]> + </Resource> + <Resource name="ast"> + <![CDATA[ +LogicalProject(currency=[$0], cnt=[$2], w_start=[TUMBLE_START($1)], w_end=[TUMBLE_END($1)]) ++- LogicalAggregate(group=[{0, 1}], cnt=[COUNT()]) + +- LogicalProject(currency=[$0], $f1=[$TUMBLE($4, 5000:INTERVAL SECOND)]) + +- LogicalWatermarkAssigner(rowtime=[currency_time], watermark=[-($4, 5000:INTERVAL SECOND)]) + +- LogicalProject(currency=[$0], currency_no=[$1], rate=[$2], c=[$3], currency_time=[TO_TIMESTAMP($3)]) + +- LogicalTableScan(table=[[default_catalog, default_database, src_with_computed_column]]) +]]> + </Resource> + <Resource name="optimized rel plan"> + <![CDATA[ +Calc(select=[currency, cnt, w$start AS w_start, w$end AS w_end], changelogMode=[I]) ++- GroupWindowAggregate(groupBy=[currency], window=[TumblingGroupWindow('w$, currency_time, 5000)], properties=[w$start, w$end, w$rowtime, w$proctime], select=[currency, COUNT(*) AS cnt, start('w$) AS w$start, end('w$) AS w$end, rowtime('w$) AS w$rowtime, proctime('w$) AS w$proctime], changelogMode=[I]) + +- Exchange(distribution=[hash[currency]], changelogMode=[I,UB,UA,D]) + +- ChangelogNormalize(key=[currency], changelogMode=[I,UB,UA,D]) + +- Exchange(distribution=[hash[currency]], changelogMode=[UA,D]) + +- WatermarkAssigner(rowtime=[currency_time], watermark=[-(currency_time, 5000:INTERVAL SECOND)], changelogMode=[UA,D]) + +- Calc(select=[currency, TO_TIMESTAMP(c) AS currency_time], changelogMode=[UA,D]) + +- TableSourceScan(table=[[default_catalog, default_database, src_with_computed_column, project=[currency, c]]], fields=[currency, c], changelogMode=[UA,D]) +]]> + </Resource> + </TestCase> + <TestCase name="testPushdownNewCalcAndWatermarkAssignerWithCalc"> + <Resource name="sql"> + <![CDATA[ +SELECT + TUMBLE_START(currency_time, INTERVAL '5' SECOND) as w_start, + TUMBLE_END(currency_time, INTERVAL '5' SECOND) as w_end, + MAX(rate) AS max_rate +FROM src_with_computed_column +GROUP BY TUMBLE(currency_time, INTERVAL '5' SECOND) +]]> + </Resource> + <Resource name="ast"> + <![CDATA[ +LogicalProject(w_start=[TUMBLE_START($0)], w_end=[TUMBLE_END($0)], max_rate=[$1]) ++- LogicalAggregate(group=[{0}], max_rate=[MAX($1)]) + +- LogicalProject($f0=[$TUMBLE($4, 5000:INTERVAL SECOND)], rate=[$2]) + +- LogicalWatermarkAssigner(rowtime=[currency_time], watermark=[-($4, 5000:INTERVAL SECOND)]) + +- LogicalProject(currency=[$0], currency_no=[$1], rate=[$2], c=[$3], currency_time=[TO_TIMESTAMP($3)]) + +- LogicalTableScan(table=[[default_catalog, default_database, src_with_computed_column]]) +]]> + </Resource> + <Resource name="optimized rel plan"> + <![CDATA[ +Calc(select=[w$start AS w_start, w$end AS w_end, max_rate], changelogMode=[I]) ++- GroupWindowAggregate(window=[TumblingGroupWindow('w$, currency_time, 5000)], properties=[w$start, w$end, w$rowtime, w$proctime], select=[MAX(rate) AS max_rate, start('w$) AS w$start, end('w$) AS w$end, rowtime('w$) AS w$rowtime, proctime('w$) AS w$proctime], changelogMode=[I]) + +- Exchange(distribution=[single], changelogMode=[I,UB,UA,D]) + +- Calc(select=[currency_time, rate], changelogMode=[I,UB,UA,D]) + +- ChangelogNormalize(key=[$0], changelogMode=[I,UB,UA,D]) + +- Exchange(distribution=[hash[$0]], changelogMode=[UA,D]) + +- WatermarkAssigner(rowtime=[currency_time], watermark=[-(currency_time, 5000:INTERVAL SECOND)], changelogMode=[UA,D]) + +- Calc(select=[TO_TIMESTAMP(c) AS currency_time, rate, currency AS $0], changelogMode=[UA,D]) + +- TableSourceScan(table=[[default_catalog, default_database, src_with_computed_column, project=[c, rate, currency]]], fields=[c, rate, currency], changelogMode=[UA,D]) +]]> + </Resource> + </TestCase> + <TestCase name="testPushdownWatermarkAssignerWithCalc"> + <Resource name="sql"> + <![CDATA[ +SELECT + TUMBLE_START(currency_time, INTERVAL '5' SECOND) as w_start, + TUMBLE_END(currency_time, INTERVAL '5' SECOND) as w_end, + MAX(rate) AS max_rate +FROM simple_src +GROUP BY TUMBLE(currency_time, INTERVAL '5' SECOND) +]]> + </Resource> + <Resource name="ast"> + <![CDATA[ +LogicalProject(w_start=[TUMBLE_START($0)], w_end=[TUMBLE_END($0)], max_rate=[$1]) ++- LogicalAggregate(group=[{0}], max_rate=[MAX($1)]) + +- LogicalProject($f0=[$TUMBLE($3, 5000:INTERVAL SECOND)], rate=[$2]) + +- LogicalWatermarkAssigner(rowtime=[currency_time], watermark=[-($3, 5000:INTERVAL SECOND)]) + +- LogicalTableScan(table=[[default_catalog, default_database, simple_src]]) +]]> + </Resource> + <Resource name="optimized rel plan"> + <![CDATA[ +Calc(select=[w$start AS w_start, w$end AS w_end, max_rate], changelogMode=[I]) ++- GroupWindowAggregate(window=[TumblingGroupWindow('w$, currency_time, 5000)], properties=[w$start, w$end, w$rowtime, w$proctime], select=[MAX(rate) AS max_rate, start('w$) AS w$start, end('w$) AS w$end, rowtime('w$) AS w$rowtime, proctime('w$) AS w$proctime], changelogMode=[I]) + +- Exchange(distribution=[single], changelogMode=[I,UB,UA,D]) + +- Calc(select=[currency_time, rate], changelogMode=[I,UB,UA,D]) Review comment: The test case name and the result do not match, and the calc should be pushed down then less data will be shuffled before `ChangelogNormalize` ########## File path: flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/rules/physical/stream/WatermarkAssignerChangelogNormalizeTransposeRule.java ########## @@ -62,43 +82,276 @@ public WatermarkAssignerChangelogNormalizeTransposeRule(Config config) { public void onMatch(RelOptRuleCall call) { final StreamPhysicalWatermarkAssigner watermark = call.rel(0); final RelNode node = call.rel(1); + RelNode newTree; if (node instanceof StreamPhysicalCalc) { // with calc final StreamPhysicalCalc calc = call.rel(1); final StreamPhysicalChangelogNormalize changelogNormalize = call.rel(2); final StreamPhysicalExchange exchange = call.rel(3); - - final RelNode newTree = - buildTreeInOrder( - changelogNormalize, exchange, watermark, calc, exchange.getInput()); - call.transformTo(newTree); + final Mappings.TargetMapping calcMapping = buildMapping(calc.getProgram()); + final RelDistribution exchangeDistribution = exchange.getDistribution(); + final RelDistribution newExchangeDistribution = exchangeDistribution.apply(calcMapping); + // Pushes down WatermarkAssigner/Calc as a whole if shuffle keys of + // Exchange are all kept by Calc + final boolean shuffleKeysAreKeptByCalc = + newExchangeDistribution.getType() == exchangeDistribution.getType() + && newExchangeDistribution.getKeys().size() + == exchangeDistribution.getKeys().size(); + if (shuffleKeysAreKeptByCalc) { + newTree = + pushDownWatermarkAndCalc( + watermark, + calc, + changelogNormalize, + exchange, + newExchangeDistribution); + } else { + final List<Integer> projectedOutShuffleKeys = + deriveProjectedOutShuffleKeys(exchangeDistribution.getKeys(), calcMapping); + final RexBuilder rexBuilder = call.builder().getRexBuilder(); + // Creates a new Program which contains all shuffle keys + final RexProgram newPushDownProgram = + createNewProgramWithAllShuffleKeys( + calc.getProgram(), projectedOutShuffleKeys, rexBuilder); + if (newPushDownProgram.isPermutation()) { + // Pushes down WatermarkAssigner alone if new pushDown program is permutation + newTree = + pushDownWatermarkAlone( + watermark, + calc, + changelogNormalize, + exchange, + calcMapping, + rexBuilder); + } else { + // Pushes down new WatermarkAssigner/Calc, adds a top Calc to remove new added + // shuffle keys + newTree = + pushDownNewWatermarkAndCalc( + newPushDownProgram, + watermark, + calc, + changelogNormalize, + exchange, + rexBuilder); + } + } } else if (node instanceof StreamPhysicalChangelogNormalize) { // without calc final StreamPhysicalChangelogNormalize changelogNormalize = call.rel(1); final StreamPhysicalExchange exchange = call.rel(2); - - final RelNode newTree = - buildTreeInOrder(changelogNormalize, exchange, watermark, exchange.getInput()); - call.transformTo(newTree); + newTree = + buildTreeInOrder( + exchange.getInput(), + // Clears distribution on new WatermarkAssigner + Tuple2.of( + watermark, + watermark.getTraitSet().plus(FlinkRelDistribution.DEFAULT())), + Tuple2.of(exchange, exchange.getTraitSet()), + Tuple2.of(changelogNormalize, changelogNormalize.getTraitSet())); } else { throw new IllegalStateException( this.getClass().getName() + " matches a wrong relation tree: " + RelOptUtil.toString(watermark)); } + call.transformTo(newTree); + } + + private RelNode pushDownWatermarkAndCalc( + StreamPhysicalWatermarkAssigner watermark, + StreamPhysicalCalc calc, + StreamPhysicalChangelogNormalize changelogNormalize, + StreamPhysicalExchange exchange, + RelDistribution newExchangeDistribution) { + return buildTreeInOrder( + exchange.getInput(), + // clears distribution on new Calc/WatermarkAssigner + Tuple2.of(calc, calc.getTraitSet().plus(FlinkRelDistribution.DEFAULT())), + Tuple2.of(watermark, watermark.getTraitSet().plus(FlinkRelDistribution.DEFAULT())), + // updates distribution on new Exchange/Normalize based on field + // mapping of Calc + Tuple2.of(exchange, exchange.getTraitSet().plus(newExchangeDistribution)), + Tuple2.of( + changelogNormalize, + changelogNormalize.getTraitSet().plus(newExchangeDistribution))); + } + + private RelNode pushDownNewWatermarkAndCalc( + RexProgram newPushDownProgram, + StreamPhysicalWatermarkAssigner watermark, + StreamPhysicalCalc calc, + StreamPhysicalChangelogNormalize changelogNormalize, + StreamPhysicalExchange exchange, + RexBuilder rexBuilder) { + final RelNode pushDownCalc = + calc.copy( + // clears distribution on new Calc + calc.getTraitSet().plus(FlinkRelDistribution.DEFAULT()), + exchange.getInput(), + newPushDownProgram); + final Mappings.TargetMapping mappingOfPushDownCalc = buildMapping(newPushDownProgram); + final RelDistribution newDistribution = + exchange.getDistribution().apply(mappingOfPushDownCalc); + final RelNode newChangelogNormalize = + buildTreeInOrder( + pushDownCalc, + Tuple2.of( + watermark, + watermark.getTraitSet().plus(FlinkRelDistribution.DEFAULT())), + // updates distribution on new Exchange/Normalize based on field + // mapping of Calc + Tuple2.of(exchange, exchange.getTraitSet().plus(newDistribution)), + Tuple2.of( + changelogNormalize, + changelogNormalize.getTraitSet().plus(newDistribution))); + final List<String> newInputFieldNames = newChangelogNormalize.getRowType().getFieldNames(); + final RexProgramBuilder topProgramBuilder = + new RexProgramBuilder(newChangelogNormalize.getRowType(), rexBuilder); + for (int fieldIdx = 0; fieldIdx < calc.getRowType().getFieldCount(); fieldIdx++) { + topProgramBuilder.addProject( + RexInputRef.of(fieldIdx, newChangelogNormalize.getRowType()), + newInputFieldNames.get(fieldIdx)); + } + final RexProgram topProgram = topProgramBuilder.getProgram(); + return calc.copy(calc.getTraitSet(), newChangelogNormalize, topProgram); + } + + private RelNode pushDownWatermarkAlone( + StreamPhysicalWatermarkAssigner watermark, + StreamPhysicalCalc calc, + StreamPhysicalChangelogNormalize changelogNormalize, + StreamPhysicalExchange exchange, + Mappings.TargetMapping calcMapping, + RexBuilder rexBuilder) { + Mappings.TargetMapping inversedMapping = calcMapping.inverse(); + final int newRowTimeFieldIndex = + inversedMapping.getTargetOpt(watermark.rowtimeFieldIndex()); + // Updates watermark properties after push down before Calc + // 1. rewrites watermark expression + // 2. clears distribution + // 3. updates row time field index + RexNode newWatermarkExpr = watermark.watermarkExpr(); + if (watermark.watermarkExpr() != null) { + newWatermarkExpr = RexUtil.apply(inversedMapping, watermark.watermarkExpr()); + } + final RelNode newWatermark = + watermark.copy( + watermark.getTraitSet().plus(FlinkRelDistribution.DEFAULT()), + exchange.getInput(), + newRowTimeFieldIndex, + newWatermarkExpr); + final RelNode newChangelogNormalize = + buildTreeInOrder( + newWatermark, + Tuple2.of(exchange, exchange.getTraitSet()), + Tuple2.of(changelogNormalize, changelogNormalize.getTraitSet())); + // Rewrites Calc program because the field type of row time + // field is changed after watermark pushed down + final RexProgram oldProgram = calc.getProgram(); + final RexProgramBuilder programBuilder = + new RexProgramBuilder(newChangelogNormalize.getRowType(), rexBuilder); + final Function<RexNode, RexNode> rexShuttle = + e -> + e.accept( + new RexShuttle() { + @Override + public RexNode visitInputRef(RexInputRef inputRef) { + if (inputRef.getIndex() == newRowTimeFieldIndex) { + return RexInputRef.of( + newRowTimeFieldIndex, + newChangelogNormalize.getRowType()); + } else { + return inputRef; + } + } + }); + oldProgram + .getNamedProjects() + .forEach( + pair -> + programBuilder.addProject( + rexShuttle.apply(oldProgram.expandLocalRef(pair.left)), + pair.right)); + if (oldProgram.getCondition() != null) { + programBuilder.addCondition( + rexShuttle.apply(oldProgram.expandLocalRef(oldProgram.getCondition()))); + } + final RexProgram newProgram = programBuilder.getProgram(); + return calc.copy(calc.getTraitSet(), newChangelogNormalize, newProgram); + } + + private List<Integer> deriveProjectedOutShuffleKeys( + List<Integer> allShuffleKeys, Mappings.TargetMapping calcMapping) { + List<Integer> projectsOutShuffleKeys = new ArrayList<>(); + for (Integer key : allShuffleKeys) { + int targetIdx = calcMapping.getTargetOpt(key); + if (targetIdx < 0) { + projectsOutShuffleKeys.add(key); + } + } + return projectsOutShuffleKeys; + } + + private RexProgram createNewProgramWithAllShuffleKeys( + RexProgram program, List<Integer> projectsOutShuffleKeys, RexBuilder rexBuilder) { + RelDataType oldInputRowType = program.getInputRowType(); + RexProgramBuilder newProgramBuilder = new RexProgramBuilder(oldInputRowType, rexBuilder); + program.getNamedProjects() + .forEach( + pair -> + newProgramBuilder.addProject( + program.expandLocalRef(pair.left), pair.right)); + for (Integer projectsOutShuffleKey : projectsOutShuffleKeys) { + newProgramBuilder.addProject( + RexInputRef.of(projectsOutShuffleKey, oldInputRowType), null); + } + if (program.getCondition() != null) { + newProgramBuilder.addCondition(program.expandLocalRef(program.getCondition())); + } + return newProgramBuilder.getProgram(); + } + + private Mappings.TargetMapping buildMapping(RexProgram program) { + final Map<Integer, Integer> mapInToOutPos = new HashMap<>(); + final List<RexLocalRef> projects = program.getProjectList(); + for (int idx = 0; idx < projects.size(); idx++) { + RexNode rexNode = program.expandLocalRef(projects.get(idx)); + if (rexNode instanceof RexInputRef) { + mapInToOutPos.put(((RexInputRef) rexNode).getIndex(), idx); + } + } + return Mappings.target( + mapInToOutPos, + program.getInputRowType().getFieldCount(), + program.getOutputRowType().getFieldCount()); } /** - * Build a new {@link RelNode} tree in the given nodes order which is in root-down direction. + * Build a new {@link RelNode} tree in the given nodes order which is in bottom-up direction. */ - private RelNode buildTreeInOrder(RelNode... nodes) { - checkArgument(nodes.length >= 2); - RelNode root = nodes[nodes.length - 1]; - for (int i = nodes.length - 2; i >= 0; i--) { - RelNode node = nodes[i]; - root = node.copy(node.getTraitSet(), Collections.singletonList(root)); + private RelNode buildTreeInOrder( Review comment: nit: add `@SafeVarargs` to make idea happy ########## File path: flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/rules/physical/stream/WatermarkAssignerChangelogNormalizeTransposeRule.java ########## @@ -62,43 +82,276 @@ public WatermarkAssignerChangelogNormalizeTransposeRule(Config config) { public void onMatch(RelOptRuleCall call) { final StreamPhysicalWatermarkAssigner watermark = call.rel(0); final RelNode node = call.rel(1); + RelNode newTree; if (node instanceof StreamPhysicalCalc) { // with calc final StreamPhysicalCalc calc = call.rel(1); final StreamPhysicalChangelogNormalize changelogNormalize = call.rel(2); final StreamPhysicalExchange exchange = call.rel(3); - - final RelNode newTree = - buildTreeInOrder( - changelogNormalize, exchange, watermark, calc, exchange.getInput()); - call.transformTo(newTree); + final Mappings.TargetMapping calcMapping = buildMapping(calc.getProgram()); + final RelDistribution exchangeDistribution = exchange.getDistribution(); + final RelDistribution newExchangeDistribution = exchangeDistribution.apply(calcMapping); + // Pushes down WatermarkAssigner/Calc as a whole if shuffle keys of + // Exchange are all kept by Calc + final boolean shuffleKeysAreKeptByCalc = + newExchangeDistribution.getType() == exchangeDistribution.getType() + && newExchangeDistribution.getKeys().size() + == exchangeDistribution.getKeys().size(); + if (shuffleKeysAreKeptByCalc) { + newTree = + pushDownWatermarkAndCalc( + watermark, + calc, + changelogNormalize, + exchange, + newExchangeDistribution); + } else { + final List<Integer> projectedOutShuffleKeys = + deriveProjectedOutShuffleKeys(exchangeDistribution.getKeys(), calcMapping); + final RexBuilder rexBuilder = call.builder().getRexBuilder(); + // Creates a new Program which contains all shuffle keys + final RexProgram newPushDownProgram = + createNewProgramWithAllShuffleKeys( + calc.getProgram(), projectedOutShuffleKeys, rexBuilder); + if (newPushDownProgram.isPermutation()) { + // Pushes down WatermarkAssigner alone if new pushDown program is permutation + newTree = + pushDownWatermarkAlone( + watermark, + calc, + changelogNormalize, + exchange, + calcMapping, + rexBuilder); + } else { + // Pushes down new WatermarkAssigner/Calc, adds a top Calc to remove new added + // shuffle keys + newTree = + pushDownNewWatermarkAndCalc( + newPushDownProgram, + watermark, + calc, + changelogNormalize, + exchange, + rexBuilder); + } + } } else if (node instanceof StreamPhysicalChangelogNormalize) { // without calc final StreamPhysicalChangelogNormalize changelogNormalize = call.rel(1); final StreamPhysicalExchange exchange = call.rel(2); - - final RelNode newTree = - buildTreeInOrder(changelogNormalize, exchange, watermark, exchange.getInput()); - call.transformTo(newTree); + newTree = + buildTreeInOrder( + exchange.getInput(), + // Clears distribution on new WatermarkAssigner + Tuple2.of( + watermark, + watermark.getTraitSet().plus(FlinkRelDistribution.DEFAULT())), + Tuple2.of(exchange, exchange.getTraitSet()), + Tuple2.of(changelogNormalize, changelogNormalize.getTraitSet())); } else { throw new IllegalStateException( this.getClass().getName() + " matches a wrong relation tree: " + RelOptUtil.toString(watermark)); } + call.transformTo(newTree); + } + + private RelNode pushDownWatermarkAndCalc( + StreamPhysicalWatermarkAssigner watermark, + StreamPhysicalCalc calc, + StreamPhysicalChangelogNormalize changelogNormalize, + StreamPhysicalExchange exchange, + RelDistribution newExchangeDistribution) { + return buildTreeInOrder( + exchange.getInput(), + // clears distribution on new Calc/WatermarkAssigner + Tuple2.of(calc, calc.getTraitSet().plus(FlinkRelDistribution.DEFAULT())), + Tuple2.of(watermark, watermark.getTraitSet().plus(FlinkRelDistribution.DEFAULT())), + // updates distribution on new Exchange/Normalize based on field + // mapping of Calc + Tuple2.of(exchange, exchange.getTraitSet().plus(newExchangeDistribution)), + Tuple2.of( + changelogNormalize, + changelogNormalize.getTraitSet().plus(newExchangeDistribution))); + } + + private RelNode pushDownNewWatermarkAndCalc( + RexProgram newPushDownProgram, + StreamPhysicalWatermarkAssigner watermark, + StreamPhysicalCalc calc, + StreamPhysicalChangelogNormalize changelogNormalize, + StreamPhysicalExchange exchange, + RexBuilder rexBuilder) { + final RelNode pushDownCalc = + calc.copy( + // clears distribution on new Calc + calc.getTraitSet().plus(FlinkRelDistribution.DEFAULT()), + exchange.getInput(), + newPushDownProgram); + final Mappings.TargetMapping mappingOfPushDownCalc = buildMapping(newPushDownProgram); + final RelDistribution newDistribution = + exchange.getDistribution().apply(mappingOfPushDownCalc); + final RelNode newChangelogNormalize = + buildTreeInOrder( + pushDownCalc, + Tuple2.of( + watermark, + watermark.getTraitSet().plus(FlinkRelDistribution.DEFAULT())), + // updates distribution on new Exchange/Normalize based on field + // mapping of Calc + Tuple2.of(exchange, exchange.getTraitSet().plus(newDistribution)), + Tuple2.of( + changelogNormalize, + changelogNormalize.getTraitSet().plus(newDistribution))); + final List<String> newInputFieldNames = newChangelogNormalize.getRowType().getFieldNames(); + final RexProgramBuilder topProgramBuilder = + new RexProgramBuilder(newChangelogNormalize.getRowType(), rexBuilder); + for (int fieldIdx = 0; fieldIdx < calc.getRowType().getFieldCount(); fieldIdx++) { + topProgramBuilder.addProject( + RexInputRef.of(fieldIdx, newChangelogNormalize.getRowType()), + newInputFieldNames.get(fieldIdx)); + } + final RexProgram topProgram = topProgramBuilder.getProgram(); + return calc.copy(calc.getTraitSet(), newChangelogNormalize, topProgram); + } + + private RelNode pushDownWatermarkAlone( + StreamPhysicalWatermarkAssigner watermark, + StreamPhysicalCalc calc, + StreamPhysicalChangelogNormalize changelogNormalize, + StreamPhysicalExchange exchange, + Mappings.TargetMapping calcMapping, + RexBuilder rexBuilder) { + Mappings.TargetMapping inversedMapping = calcMapping.inverse(); + final int newRowTimeFieldIndex = + inversedMapping.getTargetOpt(watermark.rowtimeFieldIndex()); + // Updates watermark properties after push down before Calc + // 1. rewrites watermark expression + // 2. clears distribution + // 3. updates row time field index + RexNode newWatermarkExpr = watermark.watermarkExpr(); + if (watermark.watermarkExpr() != null) { + newWatermarkExpr = RexUtil.apply(inversedMapping, watermark.watermarkExpr()); + } + final RelNode newWatermark = + watermark.copy( + watermark.getTraitSet().plus(FlinkRelDistribution.DEFAULT()), + exchange.getInput(), + newRowTimeFieldIndex, + newWatermarkExpr); + final RelNode newChangelogNormalize = + buildTreeInOrder( + newWatermark, + Tuple2.of(exchange, exchange.getTraitSet()), + Tuple2.of(changelogNormalize, changelogNormalize.getTraitSet())); + // Rewrites Calc program because the field type of row time + // field is changed after watermark pushed down + final RexProgram oldProgram = calc.getProgram(); + final RexProgramBuilder programBuilder = + new RexProgramBuilder(newChangelogNormalize.getRowType(), rexBuilder); + final Function<RexNode, RexNode> rexShuttle = + e -> + e.accept( + new RexShuttle() { + @Override + public RexNode visitInputRef(RexInputRef inputRef) { + if (inputRef.getIndex() == newRowTimeFieldIndex) { + return RexInputRef.of( + newRowTimeFieldIndex, + newChangelogNormalize.getRowType()); + } else { + return inputRef; + } + } + }); + oldProgram + .getNamedProjects() + .forEach( + pair -> + programBuilder.addProject( + rexShuttle.apply(oldProgram.expandLocalRef(pair.left)), + pair.right)); + if (oldProgram.getCondition() != null) { + programBuilder.addCondition( + rexShuttle.apply(oldProgram.expandLocalRef(oldProgram.getCondition()))); + } + final RexProgram newProgram = programBuilder.getProgram(); + return calc.copy(calc.getTraitSet(), newChangelogNormalize, newProgram); + } + + private List<Integer> deriveProjectedOutShuffleKeys( + List<Integer> allShuffleKeys, Mappings.TargetMapping calcMapping) { + List<Integer> projectsOutShuffleKeys = new ArrayList<>(); + for (Integer key : allShuffleKeys) { + int targetIdx = calcMapping.getTargetOpt(key); + if (targetIdx < 0) { + projectsOutShuffleKeys.add(key); + } + } + return projectsOutShuffleKeys; + } + + private RexProgram createNewProgramWithAllShuffleKeys( + RexProgram program, List<Integer> projectsOutShuffleKeys, RexBuilder rexBuilder) { + RelDataType oldInputRowType = program.getInputRowType(); + RexProgramBuilder newProgramBuilder = new RexProgramBuilder(oldInputRowType, rexBuilder); + program.getNamedProjects() + .forEach( + pair -> + newProgramBuilder.addProject( + program.expandLocalRef(pair.left), pair.right)); + for (Integer projectsOutShuffleKey : projectsOutShuffleKeys) { + newProgramBuilder.addProject( + RexInputRef.of(projectsOutShuffleKey, oldInputRowType), null); + } + if (program.getCondition() != null) { + newProgramBuilder.addCondition(program.expandLocalRef(program.getCondition())); + } + return newProgramBuilder.getProgram(); + } + + private Mappings.TargetMapping buildMapping(RexProgram program) { + final Map<Integer, Integer> mapInToOutPos = new HashMap<>(); + final List<RexLocalRef> projects = program.getProjectList(); + for (int idx = 0; idx < projects.size(); idx++) { + RexNode rexNode = program.expandLocalRef(projects.get(idx)); + if (rexNode instanceof RexInputRef) { + mapInToOutPos.put(((RexInputRef) rexNode).getIndex(), idx); + } + } + return Mappings.target( + mapInToOutPos, + program.getInputRowType().getFieldCount(), + program.getOutputRowType().getFieldCount()); } /** - * Build a new {@link RelNode} tree in the given nodes order which is in root-down direction. + * Build a new {@link RelNode} tree in the given nodes order which is in bottom-up direction. */ - private RelNode buildTreeInOrder(RelNode... nodes) { - checkArgument(nodes.length >= 2); - RelNode root = nodes[nodes.length - 1]; - for (int i = nodes.length - 2; i >= 0; i--) { - RelNode node = nodes[i]; - root = node.copy(node.getTraitSet(), Collections.singletonList(root)); + private RelNode buildTreeInOrder( + RelNode leafNode, Tuple2<RelNode, RelTraitSet>... nodeAndTraits) { + checkArgument(nodeAndTraits.length >= 1); + RelNode inputNode = leafNode; + RelNode currentNode = null; + for (Tuple2<RelNode, RelTraitSet> nodeAndPair : nodeAndTraits) { Review comment: nit: nodeAndPair -> nodeAndTrait -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected]
