This is an automated email from the ASF dual-hosted git repository.
roryqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new cfd927fa [#991] Improvement(tez): TezRemoteShuffleManager support
secure cluster. (#1005)
cfd927fa is described below
commit cfd927fa62972010acb8e14d598a1fa61b94d35a
Author: zhengchenyu <[email protected]>
AuthorDate: Wed Jul 12 15:33:33 2023 +0800
[#991] Improvement(tez): TezRemoteShuffleManager support secure cluster.
(#1005)
### What changes were proposed in this pull request?
support secure cluster.
issue: #991
### How was this patch tested?
unit test, test on yarn cluster.
---
.../java/org/apache/tez/common/UmbilicalUtils.java | 15 ++-
.../org/apache/tez/dag/app/RssDAGAppMaster.java | 8 +-
.../tez/dag/app/TezRemoteShuffleManager.java | 8 +-
.../security/authorize/RssTezAMPolicyProvider.java | 37 +++++++
.../output/RssOrderedPartitionedKVOutput.java | 10 +-
.../library/output/RssUnorderedKVOutput.java | 10 +-
.../output/RssUnorderedPartitionedKVOutput.java | 9 ++
.../tez/dag/app/TezRemoteShuffleManagerTest.java | 112 ++++++++++++++++++++-
8 files changed, 198 insertions(+), 11 deletions(-)
diff --git a/client-tez/src/main/java/org/apache/tez/common/UmbilicalUtils.java
b/client-tez/src/main/java/org/apache/tez/common/UmbilicalUtils.java
index 26e0dc31..2c211131 100644
--- a/client-tez/src/main/java/org/apache/tez/common/UmbilicalUtils.java
+++ b/client-tez/src/main/java/org/apache/tez/common/UmbilicalUtils.java
@@ -26,8 +26,13 @@ import java.util.Map;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.ipc.RPC;
import org.apache.hadoop.net.NetUtils;
+import org.apache.hadoop.security.Credentials;
+import org.apache.hadoop.security.SecurityUtil;
import org.apache.hadoop.security.UserGroupInformation;
+import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.yarn.api.records.ApplicationId;
+import org.apache.tez.common.security.JobTokenIdentifier;
+import org.apache.tez.common.security.TokenCache;
import org.apache.tez.dag.api.TezException;
import org.apache.tez.dag.records.TezTaskAttemptID;
import org.slf4j.Logger;
@@ -60,11 +65,15 @@ public class UmbilicalUtils {
Configuration conf,
TezTaskAttemptID taskAttemptId,
int shuffleId) throws IOException, InterruptedException,
TezException {
- UserGroupInformation taskOwner =
UserGroupInformation.createRemoteUser(applicationId.toString());
-
String host = conf.get(RSS_AM_SHUFFLE_MANAGER_ADDRESS);
int port = conf.getInt(RSS_AM_SHUFFLE_MANAGER_PORT, -1);
final InetSocketAddress address = NetUtils.createSocketAddrForHost(host,
port);
+
+ UserGroupInformation taskOwner =
UserGroupInformation.createRemoteUser(applicationId.toString());
+ Credentials credentials =
UserGroupInformation.getCurrentUser().getCredentials();
+ Token<JobTokenIdentifier> jobToken =
TokenCache.getSessionToken(credentials);
+ SecurityUtil.setTokenService(jobToken, address);
+ taskOwner.addToken(jobToken);
TezRemoteShuffleUmbilicalProtocol umbilical = taskOwner
.doAs(new
PrivilegedExceptionAction<TezRemoteShuffleUmbilicalProtocol>() {
@Override
@@ -91,7 +100,7 @@ public class UmbilicalUtils {
TezTaskAttemptID taskAttemptId,
int shuffleId) {
try {
- return doRequestShuffleServer(applicationId, conf,
taskAttemptId,shuffleId);
+ return doRequestShuffleServer(applicationId, conf, taskAttemptId,
shuffleId);
} catch (IOException | InterruptedException | TezException e) {
LOG.error("Failed to requestShuffleServer, applicationId:{},
taskAttemptId:{}, shuffleId:{}, worker:{}",
applicationId, taskAttemptId, shuffleId, e);
diff --git
a/client-tez/src/main/java/org/apache/tez/dag/app/RssDAGAppMaster.java
b/client-tez/src/main/java/org/apache/tez/dag/app/RssDAGAppMaster.java
index f7d90c65..6baefe5a 100644
--- a/client-tez/src/main/java/org/apache/tez/dag/app/RssDAGAppMaster.java
+++ b/client-tez/src/main/java/org/apache/tez/dag/app/RssDAGAppMaster.java
@@ -30,6 +30,7 @@ import com.google.common.annotations.VisibleForTesting;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.security.Credentials;
import org.apache.hadoop.security.UserGroupInformation;
+import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.util.ShutdownHookManager;
import org.apache.hadoop.yarn.YarnUncaughtExceptionHandler;
import org.apache.hadoop.yarn.api.ApplicationConstants;
@@ -49,6 +50,8 @@ import org.apache.tez.common.TezCommonUtils;
import org.apache.tez.common.TezUtils;
import org.apache.tez.common.TezUtilsInternal;
import org.apache.tez.common.VersionInfo;
+import org.apache.tez.common.security.JobTokenIdentifier;
+import org.apache.tez.common.security.TokenCache;
import org.apache.tez.dag.api.InputDescriptor;
import org.apache.tez.dag.api.OutputDescriptor;
import org.apache.tez.dag.api.TezConstants;
@@ -154,8 +157,9 @@ public class RssDAGAppMaster extends DAGAppMaster {
heartbeatInterval,
TimeUnit.MILLISECONDS);
- appMaster.setTezRemoteShuffleManager(new
TezRemoteShuffleManager(strAppAttemptId, null, conf,
- strAppAttemptId, client));
+ Token<JobTokenIdentifier> sessionToken =
TokenCache.getSessionToken(appMaster.getContext().getAppCredentials());
+ appMaster.setTezRemoteShuffleManager(
+ new TezRemoteShuffleManager(appMaster.getAppID().toString(),
sessionToken, conf, strAppAttemptId, client));
appMaster.getTezRemoteShuffleManager().initialize();
appMaster.getTezRemoteShuffleManager().start();
diff --git
a/client-tez/src/main/java/org/apache/tez/dag/app/TezRemoteShuffleManager.java
b/client-tez/src/main/java/org/apache/tez/dag/app/TezRemoteShuffleManager.java
index 09f33957..840c9e87 100644
---
a/client-tez/src/main/java/org/apache/tez/dag/app/TezRemoteShuffleManager.java
+++
b/client-tez/src/main/java/org/apache/tez/dag/app/TezRemoteShuffleManager.java
@@ -44,10 +44,11 @@ import org.apache.tez.common.ServicePluginLifecycle;
import org.apache.tez.common.ShuffleAssignmentsInfoWritable;
import org.apache.tez.common.TezRemoteShuffleUmbilicalProtocol;
import org.apache.tez.common.security.JobTokenIdentifier;
+import org.apache.tez.common.security.JobTokenSecretManager;
import org.apache.tez.dag.api.TezConfiguration;
import org.apache.tez.dag.api.TezException;
import org.apache.tez.dag.api.TezUncheckedException;
-import org.apache.tez.dag.app.security.authorize.TezAMPolicyProvider;
+import org.apache.tez.dag.app.security.authorize.RssTezAMPolicyProvider;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -247,6 +248,8 @@ public class TezRemoteShuffleManager implements
ServicePluginLifecycle {
rssAmRpcBindPort = 0;
}
+ JobTokenSecretManager jobTokenSecretManager = new
JobTokenSecretManager();
+ jobTokenSecretManager.addTokenForJob(tokenIdentifier, sessionToken);
server = new RPC.Builder(conf)
.setProtocol(TezRemoteShuffleUmbilicalProtocol.class)
.setBindAddress(rssAmRpcBindAddress)
@@ -256,13 +259,14 @@ public class TezRemoteShuffleManager implements
ServicePluginLifecycle {
conf.getInt(TezConfiguration.TEZ_AM_TASK_LISTENER_THREAD_COUNT,
TezConfiguration.TEZ_AM_TASK_LISTENER_THREAD_COUNT_DEFAULT))
.setPortRangeConfig(TezConfiguration.TEZ_AM_TASK_AM_PORT_RANGE)
+ .setSecretManager(jobTokenSecretManager)
.build();
// Enable service authorization?
if (conf.getBoolean(
CommonConfigurationKeysPublic.HADOOP_SECURITY_AUTHORIZATION,
false)) {
- refreshServiceAcls(conf, new TezAMPolicyProvider());
+ refreshServiceAcls(conf, new RssTezAMPolicyProvider());
}
server.start();
diff --git
a/client-tez/src/main/java/org/apache/tez/dag/app/security/authorize/RssTezAMPolicyProvider.java
b/client-tez/src/main/java/org/apache/tez/dag/app/security/authorize/RssTezAMPolicyProvider.java
new file mode 100644
index 00000000..ec99c858
--- /dev/null
+++
b/client-tez/src/main/java/org/apache/tez/dag/app/security/authorize/RssTezAMPolicyProvider.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.dag.app.security.authorize;
+
+import org.apache.hadoop.security.authorize.PolicyProvider;
+import org.apache.hadoop.security.authorize.Service;
+import org.apache.tez.common.TezRemoteShuffleUmbilicalProtocol;
+import org.apache.tez.dag.api.TezConstants;
+
+public class RssTezAMPolicyProvider extends PolicyProvider {
+
+ private static final Service[] tezApplicationMasterServices =
+ new Service[] {
+ new
Service(TezConstants.TEZ_AM_SECURITY_SERVICE_AUTHORIZATION_TASK_UMBILICAL,
+ TezRemoteShuffleUmbilicalProtocol.class)
+ };
+
+ @Override
+ public Service[] getServices() {
+ return tezApplicationMasterServices.clone();
+ }
+}
diff --git
a/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssOrderedPartitionedKVOutput.java
b/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssOrderedPartitionedKVOutput.java
index 3b1bc77e..bf543587 100644
---
a/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssOrderedPartitionedKVOutput.java
+++
b/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssOrderedPartitionedKVOutput.java
@@ -35,7 +35,10 @@ import
org.apache.hadoop.classification.InterfaceAudience.Public;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.ipc.RPC;
import org.apache.hadoop.net.NetUtils;
+import org.apache.hadoop.security.Credentials;
+import org.apache.hadoop.security.SecurityUtil;
import org.apache.hadoop.security.UserGroupInformation;
+import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.tez.common.GetShuffleServerRequest;
import org.apache.tez.common.GetShuffleServerResponse;
@@ -43,6 +46,8 @@ import org.apache.tez.common.RssTezUtils;
import org.apache.tez.common.TezCommonUtils;
import org.apache.tez.common.TezRemoteShuffleUmbilicalProtocol;
import org.apache.tez.common.TezUtils;
+import org.apache.tez.common.security.JobTokenIdentifier;
+import org.apache.tez.common.security.TokenCache;
import org.apache.tez.dag.api.TezConfiguration;
import org.apache.tez.dag.records.TezDAGID;
import org.apache.tez.dag.records.TezTaskAttemptID;
@@ -133,7 +138,10 @@ public class RssOrderedPartitionedKVOutput extends
AbstractLogicalOutput {
final InetSocketAddress address = NetUtils.createSocketAddrForHost(host,
port);
UserGroupInformation taskOwner =
UserGroupInformation.createRemoteUser(this.applicationId.toString());
-
+ Credentials credentials =
UserGroupInformation.getCurrentUser().getCredentials();
+ Token<JobTokenIdentifier> jobToken =
TokenCache.getSessionToken(credentials);
+ SecurityUtil.setTokenService(jobToken, address);
+ taskOwner.addToken(jobToken);
final TezRemoteShuffleUmbilicalProtocol umbilical = taskOwner
.doAs(new
PrivilegedExceptionAction<TezRemoteShuffleUmbilicalProtocol>() {
@Override
diff --git
a/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssUnorderedKVOutput.java
b/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssUnorderedKVOutput.java
index 76f133e7..86c1ef3b 100644
---
a/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssUnorderedKVOutput.java
+++
b/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssUnorderedKVOutput.java
@@ -35,7 +35,10 @@ import
org.apache.hadoop.classification.InterfaceAudience.Public;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.ipc.RPC;
import org.apache.hadoop.net.NetUtils;
+import org.apache.hadoop.security.Credentials;
+import org.apache.hadoop.security.SecurityUtil;
import org.apache.hadoop.security.UserGroupInformation;
+import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.tez.common.GetShuffleServerRequest;
import org.apache.tez.common.GetShuffleServerResponse;
@@ -43,6 +46,8 @@ import org.apache.tez.common.RssTezUtils;
import org.apache.tez.common.TezCommonUtils;
import org.apache.tez.common.TezRemoteShuffleUmbilicalProtocol;
import org.apache.tez.common.TezUtils;
+import org.apache.tez.common.security.JobTokenIdentifier;
+import org.apache.tez.common.security.TokenCache;
import org.apache.tez.dag.api.TezConfiguration;
import org.apache.tez.dag.records.TezDAGID;
import org.apache.tez.dag.records.TezTaskAttemptID;
@@ -136,7 +141,10 @@ public class RssUnorderedKVOutput extends
AbstractLogicalOutput {
final InetSocketAddress address = NetUtils.createSocketAddrForHost(host,
port);
UserGroupInformation taskOwner =
UserGroupInformation.createRemoteUser(this.applicationId.toString());
-
+ Credentials credentials =
UserGroupInformation.getCurrentUser().getCredentials();
+ Token<JobTokenIdentifier> jobToken =
TokenCache.getSessionToken(credentials);
+ SecurityUtil.setTokenService(jobToken, address);
+ taskOwner.addToken(jobToken);
final TezRemoteShuffleUmbilicalProtocol umbilical = taskOwner
.doAs(new
PrivilegedExceptionAction<TezRemoteShuffleUmbilicalProtocol>() {
@Override
diff --git
a/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssUnorderedPartitionedKVOutput.java
b/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssUnorderedPartitionedKVOutput.java
index 2b2262e8..5a97f4f9 100644
---
a/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssUnorderedPartitionedKVOutput.java
+++
b/client-tez/src/main/java/org/apache/tez/runtime/library/output/RssUnorderedPartitionedKVOutput.java
@@ -35,7 +35,10 @@ import
org.apache.hadoop.classification.InterfaceAudience.Public;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.ipc.RPC;
import org.apache.hadoop.net.NetUtils;
+import org.apache.hadoop.security.Credentials;
+import org.apache.hadoop.security.SecurityUtil;
import org.apache.hadoop.security.UserGroupInformation;
+import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.tez.common.GetShuffleServerRequest;
import org.apache.tez.common.GetShuffleServerResponse;
@@ -43,6 +46,8 @@ import org.apache.tez.common.RssTezUtils;
import org.apache.tez.common.TezCommonUtils;
import org.apache.tez.common.TezRemoteShuffleUmbilicalProtocol;
import org.apache.tez.common.TezUtils;
+import org.apache.tez.common.security.JobTokenIdentifier;
+import org.apache.tez.common.security.TokenCache;
import org.apache.tez.dag.api.TezConfiguration;
import org.apache.tez.dag.records.TezDAGID;
import org.apache.tez.dag.records.TezTaskAttemptID;
@@ -134,6 +139,10 @@ public class RssUnorderedPartitionedKVOutput extends
AbstractLogicalOutput {
final InetSocketAddress address = NetUtils.createSocketAddrForHost(host,
port);
UserGroupInformation taskOwner =
UserGroupInformation.createRemoteUser(this.applicationId.toString());
+ Credentials credentials =
UserGroupInformation.getCurrentUser().getCredentials();
+ Token<JobTokenIdentifier> jobToken =
TokenCache.getSessionToken(credentials);
+ SecurityUtil.setTokenService(jobToken, address);
+ taskOwner.addToken(jobToken);
final TezRemoteShuffleUmbilicalProtocol umbilical = taskOwner
.doAs(new
PrivilegedExceptionAction<TezRemoteShuffleUmbilicalProtocol>() {
@Override
diff --git
a/client-tez/src/test/java/org/apache/tez/dag/app/TezRemoteShuffleManagerTest.java
b/client-tez/src/test/java/org/apache/tez/dag/app/TezRemoteShuffleManagerTest.java
index 3af94aa0..20f4cf2c 100644
---
a/client-tez/src/test/java/org/apache/tez/dag/app/TezRemoteShuffleManagerTest.java
+++
b/client-tez/src/test/java/org/apache/tez/dag/app/TezRemoteShuffleManagerTest.java
@@ -26,13 +26,20 @@ import java.util.List;
import java.util.Map;
import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.Text;
import org.apache.hadoop.ipc.RPC;
import org.apache.hadoop.net.NetUtils;
+import org.apache.hadoop.security.Credentials;
+import org.apache.hadoop.security.SecurityUtil;
import org.apache.hadoop.security.UserGroupInformation;
+import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.tez.common.GetShuffleServerRequest;
import org.apache.tez.common.GetShuffleServerResponse;
import org.apache.tez.common.TezRemoteShuffleUmbilicalProtocol;
+import org.apache.tez.common.security.JobTokenIdentifier;
+import org.apache.tez.common.security.JobTokenSecretManager;
+import org.apache.tez.common.security.TokenCache;
import org.apache.tez.dag.records.TezDAGID;
import org.apache.tez.dag.records.TezTaskAttemptID;
import org.apache.tez.dag.records.TezTaskID;
@@ -97,7 +104,12 @@ public class TezRemoteShuffleManagerTest {
ApplicationId appId = ApplicationId.newInstance(9999, 72);
Configuration conf = new Configuration();
- TezRemoteShuffleManager tezRemoteShuffleManager = new
TezRemoteShuffleManager(appId.toString(), null,
+ JobTokenIdentifier identifier = new JobTokenIdentifier(new
Text(appId.toString()));
+ JobTokenSecretManager secretManager = new JobTokenSecretManager();
+ String tokenIdentifier = appId.toString();
+ Token<JobTokenIdentifier> sessionToken = new Token(identifier,
secretManager);
+ secretManager.addTokenForJob(tokenIdentifier, sessionToken);
+ TezRemoteShuffleManager tezRemoteShuffleManager = new
TezRemoteShuffleManager(appId.toString(), sessionToken,
conf, appId.toString(), client);
tezRemoteShuffleManager.initialize();
tezRemoteShuffleManager.start();
@@ -106,7 +118,6 @@ public class TezRemoteShuffleManagerTest {
int port = tezRemoteShuffleManager.getAddress().getPort();
final InetSocketAddress address = NetUtils.createSocketAddrForHost(host,
port);
- String tokenIdentifier = appId.toString();
UserGroupInformation taskOwner =
UserGroupInformation.createRemoteUser(tokenIdentifier);
TezRemoteShuffleUmbilicalProtocol umbilical = taskOwner.doAs(
@@ -140,4 +151,101 @@ public class TezRemoteShuffleManagerTest {
fail();
}
}
+
+ @Test
+ public void testTezRemoteShuffleManagerSecure() {
+ try {
+ Map<Integer, List<ShuffleServerInfo>> partitionToServers = new
HashMap<>();
+ partitionToServers.put(0, new ArrayList<>());
+ partitionToServers.put(1, new ArrayList<>());
+ partitionToServers.put(2, new ArrayList<>());
+ partitionToServers.put(3, new ArrayList<>());
+ partitionToServers.put(4, new ArrayList<>());
+
+ ShuffleServerInfo work1 = new ShuffleServerInfo("host1", 9999);
+ ShuffleServerInfo work2 = new ShuffleServerInfo("host2", 9999);
+ ShuffleServerInfo work3 = new ShuffleServerInfo("host3", 9999);
+ ShuffleServerInfo work4 = new ShuffleServerInfo("host4", 9999);
+
+ partitionToServers.get(0).addAll(Arrays.asList(work1, work2, work3,
work4));
+ partitionToServers.get(1).addAll(Arrays.asList(work1, work2, work3,
work4));
+ partitionToServers.get(2).addAll(Arrays.asList(work1, work3));
+ partitionToServers.get(3).addAll(Arrays.asList(work3, work4));
+ partitionToServers.get(4).addAll(Arrays.asList(work2, work4));
+
+ Map<ShuffleServerInfo, List<PartitionRange>> serverToPartitionRanges =
new HashMap<>();
+ PartitionRange range0 = new PartitionRange(0, 0);
+ PartitionRange range1 = new PartitionRange(1, 1);
+ PartitionRange range2 = new PartitionRange(2, 2);
+ PartitionRange range3 = new PartitionRange(3, 3);
+ PartitionRange range4 = new PartitionRange(4, 4);
+
+ serverToPartitionRanges.put(work1, Arrays.asList(range0, range1,
range2));
+ serverToPartitionRanges.put(work2, Arrays.asList(range0, range1,
range4));
+ serverToPartitionRanges.put(work3, Arrays.asList(range0, range1, range2,
range3));
+ serverToPartitionRanges.put(work4, Arrays.asList(range0, range1, range3,
range4));
+
+ ShuffleAssignmentsInfo shuffleAssignmentsInfo = new
ShuffleAssignmentsInfo(partitionToServers,
+ serverToPartitionRanges);
+
+ ShuffleWriteClient client = mock(ShuffleWriteClient.class);
+ when(client.getShuffleAssignments(anyString(), anyInt(), anyInt(),
anyInt(), anySet(), anyInt(), anyInt()))
+ .thenReturn(shuffleAssignmentsInfo);
+
+ ApplicationId appId = ApplicationId.newInstance(9999, 72);
+
+ Configuration conf = new Configuration();
+ conf.set("hadoop.security.authentication", "kerberos");
+ JobTokenIdentifier identifier = new JobTokenIdentifier(new
Text(appId.toString()));
+ JobTokenSecretManager secretManager = new JobTokenSecretManager();
+ String tokenIdentifier = appId.toString();
+ Token<JobTokenIdentifier> sessionToken = new Token(identifier,
secretManager);
+ Credentials credentials = new Credentials();
+ TokenCache.setSessionToken(sessionToken, credentials);
+ secretManager.addTokenForJob(tokenIdentifier, sessionToken);
+ TezRemoteShuffleManager tezRemoteShuffleManager = new
TezRemoteShuffleManager(appId.toString(), sessionToken,
+ conf, appId.toString(), client);
+ tezRemoteShuffleManager.initialize();
+ tezRemoteShuffleManager.start();
+
+ String host = tezRemoteShuffleManager.getAddress().getHostString();
+ int port = tezRemoteShuffleManager.getAddress().getPort();
+ final InetSocketAddress address = NetUtils.createSocketAddrForHost(host,
port);
+
+ // here we omit the procces of deliver the credentials, just use it
directly
+ UserGroupInformation taskOwner =
UserGroupInformation.createRemoteUser(tokenIdentifier);
+ Token<JobTokenIdentifier> jobToken =
TokenCache.getSessionToken(credentials);
+ SecurityUtil.setTokenService(jobToken, address);
+ taskOwner.addToken(jobToken);
+ TezRemoteShuffleUmbilicalProtocol umbilical = taskOwner.doAs(
+ new PrivilegedExceptionAction<TezRemoteShuffleUmbilicalProtocol>() {
+ @Override
+ public TezRemoteShuffleUmbilicalProtocol run() throws Exception {
+ return RPC.getProxy(TezRemoteShuffleUmbilicalProtocol.class,
+ TezRemoteShuffleUmbilicalProtocol.versionID, address, conf);
+ }
+ });
+
+ TezDAGID dagId = TezDAGID.getInstance(appId, 1);
+ TezVertexID vId = TezVertexID.getInstance(dagId, 35);
+ TezTaskID tId = TezTaskID.getInstance(vId, 389);
+ TezTaskAttemptID taId = TezTaskAttemptID.getInstance(tId, 2);
+
+ int mapNum = 1;
+ int shuffleId = 10001;
+ int reduceNum = shuffleAssignmentsInfo.getPartitionToServers().size();
+
+ String errorMessage = "failed to get Shuffle Assignments";
+ GetShuffleServerRequest request = new GetShuffleServerRequest(taId,
mapNum, reduceNum, shuffleId);
+ GetShuffleServerResponse response =
umbilical.getShuffleAssignments(request);
+ assertEquals(0, response.getStatus(), errorMessage);
+ assertEquals(reduceNum,
response.getShuffleAssignmentsInfoWritable().getShuffleAssignmentsInfo()
+ .getPartitionToServers().size());
+
+ } catch (Exception e) {
+ e.printStackTrace();
+ assertEquals("test", e.getMessage());
+ fail();
+ }
+ }
}