This is an automated email from the ASF dual-hosted git repository. dimuthuupe pushed a commit to branch cybershuttle-staging in repository https://gitbox.apache.org/repos/asf/airavata.git
commit 13cfa7d3205bc3ace3c1e253ad18171271f0b4e4 Author: yasith <[email protected]> AuthorDate: Fri Apr 4 09:09:51 2025 -0400 update sdk and agent service to accep yml file and dependencies. --- .../airavata_jupyter_magic/__init__.py | 220 ++++++++++++++++----- .../service/handlers/AgentManagementHandler.java | 12 +- .../service/models/AgentLaunchRequest.java | 30 +++ 3 files changed, 207 insertions(+), 55 deletions(-) diff --git a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_jupyter_magic/__init__.py b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_jupyter_magic/__init__.py index b59cc12f0b..51437cc43f 100644 --- a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_jupyter_magic/__init__.py +++ b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_jupyter_magic/__init__.py @@ -7,23 +7,27 @@ from argparse import ArgumentParser from dataclasses import dataclass from enum import IntEnum from pathlib import Path -from rich.console import Console from typing import NamedTuple import jwt import requests -from .device_auth import DeviceFlowAuthenticator +import yaml from IPython.core.getipython import get_ipython from IPython.core.interactiveshell import ExecutionInfo, ExecutionResult from IPython.core.magic import register_cell_magic, register_line_magic from IPython.display import HTML, Image, display +from rich.console import Console + +from .device_auth import DeviceFlowAuthenticator # ======================================================================== # DATA STRUCTURES + class InvalidStateError(Exception): pass + class RequestedRuntime: cluster: str cpus: int @@ -31,6 +35,7 @@ class RequestedRuntime: walltime: int queue: str group: str + file: str | None class ProcessState(IntEnum): @@ -236,13 +241,14 @@ def submit_agent_job( rt_name: str, access_token: str, app_name: str, - cluster: str, - cpus: int, - memory: int | None, - walltime: int, - queue: str, - group: str, + cluster: str | None = None, + cpus: int | None = None, + memory: int | None = None, + walltime: int | None = None, + queue: str | None = None, + group: str | None = None, gateway_id: str = 'default', + file: str | None = None, ) -> None: """ Submit an agent job to the given runtime @@ -257,6 +263,7 @@ def submit_agent_job( @param queue: the queue @param group: the group @param gateway_id: the gateway id + @param file: environment file @returns: None """ @@ -264,23 +271,53 @@ def submit_agent_job( url = api_base_url + '/api/v1/exp/launch' # Data to be sent in the POST request - data = { - 'experimentName': app_name, - 'remoteCluster': cluster, - 'cpuCount': cpus, - 'nodeCount': 1, - 'memory': memory, - 'wallTime': walltime, - 'queue': queue, - 'group': group, - } - - # Convert the data to JSON format - json_data = json.dumps(data) - + if file is not None: + fp = Path(file) + assert fp.exists(), f"File {file} does not exist" + with open(fp, "r") as f: + content = yaml.safe_load(f) + # workspace + resources = content["workspace"]["resources"] + models = content["workspace"]["model_collection"] + datasets = content["workspace"]["data_collection"] + collection = models + datasets + mounts = [f"{i['identifier']}:{i['mount_point']}" for i in collection] + # dependencies + condas = content["additional_dependencies"]["conda"] + pips = content["additional_dependencies"]["pip"] + data = { + 'experimentName': app_name, + 'nodeCount': 1, + 'cpuCount': resources["min_cpu"], + 'memory': resources["min_mem"], + 'wallTime': resources["walltime"], + 'remoteCluster': resources["cluster"], + 'group': resources["group"], + 'queue': resources["queue"], + 'libraries': condas, + 'pip': pips, + 'mounts': mounts, + } + else: + data = { + 'experimentName': app_name, + 'nodeCount': 1, + 'cpuCount': cpus, + 'memory': memory, + 'wallTime': walltime, + 'remoteCluster': cluster, + 'group': group, + 'queue': queue, + 'libraries': [], + 'pip': [], + 'mounts': [], + } + + print(f"Requesting runtime={rt_name}", flush=True) + print(yaml.dump(data, indent=2), flush=True) # Send the POST request headers = generate_headers(access_token, gateway_id) - res = requests.post(url, headers=headers, data=json_data) + res = requests.post(url, headers=headers, data=json.dumps(data)) code = res.status_code # Check if the request was successful @@ -292,16 +329,16 @@ def submit_agent_job( print(msg) raise InvalidStateError(msg) rt = RuntimeInfo( + gateway_id=gateway_id, + processId=pid, agentId=obj['agentId'], experimentId=obj['experimentId'], - processId=pid, - cluster=cluster, - queue=queue, - cpus=cpus, - memory=memory, - walltime=walltime, - gateway_id=gateway_id, - group=group, + cluster=data['remoteCluster'], + queue=data['queue'], + cpus=data['cpuCount'], + memory=data['memory'], + walltime=data['wallTime'], + group=data['group'], ) state.all_runtimes[rt_name] = rt print(f'Requested runtime={rt_name}. state={pstate.name}') @@ -328,15 +365,67 @@ def wait_until_runtime_ready(access_token: str, rt_name: str): while True: ready, rstate = is_runtime_ready(access_token, rt, rt_name) if ready: - status.update(f"Connecting to runtime={rt_name}... status=READY") + status.update( + f"Connecting to runtime={rt_name}... status=READY") break else: - status.update(f"Connecting to runtime={rt_name}... status={rstate}") + status.update( + f"Connecting to runtime={rt_name}... status={rstate}") time.sleep(5) status.stop() console.clear() +def restart_runtime_kernel(access_token: str, rt_name: str, env_name: str, runtime: RuntimeInfo): + """ + Restart the kernel runtime on the given runtime. + + @param access_token: the access token + @param env_name: the environment name + @param runtime: the runtime info + @returns: None + + """ + + url = api_base_url + '/api/v1/agent/setup/restart' + + decode = jwt.decode(access_token, options={"verify_signature": False}) + user_id = decode['preferred_username'] + claimsMap = { + "userName": user_id, + "gatewayID": runtime.gateway_id + } + + # Headers + headers = { + 'Content-Type': 'application/json', + 'Authorization': 'Bearer ' + access_token, + 'X-Claims': json.dumps(claimsMap) + } + + # Send the POST request + res = requests.post(url, headers=headers, data=json.dumps({ + "agentId": runtime.agentId, + "envName": env_name, + })) + data = res.json() + + executionId = data.get("executionId") + if not executionId: + print(f"Failed to restart kernel runtime={runtime.agentId}") + return + + # Check if the request was successful + while True: + url = api_base_url + "/api/v1/agent/setup/restart/" + executionId + res = requests.get(url, headers={'Accept': 'application/json'}) + data = res.json() + if data.get('restarted'): + print(f"Restarted kernel={env_name} on runtime={rt_name}") + break + time.sleep(1) + + def stop_agent_job(access_token: str, runtime_name: str, runtime: RuntimeInfo): """ Stop the agent job on the given runtime. @@ -677,25 +766,39 @@ def request_runtime(line: str): prog="request_runtime", description="Request a runtime with given capabilities", ) - p.add_argument("--cluster", type=str, help="cluster", required=True) - p.add_argument("--cpus", type=int, help="CPU cores", required=True) + p.add_argument("--cluster", type=str, help="cluster", required=False) + p.add_argument("--cpus", type=int, help="CPU cores", required=False) p.add_argument("--memory", type=int, help="memory (MB)", required=False) - p.add_argument("--walltime", type=int, help="time (mins)", required=True) - p.add_argument("--queue", type=str, help="resource queue", required=True) - p.add_argument("--group", type=str, help="resource group", required=True) + p.add_argument("--walltime", type=int, help="time (mins)", required=False) + p.add_argument("--queue", type=str, help="resource queue", required=False) + p.add_argument("--group", type=str, help="resource group", required=False) + p.add_argument("--file", type=str, help="yml file", required=False) args = p.parse_args(cmd_args, namespace=RequestedRuntime()) - submit_agent_job( - rt_name=rt_name, - access_token=access_token, - app_name='CS_Agent', - cluster=args.cluster, - cpus=args.cpus, - memory=args.memory, - walltime=args.walltime, - queue=args.queue, - group=args.group, - ) + if args.file is not None: + submit_agent_job( + rt_name=rt_name, + access_token=access_token, + app_name='CS_Agent', + file=args.file, + ) + else: + assert args.cluster is not None + assert args.cpus is not None + assert args.walltime is not None + assert args.queue is not None + assert args.group is not None + submit_agent_job( + rt_name=rt_name, + access_token=access_token, + app_name='CS_Agent', + cluster=args.cluster, + cpus=args.cpus, + memory=args.memory, + walltime=args.walltime, + queue=args.queue, + group=args.group, + ) @register_line_magic @@ -724,6 +827,21 @@ def stat_runtime(line: str): print(f"Runtime={rt_name} is still preparing... {rstate}") +@register_line_magic +def restart_runtime(rt_name: str): + """ + Restart the runtime + + """ + access_token = get_access_token() + assert access_token is not None + + rt = state.all_runtimes.get(rt_name, None) + if rt is None: + return print(f"Runtime {rt_name} not found.") + restart_runtime_kernel(access_token, rt_name, "base", rt) + + @register_line_magic def stop_runtime(rt_name: str): """ @@ -804,7 +922,8 @@ def run_cell(raw_cell, store_history=False, silent=False, shell_futures=True, ce wait_until_runtime_ready(access_token, rt) return run_on_runtime(rt, raw_cell, store_history, silent, shell_futures, cell_id) except Exception as e: - info = ExecutionInfo(raw_cell, store_history, silent, shell_futures, cell_id) + info = ExecutionInfo(raw_cell, store_history, + silent, shell_futures, cell_id) result = ExecutionResult(info) print(f"Error: {e}") result.error_in_exec = e @@ -819,6 +938,7 @@ Loaded airavata_jupyter_magic %authenticate -- Authenticate to access high-performance runtimes. %request_runtime <rt> [args] -- Request a runtime named <rt> with configuration <args>. Call multiple times to request multiple runtimes. + %restart_runtime <rt> -- Restart runtime <rt>. Run this if you install new dependencies or if the runtime hangs. %stop_runtime <rt> -- Stop runtime <rt> when no longer needed. %switch_runtime <rt> -- Switch active runtime to <rt>. All subsequent executions will use this runtime. %%run_on <rt> -- Force a cell to always execute on <rt>, regardless of the active runtime. diff --git a/modules/agent-framework/agent-service/src/main/java/org/apache/airavata/agent/connection/service/handlers/AgentManagementHandler.java b/modules/agent-framework/agent-service/src/main/java/org/apache/airavata/agent/connection/service/handlers/AgentManagementHandler.java index acc2e003a0..b8d6645c27 100644 --- a/modules/agent-framework/agent-service/src/main/java/org/apache/airavata/agent/connection/service/handlers/AgentManagementHandler.java +++ b/modules/agent-framework/agent-service/src/main/java/org/apache/airavata/agent/connection/service/handlers/AgentManagementHandler.java @@ -164,11 +164,13 @@ public class AgentManagementHandler { List<InputDataObjectType> applicationInputs = airavataClient.getApplicationInputs(authzToken, appInterfaceId); List<InputDataObjectType> experimentInputs = applicationInputs.stream() .peek(input -> { - if ("agent_id".equals(input.getName())) { - input.setValue(agentId); - - } else if ("server_url".equals(input.getName())) { - input.setValue(airavataService.getServerUrl()); + switch (input.getName()) { + case "agent_id" -> input.setValue(agentId); + case "server_url" -> input.setValue(airavataService.getServerUrl()); + case "libraries" -> input.setValue(String.join(",", req.getLibraries())); + case "pip" -> input.setValue(String.join(",", req.getPip())); + case "mounts" -> input.setValue(String.join(",", req.getMounts())); + default -> {} } }) .collect(Collectors.toList()); diff --git a/modules/agent-framework/agent-service/src/main/java/org/apache/airavata/agent/connection/service/models/AgentLaunchRequest.java b/modules/agent-framework/agent-service/src/main/java/org/apache/airavata/agent/connection/service/models/AgentLaunchRequest.java index c39536f1bf..7934ddf74e 100644 --- a/modules/agent-framework/agent-service/src/main/java/org/apache/airavata/agent/connection/service/models/AgentLaunchRequest.java +++ b/modules/agent-framework/agent-service/src/main/java/org/apache/airavata/agent/connection/service/models/AgentLaunchRequest.java @@ -1,5 +1,8 @@ package org.apache.airavata.agent.connection.service.models; +import java.util.ArrayList; +import java.util.List; + import org.apache.commons.lang3.StringUtils; public class AgentLaunchRequest { @@ -8,6 +11,9 @@ public class AgentLaunchRequest { private String projectName; private String remoteCluster; private String group; + private List<String> libraries = new ArrayList<>(); + private List<String> pip = new ArrayList<>(); + private List<String> mounts = new ArrayList<>(); private String queue = "shared"; private int wallTime = 30; @@ -90,4 +96,28 @@ public class AgentLaunchRequest { public void setProjectName(String projectName) { this.projectName = projectName; } + + public List<String> getLibraries() { + return libraries; + } + + public void setLibraries(List<String> libraries) { + this.libraries = libraries; + } + + public List<String> getPip() { + return pip; + } + + public void setPip(List<String> pip) { + this.pip = pip; + } + + public List<String> getMounts() { + return mounts; + } + + public void setMounts(List<String> mounts) { + this.mounts = mounts; + } }
