This is an automated email from the ASF dual-hosted git repository.

guoweijie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit cfbd41ecaa16edad5e460572eea6b5cf660af99d
Author: Wencong Liu <[email protected]>
AuthorDate: Tue Mar 12 12:17:00 2024 +0800

    [FLINK-34543][datastream] Introduce the MapPartition API on 
PartitionWindowedStream
---
 .../datastream/KeyedPartitionWindowedStream.java   |  35 +++
 .../NonKeyedPartitionWindowedStream.java           |  19 ++
 .../api/datastream/PartitionWindowedStream.java    |  13 +-
 .../api/operators/MapPartitionIterator.java        | 217 ++++++++++++++++
 .../api/operators/MapPartitionOperator.java        |  77 ++++++
 .../api/operators/MapPartitionIteratorTest.java    | 278 +++++++++++++++++++++
 .../api/operators/MapPartitionOperatorTest.java    | 117 +++++++++
 .../KeyedPartitionWindowedStreamITCase.java        | 120 +++++++++
 .../NonKeyedPartitionWindowedStreamITCase.java     |  95 +++++++
 9 files changed, 970 insertions(+), 1 deletion(-)

diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/KeyedPartitionWindowedStream.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/KeyedPartitionWindowedStream.java
index ecb9db8563b..6f2364ec2a8 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/KeyedPartitionWindowedStream.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/KeyedPartitionWindowedStream.java
@@ -19,7 +19,16 @@
 package org.apache.flink.streaming.api.datastream;
 
 import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.TypeExtractor;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.WindowFunction;
+import org.apache.flink.streaming.api.windowing.assigners.GlobalWindows;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.util.Collector;
+
+import static org.apache.flink.util.Preconditions.checkNotNull;
 
 /**
  * {@link KeyedPartitionWindowedStream} represents a data stream that collects 
all records with the
@@ -37,4 +46,30 @@ public class KeyedPartitionWindowedStream<T, KEY> implements 
PartitionWindowedSt
         this.environment = environment;
         this.input = input;
     }
+
+    @Override
+    public <R> SingleOutputStreamOperator<R> mapPartition(
+            MapPartitionFunction<T, R> mapPartitionFunction) {
+        checkNotNull(mapPartitionFunction, "The map partition function must 
not be null.");
+        mapPartitionFunction = environment.clean(mapPartitionFunction);
+        String opName = "MapPartition";
+        TypeInformation<R> resultType =
+                TypeExtractor.getMapPartitionReturnTypes(
+                        mapPartitionFunction, input.getType(), opName, true);
+        MapPartitionFunction<T, R> function = mapPartitionFunction;
+        return input.window(GlobalWindows.createWithEndOfStreamTrigger())
+                .apply(
+                        new WindowFunction<T, R, KEY, GlobalWindow>() {
+                            @Override
+                            public void apply(
+                                    KEY key,
+                                    GlobalWindow window,
+                                    Iterable<T> input,
+                                    Collector<R> out)
+                                    throws Exception {
+                                function.mapPartition(input, out);
+                            }
+                        },
+                        resultType);
+    }
 }
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/NonKeyedPartitionWindowedStream.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/NonKeyedPartitionWindowedStream.java
index 9454f04267d..33b36111c67 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/NonKeyedPartitionWindowedStream.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/NonKeyedPartitionWindowedStream.java
@@ -19,7 +19,11 @@
 package org.apache.flink.streaming.api.datastream;
 
 import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.TypeExtractor;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.MapPartitionOperator;
 
 /**
  * {@link NonKeyedPartitionWindowedStream} represents a data stream that 
collects all records of
@@ -37,4 +41,19 @@ public class NonKeyedPartitionWindowedStream<T> implements 
PartitionWindowedStre
         this.environment = environment;
         this.input = input;
     }
+
+    @Override
+    public <R> SingleOutputStreamOperator<R> mapPartition(
+            MapPartitionFunction<T, R> mapPartitionFunction) {
+        if (mapPartitionFunction == null) {
+            throw new NullPointerException("The map partition function must 
not be null.");
+        }
+        mapPartitionFunction = environment.clean(mapPartitionFunction);
+        String opName = "MapPartition";
+        TypeInformation<R> resultType =
+                TypeExtractor.getMapPartitionReturnTypes(
+                        mapPartitionFunction, input.getType(), opName, true);
+        return input.transform(opName, resultType, new 
MapPartitionOperator<>(mapPartitionFunction))
+                .setParallelism(input.getParallelism());
+    }
 }
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/PartitionWindowedStream.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/PartitionWindowedStream.java
index 8c3caf97982..2169dd6d71f 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/PartitionWindowedStream.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/PartitionWindowedStream.java
@@ -19,6 +19,7 @@
 package org.apache.flink.streaming.api.datastream;
 
 import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
 
 /**
  * {@link PartitionWindowedStream} represents a data stream that collects all 
records of each
@@ -29,4 +30,14 @@ import org.apache.flink.annotation.PublicEvolving;
  * @param <T> The type of the elements in this stream.
  */
 @PublicEvolving
