This is an automated email from the ASF dual-hosted git repository.

xtsong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 14392cf4ff291d1dbe316f4687e79760591127f3
Author: Weijie Guo <[email protected]>
AuthorDate: Mon Sep 26 17:02:22 2022 +0800

    [FLINK-28889] HsBufferContext supports multiple consumer.
---
 .../network/partition/hybrid/HsBufferContext.java  | 16 +++---
 .../io/network/partition/hybrid/HsConsumerId.java  | 63 ++++++++++++++++++++++
 .../partition/hybrid/HsFullSpillingStrategy.java   | 12 +++--
 .../partition/hybrid/HsMemoryDataManager.java      |  5 +-
 .../hybrid/HsSelectiveSpillingStrategy.java        | 13 +++--
 .../partition/hybrid/HsSpillingInfoProvider.java   | 24 ++++++++-
 .../hybrid/HsSubpartitionMemoryDataManager.java    | 22 ++++----
 .../partition/hybrid/HsBufferContextTest.java      | 27 +++++++---
 .../network/partition/hybrid/HsConsumerIdTest.java | 43 +++++++++++++++
 .../HsSubpartitionMemoryDataManagerTest.java       | 28 ++++++----
 .../hybrid/TestingSpillingInfoProvider.java        |  8 +--
 11 files changed, 212 insertions(+), 49 deletions(-)

diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsBufferContext.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsBufferContext.java
index 4a5f5fd9651..06970c97e54 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsBufferContext.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsBufferContext.java
@@ -22,8 +22,11 @@ import org.apache.flink.runtime.io.network.buffer.Buffer;
 
 import javax.annotation.Nullable;
 
+import java.util.Collections;
 import java.util.Optional;
+import java.util.Set;
 import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ConcurrentHashMap;
 
 import static org.apache.flink.util.Preconditions.checkState;
 
