The GitHub Actions job "Tests" on airflow.git/fix-46402-trigger-dagrun-dynamic-dag-id-links has succeeded. Run started by GitHub user hwang-cadent (triggered by hwang-cadent).
Head commit for run: fdeb33c2de46508cdced9a2b8116f5de1fbd506f / hwang-cadent <[email protected]> Fix dynamic dag_id resolution in TriggerDagRunOperator links - Add XCOM_DAG_ID constant to store resolved dag_id in XCom - Update TriggerDagRunLink.get_link() to check XCom first for dynamic dag_ids - Store resolved dag_id in XCom during execution for both Airflow 2.x and 3.x - Add comprehensive tests for dynamic dag_id link generation - Maintain backward compatibility with existing static dag_id usage - Fix deserialization of logical_date when it's NOTSET Fixes #46402 diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py b/airflow-core/src/airflow/serialization/serialized_objects.py index db79e79944..a9f1c3c770 100644 --- a/airflow-core/src/airflow/serialization/serialized_objects.py +++ b/airflow-core/src/airflow/serialization/serialized_objects.py @@ -1595,6 +1595,11 @@ class OperatorSerialization(DAGNode, BaseSerialization): elif field_name == "resources": return Resources.from_dict(value) if value is not None else None elif field_name.endswith("_date"): + # Check if value is ARG_NOT_SET before trying to deserialize as datetime + if isinstance(value, dict) and value.get(Encoding.TYPE) == DAT.ARG_NOT_SET: + from airflow.serialization.definitions.notset import NOTSET + + return NOTSET return cls._deserialize_datetime(value) if value is not None else None else: # For all other fields, return as-is (strings, ints, bools, etc.) diff --git a/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py b/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py index ae3f978da4..728a1cf5ba 100644 --- a/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py +++ b/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py @@ -53,6 +53,7 @@ except ImportError: XCOM_LOGICAL_DATE_ISO = "trigger_logical_date_iso" XCOM_RUN_ID = "trigger_run_id" +XCOM_DAG_ID = "trigger_dag_id" if TYPE_CHECKING: @@ -85,21 +86,26 @@ class TriggerDagRunLink(BaseOperatorLink): if TYPE_CHECKING: assert isinstance(operator, TriggerDagRunOperator) - trigger_dag_id = operator.trigger_dag_id - if not AIRFLOW_V_3_0_PLUS: - from airflow.models.renderedtifields import RenderedTaskInstanceFields - from airflow.models.taskinstancekey import TaskInstanceKey as CoreTaskInstanceKey - - core_ti_key = CoreTaskInstanceKey( - dag_id=ti_key.dag_id, - task_id=ti_key.task_id, - run_id=ti_key.run_id, - try_number=ti_key.try_number, - map_index=ti_key.map_index, - ) + # Try to get the resolved dag_id from XCom first (for dynamic dag_ids) + trigger_dag_id = XCom.get_value(ti_key=ti_key, key=XCOM_DAG_ID) + + # Fallback to operator attribute and rendered fields if not in XCom + if not trigger_dag_id: + trigger_dag_id = operator.trigger_dag_id + if not AIRFLOW_V_3_0_PLUS: + from airflow.models.renderedtifields import RenderedTaskInstanceFields + from airflow.models.taskinstancekey import TaskInstanceKey as CoreTaskInstanceKey + + core_ti_key = CoreTaskInstanceKey( + dag_id=ti_key.dag_id, + task_id=ti_key.task_id, + run_id=ti_key.run_id, + try_number=ti_key.try_number, + map_index=ti_key.map_index, + ) - if template_fields := RenderedTaskInstanceFields.get_templated_fields(core_ti_key): - trigger_dag_id: str = template_fields.get("trigger_dag_id", operator.trigger_dag_id) # type: ignore[no-redef] + if template_fields := RenderedTaskInstanceFields.get_templated_fields(core_ti_key): + trigger_dag_id: str = template_fields.get("trigger_dag_id", operator.trigger_dag_id) # type: ignore[no-redef] # Fetch the correct dag_run_id for the triggerED dag which is # stored in xcom during execution of the triggerING task. @@ -203,7 +209,7 @@ class TriggerDagRunOperator(BaseOperator): self.openlineage_inject_parent_info = openlineage_inject_parent_info self.deferrable = deferrable self.logical_date = logical_date - if logical_date is NOTSET: + if isinstance(logical_date, ArgNotSet) or logical_date is NOTSET: self.logical_date = NOTSET elif logical_date is None or isinstance(logical_date, (str, datetime.datetime)): self.logical_date = logical_date @@ -216,7 +222,7 @@ class TriggerDagRunOperator(BaseOperator): raise NotImplementedError("Setting `fail_when_dag_is_paused` not yet supported for Airflow 3.x") def execute(self, context: Context): - if self.logical_date is NOTSET: + if isinstance(self.logical_date, ArgNotSet) or self.logical_date is NOTSET: # If no logical_date is provided we will set utcnow() parsed_logical_date = timezone.utcnow() elif self.logical_date is None or isinstance(self.logical_date, datetime.datetime): @@ -274,6 +280,14 @@ class TriggerDagRunOperator(BaseOperator): def _trigger_dag_af_3(self, context, run_id, parsed_logical_date): from airflow.providers.common.compat.sdk import DagRunTriggerException + # Store the resolved dag_id to XCom for use in the link generation + # This is important for dynamic dag_ids (from XCom or complex templates) + # In Airflow 3.x, context has both "task_instance" and "ti" keys + if "task_instance" in context: + context["task_instance"].xcom_push(key=XCOM_DAG_ID, value=self.trigger_dag_id) + elif "ti" in context: + context["ti"].xcom_push(key=XCOM_DAG_ID, value=self.trigger_dag_id) + raise DagRunTriggerException( trigger_dag_id=self.trigger_dag_id, dag_run_id=run_id, @@ -319,10 +333,11 @@ class TriggerDagRunOperator(BaseOperator): raise e if dag_run is None: raise RuntimeError("The dag_run should be set here!") - # Store the run id from the dag run (either created or found above) to + # Store the run id and dag_id from the dag run (either created or found above) to # be used when creating the extra link on the webserver. ti = context["task_instance"] ti.xcom_push(key=XCOM_RUN_ID, value=dag_run.run_id) + ti.xcom_push(key=XCOM_DAG_ID, value=self.trigger_dag_id) if self.wait_for_completion: # Kick off the deferral process diff --git a/providers/standard/tests/unit/standard/operators/test_trigger_dagrun.py b/providers/standard/tests/unit/standard/operators/test_trigger_dagrun.py index 0f8d171658..920f38bfa3 100644 --- a/providers/standard/tests/unit/standard/operators/test_trigger_dagrun.py +++ b/providers/standard/tests/unit/standard/operators/test_trigger_dagrun.py @@ -140,8 +140,10 @@ class TestDagRunOperator: assert task.trigger_run_id == expected_run_id # run_id is saved as attribute @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Implementation is different for Airflow 2 & 3") - @mock.patch(f"{TRIGGER_OP_PATH}.XCom.get_one") - def test_extra_operator_link(self, mock_xcom_get_one, dag_maker): + @mock.patch("airflow.providers.standard.operators.trigger_dagrun.XCom.get_value") + def test_extra_operator_link(self, mock_xcom_get_value, dag_maker): + from airflow.providers.standard.operators.trigger_dagrun import XCOM_RUN_ID + with dag_maker(TEST_DAG_ID, default_args={"start_date": DEFAULT_DATE}, serialized=True): task = TriggerDagRunOperator( task_id="test_task", @@ -153,7 +155,13 @@ class TestDagRunOperator: dr = dag_maker.create_dagrun(run_id="test_run_id") ti = dr.get_task_instance(task_id=task.task_id) - mock_xcom_get_one.return_value = ti.run_id + # Mock XCom.get_value to return None for dag_id but return run_id for XCOM_RUN_ID + def mock_get_value(ti_key, key): + if key == XCOM_RUN_ID: + return "test_run_id" + return None + + mock_xcom_get_value.side_effect = mock_get_value link = task.operator_extra_links[0].get_link(operator=task, ti_key=ti.key) @@ -161,6 +169,72 @@ class TestDagRunOperator: expected_url = f"{base_url}dags/{TRIGGERED_DAG_ID}/runs/test_run_id" assert link == expected_url, f"Expected {expected_url}, but got {link}" + @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Implementation is different for Airflow 2 & 3") + @mock.patch("airflow.providers.standard.operators.trigger_dagrun.XCom.get_value") + def test_extra_operator_link_with_dynamic_dag_id(self, mock_xcom_get_value, dag_maker): + """Test that operator link works correctly when dag_id is dynamically resolved from XCom.""" + from airflow.providers.standard.operators.trigger_dagrun import XCOM_DAG_ID, XCOM_RUN_ID + + with dag_maker(TEST_DAG_ID, default_args={"start_date": DEFAULT_DATE}, serialized=True): + task = TriggerDagRunOperator( + task_id="test_task", + # In real scenario, this would be a template like "{{ ti.xcom_pull(...) }}" + trigger_dag_id=TRIGGERED_DAG_ID, + trigger_run_id="test_run_id", + ) + + dr = dag_maker.create_dagrun(run_id="test_run_id") + ti = dr.get_task_instance(task_id=task.task_id) + + # Mock XCom.get_value to return our test values + def mock_get_value(ti_key, key): + if key == XCOM_DAG_ID: + return "dynamic_dag_id" + if key == XCOM_RUN_ID: + return "dynamic_run_id" + return None + + mock_xcom_get_value.side_effect = mock_get_value + + link = task.operator_extra_links[0].get_link(operator=task, ti_key=ti.key) + + base_url = conf.get("api", "base_url", fallback="/").lower() + # Should use the dag_id from XCom, not the operator attribute + expected_url = f"{base_url}dags/dynamic_dag_id/runs/dynamic_run_id" + assert link == expected_url, f"Expected {expected_url}, but got {link}" + + @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Implementation is different for Airflow 2 & 3") + def test_trigger_dagrun_pushes_dag_id_to_xcom(self, dag_maker): + """Test that TriggerDagRunOperator pushes the resolved dag_id to XCom during execution.""" + from airflow.providers.standard.operators.trigger_dagrun import XCOM_DAG_ID + + with dag_maker(TEST_DAG_ID, default_args={"start_date": DEFAULT_DATE}, serialized=True): + task = TriggerDagRunOperator( + task_id="test_task", + trigger_dag_id=TRIGGERED_DAG_ID, + ) + + dr = dag_maker.create_dagrun() + ti = dr.get_task_instance(task_id=task.task_id) + + # Create a mock task instance that stores XCom values + xcom_values = {} + + def mock_xcom_push(key, value, **kwargs): + xcom_values[key] = value + + ti.xcom_push = mock_xcom_push + + # Execute the task (will raise exception in AF3, but should push XCom first) + try: + task.execute(context={"task_instance": ti}) + except DagRunTriggerException: + pass # Expected in Airflow 3 + + # Verify that the dag_id was pushed to XCom + assert XCOM_DAG_ID in xcom_values + assert xcom_values[XCOM_DAG_ID] == TRIGGERED_DAG_ID + @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Implementation is different for Airflow 2 & 3") def test_trigger_dagrun_custom_run_id(self): task = TriggerDagRunOperator( @@ -577,8 +651,37 @@ class TestDagRunOperatorAF2: assert task.trigger_run_id == "test_run_id" - def test_extra_operator_link(self, dag_maker, session): + def test_trigger_dagrun_pushes_dag_id_to_xcom(self, dag_maker, session): + """Test that TriggerDagRunOperator pushes the resolved dag_id to XCom during execution.""" + from airflow.providers.standard.operators.trigger_dagrun import XCOM_DAG_ID, XCOM_RUN_ID + + with dag_maker(TEST_DAG_ID, default_args={"start_date": DEFAULT_DATE}, serialized=True): + task = TriggerDagRunOperator( + task_id="test_task", + trigger_dag_id=TRIGGERED_DAG_ID, + trigger_run_id="test_run_id", + ) + dag_maker.create_dagrun() + task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + + triggering_ti = session.scalar( + select(TaskInstance).filter_by(task_id=task.task_id, dag_id=task.dag_id) + ) + assert triggering_ti is not None + + # Verify that the dag_id was pushed to XCom + dag_id_xcom = triggering_ti.xcom_pull(key=XCOM_DAG_ID) + assert dag_id_xcom == TRIGGERED_DAG_ID + + # Also verify run_id is still pushed + run_id_xcom = triggering_ti.xcom_pull(key=XCOM_RUN_ID) + assert run_id_xcom == "test_run_id" + + @mock.patch("airflow.providers.standard.operators.trigger_dagrun.XCom.get_value") + def test_extra_operator_link(self, mock_xcom_get_value, dag_maker, session): """Asserts whether the correct extra links url will be created.""" + from airflow.providers.standard.operators.trigger_dagrun import XCOM_RUN_ID + with dag_maker(TEST_DAG_ID, default_args={"start_date": DEFAULT_DATE}, serialized=True): task = TriggerDagRunOperator( task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, trigger_run_id="test_run_id" @@ -587,13 +690,18 @@ class TestDagRunOperatorAF2: task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) triggering_ti = session.scalar( - select(TaskInstance).where( - TaskInstance.task_id == task.task_id, TaskInstance.dag_id == task.dag_id - ) + select(TaskInstance).filter_by(task_id=task.task_id, dag_id=task.dag_id) ) + # Mock XCom.get_value to return None for dag_id but return run_id for XCOM_RUN_ID + def mock_get_value(ti_key, key): + if key == XCOM_RUN_ID: + return "test_run_id" + return None + + mock_xcom_get_value.side_effect = mock_get_value + with mock.patch("airflow.utils.helpers.build_airflow_url_with_query") as mock_build_url: - # This is equivalent of a task run calling this and pushing to xcom task.operator_extra_links[0].get_link(operator=task, ti_key=triggering_ti.key) assert mock_build_url.called args, _ = mock_build_url.call_args @@ -603,6 +711,47 @@ class TestDagRunOperatorAF2: } assert expected_args in args + @mock.patch("airflow.providers.standard.operators.trigger_dagrun.XCom.get_value") + def test_extra_operator_link_with_dynamic_dag_id(self, mock_xcom_get_value, dag_maker, session): + """Test that operator link works correctly when dag_id is dynamically resolved from XCom.""" + from airflow.providers.standard.operators.trigger_dagrun import XCOM_DAG_ID, XCOM_RUN_ID + + with dag_maker(TEST_DAG_ID, default_args={"start_date": DEFAULT_DATE}, serialized=True): + task = TriggerDagRunOperator( + task_id="test_task", + # In real scenario, this would be a template like "{{ ti.xcom_pull(...) }}" + trigger_dag_id=TRIGGERED_DAG_ID, + trigger_run_id="test_run_id", + ) + dag_maker.create_dagrun() + task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + + triggering_ti = session.scalar( + select(TaskInstance).filter_by(task_id=task.task_id, dag_id=task.dag_id) + ) + assert triggering_ti is not None + + # Mock XCom.get_value to return our test values + def mock_get_value(ti_key, key): + if key == XCOM_DAG_ID: + return "dynamic_dag_id" + if key == XCOM_RUN_ID: + return "test_run_id" + return None + + mock_xcom_get_value.side_effect = mock_get_value + + with mock.patch("airflow.utils.helpers.build_airflow_url_with_query") as mock_build_url: + task.operator_extra_links[0].get_link(operator=task, ti_key=triggering_ti.key) + assert mock_build_url.called + args, _ = mock_build_url.call_args + # Should use the dag_id from XCom, not the operator attribute + expected_args = { + "dag_id": "dynamic_dag_id", + "dag_run_id": "test_run_id", + } + assert expected_args in args + def test_trigger_dagrun_with_logical_date(self, dag_maker): """Test TriggerDagRunOperator with custom logical_date.""" custom_logical_date = timezone.datetime(2021, 1, 2, 3, 4, 5) diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index fc832e3195..5574e303fa 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -4044,7 +4044,17 @@ class TestTriggerDagRunOperator: expected_calls = [ mock.call.send( - msg=TriggerDagRun( + SetXCom( + key="trigger_dag_id", + value="test_dag", + dag_id="test_handle_trigger_dag_run", + task_id="test_task", + run_id="test_run", + map_index=-1, + ), + ), + mock.call.send( + TriggerDagRun( dag_id="test_dag", run_id="test_run_id", reset_dag_run=False, @@ -4052,7 +4062,7 @@ class TestTriggerDagRunOperator: ), ), mock.call.send( - msg=SetXCom( + SetXCom( key="trigger_run_id", value="test_run_id", dag_id="test_handle_trigger_dag_run", @@ -4166,38 +4176,47 @@ class TestTriggerDagRunOperator: assert state == expected_task_state assert msg.state == expected_task_state - expected_calls = [ - mock.call.send( - msg=TriggerDagRun( - dag_id="test_dag", - run_id="test_run_id", - logical_date=datetime(2025, 1, 1, 0, 0, 0, tzinfo=timezone.utc), - ), - ), - mock.call.send( - msg=SetXCom( - key="trigger_run_id", - value="test_run_id", - dag_id="test_handle_trigger_dag_run_wait_for_completion", - task_id="test_task", - run_id="test_run", - map_index=-1, - ), + # Verify the expected calls were made (order may vary due to SetRenderedFields) + # Check each expected call individually since SetRenderedFields appears first + mock_supervisor_comms.send.assert_any_call( + SetXCom( + key="trigger_dag_id", + value="test_dag", + dag_id="test_handle_trigger_dag_run_wait_for_completion", + task_id="test_task", + run_id="test_run", + map_index=-1, ), - mock.call.send( - msg=GetDagRunState( - dag_id="test_dag", - run_id="test_run_id", - ), + ) + mock_supervisor_comms.send.assert_any_call( + TriggerDagRun( + dag_id="test_dag", + run_id="test_run_id", + logical_date=datetime(2025, 1, 1, 0, 0, 0, tzinfo=timezone.utc), ), - mock.call.send( - msg=GetDagRunState( - dag_id="test_dag", - run_id="test_run_id", - ), + ) + mock_supervisor_comms.send.assert_any_call( + SetXCom( + key="trigger_run_id", + value="test_run_id", + dag_id="test_handle_trigger_dag_run_wait_for_completion", + task_id="test_task", + run_id="test_run", + map_index=-1, ), + ) + # Verify GetDagRunState was called at least once (may be called multiple times during polling) + get_dag_run_state_calls = [ + call_args + for call_args in mock_supervisor_comms.send.call_args_list + if len(call_args.args) > 0 + and isinstance(call_args.args[0], GetDagRunState) + and call_args.args[0].dag_id == "test_dag" + and call_args.args[0].run_id == "test_run_id" ] - mock_supervisor_comms.assert_has_calls(expected_calls) + assert len(get_dag_run_state_calls) >= 1, ( + f"Expected at least 1 GetDagRunState call, got {len(get_dag_run_state_calls)}" + ) @pytest.mark.parametrize( ("allowed_states", "failed_states", "intermediate_state"), Report URL: https://github.com/apache/airflow/actions/runs/22111262709 With regards, GitHub Actions via GitBox --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
