This is an automated email from the ASF dual-hosted git repository.

dimuthuupe pushed a commit to branch cybershuttle-staging
in repository https://gitbox.apache.org/repos/asf/airavata.git

commit 13cfa7d3205bc3ace3c1e253ad18171271f0b4e4
Author: yasith <[email protected]>
AuthorDate: Fri Apr 4 09:09:51 2025 -0400

    update sdk and agent service to accep yml file and dependencies.
---
 .../airavata_jupyter_magic/__init__.py             | 220 ++++++++++++++++-----
 .../service/handlers/AgentManagementHandler.java   |  12 +-
 .../service/models/AgentLaunchRequest.java         |  30 +++
 3 files changed, 207 insertions(+), 55 deletions(-)

diff --git 
a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_jupyter_magic/__init__.py
 
b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_jupyter_magic/__init__.py
index b59cc12f0b..51437cc43f 100644
--- 
a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_jupyter_magic/__init__.py
+++ 
b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_jupyter_magic/__init__.py
@@ -7,23 +7,27 @@ from argparse import ArgumentParser
 from dataclasses import dataclass
 from enum import IntEnum
 from pathlib import Path
-from rich.console import Console
 from typing import NamedTuple
 
 import jwt
 import requests
-from .device_auth import DeviceFlowAuthenticator
+import yaml
 from IPython.core.getipython import get_ipython
 from IPython.core.interactiveshell import ExecutionInfo, ExecutionResult
 from IPython.core.magic import register_cell_magic, register_line_magic
 from IPython.display import HTML, Image, display
+from rich.console import Console
+
+from .device_auth import DeviceFlowAuthenticator
 
 # ========================================================================
 # DATA STRUCTURES
 
+
 class InvalidStateError(Exception):
     pass
 
+
 class RequestedRuntime:
     cluster: str
     cpus: int
@@ -31,6 +35,7 @@ class RequestedRuntime:
     walltime: int
     queue: str
     group: str
+    file: str | None
 
 
 class ProcessState(IntEnum):
@@ -236,13 +241,14 @@ def submit_agent_job(
     rt_name: str,
     access_token: str,
     app_name: str,
-    cluster: str,
-    cpus: int,
-    memory: int | None,
-    walltime: int,
-    queue: str,
-    group: str,
+    cluster: str | None = None,
+    cpus: int | None = None,
+    memory: int | None = None,
+    walltime: int | None = None,
+    queue: str | None = None,
+    group: str | None = None,
     gateway_id: str = 'default',
+    file: str | None = None,
 ) -> None:
     """
     Submit an agent job to the given runtime
@@ -257,6 +263,7 @@ def submit_agent_job(
     @param queue: the queue
     @param group: the group
     @param gateway_id: the gateway id
+    @param file: environment file
     @returns: None
 
     """
