This is an automated email from the ASF dual-hosted git repository. jark pushed a commit to branch release-1.10 in repository https://gitbox.apache.org/repos/asf/flink.git
commit fefc92ace2e644f7bd0619bc5103095719f1a7bb Author: Jark Wu <[email protected]> AuthorDate: Thu Jun 4 11:57:58 2020 +0800 [FLINK-16451][table-planner-blink] Fix IndexOutOfBoundsException for DISTINCT AGG with constants This closes #12432 --- .../types/logical/utils/LogicalTypeChecks.java | 35 ++++++++++++++++++++++ .../codegen/agg/AggsHandlerCodeGenerator.scala | 1 + .../planner/codegen/agg/DistinctAggCodeGen.scala | 25 +++++++++++----- .../runtime/stream/sql/OverWindowITCase.scala | 8 +++-- 4 files changed, 60 insertions(+), 9 deletions(-) diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeChecks.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeChecks.java index c1d35a2..3437835 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeChecks.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeChecks.java @@ -30,7 +30,9 @@ import org.apache.flink.table.types.logical.LocalZonedTimestampType; import org.apache.flink.table.types.logical.LogicalType; import org.apache.flink.table.types.logical.LogicalTypeFamily; import org.apache.flink.table.types.logical.LogicalTypeRoot; +import org.apache.flink.table.types.logical.RowType; import org.apache.flink.table.types.logical.SmallIntType; +import org.apache.flink.table.types.logical.StructuredType; import org.apache.flink.table.types.logical.TimeType; import org.apache.flink.table.types.logical.TimestampKind; import org.apache.flink.table.types.logical.TimestampType; @@ -66,6 +68,8 @@ public final class LogicalTypeChecks { private static final SingleFieldIntervalExtractor SINGLE_FIELD_INTERVAL_EXTRACTOR = new SingleFieldIntervalExtractor(); + private static final FieldCountExtractor FIELD_COUNT_EXTRACTOR = new FieldCountExtractor(); + public static boolean hasRoot(LogicalType logicalType, LogicalTypeRoot typeRoot) { return logicalType.getTypeRoot() == typeRoot; } @@ -105,6 +109,13 @@ public final class LogicalTypeChecks { return logicalType.accept(LENGTH_EXTRACTOR); } + /** + * Returns the field count of row and structured types. + */ + public static int getFieldCount(LogicalType logicalType) { + return logicalType.accept(FIELD_COUNT_EXTRACTOR); + } + public static boolean hasLength(LogicalType logicalType, int length) { return getLength(logicalType) == length; } @@ -340,6 +351,30 @@ public final class LogicalTypeChecks { } } + private static class FieldCountExtractor extends Extractor<Integer> { + + @Override + public Integer visit(RowType rowType) { + return rowType.getFieldCount(); + } + + @Override + public Integer visit(StructuredType structuredType) { + int fieldCount = 0; + StructuredType currentType = structuredType; + while (currentType != null) { + fieldCount += currentType.getAttributes().size(); + currentType = currentType.getSuperType().orElse(null); + } + return fieldCount; + } + + @Override + public Integer visit(DistinctType distinctType) { + return distinctType.getSourceType().accept(this); + } + } + private static class SingleFieldIntervalExtractor extends Extractor<Boolean> { @Override diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/AggsHandlerCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/AggsHandlerCodeGenerator.scala index ba81b14..a53b7c0 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/AggsHandlerCodeGenerator.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/AggsHandlerCodeGenerator.scala @@ -259,6 +259,7 @@ class AggsHandlerCodeGenerator( index, innerCodeGens, filterExpr.toArray, + constantExprs, mergedAccOffset, aggBufferOffset, aggBufferSize, diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/DistinctAggCodeGen.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/DistinctAggCodeGen.scala index 9f53438..58f069d 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/DistinctAggCodeGen.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/DistinctAggCodeGen.scala @@ -31,6 +31,7 @@ import org.apache.flink.table.planner.expressions.converter.ExpressionConverter import org.apache.flink.table.planner.plan.utils.DistinctInfo import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromDataTypeToLogicalType import org.apache.flink.table.types.DataType +import org.apache.flink.table.types.logical.utils.LogicalTypeChecks import org.apache.flink.table.types.logical.{LogicalType, RowType} import org.apache.flink.util.Preconditions import org.apache.flink.util.Preconditions.checkArgument @@ -63,6 +64,7 @@ class DistinctAggCodeGen( distinctIndex: Int, innerAggCodeGens: Array[AggCodeGen], filterExpressions: Array[Option[Expression]], + constantExpressions: Seq[GeneratedExpression], mergedAccOffset: Int, aggBufferOffset: Int, aggBufferSize: Int, @@ -371,13 +373,22 @@ class DistinctAggCodeGen( private def generateKeyExpression( ctx: CodeGeneratorContext, generator: ExprCodeGenerator): GeneratedExpression = { - val fieldExprs = distinctInfo.argIndexes.map(generateInputAccess( - ctx, - generator.input1Type, - generator.input1Term, - _, - nullableInput = false, - deepCopy = inputFieldCopy)) + val fieldExprs = distinctInfo.argIndexes.map(argIndex => { + val inputFieldCount = LogicalTypeChecks.getFieldCount(generator.input1Type) + if (argIndex >= inputFieldCount) { + // arg index to constant + constantExpressions(argIndex - inputFieldCount) + } else { + // arg index to input field + generateInputAccess( + ctx, + generator.input1Type, + generator.input1Term, + argIndex, + nullableInput = false, + deepCopy = inputFieldCopy) + } + }) // the key expression of MapView if (fieldExprs.length > 1) { diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/OverWindowITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/OverWindowITCase.scala index 2e8cf75..a0c13e1 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/OverWindowITCase.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/OverWindowITCase.scala @@ -280,14 +280,18 @@ class OverWindowITCase(mode: StateBackendMode) extends StreamingWithStateTestBas tEnv.registerTable("T1", t1) val sqlQuery = "SELECT " + - "count(a) OVER (ORDER BY proctime ROWS BETWEEN UNBOUNDED preceding AND CURRENT ROW) " + + "listagg(distinct c, '|') " + + " OVER (ORDER BY proctime ROWS BETWEEN UNBOUNDED preceding AND CURRENT ROW), " + + "count(a) " + + " OVER (ORDER BY proctime ROWS BETWEEN UNBOUNDED preceding AND CURRENT ROW) " + "from T1" val sink = new TestingAppendSink tEnv.sqlQuery(sqlQuery).toAppendStream[Row].addSink(sink) env.execute() - val expected = List("1", "2", "3", "4", "5", "6", "7", "8", "9") + val expected = List("Hello,1", "Hello,2", "Hello,3", "Hello,4", "Hello,5", "Hello,6", + "Hello|Hello World,7", "Hello|Hello World,8", "Hello|Hello World,9") assertEquals(expected.sorted, sink.getAppendResults.sorted) }