-public interface PartitionWindowedStream<T> {}
+public interface PartitionWindowedStream<T> {
+
+    /**
+     * Process the records of the window by {@link MapPartitionFunction}.
+     *
+     * @param mapPartitionFunction The map partition function.
+     * @param <R> The type of map partition result.
+     * @return The data stream with map partition result.
+     */
+    <R> SingleOutputStreamOperator<R> mapPartition(MapPartitionFunction<T, R> 
mapPartitionFunction);
+}
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/MapPartitionIterator.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/MapPartitionIterator.java
new file mode 100644
index 00000000000..badf374e1c8
--- /dev/null
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/MapPartitionIterator.java
@@ -0,0 +1,217 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.operators;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.util.ExceptionUtils;
+import org.apache.flink.util.concurrent.ExecutorThreadFactory;
+
+import javax.annotation.concurrent.GuardedBy;
+
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.Queue;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.locks.Condition;
+import java.util.concurrent.locks.Lock;
+import java.util.concurrent.locks.ReentrantLock;
+import java.util.function.Consumer;
+import java.util.function.Supplier;
+
+import static org.apache.flink.util.Preconditions.checkState;
+
+/**
+ * The {@link MapPartitionIterator} is an iterator used in the {@link 
MapPartitionOperator}.The task
+ * main thread will add records to it. It will set itself as the input 
parameter of {@link
+ * MapPartitionFunction} and execute the function.
+ */
+@Internal
+public class MapPartitionIterator<IN> implements Iterator<IN> {
+
+    /**
+     * Max number of caches.
+     *
+     * <p>The constant defines the maximum number of caches that can be 
created. Its value is set to
+     * 100, which is considered sufficient for most parallel jobs. Each cache 
is a record and
+     * occupies a minimal amount of memory so the value is not excessively 
large.
+     */
+    public static final int DEFAULT_MAX_CACHE_NUM = 100;
+
+    /** The lock to ensure consistency between task main thread and udf 
executor. */
+    private final Lock lock = new ReentrantLock();
+
+    /** The queue to store record caches. */
+    @GuardedBy("lock")
+    private final Queue<IN> cacheQueue = new LinkedList<>();
+
+    /** The condition to indicate the cache queue is not empty. */
+    private final Condition cacheNotEmpty = lock.newCondition();
+
+    /** The condition to indicate the cache queue is not full. */
+    private final Condition cacheNotFull = lock.newCondition();
+
+    /** The condition to indicate the udf is finished. */
+    private final Condition udfFinish = lock.newCondition();
+
+    /** The task udf executor. */
+    private final ExecutorService udfExecutor;
+
+    /** The flag to represent the finished state of udf. */
+    @GuardedBy("lock")
+    private boolean udfFinished = false;
+
+    /** The flag to represent the closed state of this iterator. */
+    @GuardedBy("lock")
+    private boolean closed = false;
+
+    public MapPartitionIterator(Consumer<Iterator<IN>> udf) {
+        this.udfExecutor =
+                Executors.newSingleThreadExecutor(new 
ExecutorThreadFactory("TaskUDFExecutor"));
+        this.udfExecutor.execute(
+                () -> {
+                    udf.accept(this);
+                    runWithLock(
+                            () -> {
+                                udfFinished = true;
+                                udfFinish.signalAll();
+                                cacheNotFull.signalAll();
+                            });
+                });
+    }
+
+    @Override
+    public boolean hasNext() {
+        return supplyWithLock(
+                () -> {
+                    if (cacheQueue.size() > 0) {
+                        return true;
+                    } else if (closed) {
+                        return false;
+                    } else {
+                        waitCacheNotEmpty();
+                        return hasNext();
+                    }
+                });
+    }
+
+    @Override
+    public IN next() {
+        return supplyWithLock(
+                () -> {
+                    IN record;
+                    if (cacheQueue.size() > 0) {
+                        if (!closed && cacheQueue.size() == 
DEFAULT_MAX_CACHE_NUM) {
+                            cacheNotFull.signalAll();
+                        }
+                        record = cacheQueue.poll();
+                        return record;
+                    } else {
+                        if (closed) {
+                            return null;
+                        }
+                        waitCacheNotEmpty();
+                        return cacheQueue.poll();
+                    }
+                });
+    }
+
+    public void addRecord(IN record) {
+        runWithLock(
+                () -> {
+                    checkState(!closed);
+                    if (udfFinished) {
+                        return;
+                    }
+                    if (cacheQueue.size() < DEFAULT_MAX_CACHE_NUM) {
+                        cacheQueue.add(record);
+                        if (cacheQueue.size() == 1) {
+                            cacheNotEmpty.signalAll();
+                        }
+                    } else {
+                        waitCacheNotFull();
+                        addRecord(record);
+                    }
+                });
+    }
+
+    public void close() {
+        runWithLock(
+                () -> {
+                    closed = true;
+                    if (!udfFinished) {
+                        cacheNotEmpty.signalAll();
+                        waitUDFFinished();
+                    }
+                    udfExecutor.shutdown();
+                });
+    }
+
+    // ------------------------------------
+    //           Internal Method
+    // ------------------------------------
+
+    /** Wait until the cache is not empty. */
+    private void waitCacheNotEmpty() {
+        try {
+            cacheNotEmpty.await();
+        } catch (InterruptedException e) {
+            ExceptionUtils.rethrow(e);
+        }
+    }
+
+    /** Wait until the cache is not full. */
+    private void waitCacheNotFull() {
+        try {
+            cacheNotFull.await();
+        } catch (InterruptedException e) {
+            ExceptionUtils.rethrow(e);
+        }
+    }
+
+    /** Wait until the UDF is finished. */
+    private void waitUDFFinished() {
+        try {
+            udfFinish.await();
+        } catch (InterruptedException e) {
+            ExceptionUtils.rethrow(e);
+        }
+    }
+
+    private void runWithLock(Runnable runnable) {
+        try {
+            lock.lock();
+            runnable.run();
+        } finally {
+            lock.unlock();
+        }
+    }
+
+    private <ANY> ANY supplyWithLock(Supplier<ANY> supplier) {
+        ANY result;
+        try {
+            lock.lock();
+            result = supplier.get();
+        } finally {
+            lock.unlock();
+        }
+        return result;
+    }
+}
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/MapPartitionOperator.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/MapPartitionOperator.java
new file mode 100644
index 00000000000..df814912da3
--- /dev/null
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/MapPartitionOperator.java
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.operators;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.ExceptionUtils;
+
+/**
+ * The {@link MapPartitionOperator} is used to process all records in each 
partition on non-keyed
+ * stream. Each partition contains all records of a subtask.
+ */
+@Internal
+public class MapPartitionOperator<IN, OUT>
+        extends AbstractUdfStreamOperator<OUT, MapPartitionFunction<IN, OUT>>
+        implements OneInputStreamOperator<IN, OUT>, BoundedOneInput {
+
+    private final MapPartitionFunction<IN, OUT> function;
+
+    private transient MapPartitionIterator<IN> iterator;
+
+    public MapPartitionOperator(MapPartitionFunction<IN, OUT> function) {
+        super(function);
+        this.function = function;
+        // This operator is set to be non-chained as it doesn't use task main 
thread to write
+        // records to output, which may introduce risks to downstream chained 
operators.
+        this.chainingStrategy = ChainingStrategy.NEVER;
+    }
+
+    @Override
+    public void open() throws Exception {
+        super.open();
+        this.iterator =
+                new MapPartitionIterator<>(
+                        iterator -> {
+                            TimestampedCollector<OUT> outputCollector =
+                                    new TimestampedCollector<>(output);
+                            try {
+                                function.mapPartition(() -> iterator, 
outputCollector);
+                            } catch (Exception e) {
+                                ExceptionUtils.rethrow(e);
+                            }
+                        });
+    }
+
+    @Override
+    public void processElement(StreamRecord<IN> element) throws Exception {
+        iterator.addRecord(element.getValue());
+    }
+
+    @Override
+    public void endInput() throws Exception {
+        iterator.close();
+    }
+
+    @Override
+    public OperatorAttributes getOperatorAttributes() {
+        return new 
OperatorAttributesBuilder().setOutputOnlyAfterEndOfStream(true).build();
+    }
+}
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/MapPartitionIteratorTest.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/MapPartitionIteratorTest.java
new file mode 100644
index 00000000000..fc971ed883a
--- /dev/null
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/MapPartitionIteratorTest.java
@@ -0,0 +1,278 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.operators;
+
+import org.apache.flink.util.ExceptionUtils;
+
+import org.junit.jupiter.api.Test;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
+
+import static 
org.apache.flink.streaming.api.operators.MapPartitionIterator.DEFAULT_MAX_CACHE_NUM;
+import static org.assertj.core.api.Assertions.assertThat;
+
+/** Unit test for {@link MapPartitionIterator}. */
+class MapPartitionIteratorTest {
+
+    private static final String RECORD = "TEST";
+
+    private static final int RECORD_NUMBER = 3;
+
+    @Test
+    void testInitialize() throws ExecutionException, InterruptedException {
+        CompletableFuture<Object> result = new CompletableFuture<>();
+        MapPartitionIterator<String> iterator =
+                new MapPartitionIterator<>(stringIterator -> 
result.complete(null));
+        result.get();
+        assertThat(result).isCompleted();
+        iterator.close();
+    }
+
+    @Test
+    void testAddRecord() throws ExecutionException, InterruptedException {
+        CompletableFuture<List<String>> result = new CompletableFuture<>();
+        CompletableFuture<Object> udfFinishTrigger = new CompletableFuture<>();
+        MapPartitionIterator<String> iterator =
+                new MapPartitionIterator<>(
+                        inputIterator -> {
+                            List<String> strings = new ArrayList<>();
+                            for (int index = 0; index < RECORD_NUMBER; 
++index) {
+                                strings.add(inputIterator.next());
+                            }
+                            result.complete(strings);
+                            try {
+                                udfFinishTrigger.get();
+                            } catch (InterruptedException | ExecutionException 
e) {
+                                ExceptionUtils.rethrow(e);
+                            }
+                        });
+        // 1.Test addRecord() when the cache is empty in the 
MapPartitionIterator.
+        addRecordToIterator(RECORD_NUMBER, iterator);
+        List<String> results = result.get();
+        assertThat(results.size()).isEqualTo(RECORD_NUMBER);
+        assertThat(results.get(0)).isEqualTo(RECORD);
+        assertThat(results.get(1)).isEqualTo(RECORD);
+        assertThat(results.get(2)).isEqualTo(RECORD);
+        // 2.Test addRecord() when the cache is full in the 
MapPartitionIterator.
+        addRecordToIterator(DEFAULT_MAX_CACHE_NUM, iterator);
+        CompletableFuture<Object> mockedTaskThread1 = new 
CompletableFuture<>();
+        CompletableFuture<List<String>> addRecordFinishIdentifier1 = new 
CompletableFuture<>();
+        mockedTaskThread1.thenRunAsync(
+                () -> {
+                    iterator.addRecord(RECORD);
+                    addRecordFinishIdentifier1.complete(null);
+                });
+        mockedTaskThread1.complete(null);
+        assertThat(addRecordFinishIdentifier1).isNotCompleted();
+        iterator.next();
+        addRecordFinishIdentifier1.get();
+        assertThat(addRecordFinishIdentifier1).isCompleted();
+        // 2.Test addRecord() when the udf is finished in the 
MapPartitionIterator.
+        CompletableFuture<Object> mockedTaskThread2 = new 
CompletableFuture<>();
+        CompletableFuture<List<String>> addRecordFinishIdentifier2 = new 
CompletableFuture<>();
+        mockedTaskThread2.thenRunAsync(
+                () -> {
+                    iterator.addRecord(RECORD);
+                    addRecordFinishIdentifier2.complete(null);
+                });
+        mockedTaskThread2.complete(null);
+        assertThat(addRecordFinishIdentifier2).isNotCompleted();
+        udfFinishTrigger.complete(null);
+        addRecordFinishIdentifier2.get();
+        assertThat(addRecordFinishIdentifier2).isCompleted();
+        assertThat(udfFinishTrigger).isCompleted();
+        iterator.close();
+    }
+
+    @Test
+    void testHasNext() throws ExecutionException, InterruptedException {
+        CompletableFuture<Object> udfTrigger = new CompletableFuture<>();
+        CompletableFuture<Object> udfReadIteratorFinishIdentifier = new 
CompletableFuture<>();
+        CompletableFuture<Object> udfFinishTrigger = new CompletableFuture<>();
+        MapPartitionIterator<String> iterator =
+                new MapPartitionIterator<>(
+                        inputIterator -> {
+                            try {
+                                udfTrigger.get();
+                            } catch (InterruptedException | ExecutionException 
e) {
+                                ExceptionUtils.rethrow(e);
+                            }
+                            for (int index = 0; index < RECORD_NUMBER; 
++index) {
+                                inputIterator.next();
+                            }
+                            udfReadIteratorFinishIdentifier.complete(null);
+                            try {
+                                udfFinishTrigger.get();
+                            } catch (InterruptedException | ExecutionException 
e) {
+                                ExceptionUtils.rethrow(e);
+                            }
+                        });
+        // 1.Test hasNext() when the cache is not empty in the 
MapPartitionIterator.
+        addRecordToIterator(RECORD_NUMBER, iterator);
+        assertThat(iterator.hasNext()).isTrue();
+        // 2.Test hasNext() when the cache is empty in the 
MapPartitionIterator.
+        udfTrigger.complete(null);
+        udfReadIteratorFinishIdentifier.get();
+        assertThat(udfReadIteratorFinishIdentifier).isCompleted();
+        CompletableFuture<Object> mockedUDFThread1 = new CompletableFuture<>();
+        CompletableFuture<Boolean> hasNextFinishIdentifier1 = new 
CompletableFuture<>();
+        mockedUDFThread1.thenRunAsync(
+                () -> {
+                    boolean hasNext = iterator.hasNext();
+                    hasNextFinishIdentifier1.complete(hasNext);
+                });
+        mockedUDFThread1.complete(null);
+        assertThat(hasNextFinishIdentifier1).isNotCompleted();
+        iterator.addRecord(RECORD);
+        hasNextFinishIdentifier1.get();
+        assertThat(hasNextFinishIdentifier1).isCompletedWithValue(true);
+        iterator.next();
+        // 2.Test hasNext() when the MapPartitionIterator is closed.
+        CompletableFuture<Object> mockedUDFThread2 = new CompletableFuture<>();
+        CompletableFuture<Boolean> hasNextFinishIdentifier2 = new 
CompletableFuture<>();
+        mockedUDFThread2.thenRunAsync(
+                () -> {
+                    boolean hasNext = iterator.hasNext();
+                    hasNextFinishIdentifier2.complete(hasNext);
+                    udfFinishTrigger.complete(null);
+                });
+        mockedUDFThread2.complete(null);
+        assertThat(hasNextFinishIdentifier2).isNotCompleted();
+        iterator.close();
+        assertThat(hasNextFinishIdentifier2).isCompletedWithValue(false);
+        assertThat(udfFinishTrigger).isCompleted();
+    }
+
+    @Test
+    void testNext() throws ExecutionException, InterruptedException {
+        CompletableFuture<List<String>> result = new CompletableFuture<>();
+        CompletableFuture<Object> udfFinishTrigger = new CompletableFuture<>();
+        MapPartitionIterator<String> iterator =
+                new MapPartitionIterator<>(
+                        inputIterator -> {
+                            List<String> strings = new ArrayList<>();
+                            for (int index = 0; index < RECORD_NUMBER; 
++index) {
+                                strings.add(inputIterator.next());
+                            }
+                            result.complete(strings);
+                            try {
+                                udfFinishTrigger.get();
+                            } catch (InterruptedException | ExecutionException 
e) {
+                                ExceptionUtils.rethrow(e);
+                            }
+                        });
+        // 1.Test next() when the cache is not empty in the 
MapPartitionIterator.
+        addRecordToIterator(RECORD_NUMBER, iterator);
+        List<String> results = result.get();
+        assertThat(results.size()).isEqualTo(RECORD_NUMBER);
+        assertThat(results.get(0)).isEqualTo(RECORD);
+        assertThat(results.get(1)).isEqualTo(RECORD);
+        assertThat(results.get(2)).isEqualTo(RECORD);
+        // 2.Test next() when the cache is empty in the MapPartitionIterator.
+        CompletableFuture<Object> mockedUDFThread1 = new CompletableFuture<>();
+        CompletableFuture<String> nextFinishIdentifier1 = new 
CompletableFuture<>();
+        mockedUDFThread1.thenRunAsync(
+                () -> {
+                    String next = iterator.next();
+                    nextFinishIdentifier1.complete(next);
+                });
+        mockedUDFThread1.complete(null);
+        assertThat(nextFinishIdentifier1).isNotCompleted();
+        iterator.addRecord(RECORD);
+        nextFinishIdentifier1.get();
+        assertThat(nextFinishIdentifier1).isCompletedWithValue(RECORD);
+        // 2.Test next() when the MapPartitionIterator is closed.
+        CompletableFuture<Object> mockedUDFThread2 = new CompletableFuture<>();
+        CompletableFuture<String> nextFinishIdentifier2 = new 
CompletableFuture<>();
+        mockedUDFThread2.thenRunAsync(
+                () -> {
+                    String next = iterator.next();
+                    nextFinishIdentifier2.complete(next);
+                    udfFinishTrigger.complete(null);
+                });
+        mockedUDFThread2.complete(null);
+        assertThat(nextFinishIdentifier2).isNotCompleted();
+        iterator.close();
+        assertThat(nextFinishIdentifier2).isCompletedWithValue(null);
+        assertThat(udfFinishTrigger).isCompleted();
+    }
+
+    @Test
+    void testClose() throws ExecutionException, InterruptedException {
+        // 1.Test close() when the cache is not empty in the 
MapPartitionIterator.
+        CompletableFuture<?> udfFinishTrigger1 = new CompletableFuture<>();
+        MapPartitionIterator<String> iterator1 =
+                new MapPartitionIterator<>(
+                        ignored -> {
+                            try {
+                                udfFinishTrigger1.get();
+                            } catch (InterruptedException | ExecutionException 
e) {
+                                ExceptionUtils.rethrow(e);
+                            }
+                        });
+        iterator1.addRecord(RECORD);
+        CompletableFuture<Object> mockedTaskThread1 = new 
CompletableFuture<>();
+        CompletableFuture<Object> iteratorCloseIdentifier1 = new 
CompletableFuture<>();
+        mockedTaskThread1.thenRunAsync(
+                () -> {
+                    iterator1.close();
+                    iteratorCloseIdentifier1.complete(null);
+                });
+        mockedTaskThread1.complete(null);
+        assertThat(iteratorCloseIdentifier1).isNotCompleted();
+        udfFinishTrigger1.complete(null);
+        iteratorCloseIdentifier1.get();
+        assertThat(iteratorCloseIdentifier1).isCompleted();
+        // 2.Test close() when the cache is empty in the MapPartitionIterator.
+        CompletableFuture<?> udfFinishTrigger2 = new CompletableFuture<>();
+        MapPartitionIterator<String> iterator2 =
+                new MapPartitionIterator<>(
+                        ignored -> {
+                            try {
+                                udfFinishTrigger2.get();
+                            } catch (InterruptedException | ExecutionException 
e) {
+                                ExceptionUtils.rethrow(e);
+                            }
+                        });
+        CompletableFuture<Object> mockedTaskThread2 = new 
CompletableFuture<>();
+        CompletableFuture<Object> iteratorCloseIdentifier2 = new 
CompletableFuture<>();
+        mockedTaskThread1.thenRunAsync(
+                () -> {
+                    iterator2.close();
+                    iteratorCloseIdentifier2.complete(null);
+                });
+        mockedTaskThread2.complete(null);
+        assertThat(iteratorCloseIdentifier2).isNotCompleted();
+        udfFinishTrigger2.complete(null);
+        iteratorCloseIdentifier2.get();
+        assertThat(iteratorCloseIdentifier2).isCompleted();
+        // 2.Test close() when the udf is finished in the MapPartitionIterator.
+        MapPartitionIterator<String> iterator3 = new 
MapPartitionIterator<>(ignored -> {});
+        iterator3.close();
+    }
+
+    private void addRecordToIterator(int cacheNumber, 
MapPartitionIterator<String> iterator) {
+        for (int index = 0; index < cacheNumber; ++index) {
+            iterator.addRecord(RECORD);
+        }
+    }
+}
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/MapPartitionOperatorTest.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/MapPartitionOperatorTest.java
new file mode 100644
index 00000000000..ba9ab92420c
--- /dev/null
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/MapPartitionOperatorTest.java
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.operators;
+
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.OpenContext;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness;
+import org.apache.flink.streaming.util.TestHarnessUtil;
+import org.apache.flink.util.Collector;
+
+import org.junit.jupiter.api.Test;
+
+import java.util.ArrayList;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Queue;
+import java.util.concurrent.CompletableFuture;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/** Unit test for {@link MapPartitionOperator}. */
+class MapPartitionOperatorTest {
+
+    private static final String RECORD = "TEST";
+
+    @Test
+    void testMapPartition() throws Exception {
+        MapPartitionOperator<String, Integer> mapPartitionOperator =
+                new MapPartitionOperator<>(
+                        new MapPartition(new CompletableFuture<>(), new 
CompletableFuture<>()));
+        OneInputStreamOperatorTestHarness<String, Integer> testHarness =
+                new OneInputStreamOperatorTestHarness<>(mapPartitionOperator);
+        Queue<Object> expectedOutput = new LinkedList<>();
+        testHarness.open();
+        testHarness.processElement(new StreamRecord<>(RECORD));
+        testHarness.processElement(new StreamRecord<>(RECORD));
+        testHarness.processElement(new StreamRecord<>(RECORD));
+        testHarness.endInput();
+        expectedOutput.add(new StreamRecord<>(3));
+        TestHarnessUtil.assertOutputEquals(
+                "The result of map partition is not correct.",
+                expectedOutput,
+                testHarness.getOutput());
+        testHarness.close();
+    }
+
+    @Test
+    void testOpenClose() throws Exception {
+        CompletableFuture<Object> openIdentifier = new CompletableFuture<>();
+        CompletableFuture<Object> closeIdentifier = new CompletableFuture<>();
+        MapPartitionOperator<String, Integer> mapPartitionOperator =
+                new MapPartitionOperator<>(new MapPartition(openIdentifier, 
closeIdentifier));
+        OneInputStreamOperatorTestHarness<String, Integer> testHarness =
+                new OneInputStreamOperatorTestHarness<>(mapPartitionOperator);
+        testHarness.open();
+        testHarness.processElement(new StreamRecord<>(RECORD));
+        testHarness.endInput();
+        testHarness.close();
+        assertThat(openIdentifier).isCompleted();
+        assertThat(closeIdentifier).isCompleted();
+        assertThat(testHarness.getOutput()).isNotEmpty();
+    }
+
+    /** The test user implementation of {@link MapPartitionFunction}. */
+    private static class MapPartition extends RichMapPartitionFunction<String, 
Integer> {
+
+        private final CompletableFuture<Object> openIdentifier;
+
+        private final CompletableFuture<Object> closeIdentifier;
+
+        public MapPartition(
+                CompletableFuture<Object> openIdentifier,
+                CompletableFuture<Object> closeIdentifier) {
+            this.openIdentifier = openIdentifier;
+            this.closeIdentifier = closeIdentifier;
+        }
+
+        @Override
+        public void open(OpenContext openContext) throws Exception {
+            super.open(openContext);
+            openIdentifier.complete(null);
+        }
+
+        @Override
+        public void mapPartition(Iterable<String> values, Collector<Integer> 
out) {
+            List<String> result = new ArrayList<>();
+            for (String value : values) {
+                result.add(value);
+            }
+            out.collect(result.size());
+        }
+
+        @Override
+        public void close() throws Exception {
+            super.close();
+            closeIdentifier.complete(null);
+        }
+    }
+}
diff --git 
a/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/KeyedPartitionWindowedStreamITCase.java
 
