dawidwys commented on code in PR #23411: URL: https://github.com/apache/flink/pull/23411#discussion_r1477969162
########## flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/ArrayAggFunctionITCase.java: ########## @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.functions; + +import org.apache.flink.table.functions.BuiltInFunctionDefinitions; +import org.apache.flink.types.Row; + +import java.util.Arrays; +import java.util.Collections; +import java.util.stream.Stream; + +import static org.apache.flink.table.api.DataTypes.ARRAY; +import static org.apache.flink.table.api.DataTypes.INT; +import static org.apache.flink.table.api.DataTypes.ROW; +import static org.apache.flink.table.api.DataTypes.STRING; +import static org.apache.flink.table.api.Expressions.$; +import static org.apache.flink.types.RowKind.DELETE; +import static org.apache.flink.types.RowKind.INSERT; +import static org.apache.flink.types.RowKind.UPDATE_AFTER; +import static org.apache.flink.types.RowKind.UPDATE_BEFORE; + +/** Tests for built-in ARRAY_AGG aggregation functions. */ +class ArrayAggFunctionITCase extends BuiltInAggregateFunctionTestBase { + + @Override + Stream<TestSpec> getTestCaseSpecs() { + return Stream.of( + TestSpec.forFunction(BuiltInFunctionDefinitions.ARRAY_AGG) + .withDescription("ARRAY changelog stream aggregation") + .withSource( + ROW(STRING(), INT()), + Arrays.asList( + Row.ofKind(INSERT, "A", 1), + Row.ofKind(INSERT, "A", 2), + Row.ofKind(INSERT, "B", 2), + Row.ofKind(INSERT, "B", 2), + Row.ofKind(INSERT, "B", 3), + Row.ofKind(INSERT, "C", 3), + Row.ofKind(INSERT, "C", null), + Row.ofKind(INSERT, "D", null), + Row.ofKind(INSERT, "E", 4), + Row.ofKind(INSERT, "E", 5), + Row.ofKind(DELETE, "E", 5), + Row.ofKind(UPDATE_BEFORE, "E", 4), + Row.ofKind(UPDATE_AFTER, "E", 6))) + .testResult( + source -> + "SELECT f0, array_agg(f1) FROM " + source + " GROUP BY f0", + TableApiAggSpec.groupBySelect( + Collections.singletonList($("f0")), + $("f0"), + $("f1").arrayAgg()), + ROW(STRING(), ARRAY(INT())), + ROW(STRING(), ARRAY(INT())), + Arrays.asList( + Row.of("A", new Integer[] {1, 2}), + Row.of("B", new Integer[] {2, 2, 3}), + Row.of("C", new Integer[] {3}), Review Comment: The downside of the solution is that the array must fit into memory at all times. The difference with the `LagAggFunction` is that `LAG` keeps at most `n` elements where `n` is controlled by the user. Still I am reasonably good with the `LinkedList` approach because it anyhow needs to fit into memory when we emit it at the end as a single record. Writing this down for awareness. ########## flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/aggregate/ArrayAggFunction.java: ########## @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.functions.aggregate; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.data.ArrayData; +import org.apache.flink.table.data.GenericArrayData; +import org.apache.flink.table.runtime.typeutils.InternalSerializers; +import org.apache.flink.table.runtime.typeutils.LinkedListSerializer; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.util.FlinkRuntimeException; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedList; +import java.util.List; +import java.util.Objects; + +import static org.apache.flink.table.types.utils.DataTypeUtils.toInternalDataType; + +/** Built-in ARRAY_AGG aggregate function. */ +@Internal +public final class ArrayAggFunction<T> + extends BuiltInAggregateFunction<ArrayData, ArrayAggFunction.ArrayAggAccumulator<T>> { + + private static final long serialVersionUID = -5860934997657147836L; + + private final transient DataType elementDataType; + + private final boolean ignoreNulls; + + public ArrayAggFunction(LogicalType elementType, boolean ignoreNulls) { + this.elementDataType = toInternalDataType(elementType); + this.ignoreNulls = ignoreNulls; + } + + // -------------------------------------------------------------------------------------------- + // Planning + // -------------------------------------------------------------------------------------------- + + @Override + public List<DataType> getArgumentDataTypes() { + return Collections.singletonList(elementDataType); + } + + @Override + public DataType getAccumulatorDataType() { + DataType linkedListType = getLinkedListType(); + return DataTypes.STRUCTURED( + ArrayAggAccumulator.class, + DataTypes.FIELD("list", linkedListType), + DataTypes.FIELD("retractList", linkedListType)); + } + + @Override + public DataType getOutputDataType() { + return DataTypes.ARRAY(elementDataType).bridgedTo(ArrayData.class); + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + private DataType getLinkedListType() { + TypeSerializer<T> serializer = InternalSerializers.create(elementDataType.getLogicalType()); + return DataTypes.RAW( + LinkedList.class, (TypeSerializer) new LinkedListSerializer<>(serializer)); + } + + // -------------------------------------------------------------------------------------------- + // Runtime + // -------------------------------------------------------------------------------------------- + + /** Accumulator for ARRAY_AGG with retraction. */ + public static class ArrayAggAccumulator<T> { + public LinkedList<T> list; + public LinkedList<T> retractList; + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ArrayAggAccumulator<?> that = (ArrayAggAccumulator<?>) o; + return Objects.equals(list, that.list) && Objects.equals(retractList, that.retractList); + } + + @Override + public int hashCode() { + return Objects.hash(list, retractList); + } + } + + @Override + public ArrayAggAccumulator<T> createAccumulator() { + final ArrayAggAccumulator<T> acc = new ArrayAggAccumulator<>(); + acc.list = new LinkedList<>(); + acc.retractList = new LinkedList<>(); + return acc; + } + + public void accumulate(ArrayAggAccumulator<T> acc, T value) throws Exception { + if (value == null) { + if (!ignoreNulls) { + acc.list.add(null); + } + } else { + acc.list.add(value); + } + } + + public void retract(ArrayAggAccumulator<T> acc, T value) throws Exception { + if (value != null) { Review Comment: would be super nice to have a test case for that ########## flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/aggregate/ArrayAggFunction.java: ########## @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.functions.aggregate; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.data.ArrayData; +import org.apache.flink.table.data.GenericArrayData; +import org.apache.flink.table.runtime.typeutils.InternalSerializers; +import org.apache.flink.table.runtime.typeutils.LinkedListSerializer; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.util.FlinkRuntimeException; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedList; +import java.util.List; +import java.util.Objects; + +import static org.apache.flink.table.types.utils.DataTypeUtils.toInternalDataType; + +/** Built-in ARRAY_AGG aggregate function. */ +@Internal +public final class ArrayAggFunction<T> + extends BuiltInAggregateFunction<ArrayData, ArrayAggFunction.ArrayAggAccumulator<T>> { + + private static final long serialVersionUID = -5860934997657147836L; + + private final transient DataType elementDataType; + + private final boolean ignoreNulls; + + public ArrayAggFunction(LogicalType elementType, boolean ignoreNulls) { + this.elementDataType = toInternalDataType(elementType); + this.ignoreNulls = ignoreNulls; + } + + // -------------------------------------------------------------------------------------------- + // Planning + // -------------------------------------------------------------------------------------------- + + @Override + public List<DataType> getArgumentDataTypes() { + return Collections.singletonList(elementDataType); + } + + @Override + public DataType getAccumulatorDataType() { + DataType linkedListType = getLinkedListType(); + return DataTypes.STRUCTURED( + ArrayAggAccumulator.class, + DataTypes.FIELD("list", linkedListType), + DataTypes.FIELD("retractList", linkedListType)); + } + + @Override + public DataType getOutputDataType() { + return DataTypes.ARRAY(elementDataType).bridgedTo(ArrayData.class); + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + private DataType getLinkedListType() { + TypeSerializer<T> serializer = InternalSerializers.create(elementDataType.getLogicalType()); + return DataTypes.RAW( + LinkedList.class, (TypeSerializer) new LinkedListSerializer<>(serializer)); + } + + // -------------------------------------------------------------------------------------------- + // Runtime + // -------------------------------------------------------------------------------------------- + + /** Accumulator for ARRAY_AGG with retraction. */ + public static class ArrayAggAccumulator<T> { + public LinkedList<T> list; + public LinkedList<T> retractList; + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ArrayAggAccumulator<?> that = (ArrayAggAccumulator<?>) o; + return Objects.equals(list, that.list) && Objects.equals(retractList, that.retractList); + } + + @Override + public int hashCode() { + return Objects.hash(list, retractList); + } + } + + @Override + public ArrayAggAccumulator<T> createAccumulator() { + final ArrayAggAccumulator<T> acc = new ArrayAggAccumulator<>(); + acc.list = new LinkedList<>(); + acc.retractList = new LinkedList<>(); + return acc; + } + + public void accumulate(ArrayAggAccumulator<T> acc, T value) throws Exception { + if (value == null) { + if (!ignoreNulls) { + acc.list.add(null); + } + } else { + acc.list.add(value); + } + } + + public void retract(ArrayAggAccumulator<T> acc, T value) throws Exception { + if (value != null) { + if (!acc.list.remove(value)) { + acc.retractList.add(value); + } + } + } + + public void merge(ArrayAggAccumulator<T> acc, Iterable<ArrayAggAccumulator<T>> its) + throws Exception { + for (ArrayAggAccumulator<T> otherAcc : its) { + // merge list of acc and other + List<T> buffer = new ArrayList<>(); + for (T element : acc.list) { + buffer.add(element); + } + for (T element : otherAcc.list) { + buffer.add(element); + } + // merge retract list of acc and other + List<T> retractBuffer = new ArrayList<>(); + for (T element : acc.retractList) { + retractBuffer.add(element); + } + for (T element : otherAcc.retractList) { + retractBuffer.add(element); + } + + // merge list & retract list + List<T> newRetractBuffer = new ArrayList<>(); + for (T element : retractBuffer) { + if (!buffer.remove(element)) { + newRetractBuffer.add(element); + } + } + + // update to acc + acc.list.clear(); + acc.list.addAll(buffer); Review Comment: can't we populate `acc.list` in a single go? The current approach does make sense with a `ListView` but it does not with a `LinkedList` kept in memory. I believe we don't need the intermediate `buffer` and `newRetractBuffer` ########## flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/aggregate/ArrayAggFunction.java: ########## @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.functions.aggregate; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.data.ArrayData; +import org.apache.flink.table.data.GenericArrayData; +import org.apache.flink.table.runtime.typeutils.InternalSerializers; +import org.apache.flink.table.runtime.typeutils.LinkedListSerializer; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.util.FlinkRuntimeException; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedList; +import java.util.List; +import java.util.Objects; + +import static org.apache.flink.table.types.utils.DataTypeUtils.toInternalDataType; + +/** Built-in ARRAY_AGG aggregate function. */ +@Internal +public final class ArrayAggFunction<T> + extends BuiltInAggregateFunction<ArrayData, ArrayAggFunction.ArrayAggAccumulator<T>> { + + private static final long serialVersionUID = -5860934997657147836L; + + private final transient DataType elementDataType; + + private final boolean ignoreNulls; + + public ArrayAggFunction(LogicalType elementType, boolean ignoreNulls) { + this.elementDataType = toInternalDataType(elementType); + this.ignoreNulls = ignoreNulls; + } + + // -------------------------------------------------------------------------------------------- + // Planning + // -------------------------------------------------------------------------------------------- + + @Override + public List<DataType> getArgumentDataTypes() { + return Collections.singletonList(elementDataType); + } + + @Override + public DataType getAccumulatorDataType() { + DataType linkedListType = getLinkedListType(); + return DataTypes.STRUCTURED( + ArrayAggAccumulator.class, + DataTypes.FIELD("list", linkedListType), + DataTypes.FIELD("retractList", linkedListType)); + } + + @Override + public DataType getOutputDataType() { + return DataTypes.ARRAY(elementDataType).bridgedTo(ArrayData.class); + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + private DataType getLinkedListType() { + TypeSerializer<T> serializer = InternalSerializers.create(elementDataType.getLogicalType()); + return DataTypes.RAW( + LinkedList.class, (TypeSerializer) new LinkedListSerializer<>(serializer)); + } + + // -------------------------------------------------------------------------------------------- + // Runtime + // -------------------------------------------------------------------------------------------- + + /** Accumulator for ARRAY_AGG with retraction. */ + public static class ArrayAggAccumulator<T> { + public LinkedList<T> list; + public LinkedList<T> retractList; + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ArrayAggAccumulator<?> that = (ArrayAggAccumulator<?>) o; + return Objects.equals(list, that.list) && Objects.equals(retractList, that.retractList); + } + + @Override + public int hashCode() { + return Objects.hash(list, retractList); + } + } + + @Override + public ArrayAggAccumulator<T> createAccumulator() { + final ArrayAggAccumulator<T> acc = new ArrayAggAccumulator<>(); + acc.list = new LinkedList<>(); + acc.retractList = new LinkedList<>(); + return acc; + } + + public void accumulate(ArrayAggAccumulator<T> acc, T value) throws Exception { + if (value == null) { + if (!ignoreNulls) { + acc.list.add(null); + } + } else { + acc.list.add(value); + } + } + + public void retract(ArrayAggAccumulator<T> acc, T value) throws Exception { + if (value != null) { Review Comment: what about retracting nulls? ########## flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java: ########## @@ -1139,6 +1141,22 @@ public List<SqlGroupedWindowFunction> getAuxiliaryFunctions() { public static final SqlAggFunction APPROX_COUNT_DISTINCT = SqlStdOperatorTable.APPROX_COUNT_DISTINCT; + /** + * Use the definitions in Flink instead of {@link SqlLibraryOperators#ARRAY_AGG}, because we + * ignore nulls and returns nullable ARRAY type. Order by clause like <code> Review Comment: is the comment still correct after the last changes? ########## flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/aggregate/ArrayAggFunction.java: ########## @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime.functions.aggregate; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.data.ArrayData; +import org.apache.flink.table.data.GenericArrayData; +import org.apache.flink.table.runtime.typeutils.InternalSerializers; +import org.apache.flink.table.runtime.typeutils.LinkedListSerializer; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.util.FlinkRuntimeException; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedList; +import java.util.List; +import java.util.Objects; + +import static org.apache.flink.table.types.utils.DataTypeUtils.toInternalDataType; + +/** Built-in ARRAY_AGG aggregate function. */ +@Internal +public final class ArrayAggFunction<T> + extends BuiltInAggregateFunction<ArrayData, ArrayAggFunction.ArrayAggAccumulator<T>> { + + private static final long serialVersionUID = -5860934997657147836L; + + private final transient DataType elementDataType; + + private final boolean ignoreNulls; + + public ArrayAggFunction(LogicalType elementType, boolean ignoreNulls) { + this.elementDataType = toInternalDataType(elementType); + this.ignoreNulls = ignoreNulls; + } + + // -------------------------------------------------------------------------------------------- + // Planning + // -------------------------------------------------------------------------------------------- + + @Override + public List<DataType> getArgumentDataTypes() { + return Collections.singletonList(elementDataType); + } + + @Override + public DataType getAccumulatorDataType() { + DataType linkedListType = getLinkedListType(); + return DataTypes.STRUCTURED( + ArrayAggAccumulator.class, + DataTypes.FIELD("list", linkedListType), + DataTypes.FIELD("retractList", linkedListType)); + } + + @Override + public DataType getOutputDataType() { + return DataTypes.ARRAY(elementDataType).bridgedTo(ArrayData.class); + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + private DataType getLinkedListType() { + TypeSerializer<T> serializer = InternalSerializers.create(elementDataType.getLogicalType()); + return DataTypes.RAW( + LinkedList.class, (TypeSerializer) new LinkedListSerializer<>(serializer)); + } + + // -------------------------------------------------------------------------------------------- + // Runtime + // -------------------------------------------------------------------------------------------- + + /** Accumulator for ARRAY_AGG with retraction. */ + public static class ArrayAggAccumulator<T> { + public LinkedList<T> list; + public LinkedList<T> retractList; + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ArrayAggAccumulator<?> that = (ArrayAggAccumulator<?>) o; + return Objects.equals(list, that.list) && Objects.equals(retractList, that.retractList); + } + + @Override + public int hashCode() { + return Objects.hash(list, retractList); + } + } + + @Override + public ArrayAggAccumulator<T> createAccumulator() { + final ArrayAggAccumulator<T> acc = new ArrayAggAccumulator<>(); + acc.list = new LinkedList<>(); + acc.retractList = new LinkedList<>(); + return acc; + } + + public void accumulate(ArrayAggAccumulator<T> acc, T value) throws Exception { + if (value == null) { + if (!ignoreNulls) { + acc.list.add(null); + } + } else { + acc.list.add(value); + } + } + + public void retract(ArrayAggAccumulator<T> acc, T value) throws Exception { + if (value != null) { + if (!acc.list.remove(value)) { + acc.retractList.add(value); + } + } + } + + public void merge(ArrayAggAccumulator<T> acc, Iterable<ArrayAggAccumulator<T>> its) + throws Exception { + for (ArrayAggAccumulator<T> otherAcc : its) { + // merge list of acc and other + List<T> buffer = new ArrayList<>(); + for (T element : acc.list) { + buffer.add(element); + } + for (T element : otherAcc.list) { + buffer.add(element); + } + // merge retract list of acc and other + List<T> retractBuffer = new ArrayList<>(); + for (T element : acc.retractList) { + retractBuffer.add(element); + } + for (T element : otherAcc.retractList) { + retractBuffer.add(element); + } Review Comment: do we need the `retractBuffer`? Can't we just iterate over both the `retractList` and create only the final `newRetractBuffer`? It seems we create quite some unnecessary objects and potentially list resizing here. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
