This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 19ebc8532e6 Create CloudComposerExternalTaskSensor for Cloud Composer 
service (#57971)
19ebc8532e6 is described below

commit 19ebc8532e6a436b06aa04afa8545a64ded9a6a0
Author: Maksim <[email protected]>
AuthorDate: Mon Nov 24 12:30:34 2025 +0100

    Create CloudComposerExternalTaskSensor for Cloud Composer service (#57971)
---
 dev/breeze/tests/test_selective_checks.py          |  11 +-
 .../google/docs/operators/cloud/cloud_composer.rst |  20 +
 providers/google/pyproject.toml                    |   4 +
 .../providers/google/cloud/hooks/cloud_composer.py |  74 +++-
 .../google/cloud/sensors/cloud_composer.py         | 443 ++++++++++++++++++++-
 .../google/cloud/triggers/cloud_composer.py        | 184 ++++++++-
 .../cloud/composer/example_cloud_composer.py       |  35 +-
 .../unit/google/cloud/hooks/test_cloud_composer.py |  37 ++
 .../google/cloud/sensors/test_cloud_composer.py    | 173 +++++++-
 .../google/cloud/triggers/test_cloud_composer.py   |  62 ++-
 10 files changed, 1026 insertions(+), 17 deletions(-)

diff --git a/dev/breeze/tests/test_selective_checks.py 
b/dev/breeze/tests/test_selective_checks.py
index f1b8c4e3e4a..a14d7035d8c 100644
--- a/dev/breeze/tests/test_selective_checks.py
+++ b/dev/breeze/tests/test_selective_checks.py
@@ -1852,7 +1852,7 @@ def test_expected_output_push(
                 "selected-providers-list-as-string": "amazon apache.beam 
apache.cassandra apache.kafka "
                 "cncf.kubernetes common.compat common.sql "
                 "facebook google hashicorp http microsoft.azure 
microsoft.mssql mysql "
-                "openlineage oracle postgres presto salesforce samba sftp ssh 
trino",
+                "openlineage oracle postgres presto salesforce samba sftp ssh 
standard trino",
                 "all-python-versions": 
f"['{DEFAULT_PYTHON_MAJOR_MINOR_VERSION}']",
                 "all-python-versions-list-as-string": 
DEFAULT_PYTHON_MAJOR_MINOR_VERSION,
                 "ci-image-build": "true",
@@ -1864,7 +1864,7 @@ def test_expected_output_push(
                 "docs-list-as-string": "apache-airflow helm-chart amazon 
apache.beam apache.cassandra "
                 "apache.kafka cncf.kubernetes common.compat common.sql 
facebook google hashicorp http microsoft.azure "
                 "microsoft.mssql mysql openlineage oracle postgres "
-                "presto salesforce samba sftp ssh trino",
+                "presto salesforce samba sftp ssh standard trino",
                 "skip-prek-hooks": ALL_SKIPPED_COMMITS_IF_NO_UI,
                 "run-kubernetes-tests": "true",
                 "upgrade-to-newer-dependencies": "false",
@@ -1874,12 +1874,13 @@ def test_expected_output_push(
                 "providers-test-types-list-as-strings-in-json": json.dumps(
                     [
                         {
-                            "description": "amazon...google",
+                            "description": "amazon...standard",
                             "test_types": "Providers[amazon] 
Providers[apache.beam,apache.cassandra,"
                             
"apache.kafka,cncf.kubernetes,common.compat,common.sql,facebook,"
                             
"hashicorp,http,microsoft.azure,microsoft.mssql,mysql,"
                             
"openlineage,oracle,postgres,presto,salesforce,samba,sftp,ssh,trino] "
-                            "Providers[google]",
+                            "Providers[google] "
+                            "Providers[standard]",
                         }
                     ]
                 ),
@@ -2122,7 +2123,7 @@ def test_upgrade_to_newer_dependencies(
                 "docs-list-as-string": "amazon apache.beam apache.cassandra 
apache.kafka "
                 "cncf.kubernetes common.compat common.sql facebook google 
hashicorp http "
                 "microsoft.azure microsoft.mssql mysql openlineage oracle "
-                "postgres presto salesforce samba sftp ssh trino",
+                "postgres presto salesforce samba sftp ssh standard trino",
             },
             id="Google provider docs changed",
         ),
diff --git a/providers/google/docs/operators/cloud/cloud_composer.rst 
b/providers/google/docs/operators/cloud/cloud_composer.rst
index 88381b48438..9e18698789f 100644
--- a/providers/google/docs/operators/cloud/cloud_composer.rst
+++ b/providers/google/docs/operators/cloud/cloud_composer.rst
@@ -209,3 +209,23 @@ You can trigger a DAG in another Composer environment, use:
     :dedent: 4
     :start-after: [START howto_operator_trigger_dag_run]
     :end-before: [END howto_operator_trigger_dag_run]
+
+Waits for a different DAG, task group, or task to complete
+----------------------------------------------------------
+
+You can use sensor that waits for a different DAG, task group, or task to 
complete for a specific composer environment, use:
+:class:`~airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerExternalTaskSensor`
+
+.. exampleinclude:: 
/../../google/tests/system/google/cloud/composer/example_cloud_composer.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_sensor_external_task]
+    :end-before: [END howto_sensor_external_task]
+
+or you can define the same sensor in the deferrable mode:
+
+.. exampleinclude:: 
/../../google/tests/system/google/cloud/composer/example_cloud_composer.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_sensor_external_task_deferrable_mode]
+    :end-before: [END howto_sensor_external_task_deferrable_mode]
diff --git a/providers/google/pyproject.toml b/providers/google/pyproject.toml
index 646fb0c26fd..7fe5536d162 100644
--- a/providers/google/pyproject.toml
+++ b/providers/google/pyproject.toml
@@ -204,6 +204,9 @@ dependencies = [
 "http" = [
     "apache-airflow-providers-http"
 ]
+"standard" = [
+    "apache-airflow-providers-standard"
+]
 
 [dependency-groups]
 dev = [
@@ -228,6 +231,7 @@ dev = [
     "apache-airflow-providers-salesforce",
     "apache-airflow-providers-sftp",
     "apache-airflow-providers-ssh",
+    "apache-airflow-providers-standard",
     "apache-airflow-providers-trino",
     # Additional devel dependencies (do not remove this line and add extra 
development dependencies)
     "apache-airflow-providers-apache-kafka",
diff --git 
a/providers/google/src/airflow/providers/google/cloud/hooks/cloud_composer.py 
b/providers/google/src/airflow/providers/google/cloud/hooks/cloud_composer.py
index 423d6cb56df..9549254aaf4 100644
--- 
a/providers/google/src/airflow/providers/google/cloud/hooks/cloud_composer.py
+++ 
b/providers/google/src/airflow/providers/google/cloud/hooks/cloud_composer.py
@@ -22,7 +22,7 @@ import json
 import time
 from collections.abc import MutableSequence, Sequence
 from typing import TYPE_CHECKING, Any
-from urllib.parse import urljoin
+from urllib.parse import urlencode, urljoin
 
 from aiohttp import ClientSession
 from google.api_core.client_options import ClientOptions
@@ -505,6 +505,42 @@ class CloudComposerHook(GoogleBaseHook, OperationHelper):
 
         return response.json()
 
+    def get_task_instances(
+        self,
+        composer_airflow_uri: str,
+        composer_dag_id: str,
+        query_parameters: dict | None = None,
+        timeout: float | None = None,
+    ) -> dict:
+        """
+        Get the list of task instances for provided DAG.
+
+        :param composer_airflow_uri: The URI of the Apache Airflow Web UI 
hosted within Composer environment.
+        :param composer_dag_id: The ID of DAG.
+        :query_parameters: Query parameters for this request.
+        :param timeout: The timeout for this request.
+        """
+        query_string = f"?{urlencode(query_parameters)}" if query_parameters 
else ""
+
+        response = self.make_composer_airflow_api_request(
+            method="GET",
+            airflow_uri=composer_airflow_uri,
+            
path=f"/api/v1/dags/{composer_dag_id}/dagRuns/~/taskInstances{query_string}",
+            timeout=timeout,
+        )
+
+        if response.status_code != 200:
+            self.log.error(
+                "Failed to get task instances for dag_id=%s from %s 
(status=%s): %s",
+                composer_dag_id,
+                composer_airflow_uri,
+                response.status_code,
+                response.text,
+            )
+            response.raise_for_status()
+
+        return response.json()
+
 
 class CloudComposerAsyncHook(GoogleBaseAsyncHook):
     """Hook for Google Cloud Composer async APIs."""
@@ -849,3 +885,39 @@ class CloudComposerAsyncHook(GoogleBaseAsyncHook):
             raise AirflowException(response_body["title"])
 
         return response_body
+
+    async def get_task_instances(
+        self,
+        composer_airflow_uri: str,
+        composer_dag_id: str,
+        query_parameters: dict | None = None,
+        timeout: float | None = None,
+    ) -> dict:
+        """
+        Get the list of task instances for provided DAG.
+
+        :param composer_airflow_uri: The URI of the Apache Airflow Web UI 
hosted within Composer environment.
+        :param composer_dag_id: The ID of DAG.
+        :query_parameters: Query parameters for this request.
+        :param timeout: The timeout for this request.
+        """
+        query_string = f"?{urlencode(query_parameters)}" if query_parameters 
else ""
+
+        response_body, response_status_code = await 
self.make_composer_airflow_api_request(
+            method="GET",
+            airflow_uri=composer_airflow_uri,
+            
path=f"/api/v1/dags/{composer_dag_id}/dagRuns/~/taskInstances{query_string}",
+            timeout=timeout,
+        )
+
+        if response_status_code != 200:
+            self.log.error(
+                "Failed to get task instances for dag_id=%s from %s 
(status=%s): %s",
+                composer_dag_id,
+                composer_airflow_uri,
+                response_status_code,
+                response_body["title"],
+            )
+            raise AirflowException(response_body["title"])
+
+        return response_body
diff --git 
a/providers/google/src/airflow/providers/google/cloud/sensors/cloud_composer.py 
b/providers/google/src/airflow/providers/google/cloud/sensors/cloud_composer.py
index 14976293746..5db00f60ce3 100644
--- 
a/providers/google/src/airflow/providers/google/cloud/sensors/cloud_composer.py
+++ 
b/providers/google/src/airflow/providers/google/cloud/sensors/cloud_composer.py
@@ -20,7 +20,7 @@
 from __future__ import annotations
 
 import json
-from collections.abc import Iterable, Sequence
+from collections.abc import Collection, Iterable, Sequence
 from datetime import datetime, timedelta
 from functools import cached_property
 from typing import TYPE_CHECKING
@@ -30,12 +30,21 @@ from google.api_core.exceptions import NotFound
 from google.cloud.orchestration.airflow.service_v1.types import Environment, 
ExecuteAirflowCommandResponse
 
 from airflow.configuration import conf
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, AirflowSkipException
 from airflow.providers.common.compat.sdk import BaseSensorOperator
 from airflow.providers.google.cloud.hooks.cloud_composer import 
CloudComposerHook
-from airflow.providers.google.cloud.triggers.cloud_composer import 
CloudComposerDAGRunTrigger
+from airflow.providers.google.cloud.triggers.cloud_composer import (
+    CloudComposerDAGRunTrigger,
+    CloudComposerExternalTaskTrigger,
+)
 from airflow.providers.google.common.consts import 
GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
-from airflow.utils.state import TaskInstanceState
+from airflow.providers.standard.exceptions import (
+    DuplicateStateError,
+    ExternalDagFailedError,
+    ExternalTaskFailedError,
+    ExternalTaskGroupFailedError,
+)
+from airflow.utils.state import State, TaskInstanceState
 
 if TYPE_CHECKING:
     from airflow.providers.common.compat.sdk import Context
@@ -286,3 +295,429 @@ class CloudComposerDAGRunSensor(BaseSensorOperator):
             gcp_conn_id=self.gcp_conn_id,
             impersonation_chain=self.impersonation_chain,
         )
+
+
+class CloudComposerExternalTaskSensor(BaseSensorOperator):
+    """
+    Waits for a different DAG, task group, or task to complete for a specific 
composer environment.
+
+    If both `composer_external_task_group_id` and `composer_external_task_id` 
are ``None`` (default), the sensor
+    waits for the DAG.
+    Values for `composer_external_task_group_id` and 
`composer_external_task_id` can't be set at the same time.
+
+    By default, the CloudComposerExternalTaskSensor will wait for the external 
task to
+    succeed, at which point it will also succeed. However, by default it will
+    *not* fail if the external task fails, but will continue to check the 
status
+    until the sensor times out (thus giving you time to retry the external task
+    without also having to clear the sensor).
+
+    By default, the CloudComposerExternalTaskSensor will not skip if the 
external task skips.
+    To change this, simply set ``skipped_states=[TaskInstanceState.SKIPPED]``.
+    Note that if you are monitoring multiple tasks, and one enters error state
+    and the other enters a skipped state, then the external task will react to
+    whichever one it sees first. If both happen together, then the failed state
+    takes priority.
+
+    It is possible to alter the default behavior by setting states which
+    cause the sensor to fail, e.g. by setting 
``allowed_states=[DagRunState.FAILED]``
+    and ``failed_states=[DagRunState.SUCCESS]`` you will flip the behaviour to
+    get a sensor which goes green when the external task *fails* and 
immediately
+    goes red if the external task *succeeds*!
+
+    Note that ``soft_fail`` is respected when examining the failed_states. Thus
+    if the external task enters a failed state and ``soft_fail == True`` the
+    sensor will _skip_ rather than fail. As a result, setting 
``soft_fail=True``
+    and ``failed_states=[DagRunState.SKIPPED]`` will result in the sensor
+    skipping if the external task skips. However, this is a contrived
+    example---consider using ``skipped_states`` if you would like this
+    behaviour. Using ``skipped_states`` allows the sensor to skip if the target
+    fails, but still enter failed state on timeout. Using ``soft_fail == True``
+    as above will cause the sensor to skip if the target fails, but also if it
+    times out.
+
+    :param project_id: Required. The ID of the Google Cloud project that the 
service belongs to.
+    :param region: Required. The ID of the Google Cloud region that the 
service belongs to.
+    :param environment_id: The name of the Composer environment.
+    :param composer_external_dag_id: The dag_id that contains the task you 
want to
+        wait for. (templated)
+    :param composer_external_task_id: The task_id that contains the task you 
want to
+        wait for. (templated)
+    :param composer_external_task_ids: The list of task_ids that you want to 
wait for. (templated)
+        If ``None`` (default value) the sensor waits for the DAG. Either
+        composer_external_task_id or composer_external_task_ids can be passed 
to
+        CloudComposerExternalTaskSensor, but not both.
+    :param composer_external_task_group_id: The task_group_id that contains 
the task you want to
+        wait for. (templated)
+    :param allowed_states: Iterable of allowed states, default is 
``['success']``
+    :param skipped_states: Iterable of states to make this task mark as 
skipped, default is ``None``
+    :param failed_states: Iterable of failed or dis-allowed states, default is 
``None``
+    :param execution_range: execution DAGs time range. Sensor checks DAGs 
states only for DAGs which were
+        started in this time range. For yesterday, use [positive!] 
datetime.timedelta(days=1).
+        For future, use [negative!] datetime.timedelta(days=-1). For specific 
time, use list of
+        datetimes [datetime(2024,3,22,11,0,0), datetime(2024,3,22,12,0,0)].
+        Or [datetime(2024,3,22,0,0,0)] in this case sensor will check for 
states from specific time in the
+        past till current time execution.
+        Default value datetime.timedelta(days=1).
+    :param gcp_conn_id: The connection ID to use when fetching connection info.
+    :param impersonation_chain: Optional service account to impersonate using 
short-term
+        credentials, or chained list of accounts required to get the 
access_token
+        of the last account in the list, which will be impersonated in the 
request.
+        If set as a string, the account must grant the originating account
+        the Service Account Token Creator IAM role.
+        If set as a sequence, the identities from the list must grant
+        Service Account Token Creator IAM role to the directly preceding 
identity, with first
+        account from the list granting this role to the originating account 
(templated).
+    :param poll_interval: Optional: Control the rate of the poll for the 
result of deferrable run.
+    :param deferrable: Run sensor in deferrable mode.
+    """
+
+    template_fields = (
+        "project_id",
+        "region",
+        "environment_id",
+        "composer_external_dag_id",
+        "composer_external_task_id",
+        "composer_external_task_ids",
+        "composer_external_task_group_id",
+        "impersonation_chain",
+    )
+
+    def __init__(
+        self,
+        *,
+        project_id: str,
+        region: str,
+        environment_id: str,
+        composer_external_dag_id: str,
+        composer_external_task_id: str | None = None,
+        composer_external_task_ids: Collection[str] | None = None,
+        composer_external_task_group_id: str | None = None,
+        allowed_states: Iterable[str] | None = None,
+        skipped_states: Iterable[str] | None = None,
+        failed_states: Iterable[str] | None = None,
+        execution_range: timedelta | list[datetime] | None = None,
+        gcp_conn_id: str = "google_cloud_default",
+        impersonation_chain: str | Sequence[str] | None = None,
+        deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
+        poll_interval: int = 10,
+        **kwargs,
+    ) -> None:
+        super().__init__(**kwargs)
+        self.project_id = project_id
+        self.region = region
+        self.environment_id = environment_id
+
+        self.allowed_states = list(allowed_states) if allowed_states else 
[TaskInstanceState.SUCCESS.value]
+        self.skipped_states = list(skipped_states) if skipped_states else []
+        self.failed_states = list(failed_states) if failed_states else []
+
+        total_states = set(self.allowed_states + self.skipped_states + 
self.failed_states)
+
+        if len(total_states) != len(self.allowed_states) + 
len(self.skipped_states) + len(self.failed_states):
+            raise DuplicateStateError(
+                "Duplicate values provided across allowed_states, 
skipped_states and failed_states."
+            )
+
+        # convert [] to None
+        if not composer_external_task_ids:
+            composer_external_task_ids = None
+
+        # can't set both single task id and a list of task ids
+        if composer_external_task_id is not None and 
composer_external_task_ids is not None:
+            raise ValueError(
+                "Only one of `composer_external_task_id` or 
`composer_external_task_ids` may "
+                "be provided to CloudComposerExternalTaskSensor; "
+                "use `composer_external_task_id` or 
`composer_external_task_ids` or `composer_external_task_group_id`."
+            )
+
+        # since both not set, convert the single id to a 1-elt list - from 
here on, we only consider the list
+        if composer_external_task_id is not None:
+            composer_external_task_ids = [composer_external_task_id]
+
+        if composer_external_task_group_id is not None and 
composer_external_task_ids is not None:
+            raise ValueError(
+                "Only one of `composer_external_task_group_id` or 
`composer_external_task_ids` may "
+                "be provided to CloudComposerExternalTaskSensor; "
+                "use `composer_external_task_id` or 
`composer_external_task_ids` or `composer_external_task_group_id`."
+            )
+
+        # check the requested states are all valid states for the target type, 
be it dag or task
+        if composer_external_task_ids or composer_external_task_group_id:
+            if not total_states <= set(State.task_states):
+                raise ValueError(
+                    "Valid values for `allowed_states`, `skipped_states` and 
`failed_states` "
+                    "when `composer_external_task_id` or 
`composer_external_task_ids` or `composer_external_task_group_id` "
+                    f"is not `None`: {State.task_states}"
+                )
+        elif not total_states <= set(State.dag_states):
+            raise ValueError(
+                "Valid values for `allowed_states`, `skipped_states` and 
`failed_states` "
+                f"when `composer_external_task_id` and 
`composer_external_task_group_id` is `None`: {State.dag_states}"
+            )
+
+        self.execution_range = execution_range
+        self.composer_external_dag_id = composer_external_dag_id
+        self.composer_external_task_id = composer_external_task_id
+        self.composer_external_task_ids = composer_external_task_ids
+        self.composer_external_task_group_id = composer_external_task_group_id
+        self.gcp_conn_id = gcp_conn_id
+        self.impersonation_chain = impersonation_chain
+        self.deferrable = deferrable
+        self.poll_interval = poll_interval
+
+    def _get_logical_dates(self, context) -> tuple[datetime, datetime]:
+        logical_date = context.get("logical_date", None)
+        if logical_date is None:
+            raise RuntimeError(
+                "logical_date is None. Please make sure the sensor is not used 
in an asset-triggered Dag. "
+                "CloudComposerDAGRunSensor was designed to be used in 
time-based scheduled Dags only, "
+                "and asset-triggered Dags do not have logical_date. "
+            )
+        if isinstance(self.execution_range, timedelta):
+            if self.execution_range < timedelta(0):
+                return logical_date, logical_date - self.execution_range
+            return logical_date - self.execution_range, logical_date
+        if isinstance(self.execution_range, list) and 
len(self.execution_range) > 0:
+            return self.execution_range[0], self.execution_range[1] if len(
+                self.execution_range
+            ) > 1 else logical_date
+        return logical_date - timedelta(1), logical_date
+
+    def poke(self, context: Context) -> bool:
+        start_date, end_date = self._get_logical_dates(context)
+
+        task_instances = self._get_task_instances(
+            start_date=start_date.strftime("%Y-%m-%dT%H:%M:%SZ"),
+            end_date=end_date.strftime("%Y-%m-%dT%H:%M:%SZ"),
+        )
+
+        if len(task_instances) == 0:
+            self.log.info("Task Instances are empty. Sensor waits for task 
instances...")
+            return False
+
+        if self.failed_states:
+            external_task_status = self._check_task_instances_states(
+                task_instances=task_instances,
+                start_date=start_date,
+                end_date=end_date,
+                states=self.failed_states,
+            )
+            self._handle_failed_states(external_task_status)
+
+        if self.skipped_states:
+            external_task_status = self._check_task_instances_states(
+                task_instances=task_instances,
+                start_date=start_date,
+                end_date=end_date,
+                states=self.skipped_states,
+            )
+            self._handle_skipped_states(external_task_status)
+
+        self.log.info("Sensor waits for allowed states: %s", 
self.allowed_states)
+        external_task_status = self._check_task_instances_states(
+            task_instances=task_instances,
+            start_date=start_date,
+            end_date=end_date,
+            states=self.allowed_states,
+        )
+        return external_task_status
+
+    def _get_task_instances(self, start_date: str, end_date: str) -> 
list[dict]:
+        """Get the list of task instances."""
+        try:
+            environment = self.hook.get_environment(
+                project_id=self.project_id,
+                region=self.region,
+                environment_id=self.environment_id,
+                timeout=self.timeout,
+            )
+        except NotFound as not_found_err:
+            self.log.info("The Composer environment %s does not exist.", 
self.environment_id)
+            raise AirflowException(not_found_err)
+        composer_airflow_uri = environment.config.airflow_uri
+
+        self.log.info(
+            "Pulling the DAG '%s' task instances from the '%s' environment...",
+            self.composer_external_dag_id,
+            self.environment_id,
+        )
+        task_instances_response = self.hook.get_task_instances(
+            composer_airflow_uri=composer_airflow_uri,
+            composer_dag_id=self.composer_external_dag_id,
+            query_parameters={
+                "execution_date_gte"
+                if self._composer_airflow_version < 3
+                else "logical_date_gte": start_date,
+                "execution_date_lte" if self._composer_airflow_version < 3 
else "logical_date_lte": end_date,
+            },
+            timeout=self.timeout,
+        )
+        task_instances = task_instances_response["task_instances"]
+
+        if self.composer_external_task_ids:
+            task_instances = [
+                task_instance
+                for task_instance in task_instances
+                if task_instance["task_id"] in self.composer_external_task_ids
+            ]
+        elif self.composer_external_task_group_id:
+            task_instances = [
+                task_instance
+                for task_instance in task_instances
+                if self.composer_external_task_group_id in 
task_instance["task_id"].split(".")
+            ]
+
+        return task_instances
+
+    def _check_task_instances_states(
+        self,
+        task_instances: list[dict],
+        start_date: datetime,
+        end_date: datetime,
+        states: Iterable[str],
+    ) -> bool:
+        for task_instance in task_instances:
+            if (
+                start_date.timestamp()
+                < parser.parse(
+                    task_instance["execution_date" if 
self._composer_airflow_version < 3 else "logical_date"]
+                ).timestamp()
+                < end_date.timestamp()
+            ) and task_instance["state"] not in states:
+                return False
+        return True
+
+    def _get_composer_airflow_version(self) -> int:
+        """Return Composer Airflow version."""
+        environment_obj = self.hook.get_environment(
+            project_id=self.project_id,
+            region=self.region,
+            environment_id=self.environment_id,
+        )
+        environment_config = Environment.to_dict(environment_obj)
+        image_version = 
environment_config["config"]["software_config"]["image_version"]
+        return int(image_version.split("airflow-")[1].split(".")[0])
+
+    def _handle_failed_states(self, failed_status: bool) -> None:
+        """Handle failed states and raise appropriate exceptions."""
+        if failed_status:
+            if self.composer_external_task_ids:
+                if self.soft_fail:
+                    raise AirflowSkipException(
+                        f"Some of the external tasks 
'{self.composer_external_task_ids}' "
+                        f"in DAG '{self.composer_external_dag_id}' failed. 
Skipping due to soft_fail."
+                    )
+                raise ExternalTaskFailedError(
+                    f"Some of the external tasks 
'{self.composer_external_task_ids}' "
+                    f"in DAG '{self.composer_external_dag_id}' failed."
+                )
+            if self.composer_external_task_group_id:
+                if self.soft_fail:
+                    raise AirflowSkipException(
+                        f"The external task_group 
'{self.composer_external_task_group_id}' "
+                        f"in DAG '{self.composer_external_dag_id}' failed. 
Skipping due to soft_fail."
+                    )
+                raise ExternalTaskGroupFailedError(
+                    f"The external task_group 
'{self.composer_external_task_group_id}' "
+                    f"in DAG '{self.composer_external_dag_id}' failed."
+                )
+            if self.soft_fail:
+                raise AirflowSkipException(
+                    f"The external DAG '{self.composer_external_dag_id}' 
failed. Skipping due to soft_fail."
+                )
+            raise ExternalDagFailedError(f"The external DAG 
'{self.composer_external_dag_id}' failed.")
+
+    def _handle_skipped_states(self, skipped_status: bool) -> None:
+        """Handle skipped states and raise appropriate exceptions."""
+        if skipped_status:
+            if self.composer_external_task_ids:
+                raise AirflowSkipException(
+                    f"Some of the external tasks 
'{self.composer_external_task_ids}' "
+                    f"in DAG '{self.composer_external_dag_id}' reached a state 
in our states-to-skip-on list. Skipping."
+                )
+            if self.composer_external_task_group_id:
+                raise AirflowSkipException(
+                    f"The external task_group 
'{self.composer_external_task_group_id}' "
+                    f"in DAG '{self.composer_external_dag_id}' reached a state 
in our states-to-skip-on list. Skipping."
+                )
+            raise AirflowSkipException(
+                f"The external DAG '{self.composer_external_dag_id}' reached a 
state in our states-to-skip-on list. "
+                "Skipping."
+            )
+
+    def execute(self, context: Context) -> None:
+        self._composer_airflow_version = self._get_composer_airflow_version()
+
+        if self.composer_external_task_ids and 
len(self.composer_external_task_ids) > len(
+            set(self.composer_external_task_ids)
+        ):
+            raise ValueError("Duplicate task_ids passed in 
composer_external_task_ids parameter")
+
+        if self.composer_external_task_ids:
+            self.log.info(
+                "Poking for tasks '%s' in dag '%s' on Composer environment 
'%s' ... ",
+                self.composer_external_task_ids,
+                self.composer_external_dag_id,
+                self.environment_id,
+            )
+
+        if self.composer_external_task_group_id:
+            self.log.info(
+                "Poking for task_group '%s' in dag '%s' on Composer 
environment '%s' ... ",
+                self.composer_external_task_group_id,
+                self.composer_external_dag_id,
+                self.environment_id,
+            )
+
+        if (
+            self.composer_external_dag_id
+            and not self.composer_external_task_group_id
+            and not self.composer_external_task_ids
+        ):
+            self.log.info(
+                "Poking for DAG '%s' on Composer environment '%s' ... ",
+                self.composer_external_dag_id,
+                self.environment_id,
+            )
+
+        if self.deferrable:
+            start_date, end_date = self._get_logical_dates(context)
+            self.defer(
+                timeout=timedelta(seconds=self.timeout) if self.timeout else 
None,
+                trigger=CloudComposerExternalTaskTrigger(
+                    project_id=self.project_id,
+                    region=self.region,
+                    environment_id=self.environment_id,
+                    composer_external_dag_id=self.composer_external_dag_id,
+                    composer_external_task_ids=self.composer_external_task_ids,
+                    
composer_external_task_group_id=self.composer_external_task_group_id,
+                    start_date=start_date,
+                    end_date=end_date,
+                    allowed_states=self.allowed_states,
+                    skipped_states=self.skipped_states,
+                    failed_states=self.failed_states,
+                    gcp_conn_id=self.gcp_conn_id,
+                    impersonation_chain=self.impersonation_chain,
+                    poll_interval=self.poll_interval,
+                    composer_airflow_version=self._composer_airflow_version,
+                ),
+                method_name=GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME,
+            )
+        super().execute(context)
+
+    def execute_complete(self, context: Context, event: dict):
+        if event and event["status"] == "error":
+            raise AirflowException(event["message"])
+        if event and event["status"] == "failed":
+            self._handle_failed_states(True)
+        elif event and event["status"] == "skipped":
+            self._handle_skipped_states(True)
+
+        self.log.info("External tasks for DAG '%s' has executed 
successfully.", self.composer_external_dag_id)
+
+    @cached_property
+    def hook(self) -> CloudComposerHook:
+        return CloudComposerHook(
+            gcp_conn_id=self.gcp_conn_id,
+            impersonation_chain=self.impersonation_chain,
+        )
diff --git 
a/providers/google/src/airflow/providers/google/cloud/triggers/cloud_composer.py
 
b/providers/google/src/airflow/providers/google/cloud/triggers/cloud_composer.py
index 006480e5da5..bf60acdf93e 100644
--- 
a/providers/google/src/airflow/providers/google/cloud/triggers/cloud_composer.py
+++ 
b/providers/google/src/airflow/providers/google/cloud/triggers/cloud_composer.py
@@ -20,7 +20,7 @@ from __future__ import annotations
 
 import asyncio
 import json
-from collections.abc import Sequence
+from collections.abc import Collection, Iterable, Sequence
 from datetime import datetime
 from typing import Any
 
@@ -346,3 +346,185 @@ class CloudComposerDAGRunTrigger(BaseTrigger):
                 }
             )
             return
+
+
+class CloudComposerExternalTaskTrigger(BaseTrigger):
+    """The trigger wait for the external task completion."""
+
+    def __init__(
+        self,
+        project_id: str,
+        region: str,
+        environment_id: str,
+        start_date: datetime,
+        end_date: datetime,
+        allowed_states: list[str],
+        skipped_states: list[str],
+        failed_states: list[str],
+        composer_external_dag_id: str,
+        composer_external_task_ids: Collection[str] | None = None,
+        composer_external_task_group_id: str | None = None,
+        gcp_conn_id: str = "google_cloud_default",
+        impersonation_chain: str | Sequence[str] | None = None,
+        poll_interval: int = 10,
+        composer_airflow_version: int = 2,
+    ):
+        super().__init__()
+        self.project_id = project_id
+        self.region = region
+        self.environment_id = environment_id
+        self.start_date = start_date
+        self.end_date = end_date
+        self.allowed_states = allowed_states
+        self.skipped_states = skipped_states
+        self.failed_states = failed_states
+        self.composer_external_dag_id = composer_external_dag_id
+        self.composer_external_task_ids = composer_external_task_ids
+        self.composer_external_task_group_id = composer_external_task_group_id
+        self.gcp_conn_id = gcp_conn_id
+        self.impersonation_chain = impersonation_chain
+        self.poll_interval = poll_interval
+        self.composer_airflow_version = composer_airflow_version
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        return (
+            
"airflow.providers.google.cloud.triggers.cloud_composer.CloudComposerExternalTaskTrigger",
+            {
+                "project_id": self.project_id,
+                "region": self.region,
+                "environment_id": self.environment_id,
+                "start_date": self.start_date,
+                "end_date": self.end_date,
+                "allowed_states": self.allowed_states,
+                "skipped_states": self.skipped_states,
+                "failed_states": self.failed_states,
+                "composer_external_dag_id": self.composer_external_dag_id,
+                "composer_external_task_ids": self.composer_external_task_ids,
+                "composer_external_task_group_id": 
self.composer_external_task_group_id,
+                "gcp_conn_id": self.gcp_conn_id,
+                "impersonation_chain": self.impersonation_chain,
+                "poll_interval": self.poll_interval,
+                "composer_airflow_version": self.composer_airflow_version,
+            },
+        )
+
+    async def _get_task_instances(self, start_date: str, end_date: str) -> 
list[dict]:
+        """Get the list of task instances."""
+        try:
+            environment = await self.gcp_hook.get_environment(
+                project_id=self.project_id,
+                region=self.region,
+                environment_id=self.environment_id,
+            )
+        except NotFound as not_found_err:
+            self.log.info("The Composer environment %s does not exist.", 
self.environment_id)
+            raise AirflowException(not_found_err)
+        composer_airflow_uri = environment.config.airflow_uri
+
+        self.log.info(
+            "Pulling the DAG '%s' task instances from the '%s' environment...",
+            self.composer_external_dag_id,
+            self.environment_id,
+        )
+        task_instances_response = await self.gcp_hook.get_task_instances(
+            composer_airflow_uri=composer_airflow_uri,
+            composer_dag_id=self.composer_external_dag_id,
+            query_parameters={
+                "execution_date_gte" if self.composer_airflow_version < 3 else 
"logical_date_gte": start_date,
+                "execution_date_lte" if self.composer_airflow_version < 3 else 
"logical_date_lte": end_date,
+            },
+        )
+        task_instances = task_instances_response["task_instances"]
+
+        if self.composer_external_task_ids:
+            task_instances = [
+                task_instance
+                for task_instance in task_instances
+                if task_instance["task_id"] in self.composer_external_task_ids
+            ]
+        elif self.composer_external_task_group_id:
+            task_instances = [
+                task_instance
+                for task_instance in task_instances
+                if self.composer_external_task_group_id in 
task_instance["task_id"].split(".")
+            ]
+
+        return task_instances
+
+    def _check_task_instances_states(
+        self,
+        task_instances: list[dict],
+        start_date: datetime,
+        end_date: datetime,
+        states: Iterable[str],
+    ) -> bool:
+        for task_instance in task_instances:
+            if (
+                start_date.timestamp()
+                < parser.parse(
+                    task_instance["execution_date" if 
self.composer_airflow_version < 3 else "logical_date"]
+                ).timestamp()
+                < end_date.timestamp()
+            ) and task_instance["state"] not in states:
+                return False
+        return True
+
+    def _get_async_hook(self) -> CloudComposerAsyncHook:
+        return CloudComposerAsyncHook(
+            gcp_conn_id=self.gcp_conn_id,
+            impersonation_chain=self.impersonation_chain,
+        )
+
+    async def run(self):
+        self.gcp_hook: CloudComposerAsyncHook = self._get_async_hook()
+        try:
+            while True:
+                task_instances = await self._get_task_instances(
+                    start_date=self.start_date.strftime("%Y-%m-%dT%H:%M:%SZ"),
+                    end_date=self.end_date.strftime("%Y-%m-%dT%H:%M:%SZ"),
+                )
+
+                if len(task_instances) == 0:
+                    self.log.info("Task Instances are empty. Sensor waits for 
task instances...")
+                    self.log.info("Sleeping for %s seconds.", 
self.poll_interval)
+                    await asyncio.sleep(self.poll_interval)
+                    continue
+
+                if self.failed_states and self._check_task_instances_states(
+                    task_instances=task_instances,
+                    start_date=self.start_date,
+                    end_date=self.end_date,
+                    states=self.failed_states,
+                ):
+                    yield TriggerEvent({"status": "failed"})
+                    return
+
+                if self.skipped_states and self._check_task_instances_states(
+                    task_instances=task_instances,
+                    start_date=self.start_date,
+                    end_date=self.end_date,
+                    states=self.skipped_states,
+                ):
+                    yield TriggerEvent({"status": "skipped"})
+                    return
+
+                self.log.info("Sensor waits for allowed states: %s", 
self.allowed_states)
+                if self._check_task_instances_states(
+                    task_instances=task_instances,
+                    start_date=self.start_date,
+                    end_date=self.end_date,
+                    states=self.allowed_states,
+                ):
+                    yield TriggerEvent({"status": "success"})
+                    return
+
+                self.log.info("Sleeping for %s seconds.", self.poll_interval)
+                await asyncio.sleep(self.poll_interval)
+        except AirflowException as ex:
+            yield TriggerEvent(
+                {
+                    "status": "error",
+                    "message": str(ex),
+                }
+            )
+            return
diff --git 
a/providers/google/tests/system/google/cloud/composer/example_cloud_composer.py 
b/providers/google/tests/system/google/cloud/composer/example_cloud_composer.py
index 21e635a3853..a5105dc3eca 100644
--- 
a/providers/google/tests/system/google/cloud/composer/example_cloud_composer.py
+++ 
b/providers/google/tests/system/google/cloud/composer/example_cloud_composer.py
@@ -18,7 +18,7 @@
 from __future__ import annotations
 
 import os
-from datetime import datetime
+from datetime import datetime, timedelta
 
 from googleapiclient.discovery import build
 from googleapiclient.errors import HttpError
@@ -43,7 +43,10 @@ from airflow.providers.google.cloud.operators.cloud_composer 
import (
     CloudComposerTriggerDAGRunOperator,
     CloudComposerUpdateEnvironmentOperator,
 )
-from airflow.providers.google.cloud.sensors.cloud_composer import 
CloudComposerDAGRunSensor
+from airflow.providers.google.cloud.sensors.cloud_composer import (
+    CloudComposerDAGRunSensor,
+    CloudComposerExternalTaskSensor,
+)
 
 try:
     from airflow.sdk import TriggerRule
@@ -229,6 +232,33 @@ with DAG(
     )
     # [END howto_operator_trigger_dag_run]
 
+    # [START howto_sensor_external_task]
+    external_task_sensor = CloudComposerExternalTaskSensor(
+        task_id="external_task_sensor",
+        project_id=PROJECT_ID,
+        region=REGION,
+        environment_id=ENVIRONMENT_ID,
+        composer_external_dag_id="airflow_monitoring",
+        composer_external_task_id="echo",
+        allowed_states=["success"],
+        execution_range=[datetime.now() - timedelta(1), datetime.now()],
+    )
+    # [END howto_sensor_external_task]
+
+    # [START howto_sensor_external_task_deferrable_mode]
+    defer_external_task_sensor = CloudComposerExternalTaskSensor(
+        task_id="defer_external_task_sensor",
+        project_id=PROJECT_ID,
+        region=REGION,
+        environment_id=ENVIRONMENT_ID_ASYNC,
+        composer_external_dag_id="airflow_monitoring",
+        composer_external_task_id="echo",
+        allowed_states=["success"],
+        execution_range=[datetime.now() - timedelta(1), datetime.now()],
+        deferrable=True,
+    )
+    # [END howto_sensor_external_task_deferrable_mode]
+
     # [START howto_operator_delete_composer_environment]
     delete_env = CloudComposerDeleteEnvironmentOperator(
         task_id="delete_env",
@@ -262,6 +292,7 @@ with DAG(
         [run_airflow_cli_cmd, defer_run_airflow_cli_cmd],
         [dag_run_sensor, defer_dag_run_sensor],
         trigger_dag_run,
+        [external_task_sensor, defer_external_task_sensor],
         # TEST TEARDOWN
         [delete_env, defer_delete_env],
     )
diff --git 
a/providers/google/tests/unit/google/cloud/hooks/test_cloud_composer.py 
b/providers/google/tests/unit/google/cloud/hooks/test_cloud_composer.py
index e577f9bd086..680fe791db6 100644
--- a/providers/google/tests/unit/google/cloud/hooks/test_cloud_composer.py
+++ b/providers/google/tests/unit/google/cloud/hooks/test_cloud_composer.py
@@ -298,6 +298,24 @@ class TestCloudComposerHook:
             timeout=TEST_TIMEOUT,
         )
 
+    @pytest.mark.parametrize("query_parameters", [None, {"test_key": 
"test_value"}])
+    
@mock.patch(COMPOSER_STRING.format("CloudComposerHook.make_composer_airflow_api_request"))
+    def test_get_task_instances(self, mock_composer_airflow_api_request, 
query_parameters) -> None:
+        query_string = "?test_key=test_value" if query_parameters else ""
+        self.hook.get_credentials = mock.MagicMock()
+        self.hook.get_task_instances(
+            composer_airflow_uri=TEST_COMPOSER_AIRFLOW_URI,
+            composer_dag_id=TEST_COMPOSER_DAG_ID,
+            query_parameters=query_parameters,
+            timeout=TEST_TIMEOUT,
+        )
+        mock_composer_airflow_api_request.assert_called_once_with(
+            method="GET",
+            airflow_uri=TEST_COMPOSER_AIRFLOW_URI,
+            
path=f"/api/v1/dags/{TEST_COMPOSER_DAG_ID}/dagRuns/~/taskInstances{query_string}",
+            timeout=TEST_TIMEOUT,
+        )
+
 
 class TestCloudComposerAsyncHook:
     def setup_method(self, method):
@@ -484,3 +502,22 @@ class TestCloudComposerAsyncHook:
             path=f"/api/v1/dags/{TEST_COMPOSER_DAG_ID}/dagRuns",
             timeout=TEST_TIMEOUT,
         )
+
+    @pytest.mark.asyncio
+    @pytest.mark.parametrize("query_parameters", [None, {"test_key": 
"test_value"}])
+    
@mock.patch(COMPOSER_STRING.format("CloudComposerAsyncHook.make_composer_airflow_api_request"))
+    async def test_get_task_instances(self, mock_composer_airflow_api_request, 
query_parameters) -> None:
+        query_string = "?test_key=test_value" if query_parameters else ""
+        mock_composer_airflow_api_request.return_value = ({}, 200)
+        await self.hook.get_task_instances(
+            composer_airflow_uri=TEST_COMPOSER_AIRFLOW_URI,
+            composer_dag_id=TEST_COMPOSER_DAG_ID,
+            query_parameters=query_parameters,
+            timeout=TEST_TIMEOUT,
+        )
+        mock_composer_airflow_api_request.assert_called_once_with(
+            method="GET",
+            airflow_uri=TEST_COMPOSER_AIRFLOW_URI,
+            
path=f"/api/v1/dags/{TEST_COMPOSER_DAG_ID}/dagRuns/~/taskInstances{query_string}",
+            timeout=TEST_TIMEOUT,
+        )
diff --git 
a/providers/google/tests/unit/google/cloud/sensors/test_cloud_composer.py 
b/providers/google/tests/unit/google/cloud/sensors/test_cloud_composer.py
index 6988508b056..7f639a59685 100644
--- a/providers/google/tests/unit/google/cloud/sensors/test_cloud_composer.py
+++ b/providers/google/tests/unit/google/cloud/sensors/test_cloud_composer.py
@@ -23,7 +23,10 @@ from unittest import mock
 
 import pytest
 
-from airflow.providers.google.cloud.sensors.cloud_composer import 
CloudComposerDAGRunSensor
+from airflow.providers.google.cloud.sensors.cloud_composer import (
+    CloudComposerDAGRunSensor,
+    CloudComposerExternalTaskSensor,
+)
 
 TEST_PROJECT_ID = "test_project_id"
 TEST_OPERATION_NAME = "test_operation_name"
@@ -49,6 +52,22 @@ TEST_GET_RESULT = lambda state, date_key: {
     "dag_runs": TEST_DAG_RUNS_RESULT(state, date_key, "dag_run_id"),
     "total_entries": 1,
 }
+TEST_COMPOSER_EXTERNAL_TASK_ID = "test_external_task_id"
+TEST_COMPOSER_EXTERNAL_TASK_GROUP_ID = "test_external_task_group_id"
+TEST_TASK_INSTANCES_RESULT = lambda state, date_key, task_id: [
+    {
+        "task_id": task_id,
+        "dag_id": "test_dag_id",
+        "state": state,
+        date_key: "2024-05-22T11:10:00+00:00",
+        "start_date": "2024-05-22T11:20:01.531988+00:00",
+        "end_date": "2024-05-22T11:20:11.997479+00:00",
+    }
+]
+TEST_GET_TASK_INSTANCES_RESULT = lambda state, date_key, task_id: {
+    "task_instances": TEST_TASK_INSTANCES_RESULT(state, date_key, task_id),
+    "total_entries": 1,
+}
 
 
 class TestCloudComposerDAGRunSensor:
@@ -185,3 +204,155 @@ class TestCloudComposerDAGRunSensor:
         task._composer_airflow_version = composer_airflow_version
 
         assert not task.poke(context={"logical_date": datetime(2024, 5, 23, 0, 
0, 0)})
+
+
+class TestCloudComposerExternalTaskSensor:
+    @pytest.mark.parametrize("composer_airflow_version", [2, 3])
+    
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerHook")
+    def test_wait_ready(self, mock_hook, composer_airflow_version):
+        mock_hook.return_value.get_task_instances.return_value = 
TEST_GET_TASK_INSTANCES_RESULT(
+            "success",
+            "execution_date" if composer_airflow_version < 3 else 
"logical_date",
+            TEST_COMPOSER_EXTERNAL_TASK_ID,
+        )
+
+        task = CloudComposerExternalTaskSensor(
+            task_id="task-id",
+            project_id=TEST_PROJECT_ID,
+            region=TEST_REGION,
+            environment_id=TEST_ENVIRONMENT_ID,
+            composer_external_dag_id="test_dag_id",
+            allowed_states=["success"],
+        )
+        task._composer_airflow_version = composer_airflow_version
+
+        assert task.poke(context={"logical_date": datetime(2024, 5, 23, 0, 0, 
0)})
+
+    @pytest.mark.parametrize("composer_airflow_version", [2, 3])
+    
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerHook")
+    def test_wait_not_ready(self, mock_hook, composer_airflow_version):
+        mock_hook.return_value.get_task_instances.return_value = 
TEST_GET_TASK_INSTANCES_RESULT(
+            "running",
+            "execution_date" if composer_airflow_version < 3 else 
"logical_date",
+            TEST_COMPOSER_EXTERNAL_TASK_ID,
+        )
+
+        task = CloudComposerExternalTaskSensor(
+            task_id="task-id",
+            project_id=TEST_PROJECT_ID,
+            region=TEST_REGION,
+            environment_id=TEST_ENVIRONMENT_ID,
+            composer_external_dag_id="test_dag_id",
+            allowed_states=["success"],
+        )
+        task._composer_airflow_version = composer_airflow_version
+
+        assert not task.poke(context={"logical_date": datetime(2024, 5, 23, 0, 
0, 0)})
+
+    @pytest.mark.parametrize("composer_airflow_version", [2, 3])
+    
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerHook")
+    def test_task_instances_empty(self, mock_hook, composer_airflow_version):
+        mock_hook.return_value.get_task_instances.return_value = {
+            "task_instances": [],
+            "total_entries": 0,
+        }
+
+        task = CloudComposerExternalTaskSensor(
+            task_id="task-id",
+            project_id=TEST_PROJECT_ID,
+            region=TEST_REGION,
+            environment_id=TEST_ENVIRONMENT_ID,
+            composer_external_dag_id="test_dag_id",
+            allowed_states=["success"],
+        )
+        task._composer_airflow_version = composer_airflow_version
+
+        assert not task.poke(context={"logical_date": datetime(2024, 5, 23, 0, 
0, 0)})
+
+    @pytest.mark.parametrize("composer_airflow_version", [2, 3])
+    
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerHook")
+    def test_composer_external_task_id_wait_ready(self, mock_hook, 
composer_airflow_version):
+        mock_hook.return_value.get_task_instances.return_value = 
TEST_GET_TASK_INSTANCES_RESULT(
+            "success",
+            "execution_date" if composer_airflow_version < 3 else 
"logical_date",
+            TEST_COMPOSER_EXTERNAL_TASK_ID,
+        )
+
+        task = CloudComposerExternalTaskSensor(
+            task_id="task-id",
+            project_id=TEST_PROJECT_ID,
+            region=TEST_REGION,
+            environment_id=TEST_ENVIRONMENT_ID,
+            composer_external_dag_id="test_dag_id",
+            composer_external_task_id=TEST_COMPOSER_EXTERNAL_TASK_ID,
+            allowed_states=["success"],
+        )
+        task._composer_airflow_version = composer_airflow_version
+
+        assert task.poke(context={"logical_date": datetime(2024, 5, 23, 0, 0, 
0)})
+
+    @pytest.mark.parametrize("composer_airflow_version", [2, 3])
+    
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerHook")
+    def test_composer_external_task_id_wait_not_ready(self, mock_hook, 
composer_airflow_version):
+        mock_hook.return_value.get_task_instances.return_value = 
TEST_GET_TASK_INSTANCES_RESULT(
+            "running",
+            "execution_date" if composer_airflow_version < 3 else 
"logical_date",
+            TEST_COMPOSER_EXTERNAL_TASK_ID,
+        )
+
+        task = CloudComposerExternalTaskSensor(
+            task_id="task-id",
+            project_id=TEST_PROJECT_ID,
+            region=TEST_REGION,
+            environment_id=TEST_ENVIRONMENT_ID,
+            composer_external_dag_id="test_dag_id",
+            composer_external_task_id=TEST_COMPOSER_EXTERNAL_TASK_ID,
+            allowed_states=["success"],
+        )
+        task._composer_airflow_version = composer_airflow_version
+
+        assert not task.poke(context={"logical_date": datetime(2024, 5, 23, 0, 
0, 0)})
+
+    @pytest.mark.parametrize("composer_airflow_version", [2, 3])
+    
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerHook")
+    def test_composer_external_task_group_id_wait_ready(self, mock_hook, 
composer_airflow_version):
+        mock_hook.return_value.get_task_instances.return_value = 
TEST_GET_TASK_INSTANCES_RESULT(
+            "success",
+            "execution_date" if composer_airflow_version < 3 else 
"logical_date",
+            
f"{TEST_COMPOSER_EXTERNAL_TASK_GROUP_ID}.{TEST_COMPOSER_EXTERNAL_TASK_ID}",
+        )
+
+        task = CloudComposerExternalTaskSensor(
+            task_id="task-id",
+            project_id=TEST_PROJECT_ID,
+            region=TEST_REGION,
+            environment_id=TEST_ENVIRONMENT_ID,
+            composer_external_dag_id="test_dag_id",
+            
composer_external_task_group_id=TEST_COMPOSER_EXTERNAL_TASK_GROUP_ID,
+            allowed_states=["success"],
+        )
+        task._composer_airflow_version = composer_airflow_version
+
+        assert task.poke(context={"logical_date": datetime(2024, 5, 23, 0, 0, 
0)})
+
+    @pytest.mark.parametrize("composer_airflow_version", [2, 3])
+    
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerHook")
+    def test_composer_external_task_group_id_wait_not_ready(self, mock_hook, 
composer_airflow_version):
+        mock_hook.return_value.get_task_instances.return_value = 
TEST_GET_TASK_INSTANCES_RESULT(
+            "running",
+            "execution_date" if composer_airflow_version < 3 else 
"logical_date",
+            
f"{TEST_COMPOSER_EXTERNAL_TASK_GROUP_ID}.{TEST_COMPOSER_EXTERNAL_TASK_ID}",
+        )
+
+        task = CloudComposerExternalTaskSensor(
+            task_id="task-id",
+            project_id=TEST_PROJECT_ID,
+            region=TEST_REGION,
+            environment_id=TEST_ENVIRONMENT_ID,
+            composer_external_dag_id="test_dag_id",
+            
composer_external_task_group_id=TEST_COMPOSER_EXTERNAL_TASK_GROUP_ID,
+            allowed_states=["success"],
+        )
+        task._composer_airflow_version = composer_airflow_version
+
+        assert not task.poke(context={"logical_date": datetime(2024, 5, 23, 0, 
0, 0)})
diff --git 
a/providers/google/tests/unit/google/cloud/triggers/test_cloud_composer.py 
b/providers/google/tests/unit/google/cloud/triggers/test_cloud_composer.py
index f093a76d1d7..1d6307f3814 100644
--- a/providers/google/tests/unit/google/cloud/triggers/test_cloud_composer.py
+++ b/providers/google/tests/unit/google/cloud/triggers/test_cloud_composer.py
@@ -26,6 +26,7 @@ from airflow.models import Connection
 from airflow.providers.google.cloud.triggers.cloud_composer import (
     CloudComposerAirflowCLICommandTrigger,
     CloudComposerDAGRunTrigger,
+    CloudComposerExternalTaskTrigger,
 )
 from airflow.triggers.base import TriggerEvent
 
@@ -40,9 +41,13 @@ TEST_EXEC_CMD_INFO = {
 }
 TEST_COMPOSER_DAG_ID = "test_dag_id"
 TEST_COMPOSER_DAG_RUN_ID = "scheduled__2024-05-22T11:10:00+00:00"
+TEST_COMPOSER_EXTERNAL_TASK_IDS = ["test_external_task_id"]
+TEST_COMPOSER_EXTERNAL_TASK_GROUP_ID = "test_external_task_group_id"
 TEST_START_DATE = datetime(2024, 3, 22, 11, 0, 0)
 TEST_END_DATE = datetime(2024, 3, 22, 12, 0, 0)
-TEST_STATES = ["success"]
+TEST_ALLOWED_STATES = ["success"]
+TEST_SKIPPED_STATES = ["skipped"]
+TEST_FAILED_STATES = ["failed"]
 TEST_GCP_CONN_ID = "test_gcp_conn_id"
 TEST_POLL_INTERVAL = 10
 TEST_COMPOSER_AIRFLOW_VERSION = 3
@@ -86,7 +91,7 @@ def dag_run_trigger(mock_conn):
         composer_dag_run_id=TEST_COMPOSER_DAG_RUN_ID,
         start_date=TEST_START_DATE,
         end_date=TEST_END_DATE,
-        allowed_states=TEST_STATES,
+        allowed_states=TEST_ALLOWED_STATES,
         gcp_conn_id=TEST_GCP_CONN_ID,
         impersonation_chain=TEST_IMPERSONATION_CHAIN,
         poll_interval=TEST_POLL_INTERVAL,
@@ -95,6 +100,31 @@ def dag_run_trigger(mock_conn):
     )
 
 
[email protected]
[email protected](
+    
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.get_connection",
+    return_value=Connection(conn_id="test_conn"),
+)
+def external_task_trigger(mock_conn):
+    return CloudComposerExternalTaskTrigger(
+        project_id=TEST_PROJECT_ID,
+        region=TEST_LOCATION,
+        environment_id=TEST_ENVIRONMENT_ID,
+        start_date=TEST_START_DATE,
+        end_date=TEST_END_DATE,
+        allowed_states=TEST_ALLOWED_STATES,
+        skipped_states=TEST_SKIPPED_STATES,
+        failed_states=TEST_FAILED_STATES,
+        composer_external_dag_id=TEST_COMPOSER_DAG_ID,
+        composer_external_task_ids=TEST_COMPOSER_EXTERNAL_TASK_IDS,
+        composer_external_task_group_id=TEST_COMPOSER_EXTERNAL_TASK_GROUP_ID,
+        gcp_conn_id=TEST_GCP_CONN_ID,
+        impersonation_chain=TEST_IMPERSONATION_CHAIN,
+        poll_interval=TEST_POLL_INTERVAL,
+        composer_airflow_version=TEST_COMPOSER_AIRFLOW_VERSION,
+    )
+
+
 class TestCloudComposerAirflowCLICommandTrigger:
     def test_serialize(self, cli_command_trigger):
         actual_data = cli_command_trigger.serialize()
@@ -143,7 +173,7 @@ class TestCloudComposerDAGRunTrigger:
                 "composer_dag_run_id": TEST_COMPOSER_DAG_RUN_ID,
                 "start_date": TEST_START_DATE,
                 "end_date": TEST_END_DATE,
-                "allowed_states": TEST_STATES,
+                "allowed_states": TEST_ALLOWED_STATES,
                 "gcp_conn_id": TEST_GCP_CONN_ID,
                 "impersonation_chain": TEST_IMPERSONATION_CHAIN,
                 "poll_interval": TEST_POLL_INTERVAL,
@@ -152,3 +182,29 @@ class TestCloudComposerDAGRunTrigger:
             },
         )
         assert actual_data == expected_data
+
+
+class TestCloudComposerExternalTaskTrigger:
+    def test_serialize(self, external_task_trigger):
+        actual_data = external_task_trigger.serialize()
+        expected_data = (
+            
"airflow.providers.google.cloud.triggers.cloud_composer.CloudComposerExternalTaskTrigger",
+            {
+                "project_id": TEST_PROJECT_ID,
+                "region": TEST_LOCATION,
+                "environment_id": TEST_ENVIRONMENT_ID,
+                "start_date": TEST_START_DATE,
+                "end_date": TEST_END_DATE,
+                "allowed_states": TEST_ALLOWED_STATES,
+                "skipped_states": TEST_SKIPPED_STATES,
+                "failed_states": TEST_FAILED_STATES,
+                "composer_external_dag_id": TEST_COMPOSER_DAG_ID,
+                "composer_external_task_ids": TEST_COMPOSER_EXTERNAL_TASK_IDS,
+                "composer_external_task_group_id": 
TEST_COMPOSER_EXTERNAL_TASK_GROUP_ID,
+                "gcp_conn_id": TEST_GCP_CONN_ID,
+                "impersonation_chain": TEST_IMPERSONATION_CHAIN,
+                "poll_interval": TEST_POLL_INTERVAL,
+                "composer_airflow_version": TEST_COMPOSER_AIRFLOW_VERSION,
+            },
+        )
+        assert actual_data == expected_data

Reply via email to