This is an automated email from the ASF dual-hosted git repository.

zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 907364db [CELEBORN-156] add remoteShuffleResultPartition in 
flink-plugin (#1103)
907364db is described below

commit 907364dbf2236230a591406fe4bb2d62fa3b2e47
Author: zhongqiangczq <[email protected]>
AuthorDate: Wed Dec 21 12:22:17 2022 +0800

    [CELEBORN-156] add remoteShuffleResultPartition in flink-plugin (#1103)
---
 .../plugin/flink/RemoteShuffleResultPartition.java | 380 ++++++++++++++++++
 .../flink/RemoteShuffleResultPartitionSuiteJ.java  |  93 +++++
 .../plugin/flink/buffer/PartitionSortedBuffer.java | 440 +++++++++++++++++++++
 .../celeborn/plugin/flink/buffer/SortBuffer.java   |  92 +++++
 4 files changed, 1005 insertions(+)

diff --git 
a/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleResultPartition.java
 
b/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleResultPartition.java
new file mode 100644
index 00000000..6be7c665
--- /dev/null
+++ 
b/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleResultPartition.java
@@ -0,0 +1,380 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.plugin.flink;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.concurrent.CompletableFuture;
+
+import javax.annotation.Nullable;
+
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.runtime.event.AbstractEvent;
+import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
+import org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.Buffer.DataType;
+import org.apache.flink.runtime.io.network.buffer.BufferCompressor;
+import org.apache.flink.runtime.io.network.buffer.BufferPool;
+import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
+import 
org.apache.flink.runtime.io.network.partition.BufferAvailabilityListener;
+import org.apache.flink.runtime.io.network.partition.ResultPartition;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionManager;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
+import org.apache.flink.runtime.io.network.partition.ResultSubpartitionView;
+import org.apache.flink.util.FlinkRuntimeException;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.function.SupplierWithException;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.celeborn.plugin.flink.buffer.PartitionSortedBuffer;
+import org.apache.celeborn.plugin.flink.buffer.SortBuffer;
+import org.apache.celeborn.plugin.flink.utils.BufferUtils;
+import org.apache.celeborn.plugin.flink.utils.Utils;
+
+/**
+ * A {@link ResultPartition} which appends records and events to {@link 
SortBuffer} and after the
+ * {@link SortBuffer} is full, all data in the {@link SortBuffer} will be 
copied and spilled to the
+ * remote shuffle service in subpartition index order sequentially. Large 
records that can not be
+ * appended to an empty {@link 
org.apache.flink.runtime.io.network.partition.SortBuffer} will be
+ * spilled directly.
+ */
+public class RemoteShuffleResultPartition extends ResultPartition {
+
+  private static final Logger LOG = 
LoggerFactory.getLogger(RemoteShuffleResultPartition.class);
+
+  /** Size of network buffer and write buffer. */
+  private final int networkBufferSize;
+
+  /** {@link SortBuffer} for records sent by {@link 
#broadcastRecord(ByteBuffer)}. */
+  private SortBuffer broadcastSortBuffer;
+
+  /** {@link SortBuffer} for records sent by {@link #emitRecord(ByteBuffer, 
int)}. */
+  private SortBuffer unicastSortBuffer;
+
+  /** Utility to spill data to shuffle workers. */
+  private final RemoteShuffleOutputGate outputGate;
+
+  public RemoteShuffleResultPartition(
+      String owningTaskName,
+      int partitionIndex,
+      ResultPartitionID partitionId,
+      ResultPartitionType partitionType,
+      int numSubpartitions,
+      int numTargetKeyGroups,
+      int networkBufferSize,
+      ResultPartitionManager partitionManager,
+      @Nullable BufferCompressor bufferCompressor,
+      SupplierWithException<BufferPool, IOException> bufferPoolFactory,
+      RemoteShuffleOutputGate outputGate) {
+
+    super(
+        owningTaskName,
+        partitionIndex,
+        partitionId,
+        partitionType,
+        numSubpartitions,
+        numTargetKeyGroups,
+        partitionManager,
+        bufferCompressor,
+        bufferPoolFactory);
+
+    this.networkBufferSize = networkBufferSize;
+    this.outputGate = outputGate;
+  }
+
+  @Override
+  public void setup() throws IOException {
+    LOG.info("Setup {}", this);
+    super.setup();
+    BufferUtils.reserveNumRequiredBuffers(bufferPool, 1);
+    try {
+      outputGate.setup();
+    } catch (Throwable throwable) {
+      LOG.error("Failed to setup remote output gate.", throwable);
+      Utils.rethrowAsRuntimeException(throwable);
+    }
+  }
+
+  @Override
+  public void emitRecord(ByteBuffer record, int targetSubpartition) throws 
IOException {
+    emit(record, targetSubpartition, DataType.DATA_BUFFER, false);
+  }
+
+  @Override
+  public void broadcastRecord(ByteBuffer record) throws IOException {
+    broadcast(record, DataType.DATA_BUFFER);
+  }
+
+  @Override
+  public void broadcastEvent(AbstractEvent event, boolean isPriorityEvent) 
throws IOException {
+    Buffer buffer = EventSerializer.toBuffer(event, isPriorityEvent);
+    try {
+      ByteBuffer serializedEvent = buffer.getNioBufferReadable();
+      broadcast(serializedEvent, buffer.getDataType());
+    } finally {
+      buffer.recycleBuffer();
+    }
+  }
+
+  private void broadcast(ByteBuffer record, DataType dataType) throws 
IOException {
+    emit(record, 0, dataType, true);
+  }
+
+  private void emit(
+      ByteBuffer record, int targetSubpartition, DataType dataType, boolean 
isBroadcast)
+      throws IOException {
+
+    checkInProduceState();
+    if (isBroadcast) {
+      Preconditions.checkState(
+          targetSubpartition == 0, "Target subpartition index can only be 0 
when broadcast.");
+    }
+
+    SortBuffer sortBuffer = isBroadcast ? getBroadcastSortBuffer() : 
getUnicastSortBuffer();
+    if (sortBuffer.append(record, targetSubpartition, dataType)) {
+      return;
+    }
+
+    try {
+      if (!sortBuffer.hasRemaining()) {
+        // the record can not be appended to the free sort buffer because it 
is too large
+        sortBuffer.finish();
+        sortBuffer.release();
+        writeLargeRecord(record, targetSubpartition, dataType, isBroadcast);
+        return;
+      }
+      flushSortBuffer(sortBuffer, isBroadcast);
+    } catch (InterruptedException e) {
+      LOG.error("Failed to flush the sort buffer.", e);
+      Utils.rethrowAsRuntimeException(e);
+    }
+    emit(record, targetSubpartition, dataType, isBroadcast);
+  }
+
+  private void releaseSortBuffer(SortBuffer sortBuffer) {
+    if (sortBuffer != null) {
+      sortBuffer.release();
+    }
+  }
+
+  @VisibleForTesting
+  SortBuffer getUnicastSortBuffer() throws IOException {
+    flushBroadcastSortBuffer();
+
+    if (unicastSortBuffer != null && !unicastSortBuffer.isFinished()) {
+      return unicastSortBuffer;
+    }
+
+    unicastSortBuffer =
+        new PartitionSortedBuffer(bufferPool, numSubpartitions, 
networkBufferSize, null);
+    return unicastSortBuffer;
+  }
+
+  private SortBuffer getBroadcastSortBuffer() throws IOException {
+    flushUnicastSortBuffer();
+
+    if (broadcastSortBuffer != null && !broadcastSortBuffer.isFinished()) {
+      return broadcastSortBuffer;
+    }
+
+    broadcastSortBuffer =
+        new PartitionSortedBuffer(bufferPool, numSubpartitions, 
networkBufferSize, null);
+    return broadcastSortBuffer;
+  }
+
+  private void flushBroadcastSortBuffer() throws IOException {
+    flushSortBuffer(broadcastSortBuffer, true);
+  }
+
+  private void flushUnicastSortBuffer() throws IOException {
+    flushSortBuffer(unicastSortBuffer, false);
+  }
+
+  @VisibleForTesting
+  void flushSortBuffer(SortBuffer sortBuffer, boolean isBroadcast) throws 
IOException {
+    if (sortBuffer == null || sortBuffer.isReleased()) {
+      return;
+    }
+    sortBuffer.finish();
+    if (sortBuffer.hasRemaining()) {
+      try {
+        outputGate.regionStart(isBroadcast);
+        while (sortBuffer.hasRemaining()) {
+          MemorySegment segment = 
outputGate.getBufferPool().requestMemorySegmentBlocking();
+          SortBuffer.BufferWithChannel bufferWithChannel;
+          try {
+            bufferWithChannel =
+                sortBuffer.copyIntoSegment(
+                    segment, outputGate.getBufferPool(), 
BufferUtils.HEADER_LENGTH);
+          } catch (Throwable t) {
+            outputGate.getBufferPool().recycle(segment);
+            throw new FlinkRuntimeException("Shuffle write failure.", t);
+          }
+
+          Buffer buffer = bufferWithChannel.getBuffer();
+          int subpartitionIndex = bufferWithChannel.getChannelIndex();
+          updateStatistics(bufferWithChannel.getBuffer());
+          writeCompressedBufferIfPossible(buffer, subpartitionIndex);
+        }
+        outputGate.regionFinish();
+      } catch (InterruptedException e) {
+        throw new IOException("Failed to flush the sort buffer, broadcast=" + 
isBroadcast, e);
+      }
+    }
+    releaseSortBuffer(sortBuffer);
+  }
+
+  private void writeCompressedBufferIfPossible(Buffer buffer, int 
targetSubpartition)
+      throws InterruptedException {
+    Buffer compressedBuffer = null;
+    try {
+      if (canBeCompressed(buffer)) {
+        Buffer dataBuffer =
+            buffer.readOnlySlice(
+                BufferUtils.HEADER_LENGTH, buffer.getSize() - 
BufferUtils.HEADER_LENGTH);
+        compressedBuffer =
+            
Utils.checkNotNull(bufferCompressor).compressToIntermediateBuffer(dataBuffer);
+      }
+      BufferUtils.setCompressedDataWithHeader(buffer, compressedBuffer);
+    } catch (Throwable throwable) {
+      buffer.recycleBuffer();
+      throw new RuntimeException("Shuffle write failure.", throwable);
+    } finally {
+      if (compressedBuffer != null && compressedBuffer.isCompressed()) {
+        compressedBuffer.setReaderIndex(0);
+        compressedBuffer.recycleBuffer();
+      }
+    }
+    outputGate.write(buffer, targetSubpartition);
+  }
+
+  private void updateStatistics(Buffer buffer) {
+    numBuffersOut.inc();
+    numBytesOut.inc(buffer.readableBytes() - BufferUtils.HEADER_LENGTH);
+  }
+
+  /** Spills the large record into {@link RemoteShuffleOutputGate}. */
+  private void writeLargeRecord(
+      ByteBuffer record, int targetSubpartition, DataType dataType, boolean 
isBroadcast)
+      throws InterruptedException {
+
+    outputGate.regionStart(isBroadcast);
+    while (record.hasRemaining()) {
+      MemorySegment writeBuffer = 
outputGate.getBufferPool().requestMemorySegmentBlocking();
+      int toCopy = Math.min(record.remaining(), writeBuffer.size() - 
BufferUtils.HEADER_LENGTH);
+      writeBuffer.put(BufferUtils.HEADER_LENGTH, record, toCopy);
+      NetworkBuffer buffer =
+          new NetworkBuffer(
+              writeBuffer,
+              outputGate.getBufferPool(),
+              dataType,
+              toCopy + BufferUtils.HEADER_LENGTH);
+
+      updateStatistics(buffer);
+      writeCompressedBufferIfPossible(buffer, targetSubpartition);
+    }
+    outputGate.regionFinish();
+  }
+
+  @Override
+  public void finish() throws IOException {
+    Utils.checkState(!isReleased(), "Result partition is already released.");
+    broadcastEvent(EndOfPartitionEvent.INSTANCE, false);
+    Utils.checkState(
+        unicastSortBuffer == null || unicastSortBuffer.isReleased(),
+        "The unicast sort buffer should be either null or released.");
+    flushBroadcastSortBuffer();
+    try {
+      outputGate.finish();
+    } catch (InterruptedException e) {
+      throw new IOException("Output gate fails to finish.", e);
+    }
+    super.finish();
+  }
+
+  @Override
+  public synchronized void close() {
+    releaseSortBuffer(unicastSortBuffer);
+    releaseSortBuffer(broadcastSortBuffer);
+    super.close();
+    try {
+      outputGate.close();
+    } catch (Exception e) {
+      Utils.rethrowAsRuntimeException(e);
+    }
+  }
+
+  @Override
+  protected void releaseInternal() {
+    // no-op
+  }
+
+  @Override
+  public void flushAll() {
+    try {
+      flushUnicastSortBuffer();
+      flushBroadcastSortBuffer();
+    } catch (Throwable t) {
+      LOG.error("Failed to flush the current sort buffer.", t);
+      Utils.rethrowAsRuntimeException(t);
+    }
+  }
+
+  @Override
+  public void flush(int subpartitionIndex) {
+    flushAll();
+  }
+
+  @Override
+  public CompletableFuture<?> getAvailableFuture() {
+    return AVAILABLE;
+  }
+
+  @Override
+  public int getNumberOfQueuedBuffers() {
+    return 0;
+  }
+
+  @Override
+  public int getNumberOfQueuedBuffers(int targetSubpartition) {
+    return 0;
+  }
+
+  @Override
+  public ResultSubpartitionView createSubpartitionView(
+      int index, BufferAvailabilityListener availabilityListener) {
+    throw new UnsupportedOperationException("Not supported.");
+  }
+
+  @Override
+  public String toString() {
+    return "ResultPartition "
+        + partitionId.toString()
+        + " ["
+        + partitionType
+        + ", "
+        + numSubpartitions
+        + " subpartitions, shuffle-descriptor: "
+        + outputGate.getShuffleDesc()
+        + "]";
+  }
+}
diff --git 
a/client-flink/flink-1.14/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleResultPartitionSuiteJ.java
 
b/client-flink/flink-1.14/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleResultPartitionSuiteJ.java
new file mode 100644
index 00000000..a0b620a5
--- /dev/null
+++ 
b/client-flink/flink-1.14/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleResultPartitionSuiteJ.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.plugin.flink;
+
+import org.apache.celeborn.plugin.flink.buffer.SortBuffer;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.BufferCompressor;
+import org.apache.flink.runtime.io.network.buffer.BufferPool;
+import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionManager;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
+import org.apache.flink.util.function.SupplierWithException;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.time.Duration;
+import java.util.ArrayList;
+import java.util.List;
+
+import static org.mockito.Matchers.anyBoolean;
+import static org.mockito.Mockito.doNothing;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+public class RemoteShuffleResultPartitionSuiteJ {
+    private BufferCompressor bufferCompressor =
+            new BufferCompressor(32 * 1024, "lz4");
+    private RemoteShuffleOutputGate remoteShuffleOutputGate = 
mock(RemoteShuffleOutputGate.class);
+
+    @Before
+    public void setup() {
+
+    }
+
+    @Test
+    public void tesSimpleFlush() throws IOException, InterruptedException {
+        List<SupplierWithException<BufferPool, IOException>> bufferPool = 
createBufferPoolFactory();
+        RemoteShuffleResultPartition remoteShuffleResultPartition = new 
RemoteShuffleResultPartition("test",
+                0,
+                new ResultPartitionID(),
+                ResultPartitionType.BLOCKING,
+                2,
+                2,
+                32 * 1024,
+                new ResultPartitionManager(),
+                bufferCompressor,
+                bufferPool.get(0),
+                remoteShuffleOutputGate);
+        remoteShuffleResultPartition.setup();
+        doNothing().when(remoteShuffleOutputGate).regionStart(anyBoolean());
+        doNothing().when(remoteShuffleOutputGate).regionFinish();
+        
when(remoteShuffleOutputGate.getBufferPool()).thenReturn(bufferPool.get(1).get());
+        SortBuffer sortBuffer = 
remoteShuffleResultPartition.getUnicastSortBuffer();
+        ByteBuffer byteBuffer = ByteBuffer.wrap(new byte[] {1, 2, 3});
+        sortBuffer.append(byteBuffer, 0, Buffer.DataType.DATA_BUFFER);
+        remoteShuffleResultPartition.flushSortBuffer(sortBuffer, true);
+    }
+
+    private List<SupplierWithException<BufferPool, IOException>> 
createBufferPoolFactory() {
+        NetworkBufferPool networkBufferPool =
+                new NetworkBufferPool(256 * 8, 32 * 1024, 
Duration.ofMillis(1000));
+
+        int numBuffersPerPartition = 64 * 1024 / 32;
+        int numForResultPartition = numBuffersPerPartition * 7 / 8;
+        int numForOutputGate = numBuffersPerPartition - numForResultPartition;
+
+        List<SupplierWithException<BufferPool, IOException>> factories = new 
ArrayList<>();
+        factories.add(
+                () -> 
networkBufferPool.createBufferPool(numForResultPartition, 
numForResultPartition));
+        factories.add(() -> 
networkBufferPool.createBufferPool(numForOutputGate, numForOutputGate));
+        return factories;
+    }
+
+
+}
diff --git 
a/client-flink/flink-common/src/main/java/org/apache/celeborn/plugin/flink/buffer/PartitionSortedBuffer.java
 
b/client-flink/flink-common/src/main/java/org/apache/celeborn/plugin/flink/buffer/PartitionSortedBuffer.java
new file mode 100644
index 00000000..a4184cac
--- /dev/null
+++ 
b/client-flink/flink-common/src/main/java/org/apache/celeborn/plugin/flink/buffer/PartitionSortedBuffer.java
@@ -0,0 +1,440 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.plugin.flink.buffer;
+
+import static org.apache.celeborn.plugin.flink.utils.Utils.checkArgument;
+import static org.apache.celeborn.plugin.flink.utils.Utils.checkNotNull;
+import static org.apache.celeborn.plugin.flink.utils.Utils.checkState;
+import static org.apache.flink.runtime.io.network.buffer.Buffer.DataType;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Arrays;
+
+import javax.annotation.Nullable;
+import javax.annotation.concurrent.GuardedBy;
+import javax.annotation.concurrent.NotThreadSafe;
+
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.BufferPool;
+import org.apache.flink.runtime.io.network.buffer.BufferRecycler;
+import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
+import org.apache.flink.util.FlinkRuntimeException;
+
+/**
+ * A {@link SortBuffer} implementation which sorts all appended records only 
by subpartition index.
+ * Records of the same subpartition keep the appended order.
+ *
+ * <p>It maintains a list of {@link MemorySegment}s as a joint buffer. Data 
will be appended to the
+ * joint buffer sequentially. When writing a record, an index entry will be 
appended first. An index
+ * entry consists of 4 fields: 4 bytes for record length, 4 bytes for {@link 
DataType} and 8 bytes
+ * for address pointing to the next index entry of the same channel which will 
be used to index the
+ * next record to read when coping data from this {@link SortBuffer}. For 
simplicity, no index entry
+ * can span multiple segments. The corresponding record data is seated right 
after its index entry
+ * and different from the index entry, records have variable length thus may 
span multiple segments.
+ */
+@NotThreadSafe
+public class PartitionSortedBuffer implements SortBuffer {
+
+  /**
+   * Size of an index entry: 4 bytes for record length, 4 bytes for data type 
and 8 bytes for
+   * pointer to next entry.
+   */
+  private static final int INDEX_ENTRY_SIZE = 4 + 4 + 8;
+
+  private final Object lock;
+  /** A buffer pool to request memory segments from. */
+  private final BufferPool bufferPool;
+
+  /** A segment list as a joint buffer which stores all records and index 
entries. */
+  @GuardedBy("lock")
+  private final ArrayList<MemorySegment> buffers = new ArrayList<>();
+
+  /** Addresses of the first record's index entry for each subpartition. */
+  private final long[] firstIndexEntryAddresses;
+
+  /** Addresses of the last record's index entry for each subpartition. */
+  private final long[] lastIndexEntryAddresses;
+  /** Size of buffers requested from buffer pool. All buffers must be of the 
same size. */
+  private final int bufferSize;
+  /** Data of different subpartitions in this sort buffer will be read in this 
order. */
+  private final int[] subpartitionReadOrder;
+
+  // 
---------------------------------------------------------------------------------------------
+  // Statistics and states
+  // 
---------------------------------------------------------------------------------------------
+  /** Total number of bytes already appended to this sort buffer. */
+  private long numTotalBytes;
+  /** Total number of records already appended to this sort buffer. */
+  private long numTotalRecords;
+  /** Total number of bytes already read from this sort buffer. */
+  private long numTotalBytesRead;
+  /** Whether this sort buffer is finished. One can only read a finished sort 
buffer. */
+  private boolean isFinished;
+
+  // 
---------------------------------------------------------------------------------------------
+  // For writing
+  // 
---------------------------------------------------------------------------------------------
+  /** Whether this sort buffer is released. A released sort buffer can not be 
used. */
+  @GuardedBy("lock")
+  private boolean isReleased;
+  /** Array index in the segment list of the current available buffer for 
writing. */
+  private int writeSegmentIndex;
+
+  // 
---------------------------------------------------------------------------------------------
+  // For reading
+  // 
---------------------------------------------------------------------------------------------
+  /** Next position in the current available buffer for writing. */
+  private int writeSegmentOffset;
+  /** Index entry address of the current record or event to be read. */
+  private long readIndexEntryAddress;
+
+  /** Record bytes remaining after last copy, which must be read first in next 
copy. */
+  private int recordRemainingBytes;
+
+  /** Used to index the current available channel to read data from. */
+  private int readOrderIndex = -1;
+
+  public PartitionSortedBuffer(
+      BufferPool bufferPool,
+      int numSubpartitions,
+      int bufferSize,
+      @Nullable int[] customReadOrder) {
+    checkArgument(bufferSize > INDEX_ENTRY_SIZE, "Buffer size is too small.");
+
+    this.lock = new Object();
+    this.bufferPool = checkNotNull(bufferPool);
+    this.bufferSize = bufferSize;
+    this.firstIndexEntryAddresses = new long[numSubpartitions];
+    this.lastIndexEntryAddresses = new long[numSubpartitions];
+
+    // initialized with -1 means the corresponding channel has no data.
+    Arrays.fill(firstIndexEntryAddresses, -1L);
+    Arrays.fill(lastIndexEntryAddresses, -1L);
+
+    this.subpartitionReadOrder = new int[numSubpartitions];
+    if (customReadOrder != null) {
+      checkArgument(customReadOrder.length == numSubpartitions, "Illegal data 
read order.");
+      System.arraycopy(customReadOrder, 0, this.subpartitionReadOrder, 0, 
numSubpartitions);
+    } else {
+      for (int channel = 0; channel < numSubpartitions; ++channel) {
+        this.subpartitionReadOrder[channel] = channel;
+      }
+    }
+  }
+
+  @Override
+  public boolean append(ByteBuffer source, int targetChannel, DataType 
dataType)
+      throws IOException {
+    checkArgument(source.hasRemaining(), "Cannot append empty data.");
+    checkState(!isFinished, "Sort buffer is already finished.");
+    checkState(!isReleased, "Sort buffer is already released.");
+
+    int totalBytes = source.remaining();
+
+    // return false directly if it can not allocate enough buffers for the 
given record
+    if (!allocateBuffersForRecord(totalBytes)) {
+      return false;
+    }
+
+    // write the index entry and record or event data
+    writeIndex(targetChannel, totalBytes, dataType);
+    writeRecord(source);
+
+    ++numTotalRecords;
+    numTotalBytes += totalBytes;
+
+    return true;
+  }
+
+  private void writeIndex(int channelIndex, int numRecordBytes, DataType 
dataType) {
+    MemorySegment segment = buffers.get(writeSegmentIndex);
+
+    // record length takes the high 32 bits and data type takes the low 32 bits
+    segment.putLong(writeSegmentOffset, ((long) numRecordBytes << 32) | 
dataType.ordinal());
+
+    // segment index takes the high 32 bits and segment offset takes the low 
32 bits
+    long indexEntryAddress = ((long) writeSegmentIndex << 32) | 
writeSegmentOffset;
+
+    long lastIndexEntryAddress = lastIndexEntryAddresses[channelIndex];
+    lastIndexEntryAddresses[channelIndex] = indexEntryAddress;
+
+    if (lastIndexEntryAddress >= 0) {
+      // link the previous index entry of the given channel to the new index 
entry
+      segment = buffers.get(getSegmentIndexFromPointer(lastIndexEntryAddress));
+      segment.putLong(getSegmentOffsetFromPointer(lastIndexEntryAddress) + 8, 
indexEntryAddress);
+    } else {
+      firstIndexEntryAddresses[channelIndex] = indexEntryAddress;
+    }
+
+    // move the writer position forward to write the corresponding record
+    updateWriteSegmentIndexAndOffset(INDEX_ENTRY_SIZE);
+  }
+
+  private void writeRecord(ByteBuffer source) {
+    while (source.hasRemaining()) {
+      MemorySegment segment = buffers.get(writeSegmentIndex);
+      int toCopy = Math.min(bufferSize - writeSegmentOffset, 
source.remaining());
+      segment.put(writeSegmentOffset, source, toCopy);
+
+      // move the writer position forward to write the remaining bytes or next 
record
+      updateWriteSegmentIndexAndOffset(toCopy);
+    }
+  }
+
+  private boolean allocateBuffersForRecord(int numRecordBytes) throws 
IOException {
+    int numBytesRequired = INDEX_ENTRY_SIZE + numRecordBytes;
+    int availableBytes = writeSegmentIndex == buffers.size() ? 0 : bufferSize 
- writeSegmentOffset;
+
+    // return directly if current available bytes is adequate
+    if (availableBytes >= numBytesRequired) {
+      return true;
+    }
+
+    // skip the remaining free space if the available bytes is not enough for 
an index entry
+    if (availableBytes < INDEX_ENTRY_SIZE) {
+      updateWriteSegmentIndexAndOffset(availableBytes);
+      availableBytes = 0;
+    }
+
+    // allocate exactly enough buffers for the appended record
+    do {
+      MemorySegment segment = requestBufferFromPool();
+      if (segment == null) {
+        // return false if we can not allocate enough buffers for the appended 
record
+        return false;
+      }
+
+      availableBytes += bufferSize;
+      addBuffer(segment);
+    } while (availableBytes < numBytesRequired);
+
+    return true;
+  }
+
+  private void addBuffer(MemorySegment segment) {
+    synchronized (lock) {
+      if (segment.size() != bufferSize) {
+        bufferPool.recycle(segment);
+        throw new IllegalStateException("Illegal memory segment size.");
+      }
+
+      if (isReleased) {
+        bufferPool.recycle(segment);
+        throw new IllegalStateException("Sort buffer is already released.");
+      }
+
+      buffers.add(segment);
+    }
+  }
+
+  private MemorySegment requestBufferFromPool() throws IOException {
+    try {
+      // blocking request buffers if there is still guaranteed memory
+      if (buffers.size() < bufferPool.getNumberOfRequiredMemorySegments()) {
+        return bufferPool.requestMemorySegmentBlocking();
+      }
+    } catch (InterruptedException e) {
+      throw new IOException("Interrupted while requesting buffer.");
+    }
+
+    return bufferPool.requestMemorySegment();
+  }
+
+  private void updateWriteSegmentIndexAndOffset(int numBytes) {
+    writeSegmentOffset += numBytes;
+
+    // using the next available free buffer if the current is full
+    if (writeSegmentOffset == bufferSize) {
+      ++writeSegmentIndex;
+      writeSegmentOffset = 0;
+    }
+  }
+
+  @Override
+  public BufferWithChannel copyIntoSegment(
+      MemorySegment target, BufferRecycler recycler, int offset) {
+    synchronized (lock) {
+      checkState(hasRemaining(), "No data remaining.");
+      checkState(isFinished, "Should finish the sort buffer first before 
coping any data.");
+      checkState(!isReleased, "Sort buffer is already released.");
+
+      int numBytesCopied = 0;
+      DataType bufferDataType = DataType.DATA_BUFFER;
+      int channelIndex = subpartitionReadOrder[readOrderIndex];
+
+      do {
+        int sourceSegmentIndex = 
getSegmentIndexFromPointer(readIndexEntryAddress);
+        int sourceSegmentOffset = 
getSegmentOffsetFromPointer(readIndexEntryAddress);
+        MemorySegment sourceSegment = buffers.get(sourceSegmentIndex);
+
+        long lengthAndDataType = sourceSegment.getLong(sourceSegmentOffset);
+        int length = getSegmentIndexFromPointer(lengthAndDataType);
+        DataType dataType = 
DataType.values()[getSegmentOffsetFromPointer(lengthAndDataType)];
+
+        // return the data read directly if the next to read is an event
+        if (dataType.isEvent() && numBytesCopied > 0) {
+          break;
+        }
+        bufferDataType = dataType;
+
+        // get the next index entry address and move the read position forward
+        long nextReadIndexEntryAddress = 
sourceSegment.getLong(sourceSegmentOffset + 8);
+        sourceSegmentOffset += INDEX_ENTRY_SIZE;
+
+        // throws if the event is too big to be accommodated by a buffer.
+        if (bufferDataType.isEvent() && target.size() < length) {
+          throw new FlinkRuntimeException("Event is too big to be accommodated 
by a buffer");
+        }
+
+        numBytesCopied +=
+            copyRecordOrEvent(
+                target, numBytesCopied + offset, sourceSegmentIndex, 
sourceSegmentOffset, length);
+
+        if (recordRemainingBytes == 0) {
+          // move to next channel if the current channel has been finished
+          if (readIndexEntryAddress == lastIndexEntryAddresses[channelIndex]) {
+            updateReadChannelAndIndexEntryAddress();
+            break;
+          }
+          readIndexEntryAddress = nextReadIndexEntryAddress;
+        }
+      } while (numBytesCopied < target.size() - offset && 
bufferDataType.isBuffer());
+
+      numTotalBytesRead += numBytesCopied;
+      Buffer buffer = new NetworkBuffer(target, recycler, bufferDataType, 
numBytesCopied + offset);
+      return new BufferWithChannel(buffer, channelIndex);
+    }
+  }
+
+  private int copyRecordOrEvent(
+      MemorySegment targetSegment,
+      int targetSegmentOffset,
+      int sourceSegmentIndex,
+      int sourceSegmentOffset,
+      int recordLength) {
+    if (recordRemainingBytes > 0) {
+      // skip the data already read if there is remaining partial record after 
the previous
+      // copy
+      long position = (long) sourceSegmentOffset + (recordLength - 
recordRemainingBytes);
+      sourceSegmentIndex += (position / bufferSize);
+      sourceSegmentOffset = (int) (position % bufferSize);
+    } else {
+      recordRemainingBytes = recordLength;
+    }
+
+    int targetSegmentSize = targetSegment.size();
+    int numBytesToCopy = Math.min(targetSegmentSize - targetSegmentOffset, 
recordRemainingBytes);
+    do {
+      // move to next data buffer if all data of the current buffer has been 
copied
+      if (sourceSegmentOffset == bufferSize) {
+        ++sourceSegmentIndex;
+        sourceSegmentOffset = 0;
+      }
+
+      int sourceRemainingBytes = Math.min(bufferSize - sourceSegmentOffset, 
recordRemainingBytes);
+      int numBytes = Math.min(targetSegmentSize - targetSegmentOffset, 
sourceRemainingBytes);
+      MemorySegment sourceSegment = buffers.get(sourceSegmentIndex);
+      sourceSegment.copyTo(sourceSegmentOffset, targetSegment, 
targetSegmentOffset, numBytes);
+
+      recordRemainingBytes -= numBytes;
+      targetSegmentOffset += numBytes;
+      sourceSegmentOffset += numBytes;
+    } while (recordRemainingBytes > 0 && targetSegmentOffset < 
targetSegmentSize);
+
+    return numBytesToCopy;
+  }
+
+  private void updateReadChannelAndIndexEntryAddress() {
+    // skip the channels without any data
+    while (++readOrderIndex < firstIndexEntryAddresses.length) {
+      int channelIndex = subpartitionReadOrder[readOrderIndex];
+      if ((readIndexEntryAddress = firstIndexEntryAddresses[channelIndex]) >= 
0) {
+        break;
+      }
+    }
+  }
+
+  private int getSegmentIndexFromPointer(long value) {
+    return (int) (value >>> 32);
+  }
+
+  private int getSegmentOffsetFromPointer(long value) {
+    return (int) (value);
+  }
+
+  @Override
+  public long numRecords() {
+    return numTotalRecords;
+  }
+
+  @Override
+  public long numBytes() {
+    return numTotalBytes;
+  }
+
+  @Override
+  public boolean hasRemaining() {
+    return numTotalBytesRead < numTotalBytes;
+  }
+
+  @Override
+  public void finish() {
+    checkState(
+        !isFinished, "com.alibaba.flink.shuffle.plugin.transfer.SortBuffer is 
already finished.");
+
+    isFinished = true;
+
+    // prepare for reading
+    updateReadChannelAndIndexEntryAddress();
+  }
+
+  @Override
+  public boolean isFinished() {
+    return isFinished;
+  }
+
+  @Override
+  public void release() {
+    // the sort buffer can be released by other threads
+    synchronized (lock) {
+      if (isReleased) {
+        return;
+      }
+
+      isReleased = true;
+
+      for (MemorySegment segment : buffers) {
+        bufferPool.recycle(segment);
+      }
+      buffers.clear();
+
+      numTotalBytes = 0;
+      numTotalRecords = 0;
+    }
+  }
+
+  @Override
+  public boolean isReleased() {
+    synchronized (lock) {
+      return isReleased;
+    }
+  }
+}
diff --git 
a/client-flink/flink-common/src/main/java/org/apache/celeborn/plugin/flink/buffer/SortBuffer.java
 
b/client-flink/flink-common/src/main/java/org/apache/celeborn/plugin/flink/buffer/SortBuffer.java
new file mode 100644
index 00000000..7dd43f81
--- /dev/null
+++ 
b/client-flink/flink-common/src/main/java/org/apache/celeborn/plugin/flink/buffer/SortBuffer.java
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.plugin.flink.buffer;
+
+import static org.apache.celeborn.plugin.flink.utils.Utils.checkNotNull;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.BufferRecycler;
+
+/**
+ * Data of different channels can be appended to a {@link SortBuffer}., after 
apending finished,
+ * data can be copied from it in channel index order.
+ */
+public interface SortBuffer {
+
+  /**
+   * Appends data of the specified channel to this {@link SortBuffer} and 
returns true if all bytes
+   * of the source buffer is copied to this {@link SortBuffer} successfully, 
otherwise if returns
+   * false, nothing will be copied.
+   */
+  boolean append(ByteBuffer source, int targetChannel, Buffer.DataType 
dataType) throws IOException;
+
+  /**
+   * Copies data from this {@link SortBuffer} to the target {@link 
MemorySegment} in channel index
+   * order and returns {@link BufferWithChannel} which contains the copied 
data and the
+   * corresponding channel index.
+   */
+  BufferWithChannel copyIntoSegment(MemorySegment target, BufferRecycler 
recycler, int offset);
+
+  /** Returns the number of records written to this {@link SortBuffer}. */
+  long numRecords();
+
+  /** Returns the number of bytes written to this {@link SortBuffer}. */
+  long numBytes();
+
+  /** Returns true if there is still data can be consumed in this {@link 
SortBuffer}. */
+  boolean hasRemaining();
+
+  /** Finishes this {@link SortBuffer} which means no record can be appended 
any more. */
+  void finish();
+
+  /** Whether this {@link SortBuffer} is finished or not. */
+  boolean isFinished();
+
+  /** Releases this {@link SortBuffer} which releases all resources. */
+  void release();
+
+  /** Whether this {@link SortBuffer} is released or not. */
+  boolean isReleased();
+
+  /** Buffer and the corresponding channel index returned to reader. */
+  class BufferWithChannel {
+
+    private final Buffer buffer;
+
+    private final int channelIndex;
+
+    BufferWithChannel(Buffer buffer, int channelIndex) {
+      this.buffer = checkNotNull(buffer);
+      this.channelIndex = channelIndex;
+    }
+
+    /** Get {@link Buffer}. */
+    public Buffer getBuffer() {
+      return buffer;
+    }
+
+    /** Get channel index. */
+    public int getChannelIndex() {
+      return channelIndex;
+    }
+  }
+}


Reply via email to