@@ -37,7 +40,7 @@ import static org.apache.flink.util.Preconditions.checkState;
  *       can no longer be spilled or consumed.
  *   <li>{@link #spillStarted} indicates that spilling of the buffer has 
started, either completed
  *       or not.
- *   <li>{@link #consumed} indicates that buffer has been consumed by the 
downstream.
+ *   <li>{@link #consumed} indicates that buffer has been consumed by these 
consumers.
  * </ul>
  *
  * <p>Reference count of the buffer is maintained as follows: *
@@ -64,7 +67,7 @@ public class HsBufferContext {
 
     private boolean spillStarted;
 
-    private boolean consumed;
+    private final Set<HsConsumerId> consumed = Collections.newSetFromMap(new 
ConcurrentHashMap<>());
 
     @Nullable private CompletableFuture<Void> spilledFuture;
 
@@ -89,8 +92,8 @@ public class HsBufferContext {
         return spillStarted;
     }
 
-    public boolean isConsumed() {
-        return consumed;
+    public boolean isConsumed(HsConsumerId consumerId) {
+        return consumed.contains(consumerId);
     }
 
     public Optional<CompletableFuture<Void>> getSpilledFuture() {
@@ -127,10 +130,9 @@ public class HsBufferContext {
         return true;
     }
 
-    public void consumed() {
+    public void consumed(HsConsumerId consumerId) {
         checkState(!released, "Buffer is already released.");
-        checkState(!consumed, "Consume buffer repeatedly is unexpected.");
-        consumed = true;
+        checkState(consumed.add(consumerId), "Consume buffer repeatedly is 
unexpected.");
         // increase ref count when buffer is consumed, will be decreased when 
downstream finish
         // consuming.
         buffer.retainBuffer();
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsConsumerId.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsConsumerId.java
new file mode 100644
index 00000000000..132da40f6d7
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsConsumerId.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.partition.hybrid;
+
+import javax.annotation.Nullable;
+
+import java.util.Objects;
+
+/** This class represents the identifier of hybrid shuffle's consumer. */
+public class HsConsumerId {
+    /**
+     * This consumer id is used in the scenarios that information related to 
specific consumer needs
+     * to be ignored.
+     */
+    public static final HsConsumerId ANY = new HsConsumerId(-1);
+
+    /** This consumer id is used for the first consumer of a single 
subpartition. */
+    public static final HsConsumerId DEFAULT = new HsConsumerId(0);
+
+    /** This is a unique field for each consumer of a single subpartition. */
+    private final int id;
+
+    private HsConsumerId(int id) {
+        this.id = id;
+    }
+
+    public static HsConsumerId newId(@Nullable HsConsumerId lastId) {
+        return lastId == null ? DEFAULT : new HsConsumerId(lastId.id + 1);
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) {
+            return true;
+        }
+        if (o == null || getClass() != o.getClass()) {
+            return false;
+        }
+        HsConsumerId that = (HsConsumerId) o;
+        return id == that.id;
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(id);
+    }
+}
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsFullSpillingStrategy.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsFullSpillingStrategy.java
index ac2d86fcc30..94a067fbaeb 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsFullSpillingStrategy.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsFullSpillingStrategy.java
@@ -18,7 +18,7 @@
 
 package org.apache.flink.runtime.io.network.partition.hybrid;
 
-import 
org.apache.flink.runtime.io.network.partition.hybrid.HsSpillingInfoProvider.ConsumeStatus;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.HsSpillingInfoProvider.ConsumeStatusWithId;
 import 
org.apache.flink.runtime.io.network.partition.hybrid.HsSpillingInfoProvider.SpillStatus;
 
 import java.util.ArrayDeque;
@@ -88,12 +88,14 @@ public class HsFullSpillingStrategy implements 
HsSpillingStrategy {
                             subpartitionId,
                             // get all not start spilling buffers.
                             spillingInfoProvider.getBuffersInOrder(
-                                    subpartitionId, SpillStatus.NOT_SPILL, 
ConsumeStatus.ALL))
+                                    subpartitionId,
+                                    SpillStatus.NOT_SPILL,
+                                    ConsumeStatusWithId.ALL_ANY))
                     .addBufferToRelease(
                             subpartitionId,
                             // get all not released buffers.
                             spillingInfoProvider.getBuffersInOrder(
-                                    subpartitionId, SpillStatus.ALL, 
ConsumeStatus.ALL));
+                                    subpartitionId, SpillStatus.ALL, 
ConsumeStatusWithId.ALL_ANY));
         }
         return builder.build();
     }
@@ -110,7 +112,7 @@ public class HsFullSpillingStrategy implements 
HsSpillingStrategy {
             builder.addBufferToSpill(
                     i,
                     spillingInfoProvider.getBuffersInOrder(
-                            i, SpillStatus.NOT_SPILL, ConsumeStatus.ALL));
+                            i, SpillStatus.NOT_SPILL, 
ConsumeStatusWithId.ALL_ANY));
         }
     }
 
