This is an automated email from the ASF dual-hosted git repository.

dimuthuupe pushed a commit to branch cybershuttle-staging
in repository https://gitbox.apache.org/repos/asf/airavata.git

commit fcb0c50670f7802d9b7fec833b2e27dd16225954
Author: yasith <[email protected]>
AuthorDate: Fri Apr 4 13:00:00 2025 -0400

    make remote cell execution responsive. bump version to 2.0.6
---
 .../airavata_jupyter_magic/__init__.py             | 126 +++++++++++----------
 .../airavata_jupyter_magic/device_auth.py          |   5 +-
 .../airavata-python-sdk/pyproject.toml             |   2 +-
 3 files changed, 69 insertions(+), 64 deletions(-)

diff --git 
a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_jupyter_magic/__init__.py
 
b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_jupyter_magic/__init__.py
index 51437cc43f..a68752b0dc 100644
--- 
a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_jupyter_magic/__init__.py
+++ 
b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_jupyter_magic/__init__.py
@@ -1,5 +1,7 @@
+import asyncio
 import base64
 import binascii
+from functools import partial
 import json
 import os
 import time
@@ -7,13 +9,14 @@ from argparse import ArgumentParser
 from dataclasses import dataclass
 from enum import IntEnum
 from pathlib import Path
-from typing import NamedTuple
+from types import CodeType
+from typing import Any, NamedTuple, Optional
 
 import jwt
 import requests
 import yaml
 from IPython.core.getipython import get_ipython
-from IPython.core.interactiveshell import ExecutionInfo, ExecutionResult
+from IPython.core.interactiveshell import ExecutionResult
 from IPython.core.magic import register_cell_magic, register_line_magic
 from IPython.display import HTML, Image, display
 from rich.console import Console
