This is an automated email from the ASF dual-hosted git repository.

lahirujayathilake pushed a commit to branch cybershuttle-dev
in repository https://gitbox.apache.org/repos/asf/airavata.git

commit 919bfd1d9d16882a9bc154d701be85add5482fa9
Author: yasith <[email protected]>
AuthorDate: Sun Dec 15 21:06:25 2024 -0600

    make settings.ini lean, add runtime picker {by cluster, category}. change 
agent base dir.
---
 .../airavata_experiments/airavata.py               | 89 +++++++++++-----------
 .../airavata_experiments/runtime.py                | 39 +++++++---
 .../airavata-python-sdk/samples/poc.ipynb          |  2 +-
 3 files changed, 73 insertions(+), 57 deletions(-)

diff --git 
a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/airavata.py
 
b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/airavata.py
index b6556f1770..36c2efbf43 100644
--- 
a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/airavata.py
+++ 
b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/airavata.py
@@ -26,7 +26,6 @@ from airavata.model.scheduling.ttypes import 
ComputationalResourceSchedulingMode
 from airavata.model.data.replica.ttypes import DataProductModel, 
DataProductType, DataReplicaLocationModel, ReplicaLocationCategory
 
 from airavata_sdk.clients.api_server_client import APIServerClient
-from airavata_sdk.transport.settings import ExperimentSettings, 
GatewaySettings, APIServerClientSettings
 
 logger = logging.getLogger("airavata_sdk.clients")
 logger.setLevel(logging.INFO)
@@ -38,6 +37,31 @@ LaunchState = NamedTuple("LaunchState", [
   ("sr_host", str),
 ])
 
