This is an automated email from the ASF dual-hosted git repository.

ethanfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 3dd810cd9 [CELEBORN-1612] Add a basic reader writer class to Tez
3dd810cd9 is described below

commit 3dd810cd9bf1ae09d7c4b024876d94aac066e3c9
Author: hongguangwei <[email protected]>
AuthorDate: Tue Dec 3 14:51:52 2024 +0800

    [CELEBORN-1612] Add a basic reader writer class to Tez
    
    ### What changes were proposed in this pull request?
    1. Add a basic reader writer class to Tez
    2. Copy sort pusher from MR to Tez project and add unsort implementation
    
    ### Why are the changes needed?
    Add basic utilities to support Tez client.
    
    ### Does this PR introduce _any_ user-facing change?
    NO.
    
    ### How was this patch tested?
    Cluster tests.
    
    Closes #2969 from GH-Gloway/1610.
    
    Authored-by: hongguangwei <[email protected]>
    Signed-off-by: mingji <[email protected]>
---
 .../apache/celeborn/client/CelebornTezReader.java  |  94 ++++++
 .../apache/celeborn/client/CelebornTezWriter.java  | 128 ++++++++
 .../library/sort/CelebornSortBasedPusher.java      | 347 +++++++++++++++++++++
 3 files changed, 569 insertions(+)

diff --git 
a/client-tez/tez/src/main/java/org/apache/celeborn/client/CelebornTezReader.java
 
b/client-tez/tez/src/main/java/org/apache/celeborn/client/CelebornTezReader.java
new file mode 100644
index 000000000..47af6e465
--- /dev/null
+++ 
b/client-tez/tez/src/main/java/org/apache/celeborn/client/CelebornTezReader.java
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.client;
+
+import java.io.IOException;
+
+import org.slf4j.LoggerFactory;
+
+import org.apache.celeborn.client.read.CelebornInputStream;
+import org.apache.celeborn.client.read.MetricsCallback;
+import org.apache.celeborn.common.exception.CelebornIOException;
+import org.apache.celeborn.common.unsafe.Platform;
+
+public class CelebornTezReader {
+  private static final org.slf4j.Logger logger = 
LoggerFactory.getLogger(CelebornTezReader.class);
+
+  private int shuffleId;
+  private int partitionId;
+  private int attemptNumber;
+  private ShuffleClient shuffleClient;
+  private long inputShuffleSize;
+  private CelebornInputStream celebornInputStream;
+
+  public CelebornTezReader(
+      ShuffleClient shuffleClient, int shuffleId, int partitionId, int 
attemptNumber) {
+    this.shuffleClient = shuffleClient;
+    this.partitionId = partitionId;
+    this.attemptNumber = attemptNumber;
+    this.shuffleId = shuffleId;
+  }
+
+  public void init() throws IOException {
+    MetricsCallback metricsCallback =
+        new MetricsCallback() {
+          @Override
+          public void incBytesRead(long bytesRead) {}
+
+          @Override
+          public void incReadTime(long time) {}
+        };
+    celebornInputStream =
+        shuffleClient.readPartition(
+            shuffleId, partitionId, attemptNumber, 0, Integer.MAX_VALUE, 
metricsCallback);
+  }
+
+  public byte[] getShuffleBlock() throws IOException {
+    // get len
+    byte[] header = new byte[4];
+    int count = celebornInputStream.read(header);
+    if (count == -1) {
+      return null;
+    }
+    while (count != header.length) {
+      count += celebornInputStream.read(header, count, 4 - count);
+    }
+
+    // get data
+    int blockLen = Platform.getInt(header, Platform.BYTE_ARRAY_OFFSET);
+    inputShuffleSize += blockLen;
+    byte[] shuffleData = new byte[blockLen];
+    count = celebornInputStream.read(shuffleData);
+    while (count != shuffleData.length) {
+      count += celebornInputStream.read(shuffleData, count, blockLen - count);
+      if (count == -1) {
+        // read shuffle is done.
+        throw new CelebornIOException("Read mr shuffle failed.");
+      }
+    }
+    return shuffleData;
+  }
+
+  public void close() throws IOException {
+    celebornInputStream.close();
+  }
+
+  public int getPartitionId() {
+    return partitionId;
+  }
+}
diff --git 
a/client-tez/tez/src/main/java/org/apache/celeborn/client/CelebornTezWriter.java
 
