This is an automated email from the ASF dual-hosted git repository.

shahar pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new fcc28a56437 Add use_rest_api parameter for CloudComposerDAGRunSensor 
for pulling dag_runs using the Airflow REST API (#56138)
fcc28a56437 is described below

commit fcc28a56437c58bc34fea19bd53936c9fdf8799b
Author: Maksim <[email protected]>
AuthorDate: Thu Oct 9 14:11:07 2025 +0200

    Add use_rest_api parameter for CloudComposerDAGRunSensor for pulling 
dag_runs using the Airflow REST API (#56138)
---
 .../providers/google/cloud/hooks/cloud_composer.py | 132 ++++++++++++++++++++-
 .../google/cloud/sensors/cloud_composer.py         |  74 ++++++++----
 .../google/cloud/triggers/cloud_composer.py        |  72 +++++++----
 .../unit/google/cloud/hooks/test_cloud_composer.py |  56 +++++++++
 .../google/cloud/sensors/test_cloud_composer.py    |  68 ++++++++---
 .../google/cloud/triggers/test_cloud_composer.py   |   3 +
 6 files changed, 344 insertions(+), 61 deletions(-)

diff --git 
a/providers/google/src/airflow/providers/google/cloud/hooks/cloud_composer.py 
b/providers/google/src/airflow/providers/google/cloud/hooks/cloud_composer.py
index daf6a06a927..423d6cb56df 100644
--- 
a/providers/google/src/airflow/providers/google/cloud/hooks/cloud_composer.py
+++ 
b/providers/google/src/airflow/providers/google/cloud/hooks/cloud_composer.py
@@ -24,9 +24,10 @@ from collections.abc import MutableSequence, Sequence
 from typing import TYPE_CHECKING, Any
 from urllib.parse import urljoin
 
+from aiohttp import ClientSession
 from google.api_core.client_options import ClientOptions
 from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
-from google.auth.transport.requests import AuthorizedSession
+from google.auth.transport.requests import AuthorizedSession, Request
 from google.cloud.orchestration.airflow.service_v1 import (
     EnvironmentsAsyncClient,
     EnvironmentsClient,
@@ -472,6 +473,38 @@ class CloudComposerHook(GoogleBaseHook, OperationHelper):
 
         return response.json()
 
+    def get_dag_runs(
+        self,
+        composer_airflow_uri: str,
+        composer_dag_id: str,
+        timeout: float | None = None,
+    ) -> dict:
+        """
+        Get the list of dag runs for provided DAG.
+
+        :param composer_airflow_uri: The URI of the Apache Airflow Web UI 
hosted within Composer environment.
+        :param composer_dag_id: The ID of DAG.
+        :param timeout: The timeout for this request.
+        """
+        response = self.make_composer_airflow_api_request(
+            method="GET",
+            airflow_uri=composer_airflow_uri,
+            path=f"/api/v1/dags/{composer_dag_id}/dagRuns",
+            timeout=timeout,
+        )
+
+        if response.status_code != 200:
+            self.log.error(
+                "Failed to get DAG runs for dag_id=%s from %s (status=%s): %s",
+                composer_dag_id,
+                composer_airflow_uri,
+                response.status_code,
+                response.text,
+            )
+            response.raise_for_status()
+
+        return response.json()
+
 
 class CloudComposerAsyncHook(GoogleBaseAsyncHook):
     """Hook for Google Cloud Composer async APIs."""
@@ -489,6 +522,42 @@ class CloudComposerAsyncHook(GoogleBaseAsyncHook):
             client_options=self.client_options,
         )
 
+    async def make_composer_airflow_api_request(
+        self,
+        method: str,
+        airflow_uri: str,
+        path: str,
+        data: Any | None = None,
+        timeout: float | None = None,
+    ):
+        """
+        Make a request to Cloud Composer environment's web server.
+
+        :param method: The request method to use ('GET', 'OPTIONS', 'HEAD', 
'POST', 'PUT', 'PATCH', 'DELETE').
+        :param airflow_uri: The URI of the Apache Airflow Web UI hosted within 
this environment.
+        :param path: The path to send the request.
+        :param data: Dictionary, list of tuples, bytes, or file-like object to 
send in the body of the request.
+        :param timeout: The timeout for this request.
+        """
+        sync_hook = await self.get_sync_hook()
+        credentials = sync_hook.get_credentials()
+
+        if not credentials.valid:
+            credentials.refresh(Request())
+
+        async with ClientSession() as session:
+            async with session.request(
+                method=method,
+                url=urljoin(airflow_uri, path),
+                data=data,
+                headers={
+                    "Content-Type": "application/json",
+                    "Authorization": f"Bearer {credentials.token}",
+                },
+                timeout=timeout,
+            ) as response:
+                return await response.json(), response.status
+
     def get_environment_name(self, project_id, region, environment_id):
         return 
f"projects/{project_id}/locations/{region}/environments/{environment_id}"
 
@@ -594,6 +663,35 @@ class CloudComposerAsyncHook(GoogleBaseAsyncHook):
             metadata=metadata,
         )
 
+    @GoogleBaseHook.fallback_to_default_project_id
+    async def get_environment(
+        self,
+        project_id: str,
+        region: str,
+        environment_id: str,
+        retry: AsyncRetry | _MethodDefault = DEFAULT,
+        timeout: float | None = None,
+        metadata: Sequence[tuple[str, str]] = (),
+    ) -> Environment:
+        """
+        Get an existing environment.
+
+        :param project_id: Required. The ID of the Google Cloud project that 
the service belongs to.
+        :param region: Required. The ID of the Google Cloud region that the 
service belongs to.
+        :param environment_id: Required. The ID of the Google Cloud 
environment that the service belongs to.
+        :param retry: Designation of what errors, if any, should be retried.
+        :param timeout: The timeout for this request.
+        :param metadata: Strings which should be sent along with the request 
as metadata.
+        """
+        client = await self.get_environment_client()
+
+        return await client.get_environment(
+            request={"name": self.get_environment_name(project_id, region, 
environment_id)},
+            retry=retry,
+            timeout=timeout,
+            metadata=metadata,
+        )
+
     @GoogleBaseHook.fallback_to_default_project_id
     async def execute_airflow_command(
         self,
@@ -719,3 +817,35 @@ class CloudComposerAsyncHook(GoogleBaseAsyncHook):
 
             self.log.info("Sleeping for %s seconds.", poll_interval)
             await asyncio.sleep(poll_interval)
+
+    async def get_dag_runs(
+        self,
+        composer_airflow_uri: str,
+        composer_dag_id: str,
+        timeout: float | None = None,
+    ) -> dict:
+        """
+        Get the list of dag runs for provided DAG.
+
+        :param composer_airflow_uri: The URI of the Apache Airflow Web UI 
hosted within Composer environment.
+        :param composer_dag_id: The ID of DAG.
+        :param timeout: The timeout for this request.
+        """
+        response_body, response_status_code = await 
self.make_composer_airflow_api_request(
+            method="GET",
+            airflow_uri=composer_airflow_uri,
+            path=f"/api/v1/dags/{composer_dag_id}/dagRuns",
+            timeout=timeout,
+        )
+
+        if response_status_code != 200:
+            self.log.error(
+                "Failed to get DAG runs for dag_id=%s from %s (status=%s): %s",
+                composer_dag_id,
+                composer_airflow_uri,
+                response_status_code,
+                response_body["title"],
+            )
+            raise AirflowException(response_body["title"])
+
+        return response_body
diff --git 
a/providers/google/src/airflow/providers/google/cloud/sensors/cloud_composer.py 
b/providers/google/src/airflow/providers/google/cloud/sensors/cloud_composer.py
index 09aa0e1aa8a..242953dd2ef 100644
--- 
a/providers/google/src/airflow/providers/google/cloud/sensors/cloud_composer.py
+++ 
b/providers/google/src/airflow/providers/google/cloud/sensors/cloud_composer.py
@@ -26,6 +26,7 @@ from functools import cached_property
 from typing import TYPE_CHECKING
 
 from dateutil import parser
+from google.api_core.exceptions import NotFound
 from google.cloud.orchestration.airflow.service_v1.types import Environment, 
ExecuteAirflowCommandResponse
 
 from airflow.configuration import conf
@@ -97,6 +98,7 @@ class CloudComposerDAGRunSensor(BaseSensorOperator):
         impersonation_chain: str | Sequence[str] | None = None,
         deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
         poll_interval: int = 10,
+        use_rest_api: bool = False,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
@@ -111,6 +113,7 @@ class CloudComposerDAGRunSensor(BaseSensorOperator):
         self.impersonation_chain = impersonation_chain
         self.deferrable = deferrable
         self.poll_interval = poll_interval
+        self.use_rest_api = use_rest_api
 
         if self.composer_dag_run_id and self.execution_range:
             self.log.warning(
@@ -161,26 +164,51 @@ class CloudComposerDAGRunSensor(BaseSensorOperator):
 
     def _pull_dag_runs(self) -> list[dict]:
         """Pull the list of dag runs."""
-        cmd_parameters = (
-            ["-d", self.composer_dag_id, "-o", "json"]
-            if self._composer_airflow_version < 3
-            else [self.composer_dag_id, "-o", "json"]
-        )
-        dag_runs_cmd = self.hook.execute_airflow_command(
-            project_id=self.project_id,
-            region=self.region,
-            environment_id=self.environment_id,
-            command="dags",
-            subcommand="list-runs",
-            parameters=cmd_parameters,
-        )
-        cmd_result = self.hook.wait_command_execution_result(
-            project_id=self.project_id,
-            region=self.region,
-            environment_id=self.environment_id,
-            
execution_cmd_info=ExecuteAirflowCommandResponse.to_dict(dag_runs_cmd),
-        )
-        dag_runs = json.loads(cmd_result["output"][0]["content"])
+        if self.use_rest_api:
+            try:
+                environment = self.hook.get_environment(
+                    project_id=self.project_id,
+                    region=self.region,
+                    environment_id=self.environment_id,
+                    timeout=self.timeout,
+                )
+            except NotFound as not_found_err:
+                self.log.info("The Composer environment %s does not exist.", 
self.environment_id)
+                raise AirflowException(not_found_err)
+            composer_airflow_uri = environment.config.airflow_uri
+
+            self.log.info(
+                "Pulling the DAG %s runs from the %s environment...",
+                self.composer_dag_id,
+                self.environment_id,
+            )
+            dag_runs_response = self.hook.get_dag_runs(
+                composer_airflow_uri=composer_airflow_uri,
+                composer_dag_id=self.composer_dag_id,
+                timeout=self.timeout,
+            )
+            dag_runs = dag_runs_response["dag_runs"]
+        else:
+            cmd_parameters = (
+                ["-d", self.composer_dag_id, "-o", "json"]
+                if self._composer_airflow_version < 3
+                else [self.composer_dag_id, "-o", "json"]
+            )
+            dag_runs_cmd = self.hook.execute_airflow_command(
+                project_id=self.project_id,
+                region=self.region,
+                environment_id=self.environment_id,
+                command="dags",
+                subcommand="list-runs",
+                parameters=cmd_parameters,
+            )
+            cmd_result = self.hook.wait_command_execution_result(
+                project_id=self.project_id,
+                region=self.region,
+                environment_id=self.environment_id,
+                
execution_cmd_info=ExecuteAirflowCommandResponse.to_dict(dag_runs_cmd),
+            )
+            dag_runs = json.loads(cmd_result["output"][0]["content"])
         return dag_runs
 
     def _check_dag_runs_states(
@@ -213,7 +241,10 @@ class CloudComposerDAGRunSensor(BaseSensorOperator):
 
     def _check_composer_dag_run_id_states(self, dag_runs: list[dict]) -> bool:
         for dag_run in dag_runs:
-            if dag_run["run_id"] == self.composer_dag_run_id and 
dag_run["state"] in self.allowed_states:
+            if (
+                dag_run["dag_run_id" if self.use_rest_api else "run_id"] == 
self.composer_dag_run_id
+                and dag_run["state"] in self.allowed_states
+            ):
                 return True
         return False
 
@@ -236,6 +267,7 @@ class CloudComposerDAGRunSensor(BaseSensorOperator):
                     impersonation_chain=self.impersonation_chain,
                     poll_interval=self.poll_interval,
                     composer_airflow_version=self._composer_airflow_version,
+                    use_rest_api=self.use_rest_api,
                 ),
                 method_name=GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME,
             )
diff --git 
a/providers/google/src/airflow/providers/google/cloud/triggers/cloud_composer.py
 
b/providers/google/src/airflow/providers/google/cloud/triggers/cloud_composer.py
index f6654a39351..006480e5da5 100644
--- 
a/providers/google/src/airflow/providers/google/cloud/triggers/cloud_composer.py
+++ 
b/providers/google/src/airflow/providers/google/cloud/triggers/cloud_composer.py
@@ -25,6 +25,7 @@ from datetime import datetime
 from typing import Any
 
 from dateutil import parser
+from google.api_core.exceptions import NotFound
 from google.cloud.orchestration.airflow.service_v1.types import 
ExecuteAirflowCommandResponse
 
 from airflow.exceptions import AirflowException
@@ -188,6 +189,7 @@ class CloudComposerDAGRunTrigger(BaseTrigger):
         impersonation_chain: str | Sequence[str] | None = None,
         poll_interval: int = 10,
         composer_airflow_version: int = 2,
+        use_rest_api: bool = False,
     ):
         super().__init__()
         self.project_id = project_id
@@ -202,6 +204,7 @@ class CloudComposerDAGRunTrigger(BaseTrigger):
         self.impersonation_chain = impersonation_chain
         self.poll_interval = poll_interval
         self.composer_airflow_version = composer_airflow_version
+        self.use_rest_api = use_rest_api
 
     def serialize(self) -> tuple[str, dict[str, Any]]:
         return (
@@ -219,31 +222,55 @@ class CloudComposerDAGRunTrigger(BaseTrigger):
                 "impersonation_chain": self.impersonation_chain,
                 "poll_interval": self.poll_interval,
                 "composer_airflow_version": self.composer_airflow_version,
+                "use_rest_api": self.use_rest_api,
             },
         )
 
     async def _pull_dag_runs(self) -> list[dict]:
         """Pull the list of dag runs."""
