This is an automated email from the ASF dual-hosted git repository.
dimuthuupe pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airavata-mft.git
The following commit(s) were added to refs/heads/master by this push:
new 1a3c3b7 Optimizing AWS agent termination logic
1a3c3b7 is described below
commit 1a3c3b73493ba96ae83d91cee66085b304db9872
Author: Dimuthu Wannipurage <[email protected]>
AuthorDate: Fri Jan 6 14:47:11 2023 -0500
Optimizing AWS agent termination logic
---
.../airavata/mft/agent/TransferOrchestrator.java | 9 ++--
.../mft/agent/ingress/ConsulIngressHandler.java | 9 ++--
.../apache/airavata/mft/admin/MFTConsulClient.java | 52 ++++++++++++++++++++++
.../mft/controller/spawner/AgentOrchestrator.java | 50 ++++++++-------------
.../mft/controller/spawner/EC2AgentSpawner.java | 27 +++++++----
.../mft/controller/spawner/SSHProvider.java | 21 ++++++++-
6 files changed, 115 insertions(+), 53 deletions(-)
diff --git
a/agent/service/src/main/java/org/apache/airavata/mft/agent/TransferOrchestrator.java
b/agent/service/src/main/java/org/apache/airavata/mft/agent/TransferOrchestrator.java
index 7ca6e3a..8b4fc3e 100644
---
a/agent/service/src/main/java/org/apache/airavata/mft/agent/TransferOrchestrator.java
+++
b/agent/service/src/main/java/org/apache/airavata/mft/agent/TransferOrchestrator.java
@@ -79,7 +79,7 @@ public class TransferOrchestrator {
public void submitTransferToProcess(String transferId,
AgentTransferRequest request,
BiConsumer<EndpointPaths,
TransferState> updateStatus,
- Consumer<Boolean> createTransferHook) {
+ BiConsumer<EndpointPaths, Boolean>
createTransferHook) {
long totalPending =
totalPendingTransfers.addAndGet(request.getEndpointPathsCount());
logger.info("Total pending files to transfer {}", totalPending);
for (EndpointPaths endpointPath : request.getEndpointPathsList()) {
@@ -94,7 +94,8 @@ public class TransferOrchestrator {
public void processTransfer(String transferId, String requestId,
StorageWrapper sourceStorage, StorageWrapper destStorage,
SecretWrapper sourceSecret,SecretWrapper
destSecret, EndpointPaths endpointPath,
- BiConsumer<EndpointPaths, TransferState>
updateStatus, Consumer<Boolean> createTransferHook) {
+ BiConsumer<EndpointPaths, TransferState>
updateStatus,
+ BiConsumer<EndpointPaths, Boolean>
createTransferHook) {
try {
long running = totalRunningTransfers.incrementAndGet();
@@ -139,13 +140,13 @@ public class TransferOrchestrator {
.setDescription("Started the transfer"));
// Save transfer metadata in scheduled path to recover in case of
an Agent failures. Recovery is done from controller
- createTransferHook.accept(true);
+ createTransferHook.accept(endpointPath, true);
mediator.transferSingleThread(transferId, srcCC, dstCC,
updateStatus,
(id, transferSuccess) -> {
try {
// Delete scheduled key as the transfer completed
/ failed if it was placed in current session
- createTransferHook.accept(false);
+ createTransferHook.accept(endpointPath,false);
long pendingAfter =
totalRunningTransfers.decrementAndGet();
logger.info("Removed transfer {} from queue with
transfer success = {}. Total running {}",
id, transferSuccess, pendingAfter);
diff --git
a/agent/service/src/main/java/org/apache/airavata/mft/agent/ingress/ConsulIngressHandler.java
b/agent/service/src/main/java/org/apache/airavata/mft/agent/ingress/ConsulIngressHandler.java
index eeb262d..bc705d7 100644
---
a/agent/service/src/main/java/org/apache/airavata/mft/agent/ingress/ConsulIngressHandler.java
+++
b/agent/service/src/main/java/org/apache/airavata/mft/agent/ingress/ConsulIngressHandler.java
@@ -116,14 +116,11 @@ public class ConsulIngressHandler {
AgentUtil.throwingBiConsumerWrapper((endPointPath,
st) -> {
mftConsulClient.submitFileTransferStateToProcess(transferId,
request.getRequestId(), endPointPath, agentId, st.setPublisher(agentId));
}),
- AgentUtil.throwingConsumerWrapper(create -> {
+ AgentUtil.throwingBiConsumerWrapper((endpointPath,
create) -> {
if (create) {
-
mftConsulClient.getKvClient().putValue(MFTConsulClient.AGENTS_SCHEDULED_PATH
- + agentId + "/" + session
+ "/" + transferId + "/" + agentTransferRequestId,
- reqBytes, 0L, PutOptions.BLANK);
+
mftConsulClient.createEndpointHookForAgent(agentId, session, transferId,
agentTransferRequestId, endpointPath);
} else {
-
mftConsulClient.getKvClient().deleteKey(MFTConsulClient.AGENTS_SCHEDULED_PATH
- + agentId + "/" + session + "/" +
transferId + "/" + agentTransferRequestId);
+
mftConsulClient.deleteEndpointHookForAgent(agentId, session, transferId,
agentTransferRequestId, endpointPath);
}
}));
});
diff --git
a/common/common-clients/src/main/java/org/apache/airavata/mft/admin/MFTConsulClient.java
b/common/common-clients/src/main/java/org/apache/airavata/mft/admin/MFTConsulClient.java
index bf4b7eb..ef2bd76 100644
---
a/common/common-clients/src/main/java/org/apache/airavata/mft/admin/MFTConsulClient.java
+++
b/common/common-clients/src/main/java/org/apache/airavata/mft/admin/MFTConsulClient.java
@@ -141,6 +141,22 @@ public class MFTConsulClient {
}
}
+ public List<String> listPendingAgentTransfers(String agentId) throws
MFTConsulClientException {
+ try {
+ try {
+ return kvClient.getKeys(AGENTS_TRANSFER_REQUEST_MESSAGE_PATH +
agentId);
+ } catch (ConsulException e) {
+ if (e.getCode() == 404) {
+ return Collections.emptyList();
+ } else {
+ throw e;
+ }
+ }
+ } catch (Exception e) {
+ throw new MFTConsulClientException("Failed to list pending agent
transfers for agent " + agentId, e);
+ }
+ }
+
public void sendSyncRPCToAgent(String agentId, SyncRPCRequest rpcRequest)
throws MFTConsulClientException {
try {
String asString = mapper.writeValueAsString(rpcRequest);
@@ -381,6 +397,42 @@ public class MFTConsulClient {
return liveAgentIds.stream().map(id ->
getAgentInfo(id).get()).collect(Collectors.toList());
}
+ public void createEndpointHookForAgent(String agentId, String session,
String transferId,
+ String agentTransferRequestId,
+ EndpointPaths endpointPaths) {
+ getKvClient().putValue(MFTConsulClient.AGENTS_SCHEDULED_PATH
+ + agentId + "/" + session + "/" + transferId + "/" +
agentTransferRequestId
+ + "/" + getEndpointPathHash(endpointPaths),
+ endpointPaths.toByteArray(), 0L, PutOptions.BLANK);
+ }
+
+ public void deleteEndpointHookForAgent(String agentId, String session,
String transferId,
+ String agentTransferRequestId,
+ EndpointPaths endpointPaths) {
+ getKvClient().deleteKey(MFTConsulClient.AGENTS_SCHEDULED_PATH
+ + agentId + "/" + session + "/" + transferId + "/" +
agentTransferRequestId
+ + "/" + getEndpointPathHash(endpointPaths));
+ }
+
+ public int getEndpointHookCountForAgent(String agentId) throws
MFTConsulClientException {
+ Optional<String> session = getKvClient().getSession(LIVE_AGENTS_PATH +
agentId);
+
+ try {
+ try {
+ return session.map(s ->
getKvClient().getKeys(MFTConsulClient.AGENTS_SCHEDULED_PATH
+ + agentId + "/" + s).size()).orElse(0);
+ } catch (ConsulException e) {
+ if (e.getCode() == 404) {
+ return 0;
+ } else {
+ throw e;
+ }
+ }
+ } catch (Exception e) {
+ throw new MFTConsulClientException("Failed to fetch endpoint hook
count for agent " + agentId, e);
+ }
+ }
+
public KeyValueClient getKvClient() {
return kvClient;
}
diff --git
a/controller/src/main/java/org/apache/airavata/mft/controller/spawner/AgentOrchestrator.java
b/controller/src/main/java/org/apache/airavata/mft/controller/spawner/AgentOrchestrator.java
index 0ccec2e..454acbe 100644
---
a/controller/src/main/java/org/apache/airavata/mft/controller/spawner/AgentOrchestrator.java
+++
b/controller/src/main/java/org/apache/airavata/mft/controller/spawner/AgentOrchestrator.java
@@ -154,41 +154,27 @@ public class AgentOrchestrator {
if ((System.currentTimeMillis() -
metadata.lastScannedTime) > SPAWNER_MAX_IDLE_SECONDS * 1000) {
- long totalFiles = 0;
- long completedOrFailedFiles = 0;
-
- Map<String, TransferInfo> transferInfos =
metadata.transferInfos;
-
- for (String agentTransferId: transferInfos.keySet()) {
- TransferInfo transferInfo =
transferInfos.get(agentTransferId);
-
- try {
- totalFiles +=
transferInfo.agentTransferRequest.getEndpointPathsCount();
-
- List<TransferState> transferStates =
this.transferDispatcher.getMftConsulClient()
-
.getTransferStates(transferInfo.transferId, agentTransferId);
-
- completedOrFailedFiles +=
transferStates.stream()
- .filter(transferState ->
transferState.getState().equals("COMPLETED") ||
-
transferState.getState().equals("FAILED")).count();
-
-
-
- } catch (Exception e) {
- logger.error("Failed to fetch transfer states
for agent transfer id {}", agentTransferId, e);
- }
+ if (metadata.transferInfos.size() > 0) {
+ return;
}
- logger.info("Spawner with key {} has total {} files to
be transferred and {} were completed or failed",
- key, totalFiles, completedOrFailedFiles);
-
- if (totalFiles == completedOrFailedFiles) {
- // TODO create a write lock with reusing agent
logic
+ logger.info("No transfer infos for spawner {}.
Checking for termination", key);
+
+ try {
+ List<String> pendingAgentTransfers =
transferDispatcher.getMftConsulClient().listPendingAgentTransfers(metadata.spawner.getLaunchState().get());
+ if (pendingAgentTransfers.isEmpty()) {
+ int totalFilesInProgress =
transferDispatcher.getMftConsulClient().getEndpointHookCountForAgent(metadata.spawner.getLaunchState().get());
+ if (totalFilesInProgress == 0) {
+ logger.info("Killing spawner with key {}
as all files were transferred and the agent" +
+ " is inactive for {}
seconds",
+ key, SPAWNER_MAX_IDLE_SECONDS);
+ metadata.spawner.terminate();
+ launchedSpawnersMap.remove(key);
+ }
+ }
- logger.info("Killing spawner with key {} as all
files were transferred and inactive for {} seconds",
- key, SPAWNER_MAX_IDLE_SECONDS);
- metadata.spawner.terminate();
- launchedSpawnersMap.remove(key);
+ } catch (Exception e) {
+ logger.error("Failed while fetching the endpoint
count for agent", e);
}
}
});
diff --git
a/controller/src/main/java/org/apache/airavata/mft/controller/spawner/EC2AgentSpawner.java
b/controller/src/main/java/org/apache/airavata/mft/controller/spawner/EC2AgentSpawner.java
index a097f92..86a43e1 100644
---
a/controller/src/main/java/org/apache/airavata/mft/controller/spawner/EC2AgentSpawner.java
+++
b/controller/src/main/java/org/apache/airavata/mft/controller/spawner/EC2AgentSpawner.java
@@ -51,6 +51,7 @@ public class EC2AgentSpawner extends AgentSpawner {
private Future<String> launchFuture;
private Future<Boolean> terminateFuture;
private final Map<String, String> amiMap;
+ private SSHProvider sshProvider;
public EC2AgentSpawner(StorageWrapper storageWrapper, SecretWrapper
secretWrapper) {
super(storageWrapper, secretWrapper);
@@ -219,35 +220,35 @@ public class EC2AgentSpawner extends AgentSpawner {
logger.info("Waiting 30 seconds until the ssh interface
comes up in instance {}", instanceId);
Thread.sleep(30000);
if ("running".equals(instanceState.getName()) &&
publicIpAddress != null) {
- SSHProvider portForwardAgent = new SSHProvider();
- portForwardAgent.initConnection(publicIpAddress, 22,
+ sshProvider = new SSHProvider();
+ sshProvider.initConnection(publicIpAddress, 22,
Path.of(mftKeyDir,
keyName).toAbsolutePath().toString(), systemUser);
logger.info("Created SSH Connection. Installing
dependencies...");
- int exeCode = portForwardAgent.runCommand("sudo apt
install -y openjdk-11-jre-headless");
+ int exeCode = sshProvider.runCommand("sudo apt install
-y openjdk-11-jre-headless");
if (exeCode != 0)
throw new IOException("Failed to install jdk on
new VM");
- exeCode = portForwardAgent.runCommand("sudo apt
install -y unzip");
+ exeCode = sshProvider.runCommand("sudo apt install -y
unzip");
if (exeCode != 0)
throw new IOException("Failed to install unzip on
new VM");
- exeCode = portForwardAgent.runCommand("wget
https://github.com/apache/airavata-mft/releases/download/v0.0.1/MFT-Agent-0.01-bin.zip");
+ exeCode = sshProvider.runCommand("wget
https://github.com/apache/airavata-mft/releases/download/v0.0.1/MFT-Agent-0.01-bin.zip");
if (exeCode != 0)
throw new IOException("Failed to download mft
distribution");
- exeCode = portForwardAgent.runCommand("unzip
MFT-Agent-0.01-bin.zip");
+ exeCode = sshProvider.runCommand("unzip
MFT-Agent-0.01-bin.zip");
if (exeCode != 0)
throw new IOException("Failed to unzip mft
distribution");
- exeCode = portForwardAgent.runCommand("sed -ir
\"s/^[#]*\\s*agent.id=.*/agent.id=" + agentId + "/\"
/home/ubuntu/MFT-Agent-0.01/conf/application.properties");
+ exeCode = sshProvider.runCommand("sed -ir
\"s/^[#]*\\s*agent.id=.*/agent.id=" + agentId + "/\"
/home/ubuntu/MFT-Agent-0.01/conf/application.properties");
if (exeCode != 0)
throw new IOException("Failed to update agent id
in config file");
portForwardLock = new CountDownLatch(1);
- CountDownLatch portForwardPendingLock =
portForwardAgent.createSshPortForward(8500, portForwardLock);
+ CountDownLatch portForwardPendingLock =
sshProvider.createSshPortForward(8500, portForwardLock);
logger.info("Waiting until the port forward is setup");
portForwardPendingLock.await();
- exeCode = portForwardAgent.runCommand("sh
MFT-Agent-0.01/bin/agent-daemon.sh start");
+ exeCode = sshProvider.runCommand("sh
MFT-Agent-0.01/bin/agent-daemon.sh start");
if (exeCode != 0)
throw new IOException("Failed to start the MFT
Agent");
@@ -278,6 +279,7 @@ public class EC2AgentSpawner extends AgentSpawner {
public void terminate() {
terminateFuture = executor.submit(() -> {
+
if (instanceId != null) {
String accessKey = secretWrapper.getS3().getAccessKey();
String secretKey = secretWrapper.getS3().getSecretKey();
@@ -294,6 +296,13 @@ public class EC2AgentSpawner extends AgentSpawner {
portForwardLock.countDown();
}
+ logger.info("Waiting 3 seconds until the port forward lock is
released");
+ Thread.sleep(3000);
+
+ if (sshProvider != null) {
+ sshProvider.closeConnection();
+ }
+
TerminateInstancesRequest terminateInstancesRequest = new
TerminateInstancesRequest();
terminateInstancesRequest.setInstanceIds(Collections.singleton(instanceId));
amazonEC2.terminateInstances(terminateInstancesRequest);
diff --git
a/controller/src/main/java/org/apache/airavata/mft/controller/spawner/SSHProvider.java
b/controller/src/main/java/org/apache/airavata/mft/controller/spawner/SSHProvider.java
index bbd0166..6f59cae 100644
---
a/controller/src/main/java/org/apache/airavata/mft/controller/spawner/SSHProvider.java
+++
b/controller/src/main/java/org/apache/airavata/mft/controller/spawner/SSHProvider.java
@@ -46,8 +46,10 @@ public class SSHProvider {
private static final Logger logger =
LoggerFactory.getLogger(SSHProvider.class);
private SSHClient client;
+ private String hostName;
public void initConnection(String hostName, int port, String keyPath,
String user) throws IOException {
+ this.hostName = hostName;
DefaultConfig defaultConfig = new DefaultConfig();
defaultConfig.setKeepAliveProvider(KeepAliveProvider.KEEP_ALIVE);
@@ -94,6 +96,17 @@ public class SSHProvider {
client.auth(user, am);
}
+ public void closeConnection() {
+ try {
+ if (client != null) {
+ client.close();
+ logger.info("Closed the SSH connection to host {}", hostName);
+ }
+ } catch (Throwable e) {
+ logger.warn("Failed to close the SSH connection for host {}.
Continuing ...", hostName, e);
+ }
+ }
+
public int runCommand(String command) throws IOException {
Session session = this.client.startSession();
logger.info("Running command {}", command);
@@ -111,10 +124,14 @@ public class SSHProvider {
CountDownLatch portForwardCompleteLock = new CountDownLatch(1);
new Thread(() -> {
+
String consulHost = "localhost";
+ RemotePortForwarder remotePortForwarder;
+ RemotePortForwarder.Forward portBind;
try {
- client.getRemotePortForwarder().bind(
+ remotePortForwarder = client.getRemotePortForwarder();
+ portBind = remotePortForwarder.bind(
new RemotePortForwarder.Forward(localPort),
new SocketForwardingConnectListener(new
InetSocketAddress(consulHost, localPort)));
@@ -123,7 +140,7 @@ public class SSHProvider {
portForwardHoldLock.await();
logger.info("Releasing the remote port forward");
- client.getRemotePortForwarder().cancel(new
RemotePortForwarder.Forward(localPort));
+ remotePortForwarder.cancel(portBind);
} catch (Exception e) {
logger.error("Failed to create the remote port forward for
port {}", localPort, e);