@@ -264,23 +271,53 @@ def submit_agent_job(
     url = api_base_url + '/api/v1/exp/launch'
 
     # Data to be sent in the POST request
-    data = {
-        'experimentName': app_name,
-        'remoteCluster': cluster,
-        'cpuCount': cpus,
-        'nodeCount': 1,
-        'memory': memory,
-        'wallTime': walltime,
-        'queue': queue,
-        'group': group,
-    }
-
-    # Convert the data to JSON format
-    json_data = json.dumps(data)
-
+    if file is not None:
+        fp = Path(file)
+        assert fp.exists(), f"File {file} does not exist"
+        with open(fp, "r") as f:
+            content = yaml.safe_load(f)
+        # workspace
+        resources = content["workspace"]["resources"]
+        models = content["workspace"]["model_collection"]
+        datasets = content["workspace"]["data_collection"]
+        collection = models + datasets
+        mounts = [f"{i['identifier']}:{i['mount_point']}" for i in collection]
+        # dependencies
+        condas = content["additional_dependencies"]["conda"]
+        pips = content["additional_dependencies"]["pip"]
+        data = {
+            'experimentName': app_name,
+            'nodeCount': 1,
+            'cpuCount': resources["min_cpu"],
+            'memory': resources["min_mem"],
+            'wallTime': resources["walltime"],
+            'remoteCluster': resources["cluster"],
+            'group': resources["group"],
+            'queue': resources["queue"],
+            'libraries': condas,
+            'pip': pips,
+            'mounts': mounts,
+        }
+    else:
+        data = {
+            'experimentName': app_name,
+            'nodeCount': 1,
+            'cpuCount': cpus,
+            'memory': memory,
+            'wallTime': walltime,
+            'remoteCluster': cluster,
+            'group': group,
+            'queue': queue,
+            'libraries': [],
+            'pip': [],
+            'mounts': [],
+        }
+
+    print(f"Requesting runtime={rt_name}", flush=True)
+    print(yaml.dump(data, indent=2), flush=True)
     # Send the POST request
     headers = generate_headers(access_token, gateway_id)
-    res = requests.post(url, headers=headers, data=json_data)
+    res = requests.post(url, headers=headers, data=json.dumps(data))
     code = res.status_code
 
     # Check if the request was successful
@@ -292,16 +329,16 @@ def submit_agent_job(
             print(msg)
             raise InvalidStateError(msg)
         rt = RuntimeInfo(
+            gateway_id=gateway_id,
+            processId=pid,
             agentId=obj['agentId'],
             experimentId=obj['experimentId'],
-            processId=pid,
-            cluster=cluster,
-            queue=queue,
-            cpus=cpus,
-            memory=memory,
-            walltime=walltime,
-            gateway_id=gateway_id,
-            group=group,
+            cluster=data['remoteCluster'],
+            queue=data['queue'],
+            cpus=data['cpuCount'],
+            memory=data['memory'],
+            walltime=data['wallTime'],
+            group=data['group'],
         )
         state.all_runtimes[rt_name] = rt
         print(f'Requested runtime={rt_name}. state={pstate.name}')
@@ -328,15 +365,67 @@ def wait_until_runtime_ready(access_token: str, rt_name: 
str):
         while True:
             ready, rstate = is_runtime_ready(access_token, rt, rt_name)
             if ready:
-                status.update(f"Connecting to runtime={rt_name}... 
status=READY")
+                status.update(
+                    f"Connecting to runtime={rt_name}... status=READY")
                 break
             else:
-                status.update(f"Connecting to runtime={rt_name}... 
status={rstate}")
+                status.update(
+                    f"Connecting to runtime={rt_name}... status={rstate}")
                 time.sleep(5)
         status.stop()
     console.clear()
 
 
+def restart_runtime_kernel(access_token: str, rt_name: str, env_name: str, 
runtime: RuntimeInfo):
+    """
+    Restart the kernel runtime on the given runtime.
+
+    @param access_token: the access token
+    @param env_name: the environment name
+    @param runtime: the runtime info
+    @returns: None
+
+    """
+
+    url = api_base_url + '/api/v1/agent/setup/restart'
+
+    decode = jwt.decode(access_token, options={"verify_signature": False})
+    user_id = decode['preferred_username']
+    claimsMap = {
+        "userName": user_id,
+        "gatewayID": runtime.gateway_id
+    }
+
+    # Headers
+    headers = {
+        'Content-Type': 'application/json',
+        'Authorization': 'Bearer ' + access_token,
+        'X-Claims': json.dumps(claimsMap)
+    }
+
+    # Send the POST request
+    res = requests.post(url, headers=headers, data=json.dumps({
+        "agentId": runtime.agentId,
+        "envName": env_name,
+    }))
+    data = res.json()
+
+    executionId = data.get("executionId")
+    if not executionId:
+        print(f"Failed to restart kernel runtime={runtime.agentId}")
+        return
+
+    # Check if the request was successful
+    while True:
+        url = api_base_url + "/api/v1/agent/setup/restart/" + executionId
+        res = requests.get(url, headers={'Accept': 'application/json'})
+        data = res.json()
+        if data.get('restarted'):
+            print(f"Restarted kernel={env_name} on runtime={rt_name}")
+            break
+        time.sleep(1)
+
+
 def stop_agent_job(access_token: str, runtime_name: str, runtime: RuntimeInfo):
     """
     Stop the agent job on the given runtime.
@@ -677,25 +766,39 @@ def request_runtime(line: str):
         prog="request_runtime",
         description="Request a runtime with given capabilities",
     )
-    p.add_argument("--cluster", type=str, help="cluster", required=True)
-    p.add_argument("--cpus", type=int, help="CPU cores", required=True)
+    p.add_argument("--cluster", type=str, help="cluster", required=False)
+    p.add_argument("--cpus", type=int, help="CPU cores", required=False)
     p.add_argument("--memory", type=int, help="memory (MB)", required=False)
-    p.add_argument("--walltime", type=int, help="time (mins)", required=True)
-    p.add_argument("--queue", type=str, help="resource queue", required=True)
-    p.add_argument("--group", type=str, help="resource group", required=True)
+    p.add_argument("--walltime", type=int, help="time (mins)", required=False)
+    p.add_argument("--queue", type=str, help="resource queue", required=False)
+    p.add_argument("--group", type=str, help="resource group", required=False)
+    p.add_argument("--file", type=str, help="yml file", required=False)
     args = p.parse_args(cmd_args, namespace=RequestedRuntime())
 
-    submit_agent_job(
-        rt_name=rt_name,
-        access_token=access_token,
-        app_name='CS_Agent',
-        cluster=args.cluster,
-        cpus=args.cpus,
-        memory=args.memory,
-        walltime=args.walltime,
-        queue=args.queue,
-        group=args.group,
-    )
+    if args.file is not None:
+        submit_agent_job(
+            rt_name=rt_name,
+            access_token=access_token,
+            app_name='CS_Agent',
+            file=args.file,
+        )
+    else:
+        assert args.cluster is not None
+        assert args.cpus is not None
+        assert args.walltime is not None
+        assert args.queue is not None
+        assert args.group is not None
+        submit_agent_job(
+            rt_name=rt_name,
+            access_token=access_token,
+            app_name='CS_Agent',
+            cluster=args.cluster,
+            cpus=args.cpus,
+            memory=args.memory,
+            walltime=args.walltime,
+            queue=args.queue,
+            group=args.group,
+        )
 
 
 @register_line_magic
@@ -724,6 +827,21 @@ def stat_runtime(line: str):
         print(f"Runtime={rt_name} is still preparing... {rstate}")
 
 
+@register_line_magic
+def restart_runtime(rt_name: str):
+    """
+    Restart the runtime
+
+    """
+    access_token = get_access_token()
+    assert access_token is not None
+
+    rt = state.all_runtimes.get(rt_name, None)
+    if rt is None:
+        return print(f"Runtime {rt_name} not found.")
+    restart_runtime_kernel(access_token, rt_name, "base", rt)
+
+
 @register_line_magic
 def stop_runtime(rt_name: str):
     """
@@ -804,7 +922,8 @@ def run_cell(raw_cell, store_history=False, silent=False, 
shell_futures=True, ce
             wait_until_runtime_ready(access_token, rt)
             return run_on_runtime(rt, raw_cell, store_history, silent, 
shell_futures, cell_id)
         except Exception as e:
-            info = ExecutionInfo(raw_cell, store_history, silent, 
shell_futures, cell_id)
+            info = ExecutionInfo(raw_cell, store_history,
+                                 silent, shell_futures, cell_id)
             result = ExecutionResult(info)
             print(f"Error: {e}")
             result.error_in_exec = e
@@ -819,6 +938,7 @@ Loaded airavata_jupyter_magic
 
   %authenticate                      -- Authenticate to access 
high-performance runtimes.
   %request_runtime <rt> [args]       -- Request a runtime named <rt> with 
configuration <args>. Call multiple times to request multiple runtimes.
+  %restart_runtime <rt>              -- Restart runtime <rt>. Run this if you 
install new dependencies or if the runtime hangs.
   %stop_runtime <rt>                 -- Stop runtime <rt> when no longer 
needed.
   %switch_runtime <rt>               -- Switch active runtime to <rt>. All 
subsequent executions will use this runtime.
   %%run_on <rt>                      -- Force a cell to always execute on 
<rt>, regardless of the active runtime.
diff --git 
a/modules/agent-framework/agent-service/src/main/java/org/apache/airavata/agent/connection/service/handlers/AgentManagementHandler.java
 
b/modules/agent-framework/agent-service/src/main/java/org/apache/airavata/agent/connection/service/handlers/AgentManagementHandler.java
index acc2e003a0..b8d6645c27 100644
--- 
a/modules/agent-framework/agent-service/src/main/java/org/apache/airavata/agent/connection/service/handlers/AgentManagementHandler.java
+++ 
b/modules/agent-framework/agent-service/src/main/java/org/apache/airavata/agent/connection/service/handlers/AgentManagementHandler.java
@@ -164,11 +164,13 @@ public class AgentManagementHandler {
         List<InputDataObjectType> applicationInputs = 
airavataClient.getApplicationInputs(authzToken, appInterfaceId);
         List<InputDataObjectType> experimentInputs = applicationInputs.stream()
                 .peek(input -> {
-                    if ("agent_id".equals(input.getName())) {
-                        input.setValue(agentId);
-
-                    } else if ("server_url".equals(input.getName())) {
-                        input.setValue(airavataService.getServerUrl());
+                    switch (input.getName()) {
+                        case "agent_id" -> input.setValue(agentId);
+                        case "server_url" -> 
input.setValue(airavataService.getServerUrl());
+                        case "libraries" -> input.setValue(String.join(",", 
req.getLibraries()));
+                        case "pip" -> input.setValue(String.join(",", 
req.getPip()));
+                        case "mounts" -> input.setValue(String.join(",", 
req.getMounts()));
+                        default -> {}
                     }
                 })
                 .collect(Collectors.toList());
diff --git 
a/modules/agent-framework/agent-service/src/main/java/org/apache/airavata/agent/connection/service/models/AgentLaunchRequest.java
 
b/modules/agent-framework/agent-service/src/main/java/org/apache/airavata/agent/connection/service/models/AgentLaunchRequest.java
index c39536f1bf..7934ddf74e 100644
--- 
a/modules/agent-framework/agent-service/src/main/java/org/apache/airavata/agent/connection/service/models/AgentLaunchRequest.java
+++ 
b/modules/agent-framework/agent-service/src/main/java/org/apache/airavata/agent/connection/service/models/AgentLaunchRequest.java
@@ -1,5 +1,8 @@
 package org.apache.airavata.agent.connection.service.models;
 
+import java.util.ArrayList;
+import java.util.List;
+
 import org.apache.commons.lang3.StringUtils;
 
 public class AgentLaunchRequest {
@@ -8,6 +11,9 @@ public class AgentLaunchRequest {
     private String projectName;
     private String remoteCluster;
     private String group;
+    private List<String> libraries = new ArrayList<>();
+    private List<String> pip = new ArrayList<>();
+    private List<String> mounts = new ArrayList<>();
 
     private String queue = "shared";
     private int wallTime = 30;
@@ -90,4 +96,28 @@ public class AgentLaunchRequest {
     public void setProjectName(String projectName) {
         this.projectName = projectName;
     }
+
+    public List<String> getLibraries() {
+        return libraries;
+    }
+
+    public void setLibraries(List<String> libraries) {
+        this.libraries = libraries;
+    }
+
+    public List<String> getPip() {
+        return pip;
+    }
+
+    public void setPip(List<String> pip) {
+        this.pip = pip;
+    }
+
+    public List<String> getMounts() {
+        return mounts;
+    }
+
+    public void setMounts(List<String> mounts) {
+        this.mounts = mounts;
+    }
 }

Reply via email to