This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 074548d76d7 Split DAGNode in Core and SDK (#59708)
074548d76d7 is described below
commit 074548d76d7c2e2579dc0d603e4503eb7b58752f
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Wed Dec 24 04:22:33 2025 +0800
Split DAGNode in Core and SDK (#59708)
---
airflow-core/pyproject.toml | 6 +-
airflow-core/src/airflow/_shared/dagnode | 1 +
.../execution_api/routes/task_instances.py | 4 +-
airflow-core/src/airflow/models/mappedoperator.py | 11 +-
airflow-core/src/airflow/models/taskinstance.py | 7 +-
.../src/airflow/serialization/definitions/dag.py | 4 +-
.../src/airflow/serialization/definitions/node.py | 51 +++++
.../airflow/serialization/definitions/taskgroup.py | 5 +-
.../airflow/serialization/serialized_objects.py | 15 +-
.../src/airflow/ti_deps/deps/trigger_rule_dep.py | 7 +-
airflow-core/src/airflow/utils/dag_edges.py | 8 +-
airflow-core/src/airflow/utils/dot_renderer.py | 3 +-
pyproject.toml | 1 +
shared/dagnode/pyproject.toml | 48 +++++
.../dagnode/src/airflow_shared/dagnode/__init__.py | 16 ++
.../dagnode/src/airflow_shared/dagnode}/node.py | 212 ++++-----------------
shared/dagnode/tests/__init__.py | 16 ++
shared/dagnode/tests/dagnode/__init__.py | 16 ++
shared/dagnode/tests/dagnode/test_node.py | 83 ++++++++
task-sdk/pyproject.toml | 4 +-
task-sdk/src/airflow/sdk/_shared/dagnode | 1 +
.../airflow/sdk/definitions/_internal/mixins.py | 2 -
.../src/airflow/sdk/definitions/_internal/node.py | 187 +-----------------
task-sdk/src/airflow/sdk/definitions/dag.py | 3 +-
24 files changed, 302 insertions(+), 409 deletions(-)
diff --git a/airflow-core/pyproject.toml b/airflow-core/pyproject.toml
index 3cf114af0bf..a0ff1a2466b 100644
--- a/airflow-core/pyproject.toml
+++ b/airflow-core/pyproject.toml
@@ -229,8 +229,9 @@ exclude = [
[tool.hatch.build.targets.sdist.force-include]
"../shared/configuration/src/airflow_shared/configuration" =
"src/airflow/_shared/configuration"
-"../shared/module_loading/src/airflow_shared/module_loading" =
"src/airflow/_shared/module_loading"
+"../shared/dagnode/src/airflow_shared/dagnode" = "src/airflow/_shared/dagnode"
"../shared/logging/src/airflow_shared/logging" = "src/airflow/_shared/logging"
+"../shared/module_loading/src/airflow_shared/module_loading" =
"src/airflow/_shared/module_loading"
"../shared/observability/src/airflow_shared/observability" =
"src/airflow/_shared/observability"
"../shared/secrets_backend/src/airflow_shared/secrets_backend" =
"src/airflow/_shared/secrets_backend"
"../shared/secrets_masker/src/airflow_shared/secrets_masker" =
"src/airflow/_shared/secrets_masker"
@@ -303,10 +304,11 @@ apache-airflow-devel-common = { workspace = true }
[tool.airflow]
shared_distributions = [
"apache-airflow-shared-configuration",
+ "apache-airflow-shared-dagnode",
"apache-airflow-shared-logging",
"apache-airflow-shared-module-loading",
+ "apache-airflow-shared-observability",
"apache-airflow-shared-secrets-backend",
"apache-airflow-shared-secrets-masker",
"apache-airflow-shared-timezones",
- "apache-airflow-shared-observability",
]
diff --git a/airflow-core/src/airflow/_shared/dagnode
b/airflow-core/src/airflow/_shared/dagnode
new file mode 120000
index 00000000000..ad88febb9c0
--- /dev/null
+++ b/airflow-core/src/airflow/_shared/dagnode
@@ -0,0 +1 @@
+../../../../shared/dagnode/src/airflow_shared/dagnode
\ No newline at end of file
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
index 61b205864e6..366cd456de2 100644
---
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
+++
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -78,8 +78,6 @@ from airflow.utils.state import DagRunState,
TaskInstanceState, TerminalTIState
if TYPE_CHECKING:
from sqlalchemy.sql.dml import Update
- from airflow.models.expandinput import SchedulerExpandInput
-
router = VersionedAPIRouter()
ti_id_router = VersionedAPIRouter(
@@ -315,7 +313,7 @@ def _get_upstream_map_indexes(
except NotFullyPopulated:
# Second try: resolve XCom for correct count
try:
- expand_input = cast("SchedulerExpandInput",
upstream_mapped_group._expand_input)
+ expand_input = upstream_mapped_group._expand_input
mapped_ti_count =
expand_input.get_total_map_length(ti.run_id, session=session)
except NotFullyPopulated:
# For these trigger rules, unresolved map indexes are
acceptable.
diff --git a/airflow-core/src/airflow/models/mappedoperator.py
b/airflow-core/src/airflow/models/mappedoperator.py
index c227bbfc54b..758149875f0 100644
--- a/airflow-core/src/airflow/models/mappedoperator.py
+++ b/airflow-core/src/airflow/models/mappedoperator.py
@@ -31,8 +31,8 @@ from sqlalchemy.orm import Session
from airflow.exceptions import AirflowException, NotMapped
from airflow.sdk import BaseOperator as TaskSDKBaseOperator
from airflow.sdk.definitions._internal.abstractoperator import
DEFAULT_RETRY_DELAY_MULTIPLIER
-from airflow.sdk.definitions._internal.node import DAGNode
from airflow.sdk.definitions.mappedoperator import MappedOperator as
TaskSDKMappedOperator
+from airflow.serialization.definitions.node import DAGNode
from airflow.serialization.definitions.param import SerializedParamsDict
from airflow.serialization.definitions.taskgroup import
SerializedMappedTaskGroup, SerializedTaskGroup
from airflow.serialization.enums import DagAttributeTypes
@@ -48,6 +48,7 @@ if TYPE_CHECKING:
from airflow.models import TaskInstance
from airflow.models.expandinput import SchedulerExpandInput
from airflow.sdk import BaseOperatorLink, Context
+ from airflow.sdk.definitions._internal.node import DAGNode as
TaskSDKDAGNode
from airflow.sdk.definitions.operator_resources import Resources
from airflow.serialization.serialized_objects import SerializedDAG
from airflow.task.trigger_rule import TriggerRule
@@ -83,7 +84,6 @@ def is_mapped(obj: Operator | SerializedTaskGroup) ->
TypeGuard[MappedOperator |
getstate_setstate=False,
repr=False,
)
-# TODO (GH-52141): Duplicate DAGNode in the scheduler.
class MappedOperator(DAGNode):
"""Object representing a mapped operator in a DAG."""
@@ -110,11 +110,6 @@ class MappedOperator(DAGNode):
start_from_trigger: bool = False
_needs_expansion: bool = True
- # TODO (GH-52141): These should contain serialized containers, but
currently
- # this class inherits from an SDK one.
- dag: SerializedDAG = attrs.field(init=False) # type: ignore[assignment]
- task_group: SerializedTaskGroup = attrs.field(init=False) # type:
ignore[assignment]
-
doc: str | None = attrs.field(init=False)
doc_json: str | None = attrs.field(init=False)
doc_rst: str | None = attrs.field(init=False)
@@ -503,7 +498,7 @@ class MappedOperator(DAGNode):
@functools.singledispatch
-def get_mapped_ti_count(task: DAGNode, run_id: str, *, session: Session) ->
int:
+def get_mapped_ti_count(task: DAGNode | TaskSDKDAGNode, run_id: str, *,
session: Session) -> int:
raise NotImplementedError(f"Not implemented for {type(task)}")
diff --git a/airflow-core/src/airflow/models/taskinstance.py
b/airflow-core/src/airflow/models/taskinstance.py
index 0814f438688..ac854ac162c 100644
--- a/airflow-core/src/airflow/models/taskinstance.py
+++ b/airflow-core/src/airflow/models/taskinstance.py
@@ -28,7 +28,7 @@ from collections import defaultdict
from collections.abc import Collection, Iterable
from datetime import datetime, timedelta
from functools import cache
-from typing import TYPE_CHECKING, Any, cast
+from typing import TYPE_CHECKING, Any
from urllib.parse import quote
import attrs
@@ -2332,10 +2332,7 @@ def find_relevant_relatives(
# Treat it as a normal task instead.
_visit_relevant_relatives_for_normal([task_id])
continue
- # TODO (GH-52141): This should return scheduler operator types, but
- # currently get_flat_relatives is inherited from SDK DAGNode.
- relatives = cast("Iterable[Operator]",
task.get_flat_relatives(upstream=direction == "upstream"))
- for relative in relatives:
+ for relative in task.get_flat_relatives(upstream=direction ==
"upstream"):
if relative.task_id in visited:
continue
relative_map_indexes = _get_relevant_map_indexes(
diff --git a/airflow-core/src/airflow/serialization/definitions/dag.py
b/airflow-core/src/airflow/serialization/definitions/dag.py
index f6556dcfd1c..238d2d748dc 100644
--- a/airflow-core/src/airflow/serialization/definitions/dag.py
+++ b/airflow-core/src/airflow/serialization/definitions/dag.py
@@ -284,9 +284,7 @@ class SerializedDAG:
direct_upstreams: list[SerializedOperator] = []
if include_direct_upstream:
for t in itertools.chain(matched_tasks, also_include):
- # TODO (GH-52141): This should return scheduler types, but
currently we reuse SDK DAGNode.
- upstream = (u for u in cast("Iterable[SerializedOperator]",
t.upstream_list) if is_task(u))
- direct_upstreams.extend(upstream)
+ direct_upstreams.extend(u for u in t.upstream_list if
is_task(u))
# Make sure to not recursively deepcopy the dag or task_group while
copying the task.
# task_group is reset later
diff --git a/airflow-core/src/airflow/serialization/definitions/node.py
b/airflow-core/src/airflow/serialization/definitions/node.py
new file mode 100644
index 00000000000..b17e46234ab
--- /dev/null
+++ b/airflow-core/src/airflow/serialization/definitions/node.py
@@ -0,0 +1,51 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import abc
+from typing import TYPE_CHECKING
+
+from airflow._shared.dagnode.node import GenericDAGNode
+
+if TYPE_CHECKING:
+ from collections.abc import Sequence
+ from typing import TypeAlias
+
+ from airflow.models.mappedoperator import MappedOperator
+ from airflow.serialization.definitions.taskgroup import
SerializedTaskGroup # noqa: F401
+ from airflow.serialization.serialized_objects import
SerializedBaseOperator, SerializedDAG # noqa: F401
+
+ Operator: TypeAlias = SerializedBaseOperator | MappedOperator
+
+
+class DAGNode(GenericDAGNode["SerializedDAG", "Operator",
"SerializedTaskGroup"], metaclass=abc.ABCMeta):
+ """
+ Base class for a node in the graph of a workflow.
+
+ A node may be an operator or task group, either mapped or unmapped.
+ """
+
+ @property
+ @abc.abstractmethod
+ def roots(self) -> Sequence[DAGNode]:
+ raise NotImplementedError()
+
+ @property
+ @abc.abstractmethod
+ def leaves(self) -> Sequence[DAGNode]:
+ raise NotImplementedError()
diff --git a/airflow-core/src/airflow/serialization/definitions/taskgroup.py
b/airflow-core/src/airflow/serialization/definitions/taskgroup.py
index 3dcb62aa30f..c127353bcfe 100644
--- a/airflow-core/src/airflow/serialization/definitions/taskgroup.py
+++ b/airflow-core/src/airflow/serialization/definitions/taskgroup.py
@@ -27,7 +27,7 @@ from typing import TYPE_CHECKING
import attrs
import methodtools
-from airflow.sdk.definitions._internal.node import DAGNode
+from airflow.serialization.definitions.node import DAGNode
if TYPE_CHECKING:
from collections.abc import Generator, Iterator
@@ -45,8 +45,7 @@ class SerializedTaskGroup(DAGNode):
group_display_name: str | None = attrs.field()
prefix_group_id: bool = attrs.field()
parent_group: SerializedTaskGroup | None = attrs.field()
- # TODO (GH-52141): Replace DAGNode dependency.
- dag: SerializedDAG = attrs.field() # type: ignore[assignment]
+ dag: SerializedDAG = attrs.field()
tooltip: str = attrs.field()
default_args: dict[str, Any] = attrs.field(factory=dict)
diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py
b/airflow-core/src/airflow/serialization/serialized_objects.py
index ce136f998ce..2f6d4b362a8 100644
--- a/airflow-core/src/airflow/serialization/serialized_objects.py
+++ b/airflow-core/src/airflow/serialization/serialized_objects.py
@@ -53,7 +53,6 @@ from airflow.models.xcom import XComModel
from airflow.models.xcom_arg import SchedulerXComArg, deserialize_xcom_arg
from airflow.sdk import DAG, Asset, AssetAlias, BaseOperator, XComArg
from airflow.sdk.bases.operator import OPERATOR_DEFAULTS # TODO: Copy this
into the scheduler?
-from airflow.sdk.definitions._internal.node import DAGNode
from airflow.sdk.definitions.asset import (
AssetAliasEvent,
AssetAliasUniqueKey,
@@ -76,6 +75,7 @@ from airflow.serialization.definitions.assets import (
SerializedAssetUniqueKey,
)
from airflow.serialization.definitions.dag import SerializedDAG
+from airflow.serialization.definitions.node import DAGNode
from airflow.serialization.definitions.param import SerializedParam,
SerializedParamsDict
from airflow.serialization.definitions.taskgroup import
SerializedMappedTaskGroup, SerializedTaskGroup
from airflow.serialization.encoders import (
@@ -118,6 +118,7 @@ if TYPE_CHECKING:
from airflow.models.taskinstance import TaskInstance
from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator
# noqa: TC004
from airflow.sdk import BaseOperatorLink
+ from airflow.sdk.definitions._internal.node import DAGNode as SDKDAGNode
from airflow.serialization.json_schema import Validator
from airflow.task.trigger_rule import TriggerRule
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
@@ -1022,7 +1023,6 @@ class DependencyDetector:
yield from tt.asset_condition.iter_dag_dependencies(source="",
target=dag.dag_id)
-# TODO (GH-52141): Duplicate DAGNode in the scheduler.
class SerializedBaseOperator(DAGNode, BaseSerialization):
"""
A JSON serializable representation of operator.
@@ -1052,10 +1052,8 @@ class SerializedBaseOperator(DAGNode, BaseSerialization):
_task_display_name: str | None
_weight_rule: str | PriorityWeightStrategy = "downstream"
- # TODO (GH-52141): These should contain serialized containers, but
currently
- # this class inherits from an SDK one.
- dag: SerializedDAG | None = None # type: ignore[assignment]
- task_group: SerializedTaskGroup | None = None # type: ignore[assignment]
+ dag: SerializedDAG | None = None
+ task_group: SerializedTaskGroup | None = None
allow_nested_operators: bool = True
depends_on_past: bool = False
@@ -1159,8 +1157,7 @@ class SerializedBaseOperator(DAGNode, BaseSerialization):
def node_id(self) -> str:
return self.task_id
- # TODO (GH-52141): Replace DAGNode with a scheduler type.
- def get_dag(self) -> SerializedDAG | None: # type: ignore[override]
+ def get_dag(self) -> SerializedDAG | None:
return self.dag
@property
@@ -1680,7 +1677,7 @@ class SerializedBaseOperator(DAGNode, BaseSerialization):
return False
@classmethod
- def _is_excluded(cls, var: Any, attrname: str, op: DAGNode):
+ def _is_excluded(cls, var: Any, attrname: str, op: SDKDAGNode) -> bool:
"""
Determine if a variable is excluded from the serialized object.
diff --git a/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py
b/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py
index 3c0fc3be5c0..2effd1fef6e 100644
--- a/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py
+++ b/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py
@@ -620,12 +620,7 @@ class TriggerRuleDep(BaseTIDep):
if not task.is_teardown:
# a teardown cannot have any indirect setups
- relevant_setups: dict[str, MappedOperator |
SerializedBaseOperator] = {
- # TODO (GH-52141): This should return scheduler types, but
- # currently we reuse logic in SDK DAGNode.
- t.task_id: t # type: ignore[misc]
- for t in task.get_upstreams_only_setups()
- }
+ relevant_setups = {t.task_id: t for t in
task.get_upstreams_only_setups()}
if relevant_setups:
for status, changed in
_evaluate_setup_constraint(relevant_setups=relevant_setups):
yield status
diff --git a/airflow-core/src/airflow/utils/dag_edges.py
b/airflow-core/src/airflow/utils/dag_edges.py
index 94c6069f91b..1f3c0fbd254 100644
--- a/airflow-core/src/airflow/utils/dag_edges.py
+++ b/airflow-core/src/airflow/utils/dag_edges.py
@@ -23,8 +23,6 @@ from airflow.sdk.definitions._internal.abstractoperator
import AbstractOperator
from airflow.serialization.serialized_objects import SerializedBaseOperator,
SerializedDAG
if TYPE_CHECKING:
- from collections.abc import Iterable
-
from airflow.sdk import DAG
Operator: TypeAlias = MappedOperator | SerializedBaseOperator
@@ -118,11 +116,7 @@ def dag_edges(dag: DAG | SerializedDAG):
while tasks_to_trace:
tasks_to_trace_next: list[Operator] = []
for task in tasks_to_trace:
- # TODO (GH-52141): downstream_list on DAGNode needs to be able to
- # return scheduler types when used in scheduler, but SDK types when
- # used at runtime. This means DAGNode needs to be rewritten as a
- # generic class.
- for child in cast("Iterable[Operator]", task.downstream_list):
+ for child in task.downstream_list:
edge = (task.task_id, child.task_id)
if task.is_setup and child.is_teardown:
setup_teardown_edges.add(edge)
diff --git a/airflow-core/src/airflow/utils/dot_renderer.py
b/airflow-core/src/airflow/utils/dot_renderer.py
index 586789f1722..d0802972980 100644
--- a/airflow-core/src/airflow/utils/dot_renderer.py
+++ b/airflow-core/src/airflow/utils/dot_renderer.py
@@ -36,7 +36,6 @@ if TYPE_CHECKING:
import graphviz
from airflow.models import TaskInstance
- from airflow.models.taskmixin import DependencyMixin
from airflow.serialization.dag_dependency import DagDependency
else:
try:
@@ -136,7 +135,7 @@ def _draw_task_group(
def _draw_nodes(
- node: DependencyMixin, parent_graph: graphviz.Digraph, states_by_task_id:
dict[str, str | None] | None
+ node: object, parent_graph: graphviz.Digraph, states_by_task_id: dict[str,
str | None] | None
) -> None:
"""Draw the node and its children on the given parent_graph recursively."""
if isinstance(node, (BaseOperator, MappedOperator, SerializedBaseOperator,
SerializedMappedOperator)):
diff --git a/pyproject.toml b/pyproject.toml
index ba650152d60..423326157a4 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1346,6 +1346,7 @@ apache-airflow-kubernetes-tests = { workspace = true }
apache-airflow-providers = { workspace = true }
apache-aurflow-docker-stack = { workspace = true }
apache-airflow-shared-configuration = { workspace = true }
+apache-airflow-shared-dagnode = { workspace = true }
apache-airflow-shared-logging = { workspace = true }
apache-airflow-shared-module-loading = { workspace = true }
apache-airflow-shared-secrets-backend = { workspace = true }
diff --git a/shared/dagnode/pyproject.toml b/shared/dagnode/pyproject.toml
new file mode 100644
index 00000000000..d75d1cf3c54
--- /dev/null
+++ b/shared/dagnode/pyproject.toml
@@ -0,0 +1,48 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+[project]
+name = "apache-airflow-shared-dagnode"
+description = "Shared DAGNode logic for Airflow distributions"
+version = "0.0"
+classifiers = [
+ "Private :: Do Not Upload",
+]
+
+dependencies = [
+ "structlog>=25.4.0",
+]
+
+[dependency-groups]
+dev = [
+ "apache-airflow-devel-common",
+]
+
+[build-system]
+requires = ["hatchling"]
+build-backend = "hatchling.build"
+
+[tool.hatch.build.targets.wheel]
+packages = ["src/airflow_shared"]
+
+[tool.ruff]
+extend = "../../pyproject.toml"
+src = ["src"]
+
+[tool.ruff.lint.per-file-ignores]
+# Ignore Doc rules et al for anything outside of tests
+"!src/*" = ["D", "S101", "TRY002"]
diff --git a/shared/dagnode/src/airflow_shared/dagnode/__init__.py
b/shared/dagnode/src/airflow_shared/dagnode/__init__.py
new file mode 100644
index 00000000000..13a83393a91
--- /dev/null
+++ b/shared/dagnode/src/airflow_shared/dagnode/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/node.py
b/shared/dagnode/src/airflow_shared/dagnode/node.py
similarity index 50%
copy from task-sdk/src/airflow/sdk/definitions/_internal/node.py
copy to shared/dagnode/src/airflow_shared/dagnode/node.py
index 86979e442cd..2f4504818e2 100644
--- a/task-sdk/src/airflow/sdk/definitions/_internal/node.py
+++ b/shared/dagnode/src/airflow_shared/dagnode/node.py
@@ -17,66 +17,29 @@
from __future__ import annotations
-import re
-from abc import ABCMeta, abstractmethod
-from collections.abc import Collection, Iterable, Sequence
-from datetime import datetime
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Generic, TypeVar
import structlog
-from airflow.sdk.definitions._internal.mixins import DependencyMixin
-
if TYPE_CHECKING:
- from airflow.sdk.definitions.dag import DAG
- from airflow.sdk.definitions.edges import EdgeModifier
- from airflow.sdk.definitions.taskgroup import TaskGroup
- from airflow.sdk.types import Logger, Operator
- from airflow.serialization.enums import DagAttributeTypes
-
-
-KEY_REGEX = re.compile(r"^[\w.-]+$")
-GROUP_KEY_REGEX = re.compile(r"^[\w-]+$")
-CAMELCASE_TO_SNAKE_CASE_REGEX = re.compile(r"(?!^)([A-Z]+)")
-
-
-def validate_key(k: str, max_length: int = 250):
- """Validate value used as a key."""
- if not isinstance(k, str):
- raise TypeError(f"The key has to be a string and is {type(k)}:{k}")
- if (length := len(k)) > max_length:
- raise ValueError(f"The key has to be less than {max_length}
characters, not {length}")
- if not KEY_REGEX.match(k):
- raise ValueError(
- f"The key {k!r} has to be made of alphanumeric characters, dashes,
"
- f"dots, and underscores exclusively"
- )
+ from collections.abc import Collection, Iterable
+ from ..logging.types import Logger
-def validate_group_key(k: str, max_length: int = 200):
- """Validate value used as a group key."""
- if not isinstance(k, str):
- raise TypeError(f"The key has to be a string and is {type(k)}:{k}")
- if (length := len(k)) > max_length:
- raise ValueError(f"The key has to be less than {max_length}
characters, not {length}")
- if not GROUP_KEY_REGEX.match(k):
- raise ValueError(
- f"The key {k!r} has to be made of alphanumeric characters, dashes,
and underscores exclusively"
- )
+Dag = TypeVar("Dag")
+Task = TypeVar("Task")
+TaskGroup = TypeVar("TaskGroup")
-class DAGNode(DependencyMixin, metaclass=ABCMeta):
+class GenericDAGNode(Generic[Dag, Task, TaskGroup]):
"""
- A base class for a node in the graph of a workflow.
+ Generic class for a node in the graph of a workflow.
- A node may be an Operator or a Task Group, either mapped or unmapped.
+ A node may be an operator or task group, either mapped or unmapped.
"""
- dag: DAG | None
+ dag: Dag | None
task_group: TaskGroup | None
- """The task_group that contains this node"""
- start_date: datetime | None
- end_date: datetime | None
upstream_task_ids: set[str]
downstream_task_ids: set[str]
@@ -85,46 +48,12 @@ class DAGNode(DependencyMixin, metaclass=ABCMeta):
_cached_logger: Logger | None = None
def __init__(self):
+ super().__init__()
self.upstream_task_ids = set()
self.downstream_task_ids = set()
- super().__init__()
-
- def get_dag(self) -> DAG | None:
- return self.dag
-
- @property
- @abstractmethod
- def node_id(self) -> str:
- raise NotImplementedError()
-
- @property
- def label(self) -> str | None:
- tg = self.task_group
- if tg and tg.node_id and tg.prefix_group_id:
- # "task_group_id.task_id" -> "task_id"
- return self.node_id[len(tg.node_id) + 1 :]
- return self.node_id
-
- def has_dag(self) -> bool:
- return self.dag is not None
-
- @property
- def dag_id(self) -> str:
- """Returns dag id if it has one or an adhoc/meaningless ID."""
- if self.dag:
- return self.dag.dag_id
- return "_in_memory_dag_"
@property
def log(self) -> Logger:
- """
- Get a logger for this node.
-
- The logger name is determined by:
- 1. Using _logger_name if provided
- 2. Otherwise, using the class's module and qualified name
- 3. Prefixing with _log_config_logger_name if set
- """
if self._cached_logger is not None:
return self._cached_logger
@@ -145,101 +74,40 @@ class DAGNode(DependencyMixin, metaclass=ABCMeta):
return self._cached_logger
@property
- @abstractmethod
- def roots(self) -> Sequence[DAGNode]:
- raise NotImplementedError()
+ def dag_id(self) -> str:
+ if self.dag:
+ return self.dag.dag_id
+ return "_in_memory_dag_"
@property
- @abstractmethod
- def leaves(self) -> Sequence[DAGNode]:
+ def node_id(self) -> str:
raise NotImplementedError()
- def _set_relatives(
- self,
- task_or_task_list: DependencyMixin | Sequence[DependencyMixin],
- upstream: bool = False,
- edge_modifier: EdgeModifier | None = None,
- ) -> None:
- """Set relatives for the task or task list."""
- from airflow.sdk.bases.operator import BaseOperator
- from airflow.sdk.definitions.mappedoperator import MappedOperator
-
- if not isinstance(task_or_task_list, Sequence):
- task_or_task_list = [task_or_task_list]
-
- task_list: list[BaseOperator | MappedOperator] = []
- for task_object in task_or_task_list:
- task_object.update_relative(self, not upstream,
edge_modifier=edge_modifier)
- relatives = task_object.leaves if upstream else task_object.roots
- for task in relatives:
- if not isinstance(task, (BaseOperator, MappedOperator)):
- raise TypeError(
- f"Relationships can only be set between Operators;
received {task.__class__.__name__}"
- )
- task_list.append(task)
-
- # relationships can only be set if the tasks share a single Dag. Tasks
- # without a Dag are assigned to that Dag.
- dags: set[DAG] = {task.dag for task in [*self.roots, *task_list] if
task.has_dag() and task.dag}
-
- if len(dags) > 1:
- raise RuntimeError(f"Tried to set relationships between tasks in
more than one Dag: {dags}")
- if len(dags) == 1:
- dag = dags.pop()
- else:
- raise ValueError(
- "Tried to create relationships between tasks that don't have
Dags yet. "
- f"Set the Dag for at least one task and try again: {[self,
*task_list]}"
- )
-
- if not self.has_dag():
- # If this task does not yet have a Dag, add it to the same Dag as
the other task.
- self.dag = dag
-
- for task in task_list:
- if dag and not task.has_dag():
- # If the other task does not yet have a Dag, add it to the
same Dag as this task and
- dag.add_task(task) # type: ignore[arg-type]
- if upstream:
- task.downstream_task_ids.add(self.node_id)
- self.upstream_task_ids.add(task.node_id)
- if edge_modifier:
- edge_modifier.add_edge_info(dag, task.node_id,
self.node_id)
- else:
- self.downstream_task_ids.add(task.node_id)
- task.upstream_task_ids.add(self.node_id)
- if edge_modifier:
- edge_modifier.add_edge_info(dag, self.node_id,
task.node_id)
-
- def set_downstream(
- self,
- task_or_task_list: DependencyMixin | Sequence[DependencyMixin],
- edge_modifier: EdgeModifier | None = None,
- ) -> None:
- """Set a node (or nodes) to be directly downstream from the current
node."""
- self._set_relatives(task_or_task_list, upstream=False,
edge_modifier=edge_modifier)
-
- def set_upstream(
- self,
- task_or_task_list: DependencyMixin | Sequence[DependencyMixin],
- edge_modifier: EdgeModifier | None = None,
- ) -> None:
- """Set a node (or nodes) to be directly upstream from the current
node."""
- self._set_relatives(task_or_task_list, upstream=True,
edge_modifier=edge_modifier)
+ @property
+ def label(self) -> str | None:
+ tg = self.task_group
+ if tg and tg.node_id and tg.prefix_group_id:
+ # "task_group_id.task_id" -> "task_id"
+ return self.node_id[len(tg.node_id) + 1 :]
+ return self.node_id
@property
- def downstream_list(self) -> Iterable[Operator]:
- """List of nodes directly downstream."""
+ def upstream_list(self) -> Iterable[Task]:
if not self.dag:
raise RuntimeError(f"Operator {self} has not been assigned to a
Dag yet")
- return [self.dag.get_task(tid) for tid in self.downstream_task_ids]
+ return [self.dag.get_task(tid) for tid in self.upstream_task_ids]
@property
- def upstream_list(self) -> Iterable[Operator]:
- """List of nodes directly upstream."""
+ def downstream_list(self) -> Iterable[Task]:
if not self.dag:
raise RuntimeError(f"Operator {self} has not been assigned to a
Dag yet")
- return [self.dag.get_task(tid) for tid in self.upstream_task_ids]
+ return [self.dag.get_task(tid) for tid in self.downstream_task_ids]
+
+ def has_dag(self) -> bool:
+ return self.dag is not None
+
+ def get_dag(self) -> Dag | None:
+ return self.dag
def get_direct_relative_ids(self, upstream: bool = False) -> set[str]:
"""Get set of the direct relative ids to the current task, upstream or
downstream."""
@@ -247,7 +115,7 @@ class DAGNode(DependencyMixin, metaclass=ABCMeta):
return self.upstream_task_ids
return self.downstream_task_ids
- def get_direct_relatives(self, upstream: bool = False) ->
Iterable[Operator]:
+ def get_direct_relatives(self, upstream: bool = False) -> Iterable[Task]:
"""Get list of the direct relatives to the current task, upstream or
downstream."""
if upstream:
return self.upstream_list
@@ -283,14 +151,14 @@ class DAGNode(DependencyMixin, metaclass=ABCMeta):
return relatives
- def get_flat_relatives(self, upstream: bool = False) ->
Collection[Operator]:
+ def get_flat_relatives(self, upstream: bool = False) -> Collection[Task]:
"""Get a flat list of relatives, either upstream or downstream."""
dag = self.get_dag()
if not dag:
return set()
return [dag.task_dict[task_id] for task_id in
self.get_flat_relative_ids(upstream=upstream)]
- def get_upstreams_follow_setups(self) -> Iterable[Operator]:
+ def get_upstreams_follow_setups(self) -> Iterable[Task]:
"""All upstreams and, for each upstream setup, its respective
teardowns."""
for task in self.get_flat_relatives(upstream=True):
yield task
@@ -299,7 +167,7 @@ class DAGNode(DependencyMixin, metaclass=ABCMeta):
if t.is_teardown and t != self:
yield t
- def get_upstreams_only_setups_and_teardowns(self) -> Iterable[Operator]:
+ def get_upstreams_only_setups_and_teardowns(self) -> Iterable[Task]:
"""
Only *relevant* upstream setups and their teardowns.
@@ -323,7 +191,7 @@ class DAGNode(DependencyMixin, metaclass=ABCMeta):
if t.is_teardown and t != self:
yield t
- def get_upstreams_only_setups(self) -> Iterable[Operator]:
+ def get_upstreams_only_setups(self) -> Iterable[Task]:
"""
Return relevant upstream setups.
@@ -333,7 +201,3 @@ class DAGNode(DependencyMixin, metaclass=ABCMeta):
for task in self.get_upstreams_only_setups_and_teardowns():
if task.is_setup:
yield task
-
- def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
- """Serialize a task group's content; used by TaskGroupSerialization."""
- raise NotImplementedError()
diff --git a/shared/dagnode/tests/__init__.py b/shared/dagnode/tests/__init__.py
new file mode 100644
index 00000000000..13a83393a91
--- /dev/null
+++ b/shared/dagnode/tests/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/shared/dagnode/tests/dagnode/__init__.py
b/shared/dagnode/tests/dagnode/__init__.py
new file mode 100644
index 00000000000..13a83393a91
--- /dev/null
+++ b/shared/dagnode/tests/dagnode/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/shared/dagnode/tests/dagnode/test_node.py
b/shared/dagnode/tests/dagnode/test_node.py
new file mode 100644
index 00000000000..4259ca7555f
--- /dev/null
+++ b/shared/dagnode/tests/dagnode/test_node.py
@@ -0,0 +1,83 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from unittest import mock
+
+import attrs
+import pytest
+
+from airflow_shared.dagnode.node import GenericDAGNode
+
+
+class Task:
+ """Task type for tests."""
+
+
[email protected]
+class TaskGroup:
+ """Task group type for tests."""
+
+ node_id: str = attrs.field(init=False, default="test_group_id")
+ prefix_group_id: str
+
+
+class Dag:
+ """Dag type for tests."""
+
+ dag_id = "test_dag_id"
+
+
+class ConcreteDAGNode(GenericDAGNode[Dag, Task, TaskGroup]):
+ """Concrete DAGNode variant for tests."""
+
+ dag = None
+ task_group = None
+
+ @property
+ def node_id(self) -> str:
+ return "test_group_id.test_node_id"
+
+
+class TestDAGNode:
+ @pytest.fixture
+ def node(self):
+ return ConcreteDAGNode()
+
+ def test_log(self, node: ConcreteDAGNode) -> None:
+ assert node._cached_logger is None
+ with mock.patch("structlog.get_logger") as mock_get_logger:
+ log = node.log
+ assert log is node._cached_logger
+ assert mock_get_logger.mock_calls ==
[mock.call("tests.dagnode.test_node.ConcreteDAGNode")]
+
+ def test_dag_id(self, node: ConcreteDAGNode) -> None:
+ assert node.dag is None
+ assert node.dag_id == "_in_memory_dag_"
+ node.dag = Dag()
+ assert node.dag_id == "test_dag_id"
+
+ @pytest.mark.parametrize(
+ ("prefix_group_id", "expected_label"),
+ [(True, "test_node_id"), (False, "test_group_id.test_node_id")],
+ )
+ def test_label(self, node: ConcreteDAGNode, prefix_group_id: bool,
expected_label: str) -> None:
+ assert node.task_group is None
+ assert node.label == "test_group_id.test_node_id"
+ node.task_group = TaskGroup(prefix_group_id)
+ assert node.label == expected_label
diff --git a/task-sdk/pyproject.toml b/task-sdk/pyproject.toml
index 98d91657634..7c26f1babb5 100644
--- a/task-sdk/pyproject.toml
+++ b/task-sdk/pyproject.toml
@@ -116,8 +116,9 @@ path = "src/airflow/sdk/__init__.py"
[tool.hatch.build.targets.sdist.force-include]
"../shared/configuration/src/airflow_shared/configuration" =
"src/airflow/sdk/_shared/configuration"
-"../shared/module_loading/src/airflow_shared/module_loading" =
"src/airflow/sdk/_shared/module_loading"
+"../shared/dagnode/src/airflow_shared/dagnode" =
"src/airflow/sdk/_shared/dagnode"
"../shared/logging/src/airflow_shared/logging" =
"src/airflow/sdk/_shared/logging"
+"../shared/module_loading/src/airflow_shared/module_loading" =
"src/airflow/sdk/_shared/module_loading"
"../shared/observability/src/airflow_shared/observability" =
"src/airflow/_shared/observability"
"../shared/secrets_backend/src/airflow_shared/secrets_backend" =
"src/airflow/sdk/_shared/secrets_backend"
"../shared/secrets_masker/src/airflow_shared/secrets_masker" =
"src/airflow/sdk/_shared/secrets_masker"
@@ -264,6 +265,7 @@ tmp_path_retention_policy = "failed"
[tool.airflow]
shared_distributions = [
"apache-airflow-shared-configuration",
+ "apache-airflow-shared-dagnode",
"apache-airflow-shared-logging",
"apache-airflow-shared-module-loading",
"apache-airflow-shared-secrets-backend",
diff --git a/task-sdk/src/airflow/sdk/_shared/dagnode
b/task-sdk/src/airflow/sdk/_shared/dagnode
new file mode 120000
index 00000000000..9455ba69b08
--- /dev/null
+++ b/task-sdk/src/airflow/sdk/_shared/dagnode
@@ -0,0 +1 @@
+../../../../../shared/dagnode/src/airflow_shared/dagnode
\ No newline at end of file
diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/mixins.py
b/task-sdk/src/airflow/sdk/definitions/_internal/mixins.py
index d14d6299159..e186ef97e64 100644
--- a/task-sdk/src/airflow/sdk/definitions/_internal/mixins.py
+++ b/task-sdk/src/airflow/sdk/definitions/_internal/mixins.py
@@ -31,8 +31,6 @@ if TYPE_CHECKING:
Operator: TypeAlias = BaseOperator | MappedOperator
-# TODO: Should this all just live on DAGNode?
-
class DependencyMixin:
"""Mixing implementing common dependency setting methods like >> and <<."""
diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/node.py
b/task-sdk/src/airflow/sdk/definitions/_internal/node.py
index 86979e442cd..b2cb651efe1 100644
--- a/task-sdk/src/airflow/sdk/definitions/_internal/node.py
+++ b/task-sdk/src/airflow/sdk/definitions/_internal/node.py
@@ -19,19 +19,18 @@ from __future__ import annotations
import re
from abc import ABCMeta, abstractmethod
-from collections.abc import Collection, Iterable, Sequence
+from collections.abc import Sequence
from datetime import datetime
from typing import TYPE_CHECKING, Any
-import structlog
-
+from airflow.sdk._shared.dagnode.node import GenericDAGNode
from airflow.sdk.definitions._internal.mixins import DependencyMixin
if TYPE_CHECKING:
from airflow.sdk.definitions.dag import DAG
from airflow.sdk.definitions.edges import EdgeModifier
- from airflow.sdk.definitions.taskgroup import TaskGroup
- from airflow.sdk.types import Logger, Operator
+ from airflow.sdk.definitions.taskgroup import TaskGroup # noqa: F401
+ from airflow.sdk.types import Operator # noqa: F401
from airflow.serialization.enums import DagAttributeTypes
@@ -65,84 +64,15 @@ def validate_group_key(k: str, max_length: int = 200):
)
-class DAGNode(DependencyMixin, metaclass=ABCMeta):
+class DAGNode(GenericDAGNode["DAG", "Operator", "TaskGroup"], DependencyMixin,
metaclass=ABCMeta):
"""
A base class for a node in the graph of a workflow.
A node may be an Operator or a Task Group, either mapped or unmapped.
"""
- dag: DAG | None
- task_group: TaskGroup | None
- """The task_group that contains this node"""
start_date: datetime | None
end_date: datetime | None
- upstream_task_ids: set[str]
- downstream_task_ids: set[str]
-
- _log_config_logger_name: str | None = None
- _logger_name: str | None = None
- _cached_logger: Logger | None = None
-
- def __init__(self):
- self.upstream_task_ids = set()
- self.downstream_task_ids = set()
- super().__init__()
-
- def get_dag(self) -> DAG | None:
- return self.dag
-
- @property
- @abstractmethod
- def node_id(self) -> str:
- raise NotImplementedError()
-
- @property
- def label(self) -> str | None:
- tg = self.task_group
- if tg and tg.node_id and tg.prefix_group_id:
- # "task_group_id.task_id" -> "task_id"
- return self.node_id[len(tg.node_id) + 1 :]
- return self.node_id
-
- def has_dag(self) -> bool:
- return self.dag is not None
-
- @property
- def dag_id(self) -> str:
- """Returns dag id if it has one or an adhoc/meaningless ID."""
- if self.dag:
- return self.dag.dag_id
- return "_in_memory_dag_"
-
- @property
- def log(self) -> Logger:
- """
- Get a logger for this node.
-
- The logger name is determined by:
- 1. Using _logger_name if provided
- 2. Otherwise, using the class's module and qualified name
- 3. Prefixing with _log_config_logger_name if set
- """
- if self._cached_logger is not None:
- return self._cached_logger
-
- typ = type(self)
-
- logger_name: str = (
- self._logger_name if self._logger_name is not None else
f"{typ.__module__}.{typ.__qualname__}"
- )
-
- if self._log_config_logger_name:
- logger_name = (
- f"{self._log_config_logger_name}.{logger_name}"
- if logger_name
- else self._log_config_logger_name
- )
-
- self._cached_logger = structlog.get_logger(logger_name)
- return self._cached_logger
@property
@abstractmethod
@@ -227,113 +157,6 @@ class DAGNode(DependencyMixin, metaclass=ABCMeta):
"""Set a node (or nodes) to be directly upstream from the current
node."""
self._set_relatives(task_or_task_list, upstream=True,
edge_modifier=edge_modifier)
- @property
- def downstream_list(self) -> Iterable[Operator]:
- """List of nodes directly downstream."""
- if not self.dag:
- raise RuntimeError(f"Operator {self} has not been assigned to a
Dag yet")
- return [self.dag.get_task(tid) for tid in self.downstream_task_ids]
-
- @property
- def upstream_list(self) -> Iterable[Operator]:
- """List of nodes directly upstream."""
- if not self.dag:
- raise RuntimeError(f"Operator {self} has not been assigned to a
Dag yet")
- return [self.dag.get_task(tid) for tid in self.upstream_task_ids]
-
- def get_direct_relative_ids(self, upstream: bool = False) -> set[str]:
- """Get set of the direct relative ids to the current task, upstream or
downstream."""
- if upstream:
- return self.upstream_task_ids
- return self.downstream_task_ids
-
- def get_direct_relatives(self, upstream: bool = False) ->
Iterable[Operator]:
- """Get list of the direct relatives to the current task, upstream or
downstream."""
- if upstream:
- return self.upstream_list
- return self.downstream_list
-
- def get_flat_relative_ids(self, *, upstream: bool = False) -> set[str]:
- """
- Get a flat set of relative IDs, upstream or downstream.
-
- Will recurse each relative found in the direction specified.
-
- :param upstream: Whether to look for upstream or downstream relatives.
- """
- dag = self.get_dag()
- if not dag:
- return set()
-
- relatives: set[str] = set()
-
- # This is intentionally implemented as a loop, instead of calling
- # get_direct_relative_ids() recursively, since Python has significant
- # limitation on stack level, and a recursive implementation can blow up
- # if a DAG contains very long routes.
- task_ids_to_trace = self.get_direct_relative_ids(upstream)
- while task_ids_to_trace:
- task_ids_to_trace_next: set[str] = set()
- for task_id in task_ids_to_trace:
- if task_id in relatives:
- continue
-
task_ids_to_trace_next.update(dag.task_dict[task_id].get_direct_relative_ids(upstream))
- relatives.add(task_id)
- task_ids_to_trace = task_ids_to_trace_next
-
- return relatives
-
- def get_flat_relatives(self, upstream: bool = False) ->
Collection[Operator]:
- """Get a flat list of relatives, either upstream or downstream."""
- dag = self.get_dag()
- if not dag:
- return set()
- return [dag.task_dict[task_id] for task_id in
self.get_flat_relative_ids(upstream=upstream)]
-
- def get_upstreams_follow_setups(self) -> Iterable[Operator]:
- """All upstreams and, for each upstream setup, its respective
teardowns."""
- for task in self.get_flat_relatives(upstream=True):
- yield task
- if task.is_setup:
- for t in task.downstream_list:
- if t.is_teardown and t != self:
- yield t
-
- def get_upstreams_only_setups_and_teardowns(self) -> Iterable[Operator]:
- """
- Only *relevant* upstream setups and their teardowns.
-
- This method is meant to be used when we are clearing the task
(non-upstream) and we need
- to add in the *relevant* setups and their teardowns.
-
- Relevant in this case means, the setup has a teardown that is
downstream of ``self``,
- or the setup has no teardowns.
- """
- downstream_teardown_ids = {
- x.task_id for x in self.get_flat_relatives(upstream=False) if
x.is_teardown
- }
- for task in self.get_flat_relatives(upstream=True):
- if not task.is_setup:
- continue
- has_no_teardowns = not any(x.is_teardown for x in
task.downstream_list)
- # if task has no teardowns or has teardowns downstream of self
- if has_no_teardowns or
task.downstream_task_ids.intersection(downstream_teardown_ids):
- yield task
- for t in task.downstream_list:
- if t.is_teardown and t != self:
- yield t
-
- def get_upstreams_only_setups(self) -> Iterable[Operator]:
- """
- Return relevant upstream setups.
-
- This method is meant to be used when we are checking task dependencies
where we need
- to wait for all the upstream setups to complete before we can run the
task.
- """
- for task in self.get_upstreams_only_setups_and_teardowns():
- if task.is_setup:
- yield task
-
def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
"""Serialize a task group's content; used by TaskGroupSerialization."""
raise NotImplementedError()
diff --git a/task-sdk/src/airflow/sdk/definitions/dag.py
b/task-sdk/src/airflow/sdk/definitions/dag.py
index 887493e1ca7..479958b99d7 100644
--- a/task-sdk/src/airflow/sdk/definitions/dag.py
+++ b/task-sdk/src/airflow/sdk/definitions/dag.py
@@ -915,8 +915,7 @@ class DAG:
direct_upstreams: list[Operator] = []
if include_direct_upstream:
for t in itertools.chain(matched_tasks, also_include):
- upstream = (u for u in t.upstream_list if is_task(u))
- direct_upstreams.extend(upstream)
+ direct_upstreams.extend(u for u in t.upstream_list if
is_task(u))
# Make sure to not recursively deepcopy the dag or task_group while
copying the task.
# task_group is reset later