This is an automated email from the ASF dual-hosted git repository.

guoweijie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 46ec4b02ec2f514eb88e863b1d963e654dea1d32
Author: Wencong Liu <[email protected]>
AuthorDate: Tue Mar 12 13:49:56 2024 +0800

    [FLINK-34543][datastream] Introduce the Aggregate API on 
PartitionWindowedStream
---
 .../datastream/KeyedPartitionWindowedStream.java   |  10 ++
 .../NonKeyedPartitionWindowedStream.java           |  16 +++
 .../api/datastream/PartitionWindowedStream.java    |  12 ++
 .../api/operators/PartitionAggregateOperator.java  |  65 ++++++++++
 .../operators/PartitionAggregateOperatorTest.java  | 137 +++++++++++++++++++++
 .../KeyedPartitionWindowedStreamITCase.java        |  77 ++++++++++++
 .../NonKeyedPartitionWindowedStreamITCase.java     |  51 ++++++++
 7 files changed, 368 insertions(+)

diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/KeyedPartitionWindowedStream.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/KeyedPartitionWindowedStream.java
index ccd9f02e197..5ccc05826bf 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/KeyedPartitionWindowedStream.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/KeyedPartitionWindowedStream.java
@@ -19,6 +19,7 @@
 package org.apache.flink.streaming.api.datastream;
 
 import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.functions.AggregateFunction;
 import org.apache.flink.api.common.functions.MapPartitionFunction;
 import org.apache.flink.api.common.functions.ReduceFunction;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
@@ -80,4 +81,13 @@ public class KeyedPartitionWindowedStream<T, KEY> implements 
PartitionWindowedSt
         reduceFunction = environment.clean(reduceFunction);
         return 
input.window(GlobalWindows.createWithEndOfStreamTrigger()).reduce(reduceFunction);
     }
+
+    @Override
+    public <ACC, R> SingleOutputStreamOperator<R> aggregate(
+            AggregateFunction<T, ACC, R> aggregateFunction) {
+        checkNotNull(aggregateFunction, "The aggregate function must not be 
null.");
+        aggregateFunction = environment.clean(aggregateFunction);
+        return input.window(GlobalWindows.createWithEndOfStreamTrigger())
+                .aggregate(aggregateFunction);
+    }
 }
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/NonKeyedPartitionWindowedStream.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/NonKeyedPartitionWindowedStream.java
index a1943a72f33..0f415aed876 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/NonKeyedPartitionWindowedStream.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/NonKeyedPartitionWindowedStream.java
@@ -19,12 +19,14 @@
 package org.apache.flink.streaming.api.datastream;
 
 import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.functions.AggregateFunction;
 import org.apache.flink.api.common.functions.MapPartitionFunction;
 import org.apache.flink.api.common.functions.ReduceFunction;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.java.typeutils.TypeExtractor;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.streaming.api.operators.MapPartitionOperator;
+import org.apache.flink.streaming.api.operators.PartitionAggregateOperator;
 import org.apache.flink.streaming.api.operators.PartitionReduceOperator;
 
 import static org.apache.flink.util.Preconditions.checkNotNull;
@@ -72,4 +74,18 @@ public class NonKeyedPartitionWindowedStream<T> implements 
PartitionWindowedStre
                         new PartitionReduceOperator<>(reduceFunction))
                 .setParallelism(input.getParallelism());
     }
+
+    @Override
+    public <ACC, R> SingleOutputStreamOperator<R> aggregate(
+            AggregateFunction<T, ACC, R> aggregateFunction) {
+        checkNotNull(aggregateFunction, "The aggregate function must not be 
null.");
+        aggregateFunction = environment.clean(aggregateFunction);
+        String opName = "PartitionAggregate";
+        TypeInformation<R> resultType =
+                TypeExtractor.getAggregateFunctionReturnType(
+                        aggregateFunction, input.getType(), opName, true);
+        return input.transform(
+                        opName, resultType, new 
PartitionAggregateOperator<>(aggregateFunction))
+                .setParallelism(input.getParallelism());
+    }
 }
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/PartitionWindowedStream.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/PartitionWindowedStream.java
index a9f0bcba28a..19ebb0a7530 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/PartitionWindowedStream.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/PartitionWindowedStream.java
@@ -19,6 +19,7 @@
 package org.apache.flink.streaming.api.datastream;
 
 import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.functions.AggregateFunction;
 import org.apache.flink.api.common.functions.MapPartitionFunction;
 import org.apache.flink.api.common.functions.ReduceFunction;
 