-        cmd_parameters = (
-            ["-d", self.composer_dag_id, "-o", "json"]
-            if self.composer_airflow_version < 3
-            else [self.composer_dag_id, "-o", "json"]
-        )
-        dag_runs_cmd = await self.gcp_hook.execute_airflow_command(
-            project_id=self.project_id,
-            region=self.region,
-            environment_id=self.environment_id,
-            command="dags",
-            subcommand="list-runs",
-            parameters=cmd_parameters,
-        )
-        cmd_result = await self.gcp_hook.wait_command_execution_result(
-            project_id=self.project_id,
-            region=self.region,
-            environment_id=self.environment_id,
-            
execution_cmd_info=ExecuteAirflowCommandResponse.to_dict(dag_runs_cmd),
-        )
-        dag_runs = json.loads(cmd_result["output"][0]["content"])
+        if self.use_rest_api:
+            try:
+                environment = await self.gcp_hook.get_environment(
+                    project_id=self.project_id,
+                    region=self.region,
+                    environment_id=self.environment_id,
+                )
+            except NotFound as not_found_err:
+                self.log.info("The Composer environment %s does not exist.", 
self.environment_id)
+                raise AirflowException(not_found_err)
+            composer_airflow_uri = environment.config.airflow_uri
+
+            self.log.info(
+                "Pulling the DAG %s runs from the %s environment...",
+                self.composer_dag_id,
+                self.environment_id,
+            )
+            dag_runs_response = await self.gcp_hook.get_dag_runs(
+                composer_airflow_uri=composer_airflow_uri,
+                composer_dag_id=self.composer_dag_id,
+            )
+            dag_runs = dag_runs_response["dag_runs"]
+        else:
+            cmd_parameters = (
+                ["-d", self.composer_dag_id, "-o", "json"]
+                if self.composer_airflow_version < 3
+                else [self.composer_dag_id, "-o", "json"]
+            )
+            dag_runs_cmd = await self.gcp_hook.execute_airflow_command(
+                project_id=self.project_id,
+                region=self.region,
+                environment_id=self.environment_id,
+                command="dags",
+                subcommand="list-runs",
+                parameters=cmd_parameters,
+            )
+            cmd_result = await self.gcp_hook.wait_command_execution_result(
+                project_id=self.project_id,
+                region=self.region,
+                environment_id=self.environment_id,
+                
execution_cmd_info=ExecuteAirflowCommandResponse.to_dict(dag_runs_cmd),
+            )
+            dag_runs = json.loads(cmd_result["output"][0]["content"])
         return dag_runs
 
     def _check_dag_runs_states(
@@ -271,7 +298,10 @@ class CloudComposerDAGRunTrigger(BaseTrigger):
 
     def _check_composer_dag_run_id_states(self, dag_runs: list[dict]) -> bool:
         for dag_run in dag_runs:
-            if dag_run["run_id"] == self.composer_dag_run_id and 
dag_run["state"] in self.allowed_states:
+            if (
+                dag_run["dag_run_id" if self.use_rest_api else "run_id"] == 
self.composer_dag_run_id
+                and dag_run["state"] in self.allowed_states
+            ):
                 return True
         return False
 
diff --git 
a/providers/google/tests/unit/google/cloud/hooks/test_cloud_composer.py 
b/providers/google/tests/unit/google/cloud/hooks/test_cloud_composer.py
index cd62056e36e..e577f9bd086 100644
--- a/providers/google/tests/unit/google/cloud/hooks/test_cloud_composer.py
+++ b/providers/google/tests/unit/google/cloud/hooks/test_cloud_composer.py
@@ -283,6 +283,21 @@ class TestCloudComposerHook:
             timeout=TEST_TIMEOUT,
         )
 
