This is an automated email from the ASF dual-hosted git repository. lahirujayathilake pushed a commit to branch cybershuttle-dev in repository https://gitbox.apache.org/repos/asf/airavata.git
commit 1dcb49a8e7c79b33c40b3c49fb481ebd81789b97 Author: yasith <[email protected]> AuthorDate: Sun Dec 8 03:06:19 2024 -0600 improve auth flow, add python code invocation, add plan crud apis and file ul/dl apis, reduce settings.ini deps, fix bugs --- .../airavata-python-sdk/.gitignore | 1 + .../airavata_experiments/__init__.py | 20 +- .../airavata_experiments/airavata.py | 345 ++++++++++++++---- .../airavata_experiments/auth/device_auth.py | 216 +++++++----- .../airavata_experiments/plan.py | 114 +++++- .../airavata_experiments/runtime.py | 153 +++++++- .../airavata_experiments/scripter.py | 144 ++++++++ .../airavata_experiments/sftp.py | 112 +++--- .../airavata_experiments/task.py | 34 +- .../airavata_sdk/clients/api_server_client.py | 9 +- .../airavata-python-sdk/pyproject.toml | 4 +- .../airavata-python-sdk/samples/annotations.py | 70 ---- .../airavata-python-sdk/samples/poc.ipynb | 385 ++++++++++++--------- 13 files changed, 1088 insertions(+), 519 deletions(-) diff --git a/airavata-api/airavata-client-sdks/airavata-python-sdk/.gitignore b/airavata-api/airavata-client-sdks/airavata-python-sdk/.gitignore index 11b22924b5..2fb5c82ffc 100644 --- a/airavata-api/airavata-client-sdks/airavata-python-sdk/.gitignore +++ b/airavata-api/airavata-client-sdks/airavata-python-sdk/.gitignore @@ -9,6 +9,7 @@ __pycache__/ .ipynb_checkpoints *.egg-info/ data/ +results/ plan.json settings*.ini auth.state \ No newline at end of file diff --git a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/__init__.py b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/__init__.py index 298a46ad94..0f770f2205 100644 --- a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/__init__.py +++ b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/__init__.py @@ -17,23 +17,7 @@ from __future__ import annotations from . import base, md, plan -from .runtime import list_runtimes from .auth import login, logout +from .runtime import list_runtimes - -def load_plan(path: str) -> plan.Plan: - return plan.Plan.load_json(path) - - -def task_context(task: base.Task): - def inner(func): - # take the function into the task's location - # and execute it there. then fetch the result - result = func(**task.inputs) - # and return it to the caller. - return result - - return inner - - -__all__ = ["login", "logout", "list_runtimes", "md", "task_context"] +__all__ = ["login", "logout", "list_runtimes", "base", "md", "plan"] diff --git a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/airavata.py b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/airavata.py index 1dda44c55a..b6556f1770 100644 --- a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/airavata.py +++ b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/airavata.py @@ -16,36 +16,158 @@ import logging from pathlib import Path +from typing import Literal, NamedTuple from .sftp import SFTPConnector import jwt from airavata.model.security.ttypes import AuthzToken +from airavata.model.experiment.ttypes import ExperimentModel, ExperimentType, UserConfigurationDataModel +from airavata.model.scheduling.ttypes import ComputationalResourceSchedulingModel +from airavata.model.data.replica.ttypes import DataProductModel, DataProductType, DataReplicaLocationModel, ReplicaLocationCategory + from airavata_sdk.clients.api_server_client import APIServerClient -from airavata_sdk.clients.utils.api_server_client_util import APIServerClientUtil -from airavata_sdk.clients.utils.data_model_creation_util import DataModelCreationUtil -from airavata_sdk.transport.settings import ExperimentSettings, GatewaySettings +from airavata_sdk.transport.settings import ExperimentSettings, GatewaySettings, APIServerClientSettings logger = logging.getLogger("airavata_sdk.clients") logger.setLevel(logging.INFO) +LaunchState = NamedTuple("LaunchState", [ + ("experiment_id", str), + ("mount_point", Path), + ("experiment_dir", str), + ("sr_host", str), +]) class AiravataOperator: + def register_input_file( + self, + file_identifier: str, + storage_name: str, + storageId: str, + gateway_id: str, + input_file_name: str, + uploaded_storage_path: str, + ) -> str: + + dataProductModel = DataProductModel( + gatewayId=gateway_id, + ownerName=self.user_id, + productName=file_identifier, + dataProductType=DataProductType.FILE, + replicaLocations=[ + DataReplicaLocationModel( + replicaName="{} gateway data store copy".format(input_file_name), + replicaLocationCategory=ReplicaLocationCategory.GATEWAY_DATA_STORE, + storageResourceId=storageId, + filePath="file://{}:{}".format(storage_name, uploaded_storage_path + input_file_name), + )], + ) + + return self.api_server_client.register_data_product(self.airavata_token, dataProductModel) # type: ignore + + def create_experiment_model( + self, + project_name: str, + application_name: str, + experiment_name: str, + description: str, + gateway_id: str, + ) -> ExperimentModel: + + execution_id = self.get_app_interface_id(application_name) + project_id = self.get_project_id(project_name) + return ExperimentModel( + experimentName=experiment_name, + gatewayId=gateway_id, + userName=self.user_id, + description=description, + projectId=project_id, + experimentType=ExperimentType.SINGLE_APPLICATION, + executionId=execution_id + ) + + def get_resource_host_id(self, resource_name): + resources: dict = self.api_server_client.get_all_compute_resource_names(self.airavata_token) # type: ignore + return next((str(k) for k, v in resources.items() if v == resource_name)) + + def configure_computation_resource_scheduling( + self, + experiment_model: ExperimentModel, + computation_resource_name: str, + group_resource_profile_name: str, + storageId: str, + node_count: int, + total_cpu_count: int, + queue_name: str, + wall_time_limit: int, + experiment_dir_path: str, + auto_schedule=False, + ) -> ExperimentModel: + resource_host_id = self.get_resource_host_id(computation_resource_name) + groupResourceProfileId = self.get_group_resource_profile_id(group_resource_profile_name) + computRes = ComputationalResourceSchedulingModel() + computRes.resourceHostId = resource_host_id + computRes.nodeCount = node_count + computRes.totalCPUCount = total_cpu_count + computRes.queueName = queue_name + computRes.wallTimeLimit = wall_time_limit + + userConfigData = UserConfigurationDataModel() + userConfigData.computationalResourceScheduling = computRes + + userConfigData.groupResourceProfileId = groupResourceProfileId + userConfigData.storageId = storageId + + userConfigData.experimentDataDir = experiment_dir_path + userConfigData.airavataAutoSchedule = auto_schedule + experiment_model.userConfigurationData = userConfigData + + return experiment_model + def __init__(self, access_token: str, config_file: str = "settings.ini"): # store variables - self.config_file = config_file self.access_token = access_token + self.api_settings = APIServerClientSettings(config_file) + self.gateway_settings = GatewaySettings(config_file) + self.experiment_settings = ExperimentSettings(config_file) # load api server settings and create client - self.api_server_client = APIServerClient(self.config_file) + self.api_server_client = APIServerClient(api_server_settings=self.api_settings) # load gateway settings - self.gateway_conf = GatewaySettings(self.config_file) - gateway_id = self.gateway_conf.GATEWAY_ID - # load experiment settings - self.experiment_conf = ExperimentSettings(self.config_file) + gateway_id = self.default_gateway_id() self.airavata_token = self.__airavata_token__(self.access_token, gateway_id) - self.api_util = APIServerClientUtil(self.config_file, username=self.user_id, password="", gateway_id=gateway_id, access_token=self.access_token) - def __airavata_token__(self, access_token, gateway_id): + def default_gateway_id(self): + return self.gateway_settings.GATEWAY_ID + + def default_gateway_data_store_dir(self): + return self.gateway_settings.GATEWAY_DATA_STORE_DIR + + def default_sftp_port(self): + return self.experiment_settings.SFTP_PORT + + def default_experiment_queue(self): + return self.experiment_settings.QUEUE_NAME + + def default_grp_name(self): + return self.experiment_settings.GROUP_RESOURCE_PROFILE_NAME + + def default_sr_hostname(self): + return self.experiment_settings.STORAGE_RESOURCE_HOST + + def default_project_name(self): + return self.experiment_settings.PROJECT_NAME + + def default_node_count(self): + return self.experiment_settings.NODE_COUNT + + def default_cpu_count(self): + return self.experiment_settings.TOTAL_CPU_COUNT + + def default_walltime(self): + return self.experiment_settings.WALL_TIME_LIMIT + + def __airavata_token__(self, access_token: str, gateway_id: str): """ Decode access token (string) and create AuthzToken (object) @@ -55,6 +177,7 @@ class AiravataOperator: claimsMap = {"userName": self.user_id, "gatewayID": gateway_id} return AuthzToken(accessToken=self.access_token, claimsMap=claimsMap) + def get_experiment(self, experiment_id: str): """ Get experiment by id @@ -62,28 +185,32 @@ class AiravataOperator: """ return self.api_server_client.get_experiment(self.airavata_token, experiment_id) + def get_accessible_apps(self, gateway_id: str | None = None): """ Get all applications available in the gateway """ # use defaults for missing values - gateway_id = gateway_id or self.gateway_conf.GATEWAY_ID + gateway_id = gateway_id or self.default_gateway_id() # logic app_interfaces = self.api_server_client.get_all_application_interfaces(self.airavata_token, gateway_id) return app_interfaces - def get_preferred_storage(self, gateway_id: str | None = None, storage_name: str | None = None): + + def get_preferred_storage(self, gateway_id: str | None = None, sr_hostname: str | None = None): """ Get preferred storage resource """ # use defaults for missing values - gateway_id = gateway_id or self.gateway_conf.GATEWAY_ID - storage_name = storage_name or self.experiment_conf.STORAGE_RESOURCE_HOST + gateway_id = gateway_id or self.default_gateway_id() + sr_hostname = sr_hostname or self.default_sr_hostname() # logic - storage_id = self.api_util.get_storage_resource_id(storage_name) - return self.api_server_client.get_gateway_storage_preference(self.airavata_token, gateway_id, storage_id) + sr_names: dict[str, str] = self.api_server_client.get_all_storage_resource_names(self.airavata_token) # type: ignore + sr_id = next((str(k) for k, v in sr_names.items() if v == sr_hostname)) + return self.api_server_client.get_gateway_storage_preference(self.airavata_token, gateway_id, sr_id) + def get_storage(self, storage_name: str | None = None) -> any: # type: ignore """ @@ -91,43 +218,64 @@ class AiravataOperator: """ # use defaults for missing values - storage_name = storage_name or self.experiment_conf.STORAGE_RESOURCE_HOST + storage_name = storage_name or self.default_sr_hostname() # logic - storage_id = self.api_util.get_storage_resource_id(storage_name) - storage = self.api_util.api_server_client.get_storage_resource(self.airavata_token, storage_id) + sr_names: dict[str, str] = self.api_server_client.get_all_storage_resource_names(self.airavata_token) # type: ignore + sr_id = next((str(k) for k, v in sr_names.items() if v == storage_name)) + storage = self.api_server_client.get_storage_resource(self.airavata_token, sr_id) return storage + + + - def get_group_resource_profile(self, grp_name: str | None = None): + def get_group_resource_profile_id(self, grp_name: str | None = None) -> str: """ - Get group resource profile by name + Get group resource profile id by name """ # use defaults for missing values - grp_name = grp_name or self.experiment_conf.GROUP_RESOURCE_PROFILE_NAME + grp_name = grp_name or self.default_grp_name() # logic - grp_id = self.api_util.get_group_resource_profile_id(grp_name) - grp = self.api_util.api_server_client.get_group_resource_profile(self.airavata_token, grp_id) + grps: list = self.api_server_client.get_group_resource_list(self.airavata_token, self.default_gateway_id()) # type: ignore + grp_id = next((grp.groupResourceProfileId for grp in grps if grp.groupResourceProfileName == grp_name)) + return str(grp_id) + + def get_group_resource_profile(self, grp_id: str): + grp: any = self.api_server_client.get_group_resource_profile(self.airavata_token, grp_id) # type: ignore return grp + def get_compatible_deployments(self, app_interface_id: str, grp_name: str | None = None): """ Get compatible deployments for an application interface and group resource profile """ # use defaults for missing values - grp_name = grp_name or self.experiment_conf.GROUP_RESOURCE_PROFILE_NAME + grp_name = grp_name or self.default_grp_name() # logic - grp_id = self.api_util.get_group_resource_profile_id(grp_name) + grps: list = self.api_server_client.get_group_resource_list(self.airavata_token, self.default_gateway_id()) # type: ignore + grp_id = next((grp.groupResourceProfileId for grp in grps if grp.groupResourceProfileName == grp_name)) deployments = self.api_server_client.get_application_deployments_for_app_module_and_group_resource_profile(self.airavata_token, app_interface_id, grp_id) return deployments + def get_app_interface_id(self, app_name: str, gateway_id: str | None = None): """ Get application interface id by name """ - self.api_util.gateway_id = str(gateway_id or self.gateway_conf.GATEWAY_ID) - return self.api_util.get_execution_id(app_name) + gateway_id = str(gateway_id or self.default_gateway_id()) + apps: list = self.api_server_client.get_all_application_interfaces(self.airavata_token, gateway_id) # type: ignore + app_id = next((app.applicationInterfaceId for app in apps if app.applicationName == app_name)) + return str(app_id) + + + def get_project_id(self, project_name: str, gateway_id: str | None = None): + gateway_id = str(gateway_id or self.default_gateway_id()) + projects: list = self.api_server_client.get_user_projects(self.airavata_token, gateway_id, self.user_id, 10, 0) # type: ignore + project_id = next((p.projectID for p in projects if p.name == project_name and p.owner == self.user_id)) + return str(project_id) + def get_application_inputs(self, app_interface_id: str) -> list: """ @@ -136,6 +284,7 @@ class AiravataOperator: """ return list(self.api_server_client.get_application_inputs(self.airavata_token, app_interface_id)) # type: ignore + def get_compute_resources_by_ids(self, resource_ids: list[str]): """ Get compute resources by ids @@ -143,32 +292,77 @@ class AiravataOperator: """ return [self.api_server_client.get_compute_resource(self.airavata_token, resource_id) for resource_id in resource_ids] - def make_experiment_dir(self, storage_resource, project_name: str, experiment_name: str) -> str: + + def make_experiment_dir(self, sr_host: str, project_name: str, experiment_name: str) -> str: """ Make experiment directory on storage resource, and return the remote path Return Path: /{project_name}/{experiment_name} """ - host = storage_resource.hostName - port = self.experiment_conf.SFTP_PORT - sftp_connector = SFTPConnector(host=host, port=port, username=self.user_id, password=self.access_token) - remote_path = sftp_connector.make_experiment_dir(project_name, experiment_name) + host = sr_host + port = self.default_sftp_port() + sftp_connector = SFTPConnector(host=host, port=int(port), username=self.user_id, password=self.access_token) + remote_path = sftp_connector.mkdir(project_name, experiment_name) logger.info("Experiment directory created at %s", remote_path) return remote_path - def upload_files(self, storage_resource, files: list[Path], exp_dir: str) -> None: + + def upload_files(self, sr_host: str, local_files: list[Path], remote_dir: str) -> list[str]: + """ + Upload local files to a remote directory of a storage resource + + Return Path: /{project_name}/{experiment_name} + """ - Upload input files to storage resource, and return the remote path + host = sr_host + port = self.default_sftp_port() + sftp_connector = SFTPConnector(host=host, port=int(port), username=self.user_id, password=self.access_token) + paths = sftp_connector.put(local_files, remote_dir) + logger.info(f"{len(paths)} Local files uploaded to remote dir: %s", remote_dir) + return paths + + + def list_files(self, sr_host: str, remote_dir: str) -> list[str]: + """ + List files in a remote directory of a storage resource Return Path: /{project_name}/{experiment_name} """ - host = storage_resource.hostName - port = self.experiment_conf.SFTP_PORT - sftp_connector = SFTPConnector(host=host, port=port, username=self.user_id, password=self.access_token) - sftp_connector.upload_files(files, exp_dir) - logger.info("Input files uploaded to %s", exp_dir) + host = sr_host + port = self.default_sftp_port() + sftp_connector = SFTPConnector(host=host, port=int(port), username=self.user_id, password=self.access_token) + return sftp_connector.ls(remote_dir) + + + def download_file(self, sr_host: str, remote_file: str, local_dir: str) -> str: + """ + Download files from a remote directory of a storage resource to a local directory + + Return Path: /{project_name}/{experiment_name} + + """ + host = sr_host + port = self.default_sftp_port() + sftp_connector = SFTPConnector(host=host, port=int(port), username=self.user_id, password=self.access_token) + path = sftp_connector.get(remote_file, local_dir) + logger.info("Remote files downlaoded to local dir: %s", local_dir) + return path + + def cat_file(self, sr_host: str, remote_file: str) -> bytes: + """ + Download files from a remote directory of a storage resource to a local directory + + Return Path: /{project_name}/{experiment_name} + + """ + host = sr_host + port = self.default_sftp_port() + sftp_connector = SFTPConnector(host=host, port=int(port), username=self.user_id, password=self.access_token) + data = sftp_connector.cat(remote_file) + logger.info("Remote files downlaoded to local dir: %s bytes", len(data)) + return data def launch_experiment( self, @@ -186,24 +380,25 @@ class AiravataOperator: cpu_count: int | None = None, walltime: int | None = None, auto_schedule: bool = False, - ) -> str: + ) -> LaunchState: """ Launch an experiment and return its id """ # preprocess args (str) print("[AV] Preprocessing args...") - gateway_id = str(gateway_id or self.gateway_conf.GATEWAY_ID) - queue_name = str(queue_name or self.experiment_conf.QUEUE_NAME) - grp_name = str(grp_name or self.experiment_conf.GROUP_RESOURCE_PROFILE_NAME) - sr_host = str(sr_host or self.experiment_conf.STORAGE_RESOURCE_HOST) - project_name = str(project_name or self.experiment_conf.PROJECT_NAME) - mount_point = Path(self.gateway_conf.GATEWAY_DATA_STORE_DIR) / self.user_id + gateway_id = str(gateway_id or self.default_gateway_id()) + mount_point = Path(self.default_gateway_data_store_dir()) / self.user_id + + queue_name = str(queue_name or self.default_experiment_queue()) + grp_name = str(grp_name or self.default_grp_name()) + sr_host = str(sr_host or self.default_sr_hostname()) + project_name = str(project_name or self.default_project_name()) # preprocess args (int) - node_count = int(node_count or self.experiment_conf.NODE_COUNT or "1") - cpu_count = int(cpu_count or self.experiment_conf.TOTAL_CPU_COUNT or "1") - walltime = int(walltime or self.experiment_conf.WALL_TIME_LIMIT or "30") + node_count = int(node_count or self.default_node_count() or "1") + cpu_count = int(cpu_count or self.default_cpu_count() or "1") + walltime = int(walltime or self.default_walltime() or "30") # validate args (str) print("[AV] Validating args...") @@ -226,10 +421,10 @@ class AiravataOperator: # setup runtime params print("[AV] Setting up runtime params...") storage = self.get_storage(sr_host) - queue_name = queue_name or self.experiment_conf.QUEUE_NAME - node_count = int(node_count or self.experiment_conf.NODE_COUNT or "1") - cpu_count = int(cpu_count or self.experiment_conf.TOTAL_CPU_COUNT or "1") - walltime = int(walltime or self.experiment_conf.WALL_TIME_LIMIT or "01:00:00") + queue_name = queue_name or self.default_experiment_queue() + node_count = int(node_count or self.default_node_count() or "1") + cpu_count = int(cpu_count or self.default_cpu_count() or "1") + walltime = int(walltime or self.default_walltime() or "30") sr_id = storage.storageResourceId # setup application interface @@ -239,24 +434,17 @@ class AiravataOperator: # setup experiment print("[AV] Setting up experiment...") - data_model_util = DataModelCreationUtil( - self.config_file, - username=self.user_id, - password=None, - gateway_id=gateway_id, - access_token=self.access_token, - ) - experiment = data_model_util.get_experiment_data_model_for_single_application( + experiment = self.create_experiment_model( experiment_name=experiment_name, application_name=app_name, project_name=project_name, description=experiment_name, + gateway_id=gateway_id, ) - # setup experiment directory print("[AV] Setting up experiment directory...") exp_dir = self.make_experiment_dir( - storage_resource=storage, + sr_host=storage.hostName, project_name=project_name, experiment_name=experiment_name, ) @@ -264,7 +452,7 @@ class AiravataOperator: print("[AV] exp_dir:", exp_dir) print("[AV] abs_path:", abs_path) - experiment = data_model_util.configure_computation_resource_scheduling( + experiment = self.configure_computation_resource_scheduling( experiment_model=experiment, computation_resource_name=computation_resource_name, group_resource_profile_name=grp_name, @@ -281,7 +469,7 @@ class AiravataOperator: print("[AV] Setting up file inputs...") def register_input_file(file: Path) -> str: - return str(data_model_util.register_input_file(file.name, sr_host, sr_id, file.name, abs_path)) + return str(self.register_input_file(file.name, sr_host, sr_id, gateway_id, file.name, abs_path)) # setup experiment inputs files_to_upload = list[Path]() @@ -303,8 +491,8 @@ class AiravataOperator: data_inputs[key] = value # configure file inputs for experiment - print("[AV] Uploading file inputs for experiment...") - self.upload_files(storage, files_to_upload, exp_dir) + print(f"[AV] Uploading {len(files_to_upload)} file inputs for experiment...") + self.upload_files(storage.hostName, files_to_upload, exp_dir) # configure experiment inputs experiment_inputs = [] @@ -333,14 +521,21 @@ class AiravataOperator: # launch experiment self.api_server_client.launch_experiment(self.airavata_token, ex_id, gateway_id) - return str(ex_id) + return LaunchState( + experiment_id=str(ex_id), + mount_point=mount_point, + experiment_dir=exp_dir, + sr_host=str(storage.hostName), + ) - def get_experiment_status(self, experiment_id): - status = self.api_server_client.get_experiment_status( - self.airavata_token, experiment_id) - return status + + def get_experiment_status(self, experiment_id) -> Literal["CREATED", "VALIDATED", "SCHEDULED", "LAUNCHED", "EXECUTING", "CANCELING", "CANCELED", "COMPLETED", "FAILED"]: + states = ["CREATED", "VALIDATED", "SCHEDULED", "LAUNCHED", "EXECUTING", "CANCELING", "CANCELED", "COMPLETED", "FAILED"] + status: any = self.api_server_client.get_experiment_status(self.airavata_token, experiment_id) # type: ignore + return states[status.state] + def stop_experiment(self, experiment_id): status = self.api_server_client.terminate_experiment( - self.airavata_token, experiment_id, self.gateway_conf.GATEWAY_ID) + self.airavata_token, experiment_id, self.default_gateway_id()) return status diff --git a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/auth/device_auth.py b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/auth/device_auth.py index 11aada5ffa..165d1a7175 100644 --- a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/auth/device_auth.py +++ b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/auth/device_auth.py @@ -14,8 +14,13 @@ # limitations under the License. # +import datetime +import json +import os import time +import webbrowser +import jwt import requests @@ -24,14 +29,40 @@ class DeviceFlowAuthenticator: idp_url: str realm: str client_id: str - device_code: str | None interval: int - access_token: str | None - refresh_token: str | None + device_code: str | None + _access_token: str | None + _refresh_token: str | None + + def __has_expired__(self, token: str) -> bool: + try: + decoded = jwt.decode(token, options={"verify_signature": False}) + tA = datetime.datetime.now(datetime.timezone.utc).timestamp() + tB = int(decoded.get("exp", 0)) + return tA >= tB + except: + return True @property - def logged_in(self) -> bool: - return self.access_token is not None + def access_token(self) -> str: + if self._access_token and not self.__has_expired__(self._access_token): + return self._access_token + elif self._refresh_token and not self.__has_expired__(self._refresh_token): + self.refresh() + else: + self.login() + assert self._access_token + return self._access_token + + @property + def refresh_token(self) -> str: + if self._refresh_token and not self.__has_expired__(self._refresh_token): + return self._refresh_token + else: + self.login() + assert self._refresh_token + return self._refresh_token + def __init__( self, @@ -47,88 +78,109 @@ class DeviceFlowAuthenticator: raise ValueError( "Missing required environment variables for client ID, realm, or auth server URL") + self.interval = 5 self.device_code = None - self.interval = -1 - self.access_token = None + self._access_token = None + self._refresh_token = None - def login(self, interactive: bool = True): - # Step 0: Check if we have a saved token - if self.__load_saved_token__(): - print("Using saved token") - return - - # Step 1: Request device and user code - auth_device_url = f"{self.idp_url}/realms/{self.realm}/protocol/openid-connect/auth/device" + def refresh(self) -> None: + auth_device_url = f"{self.idp_url}/realms/{self.realm}/protocol/openid-connect/token" response = requests.post(auth_device_url, data={ - "client_id": self.client_id, "scope": "openid"}) - + "client_id": self.client_id, + "grant_type": "refresh_token", + "scope": "openid", + "refresh_token": self._refresh_token + }) if response.status_code != 200: - print(f"Error in device authorization request: {response.status_code} - {response.text}") - return - + raise Exception(f"Error in token refresh request: {response.status_code} - {response.text}") data = response.json() - self.device_code = data.get("device_code") - self.interval = data.get("interval", 5) - - print(f"User code: {data.get('user_code')}") - print(f"Please authenticate by visiting: {data.get('verification_uri_complete')}") - - if interactive: - import webbrowser - - webbrowser.open(data.get("verification_uri_complete")) - - # Step 2: Poll for the token - self.__poll_for_token__() - - def logout(self): - self.access_token = None - self.refresh_token = None - - def __poll_for_token__(self): - token_url = f"{self.idp_url}/realms/{self.realm}/protocol/openid-connect/token" - print("Waiting for authorization...") - while True: - response = requests.post( - token_url, - data={ - "client_id": self.client_id, - "grant_type": "urn:ietf:params:oauth:grant-type:device_code", - "device_code": self.device_code, - }, - ) - if response.status_code == 200: - data = response.json() - self.refresh_token = data.get("refresh_token") - self.access_token = data.get("access_token") - print("Authorization successful!") - self.__persist_token__() - return - elif response.status_code == 400 and response.json().get("error") == "authorization_pending": - time.sleep(self.interval) + self._refresh_token = data["refresh_token"] + self._access_token = data["access_token"] + assert self._access_token is not None + assert self._refresh_token is not None + self.__persist_token__(self._refresh_token, self._access_token) + + def login(self, interactive: bool = True) -> None: + + try: + # [Flow A] Reuse saved token + if os.path.exists("auth.state"): + try: + # [A1] Load token from file + with open("auth.state", "r") as f: + data = json.load(f) + self._refresh_token = str(data["refresh_token"]) + self._access_token = str(data["access_token"]) + except: + print("Failed to load auth.state file!") else: - print(f"Authorization error: {response.status_code} - {response.text}") - break - - def __persist_token__(self): + # [A2] Check if access token is valid, if so, return + if not self.__has_expired__(self._access_token): + print("Authenticated via saved access token!") + return None + else: + print("Access token is invalid!") + # [A3] Check if refresh token is valid. if so, refresh + try: + if not self.__has_expired__(self._refresh_token): + self.refresh() + print("Authenticated via saved refresh token!") + return None + else: + print("Refresh token is invalid!") + except Exception as e: + print(*e.args) + + # [Flow B] Request device and user code + + # [B1] Initiate device auth flow + auth_device_url = f"{self.idp_url}/realms/{self.realm}/protocol/openid-connect/auth/device" + response = requests.post(auth_device_url, data={ + "client_id": self.client_id, + "scope": "openid", + }) + if response.status_code != 200: + raise Exception(f"Error in device authorization request: {response.status_code} - {response.text}") + data = response.json() + self.device_code = data.get("device_code", self.device_code) + self.interval = data.get("interval", self.interval) + url = data['verification_uri_complete'] + print(f"Please authenticate by visiting: {url}") + if interactive: + webbrowser.open(url) + + # [B2] Poll until token is received + token_url = f"{self.idp_url}/realms/{self.realm}/protocol/openid-connect/token" + print("Waiting for authorization...") + while True: + response = requests.post( + token_url, + data={ + "client_id": self.client_id, + "grant_type": "urn:ietf:params:oauth:grant-type:device_code", + "device_code": self.device_code, + }, + ) + if response.status_code == 200: + data = response.json() + self.__persist_token__(data["refresh_token"], data["access_token"]) + print("Authenticated via device auth!") + return + elif response.status_code == 400 and response.json().get("error") == "authorization_pending": + time.sleep(self.interval) + else: + raise Exception(f"Authorization error: {response.status_code} - {response.text}") + + except Exception as e: + print("login() failed!", e) + + def logout(self) -> None: + self._access_token = None + self._refresh_token = None + + def __persist_token__(self, refresh_token: str, access_token: str) -> None: + self._access_token = access_token + self._refresh_token = refresh_token import json with open("auth.state", "w") as f: - json.dump({"refresh_token": self.refresh_token, - "access_token": self.access_token}, f) - - def __load_saved_token__(self): - import json - import jwt - import datetime - try: - with open("auth.state", "r") as f: - data = json.load(f) - self.refresh_token = str(data["refresh_token"]) - self.access_token = str(data["access_token"]) - decoded = jwt.decode(self.access_token, options={"verify_signature": False}) - tA = datetime.datetime.now(datetime.timezone.utc).timestamp() - tB = int(decoded.get("exp", 0)) - return tA < tB - except (FileNotFoundError, KeyError, ValueError, StopIteration) as e: - print(e) - return False + json.dump({"refresh_token": self._refresh_token, "access_token": self._access_token}, f) diff --git a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/plan.py b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/plan.py index 4d098215b3..12e49a7d48 100644 --- a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/plan.py +++ b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/plan.py @@ -18,15 +18,20 @@ from __future__ import annotations import json import time +import os import pydantic from rich.progress import Progress -from .runtime import Runtime +from .runtime import is_terminal_state from .task import Task +import uuid +from .airavata import AiravataOperator +from .auth import context class Plan(pydantic.BaseModel): + id: str | None = pydantic.Field(default=None) tasks: list[Task] = [] @pydantic.field_validator("tasks", mode="before") @@ -36,6 +41,7 @@ class Plan(pydantic.BaseModel): return v def describe(self) -> None: + print(f"Plan(id={self.id}): {len(self.tasks)} tasks") for task in self.tasks: print(task) @@ -72,19 +78,22 @@ class Plan(pydantic.BaseModel): task.stop() print("Task(s) stopped.") - def __stage_fetch__(self) -> list[list[str]]: + def __stage_fetch__(self, local_dir: str) -> list[list[str]]: print("Fetching results...") fps = list[list[str]]() for task in self.tasks: runtime = task.runtime ref = task.ref + task_dir = os.path.join(local_dir, task.name) + os.makedirs(task_dir, exist_ok=True) fps_task = list[str]() assert ref is not None - for remote_fp in task.files(): - fp = runtime.download(ref, remote_fp) + for remote_fp in task.ls(): + fp = runtime.download(remote_fp, task_dir, task) fps_task.append(fp) fps.append(fps_task) print("Results fetched.") + self.save_json(os.path.join(local_dir, "plan.json")) return fps def launch(self, silent: bool = False) -> None: @@ -92,13 +101,17 @@ class Plan(pydantic.BaseModel): self.__stage_prepare__() self.__stage_confirm__(silent) self.__stage_launch_task__() + self.save() except Exception as e: print(*e.args, sep="\n") + def status(self) -> None: + statuses = self.__stage_status__() + for task, status in zip(self.tasks, statuses): + print(f"{task.name}: {status}") + def join(self, check_every_n_mins: float = 0.1) -> None: n = len(self.tasks) - states = ["CREATED", "VALIDATED", "SCHEDULED", "LAUNCHED", "EXECUTING", "CANCELING", "CANCELED", "COMPLETED", "FAILED"] - def is_terminal_state(x): return x in ["CANCELED", "COMPLETED", "FAILED"] try: with Progress() as progress: pbars = [progress.add_task(f"{task.name} ({i+1}/{n})", total=None) for i, task in enumerate(self.tasks)] @@ -106,11 +119,9 @@ class Plan(pydantic.BaseModel): while not all(completed): statuses = self.__stage_status__() for i, (task, status) in enumerate(zip(self.tasks, statuses)): - state = status.state - state_text = states[state] pbar = pbars[i] - progress.update(pbar, description=f"{task.name} ({i+1}/{n}): {state_text}") - if is_terminal_state(state_text): + progress.update(pbar, description=f"{task.name} ({i+1}/{n}): {status}") + if is_terminal_state(status): completed[i] = True progress.update(pbar, completed=True) sleep_time = check_every_n_mins * 60 @@ -119,18 +130,85 @@ class Plan(pydantic.BaseModel): except KeyboardInterrupt: print("Interrupted by user.") + def download(self, local_dir: str): + assert os.path.isdir(local_dir) + self.__stage_fetch__(local_dir) + def stop(self) -> None: self.__stage_stop__() + self.save() def save_json(self, filename: str) -> None: with open(filename, "w") as f: json.dump(self.model_dump(), f, indent=2) - @staticmethod - def load_json(filename: str) -> Plan: - with open(filename, "r") as f: - model = json.load(f) - return Plan(**model) - - def collect_results(self, runtime: Runtime) -> list[list[str]]: - return self.__stage_fetch__() + def save(self) -> None: + av = AiravataOperator(context.access_token) + az = av.__airavata_token__(av.access_token, av.default_gateway_id()) + assert az.accessToken is not None + assert az.claimsMap is not None + headers = { + 'Content-Type': 'application/json', + 'Authorization': 'Bearer ' + az.accessToken, + 'X-Claims': json.dumps(az.claimsMap) + } + import requests + if self.id is None: + self.id = str(uuid.uuid4()) + response = requests.post("https://api.gateway.cybershuttle.org/api/v1/plan", headers=headers, json=self.model_dump()) + print(f"Plan saved: {self.id}") + else: + response = requests.put(f"https://api.gateway.cybershuttle.org/api/v1/plan/{self.id}", headers=headers, json=self.model_dump()) + print(f"Plan updated: {self.id}") + + if response.status_code == 200: + body = response.json() + plan = json.loads(body["data"]) + assert plan["id"] == self.id + else: + raise Exception(response) + +def load_json(filename: str) -> Plan: + with open(filename, "r") as f: + model = json.load(f) + return Plan(**model) + +def load(id: str) -> Plan: + av = AiravataOperator(context.access_token) + az = av.__airavata_token__(av.access_token, av.default_gateway_id()) + assert az.accessToken is not None + assert az.claimsMap is not None + headers = { + 'Content-Type': 'application/json', + 'Authorization': 'Bearer ' + az.accessToken, + 'X-Claims': json.dumps(az.claimsMap) + } + import requests + response = requests.get(f"https://api.gateway.cybershuttle.org/api/v1/plan/{id}", headers=headers) + + if response.status_code == 200: + body = response.json() + plan = json.loads(body["data"]) + return Plan(**plan) + else: + raise Exception(response) + +def query() -> list[Plan]: + av = AiravataOperator(context.access_token) + az = av.__airavata_token__(av.access_token, av.default_gateway_id()) + assert az.accessToken is not None + assert az.claimsMap is not None + headers = { + 'Content-Type': 'application/json', + 'Authorization': 'Bearer ' + az.accessToken, + 'X-Claims': json.dumps(az.claimsMap) + } + import requests + response = requests.get(f"https://api.gateway.cybershuttle.org/api/v1/plan/user", headers=headers) + + if response.status_code == 200: + items: list = response.json() + plans = [json.loads(item["data"]) for item in items] + return [Plan(**plan) for plan in plans] + else: + raise Exception(response) diff --git a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/runtime.py b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/runtime.py index 23347b46e4..40ee36196f 100644 --- a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/runtime.py +++ b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/runtime.py @@ -13,18 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +from __future__ import annotations from .auth import context import abc from typing import Any +from pathlib import Path import pydantic import requests import uuid import time +# from .task import Task Task = Any +def is_terminal_state(x): return x in ["CANCELED", "COMPLETED", "FAILED"] + conn_svc_url = "api.gateway.cybershuttle.org" @@ -37,6 +41,9 @@ class Runtime(abc.ABC, pydantic.BaseModel): @abc.abstractmethod def execute(self, task: Task) -> None: ... + @abc.abstractmethod + def execute_py(self, libraries: list[str], code: str, task: Task) -> None: ... + @abc.abstractmethod def status(self, task: Task) -> str: ... @@ -47,7 +54,13 @@ class Runtime(abc.ABC, pydantic.BaseModel): def ls(self, task: Task) -> list[str]: ... @abc.abstractmethod - def download(self, file: str, task: Task) -> str: ... + def upload(self, file: Path, task: Task) -> str: ... + + @abc.abstractmethod + def download(self, file: str, local_dir: str, task: Task) -> str: ... + + @abc.abstractmethod + def cat(self, file: str, task: Task) -> bytes: ... def __str__(self) -> str: return f"{self.__class__.__name__}(args={self.args})" @@ -87,6 +100,9 @@ class Mock(Runtime): task.agent_ref = str(uuid.uuid4()) task.ref = str(uuid.uuid4()) + def execute_py(self, libraries: list[str], code: str, task: Task) -> None: + pass + def status(self, task: Task) -> str: import random @@ -99,11 +115,17 @@ class Mock(Runtime): pass def ls(self, task: Task) -> list[str]: - return [] + return [""] - def download(self, file: str, task: Task) -> str: + def upload(self, file: Path, task: Task) -> str: return "" + def download(self, file: str, local_dir: str, task: Task) -> str: + return "" + + def cat(self, file: str, task: Task) -> bytes: + return b"" + @staticmethod def default(): return Mock() @@ -115,7 +137,6 @@ class Remote(Runtime): super().__init__(id="remote", args=kwargs) def execute(self, task: Task) -> None: - assert context.access_token is not None assert task.ref is None assert task.agent_ref is None @@ -124,16 +145,45 @@ class Remote(Runtime): print(f"[Remote] Experiment Created: name={task.name}") assert "cluster" in self.args task.agent_ref = str(uuid.uuid4()) - task.ref = av.launch_experiment( + launch_state = av.launch_experiment( experiment_name=task.name, app_name=task.app_id, computation_resource_name=str(self.args["cluster"]), inputs={**task.inputs, "agent_id": task.agent_ref, "server_url": conn_svc_url} ) + task.ref = launch_state.experiment_id + task.workdir = launch_state.experiment_dir + task.sr_host = launch_state.sr_host print(f"[Remote] Experiment Launched: id={task.ref}") + def execute_py(self, libraries: list[str], code: str, task: Task) -> None: + print(f"* Packages: {libraries}") + print(f"* Code:\n{code}") + try: + res = requests.post(f"https://{conn_svc_url}/api/v1/agent/executepythonrequest", json={ + "libraries": libraries, + "code": code, + "pythonVersion": "3.11", # TODO verify + "keepAlive": False, # TODO verify + "parentExperimentId": task.ref, + "agentId": task.agent_ref, + }) + data = res.json() + if data["error"] is not None: + raise Exception(data["error"]) + else: + exc_id = data["executionId"] + while True: + res = requests.get(f"https://{conn_svc_url}/api/v1/agent/executepythonresponse/{exc_id}") + data = res.json() + if data["available"]: + files = data["responseString"].split("\n") + return files + time.sleep(1) + except Exception as e: + print(f"\nRemote execution failed! {e}") + def status(self, task: Task): - assert context.access_token is not None assert task.ref is not None assert task.agent_ref is not None @@ -143,18 +193,21 @@ class Remote(Runtime): return status def signal(self, signal: str, task: Task) -> None: - assert context.access_token is not None assert task.ref is not None assert task.agent_ref is not None from .airavata import AiravataOperator av = AiravataOperator(context.access_token) - status = av.stop_experiment(task.ref) + av.stop_experiment(task.ref) def ls(self, task: Task) -> list[str]: - assert context.access_token is not None assert task.ref is not None assert task.agent_ref is not None + assert task.sr_host is not None + assert task.workdir is not None + + from .airavata import AiravataOperator + av = AiravataOperator(context.access_token) res = requests.post(f"https://{conn_svc_url}/api/v1/agent/executecommandrequest", json={ "agentId": task.agent_ref, @@ -164,8 +217,7 @@ class Remote(Runtime): data = res.json() if data["error"] is not None: if str(data["error"]) == "Agent not found": - print("Experiment is initializing...") - return [] + return av.list_files(task.sr_host, task.workdir) else: raise Exception(data["error"]) else: @@ -178,12 +230,15 @@ class Remote(Runtime): return files time.sleep(1) - def download(self, file: str, task: Task) -> str: - assert context.access_token is not None + def upload(self, file: Path, task: Task) -> str: assert task.ref is not None assert task.agent_ref is not None + assert task.sr_host is not None + assert task.workdir is not None import os + from .airavata import AiravataOperator + av = AiravataOperator(context.access_token) res = requests.post(f"https://{conn_svc_url}/api/v1/agent/executecommandrequest", json={ "agentId": task.agent_ref, @@ -192,7 +247,10 @@ class Remote(Runtime): }) data = res.json() if data["error"] is not None: - raise Exception(data["error"]) + if str(data["error"]) == "Agent not found": + return av.upload_files(task.sr_host, [file], task.workdir).pop() + else: + raise Exception(data["error"]) else: exc_id = data["executionId"] while True: @@ -203,6 +261,71 @@ class Remote(Runtime): return files time.sleep(1) + def download(self, file: str, local_dir: str, task: Task) -> str: + assert task.ref is not None + assert task.agent_ref is not None + assert task.sr_host is not None + assert task.workdir is not None + + import os + from .airavata import AiravataOperator + av = AiravataOperator(context.access_token) + + res = requests.post(f"https://{conn_svc_url}/api/v1/agent/executecommandrequest", json={ + "agentId": task.agent_ref, + "workingDir": ".", + "arguments": ["cat", os.path.join("/data", file)] + }) + data = res.json() + if data["error"] is not None: + if str(data["error"]) == "Agent not found": + return av.download_file(task.sr_host, os.path.join(task.workdir, file), local_dir) + else: + raise Exception(data["error"]) + else: + exc_id = data["executionId"] + while True: + res = requests.get(f"https://{conn_svc_url}/api/v1/agent/executecommandresponse/{exc_id}") + data = res.json() + if data["available"]: + content = data["responseString"] + path = Path(local_dir) / Path(file).name + with open(path, "w") as f: + f.write(content) + return path.as_posix() + time.sleep(1) + + def cat(self, file: str, task: Task) -> bytes: + assert task.ref is not None + assert task.agent_ref is not None + assert task.sr_host is not None + assert task.workdir is not None + + import os + from .airavata import AiravataOperator + av = AiravataOperator(context.access_token) + + res = requests.post(f"https://{conn_svc_url}/api/v1/agent/executecommandrequest", json={ + "agentId": task.agent_ref, + "workingDir": ".", + "arguments": ["cat", os.path.join("/data", file)] + }) + data = res.json() + if data["error"] is not None: + if str(data["error"]) == "Agent not found": + return av.cat_file(task.sr_host, os.path.join(task.workdir, file)) + else: + raise Exception(data["error"]) + else: + exc_id = data["executionId"] + while True: + res = requests.get(f"https://{conn_svc_url}/api/v1/agent/executecommandresponse/{exc_id}") + data = res.json() + if data["available"]: + content = str(data["responseString"]).encode() + return content + time.sleep(1) + @staticmethod def default(): return Remote( diff --git a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/scripter.py b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/scripter.py new file mode 100644 index 0000000000..76aa874ddb --- /dev/null +++ b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/scripter.py @@ -0,0 +1,144 @@ +import inspect +import ast +import textwrap +import sys + + +def scriptize(func): + # Get the source code of the decorated function + source_code = textwrap.dedent(inspect.getsource(func)) + func_tree = ast.parse(source_code) + + # Retrieve the module where the function is defined + module_name = func.__module__ + if module_name in sys.modules: + module = sys.modules[module_name] + else: + raise RuntimeError(f"Cannot find module {module_name} for function {func.__name__}") + + # Attempt to get the module source. + # If this fails (e.g., in a Jupyter notebook), fallback to an empty module tree. + try: + module_source = textwrap.dedent(inspect.getsource(module)) + module_tree = ast.parse(module_source) + except (TypeError, OSError): + # In Jupyter (or certain environments), we can't get the module source this way. + # Use an empty module tree as a fallback. + module_tree = ast.parse("") + + # Find the function definition node + func_def = next( + (node for node in func_tree.body if isinstance(node, ast.FunctionDef)), None) + if not func_def: + raise ValueError("No function definition found in func_tree.") + + # ---- NEW: Identify used names in the function body ---- + # We'll walk the function body to collect all names used. + class NameCollector(ast.NodeVisitor): + def __init__(self): + self.used_names = set() + + def visit_Name(self, node): + self.used_names.add(node.id) + self.generic_visit(node) + + def visit_Attribute(self, node): + # This accounts for usage like time.sleep (attribute access) + # We add 'time' if we see something like time.sleep + # The top-level name is usually in node.value + if isinstance(node.value, ast.Name): + self.used_names.add(node.value.id) + self.generic_visit(node) + + name_collector = NameCollector() + name_collector.visit(func_def) + used_names = name_collector.used_names + + # For imports, we need to consider a few cases: + # - `import module` + # - `import module as alias` + # - `from module import name` + # We'll keep an import if it introduces at least one name or module referenced by the function. + def is_import_used(import_node): + + if isinstance(import_node, ast.Import): + # import something [as alias] + for alias in import_node.names: + # If we have something like `import time` and "time" is used, + # or `import pandas as pd` and "pd" is used, keep it. + if alias.asname and alias.asname in used_names: + return True + if alias.name.split('.')[0] in used_names: + return True + return False + elif isinstance(import_node, ast.ImportFrom): + # from module import name(s) + # Keep if any of the imported names or their asnames are used + for alias in import_node.names: + # Special case: if we have `from module import task_context`, ignore it + if alias.name == "task_context": + return False + # If from module import x as y, check y; else check x + if alias.asname and alias.asname in used_names: + return True + if alias.name in used_names: + return True + # Another subtlety: if we have `from time import sleep` + # and we call `time.sleep()` is that detected? + # Actually, we already caught attribute usage above, which would add "time" to used_names + # but not "sleep". If the code does `sleep(n)` directly, then "sleep" is in used_names. + return False + return False + + # For other functions, include only if their name is referenced. + def is_function_used(func_node): + return func_node.name in used_names + + def wrapper(*args, **kwargs): + # Bind arguments + func_signature = inspect.signature(func) + bound_args = func_signature.bind(*args, **kwargs) + bound_args.apply_defaults() + + # Convert the original function body to source + body_source_lines = [ast.unparse(stmt) for stmt in func_def.body] + body_source_code = "\n".join(body_source_lines) + + # Collect relevant code blocks: + relevant_code_blocks = [] + for node in module_tree.body: + if isinstance(node, ast.Import) or isinstance(node, ast.ImportFrom): + # Include only used imports + if is_import_used(node): + relevant_code_blocks.append(ast.unparse(node).strip()) + elif isinstance(node, ast.FunctionDef): + # Include only used functions, excluding the decorator itself and the decorated function + if node.name not in ('task_context', func.__name__) and is_function_used(node): + func_code = ast.unparse(node).strip() + relevant_code_blocks.append(func_code) + + # Prepare argument assignments + arg_assignments = [] + for arg_name, arg_value in bound_args.arguments.items(): + # Stringify arguments as before + if isinstance(arg_value, str): + arg_assignments.append(f"{arg_name} = {arg_value!r}") + else: + arg_assignments.append(f"{arg_name} = {repr(arg_value)}") + + # Combine everything + combined_code_parts = [] + if relevant_code_blocks: + combined_code_parts.append("\n\n".join(relevant_code_blocks)) + if arg_assignments: + if combined_code_parts: + combined_code_parts.append("") # blank line before args + combined_code_parts.extend(arg_assignments) + if arg_assignments: + combined_code_parts.append("") # blank line before body + combined_code_parts.append(body_source_code) + + combined_code = "\n".join(combined_code_parts).strip() + return combined_code + + return wrapper diff --git a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/sftp.py b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/sftp.py index 15d66d1790..18e72a1d15 100644 --- a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/sftp.py +++ b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/sftp.py @@ -20,7 +20,6 @@ from datetime import datetime from rich.progress import Progress import paramiko -from paramiko import SFTPClient, Transport from scp import SCPClient logger = logging.getLogger(__name__) @@ -36,61 +35,28 @@ def create_pkey(pkey_path): class SFTPConnector(object): - def __init__(self, host, port, username, password=None, pkey=None): + def __init__(self, host: str, port: int, username: str, password: str | None = None, pkey: str | None = None): self.host = host self.port = port self.username = username self.password = password - self.pkey = pkey ssh = paramiko.SSHClient() self.ssh = ssh - # self.sftp = paramiko.SFTPClient() - # Trust all key policy on remote host - ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - def upload_files(self, localpaths: list[Path], remote_path: str): - transport = Transport(sock=(self.host, int(self.port))) - if self.pkey is not None: - transport.connect(username=self.username, pkey=create_pkey(self.pkey)) - else: - transport.connect(username=self.username, password=self.password) - try: - with Progress() as progress: - task = progress.add_task("Uploading...", total=len(localpaths)-1) - for file in localpaths: - connection = SFTPClient.from_transport(transport) - assert connection is not None - try: - connection.lstat(remote_path) # Test if remote_path exists - except IOError: - connection.mkdir(remote_path) - remote_fpath = remote_path + "/" + file.name - connection.put(file, remote_fpath) - progress.update(task, advance=1, description=f"Uploading: {file.name}") - progress.update(task, completed=True) - finally: - transport.close() - return remote_path - - def make_experiment_dir(self, project_name: str, exprement_id, remote_base=""): + def mkdir(self, project_name: str, exprement_id: str): project_name = project_name.replace(" ", "_") time = datetime.now().strftime("%Y-%m-%d %H:%M:%S").replace(" ", "_") time = time.replace(":", "_") time = time.replace("-", "_") exprement_id = exprement_id + "_" + time - base_path = remote_base + "/" + project_name + base_path = "/" + project_name remote_path = base_path + "/" + exprement_id - transport = Transport(sock=(self.host, int(self.port))) - if self.pkey is not None: - transport.connect(username=self.username, - pkey=create_pkey(self.pkey)) - else: - transport.connect(username=self.username, password=self.password) - + transport = paramiko.Transport(sock=(self.host, int(self.port))) + transport.connect(username=self.username, password=self.password) try: - connection = SFTPClient.from_transport(transport) + connection = paramiko.SFTPClient.from_transport(transport) assert connection is not None try: connection.lstat(base_path) # Test if remote_path exists @@ -102,25 +68,55 @@ class SFTPConnector(object): connection.mkdir(remote_path) finally: transport.close() - return remote_path - def download_files(self, local_path, remote_path): - if self.pkey is not None: - self.ssh.connect(self.host, self.port, self.username, - pkey=create_pkey(self.pkey)) - else: - self.ssh.connect(self.host, self.port, - self.username, password=self.password) - - transport = self.ssh.get_transport() - assert transport is not None + def put(self, local_paths: list[Path], remote_path: str) -> list[str]: + transport = paramiko.Transport(sock=(self.host, int(self.port))) + transport.connect(username=self.username, password=self.password) + remote_paths = [] + try: + with Progress() as progress: + task = progress.add_task("Uploading...", total=len(local_paths)-1) + for file in local_paths: + connection = paramiko.SFTPClient.from_transport(transport) + assert connection is not None + try: + connection.lstat(remote_path) # Test if remote_path exists + except IOError: + connection.mkdir(remote_path) + remote_fpath = remote_path + "/" + file.name + connection.put(file, remote_fpath) + remote_paths.append(remote_fpath) + progress.update(task, advance=1, description=f"Uploading: {file.name}") + progress.update(task, completed=True) + finally: + transport.close() + return remote_paths + + def ls(self, remote_path: str) -> list[str]: + transport = paramiko.Transport(sock=(self.host, int(self.port))) + transport.connect(username=self.username, password=self.password) + try: + connection = paramiko.SFTPClient.from_transport(transport) + assert connection is not None + files = connection.listdir(remote_path) + finally: + transport.close() + return files + + def get(self, remote_path: str, local_path: str) -> str: + transport = paramiko.Transport(sock=(self.host, int(self.port))) + transport.connect(username=self.username, password=self.password) with SCPClient(transport) as conn: - conn.get(remote_path=remote_path, - local_path=local_path, recursive=True) + conn.get(remote_path, local_path, recursive=True) self.ssh.close() - - @staticmethod - def uploading_info(uploaded_file_size, total_file_size): - logging.info("uploaded_file_size : {} total_file_size : {}".format( - uploaded_file_size, total_file_size)) + return (Path(local_path) / Path(remote_path).name).as_posix() + + def cat(self, remote_path: str) -> bytes: + transport = paramiko.Transport(sock=(self.host, int(self.port))) + transport.connect(username=self.username, password=self.password) + sftp = paramiko.SFTPClient.from_transport(transport) + assert sftp is not None + with sftp.open(remote_path, "r") as f: + content = f.read() + return content diff --git a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/task.py b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/task.py index d4290b91ee..6d67c2dcf3 100644 --- a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/task.py +++ b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/task.py @@ -13,11 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +from __future__ import annotations from typing import Any - import pydantic - from .runtime import Runtime class Task(pydantic.BaseModel): @@ -28,6 +26,8 @@ class Task(pydantic.BaseModel): runtime: Runtime ref: str | None = pydantic.Field(default=None) agent_ref: str | None = pydantic.Field(default=None) + workdir: str | None = pydantic.Field(default=None) + sr_host: str | None = pydantic.Field(default=None) @pydantic.field_validator("runtime", mode="before") def set_runtime(cls, v): @@ -38,7 +38,7 @@ class Task(pydantic.BaseModel): return v def __str__(self) -> str: - return f"Task(\nname={self.name}\napp_id={self.app_id}\ninputs={self.inputs}\nruntime={self.runtime}\n)" + return f"Task(\nname={self.name}\napp_id={self.app_id}\ninputs={self.inputs}\nruntime={self.runtime}\nref={self.ref}\nagent_ref={self.agent_ref}\nfile_path={self.sr_host}:{self.workdir}\n)" def launch(self) -> None: assert self.ref is None @@ -49,14 +49,34 @@ class Task(pydantic.BaseModel): assert self.ref is not None return self.runtime.status(self) - def files(self) -> list[str]: + def ls(self) -> list[str]: assert self.ref is not None return self.runtime.ls(self) - def cat(self, file: str) -> str: + def upload(self, file: str) -> str: + assert self.ref is not None + from pathlib import Path + return self.runtime.upload(Path(file), self) + + def download(self, file: str, local_dir: str) -> str: + assert self.ref is not None + from pathlib import Path + Path(local_dir).mkdir(parents=True, exist_ok=True) + return self.runtime.download(file, local_dir, self) + + def cat(self, file: str) -> bytes: assert self.ref is not None - return self.runtime.download(file, self) + return self.runtime.cat(file, self) def stop(self) -> None: assert self.ref is not None return self.runtime.signal("SIGTERM", self) + + def context(self, packages: list[str]) -> Any: + def decorator(func): + def wrapper(*args, **kwargs): + from .scripter import scriptize + make_script = scriptize(func) + return self.runtime.execute_py(packages, make_script(*args, **kwargs), self) + return wrapper + return decorator diff --git a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_sdk/clients/api_server_client.py b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_sdk/clients/api_server_client.py index c6a10132fd..1c120c10d3 100644 --- a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_sdk/clients/api_server_client.py +++ b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_sdk/clients/api_server_client.py @@ -29,9 +29,12 @@ logger.setLevel(logging.DEBUG) class APIServerClient(object): - def __init__(self, configuration_file_location=None): - self.api_server_settings = APIServerClientSettings(configuration_file_location) - self._load_settings(configuration_file_location) + def __init__(self, configuration_file_location=None, api_server_settings=None): + if configuration_file_location is not None: + self.api_server_settings = APIServerClientSettings(configuration_file_location) + self._load_settings(configuration_file_location) + elif api_server_settings is not None: + self.api_server_settings = api_server_settings self.api_server_client_pool = utils.initialize_api_client_pool(self.api_server_settings.API_SERVER_HOST, self.api_server_settings.API_SERVER_PORT, self.api_server_settings.API_SERVER_SECURE) diff --git a/airavata-api/airavata-client-sdks/airavata-python-sdk/pyproject.toml b/airavata-api/airavata-client-sdks/airavata-python-sdk/pyproject.toml index 79ba8eedbc..5a9c418151 100644 --- a/airavata-api/airavata-client-sdks/airavata-python-sdk/pyproject.toml +++ b/airavata-api/airavata-client-sdks/airavata-python-sdk/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "airavata-python-sdk-test" -version = "0.0.3" +version = "0.0.5.post1" description = "Apache Airavata Python SDK" readme = "README.md" license = { text = "Apache License 2.0" } @@ -14,7 +14,7 @@ dependencies = [ "oauthlib", "requests", "requests-oauthlib", - "thrift~=0.16.0", + "thrift~=0.21.0", "thrift_connector", "paramiko", "scp", diff --git a/airavata-api/airavata-client-sdks/airavata-python-sdk/samples/annotations.py b/airavata-api/airavata-client-sdks/airavata-python-sdk/samples/annotations.py deleted file mode 100644 index 2c368b69f3..0000000000 --- a/airavata-api/airavata-client-sdks/airavata-python-sdk/samples/annotations.py +++ /dev/null @@ -1,70 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import time -import inspect - -# Define the decorator factory -def analyze_replica(plan_id = 0, replica_id=1): - def decorator(func): - def wrapper(*args, **kwargs): - # Filter loaded functions and classes - loaded_functions = {name: obj for name, obj in globals().items() if inspect.isfunction(obj)} - loaded_classes = {name: obj for name, obj in globals().items() if inspect.isclass(obj)} - - print("Plan id ", plan_id, " Replica id ", replica_id) - print("Passed function") - print(inspect.getsource(func)) - - print("Functions loaded in session:") - for name, f in loaded_functions.items(): - # Skip the wrapper itself and the decorator function - if name in ['execution_timer', 'decorator', 'wrapper']: - continue - print(f"- {name}:") - print(inspect.getsource(f)) - - print("\nClasses loaded in session:") - for name, cls in loaded_classes.items(): - print(f"- {name}:") - print(inspect.getsource(cls)) - - # Find the - # Call the original function - result = func(*args, **kwargs) - return result # Return the result of the original function - return wrapper - return decorator - - -# Example standalone function -def print_something(): - print("PRINTING SOMETHING") - -def print_some_int(integer= 10): - print("PRINTING SOMETHING ", integer) - -# Apply the decorator with a parameter -@analyze_replica(plan_id = 100, replica_id=10110) -def example_function(n): - time.sleep(n) # Simulate a delay - print_something() - return f"Function ran for {n} seconds." - - -# Call the decorated function -print(example_function(2)) \ No newline at end of file diff --git a/airavata-api/airavata-client-sdks/airavata-python-sdk/samples/poc.ipynb b/airavata-api/airavata-client-sdks/airavata-python-sdk/samples/poc.ipynb index eb01e821b6..a44d17507a 100644 --- a/airavata-api/airavata-client-sdks/airavata-python-sdk/samples/poc.ipynb +++ b/airavata-api/airavata-client-sdks/airavata-python-sdk/samples/poc.ipynb @@ -23,80 +23,11 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Obtaining file:///Users/yasith/projects/artisan/airavata/airavata-api/airavata-client-sdks/airavata-python-sdk\n", - " Installing build dependencies ... \u001b[?25ldone\n", - "\u001b[?25h Checking if build backend supports build_editable ... \u001b[?25ldone\n", - "\u001b[?25h Getting requirements to build editable ... \u001b[?25ldone\n", - "\u001b[?25h Preparing editable metadata (pyproject.toml) ... \u001b[?25ldone\n", - "\u001b[?25hRequirement already satisfied: oauthlib in /Users/yasith/.mamba/envs/airavata/lib/python3.12/site-packages (from airavata-python-sdk-test==0.0.2) (3.2.2)\n", - "Requirement already satisfied: requests in /Users/yasith/.mamba/envs/airavata/lib/python3.12/site-packages (from airavata-python-sdk-test==0.0.2) (2.32.3)\n", - "Requirement already satisfied: requests-oauthlib in /Users/yasith/.mamba/envs/airavata/lib/python3.12/site-packages (from airavata-python-sdk-test==0.0.2) (2.0.0)\n", - "Requirement already satisfied: thrift~=0.16.0 in /Users/yasith/.mamba/envs/airavata/lib/python3.12/site-packages (from airavata-python-sdk-test==0.0.2) (0.16.0)\n", - "Requirement already satisfied: thrift_connector in /Users/yasith/.mamba/envs/airavata/lib/python3.12/site-packages (from airavata-python-sdk-test==0.0.2) (0.24)\n", - "Requirement already satisfied: paramiko in /Users/yasith/.mamba/envs/airavata/lib/python3.12/site-packages (from airavata-python-sdk-test==0.0.2) (3.5.0)\n", - "Requirement already satisfied: scp in /Users/yasith/.mamba/envs/airavata/lib/python3.12/site-packages (from airavata-python-sdk-test==0.0.2) (0.15.0)\n", - "Requirement already satisfied: pysftp in /Users/yasith/.mamba/envs/airavata/lib/python3.12/site-packages (from airavata-python-sdk-test==0.0.2) (0.2.9)\n", - "Requirement already satisfied: configparser in /Users/yasith/.mamba/envs/airavata/lib/python3.12/site-packages (from airavata-python-sdk-test==0.0.2) (7.1.0)\n", - "Requirement already satisfied: urllib3 in /Users/yasith/.mamba/envs/airavata/lib/python3.12/site-packages (from airavata-python-sdk-test==0.0.2) (2.2.3)\n", - "Requirement already satisfied: pyjwt in /Users/yasith/.mamba/envs/airavata/lib/python3.12/site-packages (from airavata-python-sdk-test==0.0.2) (2.10.1)\n", - "Requirement already satisfied: pydantic in /Users/yasith/.mamba/envs/airavata/lib/python3.12/site-packages (from airavata-python-sdk-test==0.0.2) (2.10.3)\n", - "Requirement already satisfied: rich in /Users/yasith/.mamba/envs/airavata/lib/python3.12/site-packages (from airavata-python-sdk-test==0.0.2) (13.9.4)\n", - "Requirement already satisfied: ipywidgets in /Users/yasith/.mamba/envs/airavata/lib/python3.12/site-packages (from airavata-python-sdk-test==0.0.2) (8.1.5)\n", - "Requirement already satisfied: six>=1.7.2 in /Users/yasith/.mamba/envs/airavata/lib/python3.12/site-packages (from thrift~=0.16.0->airavata-python-sdk-test==0.0.2) (1.16.0)\n", - "Requirement already satisfied: comm>=0.1.3 in /Users/yasith/.mamba/envs/airavata/lib/python3.12/site-packages (from ipywidgets->airavata-python-sdk-test==0.0.2) (0.2.2)\n", - "Requirement already satisfied: ipython>=6.1.0 in /Users/yasith/.mamba/envs/airavata/lib/python3.12/site-packages (from ipywidgets->airavata-python-sdk-test==0.0.2) (8.30.0)\n", - "Requirement already satisfied: traitlets>=4.3.1 in /Users/yasith/.mamba/envs/airavata/lib/python3.12/site-packages (from ipywidgets->airavata-python-sdk-test==0.0.2) (5.14.3)\n", - "Requirement already satisfied: widgetsnbextension~=4.0.12 in /Users/yasith/.mamba/envs/airavata/lib/python3.12/site-packages (from ipywidgets->airavata-python-sdk-test==0.0.2) (4.0.13)\n", - "Requirement already satisfied: jupyterlab-widgets~=3.0.12 in /Users/yasith/.mamba/envs/airavata/lib/python3.12/site-packages (from ipywidgets->airavata-python-sdk-test==0.0.2) (3.0.13)\n", - "Requirement already satisfied: bcrypt>=3.2 in /Users/yasith/.mamba/envs/airavata/lib/python3.12/site-packages (from paramiko->airavata-python-sdk-test==0.0.2) (4.2.1)\n", - "Requirement already satisfied: cryptography>=3.3 in /Users/yasith/.mamba/envs/airavata/lib/python3.12/site-packages (from paramiko->airavata-python-sdk-test==0.0.2) (44.0.0)\n", - "Requirement already satisfied: pynacl>=1.5 in /Users/yasith/.mamba/envs/airavata/lib/python3.12/site-packages (from paramiko->airavata-python-sdk-test==0.0.2) (1.5.0)\n", - "Requirement already satisfied: annotated-types>=0.6.0 in /Users/yasith/.mamba/envs/airavata/lib/python3.12/site-packages (from pydantic->airavata-python-sdk-test==0.0.2) (0.7.0)\n", - "Requirement already satisfied: pydantic-core==2.27.1 in /Users/yasith/.mamba/envs/airavata/lib/python3.12/site-packages (from pydantic->airavata-python-sdk-test==0.0.2) (2.27.1)\n", - "Requirement already satisfied: typing-extensions>=4.12.2 in /Users/yasith/.mamba/envs/airavata/lib/python3.12/site-packages (from pydantic->airavata-python-sdk-test==0.0.2) (4.12.2)\n", - "Requirement already satisfied: charset-normalizer<4,>=2 in /Users/yasith/.mamba/envs/airavata/lib/python3.12/site-packages (from requests->airavata-python-sdk-test==0.0.2) (3.4.0)\n", - "Requirement already satisfied: idna<4,>=2.5 in /Users/yasith/.mamba/envs/airavata/lib/python3.12/site-packages (from requests->airavata-python-sdk-test==0.0.2) (3.10)\n", - "Requirement already satisfied: certifi>=2017.4.17 in /Users/yasith/.mamba/envs/airavata/lib/python3.12/site-packages (from requests->airavata-python-sdk-test==0.0.2) (2024.8.30)\n", - "Requirement already satisfied: markdown-it-py>=2.2.0 in /Users/yasith/.mamba/envs/airavata/lib/python3.12/site-packages (from rich->airavata-python-sdk-test==0.0.2) (3.0.0)\n", - "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /Users/yasith/.mamba/envs/airavata/lib/python3.12/site-packages (from rich->airavata-python-sdk-test==0.0.2) (2.18.0)\n", - "Requirement already satisfied: cffi>=1.12 in /Users/yasith/.mamba/envs/airavata/lib/python3.12/site-packages (from cryptography>=3.3->paramiko->airavata-python-sdk-test==0.0.2) (1.17.1)\n", - "Requirement already satisfied: decorator in /Users/yasith/.mamba/envs/airavata/lib/python3.12/site-packages (from ipython>=6.1.0->ipywidgets->airavata-python-sdk-test==0.0.2) (5.1.1)\n", - "Requirement already satisfied: jedi>=0.16 in /Users/yasith/.mamba/envs/airavata/lib/python3.12/site-packages (from ipython>=6.1.0->ipywidgets->airavata-python-sdk-test==0.0.2) (0.19.2)\n", - "Requirement already satisfied: matplotlib-inline in /Users/yasith/.mamba/envs/airavata/lib/python3.12/site-packages (from ipython>=6.1.0->ipywidgets->airavata-python-sdk-test==0.0.2) (0.1.7)\n", - "Requirement already satisfied: pexpect>4.3 in /Users/yasith/.mamba/envs/airavata/lib/python3.12/site-packages (from ipython>=6.1.0->ipywidgets->airavata-python-sdk-test==0.0.2) (4.9.0)\n", - "Requirement already satisfied: prompt_toolkit<3.1.0,>=3.0.41 in /Users/yasith/.mamba/envs/airavata/lib/python3.12/site-packages (from ipython>=6.1.0->ipywidgets->airavata-python-sdk-test==0.0.2) (3.0.48)\n", - "Requirement already satisfied: stack_data in /Users/yasith/.mamba/envs/airavata/lib/python3.12/site-packages (from ipython>=6.1.0->ipywidgets->airavata-python-sdk-test==0.0.2) (0.6.2)\n", - "Requirement already satisfied: mdurl~=0.1 in /Users/yasith/.mamba/envs/airavata/lib/python3.12/site-packages (from markdown-it-py>=2.2.0->rich->airavata-python-sdk-test==0.0.2) (0.1.2)\n", - "Requirement already satisfied: pycparser in /Users/yasith/.mamba/envs/airavata/lib/python3.12/site-packages (from cffi>=1.12->cryptography>=3.3->paramiko->airavata-python-sdk-test==0.0.2) (2.22)\n", - "Requirement already satisfied: parso<0.9.0,>=0.8.4 in /Users/yasith/.mamba/envs/airavata/lib/python3.12/site-packages (from jedi>=0.16->ipython>=6.1.0->ipywidgets->airavata-python-sdk-test==0.0.2) (0.8.4)\n", - "Requirement already satisfied: ptyprocess>=0.5 in /Users/yasith/.mamba/envs/airavata/lib/python3.12/site-packages (from pexpect>4.3->ipython>=6.1.0->ipywidgets->airavata-python-sdk-test==0.0.2) (0.7.0)\n", - "Requirement already satisfied: wcwidth in /Users/yasith/.mamba/envs/airavata/lib/python3.12/site-packages (from prompt_toolkit<3.1.0,>=3.0.41->ipython>=6.1.0->ipywidgets->airavata-python-sdk-test==0.0.2) (0.2.13)\n", - "Requirement already satisfied: executing>=1.2.0 in /Users/yasith/.mamba/envs/airavata/lib/python3.12/site-packages (from stack_data->ipython>=6.1.0->ipywidgets->airavata-python-sdk-test==0.0.2) (2.1.0)\n", - "Requirement already satisfied: asttokens>=2.1.0 in /Users/yasith/.mamba/envs/airavata/lib/python3.12/site-packages (from stack_data->ipython>=6.1.0->ipywidgets->airavata-python-sdk-test==0.0.2) (3.0.0)\n", - "Requirement already satisfied: pure-eval in /Users/yasith/.mamba/envs/airavata/lib/python3.12/site-packages (from stack_data->ipython>=6.1.0->ipywidgets->airavata-python-sdk-test==0.0.2) (0.2.3)\n", - "Building wheels for collected packages: airavata-python-sdk-test\n", - " Building editable for airavata-python-sdk-test (pyproject.toml) ... \u001b[?25ldone\n", - "\u001b[?25h Created wheel for airavata-python-sdk-test: filename=airavata_python_sdk_test-0.0.2-0.editable-py3-none-any.whl size=11284 sha256=c3d58cfa6d1cd393fa9ff8ff597e416e466721170a37adcd4b3429d39076dc3e\n", - " Stored in directory: /private/var/folders/_n/fcf6nx4j67gbbt4_8mjqxdc80000gn/T/pip-ephem-wheel-cache-srerellb/wheels/6a/64/3a/ba5bbd28958f1b9f1f2d15d2c8999c899e17c402760ebd7d24\n", - "Successfully built airavata-python-sdk-test\n", - "Installing collected packages: airavata-python-sdk-test\n", - " Attempting uninstall: airavata-python-sdk-test\n", - " Found existing installation: airavata-python-sdk-test 0.0.2\n", - " Uninstalling airavata-python-sdk-test-0.0.2:\n", - " Successfully uninstalled airavata-python-sdk-test-0.0.2\n", - "Successfully installed airavata-python-sdk-test-0.0.2\n", - "Note: you may need to restart the kernel to use updated packages.\n", - "/Users/yasith/projects/artisan/airavata/airavata-api/airavata-client-sdks/airavata-python-sdk/samples\n" - ] - } - ], + "outputs": [], "source": [ - "%pip install --upgrade airavata-python-sdk-test" + "%pip uninstall -y airavata-python-sdk-test\n", + "%pip cache purge\n", + "%pip install -e airavata-api/airavata-client-sdks/airavata-python-sdk" ] }, { @@ -110,8 +41,24 @@ "cell_type": "code", "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/Users/yasith/projects/artisan/airavata/airavata-api/airavata-client-sdks/airavata-python-sdk/samples\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "using legacy validation callback\n" + ] + } + ], "source": [ + "%cd airavata-api/airavata-client-sdks/airavata-python-sdk/samples\n", "import airavata_experiments as ae\n", "from airavata_experiments import md" ] @@ -128,17 +75,9 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Using saved token\n" - ] - } - ], + "outputs": [], "source": [ "ae.login()" ] @@ -270,8 +209,7 @@ "outputs": [], "source": [ "plan = exp.plan() # this will create a plan for the experiment\n", - "plan.describe() # this will describe the plan\n", - "plan.save_json(\"plan.json\") # save the plan state" + "plan.describe() # this will describe the plan" ] }, { @@ -287,16 +225,16 @@ "metadata": {}, "outputs": [], "source": [ - "plan = ae.load_plan(\"plan.json\")\n", - "plan.launch()\n", - "plan.save_json(\"plan.json\") # save the plan state" + "plan.save() # this will save the plan in DB\n", + "plan.launch() # this will launch the plan\n", + "plan.save_json(\"plan.json\") # this will save the plan locally" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Option A - Wait for Completion" + "## Load and Describe the Launched Plan" ] }, { @@ -305,49 +243,190 @@ "metadata": {}, "outputs": [], "source": [ - "plan = ae.load_plan(\"plan.json\")\n", + "assert plan.id is not None\n", + "plan = ae.plan.load(plan.id)\n", "plan.describe()" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## List all Plans the User Created" + ] + }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "fc07d00c22d04fc2b7d4eb2235fe810b", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Output()" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, { "data": { "text/html": [ - "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n" + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>id</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>0</th>\n", + " <td>16781b12-fd99-496c-b815-c0fdc5889664</td>\n", + " </tr>\n", + " <tr>\n", + " <th>1</th>\n", + " <td>2a09f1c4-8a0a-46e4-bbdd-ffab13be3d5b</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2</th>\n", + " <td>2a7896fa-5898-42f6-92b6-c053e4a702ba</td>\n", + " </tr>\n", + " <tr>\n", + " <th>3</th>\n", + " <td>2e206dab-ada7-45a6-a2ea-940adf9ef646</td>\n", + " </tr>\n", + " <tr>\n", + " <th>4</th>\n", + " <td>4fb8a73b-8333-4c73-8e74-5dd103f8a22f</td>\n", + " </tr>\n", + " <tr>\n", + " <th>5</th>\n", + " <td>5197d68c-63ec-4d13-bac5-24484e1d0ca6</td>\n", + " </tr>\n", + " <tr>\n", + " <th>6</th>\n", + " <td>54b9dcd6-a5e8-4a05-9690-aacd346de55c</td>\n", + " </tr>\n", + " <tr>\n", + " <th>7</th>\n", + " <td>562a195e-83f9-4de4-af5b-c43b4a2a40f6</td>\n", + " </tr>\n", + " <tr>\n", + " <th>8</th>\n", + " <td>768d97d5-233b-4450-a7e3-4df31f1fac3c</td>\n", + " </tr>\n", + " <tr>\n", + " <th>9</th>\n", + " <td>82814692-63fa-48e1-9e26-78b75269f513</td>\n", + " </tr>\n", + " <tr>\n", + " <th>10</th>\n", + " <td>ae70b7d2-294e-44c1-b2d7-8586642e241e</td>\n", + " </tr>\n", + " <tr>\n", + " <th>11</th>\n", + " <td>af3a2094-5bb3-4452-a9c3-45451bfd23cb</td>\n", + " </tr>\n", + " <tr>\n", + " <th>12</th>\n", + " <td>b82ae820-93bc-4e26-b080-2563824a1c5b</td>\n", + " </tr>\n", + " <tr>\n", + " <th>13</th>\n", + " <td>c51d01a2-4b57-47c7-a4d2-91a8ede53c77</td>\n", + " </tr>\n", + " <tr>\n", + " <th>14</th>\n", + " <td>d5db8cc0-76da-4435-9509-3a5733c41d7e</td>\n", + " </tr>\n", + " <tr>\n", + " <th>15</th>\n", + " <td>d6e5e9f0-dc11-4262-b16a-51fef7be42c1</td>\n", + " </tr>\n", + " <tr>\n", + " <th>16</th>\n", + " <td>d763f645-2f8f-460e-9e83-4a98365508eb</td>\n", + " </tr>\n", + " <tr>\n", + " <th>17</th>\n", + " <td>eff68e72-c585-4066-a8a0-d36cc45f648c</td>\n", + " </tr>\n", + " <tr>\n", + " <th>18</th>\n", + " <td>fcc54603-aa0b-4ca7-89ac-c04d0725f4cb</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "</div>" ], - "text/plain": [] + "text/plain": [ + " id\n", + "0 16781b12-fd99-496c-b815-c0fdc5889664\n", + "1 2a09f1c4-8a0a-46e4-bbdd-ffab13be3d5b\n", + "2 2a7896fa-5898-42f6-92b6-c053e4a702ba\n", + "3 2e206dab-ada7-45a6-a2ea-940adf9ef646\n", + "4 4fb8a73b-8333-4c73-8e74-5dd103f8a22f\n", + "5 5197d68c-63ec-4d13-bac5-24484e1d0ca6\n", + "6 54b9dcd6-a5e8-4a05-9690-aacd346de55c\n", + "7 562a195e-83f9-4de4-af5b-c43b4a2a40f6\n", + "8 768d97d5-233b-4450-a7e3-4df31f1fac3c\n", + "9 82814692-63fa-48e1-9e26-78b75269f513\n", + "10 ae70b7d2-294e-44c1-b2d7-8586642e241e\n", + "11 af3a2094-5bb3-4452-a9c3-45451bfd23cb\n", + "12 b82ae820-93bc-4e26-b080-2563824a1c5b\n", + "13 c51d01a2-4b57-47c7-a4d2-91a8ede53c77\n", + "14 d5db8cc0-76da-4435-9509-3a5733c41d7e\n", + "15 d6e5e9f0-dc11-4262-b16a-51fef7be42c1\n", + "16 d763f645-2f8f-460e-9e83-4a98365508eb\n", + "17 eff68e72-c585-4066-a8a0-d36cc45f648c\n", + "18 fcc54603-aa0b-4ca7-89ac-c04d0725f4cb" + ] }, "metadata": {}, "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Interrupted by user.\n" - ] } ], "source": [ - "plan = ae.load_plan(\"plan.json\")\n", + "import pandas as pd\n", + "plans = ae.plan.query()\n", + "display(pd.DataFrame([plan.model_dump(include={\"id\"}) for plan in plans]))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Check Plan Status" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plan.status()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Block Until Plan Completes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ "plan.join()" ] }, @@ -355,7 +434,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Option B - Terminate Execution" + "## Stop Plan Execution" ] }, { @@ -364,7 +443,6 @@ "metadata": {}, "outputs": [], "source": [ - "plan = ae.load_plan(\"plan.json\")\n", "plan.stop()" ] }, @@ -372,7 +450,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Option C - Monitor Files During Execution" + "## Run File Operations on Plan" ] }, { @@ -384,53 +462,19 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "ExperimentStatus(state=4, timeOfStateChange=1733417291473, reason='process started', statusId='EXPERIMENT_STATE_451b84f8-b6e8-472c-84b6-23460a4ecbdf')\n", - "['1',\n", - " 'A1497186742',\n", - " 'NAMD.stderr',\n", - " 'NAMD.stdout',\n", - " 'NAMD_Repl_.dcd',\n", - " 'NAMD_Repl_1.out',\n", - " 'b4pull.pdb',\n", - " 'b4pull.restart.coor',\n", - " 'b4pull.restart.vel',\n", - " 'b4pull.restart.xsc',\n", - " 'job_1017394371.slurm',\n", - " 'par_all36_water.prm',\n", - " 'par_all36m_prot.prm',\n", - " 'pull.conf',\n", - " 'structure.pdb',\n", - " 'structure.psf',\n", - " '']\n" - ] - }, - { - "data": { - "text/plain": [ - "\"ExeTyp=GPU\\nPJobID=\\nrep_list=\\nnum_rep=1\\ninput=pull.conf\\nagent_id=e016b89f-5eef-4e8e-b4eb-d202942dc76d\\nserver_url=api.gateway.cybershuttle.org\\n The Airavata Gateway User is scigap\\n Namd run will use the input pull.conf\\nNo replica array \\nGPU executable will be used\\nLoading gpu modules\\nGPU Run Command is time -p mpirun --hostfile ./HostFile -np 10 namd3 +p10 pull.conf\\nlrwxrwxrwx 1 scigap ind123 103 Dec 5 08:48 35590436 -> /expanse/lustre/scratch/scigap/te [...] - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ - "plan = ae.load_plan(\"plan.json\")\n", - "from pprint import pprint\n", "for task in plan.tasks:\n", " status = task.status()\n", " print(status)\n", - " files = task.files()\n", - " pprint(files)\n", - "\n", - "display(plan.tasks[0].cat(\"NAMD.stdout\"))" + " # task.upload(\"data/sample.txt\")\n", + " files = task.ls()\n", + " display(files)\n", + " display(task.cat(\"NAMD.stderr\"))\n", + " # task.download(\"NAMD.stdout\", \"./results\")\n", + " task.download(\"NAMD_Repl_1.out\", \"./results\")" ] }, { @@ -446,19 +490,18 @@ "metadata": {}, "outputs": [], "source": [ - "from matplotlib import pyplot as plt\n", - "import pandas as pd\n", - "\n", "for index, task in enumerate(plan.tasks):\n", "\n", - " @cs.task_context(task)\n", - " def visualize():\n", - " data = pd.read_csv(\"data.csv\")\n", - " plt.figure(figsize=(8, 6))\n", - " plt.plot(data[\"x\"], data[\"y\"], marker=\"o\", linestyle=\"-\", linewidth=2, markersize=6)\n", - " plt.title(f\"Plot for Replica {index} of {len(plan.tasks)}\")\n", + " @task.context(packages=[\"matplotlib\", \"pandas\"])\n", + " def analyze(x, y, index, num_tasks) -> None:\n", + " from matplotlib import pyplot as plt\n", + " import pandas as pd\n", + " df = pd.read_csv(\"data.csv\")\n", + " plt.figure(figsize=(x, y))\n", + " plt.plot(df[\"x\"], df[\"y\"], marker=\"o\", linestyle=\"-\", linewidth=2, markersize=6)\n", + " plt.title(f\"Plot for Replica {index} of {num_tasks}\")\n", "\n", - " visualize()" + " analyze(3, 4, index+1, len(plan.tasks))" ] } ],
