shahar1 commented on code in PR #59558: URL: https://github.com/apache/airflow/pull/59558#discussion_r2724209147
########## providers/google/src/airflow/providers/google/cloud/operators/ray.py: ########## @@ -0,0 +1,430 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Google Cloud Ray Job operators.""" + +from __future__ import annotations + +from collections.abc import Sequence +from functools import cached_property +from typing import TYPE_CHECKING, Any + +from google.api_core.exceptions import NotFound + +from airflow.exceptions import AirflowException +from airflow.providers.google.cloud.hooks.ray import RayJobHook +from airflow.providers.google.cloud.links.ray import RayJobLink +from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator + +if TYPE_CHECKING: + from airflow.providers.common.compat.sdk import Context + + +class RayJobBaseOperator(GoogleCloudBaseOperator): + """ + Base class for Jobs on Ray operators. + + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ( + "gcp_conn_id", + "impersonation_chain", + ) + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + @cached_property + def hook(self) -> RayJobHook: + return RayJobHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + + +class RaySubmitJobOperator(RayJobBaseOperator): + """ + Submit and execute Job on Ray cluster. + + When a job is submitted, it runs once to completion or failure. Retries or + different runs with different parameters should be handled by the + submitter. Jobs are bound to the lifetime of a Ray cluster, so if the + cluster goes down, all running jobs on that cluster will be terminated. + + :param cluster_address: Required. Either (1) the address of the Ray cluster, or (2) the HTTP address + of the dashboard server on the head node, e.g. "http://<head-node-ip>:8265". + In case (1) it must be specified as an address that can be passed to + ray.init(), e.g. a Ray Client address (ray://<head_node_host>:10001), + or "auto", or "localhost:<port>". + :param entrypoint: Required. The shell command to run for this job. + :param get_job_logs: If set to True, the operator will wait until the end of + Job execution and output the logs. + :param wait_for_job_done: If set to True, the operator will wait until the end of + Job execution. Please note, that if the Job will fail during execution and + this parameter is set to False, there will be no indication of the failure. + :param submission_id: A unique ID for this job. + :param runtime_env: The runtime environment to install and run this job in. + :param metadata: Arbitrary data to store along with this job. + :param entrypoint_num_cpus: The quantity of CPU cores to reserve for the execution + of the entrypoint command, separately from any tasks or actors launched + by it. Defaults to 0. + :param entrypoint_num_gpus: The quantity of GPUs to reserve for the execution + of the entrypoint command, separately from any tasks or actors launched + by it. Defaults to 0. + :param entrypoint_memory: The quantity of memory to reserve for the + execution of the entrypoint command, separately from any tasks or + actors launched by it. Defaults to 0. + :param entrypoint_resources: The quantity of custom resources to reserve for the + execution of the entrypoint command, separately from any tasks or + actors launched by it. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = tuple( + {"entrypoint", "submission_id", "cluster_address"} | set(RayJobBaseOperator.template_fields) + ) + operator_extra_links = (RayJobLink(),) + + def __init__( + self, + cluster_address: str, + entrypoint: str, + get_job_logs: bool | None = False, + wait_for_job_done: bool | None = False, + runtime_env: dict[str, Any] | None = None, + metadata: dict[str, str] | None = None, + submission_id: str | None = None, + entrypoint_num_cpus: int | float | None = None, + entrypoint_num_gpus: int | float | None = None, + entrypoint_memory: int | None = None, + entrypoint_resources: dict[str, float] | None = None, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.cluster_address = cluster_address + self.get_job_logs = get_job_logs + self.wait_for_job_done = wait_for_job_done + self.entrypoint = entrypoint + self.runtime_env = runtime_env + self.metadata = metadata + self.submission_id = submission_id + self.entrypoint_num_cpus = entrypoint_num_cpus + self.entrypoint_num_gpus = entrypoint_num_gpus + self.entrypoint_memory = entrypoint_memory + self.entrypoint_resources = entrypoint_resources + + def _check_job_status(self, cluster_address, job_id) -> str: + "Check if the Job has reached terminated state." + while True: + try: + job_status = self.hook.get_job_status(cluster_address=cluster_address, job_id=job_id) + if job_status not in ["SUCCEEDED", "FAILED"]: + self.log.info("Job status: %s...", job_status) + continue + self.log.info("Job has finished execution with status: %s", job_status) + return job_status + except Exception: + raise AirflowException("Some error occurred when trying to get Job's status.") + + def _get_job_logs(self, cluster_address, job_id): + "Output Job logs." + try: + logs = self.hook.get_job_logs(cluster_address=cluster_address, job_id=job_id) + self.log.info("Got job logs:\n%s\n", logs) Review Comment: I'd extract the `self.log.info` out of the exception. Also I'm not sure why the exception is needed at all. ########## providers/google/src/airflow/providers/google/cloud/operators/ray.py: ########## @@ -0,0 +1,430 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Google Cloud Ray Job operators.""" + +from __future__ import annotations + +from collections.abc import Sequence +from functools import cached_property +from typing import TYPE_CHECKING, Any + +from google.api_core.exceptions import NotFound + +from airflow.exceptions import AirflowException +from airflow.providers.google.cloud.hooks.ray import RayJobHook +from airflow.providers.google.cloud.links.ray import RayJobLink +from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator + +if TYPE_CHECKING: + from airflow.providers.common.compat.sdk import Context + + +class RayJobBaseOperator(GoogleCloudBaseOperator): + """ + Base class for Jobs on Ray operators. + + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ( + "gcp_conn_id", + "impersonation_chain", + ) + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + @cached_property + def hook(self) -> RayJobHook: + return RayJobHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + + +class RaySubmitJobOperator(RayJobBaseOperator): + """ + Submit and execute Job on Ray cluster. + + When a job is submitted, it runs once to completion or failure. Retries or + different runs with different parameters should be handled by the + submitter. Jobs are bound to the lifetime of a Ray cluster, so if the + cluster goes down, all running jobs on that cluster will be terminated. + + :param cluster_address: Required. Either (1) the address of the Ray cluster, or (2) the HTTP address + of the dashboard server on the head node, e.g. "http://<head-node-ip>:8265". + In case (1) it must be specified as an address that can be passed to + ray.init(), e.g. a Ray Client address (ray://<head_node_host>:10001), + or "auto", or "localhost:<port>". + :param entrypoint: Required. The shell command to run for this job. + :param get_job_logs: If set to True, the operator will wait until the end of + Job execution and output the logs. + :param wait_for_job_done: If set to True, the operator will wait until the end of + Job execution. Please note, that if the Job will fail during execution and + this parameter is set to False, there will be no indication of the failure. + :param submission_id: A unique ID for this job. + :param runtime_env: The runtime environment to install and run this job in. + :param metadata: Arbitrary data to store along with this job. + :param entrypoint_num_cpus: The quantity of CPU cores to reserve for the execution + of the entrypoint command, separately from any tasks or actors launched + by it. Defaults to 0. + :param entrypoint_num_gpus: The quantity of GPUs to reserve for the execution + of the entrypoint command, separately from any tasks or actors launched + by it. Defaults to 0. + :param entrypoint_memory: The quantity of memory to reserve for the + execution of the entrypoint command, separately from any tasks or + actors launched by it. Defaults to 0. + :param entrypoint_resources: The quantity of custom resources to reserve for the + execution of the entrypoint command, separately from any tasks or + actors launched by it. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = tuple( + {"entrypoint", "submission_id", "cluster_address"} | set(RayJobBaseOperator.template_fields) + ) + operator_extra_links = (RayJobLink(),) + + def __init__( + self, + cluster_address: str, + entrypoint: str, + get_job_logs: bool | None = False, + wait_for_job_done: bool | None = False, + runtime_env: dict[str, Any] | None = None, + metadata: dict[str, str] | None = None, + submission_id: str | None = None, + entrypoint_num_cpus: int | float | None = None, + entrypoint_num_gpus: int | float | None = None, + entrypoint_memory: int | None = None, + entrypoint_resources: dict[str, float] | None = None, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.cluster_address = cluster_address + self.get_job_logs = get_job_logs + self.wait_for_job_done = wait_for_job_done + self.entrypoint = entrypoint + self.runtime_env = runtime_env + self.metadata = metadata + self.submission_id = submission_id + self.entrypoint_num_cpus = entrypoint_num_cpus + self.entrypoint_num_gpus = entrypoint_num_gpus + self.entrypoint_memory = entrypoint_memory + self.entrypoint_resources = entrypoint_resources + + def _check_job_status(self, cluster_address, job_id) -> str: + "Check if the Job has reached terminated state." + while True: + try: + job_status = self.hook.get_job_status(cluster_address=cluster_address, job_id=job_id) + if job_status not in ["SUCCEEDED", "FAILED"]: Review Comment: small suggestion: put the values in an enum ########## providers/google/src/airflow/providers/google/cloud/operators/ray.py: ########## @@ -0,0 +1,430 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Google Cloud Ray Job operators.""" + +from __future__ import annotations + +from collections.abc import Sequence +from functools import cached_property +from typing import TYPE_CHECKING, Any + +from google.api_core.exceptions import NotFound + +from airflow.exceptions import AirflowException +from airflow.providers.google.cloud.hooks.ray import RayJobHook +from airflow.providers.google.cloud.links.ray import RayJobLink +from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator + +if TYPE_CHECKING: + from airflow.providers.common.compat.sdk import Context + + +class RayJobBaseOperator(GoogleCloudBaseOperator): + """ + Base class for Jobs on Ray operators. + + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ( + "gcp_conn_id", + "impersonation_chain", + ) + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + @cached_property + def hook(self) -> RayJobHook: + return RayJobHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + + +class RaySubmitJobOperator(RayJobBaseOperator): + """ + Submit and execute Job on Ray cluster. + + When a job is submitted, it runs once to completion or failure. Retries or + different runs with different parameters should be handled by the + submitter. Jobs are bound to the lifetime of a Ray cluster, so if the + cluster goes down, all running jobs on that cluster will be terminated. + + :param cluster_address: Required. Either (1) the address of the Ray cluster, or (2) the HTTP address + of the dashboard server on the head node, e.g. "http://<head-node-ip>:8265". + In case (1) it must be specified as an address that can be passed to + ray.init(), e.g. a Ray Client address (ray://<head_node_host>:10001), + or "auto", or "localhost:<port>". + :param entrypoint: Required. The shell command to run for this job. + :param get_job_logs: If set to True, the operator will wait until the end of + Job execution and output the logs. + :param wait_for_job_done: If set to True, the operator will wait until the end of + Job execution. Please note, that if the Job will fail during execution and + this parameter is set to False, there will be no indication of the failure. + :param submission_id: A unique ID for this job. + :param runtime_env: The runtime environment to install and run this job in. + :param metadata: Arbitrary data to store along with this job. + :param entrypoint_num_cpus: The quantity of CPU cores to reserve for the execution + of the entrypoint command, separately from any tasks or actors launched + by it. Defaults to 0. + :param entrypoint_num_gpus: The quantity of GPUs to reserve for the execution + of the entrypoint command, separately from any tasks or actors launched + by it. Defaults to 0. + :param entrypoint_memory: The quantity of memory to reserve for the + execution of the entrypoint command, separately from any tasks or + actors launched by it. Defaults to 0. + :param entrypoint_resources: The quantity of custom resources to reserve for the + execution of the entrypoint command, separately from any tasks or + actors launched by it. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = tuple( + {"entrypoint", "submission_id", "cluster_address"} | set(RayJobBaseOperator.template_fields) + ) + operator_extra_links = (RayJobLink(),) + + def __init__( + self, + cluster_address: str, + entrypoint: str, + get_job_logs: bool | None = False, + wait_for_job_done: bool | None = False, + runtime_env: dict[str, Any] | None = None, + metadata: dict[str, str] | None = None, + submission_id: str | None = None, + entrypoint_num_cpus: int | float | None = None, + entrypoint_num_gpus: int | float | None = None, + entrypoint_memory: int | None = None, + entrypoint_resources: dict[str, float] | None = None, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.cluster_address = cluster_address + self.get_job_logs = get_job_logs + self.wait_for_job_done = wait_for_job_done + self.entrypoint = entrypoint + self.runtime_env = runtime_env + self.metadata = metadata + self.submission_id = submission_id + self.entrypoint_num_cpus = entrypoint_num_cpus + self.entrypoint_num_gpus = entrypoint_num_gpus + self.entrypoint_memory = entrypoint_memory + self.entrypoint_resources = entrypoint_resources + + def _check_job_status(self, cluster_address, job_id) -> str: + "Check if the Job has reached terminated state." + while True: + try: + job_status = self.hook.get_job_status(cluster_address=cluster_address, job_id=job_id) + if job_status not in ["SUCCEEDED", "FAILED"]: + self.log.info("Job status: %s...", job_status) + continue + self.log.info("Job has finished execution with status: %s", job_status) + return job_status + except Exception: + raise AirflowException("Some error occurred when trying to get Job's status.") Review Comment: 1. Following a recent decision, we should avoid the usage of `AirflowException` (see [dev list thread](https://lists.apache.org/thread/5rv4tz0oc27bgr4khx0on0jz8fpxvh55)) - please refactor all existing usages. 2. For all existing `try...except Exception` - please reconsider if raising any exception but the original is necessary, because it obfuscates the real exception, which makes it harder for debugging. If they are necessary - they should be tested. ########## providers/google/tests/system/google/cloud/ray/example_ray_job.py: ########## @@ -0,0 +1,168 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +""" +Example Airflow DAG for Jobs on Ray operations. +""" + +from __future__ import annotations + +import os +from datetime import datetime + +from airflow.exceptions import AirflowOptionalProviderFeatureException +from airflow.models.dag import DAG +from airflow.providers.google.cloud.operators.ray import ( + RayDeleteJobOperator, + RayGetJobInfoOperator, + RayListJobsOperator, + RayStopJobOperator, + RaySubmitJobOperator, +) +from airflow.providers.google.cloud.operators.vertex_ai.ray import ( + CreateRayClusterOperator, + DeleteRayClusterOperator, + GetRayClusterOperator, +) + +try: + from google.cloud.aiplatform.vertex_ray.util import resources +except ImportError: + raise AirflowOptionalProviderFeatureException( + "The ray provider is optional and requires the `google-cloud-aiplatform` package to be installed. " + ) +try: + from airflow.sdk import TriggerRule +except ImportError: + # Compatibility for Airflow < 3.1 + from airflow.utils.trigger_rule import TriggerRule # type: ignore[no-redef,attr-defined] + +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default") +PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default") +DAG_ID = "ray_job_operations" +LOCATION = "us-central1" +JOB_ID = f"{DAG_ID}_{ENV_ID}".replace("-", "_") +WORKER_NODE_RESOURCES = resources.Resources( + node_count=1, +) + +with DAG( + DAG_ID, + schedule="@once", + start_date=datetime(2021, 1, 1), + catchup=False, + render_template_as_native_obj=True, + tags=["example", "job", "ray"], +) as dag: + create_ray_cluster = CreateRayClusterOperator( + task_id="create_ray_cluster", + project_id=PROJECT_ID, + location=LOCATION, + worker_node_types=[WORKER_NODE_RESOURCES], + python_version="3.10", + ray_version="2.33", + ) + + get_ray_cluster = GetRayClusterOperator( + task_id="get_ray_cluster", + project_id=PROJECT_ID, + location=LOCATION, + cluster_id=create_ray_cluster.output["cluster_id"], + ) + + # [START how_to_ray_submit_job] + submit_ray_job = RaySubmitJobOperator( + task_id="submit_ray_job", + cluster_address="{{ task_instance.xcom_pull(task_ids='get_ray_cluster')['dashboard_address'] }}", + entrypoint="python3 heavy.py", + runtime_env={ + "working_dir": "./providers/google/tests/system/google/cloud/vertex_ai/resources", Review Comment: Path seems to be wrong (according to current structure): ```suggestion "working_dir": "./providers/google/tests/system/google/cloud/ray/resources", ``` ########## providers/google/src/airflow/providers/google/cloud/operators/ray.py: ########## @@ -0,0 +1,430 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Google Cloud Ray Job operators.""" + +from __future__ import annotations + +from collections.abc import Sequence +from functools import cached_property +from typing import TYPE_CHECKING, Any + +from google.api_core.exceptions import NotFound + +from airflow.exceptions import AirflowException +from airflow.providers.google.cloud.hooks.ray import RayJobHook +from airflow.providers.google.cloud.links.ray import RayJobLink +from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator + +if TYPE_CHECKING: + from airflow.providers.common.compat.sdk import Context + + +class RayJobBaseOperator(GoogleCloudBaseOperator): + """ + Base class for Jobs on Ray operators. + + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ( + "gcp_conn_id", + "impersonation_chain", + ) + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + @cached_property + def hook(self) -> RayJobHook: + return RayJobHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + + +class RaySubmitJobOperator(RayJobBaseOperator): + """ + Submit and execute Job on Ray cluster. + + When a job is submitted, it runs once to completion or failure. Retries or + different runs with different parameters should be handled by the + submitter. Jobs are bound to the lifetime of a Ray cluster, so if the + cluster goes down, all running jobs on that cluster will be terminated. + + :param cluster_address: Required. Either (1) the address of the Ray cluster, or (2) the HTTP address + of the dashboard server on the head node, e.g. "http://<head-node-ip>:8265". + In case (1) it must be specified as an address that can be passed to + ray.init(), e.g. a Ray Client address (ray://<head_node_host>:10001), + or "auto", or "localhost:<port>". + :param entrypoint: Required. The shell command to run for this job. + :param get_job_logs: If set to True, the operator will wait until the end of + Job execution and output the logs. + :param wait_for_job_done: If set to True, the operator will wait until the end of + Job execution. Please note, that if the Job will fail during execution and + this parameter is set to False, there will be no indication of the failure. + :param submission_id: A unique ID for this job. + :param runtime_env: The runtime environment to install and run this job in. + :param metadata: Arbitrary data to store along with this job. + :param entrypoint_num_cpus: The quantity of CPU cores to reserve for the execution + of the entrypoint command, separately from any tasks or actors launched + by it. Defaults to 0. + :param entrypoint_num_gpus: The quantity of GPUs to reserve for the execution + of the entrypoint command, separately from any tasks or actors launched + by it. Defaults to 0. + :param entrypoint_memory: The quantity of memory to reserve for the + execution of the entrypoint command, separately from any tasks or + actors launched by it. Defaults to 0. + :param entrypoint_resources: The quantity of custom resources to reserve for the + execution of the entrypoint command, separately from any tasks or + actors launched by it. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = tuple( + {"entrypoint", "submission_id", "cluster_address"} | set(RayJobBaseOperator.template_fields) + ) + operator_extra_links = (RayJobLink(),) + + def __init__( + self, + cluster_address: str, + entrypoint: str, + get_job_logs: bool | None = False, + wait_for_job_done: bool | None = False, + runtime_env: dict[str, Any] | None = None, + metadata: dict[str, str] | None = None, + submission_id: str | None = None, + entrypoint_num_cpus: int | float | None = None, + entrypoint_num_gpus: int | float | None = None, + entrypoint_memory: int | None = None, + entrypoint_resources: dict[str, float] | None = None, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.cluster_address = cluster_address + self.get_job_logs = get_job_logs + self.wait_for_job_done = wait_for_job_done + self.entrypoint = entrypoint + self.runtime_env = runtime_env + self.metadata = metadata + self.submission_id = submission_id + self.entrypoint_num_cpus = entrypoint_num_cpus + self.entrypoint_num_gpus = entrypoint_num_gpus + self.entrypoint_memory = entrypoint_memory + self.entrypoint_resources = entrypoint_resources + + def _check_job_status(self, cluster_address, job_id) -> str: + "Check if the Job has reached terminated state." + while True: + try: + job_status = self.hook.get_job_status(cluster_address=cluster_address, job_id=job_id) + if job_status not in ["SUCCEEDED", "FAILED"]: + self.log.info("Job status: %s...", job_status) + continue Review Comment: 1. There should be a sleep interval + timeout handling. 2. Would be great if you could test this logic ########## providers/google/src/airflow/providers/google/cloud/operators/ray.py: ########## @@ -0,0 +1,430 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Google Cloud Ray Job operators.""" + +from __future__ import annotations + +from collections.abc import Sequence +from functools import cached_property +from typing import TYPE_CHECKING, Any + +from google.api_core.exceptions import NotFound + +from airflow.exceptions import AirflowException +from airflow.providers.google.cloud.hooks.ray import RayJobHook +from airflow.providers.google.cloud.links.ray import RayJobLink +from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator + +if TYPE_CHECKING: + from airflow.providers.common.compat.sdk import Context + + +class RayJobBaseOperator(GoogleCloudBaseOperator): + """ + Base class for Jobs on Ray operators. + + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ( + "gcp_conn_id", + "impersonation_chain", + ) + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + @cached_property + def hook(self) -> RayJobHook: + return RayJobHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + + +class RaySubmitJobOperator(RayJobBaseOperator): + """ + Submit and execute Job on Ray cluster. + + When a job is submitted, it runs once to completion or failure. Retries or + different runs with different parameters should be handled by the + submitter. Jobs are bound to the lifetime of a Ray cluster, so if the + cluster goes down, all running jobs on that cluster will be terminated. + + :param cluster_address: Required. Either (1) the address of the Ray cluster, or (2) the HTTP address + of the dashboard server on the head node, e.g. "http://<head-node-ip>:8265". + In case (1) it must be specified as an address that can be passed to + ray.init(), e.g. a Ray Client address (ray://<head_node_host>:10001), + or "auto", or "localhost:<port>". + :param entrypoint: Required. The shell command to run for this job. + :param get_job_logs: If set to True, the operator will wait until the end of + Job execution and output the logs. + :param wait_for_job_done: If set to True, the operator will wait until the end of + Job execution. Please note, that if the Job will fail during execution and + this parameter is set to False, there will be no indication of the failure. + :param submission_id: A unique ID for this job. + :param runtime_env: The runtime environment to install and run this job in. + :param metadata: Arbitrary data to store along with this job. + :param entrypoint_num_cpus: The quantity of CPU cores to reserve for the execution + of the entrypoint command, separately from any tasks or actors launched + by it. Defaults to 0. + :param entrypoint_num_gpus: The quantity of GPUs to reserve for the execution + of the entrypoint command, separately from any tasks or actors launched + by it. Defaults to 0. + :param entrypoint_memory: The quantity of memory to reserve for the + execution of the entrypoint command, separately from any tasks or + actors launched by it. Defaults to 0. + :param entrypoint_resources: The quantity of custom resources to reserve for the + execution of the entrypoint command, separately from any tasks or + actors launched by it. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = tuple( + {"entrypoint", "submission_id", "cluster_address"} | set(RayJobBaseOperator.template_fields) + ) + operator_extra_links = (RayJobLink(),) + + def __init__( + self, + cluster_address: str, + entrypoint: str, + get_job_logs: bool | None = False, + wait_for_job_done: bool | None = False, + runtime_env: dict[str, Any] | None = None, + metadata: dict[str, str] | None = None, + submission_id: str | None = None, + entrypoint_num_cpus: int | float | None = None, + entrypoint_num_gpus: int | float | None = None, + entrypoint_memory: int | None = None, + entrypoint_resources: dict[str, float] | None = None, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.cluster_address = cluster_address + self.get_job_logs = get_job_logs + self.wait_for_job_done = wait_for_job_done + self.entrypoint = entrypoint + self.runtime_env = runtime_env + self.metadata = metadata + self.submission_id = submission_id + self.entrypoint_num_cpus = entrypoint_num_cpus + self.entrypoint_num_gpus = entrypoint_num_gpus + self.entrypoint_memory = entrypoint_memory + self.entrypoint_resources = entrypoint_resources + + def _check_job_status(self, cluster_address, job_id) -> str: + "Check if the Job has reached terminated state." + while True: + try: + job_status = self.hook.get_job_status(cluster_address=cluster_address, job_id=job_id) + if job_status not in ["SUCCEEDED", "FAILED"]: + self.log.info("Job status: %s...", job_status) + continue + self.log.info("Job has finished execution with status: %s", job_status) + return job_status + except Exception: + raise AirflowException("Some error occurred when trying to get Job's status.") + + def _get_job_logs(self, cluster_address, job_id): + "Output Job logs." + try: + logs = self.hook.get_job_logs(cluster_address=cluster_address, job_id=job_id) + self.log.info("Got job logs:\n%s\n", logs) + except Exception: + raise AirflowException("Some error occurred when trying to get logs from Job.") + + def execute(self, context: Context): + if self.get_job_logs and not self.wait_for_job_done: + raise ValueError( + "Retrieving Job logs can be possible only after Job completion. " + "Please, enable wait_for_job_done parameter to be able to get logs." + ) + try: + self.log.info("Submitting Job on a Ray cluster...") + submitted_job_id = self.hook.submit_job( + cluster_address=self.cluster_address, + entrypoint=self.entrypoint, + runtime_env=self.runtime_env, + metadata=self.metadata, + submission_id=self.submission_id, + entrypoint_num_cpus=self.entrypoint_num_cpus, + entrypoint_num_gpus=self.entrypoint_num_gpus, + entrypoint_memory=self.entrypoint_memory, + entrypoint_resources=self.entrypoint_resources, + ) + self.log.info("Submitted Ray Job id=%s", submitted_job_id) + RayJobLink.persist( + context=context, + cluster_address=self.cluster_address, + job_id=submitted_job_id, + ) + except RuntimeError as exc: + raise exc Review Comment: Seems redundant ########## providers/google/src/airflow/providers/google/cloud/hooks/ray.py: ########## @@ -0,0 +1,223 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Google Cloud Ray Job hook.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from ray.job_submission import JobSubmissionClient + +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook + +if TYPE_CHECKING: + from ray.dashboard.modules.job.common import JobStatus + from ray.dashboard.modules.job.pydantic_models import JobDetails + + +class RayJobHook(GoogleBaseHook): + """Hook for Jobs APIs.""" + + def get_client(self, address: str): Review Comment: Please add a type hint ########## providers/google/src/airflow/providers/google/cloud/operators/ray.py: ########## @@ -0,0 +1,430 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Google Cloud Ray Job operators.""" + +from __future__ import annotations + +from collections.abc import Sequence +from functools import cached_property +from typing import TYPE_CHECKING, Any + +from google.api_core.exceptions import NotFound + +from airflow.exceptions import AirflowException +from airflow.providers.google.cloud.hooks.ray import RayJobHook +from airflow.providers.google.cloud.links.ray import RayJobLink +from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator + +if TYPE_CHECKING: + from airflow.providers.common.compat.sdk import Context + + +class RayJobBaseOperator(GoogleCloudBaseOperator): + """ + Base class for Jobs on Ray operators. + + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ( + "gcp_conn_id", + "impersonation_chain", + ) + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + @cached_property + def hook(self) -> RayJobHook: + return RayJobHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + + +class RaySubmitJobOperator(RayJobBaseOperator): + """ + Submit and execute Job on Ray cluster. + + When a job is submitted, it runs once to completion or failure. Retries or + different runs with different parameters should be handled by the + submitter. Jobs are bound to the lifetime of a Ray cluster, so if the + cluster goes down, all running jobs on that cluster will be terminated. + + :param cluster_address: Required. Either (1) the address of the Ray cluster, or (2) the HTTP address + of the dashboard server on the head node, e.g. "http://<head-node-ip>:8265". + In case (1) it must be specified as an address that can be passed to + ray.init(), e.g. a Ray Client address (ray://<head_node_host>:10001), + or "auto", or "localhost:<port>". + :param entrypoint: Required. The shell command to run for this job. + :param get_job_logs: If set to True, the operator will wait until the end of + Job execution and output the logs. + :param wait_for_job_done: If set to True, the operator will wait until the end of + Job execution. Please note, that if the Job will fail during execution and + this parameter is set to False, there will be no indication of the failure. + :param submission_id: A unique ID for this job. + :param runtime_env: The runtime environment to install and run this job in. + :param metadata: Arbitrary data to store along with this job. + :param entrypoint_num_cpus: The quantity of CPU cores to reserve for the execution + of the entrypoint command, separately from any tasks or actors launched + by it. Defaults to 0. + :param entrypoint_num_gpus: The quantity of GPUs to reserve for the execution + of the entrypoint command, separately from any tasks or actors launched + by it. Defaults to 0. + :param entrypoint_memory: The quantity of memory to reserve for the + execution of the entrypoint command, separately from any tasks or + actors launched by it. Defaults to 0. + :param entrypoint_resources: The quantity of custom resources to reserve for the + execution of the entrypoint command, separately from any tasks or + actors launched by it. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = tuple( + {"entrypoint", "submission_id", "cluster_address"} | set(RayJobBaseOperator.template_fields) + ) + operator_extra_links = (RayJobLink(),) + + def __init__( + self, + cluster_address: str, + entrypoint: str, + get_job_logs: bool | None = False, + wait_for_job_done: bool | None = False, + runtime_env: dict[str, Any] | None = None, + metadata: dict[str, str] | None = None, + submission_id: str | None = None, + entrypoint_num_cpus: int | float | None = None, + entrypoint_num_gpus: int | float | None = None, + entrypoint_memory: int | None = None, + entrypoint_resources: dict[str, float] | None = None, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.cluster_address = cluster_address + self.get_job_logs = get_job_logs + self.wait_for_job_done = wait_for_job_done + self.entrypoint = entrypoint + self.runtime_env = runtime_env + self.metadata = metadata + self.submission_id = submission_id + self.entrypoint_num_cpus = entrypoint_num_cpus + self.entrypoint_num_gpus = entrypoint_num_gpus + self.entrypoint_memory = entrypoint_memory + self.entrypoint_resources = entrypoint_resources + + def _check_job_status(self, cluster_address, job_id) -> str: + "Check if the Job has reached terminated state." + while True: + try: + job_status = self.hook.get_job_status(cluster_address=cluster_address, job_id=job_id) + if job_status not in ["SUCCEEDED", "FAILED"]: + self.log.info("Job status: %s...", job_status) + continue + self.log.info("Job has finished execution with status: %s", job_status) + return job_status + except Exception: + raise AirflowException("Some error occurred when trying to get Job's status.") + + def _get_job_logs(self, cluster_address, job_id): + "Output Job logs." + try: + logs = self.hook.get_job_logs(cluster_address=cluster_address, job_id=job_id) + self.log.info("Got job logs:\n%s\n", logs) + except Exception: + raise AirflowException("Some error occurred when trying to get logs from Job.") + + def execute(self, context: Context): Review Comment: What about making this a deferrable operator? (not urgent to implement as part of this PR, but make sure that it's reflected if currently not supported) ########## providers/google/src/airflow/providers/google/cloud/hooks/ray.py: ########## @@ -0,0 +1,223 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Google Cloud Ray Job hook.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from ray.job_submission import JobSubmissionClient + +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook + +if TYPE_CHECKING: + from ray.dashboard.modules.job.common import JobStatus + from ray.dashboard.modules.job.pydantic_models import JobDetails + + +class RayJobHook(GoogleBaseHook): + """Hook for Jobs APIs.""" + + def get_client(self, address: str): + """ + Create a client for submitting and interacting with jobs on a remote cluster. + + :param address: Either (1) the address of the Ray cluster, or (2) the HTTP address + of the dashboard server on the head node, e.g. "http://<head-node-ip>:8265". + In case (1) it must be specified as an address that can be passed to + ray.init(), e.g. a Ray Client address (ray://<head_node_host>:10001), + or "auto", or "localhost:<port>". + """ + if address.endswith("aiplatform-training.googleusercontent.com"): Review Comment: This security comment should be handled -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
