This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new c6a014a370 Add `CloudBatchHook` and operators (#32606)
c6a014a370 is described below
commit c6a014a3707d2e4a5a9d2fe0b4277be09266b63b
Author: Freddy Demiane <[email protected]>
AuthorDate: Fri Aug 18 21:00:08 2023 +0200
Add `CloudBatchHook` and operators (#32606)
---
.../providers/google/cloud/hooks/cloud_batch.py | 204 ++++++++++++
.../google/cloud/operators/cloud_batch.py | 298 ++++++++++++++++++
.../providers/google/cloud/triggers/cloud_batch.py | 156 ++++++++++
airflow/providers/google/provider.yaml | 16 +
.../operators/cloud/cloud_batch.rst | 108 +++++++
generated/provider_dependencies.json | 1 +
.../google/cloud/hooks/test_cloud_batch.py | 343 +++++++++++++++++++++
.../google/cloud/operators/test_cloud_batch.py | 190 ++++++++++++
.../google/cloud/triggers/test_cloud_batch.py | 160 ++++++++++
.../providers/google/cloud/cloud_batch/__init__.py | 16 +
.../cloud/cloud_batch/example_cloud_batch.py | 202 ++++++++++++
11 files changed, 1694 insertions(+)
diff --git a/airflow/providers/google/cloud/hooks/cloud_batch.py
b/airflow/providers/google/cloud/hooks/cloud_batch.py
new file mode 100644
index 0000000000..f85283047d
--- /dev/null
+++ b/airflow/providers/google/cloud/hooks/cloud_batch.py
@@ -0,0 +1,204 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import itertools
+import json
+from time import sleep
+from typing import Iterable, Sequence
+
+from google.api_core import operation # type: ignore
+from google.cloud.batch import ListJobsRequest, ListTasksRequest
+from google.cloud.batch_v1 import (
+ BatchServiceAsyncClient,
+ BatchServiceClient,
+ CreateJobRequest,
+ Job,
+ JobStatus,
+ Task,
+)
+from google.cloud.batch_v1.services.batch_service import pagers
+
+from airflow.exceptions import AirflowException
+from airflow.providers.google.common.consts import CLIENT_INFO
+from airflow.providers.google.common.hooks.base_google import
PROVIDE_PROJECT_ID, GoogleBaseHook
+
+
+class CloudBatchHook(GoogleBaseHook):
+ """
+ Hook for the Google Cloud Batch service.
+
+ :param gcp_conn_id: The connection ID to use when fetching connection info.
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account.
+ """
+
+ def __init__(
+ self,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ ) -> None:
+ super().__init__(gcp_conn_id=gcp_conn_id,
impersonation_chain=impersonation_chain)
+ self._client: BatchServiceClient | None = None
+
+ def get_conn(self):
+ """
+ Retrieves connection to GCE Batch.
+
+ :return: Google Batch Service client object.
+ """
+ if self._client is None:
+ self._client =
BatchServiceClient(credentials=self.get_credentials(), client_info=CLIENT_INFO)
+ return self._client
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def submit_batch_job(
+ self, job_name: str, job: Job, region: str, project_id: str =
PROVIDE_PROJECT_ID
+ ) -> Job:
+ if isinstance(job, dict):
+ job = Job.from_json(json.dumps(job))
+
+ create_request = CreateJobRequest()
+ create_request.job = job
+ create_request.job_id = job_name
+ create_request.parent = f"projects/{project_id}/locations/{region}"
+
+ return self.get_conn().create_job(create_request)
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def delete_job(
+ self, job_name: str, region: str, project_id: str = PROVIDE_PROJECT_ID
+ ) -> operation.Operation:
+ return
self.get_conn().delete_job(name=f"projects/{project_id}/locations/{region}/jobs/{job_name}")
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def list_jobs(
+ self,
+ region: str,
+ project_id: str = PROVIDE_PROJECT_ID,
+ filter: str | None = None,
+ limit: int | None = None,
+ ) -> Iterable[Job]:
+
+ if limit is not None and limit < 0:
+ raise AirflowException("The limit for the list jobs request should
be greater or equal to zero")
+
+ list_jobs_request: ListJobsRequest = ListJobsRequest(
+ parent=f"projects/{project_id}/locations/{region}", filter=filter
+ )
+
+ jobs: pagers.ListJobsPager =
self.get_conn().list_jobs(request=list_jobs_request)
+
+ return list(itertools.islice(jobs, limit))
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def list_tasks(
+ self,
+ region: str,
+ job_name: str,
+ project_id: str = PROVIDE_PROJECT_ID,
+ group_name: str = "group0",
+ filter: str | None = None,
+ limit: int | None = None,
+ ) -> Iterable[Task]:
+
+ if limit is not None and limit < 0:
+ raise AirflowException("The limit for the list tasks request
should be greater or equal to zero")
+
+ list_tasks_request: ListTasksRequest = ListTasksRequest(
+
parent=f"projects/{project_id}/locations/{region}/jobs/{job_name}/taskGroups/{group_name}",
+ filter=filter,
+ )
+
+ tasks: pagers.ListTasksPager =
self.get_conn().list_tasks(request=list_tasks_request)
+
+ return list(itertools.islice(tasks, limit))
+
+ def wait_for_job(
+ self, job_name: str, polling_period_seconds: float = 10, timeout:
float | None = None
+ ) -> Job:
+ client = self.get_conn()
+ while timeout is None or timeout > 0:
+ try:
+ job = client.get_job(name=f"{job_name}")
+ status: JobStatus.State = job.status.state
+ if (
+ status == JobStatus.State.SUCCEEDED
+ or status == JobStatus.State.FAILED
+ or status == JobStatus.State.DELETION_IN_PROGRESS
+ ):
+ return job
+ else:
+ sleep(polling_period_seconds)
+ except Exception as e:
+ self.log.exception("Exception occurred while checking for job
completion.")
+ raise e
+
+ if timeout is not None:
+ timeout -= polling_period_seconds
+
+ raise AirflowException(f"Job with name [{job_name}] timed out")
+
+ def get_job(self, job_name) -> Job:
+ return self.get_conn().get_job(name=job_name)
+
+
+class CloudBatchAsyncHook(GoogleBaseHook):
+ """
+ Async hook for the Google Cloud Batch service.
+
+ :param gcp_conn_id: The connection ID to use when fetching connection info.
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account.
+ """
+
+ def __init__(
+ self,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ ):
+
+ self._client: BatchServiceAsyncClient | None = None
+ super().__init__(gcp_conn_id=gcp_conn_id,
impersonation_chain=impersonation_chain)
+
+ def get_conn(self):
+ if self._client is None:
+ self._client = BatchServiceAsyncClient(
+ credentials=self.get_credentials(), client_info=CLIENT_INFO
+ )
+
+ return self._client
+
+ async def get_batch_job(
+ self,
+ job_name: str,
+ ) -> Job:
+ client = self.get_conn()
+ return await client.get_job(name=f"{job_name}")
diff --git a/airflow/providers/google/cloud/operators/cloud_batch.py
b/airflow/providers/google/cloud/operators/cloud_batch.py
new file mode 100644
index 0000000000..26c0af06c1
--- /dev/null
+++ b/airflow/providers/google/cloud/operators/cloud_batch.py
@@ -0,0 +1,298 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import Sequence
+
+from google.api_core import operation # type: ignore
+from google.cloud.batch_v1 import Job, Task
+
+from airflow.configuration import conf
+from airflow.exceptions import AirflowException
+from airflow.providers.google.cloud.hooks.cloud_batch import CloudBatchHook
+from airflow.providers.google.cloud.operators.cloud_base import
GoogleCloudBaseOperator
+from airflow.providers.google.cloud.triggers.cloud_batch import
CloudBatchJobFinishedTrigger
+from airflow.utils.context import Context
+
+
+class CloudBatchSubmitJobOperator(GoogleCloudBaseOperator):
+ """
+ Submit a job and wait for its completion.
+
+ :param project_id: Required. The ID of the Google Cloud project that the
service belongs to.
+ :param region: Required. The ID of the Google Cloud region that the
service belongs to.
+ :param job_name: Required. The name of the job to create.
+ :param job: Required. The job descriptor containing the configuration of
the job to submit.
+ :param polling_period_seconds: Optional: Control the rate of the poll for
the result of deferrable run.
+ By default, the trigger will poll every 10 seconds.
+ :param timeout: The timeout for this request.
+ :param gcp_conn_id: The connection ID used to connect to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account
(templated).
+ :param deferrable: Run operator in the deferrable mode
+
+ """
+
+ template_fields = ("project_id", "region", "gcp_conn_id",
"impersonation_chain", "job_name")
+
+ def __init__(
+ self,
+ project_id: str,
+ region: str,
+ job_name: str,
+ job: dict | Job,
+ polling_period_seconds: float = 10,
+ timeout_seconds: float | None = None,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.project_id = project_id
+ self.region = region
+ self.job_name = job_name
+ self.job = job
+ self.polling_period_seconds = polling_period_seconds
+ self.timeout_seconds = timeout_seconds
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+ self.deferrable = deferrable
+ self.polling_period_seconds = polling_period_seconds
+
+ def execute(self, context: Context):
+ hook: CloudBatchHook = CloudBatchHook(self.gcp_conn_id,
self.impersonation_chain)
+ job = hook.submit_batch_job(job_name=self.job_name, job=self.job,
region=self.region)
+
+ if not self.deferrable:
+ completed_job = hook.wait_for_job(
+ job_name=job.name,
+ polling_period_seconds=self.polling_period_seconds,
+ timeout=self.timeout_seconds,
+ )
+
+ return Job.to_dict(completed_job)
+
+ else:
+ self.defer(
+ trigger=CloudBatchJobFinishedTrigger(
+ job_name=job.name,
+ project_id=self.project_id,
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ location=self.region,
+ polling_period_seconds=self.polling_period_seconds,
+ timeout=self.timeout_seconds,
+ ),
+ method_name="execute_complete",
+ )
+
+ def execute_complete(self, context: Context, event: dict):
+ job_status = event["status"]
+ if job_status == "success":
+ hook: CloudBatchHook = CloudBatchHook(self.gcp_conn_id,
self.impersonation_chain)
+ job = hook.get_job(job_name=event["job_name"])
+ return Job.to_dict(job)
+ else:
+ raise AirflowException(f"Unexpected error in the operation:
{event['message']}")
+
+
+class CloudBatchDeleteJobOperator(GoogleCloudBaseOperator):
+ """
+ Deletes a job and wait for the operation to be completed.
+
+ :param project_id: Required. The ID of the Google Cloud project that the
service belongs to.
+ :param region: Required. The ID of the Google Cloud region that the
service belongs to.
+ :param job_name: Required. The name of the job to be deleted.
+ :param timeout: The timeout for this request.
+ :param gcp_conn_id: The connection ID used to connect to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account
(templated).
+
+ """
+
+ template_fields = ("project_id", "region", "gcp_conn_id",
"impersonation_chain", "job_name")
+
+ def __init__(
+ self,
+ project_id: str,
+ region: str,
+ job_name: str,
+ timeout: float | None = None,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ) -> None:
+
+ super().__init__(**kwargs)
+ self.project_id = project_id
+ self.region = region
+ self.job_name = job_name
+ self.timeout = timeout
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+
+ def execute(self, context: Context):
+ hook: CloudBatchHook = CloudBatchHook(self.gcp_conn_id,
self.impersonation_chain)
+
+ operation = hook.delete_job(job_name=self.job_name,
region=self.region, project_id=self.project_id)
+
+ self._wait_for_operation(operation)
+
+ def _wait_for_operation(self, operation: operation.Operation):
+ try:
+ return operation.result(timeout=self.timeout)
+ except Exception:
+ error = operation.exception(timeout=self.timeout)
+ raise AirflowException(error)
+
+
+class CloudBatchListJobsOperator(GoogleCloudBaseOperator):
+ """
+ List Cloud Batch jobs.
+
+ :param project_id: Required. The ID of the Google Cloud project that the
service belongs to.
+ :param region: Required. The ID of the Google Cloud region that the
service belongs to.
+ :param gcp_conn_id: The connection ID used to connect to Google Cloud.
+ :param filter: The filter based on which to list the jobs. If left empty,
all the jobs are listed.
+ :param limit: The number of jobs to list. If left empty,
+ all the jobs matching the filter will be returned.
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account
(templated).
+
+ """
+
+ template_fields = (
+ "project_id",
+ "region",
+ "gcp_conn_id",
+ "impersonation_chain",
+ )
+
+ def __init__(
+ self,
+ project_id: str,
+ region: str,
+ gcp_conn_id: str = "google_cloud_default",
+ filter: str | None = None,
+ limit: int | None = None,
+ impersonation_chain: str | Sequence[str] | None = None,
+ **kwargs,
+ ) -> None:
+
+ super().__init__(**kwargs)
+ self.project_id = project_id
+ self.region = region
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+ self.filter = filter
+ self.limit = limit
+ if limit is not None and limit < 0:
+ raise AirflowException("The limit for the list jobs request should
be greater or equal to zero")
+
+ def execute(self, context: Context):
+ hook: CloudBatchHook = CloudBatchHook(self.gcp_conn_id,
self.impersonation_chain)
+
+ jobs_list = hook.list_jobs(
+ region=self.region, project_id=self.project_id,
filter=self.filter, limit=self.limit
+ )
+
+ return [Job.to_dict(job) for job in jobs_list]
+
+
+class CloudBatchListTasksOperator(GoogleCloudBaseOperator):
+ """
+ List Cloud Batch tasks for a given job.
+
+ :param project_id: Required. The ID of the Google Cloud project that the
service belongs to.
+ :param region: Required. The ID of the Google Cloud region that the
service belongs to.
+ :param job_name: Required. The name of the job for which to list tasks.
+ :param gcp_conn_id: The connection ID used to connect to Google Cloud.
+ :param filter: The filter based on which to list the jobs. If left empty,
all the jobs are listed.
+ :param group_name: The name of the group that owns the task. By default
it's `group0`.
+ :param limit: The number of tasks to list.
+ If left empty, all the tasks matching the filter will be returned.
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account
(templated).
+
+ """
+
+ template_fields = ("project_id", "region", "job_name", "gcp_conn_id",
"impersonation_chain", "group_name")
+
+ def __init__(
+ self,
+ project_id: str,
+ region: str,
+ job_name: str,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ group_name: str = "group0",
+ filter: str | None = None,
+ limit: int | None = None,
+ **kwargs,
+ ) -> None:
+
+ super().__init__(**kwargs)
+ self.project_id = project_id
+ self.region = region
+ self.job_name = job_name
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+ self.group_name = group_name
+ self.filter = filter
+ self.limit = limit
+ if limit is not None and limit < 0:
+ raise AirflowException("The limit for the list jobs request should
be greater or equal to zero")
+
+ def execute(self, context: Context):
+ hook: CloudBatchHook = CloudBatchHook(self.gcp_conn_id,
self.impersonation_chain)
+
+ tasks_list = hook.list_tasks(
+ region=self.region,
+ project_id=self.project_id,
+ job_name=self.job_name,
+ group_name=self.group_name,
+ filter=self.filter,
+ limit=self.limit,
+ )
+
+ return [Task.to_dict(task) for task in tasks_list]
diff --git a/airflow/providers/google/cloud/triggers/cloud_batch.py
b/airflow/providers/google/cloud/triggers/cloud_batch.py
new file mode 100644
index 0000000000..211e436c95
--- /dev/null
+++ b/airflow/providers/google/cloud/triggers/cloud_batch.py
@@ -0,0 +1,156 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+from typing import Any, AsyncIterator, Sequence
+
+from google.cloud.batch_v1 import Job, JobStatus
+
+from airflow.providers.google.cloud.hooks.cloud_batch import
CloudBatchAsyncHook
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+DEFAULT_BATCH_LOCATION = "us-central1"
+
+
+class CloudBatchJobFinishedTrigger(BaseTrigger):
+ """Cloud Batch trigger to check if templated job has been finished.
+
+ :param job_name: Required. Name of the job.
+ :param project_id: Required. the Google Cloud project ID in which the job
was started.
+ :param location: Optional. the location where job is executed. If set to
None then
+ the value of DEFAULT_BATCH_LOCATION will be used
+ :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
+ :param impersonation_chain: Optional. Service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account
(templated).
+ :param poll_sleep: Polling period in seconds to check for the status
+
+ """
+
+ def __init__(
+ self,
+ job_name: str,
+ project_id: str | None,
+ location: str = DEFAULT_BATCH_LOCATION,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ polling_period_seconds: float = 10,
+ timeout: float | None = None,
+ ):
+ super().__init__()
+ self.project_id = project_id
+ self.job_name = job_name
+ self.location = location
+ self.gcp_conn_id = gcp_conn_id
+ self.polling_period_seconds = polling_period_seconds
+ self.timeout = timeout
+ self.impersonation_chain = impersonation_chain
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ """Serializes class arguments and classpath."""
+ return (
+
"airflow.providers.google.cloud.triggers.cloud_batch.CloudBatchJobFinishedTrigger",
+ {
+ "project_id": self.project_id,
+ "job_name": self.job_name,
+ "location": self.location,
+ "gcp_conn_id": self.gcp_conn_id,
+ "polling_period_seconds": self.polling_period_seconds,
+ "timeout": self.timeout,
+ "impersonation_chain": self.impersonation_chain,
+ },
+ )
+
+ async def run(self) -> AsyncIterator[TriggerEvent]:
+ """
+ Main loop of the class in where it is fetching the job status and
yields certain Event.
+
+ If the job has status success then it yields TriggerEvent with success
status, if job has
+ status failed - with error status and if the job is being deleted -
with deleted status.
+ In any other case Trigger will wait for specified amount of time
+ stored in self.polling_period_seconds variable.
+ """
+ timeout = self.timeout
+ hook = self._get_async_hook()
+ while timeout is None or timeout > 0:
+
+ try:
+ job: Job = await hook.get_batch_job(job_name=self.job_name)
+
+ status: JobStatus.State = job.status.state
+ if status == JobStatus.State.SUCCEEDED:
+ yield TriggerEvent(
+ {
+ "job_name": self.job_name,
+ "status": "success",
+ "message": "Job completed",
+ }
+ )
+ return
+ elif status == JobStatus.State.FAILED:
+ yield TriggerEvent(
+ {
+ "job_name": self.job_name,
+ "status": "error",
+ "message": f"Batch job with name {self.job_name}
has failed its execution",
+ }
+ )
+ return
+ elif status == JobStatus.State.DELETION_IN_PROGRESS:
+ yield TriggerEvent(
+ {
+ "job_name": self.job_name,
+ "status": "deleted",
+ "message": f"Batch job with name {self.job_name}
is being deleted",
+ }
+ )
+ return
+ else:
+ self.log.info("Current job status is: %s", status)
+ self.log.info("Sleeping for %s seconds.",
self.polling_period_seconds)
+ if timeout is not None:
+ timeout -= self.polling_period_seconds
+
+ if timeout is None or timeout > 0:
+ await asyncio.sleep(self.polling_period_seconds)
+
+ except Exception as e:
+ self.log.exception("Exception occurred while checking for job
completion.")
+ yield TriggerEvent({"status": "error", "message": str(e)})
+ return
+
+ self.log.exception(f"Job with name [{self.job_name}] timed out")
+ yield TriggerEvent(
+ {
+ "job_name": self.job_name,
+ "status": "timed out",
+ "message": f"Batch job with name {self.job_name} timed out",
+ }
+ )
+ return
+
+ def _get_async_hook(self) -> CloudBatchAsyncHook:
+ return CloudBatchAsyncHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
diff --git a/airflow/providers/google/provider.yaml
b/airflow/providers/google/provider.yaml
index f3ec0ecee9..29fd2c8073 100644
--- a/airflow/providers/google/provider.yaml
+++ b/airflow/providers/google/provider.yaml
@@ -117,6 +117,7 @@ dependencies:
- google-cloud-videointelligence>=2.11.0
- google-cloud-vision>=3.4.0
- google-cloud-workflows>=1.10.0
+ - google-cloud-batch>=0.13.0
- grpcio-gcp>=0.2.2
- httpx
- json-merge-patch>=0.2
@@ -182,6 +183,11 @@ integrations:
how-to-guide:
-
/docs/apache-airflow-providers-google/operators/cloud/cloud_composer.rst
tags: [google]
+ - integration-name: Google Cloud Batch
+ external-doc-url: https://cloud.google.com/batch
+ how-to-guide:
+ - /docs/apache-airflow-providers-google/operators/cloud/cloud_batch.rst
+ tags: [google]
- integration-name: Google Cloud Dataform
external-doc-url: https://cloud.google.com/dataform/
how-to-guide:
@@ -611,6 +617,10 @@ operators:
- integration-name: Google Cloud Dataform
python-modules:
- airflow.providers.google.cloud.operators.dataform
+ - integration-name: Google Cloud Batch
+ python-modules:
+ - airflow.providers.google.cloud.operators.cloud_batch
+
sensors:
- integration-name: Google BigQuery
@@ -850,6 +860,9 @@ hooks:
- integration-name: Google Cloud Dataform
python-modules:
- airflow.providers.google.cloud.hooks.dataform
+ - integration-name: Google Cloud Batch
+ python-modules:
+ - airflow.providers.google.cloud.hooks.cloud_batch
triggers:
- integration-name: Google BigQuery Data Transfer Service
@@ -891,6 +904,9 @@ triggers:
- integration-name: Google Cloud Pub/Sub
python-modules:
- airflow.providers.google.cloud.triggers.pubsub
+ - integration-name: Google Cloud
+ python-modules:
+ - airflow.providers.google.cloud.triggers.cloud_batch
transfers:
- source-integration-name: Presto
diff --git
a/docs/apache-airflow-providers-google/operators/cloud/cloud_batch.rst
b/docs/apache-airflow-providers-google/operators/cloud/cloud_batch.rst
new file mode 100644
index 0000000000..2254ead25b
--- /dev/null
+++ b/docs/apache-airflow-providers-google/operators/cloud/cloud_batch.rst
@@ -0,0 +1,108 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+Google Cloud Batch Operators
+===============================
+
+Cloud Batch is a fully managed batch service to schedule, queue, and execute
batch jobs on Google's infrastructure.
+
+For more information about the service visit `Google Cloud Batch documentation
<https://cloud.google.com/batch>`__.
+
+Prerequisite Tasks
+^^^^^^^^^^^^^^^^^^
+
+.. include:: /operators/_partials/prerequisite_tasks.rst
+
+Submit a job
+---------------------
+
+Before you submit a job in Cloud Batch, you need to define it.
+For more information about the Job object fields, visit `Google Cloud Batch
Job description
<https://cloud.google.com/python/docs/reference/batch/latest/google.cloud.batch_v1.types.Job>`__.
+
+A simple job configuration can look as follows:
+
+.. exampleinclude::
/../../tests/system/providers/google/cloud/cloud_batch/example_cloud_batch.py
+ :language: python
+ :dedent: 0
+ :start-after: [START howto_operator_batch_job_creation]
+ :end-before: [END howto_operator_batch_job_creation]
+
+With this configuration we can submit the job:
+:class:`~airflow.providers.google.cloud.operators.cloud_batch.CloudBatchSubmitJobOperator`
+
+.. exampleinclude::
/../../tests/system/providers/google/cloud/cloud_batch/example_cloud_batch.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_batch_submit_job]
+ :end-before: [END howto_operator_batch_submit_job]
+
+or you can define the same operator in the deferrable mode:
+:class:`~airflow.providers.google.cloud.operators.cloud_batch.CloudBatchSubmitJobOperator`
+
+.. exampleinclude::
/../../tests/system/providers/google/cloud/cloud_batch/example_cloud_batch.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_batch_submit_job_deferrable_mode]
+ :end-before: [END howto_operator_batch_submit_job_deferrable_mode]
+
+Note that this operator waits for the job complete its execution, and the
Job's dictionary representation is pushed to XCom.
+
+List a job's tasks
+------------------
+
+To list the tasks of a certain job, you can use:
+
+:class:`~airflow.providers.google.cloud.operators.cloud_batch.CloudBatchListTasksOperator`
+
+.. exampleinclude::
/../../tests/system/providers/google/cloud/cloud_batch/example_cloud_batch.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_batch_list_tasks]
+ :end-before: [END howto_operator_batch_list_tasks]
+
+The operator takes two optional parameters: "limit" to limit the number of
tasks returned, and "filter" to only list the tasks matching the `filter
<https://cloud.google.com/sdk/gcloud/reference/topic/filters>`__.
+
+List jobs
+----------------------
+
+To list the jobs, you can use:
+
+:class:`~airflow.providers.google.cloud.operators.cloud_batch.CloudBatchListJobsOperator`
+
+.. exampleinclude::
/../../tests/system/providers/google/cloud/cloud_batch/example_cloud_batch.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_batch_list_jobs]
+ :end-before: [END howto_operator_batch_list_jobs]
+
+The operator takes two optional parameters: "limit" to limit the number of
tasks returned, and "filter" to only list the tasks matching the `filter
<https://cloud.google.com/sdk/gcloud/reference/topic/filters>`__.
+
+Delete a job
+-----------------
+
+To delete a job you can use:
+
+:class:`~airflow.providers.google.cloud.operators.cloud_batch.CloudBatchDeleteJobOperator`
+
+.. exampleinclude::
/../../tests/system/providers/google/cloud/cloud_batch/example_cloud_batch.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_delete_job]
+ :end-before: [END howto_operator_delete_job]
+
+
+Note that this operator waits for the job to be deleted, and the deleted Job's
dictionary representation is pushed to XCom.
diff --git a/generated/provider_dependencies.json
b/generated/provider_dependencies.json
index ee7a382267..b96c6baf1c 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -420,6 +420,7 @@
"google-auth>=1.0.0",
"google-cloud-aiplatform>=1.22.1",
"google-cloud-automl>=2.11.0",
+ "google-cloud-batch>=0.13.0",
"google-cloud-bigquery-datatransfer>=3.11.0",
"google-cloud-bigtable>=2.17.0",
"google-cloud-build>=3.13.0",
diff --git a/tests/providers/google/cloud/hooks/test_cloud_batch.py
b/tests/providers/google/cloud/hooks/test_cloud_batch.py
new file mode 100644
index 0000000000..d4c000067f
--- /dev/null
+++ b/tests/providers/google/cloud/hooks/test_cloud_batch.py
@@ -0,0 +1,343 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+import pytest
+from google.cloud.batch import ListJobsRequest
+from google.cloud.batch_v1 import CreateJobRequest, Job, JobStatus
+
+from airflow.exceptions import AirflowException
+from airflow.providers.google.cloud.hooks.cloud_batch import
CloudBatchAsyncHook, CloudBatchHook
+from tests.providers.google.cloud.utils.base_gcp_mock import
mock_base_gcp_hook_default_project_id
+
+
+class TestCloudBathHook:
+ def dummy_get_credentials(self):
+ pass
+
+ @pytest.fixture
+ def cloud_batch_hook(self):
+ cloud_batch_hook = CloudBatchHook()
+ cloud_batch_hook.get_credentials = self.dummy_get_credentials
+ return cloud_batch_hook
+
+ @mock.patch(
+
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
+ new=mock_base_gcp_hook_default_project_id,
+ )
+
@mock.patch("airflow.providers.google.cloud.hooks.cloud_batch.BatchServiceClient")
+ def test_submit(self, mock_batch_service_client, cloud_batch_hook):
+ job = Job()
+ job_name = "jobname"
+ project_id = "test_project_id"
+ region = "us-central1"
+
+ cloud_batch_hook.submit_batch_job(
+ job_name=job_name, job=Job.to_dict(job), region=region,
project_id=project_id
+ )
+
+ create_request = CreateJobRequest()
+ create_request.job = job
+ create_request.job_id = job_name
+ create_request.parent = f"projects/{project_id}/locations/{region}"
+
+ cloud_batch_hook._client.create_job.assert_called_with(create_request)
+
+ @mock.patch(
+
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
+ new=mock_base_gcp_hook_default_project_id,
+ )
+
@mock.patch("airflow.providers.google.cloud.hooks.cloud_batch.BatchServiceClient")
+ def test_get_job(self, mock_batch_service_client, cloud_batch_hook):
+ cloud_batch_hook.get_job("job1")
+ cloud_batch_hook._client.get_job.assert_called_once_with(name="job1")
+
+ @mock.patch(
+
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
+ new=mock_base_gcp_hook_default_project_id,
+ )
+
@mock.patch("airflow.providers.google.cloud.hooks.cloud_batch.BatchServiceClient")
+ def test_delete_job(self, mock_batch_service_client, cloud_batch_hook):
+ job_name = "job1"
+ region = "us-east1"
+ project_id = "test_project_id"
+ cloud_batch_hook.delete_job(job_name=job_name, region=region,
project_id=project_id)
+ cloud_batch_hook._client.delete_job.assert_called_once_with(
+ name=f"projects/{project_id}/locations/{region}/jobs/{job_name}"
+ )
+
+ @pytest.mark.parametrize(
+ "state", [JobStatus.State.SUCCEEDED, JobStatus.State.FAILED,
JobStatus.State.DELETION_IN_PROGRESS]
+ )
+ @mock.patch(
+
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
+ new=mock_base_gcp_hook_default_project_id,
+ )
+
@mock.patch("airflow.providers.google.cloud.hooks.cloud_batch.BatchServiceClient")
+ def test_wait_job_succeeded(self, mock_batch_service_client, state,
cloud_batch_hook):
+ mock_job = self._mock_job_with_status(state)
+ mock_batch_service_client.return_value.get_job.return_value = mock_job
+ actual_job = cloud_batch_hook.wait_for_job("job1")
+ assert actual_job == mock_job
+
+ @mock.patch(
+
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
+ new=mock_base_gcp_hook_default_project_id,
+ )
+
@mock.patch("airflow.providers.google.cloud.hooks.cloud_batch.BatchServiceClient")
+ def test_wait_job_timeout(self, mock_batch_service_client,
cloud_batch_hook):
+ mock_job = self._mock_job_with_status(JobStatus.State.RUNNING)
+ mock_batch_service_client.return_value.get_job.return_value = mock_job
+
+ exception_caught = False
+ try:
+ cloud_batch_hook.wait_for_job("job1", polling_period_seconds=0.01,
timeout=0.02)
+ except AirflowException:
+ exception_caught = True
+
+ assert exception_caught
+
+ @mock.patch(
+
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
+ new=mock_base_gcp_hook_default_project_id,
+ )
+
@mock.patch("airflow.providers.google.cloud.hooks.cloud_batch.BatchServiceClient")
+ def test_list_jobs(self, mock_batch_service_client, cloud_batch_hook):
+
+ number_of_jobs = 3
+ region = "us-central1"
+ project_id = "test_project_id"
+ filter = "filter_description"
+
+ page = self._mock_pager(number_of_jobs)
+ mock_batch_service_client.return_value.list_jobs.return_value = page
+
+ jobs_list = cloud_batch_hook.list_jobs(region=region,
project_id=project_id, filter=filter)
+
+ for i in range(number_of_jobs):
+ assert jobs_list[i].name == f"name{i}"
+
+ expected_list_jobs_request: ListJobsRequest = ListJobsRequest(
+ parent=f"projects/{project_id}/locations/{region}", filter=filter
+ )
+
mock_batch_service_client.return_value.list_jobs.assert_called_once_with(
+ request=expected_list_jobs_request
+ )
+
+ @mock.patch(
+
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
+ new=mock_base_gcp_hook_default_project_id,
+ )
+
@mock.patch("airflow.providers.google.cloud.hooks.cloud_batch.BatchServiceClient")
+ def test_list_jobs_with_limit(self, mock_batch_service_client,
cloud_batch_hook):
+
+ number_of_jobs = 3
+ limit = 2
+ region = "us-central1"
+ project_id = "test_project_id"
+ filter = "filter_description"
+
+ page = self._mock_pager(number_of_jobs)
+ mock_batch_service_client.return_value.list_jobs.return_value = page
+
+ jobs_list = cloud_batch_hook.list_jobs(
+ region=region, project_id=project_id, filter=filter, limit=limit
+ )
+
+ assert len(jobs_list) == limit
+ for i in range(limit):
+ assert jobs_list[i].name == f"name{i}"
+
+ @mock.patch(
+
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
+ new=mock_base_gcp_hook_default_project_id,
+ )
+
@mock.patch("airflow.providers.google.cloud.hooks.cloud_batch.BatchServiceClient")
+ def test_list_jobs_with_limit_zero(self, mock_batch_service_client,
cloud_batch_hook):
+
+ number_of_jobs = 3
+ limit = 0
+ region = "us-central1"
+ project_id = "test_project_id"
+ filter = "filter_description"
+
+ page = self._mock_pager(number_of_jobs)
+ mock_batch_service_client.return_value.list_jobs.return_value = page
+
+ jobs_list = cloud_batch_hook.list_jobs(
+ region=region, project_id=project_id, filter=filter, limit=limit
+ )
+
+ assert len(jobs_list) == 0
+
+ @mock.patch(
+
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
+ new=mock_base_gcp_hook_default_project_id,
+ )
+
@mock.patch("airflow.providers.google.cloud.hooks.cloud_batch.BatchServiceClient")
+ def test_list_jobs_with_limit_greater_then_range(self,
mock_batch_service_client, cloud_batch_hook):
+
+ number_of_jobs = 3
+ limit = 5
+ region = "us-central1"
+ project_id = "test_project_id"
+ filter = "filter_description"
+
+ page = self._mock_pager(number_of_jobs)
+ mock_batch_service_client.return_value.list_jobs.return_value = page
+
+ jobs_list = cloud_batch_hook.list_jobs(
+ region=region, project_id=project_id, filter=filter, limit=limit
+ )
+
+ assert len(jobs_list) == number_of_jobs
+ for i in range(number_of_jobs):
+ assert jobs_list[i].name == f"name{i}"
+
+ @mock.patch(
+
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
+ new=mock_base_gcp_hook_default_project_id,
+ )
+
@mock.patch("airflow.providers.google.cloud.hooks.cloud_batch.BatchServiceClient")
+ def test_list_jobs_with_limit_less_than_zero(self,
mock_batch_service_client, cloud_batch_hook):
+
+ number_of_jobs = 3
+ limit = -1
+ region = "us-central1"
+ project_id = "test_project_id"
+ filter = "filter_description"
+
+ page = self._mock_pager(number_of_jobs)
+ mock_batch_service_client.return_value.list_jobs.return_value = page
+
+ with pytest.raises(expected_exception=AirflowException):
+ cloud_batch_hook.list_jobs(region=region, project_id=project_id,
filter=filter, limit=limit)
+
+ @mock.patch(
+
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
+ new=mock_base_gcp_hook_default_project_id,
+ )
+
@mock.patch("airflow.providers.google.cloud.hooks.cloud_batch.BatchServiceClient")
+ def test_list_tasks_with_limit(self, mock_batch_service_client,
cloud_batch_hook):
+
+ number_of_tasks = 3
+ limit = 2
+ region = "us-central1"
+ project_id = "test_project_id"
+ filter = "filter_description"
+ job_name = "test_job"
+
+ page = self._mock_pager(number_of_tasks)
+ mock_batch_service_client.return_value.list_tasks.return_value = page
+
+ tasks_list = cloud_batch_hook.list_tasks(
+ region=region, project_id=project_id, job_name=job_name,
filter=filter, limit=limit
+ )
+
+ assert len(tasks_list) == limit
+ for i in range(limit):
+ assert tasks_list[i].name == f"name{i}"
+
+ @mock.patch(
+
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
+ new=mock_base_gcp_hook_default_project_id,
+ )
+
@mock.patch("airflow.providers.google.cloud.hooks.cloud_batch.BatchServiceClient")
+ def test_list_tasks_with_limit_greater_then_range(self,
mock_batch_service_client, cloud_batch_hook):
+
+ number_of_tasks = 3
+ limit = 5
+ region = "us-central1"
+ project_id = "test_project_id"
+ filter = "filter_description"
+ job_name = "test_job"
+
+ page = self._mock_pager(number_of_tasks)
+ mock_batch_service_client.return_value.list_tasks.return_value = page
+
+ tasks_list = cloud_batch_hook.list_tasks(
+ region=region, project_id=project_id, filter=filter,
job_name=job_name, limit=limit
+ )
+
+ assert len(tasks_list) == number_of_tasks
+ for i in range(number_of_tasks):
+ assert tasks_list[i].name == f"name{i}"
+
+ @mock.patch(
+
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
+ new=mock_base_gcp_hook_default_project_id,
+ )
+
@mock.patch("airflow.providers.google.cloud.hooks.cloud_batch.BatchServiceClient")
+ def test_list_tasks_with_limit_less_than_zero(self,
mock_batch_service_client, cloud_batch_hook):
+
+ number_of_tasks = 3
+ limit = -1
+ region = "us-central1"
+ project_id = "test_project_id"
+ filter = "filter_description"
+ job_name = "test_job"
+
+ page = self._mock_pager(number_of_tasks)
+ mock_batch_service_client.return_value.list_tasks.return_value = page
+
+ with pytest.raises(expected_exception=AirflowException):
+ cloud_batch_hook.list_tasks(
+ region=region, project_id=project_id, job_name=job_name,
filter=filter, limit=limit
+ )
+
+ def _mock_job_with_status(self, status: JobStatus.State):
+ job = mock.MagicMock()
+ job.status.state = status
+ return job
+
+ def _mock_pager(self, number_of_jobs):
+ mock_pager = []
+ for i in range(number_of_jobs):
+ mock_pager.append(Job(name=f"name{i}"))
+
+ return mock_pager
+
+
+class TestCloudBatchAsyncHook:
+ @pytest.mark.asyncio
+ @mock.patch(
+
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
+ new=mock_base_gcp_hook_default_project_id,
+ )
+
@mock.patch("airflow.providers.google.cloud.hooks.cloud_batch.BatchServiceAsyncClient")
+ async def test_get_job(self, mock_client):
+ expected_job = {"name": "somename"}
+
+ async def _get_job(name):
+ return expected_job
+
+ job_name = "jobname"
+ mock_client.return_value = mock.MagicMock()
+ mock_client.return_value.get_job = _get_job
+
+ hook = CloudBatchAsyncHook()
+ hook.get_credentials = self._dummy_get_credentials
+
+ returned_operation = await hook.get_batch_job(job_name=job_name)
+
+ assert returned_operation == expected_job
+
+ def _dummy_get_credentials(self):
+ pass
diff --git a/tests/providers/google/cloud/operators/test_cloud_batch.py
b/tests/providers/google/cloud/operators/test_cloud_batch.py
new file mode 100644
index 0000000000..a4377eebdc
--- /dev/null
+++ b/tests/providers/google/cloud/operators/test_cloud_batch.py
@@ -0,0 +1,190 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+import pytest
+from google.cloud import batch_v1
+
+from airflow.exceptions import AirflowException, TaskDeferred
+from airflow.providers.google.cloud.operators.cloud_batch import (
+ CloudBatchDeleteJobOperator,
+ CloudBatchListJobsOperator,
+ CloudBatchListTasksOperator,
+ CloudBatchSubmitJobOperator,
+)
+
+CLOUD_BATCH_HOOK_PATH =
"airflow.providers.google.cloud.operators.cloud_batch.CloudBatchHook"
+TASK_ID = "test"
+PROJECT_ID = "testproject"
+REGION = "us-central1"
+JOB_NAME = "test"
+JOB = batch_v1.Job()
+JOB.name = JOB_NAME
+
+
+class TestCloudBatchSubmitJobOperator:
+ @mock.patch(CLOUD_BATCH_HOOK_PATH)
+ def test_execute(self, mock):
+ mock.return_value.wait_for_job.return_value = JOB
+ operator = CloudBatchSubmitJobOperator(
+ task_id=TASK_ID, project_id=PROJECT_ID, region=REGION,
job_name=JOB_NAME, job=JOB
+ )
+
+ completed_job = operator.execute(context=mock.MagicMock())
+
+ assert completed_job["name"] == JOB_NAME
+
+
mock.return_value.submit_batch_job.assert_called_with(job_name=JOB_NAME,
job=JOB, region=REGION)
+ mock.return_value.wait_for_job.assert_called()
+
+ @mock.patch(CLOUD_BATCH_HOOK_PATH)
+ def test_execute_deferrable(self, mock):
+ operator = CloudBatchSubmitJobOperator(
+ task_id=TASK_ID, project_id=PROJECT_ID, region=REGION,
job_name=JOB_NAME, job=JOB, deferrable=True
+ )
+
+ with pytest.raises(expected_exception=TaskDeferred):
+ operator.execute(context=mock.MagicMock())
+
+ @mock.patch(CLOUD_BATCH_HOOK_PATH)
+ def test_execute_complete(self, mock):
+ mock.return_value.get_job.return_value = JOB
+ operator = CloudBatchSubmitJobOperator(
+ task_id=TASK_ID, project_id=PROJECT_ID, region=REGION,
job_name=JOB_NAME, job=JOB, deferrable=True
+ )
+
+ event = {"status": "success", "job_name": JOB_NAME, "message": "test
error"}
+ completed_job = operator.execute_complete(context=mock.MagicMock(),
event=event)
+
+ assert completed_job["name"] == JOB_NAME
+
+ mock.return_value.get_job.assert_called_once_with(job_name=JOB_NAME)
+
+ @mock.patch(CLOUD_BATCH_HOOK_PATH)
+ def test_execute_complete_exception(self, mock):
+ operator = CloudBatchSubmitJobOperator(
+ task_id=TASK_ID, project_id=PROJECT_ID, region=REGION,
job_name=JOB_NAME, job=JOB, deferrable=True
+ )
+
+ event = {"status": "error", "job_name": JOB_NAME, "message": "test
error"}
+ with pytest.raises(
+ expected_exception=AirflowException, match="Unexpected error in
the operation: test error"
+ ):
+ operator.execute_complete(context=mock.MagicMock(), event=event)
+
+
+class TestCloudBatchDeleteJobOperator:
+ @mock.patch(CLOUD_BATCH_HOOK_PATH)
+ def test_execute(self, hook_mock):
+ delete_operation_mock = self._delete_operation_mock()
+ hook_mock.return_value.delete_job.return_value = delete_operation_mock
+
+ operator = CloudBatchDeleteJobOperator(
+ task_id=TASK_ID,
+ project_id=PROJECT_ID,
+ region=REGION,
+ job_name=JOB_NAME,
+ )
+
+ operator.execute(context=mock.MagicMock())
+
+ hook_mock.return_value.delete_job.assert_called_once_with(
+ job_name=JOB_NAME, region=REGION, project_id=PROJECT_ID
+ )
+ delete_operation_mock.result.assert_called_once()
+
+ def _delete_operation_mock(self):
+ operation = mock.MagicMock()
+ operation.result.return_value = mock.MagicMock()
+ return operation
+
+
+class TestCloudBatchListJobsOperator:
+ @mock.patch(CLOUD_BATCH_HOOK_PATH)
+ def test_execute(self, hook_mock):
+
+ filter = "filter_description"
+ limit = 2
+ operator = CloudBatchListJobsOperator(
+ task_id=TASK_ID, project_id=PROJECT_ID, region=REGION,
filter=filter, limit=limit
+ )
+
+ operator.execute(context=mock.MagicMock())
+
+ hook_mock.return_value.list_jobs.assert_called_once_with(
+ region=REGION, project_id=PROJECT_ID, filter=filter, limit=limit
+ )
+
+ @mock.patch(CLOUD_BATCH_HOOK_PATH)
+ def test_execute_with_invalid_limit(self, hook_mock):
+
+ filter = "filter_description"
+ limit = -1
+
+ with pytest.raises(expected_exception=AirflowException):
+ CloudBatchListJobsOperator(
+ task_id=TASK_ID, project_id=PROJECT_ID, region=REGION,
filter=filter, limit=limit
+ )
+
+
+class TestCloudBatchListTasksOperator:
+ @mock.patch(CLOUD_BATCH_HOOK_PATH)
+ def test_execute(self, hook_mock):
+
+ filter = "filter_description"
+ limit = 2
+ job_name = "test_job"
+
+ operator = CloudBatchListTasksOperator(
+ task_id=TASK_ID,
+ project_id=PROJECT_ID,
+ region=REGION,
+ job_name=job_name,
+ filter=filter,
+ limit=limit,
+ )
+
+ operator.execute(context=mock.MagicMock())
+
+ hook_mock.return_value.list_tasks.assert_called_once_with(
+ region=REGION,
+ project_id=PROJECT_ID,
+ filter=filter,
+ job_name=job_name,
+ limit=limit,
+ group_name="group0",
+ )
+
+ @mock.patch(CLOUD_BATCH_HOOK_PATH)
+ def test_execute_with_invalid_limit(self, hook_mock):
+
+ filter = "filter_description"
+ limit = -1
+ job_name = "test_job"
+
+ with pytest.raises(expected_exception=AirflowException):
+ CloudBatchListTasksOperator(
+ task_id=TASK_ID,
+ project_id=PROJECT_ID,
+ region=REGION,
+ job_name=job_name,
+ filter=filter,
+ limit=limit,
+ )
diff --git a/tests/providers/google/cloud/triggers/test_cloud_batch.py
b/tests/providers/google/cloud/triggers/test_cloud_batch.py
new file mode 100644
index 0000000000..8da083f17e
--- /dev/null
+++ b/tests/providers/google/cloud/triggers/test_cloud_batch.py
@@ -0,0 +1,160 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+import pytest
+from google.cloud.batch_v1 import Job, JobStatus
+
+from airflow.providers.google.cloud.triggers.cloud_batch import
CloudBatchJobFinishedTrigger
+from airflow.triggers.base import TriggerEvent
+
+JOB_NAME = "jobName"
+PROJECT_ID = "projectId"
+LOCATION = "us-central1"
+GCP_CONNECTION_ID = "gcp_connection_id"
+POLL_SLEEP = 0.01
+TIMEOUT = 0.02
+IMPERSONATION_CHAIN = "impersonation_chain"
+
+
[email protected]
+def trigger():
+ return CloudBatchJobFinishedTrigger(
+ job_name=JOB_NAME,
+ project_id=PROJECT_ID,
+ location=LOCATION,
+ gcp_conn_id=GCP_CONNECTION_ID,
+ polling_period_seconds=POLL_SLEEP,
+ timeout=TIMEOUT,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+
+
+class TestCloudBatchJobFinishedTrigger:
+ def test_serialization(self, trigger):
+ classpath, kwargs = trigger.serialize()
+ assert classpath ==
"airflow.providers.google.cloud.triggers.cloud_batch.CloudBatchJobFinishedTrigger"
+ assert kwargs == {
+ "project_id": PROJECT_ID,
+ "job_name": JOB_NAME,
+ "location": LOCATION,
+ "gcp_conn_id": GCP_CONNECTION_ID,
+ "polling_period_seconds": POLL_SLEEP,
+ "timeout": TIMEOUT,
+ "impersonation_chain": IMPERSONATION_CHAIN,
+ }
+
+ @pytest.mark.asyncio
+
@mock.patch("airflow.providers.google.cloud.triggers.cloud_batch.CloudBatchAsyncHook")
+ async def test_trigger_on_success_yield_successfully(
+ self, mock_hook, trigger: CloudBatchJobFinishedTrigger
+ ):
+ """
+ Tests the CloudBatchJobFinishedTrigger fires once the job execution
reaches a successful state.
+ """
+ state = JobStatus.State.SUCCEEDED
+ mock_hook.return_value.get_batch_job.return_value =
self._mock_job_with_state(state)
+ generator = trigger.run()
+ actual = await generator.asend(None)
+ assert (
+ TriggerEvent(
+ {
+ "job_name": JOB_NAME,
+ "status": "success",
+ "message": "Job completed",
+ }
+ )
+ == actual
+ )
+
+ @pytest.mark.asyncio
+
@mock.patch("airflow.providers.google.cloud.triggers.cloud_batch.CloudBatchAsyncHook")
+ async def test_trigger_on_deleted_yield_successfully(
+ self, mock_hook, trigger: CloudBatchJobFinishedTrigger
+ ):
+ """
+ Tests the CloudBatchJobFinishedTrigger fires once the job execution
reaches a successful state.
+ """
+ state = JobStatus.State.DELETION_IN_PROGRESS
+ mock_hook.return_value.get_batch_job.return_value =
self._mock_job_with_state(state)
+ generator = trigger.run()
+ actual = await generator.asend(None)
+ assert (
+ TriggerEvent(
+ {
+ "job_name": JOB_NAME,
+ "status": "deleted",
+ "message": f"Batch job with name {JOB_NAME} is being
deleted",
+ }
+ )
+ == actual
+ )
+
+ @pytest.mark.asyncio
+
@mock.patch("airflow.providers.google.cloud.triggers.cloud_batch.CloudBatchAsyncHook")
+ async def test_trigger_on_deleted_yield_exception(self, mock_hook,
trigger: CloudBatchJobFinishedTrigger):
+ """
+ Tests the CloudBatchJobFinishedTrigger fires once the job execution
+ reaches an state with an error message.
+ """
+ mock_hook.return_value.get_batch_job.side_effect = Exception("Test
Exception")
+ generator = trigger.run()
+ actual = await generator.asend(None)
+ assert (
+ TriggerEvent(
+ {
+ "status": "error",
+ "message": "Test Exception",
+ }
+ )
+ == actual
+ )
+
+ @pytest.mark.asyncio
+
@mock.patch("airflow.providers.google.cloud.triggers.cloud_batch.CloudBatchAsyncHook")
+ async def test_trigger_timeout(self, mock_hook, trigger:
CloudBatchJobFinishedTrigger):
+ """
+ Tests the CloudBatchJobFinishedTrigger fires once the job execution
times out with an error message.
+ """
+
+ async def _mock_job(job_name):
+ job = mock.MagicMock()
+ job.status.state = JobStatus.State.RUNNING
+ return job
+
+ mock_hook.return_value.get_batch_job = _mock_job
+
+ generator = trigger.run()
+ actual = await generator.asend(None)
+ assert (
+ TriggerEvent(
+ {
+ "job_name": JOB_NAME,
+ "status": "timed out",
+ "message": f"Batch job with name {JOB_NAME} timed out",
+ }
+ )
+ == actual
+ )
+
+ async def _mock_job_with_state(self, state: JobStatus.State):
+ job: Job = mock.MagicMock()
+ job.status.state = state
+ return job
diff --git a/tests/system/providers/google/cloud/cloud_batch/__init__.py
b/tests/system/providers/google/cloud/cloud_batch/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/tests/system/providers/google/cloud/cloud_batch/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git
a/tests/system/providers/google/cloud/cloud_batch/example_cloud_batch.py
b/tests/system/providers/google/cloud/cloud_batch/example_cloud_batch.py
new file mode 100644
index 0000000000..d3f3d752a8
--- /dev/null
+++ b/tests/system/providers/google/cloud/cloud_batch/example_cloud_batch.py
@@ -0,0 +1,202 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Example Airflow DAG that uses Google Cloud Batch Operators.
+"""
+from __future__ import annotations
+
+import os
+from datetime import datetime
+
+from google.cloud import batch_v1
+
+from airflow import models
+from airflow.operators.python import PythonOperator
+from airflow.providers.google.cloud.operators.cloud_batch import (
+ CloudBatchDeleteJobOperator,
+ CloudBatchListJobsOperator,
+ CloudBatchListTasksOperator,
+ CloudBatchSubmitJobOperator,
+)
+from airflow.utils.trigger_rule import TriggerRule
+
+PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT")
+DAG_ID = "example_cloud_batch"
+
+region = "us-central1"
+job_name_prefix = "batch-system-test-job"
+job1_name = f"{job_name_prefix}1"
+job2_name = f"{job_name_prefix}2"
+
+submit1_task_name = "submit-job1"
+submit2_task_name = "submit-job2"
+
+delete1_task_name = "delete-job1"
+delete2_task_name = "delete-job2"
+
+list_jobs_task_name = "list-jobs"
+list_tasks_task_name = "list-tasks"
+
+clean1_task_name = "clean-job1"
+clean2_task_name = "clean-job2"
+
+
+def _assert_jobs(ti):
+ job_names = ti.xcom_pull(task_ids=[list_jobs_task_name],
key="return_value")
+ job_names_str = job_names[0][0]["name"].split("/")[-1] + " " +
job_names[0][1]["name"].split("/")[-1]
+ assert job1_name in job_names_str
+ assert job2_name in job_names_str
+
+
+def _assert_tasks(ti):
+ tasks_names = ti.xcom_pull(task_ids=[list_tasks_task_name],
key="return_value")
+ assert len(tasks_names[0]) == 2
+ assert "tasks/0" in tasks_names[0][0]["name"]
+ assert "tasks/1" in tasks_names[0][1]["name"]
+
+
+# [START howto_operator_batch_job_creation]
+def _create_job():
+ runnable = batch_v1.Runnable()
+ runnable.container = batch_v1.Runnable.Container()
+ runnable.container.image_uri = "gcr.io/google-containers/busybox"
+ runnable.container.entrypoint = "/bin/sh"
+ runnable.container.commands = [
+ "-c",
+ "echo Hello world! This is task ${BATCH_TASK_INDEX}.\
+ This job has a total of ${BATCH_TASK_COUNT} tasks.",
+ ]
+
+ task = batch_v1.TaskSpec()
+ task.runnables = [runnable]
+
+ resources = batch_v1.ComputeResource()
+ resources.cpu_milli = 2000
+ resources.memory_mib = 16
+ task.compute_resource = resources
+ task.max_retry_count = 2
+
+ group = batch_v1.TaskGroup()
+ group.task_count = 2
+ group.task_spec = task
+ policy = batch_v1.AllocationPolicy.InstancePolicy()
+ policy.machine_type = "e2-standard-4"
+ instances = batch_v1.AllocationPolicy.InstancePolicyOrTemplate()
+ instances.policy = policy
+ allocation_policy = batch_v1.AllocationPolicy()
+ allocation_policy.instances = [instances]
+
+ job = batch_v1.Job()
+ job.task_groups = [group]
+ job.allocation_policy = allocation_policy
+ job.labels = {"env": "testing", "type": "container"}
+
+ job.logs_policy = batch_v1.LogsPolicy()
+ job.logs_policy.destination = batch_v1.LogsPolicy.Destination.CLOUD_LOGGING
+
+ return job
+
+
+# [END howto_operator_batch_job_creation]
+
+
+with models.DAG(
+ DAG_ID,
+ schedule="@once",
+ start_date=datetime(2021, 1, 1),
+ catchup=False,
+ tags=["example", "batch"],
+) as dag:
+
+ # [START howto_operator_batch_submit_job]
+ submit1 = CloudBatchSubmitJobOperator(
+ task_id=submit1_task_name,
+ project_id=PROJECT_ID,
+ region=region,
+ job_name=job1_name,
+ job=_create_job(),
+ dag=dag,
+ deferrable=False,
+ )
+ # [END howto_operator_batch_submit_job]
+
+ # [START howto_operator_batch_submit_job_deferrable_mode]
+ submit2 = CloudBatchSubmitJobOperator(
+ task_id=submit2_task_name,
+ project_id=PROJECT_ID,
+ region=region,
+ job_name=job2_name,
+ job=batch_v1.Job.to_dict(_create_job()),
+ dag=dag,
+ deferrable=True,
+ )
+ # [END howto_operator_batch_submit_job_deferrable_mode]
+
+ # [START howto_operator_batch_list_tasks]
+ list_tasks = CloudBatchListTasksOperator(
+ task_id=list_tasks_task_name, project_id=PROJECT_ID, region=region,
job_name=job1_name, dag=dag
+ )
+ # [END howto_operator_batch_list_tasks]
+
+ assert_tasks = PythonOperator(task_id="assert-tasks",
python_callable=_assert_tasks, dag=dag)
+
+ # [START howto_operator_batch_list_jobs]
+ list_jobs = CloudBatchListJobsOperator(
+ task_id=list_jobs_task_name,
+ project_id=PROJECT_ID,
+ region=region,
+ limit=2,
+
filter=f"name:projects/{PROJECT_ID}/locations/{region}/jobs/{job_name_prefix}*",
+ dag=dag,
+ )
+ # [END howto_operator_batch_list_jobs]
+
+ get_name = PythonOperator(task_id="assert-jobs",
python_callable=_assert_jobs, dag=dag)
+
+ # [START howto_operator_delete_job]
+ delete_job1 = CloudBatchDeleteJobOperator(
+ task_id="delete-job1",
+ project_id=PROJECT_ID,
+ region=region,
+ job_name=job1_name,
+ dag=dag,
+ trigger_rule=TriggerRule.ALL_DONE,
+ )
+ # [END howto_operator_delete_job]
+
+ delete_job2 = CloudBatchDeleteJobOperator(
+ task_id="delete-job2",
+ project_id=PROJECT_ID,
+ region=region,
+ job_name=job2_name,
+ dag=dag,
+ trigger_rule=TriggerRule.ALL_DONE,
+ )
+
+ ([submit1, submit2] >> list_tasks >> assert_tasks >> list_jobs >> get_name
>> [delete_job1, delete_job2])
+
+ from tests.system.utils.watcher import watcher
+
+ # This test needs watcher in order to properly mark success/failure
+ # when "tearDown" task with trigger rule is part of the DAG
+ list(dag.tasks) >> watcher()
+
+from tests.system.utils import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see:
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)