reswqa commented on code in PR #22652:
URL: https://github.com/apache/flink/pull/22652#discussion_r1212863225


##########
flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/TieredStorageProducerClient.java:
##########
@@ -80,13 +107,23 @@ public void write(
 
         if (isBroadcast && !isBroadcastOnly) {
             for (int i = 0; i < numSubpartitions; ++i) {
-                bufferAccumulator.receive(record.duplicate(), subpartitionId, 
dataType);
+                // As the tiered storage subpartition ID is created only for 
broadcast records,
+                // which are fewer than normal records, the performance impact 
of generating new
+                // TieredStorageSubpartitionId objects is expected to be 
manageable. If the
+                // performance is significantly affected, this logic will be 
optimized accordingly.
+                bufferAccumulator.receive(
+                        record.duplicate(), new 
TieredStorageSubpartitionId(i), dataType);
             }
         } else {
             bufferAccumulator.receive(record, subpartitionId, dataType);
         }
     }
 
+    public void setMetricStatisticsUpdater(
+            Consumer<TieredStorageProducerMetricUpdate> 
metricStatisticsUpdater) {
+        this.metricStatisticsUpdater = metricStatisticsUpdater;

Review Comment:
   ```suggestion
         this.metricStatisticsUpdater = checkNotNull(metricStatisticsUpdater);
   ```



##########
flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/tier/TierProducerAgent.java:
##########
@@ -21,9 +21,15 @@
 import org.apache.flink.runtime.io.network.buffer.Buffer;
 import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStorageSubpartitionId;
 
-import java.io.IOException;
-
-/** The producer-side agent of a Tier. */
+/**
+ * The producer-side agent of a Tier.
+ *
+ * <p>Note that when writing a buffer to a tier, the {@link TierProducerAgent} 
should first call
+ * {@code tryStartNewSegment} to start a new segment. The agent can then 
continue writing the buffer
+ * to the tier as long as the return value of {@code write} is true. If the 
return value of {@code
+ * write} is false, it indicates that the current segment can no longer store 
the buffer, and the
+ * agent should try to start a new segment before writing the buffer.
+ */
 public interface TierProducerAgent {

Review Comment:
   I'd prefer let this extends `AutoClosable`.



##########
flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/tier/TierProducerAgent.java:
##########
@@ -35,9 +41,21 @@ public interface TierProducerAgent {
      */
     boolean tryStartNewSegment(TieredStorageSubpartitionId subpartitionId, int 
segmentId);
 
-    /** Writes the finished {@link Buffer} to the consumer. */
-    boolean write(TieredStorageSubpartitionId subpartitionId, Buffer 
finishedBuffer)
-            throws IOException;
+    /**
+     * Writes the finished {@link Buffer} to the consumer.
+     *
+     * <p>Note that the tier must ensure that the buffer is written 
successfully without any
+     * exceptions, in order to guarantee that the buffer will be recycled. If 
this method throws an
+     * exception in the subsequent modifications, the caller should make sure 
that the buffer is
+     * recycled finally.
+     *
+     * @param subpartitionId the subpartition id that the buffer is writing to
+     * @param finishedBuffer the writing buffer
+     * @return return true if the buffer is written successfully, return false 
if the current
+     *     segment can not store this buffer and the current segment is 
finished. When returning
+     *     false, the agent should try start a new segment before writing the 
buffer.
+     */
+    boolean write(TieredStorageSubpartitionId subpartitionId, Buffer 
finishedBuffer);

Review Comment:
   how about rename this to `tryWrite`?



##########
flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/TieredStorageProducerClientTest.java:
##########
@@ -0,0 +1,352 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage;
+
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.BufferPool;
+import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStorageSubpartitionId;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.tier.TierProducerAgent;
+import org.apache.flink.testutils.junit.extensions.parameterized.Parameter;
+import 
org.apache.flink.testutils.junit.extensions.parameterized.ParameterizedTestExtension;
+import org.apache.flink.testutils.junit.extensions.parameterized.Parameters;
+
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.TestTemplate;
+import org.junit.jupiter.api.extension.ExtendWith;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.Random;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/** Tests for {@link TieredStorageProducerClient}. */
+@ExtendWith(ParameterizedTestExtension.class)
+public class TieredStorageProducerClientTest {
+
+    private static final int NUM_TOTAL_BUFFERS = 1000;
+
+    private static final int NETWORK_BUFFER_SIZE = 1024;
+
+    private static final float NUM_BUFFERS_TRIGGER_FLUSH_RATIO = 0.6f;
+
+    private static final int NUM_BUFFERS_IN_A_SEGMENT = 5;
+
+    @Parameter public boolean isBroadcast;
+
+    private NetworkBufferPool globalPool;
+
+    @Parameters(name = "isBroadcast={0}")
+    public static Collection<Boolean> parameters() {
+        return Arrays.asList(false, true);
+    }
+
+    @BeforeEach
+    void before() {
+        globalPool = new NetworkBufferPool(NUM_TOTAL_BUFFERS, 
NETWORK_BUFFER_SIZE);
+    }
+
+    @AfterEach
+    void after() {
+        globalPool.destroy();
+    }
+
+    @TestTemplate
+    void testWriteRecordsToEmptyStorageTiers() throws IOException {
+        int numSubpartitions = 10;
+        int numBuffersInPool = 10;
+        int bufferSize = 1024;
+        Random random = new Random();
+
+        TieredStorageMemoryManagerImpl storageMemoryManager =
+                createStorageMemoryManager(numBuffersInPool);
+        BufferAccumulator bufferAccumulator =
+                new HashBufferAccumulator(numSubpartitions, bufferSize, 
storageMemoryManager);
+
+        TieredStorageProducerClient tieredStorageProducerClient =
+                new TieredStorageProducerClient(
+                        numSubpartitions, false, bufferAccumulator, null, 
Collections.emptyList());
+
+        assertThatThrownBy(
+                        () ->
+                                tieredStorageProducerClient.write(
+                                        generateRandomData(bufferSize, random),
+                                        new TieredStorageSubpartitionId(0),
+                                        Buffer.DataType.DATA_BUFFER,
+                                        isBroadcast))
+                .isInstanceOf(RuntimeException.class)
+                .hasMessageContaining("Failed to choose a storage tier");
+    }
+
+    @TestTemplate
+    void testEmptyMetricUpdater() throws IOException {
+        int numSubpartitions = 10;
+        int numBuffersInPool = 10;
+        int bufferSize = 1024;
+        Random random = new Random();
+
+        TieredStorageMemoryManagerImpl storageMemoryManager =
+                createStorageMemoryManager(numBuffersInPool);
+        BufferAccumulator bufferAccumulator =
+                new HashBufferAccumulator(numSubpartitions, bufferSize, 
storageMemoryManager);
+        TestingTierProducerAgent testingTierProducerAgent =
+                new TestingTierProducerAgent(false, numSubpartitions);
+
+        TieredStorageProducerClient tieredStorageProducerClient =
+                new TieredStorageProducerClient(
+                        numSubpartitions,
+                        false,
+                        bufferAccumulator,
+                        null,
+                        Collections.singletonList(testingTierProducerAgent));
+
+        assertThatThrownBy(
+                        () ->
+                                tieredStorageProducerClient.write(
+                                        generateRandomData(bufferSize, random),
+                                        new TieredStorageSubpartitionId(0),
+                                        Buffer.DataType.DATA_BUFFER,
+                                        isBroadcast))
+                .isInstanceOf(NullPointerException.class);
+    }
+
+    @TestTemplate
+    void testWriteRecordsToMultipleStorageTiers() throws IOException {
+        int numSubpartitions = 10;
+        int numBuffersInPool = 10;
+        int numSegmentCapacityInLimitedTier = 12;
+        int numMaxToSendRecords = 200;
+        int bufferSize = 1024;
+        Random random = new Random();
+
+        TieredStorageMemoryManagerImpl storageMemoryManager =
+                createStorageMemoryManager(numBuffersInPool);
+        BufferAccumulator bufferAccumulator =
+                new HashBufferAccumulator(numSubpartitions, bufferSize, 
storageMemoryManager);
+        List<TierProducerAgent> tierProducerAgents = new ArrayList<>();
+        TestingTierProducerAgent limitedSegmentsTierProducerAgent =
+                new TestingTierProducerAgent(
+                        true, numSubpartitions, 
numSegmentCapacityInLimitedTier);
+        TestingTierProducerAgent unLimitedTierProducerAgent =
+                new TestingTierProducerAgent(false, numSubpartitions);
+        tierProducerAgents.add(limitedSegmentsTierProducerAgent);
+        tierProducerAgents.add(unLimitedTierProducerAgent);
+
+        TieredStorageProducerClient tieredStorageProducerClient =
+                new TieredStorageProducerClient(
+                        numSubpartitions, false, bufferAccumulator, null, 
tierProducerAgents);
+
+        AtomicInteger numWriteBuffers = new AtomicInteger();
+        AtomicInteger numWriteBytes = new AtomicInteger();
+        tieredStorageProducerClient.setMetricStatisticsUpdater(
+                metricStatistics -> {
+                    numWriteBuffers.set(
+                            numWriteBuffers.get() + 
metricStatistics.numWriteBuffersDelta());
+                    numWriteBytes.set(numWriteBytes.get() + 
metricStatistics.numWriteBytesDelta());
+                });
+
+        int[] numSentRecords = new int[numSubpartitions];
+        for (int i = 0; i < numSubpartitions; i++) {
+            int numToSend = random.nextInt(numMaxToSendRecords - 1) + 1;
+            numSentRecords[i] = numToSend;
+            TieredStorageSubpartitionId subpartitionId = new 
TieredStorageSubpartitionId(i);
+            for (int j = 0; j < numToSend; j++) {
+                tieredStorageProducerClient.write(
+                        generateRandomData(bufferSize, random),
+                        subpartitionId,
+                        Buffer.DataType.DATA_BUFFER,
+                        isBroadcast);
+            }
+        }
+
+        checkReceivedBuffersAndSegments(
+                numSubpartitions,
+                bufferSize,
+                numSegmentCapacityInLimitedTier,
+                limitedSegmentsTierProducerAgent,
+                unLimitedTierProducerAgent,
+                numSentRecords,
+                numWriteBuffers,
+                numWriteBytes);
+
+        assertThat(limitedSegmentsTierProducerAgent.isClosed).isFalse();
+        assertThat(unLimitedTierProducerAgent.isClosed).isFalse();
+        tieredStorageProducerClient.close();
+        assertThat(limitedSegmentsTierProducerAgent.isClosed).isTrue();
+        assertThat(unLimitedTierProducerAgent.isClosed).isTrue();
+        storageMemoryManager.release();
+    }
+
+    private void checkReceivedBuffersAndSegments(
+            int numSubpartitions,
+            int bufferSize,
+            int numSegmentCapacityInLimitedTier,
+            TestingTierProducerAgent limitedSegmentsTierProducerAgent,
+            TestingTierProducerAgent unLimitedTierProducerAgent,
+            int[] numSentRecords,
+            AtomicInteger numWriteBuffers,
+            AtomicInteger numWriteBytes) {
+        int[] numExpectedBuffersInLimitedTier = new int[numSubpartitions];
+        int[] numExpectedBuffersInUnLimitedTier = new int[numSubpartitions];
+        int[] numExpectedSegmentInLimitedTier = new int[numSubpartitions];
+        int[] numExpectedSegmentInUnLimitedTier = new int[numSubpartitions];
+
+        calculateNumExpectedBuffersAndSegments(
+                numSubpartitions,
+                numSegmentCapacityInLimitedTier,
+                numSentRecords,
+                numExpectedBuffersInLimitedTier,
+                numExpectedBuffersInUnLimitedTier,
+                numExpectedSegmentInLimitedTier,
+                numExpectedSegmentInUnLimitedTier);
+
+        for (int i = 0; i < numSubpartitions; i++) {
+            
assertThat(limitedSegmentsTierProducerAgent.numTotalReceivedBuffers[i])
+                    .isEqualTo(numExpectedBuffersInLimitedTier[i]);
+            assertThat(unLimitedTierProducerAgent.numTotalReceivedBuffers[i])
+                    .isEqualTo(numExpectedBuffersInUnLimitedTier[i]);
+            
assertThat(limitedSegmentsTierProducerAgent.numTotalReceivedBuffers[i] > 
0).isTrue();
+            assertThat(limitedSegmentsTierProducerAgent.numReceivedSegments[i])
+                    .isEqualTo(numExpectedSegmentInLimitedTier[i]);
+            assertThat(unLimitedTierProducerAgent.numReceivedSegments[i])
+                    .isEqualTo(numExpectedSegmentInUnLimitedTier[i]);
+        }
+
+        int numTotalRecords = Arrays.stream(numSentRecords).sum();
+        int numExpectedBuffers = isBroadcast ? numTotalRecords * 
numSubpartitions : numTotalRecords;
+        assertThat(numWriteBuffers.get()).isEqualTo(numExpectedBuffers);
+        assertThat(numWriteBytes.get()).isEqualTo(numExpectedBuffers * 
bufferSize);
+    }
+
+    private void calculateNumExpectedBuffersAndSegments(
+            int numSubpartitions,
+            int numSegmentCapacityInLimitedTier,
+            int[] numSentRecords,
+            int[] numExpectedBuffersInLimitedTier,
+            int[] numExpectedBuffersInUnLimitedTier,
+            int[] numExpectedSegmentInLimitedTier,
+            int[] numExpectedSegmentInUnLimitedTier) {
+        int numTotalRecords = 
Arrays.stream(numSentRecords).map(Integer::new).sum();
+        for (int i = 0; i < numSubpartitions; i++) {
+            int totalSentRecords = isBroadcast ? numTotalRecords : 
numSentRecords[i];
+            int maxBuffersInLimitedTier =
+                    numSegmentCapacityInLimitedTier * NUM_BUFFERS_IN_A_SEGMENT;
+
+            numExpectedBuffersInLimitedTier[i] =
+                    Math.min(totalSentRecords, maxBuffersInLimitedTier);
+            numExpectedSegmentInLimitedTier[i] =
+                    numExpectedBuffersInLimitedTier[i] / 
NUM_BUFFERS_IN_A_SEGMENT
+                            + (numExpectedBuffersInLimitedTier[i] % 
NUM_BUFFERS_IN_A_SEGMENT == 0
+                                    ? 0
+                                    : 1);
+            numExpectedBuffersInUnLimitedTier[i] =
+                    totalSentRecords > maxBuffersInLimitedTier
+                            ? totalSentRecords - maxBuffersInLimitedTier
+                            : 0;
+            numExpectedSegmentInUnLimitedTier[i] =
+                    numExpectedBuffersInUnLimitedTier[i] / 
NUM_BUFFERS_IN_A_SEGMENT
+                            + (numExpectedBuffersInUnLimitedTier[i] % 
NUM_BUFFERS_IN_A_SEGMENT == 0
+                                    ? 0
+                                    : 1);
+        }
+    }
+
+    private static class TestingTierProducerAgent implements TierProducerAgent 
{
+        private final boolean isCapacityLimited;
+        private final int numSegmentCapacity;
+        private final int[] numReceivedSegments;
+        private final int[] numTotalReceivedBuffers;
+        private final int[] numReceivedBuffersInCurrentSegment;
+        private boolean isClosed;
+
+        TestingTierProducerAgent(boolean isCapacityLimited, int 
numSubpartitions) {
+            this(isCapacityLimited, numSubpartitions, Integer.MAX_VALUE);
+        }
+
+        TestingTierProducerAgent(
+                boolean isCapacityLimited, int numSubpartitions, int 
numSegmentCapacity) {
+            this.isCapacityLimited = isCapacityLimited;
+            this.numSegmentCapacity = numSegmentCapacity;
+            this.numReceivedSegments = new int[numSubpartitions];
+            this.numTotalReceivedBuffers = new int[numSubpartitions];
+            this.numReceivedBuffersInCurrentSegment = new 
int[numSubpartitions];
+
+            Arrays.fill(numReceivedSegments, 0);
+            Arrays.fill(numTotalReceivedBuffers, 0);
+            Arrays.fill(numReceivedBuffersInCurrentSegment, 0);
+        }
+
+        @Override
+        public boolean tryStartNewSegment(
+                TieredStorageSubpartitionId subpartitionId, int segmentId) {
+            boolean canStartNewSegment =
+                    !isCapacityLimited
+                            || 
numReceivedSegments[subpartitionId.getSubpartitionId()]
+                                    < numSegmentCapacity;
+            if (canStartNewSegment) {
+                numReceivedSegments[subpartitionId.getSubpartitionId()]++;
+            }
+            return canStartNewSegment;
+        }
+
+        @Override
+        public boolean write(TieredStorageSubpartitionId subpartitionId, 
Buffer finishedBuffer) {
+            if 
(numReceivedBuffersInCurrentSegment[subpartitionId.getSubpartitionId()] + 1
+                    > NUM_BUFFERS_IN_A_SEGMENT) {
+                
numReceivedBuffersInCurrentSegment[subpartitionId.getSubpartitionId()] = 0;
+                return false;
+            }
+            
numReceivedBuffersInCurrentSegment[subpartitionId.getSubpartitionId()]++;
+            numTotalReceivedBuffers[subpartitionId.getSubpartitionId()]++;
+            finishedBuffer.recycleBuffer();
+            return true;
+        }
+
+        @Override
+        public void close() {
+            this.isClosed = true;
+        }
+    }
+
+    private TieredStorageMemoryManagerImpl createStorageMemoryManager(int 
numBuffersInBufferPool)
+            throws IOException {
+        BufferPool bufferPool =
+                globalPool.createBufferPool(numBuffersInBufferPool, 
numBuffersInBufferPool);
+        TieredStorageMemoryManagerImpl storageMemoryManager =
+                new 
TieredStorageMemoryManagerImpl(NUM_BUFFERS_TRIGGER_FLUSH_RATIO, true);
+        storageMemoryManager.setup(
+                bufferPool, Collections.singletonList(new 
TieredStorageMemorySpec(this, 1)));
+        return storageMemoryManager;
+    }
+
+    private static ByteBuffer generateRandomData(int dataSize, Random random) {
+        byte[] dataWritten = new byte[dataSize];
+        random.nextBytes(dataWritten);
+        return ByteBuffer.wrap(dataWritten);
+    }

Review Comment:
   These two methods seem exactly the same as those in 
`HashBufferAccumulatorTest`, so why not extract them all into a common utils 
class called `TieredStorageTestUtils`.



##########
flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/TieredStorageProducerClient.java:
##########
@@ -100,26 +137,95 @@ public void close() {
      */
     private void writeAccumulatedBuffers(
             TieredStorageSubpartitionId subpartitionId, List<Buffer> 
accumulatedBuffers) {
-        try {
-            for (Buffer finishedBuffer : accumulatedBuffers) {
-                writeAccumulatedBuffer(subpartitionId, finishedBuffer);
+        Iterator<Buffer> bufferIterator = accumulatedBuffers.iterator();
+
+        int numWriteBytes = 0;
+        int numWriteBuffers = 0;
+        while (bufferIterator.hasNext()) {
+            Buffer buffer = bufferIterator.next();
+            try {
+                writeAccumulatedBuffer(subpartitionId, buffer);
+            } catch (IOException ioe) {
+                buffer.recycleBuffer();
+                while (bufferIterator.hasNext()) {
+                    bufferIterator.next().recycleBuffer();
+                }
+                ExceptionUtils.rethrow(ioe);
             }
-        } catch (IOException e) {
-            ExceptionUtils.rethrow(e);
+            numWriteBuffers++;
+            numWriteBytes += buffer.readableBytes();
         }
+        updateMetricStatistics(numWriteBuffers, numWriteBytes);
     }
 
     /**
      * Write the accumulated buffer of this subpartitionId to an appropriate 
tier. After the tier is
      * decided, the buffer will be written to the selected tier.
      *
+     * <p>Note that the method only throws an exception when choosing a 
storage tier, so the caller
+     * should ensure that the buffer is recycled when throwing an exception.
+     *
      * @param subpartitionId the subpartition identifier
      * @param accumulatedBuffer one accumulated buffer of this subpartition
      */
     private void writeAccumulatedBuffer(
             TieredStorageSubpartitionId subpartitionId, Buffer 
accumulatedBuffer)
             throws IOException {
-        // TODO, Try to write the accumulated buffer to the appropriate tier. 
After the tier is
-        // decided, then write the accumulated buffer to the tier.
+        Buffer compressedBuffer = compressBufferIfPossible(accumulatedBuffer);
+
+        if (currentSubpartitionTierAgent[subpartitionId.getSubpartitionId()] 
== null) {
+            chooseStorageTierToStartSegment(subpartitionId);
+        }
+
+        boolean isSuccess =
+                
currentSubpartitionTierAgent[subpartitionId.getSubpartitionId()].write(
+                        subpartitionId, compressedBuffer);
+        if (!isSuccess) {
+            chooseStorageTierToStartSegment(subpartitionId);
+            isSuccess =
+                    
currentSubpartitionTierAgent[subpartitionId.getSubpartitionId()].write(
+                            subpartitionId, compressedBuffer);
+            checkState(isSuccess, "Failed to write the first buffer to the new 
segment");
+        }

Review Comment:
   ```suggestion
      if 
(!currentSubpartitionTierAgent[subpartitionId.getSubpartitionId()].write(
                   subpartitionId, compressedBuffer)) {
               chooseStorageTierToStartSegment(subpartitionId);
   
               checkState(
                       
currentSubpartitionTierAgent[subpartitionId.getSubpartitionId()].write(
                               subpartitionId, compressedBuffer),
                       "Failed to write the first buffer to the new segment");
           }
   ```



##########
flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/TieredStorageProducerClientTest.java:
##########
@@ -0,0 +1,352 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage;
+
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.BufferPool;
+import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStorageSubpartitionId;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.tier.TierProducerAgent;
+import org.apache.flink.testutils.junit.extensions.parameterized.Parameter;
+import 
org.apache.flink.testutils.junit.extensions.parameterized.ParameterizedTestExtension;
+import org.apache.flink.testutils.junit.extensions.parameterized.Parameters;
+
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.TestTemplate;
+import org.junit.jupiter.api.extension.ExtendWith;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.Random;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/** Tests for {@link TieredStorageProducerClient}. */
+@ExtendWith(ParameterizedTestExtension.class)
+public class TieredStorageProducerClientTest {
+
+    private static final int NUM_TOTAL_BUFFERS = 1000;
+
+    private static final int NETWORK_BUFFER_SIZE = 1024;
+
+    private static final float NUM_BUFFERS_TRIGGER_FLUSH_RATIO = 0.6f;
+
+    private static final int NUM_BUFFERS_IN_A_SEGMENT = 5;
+
+    @Parameter public boolean isBroadcast;
+
+    private NetworkBufferPool globalPool;
+
+    @Parameters(name = "isBroadcast={0}")
+    public static Collection<Boolean> parameters() {
+        return Arrays.asList(false, true);
+    }
+
+    @BeforeEach
+    void before() {
+        globalPool = new NetworkBufferPool(NUM_TOTAL_BUFFERS, 
NETWORK_BUFFER_SIZE);
+    }
+
+    @AfterEach
+    void after() {
+        globalPool.destroy();
+    }
+
+    @TestTemplate
+    void testWriteRecordsToEmptyStorageTiers() throws IOException {
+        int numSubpartitions = 10;
+        int numBuffersInPool = 10;
+        int bufferSize = 1024;
+        Random random = new Random();
+
+        TieredStorageMemoryManagerImpl storageMemoryManager =
+                createStorageMemoryManager(numBuffersInPool);
+        BufferAccumulator bufferAccumulator =
+                new HashBufferAccumulator(numSubpartitions, bufferSize, 
storageMemoryManager);
+
+        TieredStorageProducerClient tieredStorageProducerClient =
+                new TieredStorageProducerClient(
+                        numSubpartitions, false, bufferAccumulator, null, 
Collections.emptyList());
+
+        assertThatThrownBy(
+                        () ->
+                                tieredStorageProducerClient.write(
+                                        generateRandomData(bufferSize, random),
+                                        new TieredStorageSubpartitionId(0),
+                                        Buffer.DataType.DATA_BUFFER,
+                                        isBroadcast))
+                .isInstanceOf(RuntimeException.class)
+                .hasMessageContaining("Failed to choose a storage tier");
+    }
+
+    @TestTemplate
+    void testEmptyMetricUpdater() throws IOException {

Review Comment:
   TBH, I don't think this is a meaningful test case. If we don't set the value 
for a not initialized field, it will be `null`, which doesn't seem to require 
special testing.



##########
flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/TieredStorageProducerClientTest.java:
##########
@@ -0,0 +1,352 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage;
+
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.BufferPool;
+import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStorageSubpartitionId;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.tier.TierProducerAgent;
+import org.apache.flink.testutils.junit.extensions.parameterized.Parameter;
+import 
org.apache.flink.testutils.junit.extensions.parameterized.ParameterizedTestExtension;
+import org.apache.flink.testutils.junit.extensions.parameterized.Parameters;
+
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.TestTemplate;
+import org.junit.jupiter.api.extension.ExtendWith;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.Random;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/** Tests for {@link TieredStorageProducerClient}. */
+@ExtendWith(ParameterizedTestExtension.class)
+public class TieredStorageProducerClientTest {
+
+    private static final int NUM_TOTAL_BUFFERS = 1000;
+
+    private static final int NETWORK_BUFFER_SIZE = 1024;
+
+    private static final float NUM_BUFFERS_TRIGGER_FLUSH_RATIO = 0.6f;
+
+    private static final int NUM_BUFFERS_IN_A_SEGMENT = 5;

Review Comment:
   ```suggestion
       private static final int NUM_BUFFERS_PER_SEGMENT = 5;
   ```



##########
flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/tier/TierProducerAgent.java:
##########
@@ -35,9 +41,21 @@ public interface TierProducerAgent {
      */
     boolean tryStartNewSegment(TieredStorageSubpartitionId subpartitionId, int 
segmentId);
 
-    /** Writes the finished {@link Buffer} to the consumer. */
-    boolean write(TieredStorageSubpartitionId subpartitionId, Buffer 
finishedBuffer)
-            throws IOException;
+    /**
+     * Writes the finished {@link Buffer} to the consumer.
+     *
+     * <p>Note that the tier must ensure that the buffer is written 
successfully without any
+     * exceptions, in order to guarantee that the buffer will be recycled. If 
this method throws an
+     * exception in the subsequent modifications, the caller should make sure 
that the buffer is
+     * recycled finally.

Review Comment:
   > If this method throws an exception in the subsequent modifications, the 
caller should make sure that the buffer is recycled finally.
   
   The readability of this paragraph is not good, which has caused me some 
confusion. UUIC, what we want to express is that if the method is successfully 
executed (without throwing any exception), the buffer should be released by the 
caller, otherwise it is the responsibility of the tier.



##########
flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/TieredStorageProducerClientTest.java:
##########
@@ -0,0 +1,352 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage;
+
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.BufferPool;
+import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStorageSubpartitionId;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.tier.TierProducerAgent;
+import org.apache.flink.testutils.junit.extensions.parameterized.Parameter;
+import 
org.apache.flink.testutils.junit.extensions.parameterized.ParameterizedTestExtension;
+import org.apache.flink.testutils.junit.extensions.parameterized.Parameters;
+
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.TestTemplate;
+import org.junit.jupiter.api.extension.ExtendWith;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.Random;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/** Tests for {@link TieredStorageProducerClient}. */
+@ExtendWith(ParameterizedTestExtension.class)
+public class TieredStorageProducerClientTest {
+
+    private static final int NUM_TOTAL_BUFFERS = 1000;
+
+    private static final int NETWORK_BUFFER_SIZE = 1024;
+
+    private static final float NUM_BUFFERS_TRIGGER_FLUSH_RATIO = 0.6f;
+
+    private static final int NUM_BUFFERS_IN_A_SEGMENT = 5;
+
+    @Parameter public boolean isBroadcast;
+
+    private NetworkBufferPool globalPool;
+
+    @Parameters(name = "isBroadcast={0}")
+    public static Collection<Boolean> parameters() {
+        return Arrays.asList(false, true);
+    }
+
+    @BeforeEach
+    void before() {
+        globalPool = new NetworkBufferPool(NUM_TOTAL_BUFFERS, 
NETWORK_BUFFER_SIZE);
+    }
+
+    @AfterEach
+    void after() {
+        globalPool.destroy();
+    }
+
+    @TestTemplate
+    void testWriteRecordsToEmptyStorageTiers() throws IOException {
+        int numSubpartitions = 10;
+        int numBuffersInPool = 10;
+        int bufferSize = 1024;
+        Random random = new Random();
+
+        TieredStorageMemoryManagerImpl storageMemoryManager =
+                createStorageMemoryManager(numBuffersInPool);
+        BufferAccumulator bufferAccumulator =
+                new HashBufferAccumulator(numSubpartitions, bufferSize, 
storageMemoryManager);
+
+        TieredStorageProducerClient tieredStorageProducerClient =
+                new TieredStorageProducerClient(
+                        numSubpartitions, false, bufferAccumulator, null, 
Collections.emptyList());
+
+        assertThatThrownBy(
+                        () ->
+                                tieredStorageProducerClient.write(
+                                        generateRandomData(bufferSize, random),
+                                        new TieredStorageSubpartitionId(0),
+                                        Buffer.DataType.DATA_BUFFER,
+                                        isBroadcast))
+                .isInstanceOf(RuntimeException.class)
+                .hasMessageContaining("Failed to choose a storage tier");

Review Comment:
   It is difficult to see the connection between this call(i.e. 
`tieredStorageProducerClient.write`) and this expected exception, I have to 
debug the code to understand the logic. And this test relies too heavily on the 
implementation of `BufferAccumulator`, if we change its code in the future, it 
is likely to implicate this test. 
   
   I'd suggestion using a mock `TestingBufferAccumulator` to ensure that this 
call will flush a finished buffer.



##########
flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/TieredStorageProducerClientTest.java:
##########
@@ -0,0 +1,352 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage;
+
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.BufferPool;
+import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStorageSubpartitionId;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.tier.TierProducerAgent;
+import org.apache.flink.testutils.junit.extensions.parameterized.Parameter;
+import 
org.apache.flink.testutils.junit.extensions.parameterized.ParameterizedTestExtension;
+import org.apache.flink.testutils.junit.extensions.parameterized.Parameters;
+
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.TestTemplate;
+import org.junit.jupiter.api.extension.ExtendWith;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.Random;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/** Tests for {@link TieredStorageProducerClient}. */
+@ExtendWith(ParameterizedTestExtension.class)
+public class TieredStorageProducerClientTest {
+
+    private static final int NUM_TOTAL_BUFFERS = 1000;
+
+    private static final int NETWORK_BUFFER_SIZE = 1024;
+
+    private static final float NUM_BUFFERS_TRIGGER_FLUSH_RATIO = 0.6f;
+
+    private static final int NUM_BUFFERS_IN_A_SEGMENT = 5;
+
+    @Parameter public boolean isBroadcast;
+
+    private NetworkBufferPool globalPool;
+
+    @Parameters(name = "isBroadcast={0}")
+    public static Collection<Boolean> parameters() {
+        return Arrays.asList(false, true);
+    }
+
+    @BeforeEach
+    void before() {
+        globalPool = new NetworkBufferPool(NUM_TOTAL_BUFFERS, 
NETWORK_BUFFER_SIZE);
+    }
+
+    @AfterEach
+    void after() {
+        globalPool.destroy();
+    }
+
+    @TestTemplate
+    void testWriteRecordsToEmptyStorageTiers() throws IOException {
+        int numSubpartitions = 10;
+        int numBuffersInPool = 10;
+        int bufferSize = 1024;
+        Random random = new Random();
+
+        TieredStorageMemoryManagerImpl storageMemoryManager =
+                createStorageMemoryManager(numBuffersInPool);
+        BufferAccumulator bufferAccumulator =
+                new HashBufferAccumulator(numSubpartitions, bufferSize, 
storageMemoryManager);

Review Comment:
   These two fields can be promoted to class members and initialized in 
`before()`.



##########
flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/tiered/storage/TieredStorageProducerClientTest.java:
##########
@@ -0,0 +1,352 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage;
+
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.BufferPool;
+import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.common.TieredStorageSubpartitionId;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.tiered.tier.TierProducerAgent;
+import org.apache.flink.testutils.junit.extensions.parameterized.Parameter;
+import 
org.apache.flink.testutils.junit.extensions.parameterized.ParameterizedTestExtension;
+import org.apache.flink.testutils.junit.extensions.parameterized.Parameters;
+
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.TestTemplate;
+import org.junit.jupiter.api.extension.ExtendWith;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.Random;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/** Tests for {@link TieredStorageProducerClient}. */
+@ExtendWith(ParameterizedTestExtension.class)
+public class TieredStorageProducerClientTest {
+
+    private static final int NUM_TOTAL_BUFFERS = 1000;
+
+    private static final int NETWORK_BUFFER_SIZE = 1024;
+
+    private static final float NUM_BUFFERS_TRIGGER_FLUSH_RATIO = 0.6f;
+
+    private static final int NUM_BUFFERS_IN_A_SEGMENT = 5;
+
+    @Parameter public boolean isBroadcast;
+
+    private NetworkBufferPool globalPool;
+
+    @Parameters(name = "isBroadcast={0}")
+    public static Collection<Boolean> parameters() {
+        return Arrays.asList(false, true);
+    }
+
+    @BeforeEach
+    void before() {
+        globalPool = new NetworkBufferPool(NUM_TOTAL_BUFFERS, 
NETWORK_BUFFER_SIZE);
+    }
+
+    @AfterEach
+    void after() {
+        globalPool.destroy();
+    }
+
+    @TestTemplate
+    void testWriteRecordsToEmptyStorageTiers() throws IOException {
+        int numSubpartitions = 10;
+        int numBuffersInPool = 10;
+        int bufferSize = 1024;
+        Random random = new Random();
+
+        TieredStorageMemoryManagerImpl storageMemoryManager =
+                createStorageMemoryManager(numBuffersInPool);
+        BufferAccumulator bufferAccumulator =
+                new HashBufferAccumulator(numSubpartitions, bufferSize, 
storageMemoryManager);
+
+        TieredStorageProducerClient tieredStorageProducerClient =
+                new TieredStorageProducerClient(
+                        numSubpartitions, false, bufferAccumulator, null, 
Collections.emptyList());
+
+        assertThatThrownBy(
+                        () ->
+                                tieredStorageProducerClient.write(
+                                        generateRandomData(bufferSize, random),
+                                        new TieredStorageSubpartitionId(0),
+                                        Buffer.DataType.DATA_BUFFER,
+                                        isBroadcast))
+                .isInstanceOf(RuntimeException.class)
+                .hasMessageContaining("Failed to choose a storage tier");
+    }
+
+    @TestTemplate
+    void testEmptyMetricUpdater() throws IOException {
+        int numSubpartitions = 10;
+        int numBuffersInPool = 10;
+        int bufferSize = 1024;
+        Random random = new Random();
+
+        TieredStorageMemoryManagerImpl storageMemoryManager =
+                createStorageMemoryManager(numBuffersInPool);
+        BufferAccumulator bufferAccumulator =
+                new HashBufferAccumulator(numSubpartitions, bufferSize, 
storageMemoryManager);
+        TestingTierProducerAgent testingTierProducerAgent =
+                new TestingTierProducerAgent(false, numSubpartitions);
+
+        TieredStorageProducerClient tieredStorageProducerClient =
+                new TieredStorageProducerClient(
+                        numSubpartitions,
+                        false,
+                        bufferAccumulator,
+                        null,
+                        Collections.singletonList(testingTierProducerAgent));
+
+        assertThatThrownBy(
+                        () ->
+                                tieredStorageProducerClient.write(
+                                        generateRandomData(bufferSize, random),
+                                        new TieredStorageSubpartitionId(0),
+                                        Buffer.DataType.DATA_BUFFER,
+                                        isBroadcast))
+                .isInstanceOf(NullPointerException.class);
+    }
+
+    @TestTemplate
+    void testWriteRecordsToMultipleStorageTiers() throws IOException {

Review Comment:
   This test is too complex, and the some test logic does not belong to 
`TieredStorageProducerClient`. What we need to test is the logic that only 
belongs `TieredStorageProducerClient`. Others like whether the buffer is 
written to the tier is not its concern.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to