This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 19ebc8532e6 Create CloudComposerExternalTaskSensor for Cloud Composer
service (#57971)
19ebc8532e6 is described below
commit 19ebc8532e6a436b06aa04afa8545a64ded9a6a0
Author: Maksim <[email protected]>
AuthorDate: Mon Nov 24 12:30:34 2025 +0100
Create CloudComposerExternalTaskSensor for Cloud Composer service (#57971)
---
dev/breeze/tests/test_selective_checks.py | 11 +-
.../google/docs/operators/cloud/cloud_composer.rst | 20 +
providers/google/pyproject.toml | 4 +
.../providers/google/cloud/hooks/cloud_composer.py | 74 +++-
.../google/cloud/sensors/cloud_composer.py | 443 ++++++++++++++++++++-
.../google/cloud/triggers/cloud_composer.py | 184 ++++++++-
.../cloud/composer/example_cloud_composer.py | 35 +-
.../unit/google/cloud/hooks/test_cloud_composer.py | 37 ++
.../google/cloud/sensors/test_cloud_composer.py | 173 +++++++-
.../google/cloud/triggers/test_cloud_composer.py | 62 ++-
10 files changed, 1026 insertions(+), 17 deletions(-)
diff --git a/dev/breeze/tests/test_selective_checks.py
b/dev/breeze/tests/test_selective_checks.py
index f1b8c4e3e4a..a14d7035d8c 100644
--- a/dev/breeze/tests/test_selective_checks.py
+++ b/dev/breeze/tests/test_selective_checks.py
@@ -1852,7 +1852,7 @@ def test_expected_output_push(
"selected-providers-list-as-string": "amazon apache.beam
apache.cassandra apache.kafka "
"cncf.kubernetes common.compat common.sql "
"facebook google hashicorp http microsoft.azure
microsoft.mssql mysql "
- "openlineage oracle postgres presto salesforce samba sftp ssh
trino",
+ "openlineage oracle postgres presto salesforce samba sftp ssh
standard trino",
"all-python-versions":
f"['{DEFAULT_PYTHON_MAJOR_MINOR_VERSION}']",
"all-python-versions-list-as-string":
DEFAULT_PYTHON_MAJOR_MINOR_VERSION,
"ci-image-build": "true",
@@ -1864,7 +1864,7 @@ def test_expected_output_push(
"docs-list-as-string": "apache-airflow helm-chart amazon
apache.beam apache.cassandra "
"apache.kafka cncf.kubernetes common.compat common.sql
facebook google hashicorp http microsoft.azure "
"microsoft.mssql mysql openlineage oracle postgres "
- "presto salesforce samba sftp ssh trino",
+ "presto salesforce samba sftp ssh standard trino",
"skip-prek-hooks": ALL_SKIPPED_COMMITS_IF_NO_UI,
"run-kubernetes-tests": "true",
"upgrade-to-newer-dependencies": "false",
@@ -1874,12 +1874,13 @@ def test_expected_output_push(
"providers-test-types-list-as-strings-in-json": json.dumps(
[
{
- "description": "amazon...google",
+ "description": "amazon...standard",
"test_types": "Providers[amazon]
Providers[apache.beam,apache.cassandra,"
"apache.kafka,cncf.kubernetes,common.compat,common.sql,facebook,"
"hashicorp,http,microsoft.azure,microsoft.mssql,mysql,"
"openlineage,oracle,postgres,presto,salesforce,samba,sftp,ssh,trino] "
- "Providers[google]",
+ "Providers[google] "
+ "Providers[standard]",
}
]
),
@@ -2122,7 +2123,7 @@ def test_upgrade_to_newer_dependencies(
"docs-list-as-string": "amazon apache.beam apache.cassandra
apache.kafka "
"cncf.kubernetes common.compat common.sql facebook google
hashicorp http "
"microsoft.azure microsoft.mssql mysql openlineage oracle "
- "postgres presto salesforce samba sftp ssh trino",
+ "postgres presto salesforce samba sftp ssh standard trino",
},
id="Google provider docs changed",
),
diff --git a/providers/google/docs/operators/cloud/cloud_composer.rst
b/providers/google/docs/operators/cloud/cloud_composer.rst
index 88381b48438..9e18698789f 100644
--- a/providers/google/docs/operators/cloud/cloud_composer.rst
+++ b/providers/google/docs/operators/cloud/cloud_composer.rst
@@ -209,3 +209,23 @@ You can trigger a DAG in another Composer environment, use:
:dedent: 4
:start-after: [START howto_operator_trigger_dag_run]
:end-before: [END howto_operator_trigger_dag_run]
+
+Waits for a different DAG, task group, or task to complete
+----------------------------------------------------------
+
+You can use sensor that waits for a different DAG, task group, or task to
complete for a specific composer environment, use:
+:class:`~airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerExternalTaskSensor`
+
+.. exampleinclude::
/../../google/tests/system/google/cloud/composer/example_cloud_composer.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_sensor_external_task]
+ :end-before: [END howto_sensor_external_task]
+
+or you can define the same sensor in the deferrable mode:
+
+.. exampleinclude::
/../../google/tests/system/google/cloud/composer/example_cloud_composer.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_sensor_external_task_deferrable_mode]
+ :end-before: [END howto_sensor_external_task_deferrable_mode]
diff --git a/providers/google/pyproject.toml b/providers/google/pyproject.toml
index 646fb0c26fd..7fe5536d162 100644
--- a/providers/google/pyproject.toml
+++ b/providers/google/pyproject.toml
@@ -204,6 +204,9 @@ dependencies = [
"http" = [
"apache-airflow-providers-http"
]
+"standard" = [
+ "apache-airflow-providers-standard"
+]
[dependency-groups]
dev = [
@@ -228,6 +231,7 @@ dev = [
"apache-airflow-providers-salesforce",
"apache-airflow-providers-sftp",
"apache-airflow-providers-ssh",
+ "apache-airflow-providers-standard",
"apache-airflow-providers-trino",
# Additional devel dependencies (do not remove this line and add extra
development dependencies)
"apache-airflow-providers-apache-kafka",
diff --git
a/providers/google/src/airflow/providers/google/cloud/hooks/cloud_composer.py
b/providers/google/src/airflow/providers/google/cloud/hooks/cloud_composer.py
index 423d6cb56df..9549254aaf4 100644
---
a/providers/google/src/airflow/providers/google/cloud/hooks/cloud_composer.py
+++
b/providers/google/src/airflow/providers/google/cloud/hooks/cloud_composer.py
@@ -22,7 +22,7 @@ import json
import time
from collections.abc import MutableSequence, Sequence
from typing import TYPE_CHECKING, Any
-from urllib.parse import urljoin
+from urllib.parse import urlencode, urljoin
from aiohttp import ClientSession
from google.api_core.client_options import ClientOptions
@@ -505,6 +505,42 @@ class CloudComposerHook(GoogleBaseHook, OperationHelper):
return response.json()
+ def get_task_instances(
+ self,
+ composer_airflow_uri: str,
+ composer_dag_id: str,
+ query_parameters: dict | None = None,
+ timeout: float | None = None,
+ ) -> dict:
+ """
+ Get the list of task instances for provided DAG.
+
+ :param composer_airflow_uri: The URI of the Apache Airflow Web UI
hosted within Composer environment.
+ :param composer_dag_id: The ID of DAG.
+ :query_parameters: Query parameters for this request.
+ :param timeout: The timeout for this request.
+ """
+ query_string = f"?{urlencode(query_parameters)}" if query_parameters
else ""
+
+ response = self.make_composer_airflow_api_request(
+ method="GET",
+ airflow_uri=composer_airflow_uri,
+
path=f"/api/v1/dags/{composer_dag_id}/dagRuns/~/taskInstances{query_string}",
+ timeout=timeout,
+ )
+
+ if response.status_code != 200:
+ self.log.error(
+ "Failed to get task instances for dag_id=%s from %s
(status=%s): %s",
+ composer_dag_id,
+ composer_airflow_uri,
+ response.status_code,
+ response.text,
+ )
+ response.raise_for_status()
+
+ return response.json()
+
class CloudComposerAsyncHook(GoogleBaseAsyncHook):
"""Hook for Google Cloud Composer async APIs."""
@@ -849,3 +885,39 @@ class CloudComposerAsyncHook(GoogleBaseAsyncHook):
raise AirflowException(response_body["title"])
return response_body
+
+ async def get_task_instances(
+ self,
+ composer_airflow_uri: str,
+ composer_dag_id: str,
+ query_parameters: dict | None = None,
+ timeout: float | None = None,
+ ) -> dict:
+ """
+ Get the list of task instances for provided DAG.
+
+ :param composer_airflow_uri: The URI of the Apache Airflow Web UI
hosted within Composer environment.
+ :param composer_dag_id: The ID of DAG.
+ :query_parameters: Query parameters for this request.
+ :param timeout: The timeout for this request.
+ """
+ query_string = f"?{urlencode(query_parameters)}" if query_parameters
else ""
+
+ response_body, response_status_code = await
self.make_composer_airflow_api_request(
+ method="GET",
+ airflow_uri=composer_airflow_uri,
+
path=f"/api/v1/dags/{composer_dag_id}/dagRuns/~/taskInstances{query_string}",
+ timeout=timeout,
+ )
+
+ if response_status_code != 200:
+ self.log.error(
+ "Failed to get task instances for dag_id=%s from %s
(status=%s): %s",
+ composer_dag_id,
+ composer_airflow_uri,
+ response_status_code,
+ response_body["title"],
+ )
+ raise AirflowException(response_body["title"])
+
+ return response_body
diff --git
a/providers/google/src/airflow/providers/google/cloud/sensors/cloud_composer.py
b/providers/google/src/airflow/providers/google/cloud/sensors/cloud_composer.py
index 14976293746..5db00f60ce3 100644
---
a/providers/google/src/airflow/providers/google/cloud/sensors/cloud_composer.py
+++
b/providers/google/src/airflow/providers/google/cloud/sensors/cloud_composer.py
@@ -20,7 +20,7 @@
from __future__ import annotations
import json
-from collections.abc import Iterable, Sequence
+from collections.abc import Collection, Iterable, Sequence
from datetime import datetime, timedelta
from functools import cached_property
from typing import TYPE_CHECKING
@@ -30,12 +30,21 @@ from google.api_core.exceptions import NotFound
from google.cloud.orchestration.airflow.service_v1.types import Environment,
ExecuteAirflowCommandResponse
from airflow.configuration import conf
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.common.compat.sdk import BaseSensorOperator
from airflow.providers.google.cloud.hooks.cloud_composer import
CloudComposerHook
-from airflow.providers.google.cloud.triggers.cloud_composer import
CloudComposerDAGRunTrigger
+from airflow.providers.google.cloud.triggers.cloud_composer import (
+ CloudComposerDAGRunTrigger,
+ CloudComposerExternalTaskTrigger,
+)
from airflow.providers.google.common.consts import
GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
-from airflow.utils.state import TaskInstanceState
+from airflow.providers.standard.exceptions import (
+ DuplicateStateError,
+ ExternalDagFailedError,
+ ExternalTaskFailedError,
+ ExternalTaskGroupFailedError,
+)
+from airflow.utils.state import State, TaskInstanceState
if TYPE_CHECKING:
from airflow.providers.common.compat.sdk import Context
@@ -286,3 +295,429 @@ class CloudComposerDAGRunSensor(BaseSensorOperator):
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
+
+
+class CloudComposerExternalTaskSensor(BaseSensorOperator):
+ """
+ Waits for a different DAG, task group, or task to complete for a specific
composer environment.
+
+ If both `composer_external_task_group_id` and `composer_external_task_id`
are ``None`` (default), the sensor
+ waits for the DAG.
+ Values for `composer_external_task_group_id` and
`composer_external_task_id` can't be set at the same time.
+
+ By default, the CloudComposerExternalTaskSensor will wait for the external
task to
+ succeed, at which point it will also succeed. However, by default it will
+ *not* fail if the external task fails, but will continue to check the
status
+ until the sensor times out (thus giving you time to retry the external task
+ without also having to clear the sensor).
+
+ By default, the CloudComposerExternalTaskSensor will not skip if the
external task skips.
+ To change this, simply set ``skipped_states=[TaskInstanceState.SKIPPED]``.
+ Note that if you are monitoring multiple tasks, and one enters error state
+ and the other enters a skipped state, then the external task will react to
+ whichever one it sees first. If both happen together, then the failed state
+ takes priority.
+
+ It is possible to alter the default behavior by setting states which
+ cause the sensor to fail, e.g. by setting
``allowed_states=[DagRunState.FAILED]``
+ and ``failed_states=[DagRunState.SUCCESS]`` you will flip the behaviour to
+ get a sensor which goes green when the external task *fails* and
immediately
+ goes red if the external task *succeeds*!
+
+ Note that ``soft_fail`` is respected when examining the failed_states. Thus
+ if the external task enters a failed state and ``soft_fail == True`` the
+ sensor will _skip_ rather than fail. As a result, setting
``soft_fail=True``
+ and ``failed_states=[DagRunState.SKIPPED]`` will result in the sensor
+ skipping if the external task skips. However, this is a contrived
+ example---consider using ``skipped_states`` if you would like this
+ behaviour. Using ``skipped_states`` allows the sensor to skip if the target
+ fails, but still enter failed state on timeout. Using ``soft_fail == True``
+ as above will cause the sensor to skip if the target fails, but also if it
+ times out.
+
+ :param project_id: Required. The ID of the Google Cloud project that the
service belongs to.
+ :param region: Required. The ID of the Google Cloud region that the
service belongs to.
+ :param environment_id: The name of the Composer environment.
+ :param composer_external_dag_id: The dag_id that contains the task you
want to
+ wait for. (templated)
+ :param composer_external_task_id: The task_id that contains the task you
want to
+ wait for. (templated)
+ :param composer_external_task_ids: The list of task_ids that you want to
wait for. (templated)
+ If ``None`` (default value) the sensor waits for the DAG. Either
+ composer_external_task_id or composer_external_task_ids can be passed
to
+ CloudComposerExternalTaskSensor, but not both.
+ :param composer_external_task_group_id: The task_group_id that contains
the task you want to
+ wait for. (templated)
+ :param allowed_states: Iterable of allowed states, default is
``['success']``
+ :param skipped_states: Iterable of states to make this task mark as
skipped, default is ``None``
+ :param failed_states: Iterable of failed or dis-allowed states, default is
``None``
+ :param execution_range: execution DAGs time range. Sensor checks DAGs
states only for DAGs which were
+ started in this time range. For yesterday, use [positive!]
datetime.timedelta(days=1).
+ For future, use [negative!] datetime.timedelta(days=-1). For specific
time, use list of
+ datetimes [datetime(2024,3,22,11,0,0), datetime(2024,3,22,12,0,0)].
+ Or [datetime(2024,3,22,0,0,0)] in this case sensor will check for
states from specific time in the
+ past till current time execution.
+ Default value datetime.timedelta(days=1).
+ :param gcp_conn_id: The connection ID to use when fetching connection info.
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account
(templated).
+ :param poll_interval: Optional: Control the rate of the poll for the
result of deferrable run.
+ :param deferrable: Run sensor in deferrable mode.
+ """
+
+ template_fields = (
+ "project_id",
+ "region",
+ "environment_id",
+ "composer_external_dag_id",
+ "composer_external_task_id",
+ "composer_external_task_ids",
+ "composer_external_task_group_id",
+ "impersonation_chain",
+ )
+
+ def __init__(
+ self,
+ *,
+ project_id: str,
+ region: str,
+ environment_id: str,
+ composer_external_dag_id: str,
+ composer_external_task_id: str | None = None,
+ composer_external_task_ids: Collection[str] | None = None,
+ composer_external_task_group_id: str | None = None,
+ allowed_states: Iterable[str] | None = None,
+ skipped_states: Iterable[str] | None = None,
+ failed_states: Iterable[str] | None = None,
+ execution_range: timedelta | list[datetime] | None = None,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
+ poll_interval: int = 10,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.project_id = project_id
+ self.region = region
+ self.environment_id = environment_id
+
+ self.allowed_states = list(allowed_states) if allowed_states else
[TaskInstanceState.SUCCESS.value]
+ self.skipped_states = list(skipped_states) if skipped_states else []
+ self.failed_states = list(failed_states) if failed_states else []
+
+ total_states = set(self.allowed_states + self.skipped_states +
self.failed_states)
+
+ if len(total_states) != len(self.allowed_states) +
len(self.skipped_states) + len(self.failed_states):
+ raise DuplicateStateError(
+ "Duplicate values provided across allowed_states,
skipped_states and failed_states."
+ )
+
+ # convert [] to None
+ if not composer_external_task_ids:
+ composer_external_task_ids = None
+
+ # can't set both single task id and a list of task ids
+ if composer_external_task_id is not None and
composer_external_task_ids is not None:
+ raise ValueError(
+ "Only one of `composer_external_task_id` or
`composer_external_task_ids` may "
+ "be provided to CloudComposerExternalTaskSensor; "
+ "use `composer_external_task_id` or
`composer_external_task_ids` or `composer_external_task_group_id`."
+ )
+
+ # since both not set, convert the single id to a 1-elt list - from
here on, we only consider the list
+ if composer_external_task_id is not None:
+ composer_external_task_ids = [composer_external_task_id]
+
+ if composer_external_task_group_id is not None and
composer_external_task_ids is not None:
+ raise ValueError(
+ "Only one of `composer_external_task_group_id` or
`composer_external_task_ids` may "
+ "be provided to CloudComposerExternalTaskSensor; "
+ "use `composer_external_task_id` or
`composer_external_task_ids` or `composer_external_task_group_id`."
+ )
+
+ # check the requested states are all valid states for the target type,
be it dag or task
+ if composer_external_task_ids or composer_external_task_group_id:
+ if not total_states <= set(State.task_states):
+ raise ValueError(
+ "Valid values for `allowed_states`, `skipped_states` and
`failed_states` "
+ "when `composer_external_task_id` or
`composer_external_task_ids` or `composer_external_task_group_id` "
+ f"is not `None`: {State.task_states}"
+ )
+ elif not total_states <= set(State.dag_states):
+ raise ValueError(
+ "Valid values for `allowed_states`, `skipped_states` and
`failed_states` "
+ f"when `composer_external_task_id` and
`composer_external_task_group_id` is `None`: {State.dag_states}"
+ )
+
+ self.execution_range = execution_range
+ self.composer_external_dag_id = composer_external_dag_id
+ self.composer_external_task_id = composer_external_task_id
+ self.composer_external_task_ids = composer_external_task_ids
+ self.composer_external_task_group_id = composer_external_task_group_id
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+ self.deferrable = deferrable
+ self.poll_interval = poll_interval
+
+ def _get_logical_dates(self, context) -> tuple[datetime, datetime]:
+ logical_date = context.get("logical_date", None)
+ if logical_date is None:
+ raise RuntimeError(
+ "logical_date is None. Please make sure the sensor is not used
in an asset-triggered Dag. "
+ "CloudComposerDAGRunSensor was designed to be used in
time-based scheduled Dags only, "
+ "and asset-triggered Dags do not have logical_date. "
+ )
+ if isinstance(self.execution_range, timedelta):
+ if self.execution_range < timedelta(0):
+ return logical_date, logical_date - self.execution_range
+ return logical_date - self.execution_range, logical_date
+ if isinstance(self.execution_range, list) and
len(self.execution_range) > 0:
+ return self.execution_range[0], self.execution_range[1] if len(
+ self.execution_range
+ ) > 1 else logical_date
+ return logical_date - timedelta(1), logical_date
+
+ def poke(self, context: Context) -> bool:
+ start_date, end_date = self._get_logical_dates(context)
+
+ task_instances = self._get_task_instances(
+ start_date=start_date.strftime("%Y-%m-%dT%H:%M:%SZ"),
+ end_date=end_date.strftime("%Y-%m-%dT%H:%M:%SZ"),
+ )
+
+ if len(task_instances) == 0:
+ self.log.info("Task Instances are empty. Sensor waits for task
instances...")
+ return False
+
+ if self.failed_states:
+ external_task_status = self._check_task_instances_states(
+ task_instances=task_instances,
+ start_date=start_date,
+ end_date=end_date,
+ states=self.failed_states,
+ )
+ self._handle_failed_states(external_task_status)
+
+ if self.skipped_states:
+ external_task_status = self._check_task_instances_states(
+ task_instances=task_instances,
+ start_date=start_date,
+ end_date=end_date,
+ states=self.skipped_states,
+ )
+ self._handle_skipped_states(external_task_status)
+
+ self.log.info("Sensor waits for allowed states: %s",
self.allowed_states)
+ external_task_status = self._check_task_instances_states(
+ task_instances=task_instances,
+ start_date=start_date,
+ end_date=end_date,
+ states=self.allowed_states,
+ )
+ return external_task_status
+
+ def _get_task_instances(self, start_date: str, end_date: str) ->
list[dict]:
+ """Get the list of task instances."""
+ try:
+ environment = self.hook.get_environment(
+ project_id=self.project_id,
+ region=self.region,
+ environment_id=self.environment_id,
+ timeout=self.timeout,
+ )
+ except NotFound as not_found_err:
+ self.log.info("The Composer environment %s does not exist.",
self.environment_id)
+ raise AirflowException(not_found_err)
+ composer_airflow_uri = environment.config.airflow_uri
+
+ self.log.info(
+ "Pulling the DAG '%s' task instances from the '%s' environment...",
+ self.composer_external_dag_id,
+ self.environment_id,
+ )
+ task_instances_response = self.hook.get_task_instances(
+ composer_airflow_uri=composer_airflow_uri,
+ composer_dag_id=self.composer_external_dag_id,
+ query_parameters={
+ "execution_date_gte"
+ if self._composer_airflow_version < 3
+ else "logical_date_gte": start_date,
+ "execution_date_lte" if self._composer_airflow_version < 3
else "logical_date_lte": end_date,
+ },
+ timeout=self.timeout,
+ )
+ task_instances = task_instances_response["task_instances"]
+
+ if self.composer_external_task_ids:
+ task_instances = [
+ task_instance
+ for task_instance in task_instances
+ if task_instance["task_id"] in self.composer_external_task_ids
+ ]
+ elif self.composer_external_task_group_id:
+ task_instances = [
+ task_instance
+ for task_instance in task_instances
+ if self.composer_external_task_group_id in
task_instance["task_id"].split(".")
+ ]
+
+ return task_instances
+
+ def _check_task_instances_states(
+ self,
+ task_instances: list[dict],
+ start_date: datetime,
+ end_date: datetime,
+ states: Iterable[str],
+ ) -> bool:
+ for task_instance in task_instances:
+ if (
+ start_date.timestamp()
+ < parser.parse(
+ task_instance["execution_date" if
self._composer_airflow_version < 3 else "logical_date"]
+ ).timestamp()
+ < end_date.timestamp()
+ ) and task_instance["state"] not in states:
+ return False
+ return True
+
+ def _get_composer_airflow_version(self) -> int:
+ """Return Composer Airflow version."""
+ environment_obj = self.hook.get_environment(
+ project_id=self.project_id,
+ region=self.region,
+ environment_id=self.environment_id,
+ )
+ environment_config = Environment.to_dict(environment_obj)
+ image_version =
environment_config["config"]["software_config"]["image_version"]
+ return int(image_version.split("airflow-")[1].split(".")[0])
+
+ def _handle_failed_states(self, failed_status: bool) -> None:
+ """Handle failed states and raise appropriate exceptions."""
+ if failed_status:
+ if self.composer_external_task_ids:
+ if self.soft_fail:
+ raise AirflowSkipException(
+ f"Some of the external tasks
'{self.composer_external_task_ids}' "
+ f"in DAG '{self.composer_external_dag_id}' failed.
Skipping due to soft_fail."
+ )
+ raise ExternalTaskFailedError(
+ f"Some of the external tasks
'{self.composer_external_task_ids}' "
+ f"in DAG '{self.composer_external_dag_id}' failed."
+ )
+ if self.composer_external_task_group_id:
+ if self.soft_fail:
+ raise AirflowSkipException(
+ f"The external task_group
'{self.composer_external_task_group_id}' "
+ f"in DAG '{self.composer_external_dag_id}' failed.
Skipping due to soft_fail."
+ )
+ raise ExternalTaskGroupFailedError(
+ f"The external task_group
'{self.composer_external_task_group_id}' "
+ f"in DAG '{self.composer_external_dag_id}' failed."
+ )
+ if self.soft_fail:
+ raise AirflowSkipException(
+ f"The external DAG '{self.composer_external_dag_id}'
failed. Skipping due to soft_fail."
+ )
+ raise ExternalDagFailedError(f"The external DAG
'{self.composer_external_dag_id}' failed.")
+
+ def _handle_skipped_states(self, skipped_status: bool) -> None:
+ """Handle skipped states and raise appropriate exceptions."""
+ if skipped_status:
+ if self.composer_external_task_ids:
+ raise AirflowSkipException(
+ f"Some of the external tasks
'{self.composer_external_task_ids}' "
+ f"in DAG '{self.composer_external_dag_id}' reached a state
in our states-to-skip-on list. Skipping."
+ )
+ if self.composer_external_task_group_id:
+ raise AirflowSkipException(
+ f"The external task_group
'{self.composer_external_task_group_id}' "
+ f"in DAG '{self.composer_external_dag_id}' reached a state
in our states-to-skip-on list. Skipping."
+ )
+ raise AirflowSkipException(
+ f"The external DAG '{self.composer_external_dag_id}' reached a
state in our states-to-skip-on list. "
+ "Skipping."
+ )
+
+ def execute(self, context: Context) -> None:
+ self._composer_airflow_version = self._get_composer_airflow_version()
+
+ if self.composer_external_task_ids and
len(self.composer_external_task_ids) > len(
+ set(self.composer_external_task_ids)
+ ):
+ raise ValueError("Duplicate task_ids passed in
composer_external_task_ids parameter")
+
+ if self.composer_external_task_ids:
+ self.log.info(
+ "Poking for tasks '%s' in dag '%s' on Composer environment
'%s' ... ",
+ self.composer_external_task_ids,
+ self.composer_external_dag_id,
+ self.environment_id,
+ )
+
+ if self.composer_external_task_group_id:
+ self.log.info(
+ "Poking for task_group '%s' in dag '%s' on Composer
environment '%s' ... ",
+ self.composer_external_task_group_id,
+ self.composer_external_dag_id,
+ self.environment_id,
+ )
+
+ if (
+ self.composer_external_dag_id
+ and not self.composer_external_task_group_id
+ and not self.composer_external_task_ids
+ ):
+ self.log.info(
+ "Poking for DAG '%s' on Composer environment '%s' ... ",
+ self.composer_external_dag_id,
+ self.environment_id,
+ )
+
+ if self.deferrable:
+ start_date, end_date = self._get_logical_dates(context)
+ self.defer(
+ timeout=timedelta(seconds=self.timeout) if self.timeout else
None,
+ trigger=CloudComposerExternalTaskTrigger(
+ project_id=self.project_id,
+ region=self.region,
+ environment_id=self.environment_id,
+ composer_external_dag_id=self.composer_external_dag_id,
+ composer_external_task_ids=self.composer_external_task_ids,
+
composer_external_task_group_id=self.composer_external_task_group_id,
+ start_date=start_date,
+ end_date=end_date,
+ allowed_states=self.allowed_states,
+ skipped_states=self.skipped_states,
+ failed_states=self.failed_states,
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ poll_interval=self.poll_interval,
+ composer_airflow_version=self._composer_airflow_version,
+ ),
+ method_name=GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME,
+ )
+ super().execute(context)
+
+ def execute_complete(self, context: Context, event: dict):
+ if event and event["status"] == "error":
+ raise AirflowException(event["message"])
+ if event and event["status"] == "failed":
+ self._handle_failed_states(True)
+ elif event and event["status"] == "skipped":
+ self._handle_skipped_states(True)
+
+ self.log.info("External tasks for DAG '%s' has executed
successfully.", self.composer_external_dag_id)
+
+ @cached_property
+ def hook(self) -> CloudComposerHook:
+ return CloudComposerHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
diff --git
a/providers/google/src/airflow/providers/google/cloud/triggers/cloud_composer.py
b/providers/google/src/airflow/providers/google/cloud/triggers/cloud_composer.py
index 006480e5da5..bf60acdf93e 100644
---
a/providers/google/src/airflow/providers/google/cloud/triggers/cloud_composer.py
+++
b/providers/google/src/airflow/providers/google/cloud/triggers/cloud_composer.py
@@ -20,7 +20,7 @@ from __future__ import annotations
import asyncio
import json
-from collections.abc import Sequence
+from collections.abc import Collection, Iterable, Sequence
from datetime import datetime
from typing import Any
@@ -346,3 +346,185 @@ class CloudComposerDAGRunTrigger(BaseTrigger):
}
)
return
+
+
+class CloudComposerExternalTaskTrigger(BaseTrigger):
+ """The trigger wait for the external task completion."""
+
+ def __init__(
+ self,
+ project_id: str,
+ region: str,
+ environment_id: str,
+ start_date: datetime,
+ end_date: datetime,
+ allowed_states: list[str],
+ skipped_states: list[str],
+ failed_states: list[str],
+ composer_external_dag_id: str,
+ composer_external_task_ids: Collection[str] | None = None,
+ composer_external_task_group_id: str | None = None,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ poll_interval: int = 10,
+ composer_airflow_version: int = 2,
+ ):
+ super().__init__()
+ self.project_id = project_id
+ self.region = region
+ self.environment_id = environment_id
+ self.start_date = start_date
+ self.end_date = end_date
+ self.allowed_states = allowed_states
+ self.skipped_states = skipped_states
+ self.failed_states = failed_states
+ self.composer_external_dag_id = composer_external_dag_id
+ self.composer_external_task_ids = composer_external_task_ids
+ self.composer_external_task_group_id = composer_external_task_group_id
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+ self.poll_interval = poll_interval
+ self.composer_airflow_version = composer_airflow_version
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ return (
+
"airflow.providers.google.cloud.triggers.cloud_composer.CloudComposerExternalTaskTrigger",
+ {
+ "project_id": self.project_id,
+ "region": self.region,
+ "environment_id": self.environment_id,
+ "start_date": self.start_date,
+ "end_date": self.end_date,
+ "allowed_states": self.allowed_states,
+ "skipped_states": self.skipped_states,
+ "failed_states": self.failed_states,
+ "composer_external_dag_id": self.composer_external_dag_id,
+ "composer_external_task_ids": self.composer_external_task_ids,
+ "composer_external_task_group_id":
self.composer_external_task_group_id,
+ "gcp_conn_id": self.gcp_conn_id,
+ "impersonation_chain": self.impersonation_chain,
+ "poll_interval": self.poll_interval,
+ "composer_airflow_version": self.composer_airflow_version,
+ },
+ )
+
+ async def _get_task_instances(self, start_date: str, end_date: str) ->
list[dict]:
+ """Get the list of task instances."""
+ try:
+ environment = await self.gcp_hook.get_environment(
+ project_id=self.project_id,
+ region=self.region,
+ environment_id=self.environment_id,
+ )
+ except NotFound as not_found_err:
+ self.log.info("The Composer environment %s does not exist.",
self.environment_id)
+ raise AirflowException(not_found_err)
+ composer_airflow_uri = environment.config.airflow_uri
+
+ self.log.info(
+ "Pulling the DAG '%s' task instances from the '%s' environment...",
+ self.composer_external_dag_id,
+ self.environment_id,
+ )
+ task_instances_response = await self.gcp_hook.get_task_instances(
+ composer_airflow_uri=composer_airflow_uri,
+ composer_dag_id=self.composer_external_dag_id,
+ query_parameters={
+ "execution_date_gte" if self.composer_airflow_version < 3 else
"logical_date_gte": start_date,
+ "execution_date_lte" if self.composer_airflow_version < 3 else
"logical_date_lte": end_date,
+ },
+ )
+ task_instances = task_instances_response["task_instances"]
+
+ if self.composer_external_task_ids:
+ task_instances = [
+ task_instance
+ for task_instance in task_instances
+ if task_instance["task_id"] in self.composer_external_task_ids
+ ]
+ elif self.composer_external_task_group_id:
+ task_instances = [
+ task_instance
+ for task_instance in task_instances
+ if self.composer_external_task_group_id in
task_instance["task_id"].split(".")
+ ]
+
+ return task_instances
+
+ def _check_task_instances_states(
+ self,
+ task_instances: list[dict],
+ start_date: datetime,
+ end_date: datetime,
+ states: Iterable[str],
+ ) -> bool:
+ for task_instance in task_instances:
+ if (
+ start_date.timestamp()
+ < parser.parse(
+ task_instance["execution_date" if
self.composer_airflow_version < 3 else "logical_date"]
+ ).timestamp()
+ < end_date.timestamp()
+ ) and task_instance["state"] not in states:
+ return False
+ return True
+
+ def _get_async_hook(self) -> CloudComposerAsyncHook:
+ return CloudComposerAsyncHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+
+ async def run(self):
+ self.gcp_hook: CloudComposerAsyncHook = self._get_async_hook()
+ try:
+ while True:
+ task_instances = await self._get_task_instances(
+ start_date=self.start_date.strftime("%Y-%m-%dT%H:%M:%SZ"),
+ end_date=self.end_date.strftime("%Y-%m-%dT%H:%M:%SZ"),
+ )
+
+ if len(task_instances) == 0:
+ self.log.info("Task Instances are empty. Sensor waits for
task instances...")
+ self.log.info("Sleeping for %s seconds.",
self.poll_interval)
+ await asyncio.sleep(self.poll_interval)
+ continue
+
+ if self.failed_states and self._check_task_instances_states(
+ task_instances=task_instances,
+ start_date=self.start_date,
+ end_date=self.end_date,
+ states=self.failed_states,
+ ):
+ yield TriggerEvent({"status": "failed"})
+ return
+
+ if self.skipped_states and self._check_task_instances_states(
+ task_instances=task_instances,
+ start_date=self.start_date,
+ end_date=self.end_date,
+ states=self.skipped_states,
+ ):
+ yield TriggerEvent({"status": "skipped"})
+ return
+
+ self.log.info("Sensor waits for allowed states: %s",
self.allowed_states)
+ if self._check_task_instances_states(
+ task_instances=task_instances,
+ start_date=self.start_date,
+ end_date=self.end_date,
+ states=self.allowed_states,
+ ):
+ yield TriggerEvent({"status": "success"})
+ return
+
+ self.log.info("Sleeping for %s seconds.", self.poll_interval)
+ await asyncio.sleep(self.poll_interval)
+ except AirflowException as ex:
+ yield TriggerEvent(
+ {
+ "status": "error",
+ "message": str(ex),
+ }
+ )
+ return
diff --git
a/providers/google/tests/system/google/cloud/composer/example_cloud_composer.py
b/providers/google/tests/system/google/cloud/composer/example_cloud_composer.py
index 21e635a3853..a5105dc3eca 100644
---
a/providers/google/tests/system/google/cloud/composer/example_cloud_composer.py
+++
b/providers/google/tests/system/google/cloud/composer/example_cloud_composer.py
@@ -18,7 +18,7 @@
from __future__ import annotations
import os
-from datetime import datetime
+from datetime import datetime, timedelta
from googleapiclient.discovery import build
from googleapiclient.errors import HttpError
@@ -43,7 +43,10 @@ from airflow.providers.google.cloud.operators.cloud_composer
import (
CloudComposerTriggerDAGRunOperator,
CloudComposerUpdateEnvironmentOperator,
)
-from airflow.providers.google.cloud.sensors.cloud_composer import
CloudComposerDAGRunSensor
+from airflow.providers.google.cloud.sensors.cloud_composer import (
+ CloudComposerDAGRunSensor,
+ CloudComposerExternalTaskSensor,
+)
try:
from airflow.sdk import TriggerRule
@@ -229,6 +232,33 @@ with DAG(
)
# [END howto_operator_trigger_dag_run]
+ # [START howto_sensor_external_task]
+ external_task_sensor = CloudComposerExternalTaskSensor(
+ task_id="external_task_sensor",
+ project_id=PROJECT_ID,
+ region=REGION,
+ environment_id=ENVIRONMENT_ID,
+ composer_external_dag_id="airflow_monitoring",
+ composer_external_task_id="echo",
+ allowed_states=["success"],
+ execution_range=[datetime.now() - timedelta(1), datetime.now()],
+ )
+ # [END howto_sensor_external_task]
+
+ # [START howto_sensor_external_task_deferrable_mode]
+ defer_external_task_sensor = CloudComposerExternalTaskSensor(
+ task_id="defer_external_task_sensor",
+ project_id=PROJECT_ID,
+ region=REGION,
+ environment_id=ENVIRONMENT_ID_ASYNC,
+ composer_external_dag_id="airflow_monitoring",
+ composer_external_task_id="echo",
+ allowed_states=["success"],
+ execution_range=[datetime.now() - timedelta(1), datetime.now()],
+ deferrable=True,
+ )
+ # [END howto_sensor_external_task_deferrable_mode]
+
# [START howto_operator_delete_composer_environment]
delete_env = CloudComposerDeleteEnvironmentOperator(
task_id="delete_env",
@@ -262,6 +292,7 @@ with DAG(
[run_airflow_cli_cmd, defer_run_airflow_cli_cmd],
[dag_run_sensor, defer_dag_run_sensor],
trigger_dag_run,
+ [external_task_sensor, defer_external_task_sensor],
# TEST TEARDOWN
[delete_env, defer_delete_env],
)
diff --git
a/providers/google/tests/unit/google/cloud/hooks/test_cloud_composer.py
b/providers/google/tests/unit/google/cloud/hooks/test_cloud_composer.py
index e577f9bd086..680fe791db6 100644
--- a/providers/google/tests/unit/google/cloud/hooks/test_cloud_composer.py
+++ b/providers/google/tests/unit/google/cloud/hooks/test_cloud_composer.py
@@ -298,6 +298,24 @@ class TestCloudComposerHook:
timeout=TEST_TIMEOUT,
)
+ @pytest.mark.parametrize("query_parameters", [None, {"test_key":
"test_value"}])
+
@mock.patch(COMPOSER_STRING.format("CloudComposerHook.make_composer_airflow_api_request"))
+ def test_get_task_instances(self, mock_composer_airflow_api_request,
query_parameters) -> None:
+ query_string = "?test_key=test_value" if query_parameters else ""
+ self.hook.get_credentials = mock.MagicMock()
+ self.hook.get_task_instances(
+ composer_airflow_uri=TEST_COMPOSER_AIRFLOW_URI,
+ composer_dag_id=TEST_COMPOSER_DAG_ID,
+ query_parameters=query_parameters,
+ timeout=TEST_TIMEOUT,
+ )
+ mock_composer_airflow_api_request.assert_called_once_with(
+ method="GET",
+ airflow_uri=TEST_COMPOSER_AIRFLOW_URI,
+
path=f"/api/v1/dags/{TEST_COMPOSER_DAG_ID}/dagRuns/~/taskInstances{query_string}",
+ timeout=TEST_TIMEOUT,
+ )
+
class TestCloudComposerAsyncHook:
def setup_method(self, method):
@@ -484,3 +502,22 @@ class TestCloudComposerAsyncHook:
path=f"/api/v1/dags/{TEST_COMPOSER_DAG_ID}/dagRuns",
timeout=TEST_TIMEOUT,
)
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize("query_parameters", [None, {"test_key":
"test_value"}])
+
@mock.patch(COMPOSER_STRING.format("CloudComposerAsyncHook.make_composer_airflow_api_request"))
+ async def test_get_task_instances(self, mock_composer_airflow_api_request,
query_parameters) -> None:
+ query_string = "?test_key=test_value" if query_parameters else ""
+ mock_composer_airflow_api_request.return_value = ({}, 200)
+ await self.hook.get_task_instances(
+ composer_airflow_uri=TEST_COMPOSER_AIRFLOW_URI,
+ composer_dag_id=TEST_COMPOSER_DAG_ID,
+ query_parameters=query_parameters,
+ timeout=TEST_TIMEOUT,
+ )
+ mock_composer_airflow_api_request.assert_called_once_with(
+ method="GET",
+ airflow_uri=TEST_COMPOSER_AIRFLOW_URI,
+
path=f"/api/v1/dags/{TEST_COMPOSER_DAG_ID}/dagRuns/~/taskInstances{query_string}",
+ timeout=TEST_TIMEOUT,
+ )
diff --git
a/providers/google/tests/unit/google/cloud/sensors/test_cloud_composer.py
b/providers/google/tests/unit/google/cloud/sensors/test_cloud_composer.py
index 6988508b056..7f639a59685 100644
--- a/providers/google/tests/unit/google/cloud/sensors/test_cloud_composer.py
+++ b/providers/google/tests/unit/google/cloud/sensors/test_cloud_composer.py
@@ -23,7 +23,10 @@ from unittest import mock
import pytest
-from airflow.providers.google.cloud.sensors.cloud_composer import
CloudComposerDAGRunSensor
+from airflow.providers.google.cloud.sensors.cloud_composer import (
+ CloudComposerDAGRunSensor,
+ CloudComposerExternalTaskSensor,
+)
TEST_PROJECT_ID = "test_project_id"
TEST_OPERATION_NAME = "test_operation_name"
@@ -49,6 +52,22 @@ TEST_GET_RESULT = lambda state, date_key: {
"dag_runs": TEST_DAG_RUNS_RESULT(state, date_key, "dag_run_id"),
"total_entries": 1,
}
+TEST_COMPOSER_EXTERNAL_TASK_ID = "test_external_task_id"
+TEST_COMPOSER_EXTERNAL_TASK_GROUP_ID = "test_external_task_group_id"
+TEST_TASK_INSTANCES_RESULT = lambda state, date_key, task_id: [
+ {
+ "task_id": task_id,
+ "dag_id": "test_dag_id",
+ "state": state,
+ date_key: "2024-05-22T11:10:00+00:00",
+ "start_date": "2024-05-22T11:20:01.531988+00:00",
+ "end_date": "2024-05-22T11:20:11.997479+00:00",
+ }
+]
+TEST_GET_TASK_INSTANCES_RESULT = lambda state, date_key, task_id: {
+ "task_instances": TEST_TASK_INSTANCES_RESULT(state, date_key, task_id),
+ "total_entries": 1,
+}
class TestCloudComposerDAGRunSensor:
@@ -185,3 +204,155 @@ class TestCloudComposerDAGRunSensor:
task._composer_airflow_version = composer_airflow_version
assert not task.poke(context={"logical_date": datetime(2024, 5, 23, 0,
0, 0)})
+
+
+class TestCloudComposerExternalTaskSensor:
+ @pytest.mark.parametrize("composer_airflow_version", [2, 3])
+
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerHook")
+ def test_wait_ready(self, mock_hook, composer_airflow_version):
+ mock_hook.return_value.get_task_instances.return_value =
TEST_GET_TASK_INSTANCES_RESULT(
+ "success",
+ "execution_date" if composer_airflow_version < 3 else
"logical_date",
+ TEST_COMPOSER_EXTERNAL_TASK_ID,
+ )
+
+ task = CloudComposerExternalTaskSensor(
+ task_id="task-id",
+ project_id=TEST_PROJECT_ID,
+ region=TEST_REGION,
+ environment_id=TEST_ENVIRONMENT_ID,
+ composer_external_dag_id="test_dag_id",
+ allowed_states=["success"],
+ )
+ task._composer_airflow_version = composer_airflow_version
+
+ assert task.poke(context={"logical_date": datetime(2024, 5, 23, 0, 0,
0)})
+
+ @pytest.mark.parametrize("composer_airflow_version", [2, 3])
+
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerHook")
+ def test_wait_not_ready(self, mock_hook, composer_airflow_version):
+ mock_hook.return_value.get_task_instances.return_value =
TEST_GET_TASK_INSTANCES_RESULT(
+ "running",
+ "execution_date" if composer_airflow_version < 3 else
"logical_date",
+ TEST_COMPOSER_EXTERNAL_TASK_ID,
+ )
+
+ task = CloudComposerExternalTaskSensor(
+ task_id="task-id",
+ project_id=TEST_PROJECT_ID,
+ region=TEST_REGION,
+ environment_id=TEST_ENVIRONMENT_ID,
+ composer_external_dag_id="test_dag_id",
+ allowed_states=["success"],
+ )
+ task._composer_airflow_version = composer_airflow_version
+
+ assert not task.poke(context={"logical_date": datetime(2024, 5, 23, 0,
0, 0)})
+
+ @pytest.mark.parametrize("composer_airflow_version", [2, 3])
+
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerHook")
+ def test_task_instances_empty(self, mock_hook, composer_airflow_version):
+ mock_hook.return_value.get_task_instances.return_value = {
+ "task_instances": [],
+ "total_entries": 0,
+ }
+
+ task = CloudComposerExternalTaskSensor(
+ task_id="task-id",
+ project_id=TEST_PROJECT_ID,
+ region=TEST_REGION,
+ environment_id=TEST_ENVIRONMENT_ID,
+ composer_external_dag_id="test_dag_id",
+ allowed_states=["success"],
+ )
+ task._composer_airflow_version = composer_airflow_version
+
+ assert not task.poke(context={"logical_date": datetime(2024, 5, 23, 0,
0, 0)})
+
+ @pytest.mark.parametrize("composer_airflow_version", [2, 3])
+
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerHook")
+ def test_composer_external_task_id_wait_ready(self, mock_hook,
composer_airflow_version):
+ mock_hook.return_value.get_task_instances.return_value =
TEST_GET_TASK_INSTANCES_RESULT(
+ "success",
+ "execution_date" if composer_airflow_version < 3 else
"logical_date",
+ TEST_COMPOSER_EXTERNAL_TASK_ID,
+ )
+
+ task = CloudComposerExternalTaskSensor(
+ task_id="task-id",
+ project_id=TEST_PROJECT_ID,
+ region=TEST_REGION,
+ environment_id=TEST_ENVIRONMENT_ID,
+ composer_external_dag_id="test_dag_id",
+ composer_external_task_id=TEST_COMPOSER_EXTERNAL_TASK_ID,
+ allowed_states=["success"],
+ )
+ task._composer_airflow_version = composer_airflow_version
+
+ assert task.poke(context={"logical_date": datetime(2024, 5, 23, 0, 0,
0)})
+
+ @pytest.mark.parametrize("composer_airflow_version", [2, 3])
+
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerHook")
+ def test_composer_external_task_id_wait_not_ready(self, mock_hook,
composer_airflow_version):
+ mock_hook.return_value.get_task_instances.return_value =
TEST_GET_TASK_INSTANCES_RESULT(
+ "running",
+ "execution_date" if composer_airflow_version < 3 else
"logical_date",
+ TEST_COMPOSER_EXTERNAL_TASK_ID,
+ )
+
+ task = CloudComposerExternalTaskSensor(
+ task_id="task-id",
+ project_id=TEST_PROJECT_ID,
+ region=TEST_REGION,
+ environment_id=TEST_ENVIRONMENT_ID,
+ composer_external_dag_id="test_dag_id",
+ composer_external_task_id=TEST_COMPOSER_EXTERNAL_TASK_ID,
+ allowed_states=["success"],
+ )
+ task._composer_airflow_version = composer_airflow_version
+
+ assert not task.poke(context={"logical_date": datetime(2024, 5, 23, 0,
0, 0)})
+
+ @pytest.mark.parametrize("composer_airflow_version", [2, 3])
+
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerHook")
+ def test_composer_external_task_group_id_wait_ready(self, mock_hook,
composer_airflow_version):
+ mock_hook.return_value.get_task_instances.return_value =
TEST_GET_TASK_INSTANCES_RESULT(
+ "success",
+ "execution_date" if composer_airflow_version < 3 else
"logical_date",
+
f"{TEST_COMPOSER_EXTERNAL_TASK_GROUP_ID}.{TEST_COMPOSER_EXTERNAL_TASK_ID}",
+ )
+
+ task = CloudComposerExternalTaskSensor(
+ task_id="task-id",
+ project_id=TEST_PROJECT_ID,
+ region=TEST_REGION,
+ environment_id=TEST_ENVIRONMENT_ID,
+ composer_external_dag_id="test_dag_id",
+
composer_external_task_group_id=TEST_COMPOSER_EXTERNAL_TASK_GROUP_ID,
+ allowed_states=["success"],
+ )
+ task._composer_airflow_version = composer_airflow_version
+
+ assert task.poke(context={"logical_date": datetime(2024, 5, 23, 0, 0,
0)})
+
+ @pytest.mark.parametrize("composer_airflow_version", [2, 3])
+
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerHook")
+ def test_composer_external_task_group_id_wait_not_ready(self, mock_hook,
composer_airflow_version):
+ mock_hook.return_value.get_task_instances.return_value =
TEST_GET_TASK_INSTANCES_RESULT(
+ "running",
+ "execution_date" if composer_airflow_version < 3 else
"logical_date",
+
f"{TEST_COMPOSER_EXTERNAL_TASK_GROUP_ID}.{TEST_COMPOSER_EXTERNAL_TASK_ID}",
+ )
+
+ task = CloudComposerExternalTaskSensor(
+ task_id="task-id",
+ project_id=TEST_PROJECT_ID,
+ region=TEST_REGION,
+ environment_id=TEST_ENVIRONMENT_ID,
+ composer_external_dag_id="test_dag_id",
+
composer_external_task_group_id=TEST_COMPOSER_EXTERNAL_TASK_GROUP_ID,
+ allowed_states=["success"],
+ )
+ task._composer_airflow_version = composer_airflow_version
+
+ assert not task.poke(context={"logical_date": datetime(2024, 5, 23, 0,
0, 0)})
diff --git
a/providers/google/tests/unit/google/cloud/triggers/test_cloud_composer.py
b/providers/google/tests/unit/google/cloud/triggers/test_cloud_composer.py
index f093a76d1d7..1d6307f3814 100644
--- a/providers/google/tests/unit/google/cloud/triggers/test_cloud_composer.py
+++ b/providers/google/tests/unit/google/cloud/triggers/test_cloud_composer.py
@@ -26,6 +26,7 @@ from airflow.models import Connection
from airflow.providers.google.cloud.triggers.cloud_composer import (
CloudComposerAirflowCLICommandTrigger,
CloudComposerDAGRunTrigger,
+ CloudComposerExternalTaskTrigger,
)
from airflow.triggers.base import TriggerEvent
@@ -40,9 +41,13 @@ TEST_EXEC_CMD_INFO = {
}
TEST_COMPOSER_DAG_ID = "test_dag_id"
TEST_COMPOSER_DAG_RUN_ID = "scheduled__2024-05-22T11:10:00+00:00"
+TEST_COMPOSER_EXTERNAL_TASK_IDS = ["test_external_task_id"]
+TEST_COMPOSER_EXTERNAL_TASK_GROUP_ID = "test_external_task_group_id"
TEST_START_DATE = datetime(2024, 3, 22, 11, 0, 0)
TEST_END_DATE = datetime(2024, 3, 22, 12, 0, 0)
-TEST_STATES = ["success"]
+TEST_ALLOWED_STATES = ["success"]
+TEST_SKIPPED_STATES = ["skipped"]
+TEST_FAILED_STATES = ["failed"]
TEST_GCP_CONN_ID = "test_gcp_conn_id"
TEST_POLL_INTERVAL = 10
TEST_COMPOSER_AIRFLOW_VERSION = 3
@@ -86,7 +91,7 @@ def dag_run_trigger(mock_conn):
composer_dag_run_id=TEST_COMPOSER_DAG_RUN_ID,
start_date=TEST_START_DATE,
end_date=TEST_END_DATE,
- allowed_states=TEST_STATES,
+ allowed_states=TEST_ALLOWED_STATES,
gcp_conn_id=TEST_GCP_CONN_ID,
impersonation_chain=TEST_IMPERSONATION_CHAIN,
poll_interval=TEST_POLL_INTERVAL,
@@ -95,6 +100,31 @@ def dag_run_trigger(mock_conn):
)
[email protected]
[email protected](
+
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.get_connection",
+ return_value=Connection(conn_id="test_conn"),
+)
+def external_task_trigger(mock_conn):
+ return CloudComposerExternalTaskTrigger(
+ project_id=TEST_PROJECT_ID,
+ region=TEST_LOCATION,
+ environment_id=TEST_ENVIRONMENT_ID,
+ start_date=TEST_START_DATE,
+ end_date=TEST_END_DATE,
+ allowed_states=TEST_ALLOWED_STATES,
+ skipped_states=TEST_SKIPPED_STATES,
+ failed_states=TEST_FAILED_STATES,
+ composer_external_dag_id=TEST_COMPOSER_DAG_ID,
+ composer_external_task_ids=TEST_COMPOSER_EXTERNAL_TASK_IDS,
+ composer_external_task_group_id=TEST_COMPOSER_EXTERNAL_TASK_GROUP_ID,
+ gcp_conn_id=TEST_GCP_CONN_ID,
+ impersonation_chain=TEST_IMPERSONATION_CHAIN,
+ poll_interval=TEST_POLL_INTERVAL,
+ composer_airflow_version=TEST_COMPOSER_AIRFLOW_VERSION,
+ )
+
+
class TestCloudComposerAirflowCLICommandTrigger:
def test_serialize(self, cli_command_trigger):
actual_data = cli_command_trigger.serialize()
@@ -143,7 +173,7 @@ class TestCloudComposerDAGRunTrigger:
"composer_dag_run_id": TEST_COMPOSER_DAG_RUN_ID,
"start_date": TEST_START_DATE,
"end_date": TEST_END_DATE,
- "allowed_states": TEST_STATES,
+ "allowed_states": TEST_ALLOWED_STATES,
"gcp_conn_id": TEST_GCP_CONN_ID,
"impersonation_chain": TEST_IMPERSONATION_CHAIN,
"poll_interval": TEST_POLL_INTERVAL,
@@ -152,3 +182,29 @@ class TestCloudComposerDAGRunTrigger:
},
)
assert actual_data == expected_data
+
+
+class TestCloudComposerExternalTaskTrigger:
+ def test_serialize(self, external_task_trigger):
+ actual_data = external_task_trigger.serialize()
+ expected_data = (
+
"airflow.providers.google.cloud.triggers.cloud_composer.CloudComposerExternalTaskTrigger",
+ {
+ "project_id": TEST_PROJECT_ID,
+ "region": TEST_LOCATION,
+ "environment_id": TEST_ENVIRONMENT_ID,
+ "start_date": TEST_START_DATE,
+ "end_date": TEST_END_DATE,
+ "allowed_states": TEST_ALLOWED_STATES,
+ "skipped_states": TEST_SKIPPED_STATES,
+ "failed_states": TEST_FAILED_STATES,
+ "composer_external_dag_id": TEST_COMPOSER_DAG_ID,
+ "composer_external_task_ids": TEST_COMPOSER_EXTERNAL_TASK_IDS,
+ "composer_external_task_group_id":
TEST_COMPOSER_EXTERNAL_TASK_GROUP_ID,
+ "gcp_conn_id": TEST_GCP_CONN_ID,
+ "impersonation_chain": TEST_IMPERSONATION_CHAIN,
+ "poll_interval": TEST_POLL_INTERVAL,
+ "composer_airflow_version": TEST_COMPOSER_AIRFLOW_VERSION,
+ },
+ )
+ assert actual_data == expected_data