This is an automated email from the ASF dual-hosted git repository.
ethanfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new 3dd810cd9 [CELEBORN-1612] Add a basic reader writer class to Tez
3dd810cd9 is described below
commit 3dd810cd9bf1ae09d7c4b024876d94aac066e3c9
Author: hongguangwei <[email protected]>
AuthorDate: Tue Dec 3 14:51:52 2024 +0800
[CELEBORN-1612] Add a basic reader writer class to Tez
### What changes were proposed in this pull request?
1. Add a basic reader writer class to Tez
2. Copy sort pusher from MR to Tez project and add unsort implementation
### Why are the changes needed?
Add basic utilities to support Tez client.
### Does this PR introduce _any_ user-facing change?
NO.
### How was this patch tested?
Cluster tests.
Closes #2969 from GH-Gloway/1610.
Authored-by: hongguangwei <[email protected]>
Signed-off-by: mingji <[email protected]>
---
.../apache/celeborn/client/CelebornTezReader.java | 94 ++++++
.../apache/celeborn/client/CelebornTezWriter.java | 128 ++++++++
.../library/sort/CelebornSortBasedPusher.java | 347 +++++++++++++++++++++
3 files changed, 569 insertions(+)
diff --git
a/client-tez/tez/src/main/java/org/apache/celeborn/client/CelebornTezReader.java
b/client-tez/tez/src/main/java/org/apache/celeborn/client/CelebornTezReader.java
new file mode 100644
index 000000000..47af6e465
--- /dev/null
+++
b/client-tez/tez/src/main/java/org/apache/celeborn/client/CelebornTezReader.java
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.client;
+
+import java.io.IOException;
+
+import org.slf4j.LoggerFactory;
+
+import org.apache.celeborn.client.read.CelebornInputStream;
+import org.apache.celeborn.client.read.MetricsCallback;
+import org.apache.celeborn.common.exception.CelebornIOException;
+import org.apache.celeborn.common.unsafe.Platform;
+
+public class CelebornTezReader {
+ private static final org.slf4j.Logger logger =
LoggerFactory.getLogger(CelebornTezReader.class);
+
+ private int shuffleId;
+ private int partitionId;
+ private int attemptNumber;
+ private ShuffleClient shuffleClient;
+ private long inputShuffleSize;
+ private CelebornInputStream celebornInputStream;
+
+ public CelebornTezReader(
+ ShuffleClient shuffleClient, int shuffleId, int partitionId, int
attemptNumber) {
+ this.shuffleClient = shuffleClient;
+ this.partitionId = partitionId;
+ this.attemptNumber = attemptNumber;
+ this.shuffleId = shuffleId;
+ }
+
+ public void init() throws IOException {
+ MetricsCallback metricsCallback =
+ new MetricsCallback() {
+ @Override
+ public void incBytesRead(long bytesRead) {}
+
+ @Override
+ public void incReadTime(long time) {}
+ };
+ celebornInputStream =
+ shuffleClient.readPartition(
+ shuffleId, partitionId, attemptNumber, 0, Integer.MAX_VALUE,
metricsCallback);
+ }
+
+ public byte[] getShuffleBlock() throws IOException {
+ // get len
+ byte[] header = new byte[4];
+ int count = celebornInputStream.read(header);
+ if (count == -1) {
+ return null;
+ }
+ while (count != header.length) {
+ count += celebornInputStream.read(header, count, 4 - count);
+ }
+
+ // get data
+ int blockLen = Platform.getInt(header, Platform.BYTE_ARRAY_OFFSET);
+ inputShuffleSize += blockLen;
+ byte[] shuffleData = new byte[blockLen];
+ count = celebornInputStream.read(shuffleData);
+ while (count != shuffleData.length) {
+ count += celebornInputStream.read(shuffleData, count, blockLen - count);
+ if (count == -1) {
+ // read shuffle is done.
+ throw new CelebornIOException("Read mr shuffle failed.");
+ }
+ }
+ return shuffleData;
+ }
+
+ public void close() throws IOException {
+ celebornInputStream.close();
+ }
+
+ public int getPartitionId() {
+ return partitionId;
+ }
+}
diff --git
a/client-tez/tez/src/main/java/org/apache/celeborn/client/CelebornTezWriter.java
b/client-tez/tez/src/main/java/org/apache/celeborn/client/CelebornTezWriter.java
new file mode 100644
index 000000000..e80de0504
--- /dev/null
+++
b/client-tez/tez/src/main/java/org/apache/celeborn/client/CelebornTezWriter.java
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.client;
+
+import java.io.IOException;
+import java.util.concurrent.atomic.LongAdder;
+
+import org.apache.tez.runtime.library.api.IOInterruptedException;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.celeborn.client.write.DataPusher;
+import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.identity.UserIdentifier;
+
+public class CelebornTezWriter {
+ private final Logger logger =
LoggerFactory.getLogger(CelebornTezWriter.class);
+
+ private final ShuffleClient shuffleClient;
+ private DataPusher dataPusher;
+ private final int shuffleId;
+ private final int mapId;
+ private final int attemptNumber;
+ private final int numMappers;
+ private final int numPartitions;
+
+ public CelebornTezWriter(
+ int shuffleId,
+ int mapId,
+ int attemptNumber,
+ long taskAttemptId,
+ int numMappers,
+ int numPartitions,
+ CelebornConf conf,
+ String appUniqueId,
+ String lifecycleManagerHost,
+ int lifecycleManagerPort,
+ UserIdentifier userIdentifier) {
+ shuffleClient =
+ ShuffleClient.get(
+ appUniqueId, lifecycleManagerHost, lifecycleManagerPort, conf,
userIdentifier, null);
+ // TEZ_SHUFFLE_ID
+ this.shuffleId = shuffleId;
+ this.mapId = mapId;
+ this.attemptNumber = attemptNumber;
+ this.numMappers = numMappers;
+ this.numPartitions = numPartitions;
+
+ LongAdder[] mapStatusLengths = new LongAdder[numPartitions];
+ for (int i = 0; i < numPartitions; i++) {
+ mapStatusLengths[i] = new LongAdder();
+ }
+ try {
+ dataPusher =
+ new DataPusher(
+ shuffleId,
+ mapId,
+ attemptNumber,
+ taskAttemptId,
+ numMappers,
+ numPartitions,
+ conf,
+ shuffleClient,
+ null,
+ integer -> {},
+ mapStatusLengths);
+ } catch (InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ public void pushData(int partitionId, byte[] dataBuf, int size) throws
IOException {
+ try {
+ dataPusher.addTask(partitionId, dataBuf, size);
+ } catch (InterruptedException e) {
+ throw new IOInterruptedException(e);
+ }
+ }
+
+ public void mergeData(int partitionId, byte[] dataBuf, int size) throws
IOException {
+ int bytesWritten =
+ shuffleClient.mergeData(
+ shuffleId,
+ mapId,
+ attemptNumber,
+ partitionId,
+ dataBuf,
+ 0,
+ size,
+ numMappers,
+ numPartitions);
+ }
+
+ public int getNumPartitions() {
+ return numPartitions;
+ }
+
+ public void close() throws IOException {
+ logger.info(
+ "Call mapper end shuffleId:{} mapId:{} attemptId:{} numMappers:{}",
+ 0,
+ mapId,
+ attemptNumber,
+ numMappers);
+ try {
+ dataPusher.waitOnTermination();
+ shuffleClient.pushMergedData(shuffleId, mapId, attemptNumber);
+ shuffleClient.mapperEnd(shuffleId, mapId, attemptNumber, numMappers);
+ } catch (InterruptedException e) {
+ throw new IOInterruptedException(e);
+ }
+ }
+}
diff --git
a/client-tez/tez/src/main/java/org/apache/tez/runtime/library/sort/CelebornSortBasedPusher.java
b/client-tez/tez/src/main/java/org/apache/tez/runtime/library/sort/CelebornSortBasedPusher.java
new file mode 100644
index 000000000..e7d491233
--- /dev/null
+++
b/client-tez/tez/src/main/java/org/apache/tez/runtime/library/sort/CelebornSortBasedPusher.java
@@ -0,0 +1,347 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.runtime.library.sort;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.util.*;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicLong;
+import java.util.concurrent.atomic.AtomicReference;
+
+import org.apache.hadoop.io.RawComparator;
+import org.apache.hadoop.io.WritableUtils;
+import org.apache.hadoop.io.serializer.Serializer;
+import org.apache.tez.common.counters.TezCounter;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.celeborn.client.CelebornTezWriter;
+import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.unsafe.Platform;
+import org.apache.celeborn.common.util.Utils;
+
+public class CelebornSortBasedPusher<K, V> extends OutputStream {
+ private final Logger logger =
LoggerFactory.getLogger(CelebornSortBasedPusher.class);
+ private final CelebornTezWriter celebornTezWriter;
+ private final int maxIOBufferSize;
+ private final int spillIOBufferSize;
+ private final Serializer<K> kSer;
+ private final Serializer<V> vSer;
+ private final RawComparator<K> comparator;
+ private final AtomicReference<Exception> exception = new AtomicReference<>();
+ private final int numOutputs;
+ private final TezCounter mapOutputByteCounter;
+ private final TezCounter mapOutputRecordCounter;
+ private final Map<Integer, List<SerializedKV>> partitionedKVs;
+ private int writePos;
+ private byte[] serializedKV;
+ private final int maxPushDataSize;
+ private Map<Integer, AtomicInteger> recordsPerPartition = new HashMap<>();
+ private Map<Integer, AtomicLong> bytesPerPartition = new HashMap<>();
+ private final boolean needSort;
+
+ public CelebornSortBasedPusher(
+ Serializer<K> kSer,
+ Serializer<V> vSer,
+ int maxIOBufferSize,
+ int spillIOBufferSize,
+ RawComparator<K> comparator,
+ TezCounter mapOutputByteCounter,
+ TezCounter mapOutputRecordCounter,
+ CelebornTezWriter celebornTezWriter,
+ CelebornConf celebornConf,
+ boolean needSort) {
+ this.kSer = kSer;
+ this.vSer = vSer;
+ this.maxIOBufferSize = maxIOBufferSize;
+ this.spillIOBufferSize = spillIOBufferSize;
+ this.mapOutputByteCounter = mapOutputByteCounter;
+ this.mapOutputRecordCounter = mapOutputRecordCounter;
+ this.comparator = comparator;
+ this.celebornTezWriter = celebornTezWriter;
+ this.needSort = needSort;
+ this.numOutputs = celebornTezWriter.getNumPartitions();
+ partitionedKVs = new HashMap<>();
+ serializedKV = new byte[maxIOBufferSize];
+ maxPushDataSize = (int) celebornConf.clientMrMaxPushData();
+ try {
+ kSer.open(this);
+ vSer.open(this);
+ } catch (IOException e) {
+ exception.compareAndSet(null, e);
+ }
+ }
+
+ public void insert(K key, V value, int partition) {
+ try {
+ if (writePos >= spillIOBufferSize) {
+ // needs to sort and flush data
+ if (logger.isDebugEnabled()) {
+ logger.debug(
+ "Data is large enough {}/{}/{}, trigger sort and flush",
+ Utils.bytesToString(writePos),
+ Utils.bytesToString(spillIOBufferSize),
+ Utils.bytesToString(maxIOBufferSize));
+ }
+ if (needSort) {
+ sortKVs();
+ }
+ sendKVAndUpdateWritePos();
+ }
+ int dataLen = insertRecordInternal(key, value, partition);
+ if (numOutputs == 1 && !needSort) {
+ recordsPerPartition.putIfAbsent(0, new AtomicInteger());
+ bytesPerPartition.putIfAbsent(0, new AtomicLong());
+ recordsPerPartition.get(0).incrementAndGet();
+ bytesPerPartition.get(0).incrementAndGet();
+ } else {
+ recordsPerPartition.computeIfAbsent(partition, p -> new
AtomicInteger());
+ bytesPerPartition.computeIfAbsent(partition, p -> new AtomicLong());
+ recordsPerPartition.get(partition).incrementAndGet();
+ bytesPerPartition.get(partition).incrementAndGet();
+ }
+ if (logger.isDebugEnabled()) {
+ logger.debug(
+ "Sort based pusher insert into partition:{} with {} bytes",
partition, dataLen);
+ }
+ mapOutputRecordCounter.increment(1);
+ mapOutputByteCounter.increment(dataLen);
+ } catch (IOException e) {
+ exception.compareAndSet(null, e);
+ }
+ }
+
+ private void sendKVAndUpdateWritePos() throws IOException {
+ Iterator<Map.Entry<Integer, List<SerializedKV>>> entryIter =
+ partitionedKVs.entrySet().iterator();
+ while (entryIter.hasNext()) {
+ Map.Entry<Integer, List<SerializedKV>> entry = entryIter.next();
+ entryIter.remove();
+ int partition = entry.getKey();
+ List<SerializedKV> kvs = entry.getValue();
+ List<SerializedKV> localKVs = new ArrayList<>();
+ int partitionKVTotalLen = 0;
+ // process buffers for specific partition
+ for (SerializedKV kv : kvs) {
+ partitionKVTotalLen += kv.kLen + kv.vLen;
+ localKVs.add(kv);
+ if (partitionKVTotalLen > maxPushDataSize) {
+ // limit max size of pushdata to avoid possible memory issue in
Celeborn worker
+ // data layout
+ // pushdata header (16) + pushDataLen(4) +
+ // [varKeyLen+varValLen+serializedRecord(x)][...]
+ sendSortedBuffersPartition(partition, localKVs, partitionKVTotalLen,
false);
+ localKVs.clear();
+ partitionKVTotalLen = 0;
+ }
+ }
+ if (!localKVs.isEmpty()) {
+ sendSortedBuffersPartition(partition, localKVs, partitionKVTotalLen,
true);
+ }
+ kvs.clear();
+ }
+ // all data sent
+ partitionedKVs.clear();
+ writePos = 0;
+ }
+
+ private void sendSortedBuffersPartition(
+ int partition, List<SerializedKV> localKVs, int partitionKVTotalLen,
boolean isMerge)
+ throws IOException {
+ int extraSize = 0;
+ for (SerializedKV localKV : localKVs) {
+ extraSize += WritableUtils.getVIntSize(localKV.kLen);
+ extraSize += WritableUtils.getVIntSize(localKV.vLen);
+ }
+ // copied from hadoop logic
+ extraSize += WritableUtils.getVIntSize(-1);
+ extraSize += WritableUtils.getVIntSize(-1);
+ // whole buffer's size +
[(keyLen+valueLen)+(serializedKey+serializedValue)]
+ int length = 4 + extraSize + partitionKVTotalLen;
+ byte[] pkvs = new byte[length];
+ int pkvsPos = 4;
+ Platform.putInt(pkvs, Platform.BYTE_ARRAY_OFFSET, partitionKVTotalLen +
extraSize);
+ for (SerializedKV kv : localKVs) {
+ int recordLen = kv.kLen + kv.vLen;
+ // write key len
+ pkvsPos = writeVLong(pkvs, pkvsPos, kv.kLen);
+ // write value len
+ pkvsPos = writeVLong(pkvs, pkvsPos, kv.vLen);
+ // write serialized record
+ System.arraycopy(serializedKV, kv.offset, pkvs, pkvsPos, recordLen);
+ pkvsPos += recordLen;
+ }
+ // finally write -1 two times
+ pkvsPos = writeVLong(pkvs, pkvsPos, -1);
+ writeVLong(pkvs, pkvsPos, -1);
+ if (isMerge) {
+ celebornTezWriter.mergeData(partition, pkvs, length);
+ } else {
+ celebornTezWriter.pushData(partition, pkvs, length);
+ }
+ }
+
+ /**
+ * Write variable length int to array Modified from
+ * org.apache.hadoop.io.WritableUtils#writeVLong(java.io.DataOutput, long)
+ */
+ private int writeVLong(byte[] data, int offset, long dataInt) {
+ if (dataInt >= -112L && dataInt <= 127L) {
+ data[offset++] = (byte) ((int) dataInt);
+ return offset;
+ }
+
+ int len = -112;
+ if (dataInt < 0L) {
+ dataInt ^= -1L;
+ len = -120;
+ }
+
+ long tmp = dataInt;
+ while (tmp != 0) {
+ tmp = tmp >> 8;
+ len--;
+ }
+
+ data[offset++] = (byte) len;
+
+ len = len < -120 ? -(len + 120) : -(len + 112);
+
+ for (int idx = len; idx != 0; --idx) {
+ int shiftBits = (idx - 1) * 8;
+ long mask = 0xFFL << shiftBits;
+ data[offset++] = ((byte) ((int) ((dataInt & mask) >> shiftBits)));
+ }
+ return offset;
+ }
+
+ private void sortKVs() {
+ for (Map.Entry<Integer, List<SerializedKV>> partitionKVEntry :
partitionedKVs.entrySet()) {
+ partitionKVEntry
+ .getValue()
+ .sort(
+ (o1, o2) ->
+ comparator.compare(
+ serializedKV, o1.offset, o1.kLen, serializedKV,
o2.offset, o2.kLen));
+ }
+ }
+
+ private int insertRecordInternal(K key, V value, int partition) throws
IOException {
+ int offset = writePos;
+ int keyLen;
+ int valLen;
+ kSer.serialize(key);
+ keyLen = writePos - offset;
+ vSer.serialize(value);
+ valLen = writePos - keyLen - offset;
+ List<SerializedKV> serializedKVs =
+ partitionedKVs.computeIfAbsent(partition, v -> new ArrayList<>());
+ serializedKVs.add(new SerializedKV(offset, keyLen, valLen));
+ if (logger.isDebugEnabled()) {
+ logger.debug(
+ "Pusher insert into buffer partition:{} offset:{} keyLen:{}
valueLen:{} size:{}",
+ partition,
+ offset,
+ keyLen,
+ valLen,
+ partitionedKVs.size());
+ }
+ return keyLen + valLen;
+ }
+
+ public void checkException() throws IOException {
+ if (exception.get() != null) {
+ throw new IOException("Write data to celeborn failed", exception.get());
+ }
+ }
+
+ @Override
+ public void write(int b) throws IOException {
+ if (writePos < maxIOBufferSize) {
+ serializedKV[writePos] = (byte) b;
+ writePos++;
+ } else {
+ logger.warn("Sort push memory high, write pos {} max size {}", writePos,
maxIOBufferSize);
+ throw new IOException("Sort pusher memory exhausted.");
+ }
+ }
+
+ @Override
+ public void flush() {
+ logger.info("Sort based pusher called flush");
+ try {
+ if (needSort) {
+ sortKVs();
+ }
+ sendKVAndUpdateWritePos();
+ } catch (IOException e) {
+ exception.compareAndSet(null, e);
+ }
+ }
+
+ @Override
+ public void close() {
+ flush();
+ try {
+ celebornTezWriter.close();
+ } catch (IOException e) {
+ exception.compareAndSet(null, e);
+ }
+ partitionedKVs.clear();
+ serializedKV = null;
+ }
+
+ public int[] getRecordsPerPartition() {
+ int[] values = new int[numOutputs];
+ for (int i = 0; i < numOutputs; i++) {
+ AtomicInteger records = recordsPerPartition.get(i);
+ if (records != null) {
+ values[i] = recordsPerPartition.get(i).get();
+ } else {
+ values[i] = 0;
+ }
+ }
+ return values;
+ }
+
+ public long[] getBytesPerPartition() {
+ long[] values = new long[numOutputs];
+ for (int i = 0; i < numOutputs; i++) {
+ AtomicLong bytes = bytesPerPartition.get(i);
+ if (bytes != null) {
+ values[i] = bytes.get();
+ } else {
+ values[i] = 0;
+ }
+ }
+ return values;
+ }
+
+ static class SerializedKV {
+ final int offset;
+ final int kLen;
+ final int vLen;
+
+ public SerializedKV(int offset, int kLen, int vLen) {
+ this.offset = offset;
+ this.kLen = kLen;
+ this.vLen = vLen;
+ }
+ }
+}