This is an automated email from the ASF dual-hosted git repository.
roryqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new c713f391 [#872][FOLLOWUP] feat(tez): Add the common and utils class
(#894)
c713f391 is described below
commit c713f3917b887657e2ffca30613d100e3673e8ee
Author: Qing <[email protected]>
AuthorDate: Wed May 31 22:32:34 2023 +0800
[#872][FOLLOWUP] feat(tez): Add the common and utils class (#894)
### What changes were proposed in this pull request?
Tez Shuffle Read supporting Remote Shuffle related common Class
### Why are the changes needed?
Fix: #872
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
UT.
---
.../apache/tez/common/GetShuffleServerRequest.java | 96 ++++++++++++
.../tez/common/GetShuffleServerResponse.java | 77 ++++++++++
.../tez/common/ShuffleAssignmentsInfoWritable.java | 165 +++++++++++++++++++++
.../java/org/apache/tez/common/TezClassLoader.java | 82 ++++++++++
.../tez/common/GetShuffleServerRequestTest.java | 67 +++++++++
.../tez/common/GetShuffleServerResponseTest.java | 123 +++++++++++++++
.../common/ShuffleAssignmentsInfoWritableTest.java | 109 ++++++++++++++
7 files changed, 719 insertions(+)
diff --git
a/client-tez/src/main/java/org/apache/tez/common/GetShuffleServerRequest.java
b/client-tez/src/main/java/org/apache/tez/common/GetShuffleServerRequest.java
new file mode 100644
index 00000000..bdd8f00b
--- /dev/null
+++
b/client-tez/src/main/java/org/apache/tez/common/GetShuffleServerRequest.java
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.common;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+import org.apache.hadoop.io.Writable;
+import org.apache.tez.dag.records.TezTaskAttemptID;
+
+
+public class GetShuffleServerRequest implements Writable {
+ private TezTaskAttemptID currentTaskAttemptID;
+ private int startIndex;
+ private int partitionNum;
+ private int shuffleId;
+
+ public GetShuffleServerRequest() {
+ }
+
+ public GetShuffleServerRequest(TezTaskAttemptID currentTaskAttemptID, int
startIndex,
+ int partitionNum, int shuffleId) {
+ this.currentTaskAttemptID = currentTaskAttemptID;
+ this.startIndex = startIndex;
+ this.partitionNum = partitionNum;
+ this.shuffleId = shuffleId;
+ }
+
+ @Override
+ public void write(DataOutput output) throws IOException {
+ output.writeInt(startIndex);
+ output.writeInt(partitionNum);
+ output.writeInt(shuffleId);
+ if (currentTaskAttemptID != null) {
+ output.writeBoolean(true);
+ currentTaskAttemptID.write(output);
+ } else {
+ output.writeBoolean(false);
+ }
+ }
+
+ @Override
+ public void readFields(DataInput dataInput) throws IOException {
+ startIndex = dataInput.readInt();
+ partitionNum = dataInput.readInt();
+ shuffleId = dataInput.readInt();
+ boolean hasTaskTaskAttemptID = dataInput.readBoolean();
+ if (hasTaskTaskAttemptID) {
+ currentTaskAttemptID = new TezTaskAttemptID();
+ currentTaskAttemptID.readFields(dataInput);
+ }
+ }
+
+ @Override
+ public String toString() {
+ return "GetShuffleServerRequest{"
+ + "currentTaskAttemptID="
+ + currentTaskAttemptID
+ + ", startIndex=" + startIndex
+ + ", partitionNum=" + partitionNum
+ + ", shuffleId=" + shuffleId
+ + '}';
+ }
+
+ public TezTaskAttemptID getCurrentTaskAttemptID() {
+ return currentTaskAttemptID;
+ }
+
+ public int getStartIndex() {
+ return startIndex;
+ }
+
+ public int getPartitionNum() {
+ return partitionNum;
+ }
+
+ public int getShuffleId() {
+ return shuffleId;
+ }
+}
diff --git
a/client-tez/src/main/java/org/apache/tez/common/GetShuffleServerResponse.java
b/client-tez/src/main/java/org/apache/tez/common/GetShuffleServerResponse.java
new file mode 100644
index 00000000..0defeaa5
--- /dev/null
+++
b/client-tez/src/main/java/org/apache/tez/common/GetShuffleServerResponse.java
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.common;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.Objects;
+
+import org.apache.hadoop.io.Writable;
+
+public class GetShuffleServerResponse implements Writable {
+ private int status;
+ private String retMsg;
+ private ShuffleAssignmentsInfoWritable shuffleAssignmentsInfoWritable;
+
+ public int getStatus() {
+ return status;
+ }
+
+ public void setStatus(int status) {
+ this.status = status;
+ }
+
+ public String getRetMsg() {
+ return retMsg;
+ }
+
+ public void setRetMsg(String retMsg) {
+ this.retMsg = retMsg;
+ }
+
+ public ShuffleAssignmentsInfoWritable getShuffleAssignmentsInfoWritable() {
+ return shuffleAssignmentsInfoWritable;
+ }
+
+ public void setShuffleAssignmentsInfoWritable(ShuffleAssignmentsInfoWritable
shuffleAssignmentsInfoWritable) {
+ this.shuffleAssignmentsInfoWritable = shuffleAssignmentsInfoWritable;
+ }
+
+ @Override
+ public void write(DataOutput dataOutput) throws IOException {
+ dataOutput.writeInt(status);
+ dataOutput.writeUTF(retMsg);
+ if (Objects.isNull(shuffleAssignmentsInfoWritable)) {
+ dataOutput.writeInt(-1);
+ } else {
+ dataOutput.writeInt(1);
+ shuffleAssignmentsInfoWritable.write(dataOutput);
+ }
+ }
+
+ @Override
+ public void readFields(DataInput dataInput) throws IOException {
+ status = dataInput.readInt();
+ retMsg = dataInput.readUTF();
+ shuffleAssignmentsInfoWritable = new ShuffleAssignmentsInfoWritable();
+ if (dataInput.readInt() != -1) {
+ shuffleAssignmentsInfoWritable.readFields(dataInput);
+ }
+ }
+}
diff --git
a/client-tez/src/main/java/org/apache/tez/common/ShuffleAssignmentsInfoWritable.java
b/client-tez/src/main/java/org/apache/tez/common/ShuffleAssignmentsInfoWritable.java
new file mode 100644
index 00000000..71a32a68
--- /dev/null
+++
b/client-tez/src/main/java/org/apache/tez/common/ShuffleAssignmentsInfoWritable.java
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.common;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.apache.commons.collections.MapUtils;
+import org.apache.hadoop.io.Writable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.common.PartitionRange;
+import org.apache.uniffle.common.ShuffleAssignmentsInfo;
+import org.apache.uniffle.common.ShuffleServerInfo;
+
+public class ShuffleAssignmentsInfoWritable implements Writable {
+ private ShuffleAssignmentsInfo shuffleAssignmentsInfo;
+ private static final Logger LOG =
LoggerFactory.getLogger(ShuffleAssignmentsInfoWritable.class);
+
+
+ public ShuffleAssignmentsInfoWritable() {
+
+ }
+
+ public ShuffleAssignmentsInfoWritable(ShuffleAssignmentsInfo
shuffleAssignmentsInfo) {
+ this.shuffleAssignmentsInfo = shuffleAssignmentsInfo;
+ }
+
+
+ @Override
+ public void write(DataOutput dataOutput) throws IOException {
+ if (shuffleAssignmentsInfo == null) {
+ dataOutput.writeInt(-1);
+ LOG.warn("shuffleAssignmentsInfo is null, no need write");
+ return;
+ } else {
+ dataOutput.writeInt(1);
+ }
+
+ Map<Integer, List<ShuffleServerInfo>> partitionToServers =
shuffleAssignmentsInfo.getPartitionToServers();
+ if (MapUtils.isEmpty(partitionToServers)) {
+ dataOutput.writeInt(-1);
+ } else {
+ dataOutput.writeInt(partitionToServers.size());
+ for (Map.Entry<Integer, List<ShuffleServerInfo>> entry :
partitionToServers.entrySet()) {
+ dataOutput.writeInt(entry.getKey());
+ if (CollectionUtils.isEmpty(entry.getValue())) {
+ dataOutput.writeInt(-1);
+ } else {
+ dataOutput.writeInt(entry.getValue().size());
+ for (ShuffleServerInfo serverInfo : entry.getValue()) {
+ dataOutput.writeUTF(serverInfo.getId());
+ dataOutput.writeUTF(serverInfo.getHost());
+ dataOutput.writeInt(serverInfo.getGrpcPort());
+ }
+ }
+ }
+ }
+
+ Map<ShuffleServerInfo, List<PartitionRange>> serverToPartitionRanges =
shuffleAssignmentsInfo
+ .getServerToPartitionRanges();
+ if (MapUtils.isEmpty(serverToPartitionRanges)) {
+ dataOutput.writeInt(-1);
+ } else {
+ dataOutput.writeInt(serverToPartitionRanges.size());
+ for (Map.Entry<ShuffleServerInfo, List<PartitionRange>> entry :
serverToPartitionRanges.entrySet()) {
+ dataOutput.writeUTF(entry.getKey().getId());
+ dataOutput.writeUTF(entry.getKey().getHost());
+ dataOutput.writeInt(entry.getKey().getGrpcPort());
+ if (CollectionUtils.isEmpty(entry.getValue())) {
+ dataOutput.writeInt(-1);
+ } else {
+ dataOutput.writeInt(entry.getValue().size());
+ for (PartitionRange range : entry.getValue()) {
+ dataOutput.writeInt(range.getStart());
+ dataOutput.writeInt(range.getEnd());
+ }
+ }
+ }
+ }
+ }
+
+ @Override
+ public void readFields(DataInput dataInput) throws IOException {
+ if (dataInput.readInt() == -1) {
+ LOG.warn("shuffleAssignmentsInfo is null, no need read");
+ return;
+ }
+
+ Map<Integer, List<ShuffleServerInfo>> partitionToServers = new HashMap<>();
+ int partitionToServersSize = dataInput.readInt();
+ if (partitionToServersSize != -1) {
+ Integer partitionId;
+ for (int i = 0; i < partitionToServersSize; i++) {
+ partitionId = dataInput.readInt();
+ List<ShuffleServerInfo> shuffleServerInfoList = new ArrayList<>();
+ int shuffleServerInfoListSize = dataInput.readInt();
+ if (shuffleServerInfoListSize != -1) {
+ for (int i1 = 0; i1 < shuffleServerInfoListSize; i1++) {
+ String id = dataInput.readUTF();
+ String host = dataInput.readUTF();
+ int port = dataInput.readInt();
+ ShuffleServerInfo shuffleServerInfo = new ShuffleServerInfo(id,
host, port);
+ shuffleServerInfoList.add(shuffleServerInfo);
+ }
+ }
+
+ partitionToServers.put(partitionId, shuffleServerInfoList);
+ }
+ }
+
+ Map<ShuffleServerInfo, List<PartitionRange>> serverToPartitionRanges = new
HashMap<>();
+ int serverToPartitionRangesSize = dataInput.readInt();
+ if (serverToPartitionRangesSize != -1) {
+ for (int i = 0; i < serverToPartitionRangesSize; i++) {
+ ShuffleServerInfo shuffleServerInfo;
+ List<PartitionRange> partitionRangeList = new ArrayList<>();
+
+ String id = dataInput.readUTF();
+ String host = dataInput.readUTF();
+ int port = dataInput.readInt();
+ shuffleServerInfo = new ShuffleServerInfo(id, host, port);
+
+ int partitionRangeListSize = dataInput.readInt();
+ if (partitionRangeListSize != -1) {
+ for (int i1 = 0; i1 < partitionRangeListSize; i1++) {
+ int start = dataInput.readInt();
+ int end = dataInput.readInt();
+ PartitionRange partitionRange = new PartitionRange(start, end);
+ partitionRangeList.add(partitionRange);
+ }
+ }
+ serverToPartitionRanges.put(shuffleServerInfo, partitionRangeList);
+ }
+ }
+
+ shuffleAssignmentsInfo = new ShuffleAssignmentsInfo(partitionToServers,
serverToPartitionRanges);
+ }
+
+ public ShuffleAssignmentsInfo getShuffleAssignmentsInfo() {
+ return shuffleAssignmentsInfo;
+ }
+}
diff --git a/client-tez/src/main/java/org/apache/tez/common/TezClassLoader.java
b/client-tez/src/main/java/org/apache/tez/common/TezClassLoader.java
new file mode 100644
index 00000000..1952e2d2
--- /dev/null
+++ b/client-tez/src/main/java/org/apache/tez/common/TezClassLoader.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.common;
+
+import java.net.URL;
+import java.net.URLClassLoader;
+import java.security.AccessController;
+import java.security.PrivilegedAction;
+
+import org.apache.hadoop.conf.Configuration;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+/**
+ * ClassLoader to allow addition of new paths to classpath in the runtime.
+ *
+ * It uses URLClassLoader with this class' classloader as parent classloader.
+ * And hence first delegates the resource loading to parent and then to the
URLs
+ * added. The process must be setup to use by invoking setupTezClassLoader()
which sets
+ * the global TezClassLoader as current thread context class loader. All
threads
+ * created will inherit the classloader and hence will resolve the
class/resource
+ * from TezClassLoader.
+ */
+
+public class TezClassLoader extends URLClassLoader {
+ private static final TezClassLoader INSTANCE;
+ private static final Logger LOG =
LoggerFactory.getLogger(TezClassLoader.class);
+
+ static {
+ INSTANCE = AccessController.doPrivileged(new
PrivilegedAction<TezClassLoader>() {
+ @Override
+ public TezClassLoader run() {
+ return new TezClassLoader();
+ }
+ });
+ }
+
+ private TezClassLoader() {
+ super(new URL[] {}, TezClassLoader.class.getClassLoader());
+
+ LOG.info(
+ "Created TezClassLoader with parent classloader: {}, thread: {},
system classloader: {}",
+ TezClassLoader.class.getClassLoader(),
Thread.currentThread().getId(),
+ ClassLoader.getSystemClassLoader());
+ }
+
+ @Override
+ public void addURL(URL url) {
+ super.addURL(url);
+ }
+
+ public static TezClassLoader getInstance() {
+ return INSTANCE;
+ }
+
+ public static void setupTezClassLoader() {
+ LOG.debug(
+ "Setting up TezClassLoader: thread: {}, current thread
classloader: {} system classloader: {}",
+ Thread.currentThread().getId(),
Thread.currentThread().getContextClassLoader(),
+ ClassLoader.getSystemClassLoader());
+ Thread.currentThread().setContextClassLoader(INSTANCE);
+ }
+
+ public static void setupForConfiguration(Configuration configuration) {
+ configuration.setClassLoader(INSTANCE);
+ }
+}
+
diff --git
a/client-tez/src/test/java/org/apache/tez/common/GetShuffleServerRequestTest.java
b/client-tez/src/test/java/org/apache/tez/common/GetShuffleServerRequestTest.java
new file mode 100644
index 00000000..d300844b
--- /dev/null
+++
b/client-tez/src/test/java/org/apache/tez/common/GetShuffleServerRequestTest.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.common;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataInput;
+import java.io.DataInputStream;
+import java.io.DataOutput;
+import java.io.DataOutputStream;
+import java.io.IOException;
+
+import org.apache.hadoop.yarn.api.records.ApplicationId;
+import org.apache.tez.dag.records.TezDAGID;
+import org.apache.tez.dag.records.TezTaskAttemptID;
+import org.apache.tez.dag.records.TezTaskID;
+import org.apache.tez.dag.records.TezVertexID;
+import org.junit.jupiter.api.Test;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+public class GetShuffleServerRequestTest {
+ @Test
+ public void testSerDe() throws IOException {
+ ApplicationId appId = ApplicationId.newInstance(9999, 72);
+ TezDAGID dagId = TezDAGID.getInstance(appId, 1);
+ TezVertexID vId = TezVertexID.getInstance(dagId, 35);
+ TezTaskID tId = TezTaskID.getInstance(vId, 389);
+
+ TezTaskAttemptID tezTaskAttemptID = TezTaskAttemptID.getInstance(tId, 2);
+ int startIndex = 1;
+ int partitionNum = 20;
+ int shuffleId = 1998;
+
+ GetShuffleServerRequest request = new
GetShuffleServerRequest(tezTaskAttemptID, startIndex,
+ partitionNum, shuffleId);
+
+ ByteArrayOutputStream bos = new ByteArrayOutputStream();
+ DataOutput out = new DataOutputStream(bos);
+ request.write(out);
+
+ GetShuffleServerRequest deSerRequest = new GetShuffleServerRequest();
+ ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray());
+ DataInput in = new DataInputStream(bis);
+ deSerRequest.readFields(in);
+
+ assertEquals(request.getCurrentTaskAttemptID(),
deSerRequest.getCurrentTaskAttemptID());
+ assertEquals(request.getStartIndex(), deSerRequest.getStartIndex());
+ assertEquals(request.getPartitionNum(), deSerRequest.getPartitionNum());
+ assertEquals(request.getShuffleId(), deSerRequest.getShuffleId());
+ }
+}
diff --git
a/client-tez/src/test/java/org/apache/tez/common/GetShuffleServerResponseTest.java
b/client-tez/src/test/java/org/apache/tez/common/GetShuffleServerResponseTest.java
new file mode 100644
index 00000000..34419bfd
--- /dev/null
+++
b/client-tez/src/test/java/org/apache/tez/common/GetShuffleServerResponseTest.java
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.common;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataInput;
+import java.io.DataInputStream;
+import java.io.DataOutput;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.junit.jupiter.api.Test;
+
+import org.apache.uniffle.common.PartitionRange;
+import org.apache.uniffle.common.ShuffleAssignmentsInfo;
+import org.apache.uniffle.common.ShuffleServerInfo;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+public class GetShuffleServerResponseTest {
+ @Test
+ public void testSerDe() throws IOException {
+ Map<Integer, List<ShuffleServerInfo>> partitionToServers = new HashMap<>();
+ partitionToServers.put(0, new ArrayList<>());
+ partitionToServers.put(1, new ArrayList<>());
+ partitionToServers.put(2, new ArrayList<>());
+ partitionToServers.put(3, new ArrayList<>());
+ partitionToServers.put(4, new ArrayList<>());
+
+ ShuffleServerInfo work1 = new ShuffleServerInfo("host1", 9999);
+ ShuffleServerInfo work2 = new ShuffleServerInfo("host1", 9999);
+ ShuffleServerInfo work3 = new ShuffleServerInfo("host1", 9999);
+ ShuffleServerInfo work4 = new ShuffleServerInfo("host1", 9999);
+
+ partitionToServers.get(0).addAll(Arrays.asList(work1, work2, work3,
work4));
+ partitionToServers.get(1).addAll(Arrays.asList(work1, work2, work3,
work4));
+ partitionToServers.get(2).addAll(Arrays.asList(work1, work3));
+ partitionToServers.get(3).addAll(Arrays.asList(work3, work4));
+ partitionToServers.get(4).addAll(Arrays.asList(work2, work4));
+
+ Map<ShuffleServerInfo, List<PartitionRange>> serverToPartitionRanges = new
HashMap<>();
+ PartitionRange range0 = new PartitionRange(0, 0);
+ PartitionRange range1 = new PartitionRange(1, 1);
+ PartitionRange range2 = new PartitionRange(2, 2);
+ PartitionRange range3 = new PartitionRange(3, 3);
+ PartitionRange range4 = new PartitionRange(4, 4);
+
+ serverToPartitionRanges.put(work1, Arrays.asList(range0, range1, range2));
+ serverToPartitionRanges.put(work2, Arrays.asList(range0, range1, range4));
+ serverToPartitionRanges.put(work3, Arrays.asList(range0, range1, range2,
range3));
+ serverToPartitionRanges.put(work4, Arrays.asList(range0, range1, range3,
range4));
+
+ ShuffleAssignmentsInfo info = new
ShuffleAssignmentsInfo(partitionToServers, serverToPartitionRanges);
+
+ int status = 0;
+ String retMsg = "none";
+ ShuffleAssignmentsInfoWritable infoWritable = new
ShuffleAssignmentsInfoWritable(info);
+
+ GetShuffleServerResponse response = new GetShuffleServerResponse();
+ response.setStatus(status);
+ response.setRetMsg(retMsg);
+ response.setShuffleAssignmentsInfoWritable(infoWritable);
+
+ ByteArrayOutputStream bos = new ByteArrayOutputStream();
+ DataOutput out = new DataOutputStream(bos);
+ response.write(out);
+
+ GetShuffleServerResponse deSerResponse = new GetShuffleServerResponse();
+ ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray());
+ DataInput in = new DataInputStream(bis);
+ deSerResponse.readFields(in);
+
+ assertEquals(response.getStatus(), deSerResponse.getStatus());
+ assertEquals(response.getRetMsg(), deSerResponse.getRetMsg());
+ {
+ Map<Integer, List<ShuffleServerInfo>> base =
response.getShuffleAssignmentsInfoWritable()
+ .getShuffleAssignmentsInfo()
+ .getPartitionToServers();
+ Map<Integer, List<ShuffleServerInfo>> deSer =
deSerResponse.getShuffleAssignmentsInfoWritable()
+ .getShuffleAssignmentsInfo()
+ .getPartitionToServers();
+
+ assertEquals(base.size(), deSer.size());
+ for (Integer partitionId : base.keySet()) {
+ assertEquals(base.get(partitionId), deSer.get(partitionId));
+ }
+ }
+ {
+ Map<ShuffleServerInfo, List<PartitionRange>> base =
response.getShuffleAssignmentsInfoWritable()
+ .getShuffleAssignmentsInfo()
+ .getServerToPartitionRanges();
+ Map<ShuffleServerInfo, List<PartitionRange>> deSer =
deSerResponse.getShuffleAssignmentsInfoWritable()
+ .getShuffleAssignmentsInfo()
+ .getServerToPartitionRanges();
+
+ assertEquals(base.size(), deSer.size());
+ for (ShuffleServerInfo serverInfo : base.keySet()) {
+ assertEquals(base.get(serverInfo), deSer.get(serverInfo));
+ }
+ }
+ }
+}
diff --git
a/client-tez/src/test/java/org/apache/tez/common/ShuffleAssignmentsInfoWritableTest.java
b/client-tez/src/test/java/org/apache/tez/common/ShuffleAssignmentsInfoWritableTest.java
new file mode 100644
index 00000000..888c01cd
--- /dev/null
+++
b/client-tez/src/test/java/org/apache/tez/common/ShuffleAssignmentsInfoWritableTest.java
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.common;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataInput;
+import java.io.DataInputStream;
+import java.io.DataOutput;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.junit.jupiter.api.Test;
+
+import org.apache.uniffle.common.PartitionRange;
+import org.apache.uniffle.common.ShuffleAssignmentsInfo;
+import org.apache.uniffle.common.ShuffleServerInfo;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+public class ShuffleAssignmentsInfoWritableTest {
+ @Test
+ public void testSerDe() throws IOException {
+ Map<Integer, List<ShuffleServerInfo>> partitionToServers = new HashMap<>();
+ partitionToServers.put(0, new ArrayList<>());
+ partitionToServers.put(1, new ArrayList<>());
+ partitionToServers.put(2, new ArrayList<>());
+ partitionToServers.put(3, new ArrayList<>());
+ partitionToServers.put(4, new ArrayList<>());
+
+ ShuffleServerInfo work1 = new ShuffleServerInfo("host1", 9999);
+ ShuffleServerInfo work2 = new ShuffleServerInfo("host1", 9999);
+ ShuffleServerInfo work3 = new ShuffleServerInfo("host1", 9999);
+ ShuffleServerInfo work4 = new ShuffleServerInfo("host1", 9999);
+
+ partitionToServers.get(0).addAll(Arrays.asList(work1, work2, work3,
work4));
+ partitionToServers.get(1).addAll(Arrays.asList(work1, work2, work3,
work4));
+ partitionToServers.get(2).addAll(Arrays.asList(work1, work3));
+ partitionToServers.get(3).addAll(Arrays.asList(work3, work4));
+ partitionToServers.get(4).addAll(Arrays.asList(work2, work4));
+
+ Map<ShuffleServerInfo, List<PartitionRange>> serverToPartitionRanges = new
HashMap<>();
+ PartitionRange range0 = new PartitionRange(0, 0);
+ PartitionRange range1 = new PartitionRange(1, 1);
+ PartitionRange range2 = new PartitionRange(2, 2);
+ PartitionRange range3 = new PartitionRange(3, 3);
+ PartitionRange range4 = new PartitionRange(4, 4);
+
+ serverToPartitionRanges.put(work1, Arrays.asList(range0, range1, range2));
+ serverToPartitionRanges.put(work2, Arrays.asList(range0, range1, range4));
+ serverToPartitionRanges.put(work3, Arrays.asList(range0, range1, range2,
range3));
+ serverToPartitionRanges.put(work4, Arrays.asList(range0, range1, range3,
range4));
+
+
+ ShuffleAssignmentsInfo info = new
ShuffleAssignmentsInfo(partitionToServers, serverToPartitionRanges);
+ ShuffleAssignmentsInfoWritable infoWritable = new
ShuffleAssignmentsInfoWritable(info);
+
+ ByteArrayOutputStream bos = new ByteArrayOutputStream();
+ DataOutput out = new DataOutputStream(bos);
+ infoWritable.write(out);
+
+ ShuffleAssignmentsInfoWritable deSerInfoWritable = new
ShuffleAssignmentsInfoWritable();
+ ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray());
+ DataInput in = new DataInputStream(bis);
+ deSerInfoWritable.readFields(in);
+
+ {
+ Map<Integer, List<ShuffleServerInfo>> base =
infoWritable.getShuffleAssignmentsInfo().getPartitionToServers();
+ Map<Integer, List<ShuffleServerInfo>> deSer =
deSerInfoWritable.getShuffleAssignmentsInfo()
+ .getPartitionToServers();
+
+ assertEquals(base.size(), deSer.size());
+ for (Integer partitionId : base.keySet()) {
+ assertEquals(base.get(partitionId), deSer.get(partitionId));
+ }
+ }
+ {
+ Map<ShuffleServerInfo, List<PartitionRange>> base =
infoWritable.getShuffleAssignmentsInfo()
+ .getServerToPartitionRanges();
+ Map<ShuffleServerInfo, List<PartitionRange>> deSer =
deSerInfoWritable.getShuffleAssignmentsInfo()
+ .getServerToPartitionRanges();
+
+ assertEquals(base.size(), deSer.size());
+ for (ShuffleServerInfo serverInfo : base.keySet()) {
+ assertEquals(base.get(serverInfo), deSer.get(serverInfo));
+ }
+ }
+ }
+}