VladaZakharova commented on code in PR #59558: URL: https://github.com/apache/airflow/pull/59558#discussion_r2745310362
########## providers/google/src/airflow/providers/google/cloud/operators/ray.py: ########## @@ -0,0 +1,430 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Google Cloud Ray Job operators.""" + +from __future__ import annotations + +from collections.abc import Sequence +from functools import cached_property +from typing import TYPE_CHECKING, Any + +from google.api_core.exceptions import NotFound + +from airflow.exceptions import AirflowException +from airflow.providers.google.cloud.hooks.ray import RayJobHook +from airflow.providers.google.cloud.links.ray import RayJobLink +from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator + +if TYPE_CHECKING: + from airflow.providers.common.compat.sdk import Context + + +class RayJobBaseOperator(GoogleCloudBaseOperator): + """ + Base class for Jobs on Ray operators. + + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ( + "gcp_conn_id", + "impersonation_chain", + ) + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + @cached_property + def hook(self) -> RayJobHook: + return RayJobHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + + +class RaySubmitJobOperator(RayJobBaseOperator): + """ + Submit and execute Job on Ray cluster. + + When a job is submitted, it runs once to completion or failure. Retries or + different runs with different parameters should be handled by the + submitter. Jobs are bound to the lifetime of a Ray cluster, so if the + cluster goes down, all running jobs on that cluster will be terminated. + + :param cluster_address: Required. Either (1) the address of the Ray cluster, or (2) the HTTP address + of the dashboard server on the head node, e.g. "http://<head-node-ip>:8265". + In case (1) it must be specified as an address that can be passed to + ray.init(), e.g. a Ray Client address (ray://<head_node_host>:10001), + or "auto", or "localhost:<port>". + :param entrypoint: Required. The shell command to run for this job. + :param get_job_logs: If set to True, the operator will wait until the end of + Job execution and output the logs. + :param wait_for_job_done: If set to True, the operator will wait until the end of + Job execution. Please note, that if the Job will fail during execution and + this parameter is set to False, there will be no indication of the failure. + :param submission_id: A unique ID for this job. + :param runtime_env: The runtime environment to install and run this job in. + :param metadata: Arbitrary data to store along with this job. + :param entrypoint_num_cpus: The quantity of CPU cores to reserve for the execution + of the entrypoint command, separately from any tasks or actors launched + by it. Defaults to 0. + :param entrypoint_num_gpus: The quantity of GPUs to reserve for the execution + of the entrypoint command, separately from any tasks or actors launched + by it. Defaults to 0. + :param entrypoint_memory: The quantity of memory to reserve for the + execution of the entrypoint command, separately from any tasks or + actors launched by it. Defaults to 0. + :param entrypoint_resources: The quantity of custom resources to reserve for the + execution of the entrypoint command, separately from any tasks or + actors launched by it. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = tuple( + {"entrypoint", "submission_id", "cluster_address"} | set(RayJobBaseOperator.template_fields) + ) + operator_extra_links = (RayJobLink(),) + + def __init__( + self, + cluster_address: str, + entrypoint: str, + get_job_logs: bool | None = False, + wait_for_job_done: bool | None = False, + runtime_env: dict[str, Any] | None = None, + metadata: dict[str, str] | None = None, + submission_id: str | None = None, + entrypoint_num_cpus: int | float | None = None, + entrypoint_num_gpus: int | float | None = None, + entrypoint_memory: int | None = None, + entrypoint_resources: dict[str, float] | None = None, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.cluster_address = cluster_address + self.get_job_logs = get_job_logs + self.wait_for_job_done = wait_for_job_done + self.entrypoint = entrypoint + self.runtime_env = runtime_env + self.metadata = metadata + self.submission_id = submission_id + self.entrypoint_num_cpus = entrypoint_num_cpus + self.entrypoint_num_gpus = entrypoint_num_gpus + self.entrypoint_memory = entrypoint_memory + self.entrypoint_resources = entrypoint_resources + + def _check_job_status(self, cluster_address, job_id) -> str: + "Check if the Job has reached terminated state." + while True: + try: + job_status = self.hook.get_job_status(cluster_address=cluster_address, job_id=job_id) + if job_status not in ["SUCCEEDED", "FAILED"]: + self.log.info("Job status: %s...", job_status) + continue + self.log.info("Job has finished execution with status: %s", job_status) + return job_status + except Exception: + raise AirflowException("Some error occurred when trying to get Job's status.") Review Comment: I think you are right, maybe exception here is too much. Playing "too safe" will make any good here :D ########## providers/google/src/airflow/providers/google/cloud/operators/ray.py: ########## @@ -0,0 +1,430 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Google Cloud Ray Job operators.""" + +from __future__ import annotations + +from collections.abc import Sequence +from functools import cached_property +from typing import TYPE_CHECKING, Any + +from google.api_core.exceptions import NotFound + +from airflow.exceptions import AirflowException +from airflow.providers.google.cloud.hooks.ray import RayJobHook +from airflow.providers.google.cloud.links.ray import RayJobLink +from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator + +if TYPE_CHECKING: + from airflow.providers.common.compat.sdk import Context + + +class RayJobBaseOperator(GoogleCloudBaseOperator): + """ + Base class for Jobs on Ray operators. + + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ( + "gcp_conn_id", + "impersonation_chain", + ) + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + @cached_property + def hook(self) -> RayJobHook: + return RayJobHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + + +class RaySubmitJobOperator(RayJobBaseOperator): + """ + Submit and execute Job on Ray cluster. + + When a job is submitted, it runs once to completion or failure. Retries or + different runs with different parameters should be handled by the + submitter. Jobs are bound to the lifetime of a Ray cluster, so if the + cluster goes down, all running jobs on that cluster will be terminated. + + :param cluster_address: Required. Either (1) the address of the Ray cluster, or (2) the HTTP address + of the dashboard server on the head node, e.g. "http://<head-node-ip>:8265". + In case (1) it must be specified as an address that can be passed to + ray.init(), e.g. a Ray Client address (ray://<head_node_host>:10001), + or "auto", or "localhost:<port>". + :param entrypoint: Required. The shell command to run for this job. + :param get_job_logs: If set to True, the operator will wait until the end of + Job execution and output the logs. + :param wait_for_job_done: If set to True, the operator will wait until the end of + Job execution. Please note, that if the Job will fail during execution and + this parameter is set to False, there will be no indication of the failure. + :param submission_id: A unique ID for this job. + :param runtime_env: The runtime environment to install and run this job in. + :param metadata: Arbitrary data to store along with this job. + :param entrypoint_num_cpus: The quantity of CPU cores to reserve for the execution + of the entrypoint command, separately from any tasks or actors launched + by it. Defaults to 0. + :param entrypoint_num_gpus: The quantity of GPUs to reserve for the execution + of the entrypoint command, separately from any tasks or actors launched + by it. Defaults to 0. + :param entrypoint_memory: The quantity of memory to reserve for the + execution of the entrypoint command, separately from any tasks or + actors launched by it. Defaults to 0. + :param entrypoint_resources: The quantity of custom resources to reserve for the + execution of the entrypoint command, separately from any tasks or + actors launched by it. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = tuple( + {"entrypoint", "submission_id", "cluster_address"} | set(RayJobBaseOperator.template_fields) + ) + operator_extra_links = (RayJobLink(),) + + def __init__( + self, + cluster_address: str, + entrypoint: str, + get_job_logs: bool | None = False, + wait_for_job_done: bool | None = False, + runtime_env: dict[str, Any] | None = None, + metadata: dict[str, str] | None = None, + submission_id: str | None = None, + entrypoint_num_cpus: int | float | None = None, + entrypoint_num_gpus: int | float | None = None, + entrypoint_memory: int | None = None, + entrypoint_resources: dict[str, float] | None = None, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.cluster_address = cluster_address + self.get_job_logs = get_job_logs + self.wait_for_job_done = wait_for_job_done + self.entrypoint = entrypoint + self.runtime_env = runtime_env + self.metadata = metadata + self.submission_id = submission_id + self.entrypoint_num_cpus = entrypoint_num_cpus + self.entrypoint_num_gpus = entrypoint_num_gpus + self.entrypoint_memory = entrypoint_memory + self.entrypoint_resources = entrypoint_resources + + def _check_job_status(self, cluster_address, job_id) -> str: + "Check if the Job has reached terminated state." + while True: + try: + job_status = self.hook.get_job_status(cluster_address=cluster_address, job_id=job_id) + if job_status not in ["SUCCEEDED", "FAILED"]: + self.log.info("Job status: %s...", job_status) + continue + self.log.info("Job has finished execution with status: %s", job_status) + return job_status + except Exception: + raise AirflowException("Some error occurred when trying to get Job's status.") + + def _get_job_logs(self, cluster_address, job_id): + "Output Job logs." + try: + logs = self.hook.get_job_logs(cluster_address=cluster_address, job_id=job_id) + self.log.info("Got job logs:\n%s\n", logs) + except Exception: + raise AirflowException("Some error occurred when trying to get logs from Job.") + + def execute(self, context: Context): Review Comment: For me the task was to implement basic operators logic. Usually deferrable mode is coming after the idea of in general adding new operators is supported in the community. But I agree, it is needed. I would prefer to add deferrable mode for all the operators haha ########## providers/google/src/airflow/providers/google/cloud/operators/ray.py: ########## @@ -0,0 +1,430 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Google Cloud Ray Job operators.""" + +from __future__ import annotations + +from collections.abc import Sequence +from functools import cached_property +from typing import TYPE_CHECKING, Any + +from google.api_core.exceptions import NotFound + +from airflow.exceptions import AirflowException +from airflow.providers.google.cloud.hooks.ray import RayJobHook +from airflow.providers.google.cloud.links.ray import RayJobLink +from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator + +if TYPE_CHECKING: + from airflow.providers.common.compat.sdk import Context + + +class RayJobBaseOperator(GoogleCloudBaseOperator): + """ + Base class for Jobs on Ray operators. + + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ( + "gcp_conn_id", + "impersonation_chain", + ) + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + @cached_property + def hook(self) -> RayJobHook: + return RayJobHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + + +class RaySubmitJobOperator(RayJobBaseOperator): + """ + Submit and execute Job on Ray cluster. + + When a job is submitted, it runs once to completion or failure. Retries or + different runs with different parameters should be handled by the + submitter. Jobs are bound to the lifetime of a Ray cluster, so if the + cluster goes down, all running jobs on that cluster will be terminated. + + :param cluster_address: Required. Either (1) the address of the Ray cluster, or (2) the HTTP address + of the dashboard server on the head node, e.g. "http://<head-node-ip>:8265". + In case (1) it must be specified as an address that can be passed to + ray.init(), e.g. a Ray Client address (ray://<head_node_host>:10001), + or "auto", or "localhost:<port>". + :param entrypoint: Required. The shell command to run for this job. + :param get_job_logs: If set to True, the operator will wait until the end of + Job execution and output the logs. + :param wait_for_job_done: If set to True, the operator will wait until the end of + Job execution. Please note, that if the Job will fail during execution and + this parameter is set to False, there will be no indication of the failure. + :param submission_id: A unique ID for this job. + :param runtime_env: The runtime environment to install and run this job in. + :param metadata: Arbitrary data to store along with this job. + :param entrypoint_num_cpus: The quantity of CPU cores to reserve for the execution + of the entrypoint command, separately from any tasks or actors launched + by it. Defaults to 0. + :param entrypoint_num_gpus: The quantity of GPUs to reserve for the execution + of the entrypoint command, separately from any tasks or actors launched + by it. Defaults to 0. + :param entrypoint_memory: The quantity of memory to reserve for the + execution of the entrypoint command, separately from any tasks or + actors launched by it. Defaults to 0. + :param entrypoint_resources: The quantity of custom resources to reserve for the + execution of the entrypoint command, separately from any tasks or + actors launched by it. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = tuple( + {"entrypoint", "submission_id", "cluster_address"} | set(RayJobBaseOperator.template_fields) + ) + operator_extra_links = (RayJobLink(),) + + def __init__( + self, + cluster_address: str, + entrypoint: str, + get_job_logs: bool | None = False, + wait_for_job_done: bool | None = False, + runtime_env: dict[str, Any] | None = None, + metadata: dict[str, str] | None = None, + submission_id: str | None = None, + entrypoint_num_cpus: int | float | None = None, + entrypoint_num_gpus: int | float | None = None, + entrypoint_memory: int | None = None, + entrypoint_resources: dict[str, float] | None = None, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.cluster_address = cluster_address + self.get_job_logs = get_job_logs + self.wait_for_job_done = wait_for_job_done + self.entrypoint = entrypoint + self.runtime_env = runtime_env + self.metadata = metadata + self.submission_id = submission_id + self.entrypoint_num_cpus = entrypoint_num_cpus + self.entrypoint_num_gpus = entrypoint_num_gpus + self.entrypoint_memory = entrypoint_memory + self.entrypoint_resources = entrypoint_resources + + def _check_job_status(self, cluster_address, job_id) -> str: + "Check if the Job has reached terminated state." + while True: + try: + job_status = self.hook.get_job_status(cluster_address=cluster_address, job_id=job_id) + if job_status not in ["SUCCEEDED", "FAILED"]: + self.log.info("Job status: %s...", job_status) + continue Review Comment: I have tested this in the system tests :) I think you meant unit tests then, I will add ########## providers/google/src/airflow/providers/google/cloud/operators/ray.py: ########## @@ -0,0 +1,430 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Google Cloud Ray Job operators.""" + +from __future__ import annotations + +from collections.abc import Sequence +from functools import cached_property +from typing import TYPE_CHECKING, Any + +from google.api_core.exceptions import NotFound + +from airflow.exceptions import AirflowException +from airflow.providers.google.cloud.hooks.ray import RayJobHook +from airflow.providers.google.cloud.links.ray import RayJobLink +from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator + +if TYPE_CHECKING: + from airflow.providers.common.compat.sdk import Context + + +class RayJobBaseOperator(GoogleCloudBaseOperator): + """ + Base class for Jobs on Ray operators. + + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ( + "gcp_conn_id", + "impersonation_chain", + ) + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + @cached_property + def hook(self) -> RayJobHook: + return RayJobHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + + +class RaySubmitJobOperator(RayJobBaseOperator): + """ + Submit and execute Job on Ray cluster. + + When a job is submitted, it runs once to completion or failure. Retries or + different runs with different parameters should be handled by the + submitter. Jobs are bound to the lifetime of a Ray cluster, so if the + cluster goes down, all running jobs on that cluster will be terminated. + + :param cluster_address: Required. Either (1) the address of the Ray cluster, or (2) the HTTP address + of the dashboard server on the head node, e.g. "http://<head-node-ip>:8265". + In case (1) it must be specified as an address that can be passed to + ray.init(), e.g. a Ray Client address (ray://<head_node_host>:10001), + or "auto", or "localhost:<port>". + :param entrypoint: Required. The shell command to run for this job. + :param get_job_logs: If set to True, the operator will wait until the end of + Job execution and output the logs. + :param wait_for_job_done: If set to True, the operator will wait until the end of + Job execution. Please note, that if the Job will fail during execution and + this parameter is set to False, there will be no indication of the failure. + :param submission_id: A unique ID for this job. + :param runtime_env: The runtime environment to install and run this job in. + :param metadata: Arbitrary data to store along with this job. + :param entrypoint_num_cpus: The quantity of CPU cores to reserve for the execution + of the entrypoint command, separately from any tasks or actors launched + by it. Defaults to 0. + :param entrypoint_num_gpus: The quantity of GPUs to reserve for the execution + of the entrypoint command, separately from any tasks or actors launched + by it. Defaults to 0. + :param entrypoint_memory: The quantity of memory to reserve for the + execution of the entrypoint command, separately from any tasks or + actors launched by it. Defaults to 0. + :param entrypoint_resources: The quantity of custom resources to reserve for the + execution of the entrypoint command, separately from any tasks or + actors launched by it. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = tuple( + {"entrypoint", "submission_id", "cluster_address"} | set(RayJobBaseOperator.template_fields) + ) + operator_extra_links = (RayJobLink(),) + + def __init__( + self, + cluster_address: str, + entrypoint: str, + get_job_logs: bool | None = False, + wait_for_job_done: bool | None = False, + runtime_env: dict[str, Any] | None = None, + metadata: dict[str, str] | None = None, + submission_id: str | None = None, + entrypoint_num_cpus: int | float | None = None, + entrypoint_num_gpus: int | float | None = None, + entrypoint_memory: int | None = None, + entrypoint_resources: dict[str, float] | None = None, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.cluster_address = cluster_address + self.get_job_logs = get_job_logs + self.wait_for_job_done = wait_for_job_done + self.entrypoint = entrypoint + self.runtime_env = runtime_env + self.metadata = metadata + self.submission_id = submission_id + self.entrypoint_num_cpus = entrypoint_num_cpus + self.entrypoint_num_gpus = entrypoint_num_gpus + self.entrypoint_memory = entrypoint_memory + self.entrypoint_resources = entrypoint_resources + + def _check_job_status(self, cluster_address, job_id) -> str: + "Check if the Job has reached terminated state." + while True: + try: + job_status = self.hook.get_job_status(cluster_address=cluster_address, job_id=job_id) + if job_status not in ["SUCCEEDED", "FAILED"]: + self.log.info("Job status: %s...", job_status) + continue + self.log.info("Job has finished execution with status: %s", job_status) + return job_status + except Exception: + raise AirflowException("Some error occurred when trying to get Job's status.") + + def _get_job_logs(self, cluster_address, job_id): + "Output Job logs." + try: + logs = self.hook.get_job_logs(cluster_address=cluster_address, job_id=job_id) + self.log.info("Got job logs:\n%s\n", logs) + except Exception: + raise AirflowException("Some error occurred when trying to get logs from Job.") + + def execute(self, context: Context): + if self.get_job_logs and not self.wait_for_job_done: + raise ValueError( + "Retrieving Job logs can be possible only after Job completion. " + "Please, enable wait_for_job_done parameter to be able to get logs." + ) + try: + self.log.info("Submitting Job on a Ray cluster...") + submitted_job_id = self.hook.submit_job( + cluster_address=self.cluster_address, + entrypoint=self.entrypoint, + runtime_env=self.runtime_env, + metadata=self.metadata, + submission_id=self.submission_id, + entrypoint_num_cpus=self.entrypoint_num_cpus, + entrypoint_num_gpus=self.entrypoint_num_gpus, + entrypoint_memory=self.entrypoint_memory, + entrypoint_resources=self.entrypoint_resources, + ) + self.log.info("Submitted Ray Job id=%s", submitted_job_id) + RayJobLink.persist( + context=context, + cluster_address=self.cluster_address, + job_id=submitted_job_id, + ) + except RuntimeError as exc: + raise exc Review Comment: why? I am catching the error that API can send, so at least the operator will not be marked as succeeded if the job was actually not submitted and there was an error ########## providers/google/tests/system/google/cloud/ray/example_ray_job.py: ########## @@ -0,0 +1,168 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +""" +Example Airflow DAG for Jobs on Ray operations. +""" + +from __future__ import annotations + +import os +from datetime import datetime + +from airflow.exceptions import AirflowOptionalProviderFeatureException +from airflow.models.dag import DAG +from airflow.providers.google.cloud.operators.ray import ( + RayDeleteJobOperator, + RayGetJobInfoOperator, + RayListJobsOperator, + RayStopJobOperator, + RaySubmitJobOperator, +) +from airflow.providers.google.cloud.operators.vertex_ai.ray import ( + CreateRayClusterOperator, + DeleteRayClusterOperator, + GetRayClusterOperator, +) + +try: + from google.cloud.aiplatform.vertex_ray.util import resources +except ImportError: + raise AirflowOptionalProviderFeatureException( + "The ray provider is optional and requires the `google-cloud-aiplatform` package to be installed. " + ) +try: + from airflow.sdk import TriggerRule +except ImportError: + # Compatibility for Airflow < 3.1 + from airflow.utils.trigger_rule import TriggerRule # type: ignore[no-redef,attr-defined] + +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default") +PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default") +DAG_ID = "ray_job_operations" +LOCATION = "us-central1" +JOB_ID = f"{DAG_ID}_{ENV_ID}".replace("-", "_") +WORKER_NODE_RESOURCES = resources.Resources( + node_count=1, +) + +with DAG( + DAG_ID, + schedule="@once", + start_date=datetime(2021, 1, 1), + catchup=False, + render_template_as_native_obj=True, + tags=["example", "job", "ray"], +) as dag: + create_ray_cluster = CreateRayClusterOperator( + task_id="create_ray_cluster", + project_id=PROJECT_ID, + location=LOCATION, + worker_node_types=[WORKER_NODE_RESOURCES], + python_version="3.10", + ray_version="2.33", + ) + + get_ray_cluster = GetRayClusterOperator( + task_id="get_ray_cluster", + project_id=PROJECT_ID, + location=LOCATION, + cluster_id=create_ray_cluster.output["cluster_id"], + ) + + # [START how_to_ray_submit_job] + submit_ray_job = RaySubmitJobOperator( + task_id="submit_ray_job", + cluster_address="{{ task_instance.xcom_pull(task_ids='get_ray_cluster')['dashboard_address'] }}", + entrypoint="python3 heavy.py", + runtime_env={ + "working_dir": "./providers/google/tests/system/google/cloud/vertex_ai/resources", Review Comment: thanks! -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