+    
@mock.patch(COMPOSER_STRING.format("CloudComposerHook.make_composer_airflow_api_request"))
+    def test_get_dag_runs(self, mock_composer_airflow_api_request) -> None:
+        self.hook.get_credentials = mock.MagicMock()
+        self.hook.get_dag_runs(
+            composer_airflow_uri=TEST_COMPOSER_AIRFLOW_URI,
+            composer_dag_id=TEST_COMPOSER_DAG_ID,
+            timeout=TEST_TIMEOUT,
+        )
+        mock_composer_airflow_api_request.assert_called_once_with(
+            method="GET",
+            airflow_uri=TEST_COMPOSER_AIRFLOW_URI,
+            path=f"/api/v1/dags/{TEST_COMPOSER_DAG_ID}/dagRuns",
+            timeout=TEST_TIMEOUT,
+        )
+
 
 class TestCloudComposerAsyncHook:
     def setup_method(self, method):
@@ -365,6 +380,31 @@ class TestCloudComposerAsyncHook:
             metadata=TEST_METADATA,
         )
 
+    @pytest.mark.asyncio
+    
@mock.patch(COMPOSER_STRING.format("CloudComposerAsyncHook.get_environment_client"))
+    async def test_get_environment(self, mock_client) -> None:
+        mock_env_client = AsyncMock(EnvironmentsAsyncClient)
+        mock_client.return_value = mock_env_client
+        await self.hook.get_environment(
+            project_id=TEST_GCP_PROJECT,
+            region=TEST_GCP_REGION,
+            environment_id=TEST_ENVIRONMENT_ID,
+            retry=TEST_RETRY,
+            timeout=TEST_TIMEOUT,
+            metadata=TEST_METADATA,
+        )
+        mock_client.assert_called_once()
+        mock_client.return_value.get_environment.assert_called_once_with(
+            request={
+                "name": self.hook.get_environment_name(
+                    TEST_GCP_PROJECT, TEST_GCP_REGION, TEST_ENVIRONMENT_ID
+                ),
+            },
+            retry=TEST_RETRY,
+            timeout=TEST_TIMEOUT,
+            metadata=TEST_METADATA,
+        )
+
     @pytest.mark.asyncio
     
