This is an automated email from the ASF dual-hosted git repository. xtsong pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit 0d466603a981acfee8d7136220461f63849fbe0f Author: Weijie Guo <[email protected]> AuthorDate: Mon Jul 18 16:52:00 2022 +0800 [FLINK-27904][runtime] Introduce HsMemoryDataManager to manage in-memory data of hybrid shuffle mode This closes #20293 --- .../network/partition/hybrid/HsBufferContext.java | 128 ++++++ .../partition/hybrid/HsMemoryDataManager.java | 286 +++++++++++++ .../hybrid/HsMemoryDataManagerOperation.java | 52 +++ .../partition/hybrid/HsMemoryDataSpiller.java | 1 - .../hybrid/HsSubpartitionMemoryDataManager.java | 471 +++++++++++++++++++++ .../partition/hybrid/HsBufferContextTest.java | 131 ++++++ .../hybrid/HsFullSpillingStrategyTest.java | 2 +- .../partition/hybrid/HsMemoryDataManagerTest.java | 214 ++++++++++ .../hybrid/HsSelectiveSpillingStrategyTest.java | 2 +- .../hybrid/HsSpillingStrategyUtilsTest.java | 4 +- .../HsSubpartitionMemoryDataManagerTest.java | 427 +++++++++++++++++++ ...yTestUtils.java => HybridShuffleTestUtils.java} | 21 +- .../partition/hybrid/TestingFileDataIndex.java | 96 +++++ .../hybrid/TestingMemoryDataManagerOperation.java | 119 ++++++ .../partition/hybrid/TestingSpillingStrategy.java | 119 ++++++ 15 files changed, 2066 insertions(+), 7 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsBufferContext.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsBufferContext.java new file mode 100644 index 00000000000..8feb6fd0c41 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsBufferContext.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.partition.hybrid; + +import org.apache.flink.runtime.io.network.buffer.Buffer; + +import javax.annotation.Nullable; + +import java.util.Optional; +import java.util.concurrent.CompletableFuture; + +import static org.apache.flink.util.Preconditions.checkState; + +/** + * This class maintains the buffer's reference count and its status for hybrid shuffle mode. + * + * <p>Each buffer has three status: {@link #released}, {@link #spillStarted}, {@link #consumed}. + * + * <ul> + * <li>{@link #released} indicates that buffer has been released from the memory data manager, and + * can no longer be spilled or consumed. + * <li>{@link #spillStarted} indicates that spilling of the buffer has started, either completed + * or not. + * <li>{@link #consumed} indicates that buffer has been consumed by the downstream. + * </ul> + * + * <p>Reference count of the buffer is maintained as follows: * + * + * <ul> + * <li>+1 when the buffer is obtained by memory data manager (from the buffer pool), and -1 when + * it is released from memory data manager. + * <li>+1 when spilling of the buffer is tarted, and -1 when it is completed. + * <li>+1 when the buffer is being consumed, and -1 when consuming is completed (by the + * downstream). + * </ul> + * + * <p>Note: This class is not thread-safe. + */ +public class HsBufferContext { + private final Buffer buffer; + + private final BufferIndexAndChannel bufferIndexAndChannel; + + // -------------------------- + // Buffer Status + // -------------------------- + private boolean released; + + private boolean spillStarted; + + private boolean consumed; + + @Nullable private CompletableFuture<Void> spilledFuture; + + public HsBufferContext(Buffer buffer, int bufferIndex, int subpartitionId) { + this.bufferIndexAndChannel = new BufferIndexAndChannel(bufferIndex, subpartitionId); + this.buffer = buffer; + } + + public Buffer getBuffer() { + return buffer; + } + + public BufferIndexAndChannel getBufferIndexAndChannel() { + return bufferIndexAndChannel; + } + + public boolean isReleased() { + return released; + } + + public boolean isSpillStarted() { + return spillStarted; + } + + public boolean isConsumed() { + return consumed; + } + + public Optional<CompletableFuture<Void>> getSpilledFuture() { + return Optional.ofNullable(spilledFuture); + } + + public void release() { + checkState(!released, "Release buffer repeatedly is unexpected."); + released = true; + // decrease ref count when buffer is released from memory. + buffer.recycleBuffer(); + } + + public void startSpilling(CompletableFuture<Void> spilledFuture) { + checkState(!released, "Buffer is already released."); + checkState( + !spillStarted && this.spilledFuture == null, + "Spill buffer repeatedly is unexpected."); + spillStarted = true; + this.spilledFuture = spilledFuture; + // increase ref count when buffer is decided to spill. + buffer.retainBuffer(); + // decrease ref count when buffer spilling is finished. + spilledFuture.thenRun(buffer::recycleBuffer); + } + + public void consumed() { + checkState(!released, "Buffer is already released."); + checkState(!consumed, "Consume buffer repeatedly is unexpected."); + consumed = true; + // increase ref count when buffer is consumed, will be decreased when downstream finish + // consuming. + buffer.retainBuffer(); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsMemoryDataManager.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsMemoryDataManager.java new file mode 100644 index 00000000000..0a7c60cf9b2 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsMemoryDataManager.java @@ -0,0 +1,286 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.partition.hybrid; + +import org.apache.flink.core.memory.MemorySegment; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.buffer.BufferBuilder; +import org.apache.flink.runtime.io.network.buffer.BufferPool; +import org.apache.flink.runtime.io.network.partition.hybrid.HsSpillingStrategy.Decision; +import org.apache.flink.util.function.SupplierWithException; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Deque; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReadWriteLock; +import java.util.concurrent.locks.ReentrantReadWriteLock; + +/** This class is responsible for managing data in memory. */ +public class HsMemoryDataManager implements HsSpillingInfoProvider, HsMemoryDataManagerOperation { + + private final int numSubpartitions; + + private final HsSubpartitionMemoryDataManager[] subpartitionMemoryDataManagers; + + private final HsMemoryDataSpiller spiller; + + private final HsSpillingStrategy spillStrategy; + + private final HsFileDataIndex fileDataIndex; + + private final BufferPool bufferPool; + + private final Lock lock; + + private final AtomicInteger numRequestedBuffers = new AtomicInteger(0); + + private final AtomicInteger numUnSpillBuffers = new AtomicInteger(0); + + public HsMemoryDataManager( + int numSubpartitions, + int bufferSize, + BufferPool bufferPool, + HsSpillingStrategy spillStrategy, + HsFileDataIndex fileDataIndex, + FileChannel dataFileChannel) { + this.numSubpartitions = numSubpartitions; + this.bufferPool = bufferPool; + this.spiller = new HsMemoryDataSpiller(dataFileChannel); + this.spillStrategy = spillStrategy; + this.fileDataIndex = fileDataIndex; + this.subpartitionMemoryDataManagers = new HsSubpartitionMemoryDataManager[numSubpartitions]; + + ReadWriteLock readWriteLock = new ReentrantReadWriteLock(true); + this.lock = readWriteLock.writeLock(); + + for (int subpartitionId = 0; subpartitionId < numSubpartitions; ++subpartitionId) { + subpartitionMemoryDataManagers[subpartitionId] = + new HsSubpartitionMemoryDataManager( + subpartitionId, bufferSize, readWriteLock.readLock(), this); + } + } + + // ------------------------------------ + // For ResultPartition + // ------------------------------------ + + /** + * Append record to {@link HsMemoryDataManager}, It will be managed by {@link + * HsSubpartitionMemoryDataManager} witch it belongs to. + * + * @param record to be managed by this class. + * @param targetChannel target subpartition of this record. + * @param dataType the type of this record. In other words, is it data or event. + */ + public void append(ByteBuffer record, int targetChannel, Buffer.DataType dataType) + throws IOException { + try { + getSubpartitionMemoryDataManager(targetChannel).append(record, dataType); + } catch (InterruptedException e) { + throw new IOException(e); + } + } + + // ------------------------------------ + // For Spilling Strategy + // ------------------------------------ + + @Override + public int getPoolSize() { + return bufferPool.getNumBuffers(); + } + + @Override + public int getNumSubpartitions() { + return numSubpartitions; + } + + @Override + public int getNumTotalRequestedBuffers() { + return numRequestedBuffers.get(); + } + + @Override + public int getNumTotalUnSpillBuffers() { + return numUnSpillBuffers.get(); + } + + // Write lock should be acquired before invoke this method. + @Override + public Deque<BufferIndexAndChannel> getBuffersInOrder( + int subpartitionId, SpillStatus spillStatus, ConsumeStatus consumeStatus) { + HsSubpartitionMemoryDataManager targetSubpartitionDataManager = + getSubpartitionMemoryDataManager(subpartitionId); + return targetSubpartitionDataManager.getBuffersSatisfyStatus(spillStatus, consumeStatus); + } + + // Write lock should be acquired before invoke this method. + @Override + public List<Integer> getNextBufferIndexToConsume() { + // TODO implements this logical when subpartition view is implemented. + return Collections.emptyList(); + } + + // ------------------------------------ + // Callback for subpartition + // ------------------------------------ + + @Override + public void markBufferReadableFromFile(int subpartitionId, int bufferIndex) { + fileDataIndex.markBufferReadable(subpartitionId, bufferIndex); + } + + @Override + public BufferBuilder requestBufferFromPool() throws InterruptedException { + MemorySegment segment = bufferPool.requestMemorySegmentBlocking(); + Optional<Decision> decisionOpt = + spillStrategy.onMemoryUsageChanged( + numRequestedBuffers.incrementAndGet(), getPoolSize()); + + handleDecision(decisionOpt); + return new BufferBuilder(segment, this::recycleBuffer); + } + + @Override + public void onBufferConsumed(BufferIndexAndChannel consumedBuffer) { + Optional<Decision> decision = spillStrategy.onBufferConsumed(consumedBuffer); + handleDecision(decision); + } + + @Override + public void onBufferFinished() { + Optional<Decision> decision = + spillStrategy.onBufferFinished(numUnSpillBuffers.incrementAndGet()); + handleDecision(decision); + } + + // ------------------------------------ + // Internal Method + // ------------------------------------ + + // Attention: Do not call this method within the read lock and subpartition lock, otherwise + // deadlock may occur as this method maybe acquire write lock and other subpartition's lock + // inside. + private void handleDecision( + @SuppressWarnings("OptionalUsedAsFieldOrParameterType") + Optional<Decision> decisionOpt) { + Decision decision = + decisionOpt.orElseGet( + () -> callWithLock(() -> spillStrategy.decideActionWithGlobalInfo(this))); + + if (!decision.getBufferToSpill().isEmpty()) { + spillBuffers(decision.getBufferToSpill()); + } + if (!decision.getBufferToRelease().isEmpty()) { + releaseBuffers(decision.getBufferToRelease()); + } + } + + /** + * Spill buffers for each subpartition in a decision. + * + * <p>Note that: The method should not be locked, it is the responsibility of each subpartition + * to maintain thread safety itself. + * + * @param toSpill All buffers that need to be spilled in a decision. + */ + private void spillBuffers(Map<Integer, List<BufferIndexAndChannel>> toSpill) { + CompletableFuture<Void> spillingCompleteFuture = new CompletableFuture<>(); + List<BufferWithIdentity> bufferWithIdentities = new ArrayList<>(); + toSpill.forEach( + (subpartitionId, bufferIndexAndChannels) -> { + HsSubpartitionMemoryDataManager subpartitionDataManager = + getSubpartitionMemoryDataManager(subpartitionId); + bufferWithIdentities.addAll( + subpartitionDataManager.spillSubpartitionBuffers( + bufferIndexAndChannels, spillingCompleteFuture)); + // decrease numUnSpillBuffers as this subpartition's buffer is spill. + numUnSpillBuffers.getAndAdd(-bufferIndexAndChannels.size()); + }); + + spiller.spillAsync(bufferWithIdentities) + .thenAccept( + spilledBuffers -> { + fileDataIndex.addBuffers(spilledBuffers); + spillingCompleteFuture.complete(null); + }); + } + + /** + * Release buffers for each subpartition in a decision. + * + * <p>Note that: The method should not be locked, it is the responsibility of each subpartition + * to maintain thread safety itself. + * + * @param toRelease All buffers that need to be released in a decision. + */ + private void releaseBuffers(Map<Integer, List<BufferIndexAndChannel>> toRelease) { + toRelease.forEach( + (subpartitionId, subpartitionBuffers) -> + getSubpartitionMemoryDataManager(subpartitionId) + .releaseSubpartitionBuffers(subpartitionBuffers)); + } + + private HsSubpartitionMemoryDataManager getSubpartitionMemoryDataManager(int targetChannel) { + return subpartitionMemoryDataManagers[targetChannel]; + } + + private void recycleBuffer(MemorySegment buffer) { + numRequestedBuffers.decrementAndGet(); + bufferPool.recycle(buffer); + } + + public <T, R extends Exception> T callWithLock(SupplierWithException<T, R> callable) throws R { + try { + lock.lock(); + return callable.get(); + } finally { + lock.unlock(); + } + } + + /** Integrate the buffer and dataType of next buffer. */ + public static class BufferAndNextDataType { + private final Buffer buffer; + + private final Buffer.DataType nextDataType; + + public BufferAndNextDataType(Buffer buffer, Buffer.DataType nextDataType) { + this.buffer = buffer; + this.nextDataType = nextDataType; + } + + public Buffer getBuffer() { + return buffer; + } + + public Buffer.DataType getNextDataType() { + return nextDataType; + } + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsMemoryDataManagerOperation.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsMemoryDataManagerOperation.java new file mode 100644 index 00000000000..d34b251a700 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsMemoryDataManagerOperation.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.partition.hybrid; + +import org.apache.flink.runtime.io.network.buffer.BufferBuilder; + +/** + * This interface is used by {@link HsSubpartitionMemoryDataManager} to operate {@link + * HsMemoryDataManager}. Spilling decision may be made and handled inside these operations. + */ +public interface HsMemoryDataManagerOperation { + /** + * Request buffer from buffer pool. + * + * @return requested buffer. + */ + BufferBuilder requestBufferFromPool() throws InterruptedException; + + /** + * This method is called when buffer should mark as readable in {@link HsFileDataIndex}. + * + * @param subpartitionId the subpartition that target buffer belong to. + * @param bufferIndex index of buffer to mark as readable. + */ + void markBufferReadableFromFile(int subpartitionId, int bufferIndex); + + /** + * This method is called when buffer is consumed. + * + * @param consumedBuffer target buffer to mark as consumed. + */ + void onBufferConsumed(BufferIndexAndChannel consumedBuffer); + + /** This method is called when buffer is finished. */ + void onBufferFinished(); +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsMemoryDataSpiller.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsMemoryDataSpiller.java index dd225ba6b27..be376b5e241 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsMemoryDataSpiller.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsMemoryDataSpiller.java @@ -89,7 +89,6 @@ public class HsMemoryDataSpiller implements AutoCloseable { // complete spill future when buffers are written to disk successfully. // note that the ownership of these buffers is transferred to the MemoryDataManager, // which controls data's life cycle. - // TODO update file data index and handle buffers release in future ticket. spilledFuture.complete(spilledBuffers); } catch (IOException exception) { // if spilling is failed, throw exception directly to uncaughtExceptionHandler. diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSubpartitionMemoryDataManager.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSubpartitionMemoryDataManager.java new file mode 100644 index 00000000000..56814024911 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSubpartitionMemoryDataManager.java @@ -0,0 +1,471 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.partition.hybrid; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.core.memory.MemorySegment; +import org.apache.flink.core.memory.MemorySegmentFactory; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.buffer.Buffer.DataType; +import org.apache.flink.runtime.io.network.buffer.BufferBuilder; +import org.apache.flink.runtime.io.network.buffer.BufferConsumer; +import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler; +import org.apache.flink.runtime.io.network.buffer.NetworkBuffer; +import org.apache.flink.runtime.io.network.partition.hybrid.HsSpillingInfoProvider.ConsumeStatus; +import org.apache.flink.runtime.io.network.partition.hybrid.HsSpillingInfoProvider.SpillStatus; +import org.apache.flink.util.function.SupplierWithException; +import org.apache.flink.util.function.ThrowingRunnable; + +import javax.annotation.concurrent.GuardedBy; + +import java.nio.ByteBuffer; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Queue; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.locks.Lock; +import java.util.stream.Collectors; + +import static org.apache.flink.util.Preconditions.checkArgument; +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * This class is responsible for managing the data in a single subpartition. One {@link + * HsMemoryDataManager} will hold multiple {@link HsSubpartitionMemoryDataManager}. + */ +public class HsSubpartitionMemoryDataManager { + private final int targetChannel; + + private final int bufferSize; + + private final HsMemoryDataManagerOperation memoryDataManagerOperation; + + // Not guarded by lock because it is expected only accessed from task's main thread. + private final Queue<BufferBuilder> unfinishedBuffers = new LinkedList<>(); + + // Not guarded by lock because it is expected only accessed from task's main thread. + private int finishedBufferIndex; + + @GuardedBy("subpartitionLock") + private final Deque<HsBufferContext> allBuffers = new LinkedList<>(); + + @GuardedBy("subpartitionLock") + private final Deque<HsBufferContext> unConsumedBuffers = new LinkedList<>(); + + @GuardedBy("subpartitionLock") + private final Map<Integer, HsBufferContext> bufferIndexToContexts = new HashMap<>(); + + /** DO NOT USE DIRECTLY. Use {@link #runWithLock} or {@link #callWithLock} instead. */ + private final Lock resultPartitionLock; + + /** DO NOT USE DIRECTLY. Use {@link #runWithLock} or {@link #callWithLock} instead. */ + private final Object subpartitionLock = new Object(); + + HsSubpartitionMemoryDataManager( + int targetChannel, + int bufferSize, + Lock resultPartitionLock, + HsMemoryDataManagerOperation memoryDataManagerOperation) { + this.targetChannel = targetChannel; + this.bufferSize = bufferSize; + this.resultPartitionLock = resultPartitionLock; + this.memoryDataManagerOperation = memoryDataManagerOperation; + } + + // ------------------------------------------------------------------------ + // Called by Consumer + // ------------------------------------------------------------------------ + + /** + * Check whether the head of {@link #unConsumedBuffers} is the buffer to be consumed next time. + * If so, return the next buffer's data type. + * + * @param nextToConsumeIndex index of the buffer to be consumed next time. + * @return If the head of {@link #unConsumedBuffers} is target, return the buffer's data type. + * Otherwise, return {@link DataType#NONE}. + */ + @SuppressWarnings("FieldAccessNotGuarded") + // Note that: callWithLock ensure that code block guarded by resultPartitionReadLock and + // subpartitionLock. + public DataType peekNextToConsumeDataType(int nextToConsumeIndex) { + return callWithLock(() -> peekNextToConsumeDataTypeInternal(nextToConsumeIndex)); + } + + /** + * Check whether the head of {@link #unConsumedBuffers} is the buffer to be consumed. If so, + * return the buffer and next data type. + * + * @param toConsumeIndex index of buffer to be consumed. + * @return If the head of {@link #unConsumedBuffers} is target, return optional of the buffer + * and next data type. Otherwise, return {@link Optional#empty()}. + */ + @SuppressWarnings("FieldAccessNotGuarded") + // Note that: callWithLock ensure that code block guarded by resultPartitionReadLock and + // subpartitionLock. + public Optional<HsMemoryDataManager.BufferAndNextDataType> consumeBuffer(int toConsumeIndex) { + Optional<Tuple2<HsBufferContext, DataType>> bufferAndNextDataType = + callWithLock( + () -> { + if (!checkFirstUnConsumedBufferIndex(toConsumeIndex)) { + return Optional.empty(); + } + + HsBufferContext bufferContext = + checkNotNull(unConsumedBuffers.pollFirst()); + bufferContext.consumed(); + DataType nextDataType = + peekNextToConsumeDataTypeInternal(toConsumeIndex + 1); + return Optional.of(Tuple2.of(bufferContext, nextDataType)); + }); + + bufferAndNextDataType.ifPresent( + tuple -> + memoryDataManagerOperation.onBufferConsumed( + tuple.f0.getBufferIndexAndChannel())); + return bufferAndNextDataType.map( + tuple -> + new HsMemoryDataManager.BufferAndNextDataType( + tuple.f0.getBuffer(), tuple.f1)); + } + + // ------------------------------------------------------------------------ + // Called by MemoryDataManager + // ------------------------------------------------------------------------ + + /** + * Append record to {@link HsSubpartitionMemoryDataManager}. + * + * @param record to be managed by this class. + * @param dataType the type of this record. In other words, is it data or event. + */ + public void append(ByteBuffer record, DataType dataType) throws InterruptedException { + if (dataType.isEvent()) { + writeEvent(record, dataType); + } else { + writeRecord(record, dataType); + } + } + + /** + * Get buffers in {@link #allBuffers} that satisfy expected {@link SpillStatus} and {@link + * ConsumeStatus}. + * + * @param spillStatus the status of spilling expected. + * @param consumeStatus the status of consuming expected. + * @return buffers satisfy expected status in order. + */ + @SuppressWarnings("FieldAccessNotGuarded") + // Note that: callWithLock ensure that code block guarded by resultPartitionReadLock and + // subpartitionLock. + public Deque<BufferIndexAndChannel> getBuffersSatisfyStatus( + SpillStatus spillStatus, ConsumeStatus consumeStatus) { + return callWithLock( + () -> { + // TODO return iterator to avoid completely traversing the queue for each call. + Deque<BufferIndexAndChannel> targetBuffers = new ArrayDeque<>(); + // traverse buffers in order. + allBuffers.forEach( + (bufferContext -> { + if (isBufferSatisfyStatus( + bufferContext, spillStatus, consumeStatus)) { + targetBuffers.add(bufferContext.getBufferIndexAndChannel()); + } + })); + return targetBuffers; + }); + } + + /** + * Spill this subpartition's buffers in a decision. + * + * @param toSpill All buffers that need to be spilled belong to this subpartition in a decision. + * @param spillDoneFuture completed when spill is finished. + * @return {@link BufferWithIdentity}s about these spill buffers. + */ + @SuppressWarnings("FieldAccessNotGuarded") + // Note that: callWithLock ensure that code block guarded by resultPartitionReadLock and + // subpartitionLock. + public List<BufferWithIdentity> spillSubpartitionBuffers( + List<BufferIndexAndChannel> toSpill, CompletableFuture<Void> spillDoneFuture) { + return callWithLock( + () -> + toSpill.stream() + .map( + indexAndChannel -> { + int bufferIndex = indexAndChannel.getBufferIndex(); + HsBufferContext bufferContext = + startSpillingBuffer( + bufferIndex, spillDoneFuture); + return new BufferWithIdentity( + bufferContext.getBuffer(), + bufferIndex, + targetChannel); + }) + .collect(Collectors.toList())); + } + + /** + * Release this subpartition's buffers in a decision. + * + * @param toRelease All buffers that need to be released belong to this subpartition in a + * decision. + */ + @SuppressWarnings("FieldAccessNotGuarded") + // Note that: runWithLock ensure that code block guarded by resultPartitionReadLock and + // subpartitionLock. + public void releaseSubpartitionBuffers(List<BufferIndexAndChannel> toRelease) { + runWithLock( + () -> + toRelease.forEach( + (indexAndChannel) -> { + int bufferIndex = indexAndChannel.getBufferIndex(); + HsBufferContext bufferContext = + checkNotNull(bufferIndexToContexts.get(bufferIndex)); + checkAndMarkBufferReadable(bufferContext); + releaseBuffer(bufferIndex); + })); + } + + // ------------------------------------------------------------------------ + // Internal Methods + // ------------------------------------------------------------------------ + + private void writeEvent(ByteBuffer event, DataType dataType) { + checkArgument(dataType.isEvent()); + + // each Event must take an exclusive buffer + finishCurrentWritingBufferIfNotEmpty(); + + // store Events in adhoc heap segments, for network memory efficiency + MemorySegment data = MemorySegmentFactory.wrap(event.array()); + Buffer buffer = + new NetworkBuffer(data, FreeingBufferRecycler.INSTANCE, dataType, data.size()); + + HsBufferContext bufferContext = + new HsBufferContext(buffer, finishedBufferIndex, targetChannel); + addFinishedBuffer(bufferContext); + memoryDataManagerOperation.onBufferFinished(); + } + + private void writeRecord(ByteBuffer record, DataType dataType) throws InterruptedException { + checkArgument(!dataType.isEvent()); + + ensureCapacityForRecord(record); + + writeRecord(record); + } + + private void ensureCapacityForRecord(ByteBuffer record) throws InterruptedException { + final int numRecordBytes = record.remaining(); + int availableBytes = + Optional.ofNullable(unfinishedBuffers.peek()) + .map( + currentWritingBuffer -> + currentWritingBuffer.getWritableBytes() + + bufferSize * (unfinishedBuffers.size() - 1)) + .orElse(0); + + while (availableBytes < numRecordBytes) { + // request unfinished buffer. + BufferBuilder bufferBuilder = memoryDataManagerOperation.requestBufferFromPool(); + unfinishedBuffers.add(bufferBuilder); + availableBytes += bufferSize; + } + } + + private void writeRecord(ByteBuffer record) { + while (record.hasRemaining()) { + BufferBuilder currentWritingBuffer = + checkNotNull( + unfinishedBuffers.peek(), "Expect enough capacity for the record."); + currentWritingBuffer.append(record); + + if (currentWritingBuffer.isFull()) { + finishCurrentWritingBuffer(); + } + } + } + + private void finishCurrentWritingBufferIfNotEmpty() { + BufferBuilder currentWritingBuffer = unfinishedBuffers.peek(); + if (currentWritingBuffer == null || currentWritingBuffer.getWritableBytes() == bufferSize) { + return; + } + + finishCurrentWritingBuffer(); + } + + private void finishCurrentWritingBuffer() { + BufferBuilder currentWritingBuffer = unfinishedBuffers.poll(); + + if (currentWritingBuffer == null) { + return; + } + + currentWritingBuffer.finish(); + BufferConsumer bufferConsumer = currentWritingBuffer.createBufferConsumerFromBeginning(); + Buffer buffer = bufferConsumer.build(); + currentWritingBuffer.close(); + bufferConsumer.close(); + + HsBufferContext bufferContext = + new HsBufferContext(buffer, finishedBufferIndex, targetChannel); + addFinishedBuffer(bufferContext); + memoryDataManagerOperation.onBufferFinished(); + } + + @SuppressWarnings("FieldAccessNotGuarded") + // Note that: callWithLock ensure that code block guarded by resultPartitionReadLock and + // subpartitionLock. + private void addFinishedBuffer(HsBufferContext bufferContext) { + finishedBufferIndex++; + boolean needNotify = + callWithLock( + () -> { + allBuffers.add(bufferContext); + unConsumedBuffers.add(bufferContext); + bufferIndexToContexts.put( + bufferContext.getBufferIndexAndChannel().getBufferIndex(), + bufferContext); + trimHeadingReleasedBuffers(unConsumedBuffers); + return unConsumedBuffers.isEmpty(); + }); + if (needNotify) { + // TODO notify data available, the notification mechanism may need further + // consideration. + } + } + + @GuardedBy("subpartitionLock") + private DataType peekNextToConsumeDataTypeInternal(int nextToConsumeIndex) { + return checkFirstUnConsumedBufferIndex(nextToConsumeIndex) + ? checkNotNull(unConsumedBuffers.peekFirst()).getBuffer().getDataType() + : DataType.NONE; + } + + @GuardedBy("subpartitionLock") + private boolean checkFirstUnConsumedBufferIndex(int expectedBufferIndex) { + trimHeadingReleasedBuffers(unConsumedBuffers); + return !unConsumedBuffers.isEmpty() + && unConsumedBuffers.peekFirst().getBufferIndexAndChannel().getBufferIndex() + == expectedBufferIndex; + } + + /** + * Remove all released buffer from head of queue until buffer queue is empty or meet un-released + * buffer. + */ + @GuardedBy("subpartitionLock") + private void trimHeadingReleasedBuffers(Deque<HsBufferContext> bufferQueue) { + while (!bufferQueue.isEmpty() && bufferQueue.peekFirst().isReleased()) { + bufferQueue.removeFirst(); + } + } + + @GuardedBy("subpartitionLock") + private void releaseBuffer(int bufferIndex) { + HsBufferContext bufferContext = checkNotNull(bufferIndexToContexts.remove(bufferIndex)); + bufferContext.release(); + // remove released buffers from head lazy. + trimHeadingReleasedBuffers(allBuffers); + } + + @GuardedBy("subpartitionLock") + private HsBufferContext startSpillingBuffer( + int bufferIndex, CompletableFuture<Void> spillFuture) { + HsBufferContext bufferContext = checkNotNull(bufferIndexToContexts.get(bufferIndex)); + bufferContext.startSpilling(spillFuture); + return bufferContext; + } + + @GuardedBy("subpartitionLock") + private void checkAndMarkBufferReadable(HsBufferContext bufferContext) { + // only spill and not consumed buffer needs to be marked as readable. + if (isBufferSatisfyStatus(bufferContext, SpillStatus.SPILL, ConsumeStatus.NOT_CONSUMED)) { + bufferContext + .getSpilledFuture() + .orElseThrow( + () -> + new IllegalStateException( + "Buffer in spill status should already set spilled future.")) + .thenRun( + () -> { + BufferIndexAndChannel bufferIndexAndChannel = + bufferContext.getBufferIndexAndChannel(); + memoryDataManagerOperation.markBufferReadableFromFile( + bufferIndexAndChannel.getChannel(), + bufferIndexAndChannel.getBufferIndex()); + }); + } + } + + @GuardedBy("subpartitionLock") + private boolean isBufferSatisfyStatus( + HsBufferContext bufferContext, SpillStatus spillStatus, ConsumeStatus consumeStatus) { + // released buffer is not needed. + if (bufferContext.isReleased()) { + return false; + } + boolean match = true; + switch (spillStatus) { + case NOT_SPILL: + match = !bufferContext.isSpillStarted(); + break; + case SPILL: + match = bufferContext.isSpillStarted(); + break; + } + switch (consumeStatus) { + case NOT_CONSUMED: + match &= !bufferContext.isConsumed(); + break; + case CONSUMED: + match &= bufferContext.isConsumed(); + break; + } + return match; + } + + private <E extends Exception> void runWithLock(ThrowingRunnable<E> runnable) throws E { + try { + resultPartitionLock.lock(); + synchronized (subpartitionLock) { + runnable.run(); + } + } finally { + resultPartitionLock.unlock(); + } + } + + private <R, E extends Exception> R callWithLock(SupplierWithException<R, E> callable) throws E { + try { + resultPartitionLock.lock(); + synchronized (subpartitionLock) { + return callable.get(); + } + } finally { + resultPartitionLock.unlock(); + } + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsBufferContextTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsBufferContextTest.java new file mode 100644 index 00000000000..a16b811b68f --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsBufferContextTest.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.partition.hybrid; + +import org.apache.flink.runtime.io.network.buffer.Buffer; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.concurrent.CompletableFuture; + +import static org.apache.flink.runtime.io.network.partition.hybrid.HybridShuffleTestUtils.createBuffer; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Tests for {@link HsBufferContext}. */ +class HsBufferContextTest { + private static final int BUFFER_SIZE = 16; + + private static final int SUBPARTITION_ID = 0; + + private static final int BUFFER_INDEX = 0; + + private HsBufferContext bufferContext; + + @BeforeEach + void before() { + bufferContext = createBufferContext(); + } + + @Test + void testBufferStartSpillingRefCount() { + Buffer buffer = bufferContext.getBuffer(); + CompletableFuture<Void> spilledFuture = new CompletableFuture<>(); + bufferContext.startSpilling(spilledFuture); + assertThat(bufferContext.isSpillStarted()).isTrue(); + assertThat(buffer.refCnt()).isEqualTo(2); + spilledFuture.complete(null); + assertThat(buffer.refCnt()).isEqualTo(1); + } + + @Test + void testBufferStartSpillingRepeatedly() { + bufferContext.startSpilling(new CompletableFuture<>()); + assertThatThrownBy(() -> bufferContext.startSpilling(new CompletableFuture<>())) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Spill buffer repeatedly is unexpected."); + } + + @Test + void testBufferReleaseRefCount() { + Buffer buffer = bufferContext.getBuffer(); + assertThat(buffer.refCnt()).isEqualTo(1); + bufferContext.release(); + assertThat(bufferContext.isReleased()).isTrue(); + assertThat(buffer.isRecycled()).isTrue(); + } + + @Test + void testBufferReleaseRepeatedly() { + bufferContext.release(); + assertThatThrownBy(() -> bufferContext.release()) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Release buffer repeatedly is unexpected."); + } + + @Test + void testBufferConsumed() { + Buffer buffer = bufferContext.getBuffer(); + bufferContext.consumed(); + assertThat(bufferContext.isConsumed()).isTrue(); + assertThat(buffer.refCnt()).isEqualTo(2); + } + + @Test + void testBufferConsumedRepeatedly() { + bufferContext.consumed(); + assertThatThrownBy(() -> bufferContext.consumed()) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Consume buffer repeatedly is unexpected."); + } + + @Test + void testBufferStartSpillOrConsumedAfterReleased() { + bufferContext.release(); + assertThatThrownBy(() -> bufferContext.startSpilling(new CompletableFuture<>())) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Buffer is already released."); + assertThatThrownBy(() -> bufferContext.consumed()) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Buffer is already released."); + } + + @Test + void testBufferStartSpillingThenRelease() { + Buffer buffer = bufferContext.getBuffer(); + CompletableFuture<Void> spilledFuture = new CompletableFuture<>(); + bufferContext.startSpilling(spilledFuture); + bufferContext.release(); + spilledFuture.complete(null); + assertThat(buffer.isRecycled()).isTrue(); + } + + @Test + void testBufferConsumedThenRelease() { + Buffer buffer = bufferContext.getBuffer(); + bufferContext.consumed(); + bufferContext.release(); + assertThat(buffer.refCnt()).isEqualTo(1); + } + + private static HsBufferContext createBufferContext() { + return new HsBufferContext(createBuffer(BUFFER_SIZE, false), BUFFER_INDEX, SUBPARTITION_ID); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsFullSpillingStrategyTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsFullSpillingStrategyTest.java index 2f7674aa4af..b8d1289efb8 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsFullSpillingStrategyTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsFullSpillingStrategyTest.java @@ -29,7 +29,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; -import static org.apache.flink.runtime.io.network.partition.hybrid.HsSpillingStrategyTestUtils.createBufferIndexAndChannelsList; +import static org.apache.flink.runtime.io.network.partition.hybrid.HybridShuffleTestUtils.createBufferIndexAndChannelsList; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.entry; diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsMemoryDataManagerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsMemoryDataManagerTest.java new file mode 100644 index 00000000000..6608be9c500 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsMemoryDataManagerTest.java @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.partition.hybrid; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.buffer.BufferPool; +import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool; +import org.apache.flink.runtime.io.network.partition.hybrid.HsFileDataIndex.SpilledBuffer; +import org.apache.flink.runtime.io.network.partition.hybrid.HsSpillingStrategy.Decision; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Tests for {@link HsMemoryDataManager}. */ +class HsMemoryDataManagerTest { + private static final int NUM_BUFFERS = 10; + + private static final int NUM_SUBPARTITIONS = 3; + + private int poolSize = 10; + + private int bufferSize = Integer.BYTES; + + private FileChannel dataFileChannel; + + @BeforeEach + void before(@TempDir Path tempDir) throws Exception { + Path dataPath = Files.createFile(tempDir.resolve(".data")); + dataFileChannel = FileChannel.open(dataPath, StandardOpenOption.WRITE); + } + + @Test + void testAppendMarkBufferFinished() throws Exception { + AtomicInteger finishedBuffers = new AtomicInteger(0); + HsSpillingStrategy spillingStrategy = + TestingSpillingStrategy.builder() + .setOnBufferFinishedFunction( + (numTotalUnSpillBuffers) -> { + finishedBuffers.incrementAndGet(); + return Optional.of(Decision.NO_ACTION); + }) + .build(); + bufferSize = Integer.BYTES * 3; + HsMemoryDataManager memoryDataManager = createMemoryDataManager(spillingStrategy); + + memoryDataManager.append(createRecord(0), 0, Buffer.DataType.DATA_BUFFER); + memoryDataManager.append(createRecord(1), 0, Buffer.DataType.DATA_BUFFER); + assertThat(finishedBuffers).hasValue(0); + + memoryDataManager.append(createRecord(2), 0, Buffer.DataType.DATA_BUFFER); + assertThat(finishedBuffers).hasValue(1); + + memoryDataManager.append(createRecord(3), 0, Buffer.DataType.DATA_BUFFER); + memoryDataManager.append(createRecord(4), 0, Buffer.DataType.EVENT_BUFFER); + assertThat(finishedBuffers).hasValue(3); + } + + @Test + void testAppendRequestBuffer() throws Exception { + poolSize = 3; + List<Tuple2<Integer, Integer>> numFinishedBufferAndPoolSize = new ArrayList<>(); + HsSpillingStrategy spillingStrategy = + TestingSpillingStrategy.builder() + .setOnMemoryUsageChangedFunction( + (finishedBuffer, poolSize) -> { + numFinishedBufferAndPoolSize.add( + Tuple2.of(finishedBuffer, poolSize)); + return Optional.of(Decision.NO_ACTION); + }) + .build(); + HsMemoryDataManager memoryDataManager = createMemoryDataManager(spillingStrategy); + memoryDataManager.append(createRecord(0), 0, Buffer.DataType.DATA_BUFFER); + memoryDataManager.append(createRecord(1), 1, Buffer.DataType.DATA_BUFFER); + memoryDataManager.append(createRecord(2), 2, Buffer.DataType.DATA_BUFFER); + assertThat(memoryDataManager.getNumTotalRequestedBuffers()).isEqualTo(3); + List<Tuple2<Integer, Integer>> expectedFinishedBufferAndPoolSize = + Arrays.asList(Tuple2.of(1, 3), Tuple2.of(2, 3), Tuple2.of(3, 3)); + assertThat(numFinishedBufferAndPoolSize).isEqualTo(expectedFinishedBufferAndPoolSize); + } + + @Test + void testHandleDecision() throws Exception { + final int targetSubpartition = 0; + final int numFinishedBufferToTriggerDecision = 4; + List<BufferIndexAndChannel> toSpill = + HybridShuffleTestUtils.createBufferIndexAndChannelsList( + targetSubpartition, 0, 1, 2); + List<BufferIndexAndChannel> toRelease = + HybridShuffleTestUtils.createBufferIndexAndChannelsList(targetSubpartition, 2, 3); + HsSpillingStrategy spillingStrategy = + TestingSpillingStrategy.builder() + .setOnBufferFinishedFunction( + (numFinishedBuffers) -> { + if (numFinishedBuffers < numFinishedBufferToTriggerDecision) { + return Optional.of(Decision.NO_ACTION); + } + return Optional.of( + Decision.builder() + .addBufferToSpill(targetSubpartition, toSpill) + .addBufferToRelease( + targetSubpartition, toRelease) + .build()); + }) + .build(); + CompletableFuture<List<SpilledBuffer>> spilledFuture = new CompletableFuture<>(); + CompletableFuture<Integer> readableFuture = new CompletableFuture<>(); + TestingFileDataIndex dataIndex = + TestingFileDataIndex.builder() + .setAddBuffersConsumer(spilledFuture::complete) + .setMarkBufferReadableConsumer( + (subpartitionId, bufferIndex) -> + readableFuture.complete(bufferIndex)) + .build(); + HsMemoryDataManager memoryDataManager = + createMemoryDataManager(spillingStrategy, dataIndex); + for (int i = 0; i < 4; i++) { + memoryDataManager.append( + createRecord(i), targetSubpartition, Buffer.DataType.DATA_BUFFER); + } + + assertThat(spilledFuture).succeedsWithin(10, TimeUnit.SECONDS); + assertThat(readableFuture).succeedsWithin(10, TimeUnit.SECONDS); + assertThat(readableFuture).isCompletedWithValue(2); + assertThat(memoryDataManager.getNumTotalUnSpillBuffers()).isEqualTo(1); + } + + @Test + void testHandleEmptyDecision() throws Exception { + CompletableFuture<Void> globalDecisionFuture = new CompletableFuture<>(); + HsSpillingStrategy spillingStrategy = + TestingSpillingStrategy.builder() + .setOnBufferFinishedFunction( + (finishedBuffer) -> { + // return empty optional to trigger global decision. + return Optional.empty(); + }) + .setDecideActionWithGlobalInfoFunction( + (provider) -> { + globalDecisionFuture.complete(null); + return Decision.NO_ACTION; + }) + .build(); + HsMemoryDataManager memoryDataManager = createMemoryDataManager(spillingStrategy); + // trigger an empty decision. + memoryDataManager.onBufferFinished(); + assertThat(globalDecisionFuture).isCompleted(); + } + + private HsMemoryDataManager createMemoryDataManager(HsSpillingStrategy spillStrategy) + throws Exception { + NetworkBufferPool networkBufferPool = new NetworkBufferPool(NUM_BUFFERS, bufferSize); + BufferPool bufferPool = networkBufferPool.createBufferPool(poolSize, poolSize); + return new HsMemoryDataManager( + NUM_SUBPARTITIONS, + bufferSize, + bufferPool, + spillStrategy, + new HsFileDataIndexImpl(NUM_SUBPARTITIONS), + dataFileChannel); + } + + private HsMemoryDataManager createMemoryDataManager( + HsSpillingStrategy spillStrategy, HsFileDataIndex fileDataIndex) throws Exception { + NetworkBufferPool networkBufferPool = new NetworkBufferPool(NUM_BUFFERS, bufferSize); + BufferPool bufferPool = networkBufferPool.createBufferPool(poolSize, poolSize); + return new HsMemoryDataManager( + NUM_SUBPARTITIONS, + bufferSize, + bufferPool, + spillStrategy, + fileDataIndex, + dataFileChannel); + } + + private static ByteBuffer createRecord(int value) { + ByteBuffer byteBuffer = ByteBuffer.allocate(Integer.BYTES); + byteBuffer.putInt(value); + byteBuffer.flip(); + return byteBuffer; + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSelectiveSpillingStrategyTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSelectiveSpillingStrategyTest.java index 9e7d7254dba..6862c4c6e4a 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSelectiveSpillingStrategyTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSelectiveSpillingStrategyTest.java @@ -29,7 +29,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; -import static org.apache.flink.runtime.io.network.partition.hybrid.HsSpillingStrategyTestUtils.createBufferIndexAndChannelsList; +import static org.apache.flink.runtime.io.network.partition.hybrid.HybridShuffleTestUtils.createBufferIndexAndChannelsList; import static org.assertj.core.api.Assertions.assertThat; /** Tests for {@link HsSelectiveSpillingStrategy}. */ diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSpillingStrategyUtilsTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSpillingStrategyUtilsTest.java index 98d3ab7907f..b5849a48644 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSpillingStrategyUtilsTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSpillingStrategyUtilsTest.java @@ -26,8 +26,8 @@ import java.util.Deque; import java.util.List; import java.util.TreeMap; -import static org.apache.flink.runtime.io.network.partition.hybrid.HsSpillingStrategyTestUtils.createBufferIndexAndChannelsDeque; -import static org.apache.flink.runtime.io.network.partition.hybrid.HsSpillingStrategyTestUtils.createBufferIndexAndChannelsList; +import static org.apache.flink.runtime.io.network.partition.hybrid.HybridShuffleTestUtils.createBufferIndexAndChannelsDeque; +import static org.apache.flink.runtime.io.network.partition.hybrid.HybridShuffleTestUtils.createBufferIndexAndChannelsList; import static org.assertj.core.api.Assertions.assertThat; /** Tests for {@link HsSpillingStrategyUtils}. */ diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSubpartitionMemoryDataManagerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSubpartitionMemoryDataManagerTest.java new file mode 100644 index 00000000000..a15dfddc9f2 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSubpartitionMemoryDataManagerTest.java @@ -0,0 +1,427 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.partition.hybrid; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.core.memory.MemorySegment; +import org.apache.flink.core.memory.MemorySegmentFactory; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.buffer.Buffer.DataType; +import org.apache.flink.runtime.io.network.buffer.BufferBuilder; +import org.apache.flink.runtime.io.network.partition.hybrid.HsMemoryDataManager.BufferAndNextDataType; +import org.apache.flink.runtime.io.network.partition.hybrid.HsSpillingInfoProvider.ConsumeStatus; +import org.apache.flink.runtime.io.network.partition.hybrid.HsSpillingInfoProvider.SpillStatus; + +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Deque; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.locks.ReentrantReadWriteLock; +import java.util.stream.Collectors; + +import static org.apache.flink.runtime.io.network.partition.hybrid.HybridShuffleTestUtils.createBufferBuilder; +import static org.apache.flink.util.Preconditions.checkArgument; +import static org.assertj.core.api.Assertions.assertThat; + +/** Tests for {@link HsSubpartitionMemoryDataManager}. */ +class HsSubpartitionMemoryDataManagerTest { + private static final int SUBPARTITION_ID = 0; + + private static final ReentrantReadWriteLock lock = new ReentrantReadWriteLock(); + + private static final int RECORD_SIZE = Integer.BYTES; + + private int bufferSize = RECORD_SIZE; + + @Test + void testAppendDataRequestBuffer() throws Exception { + CompletableFuture<Void> requestBufferFuture = new CompletableFuture<>(); + HsMemoryDataManagerOperation memoryDataManagerOperation = + TestingMemoryDataManagerOperation.builder() + .setRequestBufferFromPoolSupplier( + () -> { + requestBufferFuture.complete(null); + return createBufferBuilder(bufferSize); + }) + .build(); + HsSubpartitionMemoryDataManager subpartitionMemoryDataManager = + createSubpartitionMemoryDataManager(memoryDataManagerOperation); + subpartitionMemoryDataManager.append(createRecord(0), DataType.DATA_BUFFER); + assertThat(requestBufferFuture).isCompleted(); + } + + @Test + void testAppendEventNotRequestBuffer() throws Exception { + CompletableFuture<Void> requestBufferFuture = new CompletableFuture<>(); + HsMemoryDataManagerOperation memoryDataManagerOperation = + TestingMemoryDataManagerOperation.builder() + .setRequestBufferFromPoolSupplier( + () -> { + requestBufferFuture.complete(null); + return null; + }) + .build(); + HsSubpartitionMemoryDataManager subpartitionMemoryDataManager = + createSubpartitionMemoryDataManager(memoryDataManagerOperation); + subpartitionMemoryDataManager.append(createRecord(0), DataType.EVENT_BUFFER); + assertThat(requestBufferFuture).isNotDone(); + } + + @Test + void testAppendEventFinishCurrentBuffer() throws Exception { + bufferSize = RECORD_SIZE * 3; + AtomicInteger finishedBuffers = new AtomicInteger(0); + HsMemoryDataManagerOperation memoryDataManagerOperation = + TestingMemoryDataManagerOperation.builder() + .setRequestBufferFromPoolSupplier(() -> createBufferBuilder(bufferSize)) + .setOnBufferFinishedRunnable(finishedBuffers::incrementAndGet) + .build(); + HsSubpartitionMemoryDataManager subpartitionMemoryDataManager = + createSubpartitionMemoryDataManager(memoryDataManagerOperation); + subpartitionMemoryDataManager.append(createRecord(0), DataType.DATA_BUFFER); + subpartitionMemoryDataManager.append(createRecord(1), DataType.DATA_BUFFER); + assertThat(finishedBuffers).hasValue(0); + subpartitionMemoryDataManager.append(createRecord(2), DataType.EVENT_BUFFER); + assertThat(finishedBuffers).hasValue(2); + } + + @Test + void testPeekNextToConsumeDataTypeNotMeetBufferIndexToConsume() throws Exception { + TestingMemoryDataManagerOperation memoryDataManagerOperation = + TestingMemoryDataManagerOperation.builder() + .setRequestBufferFromPoolSupplier(() -> createBufferBuilder(RECORD_SIZE)) + .build(); + HsSubpartitionMemoryDataManager subpartitionMemoryDataManager = + createSubpartitionMemoryDataManager(memoryDataManagerOperation); + + subpartitionMemoryDataManager.append(createRecord(0), DataType.DATA_BUFFER); + + assertThat(subpartitionMemoryDataManager.peekNextToConsumeDataType(1)) + .isEqualTo(DataType.NONE); + } + + @Test + void testPeekNextToConsumeDataTypeTrimHeadingReleasedBuffers() throws Exception { + TestingMemoryDataManagerOperation memoryDataManagerOperation = + TestingMemoryDataManagerOperation.builder() + .setRequestBufferFromPoolSupplier(() -> createBufferBuilder(RECORD_SIZE)) + .build(); + HsSubpartitionMemoryDataManager subpartitionMemoryDataManager = + createSubpartitionMemoryDataManager(memoryDataManagerOperation); + + subpartitionMemoryDataManager.append(createRecord(0), DataType.DATA_BUFFER); + subpartitionMemoryDataManager.append(createRecord(1), DataType.DATA_BUFFER); + subpartitionMemoryDataManager.append(createRecord(2), DataType.EVENT_BUFFER); + + List<BufferIndexAndChannel> toRelease = + HybridShuffleTestUtils.createBufferIndexAndChannelsList(0, 0, 1); + subpartitionMemoryDataManager.releaseSubpartitionBuffers(toRelease); + + assertThat(subpartitionMemoryDataManager.peekNextToConsumeDataType(2)) + .isEqualTo(DataType.EVENT_BUFFER); + } + + @Test + void testConsumeBufferFirstUnConsumedBufferIndexNotMeetNextToConsume() throws Exception { + TestingMemoryDataManagerOperation memoryDataManagerOperation = + TestingMemoryDataManagerOperation.builder() + .setRequestBufferFromPoolSupplier(() -> createBufferBuilder(RECORD_SIZE)) + .build(); + HsSubpartitionMemoryDataManager subpartitionMemoryDataManager = + createSubpartitionMemoryDataManager(memoryDataManagerOperation); + + subpartitionMemoryDataManager.append(createRecord(0), DataType.DATA_BUFFER); + + assertThat(subpartitionMemoryDataManager.consumeBuffer(1)).isNotPresent(); + } + + @Test + void testConsumeBufferTrimHeadingReleasedBuffers() throws Exception { + TestingMemoryDataManagerOperation memoryDataManagerOperation = + TestingMemoryDataManagerOperation.builder() + .setRequestBufferFromPoolSupplier(() -> createBufferBuilder(RECORD_SIZE)) + .build(); + HsSubpartitionMemoryDataManager subpartitionMemoryDataManager = + createSubpartitionMemoryDataManager(memoryDataManagerOperation); + + subpartitionMemoryDataManager.append(createRecord(0), DataType.DATA_BUFFER); + subpartitionMemoryDataManager.append(createRecord(1), DataType.DATA_BUFFER); + subpartitionMemoryDataManager.append(createRecord(2), DataType.EVENT_BUFFER); + + List<BufferIndexAndChannel> toRelease = + HybridShuffleTestUtils.createBufferIndexAndChannelsList(0, 0, 1); + subpartitionMemoryDataManager.releaseSubpartitionBuffers(toRelease); + + assertThat(subpartitionMemoryDataManager.consumeBuffer(2)).isPresent(); + } + + @Test + void testConsumeBuffer() throws Exception { + List<BufferIndexAndChannel> consumedBufferIndexAndChannel = new ArrayList<>(); + TestingMemoryDataManagerOperation memoryDataManagerOperation = + TestingMemoryDataManagerOperation.builder() + .setRequestBufferFromPoolSupplier(() -> createBufferBuilder(RECORD_SIZE)) + .setOnBufferConsumedConsumer(consumedBufferIndexAndChannel::add) + .build(); + HsSubpartitionMemoryDataManager subpartitionMemoryDataManager = + createSubpartitionMemoryDataManager(memoryDataManagerOperation); + + subpartitionMemoryDataManager.append(createRecord(0), DataType.DATA_BUFFER); + subpartitionMemoryDataManager.append(createRecord(1), DataType.DATA_BUFFER); + subpartitionMemoryDataManager.append(createRecord(2), DataType.EVENT_BUFFER); + + List<Tuple2<Integer, Buffer.DataType>> expectedRecords = new ArrayList<>(); + expectedRecords.add(Tuple2.of(0, Buffer.DataType.DATA_BUFFER)); + expectedRecords.add(Tuple2.of(1, Buffer.DataType.DATA_BUFFER)); + expectedRecords.add(Tuple2.of(2, DataType.EVENT_BUFFER)); + checkConsumedBufferAndNextDataType( + expectedRecords, + Arrays.asList( + subpartitionMemoryDataManager.consumeBuffer(0), + subpartitionMemoryDataManager.consumeBuffer(1), + subpartitionMemoryDataManager.consumeBuffer(2))); + + List<BufferIndexAndChannel> expectedBufferIndexAndChannel = + HybridShuffleTestUtils.createBufferIndexAndChannelsList(0, 0, 1, 2); + assertThat(consumedBufferIndexAndChannel) + .zipSatisfy( + expectedBufferIndexAndChannel, + (consumed, expected) -> { + assertThat(consumed.getChannel()).isEqualTo(expected.getChannel()); + assertThat(consumed.getBufferIndex()) + .isEqualTo(expected.getBufferIndex()); + }); + } + + @Test + void testGetBuffersSatisfyStatus() throws Exception { + TestingMemoryDataManagerOperation memoryDataManagerOperation = + TestingMemoryDataManagerOperation.builder() + .setRequestBufferFromPoolSupplier(() -> createBufferBuilder(RECORD_SIZE)) + .build(); + HsSubpartitionMemoryDataManager subpartitionMemoryDataManager = + createSubpartitionMemoryDataManager(memoryDataManagerOperation); + final int numBuffers = 4; + for (int i = 0; i < numBuffers; i++) { + subpartitionMemoryDataManager.append(createRecord(i), DataType.DATA_BUFFER); + } + + // spill buffer 1 and 2 + List<BufferIndexAndChannel> toStartSpilling = + HybridShuffleTestUtils.createBufferIndexAndChannelsList(0, 1, 2); + CompletableFuture<Void> spilledDoneFuture = new CompletableFuture<>(); + subpartitionMemoryDataManager.spillSubpartitionBuffers(toStartSpilling, spilledDoneFuture); + + // consume buffer 0, 1 + subpartitionMemoryDataManager.consumeBuffer(0); + subpartitionMemoryDataManager.consumeBuffer(1); + + checkBufferIndex( + subpartitionMemoryDataManager.getBuffersSatisfyStatus( + SpillStatus.ALL, ConsumeStatus.ALL), + Arrays.asList(0, 1, 2, 3)); + checkBufferIndex( + subpartitionMemoryDataManager.getBuffersSatisfyStatus( + SpillStatus.ALL, ConsumeStatus.CONSUMED), + Arrays.asList(0, 1)); + checkBufferIndex( + subpartitionMemoryDataManager.getBuffersSatisfyStatus( + SpillStatus.ALL, ConsumeStatus.NOT_CONSUMED), + Arrays.asList(2, 3)); + checkBufferIndex( + subpartitionMemoryDataManager.getBuffersSatisfyStatus( + SpillStatus.SPILL, ConsumeStatus.ALL), + Arrays.asList(1, 2)); + checkBufferIndex( + subpartitionMemoryDataManager.getBuffersSatisfyStatus( + SpillStatus.NOT_SPILL, ConsumeStatus.ALL), + Arrays.asList(0, 3)); + checkBufferIndex( + subpartitionMemoryDataManager.getBuffersSatisfyStatus( + SpillStatus.SPILL, ConsumeStatus.NOT_CONSUMED), + Collections.singletonList(2)); + checkBufferIndex( + subpartitionMemoryDataManager.getBuffersSatisfyStatus( + SpillStatus.SPILL, ConsumeStatus.CONSUMED), + Collections.singletonList(1)); + checkBufferIndex( + subpartitionMemoryDataManager.getBuffersSatisfyStatus( + SpillStatus.NOT_SPILL, ConsumeStatus.CONSUMED), + Collections.singletonList(0)); + checkBufferIndex( + subpartitionMemoryDataManager.getBuffersSatisfyStatus( + SpillStatus.NOT_SPILL, ConsumeStatus.NOT_CONSUMED), + Collections.singletonList(3)); + } + + @Test + void testSpillSubpartitionBuffers() throws Exception { + CompletableFuture<Void> spilledDoneFuture = new CompletableFuture<>(); + TestingMemoryDataManagerOperation memoryDataManagerOperation = + TestingMemoryDataManagerOperation.builder() + .setRequestBufferFromPoolSupplier(() -> createBufferBuilder(RECORD_SIZE)) + .build(); + HsSubpartitionMemoryDataManager subpartitionMemoryDataManager = + createSubpartitionMemoryDataManager(memoryDataManagerOperation); + final int numBuffers = 3; + for (int i = 0; i < numBuffers; i++) { + subpartitionMemoryDataManager.append(createRecord(i), DataType.DATA_BUFFER); + } + + List<BufferIndexAndChannel> toStartSpilling = + HybridShuffleTestUtils.createBufferIndexAndChannelsList(0, 0, 1, 2); + List<BufferWithIdentity> buffers = + subpartitionMemoryDataManager.spillSubpartitionBuffers( + toStartSpilling, spilledDoneFuture); + assertThat(toStartSpilling) + .zipSatisfy( + buffers, + (expected, spilled) -> { + assertThat(expected.getBufferIndex()) + .isEqualTo(spilled.getBufferIndex()); + assertThat(expected.getChannel()).isEqualTo(spilled.getChannelIndex()); + }); + List<Integer> expectedValues = Arrays.asList(0, 1, 2); + checkBuffersRefCountAndValue(buffers, Arrays.asList(2, 2, 2), expectedValues); + spilledDoneFuture.complete(null); + checkBuffersRefCountAndValue(buffers, Arrays.asList(1, 1, 1), expectedValues); + } + + @Test + void testReleaseAndMarkReadableSubpartitionBuffers() throws Exception { + int targetChannel = 0; + List<Integer> readableBufferIndex = new ArrayList<>(); + List<MemorySegment> recycledBuffers = new ArrayList<>(); + TestingMemoryDataManagerOperation memoryDataManagerOperation = + TestingMemoryDataManagerOperation.builder() + .setRequestBufferFromPoolSupplier( + () -> + new BufferBuilder( + MemorySegmentFactory.allocateUnpooledSegment( + bufferSize), + recycledBuffers::add)) + .setMarkBufferReadableConsumer( + (channel, bufferIndex) -> { + assertThat(channel).isEqualTo(targetChannel); + readableBufferIndex.add(bufferIndex); + }) + .build(); + HsSubpartitionMemoryDataManager subpartitionMemoryDataManager = + createSubpartitionMemoryDataManager(memoryDataManagerOperation); + // append data + final int numBuffers = 3; + for (int i = 0; i < numBuffers; i++) { + subpartitionMemoryDataManager.append(createRecord(i), DataType.DATA_BUFFER); + } + // spill the last buffer and release all buffers. + List<BufferIndexAndChannel> toRelease = + HybridShuffleTestUtils.createBufferIndexAndChannelsList(targetChannel, 0, 1, 2); + CompletableFuture<Void> spilledFuture = new CompletableFuture<>(); + subpartitionMemoryDataManager.spillSubpartitionBuffers( + toRelease.subList(numBuffers - 1, numBuffers), spilledFuture); + subpartitionMemoryDataManager.releaseSubpartitionBuffers(toRelease); + assertThat(readableBufferIndex).isEmpty(); + // not start spilling buffers should be recycled after release. + checkMemorySegmentValue(recycledBuffers, Arrays.asList(0, 1)); + + // after spill finished, need mark readable buffers should trigger notify. + spilledFuture.complete(null); + assertThat(readableBufferIndex).containsExactly(2); + checkMemorySegmentValue(recycledBuffers, Arrays.asList(0, 1, 2)); + } + + private static void checkBufferIndex( + Deque<BufferIndexAndChannel> bufferWithIdentities, List<Integer> expectedIndexes) { + List<Integer> bufferIndexes = + bufferWithIdentities.stream() + .map(BufferIndexAndChannel::getBufferIndex) + .collect(Collectors.toList()); + assertThat(bufferIndexes).isEqualTo(expectedIndexes); + } + + private static void checkMemorySegmentValue( + List<MemorySegment> memorySegments, List<Integer> expectedValues) { + for (int i = 0; i < memorySegments.size(); i++) { + assertThat(memorySegments.get(i).getInt(0)).isEqualTo(expectedValues.get(i)); + } + } + + private static void checkConsumedBufferAndNextDataType( + List<Tuple2<Integer, Buffer.DataType>> expectedRecords, + List<Optional<BufferAndNextDataType>> bufferAndNextDataTypesOpt) { + checkArgument(expectedRecords.size() == bufferAndNextDataTypesOpt.size()); + for (int i = 0; i < bufferAndNextDataTypesOpt.size(); i++) { + final int index = i; + assertThat(bufferAndNextDataTypesOpt.get(index)) + .hasValueSatisfying( + (bufferAndNextDataType -> { + Buffer buffer = bufferAndNextDataType.getBuffer(); + int value = + buffer.getNioBufferReadable() + .order(ByteOrder.LITTLE_ENDIAN) + .getInt(); + Buffer.DataType dataType = buffer.getDataType(); + assertThat(value).isEqualTo(expectedRecords.get(index).f0); + assertThat(dataType).isEqualTo(expectedRecords.get(index).f1); + if (index != bufferAndNextDataTypesOpt.size() - 1) { + assertThat(bufferAndNextDataType.getNextDataType()) + .isEqualTo(expectedRecords.get(index + 1).f1); + } else { + assertThat(bufferAndNextDataType.getNextDataType()) + .isEqualTo(Buffer.DataType.NONE); + } + })); + } + } + + private static void checkBuffersRefCountAndValue( + List<BufferWithIdentity> bufferWithIdentities, + List<Integer> expectedRefCounts, + List<Integer> expectedValues) { + for (int i = 0; i < bufferWithIdentities.size(); i++) { + BufferWithIdentity bufferWithIdentity = bufferWithIdentities.get(i); + Buffer buffer = bufferWithIdentity.getBuffer(); + assertThat(buffer.getNioBufferReadable().order(ByteOrder.LITTLE_ENDIAN).getInt()) + .isEqualTo(expectedValues.get(i)); + assertThat(buffer.refCnt()).isEqualTo(expectedRefCounts.get(i)); + } + } + + private HsSubpartitionMemoryDataManager createSubpartitionMemoryDataManager( + HsMemoryDataManagerOperation memoryDataManagerOperation) { + return new HsSubpartitionMemoryDataManager( + SUBPARTITION_ID, bufferSize, lock.readLock(), memoryDataManagerOperation); + } + + private static ByteBuffer createRecord(int value) { + ByteBuffer byteBuffer = ByteBuffer.allocate(RECORD_SIZE); + byteBuffer.order(ByteOrder.LITTLE_ENDIAN); + byteBuffer.putInt(value); + byteBuffer.flip(); + return byteBuffer; + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSpillingStrategyTestUtils.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HybridShuffleTestUtils.java similarity index 72% rename from flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSpillingStrategyTestUtils.java rename to flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HybridShuffleTestUtils.java index e959c150f63..0d7ec874ba9 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HsSpillingStrategyTestUtils.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/HybridShuffleTestUtils.java @@ -20,6 +20,9 @@ package org.apache.flink.runtime.io.network.partition.hybrid; import org.apache.flink.core.memory.MemorySegment; import org.apache.flink.core.memory.MemorySegmentFactory; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.buffer.BufferBuilder; +import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler; import org.apache.flink.runtime.io.network.buffer.NetworkBuffer; import java.util.ArrayDeque; @@ -27,8 +30,8 @@ import java.util.ArrayList; import java.util.Deque; import java.util.List; -/** Test utils for {@link HsSpillingStrategy}. */ -public class HsSpillingStrategyTestUtils { +/** Test utils for hybrid shuffle mode. */ +public class HybridShuffleTestUtils { public static final int MEMORY_SEGMENT_SIZE = 128; public static List<BufferIndexAndChannel> createBufferIndexAndChannelsList( @@ -51,4 +54,18 @@ public class HsSpillingStrategyTestUtils { } return bufferIndexAndChannels; } + + public static Buffer createBuffer(int bufferSize, boolean isEvent) { + return new NetworkBuffer( + MemorySegmentFactory.allocateUnpooledSegment(bufferSize), + FreeingBufferRecycler.INSTANCE, + isEvent ? Buffer.DataType.EVENT_BUFFER : Buffer.DataType.DATA_BUFFER, + bufferSize); + } + + public static BufferBuilder createBufferBuilder(int bufferSize) { + return new BufferBuilder( + MemorySegmentFactory.allocateUnpooledSegment(bufferSize), + FreeingBufferRecycler.INSTANCE); + } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/TestingFileDataIndex.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/TestingFileDataIndex.java new file mode 100644 index 00000000000..db5663bb42f --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/TestingFileDataIndex.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.partition.hybrid; + +import java.util.List; +import java.util.Optional; +import java.util.function.BiConsumer; +import java.util.function.BiFunction; +import java.util.function.Consumer; + +/** Mock {@link HsFileDataIndex} for testing. */ +public class TestingFileDataIndex implements HsFileDataIndex { + private final BiFunction<Integer, Integer, Optional<ReadableRegion>> getReadableRegionFunction; + + private final Consumer<List<SpilledBuffer>> addBuffersConsumer; + + private final BiConsumer<Integer, Integer> markBufferReadableConsumer; + + private TestingFileDataIndex( + BiFunction<Integer, Integer, Optional<ReadableRegion>> getReadableRegionFunction, + Consumer<List<SpilledBuffer>> addBuffersConsumer, + BiConsumer<Integer, Integer> markBufferReadableConsumer) { + this.getReadableRegionFunction = getReadableRegionFunction; + this.addBuffersConsumer = addBuffersConsumer; + this.markBufferReadableConsumer = markBufferReadableConsumer; + } + + @Override + public Optional<ReadableRegion> getReadableRegion(int subpartitionId, int bufferIndex) { + return getReadableRegionFunction.apply(subpartitionId, bufferIndex); + } + + @Override + public void addBuffers(List<SpilledBuffer> spilledBuffers) { + addBuffersConsumer.accept(spilledBuffers); + } + + @Override + public void markBufferReadable(int subpartitionId, int bufferIndex) { + markBufferReadableConsumer.accept(subpartitionId, bufferIndex); + } + + public static Builder builder() { + return new Builder(); + } + + /** Builder for {@link TestingFileDataIndex}. */ + public static class Builder { + private BiFunction<Integer, Integer, Optional<ReadableRegion>> getReadableRegionFunction = + (ignore1, ignore2) -> Optional.empty(); + + private Consumer<List<SpilledBuffer>> addBuffersConsumer = (ignore) -> {}; + + private BiConsumer<Integer, Integer> markBufferReadableConsumer = (ignore1, ignore2) -> {}; + + private Builder() {} + + public Builder setGetReadableRegionFunction( + BiFunction<Integer, Integer, Optional<ReadableRegion>> getReadableRegionFunction) { + this.getReadableRegionFunction = getReadableRegionFunction; + return this; + } + + public Builder setAddBuffersConsumer(Consumer<List<SpilledBuffer>> addBuffersConsumer) { + this.addBuffersConsumer = addBuffersConsumer; + return this; + } + + public Builder setMarkBufferReadableConsumer( + BiConsumer<Integer, Integer> markBufferReadableConsumer) { + this.markBufferReadableConsumer = markBufferReadableConsumer; + return this; + } + + public TestingFileDataIndex build() { + return new TestingFileDataIndex( + getReadableRegionFunction, addBuffersConsumer, markBufferReadableConsumer); + } + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/TestingMemoryDataManagerOperation.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/TestingMemoryDataManagerOperation.java new file mode 100644 index 00000000000..f78774ca674 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/TestingMemoryDataManagerOperation.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.partition.hybrid; + +import org.apache.flink.runtime.io.network.buffer.BufferBuilder; +import org.apache.flink.util.function.SupplierWithException; + +import java.util.function.BiConsumer; +import java.util.function.Consumer; + +/** Mock {@link HsMemoryDataManagerOperation} for testing. */ +public class TestingMemoryDataManagerOperation implements HsMemoryDataManagerOperation { + private final SupplierWithException<BufferBuilder, InterruptedException> + requestBufferFromPoolSupplier; + + private final BiConsumer<Integer, Integer> markBufferReadableConsumer; + + private final Consumer<BufferIndexAndChannel> onBufferConsumedConsumer; + + private final Runnable onBufferFinishedRunnable; + + private TestingMemoryDataManagerOperation( + SupplierWithException<BufferBuilder, InterruptedException> + requestBufferFromPoolSupplier, + BiConsumer<Integer, Integer> markBufferReadableConsumer, + Consumer<BufferIndexAndChannel> onBufferConsumedConsumer, + Runnable onBufferFinishedRunnable) { + this.requestBufferFromPoolSupplier = requestBufferFromPoolSupplier; + this.markBufferReadableConsumer = markBufferReadableConsumer; + this.onBufferConsumedConsumer = onBufferConsumedConsumer; + this.onBufferFinishedRunnable = onBufferFinishedRunnable; + } + + @Override + public BufferBuilder requestBufferFromPool() throws InterruptedException { + return requestBufferFromPoolSupplier.get(); + } + + @Override + public void markBufferReadableFromFile(int subpartitionId, int bufferIndex) { + markBufferReadableConsumer.accept(subpartitionId, bufferIndex); + } + + @Override + public void onBufferConsumed(BufferIndexAndChannel consumedBuffer) { + onBufferConsumedConsumer.accept(consumedBuffer); + } + + @Override + public void onBufferFinished() { + onBufferFinishedRunnable.run(); + } + + public static Builder builder() { + return new Builder(); + } + + /** Builder for {@link TestingMemoryDataManagerOperation}. */ + public static class Builder { + private SupplierWithException<BufferBuilder, InterruptedException> + requestBufferFromPoolSupplier = () -> null; + + private BiConsumer<Integer, Integer> markBufferReadableConsumer = (ignore1, ignore2) -> {}; + + private Consumer<BufferIndexAndChannel> onBufferConsumedConsumer = (ignore1) -> {}; + + private Runnable onBufferFinishedRunnable = () -> {}; + + public Builder setRequestBufferFromPoolSupplier( + SupplierWithException<BufferBuilder, InterruptedException> + requestBufferFromPoolSupplier) { + this.requestBufferFromPoolSupplier = requestBufferFromPoolSupplier; + return this; + } + + public Builder setMarkBufferReadableConsumer( + BiConsumer<Integer, Integer> markBufferReadableConsumer) { + this.markBufferReadableConsumer = markBufferReadableConsumer; + return this; + } + + public Builder setOnBufferConsumedConsumer( + Consumer<BufferIndexAndChannel> onBufferConsumedConsumer) { + this.onBufferConsumedConsumer = onBufferConsumedConsumer; + return this; + } + + public Builder setOnBufferFinishedRunnable(Runnable onBufferFinishedRunnable) { + this.onBufferFinishedRunnable = onBufferFinishedRunnable; + return this; + } + + private Builder() {} + + public TestingMemoryDataManagerOperation build() { + return new TestingMemoryDataManagerOperation( + requestBufferFromPoolSupplier, + markBufferReadableConsumer, + onBufferConsumedConsumer, + onBufferFinishedRunnable); + } + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/TestingSpillingStrategy.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/TestingSpillingStrategy.java new file mode 100644 index 00000000000..4cce53db0a9 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/hybrid/TestingSpillingStrategy.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.partition.hybrid; + +import java.util.Optional; +import java.util.function.BiFunction; +import java.util.function.Function; + +/** Mock {@link HsSpillingStrategy} for testing. */ +public class TestingSpillingStrategy implements HsSpillingStrategy { + private final BiFunction<Integer, Integer, Optional<Decision>> onMemoryUsageChangedFunction; + + private final Function<Integer, Optional<Decision>> onBufferFinishedFunction; + + private final Function<BufferIndexAndChannel, Optional<Decision>> onBufferConsumedFunction; + + private final Function<HsSpillingInfoProvider, Decision> decideActionWithGlobalInfoFunction; + + private TestingSpillingStrategy( + BiFunction<Integer, Integer, Optional<Decision>> onMemoryUsageChangedFunction, + Function<Integer, Optional<Decision>> onBufferFinishedFunction, + Function<BufferIndexAndChannel, Optional<Decision>> onBufferConsumedFunction, + Function<HsSpillingInfoProvider, Decision> decideActionWithGlobalInfoFunction) { + this.onMemoryUsageChangedFunction = onMemoryUsageChangedFunction; + this.onBufferFinishedFunction = onBufferFinishedFunction; + this.onBufferConsumedFunction = onBufferConsumedFunction; + this.decideActionWithGlobalInfoFunction = decideActionWithGlobalInfoFunction; + } + + @Override + public Optional<Decision> onMemoryUsageChanged( + int numTotalRequestedBuffers, int currentPoolSize) { + return onMemoryUsageChangedFunction.apply(numTotalRequestedBuffers, currentPoolSize); + } + + @Override + public Optional<Decision> onBufferFinished(int numTotalUnSpillBuffers) { + return onBufferFinishedFunction.apply(numTotalUnSpillBuffers); + } + + @Override + public Optional<Decision> onBufferConsumed(BufferIndexAndChannel consumedBuffer) { + return onBufferConsumedFunction.apply(consumedBuffer); + } + + @Override + public Decision decideActionWithGlobalInfo(HsSpillingInfoProvider spillingInfoProvider) { + return decideActionWithGlobalInfoFunction.apply(spillingInfoProvider); + } + + public static Builder builder() { + return new Builder(); + } + + /** Builder for {@link TestingSpillingStrategy}. */ + public static class Builder { + private BiFunction<Integer, Integer, Optional<Decision>> onMemoryUsageChangedFunction = + (ignore1, ignore2) -> Optional.of(Decision.NO_ACTION); + + private Function<Integer, Optional<Decision>> onBufferFinishedFunction = + (ignore) -> Optional.of(Decision.NO_ACTION); + + private Function<BufferIndexAndChannel, Optional<Decision>> onBufferConsumedFunction = + (ignore) -> Optional.of(Decision.NO_ACTION); + + private Function<HsSpillingInfoProvider, Decision> decideActionWithGlobalInfoFunction = + (ignore) -> Decision.NO_ACTION; + + private Builder() {} + + public Builder setOnMemoryUsageChangedFunction( + BiFunction<Integer, Integer, Optional<Decision>> onMemoryUsageChangedFunction) { + this.onMemoryUsageChangedFunction = onMemoryUsageChangedFunction; + return this; + } + + public Builder setOnBufferFinishedFunction( + Function<Integer, Optional<Decision>> onBufferFinishedFunction) { + this.onBufferFinishedFunction = onBufferFinishedFunction; + return this; + } + + public Builder setOnBufferConsumedFunction( + Function<BufferIndexAndChannel, Optional<Decision>> onBufferConsumedFunction) { + this.onBufferConsumedFunction = onBufferConsumedFunction; + return this; + } + + public Builder setDecideActionWithGlobalInfoFunction( + Function<HsSpillingInfoProvider, Decision> decideActionWithGlobalInfoFunction) { + this.decideActionWithGlobalInfoFunction = decideActionWithGlobalInfoFunction; + return this; + } + + public TestingSpillingStrategy build() { + return new TestingSpillingStrategy( + onMemoryUsageChangedFunction, + onBufferFinishedFunction, + onBufferConsumedFunction, + decideActionWithGlobalInfoFunction); + } + } +}