b/client-tez/tez/src/main/java/org/apache/celeborn/client/CelebornTezWriter.java
new file mode 100644
index 000000000..e80de0504
--- /dev/null
+++ 
b/client-tez/tez/src/main/java/org/apache/celeborn/client/CelebornTezWriter.java
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.client;
+
+import java.io.IOException;
+import java.util.concurrent.atomic.LongAdder;
+
+import org.apache.tez.runtime.library.api.IOInterruptedException;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.celeborn.client.write.DataPusher;
+import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.identity.UserIdentifier;
+
+public class CelebornTezWriter {
+  private final Logger logger = 
LoggerFactory.getLogger(CelebornTezWriter.class);
+
+  private final ShuffleClient shuffleClient;
+  private DataPusher dataPusher;
+  private final int shuffleId;
+  private final int mapId;
+  private final int attemptNumber;
+  private final int numMappers;
+  private final int numPartitions;
+
+  public CelebornTezWriter(
+      int shuffleId,
+      int mapId,
+      int attemptNumber,
+      long taskAttemptId,
+      int numMappers,
+      int numPartitions,
+      CelebornConf conf,
+      String appUniqueId,
+      String lifecycleManagerHost,
+      int lifecycleManagerPort,
+      UserIdentifier userIdentifier) {
+    shuffleClient =
+        ShuffleClient.get(
+            appUniqueId, lifecycleManagerHost, lifecycleManagerPort, conf, 
userIdentifier, null);
+    // TEZ_SHUFFLE_ID
+    this.shuffleId = shuffleId;
+    this.mapId = mapId;
+    this.attemptNumber = attemptNumber;
+    this.numMappers = numMappers;
+    this.numPartitions = numPartitions;
+
+    LongAdder[] mapStatusLengths = new LongAdder[numPartitions];
+    for (int i = 0; i < numPartitions; i++) {
+      mapStatusLengths[i] = new LongAdder();
+    }
+    try {
+      dataPusher =
+          new DataPusher(
+              shuffleId,
+              mapId,
+              attemptNumber,
+              taskAttemptId,
+              numMappers,
+              numPartitions,
+              conf,
+              shuffleClient,
+              null,
+              integer -> {},
+              mapStatusLengths);
+    } catch (InterruptedException e) {
+      throw new RuntimeException(e);
+    }
+  }
+
+  public void pushData(int partitionId, byte[] dataBuf, int size) throws 
IOException {
+    try {
+      dataPusher.addTask(partitionId, dataBuf, size);
+    } catch (InterruptedException e) {
+      throw new IOInterruptedException(e);
+    }
+  }
+
+  public void mergeData(int partitionId, byte[] dataBuf, int size) throws 
IOException {
+    int bytesWritten =
+        shuffleClient.mergeData(
+            shuffleId,
+            mapId,
+            attemptNumber,
+            partitionId,
+            dataBuf,
+            0,
+            size,
+            numMappers,
+            numPartitions);
+  }
+
+  public int getNumPartitions() {
+    return numPartitions;
+  }
+
+  public void close() throws IOException {
+    logger.info(
+        "Call mapper end shuffleId:{} mapId:{} attemptId:{} numMappers:{}",
+        0,
+        mapId,
+        attemptNumber,
+        numMappers);
+    try {
+      dataPusher.waitOnTermination();
+      shuffleClient.pushMergedData(shuffleId, mapId, attemptNumber);
+      shuffleClient.mapperEnd(shuffleId, mapId, attemptNumber, numMappers);
+    } catch (InterruptedException e) {
+      throw new IOInterruptedException(e);
+    }
+  }
+}
diff --git 
a/client-tez/tez/src/main/java/org/apache/tez/runtime/library/sort/CelebornSortBasedPusher.java
 
