This is an automated email from the ASF dual-hosted git repository. leiyanfei pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push: new d8b3243eaa5 [FLINK-35667][state/forst] Implement Reducing Async State API for ForStStateBackend (#25308) d8b3243eaa5 is described below commit d8b3243eaa5475eafb813c23419bbb55a4bebd23 Author: Yanfei Lei <fredia...@gmail.com> AuthorDate: Sat Sep 14 10:10:20 2024 +0800 [FLINK-35667][state/forst] Implement Reducing Async State API for ForStStateBackend (#25308) --- .../runtime/asyncprocessing/StateRequestType.java | 3 + .../runtime/state/v2/InternalMergingState.java | 41 ++++++ .../runtime/state/v2/InternalReducingState.java | 117 +++++++++++++++- .../state/v2/InternalAggregatingStateTest.java | 6 +- .../state/v2/InternalKeyedStateTestBase.java | 7 +- .../runtime/state/v2/InternalListStateTest.java | 20 +-- .../runtime/state/v2/InternalMapStateTest.java | 42 +++--- .../state/v2/InternalReducingStateTest.java | 142 +++++++++++++++++++- .../runtime/state/v2/InternalValueStateTest.java | 12 +- .../flink/state/forst/ForStKeyedStateBackend.java | 12 ++ .../flink/state/forst/ForStReducingState.java | 149 +++++++++++++++++++++ .../state/forst/ForStStateRequestClassifier.java | 2 + 12 files changed, 502 insertions(+), 51 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/StateRequestType.java b/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/StateRequestType.java index 504115a48fa..a3a95cf9f69 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/StateRequestType.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/StateRequestType.java @@ -108,6 +108,9 @@ public enum StateRequestType { /** Add element into reducing state, {@link ReducingState#asyncAdd(Object)}. */ REDUCING_ADD, + /** Remove element from reducing state. */ + REDUCING_REMOVE, + /** Get value from aggregating state by {@link AggregatingState#asyncGet()}. */ AGGREGATING_GET, diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/InternalMergingState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/InternalMergingState.java new file mode 100644 index 00000000000..3f5713bdb58 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/InternalMergingState.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state.v2; + +import org.apache.flink.api.common.state.v2.StateFuture; + +import java.util.Collection; + +/** + * This class defines the internal interface for merging state. + * + * @param <N> The type of the namespace + */ +public interface InternalMergingState<N> extends InternalPartitionedState<N> { + + /** + * Merges the state of the current key for the given source namespaces into the state of the + * target namespace. + * + * @param target The target namespace where the merged state should be stored. + * @param sources The source namespaces whose state should be merged. + */ + StateFuture<Void> asyncMergeNamespaces(N target, Collection<N> sources); + + void mergeNamespaces(N target, Collection<N> sources); +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/InternalReducingState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/InternalReducingState.java index b49962f9f5d..0c3b124a778 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/InternalReducingState.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/InternalReducingState.java @@ -20,9 +20,15 @@ package org.apache.flink.runtime.state.v2; import org.apache.flink.api.common.functions.ReduceFunction; import org.apache.flink.api.common.state.v2.ReducingState; import org.apache.flink.api.common.state.v2.StateFuture; +import org.apache.flink.core.state.StateFutureUtils; import org.apache.flink.runtime.asyncprocessing.StateRequestHandler; import org.apache.flink.runtime.asyncprocessing.StateRequestType; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Iterator; +import java.util.List; + /** * A default implementation of {@link ReducingState} which delegates all async requests to {@link * StateRequestHandler}. @@ -31,7 +37,7 @@ import org.apache.flink.runtime.asyncprocessing.StateRequestType; * @param <V> The type of values kept internally in state. */ public class InternalReducingState<K, N, V> extends InternalKeyedState<K, N, V> - implements ReducingState<V> { + implements ReducingState<V>, InternalMergingState<N> { protected final ReduceFunction<V> reduceFunction; @@ -48,7 +54,15 @@ public class InternalReducingState<K, N, V> extends InternalKeyedState<K, N, V> @Override public StateFuture<Void> asyncAdd(V value) { - return handleRequest(StateRequestType.REDUCING_ADD, value); + return handleRequest(StateRequestType.REDUCING_GET, null) + .thenAccept( + oldValue -> { + V newValue = + oldValue == null + ? value + : reduceFunction.reduce((V) oldValue, value); + handleRequest(StateRequestType.REDUCING_ADD, newValue); + }); } @Override @@ -58,6 +72,103 @@ public class InternalReducingState<K, N, V> extends InternalKeyedState<K, N, V> @Override public void add(V value) { - handleRequestSync(StateRequestType.REDUCING_ADD, value); + V oldValue = handleRequestSync(StateRequestType.REDUCING_GET, null); + try { + V newValue = oldValue == null ? value : reduceFunction.reduce(oldValue, value); + handleRequestSync(StateRequestType.REDUCING_ADD, newValue); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Override + public StateFuture<Void> asyncMergeNamespaces(N target, Collection<N> sources) { + if (sources == null || sources.isEmpty()) { + return StateFutureUtils.completedVoidFuture(); + } + // phase 1: read from the sources and target + List<StateFuture<V>> futures = new ArrayList<>(sources.size() + 1); + for (N source : sources) { + if (source != null) { + setCurrentNamespace(source); + futures.add(handleRequest(StateRequestType.REDUCING_GET, null)); + } + } + setCurrentNamespace(target); + futures.add(handleRequest(StateRequestType.REDUCING_GET, null)); + // phase 2: merge the sources to the target + return StateFutureUtils.combineAll(futures) + .thenCompose( + values -> { + List<StateFuture<V>> updateFutures = + new ArrayList<>(sources.size() + 1); + V current = null; + Iterator<V> valueIterator = values.iterator(); + for (N source : sources) { + V value = valueIterator.next(); + if (value != null) { + setCurrentNamespace(source); + updateFutures.add( + handleRequest(StateRequestType.REDUCING_REMOVE, null)); + if (current != null) { + current = reduceFunction.reduce(current, value); + } else { + current = value; + } + } + } + V targetValue = valueIterator.next(); + if (current != null) { + if (targetValue != null) { + current = reduceFunction.reduce(current, targetValue); + } + setCurrentNamespace(target); + updateFutures.add( + handleRequest(StateRequestType.REDUCING_ADD, current)); + } + return StateFutureUtils.combineAll(updateFutures) + .thenAccept(ignores -> {}); + }); + } + + @Override + public void mergeNamespaces(N target, Collection<N> sources) { + if (sources == null || sources.isEmpty()) { + return; + } + try { + V current = null; + // merge the sources to the target + for (N source : sources) { + if (source != null) { + setCurrentNamespace(source); + V oldValue = handleRequestSync(StateRequestType.REDUCING_GET, null); + + if (oldValue != null) { + handleRequestSync(StateRequestType.REDUCING_REMOVE, null); + + if (current != null) { + current = reduceFunction.reduce(current, oldValue); + } else { + current = oldValue; + } + } + } + } + + // if something came out of merging the sources, merge it or write it to the target + if (current != null) { + // create the target full-binary-key + setCurrentNamespace(target); + V targetValue = handleRequestSync(StateRequestType.REDUCING_GET, null); + + if (targetValue != null) { + current = reduceFunction.reduce(current, targetValue); + } + handleRequestSync(StateRequestType.REDUCING_ADD, current); + } + } catch (Exception e) { + throw new RuntimeException("merge namespace fail.", e); + } } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalAggregatingStateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalAggregatingStateTest.java index 97410748189..da671c78a83 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalAggregatingStateTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalAggregatingStateTest.java @@ -61,12 +61,12 @@ class InternalAggregatingStateTest extends InternalKeyedStateTestBase { aec.setCurrentContext(aec.buildContext("test", "test")); state.asyncClear(); - validateRequestRun(state, StateRequestType.CLEAR, null); + validateRequestRun(state, StateRequestType.CLEAR, null, 0); state.asyncGet(); - validateRequestRun(state, StateRequestType.AGGREGATING_GET, null); + validateRequestRun(state, StateRequestType.AGGREGATING_GET, null, 0); state.asyncAdd(1); - validateRequestRun(state, StateRequestType.AGGREGATING_ADD, 1); + validateRequestRun(state, StateRequestType.AGGREGATING_ADD, 1, 0); } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalKeyedStateTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalKeyedStateTestBase.java index 91923ff7705..9e645cf5a6b 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalKeyedStateTestBase.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalKeyedStateTestBase.java @@ -86,10 +86,13 @@ public class InternalKeyedStateTestBase { } <IN> void validateRequestRun( - @Nullable State state, StateRequestType type, @Nullable IN payload) { + @Nullable State state, + StateRequestType type, + @Nullable IN payload, + int remainingRequests) { aec.triggerIfNeeded(true); testStateExecutor.validate(state, type, payload); - assertThat(testStateExecutor.receivedRequest.isEmpty()).isTrue(); + assertThat(testStateExecutor.receivedRequest.size()).isEqualTo(remainingRequests); } /** diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalListStateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalListStateTest.java index ef0de969fcf..7f6679d2399 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalListStateTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalListStateTest.java @@ -39,37 +39,37 @@ public class InternalListStateTest extends InternalKeyedStateTestBase { aec.setCurrentContext(aec.buildContext("test", "test")); listState.asyncClear(); - validateRequestRun(listState, StateRequestType.CLEAR, null); + validateRequestRun(listState, StateRequestType.CLEAR, null, 0); listState.asyncGet(); - validateRequestRun(listState, StateRequestType.LIST_GET, null); + validateRequestRun(listState, StateRequestType.LIST_GET, null, 0); listState.asyncAdd(1); - validateRequestRun(listState, StateRequestType.LIST_ADD, 1); + validateRequestRun(listState, StateRequestType.LIST_ADD, 1, 0); List<Integer> list = new ArrayList<>(); listState.asyncUpdate(list); - validateRequestRun(listState, StateRequestType.LIST_UPDATE, list); + validateRequestRun(listState, StateRequestType.LIST_UPDATE, list, 0); list = new ArrayList<>(); listState.asyncAddAll(list); - validateRequestRun(listState, StateRequestType.LIST_ADD_ALL, list); + validateRequestRun(listState, StateRequestType.LIST_ADD_ALL, list, 0); listState.clear(); - validateRequestRun(listState, StateRequestType.CLEAR, null); + validateRequestRun(listState, StateRequestType.CLEAR, null, 0); listState.get().iterator(); - validateRequestRun(listState, StateRequestType.LIST_GET, null); + validateRequestRun(listState, StateRequestType.LIST_GET, null, 0); listState.add(1); - validateRequestRun(listState, StateRequestType.LIST_ADD, 1); + validateRequestRun(listState, StateRequestType.LIST_ADD, 1, 0); list = new ArrayList<>(); listState.update(list); - validateRequestRun(listState, StateRequestType.LIST_UPDATE, list); + validateRequestRun(listState, StateRequestType.LIST_UPDATE, list, 0); list = new ArrayList<>(); listState.addAll(list); - validateRequestRun(listState, StateRequestType.LIST_ADD_ALL, list); + validateRequestRun(listState, StateRequestType.LIST_ADD_ALL, list, 0); } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalMapStateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalMapStateTest.java index ae988326caa..52dd9b1c3ef 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalMapStateTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalMapStateTest.java @@ -41,67 +41,67 @@ public class InternalMapStateTest extends InternalKeyedStateTestBase { aec.setCurrentContext(aec.buildContext("test", "test")); mapState.asyncClear(); - validateRequestRun(mapState, StateRequestType.CLEAR, null); + validateRequestRun(mapState, StateRequestType.CLEAR, null, 0); mapState.asyncGet("key1"); - validateRequestRun(mapState, StateRequestType.MAP_GET, "key1"); + validateRequestRun(mapState, StateRequestType.MAP_GET, "key1", 0); mapState.asyncPut("key2", 2); - validateRequestRun(mapState, StateRequestType.MAP_PUT, Tuple2.of("key2", 2)); + validateRequestRun(mapState, StateRequestType.MAP_PUT, Tuple2.of("key2", 2), 0); Map<String, Integer> map = new HashMap<>(); mapState.asyncPutAll(map); - validateRequestRun(mapState, StateRequestType.MAP_PUT_ALL, map); + validateRequestRun(mapState, StateRequestType.MAP_PUT_ALL, map, 0); mapState.asyncRemove("key3"); - validateRequestRun(mapState, StateRequestType.MAP_REMOVE, "key3"); + validateRequestRun(mapState, StateRequestType.MAP_REMOVE, "key3", 0); mapState.asyncContains("key4"); - validateRequestRun(mapState, StateRequestType.MAP_CONTAINS, "key4"); + validateRequestRun(mapState, StateRequestType.MAP_CONTAINS, "key4", 0); mapState.asyncEntries(); - validateRequestRun(mapState, StateRequestType.MAP_ITER, null); + validateRequestRun(mapState, StateRequestType.MAP_ITER, null, 0); mapState.asyncKeys(); - validateRequestRun(mapState, StateRequestType.MAP_ITER_KEY, null); + validateRequestRun(mapState, StateRequestType.MAP_ITER_KEY, null, 0); mapState.asyncValues(); - validateRequestRun(mapState, StateRequestType.MAP_ITER_VALUE, null); + validateRequestRun(mapState, StateRequestType.MAP_ITER_VALUE, null, 0); mapState.asyncIsEmpty(); - validateRequestRun(mapState, StateRequestType.MAP_IS_EMPTY, null); + validateRequestRun(mapState, StateRequestType.MAP_IS_EMPTY, null, 0); mapState.clear(); - validateRequestRun(mapState, StateRequestType.CLEAR, null); + validateRequestRun(mapState, StateRequestType.CLEAR, null, 0); mapState.get("key1"); - validateRequestRun(mapState, StateRequestType.MAP_GET, "key1"); + validateRequestRun(mapState, StateRequestType.MAP_GET, "key1", 0); mapState.put("key2", 2); - validateRequestRun(mapState, StateRequestType.MAP_PUT, Tuple2.of("key2", 2)); + validateRequestRun(mapState, StateRequestType.MAP_PUT, Tuple2.of("key2", 2), 0); mapState.putAll(map); - validateRequestRun(mapState, StateRequestType.MAP_PUT_ALL, map); + validateRequestRun(mapState, StateRequestType.MAP_PUT_ALL, map, 0); mapState.remove("key3"); - validateRequestRun(mapState, StateRequestType.MAP_REMOVE, "key3"); + validateRequestRun(mapState, StateRequestType.MAP_REMOVE, "key3", 0); mapState.contains("key4"); - validateRequestRun(mapState, StateRequestType.MAP_CONTAINS, "key4"); + validateRequestRun(mapState, StateRequestType.MAP_CONTAINS, "key4", 0); mapState.iterator(); - validateRequestRun(mapState, StateRequestType.MAP_ITER, null); + validateRequestRun(mapState, StateRequestType.MAP_ITER, null, 0); mapState.entries().iterator(); - validateRequestRun(mapState, StateRequestType.MAP_ITER, null); + validateRequestRun(mapState, StateRequestType.MAP_ITER, null, 0); mapState.keys().iterator(); - validateRequestRun(mapState, StateRequestType.MAP_ITER_KEY, null); + validateRequestRun(mapState, StateRequestType.MAP_ITER_KEY, null, 0); mapState.values().iterator(); - validateRequestRun(mapState, StateRequestType.MAP_ITER_VALUE, null); + validateRequestRun(mapState, StateRequestType.MAP_ITER_VALUE, null, 0); mapState.isEmpty(); - validateRequestRun(mapState, StateRequestType.MAP_IS_EMPTY, null); + validateRequestRun(mapState, StateRequestType.MAP_IS_EMPTY, null, 0); } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalReducingStateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalReducingStateTest.java index cf4657ef626..f923dd52128 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalReducingStateTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalReducingStateTest.java @@ -20,10 +20,26 @@ package org.apache.flink.runtime.state.v2; import org.apache.flink.api.common.functions.ReduceFunction; import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.runtime.asyncprocessing.AsyncExecutionController; +import org.apache.flink.runtime.asyncprocessing.MockStateRequestContainer; +import org.apache.flink.runtime.asyncprocessing.StateExecutor; +import org.apache.flink.runtime.asyncprocessing.StateRequest; +import org.apache.flink.runtime.asyncprocessing.StateRequestContainer; import org.apache.flink.runtime.asyncprocessing.StateRequestType; +import org.apache.flink.runtime.mailbox.SyncMailboxExecutor; +import org.apache.flink.util.Preconditions; import org.junit.jupiter.api.Test; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.concurrent.CompletableFuture; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; + /** Tests for {@link InternalReducingState}. */ public class InternalReducingStateTest extends InternalKeyedStateTestBase { @@ -38,21 +54,135 @@ public class InternalReducingStateTest extends InternalKeyedStateTestBase { aec.setCurrentContext(aec.buildContext("test", "test")); reducingState.asyncClear(); - validateRequestRun(reducingState, StateRequestType.CLEAR, null); + validateRequestRun(reducingState, StateRequestType.CLEAR, null, 0); reducingState.asyncGet(); - validateRequestRun(reducingState, StateRequestType.REDUCING_GET, null); + validateRequestRun(reducingState, StateRequestType.REDUCING_GET, null, 0); reducingState.asyncAdd(1); - validateRequestRun(reducingState, StateRequestType.REDUCING_ADD, 1); + validateRequestRun(reducingState, StateRequestType.REDUCING_GET, null, 1); + validateRequestRun(reducingState, StateRequestType.REDUCING_ADD, 1, 0); reducingState.clear(); - validateRequestRun(reducingState, StateRequestType.CLEAR, null); + validateRequestRun(reducingState, StateRequestType.CLEAR, null, 0); reducingState.get(); - validateRequestRun(reducingState, StateRequestType.REDUCING_GET, null); + validateRequestRun(reducingState, StateRequestType.REDUCING_GET, null, 0); reducingState.add(1); - validateRequestRun(reducingState, StateRequestType.REDUCING_ADD, 1); + validateRequestRun(reducingState, StateRequestType.REDUCING_GET, null, 1); + validateRequestRun(reducingState, StateRequestType.REDUCING_ADD, 1, 0); + } + + @Test + public void testMergeNamespace() throws Exception { + ReduceFunction<Integer> reducer = Integer::sum; + ReducingStateDescriptor<Integer> descriptor = + new ReducingStateDescriptor<>("testState", reducer, BasicTypeInfo.INT_TYPE_INFO); + AsyncExecutionController aec = + new AsyncExecutionController( + new SyncMailboxExecutor(), + (a, b) -> {}, + new ReducingStateExecutor(), + 1, + 100, + 10000, + 1); + InternalReducingState<String, String, Integer> reducingState = + new InternalReducingState<>(aec, descriptor); + aec.setCurrentContext(aec.buildContext("test", "test")); + aec.setCurrentNamespaceForState(reducingState, "1"); + reducingState.asyncAdd(1); + aec.drainInflightRecords(0); + assertThat(ReducingStateExecutor.hashMap.size()).isEqualTo(1); + assertThat(ReducingStateExecutor.hashMap.get(Tuple2.of("test", "1"))).isEqualTo(1); + aec.setCurrentNamespaceForState(reducingState, "2"); + reducingState.asyncAdd(2); + aec.drainInflightRecords(0); + assertThat(ReducingStateExecutor.hashMap.size()).isEqualTo(2); + assertThat(ReducingStateExecutor.hashMap.get(Tuple2.of("test", "1"))).isEqualTo(1); + assertThat(ReducingStateExecutor.hashMap.get(Tuple2.of("test", "2"))).isEqualTo(2); + aec.setCurrentNamespaceForState(reducingState, "3"); + reducingState.asyncAdd(3); + aec.drainInflightRecords(0); + assertThat(ReducingStateExecutor.hashMap.size()).isEqualTo(3); + assertThat(ReducingStateExecutor.hashMap.get(Tuple2.of("test", "1"))).isEqualTo(1); + assertThat(ReducingStateExecutor.hashMap.get(Tuple2.of("test", "2"))).isEqualTo(2); + assertThat(ReducingStateExecutor.hashMap.get(Tuple2.of("test", "3"))).isEqualTo(3); + + List<String> sources = new ArrayList<>(Arrays.asList("1", "2", "3")); + reducingState.asyncMergeNamespaces("0", sources); + aec.drainInflightRecords(0); + assertThat(ReducingStateExecutor.hashMap.size()).isEqualTo(1); + assertThat(ReducingStateExecutor.hashMap.get(Tuple2.of("test", "0"))).isEqualTo(6); + assertThat(ReducingStateExecutor.hashMap.get(Tuple2.of("test", "1"))).isNull(); + assertThat(ReducingStateExecutor.hashMap.get(Tuple2.of("test", "2"))).isNull(); + assertThat(ReducingStateExecutor.hashMap.get(Tuple2.of("test", "3"))).isNull(); + + aec.setCurrentNamespaceForState(reducingState, "4"); + reducingState.asyncAdd(4); + aec.drainInflightRecords(0); + assertThat(ReducingStateExecutor.hashMap.size()).isEqualTo(2); + assertThat(ReducingStateExecutor.hashMap.get(Tuple2.of("test", "0"))).isEqualTo(6); + assertThat(ReducingStateExecutor.hashMap.get(Tuple2.of("test", "4"))).isEqualTo(4); + + List<String> sources1 = new ArrayList<>(Arrays.asList("4")); + reducingState.asyncMergeNamespaces("0", sources1); + aec.drainInflightRecords(0); + + assertThat(ReducingStateExecutor.hashMap.size()).isEqualTo(1); + assertThat(ReducingStateExecutor.hashMap.get(Tuple2.of("test", "0"))).isEqualTo(10); + assertThat(ReducingStateExecutor.hashMap.get(Tuple2.of("test", "1"))).isNull(); + assertThat(ReducingStateExecutor.hashMap.get(Tuple2.of("test", "2"))).isNull(); + assertThat(ReducingStateExecutor.hashMap.get(Tuple2.of("test", "3"))).isNull(); + assertThat(ReducingStateExecutor.hashMap.get(Tuple2.of("test", "4"))).isNull(); + } + + static class ReducingStateExecutor implements StateExecutor { + + private static final HashMap<Tuple2<String, String>, Integer> hashMap = new HashMap<>(); + + public ReducingStateExecutor() { + hashMap.clear(); + } + + @Override + @SuppressWarnings({"unchecked", "rawtypes"}) + public CompletableFuture<Void> executeBatchRequests( + StateRequestContainer stateRequestContainer) { + Preconditions.checkArgument(stateRequestContainer instanceof MockStateRequestContainer); + CompletableFuture<Void> future = new CompletableFuture<>(); + for (StateRequest request : + ((MockStateRequestContainer) stateRequestContainer).getStateRequestList()) { + if (request.getRequestType() == StateRequestType.REDUCING_GET) { + String key = (String) request.getRecordContext().getKey(); + String namespace = (String) request.getNamespace(); + Integer val = hashMap.get(Tuple2.of(key, namespace)); + request.getFuture().complete(val); + } else if (request.getRequestType() == StateRequestType.REDUCING_ADD) { + String key = (String) request.getRecordContext().getKey(); + String namespace = (String) request.getNamespace(); + hashMap.put(Tuple2.of(key, namespace), (Integer) request.getPayload()); + request.getFuture().complete(null); + } else if (request.getRequestType() == StateRequestType.REDUCING_REMOVE) { + String key = (String) request.getRecordContext().getKey(); + String namespace = (String) request.getNamespace(); + hashMap.remove(Tuple2.of(key, namespace)); + request.getFuture().complete(null); + } else { + throw new UnsupportedOperationException("Unsupported request type"); + } + } + future.complete(null); + return future; + } + + @Override + public StateRequestContainer createStateRequestContainer() { + return new MockStateRequestContainer(); + } + + @Override + public void shutdown() {} } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalValueStateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalValueStateTest.java index cb79b4ab7b2..58278ecbd39 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalValueStateTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/InternalValueStateTest.java @@ -36,21 +36,21 @@ public class InternalValueStateTest extends InternalKeyedStateTestBase { aec.setCurrentContext(aec.buildContext("test", "test")); valueState.asyncClear(); - validateRequestRun(valueState, StateRequestType.CLEAR, null); + validateRequestRun(valueState, StateRequestType.CLEAR, null, 0); valueState.asyncValue(); - validateRequestRun(valueState, StateRequestType.VALUE_GET, null); + validateRequestRun(valueState, StateRequestType.VALUE_GET, null, 0); valueState.asyncUpdate(1); - validateRequestRun(valueState, StateRequestType.VALUE_UPDATE, 1); + validateRequestRun(valueState, StateRequestType.VALUE_UPDATE, 1, 0); valueState.clear(); - validateRequestRun(valueState, StateRequestType.CLEAR, null); + validateRequestRun(valueState, StateRequestType.CLEAR, null, 0); valueState.value(); - validateRequestRun(valueState, StateRequestType.VALUE_GET, null); + validateRequestRun(valueState, StateRequestType.VALUE_GET, null, 0); valueState.update(1); - validateRequestRun(valueState, StateRequestType.VALUE_UPDATE, 1); + validateRequestRun(valueState, StateRequestType.VALUE_UPDATE, 1, 0); } } diff --git a/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStKeyedStateBackend.java b/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStKeyedStateBackend.java index 7ec2f60c854..a04c7f2b5f4 100644 --- a/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStKeyedStateBackend.java +++ b/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStKeyedStateBackend.java @@ -28,6 +28,7 @@ import org.apache.flink.runtime.asyncprocessing.StateRequestHandler; import org.apache.flink.runtime.state.AsyncKeyedStateBackend; import org.apache.flink.runtime.state.SerializedCompositeKeyBuilder; import org.apache.flink.runtime.state.v2.ListStateDescriptor; +import org.apache.flink.runtime.state.v2.ReducingStateDescriptor; import org.apache.flink.runtime.state.v2.StateDescriptor; import org.apache.flink.runtime.state.v2.ValueStateDescriptor; import org.apache.flink.util.FlinkRuntimeException; @@ -191,6 +192,17 @@ public class ForStKeyedStateBackend<K> implements AsyncKeyedStateBackend { keyDeserializerView, valueDeserializerView, keyGroupPrefixBytes); + case REDUCING: + return (S) + new ForStReducingState<>( + stateRequestHandler, + columnFamilyHandle, + (ReducingStateDescriptor<SV>) stateDesc, + serializedKeyBuilder, + defaultNamespace, + namespaceSerializer::duplicate, + valueSerializerView, + valueDeserializerView); default: throw new UnsupportedOperationException( String.format("Unsupported state type: %s", stateDesc.getType())); diff --git a/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStReducingState.java b/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStReducingState.java new file mode 100644 index 00000000000..3fe2329bd7b --- /dev/null +++ b/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStReducingState.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.state.forst; + +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.core.memory.DataInputDeserializer; +import org.apache.flink.core.memory.DataOutputSerializer; +import org.apache.flink.core.state.InternalStateFuture; +import org.apache.flink.runtime.asyncprocessing.RecordContext; +import org.apache.flink.runtime.asyncprocessing.StateRequest; +import org.apache.flink.runtime.asyncprocessing.StateRequestHandler; +import org.apache.flink.runtime.asyncprocessing.StateRequestType; +import org.apache.flink.runtime.state.SerializedCompositeKeyBuilder; +import org.apache.flink.runtime.state.v2.InternalReducingState; +import org.apache.flink.runtime.state.v2.ReducingStateDescriptor; +import org.apache.flink.util.Preconditions; + +import org.rocksdb.ColumnFamilyHandle; + +import java.io.IOException; +import java.util.function.Supplier; + +/** + * The {@link InternalReducingState} implement for ForStDB. + * + * @param <K> The type of the key. + * @param <N> The type of the namespace. + * @param <V> The type of the value. + */ +public class ForStReducingState<K, N, V> extends InternalReducingState<K, N, V> + implements ForStInnerTable<K, N, V> { + + /** The column family which this internal value state belongs to. */ + private final ColumnFamilyHandle columnFamilyHandle; + + /** The serialized key builder which should be thread-safe. */ + private final ThreadLocal<SerializedCompositeKeyBuilder<K>> serializedKeyBuilder; + + /** The default namespace if not set. * */ + private final N defaultNamespace; + + /** The serializer for namespace. * */ + private final ThreadLocal<TypeSerializer<N>> namespaceSerializer; + + /** The data outputStream used for value serializer, which should be thread-safe. */ + private final ThreadLocal<DataOutputSerializer> valueSerializerView; + + /** The data inputStream used for value deserializer, which should be thread-safe. */ + private final ThreadLocal<DataInputDeserializer> valueDeserializerView; + + public ForStReducingState( + StateRequestHandler stateRequestHandler, + ColumnFamilyHandle columnFamily, + ReducingStateDescriptor<V> reducingStateDescriptor, + Supplier<SerializedCompositeKeyBuilder<K>> serializedKeyBuilderInitializer, + N defaultNamespace, + Supplier<TypeSerializer<N>> namespaceSerializerInitializer, + Supplier<DataOutputSerializer> valueSerializerViewInitializer, + Supplier<DataInputDeserializer> valueDeserializerViewInitializer) { + super(stateRequestHandler, reducingStateDescriptor); + this.columnFamilyHandle = columnFamily; + this.serializedKeyBuilder = ThreadLocal.withInitial(serializedKeyBuilderInitializer); + this.defaultNamespace = defaultNamespace; + this.namespaceSerializer = ThreadLocal.withInitial(namespaceSerializerInitializer); + this.valueSerializerView = ThreadLocal.withInitial(valueSerializerViewInitializer); + this.valueDeserializerView = ThreadLocal.withInitial(valueDeserializerViewInitializer); + } + + @Override + public ColumnFamilyHandle getColumnFamilyHandle() { + return columnFamilyHandle; + } + + @Override + public byte[] serializeKey(ContextKey<K, N> contextKey) throws IOException { + return contextKey.getOrCreateSerializedKey( + ctxKey -> { + SerializedCompositeKeyBuilder<K> builder = serializedKeyBuilder.get(); + builder.setKeyAndKeyGroup(ctxKey.getRawKey(), ctxKey.getKeyGroup()); + N namespace = ctxKey.getNamespace(); + return builder.buildCompositeKeyNamespace( + namespace == null ? defaultNamespace : namespace, + namespaceSerializer.get()); + }); + } + + @Override + public byte[] serializeValue(V value) throws IOException { + DataOutputSerializer outputView = valueSerializerView.get(); + outputView.clear(); + getValueSerializer().serialize(value, outputView); + return outputView.getCopyOfBuffer(); + } + + @Override + public V deserializeValue(byte[] valueBytes) throws IOException { + DataInputDeserializer inputView = valueDeserializerView.get(); + inputView.setBuffer(valueBytes); + return getValueSerializer().deserialize(inputView); + } + + @SuppressWarnings("unchecked") + @Override + public ForStDBGetRequest<K, N, V, V> buildDBGetRequest(StateRequest<?, ?, ?, ?> stateRequest) { + Preconditions.checkArgument(stateRequest.getRequestType() == StateRequestType.REDUCING_GET); + ContextKey<K, N> contextKey = + new ContextKey<>( + (RecordContext<K>) stateRequest.getRecordContext(), + (N) stateRequest.getNamespace()); + return new ForStDBSingleGetRequest<>( + contextKey, this, (InternalStateFuture<V>) stateRequest.getFuture()); + } + + @SuppressWarnings("unchecked") + @Override + public ForStDBPutRequest<K, N, V> buildDBPutRequest(StateRequest<?, ?, ?, ?> stateRequest) { + Preconditions.checkArgument( + stateRequest.getRequestType() == StateRequestType.REDUCING_ADD + || stateRequest.getRequestType() == StateRequestType.REDUCING_REMOVE + || stateRequest.getRequestType() == StateRequestType.CLEAR); + ContextKey<K, N> contextKey = + new ContextKey<>( + (RecordContext<K>) stateRequest.getRecordContext(), + (N) stateRequest.getNamespace()); + V value = + (stateRequest.getRequestType() == StateRequestType.REDUCING_REMOVE + || stateRequest.getRequestType() == StateRequestType.CLEAR) + ? null // "Delete(key)" is equivalent to "Put(key, null)" + : (V) stateRequest.getPayload(); + return ForStDBPutRequest.of( + contextKey, value, this, (InternalStateFuture<Void>) stateRequest.getFuture()); + } +} diff --git a/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStStateRequestClassifier.java b/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStStateRequestClassifier.java index ea6ede88acf..c7205f0791c 100644 --- a/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStStateRequestClassifier.java +++ b/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStStateRequestClassifier.java @@ -62,6 +62,7 @@ public class ForStStateRequestClassifier implements StateRequestContainer { case MAP_GET: case MAP_IS_EMPTY: case MAP_CONTAINS: + case REDUCING_GET: { ForStInnerTable<?, ?, ?> innerTable = (ForStInnerTable<?, ?, ?>) stateRequest.getState(); @@ -74,6 +75,7 @@ public class ForStStateRequestClassifier implements StateRequestContainer { case LIST_ADD_ALL: case MAP_PUT: case MAP_REMOVE: + case REDUCING_ADD: { ForStInnerTable<?, ?, ?> innerTable = (ForStInnerTable<?, ?, ?>) stateRequest.getState();