@mock.patch(COMPOSER_STRING.format("CloudComposerAsyncHook.get_environment_client"))
     async def test_execute_airflow_command(self, mock_client) -> None:
@@ -428,3 +468,19 @@ class TestCloudComposerAsyncHook:
             timeout=TEST_TIMEOUT,
             metadata=TEST_METADATA,
         )
+
+    @pytest.mark.asyncio
+    
@mock.patch(COMPOSER_STRING.format("CloudComposerAsyncHook.make_composer_airflow_api_request"))
+    async def test_get_dag_runs(self, mock_composer_airflow_api_request) -> 
None:
+        mock_composer_airflow_api_request.return_value = ({}, 200)
+        await self.hook.get_dag_runs(
+            composer_airflow_uri=TEST_COMPOSER_AIRFLOW_URI,
+            composer_dag_id=TEST_COMPOSER_DAG_ID,
+            timeout=TEST_TIMEOUT,
+        )
+        mock_composer_airflow_api_request.assert_called_once_with(
+            method="GET",
+            airflow_uri=TEST_COMPOSER_AIRFLOW_URI,
+            path=f"/api/v1/dags/{TEST_COMPOSER_DAG_ID}/dagRuns",
+            timeout=TEST_TIMEOUT,
+        )
diff --git 
a/providers/google/tests/unit/google/cloud/sensors/test_cloud_composer.py 
b/providers/google/tests/unit/google/cloud/sensors/test_cloud_composer.py
index c2da2daa00b..6988508b056 100644
--- a/providers/google/tests/unit/google/cloud/sensors/test_cloud_composer.py
+++ b/providers/google/tests/unit/google/cloud/sensors/test_cloud_composer.py
@@ -30,33 +30,39 @@ TEST_OPERATION_NAME = "test_operation_name"
 TEST_REGION = "region"
 TEST_ENVIRONMENT_ID = "test_env_id"
 TEST_COMPOSER_DAG_RUN_ID = "scheduled__2024-05-22T11:10:00+00:00"
