This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch task-sdk-first-code
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit b5e018920da4b59276d7488bfa9fc0b1de3385e0
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Wed Oct 23 13:11:27 2024 +0100

    [skip-ci]
---
 airflow/decorators/base.py  | 13 ++++++-------
 airflow/utils/task_group.py | 21 ---------------------
 tests/models/test_dag.py    |  5 +++--
 3 files changed, 9 insertions(+), 30 deletions(-)

diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py
index fe85c07ccd0..6129dc1dd42 100644
--- a/airflow/decorators/base.py
+++ b/airflow/decorators/base.py
@@ -59,26 +59,25 @@ from airflow.models.mappedoperator import MappedOperator, 
ensure_xcomarg_return_
 from airflow.models.pool import Pool
 from airflow.models.xcom_arg import XComArg
 from airflow.sdk.definitions.baseoperator import BaseOperator as 
TaskSDKBaseOperator
-from airflow.sdk.definitions.contextmanager import DagContext
+from airflow.sdk.definitions.contextmanager import DagContext, TaskGroupContext
 from airflow.typing_compat import ParamSpec, Protocol
 from airflow.utils import timezone
 from airflow.utils.context import KNOWN_CONTEXT_KEYS
 from airflow.utils.decorators import remove_task_decorator
 from airflow.utils.helpers import prevent_duplicates
-from airflow.utils.task_group import TaskGroupContext
 from airflow.utils.trigger_rule import TriggerRule
 from airflow.utils.types import NOTSET
 
 if TYPE_CHECKING:
     from sqlalchemy.orm import Session
 
-    from airflow.models.dag import DAG
     from airflow.models.expandinput import (
         ExpandInput,
         OperatorExpandArgument,
         OperatorExpandKwargsArgument,
     )
     from airflow.models.mappedoperator import ValidationSource
+    from airflow.sdk import DAG
     from airflow.utils.context import Context
     from airflow.utils.task_group import TaskGroup
 
@@ -142,13 +141,13 @@ def get_unique_task_id(
       ...
       task_id__20
     """
-    dag = dag or DagContext.get_current_dag()
+    dag = dag or DagContext.get_current()
     if not dag:
         return task_id
 
     # We need to check if we are in the context of TaskGroup as the task_id may
     # already be altered
-    task_group = task_group or TaskGroupContext.get_current_task_group(dag)
+    task_group = task_group or TaskGroupContext.get_current(dag)
     tg_task_id = task_group.child_id(task_id) if task_group else task_id
 
     if tg_task_id not in dag.task_ids:
@@ -429,8 +428,8 @@ class _TaskDecorator(ExpandableFactory, Generic[FParams, 
FReturn, OperatorSubcla
         ensure_xcomarg_return_value(expand_input.value)
 
         task_kwargs = self.kwargs.copy()
-        dag = task_kwargs.pop("dag", None) or DagContext.get_current_dag()
-        task_group = task_kwargs.pop("task_group", None) or 
TaskGroupContext.get_current_task_group(dag)
+        dag = task_kwargs.pop("dag", None) or DagContext.get_current()
+        task_group = task_kwargs.pop("task_group", None) or 
TaskGroupContext.get_current(dag)
 
         default_args, partial_params = get_merged_defaults(
             dag=dag,
diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py
index 6b760a112af..1f94880902c 100644
--- a/airflow/utils/task_group.py
+++ b/airflow/utils/task_group.py
@@ -23,13 +23,11 @@ import functools
 import operator
 from typing import TYPE_CHECKING, Iterator
 
-import airflow.sdk.definitions.contextmanager
 import airflow.sdk.definitions.taskgroup
 
 if TYPE_CHECKING:
     from sqlalchemy.orm import Session
 
-    from airflow.models.dag import DAG
     from airflow.models.operator import Operator
     from airflow.typing_compat import TypeAlias
 
@@ -78,25 +76,6 @@ class 
MappedTaskGroup(airflow.sdk.definitions.taskgroup.MappedTaskGroup):
         )
 
 
-class 
TaskGroupContext(airflow.sdk.definitions.contextmanager.TaskGroupContext, 
share_parent_context=True):
-    """TaskGroup context is used to keep the current TaskGroup when TaskGroup 
is used as ContextManager."""
-
-    @classmethod
-    def push_context_managed_task_group(cls, task_group: TaskGroup):
-        """Push a TaskGroup into the list of managed TaskGroups."""
-        return cls.push(task_group)
-
-    @classmethod
-    def pop_context_managed_task_group(cls) -> TaskGroup | None:
-        """Pops the last TaskGroup from the list of managed TaskGroups and 
update the current TaskGroup."""
-        return cls.pop()
-
-    @classmethod
-    def get_current_task_group(cls, dag: DAG | None) -> TaskGroup | None:
-        """Get the current TaskGroup."""
-        return cls.get_current(dag)
-
-
 def task_group_to_dict(task_item_or_group):
     """Create a nested dict representation of this TaskGroup and its children 
used to construct the Graph."""
     from airflow.models.abstractoperator import AbstractOperator
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index 43029fbab98..c126e2d3039 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -69,6 +69,8 @@ from airflow.models.taskinstance import TaskInstance as TI
 from airflow.operators.empty import EmptyOperator
 from airflow.operators.python import PythonOperator
 from airflow.providers.standard.operators.bash import BashOperator
+from airflow.sdk import TaskGroup
+from airflow.sdk.definitions.contextmanager import TaskGroupContext
 from airflow.security import permissions
 from airflow.templates import NativeEnvironment, SandboxedEnvironment
 from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, 
Timetable
@@ -82,7 +84,6 @@ from airflow.utils import timezone
 from airflow.utils.file import list_py_file_paths
 from airflow.utils.session import create_session
 from airflow.utils.state import DagRunState, State, TaskInstanceState
-from airflow.utils.task_group import TaskGroup, TaskGroupContext
 from airflow.utils.timezone import datetime as datetime_tz
 from airflow.utils.trigger_rule import TriggerRule
 from airflow.utils.types import DagRunType
@@ -1505,7 +1506,7 @@ class TestDag:
     def test_dag_add_task_sets_default_task_group(self):
         dag = DAG(dag_id="test_dag_add_task_sets_default_task_group", 
schedule=None, start_date=DEFAULT_DATE)
         task_without_task_group = 
EmptyOperator(task_id="task_without_group_id")
-        default_task_group = TaskGroupContext.get_current_task_group(dag)
+        default_task_group = TaskGroupContext.get_current(dag)
         dag.add_task(task_without_task_group)
         assert default_task_group.get_child_by_label("task_without_group_id") 
== task_without_task_group
 

Reply via email to