This is an automated email from the ASF dual-hosted git repository. dwysakowicz pushed a commit to branch release-1.9 in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/release-1.9 by this push: new 1de0053 [FLINK-12249][table] Fix type equivalence check problems for Window Aggregates 1de0053 is described below commit 1de005392022404adbce4cd8b9de90f157052bd0 Author: hequn8128 <chenghe...@gmail.com> AuthorDate: Fri Jul 19 16:53:17 2019 +0800 [FLINK-12249][table] Fix type equivalence check problems for Window Aggregates This closes #9141 --- .../flink/sql/tests/StreamSQLTestProgram.java | 4 +- .../logical/LogicalWindowAggregateRuleBase.scala | 100 ++++++++++++++++--- .../plan/batch/sql/agg/WindowAggregateTest.xml | 107 +++++++++++++++++++++ .../plan/stream/sql/agg/WindowAggregateTest.xml | 35 +++++++ .../plan/batch/sql/agg/WindowAggregateTest.scala | 22 +++++ .../plan/stream/sql/agg/WindowAggregateTest.scala | 21 ++++ .../rules/common/LogicalWindowAggregateRule.scala | 96 ++++++++++++++++-- .../table/api/batch/sql/GroupWindowTest.scala | 41 ++++++++ .../table/api/stream/sql/GroupWindowTest.scala | 43 +++++++++ 9 files changed, 448 insertions(+), 21 deletions(-) diff --git a/flink-end-to-end-tests/flink-stream-sql-test/src/main/java/org/apache/flink/sql/tests/StreamSQLTestProgram.java b/flink-end-to-end-tests/flink-stream-sql-test/src/main/java/org/apache/flink/sql/tests/StreamSQLTestProgram.java index cde040d..47bca8e 100644 --- a/flink-end-to-end-tests/flink-stream-sql-test/src/main/java/org/apache/flink/sql/tests/StreamSQLTestProgram.java +++ b/flink-end-to-end-tests/flink-stream-sql-test/src/main/java/org/apache/flink/sql/tests/StreamSQLTestProgram.java @@ -106,9 +106,7 @@ public class StreamSQLTestProgram { String tumbleQuery = String.format( "SELECT " + " key, " + - //TODO: The "WHEN -1 THEN NULL" part is a temporary workaround, to make the test pass, for - // https://issues.apache.org/jira/browse/FLINK-12249. We should remove it once the issue is fixed. - " CASE SUM(cnt) / COUNT(*) WHEN 101 THEN 1 WHEN -1 THEN NULL ELSE 99 END AS correct, " + + " CASE SUM(cnt) / COUNT(*) WHEN 101 THEN 1 ELSE 99 END AS correct, " + " TUMBLE_START(rowtime, INTERVAL '%d' SECOND) AS wStart, " + " TUMBLE_ROWTIME(rowtime, INTERVAL '%d' SECOND) AS rowtime " + "FROM (%s) " + diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/LogicalWindowAggregateRuleBase.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/LogicalWindowAggregateRuleBase.scala index 6c24296..ee24adb 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/LogicalWindowAggregateRuleBase.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/LogicalWindowAggregateRuleBase.scala @@ -33,8 +33,10 @@ import org.apache.calcite.plan._ import org.apache.calcite.plan.hep.HepRelVertex import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.core.Aggregate.Group +import org.apache.calcite.rel.core.{Aggregate, AggregateCall} import org.apache.calcite.rel.logical.{LogicalAggregate, LogicalProject} import org.apache.calcite.rex._ +import org.apache.calcite.sql.`type`.SqlTypeUtil import org.apache.calcite.util.ImmutableBitSet import _root_.java.math.BigDecimal @@ -84,15 +86,31 @@ abstract class LogicalWindowAggregateRuleBase(description: String) .project(project.getChildExps.updated(windowExprIdx, inAggGroupExpression)) .build() + // Currently, this rule removes the window from GROUP BY operation which may lead to changes + // of AggCall's type which brings fails on type checks. + // To solve the problem, we change the types to the inferred types in the Aggregate and then + // cast back in the project after Aggregate. + val indexAndTypes = getIndexAndInferredTypesIfChanged(agg) + val finalCalls = adjustTypes(agg, indexAndTypes) + // we don't use the builder here because it uses RelMetadataQuery which affects the plan val newAgg = LogicalAggregate.create( newProject, agg.indicator, newGroupSet, ImmutableList.of(newGroupSet), - agg.getAggCallList) + finalCalls) + + val transformed = call.builder() + val windowAgg = LogicalWindowAggregate.create( + window, + Seq[PlannerNamedWindowProperty](), + newAgg) + transformed.push(windowAgg) - // create an additional project to conform with types + // The transformation adds an additional LogicalProject at the top to ensure + // that the types are equivalent. + // 1. ensure group key types, create an additional project to conform with types val outAggGroupExpression0 = getOutAggregateGroupExpression(rexBuilder, windowExpr) // fix up the nullability if it is changed. val outAggGroupExpression = if (windowExpr.getType.isNullable != @@ -103,20 +121,80 @@ abstract class LogicalWindowAggregateRuleBase(description: String) } else { outAggGroupExpression0 } - val transformed = call.builder() - val windowAgg = LogicalWindowAggregate.create( - window, - Seq[PlannerNamedWindowProperty](), - newAgg) - // The transformation adds an additional LogicalProject at the top to ensure - // that the types are equivalent. - transformed.push(windowAgg) - .project(transformed.fields().patch(windowExprIdx, Seq(outAggGroupExpression), 0)) + val projectsEnsureGroupKeyTypes = + transformed.fields.patch(windowExprIdx, Seq(outAggGroupExpression), 0) + // 2. ensure aggCall types + val projectsEnsureAggCallTypes = + projectsEnsureGroupKeyTypes.zipWithIndex.map { + case (aggCall, index) => + val aggCallIndex = index - agg.getGroupCount + if (indexAndTypes.containsKey(aggCallIndex)) { + rexBuilder.makeCast(agg.getAggCallList.get(aggCallIndex).`type`, aggCall, true) + } else { + aggCall + } + } + transformed.project(projectsEnsureAggCallTypes) val result = transformed.build() call.transformTo(result) } + /** + * Change the types of [[AggregateCall]] to the corresponding inferred types. + */ + private def adjustTypes( + agg: LogicalAggregate, + indexAndTypes: Map[Int, RelDataType]) = { + + agg.getAggCallList.zipWithIndex.map { + case (aggCall, index) => + if (indexAndTypes.containsKey(index)) { + AggregateCall.create( + aggCall.getAggregation, + aggCall.isDistinct, + aggCall.isApproximate, + aggCall.ignoreNulls(), + aggCall.getArgList, + aggCall.filterArg, + aggCall.collation, + agg.getGroupCount, + agg.getInput, + indexAndTypes(index), + aggCall.name) + } else { + aggCall + } + } + } + + /** + * Check if there are any types of [[AggregateCall]] that need to be changed. Return the + * [[AggregateCall]] indexes and the corresponding inferred types. + */ + private def getIndexAndInferredTypesIfChanged( + agg: LogicalAggregate) + : Map[Int, RelDataType] = { + + agg.getAggCallList.zipWithIndex.flatMap { + case (aggCall, index) => + val origType = aggCall.`type` + val aggCallBinding = new Aggregate.AggCallBinding( + agg.getCluster.getTypeFactory, + aggCall.getAggregation, + SqlTypeUtil.projectTypes(agg.getInput.getRowType, aggCall.getArgList), + 0, + aggCall.hasFilter) + val inferredType = aggCall.getAggregation.inferReturnType(aggCallBinding) + + if (origType != inferredType && agg.getGroupCount == 1) { + Some(index, inferredType) + } else { + None + } + }.toMap + } + private[table] def getWindowExpressions(agg: LogicalAggregate): Seq[(RexCall, Int)] = { val project = agg.getInput.asInstanceOf[HepRelVertex].getCurrentRel.asInstanceOf[LogicalProject] val groupKeys = agg.getGroupSet diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/WindowAggregateTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/WindowAggregateTest.xml index 88ac82a..6286349 100644 --- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/WindowAggregateTest.xml +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/WindowAggregateTest.xml @@ -1409,4 +1409,111 @@ Calc(select=[w$end AS EXPR$0]) ]]> </Resource> </TestCase> + <TestCase name="testReturnTypeInferenceForWindowAgg[aggStrategy=AUTO]"> + <Resource name="sql"> + <![CDATA[ +SELECT + SUM(correct) AS s, + AVG(correct) AS a, + TUMBLE_START(b, INTERVAL '15' MINUTE) AS wStart +FROM ( + SELECT CASE a + WHEN 1 THEN 1 + ELSE 99 + END AS correct, b + FROM MyTable +) +GROUP BY TUMBLE(b, INTERVAL '15' MINUTE) + ]]> + </Resource> + <Resource name="planBefore"> + <![CDATA[ +LogicalProject(s=[$1], a=[$2], wStart=[TUMBLE_START($0)]) ++- LogicalAggregate(group=[{0}], s=[SUM($1)], a=[AVG($1)]) + +- LogicalProject($f0=[TUMBLE($1, 900000:INTERVAL MINUTE)], correct=[CASE(=($0, 1), 1, 99)]) + +- LogicalTableScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d)]]]) +]]> + </Resource> + <Resource name="planAfter"> + <![CDATA[ +Calc(select=[CAST(CASE(=($f1, 0), null:INTEGER, s)) AS s, CAST(/(CAST(CASE(=($f1, 0), null:INTEGER, s)), $f1)) AS a, w$start AS wStart]) ++- HashWindowAggregate(window=[TumblingGroupWindow], properties=[w$start, w$end, w$rowtime], select=[Final_$SUM0(sum$0) AS s, Final_COUNT(count1$1) AS $f1]) + +- Exchange(distribution=[single]) + +- LocalHashWindowAggregate(window=[TumblingGroupWindow], properties=[w$start, w$end, w$rowtime], select=[Partial_$SUM0($f1) AS sum$0, Partial_COUNT(*) AS count1$1]) + +- Calc(select=[b, CASE(=(a, 1), 1, 99) AS $f1]) + +- TableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d)]]], fields=[a, b, c, d]) +]]> + </Resource> + </TestCase> + <TestCase name="testReturnTypeInferenceForWindowAgg[aggStrategy=ONE_PHASE]"> + <Resource name="sql"> + <![CDATA[ +SELECT + SUM(correct) AS s, + AVG(correct) AS a, + TUMBLE_START(b, INTERVAL '15' MINUTE) AS wStart +FROM ( + SELECT CASE a + WHEN 1 THEN 1 + ELSE 99 + END AS correct, b + FROM MyTable +) +GROUP BY TUMBLE(b, INTERVAL '15' MINUTE) + ]]> + </Resource> + <Resource name="planBefore"> + <![CDATA[ +LogicalProject(s=[$1], a=[$2], wStart=[TUMBLE_START($0)]) ++- LogicalAggregate(group=[{0}], s=[SUM($1)], a=[AVG($1)]) + +- LogicalProject($f0=[TUMBLE($1, 900000:INTERVAL MINUTE)], correct=[CASE(=($0, 1), 1, 99)]) + +- LogicalTableScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d)]]]) +]]> + </Resource> + <Resource name="planAfter"> + <![CDATA[ +Calc(select=[CAST(CASE(=($f1, 0), null:INTEGER, s)) AS s, CAST(/(CAST(CASE(=($f1, 0), null:INTEGER, s)), $f1)) AS a, w$start AS wStart]) ++- HashWindowAggregate(window=[TumblingGroupWindow], properties=[w$start, w$end, w$rowtime], select=[$SUM0($f1) AS s, COUNT(*) AS $f1]) + +- Exchange(distribution=[single]) + +- Calc(select=[b, CASE(=(a, 1), 1, 99) AS $f1]) + +- TableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d)]]], fields=[a, b, c, d]) +]]> + </Resource> + </TestCase> + <TestCase name="testReturnTypeInferenceForWindowAgg[aggStrategy=TWO_PHASE]"> + <Resource name="sql"> + <![CDATA[ +SELECT + SUM(correct) AS s, + AVG(correct) AS a, + TUMBLE_START(b, INTERVAL '15' MINUTE) AS wStart +FROM ( + SELECT CASE a + WHEN 1 THEN 1 + ELSE 99 + END AS correct, b + FROM MyTable +) +GROUP BY TUMBLE(b, INTERVAL '15' MINUTE) + ]]> + </Resource> + <Resource name="planBefore"> + <![CDATA[ +LogicalProject(s=[$1], a=[$2], wStart=[TUMBLE_START($0)]) ++- LogicalAggregate(group=[{0}], s=[SUM($1)], a=[AVG($1)]) + +- LogicalProject($f0=[TUMBLE($1, 900000:INTERVAL MINUTE)], correct=[CASE(=($0, 1), 1, 99)]) + +- LogicalTableScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d)]]]) +]]> + </Resource> + <Resource name="planAfter"> + <![CDATA[ +Calc(select=[CAST(CASE(=($f1, 0), null:INTEGER, s)) AS s, CAST(/(CAST(CASE(=($f1, 0), null:INTEGER, s)), $f1)) AS a, w$start AS wStart]) ++- HashWindowAggregate(window=[TumblingGroupWindow], properties=[w$start, w$end, w$rowtime], select=[Final_$SUM0(sum$0) AS s, Final_COUNT(count1$1) AS $f1]) + +- Exchange(distribution=[single]) + +- LocalHashWindowAggregate(window=[TumblingGroupWindow], properties=[w$start, w$end, w$rowtime], select=[Partial_$SUM0($f1) AS sum$0, Partial_COUNT(*) AS count1$1]) + +- Calc(select=[b, CASE(=(a, 1), 1, 99) AS $f1]) + +- TableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d)]]], fields=[a, b, c, d]) +]]> + </Resource> + </TestCase> </Root> diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/agg/WindowAggregateTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/agg/WindowAggregateTest.xml index 5847055..9f10574e 100644 --- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/agg/WindowAggregateTest.xml +++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/agg/WindowAggregateTest.xml @@ -467,4 +467,39 @@ Calc(select=[EXPR$0, wAvg, w$start AS EXPR$2, w$end AS EXPR$3]) ]]> </Resource> </TestCase> + <TestCase name="testReturnTypeInferenceForWindowAgg"> + <Resource name="sql"> + <![CDATA[ +SELECT + SUM(correct) AS s, + AVG(correct) AS a, + TUMBLE_START(rowtime, INTERVAL '15' MINUTE) AS wStart +FROM ( + SELECT CASE a + WHEN 1 THEN 1 + ELSE 99 + END AS correct, rowtime + FROM MyTable +) +GROUP BY TUMBLE(rowtime, INTERVAL '15' MINUTE) + ]]> + </Resource> + <Resource name="planBefore"> + <![CDATA[ +LogicalProject(s=[$1], a=[$2], wStart=[TUMBLE_START($0)]) ++- LogicalAggregate(group=[{0}], s=[SUM($1)], a=[AVG($1)]) + +- LogicalProject($f0=[TUMBLE($4, 900000:INTERVAL MINUTE)], correct=[CASE(=($0, 1), 1, 99)]) + +- LogicalTableScan(table=[[default_catalog, default_database, MyTable]]) +]]> + </Resource> + <Resource name="planAfter"> + <![CDATA[ +Calc(select=[CAST(CASE(=($f1, 0), null:INTEGER, s)) AS s, CAST(/(CAST(CASE(=($f1, 0), null:INTEGER, s)), $f1)) AS a, w$start AS wStart]) ++- GroupWindowAggregate(window=[TumblingGroupWindow], properties=[w$start, w$end, w$rowtime, w$proctime], select=[$SUM0($f1) AS s, COUNT(*) AS $f1, start('w$) AS w$start, end('w$) AS w$end, rowtime('w$) AS w$rowtime, proctime('w$) AS w$proctime]) + +- Exchange(distribution=[single]) + +- Calc(select=[rowtime, CASE(=(a, 1), 1, 99) AS $f1]) + +- DataStreamScan(table=[[default_catalog, default_database, MyTable]], fields=[a, b, c, proctime, rowtime]) +]]> + </Resource> + </TestCase> </Root> diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/batch/sql/agg/WindowAggregateTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/batch/sql/agg/WindowAggregateTest.scala index be1ad8b..0b021b7 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/batch/sql/agg/WindowAggregateTest.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/batch/sql/agg/WindowAggregateTest.scala @@ -300,6 +300,28 @@ class WindowAggregateTest(aggStrategy: AggregatePhaseStrategy) extends TableTest """.stripMargin util.verifyPlan(sql) } + + @Test + def testReturnTypeInferenceForWindowAgg() = { + + val sql = + """ + |SELECT + | SUM(correct) AS s, + | AVG(correct) AS a, + | TUMBLE_START(b, INTERVAL '15' MINUTE) AS wStart + |FROM ( + | SELECT CASE a + | WHEN 1 THEN 1 + | ELSE 99 + | END AS correct, b + | FROM MyTable + |) + |GROUP BY TUMBLE(b, INTERVAL '15' MINUTE) + """.stripMargin + + util.verifyPlan(sql) + } } object WindowAggregateTest { diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/agg/WindowAggregateTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/agg/WindowAggregateTest.scala index 414450e..3c773ca 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/agg/WindowAggregateTest.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/agg/WindowAggregateTest.scala @@ -295,4 +295,25 @@ class WindowAggregateTest extends TableTestBase { util.verifyPlan(sql) } + @Test + def testReturnTypeInferenceForWindowAgg() = { + + val sql = + """ + |SELECT + | SUM(correct) AS s, + | AVG(correct) AS a, + | TUMBLE_START(rowtime, INTERVAL '15' MINUTE) AS wStart + |FROM ( + | SELECT CASE a + | WHEN 1 THEN 1 + | ELSE 99 + | END AS correct, rowtime + | FROM MyTable + |) + |GROUP BY TUMBLE(rowtime, INTERVAL '15' MINUTE) + """.stripMargin + + util.verifyPlan(sql) + } } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/common/LogicalWindowAggregateRule.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/common/LogicalWindowAggregateRule.scala index 8c0f0c0..431fe9e 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/common/LogicalWindowAggregateRule.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/rules/common/LogicalWindowAggregateRule.scala @@ -21,8 +21,10 @@ import com.google.common.collect.ImmutableList import org.apache.calcite.plan._ import org.apache.calcite.plan.hep.HepRelVertex import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.rel.core.{Aggregate, AggregateCall} import org.apache.calcite.rel.logical.{LogicalAggregate, LogicalProject} import org.apache.calcite.rex._ +import org.apache.calcite.sql.`type`.SqlTypeUtil import org.apache.calcite.util.ImmutableBitSet import org.apache.flink.table.api._ import org.apache.flink.table.calcite.FlinkRelBuilder.NamedWindowProperty @@ -78,24 +80,104 @@ abstract class LogicalWindowAggregateRule(ruleName: String) .project(project.getChildExps.updated(windowExprIdx, inAggGroupExpression)) .build() + // Currently, this rule removes the window from GROUP BY operation which may lead to changes + // of AggCall's type which brings fails on type checks. + // To solve the problem, we change the types to the inferred types in the Aggregate and then + // cast back in the project after Aggregate. + val indexAndTypes = getIndexAndInferredTypesIfChanged(agg) + val finalCalls = adjustTypes(agg, indexAndTypes) + // we don't use the builder here because it uses RelMetadataQuery which affects the plan val newAgg = LogicalAggregate.create( newProject, agg.indicator, newGroupSet, ImmutableList.of(newGroupSet), - agg.getAggCallList) + finalCalls) - // create an additional project to conform with types - val outAggGroupExpression = getOutAggregateGroupExpression(rexBuilder, windowExpr) val transformed = call.builder() - transformed.push(LogicalWindowAggregate.create( + val windowAgg = LogicalWindowAggregate.create( window, Seq[NamedWindowProperty](), - newAgg)) - .project(transformed.fields().patch(windowExprIdx, Seq(outAggGroupExpression), 0)) + newAgg) + transformed.push(windowAgg) - call.transformTo(transformed.build()) + // The transformation adds an additional LogicalProject at the top to ensure + // that the types are equivalent. + // 1. ensure group key types, create an additional project to conform with types + val outAggGroupExpression = getOutAggregateGroupExpression(rexBuilder, windowExpr) + val projectsEnsureGroupKeyTypes = + transformed.fields.patch(windowExprIdx, Seq(outAggGroupExpression), 0) + // 2. ensure aggCall types + val projectsEnsureAggCallTypes = + projectsEnsureGroupKeyTypes.zipWithIndex.map { + case (aggCall, index) => + val aggCallIndex = index - agg.getGroupCount + if (indexAndTypes.containsKey(aggCallIndex)) { + rexBuilder.makeCast(agg.getAggCallList.get(aggCallIndex).`type`, aggCall, true) + } else { + aggCall + } + } + transformed.project(projectsEnsureAggCallTypes) + + val result = transformed.build() + call.transformTo(result) + } + + /** + * Change the types of [[AggregateCall]] to the corresponding inferred types. + */ + private def adjustTypes( + agg: LogicalAggregate, + indexAndTypes: Map[Int, RelDataType]) = { + + agg.getAggCallList.zipWithIndex.map { + case (aggCall, index) => + if (indexAndTypes.containsKey(index)) { + AggregateCall.create( + aggCall.getAggregation, + aggCall.isDistinct, + aggCall.isApproximate, + aggCall.ignoreNulls(), + aggCall.getArgList, + aggCall.filterArg, + aggCall.collation, + agg.getGroupCount, + agg.getInput, + indexAndTypes(index), + aggCall.name) + } else { + aggCall + } + } + } + + /** + * Check if there are any types of [[AggregateCall]] that need to be changed. Return the + * [[AggregateCall]] indexes and the corresponding inferred types. + */ + private def getIndexAndInferredTypesIfChanged( + agg: LogicalAggregate) + : Map[Int, RelDataType] = { + + agg.getAggCallList.zipWithIndex.flatMap { + case (aggCall, index) => + val origType = aggCall.`type` + val aggCallBinding = new Aggregate.AggCallBinding( + agg.getCluster.getTypeFactory, + aggCall.getAggregation, + SqlTypeUtil.projectTypes(agg.getInput.getRowType, aggCall.getArgList), + 0, + aggCall.hasFilter) + val inferredType = aggCall.getAggregation.inferReturnType(aggCallBinding) + + if (origType != inferredType && agg.getGroupCount == 1) { + Some(index, inferredType) + } else { + None + } + }.toMap } private[table] def getWindowExpressions(agg: LogicalAggregate): Seq[(RexCall, Int)] = { diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/sql/GroupWindowTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/sql/GroupWindowTest.scala index 07c4067..b5091ee 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/sql/GroupWindowTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/batch/sql/GroupWindowTest.scala @@ -346,4 +346,45 @@ class GroupWindowTest extends TableTestBase { util.verifySql(sql, expected) } + + @Test + def testReturnTypeInferenceForWindowAgg(): Unit = { + val util = batchTestUtil() + val table = util.addTable[(Int, Long, String, Timestamp)]("MyTable", 'a, 'b, 'c, 'rowtime) + + val innerQuery = + """ + |SELECT + | CASE a WHEN 1 THEN 1 ELSE 99 END AS correct, + | rowtime + |FROM MyTable + """.stripMargin + + val sqlQuery = + "SELECT " + + " sum(correct) as s, " + + " avg(correct) as a, " + + " TUMBLE_START(rowtime, INTERVAL '15' MINUTE) as wStart " + + s"FROM ($innerQuery) " + + "GROUP BY TUMBLE(rowtime, INTERVAL '15' MINUTE)" + + val expected = + unaryNode( + "DataSetCalc", + unaryNode( + "DataSetWindowAggregate", + unaryNode( + "DataSetCalc", + batchTableNode(table), + term("select", "CASE(=(a, 1), 1, 99) AS correct, rowtime") + ), + term("window", "TumblingGroupWindow('w$, 'rowtime, 900000.millis)"), + term("select", "SUM(correct) AS s, AVG(correct) AS a, start('w$) AS w$start," + + " end('w$) AS w$end, rowtime('w$) AS w$rowtime") + ), + term("select", "CAST(s) AS s", "CAST(a) AS a", "CAST(w$start) AS wStart") + ) + + util.verifySql(sqlQuery, expected) + } } diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/GroupWindowTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/GroupWindowTest.scala index a8c456f..5acef08 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/GroupWindowTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/api/stream/sql/GroupWindowTest.scala @@ -301,4 +301,47 @@ class GroupWindowTest extends TableTestBase { ) streamUtil.verifySql(sql, expected) } + + @Test + def testReturnTypeInferenceForWindowAgg() = { + + val innerQuery = + """ + |SELECT + | CASE a WHEN 1 THEN 1 ELSE 99 END AS correct, + | rowtime + |FROM MyTable + """.stripMargin + + val sql = + "SELECT " + + " sum(correct) as s, " + + " avg(correct) as a, " + + " TUMBLE_START(rowtime, INTERVAL '15' MINUTE) as wStart " + + s"FROM ($innerQuery) " + + "GROUP BY TUMBLE(rowtime, INTERVAL '15' MINUTE)" + + val expected = + unaryNode( + "DataStreamCalc", + unaryNode( + "DataStreamGroupWindowAggregate", + unaryNode( + "DataStreamCalc", + streamTableNode(table), + term("select", "CASE(=(a, 1), 1, 99) AS correct", "rowtime") + ), + term("window", "TumblingGroupWindow('w$, 'rowtime, 900000.millis)"), + term("select", + "SUM(correct) AS s", + "AVG(correct) AS a", + "start('w$) AS w$start", + "end('w$) AS w$end", + "rowtime('w$) AS w$rowtime", + "proctime('w$) AS w$proctime") + ), + term("select", "CAST(s) AS s", "CAST(a) AS a", "w$start AS wStart") + ) + streamUtil.verifySql(sql, expected) + } }