-TEST_JSON_RESULT = lambda state, date_key: json.dumps(
-    [
-        {
-            "dag_id": "test_dag_id",
-            "run_id": TEST_COMPOSER_DAG_RUN_ID,
-            "state": state,
-            date_key: "2024-05-22T11:10:00+00:00",
-            "start_date": "2024-05-22T11:20:01.531988+00:00",
-            "end_date": "2024-05-22T11:20:11.997479+00:00",
-        }
-    ]
-)
+TEST_DAG_RUNS_RESULT = lambda state, date_key, run_id_key: [
+    {
+        "dag_id": "test_dag_id",
+        run_id_key: TEST_COMPOSER_DAG_RUN_ID,
+        "state": state,
+        date_key: "2024-05-22T11:10:00+00:00",
+        "start_date": "2024-05-22T11:20:01.531988+00:00",
+        "end_date": "2024-05-22T11:20:11.997479+00:00",
+    }
+]
 TEST_EXEC_RESULT = lambda state, date_key: {
-    "output": [{"line_number": 1, "content": TEST_JSON_RESULT(state, 
date_key)}],
+    "output": [{"line_number": 1, "content": 
json.dumps(TEST_DAG_RUNS_RESULT(state, date_key, "run_id"))}],
     "output_end": True,
     "exit_info": {"exit_code": 0, "error": ""},
 }
