This is an automated email from the ASF dual-hosted git repository.

leiyanfei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 4ae4b871c08 [FLINK-36957][Datastream] Implement asyc state version of 
stream flatmap (#25848)
4ae4b871c08 is described below

commit 4ae4b871c08900fb31b6bc58ef4817187672e96f
Author: Yanfei Lei <fredia...@gmail.com>
AuthorDate: Tue Jan 14 11:11:12 2025 +0800

    [FLINK-36957][Datastream] Implement asyc state version of stream flatmap 
(#25848)
---
 .../state/api/input/KeyedStateInputFormatTest.java | 86 +++++++++++++++-------
 .../operators/AsyncStreamFlatMap.java              | 57 ++++++++++++++
 .../streaming/api/datastream/KeyedStream.java      | 16 ++++
 3 files changed, 133 insertions(+), 26 deletions(-)

diff --git 
a/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/input/KeyedStateInputFormatTest.java
 
b/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/input/KeyedStateInputFormatTest.java
index b6c5acaa126..2e6faa0f480 100644
--- 
a/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/input/KeyedStateInputFormatTest.java
+++ 
b/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/input/KeyedStateInputFormatTest.java
@@ -26,6 +26,7 @@ import org.apache.flink.api.common.state.ValueStateDescriptor;
 import org.apache.flink.api.common.typeinfo.Types;
 import org.apache.flink.api.common.typeutils.base.VoidSerializer;
 import org.apache.flink.configuration.Configuration;
+import org.apache.flink.runtime.asyncprocessing.operators.AsyncStreamFlatMap;
 import org.apache.flink.runtime.checkpoint.OperatorState;
 import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
 import org.apache.flink.runtime.jobgraph.OperatorID;
@@ -39,12 +40,17 @@ import 
org.apache.flink.streaming.api.functions.KeyedProcessFunction;
 import org.apache.flink.streaming.api.operators.KeyedProcessOperator;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
 import org.apache.flink.streaming.api.operators.StreamFlatMap;
+import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
 import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness;
 import org.apache.flink.streaming.util.MockStreamingRuntimeContext;
 import org.apache.flink.util.Collector;
 
 import org.junit.Assert;
-import org.junit.Test;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.ValueSource;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
 
 import javax.annotation.Nonnull;
 
@@ -55,17 +61,20 @@ import java.util.Comparator;
 import java.util.List;
 import java.util.Set;
 
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
 /** Tests for keyed state input format. */
-public class KeyedStateInputFormatTest {
+@RunWith(Parameterized.class)
+class KeyedStateInputFormatTest {
     private static ValueStateDescriptor<Integer> stateDescriptor =
             new ValueStateDescriptor<>("state", Types.INT);
 
-    @Test
-    public void testCreatePartitionedInputSplits() throws Exception {
+    @ParameterizedTest(name = "Enable async state = {0}")
+    @ValueSource(booleans = {false, true})
+    void testCreatePartitionedInputSplits(boolean asyncState) throws Exception 
{
         OperatorID operatorID = OperatorIDGenerator.fromUid("uid");
 
-        OperatorSubtaskState state =
-                createOperatorSubtaskState(new StreamFlatMap<>(new 
StatefulFunction()));
+        OperatorSubtaskState state = 
createOperatorSubtaskState(createFlatMap(asyncState));
         OperatorState operatorState = new OperatorState(null, null, 
operatorID, 1, 128);
         operatorState.putState(0, state);
 
@@ -81,12 +90,12 @@ public class KeyedStateInputFormatTest {
                 "Failed to properly partition operator state into input 
splits", 4, splits.length);
     }
 
-    @Test
-    public void testMaxParallelismRespected() throws Exception {
+    @ParameterizedTest(name = "Enable async state = {0}")
+    @ValueSource(booleans = {false, true})
+    void testMaxParallelismRespected(boolean asyncState) throws Exception {
         OperatorID operatorID = OperatorIDGenerator.fromUid("uid");
 
-        OperatorSubtaskState state =
-                createOperatorSubtaskState(new StreamFlatMap<>(new 
StatefulFunction()));
+        OperatorSubtaskState state = 
createOperatorSubtaskState(createFlatMap(asyncState));
         OperatorState operatorState = new OperatorState(null, null, 
operatorID, 1, 128);
         operatorState.putState(0, state);
 
@@ -104,12 +113,12 @@ public class KeyedStateInputFormatTest {
                 splits.length);
     }
 
-    @Test
-    public void testReadState() throws Exception {
+    @ParameterizedTest(name = "Enable async state = {0}")
+    @ValueSource(booleans = {false, true})
+    void testReadState(boolean asyncState) throws Exception {
         OperatorID operatorID = OperatorIDGenerator.fromUid("uid");
 
-        OperatorSubtaskState state =
-                createOperatorSubtaskState(new StreamFlatMap<>(new 
StatefulFunction()));
+        OperatorSubtaskState state = 
createOperatorSubtaskState(createFlatMap(asyncState));
         OperatorState operatorState = new OperatorState(null, null, 
operatorID, 1, 128);
         operatorState.putState(0, state);
 
@@ -129,12 +138,12 @@ public class KeyedStateInputFormatTest {
         Assert.assertEquals("Incorrect data read from input split", 
Arrays.asList(1, 2, 3), data);
     }
 
-    @Test
-    public void testReadMultipleOutputPerKey() throws Exception {
+    @ParameterizedTest(name = "Enable async state = {0}")
+    @ValueSource(booleans = {false, true})
+    void testReadMultipleOutputPerKey(boolean asyncState) throws Exception {
         OperatorID operatorID = OperatorIDGenerator.fromUid("uid");
 
-        OperatorSubtaskState state =
-                createOperatorSubtaskState(new StreamFlatMap<>(new 
StatefulFunction()));
+        OperatorSubtaskState state = 
createOperatorSubtaskState(createFlatMap(asyncState));
         OperatorState operatorState = new OperatorState(null, null, 
operatorID, 1, 128);
         operatorState.putState(0, state);
 
@@ -155,12 +164,12 @@ public class KeyedStateInputFormatTest {
                 "Incorrect data read from input split", Arrays.asList(1, 1, 2, 
2, 3, 3), data);
     }
 
-    @Test(expected = IOException.class)
-    public void testInvalidProcessReaderFunctionFails() throws Exception {
+    @ParameterizedTest(name = "Enable async state = {0}")
+    @ValueSource(booleans = {false, true})
+    void testInvalidProcessReaderFunctionFails(boolean asyncState) throws 
Exception {
         OperatorID operatorID = OperatorIDGenerator.fromUid("uid");
 
-        OperatorSubtaskState state =
-                createOperatorSubtaskState(new StreamFlatMap<>(new 
StatefulFunction()));
+        OperatorSubtaskState state = 
createOperatorSubtaskState(createFlatMap(asyncState));
         OperatorState operatorState = new OperatorState(null, null, 
operatorID, 1, 128);
         operatorState.putState(0, state);
 
@@ -175,13 +184,12 @@ public class KeyedStateInputFormatTest {
 
         KeyedStateReaderFunction<Integer, Integer> userFunction = new 
InvalidReaderFunction();
 
-        readInputSplit(split, userFunction);
-
-        Assert.fail("KeyedStateReaderFunction did not fail on invalid 
RuntimeContext use");
+        assertThatThrownBy(() -> readInputSplit(split, userFunction))
+                .isInstanceOf(IOException.class);
     }
 
     @Test
-    public void testReadTime() throws Exception {
+    void testReadTime() throws Exception {
         OperatorID operatorID = OperatorIDGenerator.fromUid("uid");
 
         OperatorSubtaskState state =
@@ -237,6 +245,12 @@ public class KeyedStateInputFormatTest {
         return data;
     }
 
+    private OneInputStreamOperator<Integer, Void> createFlatMap(boolean 
asyncState) {
+        return asyncState
+                ? new AsyncStreamFlatMap<>(new AsyncStatefulFunction())
+                : new StreamFlatMap<>(new StatefulFunction());
+    }
+
     private OperatorSubtaskState createOperatorSubtaskState(
             OneInputStreamOperator<Integer, Void> operator) throws Exception {
         try (KeyedOneInputStreamOperatorTestHarness<Integer, Integer, Void> 
testHarness =
@@ -317,6 +331,26 @@ public class KeyedStateInputFormatTest {
         }
     }
 
+    static class AsyncStatefulFunction extends RichFlatMapFunction<Integer, 
Void> {
+        org.apache.flink.api.common.state.v2.ValueState<Integer> state;
+        org.apache.flink.runtime.state.v2.ValueStateDescriptor<Integer> 
asyncStateDescriptor;
+
+        @Override
+        public void open(OpenContext openContext) {
+            asyncStateDescriptor =
+                    new 
org.apache.flink.runtime.state.v2.ValueStateDescriptor<>(
+                            "state", Types.INT);
+            state =
+                    ((StreamingRuntimeContext) getRuntimeContext())
+                            .getValueState(asyncStateDescriptor);
+        }
+
+        @Override
+        public void flatMap(Integer value, Collector<Void> out) throws 
Exception {
+            state.asyncUpdate(value);
+        }
+    }
+
     static class StatefulFunctionWithTime extends 
KeyedProcessFunction<Integer, Integer, Void> {
         ValueState<Integer> state;
 
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/operators/AsyncStreamFlatMap.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/operators/AsyncStreamFlatMap.java
new file mode 100644
index 00000000000..f9aa5c90e42
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/operators/AsyncStreamFlatMap.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.asyncprocessing.operators;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.runtime.asyncprocessing.declare.DeclarationContext;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+/**
+ * A {@link AbstractAsyncStateStreamOperator} for executing {@link 
FlatMapFunction
+ * FlatMapFunctions}.
+ */
+@Internal
+public class AsyncStreamFlatMap<IN, OUT>
+        extends AbstractAsyncStateUdfStreamOperator<OUT, FlatMapFunction<IN, 
OUT>>
+        implements OneInputStreamOperator<IN, OUT> {
+
+    private static final long serialVersionUID = 1L;
+
+    private transient DeclarationContext declarationContext;
+
+    private transient TimestampedCollectorWithDeclaredVariable<OUT> collector;
+
+    public AsyncStreamFlatMap(FlatMapFunction<IN, OUT> flatMapper) {
+        super(flatMapper);
+    }
+
+    @Override
+    public void open() throws Exception {
+        super.open();
+        declarationContext = new DeclarationContext(getDeclarationManager());
+        collector = new TimestampedCollectorWithDeclaredVariable<>(output, 
declarationContext);
+    }
+
+    @Override
+    public void processElement(StreamRecord<IN> element) throws Exception {
+        collector.setTimestamp(element);
+        userFunction.flatMap(element.getValue(), collector);
+    }
+}
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/streaming/api/datastream/KeyedStream.java
 
b/flink-runtime/src/main/java/org/apache/flink/streaming/api/datastream/KeyedStream.java
index 87a8a75671f..dd4818fe30f 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/streaming/api/datastream/KeyedStream.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/streaming/api/datastream/KeyedStream.java
@@ -36,6 +36,7 @@ import org.apache.flink.api.java.typeutils.PojoTypeInfo;
 import org.apache.flink.api.java.typeutils.TupleTypeInfoBase;
 import org.apache.flink.api.java.typeutils.TypeExtractor;
 import 
org.apache.flink.runtime.asyncprocessing.operators.AsyncIntervalJoinOperator;
+import org.apache.flink.runtime.asyncprocessing.operators.AsyncStreamFlatMap;
 import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
 import 
org.apache.flink.streaming.api.functions.aggregation.AggregationFunction;
 import 
org.apache.flink.streaming.api.functions.aggregation.ComparableAggregator;
@@ -46,6 +47,8 @@ import 
org.apache.flink.streaming.api.functions.query.QueryableValueStateOperato
 import org.apache.flink.streaming.api.functions.sink.legacy.SinkFunction;
 import org.apache.flink.streaming.api.graph.StreamGraphGenerator;
 import org.apache.flink.streaming.api.operators.KeyedProcessOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.StreamFlatMap;
 import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
 import org.apache.flink.streaming.api.operators.co.IntervalJoinOperator;
 import org.apache.flink.streaming.api.transformations.OneInputTransformation;
@@ -361,6 +364,19 @@ public class KeyedStream<T, KEY> extends DataStream<T> {
         return transform("KeyedProcess", outputType, operator);
     }
 
+    // ------------------------------------------------------------------------
+    // Flat Map
+    // ------------------------------------------------------------------------
+    @Override
+    public <R> SingleOutputStreamOperator<R> flatMap(
+            FlatMapFunction<T, R> flatMapper, TypeInformation<R> outputType) {
+        OneInputStreamOperator operator =
+                isEnableAsyncState()
+                        ? new AsyncStreamFlatMap(clean(flatMapper))
+                        : new StreamFlatMap<>(clean(flatMapper));
+        return transform("Flat Map", outputType, operator);
+    }
+
     // ------------------------------------------------------------------------
     //  Joining
     // ------------------------------------------------------------------------

Reply via email to