This is an automated email from the ASF dual-hosted git repository.
yasith pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airavata.git
The following commit(s) were added to refs/heads/master by this push:
new bb0e14c4a8 Support two storage resources to stage input data and
output data separately (#570)
bb0e14c4a8 is described below
commit bb0e14c4a82ab96029ec0869f68a63ab3d7c2493
Author: Lahiru Jayathilake <[email protected]>
AuthorDate: Mon Nov 10 20:04:51 2025 -0500
Support two storage resources to stage input data and output data
separately (#570)
* Add support for an optional output storage resource for experiments
* regenerate thrift ttypes for python sdk
* update sdk to handle input storageId and output storageId. use input
storageId as default value for output storageId
* Input and output storages to stage input data and output data respectively
* Rename storageId to inputStorageResourceId in Python SDK
* generated python thrift stubs to align with the latest input/output
storage resource types addition
* apply spotless styles
* example python script to create and launch an experiment
* fix resource ids and storage ids for example
* add script to launch 10 jobs with 5 in parallel.
* improved error handling for failed experiment launch
* apply spotless styles
---------
Co-authored-by: yasithdev <[email protected]>
---
.../init/04-expcatalog-migrations.sql | 23 +-
.../airavata/helix/impl/task/TaskContext.java | 90 +++-
.../helix/impl/task/staging/ArchiveTask.java | 8 +-
.../helix/impl/task/staging/DataStagingTask.java | 122 +++++-
.../impl/task/staging/InputDataStagingTask.java | 4 +-
.../impl/task/staging/OutputDataStagingTask.java | 13 +-
.../airavata/model/util/ExperimentModelUtil.java | 3 +-
.../core/entities/expcatalog/ProcessEntity.java | 23 +-
.../expcatalog/UserConfigurationDataEntity.java | 23 +-
.../expcatalog/ExperimentRepositoryTest.java | 6 +-
.../airavata/model/experiment/ttypes.py | 50 ++-
.../airavata/model/process/ttypes.py | 68 +--
.../airavata_experiments/airavata.py | 86 +++-
.../airavata_experiments/md/applications.py | 53 +++
.../clients/utils/data_model_creation_util.py | 6 +-
.../clients/utils/experiment_handler_util.py | 34 +-
.../samples/create_launch_echo_experiment.py | 3 +-
.../samples/create_launch_gaussian_experiment.py | 3 +-
dev-tools/batch_launch_experiments.py | 479 +++++++++++++++++++++
dev-tools/create_launch_experiment_with_storage.py | 267 ++++++++++++
.../service/handlers/AgentManagementHandler.java | 6 +-
.../service/models/AgentLaunchRequest.java | 19 +
.../data-models/experiment_model.thrift | 11 +-
.../data-models/process_model.thrift | 17 +-
24 files changed, 1271 insertions(+), 146 deletions(-)
diff --git a/.devcontainer/database_scripts/init/04-expcatalog-migrations.sql
b/.devcontainer/database_scripts/init/04-expcatalog-migrations.sql
index e3ebde0753..bb32d3f438 100644
--- a/.devcontainer/database_scripts/init/04-expcatalog-migrations.sql
+++ b/.devcontainer/database_scripts/init/04-expcatalog-migrations.sql
@@ -30,4 +30,25 @@ CREATE TABLE IF NOT EXISTS COMPUTE_RESOURCE_SCHEDULING (
PARALLEL_GROUP_COUNT INT,
PRIMARY KEY (EXPERIMENT_ID,RESOURCE_HOST_ID,QUEUE_NAME),
FOREIGN KEY (EXPERIMENT_ID) REFERENCES EXPERIMENT(EXPERIMENT_ID) ON
DELETE CASCADE
-)ENGINE=InnoDB DEFAULT CHARSET=latin1;
\ No newline at end of file
+)ENGINE=InnoDB DEFAULT CHARSET=latin1;
+
+-- Rename storage resource ID to input storage resource ID
+ALTER TABLE USER_CONFIGURATION_DATA CHANGE COLUMN STORAGE_RESOURCE_ID
INPUT_STORAGE_RESOURCE_ID VARCHAR(255) DEFAULT NULL;
+ALTER TABLE PROCESS CHANGE COLUMN STORAGE_RESOURCE_ID
INPUT_STORAGE_RESOURCE_ID VARCHAR(255) DEFAULT NULL;
+
+-- Add output storage resource ID columns
+ALTER TABLE USER_CONFIGURATION_DATA ADD COLUMN IF NOT EXISTS
OUTPUT_STORAGE_RESOURCE_ID VARCHAR(255) DEFAULT NULL;
+ALTER TABLE PROCESS ADD COLUMN IF NOT EXISTS OUTPUT_STORAGE_RESOURCE_ID
VARCHAR(255) DEFAULT NULL;
+
+-- Update OUTPUT_STORAGE_RESOURCE_ID with INPUT_STORAGE_RESOURCE_ID when
missing
+UPDATE USER_CONFIGURATION_DATA
+SET OUTPUT_STORAGE_RESOURCE_ID = INPUT_STORAGE_RESOURCE_ID
+WHERE (OUTPUT_STORAGE_RESOURCE_ID IS NULL OR OUTPUT_STORAGE_RESOURCE_ID = '')
+ AND INPUT_STORAGE_RESOURCE_ID IS NOT NULL
+ AND INPUT_STORAGE_RESOURCE_ID != '';
+
+UPDATE PROCESS
+SET OUTPUT_STORAGE_RESOURCE_ID = INPUT_STORAGE_RESOURCE_ID
+WHERE (OUTPUT_STORAGE_RESOURCE_ID IS NULL OR OUTPUT_STORAGE_RESOURCE_ID = '')
+ AND INPUT_STORAGE_RESOURCE_ID IS NOT NULL
+ AND INPUT_STORAGE_RESOURCE_ID != '';
\ No newline at end of file
diff --git
a/airavata-api/src/main/java/org/apache/airavata/helix/impl/task/TaskContext.java
b/airavata-api/src/main/java/org/apache/airavata/helix/impl/task/TaskContext.java
index a5491f3f53..5bab41ae8f 100644
---
a/airavata-api/src/main/java/org/apache/airavata/helix/impl/task/TaskContext.java
+++
b/airavata-api/src/main/java/org/apache/airavata/helix/impl/task/TaskContext.java
@@ -327,17 +327,45 @@ public class TaskContext {
this.userStoragePreference = userStoragePreference;
}
+ /**
+ * Returns the default storage preference for the gateway.
+ * Prefers gateway-specific storage (ID starting with gatewayId),
otherwise uses the first available preference.
+ *
+ * @deprecated Use {@link #getInputGatewayStorageResourcePreference()} for
input staging operations
+ * or {@link #getOutputGatewayStorageResourcePreference()} for
output staging operations.
+ */
+ @Deprecated
public StoragePreference getGatewayStorageResourcePreference() throws
Exception {
if (this.gatewayStorageResourcePreference == null) {
try {
- this.gatewayStorageResourcePreference =
- registryClient.getGatewayStoragePreference(gatewayId,
processModel.getStorageResourceId());
+ GatewayResourceProfile gatewayProfile =
getGatewayResourceProfile();
+ List<StoragePreference> storagePreferences =
gatewayProfile.getStoragePreferences();
+
+ if (storagePreferences == null ||
storagePreferences.isEmpty()) {
+ throw new Exception("No storage preferences found for
gateway " + gatewayId);
+ }
+
+ String gatewayPrefix = gatewayId + "_";
+ this.gatewayStorageResourcePreference =
storagePreferences.stream()
+ .filter(pref -> {
+ String id = pref.getStorageResourceId();
+ return id != null && id.startsWith(gatewayPrefix);
+ })
+ .findFirst()
+ .orElseGet(() -> {
+ logger.debug(
+ "No gateway-specific storage found, using
first available: {}",
+
storagePreferences.get(0).getStorageResourceId());
+ return storagePreferences.get(0);
+ });
+
+ if
(this.gatewayStorageResourcePreference.getStorageResourceId().startsWith(gatewayPrefix))
{
+ logger.debug(
+ "Using gateway-specific storage preference: {}",
+
this.gatewayStorageResourcePreference.getStorageResourceId());
+ }
} catch (TException e) {
- logger.error(
- "Failed to fetch gateway storage preference for
gateway {} and storage {}",
- gatewayId,
- processModel.getStorageResourceId(),
- e);
+ logger.error("Failed to fetch gateway storage preference for
gateway {}", gatewayId, e);
throw e;
}
}
@@ -739,6 +767,54 @@ public class TaskContext {
return getGatewayStorageResourcePreference().getStorageResourceId();
}
+ public String getInputStorageResourceId() throws Exception {
+ if (processModel.getInputStorageResourceId() != null
+ && !processModel.getInputStorageResourceId().trim().isEmpty())
{
+ return processModel.getInputStorageResourceId();
+ }
+ return getStorageResourceId();
+ }
+
+ public StoragePreference getInputGatewayStorageResourcePreference() throws
Exception {
+ String inputStorageId = getInputStorageResourceId();
+ try {
+ return registryClient.getGatewayStoragePreference(gatewayId,
inputStorageId);
+ } catch (TException e) {
+ logger.error(
+ "Failed to fetch gateway storage preference for input
storage {} in gateway {}",
+ inputStorageId,
+ gatewayId,
+ e);
+ throw e;
+ }
+ }
+
+ public String getOutputStorageResourceId() throws Exception {
+ if (processModel.getOutputStorageResourceId() != null
+ &&
!processModel.getOutputStorageResourceId().trim().isEmpty()) {
+ return processModel.getOutputStorageResourceId();
+ }
+ return getStorageResourceId();
+ }
+
+ public StoragePreference getOutputGatewayStorageResourcePreference()
throws Exception {
+ String outputStorageId = getOutputStorageResourceId();
+ try {
+ return registryClient.getGatewayStoragePreference(gatewayId,
outputStorageId);
+ } catch (TException e) {
+ logger.error(
+ "Failed to fetch gateway storage preference for output
storage {} in gateway {}",
+ outputStorageId,
+ gatewayId,
+ e);
+ throw e;
+ }
+ }
+
+ public StorageResourceDescription getOutputStorageResourceDescription()
throws Exception {
+ return registryClient.getStorageResource(getOutputStorageResourceId());
+ }
+
private ComputationalResourceSchedulingModel getProcessCRSchedule() {
if (getProcessModel() != null) {
return getProcessModel().getProcessResourceSchedule();
diff --git
a/airavata-api/src/main/java/org/apache/airavata/helix/impl/task/staging/ArchiveTask.java
b/airavata-api/src/main/java/org/apache/airavata/helix/impl/task/staging/ArchiveTask.java
index e4e3336831..46ed806e6f 100644
---
a/airavata-api/src/main/java/org/apache/airavata/helix/impl/task/staging/ArchiveTask.java
+++
b/airavata-api/src/main/java/org/apache/airavata/helix/impl/task/staging/ArchiveTask.java
@@ -68,7 +68,9 @@ public class ArchiveTask extends DataStagingTask {
tarDirPath = sourceURI.getPath();
}
- String inputPath =
getTaskContext().getStorageFileSystemRootLocation();
+ String inputPath = getTaskContext()
+ .getOutputGatewayStorageResourcePreference()
+ .getFileSystemRootLocation();
destFilePath = buildDestinationFilePath(inputPath,
archiveFileName);
tarCreationAbsPath = tarDirPath + File.separator +
archiveFileName;
@@ -77,8 +79,8 @@ public class ArchiveTask extends DataStagingTask {
"Failed to obtain source URI for archival staging task
" + getTaskId(), true, e);
}
- // Fetch and validate storage adaptor
- StorageResourceAdaptor storageResourceAdaptor =
getStorageAdaptor(taskHelper.getAdaptorSupport());
+ // Fetch and validate storage adaptor (uses output storage if
configured, otherwise default)
+ StorageResourceAdaptor storageResourceAdaptor =
getOutputStorageAdaptor(taskHelper.getAdaptorSupport());
// Fetch and validate compute resource adaptor
AgentAdaptor adaptor =
getComputeResourceAdaptor(taskHelper.getAdaptorSupport());
diff --git
a/airavata-api/src/main/java/org/apache/airavata/helix/impl/task/staging/DataStagingTask.java
b/airavata-api/src/main/java/org/apache/airavata/helix/impl/task/staging/DataStagingTask.java
index 6295ddb7bf..4320f72463 100644
---
a/airavata-api/src/main/java/org/apache/airavata/helix/impl/task/staging/DataStagingTask.java
+++
b/airavata-api/src/main/java/org/apache/airavata/helix/impl/task/staging/DataStagingTask.java
@@ -36,6 +36,7 @@ import org.apache.airavata.common.utils.ServerSettings;
import org.apache.airavata.helix.impl.task.AiravataTask;
import org.apache.airavata.helix.impl.task.TaskOnFailException;
import org.apache.airavata.helix.task.api.support.AdaptorSupport;
+import org.apache.airavata.model.appcatalog.gatewayprofile.StoragePreference;
import
org.apache.airavata.model.appcatalog.storageresource.StorageResourceDescription;
import org.apache.airavata.model.task.DataStagingTaskModel;
import org.apache.airavata.patform.monitoring.CountMonitor;
@@ -77,28 +78,88 @@ public abstract class DataStagingTask extends AiravataTask {
return storageResource;
}
- @SuppressWarnings("WeakerAccess")
+ /**
+ * Gets the default storage adaptor configured for the gateway.
+ * This is the fallback storage used when input/output storage resources
are not specifically configured.
+ */
protected StorageResourceAdaptor getStorageAdaptor(AdaptorSupport
adaptorSupport) throws TaskOnFailException {
String storageId = null;
try {
storageId = getTaskContext().getStorageResourceId();
- StorageResourceAdaptor storageResourceAdaptor =
adaptorSupport.fetchStorageAdaptor(
- getGatewayId(),
- getTaskContext().getStorageResourceId(),
- getTaskContext().getDataMovementProtocol(),
- getTaskContext().getStorageResourceCredentialToken(),
- getTaskContext().getStorageResourceLoginUserName());
+ StoragePreference gatewayStoragePref =
getTaskContext().getGatewayStorageResourcePreference();
+ return createStorageAdaptorFromPreference(adaptorSupport,
storageId, gatewayStoragePref, "Default");
- if (storageResourceAdaptor == null) {
- throw new TaskOnFailException(
- "Storage resource adaptor for " +
getTaskContext().getStorageResourceId() + " can not be null",
- true,
- null);
+ } catch (Exception e) {
+ logger.error(
+ "Failed to obtain adaptor for default storage resource {}
in task {}", storageId, getTaskId(), e);
+ throw new TaskOnFailException(
+ "Failed to obtain adaptor for default storage resource " +
storageId + " in task " + getTaskId(),
+ false,
+ e);
+ }
+ }
+
+ /**
+ * Gets the input storage adaptor.
+ * Use input storage resource if configured. Otherwise, falls back to
default gateway storage.
+ */
+ protected StorageResourceAdaptor getInputStorageAdaptor(AdaptorSupport
adaptorSupport) throws TaskOnFailException {
+ String storageId = null;
+ try {
+ storageId = getTaskContext().getInputStorageResourceId();
+
+ if (getTaskContext().getProcessModel().getInputStorageResourceId()
!= null
+ && !getTaskContext()
+ .getProcessModel()
+ .getInputStorageResourceId()
+ .trim()
+ .isEmpty()) {
+
+ StoragePreference inputStoragePref =
getTaskContext().getInputGatewayStorageResourcePreference();
+ return createStorageAdaptorFromPreference(adaptorSupport,
storageId, inputStoragePref, "Input");
+ } else {
+ // Fall back to default storage resource configured
+ return getStorageAdaptor(adaptorSupport);
}
- return storageResourceAdaptor;
} catch (Exception e) {
+ logger.error(
+ "Failed to obtain adaptor for input storage resource {} in
task {}", storageId, getTaskId(), e);
throw new TaskOnFailException(
- "Failed to obtain adaptor for storage resource " +
storageId + " in task " + getTaskId(), false, e);
+ "Failed to obtain adaptor for input storage resource " +
storageId + " in task " + getTaskId(),
+ false,
+ e);
+ }
+ }
+
+ /**
+ * Gets the output storage adaptor.
+ * Use output storage resource if configured. Otherwise, falls back to
default gateway storage.
+ */
+ protected StorageResourceAdaptor getOutputStorageAdaptor(AdaptorSupport
adaptorSupport) throws TaskOnFailException {
+ String storageId = null;
+ try {
+ storageId = getTaskContext().getOutputStorageResourceId();
+
+ if
(getTaskContext().getProcessModel().getOutputStorageResourceId() != null
+ && !getTaskContext()
+ .getProcessModel()
+ .getOutputStorageResourceId()
+ .trim()
+ .isEmpty()) {
+
+ StoragePreference outputStoragePref =
getTaskContext().getOutputGatewayStorageResourcePreference();
+ return createStorageAdaptorFromPreference(adaptorSupport,
storageId, outputStoragePref, "Output");
+ } else {
+ // Fall back to default storage resource configured
+ return getStorageAdaptor(adaptorSupport);
+ }
+ } catch (Exception e) {
+ logger.error(
+ "Failed to obtain adaptor for output storage resource {}
in task {}", storageId, getTaskId(), e);
+ throw new TaskOnFailException(
+ "Failed to obtain adaptor for output storage resource " +
storageId + " in task " + getTaskId(),
+ false,
+ e);
}
}
@@ -414,4 +475,37 @@ public abstract class DataStagingTask extends AiravataTask
{
logger.warn("Failed to delete temporary file " + filePath);
}
}
+
+ /**
+ * Common method to create StorageResourceAdaptor from a StoragePreference.
+ */
+ private StorageResourceAdaptor createStorageAdaptorFromPreference(
+ AdaptorSupport adaptorSupport, String storageId, StoragePreference
storagePreference, String adaptorType)
+ throws TaskOnFailException {
+ try {
+ String credentialToken =
storagePreference.getResourceSpecificCredentialStoreToken() != null
+ ?
storagePreference.getResourceSpecificCredentialStoreToken()
+ :
getTaskContext().getGatewayResourceProfile().getCredentialStoreToken();
+
+ StorageResourceAdaptor storageResourceAdaptor =
adaptorSupport.fetchStorageAdaptor(
+ getGatewayId(),
+ storageId,
+ getTaskContext().getDataMovementProtocol(),
+ credentialToken,
+ storagePreference.getLoginUserName());
+
+ if (storageResourceAdaptor == null) {
+ throw new TaskOnFailException(
+ adaptorType + " storage resource adaptor for " +
storageId + " can not be null", true, null);
+ }
+ return storageResourceAdaptor;
+
+ } catch (Exception e) {
+ throw new TaskOnFailException(
+ "Failed to obtain adaptor for " +
adaptorType.toLowerCase() + " storage resource " + storageId
+ + " in task " + getTaskId(),
+ false,
+ e);
+ }
+ }
}
diff --git
a/airavata-api/src/main/java/org/apache/airavata/helix/impl/task/staging/InputDataStagingTask.java
b/airavata-api/src/main/java/org/apache/airavata/helix/impl/task/staging/InputDataStagingTask.java
index 9840ffdb15..65660152ca 100644
---
a/airavata-api/src/main/java/org/apache/airavata/helix/impl/task/staging/InputDataStagingTask.java
+++
b/airavata-api/src/main/java/org/apache/airavata/helix/impl/task/staging/InputDataStagingTask.java
@@ -82,8 +82,8 @@ public class InputDataStagingTask extends DataStagingTask {
sourceUrls = new String[]
{dataStagingTaskModel.getSource()};
}
- // Fetch and validate storage adaptor
- StorageResourceAdaptor storageResourceAdaptor =
getStorageAdaptor(taskHelper.getAdaptorSupport());
+ // Fetch and validate storage adaptor (uses input storage if
configured, otherwise default)
+ StorageResourceAdaptor storageResourceAdaptor =
getInputStorageAdaptor(taskHelper.getAdaptorSupport());
// Fetch and validate compute resource adaptor
AgentAdaptor adaptor =
getComputeResourceAdaptor(taskHelper.getAdaptorSupport());
diff --git
a/airavata-api/src/main/java/org/apache/airavata/helix/impl/task/staging/OutputDataStagingTask.java
b/airavata-api/src/main/java/org/apache/airavata/helix/impl/task/staging/OutputDataStagingTask.java
index 19bb9f45e7..ba1e69e0e0 100644
---
a/airavata-api/src/main/java/org/apache/airavata/helix/impl/task/staging/OutputDataStagingTask.java
+++
b/airavata-api/src/main/java/org/apache/airavata/helix/impl/task/staging/OutputDataStagingTask.java
@@ -32,6 +32,7 @@ import org.apache.airavata.helix.impl.task.TaskContext;
import org.apache.airavata.helix.impl.task.TaskOnFailException;
import org.apache.airavata.helix.task.api.TaskHelper;
import org.apache.airavata.helix.task.api.annotation.TaskDef;
+import org.apache.airavata.model.appcatalog.gatewayprofile.StoragePreference;
import
org.apache.airavata.model.appcatalog.storageresource.StorageResourceDescription;
import org.apache.airavata.model.application.io.DataType;
import org.apache.airavata.model.application.io.OutputDataObjectType;
@@ -74,8 +75,8 @@ public class OutputDataStagingTask extends DataStagingTask {
throw new TaskOnFailException(message, true, null);
}
- // Fetch and validate storage resource
- StorageResourceDescription storageResource = getStorageResource();
+ // Use output storage resource if specified, otherwise fall back
to default
+ StorageResourceDescription storageResource =
getTaskContext().getOutputStorageResourceDescription();
// Fetch and validate source and destination URLS
URI sourceURI;
@@ -90,13 +91,13 @@ public class OutputDataStagingTask extends DataStagingTask {
sourceURI.getPath().length());
if (dataStagingTaskModel.getDestination().startsWith("dummy"))
{
-
- String inputPath =
getTaskContext().getStorageFileSystemRootLocation();
+ StoragePreference outputStoragePref =
getTaskContext().getOutputGatewayStorageResourcePreference();
+ String inputPath =
outputStoragePref.getFileSystemRootLocation();
String destFilePath = buildDestinationFilePath(inputPath,
sourceFileName);
destinationURI = new URI(
"file",
- getTaskContext().getStorageResourceLoginUserName(),
+ outputStoragePref.getLoginUserName(),
storageResource.getHostName(),
22,
destFilePath,
@@ -117,7 +118,7 @@ public class OutputDataStagingTask extends DataStagingTask {
}
// Fetch and validate storage adaptor
- StorageResourceAdaptor storageResourceAdaptor =
getStorageAdaptor(taskHelper.getAdaptorSupport());
+ StorageResourceAdaptor storageResourceAdaptor =
getOutputStorageAdaptor(taskHelper.getAdaptorSupport());
// Fetch and validate compute resource adaptor
AgentAdaptor adaptor =
getComputeResourceAdaptor(taskHelper.getAdaptorSupport());
diff --git
a/airavata-api/src/main/java/org/apache/airavata/model/util/ExperimentModelUtil.java
b/airavata-api/src/main/java/org/apache/airavata/model/util/ExperimentModelUtil.java
index 73555e6c5b..8d4e3cbf68 100644
---
a/airavata-api/src/main/java/org/apache/airavata/model/util/ExperimentModelUtil.java
+++
b/airavata-api/src/main/java/org/apache/airavata/model/util/ExperimentModelUtil.java
@@ -92,7 +92,8 @@ public class ExperimentModelUtil {
UserConfigurationDataModel configData =
experiment.getUserConfigurationData();
if (configData != null) {
- processModel.setStorageResourceId(configData.getStorageId());
+
processModel.setInputStorageResourceId(configData.getInputStorageResourceId());
+
processModel.setOutputStorageResourceId(configData.getOutputStorageResourceId());
processModel.setExperimentDataDir(configData.getExperimentDataDir());
processModel.setGenerateCert(configData.isGenerateCert());
processModel.setUserDn(configData.getUserDN());
diff --git
a/airavata-api/src/main/java/org/apache/airavata/registry/core/entities/expcatalog/ProcessEntity.java
b/airavata-api/src/main/java/org/apache/airavata/registry/core/entities/expcatalog/ProcessEntity.java
index 4e7a458cb4..2ec14bb6c0 100644
---
a/airavata-api/src/main/java/org/apache/airavata/registry/core/entities/expcatalog/ProcessEntity.java
+++
b/airavata-api/src/main/java/org/apache/airavata/registry/core/entities/expcatalog/ProcessEntity.java
@@ -73,8 +73,11 @@ public class ProcessEntity implements Serializable {
@Column(name = "EMAIL_ADDRESSES")
private String emailAddresses;
- @Column(name = "STORAGE_RESOURCE_ID")
- private String storageResourceId;
+ @Column(name = "INPUT_STORAGE_RESOURCE_ID")
+ private String inputStorageResourceId;
+
+ @Column(name = "OUTPUT_STORAGE_RESOURCE_ID")
+ private String outputStorageResourceId;
@Column(name = "USER_DN")
private String userDn;
@@ -246,12 +249,20 @@ public class ProcessEntity implements Serializable {
this.emailAddresses = emailAddresses;
}
- public String getStorageResourceId() {
- return storageResourceId;
+ public String getInputStorageResourceId() {
+ return inputStorageResourceId;
+ }
+
+ public void setInputStorageResourceId(String inputStorageResourceId) {
+ this.inputStorageResourceId = inputStorageResourceId;
+ }
+
+ public String getOutputStorageResourceId() {
+ return outputStorageResourceId;
}
- public void setStorageResourceId(String storageResourceId) {
- this.storageResourceId = storageResourceId;
+ public void setOutputStorageResourceId(String outputStorageResourceId) {
+ this.outputStorageResourceId = outputStorageResourceId;
}
public String getUserDn() {
diff --git
a/airavata-api/src/main/java/org/apache/airavata/registry/core/entities/expcatalog/UserConfigurationDataEntity.java
b/airavata-api/src/main/java/org/apache/airavata/registry/core/entities/expcatalog/UserConfigurationDataEntity.java
index 1e68da78df..e1e3f8ca24 100644
---
a/airavata-api/src/main/java/org/apache/airavata/registry/core/entities/expcatalog/UserConfigurationDataEntity.java
+++
b/airavata-api/src/main/java/org/apache/airavata/registry/core/entities/expcatalog/UserConfigurationDataEntity.java
@@ -86,8 +86,11 @@ public class UserConfigurationDataEntity implements
Serializable {
@Column(name = "OVERRIDE_ALLOCATION_PROJECT_NUMBER")
private String overrideAllocationProjectNumber;
- @Column(name = "STORAGE_RESOURCE_ID")
- private String storageId;
+ @Column(name = "INPUT_STORAGE_RESOURCE_ID")
+ private String inputStorageResourceId;
+
+ @Column(name = "OUTPUT_STORAGE_RESOURCE_ID")
+ private String outputStorageResourceId;
@Column(name = "EXPERIMENT_DATA_DIR", length = 512)
private String experimentDataDir;
@@ -255,12 +258,20 @@ public class UserConfigurationDataEntity implements
Serializable {
this.overrideAllocationProjectNumber = overrideAllocationProjectNumber;
}
- public String getStorageId() {
- return storageId;
+ public String getInputStorageResourceId() {
+ return inputStorageResourceId;
+ }
+
+ public void setInputStorageResourceId(String inputStorageResourceId) {
+ this.inputStorageResourceId = inputStorageResourceId;
+ }
+
+ public String getOutputStorageResourceId() {
+ return outputStorageResourceId;
}
- public void setStorageId(String storageId) {
- this.storageId = storageId;
+ public void setOutputStorageResourceId(String outputStorageResourceId) {
+ this.outputStorageResourceId = outputStorageResourceId;
}
public String getExperimentDataDir() {
diff --git
a/airavata-api/src/test/java/org/apache/airavata/registry/core/repositories/expcatalog/ExperimentRepositoryTest.java
b/airavata-api/src/test/java/org/apache/airavata/registry/core/repositories/expcatalog/ExperimentRepositoryTest.java
index 23a9b19fe8..33499575a6 100644
---
a/airavata-api/src/test/java/org/apache/airavata/registry/core/repositories/expcatalog/ExperimentRepositoryTest.java
+++
b/airavata-api/src/test/java/org/apache/airavata/registry/core/repositories/expcatalog/ExperimentRepositoryTest.java
@@ -129,12 +129,14 @@ public class ExperimentRepositoryTest extends TestBase {
assertEquals(
experimentId,
experimentRepository.addUserConfigurationData(userConfigurationDataModel,
experimentId));
- userConfigurationDataModel.setStorageId("storage2");
+ userConfigurationDataModel.setInputStorageResourceId("storage2");
+ userConfigurationDataModel.setOutputStorageResourceId("storage2");
experimentRepository.updateUserConfigurationData(userConfigurationDataModel,
experimentId);
final UserConfigurationDataModel retrievedUserConfigurationDataModel =
experimentRepository.getUserConfigurationData(experimentId);
- assertEquals("storage2",
retrievedUserConfigurationDataModel.getStorageId());
+ assertEquals("storage2",
retrievedUserConfigurationDataModel.getInputStorageResourceId());
+ assertEquals("storage2",
retrievedUserConfigurationDataModel.getOutputStorageResourceId());
final ComputationalResourceSchedulingModel
retrievedComputationalResourceScheduling =
retrievedUserConfigurationDataModel.getComputationalResourceScheduling();
assertNotNull(retrievedComputationalResourceScheduling);
diff --git a/dev-tools/airavata-python-sdk/airavata/model/experiment/ttypes.py
b/dev-tools/airavata-python-sdk/airavata/model/experiment/ttypes.py
index 99ccd9229e..abbcfe59b7 100644
--- a/dev-tools/airavata-python-sdk/airavata/model/experiment/ttypes.py
+++ b/dev-tools/airavata-python-sdk/airavata/model/experiment/ttypes.py
@@ -65,7 +65,8 @@ class UserConfigurationDataModel(object):
- throttleResources
- userDN
- generateCert
- - storageId
+ - inputStorageResourceId
+ - outputStorageResourceId
- experimentDataDir
- useUserCRPref
- groupResourceProfileId
@@ -75,7 +76,7 @@ class UserConfigurationDataModel(object):
thrift_spec: typing.Any = None
- def __init__(self, airavataAutoSchedule: bool = False,
overrideManualScheduledParams: bool = False, shareExperimentPublicly:
typing.Optional[bool] = False, computationalResourceScheduling:
typing.Optional[airavata.model.scheduling.ttypes.ComputationalResourceSchedulingModel]
= None, throttleResources: typing.Optional[bool] = False, userDN:
typing.Optional[str] = None, generateCert: typing.Optional[bool] = False,
storageId: typing.Optional[str] = None, experimentDataDir: typing.Option [...]
+ def __init__(self, airavataAutoSchedule: bool = False,
overrideManualScheduledParams: bool = False, shareExperimentPublicly:
typing.Optional[bool] = False, computationalResourceScheduling:
typing.Optional[airavata.model.scheduling.ttypes.ComputationalResourceSchedulingModel]
= None, throttleResources: typing.Optional[bool] = False, userDN:
typing.Optional[str] = None, generateCert: typing.Optional[bool] = False,
inputStorageResourceId: typing.Optional[str] = None, outputStorageResour [...]
self.airavataAutoSchedule: bool = airavataAutoSchedule
self.overrideManualScheduledParams: bool =
overrideManualScheduledParams
self.shareExperimentPublicly: typing.Optional[bool] =
shareExperimentPublicly
@@ -83,7 +84,8 @@ class UserConfigurationDataModel(object):
self.throttleResources: typing.Optional[bool] = throttleResources
self.userDN: typing.Optional[str] = userDN
self.generateCert: typing.Optional[bool] = generateCert
- self.storageId: typing.Optional[str] = storageId
+ self.inputStorageResourceId: typing.Optional[str] =
inputStorageResourceId
+ self.outputStorageResourceId: typing.Optional[str] =
outputStorageResourceId
self.experimentDataDir: typing.Optional[str] = experimentDataDir
self.useUserCRPref: typing.Optional[bool] = useUserCRPref
self.groupResourceProfileId: typing.Optional[str] =
groupResourceProfileId
@@ -136,25 +138,30 @@ class UserConfigurationDataModel(object):
iprot.skip(ftype)
elif fid == 8:
if ftype == TType.STRING:
- self.storageId = iprot.readString().decode('utf-8',
errors='replace') if sys.version_info[0] == 2 else iprot.readString()
+ self.inputStorageResourceId =
iprot.readString().decode('utf-8', errors='replace') if sys.version_info[0] ==
2 else iprot.readString()
else:
iprot.skip(ftype)
elif fid == 9:
if ftype == TType.STRING:
- self.experimentDataDir =
iprot.readString().decode('utf-8', errors='replace') if sys.version_info[0] ==
2 else iprot.readString()
+ self.outputStorageResourceId =
iprot.readString().decode('utf-8', errors='replace') if sys.version_info[0] ==
2 else iprot.readString()
else:
iprot.skip(ftype)
elif fid == 10:
+ if ftype == TType.STRING:
+ self.experimentDataDir =
iprot.readString().decode('utf-8', errors='replace') if sys.version_info[0] ==
2 else iprot.readString()
+ else:
+ iprot.skip(ftype)
+ elif fid == 11:
if ftype == TType.BOOL:
self.useUserCRPref = iprot.readBool()
else:
iprot.skip(ftype)
- elif fid == 11:
+ elif fid == 12:
if ftype == TType.STRING:
self.groupResourceProfileId =
iprot.readString().decode('utf-8', errors='replace') if sys.version_info[0] ==
2 else iprot.readString()
else:
iprot.skip(ftype)
- elif fid == 12:
+ elif fid == 13:
if ftype == TType.LIST:
self.autoScheduledCompResourceSchedulingList = []
(_etype3, _size0) = iprot.readListBegin()
@@ -204,24 +211,28 @@ class UserConfigurationDataModel(object):
oprot.writeFieldBegin('generateCert', TType.BOOL, 7)
oprot.writeBool(self.generateCert)
oprot.writeFieldEnd()
- if self.storageId is not None:
- oprot.writeFieldBegin('storageId', TType.STRING, 8)
- oprot.writeString(self.storageId.encode('utf-8') if
sys.version_info[0] == 2 else self.storageId)
+ if self.inputStorageResourceId is not None:
+ oprot.writeFieldBegin('inputStorageResourceId', TType.STRING, 8)
+ oprot.writeString(self.inputStorageResourceId.encode('utf-8') if
sys.version_info[0] == 2 else self.inputStorageResourceId)
+ oprot.writeFieldEnd()
+ if self.outputStorageResourceId is not None:
+ oprot.writeFieldBegin('outputStorageResourceId', TType.STRING, 9)
+ oprot.writeString(self.outputStorageResourceId.encode('utf-8') if
sys.version_info[0] == 2 else self.outputStorageResourceId)
oprot.writeFieldEnd()
if self.experimentDataDir is not None:
- oprot.writeFieldBegin('experimentDataDir', TType.STRING, 9)
+ oprot.writeFieldBegin('experimentDataDir', TType.STRING, 10)
oprot.writeString(self.experimentDataDir.encode('utf-8') if
sys.version_info[0] == 2 else self.experimentDataDir)
oprot.writeFieldEnd()
if self.useUserCRPref is not None:
- oprot.writeFieldBegin('useUserCRPref', TType.BOOL, 10)
+ oprot.writeFieldBegin('useUserCRPref', TType.BOOL, 11)
oprot.writeBool(self.useUserCRPref)
oprot.writeFieldEnd()
if self.groupResourceProfileId is not None:
- oprot.writeFieldBegin('groupResourceProfileId', TType.STRING, 11)
+ oprot.writeFieldBegin('groupResourceProfileId', TType.STRING, 12)
oprot.writeString(self.groupResourceProfileId.encode('utf-8') if
sys.version_info[0] == 2 else self.groupResourceProfileId)
oprot.writeFieldEnd()
if self.autoScheduledCompResourceSchedulingList is not None:
- oprot.writeFieldBegin('autoScheduledCompResourceSchedulingList',
TType.LIST, 12)
+ oprot.writeFieldBegin('autoScheduledCompResourceSchedulingList',
TType.LIST, 13)
oprot.writeListBegin(TType.STRUCT,
len(self.autoScheduledCompResourceSchedulingList))
for iter6 in self.autoScheduledCompResourceSchedulingList:
iter6.write(oprot)
@@ -1039,11 +1050,12 @@ UserConfigurationDataModel.thrift_spec = (
(5, TType.BOOL, 'throttleResources', None, False, ), # 5
(6, TType.STRING, 'userDN', 'UTF8', None, ), # 6
(7, TType.BOOL, 'generateCert', None, False, ), # 7
- (8, TType.STRING, 'storageId', 'UTF8', None, ), # 8
- (9, TType.STRING, 'experimentDataDir', 'UTF8', None, ), # 9
- (10, TType.BOOL, 'useUserCRPref', None, None, ), # 10
- (11, TType.STRING, 'groupResourceProfileId', 'UTF8', None, ), # 11
- (12, TType.LIST, 'autoScheduledCompResourceSchedulingList', (TType.STRUCT,
[airavata.model.scheduling.ttypes.ComputationalResourceSchedulingModel, None],
False), None, ), # 12
+ (8, TType.STRING, 'inputStorageResourceId', 'UTF8', None, ), # 8
+ (9, TType.STRING, 'outputStorageResourceId', 'UTF8', None, ), # 9
+ (10, TType.STRING, 'experimentDataDir', 'UTF8', None, ), # 10
+ (11, TType.BOOL, 'useUserCRPref', None, None, ), # 11
+ (12, TType.STRING, 'groupResourceProfileId', 'UTF8', None, ), # 12
+ (13, TType.LIST, 'autoScheduledCompResourceSchedulingList', (TType.STRUCT,
[airavata.model.scheduling.ttypes.ComputationalResourceSchedulingModel, None],
False), None, ), # 13
)
all_structs.append(ExperimentModel)
ExperimentModel.thrift_spec = (
diff --git a/dev-tools/airavata-python-sdk/airavata/model/process/ttypes.py
b/dev-tools/airavata-python-sdk/airavata/model/process/ttypes.py
index 1f6a0f8272..4b82d5abdd 100644
--- a/dev-tools/airavata-python-sdk/airavata/model/process/ttypes.py
+++ b/dev-tools/airavata-python-sdk/airavata/model/process/ttypes.py
@@ -150,7 +150,8 @@ class ProcessModel(object):
- gatewayExecutionId
- enableEmailNotification
- emailAddresses
- - storageResourceId
+ - inputStorageResourceId
+ - outputStorageResourceId
- userDn
- generateCert
- experimentDataDir
@@ -163,7 +164,7 @@ class ProcessModel(object):
thrift_spec: typing.Any = None
- def __init__(self, processId: str = "DO_NOT_SET_AT_CLIENTS", experimentId:
str = None, creationTime: typing.Optional[int] = None, lastUpdateTime:
typing.Optional[int] = None, processStatuses:
typing.Optional[list[airavata.model.status.ttypes.ProcessStatus]] = None,
processDetail: typing.Optional[str] = None, applicationInterfaceId:
typing.Optional[str] = None, applicationDeploymentId: typing.Optional[str] =
None, computeResourceId: typing.Optional[str] = None, processInputs: typing.O
[...]
+ def __init__(self, processId: str = "DO_NOT_SET_AT_CLIENTS", experimentId:
str = None, creationTime: typing.Optional[int] = None, lastUpdateTime:
typing.Optional[int] = None, processStatuses:
typing.Optional[list[airavata.model.status.ttypes.ProcessStatus]] = None,
processDetail: typing.Optional[str] = None, applicationInterfaceId:
typing.Optional[str] = None, applicationDeploymentId: typing.Optional[str] =
None, computeResourceId: typing.Optional[str] = None, processInputs: typing.O
[...]
self.processId: str = processId
self.experimentId: str = experimentId
self.creationTime: typing.Optional[int] = creationTime
@@ -182,7 +183,8 @@ class ProcessModel(object):
self.gatewayExecutionId: typing.Optional[str] = gatewayExecutionId
self.enableEmailNotification: typing.Optional[bool] =
enableEmailNotification
self.emailAddresses: typing.Optional[list[str]] = emailAddresses
- self.storageResourceId: typing.Optional[str] = storageResourceId
+ self.inputStorageResourceId: typing.Optional[str] =
inputStorageResourceId
+ self.outputStorageResourceId: typing.Optional[str] =
outputStorageResourceId
self.userDn: typing.Optional[str] = userDn
self.generateCert: typing.Optional[bool] = generateCert
self.experimentDataDir: typing.Optional[str] = experimentDataDir
@@ -328,40 +330,45 @@ class ProcessModel(object):
iprot.skip(ftype)
elif fid == 19:
if ftype == TType.STRING:
- self.storageResourceId =
iprot.readString().decode('utf-8', errors='replace') if sys.version_info[0] ==
2 else iprot.readString()
+ self.inputStorageResourceId =
iprot.readString().decode('utf-8', errors='replace') if sys.version_info[0] ==
2 else iprot.readString()
else:
iprot.skip(ftype)
elif fid == 20:
if ftype == TType.STRING:
- self.userDn = iprot.readString().decode('utf-8',
errors='replace') if sys.version_info[0] == 2 else iprot.readString()
+ self.outputStorageResourceId =
iprot.readString().decode('utf-8', errors='replace') if sys.version_info[0] ==
2 else iprot.readString()
else:
iprot.skip(ftype)
elif fid == 21:
+ if ftype == TType.STRING:
+ self.userDn = iprot.readString().decode('utf-8',
errors='replace') if sys.version_info[0] == 2 else iprot.readString()
+ else:
+ iprot.skip(ftype)
+ elif fid == 22:
if ftype == TType.BOOL:
self.generateCert = iprot.readBool()
else:
iprot.skip(ftype)
- elif fid == 22:
+ elif fid == 23:
if ftype == TType.STRING:
self.experimentDataDir =
iprot.readString().decode('utf-8', errors='replace') if sys.version_info[0] ==
2 else iprot.readString()
else:
iprot.skip(ftype)
- elif fid == 23:
+ elif fid == 24:
if ftype == TType.STRING:
self.userName = iprot.readString().decode('utf-8',
errors='replace') if sys.version_info[0] == 2 else iprot.readString()
else:
iprot.skip(ftype)
- elif fid == 24:
+ elif fid == 25:
if ftype == TType.BOOL:
self.useUserCRPref = iprot.readBool()
else:
iprot.skip(ftype)
- elif fid == 25:
+ elif fid == 26:
if ftype == TType.STRING:
self.groupResourceProfileId =
iprot.readString().decode('utf-8', errors='replace') if sys.version_info[0] ==
2 else iprot.readString()
else:
iprot.skip(ftype)
- elif fid == 26:
+ elif fid == 27:
if ftype == TType.LIST:
self.processWorkflows = []
(_etype39, _size36) = iprot.readListBegin()
@@ -473,36 +480,40 @@ class ProcessModel(object):
oprot.writeString(iter47.encode('utf-8') if
sys.version_info[0] == 2 else iter47)
oprot.writeListEnd()
oprot.writeFieldEnd()
- if self.storageResourceId is not None:
- oprot.writeFieldBegin('storageResourceId', TType.STRING, 19)
- oprot.writeString(self.storageResourceId.encode('utf-8') if
sys.version_info[0] == 2 else self.storageResourceId)
+ if self.inputStorageResourceId is not None:
+ oprot.writeFieldBegin('inputStorageResourceId', TType.STRING, 19)
+ oprot.writeString(self.inputStorageResourceId.encode('utf-8') if
sys.version_info[0] == 2 else self.inputStorageResourceId)
+ oprot.writeFieldEnd()
+ if self.outputStorageResourceId is not None:
+ oprot.writeFieldBegin('outputStorageResourceId', TType.STRING, 20)
+ oprot.writeString(self.outputStorageResourceId.encode('utf-8') if
sys.version_info[0] == 2 else self.outputStorageResourceId)
oprot.writeFieldEnd()
if self.userDn is not None:
- oprot.writeFieldBegin('userDn', TType.STRING, 20)
+ oprot.writeFieldBegin('userDn', TType.STRING, 21)
oprot.writeString(self.userDn.encode('utf-8') if
sys.version_info[0] == 2 else self.userDn)
oprot.writeFieldEnd()
if self.generateCert is not None:
- oprot.writeFieldBegin('generateCert', TType.BOOL, 21)
+ oprot.writeFieldBegin('generateCert', TType.BOOL, 22)
oprot.writeBool(self.generateCert)
oprot.writeFieldEnd()
if self.experimentDataDir is not None:
- oprot.writeFieldBegin('experimentDataDir', TType.STRING, 22)
+ oprot.writeFieldBegin('experimentDataDir', TType.STRING, 23)
oprot.writeString(self.experimentDataDir.encode('utf-8') if
sys.version_info[0] == 2 else self.experimentDataDir)
oprot.writeFieldEnd()
if self.userName is not None:
- oprot.writeFieldBegin('userName', TType.STRING, 23)
+ oprot.writeFieldBegin('userName', TType.STRING, 24)
oprot.writeString(self.userName.encode('utf-8') if
sys.version_info[0] == 2 else self.userName)
oprot.writeFieldEnd()
if self.useUserCRPref is not None:
- oprot.writeFieldBegin('useUserCRPref', TType.BOOL, 24)
+ oprot.writeFieldBegin('useUserCRPref', TType.BOOL, 25)
oprot.writeBool(self.useUserCRPref)
oprot.writeFieldEnd()
if self.groupResourceProfileId is not None:
- oprot.writeFieldBegin('groupResourceProfileId', TType.STRING, 25)
+ oprot.writeFieldBegin('groupResourceProfileId', TType.STRING, 26)
oprot.writeString(self.groupResourceProfileId.encode('utf-8') if
sys.version_info[0] == 2 else self.groupResourceProfileId)
oprot.writeFieldEnd()
if self.processWorkflows is not None:
- oprot.writeFieldBegin('processWorkflows', TType.LIST, 26)
+ oprot.writeFieldBegin('processWorkflows', TType.LIST, 27)
oprot.writeListBegin(TType.STRUCT, len(self.processWorkflows))
for iter48 in self.processWorkflows:
iter48.write(oprot)
@@ -557,14 +568,15 @@ ProcessModel.thrift_spec = (
(16, TType.STRING, 'gatewayExecutionId', 'UTF8', None, ), # 16
(17, TType.BOOL, 'enableEmailNotification', None, None, ), # 17
(18, TType.LIST, 'emailAddresses', (TType.STRING, 'UTF8', False), None, ),
# 18
- (19, TType.STRING, 'storageResourceId', 'UTF8', None, ), # 19
- (20, TType.STRING, 'userDn', 'UTF8', None, ), # 20
- (21, TType.BOOL, 'generateCert', None, False, ), # 21
- (22, TType.STRING, 'experimentDataDir', 'UTF8', None, ), # 22
- (23, TType.STRING, 'userName', 'UTF8', None, ), # 23
- (24, TType.BOOL, 'useUserCRPref', None, None, ), # 24
- (25, TType.STRING, 'groupResourceProfileId', 'UTF8', None, ), # 25
- (26, TType.LIST, 'processWorkflows', (TType.STRUCT, [ProcessWorkflow,
None], False), None, ), # 26
+ (19, TType.STRING, 'inputStorageResourceId', 'UTF8', None, ), # 19
+ (20, TType.STRING, 'outputStorageResourceId', 'UTF8', None, ), # 20
+ (21, TType.STRING, 'userDn', 'UTF8', None, ), # 21
+ (22, TType.BOOL, 'generateCert', None, False, ), # 22
+ (23, TType.STRING, 'experimentDataDir', 'UTF8', None, ), # 23
+ (24, TType.STRING, 'userName', 'UTF8', None, ), # 24
+ (25, TType.BOOL, 'useUserCRPref', None, None, ), # 25
+ (26, TType.STRING, 'groupResourceProfileId', 'UTF8', None, ), # 26
+ (27, TType.LIST, 'processWorkflows', (TType.STRUCT, [ProcessWorkflow,
None], False), None, ), # 27
)
fix_spec(all_structs)
del all_structs
diff --git a/dev-tools/airavata-python-sdk/airavata_experiments/airavata.py
b/dev-tools/airavata-python-sdk/airavata_experiments/airavata.py
index f5ca757de1..24961dcfc1 100644
--- a/dev-tools/airavata-python-sdk/airavata_experiments/airavata.py
+++ b/dev-tools/airavata-python-sdk/airavata_experiments/airavata.py
@@ -112,7 +112,8 @@ class AiravataOperator:
experiment_model: ExperimentModel,
computation_resource_name: str,
group: str,
- storageId: str,
+ inputStorageId: str,
+ outputStorageId: str,
node_count: int,
total_cpu_count: int,
queue_name: str,
@@ -133,7 +134,8 @@ class AiravataOperator:
userConfigData.computationalResourceScheduling = computRes
userConfigData.groupResourceProfileId = groupResourceProfileId
- userConfigData.storageId = storageId
+ userConfigData.inputStorageResourceId = inputStorageId
+ userConfigData.outputStorageResourceId = outputStorageId
userConfigData.experimentDataDir = experiment_dir_path
userConfigData.airavataAutoSchedule = auto_schedule
@@ -535,7 +537,8 @@ class AiravataOperator:
group: str = "Default",
*,
gateway_id: str | None = None,
- sr_host: str | None = None,
+ input_sr_host: str | None = None,
+ output_sr_host: str | None = None,
auto_schedule: bool = False,
) -> LaunchState:
"""
@@ -545,7 +548,8 @@ class AiravataOperator:
# preprocess args (str)
print("[AV] Preprocessing args...")
gateway_id = str(gateway_id or self.default_gateway_id())
- sr_host = str(sr_host or self.default_sr_hostname())
+ input_sr_host = str(input_sr_host or self.default_sr_hostname())
+ output_sr_host = str(output_sr_host or input_sr_host or
self.default_sr_hostname())
mount_point = Path(self.default_gateway_data_store_dir()) / self.user_id
server_url = urlparse(self.connection_svc_url()).netloc
@@ -558,7 +562,8 @@ class AiravataOperator:
assert len(gateway_id) > 0, f"Invalid gateway_id: {gateway_id}"
assert len(queue_name) > 0, f"Invalid queue_name: {queue_name}"
assert len(group) > 0, f"Invalid group name: {group}"
- assert len(sr_host) > 0, f"Invalid sr_host: {sr_host}"
+ assert len(input_sr_host) > 0, f"Invalid input_sr_host: {input_sr_host}"
+ assert len(output_sr_host) > 0, f"Invalid output_sr_host: {output_sr_host}"
assert len(project) > 0, f"Invalid project_name: {project}"
assert len(mount_point.as_posix()) > 0, f"Invalid mount_point:
{mount_point}"
@@ -585,10 +590,14 @@ class AiravataOperator:
data_inputs.update({"agent_id": data_inputs.get("agent_id",
str(uuid.uuid4()))})
data_inputs.update({"server_url": server_url})
- # setup runtime params
- print("[AV] Setting up runtime params...")
- storage = self.get_storage(sr_host)
- sr_id = storage.storageResourceId
+ # setup storage
+ print("[AV] Setting up storage...")
+ input_storage = self.get_storage(input_sr_host)
+ output_storage = self.get_storage(output_sr_host)
+ assert input_storage is not None, f"Invalid input_storage: {input_storage}"
+ assert output_storage is not None, f"Invalid output_storage:
{output_storage}"
+ input_sr_id = input_storage.storageResourceId
+ output_sr_id = output_storage.storageResourceId
# setup application interface
print("[AV] Setting up application interface...")
@@ -607,7 +616,7 @@ class AiravataOperator:
# setup experiment directory
print("[AV] Setting up experiment directory...")
exp_dir = self.make_experiment_dir(
- sr_host=storage.hostName,
+ sr_host=input_storage.hostName,
project_name=project,
experiment_name=experiment_name,
)
@@ -620,7 +629,8 @@ class AiravataOperator:
experiment_model=experiment,
computation_resource_name=computation_resource_name,
group=group,
- storageId=sr_id,
+ inputStorageId=input_sr_id,
+ outputStorageId=output_sr_id,
node_count=node_count,
total_cpu_count=cpu_count,
wall_time_limit=walltime,
@@ -630,7 +640,7 @@ class AiravataOperator:
)
def register_input_file(file: Path) -> str:
- return str(self.register_input_file(file.name, sr_host, sr_id,
gateway_id, file.name, abs_path))
+ return str(self.register_input_file(file.name, input_sr_host,
input_sr_id, gateway_id, file.name, abs_path))
# set up experiment inputs
print("[AV] Setting up experiment inputs...")
@@ -671,7 +681,7 @@ class AiravataOperator:
# upload file inputs for experiment
print(f"[AV] Uploading {len(files_to_upload)} file inputs for
experiment...")
- self.upload_files(None, None, storage.hostName, files_to_upload, exp_dir)
+ self.upload_files(None, None, input_storage.hostName, files_to_upload,
exp_dir)
# create experiment
print(f"[AV] Creating experiment...")
@@ -693,25 +703,61 @@ class AiravataOperator:
# wait until experiment begins, then get process id
print(f"[AV] Experiment {experiment_name} WAITING until experiment
begins...")
process_id = None
- while process_id is None:
+ max_wait_process = 300 # 10 minutes max wait for process
+ wait_count_process = 0
+ while process_id is None and wait_count_process < max_wait_process:
+ # Check experiment status - if failed, raise error
+ try:
+ status = self.get_experiment_status(ex_id)
+ if status == ExperimentState.FAILED:
+ raise Exception(f"[AV] Experiment {experiment_name} FAILED while
waiting for process to begin")
+ if status in [ExperimentState.COMPLETED, ExperimentState.CANCELED]:
+ raise Exception(f"[AV] Experiment {experiment_name} reached terminal
state {status.name} while waiting for process")
+ except Exception as status_err:
+ if "FAILED" in str(status_err) or "terminal state" in str(status_err):
+ raise status_err
+
try:
process_id = self.get_process_id(ex_id)
except:
+ pass
+
+ if process_id is None:
time.sleep(2)
- else:
- time.sleep(2)
+ wait_count_process += 1
+
+ if process_id is None:
+ raise Exception(f"[AV] Experiment {experiment_name} timeout waiting for
process to begin")
print(f"[AV] Experiment {experiment_name} EXECUTING with pid:
{process_id}")
# wait until task begins, then get job id
print(f"[AV] Experiment {experiment_name} WAITING until task begins...")
job_id = job_state = None
- while job_id in [None, "N/A"]:
+ max_wait_task = 300 # 10 minutes max wait for task
+ wait_count_task = 0
+ while job_id in [None, "N/A"] and wait_count_task < max_wait_task:
+ # Check experiment status - if failed, raise error
+ try:
+ status = self.get_experiment_status(ex_id)
+ if status == ExperimentState.FAILED:
+ raise Exception(f"[AV] Experiment {experiment_name} FAILED while
waiting for task to begin")
+ if status in [ExperimentState.COMPLETED, ExperimentState.CANCELED]:
+ raise Exception(f"[AV] Experiment {experiment_name} reached terminal
state {status.name} while waiting for task")
+ except Exception as status_err:
+ if "FAILED" in str(status_err) or "terminal state" in str(status_err):
+ raise status_err
+
try:
job_id, job_state = self.get_task_status(ex_id)
except:
+ pass
+
+ if job_id in [None, "N/A"]:
time.sleep(2)
- else:
- time.sleep(2)
+ wait_count_task += 1
+
+ if job_id in [None, "N/A"]:
+ raise Exception(f"[AV] Experiment {experiment_name} timeout waiting for
task to begin")
assert job_state is not None, f"Job state is None for job id: {job_id}"
print(f"[AV] Experiment {experiment_name} - Task {job_state.name} with id:
{job_id}")
@@ -721,7 +767,7 @@ class AiravataOperator:
process_id=process_id,
mount_point=mount_point,
experiment_dir=exp_dir,
- sr_host=storage.hostName,
+ sr_host=input_storage.hostName,
)
def get_experiment_status(self, experiment_id: str) -> ExperimentState:
diff --git
a/dev-tools/airavata-python-sdk/airavata_experiments/md/applications.py
b/dev-tools/airavata-python-sdk/airavata_experiments/md/applications.py
index 19e67d6ea0..76bc11b179 100644
--- a/dev-tools/airavata-python-sdk/airavata_experiments/md/applications.py
+++ b/dev-tools/airavata-python-sdk/airavata_experiments/md/applications.py
@@ -154,6 +154,59 @@ class AlphaFold2(ExperimentApp):
obj.tasks = []
return obj
+class VizFold_MSA(ExperimentApp):
+ """
+ VizFold lets you compute the 3D structure of a protein (using OpenFold),
+ and visualize its residue-to-residue attention scores using arc diagrams.
+ """
+
+ def __init__(
+ self,
+ ) -> None:
+ super().__init__(app_id="VizFold-MSA")
+
+ @classmethod
+ def initialize(
+ cls,
+ name: str,
+ protein: str,
+ ) -> Experiment[ExperimentApp]:
+ app = cls()
+ obj = Experiment[ExperimentApp](name, app).with_inputs(
+ protein=protein,
+ )
+ obj.input_mapping = {
+ "Protein": ("protein", "str"),
+ }
+ obj.tasks = []
+ return obj
+
+class VizFold_Fold(ExperimentApp):
+ """
+ VizFold lets you compute the 3D structure of a protein (using OpenFold),
+ and visualize its residue-to-residue attention scores using arc diagrams.
+ """
+
+ def __init__(
+ self,
+ ) -> None:
+ super().__init__(app_id="VizFold-Fold")
+
+ @classmethod
+ def initialize(
+ cls,
+ name: str,
+ protein: str,
+ ) -> Experiment[ExperimentApp]:
+ app = cls()
+ obj = Experiment[ExperimentApp](name, app).with_inputs(
+ protein=protein,
+ )
+ obj.input_mapping = {
+ "Protein": ("protein", "str"),
+ }
+ obj.tasks = []
+ return obj
class AMBER(ExperimentApp):
"""
diff --git
a/dev-tools/airavata-python-sdk/airavata_sdk/clients/utils/data_model_creation_util.py
b/dev-tools/airavata-python-sdk/airavata_sdk/clients/utils/data_model_creation_util.py
index 65dfa7df07..a68b25127c 100644
---
a/dev-tools/airavata-python-sdk/airavata_sdk/clients/utils/data_model_creation_util.py
+++
b/dev-tools/airavata-python-sdk/airavata_sdk/clients/utils/data_model_creation_util.py
@@ -76,7 +76,8 @@ class DataModelCreationUtil(object):
experiment_model: ExperimentModel,
computation_resource_name: str,
group_resource_profile_name: str,
- storageId: str,
+ inputStorageId: str,
+ outputStorageId: str,
node_count: int,
total_cpu_count: int,
queue_name: str,
@@ -97,7 +98,8 @@ class DataModelCreationUtil(object):
userConfigData.computationalResourceScheduling = computRes
userConfigData.groupResourceProfileId = groupResourceProfileId
- userConfigData.storageId = storageId
+ userConfigData.inputStorageResourceId = inputStorageId
+ userConfigData.outputStorageResourceId = outputStorageId
userConfigData.experimentDataDir = experiment_dir_path
userConfigData.airavataAutoSchedule = auto_schedule
diff --git
a/dev-tools/airavata-python-sdk/airavata_sdk/clients/utils/experiment_handler_util.py
b/dev-tools/airavata-python-sdk/airavata_sdk/clients/utils/experiment_handler_util.py
index 442d3e90ad..9557da3db8 100644
---
a/dev-tools/airavata-python-sdk/airavata_sdk/clients/utils/experiment_handler_util.py
+++
b/dev-tools/airavata-python-sdk/airavata_sdk/clients/utils/experiment_handler_util.py
@@ -83,6 +83,7 @@ class ExperimentHandlerUtil(object):
group_name: str = "Default",
application_name: str = "Default Application",
project_name: str = "Default Project",
+ output_storage_host: str | None = None,
):
execution_id = self.airavata_util.get_execution_id(application_name)
assert execution_id is not None
@@ -91,14 +92,20 @@ class ExperimentHandlerUtil(object):
resource_host_id =
self.airavata_util.get_resource_host_id(computation_resource_name)
group_resource_profile_id =
self.airavata_util.get_group_resource_profile_id(group_name)
- storage_host = self.settings.STORAGE_RESOURCE_HOST
- assert storage_host is not None
+ input_storage_host = self.settings.STORAGE_RESOURCE_HOST
+ assert input_storage_host is not None
sftp_port = self.settings.SFTP_PORT
assert sftp_port is not None
- storage_id = self.airavata_util.get_storage_resource_id(storage_host)
- assert storage_id is not None
+ input_storage_id =
self.airavata_util.get_storage_resource_id(input_storage_host)
+ assert input_storage_id is not None
+
+ if output_storage_host is not None:
+ output_storage_id =
self.airavata_util.get_storage_resource_id(output_storage_host)
+ else:
+ output_storage_id = input_storage_id
+ assert output_storage_id is not None
assert project_name is not None
assert application_name is not None
@@ -112,8 +119,8 @@ class ExperimentHandlerUtil(object):
description=description,
)
- logger.info("connnecting to file upload endpoint %s : %s",
storage_host, sftp_port)
- sftp_connector = SFTPConnector(host=storage_host,
+ logger.info("connnecting to file upload endpoint %s : %s",
input_storage_host, sftp_port)
+ sftp_connector = SFTPConnector(host=input_storage_host,
port=sftp_port,
username=self.user_id,
password=self.access_token)
@@ -136,7 +143,8 @@ class ExperimentHandlerUtil(object):
experiment =
self.data_model_client.configure_computation_resource_scheduling(experiment_model=experiment,
computation_resource_name=computation_resource_name,
group_resource_profile_name=group_name,
-
storageId=storage_id,
+
inputStorageId=input_storage_id,
+
outputStorageId=output_storage_id,
node_count=int(node_count),
total_cpu_count=int(cpu_count),
wall_time_limit=int(walltime),
@@ -151,8 +159,8 @@ class ExperimentHandlerUtil(object):
data_uris = []
for x in input_file_mapping[key]:
data_uri =
self.data_model_client.register_input_file(file_identifier=x,
-
storage_name=storage_host,
-
storageId=storage_id,
+
storage_name=input_storage_host,
+
storageId=input_storage_id,
input_file_name=x,
uploaded_storage_path=path)
data_uris.append(data_uri)
@@ -160,8 +168,8 @@ class ExperimentHandlerUtil(object):
else:
x = input_file_mapping[key]
data_uri =
self.data_model_client.register_input_file(file_identifier=x,
-
storage_name=storage_host,
-
storageId=storage_id,
+
storage_name=input_storage_host,
+
storageId=input_storage_id,
input_file_name=x,
uploaded_storage_path=path)
new_file_mapping[key] = data_uri
@@ -177,8 +185,8 @@ class ExperimentHandlerUtil(object):
data_uris = []
for x in input_files:
data_uri =
self.data_model_client.register_input_file(file_identifier=x,
-
storage_name=storage_host,
-
storageId=storage_id,
+
storage_name=input_storage_host,
+
storageId=input_storage_id,
input_file_name=x,
uploaded_storage_path=path)
data_uris.append(data_uri)
diff --git
a/dev-tools/airavata-python-sdk/airavata_sdk/samples/create_launch_echo_experiment.py
b/dev-tools/airavata-python-sdk/airavata_sdk/samples/create_launch_echo_experiment.py
index 53801a74c9..a85a1d8c85 100644
---
a/dev-tools/airavata-python-sdk/airavata_sdk/samples/create_launch_echo_experiment.py
+++
b/dev-tools/airavata-python-sdk/airavata_sdk/samples/create_launch_echo_experiment.py
@@ -71,7 +71,8 @@ path = Settings().GATEWAY_DATA_STORE_DIR + path_suffix
experiment =
data_model_client.configure_computation_resource_scheduling(experiment_model=experiment,
computation_resource_name="karst.uits.iu.edu",
group_resource_profile_name="Default Gateway Profile",
-
storageId="pgadev.scigap.org",
+
inputStorageId="pgadev.scigap.org",
+
outputStorageId="pgadev.scigap.org",
node_count=1,
total_cpu_count=16,
wall_time_limit=15,
diff --git
a/dev-tools/airavata-python-sdk/airavata_sdk/samples/create_launch_gaussian_experiment.py
b/dev-tools/airavata-python-sdk/airavata_sdk/samples/create_launch_gaussian_experiment.py
index 92b7c3b2e8..4640aed630 100644
---
a/dev-tools/airavata-python-sdk/airavata_sdk/samples/create_launch_gaussian_experiment.py
+++
b/dev-tools/airavata-python-sdk/airavata_sdk/samples/create_launch_gaussian_experiment.py
@@ -82,7 +82,8 @@ path = fb.upload_files(api_server_client,
credential_store_client, token, gatewa
experiment =
data_model_client.configure_computation_resource_scheduling(experiment_model=experiment,
computation_resource_name="karst.uits.iu.edu",
group_resource_profile_name="Default Gateway Profile",
-
storageId="pgadev.scigap.org",
+
inputStorageId="pgadev.scigap.org",
+
outputStorageId="pgadev.scigap.org",
node_count=1,
total_cpu_count=16,
wall_time_limit=15,
diff --git a/dev-tools/batch_launch_experiments.py
b/dev-tools/batch_launch_experiments.py
new file mode 100644
index 0000000000..8edaa47a7d
--- /dev/null
+++ b/dev-tools/batch_launch_experiments.py
@@ -0,0 +1,479 @@
+#!/usr/bin/env python3
+
+import os
+import sys
+import time
+import logging
+from concurrent.futures import ThreadPoolExecutor
+from typing import List, Dict, Any, Union
+import pydantic
+from rich.progress import Progress, BarColumn, TextColumn, SpinnerColumn,
TimeElapsedColumn
+from rich.console import Console
+from rich.live import Live
+from rich.panel import Panel
+from rich.layout import Layout
+from collections import deque
+
+os.environ['AUTH_SERVER_URL'] = "https://auth.dev.cybershuttle.org"
+os.environ['API_SERVER_HOSTNAME'] = "api.dev.cybershuttle.org"
+os.environ['GATEWAY_URL'] = "https://gateway.dev.cybershuttle.org"
+os.environ['STORAGE_RESOURCE_HOST'] = "gateway.dev.cybershuttle.org"
+
+sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
+
+from create_launch_experiment_with_storage import create_and_launch_experiment
+from airavata_experiments.airavata import AiravataOperator
+from airavata.model.status.ttypes import ExperimentState
+
+
+class ExperimentLaunchResult(pydantic.BaseModel):
+ """Result from creating and launching an experiment."""
+ experiment_id: str
+ process_id: str
+ experiment_dir: str
+ storage_host: str
+ mount_point: str
+
+
+class JobConfig(pydantic.BaseModel):
+ """Configuration for a batch job submission."""
+ experiment_name: str
+ project_name: str
+ application_name: str
+ computation_resource_name: str
+ queue_name: str
+ node_count: int
+ cpu_count: int
+ walltime: int
+ group_name: str = "Default"
+ input_storage_host: str | None = None
+ output_storage_host: str | None = None
+ input_files: Dict[str, Union[str, List[str]]] | None = None
+ data_inputs: Dict[str, Union[str, int, float]] | None = None
+ gateway_id: str | None = None
+ auto_schedule: bool = False
+
+
+class JobResult(pydantic.BaseModel):
+ """Result from submitting and monitoring a single job."""
+ job_index: int
+ experiment_id: str | None
+ status: str
+ result: ExperimentLaunchResult | None = None
+ success: bool
+ error: str | None = None
+
+
+def get_experiment_state_value(status) -> tuple[int, str, ExperimentState]:
+ """Extract state value, name, and enum from status. Returns (value, name,
enum)."""
+ if isinstance(status, ExperimentState):
+ return status.value, status.name, status
+
+ # Handle ExperimentStatus object
+ if hasattr(status, 'state'):
+ state = status.state
+ if isinstance(state, ExperimentState):
+ return state.value, state.name, state
+ elif hasattr(state, 'value'):
+ return state.value, state.name if hasattr(state, 'name') else
str(state), state
+
+ # Handle direct value/name access
+ status_value = status.value if hasattr(status, 'value') else (status if
isinstance(status, int) else None)
+ status_name = status.name if hasattr(status, 'name') else str(status)
+
+ # Convert to ExperimentState enum
+ if status_value is not None:
+ try:
+ enum_state = ExperimentState(status_value)
+ return status_value, status_name, enum_state
+ except (ValueError, TypeError):
+ pass
+
+ # Fallback
+ return None, status_name, ExperimentState.FAILED
+
+
+def monitor_experiment_silent(operator: AiravataOperator, experiment_id: str,
check_interval: int = 30) -> ExperimentState:
+ """Monitor experiment silently until completion. Returns final status."""
+ logger = logging.getLogger(__name__)
+ max_checks = 3600 # Maximum number of checks (about 5 hours at 5s
interval)
+ check_count = 0
+
+ # Use shorter interval initially, then increase
+ initial_interval = min(check_interval, 5) # Check every 5 seconds
initially
+
+ while check_count < max_checks:
+ try:
+ status = operator.get_experiment_status(experiment_id)
+
+ # Extract state information
+ status_value, status_name, status_enum =
get_experiment_state_value(status)
+
+ # Log status periodically for debugging
+ if check_count % 12 == 0: # Log every minute (12 * 5s)
+ logger.debug(f"Experiment {experiment_id} status check
{check_count}: value={status_value}, name={status_name}")
+
+ # Check terminal states: COMPLETED (7), CANCELED (6), FAILED (8)
+ if status_value is not None:
+ is_terminal = status_value in [
+ ExperimentState.COMPLETED.value, # 7
+ ExperimentState.CANCELED.value, # 6
+ ExperimentState.FAILED.value # 8
+ ]
+ else:
+ is_terminal = status_name in ['COMPLETED', 'CANCELED',
'FAILED']
+
+ if is_terminal:
+ logger.info(f"Experiment {experiment_id} reached terminal
state: {status_name} (value: {status_value})")
+ return status_enum
+
+ except Exception as e:
+ # If we can't get status, log but continue monitoring
+ logger.warning(f"Error checking experiment {experiment_id} status
(check {check_count}): {e}")
+ import traceback
+ logger.debug(traceback.format_exc())
+ if check_count > 10: # After several failed checks, assume failed
+ logger.error(f"Multiple status check failures for
{experiment_id}, assuming FAILED")
+ return ExperimentState.FAILED
+
+ # Sleep before next check
+ sleep_time = initial_interval if check_count < 6 else check_interval
+ time.sleep(sleep_time)
+ check_count += 1
+
+ # If we've exceeded max checks, assume failed
+ logger.error(f"Experiment {experiment_id} monitoring timeout after
{check_count} checks, assuming FAILED")
+ return ExperimentState.FAILED
+
+
+def submit_and_monitor_job(
+ job_index: int,
+ job_config: JobConfig | Dict[str, Any],
+ access_token: str,
+) -> JobResult:
+ """Submit and monitor a single job. Returns job result with status."""
+ # Convert dict to JobConfig if needed
+ if isinstance(job_config, dict):
+ job_config = JobConfig(**job_config)
+
+ try:
+ # Make experiment name unique for each job to avoid directory conflicts
+ # Using job_index ensures uniqueness and makes it easy to track
+ unique_experiment_name = f"{job_config.experiment_name}-job{job_index}"
+
+ # Handle input_files and data_inputs same way as working version
+ input_files = job_config.input_files if job_config.input_files else
None
+ data_inputs = job_config.data_inputs if job_config.data_inputs else
None
+
+ result_dict = create_and_launch_experiment(
+ access_token=access_token,
+ experiment_name=unique_experiment_name,
+ project_name=job_config.project_name,
+ application_name=job_config.application_name,
+ computation_resource_name=job_config.computation_resource_name,
+ queue_name=job_config.queue_name,
+ node_count=job_config.node_count,
+ cpu_count=job_config.cpu_count,
+ walltime=job_config.walltime,
+ group_name=job_config.group_name,
+ input_storage_host=job_config.input_storage_host,
+ output_storage_host=job_config.output_storage_host,
+ input_files=input_files,
+ data_inputs=data_inputs,
+ gateway_id=job_config.gateway_id,
+ auto_schedule=job_config.auto_schedule,
+ monitor=False,
+ )
+
+ operator = AiravataOperator(access_token=access_token)
+ experiment_id = result_dict['experiment_id']
+
+ # Check status immediately after submission to catch early failures
+ try:
+ initial_status = operator.get_experiment_status(experiment_id)
+ status_value, status_name, status_enum =
get_experiment_state_value(initial_status)
+
+ # Check if already in terminal state
+ if status_value is not None and status_value in [
+ ExperimentState.COMPLETED.value,
+ ExperimentState.CANCELED.value,
+ ExperimentState.FAILED.value
+ ]:
+ # Already in terminal state
+ final_status = status_enum
+ else:
+ # Monitor until completion
+ final_status = monitor_experiment_silent(operator,
experiment_id)
+ except Exception as e:
+ # If we can't check status, log and assume failed
+ logger = logging.getLogger(__name__)
+ logger.error(f"Error monitoring experiment {experiment_id}: {e}")
+ import traceback
+ logger.debug(traceback.format_exc())
+ final_status = ExperimentState.FAILED
+
+ result = ExperimentLaunchResult(**result_dict)
+
+ return JobResult(
+ job_index=job_index,
+ experiment_id=result.experiment_id,
+ status=final_status.name,
+ result=result,
+ success=final_status == ExperimentState.COMPLETED,
+ )
+ except Exception as e:
+ # Log the error for debugging
+ import traceback
+ error_msg = f"{str(e)}\n{traceback.format_exc()}"
+ logger = logging.getLogger(__name__)
+ logger.error(f"Job {job_index} failed: {error_msg}")
+
+ return JobResult(
+ job_index=job_index,
+ experiment_id=None,
+ status='ERROR',
+ result=None,
+ success=False,
+ error=str(e),
+ )
+
+
+def batch_submit_jobs(
+ job_config: JobConfig | Dict[str, Any],
+ num_copies: int = 10,
+ max_concurrent: int = 5,
+ access_token: str | None = None,
+) -> List[JobResult]:
+ """Submit multiple job copies in batches with progress bar."""
+ if access_token is None:
+ from airavata_auth.device_auth import AuthContext
+ access_token = AuthContext.get_access_token()
+
+ console = Console()
+ results = []
+ log_buffer = deque(maxlen=50) # Keep last 50 log lines for display
+
+ # Custom handler to capture logs to buffer
+ class ListHandler(logging.Handler):
+ def __init__(self, buffer):
+ super().__init__()
+ self.buffer = buffer
+
+ def emit(self, record):
+ msg = self.format(record)
+ self.buffer.append(msg)
+
+ log_handler = ListHandler(log_buffer)
+ log_handler.setLevel(logging.INFO)
+ log_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s -
%(levelname)s - %(message)s'))
+
+ # Add to root logger and module logger
+ logging.root.addHandler(log_handler)
+ logger = logging.getLogger('create_launch_experiment_with_storage')
+ logger.addHandler(log_handler)
+
+ # Configure progress bar
+ progress = Progress(
+ SpinnerColumn(),
+ TextColumn("[progress.description]{task.description}"),
+ BarColumn(),
+ TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
+ TextColumn("•"),
+ TextColumn("{task.completed}/{task.total}"),
+ TimeElapsedColumn(),
+ console=console,
+ )
+
+ task = progress.add_task(
+ f"{num_copies} total, 0 running, 0 completed, 0 failed",
+ total=num_copies
+ )
+
+ # Create layout with logs above and progress below
+ layout = Layout()
+ layout.split_column(
+ Layout(name="logs", size=None),
+ Layout(progress, name="progress", size=3)
+ )
+
+ def make_display():
+ # Get logs from buffer - always show the latest logs (they're added to
end of deque)
+ log_lines = list(log_buffer) if log_buffer else ["No logs yet..."]
+ # Show last 20 lines to keep display manageable and scrolled to bottom
+ display_lines = log_lines[-20:] if len(log_lines) > 20 else log_lines
+ log_text = '\n'.join(display_lines)
+ log_panel = Panel(
+ log_text,
+ title="Logs (latest)",
+ border_style="blue",
+ height=None,
+ expand=False
+ )
+ layout["logs"].update(log_panel)
+ return layout
+
+ try:
+ # Use Live to keep layout fixed, progress at bottom
+ with Live(make_display(), console=console, refresh_per_second=4,
screen=True) as live:
+ with ThreadPoolExecutor(max_workers=max_concurrent) as executor:
+ active_futures = {}
+ next_job_index = 0
+
+ # Submit initial batch
+ while next_job_index < min(max_concurrent, num_copies):
+ future = executor.submit(submit_and_monitor_job,
next_job_index, job_config, access_token)
+ active_futures[future] = next_job_index
+ next_job_index += 1
+
+ # Process completed jobs and submit new ones
+ # Continue until all jobs are submitted AND all active futures
are done
+ while active_futures or next_job_index < num_copies:
+ completed_futures = [f for f in active_futures if f.done()]
+
+ for future in completed_futures:
+ job_idx = active_futures.pop(future)
+
+ try:
+ result = future.result()
+ results.append(result)
+ except Exception as e:
+ # Handle unexpected exceptions
+ results.append(JobResult(
+ job_index=job_idx,
+ experiment_id=None,
+ status='ERROR',
+ result=None,
+ success=False,
+ error=str(e),
+ ))
+
+ # Submit next jobs if available and we have capacity
+ while next_job_index < num_copies and len(active_futures)
< max_concurrent:
+ try:
+ new_future =
executor.submit(submit_and_monitor_job, next_job_index, job_config,
access_token)
+ active_futures[new_future] = next_job_index
+ next_job_index += 1
+ except Exception as e:
+ # If submission itself fails, mark as error and
continue
+ results.append(JobResult(
+ job_index=next_job_index,
+ experiment_id=None,
+ status='ERROR',
+ result=None,
+ success=False,
+ error=f"Submission failed: {str(e)}",
+ ))
+ next_job_index += 1
+
+ # Update progress bar with counts
+ completed_count = len(results)
+ running_count = len(active_futures)
+ submitted_count = next_job_index
+ successful_count = sum(1 for r in results if r.success)
+ failed_count = completed_count - successful_count
+
+ # Show submitted count if not all jobs submitted yet
+ if submitted_count < num_copies:
+ status_desc = f"{num_copies} total, {submitted_count}
submitted, {running_count} running, {completed_count} completed, {failed_count}
failed"
+ else:
+ status_desc = f"{num_copies} total, {running_count}
running, {completed_count} completed, {failed_count} failed"
+
+ progress.update(
+ task,
+ completed=completed_count,
+ description=status_desc
+ )
+ live.update(make_display())
+
+ if not completed_futures and next_job_index >= num_copies:
+ # Only sleep if nothing changed
+ time.sleep(1)
+
+ # Sort results by job_index
+ results.sort(key=lambda x: x.job_index)
+ return results
+ finally:
+ # Clean up log handlers
+ logging.root.removeHandler(log_handler)
+ if log_handler in logger.handlers:
+ logger.removeHandler(log_handler)
+
+
+def main():
+ """Main function that sets up job configuration and runs batch
submission."""
+ from airavata_auth.device_auth import AuthContext
+
+ access_token = AuthContext.get_access_token()
+
+ # Job configuration - matching create_launch_experiment_with_storage.py
exactly
+ EXPERIMENT_NAME = "Test"
+ PROJECT_NAME = "Default Project"
+ APPLICATION_NAME = "NAMD-test"
+ GATEWAY_ID = None
+ COMPUTATION_RESOURCE_NAME = "NeuroData25VC2"
+ QUEUE_NAME = "cloud"
+ NODE_COUNT = 1
+ CPU_COUNT = 1
+ WALLTIME = 5
+ GROUP_NAME = "Default"
+ INPUT_STORAGE_HOST = "gateway.dev.cybershuttle.org"
+ OUTPUT_STORAGE_HOST = "149.165.169.12"
+ INPUT_FILES = {}
+ DATA_INPUTS = {}
+ AUTO_SCHEDULE = False
+
+ job_config = JobConfig(
+ experiment_name=EXPERIMENT_NAME,
+ project_name=PROJECT_NAME,
+ application_name=APPLICATION_NAME,
+ computation_resource_name=COMPUTATION_RESOURCE_NAME,
+ queue_name=QUEUE_NAME,
+ node_count=NODE_COUNT,
+ cpu_count=CPU_COUNT,
+ walltime=WALLTIME,
+ group_name=GROUP_NAME,
+ input_storage_host=INPUT_STORAGE_HOST,
+ output_storage_host=OUTPUT_STORAGE_HOST,
+ input_files=INPUT_FILES if INPUT_FILES else None,
+ data_inputs=DATA_INPUTS if DATA_INPUTS else None,
+ gateway_id=GATEWAY_ID,
+ auto_schedule=AUTO_SCHEDULE,
+ )
+
+ num_copies = 10
+
+ try:
+ results = batch_submit_jobs(
+ job_config=job_config,
+ num_copies=num_copies,
+ max_concurrent=5,
+ access_token=access_token,
+ )
+
+ # Print summary
+ print("\n" + "="*60)
+ print(f"Batch submission complete: {num_copies} jobs")
+ print("="*60)
+ successful = sum(1 for r in results if r.success)
+ print(f"Successful: {successful}/{num_copies}")
+ print(f"Failed: {num_copies - successful}/{num_copies}")
+ print("\nJob Results:")
+ for result in results:
+ status_symbol = "✓" if result.success else "✗"
+ exp_id = result.experiment_id or 'N/A'
+ print(f" {status_symbol} Job {result.job_index}: {result.status} "
+ f"(ID: {exp_id})")
+ print("="*60)
+
+ return results
+
+ except Exception as e:
+ print(f"Failed to run batch submission: {repr(e)}", file=sys.stderr)
+ import traceback
+ traceback.print_exc()
+ sys.exit(1)
+
+
+if __name__ == "__main__":
+ main()
+
diff --git a/dev-tools/create_launch_experiment_with_storage.py
b/dev-tools/create_launch_experiment_with_storage.py
new file mode 100755
index 0000000000..7d333e1234
--- /dev/null
+++ b/dev-tools/create_launch_experiment_with_storage.py
@@ -0,0 +1,267 @@
+#!/usr/bin/env python3
+
+import os
+import sys
+import time
+import logging
+from pathlib import Path
+
+os.environ['AUTH_SERVER_URL'] = "https://auth.dev.cybershuttle.org"
+os.environ['API_SERVER_HOSTNAME'] = "api.dev.cybershuttle.org"
+os.environ['GATEWAY_URL'] = "https://gateway.dev.cybershuttle.org"
+os.environ['STORAGE_RESOURCE_HOST'] = "gateway.dev.cybershuttle.org"
+
+sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
+
+from airavata_experiments.airavata import AiravataOperator
+from airavata.model.status.ttypes import ExperimentState
+from airavata_auth.device_auth import AuthContext
+
+logging.basicConfig(
+ level=logging.INFO,
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
+)
+logger = logging.getLogger(__name__)
+
+
+def list_storage_resources(access_token: str, gateway_id: str | None = None):
+ operator = AiravataOperator(access_token=access_token)
+ sr_names =
operator.api_server_client.get_all_storage_resource_names(operator.airavata_token)
+ logger.info("Available storage resources:")
+ for sr_id, hostname in sr_names.items():
+ logger.info(f" ID: {sr_id}, Hostname: {hostname}")
+ return sr_names
+
+
+def get_storage_hostname_by_id(access_token: str, storage_resource_id: str) ->
str | None:
+ operator = AiravataOperator(access_token=access_token)
+ sr_names =
operator.api_server_client.get_all_storage_resource_names(operator.airavata_token)
+ hostname = sr_names.get(storage_resource_id)
+ if hostname:
+ logger.info(f"Storage ID {storage_resource_id} maps to hostname:
{hostname}")
+ else:
+ logger.warning(f"Storage ID {storage_resource_id} not found in
available resources")
+ return hostname
+
+
+def create_and_launch_experiment(
+ access_token: str,
+ experiment_name: str,
+ project_name: str,
+ application_name: str,
+ computation_resource_name: str,
+ queue_name: str,
+ node_count: int,
+ cpu_count: int,
+ walltime: int,
+ group_name: str = "Default",
+ input_storage_host: str | None = None,
+ output_storage_host: str | None = None,
+ input_files: dict[str, str | list[str]] | None = None,
+ data_inputs: dict[str, str | int | float] | None = None,
+ gateway_id: str | None = None,
+ auto_schedule: bool = False,
+ monitor: bool = True,
+) -> dict:
+ operator = AiravataOperator(access_token=access_token)
+
+ experiment_inputs = {}
+
+ if input_files:
+ for input_name, file_paths in input_files.items():
+ if isinstance(file_paths, list):
+ experiment_inputs[input_name] = {
+ "type": "uri[]",
+ "value": [str(Path(fp).resolve()) for fp in file_paths]
+ }
+ logger.info(f"Added file array input '{input_name}':
{file_paths}")
+ else:
+ experiment_inputs[input_name] = {
+ "type": "uri",
+ "value": str(Path(file_paths).resolve())
+ }
+ logger.info(f"Added file input '{input_name}': {file_paths}")
+
+ if data_inputs:
+ for input_name, value in data_inputs.items():
+ if isinstance(value, int):
+ experiment_inputs[input_name] = {"type": "int", "value": value}
+ elif isinstance(value, float):
+ experiment_inputs[input_name] = {"type": "float", "value":
value}
+ else:
+ experiment_inputs[input_name] = {"type": "string", "value":
str(value)}
+ logger.info(f"Added data input '{input_name}': {value}")
+
+ if not experiment_inputs:
+ logger.info("No inputs provided. Adding dummy input for applications
that don't require inputs...")
+ experiment_inputs = {"__no_inputs__": {"type": "string", "value": ""}}
+
+ logger.info(f"Launching experiment '{experiment_name}'...")
+ logger.info(f" Project: {project_name}")
+ logger.info(f" Application: {application_name}")
+ logger.info(f" Compute Resource: {computation_resource_name}")
+ logger.info(f" Input Storage: {input_storage_host or 'default'}")
+ logger.info(f" Output Storage: {output_storage_host or input_storage_host
or 'default'}")
+
+ launch_state = operator.launch_experiment(
+ experiment_name=experiment_name,
+ project=project_name,
+ app_name=application_name,
+ inputs=experiment_inputs,
+ computation_resource_name=computation_resource_name,
+ queue_name=queue_name,
+ node_count=node_count,
+ cpu_count=cpu_count,
+ walltime=walltime,
+ group=group_name,
+ gateway_id=gateway_id,
+ input_sr_host=input_storage_host,
+ output_sr_host=output_storage_host,
+ auto_schedule=auto_schedule,
+ )
+
+ logger.info(f"Experiment launched successfully!")
+ logger.info(f" Experiment ID: {launch_state.experiment_id}")
+ logger.info(f" Process ID: {launch_state.process_id}")
+ logger.info(f" Experiment Directory: {launch_state.experiment_dir}")
+ logger.info(f" Storage Host: {launch_state.sr_host}")
+
+ result = {
+ "experiment_id": launch_state.experiment_id,
+ "process_id": launch_state.process_id,
+ "experiment_dir": launch_state.experiment_dir,
+ "storage_host": launch_state.sr_host,
+ "mount_point": str(launch_state.mount_point),
+ }
+
+ if monitor:
+ logger.info("Monitoring experiment status...")
+ monitor_experiment(operator, launch_state.experiment_id)
+
+ return result
+
+
+def get_experiment_state_value(status) -> tuple[int, str, ExperimentState]:
+ """Extract state value, name, and enum from status. Returns (value, name,
enum)."""
+ if isinstance(status, ExperimentState):
+ return status.value, status.name, status
+
+ # Handle ExperimentStatus object
+ if hasattr(status, 'state'):
+ state = status.state
+ if isinstance(state, ExperimentState):
+ return state.value, state.name, state
+ elif hasattr(state, 'value'):
+ return state.value, state.name if hasattr(state, 'name') else
str(state), state
+
+ # Handle direct value/name access
+ status_value = status.value if hasattr(status, 'value') else (status if
isinstance(status, int) else None)
+ status_name = status.name if hasattr(status, 'name') else str(status)
+
+ # Convert to ExperimentState enum
+ if status_value is not None:
+ try:
+ enum_state = ExperimentState(status_value)
+ return status_value, status_name, enum_state
+ except (ValueError, TypeError):
+ pass
+
+ # Fallback
+ return None, status_name, ExperimentState.FAILED
+
+
+def monitor_experiment(operator: AiravataOperator, experiment_id: str,
check_interval: int = 30):
+ logger.info(f"Monitoring experiment {experiment_id}...")
+
+ while True:
+ try:
+ status = operator.get_experiment_status(experiment_id)
+ status_value, status_name, status_enum =
get_experiment_state_value(status)
+ logger.info(f"Experiment status: {status_name} (value:
{status_value})")
+
+ # Check terminal states: COMPLETED (7), CANCELED (6), FAILED (8)
+ if status_value is not None:
+ is_terminal = status_value in [
+ ExperimentState.COMPLETED.value, # 7
+ ExperimentState.CANCELED.value, # 6
+ ExperimentState.FAILED.value # 8
+ ]
+ else:
+ is_terminal = status_name in ['COMPLETED', 'CANCELED',
'FAILED']
+
+ if is_terminal:
+ logger.info(f"Experiment finished with state: {status_name}")
+ break
+ except Exception as e:
+ logger.error(f"Error checking experiment {experiment_id} status:
{e}")
+ import traceback
+ logger.debug(traceback.format_exc())
+ # Continue monitoring despite errors
+
+ time.sleep(check_interval)
+
+
+def main():
+ logger.info("Authenticating...")
+ ACCESS_TOKEN = AuthContext.get_access_token()
+
+ EXPERIMENT_NAME = "Test"
+ PROJECT_NAME = "Default Project"
+ APPLICATION_NAME = "NAMD-test"
+ GATEWAY_ID = None
+
+ COMPUTATION_RESOURCE_NAME = "NeuroData25VC2"
+ QUEUE_NAME = "cloud"
+ NODE_COUNT = 1
+ CPU_COUNT = 1
+ WALLTIME = 5
+ GROUP_NAME = "Default"
+
+ INPUT_STORAGE_HOST = "gateway.dev.cybershuttle.org"
+ OUTPUT_STORAGE_HOST = "149.165.169.12"
+
+ INPUT_FILES = {}
+ DATA_INPUTS = {}
+
+ AUTO_SCHEDULE = False
+ MONITOR = True
+
+ try:
+ result = create_and_launch_experiment(
+ access_token=ACCESS_TOKEN,
+ experiment_name=EXPERIMENT_NAME,
+ project_name=PROJECT_NAME,
+ application_name=APPLICATION_NAME,
+ computation_resource_name=COMPUTATION_RESOURCE_NAME,
+ queue_name=QUEUE_NAME,
+ node_count=NODE_COUNT,
+ cpu_count=CPU_COUNT,
+ walltime=WALLTIME,
+ group_name=GROUP_NAME,
+ input_storage_host=INPUT_STORAGE_HOST,
+ output_storage_host=OUTPUT_STORAGE_HOST,
+ input_files=INPUT_FILES if INPUT_FILES else None,
+ data_inputs=DATA_INPUTS if DATA_INPUTS else None,
+ gateway_id=GATEWAY_ID,
+ auto_schedule=AUTO_SCHEDULE,
+ monitor=MONITOR,
+ )
+
+ logger.info("\n" + "="*60)
+ logger.info("Experiment created and launched successfully!")
+ logger.info("="*60)
+ logger.info(f"Experiment ID: {result['experiment_id']}")
+ logger.info(f"Process ID: {result['process_id']}")
+ logger.info(f"Experiment Directory: {result['experiment_dir']}")
+ logger.info(f"Storage Host: {result['storage_host']}")
+ logger.info("="*60)
+
+ return result
+
+ except Exception as e:
+ logger.error(f"Failed to create/launch experiment: {repr(e)}",
exc_info=True)
+ sys.exit(1)
+
+
+if __name__ == "__main__":
+ main()
diff --git
a/modules/agent-framework/agent-service/src/main/java/org/apache/airavata/agent/connection/service/handlers/AgentManagementHandler.java
b/modules/agent-framework/agent-service/src/main/java/org/apache/airavata/agent/connection/service/handlers/AgentManagementHandler.java
index 409511c0db..a976a7db8e 100644
---
a/modules/agent-framework/agent-service/src/main/java/org/apache/airavata/agent/connection/service/handlers/AgentManagementHandler.java
+++
b/modules/agent-framework/agent-service/src/main/java/org/apache/airavata/agent/connection/service/handlers/AgentManagementHandler.java
@@ -42,6 +42,7 @@ import
org.apache.airavata.model.experiment.UserConfigurationDataModel;
import org.apache.airavata.model.process.ProcessModel;
import
org.apache.airavata.model.scheduling.ComputationalResourceSchedulingModel;
import org.apache.airavata.model.security.AuthzToken;
+import org.apache.commons.lang3.StringUtils;
import org.apache.thrift.TException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -245,7 +246,10 @@ public class AgentManagementHandler {
userConfigurationDataModel.setComputationalResourceScheduling(computationalResourceSchedulingModel);
userConfigurationDataModel.setAiravataAutoSchedule(false);
userConfigurationDataModel.setOverrideManualScheduledParams(false);
- userConfigurationDataModel.setStorageId(storageResourceId);
+ userConfigurationDataModel.setInputStorageResourceId(
+ StringUtils.isNotBlank(req.getInputStorageId()) ?
req.getInputStorageId() : storageResourceId);
+ userConfigurationDataModel.setOutputStorageResourceId(
+ StringUtils.isNotBlank(req.getOutputStorageId()) ?
req.getInputStorageId() : storageResourceId);
String experimentDataDir = Paths.get(storagePath, gatewayId, userName,
projectDir, experimentName)
.toString();
userConfigurationDataModel.setExperimentDataDir(experimentDataDir);
diff --git
a/modules/agent-framework/agent-service/src/main/java/org/apache/airavata/agent/connection/service/models/AgentLaunchRequest.java
b/modules/agent-framework/agent-service/src/main/java/org/apache/airavata/agent/connection/service/models/AgentLaunchRequest.java
index 81b6c26631..0e655c07f1 100644
---
a/modules/agent-framework/agent-service/src/main/java/org/apache/airavata/agent/connection/service/models/AgentLaunchRequest.java
+++
b/modules/agent-framework/agent-service/src/main/java/org/apache/airavata/agent/connection/service/models/AgentLaunchRequest.java
@@ -39,6 +39,9 @@ public class AgentLaunchRequest {
private int nodeCount = 1;
private int memory = 2048;
+ private String inputStorageId;
+ private String outputStorageId;
+
public String getExperimentName() {
return experimentName;
}
@@ -138,4 +141,20 @@ public class AgentLaunchRequest {
public void setMounts(List<String> mounts) {
this.mounts = mounts;
}
+
+ public String getInputStorageId() {
+ return inputStorageId;
+ }
+
+ public void setInputStorageId(String inputStorageId) {
+ this.inputStorageId = inputStorageId;
+ }
+
+ public String getOutputStorageId() {
+ return outputStorageId;
+ }
+
+ public void setOutputStorageId(String outputStorageId) {
+ this.outputStorageId = outputStorageId;
+ }
}
diff --git a/thrift-interface-descriptions/data-models/experiment_model.thrift
b/thrift-interface-descriptions/data-models/experiment_model.thrift
index 0d31bb033d..07ec759fe1 100644
--- a/thrift-interface-descriptions/data-models/experiment_model.thrift
+++ b/thrift-interface-descriptions/data-models/experiment_model.thrift
@@ -65,11 +65,12 @@ struct UserConfigurationDataModel {
5: optional bool throttleResources = 0,
6: optional string userDN,
7: optional bool generateCert = 0,
- 8: optional string storageId;
- 9: optional string experimentDataDir;
- 10: optional bool useUserCRPref;
- 11: optional string groupResourceProfileId
- 12: optional list<scheduling_model.ComputationalResourceSchedulingModel>
autoScheduledCompResourceSchedulingList,
+ 8: optional string inputStorageResourceId;
+ 9: optional string outputStorageResourceId;
+ 10: optional string experimentDataDir;
+ 11: optional bool useUserCRPref;
+ 12: optional string groupResourceProfileId
+ 13: optional list<scheduling_model.ComputationalResourceSchedulingModel>
autoScheduledCompResourceSchedulingList,
}
/**
diff --git a/thrift-interface-descriptions/data-models/process_model.thrift
b/thrift-interface-descriptions/data-models/process_model.thrift
index a47e07aae8..67e45735e9 100644
--- a/thrift-interface-descriptions/data-models/process_model.thrift
+++ b/thrift-interface-descriptions/data-models/process_model.thrift
@@ -65,12 +65,13 @@ struct ProcessModel {
16: optional string gatewayExecutionId,
17: optional bool enableEmailNotification,
18: optional list<string> emailAddresses,
- 19: optional string storageResourceId,
- 20: optional string userDn,
- 21: optional bool generateCert = 0,
- 22: optional string experimentDataDir,
- 23: optional string userName,
- 24: optional bool useUserCRPref,
- 25: optional string groupResourceProfileId;
- 26: optional list<ProcessWorkflow> processWorkflows;
+ 19: optional string inputStorageResourceId,
+ 20: optional string outputStorageResourceId,
+ 21: optional string userDn,
+ 22: optional bool generateCert = 0,
+ 23: optional string experimentDataDir,
+ 24: optional string userName,
+ 25: optional bool useUserCRPref,
+ 26: optional string groupResourceProfileId;
+ 27: optional list<ProcessWorkflow> processWorkflows;
}