+TEST_GET_RESULT = lambda state, date_key: {
+    "dag_runs": TEST_DAG_RUNS_RESULT(state, date_key, "dag_run_id"),
+    "total_entries": 1,
+}
 
 
 class TestCloudComposerDAGRunSensor:
+    @pytest.mark.parametrize("use_rest_api", [True, False])
     @pytest.mark.parametrize("composer_airflow_version", [2, 3])
     
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.ExecuteAirflowCommandResponse.to_dict")
     
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerHook")
-    def test_wait_ready(self, mock_hook, to_dict_mode, 
composer_airflow_version):
+    def test_wait_ready(self, mock_hook, to_dict_mode, 
composer_airflow_version, use_rest_api):
         mock_hook.return_value.wait_command_execution_result.return_value = 
TEST_EXEC_RESULT(
             "success", "execution_date" if composer_airflow_version < 3 else 
"logical_date"
         )
+        mock_hook.return_value.get_dag_runs.return_value = TEST_GET_RESULT(
+            "success", "execution_date" if composer_airflow_version < 3 else 
"logical_date"
+        )
 
         task = CloudComposerDAGRunSensor(
             task_id="task-id",
@@ -65,18 +71,23 @@ class TestCloudComposerDAGRunSensor:
             environment_id=TEST_ENVIRONMENT_ID,
             composer_dag_id="test_dag_id",
             allowed_states=["success"],
+            use_rest_api=use_rest_api,
         )
         task._composer_airflow_version = composer_airflow_version
 
         assert task.poke(context={"logical_date": datetime(2024, 5, 23, 0, 0, 
0)})
 
