This is an automated email from the ASF dual-hosted git repository.

xtsong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit d11940c4a78c71548b5a06af50da2e5f9cb68918
Author: Weijie Guo <[email protected]>
AuthorDate: Mon Sep 26 16:46:34 2022 +0800

    [FLINK-28889] HsResultPartition support broadcast optimize
    
    This closes #21122
---
 .../ResultPartitionDeploymentDescriptor.java       |  5 ++
 .../network/partition/ResultPartitionFactory.java  | 12 +++
 .../partition/hybrid/HsResultPartition.java        | 22 +++++-
 .../network/partition/ResultPartitionBuilder.java  |  8 ++
 .../partition/ResultPartitionFactoryTest.java      | 18 +++++
 .../partition/hybrid/HsResultPartitionTest.java    | 86 ++++++++++++++++++++++
 .../shuffle/PartitionDescriptorBuilder.java        |  9 ++-
 7 files changed, 155 insertions(+), 5 deletions(-)

diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/ResultPartitionDeploymentDescriptor.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/ResultPartitionDeploymentDescriptor.java
index 3bf2c802817..f510bb08acb 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/ResultPartitionDeploymentDescriptor.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/ResultPartitionDeploymentDescriptor.java
@@ -63,6 +63,11 @@ public class ResultPartitionDeploymentDescriptor implements 
Serializable {
         return partitionDescriptor.getPartitionId();
     }
 
+    /** Whether the resultPartition is a broadcast edge. */
+    public boolean isBroadcast() {
+        return partitionDescriptor.isBroadcast();
+    }
+
     public ResultPartitionType getPartitionType() {
         return partitionDescriptor.getPartitionType();
     }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartitionFactory.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartitionFactory.java
index a95f5e76770..7858f4c6709 100755
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartitionFactory.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartitionFactory.java
@@ -126,6 +126,7 @@ public class ResultPartitionFactory {
                 desc.getPartitionType(),
                 desc.getNumberOfSubpartitions(),
                 desc.getMaxParallelism(),
+                desc.isBroadcast(),
                 createBufferPoolFactory(desc.getNumberOfSubpartitions(), 
desc.getPartitionType()));
     }
 
