This is an automated email from the ASF dual-hosted git repository.

xtsong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit fed2962d3a25179dd0fac508cf3eb29ed7b533e1
Author: Weijie Guo <[email protected]>
AuthorDate: Thu Jul 28 22:34:06 2022 +0800

    [FLINK-27908] Introduce HsResultPartition
---
 .../partition/hybrid/HsResultPartition.java        | 266 ++++++++++++
 .../partition/hybrid/HsResultPartitionTest.java    | 476 +++++++++++++++++++++
 2 files changed, 742 insertions(+)

diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsResultPartition.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsResultPartition.java
new file mode 100644
index 00000000000..100df9f7cff
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsResultPartition.java
@@ -0,0 +1,266 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.partition.hybrid;
+
+import org.apache.flink.configuration.IllegalConfigurationException;
+import org.apache.flink.runtime.checkpoint.CheckpointException;
+import org.apache.flink.runtime.event.AbstractEvent;
+import org.apache.flink.runtime.io.disk.BatchShuffleReadBufferPool;
+import org.apache.flink.runtime.io.network.api.EndOfData;
+import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
+import org.apache.flink.runtime.io.network.api.StopMode;
+import org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.BufferCompressor;
+import org.apache.flink.runtime.io.network.buffer.BufferPool;
+import 
org.apache.flink.runtime.io.network.partition.BufferAvailabilityListener;
+import org.apache.flink.runtime.io.network.partition.ResultPartition;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionManager;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
+import org.apache.flink.runtime.io.network.partition.ResultSubpartitionView;
+import org.apache.flink.util.function.SupplierWithException;
+
+import javax.annotation.Nullable;
+
+import java.io.File;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.file.Path;
+import java.util.concurrent.Executor;
+
+import static org.apache.flink.util.Preconditions.checkNotNull;
+import static org.apache.flink.util.Preconditions.checkState;
+
+/**
+ * {@link HsResultPartition} appends records and events to {@link 
HsMemoryDataManager}, the shuffle
+ * data maybe spilled to disk according to the {@link HsSpillingStrategy}, and 
the downstream can
+ * consume data from memory or disk.
+ */
+public class HsResultPartition extends ResultPartition {
+    public static final String DATA_FILE_SUFFIX = ".hybrid.data";
+
+    private final HsFileDataIndex dataIndex;
+
+    private final HsFileDataManager fileDataManager;
+
+    private final Path dataFilePath;
+
+    private final int networkBufferSize;
+
+    private final HybridShuffleConfiguration hybridShuffleConfiguration;
+
+    private boolean hasNotifiedEndOfUserRecords;
+
+    @Nullable private HsMemoryDataManager memoryDataManager;
+
+    public HsResultPartition(
+            String owningTaskName,
+            int partitionIndex,
+            ResultPartitionID partitionId,
+            ResultPartitionType partitionType,
+            int numSubpartitions,
+            int numTargetKeyGroups,
+            BatchShuffleReadBufferPool readBufferPool,
+            Executor readIOExecutor,
+            ResultPartitionManager partitionManager,
+            String dataFileBashPath,
+            int networkBufferSize,
+            HybridShuffleConfiguration hybridShuffleConfiguration,
+            @Nullable BufferCompressor bufferCompressor,
+            SupplierWithException<BufferPool, IOException> bufferPoolFactory) {
+        super(
+                owningTaskName,
+                partitionIndex,
+                partitionId,
+                partitionType,
+                numSubpartitions,
+                numTargetKeyGroups,
+                partitionManager,
+                bufferCompressor,
+                bufferPoolFactory);
+        this.networkBufferSize = networkBufferSize;
+        this.dataIndex = new HsFileDataIndexImpl(numSubpartitions);
+        this.dataFilePath = new File(dataFileBashPath + 
DATA_FILE_SUFFIX).toPath();
+        this.hybridShuffleConfiguration = hybridShuffleConfiguration;
+        this.fileDataManager =
+                new HsFileDataManager(
+                        readBufferPool,
+                        readIOExecutor,
+                        dataIndex,
+                        dataFilePath,
+                        HsSubpartitionFileReaderImpl.Factory.INSTANCE,
+                        hybridShuffleConfiguration);
+    }
+
+    // Called by task thread.
+    @Override
+    protected void setupInternal() throws IOException {
+        if (isReleased()) {
+            throw new IOException("Result partition has been released.");
+        }
+        this.fileDataManager.setup();
+        this.memoryDataManager =
+                new HsMemoryDataManager(
+                        numSubpartitions,
+                        networkBufferSize,
+                        bufferPool,
+                        getSpillingStrategy(hybridShuffleConfiguration),
+                        dataIndex,
+                        dataFilePath);
+    }
+
+    @Override
+    public void emitRecord(ByteBuffer record, int targetSubpartition) throws 
IOException {
+        emit(record, targetSubpartition, Buffer.DataType.DATA_BUFFER);
+    }
+
+    @Override
+    public void broadcastRecord(ByteBuffer record) throws IOException {
+        broadcast(record, Buffer.DataType.DATA_BUFFER);
+    }
+
+    @Override
+    public void broadcastEvent(AbstractEvent event, boolean isPriorityEvent) 
throws IOException {
+        Buffer buffer = EventSerializer.toBuffer(event, isPriorityEvent);
+        try {
+            ByteBuffer serializedEvent = buffer.getNioBufferReadable();
+            broadcast(serializedEvent, buffer.getDataType());
+        } finally {
+            buffer.recycleBuffer();
+        }
+    }
+
+    private void broadcast(ByteBuffer record, Buffer.DataType dataType) throws 
IOException {
+        for (int i = 0; i < numSubpartitions; i++) {
+            emit(record.duplicate(), i, dataType);
+        }
+    }
+
+    private void emit(ByteBuffer record, int targetSubpartition, 
Buffer.DataType dataType)
+            throws IOException {
+        checkInProduceState();
+        checkNotNull(memoryDataManager).append(record, targetSubpartition, 
dataType);
+    }
+
+    @Override
+    public ResultSubpartitionView createSubpartitionView(
+            int subpartitionId, BufferAvailabilityListener 
availabilityListener)
+            throws IOException {
+        checkState(!isReleased(), "ResultPartition already released.");
+        HsSubpartitionView subpartitionView = new 
HsSubpartitionView(availabilityListener);
+        HsDataView diskDataView =
+                fileDataManager.registerNewSubpartition(subpartitionId, 
subpartitionView);
+
+        HsDataView memoryDataView =
+                checkNotNull(memoryDataManager)
+                        .registerSubpartitionView(subpartitionId, 
subpartitionView);
+
+        subpartitionView.setDiskDataView(diskDataView);
+        subpartitionView.setMemoryDataView(memoryDataView);
+        return subpartitionView;
+    }
+
+    @Override
+    public void alignedBarrierTimeout(long checkpointId) throws IOException {
+        // Nothing to do.
+    }
+
+    @Override
+    public void abortCheckpoint(long checkpointId, CheckpointException cause) {
+        // Nothing to do.
+    }
+
+    @Override
+    public void flushAll() {
+        // Nothing to do.
+    }
+
+    @Override
+    public void flush(int subpartitionIndex) {
+        // Nothing to do.
+    }
+
+    @Override
+    public void finish() throws IOException {
+        broadcastEvent(EndOfPartitionEvent.INSTANCE, false);
+
+        checkState(!isReleased(), "Result partition is already released.");
+
+        super.finish();
+    }
+
+    @Override
+    public void close() {
+        // close is called when task is finished or failed.
+        checkNotNull(memoryDataManager).close();
+        super.close();
+    }
+
+    @Override
+    protected void releaseInternal() {
+        // release is called when release by scheduler, later than close.
+        // mainly work :
+        // 1. release read scheduler.
+        // 2. delete shuffle file.
+        // 3. release all data in memory.
+
+        fileDataManager.release();
+
+        checkNotNull(memoryDataManager).release();
+    }
+
+    @Override
+    public int getNumberOfQueuedBuffers() {
+        // Batch shuffle does not need to provide QueuedBuffers information
+        return 0;
+    }
+
+    @Override
+    public long getSizeOfQueuedBuffersUnsafe() {
+        // Batch shuffle does not need to provide QueuedBuffers information
+        return 0;
+    }
+
+    @Override
+    public int getNumberOfQueuedBuffers(int targetSubpartition) {
+        // Batch shuffle does not need to provide QueuedBuffers information
+        return 0;
+    }
+
+    @Override
+    public void notifyEndOfData(StopMode mode) throws IOException {
+        if (!hasNotifiedEndOfUserRecords) {
+            broadcastEvent(new EndOfData(mode), false);
+            hasNotifiedEndOfUserRecords = true;
+        }
+    }
+
+    private HsSpillingStrategy getSpillingStrategy(
+            HybridShuffleConfiguration hybridShuffleConfiguration) {
+        switch (hybridShuffleConfiguration.getSpillingStrategyType()) {
+            case FULL:
+                return new HsFullSpillingStrategy(hybridShuffleConfiguration);
+            case SELECTIVE:
+                return new 
HsSelectiveSpillingStrategy(hybridShuffleConfiguration);
+            default:
+                throw new IllegalConfigurationException("Illegal spilling 
strategy.");
+        }
+    }
+}
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsResultPartitionTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsResultPartitionTest.java
new file mode 100644
index 00000000000..6c7d7b66b34
--- /dev/null
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsResultPartitionTest.java
@@ -0,0 +1,476 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.partition.hybrid;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.core.memory.MemorySegmentFactory;
+import org.apache.flink.core.testutils.CheckedThread;
+import org.apache.flink.runtime.event.AbstractEvent;
+import org.apache.flink.runtime.io.disk.BatchShuffleReadBufferPool;
+import org.apache.flink.runtime.io.disk.FileChannelManager;
+import org.apache.flink.runtime.io.disk.FileChannelManagerImpl;
+import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
+import org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.BufferPool;
+import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
+import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
+import 
org.apache.flink.runtime.io.network.partition.BufferAvailabilityListener;
+import 
org.apache.flink.runtime.io.network.partition.NoOpBufferAvailablityListener;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionManager;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
+import org.apache.flink.runtime.io.network.partition.ResultSubpartition;
+import org.apache.flink.runtime.io.network.partition.ResultSubpartitionView;
+import 
org.apache.flink.runtime.io.network.partition.hybrid.HybridShuffleConfiguration.SpillingStrategyType;
+
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.Timeout;
+import org.junit.jupiter.api.io.TempDir;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.file.Path;
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Queue;
+import java.util.Random;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.BiConsumer;
+
+import static org.apache.flink.util.Preconditions.checkNotNull;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/** Tests for {@link HsResultPartition}. */
+class HsResultPartitionTest {
+
+    private static final int bufferSize = 1024;
+
+    private static final int totalBuffers = 1000;
+
+    private static final int totalBytes = 32 * 1024 * 1024;
+
+    private static final int numThreads = 4;
+
+    private FileChannelManager fileChannelManager;
+
+    private NetworkBufferPool globalPool;
+
+    private BatchShuffleReadBufferPool readBufferPool;
+
+    private ExecutorService readIOExecutor;
+
+    @TempDir public Path tempDataPath;
+
+    @BeforeEach
+    void before() {
+        fileChannelManager =
+                new FileChannelManagerImpl(new String[] 
{tempDataPath.toString()}, "testing");
+        globalPool = new NetworkBufferPool(totalBuffers, bufferSize);
+        readBufferPool = new BatchShuffleReadBufferPool(totalBytes, 
bufferSize);
+        readIOExecutor = Executors.newFixedThreadPool(numThreads);
+    }
+
+    @AfterEach
+    void after() throws Exception {
+        fileChannelManager.close();
+        globalPool.destroy();
+        readBufferPool.destroy();
+        readIOExecutor.shutdown();
+    }
+
+    @Test
+    void testEmit() throws Exception {
+        int numBuffers = 100;
+        int numSubpartitions = 10;
+        int numRecords = 1000;
+        Random random = new Random();
+
+        BufferPool bufferPool = globalPool.createBufferPool(numBuffers, 
numBuffers);
+
+        try (HsResultPartition partition = 
createHsResultPartition(numSubpartitions, bufferPool)) {
+            Queue<Tuple2<ByteBuffer, Buffer.DataType>>[] dataWritten = new 
Queue[numSubpartitions];
+            Queue<Buffer>[] buffersRead = new Queue[numSubpartitions];
+            for (int i = 0; i < numSubpartitions; ++i) {
+                dataWritten[i] = new ArrayDeque<>();
+                buffersRead[i] = new ArrayDeque<>();
+            }
+
+            int[] numBytesWritten = new int[numSubpartitions];
+            int[] numBytesRead = new int[numSubpartitions];
+            Arrays.fill(numBytesWritten, 0);
+            Arrays.fill(numBytesRead, 0);
+
+            for (int i = 0; i < numRecords; ++i) {
+                ByteBuffer record = generateRandomData(random.nextInt(2 * 
bufferSize) + 1, random);
+                boolean isBroadCast = random.nextBoolean();
+
+                if (isBroadCast) {
+                    partition.broadcastRecord(record);
+                    for (int subpartition = 0; subpartition < 
numSubpartitions; ++subpartition) {
+                        recordDataWritten(
+                                record,
+                                dataWritten,
+                                subpartition,
+                                numBytesWritten,
+                                Buffer.DataType.DATA_BUFFER);
+                    }
+                } else {
+                    int subpartition = random.nextInt(numSubpartitions);
+                    partition.emitRecord(record, subpartition);
+                    recordDataWritten(
+                            record,
+                            dataWritten,
+                            subpartition,
+                            numBytesWritten,
+                            Buffer.DataType.DATA_BUFFER);
+                }
+            }
+
+            partition.finish();
+
+            for (int subpartition = 0; subpartition < numSubpartitions; 
++subpartition) {
+                ByteBuffer record = 
EventSerializer.toSerializedEvent(EndOfPartitionEvent.INSTANCE);
+                recordDataWritten(
+                        record,
+                        dataWritten,
+                        subpartition,
+                        numBytesWritten,
+                        Buffer.DataType.EVENT_BUFFER);
+            }
+
+            Tuple2<ResultSubpartitionView, 
TestingBufferAvailabilityListener>[] viewAndListeners =
+                    createSubpartitionViews(partition, numSubpartitions);
+            readData(
+                    viewAndListeners,
+                    (buffer, subpartitionId) -> {
+                        int numBytes = buffer.readableBytes();
+                        numBytesRead[subpartitionId] += numBytes;
+
+                        MemorySegment segment =
+                                
MemorySegmentFactory.allocateUnpooledSegment(numBytes);
+                        segment.put(0, buffer.getNioBufferReadable(), 
numBytes);
+                        buffersRead[subpartitionId].add(
+                                new NetworkBuffer(
+                                        segment, (buf) -> {}, 
buffer.getDataType(), numBytes));
+                    });
+            checkWriteReadResult(
+                    numSubpartitions, numBytesWritten, numBytesRead, 
dataWritten, buffersRead);
+        }
+    }
+
+    @Test
+    void testBroadcastEvent() throws Exception {
+        final int numBuffers = 1;
+        BufferPool bufferPool = globalPool.createBufferPool(numBuffers, 
numBuffers);
+        try (HsResultPartition resultPartition = createHsResultPartition(2, 
bufferPool)) {
+            resultPartition.broadcastEvent(EndOfPartitionEvent.INSTANCE, 
false);
+            // broadcast event does not request buffer
+            
assertThat(bufferPool.getNumberOfAvailableMemorySegments()).isEqualTo(1);
+
+            Tuple2[] viewAndListeners = 
createSubpartitionViews(resultPartition, 2);
+
+            boolean[] receivedEvent = new boolean[2];
+            readData(
+                    viewAndListeners,
+                    (buffer, subpartition) -> {
+                        assertThat(buffer.getDataType().isEvent()).isTrue();
+                        try {
+                            AbstractEvent event =
+                                    EventSerializer.fromSerializedEvent(
+                                            
buffer.readOnlySlice().getNioBufferReadable(),
+                                            
HsResultPartitionTest.class.getClassLoader());
+                            
assertThat(event).isInstanceOf(EndOfPartitionEvent.class);
+                            receivedEvent[subpartition] = true;
+                        } catch (IOException e) {
+                            throw new RuntimeException(e);
+                        }
+                    });
+
+            assertThat(receivedEvent).containsExactly(true, true);
+        }
+    }
+
+    @Test
+    void testClose() throws Exception {
+        final int numBuffers = 1;
+
+        BufferPool bufferPool = globalPool.createBufferPool(numBuffers, 
numBuffers);
+        HsResultPartition partition = createHsResultPartition(1, bufferPool);
+
+        partition.close();
+        // emit data to closed partition will throw exception.
+        assertThatThrownBy(() -> 
partition.emitRecord(ByteBuffer.allocate(bufferSize), 0));
+    }
+
+    @Test
+    @Timeout(30)
+    void testRelease() throws Exception {
+        final int numBuffers = 10;
+
+        BufferPool bufferPool = globalPool.createBufferPool(numBuffers, 
numBuffers);
+        HsResultPartition partition = createHsResultPartition(2, bufferPool);
+
+        partition.emitRecord(ByteBuffer.allocate(bufferSize * numBuffers), 1);
+        
assertThat(bufferPool.bestEffortGetNumOfUsedBuffers()).isEqualTo(numBuffers);
+
+        partition.close();
+        assertThat(bufferPool.isDestroyed()).isTrue();
+
+        partition.release();
+
+        while 
(checkNotNull(fileChannelManager.getPaths()[0].listFiles()).length != 0) {
+            Thread.sleep(10);
+        }
+
+        
assertThat(totalBuffers).isEqualTo(globalPool.getNumberOfAvailableMemorySegments());
+    }
+
+    @Test
+    void testCreateSubpartitionViewAfterRelease() throws Exception {
+        final int numBuffers = 10;
+        BufferPool bufferPool = globalPool.createBufferPool(numBuffers, 
numBuffers);
+        HsResultPartition resultPartition = createHsResultPartition(2, 
bufferPool);
+        resultPartition.release();
+        assertThatThrownBy(
+                        () ->
+                                resultPartition.createSubpartitionView(
+                                        0, new 
NoOpBufferAvailablityListener()))
+                .isInstanceOf(IllegalStateException.class);
+    }
+
+    @Test
+    void testAvailability() throws Exception {
+        final int numBuffers = 2;
+
+        BufferPool bufferPool = globalPool.createBufferPool(numBuffers, 
numBuffers);
+        HsResultPartition partition = createHsResultPartition(1, bufferPool);
+
+        partition.emitRecord(ByteBuffer.allocate(bufferSize * numBuffers), 0);
+        assertThat(partition.isAvailable()).isFalse();
+
+        // release partition to recycle buffer.
+        partition.close();
+        partition.release();
+
+        assertThat(partition.isAvailable()).isTrue();
+    }
+
+    private static void recordDataWritten(
+            ByteBuffer record,
+            Queue<Tuple2<ByteBuffer, Buffer.DataType>>[] dataWritten,
+            int subpartition,
+            int[] numBytesWritten,
+            Buffer.DataType dataType) {
+        record.rewind();
+        dataWritten[subpartition].add(Tuple2.of(record, dataType));
+        numBytesWritten[subpartition] += record.remaining();
+    }
+
+    private long readData(
+            Tuple2<ResultSubpartitionView, 
TestingBufferAvailabilityListener>[] viewAndListeners,
+            BiConsumer<Buffer, Integer> bufferProcessor)
+            throws Exception {
+        AtomicInteger dataSize = new AtomicInteger(0);
+        AtomicInteger numEndOfPartitionEvents = new AtomicInteger(0);
+        CheckedThread[] subpartitionViewThreads = new 
CheckedThread[viewAndListeners.length];
+        for (int i = 0; i < viewAndListeners.length; i++) {
+            // start thread for each view.
+            final int subpartition = i;
+            CheckedThread subpartitionViewThread =
+                    new CheckedThread() {
+                        @Override
+                        public void go() throws Exception {
+                            ResultSubpartitionView view = 
viewAndListeners[subpartition].f0;
+                            while (true) {
+                                ResultSubpartition.BufferAndBacklog 
bufferAndBacklog =
+                                        view.getNextBuffer();
+                                if (bufferAndBacklog == null) {
+                                    
viewAndListeners[subpartition].f1.waitForData();
+                                    continue;
+                                }
+                                Buffer buffer = bufferAndBacklog.buffer();
+                                bufferProcessor.accept(buffer, subpartition);
+                                dataSize.addAndGet(buffer.readableBytes());
+                                buffer.recycleBuffer();
+
+                                if (!buffer.isBuffer()) {
+                                    numEndOfPartitionEvents.incrementAndGet();
+                                    view.releaseAllResources();
+                                    break;
+                                }
+                                if (bufferAndBacklog.getNextDataType() == 
Buffer.DataType.NONE) {
+                                    
viewAndListeners[subpartition].f1.waitForData();
+                                }
+                            }
+                        }
+                    };
+            subpartitionViewThreads[subpartition] = subpartitionViewThread;
+            subpartitionViewThread.start();
+        }
+        for (CheckedThread thread : subpartitionViewThreads) {
+            thread.sync();
+        }
+        return dataSize.get();
+    }
+
+    private static ByteBuffer generateRandomData(int dataSize, Random random) {
+        byte[] dataWritten = new byte[dataSize];
+        random.nextBytes(dataWritten);
+        return ByteBuffer.wrap(dataWritten);
+    }
+
+    private HsResultPartition createHsResultPartition(
+            int numSubpartitions, BufferPool bufferPool, int 
numBuffersTriggerSpilling)
+            throws IOException {
+        HsResultPartition hsResultPartition =
+                new HsResultPartition(
+                        "HsResultPartitionTest",
+                        0,
+                        new ResultPartitionID(),
+                        ResultPartitionType.HYBRID,
+                        numSubpartitions,
+                        numSubpartitions,
+                        readBufferPool,
+                        readIOExecutor,
+                        new ResultPartitionManager(),
+                        fileChannelManager.createChannel().getPath(),
+                        bufferSize,
+                        HybridShuffleConfiguration.builder(
+                                        numSubpartitions, 
readBufferPool.getNumBuffersPerRequest())
+                                
.setSpillingStrategyType(SpillingStrategyType.FULL)
+                                
.setFullStrategyNumBuffersTriggerSpilling(numBuffersTriggerSpilling)
+                                .build(),
+                        null,
+                        () -> bufferPool);
+        hsResultPartition.setup();
+        return hsResultPartition;
+    }
+
+    private HsResultPartition createHsResultPartition(int numSubpartitions, 
BufferPool bufferPool)
+            throws IOException {
+        HsResultPartition hsResultPartition =
+                new HsResultPartition(
+                        "HsResultPartitionTest",
+                        0,
+                        new ResultPartitionID(),
+                        ResultPartitionType.HYBRID,
+                        numSubpartitions,
+                        numSubpartitions,
+                        readBufferPool,
+                        readIOExecutor,
+                        new ResultPartitionManager(),
+                        fileChannelManager.createChannel().getPath(),
+                        bufferSize,
+                        HybridShuffleConfiguration.builder(
+                                        numSubpartitions, 
readBufferPool.getNumBuffersPerRequest())
+                                .build(),
+                        null,
+                        () -> bufferPool);
+        hsResultPartition.setup();
+        return hsResultPartition;
+    }
+
+    private static void checkWriteReadResult(
+            int numSubpartitions,
+            int[] numBytesWritten,
+            int[] numBytesRead,
+            Queue<Tuple2<ByteBuffer, Buffer.DataType>>[] dataWritten,
+            Queue<Buffer>[] buffersRead) {
+        for (int subpartitionIndex = 0; subpartitionIndex < numSubpartitions; 
++subpartitionIndex) {
+            assertThat(numBytesWritten[subpartitionIndex])
+                    .isEqualTo(numBytesRead[subpartitionIndex]);
+
+            List<Tuple2<ByteBuffer, Buffer.DataType>> eventsWritten = new 
ArrayList<>();
+            List<Buffer> eventsRead = new ArrayList<>();
+
+            ByteBuffer subpartitionDataWritten =
+                    ByteBuffer.allocate(numBytesWritten[subpartitionIndex]);
+            for (Tuple2<ByteBuffer, Buffer.DataType> bufferDataTypeTuple :
+                    dataWritten[subpartitionIndex]) {
+                subpartitionDataWritten.put(bufferDataTypeTuple.f0);
+                bufferDataTypeTuple.f0.rewind();
+                if (bufferDataTypeTuple.f1.isEvent()) {
+                    eventsWritten.add(bufferDataTypeTuple);
+                }
+            }
+
+            ByteBuffer subpartitionDataRead = 
ByteBuffer.allocate(numBytesRead[subpartitionIndex]);
+            for (Buffer buffer : buffersRead[subpartitionIndex]) {
+                subpartitionDataRead.put(buffer.getNioBufferReadable());
+                if (!buffer.isBuffer()) {
+                    eventsRead.add(buffer);
+                }
+            }
+
+            subpartitionDataWritten.flip();
+            subpartitionDataRead.flip();
+            
assertThat(subpartitionDataWritten).isEqualTo(subpartitionDataRead);
+
+            assertThat(eventsWritten.size()).isEqualTo(eventsRead.size());
+            for (int i = 0; i < eventsWritten.size(); i++) {
+                
assertThat(eventsWritten.get(i).f1).isEqualTo(eventsRead.get(i).getDataType());
+                assertThat(eventsWritten.get(i).f0)
+                        .isEqualTo(eventsRead.get(i).getNioBufferReadable());
+            }
+        }
+    }
+
+    private Tuple2<ResultSubpartitionView, TestingBufferAvailabilityListener>[]
+            createSubpartitionViews(HsResultPartition partition, int 
numSubpartitions)
+                    throws Exception {
+        Tuple2<ResultSubpartitionView, TestingBufferAvailabilityListener>[] 
viewAndListeners =
+                new Tuple2[numSubpartitions];
+        for (int subpartition = 0; subpartition < numSubpartitions; 
++subpartition) {
+            TestingBufferAvailabilityListener listener = new 
TestingBufferAvailabilityListener();
+            viewAndListeners[subpartition] =
+                    Tuple2.of(partition.createSubpartitionView(subpartition, 
listener), listener);
+        }
+        return viewAndListeners;
+    }
+
+    private static final class TestingBufferAvailabilityListener
+            implements BufferAvailabilityListener {
+
+        private int numNotifications;
+
+        @Override
+        public synchronized void notifyDataAvailable() {
+            if (numNotifications == 0) {
+                notifyAll();
+            }
+            ++numNotifications;
+        }
+
+        public synchronized void waitForData() throws InterruptedException {
+            if (numNotifications == 0) {
+                wait();
+            }
+            numNotifications = 0;
+        }
+    }
+}

Reply via email to