This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 521410f03cb AIP-72: Move non-user facing code to `_internal` (#45515)
521410f03cb is described below
commit 521410f03cbe776a0fa1f96a5b572a17908cc327
Author: Kaxil Naik <[email protected]>
AuthorDate: Thu Jan 9 21:33:50 2025 +0530
AIP-72: Move non-user facing code to `_internal` (#45515)
Anything within `_internal` module is not a user-facing code, this makes it
clearer! It won't be covered by Semantic versioning and can have breaking
changes even in patch versions. A user should never use anything in
`airflow/sdk/definitions/_internal`.
---
.pre-commit-config.yaml | 2 +-
airflow/decorators/base.py | 2 +-
airflow/decorators/task_group.py | 2 +-
airflow/models/abstractoperator.py | 4 +-
airflow/models/baseoperator.py | 2 +-
airflow/models/dagbag.py | 8 ++--
airflow/models/expandinput.py | 2 +-
airflow/models/param.py | 2 +-
airflow/models/skipmixin.py | 2 +-
airflow/models/taskinstance.py | 4 +-
airflow/models/taskmixin.py | 8 ++--
airflow/models/xcom_arg.py | 4 +-
airflow/notifications/basenotifier.py | 2 +-
airflow/serialization/serialized_objects.py | 2 +-
airflow/utils/types.py | 6 +--
.../providers/standard/utils/python_virtualenv.py | 2 +-
.../kubernetes/operators/test_spark_kubernetes.py | 4 +-
task_sdk/src/airflow/sdk/__init__.py | 4 +-
.../airflow/sdk/definitions/_internal/__init__.py | 12 ------
.../{ => _internal}/abstractoperator.py | 6 +--
.../definitions/{ => _internal}/contextmanager.py | 30 +--------------
.../sdk/definitions/{ => _internal}/decorators.py | 0
.../sdk/definitions/{ => _internal}/mixins.py | 2 +-
.../sdk/definitions/{ => _internal}/node.py | 4 +-
.../sdk/definitions/{ => _internal}/templater.py | 15 +-------
.../sdk/{ => definitions/_internal}/types.py | 2 +-
.../src/airflow/sdk/definitions/baseoperator.py | 12 +++---
.../sdk/definitions/{decorators.py => context.py} | 45 ++++++++++++++--------
task_sdk/src/airflow/sdk/definitions/dag.py | 12 +++---
task_sdk/src/airflow/sdk/definitions/edges.py | 8 ++--
task_sdk/src/airflow/sdk/definitions/taskgroup.py | 18 ++++-----
.../src/airflow/sdk/definitions/template.py | 18 +++++----
task_sdk/src/airflow/sdk/execution_time/context.py | 4 +-
.../tests/defintions/_internal/__init__.py | 12 ------
.../defintions/{ => _internal}/test_templater.py | 2 +-
task_sdk/tests/defintions/test_baseoperator.py | 4 +-
.../{test_contextmanager.py => test_context.py} | 8 ++--
.../tests/defintions/test_template.py | 29 ++++++++++----
task_sdk/tests/execution_time/test_context.py | 2 +-
task_sdk/tests/execution_time/test_task_runner.py | 4 +-
tests/models/test_dag.py | 4 +-
41 files changed, 141 insertions(+), 174 deletions(-)
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 13886b1e2fb..9c8fc2ce8b8 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1193,7 +1193,7 @@ repos:
^providers/src/airflow/providers/ |
^(providers/)?tests/ |
task_sdk/src/airflow/sdk/definitions/dag.py$ |
- task_sdk/src/airflow/sdk/definitions/node.py$ |
+ task_sdk/src/airflow/sdk/definitions/_internal/node.py$ |
^dev/.*\.py$ |
^scripts/.*\.py$ |
^docker_tests/.*$ |
diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py
index 64081ef4070..f593b805168 100644
--- a/airflow/decorators/base.py
+++ b/airflow/decorators/base.py
@@ -52,9 +52,9 @@ from airflow.models.expandinput import (
)
from airflow.models.mappedoperator import MappedOperator,
ensure_xcomarg_return_value
from airflow.models.xcom_arg import XComArg
+from airflow.sdk.definitions._internal.contextmanager import DagContext,
TaskGroupContext
from airflow.sdk.definitions.asset import Asset
from airflow.sdk.definitions.baseoperator import BaseOperator as
TaskSDKBaseOperator
-from airflow.sdk.definitions.contextmanager import DagContext, TaskGroupContext
from airflow.typing_compat import ParamSpec, Protocol
from airflow.utils import timezone
from airflow.utils.context import KNOWN_CONTEXT_KEYS
diff --git a/airflow/decorators/task_group.py b/airflow/decorators/task_group.py
index 2fabd29157d..c75eb8505a4 100644
--- a/airflow/decorators/task_group.py
+++ b/airflow/decorators/task_group.py
@@ -40,7 +40,7 @@ from airflow.models.expandinput import (
MappedArgument,
)
from airflow.models.xcom_arg import XComArg
-from airflow.sdk.definitions.node import DAGNode
+from airflow.sdk.definitions._internal.node import DAGNode
from airflow.typing_compat import ParamSpec
from airflow.utils.helpers import prevent_duplicates
from airflow.utils.task_group import MappedTaskGroup, TaskGroup
diff --git a/airflow/models/abstractoperator.py
b/airflow/models/abstractoperator.py
index f87b6e06b1c..ed64a5320ce 100644
--- a/airflow/models/abstractoperator.py
+++ b/airflow/models/abstractoperator.py
@@ -29,7 +29,7 @@ from sqlalchemy import select
from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.models.expandinput import NotFullyPopulated
-from airflow.sdk.definitions.abstractoperator import AbstractOperator as
TaskSDKAbstractOperator
+from airflow.sdk.definitions._internal.abstractoperator import
AbstractOperator as TaskSDKAbstractOperator
from airflow.utils.context import Context
from airflow.utils.db import exists_query
from airflow.utils.log.logging_mixin import LoggingMixin
@@ -48,8 +48,8 @@ if TYPE_CHECKING:
from airflow.models.dag import DAG as SchedulerDAG
from airflow.models.mappedoperator import MappedOperator
from airflow.models.taskinstance import TaskInstance
+ from airflow.sdk.definitions._internal.node import DAGNode
from airflow.sdk.definitions.baseoperator import BaseOperator
- from airflow.sdk.definitions.node import DAGNode
from airflow.task.priority_strategy import PriorityWeightStrategy
from airflow.triggers.base import StartTriggerArgs
from airflow.utils.task_group import TaskGroup
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index f28e05584fd..d39586c8a19 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -273,7 +273,7 @@ else:
params: collections.abc.MutableMapping | None = None,
**kwargs,
):
- from airflow.sdk.definitions.contextmanager import DagContext,
TaskGroupContext
+ from airflow.sdk.definitions._internal.contextmanager import
DagContext, TaskGroupContext
validate_mapping_kwargs(operator_class, "partial", kwargs)
diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py
index 9cebbb6858c..7d0d2efc1bf 100644
--- a/airflow/models/dagbag.py
+++ b/airflow/models/dagbag.py
@@ -281,7 +281,7 @@ class DagBag(LoggingMixin):
def process_file(self, filepath, only_if_updated=True, safe_mode=True):
"""Given a path to a python module or zip file, import the module and
look for dag objects within."""
- from airflow.sdk.definitions.contextmanager import DagContext
+ from airflow.sdk.definitions._internal.contextmanager import DagContext
# if the source file no longer exists in the DB or in the filesystem,
# return an empty list
@@ -358,7 +358,7 @@ class DagBag(LoggingMixin):
return warnings
def _load_modules_from_file(self, filepath, safe_mode):
- from airflow.sdk.definitions.contextmanager import DagContext
+ from airflow.sdk.definitions._internal.contextmanager import DagContext
if not might_contain_dag(filepath, safe_mode):
# Don't want to spam user with skip messages
@@ -414,7 +414,7 @@ class DagBag(LoggingMixin):
return parse(mod_name, filepath)
def _load_modules_from_zip(self, filepath, safe_mode):
- from airflow.sdk.definitions.contextmanager import DagContext
+ from airflow.sdk.definitions._internal.contextmanager import DagContext
mods = []
with zipfile.ZipFile(filepath) as current_zip_file:
@@ -464,7 +464,7 @@ class DagBag(LoggingMixin):
def _process_modules(self, filepath, mods, file_last_changed_on_disk):
from airflow.models.dag import DAG # Avoid circular import
- from airflow.sdk.definitions.contextmanager import DagContext
+ from airflow.sdk.definitions._internal.contextmanager import DagContext
top_level_dags = {(o, m) for m in mods for o in m.__dict__.values() if
isinstance(o, DAG)}
diff --git a/airflow/models/expandinput.py b/airflow/models/expandinput.py
index bf3c6e95056..fcbb55dc3d2 100644
--- a/airflow/models/expandinput.py
+++ b/airflow/models/expandinput.py
@@ -25,7 +25,7 @@ from typing import TYPE_CHECKING, Any, NamedTuple, Union
import attr
-from airflow.sdk.definitions.mixins import ResolveMixin
+from airflow.sdk.definitions._internal.mixins import ResolveMixin
from airflow.utils.session import NEW_SESSION, provide_session
if TYPE_CHECKING:
diff --git a/airflow/models/param.py b/airflow/models/param.py
index 416d9cfb8b9..c1b47a8f5e4 100644
--- a/airflow/models/param.py
+++ b/airflow/models/param.py
@@ -24,7 +24,7 @@ from collections.abc import ItemsView, Iterable, Mapping,
MutableMapping, Values
from typing import TYPE_CHECKING, Any, ClassVar
from airflow.exceptions import AirflowException, ParamValidationError
-from airflow.sdk.definitions.mixins import ResolveMixin
+from airflow.sdk.definitions._internal.mixins import ResolveMixin
from airflow.utils.types import NOTSET, ArgNotSet
if TYPE_CHECKING:
diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py
index 8b59043ecef..63564ebbc43 100644
--- a/airflow/models/skipmixin.py
+++ b/airflow/models/skipmixin.py
@@ -35,7 +35,7 @@ if TYPE_CHECKING:
from airflow.models.dagrun import DagRun
from airflow.models.operator import Operator
- from airflow.sdk.definitions.node import DAGNode
+ from airflow.sdk.definitions._internal.node import DAGNode
# The key used by SkipMixin to store XCom data.
XCOM_SKIPMIXIN_KEY = "skipmixin_key"
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 387ea9122e0..580d8cb7b8d 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -106,9 +106,9 @@ from airflow.models.taskmap import TaskMap
from airflow.models.taskreschedule import TaskReschedule
from airflow.models.xcom import LazyXComSelectSequence, XCom
from airflow.plugins_manager import integrate_macros_plugins
+from airflow.sdk.definitions._internal.contextmanager import _CURRENT_CONTEXT
+from airflow.sdk.definitions._internal.templater import SandboxedEnvironment
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef,
AssetUniqueKey, AssetUriRef
-from airflow.sdk.definitions.templater import SandboxedEnvironment
-from airflow.sdk.execution_time.context import _CURRENT_CONTEXT
from airflow.sentry import Sentry
from airflow.settings import task_instance_mutation_hook
from airflow.stats import Stats
diff --git a/airflow/models/taskmixin.py b/airflow/models/taskmixin.py
index fa76a3815cb..7aa8f63ba3c 100644
--- a/airflow/models/taskmixin.py
+++ b/airflow/models/taskmixin.py
@@ -21,8 +21,8 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from airflow.typing_compat import TypeAlias
-import airflow.sdk.definitions.mixins
-import airflow.sdk.definitions.node
+import airflow.sdk.definitions._internal.mixins
+import airflow.sdk.definitions._internal.node
-DependencyMixin: TypeAlias = airflow.sdk.definitions.mixins.DependencyMixin
-DAGNode: TypeAlias = airflow.sdk.definitions.node.DAGNode
+DependencyMixin: TypeAlias =
airflow.sdk.definitions._internal.mixins.DependencyMixin
+DAGNode: TypeAlias = airflow.sdk.definitions._internal.node.DAGNode
diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py
index cf4147dcbfc..4bf91a68bee 100644
--- a/airflow/models/xcom_arg.py
+++ b/airflow/models/xcom_arg.py
@@ -29,8 +29,8 @@ from airflow.exceptions import AirflowException, XComNotFound
from airflow.models import MappedOperator, TaskInstance
from airflow.models.abstractoperator import AbstractOperator
from airflow.models.taskmixin import DependencyMixin
-from airflow.sdk.definitions.mixins import ResolveMixin
-from airflow.sdk.types import NOTSET, ArgNotSet
+from airflow.sdk.definitions._internal.mixins import ResolveMixin
+from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet
from airflow.utils.db import exists_query
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.setup_teardown import SetupTeardownContext
diff --git a/airflow/notifications/basenotifier.py
b/airflow/notifications/basenotifier.py
index 398d95cbb8d..ae69f07db26 100644
--- a/airflow/notifications/basenotifier.py
+++ b/airflow/notifications/basenotifier.py
@@ -21,7 +21,7 @@ from abc import abstractmethod
from collections.abc import Sequence
from typing import TYPE_CHECKING
-from airflow.sdk.definitions.templater import Templater
+from airflow.sdk.definitions._internal.templater import Templater
from airflow.utils.context import context_merge
from airflow.utils.log.logging_mixin import LoggingMixin
diff --git a/airflow/serialization/serialized_objects.py
b/airflow/serialization/serialized_objects.py
index a0e5da74145..0926f3245e0 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -99,7 +99,7 @@ if TYPE_CHECKING:
from airflow.models.baseoperatorlink import BaseOperatorLink
from airflow.models.expandinput import ExpandInput
from airflow.models.operator import Operator
- from airflow.sdk.definitions.node import DAGNode
+ from airflow.sdk.definitions._internal.node import DAGNode
from airflow.serialization.json_schema import Validator
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
from airflow.timetables.base import DagRunInfo, DataInterval, Timetable
diff --git a/airflow/utils/types.py b/airflow/utils/types.py
index 7dd1ce02b60..04ee05fc9e8 100644
--- a/airflow/utils/types.py
+++ b/airflow/utils/types.py
@@ -19,15 +19,15 @@ from __future__ import annotations
import enum
from typing import TYPE_CHECKING
-import airflow.sdk.types
+import airflow.sdk.definitions._internal.types
from airflow.typing_compat import TypeAlias, TypedDict
if TYPE_CHECKING:
from datetime import datetime
-ArgNotSet: TypeAlias = airflow.sdk.types.ArgNotSet
+ArgNotSet: TypeAlias = airflow.sdk.definitions._internal.types.ArgNotSet
-NOTSET = airflow.sdk.types.NOTSET
+NOTSET = airflow.sdk.definitions._internal.types.NOTSET
class AttributeRemoved:
diff --git
a/providers/src/airflow/providers/standard/utils/python_virtualenv.py
b/providers/src/airflow/providers/standard/utils/python_virtualenv.py
index 66cc92ee16e..64802b643be 100644
--- a/providers/src/airflow/providers/standard/utils/python_virtualenv.py
+++ b/providers/src/airflow/providers/standard/utils/python_virtualenv.py
@@ -28,7 +28,7 @@ import jinja2
from jinja2 import select_autoescape
from airflow.configuration import conf
-from airflow.sdk.definitions.templater import NativeEnvironment
+from airflow.sdk.definitions._internal.templater import NativeEnvironment
from airflow.utils.process_utils import execute_in_subprocess
diff --git a/providers/tests/cncf/kubernetes/operators/test_spark_kubernetes.py
b/providers/tests/cncf/kubernetes/operators/test_spark_kubernetes.py
index e5d4ecf6610..3c4b3189369 100644
--- a/providers/tests/cncf/kubernetes/operators/test_spark_kubernetes.py
+++ b/providers/tests/cncf/kubernetes/operators/test_spark_kubernetes.py
@@ -798,7 +798,7 @@ def test_resolve_application_file_real_file(
from airflow.template.templater import LiteralValue
except ImportError:
# Airflow 3.0+
- from airflow.sdk.definitions.templater import LiteralValue
+ from airflow.sdk.definitions._internal.templater import
LiteralValue
application_file = LiteralValue(application_file)
else:
@@ -828,7 +828,7 @@ def
test_resolve_application_file_real_file_not_exists(create_task_instance_of_o
from airflow.template.templater import LiteralValue
except ImportError:
# Airflow 3.0+
- from airflow.sdk.definitions.templater import LiteralValue
+ from airflow.sdk.definitions._internal.templater import LiteralValue
ti = create_task_instance_of_operator(
SparkKubernetesOperator,
diff --git a/task_sdk/src/airflow/sdk/__init__.py
b/task_sdk/src/airflow/sdk/__init__.py
index a71ab7b2dd8..e50b475b006 100644
--- a/task_sdk/src/airflow/sdk/__init__.py
+++ b/task_sdk/src/airflow/sdk/__init__.py
@@ -35,7 +35,7 @@ __version__ = "1.0.0.dev1"
if TYPE_CHECKING:
from airflow.sdk.definitions.baseoperator import BaseOperator
from airflow.sdk.definitions.connection import Connection
- from airflow.sdk.definitions.contextmanager import get_current_context
+ from airflow.sdk.definitions.context import get_current_context
from airflow.sdk.definitions.dag import DAG, dag
from airflow.sdk.definitions.edges import EdgeModifier, Label
from airflow.sdk.definitions.taskgroup import TaskGroup
@@ -49,7 +49,7 @@ __lazy_imports: dict[str, str] = {
"Label": ".definitions.edges",
"Connection": ".definitions.connection",
"Variable": ".definitions.variable",
- "get_current_context": ".definitions.contextmanager",
+ "get_current_context": ".definitions.context",
}
diff --git a/airflow/models/taskmixin.py
b/task_sdk/src/airflow/sdk/definitions/_internal/__init__.py
similarity index 69%
copy from airflow/models/taskmixin.py
copy to task_sdk/src/airflow/sdk/definitions/_internal/__init__.py
index fa76a3815cb..13a83393a91 100644
--- a/airflow/models/taskmixin.py
+++ b/task_sdk/src/airflow/sdk/definitions/_internal/__init__.py
@@ -14,15 +14,3 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from __future__ import annotations
-
-from typing import TYPE_CHECKING
-
-if TYPE_CHECKING:
- from airflow.typing_compat import TypeAlias
-
-import airflow.sdk.definitions.mixins
-import airflow.sdk.definitions.node
-
-DependencyMixin: TypeAlias = airflow.sdk.definitions.mixins.DependencyMixin
-DAGNode: TypeAlias = airflow.sdk.definitions.node.DAGNode
diff --git a/task_sdk/src/airflow/sdk/definitions/abstractoperator.py
b/task_sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py
similarity index 98%
rename from task_sdk/src/airflow/sdk/definitions/abstractoperator.py
rename to task_sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py
index 0251033cd55..70667cbbe33 100644
--- a/task_sdk/src/airflow/sdk/definitions/abstractoperator.py
+++ b/task_sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py
@@ -31,9 +31,9 @@ from typing import (
ClassVar,
)
-from airflow.sdk.definitions.mixins import DependencyMixin
-from airflow.sdk.definitions.node import DAGNode
-from airflow.sdk.definitions.templater import Templater
+from airflow.sdk.definitions._internal.mixins import DependencyMixin
+from airflow.sdk.definitions._internal.node import DAGNode
+from airflow.sdk.definitions._internal.templater import Templater
from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.weight_rule import WeightRule
diff --git a/task_sdk/src/airflow/sdk/definitions/contextmanager.py
b/task_sdk/src/airflow/sdk/definitions/_internal/contextmanager.py
similarity index 85%
rename from task_sdk/src/airflow/sdk/definitions/contextmanager.py
rename to task_sdk/src/airflow/sdk/definitions/_internal/contextmanager.py
index 3880bb6e357..cecb1a5c6ec 100644
--- a/task_sdk/src/airflow/sdk/definitions/contextmanager.py
+++ b/task_sdk/src/airflow/sdk/definitions/_internal/contextmanager.py
@@ -28,7 +28,7 @@ from airflow.sdk.definitions.taskgroup import TaskGroup
T = TypeVar("T")
-__all__ = ["DagContext", "TaskGroupContext", "get_current_context"]
+__all__ = ["DagContext", "TaskGroupContext"]
# This is a global variable that stores the current Task context.
# It is used to push the Context dictionary when Task starts execution
@@ -37,33 +37,7 @@ __all__ = ["DagContext", "TaskGroupContext",
"get_current_context"]
_CURRENT_CONTEXT: list[Mapping[str, Any]] = []
-def get_current_context() -> Mapping[str, Any]:
- """
- Retrieve the execution context dictionary without altering user method's
signature.
-
- This is the simplest method of retrieving the execution context dictionary.
-
- **Old style:**
-
- .. code:: python
-
- def my_task(**context):
- ti = context["ti"]
-
- **New style:**
-
- .. code:: python
-
- from airflow.providers.standard.operators.python import
get_current_context
-
-
- def my_task():
- context = get_current_context()
- ti = context["ti"]
-
- Current context will only have value if this method was called after an
operator
- was starting to execute.
- """
+def _get_current_context() -> Mapping[str, Any]:
if not _CURRENT_CONTEXT:
raise RuntimeError(
"Current context was requested but no context was found! Are you
running within an Airflow task?"
diff --git a/task_sdk/src/airflow/sdk/definitions/decorators.py
b/task_sdk/src/airflow/sdk/definitions/_internal/decorators.py
similarity index 100%
copy from task_sdk/src/airflow/sdk/definitions/decorators.py
copy to task_sdk/src/airflow/sdk/definitions/_internal/decorators.py
diff --git a/task_sdk/src/airflow/sdk/definitions/mixins.py
b/task_sdk/src/airflow/sdk/definitions/_internal/mixins.py
similarity index 98%
rename from task_sdk/src/airflow/sdk/definitions/mixins.py
rename to task_sdk/src/airflow/sdk/definitions/_internal/mixins.py
index 583d8b6491e..958d5a459be 100644
--- a/task_sdk/src/airflow/sdk/definitions/mixins.py
+++ b/task_sdk/src/airflow/sdk/definitions/_internal/mixins.py
@@ -108,7 +108,7 @@ class DependencyMixin:
@classmethod
def _iter_references(cls, obj: Any) -> Iterable[tuple[DependencyMixin,
str]]:
- from airflow.sdk.definitions.abstractoperator import AbstractOperator
+ from airflow.sdk.definitions._internal.abstractoperator import
AbstractOperator
if isinstance(obj, AbstractOperator):
yield obj, "operator"
diff --git a/task_sdk/src/airflow/sdk/definitions/node.py
b/task_sdk/src/airflow/sdk/definitions/_internal/node.py
similarity index 98%
rename from task_sdk/src/airflow/sdk/definitions/node.py
rename to task_sdk/src/airflow/sdk/definitions/_internal/node.py
index a29a36a7b7a..b8c02609118 100644
--- a/task_sdk/src/airflow/sdk/definitions/node.py
+++ b/task_sdk/src/airflow/sdk/definitions/_internal/node.py
@@ -27,14 +27,14 @@ from typing import TYPE_CHECKING, Any
import methodtools
import re2
-from airflow.sdk.definitions.mixins import DependencyMixin
+from airflow.sdk.definitions._internal.mixins import DependencyMixin
if TYPE_CHECKING:
from airflow.models.operator import Operator
+ from airflow.sdk.definitions._internal.types import Logger
from airflow.sdk.definitions.dag import DAG
from airflow.sdk.definitions.edges import EdgeModifier
from airflow.sdk.definitions.taskgroup import TaskGroup
- from airflow.sdk.types import Logger
from airflow.serialization.enums import DagAttributeTypes
diff --git a/task_sdk/src/airflow/sdk/definitions/templater.py
b/task_sdk/src/airflow/sdk/definitions/_internal/templater.py
similarity index 95%
rename from task_sdk/src/airflow/sdk/definitions/templater.py
rename to task_sdk/src/airflow/sdk/definitions/_internal/templater.py
index 65e9c70f390..8311f70981f 100644
--- a/task_sdk/src/airflow/sdk/definitions/templater.py
+++ b/task_sdk/src/airflow/sdk/definitions/_internal/templater.py
@@ -28,7 +28,7 @@ import jinja2.nativetypes
import jinja2.sandbox
from airflow.io.path import ObjectStoragePath
-from airflow.sdk.definitions.mixins import ResolveMixin
+from airflow.sdk.definitions._internal.mixins import ResolveMixin
from airflow.utils.helpers import render_template_as_native,
render_template_to_string
if TYPE_CHECKING:
@@ -38,17 +38,6 @@ if TYPE_CHECKING:
from airflow.sdk.definitions.dag import DAG
-def literal(value: Any) -> LiteralValue:
- """
- Wrap a value to ensure it is rendered as-is without applying Jinja
templating to its contents.
-
- Designed for use in an operator's template field.
-
- :param value: The value to be rendered without templating
- """
- return LiteralValue(value)
-
-
@dataclass(frozen=True)
class LiteralValue(ResolveMixin):
"""
@@ -69,8 +58,6 @@ class LiteralValue(ResolveMixin):
log = logging.getLogger(__name__)
-# TODO: Task-SDK: Should everything below this line live in
`_internal/templater.py`?
-# so that it is not exposed to the public API.
class Templater:
"""
This renders the template fields of object.
diff --git a/task_sdk/src/airflow/sdk/types.py
b/task_sdk/src/airflow/sdk/definitions/_internal/types.py
similarity index 97%
rename from task_sdk/src/airflow/sdk/types.py
rename to task_sdk/src/airflow/sdk/definitions/_internal/types.py
index ffde2170b17..0e3a39cde20 100644
--- a/task_sdk/src/airflow/sdk/types.py
+++ b/task_sdk/src/airflow/sdk/definitions/_internal/types.py
@@ -52,7 +52,7 @@ NOTSET = ArgNotSet()
if TYPE_CHECKING:
import logging
- from airflow.sdk.definitions.node import DAGNode
+ from airflow.sdk.definitions._internal.node import DAGNode
Logger = logging.Logger
else:
diff --git a/task_sdk/src/airflow/sdk/definitions/baseoperator.py
b/task_sdk/src/airflow/sdk/definitions/baseoperator.py
index 8dee46f00e4..cbd5284671a 100644
--- a/task_sdk/src/airflow/sdk/definitions/baseoperator.py
+++ b/task_sdk/src/airflow/sdk/definitions/baseoperator.py
@@ -34,7 +34,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, Final,
TypeVar, cast
import attrs
from airflow.models.param import ParamsDict
-from airflow.sdk.definitions.abstractoperator import (
+from airflow.sdk.definitions._internal.abstractoperator import (
DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
DEFAULT_OWNER,
DEFAULT_POOL_SLOTS,
@@ -48,9 +48,9 @@ from airflow.sdk.definitions.abstractoperator import (
DEFAULT_WEIGHT_RULE,
AbstractOperator,
)
-from airflow.sdk.definitions.decorators import fixup_decorator_warning_stack
-from airflow.sdk.definitions.node import validate_key
-from airflow.sdk.types import NOTSET, validate_instance_args
+from airflow.sdk.definitions._internal.decorators import
fixup_decorator_warning_stack
+from airflow.sdk.definitions._internal.node import validate_key
+from airflow.sdk.definitions._internal.types import NOTSET,
validate_instance_args
from airflow.task.priority_strategy import (
PriorityWeightStrategy,
airflow_priority_weight_strategies,
@@ -143,7 +143,7 @@ class BaseOperatorMeta(abc.ABCMeta):
@wraps(func)
def apply_defaults(self: BaseOperator, *args: Any, **kwargs: Any) ->
Any:
- from airflow.sdk.definitions.contextmanager import DagContext,
TaskGroupContext
+ from airflow.sdk.definitions._internal.contextmanager import
DagContext, TaskGroupContext
if args:
raise TypeError("Use keyword arguments when initializing
operators")
@@ -1176,7 +1176,7 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
def get_serialized_fields(cls):
"""Stringified DAGs and operators contain exactly these fields."""
if not cls.__serialized_fields:
- from airflow.sdk.definitions.contextmanager import DagContext
+ from airflow.sdk.definitions._internal.contextmanager import
DagContext
# make sure the following "fake" task is not added to current
active
# dag in context, otherwise, it will result in
diff --git a/task_sdk/src/airflow/sdk/definitions/decorators.py
b/task_sdk/src/airflow/sdk/definitions/context.py
similarity index 50%
rename from task_sdk/src/airflow/sdk/definitions/decorators.py
rename to task_sdk/src/airflow/sdk/definitions/context.py
index ab73ba0c924..41911143a1f 100644
--- a/task_sdk/src/airflow/sdk/definitions/decorators.py
+++ b/task_sdk/src/airflow/sdk/definitions/context.py
@@ -1,3 +1,4 @@
+#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -14,29 +15,39 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
from __future__ import annotations
-import sys
-from types import FunctionType
+from collections.abc import Mapping
+from typing import Any
+
+
+def get_current_context() -> Mapping[str, Any]:
+ """
+ Retrieve the execution context dictionary without altering user method's
signature.
+
+ This is the simplest method of retrieving the execution context dictionary.
+
+ **Old style:**
+
+ .. code:: python
+
+ def my_task(**context):
+ ti = context["ti"]
+ **New style:**
-class _autostacklevel_warn:
- def __init__(self):
- self.warnings = __import__("warnings")
+ .. code:: python
- def __getattr__(self, name: str):
- return getattr(self.warnings, name)
+ from airflow.sdk import get_current_context
- def __dir__(self):
- return dir(self.warnings)
- def warn(self, message, category=None, stacklevel=1, source=None):
- self.warnings.warn(message, category, stacklevel + 2, source)
+ def my_task():
+ context = get_current_context()
+ ti = context["ti"]
+ Current context will only have value if this method was called after an
operator
+ was starting to execute.
+ """
+ from airflow.sdk.definitions._internal.contextmanager import
_get_current_context
-def fixup_decorator_warning_stack(func: FunctionType):
- if func.__globals__.get("warnings") is sys.modules["warnings"]:
- # Yes, this is more than slightly hacky, but it _automatically_ sets
the right stacklevel parameter to
- # `warnings.warn` to ignore the decorator.
- func.__globals__["warnings"] = _autostacklevel_warn()
+ return _get_current_context()
diff --git a/task_sdk/src/airflow/sdk/definitions/dag.py
b/task_sdk/src/airflow/sdk/definitions/dag.py
index 0e0eead4f09..90b8b74360b 100644
--- a/task_sdk/src/airflow/sdk/definitions/dag.py
+++ b/task_sdk/src/airflow/sdk/definitions/dag.py
@@ -52,10 +52,10 @@ from airflow.exceptions import (
TaskNotFound,
)
from airflow.models.param import DagParam, ParamsDict
-from airflow.sdk.definitions.abstractoperator import AbstractOperator
+from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator
+from airflow.sdk.definitions._internal.types import NOTSET
from airflow.sdk.definitions.asset import AssetAll, BaseAsset
from airflow.sdk.definitions.baseoperator import BaseOperator
-from airflow.sdk.types import NOTSET
from airflow.timetables.base import Timetable
from airflow.timetables.simple import (
AssetTriggeredTimetable,
@@ -561,13 +561,13 @@ class DAG:
return hash(tuple(hash_components))
def __enter__(self) -> Self:
- from airflow.sdk.definitions.contextmanager import DagContext
+ from airflow.sdk.definitions._internal.contextmanager import DagContext
DagContext.push(self)
return self
def __exit__(self, _type, _value, _tb):
- from airflow.sdk.definitions.contextmanager import DagContext
+ from airflow.sdk.definitions._internal.contextmanager import DagContext
_ = DagContext.pop()
@@ -656,7 +656,7 @@ class DAG:
def get_template_env(self, *, force_sandboxed: bool = False) ->
jinja2.Environment:
"""Build a Jinja2 environment."""
- from airflow.sdk.definitions.templater import NativeEnvironment,
SandboxedEnvironment
+ from airflow.sdk.definitions._internal.templater import
NativeEnvironment, SandboxedEnvironment
# Collect directories to search for template files
searchpath = [self.folder]
@@ -892,7 +892,7 @@ class DAG:
"""
# FailStopDagInvalidTriggerRule.check(dag=self,
trigger_rule=task.trigger_rule)
- from airflow.sdk.definitions.contextmanager import TaskGroupContext
+ from airflow.sdk.definitions._internal.contextmanager import
TaskGroupContext
# if the task has no start date, assign it the same as the DAG
if not task.start_date:
diff --git a/task_sdk/src/airflow/sdk/definitions/edges.py
b/task_sdk/src/airflow/sdk/definitions/edges.py
index 7e50431b497..556cedbe78a 100644
--- a/task_sdk/src/airflow/sdk/definitions/edges.py
+++ b/task_sdk/src/airflow/sdk/definitions/edges.py
@@ -19,7 +19,7 @@ from __future__ import annotations
from collections.abc import Sequence
from typing import TYPE_CHECKING
-from airflow.sdk.definitions.mixins import DependencyMixin
+from airflow.sdk.definitions._internal.mixins import DependencyMixin
if TYPE_CHECKING:
from airflow.sdk.definitions.dag import DAG
@@ -71,7 +71,7 @@ class EdgeModifier(DependencyMixin):
stream: list[DependencyMixin],
):
from airflow.models.xcom_arg import XComArg
- from airflow.sdk.definitions.node import DAGNode
+ from airflow.sdk.definitions._internal.node import DAGNode
from airflow.sdk.definitions.taskgroup import TaskGroup
for node in self._make_list(nodes):
@@ -93,7 +93,7 @@ class EdgeModifier(DependencyMixin):
convert them to TaskGroups
"""
from airflow.models.xcom_arg import XComArg
- from airflow.sdk.definitions.node import DAGNode
+ from airflow.sdk.definitions._internal.node import DAGNode
from airflow.sdk.definitions.taskgroup import TaskGroup
group_ids = set()
@@ -118,7 +118,7 @@ class EdgeModifier(DependencyMixin):
self._downstream =
self._convert_stream_to_task_groups(self._downstream)
def _convert_stream_to_task_groups(self, stream:
Sequence[DependencyMixin]) -> Sequence[DependencyMixin]:
- from airflow.sdk.definitions.node import DAGNode
+ from airflow.sdk.definitions._internal.node import DAGNode
return [
node.task_group
diff --git a/task_sdk/src/airflow/sdk/definitions/taskgroup.py
b/task_sdk/src/airflow/sdk/definitions/taskgroup.py
index 07f8b452c19..cb5ee3eeece 100644
--- a/task_sdk/src/airflow/sdk/definitions/taskgroup.py
+++ b/task_sdk/src/airflow/sdk/definitions/taskgroup.py
@@ -36,21 +36,21 @@ from airflow.exceptions import (
DuplicateTaskIdFound,
TaskAlreadyInTaskGroup,
)
-from airflow.sdk.definitions.node import DAGNode
+from airflow.sdk.definitions._internal.node import DAGNode
from airflow.utils.trigger_rule import TriggerRule
if TYPE_CHECKING:
from airflow.models.expandinput import ExpandInput
- from airflow.sdk.definitions.abstractoperator import AbstractOperator
+ from airflow.sdk.definitions._internal.abstractoperator import
AbstractOperator
+ from airflow.sdk.definitions._internal.mixins import DependencyMixin
from airflow.sdk.definitions.baseoperator import BaseOperator
from airflow.sdk.definitions.dag import DAG
from airflow.sdk.definitions.edges import EdgeModifier
- from airflow.sdk.definitions.mixins import DependencyMixin
from airflow.serialization.enums import DagAttributeTypes
def _default_parent_group() -> TaskGroup | None:
- from airflow.sdk.definitions.contextmanager import TaskGroupContext
+ from airflow.sdk.definitions._internal.contextmanager import
TaskGroupContext
return TaskGroupContext.get_current()
@@ -65,7 +65,7 @@ def _parent_used_group_ids(tg: TaskGroup) -> set:
# that it makes Mypy (1.9.0 and 1.13.0 tested) seem to entirely loose track
that this is an Attrs class. So
# we've gone with this and moved on with our lives, mypy is to much of a dark
beast to battle over this.
def _default_dag(instance: TaskGroup):
- from airflow.sdk.definitions.contextmanager import DagContext
+ from airflow.sdk.definitions._internal.contextmanager import DagContext
if (pg := instance.parent_group) is not None:
return pg.dag
@@ -217,8 +217,8 @@ class TaskGroup(DAGNode):
:meta private:
"""
- from airflow.sdk.definitions.abstractoperator import AbstractOperator
- from airflow.sdk.definitions.contextmanager import TaskGroupContext
+ from airflow.sdk.definitions._internal.abstractoperator import
AbstractOperator
+ from airflow.sdk.definitions._internal.contextmanager import
TaskGroupContext
if TaskGroupContext.active:
if task.task_group and task.task_group != self:
@@ -346,13 +346,13 @@ class TaskGroup(DAGNode):
task.set_downstream(task_or_task_list)
def __enter__(self) -> TaskGroup:
- from airflow.sdk.definitions.contextmanager import TaskGroupContext
+ from airflow.sdk.definitions._internal.contextmanager import
TaskGroupContext
TaskGroupContext.push(self)
return self
def __exit__(self, _type, _value, _tb):
- from airflow.sdk.definitions.contextmanager import TaskGroupContext
+ from airflow.sdk.definitions._internal.contextmanager import
TaskGroupContext
TaskGroupContext.pop()
diff --git a/airflow/models/taskmixin.py
b/task_sdk/src/airflow/sdk/definitions/template.py
similarity index 67%
copy from airflow/models/taskmixin.py
copy to task_sdk/src/airflow/sdk/definitions/template.py
index fa76a3815cb..87032c6977d 100644
--- a/airflow/models/taskmixin.py
+++ b/task_sdk/src/airflow/sdk/definitions/template.py
@@ -16,13 +16,17 @@
# under the License.
from __future__ import annotations
-from typing import TYPE_CHECKING
+from typing import Any
-if TYPE_CHECKING:
- from airflow.typing_compat import TypeAlias
+from airflow.sdk.definitions._internal.templater import LiteralValue
-import airflow.sdk.definitions.mixins
-import airflow.sdk.definitions.node
-DependencyMixin: TypeAlias = airflow.sdk.definitions.mixins.DependencyMixin
-DAGNode: TypeAlias = airflow.sdk.definitions.node.DAGNode
+def literal(value: Any) -> LiteralValue:
+ """
+ Wrap a value to ensure it is rendered as-is without applying Jinja
templating to its contents.
+
+ Designed for use in an operator's template field.
+
+ :param value: The value to be rendered without templating
+ """
+ return LiteralValue(value)
diff --git a/task_sdk/src/airflow/sdk/execution_time/context.py
b/task_sdk/src/airflow/sdk/execution_time/context.py
index c5a1e9dbee4..50cbcf0a995 100644
--- a/task_sdk/src/airflow/sdk/execution_time/context.py
+++ b/task_sdk/src/airflow/sdk/execution_time/context.py
@@ -22,9 +22,9 @@ from typing import TYPE_CHECKING, Any
import structlog
-from airflow.sdk.definitions.contextmanager import _CURRENT_CONTEXT
+from airflow.sdk.definitions._internal.contextmanager import _CURRENT_CONTEXT
+from airflow.sdk.definitions._internal.types import NOTSET
from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType
-from airflow.sdk.types import NOTSET
if TYPE_CHECKING:
from airflow.sdk.definitions.connection import Connection
diff --git a/airflow/models/taskmixin.py
b/task_sdk/tests/defintions/_internal/__init__.py
similarity index 69%
copy from airflow/models/taskmixin.py
copy to task_sdk/tests/defintions/_internal/__init__.py
index fa76a3815cb..13a83393a91 100644
--- a/airflow/models/taskmixin.py
+++ b/task_sdk/tests/defintions/_internal/__init__.py
@@ -14,15 +14,3 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from __future__ import annotations
-
-from typing import TYPE_CHECKING
-
-if TYPE_CHECKING:
- from airflow.typing_compat import TypeAlias
-
-import airflow.sdk.definitions.mixins
-import airflow.sdk.definitions.node
-
-DependencyMixin: TypeAlias = airflow.sdk.definitions.mixins.DependencyMixin
-DAGNode: TypeAlias = airflow.sdk.definitions.node.DAGNode
diff --git a/task_sdk/tests/defintions/test_templater.py
b/task_sdk/tests/defintions/_internal/test_templater.py
similarity index 97%
rename from task_sdk/tests/defintions/test_templater.py
rename to task_sdk/tests/defintions/_internal/test_templater.py
index 69855b4ac28..7e33c0b4092 100644
--- a/task_sdk/tests/defintions/test_templater.py
+++ b/task_sdk/tests/defintions/_internal/test_templater.py
@@ -22,8 +22,8 @@ from datetime import datetime, timezone
import jinja2
import pytest
+from airflow.sdk.definitions._internal.templater import LiteralValue,
SandboxedEnvironment, Templater
from airflow.sdk.definitions.dag import DAG
-from airflow.sdk.definitions.templater import LiteralValue,
SandboxedEnvironment, Templater
class TestTemplater:
diff --git a/task_sdk/tests/defintions/test_baseoperator.py
b/task_sdk/tests/defintions/test_baseoperator.py
index e64a31717bc..35f33818dc1 100644
--- a/task_sdk/tests/defintions/test_baseoperator.py
+++ b/task_sdk/tests/defintions/test_baseoperator.py
@@ -29,7 +29,7 @@ import pytest
from airflow.sdk.definitions.baseoperator import BaseOperator, BaseOperatorMeta
from airflow.sdk.definitions.dag import DAG
-from airflow.sdk.definitions.templater import literal
+from airflow.sdk.definitions.template import literal
from airflow.task.priority_strategy import _DownstreamPriorityWeightStrategy,
_UpstreamPriorityWeightStrategy
DEFAULT_DATE = datetime(2016, 1, 1, tzinfo=timezone.utc)
@@ -492,7 +492,7 @@ class TestBaseOperator:
with pytest.raises(jinja2.exceptions.TemplateSyntaxError):
task.render_template("{{ invalid expression }}", {})
- @mock.patch("airflow.sdk.definitions.templater.SandboxedEnvironment",
autospec=True)
+
@mock.patch("airflow.sdk.definitions._internal.templater.SandboxedEnvironment",
autospec=True)
def test_jinja_env_creation(self, mock_jinja_env):
"""Verify if a Jinja environment is created only once when
templating."""
task = MockOperator(task_id="op1", arg1="{{ foo }}", arg2="{{ bar }}")
diff --git a/task_sdk/tests/defintions/test_contextmanager.py
b/task_sdk/tests/defintions/test_context.py
similarity index 82%
rename from task_sdk/tests/defintions/test_contextmanager.py
rename to task_sdk/tests/defintions/test_context.py
index be624aff3d1..dc25ec378bd 100644
--- a/task_sdk/tests/defintions/test_contextmanager.py
+++ b/task_sdk/tests/defintions/test_context.py
@@ -19,7 +19,7 @@ from __future__ import annotations
import pytest
-from airflow.sdk import get_current_context
+from airflow.sdk.definitions.context import get_current_context
class TestCurrentContext:
@@ -29,11 +29,13 @@ class TestCurrentContext:
def test_get_current_context_with_context(self, monkeypatch):
mock_context = {"ti": "task_instance", "key": "value"}
-
monkeypatch.setattr("airflow.sdk.definitions.contextmanager._CURRENT_CONTEXT",
[mock_context])
+ monkeypatch.setattr(
+
"airflow.sdk.definitions._internal.contextmanager._CURRENT_CONTEXT",
[mock_context]
+ )
result = get_current_context()
assert result == mock_context
def test_get_current_context_without_context(self, monkeypatch):
-
monkeypatch.setattr("airflow.sdk.definitions.contextmanager._CURRENT_CONTEXT",
[])
+
monkeypatch.setattr("airflow.sdk.definitions._internal.contextmanager._CURRENT_CONTEXT",
[])
with pytest.raises(RuntimeError, match="Current context was requested
but no context was found!"):
get_current_context()
diff --git a/airflow/models/taskmixin.py
b/task_sdk/tests/defintions/test_template.py
similarity index 54%
copy from airflow/models/taskmixin.py
copy to task_sdk/tests/defintions/test_template.py
index fa76a3815cb..deb8b96f64f 100644
--- a/airflow/models/taskmixin.py
+++ b/task_sdk/tests/defintions/test_template.py
@@ -14,15 +14,30 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+
from __future__ import annotations
-from typing import TYPE_CHECKING
+from airflow.sdk.definitions._internal.templater import Templater
+from airflow.sdk.definitions.template import literal
+
+
+def test_not_render_literal_value():
+ templater = Templater()
+ templater.template_ext = []
+ context = {}
+ content = literal("Hello {{ name }}")
+
+ rendered_content = templater.render_template(content, context)
+
+ assert rendered_content == "Hello {{ name }}"
+
-if TYPE_CHECKING:
- from airflow.typing_compat import TypeAlias
+def test_not_render_file_literal_value():
+ templater = Templater()
+ templater.template_ext = [".txt"]
+ context = {}
+ content = literal("template_file.txt")
-import airflow.sdk.definitions.mixins
-import airflow.sdk.definitions.node
+ rendered_content = templater.render_template(content, context)
-DependencyMixin: TypeAlias = airflow.sdk.definitions.mixins.DependencyMixin
-DAGNode: TypeAlias = airflow.sdk.definitions.node.DAGNode
+ assert rendered_content == "template_file.txt"
diff --git a/task_sdk/tests/execution_time/test_context.py
b/task_sdk/tests/execution_time/test_context.py
index 21a79ae5c3e..6527d517e37 100644
--- a/task_sdk/tests/execution_time/test_context.py
+++ b/task_sdk/tests/execution_time/test_context.py
@@ -21,8 +21,8 @@ from unittest.mock import MagicMock, patch
import pytest
+from airflow.sdk import get_current_context
from airflow.sdk.definitions.connection import Connection
-from airflow.sdk.definitions.contextmanager import get_current_context
from airflow.sdk.definitions.variable import Variable
from airflow.sdk.exceptions import ErrorType
from airflow.sdk.execution_time.comms import ConnectionResult, ErrorResponse,
VariableResult
diff --git a/task_sdk/tests/execution_time/test_task_runner.py
b/task_sdk/tests/execution_time/test_task_runner.py
index ebe2e05ec95..e4f1afaf881 100644
--- a/task_sdk/tests/execution_time/test_task_runner.py
+++ b/task_sdk/tests/execution_time/test_task_runner.py
@@ -33,7 +33,7 @@ from airflow.exceptions import (
AirflowSkipException,
AirflowTaskTerminated,
)
-from airflow.sdk import DAG, BaseOperator, Connection
+from airflow.sdk import DAG, BaseOperator, Connection, get_current_context
from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState
from airflow.sdk.definitions.variable import Variable
from airflow.sdk.execution_time.comms import (
@@ -450,8 +450,6 @@ def test_get_context_in_task(create_runtime_ti,
time_machine, mock_supervisor_co
class MyContextAssertOperator(BaseOperator):
def execute(self, context):
- from airflow.sdk import get_current_context
-
# Ensure the context returned by get_current_context is the same
as the
# context passed to the operator
assert context == get_current_context()
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index 03c0c2e1dae..6eaa4e3ac3a 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -70,9 +70,9 @@ from airflow.operators.empty import EmptyOperator
from airflow.providers.standard.operators.bash import BashOperator
from airflow.providers.standard.operators.python import PythonOperator
from airflow.sdk import TaskGroup
+from airflow.sdk.definitions._internal.contextmanager import TaskGroupContext
+from airflow.sdk.definitions._internal.templater import NativeEnvironment,
SandboxedEnvironment
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAll, AssetAny
-from airflow.sdk.definitions.contextmanager import TaskGroupContext
-from airflow.sdk.definitions.templater import NativeEnvironment,
SandboxedEnvironment
from airflow.security import permissions
from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction,
Timetable
from airflow.timetables.simple import (