This is an automated email from the ASF dual-hosted git repository.
dimuthuupe pushed a commit to branch airavata-v2-refactoring
in repository https://gitbox.apache.org/repos/asf/airavata.git
The following commit(s) were added to refs/heads/airavata-v2-refactoring by
this push:
new bdfd0e3b68 Supporting inter-task communication. Passing task params
through workflow context
bdfd0e3b68 is described below
commit bdfd0e3b686a7cd6a1b13cddd7db0889ef622ee4
Author: DImuthuUpe <[email protected]>
AuthorDate: Tue Jun 27 08:15:45 2023 -0400
Supporting inter-task communication. Passing task params through workflow
context
---
modules/airavata-apis/airavata-apis-server/pom.xml | 28 +++
.../apis/db/entity/backend/EC2BackendEntity.java | 11 +
.../apis/scheduling/ExperimentLauncher.java | 48 ++++-
.../airavata/apis/workflow/WorkflowExecutor.java | 3 +-
.../apis/workflow/task/common/BaseTask.java | 32 ++-
.../apis/workflow/task/common/TaskUtil.java | 79 ++++---
.../workflow/task/ec2/CreateEC2InstanceTask.java | 233 +++++++++++++++++++++
.../workflow/task/ec2/DestroyEC2InstanceTask.java | 116 +++++++++-
.../src/main/proto/execution/experiment_stub.proto | 5 +-
9 files changed, 520 insertions(+), 35 deletions(-)
diff --git a/modules/airavata-apis/airavata-apis-server/pom.xml
b/modules/airavata-apis/airavata-apis-server/pom.xml
index eff620301e..5014124e93 100644
--- a/modules/airavata-apis/airavata-apis-server/pom.xml
+++ b/modules/airavata-apis/airavata-apis-server/pom.xml
@@ -22,6 +22,28 @@
</exclusion>
</exclusions>
</dependency>
+ <dependency>
+ <groupId>org.apache.airavata</groupId>
+ <artifactId>mft-secret-service-client</artifactId>
+ <version>0.01-SNAPSHOT</version>
+ <exclusions>
+ <exclusion>
+ <groupId>com.google.protobuf</groupId>
+ <artifactId>protobuf-java</artifactId>
+ </exclusion>
+ </exclusions>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.airavata</groupId>
+ <artifactId>mft-resource-service-client</artifactId>
+ <version>0.01-SNAPSHOT</version>
+ <exclusions>
+ <exclusion>
+ <groupId>com.google.protobuf</groupId>
+ <artifactId>protobuf-java</artifactId>
+ </exclusion>
+ </exclusions>
+ </dependency>
<dependency>
<groupId>org.apache.airavata</groupId>
<artifactId>mft-resource-service-server</artifactId>
@@ -129,8 +151,14 @@
</exclusion>
</exclusions>
</dependency>
+ <dependency>
+ <groupId>com.amazonaws</groupId>
+ <artifactId>aws-java-sdk</artifactId>
+ <version>${aws.sdk.version}</version>
+ </dependency>
</dependencies>
<properties>
+ <aws.sdk.version>1.12.372</aws.sdk.version>
<maven.compiler.source>18</maven.compiler.source>
<maven.compiler.target>18</maven.compiler.target>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
diff --git
a/modules/airavata-apis/airavata-apis-server/src/main/java/org/apache/airavata/apis/db/entity/backend/EC2BackendEntity.java
b/modules/airavata-apis/airavata-apis-server/src/main/java/org/apache/airavata/apis/db/entity/backend/EC2BackendEntity.java
index cfd8d8b4d2..87a7ce5cc0 100644
---
a/modules/airavata-apis/airavata-apis-server/src/main/java/org/apache/airavata/apis/db/entity/backend/EC2BackendEntity.java
+++
b/modules/airavata-apis/airavata-apis-server/src/main/java/org/apache/airavata/apis/db/entity/backend/EC2BackendEntity.java
@@ -17,6 +17,9 @@ public class EC2BackendEntity extends ComputeBackendEntity {
@Column
String awsCredentialId;
+ @Column
+ String loginUserName;
+
public String getFlavor() {
return flavor;
}
@@ -48,4 +51,12 @@ public class EC2BackendEntity extends ComputeBackendEntity {
public void setImageId(String imageId) {
this.imageId = imageId;
}
+
+ public String getLoginUserName() {
+ return loginUserName;
+ }
+
+ public void setLoginUserName(String loginUserName) {
+ this.loginUserName = loginUserName;
+ }
}
diff --git
a/modules/airavata-apis/airavata-apis-server/src/main/java/org/apache/airavata/apis/scheduling/ExperimentLauncher.java
b/modules/airavata-apis/airavata-apis-server/src/main/java/org/apache/airavata/apis/scheduling/ExperimentLauncher.java
index 1d77d6ab56..e0c6472a15 100644
---
a/modules/airavata-apis/airavata-apis-server/src/main/java/org/apache/airavata/apis/scheduling/ExperimentLauncher.java
+++
b/modules/airavata-apis/airavata-apis-server/src/main/java/org/apache/airavata/apis/scheduling/ExperimentLauncher.java
@@ -3,6 +3,7 @@ package org.apache.airavata.apis.scheduling;
import org.apache.airavata.api.execution.ExperimentLaunchRequest;
import org.apache.airavata.api.execution.stubs.EC2Backend;
import org.apache.airavata.api.execution.stubs.Experiment;
+import org.apache.airavata.api.execution.stubs.RunConfiguration;
import org.apache.airavata.apis.service.ExecutionService;
import org.apache.airavata.apis.workflow.task.common.BaseTask;
import org.apache.airavata.apis.workflow.task.common.OutPort;
@@ -11,6 +12,11 @@ import
org.apache.airavata.apis.workflow.task.common.annotation.TaskDef;
import org.apache.airavata.apis.workflow.task.common.annotation.TaskOutPort;
import org.apache.airavata.apis.workflow.task.data.DataMovementTask;
import org.apache.airavata.apis.workflow.task.ec2.CreateEC2InstanceTask;
+import org.apache.airavata.apis.workflow.task.ec2.DestroyEC2InstanceTask;
+import org.apache.airavata.mft.credential.stubs.s3.S3Secret;
+import org.apache.airavata.mft.credential.stubs.s3.S3SecretCreateRequest;
+import org.apache.airavata.mft.secret.client.SecretServiceClient;
+import org.apache.airavata.mft.secret.client.SecretServiceClientBuilder;
import org.apache.helix.HelixManager;
import org.apache.helix.HelixManagerFactory;
import org.apache.helix.InstanceType;
@@ -35,6 +41,7 @@ public class ExperimentLauncher {
@Autowired
private ExecutionService executionService;
+
public void launchExperiment(ExperimentLaunchRequest
experimentLaunchRequest) throws Exception {
Optional<Experiment> experimentOp =
executionService.getExperiment(experimentLaunchRequest.getExperimentId());
if (experimentOp.isEmpty()) {
@@ -44,10 +51,27 @@ public class ExperimentLauncher {
Experiment experiment = experimentOp.get();
}
+ private void submitEC2Workflow(RunConfiguration runConfiguration) {
+ CreateEC2InstanceTask ec2InstanceTask = new CreateEC2InstanceTask();
+ ec2InstanceTask.setEc2Backend(runConfiguration.getEc2());
+ ec2InstanceTask.setSecretServiceHost("localhost");
+ ec2InstanceTask.setSecretServicePort(7003);
+ ec2InstanceTask.setUserToken("token");
+
+
+ }
+
+
public static void main(String args[]) throws Exception {
ExperimentLauncher launcher = new ExperimentLauncher();
launcher.init("airavata", "wm", "localhost:2181");
+ SecretServiceClient secretServiceClient =
SecretServiceClientBuilder.buildClient("localhost", 7002);
+ S3Secret s3Secret =
secretServiceClient.s3().createS3Secret(S3SecretCreateRequest.newBuilder()
+ .setAccessKey("key").setSecretKey("sec").build());
+
+ logger.info("S3 Secret id : " + s3Secret.getSecretId());
+
Map<String, BaseTask> taskMap = new HashMap<>();
DataMovementTask dataMovementTask = new DataMovementTask();
@@ -56,16 +80,34 @@ public class ExperimentLauncher {
taskMap.put(dataMovementTask.getTaskId(), dataMovementTask);
EC2Backend ec2Backend = EC2Backend.newBuilder()
- .setAwsCredentialId("SomeCred")
- .setFlavor("m2")
- .setRegion("us-west").build();
+ .setAwsCredentialId(s3Secret.getSecretId())
+ .setLoginUserName("ubuntu")
+ .setRegion("us-east-1")
+ .setFlavor("t2.micro")
+ .setImageId("ami-053b0d53c279acc90").build();
CreateEC2InstanceTask ec2InstanceTask = new CreateEC2InstanceTask();
ec2InstanceTask.setTaskId(UUID.randomUUID().toString());
ec2InstanceTask.setEc2Backend(ec2Backend);
+ ec2InstanceTask.setSecretServiceHost("localhost");
+ ec2InstanceTask.setSecretServicePort(7002);
+ ec2InstanceTask.setUserToken("token");
+
taskMap.put(ec2InstanceTask.getTaskId(), ec2InstanceTask);
+ DestroyEC2InstanceTask destroyEC2InstanceTask = new
DestroyEC2InstanceTask();
+ destroyEC2InstanceTask.setTaskId(UUID.randomUUID().toString());
+ destroyEC2InstanceTask.setEc2Backend(ec2Backend);
+ destroyEC2InstanceTask.setSecretServiceHost("localhost");
+ destroyEC2InstanceTask.setSecretServicePort(7002);
+ destroyEC2InstanceTask.setUserToken("token");
+ destroyEC2InstanceTask.setInstanceId(""); // Override by workflow
+
destroyEC2InstanceTask.overrideParameterFromWorkflowContext("instanceId",
CreateEC2InstanceTask.EC2_INSTANCE_ID);
+
+ taskMap.put(destroyEC2InstanceTask.getTaskId(),
destroyEC2InstanceTask);
+
dataMovementTask.addOutPort(new
OutPort().setNextTaskId(ec2InstanceTask.getTaskId()));
+ ec2InstanceTask.addOutPort(new
OutPort().setNextTaskId(destroyEC2InstanceTask.getTaskId()));
String[] startTaskIds = {dataMovementTask.getTaskId()};
logger.info("Submitting workflow");
diff --git
a/modules/airavata-apis/airavata-apis-server/src/main/java/org/apache/airavata/apis/workflow/WorkflowExecutor.java
b/modules/airavata-apis/airavata-apis-server/src/main/java/org/apache/airavata/apis/workflow/WorkflowExecutor.java
index 29e67f51da..7b6c67be01 100644
---
a/modules/airavata-apis/airavata-apis-server/src/main/java/org/apache/airavata/apis/workflow/WorkflowExecutor.java
+++
b/modules/airavata-apis/airavata-apis-server/src/main/java/org/apache/airavata/apis/workflow/WorkflowExecutor.java
@@ -162,7 +162,8 @@ public class WorkflowExecutor implements CommandLineRunner {
String[] taskClasses = {
"org.apache.airavata.apis.workflow.task.data.DataMovementTask",
-
"org.apache.airavata.apis.workflow.task.ec2.CreateEC2InstanceTask"};
+
"org.apache.airavata.apis.workflow.task.ec2.CreateEC2InstanceTask",
+
"org.apache.airavata.apis.workflow.task.ec2.DestroyEC2InstanceTask"};
Map<String, TaskFactory> taskMap = new HashMap<>();
diff --git
a/modules/airavata-apis/airavata-apis-server/src/main/java/org/apache/airavata/apis/workflow/task/common/BaseTask.java
b/modules/airavata-apis/airavata-apis-server/src/main/java/org/apache/airavata/apis/workflow/task/common/BaseTask.java
index 1b3e098ae2..7b24bc6265 100644
---
a/modules/airavata-apis/airavata-apis-server/src/main/java/org/apache/airavata/apis/workflow/task/common/BaseTask.java
+++
b/modules/airavata-apis/airavata-apis-server/src/main/java/org/apache/airavata/apis/workflow/task/common/BaseTask.java
@@ -9,8 +9,11 @@ import org.apache.helix.task.UserContentStore;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import java.lang.reflect.Field;
import java.util.ArrayList;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
@@ -30,6 +33,8 @@ public abstract class BaseTask extends UserContentStore
implements Task {
@TaskParam(name = "retryCount")
private ThreadLocal<Integer> retryCount = ThreadLocal.withInitial(()-> 3);
+ private ThreadLocal<Map<String,String>> paramOverrideMap =
ThreadLocal.withInitial(() -> new HashMap<>());
+
@Override
public TaskResult run() {
try {
@@ -43,7 +48,20 @@ public abstract class BaseTask extends UserContentStore
implements Task {
this.callbackContext.set(cbc);
String helixTaskId = getCallbackContext().getTaskConfig().getId();
logger.info("Running task {}", helixTaskId);
- TaskUtil.deserializeTaskData(this,
getCallbackContext().getTaskConfig().getConfigMap());
+
+ Map<String, String> configMap =
getCallbackContext().getTaskConfig().getConfigMap();
+ TaskUtil.deserializeTaskData(this, configMap);
+
+ for (String key: configMap.keySet()) {
+ if (key.startsWith("$")) {
+ String contextVariable = configMap.get(key);
+ String paramName = key.substring(1);
+ String paramValue = getUserContent(contextVariable,
Scope.WORKFLOW);
+ Field cf = TaskUtil.getClassFieldForParamName(this,
paramName);
+ TaskUtil.deserializeField(this, cf, paramValue);
+ }
+ }
+
} catch (Exception e) {
logger.error("Failed at deserializing task data", e);
return new TaskResult(TaskResult.Status.FAILED, "Failed in
deserializing task data");
@@ -66,6 +84,10 @@ public abstract class BaseTask extends UserContentStore
implements Task {
}
}
+ public void overrideParameterFromWorkflowContext(String paramName, String
contextVariable) {
+ getParamOverrideMap().put(paramName, contextVariable);
+ }
+
public abstract TaskResult onRun() throws Exception;
public abstract void onCancel() throws Exception;
@@ -98,6 +120,14 @@ public abstract class BaseTask extends UserContentStore
implements Task {
this.taskId.set(taskId);
}
+ public Map<String, String> getParamOverrideMap() {
+ return paramOverrideMap.get();
+ }
+
+ public void setParamOverrideMap(Map<String, String> paramOverrideMap) {
+ this.paramOverrideMap.set(paramOverrideMap);
+ }
+
public void setCallbackContext(TaskCallbackContext callbackContext) {
logger.info("Setting callback context {}",
callbackContext.getJobConfig().getId());
try {
diff --git
a/modules/airavata-apis/airavata-apis-server/src/main/java/org/apache/airavata/apis/workflow/task/common/TaskUtil.java
b/modules/airavata-apis/airavata-apis-server/src/main/java/org/apache/airavata/apis/workflow/task/common/TaskUtil.java
index 654c0b6156..8c27a0be72 100644
---
a/modules/airavata-apis/airavata-apis-server/src/main/java/org/apache/airavata/apis/workflow/task/common/TaskUtil.java
+++
b/modules/airavata-apis/airavata-apis-server/src/main/java/org/apache/airavata/apis/workflow/task/common/TaskUtil.java
@@ -1,6 +1,5 @@
package org.apache.airavata.apis.workflow.task.common;
-import com.google.protobuf.AbstractMessage;
import com.google.protobuf.GeneratedMessageV3;
import org.apache.airavata.apis.workflow.task.common.annotation.TaskParam;
import org.apache.commons.beanutils.PropertyUtils;
@@ -21,7 +20,36 @@ public class TaskUtil {
private final static Logger logger =
LoggerFactory.getLogger(TaskUtil.class);
- public static <T extends BaseTask> void deserializeTaskData(T instance,
Map<String, String> params) throws IllegalAccessException,
NoSuchMethodException, InvocationTargetException, InstantiationException {
+ public static <T extends BaseTask> void deserializeField(T instance, Field
classField, String value) throws Exception {
+ classField.setAccessible(true);
+ PropertyDescriptor propertyDescriptor =
PropertyUtils.getPropertyDescriptor(instance, classField.getName());
+ Method writeMethod = PropertyUtils.getWriteMethod(propertyDescriptor);
+ Class<?>[] methodParamType = writeMethod.getParameterTypes();
+ Class<?> writeParameterType = methodParamType[0];
+
+ if (GeneratedMessageV3.class.isAssignableFrom(writeParameterType)) {
// Parsing protobuf messages
+ Method parseMethod =
writeParameterType.getDeclaredMethod("parseFrom", byte[].class);
+ Object obj = parseMethod.invoke(null, value.getBytes()); //
Calling static method
+ writeMethod.invoke(instance, obj);
+ } else if (writeParameterType.isAssignableFrom(String.class)) {
+ writeMethod.invoke(instance, value);
+ } else if (writeParameterType.isAssignableFrom(Integer.class) ||
+ writeParameterType.isAssignableFrom(Integer.TYPE)) {
+ writeMethod.invoke(instance, Integer.parseInt(value));
+ } else if (writeParameterType.isAssignableFrom(Long.class) ||
+ writeParameterType.isAssignableFrom(Long.TYPE)) {
+ writeMethod.invoke(instance, Long.parseLong(value));
+ } else if (writeParameterType.isAssignableFrom(Boolean.class) ||
+ writeParameterType.isAssignableFrom(Boolean.TYPE)) {
+ writeMethod.invoke(instance, Boolean.parseBoolean(value));
+ } else if (TaskParamType.class.isAssignableFrom(writeParameterType)) {
+ Constructor<?> ctor = writeParameterType.getConstructor();
+ Object obj = ctor.newInstance();
+ ((TaskParamType)obj).deserialize(value);
+ writeMethod.invoke(instance, obj);
+ }
+ }
+ public static <T extends BaseTask> void deserializeTaskData(T instance,
Map<String, String> params) throws Exception {
List<Field> allFields = new ArrayList<>();
Class genericClass = instance.getClass();
@@ -38,36 +66,28 @@ public class TaskUtil {
TaskParam param = classField.getAnnotation(TaskParam.class);
if (param != null) {
if (params.containsKey(param.name())) {
- classField.setAccessible(true);
- PropertyDescriptor propertyDescriptor =
PropertyUtils.getPropertyDescriptor(instance, classField.getName());
- Method writeMethod =
PropertyUtils.getWriteMethod(propertyDescriptor);
- Class<?>[] methodParamType =
writeMethod.getParameterTypes();
- Class<?> writeParameterType = methodParamType[0];
+ deserializeField(instance, classField,
params.get(param.name()));
+ }
+ }
+ }
+ }
- if
(GeneratedMessageV3.class.isAssignableFrom(writeParameterType)) { // Parsing
protobuf messages
- Method parseMethod =
writeParameterType.getDeclaredMethod("parseFrom", byte[].class);
- Object obj = parseMethod.invoke(null,
params.get(param.name()).getBytes()); // Calling static method
- writeMethod.invoke(instance, obj);
- } else if
(writeParameterType.isAssignableFrom(String.class)) {
- writeMethod.invoke(instance, params.get(param.name()));
- } else if
(writeParameterType.isAssignableFrom(Integer.class) ||
- writeParameterType.isAssignableFrom(Integer.TYPE))
{
- writeMethod.invoke(instance,
Integer.parseInt(params.get(param.name())));
- } else if (writeParameterType.isAssignableFrom(Long.class)
||
- writeParameterType.isAssignableFrom(Long.TYPE)) {
- writeMethod.invoke(instance,
Long.parseLong(params.get(param.name())));
- } else if
(writeParameterType.isAssignableFrom(Boolean.class) ||
- writeParameterType.isAssignableFrom(Boolean.TYPE))
{
- writeMethod.invoke(instance,
Boolean.parseBoolean(params.get(param.name())));
- } else if
(TaskParamType.class.isAssignableFrom(writeParameterType)) {
- Constructor<?> ctor =
writeParameterType.getConstructor();
- Object obj = ctor.newInstance();
-
((TaskParamType)obj).deserialize(params.get(param.name()));
- writeMethod.invoke(instance, obj);
+ public static <T extends BaseTask> Field getClassFieldForParamName(T
instance, String paramName) {
+ Class genericClass = instance.getClass();
+
+ while (BaseTask.class.isAssignableFrom(genericClass)) {
+ Field[] declaredFields = genericClass.getDeclaredFields();
+ for (Field declaredField : declaredFields) {
+ TaskParam param = declaredField.getAnnotation(TaskParam.class);
+ if (param != null) {
+ if (param.name().equals(paramName)) {
+ return declaredField;
}
}
}
+ genericClass = genericClass.getSuperclass();
}
+ return null;
}
public static <T extends BaseTask> Map<String, String> serializeTaskData(T
data) throws IllegalAccessException, InvocationTargetException,
NoSuchMethodException {
@@ -94,6 +114,11 @@ public class TaskUtil {
}
}
}
+
+ Map<String, String> paramOverrideMap = data.getParamOverrideMap();
+ paramOverrideMap.forEach((param, variable) -> {
+ result.put("$"+ param, variable);
+ });
return result;
}
}
\ No newline at end of file
diff --git
a/modules/airavata-apis/airavata-apis-server/src/main/java/org/apache/airavata/apis/workflow/task/ec2/CreateEC2InstanceTask.java
b/modules/airavata-apis/airavata-apis-server/src/main/java/org/apache/airavata/apis/workflow/task/ec2/CreateEC2InstanceTask.java
index 95454e1d4f..bfac5405de 100644
---
a/modules/airavata-apis/airavata-apis-server/src/main/java/org/apache/airavata/apis/workflow/task/ec2/CreateEC2InstanceTask.java
+++
b/modules/airavata-apis/airavata-apis-server/src/main/java/org/apache/airavata/apis/workflow/task/ec2/CreateEC2InstanceTask.java
@@ -1,25 +1,234 @@
package org.apache.airavata.apis.workflow.task.ec2;
+import com.amazonaws.auth.AWSStaticCredentialsProvider;
+import com.amazonaws.auth.BasicAWSCredentials;
+import com.amazonaws.client.builder.AwsClientBuilder;
+import com.amazonaws.services.ec2.AmazonEC2;
+import com.amazonaws.services.ec2.AmazonEC2ClientBuilder;
+import com.amazonaws.services.ec2.model.*;
import org.apache.airavata.api.execution.stubs.EC2Backend;
import org.apache.airavata.apis.workflow.task.common.BaseTask;
import org.apache.airavata.apis.workflow.task.common.annotation.TaskDef;
import org.apache.airavata.apis.workflow.task.common.annotation.TaskParam;
+import org.apache.airavata.mft.common.AuthToken;
+import org.apache.airavata.mft.common.UserTokenAuth;
+import org.apache.airavata.mft.credential.stubs.s3.S3Secret;
+import org.apache.airavata.mft.credential.stubs.s3.S3SecretGetRequest;
+import org.apache.airavata.mft.credential.stubs.scp.SCPSecret;
+import org.apache.airavata.mft.credential.stubs.scp.SCPSecretCreateRequest;
+import org.apache.airavata.mft.credential.stubs.scp.SCPSecretGetRequest;
+import org.apache.airavata.mft.secret.client.SecretServiceClient;
+import org.apache.airavata.mft.secret.client.SecretServiceClientBuilder;
import org.apache.helix.task.TaskResult;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import java.io.File;
+import java.nio.charset.StandardCharsets;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.util.*;
+import java.util.stream.Collectors;
+
@TaskDef(name = "CreateEC2InstanceTask")
public class CreateEC2InstanceTask extends BaseTask {
+ public static final String EC2_INSTANCE_SECRET_ID =
"EC2_INSTANCE_SECRET_ID";
+ public static final String EC2_INSTANCE_ID = "EC2_INSTANCE_ID";
+ public static final String EC2_INSTANCE_IP = "EC2_INSTANCE_IP";
+
private final static Logger logger =
LoggerFactory.getLogger(CreateEC2InstanceTask.class);
@TaskParam(name = "ec2Backend")
private ThreadLocal<EC2Backend> ec2Backend = new ThreadLocal<>();
+ @TaskParam(name = "secretServiceHost")
+ private ThreadLocal<String> secretServiceHost = new ThreadLocal<>();
+
+ @TaskParam(name = "secretServicePort")
+ private ThreadLocal<Integer> secretServicePort = new ThreadLocal<>();
+
+ @TaskParam(name = "userToken")
+ private ThreadLocal<String> userToken = new ThreadLocal<>();
+
@Override
public TaskResult onRun() throws Exception {
+
+ String keyNamePrefix = "airavata-aws-agent-key-";
+ String secGroupName = "AiravataSecurityGroup";
+ String airavataKeyDir = System.getProperty("user.home") +
File.separator + ".airavata" + File.separator + "keys";
+
logger.info("Starting Create EC2 Instance Task {}", getTaskId());
logger.info("EC2 Backend {}", getEc2Backend().toString());
+
+ EC2Backend ec2BackendObj = getEc2Backend();
+ try (SecretServiceClient secretServiceClient =
SecretServiceClientBuilder
+ .buildClient(getSecretServiceHost(), getSecretServicePort())) {
+
+ S3Secret s3Secret =
secretServiceClient.s3().getS3Secret(S3SecretGetRequest.newBuilder()
+ .setAuthzToken(AuthToken.newBuilder()
+
.setUserTokenAuth(UserTokenAuth.newBuilder().setToken(getUserToken()).build())
+ .build())
+ .setSecretId(ec2BackendObj.getAwsCredentialId()).build());
+
+ BasicAWSCredentials awsCreds = new
BasicAWSCredentials(s3Secret.getAccessKey(), s3Secret.getSecretKey());
+
+ AmazonEC2 amazonEC2 =
AmazonEC2ClientBuilder.standard().withEndpointConfiguration(new
AwsClientBuilder.EndpointConfiguration(
+ "https://ec2." + ec2BackendObj.getRegion() +
".amazonaws.com", ec2BackendObj.getRegion()))
+ .withCredentials(new
AWSStaticCredentialsProvider(awsCreds))
+ .build();
+
+ DescribeSecurityGroupsRequest desSecGrp = new
DescribeSecurityGroupsRequest();
+ DescribeSecurityGroupsResult describeSecurityGroupsResult =
amazonEC2.describeSecurityGroups(desSecGrp);
+ List<SecurityGroup> securityGroups =
describeSecurityGroupsResult.getSecurityGroups();
+ boolean hasSecGroup = securityGroups.stream().anyMatch(sg ->
sg.getGroupName().equals(secGroupName));
+
+ if (!hasSecGroup) {
+ CreateSecurityGroupRequest csgr = new
CreateSecurityGroupRequest();
+ csgr.withGroupName(secGroupName).withDescription("Airavata
Security Group");
+
+ CreateSecurityGroupResult createSecurityGroupResult =
amazonEC2.createSecurityGroup(csgr);
+
+ IpPermission ipPermission = new IpPermission();
+
+ IpRange ipRange1 = new IpRange().withCidrIp("0.0.0.0/0");
+
+
ipPermission.withIpv4Ranges(Collections.singletonList(ipRange1))
+ .withIpProtocol("tcp")
+ .withFromPort(22)
+ .withToPort(22);
+
+ AuthorizeSecurityGroupIngressRequest
authorizeSecurityGroupIngressRequest =
+ new AuthorizeSecurityGroupIngressRequest();
+
authorizeSecurityGroupIngressRequest.withGroupName(secGroupName)
+ .withIpPermissions(ipPermission);
+
amazonEC2.authorizeSecurityGroupIngress(authorizeSecurityGroupIngressRequest);
+ }
+
+ String keyName = null;
+ SCPSecret scpSecret = null;
+ DescribeKeyPairsResult keyPairs = amazonEC2.describeKeyPairs();
+ List<KeyPairInfo> keyPairsWithSecretRegistration =
keyPairs.getKeyPairs().stream().filter(kp -> kp.getTags().stream()
+ .anyMatch(tg ->
tg.getKey().equals("AIRAVATA_SECRET_ID")))
+ .collect(Collectors.toList());
+
+ try {
+
+ for (KeyPairInfo keyPairInfo : keyPairsWithSecretRegistration)
{
+ Optional<Tag> secretTag =
keyPairInfo.getTags().stream().filter(t ->
t.getKey().equals("AIRAVATA_SECRET_ID")).findFirst();
+ if (secretTag.isPresent()) {
+ SCPSecret secret =
secretServiceClient.scp().getSCPSecret(SCPSecretGetRequest.newBuilder()
+ .setAuthzToken(AuthToken.newBuilder()
+
.setUserTokenAuth(UserTokenAuth.newBuilder()
+
.setToken(getUserToken()).build()).build())
+
.setSecretId(secretTag.get().getValue()).build());
+
+ if (secret != null) {
+ keyName = keyPairInfo.getKeyName();
+ scpSecret = secret;
+ logger.info("Found previously created Key Pair {}
with Airavata secret id {}", keyName, secret.getSecretId());
+ }
+ }
+ }
+ } catch (Exception e ){
+ // Ignore
+ }
+
+ //if (availableKeyPair.isEmpty()) {
+ if (keyName == null) {
+ keyName = keyNamePrefix + UUID.randomUUID().toString();
+ logger.info("Creating EC2 key pair with name {}", keyName);
+ CreateKeyPairRequest createKeyPairRequest = new
CreateKeyPairRequest();
+
+ createKeyPairRequest.withKeyName(keyName);
+
+ CreateKeyPairResult createKeyPairResult =
amazonEC2.createKeyPair(createKeyPairRequest);
+
+ KeyPair keyPair = createKeyPairResult.getKeyPair();
+
+ String privateKey = keyPair.getKeyMaterial();
+
+ scpSecret = secretServiceClient.scp()
+ .createSCPSecret(SCPSecretCreateRequest.newBuilder()
+ .setAuthzToken(AuthToken.newBuilder()
+
.setUserTokenAuth(UserTokenAuth.newBuilder()
+
.setToken(getUserToken()).build()).build())
+ .setUser(ec2BackendObj.getLoginUserName())
+ .setPrivateKey(privateKey).build());
+
+ logger.info("Created SSH Secret {}", scpSecret.getSecretId());
+
+ CreateTagsRequest tagsRequest = new CreateTagsRequest();
+
tagsRequest.setResources(Collections.singletonList(createKeyPairResult.getKeyPair().getKeyPairId()));
+ Tag secretIdTag = new Tag();
+ secretIdTag.setKey("AIRAVATA_SECRET_ID");
+ secretIdTag.setValue(scpSecret.getSecretId());
+ tagsRequest.setTags(Collections.singletonList(secretIdTag));
+ amazonEC2.createTags(tagsRequest);
+ logger.info("Created tag on SSH keypair with secret id {}",
scpSecret.getSecretId());
+ }
+ putUserContent(EC2_INSTANCE_SECRET_ID, scpSecret.getSecretId(),
Scope.WORKFLOW);
+
+ RunInstancesRequest runInstancesRequest = new
RunInstancesRequest();
+
+ runInstancesRequest.withImageId(ec2BackendObj.getImageId())
+ .withInstanceType(InstanceType.T1Micro) // TODO Externalize
+ .withMinCount(1)
+ .withMaxCount(1)
+ .withKeyName(keyName)
+ .withTagSpecifications(
+ new
TagSpecification().withResourceType(ResourceType.Instance)
+ .withTags(new
Tag().withKey("Type").withValue("Airavata"),
+ new
Tag().withKey("Task").withValue(getTaskId()),
+ new
Tag().withKey("Name").withValue("Airavata Application VM")))
+ .withSecurityGroups(secGroupName);
+
+ logger.info("Launching the EC2 VM");
+ RunInstancesResult result =
amazonEC2.runInstances(runInstancesRequest);
+
+ String instanceId =
result.getReservation().getInstances().get(0).getInstanceId();
+ putUserContent(EC2_INSTANCE_ID, instanceId, Scope.WORKFLOW);
+
+ Thread.sleep(5000); // Waiting 5 seconds until instance details to
be consistent in amazon side
+
+ try {
+ DescribeInstancesRequest describeInstancesRequest = new
DescribeInstancesRequest();
+
describeInstancesRequest.setInstanceIds(Collections.singletonList(instanceId));
+
+ InstanceState instanceState = null;
+ String publicIpAddress = null;
+
+ logger.info("Waiting until instance {} is ready", instanceId);
+
+ for (int i = 0; i < 30; i++) {
+ DescribeInstancesResult describeInstancesResult =
amazonEC2.describeInstances(describeInstancesRequest);
+ Instance instance =
describeInstancesResult.getReservations().get(0).getInstances().get(0);
+ instanceState = instance.getState();
+ publicIpAddress = instance.getPublicIpAddress();
+
+ logger.info("Instance state {}, public ip {}",
instanceState.getName(), publicIpAddress);
+
+ if (instanceState.getName().equals("running") &&
publicIpAddress != null) {
+ break;
+ }
+ Thread.sleep(2000);
+ }
+
+ putUserContent(EC2_INSTANCE_IP, publicIpAddress,
Scope.WORKFLOW);
+
+ logger.info("Waiting 30 seconds until the ssh interface comes
up in instance {}", instanceId);
+ Thread.sleep(30000);
+ logger.info("EC2 Instance is running...");
+
+ } catch (Exception e) {
+ logger.error("Failed preparing instance {}. Deleting the
instance", instanceId, e);
+ TerminateInstancesRequest terminateInstancesRequest = new
TerminateInstancesRequest();
+
terminateInstancesRequest.setInstanceIds(Collections.singleton(instanceId));
+ amazonEC2.terminateInstances(terminateInstancesRequest);
+ throw e;
+ }
+ }
+
return new TaskResult(TaskResult.Status.COMPLETED, "Completed");
}
@@ -35,4 +244,28 @@ public class CreateEC2InstanceTask extends BaseTask {
public void setEc2Backend(EC2Backend ec2Backend) {
this.ec2Backend.set(ec2Backend);
}
+
+ public String getSecretServiceHost() {
+ return secretServiceHost.get();
+ }
+
+ public void setSecretServiceHost(String secretServiceHost) {
+ this.secretServiceHost.set( secretServiceHost);
+ }
+
+ public Integer getSecretServicePort() {
+ return secretServicePort.get();
+ }
+
+ public void setSecretServicePort(Integer secretServicePort) {
+ this.secretServicePort.set(secretServicePort);
+ }
+
+ public String getUserToken() {
+ return userToken.get();
+ }
+
+ public void setUserToken(String userToken) {
+ this.userToken.set(userToken);
+ }
}
diff --git
a/modules/airavata-apis/airavata-apis-server/src/main/java/org/apache/airavata/apis/workflow/task/ec2/DestroyEC2InstanceTask.java
b/modules/airavata-apis/airavata-apis-server/src/main/java/org/apache/airavata/apis/workflow/task/ec2/DestroyEC2InstanceTask.java
index 7c4f39674a..57ba1f0c7a 100644
---
a/modules/airavata-apis/airavata-apis-server/src/main/java/org/apache/airavata/apis/workflow/task/ec2/DestroyEC2InstanceTask.java
+++
b/modules/airavata-apis/airavata-apis-server/src/main/java/org/apache/airavata/apis/workflow/task/ec2/DestroyEC2InstanceTask.java
@@ -1,4 +1,118 @@
package org.apache.airavata.apis.workflow.task.ec2;
-public class DestroyEC2InstanceTask {
+import com.amazonaws.auth.AWSStaticCredentialsProvider;
+import com.amazonaws.auth.BasicAWSCredentials;
+import com.amazonaws.client.builder.AwsClientBuilder;
+import com.amazonaws.services.ec2.AmazonEC2;
+import com.amazonaws.services.ec2.AmazonEC2ClientBuilder;
+import com.amazonaws.services.ec2.model.TerminateInstancesRequest;
+import org.apache.airavata.api.execution.stubs.EC2Backend;
+import org.apache.airavata.apis.workflow.task.common.BaseTask;
+import org.apache.airavata.apis.workflow.task.common.annotation.TaskDef;
+import org.apache.airavata.apis.workflow.task.common.annotation.TaskParam;
+import org.apache.airavata.mft.common.AuthToken;
+import org.apache.airavata.mft.common.UserTokenAuth;
+import org.apache.airavata.mft.credential.stubs.s3.S3Secret;
+import org.apache.airavata.mft.credential.stubs.s3.S3SecretGetRequest;
+import org.apache.airavata.mft.secret.client.SecretServiceClient;
+import org.apache.airavata.mft.secret.client.SecretServiceClientBuilder;
+import org.apache.helix.task.TaskResult;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.Collections;
+
+@TaskDef(name = "DestroyEC2InstanceTask")
+public class DestroyEC2InstanceTask extends BaseTask {
+
+ private final static Logger logger =
LoggerFactory.getLogger(DestroyEC2InstanceTask.class);
+
+ @TaskParam(name = "ec2Backend")
+ private ThreadLocal<EC2Backend> ec2Backend = new ThreadLocal<>();
+
+ @TaskParam(name = "secretServiceHost")
+ private ThreadLocal<String> secretServiceHost = new ThreadLocal<>();
+
+ @TaskParam(name = "secretServicePort")
+ private ThreadLocal<Integer> secretServicePort = new ThreadLocal<>();
+
+ @TaskParam(name = "userToken")
+ private ThreadLocal<String> userToken = new ThreadLocal<>();
+
+ @TaskParam(name = "instanceId")
+ private ThreadLocal<String> instanceId = new ThreadLocal<>();
+
+ @Override
+ public TaskResult onRun() throws Exception {
+
+ logger.info("Destroying the instance {}", getInstanceId());
+ EC2Backend ec2BackendObj = getEc2Backend();
+ try (SecretServiceClient secretServiceClient =
SecretServiceClientBuilder
+ .buildClient(getSecretServiceHost(), getSecretServicePort())) {
+
+ S3Secret s3Secret =
secretServiceClient.s3().getS3Secret(S3SecretGetRequest.newBuilder()
+ .setAuthzToken(AuthToken.newBuilder()
+
.setUserTokenAuth(UserTokenAuth.newBuilder().setToken(getUserToken()).build())
+ .build())
+ .setSecretId(ec2BackendObj.getAwsCredentialId()).build());
+
+ BasicAWSCredentials awsCreds = new
BasicAWSCredentials(s3Secret.getAccessKey(), s3Secret.getSecretKey());
+
+ AmazonEC2 amazonEC2 =
AmazonEC2ClientBuilder.standard().withEndpointConfiguration(new
AwsClientBuilder.EndpointConfiguration(
+ "https://ec2." + ec2BackendObj.getRegion() +
".amazonaws.com", ec2BackendObj.getRegion()))
+ .withCredentials(new
AWSStaticCredentialsProvider(awsCreds))
+ .build();
+
+ TerminateInstancesRequest terminateInstancesRequest = new
TerminateInstancesRequest();
+
terminateInstancesRequest.setInstanceIds(Collections.singleton(getInstanceId()));
+ amazonEC2.terminateInstances(terminateInstancesRequest);
+ logger.info("EC2 instance {} was successfully destroyed",
getInstanceId());
+ }
+ return new TaskResult(TaskResult.Status.COMPLETED, "Completed");
+ }
+
+ @Override
+ public void onCancel() throws Exception {
+
+ }
+
+ public EC2Backend getEc2Backend() {
+ return ec2Backend.get();
+ }
+
+ public void setEc2Backend(EC2Backend ec2Backend) {
+ this.ec2Backend.set(ec2Backend);
+ }
+
+ public String getSecretServiceHost() {
+ return secretServiceHost.get();
+ }
+
+ public void setSecretServiceHost(String secretServiceHost) {
+ this.secretServiceHost.set( secretServiceHost);
+ }
+
+ public Integer getSecretServicePort() {
+ return secretServicePort.get();
+ }
+
+ public void setSecretServicePort(Integer secretServicePort) {
+ this.secretServicePort.set(secretServicePort);
+ }
+
+ public String getUserToken() {
+ return userToken.get();
+ }
+
+ public void setUserToken(String userToken) {
+ this.userToken.set(userToken);
+ }
+
+ public String getInstanceId() {
+ return instanceId.get();
+ }
+
+ public void setInstanceId(String instanceId) {
+ this.instanceId.set(instanceId);
+ }
}
diff --git
a/modules/airavata-apis/airavata-apis-stub/src/main/proto/execution/experiment_stub.proto
b/modules/airavata-apis/airavata-apis-stub/src/main/proto/execution/experiment_stub.proto
index 15457f14fe..8c0f178eeb 100644
---
a/modules/airavata-apis/airavata-apis-stub/src/main/proto/execution/experiment_stub.proto
+++
b/modules/airavata-apis/airavata-apis-stub/src/main/proto/execution/experiment_stub.proto
@@ -138,8 +138,9 @@ message ServerBackend {
message EC2Backend {
string flavor = 1;
string imageId = 2;
- string region = 3;
- string aws_credential_id = 4;
+ string loginUserName = 3;
+ string region = 4;
+ string aws_credential_id = 5;
}
message LocalBackend {