@@ -137,6 +138,7 @@ public class ResultPartitionFactory {
             ResultPartitionType type,
             int numberOfSubpartitions,
             int maxParallelism,
+            boolean isBroadcast,
             SupplierWithException<BufferPool, IOException> bufferPoolFactory) {
         BufferCompressor bufferCompressor = null;
         if (type.supportCompression() && batchShuffleCompressionEnabled) {
@@ -216,6 +218,15 @@ public class ResultPartitionFactory {
             }
         } else if (type == ResultPartitionType.HYBRID_FULL
                 || type == ResultPartitionType.HYBRID_SELECTIVE) {
+            if (type == ResultPartitionType.HYBRID_SELECTIVE && isBroadcast) {
+                // for broadcast result partition, it can be optimized to 
always use full spilling
+                // strategy to significantly reduce shuffle data writing cost.
+                LOG.info(
+                        "{} result partition has been replaced by {} result 
partition to reduce shuffle data writing cost.",
+                        type,
+                        ResultPartitionType.HYBRID_FULL);
+                type = ResultPartitionType.HYBRID_FULL;
+            }
             partition =
                     new HsResultPartition(
                             taskNameWithSubtaskAndId,
@@ -240,6 +251,7 @@ public class ResultPartitionFactory {
                                                             
.SpillingStrategyType.SELECTIVE)
                                     .build(),
                             bufferCompressor,
+                            isBroadcast,
                             bufferPoolFactory);
         } else {
             throw new IllegalArgumentException("Unrecognized 
ResultPartitionType: " + type);
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsResultPartition.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsResultPartition.java
index 26aee9450ad..d4f2f632aff 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsResultPartition.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsResultPartition.java
@@ -60,6 +60,8 @@ import static org.apache.flink.util.Preconditions.checkState;
 public class HsResultPartition extends ResultPartition {
     public static final String DATA_FILE_SUFFIX = ".hybrid.data";
 
+    public static final int BROADCAST_CHANNEL = 0;
+
     private final HsFileDataIndex dataIndex;
 
     private final HsFileDataManager fileDataManager;
@@ -77,6 +79,9 @@ public class HsResultPartition extends ResultPartition {
 
     @Nullable private HsMemoryDataManager memoryDataManager;
 
+    /** Whether this result partition broadcasts all data and event. */
+    private final boolean isBroadcastOnly;
+
     public HsResultPartition(
             String owningTaskName,
             int partitionIndex,
@@ -91,6 +96,7 @@ public class HsResultPartition extends ResultPartition {
             int networkBufferSize,
             HybridShuffleConfiguration hybridShuffleConfiguration,
             @Nullable BufferCompressor bufferCompressor,
+            boolean isBroadcastOnly,
             SupplierWithException<BufferPool, IOException> bufferPoolFactory) {
         super(
                 owningTaskName,
@@ -103,9 +109,10 @@ public class HsResultPartition extends ResultPartition {
                 bufferCompressor,
                 bufferPoolFactory);
         this.networkBufferSize = networkBufferSize;
-        this.dataIndex = new HsFileDataIndexImpl(numSubpartitions);
+        this.dataIndex = new HsFileDataIndexImpl(isBroadcastOnly ? 1 : 
numSubpartitions);
         this.dataFilePath = new File(dataFileBashPath + 
DATA_FILE_SUFFIX).toPath();
         this.hybridShuffleConfiguration = hybridShuffleConfiguration;
+        this.isBroadcastOnly = isBroadcastOnly;
         this.fileDataManager =
                 new HsFileDataManager(
                         readBufferPool,
@@ -126,7 +133,7 @@ public class HsResultPartition extends ResultPartition {
         this.fileDataManager.setup();
         this.memoryDataManager =
                 new HsMemoryDataManager(
-                        numSubpartitions,
+                        isBroadcastOnly ? 1 : numSubpartitions,
                         networkBufferSize,
                         bufferPool,
                         getSpillingStrategy(hybridShuffleConfiguration),
@@ -167,8 +174,12 @@ public class HsResultPartition extends ResultPartition {
 
     private void broadcast(ByteBuffer record, Buffer.DataType dataType) throws 
IOException {
         numBytesProduced.inc(record.remaining());
-        for (int i = 0; i < numSubpartitions; i++) {
-            emit(record.duplicate(), i, dataType);
+        if (isBroadcastOnly) {
+            emit(record, BROADCAST_CHANNEL, dataType);
+        } else {
+            for (int i = 0; i < numSubpartitions; i++) {
+                emit(record.duplicate(), i, dataType);
+            }
         }
     }
 
@@ -190,6 +201,9 @@ public class HsResultPartition extends ResultPartition {
         if (!Files.isReadable(dataFilePath)) {
             throw new PartitionNotFoundException(getPartitionId());
         }
+        // if broadcastOptimize is enabled, map every subpartitionId to the 
special broadcast
+        // channel.
+        subpartitionId = isBroadcastOnly ? BROADCAST_CHANNEL : subpartitionId;
 
         HsSubpartitionConsumer subpartitionConsumer =
                 new HsSubpartitionConsumer(availabilityListener);
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/ResultPartitionBuilder.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/ResultPartitionBuilder.java
index 5ef48002b03..7ad6cd08b84 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/ResultPartitionBuilder.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/ResultPartitionBuilder.java
@@ -47,6 +47,8 @@ public class ResultPartitionBuilder {
 
     private int numTargetKeyGroups = 1;
 
+    private boolean isBroadcast = false;
+
     private ResultPartitionManager partitionManager = new 
ResultPartitionManager();
 
     private FileChannelManager channelManager = 
NoOpFileChannelManager.INSTANCE;
@@ -211,6 +213,11 @@ public class ResultPartitionBuilder {
         return this;
     }
 
+    public ResultPartitionBuilder setBroadcast(boolean broadcast) {
+        isBroadcast = broadcast;
+        return this;
+    }
+
     public ResultPartition build() {
         ResultPartitionFactory resultPartitionFactory =
                 new ResultPartitionFactory(
@@ -244,6 +251,7 @@ public class ResultPartitionBuilder {
                 partitionType,
                 numberOfSubpartitions,
                 numTargetKeyGroups,
+                isBroadcast,
                 factory);
     }
 }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/ResultPartitionFactoryTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/ResultPartitionFactoryTest.java
index 62420fabdcd..79a0aa89463 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/ResultPartitionFactoryTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/ResultPartitionFactoryTest.java
@@ -132,12 +132,29 @@ class ResultPartitionFactoryTest {
         assertThat(resultPartition.isReleased()).isFalse();
     }
 
+    @Test
+    public void testHybridBroadcastEdgeAlwaysUseFullResultPartition() {
+        ResultPartition resultPartition =
+                createResultPartition(
+                        ResultPartitionType.HYBRID_SELECTIVE, 
Integer.MAX_VALUE, true);
+        
assertThat(resultPartition.partitionType).isEqualTo(ResultPartitionType.HYBRID_FULL);
+
+        resultPartition =
+                createResultPartition(ResultPartitionType.HYBRID_FULL, 
Integer.MAX_VALUE, true);
+        
assertThat(resultPartition.partitionType).isEqualTo(ResultPartitionType.HYBRID_FULL);
+    }
+
     private static ResultPartition createResultPartition(ResultPartitionType 
partitionType) {
         return createResultPartition(partitionType, Integer.MAX_VALUE);
     }
 
     private static ResultPartition createResultPartition(
             ResultPartitionType partitionType, int sortShuffleMinParallelism) {
+        return createResultPartition(partitionType, sortShuffleMinParallelism, 
false);
+    }
+
+    private static ResultPartition createResultPartition(
+            ResultPartitionType partitionType, int sortShuffleMinParallelism, 
boolean isBroadcast) {
         final ResultPartitionManager manager = new ResultPartitionManager();
 
         final ResultPartitionFactory factory =
@@ -163,6 +180,7 @@ class ResultPartitionFactoryTest {
                 new ResultPartitionDeploymentDescriptor(
                         PartitionDescriptorBuilder.newBuilder()
                                 .setPartitionType(partitionType)
+                                .setIsBroadcast(isBroadcast)
                                 .build(),
                         
NettyShuffleDescriptorBuilder.newBuilder().buildLocal(),
                         1);
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsResultPartitionTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsResultPartitionTest.java
index 296834db22b..4d5fb55c203 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsResultPartitionTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsResultPartitionTest.java
@@ -279,6 +279,60 @@ class HsResultPartitionTest {
         }
     }
 
+    @Test
+    void testBroadcastResultPartition() throws Exception {
+        final int numBuffers = 10;
+        final int numRecords = 10;
+        final int numConsumers = 2;
+        final Random random = new Random();
+
+        BufferPool bufferPool = globalPool.createBufferPool(numBuffers, 
numBuffers);
+        try (HsResultPartition resultPartition = createHsResultPartition(2, 
bufferPool, true)) {
+            List<ByteBuffer> dataWritten = new ArrayList<>();
+            for (int i = 0; i < numRecords; i++) {
+                ByteBuffer record = generateRandomData(bufferSize, random);
+                resultPartition.broadcastRecord(record);
+                dataWritten.add(record);
+            }
+            resultPartition.finish();
+
+            Tuple2[] viewAndListeners = 
createSubpartitionViews(resultPartition, 2);
+
+            List<List<Buffer>> dataRead = new ArrayList<>();
+            for (int i = 0; i < numConsumers; i++) {
+                dataRead.add(new ArrayList<>());
+            }
+            readData(
+                    viewAndListeners,
+                    (buffer, subpartition) -> {
+                        int numBytes = buffer.readableBytes();
+                        if (buffer.isBuffer()) {
+                            MemorySegment segment =
+                                    
MemorySegmentFactory.allocateUnpooledSegment(numBytes);
+                            segment.put(0, buffer.getNioBufferReadable(), 
numBytes);
+                            dataRead.get(subpartition)
+                                    .add(
+                                            new NetworkBuffer(
+                                                    segment,
+                                                    (buf) -> {},
+                                                    buffer.getDataType(),
+                                                    numBytes));
+                        }
+                    });
+
+            for (int i = 0; i < numConsumers; i++) {
+                assertThat(dataWritten).hasSameSizeAs(dataRead.get(i));
+                List<Buffer> readBufferList = dataRead.get(i);
+                for (int j = 0; j < dataWritten.size(); j++) {
+                    ByteBuffer bufferWritten = dataWritten.get(j);
+                    bufferWritten.rewind();
+                    Buffer bufferRead = readBufferList.get(j);
+                    
assertThat(bufferRead.getNioBufferReadable()).isEqualTo(bufferWritten);
+                }
+            }
+        }
+    }
+
     @Test
     void testClose() throws Exception {
         final int numBuffers = 1;
@@ -427,6 +481,20 @@ class HsResultPartitionTest {
         }
     }
 
+    @Test
+    void testMetricsUpdateForBroadcastOnlyResultPartition() throws Exception {
+        BufferPool bufferPool = globalPool.createBufferPool(3, 3);
+        try (HsResultPartition partition = createHsResultPartition(2, 
bufferPool, true)) {
+            partition.broadcastRecord(ByteBuffer.allocate(bufferSize));
+            
assertThat(taskIOMetricGroup.getNumBuffersOutCounter().getCount()).isEqualTo(1);
+            
assertThat(taskIOMetricGroup.getNumBytesOutCounter().getCount()).isEqualTo(bufferSize);
+            IOMetrics ioMetrics = taskIOMetricGroup.createSnapshot();
+            assertThat(ioMetrics.getNumBytesProducedOfPartitions())
+                    .hasSize(1)
+                    .containsValue((long) bufferSize);
+        }
+    }
+
     private static void recordDataWritten(
             ByteBuffer record,
             Queue<Tuple2<ByteBuffer, Buffer.DataType>>[] dataWritten,
@@ -493,9 +561,25 @@ class HsResultPartitionTest {
 
     private HsResultPartition createHsResultPartition(int numSubpartitions, 
BufferPool bufferPool)
             throws IOException {
+        return createHsResultPartition(numSubpartitions, bufferPool, false);
+    }
+
+    private HsResultPartition createHsResultPartition(
+            int numSubpartitions,
+            BufferPool bufferPool,
+            HybridShuffleConfiguration hybridShuffleConfiguration)
+            throws IOException {
+        return createHsResultPartition(
+                numSubpartitions, bufferPool, false, 
hybridShuffleConfiguration);
+    }
+
+    private HsResultPartition createHsResultPartition(
+            int numSubpartitions, BufferPool bufferPool, boolean 
isBroadcastOnly)
+            throws IOException {
         return createHsResultPartition(
                 numSubpartitions,
                 bufferPool,
+                isBroadcastOnly,
                 HybridShuffleConfiguration.builder(
                                 numSubpartitions, 
readBufferPool.getNumBuffersPerRequest())
                         .build());
@@ -504,6 +588,7 @@ class HsResultPartitionTest {
     private HsResultPartition createHsResultPartition(
             int numSubpartitions,
             BufferPool bufferPool,
+            boolean isBroadcastOnly,
             HybridShuffleConfiguration hybridShuffleConfiguration)
             throws IOException {
         HsResultPartition hsResultPartition =
@@ -521,6 +606,7 @@ class HsResultPartitionTest {
                         bufferSize,
                         hybridShuffleConfiguration,
                         null,
+                        isBroadcastOnly,
                         () -> bufferPool);
         taskIOMetricGroup =
                 
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup();
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/shuffle/PartitionDescriptorBuilder.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/shuffle/PartitionDescriptorBuilder.java
index de85cd10adc..d2305eb1b31 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/shuffle/PartitionDescriptorBuilder.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/shuffle/PartitionDescriptorBuilder.java
@@ -28,6 +28,8 @@ public class PartitionDescriptorBuilder {
     private ResultPartitionType partitionType;
     private int totalNumberOfPartitions = 1;
 
+    private boolean isBroadcast = false;
+
     private PartitionDescriptorBuilder() {
         this.partitionId = new IntermediateResultPartitionID();
         this.partitionType = ResultPartitionType.PIPELINED;
@@ -48,6 +50,11 @@ public class PartitionDescriptorBuilder {
         return this;
     }
 
+    public PartitionDescriptorBuilder setIsBroadcast(boolean isBroadcast) {
+        this.isBroadcast = isBroadcast;
+        return this;
+    }
+
     public PartitionDescriptor build() {
         return new PartitionDescriptor(
                 new IntermediateDataSetID(),
@@ -56,7 +63,7 @@ public class PartitionDescriptorBuilder {
                 partitionType,
                 1,
                 0,
-                false,
+                isBroadcast,
                 true);
     }
 

Reply via email to