This is an automated email from the ASF dual-hosted git repository. guoweijie pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit cfbd41ecaa16edad5e460572eea6b5cf660af99d Author: Wencong Liu <[email protected]> AuthorDate: Tue Mar 12 12:17:00 2024 +0800 [FLINK-34543][datastream] Introduce the MapPartition API on PartitionWindowedStream --- .../datastream/KeyedPartitionWindowedStream.java | 35 +++ .../NonKeyedPartitionWindowedStream.java | 19 ++ .../api/datastream/PartitionWindowedStream.java | 13 +- .../api/operators/MapPartitionIterator.java | 217 ++++++++++++++++ .../api/operators/MapPartitionOperator.java | 77 ++++++ .../api/operators/MapPartitionIteratorTest.java | 278 +++++++++++++++++++++ .../api/operators/MapPartitionOperatorTest.java | 117 +++++++++ .../KeyedPartitionWindowedStreamITCase.java | 120 +++++++++ .../NonKeyedPartitionWindowedStreamITCase.java | 95 +++++++ 9 files changed, 970 insertions(+), 1 deletion(-) diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/KeyedPartitionWindowedStream.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/KeyedPartitionWindowedStream.java index ecb9db8563b..6f2364ec2a8 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/KeyedPartitionWindowedStream.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/KeyedPartitionWindowedStream.java @@ -19,7 +19,16 @@ package org.apache.flink.streaming.api.datastream; import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.TypeExtractor; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.windowing.WindowFunction; +import org.apache.flink.streaming.api.windowing.assigners.GlobalWindows; +import org.apache.flink.streaming.api.windowing.windows.GlobalWindow; +import org.apache.flink.util.Collector; + +import static org.apache.flink.util.Preconditions.checkNotNull; /** * {@link KeyedPartitionWindowedStream} represents a data stream that collects all records with the @@ -37,4 +46,30 @@ public class KeyedPartitionWindowedStream<T, KEY> implements PartitionWindowedSt this.environment = environment; this.input = input; } + + @Override + public <R> SingleOutputStreamOperator<R> mapPartition( + MapPartitionFunction<T, R> mapPartitionFunction) { + checkNotNull(mapPartitionFunction, "The map partition function must not be null."); + mapPartitionFunction = environment.clean(mapPartitionFunction); + String opName = "MapPartition"; + TypeInformation<R> resultType = + TypeExtractor.getMapPartitionReturnTypes( + mapPartitionFunction, input.getType(), opName, true); + MapPartitionFunction<T, R> function = mapPartitionFunction; + return input.window(GlobalWindows.createWithEndOfStreamTrigger()) + .apply( + new WindowFunction<T, R, KEY, GlobalWindow>() { + @Override + public void apply( + KEY key, + GlobalWindow window, + Iterable<T> input, + Collector<R> out) + throws Exception { + function.mapPartition(input, out); + } + }, + resultType); + } } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/NonKeyedPartitionWindowedStream.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/NonKeyedPartitionWindowedStream.java index 9454f04267d..33b36111c67 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/NonKeyedPartitionWindowedStream.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/NonKeyedPartitionWindowedStream.java @@ -19,7 +19,11 @@ package org.apache.flink.streaming.api.datastream; import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.TypeExtractor; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.operators.MapPartitionOperator; /** * {@link NonKeyedPartitionWindowedStream} represents a data stream that collects all records of @@ -37,4 +41,19 @@ public class NonKeyedPartitionWindowedStream<T> implements PartitionWindowedStre this.environment = environment; this.input = input; } + + @Override + public <R> SingleOutputStreamOperator<R> mapPartition( + MapPartitionFunction<T, R> mapPartitionFunction) { + if (mapPartitionFunction == null) { + throw new NullPointerException("The map partition function must not be null."); + } + mapPartitionFunction = environment.clean(mapPartitionFunction); + String opName = "MapPartition"; + TypeInformation<R> resultType = + TypeExtractor.getMapPartitionReturnTypes( + mapPartitionFunction, input.getType(), opName, true); + return input.transform(opName, resultType, new MapPartitionOperator<>(mapPartitionFunction)) + .setParallelism(input.getParallelism()); + } } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/PartitionWindowedStream.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/PartitionWindowedStream.java index 8c3caf97982..2169dd6d71f 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/PartitionWindowedStream.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/datastream/PartitionWindowedStream.java @@ -19,6 +19,7 @@ package org.apache.flink.streaming.api.datastream; import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.api.common.functions.MapPartitionFunction; /** * {@link PartitionWindowedStream} represents a data stream that collects all records of each @@ -29,4 +30,14 @@ import org.apache.flink.annotation.PublicEvolving; * @param <T> The type of the elements in this stream. */ @PublicEvolving -public interface PartitionWindowedStream<T> {} +public interface PartitionWindowedStream<T> { + + /** + * Process the records of the window by {@link MapPartitionFunction}. + * + * @param mapPartitionFunction The map partition function. + * @param <R> The type of map partition result. + * @return The data stream with map partition result. + */ + <R> SingleOutputStreamOperator<R> mapPartition(MapPartitionFunction<T, R> mapPartitionFunction); +} diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/MapPartitionIterator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/MapPartitionIterator.java new file mode 100644 index 00000000000..badf374e1c8 --- /dev/null +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/MapPartitionIterator.java @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.operators; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.util.ExceptionUtils; +import org.apache.flink.util.concurrent.ExecutorThreadFactory; + +import javax.annotation.concurrent.GuardedBy; + +import java.util.Iterator; +import java.util.LinkedList; +import java.util.Queue; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.locks.Condition; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Consumer; +import java.util.function.Supplier; + +import static org.apache.flink.util.Preconditions.checkState; + +/** + * The {@link MapPartitionIterator} is an iterator used in the {@link MapPartitionOperator}.The task + * main thread will add records to it. It will set itself as the input parameter of {@link + * MapPartitionFunction} and execute the function. + */ +@Internal +public class MapPartitionIterator<IN> implements Iterator<IN> { + + /** + * Max number of caches. + * + * <p>The constant defines the maximum number of caches that can be created. Its value is set to + * 100, which is considered sufficient for most parallel jobs. Each cache is a record and + * occupies a minimal amount of memory so the value is not excessively large. + */ + public static final int DEFAULT_MAX_CACHE_NUM = 100; + + /** The lock to ensure consistency between task main thread and udf executor. */ + private final Lock lock = new ReentrantLock(); + + /** The queue to store record caches. */ + @GuardedBy("lock") + private final Queue<IN> cacheQueue = new LinkedList<>(); + + /** The condition to indicate the cache queue is not empty. */ + private final Condition cacheNotEmpty = lock.newCondition(); + + /** The condition to indicate the cache queue is not full. */ + private final Condition cacheNotFull = lock.newCondition(); + + /** The condition to indicate the udf is finished. */ + private final Condition udfFinish = lock.newCondition(); + + /** The task udf executor. */ + private final ExecutorService udfExecutor; + + /** The flag to represent the finished state of udf. */ + @GuardedBy("lock") + private boolean udfFinished = false; + + /** The flag to represent the closed state of this iterator. */ + @GuardedBy("lock") + private boolean closed = false; + + public MapPartitionIterator(Consumer<Iterator<IN>> udf) { + this.udfExecutor = + Executors.newSingleThreadExecutor(new ExecutorThreadFactory("TaskUDFExecutor")); + this.udfExecutor.execute( + () -> { + udf.accept(this); + runWithLock( + () -> { + udfFinished = true; + udfFinish.signalAll(); + cacheNotFull.signalAll(); + }); + }); + } + + @Override + public boolean hasNext() { + return supplyWithLock( + () -> { + if (cacheQueue.size() > 0) { + return true; + } else if (closed) { + return false; + } else { + waitCacheNotEmpty(); + return hasNext(); + } + }); + } + + @Override + public IN next() { + return supplyWithLock( + () -> { + IN record; + if (cacheQueue.size() > 0) { + if (!closed && cacheQueue.size() == DEFAULT_MAX_CACHE_NUM) { + cacheNotFull.signalAll(); + } + record = cacheQueue.poll(); + return record; + } else { + if (closed) { + return null; + } + waitCacheNotEmpty(); + return cacheQueue.poll(); + } + }); + } + + public void addRecord(IN record) { + runWithLock( + () -> { + checkState(!closed); + if (udfFinished) { + return; + } + if (cacheQueue.size() < DEFAULT_MAX_CACHE_NUM) { + cacheQueue.add(record); + if (cacheQueue.size() == 1) { + cacheNotEmpty.signalAll(); + } + } else { + waitCacheNotFull(); + addRecord(record); + } + }); + } + + public void close() { + runWithLock( + () -> { + closed = true; + if (!udfFinished) { + cacheNotEmpty.signalAll(); + waitUDFFinished(); + } + udfExecutor.shutdown(); + }); + } + + // ------------------------------------ + // Internal Method + // ------------------------------------ + + /** Wait until the cache is not empty. */ + private void waitCacheNotEmpty() { + try { + cacheNotEmpty.await(); + } catch (InterruptedException e) { + ExceptionUtils.rethrow(e); + } + } + + /** Wait until the cache is not full. */ + private void waitCacheNotFull() { + try { + cacheNotFull.await(); + } catch (InterruptedException e) { + ExceptionUtils.rethrow(e); + } + } + + /** Wait until the UDF is finished. */ + private void waitUDFFinished() { + try { + udfFinish.await(); + } catch (InterruptedException e) { + ExceptionUtils.rethrow(e); + } + } + + private void runWithLock(Runnable runnable) { + try { + lock.lock(); + runnable.run(); + } finally { + lock.unlock(); + } + } + + private <ANY> ANY supplyWithLock(Supplier<ANY> supplier) { + ANY result; + try { + lock.lock(); + result = supplier.get(); + } finally { + lock.unlock(); + } + return result; + } +} diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/MapPartitionOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/MapPartitionOperator.java new file mode 100644 index 00000000000..df814912da3 --- /dev/null +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/MapPartitionOperator.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.operators; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.util.ExceptionUtils; + +/** + * The {@link MapPartitionOperator} is used to process all records in each partition on non-keyed + * stream. Each partition contains all records of a subtask. + */ +@Internal +public class MapPartitionOperator<IN, OUT> + extends AbstractUdfStreamOperator<OUT, MapPartitionFunction<IN, OUT>> + implements OneInputStreamOperator<IN, OUT>, BoundedOneInput { + + private final MapPartitionFunction<IN, OUT> function; + + private transient MapPartitionIterator<IN> iterator; + + public MapPartitionOperator(MapPartitionFunction<IN, OUT> function) { + super(function); + this.function = function; + // This operator is set to be non-chained as it doesn't use task main thread to write + // records to output, which may introduce risks to downstream chained operators. + this.chainingStrategy = ChainingStrategy.NEVER; + } + + @Override + public void open() throws Exception { + super.open(); + this.iterator = + new MapPartitionIterator<>( + iterator -> { + TimestampedCollector<OUT> outputCollector = + new TimestampedCollector<>(output); + try { + function.mapPartition(() -> iterator, outputCollector); + } catch (Exception e) { + ExceptionUtils.rethrow(e); + } + }); + } + + @Override + public void processElement(StreamRecord<IN> element) throws Exception { + iterator.addRecord(element.getValue()); + } + + @Override + public void endInput() throws Exception { + iterator.close(); + } + + @Override + public OperatorAttributes getOperatorAttributes() { + return new OperatorAttributesBuilder().setOutputOnlyAfterEndOfStream(true).build(); + } +} diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/MapPartitionIteratorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/MapPartitionIteratorTest.java new file mode 100644 index 00000000000..fc971ed883a --- /dev/null +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/MapPartitionIteratorTest.java @@ -0,0 +1,278 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.operators; + +import org.apache.flink.util.ExceptionUtils; + +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; + +import static org.apache.flink.streaming.api.operators.MapPartitionIterator.DEFAULT_MAX_CACHE_NUM; +import static org.assertj.core.api.Assertions.assertThat; + +/** Unit test for {@link MapPartitionIterator}. */ +class MapPartitionIteratorTest { + + private static final String RECORD = "TEST"; + + private static final int RECORD_NUMBER = 3; + + @Test + void testInitialize() throws ExecutionException, InterruptedException { + CompletableFuture<Object> result = new CompletableFuture<>(); + MapPartitionIterator<String> iterator = + new MapPartitionIterator<>(stringIterator -> result.complete(null)); + result.get(); + assertThat(result).isCompleted(); + iterator.close(); + } + + @Test + void testAddRecord() throws ExecutionException, InterruptedException { + CompletableFuture<List<String>> result = new CompletableFuture<>(); + CompletableFuture<Object> udfFinishTrigger = new CompletableFuture<>(); + MapPartitionIterator<String> iterator = + new MapPartitionIterator<>( + inputIterator -> { + List<String> strings = new ArrayList<>(); + for (int index = 0; index < RECORD_NUMBER; ++index) { + strings.add(inputIterator.next()); + } + result.complete(strings); + try { + udfFinishTrigger.get(); + } catch (InterruptedException | ExecutionException e) { + ExceptionUtils.rethrow(e); + } + }); + // 1.Test addRecord() when the cache is empty in the MapPartitionIterator. + addRecordToIterator(RECORD_NUMBER, iterator); + List<String> results = result.get(); + assertThat(results.size()).isEqualTo(RECORD_NUMBER); + assertThat(results.get(0)).isEqualTo(RECORD); + assertThat(results.get(1)).isEqualTo(RECORD); + assertThat(results.get(2)).isEqualTo(RECORD); + // 2.Test addRecord() when the cache is full in the MapPartitionIterator. + addRecordToIterator(DEFAULT_MAX_CACHE_NUM, iterator); + CompletableFuture<Object> mockedTaskThread1 = new CompletableFuture<>(); + CompletableFuture<List<String>> addRecordFinishIdentifier1 = new CompletableFuture<>(); + mockedTaskThread1.thenRunAsync( + () -> { + iterator.addRecord(RECORD); + addRecordFinishIdentifier1.complete(null); + }); + mockedTaskThread1.complete(null); + assertThat(addRecordFinishIdentifier1).isNotCompleted(); + iterator.next(); + addRecordFinishIdentifier1.get(); + assertThat(addRecordFinishIdentifier1).isCompleted(); + // 2.Test addRecord() when the udf is finished in the MapPartitionIterator. + CompletableFuture<Object> mockedTaskThread2 = new CompletableFuture<>(); + CompletableFuture<List<String>> addRecordFinishIdentifier2 = new CompletableFuture<>(); + mockedTaskThread2.thenRunAsync( + () -> { + iterator.addRecord(RECORD); + addRecordFinishIdentifier2.complete(null); + }); + mockedTaskThread2.complete(null); + assertThat(addRecordFinishIdentifier2).isNotCompleted(); + udfFinishTrigger.complete(null); + addRecordFinishIdentifier2.get(); + assertThat(addRecordFinishIdentifier2).isCompleted(); + assertThat(udfFinishTrigger).isCompleted(); + iterator.close(); + } + + @Test + void testHasNext() throws ExecutionException, InterruptedException { + CompletableFuture<Object> udfTrigger = new CompletableFuture<>(); + CompletableFuture<Object> udfReadIteratorFinishIdentifier = new CompletableFuture<>(); + CompletableFuture<Object> udfFinishTrigger = new CompletableFuture<>(); + MapPartitionIterator<String> iterator = + new MapPartitionIterator<>( + inputIterator -> { + try { + udfTrigger.get(); + } catch (InterruptedException | ExecutionException e) { + ExceptionUtils.rethrow(e); + } + for (int index = 0; index < RECORD_NUMBER; ++index) { + inputIterator.next(); + } + udfReadIteratorFinishIdentifier.complete(null); + try { + udfFinishTrigger.get(); + } catch (InterruptedException | ExecutionException e) { + ExceptionUtils.rethrow(e); + } + }); + // 1.Test hasNext() when the cache is not empty in the MapPartitionIterator. + addRecordToIterator(RECORD_NUMBER, iterator); + assertThat(iterator.hasNext()).isTrue(); + // 2.Test hasNext() when the cache is empty in the MapPartitionIterator. + udfTrigger.complete(null); + udfReadIteratorFinishIdentifier.get(); + assertThat(udfReadIteratorFinishIdentifier).isCompleted(); + CompletableFuture<Object> mockedUDFThread1 = new CompletableFuture<>(); + CompletableFuture<Boolean> hasNextFinishIdentifier1 = new CompletableFuture<>(); + mockedUDFThread1.thenRunAsync( + () -> { + boolean hasNext = iterator.hasNext(); + hasNextFinishIdentifier1.complete(hasNext); + }); + mockedUDFThread1.complete(null); + assertThat(hasNextFinishIdentifier1).isNotCompleted(); + iterator.addRecord(RECORD); + hasNextFinishIdentifier1.get(); + assertThat(hasNextFinishIdentifier1).isCompletedWithValue(true); + iterator.next(); + // 2.Test hasNext() when the MapPartitionIterator is closed. + CompletableFuture<Object> mockedUDFThread2 = new CompletableFuture<>(); + CompletableFuture<Boolean> hasNextFinishIdentifier2 = new CompletableFuture<>(); + mockedUDFThread2.thenRunAsync( + () -> { + boolean hasNext = iterator.hasNext(); + hasNextFinishIdentifier2.complete(hasNext); + udfFinishTrigger.complete(null); + }); + mockedUDFThread2.complete(null); + assertThat(hasNextFinishIdentifier2).isNotCompleted(); + iterator.close(); + assertThat(hasNextFinishIdentifier2).isCompletedWithValue(false); + assertThat(udfFinishTrigger).isCompleted(); + } + + @Test + void testNext() throws ExecutionException, InterruptedException { + CompletableFuture<List<String>> result = new CompletableFuture<>(); + CompletableFuture<Object> udfFinishTrigger = new CompletableFuture<>(); + MapPartitionIterator<String> iterator = + new MapPartitionIterator<>( + inputIterator -> { + List<String> strings = new ArrayList<>(); + for (int index = 0; index < RECORD_NUMBER; ++index) { + strings.add(inputIterator.next()); + } + result.complete(strings); + try { + udfFinishTrigger.get(); + } catch (InterruptedException | ExecutionException e) { + ExceptionUtils.rethrow(e); + } + }); + // 1.Test next() when the cache is not empty in the MapPartitionIterator. + addRecordToIterator(RECORD_NUMBER, iterator); + List<String> results = result.get(); + assertThat(results.size()).isEqualTo(RECORD_NUMBER); + assertThat(results.get(0)).isEqualTo(RECORD); + assertThat(results.get(1)).isEqualTo(RECORD); + assertThat(results.get(2)).isEqualTo(RECORD); + // 2.Test next() when the cache is empty in the MapPartitionIterator. + CompletableFuture<Object> mockedUDFThread1 = new CompletableFuture<>(); + CompletableFuture<String> nextFinishIdentifier1 = new CompletableFuture<>(); + mockedUDFThread1.thenRunAsync( + () -> { + String next = iterator.next(); + nextFinishIdentifier1.complete(next); + }); + mockedUDFThread1.complete(null); + assertThat(nextFinishIdentifier1).isNotCompleted(); + iterator.addRecord(RECORD); + nextFinishIdentifier1.get(); + assertThat(nextFinishIdentifier1).isCompletedWithValue(RECORD); + // 2.Test next() when the MapPartitionIterator is closed. + CompletableFuture<Object> mockedUDFThread2 = new CompletableFuture<>(); + CompletableFuture<String> nextFinishIdentifier2 = new CompletableFuture<>(); + mockedUDFThread2.thenRunAsync( + () -> { + String next = iterator.next(); + nextFinishIdentifier2.complete(next); + udfFinishTrigger.complete(null); + }); + mockedUDFThread2.complete(null); + assertThat(nextFinishIdentifier2).isNotCompleted(); + iterator.close(); + assertThat(nextFinishIdentifier2).isCompletedWithValue(null); + assertThat(udfFinishTrigger).isCompleted(); + } + + @Test + void testClose() throws ExecutionException, InterruptedException { + // 1.Test close() when the cache is not empty in the MapPartitionIterator. + CompletableFuture<?> udfFinishTrigger1 = new CompletableFuture<>(); + MapPartitionIterator<String> iterator1 = + new MapPartitionIterator<>( + ignored -> { + try { + udfFinishTrigger1.get(); + } catch (InterruptedException | ExecutionException e) { + ExceptionUtils.rethrow(e); + } + }); + iterator1.addRecord(RECORD); + CompletableFuture<Object> mockedTaskThread1 = new CompletableFuture<>(); + CompletableFuture<Object> iteratorCloseIdentifier1 = new CompletableFuture<>(); + mockedTaskThread1.thenRunAsync( + () -> { + iterator1.close(); + iteratorCloseIdentifier1.complete(null); + }); + mockedTaskThread1.complete(null); + assertThat(iteratorCloseIdentifier1).isNotCompleted(); + udfFinishTrigger1.complete(null); + iteratorCloseIdentifier1.get(); + assertThat(iteratorCloseIdentifier1).isCompleted(); + // 2.Test close() when the cache is empty in the MapPartitionIterator. + CompletableFuture<?> udfFinishTrigger2 = new CompletableFuture<>(); + MapPartitionIterator<String> iterator2 = + new MapPartitionIterator<>( + ignored -> { + try { + udfFinishTrigger2.get(); + } catch (InterruptedException | ExecutionException e) { + ExceptionUtils.rethrow(e); + } + }); + CompletableFuture<Object> mockedTaskThread2 = new CompletableFuture<>(); + CompletableFuture<Object> iteratorCloseIdentifier2 = new CompletableFuture<>(); + mockedTaskThread1.thenRunAsync( + () -> { + iterator2.close(); + iteratorCloseIdentifier2.complete(null); + }); + mockedTaskThread2.complete(null); + assertThat(iteratorCloseIdentifier2).isNotCompleted(); + udfFinishTrigger2.complete(null); + iteratorCloseIdentifier2.get(); + assertThat(iteratorCloseIdentifier2).isCompleted(); + // 2.Test close() when the udf is finished in the MapPartitionIterator. + MapPartitionIterator<String> iterator3 = new MapPartitionIterator<>(ignored -> {}); + iterator3.close(); + } + + private void addRecordToIterator(int cacheNumber, MapPartitionIterator<String> iterator) { + for (int index = 0; index < cacheNumber; ++index) { + iterator.addRecord(RECORD); + } + } +} diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/MapPartitionOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/MapPartitionOperatorTest.java new file mode 100644 index 00000000000..ba9ab92420c --- /dev/null +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/MapPartitionOperatorTest.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.api.operators; + +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.functions.OpenContext; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness; +import org.apache.flink.streaming.util.TestHarnessUtil; +import org.apache.flink.util.Collector; + +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.LinkedList; +import java.util.List; +import java.util.Queue; +import java.util.concurrent.CompletableFuture; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Unit test for {@link MapPartitionOperator}. */ +class MapPartitionOperatorTest { + + private static final String RECORD = "TEST"; + + @Test + void testMapPartition() throws Exception { + MapPartitionOperator<String, Integer> mapPartitionOperator = + new MapPartitionOperator<>( + new MapPartition(new CompletableFuture<>(), new CompletableFuture<>())); + OneInputStreamOperatorTestHarness<String, Integer> testHarness = + new OneInputStreamOperatorTestHarness<>(mapPartitionOperator); + Queue<Object> expectedOutput = new LinkedList<>(); + testHarness.open(); + testHarness.processElement(new StreamRecord<>(RECORD)); + testHarness.processElement(new StreamRecord<>(RECORD)); + testHarness.processElement(new StreamRecord<>(RECORD)); + testHarness.endInput(); + expectedOutput.add(new StreamRecord<>(3)); + TestHarnessUtil.assertOutputEquals( + "The result of map partition is not correct.", + expectedOutput, + testHarness.getOutput()); + testHarness.close(); + } + + @Test + void testOpenClose() throws Exception { + CompletableFuture<Object> openIdentifier = new CompletableFuture<>(); + CompletableFuture<Object> closeIdentifier = new CompletableFuture<>(); + MapPartitionOperator<String, Integer> mapPartitionOperator = + new MapPartitionOperator<>(new MapPartition(openIdentifier, closeIdentifier)); + OneInputStreamOperatorTestHarness<String, Integer> testHarness = + new OneInputStreamOperatorTestHarness<>(mapPartitionOperator); + testHarness.open(); + testHarness.processElement(new StreamRecord<>(RECORD)); + testHarness.endInput(); + testHarness.close(); + assertThat(openIdentifier).isCompleted(); + assertThat(closeIdentifier).isCompleted(); + assertThat(testHarness.getOutput()).isNotEmpty(); + } + + /** The test user implementation of {@link MapPartitionFunction}. */ + private static class MapPartition extends RichMapPartitionFunction<String, Integer> { + + private final CompletableFuture<Object> openIdentifier; + + private final CompletableFuture<Object> closeIdentifier; + + public MapPartition( + CompletableFuture<Object> openIdentifier, + CompletableFuture<Object> closeIdentifier) { + this.openIdentifier = openIdentifier; + this.closeIdentifier = closeIdentifier; + } + + @Override + public void open(OpenContext openContext) throws Exception { + super.open(openContext); + openIdentifier.complete(null); + } + + @Override + public void mapPartition(Iterable<String> values, Collector<Integer> out) { + List<String> result = new ArrayList<>(); + for (String value : values) { + result.add(value); + } + out.collect(result.size()); + } + + @Override + public void close() throws Exception { + super.close(); + closeIdentifier.complete(null); + } + } +} diff --git a/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/KeyedPartitionWindowedStreamITCase.java b/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/KeyedPartitionWindowedStreamITCase.java new file mode 100644 index 00000000000..0c731210a82 --- /dev/null +++ b/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/KeyedPartitionWindowedStreamITCase.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.test.streaming.runtime; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.streaming.api.datastream.DataStreamSource; +import org.apache.flink.streaming.api.datastream.KeyedPartitionWindowedStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.util.CloseableIterator; +import org.apache.flink.util.Collector; + +import org.apache.flink.shaded.guava31.com.google.common.collect.Lists; + +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Integration tests for {@link KeyedPartitionWindowedStream}. */ +class KeyedPartitionWindowedStreamITCase { + + private static final int EVENT_NUMBER = 100; + + private static final String TEST_EVENT = "Test"; + + @Test + void testMapPartition() throws Exception { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + DataStreamSource<Tuple2<String, String>> source = env.fromData(createSource()); + CloseableIterator<String> resultIterator = + source.map( + new MapFunction<Tuple2<String, String>, Tuple2<String, String>>() { + @Override + public Tuple2<String, String> map(Tuple2<String, String> value) + throws Exception { + return value; + } + }) + .setParallelism(2) + .keyBy( + new KeySelector<Tuple2<String, String>, String>() { + @Override + public String getKey(Tuple2<String, String> value) + throws Exception { + return value.f0; + } + }) + .fullWindowPartition() + .mapPartition( + new MapPartitionFunction<Tuple2<String, String>, String>() { + @Override + public void mapPartition( + Iterable<Tuple2<String, String>> values, + Collector<String> out) + throws Exception { + StringBuilder sb = new StringBuilder(); + for (Tuple2<String, String> value : values) { + sb.append(value.f0); + sb.append(value.f1); + } + out.collect(sb.toString()); + } + }) + .executeAndCollect(); + expectInAnyOrder( + resultIterator, + createExpectedString(1), + createExpectedString(2), + createExpectedString(3)); + } + + private Collection<Tuple2<String, String>> createSource() { + List<Tuple2<String, String>> source = new ArrayList<>(); + for (int index = 0; index < EVENT_NUMBER; ++index) { + source.add(Tuple2.of("k1", TEST_EVENT)); + source.add(Tuple2.of("k2", TEST_EVENT)); + source.add(Tuple2.of("k3", TEST_EVENT)); + } + return source; + } + + private String createExpectedString(int key) { + StringBuilder stringBuilder = new StringBuilder(); + for (int index = 0; index < EVENT_NUMBER; ++index) { + stringBuilder.append("k").append(key).append(TEST_EVENT); + } + return stringBuilder.toString(); + } + + private void expectInAnyOrder(CloseableIterator<String> resultIterator, String... expected) { + List<String> listExpected = Lists.newArrayList(expected); + List<String> testResults = Lists.newArrayList(resultIterator); + Collections.sort(listExpected); + Collections.sort(testResults); + assertThat(testResults).isEqualTo(listExpected); + } +} diff --git a/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/NonKeyedPartitionWindowedStreamITCase.java b/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/NonKeyedPartitionWindowedStreamITCase.java new file mode 100644 index 00000000000..19e6698b0c6 --- /dev/null +++ b/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/NonKeyedPartitionWindowedStreamITCase.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.test.streaming.runtime; + +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.streaming.api.datastream.DataStreamSource; +import org.apache.flink.streaming.api.datastream.NonKeyedPartitionWindowedStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.util.CloseableIterator; +import org.apache.flink.util.Collector; + +import org.apache.flink.shaded.guava31.com.google.common.collect.Lists; + +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Integration tests for {@link NonKeyedPartitionWindowedStream}. */ +class NonKeyedPartitionWindowedStreamITCase { + + private static final int EVENT_NUMBER = 100; + + private static final String TEST_EVENT = "Test"; + + @Test + void testMapPartition() throws Exception { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + DataStreamSource<String> source = env.fromData(createSource()); + int parallelism = 2; + CloseableIterator<String> resultIterator = + source.map(v -> v) + .setParallelism(parallelism) + .fullWindowPartition() + .mapPartition( + new MapPartitionFunction<String, String>() { + @Override + public void mapPartition( + Iterable<String> values, Collector<String> out) { + StringBuilder sb = new StringBuilder(); + for (String value : values) { + sb.append(value); + } + out.collect(sb.toString()); + } + }) + .executeAndCollect(); + String expectedResult = createExpectedString(EVENT_NUMBER / parallelism); + expectInAnyOrder(resultIterator, expectedResult, expectedResult); + } + + private void expectInAnyOrder(CloseableIterator<String> resultIterator, String... expected) { + List<String> listExpected = Lists.newArrayList(expected); + List<String> testResults = Lists.newArrayList(resultIterator); + Collections.sort(listExpected); + Collections.sort(testResults); + assertThat(testResults).isEqualTo(listExpected); + } + + private Collection<String> createSource() { + ArrayList<String> source = new ArrayList<>(); + for (int index = 0; index < EVENT_NUMBER; ++index) { + source.add(TEST_EVENT); + } + return source; + } + + private String createExpectedString(int number) { + StringBuilder stringBuilder = new StringBuilder(); + for (int index = 0; index < number; ++index) { + stringBuilder.append(TEST_EVENT); + } + return stringBuilder.toString(); + } +}