+    @pytest.mark.parametrize("use_rest_api", [True, False])
     @pytest.mark.parametrize("composer_airflow_version", [2, 3])
     
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.ExecuteAirflowCommandResponse.to_dict")
     
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerHook")
-    def test_wait_not_ready(self, mock_hook, to_dict_mode, 
composer_airflow_version):
+    def test_wait_not_ready(self, mock_hook, to_dict_mode, 
composer_airflow_version, use_rest_api):
         mock_hook.return_value.wait_command_execution_result.return_value = 
TEST_EXEC_RESULT(
             "running", "execution_date" if composer_airflow_version < 3 else 
"logical_date"
         )
+        mock_hook.return_value.get_dag_runs.return_value = TEST_GET_RESULT(
+            "running", "execution_date" if composer_airflow_version < 3 else 
"logical_date"
+        )
 
         task = CloudComposerDAGRunSensor(
             task_id="task-id",
@@ -85,20 +96,26 @@ class TestCloudComposerDAGRunSensor:
             environment_id=TEST_ENVIRONMENT_ID,
             composer_dag_id="test_dag_id",
             allowed_states=["success"],
+            use_rest_api=use_rest_api,
         )
         task._composer_airflow_version = composer_airflow_version
 
         assert not task.poke(context={"logical_date": datetime(2024, 5, 23, 0, 
0, 0)})
 
+    @pytest.mark.parametrize("use_rest_api", [True, False])
     @pytest.mark.parametrize("composer_airflow_version", [2, 3])
     
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.ExecuteAirflowCommandResponse.to_dict")
     
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerHook")
-    def test_dag_runs_empty(self, mock_hook, to_dict_mode, 
composer_airflow_version):
+    def test_dag_runs_empty(self, mock_hook, to_dict_mode, 
composer_airflow_version, use_rest_api):
         mock_hook.return_value.wait_command_execution_result.return_value = {
             "output": [{"line_number": 1, "content": json.dumps([])}],
             "output_end": True,
             "exit_info": {"exit_code": 0, "error": ""},
         }
+        mock_hook.return_value.get_dag_runs.return_value = {
+            "dag_runs": [],
+            "total_entries": 0,
+        }
 
         task = CloudComposerDAGRunSensor(
             task_id="task-id",
@@ -107,18 +124,25 @@ class TestCloudComposerDAGRunSensor:
             environment_id=TEST_ENVIRONMENT_ID,
             composer_dag_id="test_dag_id",
             allowed_states=["success"],
+            use_rest_api=use_rest_api,
         )
         task._composer_airflow_version = composer_airflow_version
 
         assert not task.poke(context={"logical_date": datetime(2024, 5, 23, 0, 
0, 0)})
 
+    @pytest.mark.parametrize("use_rest_api", [True, False])
     @pytest.mark.parametrize("composer_airflow_version", [2, 3])
     
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.ExecuteAirflowCommandResponse.to_dict")
     
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerHook")
-    def test_composer_dag_run_id_wait_ready(self, mock_hook, to_dict_mode, 
composer_airflow_version):
+    def test_composer_dag_run_id_wait_ready(
+        self, mock_hook, to_dict_mode, composer_airflow_version, use_rest_api
+    ):
         mock_hook.return_value.wait_command_execution_result.return_value = 
