TanYuxin-tyx commented on code in PR #22342: URL: https://github.com/apache/flink/pull/22342#discussion_r1223761576
########## flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/netty/TieredStorageNettyServiceImpl.java: ########## @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.partition.hybrid.tiered.netty; + +import org.apache.flink.runtime.io.network.partition.BufferAvailabilityListener; +import org.apache.flink.runtime.io.network.partition.ResultSubpartitionView; +import org.apache.flink.runtime.io.network.partition.consumer.InputChannel; +import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate; +import org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStoragePartitionId; +import org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStorageSubpartitionId; +import org.apache.flink.runtime.io.network.partition.hybrid.tiered.shuffle.TieredResultPartition; +import org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage.NettyPayload; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Queue; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.function.Supplier; + +import static org.apache.flink.util.Preconditions.checkState; + +/** The default implementation of {@link TieredStorageNettyService}. */ +public class TieredStorageNettyServiceImpl implements TieredStorageNettyService { + + // ------------------------------------ + // For producer side + // ------------------------------------ + + private final Map<TieredStoragePartitionId, List<NettyServiceProducer>> + registeredServiceProducers = new ConcurrentHashMap<>(); + + private final Map<NettyConnectionId, BufferAvailabilityListener> + registeredAvailabilityListeners = new ConcurrentHashMap<>(); + + // ------------------------------------ + // For consumer side + // ------------------------------------ + + private final Map<TieredStoragePartitionId, Map<TieredStorageSubpartitionId, Integer>> + registeredChannelIndexes = new ConcurrentHashMap<>(); + + private final Map< + TieredStoragePartitionId, + Map<TieredStorageSubpartitionId, Supplier<InputChannel>>> + registeredInputChannelProviders = new ConcurrentHashMap<>(); + + private final Map< + TieredStoragePartitionId, + Map< + TieredStorageSubpartitionId, + NettyConnectionReaderAvailabilityAndPriorityHelper>> + registeredNettyConnectionReaderAvailabilityAndPriorityHelpers = + new ConcurrentHashMap<>(); + + @Override + public void registerProducer( + TieredStoragePartitionId partitionId, NettyServiceProducer serviceProducer) { + List<NettyServiceProducer> serviceProducers = + registeredServiceProducers.getOrDefault(partitionId, new ArrayList<>()); + serviceProducers.add(serviceProducer); + registeredServiceProducers.put(partitionId, serviceProducers); + } + + @Override + public NettyConnectionReader registerConsumer( + TieredStoragePartitionId partitionId, TieredStorageSubpartitionId subpartitionId) { + Integer channelIndex = registeredChannelIndexes.get(partitionId).remove(subpartitionId); + if (registeredChannelIndexes.get(partitionId).isEmpty()) { + registeredChannelIndexes.remove(partitionId); + } + + Supplier<InputChannel> inputChannelProvider = + registeredInputChannelProviders.get(partitionId).remove(subpartitionId); + if (registeredInputChannelProviders.get(partitionId).isEmpty()) { + registeredInputChannelProviders.remove(partitionId); + } + + NettyConnectionReaderAvailabilityAndPriorityHelper helper = + registeredNettyConnectionReaderAvailabilityAndPriorityHelpers + .get(partitionId) + .remove(subpartitionId); + if (registeredNettyConnectionReaderAvailabilityAndPriorityHelpers + .get(partitionId) + .isEmpty()) { + registeredNettyConnectionReaderAvailabilityAndPriorityHelpers.remove(partitionId); + } + return new NettyConnectionReaderImpl(channelIndex, inputChannelProvider, helper); + } + + /** + * Create a {@link ResultSubpartitionView} for the netty server. + * + * @param partitionId partition id indicates the unique id of {@link TieredResultPartition}. + * @param subpartitionId subpartition id indicates the unique id of subpartition. + * @param availabilityListener listener is used to listen the available status of data. + * @return the {@link TieredStoreResultSubpartitionView}. + */ + public ResultSubpartitionView createResultSubpartitionView( + TieredStoragePartitionId partitionId, + TieredStorageSubpartitionId subpartitionId, + BufferAvailabilityListener availabilityListener) { + List<NettyServiceProducer> serviceProducers = registeredServiceProducers.get(partitionId); + if (serviceProducers == null) { + return new TieredStoreResultSubpartitionView( + availabilityListener, new ArrayList<>(), new ArrayList<>(), new ArrayList<>()); + } + List<Queue<NettyPayload>> queues = new ArrayList<>(); + List<NettyConnectionId> nettyConnectionIds = new ArrayList<>(); + for (NettyServiceProducer serviceProducer : serviceProducers) { + LinkedBlockingQueue<NettyPayload> queue = new LinkedBlockingQueue<>(); + NettyConnectionWriterImpl writer = new NettyConnectionWriterImpl(queue); + serviceProducer.connectionEstablished(subpartitionId, writer); + nettyConnectionIds.add(writer.getNettyConnectionId()); + queues.add(queue); + registeredAvailabilityListeners.put( + writer.getNettyConnectionId(), availabilityListener); + } + return new TieredStoreResultSubpartitionView( + availabilityListener, + queues, + nettyConnectionIds, + registeredServiceProducers.get(partitionId)); + } + + /** + * Notify the {@link ResultSubpartitionView} to send buffer. + * + * @param connectionId connection id indicates the id of connection. + */ + public void notifyResultSubpartitionViewSendBuffer(NettyConnectionId connectionId) { + BufferAvailabilityListener listener = registeredAvailabilityListeners.get(connectionId); + if (listener != null) { + listener.notifyDataAvailable(); + } + } + + /** + * Set up input channels in {@link SingleInputGate}. + * + * @param partitionIds partition ids indicates the ids of {@link TieredResultPartition}. + * @param subpartitionIds subpartition ids indicates the ids of subpartition. + */ + public void setUpInputChannels( + TieredStoragePartitionId[] partitionIds, + TieredStorageSubpartitionId[] subpartitionIds, + List<Supplier<InputChannel>> inputChannelProviders, + NettyConnectionReaderAvailabilityAndPriorityHelper helper) { + checkState(partitionIds.length == subpartitionIds.length); + checkState(subpartitionIds.length == inputChannelProviders.size()); + for (int index = 0; index < partitionIds.length; ++index) { + TieredStoragePartitionId partitionId = partitionIds[index]; + TieredStorageSubpartitionId subpartitionId = subpartitionIds[index]; + + Map<TieredStorageSubpartitionId, Integer> channelIndexes = + registeredChannelIndexes.getOrDefault(partitionId, new ConcurrentHashMap<>()); Review Comment: Maybe we can simplify it as `registeredChannelIndexes.computeIfAbsent().put()` mode. ########## flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/netty/TieredStorageNettyServiceImpl.java: ########## @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.partition.hybrid.tiered.netty; + +import org.apache.flink.runtime.io.network.partition.BufferAvailabilityListener; +import org.apache.flink.runtime.io.network.partition.ResultSubpartitionView; +import org.apache.flink.runtime.io.network.partition.consumer.InputChannel; +import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate; +import org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStoragePartitionId; +import org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStorageSubpartitionId; +import org.apache.flink.runtime.io.network.partition.hybrid.tiered.shuffle.TieredResultPartition; +import org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage.NettyPayload; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Queue; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.function.Supplier; + +import static org.apache.flink.util.Preconditions.checkState; + +/** The default implementation of {@link TieredStorageNettyService}. */ +public class TieredStorageNettyServiceImpl implements TieredStorageNettyService { + + // ------------------------------------ + // For producer side + // ------------------------------------ + + private final Map<TieredStoragePartitionId, List<NettyServiceProducer>> + registeredServiceProducers = new ConcurrentHashMap<>(); + + private final Map<NettyConnectionId, BufferAvailabilityListener> + registeredAvailabilityListeners = new ConcurrentHashMap<>(); + + // ------------------------------------ + // For consumer side + // ------------------------------------ + + private final Map<TieredStoragePartitionId, Map<TieredStorageSubpartitionId, Integer>> + registeredChannelIndexes = new ConcurrentHashMap<>(); + + private final Map< + TieredStoragePartitionId, + Map<TieredStorageSubpartitionId, Supplier<InputChannel>>> + registeredInputChannelProviders = new ConcurrentHashMap<>(); + + private final Map< + TieredStoragePartitionId, + Map< + TieredStorageSubpartitionId, + NettyConnectionReaderAvailabilityAndPriorityHelper>> + registeredNettyConnectionReaderAvailabilityAndPriorityHelpers = + new ConcurrentHashMap<>(); + + @Override + public void registerProducer( + TieredStoragePartitionId partitionId, NettyServiceProducer serviceProducer) { + List<NettyServiceProducer> serviceProducers = + registeredServiceProducers.getOrDefault(partitionId, new ArrayList<>()); + serviceProducers.add(serviceProducer); + registeredServiceProducers.put(partitionId, serviceProducers); + } + + @Override + public NettyConnectionReader registerConsumer( + TieredStoragePartitionId partitionId, TieredStorageSubpartitionId subpartitionId) { + Integer channelIndex = registeredChannelIndexes.get(partitionId).remove(subpartitionId); + if (registeredChannelIndexes.get(partitionId).isEmpty()) { + registeredChannelIndexes.remove(partitionId); + } + + Supplier<InputChannel> inputChannelProvider = + registeredInputChannelProviders.get(partitionId).remove(subpartitionId); + if (registeredInputChannelProviders.get(partitionId).isEmpty()) { + registeredInputChannelProviders.remove(partitionId); + } + + NettyConnectionReaderAvailabilityAndPriorityHelper helper = + registeredNettyConnectionReaderAvailabilityAndPriorityHelpers + .get(partitionId) + .remove(subpartitionId); + if (registeredNettyConnectionReaderAvailabilityAndPriorityHelpers + .get(partitionId) + .isEmpty()) { + registeredNettyConnectionReaderAvailabilityAndPriorityHelpers.remove(partitionId); + } + return new NettyConnectionReaderImpl(channelIndex, inputChannelProvider, helper); + } + + /** + * Create a {@link ResultSubpartitionView} for the netty server. + * + * @param partitionId partition id indicates the unique id of {@link TieredResultPartition}. + * @param subpartitionId subpartition id indicates the unique id of subpartition. + * @param availabilityListener listener is used to listen the available status of data. + * @return the {@link TieredStoreResultSubpartitionView}. + */ + public ResultSubpartitionView createResultSubpartitionView( + TieredStoragePartitionId partitionId, + TieredStorageSubpartitionId subpartitionId, + BufferAvailabilityListener availabilityListener) { + List<NettyServiceProducer> serviceProducers = registeredServiceProducers.get(partitionId); + if (serviceProducers == null) { + return new TieredStoreResultSubpartitionView( + availabilityListener, new ArrayList<>(), new ArrayList<>(), new ArrayList<>()); + } + List<Queue<NettyPayload>> queues = new ArrayList<>(); + List<NettyConnectionId> nettyConnectionIds = new ArrayList<>(); + for (NettyServiceProducer serviceProducer : serviceProducers) { + LinkedBlockingQueue<NettyPayload> queue = new LinkedBlockingQueue<>(); + NettyConnectionWriterImpl writer = new NettyConnectionWriterImpl(queue); + serviceProducer.connectionEstablished(subpartitionId, writer); + nettyConnectionIds.add(writer.getNettyConnectionId()); + queues.add(queue); + registeredAvailabilityListeners.put( + writer.getNettyConnectionId(), availabilityListener); + } + return new TieredStoreResultSubpartitionView( + availabilityListener, + queues, + nettyConnectionIds, + registeredServiceProducers.get(partitionId)); + } + + /** + * Notify the {@link ResultSubpartitionView} to send buffer. + * + * @param connectionId connection id indicates the id of connection. + */ + public void notifyResultSubpartitionViewSendBuffer(NettyConnectionId connectionId) { + BufferAvailabilityListener listener = registeredAvailabilityListeners.get(connectionId); + if (listener != null) { + listener.notifyDataAvailable(); + } + } + + /** + * Set up input channels in {@link SingleInputGate}. + * + * @param partitionIds partition ids indicates the ids of {@link TieredResultPartition}. + * @param subpartitionIds subpartition ids indicates the ids of subpartition. + */ + public void setUpInputChannels( Review Comment: ```suggestion public void setupInputChannels( ``` ########## flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/netty/TestingNettyServiceProducer.java: ########## @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.partition.hybrid.tiered.netty; + +import org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStorageSubpartitionId; + +import java.util.function.BiConsumer; +import java.util.function.Consumer; + +/** Test implementation for {@link NettyServiceProducer}. */ +public class TestingNettyServiceProducer implements NettyServiceProducer { + + private final BiConsumer<TieredStorageSubpartitionId, NettyConnectionWriter> + connectionEstablishConsumer; + + private final Consumer<NettyConnectionId> connectionBrokenConsumer; + + public TestingNettyServiceProducer( Review Comment: We'd better make this private to avoid using the constructor directly. It is necessary to use Builder to create the test object. There are many similar cases, we should also change them. ########## flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/netty/TieredStoreResultSubpartitionViewTest.java: ########## @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.partition.hybrid.tiered.netty; + +import org.apache.flink.core.memory.MemorySegmentFactory; +import org.apache.flink.runtime.io.network.buffer.BufferRecycler; +import org.apache.flink.runtime.io.network.buffer.NetworkBuffer; +import org.apache.flink.runtime.io.network.partition.BufferAvailabilityListener; +import org.apache.flink.runtime.io.network.partition.ResultSubpartition.BufferAndBacklog; +import org.apache.flink.runtime.io.network.partition.ResultSubpartitionView; +import org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage.NettyPayload; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Queue; +import java.util.concurrent.CompletableFuture; + +import static org.apache.flink.runtime.io.network.buffer.Buffer.DataType.END_OF_SEGMENT; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Tests for {@link TieredStoreResultSubpartitionView}. */ +public class TieredStoreResultSubpartitionViewTest { + + private static final int TIER_NUMBER = 2; + + private CompletableFuture<Void> availabilityListener; + + private List<Queue<NettyPayload>> nettyPayloadQueues; + + private List<CompletableFuture<NettyConnectionId>> connectionBrokenConsumers; + + private TieredStoreResultSubpartitionView tieredStoreResultSubpartitionView; + + @BeforeEach + void before() { + availabilityListener = new CompletableFuture<>(); + nettyPayloadQueues = createNettyPayloadQueues(); + connectionBrokenConsumers = + Arrays.asList(new CompletableFuture<>(), new CompletableFuture<>()); + tieredStoreResultSubpartitionView = + new TieredStoreResultSubpartitionView( + createBufferAvailabilityListener(availabilityListener), + nettyPayloadQueues, + createNettyConnectionIds(), + createNettyServiceProducers(connectionBrokenConsumers)); + } + + @Test + void testGetNextBuffer() throws IOException { + checkBufferAndBacklog(tieredStoreResultSubpartitionView.getNextBuffer(), 1); + checkBufferAndBacklog(tieredStoreResultSubpartitionView.getNextBuffer(), 0); + tieredStoreResultSubpartitionView.notifyRequiredSegmentId(1); + assertThat(availabilityListener).isDone(); + checkBufferAndBacklog(tieredStoreResultSubpartitionView.getNextBuffer(), 1); + checkBufferAndBacklog(tieredStoreResultSubpartitionView.getNextBuffer(), 0); + assertThat(tieredStoreResultSubpartitionView.getNextBuffer()).isNull(); + } + + @Test + void testGetNextBufferFailed() { + Throwable expectedError = new IOException(); + nettyPayloadQueues = createNettyPayloadQueuesWithError(expectedError); + tieredStoreResultSubpartitionView = + new TieredStoreResultSubpartitionView( + createBufferAvailabilityListener(availabilityListener), + nettyPayloadQueues, + createNettyConnectionIds(), + createNettyServiceProducers(connectionBrokenConsumers)); + assertThatThrownBy(tieredStoreResultSubpartitionView::getNextBuffer) + .hasCause(expectedError); + assertThat(connectionBrokenConsumers.get(0).isDone()).isTrue(); + } + + @Test + void testGetAvailabilityAndBacklog() { + ResultSubpartitionView.AvailabilityWithBacklog availabilityAndBacklog1 = + tieredStoreResultSubpartitionView.getAvailabilityAndBacklog(0); + assertThat(availabilityAndBacklog1.getBacklog()).isEqualTo(3); + assertThat(availabilityAndBacklog1.isAvailable()).isEqualTo(false); + ResultSubpartitionView.AvailabilityWithBacklog availabilityAndBacklog2 = + tieredStoreResultSubpartitionView.getAvailabilityAndBacklog(2); + assertThat(availabilityAndBacklog2.getBacklog()).isEqualTo(3); + assertThat(availabilityAndBacklog2.isAvailable()).isEqualTo(true); + } + + @Test + void testNotifyRequiredSegmentId() { + tieredStoreResultSubpartitionView.notifyRequiredSegmentId(1); + assertThat(availabilityListener.isDone()).isTrue(); + } + + @Test + void testReleaseAllResources() throws IOException { + tieredStoreResultSubpartitionView.releaseAllResources(); + assertThat(nettyPayloadQueues.get(0)).hasSize(0); + assertThat(nettyPayloadQueues.get(1)).hasSize(0); + assertThat(connectionBrokenConsumers.get(0).isDone()).isTrue(); + assertThat(connectionBrokenConsumers.get(1).isDone()).isTrue(); + assertThat(tieredStoreResultSubpartitionView.isReleased()).isTrue(); + } + + @Test + void testGetNumberOfQueuedBuffers() { + assertThat(tieredStoreResultSubpartitionView.getNumberOfQueuedBuffers()).isEqualTo(3); + assertThat(tieredStoreResultSubpartitionView.unsynchronizedGetNumberOfQueuedBuffers()) + .isEqualTo(3); + } + + private void checkBufferAndBacklog(BufferAndBacklog bufferAndBacklog, int backlog) { + assertThat(bufferAndBacklog).isNotNull(); + assertThat(bufferAndBacklog.buffer()).isNotNull(); + assertThat(bufferAndBacklog.buffersInBacklog()).isEqualTo(backlog); + } + + private BufferAvailabilityListener createBufferAvailabilityListener( + CompletableFuture<Void> notifier) { + return () -> notifier.complete(null); + } + + private List<Queue<NettyPayload>> createNettyPayloadQueues() { Review Comment: ```suggestion private static List<Queue<NettyPayload>> createNettyPayloadQueues() { ``` ########## flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/netty/TieredStoreResultSubpartitionViewTest.java: ########## @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.partition.hybrid.tiered.netty; + +import org.apache.flink.core.memory.MemorySegmentFactory; +import org.apache.flink.runtime.io.network.buffer.BufferRecycler; +import org.apache.flink.runtime.io.network.buffer.NetworkBuffer; +import org.apache.flink.runtime.io.network.partition.BufferAvailabilityListener; +import org.apache.flink.runtime.io.network.partition.ResultSubpartition.BufferAndBacklog; +import org.apache.flink.runtime.io.network.partition.ResultSubpartitionView; +import org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage.NettyPayload; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Queue; +import java.util.concurrent.CompletableFuture; + +import static org.apache.flink.runtime.io.network.buffer.Buffer.DataType.END_OF_SEGMENT; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Tests for {@link TieredStoreResultSubpartitionView}. */ +public class TieredStoreResultSubpartitionViewTest { + + private static final int TIER_NUMBER = 2; + + private CompletableFuture<Void> availabilityListener; + + private List<Queue<NettyPayload>> nettyPayloadQueues; + + private List<CompletableFuture<NettyConnectionId>> connectionBrokenConsumers; + + private TieredStoreResultSubpartitionView tieredStoreResultSubpartitionView; + + @BeforeEach + void before() { + availabilityListener = new CompletableFuture<>(); + nettyPayloadQueues = createNettyPayloadQueues(); + connectionBrokenConsumers = + Arrays.asList(new CompletableFuture<>(), new CompletableFuture<>()); + tieredStoreResultSubpartitionView = + new TieredStoreResultSubpartitionView( + createBufferAvailabilityListener(availabilityListener), + nettyPayloadQueues, + createNettyConnectionIds(), + createNettyServiceProducers(connectionBrokenConsumers)); + } + + @Test + void testGetNextBuffer() throws IOException { + checkBufferAndBacklog(tieredStoreResultSubpartitionView.getNextBuffer(), 1); + checkBufferAndBacklog(tieredStoreResultSubpartitionView.getNextBuffer(), 0); + tieredStoreResultSubpartitionView.notifyRequiredSegmentId(1); + assertThat(availabilityListener).isDone(); + checkBufferAndBacklog(tieredStoreResultSubpartitionView.getNextBuffer(), 1); + checkBufferAndBacklog(tieredStoreResultSubpartitionView.getNextBuffer(), 0); + assertThat(tieredStoreResultSubpartitionView.getNextBuffer()).isNull(); + } + + @Test + void testGetNextBufferFailed() { + Throwable expectedError = new IOException(); + nettyPayloadQueues = createNettyPayloadQueuesWithError(expectedError); + tieredStoreResultSubpartitionView = + new TieredStoreResultSubpartitionView( + createBufferAvailabilityListener(availabilityListener), + nettyPayloadQueues, + createNettyConnectionIds(), + createNettyServiceProducers(connectionBrokenConsumers)); + assertThatThrownBy(tieredStoreResultSubpartitionView::getNextBuffer) + .hasCause(expectedError); + assertThat(connectionBrokenConsumers.get(0).isDone()).isTrue(); + } + + @Test + void testGetAvailabilityAndBacklog() { + ResultSubpartitionView.AvailabilityWithBacklog availabilityAndBacklog1 = + tieredStoreResultSubpartitionView.getAvailabilityAndBacklog(0); + assertThat(availabilityAndBacklog1.getBacklog()).isEqualTo(3); + assertThat(availabilityAndBacklog1.isAvailable()).isEqualTo(false); + ResultSubpartitionView.AvailabilityWithBacklog availabilityAndBacklog2 = + tieredStoreResultSubpartitionView.getAvailabilityAndBacklog(2); + assertThat(availabilityAndBacklog2.getBacklog()).isEqualTo(3); + assertThat(availabilityAndBacklog2.isAvailable()).isEqualTo(true); + } + + @Test + void testNotifyRequiredSegmentId() { + tieredStoreResultSubpartitionView.notifyRequiredSegmentId(1); + assertThat(availabilityListener.isDone()).isTrue(); + } + + @Test + void testReleaseAllResources() throws IOException { + tieredStoreResultSubpartitionView.releaseAllResources(); + assertThat(nettyPayloadQueues.get(0)).hasSize(0); + assertThat(nettyPayloadQueues.get(1)).hasSize(0); + assertThat(connectionBrokenConsumers.get(0).isDone()).isTrue(); + assertThat(connectionBrokenConsumers.get(1).isDone()).isTrue(); + assertThat(tieredStoreResultSubpartitionView.isReleased()).isTrue(); + } + + @Test + void testGetNumberOfQueuedBuffers() { + assertThat(tieredStoreResultSubpartitionView.getNumberOfQueuedBuffers()).isEqualTo(3); + assertThat(tieredStoreResultSubpartitionView.unsynchronizedGetNumberOfQueuedBuffers()) + .isEqualTo(3); + } + + private void checkBufferAndBacklog(BufferAndBacklog bufferAndBacklog, int backlog) { + assertThat(bufferAndBacklog).isNotNull(); + assertThat(bufferAndBacklog.buffer()).isNotNull(); + assertThat(bufferAndBacklog.buffersInBacklog()).isEqualTo(backlog); + } + + private BufferAvailabilityListener createBufferAvailabilityListener( + CompletableFuture<Void> notifier) { + return () -> notifier.complete(null); + } + + private List<Queue<NettyPayload>> createNettyPayloadQueues() { + List<Queue<NettyPayload>> nettyPayloadQueues = new ArrayList<>(); + for (int index = 0; index < TIER_NUMBER; ++index) { + Queue<NettyPayload> queue = new ArrayDeque<>(); + queue.add(NettyPayload.newSegment(index)); + queue.add( + NettyPayload.newBuffer( + new NetworkBuffer( + MemorySegmentFactory.allocateUnpooledSegment(0), + BufferRecycler.DummyBufferRecycler.INSTANCE), + 0, + index)); + queue.add( + NettyPayload.newBuffer( + new NetworkBuffer( + MemorySegmentFactory.allocateUnpooledSegment(0), + BufferRecycler.DummyBufferRecycler.INSTANCE, + END_OF_SEGMENT), + 1, + index)); + nettyPayloadQueues.add(queue); + } + return nettyPayloadQueues; + } + + private List<Queue<NettyPayload>> createNettyPayloadQueuesWithError(Throwable error) { + List<Queue<NettyPayload>> nettyPayloadQueues = new ArrayList<>(); + for (int index = 0; index < TIER_NUMBER; ++index) { + Queue<NettyPayload> queue = new ArrayDeque<>(); + queue.add(NettyPayload.newSegment(index)); + queue.add(NettyPayload.newError(error)); + nettyPayloadQueues.add(queue); + } + return nettyPayloadQueues; + } + + private List<NettyConnectionId> createNettyConnectionIds() { + List<NettyConnectionId> nettyConnectionIds = new ArrayList<>(); + for (int index = 0; index < TIER_NUMBER; ++index) { + nettyConnectionIds.add(NettyConnectionId.newId()); + } + return nettyConnectionIds; + } + + private List<NettyServiceProducer> createNettyServiceProducers( Review Comment: ```suggestion private static List<NettyServiceProducer> createNettyServiceProducers( ``` ########## flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/netty/TieredStoreResultSubpartitionViewTest.java: ########## @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.partition.hybrid.tiered.netty; + +import org.apache.flink.core.memory.MemorySegmentFactory; +import org.apache.flink.runtime.io.network.buffer.BufferRecycler; +import org.apache.flink.runtime.io.network.buffer.NetworkBuffer; +import org.apache.flink.runtime.io.network.partition.BufferAvailabilityListener; +import org.apache.flink.runtime.io.network.partition.ResultSubpartition.BufferAndBacklog; +import org.apache.flink.runtime.io.network.partition.ResultSubpartitionView; +import org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage.NettyPayload; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Queue; +import java.util.concurrent.CompletableFuture; + +import static org.apache.flink.runtime.io.network.buffer.Buffer.DataType.END_OF_SEGMENT; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Tests for {@link TieredStoreResultSubpartitionView}. */ +public class TieredStoreResultSubpartitionViewTest { + + private static final int TIER_NUMBER = 2; + + private CompletableFuture<Void> availabilityListener; + + private List<Queue<NettyPayload>> nettyPayloadQueues; + + private List<CompletableFuture<NettyConnectionId>> connectionBrokenConsumers; + + private TieredStoreResultSubpartitionView tieredStoreResultSubpartitionView; + + @BeforeEach + void before() { + availabilityListener = new CompletableFuture<>(); + nettyPayloadQueues = createNettyPayloadQueues(); + connectionBrokenConsumers = + Arrays.asList(new CompletableFuture<>(), new CompletableFuture<>()); + tieredStoreResultSubpartitionView = + new TieredStoreResultSubpartitionView( + createBufferAvailabilityListener(availabilityListener), + nettyPayloadQueues, + createNettyConnectionIds(), + createNettyServiceProducers(connectionBrokenConsumers)); + } + + @Test + void testGetNextBuffer() throws IOException { + checkBufferAndBacklog(tieredStoreResultSubpartitionView.getNextBuffer(), 1); + checkBufferAndBacklog(tieredStoreResultSubpartitionView.getNextBuffer(), 0); + tieredStoreResultSubpartitionView.notifyRequiredSegmentId(1); + assertThat(availabilityListener).isDone(); + checkBufferAndBacklog(tieredStoreResultSubpartitionView.getNextBuffer(), 1); + checkBufferAndBacklog(tieredStoreResultSubpartitionView.getNextBuffer(), 0); + assertThat(tieredStoreResultSubpartitionView.getNextBuffer()).isNull(); + } + + @Test + void testGetNextBufferFailed() { + Throwable expectedError = new IOException(); + nettyPayloadQueues = createNettyPayloadQueuesWithError(expectedError); + tieredStoreResultSubpartitionView = + new TieredStoreResultSubpartitionView( + createBufferAvailabilityListener(availabilityListener), + nettyPayloadQueues, + createNettyConnectionIds(), + createNettyServiceProducers(connectionBrokenConsumers)); + assertThatThrownBy(tieredStoreResultSubpartitionView::getNextBuffer) + .hasCause(expectedError); + assertThat(connectionBrokenConsumers.get(0).isDone()).isTrue(); + } + + @Test + void testGetAvailabilityAndBacklog() { + ResultSubpartitionView.AvailabilityWithBacklog availabilityAndBacklog1 = + tieredStoreResultSubpartitionView.getAvailabilityAndBacklog(0); + assertThat(availabilityAndBacklog1.getBacklog()).isEqualTo(3); + assertThat(availabilityAndBacklog1.isAvailable()).isEqualTo(false); + ResultSubpartitionView.AvailabilityWithBacklog availabilityAndBacklog2 = + tieredStoreResultSubpartitionView.getAvailabilityAndBacklog(2); + assertThat(availabilityAndBacklog2.getBacklog()).isEqualTo(3); + assertThat(availabilityAndBacklog2.isAvailable()).isEqualTo(true); + } + + @Test + void testNotifyRequiredSegmentId() { + tieredStoreResultSubpartitionView.notifyRequiredSegmentId(1); + assertThat(availabilityListener.isDone()).isTrue(); + } + + @Test + void testReleaseAllResources() throws IOException { + tieredStoreResultSubpartitionView.releaseAllResources(); + assertThat(nettyPayloadQueues.get(0)).hasSize(0); + assertThat(nettyPayloadQueues.get(1)).hasSize(0); + assertThat(connectionBrokenConsumers.get(0).isDone()).isTrue(); + assertThat(connectionBrokenConsumers.get(1).isDone()).isTrue(); + assertThat(tieredStoreResultSubpartitionView.isReleased()).isTrue(); + } + + @Test + void testGetNumberOfQueuedBuffers() { + assertThat(tieredStoreResultSubpartitionView.getNumberOfQueuedBuffers()).isEqualTo(3); + assertThat(tieredStoreResultSubpartitionView.unsynchronizedGetNumberOfQueuedBuffers()) + .isEqualTo(3); + } + + private void checkBufferAndBacklog(BufferAndBacklog bufferAndBacklog, int backlog) { + assertThat(bufferAndBacklog).isNotNull(); + assertThat(bufferAndBacklog.buffer()).isNotNull(); + assertThat(bufferAndBacklog.buffersInBacklog()).isEqualTo(backlog); + } + + private BufferAvailabilityListener createBufferAvailabilityListener( Review Comment: ```suggestion private static BufferAvailabilityListener createBufferAvailabilityListener( ``` ########## flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/netty/NettyConnectionWriterTest.java: ########## @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.partition.hybrid.tiered.netty; + +import org.apache.flink.core.memory.MemorySegmentFactory; +import org.apache.flink.runtime.io.network.buffer.BufferRecycler; +import org.apache.flink.runtime.io.network.buffer.NetworkBuffer; +import org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage.NettyPayload; + +import org.junit.jupiter.api.Test; + +import java.util.ArrayDeque; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Tests for {@link NettyConnectionWriter}. */ +public class NettyConnectionWriterTest { + + private static final int SUBPARTITION_ID = 0; + + @Test + void testWriteBuffer() { + int bufferNumber = 10; + ArrayDeque<NettyPayload> nettyPayloadQueue = new ArrayDeque<>(); + NettyConnectionWriter nettyConnectionWriter = + new NettyConnectionWriterImpl(nettyPayloadQueue); + writeBufferTpWriter(bufferNumber, nettyConnectionWriter); + assertThat(nettyPayloadQueue.size()).isEqualTo(bufferNumber); + } + + @Test + void testGetNettyConnectionId() { + ArrayDeque<NettyPayload> nettyPayloadQueue = new ArrayDeque<>(); + NettyConnectionWriter nettyConnectionWriter = + new NettyConnectionWriterImpl(nettyPayloadQueue); + assertThat(nettyConnectionWriter.getNettyConnectionId()).isNotNull(); + } + + @Test + void testNumQueuedBuffers() { + int bufferNumber = 10; + ArrayDeque<NettyPayload> nettyPayloadQueue = new ArrayDeque<>(); + NettyConnectionWriter nettyConnectionWriter = + new NettyConnectionWriterImpl(nettyPayloadQueue); + writeBufferTpWriter(bufferNumber, nettyConnectionWriter); + assertThat(nettyConnectionWriter.numQueuedBuffers()).isEqualTo(bufferNumber); + } + + @Test + void testClose() { + int bufferNumber = 10; + ArrayDeque<NettyPayload> nettyPayloadQueue = new ArrayDeque<>(); + NettyConnectionWriter nettyConnectionWriter = + new NettyConnectionWriterImpl(nettyPayloadQueue); + writeBufferTpWriter(bufferNumber, nettyConnectionWriter); + nettyConnectionWriter.close(); + assertThat(nettyConnectionWriter.numQueuedBuffers()).isEqualTo(0); + } + + private void writeBufferTpWriter( + int bufferNumber, NettyConnectionWriter nettyConnectionWriter) { + for (int index = 0; index < bufferNumber; ++index) { + nettyConnectionWriter.writeBuffer( + NettyPayload.newBuffer( + new NetworkBuffer( + MemorySegmentFactory.allocateUnpooledSegment(0), + BufferRecycler.DummyBufferRecycler.INSTANCE), Review Comment: Note that to check all the recyclers like this. ########## flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/netty/TieredStoreResultSubpartitionViewTest.java: ########## @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.partition.hybrid.tiered.netty; + +import org.apache.flink.core.memory.MemorySegmentFactory; +import org.apache.flink.runtime.io.network.buffer.BufferRecycler; +import org.apache.flink.runtime.io.network.buffer.NetworkBuffer; +import org.apache.flink.runtime.io.network.partition.BufferAvailabilityListener; +import org.apache.flink.runtime.io.network.partition.ResultSubpartition.BufferAndBacklog; +import org.apache.flink.runtime.io.network.partition.ResultSubpartitionView; +import org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage.NettyPayload; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Queue; +import java.util.concurrent.CompletableFuture; + +import static org.apache.flink.runtime.io.network.buffer.Buffer.DataType.END_OF_SEGMENT; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Tests for {@link TieredStoreResultSubpartitionView}. */ +public class TieredStoreResultSubpartitionViewTest { + + private static final int TIER_NUMBER = 2; + + private CompletableFuture<Void> availabilityListener; + + private List<Queue<NettyPayload>> nettyPayloadQueues; + + private List<CompletableFuture<NettyConnectionId>> connectionBrokenConsumers; + + private TieredStoreResultSubpartitionView tieredStoreResultSubpartitionView; + + @BeforeEach + void before() { + availabilityListener = new CompletableFuture<>(); + nettyPayloadQueues = createNettyPayloadQueues(); + connectionBrokenConsumers = + Arrays.asList(new CompletableFuture<>(), new CompletableFuture<>()); + tieredStoreResultSubpartitionView = + new TieredStoreResultSubpartitionView( + createBufferAvailabilityListener(availabilityListener), + nettyPayloadQueues, + createNettyConnectionIds(), + createNettyServiceProducers(connectionBrokenConsumers)); + } + + @Test + void testGetNextBuffer() throws IOException { + checkBufferAndBacklog(tieredStoreResultSubpartitionView.getNextBuffer(), 1); + checkBufferAndBacklog(tieredStoreResultSubpartitionView.getNextBuffer(), 0); + tieredStoreResultSubpartitionView.notifyRequiredSegmentId(1); + assertThat(availabilityListener).isDone(); + checkBufferAndBacklog(tieredStoreResultSubpartitionView.getNextBuffer(), 1); + checkBufferAndBacklog(tieredStoreResultSubpartitionView.getNextBuffer(), 0); + assertThat(tieredStoreResultSubpartitionView.getNextBuffer()).isNull(); + } + + @Test + void testGetNextBufferFailed() { + Throwable expectedError = new IOException(); + nettyPayloadQueues = createNettyPayloadQueuesWithError(expectedError); + tieredStoreResultSubpartitionView = + new TieredStoreResultSubpartitionView( + createBufferAvailabilityListener(availabilityListener), + nettyPayloadQueues, + createNettyConnectionIds(), + createNettyServiceProducers(connectionBrokenConsumers)); + assertThatThrownBy(tieredStoreResultSubpartitionView::getNextBuffer) + .hasCause(expectedError); + assertThat(connectionBrokenConsumers.get(0).isDone()).isTrue(); + } + + @Test + void testGetAvailabilityAndBacklog() { + ResultSubpartitionView.AvailabilityWithBacklog availabilityAndBacklog1 = + tieredStoreResultSubpartitionView.getAvailabilityAndBacklog(0); + assertThat(availabilityAndBacklog1.getBacklog()).isEqualTo(3); + assertThat(availabilityAndBacklog1.isAvailable()).isEqualTo(false); + ResultSubpartitionView.AvailabilityWithBacklog availabilityAndBacklog2 = + tieredStoreResultSubpartitionView.getAvailabilityAndBacklog(2); + assertThat(availabilityAndBacklog2.getBacklog()).isEqualTo(3); + assertThat(availabilityAndBacklog2.isAvailable()).isEqualTo(true); + } + + @Test + void testNotifyRequiredSegmentId() { + tieredStoreResultSubpartitionView.notifyRequiredSegmentId(1); + assertThat(availabilityListener.isDone()).isTrue(); + } + + @Test + void testReleaseAllResources() throws IOException { + tieredStoreResultSubpartitionView.releaseAllResources(); + assertThat(nettyPayloadQueues.get(0)).hasSize(0); + assertThat(nettyPayloadQueues.get(1)).hasSize(0); + assertThat(connectionBrokenConsumers.get(0).isDone()).isTrue(); + assertThat(connectionBrokenConsumers.get(1).isDone()).isTrue(); + assertThat(tieredStoreResultSubpartitionView.isReleased()).isTrue(); + } + + @Test + void testGetNumberOfQueuedBuffers() { + assertThat(tieredStoreResultSubpartitionView.getNumberOfQueuedBuffers()).isEqualTo(3); + assertThat(tieredStoreResultSubpartitionView.unsynchronizedGetNumberOfQueuedBuffers()) + .isEqualTo(3); + } + + private void checkBufferAndBacklog(BufferAndBacklog bufferAndBacklog, int backlog) { Review Comment: ```suggestion private static void checkBufferAndBacklog(BufferAndBacklog bufferAndBacklog, int backlog) { ``` I checked all the private methods in the test classes, they should be static. ########## flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/netty/TieredStorageNettyServiceImpl.java: ########## @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.partition.hybrid.tiered.netty; + +import org.apache.flink.runtime.io.network.partition.BufferAvailabilityListener; +import org.apache.flink.runtime.io.network.partition.ResultSubpartitionView; +import org.apache.flink.runtime.io.network.partition.consumer.InputChannel; +import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate; +import org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStoragePartitionId; +import org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStorageSubpartitionId; +import org.apache.flink.runtime.io.network.partition.hybrid.tiered.shuffle.TieredResultPartition; +import org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage.NettyPayload; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Queue; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.function.Supplier; + +import static org.apache.flink.util.Preconditions.checkState; + +/** The default implementation of {@link TieredStorageNettyService}. */ +public class TieredStorageNettyServiceImpl implements TieredStorageNettyService { + + // ------------------------------------ + // For producer side + // ------------------------------------ + + private final Map<TieredStoragePartitionId, List<NettyServiceProducer>> + registeredServiceProducers = new ConcurrentHashMap<>(); + + private final Map<NettyConnectionId, BufferAvailabilityListener> + registeredAvailabilityListeners = new ConcurrentHashMap<>(); + + // ------------------------------------ + // For consumer side + // ------------------------------------ + + private final Map<TieredStoragePartitionId, Map<TieredStorageSubpartitionId, Integer>> + registeredChannelIndexes = new ConcurrentHashMap<>(); + + private final Map< + TieredStoragePartitionId, + Map<TieredStorageSubpartitionId, Supplier<InputChannel>>> + registeredInputChannelProviders = new ConcurrentHashMap<>(); + + private final Map< + TieredStoragePartitionId, + Map< + TieredStorageSubpartitionId, + NettyConnectionReaderAvailabilityAndPriorityHelper>> + registeredNettyConnectionReaderAvailabilityAndPriorityHelpers = + new ConcurrentHashMap<>(); + + @Override + public void registerProducer( + TieredStoragePartitionId partitionId, NettyServiceProducer serviceProducer) { + List<NettyServiceProducer> serviceProducers = + registeredServiceProducers.getOrDefault(partitionId, new ArrayList<>()); + serviceProducers.add(serviceProducer); + registeredServiceProducers.put(partitionId, serviceProducers); + } + + @Override + public NettyConnectionReader registerConsumer( + TieredStoragePartitionId partitionId, TieredStorageSubpartitionId subpartitionId) { + Integer channelIndex = registeredChannelIndexes.get(partitionId).remove(subpartitionId); + if (registeredChannelIndexes.get(partitionId).isEmpty()) { + registeredChannelIndexes.remove(partitionId); + } + + Supplier<InputChannel> inputChannelProvider = + registeredInputChannelProviders.get(partitionId).remove(subpartitionId); + if (registeredInputChannelProviders.get(partitionId).isEmpty()) { + registeredInputChannelProviders.remove(partitionId); + } + + NettyConnectionReaderAvailabilityAndPriorityHelper helper = + registeredNettyConnectionReaderAvailabilityAndPriorityHelpers + .get(partitionId) + .remove(subpartitionId); + if (registeredNettyConnectionReaderAvailabilityAndPriorityHelpers + .get(partitionId) + .isEmpty()) { + registeredNettyConnectionReaderAvailabilityAndPriorityHelpers.remove(partitionId); + } + return new NettyConnectionReaderImpl(channelIndex, inputChannelProvider, helper); + } + + /** + * Create a {@link ResultSubpartitionView} for the netty server. + * + * @param partitionId partition id indicates the unique id of {@link TieredResultPartition}. + * @param subpartitionId subpartition id indicates the unique id of subpartition. + * @param availabilityListener listener is used to listen the available status of data. + * @return the {@link TieredStoreResultSubpartitionView}. + */ + public ResultSubpartitionView createResultSubpartitionView( + TieredStoragePartitionId partitionId, + TieredStorageSubpartitionId subpartitionId, + BufferAvailabilityListener availabilityListener) { + List<NettyServiceProducer> serviceProducers = registeredServiceProducers.get(partitionId); + if (serviceProducers == null) { + return new TieredStoreResultSubpartitionView( + availabilityListener, new ArrayList<>(), new ArrayList<>(), new ArrayList<>()); + } + List<Queue<NettyPayload>> queues = new ArrayList<>(); + List<NettyConnectionId> nettyConnectionIds = new ArrayList<>(); + for (NettyServiceProducer serviceProducer : serviceProducers) { + LinkedBlockingQueue<NettyPayload> queue = new LinkedBlockingQueue<>(); + NettyConnectionWriterImpl writer = new NettyConnectionWriterImpl(queue); + serviceProducer.connectionEstablished(subpartitionId, writer); + nettyConnectionIds.add(writer.getNettyConnectionId()); + queues.add(queue); + registeredAvailabilityListeners.put( + writer.getNettyConnectionId(), availabilityListener); + } + return new TieredStoreResultSubpartitionView( + availabilityListener, + queues, + nettyConnectionIds, + registeredServiceProducers.get(partitionId)); + } + + /** + * Notify the {@link ResultSubpartitionView} to send buffer. + * + * @param connectionId connection id indicates the id of connection. + */ + public void notifyResultSubpartitionViewSendBuffer(NettyConnectionId connectionId) { + BufferAvailabilityListener listener = registeredAvailabilityListeners.get(connectionId); + if (listener != null) { + listener.notifyDataAvailable(); + } + } + + /** + * Set up input channels in {@link SingleInputGate}. + * + * @param partitionIds partition ids indicates the ids of {@link TieredResultPartition}. + * @param subpartitionIds subpartition ids indicates the ids of subpartition. + */ + public void setUpInputChannels( + TieredStoragePartitionId[] partitionIds, + TieredStorageSubpartitionId[] subpartitionIds, + List<Supplier<InputChannel>> inputChannelProviders, + NettyConnectionReaderAvailabilityAndPriorityHelper helper) { + checkState(partitionIds.length == subpartitionIds.length); + checkState(subpartitionIds.length == inputChannelProviders.size()); + for (int index = 0; index < partitionIds.length; ++index) { + TieredStoragePartitionId partitionId = partitionIds[index]; + TieredStorageSubpartitionId subpartitionId = subpartitionIds[index]; + + Map<TieredStorageSubpartitionId, Integer> channelIndexes = + registeredChannelIndexes.getOrDefault(partitionId, new ConcurrentHashMap<>()); + channelIndexes.put(subpartitionId, index); + registeredChannelIndexes.put(partitionId, channelIndexes); + + Map<TieredStorageSubpartitionId, Supplier<InputChannel>> providers = + registeredInputChannelProviders.getOrDefault( + partitionId, new ConcurrentHashMap<>()); + providers.put(subpartitionId, inputChannelProviders.get(index)); + registeredInputChannelProviders.put(partitionId, providers); + + Map<TieredStorageSubpartitionId, NettyConnectionReaderAvailabilityAndPriorityHelper> + helpers = + registeredNettyConnectionReaderAvailabilityAndPriorityHelpers + .getOrDefault(partitionId, new ConcurrentHashMap<>()); Review Comment: ditto ########## flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/netty/TieredStoreResultSubpartitionViewTest.java: ########## @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.partition.hybrid.tiered.netty; + +import org.apache.flink.core.memory.MemorySegmentFactory; +import org.apache.flink.runtime.io.network.buffer.BufferRecycler; +import org.apache.flink.runtime.io.network.buffer.NetworkBuffer; +import org.apache.flink.runtime.io.network.partition.BufferAvailabilityListener; +import org.apache.flink.runtime.io.network.partition.ResultSubpartition.BufferAndBacklog; +import org.apache.flink.runtime.io.network.partition.ResultSubpartitionView; +import org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage.NettyPayload; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Queue; +import java.util.concurrent.CompletableFuture; + +import static org.apache.flink.runtime.io.network.buffer.Buffer.DataType.END_OF_SEGMENT; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Tests for {@link TieredStoreResultSubpartitionView}. */ +public class TieredStoreResultSubpartitionViewTest { + + private static final int TIER_NUMBER = 2; + + private CompletableFuture<Void> availabilityListener; + + private List<Queue<NettyPayload>> nettyPayloadQueues; + + private List<CompletableFuture<NettyConnectionId>> connectionBrokenConsumers; + + private TieredStoreResultSubpartitionView tieredStoreResultSubpartitionView; + + @BeforeEach + void before() { + availabilityListener = new CompletableFuture<>(); + nettyPayloadQueues = createNettyPayloadQueues(); + connectionBrokenConsumers = + Arrays.asList(new CompletableFuture<>(), new CompletableFuture<>()); + tieredStoreResultSubpartitionView = + new TieredStoreResultSubpartitionView( + createBufferAvailabilityListener(availabilityListener), + nettyPayloadQueues, + createNettyConnectionIds(), + createNettyServiceProducers(connectionBrokenConsumers)); + } + + @Test + void testGetNextBuffer() throws IOException { + checkBufferAndBacklog(tieredStoreResultSubpartitionView.getNextBuffer(), 1); + checkBufferAndBacklog(tieredStoreResultSubpartitionView.getNextBuffer(), 0); + tieredStoreResultSubpartitionView.notifyRequiredSegmentId(1); + assertThat(availabilityListener).isDone(); + checkBufferAndBacklog(tieredStoreResultSubpartitionView.getNextBuffer(), 1); + checkBufferAndBacklog(tieredStoreResultSubpartitionView.getNextBuffer(), 0); + assertThat(tieredStoreResultSubpartitionView.getNextBuffer()).isNull(); + } + + @Test + void testGetNextBufferFailed() { + Throwable expectedError = new IOException(); + nettyPayloadQueues = createNettyPayloadQueuesWithError(expectedError); + tieredStoreResultSubpartitionView = + new TieredStoreResultSubpartitionView( + createBufferAvailabilityListener(availabilityListener), + nettyPayloadQueues, + createNettyConnectionIds(), + createNettyServiceProducers(connectionBrokenConsumers)); + assertThatThrownBy(tieredStoreResultSubpartitionView::getNextBuffer) + .hasCause(expectedError); + assertThat(connectionBrokenConsumers.get(0).isDone()).isTrue(); + } + + @Test + void testGetAvailabilityAndBacklog() { + ResultSubpartitionView.AvailabilityWithBacklog availabilityAndBacklog1 = + tieredStoreResultSubpartitionView.getAvailabilityAndBacklog(0); + assertThat(availabilityAndBacklog1.getBacklog()).isEqualTo(3); + assertThat(availabilityAndBacklog1.isAvailable()).isEqualTo(false); + ResultSubpartitionView.AvailabilityWithBacklog availabilityAndBacklog2 = + tieredStoreResultSubpartitionView.getAvailabilityAndBacklog(2); + assertThat(availabilityAndBacklog2.getBacklog()).isEqualTo(3); + assertThat(availabilityAndBacklog2.isAvailable()).isEqualTo(true); + } + + @Test + void testNotifyRequiredSegmentId() { + tieredStoreResultSubpartitionView.notifyRequiredSegmentId(1); + assertThat(availabilityListener.isDone()).isTrue(); + } + + @Test + void testReleaseAllResources() throws IOException { + tieredStoreResultSubpartitionView.releaseAllResources(); + assertThat(nettyPayloadQueues.get(0)).hasSize(0); + assertThat(nettyPayloadQueues.get(1)).hasSize(0); + assertThat(connectionBrokenConsumers.get(0).isDone()).isTrue(); + assertThat(connectionBrokenConsumers.get(1).isDone()).isTrue(); + assertThat(tieredStoreResultSubpartitionView.isReleased()).isTrue(); + } + + @Test + void testGetNumberOfQueuedBuffers() { + assertThat(tieredStoreResultSubpartitionView.getNumberOfQueuedBuffers()).isEqualTo(3); + assertThat(tieredStoreResultSubpartitionView.unsynchronizedGetNumberOfQueuedBuffers()) + .isEqualTo(3); + } + + private void checkBufferAndBacklog(BufferAndBacklog bufferAndBacklog, int backlog) { + assertThat(bufferAndBacklog).isNotNull(); + assertThat(bufferAndBacklog.buffer()).isNotNull(); + assertThat(bufferAndBacklog.buffersInBacklog()).isEqualTo(backlog); + } + + private BufferAvailabilityListener createBufferAvailabilityListener( + CompletableFuture<Void> notifier) { + return () -> notifier.complete(null); + } + + private List<Queue<NettyPayload>> createNettyPayloadQueues() { + List<Queue<NettyPayload>> nettyPayloadQueues = new ArrayList<>(); + for (int index = 0; index < TIER_NUMBER; ++index) { + Queue<NettyPayload> queue = new ArrayDeque<>(); + queue.add(NettyPayload.newSegment(index)); + queue.add( + NettyPayload.newBuffer( + new NetworkBuffer( + MemorySegmentFactory.allocateUnpooledSegment(0), + BufferRecycler.DummyBufferRecycler.INSTANCE), + 0, + index)); + queue.add( + NettyPayload.newBuffer( + new NetworkBuffer( + MemorySegmentFactory.allocateUnpooledSegment(0), + BufferRecycler.DummyBufferRecycler.INSTANCE, + END_OF_SEGMENT), + 1, + index)); + nettyPayloadQueues.add(queue); + } + return nettyPayloadQueues; + } + + private List<Queue<NettyPayload>> createNettyPayloadQueuesWithError(Throwable error) { + List<Queue<NettyPayload>> nettyPayloadQueues = new ArrayList<>(); + for (int index = 0; index < TIER_NUMBER; ++index) { + Queue<NettyPayload> queue = new ArrayDeque<>(); + queue.add(NettyPayload.newSegment(index)); + queue.add(NettyPayload.newError(error)); + nettyPayloadQueues.add(queue); + } + return nettyPayloadQueues; + } + + private List<NettyConnectionId> createNettyConnectionIds() { Review Comment: ```suggestion private static List<NettyConnectionId> createNettyConnectionIds() { ``` ########## flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/netty/TieredStoreResultSubpartitionViewTest.java: ########## @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.partition.hybrid.tiered.netty; + +import org.apache.flink.core.memory.MemorySegmentFactory; +import org.apache.flink.runtime.io.network.buffer.BufferRecycler; +import org.apache.flink.runtime.io.network.buffer.NetworkBuffer; +import org.apache.flink.runtime.io.network.partition.BufferAvailabilityListener; +import org.apache.flink.runtime.io.network.partition.ResultSubpartition.BufferAndBacklog; +import org.apache.flink.runtime.io.network.partition.ResultSubpartitionView; +import org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage.NettyPayload; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Queue; +import java.util.concurrent.CompletableFuture; + +import static org.apache.flink.runtime.io.network.buffer.Buffer.DataType.END_OF_SEGMENT; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Tests for {@link TieredStoreResultSubpartitionView}. */ +public class TieredStoreResultSubpartitionViewTest { + + private static final int TIER_NUMBER = 2; + + private CompletableFuture<Void> availabilityListener; + + private List<Queue<NettyPayload>> nettyPayloadQueues; + + private List<CompletableFuture<NettyConnectionId>> connectionBrokenConsumers; + + private TieredStoreResultSubpartitionView tieredStoreResultSubpartitionView; + + @BeforeEach + void before() { + availabilityListener = new CompletableFuture<>(); + nettyPayloadQueues = createNettyPayloadQueues(); + connectionBrokenConsumers = + Arrays.asList(new CompletableFuture<>(), new CompletableFuture<>()); + tieredStoreResultSubpartitionView = + new TieredStoreResultSubpartitionView( + createBufferAvailabilityListener(availabilityListener), + nettyPayloadQueues, + createNettyConnectionIds(), + createNettyServiceProducers(connectionBrokenConsumers)); + } + + @Test + void testGetNextBuffer() throws IOException { + checkBufferAndBacklog(tieredStoreResultSubpartitionView.getNextBuffer(), 1); + checkBufferAndBacklog(tieredStoreResultSubpartitionView.getNextBuffer(), 0); + tieredStoreResultSubpartitionView.notifyRequiredSegmentId(1); + assertThat(availabilityListener).isDone(); + checkBufferAndBacklog(tieredStoreResultSubpartitionView.getNextBuffer(), 1); + checkBufferAndBacklog(tieredStoreResultSubpartitionView.getNextBuffer(), 0); + assertThat(tieredStoreResultSubpartitionView.getNextBuffer()).isNull(); + } + + @Test + void testGetNextBufferFailed() { + Throwable expectedError = new IOException(); + nettyPayloadQueues = createNettyPayloadQueuesWithError(expectedError); + tieredStoreResultSubpartitionView = + new TieredStoreResultSubpartitionView( + createBufferAvailabilityListener(availabilityListener), + nettyPayloadQueues, + createNettyConnectionIds(), + createNettyServiceProducers(connectionBrokenConsumers)); + assertThatThrownBy(tieredStoreResultSubpartitionView::getNextBuffer) + .hasCause(expectedError); + assertThat(connectionBrokenConsumers.get(0).isDone()).isTrue(); + } + + @Test + void testGetAvailabilityAndBacklog() { + ResultSubpartitionView.AvailabilityWithBacklog availabilityAndBacklog1 = + tieredStoreResultSubpartitionView.getAvailabilityAndBacklog(0); + assertThat(availabilityAndBacklog1.getBacklog()).isEqualTo(3); + assertThat(availabilityAndBacklog1.isAvailable()).isEqualTo(false); + ResultSubpartitionView.AvailabilityWithBacklog availabilityAndBacklog2 = + tieredStoreResultSubpartitionView.getAvailabilityAndBacklog(2); + assertThat(availabilityAndBacklog2.getBacklog()).isEqualTo(3); + assertThat(availabilityAndBacklog2.isAvailable()).isEqualTo(true); + } + + @Test + void testNotifyRequiredSegmentId() { + tieredStoreResultSubpartitionView.notifyRequiredSegmentId(1); + assertThat(availabilityListener.isDone()).isTrue(); + } + + @Test + void testReleaseAllResources() throws IOException { + tieredStoreResultSubpartitionView.releaseAllResources(); + assertThat(nettyPayloadQueues.get(0)).hasSize(0); + assertThat(nettyPayloadQueues.get(1)).hasSize(0); + assertThat(connectionBrokenConsumers.get(0).isDone()).isTrue(); + assertThat(connectionBrokenConsumers.get(1).isDone()).isTrue(); + assertThat(tieredStoreResultSubpartitionView.isReleased()).isTrue(); + } + + @Test + void testGetNumberOfQueuedBuffers() { + assertThat(tieredStoreResultSubpartitionView.getNumberOfQueuedBuffers()).isEqualTo(3); + assertThat(tieredStoreResultSubpartitionView.unsynchronizedGetNumberOfQueuedBuffers()) + .isEqualTo(3); + } + + private void checkBufferAndBacklog(BufferAndBacklog bufferAndBacklog, int backlog) { + assertThat(bufferAndBacklog).isNotNull(); + assertThat(bufferAndBacklog.buffer()).isNotNull(); + assertThat(bufferAndBacklog.buffersInBacklog()).isEqualTo(backlog); + } + + private BufferAvailabilityListener createBufferAvailabilityListener( + CompletableFuture<Void> notifier) { + return () -> notifier.complete(null); + } + + private List<Queue<NettyPayload>> createNettyPayloadQueues() { + List<Queue<NettyPayload>> nettyPayloadQueues = new ArrayList<>(); + for (int index = 0; index < TIER_NUMBER; ++index) { + Queue<NettyPayload> queue = new ArrayDeque<>(); + queue.add(NettyPayload.newSegment(index)); + queue.add( + NettyPayload.newBuffer( + new NetworkBuffer( + MemorySegmentFactory.allocateUnpooledSegment(0), + BufferRecycler.DummyBufferRecycler.INSTANCE), + 0, + index)); + queue.add( + NettyPayload.newBuffer( + new NetworkBuffer( + MemorySegmentFactory.allocateUnpooledSegment(0), + BufferRecycler.DummyBufferRecycler.INSTANCE, + END_OF_SEGMENT), + 1, + index)); + nettyPayloadQueues.add(queue); + } + return nettyPayloadQueues; + } + + private List<Queue<NettyPayload>> createNettyPayloadQueuesWithError(Throwable error) { Review Comment: ```suggestion private static List<Queue<NettyPayload>> createNettyPayloadQueuesWithError(Throwable error) { ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
