This is an automated email from the ASF dual-hosted git repository.

zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 3661222d [CELEBORN-195] add implementation to MapPartitionFileWriter 
(#1141)
3661222d is described below

commit 3661222d98d68d71fa25a3330543fae8830a7cb3
Author: zhongqiangczq <[email protected]>
AuthorDate: Fri Jan 13 16:41:11 2023 +0800

    [CELEBORN-195] add implementation to MapPartitionFileWriter (#1141)
---
 .../worker/storage/MapPartitionFileWriter.java     | 215 ++++++++++++++++++++-
 .../storage/MapPartitionFileWriterSuiteJ.java      | 158 +++++++++++++++
 2 files changed, 367 insertions(+), 6 deletions(-)

diff --git 
a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/MapPartitionFileWriter.java
 
b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/MapPartitionFileWriter.java
index 84bf8503..c5d94c9c 100644
--- 
a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/MapPartitionFileWriter.java
+++ 
b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/MapPartitionFileWriter.java
@@ -20,10 +20,14 @@ package org.apache.celeborn.service.deploy.worker.storage;
 import java.io.FileOutputStream;
 import java.io.IOException;
 import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
 import java.nio.channels.FileChannel;
+import java.util.Arrays;
 
+import io.netty.buffer.ByteBuf;
 import io.netty.buffer.CompositeByteBuf;
-import org.apache.hadoop.fs.FSDataOutputStream;
+import io.netty.buffer.Unpooled;
+import org.apache.hadoop.fs.Path;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -32,6 +36,9 @@ import org.apache.celeborn.common.meta.FileInfo;
 import org.apache.celeborn.common.metrics.source.AbstractSource;
 import org.apache.celeborn.common.protocol.PartitionSplitMode;
 import org.apache.celeborn.common.protocol.PartitionType;
+import org.apache.celeborn.common.unsafe.Platform;
+import org.apache.celeborn.common.util.Utils;
+import org.apache.celeborn.service.deploy.worker.WorkerSource;
 
 /*
  * map partition file writer, it will create index for each partition
@@ -49,7 +56,6 @@ public final class MapPartitionFileWriter extends FileWriter {
   private long regionStartingOffset;
   private long numDataRegions;
   private FileChannel channelIndex;
-  private FSDataOutputStream streamIndex;
   private CompositeByteBuf flushBufferIndex;
 
   public MapPartitionFileWriter(
@@ -75,13 +81,119 @@ public final class MapPartitionFileWriter extends 
FileWriter {
     if (!fileInfo.isHdfs()) {
       channelIndex = new 
FileOutputStream(fileInfo.getIndexPath()).getChannel();
     } else {
-      streamIndex = 
StorageManager.hadoopFs().create(fileInfo.getHdfsIndexPath(), true);
+      try {
+        StorageManager.hadoopFs().create(fileInfo.getHdfsIndexPath(), 
true).close();
+      } catch (IOException e) {
+        try {
+          // If create file failed, wait 10 ms and retry
+          Thread.sleep(10);
+        } catch (InterruptedException ex) {
+          throw new RuntimeException(ex);
+        }
+        StorageManager.hadoopFs().create(fileInfo.getHdfsIndexPath(), 
true).close();
+      }
     }
+    takeBufferIndex();
+  }
+
+  private void takeBufferIndex() {
+    // metrics start
+    String metricsName = null;
+    String fileAbsPath = null;
+    if (source.metricsCollectCriticalEnabled()) {
+      metricsName = WorkerSource.TakeBufferTimeIndex();
+      fileAbsPath = fileInfo.getIndexPath();
+      source.startTimer(metricsName, fileAbsPath);
+    }
+
+    // real action
+    flushBufferIndex = flusher.takeBuffer();
+
+    // metrics end
+    if (source.metricsCollectCriticalEnabled()) {
+      source.stopTimer(metricsName, fileAbsPath);
+    }
+
+    if (flushBufferIndex == null) {
+      IOException e =
+          new IOException(
+              "Take buffer index encounter error from Flusher: " + 
flusher.bufferQueueInfo());
+      notifier.setException(e);
+    }
+  }
+
+  public void write(ByteBuf data) throws IOException {
+    byte[] header = new byte[16];
+    data.markReaderIndex();
+    data.readBytes(header);
+    data.resetReaderIndex();
+    int partitionId = Platform.getInt(header, Platform.BYTE_ARRAY_OFFSET);
+    collectPartitionDataLength(partitionId, data);
+
+    super.write(data);
+  }
+
+  private void collectPartitionDataLength(int partitionId, ByteBuf data) 
throws IOException {
+    if (numReducePartitionBytes == null) {
+      numReducePartitionBytes = new long[numReducePartitions];
+    }
+    if (partitionId < currentReducePartition) {
+      throw new IOException(
+          "Must writing data in reduce partition index order, but now 
partitionId is "
+              + partitionId
+              + " and pre partitionId is "
+              + currentReducePartition);
+    }
+
+    if (partitionId > currentReducePartition) {
+      currentReducePartition = partitionId;
+    }
+    long length = data.readableBytes();
+    totalBytes += length;
+    numReducePartitionBytes[partitionId] += length;
   }
 
   @Override
-  public long close() throws IOException {
-    return 0;
+  public synchronized long close() throws IOException {
+    return super.close(
+        () -> {
+          if (flushBufferIndex.readableBytes() > 0) {
+            flushIndex();
+          }
+        },
+        () -> {
+          if 
(StorageManager.hadoopFs().exists(fileInfo.getHdfsPeerWriterSuccessPath())) {
+            StorageManager.hadoopFs().delete(fileInfo.getHdfsPath(), false);
+            deleted = true;
+          } else {
+            
StorageManager.hadoopFs().create(fileInfo.getHdfsWriterSuccessPath()).close();
+          }
+        },
+        () -> {
+          returnBufferIndex();
+          if (channelIndex != null) {
+            channelIndex.close();
+          }
+          if (fileInfo.isHdfs()) {
+            if (StorageManager.hadoopFs()
+                .exists(
+                    new Path(
+                        Utils.getWriteSuccessFilePath(
+                            Utils.getPeerPath(fileInfo.getIndexPath()))))) {
+              StorageManager.hadoopFs().delete(fileInfo.getHdfsIndexPath(), 
false);
+              deleted = true;
+            } else {
+              StorageManager.hadoopFs()
+                  .create(new 
Path(Utils.getWriteSuccessFilePath((fileInfo.getIndexPath()))))
+                  .close();
+            }
+          }
+        });
+  }
+
+  public synchronized void destroy(IOException ioException) {
+    destroyIndex();
+    super.destroy(ioException);
   }
 
   public void pushDataHandShake(int numReducePartitions, int bufferSize) {
@@ -90,9 +202,100 @@ public final class MapPartitionFileWriter extends 
FileWriter {
   }
 
   public void regionStart(int currentDataRegionIndex, boolean 
isBroadcastRegion) {
+    this.currentReducePartition = 0;
     this.currentDataRegionIndex = currentDataRegionIndex;
     this.isBroadcastRegion = isBroadcastRegion;
   }
 
-  public void regionFinish() throws IOException {}
+  public void regionFinish() throws IOException {
+    if (regionStartingOffset == totalBytes) {
+      return;
+    }
+
+    long fileOffset = regionStartingOffset;
+    if (indexBuffer == null) {
+      indexBuffer = allocateIndexBuffer(numReducePartitions);
+    }
+
+    // write the index information of the current data region
+    for (int partitionIndex = 0; partitionIndex < numReducePartitions; 
++partitionIndex) {
+      indexBuffer.putLong(fileOffset);
+      if (!isBroadcastRegion) {
+        indexBuffer.putLong(numReducePartitionBytes[partitionIndex]);
+        fileOffset += numReducePartitionBytes[partitionIndex];
+      } else {
+        indexBuffer.putLong(numReducePartitionBytes[0]);
+      }
+    }
+
+    if (!indexBuffer.hasRemaining()) {
+      flushIndex();
+      takeBufferIndex();
+    }
+
+    ++numDataRegions;
+    regionStartingOffset = totalBytes;
+    Arrays.fill(numReducePartitionBytes, 0);
+  }
+
+  private synchronized void returnBufferIndex() {
+    if (flushBufferIndex != null) {
+      flusher.returnBuffer(flushBufferIndex);
+      flushBufferIndex = null;
+    }
+  }
+
+  private synchronized void destroyIndex() {
+    returnBufferIndex();
+    try {
+      if (channelIndex != null) {
+        channelIndex.close();
+      }
+    } catch (IOException e) {
+      logger.warn(
+          "Close channel failed for file {} caused by {}.",
+          fileInfo.getIndexPath(),
+          e.getMessage());
+    }
+  }
+
+  private void flushIndex() throws IOException {
+    indexBuffer.flip();
+    notifier.checkException();
+    notifier.numPendingFlushes.incrementAndGet();
+    if (indexBuffer.hasRemaining()) {
+      FlushTask task = null;
+      if (channelIndex != null) {
+        Unpooled.wrappedBuffer(indexBuffer);
+        task = new LocalFlushTask(flushBufferIndex, channelIndex, notifier);
+      } else if (fileInfo.isHdfs()) {
+        task = new HdfsFlushTask(flushBufferIndex, 
fileInfo.getHdfsIndexPath(), notifier);
+      }
+      addTask(task);
+      flushBufferIndex = null;
+    }
+    indexBuffer.clear();
+  }
+
+  private ByteBuffer allocateIndexBuffer(int numPartitions) {
+
+    // the returned buffer size is no smaller than 4096 bytes to improve disk 
IO performance
+    int minBufferSize = 4096;
+
+    int indexRegionSize = numPartitions * (8 + 8);
+    if (indexRegionSize >= minBufferSize) {
+      ByteBuffer buffer = ByteBuffer.allocateDirect(indexRegionSize);
+      buffer.order(ByteOrder.BIG_ENDIAN);
+      return buffer;
+    }
+
+    int numRegions = minBufferSize / indexRegionSize;
+    if (minBufferSize % indexRegionSize != 0) {
+      ++numRegions;
+    }
+    ByteBuffer buffer = ByteBuffer.allocateDirect(numRegions * 
indexRegionSize);
+    buffer.order(ByteOrder.BIG_ENDIAN);
+
+    return buffer;
+  }
 }
diff --git 
a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/MapPartitionFileWriterSuiteJ.java
 
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/MapPartitionFileWriterSuiteJ.java
new file mode 100644
index 00000000..018e7c83
--- /dev/null
+++ 
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/MapPartitionFileWriterSuiteJ.java
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.service.deploy.worker.storage;
+
+import static org.junit.Assert.assertEquals;
+
+import java.io.File;
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.util.UUID;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ThreadLocalRandom;
+import java.util.concurrent.atomic.AtomicLong;
+
+import scala.Function0;
+import scala.collection.mutable.ListBuffer;
+
+import io.netty.buffer.Unpooled;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.mockito.Mockito;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.identity.UserIdentifier;
+import org.apache.celeborn.common.meta.FileInfo;
+import org.apache.celeborn.common.network.server.memory.MemoryManager;
+import org.apache.celeborn.common.network.util.JavaUtils;
+import org.apache.celeborn.common.protocol.PartitionSplitMode;
+import org.apache.celeborn.common.protocol.StorageInfo;
+import org.apache.celeborn.common.unsafe.Platform;
+import org.apache.celeborn.common.util.Utils;
+import org.apache.celeborn.service.deploy.worker.WorkerSource;
+
+public class MapPartitionFileWriterSuiteJ {
+
+  private static final Logger LOG = 
LoggerFactory.getLogger(MapPartitionFileWriterSuiteJ.class);
+
+  private static final CelebornConf CONF = new CelebornConf();
+  public static final Long SPLIT_THRESHOLD = 256 * 1024 * 1024L;
+  public static final PartitionSplitMode splitMode = PartitionSplitMode.HARD;
+
+  private static File tempDir = null;
+  private static LocalFlusher localFlusher = null;
+  private static WorkerSource source = null;
+
+  private final UserIdentifier userIdentifier = new 
UserIdentifier("mock-tenantId", "mock-name");
+
+  @BeforeClass
+  public static void beforeAll() {
+    tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), 
"celeborn");
+
+    source = Mockito.mock(WorkerSource.class);
+    Mockito.doAnswer(
+            invocationOnMock -> {
+              Function0<?> function = (Function0<?>) 
invocationOnMock.getArguments()[2];
+              return function.apply();
+            })
+        .when(source)
+        .sample(Mockito.anyString(), Mockito.anyString(), 
Mockito.any(Function0.class));
+
+    ListBuffer<File> dirs = new ListBuffer<>();
+    dirs.$plus$eq(tempDir);
+    localFlusher =
+        new LocalFlusher(
+            source, DeviceMonitor$.MODULE$.EmptyMonitor(), 1, "disk1", 20, 1, 
StorageInfo.Type.HDD);
+    MemoryManager.initialize(0.8, 0.9, 0.5, 0.6, 0.1, 0.1, 10, 10);
+  }
+
+  @AfterClass
+  public static void afterAll() {
+    if (tempDir != null) {
+      try {
+        JavaUtils.deleteRecursively(tempDir);
+        tempDir = null;
+      } catch (IOException e) {
+        LOG.error("Failed to delete temp dir.", e);
+      }
+    }
+  }
+
+  @Test
+  public void testMultiThreadWrite() throws IOException, ExecutionException, 
InterruptedException {
+    File file = getTemporaryFile();
+    MapPartitionFileWriter fileWriter =
+        new MapPartitionFileWriter(
+            new FileInfo(file, userIdentifier),
+            localFlusher,
+            source,
+            CONF,
+            DeviceMonitor$.MODULE$.EmptyMonitor(),
+            SPLIT_THRESHOLD,
+            splitMode,
+            false);
+    fileWriter.pushDataHandShake(2, 32 * 1024);
+    fileWriter.regionStart(0, false);
+    byte[] partData0 = generateData(0);
+    byte[] partData1 = generateData(1);
+    AtomicLong length = new AtomicLong(0);
+    try {
+      fileWriter.write(Unpooled.wrappedBuffer(partData0));
+      length.addAndGet(partData0.length);
+      fileWriter.write(Unpooled.wrappedBuffer(partData1));
+      length.addAndGet(partData1.length);
+
+      fileWriter.regionFinish();
+    } catch (IOException e) {
+      LOG.error("Failed to write buffer.", e);
+    }
+    long bytesWritten = fileWriter.close();
+
+    assertEquals(length.get(), bytesWritten);
+    assertEquals(fileWriter.getFile().length(), bytesWritten);
+  }
+
+  private File getTemporaryFile() throws IOException {
+    String filename = UUID.randomUUID().toString();
+    File temporaryFile = new File(tempDir, filename);
+    temporaryFile.createNewFile();
+    return temporaryFile;
+  }
+
+  private byte[] generateData(int partitionId) {
+    ThreadLocalRandom rand = ThreadLocalRandom.current();
+    byte[] hello = "hello, world".getBytes(StandardCharsets.UTF_8);
+    int headerLength = 16;
+    int tempLen = rand.nextInt(256 * 1024) + 128 * 1024 - headerLength;
+    int len = (int) (Math.ceil(1.0 * tempLen / hello.length) * hello.length) + 
headerLength;
+
+    byte[] data = new byte[len];
+    Platform.putInt(data, Platform.BYTE_ARRAY_OFFSET, partitionId);
+    Platform.putInt(data, Platform.BYTE_ARRAY_OFFSET + 4, 0);
+    Platform.putInt(data, Platform.BYTE_ARRAY_OFFSET + 8, rand.nextInt());
+    Platform.putInt(data, Platform.BYTE_ARRAY_OFFSET + 12, len);
+
+    for (int i = headerLength; i < len; i += hello.length) {
+      System.arraycopy(hello, 0, data, i, hello.length);
+    }
+    return data;
+  }
+}

Reply via email to