TEST_EXEC_RESULT(
             "success", "execution_date" if composer_airflow_version < 3 else 
"logical_date"
         )
+        mock_hook.return_value.get_dag_runs.return_value = TEST_GET_RESULT(
+            "success", "execution_date" if composer_airflow_version < 3 else 
"logical_date"
+        )
 
         task = CloudComposerDAGRunSensor(
             task_id="task-id",
@@ -128,18 +152,25 @@ class TestCloudComposerDAGRunSensor:
             composer_dag_id="test_dag_id",
             composer_dag_run_id=TEST_COMPOSER_DAG_RUN_ID,
             allowed_states=["success"],
+            use_rest_api=use_rest_api,
         )
         task._composer_airflow_version = composer_airflow_version
 
         assert task.poke(context={"logical_date": datetime(2024, 5, 23, 0, 0, 
0)})
 
+    @pytest.mark.parametrize("use_rest_api", [True, False])
     @pytest.mark.parametrize("composer_airflow_version", [2, 3])
     
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.ExecuteAirflowCommandResponse.to_dict")
     
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerHook")
-    def test_composer_dag_run_id_wait_not_ready(self, mock_hook, to_dict_mode, 
composer_airflow_version):
+    def test_composer_dag_run_id_wait_not_ready(
+        self, mock_hook, to_dict_mode, composer_airflow_version, use_rest_api
+    ):
         mock_hook.return_value.wait_command_execution_result.return_value = 
TEST_EXEC_RESULT(
             "running", "execution_date" if composer_airflow_version < 3 else 
"logical_date"
         )
+        mock_hook.return_value.get_dag_runs.return_value = TEST_GET_RESULT(
+            "running", "execution_date" if composer_airflow_version < 3 else 
"logical_date"
+        )
 
         task = CloudComposerDAGRunSensor(
             task_id="task-id",
@@ -149,6 +180,7 @@ class TestCloudComposerDAGRunSensor:
             composer_dag_id="test_dag_id",
             composer_dag_run_id=TEST_COMPOSER_DAG_RUN_ID,
             allowed_states=["success"],
+            use_rest_api=use_rest_api,
         )
         task._composer_airflow_version = composer_airflow_version
 
diff --git 
a/providers/google/tests/unit/google/cloud/triggers/test_cloud_composer.py 
b/providers/google/tests/unit/google/cloud/triggers/test_cloud_composer.py
index cde313785d6..f093a76d1d7 100644
--- a/providers/google/tests/unit/google/cloud/triggers/test_cloud_composer.py
+++ b/providers/google/tests/unit/google/cloud/triggers/test_cloud_composer.py
@@ -46,6 +46,7 @@ TEST_STATES = ["success"]
 TEST_GCP_CONN_ID = "test_gcp_conn_id"
 TEST_POLL_INTERVAL = 10
 TEST_COMPOSER_AIRFLOW_VERSION = 3
+TEST_USE_REST_API = True
 TEST_IMPERSONATION_CHAIN = "test_impersonation_chain"
 TEST_EXEC_RESULT = {
     "output": [{"line_number": 1, "content": "test_content"}],
@@ -90,6 +91,7 @@ def dag_run_trigger(mock_conn):
         impersonation_chain=TEST_IMPERSONATION_CHAIN,
         poll_interval=TEST_POLL_INTERVAL,
         composer_airflow_version=TEST_COMPOSER_AIRFLOW_VERSION,
+        use_rest_api=TEST_USE_REST_API,
     )
 
 
@@ -146,6 +148,7 @@ class TestCloudComposerDAGRunTrigger:
                 "impersonation_chain": TEST_IMPERSONATION_CHAIN,
                 "poll_interval": TEST_POLL_INTERVAL,
                 "composer_airflow_version": TEST_COMPOSER_AIRFLOW_VERSION,
+                "use_rest_api": TEST_USE_REST_API,
             },
         )
         assert actual_data == expected_data

Reply via email to