This is an automated email from the ASF dual-hosted git repository. lahirujayathilake pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/airavata.git
commit 1046150be6f402054a39be08ef0b70092e9778a1 Author: yasith <[email protected]> AuthorDate: Sun Dec 15 21:06:25 2024 -0600 make settings.ini lean, add runtime picker {by cluster, category}. change agent base dir. --- .../airavata_experiments/airavata.py | 89 +++++++++++----------- .../airavata_experiments/runtime.py | 39 +++++++--- .../airavata-python-sdk/samples/poc.ipynb | 2 +- 3 files changed, 73 insertions(+), 57 deletions(-) diff --git a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/airavata.py b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/airavata.py index b6556f1770..36c2efbf43 100644 --- a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/airavata.py +++ b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/airavata.py @@ -26,7 +26,6 @@ from airavata.model.scheduling.ttypes import ComputationalResourceSchedulingMode from airavata.model.data.replica.ttypes import DataProductModel, DataProductType, DataReplicaLocationModel, ReplicaLocationCategory from airavata_sdk.clients.api_server_client import APIServerClient -from airavata_sdk.transport.settings import ExperimentSettings, GatewaySettings, APIServerClientSettings logger = logging.getLogger("airavata_sdk.clients") logger.setLevel(logging.INFO) @@ -38,6 +37,31 @@ LaunchState = NamedTuple("LaunchState", [ ("sr_host", str), ]) +class Settings: + + def __init__(self, config_path: str) -> None: + + import configparser + config = configparser.ConfigParser() + config.read(config_path) + + # api server client settings + self.API_SERVER_HOST = config.get('APIServer', 'API_HOST') + self.API_SERVER_PORT = config.getint('APIServer', 'API_PORT') + self.API_SERVER_SECURE = config.getboolean('APIServer', 'API_SECURE') + + # gateway settings + self.GATEWAY_ID = config.get('Gateway', 'GATEWAY_ID') + self.GATEWAY_URL = config.get('Gateway', 'GATEWAY_URL') + self.GATEWAY_DATA_STORE_DIR = config.get('Gateway', 'GATEWAY_DATA_STORE_DIR') + self.STORAGE_RESOURCE_HOST = config.get('Gateway', 'STORAGE_RESOURCE_HOST') + self.SFTP_PORT = config.get('Gateway', 'SFTP_PORT') + + # runtime-specific settings + self.PROJECT_NAME = config.get('User', 'PROJECT_NAME') + self.GROUP_RESOURCE_PROFILE_NAME = config.get('User', 'GROUP_RESOURCE_PROFILE_NAME') + + class AiravataOperator: def register_input_file( @@ -128,44 +152,30 @@ class AiravataOperator: def __init__(self, access_token: str, config_file: str = "settings.ini"): # store variables self.access_token = access_token - self.api_settings = APIServerClientSettings(config_file) - self.gateway_settings = GatewaySettings(config_file) - self.experiment_settings = ExperimentSettings(config_file) + self.settings = Settings(config_file) # load api server settings and create client - self.api_server_client = APIServerClient(api_server_settings=self.api_settings) + self.api_server_client = APIServerClient(api_server_settings=self.settings) # load gateway settings gateway_id = self.default_gateway_id() self.airavata_token = self.__airavata_token__(self.access_token, gateway_id) def default_gateway_id(self): - return self.gateway_settings.GATEWAY_ID + return self.settings.GATEWAY_ID + + def default_gateway_grp_name(self): + return self.settings.GROUP_RESOURCE_PROFILE_NAME def default_gateway_data_store_dir(self): - return self.gateway_settings.GATEWAY_DATA_STORE_DIR + return self.settings.GATEWAY_DATA_STORE_DIR def default_sftp_port(self): - return self.experiment_settings.SFTP_PORT - - def default_experiment_queue(self): - return self.experiment_settings.QUEUE_NAME - - def default_grp_name(self): - return self.experiment_settings.GROUP_RESOURCE_PROFILE_NAME + return self.settings.SFTP_PORT def default_sr_hostname(self): - return self.experiment_settings.STORAGE_RESOURCE_HOST + return self.settings.STORAGE_RESOURCE_HOST def default_project_name(self): - return self.experiment_settings.PROJECT_NAME - - def default_node_count(self): - return self.experiment_settings.NODE_COUNT - - def default_cpu_count(self): - return self.experiment_settings.TOTAL_CPU_COUNT - - def default_walltime(self): - return self.experiment_settings.WALL_TIME_LIMIT + return self.settings.PROJECT_NAME def __airavata_token__(self, access_token: str, gateway_id: str): """ @@ -234,7 +244,7 @@ class AiravataOperator: """ # use defaults for missing values - grp_name = grp_name or self.default_grp_name() + grp_name = grp_name or self.default_gateway_grp_name() # logic grps: list = self.api_server_client.get_group_resource_list(self.airavata_token, self.default_gateway_id()) # type: ignore grp_id = next((grp.groupResourceProfileId for grp in grps if grp.groupResourceProfileName == grp_name)) @@ -251,7 +261,7 @@ class AiravataOperator: """ # use defaults for missing values - grp_name = grp_name or self.default_grp_name() + grp_name = grp_name or self.default_gateway_grp_name() # logic grps: list = self.api_server_client.get_group_resource_list(self.airavata_token, self.default_gateway_id()) # type: ignore grp_id = next((grp.groupResourceProfileId for grp in grps if grp.groupResourceProfileName == grp_name)) @@ -368,17 +378,17 @@ class AiravataOperator: self, experiment_name: str, app_name: str, - computation_resource_name: str, inputs: dict[str, str | int | float | list[str]], + computation_resource_name: str, + queue_name: str, + node_count: int, + cpu_count: int, + walltime: int, *, gateway_id: str | None = None, - queue_name: str | None = None, grp_name: str | None = None, sr_host: str | None = None, project_name: str | None = None, - node_count: int | None = None, - cpu_count: int | None = None, - walltime: int | None = None, auto_schedule: bool = False, ) -> LaunchState: """ @@ -388,18 +398,11 @@ class AiravataOperator: # preprocess args (str) print("[AV] Preprocessing args...") gateway_id = str(gateway_id or self.default_gateway_id()) - mount_point = Path(self.default_gateway_data_store_dir()) / self.user_id - - queue_name = str(queue_name or self.default_experiment_queue()) - grp_name = str(grp_name or self.default_grp_name()) + grp_name = str(grp_name or self.default_gateway_grp_name()) sr_host = str(sr_host or self.default_sr_hostname()) + mount_point = Path(self.default_gateway_data_store_dir()) / self.user_id project_name = str(project_name or self.default_project_name()) - # preprocess args (int) - node_count = int(node_count or self.default_node_count() or "1") - cpu_count = int(cpu_count or self.default_cpu_count() or "1") - walltime = int(walltime or self.default_walltime() or "30") - # validate args (str) print("[AV] Validating args...") assert len(experiment_name) > 0 @@ -421,10 +424,6 @@ class AiravataOperator: # setup runtime params print("[AV] Setting up runtime params...") storage = self.get_storage(sr_host) - queue_name = queue_name or self.default_experiment_queue() - node_count = int(node_count or self.default_node_count() or "1") - cpu_count = int(cpu_count or self.default_cpu_count() or "1") - walltime = int(walltime or self.default_walltime() or "30") sr_id = storage.storageResourceId # setup application interface diff --git a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/runtime.py b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/runtime.py index 1cdfdb4284..16f5c41dc1 100644 --- a/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/runtime.py +++ b/airavata-api/airavata-client-sdks/airavata-python-sdk/airavata_experiments/runtime.py @@ -67,11 +67,10 @@ class Runtime(abc.ABC, pydantic.BaseModel): @staticmethod def default(): - # return Mock() return Remote.default() @staticmethod - def create(id: str, args: dict[str, Any]) -> "Runtime": + def create(id: str, args: dict[str, Any]) -> Runtime: if id == "mock": return Mock(**args) elif id == "remote": @@ -133,8 +132,15 @@ class Mock(Runtime): class Remote(Runtime): - def __init__(self, **kwargs) -> None: - super().__init__(id="remote", args=kwargs) + def __init__(self, cluster: str, category: str, queue_name: str, node_count: int, cpu_count: int, walltime: int) -> None: + super().__init__(id="remote", args=dict( + cluster=cluster, + category=category, + queue_name=queue_name, + node_count=node_count, + cpu_count=cpu_count, + walltime=walltime, + )) def execute(self, task: Task) -> None: assert task.ref is None @@ -148,8 +154,12 @@ class Remote(Runtime): launch_state = av.launch_experiment( experiment_name=task.name, app_name=task.app_id, + inputs={**task.inputs, "agent_id": task.agent_ref, "server_url": conn_svc_url}, computation_resource_name=str(self.args["cluster"]), - inputs={**task.inputs, "agent_id": task.agent_ref, "server_url": conn_svc_url} + queue_name=str(self.args["queue_name"]), + node_count=int(self.args["node_count"]), + cpu_count=int(self.args["cpu_count"]), + walltime=int(self.args["walltime"]), ) task.ref = launch_state.experiment_id task.workdir = launch_state.experiment_dir @@ -328,9 +338,16 @@ class Remote(Runtime): @staticmethod def default(): - return Remote( - cluster="login.expanse.sdsc.edu", - ) - -def list_runtimes(**kwargs) -> list[Runtime]: - return [Remote(cluster="login.expanse.sdsc.edu"), Remote(cluster="anvil.rcac.purdue.edu")] \ No newline at end of file + return Remote(cluster="login.expanse.sdsc.edu", category="gpu", queue_name="gpu-shared", node_count=1, cpu_count=24, walltime=30) + + +def list_runtimes( + cluster: str | None = None, + category: str | None = None, +) -> list[Runtime]: + all_runtimes = list[Runtime]([ + Remote(cluster="login.expanse.sdsc.edu", category="gpu", queue_name="gpu-shared", node_count=1, cpu_count=10, walltime=30), + Remote(cluster="login.expanse.sdsc.edu", category="cpu", queue_name="shared", node_count=1, cpu_count=10, walltime=30), + Remote(cluster="anvil.rcac.purdue.edu", category="cpu", queue_name="shared", node_count=1, cpu_count=24, walltime=30), + ]) + return [*filter(lambda r: (cluster in [None, r.args["cluster"]]) and (category in [None, r.args["category"]]), all_runtimes)] \ No newline at end of file diff --git a/airavata-api/airavata-client-sdks/airavata-python-sdk/samples/poc.ipynb b/airavata-api/airavata-client-sdks/airavata-python-sdk/samples/poc.ipynb index 7ec5ed6651..8b1fc79325 100644 --- a/airavata-api/airavata-client-sdks/airavata-python-sdk/samples/poc.ipynb +++ b/airavata-api/airavata-client-sdks/airavata-python-sdk/samples/poc.ipynb @@ -161,7 +161,7 @@ " ],\n", " parallelism=\"GPU\",\n", ")\n", - "exp.add_replica()\n", + "exp.add_replica(*ae.list_runtimes(cluster=\"login.expanse.sdsc.edu\", category=\"gpu\"))\n", "ae.display(exp)" ] },
