This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 00f33431915 Add composer_dag_run_id as optional parameter to 
CloudComposerDAGRunSensor (#54977)
00f33431915 is described below

commit 00f3343191554e5382260550d06fdef556a67f74
Author: Maksim <[email protected]>
AuthorDate: Thu Sep 4 00:27:41 2025 +0200

    Add composer_dag_run_id as optional parameter to CloudComposerDAGRunSensor 
(#54977)
---
 .../google/cloud/sensors/cloud_composer.py         | 25 ++++++++++++
 .../google/cloud/triggers/cloud_composer.py        | 35 +++++++++++++----
 .../google/cloud/sensors/test_cloud_composer.py    | 45 +++++++++++++++++++++-
 .../google/cloud/triggers/test_cloud_composer.py   |  3 ++
 4 files changed, 99 insertions(+), 9 deletions(-)

diff --git 
a/providers/google/src/airflow/providers/google/cloud/sensors/cloud_composer.py 
b/providers/google/src/airflow/providers/google/cloud/sensors/cloud_composer.py
index 85f8323225d..d69cce09899 100644
--- 
a/providers/google/src/airflow/providers/google/cloud/sensors/cloud_composer.py
+++ 
b/providers/google/src/airflow/providers/google/cloud/sensors/cloud_composer.py
@@ -61,6 +61,7 @@ class CloudComposerDAGRunSensor(BaseSensorOperator):
         Or [datetime(2024,3,22,0,0,0)] in this case sensor will check for 
states from specific time in the
         past till current time execution.
         Default value datetime.timedelta(days=1).
+    :param composer_dag_run_id: The Run ID of executable task. The 
'execution_range' param is ignored, if both specified.
     :param gcp_conn_id: The connection ID to use when fetching connection info.
     :param impersonation_chain: Optional service account to impersonate using 
short-term
         credentials, or chained list of accounts required to get the 
access_token
@@ -91,6 +92,7 @@ class CloudComposerDAGRunSensor(BaseSensorOperator):
         composer_dag_id: str,
         allowed_states: Iterable[str] | None = None,
         execution_range: timedelta | list[datetime] | None = None,
+        composer_dag_run_id: str | None = None,
         gcp_conn_id: str = "google_cloud_default",
         impersonation_chain: str | Sequence[str] | None = None,
         deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
@@ -104,11 +106,17 @@ class CloudComposerDAGRunSensor(BaseSensorOperator):
         self.composer_dag_id = composer_dag_id
         self.allowed_states = list(allowed_states) if allowed_states else 
[TaskInstanceState.SUCCESS.value]
         self.execution_range = execution_range
+        self.composer_dag_run_id = composer_dag_run_id
         self.gcp_conn_id = gcp_conn_id
         self.impersonation_chain = impersonation_chain
         self.deferrable = deferrable
         self.poll_interval = poll_interval
 
+        if self.composer_dag_run_id and self.execution_range:
+            self.log.warning(
+                "The composer_dag_run_id parameter and execution_range 
parameter do not work together. This run will ignore execution_range parameter 
and count only specified composer_dag_run_id parameter."
+            )
+
     def _get_logical_dates(self, context) -> tuple[datetime, datetime]:
         if isinstance(self.execution_range, timedelta):
             if self.execution_range < timedelta(0):
@@ -132,6 +140,16 @@ class CloudComposerDAGRunSensor(BaseSensorOperator):
             self.log.info("Dag runs are empty. Sensor waits for dag runs...")
             return False
 
+        if self.composer_dag_run_id:
+            self.log.info(
+                "Sensor waits for allowed states %s for specified RunID: %s",
+                self.allowed_states,
+                self.composer_dag_run_id,
+            )
+            composer_dag_run_id_status = 
self._check_composer_dag_run_id_states(
+                dag_runs=dag_runs,
+            )
+            return composer_dag_run_id_status
         self.log.info("Sensor waits for allowed states: %s", 
self.allowed_states)
         allowed_states_status = self._check_dag_runs_states(
             dag_runs=dag_runs,
@@ -193,6 +211,12 @@ class CloudComposerDAGRunSensor(BaseSensorOperator):
         image_version = 
environment_config["config"]["software_config"]["image_version"]
         return int(image_version.split("airflow-")[1].split(".")[0])
 
+    def _check_composer_dag_run_id_states(self, dag_runs: list[dict]) -> bool:
+        for dag_run in dag_runs:
+            if dag_run["run_id"] == self.composer_dag_run_id and 
dag_run["state"] in self.allowed_states:
+                return True
+        return False
+
     def execute(self, context: Context) -> None:
         self._composer_airflow_version = self._get_composer_airflow_version()
         if self.deferrable:
@@ -204,6 +228,7 @@ class CloudComposerDAGRunSensor(BaseSensorOperator):
                     region=self.region,
                     environment_id=self.environment_id,
                     composer_dag_id=self.composer_dag_id,
+                    composer_dag_run_id=self.composer_dag_run_id,
                     start_date=start_date,
                     end_date=end_date,
                     allowed_states=self.allowed_states,
diff --git 
a/providers/google/src/airflow/providers/google/cloud/triggers/cloud_composer.py
 
b/providers/google/src/airflow/providers/google/cloud/triggers/cloud_composer.py
index 4382b114bd4..a2840dcffe2 100644
--- 
a/providers/google/src/airflow/providers/google/cloud/triggers/cloud_composer.py
+++ 
b/providers/google/src/airflow/providers/google/cloud/triggers/cloud_composer.py
@@ -179,6 +179,7 @@ class CloudComposerDAGRunTrigger(BaseTrigger):
         start_date: datetime,
         end_date: datetime,
         allowed_states: list[str],
+        composer_dag_run_id: str | None = None,
         gcp_conn_id: str = "google_cloud_default",
         impersonation_chain: str | Sequence[str] | None = None,
         poll_interval: int = 10,
@@ -192,6 +193,7 @@ class CloudComposerDAGRunTrigger(BaseTrigger):
         self.start_date = start_date
         self.end_date = end_date
         self.allowed_states = allowed_states
+        self.composer_dag_run_id = composer_dag_run_id
         self.gcp_conn_id = gcp_conn_id
         self.impersonation_chain = impersonation_chain
         self.poll_interval = poll_interval
@@ -213,6 +215,7 @@ class CloudComposerDAGRunTrigger(BaseTrigger):
                 "start_date": self.start_date,
                 "end_date": self.end_date,
                 "allowed_states": self.allowed_states,
+                "composer_dag_run_id": self.composer_dag_run_id,
                 "gcp_conn_id": self.gcp_conn_id,
                 "impersonation_chain": self.impersonation_chain,
                 "poll_interval": self.poll_interval,
@@ -261,6 +264,12 @@ class CloudComposerDAGRunTrigger(BaseTrigger):
                 return False
         return True
 
+    def _check_composer_dag_run_id_states(self, dag_runs: list[dict]) -> bool:
+        for dag_run in dag_runs:
+            if dag_run["run_id"] == self.composer_dag_run_id and 
dag_run["state"] in self.allowed_states:
+                return True
+        return False
+
     async def run(self):
         try:
             while True:
@@ -273,14 +282,24 @@ class CloudComposerDAGRunTrigger(BaseTrigger):
                         await asyncio.sleep(self.poll_interval)
                         continue
 
-                    self.log.info("Sensor waits for allowed states: %s", 
self.allowed_states)
-                    if self._check_dag_runs_states(
-                        dag_runs=dag_runs,
-                        start_date=self.start_date,
-                        end_date=self.end_date,
-                    ):
-                        yield TriggerEvent({"status": "success"})
-                        return
+                    if self.composer_dag_run_id:
+                        self.log.info(
+                            "Sensor waits for allowed states %s for specified 
RunID: %s",
+                            self.allowed_states,
+                            self.composer_dag_run_id,
+                        )
+                        if 
self._check_composer_dag_run_id_states(dag_runs=dag_runs):
+                            yield TriggerEvent({"status": "success"})
+                            return
+                    else:
+                        self.log.info("Sensor waits for allowed states: %s", 
self.allowed_states)
+                        if self._check_dag_runs_states(
+                            dag_runs=dag_runs,
+                            start_date=self.start_date,
+                            end_date=self.end_date,
+                        ):
+                            yield TriggerEvent({"status": "success"})
+                            return
                 self.log.info("Sleeping for %s seconds.", self.poll_interval)
                 await asyncio.sleep(self.poll_interval)
         except AirflowException as ex:
diff --git 
a/providers/google/tests/unit/google/cloud/sensors/test_cloud_composer.py 
b/providers/google/tests/unit/google/cloud/sensors/test_cloud_composer.py
index ccd034a5008..c2da2daa00b 100644
--- a/providers/google/tests/unit/google/cloud/sensors/test_cloud_composer.py
+++ b/providers/google/tests/unit/google/cloud/sensors/test_cloud_composer.py
@@ -29,11 +29,12 @@ TEST_PROJECT_ID = "test_project_id"
 TEST_OPERATION_NAME = "test_operation_name"
 TEST_REGION = "region"
 TEST_ENVIRONMENT_ID = "test_env_id"
+TEST_COMPOSER_DAG_RUN_ID = "scheduled__2024-05-22T11:10:00+00:00"
 TEST_JSON_RESULT = lambda state, date_key: json.dumps(
     [
         {
             "dag_id": "test_dag_id",
-            "run_id": "scheduled__2024-05-22T11:10:00+00:00",
+            "run_id": TEST_COMPOSER_DAG_RUN_ID,
             "state": state,
             date_key: "2024-05-22T11:10:00+00:00",
             "start_date": "2024-05-22T11:20:01.531988+00:00",
@@ -110,3 +111,45 @@ class TestCloudComposerDAGRunSensor:
         task._composer_airflow_version = composer_airflow_version
 
         assert not task.poke(context={"logical_date": datetime(2024, 5, 23, 0, 
0, 0)})
+
+    @pytest.mark.parametrize("composer_airflow_version", [2, 3])
+    
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.ExecuteAirflowCommandResponse.to_dict")
+    
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerHook")
+    def test_composer_dag_run_id_wait_ready(self, mock_hook, to_dict_mode, 
composer_airflow_version):
+        mock_hook.return_value.wait_command_execution_result.return_value = 
TEST_EXEC_RESULT(
+            "success", "execution_date" if composer_airflow_version < 3 else 
"logical_date"
+        )
+
+        task = CloudComposerDAGRunSensor(
+            task_id="task-id",
+            project_id=TEST_PROJECT_ID,
+            region=TEST_REGION,
+            environment_id=TEST_ENVIRONMENT_ID,
+            composer_dag_id="test_dag_id",
+            composer_dag_run_id=TEST_COMPOSER_DAG_RUN_ID,
+            allowed_states=["success"],
+        )
+        task._composer_airflow_version = composer_airflow_version
+
+        assert task.poke(context={"logical_date": datetime(2024, 5, 23, 0, 0, 
0)})
+
+    @pytest.mark.parametrize("composer_airflow_version", [2, 3])
+    
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.ExecuteAirflowCommandResponse.to_dict")
+    
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerHook")
+    def test_composer_dag_run_id_wait_not_ready(self, mock_hook, to_dict_mode, 
composer_airflow_version):
+        mock_hook.return_value.wait_command_execution_result.return_value = 
TEST_EXEC_RESULT(
+            "running", "execution_date" if composer_airflow_version < 3 else 
"logical_date"
+        )
+
+        task = CloudComposerDAGRunSensor(
+            task_id="task-id",
+            project_id=TEST_PROJECT_ID,
+            region=TEST_REGION,
+            environment_id=TEST_ENVIRONMENT_ID,
+            composer_dag_id="test_dag_id",
+            composer_dag_run_id=TEST_COMPOSER_DAG_RUN_ID,
+            allowed_states=["success"],
+        )
+        task._composer_airflow_version = composer_airflow_version
+
+        assert not task.poke(context={"logical_date": datetime(2024, 5, 23, 0, 
0, 0)})
diff --git 
a/providers/google/tests/unit/google/cloud/triggers/test_cloud_composer.py 
b/providers/google/tests/unit/google/cloud/triggers/test_cloud_composer.py
index 8716805fa13..cde313785d6 100644
--- a/providers/google/tests/unit/google/cloud/triggers/test_cloud_composer.py
+++ b/providers/google/tests/unit/google/cloud/triggers/test_cloud_composer.py
@@ -39,6 +39,7 @@ TEST_EXEC_CMD_INFO = {
     "error": "test_error",
 }
 TEST_COMPOSER_DAG_ID = "test_dag_id"
+TEST_COMPOSER_DAG_RUN_ID = "scheduled__2024-05-22T11:10:00+00:00"
 TEST_START_DATE = datetime(2024, 3, 22, 11, 0, 0)
 TEST_END_DATE = datetime(2024, 3, 22, 12, 0, 0)
 TEST_STATES = ["success"]
@@ -81,6 +82,7 @@ def dag_run_trigger(mock_conn):
         region=TEST_LOCATION,
         environment_id=TEST_ENVIRONMENT_ID,
         composer_dag_id=TEST_COMPOSER_DAG_ID,
+        composer_dag_run_id=TEST_COMPOSER_DAG_RUN_ID,
         start_date=TEST_START_DATE,
         end_date=TEST_END_DATE,
         allowed_states=TEST_STATES,
@@ -136,6 +138,7 @@ class TestCloudComposerDAGRunTrigger:
                 "region": TEST_LOCATION,
                 "environment_id": TEST_ENVIRONMENT_ID,
                 "composer_dag_id": TEST_COMPOSER_DAG_ID,
+                "composer_dag_run_id": TEST_COMPOSER_DAG_RUN_ID,
                 "start_date": TEST_START_DATE,
                 "end_date": TEST_END_DATE,
                 "allowed_states": TEST_STATES,

Reply via email to