This is an automated email from the ASF dual-hosted git repository. ash pushed a commit to branch task-sdk-first-code in repository https://gitbox.apache.org/repos/asf/airflow.git
commit b5e018920da4b59276d7488bfa9fc0b1de3385e0 Author: Ash Berlin-Taylor <[email protected]> AuthorDate: Wed Oct 23 13:11:27 2024 +0100 [skip-ci] --- airflow/decorators/base.py | 13 ++++++------- airflow/utils/task_group.py | 21 --------------------- tests/models/test_dag.py | 5 +++-- 3 files changed, 9 insertions(+), 30 deletions(-) diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index fe85c07ccd0..6129dc1dd42 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -59,26 +59,25 @@ from airflow.models.mappedoperator import MappedOperator, ensure_xcomarg_return_ from airflow.models.pool import Pool from airflow.models.xcom_arg import XComArg from airflow.sdk.definitions.baseoperator import BaseOperator as TaskSDKBaseOperator -from airflow.sdk.definitions.contextmanager import DagContext +from airflow.sdk.definitions.contextmanager import DagContext, TaskGroupContext from airflow.typing_compat import ParamSpec, Protocol from airflow.utils import timezone from airflow.utils.context import KNOWN_CONTEXT_KEYS from airflow.utils.decorators import remove_task_decorator from airflow.utils.helpers import prevent_duplicates -from airflow.utils.task_group import TaskGroupContext from airflow.utils.trigger_rule import TriggerRule from airflow.utils.types import NOTSET if TYPE_CHECKING: from sqlalchemy.orm import Session - from airflow.models.dag import DAG from airflow.models.expandinput import ( ExpandInput, OperatorExpandArgument, OperatorExpandKwargsArgument, ) from airflow.models.mappedoperator import ValidationSource + from airflow.sdk import DAG from airflow.utils.context import Context from airflow.utils.task_group import TaskGroup @@ -142,13 +141,13 @@ def get_unique_task_id( ... task_id__20 """ - dag = dag or DagContext.get_current_dag() + dag = dag or DagContext.get_current() if not dag: return task_id # We need to check if we are in the context of TaskGroup as the task_id may # already be altered - task_group = task_group or TaskGroupContext.get_current_task_group(dag) + task_group = task_group or TaskGroupContext.get_current(dag) tg_task_id = task_group.child_id(task_id) if task_group else task_id if tg_task_id not in dag.task_ids: @@ -429,8 +428,8 @@ class _TaskDecorator(ExpandableFactory, Generic[FParams, FReturn, OperatorSubcla ensure_xcomarg_return_value(expand_input.value) task_kwargs = self.kwargs.copy() - dag = task_kwargs.pop("dag", None) or DagContext.get_current_dag() - task_group = task_kwargs.pop("task_group", None) or TaskGroupContext.get_current_task_group(dag) + dag = task_kwargs.pop("dag", None) or DagContext.get_current() + task_group = task_kwargs.pop("task_group", None) or TaskGroupContext.get_current(dag) default_args, partial_params = get_merged_defaults( dag=dag, diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py index 6b760a112af..1f94880902c 100644 --- a/airflow/utils/task_group.py +++ b/airflow/utils/task_group.py @@ -23,13 +23,11 @@ import functools import operator from typing import TYPE_CHECKING, Iterator -import airflow.sdk.definitions.contextmanager import airflow.sdk.definitions.taskgroup if TYPE_CHECKING: from sqlalchemy.orm import Session - from airflow.models.dag import DAG from airflow.models.operator import Operator from airflow.typing_compat import TypeAlias @@ -78,25 +76,6 @@ class MappedTaskGroup(airflow.sdk.definitions.taskgroup.MappedTaskGroup): ) -class TaskGroupContext(airflow.sdk.definitions.contextmanager.TaskGroupContext, share_parent_context=True): - """TaskGroup context is used to keep the current TaskGroup when TaskGroup is used as ContextManager.""" - - @classmethod - def push_context_managed_task_group(cls, task_group: TaskGroup): - """Push a TaskGroup into the list of managed TaskGroups.""" - return cls.push(task_group) - - @classmethod - def pop_context_managed_task_group(cls) -> TaskGroup | None: - """Pops the last TaskGroup from the list of managed TaskGroups and update the current TaskGroup.""" - return cls.pop() - - @classmethod - def get_current_task_group(cls, dag: DAG | None) -> TaskGroup | None: - """Get the current TaskGroup.""" - return cls.get_current(dag) - - def task_group_to_dict(task_item_or_group): """Create a nested dict representation of this TaskGroup and its children used to construct the Graph.""" from airflow.models.abstractoperator import AbstractOperator diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index 43029fbab98..c126e2d3039 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -69,6 +69,8 @@ from airflow.models.taskinstance import TaskInstance as TI from airflow.operators.empty import EmptyOperator from airflow.operators.python import PythonOperator from airflow.providers.standard.operators.bash import BashOperator +from airflow.sdk import TaskGroup +from airflow.sdk.definitions.contextmanager import TaskGroupContext from airflow.security import permissions from airflow.templates import NativeEnvironment, SandboxedEnvironment from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable @@ -82,7 +84,6 @@ from airflow.utils import timezone from airflow.utils.file import list_py_file_paths from airflow.utils.session import create_session from airflow.utils.state import DagRunState, State, TaskInstanceState -from airflow.utils.task_group import TaskGroup, TaskGroupContext from airflow.utils.timezone import datetime as datetime_tz from airflow.utils.trigger_rule import TriggerRule from airflow.utils.types import DagRunType @@ -1505,7 +1506,7 @@ class TestDag: def test_dag_add_task_sets_default_task_group(self): dag = DAG(dag_id="test_dag_add_task_sets_default_task_group", schedule=None, start_date=DEFAULT_DATE) task_without_task_group = EmptyOperator(task_id="task_without_group_id") - default_task_group = TaskGroupContext.get_current_task_group(dag) + default_task_group = TaskGroupContext.get_current(dag) dag.add_task(task_without_task_group) assert default_task_group.get_child_by_label("task_without_group_id") == task_without_task_group
