This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 0a136f96cc3 Update Dependency Detector to handle dependencies set via
partial (#42578)
0a136f96cc3 is described below
commit 0a136f96cc3652c860e159a06364e3107bdaaa6d
Author: Fred Thomsen <[email protected]>
AuthorDate: Wed Nov 27 22:19:30 2024 -0500
Update Dependency Detector to handle dependencies set via partial (#42578)
The DAG dependency view shows dependencies for DAGs that are available
at DAG parsing/serialization time. Dynamically mapped tasks that
trigger (via `TriggerDagRunOperator`) or wait for (via
`ExternalTaskSensor`) external DAGs are now listed as DAG dependencies
provided that those dependencies are set via `partial`, and not
dynamically expanded.
---
airflow/serialization/serialized_objects.py | 26 ++++++++++++++++++
tests/serialization/test_dag_serialization.py | 38 +++++++++++++++++++--------
2 files changed, 53 insertions(+), 11 deletions(-)
diff --git a/airflow/serialization/serialized_objects.py
b/airflow/serialization/serialized_objects.py
index 2e1a5cd14f9..ced1bdd6837 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -1034,6 +1034,19 @@ class DependencyDetector:
dependency_id=task.task_id,
)
)
+ elif (
+ isinstance(task, MappedOperator)
+ and issubclass(cast(type[BaseOperator], task.operator_class),
TriggerDagRunOperator)
+ and "trigger_dag_id" in task.partial_kwargs
+ ):
+ deps.append(
+ DagDependency(
+ source=task.dag_id,
+ target=task.partial_kwargs["trigger_dag_id"],
+ dependency_type="trigger",
+ dependency_id=task.task_id,
+ )
+ )
elif isinstance(task, ExternalTaskSensor):
deps.append(
DagDependency(
@@ -1043,6 +1056,19 @@ class DependencyDetector:
dependency_id=task.task_id,
)
)
+ elif (
+ isinstance(task, MappedOperator)
+ and issubclass(cast(type[BaseOperator], task.operator_class),
ExternalTaskSensor)
+ and "external_dag_id" in task.partial_kwargs
+ ):
+ deps.append(
+ DagDependency(
+ source=task.partial_kwargs["external_dag_id"],
+ target=task.dag_id,
+ dependency_type="sensor",
+ dependency_id=task.task_id,
+ )
+ )
for obj in task.outlets or []:
if isinstance(obj, Asset):
deps.append(
diff --git a/tests/serialization/test_dag_serialization.py
b/tests/serialization/test_dag_serialization.py
index 5dbe6c968ab..42d07b39220 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -1621,7 +1621,8 @@ class TestStringifiedDAGs:
assert round_tripped.outlets == []
@pytest.mark.db_test
- def test_derived_dag_deps_sensor(self):
+ @pytest.mark.parametrize("mapped", [False, True])
+ def test_derived_dag_deps_sensor(self, mapped):
"""
Tests DAG dependency detection for sensors, including derived classes
"""
@@ -1634,11 +1635,19 @@ class TestStringifiedDAGs:
logical_date = datetime(2020, 1, 1)
for class_ in [ExternalTaskSensor, DerivedSensor]:
with DAG(dag_id="test_derived_dag_deps_sensor", schedule=None,
start_date=logical_date) as dag:
- task1 = class_(
- task_id="task1",
- external_dag_id="external_dag_id",
- mode="reschedule",
- )
+ if mapped:
+ task1 = class_.partial(
+ task_id="task1",
+ external_dag_id="external_dag_id",
+ mode="reschedule",
+ ).expand(external_task_id=["some_task", "some_other_task"])
+ else:
+ task1 = class_(
+ task_id="task1",
+ external_dag_id="external_dag_id",
+ mode="reschedule",
+ )
+
task2 = EmptyOperator(task_id="task2")
task1 >> task2
@@ -1806,7 +1815,8 @@ class TestStringifiedDAGs:
)
assert actual == expected
- def test_derived_dag_deps_operator(self):
+ @pytest.mark.parametrize("mapped", [False, True])
+ def test_derived_dag_deps_operator(self, mapped):
"""
Tests DAG dependency detection for operators, including derived classes
"""
@@ -1820,10 +1830,16 @@ class TestStringifiedDAGs:
for class_ in [TriggerDagRunOperator, DerivedOperator]:
with DAG(dag_id="test_derived_dag_deps_trigger", schedule=None,
start_date=logical_date) as dag:
task1 = EmptyOperator(task_id="task1")
- task2 = class_(
- task_id="task2",
- trigger_dag_id="trigger_dag_id",
- )
+ if mapped:
+ task2 = class_.partial(
+ task_id="task2",
+ trigger_dag_id="trigger_dag_id",
+ ).expand(trigger_run_id=["one", "two"])
+ else:
+ task2 = class_(
+ task_id="task2",
+ trigger_dag_id="trigger_dag_id",
+ )
task1 >> task2
dag = SerializedDAG.to_dict(dag)