This is an automated email from the ASF dual-hosted git repository.

dimuthuupe pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airavata-mft.git


The following commit(s) were added to refs/heads/master by this push:
     new 1a3c3b7  Optimizing AWS agent termination logic
1a3c3b7 is described below

commit 1a3c3b73493ba96ae83d91cee66085b304db9872
Author: Dimuthu Wannipurage <[email protected]>
AuthorDate: Fri Jan 6 14:47:11 2023 -0500

    Optimizing AWS agent termination logic
---
 .../airavata/mft/agent/TransferOrchestrator.java   |  9 ++--
 .../mft/agent/ingress/ConsulIngressHandler.java    |  9 ++--
 .../apache/airavata/mft/admin/MFTConsulClient.java | 52 ++++++++++++++++++++++
 .../mft/controller/spawner/AgentOrchestrator.java  | 50 ++++++++-------------
 .../mft/controller/spawner/EC2AgentSpawner.java    | 27 +++++++----
 .../mft/controller/spawner/SSHProvider.java        | 21 ++++++++-
 6 files changed, 115 insertions(+), 53 deletions(-)

diff --git 
a/agent/service/src/main/java/org/apache/airavata/mft/agent/TransferOrchestrator.java
 
b/agent/service/src/main/java/org/apache/airavata/mft/agent/TransferOrchestrator.java
index 7ca6e3a..8b4fc3e 100644
--- 
a/agent/service/src/main/java/org/apache/airavata/mft/agent/TransferOrchestrator.java
+++ 
b/agent/service/src/main/java/org/apache/airavata/mft/agent/TransferOrchestrator.java
@@ -79,7 +79,7 @@ public class TransferOrchestrator {
 
     public void submitTransferToProcess(String transferId, 
AgentTransferRequest request,
                                         BiConsumer<EndpointPaths, 
TransferState> updateStatus,
-                                        Consumer<Boolean> createTransferHook) {
+                                        BiConsumer<EndpointPaths, Boolean> 
createTransferHook) {
         long totalPending = 
totalPendingTransfers.addAndGet(request.getEndpointPathsCount());
         logger.info("Total pending files to transfer {}", totalPending);
         for (EndpointPaths endpointPath : request.getEndpointPathsList()) {
@@ -94,7 +94,8 @@ public class TransferOrchestrator {
 
     public void processTransfer(String transferId, String requestId, 
StorageWrapper sourceStorage, StorageWrapper destStorage,
                                 SecretWrapper sourceSecret,SecretWrapper 
destSecret, EndpointPaths endpointPath,
-                                BiConsumer<EndpointPaths, TransferState> 
updateStatus, Consumer<Boolean> createTransferHook) {
+                                BiConsumer<EndpointPaths, TransferState> 
updateStatus,
+                                BiConsumer<EndpointPaths, Boolean> 
createTransferHook) {
         try {
 
             long running = totalRunningTransfers.incrementAndGet();
@@ -139,13 +140,13 @@ public class TransferOrchestrator {
                     .setDescription("Started the transfer"));
 
             // Save transfer metadata in scheduled path to recover in case of 
an Agent failures. Recovery is done from controller
-            createTransferHook.accept(true);
+            createTransferHook.accept(endpointPath, true);
 
             mediator.transferSingleThread(transferId, srcCC, dstCC, 
updateStatus,
                     (id, transferSuccess) -> {
                         try {
                             // Delete scheduled key as the transfer completed 
/ failed if it was placed in current session
-                            createTransferHook.accept(false);
+                            createTransferHook.accept(endpointPath,false);
                             long pendingAfter = 
totalRunningTransfers.decrementAndGet();
                             logger.info("Removed transfer {} from queue with 
transfer success = {}. Total running {}",
                                     id, transferSuccess, pendingAfter);
diff --git 
a/agent/service/src/main/java/org/apache/airavata/mft/agent/ingress/ConsulIngressHandler.java
 
b/agent/service/src/main/java/org/apache/airavata/mft/agent/ingress/ConsulIngressHandler.java
index eeb262d..bc705d7 100644
--- 
a/agent/service/src/main/java/org/apache/airavata/mft/agent/ingress/ConsulIngressHandler.java
+++ 
b/agent/service/src/main/java/org/apache/airavata/mft/agent/ingress/ConsulIngressHandler.java
@@ -116,14 +116,11 @@ public class ConsulIngressHandler {
                             AgentUtil.throwingBiConsumerWrapper((endPointPath, 
st) -> {
                                 
mftConsulClient.submitFileTransferStateToProcess(transferId, 
request.getRequestId(), endPointPath,  agentId, st.setPublisher(agentId));
                             }),
-                            AgentUtil.throwingConsumerWrapper(create -> {
+                            AgentUtil.throwingBiConsumerWrapper((endpointPath, 
create) -> {
                                 if (create) {
-                                    
mftConsulClient.getKvClient().putValue(MFTConsulClient.AGENTS_SCHEDULED_PATH
-                                                    + agentId + "/" + session 
+ "/" + transferId + "/" + agentTransferRequestId,
-                                            reqBytes, 0L, PutOptions.BLANK);
+                                    
mftConsulClient.createEndpointHookForAgent(agentId, session, transferId, 
agentTransferRequestId, endpointPath);
                                 } else {
-                                    
mftConsulClient.getKvClient().deleteKey(MFTConsulClient.AGENTS_SCHEDULED_PATH
-                                            + agentId + "/" + session + "/" + 
transferId + "/" + agentTransferRequestId);
+                                   
mftConsulClient.deleteEndpointHookForAgent(agentId, session, transferId, 
agentTransferRequestId, endpointPath);
                                 }
                             }));
                 });
diff --git 
a/common/common-clients/src/main/java/org/apache/airavata/mft/admin/MFTConsulClient.java
 
b/common/common-clients/src/main/java/org/apache/airavata/mft/admin/MFTConsulClient.java
index bf4b7eb..ef2bd76 100644
--- 
a/common/common-clients/src/main/java/org/apache/airavata/mft/admin/MFTConsulClient.java
+++ 
b/common/common-clients/src/main/java/org/apache/airavata/mft/admin/MFTConsulClient.java
@@ -141,6 +141,22 @@ public class MFTConsulClient {
         }
     }
 
+    public List<String> listPendingAgentTransfers(String agentId) throws 
MFTConsulClientException  {
+        try {
+            try {
+                return kvClient.getKeys(AGENTS_TRANSFER_REQUEST_MESSAGE_PATH + 
agentId);
+            } catch (ConsulException e) {
+                if (e.getCode() == 404) {
+                    return Collections.emptyList();
+                } else {
+                    throw e;
+                }
+            }
+        } catch (Exception e) {
+            throw new MFTConsulClientException("Failed to list pending agent 
transfers for agent " + agentId, e);
+        }
+    }
+
     public void sendSyncRPCToAgent(String agentId, SyncRPCRequest rpcRequest) 
throws MFTConsulClientException {
         try {
             String asString = mapper.writeValueAsString(rpcRequest);
@@ -381,6 +397,42 @@ public class MFTConsulClient {
         return liveAgentIds.stream().map(id -> 
getAgentInfo(id).get()).collect(Collectors.toList());
     }
 
+    public void createEndpointHookForAgent(String agentId, String session, 
String transferId,
+                                           String agentTransferRequestId,
+                                           EndpointPaths endpointPaths) {
+        getKvClient().putValue(MFTConsulClient.AGENTS_SCHEDULED_PATH
+                        + agentId + "/" + session + "/" + transferId + "/" + 
agentTransferRequestId
+                        + "/" + getEndpointPathHash(endpointPaths),
+                endpointPaths.toByteArray(), 0L, PutOptions.BLANK);
+    }
+
+    public void deleteEndpointHookForAgent(String agentId, String session, 
String transferId,
+                                          String agentTransferRequestId,
+                                          EndpointPaths endpointPaths) {
+        getKvClient().deleteKey(MFTConsulClient.AGENTS_SCHEDULED_PATH
+                + agentId + "/" + session + "/" + transferId + "/" + 
agentTransferRequestId
+                + "/" + getEndpointPathHash(endpointPaths));
+    }
+
+    public int getEndpointHookCountForAgent(String agentId) throws 
MFTConsulClientException {
+        Optional<String> session = getKvClient().getSession(LIVE_AGENTS_PATH + 
agentId);
+
+        try {
+            try {
+                return session.map(s -> 
getKvClient().getKeys(MFTConsulClient.AGENTS_SCHEDULED_PATH
+                        + agentId + "/" + s).size()).orElse(0);
+            } catch (ConsulException e) {
+                if (e.getCode() == 404) {
+                    return 0;
+                } else {
+                    throw e;
+                }
+            }
+        } catch (Exception e) {
+            throw new MFTConsulClientException("Failed to fetch endpoint hook 
count for agent " + agentId, e);
+        }
+    }
+
     public KeyValueClient getKvClient() {
         return kvClient;
     }
diff --git 
a/controller/src/main/java/org/apache/airavata/mft/controller/spawner/AgentOrchestrator.java
 
b/controller/src/main/java/org/apache/airavata/mft/controller/spawner/AgentOrchestrator.java
index 0ccec2e..454acbe 100644
--- 
a/controller/src/main/java/org/apache/airavata/mft/controller/spawner/AgentOrchestrator.java
+++ 
b/controller/src/main/java/org/apache/airavata/mft/controller/spawner/AgentOrchestrator.java
@@ -154,41 +154,27 @@ public class AgentOrchestrator {
 
                     if ((System.currentTimeMillis() - 
metadata.lastScannedTime) >  SPAWNER_MAX_IDLE_SECONDS * 1000) {
 
-                        long totalFiles = 0;
-                        long completedOrFailedFiles = 0;
-
-                        Map<String, TransferInfo> transferInfos = 
metadata.transferInfos;
-
-                        for (String agentTransferId: transferInfos.keySet()) {
-                            TransferInfo transferInfo = 
transferInfos.get(agentTransferId);
-
-                            try {
-                                totalFiles += 
transferInfo.agentTransferRequest.getEndpointPathsCount();
-
-                                List<TransferState> transferStates = 
this.transferDispatcher.getMftConsulClient()
-                                        
.getTransferStates(transferInfo.transferId, agentTransferId);
-
-                                completedOrFailedFiles += 
transferStates.stream()
-                                        .filter(transferState -> 
transferState.getState().equals("COMPLETED") ||
-                                                
transferState.getState().equals("FAILED")).count();
-
-
-
-                            } catch (Exception e) {
-                                logger.error("Failed to fetch transfer states 
for agent transfer id {}", agentTransferId, e);
-                            }
+                        if (metadata.transferInfos.size() > 0) {
+                            return;
                         }
 
-                        logger.info("Spawner with key {} has total {} files to 
be transferred and {} were completed or failed",
-                                key, totalFiles, completedOrFailedFiles);
-
-                        if (totalFiles == completedOrFailedFiles) {
-                            // TODO create a write lock with reusing agent 
logic
+                        logger.info("No transfer infos for spawner {}. 
Checking for termination", key);
+
+                        try {
+                            List<String> pendingAgentTransfers = 
transferDispatcher.getMftConsulClient().listPendingAgentTransfers(metadata.spawner.getLaunchState().get());
+                            if (pendingAgentTransfers.isEmpty()) {
+                                int totalFilesInProgress = 
transferDispatcher.getMftConsulClient().getEndpointHookCountForAgent(metadata.spawner.getLaunchState().get());
+                                if (totalFilesInProgress == 0) {
+                                    logger.info("Killing spawner with key {} 
as all files were transferred and the agent" +
+                                                    " is inactive for {} 
seconds",
+                                            key, SPAWNER_MAX_IDLE_SECONDS);
+                                    metadata.spawner.terminate();
+                                    launchedSpawnersMap.remove(key);
+                                }
+                            }
 
-                            logger.info("Killing spawner with key {} as all 
files were transferred and inactive for {} seconds",
-                                    key, SPAWNER_MAX_IDLE_SECONDS);
-                            metadata.spawner.terminate();
-                            launchedSpawnersMap.remove(key);
+                        } catch (Exception e) {
+                            logger.error("Failed while fetching the endpoint 
count for agent", e);
                         }
                     }
                 });
diff --git 
a/controller/src/main/java/org/apache/airavata/mft/controller/spawner/EC2AgentSpawner.java
 
b/controller/src/main/java/org/apache/airavata/mft/controller/spawner/EC2AgentSpawner.java
index a097f92..86a43e1 100644
--- 
a/controller/src/main/java/org/apache/airavata/mft/controller/spawner/EC2AgentSpawner.java
+++ 
b/controller/src/main/java/org/apache/airavata/mft/controller/spawner/EC2AgentSpawner.java
@@ -51,6 +51,7 @@ public class EC2AgentSpawner extends AgentSpawner {
     private Future<String> launchFuture;
     private Future<Boolean> terminateFuture;
     private final Map<String, String> amiMap;
+    private SSHProvider sshProvider;
 
     public EC2AgentSpawner(StorageWrapper storageWrapper, SecretWrapper 
secretWrapper) {
         super(storageWrapper, secretWrapper);
@@ -219,35 +220,35 @@ public class EC2AgentSpawner extends AgentSpawner {
                     logger.info("Waiting 30 seconds until the ssh interface 
comes up in instance {}", instanceId);
                     Thread.sleep(30000);
                     if ("running".equals(instanceState.getName()) && 
publicIpAddress != null) {
-                        SSHProvider portForwardAgent = new SSHProvider();
-                        portForwardAgent.initConnection(publicIpAddress, 22,
+                        sshProvider = new SSHProvider();
+                        sshProvider.initConnection(publicIpAddress, 22,
                                 Path.of(mftKeyDir, 
keyName).toAbsolutePath().toString(), systemUser);
                         logger.info("Created SSH Connection. Installing 
dependencies...");
 
-                        int exeCode = portForwardAgent.runCommand("sudo apt 
install -y openjdk-11-jre-headless");
+                        int exeCode = sshProvider.runCommand("sudo apt install 
-y openjdk-11-jre-headless");
                         if (exeCode != 0)
                             throw new IOException("Failed to install jdk on 
new VM");
-                        exeCode = portForwardAgent.runCommand("sudo apt 
install -y unzip");
+                        exeCode = sshProvider.runCommand("sudo apt install -y 
unzip");
                         if (exeCode != 0)
                             throw new IOException("Failed to install unzip on 
new VM");
-                        exeCode = portForwardAgent.runCommand("wget 
https://github.com/apache/airavata-mft/releases/download/v0.0.1/MFT-Agent-0.01-bin.zip";);
+                        exeCode = sshProvider.runCommand("wget 
https://github.com/apache/airavata-mft/releases/download/v0.0.1/MFT-Agent-0.01-bin.zip";);
                         if (exeCode != 0)
                             throw new IOException("Failed to download mft 
distribution");
-                        exeCode = portForwardAgent.runCommand("unzip 
MFT-Agent-0.01-bin.zip");
+                        exeCode = sshProvider.runCommand("unzip 
MFT-Agent-0.01-bin.zip");
                         if (exeCode != 0)
                             throw new IOException("Failed to unzip mft 
distribution");
 
-                        exeCode = portForwardAgent.runCommand("sed -ir 
\"s/^[#]*\\s*agent.id=.*/agent.id=" + agentId + "/\" 
/home/ubuntu/MFT-Agent-0.01/conf/application.properties");
+                        exeCode = sshProvider.runCommand("sed -ir 
\"s/^[#]*\\s*agent.id=.*/agent.id=" + agentId + "/\" 
/home/ubuntu/MFT-Agent-0.01/conf/application.properties");
                         if (exeCode != 0)
                             throw new IOException("Failed to update agent id 
in config file");
 
                         portForwardLock = new CountDownLatch(1);
-                        CountDownLatch portForwardPendingLock = 
portForwardAgent.createSshPortForward(8500, portForwardLock);
+                        CountDownLatch portForwardPendingLock = 
sshProvider.createSshPortForward(8500, portForwardLock);
 
                         logger.info("Waiting until the port forward is setup");
                         portForwardPendingLock.await();
 
-                        exeCode = portForwardAgent.runCommand("sh 
MFT-Agent-0.01/bin/agent-daemon.sh start");
+                        exeCode = sshProvider.runCommand("sh 
MFT-Agent-0.01/bin/agent-daemon.sh start");
                         if (exeCode != 0)
                             throw new IOException("Failed to start the MFT 
Agent");
 
@@ -278,6 +279,7 @@ public class EC2AgentSpawner extends AgentSpawner {
     public void terminate() {
 
         terminateFuture =  executor.submit(() -> {
+
             if (instanceId != null) {
                 String accessKey = secretWrapper.getS3().getAccessKey();
                 String secretKey = secretWrapper.getS3().getSecretKey();
@@ -294,6 +296,13 @@ public class EC2AgentSpawner extends AgentSpawner {
                     portForwardLock.countDown();
                 }
 
+                logger.info("Waiting 3 seconds until the port forward lock is 
released");
+                Thread.sleep(3000);
+
+                if (sshProvider != null) {
+                    sshProvider.closeConnection();
+                }
+
                 TerminateInstancesRequest terminateInstancesRequest = new 
TerminateInstancesRequest();
                 
terminateInstancesRequest.setInstanceIds(Collections.singleton(instanceId));
                 amazonEC2.terminateInstances(terminateInstancesRequest);
diff --git 
a/controller/src/main/java/org/apache/airavata/mft/controller/spawner/SSHProvider.java
 
b/controller/src/main/java/org/apache/airavata/mft/controller/spawner/SSHProvider.java
index bbd0166..6f59cae 100644
--- 
a/controller/src/main/java/org/apache/airavata/mft/controller/spawner/SSHProvider.java
+++ 
b/controller/src/main/java/org/apache/airavata/mft/controller/spawner/SSHProvider.java
@@ -46,8 +46,10 @@ public class SSHProvider {
     private static final Logger logger = 
LoggerFactory.getLogger(SSHProvider.class);
 
     private SSHClient client;
+    private String hostName;
 
     public void initConnection(String hostName, int port, String keyPath, 
String user) throws IOException {
+        this.hostName = hostName;
         DefaultConfig defaultConfig = new DefaultConfig();
         defaultConfig.setKeepAliveProvider(KeepAliveProvider.KEEP_ALIVE);
 
@@ -94,6 +96,17 @@ public class SSHProvider {
         client.auth(user, am);
     }
 
+    public void closeConnection() {
+        try {
+            if (client != null) {
+                client.close();
+                logger.info("Closed the SSH connection to host {}", hostName);
+            }
+        } catch (Throwable e) {
+            logger.warn("Failed to close the SSH connection for host {}. 
Continuing ...", hostName, e);
+        }
+    }
+
     public int runCommand(String command) throws IOException {
         Session session = this.client.startSession();
         logger.info("Running command {}", command);
@@ -111,10 +124,14 @@ public class SSHProvider {
 
         CountDownLatch portForwardCompleteLock = new CountDownLatch(1);
         new Thread(() -> {
+
             String consulHost = "localhost";
+            RemotePortForwarder remotePortForwarder;
+            RemotePortForwarder.Forward portBind;
 
             try {
-                client.getRemotePortForwarder().bind(
+                remotePortForwarder = client.getRemotePortForwarder();
+                portBind = remotePortForwarder.bind(
                         new RemotePortForwarder.Forward(localPort),
                         new SocketForwardingConnectListener(new 
InetSocketAddress(consulHost, localPort)));
 
@@ -123,7 +140,7 @@ public class SSHProvider {
                 portForwardHoldLock.await();
 
                 logger.info("Releasing the remote port forward");
-                client.getRemotePortForwarder().cancel(new 
RemotePortForwarder.Forward(localPort));
+                remotePortForwarder.cancel(portBind);
 
             } catch (Exception e) {
                 logger.error("Failed to create the remote port forward for 
port {}", localPort, e);

Reply via email to