@@ -135,7 +137,7 @@ public class HsFullSpillingStrategy implements 
HsSpillingStrategy {
         for (int subpartitionId = 0; subpartitionId < numSubpartitions; 
subpartitionId++) {
             Deque<BufferIndexAndChannel> buffersInOrder =
                     spillingInfoProvider.getBuffersInOrder(
-                            subpartitionId, SpillStatus.SPILL, 
ConsumeStatus.ALL);
+                            subpartitionId, SpillStatus.SPILL, 
ConsumeStatusWithId.ALL_ANY);
             // if the number of subpartition buffers less than survived 
buffers, reserved all of
             // them.
             int releaseNum = Math.max(0, buffersInOrder.size() - 
subpartitionSurvivedNum);
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsMemoryDataManager.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsMemoryDataManager.java
index f19cb58a795..a8d7d89e7e4 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsMemoryDataManager.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsMemoryDataManager.java
@@ -223,10 +223,11 @@ public class HsMemoryDataManager implements 
HsSpillingInfoProvider, HsMemoryData
     // Write lock should be acquired before invoke this method.
     @Override
     public Deque<BufferIndexAndChannel> getBuffersInOrder(
-            int subpartitionId, SpillStatus spillStatus, ConsumeStatus 
consumeStatus) {
+            int subpartitionId, SpillStatus spillStatus, ConsumeStatusWithId 
consumeStatusWithId) {
         HsSubpartitionMemoryDataManager targetSubpartitionDataManager =
                 getSubpartitionMemoryDataManager(subpartitionId);
-        return 
targetSubpartitionDataManager.getBuffersSatisfyStatus(spillStatus, 
consumeStatus);
+        return targetSubpartitionDataManager.getBuffersSatisfyStatus(
+                spillStatus, consumeStatusWithId);
     }
 
     // Write lock should be acquired before invoke this method.
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSelectiveSpillingStrategy.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSelectiveSpillingStrategy.java
index dc356cb20d5..8c7217a4b93 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSelectiveSpillingStrategy.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSelectiveSpillingStrategy.java
@@ -19,6 +19,7 @@
 package org.apache.flink.runtime.io.network.partition.hybrid;
 
 import 
org.apache.flink.runtime.io.network.partition.hybrid.HsSpillingInfoProvider.ConsumeStatus;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.HsSpillingInfoProvider.ConsumeStatusWithId;
 import 
org.apache.flink.runtime.io.network.partition.hybrid.HsSpillingInfoProvider.SpillStatus;
 
 import java.util.Deque;
@@ -85,7 +86,11 @@ public class HsSelectiveSpillingStrategy implements 
HsSpillingStrategy {
             subpartitionToBuffers.put(
                     channel,
                     spillingInfoProvider.getBuffersInOrder(
-                            channel, SpillStatus.NOT_SPILL, 
ConsumeStatus.NOT_CONSUMED));
+                            channel,
+                            SpillStatus.NOT_SPILL,
+                            // selective spilling strategy does not support 
multiple consumer.
+                            ConsumeStatusWithId.fromStatusAndConsumerId(
+                                    ConsumeStatus.NOT_CONSUMED, 
HsConsumerId.DEFAULT)));
         }
 
         TreeMap<Integer, List<BufferIndexAndChannel>> 
subpartitionToHighPriorityBuffers =
@@ -113,12 +118,14 @@ public class HsSelectiveSpillingStrategy implements 
HsSpillingStrategy {
                             subpartitionId,
                             // get all not start spilling buffers.
                             spillingInfoProvider.getBuffersInOrder(
-                                    subpartitionId, SpillStatus.NOT_SPILL, 
ConsumeStatus.ALL))
+                                    subpartitionId,
+                                    SpillStatus.NOT_SPILL,
+                                    ConsumeStatusWithId.ALL_ANY))
                     .addBufferToRelease(
                             subpartitionId,
                             // get all not released buffers.
                             spillingInfoProvider.getBuffersInOrder(
-                                    subpartitionId, SpillStatus.ALL, 
ConsumeStatus.ALL));
+                                    subpartitionId, SpillStatus.ALL, 
ConsumeStatusWithId.ALL_ANY));
         }
         return builder.build();
     }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSpillingInfoProvider.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSpillingInfoProvider.java
index 6168c637de4..deb32e7058e 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSpillingInfoProvider.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSpillingInfoProvider.java
@@ -43,13 +43,13 @@ public interface HsSpillingInfoProvider {
      *
      * @param subpartitionId target buffers belong to.
      * @param spillStatus expected buffer spill status.
-     * @param consumeStatus expected buffer consume status.
+     * @param consumeStatusWithId expected buffer consume status and consumer 
id.
      * @return all buffers satisfy specific status of this subpartition, This 
queue must be sorted
      *     according to bufferIndex from small to large, in other words, head 
is the buffer with the
      *     minimum bufferIndex in the current subpartition.
      */
     Deque<BufferIndexAndChannel> getBuffersInOrder(
-            int subpartitionId, SpillStatus spillStatus, ConsumeStatus 
consumeStatus);
+            int subpartitionId, SpillStatus spillStatus, ConsumeStatusWithId 
consumeStatusWithId);
 
     /** Get total number of not decided to spill buffers. */
     int getNumTotalUnSpillBuffers();
