This is an automated email from the ASF dual-hosted git repository.

roryqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new cfd927fa [#991] Improvement(tez): TezRemoteShuffleManager support 
secure cluster. (#1005)
cfd927fa is described below

commit cfd927fa62972010acb8e14d598a1fa61b94d35a
Author: zhengchenyu <[email protected]>
AuthorDate: Wed Jul 12 15:33:33 2023 +0800

    [#991] Improvement(tez): TezRemoteShuffleManager support secure cluster. 
(#1005)
    
    ### What changes were proposed in this pull request?
    
    support secure cluster.
    
    issue: #991
    
    ### How was this patch tested?
    
    unit test, test on yarn cluster.
---
 .../java/org/apache/tez/common/UmbilicalUtils.java |  15 ++-
 .../org/apache/tez/dag/app/RssDAGAppMaster.java    |   8 +-
 .../tez/dag/app/TezRemoteShuffleManager.java       |   8 +-
 .../security/authorize/RssTezAMPolicyProvider.java |  37 +++++++
 .../output/RssOrderedPartitionedKVOutput.java      |  10 +-
 .../library/output/RssUnorderedKVOutput.java       |  10 +-
 .../output/RssUnorderedPartitionedKVOutput.java    |   9 ++
 .../tez/dag/app/TezRemoteShuffleManagerTest.java   | 112 ++++++++++++++++++++-
 8 files changed, 198 insertions(+), 11 deletions(-)

diff --git a/client-tez/src/main/java/org/apache/tez/common/UmbilicalUtils.java 
b/client-tez/src/main/java/org/apache/tez/common/UmbilicalUtils.java
index 26e0dc31..2c211131 100644
--- a/client-tez/src/main/java/org/apache/tez/common/UmbilicalUtils.java
+++ b/client-tez/src/main/java/org/apache/tez/common/UmbilicalUtils.java
@@ -26,8 +26,13 @@ import java.util.Map;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.ipc.RPC;
 import org.apache.hadoop.net.NetUtils;
+import org.apache.hadoop.security.Credentials;
+import org.apache.hadoop.security.SecurityUtil;
 import org.apache.hadoop.security.UserGroupInformation;
+import org.apache.hadoop.security.token.Token;
 import org.apache.hadoop.yarn.api.records.ApplicationId;
+import org.apache.tez.common.security.JobTokenIdentifier;
+import org.apache.tez.common.security.TokenCache;
 import org.apache.tez.dag.api.TezException;
 import org.apache.tez.dag.records.TezTaskAttemptID;
 import org.slf4j.Logger;
@@ -60,11 +65,15 @@ public class UmbilicalUtils {
             Configuration conf,
             TezTaskAttemptID taskAttemptId,
             int shuffleId) throws IOException, InterruptedException, 
TezException {
-    UserGroupInformation taskOwner = 
UserGroupInformation.createRemoteUser(applicationId.toString());
-
     String host = conf.get(RSS_AM_SHUFFLE_MANAGER_ADDRESS);
     int port = conf.getInt(RSS_AM_SHUFFLE_MANAGER_PORT, -1);
     final InetSocketAddress address = NetUtils.createSocketAddrForHost(host, 
port);
+
+    UserGroupInformation taskOwner = 
UserGroupInformation.createRemoteUser(applicationId.toString());
+    Credentials credentials = 
UserGroupInformation.getCurrentUser().getCredentials();
+    Token<JobTokenIdentifier> jobToken = 
TokenCache.getSessionToken(credentials);
+    SecurityUtil.setTokenService(jobToken, address);
+    taskOwner.addToken(jobToken);
     TezRemoteShuffleUmbilicalProtocol umbilical = taskOwner
         .doAs(new 
PrivilegedExceptionAction<TezRemoteShuffleUmbilicalProtocol>() {
           @Override
@@ -91,7 +100,7 @@ public class UmbilicalUtils {
           TezTaskAttemptID taskAttemptId,
           int shuffleId) {
     try {
-      return doRequestShuffleServer(applicationId, conf, 
taskAttemptId,shuffleId);
+      return doRequestShuffleServer(applicationId, conf, taskAttemptId, 
shuffleId);
     } catch (IOException | InterruptedException | TezException e) {
       LOG.error("Failed to requestShuffleServer, applicationId:{}, 
taskAttemptId:{}, shuffleId:{}, worker:{}",
           applicationId, taskAttemptId, shuffleId, e);
diff --git 
a/client-tez/src/main/java/org/apache/tez/dag/app/RssDAGAppMaster.java 
b/client-tez/src/main/java/org/apache/tez/dag/app/RssDAGAppMaster.java
index f7d90c65..6baefe5a 100644
--- a/client-tez/src/main/java/org/apache/tez/dag/app/RssDAGAppMaster.java
+++ b/client-tez/src/main/java/org/apache/tez/dag/app/RssDAGAppMaster.java
@@ -30,6 +30,7 @@ import com.google.common.annotations.VisibleForTesting;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.security.Credentials;
 import org.apache.hadoop.security.UserGroupInformation;
+import org.apache.hadoop.security.token.Token;
 import org.apache.hadoop.util.ShutdownHookManager;
 import org.apache.hadoop.yarn.YarnUncaughtExceptionHandler;
 import org.apache.hadoop.yarn.api.ApplicationConstants;
@@ -49,6 +50,8 @@ import org.apache.tez.common.TezCommonUtils;
 import org.apache.tez.common.TezUtils;
 import org.apache.tez.common.TezUtilsInternal;
 import org.apache.tez.common.VersionInfo;
+import org.apache.tez.common.security.JobTokenIdentifier;
+import org.apache.tez.common.security.TokenCache;
 import org.apache.tez.dag.api.InputDescriptor;
 import org.apache.tez.dag.api.OutputDescriptor;
 import org.apache.tez.dag.api.TezConstants;
@@ -154,8 +157,9 @@ public class RssDAGAppMaster extends DAGAppMaster {
             heartbeatInterval,
             TimeUnit.MILLISECONDS);
 
-    appMaster.setTezRemoteShuffleManager(new 
TezRemoteShuffleManager(strAppAttemptId, null, conf,
-            strAppAttemptId, client));
+    Token<JobTokenIdentifier> sessionToken = 
TokenCache.getSessionToken(appMaster.getContext().getAppCredentials());
+    appMaster.setTezRemoteShuffleManager(
+        new TezRemoteShuffleManager(appMaster.getAppID().toString(), 
sessionToken, conf, strAppAttemptId, client));
     appMaster.getTezRemoteShuffleManager().initialize();
     appMaster.getTezRemoteShuffleManager().start();
 
diff --git 
a/client-tez/src/main/java/org/apache/tez/dag/app/TezRemoteShuffleManager.java 
b/client-tez/src/main/java/org/apache/tez/dag/app/TezRemoteShuffleManager.java
index 09f33957..840c9e87 100644
--- 
a/client-tez/src/main/java/org/apache/tez/dag/app/TezRemoteShuffleManager.java
+++ 
b/client-tez/src/main/java/org/apache/tez/dag/app/TezRemoteShuffleManager.java
@@ -44,10 +44,11 @@ import org.apache.tez.common.ServicePluginLifecycle;
 import org.apache.tez.common.ShuffleAssignmentsInfoWritable;
 import org.apache.tez.common.TezRemoteShuffleUmbilicalProtocol;
 import org.apache.tez.common.security.JobTokenIdentifier;
+import org.apache.tez.common.security.JobTokenSecretManager;
 import org.apache.tez.dag.api.TezConfiguration;
 import org.apache.tez.dag.api.TezException;
 import org.apache.tez.dag.api.TezUncheckedException;
-import org.apache.tez.dag.app.security.authorize.TezAMPolicyProvider;
+import org.apache.tez.dag.app.security.authorize.RssTezAMPolicyProvider;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -247,6 +248,8 @@ public class TezRemoteShuffleManager implements 
ServicePluginLifecycle {
         rssAmRpcBindPort = 0;
       }
 
+      JobTokenSecretManager jobTokenSecretManager = new 
JobTokenSecretManager();
+      jobTokenSecretManager.addTokenForJob(tokenIdentifier, sessionToken);
       server = new RPC.Builder(conf)
               .setProtocol(TezRemoteShuffleUmbilicalProtocol.class)
               .setBindAddress(rssAmRpcBindAddress)
@@ -256,13 +259,14 @@ public class TezRemoteShuffleManager implements 
ServicePluginLifecycle {
                       
conf.getInt(TezConfiguration.TEZ_AM_TASK_LISTENER_THREAD_COUNT,
                               
TezConfiguration.TEZ_AM_TASK_LISTENER_THREAD_COUNT_DEFAULT))
               .setPortRangeConfig(TezConfiguration.TEZ_AM_TASK_AM_PORT_RANGE)
+              .setSecretManager(jobTokenSecretManager)
               .build();
 
       // Enable service authorization?
       if (conf.getBoolean(
               CommonConfigurationKeysPublic.HADOOP_SECURITY_AUTHORIZATION,
               false)) {
-        refreshServiceAcls(conf, new TezAMPolicyProvider());
+        refreshServiceAcls(conf, new RssTezAMPolicyProvider());
       }
 
       server.start();
diff --git 
a/client-tez/src/main/java/org/apache/tez/dag/app/security/authorize/RssTezAMPolicyProvider.java
 
b/client-tez/src/main/java/org/apache/tez/dag/app/security/authorize/RssTezAMPolicyProvider.java
new file mode 100644
index 00000000..ec99c858
--- /dev/null
+++ 
b/client-tez/src/main/java/org/apache/tez/dag/app/security/authorize/RssTezAMPolicyProvider.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.dag.app.security.authorize;
+
+import org.apache.hadoop.security.authorize.PolicyProvider;
+import org.apache.hadoop.security.authorize.Service;
+import org.apache.tez.common.TezRemoteShuffleUmbilicalProtocol;
+import org.apache.tez.dag.api.TezConstants;
+
+public class RssTezAMPolicyProvider extends PolicyProvider {
+
+  private static final Service[] tezApplicationMasterServices =
+      new Service[] {
+          new 
Service(TezConstants.TEZ_AM_SECURITY_SERVICE_AUTHORIZATION_TASK_UMBILICAL,
+              TezRemoteShuffleUmbilicalProtocol.class)
+      };
+
+  @Override
+  public Service[] getServices() {
+    return tezApplicationMasterServices.clone();
+  }
+}
diff --git 
a/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssOrderedPartitionedKVOutput.java
 
b/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssOrderedPartitionedKVOutput.java
index 3b1bc77e..bf543587 100644
--- 
a/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssOrderedPartitionedKVOutput.java
+++ 
b/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssOrderedPartitionedKVOutput.java
@@ -35,7 +35,10 @@ import 
org.apache.hadoop.classification.InterfaceAudience.Public;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.ipc.RPC;
 import org.apache.hadoop.net.NetUtils;
+import org.apache.hadoop.security.Credentials;
+import org.apache.hadoop.security.SecurityUtil;
 import org.apache.hadoop.security.UserGroupInformation;
+import org.apache.hadoop.security.token.Token;
 import org.apache.hadoop.yarn.api.records.ApplicationId;
 import org.apache.tez.common.GetShuffleServerRequest;
 import org.apache.tez.common.GetShuffleServerResponse;
@@ -43,6 +46,8 @@ import org.apache.tez.common.RssTezUtils;
 import org.apache.tez.common.TezCommonUtils;
 import org.apache.tez.common.TezRemoteShuffleUmbilicalProtocol;
 import org.apache.tez.common.TezUtils;
+import org.apache.tez.common.security.JobTokenIdentifier;
+import org.apache.tez.common.security.TokenCache;
 import org.apache.tez.dag.api.TezConfiguration;
 import org.apache.tez.dag.records.TezDAGID;
 import org.apache.tez.dag.records.TezTaskAttemptID;
@@ -133,7 +138,10 @@ public class RssOrderedPartitionedKVOutput extends 
AbstractLogicalOutput {
     final InetSocketAddress address = NetUtils.createSocketAddrForHost(host, 
port);
 
     UserGroupInformation taskOwner = 
UserGroupInformation.createRemoteUser(this.applicationId.toString());
-
+    Credentials credentials = 
UserGroupInformation.getCurrentUser().getCredentials();
+    Token<JobTokenIdentifier> jobToken = 
TokenCache.getSessionToken(credentials);
+    SecurityUtil.setTokenService(jobToken, address);
+    taskOwner.addToken(jobToken);
     final TezRemoteShuffleUmbilicalProtocol umbilical = taskOwner
         .doAs(new 
PrivilegedExceptionAction<TezRemoteShuffleUmbilicalProtocol>() {
           @Override
diff --git 
a/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssUnorderedKVOutput.java
 
b/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssUnorderedKVOutput.java
index 76f133e7..86c1ef3b 100644
--- 
a/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssUnorderedKVOutput.java
+++ 
b/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssUnorderedKVOutput.java
@@ -35,7 +35,10 @@ import 
org.apache.hadoop.classification.InterfaceAudience.Public;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.ipc.RPC;
 import org.apache.hadoop.net.NetUtils;
+import org.apache.hadoop.security.Credentials;
+import org.apache.hadoop.security.SecurityUtil;
 import org.apache.hadoop.security.UserGroupInformation;
+import org.apache.hadoop.security.token.Token;
 import org.apache.hadoop.yarn.api.records.ApplicationId;
 import org.apache.tez.common.GetShuffleServerRequest;
 import org.apache.tez.common.GetShuffleServerResponse;
@@ -43,6 +46,8 @@ import org.apache.tez.common.RssTezUtils;
 import org.apache.tez.common.TezCommonUtils;
 import org.apache.tez.common.TezRemoteShuffleUmbilicalProtocol;
 import org.apache.tez.common.TezUtils;
+import org.apache.tez.common.security.JobTokenIdentifier;
+import org.apache.tez.common.security.TokenCache;
 import org.apache.tez.dag.api.TezConfiguration;
 import org.apache.tez.dag.records.TezDAGID;
 import org.apache.tez.dag.records.TezTaskAttemptID;
@@ -136,7 +141,10 @@ public class RssUnorderedKVOutput extends 
AbstractLogicalOutput {
     final InetSocketAddress address = NetUtils.createSocketAddrForHost(host, 
port);
 
     UserGroupInformation taskOwner = 
UserGroupInformation.createRemoteUser(this.applicationId.toString());
-
+    Credentials credentials = 
UserGroupInformation.getCurrentUser().getCredentials();
+    Token<JobTokenIdentifier> jobToken = 
TokenCache.getSessionToken(credentials);
+    SecurityUtil.setTokenService(jobToken, address);
+    taskOwner.addToken(jobToken);
     final TezRemoteShuffleUmbilicalProtocol umbilical = taskOwner
         .doAs(new 
PrivilegedExceptionAction<TezRemoteShuffleUmbilicalProtocol>() {
           @Override
diff --git 
a/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssUnorderedPartitionedKVOutput.java
 
b/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssUnorderedPartitionedKVOutput.java
index 2b2262e8..5a97f4f9 100644
--- 
a/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssUnorderedPartitionedKVOutput.java
+++ 
b/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssUnorderedPartitionedKVOutput.java
@@ -35,7 +35,10 @@ import 
org.apache.hadoop.classification.InterfaceAudience.Public;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.ipc.RPC;
 import org.apache.hadoop.net.NetUtils;
+import org.apache.hadoop.security.Credentials;
+import org.apache.hadoop.security.SecurityUtil;
 import org.apache.hadoop.security.UserGroupInformation;
+import org.apache.hadoop.security.token.Token;
 import org.apache.hadoop.yarn.api.records.ApplicationId;
 import org.apache.tez.common.GetShuffleServerRequest;
 import org.apache.tez.common.GetShuffleServerResponse;
@@ -43,6 +46,8 @@ import org.apache.tez.common.RssTezUtils;
 import org.apache.tez.common.TezCommonUtils;
 import org.apache.tez.common.TezRemoteShuffleUmbilicalProtocol;
 import org.apache.tez.common.TezUtils;
+import org.apache.tez.common.security.JobTokenIdentifier;
+import org.apache.tez.common.security.TokenCache;
 import org.apache.tez.dag.api.TezConfiguration;
 import org.apache.tez.dag.records.TezDAGID;
 import org.apache.tez.dag.records.TezTaskAttemptID;
@@ -134,6 +139,10 @@ public class RssUnorderedPartitionedKVOutput extends 
AbstractLogicalOutput {
     final InetSocketAddress address = NetUtils.createSocketAddrForHost(host, 
port);
 
     UserGroupInformation taskOwner = 
UserGroupInformation.createRemoteUser(this.applicationId.toString());
+    Credentials credentials = 
UserGroupInformation.getCurrentUser().getCredentials();
+    Token<JobTokenIdentifier> jobToken = 
TokenCache.getSessionToken(credentials);
+    SecurityUtil.setTokenService(jobToken, address);
+    taskOwner.addToken(jobToken);
     final TezRemoteShuffleUmbilicalProtocol umbilical = taskOwner
         .doAs(new 
PrivilegedExceptionAction<TezRemoteShuffleUmbilicalProtocol>() {
           @Override
diff --git 
a/client-tez/src/test/java/org/apache/tez/dag/app/TezRemoteShuffleManagerTest.java
 
b/client-tez/src/test/java/org/apache/tez/dag/app/TezRemoteShuffleManagerTest.java
index 3af94aa0..20f4cf2c 100644
--- 
a/client-tez/src/test/java/org/apache/tez/dag/app/TezRemoteShuffleManagerTest.java
+++ 
b/client-tez/src/test/java/org/apache/tez/dag/app/TezRemoteShuffleManagerTest.java
@@ -26,13 +26,20 @@ import java.util.List;
 import java.util.Map;
 
 import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.Text;
 import org.apache.hadoop.ipc.RPC;
 import org.apache.hadoop.net.NetUtils;
+import org.apache.hadoop.security.Credentials;
+import org.apache.hadoop.security.SecurityUtil;
 import org.apache.hadoop.security.UserGroupInformation;
+import org.apache.hadoop.security.token.Token;
 import org.apache.hadoop.yarn.api.records.ApplicationId;
 import org.apache.tez.common.GetShuffleServerRequest;
 import org.apache.tez.common.GetShuffleServerResponse;
 import org.apache.tez.common.TezRemoteShuffleUmbilicalProtocol;
+import org.apache.tez.common.security.JobTokenIdentifier;
+import org.apache.tez.common.security.JobTokenSecretManager;
+import org.apache.tez.common.security.TokenCache;
 import org.apache.tez.dag.records.TezDAGID;
 import org.apache.tez.dag.records.TezTaskAttemptID;
 import org.apache.tez.dag.records.TezTaskID;
@@ -97,7 +104,12 @@ public class TezRemoteShuffleManagerTest {
       ApplicationId appId = ApplicationId.newInstance(9999, 72);
 
       Configuration conf = new Configuration();
-      TezRemoteShuffleManager tezRemoteShuffleManager = new 
TezRemoteShuffleManager(appId.toString(), null,
+      JobTokenIdentifier identifier = new JobTokenIdentifier(new 
Text(appId.toString()));
+      JobTokenSecretManager secretManager = new JobTokenSecretManager();
+      String tokenIdentifier = appId.toString();
+      Token<JobTokenIdentifier> sessionToken = new Token(identifier, 
secretManager);
+      secretManager.addTokenForJob(tokenIdentifier, sessionToken);
+      TezRemoteShuffleManager tezRemoteShuffleManager = new 
TezRemoteShuffleManager(appId.toString(), sessionToken,
               conf, appId.toString(), client);
       tezRemoteShuffleManager.initialize();
       tezRemoteShuffleManager.start();
@@ -106,7 +118,6 @@ public class TezRemoteShuffleManagerTest {
       int port = tezRemoteShuffleManager.getAddress().getPort();
       final InetSocketAddress address = NetUtils.createSocketAddrForHost(host, 
port);
 
-      String tokenIdentifier = appId.toString();
       UserGroupInformation taskOwner = 
UserGroupInformation.createRemoteUser(tokenIdentifier);
 
       TezRemoteShuffleUmbilicalProtocol umbilical = taskOwner.doAs(
@@ -140,4 +151,101 @@ public class TezRemoteShuffleManagerTest {
       fail();
     }
   }
+
+  @Test
+  public void testTezRemoteShuffleManagerSecure() {
+    try {
+      Map<Integer, List<ShuffleServerInfo>> partitionToServers = new 
HashMap<>();
+      partitionToServers.put(0, new ArrayList<>());
+      partitionToServers.put(1, new ArrayList<>());
+      partitionToServers.put(2, new ArrayList<>());
+      partitionToServers.put(3, new ArrayList<>());
+      partitionToServers.put(4, new ArrayList<>());
+
+      ShuffleServerInfo work1 = new ShuffleServerInfo("host1", 9999);
+      ShuffleServerInfo work2 = new ShuffleServerInfo("host2", 9999);
+      ShuffleServerInfo work3 = new ShuffleServerInfo("host3", 9999);
+      ShuffleServerInfo work4 = new ShuffleServerInfo("host4", 9999);
+
+      partitionToServers.get(0).addAll(Arrays.asList(work1, work2, work3, 
work4));
+      partitionToServers.get(1).addAll(Arrays.asList(work1, work2, work3, 
work4));
+      partitionToServers.get(2).addAll(Arrays.asList(work1, work3));
+      partitionToServers.get(3).addAll(Arrays.asList(work3, work4));
+      partitionToServers.get(4).addAll(Arrays.asList(work2, work4));
+
+      Map<ShuffleServerInfo, List<PartitionRange>> serverToPartitionRanges = 
new HashMap<>();
+      PartitionRange range0 = new PartitionRange(0, 0);
+      PartitionRange range1 = new PartitionRange(1, 1);
+      PartitionRange range2 = new PartitionRange(2, 2);
+      PartitionRange range3 = new PartitionRange(3, 3);
+      PartitionRange range4 = new PartitionRange(4, 4);
+
+      serverToPartitionRanges.put(work1, Arrays.asList(range0, range1, 
range2));
+      serverToPartitionRanges.put(work2, Arrays.asList(range0, range1, 
range4));
+      serverToPartitionRanges.put(work3, Arrays.asList(range0, range1, range2, 
range3));
+      serverToPartitionRanges.put(work4, Arrays.asList(range0, range1, range3, 
range4));
+
+      ShuffleAssignmentsInfo shuffleAssignmentsInfo = new 
ShuffleAssignmentsInfo(partitionToServers,
+          serverToPartitionRanges);
+
+      ShuffleWriteClient client = mock(ShuffleWriteClient.class);
+      when(client.getShuffleAssignments(anyString(), anyInt(), anyInt(), 
anyInt(), anySet(), anyInt(), anyInt()))
+          .thenReturn(shuffleAssignmentsInfo);
+
+      ApplicationId appId = ApplicationId.newInstance(9999, 72);
+
+      Configuration conf = new Configuration();
+      conf.set("hadoop.security.authentication", "kerberos");
+      JobTokenIdentifier identifier = new JobTokenIdentifier(new 
Text(appId.toString()));
+      JobTokenSecretManager secretManager = new JobTokenSecretManager();
+      String tokenIdentifier = appId.toString();
+      Token<JobTokenIdentifier> sessionToken = new Token(identifier, 
secretManager);
+      Credentials credentials = new Credentials();
+      TokenCache.setSessionToken(sessionToken, credentials);
+      secretManager.addTokenForJob(tokenIdentifier, sessionToken);
+      TezRemoteShuffleManager tezRemoteShuffleManager = new 
TezRemoteShuffleManager(appId.toString(), sessionToken,
+          conf, appId.toString(), client);
+      tezRemoteShuffleManager.initialize();
+      tezRemoteShuffleManager.start();
+
+      String host = tezRemoteShuffleManager.getAddress().getHostString();
+      int port = tezRemoteShuffleManager.getAddress().getPort();
+      final InetSocketAddress address = NetUtils.createSocketAddrForHost(host, 
port);
+
+      // here we omit the procces of deliver the credentials, just use it 
directly
+      UserGroupInformation taskOwner = 
UserGroupInformation.createRemoteUser(tokenIdentifier);
+      Token<JobTokenIdentifier> jobToken = 
TokenCache.getSessionToken(credentials);
+      SecurityUtil.setTokenService(jobToken, address);
+      taskOwner.addToken(jobToken);
+      TezRemoteShuffleUmbilicalProtocol umbilical = taskOwner.doAs(
+          new PrivilegedExceptionAction<TezRemoteShuffleUmbilicalProtocol>() {
+            @Override
+            public TezRemoteShuffleUmbilicalProtocol run() throws Exception {
+              return RPC.getProxy(TezRemoteShuffleUmbilicalProtocol.class,
+                  TezRemoteShuffleUmbilicalProtocol.versionID, address, conf);
+            }
+          });
+
+      TezDAGID dagId = TezDAGID.getInstance(appId, 1);
+      TezVertexID vId = TezVertexID.getInstance(dagId, 35);
+      TezTaskID tId = TezTaskID.getInstance(vId, 389);
+      TezTaskAttemptID taId = TezTaskAttemptID.getInstance(tId, 2);
+
+      int mapNum = 1;
+      int shuffleId = 10001;
+      int reduceNum = shuffleAssignmentsInfo.getPartitionToServers().size();
+
+      String errorMessage = "failed to get Shuffle Assignments";
+      GetShuffleServerRequest request = new GetShuffleServerRequest(taId, 
mapNum, reduceNum, shuffleId);
+      GetShuffleServerResponse response = 
umbilical.getShuffleAssignments(request);
+      assertEquals(0, response.getStatus(), errorMessage);
+      assertEquals(reduceNum, 
response.getShuffleAssignmentsInfoWritable().getShuffleAssignmentsInfo()
+          .getPartitionToServers().size());
+
+    } catch (Exception e) {
+      e.printStackTrace();
+      assertEquals("test", e.getMessage());
+      fail();
+    }
+  }
 }

Reply via email to