b/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/KeyedPartitionWindowedStreamITCase.java
new file mode 100644
index 00000000000..0c731210a82
--- /dev/null
+++ 
b/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/KeyedPartitionWindowedStreamITCase.java
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.test.streaming.runtime;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.streaming.api.datastream.DataStreamSource;
+import org.apache.flink.streaming.api.datastream.KeyedPartitionWindowedStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.util.CloseableIterator;
+import org.apache.flink.util.Collector;
+
+import org.apache.flink.shaded.guava31.com.google.common.collect.Lists;
+
+import org.junit.jupiter.api.Test;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/** Integration tests for {@link KeyedPartitionWindowedStream}. */
+class KeyedPartitionWindowedStreamITCase {
+
+    private static final int EVENT_NUMBER = 100;
+
+    private static final String TEST_EVENT = "Test";
+
+    @Test
+    void testMapPartition() throws Exception {
+        StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
+        DataStreamSource<Tuple2<String, String>> source = 
env.fromData(createSource());
+        CloseableIterator<String> resultIterator =
+                source.map(
+                                new MapFunction<Tuple2<String, String>, 
Tuple2<String, String>>() {
+                                    @Override
+                                    public Tuple2<String, String> 
map(Tuple2<String, String> value)
+                                            throws Exception {
+                                        return value;
+                                    }
+                                })
+                        .setParallelism(2)
+                        .keyBy(
+                                new KeySelector<Tuple2<String, String>, 
String>() {
+                                    @Override
+                                    public String getKey(Tuple2<String, 
String> value)
+                                            throws Exception {
+                                        return value.f0;
+                                    }
+                                })
+                        .fullWindowPartition()
+                        .mapPartition(
+                                new MapPartitionFunction<Tuple2<String, 
String>, String>() {
+                                    @Override
+                                    public void mapPartition(
+                                            Iterable<Tuple2<String, String>> 
values,
+                                            Collector<String> out)
+                                            throws Exception {
+                                        StringBuilder sb = new StringBuilder();
+                                        for (Tuple2<String, String> value : 
values) {
+                                            sb.append(value.f0);
+                                            sb.append(value.f1);
+                                        }
+                                        out.collect(sb.toString());
+                                    }
+                                })
+                        .executeAndCollect();
+        expectInAnyOrder(
+                resultIterator,
+                createExpectedString(1),
+                createExpectedString(2),
+                createExpectedString(3));
+    }
+
+    private Collection<Tuple2<String, String>> createSource() {
+        List<Tuple2<String, String>> source = new ArrayList<>();
+        for (int index = 0; index < EVENT_NUMBER; ++index) {
+            source.add(Tuple2.of("k1", TEST_EVENT));
+            source.add(Tuple2.of("k2", TEST_EVENT));
+            source.add(Tuple2.of("k3", TEST_EVENT));
+        }
+        return source;
+    }
+
+    private String createExpectedString(int key) {
+        StringBuilder stringBuilder = new StringBuilder();
+        for (int index = 0; index < EVENT_NUMBER; ++index) {
+            stringBuilder.append("k").append(key).append(TEST_EVENT);
+        }
+        return stringBuilder.toString();
+    }
+
+    private void expectInAnyOrder(CloseableIterator<String> resultIterator, 
String... expected) {
+        List<String> listExpected = Lists.newArrayList(expected);
+        List<String> testResults = Lists.newArrayList(resultIterator);
+        Collections.sort(listExpected);
+        Collections.sort(testResults);
+        assertThat(testResults).isEqualTo(listExpected);
+    }
+}
diff --git 
a/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/NonKeyedPartitionWindowedStreamITCase.java
 
