This is an automated email from the ASF dual-hosted git repository. guoweijie pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit 46ec4b02ec2f514eb88e863b1d963e654dea1d32 Author: Wencong Liu <[email protected]> AuthorDate: Tue Mar 12 13:49:56 2024 +0800 [FLINK-34543][datastream] Introduce the Aggregate API on PartitionWindowedStream --- .../datastream/KeyedPartitionWindowedStream.java | 10 ++ .../NonKeyedPartitionWindowedStream.java | 16 +++ .../api/datastream/PartitionWindowedStream.java | 12 ++ .../api/operators/PartitionAggregateOperator.java | 65 ++++++++++ .../operators/PartitionAggregateOperatorTest.java | 137 +++++++++++++++++++++ .../KeyedPartitionWindowedStreamITCase.java | 77 ++++++++++++ .../NonKeyedPartitionWindowedStreamITCase.java | 51 ++++++++ 7 files changed, 368 insertions(+) diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/KeyedPartitionWindowedStream.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/KeyedPartitionWindowedStream.java index ccd9f02e197..5ccc05826bf 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/KeyedPartitionWindowedStream.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/KeyedPartitionWindowedStream.java @@ -19,6 +19,7 @@ package org.apache.flink.streaming.api.datastream; import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.functions.AggregateFunction; import org.apache.flink.api.common.functions.MapPartitionFunction; import org.apache.flink.api.common.functions.ReduceFunction; import org.apache.flink.api.common.typeinfo.TypeInformation; @@ -80,4 +81,13 @@ public class KeyedPartitionWindowedStream<T, KEY> implements PartitionWindowedSt reduceFunction = environment.clean(reduceFunction); return input.window(GlobalWindows.createWithEndOfStreamTrigger()).reduce(reduceFunction); } + + @Override + public <ACC, R> SingleOutputStreamOperator<R> aggregate( + AggregateFunction<T, ACC, R> aggregateFunction) { + checkNotNull(aggregateFunction, "The aggregate function must not be null."); + aggregateFunction = environment.clean(aggregateFunction); + return input.window(GlobalWindows.createWithEndOfStreamTrigger()) + .aggregate(aggregateFunction); + } } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/NonKeyedPartitionWindowedStream.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/NonKeyedPartitionWindowedStream.java index a1943a72f33..0f415aed876 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/NonKeyedPartitionWindowedStream.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/NonKeyedPartitionWindowedStream.java @@ -19,12 +19,14 @@ package org.apache.flink.streaming.api.datastream; import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.functions.AggregateFunction; import org.apache.flink.api.common.functions.MapPartitionFunction; import org.apache.flink.api.common.functions.ReduceFunction; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.typeutils.TypeExtractor; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.streaming.api.operators.MapPartitionOperator; +import org.apache.flink.streaming.api.operators.PartitionAggregateOperator; import org.apache.flink.streaming.api.operators.PartitionReduceOperator; import static org.apache.flink.util.Preconditions.checkNotNull; @@ -72,4 +74,18 @@ public class NonKeyedPartitionWindowedStream<T> implements PartitionWindowedStre new PartitionReduceOperator<>(reduceFunction)) .setParallelism(input.getParallelism()); } + + @Override + public <ACC, R> SingleOutputStreamOperator<R> aggregate( + AggregateFunction<T, ACC, R> aggregateFunction) { + checkNotNull(aggregateFunction, "The aggregate function must not be null."); + aggregateFunction = environment.clean(aggregateFunction); + String opName = "PartitionAggregate"; + TypeInformation<R> resultType = + TypeExtractor.getAggregateFunctionReturnType( + aggregateFunction, input.getType(), opName, true); + return input.transform( + opName, resultType, new PartitionAggregateOperator<>(aggregateFunction)) + .setParallelism(input.getParallelism()); + } } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/PartitionWindowedStream.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/PartitionWindowedStream.java index a9f0bcba28a..19ebb0a7530 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/PartitionWindowedStream.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/PartitionWindowedStream.java @@ -19,6 +19,7 @@ package org.apache.flink.streaming.api.datastream; import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.api.common.functions.AggregateFunction; import org.apache.flink.api.common.functions.MapPartitionFunction; import org.apache.flink.api.common.functions.ReduceFunction; @@ -49,4 +50,15 @@ public interface PartitionWindowedStream<T> { * @return The data stream with final reduced result. */ SingleOutputStreamOperator<T> reduce(ReduceFunction<T> reduceFunction); + + /** + * Applies an aggregate transformation on the records of the window. + * + * @param aggregateFunction The aggregate function. + * @param <ACC> The type of accumulator in aggregate function. + * @param <R> The type of aggregate function result. + * @return The data stream with final aggregated result. + */ + <ACC, R> SingleOutputStreamOperator<R> aggregate( + AggregateFunction<T, ACC, R> aggregateFunction); } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/PartitionAggregateOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/PartitionAggregateOperator.java new file mode 100644 index 00000000000..c42184a12ed --- /dev/null +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/PartitionAggregateOperator.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.operators; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.functions.AggregateFunction; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; + +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * The {@link PartitionAggregateOperator} is used to apply the aggregate transformation on all + * records of each partition. Each partition contains all records of a subtask. + */ +@Internal +public class PartitionAggregateOperator<IN, ACC, OUT> + extends AbstractUdfStreamOperator<OUT, AggregateFunction<IN, ACC, OUT>> + implements OneInputStreamOperator<IN, OUT>, BoundedOneInput { + + private final AggregateFunction<IN, ACC, OUT> aggregateFunction; + + private ACC currentAccumulator = null; + + public PartitionAggregateOperator(AggregateFunction<IN, ACC, OUT> aggregateFunction) { + super(aggregateFunction); + this.aggregateFunction = aggregateFunction; + } + + @Override + public void open() throws Exception { + super.open(); + this.currentAccumulator = checkNotNull(aggregateFunction.createAccumulator()); + } + + @Override + public void processElement(StreamRecord<IN> element) throws Exception { + aggregateFunction.add(element.getValue(), currentAccumulator); + } + + @Override + public void endInput() throws Exception { + output.collect(new StreamRecord<>(aggregateFunction.getResult(currentAccumulator))); + } + + @Override + public OperatorAttributes getOperatorAttributes() { + return new OperatorAttributesBuilder().setOutputOnlyAfterEndOfStream(true).build(); + } +} diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/PartitionAggregateOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/PartitionAggregateOperatorTest.java new file mode 100644 index 00000000000..fc30f6095c9 --- /dev/null +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/PartitionAggregateOperatorTest.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.operators; + +import org.apache.flink.api.common.functions.AggregateFunction; +import org.apache.flink.api.common.functions.OpenContext; +import org.apache.flink.api.common.functions.RichAggregateFunction; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; +import org.apache.flink.streaming.util.TestHarnessUtil; + +import org.junit.jupiter.api.Test; + +import java.util.LinkedList; +import java.util.Queue; +import java.util.concurrent.CompletableFuture; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Unit test for {@link PartitionAggregateOperator}. */ +class PartitionAggregateOperatorTest { + + private static final int RECORD = 1; + + @Test + void testAggregate() throws Exception { + PartitionAggregateOperator<Integer, TestAccumulator, String> partitionAggregateOperator = + new PartitionAggregateOperator<>( + new Aggregate(new CompletableFuture<>(), new CompletableFuture<>())); + OneInputStreamOperatorTestHarness<Integer, String> testHarness = + new OneInputStreamOperatorTestHarness<>(partitionAggregateOperator); + Queue<Object> expectedOutput = new LinkedList<>(); + testHarness.open(); + testHarness.processElement(new StreamRecord<>(RECORD)); + testHarness.processElement(new StreamRecord<>(RECORD)); + testHarness.processElement(new StreamRecord<>(RECORD)); + testHarness.endInput(); + expectedOutput.add(new StreamRecord<>("303")); + TestHarnessUtil.assertOutputEquals( + "The aggregate result is not correct.", expectedOutput, testHarness.getOutput()); + testHarness.close(); + } + + @Test + void testOpenClose() throws Exception { + CompletableFuture<Object> openIdentifier = new CompletableFuture<>(); + CompletableFuture<Object> closeIdentifier = new CompletableFuture<>(); + PartitionAggregateOperator<Integer, TestAccumulator, String> partitionAggregateOperator = + new PartitionAggregateOperator<>(new Aggregate(openIdentifier, closeIdentifier)); + OneInputStreamOperatorTestHarness<Integer, String> testHarness = + new OneInputStreamOperatorTestHarness<>(partitionAggregateOperator); + testHarness.open(); + testHarness.processElement(new StreamRecord<>(RECORD)); + testHarness.endInput(); + testHarness.close(); + assertThat(openIdentifier).isCompleted(); + assertThat(closeIdentifier).isCompleted(); + assertThat(testHarness.getOutput()).isNotEmpty(); + } + + /** The test user implementation of {@link AggregateFunction}. */ + private static class Aggregate extends RichAggregateFunction<Integer, TestAccumulator, String> { + + private final CompletableFuture<Object> openIdentifier; + + private final CompletableFuture<Object> closeIdentifier; + + public Aggregate( + CompletableFuture<Object> openIdentifier, + CompletableFuture<Object> closeIdentifier) { + this.openIdentifier = openIdentifier; + this.closeIdentifier = closeIdentifier; + } + + @Override + public void open(OpenContext openContext) throws Exception { + super.open(openContext); + openIdentifier.complete(null); + } + + @Override + public TestAccumulator createAccumulator() { + return new TestAccumulator(); + } + + @Override + public TestAccumulator add(Integer value, TestAccumulator accumulator) { + accumulator.addNumber(value); + return accumulator; + } + + @Override + public String getResult(TestAccumulator accumulator) { + return accumulator.getResult(); + } + + @Override + public TestAccumulator merge(TestAccumulator a, TestAccumulator b) { + return null; + } + + @Override + public void close() throws Exception { + super.close(); + closeIdentifier.complete(null); + } + } + + /** The test accumulator. */ + private static class TestAccumulator { + private Integer result = 0; + + public void addNumber(Integer number) { + result = result + number + 100; + } + + public String getResult() { + return String.valueOf(result); + } + } +} diff --git a/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/KeyedPartitionWindowedStreamITCase.java b/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/KeyedPartitionWindowedStreamITCase.java index 4c2951bda6f..c51de02d9fb 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/KeyedPartitionWindowedStreamITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/KeyedPartitionWindowedStreamITCase.java @@ -18,6 +18,7 @@ package org.apache.flink.test.streaming.runtime; +import org.apache.flink.api.common.functions.AggregateFunction; import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.functions.MapPartitionFunction; import org.apache.flink.api.common.functions.ReduceFunction; @@ -146,6 +147,69 @@ class KeyedPartitionWindowedStreamITCase { expectInAnyOrder(resultIterator, "key11000", "key21000", "key31000"); } + @Test + void testAggregate() throws Exception { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + DataStreamSource<Tuple2<String, Integer>> source = + env.fromData( + Tuple2.of("Key1", 1), + Tuple2.of("Key1", 2), + Tuple2.of("Key2", 2), + Tuple2.of("Key2", 1), + Tuple2.of("Key3", 1), + Tuple2.of("Key3", 1), + Tuple2.of("Key3", 1)); + CloseableIterator<String> resultIterator = + source.map( + new MapFunction< + Tuple2<String, Integer>, Tuple2<String, Integer>>() { + @Override + public Tuple2<String, Integer> map( + Tuple2<String, Integer> value) throws Exception { + return value; + } + }) + .setParallelism(2) + .keyBy( + new KeySelector<Tuple2<String, Integer>, String>() { + @Override + public String getKey(Tuple2<String, Integer> value) + throws Exception { + return value.f0; + } + }) + .fullWindowPartition() + .aggregate( + new AggregateFunction< + Tuple2<String, Integer>, TestAccumulator, String>() { + @Override + public TestAccumulator createAccumulator() { + return new TestAccumulator(); + } + + @Override + public TestAccumulator add( + Tuple2<String, Integer> value, + TestAccumulator accumulator) { + accumulator.addTestField(value.f1); + return accumulator; + } + + @Override + public String getResult(TestAccumulator accumulator) { + return accumulator.getTestField(); + } + + @Override + public TestAccumulator merge( + TestAccumulator a, TestAccumulator b) { + throw new RuntimeException(); + } + }) + .executeAndCollect(); + expectInAnyOrder(resultIterator, "97", "97", "97"); + } + private Collection<Tuple2<String, String>> createSource() { List<Tuple2<String, String>> source = new ArrayList<>(); for (int index = 0; index < EVENT_NUMBER; ++index) { @@ -171,4 +235,17 @@ class KeyedPartitionWindowedStreamITCase { Collections.sort(testResults); assertThat(testResults).isEqualTo(listExpected); } + + /** The test accumulator. */ + private static class TestAccumulator { + private Integer testField = 100; + + public void addTestField(Integer number) { + testField = testField - number; + } + + public String getTestField() { + return String.valueOf(testField); + } + } } diff --git a/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/NonKeyedPartitionWindowedStreamITCase.java b/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/NonKeyedPartitionWindowedStreamITCase.java index 3b8559e4711..b522664c9d3 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/NonKeyedPartitionWindowedStreamITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/NonKeyedPartitionWindowedStreamITCase.java @@ -18,6 +18,7 @@ package org.apache.flink.test.streaming.runtime; +import org.apache.flink.api.common.functions.AggregateFunction; import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.functions.MapPartitionFunction; import org.apache.flink.api.common.functions.ReduceFunction; @@ -98,6 +99,43 @@ class NonKeyedPartitionWindowedStreamITCase { expectInAnyOrder(resultIterator, "1000", "1000"); } + @Test + void testAggregate() throws Exception { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + DataStreamSource<Integer> source = env.fromData(1, 1, 2, 2, 3, 3); + CloseableIterator<String> resultIterator = + source.map(v -> v) + .setParallelism(2) + .fullWindowPartition() + .aggregate( + new AggregateFunction<Integer, TestAccumulator, String>() { + @Override + public TestAccumulator createAccumulator() { + return new TestAccumulator(); + } + + @Override + public TestAccumulator add( + Integer value, TestAccumulator accumulator) { + accumulator.addTestField(value); + return accumulator; + } + + @Override + public String getResult(TestAccumulator accumulator) { + return accumulator.getTestField(); + } + + @Override + public TestAccumulator merge( + TestAccumulator a, TestAccumulator b) { + throw new RuntimeException(); + } + }) + .executeAndCollect(); + expectInAnyOrder(resultIterator, "94", "94"); + } + private void expectInAnyOrder(CloseableIterator<String> resultIterator, String... expected) { List<String> listExpected = Lists.newArrayList(expected); List<String> testResults = Lists.newArrayList(resultIterator); @@ -121,4 +159,17 @@ class NonKeyedPartitionWindowedStreamITCase { } return stringBuilder.toString(); } + + /** The test accumulator. */ + private static class TestAccumulator { + private Integer testField = 100; + + public void addTestField(Integer number) { + testField = testField - number; + } + + public String getTestField() { + return String.valueOf(testField); + } + } }
