This is an automated email from the ASF dual-hosted git repository.
zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new 3661222d [CELEBORN-195] add implementation to MapPartitionFileWriter
(#1141)
3661222d is described below
commit 3661222d98d68d71fa25a3330543fae8830a7cb3
Author: zhongqiangczq <[email protected]>
AuthorDate: Fri Jan 13 16:41:11 2023 +0800
[CELEBORN-195] add implementation to MapPartitionFileWriter (#1141)
---
.../worker/storage/MapPartitionFileWriter.java | 215 ++++++++++++++++++++-
.../storage/MapPartitionFileWriterSuiteJ.java | 158 +++++++++++++++
2 files changed, 367 insertions(+), 6 deletions(-)
diff --git
a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/MapPartitionFileWriter.java
b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/MapPartitionFileWriter.java
index 84bf8503..c5d94c9c 100644
---
a/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/MapPartitionFileWriter.java
+++
b/worker/src/main/java/org/apache/celeborn/service/deploy/worker/storage/MapPartitionFileWriter.java
@@ -20,10 +20,14 @@ package org.apache.celeborn.service.deploy.worker.storage;
import java.io.FileOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
import java.nio.channels.FileChannel;
+import java.util.Arrays;
+import io.netty.buffer.ByteBuf;
import io.netty.buffer.CompositeByteBuf;
-import org.apache.hadoop.fs.FSDataOutputStream;
+import io.netty.buffer.Unpooled;
+import org.apache.hadoop.fs.Path;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -32,6 +36,9 @@ import org.apache.celeborn.common.meta.FileInfo;
import org.apache.celeborn.common.metrics.source.AbstractSource;
import org.apache.celeborn.common.protocol.PartitionSplitMode;
import org.apache.celeborn.common.protocol.PartitionType;
+import org.apache.celeborn.common.unsafe.Platform;
+import org.apache.celeborn.common.util.Utils;
+import org.apache.celeborn.service.deploy.worker.WorkerSource;
/*
* map partition file writer, it will create index for each partition
@@ -49,7 +56,6 @@ public final class MapPartitionFileWriter extends FileWriter {
private long regionStartingOffset;
private long numDataRegions;
private FileChannel channelIndex;
- private FSDataOutputStream streamIndex;
private CompositeByteBuf flushBufferIndex;
public MapPartitionFileWriter(
@@ -75,13 +81,119 @@ public final class MapPartitionFileWriter extends
FileWriter {
if (!fileInfo.isHdfs()) {
channelIndex = new
FileOutputStream(fileInfo.getIndexPath()).getChannel();
} else {
- streamIndex =
StorageManager.hadoopFs().create(fileInfo.getHdfsIndexPath(), true);
+ try {
+ StorageManager.hadoopFs().create(fileInfo.getHdfsIndexPath(),
true).close();
+ } catch (IOException e) {
+ try {
+ // If create file failed, wait 10 ms and retry
+ Thread.sleep(10);
+ } catch (InterruptedException ex) {
+ throw new RuntimeException(ex);
+ }
+ StorageManager.hadoopFs().create(fileInfo.getHdfsIndexPath(),
true).close();
+ }
}
+ takeBufferIndex();
+ }
+
+ private void takeBufferIndex() {
+ // metrics start
+ String metricsName = null;
+ String fileAbsPath = null;
+ if (source.metricsCollectCriticalEnabled()) {
+ metricsName = WorkerSource.TakeBufferTimeIndex();
+ fileAbsPath = fileInfo.getIndexPath();
+ source.startTimer(metricsName, fileAbsPath);
+ }
+
+ // real action
+ flushBufferIndex = flusher.takeBuffer();
+
+ // metrics end
+ if (source.metricsCollectCriticalEnabled()) {
+ source.stopTimer(metricsName, fileAbsPath);
+ }
+
+ if (flushBufferIndex == null) {
+ IOException e =
+ new IOException(
+ "Take buffer index encounter error from Flusher: " +
flusher.bufferQueueInfo());
+ notifier.setException(e);
+ }
+ }
+
+ public void write(ByteBuf data) throws IOException {
+ byte[] header = new byte[16];
+ data.markReaderIndex();
+ data.readBytes(header);
+ data.resetReaderIndex();
+ int partitionId = Platform.getInt(header, Platform.BYTE_ARRAY_OFFSET);
+ collectPartitionDataLength(partitionId, data);
+
+ super.write(data);
+ }
+
+ private void collectPartitionDataLength(int partitionId, ByteBuf data)
throws IOException {
+ if (numReducePartitionBytes == null) {
+ numReducePartitionBytes = new long[numReducePartitions];
+ }
+ if (partitionId < currentReducePartition) {
+ throw new IOException(
+ "Must writing data in reduce partition index order, but now
partitionId is "
+ + partitionId
+ + " and pre partitionId is "
+ + currentReducePartition);
+ }
+
+ if (partitionId > currentReducePartition) {
+ currentReducePartition = partitionId;
+ }
+ long length = data.readableBytes();
+ totalBytes += length;
+ numReducePartitionBytes[partitionId] += length;
}
@Override
- public long close() throws IOException {
- return 0;
+ public synchronized long close() throws IOException {
+ return super.close(
+ () -> {
+ if (flushBufferIndex.readableBytes() > 0) {
+ flushIndex();
+ }
+ },
+ () -> {
+ if
(StorageManager.hadoopFs().exists(fileInfo.getHdfsPeerWriterSuccessPath())) {
+ StorageManager.hadoopFs().delete(fileInfo.getHdfsPath(), false);
+ deleted = true;
+ } else {
+
StorageManager.hadoopFs().create(fileInfo.getHdfsWriterSuccessPath()).close();
+ }
+ },
+ () -> {
+ returnBufferIndex();
+ if (channelIndex != null) {
+ channelIndex.close();
+ }
+ if (fileInfo.isHdfs()) {
+ if (StorageManager.hadoopFs()
+ .exists(
+ new Path(
+ Utils.getWriteSuccessFilePath(
+ Utils.getPeerPath(fileInfo.getIndexPath()))))) {
+ StorageManager.hadoopFs().delete(fileInfo.getHdfsIndexPath(),
false);
+ deleted = true;
+ } else {
+ StorageManager.hadoopFs()
+ .create(new
Path(Utils.getWriteSuccessFilePath((fileInfo.getIndexPath()))))
+ .close();
+ }
+ }
+ });
+ }
+
+ public synchronized void destroy(IOException ioException) {
+ destroyIndex();
+ super.destroy(ioException);
}
public void pushDataHandShake(int numReducePartitions, int bufferSize) {
@@ -90,9 +202,100 @@ public final class MapPartitionFileWriter extends
FileWriter {
}
public void regionStart(int currentDataRegionIndex, boolean
isBroadcastRegion) {
+ this.currentReducePartition = 0;
this.currentDataRegionIndex = currentDataRegionIndex;
this.isBroadcastRegion = isBroadcastRegion;
}
- public void regionFinish() throws IOException {}
+ public void regionFinish() throws IOException {
+ if (regionStartingOffset == totalBytes) {
+ return;
+ }
+
+ long fileOffset = regionStartingOffset;
+ if (indexBuffer == null) {
+ indexBuffer = allocateIndexBuffer(numReducePartitions);
+ }
+
+ // write the index information of the current data region
+ for (int partitionIndex = 0; partitionIndex < numReducePartitions;
++partitionIndex) {
+ indexBuffer.putLong(fileOffset);
+ if (!isBroadcastRegion) {
+ indexBuffer.putLong(numReducePartitionBytes[partitionIndex]);
+ fileOffset += numReducePartitionBytes[partitionIndex];
+ } else {
+ indexBuffer.putLong(numReducePartitionBytes[0]);
+ }
+ }
+
+ if (!indexBuffer.hasRemaining()) {
+ flushIndex();
+ takeBufferIndex();
+ }
+
+ ++numDataRegions;
+ regionStartingOffset = totalBytes;
+ Arrays.fill(numReducePartitionBytes, 0);
+ }
+
+ private synchronized void returnBufferIndex() {
+ if (flushBufferIndex != null) {
+ flusher.returnBuffer(flushBufferIndex);
+ flushBufferIndex = null;
+ }
+ }
+
+ private synchronized void destroyIndex() {
+ returnBufferIndex();
+ try {
+ if (channelIndex != null) {
+ channelIndex.close();
+ }
+ } catch (IOException e) {
+ logger.warn(
+ "Close channel failed for file {} caused by {}.",
+ fileInfo.getIndexPath(),
+ e.getMessage());
+ }
+ }
+
+ private void flushIndex() throws IOException {
+ indexBuffer.flip();
+ notifier.checkException();
+ notifier.numPendingFlushes.incrementAndGet();
+ if (indexBuffer.hasRemaining()) {
+ FlushTask task = null;
+ if (channelIndex != null) {
+ Unpooled.wrappedBuffer(indexBuffer);
+ task = new LocalFlushTask(flushBufferIndex, channelIndex, notifier);
+ } else if (fileInfo.isHdfs()) {
+ task = new HdfsFlushTask(flushBufferIndex,
fileInfo.getHdfsIndexPath(), notifier);
+ }
+ addTask(task);
+ flushBufferIndex = null;
+ }
+ indexBuffer.clear();
+ }
+
+ private ByteBuffer allocateIndexBuffer(int numPartitions) {
+
+ // the returned buffer size is no smaller than 4096 bytes to improve disk
IO performance
+ int minBufferSize = 4096;
+
+ int indexRegionSize = numPartitions * (8 + 8);
+ if (indexRegionSize >= minBufferSize) {
+ ByteBuffer buffer = ByteBuffer.allocateDirect(indexRegionSize);
+ buffer.order(ByteOrder.BIG_ENDIAN);
+ return buffer;
+ }
+
+ int numRegions = minBufferSize / indexRegionSize;
+ if (minBufferSize % indexRegionSize != 0) {
+ ++numRegions;
+ }
+ ByteBuffer buffer = ByteBuffer.allocateDirect(numRegions *
indexRegionSize);
+ buffer.order(ByteOrder.BIG_ENDIAN);
+
+ return buffer;
+ }
}
diff --git
a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/MapPartitionFileWriterSuiteJ.java
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/MapPartitionFileWriterSuiteJ.java
new file mode 100644
index 00000000..018e7c83
--- /dev/null
+++
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/MapPartitionFileWriterSuiteJ.java
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.service.deploy.worker.storage;
+
+import static org.junit.Assert.assertEquals;
+
+import java.io.File;
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.util.UUID;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ThreadLocalRandom;
+import java.util.concurrent.atomic.AtomicLong;
+
+import scala.Function0;
+import scala.collection.mutable.ListBuffer;
+
+import io.netty.buffer.Unpooled;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.mockito.Mockito;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.identity.UserIdentifier;
+import org.apache.celeborn.common.meta.FileInfo;
+import org.apache.celeborn.common.network.server.memory.MemoryManager;
+import org.apache.celeborn.common.network.util.JavaUtils;
+import org.apache.celeborn.common.protocol.PartitionSplitMode;
+import org.apache.celeborn.common.protocol.StorageInfo;
+import org.apache.celeborn.common.unsafe.Platform;
+import org.apache.celeborn.common.util.Utils;
+import org.apache.celeborn.service.deploy.worker.WorkerSource;
+
+public class MapPartitionFileWriterSuiteJ {
+
+ private static final Logger LOG =
LoggerFactory.getLogger(MapPartitionFileWriterSuiteJ.class);
+
+ private static final CelebornConf CONF = new CelebornConf();
+ public static final Long SPLIT_THRESHOLD = 256 * 1024 * 1024L;
+ public static final PartitionSplitMode splitMode = PartitionSplitMode.HARD;
+
+ private static File tempDir = null;
+ private static LocalFlusher localFlusher = null;
+ private static WorkerSource source = null;
+
+ private final UserIdentifier userIdentifier = new
UserIdentifier("mock-tenantId", "mock-name");
+
+ @BeforeClass
+ public static void beforeAll() {
+ tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"),
"celeborn");
+
+ source = Mockito.mock(WorkerSource.class);
+ Mockito.doAnswer(
+ invocationOnMock -> {
+ Function0<?> function = (Function0<?>)
invocationOnMock.getArguments()[2];
+ return function.apply();
+ })
+ .when(source)
+ .sample(Mockito.anyString(), Mockito.anyString(),
Mockito.any(Function0.class));
+
+ ListBuffer<File> dirs = new ListBuffer<>();
+ dirs.$plus$eq(tempDir);
+ localFlusher =
+ new LocalFlusher(
+ source, DeviceMonitor$.MODULE$.EmptyMonitor(), 1, "disk1", 20, 1,
StorageInfo.Type.HDD);
+ MemoryManager.initialize(0.8, 0.9, 0.5, 0.6, 0.1, 0.1, 10, 10);
+ }
+
+ @AfterClass
+ public static void afterAll() {
+ if (tempDir != null) {
+ try {
+ JavaUtils.deleteRecursively(tempDir);
+ tempDir = null;
+ } catch (IOException e) {
+ LOG.error("Failed to delete temp dir.", e);
+ }
+ }
+ }
+
+ @Test
+ public void testMultiThreadWrite() throws IOException, ExecutionException,
InterruptedException {
+ File file = getTemporaryFile();
+ MapPartitionFileWriter fileWriter =
+ new MapPartitionFileWriter(
+ new FileInfo(file, userIdentifier),
+ localFlusher,
+ source,
+ CONF,
+ DeviceMonitor$.MODULE$.EmptyMonitor(),
+ SPLIT_THRESHOLD,
+ splitMode,
+ false);
+ fileWriter.pushDataHandShake(2, 32 * 1024);
+ fileWriter.regionStart(0, false);
+ byte[] partData0 = generateData(0);
+ byte[] partData1 = generateData(1);
+ AtomicLong length = new AtomicLong(0);
+ try {
+ fileWriter.write(Unpooled.wrappedBuffer(partData0));
+ length.addAndGet(partData0.length);
+ fileWriter.write(Unpooled.wrappedBuffer(partData1));
+ length.addAndGet(partData1.length);
+
+ fileWriter.regionFinish();
+ } catch (IOException e) {
+ LOG.error("Failed to write buffer.", e);
+ }
+ long bytesWritten = fileWriter.close();
+
+ assertEquals(length.get(), bytesWritten);
+ assertEquals(fileWriter.getFile().length(), bytesWritten);
+ }
+
+ private File getTemporaryFile() throws IOException {
+ String filename = UUID.randomUUID().toString();
+ File temporaryFile = new File(tempDir, filename);
+ temporaryFile.createNewFile();
+ return temporaryFile;
+ }
+
+ private byte[] generateData(int partitionId) {
+ ThreadLocalRandom rand = ThreadLocalRandom.current();
+ byte[] hello = "hello, world".getBytes(StandardCharsets.UTF_8);
+ int headerLength = 16;
+ int tempLen = rand.nextInt(256 * 1024) + 128 * 1024 - headerLength;
+ int len = (int) (Math.ceil(1.0 * tempLen / hello.length) * hello.length) +
headerLength;
+
+ byte[] data = new byte[len];
+ Platform.putInt(data, Platform.BYTE_ARRAY_OFFSET, partitionId);
+ Platform.putInt(data, Platform.BYTE_ARRAY_OFFSET + 4, 0);
+ Platform.putInt(data, Platform.BYTE_ARRAY_OFFSET + 8, rand.nextInt());
+ Platform.putInt(data, Platform.BYTE_ARRAY_OFFSET + 12, len);
+
+ for (int i = headerLength; i < len; i += hello.length) {
+ System.arraycopy(hello, 0, data, i, hello.length);
+ }
+ return data;
+ }
+}