b/client-tez/tez/src/main/java/org/apache/tez/runtime/library/sort/CelebornSortBasedPusher.java
new file mode 100644
index 000000000..e7d491233
--- /dev/null
+++ 
b/client-tez/tez/src/main/java/org/apache/tez/runtime/library/sort/CelebornSortBasedPusher.java
@@ -0,0 +1,347 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.library.sort;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.util.*;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicLong;
+import java.util.concurrent.atomic.AtomicReference;
+
+import org.apache.hadoop.io.RawComparator;
+import org.apache.hadoop.io.WritableUtils;
+import org.apache.hadoop.io.serializer.Serializer;
+import org.apache.tez.common.counters.TezCounter;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.celeborn.client.CelebornTezWriter;
+import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.unsafe.Platform;
+import org.apache.celeborn.common.util.Utils;
+
+public class CelebornSortBasedPusher<K, V> extends OutputStream {
+  private final Logger logger = 
LoggerFactory.getLogger(CelebornSortBasedPusher.class);
+  private final CelebornTezWriter celebornTezWriter;
+  private final int maxIOBufferSize;
+  private final int spillIOBufferSize;
+  private final Serializer<K> kSer;
+  private final Serializer<V> vSer;
+  private final RawComparator<K> comparator;
+  private final AtomicReference<Exception> exception = new AtomicReference<>();
+  private final int numOutputs;
+  private final TezCounter mapOutputByteCounter;
+  private final TezCounter mapOutputRecordCounter;
+  private final Map<Integer, List<SerializedKV>> partitionedKVs;
+  private int writePos;
+  private byte[] serializedKV;
+  private final int maxPushDataSize;
+  private Map<Integer, AtomicInteger> recordsPerPartition = new HashMap<>();
+  private Map<Integer, AtomicLong> bytesPerPartition = new HashMap<>();
+  private final boolean needSort;
+
+  public CelebornSortBasedPusher(
+      Serializer<K> kSer,
+      Serializer<V> vSer,
+      int maxIOBufferSize,
+      int spillIOBufferSize,
+      RawComparator<K> comparator,
+      TezCounter mapOutputByteCounter,
+      TezCounter mapOutputRecordCounter,
+      CelebornTezWriter celebornTezWriter,
+      CelebornConf celebornConf,
+      boolean needSort) {
+    this.kSer = kSer;
+    this.vSer = vSer;
+    this.maxIOBufferSize = maxIOBufferSize;
+    this.spillIOBufferSize = spillIOBufferSize;
+    this.mapOutputByteCounter = mapOutputByteCounter;
+    this.mapOutputRecordCounter = mapOutputRecordCounter;
+    this.comparator = comparator;
+    this.celebornTezWriter = celebornTezWriter;
+    this.needSort = needSort;
+    this.numOutputs = celebornTezWriter.getNumPartitions();
+    partitionedKVs = new HashMap<>();
+    serializedKV = new byte[maxIOBufferSize];
+    maxPushDataSize = (int) celebornConf.clientMrMaxPushData();
+    try {
+      kSer.open(this);
+      vSer.open(this);
+    } catch (IOException e) {
+      exception.compareAndSet(null, e);
+    }
+  }
+
+  public void insert(K key, V value, int partition) {
+    try {
+      if (writePos >= spillIOBufferSize) {
+        // needs to sort and flush data
+        if (logger.isDebugEnabled()) {
+          logger.debug(
+              "Data is large enough {}/{}/{}, trigger sort and flush",
+              Utils.bytesToString(writePos),
+              Utils.bytesToString(spillIOBufferSize),
+              Utils.bytesToString(maxIOBufferSize));
+        }
+        if (needSort) {
+          sortKVs();
+        }
+        sendKVAndUpdateWritePos();
+      }
+      int dataLen = insertRecordInternal(key, value, partition);
+      if (numOutputs == 1 && !needSort) {
+        recordsPerPartition.putIfAbsent(0, new AtomicInteger());
+        bytesPerPartition.putIfAbsent(0, new AtomicLong());
+        recordsPerPartition.get(0).incrementAndGet();
+        bytesPerPartition.get(0).incrementAndGet();
+      } else {
+        recordsPerPartition.computeIfAbsent(partition, p -> new 
AtomicInteger());
+        bytesPerPartition.computeIfAbsent(partition, p -> new AtomicLong());
+        recordsPerPartition.get(partition).incrementAndGet();
+        bytesPerPartition.get(partition).incrementAndGet();
+      }
+      if (logger.isDebugEnabled()) {
+        logger.debug(
+            "Sort based pusher insert into partition:{} with {} bytes", 
partition, dataLen);
+      }
+      mapOutputRecordCounter.increment(1);
+      mapOutputByteCounter.increment(dataLen);
+    } catch (IOException e) {
+      exception.compareAndSet(null, e);
+    }
+  }
+
+  private void sendKVAndUpdateWritePos() throws IOException {
+    Iterator<Map.Entry<Integer, List<SerializedKV>>> entryIter =
+        partitionedKVs.entrySet().iterator();
+    while (entryIter.hasNext()) {
+      Map.Entry<Integer, List<SerializedKV>> entry = entryIter.next();
+      entryIter.remove();
+      int partition = entry.getKey();
+      List<SerializedKV> kvs = entry.getValue();
+      List<SerializedKV> localKVs = new ArrayList<>();
+      int partitionKVTotalLen = 0;
+      // process buffers for specific partition
+      for (SerializedKV kv : kvs) {
+        partitionKVTotalLen += kv.kLen + kv.vLen;
+        localKVs.add(kv);
+        if (partitionKVTotalLen > maxPushDataSize) {
+          // limit max size of pushdata to avoid possible memory issue in 
Celeborn worker
+          // data layout
+          // pushdata header (16) + pushDataLen(4) +
+          // [varKeyLen+varValLen+serializedRecord(x)][...]
+          sendSortedBuffersPartition(partition, localKVs, partitionKVTotalLen, 
false);
+          localKVs.clear();
+          partitionKVTotalLen = 0;
+        }
+      }
+      if (!localKVs.isEmpty()) {
+        sendSortedBuffersPartition(partition, localKVs, partitionKVTotalLen, 
true);
+      }
+      kvs.clear();
+    }
+    // all data sent
+    partitionedKVs.clear();
+    writePos = 0;
+  }
+
+  private void sendSortedBuffersPartition(
+      int partition, List<SerializedKV> localKVs, int partitionKVTotalLen, 
boolean isMerge)
+      throws IOException {
+    int extraSize = 0;
+    for (SerializedKV localKV : localKVs) {
+      extraSize += WritableUtils.getVIntSize(localKV.kLen);
+      extraSize += WritableUtils.getVIntSize(localKV.vLen);
+    }
+    // copied from hadoop logic
+    extraSize += WritableUtils.getVIntSize(-1);
+    extraSize += WritableUtils.getVIntSize(-1);
+    // whole buffer's size + 
[(keyLen+valueLen)+(serializedKey+serializedValue)]
+    int length = 4 + extraSize + partitionKVTotalLen;
+    byte[] pkvs = new byte[length];
+    int pkvsPos = 4;
+    Platform.putInt(pkvs, Platform.BYTE_ARRAY_OFFSET, partitionKVTotalLen + 
extraSize);
+    for (SerializedKV kv : localKVs) {
+      int recordLen = kv.kLen + kv.vLen;
+      // write key len
+      pkvsPos = writeVLong(pkvs, pkvsPos, kv.kLen);
+      // write value len
+      pkvsPos = writeVLong(pkvs, pkvsPos, kv.vLen);
+      // write serialized record
+      System.arraycopy(serializedKV, kv.offset, pkvs, pkvsPos, recordLen);
+      pkvsPos += recordLen;
+    }
+    // finally write -1 two times
+    pkvsPos = writeVLong(pkvs, pkvsPos, -1);
+    writeVLong(pkvs, pkvsPos, -1);
+    if (isMerge) {
+      celebornTezWriter.mergeData(partition, pkvs, length);
+    } else {
+      celebornTezWriter.pushData(partition, pkvs, length);
+    }
+  }
+
+  /**
+   * Write variable length int to array Modified from
+   * org.apache.hadoop.io.WritableUtils#writeVLong(java.io.DataOutput, long)
+   */
+  private int writeVLong(byte[] data, int offset, long dataInt) {
+    if (dataInt >= -112L && dataInt <= 127L) {
+      data[offset++] = (byte) ((int) dataInt);
+      return offset;
+    }
+
+    int len = -112;
+    if (dataInt < 0L) {
+      dataInt ^= -1L;
+      len = -120;
+    }
+
+    long tmp = dataInt;
+    while (tmp != 0) {
+      tmp = tmp >> 8;
+      len--;
+    }
+
+    data[offset++] = (byte) len;
+
+    len = len < -120 ? -(len + 120) : -(len + 112);
+
+    for (int idx = len; idx != 0; --idx) {
+      int shiftBits = (idx - 1) * 8;
+      long mask = 0xFFL << shiftBits;
+      data[offset++] = ((byte) ((int) ((dataInt & mask) >> shiftBits)));
+    }
+    return offset;
+  }
+
+  private void sortKVs() {
+    for (Map.Entry<Integer, List<SerializedKV>> partitionKVEntry : 
partitionedKVs.entrySet()) {
+      partitionKVEntry
+          .getValue()
+          .sort(
+              (o1, o2) ->
+                  comparator.compare(
+                      serializedKV, o1.offset, o1.kLen, serializedKV, 
o2.offset, o2.kLen));
+    }
+  }
+
+  private int insertRecordInternal(K key, V value, int partition) throws 
IOException {
+    int offset = writePos;
+    int keyLen;
+    int valLen;
+    kSer.serialize(key);
+    keyLen = writePos - offset;
+    vSer.serialize(value);
+    valLen = writePos - keyLen - offset;
+    List<SerializedKV> serializedKVs =
+        partitionedKVs.computeIfAbsent(partition, v -> new ArrayList<>());
+    serializedKVs.add(new SerializedKV(offset, keyLen, valLen));
+    if (logger.isDebugEnabled()) {
+      logger.debug(
+          "Pusher insert into buffer partition:{} offset:{} keyLen:{} 
valueLen:{} size:{}",
+          partition,
+          offset,
+          keyLen,
+          valLen,
+          partitionedKVs.size());
+    }
+    return keyLen + valLen;
+  }
+
+  public void checkException() throws IOException {
+    if (exception.get() != null) {
+      throw new IOException("Write data to celeborn failed", exception.get());
+    }
+  }
+
+  @Override
+  public void write(int b) throws IOException {
+    if (writePos < maxIOBufferSize) {
+      serializedKV[writePos] = (byte) b;
+      writePos++;
+    } else {
+      logger.warn("Sort push memory high, write pos {} max size {}", writePos, 
maxIOBufferSize);
+      throw new IOException("Sort pusher memory exhausted.");
+    }
+  }
+
+  @Override
+  public void flush() {
+    logger.info("Sort based pusher called flush");
+    try {
+      if (needSort) {
+        sortKVs();
+      }
+      sendKVAndUpdateWritePos();
+    } catch (IOException e) {
+      exception.compareAndSet(null, e);
+    }
+  }
+
+  @Override
+  public void close() {
+    flush();
+    try {
+      celebornTezWriter.close();
+    } catch (IOException e) {
+      exception.compareAndSet(null, e);
+    }
+    partitionedKVs.clear();
+    serializedKV = null;
+  }
+
+  public int[] getRecordsPerPartition() {
+    int[] values = new int[numOutputs];
+    for (int i = 0; i < numOutputs; i++) {
+      AtomicInteger records = recordsPerPartition.get(i);
+      if (records != null) {
+        values[i] = recordsPerPartition.get(i).get();
+      } else {
+        values[i] = 0;
+      }
+    }
+    return values;
+  }
+
+  public long[] getBytesPerPartition() {
+    long[] values = new long[numOutputs];
+    for (int i = 0; i < numOutputs; i++) {
+      AtomicLong bytes = bytesPerPartition.get(i);
+      if (bytes != null) {
+        values[i] = bytes.get();
+      } else {
+        values[i] = 0;
+      }
+    }
+    return values;
+  }
+
+  static class SerializedKV {
+    final int offset;
+    final int kLen;
+    final int vLen;
+
+    public SerializedKV(int offset, int kLen, int vLen) {
+      this.offset = offset;
+      this.kLen = kLen;
+      this.vLen = vLen;
+    }
+  }
+}

Reply via email to