@@ -79,4 +79,24 @@ public interface HsSpillingInfoProvider {
         /** The buffer is either consumed or not consumed. */
         ALL
     }
+
+    /** This class represents a pair of {@link ConsumeStatus} and consumer id. 
*/
+    class ConsumeStatusWithId {
+        public static final ConsumeStatusWithId ALL_ANY =
+                new ConsumeStatusWithId(ConsumeStatus.ALL, HsConsumerId.ANY);
+
+        ConsumeStatus status;
+
+        HsConsumerId consumerId;
+
+        private ConsumeStatusWithId(ConsumeStatus status, HsConsumerId 
consumerId) {
+            this.status = status;
+            this.consumerId = consumerId;
+        }
+
+        public static ConsumeStatusWithId fromStatusAndConsumerId(
+                ConsumeStatus consumeStatus, HsConsumerId consumerId) {
+            return new ConsumeStatusWithId(consumeStatus, consumerId);
+        }
+    }
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSubpartitionMemoryDataManager.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSubpartitionMemoryDataManager.java
index 619bc01d05a..f6804a62b1b 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSubpartitionMemoryDataManager.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSubpartitionMemoryDataManager.java
@@ -30,6 +30,7 @@ import 
org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
 import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
 import 
org.apache.flink.runtime.io.network.partition.ResultSubpartition.BufferAndBacklog;
 import 
org.apache.flink.runtime.io.network.partition.hybrid.HsSpillingInfoProvider.ConsumeStatus;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.HsSpillingInfoProvider.ConsumeStatusWithId;
 import 
org.apache.flink.runtime.io.network.partition.hybrid.HsSpillingInfoProvider.SpillStatus;
 import org.apache.flink.util.function.SupplierWithException;
 import org.apache.flink.util.function.ThrowingRunnable;