@@ -121,7 +124,7 @@ def get_access_token(envar_name: str = "CS_ACCESS_TOKEN", 
state_path: str = "/tm
     return token
 
 
-def is_runtime_ready(access_token: str, rt: RuntimeInfo, rt_name: str) -> 
tuple[bool, str]:
+def is_runtime_ready(access_token: str, rt: RuntimeInfo, rt_name: str):
     """
     Check if the runtime (i.e., agent job) is ready to receive requests
 
@@ -361,16 +364,14 @@ def wait_until_runtime_ready(access_token: str, rt_name: 
str):
     if rt_name == "local":
         return
     console = Console()
-    with console.status(f"Connecting to runtime={rt_name}...") as status:
+    with console.status(f"Connecting to={rt_name}...") as status:
         while True:
             ready, rstate = is_runtime_ready(access_token, rt, rt_name)
             if ready:
-                status.update(
-                    f"Connecting to runtime={rt_name}... status=READY")
+                status.update(f"Connecting to={rt_name}... status=READY")
                 break
             else:
-                status.update(
-                    f"Connecting to runtime={rt_name}... status={rstate}")
+                status.update(f"Connecting to={rt_name}... status={rstate}")
                 time.sleep(5)
         status.stop()
     console.clear()
@@ -467,19 +468,17 @@ def stop_agent_job(access_token: str, runtime_name: str, 
runtime: RuntimeInfo):
             f'[{status}] Failed to terminate runtime={runtime_name}: 
error={res.text}')
 
 
-def run_on_runtime(rt_name: str, cell: str, store_history=False, silent=False, 
shell_futures=True, cell_id=None):
-    info = ExecutionInfo(cell, store_history, silent, shell_futures, cell_id)
-    excResult = ExecutionResult(info)
+def run_on_runtime(rt_name: str, code_obj: str, result: ExecutionResult) -> 
bool:
     rt = state.all_runtimes.get(rt_name, None)
     if rt is None:
-        excResult.error_in_exec = Exception(f"Runtime {rt_name} not found.")
-        return excResult
+        result.error_in_exec = Exception(f"Runtime {rt_name} not found.")
+        return False
 
     url = api_base_url + '/api/v1/agent/execute/jupyter'
     data = {
         "agentId": rt.agentId,
         "envName": "base",
-        "code": cell,
+        "code": code_obj,
     }
     json_data = json.dumps(data)
     response = requests.post(
@@ -488,14 +487,14 @@ def run_on_runtime(rt_name: str, cell: str, 
store_history=False, silent=False, s
 
     execution_id = execution_resp.get("executionId")
     if not execution_id:
-        excResult.error_in_exec = Exception("Failed to start cell execution")
-        return excResult
+        result.error_in_exec = Exception("Failed to start cell execution")
+        return False
 
     error = execution_resp.get("error")
     if error:
-        excResult.error_in_exec = Exception(
+        result.error_in_exec = Exception(
             "Cell execution failed. Error: " + error)
-        return excResult
+        return False
 
     while True:
         url = api_base_url + "/api/v1/agent/execute/jupyter/" + execution_id
@@ -505,16 +504,16 @@ def run_on_runtime(rt_name: str, cell: str, 
store_history=False, silent=False, s
             break
         time.sleep(1)
 
-    result_str = json_response.get('responseString')
+    exec_result_str = json_response.get('responseString')
     try:
-        result = json.loads(result_str)
+        exec_result = json.loads(exec_result_str)
     except json.JSONDecodeError as e:
-        excResult.error_in_exec = Exception(
+        result.error_in_exec = Exception(
             f"Failed to decode response from runtime={rt_name}: {e.msg}")
-        return excResult
+        return False
 
-    if 'outputs' in result:
-        for output in result['outputs']:
+    if 'outputs' in exec_result:
+        for output in exec_result['outputs']:
             output_type = output.get('output_type')
             if output_type == 'display_data':
                 data_obj = output.get('data', {})
@@ -524,9 +523,9 @@ def run_on_runtime(rt_name: str, cell: str, 
store_history=False, silent=False, s
                         image_bytes = base64.b64decode(image_data)
                         display(Image(data=image_bytes, format='png'))
                     except binascii.Error as e:
-                        excResult.error_in_exec = Exception(
+                        result.error_in_exec = Exception(
                             f"Failed to decode image data: {e}")
-                        return excResult
+                        return False
 
             elif output_type == 'stream':
                 stream_name = output.get('name', 'stdout')
@@ -546,8 +545,8 @@ def run_on_runtime(rt_name: str, cell: str, 
store_history=False, silent=False, s
                     </div>
                     """
                     display(HTML(error_html))
-                    excResult.error_in_exec = Exception(stream_text)
-                    return excResult
+                    result.error_in_exec = Exception(stream_text)
+                    return False
                 else:
                     print(stream_text)
 
@@ -570,31 +569,31 @@ def run_on_runtime(rt_name: str, cell: str, 
store_history=False, silent=False, s
                     error_html += f"{line}\n"
                 error_html += "</pre></div>"
                 display(HTML(error_html))
-                excResult.error_in_exec = Exception(f"{ename}: {evalue}")
-                return excResult
+                result.error_in_exec = Exception(f"{ename}: {evalue}")
+                return False
 
             elif output_type == 'execute_result':
                 data_obj = output.get('data', {})
                 if 'text/plain' in data_obj:
                     print(data_obj['text/plain'])
     else:
-        if 'result' in result:
-            print(result['result'])
-        elif 'error' in result:
-            print(result['error']['ename'])
-            print(result['error']['evalue'])
-            print(result['error']['traceback'])
-        elif 'display' in result:
-            data_obj = result['display'].get('data', {})
+        if 'result' in exec_result:
+            print(exec_result['result'])
+        elif 'error' in exec_result:
+            print(exec_result['error']['ename'])
+            print(exec_result['error']['evalue'])
+            print(exec_result['error']['traceback'])
+        elif 'display' in exec_result:
+            data_obj = exec_result['display'].get('data', {})
             if 'image/png' in data_obj:
                 image_data = data_obj['image/png']
                 try:
                     image_bytes = base64.b64decode(image_data)
                     display(Image(data=image_bytes, format='png'))
                 except binascii.Error as e:
-                    excResult.error_in_exec = Exception(
+                    result.error_in_exec = Exception(
                         f"Failed to decode image data: {e}")
-                    return excResult
+                    return False
 
         else:
             # Mark as failed execution if no recognized output format is found
@@ -610,12 +609,12 @@ def run_on_runtime(rt_name: str, cell: str, 
store_history=False, silent=False, s
             <strong>Error:</strong> Execution failed with unrecognized output 
format from remote runtime.
             <pre>{}</pre>
           </div>
-          """.format(result_str)
+          """.format(exec_result)
             display(HTML(error_html))
-            excResult.error_in_exec = Exception(
+            result.error_in_exec = Exception(
                 "Execution failed with unrecognized output format from remote 
runtime.")
-            return excResult
-    return excResult
+            return False
+    return True
 
 
 def push_remote(local_path: str, remot_rt: str, remot_path: str) -> None:
@@ -685,7 +684,7 @@ def run_on(line: str, cell: str):
     try:
         if cell_runtime in ["local", *state.all_runtimes]:
             state.current_runtime = cell_runtime
-            ipython.run_cell(cell, silent=True)
+            return ipython.run_cell(cell)
         else:
             msg = f"Runtime {cell_runtime} not found."
             print(msg)
@@ -776,7 +775,7 @@ def request_runtime(line: str):
     args = p.parse_args(cmd_args, namespace=RequestedRuntime())
 
     if args.file is not None:
-        submit_agent_job(
+        return submit_agent_job(
             rt_name=rt_name,
             access_token=access_token,
             app_name='CS_Agent',
@@ -788,7 +787,7 @@ def request_runtime(line: str):
         assert args.walltime is not None
         assert args.queue is not None
         assert args.group is not None
-        submit_agent_job(
+        return submit_agent_job(
             rt_name=rt_name,
             access_token=access_token,
             app_name='CS_Agent',
@@ -807,7 +806,6 @@ def stat_runtime(line: str):
     Show the status of the runtime
 
     """
-
     access_token = get_access_token()
     assert access_token is not None
 
@@ -839,7 +837,7 @@ def restart_runtime(rt_name: str):
     rt = state.all_runtimes.get(rt_name, None)
     if rt is None:
         return print(f"Runtime {rt_name} not found.")
-    restart_runtime_kernel(access_token, rt_name, "base", rt)
+    return restart_runtime_kernel(access_token, rt_name, "base", rt)
 
 
 @register_line_magic
@@ -854,7 +852,7 @@ def stop_runtime(rt_name: str):
     rt = state.all_runtimes.get(rt_name, None)
     if rt is None:
         return print(f"Runtime {rt_name} not found.")
-    stop_agent_job(access_token, rt_name, rt)
+    return stop_agent_job(access_token, rt_name, rt)
 
 
 @register_line_magic
@@ -895,13 +893,13 @@ def copy_data(line: str):
 ipython = get_ipython()
 if ipython is None:
     raise RuntimeError("airavata_jupyter_magic requires an ipython session")
-
+assert ipython is not None
 api_base_url = "https://api.gateway.cybershuttle.org";
 file_server_url = "http://3.142.234.94:8050";
 MSG_NOT_INITIALIZED = r"Runtime not found. Please run %request_runtime 
name=<name> cluster=<cluster> cpu=<cpu> memory=<memory mb> queue=<queue> 
walltime=<walltime minutes> group=<group> to request one."
 
 state = State(current_runtime="local", all_runtimes={})
-orig_run_cell = ipython.run_cell
+orig_run_code = ipython.run_cell_async
 
 
 def cell_has_magic(raw_cell: str) -> bool:
@@ -911,26 +909,32 @@ def cell_has_magic(raw_cell: str) -> bool:
     return any(line.strip().startswith(magics) for line in lines)
 
 
-def run_cell(raw_cell, store_history=False, silent=False, shell_futures=True, 
cell_id=None):
+async def run_cell_async(
+    raw_cell: str,
+    store_history=False,
+    silent=False,
+    shell_futures=True,
+    *,
+    transformed_cell: Optional[str] = None,
+    preprocessing_exc_tuple: Optional[Any] = None,
+    cell_id=None,
+) -> ExecutionResult:
     rt = state.current_runtime
     if rt == "local" or cell_has_magic(raw_cell):
-        return orig_run_cell(raw_cell, store_history, silent, shell_futures, 
cell_id)
+        return await orig_run_code(raw_cell, store_history, silent, 
shell_futures, transformed_cell=transformed_cell, 
preprocessing_exc_tuple=preprocessing_exc_tuple, cell_id=cell_id)
     else:
         access_token = get_access_token()
         assert access_token is not None
+        result = ExecutionResult(info=None)
         try:
             wait_until_runtime_ready(access_token, rt)
-            return run_on_runtime(rt, raw_cell, store_history, silent, 
shell_futures, cell_id)
+            run_on_runtime(rt, raw_cell, result)
+            return result
         except Exception as e:
-            info = ExecutionInfo(raw_cell, store_history,
-                                 silent, shell_futures, cell_id)
-            result = ExecutionResult(info)
-            print(f"Error: {e}")
             result.error_in_exec = e
             return result
 
-
-ipython.run_cell = run_cell
+ipython.run_cell_async = run_cell_async
 
 print(r"""
 Loaded airavata_jupyter_magic
diff --git 
a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_jupyter_magic/device_auth.py
 
b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_jupyter_magic/device_auth.py
index 72fd44b769..e6bf83396d 100644
--- 
a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_jupyter_magic/device_auth.py
+++ 
b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_jupyter_magic/device_auth.py
@@ -1,6 +1,7 @@
-import requests
-import time
 import os
+import time
+
+import requests
 from rich.console import Console
 
 # Load environment variables from .env file
diff --git 
a/airavata-api/airavata-client-sdks/airavata-python-sdk/pyproject.toml 
b/airavata-api/airavata-client-sdks/airavata-python-sdk/pyproject.toml
index dadd1463cb..58c49a9e4e 100644
--- a/airavata-api/airavata-client-sdks/airavata-python-sdk/pyproject.toml
+++ b/airavata-api/airavata-client-sdks/airavata-python-sdk/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
 
 [project]
 name = "airavata-python-sdk"
-version = "2.0.5.post2"
+version = "2.0.6"
 description = "Apache Airavata Python SDK"
 readme = "README.md"
 license = "Apache-2.0"

Reply via email to