This is an automated email from the ASF dual-hosted git repository.
lzljs3620320 pushed a commit to branch release-1.13
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/release-1.13 by this push:
new c623db4 [FLINK-19449][table-planner] LEAD/LAG cannot work correctly
in streaming mode
c623db4 is described below
commit c623db4dcf7549fed263ab6f92aac49a94a25897
Author: Jingsong Lee <[email protected]>
AuthorDate: Wed Apr 28 17:23:38 2021 +0800
[FLINK-19449][table-planner] LEAD/LAG cannot work correctly in streaming
mode
This closes #15793
---
docs/data/sql_functions.yml | 6 +-
.../functions/aggfunctions/LagAggFunction.java | 163 +++++++++++++++++
.../stream/StreamExecGlobalWindowAggregate.java | 4 +-
.../stream/StreamExecLocalWindowAggregate.java | 2 +-
.../exec/stream/StreamExecWindowAggregate.java | 2 +-
.../plan/metadata/FlinkRelMdColumnInterval.scala | 25 ++-
.../StreamPhysicalGlobalWindowAggregate.scala | 2 +-
.../StreamPhysicalLocalWindowAggregate.scala | 2 +-
.../stream/StreamPhysicalWindowAggregate.scala | 2 +-
.../planner/plan/utils/AggFunctionFactory.scala | 35 +++-
.../table/planner/plan/utils/AggregateUtil.scala | 27 ++-
.../functions/aggfunctions/LagAggFunctionTest.java | 62 +++++++
.../plan/metadata/FlinkRelMdHandlerTestBase.scala | 9 +-
.../runtime/stream/sql/OverAggregateITCase.scala | 68 +++++++
.../runtime/typeutils/LinkedListSerializer.java | 203 +++++++++++++++++++++
.../typeutils/LinkedListSerializerTest.java | 72 ++++++++
16 files changed, 646 insertions(+), 38 deletions(-)
diff --git a/docs/data/sql_functions.yml b/docs/data/sql_functions.yml
index 6b7caa15..51df9d1 100644
--- a/docs/data/sql_functions.yml
+++ b/docs/data/sql_functions.yml
@@ -674,10 +674,10 @@ aggregate:
- sql: ROW_NUMER()
description: Assigns a unique, sequential number to each row, starting
with one, according to the ordering of rows within the window partition.
ROW_NUMBER and RANK are similar. ROW_NUMBER numbers all rows sequentially (for
example 1, 2, 3, 4, 5). RANK provides the same numeric value for ties (for
example 1, 2, 2, 4, 5).
- sql: LEAD(expression [, offset] [, default])
- description: Returns the value of expression at the offsetth row before
the current row in the window. The default value of offset is 1 and the default
value of default is NULL.
- - sql: LAG(expression [, offset] [, default])
description: Returns the value of expression at the offsetth row after the
current row in the window. The default value of offset is 1 and the default
value of default is NULL.
- - sql: FIRST_VALUE(expression)
+ - sql: LAG(expression [, offset] [, default])
+ description: Returns the value of expression at the offsetth row before
the current row in the window. The default value of offset is 1 and the default
value of default is NULL.
+ - sql: FIRST_VALUE(expression)
description: Returns the first value in an ordered set of values.
- sql: LAST_VALUE(expression)
description: Returns the last value in an ordered set of values.
diff --git
a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/LagAggFunction.java
b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/LagAggFunction.java
new file mode 100644
index 0000000..2ad9b97
--- /dev/null
+++
b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/LagAggFunction.java
@@ -0,0 +1,163 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.functions.aggfunctions;
+
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.TableException;
+import org.apache.flink.table.functions.AggregateFunction;
+import
org.apache.flink.table.runtime.functions.aggregate.BuiltInAggregateFunction;
+import org.apache.flink.table.runtime.typeutils.InternalSerializers;
+import org.apache.flink.table.runtime.typeutils.LinkedListSerializer;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.table.types.logical.LogicalTypeRoot;
+import org.apache.flink.table.types.utils.DataTypeUtils;
+
+import java.util.Arrays;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Objects;
+
+/** Lag {@link AggregateFunction}. */
+public class LagAggFunction<T> extends BuiltInAggregateFunction<T,
LagAggFunction.LagAcc<T>> {
+
+ private final transient DataType[] valueDataTypes;
+
+ @SuppressWarnings("unchecked")
+ public LagAggFunction(LogicalType[] valueTypes) {
+ this.valueDataTypes =
+ Arrays.stream(valueTypes)
+ .map(DataTypeUtils::toInternalDataType)
+ .toArray(DataType[]::new);
+ if (valueDataTypes.length == 3
+ && valueDataTypes[2].getLogicalType().getTypeRoot() !=
LogicalTypeRoot.NULL) {
+ if (valueDataTypes[0].getConversionClass() !=
valueDataTypes[2].getConversionClass()) {
+ throw new TableException(
+ String.format(
+ "Please explicitly cast default value %s to
%s.",
+ valueDataTypes[2], valueDataTypes[1]));
+ }
+ }
+ }
+
+ //
--------------------------------------------------------------------------------------------
+ // Planning
+ //
--------------------------------------------------------------------------------------------
+
+ @Override
+ public List<DataType> getArgumentDataTypes() {
+ return Arrays.asList(valueDataTypes);
+ }
+
+ @Override
+ public DataType getAccumulatorDataType() {
+ return DataTypes.STRUCTURED(
+ LagAcc.class,
+ DataTypes.FIELD("offset", DataTypes.INT()),
+ DataTypes.FIELD("defaultValue", valueDataTypes[0]),
+ DataTypes.FIELD("buffer", getLinkedListType()));
+ }
+
+ @SuppressWarnings({"unchecked", "rawtypes"})
+ private DataType getLinkedListType() {
+ TypeSerializer<T> serializer =
+
InternalSerializers.create(getOutputDataType().getLogicalType());
+ return DataTypes.RAW(
+ LinkedList.class, (TypeSerializer) new
LinkedListSerializer<>(serializer));
+ }
+
+ @Override
+ public DataType getOutputDataType() {
+ return valueDataTypes[0];
+ }
+
+ //
--------------------------------------------------------------------------------------------
+ // Runtime
+ //
--------------------------------------------------------------------------------------------
+
+ public void accumulate(LagAcc<T> acc, T value) throws Exception {
+ acc.buffer.add(value);
+ while (acc.buffer.size() > acc.offset + 1) {
+ acc.buffer.removeFirst();
+ }
+ }
+
+ public void accumulate(LagAcc<T> acc, T value, int offset) throws
Exception {
+ if (offset < 0) {
+ throw new TableException(String.format("Offset(%d) should be
positive.", offset));
+ }
+
+ acc.offset = offset;
+ accumulate(acc, value);
+ }
+
+ public void accumulate(LagAcc<T> acc, T value, int offset, T defaultValue)
throws Exception {
+ acc.defaultValue = defaultValue;
+ accumulate(acc, value, offset);
+ }
+
+ public void resetAccumulator(LagAcc<T> acc) throws Exception {
+ acc.offset = 1;
+ acc.defaultValue = null;
+ acc.buffer.clear();
+ }
+
+ @Override
+ public T getValue(LagAcc<T> acc) {
+ if (acc.buffer.size() < acc.offset + 1) {
+ return acc.defaultValue;
+ } else if (acc.buffer.size() == acc.offset + 1) {
+ return acc.buffer.getFirst();
+ } else {
+ throw new TableException("Too more elements: " + acc);
+ }
+ }
+
+ @Override
+ public LagAcc<T> createAccumulator() {
+ return new LagAcc<>();
+ }
+
+ /** Accumulator for LAG. */
+ public static class LagAcc<T> {
+ public int offset = 1;
+ public T defaultValue = null;
+ public LinkedList<T> buffer = new LinkedList<>();
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ LagAcc<?> lagAcc = (LagAcc<?>) o;
+ return offset == lagAcc.offset
+ && Objects.equals(defaultValue, lagAcc.defaultValue)
+ && Objects.equals(buffer, lagAcc.buffer);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(offset, defaultValue, buffer);
+ }
+ }
+}
diff --git
a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGlobalWindowAggregate.java
b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGlobalWindowAggregate.java
index 8df6f2a..41ab7a2 100644
---
a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGlobalWindowAggregate.java
+++
b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGlobalWindowAggregate.java
@@ -145,14 +145,14 @@ public class StreamExecGlobalWindowAggregate extends
StreamExecWindowAggregateBa
final SliceAssigner sliceAssigner = createSliceAssigner(windowing,
shiftTimeZone);
final AggregateInfoList localAggInfoList =
- AggregateUtil.deriveWindowAggregateInfoList(
+ AggregateUtil.deriveStreamWindowAggregateInfoList(
localAggInputRowType, // should use original input here
JavaScalaConversionUtil.toScala(Arrays.asList(aggCalls)),
windowing.getWindow(),
false); // isStateBackendDataViews
final AggregateInfoList globalAggInfoList =
- AggregateUtil.deriveWindowAggregateInfoList(
+ AggregateUtil.deriveStreamWindowAggregateInfoList(
localAggInputRowType, // should use original input here
JavaScalaConversionUtil.toScala(Arrays.asList(aggCalls)),
windowing.getWindow(),
diff --git
a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecLocalWindowAggregate.java
b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecLocalWindowAggregate.java
index f333255..18f8a8d 100644
---
a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecLocalWindowAggregate.java
+++
b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecLocalWindowAggregate.java
@@ -122,7 +122,7 @@ public class StreamExecLocalWindowAggregate extends
StreamExecWindowAggregateBas
final SliceAssigner sliceAssigner = createSliceAssigner(windowing,
shiftTimeZone);
final AggregateInfoList aggInfoList =
- AggregateUtil.deriveWindowAggregateInfoList(
+ AggregateUtil.deriveStreamWindowAggregateInfoList(
inputRowType,
JavaScalaConversionUtil.toScala(Arrays.asList(aggCalls)),
windowing.getWindow(),
diff --git
a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecWindowAggregate.java
b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecWindowAggregate.java
index 913abee..3229441 100644
---
a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecWindowAggregate.java
+++
b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecWindowAggregate.java
@@ -143,7 +143,7 @@ public class StreamExecWindowAggregate extends
StreamExecWindowAggregateBase {
// Hopping window requires additional COUNT(*) to determine whether to
register next timer
// through whether the current fired window is empty, see
SliceSharedWindowAggProcessor.
final AggregateInfoList aggInfoList =
- AggregateUtil.deriveWindowAggregateInfoList(
+ AggregateUtil.deriveStreamWindowAggregateInfoList(
inputRowType,
JavaScalaConversionUtil.toScala(Arrays.asList(aggCalls)),
windowing.getWindow(),
diff --git
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala
index 23bd99c..f7c4641 100644
---
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala
+++
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala
@@ -562,9 +562,10 @@ class FlinkRelMdColumnInterval private extends
MetadataHandler[ColumnInterval] {
def getAggCallFromLocalAgg(
index: Int,
aggCalls: Seq[AggregateCall],
- inputType: RelDataType): AggregateCall = {
+ inputType: RelDataType,
+ isBounded: Boolean): AggregateCall = {
val outputIndexToAggCallIndexMap =
AggregateUtil.getOutputIndexToAggCallIndexMap(
- aggCalls, inputType)
+ aggCalls, inputType, isBounded)
if (outputIndexToAggCallIndexMap.containsKey(index)) {
val realIndex = outputIndexToAggCallIndexMap.get(index)
aggCalls(realIndex)
@@ -576,9 +577,10 @@ class FlinkRelMdColumnInterval private extends
MetadataHandler[ColumnInterval] {
def getAggCallIndexInLocalAgg(
index: Int,
globalAggCalls: Seq[AggregateCall],
- inputRowType: RelDataType): Integer = {
+ inputRowType: RelDataType,
+ isBounded: Boolean): Integer = {
val outputIndexToAggCallIndexMap =
AggregateUtil.getOutputIndexToAggCallIndexMap(
- globalAggCalls, inputRowType)
+ globalAggCalls, inputRowType, isBounded)
outputIndexToAggCallIndexMap.foreach {
case (k, v) => if (v == index) {
@@ -600,34 +602,37 @@ class FlinkRelMdColumnInterval private extends
MetadataHandler[ColumnInterval] {
case agg: StreamPhysicalGlobalGroupAggregate
if agg.aggCalls.length > aggCallIndex =>
val aggCallIndexInLocalAgg = getAggCallIndexInLocalAgg(
- aggCallIndex, agg.aggCalls, agg.localAggInputRowType)
+ aggCallIndex, agg.aggCalls, agg.localAggInputRowType, isBounded
= false)
if (aggCallIndexInLocalAgg != null) {
return fmq.getColumnInterval(agg.getInput, groupSet.length +
aggCallIndexInLocalAgg)
} else {
null
}
case agg: StreamPhysicalLocalGroupAggregate =>
- getAggCallFromLocalAgg(aggCallIndex, agg.aggCalls,
agg.getInput.getRowType)
+ getAggCallFromLocalAgg(
+ aggCallIndex, agg.aggCalls, agg.getInput.getRowType, isBounded =
false)
case agg: StreamPhysicalIncrementalGroupAggregate
if agg.partialAggCalls.length > aggCallIndex =>
agg.partialAggCalls(aggCallIndex)
case agg: StreamPhysicalGroupWindowAggregate if agg.aggCalls.length
> aggCallIndex =>
agg.aggCalls(aggCallIndex)
case agg: BatchPhysicalLocalHashAggregate =>
- getAggCallFromLocalAgg(aggCallIndex, agg.getAggCallList,
agg.getInput.getRowType)
+ getAggCallFromLocalAgg(
+ aggCallIndex, agg.getAggCallList, agg.getInput.getRowType,
isBounded = true)
case agg: BatchPhysicalHashAggregate if agg.isMerge =>
val aggCallIndexInLocalAgg = getAggCallIndexInLocalAgg(
- aggCallIndex, agg.getAggCallList, agg.aggInputRowType)
+ aggCallIndex, agg.getAggCallList, agg.aggInputRowType, isBounded
= true)
if (aggCallIndexInLocalAgg != null) {
return fmq.getColumnInterval(agg.getInput, groupSet.length +
aggCallIndexInLocalAgg)
} else {
null
}
case agg: BatchPhysicalLocalSortAggregate =>
- getAggCallFromLocalAgg(aggCallIndex, agg.getAggCallList,
agg.getInput.getRowType)
+ getAggCallFromLocalAgg(
+ aggCallIndex, agg.getAggCallList, agg.getInput.getRowType,
isBounded = true)
case agg: BatchPhysicalSortAggregate if agg.isMerge =>
val aggCallIndexInLocalAgg = getAggCallIndexInLocalAgg(
- aggCallIndex, agg.getAggCallList, agg.aggInputRowType)
+ aggCallIndex, agg.getAggCallList, agg.aggInputRowType, isBounded
= true)
if (aggCallIndexInLocalAgg != null) {
return fmq.getColumnInterval(agg.getInput, groupSet.length +
aggCallIndexInLocalAgg)
} else {
diff --git
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalGlobalWindowAggregate.scala
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalGlobalWindowAggregate.scala
index bef2589..bdace61 100644
---
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalGlobalWindowAggregate.scala
+++
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalGlobalWindowAggregate.scala
@@ -63,7 +63,7 @@ class StreamPhysicalGlobalWindowAggregate(
extends SingleRel(cluster, traitSet, inputRel)
with StreamPhysicalRel {
- private lazy val aggInfoList = AggregateUtil.deriveWindowAggregateInfoList(
+ private lazy val aggInfoList =
AggregateUtil.deriveStreamWindowAggregateInfoList(
FlinkTypeFactory.toLogicalRowType(inputRowTypeOfLocalAgg),
aggCalls,
windowing.getWindow,
diff --git
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalLocalWindowAggregate.scala
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalLocalWindowAggregate.scala
index 518ccda..2823aab 100644
---
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalLocalWindowAggregate.scala
+++
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalLocalWindowAggregate.scala
@@ -56,7 +56,7 @@ class StreamPhysicalLocalWindowAggregate(
extends SingleRel(cluster, traitSet, inputRel)
with StreamPhysicalRel {
- private lazy val aggInfoList = AggregateUtil.deriveWindowAggregateInfoList(
+ private lazy val aggInfoList =
AggregateUtil.deriveStreamWindowAggregateInfoList(
FlinkTypeFactory.toLogicalRowType(inputRel.getRowType),
aggCalls,
windowing.getWindow,
diff --git
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalWindowAggregate.scala
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalWindowAggregate.scala
index 21a1f50..eaa70e2 100644
---
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalWindowAggregate.scala
+++
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalWindowAggregate.scala
@@ -56,7 +56,7 @@ class StreamPhysicalWindowAggregate(
extends SingleRel(cluster, traitSet, inputRel)
with StreamPhysicalRel {
- lazy val aggInfoList: AggregateInfoList =
AggregateUtil.deriveWindowAggregateInfoList(
+ lazy val aggInfoList: AggregateInfoList =
AggregateUtil.deriveStreamWindowAggregateInfoList(
FlinkTypeFactory.toLogicalRowType(inputRel.getRowType),
aggCalls,
windowing.getWindow,
diff --git
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala
index e271a74..a2b795b 100644
---
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala
+++
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala
@@ -45,14 +45,16 @@ import scala.collection.JavaConversions._
* as subclasses of [[SqlAggFunction]] in Calcite but not as
[[BridgingSqlAggFunction]]. The factory
* returns [[DeclarativeAggregateFunction]] or [[BuiltInAggregateFunction]].
*
- * @param inputType the input rel data type
- * @param orderKeyIdx the indexes of order key (null when is not over agg)
- * @param needRetraction true if need retraction
+ * @param inputRowType the input row type
+ * @param orderKeyIndexes the indexes of order key (null when is not over agg)
+ * @param aggCallNeedRetractions true if need retraction
+ * @param isBounded true if the source is bounded source
*/
class AggFunctionFactory(
inputRowType: RowType,
orderKeyIndexes: Array[Int],
- aggCallNeedRetractions: Array[Boolean]) {
+ aggCallNeedRetractions: Array[Boolean],
+ isBounded: Boolean) {
/**
* The entry point to create an aggregate function from the given
[[AggregateCall]].
@@ -94,8 +96,12 @@ class AggFunctionFactory(
case a: SqlRankFunction if a.getKind == SqlKind.DENSE_RANK =>
createDenseRankAggFunction(argTypes)
- case _: SqlLeadLagAggFunction =>
- createLeadLagAggFunction(argTypes, index)
+ case func: SqlLeadLagAggFunction =>
+ if (isBounded) {
+ createBatchLeadLagAggFunction(argTypes, index)
+ } else {
+ createStreamLeadLagAggFunction(func, argTypes, index)
+ }
case _: SqlSingleValueAggFunction =>
createSingleValueAggFunction(argTypes)
@@ -328,7 +334,22 @@ class AggFunctionFactory(
}
}
- private def createLeadLagAggFunction(
+ private def createStreamLeadLagAggFunction(
+ func: SqlLeadLagAggFunction,
+ argTypes: Array[LogicalType],
+ index: Int): UserDefinedFunction = {
+ if (func.getKind == SqlKind.LEAD) {
+ throw new TableException("LEAD Function is not supported in stream
mode.")
+ }
+
+ if (aggCallNeedRetractions(index)) {
+ throw new TableException("LAG Function with retraction is not supported
in stream mode.")
+ }
+
+ new LagAggFunction(argTypes)
+ }
+
+ private def createBatchLeadLagAggFunction(
argTypes: Array[LogicalType], index: Int): UserDefinedFunction = {
argTypes(0).getTypeRoot match {
case TINYINT =>
diff --git
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala
index 9bfcdeb..3125238 100644
---
a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala
+++
b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala
@@ -153,6 +153,7 @@ object AggregateUtil extends Enumeration {
def getOutputIndexToAggCallIndexMap(
aggregateCalls: Seq[AggregateCall],
inputType: RelDataType,
+ isBounded: Boolean,
orderKeyIndexes: Array[Int] = null): util.Map[Integer, Integer] = {
val aggInfos = transformToAggregateInfoList(
FlinkTypeFactory.toLogicalRowType(inputType),
@@ -161,7 +162,8 @@ object AggregateUtil extends Enumeration {
orderKeyIndexes,
needInputCount = false,
isStateBackedDataViews = false,
- needDistinctInfo = false).aggInfos
+ needDistinctInfo = false,
+ isBounded).aggInfos
val map = new util.HashMap[Integer, Integer]()
var outputIndex = 0
@@ -248,7 +250,7 @@ object AggregateUtil extends Enumeration {
isStateBackendDataViews = true)
}
- def deriveWindowAggregateInfoList(
+ def deriveStreamWindowAggregateInfoList(
inputRowType: RowType,
aggCalls: Seq[AggregateCall],
windowSpec: WindowSpec,
@@ -271,7 +273,8 @@ object AggregateUtil extends Enumeration {
orderKeyIndexes = null,
needInputCount,
isStateBackendDataViews,
- needDistinctInfo = true)
+ needDistinctInfo = true,
+ isBounded = false)
}
def transformToBatchAggregateFunctions(
@@ -287,7 +290,8 @@ object AggregateUtil extends Enumeration {
orderKeyIndexes,
needInputCount = false,
isStateBackedDataViews = false,
- needDistinctInfo = false).aggInfos
+ needDistinctInfo = false,
+ isBounded = true).aggInfos
val aggFields = aggInfos.map(_.argIndexes)
val bufferTypes = aggInfos.map(_.externalAccTypes)
@@ -315,7 +319,8 @@ object AggregateUtil extends Enumeration {
orderKeyIndexes,
needInputCount = false,
isStateBackedDataViews = false,
- needDistinctInfo = false)
+ needDistinctInfo = false,
+ isBounded = true)
}
def transformToStreamAggregateInfoList(
@@ -332,7 +337,8 @@ object AggregateUtil extends Enumeration {
orderKeyIndexes = null,
needInputCount,
isStateBackendDataViews,
- needDistinctInfo)
+ needDistinctInfo,
+ isBounded = false)
}
/**
@@ -355,7 +361,8 @@ object AggregateUtil extends Enumeration {
orderKeyIndexes: Array[Int],
needInputCount: Boolean,
isStateBackedDataViews: Boolean,
- needDistinctInfo: Boolean): AggregateInfoList = {
+ needDistinctInfo: Boolean,
+ isBounded: Boolean): AggregateInfoList = {
// Step-1:
// if need inputCount, find count1 in the existed aggregate calls first,
@@ -375,7 +382,11 @@ object AggregateUtil extends Enumeration {
// Step-3:
// create aggregate information
- val factory = new AggFunctionFactory(inputRowType, orderKeyIndexes,
aggCallNeedRetractions)
+ val factory = new AggFunctionFactory(
+ inputRowType,
+ orderKeyIndexes,
+ aggCallNeedRetractions,
+ isBounded)
val aggInfos = newAggCalls
.zipWithIndex
.map { case (call, index) =>
diff --git
a/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/planner/functions/aggfunctions/LagAggFunctionTest.java
b/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/planner/functions/aggfunctions/LagAggFunctionTest.java
new file mode 100644
index 0000000..e3553d8
--- /dev/null
+++
b/flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/planner/functions/aggfunctions/LagAggFunctionTest.java
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.functions.aggfunctions;
+
+import org.apache.flink.table.data.StringData;
+import org.apache.flink.table.functions.AggregateFunction;
+import org.apache.flink.table.types.logical.CharType;
+import org.apache.flink.table.types.logical.IntType;
+import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.table.types.logical.VarCharType;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import static org.apache.flink.table.data.StringData.fromString;
+
+/** Test for {@link LagAggFunction}. */
+public class LagAggFunctionTest
+ extends AggFunctionTestBase<StringData,
LagAggFunction.LagAcc<StringData>> {
+
+ @Override
+ protected List<List<StringData>> getInputValueSets() {
+ return Arrays.asList(
+ Collections.singletonList(fromString("1")),
+ Arrays.asList(fromString("1"), null),
+ Arrays.asList(null, null),
+ Arrays.asList(null, fromString("10")));
+ }
+
+ @Override
+ protected List<StringData> getExpectedResults() {
+ return Arrays.asList(null, fromString("1"), null, null);
+ }
+
+ @Override
+ protected AggregateFunction<StringData, LagAggFunction.LagAcc<StringData>>
getAggregator() {
+ return new LagAggFunction<>(
+ new LogicalType[] {new VarCharType(), new IntType(), new
CharType()});
+ }
+
+ @Override
+ protected Class<?> getAccClass() {
+ return LagAggFunction.LagAcc.class;
+ }
+}
diff --git
a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala
index f0ecc24..1dbbb54 100644
---
a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala
+++
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala
@@ -949,7 +949,8 @@ class FlinkRelMdHandlerTestBase {
val aggFunctionFactory = new AggFunctionFactory(
FlinkTypeFactory.toLogicalRowType(studentBatchScan.getRowType),
Array.empty[Int],
- Array.fill(aggCalls.size())(false))
+ Array.fill(aggCalls.size())(false),
+ false)
val aggCallToAggFunction = aggCalls.zipWithIndex.map {
case (call, index) => (call, aggFunctionFactory.createAggFunction(call,
index))
}
@@ -1157,7 +1158,8 @@ class FlinkRelMdHandlerTestBase {
val aggFunctionFactory = new AggFunctionFactory(
FlinkTypeFactory.toLogicalRowType(calcOnStudentScan.getRowType),
Array.empty[Int],
- Array.fill(aggCalls.size())(false))
+ Array.fill(aggCalls.size())(false),
+ false)
val aggCallToAggFunction = aggCalls.zipWithIndex.map {
case (call, index) => (call, aggFunctionFactory.createAggFunction(call,
index))
}
@@ -1324,7 +1326,8 @@ class FlinkRelMdHandlerTestBase {
val aggFunctionFactory = new AggFunctionFactory(
FlinkTypeFactory.toLogicalRowType(studentBatchScan.getRowType),
Array.empty[Int],
- Array.fill(aggCalls.size())(false))
+ Array.fill(aggCalls.size())(false),
+ false)
val aggCallToAggFunction = aggCalls.zipWithIndex.map {
case (call, index) => (call, aggFunctionFactory.createAggFunction(call,
index))
}
diff --git
a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/OverAggregateITCase.scala
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/OverAggregateITCase.scala
index e81d115..208de4a 100644
---
a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/OverAggregateITCase.scala
+++
b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/OverAggregateITCase.scala
@@ -56,6 +56,74 @@ class OverAggregateITCase(mode: StateBackendMode) extends
StreamingWithStateTest
}
@Test
+ def testLagFunction(): Unit = {
+ val sqlQuery = "SELECT a, b, c, " +
+ " LAG(b) OVER(PARTITION BY a ORDER BY rowtime)," +
+ " LAG(b, 2) OVER(PARTITION BY a ORDER BY rowtime)," +
+ " LAG(b, 2, CAST(10086 AS BIGINT)) OVER(PARTITION BY a ORDER BY
rowtime)" +
+ "FROM T1"
+
+ val data: Seq[Either[(Long, (Int, Long, String)), Long]] = Seq(
+ Left(14000001L, (1, 1L, "Hi")),
+ Left(14000005L, (1, 2L, "Hi")),
+ Left(14000002L, (1, 3L, "Hello")),
+ Left(14000003L, (1, 4L, "Hello")),
+ Left(14000003L, (1, 5L, "Hello")),
+ Right(14000020L),
+ Left(14000021L, (1, 6L, "Hello world")),
+ Left(14000022L, (1, 7L, "Hello world")),
+ Right(14000030L))
+
+ val source = failingDataSource(data)
+ val t1 = source.transform("TimeAssigner", new
EventTimeProcessOperator[(Int, Long, String)])
+ .setParallelism(source.parallelism)
+ .toTable(tEnv, 'a, 'b, 'c, 'rowtime.rowtime)
+
+ tEnv.registerTable("T1", t1)
+
+ val sink = new TestingAppendSink
+ tEnv.sqlQuery(sqlQuery).toAppendStream[Row].addSink(sink)
+ env.execute()
+
+ val expected = List(
+ s"1,1,Hi,null,null,10086",
+ s"1,3,Hello,1,null,10086",
+ s"1,4,Hello,4,3,3",
+ s"1,5,Hello,4,3,3",
+ s"1,2,Hi,5,4,4",
+ s"1,6,Hello world,2,5,5",
+ s"1,7,Hello world,6,2,2")
+ assertEquals(expected.sorted, sink.getAppendResults.sorted)
+ }
+
+ @Test
+ def testLeadFunction(): Unit = {
+ expectedException.expectMessage("LEAD Function is not supported in stream
mode")
+
+ val sqlQuery = "SELECT a, b, c, " +
+ " LEAD(b) OVER(PARTITION BY a ORDER BY rowtime)," +
+ " LEAD(b, 2) OVER(PARTITION BY a ORDER BY rowtime)," +
+ " LEAD(b, 2, CAST(10086 AS BIGINT)) OVER(PARTITION BY a ORDER BY
rowtime)" +
+ "FROM T1"
+
+ val data: Seq[Either[(Long, (Int, Long, String)), Long]] = Seq(
+ Left(14000001L, (1, 1L, "Hi")),
+ Left(14000003L, (1, 5L, "Hello")),
+ Right(14000020L),
+ Left(14000021L, (1, 6L, "Hello world")),
+ Left(14000022L, (1, 7L, "Hello world")),
+ Right(14000030L))
+ val source = failingDataSource(data)
+ val t1 = source.transform("TimeAssigner", new
EventTimeProcessOperator[(Int, Long, String)])
+ .setParallelism(source.parallelism)
+ .toTable(tEnv, 'a, 'b, 'c, 'rowtime.rowtime)
+ tEnv.registerTable("T1", t1)
+ val sink = new TestingAppendSink
+ tEnv.sqlQuery(sqlQuery).toAppendStream[Row].addSink(sink)
+ env.execute()
+ }
+
+ @Test
def testRowNumberOnOver(): Unit = {
val t = failingDataSource(TestData.tupleData5)
.toTable(tEnv, 'a, 'b, 'c, 'd, 'e, 'proctime.proctime)
diff --git
a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/typeutils/LinkedListSerializer.java
b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/typeutils/LinkedListSerializer.java
new file mode 100644
index 0000000..df97203
--- /dev/null
+++
b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/typeutils/LinkedListSerializer.java
@@ -0,0 +1,203 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.typeutils;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.CompositeTypeSerializerSnapshot;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataOutputView;
+
+import java.io.IOException;
+import java.util.LinkedList;
+
+import static org.apache.flink.util.Preconditions.checkNotNull;
+
+/**
+ * A serializer for {@link LinkedList}. The serializer relies on an element
serializer for the
+ * serialization of the list's elements.
+ *
+ * @param <T> The type of element in the list.
+ */
+@Internal
+public final class LinkedListSerializer<T> extends
TypeSerializer<LinkedList<T>> {
+
+ private static final long serialVersionUID = 1L;
+
+ /** The serializer for the elements of the list. */
+ private final TypeSerializer<T> elementSerializer;
+
+ /**
+ * Creates a list serializer that uses the given serializer to serialize
the list's elements.
+ *
+ * @param elementSerializer The serializer for the elements of the list
+ */
+ public LinkedListSerializer(TypeSerializer<T> elementSerializer) {
+ this.elementSerializer = checkNotNull(elementSerializer);
+ }
+
+ // ------------------------------------------------------------------------
+ // LinkedListSerializer specific properties
+ // ------------------------------------------------------------------------
+
+ /**
+ * Gets the serializer for the elements of the list.
+ *
+ * @return The serializer for the elements of the list
+ */
+ public TypeSerializer<T> getElementSerializer() {
+ return elementSerializer;
+ }
+
+ // ------------------------------------------------------------------------
+ // Type Serializer implementation
+ // ------------------------------------------------------------------------
+
+ @Override
+ public boolean isImmutableType() {
+ return false;
+ }
+
+ @Override
+ public TypeSerializer<LinkedList<T>> duplicate() {
+ TypeSerializer<T> duplicateElement = elementSerializer.duplicate();
+ return duplicateElement == elementSerializer
+ ? this
+ : new LinkedListSerializer<>(duplicateElement);
+ }
+
+ @Override
+ public LinkedList<T> createInstance() {
+ return new LinkedList<>();
+ }
+
+ @Override
+ public LinkedList<T> copy(LinkedList<T> from) {
+ LinkedList<T> newList = new LinkedList<>();
+ for (T element : from) {
+ newList.add(elementSerializer.copy(element));
+ }
+ return newList;
+ }
+
+ @Override
+ public LinkedList<T> copy(LinkedList<T> from, LinkedList<T> reuse) {
+ return copy(from);
+ }
+
+ @Override
+ public int getLength() {
+ return -1; // var length
+ }
+
+ @Override
+ public void serialize(LinkedList<T> list, DataOutputView target) throws
IOException {
+ target.writeInt(list.size());
+ for (T element : list) {
+ elementSerializer.serialize(element, target);
+ }
+ }
+
+ @Override
+ public LinkedList<T> deserialize(DataInputView source) throws IOException {
+ final int size = source.readInt();
+ final LinkedList<T> list = new LinkedList<>();
+ for (int i = 0; i < size; i++) {
+ list.add(elementSerializer.deserialize(source));
+ }
+ return list;
+ }
+
+ @Override
+ public LinkedList<T> deserialize(LinkedList<T> reuse, DataInputView
source) throws IOException {
+ return deserialize(source);
+ }
+
+ @Override
+ public void copy(DataInputView source, DataOutputView target) throws
IOException {
+ // copy number of elements
+ final int num = source.readInt();
+ target.writeInt(num);
+ for (int i = 0; i < num; i++) {
+ elementSerializer.copy(source, target);
+ }
+ }
+
+ // --------------------------------------------------------------------
+
+ @Override
+ public boolean equals(Object obj) {
+ return obj == this
+ || (obj != null
+ && obj.getClass() == getClass()
+ && elementSerializer.equals(
+ ((LinkedListSerializer<?>)
obj).elementSerializer));
+ }
+
+ @Override
+ public int hashCode() {
+ return elementSerializer.hashCode();
+ }
+
+ //
--------------------------------------------------------------------------------------------
+ // Serializer configuration snapshot & compatibility
+ //
--------------------------------------------------------------------------------------------
+
+ @Override
+ public TypeSerializerSnapshot<LinkedList<T>> snapshotConfiguration() {
+ return new LinkedListSerializerSnapshot<>(this);
+ }
+
+ /** Snapshot class for the {@link LinkedListSerializer}. */
+ public static class LinkedListSerializerSnapshot<T>
+ extends CompositeTypeSerializerSnapshot<LinkedList<T>,
LinkedListSerializer<T>> {
+
+ private static final int CURRENT_VERSION = 1;
+
+ /** Constructor for read instantiation. */
+ public LinkedListSerializerSnapshot() {
+ super(LinkedListSerializer.class);
+ }
+
+ /** Constructor to create the snapshot for writing. */
+ public LinkedListSerializerSnapshot(LinkedListSerializer<T>
listSerializer) {
+ super(listSerializer);
+ }
+
+ @Override
+ public int getCurrentOuterSnapshotVersion() {
+ return CURRENT_VERSION;
+ }
+
+ @Override
+ protected LinkedListSerializer<T>
createOuterSerializerWithNestedSerializers(
+ TypeSerializer<?>[] nestedSerializers) {
+ @SuppressWarnings("unchecked")
+ TypeSerializer<T> elementSerializer = (TypeSerializer<T>)
nestedSerializers[0];
+ return new LinkedListSerializer<>(elementSerializer);
+ }
+
+ @Override
+ protected TypeSerializer<?>[] getNestedSerializers(
+ LinkedListSerializer<T> outerSerializer) {
+ return new TypeSerializer<?>[]
{outerSerializer.getElementSerializer()};
+ }
+ }
+}
diff --git
a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/typeutils/LinkedListSerializerTest.java
b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/typeutils/LinkedListSerializerTest.java
new file mode 100644
index 0000000..eea1556
--- /dev/null
+++
b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/typeutils/LinkedListSerializerTest.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.typeutils;
+
+import org.apache.flink.api.common.typeutils.SerializerTestBase;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.base.LongSerializer;
+
+import java.util.LinkedList;
+import java.util.Random;
+
+/** A test for the {@link LinkedListSerializer}. */
+public class LinkedListSerializerTest extends
SerializerTestBase<LinkedList<Long>> {
+
+ @Override
+ protected TypeSerializer<LinkedList<Long>> createSerializer() {
+ return new LinkedListSerializer<>(LongSerializer.INSTANCE);
+ }
+
+ @Override
+ protected int getLength() {
+ return -1;
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ protected Class<LinkedList<Long>> getTypeClass() {
+ return (Class<LinkedList<Long>>) (Class<?>) LinkedList.class;
+ }
+
+ @SuppressWarnings({"rawtypes", "unchecked"})
+ @Override
+ protected LinkedList<Long>[] getTestData() {
+ final Random rnd = new Random(123654789);
+
+ // empty lists
+ final LinkedList<Long> list1 = new LinkedList<>();
+
+ // single element lists
+ final LinkedList<Long> list2 = new LinkedList<>();
+ list2.add(12345L);
+
+ // longer lists
+ final LinkedList<Long> list3 = new LinkedList<>();
+ for (int i = 0; i < rnd.nextInt(200); i++) {
+ list3.add(rnd.nextLong());
+ }
+
+ final LinkedList<Long> list4 = new LinkedList<>();
+ for (int i = 0; i < rnd.nextInt(200); i++) {
+ list4.add(rnd.nextLong());
+ }
+
+ return (LinkedList<Long>[]) new LinkedList[] {list1, list2, list3,
list4};
+ }
+}