@@ -144,7 +145,8 @@ public class HsSubpartitionMemoryDataManager implements 
HsDataView {
 
                             HsBufferContext bufferContext =
                                     
checkNotNull(unConsumedBuffers.pollFirst());
-                            bufferContext.consumed();
+                            // TODO move this logical to consumer and pass 
real consumerId.
+                            bufferContext.consumed(HsConsumerId.DEFAULT);
                             DataType nextDataType =
                                     
peekNextToConsumeDataTypeInternal(toConsumeIndex + 1);
                             return Optional.of(Tuple2.of(bufferContext, 
nextDataType));
@@ -200,14 +202,14 @@ public class HsSubpartitionMemoryDataManager implements 
HsDataView {
      * ConsumeStatus}.
      *
      * @param spillStatus the status of spilling expected.
-     * @param consumeStatus the status of consuming expected.
+     * @param consumeStatusWithId the status and consumerId expected.
      * @return buffers satisfy expected status in order.
      */
     @SuppressWarnings("FieldAccessNotGuarded")
     // Note that: callWithLock ensure that code block guarded by 
resultPartitionReadLock and
     // subpartitionLock.
     public Deque<BufferIndexAndChannel> getBuffersSatisfyStatus(
-            SpillStatus spillStatus, ConsumeStatus consumeStatus) {
+            SpillStatus spillStatus, ConsumeStatusWithId consumeStatusWithId) {
         return callWithLock(
                 () -> {
                     // TODO return iterator to avoid completely traversing the 
queue for each call.
@@ -216,7 +218,7 @@ public class HsSubpartitionMemoryDataManager implements 
HsDataView {
                     allBuffers.forEach(
                             (bufferContext -> {
                                 if (isBufferSatisfyStatus(
-                                        bufferContext, spillStatus, 
consumeStatus)) {
+                                        bufferContext, spillStatus, 
consumeStatusWithId)) {
                                     
targetBuffers.add(bufferContext.getBufferIndexAndChannel());
                                 }
                             }));
@@ -460,7 +462,7 @@ public class HsSubpartitionMemoryDataManager implements 
HsDataView {
     @GuardedBy("subpartitionLock")
     private void checkAndMarkBufferReadable(HsBufferContext bufferContext) {
         // only spill buffer needs to be marked as released.
-        if (isBufferSatisfyStatus(bufferContext, SpillStatus.SPILL, 
ConsumeStatus.ALL)) {
+        if (isBufferSatisfyStatus(bufferContext, SpillStatus.SPILL, 
ConsumeStatusWithId.ALL_ANY)) {
             bufferContext
                     .getSpilledFuture()
                     .orElseThrow(
@@ -480,7 +482,9 @@ public class HsSubpartitionMemoryDataManager implements 
HsDataView {
 
     @GuardedBy("subpartitionLock")
     private boolean isBufferSatisfyStatus(
-            HsBufferContext bufferContext, SpillStatus spillStatus, 
ConsumeStatus consumeStatus) {
+            HsBufferContext bufferContext,
+            SpillStatus spillStatus,
+            ConsumeStatusWithId consumeStatusWithId) {
         // released buffer is not needed.
         if (bufferContext.isReleased()) {
             return false;
@@ -494,12 +498,12 @@ public class HsSubpartitionMemoryDataManager implements 
HsDataView {
                 match = bufferContext.isSpillStarted();
                 break;
         }
-        switch (consumeStatus) {
+        switch (consumeStatusWithId.status) {
             case NOT_CONSUMED:
-                match &= !bufferContext.isConsumed();
+                match &= 
!bufferContext.isConsumed(consumeStatusWithId.consumerId);
                 break;
             case CONSUMED:
-                match &= bufferContext.isConsumed();
+                match &= 
bufferContext.isConsumed(consumeStatusWithId.consumerId);
                 break;
         }
         return match;
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsBufferContextTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsBufferContextTest.java
index 960357d4127..f0419d114df 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsBufferContextTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsBufferContextTest.java
@@ -81,25 +81,40 @@ class HsBufferContextTest {
 
     @Test
     void testBufferConsumed() {
+        final HsConsumerId consumerId = HsConsumerId.DEFAULT;
         Buffer buffer = bufferContext.getBuffer();
-        bufferContext.consumed();
-        assertThat(bufferContext.isConsumed()).isTrue();
+        bufferContext.consumed(consumerId);
+        assertThat(bufferContext.isConsumed(consumerId)).isTrue();
         assertThat(buffer.refCnt()).isEqualTo(2);
     }
 
     @Test
     void testBufferConsumedRepeatedly() {
-        bufferContext.consumed();
-        assertThatThrownBy(() -> bufferContext.consumed())
+        final HsConsumerId consumerId = HsConsumerId.DEFAULT;
+        bufferContext.consumed(consumerId);
+        assertThatThrownBy(() -> bufferContext.consumed(consumerId))
                 .isInstanceOf(IllegalStateException.class)
                 .hasMessageContaining("Consume buffer repeatedly is 
unexpected.");
     }
 
+    @Test
+    void testBufferConsumedMultipleConsumer() {
+        HsConsumerId consumer0 = HsConsumerId.newId(null);
+        HsConsumerId consumer1 = HsConsumerId.newId(consumer0);
+        bufferContext.consumed(consumer0);
+        bufferContext.consumed(consumer1);
+
+        assertThat(bufferContext.isConsumed(consumer0)).isTrue();
+        assertThat(bufferContext.isConsumed(consumer1)).isTrue();
+
+        
assertThat(bufferContext.isConsumed(HsConsumerId.newId(consumer1))).isFalse();
+    }
+
     @Test
     void testBufferStartSpillOrConsumedAfterReleased() {
         bufferContext.release();
         assertThat(bufferContext.startSpilling(new 
CompletableFuture<>())).isFalse();
-        assertThatThrownBy(() -> bufferContext.consumed())
+        assertThatThrownBy(() -> bufferContext.consumed(HsConsumerId.DEFAULT))
                 .isInstanceOf(IllegalStateException.class)
                 .hasMessageContaining("Buffer is already released.");
     }
@@ -117,7 +132,7 @@ class HsBufferContextTest {
     @Test
     void testBufferConsumedThenRelease() {
         Buffer buffer = bufferContext.getBuffer();
-        bufferContext.consumed();
+        bufferContext.consumed(HsConsumerId.DEFAULT);
         bufferContext.release();
         assertThat(buffer.refCnt()).isEqualTo(1);
     }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsConsumerIdTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsConsumerIdTest.java
new file mode 100644
index 00000000000..591d6b0ccfd
--- /dev/null
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsConsumerIdTest.java
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.partition.hybrid;
+
+import org.junit.jupiter.api.Test;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/** Tests for {@link HsConsumerId}. */
+class HsConsumerIdTest {
+    @Test
+    void testNewIdFromNull() {
+        HsConsumerId consumerId = HsConsumerId.newId(null);
+        assertThat(consumerId).isNotNull().isEqualTo(HsConsumerId.DEFAULT);
+    }
+
+    @Test
+    void testConsumerIdEquals() {
+        HsConsumerId consumerId = HsConsumerId.newId(null);
+        HsConsumerId consumerId1 = HsConsumerId.newId(consumerId);
+        HsConsumerId consumerId2 = HsConsumerId.newId(consumerId);
+        assertThat(consumerId1.hashCode()).isEqualTo(consumerId2.hashCode());
+        assertThat(consumerId1).isEqualTo(consumerId2);
+
+        assertThat(HsConsumerId.newId(consumerId2)).isNotEqualTo(consumerId2);
+    }
+}
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSubpartitionMemoryDataManagerTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSubpartitionMemoryDataManagerTest.java
index ba53f2a9efb..174c021990d 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSubpartitionMemoryDataManagerTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSubpartitionMemoryDataManagerTest.java
@@ -53,6 +53,8 @@ import java.util.concurrent.locks.ReentrantReadWriteLock;
 import java.util.stream.Collectors;
 import java.util.stream.IntStream;
 
+import static 
org.apache.flink.runtime.io.network.partition.hybrid.HsSpillingInfoProvider.ConsumeStatusWithId.ALL_ANY;
+import static 
org.apache.flink.runtime.io.network.partition.hybrid.HsSpillingInfoProvider.ConsumeStatusWithId.fromStatusAndConsumerId;
 import static 
org.apache.flink.runtime.io.network.partition.hybrid.HybridShuffleTestUtils.createBufferBuilder;
 import static 
org.apache.flink.runtime.io.network.partition.hybrid.HybridShuffleTestUtils.createTestingOutputMetrics;
 import static org.assertj.core.api.Assertions.assertThat;
@@ -290,40 +292,44 @@ class HsSubpartitionMemoryDataManagerTest {
         subpartitionMemoryDataManager.consumeBuffer(1);
 
         checkBufferIndex(
-                subpartitionMemoryDataManager.getBuffersSatisfyStatus(
-                        SpillStatus.ALL, ConsumeStatus.ALL),
+                
subpartitionMemoryDataManager.getBuffersSatisfyStatus(SpillStatus.ALL, ALL_ANY),
                 Arrays.asList(0, 1, 2, 3));
         checkBufferIndex(
                 subpartitionMemoryDataManager.getBuffersSatisfyStatus(
-                        SpillStatus.ALL, ConsumeStatus.CONSUMED),
+                        SpillStatus.ALL,
+                        fromStatusAndConsumerId(ConsumeStatus.CONSUMED, 
HsConsumerId.DEFAULT)),
                 Arrays.asList(0, 1));
         checkBufferIndex(
                 subpartitionMemoryDataManager.getBuffersSatisfyStatus(
-                        SpillStatus.ALL, ConsumeStatus.NOT_CONSUMED),
+                        SpillStatus.ALL,
+                        fromStatusAndConsumerId(ConsumeStatus.NOT_CONSUMED, 
HsConsumerId.DEFAULT)),
                 Arrays.asList(2, 3));
         checkBufferIndex(
-                subpartitionMemoryDataManager.getBuffersSatisfyStatus(
-                        SpillStatus.SPILL, ConsumeStatus.ALL),
+                
subpartitionMemoryDataManager.getBuffersSatisfyStatus(SpillStatus.SPILL, 
ALL_ANY),
                 Arrays.asList(1, 2));
         checkBufferIndex(
                 subpartitionMemoryDataManager.getBuffersSatisfyStatus(
-                        SpillStatus.NOT_SPILL, ConsumeStatus.ALL),
+                        SpillStatus.NOT_SPILL, ALL_ANY),
                 Arrays.asList(0, 3));
         checkBufferIndex(
                 subpartitionMemoryDataManager.getBuffersSatisfyStatus(
-                        SpillStatus.SPILL, ConsumeStatus.NOT_CONSUMED),
+                        SpillStatus.SPILL,
+                        fromStatusAndConsumerId(ConsumeStatus.NOT_CONSUMED, 
HsConsumerId.DEFAULT)),
                 Collections.singletonList(2));
         checkBufferIndex(
                 subpartitionMemoryDataManager.getBuffersSatisfyStatus(
-                        SpillStatus.SPILL, ConsumeStatus.CONSUMED),
+                        SpillStatus.SPILL,
+                        fromStatusAndConsumerId(ConsumeStatus.CONSUMED, 
HsConsumerId.DEFAULT)),
                 Collections.singletonList(1));
         checkBufferIndex(
                 subpartitionMemoryDataManager.getBuffersSatisfyStatus(
-                        SpillStatus.NOT_SPILL, ConsumeStatus.CONSUMED),
+                        SpillStatus.NOT_SPILL,
+                        fromStatusAndConsumerId(ConsumeStatus.CONSUMED, 
HsConsumerId.DEFAULT)),
                 Collections.singletonList(0));
         checkBufferIndex(
                 subpartitionMemoryDataManager.getBuffersSatisfyStatus(
-                        SpillStatus.NOT_SPILL, ConsumeStatus.NOT_CONSUMED),
+                        SpillStatus.NOT_SPILL,
+                        fromStatusAndConsumerId(ConsumeStatus.NOT_CONSUMED, 
HsConsumerId.DEFAULT)),
                 Collections.singletonList(3));
     }
 
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/TestingSpillingInfoProvider.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/TestingSpillingInfoProvider.java
index f8b303479f0..67b34d0e774 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/TestingSpillingInfoProvider.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/TestingSpillingInfoProvider.java
@@ -79,7 +79,7 @@ public class TestingSpillingInfoProvider implements 
HsSpillingInfoProvider {
 
     @Override
     public Deque<BufferIndexAndChannel> getBuffersInOrder(
-            int subpartitionId, SpillStatus spillStatus, ConsumeStatus 
consumeStatus) {
+            int subpartitionId, SpillStatus spillStatus, ConsumeStatusWithId 
consumeStatusWithId) {
         Deque<BufferIndexAndChannel> buffersInOrder = new ArrayDeque<>();
 
         List<BufferIndexAndChannel> subpartitionBuffers = 
allBuffers.get(subpartitionId);
@@ -90,7 +90,7 @@ public class TestingSpillingInfoProvider implements 
HsSpillingInfoProvider {
         for (int i = 0; i < subpartitionBuffers.size(); i++) {
             if (isBufferSatisfyStatus(
                     spillStatus,
-                    consumeStatus,
+                    consumeStatusWithId,
                     spillBufferIndexes
                             .getOrDefault(subpartitionId, 
Collections.emptySet())
                             .contains(i),
@@ -124,7 +124,7 @@ public class TestingSpillingInfoProvider implements 
HsSpillingInfoProvider {
 
     private static boolean isBufferSatisfyStatus(
             SpillStatus spillStatus,
-            ConsumeStatus consumeStatus,
+            ConsumeStatusWithId consumeStatusWithId,
             boolean isSpill,
             boolean isConsumed) {
         boolean isNeeded = true;
@@ -136,7 +136,7 @@ public class TestingSpillingInfoProvider implements 
HsSpillingInfoProvider {
                 isNeeded = isSpill;
                 break;
         }
-        switch (consumeStatus) {
+        switch (consumeStatusWithId.status) {
             case NOT_CONSUMED:
                 isNeeded &= !isConsumed;
                 break;

Reply via email to