@@ -49,4 +50,15 @@ public interface PartitionWindowedStream<T> {
      * @return The data stream with final reduced result.
      */
     SingleOutputStreamOperator<T> reduce(ReduceFunction<T> reduceFunction);
+
+    /**
+     * Applies an aggregate transformation on the records of the window.
+     *
+     * @param aggregateFunction The aggregate function.
+     * @param <ACC> The type of accumulator in aggregate function.
+     * @param <R> The type of aggregate function result.
+     * @return The data stream with final aggregated result.
+     */
+    <ACC, R> SingleOutputStreamOperator<R> aggregate(
+            AggregateFunction<T, ACC, R> aggregateFunction);
 }
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/PartitionAggregateOperator.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/PartitionAggregateOperator.java
new file mode 100644
index 00000000000..c42184a12ed
--- /dev/null
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/PartitionAggregateOperator.java
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.operators;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+import static org.apache.flink.util.Preconditions.checkNotNull;
+
+/**
+ * The {@link PartitionAggregateOperator} is used to apply the aggregate 
transformation on all
+ * records of each partition. Each partition contains all records of a subtask.
+ */
+@Internal
+public class PartitionAggregateOperator<IN, ACC, OUT>
+        extends AbstractUdfStreamOperator<OUT, AggregateFunction<IN, ACC, OUT>>
+        implements OneInputStreamOperator<IN, OUT>, BoundedOneInput {
+
+    private final AggregateFunction<IN, ACC, OUT> aggregateFunction;
+
+    private ACC currentAccumulator = null;
+
+    public PartitionAggregateOperator(AggregateFunction<IN, ACC, OUT> 
aggregateFunction) {
+        super(aggregateFunction);
+        this.aggregateFunction = aggregateFunction;
+    }
+
+    @Override
+    public void open() throws Exception {
+        super.open();
+        this.currentAccumulator = 
checkNotNull(aggregateFunction.createAccumulator());
+    }
+
+    @Override
+    public void processElement(StreamRecord<IN> element) throws Exception {
+        aggregateFunction.add(element.getValue(), currentAccumulator);
+    }
+
+    @Override
+    public void endInput() throws Exception {
+        output.collect(new 
StreamRecord<>(aggregateFunction.getResult(currentAccumulator)));
+    }
+
+    @Override
+    public OperatorAttributes getOperatorAttributes() {
+        return new 
OperatorAttributesBuilder().setOutputOnlyAfterEndOfStream(true).build();
+    }
+}
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/PartitionAggregateOperatorTest.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/PartitionAggregateOperatorTest.java
new file mode 100644
index 00000000000..fc30f6095c9
--- /dev/null
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/PartitionAggregateOperatorTest.java
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.operators;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.OpenContext;
+import org.apache.flink.api.common.functions.RichAggregateFunction;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness;
+import org.apache.flink.streaming.util.TestHarnessUtil;
+
+import org.junit.jupiter.api.Test;
+
+import java.util.LinkedList;
+import java.util.Queue;
+import java.util.concurrent.CompletableFuture;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/** Unit test for {@link PartitionAggregateOperator}. */
+class PartitionAggregateOperatorTest {
+
+    private static final int RECORD = 1;
+
+    @Test
+    void testAggregate() throws Exception {
+        PartitionAggregateOperator<Integer, TestAccumulator, String> 
partitionAggregateOperator =
+                new PartitionAggregateOperator<>(
+                        new Aggregate(new CompletableFuture<>(), new 
CompletableFuture<>()));
+        OneInputStreamOperatorTestHarness<Integer, String> testHarness =
+                new 
OneInputStreamOperatorTestHarness<>(partitionAggregateOperator);
+        Queue<Object> expectedOutput = new LinkedList<>();
+        testHarness.open();
+        testHarness.processElement(new StreamRecord<>(RECORD));
+        testHarness.processElement(new StreamRecord<>(RECORD));
+        testHarness.processElement(new StreamRecord<>(RECORD));
+        testHarness.endInput();
+        expectedOutput.add(new StreamRecord<>("303"));
+        TestHarnessUtil.assertOutputEquals(
+                "The aggregate result is not correct.", expectedOutput, 
testHarness.getOutput());
+        testHarness.close();
+    }
+
+    @Test
+    void testOpenClose() throws Exception {
+        CompletableFuture<Object> openIdentifier = new CompletableFuture<>();
+        CompletableFuture<Object> closeIdentifier = new CompletableFuture<>();
+        PartitionAggregateOperator<Integer, TestAccumulator, String> 
partitionAggregateOperator =
+                new PartitionAggregateOperator<>(new Aggregate(openIdentifier, 
closeIdentifier));
+        OneInputStreamOperatorTestHarness<Integer, String> testHarness =
+                new 
OneInputStreamOperatorTestHarness<>(partitionAggregateOperator);
+        testHarness.open();
+        testHarness.processElement(new StreamRecord<>(RECORD));
+        testHarness.endInput();
+        testHarness.close();
+        assertThat(openIdentifier).isCompleted();
+        assertThat(closeIdentifier).isCompleted();
+        assertThat(testHarness.getOutput()).isNotEmpty();
+    }
+
+    /** The test user implementation of {@link AggregateFunction}. */
+    private static class Aggregate extends RichAggregateFunction<Integer, 
TestAccumulator, String> {
+
+        private final CompletableFuture<Object> openIdentifier;
+
+        private final CompletableFuture<Object> closeIdentifier;
+
+        public Aggregate(
+                CompletableFuture<Object> openIdentifier,
+                CompletableFuture<Object> closeIdentifier) {
+            this.openIdentifier = openIdentifier;
+            this.closeIdentifier = closeIdentifier;
+        }
+
+        @Override
+        public void open(OpenContext openContext) throws Exception {
+            super.open(openContext);
+            openIdentifier.complete(null);
+        }
+
+        @Override
+        public TestAccumulator createAccumulator() {
+            return new TestAccumulator();
+        }
+
+        @Override
+        public TestAccumulator add(Integer value, TestAccumulator accumulator) 
{
+            accumulator.addNumber(value);
+            return accumulator;
+        }
+
+        @Override
+        public String getResult(TestAccumulator accumulator) {
+            return accumulator.getResult();
+        }
+
+        @Override
+        public TestAccumulator merge(TestAccumulator a, TestAccumulator b) {
+            return null;
+        }
+
+        @Override
+        public void close() throws Exception {
+            super.close();
+            closeIdentifier.complete(null);
+        }
+    }
+
+    /** The test accumulator. */
+    private static class TestAccumulator {
+        private Integer result = 0;
+
+        public void addNumber(Integer number) {
+            result = result + number + 100;
+        }
+
+        public String getResult() {
+            return String.valueOf(result);
+        }
+    }
+}
diff --git 
a/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/KeyedPartitionWindowedStreamITCase.java
 
