This is an automated email from the ASF dual-hosted git repository. xtsong pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit 14392cf4ff291d1dbe316f4687e79760591127f3 Author: Weijie Guo <[email protected]> AuthorDate: Mon Sep 26 17:02:22 2022 +0800 [FLINK-28889] HsBufferContext supports multiple consumer. --- .../network/partition/hybrid/HsBufferContext.java | 16 +++--- .../io/network/partition/hybrid/HsConsumerId.java | 63 ++++++++++++++++++++++ .../partition/hybrid/HsFullSpillingStrategy.java | 12 +++-- .../partition/hybrid/HsMemoryDataManager.java | 5 +- .../hybrid/HsSelectiveSpillingStrategy.java | 13 +++-- .../partition/hybrid/HsSpillingInfoProvider.java | 24 ++++++++- .../hybrid/HsSubpartitionMemoryDataManager.java | 22 ++++---- .../partition/hybrid/HsBufferContextTest.java | 27 +++++++--- .../network/partition/hybrid/HsConsumerIdTest.java | 43 +++++++++++++++ .../HsSubpartitionMemoryDataManagerTest.java | 28 ++++++---- .../hybrid/TestingSpillingInfoProvider.java | 8 +-- 11 files changed, 212 insertions(+), 49 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsBufferContext.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsBufferContext.java index 4a5f5fd9651..06970c97e54 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsBufferContext.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsBufferContext.java @@ -22,8 +22,11 @@ import org.apache.flink.runtime.io.network.buffer.Buffer; import javax.annotation.Nullable; +import java.util.Collections; import java.util.Optional; +import java.util.Set; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; import static org.apache.flink.util.Preconditions.checkState; @@ -37,7 +40,7 @@ import static org.apache.flink.util.Preconditions.checkState; * can no longer be spilled or consumed. * <li>{@link #spillStarted} indicates that spilling of the buffer has started, either completed * or not. - * <li>{@link #consumed} indicates that buffer has been consumed by the downstream. + * <li>{@link #consumed} indicates that buffer has been consumed by these consumers. * </ul> * * <p>Reference count of the buffer is maintained as follows: * @@ -64,7 +67,7 @@ public class HsBufferContext { private boolean spillStarted; - private boolean consumed; + private final Set<HsConsumerId> consumed = Collections.newSetFromMap(new ConcurrentHashMap<>()); @Nullable private CompletableFuture<Void> spilledFuture; @@ -89,8 +92,8 @@ public class HsBufferContext { return spillStarted; } - public boolean isConsumed() { - return consumed; + public boolean isConsumed(HsConsumerId consumerId) { + return consumed.contains(consumerId); } public Optional<CompletableFuture<Void>> getSpilledFuture() { @@ -127,10 +130,9 @@ public class HsBufferContext { return true; } - public void consumed() { + public void consumed(HsConsumerId consumerId) { checkState(!released, "Buffer is already released."); - checkState(!consumed, "Consume buffer repeatedly is unexpected."); - consumed = true; + checkState(consumed.add(consumerId), "Consume buffer repeatedly is unexpected."); // increase ref count when buffer is consumed, will be decreased when downstream finish // consuming. buffer.retainBuffer(); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsConsumerId.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsConsumerId.java new file mode 100644 index 00000000000..132da40f6d7 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsConsumerId.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.partition.hybrid; + +import javax.annotation.Nullable; + +import java.util.Objects; + +/** This class represents the identifier of hybrid shuffle's consumer. */ +public class HsConsumerId { + /** + * This consumer id is used in the scenarios that information related to specific consumer needs + * to be ignored. + */ + public static final HsConsumerId ANY = new HsConsumerId(-1); + + /** This consumer id is used for the first consumer of a single subpartition. */ + public static final HsConsumerId DEFAULT = new HsConsumerId(0); + + /** This is a unique field for each consumer of a single subpartition. */ + private final int id; + + private HsConsumerId(int id) { + this.id = id; + } + + public static HsConsumerId newId(@Nullable HsConsumerId lastId) { + return lastId == null ? DEFAULT : new HsConsumerId(lastId.id + 1); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + HsConsumerId that = (HsConsumerId) o; + return id == that.id; + } + + @Override + public int hashCode() { + return Objects.hash(id); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsFullSpillingStrategy.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsFullSpillingStrategy.java index ac2d86fcc30..94a067fbaeb 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsFullSpillingStrategy.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsFullSpillingStrategy.java @@ -18,7 +18,7 @@ package org.apache.flink.runtime.io.network.partition.hybrid; -import org.apache.flink.runtime.io.network.partition.hybrid.HsSpillingInfoProvider.ConsumeStatus; +import org.apache.flink.runtime.io.network.partition.hybrid.HsSpillingInfoProvider.ConsumeStatusWithId; import org.apache.flink.runtime.io.network.partition.hybrid.HsSpillingInfoProvider.SpillStatus; import java.util.ArrayDeque; @@ -88,12 +88,14 @@ public class HsFullSpillingStrategy implements HsSpillingStrategy { subpartitionId, // get all not start spilling buffers. spillingInfoProvider.getBuffersInOrder( - subpartitionId, SpillStatus.NOT_SPILL, ConsumeStatus.ALL)) + subpartitionId, + SpillStatus.NOT_SPILL, + ConsumeStatusWithId.ALL_ANY)) .addBufferToRelease( subpartitionId, // get all not released buffers. spillingInfoProvider.getBuffersInOrder( - subpartitionId, SpillStatus.ALL, ConsumeStatus.ALL)); + subpartitionId, SpillStatus.ALL, ConsumeStatusWithId.ALL_ANY)); } return builder.build(); } @@ -110,7 +112,7 @@ public class HsFullSpillingStrategy implements HsSpillingStrategy { builder.addBufferToSpill( i, spillingInfoProvider.getBuffersInOrder( - i, SpillStatus.NOT_SPILL, ConsumeStatus.ALL)); + i, SpillStatus.NOT_SPILL, ConsumeStatusWithId.ALL_ANY)); } } @@ -135,7 +137,7 @@ public class HsFullSpillingStrategy implements HsSpillingStrategy { for (int subpartitionId = 0; subpartitionId < numSubpartitions; subpartitionId++) { Deque<BufferIndexAndChannel> buffersInOrder = spillingInfoProvider.getBuffersInOrder( - subpartitionId, SpillStatus.SPILL, ConsumeStatus.ALL); + subpartitionId, SpillStatus.SPILL, ConsumeStatusWithId.ALL_ANY); // if the number of subpartition buffers less than survived buffers, reserved all of // them. int releaseNum = Math.max(0, buffersInOrder.size() - subpartitionSurvivedNum); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsMemoryDataManager.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsMemoryDataManager.java index f19cb58a795..a8d7d89e7e4 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsMemoryDataManager.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsMemoryDataManager.java @@ -223,10 +223,11 @@ public class HsMemoryDataManager implements HsSpillingInfoProvider, HsMemoryData // Write lock should be acquired before invoke this method. @Override public Deque<BufferIndexAndChannel> getBuffersInOrder( - int subpartitionId, SpillStatus spillStatus, ConsumeStatus consumeStatus) { + int subpartitionId, SpillStatus spillStatus, ConsumeStatusWithId consumeStatusWithId) { HsSubpartitionMemoryDataManager targetSubpartitionDataManager = getSubpartitionMemoryDataManager(subpartitionId); - return targetSubpartitionDataManager.getBuffersSatisfyStatus(spillStatus, consumeStatus); + return targetSubpartitionDataManager.getBuffersSatisfyStatus( + spillStatus, consumeStatusWithId); } // Write lock should be acquired before invoke this method. diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSelectiveSpillingStrategy.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSelectiveSpillingStrategy.java index dc356cb20d5..8c7217a4b93 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSelectiveSpillingStrategy.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSelectiveSpillingStrategy.java @@ -19,6 +19,7 @@ package org.apache.flink.runtime.io.network.partition.hybrid; import org.apache.flink.runtime.io.network.partition.hybrid.HsSpillingInfoProvider.ConsumeStatus; +import org.apache.flink.runtime.io.network.partition.hybrid.HsSpillingInfoProvider.ConsumeStatusWithId; import org.apache.flink.runtime.io.network.partition.hybrid.HsSpillingInfoProvider.SpillStatus; import java.util.Deque; @@ -85,7 +86,11 @@ public class HsSelectiveSpillingStrategy implements HsSpillingStrategy { subpartitionToBuffers.put( channel, spillingInfoProvider.getBuffersInOrder( - channel, SpillStatus.NOT_SPILL, ConsumeStatus.NOT_CONSUMED)); + channel, + SpillStatus.NOT_SPILL, + // selective spilling strategy does not support multiple consumer. + ConsumeStatusWithId.fromStatusAndConsumerId( + ConsumeStatus.NOT_CONSUMED, HsConsumerId.DEFAULT))); } TreeMap<Integer, List<BufferIndexAndChannel>> subpartitionToHighPriorityBuffers = @@ -113,12 +118,14 @@ public class HsSelectiveSpillingStrategy implements HsSpillingStrategy { subpartitionId, // get all not start spilling buffers. spillingInfoProvider.getBuffersInOrder( - subpartitionId, SpillStatus.NOT_SPILL, ConsumeStatus.ALL)) + subpartitionId, + SpillStatus.NOT_SPILL, + ConsumeStatusWithId.ALL_ANY)) .addBufferToRelease( subpartitionId, // get all not released buffers. spillingInfoProvider.getBuffersInOrder( - subpartitionId, SpillStatus.ALL, ConsumeStatus.ALL)); + subpartitionId, SpillStatus.ALL, ConsumeStatusWithId.ALL_ANY)); } return builder.build(); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSpillingInfoProvider.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSpillingInfoProvider.java index 6168c637de4..deb32e7058e 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSpillingInfoProvider.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSpillingInfoProvider.java @@ -43,13 +43,13 @@ public interface HsSpillingInfoProvider { * * @param subpartitionId target buffers belong to. * @param spillStatus expected buffer spill status. - * @param consumeStatus expected buffer consume status. + * @param consumeStatusWithId expected buffer consume status and consumer id. * @return all buffers satisfy specific status of this subpartition, This queue must be sorted * according to bufferIndex from small to large, in other words, head is the buffer with the * minimum bufferIndex in the current subpartition. */ Deque<BufferIndexAndChannel> getBuffersInOrder( - int subpartitionId, SpillStatus spillStatus, ConsumeStatus consumeStatus); + int subpartitionId, SpillStatus spillStatus, ConsumeStatusWithId consumeStatusWithId); /** Get total number of not decided to spill buffers. */ int getNumTotalUnSpillBuffers(); @@ -79,4 +79,24 @@ public interface HsSpillingInfoProvider { /** The buffer is either consumed or not consumed. */ ALL } + + /** This class represents a pair of {@link ConsumeStatus} and consumer id. */ + class ConsumeStatusWithId { + public static final ConsumeStatusWithId ALL_ANY = + new ConsumeStatusWithId(ConsumeStatus.ALL, HsConsumerId.ANY); + + ConsumeStatus status; + + HsConsumerId consumerId; + + private ConsumeStatusWithId(ConsumeStatus status, HsConsumerId consumerId) { + this.status = status; + this.consumerId = consumerId; + } + + public static ConsumeStatusWithId fromStatusAndConsumerId( + ConsumeStatus consumeStatus, HsConsumerId consumerId) { + return new ConsumeStatusWithId(consumeStatus, consumerId); + } + } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSubpartitionMemoryDataManager.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSubpartitionMemoryDataManager.java index 619bc01d05a..f6804a62b1b 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSubpartitionMemoryDataManager.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSubpartitionMemoryDataManager.java @@ -30,6 +30,7 @@ import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler; import org.apache.flink.runtime.io.network.buffer.NetworkBuffer; import org.apache.flink.runtime.io.network.partition.ResultSubpartition.BufferAndBacklog; import org.apache.flink.runtime.io.network.partition.hybrid.HsSpillingInfoProvider.ConsumeStatus; +import org.apache.flink.runtime.io.network.partition.hybrid.HsSpillingInfoProvider.ConsumeStatusWithId; import org.apache.flink.runtime.io.network.partition.hybrid.HsSpillingInfoProvider.SpillStatus; import org.apache.flink.util.function.SupplierWithException; import org.apache.flink.util.function.ThrowingRunnable; @@ -144,7 +145,8 @@ public class HsSubpartitionMemoryDataManager implements HsDataView { HsBufferContext bufferContext = checkNotNull(unConsumedBuffers.pollFirst()); - bufferContext.consumed(); + // TODO move this logical to consumer and pass real consumerId. + bufferContext.consumed(HsConsumerId.DEFAULT); DataType nextDataType = peekNextToConsumeDataTypeInternal(toConsumeIndex + 1); return Optional.of(Tuple2.of(bufferContext, nextDataType)); @@ -200,14 +202,14 @@ public class HsSubpartitionMemoryDataManager implements HsDataView { * ConsumeStatus}. * * @param spillStatus the status of spilling expected. - * @param consumeStatus the status of consuming expected. + * @param consumeStatusWithId the status and consumerId expected. * @return buffers satisfy expected status in order. */ @SuppressWarnings("FieldAccessNotGuarded") // Note that: callWithLock ensure that code block guarded by resultPartitionReadLock and // subpartitionLock. public Deque<BufferIndexAndChannel> getBuffersSatisfyStatus( - SpillStatus spillStatus, ConsumeStatus consumeStatus) { + SpillStatus spillStatus, ConsumeStatusWithId consumeStatusWithId) { return callWithLock( () -> { // TODO return iterator to avoid completely traversing the queue for each call. @@ -216,7 +218,7 @@ public class HsSubpartitionMemoryDataManager implements HsDataView { allBuffers.forEach( (bufferContext -> { if (isBufferSatisfyStatus( - bufferContext, spillStatus, consumeStatus)) { + bufferContext, spillStatus, consumeStatusWithId)) { targetBuffers.add(bufferContext.getBufferIndexAndChannel()); } })); @@ -460,7 +462,7 @@ public class HsSubpartitionMemoryDataManager implements HsDataView { @GuardedBy("subpartitionLock") private void checkAndMarkBufferReadable(HsBufferContext bufferContext) { // only spill buffer needs to be marked as released. - if (isBufferSatisfyStatus(bufferContext, SpillStatus.SPILL, ConsumeStatus.ALL)) { + if (isBufferSatisfyStatus(bufferContext, SpillStatus.SPILL, ConsumeStatusWithId.ALL_ANY)) { bufferContext .getSpilledFuture() .orElseThrow( @@ -480,7 +482,9 @@ public class HsSubpartitionMemoryDataManager implements HsDataView { @GuardedBy("subpartitionLock") private boolean isBufferSatisfyStatus( - HsBufferContext bufferContext, SpillStatus spillStatus, ConsumeStatus consumeStatus) { + HsBufferContext bufferContext, + SpillStatus spillStatus, + ConsumeStatusWithId consumeStatusWithId) { // released buffer is not needed. if (bufferContext.isReleased()) { return false; @@ -494,12 +498,12 @@ public class HsSubpartitionMemoryDataManager implements HsDataView { match = bufferContext.isSpillStarted(); break; } - switch (consumeStatus) { + switch (consumeStatusWithId.status) { case NOT_CONSUMED: - match &= !bufferContext.isConsumed(); + match &= !bufferContext.isConsumed(consumeStatusWithId.consumerId); break; case CONSUMED: - match &= bufferContext.isConsumed(); + match &= bufferContext.isConsumed(consumeStatusWithId.consumerId); break; } return match; diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsBufferContextTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsBufferContextTest.java index 960357d4127..f0419d114df 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsBufferContextTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsBufferContextTest.java @@ -81,25 +81,40 @@ class HsBufferContextTest { @Test void testBufferConsumed() { + final HsConsumerId consumerId = HsConsumerId.DEFAULT; Buffer buffer = bufferContext.getBuffer(); - bufferContext.consumed(); - assertThat(bufferContext.isConsumed()).isTrue(); + bufferContext.consumed(consumerId); + assertThat(bufferContext.isConsumed(consumerId)).isTrue(); assertThat(buffer.refCnt()).isEqualTo(2); } @Test void testBufferConsumedRepeatedly() { - bufferContext.consumed(); - assertThatThrownBy(() -> bufferContext.consumed()) + final HsConsumerId consumerId = HsConsumerId.DEFAULT; + bufferContext.consumed(consumerId); + assertThatThrownBy(() -> bufferContext.consumed(consumerId)) .isInstanceOf(IllegalStateException.class) .hasMessageContaining("Consume buffer repeatedly is unexpected."); } + @Test + void testBufferConsumedMultipleConsumer() { + HsConsumerId consumer0 = HsConsumerId.newId(null); + HsConsumerId consumer1 = HsConsumerId.newId(consumer0); + bufferContext.consumed(consumer0); + bufferContext.consumed(consumer1); + + assertThat(bufferContext.isConsumed(consumer0)).isTrue(); + assertThat(bufferContext.isConsumed(consumer1)).isTrue(); + + assertThat(bufferContext.isConsumed(HsConsumerId.newId(consumer1))).isFalse(); + } + @Test void testBufferStartSpillOrConsumedAfterReleased() { bufferContext.release(); assertThat(bufferContext.startSpilling(new CompletableFuture<>())).isFalse(); - assertThatThrownBy(() -> bufferContext.consumed()) + assertThatThrownBy(() -> bufferContext.consumed(HsConsumerId.DEFAULT)) .isInstanceOf(IllegalStateException.class) .hasMessageContaining("Buffer is already released."); } @@ -117,7 +132,7 @@ class HsBufferContextTest { @Test void testBufferConsumedThenRelease() { Buffer buffer = bufferContext.getBuffer(); - bufferContext.consumed(); + bufferContext.consumed(HsConsumerId.DEFAULT); bufferContext.release(); assertThat(buffer.refCnt()).isEqualTo(1); } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsConsumerIdTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsConsumerIdTest.java new file mode 100644 index 00000000000..591d6b0ccfd --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsConsumerIdTest.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.partition.hybrid; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Tests for {@link HsConsumerId}. */ +class HsConsumerIdTest { + @Test + void testNewIdFromNull() { + HsConsumerId consumerId = HsConsumerId.newId(null); + assertThat(consumerId).isNotNull().isEqualTo(HsConsumerId.DEFAULT); + } + + @Test + void testConsumerIdEquals() { + HsConsumerId consumerId = HsConsumerId.newId(null); + HsConsumerId consumerId1 = HsConsumerId.newId(consumerId); + HsConsumerId consumerId2 = HsConsumerId.newId(consumerId); + assertThat(consumerId1.hashCode()).isEqualTo(consumerId2.hashCode()); + assertThat(consumerId1).isEqualTo(consumerId2); + + assertThat(HsConsumerId.newId(consumerId2)).isNotEqualTo(consumerId2); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSubpartitionMemoryDataManagerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSubpartitionMemoryDataManagerTest.java index ba53f2a9efb..174c021990d 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSubpartitionMemoryDataManagerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSubpartitionMemoryDataManagerTest.java @@ -53,6 +53,8 @@ import java.util.concurrent.locks.ReentrantReadWriteLock; import java.util.stream.Collectors; import java.util.stream.IntStream; +import static org.apache.flink.runtime.io.network.partition.hybrid.HsSpillingInfoProvider.ConsumeStatusWithId.ALL_ANY; +import static org.apache.flink.runtime.io.network.partition.hybrid.HsSpillingInfoProvider.ConsumeStatusWithId.fromStatusAndConsumerId; import static org.apache.flink.runtime.io.network.partition.hybrid.HybridShuffleTestUtils.createBufferBuilder; import static org.apache.flink.runtime.io.network.partition.hybrid.HybridShuffleTestUtils.createTestingOutputMetrics; import static org.assertj.core.api.Assertions.assertThat; @@ -290,40 +292,44 @@ class HsSubpartitionMemoryDataManagerTest { subpartitionMemoryDataManager.consumeBuffer(1); checkBufferIndex( - subpartitionMemoryDataManager.getBuffersSatisfyStatus( - SpillStatus.ALL, ConsumeStatus.ALL), + subpartitionMemoryDataManager.getBuffersSatisfyStatus(SpillStatus.ALL, ALL_ANY), Arrays.asList(0, 1, 2, 3)); checkBufferIndex( subpartitionMemoryDataManager.getBuffersSatisfyStatus( - SpillStatus.ALL, ConsumeStatus.CONSUMED), + SpillStatus.ALL, + fromStatusAndConsumerId(ConsumeStatus.CONSUMED, HsConsumerId.DEFAULT)), Arrays.asList(0, 1)); checkBufferIndex( subpartitionMemoryDataManager.getBuffersSatisfyStatus( - SpillStatus.ALL, ConsumeStatus.NOT_CONSUMED), + SpillStatus.ALL, + fromStatusAndConsumerId(ConsumeStatus.NOT_CONSUMED, HsConsumerId.DEFAULT)), Arrays.asList(2, 3)); checkBufferIndex( - subpartitionMemoryDataManager.getBuffersSatisfyStatus( - SpillStatus.SPILL, ConsumeStatus.ALL), + subpartitionMemoryDataManager.getBuffersSatisfyStatus(SpillStatus.SPILL, ALL_ANY), Arrays.asList(1, 2)); checkBufferIndex( subpartitionMemoryDataManager.getBuffersSatisfyStatus( - SpillStatus.NOT_SPILL, ConsumeStatus.ALL), + SpillStatus.NOT_SPILL, ALL_ANY), Arrays.asList(0, 3)); checkBufferIndex( subpartitionMemoryDataManager.getBuffersSatisfyStatus( - SpillStatus.SPILL, ConsumeStatus.NOT_CONSUMED), + SpillStatus.SPILL, + fromStatusAndConsumerId(ConsumeStatus.NOT_CONSUMED, HsConsumerId.DEFAULT)), Collections.singletonList(2)); checkBufferIndex( subpartitionMemoryDataManager.getBuffersSatisfyStatus( - SpillStatus.SPILL, ConsumeStatus.CONSUMED), + SpillStatus.SPILL, + fromStatusAndConsumerId(ConsumeStatus.CONSUMED, HsConsumerId.DEFAULT)), Collections.singletonList(1)); checkBufferIndex( subpartitionMemoryDataManager.getBuffersSatisfyStatus( - SpillStatus.NOT_SPILL, ConsumeStatus.CONSUMED), + SpillStatus.NOT_SPILL, + fromStatusAndConsumerId(ConsumeStatus.CONSUMED, HsConsumerId.DEFAULT)), Collections.singletonList(0)); checkBufferIndex( subpartitionMemoryDataManager.getBuffersSatisfyStatus( - SpillStatus.NOT_SPILL, ConsumeStatus.NOT_CONSUMED), + SpillStatus.NOT_SPILL, + fromStatusAndConsumerId(ConsumeStatus.NOT_CONSUMED, HsConsumerId.DEFAULT)), Collections.singletonList(3)); } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/TestingSpillingInfoProvider.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/TestingSpillingInfoProvider.java index f8b303479f0..67b34d0e774 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/TestingSpillingInfoProvider.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/TestingSpillingInfoProvider.java @@ -79,7 +79,7 @@ public class TestingSpillingInfoProvider implements HsSpillingInfoProvider { @Override public Deque<BufferIndexAndChannel> getBuffersInOrder( - int subpartitionId, SpillStatus spillStatus, ConsumeStatus consumeStatus) { + int subpartitionId, SpillStatus spillStatus, ConsumeStatusWithId consumeStatusWithId) { Deque<BufferIndexAndChannel> buffersInOrder = new ArrayDeque<>(); List<BufferIndexAndChannel> subpartitionBuffers = allBuffers.get(subpartitionId); @@ -90,7 +90,7 @@ public class TestingSpillingInfoProvider implements HsSpillingInfoProvider { for (int i = 0; i < subpartitionBuffers.size(); i++) { if (isBufferSatisfyStatus( spillStatus, - consumeStatus, + consumeStatusWithId, spillBufferIndexes .getOrDefault(subpartitionId, Collections.emptySet()) .contains(i), @@ -124,7 +124,7 @@ public class TestingSpillingInfoProvider implements HsSpillingInfoProvider { private static boolean isBufferSatisfyStatus( SpillStatus spillStatus, - ConsumeStatus consumeStatus, + ConsumeStatusWithId consumeStatusWithId, boolean isSpill, boolean isConsumed) { boolean isNeeded = true; @@ -136,7 +136,7 @@ public class TestingSpillingInfoProvider implements HsSpillingInfoProvider { isNeeded = isSpill; break; } - switch (consumeStatus) { + switch (consumeStatusWithId.status) { case NOT_CONSUMED: isNeeded &= !isConsumed; break;
