This is an automated email from the ASF dual-hosted git repository. lahirujayathilake pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/airavata.git
commit 83052b074dca49a90b94112ec3a48ed7880dd47b Author: yasith <[email protected]> AuthorDate: Thu Feb 20 18:53:33 2025 -0600 choose group/project from API, silent plan launch, add_replica() -> create_task(), display task state, planid and taskid, improve messaging --- .../airavata_experiments/__init__.py | 120 ++++++++++++++++++--- .../airavata_experiments/airavata.py | 88 ++++++++------- .../airavata_experiments/auth/device_auth.py | 16 +-- .../airavata_experiments/base.py | 12 +-- .../airavata_experiments/plan.py | 17 ++- .../airavata_experiments/runtime.py | 31 ++++-- .../airavata_experiments/task.py | 3 +- .../airavata_sdk/clients/api_server_client.py | 32 +++--- .../airavata-python-sdk/pyproject.toml | 4 +- .../jupyterhub/data/1_experiment_sdk.ipynb | 6 +- .../deployments/jupyterhub/data/smd_cpu.ipynb | 10 +- 11 files changed, 227 insertions(+), 112 deletions(-) diff --git a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/__init__.py b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/__init__.py index f45c820c4f..98391d4fa4 100644 --- a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/__init__.py +++ b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/__init__.py @@ -19,10 +19,12 @@ from __future__ import annotations from . import base, plan from .auth import login, logout from .runtime import list_runtimes, Runtime +from typing import Any __all__ = ["login", "logout", "list_runtimes", "base", "plan"] -def display_runtimes(runtimes: list[Runtime]): + +def display_runtimes(runtimes: list[Runtime]) -> None: """ Display runtimes in a tabular format """ @@ -33,9 +35,11 @@ def display_runtimes(runtimes: list[Runtime]): record = dict(id=runtime.id, **runtime.args) records.append(record) - return pd.DataFrame(records) + d = get_display_fn() + d(pd.DataFrame(records).set_index("id")) + -def display_experiments(experiments: list[base.Experiment]): +def display_experiments(experiments: list[base.Experiment]) -> None: """ Display experiments in a tabular format """ @@ -48,23 +52,113 @@ def display_experiments(experiments: list[base.Experiment]): record[k] = ", ".join(v) if isinstance(v, list) else str(v) records.append(record) - return pd.DataFrame(records) + d = get_display_fn() + d(pd.DataFrame(records).set_index("name")) + -def display_plans(plans: list[plan.Plan]): +def display_plans(plans: list[plan.Plan]) -> None: """ Display plans in a tabular format """ - import pandas as pd + + from IPython.display import HTML - records = [] + html = """ + <table border='1'> + <tr> + <th><b>Plan Id</b></th> + <th><b>Task Id</b></th> + <th><b>State</b></th> + <th><b>Name</b></th> + <th><b>App</b></th> + <th><b>Inputs</b></th> + <th><b>Cluster</b></th> + <th><b>Group</b></th> + <th><b>Category</b></th> + <th><b>Queue</b></th> + <th><b>Nodes</b></th> + <th><b>CPUs</b></th> + <th><b>GPUs</b></th> + <th><b>Time</b></th> + </tr> + """ + for plan in plans: for task in plan.tasks: - record = dict(plan_id=str(plan.id)) - for k, v in task.model_dump().items(): - record[k] = ", ".join(v) if isinstance(v, list) else str(v) - records.append(record) + html += generate_task_row(plan.id or "N/A", task) + html += "<tr></tr>" - return pd.DataFrame(records) + html += "</table>" + + script = """ + <script type="text/javascript"> + document.querySelectorAll('.hover-container').forEach(item => { + item.addEventListener('mouseover', () => item.querySelector('.hover-content').style.visibility = 'visible'); + item.addEventListener('mouseout', () => item.querySelector('.hover-content').style.visibility = 'hidden'); + }); + </script> + """ + + d = get_display_fn() + d(HTML(html + script)) + + +def generate_task_row(plan_id: str, task: plan.Task): + """ + Generate a row for the task + """ + + task_id, task_state = task.status() if task.ref else ("N/A", "N/A") + + return f""" + <tr> + <td>{plan_id[:8]}...</td> + <td>{task_id}</td> + <td>{task_state}</td> + <td>{task.name}</td> + <td>{task.app_id}</td> + <td> + <div style="position:relative;" class="hover-container"> + [Hover] + <div class="hover-content" style="visibility:hidden; position:absolute; top:20px; left:0; background:white; border:1px solid black; padding:5px; z-index:10;"> + {generate_inputs_table(task.inputs)} + </div> + </div> + </td> + <td>{task.runtime.args.get("cluster", "N/A")}</td> + <td>{task.runtime.args.get("group", "N/A")}</td> + <td>{task.runtime.args.get("category", "N/A")}</td> + <td>{task.runtime.args.get("queue_name", "N/A")}</td> + <td>{task.runtime.args.get("node_count", "N/A")}</td> + <td>{task.runtime.args.get("cpu_count", "N/A")}</td> + <td>{task.runtime.args.get("gpu_count", "N/A")}</td> + <td>{task.runtime.args.get("walltime", "N/A")}</td> + </tr> + """ + + +def generate_inputs_table(inputs: dict): + + html = """ + <table border='1'> + <tr><th><b>Input</b></th><th><b>Type</b></th><th><b>Value</b></th></tr> + """ + for k, v in inputs.items(): + html += f"""<tr><th>{k}</th><th>{v.get("type")}</th><th>{v.get("value")}</th></tr>""" + html += "</table>" + return html + +def get_display_fn() -> Any: + try: + from IPython.core.getipython import get_ipython + from IPython.display import display as d + if get_ipython() is not None and d is not None: + return d + else: + raise Exception("Not in IPython environment") + except: + return print + def display(arg): @@ -83,4 +177,4 @@ def display(arg): if isinstance(arg, plan.Plan): return display_plans([arg]) - raise NotImplementedError(f"Cannot display object of type {type(arg)}") \ No newline at end of file + raise NotImplementedError(f"Cannot display object of type {type(arg)}") diff --git a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/airavata.py b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/airavata.py index ebab013cba..ab0c273d17 100644 --- a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/airavata.py +++ b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/airavata.py @@ -68,10 +68,6 @@ class Settings: self.STORAGE_RESOURCE_HOST = config.get('Gateway', 'STORAGE_RESOURCE_HOST') self.SFTP_PORT = config.get('Gateway', 'SFTP_PORT') - # runtime-specific settings - self.PROJECT_NAME = config.get('User', 'PROJECT_NAME') - self.GROUP_RESOURCE_PROFILE_NAME = config.get('User', 'GROUP_RESOURCE_PROFILE_NAME') - class AiravataOperator: @@ -130,7 +126,7 @@ class AiravataOperator: self, experiment_model: ExperimentModel, computation_resource_name: str, - group_resource_profile_name: str, + group: str, storageId: str, node_count: int, total_cpu_count: int, @@ -140,7 +136,7 @@ class AiravataOperator: auto_schedule=False, ) -> ExperimentModel: resource_host_id = self.get_resource_host_id(computation_resource_name) - groupResourceProfileId = self.get_group_resource_profile_id(group_resource_profile_name) + groupResourceProfileId = self.get_group_resource_profile_id(group) computRes = ComputationalResourceSchedulingModel() computRes.resourceHostId = resource_host_id computRes.nodeCount = node_count @@ -173,9 +169,6 @@ class AiravataOperator: def default_gateway_id(self): return self.settings.GATEWAY_ID - def default_gateway_grp_name(self): - return self.settings.GROUP_RESOURCE_PROFILE_NAME - def default_gateway_data_store_dir(self): return self.settings.GATEWAY_DATA_STORE_DIR @@ -185,9 +178,6 @@ class AiravataOperator: def default_sr_hostname(self): return self.settings.STORAGE_RESOURCE_HOST - def default_project_name(self): - return self.settings.PROJECT_NAME - def connection_svc_url(self): return self.settings.CONNECTION_SVC_URL @@ -258,32 +248,28 @@ class AiravataOperator: storage = self.api_server_client.get_storage_resource(self.airavata_token, sr_id) return storage - def get_group_resource_profile_id(self, grp_name: str | None = None) -> str: + def get_group_resource_profile_id(self, group: str) -> str: """ Get group resource profile id by name """ - # use defaults for missing values - grp_name = grp_name or self.default_gateway_grp_name() # logic grps: list = self.api_server_client.get_group_resource_list(self.airavata_token, self.default_gateway_id()) # type: ignore - grp_id = next((grp.groupResourceProfileId for grp in grps if grp.groupResourceProfileName == grp_name)) + grp_id = next((grp.groupResourceProfileId for grp in grps if grp.groupResourceProfileName == group)) return str(grp_id) - def get_group_resource_profile(self, grp_id: str): - grp: any = self.api_server_client.get_group_resource_profile(self.airavata_token, grp_id) # type: ignore + def get_group_resource_profile(self, group_id: str): + grp: any = self.api_server_client.get_group_resource_profile(self.airavata_token, group_id) # type: ignore return grp - def get_compatible_deployments(self, app_interface_id: str, grp_name: str | None = None): + def get_compatible_deployments(self, app_interface_id: str, group: str): """ Get compatible deployments for an application interface and group resource profile """ - # use defaults for missing values - grp_name = grp_name or self.default_gateway_grp_name() # logic grps: list = self.api_server_client.get_group_resource_list(self.airavata_token, self.default_gateway_id()) # type: ignore - grp_id = next((grp.groupResourceProfileId for grp in grps if grp.groupResourceProfileName == grp_name)) + grp_id = next((grp.groupResourceProfileId for grp in grps if grp.groupResourceProfileName == group)) deployments = self.api_server_client.get_application_deployments_for_app_module_and_group_resource_profile(self.airavata_token, app_interface_id, grp_id) return deployments @@ -514,6 +500,7 @@ class AiravataOperator: def launch_experiment( self, experiment_name: str, + project: str, app_name: str, inputs: dict[str, dict[str, str | int | float | list[str]]], computation_resource_name: str, @@ -521,11 +508,10 @@ class AiravataOperator: node_count: int, cpu_count: int, walltime: int, + group: str = "Default", *, gateway_id: str | None = None, - grp_name: str | None = None, sr_host: str | None = None, - project_name: str | None = None, auto_schedule: bool = False, ) -> LaunchState: """ @@ -535,10 +521,8 @@ class AiravataOperator: # preprocess args (str) print("[AV] Preprocessing args...") gateway_id = str(gateway_id or self.default_gateway_id()) - grp_name = str(grp_name or self.default_gateway_grp_name()) sr_host = str(sr_host or self.default_sr_hostname()) mount_point = Path(self.default_gateway_data_store_dir()) / self.user_id - project_name = str(project_name or self.default_project_name()) server_url = urlparse(self.connection_svc_url()).netloc # validate args (str) @@ -549,9 +533,9 @@ class AiravataOperator: assert len(inputs) > 0, f"Invalid inputs: {inputs}" assert len(gateway_id) > 0, f"Invalid gateway_id: {gateway_id}" assert len(queue_name) > 0, f"Invalid queue_name: {queue_name}" - assert len(grp_name) > 0, f"Invalid grp_name: {grp_name}" + assert len(group) > 0, f"Invalid group name: {group}" assert len(sr_host) > 0, f"Invalid sr_host: {sr_host}" - assert len(project_name) > 0, f"Invalid project_name: {project_name}" + assert len(project) > 0, f"Invalid project_name: {project}" assert len(mount_point.as_posix()) > 0, f"Invalid mount_point: {mount_point}" # validate args (int) @@ -592,7 +576,7 @@ class AiravataOperator: experiment = self.create_experiment_model( experiment_name=experiment_name, application_name=app_name, - project_name=project_name, + project_name=project, description=experiment_name, gateway_id=gateway_id, ) @@ -600,7 +584,7 @@ class AiravataOperator: print("[AV] Setting up experiment directory...") exp_dir = self.make_experiment_dir( sr_host=storage.hostName, - project_name=project_name, + project_name=project, experiment_name=experiment_name, ) abs_path = (mount_point / exp_dir.lstrip("/")).as_posix().rstrip("/") + "/" @@ -610,7 +594,7 @@ class AiravataOperator: experiment = self.configure_computation_resource_scheduling( experiment_model=experiment, computation_resource_name=computation_resource_name, - group_resource_profile_name=grp_name, + group=group, storageId=sr_id, node_count=node_count, total_cpu_count=cpu_count, @@ -671,7 +655,7 @@ class AiravataOperator: self.api_server_client.launch_experiment(self.airavata_token, ex_id, gateway_id) print(f"[AV] Experiment {experiment_name} STARTED with id: {ex_id}") - # get process id + # wait until experiment begins, then get process id print(f"[AV] Experiment {experiment_name} WAITING until experiment begins...") process_id = None while process_id is None: @@ -683,6 +667,18 @@ class AiravataOperator: time.sleep(2) print(f"[AV] Experiment {experiment_name} EXECUTING with pid: {process_id}") + # wait until task begins, then get job id + print(f"[AV] Experiment {experiment_name} WAITING until task begins...") + job_id = job_state = None + while job_state is None: + try: + job_id, job_state = self.get_task_status(ex_id) + except: + time.sleep(2) + else: + time.sleep(2) + print(f"[AV] Experiment {experiment_name} - Task {job_state} with id: {job_id}") + return LaunchState( experiment_id=ex_id, agent_ref=str(data_inputs["agent_id"]), @@ -702,7 +698,7 @@ class AiravataOperator: self.airavata_token, experiment_id, self.default_gateway_id()) return status - def execute_py(self, libraries: list[str], code: str, agent_id: str, pid: str, runtime_args: dict, cold_start: bool = True) -> str | None: + def execute_py(self, project: str, libraries: list[str], code: str, agent_id: str, pid: str, runtime_args: dict, cold_start: bool = True) -> str | None: # lambda to send request print(f"[av] Attempting to submit to agent {agent_id}...") make_request = lambda: requests.post(f"{self.connection_svc_url()}/agent/executepythonrequest", json={ @@ -723,6 +719,7 @@ class AiravataOperator: self.launch_experiment( experiment_name="Agent", app_name="AiravataAgent", + project=project, inputs={ "agent_id": {"type": "str", "value": agent_id}, "server_url": {"type": "str", "value": urlparse(self.connection_svc_url()).netloc}, @@ -733,8 +730,9 @@ class AiravataOperator: node_count=1, cpu_count=runtime_args["cpu_count"], walltime=runtime_args["walltime"], + group=runtime_args["group"], ) - return self.execute_py(libraries, code, agent_id, pid, runtime_args, cold_start=False) + return self.execute_py(project, libraries, code, agent_id, pid, runtime_args, cold_start=False) elif data["executionId"] is not None: print(f"[av] Submitted to Python Interpreter") # agent response @@ -774,7 +772,23 @@ class AiravataOperator: def get_available_runtimes(self): from .runtime import Remote return [ - Remote(cluster="login.expanse.sdsc.edu", category="gpu", queue_name="gpu-shared", node_count=1, cpu_count=10, walltime=30), - Remote(cluster="login.expanse.sdsc.edu", category="cpu", queue_name="shared", node_count=1, cpu_count=10, walltime=30), - Remote(cluster="anvil.rcac.purdue.edu", category="cpu", queue_name="shared", node_count=1, cpu_count=24, walltime=30), + Remote(cluster="login.expanse.sdsc.edu", category="gpu", queue_name="gpu-shared", node_count=1, cpu_count=10, gpu_count=1, walltime=30, group="Default"), + Remote(cluster="login.expanse.sdsc.edu", category="cpu", queue_name="shared", node_count=1, cpu_count=10, gpu_count=0, walltime=30, group="Default"), + Remote(cluster="anvil.rcac.purdue.edu", category="cpu", queue_name="shared", node_count=1, cpu_count=24, gpu_count=0, walltime=30, group="Default"), + Remote(cluster="login.expanse.sdsc.edu", category="gpu", queue_name="gpu-shared", node_count=1, cpu_count=10, gpu_count=1, walltime=30, group="GaussianGroup"), + Remote(cluster="login.expanse.sdsc.edu", category="cpu", queue_name="shared", node_count=1, cpu_count=10, gpu_count=0, walltime=30, group="GaussianGroup"), + Remote(cluster="anvil.rcac.purdue.edu", category="cpu", queue_name="shared", node_count=1, cpu_count=24, gpu_count=0, walltime=30, group="GaussianGroup"), ] + + def get_task_status(self, experiment_id: str) -> tuple[str, Literal["SUBMITTED", "UN_SUBMITTED", "SETUP", "QUEUED", "ACTIVE", "COMPLETE", "CANCELING", "CANCELED", "FAILED", "HELD", "SUSPENDED", "UNKNOWN"] | None]: + states = ["SUBMITTED", "UN_SUBMITTED", "SETUP", "QUEUED", "ACTIVE", "COMPLETE", "CANCELING", "CANCELED", "FAILED", "HELD", "SUSPENDED", "UNKNOWN"] + job_details: dict = self.api_server_client.get_job_statuses(self.airavata_token, experiment_id) # type: ignore + job_id = job_state = None + # get the most recent job id and state + for job_id, v in job_details.items(): + if v.reason in states: + job_state = v.reason + else: + job_state = states[int(v.jobState)] + return job_id or "N/A", job_state # type: ignore + diff --git a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/auth/device_auth.py b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/auth/device_auth.py index 165d1a7175..8ed873f181 100644 --- a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/auth/device_auth.py +++ b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/auth/device_auth.py @@ -101,7 +101,7 @@ class DeviceFlowAuthenticator: self.__persist_token__(self._refresh_token, self._access_token) def login(self, interactive: bool = True) -> None: - + auth_warning = None try: # [Flow A] Reuse saved token if os.path.exists("auth.state"): @@ -112,24 +112,24 @@ class DeviceFlowAuthenticator: self._refresh_token = str(data["refresh_token"]) self._access_token = str(data["access_token"]) except: - print("Failed to load auth.state file!") + auth_warning = "Failed to load auth.state file!" else: # [A2] Check if access token is valid, if so, return if not self.__has_expired__(self._access_token): - print("Authenticated via saved access token!") - return None + return print("Authenticated via saved access token!") else: - print("Access token is invalid!") + auth_warning = "Access token is invalid!" # [A3] Check if refresh token is valid. if so, refresh try: if not self.__has_expired__(self._refresh_token): self.refresh() - print("Authenticated via saved refresh token!") - return None + return print("Authenticated via saved refresh token!") else: - print("Refresh token is invalid!") + auth_warning = "Refresh token is invalid!" except Exception as e: print(*e.args) + if auth_warning: + print(auth_warning) # [Flow B] Request device and user code diff --git a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/base.py b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/base.py index 1967cf6f73..56104ca5c8 100644 --- a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/base.py +++ b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/base.py @@ -85,24 +85,22 @@ class Experiment(Generic[T], abc.ABC): self.resource = resource return self - def add_replica(self, *allowed_runtimes: Runtime) -> None: + def create_task(self, *allowed_runtimes: Runtime, name: str | None = None) -> None: """ - Add a replica to the experiment. - This will create a copy of the application with the given inputs. - + Create a task to run the experiment on a given runtime. """ runtime = random.choice(allowed_runtimes) if len(allowed_runtimes) > 0 else self.resource uuid_str = str(uuid.uuid4())[:4].upper() self.tasks.append( Task( - name=f"{self.name}_{uuid_str}", + name=name or f"{self.name}_{uuid_str}", app_id=self.application.app_id, inputs={**self.inputs}, runtime=runtime, ) ) - print(f"Added replica. ({len(self.tasks)} tasks in total)") + print(f"Task created. ({len(self.tasks)} tasks in total)") def add_sweep(self, *allowed_runtimes: Runtime, **space: list) -> None: """ @@ -126,7 +124,7 @@ class Experiment(Generic[T], abc.ABC): def plan(self, **kwargs) -> Plan: if len(self.tasks) == 0: - self.add_replica(self.resource) + self.create_task(self.resource) tasks = [] for t in self.tasks: agg_inputs = {**self.inputs, **t.inputs} diff --git a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/plan.py b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/plan.py index b6cdaa497c..d9bad69ad3 100644 --- a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/plan.py +++ b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/plan.py @@ -41,15 +41,14 @@ class Plan(pydantic.BaseModel): return v def __stage_prepare__(self) -> None: - print("Preparing execution plan...") + print("Preparing to launch...") def __stage_confirm__(self, silent: bool) -> None: - print("Confirming execution plan...") if not silent: while True: - res = input("Here is the execution plan. continue? (Y/n) ") + res = input("Ready to launch. Continue? (Y/n) ") if res.upper() in ["N"]: - raise Exception("Execution was aborted by user.") + raise Exception("Launch aborted by user.") elif res.upper() in ["Y", ""]: break else: @@ -81,7 +80,7 @@ class Plan(pydantic.BaseModel): self.save_json(os.path.join(local_dir, "plan.json")) return fps - def launch(self, silent: bool = False) -> None: + def launch(self, silent: bool = True) -> None: try: self.__stage_prepare__() self.__stage_confirm__(silent) @@ -93,8 +92,8 @@ class Plan(pydantic.BaseModel): def status(self) -> None: statuses = self.__stage_status__() print(f"Plan {self.id} ({len(self.tasks)} tasks):") - for task, status in zip(self.tasks, statuses): - print(f"* {task.name}: {status}") + for task, (task_id, status) in zip(self.tasks, statuses): + print(f"* {task.name}: {task_id}: {status}") def wait_for_completion(self, check_every_n_mins: float = 0.1) -> None: n = len(self.tasks) @@ -104,9 +103,9 @@ class Plan(pydantic.BaseModel): while True: completed = [False] * n statuses = self.__stage_status__() - for i, (task, status, pbar) in enumerate(zip(self.tasks, statuses, pbars)): + for i, (task, (task_id, status), pbar) in enumerate(zip(self.tasks, statuses, pbars)): completed[i] = is_terminal_state(status) - progress.update(pbar, description=f"{task.name} ({i+1}/{n}): {status}", completed=completed[i], refresh=True) + progress.update(pbar, description=f"{task.name} ({i+1}/{n}): {task_id}: {status}", completed=completed[i], refresh=True) if all(completed): break sleep_time = check_every_n_mins * 60 diff --git a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/runtime.py b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/runtime.py index 260b784e65..36c4dc4c93 100644 --- a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/runtime.py +++ b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/runtime.py @@ -36,7 +36,7 @@ class Runtime(abc.ABC, pydantic.BaseModel): def execute_py(self, libraries: list[str], code: str, task: Task) -> None: ... @abc.abstractmethod - def status(self, task: Task) -> str: ... + def status(self, task: Task) -> tuple[str, str]: ... @abc.abstractmethod def signal(self, signal: str, task: Task) -> None: ... @@ -93,13 +93,13 @@ class Mock(Runtime): def execute_py(self, libraries: list[str], code: str, task: Task) -> None: pass - def status(self, task: Task) -> str: + def status(self, task: Task) -> tuple[str, str]: import random self._state += random.randint(0, 5) if self._state > 10: - return "COMPLETED" - return "RUNNING" + return "N/A", "COMPLETED" + return "N/A", "RUNNING" def signal(self, signal: str, task: Task) -> None: pass @@ -123,20 +123,22 @@ class Mock(Runtime): class Remote(Runtime): - def __init__(self, cluster: str, category: str, queue_name: str, node_count: int, cpu_count: int, walltime: int) -> None: + def __init__(self, cluster: str, category: str, queue_name: str, node_count: int, cpu_count: int, walltime: int, gpu_count: int = 0, group: str = "Default") -> None: super().__init__(id="remote", args=dict( cluster=cluster, category=category, queue_name=queue_name, node_count=node_count, cpu_count=cpu_count, + gpu_count=gpu_count, walltime=walltime, + group=group, )) def execute(self, task: Task) -> None: assert task.ref is None assert task.agent_ref is None - assert {"cluster", "queue_name", "node_count", "cpu_count", "walltime"}.issubset(self.args.keys()) + assert {"cluster", "group", "queue_name", "node_count", "cpu_count", "gpu_count", "walltime"}.issubset(self.args.keys()) print(f"[Remote] Creating Experiment: name={task.name}") from .airavata import AiravataOperator @@ -145,12 +147,14 @@ class Remote(Runtime): launch_state = av.launch_experiment( experiment_name=task.name, app_name=task.app_id, + project=task.project, inputs=task.inputs, computation_resource_name=str(self.args["cluster"]), queue_name=str(self.args["queue_name"]), node_count=int(self.args["node_count"]), cpu_count=int(self.args["cpu_count"]), walltime=int(self.args["walltime"]), + group=str(self.args["group"]), ) task.agent_ref = launch_state.agent_ref task.pid = launch_state.process_id @@ -169,17 +173,21 @@ class Remote(Runtime): from .airavata import AiravataOperator av = AiravataOperator(context.access_token) - result = av.execute_py(libraries, code, task.agent_ref, task.pid, task.runtime.args) + result = av.execute_py(task.project, libraries, code, task.agent_ref, task.pid, task.runtime.args) print(result) - def status(self, task: Task): + def status(self, task: Task) -> tuple[str, str]: assert task.ref is not None assert task.agent_ref is not None from .airavata import AiravataOperator av = AiravataOperator(context.access_token) - status = av.get_experiment_status(task.ref) - return status + # prioritize job state, fallback to experiment state + job_id, job_state = av.get_task_status(task.ref) + if not job_state or job_state == "UN_SUBMITTED": + return job_id, av.get_experiment_status(task.ref) + else: + return job_id, job_state def signal(self, signal: str, task: Task) -> None: assert task.ref is not None @@ -245,6 +253,7 @@ class Remote(Runtime): def list_runtimes( cluster: str | None = None, category: str | None = None, + group: str | None = None, node_count: int | None = None, cpu_count: int | None = None, walltime: int | None = None, @@ -254,7 +263,7 @@ def list_runtimes( all_runtimes = av.get_available_runtimes() out_runtimes = [] for r in all_runtimes: - if (cluster in [None, r.args["cluster"]]) and (category in [None, r.args["category"]]): + if (cluster in [None, r.args["cluster"]]) and (category in [None, r.args["category"]]) and (group in [None, r.args["group"]]): r.args["node_count"] = node_count or r.args["node_count"] r.args["cpu_count"] = cpu_count or r.args["cpu_count"] r.args["walltime"] = walltime or r.args["walltime"] diff --git a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/task.py b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/task.py index 1612f44f74..285c959c66 100644 --- a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/task.py +++ b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/task.py @@ -23,6 +23,7 @@ class Task(pydantic.BaseModel): name: str app_id: str + project: str = pydantic.Field(default="Default Project") inputs: dict[str, Any] runtime: Runtime ref: str | None = pydantic.Field(default=None) @@ -53,7 +54,7 @@ class Task(pydantic.BaseModel): print(f"[Task] Executing {self.name} on {self.runtime}") self.runtime.execute(self) - def status(self) -> str: + def status(self) -> tuple[str, str]: assert self.ref is not None return self.runtime.status(self) diff --git a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_sdk/clients/api_server_client.py b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_sdk/clients/api_server_client.py index 1c120c10d3..873f23c07c 100644 --- a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_sdk/clients/api_server_client.py +++ b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_sdk/clients/api_server_client.py @@ -1753,19 +1753,19 @@ class APIServerClient(object): return self.api_server_client_pool.cloneExperiment(authz_token, existing_experiment_id, new_experiment_name, new_experiment_projectId) except InvalidRequestException: - logger.exception("Error occurred in get_job_details, probably due to invalid parameters ", + logger.exception("Error occurred in clone_experiment, probably due to invalid parameters ", ) raise except AiravataClientException: - logger.exception("Error occurred in get_job_details, probably due to client misconfiguration ", + logger.exception("Error occurred in clone_experiment, probably due to client misconfiguration ", ) raise except AiravataSystemException: - logger.exception("Error occurred in get_job_details, probably due to server side error ", + logger.exception("Error occurred in clone_experiment, probably due to server side error ", ) raise except AuthorizationException: - logger.exception("Error occurred in get_job_details, probably due to invalid authz token ", + logger.exception("Error occurred in clone_experiment, probably due to invalid authz token ", ) raise @@ -1826,19 +1826,19 @@ class APIServerClient(object): new_experiment_name, new_experiment_projectId) except InvalidRequestException: - logger.exception("Error occurred in get_job_details, probably due to invalid parameters ", + logger.exception("Error occurred in clone_experiment_by_admin, probably due to invalid parameters ", ) raise except AiravataClientException: - logger.exception("Error occurred in get_job_details, probably due to client misconfiguration ", + logger.exception("Error occurred in clone_experiment_by_admin, probably due to client misconfiguration ", ) raise except AiravataSystemException: - logger.exception("Error occurred in get_job_details, probably due to server side error ", + logger.exception("Error occurred in clone_experiment_by_admin, probably due to server side error ", ) raise except AuthorizationException: - logger.exception("Error occurred in get_job_details, probably due to invalid authz token ", + logger.exception("Error occurred in clone_experiment_by_admin, probably due to invalid authz token ", ) raise @@ -1889,19 +1889,19 @@ class APIServerClient(object): return self.api_server_client_pool.terminateExperiment(authz_token, airavata_experiment_id, gateway_id) except InvalidRequestException: - logger.exception("Error occurred in get_job_details, probably due to invalid parameters ", + logger.exception("Error occurred in terminate_experiment, probably due to invalid parameters ", ) raise except AiravataClientException: - logger.exception("Error occurred in get_job_details, probably due to client misconfiguration ", + logger.exception("Error occurred in terminate_experiment, probably due to client misconfiguration ", ) raise except AiravataSystemException: - logger.exception("Error occurred in get_job_details, probably due to server side error ", + logger.exception("Error occurred in terminate_experiment, probably due to server side error ", ) raise except AuthorizationException: - logger.exception("Error occurred in get_job_details, probably due to invalid authz token ", + logger.exception("Error occurred in terminate_experiment, probably due to invalid authz token ", ) raise @@ -1929,19 +1929,19 @@ class APIServerClient(object): return self.api_server_client_pool.registerApplicationModule(authz_token, gateway_id, application_module) except InvalidRequestException: - logger.exception("Error occurred in get_job_details, probably due to invalid parameters ", + logger.exception("Error occurred in register_application_module, probably due to invalid parameters ", ) raise except AiravataClientException: - logger.exception("Error occurred in get_job_details, probably due to client misconfiguration ", + logger.exception("Error occurred in register_application_module, probably due to client misconfiguration ", ) raise except AiravataSystemException: - logger.exception("Error occurred in get_job_details, probably due to server side error ", + logger.exception("Error occurred in register_application_module, probably due to server side error ", ) raise except AuthorizationException: - logger.exception("Error occurred in get_job_details, probably due to invalid authz token ", + logger.exception("Error occurred in register_application_module, probably due to invalid authz token ", ) raise diff --git a/airavata-api/airavata-client-sdks/airavata-python-sdk/pyproject.toml b/airavata-api/airavata-client-sdks/airavata-python-sdk/pyproject.toml index 3cda818551..1bbe41c4fd 100644 --- a/airavata-api/airavata-client-sdks/airavata-python-sdk/pyproject.toml +++ b/airavata-api/airavata-client-sdks/airavata-python-sdk/pyproject.toml @@ -3,8 +3,8 @@ requires = ["setuptools", "wheel"] build-backend = "setuptools.build_meta" [project] -name = "airavata-python-sdk-test" -version = "0.0.16" +name = "airavata-python-sdk" +version = "2.0.0" description = "Apache Airavata Python SDK" readme = "README.md" license = { text = "Apache License 2.0" } diff --git a/modules/agent-framework/deployments/jupyterhub/data/1_experiment_sdk.ipynb b/modules/agent-framework/deployments/jupyterhub/data/1_experiment_sdk.ipynb index d9de2a5b38..7bfa0bf4f4 100644 --- a/modules/agent-framework/deployments/jupyterhub/data/1_experiment_sdk.ipynb +++ b/modules/agent-framework/deployments/jupyterhub/data/1_experiment_sdk.ipynb @@ -130,8 +130,8 @@ ") -> Experiment[ExperimentApp]\n", "```\n", "\n", - "To add replica runs, simply call the `exp.add_replica()` function.\n", - "You can call the `add_replica()` function as many times as you want replicas.\n", + "To add replica runs, simply call the `exp.create_task()` function.\n", + "You can call the `create_task()` function as many times as you want replicas.\n", "Any optional resource constraint can be provided here.\n", "\n", "You can also call `ae.display()` to pretty-print the experiment." @@ -160,7 +160,7 @@ " ],\n", " parallelism=\"GPU\",\n", ")\n", - "exp.add_replica(*ae.list_runtimes(cluster=\"login.expanse.sdsc.edu\", category=\"gpu\", walltime=180))\n", + "exp.create_task(*ae.list_runtimes(cluster=\"login.expanse.sdsc.edu\", category=\"gpu\", walltime=180))\n", "ae.display(exp)" ] }, diff --git a/modules/agent-framework/deployments/jupyterhub/data/smd_cpu.ipynb b/modules/agent-framework/deployments/jupyterhub/data/smd_cpu.ipynb index 22b4faa37e..44a54a39fb 100644 --- a/modules/agent-framework/deployments/jupyterhub/data/smd_cpu.ipynb +++ b/modules/agent-framework/deployments/jupyterhub/data/smd_cpu.ipynb @@ -130,8 +130,8 @@ ") -> Experiment[ExperimentApp]\n", "```\n", "\n", - "To add replica runs, simply call the `exp.add_replica()` function.\n", - "You can call the `add_replica()` function as many times as you want replicas.\n", + "To add replica runs, simply call the `exp.create_task()` function.\n", + "You can call the `create_task()` function as many times as you want replicas.\n", "Any optional resource constraint can be provided here.\n", "\n", "You can also call `ae.display()` to pretty-print the experiment." @@ -160,7 +160,7 @@ " ],\n", " parallelism=\"CPU\",\n", ")\n", - "exp.add_replica(*ae.list_runtimes(cluster=\"login.expanse.sdsc.edu\", category=\"cpu\", walltime=60))\n", + "exp.create_task(*ae.list_runtimes(cluster=\"login.expanse.sdsc.edu\", category=\"cpu\", walltime=60))\n", "ae.display(exp)" ] }, @@ -379,7 +379,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "airavata", "language": "python", "name": "python3" }, @@ -393,7 +393,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.6" + "version": "3.12.7" } }, "nbformat": 4,