+class Settings:
+
+  def __init__(self, config_path: str) -> None:
+    
+    import configparser
+    config = configparser.ConfigParser()
+    config.read(config_path)
+
+    # api server client settings
+    self.API_SERVER_HOST = config.get('APIServer', 'API_HOST')
+    self.API_SERVER_PORT = config.getint('APIServer', 'API_PORT')
+    self.API_SERVER_SECURE = config.getboolean('APIServer', 'API_SECURE')
+    
+    # gateway settings
+    self.GATEWAY_ID = config.get('Gateway', 'GATEWAY_ID')
+    self.GATEWAY_URL = config.get('Gateway', 'GATEWAY_URL')
+    self.GATEWAY_DATA_STORE_DIR = config.get('Gateway', 
'GATEWAY_DATA_STORE_DIR')
+    self.STORAGE_RESOURCE_HOST = config.get('Gateway', 'STORAGE_RESOURCE_HOST')
+    self.SFTP_PORT = config.get('Gateway', 'SFTP_PORT')
+    
+    # runtime-specific settings
+    self.PROJECT_NAME = config.get('User', 'PROJECT_NAME')
+    self.GROUP_RESOURCE_PROFILE_NAME = config.get('User', 
'GROUP_RESOURCE_PROFILE_NAME')
+    
+
 class AiravataOperator:
 
   def register_input_file(
@@ -128,44 +152,30 @@ class AiravataOperator:
   def __init__(self, access_token: str, config_file: str = "settings.ini"):
     # store variables
     self.access_token = access_token
-    self.api_settings = APIServerClientSettings(config_file)
-    self.gateway_settings = GatewaySettings(config_file)
-    self.experiment_settings = ExperimentSettings(config_file)
+    self.settings = Settings(config_file)
     # load api server settings and create client
-    self.api_server_client = 
APIServerClient(api_server_settings=self.api_settings)
+    self.api_server_client = APIServerClient(api_server_settings=self.settings)
     # load gateway settings
     gateway_id = self.default_gateway_id()
     self.airavata_token = self.__airavata_token__(self.access_token, 
gateway_id)
 
   def default_gateway_id(self):
-    return self.gateway_settings.GATEWAY_ID
+    return self.settings.GATEWAY_ID
+  
+  def default_gateway_grp_name(self):
+    return self.settings.GROUP_RESOURCE_PROFILE_NAME
   
   def default_gateway_data_store_dir(self):
-    return self.gateway_settings.GATEWAY_DATA_STORE_DIR
+    return self.settings.GATEWAY_DATA_STORE_DIR
   
   def default_sftp_port(self):
-    return self.experiment_settings.SFTP_PORT
-  
-  def default_experiment_queue(self):
-    return self.experiment_settings.QUEUE_NAME
-  
-  def default_grp_name(self):
-    return self.experiment_settings.GROUP_RESOURCE_PROFILE_NAME
+    return self.settings.SFTP_PORT
   
   def default_sr_hostname(self):
-    return self.experiment_settings.STORAGE_RESOURCE_HOST
+    return self.settings.STORAGE_RESOURCE_HOST
   
   def default_project_name(self):
-    return self.experiment_settings.PROJECT_NAME
-  
-  def default_node_count(self):
-    return self.experiment_settings.NODE_COUNT
-  
-  def default_cpu_count(self):
-    return self.experiment_settings.TOTAL_CPU_COUNT
-  
-  def default_walltime(self):
-    return self.experiment_settings.WALL_TIME_LIMIT
+    return self.settings.PROJECT_NAME
 
   def __airavata_token__(self, access_token: str, gateway_id: str):
     """
@@ -234,7 +244,7 @@ class AiravataOperator:
 
     """
     # use defaults for missing values
-    grp_name = grp_name or self.default_grp_name()
+    grp_name = grp_name or self.default_gateway_grp_name()
     # logic
     grps: list = 
self.api_server_client.get_group_resource_list(self.airavata_token, 
self.default_gateway_id()) # type: ignore
     grp_id = next((grp.groupResourceProfileId for grp in grps if 
grp.groupResourceProfileName == grp_name))
@@ -251,7 +261,7 @@ class AiravataOperator:
 
     """
     # use defaults for missing values
-    grp_name = grp_name or self.default_grp_name()
+    grp_name = grp_name or self.default_gateway_grp_name()
     # logic
     grps: list = 
self.api_server_client.get_group_resource_list(self.airavata_token, 
self.default_gateway_id()) # type: ignore
     grp_id = next((grp.groupResourceProfileId for grp in grps if 
grp.groupResourceProfileName == grp_name))
@@ -368,17 +378,17 @@ class AiravataOperator:
       self,
       experiment_name: str,
       app_name: str,
-      computation_resource_name: str,
       inputs: dict[str, str | int | float | list[str]],
+      computation_resource_name: str,
+      queue_name: str,
+      node_count: int,
+      cpu_count: int,
+      walltime: int,
       *,
       gateway_id: str | None = None,
-      queue_name: str | None = None,
       grp_name: str | None = None,
       sr_host: str | None = None,
       project_name: str | None = None,
-      node_count: int | None = None,
-      cpu_count: int | None = None,
-      walltime: int | None = None,
       auto_schedule: bool = False,
   ) -> LaunchState:
     """
@@ -388,18 +398,11 @@ class AiravataOperator:
     # preprocess args (str)
     print("[AV] Preprocessing args...")
     gateway_id = str(gateway_id or self.default_gateway_id())
-    mount_point = Path(self.default_gateway_data_store_dir()) / self.user_id
-
-    queue_name = str(queue_name or self.default_experiment_queue())
-    grp_name = str(grp_name or self.default_grp_name())
+    grp_name = str(grp_name or self.default_gateway_grp_name())
     sr_host = str(sr_host or self.default_sr_hostname())
+    mount_point = Path(self.default_gateway_data_store_dir()) / self.user_id
     project_name = str(project_name or self.default_project_name())
 
-    # preprocess args (int)
-    node_count = int(node_count or self.default_node_count() or "1")
-    cpu_count = int(cpu_count or self.default_cpu_count() or "1")
-    walltime = int(walltime or self.default_walltime() or "30")
-
     # validate args (str)
     print("[AV] Validating args...")
     assert len(experiment_name) > 0
@@ -421,10 +424,6 @@ class AiravataOperator:
     # setup runtime params
     print("[AV] Setting up runtime params...")
     storage = self.get_storage(sr_host)
-    queue_name = queue_name or self.default_experiment_queue()
-    node_count = int(node_count or self.default_node_count() or "1")
-    cpu_count = int(cpu_count or self.default_cpu_count() or "1")
-    walltime = int(walltime or self.default_walltime() or "30")
     sr_id = storage.storageResourceId
 
     # setup application interface
diff --git 
a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/runtime.py
 
b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/runtime.py
index 1cdfdb4284..16f5c41dc1 100644
--- 
a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/runtime.py
+++ 
b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/runtime.py
@@ -67,11 +67,10 @@ class Runtime(abc.ABC, pydantic.BaseModel):
 
   @staticmethod
   def default():
-    # return Mock()
     return Remote.default()
 
   @staticmethod
-  def create(id: str, args: dict[str, Any]) -> "Runtime":
+  def create(id: str, args: dict[str, Any]) -> Runtime:
     if id == "mock":
       return Mock(**args)
     elif id == "remote":
@@ -133,8 +132,15 @@ class Mock(Runtime):
 
 class Remote(Runtime):
 
-  def __init__(self, **kwargs) -> None:
-    super().__init__(id="remote", args=kwargs)
+  def __init__(self, cluster: str, category: str, queue_name: str, node_count: 
int, cpu_count: int, walltime: int) -> None:
+    super().__init__(id="remote", args=dict(
+        cluster=cluster,
+        category=category,
+        queue_name=queue_name,
+        node_count=node_count,
+        cpu_count=cpu_count,
+        walltime=walltime,
+    ))
 
   def execute(self, task: Task) -> None:
     assert task.ref is None
@@ -148,8 +154,12 @@ class Remote(Runtime):
     launch_state = av.launch_experiment(
         experiment_name=task.name,
         app_name=task.app_id,
+        inputs={**task.inputs, "agent_id": task.agent_ref, "server_url": 
conn_svc_url},
         computation_resource_name=str(self.args["cluster"]),
-        inputs={**task.inputs, "agent_id": task.agent_ref, "server_url": 
conn_svc_url}
+        queue_name=str(self.args["queue_name"]),
+        node_count=int(self.args["node_count"]),
+        cpu_count=int(self.args["cpu_count"]),
+        walltime=int(self.args["walltime"]),
     )
     task.ref = launch_state.experiment_id
     task.workdir = launch_state.experiment_dir
@@ -328,9 +338,16 @@ class Remote(Runtime):
 
   @staticmethod
   def default():
-    return Remote(
-        cluster="login.expanse.sdsc.edu",
-    )
-
-def list_runtimes(**kwargs) -> list[Runtime]:
-  return [Remote(cluster="login.expanse.sdsc.edu"), 
Remote(cluster="anvil.rcac.purdue.edu")]
\ No newline at end of file
+    return Remote(cluster="login.expanse.sdsc.edu", category="gpu", 
queue_name="gpu-shared", node_count=1, cpu_count=24, walltime=30)
+    
+
+def list_runtimes(
+    cluster: str | None = None,
+    category: str | None = None,
+) -> list[Runtime]:
+  all_runtimes = list[Runtime]([
+    Remote(cluster="login.expanse.sdsc.edu", category="gpu", 
queue_name="gpu-shared", node_count=1, cpu_count=10, walltime=30),
+    Remote(cluster="login.expanse.sdsc.edu", category="cpu", 
queue_name="shared", node_count=1, cpu_count=10, walltime=30),
+    Remote(cluster="anvil.rcac.purdue.edu", category="cpu", 
queue_name="shared", node_count=1, cpu_count=24, walltime=30),
+  ])
+  return [*filter(lambda r: (cluster in [None, r.args["cluster"]]) and 
(category in [None, r.args["category"]]), all_runtimes)]
\ No newline at end of file
diff --git 
a/airavata-api/airavata-client-sdks/airavata-python-sdk/samples/poc.ipynb 
b/airavata-api/airavata-client-sdks/airavata-python-sdk/samples/poc.ipynb
index 7ec5ed6651..8b1fc79325 100644
--- a/airavata-api/airavata-client-sdks/airavata-python-sdk/samples/poc.ipynb
+++ b/airavata-api/airavata-client-sdks/airavata-python-sdk/samples/poc.ipynb
@@ -161,7 +161,7 @@
     "    ],\n",
     "    parallelism=\"GPU\",\n",
     ")\n",
-    "exp.add_replica()\n",
+    "exp.add_replica(*ae.list_runtimes(cluster=\"login.expanse.sdsc.edu\", 
category=\"gpu\"))\n",
     "ae.display(exp)"
    ]
   },

Reply via email to