b/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/KeyedPartitionWindowedStreamITCase.java
index 4c2951bda6f..c51de02d9fb 100644
--- 
a/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/KeyedPartitionWindowedStreamITCase.java
+++ 
b/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/KeyedPartitionWindowedStreamITCase.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.test.streaming.runtime;
 
+import org.apache.flink.api.common.functions.AggregateFunction;
 import org.apache.flink.api.common.functions.MapFunction;
 import org.apache.flink.api.common.functions.MapPartitionFunction;
 import org.apache.flink.api.common.functions.ReduceFunction;
@@ -146,6 +147,69 @@ class KeyedPartitionWindowedStreamITCase {
         expectInAnyOrder(resultIterator, "key11000", "key21000", "key31000");
     }
 
+    @Test
+    void testAggregate() throws Exception {
+        StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
+        DataStreamSource<Tuple2<String, Integer>> source =
+                env.fromData(
+                        Tuple2.of("Key1", 1),
+                        Tuple2.of("Key1", 2),
+                        Tuple2.of("Key2", 2),
+                        Tuple2.of("Key2", 1),
+                        Tuple2.of("Key3", 1),
+                        Tuple2.of("Key3", 1),
+                        Tuple2.of("Key3", 1));
+        CloseableIterator<String> resultIterator =
+                source.map(
+                                new MapFunction<
+                                        Tuple2<String, Integer>, 
Tuple2<String, Integer>>() {
+                                    @Override
+                                    public Tuple2<String, Integer> map(
+                                            Tuple2<String, Integer> value) 
throws Exception {
+                                        return value;
+                                    }
+                                })
+                        .setParallelism(2)
+                        .keyBy(
+                                new KeySelector<Tuple2<String, Integer>, 
String>() {
+                                    @Override
+                                    public String getKey(Tuple2<String, 
Integer> value)
+                                            throws Exception {
+                                        return value.f0;
+                                    }
+                                })
+                        .fullWindowPartition()
+                        .aggregate(
+                                new AggregateFunction<
+                                        Tuple2<String, Integer>, 
TestAccumulator, String>() {
+                                    @Override
+                                    public TestAccumulator createAccumulator() 
{
+                                        return new TestAccumulator();
+                                    }
+
+                                    @Override
+                                    public TestAccumulator add(
+                                            Tuple2<String, Integer> value,
+                                            TestAccumulator accumulator) {
+                                        accumulator.addTestField(value.f1);
+                                        return accumulator;
+                                    }
+
+                                    @Override
+                                    public String getResult(TestAccumulator 
accumulator) {
+                                        return accumulator.getTestField();
+                                    }
+
+                                    @Override
+                                    public TestAccumulator merge(
+                                            TestAccumulator a, TestAccumulator 
b) {
+                                        throw new RuntimeException();
+                                    }
+                                })
+                        .executeAndCollect();
+        expectInAnyOrder(resultIterator, "97", "97", "97");
+    }
+
     private Collection<Tuple2<String, String>> createSource() {
         List<Tuple2<String, String>> source = new ArrayList<>();
         for (int index = 0; index < EVENT_NUMBER; ++index) {
@@ -171,4 +235,17 @@ class KeyedPartitionWindowedStreamITCase {
         Collections.sort(testResults);
         assertThat(testResults).isEqualTo(listExpected);
     }
+
+    /** The test accumulator. */
+    private static class TestAccumulator {
+        private Integer testField = 100;
+
+        public void addTestField(Integer number) {
+            testField = testField - number;
+        }
+
+        public String getTestField() {
+            return String.valueOf(testField);
+        }
+    }
 }
diff --git 
a/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/NonKeyedPartitionWindowedStreamITCase.java
 
b/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/NonKeyedPartitionWindowedStreamITCase.java
index 3b8559e4711..b522664c9d3 100644
--- 
a/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/NonKeyedPartitionWindowedStreamITCase.java
+++ 
b/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/NonKeyedPartitionWindowedStreamITCase.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.test.streaming.runtime;
 
+import org.apache.flink.api.common.functions.AggregateFunction;
 import org.apache.flink.api.common.functions.MapFunction;
 import org.apache.flink.api.common.functions.MapPartitionFunction;
 import org.apache.flink.api.common.functions.ReduceFunction;
@@ -98,6 +99,43 @@ class NonKeyedPartitionWindowedStreamITCase {
         expectInAnyOrder(resultIterator, "1000", "1000");
     }
 
+    @Test
+    void testAggregate() throws Exception {
+        StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
+        DataStreamSource<Integer> source = env.fromData(1, 1, 2, 2, 3, 3);
+        CloseableIterator<String> resultIterator =
+                source.map(v -> v)
+                        .setParallelism(2)
+                        .fullWindowPartition()
+                        .aggregate(
+                                new AggregateFunction<Integer, 
TestAccumulator, String>() {
+                                    @Override
+                                    public TestAccumulator createAccumulator() 
{
+                                        return new TestAccumulator();
+                                    }
+
+                                    @Override
+                                    public TestAccumulator add(
+                                            Integer value, TestAccumulator 
accumulator) {
+                                        accumulator.addTestField(value);
+                                        return accumulator;
+                                    }
+
+                                    @Override
+                                    public String getResult(TestAccumulator 
accumulator) {
+                                        return accumulator.getTestField();
+                                    }
+
+                                    @Override
+                                    public TestAccumulator merge(
+                                            TestAccumulator a, TestAccumulator 
b) {
+                                        throw new RuntimeException();
+                                    }
+                                })
+                        .executeAndCollect();
+        expectInAnyOrder(resultIterator, "94", "94");
+    }
+
     private void expectInAnyOrder(CloseableIterator<String> resultIterator, 
String... expected) {
         List<String> listExpected = Lists.newArrayList(expected);
         List<String> testResults = Lists.newArrayList(resultIterator);
@@ -121,4 +159,17 @@ class NonKeyedPartitionWindowedStreamITCase {
         }
         return stringBuilder.toString();
     }
+
+    /** The test accumulator. */
+    private static class TestAccumulator {
+        private Integer testField = 100;
+
+        public void addTestField(Integer number) {
+            testField = testField - number;
+        }
+
+        public String getTestField() {
+            return String.valueOf(testField);
+        }
+    }
 }

Reply via email to