b/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/NonKeyedPartitionWindowedStreamITCase.java
new file mode 100644
index 00000000000..19e6698b0c6
--- /dev/null
+++ 
b/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/NonKeyedPartitionWindowedStreamITCase.java
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.test.streaming.runtime;
+
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.streaming.api.datastream.DataStreamSource;
+import 
org.apache.flink.streaming.api.datastream.NonKeyedPartitionWindowedStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.util.CloseableIterator;
+import org.apache.flink.util.Collector;
+
+import org.apache.flink.shaded.guava31.com.google.common.collect.Lists;
+
+import org.junit.jupiter.api.Test;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/** Integration tests for {@link NonKeyedPartitionWindowedStream}. */
+class NonKeyedPartitionWindowedStreamITCase {
+
+    private static final int EVENT_NUMBER = 100;
+
+    private static final String TEST_EVENT = "Test";
+
+    @Test
+    void testMapPartition() throws Exception {
+        StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
+        DataStreamSource<String> source = env.fromData(createSource());
+        int parallelism = 2;
+        CloseableIterator<String> resultIterator =
+                source.map(v -> v)
+                        .setParallelism(parallelism)
+                        .fullWindowPartition()
+                        .mapPartition(
+                                new MapPartitionFunction<String, String>() {
+                                    @Override
+                                    public void mapPartition(
+                                            Iterable<String> values, 
Collector<String> out) {
+                                        StringBuilder sb = new StringBuilder();
+                                        for (String value : values) {
+                                            sb.append(value);
+                                        }
+                                        out.collect(sb.toString());
+                                    }
+                                })
+                        .executeAndCollect();
+        String expectedResult = createExpectedString(EVENT_NUMBER / 
parallelism);
+        expectInAnyOrder(resultIterator, expectedResult, expectedResult);
+    }
+
+    private void expectInAnyOrder(CloseableIterator<String> resultIterator, 
String... expected) {
+        List<String> listExpected = Lists.newArrayList(expected);
+        List<String> testResults = Lists.newArrayList(resultIterator);
+        Collections.sort(listExpected);
+        Collections.sort(testResults);
+        assertThat(testResults).isEqualTo(listExpected);
+    }
+
+    private Collection<String> createSource() {
+        ArrayList<String> source = new ArrayList<>();
+        for (int index = 0; index < EVENT_NUMBER; ++index) {
+            source.add(TEST_EVENT);
+        }
+        return source;
+    }
+
+    private String createExpectedString(int number) {
+        StringBuilder stringBuilder = new StringBuilder();
+        for (int index = 0; index < number; ++index) {
+            stringBuilder.append(TEST_EVENT);
+        }
+        return stringBuilder.toString();
+    }
+}


Reply via email to