This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 00f33431915 Add composer_dag_run_id as optional parameter to
CloudComposerDAGRunSensor (#54977)
00f33431915 is described below
commit 00f3343191554e5382260550d06fdef556a67f74
Author: Maksim <[email protected]>
AuthorDate: Thu Sep 4 00:27:41 2025 +0200
Add composer_dag_run_id as optional parameter to CloudComposerDAGRunSensor
(#54977)
---
.../google/cloud/sensors/cloud_composer.py | 25 ++++++++++++
.../google/cloud/triggers/cloud_composer.py | 35 +++++++++++++----
.../google/cloud/sensors/test_cloud_composer.py | 45 +++++++++++++++++++++-
.../google/cloud/triggers/test_cloud_composer.py | 3 ++
4 files changed, 99 insertions(+), 9 deletions(-)
diff --git
a/providers/google/src/airflow/providers/google/cloud/sensors/cloud_composer.py
b/providers/google/src/airflow/providers/google/cloud/sensors/cloud_composer.py
index 85f8323225d..d69cce09899 100644
---
a/providers/google/src/airflow/providers/google/cloud/sensors/cloud_composer.py
+++
b/providers/google/src/airflow/providers/google/cloud/sensors/cloud_composer.py
@@ -61,6 +61,7 @@ class CloudComposerDAGRunSensor(BaseSensorOperator):
Or [datetime(2024,3,22,0,0,0)] in this case sensor will check for
states from specific time in the
past till current time execution.
Default value datetime.timedelta(days=1).
+ :param composer_dag_run_id: The Run ID of executable task. The
'execution_range' param is ignored, if both specified.
:param gcp_conn_id: The connection ID to use when fetching connection info.
:param impersonation_chain: Optional service account to impersonate using
short-term
credentials, or chained list of accounts required to get the
access_token
@@ -91,6 +92,7 @@ class CloudComposerDAGRunSensor(BaseSensorOperator):
composer_dag_id: str,
allowed_states: Iterable[str] | None = None,
execution_range: timedelta | list[datetime] | None = None,
+ composer_dag_run_id: str | None = None,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
@@ -104,11 +106,17 @@ class CloudComposerDAGRunSensor(BaseSensorOperator):
self.composer_dag_id = composer_dag_id
self.allowed_states = list(allowed_states) if allowed_states else
[TaskInstanceState.SUCCESS.value]
self.execution_range = execution_range
+ self.composer_dag_run_id = composer_dag_run_id
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.deferrable = deferrable
self.poll_interval = poll_interval
+ if self.composer_dag_run_id and self.execution_range:
+ self.log.warning(
+ "The composer_dag_run_id parameter and execution_range
parameter do not work together. This run will ignore execution_range parameter
and count only specified composer_dag_run_id parameter."
+ )
+
def _get_logical_dates(self, context) -> tuple[datetime, datetime]:
if isinstance(self.execution_range, timedelta):
if self.execution_range < timedelta(0):
@@ -132,6 +140,16 @@ class CloudComposerDAGRunSensor(BaseSensorOperator):
self.log.info("Dag runs are empty. Sensor waits for dag runs...")
return False
+ if self.composer_dag_run_id:
+ self.log.info(
+ "Sensor waits for allowed states %s for specified RunID: %s",
+ self.allowed_states,
+ self.composer_dag_run_id,
+ )
+ composer_dag_run_id_status =
self._check_composer_dag_run_id_states(
+ dag_runs=dag_runs,
+ )
+ return composer_dag_run_id_status
self.log.info("Sensor waits for allowed states: %s",
self.allowed_states)
allowed_states_status = self._check_dag_runs_states(
dag_runs=dag_runs,
@@ -193,6 +211,12 @@ class CloudComposerDAGRunSensor(BaseSensorOperator):
image_version =
environment_config["config"]["software_config"]["image_version"]
return int(image_version.split("airflow-")[1].split(".")[0])
+ def _check_composer_dag_run_id_states(self, dag_runs: list[dict]) -> bool:
+ for dag_run in dag_runs:
+ if dag_run["run_id"] == self.composer_dag_run_id and
dag_run["state"] in self.allowed_states:
+ return True
+ return False
+
def execute(self, context: Context) -> None:
self._composer_airflow_version = self._get_composer_airflow_version()
if self.deferrable:
@@ -204,6 +228,7 @@ class CloudComposerDAGRunSensor(BaseSensorOperator):
region=self.region,
environment_id=self.environment_id,
composer_dag_id=self.composer_dag_id,
+ composer_dag_run_id=self.composer_dag_run_id,
start_date=start_date,
end_date=end_date,
allowed_states=self.allowed_states,
diff --git
a/providers/google/src/airflow/providers/google/cloud/triggers/cloud_composer.py
b/providers/google/src/airflow/providers/google/cloud/triggers/cloud_composer.py
index 4382b114bd4..a2840dcffe2 100644
---
a/providers/google/src/airflow/providers/google/cloud/triggers/cloud_composer.py
+++
b/providers/google/src/airflow/providers/google/cloud/triggers/cloud_composer.py
@@ -179,6 +179,7 @@ class CloudComposerDAGRunTrigger(BaseTrigger):
start_date: datetime,
end_date: datetime,
allowed_states: list[str],
+ composer_dag_run_id: str | None = None,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
poll_interval: int = 10,
@@ -192,6 +193,7 @@ class CloudComposerDAGRunTrigger(BaseTrigger):
self.start_date = start_date
self.end_date = end_date
self.allowed_states = allowed_states
+ self.composer_dag_run_id = composer_dag_run_id
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.poll_interval = poll_interval
@@ -213,6 +215,7 @@ class CloudComposerDAGRunTrigger(BaseTrigger):
"start_date": self.start_date,
"end_date": self.end_date,
"allowed_states": self.allowed_states,
+ "composer_dag_run_id": self.composer_dag_run_id,
"gcp_conn_id": self.gcp_conn_id,
"impersonation_chain": self.impersonation_chain,
"poll_interval": self.poll_interval,
@@ -261,6 +264,12 @@ class CloudComposerDAGRunTrigger(BaseTrigger):
return False
return True
+ def _check_composer_dag_run_id_states(self, dag_runs: list[dict]) -> bool:
+ for dag_run in dag_runs:
+ if dag_run["run_id"] == self.composer_dag_run_id and
dag_run["state"] in self.allowed_states:
+ return True
+ return False
+
async def run(self):
try:
while True:
@@ -273,14 +282,24 @@ class CloudComposerDAGRunTrigger(BaseTrigger):
await asyncio.sleep(self.poll_interval)
continue
- self.log.info("Sensor waits for allowed states: %s",
self.allowed_states)
- if self._check_dag_runs_states(
- dag_runs=dag_runs,
- start_date=self.start_date,
- end_date=self.end_date,
- ):
- yield TriggerEvent({"status": "success"})
- return
+ if self.composer_dag_run_id:
+ self.log.info(
+ "Sensor waits for allowed states %s for specified
RunID: %s",
+ self.allowed_states,
+ self.composer_dag_run_id,
+ )
+ if
self._check_composer_dag_run_id_states(dag_runs=dag_runs):
+ yield TriggerEvent({"status": "success"})
+ return
+ else:
+ self.log.info("Sensor waits for allowed states: %s",
self.allowed_states)
+ if self._check_dag_runs_states(
+ dag_runs=dag_runs,
+ start_date=self.start_date,
+ end_date=self.end_date,
+ ):
+ yield TriggerEvent({"status": "success"})
+ return
self.log.info("Sleeping for %s seconds.", self.poll_interval)
await asyncio.sleep(self.poll_interval)
except AirflowException as ex:
diff --git
a/providers/google/tests/unit/google/cloud/sensors/test_cloud_composer.py
b/providers/google/tests/unit/google/cloud/sensors/test_cloud_composer.py
index ccd034a5008..c2da2daa00b 100644
--- a/providers/google/tests/unit/google/cloud/sensors/test_cloud_composer.py
+++ b/providers/google/tests/unit/google/cloud/sensors/test_cloud_composer.py
@@ -29,11 +29,12 @@ TEST_PROJECT_ID = "test_project_id"
TEST_OPERATION_NAME = "test_operation_name"
TEST_REGION = "region"
TEST_ENVIRONMENT_ID = "test_env_id"
+TEST_COMPOSER_DAG_RUN_ID = "scheduled__2024-05-22T11:10:00+00:00"
TEST_JSON_RESULT = lambda state, date_key: json.dumps(
[
{
"dag_id": "test_dag_id",
- "run_id": "scheduled__2024-05-22T11:10:00+00:00",
+ "run_id": TEST_COMPOSER_DAG_RUN_ID,
"state": state,
date_key: "2024-05-22T11:10:00+00:00",
"start_date": "2024-05-22T11:20:01.531988+00:00",
@@ -110,3 +111,45 @@ class TestCloudComposerDAGRunSensor:
task._composer_airflow_version = composer_airflow_version
assert not task.poke(context={"logical_date": datetime(2024, 5, 23, 0,
0, 0)})
+
+ @pytest.mark.parametrize("composer_airflow_version", [2, 3])
+
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.ExecuteAirflowCommandResponse.to_dict")
+
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerHook")
+ def test_composer_dag_run_id_wait_ready(self, mock_hook, to_dict_mode,
composer_airflow_version):
+ mock_hook.return_value.wait_command_execution_result.return_value =
TEST_EXEC_RESULT(
+ "success", "execution_date" if composer_airflow_version < 3 else
"logical_date"
+ )
+
+ task = CloudComposerDAGRunSensor(
+ task_id="task-id",
+ project_id=TEST_PROJECT_ID,
+ region=TEST_REGION,
+ environment_id=TEST_ENVIRONMENT_ID,
+ composer_dag_id="test_dag_id",
+ composer_dag_run_id=TEST_COMPOSER_DAG_RUN_ID,
+ allowed_states=["success"],
+ )
+ task._composer_airflow_version = composer_airflow_version
+
+ assert task.poke(context={"logical_date": datetime(2024, 5, 23, 0, 0,
0)})
+
+ @pytest.mark.parametrize("composer_airflow_version", [2, 3])
+
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.ExecuteAirflowCommandResponse.to_dict")
+
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerHook")
+ def test_composer_dag_run_id_wait_not_ready(self, mock_hook, to_dict_mode,
composer_airflow_version):
+ mock_hook.return_value.wait_command_execution_result.return_value =
TEST_EXEC_RESULT(
+ "running", "execution_date" if composer_airflow_version < 3 else
"logical_date"
+ )
+
+ task = CloudComposerDAGRunSensor(
+ task_id="task-id",
+ project_id=TEST_PROJECT_ID,
+ region=TEST_REGION,
+ environment_id=TEST_ENVIRONMENT_ID,
+ composer_dag_id="test_dag_id",
+ composer_dag_run_id=TEST_COMPOSER_DAG_RUN_ID,
+ allowed_states=["success"],
+ )
+ task._composer_airflow_version = composer_airflow_version
+
+ assert not task.poke(context={"logical_date": datetime(2024, 5, 23, 0,
0, 0)})
diff --git
a/providers/google/tests/unit/google/cloud/triggers/test_cloud_composer.py
b/providers/google/tests/unit/google/cloud/triggers/test_cloud_composer.py
index 8716805fa13..cde313785d6 100644
--- a/providers/google/tests/unit/google/cloud/triggers/test_cloud_composer.py
+++ b/providers/google/tests/unit/google/cloud/triggers/test_cloud_composer.py
@@ -39,6 +39,7 @@ TEST_EXEC_CMD_INFO = {
"error": "test_error",
}
TEST_COMPOSER_DAG_ID = "test_dag_id"
+TEST_COMPOSER_DAG_RUN_ID = "scheduled__2024-05-22T11:10:00+00:00"
TEST_START_DATE = datetime(2024, 3, 22, 11, 0, 0)
TEST_END_DATE = datetime(2024, 3, 22, 12, 0, 0)
TEST_STATES = ["success"]
@@ -81,6 +82,7 @@ def dag_run_trigger(mock_conn):
region=TEST_LOCATION,
environment_id=TEST_ENVIRONMENT_ID,
composer_dag_id=TEST_COMPOSER_DAG_ID,
+ composer_dag_run_id=TEST_COMPOSER_DAG_RUN_ID,
start_date=TEST_START_DATE,
end_date=TEST_END_DATE,
allowed_states=TEST_STATES,
@@ -136,6 +138,7 @@ class TestCloudComposerDAGRunTrigger:
"region": TEST_LOCATION,
"environment_id": TEST_ENVIRONMENT_ID,
"composer_dag_id": TEST_COMPOSER_DAG_ID,
+ "composer_dag_run_id": TEST_COMPOSER_DAG_RUN_ID,
"start_date": TEST_START_DATE,
"end_date": TEST_END_DATE,
"allowed_states": TEST_STATES,