This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new d22505da575 Remove SDK reference for NOTSET in Airflow Core (#58258)
d22505da575 is described below
commit d22505da575fe06b825e999ea03cde2291abf5b5
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Fri Nov 14 22:47:26 2025 +0800
Remove SDK reference for NOTSET in Airflow Core (#58258)
---
airflow-core/src/airflow/dag_processing/dagbag.py | 21 ++++------
airflow-core/src/airflow/models/dagrun.py | 6 +--
airflow-core/src/airflow/models/variable.py | 8 +---
airflow-core/src/airflow/models/xcom_arg.py | 17 ++++----
.../airflow/serialization/definitions/notset.py | 22 ++++++++--
.../src/airflow/serialization/definitions/param.py | 4 +-
.../airflow/serialization/serialized_objects.py | 12 ++++--
airflow-core/src/airflow/utils/context.py | 12 +++---
airflow-core/src/airflow/utils/helpers.py | 10 +----
airflow-core/src/airflow/utils/types.py | 22 ++++++----
airflow-core/tests/unit/models/test_xcom_arg.py | 2 +-
airflow-core/tests/unit/utils/test_helpers.py | 2 +-
devel-common/src/tests_common/pytest_plugin.py | 10 ++---
.../src/tests_common/test_utils/version_compat.py | 5 ++-
.../src/airflow/providers/amazon/aws/hooks/s3.py | 6 +--
.../src/airflow/providers/amazon/aws/hooks/ssm.py | 8 ++--
.../providers/amazon/aws/operators/base_aws.py | 2 +-
.../airflow/providers/amazon/aws/operators/emr.py | 2 +-
.../providers/amazon/aws/sensors/base_aws.py | 2 +-
.../airflow/providers/amazon/aws/transfers/base.py | 8 ++--
.../amazon/aws/transfers/dynamodb_to_s3.py | 4 +-
.../amazon/aws/transfers/redshift_to_s3.py | 10 ++---
.../amazon/aws/transfers/s3_to_redshift.py | 10 ++---
.../providers/amazon/aws/triggers/bedrock.py | 2 +-
.../amazon/aws/utils/connection_wrapper.py | 2 +-
.../src/airflow/providers/amazon/version_compat.py | 19 +++++++++
.../unit/amazon/aws/hooks/test_redshift_sql.py | 2 +-
.../unit/amazon/aws/notifications/test_ses.py | 2 +-
.../unit/amazon/aws/notifications/test_sns.py | 2 +-
.../unit/amazon/aws/notifications/test_sqs.py | 2 +-
.../unit/amazon/aws/operators/test_comprehend.py | 2 +-
.../tests/unit/amazon/aws/operators/test_ecs.py | 2 +-
.../amazon/aws/operators/test_emr_serverless.py | 2 +-
.../tests/unit/amazon/aws/sensors/test_ecs.py | 2 +-
.../unit/amazon/aws/utils/test_identifiers.py | 2 +-
.../providers/edge3/example_dags/win_test.py | 48 ++++++++++++----------
.../providers/google/cloud/hooks/compute_ssh.py | 6 ++-
.../google/cloud/log/stackdriver_task_handler.py | 13 +++---
.../tests/unit/postgres/hooks/test_postgres.py | 3 +-
.../src/airflow/providers/slack/utils/__init__.py | 5 ++-
.../ssh/src/airflow/providers/ssh/hooks/ssh.py | 17 ++++++--
.../ssh/src/airflow/providers/ssh/operators/ssh.py | 6 ++-
providers/ssh/tests/unit/ssh/operators/test_ssh.py | 3 +-
.../airflow/providers/standard/operators/bash.py | 3 +-
.../providers/standard/operators/trigger_dagrun.py | 7 +++-
.../tests/unit/standard/decorators/test_bash.py | 4 +-
.../tests/unit/standard/operators/test_python.py | 9 +++-
.../src/airflow/sdk/definitions/_internal/types.py | 47 ++++++++++-----------
task-sdk/src/airflow/sdk/definitions/dag.py | 4 +-
task-sdk/src/airflow/sdk/definitions/param.py | 4 +-
task-sdk/src/airflow/sdk/definitions/xcom_arg.py | 28 ++++++-------
.../src/airflow/sdk/execution_time/task_runner.py | 4 +-
.../task_sdk/execution_time/test_task_runner.py | 6 +--
53 files changed, 263 insertions(+), 200 deletions(-)
diff --git a/airflow-core/src/airflow/dag_processing/dagbag.py
b/airflow-core/src/airflow/dag_processing/dagbag.py
index cadd26412c9..173f5b05b4e 100644
--- a/airflow-core/src/airflow/dag_processing/dagbag.py
+++ b/airflow-core/src/airflow/dag_processing/dagbag.py
@@ -49,6 +49,7 @@ from airflow.exceptions import (
)
from airflow.executors.executor_loader import ExecutorLoader
from airflow.listeners.listener import get_listener_manager
+from airflow.serialization.definitions.notset import NOTSET, ArgNotSet,
is_arg_set
from airflow.serialization.serialized_objects import LazyDeserializedDAG
from airflow.utils.docs import get_docs_url
from airflow.utils.file import (
@@ -59,7 +60,6 @@ from airflow.utils.file import (
)
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.types import NOTSET
if TYPE_CHECKING:
from collections.abc import Generator
@@ -68,7 +68,6 @@ if TYPE_CHECKING:
from airflow import DAG
from airflow.models.dagwarning import DagWarning
- from airflow.utils.types import ArgNotSet
@contextlib.contextmanager
@@ -231,14 +230,6 @@ class DagBag(LoggingMixin):
super().__init__()
self.bundle_path = bundle_path
self.bundle_name = bundle_name
- include_examples = (
- include_examples
- if isinstance(include_examples, bool)
- else conf.getboolean("core", "LOAD_EXAMPLES")
- )
- safe_mode = (
- safe_mode if isinstance(safe_mode, bool) else
conf.getboolean("core", "DAG_DISCOVERY_SAFE_MODE")
- )
dag_folder = dag_folder or settings.DAGS_FOLDER
self.dag_folder = dag_folder
@@ -259,8 +250,14 @@ class DagBag(LoggingMixin):
if collect_dags:
self.collect_dags(
dag_folder=dag_folder,
- include_examples=include_examples,
- safe_mode=safe_mode,
+ include_examples=(
+ include_examples
+ if is_arg_set(include_examples)
+ else conf.getboolean("core", "LOAD_EXAMPLES")
+ ),
+ safe_mode=(
+ safe_mode if is_arg_set(safe_mode) else
conf.getboolean("core", "DAG_DISCOVERY_SAFE_MODE")
+ ),
)
# Should the extra operator link be loaded via plugins?
# This flag is set to False in Scheduler so that Extra Operator links
are not loaded
diff --git a/airflow-core/src/airflow/models/dagrun.py
b/airflow-core/src/airflow/models/dagrun.py
index 97598a3c2de..54ca61a75b3 100644
--- a/airflow-core/src/airflow/models/dagrun.py
+++ b/airflow-core/src/airflow/models/dagrun.py
@@ -69,6 +69,7 @@ from airflow.models.taskinstancehistory import
TaskInstanceHistory as TIH
from airflow.models.tasklog import LogTemplate
from airflow.models.taskmap import TaskMap
from airflow.sdk.definitions.deadline import DeadlineReference
+from airflow.serialization.definitions.notset import NOTSET, ArgNotSet,
is_arg_set
from airflow.stats import Stats
from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.dependencies_states import SCHEDULEABLE_STATES
@@ -90,7 +91,7 @@ from airflow.utils.sqlalchemy import (
from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.strings import get_random_string
from airflow.utils.thread_safe_dict import ThreadSafeDict
-from airflow.utils.types import NOTSET, DagRunTriggeredByType, DagRunType
+from airflow.utils.types import DagRunTriggeredByType, DagRunType
if TYPE_CHECKING:
from typing import Literal, TypeAlias
@@ -105,7 +106,6 @@ if TYPE_CHECKING:
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.sdk import DAG as SDKDAG
from airflow.serialization.serialized_objects import
SerializedBaseOperator, SerializedDAG
- from airflow.utils.types import ArgNotSet
CreatedTasks = TypeVar("CreatedTasks", Iterator["dict[str, Any]"],
Iterator[TI])
AttributeValueType: TypeAlias = (
@@ -348,7 +348,7 @@ class DagRun(Base, LoggingMixin):
self.conf = conf or {}
if state is not None:
self.state = state
- if queued_at is NOTSET:
+ if not is_arg_set(queued_at):
self.queued_at = timezone.utcnow() if state == DagRunState.QUEUED
else None
elif queued_at is not None:
self.queued_at = queued_at
diff --git a/airflow-core/src/airflow/models/variable.py
b/airflow-core/src/airflow/models/variable.py
index a13eb4fe158..83c457eb772 100644
--- a/airflow-core/src/airflow/models/variable.py
+++ b/airflow-core/src/airflow/models/variable.py
@@ -156,13 +156,9 @@ class Variable(Base, LoggingMixin):
stacklevel=1,
)
from airflow.sdk import Variable as TaskSDKVariable
- from airflow.sdk.definitions._internal.types import NOTSET
- var_val = TaskSDKVariable.get(
- key,
- default=NOTSET if default_var is cls.__NO_DEFAULT_SENTINEL
else default_var,
- deserialize_json=deserialize_json,
- )
+ default_kwargs = {} if default_var is cls.__NO_DEFAULT_SENTINEL
else {"default": default_var}
+ var_val = TaskSDKVariable.get(key,
deserialize_json=deserialize_json, **default_kwargs)
if isinstance(var_val, str):
mask_secret(var_val, key)
diff --git a/airflow-core/src/airflow/models/xcom_arg.py
b/airflow-core/src/airflow/models/xcom_arg.py
index bc9326b2b2f..8da146ca3af 100644
--- a/airflow-core/src/airflow/models/xcom_arg.py
+++ b/airflow-core/src/airflow/models/xcom_arg.py
@@ -27,11 +27,10 @@ from sqlalchemy.orm import Session
from airflow.models.referencemixin import ReferenceMixin
from airflow.models.xcom import XCOM_RETURN_KEY
-from airflow.sdk.definitions._internal.types import ArgNotSet
from airflow.sdk.definitions.xcom_arg import XComArg
+from airflow.serialization.definitions.notset import NOTSET, is_arg_set
from airflow.utils.db import exists_query
from airflow.utils.state import State
-from airflow.utils.types import NOTSET
__all__ = ["XComArg", "get_task_map_length"]
@@ -150,7 +149,7 @@ def get_task_map_length(xcom_arg: SchedulerXComArg, run_id:
str, *, session: Ses
@get_task_map_length.register
-def _(xcom_arg: SchedulerPlainXComArg, run_id: str, *, session: Session):
+def _(xcom_arg: SchedulerPlainXComArg, run_id: str, *, session: Session) ->
int | None:
from airflow.models.mappedoperator import is_mapped
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskmap import TaskMap
@@ -193,23 +192,23 @@ def _(xcom_arg: SchedulerPlainXComArg, run_id: str, *,
session: Session):
@get_task_map_length.register
-def _(xcom_arg: SchedulerMapXComArg, run_id: str, *, session: Session):
+def _(xcom_arg: SchedulerMapXComArg, run_id: str, *, session: Session) -> int
| None:
return get_task_map_length(xcom_arg.arg, run_id, session=session)
@get_task_map_length.register
-def _(xcom_arg: SchedulerZipXComArg, run_id: str, *, session: Session):
+def _(xcom_arg: SchedulerZipXComArg, run_id: str, *, session: Session) -> int
| None:
all_lengths = (get_task_map_length(arg, run_id, session=session) for arg
in xcom_arg.args)
ready_lengths = [length for length in all_lengths if length is not None]
if len(ready_lengths) != len(xcom_arg.args):
return None # If any of the referenced XComs is not ready, we are not
ready either.
- if isinstance(xcom_arg.fillvalue, ArgNotSet):
- return min(ready_lengths)
- return max(ready_lengths)
+ if is_arg_set(xcom_arg.fillvalue):
+ return max(ready_lengths)
+ return min(ready_lengths)
@get_task_map_length.register
-def _(xcom_arg: SchedulerConcatXComArg, run_id: str, *, session: Session):
+def _(xcom_arg: SchedulerConcatXComArg, run_id: str, *, session: Session) ->
int | None:
all_lengths = (get_task_map_length(arg, run_id, session=session) for arg
in xcom_arg.args)
ready_lengths = [length for length in all_lengths if length is not None]
if len(ready_lengths) != len(xcom_arg.args):
diff --git a/airflow-core/src/airflow/serialization/definitions/notset.py
b/airflow-core/src/airflow/serialization/definitions/notset.py
index 0e2057c45d0..a7731daed20 100644
--- a/airflow-core/src/airflow/serialization/definitions/notset.py
+++ b/airflow-core/src/airflow/serialization/definitions/notset.py
@@ -18,7 +18,23 @@
from __future__ import annotations
-from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet
+from typing import TYPE_CHECKING, TypeVar
-# TODO (GH-52141): Have different NOTSET and ArgNotSet in the scheduler.
-__all__ = ["NOTSET", "ArgNotSet"]
+if TYPE_CHECKING:
+ from typing_extensions import TypeIs
+
+ T = TypeVar("T")
+
+__all__ = ["NOTSET", "ArgNotSet", "is_arg_set"]
+
+
+class ArgNotSet:
+ """Sentinel type for annotations, useful when None is not viable."""
+
+
+NOTSET = ArgNotSet()
+"""Sentinel value for argument default. See ``ArgNotSet``."""
+
+
+def is_arg_set(value: T | ArgNotSet) -> TypeIs[T]:
+ return not isinstance(value, ArgNotSet)
diff --git a/airflow-core/src/airflow/serialization/definitions/param.py
b/airflow-core/src/airflow/serialization/definitions/param.py
index 8169b23f59c..733131f3eab 100644
--- a/airflow-core/src/airflow/serialization/definitions/param.py
+++ b/airflow-core/src/airflow/serialization/definitions/param.py
@@ -22,7 +22,7 @@ import collections.abc
import copy
from typing import TYPE_CHECKING, Any
-from airflow.serialization.definitions.notset import NOTSET, ArgNotSet
+from airflow.serialization.definitions.notset import NOTSET, is_arg_set
if TYPE_CHECKING:
from collections.abc import Iterator, Mapping
@@ -51,7 +51,7 @@ class SerializedParam:
import jsonschema
try:
- if isinstance(value := self.value, ArgNotSet):
+ if not is_arg_set(value := self.value):
raise ValueError("No value passed")
jsonschema.validate(value, self.schema,
format_checker=jsonschema.FormatChecker())
except Exception:
diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py
b/airflow-core/src/airflow/serialization/serialized_objects.py
index acb05228bef..9a2a8efe96e 100644
--- a/airflow-core/src/airflow/serialization/serialized_objects.py
+++ b/airflow-core/src/airflow/serialization/serialized_objects.py
@@ -122,7 +122,7 @@ from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.module_loading import import_string, qualname
from airflow.utils.session import NEW_SESSION, create_session, provide_session
from airflow.utils.state import DagRunState, TaskInstanceState
-from airflow.utils.types import NOTSET, ArgNotSet, DagRunTriggeredByType,
DagRunType
+from airflow.utils.types import DagRunTriggeredByType, DagRunType
if TYPE_CHECKING:
from inspect import Parameter
@@ -736,7 +736,11 @@ class BaseSerialization:
:meta private:
"""
- if cls._is_primitive(var):
+ from airflow.sdk.definitions._internal.types import is_arg_set
+
+ if not is_arg_set(var):
+ return cls._encode(None, type_=DAT.ARG_NOT_SET)
+ elif cls._is_primitive(var):
# enum.IntEnum is an int instance, it causes json dumps error so
we use its value.
if isinstance(var, enum.Enum):
return var.value
@@ -867,8 +871,6 @@ class BaseSerialization:
obj = cls.serialize(v, strict=strict)
d[str(k)] = obj
return cls._encode(d, type_=DAT.TASK_CONTEXT)
- elif isinstance(var, ArgNotSet):
- return cls._encode(None, type_=DAT.ARG_NOT_SET)
else:
return cls.default_serialization(strict, var)
@@ -981,6 +983,8 @@ class BaseSerialization:
elif type_ == DAT.TASK_INSTANCE_KEY:
return TaskInstanceKey(**var)
elif type_ == DAT.ARG_NOT_SET:
+ from airflow.serialization.definitions.notset import NOTSET
+
return NOTSET
elif type_ == DAT.DEADLINE_ALERT:
return DeadlineAlert.deserialize_deadline_alert(var)
diff --git a/airflow-core/src/airflow/utils/context.py
b/airflow-core/src/airflow/utils/context.py
index abf6bb1a530..4793d628323 100644
--- a/airflow-core/src/airflow/utils/context.py
+++ b/airflow-core/src/airflow/utils/context.py
@@ -30,17 +30,15 @@ from typing import (
from sqlalchemy import select
-from airflow.models.asset import (
- AssetModel,
-)
+from airflow.models.asset import AssetModel
from airflow.sdk.definitions.context import Context
from airflow.sdk.execution_time.context import (
ConnectionAccessor as ConnectionAccessorSDK,
OutletEventAccessors as OutletEventAccessorsSDK,
VariableAccessor as VariableAccessorSDK,
)
+from airflow.serialization.definitions.notset import NOTSET, is_arg_set
from airflow.utils.session import create_session
-from airflow.utils.types import NOTSET
if TYPE_CHECKING:
from airflow.sdk.definitions.asset import Asset
@@ -100,9 +98,9 @@ class VariableAccessor(VariableAccessorSDK):
def get(self, key, default: Any = NOTSET) -> Any:
from airflow.models.variable import Variable
- if default is NOTSET:
- return Variable.get(key, deserialize_json=self._deserialize_json)
- return Variable.get(key, default,
deserialize_json=self._deserialize_json)
+ if is_arg_set(default):
+ return Variable.get(key, default,
deserialize_json=self._deserialize_json)
+ return Variable.get(key, deserialize_json=self._deserialize_json)
class ConnectionAccessor(ConnectionAccessorSDK):
diff --git a/airflow-core/src/airflow/utils/helpers.py
b/airflow-core/src/airflow/utils/helpers.py
index 4d79a28d41e..0e4c5f325b1 100644
--- a/airflow-core/src/airflow/utils/helpers.py
+++ b/airflow-core/src/airflow/utils/helpers.py
@@ -30,7 +30,7 @@ from lazy_object_proxy import Proxy
from airflow.configuration import conf
from airflow.exceptions import AirflowException
-from airflow.utils.types import NOTSET
+from airflow.serialization.definitions.notset import is_arg_set
if TYPE_CHECKING:
from datetime import datetime
@@ -283,13 +283,7 @@ def at_most_one(*args) -> bool:
If user supplies an iterable, we raise ValueError and force them to unpack.
"""
-
- def is_set(val):
- if val is NOTSET:
- return False
- return bool(val)
-
- return sum(map(is_set, args)) in (0, 1)
+ return sum(is_arg_set(a) and bool(a) for a in args) in (0, 1)
def prune_dict(val: Any, mode="strict"):
diff --git a/airflow-core/src/airflow/utils/types.py
b/airflow-core/src/airflow/utils/types.py
index 276901f94bc..4e1aa11ffd8 100644
--- a/airflow-core/src/airflow/utils/types.py
+++ b/airflow-core/src/airflow/utils/types.py
@@ -14,19 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+
from __future__ import annotations
import enum
-from typing import TYPE_CHECKING
-
-import airflow.sdk.definitions._internal.types
-
-if TYPE_CHECKING:
- from typing import TypeAlias
-ArgNotSet: TypeAlias = airflow.sdk.definitions._internal.types.ArgNotSet
-
-NOTSET = airflow.sdk.definitions._internal.types.NOTSET
+from airflow.utils.deprecation_tools import add_deprecated_classes
class DagRunType(str, enum.Enum):
@@ -68,3 +61,14 @@ class DagRunTriggeredByType(enum.Enum):
TIMETABLE = "timetable" # for timetable based triggering
ASSET = "asset" # for asset_triggered run type
BACKFILL = "backfill"
+
+
+add_deprecated_classes(
+ {
+ __name__: {
+ "ArgNotSet": "airflow.serialization.definitions.notset.ArgNotSet",
+ "NOTSET": "airflow.serialization.definitions.notset.ArgNotSet",
+ },
+ },
+ package=__name__,
+)
diff --git a/airflow-core/tests/unit/models/test_xcom_arg.py
b/airflow-core/tests/unit/models/test_xcom_arg.py
index f5ce83df4a9..ee293544449 100644
--- a/airflow-core/tests/unit/models/test_xcom_arg.py
+++ b/airflow-core/tests/unit/models/test_xcom_arg.py
@@ -21,7 +21,7 @@ import pytest
from airflow.models.xcom_arg import XComArg
from airflow.providers.standard.operators.bash import BashOperator
from airflow.providers.standard.operators.python import PythonOperator
-from airflow.utils.types import NOTSET
+from airflow.serialization.definitions.notset import NOTSET
from tests_common.test_utils.db import clear_db_dags, clear_db_runs
diff --git a/airflow-core/tests/unit/utils/test_helpers.py
b/airflow-core/tests/unit/utils/test_helpers.py
index 6f297904eb7..6179acadfb9 100644
--- a/airflow-core/tests/unit/utils/test_helpers.py
+++ b/airflow-core/tests/unit/utils/test_helpers.py
@@ -26,6 +26,7 @@ import pytest
from airflow._shared.timezones import timezone
from airflow.exceptions import AirflowException
from airflow.jobs.base_job_runner import BaseJobRunner
+from airflow.serialization.definitions.notset import NOTSET
from airflow.utils import helpers
from airflow.utils.helpers import (
at_most_one,
@@ -35,7 +36,6 @@ from airflow.utils.helpers import (
prune_dict,
validate_key,
)
-from airflow.utils.types import NOTSET
from tests_common.test_utils.db import clear_db_dags, clear_db_runs
diff --git a/devel-common/src/tests_common/pytest_plugin.py
b/devel-common/src/tests_common/pytest_plugin.py
index 96f1bfe814d..e70ebd0e693 100644
--- a/devel-common/src/tests_common/pytest_plugin.py
+++ b/devel-common/src/tests_common/pytest_plugin.py
@@ -883,7 +883,7 @@ def dag_maker(request) -> Generator[DagMaker, None, None]:
# This fixture is "called" early on in the pytest collection process, and
# if we import airflow.* here the wrong (non-test) config will be loaded
# and "baked" in to various constants
- from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS,
AIRFLOW_V_3_1_PLUS
+ from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS,
AIRFLOW_V_3_1_PLUS, NOTSET
want_serialized = False
want_activate_assets = True # Only has effect if want_serialized=True on
Airflow 3.
@@ -896,7 +896,6 @@ def dag_maker(request) -> Generator[DagMaker, None, None]:
(want_activate_assets,) = serialized_marker.args or (True,)
from airflow.utils.log.logging_mixin import LoggingMixin
- from airflow.utils.types import NOTSET
class DagFactory(LoggingMixin, DagMaker):
_own_session = False
@@ -1465,9 +1464,8 @@ def create_task_instance(dag_maker: DagMaker,
create_dummy_dag: CreateDummyDAG)
Uses ``create_dummy_dag`` to create the dag structure.
"""
from airflow.providers.standard.operators.empty import EmptyOperator
- from airflow.utils.types import NOTSET, ArgNotSet
- from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
+ from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS,
NOTSET, ArgNotSet
def maker(
logical_date: datetime | None | ArgNotSet = NOTSET,
@@ -1574,7 +1572,7 @@ class CreateTaskInstanceOfOperator(Protocol):
@pytest.fixture
def create_serialized_task_instance_of_operator(dag_maker: DagMaker) ->
CreateTaskInstanceOfOperator:
- from airflow.utils.types import NOTSET
+ from tests_common.test_utils.version_compat import NOTSET
def _create_task_instance(
operator_class,
@@ -1594,7 +1592,7 @@ def
create_serialized_task_instance_of_operator(dag_maker: DagMaker) -> CreateTa
@pytest.fixture
def create_task_instance_of_operator(dag_maker: DagMaker) ->
CreateTaskInstanceOfOperator:
- from airflow.utils.types import NOTSET
+ from tests_common.test_utils.version_compat import NOTSET
def _create_task_instance(
operator_class,
diff --git a/devel-common/src/tests_common/test_utils/version_compat.py
b/devel-common/src/tests_common/test_utils/version_compat.py
index ad093637b79..e30c692278f 100644
--- a/devel-common/src/tests_common/test_utils/version_compat.py
+++ b/devel-common/src/tests_common/test_utils/version_compat.py
@@ -39,17 +39,18 @@ AIRFLOW_V_3_1_PLUS = get_base_airflow_version_tuple() >=
(3, 1, 0)
AIRFLOW_V_3_1_3_PLUS = get_base_airflow_version_tuple() >= (3, 1, 3)
AIRFLOW_V_3_2_PLUS = get_base_airflow_version_tuple() >= (3, 2, 0)
-
if AIRFLOW_V_3_1_PLUS:
from airflow.sdk import PokeReturnValue, timezone
from airflow.sdk.bases.xcom import BaseXCom
from airflow.sdk.definitions._internal.decorators import
remove_task_decorator
+ from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet
XCOM_RETURN_KEY = BaseXCom.XCOM_RETURN_KEY
else:
from airflow.sensors.base import PokeReturnValue # type: ignore[no-redef]
from airflow.utils import timezone # type: ignore[attr-defined,no-redef]
from airflow.utils.decorators import remove_task_decorator # type:
ignore[no-redef]
+ from airflow.utils.types import NOTSET, ArgNotSet # type:
ignore[attr-defined,no-redef]
from airflow.utils.xcom import XCOM_RETURN_KEY # type: ignore[no-redef]
@@ -70,9 +71,11 @@ __all__ = [
"AIRFLOW_V_3_0_1",
"AIRFLOW_V_3_1_PLUS",
"AIRFLOW_V_3_2_PLUS",
+ "NOTSET",
"SQLALCHEMY_V_1_4",
"SQLALCHEMY_V_2_0",
"XCOM_RETURN_KEY",
+ "ArgNotSet",
"PokeReturnValue",
"remove_task_decorator",
"timezone",
diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/s3.py
b/providers/amazon/src/airflow/providers/amazon/aws/hooks/s3.py
index 5b275cb641f..87404dedb41 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/s3.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/s3.py
@@ -43,15 +43,13 @@ from urllib.parse import urlsplit
from uuid import uuid4
if TYPE_CHECKING:
+ from aiobotocore.client import AioBaseClient
from mypy_boto3_s3.service_resource import (
Bucket as S3Bucket,
Object as S3ResourceObject,
)
- from airflow.utils.types import ArgNotSet
-
- with suppress(ImportError):
- from aiobotocore.client import AioBaseClient
+ from airflow.providers.amazon.version_compat import ArgNotSet
from asgiref.sync import sync_to_async
diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/ssm.py
b/providers/amazon/src/airflow/providers/amazon/aws/hooks/ssm.py
index 7bc2c821426..ef8c30323a7 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/ssm.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/ssm.py
@@ -20,7 +20,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
-from airflow.utils.types import NOTSET, ArgNotSet
+from airflow.providers.amazon.version_compat import NOTSET, ArgNotSet,
is_arg_set
if TYPE_CHECKING:
from airflow.sdk.execution_time.secrets_masker import mask_secret
@@ -71,9 +71,9 @@ class SsmHook(AwsBaseHook):
mask_secret(value)
return value
except self.conn.exceptions.ParameterNotFound:
- if isinstance(default, ArgNotSet):
- raise
- return default
+ if is_arg_set(default):
+ return default
+ raise
def get_command_invocation(self, command_id: str, instance_id: str) ->
dict:
"""
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/operators/base_aws.py
b/providers/amazon/src/airflow/providers/amazon/aws/operators/base_aws.py
index 29fb40d3e67..b293c52e3a7 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/operators/base_aws.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/base_aws.py
@@ -25,8 +25,8 @@ from airflow.providers.amazon.aws.utils.mixins import (
AwsHookType,
aws_template_fields,
)
+from airflow.providers.amazon.version_compat import NOTSET, ArgNotSet
from airflow.providers.common.compat.sdk import BaseOperator
-from airflow.utils.types import NOTSET, ArgNotSet
class AwsBaseOperator(BaseOperator, AwsBaseHookMixin[AwsHookType]):
diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py
b/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py
index c023bc04e14..b1296b0f9eb 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py
@@ -57,8 +57,8 @@ from airflow.providers.amazon.aws.utils.waiter import (
waiter,
)
from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
+from airflow.providers.amazon.version_compat import NOTSET, ArgNotSet
from airflow.utils.helpers import exactly_one, prune_dict
-from airflow.utils.types import NOTSET, ArgNotSet
if TYPE_CHECKING:
from airflow.utils.context import Context
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/sensors/base_aws.py
b/providers/amazon/src/airflow/providers/amazon/aws/sensors/base_aws.py
index b13634bc2bd..562e0816cea 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/sensors/base_aws.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/sensors/base_aws.py
@@ -25,8 +25,8 @@ from airflow.providers.amazon.aws.utils.mixins import (
AwsHookType,
aws_template_fields,
)
+from airflow.providers.amazon.version_compat import NOTSET, ArgNotSet
from airflow.providers.common.compat.sdk import BaseSensorOperator
-from airflow.utils.types import NOTSET, ArgNotSet
class AwsBaseSensor(BaseSensorOperator, AwsBaseHookMixin[AwsHookType]):
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/transfers/base.py
b/providers/amazon/src/airflow/providers/amazon/aws/transfers/base.py
index 612e57701cb..9be7fb54991 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/transfers/base.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/transfers/base.py
@@ -22,8 +22,8 @@ from __future__ import annotations
from collections.abc import Sequence
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
+from airflow.providers.amazon.version_compat import NOTSET, ArgNotSet,
is_arg_set
from airflow.providers.common.compat.sdk import BaseOperator
-from airflow.utils.types import NOTSET, ArgNotSet
class AwsToAwsBaseOperator(BaseOperator):
@@ -55,7 +55,7 @@ class AwsToAwsBaseOperator(BaseOperator):
self.source_aws_conn_id = source_aws_conn_id
self.dest_aws_conn_id = dest_aws_conn_id
self.source_aws_conn_id = source_aws_conn_id
- if isinstance(dest_aws_conn_id, ArgNotSet):
- self.dest_aws_conn_id = self.source_aws_conn_id
- else:
+ if is_arg_set(dest_aws_conn_id):
self.dest_aws_conn_id = dest_aws_conn_id
+ else:
+ self.dest_aws_conn_id = self.source_aws_conn_id
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py
b/providers/amazon/src/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py
index d50fcfd00ae..a683b4c9e21 100644
---
a/providers/amazon/src/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py
+++
b/providers/amazon/src/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py
@@ -36,8 +36,8 @@ from airflow.providers.amazon.aws.transfers.base import
AwsToAwsBaseOperator
from airflow.utils.helpers import prune_dict
if TYPE_CHECKING:
- from airflow.utils.context import Context
- from airflow.utils.types import ArgNotSet
+ from airflow.providers.amazon.version_compat import ArgNotSet
+ from airflow.sdk import Context
class JSONEncoder(json.JSONEncoder):
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/transfers/redshift_to_s3.py
b/providers/amazon/src/airflow/providers/amazon/aws/transfers/redshift_to_s3.py
index a285af2f65b..0fe68099ace 100644
---
a/providers/amazon/src/airflow/providers/amazon/aws/transfers/redshift_to_s3.py
+++
b/providers/amazon/src/airflow/providers/amazon/aws/transfers/redshift_to_s3.py
@@ -28,8 +28,8 @@ from airflow.providers.amazon.aws.hooks.redshift_data import
RedshiftDataHook
from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.amazon.aws.utils.redshift import build_credentials_block
+from airflow.providers.amazon.version_compat import NOTSET, ArgNotSet,
is_arg_set
from airflow.providers.common.compat.sdk import BaseOperator
-from airflow.utils.types import NOTSET, ArgNotSet
if TYPE_CHECKING:
from airflow.utils.context import Context
@@ -131,12 +131,12 @@ class RedshiftToS3Operator(BaseOperator):
# actually provide a connection note that, because we don't want to
let the exception bubble up in
# that case (since we're silently injecting a connection on their
behalf).
self._aws_conn_id: str | None
- if isinstance(aws_conn_id, ArgNotSet):
- self.conn_set = False
- self._aws_conn_id = "aws_default"
- else:
+ if is_arg_set(aws_conn_id):
self.conn_set = True
self._aws_conn_id = aws_conn_id
+ else:
+ self.conn_set = False
+ self._aws_conn_id = "aws_default"
def _build_unload_query(
self, credentials_block: str, select_query: str, s3_key: str,
unload_options: str
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
b/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
index ae36822976b..87d1d752e94 100644
---
a/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
+++
b/providers/amazon/src/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
@@ -24,8 +24,8 @@ from airflow.providers.amazon.aws.hooks.redshift_data import
RedshiftDataHook
from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.amazon.aws.utils.redshift import build_credentials_block
+from airflow.providers.amazon.version_compat import NOTSET, ArgNotSet,
is_arg_set
from airflow.providers.common.compat.sdk import BaseOperator
-from airflow.utils.types import NOTSET, ArgNotSet
if TYPE_CHECKING:
from airflow.utils.context import Context
@@ -122,12 +122,12 @@ class S3ToRedshiftOperator(BaseOperator):
# actually provide a connection note that, because we don't want to
let the exception bubble up in
# that case (since we're silently injecting a connection on their
behalf).
self._aws_conn_id: str | None
- if isinstance(aws_conn_id, ArgNotSet):
- self.conn_set = False
- self._aws_conn_id = "aws_default"
- else:
+ if is_arg_set(aws_conn_id):
self.conn_set = True
self._aws_conn_id = aws_conn_id
+ else:
+ self.conn_set = False
+ self._aws_conn_id = "aws_default"
if self.redshift_data_api_kwargs:
for arg in ["sql", "parameters"]:
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/triggers/bedrock.py
b/providers/amazon/src/airflow/providers/amazon/aws/triggers/bedrock.py
index faac0f90753..70d8254a316 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/bedrock.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/bedrock.py
@@ -20,7 +20,7 @@ from typing import TYPE_CHECKING
from airflow.providers.amazon.aws.hooks.bedrock import BedrockAgentHook,
BedrockHook
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
-from airflow.utils.types import NOTSET, ArgNotSet
+from airflow.providers.amazon.version_compat import NOTSET, ArgNotSet
if TYPE_CHECKING:
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/utils/connection_wrapper.py
b/providers/amazon/src/airflow/providers/amazon/aws/utils/connection_wrapper.py
index 3ed84db4843..e1f7bbeb0f4 100644
---
a/providers/amazon/src/airflow/providers/amazon/aws/utils/connection_wrapper.py
+++
b/providers/amazon/src/airflow/providers/amazon/aws/utils/connection_wrapper.py
@@ -28,8 +28,8 @@ from botocore.config import Config
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.utils import trim_none_values
+from airflow.providers.amazon.version_compat import NOTSET, ArgNotSet
from airflow.utils.log.logging_mixin import LoggingMixin
-from airflow.utils.types import NOTSET, ArgNotSet
if TYPE_CHECKING:
from airflow.providers.common.compat.sdk import Connection
diff --git a/providers/amazon/src/airflow/providers/amazon/version_compat.py
b/providers/amazon/src/airflow/providers/amazon/version_compat.py
index a7d116ec043..dc76a025ccf 100644
--- a/providers/amazon/src/airflow/providers/amazon/version_compat.py
+++ b/providers/amazon/src/airflow/providers/amazon/version_compat.py
@@ -20,9 +20,13 @@
# ON AIRFLOW VERSION, PLEASE COPY THIS FILE TO THE ROOT PACKAGE OF YOUR
PROVIDER AND IMPORT
# THOSE CONSTANTS FROM IT RATHER THAN IMPORTING THEM FROM ANOTHER PROVIDER OR
TEST CODE
#
+
from __future__ import annotations
+import functools
+
[email protected]
def get_base_airflow_version_tuple() -> tuple[int, int, int]:
from packaging.version import Version
@@ -36,8 +40,23 @@ AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3,
0, 0)
AIRFLOW_V_3_1_PLUS: bool = get_base_airflow_version_tuple() >= (3, 1, 0)
AIRFLOW_V_3_1_1_PLUS: bool = get_base_airflow_version_tuple() >= (3, 1, 1)
+try:
+ from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet
+except ImportError:
+ from airflow.utils.types import NOTSET, ArgNotSet # type:
ignore[attr-defined,no-redef]
+try:
+ from airflow.sdk.definitions._internal.types import is_arg_set
+except ImportError:
+
+ def is_arg_set(value): # type: ignore[misc,no-redef]
+ return value is not NOTSET
+
+
__all__ = [
"AIRFLOW_V_3_0_PLUS",
"AIRFLOW_V_3_1_PLUS",
"AIRFLOW_V_3_1_1_PLUS",
+ "NOTSET",
+ "ArgNotSet",
+ "is_arg_set",
]
diff --git a/providers/amazon/tests/unit/amazon/aws/hooks/test_redshift_sql.py
b/providers/amazon/tests/unit/amazon/aws/hooks/test_redshift_sql.py
index 399245face3..f873b6dba70 100644
--- a/providers/amazon/tests/unit/amazon/aws/hooks/test_redshift_sql.py
+++ b/providers/amazon/tests/unit/amazon/aws/hooks/test_redshift_sql.py
@@ -24,7 +24,7 @@ import pytest
from airflow.exceptions import AirflowException
from airflow.models import Connection
from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook
-from airflow.utils.types import NOTSET
+from airflow.providers.amazon.version_compat import NOTSET
LOGIN_USER = "login"
LOGIN_PASSWORD = "password"
diff --git a/providers/amazon/tests/unit/amazon/aws/notifications/test_ses.py
b/providers/amazon/tests/unit/amazon/aws/notifications/test_ses.py
index b848f38c5b3..423c41a8b2e 100644
--- a/providers/amazon/tests/unit/amazon/aws/notifications/test_ses.py
+++ b/providers/amazon/tests/unit/amazon/aws/notifications/test_ses.py
@@ -21,7 +21,7 @@ from unittest import mock
import pytest
from airflow.providers.amazon.aws.notifications.ses import SesNotifier,
send_ses_notification
-from airflow.utils.types import NOTSET
+from airflow.providers.amazon.version_compat import NOTSET
TEST_EMAIL_PARAMS = {
"mail_from": "[email protected]",
diff --git a/providers/amazon/tests/unit/amazon/aws/notifications/test_sns.py
b/providers/amazon/tests/unit/amazon/aws/notifications/test_sns.py
index b09098d0d49..ef3abcfbacb 100644
--- a/providers/amazon/tests/unit/amazon/aws/notifications/test_sns.py
+++ b/providers/amazon/tests/unit/amazon/aws/notifications/test_sns.py
@@ -21,7 +21,7 @@ from unittest import mock
import pytest
from airflow.providers.amazon.aws.notifications.sns import SnsNotifier,
send_sns_notification
-from airflow.utils.types import NOTSET
+from airflow.providers.amazon.version_compat import NOTSET
PUBLISH_KWARGS = {
"target_arn": "arn:aws:sns:us-west-2:123456789098:TopicName",
diff --git a/providers/amazon/tests/unit/amazon/aws/notifications/test_sqs.py
b/providers/amazon/tests/unit/amazon/aws/notifications/test_sqs.py
index 2c33e77f2e4..10a9d115f73 100644
--- a/providers/amazon/tests/unit/amazon/aws/notifications/test_sqs.py
+++ b/providers/amazon/tests/unit/amazon/aws/notifications/test_sqs.py
@@ -21,7 +21,7 @@ from unittest import mock
import pytest
from airflow.providers.amazon.aws.notifications.sqs import SqsNotifier,
send_sqs_notification
-from airflow.utils.types import NOTSET
+from airflow.providers.amazon.version_compat import NOTSET
PARAM_DEFAULT_VALUE = pytest.param(NOTSET, id="default-value")
diff --git
a/providers/amazon/tests/unit/amazon/aws/operators/test_comprehend.py
b/providers/amazon/tests/unit/amazon/aws/operators/test_comprehend.py
index c7aeb0830ba..fea16411a29 100644
--- a/providers/amazon/tests/unit/amazon/aws/operators/test_comprehend.py
+++ b/providers/amazon/tests/unit/amazon/aws/operators/test_comprehend.py
@@ -29,7 +29,7 @@ from airflow.providers.amazon.aws.operators.comprehend import
(
ComprehendCreateDocumentClassifierOperator,
ComprehendStartPiiEntitiesDetectionJobOperator,
)
-from airflow.utils.types import NOTSET
+from airflow.providers.amazon.version_compat import NOTSET
from unit.amazon.aws.utils.test_template_fields import validate_template_fields
diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_ecs.py
b/providers/amazon/tests/unit/amazon/aws/operators/test_ecs.py
index 14438fdbde7..14b442d19c5 100644
--- a/providers/amazon/tests/unit/amazon/aws/operators/test_ecs.py
+++ b/providers/amazon/tests/unit/amazon/aws/operators/test_ecs.py
@@ -37,7 +37,7 @@ from airflow.providers.amazon.aws.operators.ecs import (
)
from airflow.providers.amazon.aws.triggers.ecs import TaskDoneTrigger
from airflow.providers.amazon.aws.utils.task_log_fetcher import
AwsTaskLogFetcher
-from airflow.utils.types import NOTSET
+from airflow.providers.amazon.version_compat import NOTSET
from unit.amazon.aws.utils.test_template_fields import validate_template_fields
diff --git
a/providers/amazon/tests/unit/amazon/aws/operators/test_emr_serverless.py
b/providers/amazon/tests/unit/amazon/aws/operators/test_emr_serverless.py
index e88880289bc..e43d7b9793c 100644
--- a/providers/amazon/tests/unit/amazon/aws/operators/test_emr_serverless.py
+++ b/providers/amazon/tests/unit/amazon/aws/operators/test_emr_serverless.py
@@ -31,7 +31,7 @@ from airflow.providers.amazon.aws.operators.emr import (
EmrServerlessStartJobOperator,
EmrServerlessStopApplicationOperator,
)
-from airflow.utils.types import NOTSET
+from airflow.providers.amazon.version_compat import NOTSET
from unit.amazon.aws.utils.test_template_fields import validate_template_fields
diff --git a/providers/amazon/tests/unit/amazon/aws/sensors/test_ecs.py
b/providers/amazon/tests/unit/amazon/aws/sensors/test_ecs.py
index 8a116fd484b..11968340793 100644
--- a/providers/amazon/tests/unit/amazon/aws/sensors/test_ecs.py
+++ b/providers/amazon/tests/unit/amazon/aws/sensors/test_ecs.py
@@ -35,12 +35,12 @@ from airflow.providers.amazon.aws.sensors.ecs import (
EcsTaskStates,
EcsTaskStateSensor,
)
+from airflow.providers.amazon.version_compat import NOTSET
try:
from airflow.sdk import timezone
except ImportError:
from airflow.utils import timezone # type: ignore[attr-defined,no-redef]
-from airflow.utils.types import NOTSET
_Operator = TypeVar("_Operator")
TEST_CLUSTER_NAME = "fake-cluster"
diff --git a/providers/amazon/tests/unit/amazon/aws/utils/test_identifiers.py
b/providers/amazon/tests/unit/amazon/aws/utils/test_identifiers.py
index 40c1aba2420..4ac5ff28dc2 100644
--- a/providers/amazon/tests/unit/amazon/aws/utils/test_identifiers.py
+++ b/providers/amazon/tests/unit/amazon/aws/utils/test_identifiers.py
@@ -23,7 +23,7 @@ import uuid
import pytest
from airflow.providers.amazon.aws.utils.identifiers import generate_uuid
-from airflow.utils.types import NOTSET
+from airflow.providers.amazon.version_compat import NOTSET
class TestGenerateUuid:
diff --git
a/providers/edge3/src/airflow/providers/edge3/example_dags/win_test.py
b/providers/edge3/src/airflow/providers/edge3/example_dags/win_test.py
index 363fa5e9bac..7aa2822f219 100644
--- a/providers/edge3/src/airflow/providers/edge3/example_dags/win_test.py
+++ b/providers/edge3/src/airflow/providers/edge3/example_dags/win_test.py
@@ -32,42 +32,48 @@ from subprocess import STDOUT, Popen
from time import sleep
from typing import TYPE_CHECKING, Any
-try:
- from airflow.sdk import task, task_group
-except ImportError:
- # Airflow 2 path
- from airflow.decorators import task, task_group # type:
ignore[attr-defined,no-redef]
from airflow.exceptions import AirflowException, AirflowNotFoundException,
AirflowSkipException
from airflow.models import BaseOperator
from airflow.models.dag import DAG
from airflow.models.variable import Variable
from airflow.providers.standard.operators.empty import EmptyOperator
+from airflow.sdk.execution_time.context import context_to_airflow_vars
+try:
+ from airflow.sdk import task, task_group
+except ImportError:
+ from airflow.decorators import task, task_group # type:
ignore[attr-defined,no-redef]
try:
from airflow.sdk import BaseHook
except ImportError:
from airflow.hooks.base import BaseHook # type:
ignore[attr-defined,no-redef]
-from airflow.sdk import Param
-
+try:
+ from airflow.sdk import Param
+except ImportError:
+ from airflow.models import Param # type: ignore[attr-defined,no-redef]
try:
from airflow.sdk import TriggerRule
except ImportError:
- # Compatibility for Airflow < 3.1
from airflow.utils.trigger_rule import TriggerRule # type:
ignore[no-redef,attr-defined]
-from airflow.sdk.execution_time.context import context_to_airflow_vars
-from airflow.utils.types import ArgNotSet
-
-if TYPE_CHECKING:
- try:
- from airflow.sdk.types import RuntimeTaskInstanceProtocol as
TaskInstance
- except ImportError:
- from airflow.models import TaskInstance # type: ignore[assignment]
- from airflow.utils.context import Context
-
try:
- from airflow.operators.python import PythonOperator
+ from airflow.providers.common.compat.standard.operators import
PythonOperator
+except ImportError:
+ from airflow.operators.python import PythonOperator # type:
ignore[no-redef]
+try:
+ from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet
except ImportError:
- from airflow.providers.common.compat.standard.operators import
PythonOperator # type: ignore[no-redef]
+ from airflow.utils.types import NOTSET, ArgNotSet # type:
ignore[attr-defined,no-redef]
+try:
+ from airflow.sdk.definitions._internal.types import is_arg_set
+except ImportError:
+
+ def is_arg_set(value): # type: ignore[misc,no-redef]
+ return value is not NOTSET
+
+
+if TYPE_CHECKING:
+ from airflow.sdk import Context
+ from airflow.sdk.types import RuntimeTaskInstanceProtocol as TaskInstance
class CmdOperator(BaseOperator):
@@ -163,7 +169,7 @@ class CmdOperator(BaseOperator):
# When using the @task.command decorator, the command is not known
until the underlying Python
# callable is executed and therefore set to NOTSET initially. This
flag is useful during execution to
# determine whether the command value needs to re-rendered.
- self._init_command_not_set = isinstance(self.command, ArgNotSet)
+ self._init_command_not_set = not is_arg_set(self.command)
@staticmethod
def refresh_command(ti: TaskInstance) -> None:
diff --git
a/providers/google/src/airflow/providers/google/cloud/hooks/compute_ssh.py
b/providers/google/src/airflow/providers/google/cloud/hooks/compute_ssh.py
index a8e00ff4532..54f7aa93c6e 100644
--- a/providers/google/src/airflow/providers/google/cloud/hooks/compute_ssh.py
+++ b/providers/google/src/airflow/providers/google/cloud/hooks/compute_ssh.py
@@ -31,7 +31,11 @@ from airflow.providers.google.cloud.hooks.compute import
ComputeEngineHook
from airflow.providers.google.cloud.hooks.os_login import OSLoginHook
from airflow.providers.google.common.hooks.base_google import
PROVIDE_PROJECT_ID
from airflow.providers.ssh.hooks.ssh import SSHHook
-from airflow.utils.types import NOTSET, ArgNotSet
+
+try:
+ from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet
+except ImportError:
+ from airflow.utils.types import NOTSET, ArgNotSet # type:
ignore[attr-defined,no-redef]
# Paramiko should be imported after airflow.providers.ssh. Then the import
will fail with
# cannot import "airflow.providers.ssh" and will be correctly discovered as
optional feature
diff --git
a/providers/google/src/airflow/providers/google/cloud/log/stackdriver_task_handler.py
b/providers/google/src/airflow/providers/google/cloud/log/stackdriver_task_handler.py
index c9cc797206e..2ef1ce2efef 100644
---
a/providers/google/src/airflow/providers/google/cloud/log/stackdriver_task_handler.py
+++
b/providers/google/src/airflow/providers/google/cloud/log/stackdriver_task_handler.py
@@ -35,17 +35,20 @@ from airflow.exceptions import
AirflowProviderDeprecationWarning
from airflow.providers.google.cloud.utils.credentials_provider import
get_credentials_and_project_id
from airflow.providers.google.common.consts import CLIENT_INFO
from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS
-from airflow.utils.types import NOTSET, ArgNotSet
+
+try:
+ from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet
+except ImportError:
+ from airflow.utils.types import NOTSET, ArgNotSet # type:
ignore[attr-defined,no-redef]
+
+if not AIRFLOW_V_3_0_PLUS:
+ from airflow.utils.log.trigger_handler import ctx_indiv_trigger
if TYPE_CHECKING:
from google.auth.credentials import Credentials
from airflow.models import TaskInstance
-
-if not AIRFLOW_V_3_0_PLUS:
- from airflow.utils.log.trigger_handler import ctx_indiv_trigger
-
DEFAULT_LOGGER_NAME = "airflow"
_GLOBAL_RESOURCE = Resource(type="global", labels={})
diff --git a/providers/postgres/tests/unit/postgres/hooks/test_postgres.py
b/providers/postgres/tests/unit/postgres/hooks/test_postgres.py
index b286a169adf..f531dd50436 100644
--- a/providers/postgres/tests/unit/postgres/hooks/test_postgres.py
+++ b/providers/postgres/tests/unit/postgres/hooks/test_postgres.py
@@ -31,10 +31,9 @@ from airflow.exceptions import AirflowException,
AirflowOptionalProviderFeatureE
from airflow.models import Connection
from airflow.providers.postgres.dialects.postgres import PostgresDialect
from airflow.providers.postgres.hooks.postgres import CompatConnection,
PostgresHook
-from airflow.utils.types import NOTSET
from tests_common.test_utils.common_sql import mock_db_hook
-from tests_common.test_utils.version_compat import SQLALCHEMY_V_1_4
+from tests_common.test_utils.version_compat import NOTSET, SQLALCHEMY_V_1_4
INSERT_SQL_STATEMENT = "INSERT INTO connection (id, conn_id, conn_type,
description, host, {}, login, password, port, is_encrypted, is_extra_encrypted,
extra) VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)"
diff --git a/providers/slack/src/airflow/providers/slack/utils/__init__.py
b/providers/slack/src/airflow/providers/slack/utils/__init__.py
index 6c59b85c053..7c46594034f 100644
--- a/providers/slack/src/airflow/providers/slack/utils/__init__.py
+++ b/providers/slack/src/airflow/providers/slack/utils/__init__.py
@@ -20,7 +20,10 @@ import warnings
from collections.abc import Sequence
from typing import Any
-from airflow.utils.types import NOTSET
+try:
+ from airflow.sdk.definitions._internal.types import NOTSET
+except ImportError:
+ from airflow.utils.types import NOTSET # type:
ignore[attr-defined,no-redef]
class ConnectionExtraConfig:
diff --git a/providers/ssh/src/airflow/providers/ssh/hooks/ssh.py
b/providers/ssh/src/airflow/providers/ssh/hooks/ssh.py
index 55497d22d3e..b5860a931a5 100644
--- a/providers/ssh/src/airflow/providers/ssh/hooks/ssh.py
+++ b/providers/ssh/src/airflow/providers/ssh/hooks/ssh.py
@@ -36,7 +36,18 @@ from tenacity import Retrying, stop_after_attempt,
wait_fixed, wait_random
from airflow.exceptions import AirflowException
from airflow.providers.common.compat.sdk import BaseHook
from airflow.utils.platform import getuser
-from airflow.utils.types import NOTSET, ArgNotSet
+
+try:
+ from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet
+except ImportError:
+ from airflow.utils.types import NOTSET, ArgNotSet # type:
ignore[attr-defined,no-redef]
+try:
+ from airflow.sdk.definitions._internal.types import is_arg_set
+except ImportError:
+
+ def is_arg_set(value): # type: ignore[misc,no-redef]
+ return value is not NOTSET
+
CMD_TIMEOUT = 10
@@ -438,9 +449,9 @@ class SSHHook(BaseHook):
self.log.info("Running command: %s", command)
cmd_timeout: float | None
- if not isinstance(timeout, ArgNotSet):
+ if is_arg_set(timeout):
cmd_timeout = timeout
- elif not isinstance(self.cmd_timeout, ArgNotSet):
+ elif is_arg_set(self.cmd_timeout):
cmd_timeout = self.cmd_timeout
else:
cmd_timeout = CMD_TIMEOUT
diff --git a/providers/ssh/src/airflow/providers/ssh/operators/ssh.py
b/providers/ssh/src/airflow/providers/ssh/operators/ssh.py
index f2f53132376..3aef97df0c4 100644
--- a/providers/ssh/src/airflow/providers/ssh/operators/ssh.py
+++ b/providers/ssh/src/airflow/providers/ssh/operators/ssh.py
@@ -26,7 +26,11 @@ from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.common.compat.sdk import BaseOperator
from airflow.providers.ssh.hooks.ssh import SSHHook
-from airflow.utils.types import NOTSET, ArgNotSet
+
+try:
+ from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet
+except ImportError:
+ from airflow.utils.types import NOTSET, ArgNotSet # type:
ignore[attr-defined,no-redef]
if TYPE_CHECKING:
from paramiko.client import SSHClient
diff --git a/providers/ssh/tests/unit/ssh/operators/test_ssh.py
b/providers/ssh/tests/unit/ssh/operators/test_ssh.py
index 1d94fe57683..5747a738fc8 100644
--- a/providers/ssh/tests/unit/ssh/operators/test_ssh.py
+++ b/providers/ssh/tests/unit/ssh/operators/test_ssh.py
@@ -30,11 +30,10 @@ from airflow.models import TaskInstance
from airflow.providers.common.compat.sdk import timezone
from airflow.providers.ssh.hooks.ssh import SSHHook
from airflow.providers.ssh.operators.ssh import SSHOperator
-from airflow.utils.types import NOTSET
from tests_common.test_utils.config import conf_vars
from tests_common.test_utils.dag import sync_dag_to_db
-from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, NOTSET
datetime = timezone.datetime
diff --git
a/providers/standard/src/airflow/providers/standard/operators/bash.py
b/providers/standard/src/airflow/providers/standard/operators/bash.py
index 533b1dee5fb..7c5f0269a96 100644
--- a/providers/standard/src/airflow/providers/standard/operators/bash.py
+++ b/providers/standard/src/airflow/providers/standard/operators/bash.py
@@ -31,7 +31,8 @@ from airflow.providers.standard.version_compat import
BaseOperator
if TYPE_CHECKING:
from airflow.providers.common.compat.sdk import Context
- from airflow.utils.types import ArgNotSet
+
+ from tests_common.test_utils.version_compat import ArgNotSet
class BashOperator(BaseOperator):
diff --git
a/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py
b/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py
index 3bec158c8b7..c0f8709fa87 100644
---
a/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py
+++
b/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py
@@ -42,7 +42,12 @@ from airflow.providers.common.compat.sdk import
BaseOperatorLink, XCom, timezone
from airflow.providers.standard.triggers.external_task import DagStateTrigger
from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS,
BaseOperator
from airflow.utils.state import DagRunState
-from airflow.utils.types import NOTSET, ArgNotSet, DagRunType
+from airflow.utils.types import DagRunType
+
+try:
+ from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet
+except ImportError:
+ from airflow.utils.types import NOTSET, ArgNotSet # type:
ignore[attr-defined,no-redef]
XCOM_LOGICAL_DATE_ISO = "trigger_logical_date_iso"
XCOM_RUN_ID = "trigger_run_id"
diff --git a/providers/standard/tests/unit/standard/decorators/test_bash.py
b/providers/standard/tests/unit/standard/decorators/test_bash.py
index e3828db3f11..d868182f3b7 100644
--- a/providers/standard/tests/unit/standard/decorators/test_bash.py
+++ b/providers/standard/tests/unit/standard/decorators/test_bash.py
@@ -41,7 +41,9 @@ if AIRFLOW_V_3_0_PLUS:
else:
# bad hack but does the job
from airflow.decorators import task # type: ignore[attr-defined,no-redef]
- from airflow.utils.types import NOTSET as SET_DURING_EXECUTION # type:
ignore[assignment]
+ from airflow.utils.types import ( # type: ignore[attr-defined,no-redef]
+ NOTSET as SET_DURING_EXECUTION, # type: ignore[assignment]
+ )
if AIRFLOW_V_3_1_PLUS:
from airflow.sdk import timezone
else:
diff --git a/providers/standard/tests/unit/standard/operators/test_python.py
b/providers/standard/tests/unit/standard/operators/test_python.py
index 9787b85a6c6..cbc2bca2958 100644
--- a/providers/standard/tests/unit/standard/operators/test_python.py
+++ b/providers/standard/tests/unit/standard/operators/test_python.py
@@ -63,10 +63,15 @@ from airflow.providers.standard.operators.python import (
from airflow.providers.standard.utils.python_virtualenv import
_execute_in_subprocess, prepare_virtualenv
from airflow.utils.session import create_session
from airflow.utils.state import DagRunState, State, TaskInstanceState
-from airflow.utils.types import NOTSET, DagRunType
+from airflow.utils.types import DagRunType
from tests_common.test_utils.db import clear_db_runs
-from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_1,
AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS
+from tests_common.test_utils.version_compat import (
+ AIRFLOW_V_3_0_1,
+ AIRFLOW_V_3_0_PLUS,
+ AIRFLOW_V_3_1_PLUS,
+ NOTSET,
+)
if AIRFLOW_V_3_0_PLUS:
from airflow.sdk import BaseOperator
diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/types.py
b/task-sdk/src/airflow/sdk/definitions/_internal/types.py
index 8ae8ef1b1cb..47270f823fc 100644
--- a/task-sdk/src/airflow/sdk/definitions/_internal/types.py
+++ b/task-sdk/src/airflow/sdk/definitions/_internal/types.py
@@ -17,36 +17,37 @@
from __future__ import annotations
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, TypeVar
+if TYPE_CHECKING:
+ from typing_extensions import TypeIs
-class ArgNotSet:
- """
- Sentinel type for annotations, useful when None is not viable.
-
- Use like this::
+ from airflow.sdk.definitions._internal.node import DAGNode
- def is_arg_passed(arg: Union[ArgNotSet, None] = NOTSET) -> bool:
- if arg is NOTSET:
- return False
- return True
+ T = TypeVar("T")
+__all__ = [
+ "NOTSET",
+ "SET_DURING_EXECUTION",
+ "ArgNotSet",
+ "SetDuringExecution",
+ "is_arg_set",
+ "validate_instance_args",
+]
- is_arg_passed() # False.
- is_arg_passed(None) # True.
- """
+try:
+ # If core and SDK exist together, use core to avoid identity issues.
+ from airflow.serialization.definitions.notset import NOTSET, ArgNotSet
+except ModuleNotFoundError:
- @staticmethod
- def serialize():
- return "NOTSET"
+ class ArgNotSet: # type: ignore[no-redef]
+ """Sentinel type for annotations, useful when None is not viable."""
- @classmethod
- def deserialize(cls):
- return cls
+ NOTSET = ArgNotSet() # type: ignore[no-redef]
-NOTSET = ArgNotSet()
-"""Sentinel value for argument default. See ``ArgNotSet``."""
+def is_arg_set(value: T | ArgNotSet) -> TypeIs[T]:
+ return not isinstance(value, ArgNotSet)
class SetDuringExecution(ArgNotSet):
@@ -61,10 +62,6 @@ SET_DURING_EXECUTION = SetDuringExecution()
"""Sentinel value for argument default. See ``SetDuringExecution``."""
-if TYPE_CHECKING:
- from airflow.sdk.definitions._internal.node import DAGNode
-
-
def validate_instance_args(instance: DAGNode, expected_arg_types: dict[str,
Any]) -> None:
"""Validate that the instance has the expected types for the arguments."""
from airflow.sdk.definitions.taskgroup import TaskGroup
diff --git a/task-sdk/src/airflow/sdk/definitions/dag.py
b/task-sdk/src/airflow/sdk/definitions/dag.py
index 4e26fce37e7..6fc32c188b8 100644
--- a/task-sdk/src/airflow/sdk/definitions/dag.py
+++ b/task-sdk/src/airflow/sdk/definitions/dag.py
@@ -46,7 +46,7 @@ from airflow.exceptions import (
from airflow.sdk import TaskInstanceState, TriggerRule
from airflow.sdk.bases.operator import BaseOperator
from airflow.sdk.definitions._internal.node import validate_key
-from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet
+from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet,
is_arg_set
from airflow.sdk.definitions.asset import AssetAll, BaseAsset
from airflow.sdk.definitions.context import Context
from airflow.sdk.definitions.deadline import DeadlineAlert
@@ -1197,7 +1197,7 @@ class DAG:
self.validate()
# Allow users to explicitly pass None. If it isn't set, we default
to current time.
- logical_date = logical_date if not isinstance(logical_date,
ArgNotSet) else timezone.utcnow()
+ logical_date = logical_date if is_arg_set(logical_date) else
timezone.utcnow()
log.debug("Clearing existing task instances for logical date %s",
logical_date)
# TODO: Replace with calling client.dag_run.clear in Execution API
at some point
diff --git a/task-sdk/src/airflow/sdk/definitions/param.py
b/task-sdk/src/airflow/sdk/definitions/param.py
index 2c853ce1ffc..5c0136d7de5 100644
--- a/task-sdk/src/airflow/sdk/definitions/param.py
+++ b/task-sdk/src/airflow/sdk/definitions/param.py
@@ -25,7 +25,7 @@ from typing import TYPE_CHECKING, Any, ClassVar
from airflow.exceptions import AirflowException, ParamValidationError
from airflow.sdk.definitions._internal.mixins import ResolveMixin
-from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet
+from airflow.sdk.definitions._internal.types import NOTSET, is_arg_set
if TYPE_CHECKING:
from airflow.sdk.definitions.context import Context
@@ -90,7 +90,7 @@ class Param:
if value is not NOTSET:
self._check_json(value)
final_val = self.value if value is NOTSET else value
- if isinstance(final_val, ArgNotSet):
+ if not is_arg_set(final_val):
if suppress_exception:
return None
raise ParamValidationError("No value passed and Param has no
default value")
diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py
b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py
index fa540333cf3..a674a47e19b 100644
--- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py
+++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py
@@ -31,7 +31,7 @@ from airflow.sdk import TriggerRule
from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator
from airflow.sdk.definitions._internal.mixins import DependencyMixin,
ResolveMixin
from airflow.sdk.definitions._internal.setup_teardown import
SetupTeardownContext
-from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet
+from airflow.sdk.definitions._internal.types import NOTSET, is_arg_set
from airflow.sdk.execution_time.lazy_sequence import LazyXComSequence
from airflow.sdk.execution_time.xcom import BaseXCom
@@ -347,7 +347,7 @@ class PlainXComArg(XComArg):
default=NOTSET,
map_indexes=map_indexes,
)
- if not isinstance(result, ArgNotSet):
+ if is_arg_set(result):
return result
if self.key == BaseXCom.XCOM_RETURN_KEY:
return None
@@ -452,9 +452,9 @@ class _ZipResult(Sequence):
def __len__(self) -> int:
lengths = (len(v) for v in self.values)
- if isinstance(self.fillvalue, ArgNotSet):
- return min(lengths)
- return max(lengths)
+ if is_arg_set(self.fillvalue):
+ return max(lengths)
+ return min(lengths)
@attrs.define
@@ -474,15 +474,15 @@ class ZipXComArg(XComArg):
args_iter = iter(self.args)
first = repr(next(args_iter))
rest = ", ".join(repr(arg) for arg in args_iter)
- if isinstance(self.fillvalue, ArgNotSet):
- return f"{first}.zip({rest})"
- return f"{first}.zip({rest}, fillvalue={self.fillvalue!r})"
+ if is_arg_set(self.fillvalue):
+ return f"{first}.zip({rest}, fillvalue={self.fillvalue!r})"
+ return f"{first}.zip({rest})"
def _serialize(self) -> dict[str, Any]:
args = [serialize_xcom_arg(arg) for arg in self.args]
- if isinstance(self.fillvalue, ArgNotSet):
- return {"args": args}
- return {"args": args, "fillvalue": self.fillvalue}
+ if is_arg_set(self.fillvalue):
+ return {"args": args, "fillvalue": self.fillvalue}
+ return {"args": args}
def iter_references(self) -> Iterator[tuple[Operator, str]]:
for arg in self.args:
@@ -602,9 +602,9 @@ def _(xcom_arg: ZipXComArg, resolved_val: Sized,
upstream_map_indexes: dict[str,
ready_lengths = [length for length in all_lengths if length is not None]
if len(ready_lengths) != len(xcom_arg.args):
return None # If any of the referenced XComs is not ready, we are not
ready either.
- if isinstance(xcom_arg.fillvalue, ArgNotSet):
- return min(ready_lengths)
- return max(ready_lengths)
+ if is_arg_set(xcom_arg.fillvalue):
+ return max(ready_lengths)
+ return min(ready_lengths)
@get_task_map_length.register
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index 26c1e5ee762..74d085f245a 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -54,7 +54,7 @@ from airflow.sdk.api.datamodels._generated import (
from airflow.sdk.bases.operator import BaseOperator, ExecutorSafeguard
from airflow.sdk.bases.xcom import BaseXCom
from airflow.sdk.definitions._internal.dag_parsing_context import
_airflow_parsing_context_manager
-from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet
+from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet,
is_arg_set
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef,
AssetUniqueKey, AssetUriRef
from airflow.sdk.definitions.mappedoperator import MappedOperator
from airflow.sdk.definitions.param import process_params
@@ -360,7 +360,7 @@ class RuntimeTaskInstance(TaskInstance):
task_ids = [task_ids]
# If map_indexes is not specified, pull xcoms from all map indexes for
each task
- if isinstance(map_indexes, ArgNotSet):
+ if not is_arg_set(map_indexes):
xcoms: list[Any] = []
for t_id in task_ids:
values = XCom.get_all(
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index cb2687a9676..56ee821dd2f 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -66,7 +66,7 @@ from airflow.sdk.api.datamodels._generated import (
TIRunContext,
)
from airflow.sdk.bases.xcom import BaseXCom
-from airflow.sdk.definitions._internal.types import NOTSET,
SET_DURING_EXECUTION, ArgNotSet
+from airflow.sdk.definitions._internal.types import NOTSET,
SET_DURING_EXECUTION, is_arg_set
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey,
Dataset, Model
from airflow.sdk.definitions.param import DagParam
from airflow.sdk.exceptions import ErrorType
@@ -1570,9 +1570,7 @@ class TestRuntimeTaskInstance:
for task_id_raw in task_ids:
# Without task_ids (or None) expected behavior is to pull with
calling task_id
- task_id = (
- test_task_id if task_id_raw is None or isinstance(task_id_raw,
ArgNotSet) else task_id_raw
- )
+ task_id = task_id_raw if is_arg_set(task_id_raw) and task_id_raw
is not None else test_task_id
for map_index in map_indexes:
if map_index == NOTSET: