This is an automated email from the ASF dual-hosted git repository. leiyanfei pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push: new 6d139a19c80 [FLINK-35024][Runtime/State] Implement the record buffer of AsyncExecutionController (#24633) 6d139a19c80 is described below commit 6d139a19c809f787317c5afa4e56e1c544125e5f Author: Yanfei Lei <fredia...@gmail.com> AuthorDate: Mon Apr 15 13:56:53 2024 +0800 [FLINK-35024][Runtime/State] Implement the record buffer of AsyncExecutionController (#24633) --- .../asyncprocessing/AsyncExecutionController.java | 98 +++++++- .../asyncprocessing/StateRequestBuffer.java | 125 ++++++++++ .../AsyncExecutionControllerTest.java | 259 ++++++++++++++------- 3 files changed, 385 insertions(+), 97 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/AsyncExecutionController.java b/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/AsyncExecutionController.java index ffecc5f9687..cf8304a71ea 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/AsyncExecutionController.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/AsyncExecutionController.java @@ -21,12 +21,16 @@ package org.apache.flink.runtime.asyncprocessing; import org.apache.flink.api.common.operators.MailboxExecutor; import org.apache.flink.api.common.state.v2.State; import org.apache.flink.core.state.InternalStateFuture; +import org.apache.flink.util.Preconditions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import javax.annotation.Nullable; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; + /** * The Async Execution Controller (AEC) receives processing requests from operators, and put them * into execution according to some strategies. @@ -45,11 +49,26 @@ public class AsyncExecutionController<R, K> { private static final Logger LOG = LoggerFactory.getLogger(AsyncExecutionController.class); + public static final int DEFAULT_BATCH_SIZE = 1000; public static final int DEFAULT_MAX_IN_FLIGHT_RECORD_NUM = 6000; + /** + * The batch size. When the number of state requests in the active buffer exceeds the batch + * size, a batched state execution would be triggered. + */ + private final int batchSize; + /** The max allowed number of in-flight records. */ private final int maxInFlightRecordNum; + /** + * The mailbox executor borrowed from {@code StreamTask}. Keeping the reference of + * mailboxExecutor here is to restrict the number of in-flight records, when the number of + * in-flight records > {@link #maxInFlightRecordNum}, the newly entering records would be + * blocked. + */ + private final MailboxExecutor mailboxExecutor; + /** The key accounting unit which is used to detect the key conflict. */ final KeyAccountingUnit<R, K> keyAccountingUnit; @@ -65,17 +84,35 @@ public class AsyncExecutionController<R, K> { /** The corresponding context that currently runs in task thread. */ RecordContext<R, K> currentContext; + /** The buffer to store the state requests to execute in batch. */ + StateRequestBuffer<R, K> stateRequestsBuffer; + + /** + * The number of in-flight records. Including the records in active buffer and blocking buffer. + */ + final AtomicInteger inFlightRecordNum; + public AsyncExecutionController(MailboxExecutor mailboxExecutor, StateExecutor stateExecutor) { - this(mailboxExecutor, stateExecutor, DEFAULT_MAX_IN_FLIGHT_RECORD_NUM); + this(mailboxExecutor, stateExecutor, DEFAULT_BATCH_SIZE, DEFAULT_MAX_IN_FLIGHT_RECORD_NUM); } public AsyncExecutionController( - MailboxExecutor mailboxExecutor, StateExecutor stateExecutor, int maxInFlightRecords) { + MailboxExecutor mailboxExecutor, + StateExecutor stateExecutor, + int batchSize, + int maxInFlightRecords) { this.keyAccountingUnit = new KeyAccountingUnit<>(maxInFlightRecords); + this.mailboxExecutor = mailboxExecutor; this.stateFutureFactory = new StateFutureFactory<>(this, mailboxExecutor); this.stateExecutor = stateExecutor; + this.batchSize = batchSize; this.maxInFlightRecordNum = maxInFlightRecords; - LOG.info("Create AsyncExecutionController: maxInFlightRecordsNum {}", maxInFlightRecords); + this.stateRequestsBuffer = new StateRequestBuffer<>(); + this.inFlightRecordNum = new AtomicInteger(0); + LOG.info( + "Create AsyncExecutionController: batchSize {}, maxInFlightRecordsNum {}", + batchSize, + maxInFlightRecords); } /** @@ -107,6 +144,14 @@ public class AsyncExecutionController<R, K> { */ public void disposeContext(RecordContext<R, K> toDispose) { keyAccountingUnit.release(toDispose.getRecord(), toDispose.getKey()); + inFlightRecordNum.decrementAndGet(); + RecordContext<R, K> nextRecordCtx = + stateRequestsBuffer.tryActivateOneByKey(toDispose.getKey()); + if (nextRecordCtx != null) { + Preconditions.checkState( + tryOccupyKey(nextRecordCtx), + String.format("key(%s) is already occupied.", nextRecordCtx.getKey())); + } } /** @@ -140,23 +185,28 @@ public class AsyncExecutionController<R, K> { InternalStateFuture<OUT> stateFuture = stateFutureFactory.create(currentContext); StateRequest<K, IN, OUT> request = new StateRequest<>(state, type, payload, stateFuture, currentContext); - // Step 2: try to occupy the key and place it into right buffer. + + // Step 2: try to seize the capacity, if the current in-flight records exceeds the limit, + // block the current state request from entering until some buffered requests are processed. + seizeCapacity(); + + // Step 3: try to occupy the key and place it into right buffer. if (tryOccupyKey(currentContext)) { insertActiveBuffer(request); } else { insertBlockingBuffer(request); } - // Step 3: trigger the (active) buffer if needed. + // Step 4: trigger the (active) buffer if needed. triggerIfNeeded(false); return stateFuture; } <IN, OUT> void insertActiveBuffer(StateRequest<K, IN, OUT> request) { - // TODO: implement the active buffer. + stateRequestsBuffer.enqueueToActive(request); } <IN, OUT> void insertBlockingBuffer(StateRequest<K, IN, OUT> request) { - // TODO: implement the blocking buffer. + stateRequestsBuffer.enqueueToBlocking(request); } /** @@ -165,6 +215,38 @@ public class AsyncExecutionController<R, K> { * @param force whether to trigger requests in force. */ void triggerIfNeeded(boolean force) { - // TODO: implement the trigger logic. + // TODO: introduce a timeout mechanism for triggering. + if (!force && stateRequestsBuffer.activeQueueSize() < batchSize) { + return; + } + List<StateRequest<?, ?, ?>> toRun = stateRequestsBuffer.popActive(batchSize); + stateExecutor.executeBatchRequests(toRun); + } + + private void seizeCapacity() { + // 1. Check if the record is already in buffer. If yes, this indicates that it is a state + // request resulting from a callback statement, otherwise, it signifies the initial state + // request for a newly entered record. + if (currentContext.isKeyOccupied()) { + return; + } + RecordContext<R, K> storedContext = currentContext; + // 2. If the state request is for a newly entered record, the in-flight record number should + // be less than the max in-flight record number. + // Note: the currentContext may be updated by {@code StateFutureFactory#build}. + try { + while (inFlightRecordNum.get() > maxInFlightRecordNum) { + if (!mailboxExecutor.tryYield()) { + triggerIfNeeded(true); + Thread.sleep(1); + } + } + } catch (InterruptedException ignored) { + // ignore the interrupted exception to avoid throwing fatal error when the task cancel + // or exit. + } + // 3. Ensure the currentContext is restored. + setCurrentContext(storedContext); + inFlightRecordNum.incrementAndGet(); } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/StateRequestBuffer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/StateRequestBuffer.java new file mode 100644 index 00000000000..9a0cde21936 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/asyncprocessing/StateRequestBuffer.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.asyncprocessing; + +import javax.annotation.Nullable; +import javax.annotation.concurrent.NotThreadSafe; + +import java.util.Deque; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; + +/** + * A buffer to hold state requests to execute state requests in batch, which can only be manipulated + * within task thread. + * + * @param <R> the type of the record + * @param <K> the type of the key + */ +@NotThreadSafe +public class StateRequestBuffer<R, K> { + /** + * The state requests in this buffer could be executed when the buffer is full or configured + * batch size is reached. All operations on this buffer must be invoked in task thread. + */ + final LinkedList<StateRequest<K, ?, ?>> activeQueue; + + /** + * The requests in that should wait until all preceding records with identical key finishing its + * execution. After which the queueing requests will move into the active buffer. All operations + * on this buffer must be invoked in task thread. + */ + final Map<K, Deque<StateRequest<K, ?, ?>>> blockingQueue; + + /** The number of state requests in blocking queue. */ + int blockingQueueSize; + + public StateRequestBuffer() { + this.activeQueue = new LinkedList<>(); + this.blockingQueue = new HashMap<>(); + this.blockingQueueSize = 0; + } + + void enqueueToActive(StateRequest<K, ?, ?> request) { + activeQueue.add(request); + } + + void enqueueToBlocking(StateRequest<K, ?, ?> request) { + blockingQueue + .computeIfAbsent(request.getRecordContext().getKey(), k -> new LinkedList<>()) + .add(request); + blockingQueueSize++; + } + + /** + * Try to pull one state request with specific key from blocking queue to active queue. + * + * @param key The key to release, the other records with this key is no longer blocking. + * @return The first record context with the same key in blocking queue, null if no such record. + */ + @Nullable + @SuppressWarnings("rawtypes") + RecordContext<R, K> tryActivateOneByKey(K key) { + if (!blockingQueue.containsKey(key)) { + return null; + } + + StateRequest<K, ?, ?> stateRequest = blockingQueue.get(key).removeFirst(); + activeQueue.add(stateRequest); + if (blockingQueue.get(key).isEmpty()) { + blockingQueue.remove(key); + } + blockingQueueSize--; + return (RecordContext<R, K>) stateRequest.getRecordContext(); + } + + /** + * Get the number of state requests of blocking queue in constant-time. + * + * @return the number of state requests of blocking queue. + */ + int blockingQueueSize() { + return blockingQueueSize; + } + + /** + * Get the number of state requests of active queue in constant-time. + * + * @return the number of state requests of active queue. + */ + int activeQueueSize() { + return activeQueue.size(); + } + + /** + * Try to pop state requests from active queue, if the size of active queue is less than N, + * return all the requests in active queue. + * + * @param n the number of state requests to pop. + * @return A list of state requests. + */ + List<StateRequest<?, ?, ?>> popActive(int n) { + LinkedList<StateRequest<?, ?, ?>> ret = + new LinkedList<>(activeQueue.subList(0, Math.min(activeQueue.size(), n))); + activeQueue.removeAll(ret); + return ret; + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/asyncprocessing/AsyncExecutionControllerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/asyncprocessing/AsyncExecutionControllerTest.java index 41c2031414d..acaba0c0225 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/asyncprocessing/AsyncExecutionControllerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/asyncprocessing/AsyncExecutionControllerTest.java @@ -18,18 +18,16 @@ package org.apache.flink.runtime.asyncprocessing; -import org.apache.flink.api.common.operators.MailboxExecutor; import org.apache.flink.api.common.state.v2.StateFuture; import org.apache.flink.api.common.state.v2.ValueState; import org.apache.flink.core.state.StateFutureUtils; import org.apache.flink.runtime.mailbox.SyncMailboxExecutor; import org.apache.flink.util.Preconditions; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import java.util.HashMap; -import java.util.Iterator; -import java.util.LinkedList; import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicInteger; @@ -37,34 +35,36 @@ import static org.assertj.core.api.AssertionsForClassTypes.assertThat; /** Test for {@link AsyncExecutionController}. */ class AsyncExecutionControllerTest { + AsyncExecutionController aec; + TestUnderlyingState underlyingState; + AtomicInteger output; + TestValueState valueState; + + final Runnable userCode = + () -> { + valueState + .asyncValue() + .thenCompose( + val -> { + int updated = (val == null ? 1 : (val + 1)); + return valueState + .asyncUpdate(updated) + .thenCompose( + o -> StateFutureUtils.completedFuture(updated)); + }) + .thenAccept(val -> output.set(val)); + }; + + @BeforeEach + void setup() { + aec = new AsyncExecutionController<>(new SyncMailboxExecutor(), new TestStateExecutor()); + underlyingState = new TestUnderlyingState(); + valueState = new TestValueState(aec, underlyingState); + output = new AtomicInteger(); + } - // TODO: this test is not well completed, cause buffering in AEC is not implemented. - // Yet, just for illustrating the interaction between AEC and Async state API. @Test void testBasicRun() { - TestAsyncExecutionController<String, String> aec = - new TestAsyncExecutionController<>( - new SyncMailboxExecutor(), new TestStateExecutor()); - TestUnderlyingState underlyingState = new TestUnderlyingState(); - TestValueState valueState = new TestValueState(aec, underlyingState); - AtomicInteger output = new AtomicInteger(); - Runnable userCode = - () -> { - valueState - .asyncValue() - .thenCompose( - val -> { - int updated = (val == null ? 1 : (val + 1)); - return valueState - .asyncUpdate(updated) - .thenCompose( - o -> - StateFutureUtils.completedFuture( - updated)); - }) - .thenAccept(val -> output.set(val)); - }; - // ============================ element1 ============================ String record1 = "key1-r1"; String key1 = "key1"; @@ -77,18 +77,20 @@ class AsyncExecutionControllerTest { // Single-step run. // Firstly, the user code generates value get in active buffer. - assertThat(aec.activeBuffer.size()).isEqualTo(1); + assertThat(aec.stateRequestsBuffer.activeQueueSize()).isEqualTo(1); assertThat(aec.keyAccountingUnit.occupiedCount()).isEqualTo(1); + assertThat(aec.inFlightRecordNum.get()).isEqualTo(1); aec.triggerIfNeeded(true); // After running, the value update is in active buffer. - assertThat(aec.activeBuffer.size()).isEqualTo(1); + assertThat(aec.stateRequestsBuffer.activeQueueSize()).isEqualTo(1); assertThat(aec.keyAccountingUnit.occupiedCount()).isEqualTo(1); aec.triggerIfNeeded(true); // Value update finishes. - assertThat(aec.activeBuffer.size()).isEqualTo(0); + assertThat(aec.stateRequestsBuffer.activeQueueSize()).isEqualTo(0); assertThat(aec.keyAccountingUnit.occupiedCount()).isEqualTo(0); assertThat(output.get()).isEqualTo(1); assertThat(recordContext1.getReferenceCount()).isEqualTo(0); + assertThat(aec.inFlightRecordNum.get()).isEqualTo(0); // ============================ element 2 & 3 ============================ String record2 = "key1-r2"; @@ -108,37 +110,37 @@ class AsyncExecutionControllerTest { // Single-step run. // Firstly, the user code for record2 generates value get in active buffer, // while user code for record3 generates value get in blocking buffer. - assertThat(aec.activeBuffer.size()).isEqualTo(1); - assertThat(aec.blockingBuffer.size()).isEqualTo(1); + assertThat(aec.stateRequestsBuffer.activeQueueSize()).isEqualTo(1); + assertThat(aec.stateRequestsBuffer.blockingQueueSize()).isEqualTo(1); assertThat(aec.keyAccountingUnit.occupiedCount()).isEqualTo(1); + assertThat(aec.inFlightRecordNum.get()).isEqualTo(2); aec.triggerIfNeeded(true); // After running, the value update for record2 is in active buffer. - assertThat(aec.activeBuffer.size()).isEqualTo(1); - assertThat(aec.blockingBuffer.size()).isEqualTo(1); + assertThat(aec.stateRequestsBuffer.activeQueueSize()).isEqualTo(1); + assertThat(aec.stateRequestsBuffer.blockingQueueSize()).isEqualTo(1); assertThat(aec.keyAccountingUnit.occupiedCount()).isEqualTo(1); + assertThat(aec.inFlightRecordNum.get()).isEqualTo(2); aec.triggerIfNeeded(true); - // Value update for record2 finishes. The value get for record3 is still in blocking status. - assertThat(aec.activeBuffer.size()).isEqualTo(0); - assertThat(aec.keyAccountingUnit.occupiedCount()).isEqualTo(0); + // Value update for record2 finishes. The value get for record3 is migrated from blocking + // buffer to active buffer actively. + assertThat(aec.stateRequestsBuffer.activeQueueSize()).isEqualTo(1); + assertThat(aec.keyAccountingUnit.occupiedCount()).isEqualTo(1); + assertThat(aec.inFlightRecordNum.get()).isEqualTo(1); assertThat(output.get()).isEqualTo(2); assertThat(recordContext2.getReferenceCount()).isEqualTo(0); - assertThat(aec.blockingBuffer.size()).isEqualTo(1); - - aec.migrateBlockingToActive(); - // Value get for record3 is ready for run. + assertThat(aec.stateRequestsBuffer.blockingQueueSize()).isEqualTo(0); - assertThat(aec.keyAccountingUnit.occupiedCount()).isEqualTo(1); - assertThat(aec.activeBuffer.size()).isEqualTo(1); - assertThat(aec.blockingBuffer.size()).isEqualTo(0); - assertThat(aec.keyAccountingUnit.occupiedCount()).isEqualTo(1); + // Let value get for record3 to run. aec.triggerIfNeeded(true); // After running, the value update for record3 is in active buffer. - assertThat(aec.activeBuffer.size()).isEqualTo(1); + assertThat(aec.stateRequestsBuffer.activeQueueSize()).isEqualTo(1); assertThat(aec.keyAccountingUnit.occupiedCount()).isEqualTo(1); + assertThat(aec.inFlightRecordNum.get()).isEqualTo(1); aec.triggerIfNeeded(true); // Value update for record3 finishes. - assertThat(aec.activeBuffer.size()).isEqualTo(0); + assertThat(aec.stateRequestsBuffer.activeQueueSize()).isEqualTo(0); assertThat(aec.keyAccountingUnit.occupiedCount()).isEqualTo(0); + assertThat(aec.inFlightRecordNum.get()).isEqualTo(0); assertThat(output.get()).isEqualTo(3); assertThat(recordContext3.getReferenceCount()).isEqualTo(0); @@ -152,66 +154,145 @@ class AsyncExecutionControllerTest { // Single-step run for another key. // Firstly, the user code generates value get in active buffer. - assertThat(aec.activeBuffer.size()).isEqualTo(1); + assertThat(aec.stateRequestsBuffer.activeQueueSize()).isEqualTo(1); assertThat(aec.keyAccountingUnit.occupiedCount()).isEqualTo(1); + assertThat(aec.inFlightRecordNum.get()).isEqualTo(1); aec.triggerIfNeeded(true); // After running, the value update is in active buffer. - assertThat(aec.activeBuffer.size()).isEqualTo(1); + assertThat(aec.stateRequestsBuffer.activeQueueSize()).isEqualTo(1); assertThat(aec.keyAccountingUnit.occupiedCount()).isEqualTo(1); + assertThat(aec.inFlightRecordNum.get()).isEqualTo(1); aec.triggerIfNeeded(true); // Value update finishes. - assertThat(aec.activeBuffer.size()).isEqualTo(0); + assertThat(aec.stateRequestsBuffer.activeQueueSize()).isEqualTo(0); assertThat(aec.keyAccountingUnit.occupiedCount()).isEqualTo(0); + assertThat(aec.inFlightRecordNum.get()).isEqualTo(0); assertThat(output.get()).isEqualTo(1); assertThat(recordContext4.getReferenceCount()).isEqualTo(0); } - /** - * An AsyncExecutionController for testing purpose, which integrates with basic buffer - * mechanism. - */ - static class TestAsyncExecutionController<R, K> extends AsyncExecutionController<R, K> { + @Test + void testRecordsRunInOrder() { + // Record1 and record3 have the same key, record2 has a different key. + // Record2 should be processed before record3. - LinkedList<StateRequest<K, ?, ?>> activeBuffer; + String record1 = "key1-r1"; + String key1 = "key1"; + RecordContext<String, String> recordContext1 = aec.buildContext(record1, key1); + aec.setCurrentContext(recordContext1); + // execute user code + userCode.run(); - LinkedList<StateRequest<K, ?, ?>> blockingBuffer; + String record2 = "key2-r1"; + String key2 = "key2"; + RecordContext<String, String> recordContext2 = aec.buildContext(record2, key2); + aec.setCurrentContext(recordContext2); + // execute user code + userCode.run(); - public TestAsyncExecutionController( - MailboxExecutor mailboxExecutor, StateExecutor stateExecutor) { - super(mailboxExecutor, stateExecutor); - activeBuffer = new LinkedList<>(); - blockingBuffer = new LinkedList<>(); - } + String record3 = "key1-r2"; + String key3 = "key1"; + RecordContext<String, String> recordContext3 = aec.buildContext(record3, key3); + aec.setCurrentContext(recordContext3); + // execute user code + userCode.run(); - @Override - <IN, OUT> void insertActiveBuffer(StateRequest<K, IN, OUT> request) { - activeBuffer.push(request); - } + // Record1's value get and record2's value get are in active buffer + assertThat(aec.stateRequestsBuffer.activeQueueSize()).isEqualTo(2); + assertThat(aec.keyAccountingUnit.occupiedCount()).isEqualTo(2); + assertThat(aec.inFlightRecordNum.get()).isEqualTo(3); + // Record3's value get is in blocking buffer + assertThat(aec.stateRequestsBuffer.blockingQueueSize()).isEqualTo(1); + aec.triggerIfNeeded(true); + // After running, record1's value update and record2's value update are in active buffer. + assertThat(aec.stateRequestsBuffer.activeQueueSize()).isEqualTo(2); + assertThat(aec.keyAccountingUnit.occupiedCount()).isEqualTo(2); + assertThat(aec.inFlightRecordNum.get()).isEqualTo(3); + assertThat(aec.stateRequestsBuffer.blockingQueueSize()).isEqualTo(1); + aec.triggerIfNeeded(true); + // Record1's value update and record2's value update finish, record3's value get migrates to + // active buffer when record1's refCount reach 0. + assertThat(aec.stateRequestsBuffer.activeQueueSize()).isEqualTo(1); + assertThat(aec.keyAccountingUnit.occupiedCount()).isEqualTo(1); + assertThat(aec.inFlightRecordNum.get()).isEqualTo(1); + assertThat(aec.stateRequestsBuffer.blockingQueueSize()).isEqualTo(0); + assertThat(output.get()).isEqualTo(1); + assertThat(recordContext1.getReferenceCount()).isEqualTo(0); + assertThat(recordContext2.getReferenceCount()).isEqualTo(0); + aec.triggerIfNeeded(true); + // After running, record3's value update is added to active buffer. + assertThat(aec.stateRequestsBuffer.activeQueueSize()).isEqualTo(1); + assertThat(aec.stateRequestsBuffer.blockingQueueSize()).isEqualTo(0); + aec.triggerIfNeeded(true); + assertThat(output.get()).isEqualTo(2); + assertThat(recordContext3.getReferenceCount()).isEqualTo(0); + assertThat(aec.inFlightRecordNum.get()).isEqualTo(0); + } - <IN, OUT> void insertBlockingBuffer(StateRequest<K, IN, OUT> request) { - blockingBuffer.push(request); - } + @Test + void testInFlightRecordControl() { + final int batchSize = 5; + final int maxInFlight = 10; + aec = + new AsyncExecutionController<>( + new SyncMailboxExecutor(), new TestStateExecutor(), batchSize, maxInFlight); + valueState = new TestValueState(aec, underlyingState); - void triggerIfNeeded(boolean force) { - if (!force) { - // Disable normal trigger, to perform single-step debugging and check. - return; - } - LinkedList<StateRequest<?, ?, ?>> toRun = new LinkedList<>(activeBuffer); - activeBuffer.clear(); - stateExecutor.executeBatchRequests(toRun); - } + AtomicInteger output = new AtomicInteger(); + Runnable userCode = + () -> { + valueState + .asyncValue() + .thenCompose( + val -> { + int updated = (val == null ? 1 : (val + 1)); + return valueState + .asyncUpdate(updated) + .thenCompose( + o -> + StateFutureUtils.completedFuture( + updated)); + }) + .thenAccept(val -> output.set(val)); + }; - @SuppressWarnings("unchecked") - void migrateBlockingToActive() { - Iterator<StateRequest<K, ?, ?>> blockingIter = blockingBuffer.iterator(); - while (blockingIter.hasNext()) { - StateRequest<K, ?, ?> request = blockingIter.next(); - if (tryOccupyKey((RecordContext<R, K>) request.getRecordContext())) { - insertActiveBuffer(request); - blockingIter.remove(); - } + // For records with different keys, the in-flight records is controlled by batch size. + for (int round = 0; round < 10; round++) { + for (int i = 0; i < batchSize; i++) { + String record = + String.format("key%d-r%d", round * batchSize + i, round * batchSize + i); + String key = String.format("key%d", round * batchSize + i); + RecordContext<String, String> recordContext = aec.buildContext(record, key); + aec.setCurrentContext(recordContext); + userCode.run(); } + assertThat(aec.inFlightRecordNum.get()).isEqualTo(0); + assertThat(aec.stateRequestsBuffer.activeQueueSize()).isEqualTo(0); + assertThat(aec.stateRequestsBuffer.blockingQueueSize()).isEqualTo(0); + } + // For records with the same key, the in-flight records is controlled by max in-flight + // records number. + for (int i = 0; i < maxInFlight; i++) { + String record = String.format("sameKey-r%d", i, i); + String key = "sameKey"; + RecordContext<String, String> recordContext = aec.buildContext(record, key); + aec.setCurrentContext(recordContext); + userCode.run(); + } + assertThat(aec.inFlightRecordNum.get()).isEqualTo(maxInFlight); + assertThat(aec.stateRequestsBuffer.activeQueueSize()).isEqualTo(1); + assertThat(aec.stateRequestsBuffer.blockingQueueSize()).isEqualTo(maxInFlight - 1); + // In the following example, the batch size will degrade to 1, meaning that + // each batch only have 1 state request. + for (int i = maxInFlight; i < 10 * maxInFlight; i++) { + String record = String.format("sameKey-r%d", i, i); + String key = "sameKey"; + RecordContext<String, String> recordContext = aec.buildContext(record, key); + aec.setCurrentContext(recordContext); + userCode.run(); + assertThat(aec.inFlightRecordNum.get()).isEqualTo(maxInFlight + 1); + assertThat(aec.stateRequestsBuffer.activeQueueSize()).isEqualTo(1); + assertThat(aec.stateRequestsBuffer.blockingQueueSize()).isEqualTo(maxInFlight); } }