This is an automated email from the ASF dual-hosted git repository.

roryqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new c713f391 [#872][FOLLOWUP] feat(tez): Add the common and utils class 
(#894)
c713f391 is described below

commit c713f3917b887657e2ffca30613d100e3673e8ee
Author: Qing <[email protected]>
AuthorDate: Wed May 31 22:32:34 2023 +0800

    [#872][FOLLOWUP] feat(tez): Add the common and utils class (#894)
    
    ### What changes were proposed in this pull request?
    
    Tez Shuffle Read supporting Remote Shuffle related common Class
    
    ### Why are the changes needed?
    
    Fix:  #872
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    
    UT.
---
 .../apache/tez/common/GetShuffleServerRequest.java |  96 ++++++++++++
 .../tez/common/GetShuffleServerResponse.java       |  77 ++++++++++
 .../tez/common/ShuffleAssignmentsInfoWritable.java | 165 +++++++++++++++++++++
 .../java/org/apache/tez/common/TezClassLoader.java |  82 ++++++++++
 .../tez/common/GetShuffleServerRequestTest.java    |  67 +++++++++
 .../tez/common/GetShuffleServerResponseTest.java   | 123 +++++++++++++++
 .../common/ShuffleAssignmentsInfoWritableTest.java | 109 ++++++++++++++
 7 files changed, 719 insertions(+)

diff --git 
a/client-tez/src/main/java/org/apache/tez/common/GetShuffleServerRequest.java 
b/client-tez/src/main/java/org/apache/tez/common/GetShuffleServerRequest.java
new file mode 100644
index 00000000..bdd8f00b
--- /dev/null
+++ 
b/client-tez/src/main/java/org/apache/tez/common/GetShuffleServerRequest.java
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.common;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+import org.apache.hadoop.io.Writable;
+import org.apache.tez.dag.records.TezTaskAttemptID;
+
+
+public class GetShuffleServerRequest implements Writable {
+  private TezTaskAttemptID currentTaskAttemptID;
+  private int startIndex;
+  private int partitionNum;
+  private int shuffleId;
+
+  public GetShuffleServerRequest() {
+  }
+
+  public GetShuffleServerRequest(TezTaskAttemptID currentTaskAttemptID, int 
startIndex,
+                                 int partitionNum, int shuffleId) {
+    this.currentTaskAttemptID = currentTaskAttemptID;
+    this.startIndex = startIndex;
+    this.partitionNum = partitionNum;
+    this.shuffleId = shuffleId;
+  }
+
+  @Override
+  public void write(DataOutput output) throws IOException {
+    output.writeInt(startIndex);
+    output.writeInt(partitionNum);
+    output.writeInt(shuffleId);
+    if (currentTaskAttemptID != null) {
+      output.writeBoolean(true);
+      currentTaskAttemptID.write(output);
+    } else {
+      output.writeBoolean(false);
+    }
+  }
+
+  @Override
+  public void readFields(DataInput dataInput) throws IOException {
+    startIndex = dataInput.readInt();
+    partitionNum = dataInput.readInt();
+    shuffleId = dataInput.readInt();
+    boolean hasTaskTaskAttemptID = dataInput.readBoolean();
+    if (hasTaskTaskAttemptID) {
+      currentTaskAttemptID = new TezTaskAttemptID();
+      currentTaskAttemptID.readFields(dataInput);
+    }
+  }
+
+  @Override
+  public String toString() {
+    return "GetShuffleServerRequest{"
+            + "currentTaskAttemptID="
+            + currentTaskAttemptID
+            + ", startIndex=" + startIndex
+            + ", partitionNum=" + partitionNum
+            + ", shuffleId=" + shuffleId
+            + '}';
+  }
+
+  public TezTaskAttemptID getCurrentTaskAttemptID() {
+    return currentTaskAttemptID;
+  }
+
+  public int getStartIndex() {
+    return startIndex;
+  }
+
+  public int getPartitionNum() {
+    return partitionNum;
+  }
+
+  public int getShuffleId() {
+    return shuffleId;
+  }
+}
diff --git 
a/client-tez/src/main/java/org/apache/tez/common/GetShuffleServerResponse.java 
b/client-tez/src/main/java/org/apache/tez/common/GetShuffleServerResponse.java
new file mode 100644
index 00000000..0defeaa5
--- /dev/null
+++ 
b/client-tez/src/main/java/org/apache/tez/common/GetShuffleServerResponse.java
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.common;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.Objects;
+
+import org.apache.hadoop.io.Writable;
+
+public class GetShuffleServerResponse implements Writable {
+  private int status;
+  private String retMsg;
+  private ShuffleAssignmentsInfoWritable shuffleAssignmentsInfoWritable;
+
+  public int getStatus() {
+    return status;
+  }
+
+  public void setStatus(int status) {
+    this.status = status;
+  }
+
+  public String getRetMsg() {
+    return retMsg;
+  }
+
+  public void setRetMsg(String retMsg) {
+    this.retMsg = retMsg;
+  }
+
+  public ShuffleAssignmentsInfoWritable getShuffleAssignmentsInfoWritable() {
+    return shuffleAssignmentsInfoWritable;
+  }
+
+  public void setShuffleAssignmentsInfoWritable(ShuffleAssignmentsInfoWritable 
shuffleAssignmentsInfoWritable) {
+    this.shuffleAssignmentsInfoWritable = shuffleAssignmentsInfoWritable;
+  }
+
+  @Override
+  public void write(DataOutput dataOutput) throws IOException {
+    dataOutput.writeInt(status);
+    dataOutput.writeUTF(retMsg);
+    if (Objects.isNull(shuffleAssignmentsInfoWritable)) {
+      dataOutput.writeInt(-1);
+    } else {
+      dataOutput.writeInt(1);
+      shuffleAssignmentsInfoWritable.write(dataOutput);
+    }
+  }
+
+  @Override
+  public void readFields(DataInput dataInput) throws IOException {
+    status = dataInput.readInt();
+    retMsg = dataInput.readUTF();
+    shuffleAssignmentsInfoWritable = new ShuffleAssignmentsInfoWritable();
+    if (dataInput.readInt() != -1) {
+      shuffleAssignmentsInfoWritable.readFields(dataInput);
+    }
+  }
+}
diff --git 
a/client-tez/src/main/java/org/apache/tez/common/ShuffleAssignmentsInfoWritable.java
 
b/client-tez/src/main/java/org/apache/tez/common/ShuffleAssignmentsInfoWritable.java
new file mode 100644
index 00000000..71a32a68
--- /dev/null
+++ 
b/client-tez/src/main/java/org/apache/tez/common/ShuffleAssignmentsInfoWritable.java
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.common;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.apache.commons.collections.MapUtils;
+import org.apache.hadoop.io.Writable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.common.PartitionRange;
+import org.apache.uniffle.common.ShuffleAssignmentsInfo;
+import org.apache.uniffle.common.ShuffleServerInfo;
+
+public class ShuffleAssignmentsInfoWritable implements Writable {
+  private ShuffleAssignmentsInfo shuffleAssignmentsInfo;
+  private static final Logger LOG = 
LoggerFactory.getLogger(ShuffleAssignmentsInfoWritable.class);
+
+
+  public ShuffleAssignmentsInfoWritable() {
+
+  }
+
+  public ShuffleAssignmentsInfoWritable(ShuffleAssignmentsInfo 
shuffleAssignmentsInfo) {
+    this.shuffleAssignmentsInfo = shuffleAssignmentsInfo;
+  }
+
+
+  @Override
+  public void write(DataOutput dataOutput) throws IOException {
+    if (shuffleAssignmentsInfo == null) {
+      dataOutput.writeInt(-1);
+      LOG.warn("shuffleAssignmentsInfo is null, no need write");
+      return;
+    } else {
+      dataOutput.writeInt(1);
+    }
+
+    Map<Integer, List<ShuffleServerInfo>> partitionToServers = 
shuffleAssignmentsInfo.getPartitionToServers();
+    if (MapUtils.isEmpty(partitionToServers)) {
+      dataOutput.writeInt(-1);
+    } else {
+      dataOutput.writeInt(partitionToServers.size());
+      for (Map.Entry<Integer, List<ShuffleServerInfo>> entry : 
partitionToServers.entrySet()) {
+        dataOutput.writeInt(entry.getKey());
+        if (CollectionUtils.isEmpty(entry.getValue())) {
+          dataOutput.writeInt(-1);
+        } else {
+          dataOutput.writeInt(entry.getValue().size());
+          for (ShuffleServerInfo serverInfo : entry.getValue()) {
+            dataOutput.writeUTF(serverInfo.getId());
+            dataOutput.writeUTF(serverInfo.getHost());
+            dataOutput.writeInt(serverInfo.getGrpcPort());
+          }
+        }
+      }
+    }
+
+    Map<ShuffleServerInfo, List<PartitionRange>> serverToPartitionRanges = 
shuffleAssignmentsInfo
+            .getServerToPartitionRanges();
+    if (MapUtils.isEmpty(serverToPartitionRanges)) {
+      dataOutput.writeInt(-1);
+    } else {
+      dataOutput.writeInt(serverToPartitionRanges.size());
+      for (Map.Entry<ShuffleServerInfo, List<PartitionRange>> entry : 
serverToPartitionRanges.entrySet()) {
+        dataOutput.writeUTF(entry.getKey().getId());
+        dataOutput.writeUTF(entry.getKey().getHost());
+        dataOutput.writeInt(entry.getKey().getGrpcPort());
+        if (CollectionUtils.isEmpty(entry.getValue())) {
+          dataOutput.writeInt(-1);
+        } else {
+          dataOutput.writeInt(entry.getValue().size());
+          for (PartitionRange range : entry.getValue()) {
+            dataOutput.writeInt(range.getStart());
+            dataOutput.writeInt(range.getEnd());
+          }
+        }
+      }
+    }
+  }
+
+  @Override
+  public void readFields(DataInput dataInput) throws IOException {
+    if (dataInput.readInt() == -1) {
+      LOG.warn("shuffleAssignmentsInfo is null, no need read");
+      return;
+    }
+
+    Map<Integer, List<ShuffleServerInfo>> partitionToServers = new HashMap<>();
+    int partitionToServersSize = dataInput.readInt();
+    if (partitionToServersSize != -1) {
+      Integer partitionId;
+      for (int i = 0; i < partitionToServersSize; i++) {
+        partitionId = dataInput.readInt();
+        List<ShuffleServerInfo> shuffleServerInfoList = new ArrayList<>();
+        int shuffleServerInfoListSize = dataInput.readInt();
+        if (shuffleServerInfoListSize != -1) {
+          for (int i1 = 0; i1 < shuffleServerInfoListSize; i1++) {
+            String id = dataInput.readUTF();
+            String host = dataInput.readUTF();
+            int port = dataInput.readInt();
+            ShuffleServerInfo shuffleServerInfo = new ShuffleServerInfo(id, 
host, port);
+            shuffleServerInfoList.add(shuffleServerInfo);
+          }
+        }
+
+        partitionToServers.put(partitionId, shuffleServerInfoList);
+      }
+    }
+
+    Map<ShuffleServerInfo, List<PartitionRange>> serverToPartitionRanges = new 
HashMap<>();
+    int serverToPartitionRangesSize = dataInput.readInt();
+    if (serverToPartitionRangesSize != -1) {
+      for (int i = 0; i < serverToPartitionRangesSize; i++) {
+        ShuffleServerInfo shuffleServerInfo;
+        List<PartitionRange> partitionRangeList = new ArrayList<>();
+
+        String id = dataInput.readUTF();
+        String host = dataInput.readUTF();
+        int port = dataInput.readInt();
+        shuffleServerInfo = new ShuffleServerInfo(id, host, port);
+
+        int partitionRangeListSize = dataInput.readInt();
+        if (partitionRangeListSize != -1) {
+          for (int i1 = 0; i1 < partitionRangeListSize; i1++) {
+            int start = dataInput.readInt();
+            int end = dataInput.readInt();
+            PartitionRange partitionRange = new PartitionRange(start, end);
+            partitionRangeList.add(partitionRange);
+          }
+        }
+        serverToPartitionRanges.put(shuffleServerInfo, partitionRangeList);
+      }
+    }
+
+    shuffleAssignmentsInfo = new ShuffleAssignmentsInfo(partitionToServers, 
serverToPartitionRanges);
+  }
+
+  public ShuffleAssignmentsInfo getShuffleAssignmentsInfo() {
+    return shuffleAssignmentsInfo;
+  }
+}
diff --git a/client-tez/src/main/java/org/apache/tez/common/TezClassLoader.java 
b/client-tez/src/main/java/org/apache/tez/common/TezClassLoader.java
new file mode 100644
index 00000000..1952e2d2
--- /dev/null
+++ b/client-tez/src/main/java/org/apache/tez/common/TezClassLoader.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.common;
+
+import java.net.URL;
+import java.net.URLClassLoader;
+import java.security.AccessController;
+import java.security.PrivilegedAction;
+
+import org.apache.hadoop.conf.Configuration;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+/**
+ * ClassLoader to allow addition of new paths to classpath in the runtime.
+ *
+ * It uses URLClassLoader with this class' classloader as parent classloader.
+ * And hence first delegates the resource loading to parent and then to the 
URLs
+ * added. The process must be setup to use by invoking setupTezClassLoader() 
which sets
+ * the global TezClassLoader as current thread context class loader. All 
threads
+ * created will inherit the classloader and hence will resolve the 
class/resource
+ * from TezClassLoader.
+ */
+
+public class TezClassLoader extends URLClassLoader {
+  private static final TezClassLoader INSTANCE;
+  private static final Logger LOG = 
LoggerFactory.getLogger(TezClassLoader.class);
+
+  static {
+    INSTANCE = AccessController.doPrivileged(new 
PrivilegedAction<TezClassLoader>() {
+      @Override
+      public TezClassLoader run() {
+        return new TezClassLoader();
+      }
+    });
+  }
+
+  private TezClassLoader() {
+    super(new URL[] {}, TezClassLoader.class.getClassLoader());
+
+    LOG.info(
+            "Created TezClassLoader with parent classloader: {}, thread: {}, 
system classloader: {}",
+            TezClassLoader.class.getClassLoader(), 
Thread.currentThread().getId(),
+            ClassLoader.getSystemClassLoader());
+  }
+
+  @Override
+  public void addURL(URL url) {
+    super.addURL(url);
+  }
+
+  public static TezClassLoader getInstance() {
+    return INSTANCE;
+  }
+
+  public static void setupTezClassLoader() {
+    LOG.debug(
+            "Setting up TezClassLoader: thread: {}, current thread 
classloader: {} system classloader: {}",
+            Thread.currentThread().getId(), 
Thread.currentThread().getContextClassLoader(),
+            ClassLoader.getSystemClassLoader());
+    Thread.currentThread().setContextClassLoader(INSTANCE);
+  }
+
+  public static void setupForConfiguration(Configuration configuration) {
+    configuration.setClassLoader(INSTANCE);
+  }
+}
+
diff --git 
a/client-tez/src/test/java/org/apache/tez/common/GetShuffleServerRequestTest.java
 
b/client-tez/src/test/java/org/apache/tez/common/GetShuffleServerRequestTest.java
new file mode 100644
index 00000000..d300844b
--- /dev/null
+++ 
b/client-tez/src/test/java/org/apache/tez/common/GetShuffleServerRequestTest.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.common;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataInput;
+import java.io.DataInputStream;
+import java.io.DataOutput;
+import java.io.DataOutputStream;
+import java.io.IOException;
+
+import org.apache.hadoop.yarn.api.records.ApplicationId;
+import org.apache.tez.dag.records.TezDAGID;
+import org.apache.tez.dag.records.TezTaskAttemptID;
+import org.apache.tez.dag.records.TezTaskID;
+import org.apache.tez.dag.records.TezVertexID;
+import org.junit.jupiter.api.Test;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+public class GetShuffleServerRequestTest {
+  @Test
+  public void testSerDe() throws IOException {
+    ApplicationId appId = ApplicationId.newInstance(9999, 72);
+    TezDAGID dagId = TezDAGID.getInstance(appId, 1);
+    TezVertexID vId = TezVertexID.getInstance(dagId, 35);
+    TezTaskID tId = TezTaskID.getInstance(vId, 389);
+
+    TezTaskAttemptID tezTaskAttemptID = TezTaskAttemptID.getInstance(tId, 2);
+    int startIndex = 1;
+    int partitionNum = 20;
+    int shuffleId = 1998;
+
+    GetShuffleServerRequest request = new 
GetShuffleServerRequest(tezTaskAttemptID, startIndex,
+            partitionNum, shuffleId);
+
+    ByteArrayOutputStream bos = new ByteArrayOutputStream();
+    DataOutput out = new DataOutputStream(bos);
+    request.write(out);
+
+    GetShuffleServerRequest deSerRequest = new GetShuffleServerRequest();
+    ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray());
+    DataInput in = new DataInputStream(bis);
+    deSerRequest.readFields(in);
+
+    assertEquals(request.getCurrentTaskAttemptID(), 
deSerRequest.getCurrentTaskAttemptID());
+    assertEquals(request.getStartIndex(), deSerRequest.getStartIndex());
+    assertEquals(request.getPartitionNum(), deSerRequest.getPartitionNum());
+    assertEquals(request.getShuffleId(), deSerRequest.getShuffleId());
+  }
+}
diff --git 
a/client-tez/src/test/java/org/apache/tez/common/GetShuffleServerResponseTest.java
 
b/client-tez/src/test/java/org/apache/tez/common/GetShuffleServerResponseTest.java
new file mode 100644
index 00000000..34419bfd
--- /dev/null
+++ 
b/client-tez/src/test/java/org/apache/tez/common/GetShuffleServerResponseTest.java
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.common;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataInput;
+import java.io.DataInputStream;
+import java.io.DataOutput;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.junit.jupiter.api.Test;
+
+import org.apache.uniffle.common.PartitionRange;
+import org.apache.uniffle.common.ShuffleAssignmentsInfo;
+import org.apache.uniffle.common.ShuffleServerInfo;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+public class GetShuffleServerResponseTest {
+  @Test
+  public void testSerDe() throws IOException {
+    Map<Integer, List<ShuffleServerInfo>> partitionToServers = new HashMap<>();
+    partitionToServers.put(0, new ArrayList<>());
+    partitionToServers.put(1, new ArrayList<>());
+    partitionToServers.put(2, new ArrayList<>());
+    partitionToServers.put(3, new ArrayList<>());
+    partitionToServers.put(4, new ArrayList<>());
+
+    ShuffleServerInfo work1 = new ShuffleServerInfo("host1", 9999);
+    ShuffleServerInfo work2 = new ShuffleServerInfo("host1", 9999);
+    ShuffleServerInfo work3 = new ShuffleServerInfo("host1", 9999);
+    ShuffleServerInfo work4 = new ShuffleServerInfo("host1", 9999);
+
+    partitionToServers.get(0).addAll(Arrays.asList(work1, work2, work3, 
work4));
+    partitionToServers.get(1).addAll(Arrays.asList(work1, work2, work3, 
work4));
+    partitionToServers.get(2).addAll(Arrays.asList(work1, work3));
+    partitionToServers.get(3).addAll(Arrays.asList(work3, work4));
+    partitionToServers.get(4).addAll(Arrays.asList(work2, work4));
+
+    Map<ShuffleServerInfo, List<PartitionRange>> serverToPartitionRanges = new 
HashMap<>();
+    PartitionRange range0 = new PartitionRange(0, 0);
+    PartitionRange range1 = new PartitionRange(1, 1);
+    PartitionRange range2 = new PartitionRange(2, 2);
+    PartitionRange range3 = new PartitionRange(3, 3);
+    PartitionRange range4 = new PartitionRange(4, 4);
+
+    serverToPartitionRanges.put(work1, Arrays.asList(range0, range1, range2));
+    serverToPartitionRanges.put(work2, Arrays.asList(range0, range1, range4));
+    serverToPartitionRanges.put(work3, Arrays.asList(range0, range1, range2, 
range3));
+    serverToPartitionRanges.put(work4, Arrays.asList(range0, range1, range3, 
range4));
+
+    ShuffleAssignmentsInfo info = new 
ShuffleAssignmentsInfo(partitionToServers, serverToPartitionRanges);
+
+    int status = 0;
+    String retMsg = "none";
+    ShuffleAssignmentsInfoWritable infoWritable = new 
ShuffleAssignmentsInfoWritable(info);
+
+    GetShuffleServerResponse response = new GetShuffleServerResponse();
+    response.setStatus(status);
+    response.setRetMsg(retMsg);
+    response.setShuffleAssignmentsInfoWritable(infoWritable);
+
+    ByteArrayOutputStream bos = new ByteArrayOutputStream();
+    DataOutput out = new DataOutputStream(bos);
+    response.write(out);
+
+    GetShuffleServerResponse deSerResponse = new GetShuffleServerResponse();
+    ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray());
+    DataInput in = new DataInputStream(bis);
+    deSerResponse.readFields(in);
+
+    assertEquals(response.getStatus(), deSerResponse.getStatus());
+    assertEquals(response.getRetMsg(), deSerResponse.getRetMsg());
+    {
+      Map<Integer, List<ShuffleServerInfo>> base = 
response.getShuffleAssignmentsInfoWritable()
+              .getShuffleAssignmentsInfo()
+              .getPartitionToServers();
+      Map<Integer, List<ShuffleServerInfo>> deSer = 
deSerResponse.getShuffleAssignmentsInfoWritable()
+              .getShuffleAssignmentsInfo()
+              .getPartitionToServers();
+
+      assertEquals(base.size(), deSer.size());
+      for (Integer partitionId : base.keySet()) {
+        assertEquals(base.get(partitionId), deSer.get(partitionId));
+      }
+    }
+    {
+      Map<ShuffleServerInfo, List<PartitionRange>> base = 
response.getShuffleAssignmentsInfoWritable()
+              .getShuffleAssignmentsInfo()
+              .getServerToPartitionRanges();
+      Map<ShuffleServerInfo, List<PartitionRange>> deSer = 
deSerResponse.getShuffleAssignmentsInfoWritable()
+              .getShuffleAssignmentsInfo()
+              .getServerToPartitionRanges();
+
+      assertEquals(base.size(), deSer.size());
+      for (ShuffleServerInfo serverInfo : base.keySet()) {
+        assertEquals(base.get(serverInfo), deSer.get(serverInfo));
+      }
+    }
+  }
+}
diff --git 
a/client-tez/src/test/java/org/apache/tez/common/ShuffleAssignmentsInfoWritableTest.java
 
b/client-tez/src/test/java/org/apache/tez/common/ShuffleAssignmentsInfoWritableTest.java
new file mode 100644
index 00000000..888c01cd
--- /dev/null
+++ 
b/client-tez/src/test/java/org/apache/tez/common/ShuffleAssignmentsInfoWritableTest.java
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.common;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataInput;
+import java.io.DataInputStream;
+import java.io.DataOutput;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.junit.jupiter.api.Test;
+
+import org.apache.uniffle.common.PartitionRange;
+import org.apache.uniffle.common.ShuffleAssignmentsInfo;
+import org.apache.uniffle.common.ShuffleServerInfo;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+public class ShuffleAssignmentsInfoWritableTest {
+  @Test
+  public void testSerDe() throws IOException {
+    Map<Integer, List<ShuffleServerInfo>> partitionToServers = new HashMap<>();
+    partitionToServers.put(0, new ArrayList<>());
+    partitionToServers.put(1, new ArrayList<>());
+    partitionToServers.put(2, new ArrayList<>());
+    partitionToServers.put(3, new ArrayList<>());
+    partitionToServers.put(4, new ArrayList<>());
+
+    ShuffleServerInfo work1 = new ShuffleServerInfo("host1", 9999);
+    ShuffleServerInfo work2 = new ShuffleServerInfo("host1", 9999);
+    ShuffleServerInfo work3 = new ShuffleServerInfo("host1", 9999);
+    ShuffleServerInfo work4 = new ShuffleServerInfo("host1", 9999);
+
+    partitionToServers.get(0).addAll(Arrays.asList(work1, work2, work3, 
work4));
+    partitionToServers.get(1).addAll(Arrays.asList(work1, work2, work3, 
work4));
+    partitionToServers.get(2).addAll(Arrays.asList(work1, work3));
+    partitionToServers.get(3).addAll(Arrays.asList(work3, work4));
+    partitionToServers.get(4).addAll(Arrays.asList(work2, work4));
+
+    Map<ShuffleServerInfo, List<PartitionRange>> serverToPartitionRanges = new 
HashMap<>();
+    PartitionRange range0 = new PartitionRange(0, 0);
+    PartitionRange range1 = new PartitionRange(1, 1);
+    PartitionRange range2 = new PartitionRange(2, 2);
+    PartitionRange range3 = new PartitionRange(3, 3);
+    PartitionRange range4 = new PartitionRange(4, 4);
+
+    serverToPartitionRanges.put(work1, Arrays.asList(range0, range1, range2));
+    serverToPartitionRanges.put(work2, Arrays.asList(range0, range1, range4));
+    serverToPartitionRanges.put(work3, Arrays.asList(range0, range1, range2, 
range3));
+    serverToPartitionRanges.put(work4, Arrays.asList(range0, range1, range3, 
range4));
+
+
+    ShuffleAssignmentsInfo info = new 
ShuffleAssignmentsInfo(partitionToServers, serverToPartitionRanges);
+    ShuffleAssignmentsInfoWritable infoWritable = new 
ShuffleAssignmentsInfoWritable(info);
+
+    ByteArrayOutputStream bos = new ByteArrayOutputStream();
+    DataOutput out = new DataOutputStream(bos);
+    infoWritable.write(out);
+
+    ShuffleAssignmentsInfoWritable deSerInfoWritable = new 
ShuffleAssignmentsInfoWritable();
+    ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray());
+    DataInput in = new DataInputStream(bis);
+    deSerInfoWritable.readFields(in);
+
+    {
+      Map<Integer, List<ShuffleServerInfo>> base = 
infoWritable.getShuffleAssignmentsInfo().getPartitionToServers();
+      Map<Integer, List<ShuffleServerInfo>> deSer = 
deSerInfoWritable.getShuffleAssignmentsInfo()
+              .getPartitionToServers();
+
+      assertEquals(base.size(), deSer.size());
+      for (Integer partitionId : base.keySet()) {
+        assertEquals(base.get(partitionId), deSer.get(partitionId));
+      }
+    }
+    {
+      Map<ShuffleServerInfo, List<PartitionRange>> base = 
infoWritable.getShuffleAssignmentsInfo()
+              .getServerToPartitionRanges();
+      Map<ShuffleServerInfo, List<PartitionRange>> deSer = 
deSerInfoWritable.getShuffleAssignmentsInfo()
+              .getServerToPartitionRanges();
+
+      assertEquals(base.size(), deSer.size());
+      for (ShuffleServerInfo serverInfo : base.keySet()) {
+        assertEquals(base.get(serverInfo), deSer.get(serverInfo));
+      }
+    }
+  }
+}

Reply via email to