This is an automated email from the ASF dual-hosted git repository.
shahar pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new fcc28a56437 Add use_rest_api parameter for CloudComposerDAGRunSensor
for pulling dag_runs using the Airflow REST API (#56138)
fcc28a56437 is described below
commit fcc28a56437c58bc34fea19bd53936c9fdf8799b
Author: Maksim <[email protected]>
AuthorDate: Thu Oct 9 14:11:07 2025 +0200
Add use_rest_api parameter for CloudComposerDAGRunSensor for pulling
dag_runs using the Airflow REST API (#56138)
---
.../providers/google/cloud/hooks/cloud_composer.py | 132 ++++++++++++++++++++-
.../google/cloud/sensors/cloud_composer.py | 74 ++++++++----
.../google/cloud/triggers/cloud_composer.py | 72 +++++++----
.../unit/google/cloud/hooks/test_cloud_composer.py | 56 +++++++++
.../google/cloud/sensors/test_cloud_composer.py | 68 ++++++++---
.../google/cloud/triggers/test_cloud_composer.py | 3 +
6 files changed, 344 insertions(+), 61 deletions(-)
diff --git
a/providers/google/src/airflow/providers/google/cloud/hooks/cloud_composer.py
b/providers/google/src/airflow/providers/google/cloud/hooks/cloud_composer.py
index daf6a06a927..423d6cb56df 100644
---
a/providers/google/src/airflow/providers/google/cloud/hooks/cloud_composer.py
+++
b/providers/google/src/airflow/providers/google/cloud/hooks/cloud_composer.py
@@ -24,9 +24,10 @@ from collections.abc import MutableSequence, Sequence
from typing import TYPE_CHECKING, Any
from urllib.parse import urljoin
+from aiohttp import ClientSession
from google.api_core.client_options import ClientOptions
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
-from google.auth.transport.requests import AuthorizedSession
+from google.auth.transport.requests import AuthorizedSession, Request
from google.cloud.orchestration.airflow.service_v1 import (
EnvironmentsAsyncClient,
EnvironmentsClient,
@@ -472,6 +473,38 @@ class CloudComposerHook(GoogleBaseHook, OperationHelper):
return response.json()
+ def get_dag_runs(
+ self,
+ composer_airflow_uri: str,
+ composer_dag_id: str,
+ timeout: float | None = None,
+ ) -> dict:
+ """
+ Get the list of dag runs for provided DAG.
+
+ :param composer_airflow_uri: The URI of the Apache Airflow Web UI
hosted within Composer environment.
+ :param composer_dag_id: The ID of DAG.
+ :param timeout: The timeout for this request.
+ """
+ response = self.make_composer_airflow_api_request(
+ method="GET",
+ airflow_uri=composer_airflow_uri,
+ path=f"/api/v1/dags/{composer_dag_id}/dagRuns",
+ timeout=timeout,
+ )
+
+ if response.status_code != 200:
+ self.log.error(
+ "Failed to get DAG runs for dag_id=%s from %s (status=%s): %s",
+ composer_dag_id,
+ composer_airflow_uri,
+ response.status_code,
+ response.text,
+ )
+ response.raise_for_status()
+
+ return response.json()
+
class CloudComposerAsyncHook(GoogleBaseAsyncHook):
"""Hook for Google Cloud Composer async APIs."""
@@ -489,6 +522,42 @@ class CloudComposerAsyncHook(GoogleBaseAsyncHook):
client_options=self.client_options,
)
+ async def make_composer_airflow_api_request(
+ self,
+ method: str,
+ airflow_uri: str,
+ path: str,
+ data: Any | None = None,
+ timeout: float | None = None,
+ ):
+ """
+ Make a request to Cloud Composer environment's web server.
+
+ :param method: The request method to use ('GET', 'OPTIONS', 'HEAD',
'POST', 'PUT', 'PATCH', 'DELETE').
+ :param airflow_uri: The URI of the Apache Airflow Web UI hosted within
this environment.
+ :param path: The path to send the request.
+ :param data: Dictionary, list of tuples, bytes, or file-like object to
send in the body of the request.
+ :param timeout: The timeout for this request.
+ """
+ sync_hook = await self.get_sync_hook()
+ credentials = sync_hook.get_credentials()
+
+ if not credentials.valid:
+ credentials.refresh(Request())
+
+ async with ClientSession() as session:
+ async with session.request(
+ method=method,
+ url=urljoin(airflow_uri, path),
+ data=data,
+ headers={
+ "Content-Type": "application/json",
+ "Authorization": f"Bearer {credentials.token}",
+ },
+ timeout=timeout,
+ ) as response:
+ return await response.json(), response.status
+
def get_environment_name(self, project_id, region, environment_id):
return
f"projects/{project_id}/locations/{region}/environments/{environment_id}"
@@ -594,6 +663,35 @@ class CloudComposerAsyncHook(GoogleBaseAsyncHook):
metadata=metadata,
)
+ @GoogleBaseHook.fallback_to_default_project_id
+ async def get_environment(
+ self,
+ project_id: str,
+ region: str,
+ environment_id: str,
+ retry: AsyncRetry | _MethodDefault = DEFAULT,
+ timeout: float | None = None,
+ metadata: Sequence[tuple[str, str]] = (),
+ ) -> Environment:
+ """
+ Get an existing environment.
+
+ :param project_id: Required. The ID of the Google Cloud project that
the service belongs to.
+ :param region: Required. The ID of the Google Cloud region that the
service belongs to.
+ :param environment_id: Required. The ID of the Google Cloud
environment that the service belongs to.
+ :param retry: Designation of what errors, if any, should be retried.
+ :param timeout: The timeout for this request.
+ :param metadata: Strings which should be sent along with the request
as metadata.
+ """
+ client = await self.get_environment_client()
+
+ return await client.get_environment(
+ request={"name": self.get_environment_name(project_id, region,
environment_id)},
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+
@GoogleBaseHook.fallback_to_default_project_id
async def execute_airflow_command(
self,
@@ -719,3 +817,35 @@ class CloudComposerAsyncHook(GoogleBaseAsyncHook):
self.log.info("Sleeping for %s seconds.", poll_interval)
await asyncio.sleep(poll_interval)
+
+ async def get_dag_runs(
+ self,
+ composer_airflow_uri: str,
+ composer_dag_id: str,
+ timeout: float | None = None,
+ ) -> dict:
+ """
+ Get the list of dag runs for provided DAG.
+
+ :param composer_airflow_uri: The URI of the Apache Airflow Web UI
hosted within Composer environment.
+ :param composer_dag_id: The ID of DAG.
+ :param timeout: The timeout for this request.
+ """
+ response_body, response_status_code = await
self.make_composer_airflow_api_request(
+ method="GET",
+ airflow_uri=composer_airflow_uri,
+ path=f"/api/v1/dags/{composer_dag_id}/dagRuns",
+ timeout=timeout,
+ )
+
+ if response_status_code != 200:
+ self.log.error(
+ "Failed to get DAG runs for dag_id=%s from %s (status=%s): %s",
+ composer_dag_id,
+ composer_airflow_uri,
+ response_status_code,
+ response_body["title"],
+ )
+ raise AirflowException(response_body["title"])
+
+ return response_body
diff --git
a/providers/google/src/airflow/providers/google/cloud/sensors/cloud_composer.py
b/providers/google/src/airflow/providers/google/cloud/sensors/cloud_composer.py
index 09aa0e1aa8a..242953dd2ef 100644
---
a/providers/google/src/airflow/providers/google/cloud/sensors/cloud_composer.py
+++
b/providers/google/src/airflow/providers/google/cloud/sensors/cloud_composer.py
@@ -26,6 +26,7 @@ from functools import cached_property
from typing import TYPE_CHECKING
from dateutil import parser
+from google.api_core.exceptions import NotFound
from google.cloud.orchestration.airflow.service_v1.types import Environment,
ExecuteAirflowCommandResponse
from airflow.configuration import conf
@@ -97,6 +98,7 @@ class CloudComposerDAGRunSensor(BaseSensorOperator):
impersonation_chain: str | Sequence[str] | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
poll_interval: int = 10,
+ use_rest_api: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -111,6 +113,7 @@ class CloudComposerDAGRunSensor(BaseSensorOperator):
self.impersonation_chain = impersonation_chain
self.deferrable = deferrable
self.poll_interval = poll_interval
+ self.use_rest_api = use_rest_api
if self.composer_dag_run_id and self.execution_range:
self.log.warning(
@@ -161,26 +164,51 @@ class CloudComposerDAGRunSensor(BaseSensorOperator):
def _pull_dag_runs(self) -> list[dict]:
"""Pull the list of dag runs."""
- cmd_parameters = (
- ["-d", self.composer_dag_id, "-o", "json"]
- if self._composer_airflow_version < 3
- else [self.composer_dag_id, "-o", "json"]
- )
- dag_runs_cmd = self.hook.execute_airflow_command(
- project_id=self.project_id,
- region=self.region,
- environment_id=self.environment_id,
- command="dags",
- subcommand="list-runs",
- parameters=cmd_parameters,
- )
- cmd_result = self.hook.wait_command_execution_result(
- project_id=self.project_id,
- region=self.region,
- environment_id=self.environment_id,
-
execution_cmd_info=ExecuteAirflowCommandResponse.to_dict(dag_runs_cmd),
- )
- dag_runs = json.loads(cmd_result["output"][0]["content"])
+ if self.use_rest_api:
+ try:
+ environment = self.hook.get_environment(
+ project_id=self.project_id,
+ region=self.region,
+ environment_id=self.environment_id,
+ timeout=self.timeout,
+ )
+ except NotFound as not_found_err:
+ self.log.info("The Composer environment %s does not exist.",
self.environment_id)
+ raise AirflowException(not_found_err)
+ composer_airflow_uri = environment.config.airflow_uri
+
+ self.log.info(
+ "Pulling the DAG %s runs from the %s environment...",
+ self.composer_dag_id,
+ self.environment_id,
+ )
+ dag_runs_response = self.hook.get_dag_runs(
+ composer_airflow_uri=composer_airflow_uri,
+ composer_dag_id=self.composer_dag_id,
+ timeout=self.timeout,
+ )
+ dag_runs = dag_runs_response["dag_runs"]
+ else:
+ cmd_parameters = (
+ ["-d", self.composer_dag_id, "-o", "json"]
+ if self._composer_airflow_version < 3
+ else [self.composer_dag_id, "-o", "json"]
+ )
+ dag_runs_cmd = self.hook.execute_airflow_command(
+ project_id=self.project_id,
+ region=self.region,
+ environment_id=self.environment_id,
+ command="dags",
+ subcommand="list-runs",
+ parameters=cmd_parameters,
+ )
+ cmd_result = self.hook.wait_command_execution_result(
+ project_id=self.project_id,
+ region=self.region,
+ environment_id=self.environment_id,
+
execution_cmd_info=ExecuteAirflowCommandResponse.to_dict(dag_runs_cmd),
+ )
+ dag_runs = json.loads(cmd_result["output"][0]["content"])
return dag_runs
def _check_dag_runs_states(
@@ -213,7 +241,10 @@ class CloudComposerDAGRunSensor(BaseSensorOperator):
def _check_composer_dag_run_id_states(self, dag_runs: list[dict]) -> bool:
for dag_run in dag_runs:
- if dag_run["run_id"] == self.composer_dag_run_id and
dag_run["state"] in self.allowed_states:
+ if (
+ dag_run["dag_run_id" if self.use_rest_api else "run_id"] ==
self.composer_dag_run_id
+ and dag_run["state"] in self.allowed_states
+ ):
return True
return False
@@ -236,6 +267,7 @@ class CloudComposerDAGRunSensor(BaseSensorOperator):
impersonation_chain=self.impersonation_chain,
poll_interval=self.poll_interval,
composer_airflow_version=self._composer_airflow_version,
+ use_rest_api=self.use_rest_api,
),
method_name=GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME,
)
diff --git
a/providers/google/src/airflow/providers/google/cloud/triggers/cloud_composer.py
b/providers/google/src/airflow/providers/google/cloud/triggers/cloud_composer.py
index f6654a39351..006480e5da5 100644
---
a/providers/google/src/airflow/providers/google/cloud/triggers/cloud_composer.py
+++
b/providers/google/src/airflow/providers/google/cloud/triggers/cloud_composer.py
@@ -25,6 +25,7 @@ from datetime import datetime
from typing import Any
from dateutil import parser
+from google.api_core.exceptions import NotFound
from google.cloud.orchestration.airflow.service_v1.types import
ExecuteAirflowCommandResponse
from airflow.exceptions import AirflowException
@@ -188,6 +189,7 @@ class CloudComposerDAGRunTrigger(BaseTrigger):
impersonation_chain: str | Sequence[str] | None = None,
poll_interval: int = 10,
composer_airflow_version: int = 2,
+ use_rest_api: bool = False,
):
super().__init__()
self.project_id = project_id
@@ -202,6 +204,7 @@ class CloudComposerDAGRunTrigger(BaseTrigger):
self.impersonation_chain = impersonation_chain
self.poll_interval = poll_interval
self.composer_airflow_version = composer_airflow_version
+ self.use_rest_api = use_rest_api
def serialize(self) -> tuple[str, dict[str, Any]]:
return (
@@ -219,31 +222,55 @@ class CloudComposerDAGRunTrigger(BaseTrigger):
"impersonation_chain": self.impersonation_chain,
"poll_interval": self.poll_interval,
"composer_airflow_version": self.composer_airflow_version,
+ "use_rest_api": self.use_rest_api,
},
)
async def _pull_dag_runs(self) -> list[dict]:
"""Pull the list of dag runs."""
- cmd_parameters = (
- ["-d", self.composer_dag_id, "-o", "json"]
- if self.composer_airflow_version < 3
- else [self.composer_dag_id, "-o", "json"]
- )
- dag_runs_cmd = await self.gcp_hook.execute_airflow_command(
- project_id=self.project_id,
- region=self.region,
- environment_id=self.environment_id,
- command="dags",
- subcommand="list-runs",
- parameters=cmd_parameters,
- )
- cmd_result = await self.gcp_hook.wait_command_execution_result(
- project_id=self.project_id,
- region=self.region,
- environment_id=self.environment_id,
-
execution_cmd_info=ExecuteAirflowCommandResponse.to_dict(dag_runs_cmd),
- )
- dag_runs = json.loads(cmd_result["output"][0]["content"])
+ if self.use_rest_api:
+ try:
+ environment = await self.gcp_hook.get_environment(
+ project_id=self.project_id,
+ region=self.region,
+ environment_id=self.environment_id,
+ )
+ except NotFound as not_found_err:
+ self.log.info("The Composer environment %s does not exist.",
self.environment_id)
+ raise AirflowException(not_found_err)
+ composer_airflow_uri = environment.config.airflow_uri
+
+ self.log.info(
+ "Pulling the DAG %s runs from the %s environment...",
+ self.composer_dag_id,
+ self.environment_id,
+ )
+ dag_runs_response = await self.gcp_hook.get_dag_runs(
+ composer_airflow_uri=composer_airflow_uri,
+ composer_dag_id=self.composer_dag_id,
+ )
+ dag_runs = dag_runs_response["dag_runs"]
+ else:
+ cmd_parameters = (
+ ["-d", self.composer_dag_id, "-o", "json"]
+ if self.composer_airflow_version < 3
+ else [self.composer_dag_id, "-o", "json"]
+ )
+ dag_runs_cmd = await self.gcp_hook.execute_airflow_command(
+ project_id=self.project_id,
+ region=self.region,
+ environment_id=self.environment_id,
+ command="dags",
+ subcommand="list-runs",
+ parameters=cmd_parameters,
+ )
+ cmd_result = await self.gcp_hook.wait_command_execution_result(
+ project_id=self.project_id,
+ region=self.region,
+ environment_id=self.environment_id,
+
execution_cmd_info=ExecuteAirflowCommandResponse.to_dict(dag_runs_cmd),
+ )
+ dag_runs = json.loads(cmd_result["output"][0]["content"])
return dag_runs
def _check_dag_runs_states(
@@ -271,7 +298,10 @@ class CloudComposerDAGRunTrigger(BaseTrigger):
def _check_composer_dag_run_id_states(self, dag_runs: list[dict]) -> bool:
for dag_run in dag_runs:
- if dag_run["run_id"] == self.composer_dag_run_id and
dag_run["state"] in self.allowed_states:
+ if (
+ dag_run["dag_run_id" if self.use_rest_api else "run_id"] ==
self.composer_dag_run_id
+ and dag_run["state"] in self.allowed_states
+ ):
return True
return False
diff --git
a/providers/google/tests/unit/google/cloud/hooks/test_cloud_composer.py
b/providers/google/tests/unit/google/cloud/hooks/test_cloud_composer.py
index cd62056e36e..e577f9bd086 100644
--- a/providers/google/tests/unit/google/cloud/hooks/test_cloud_composer.py
+++ b/providers/google/tests/unit/google/cloud/hooks/test_cloud_composer.py
@@ -283,6 +283,21 @@ class TestCloudComposerHook:
timeout=TEST_TIMEOUT,
)
+
@mock.patch(COMPOSER_STRING.format("CloudComposerHook.make_composer_airflow_api_request"))
+ def test_get_dag_runs(self, mock_composer_airflow_api_request) -> None:
+ self.hook.get_credentials = mock.MagicMock()
+ self.hook.get_dag_runs(
+ composer_airflow_uri=TEST_COMPOSER_AIRFLOW_URI,
+ composer_dag_id=TEST_COMPOSER_DAG_ID,
+ timeout=TEST_TIMEOUT,
+ )
+ mock_composer_airflow_api_request.assert_called_once_with(
+ method="GET",
+ airflow_uri=TEST_COMPOSER_AIRFLOW_URI,
+ path=f"/api/v1/dags/{TEST_COMPOSER_DAG_ID}/dagRuns",
+ timeout=TEST_TIMEOUT,
+ )
+
class TestCloudComposerAsyncHook:
def setup_method(self, method):
@@ -365,6 +380,31 @@ class TestCloudComposerAsyncHook:
metadata=TEST_METADATA,
)
+ @pytest.mark.asyncio
+
@mock.patch(COMPOSER_STRING.format("CloudComposerAsyncHook.get_environment_client"))
+ async def test_get_environment(self, mock_client) -> None:
+ mock_env_client = AsyncMock(EnvironmentsAsyncClient)
+ mock_client.return_value = mock_env_client
+ await self.hook.get_environment(
+ project_id=TEST_GCP_PROJECT,
+ region=TEST_GCP_REGION,
+ environment_id=TEST_ENVIRONMENT_ID,
+ retry=TEST_RETRY,
+ timeout=TEST_TIMEOUT,
+ metadata=TEST_METADATA,
+ )
+ mock_client.assert_called_once()
+ mock_client.return_value.get_environment.assert_called_once_with(
+ request={
+ "name": self.hook.get_environment_name(
+ TEST_GCP_PROJECT, TEST_GCP_REGION, TEST_ENVIRONMENT_ID
+ ),
+ },
+ retry=TEST_RETRY,
+ timeout=TEST_TIMEOUT,
+ metadata=TEST_METADATA,
+ )
+
@pytest.mark.asyncio
@mock.patch(COMPOSER_STRING.format("CloudComposerAsyncHook.get_environment_client"))
async def test_execute_airflow_command(self, mock_client) -> None:
@@ -428,3 +468,19 @@ class TestCloudComposerAsyncHook:
timeout=TEST_TIMEOUT,
metadata=TEST_METADATA,
)
+
+ @pytest.mark.asyncio
+
@mock.patch(COMPOSER_STRING.format("CloudComposerAsyncHook.make_composer_airflow_api_request"))
+ async def test_get_dag_runs(self, mock_composer_airflow_api_request) ->
None:
+ mock_composer_airflow_api_request.return_value = ({}, 200)
+ await self.hook.get_dag_runs(
+ composer_airflow_uri=TEST_COMPOSER_AIRFLOW_URI,
+ composer_dag_id=TEST_COMPOSER_DAG_ID,
+ timeout=TEST_TIMEOUT,
+ )
+ mock_composer_airflow_api_request.assert_called_once_with(
+ method="GET",
+ airflow_uri=TEST_COMPOSER_AIRFLOW_URI,
+ path=f"/api/v1/dags/{TEST_COMPOSER_DAG_ID}/dagRuns",
+ timeout=TEST_TIMEOUT,
+ )
diff --git
a/providers/google/tests/unit/google/cloud/sensors/test_cloud_composer.py
b/providers/google/tests/unit/google/cloud/sensors/test_cloud_composer.py
index c2da2daa00b..6988508b056 100644
--- a/providers/google/tests/unit/google/cloud/sensors/test_cloud_composer.py
+++ b/providers/google/tests/unit/google/cloud/sensors/test_cloud_composer.py
@@ -30,33 +30,39 @@ TEST_OPERATION_NAME = "test_operation_name"
TEST_REGION = "region"
TEST_ENVIRONMENT_ID = "test_env_id"
TEST_COMPOSER_DAG_RUN_ID = "scheduled__2024-05-22T11:10:00+00:00"
-TEST_JSON_RESULT = lambda state, date_key: json.dumps(
- [
- {
- "dag_id": "test_dag_id",
- "run_id": TEST_COMPOSER_DAG_RUN_ID,
- "state": state,
- date_key: "2024-05-22T11:10:00+00:00",
- "start_date": "2024-05-22T11:20:01.531988+00:00",
- "end_date": "2024-05-22T11:20:11.997479+00:00",
- }
- ]
-)
+TEST_DAG_RUNS_RESULT = lambda state, date_key, run_id_key: [
+ {
+ "dag_id": "test_dag_id",
+ run_id_key: TEST_COMPOSER_DAG_RUN_ID,
+ "state": state,
+ date_key: "2024-05-22T11:10:00+00:00",
+ "start_date": "2024-05-22T11:20:01.531988+00:00",
+ "end_date": "2024-05-22T11:20:11.997479+00:00",
+ }
+]
TEST_EXEC_RESULT = lambda state, date_key: {
- "output": [{"line_number": 1, "content": TEST_JSON_RESULT(state,
date_key)}],
+ "output": [{"line_number": 1, "content":
json.dumps(TEST_DAG_RUNS_RESULT(state, date_key, "run_id"))}],
"output_end": True,
"exit_info": {"exit_code": 0, "error": ""},
}
+TEST_GET_RESULT = lambda state, date_key: {
+ "dag_runs": TEST_DAG_RUNS_RESULT(state, date_key, "dag_run_id"),
+ "total_entries": 1,
+}
class TestCloudComposerDAGRunSensor:
+ @pytest.mark.parametrize("use_rest_api", [True, False])
@pytest.mark.parametrize("composer_airflow_version", [2, 3])
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.ExecuteAirflowCommandResponse.to_dict")
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerHook")
- def test_wait_ready(self, mock_hook, to_dict_mode,
composer_airflow_version):
+ def test_wait_ready(self, mock_hook, to_dict_mode,
composer_airflow_version, use_rest_api):
mock_hook.return_value.wait_command_execution_result.return_value =
TEST_EXEC_RESULT(
"success", "execution_date" if composer_airflow_version < 3 else
"logical_date"
)
+ mock_hook.return_value.get_dag_runs.return_value = TEST_GET_RESULT(
+ "success", "execution_date" if composer_airflow_version < 3 else
"logical_date"
+ )
task = CloudComposerDAGRunSensor(
task_id="task-id",
@@ -65,18 +71,23 @@ class TestCloudComposerDAGRunSensor:
environment_id=TEST_ENVIRONMENT_ID,
composer_dag_id="test_dag_id",
allowed_states=["success"],
+ use_rest_api=use_rest_api,
)
task._composer_airflow_version = composer_airflow_version
assert task.poke(context={"logical_date": datetime(2024, 5, 23, 0, 0,
0)})
+ @pytest.mark.parametrize("use_rest_api", [True, False])
@pytest.mark.parametrize("composer_airflow_version", [2, 3])
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.ExecuteAirflowCommandResponse.to_dict")
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerHook")
- def test_wait_not_ready(self, mock_hook, to_dict_mode,
composer_airflow_version):
+ def test_wait_not_ready(self, mock_hook, to_dict_mode,
composer_airflow_version, use_rest_api):
mock_hook.return_value.wait_command_execution_result.return_value =
TEST_EXEC_RESULT(
"running", "execution_date" if composer_airflow_version < 3 else
"logical_date"
)
+ mock_hook.return_value.get_dag_runs.return_value = TEST_GET_RESULT(
+ "running", "execution_date" if composer_airflow_version < 3 else
"logical_date"
+ )
task = CloudComposerDAGRunSensor(
task_id="task-id",
@@ -85,20 +96,26 @@ class TestCloudComposerDAGRunSensor:
environment_id=TEST_ENVIRONMENT_ID,
composer_dag_id="test_dag_id",
allowed_states=["success"],
+ use_rest_api=use_rest_api,
)
task._composer_airflow_version = composer_airflow_version
assert not task.poke(context={"logical_date": datetime(2024, 5, 23, 0,
0, 0)})
+ @pytest.mark.parametrize("use_rest_api", [True, False])
@pytest.mark.parametrize("composer_airflow_version", [2, 3])
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.ExecuteAirflowCommandResponse.to_dict")
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerHook")
- def test_dag_runs_empty(self, mock_hook, to_dict_mode,
composer_airflow_version):
+ def test_dag_runs_empty(self, mock_hook, to_dict_mode,
composer_airflow_version, use_rest_api):
mock_hook.return_value.wait_command_execution_result.return_value = {
"output": [{"line_number": 1, "content": json.dumps([])}],
"output_end": True,
"exit_info": {"exit_code": 0, "error": ""},
}
+ mock_hook.return_value.get_dag_runs.return_value = {
+ "dag_runs": [],
+ "total_entries": 0,
+ }
task = CloudComposerDAGRunSensor(
task_id="task-id",
@@ -107,18 +124,25 @@ class TestCloudComposerDAGRunSensor:
environment_id=TEST_ENVIRONMENT_ID,
composer_dag_id="test_dag_id",
allowed_states=["success"],
+ use_rest_api=use_rest_api,
)
task._composer_airflow_version = composer_airflow_version
assert not task.poke(context={"logical_date": datetime(2024, 5, 23, 0,
0, 0)})
+ @pytest.mark.parametrize("use_rest_api", [True, False])
@pytest.mark.parametrize("composer_airflow_version", [2, 3])
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.ExecuteAirflowCommandResponse.to_dict")
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerHook")
- def test_composer_dag_run_id_wait_ready(self, mock_hook, to_dict_mode,
composer_airflow_version):
+ def test_composer_dag_run_id_wait_ready(
+ self, mock_hook, to_dict_mode, composer_airflow_version, use_rest_api
+ ):
mock_hook.return_value.wait_command_execution_result.return_value =
TEST_EXEC_RESULT(
"success", "execution_date" if composer_airflow_version < 3 else
"logical_date"
)
+ mock_hook.return_value.get_dag_runs.return_value = TEST_GET_RESULT(
+ "success", "execution_date" if composer_airflow_version < 3 else
"logical_date"
+ )
task = CloudComposerDAGRunSensor(
task_id="task-id",
@@ -128,18 +152,25 @@ class TestCloudComposerDAGRunSensor:
composer_dag_id="test_dag_id",
composer_dag_run_id=TEST_COMPOSER_DAG_RUN_ID,
allowed_states=["success"],
+ use_rest_api=use_rest_api,
)
task._composer_airflow_version = composer_airflow_version
assert task.poke(context={"logical_date": datetime(2024, 5, 23, 0, 0,
0)})
+ @pytest.mark.parametrize("use_rest_api", [True, False])
@pytest.mark.parametrize("composer_airflow_version", [2, 3])
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.ExecuteAirflowCommandResponse.to_dict")
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerHook")
- def test_composer_dag_run_id_wait_not_ready(self, mock_hook, to_dict_mode,
composer_airflow_version):
+ def test_composer_dag_run_id_wait_not_ready(
+ self, mock_hook, to_dict_mode, composer_airflow_version, use_rest_api
+ ):
mock_hook.return_value.wait_command_execution_result.return_value =
TEST_EXEC_RESULT(
"running", "execution_date" if composer_airflow_version < 3 else
"logical_date"
)
+ mock_hook.return_value.get_dag_runs.return_value = TEST_GET_RESULT(
+ "running", "execution_date" if composer_airflow_version < 3 else
"logical_date"
+ )
task = CloudComposerDAGRunSensor(
task_id="task-id",
@@ -149,6 +180,7 @@ class TestCloudComposerDAGRunSensor:
composer_dag_id="test_dag_id",
composer_dag_run_id=TEST_COMPOSER_DAG_RUN_ID,
allowed_states=["success"],
+ use_rest_api=use_rest_api,
)
task._composer_airflow_version = composer_airflow_version
diff --git
a/providers/google/tests/unit/google/cloud/triggers/test_cloud_composer.py
b/providers/google/tests/unit/google/cloud/triggers/test_cloud_composer.py
index cde313785d6..f093a76d1d7 100644
--- a/providers/google/tests/unit/google/cloud/triggers/test_cloud_composer.py
+++ b/providers/google/tests/unit/google/cloud/triggers/test_cloud_composer.py
@@ -46,6 +46,7 @@ TEST_STATES = ["success"]
TEST_GCP_CONN_ID = "test_gcp_conn_id"
TEST_POLL_INTERVAL = 10
TEST_COMPOSER_AIRFLOW_VERSION = 3
+TEST_USE_REST_API = True
TEST_IMPERSONATION_CHAIN = "test_impersonation_chain"
TEST_EXEC_RESULT = {
"output": [{"line_number": 1, "content": "test_content"}],
@@ -90,6 +91,7 @@ def dag_run_trigger(mock_conn):
impersonation_chain=TEST_IMPERSONATION_CHAIN,
poll_interval=TEST_POLL_INTERVAL,
composer_airflow_version=TEST_COMPOSER_AIRFLOW_VERSION,
+ use_rest_api=TEST_USE_REST_API,
)
@@ -146,6 +148,7 @@ class TestCloudComposerDAGRunTrigger:
"impersonation_chain": TEST_IMPERSONATION_CHAIN,
"poll_interval": TEST_POLL_INTERVAL,
"composer_airflow_version": TEST_COMPOSER_AIRFLOW_VERSION,
+ "use_rest_api": TEST_USE_REST_API,
},
)
assert actual_data == expected_data