This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 66d0222947 Add CloudRunHook and operators (#33067)
66d0222947 is described below
commit 66d0222947bf6ec779064d25c16391e22c1a9368
Author: Freddy Demiane <[email protected]>
AuthorDate: Fri Aug 25 00:42:57 2023 +0200
Add CloudRunHook and operators (#33067)
---
airflow/providers/google/cloud/hooks/cloud_run.py | 176 ++++++++++
.../providers/google/cloud/operators/cloud_run.py | 354 +++++++++++++++++++++
.../providers/google/cloud/triggers/cloud_run.py | 142 +++++++++
airflow/providers/google/provider.yaml | 15 +
.../operators/cloud/cloud_run.rst | 128 ++++++++
generated/provider_dependencies.json | 1 +
.../providers/google/cloud/hooks/test_cloud_run.py | 273 ++++++++++++++++
.../google/cloud/operators/test_cloud_run.py | 308 ++++++++++++++++++
.../google/cloud/triggers/test_cloud_run.py | 155 +++++++++
.../providers/google/cloud/cloud_run/__init__.py | 16 +
.../google/cloud/cloud_run/example_cloud_run.py | 261 +++++++++++++++
11 files changed, 1829 insertions(+)
diff --git a/airflow/providers/google/cloud/hooks/cloud_run.py
b/airflow/providers/google/cloud/hooks/cloud_run.py
new file mode 100644
index 0000000000..f33b5fa3af
--- /dev/null
+++ b/airflow/providers/google/cloud/hooks/cloud_run.py
@@ -0,0 +1,176 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import itertools
+from typing import Iterable, Sequence
+
+from google.api_core import operation
+from google.cloud.run_v2 import (
+ CreateJobRequest,
+ DeleteJobRequest,
+ GetJobRequest,
+ Job,
+ JobsAsyncClient,
+ JobsClient,
+ ListJobsRequest,
+ RunJobRequest,
+ UpdateJobRequest,
+)
+from google.cloud.run_v2.services.jobs import pagers
+from google.longrunning import operations_pb2
+
+from airflow.exceptions import AirflowException
+from airflow.providers.google.common.consts import CLIENT_INFO
+from airflow.providers.google.common.hooks.base_google import
PROVIDE_PROJECT_ID, GoogleBaseHook
+
+
+class CloudRunHook(GoogleBaseHook):
+ """
+ Hook for the Google Cloud Run service.
+
+ :param gcp_conn_id: The connection ID to use when fetching connection info.
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account.
+ """
+
+ def __init__(
+ self,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ ) -> None:
+ super().__init__(gcp_conn_id=gcp_conn_id,
impersonation_chain=impersonation_chain)
+ self._client: JobsClient | None = None
+
+ def get_conn(self):
+ """
+ Retrieves connection to Cloud Run.
+
+ :return: Cloud Run Jobs client object.
+ """
+ if self._client is None:
+ self._client = JobsClient(credentials=self.get_credentials(),
client_info=CLIENT_INFO)
+ return self._client
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def delete_job(self, job_name: str, region: str, project_id: str =
PROVIDE_PROJECT_ID) -> Job:
+ delete_request = DeleteJobRequest()
+ delete_request.name =
f"projects/{project_id}/locations/{region}/jobs/{job_name}"
+
+ operation = self.get_conn().delete_job(delete_request)
+ return operation.result()
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def create_job(
+ self, job_name: str, job: Job | dict, region: str, project_id: str =
PROVIDE_PROJECT_ID
+ ) -> Job:
+ if isinstance(job, dict):
+ job = Job(job)
+
+ create_request = CreateJobRequest()
+ create_request.job = job
+ create_request.job_id = job_name
+ create_request.parent = f"projects/{project_id}/locations/{region}"
+
+ operation = self.get_conn().create_job(create_request)
+ return operation.result()
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def update_job(
+ self, job_name: str, job: Job | dict, region: str, project_id: str =
PROVIDE_PROJECT_ID
+ ) -> Job:
+ if isinstance(job, dict):
+ job = Job(job)
+
+ update_request = UpdateJobRequest()
+ job.name = f"projects/{project_id}/locations/{region}/jobs/{job_name}"
+ update_request.job = job
+ operation = self.get_conn().update_job(update_request)
+ return operation.result()
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def execute_job(
+ self, job_name: str, region: str, project_id: str = PROVIDE_PROJECT_ID
+ ) -> operation.Operation:
+ run_job_request =
RunJobRequest(name=f"projects/{project_id}/locations/{region}/jobs/{job_name}")
+ operation = self.get_conn().run_job(request=run_job_request)
+ return operation
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def get_job(self, job_name: str, region: str, project_id: str =
PROVIDE_PROJECT_ID):
+ get_job_request =
GetJobRequest(name=f"projects/{project_id}/locations/{region}/jobs/{job_name}")
+ return self.get_conn().get_job(get_job_request)
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def list_jobs(
+ self,
+ region: str,
+ project_id: str = PROVIDE_PROJECT_ID,
+ show_deleted: bool = False,
+ limit: int | None = None,
+ ) -> Iterable[Job]:
+
+ if limit is not None and limit < 0:
+ raise AirflowException("The limit for the list jobs request should
be greater or equal to zero")
+
+ list_jobs_request: ListJobsRequest = ListJobsRequest(
+ parent=f"projects/{project_id}/locations/{region}",
show_deleted=show_deleted
+ )
+
+ jobs: pagers.ListJobsPager =
self.get_conn().list_jobs(request=list_jobs_request)
+
+ return list(itertools.islice(jobs, limit))
+
+
+class CloudRunAsyncHook(GoogleBaseHook):
+ """
+ Async hook for the Google Cloud Run service.
+
+ :param gcp_conn_id: The connection ID to use when fetching connection info.
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account.
+ """
+
+ def __init__(
+ self,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ ):
+ self._client: JobsAsyncClient = JobsAsyncClient()
+ super().__init__(gcp_conn_id=gcp_conn_id,
impersonation_chain=impersonation_chain)
+
+ def get_conn(self):
+ if self._client is None:
+ self._client = JobsAsyncClient(credentials=self.get_credentials(),
client_info=CLIENT_INFO)
+
+ return self._client
+
+ async def get_operation(self, operation_name: str) ->
operations_pb2.Operation:
+ return await
self.get_conn().get_operation(operations_pb2.GetOperationRequest(name=operation_name))
diff --git a/airflow/providers/google/cloud/operators/cloud_run.py
b/airflow/providers/google/cloud/operators/cloud_run.py
new file mode 100644
index 0000000000..d27b17973d
--- /dev/null
+++ b/airflow/providers/google/cloud/operators/cloud_run.py
@@ -0,0 +1,354 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Sequence
+
+from google.api_core import operation
+from google.cloud.run_v2 import Job
+from google.cloud.run_v2.types import Execution
+
+from airflow.configuration import conf
+from airflow.exceptions import AirflowException
+from airflow.providers.google.cloud.hooks.cloud_run import CloudRunHook
+from airflow.providers.google.cloud.operators.cloud_base import
GoogleCloudBaseOperator
+from airflow.providers.google.cloud.triggers.cloud_run import
CloudRunJobFinishedTrigger, RunJobStatus
+
+if TYPE_CHECKING:
+ from airflow.utils.context import Context
+
+
+class CloudRunCreateJobOperator(GoogleCloudBaseOperator):
+ """
+ Creates a job without executing it. Pushes the created job to xcom.
+
+ :param project_id: Required. The ID of the Google Cloud project that the
service belongs to.
+ :param region: Required. The ID of the Google Cloud region that the
service belongs to.
+ :param job_name: Required. The name of the job to create.
+ :param job: Required. The job descriptor containing the configuration of
the job to submit.
+ :param gcp_conn_id: The connection ID used to connect to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account
(templated).
+ """
+
+ template_fields = ("project_id", "region", "gcp_conn_id",
"impersonation_chain", "job_name")
+
+ def __init__(
+ self,
+ project_id: str,
+ region: str,
+ job_name: str,
+ job: dict | Job,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.project_id = project_id
+ self.region = region
+ self.job_name = job_name
+ self.job = job
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+
+ def execute(self, context: Context):
+
+ hook: CloudRunHook = CloudRunHook(
+ gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain
+ )
+ job = hook.create_job(
+ job_name=self.job_name, job=self.job, region=self.region,
project_id=self.project_id
+ )
+
+ return Job.to_dict(job)
+
+
+class CloudRunUpdateJobOperator(GoogleCloudBaseOperator):
+ """
+ Updates a job and wait for the operation to be completed. Pushes the
updated job to xcom.
+
+ :param project_id: Required. The ID of the Google Cloud project that the
service belongs to.
+ :param region: Required. The ID of the Google Cloud region that the
service belongs to.
+ :param job_name: Required. The name of the job to update.
+ :param job: Required. The job descriptor containing the new configuration
of the job to update.
+ The name field will be replaced by job_name
+ :param gcp_conn_id: The connection ID used to connect to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account
(templated).
+ """
+
+ template_fields = ("project_id", "region", "gcp_conn_id",
"impersonation_chain", "job_name")
+
+ def __init__(
+ self,
+ project_id: str,
+ region: str,
+ job_name: str,
+ job: dict | Job,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.project_id = project_id
+ self.region = region
+ self.job_name = job_name
+ self.job = job
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+
+ def execute(self, context: Context):
+
+ hook: CloudRunHook = CloudRunHook(
+ gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain
+ )
+ job = hook.update_job(
+ job_name=self.job_name, job=self.job, region=self.region,
project_id=self.project_id
+ )
+
+ return Job.to_dict(job)
+
+
+class CloudRunDeleteJobOperator(GoogleCloudBaseOperator):
+ """
+ Deletes a job and wait for the the operation to be completed. Pushes the
deleted job to xcom.
+
+ :param project_id: Required. The ID of the Google Cloud project that the
service belongs to.
+ :param region: Required. The ID of the Google Cloud region that the
service belongs to.
+ :param job_name: Required. The name of the job to delete.
+ :param gcp_conn_id: The connection ID used to connect to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account
(templated).
+ """
+
+ template_fields = ("project_id", "region", "gcp_conn_id",
"impersonation_chain", "job_name")
+
+ def __init__(
+ self,
+ project_id: str,
+ region: str,
+ job_name: str,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.project_id = project_id
+ self.region = region
+ self.job_name = job_name
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+
+ def execute(self, context: Context):
+
+ hook: CloudRunHook = CloudRunHook(
+ gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain
+ )
+ job = hook.delete_job(job_name=self.job_name, region=self.region,
project_id=self.project_id)
+
+ return Job.to_dict(job)
+
+
+class CloudRunListJobsOperator(GoogleCloudBaseOperator):
+ """
+ Lists jobs.
+
+ :param project_id: Required. The ID of the Google Cloud project that the
service belongs to.
+ :param region: Required. The ID of the Google Cloud region that the
service belongs to.
+ :param show_deleted: If true, returns deleted (but unexpired)
+ resources along with active ones.
+ :param limit: The number of jobs to list. If left empty,
+ all the jobs will be returned.
+ :param gcp_conn_id: The connection ID used to connect to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account
(templated).
+ """
+
+ template_fields = (
+ "project_id",
+ "region",
+ "gcp_conn_id",
+ "impersonation_chain",
+ )
+
+ def __init__(
+ self,
+ project_id: str,
+ region: str,
+ show_deleted: bool = False,
+ limit: int | None = None,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.project_id = project_id
+ self.region = region
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+ self.show_deleted = show_deleted
+ self.limit = limit
+ if limit is not None and limit < 0:
+ raise AirflowException("The limit for the list jobs request should
be greater or equal to zero")
+
+ def execute(self, context: Context):
+ hook: CloudRunHook = CloudRunHook(
+ gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain
+ )
+ jobs = hook.list_jobs(
+ region=self.region, project_id=self.project_id,
show_deleted=self.show_deleted, limit=self.limit
+ )
+
+ return [Job.to_dict(job) for job in jobs]
+
+
+class CloudRunExecuteJobOperator(GoogleCloudBaseOperator):
+ """
+ Executes a job and wait for the operation to be completed. Pushes the
executed job to xcom.
+
+ :param project_id: Required. The ID of the Google Cloud project that the
service belongs to.
+ :param region: Required. The ID of the Google Cloud region that the
service belongs to.
+ :param job_name: Required. The name of the job to update.
+ :param job: Required. The job descriptor containing the new configuration
of the job to update.
+ The name field will be replaced by job_name
+ :param gcp_conn_id: The connection ID used to connect to Google Cloud.
+ :param polling_period_seconds: Optional: Control the rate of the poll for
the result of deferrable run.
+ By default, the trigger will poll every 10 seconds.
+ :param timeout: The timeout for this request.
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account
(templated).
+ :param deferrable: Run operator in the deferrable mode
+ """
+
+ template_fields = ("project_id", "region", "gcp_conn_id",
"impersonation_chain", "job_name")
+
+ def __init__(
+ self,
+ project_id: str,
+ region: str,
+ job_name: str,
+ polling_period_seconds: float = 10,
+ timeout_seconds: float | None = None,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.project_id = project_id
+ self.region = region
+ self.job_name = job_name
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+ self.polling_period_seconds = polling_period_seconds
+ self.timeout_seconds = timeout_seconds
+ self.deferrable = deferrable
+ self.operation: operation.Operation | None = None
+
+ def execute(self, context: Context):
+ hook: CloudRunHook = CloudRunHook(
+ gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain
+ )
+ self.operation = hook.execute_job(
+ region=self.region, project_id=self.project_id,
job_name=self.job_name
+ )
+
+ if not self.deferrable:
+ result: Execution = self._wait_for_operation(self.operation)
+ self._fail_if_execution_failed(result)
+ job = hook.get_job(job_name=result.job, region=self.region)
+ return Job.to_dict(job)
+ else:
+ self.defer(
+ trigger=CloudRunJobFinishedTrigger(
+ operation_name=self.operation.operation.name,
+ job_name=self.job_name,
+ project_id=self.project_id,
+ location=self.region,
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ polling_period_seconds=self.polling_period_seconds,
+ ),
+ method_name="execute_complete",
+ )
+
+ def execute_complete(self, context: Context, event: dict):
+ status = event["status"]
+
+ if status == RunJobStatus.TIMEOUT:
+ raise AirflowException("Operation timed out")
+
+ if status == RunJobStatus.FAIL:
+ error_code = event["operation_error_code"]
+ error_message = event["operation_error_message"]
+ raise AirflowException(
+ f"Operation failed with error code [{error_code}] and error
message [{error_message}]"
+ )
+
+ hook: CloudRunHook = CloudRunHook(self.gcp_conn_id,
self.impersonation_chain)
+
+ job = hook.get_job(job_name=event["job_name"], region=self.region)
+ return Job.to_dict(job)
+
+ def _fail_if_execution_failed(self, execution: Execution):
+ task_count = execution.task_count
+ succeeded_count = execution.succeeded_count
+ failed_count = execution.failed_count
+
+ if succeeded_count + failed_count != task_count:
+ raise AirflowException("Not all tasks finished execution")
+
+ if failed_count > 0:
+ raise AirflowException("Some tasks failed execution")
+
+ def _wait_for_operation(self, operation: operation.Operation):
+ try:
+ return operation.result(timeout=self.timeout_seconds)
+ except Exception:
+ error = operation.exception(timeout=self.timeout_seconds)
+ raise AirflowException(error)
diff --git a/airflow/providers/google/cloud/triggers/cloud_run.py
b/airflow/providers/google/cloud/triggers/cloud_run.py
new file mode 100644
index 0000000000..ddbd74864c
--- /dev/null
+++ b/airflow/providers/google/cloud/triggers/cloud_run.py
@@ -0,0 +1,142 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+from enum import Enum
+from typing import Any, AsyncIterator, Sequence
+
+from google.longrunning import operations_pb2
+
+from airflow.exceptions import AirflowException
+from airflow.providers.google.cloud.hooks.cloud_run import CloudRunAsyncHook
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+DEFAULT_BATCH_LOCATION = "us-central1"
+
+
+class RunJobStatus(Enum):
+ """Enum to represent the status of a job run."""
+
+ SUCCESS = "Success"
+ FAIL = "Fail"
+ TIMEOUT = "Timeout"
+
+
+class CloudRunJobFinishedTrigger(BaseTrigger):
+ """Cloud Run trigger to check if templated job has been finished.
+
+ :param operation_name: Required. Name of the operation.
+ :param job_name: Required. Name of the job.
+ :param project_id: Required. the Google Cloud project ID in which the job
was started.
+ :param location: Optional. the location where job is executed.
+ If set to None then the value of DEFAULT_BATCH_LOCATION will be used.
+ :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
+ :param impersonation_chain: Optional. Service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token of the last account
+ in the list, which will be impersonated in the request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account
(templated).
+ :param poll_sleep: Polling period in seconds to check for the status.
+ :timeout: The time to wait before failing the operation.
+ """
+
+ def __init__(
+ self,
+ operation_name: str,
+ job_name: str,
+ project_id: str | None,
+ location: str = DEFAULT_BATCH_LOCATION,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ polling_period_seconds: float = 10,
+ timeout: float | None = None,
+ ):
+ super().__init__()
+ self.project_id = project_id
+ self.job_name = job_name
+ self.operation_name = operation_name
+ self.location = location
+ self.gcp_conn_id = gcp_conn_id
+ self.polling_period_seconds = polling_period_seconds
+ self.timeout = timeout
+ self.impersonation_chain = impersonation_chain
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ """Serializes class arguments and classpath."""
+ return (
+
"airflow.providers.google.cloud.triggers.cloud_run.CloudRunJobFinishedTrigger",
+ {
+ "project_id": self.project_id,
+ "operation_name": self.operation_name,
+ "job_name": self.job_name,
+ "location": self.location,
+ "gcp_conn_id": self.gcp_conn_id,
+ "polling_period_seconds": self.polling_period_seconds,
+ "timeout": self.timeout,
+ "impersonation_chain": self.impersonation_chain,
+ },
+ )
+
+ async def run(self) -> AsyncIterator[TriggerEvent]:
+ timeout = self.timeout
+ hook = self._get_async_hook()
+ while timeout is None or timeout > 0:
+ operation: operations_pb2.Operation = await
hook.get_operation(self.operation_name)
+ if operation.done:
+ # An operation can only have one of those two combinations: if
it is succeeded, then
+ # the response field will be populated, else, then the error
field will be.
+ if operation.response is not None:
+ yield TriggerEvent(
+ {
+ "status": RunJobStatus.SUCCESS,
+ "job_name": self.job_name,
+ }
+ )
+ else:
+ yield TriggerEvent(
+ {
+ "status": RunJobStatus.FAIL,
+ "operation_error_code": operation.error.code,
+ "operation_error_message": operation.error.message,
+ "job_name": self.job_name,
+ }
+ )
+ elif operation.error.message:
+ raise AirflowException(f"Cloud Run Job error:
{operation.error.message}")
+
+ if timeout is not None:
+ timeout -= self.polling_period_seconds
+
+ if timeout is None or timeout > 0:
+ await asyncio.sleep(self.polling_period_seconds)
+
+ yield TriggerEvent(
+ {
+ "status": RunJobStatus.TIMEOUT,
+ "job_name": self.job_name,
+ }
+ )
+
+ def _get_async_hook(self) -> CloudRunAsyncHook:
+ return CloudRunAsyncHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
diff --git a/airflow/providers/google/provider.yaml
b/airflow/providers/google/provider.yaml
index 94f9353eb0..a6df5f3f7c 100644
--- a/airflow/providers/google/provider.yaml
+++ b/airflow/providers/google/provider.yaml
@@ -117,6 +117,7 @@ dependencies:
- google-cloud-videointelligence>=2.11.0
- google-cloud-vision>=3.4.0
- google-cloud-workflows>=1.10.0
+ - google-cloud-run>=0.9.0
- google-cloud-batch>=0.13.0
- grpcio-gcp>=0.2.2
- httpx
@@ -183,6 +184,11 @@ integrations:
how-to-guide:
-
/docs/apache-airflow-providers-google/operators/cloud/cloud_composer.rst
tags: [google]
+ - integration-name: Google Cloud Run
+ external-doc-url: https://cloud.google.com/run
+ how-to-guide:
+ - /docs/apache-airflow-providers-google/operators/cloud/cloud_run.rst
+ tags: [google]
- integration-name: Google Cloud Batch
external-doc-url: https://cloud.google.com/batch
how-to-guide:
@@ -489,6 +495,9 @@ operators:
- integration-name: Google Cloud Composer
python-modules:
- airflow.providers.google.cloud.operators.cloud_composer
+ - integration-name: Google Cloud Run
+ python-modules:
+ - airflow.providers.google.cloud.operators.cloud_run
- integration-name: Google Cloud Memorystore
python-modules:
- airflow.providers.google.cloud.operators.cloud_memorystore
@@ -709,6 +718,9 @@ hooks:
- integration-name: Google Cloud Composer
python-modules:
- airflow.providers.google.cloud.hooks.cloud_composer
+ - integration-name: Google Cloud Run
+ python-modules:
+ - airflow.providers.google.cloud.hooks.cloud_run
- integration-name: Google Cloud Memorystore
python-modules:
- airflow.providers.google.cloud.hooks.cloud_memorystore
@@ -877,6 +889,9 @@ triggers:
- integration-name: Google Cloud Composer
python-modules:
- airflow.providers.google.cloud.triggers.cloud_composer
+ - integration-name: Google Cloud Run
+ python-modules:
+ - airflow.providers.google.cloud.triggers.cloud_run
- integration-name: Google Cloud Storage Transfer Service
python-modules:
- airflow.providers.google.cloud.triggers.cloud_storage_transfer_service
diff --git a/docs/apache-airflow-providers-google/operators/cloud/cloud_run.rst
b/docs/apache-airflow-providers-google/operators/cloud/cloud_run.rst
new file mode 100644
index 0000000000..7c80f86d15
--- /dev/null
+++ b/docs/apache-airflow-providers-google/operators/cloud/cloud_run.rst
@@ -0,0 +1,128 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+Google Cloud Run Operators
+===============================
+
+Cloud Run is used to build and deploy scalable containerized apps written in
any language (including Go, Python, Java, Node.js, .NET, and Ruby) on a fully
managed platform.
+
+For more information about the service visit `Google Cloud Run documentation
<https://cloud.google.com/run/docs/>`__.
+
+Prerequisite Tasks
+^^^^^^^^^^^^^^^^^^
+
+.. include:: /operators/_partials/prerequisite_tasks.rst
+
+Create a job
+---------------------
+
+Before you create a job in Cloud Run, you need to define it.
+For more information about the Job object fields, visit `Google Cloud Run Job
description
<https://cloud.google.com/run/docs/reference/rpc/google.cloud.run.v2#google.cloud.run.v2.Job>`__
+
+A simple job configuration can look as follows:
+
+.. exampleinclude::
/../../tests/system/providers/google/cloud/cloud_run/example_cloud_run.py
+ :language: python
+ :dedent: 0
+ :start-after: [START howto_operator_cloud_run_job_creation]
+ :end-before: [END howto_operator_cloud_run_job_creation]
+
+
+With this configuration we can create the job:
+:class:`~airflow.providers.google.cloud.operators.cloud_run.CloudRunCreateJobOperator`
+
+.. exampleinclude::
/../../tests/system/providers/google/cloud/cloud_run/example_cloud_run.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_cloud_run_create_job]
+ :end-before: [END howto_operator_cloud_run_create_job]
+
+
+Note that this operator only creates the job without executing it. The Job's
dictionary representation is pushed to XCom.
+
+Execute a job
+---------------------
+
+To execute a job, you can use:
+
+:class:`~airflow.providers.google.cloud.operators.cloud_run.CloudRunExecuteJobOperator`
+
+.. exampleinclude::
/../../tests/system/providers/google/cloud/cloud_run/example_cloud_run.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_cloud_run_execute_job]
+ :end-before: [END howto_operator_cloud_run_execute_job]
+
+or you can define the same operator in the deferrable mode:
+
+:class:`~airflow.providers.google.cloud.operators.cloud_run.CloudRunExecuteJobOperator`
+
+.. exampleinclude::
/../../tests/system/providers/google/cloud/cloud_run/example_cloud_run.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_cloud_run_execute_job_deferrable_mode]
+ :end-before: [END howto_operator_cloud_run_execute_job_deferrable_mode]
+
+
+
+Update a job
+------------------
+
+To update a job, you can use:
+
+:class:`~airflow.providers.google.cloud.operators.cloud_run.CloudRunUpdateJobOperator`
+
+.. exampleinclude::
/../../tests/system/providers/google/cloud/cloud_run/example_cloud_run.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_cloud_update_job]
+ :end-before: [END howto_operator_cloud_update_job]
+
+
+The Job's dictionary representation is pushed to XCom.
+
+
+List jobs
+----------------------
+
+To list the jobs, you can use:
+
+:class:`~airflow.providers.google.cloud.operators.cloud_run.CloudRunListJobsOperator`
+
+.. exampleinclude::
/../../tests/system/providers/google/cloud/cloud_run/example_cloud_run.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_cloud_run_list_jobs]
+ :end-before: [END howto_operator_cloud_run_list_jobs]
+
+The operator takes two optional parameters: "limit" to limit the number of
tasks returned, and "show_deleted" to include deleted jobs in the result.
+
+
+Delete a job
+-----------------
+
+To delete a job you can use:
+
+:class:`~airflow.providers.google.cloud.operators.cloud_run.CloudRunDeleteJobOperator`
+
+.. exampleinclude::
/../../tests/system/providers/google/cloud/cloud_run/example_cloud_run.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_cloud_delete_job]
+ :end-before: [END howto_operator_cloud_delete_job]
+
+Note this operator waits for the job to be deleted, and the deleted Job's
dictionary representation is pushed to XCom.
diff --git a/generated/provider_dependencies.json
b/generated/provider_dependencies.json
index 0b1c5d5a1d..5191600a98 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -440,6 +440,7 @@
"google-cloud-os-login>=2.9.1",
"google-cloud-pubsub>=2.15.0",
"google-cloud-redis>=2.12.0",
+ "google-cloud-run>=0.9.0",
"google-cloud-secret-manager>=2.16.0",
"google-cloud-spanner>=3.11.1",
"google-cloud-speech>=2.18.0",
diff --git a/tests/providers/google/cloud/hooks/test_cloud_run.py
b/tests/providers/google/cloud/hooks/test_cloud_run.py
new file mode 100644
index 0000000000..ba53b6de92
--- /dev/null
+++ b/tests/providers/google/cloud/hooks/test_cloud_run.py
@@ -0,0 +1,273 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+import pytest
+from google.cloud.run_v2 import (
+ CreateJobRequest,
+ GetJobRequest,
+ Job,
+ ListJobsRequest,
+ RunJobRequest,
+ UpdateJobRequest,
+)
+
+from airflow.exceptions import AirflowException
+from airflow.providers.google.cloud.hooks.cloud_run import CloudRunAsyncHook,
CloudRunHook
+from tests.providers.google.cloud.utils.base_gcp_mock import
mock_base_gcp_hook_default_project_id
+
+
+class TestCloudBathHook:
+ def dummy_get_credentials(self):
+ pass
+
+ @pytest.fixture
+ def cloud_run_hook(self):
+ cloud_run_hook = CloudRunHook()
+ cloud_run_hook.get_credentials = self.dummy_get_credentials
+ return cloud_run_hook
+
+ @mock.patch(
+
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
+ new=mock_base_gcp_hook_default_project_id,
+ )
+ @mock.patch("airflow.providers.google.cloud.hooks.cloud_run.JobsClient")
+ def test_get_job(self, mock_batch_service_client, cloud_run_hook):
+ job_name = "job1"
+ region = "region1"
+ project_id = "projectid"
+
+ get_job_request =
GetJobRequest(name=f"projects/{project_id}/locations/{region}/jobs/{job_name}")
+
+ cloud_run_hook.get_job(job_name=job_name, region=region,
project_id=project_id)
+ cloud_run_hook._client.get_job.assert_called_once_with(get_job_request)
+
+ @mock.patch(
+
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
+ new=mock_base_gcp_hook_default_project_id,
+ )
+ @mock.patch("airflow.providers.google.cloud.hooks.cloud_run.JobsClient")
+ def test_update_job(self, mock_batch_service_client, cloud_run_hook):
+ job_name = "job1"
+ region = "region1"
+ project_id = "projectid"
+ job = Job()
+ job.name = job.name =
f"projects/{project_id}/locations/{region}/jobs/{job_name}"
+
+ update_request = UpdateJobRequest()
+ update_request.job = job
+
+ cloud_run_hook.update_job(
+ job=Job.to_dict(job), job_name=job_name, region=region,
project_id=project_id
+ )
+
+
cloud_run_hook._client.update_job.assert_called_once_with(update_request)
+
+ @mock.patch(
+
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
+ new=mock_base_gcp_hook_default_project_id,
+ )
+ @mock.patch("airflow.providers.google.cloud.hooks.cloud_run.JobsClient")
+ def test_create_job(self, mock_batch_service_client, cloud_run_hook):
+ job_name = "job1"
+ region = "region1"
+ project_id = "projectid"
+ job = Job()
+
+ create_request = CreateJobRequest()
+ create_request.job = job
+ create_request.job_id = job_name
+ create_request.parent = f"projects/{project_id}/locations/{region}"
+
+ cloud_run_hook.create_job(
+ job=Job.to_dict(job), job_name=job_name, region=region,
project_id=project_id
+ )
+
+
cloud_run_hook._client.create_job.assert_called_once_with(create_request)
+
+ @mock.patch(
+
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
+ new=mock_base_gcp_hook_default_project_id,
+ )
+ @mock.patch("airflow.providers.google.cloud.hooks.cloud_run.JobsClient")
+ def test_execute_job(self, mock_batch_service_client, cloud_run_hook):
+ job_name = "job1"
+ region = "region1"
+ project_id = "projectid"
+ run_job_request =
RunJobRequest(name=f"projects/{project_id}/locations/{region}/jobs/{job_name}")
+
+ cloud_run_hook.execute_job(job_name=job_name, region=region,
project_id=project_id)
+
cloud_run_hook._client.run_job.assert_called_once_with(request=run_job_request)
+
+ @mock.patch(
+
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
+ new=mock_base_gcp_hook_default_project_id,
+ )
+ @mock.patch("airflow.providers.google.cloud.hooks.cloud_run.JobsClient")
+ def test_list_jobs(self, mock_batch_service_client, cloud_run_hook):
+
+ number_of_jobs = 3
+ region = "us-central1"
+ project_id = "test_project_id"
+
+ page = self._mock_pager(number_of_jobs)
+ mock_batch_service_client.return_value.list_jobs.return_value = page
+
+ jobs_list = cloud_run_hook.list_jobs(region=region,
project_id=project_id)
+
+ for i in range(number_of_jobs):
+ assert jobs_list[i].name == f"name{i}"
+
+ expected_list_jobs_request: ListJobsRequest = ListJobsRequest(
+ parent=f"projects/{project_id}/locations/{region}"
+ )
+
mock_batch_service_client.return_value.list_jobs.assert_called_once_with(
+ request=expected_list_jobs_request
+ )
+
+ @mock.patch(
+
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
+ new=mock_base_gcp_hook_default_project_id,
+ )
+ @mock.patch("airflow.providers.google.cloud.hooks.cloud_run.JobsClient")
+ def test_list_jobs_show_deleted(self, mock_batch_service_client,
cloud_run_hook):
+ number_of_jobs = 3
+ region = "us-central1"
+ project_id = "test_project_id"
+
+ page = self._mock_pager(number_of_jobs)
+ mock_batch_service_client.return_value.list_jobs.return_value = page
+
+ jobs_list = cloud_run_hook.list_jobs(region=region,
project_id=project_id, show_deleted=True)
+
+ for i in range(number_of_jobs):
+ assert jobs_list[i].name == f"name{i}"
+
+ expected_list_jobs_request: ListJobsRequest = ListJobsRequest(
+ parent=f"projects/{project_id}/locations/{region}",
show_deleted=True
+ )
+
mock_batch_service_client.return_value.list_jobs.assert_called_once_with(
+ request=expected_list_jobs_request
+ )
+
+ @mock.patch(
+
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
+ new=mock_base_gcp_hook_default_project_id,
+ )
+ @mock.patch("airflow.providers.google.cloud.hooks.cloud_run.JobsClient")
+ def test_list_jobs_with_limit(self, mock_batch_service_client,
cloud_run_hook):
+ number_of_jobs = 3
+ limit = 2
+ region = "us-central1"
+ project_id = "test_project_id"
+
+ page = self._mock_pager(number_of_jobs)
+ mock_batch_service_client.return_value.list_jobs.return_value = page
+
+ jobs_list = cloud_run_hook.list_jobs(region=region,
project_id=project_id, limit=limit)
+
+ assert len(jobs_list) == limit
+ for i in range(limit):
+ assert jobs_list[i].name == f"name{i}"
+
+ @mock.patch("airflow.providers.google.cloud.hooks.cloud_run.JobsClient")
+ def test_list_jobs_with_limit_zero(self, mock_batch_service_client,
cloud_run_hook):
+ number_of_jobs = 3
+ limit = 0
+ region = "us-central1"
+ project_id = "test_project_id"
+
+ page = self._mock_pager(number_of_jobs)
+ mock_batch_service_client.return_value.list_jobs.return_value = page
+
+ jobs_list = cloud_run_hook.list_jobs(region=region,
project_id=project_id, limit=limit)
+
+ assert len(jobs_list) == 0
+
+ @mock.patch(
+
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
+ new=mock_base_gcp_hook_default_project_id,
+ )
+ @mock.patch("airflow.providers.google.cloud.hooks.cloud_run.JobsClient")
+ def test_list_jobs_with_limit_greater_then_range(self,
mock_batch_service_client, cloud_run_hook):
+ number_of_jobs = 3
+ limit = 5
+ region = "us-central1"
+ project_id = "test_project_id"
+
+ page = self._mock_pager(number_of_jobs)
+ mock_batch_service_client.return_value.list_jobs.return_value = page
+
+ jobs_list = cloud_run_hook.list_jobs(region=region,
project_id=project_id, limit=limit)
+
+ assert len(jobs_list) == number_of_jobs
+ for i in range(number_of_jobs):
+ assert jobs_list[i].name == f"name{i}"
+
+ @mock.patch(
+
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
+ new=mock_base_gcp_hook_default_project_id,
+ )
+ @mock.patch("airflow.providers.google.cloud.hooks.cloud_run.JobsClient")
+ def test_list_jobs_with_limit_less_than_zero(self,
mock_batch_service_client, cloud_run_hook):
+ number_of_jobs = 3
+ limit = -1
+ region = "us-central1"
+ project_id = "test_project_id"
+
+ page = self._mock_pager(number_of_jobs)
+ mock_batch_service_client.return_value.list_jobs.return_value = page
+
+ with pytest.raises(expected_exception=AirflowException):
+ cloud_run_hook.list_jobs(region=region, project_id=project_id,
limit=limit)
+
+ def _mock_pager(self, number_of_jobs):
+ mock_pager = []
+ for i in range(number_of_jobs):
+ mock_pager.append(Job(name=f"name{i}"))
+
+ return mock_pager
+
+
+class TestCloudRunAsyncHook:
+ @pytest.mark.asyncio
+ @mock.patch(
+
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
+ new=mock_base_gcp_hook_default_project_id,
+ )
+
@mock.patch("airflow.providers.google.cloud.hooks.cloud_run.JobsAsyncClient")
+ async def test_get_operation(self, mock_client):
+ expected_operation = {"name": "somename"}
+
+ async def _get_operation(name):
+ return expected_operation
+
+ operation_name = "operationname"
+ mock_client.return_value = mock.MagicMock()
+ mock_client.return_value.get_operation = _get_operation
+ hook = CloudRunAsyncHook()
+ hook.get_credentials = self._dummy_get_credentials
+
+ returned_operation = await
hook.get_operation(operation_name=operation_name)
+
+ assert returned_operation == expected_operation
+
+ def _dummy_get_credentials(self):
+ pass
diff --git a/tests/providers/google/cloud/operators/test_cloud_run.py
b/tests/providers/google/cloud/operators/test_cloud_run.py
new file mode 100644
index 0000000000..8dbac06b1e
--- /dev/null
+++ b/tests/providers/google/cloud/operators/test_cloud_run.py
@@ -0,0 +1,308 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+This module contains various unit tests for GCP Cloud Build Operators
+"""
+from __future__ import annotations
+
+from unittest import mock
+
+import pytest
+from google.cloud.run_v2 import Job
+
+from airflow.exceptions import AirflowException, TaskDeferred
+from airflow.providers.google.cloud.operators.cloud_run import (
+ CloudRunCreateJobOperator,
+ CloudRunDeleteJobOperator,
+ CloudRunExecuteJobOperator,
+ CloudRunListJobsOperator,
+ CloudRunUpdateJobOperator,
+)
+from airflow.providers.google.cloud.triggers.cloud_run import RunJobStatus
+
+CLOUD_RUN_HOOK_PATH =
"airflow.providers.google.cloud.operators.cloud_run.CloudRunHook"
+TASK_ID = "test"
+PROJECT_ID = "testproject"
+REGION = "us-central1"
+JOB_NAME = "jobname"
+JOB = Job()
+JOB.name = JOB_NAME
+
+
+def _assert_common_template_fields(template_fields):
+ assert "project_id" in template_fields
+ assert "region" in template_fields
+ assert "gcp_conn_id" in template_fields
+ assert "impersonation_chain" in template_fields
+
+
+class TestCloudRunCreateJobOperator:
+ def test_template_fields(self):
+ operator = CloudRunCreateJobOperator(
+ task_id=TASK_ID, project_id=PROJECT_ID, region=REGION,
job_name=JOB_NAME, job=JOB
+ )
+
+ _assert_common_template_fields(operator.template_fields)
+ assert "job_name" in operator.template_fields
+
+ @mock.patch(CLOUD_RUN_HOOK_PATH)
+ def test_create(self, hook_mock):
+ hook_mock.return_value.create_job.return_value = JOB
+
+ operator = CloudRunCreateJobOperator(
+ task_id=TASK_ID, project_id=PROJECT_ID, region=REGION,
job_name=JOB_NAME, job=JOB
+ )
+
+ operator.execute(context=mock.MagicMock())
+
+ hook_mock.return_value.create_job.assert_called_once_with(
+ job_name=JOB_NAME, region=REGION, project_id=PROJECT_ID, job=JOB
+ )
+
+
+class TestCloudRunExecuteJobOperator:
+ def test_template_fields(self):
+ operator = CloudRunExecuteJobOperator(
+ task_id=TASK_ID, project_id=PROJECT_ID, region=REGION,
job_name=JOB_NAME
+ )
+
+ _assert_common_template_fields(operator.template_fields)
+ assert "job_name" in operator.template_fields
+
+ @mock.patch(CLOUD_RUN_HOOK_PATH)
+ def test_execute_success(self, hook_mock):
+ hook_mock.return_value.get_job.return_value = JOB
+ hook_mock.return_value.execute_job.return_value =
self._mock_operation(3, 3, 0)
+
+ operator = CloudRunExecuteJobOperator(
+ task_id=TASK_ID, project_id=PROJECT_ID, region=REGION,
job_name=JOB_NAME
+ )
+
+ operator.execute(context=mock.MagicMock())
+
+ hook_mock.return_value.execute_job.assert_called_once_with(
+ job_name=JOB_NAME, region=REGION, project_id=PROJECT_ID
+ )
+
+ @mock.patch(CLOUD_RUN_HOOK_PATH)
+ def test_execute_fail_one_failed_task(self, hook_mock):
+ hook_mock.return_value.execute_job.return_value =
self._mock_operation(3, 2, 1)
+
+ operator = CloudRunExecuteJobOperator(
+ task_id=TASK_ID, project_id=PROJECT_ID, region=REGION,
job_name=JOB_NAME
+ )
+
+ with pytest.raises(AirflowException) as exception:
+ operator.execute(context=mock.MagicMock())
+
+ assert "Some tasks failed execution" in str(exception.value)
+
+ @mock.patch(CLOUD_RUN_HOOK_PATH)
+ def test_execute_fail_all_failed_tasks(self, hook_mock):
+
+ hook_mock.return_value.execute_job.return_value =
self._mock_operation(3, 0, 3)
+
+ operator = CloudRunExecuteJobOperator(
+ task_id=TASK_ID, project_id=PROJECT_ID, region=REGION,
job_name=JOB_NAME
+ )
+
+ with pytest.raises(AirflowException) as exception:
+ operator.execute(context=mock.MagicMock())
+
+ assert "Some tasks failed execution" in str(exception.value)
+
+ @mock.patch(CLOUD_RUN_HOOK_PATH)
+ def test_execute_fail_incomplete_failed_tasks(self, hook_mock):
+ hook_mock.return_value.execute_job.return_value =
self._mock_operation(3, 2, 0)
+
+ operator = CloudRunExecuteJobOperator(
+ task_id=TASK_ID, project_id=PROJECT_ID, region=REGION,
job_name=JOB_NAME
+ )
+
+ with pytest.raises(AirflowException) as exception:
+ operator.execute(context=mock.MagicMock())
+
+ assert "Not all tasks finished execution" in str(exception.value)
+
+ @mock.patch(CLOUD_RUN_HOOK_PATH)
+ def test_execute_fail_incomplete_succeeded_tasks(self, hook_mock):
+ hook_mock.return_value.execute_job.return_value =
self._mock_operation(3, 0, 2)
+
+ operator = CloudRunExecuteJobOperator(
+ task_id=TASK_ID, project_id=PROJECT_ID, region=REGION,
job_name=JOB_NAME
+ )
+
+ with pytest.raises(AirflowException) as exception:
+ operator.execute(context=mock.MagicMock())
+
+ assert "Not all tasks finished execution" in str(exception.value)
+
+ @mock.patch(CLOUD_RUN_HOOK_PATH)
+ def test_execute_deferrable(self, hook_mock):
+
+ operator = CloudRunExecuteJobOperator(
+ task_id=TASK_ID, project_id=PROJECT_ID, region=REGION,
job_name=JOB_NAME, deferrable=True
+ )
+
+ with pytest.raises(TaskDeferred):
+ operator.execute(mock.MagicMock())
+
+ @mock.patch(CLOUD_RUN_HOOK_PATH)
+ def test_execute_deferrable_execute_complete_method_timeout(self,
hook_mock):
+
+ operator = CloudRunExecuteJobOperator(
+ task_id=TASK_ID, project_id=PROJECT_ID, region=REGION,
job_name=JOB_NAME, deferrable=True
+ )
+
+ event = {"status": RunJobStatus.TIMEOUT, "job_name": JOB_NAME}
+
+ with pytest.raises(AirflowException) as e:
+ operator.execute_complete(mock.MagicMock(), event)
+
+ assert "Operation timed out" in str(e.value)
+
+ @mock.patch(CLOUD_RUN_HOOK_PATH)
+ def test_execute_deferrable_execute_complete_method_fail(self, hook_mock):
+
+ operator = CloudRunExecuteJobOperator(
+ task_id=TASK_ID, project_id=PROJECT_ID, region=REGION,
job_name=JOB_NAME, deferrable=True
+ )
+
+ error_code = 10
+ error_message = "error message"
+
+ event = {
+ "status": RunJobStatus.FAIL,
+ "operation_error_code": error_code,
+ "operation_error_message": error_message,
+ "job_name": JOB_NAME,
+ }
+
+ with pytest.raises(AirflowException) as e:
+ operator.execute_complete(mock.MagicMock(), event)
+
+ assert f"Operation failed with error code [{error_code}] and error
message [{error_message}]" in str(
+ e.value
+ )
+
+ @mock.patch(CLOUD_RUN_HOOK_PATH)
+ def test_execute_deferrable_execute_complete_method_success(self,
hook_mock):
+
+ hook_mock.return_value.get_job.return_value = JOB
+
+ operator = CloudRunExecuteJobOperator(
+ task_id=TASK_ID, project_id=PROJECT_ID, region=REGION,
job_name=JOB_NAME, deferrable=True
+ )
+
+ event = {"status": RunJobStatus.SUCCESS, "job_name": JOB_NAME}
+
+ result = operator.execute_complete(mock.MagicMock(), event)
+ assert result["name"] == JOB_NAME
+
+ def _mock_operation(self, task_count, succeeded_count, failed_count):
+ operation = mock.MagicMock()
+ operation.result.return_value = self._mock_execution(task_count,
succeeded_count, failed_count)
+ return operation
+
+ def _mock_execution(self, task_count, succeeded_count, failed_count):
+ execution = mock.MagicMock()
+ execution.task_count = task_count
+ execution.succeeded_count = succeeded_count
+ execution.failed_count = failed_count
+ return execution
+
+
+class TestCloudRunDeleteJobOperator:
+ def test_template_fields(self):
+ operator = CloudRunDeleteJobOperator(
+ task_id=TASK_ID, project_id=PROJECT_ID, region=REGION,
job_name=JOB_NAME
+ )
+
+ _assert_common_template_fields(operator.template_fields)
+ assert "job_name" in operator.template_fields
+
+ @mock.patch(CLOUD_RUN_HOOK_PATH)
+ def test_execute(self, hook_mock):
+ hook_mock.return_value.delete_job.return_value = JOB
+
+ operator = CloudRunDeleteJobOperator(
+ task_id=TASK_ID, project_id=PROJECT_ID, region=REGION,
job_name=JOB_NAME
+ )
+
+ deleted_job = operator.execute(context=mock.MagicMock())
+
+ assert deleted_job["name"] == JOB.name
+
+ hook_mock.return_value.delete_job.assert_called_once_with(
+ job_name=JOB_NAME, region=REGION, project_id=PROJECT_ID
+ )
+
+
+class TestCloudRunUpdateJobOperator:
+ def test_template_fields(self):
+ operator = CloudRunUpdateJobOperator(
+ task_id=TASK_ID, project_id=PROJECT_ID, region=REGION,
job_name=JOB_NAME, job=JOB
+ )
+
+ _assert_common_template_fields(operator.template_fields)
+ assert "job_name" in operator.template_fields
+
+ @mock.patch(CLOUD_RUN_HOOK_PATH)
+ def test_execute(self, hook_mock):
+ hook_mock.return_value.update_job.return_value = JOB
+
+ operator = CloudRunUpdateJobOperator(
+ task_id=TASK_ID, project_id=PROJECT_ID, region=REGION,
job_name=JOB_NAME, job=JOB
+ )
+
+ updated_job = operator.execute(context=mock.MagicMock())
+
+ assert updated_job["name"] == JOB.name
+
+ hook_mock.return_value.update_job.assert_called_once_with(
+ job_name=JOB_NAME, job=JOB, region=REGION, project_id=PROJECT_ID
+ )
+
+
+class TestCloudRunListJobsOperator:
+ def test_template_fields(self):
+ operator = CloudRunListJobsOperator(
+ task_id=TASK_ID, project_id=PROJECT_ID, region=REGION, limit=2,
show_deleted=False
+ )
+
+ _assert_common_template_fields(operator.template_fields)
+
+ @mock.patch(CLOUD_RUN_HOOK_PATH)
+ def test_execute(self, hook_mock):
+ limit = 2
+ show_deleted = True
+ operator = CloudRunListJobsOperator(
+ task_id=TASK_ID, project_id=PROJECT_ID, region=REGION,
limit=limit, show_deleted=show_deleted
+ )
+
+ operator.execute(context=mock.MagicMock())
+
+ hook_mock.return_value.list_jobs.assert_called_once_with(
+ region=REGION, project_id=PROJECT_ID, limit=limit,
show_deleted=show_deleted
+ )
+
+ @mock.patch(CLOUD_RUN_HOOK_PATH)
+ def test_execute_with_invalid_limit(self, hook_mock):
+ limit = -1
+ with pytest.raises(expected_exception=AirflowException):
+ CloudRunListJobsOperator(task_id=TASK_ID, project_id=PROJECT_ID,
region=REGION, limit=limit)
diff --git a/tests/providers/google/cloud/triggers/test_cloud_run.py
b/tests/providers/google/cloud/triggers/test_cloud_run.py
new file mode 100644
index 0000000000..42c57d9b93
--- /dev/null
+++ b/tests/providers/google/cloud/triggers/test_cloud_run.py
@@ -0,0 +1,155 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+import pytest
+
+from airflow.exceptions import AirflowException
+from airflow.providers.google.cloud.triggers.cloud_run import
CloudRunJobFinishedTrigger, RunJobStatus
+from airflow.triggers.base import TriggerEvent
+
+OPERATION_NAME = "operation"
+JOB_NAME = "jobName"
+PROJECT_ID = "projectId"
+LOCATION = "us-central1"
+GCP_CONNECTION_ID = "gcp_connection_id"
+POLL_SLEEP = 0.01
+TIMEOUT = 0.02
+IMPERSONATION_CHAIN = "impersonation_chain"
+
+
[email protected]
+def trigger():
+ return CloudRunJobFinishedTrigger(
+ operation_name=OPERATION_NAME,
+ job_name=JOB_NAME,
+ project_id=PROJECT_ID,
+ location=LOCATION,
+ gcp_conn_id=GCP_CONNECTION_ID,
+ polling_period_seconds=POLL_SLEEP,
+ timeout=TIMEOUT,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+
+
+class TestCloudBatchJobFinishedTrigger:
+ def test_serialization(self, trigger):
+ classpath, kwargs = trigger.serialize()
+ assert classpath ==
"airflow.providers.google.cloud.triggers.cloud_run.CloudRunJobFinishedTrigger"
+ assert kwargs == {
+ "project_id": PROJECT_ID,
+ "operation_name": OPERATION_NAME,
+ "job_name": JOB_NAME,
+ "location": LOCATION,
+ "gcp_conn_id": GCP_CONNECTION_ID,
+ "polling_period_seconds": POLL_SLEEP,
+ "timeout": TIMEOUT,
+ "impersonation_chain": IMPERSONATION_CHAIN,
+ }
+
+ @pytest.mark.asyncio
+
@mock.patch("airflow.providers.google.cloud.triggers.cloud_run.CloudRunAsyncHook")
+ async def test_trigger_on_operation_completed_yield_successfully(
+ self, mock_hook, trigger: CloudRunJobFinishedTrigger
+ ):
+ """
+ Tests the CloudRunJobFinishedTrigger fires once the job execution
reaches a successful state.
+ """
+
+ done = True
+ name = "name"
+ error_code = 10
+ error_message = "message"
+
+ mock_hook.return_value.get_operation.return_value =
self._mock_operation(
+ done, name, error_code, error_message
+ )
+ generator = trigger.run()
+ actual = await generator.asend(None)
+ assert (
+ TriggerEvent(
+ {
+ "status": RunJobStatus.SUCCESS,
+ "job_name": JOB_NAME,
+ }
+ )
+ == actual
+ )
+
+ @pytest.mark.asyncio
+
@mock.patch("airflow.providers.google.cloud.triggers.cloud_run.CloudRunAsyncHook")
+ async def test_trigger_on_operation_failed_yield_error(
+ self, mock_hook, trigger: CloudRunJobFinishedTrigger
+ ):
+ """
+ Tests the CloudRunJobFinishedTrigger raises an exception once the job
execution fails.
+ """
+
+ done = False
+ name = "name"
+ error_code = 10
+ error_message = "message"
+
+ mock_hook.return_value.get_operation.return_value =
self._mock_operation(
+ done, name, error_code, error_message
+ )
+ generator = trigger.run()
+
+ with pytest.raises(expected_exception=AirflowException):
+ await generator.asend(None)
+
+ @pytest.mark.asyncio
+
@mock.patch("airflow.providers.google.cloud.triggers.cloud_run.CloudRunAsyncHook")
+ async def test_trigger_timeout(self, mock_hook, trigger:
CloudRunJobFinishedTrigger):
+ """
+ Tests the CloudRunJobFinishedTrigger fires once the job execution
times out with an error message.
+ """
+
+ async def _mock_operation(name):
+ operation = mock.MagicMock()
+ operation.done = False
+ operation.error = mock.MagicMock()
+ operation.error.message = None
+ operation.error.code = None
+ return operation
+
+ mock_hook.return_value.get_operation = _mock_operation
+
+ generator = trigger.run()
+ actual = await generator.asend(None)
+
+ assert (
+ TriggerEvent(
+ {
+ "status": RunJobStatus.TIMEOUT,
+ "job_name": JOB_NAME,
+ }
+ )
+ == actual
+ )
+
+ async def _mock_operation(self, done, name, error_code, error_message):
+ operation = mock.MagicMock()
+ operation.done = done
+ operation.name = name
+ operation.error = mock.MagicMock()
+ operation.error.message = error_message
+ operation.error.code = error_code
+ return operation
diff --git a/tests/system/providers/google/cloud/cloud_run/__init__.py
b/tests/system/providers/google/cloud/cloud_run/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/tests/system/providers/google/cloud/cloud_run/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/system/providers/google/cloud/cloud_run/example_cloud_run.py
b/tests/system/providers/google/cloud/cloud_run/example_cloud_run.py
new file mode 100644
index 0000000000..e5f83578cd
--- /dev/null
+++ b/tests/system/providers/google/cloud/cloud_run/example_cloud_run.py
@@ -0,0 +1,261 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Example Airflow DAG that uses Google Cloud Run Operators.
+"""
+from __future__ import annotations
+
+import os
+from datetime import datetime
+
+from google.cloud.run_v2 import Job
+from google.cloud.run_v2.types import k8s_min
+
+from airflow import models
+from airflow.operators.python import PythonOperator
+from airflow.providers.google.cloud.operators.cloud_run import (
+ CloudRunCreateJobOperator,
+ CloudRunDeleteJobOperator,
+ CloudRunExecuteJobOperator,
+ CloudRunListJobsOperator,
+ CloudRunUpdateJobOperator,
+)
+from airflow.utils.trigger_rule import TriggerRule
+
+PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT")
+DAG_ID = "example_cloud_run"
+
+region = "us-central1"
+job_name_prefix = "cloudrun-system-test-job"
+job1_name = f"{job_name_prefix}1"
+job2_name = f"{job_name_prefix}2"
+
+create1_task_name = "create-job1"
+create2_task_name = "create-job2"
+
+execute1_task_name = "execute-job1"
+execute2_task_name = "execute-job2"
+
+update_job1_task_name = "update-job1"
+
+delete1_task_name = "delete-job1"
+delete2_task_name = "delete-job2"
+
+list_jobs_limit_task_name = "list-jobs-limit"
+list_jobs_task_name = "list-jobs"
+
+clean1_task_name = "clean-job1"
+clean2_task_name = "clean-job2"
+
+
+def _assert_executed_jobs_xcom(ti):
+ job1_dicts = ti.xcom_pull(task_ids=[execute1_task_name],
key="return_value")
+ assert job1_name in job1_dicts[0]["name"]
+
+ job2_dicts = ti.xcom_pull(task_ids=[execute2_task_name],
key="return_value")
+ assert job2_name in job2_dicts[0]["name"]
+
+
+def _assert_created_jobs_xcom(ti):
+ job1_dicts = ti.xcom_pull(task_ids=[create1_task_name], key="return_value")
+ assert job1_name in job1_dicts[0]["name"]
+
+ job2_dicts = ti.xcom_pull(task_ids=[create2_task_name], key="return_value")
+ assert job2_name in job2_dicts[0]["name"]
+
+
+def _assert_updated_job(ti):
+ job_dicts = ti.xcom_pull(task_ids=[update_job1_task_name],
key="return_value")
+ job_dict = job_dicts[0]
+ assert job_dict["labels"]["somelabel"] == "label1"
+
+
+def _assert_jobs(ti):
+ job_dicts = ti.xcom_pull(task_ids=[list_jobs_task_name],
key="return_value")
+
+ job1_exists = False
+ job2_exists = False
+
+ for job_dict in job_dicts[0]:
+ if job1_exists and job2_exists:
+ break
+
+ if job1_name in job_dict["name"]:
+ job1_exists = True
+
+ if job2_name in job_dict["name"]:
+ job2_exists = True
+
+ assert job1_exists and job2_exists
+
+
+def _assert_one_job(ti):
+ job_dicts = ti.xcom_pull(task_ids=[list_jobs_limit_task_name],
key="return_value")
+
+ assert len(job_dicts[0]) == 1
+
+
+# [START howto_operator_cloud_run_job_creation]
+def _create_job():
+ job = Job()
+ container = k8s_min.Container()
+ container.image = "us-docker.pkg.dev/cloudrun/container/job:latest"
+ job.template.template.containers.append(container)
+ return job
+
+
+# [END howto_operator_cloud_run_job_creation]
+
+
+def _create_job_with_label():
+ job = _create_job()
+ job.labels = {"somelabel": "label1"}
+ return job
+
+
+with models.DAG(
+ DAG_ID,
+ schedule="@once",
+ start_date=datetime(2021, 1, 1),
+ catchup=False,
+ tags=["example"],
+) as dag:
+
+ # [START howto_operator_cloud_run_create_job]
+ create1 = CloudRunCreateJobOperator(
+ task_id=create1_task_name,
+ project_id=PROJECT_ID,
+ region=region,
+ job_name=job1_name,
+ job=_create_job(),
+ dag=dag,
+ )
+ # [END howto_operator_cloud_run_create_job]
+
+ create2 = CloudRunCreateJobOperator(
+ task_id=create2_task_name,
+ project_id=PROJECT_ID,
+ region=region,
+ job_name=job2_name,
+ job=Job.to_dict(_create_job()),
+ dag=dag,
+ )
+
+ assert_created_jobs = PythonOperator(
+ task_id="assert-created-jobs",
python_callable=_assert_created_jobs_xcom, dag=dag
+ )
+
+ # [START howto_operator_cloud_run_execute_job]
+ execute1 = CloudRunExecuteJobOperator(
+ task_id=execute1_task_name,
+ project_id=PROJECT_ID,
+ region=region,
+ job_name=job1_name,
+ dag=dag,
+ deferrable=False,
+ )
+ # [END howto_operator_cloud_run_execute_job]
+
+ # [START howto_operator_cloud_run_execute_job_deferrable_mode]
+ execute2 = CloudRunExecuteJobOperator(
+ task_id=execute2_task_name,
+ project_id=PROJECT_ID,
+ region=region,
+ job_name=job2_name,
+ dag=dag,
+ deferrable=True,
+ )
+ # [END howto_operator_cloud_run_execute_job_deferrable_mode]
+
+ assert_executed_jobs = PythonOperator(
+ task_id="assert-executed-jobs",
python_callable=_assert_executed_jobs_xcom, dag=dag
+ )
+
+ list_jobs_limit = CloudRunListJobsOperator(
+ task_id=list_jobs_limit_task_name, project_id=PROJECT_ID,
region=region, dag=dag, limit=1
+ )
+
+ assert_jobs_limit = PythonOperator(task_id="assert-jobs-limit",
python_callable=_assert_one_job, dag=dag)
+
+ # [START howto_operator_cloud_run_list_jobs]
+ list_jobs = CloudRunListJobsOperator(
+ task_id=list_jobs_task_name, project_id=PROJECT_ID, region=region,
dag=dag
+ )
+ # [END howto_operator_cloud_run_list_jobs]
+
+ assert_jobs = PythonOperator(task_id="assert-jobs",
python_callable=_assert_jobs, dag=dag)
+
+ # [START howto_operator_cloud_update_job]
+ update_job1 = CloudRunUpdateJobOperator(
+ task_id=update_job1_task_name,
+ project_id=PROJECT_ID,
+ region=region,
+ job_name=job1_name,
+ job=_create_job_with_label(),
+ dag=dag,
+ )
+ # [END howto_operator_cloud_update_job]
+
+ assert_job_updated = PythonOperator(
+ task_id="assert-job-updated", python_callable=_assert_updated_job,
dag=dag
+ )
+
+ # [START howto_operator_cloud_delete_job]
+ delete_job1 = CloudRunDeleteJobOperator(
+ task_id="delete-job1",
+ project_id=PROJECT_ID,
+ region=region,
+ job_name=job1_name,
+ dag=dag,
+ trigger_rule=TriggerRule.ALL_DONE,
+ )
+ # [END howto_operator_cloud_delete_job]
+
+ delete_job2 = CloudRunDeleteJobOperator(
+ task_id="delete-job2",
+ project_id=PROJECT_ID,
+ region=region,
+ job_name=job2_name,
+ dag=dag,
+ trigger_rule=TriggerRule.ALL_DONE,
+ )
+
+ (
+ (create1, create2)
+ >> assert_created_jobs
+ >> (execute1, execute2)
+ >> assert_executed_jobs
+ >> list_jobs_limit
+ >> assert_jobs_limit
+ >> list_jobs
+ >> assert_jobs
+ >> update_job1
+ >> assert_job_updated
+ >> (delete_job1, delete_job2)
+ )
+
+ from tests.system.utils.watcher import watcher
+
+ # This test needs watcher in order to properly mark success/failure
+ # when "tearDown" task with trigger rule is part of the DAG
+ list(dag.tasks) >> watcher()
+
+from tests.system.utils import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see:
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)