This is an automated email from the ASF dual-hosted git repository. leiyanfei pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push: new 4ae4b871c08 [FLINK-36957][Datastream] Implement asyc state version of stream flatmap (#25848) 4ae4b871c08 is described below commit 4ae4b871c08900fb31b6bc58ef4817187672e96f Author: Yanfei Lei <fredia...@gmail.com> AuthorDate: Tue Jan 14 11:11:12 2025 +0800 [FLINK-36957][Datastream] Implement asyc state version of stream flatmap (#25848) --- .../state/api/input/KeyedStateInputFormatTest.java | 86 +++++++++++++++------- .../operators/AsyncStreamFlatMap.java | 57 ++++++++++++++ .../streaming/api/datastream/KeyedStream.java | 16 ++++ 3 files changed, 133 insertions(+), 26 deletions(-) diff --git a/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/input/KeyedStateInputFormatTest.java b/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/input/KeyedStateInputFormatTest.java index b6c5acaa126..2e6faa0f480 100644 --- a/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/input/KeyedStateInputFormatTest.java +++ b/flink-libraries/flink-state-processing-api/src/test/java/org/apache/flink/state/api/input/KeyedStateInputFormatTest.java @@ -26,6 +26,7 @@ import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.api.common.typeutils.base.VoidSerializer; import org.apache.flink.configuration.Configuration; +import org.apache.flink.runtime.asyncprocessing.operators.AsyncStreamFlatMap; import org.apache.flink.runtime.checkpoint.OperatorState; import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; import org.apache.flink.runtime.jobgraph.OperatorID; @@ -39,12 +40,17 @@ import org.apache.flink.streaming.api.functions.KeyedProcessFunction; import org.apache.flink.streaming.api.operators.KeyedProcessOperator; import org.apache.flink.streaming.api.operators.OneInputStreamOperator; import org.apache.flink.streaming.api.operators.StreamFlatMap; +import org.apache.flink.streaming.api.operators.StreamingRuntimeContext; import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness; import org.apache.flink.streaming.util.MockStreamingRuntimeContext; import org.apache.flink.util.Collector; import org.junit.Assert; -import org.junit.Test; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; import javax.annotation.Nonnull; @@ -55,17 +61,20 @@ import java.util.Comparator; import java.util.List; import java.util.Set; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + /** Tests for keyed state input format. */ -public class KeyedStateInputFormatTest { +@RunWith(Parameterized.class) +class KeyedStateInputFormatTest { private static ValueStateDescriptor<Integer> stateDescriptor = new ValueStateDescriptor<>("state", Types.INT); - @Test - public void testCreatePartitionedInputSplits() throws Exception { + @ParameterizedTest(name = "Enable async state = {0}") + @ValueSource(booleans = {false, true}) + void testCreatePartitionedInputSplits(boolean asyncState) throws Exception { OperatorID operatorID = OperatorIDGenerator.fromUid("uid"); - OperatorSubtaskState state = - createOperatorSubtaskState(new StreamFlatMap<>(new StatefulFunction())); + OperatorSubtaskState state = createOperatorSubtaskState(createFlatMap(asyncState)); OperatorState operatorState = new OperatorState(null, null, operatorID, 1, 128); operatorState.putState(0, state); @@ -81,12 +90,12 @@ public class KeyedStateInputFormatTest { "Failed to properly partition operator state into input splits", 4, splits.length); } - @Test - public void testMaxParallelismRespected() throws Exception { + @ParameterizedTest(name = "Enable async state = {0}") + @ValueSource(booleans = {false, true}) + void testMaxParallelismRespected(boolean asyncState) throws Exception { OperatorID operatorID = OperatorIDGenerator.fromUid("uid"); - OperatorSubtaskState state = - createOperatorSubtaskState(new StreamFlatMap<>(new StatefulFunction())); + OperatorSubtaskState state = createOperatorSubtaskState(createFlatMap(asyncState)); OperatorState operatorState = new OperatorState(null, null, operatorID, 1, 128); operatorState.putState(0, state); @@ -104,12 +113,12 @@ public class KeyedStateInputFormatTest { splits.length); } - @Test - public void testReadState() throws Exception { + @ParameterizedTest(name = "Enable async state = {0}") + @ValueSource(booleans = {false, true}) + void testReadState(boolean asyncState) throws Exception { OperatorID operatorID = OperatorIDGenerator.fromUid("uid"); - OperatorSubtaskState state = - createOperatorSubtaskState(new StreamFlatMap<>(new StatefulFunction())); + OperatorSubtaskState state = createOperatorSubtaskState(createFlatMap(asyncState)); OperatorState operatorState = new OperatorState(null, null, operatorID, 1, 128); operatorState.putState(0, state); @@ -129,12 +138,12 @@ public class KeyedStateInputFormatTest { Assert.assertEquals("Incorrect data read from input split", Arrays.asList(1, 2, 3), data); } - @Test - public void testReadMultipleOutputPerKey() throws Exception { + @ParameterizedTest(name = "Enable async state = {0}") + @ValueSource(booleans = {false, true}) + void testReadMultipleOutputPerKey(boolean asyncState) throws Exception { OperatorID operatorID = OperatorIDGenerator.fromUid("uid"); - OperatorSubtaskState state = - createOperatorSubtaskState(new StreamFlatMap<>(new StatefulFunction())); + OperatorSubtaskState state = createOperatorSubtaskState(createFlatMap(asyncState)); OperatorState operatorState = new OperatorState(null, null, operatorID, 1, 128); operatorState.putState(0, state); @@ -155,12 +164,12 @@ public class KeyedStateInputFormatTest { "Incorrect data read from input split", Arrays.asList(1, 1, 2, 2, 3, 3), data); } - @Test(expected = IOException.class) - public void testInvalidProcessReaderFunctionFails() throws Exception { + @ParameterizedTest(name = "Enable async state = {0}") + @ValueSource(booleans = {false, true}) + void testInvalidProcessReaderFunctionFails(boolean asyncState) throws Exception { OperatorID operatorID = OperatorIDGenerator.fromUid("uid"); - OperatorSubtaskState state = - createOperatorSubtaskState(new StreamFlatMap<>(new StatefulFunction())); + OperatorSubtaskState state = createOperatorSubtaskState(createFlatMap(asyncState)); OperatorState operatorState = new OperatorState(null, null, operatorID, 1, 128); operatorState.putState(0, state); @@ -175,13 +184,12 @@ public class KeyedStateInputFormatTest { KeyedStateReaderFunction<Integer, Integer> userFunction = new InvalidReaderFunction(); - readInputSplit(split, userFunction); - - Assert.fail("KeyedStateReaderFunction did not fail on invalid RuntimeContext use"); + assertThatThrownBy(() -> readInputSplit(split, userFunction)) + .isInstanceOf(IOException.class); } @Test - public void testReadTime() throws Exception { + void testReadTime() throws Exception { OperatorID operatorID = OperatorIDGenerator.fromUid("uid"); OperatorSubtaskState state = @@ -237,6 +245,12 @@ public class KeyedStateInputFormatTest { return data; } + private OneInputStreamOperator<Integer, Void> createFlatMap(boolean asyncState) { + return asyncState + ? new AsyncStreamFlatMap<>(new AsyncStatefulFunction()) + : new StreamFlatMap<>(new StatefulFunction()); + } + private OperatorSubtaskState createOperatorSubtaskState( OneInputStreamOperator<Integer, Void> operator) throws Exception { try (KeyedOneInputStreamOperatorTestHarness<Integer, Integer, Void> testHarness = @@ -317,6 +331,26 @@ public class KeyedStateInputFormatTest { } } + static class AsyncStatefulFunction extends RichFlatMapFunction<Integer, Void> { + org.apache.flink.api.common.state.v2.ValueState<Integer> state; + org.apache.flink.runtime.state.v2.ValueStateDescriptor<Integer> asyncStateDescriptor; + + @Override + public void open(OpenContext openContext) { + asyncStateDescriptor = + new org.apache.flink.runtime.state.v2.ValueStateDescriptor<>( + "state", Types.INT); + state = + ((StreamingRuntimeContext) getRuntimeContext()) + .getValueState(asyncStateDescriptor); + } + + @Override + public void flatMap(Integer value, Collector<Void> out) throws Exception { + state.asyncUpdate(value); + } + } + static class StatefulFunctionWithTime extends KeyedProcessFunction<Integer, Integer, Void> { ValueState<Integer> state; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/operators/AsyncStreamFlatMap.java b/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/operators/AsyncStreamFlatMap.java new file mode 100644 index 00000000000..f9aa5c90e42 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/operators/AsyncStreamFlatMap.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.asyncprocessing.operators; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.runtime.asyncprocessing.declare.DeclarationContext; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; + +/** + * A {@link AbstractAsyncStateStreamOperator} for executing {@link FlatMapFunction + * FlatMapFunctions}. + */ +@Internal +public class AsyncStreamFlatMap<IN, OUT> + extends AbstractAsyncStateUdfStreamOperator<OUT, FlatMapFunction<IN, OUT>> + implements OneInputStreamOperator<IN, OUT> { + + private static final long serialVersionUID = 1L; + + private transient DeclarationContext declarationContext; + + private transient TimestampedCollectorWithDeclaredVariable<OUT> collector; + + public AsyncStreamFlatMap(FlatMapFunction<IN, OUT> flatMapper) { + super(flatMapper); + } + + @Override + public void open() throws Exception { + super.open(); + declarationContext = new DeclarationContext(getDeclarationManager()); + collector = new TimestampedCollectorWithDeclaredVariable<>(output, declarationContext); + } + + @Override + public void processElement(StreamRecord<IN> element) throws Exception { + collector.setTimestamp(element); + userFunction.flatMap(element.getValue(), collector); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/api/datastream/KeyedStream.java b/flink-runtime/src/main/java/org/apache/flink/streaming/api/datastream/KeyedStream.java index 87a8a75671f..dd4818fe30f 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/api/datastream/KeyedStream.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/api/datastream/KeyedStream.java @@ -36,6 +36,7 @@ import org.apache.flink.api.java.typeutils.PojoTypeInfo; import org.apache.flink.api.java.typeutils.TupleTypeInfoBase; import org.apache.flink.api.java.typeutils.TypeExtractor; import org.apache.flink.runtime.asyncprocessing.operators.AsyncIntervalJoinOperator; +import org.apache.flink.runtime.asyncprocessing.operators.AsyncStreamFlatMap; import org.apache.flink.streaming.api.functions.KeyedProcessFunction; import org.apache.flink.streaming.api.functions.aggregation.AggregationFunction; import org.apache.flink.streaming.api.functions.aggregation.ComparableAggregator; @@ -46,6 +47,8 @@ import org.apache.flink.streaming.api.functions.query.QueryableValueStateOperato import org.apache.flink.streaming.api.functions.sink.legacy.SinkFunction; import org.apache.flink.streaming.api.graph.StreamGraphGenerator; import org.apache.flink.streaming.api.operators.KeyedProcessOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.StreamFlatMap; import org.apache.flink.streaming.api.operators.StreamOperatorFactory; import org.apache.flink.streaming.api.operators.co.IntervalJoinOperator; import org.apache.flink.streaming.api.transformations.OneInputTransformation; @@ -361,6 +364,19 @@ public class KeyedStream<T, KEY> extends DataStream<T> { return transform("KeyedProcess", outputType, operator); } + // ------------------------------------------------------------------------ + // Flat Map + // ------------------------------------------------------------------------ + @Override + public <R> SingleOutputStreamOperator<R> flatMap( + FlatMapFunction<T, R> flatMapper, TypeInformation<R> outputType) { + OneInputStreamOperator operator = + isEnableAsyncState() + ? new AsyncStreamFlatMap(clean(flatMapper)) + : new StreamFlatMap<>(clean(flatMapper)); + return transform("Flat Map", outputType, operator); + } + // ------------------------------------------------------------------------ // Joining // ------------------------------------------------------------------------