This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new b58f5d30dcd Deprecate and move `airflow.utils.task_group` to SDK
(#53450)
b58f5d30dcd is described below
commit b58f5d30dcd2328cfe7a6f9562882c820ee7b923
Author: Ephraim Anierobi <[email protected]>
AuthorDate: Sat Jul 19 13:23:27 2025 +0100
Deprecate and move `airflow.utils.task_group` to SDK (#53450)
* Deprecate and move `airflow.utils.task_group` to SDK
Some part of the module is already moved to SDK. This completes the
move
* fixup! Deprecate and move `airflow.utils.task_group` to SDK
* fixup! fixup! Deprecate and move `airflow.utils.task_group` to SDK
---
.../airflow/api_fastapi/core_api/routes/ui/grid.py | 2 +-
.../api_fastapi/core_api/routes/ui/structure.py | 2 +-
.../api_fastapi/core_api/services/ui/grid.py | 3 +-
.../airflow/example_dags/example_setup_teardown.py | 2 +-
.../src/airflow/example_dags/example_task_group.py | 2 +-
airflow-core/src/airflow/models/taskinstance.py | 3 +-
.../src/airflow/ti_deps/deps/trigger_rule_dep.py | 2 +-
airflow-core/src/airflow/utils/__init__.py | 6 +
airflow-core/src/airflow/utils/dot_renderer.py | 2 +-
airflow-core/src/airflow/utils/task_group.py | 128 ---------------------
.../api_fastapi/core_api/routes/ui/test_grid.py | 2 +-
airflow-core/tests/unit/models/test_dagrun.py | 2 +-
.../tests/unit/models/test_mappedoperator.py | 2 +-
.../tests/unit/models/test_taskinstance.py | 2 +-
.../unit/serialization/test_dag_serialization.py | 2 +-
.../unit/serialization/test_serialized_objects.py | 2 +-
airflow-core/tests/unit/utils/test_dag_cycle.py | 2 +-
airflow-core/tests/unit/utils/test_dot_renderer.py | 2 +-
airflow-core/tests/unit/utils/test_edgemodifier.py | 2 +-
airflow-core/tests/unit/utils/test_task_group.py | 2 +-
.../tests/unit/standard/decorators/test_python.py | 3 +-
.../standard/operators/test_branch_operator.py | 7 +-
.../standard/sensors/test_external_task_sensor.py | 8 +-
.../unit/standard/utils/test_sensor_helper.py | 8 +-
task-sdk/src/airflow/sdk/definitions/taskgroup.py | 85 +++++++++++---
25 files changed, 112 insertions(+), 171 deletions(-)
diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/grid.py
b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/grid.py
index 5c7494d3b70..f4c6d7d558c 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/grid.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/grid.py
@@ -52,7 +52,7 @@ from airflow.models.dag_version import DagVersion
from airflow.models.dagrun import DagRun
from airflow.models.serialized_dag import SerializedDagModel
from airflow.models.taskinstance import TaskInstance
-from airflow.utils.task_group import (
+from airflow.sdk.definitions.taskgroup import (
get_task_group_children_getter,
task_group_to_dict_grid,
)
diff --git
a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/structure.py
b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/structure.py
index fff7325f41a..c308ae21432 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/structure.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/structure.py
@@ -32,8 +32,8 @@ from airflow.api_fastapi.core_api.services.ui.structure
import (
)
from airflow.models.dag_version import DagVersion
from airflow.models.serialized_dag import SerializedDagModel
+from airflow.sdk.definitions.taskgroup import task_group_to_dict
from airflow.utils.dag_edges import dag_edges
-from airflow.utils.task_group import task_group_to_dict
structure_router = AirflowRouter(tags=["Structure"], prefix="/structure")
diff --git a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py
b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py
index 3bf22517da2..b8e569f877e 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py
@@ -25,9 +25,8 @@ import structlog
from airflow.api_fastapi.common.parameters import state_priority
from airflow.models.taskmap import TaskMap
from airflow.sdk.definitions.mappedoperator import MappedOperator
-from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup
+from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup,
get_task_group_children_getter
from airflow.serialization.serialized_objects import SerializedBaseOperator
-from airflow.utils.task_group import get_task_group_children_getter
log = structlog.get_logger(logger_name=__name__)
diff --git a/airflow-core/src/airflow/example_dags/example_setup_teardown.py
b/airflow-core/src/airflow/example_dags/example_setup_teardown.py
index a36e79a55e5..052377736ea 100644
--- a/airflow-core/src/airflow/example_dags/example_setup_teardown.py
+++ b/airflow-core/src/airflow/example_dags/example_setup_teardown.py
@@ -23,7 +23,7 @@ import pendulum
from airflow.providers.standard.operators.bash import BashOperator
from airflow.sdk import DAG
-from airflow.utils.task_group import TaskGroup
+from airflow.sdk.definitions.taskgroup import TaskGroup
with DAG(
dag_id="example_setup_teardown",
diff --git a/airflow-core/src/airflow/example_dags/example_task_group.py
b/airflow-core/src/airflow/example_dags/example_task_group.py
index e83ac2e9989..c882c269c47 100644
--- a/airflow-core/src/airflow/example_dags/example_task_group.py
+++ b/airflow-core/src/airflow/example_dags/example_task_group.py
@@ -24,7 +24,7 @@ import pendulum
from airflow.providers.standard.operators.bash import BashOperator
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.sdk import DAG
-from airflow.utils.task_group import TaskGroup
+from airflow.sdk.definitions.taskgroup import TaskGroup
# [START howto_task_group]
with DAG(
diff --git a/airflow-core/src/airflow/models/taskinstance.py
b/airflow-core/src/airflow/models/taskinstance.py
index 40ad420d430..ae55f6a65ad 100644
--- a/airflow-core/src/airflow/models/taskinstance.py
+++ b/airflow-core/src/airflow/models/taskinstance.py
@@ -124,11 +124,10 @@ if TYPE_CHECKING:
from airflow.sdk.definitions.asset import AssetNameRef, AssetUniqueKey,
AssetUriRef
from airflow.sdk.definitions.dag import DAG
from airflow.sdk.definitions.mappedoperator import MappedOperator
- from airflow.sdk.definitions.taskgroup import MappedTaskGroup
+ from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup
from airflow.sdk.types import RuntimeTaskInstanceProtocol
from airflow.serialization.serialized_objects import SerializedBaseOperator
from airflow.utils.context import Context
- from airflow.utils.task_group import TaskGroup
Operator: TypeAlias = BaseOperator | MappedOperator
diff --git a/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py
b/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py
index 9a4f30c8dff..9a0096f569d 100644
--- a/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py
+++ b/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py
@@ -26,9 +26,9 @@ from typing import TYPE_CHECKING, NamedTuple
from sqlalchemy import and_, func, or_, select
from airflow.models.taskinstance import PAST_DEPENDS_MET
+from airflow.sdk.definitions.taskgroup import MappedTaskGroup
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
from airflow.utils.state import TaskInstanceState
-from airflow.utils.task_group import MappedTaskGroup
from airflow.utils.trigger_rule import TriggerRule as TR
if TYPE_CHECKING:
diff --git a/airflow-core/src/airflow/utils/__init__.py
b/airflow-core/src/airflow/utils/__init__.py
index 153438c8c16..563d133eb2b 100644
--- a/airflow-core/src/airflow/utils/__init__.py
+++ b/airflow-core/src/airflow/utils/__init__.py
@@ -28,6 +28,12 @@ __deprecated_classes = {
"xcom": {
"XCOM_RETURN_KEY": "airflow.models.xcom.XCOM_RETURN_KEY",
},
+ "task_group": {
+ "TaskGroup": "airflow.sdk.definitions.taskgroup.TaskGroup",
+ "MappedTaskGroup": "airflow.sdk.definitions.taskgroup.MappedOperator",
+ "get_task_group_children_getter":
"airflow.sdk.definitions.taskgroup.get_task_group_children_getter",
+ "task_group_to_dict":
"airflow.sdk.definitions.taskgroup.task_group_to_dict",
+ },
}
add_deprecated_classes(__deprecated_classes, __name__)
diff --git a/airflow-core/src/airflow/utils/dot_renderer.py
b/airflow-core/src/airflow/utils/dot_renderer.py
index 5b624d30534..50911572d39 100644
--- a/airflow-core/src/airflow/utils/dot_renderer.py
+++ b/airflow-core/src/airflow/utils/dot_renderer.py
@@ -26,10 +26,10 @@ from typing import TYPE_CHECKING, Any
from airflow.exceptions import AirflowException
from airflow.sdk import BaseOperator
from airflow.sdk.definitions.mappedoperator import MappedOperator
+from airflow.sdk.definitions.taskgroup import TaskGroup
from airflow.serialization.serialized_objects import SerializedBaseOperator
from airflow.utils.dag_edges import dag_edges
from airflow.utils.state import State
-from airflow.utils.task_group import TaskGroup
if TYPE_CHECKING:
import graphviz
diff --git a/airflow-core/src/airflow/utils/task_group.py
b/airflow-core/src/airflow/utils/task_group.py
deleted file mode 100644
index 675f9f19faf..00000000000
--- a/airflow-core/src/airflow/utils/task_group.py
+++ /dev/null
@@ -1,128 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""A collection of closely related tasks on the same DAG that should be
grouped together visually."""
-
-from __future__ import annotations
-
-from collections.abc import Callable
-from functools import cache
-from operator import methodcaller
-from typing import TYPE_CHECKING
-
-import airflow.sdk.definitions.taskgroup
-from airflow.configuration import conf
-
-if TYPE_CHECKING:
- from typing import TypeAlias
-
-TaskGroup: TypeAlias = airflow.sdk.definitions.taskgroup.TaskGroup
-MappedTaskGroup: TypeAlias = airflow.sdk.definitions.taskgroup.MappedTaskGroup
-
-
-@cache
-def get_task_group_children_getter() -> Callable:
- """Get the Task Group Children Getter for the DAG."""
- sort_order = conf.get("api", "grid_view_sorting_order")
- if sort_order == "topological":
- return methodcaller("topological_sort")
- return methodcaller("hierarchical_alphabetical_sort")
-
-
-def task_group_to_dict(task_item_or_group, parent_group_is_mapped=False):
- """Create a nested dict representation of this TaskGroup and its children
used to construct the Graph."""
- from airflow.sdk.definitions._internal.abstractoperator import
AbstractOperator
- from airflow.sdk.definitions.mappedoperator import MappedOperator
- from airflow.serialization.serialized_objects import SerializedBaseOperator
-
- if isinstance(task := task_item_or_group, (AbstractOperator,
SerializedBaseOperator)):
- node_operator = {
- "id": task.task_id,
- "label": task.label,
- "operator": task.operator_name,
- "type": "task",
- }
- if task.is_setup:
- node_operator["setup_teardown_type"] = "setup"
- elif task.is_teardown:
- node_operator["setup_teardown_type"] = "teardown"
- if isinstance(task, MappedOperator) or parent_group_is_mapped:
- node_operator["is_mapped"] = True
- return node_operator
-
- task_group = task_item_or_group
- is_mapped = isinstance(task_group, MappedTaskGroup)
- children = [
- task_group_to_dict(child,
parent_group_is_mapped=parent_group_is_mapped or is_mapped)
- for child in get_task_group_children_getter()(task_group)
- ]
-
- if task_group.upstream_group_ids or task_group.upstream_task_ids:
- # This is the join node used to reduce the number of edges between two
TaskGroup.
- children.append({"id": task_group.upstream_join_id, "label": "",
"type": "join"})
-
- if task_group.downstream_group_ids or task_group.downstream_task_ids:
- # This is the join node used to reduce the number of edges between two
TaskGroup.
- children.append({"id": task_group.downstream_join_id, "label": "",
"type": "join"})
-
- return {
- "id": task_group.group_id,
- "label": task_group.label,
- "tooltip": task_group.tooltip,
- "is_mapped": is_mapped,
- "children": children,
- "type": "task",
- }
-
-
-def task_group_to_dict_grid(task_item_or_group, parent_group_is_mapped=False):
- """Create a nested dict representation of this TaskGroup and its children
used to construct the Graph."""
- from airflow.sdk.definitions._internal.abstractoperator import
AbstractOperator
- from airflow.sdk.definitions.mappedoperator import MappedOperator
- from airflow.serialization.serialized_objects import SerializedBaseOperator
-
- if isinstance(task := task_item_or_group, (AbstractOperator,
SerializedBaseOperator)):
- is_mapped = None
- if isinstance(task, MappedOperator) or parent_group_is_mapped:
- is_mapped = True
- setup_teardown_type = None
- if task.is_setup is True:
- setup_teardown_type = "setup"
- elif task.is_teardown is True:
- setup_teardown_type = "teardown"
- return {
- "id": task.task_id,
- "label": task.label,
- "is_mapped": is_mapped,
- "children": None,
- "setup_teardown_type": setup_teardown_type,
- }
-
- task_group = task_item_or_group
- task_group_sort = get_task_group_children_getter()
- is_mapped_group = isinstance(task_group, MappedTaskGroup)
- children = [
- task_group_to_dict_grid(x,
parent_group_is_mapped=parent_group_is_mapped or is_mapped_group)
- for x in task_group_sort(task_group)
- ]
-
- return {
- "id": task_group.group_id,
- "label": task_group.label,
- "is_mapped": is_mapped_group or None,
- "children": children or None,
- }
diff --git
a/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_grid.py
b/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_grid.py
index 314c6bf8437..9db616f4dbd 100644
--- a/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_grid.py
+++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_grid.py
@@ -29,10 +29,10 @@ from airflow.models.dag import DagModel
from airflow.models.taskinstance import TaskInstance
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.sdk import task_group
+from airflow.sdk.definitions.taskgroup import TaskGroup
from airflow.utils import timezone
from airflow.utils.session import provide_session
from airflow.utils.state import DagRunState, TaskInstanceState
-from airflow.utils.task_group import TaskGroup
from airflow.utils.types import DagRunTriggeredByType, DagRunType
from tests_common.test_utils.db import clear_db_assets, clear_db_dags,
clear_db_runs, clear_db_serialized_dags
diff --git a/airflow-core/tests/unit/models/test_dagrun.py
b/airflow-core/tests/unit/models/test_dagrun.py
index 32445dd2958..50440d4f44e 100644
--- a/airflow-core/tests/unit/models/test_dagrun.py
+++ b/airflow-core/tests/unit/models/test_dagrun.py
@@ -2774,7 +2774,7 @@ def test_teardown_and_fail_fast(dag_maker):
in this case, the second teardown skips because its setup skips.
"""
from airflow.sdk import task as task_decorator
- from airflow.utils.task_group import TaskGroup
+ from airflow.sdk.definitions.taskgroup import TaskGroup
with dag_maker(fail_fast=True) as dag:
for num in (1, 2):
diff --git a/airflow-core/tests/unit/models/test_mappedoperator.py
b/airflow-core/tests/unit/models/test_mappedoperator.py
index 2922ee1c23e..ea9f35b63f7 100644
--- a/airflow-core/tests/unit/models/test_mappedoperator.py
+++ b/airflow-core/tests/unit/models/test_mappedoperator.py
@@ -34,8 +34,8 @@ from airflow.models.taskinstance import TaskInstance
from airflow.models.taskmap import TaskMap
from airflow.providers.standard.operators.python import PythonOperator
from airflow.sdk import setup, task, task_group, teardown
+from airflow.sdk.definitions.taskgroup import TaskGroup
from airflow.utils.state import TaskInstanceState
-from airflow.utils.task_group import TaskGroup
from airflow.utils.trigger_rule import TriggerRule
from tests_common.test_utils.mapping import expand_mapped_task
diff --git a/airflow-core/tests/unit/models/test_taskinstance.py
b/airflow-core/tests/unit/models/test_taskinstance.py
index d1b3a750db3..7e041917056 100644
--- a/airflow-core/tests/unit/models/test_taskinstance.py
+++ b/airflow-core/tests/unit/models/test_taskinstance.py
@@ -67,6 +67,7 @@ from airflow.sdk import BaseSensorOperator, task, task_group
from airflow.sdk.api.datamodels._generated import AssetEventResponse,
AssetResponse
from airflow.sdk.definitions.asset import Asset, AssetAlias
from airflow.sdk.definitions.param import process_params
+from airflow.sdk.definitions.taskgroup import TaskGroup
from airflow.sdk.execution_time.comms import (
AssetEventsResult,
)
@@ -83,7 +84,6 @@ from airflow.utils.db import merge_conn
from airflow.utils.session import create_session, provide_session
from airflow.utils.span_status import SpanStatus
from airflow.utils.state import DagRunState, State, TaskInstanceState
-from airflow.utils.task_group import TaskGroup
from airflow.utils.types import DagRunTriggeredByType, DagRunType
from tests_common.test_utils import db
diff --git a/airflow-core/tests/unit/serialization/test_dag_serialization.py
b/airflow-core/tests/unit/serialization/test_dag_serialization.py
index 0331b325ac1..c61284c7850 100644
--- a/airflow-core/tests/unit/serialization/test_dag_serialization.py
+++ b/airflow-core/tests/unit/serialization/test_dag_serialization.py
@@ -68,6 +68,7 @@ from airflow.sdk.bases.operator import BaseOperator
from airflow.sdk.definitions._internal.expandinput import EXPAND_INPUT_EMPTY
from airflow.sdk.definitions.asset import Asset, AssetUniqueKey
from airflow.sdk.definitions.param import Param, ParamsDict
+from airflow.sdk.definitions.taskgroup import TaskGroup
from airflow.security import permissions
from airflow.serialization.enums import Encoding
from airflow.serialization.json_schema import load_dag_schema_dict
@@ -84,7 +85,6 @@ from airflow.triggers.base import StartTriggerArgs
from airflow.utils import timezone
from airflow.utils.module_loading import qualname
from airflow.utils.operator_resources import Resources
-from airflow.utils.task_group import TaskGroup
from tests_common.test_utils.config import conf_vars
from tests_common.test_utils.markers import
skip_if_force_lowest_dependencies_marker
diff --git a/airflow-core/tests/unit/serialization/test_serialized_objects.py
b/airflow-core/tests/unit/serialization/test_serialized_objects.py
index 402c169cc08..dc944035c75 100644
--- a/airflow-core/tests/unit/serialization/test_serialized_objects.py
+++ b/airflow-core/tests/unit/serialization/test_serialized_objects.py
@@ -59,6 +59,7 @@ from airflow.sdk.definitions.asset import (
from airflow.sdk.definitions.deadline import DeadlineAlert,
DeadlineAlertFields, DeadlineReference
from airflow.sdk.definitions.decorators import task
from airflow.sdk.definitions.param import Param
+from airflow.sdk.definitions.taskgroup import TaskGroup
from airflow.sdk.execution_time.context import OutletEventAccessor,
OutletEventAccessors
from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
from airflow.serialization.serialized_objects import BaseSerialization,
LazyDeserializedDAG, SerializedDAG
@@ -68,7 +69,6 @@ from airflow.utils import timezone
from airflow.utils.db import LazySelectSequence
from airflow.utils.operator_resources import Resources
from airflow.utils.state import DagRunState, State
-from airflow.utils.task_group import TaskGroup
from airflow.utils.types import DagRunType
from unit.models import DEFAULT_DATE
diff --git a/airflow-core/tests/unit/utils/test_dag_cycle.py
b/airflow-core/tests/unit/utils/test_dag_cycle.py
index c436af01c7d..e17ff7c5f3c 100644
--- a/airflow-core/tests/unit/utils/test_dag_cycle.py
+++ b/airflow-core/tests/unit/utils/test_dag_cycle.py
@@ -22,8 +22,8 @@ from airflow.exceptions import AirflowDagCycleException
from airflow.models.dag import DAG
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.sdk import Label
+from airflow.sdk.definitions.taskgroup import TaskGroup
from airflow.utils.dag_cycle_tester import check_cycle
-from airflow.utils.task_group import TaskGroup
from unit.models import DEFAULT_DATE
diff --git a/airflow-core/tests/unit/utils/test_dot_renderer.py
b/airflow-core/tests/unit/utils/test_dot_renderer.py
index d3ba7acaab7..240876ec7f4 100644
--- a/airflow-core/tests/unit/utils/test_dot_renderer.py
+++ b/airflow-core/tests/unit/utils/test_dot_renderer.py
@@ -24,10 +24,10 @@ import pytest
from airflow.models.dag import DAG
from airflow.providers.standard.operators.empty import EmptyOperator
+from airflow.sdk.definitions.taskgroup import TaskGroup
from airflow.serialization.dag_dependency import DagDependency
from airflow.utils import dot_renderer, timezone
from airflow.utils.state import State
-from airflow.utils.task_group import TaskGroup
from tests_common.test_utils.compat import BashOperator
from tests_common.test_utils.db import clear_db_dags
diff --git a/airflow-core/tests/unit/utils/test_edgemodifier.py
b/airflow-core/tests/unit/utils/test_edgemodifier.py
index 0885230e113..98ea514af3c 100644
--- a/airflow-core/tests/unit/utils/test_edgemodifier.py
+++ b/airflow-core/tests/unit/utils/test_edgemodifier.py
@@ -25,8 +25,8 @@ from airflow.models.xcom_arg import XComArg
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.providers.standard.operators.python import PythonOperator
from airflow.sdk import Label
+from airflow.sdk.definitions.taskgroup import TaskGroup
from airflow.utils.dag_edges import dag_edges
-from airflow.utils.task_group import TaskGroup
DEFAULT_ARGS = {
"owner": "test",
diff --git a/airflow-core/tests/unit/utils/test_task_group.py
b/airflow-core/tests/unit/utils/test_task_group.py
index 8d6e6c0b334..7af3b3326f4 100644
--- a/airflow-core/tests/unit/utils/test_task_group.py
+++ b/airflow-core/tests/unit/utils/test_task_group.py
@@ -34,8 +34,8 @@ from airflow.sdk import (
task_group as task_group_decorator,
teardown,
)
+from airflow.sdk.definitions.taskgroup import TaskGroup, task_group_to_dict
from airflow.utils.dag_edges import dag_edges
-from airflow.utils.task_group import TaskGroup, task_group_to_dict
from tests_common.test_utils.compat import BashOperator, PythonOperator
from unit.models import DEFAULT_DATE
diff --git a/providers/standard/tests/unit/standard/decorators/test_python.py
b/providers/standard/tests/unit/standard/decorators/test_python.py
index 1d65c02f533..3300c51a9fd 100644
--- a/providers/standard/tests/unit/standard/decorators/test_python.py
+++ b/providers/standard/tests/unit/standard/decorators/test_python.py
@@ -37,6 +37,7 @@ if AIRFLOW_V_3_0_PLUS:
from airflow.sdk.bases.decorator import DecoratedMappedOperator
from airflow.sdk.definitions._internal.expandinput import
DictOfListsExpandInput
from airflow.sdk.definitions.mappedoperator import MappedOperator
+
else:
from airflow.decorators import setup, task as task_decorator, teardown
from airflow.decorators.base import DecoratedMappedOperator # type:
ignore[no-redef]
@@ -45,7 +46,7 @@ else:
from airflow.models.expandinput import DictOfListsExpandInput
from airflow.models.mappedoperator import MappedOperator
from airflow.models.xcom_arg import XComArg
- from airflow.utils.task_group import TaskGroup
+ from airflow.utils.task_group import TaskGroup # type: ignore[no-redef]
pytestmark = pytest.mark.db_test
diff --git
a/providers/standard/tests/unit/standard/operators/test_branch_operator.py
b/providers/standard/tests/unit/standard/operators/test_branch_operator.py
index 670ce77415b..bcecbd26b75 100644
--- a/providers/standard/tests/unit/standard/operators/test_branch_operator.py
+++ b/providers/standard/tests/unit/standard/operators/test_branch_operator.py
@@ -28,9 +28,14 @@ from airflow.providers.standard.utils.skipmixin import
XCOM_SKIPMIXIN_FOLLOWED,
from airflow.timetables.base import DataInterval
from airflow.utils import timezone
from airflow.utils.state import State
-from airflow.utils.task_group import TaskGroup
from airflow.utils.types import DagRunType
+try:
+ from airflow.sdk.definitions.taskgroup import TaskGroup
+except ImportError:
+ # Fallback for Airflow < 3.1
+ from airflow.utils.task_group import TaskGroup # type: ignore[no-redef]
+
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_1,
AIRFLOW_V_3_0_PLUS
if AIRFLOW_V_3_0_PLUS:
diff --git
a/providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py
b/providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py
index daff66dcb13..cb881617faf 100644
---
a/providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py
+++
b/providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py
@@ -66,7 +66,6 @@ from airflow.serialization.serialized_objects import
SerializedBaseOperator
from airflow.timetables.base import DataInterval
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import DagRunState, State, TaskInstanceState
-from airflow.utils.task_group import TaskGroup
from airflow.utils.timezone import coerce_datetime, datetime
from airflow.utils.types import DagRunType
@@ -80,6 +79,13 @@ if AIRFLOW_V_3_0_PLUS:
from airflow.utils.types import DagRunTriggeredByType
else:
from airflow.decorators import task as task_deco
+
+try:
+ from airflow.sdk.definitions.taskgroup import TaskGroup
+except ImportError:
+ # Fallback for Airflow < 3.1
+ from airflow.utils.task_group import TaskGroup # type: ignore[no-redef]
+
pytestmark = pytest.mark.db_test
diff --git a/providers/standard/tests/unit/standard/utils/test_sensor_helper.py
b/providers/standard/tests/unit/standard/utils/test_sensor_helper.py
index 346e956a981..89735e206db 100644
--- a/providers/standard/tests/unit/standard/utils/test_sensor_helper.py
+++ b/providers/standard/tests/unit/standard/utils/test_sensor_helper.py
@@ -35,15 +35,21 @@ from airflow.providers.standard.utils.sensor_helper import (
)
from airflow.utils import timezone
from airflow.utils.state import DagRunState, TaskInstanceState
-from airflow.utils.task_group import TaskGroup
from airflow.utils.types import DagRunType
from tests_common.test_utils import db
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
+try:
+ from airflow.sdk.definitions.taskgroup import TaskGroup
+except ImportError:
+ # Fallback for Airflow < 3.1
+ from airflow.utils.task_group import TaskGroup # type: ignore[no-redef]
+
if TYPE_CHECKING:
from sqlalchemy.orm.session import Session
+
TI = TaskInstance
diff --git a/task-sdk/src/airflow/sdk/definitions/taskgroup.py
b/task-sdk/src/airflow/sdk/definitions/taskgroup.py
index c5afe65e0d8..ed79ae4202e 100644
--- a/task-sdk/src/airflow/sdk/definitions/taskgroup.py
+++ b/task-sdk/src/airflow/sdk/definitions/taskgroup.py
@@ -24,12 +24,15 @@ import functools
import operator
import re
import weakref
-from collections.abc import Generator, Iterator, Sequence
+from collections.abc import Callable, Generator, Iterator, Sequence
+from functools import cache
+from operator import methodcaller
from typing import TYPE_CHECKING, Any
import attrs
import methodtools
+from airflow.configuration import conf
from airflow.exceptions import (
AirflowDagCycleException,
AirflowException,
@@ -669,36 +672,41 @@ class MappedTaskGroup(TaskGroup):
yield op
-def task_group_to_dict(task_item_or_group):
+@cache
+def get_task_group_children_getter() -> Callable:
+ """Get the Task Group Children Getter for the DAG."""
+ sort_order = conf.get("api", "grid_view_sorting_order")
+ if sort_order == "topological":
+ return methodcaller("topological_sort")
+ return methodcaller("hierarchical_alphabetical_sort")
+
+
+def task_group_to_dict(task_item_or_group, parent_group_is_mapped=False):
"""Create a nested dict representation of this TaskGroup and its children
used to construct the Graph."""
from airflow.sdk.definitions._internal.abstractoperator import
AbstractOperator
from airflow.sdk.definitions.mappedoperator import MappedOperator
from airflow.serialization.serialized_objects import SerializedBaseOperator
if isinstance(task := task_item_or_group, (AbstractOperator,
SerializedBaseOperator)):
- setup_teardown_type = {}
- is_mapped = {}
- node_type = {"type": "task"}
- if task.is_setup is True:
- setup_teardown_type["setup_teardown_type"] = "setup"
- elif task.is_teardown is True:
- setup_teardown_type["setup_teardown_type"] = "teardown"
- if isinstance(task, MappedOperator):
- is_mapped["is_mapped"] = True
- if getattr(task, "_is_sensor", False):
- node_type["type"] = "sensor"
- return {
+ node_operator = {
"id": task.task_id,
"label": task.label,
- **is_mapped,
- **setup_teardown_type,
- **node_type,
+ "operator": task.operator_name,
+ "type": "task",
}
+ if task.is_setup:
+ node_operator["setup_teardown_type"] = "setup"
+ elif task.is_teardown:
+ node_operator["setup_teardown_type"] = "teardown"
+ if isinstance(task, MappedOperator) or parent_group_is_mapped:
+ node_operator["is_mapped"] = True
+ return node_operator
task_group = task_item_or_group
is_mapped = isinstance(task_group, MappedTaskGroup)
children = [
- task_group_to_dict(child) for child in
sorted(task_group.children.values(), key=lambda t: t.label)
+ task_group_to_dict(child,
parent_group_is_mapped=parent_group_is_mapped or is_mapped)
+ for child in get_task_group_children_getter()(task_group)
]
if task_group.upstream_group_ids or task_group.upstream_task_ids:
@@ -715,5 +723,44 @@ def task_group_to_dict(task_item_or_group):
"tooltip": task_group.tooltip,
"is_mapped": is_mapped,
"children": children,
- "type": "task_group",
+ "type": "task",
+ }
+
+
+def task_group_to_dict_grid(task_item_or_group, parent_group_is_mapped=False):
+ """Create a nested dict representation of this TaskGroup and its children
used to construct the Graph."""
+ from airflow.sdk.definitions._internal.abstractoperator import
AbstractOperator
+ from airflow.sdk.definitions.mappedoperator import MappedOperator
+ from airflow.serialization.serialized_objects import SerializedBaseOperator
+
+ if isinstance(task := task_item_or_group, (AbstractOperator,
SerializedBaseOperator)):
+ is_mapped = None
+ if isinstance(task, MappedOperator) or parent_group_is_mapped:
+ is_mapped = True
+ setup_teardown_type = None
+ if task.is_setup is True:
+ setup_teardown_type = "setup"
+ elif task.is_teardown is True:
+ setup_teardown_type = "teardown"
+ return {
+ "id": task.task_id,
+ "label": task.label,
+ "is_mapped": is_mapped,
+ "children": None,
+ "setup_teardown_type": setup_teardown_type,
+ }
+
+ task_group = task_item_or_group
+ task_group_sort = get_task_group_children_getter()
+ is_mapped_group = isinstance(task_group, MappedTaskGroup)
+ children = [
+ task_group_to_dict_grid(x,
parent_group_is_mapped=parent_group_is_mapped or is_mapped_group)
+ for x in task_group_sort(task_group)
+ ]
+
+ return {
+ "id": task_group.group_id,
+ "label": task_group.label,
+ "is_mapped": is_mapped_group or None,
+ "children": children or None,
}