This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new dad102e9c7e Split SerializedBaseOperator from serde logic (#59627)
dad102e9c7e is described below
commit dad102e9c7eec421bac1ff46015a45e8664ec4bf
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Wed Dec 24 08:55:15 2025 +0800
Split SerializedBaseOperator from serde logic (#59627)
---
airflow-core/src/airflow/api/common/mark_tasks.py | 3 +-
.../src/airflow/api_fastapi/common/dagbag.py | 2 +-
.../src/airflow/api_fastapi/common/parameters.py | 2 +-
.../api_fastapi/core_api/datamodels/dag_run.py | 2 +-
.../core_api/services/public/task_instances.py | 2 +-
.../api_fastapi/core_api/services/ui/calendar.py | 2 +-
.../api_fastapi/core_api/services/ui/grid.py | 2 +-
.../api_fastapi/core_api/services/ui/task_group.py | 2 +-
.../execution_api/routes/task_instances.py | 2 +-
.../src/airflow/cli/commands/task_command.py | 2 +-
.../src/airflow/dag_processing/collection.py | 3 +-
.../src/airflow/jobs/scheduler_job_runner.py | 2 +-
airflow-core/src/airflow/models/backfill.py | 2 +-
airflow-core/src/airflow/models/dag.py | 3 +-
airflow-core/src/airflow/models/dagbag.py | 2 +-
airflow-core/src/airflow/models/dagrun.py | 3 +-
airflow-core/src/airflow/models/expandinput.py | 2 +-
airflow-core/src/airflow/models/mappedoperator.py | 92 +++--
airflow-core/src/airflow/models/referencemixin.py | 2 +-
.../src/airflow/models/renderedtifields.py | 2 +-
airflow-core/src/airflow/models/taskinstance.py | 2 +-
airflow-core/src/airflow/models/taskmap.py | 4 +-
airflow-core/src/airflow/models/xcom_arg.py | 5 +-
.../serialization/definitions/baseoperator.py | 452 ++++++++++++++++++++
.../src/airflow/serialization/definitions/dag.py | 2 +-
.../src/airflow/serialization/definitions/node.py | 5 +-
.../airflow/serialization/definitions/taskgroup.py | 4 +-
.../airflow/serialization/serialized_objects.py | 458 ++-------------------
.../ti_deps/deps/mapped_task_upstream_dep.py | 2 +-
.../src/airflow/ti_deps/deps/prev_dagrun_dep.py | 2 +-
.../src/airflow/ti_deps/deps/trigger_rule_dep.py | 2 +-
airflow-core/src/airflow/utils/cli.py | 2 +-
airflow-core/src/airflow/utils/dag_edges.py | 3 +-
airflow-core/src/airflow/utils/dot_renderer.py | 3 +-
.../tests/unit/api/common/test_mark_tasks.py | 2 +-
airflow-core/tests/unit/jobs/test_scheduler_job.py | 3 +-
airflow-core/tests/unit/models/test_cleartasks.py | 2 +-
airflow-core/tests/unit/models/test_dagcode.py | 2 +-
airflow-core/tests/unit/models/test_dagrun.py | 4 +-
.../tests/unit/models/test_mappedoperator.py | 2 +-
.../tests/unit/models/test_taskinstance.py | 7 +-
.../unit/serialization/test_dag_serialization.py | 114 ++---
.../unit/serialization/test_serialized_objects.py | 2 +-
devel-common/src/tests_common/pytest_plugin.py | 2 +-
devel-common/src/tests_common/test_utils/compat.py | 8 +-
.../src/tests_common/test_utils/mapping.py | 2 +-
.../fab/auth_manager/security_manager/override.py | 2 +-
.../openlineage/utils/selective_enable.py | 2 +-
.../tests/unit/openlineage/utils/test_utils.py | 11 +-
.../standard/sensors/test_external_task_sensor.py | 6 +-
.../unit/standard/utils/test_sensor_helper.py | 2 +-
scripts/in_container/run_schema_defaults_check.py | 2 +-
52 files changed, 665 insertions(+), 590 deletions(-)
diff --git a/airflow-core/src/airflow/api/common/mark_tasks.py
b/airflow-core/src/airflow/api/common/mark_tasks.py
index 2525250ed82..c67b0dedcd9 100644
--- a/airflow-core/src/airflow/api/common/mark_tasks.py
+++ b/airflow-core/src/airflow/api/common/mark_tasks.py
@@ -34,7 +34,8 @@ if TYPE_CHECKING:
from sqlalchemy.orm import Session as SASession
from airflow.models.mappedoperator import MappedOperator
- from airflow.serialization.serialized_objects import
SerializedBaseOperator, SerializedDAG
+ from airflow.serialization.definitions.baseoperator import
SerializedBaseOperator
+ from airflow.serialization.definitions.dag import SerializedDAG
Operator: TypeAlias = MappedOperator | SerializedBaseOperator
diff --git a/airflow-core/src/airflow/api_fastapi/common/dagbag.py
b/airflow-core/src/airflow/api_fastapi/common/dagbag.py
index 491a7131acc..ce81f6906b7 100644
--- a/airflow-core/src/airflow/api_fastapi/common/dagbag.py
+++ b/airflow-core/src/airflow/api_fastapi/common/dagbag.py
@@ -25,7 +25,7 @@ from airflow.models.dagbag import DBDagBag
if TYPE_CHECKING:
from airflow.models.dagrun import DagRun
- from airflow.serialization.serialized_objects import SerializedDAG
+ from airflow.serialization.definitions.dag import SerializedDAG
def create_dag_bag() -> DBDagBag:
diff --git a/airflow-core/src/airflow/api_fastapi/common/parameters.py
b/airflow-core/src/airflow/api_fastapi/common/parameters.py
index 924206602fc..0d460380c1e 100644
--- a/airflow-core/src/airflow/api_fastapi/common/parameters.py
+++ b/airflow-core/src/airflow/api_fastapi/common/parameters.py
@@ -69,7 +69,7 @@ if TYPE_CHECKING:
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.sql import ColumnElement, Select
- from airflow.serialization.serialized_objects import SerializedDAG
+ from airflow.serialization.definitions.dag import SerializedDAG
T = TypeVar("T")
diff --git
a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/dag_run.py
b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/dag_run.py
index 56546b58996..19960a5b8c4 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/dag_run.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/dag_run.py
@@ -32,7 +32,7 @@ from airflow.utils.state import DagRunState
from airflow.utils.types import DagRunTriggeredByType, DagRunType
if TYPE_CHECKING:
- from airflow.serialization.serialized_objects import SerializedDAG
+ from airflow.serialization.definitions.dag import SerializedDAG
class DAGRunPatchStates(str, Enum):
diff --git
a/airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py
b/airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py
index eda41f27cbe..88f9f1dcf66 100644
---
a/airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py
+++
b/airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py
@@ -42,7 +42,7 @@ from airflow.api_fastapi.core_api.security import GetUserDep
from airflow.api_fastapi.core_api.services.public.common import BulkService
from airflow.listeners.listener import get_listener_manager
from airflow.models.taskinstance import TaskInstance as TI
-from airflow.serialization.serialized_objects import SerializedDAG
+from airflow.serialization.definitions.dag import SerializedDAG
from airflow.utils.state import TaskInstanceState
log = structlog.get_logger(__name__)
diff --git
a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/calendar.py
b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/calendar.py
index eed47b2c4db..0f00af412c6 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/calendar.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/calendar.py
@@ -35,7 +35,7 @@ from airflow.api_fastapi.core_api.datamodels.ui.calendar
import (
CalendarTimeRangeResponse,
)
from airflow.models.dagrun import DagRun
-from airflow.serialization.serialized_objects import SerializedDAG
+from airflow.serialization.definitions.dag import SerializedDAG
from airflow.timetables._cron import CronMixin
from airflow.timetables.base import DataInterval, TimeRestriction
from airflow.timetables.simple import ContinuousTimetable
diff --git a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py
b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py
index 5b20511d31f..bc7bd4b8a23 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py
@@ -26,8 +26,8 @@ from airflow.api_fastapi.common.parameters import
state_priority
from airflow.api_fastapi.core_api.services.ui.task_group import
get_task_group_children_getter
from airflow.models.mappedoperator import MappedOperator
from airflow.models.taskmap import TaskMap
+from airflow.serialization.definitions.baseoperator import
SerializedBaseOperator
from airflow.serialization.definitions.taskgroup import SerializedTaskGroup
-from airflow.serialization.serialized_objects import SerializedBaseOperator
log = structlog.get_logger(logger_name=__name__)
diff --git
a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/task_group.py
b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/task_group.py
index 0fc3c0b14c7..86f363afa77 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/services/ui/task_group.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/services/ui/task_group.py
@@ -25,7 +25,7 @@ from operator import methodcaller
from airflow.configuration import conf
from airflow.models.mappedoperator import MappedOperator, is_mapped
-from airflow.serialization.serialized_objects import SerializedBaseOperator
+from airflow.serialization.definitions.baseoperator import
SerializedBaseOperator
@cache
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
index 366cd456de2..3f0eec6ae0b 100644
---
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
+++
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -70,7 +70,7 @@ from airflow.models.trigger import Trigger
from airflow.models.xcom import XComModel
from airflow.sdk.definitions._internal.expandinput import NotFullyPopulated
from airflow.serialization.definitions.assets import SerializedAsset,
SerializedAssetUniqueKey
-from airflow.serialization.serialized_objects import SerializedDAG
+from airflow.serialization.definitions.dag import SerializedDAG
from airflow.task.trigger_rule import TriggerRule
from airflow.utils.sqlalchemy import get_dialect_name
from airflow.utils.state import DagRunState, TaskInstanceState, TerminalTIState
diff --git a/airflow-core/src/airflow/cli/commands/task_command.py
b/airflow-core/src/airflow/cli/commands/task_command.py
index 586edad563f..f54c1740573 100644
--- a/airflow-core/src/airflow/cli/commands/task_command.py
+++ b/airflow-core/src/airflow/cli/commands/task_command.py
@@ -63,7 +63,7 @@ if TYPE_CHECKING:
from sqlalchemy.orm.session import Session
from airflow.models.mappedoperator import MappedOperator
- from airflow.serialization.serialized_objects import SerializedBaseOperator
+ from airflow.serialization.definitions.baseoperator import
SerializedBaseOperator
CreateIfNecessary = Literal[False, "db", "memory"]
Operator = MappedOperator | SerializedBaseOperator
diff --git a/airflow-core/src/airflow/dag_processing/collection.py
b/airflow-core/src/airflow/dag_processing/collection.py
index 13f2415a3ba..1d40c5d0396 100644
--- a/airflow-core/src/airflow/dag_processing/collection.py
+++ b/airflow-core/src/airflow/dag_processing/collection.py
@@ -60,8 +60,9 @@ from airflow.serialization.definitions.assets import (
SerializedAssetNameRef,
SerializedAssetUriRef,
)
+from airflow.serialization.definitions.dag import SerializedDAG
from airflow.serialization.enums import Encoding
-from airflow.serialization.serialized_objects import BaseSerialization,
LazyDeserializedDAG, SerializedDAG
+from airflow.serialization.serialized_objects import BaseSerialization,
LazyDeserializedDAG
from airflow.triggers.base import BaseEventTrigger
from airflow.utils.retries import MAX_DB_RETRIES, run_with_db_retries
from airflow.utils.sqlalchemy import get_dialect_name, with_row_locks
diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py
b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
index 80fea348b87..6306187ea4a 100644
--- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py
+++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
@@ -110,7 +110,7 @@ if TYPE_CHECKING:
from airflow.executors.base_executor import BaseExecutor
from airflow.executors.executor_utils import ExecutorName
from airflow.models.taskinstance import TaskInstanceKey
- from airflow.serialization.serialized_objects import SerializedDAG
+ from airflow.serialization.definitions.dag import SerializedDAG
from airflow.utils.sqlalchemy import CommitProhibitorGuard
TI = TaskInstance
diff --git a/airflow-core/src/airflow/models/backfill.py
b/airflow-core/src/airflow/models/backfill.py
index 6d2dfb573d2..33df6901d40 100644
--- a/airflow-core/src/airflow/models/backfill.py
+++ b/airflow-core/src/airflow/models/backfill.py
@@ -56,7 +56,7 @@ if TYPE_CHECKING:
from sqlalchemy.orm import Session
from airflow.models.dagrun import DagRun
- from airflow.serialization.serialized_objects import SerializedDAG
+ from airflow.serialization.definitions.dag import SerializedDAG
from airflow.timetables.base import DagRunInfo
log = logging.getLogger(__name__)
diff --git a/airflow-core/src/airflow/models/dag.py
b/airflow-core/src/airflow/models/dag.py
index 4cac57ad72a..4705f9b18a9 100644
--- a/airflow-core/src/airflow/models/dag.py
+++ b/airflow-core/src/airflow/models/dag.py
@@ -76,7 +76,8 @@ if TYPE_CHECKING:
SerializedAssetAlias,
SerializedAssetBase,
)
- from airflow.serialization.serialized_objects import
SerializedBaseOperator, SerializedDAG
+ from airflow.serialization.definitions.baseoperator import
SerializedBaseOperator
+ from airflow.serialization.definitions.dag import SerializedDAG
Operator: TypeAlias = MappedOperator | SerializedBaseOperator
UKey: TypeAlias = SerializedAssetUniqueKey
diff --git a/airflow-core/src/airflow/models/dagbag.py
b/airflow-core/src/airflow/models/dagbag.py
index d1e93217473..abf8c417603 100644
--- a/airflow-core/src/airflow/models/dagbag.py
+++ b/airflow-core/src/airflow/models/dagbag.py
@@ -35,7 +35,7 @@ if TYPE_CHECKING:
from airflow.models import DagRun
from airflow.models.serialized_dag import SerializedDagModel
- from airflow.serialization.serialized_objects import SerializedDAG
+ from airflow.serialization.definitions.dag import SerializedDAG
class DBDagBag:
diff --git a/airflow-core/src/airflow/models/dagrun.py
b/airflow-core/src/airflow/models/dagrun.py
index 89999bdbf29..da66641524f 100644
--- a/airflow-core/src/airflow/models/dagrun.py
+++ b/airflow-core/src/airflow/models/dagrun.py
@@ -107,7 +107,8 @@ if TYPE_CHECKING:
from airflow.models.mappedoperator import MappedOperator
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.sdk import DAG as SDKDAG
- from airflow.serialization.serialized_objects import
SerializedBaseOperator, SerializedDAG
+ from airflow.serialization.definitions.baseoperator import
SerializedBaseOperator
+ from airflow.serialization.definitions.dag import SerializedDAG
CreatedTasks = TypeVar("CreatedTasks", Iterator["dict[str, Any]"],
Iterator[TI])
AttributeValueType: TypeAlias = (
diff --git a/airflow-core/src/airflow/models/expandinput.py
b/airflow-core/src/airflow/models/expandinput.py
index 40cbcdefea6..58f45a83ea9 100644
--- a/airflow-core/src/airflow/models/expandinput.py
+++ b/airflow-core/src/airflow/models/expandinput.py
@@ -41,7 +41,7 @@ if TYPE_CHECKING:
from airflow.models.mappedoperator import MappedOperator
from airflow.models.xcom_arg import SchedulerXComArg
- from airflow.serialization.serialized_objects import SerializedBaseOperator
+ from airflow.serialization.definitions.baseoperator import
SerializedBaseOperator
Operator: TypeAlias = MappedOperator | SerializedBaseOperator
diff --git a/airflow-core/src/airflow/models/mappedoperator.py
b/airflow-core/src/airflow/models/mappedoperator.py
index 758149875f0..779f07226eb 100644
--- a/airflow-core/src/airflow/models/mappedoperator.py
+++ b/airflow-core/src/airflow/models/mappedoperator.py
@@ -32,11 +32,11 @@ from airflow.exceptions import AirflowException, NotMapped
from airflow.sdk import BaseOperator as TaskSDKBaseOperator
from airflow.sdk.definitions._internal.abstractoperator import
DEFAULT_RETRY_DELAY_MULTIPLIER
from airflow.sdk.definitions.mappedoperator import MappedOperator as
TaskSDKMappedOperator
+from airflow.serialization.definitions.baseoperator import
DEFAULT_OPERATOR_DEPS, SerializedBaseOperator
from airflow.serialization.definitions.node import DAGNode
from airflow.serialization.definitions.param import SerializedParamsDict
from airflow.serialization.definitions.taskgroup import
SerializedMappedTaskGroup, SerializedTaskGroup
from airflow.serialization.enums import DagAttributeTypes
-from airflow.serialization.serialized_objects import DEFAULT_OPERATOR_DEPS,
SerializedBaseOperator
from airflow.task.priority_strategy import PriorityWeightStrategy,
validate_and_load_priority_weight_strategy
if TYPE_CHECKING:
@@ -47,10 +47,11 @@ if TYPE_CHECKING:
from airflow.models import TaskInstance
from airflow.models.expandinput import SchedulerExpandInput
- from airflow.sdk import BaseOperatorLink, Context
+ from airflow.sdk import Context
from airflow.sdk.definitions._internal.node import DAGNode as
TaskSDKDAGNode
from airflow.sdk.definitions.operator_resources import Resources
- from airflow.serialization.serialized_objects import SerializedDAG
+ from airflow.serialization.definitions.dag import SerializedDAG
+ from airflow.serialization.serialized_objects import XComOperatorLink
from airflow.task.trigger_rule import TriggerRule
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
from airflow.triggers.base import StartTriggerArgs
@@ -94,7 +95,7 @@ class MappedOperator(DAGNode):
# Needed for serialization.
task_id: str
params: SerializedParamsDict = attrs.field(init=False,
factory=SerializedParamsDict)
- operator_extra_links: Collection[BaseOperatorLink]
+ operator_extra_links: Collection[XComOperatorLink]
template_ext: Sequence[str]
template_fields: Collection[str]
template_fields_renderers: dict[str, str]
@@ -139,6 +140,9 @@ class MappedOperator(DAGNode):
def __repr__(self) -> str:
return f"<SerializedMappedTask({self.task_type}): {self.task_id}>"
+ def _get_partial_kwargs_or_operator_default(self, key: str):
+ return self.partial_kwargs.get(key, getattr(SerializedBaseOperator,
key))
+
@property
def node_id(self) -> str:
return self.task_id
@@ -183,98 +187,95 @@ class MappedOperator(DAGNode):
@property
def owner(self) -> str:
- return self.partial_kwargs.get("owner", SerializedBaseOperator.owner)
+ return self._get_partial_kwargs_or_operator_default("owner")
@property
def trigger_rule(self) -> TriggerRule:
- return self.partial_kwargs.get("trigger_rule",
SerializedBaseOperator.trigger_rule)
+ return self._get_partial_kwargs_or_operator_default("trigger_rule")
@property
def is_setup(self) -> bool:
- return bool(self.partial_kwargs.get("is_setup"))
+ return self._get_partial_kwargs_or_operator_default("is_setup")
@property
def is_teardown(self) -> bool:
- return bool(self.partial_kwargs.get("is_teardown"))
+ return self._get_partial_kwargs_or_operator_default("is_teardown")
@property
def depends_on_past(self) -> bool:
- return bool(self.partial_kwargs.get("depends_on_past"))
+ return self._get_partial_kwargs_or_operator_default("depends_on_past")
@property
def ignore_first_depends_on_past(self) -> bool:
- value = self.partial_kwargs.get(
- "ignore_first_depends_on_past",
SerializedBaseOperator.ignore_first_depends_on_past
- )
- return bool(value)
+ return
self._get_partial_kwargs_or_operator_default("ignore_first_depends_on_past")
@property
def wait_for_downstream(self) -> bool:
- return bool(self.partial_kwargs.get("wait_for_downstream"))
+ return
self._get_partial_kwargs_or_operator_default("wait_for_downstream")
@property
def retries(self) -> int:
- return self.partial_kwargs.get("retries",
SerializedBaseOperator.retries)
+ return self._get_partial_kwargs_or_operator_default("retries")
@property
def queue(self) -> str:
- return self.partial_kwargs.get("queue", SerializedBaseOperator.queue)
+ return self._get_partial_kwargs_or_operator_default("queue")
@property
def pool(self) -> str:
- return self.partial_kwargs.get("pool", SerializedBaseOperator.pool)
+ return self._get_partial_kwargs_or_operator_default("pool")
@property
def pool_slots(self) -> int:
- return self.partial_kwargs.get("pool_slots",
SerializedBaseOperator.pool_slots)
+ return self._get_partial_kwargs_or_operator_default("pool_slots")
@property
def resources(self) -> Resources | None:
- return self.partial_kwargs.get("resources")
+ return self._get_partial_kwargs_or_operator_default("resources")
@property
def max_active_tis_per_dag(self) -> int | None:
- return self.partial_kwargs.get("max_active_tis_per_dag")
+ return
self._get_partial_kwargs_or_operator_default("max_active_tis_per_dag")
@property
def max_active_tis_per_dagrun(self) -> int | None:
- return self.partial_kwargs.get("max_active_tis_per_dagrun")
+ return
self._get_partial_kwargs_or_operator_default("max_active_tis_per_dagrun")
@property
def has_on_execute_callback(self) -> bool:
- return bool(self.partial_kwargs.get("has_on_execute_callback", False))
+ return
self._get_partial_kwargs_or_operator_default("has_on_execute_callback")
@property
def has_on_failure_callback(self) -> bool:
- return bool(self.partial_kwargs.get("has_on_failure_callback", False))
+ return
self._get_partial_kwargs_or_operator_default("has_on_failure_callback")
@property
def has_on_retry_callback(self) -> bool:
- return bool(self.partial_kwargs.get("has_on_retry_callback", False))
+ return
self._get_partial_kwargs_or_operator_default("has_on_retry_callback")
@property
def has_on_success_callback(self) -> bool:
- return bool(self.partial_kwargs.get("has_on_success_callback", False))
+ return
self._get_partial_kwargs_or_operator_default("has_on_success_callback")
@property
def has_on_skipped_callback(self) -> bool:
- return bool(self.partial_kwargs.get("has_on_skipped_callback", False))
+ return
self._get_partial_kwargs_or_operator_default("has_on_skipped_callback")
@property
def run_as_user(self) -> str | None:
- return self.partial_kwargs.get("run_as_user")
+ return self._get_partial_kwargs_or_operator_default("run_as_user")
@property
def priority_weight(self) -> int:
- return self.partial_kwargs.get("priority_weight",
SerializedBaseOperator.priority_weight)
+ return self._get_partial_kwargs_or_operator_default("priority_weight")
@property
def retry_delay(self) -> datetime.timedelta:
- return self.partial_kwargs.get("retry_delay",
SerializedBaseOperator.retry_delay)
+ return self._get_partial_kwargs_or_operator_default("retry_delay")
@property
def retry_exponential_backoff(self) -> float:
- value = self.partial_kwargs.get("retry_exponential_backoff", 0)
+ value =
self._get_partial_kwargs_or_operator_default("retry_exponential_backoff")
if value is True:
return 2.0
if value is False:
@@ -282,8 +283,8 @@ class MappedOperator(DAGNode):
return float(value)
@property
- def max_retry_delay(self) -> datetime.timedelta | None:
- return self.partial_kwargs.get("max_retry_delay")
+ def max_retry_delay(self) -> datetime.timedelta | float | None:
+ return self._get_partial_kwargs_or_operator_default("max_retry_delay")
@property
def retry_delay_multiplier(self) -> float:
@@ -291,45 +292,46 @@ class MappedOperator(DAGNode):
@property
def weight_rule(self) -> PriorityWeightStrategy:
- return validate_and_load_priority_weight_strategy(
- self.partial_kwargs.get("weight_rule",
SerializedBaseOperator._weight_rule)
- )
+ from airflow.serialization.definitions.baseoperator import
SerializedBaseOperator
+
+ value = self.partial_kwargs.get("weight_rule") or
SerializedBaseOperator._weight_rule
+ return validate_and_load_priority_weight_strategy(value)
@property
def executor(self) -> str | None:
- return self.partial_kwargs.get("executor")
+ return self._get_partial_kwargs_or_operator_default("executor")
@property
def executor_config(self) -> dict:
- return self.partial_kwargs.get("executor_config", {})
+ return self._get_partial_kwargs_or_operator_default("executor_config")
@property
def execution_timeout(self) -> datetime.timedelta | None:
- return self.partial_kwargs.get("execution_timeout")
+ return
self._get_partial_kwargs_or_operator_default("execution_timeout")
@property
def inlets(self) -> list[Any]:
- return self.partial_kwargs.get("inlets", [])
+ return self._get_partial_kwargs_or_operator_default("inlets")
@property
def outlets(self) -> list[Any]:
- return self.partial_kwargs.get("outlets", [])
+ return self._get_partial_kwargs_or_operator_default("outlets")
@property
def email(self) -> str | Iterable[str] | None:
- return self.partial_kwargs.get("email")
+ return self._get_partial_kwargs_or_operator_default("email")
@property
def email_on_failure(self) -> bool:
- return self.partial_kwargs.get("email_on_failure", True)
+ return self._get_partial_kwargs_or_operator_default("email_on_failure")
@property
def email_on_retry(self) -> bool:
- return self.partial_kwargs.get("email_on_retry", True)
+ return self._get_partial_kwargs_or_operator_default("email_on_retry")
@property
def on_failure_fail_dagrun(self) -> bool:
- return bool(self.partial_kwargs.get("on_failure_fail_dagrun"))
+ return
self._get_partial_kwargs_or_operator_default("on_failure_fail_dagrun")
@on_failure_fail_dagrun.setter
def on_failure_fail_dagrun(self, v) -> None:
@@ -368,7 +370,7 @@ class MappedOperator(DAGNode):
)
@functools.cached_property
- def operator_extra_link_dict(self) -> dict[str, BaseOperatorLink]:
+ def operator_extra_link_dict(self) -> dict[str, XComOperatorLink]:
"""Returns dictionary of all extra links for the operator."""
op_extra_links_from_plugin: dict[str, Any] = {}
from airflow import plugins_manager
diff --git a/airflow-core/src/airflow/models/referencemixin.py
b/airflow-core/src/airflow/models/referencemixin.py
index 19a775417f8..cb427215bf9 100644
--- a/airflow-core/src/airflow/models/referencemixin.py
+++ b/airflow-core/src/airflow/models/referencemixin.py
@@ -24,7 +24,7 @@ if TYPE_CHECKING:
from typing import TypeAlias
from airflow.models.mappedoperator import MappedOperator
- from airflow.serialization.serialized_objects import SerializedBaseOperator
+ from airflow.serialization.definitions.baseoperator import
SerializedBaseOperator
Operator: TypeAlias = MappedOperator | SerializedBaseOperator
diff --git a/airflow-core/src/airflow/models/renderedtifields.py
b/airflow-core/src/airflow/models/renderedtifields.py
index a544dbce691..7c3ad88bcaa 100644
--- a/airflow-core/src/airflow/models/renderedtifields.py
+++ b/airflow-core/src/airflow/models/renderedtifields.py
@@ -48,7 +48,7 @@ if TYPE_CHECKING:
from sqlalchemy.sql import FromClause
from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
- from airflow.serialization.serialized_objects import SerializedBaseOperator
+ from airflow.serialization.definitions.baseoperator import
SerializedBaseOperator
def _get_nested_value(obj: Any, path: str) -> Any:
diff --git a/airflow-core/src/airflow/models/taskinstance.py
b/airflow-core/src/airflow/models/taskinstance.py
index ac854ac162c..99037ec6519 100644
--- a/airflow-core/src/airflow/models/taskinstance.py
+++ b/airflow-core/src/airflow/models/taskinstance.py
@@ -115,9 +115,9 @@ if TYPE_CHECKING:
from airflow.models.dag import DagModel
from airflow.models.dagrun import DagRun
from airflow.models.mappedoperator import MappedOperator
+ from airflow.serialization.definitions.baseoperator import
SerializedBaseOperator
from airflow.serialization.definitions.dag import SerializedDAG
from airflow.serialization.definitions.taskgroup import SerializedTaskGroup
- from airflow.serialization.serialized_objects import SerializedBaseOperator
from airflow.utils.context import Context
Operator: TypeAlias = MappedOperator | SerializedBaseOperator
diff --git a/airflow-core/src/airflow/models/taskmap.py
b/airflow-core/src/airflow/models/taskmap.py
index 91326d8873a..7560a5a8a11 100644
--- a/airflow-core/src/airflow/models/taskmap.py
+++ b/airflow-core/src/airflow/models/taskmap.py
@@ -38,7 +38,7 @@ if TYPE_CHECKING:
from airflow.models.mappedoperator import MappedOperator
from airflow.models.taskinstance import TaskInstance
- from airflow.serialization.serialized_objects import SerializedBaseOperator
+ from airflow.serialization.definitions.baseoperator import
SerializedBaseOperator
class TaskMapVariant(enum.Enum):
@@ -141,7 +141,7 @@ class TaskMap(TaskInstanceDependencies):
from airflow.models.expandinput import NotFullyPopulated
from airflow.models.mappedoperator import MappedOperator,
get_mapped_ti_count
from airflow.models.taskinstance import TaskInstance
- from airflow.serialization.serialized_objects import
SerializedBaseOperator
+ from airflow.serialization.definitions.baseoperator import
SerializedBaseOperator
from airflow.settings import task_instance_mutation_hook
if not isinstance(task, (MappedOperator, SerializedBaseOperator)):
diff --git a/airflow-core/src/airflow/models/xcom_arg.py
b/airflow-core/src/airflow/models/xcom_arg.py
index 8ea08a24132..5abd2aec8e1 100644
--- a/airflow-core/src/airflow/models/xcom_arg.py
+++ b/airflow-core/src/airflow/models/xcom_arg.py
@@ -35,7 +35,8 @@ __all__ = ["SchedulerXComArg", "deserialize_xcom_arg",
"get_task_map_length"]
if TYPE_CHECKING:
from airflow.models.mappedoperator import MappedOperator
- from airflow.serialization.serialized_objects import
SerializedBaseOperator, SerializedDAG
+ from airflow.serialization.definitions.baseoperator import
SerializedBaseOperator
+ from airflow.serialization.definitions.dag import SerializedDAG
from airflow.typing_compat import Self
Operator: TypeAlias = MappedOperator | SerializedBaseOperator
@@ -71,7 +72,7 @@ class SchedulerXComArg:
collection objects, and instances with ``template_fields`` set.
"""
from airflow.models.mappedoperator import MappedOperator
- from airflow.serialization.serialized_objects import
SerializedBaseOperator
+ from airflow.serialization.definitions.baseoperator import
SerializedBaseOperator
if isinstance(arg, ReferenceMixin):
yield from arg.iter_references()
diff --git a/airflow-core/src/airflow/serialization/definitions/baseoperator.py
b/airflow-core/src/airflow/serialization/definitions/baseoperator.py
new file mode 100644
index 00000000000..3e8293156c0
--- /dev/null
+++ b/airflow-core/src/airflow/serialization/definitions/baseoperator.py
@@ -0,0 +1,452 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import datetime
+import functools
+from typing import TYPE_CHECKING, Any
+
+import methodtools
+
+from airflow.exceptions import AirflowException
+from airflow.serialization.definitions.node import DAGNode
+from airflow.serialization.definitions.param import SerializedParamsDict
+from airflow.serialization.enums import DagAttributeTypes
+from airflow.task.priority_strategy import PriorityWeightStrategy,
validate_and_load_priority_weight_strategy
+from airflow.ti_deps.deps.mapped_task_upstream_dep import MappedTaskUpstreamDep
+from airflow.ti_deps.deps.not_in_retry_period_dep import NotInRetryPeriodDep
+from airflow.ti_deps.deps.not_previously_skipped_dep import
NotPreviouslySkippedDep
+from airflow.ti_deps.deps.prev_dagrun_dep import PrevDagrunDep
+from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep
+
+if TYPE_CHECKING:
+ from collections.abc import Collection, Iterable, Iterator, Sequence
+
+ from airflow.models.mappedoperator import MappedOperator
+ from airflow.models.taskinstance import TaskInstance
+ from airflow.sdk import Context
+ from airflow.serialization.definitions.dag import SerializedDAG
+ from airflow.serialization.definitions.taskgroup import
SerializedMappedTaskGroup, SerializedTaskGroup
+ from airflow.serialization.serialized_objects import XComOperatorLink
+ from airflow.task.trigger_rule import TriggerRule
+ from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
+ from airflow.triggers.base import StartTriggerArgs
+
+DEFAULT_OPERATOR_DEPS: frozenset[BaseTIDep] = frozenset(
+ (
+ NotInRetryPeriodDep(),
+ PrevDagrunDep(),
+ TriggerRuleDep(),
+ NotPreviouslySkippedDep(),
+ MappedTaskUpstreamDep(),
+ )
+)
+
+
+class SerializedBaseOperator(DAGNode):
+ """
+ Serialized representation of a BaseOperator instance.
+
+ See :mod:`~airflow.serialization.serialized_objects.OperatorSerialization`
+ for more details on operator serialization.
+ """
+
+ _can_skip_downstream: bool
+ _is_empty: bool
+ _needs_expansion: bool
+ _task_display_name: str | None = None
+ _weight_rule: str | PriorityWeightStrategy = "downstream"
+
+ allow_nested_operators: bool = True
+ dag: SerializedDAG | None = None
+ depends_on_past: bool = False
+ do_xcom_push: bool = True
+ doc: str | None = None
+ doc_md: str | None = None
+ doc_json: str | None = None
+ doc_yaml: str | None = None
+ doc_rst: str | None = None
+ downstream_task_ids: set[str] = set()
+ email: str | Sequence[str] | None = None
+
+ # These two are deprecated.
+ email_on_retry: bool = True
+ email_on_failure: bool = True
+
+ execution_timeout: datetime.timedelta | None = None
+ executor: str | None = None
+ executor_config: dict = {}
+ ignore_first_depends_on_past: bool = False
+
+ inlets: Sequence = []
+ is_setup: bool = False
+ is_teardown: bool = False
+
+ map_index_template: str | None = None
+ max_active_tis_per_dag: int | None = None
+ max_active_tis_per_dagrun: int | None = None
+ max_retry_delay: datetime.timedelta | float | None = None
+ multiple_outputs: bool = False
+
+ # Boolean flags for callback existence
+ has_on_execute_callback: bool = False
+ has_on_failure_callback: bool = False
+ has_on_retry_callback: bool = False
+ has_on_success_callback: bool = False
+ has_on_skipped_callback: bool = False
+
+ operator_extra_links: Collection[XComOperatorLink] = []
+ on_failure_fail_dagrun: bool = False
+
+ outlets: Sequence = []
+ owner: str = "airflow"
+ params: SerializedParamsDict = SerializedParamsDict()
+ pool: str = "default_pool"
+ pool_slots: int = 1
+ priority_weight: int = 1
+ queue: str = "default"
+
+ resources: dict[str, Any] | None = None
+ retries: int = 0
+ retry_delay: datetime.timedelta = datetime.timedelta(seconds=300)
+ retry_exponential_backoff: float = 0
+ run_as_user: str | None = None
+ task_group: SerializedTaskGroup | None = None
+
+ start_date: datetime.datetime | None = None
+ end_date: datetime.datetime | None = None
+
+ start_from_trigger: bool = False
+ start_trigger_args: StartTriggerArgs | None = None
+
+ task_type: str = "BaseOperator"
+ template_ext: Sequence[str] = []
+ template_fields: Collection[str] = []
+ template_fields_renderers: dict[str, str] = {}
+
+ trigger_rule: str | TriggerRule = "all_success"
+
+ # TODO: Remove the following, they aren't used anymore
+ ui_color: str = "#fff"
+ ui_fgcolor: str = "#000"
+
+ wait_for_downstream: bool = False
+ wait_for_past_depends_before_skipping: bool = False
+
+ is_mapped = False
+
+ def __init__(self, *, task_id: str, _airflow_from_mapped: bool = False) ->
None:
+ super().__init__()
+ self._BaseOperator__from_mapped = _airflow_from_mapped
+ self.task_id = task_id
+ self.deps = DEFAULT_OPERATOR_DEPS
+ self._operator_name: str | None = None
+
+ # Disable hashing.
+ __hash__ = None # type: ignore[assignment]
+
+ def __eq__(self, other) -> bool:
+ return NotImplemented
+
+ def __repr__(self) -> str:
+ return f"<SerializedTask({self.task_type}): {self.task_id}>"
+
+ @classmethod
+ def get_serialized_fields(cls):
+ """Fields to deserialize from the serialized JSON object."""
+ return frozenset(
+ (
+ "_logger_name",
+ "_needs_expansion",
+ "_task_display_name",
+ "allow_nested_operators",
+ "depends_on_past",
+ "do_xcom_push",
+ "doc",
+ "doc_json",
+ "doc_md",
+ "doc_rst",
+ "doc_yaml",
+ "downstream_task_ids",
+ "email",
+ "email_on_failure",
+ "email_on_retry",
+ "end_date",
+ "execution_timeout",
+ "executor",
+ "executor_config",
+ "ignore_first_depends_on_past",
+ "inlets",
+ "is_setup",
+ "is_teardown",
+ "map_index_template",
+ "max_active_tis_per_dag",
+ "max_active_tis_per_dagrun",
+ "max_retry_delay",
+ "multiple_outputs",
+ "has_on_execute_callback",
+ "has_on_failure_callback",
+ "has_on_retry_callback",
+ "has_on_skipped_callback",
+ "has_on_success_callback",
+ "on_failure_fail_dagrun",
+ "outlets",
+ "owner",
+ "params",
+ "pool",
+ "pool_slots",
+ "priority_weight",
+ "queue",
+ "resources",
+ "retries",
+ "retry_delay",
+ "retry_exponential_backoff",
+ "run_as_user",
+ "start_date",
+ "start_from_trigger",
+ "start_trigger_args",
+ "task_id",
+ "task_type",
+ "template_ext",
+ "template_fields",
+ "template_fields_renderers",
+ "trigger_rule",
+ "ui_color",
+ "ui_fgcolor",
+ "wait_for_downstream",
+ "wait_for_past_depends_before_skipping",
+ "weight_rule",
+ )
+ )
+
+ @property
+ def node_id(self) -> str:
+ return self.task_id
+
+ def get_dag(self) -> SerializedDAG | None:
+ return self.dag
+
+ @property
+ def roots(self) -> Sequence[DAGNode]:
+ """Required by DAGNode."""
+ return [self]
+
+ @property
+ def leaves(self) -> Sequence[DAGNode]:
+ """Required by DAGNode."""
+ return [self]
+
+ @functools.cached_property
+ def operator_extra_link_dict(self) -> dict[str, XComOperatorLink]:
+ """All extra links for the operator."""
+ return {link.name: link for link in self.operator_extra_links}
+
+ @functools.cached_property
+ def global_operator_extra_link_dict(self) -> dict[str, Any]:
+ """All global extra links."""
+ from airflow import plugins_manager
+
+ plugins_manager.initialize_extra_operators_links_plugins()
+ if plugins_manager.global_operator_extra_links is None:
+ raise AirflowException("Can't load operators")
+ return {link.name: link for link in
plugins_manager.global_operator_extra_links}
+
+ @functools.cached_property
+ def extra_links(self) -> list[str]:
+ return
sorted(set(self.operator_extra_link_dict).union(self.global_operator_extra_link_dict))
+
+ def get_extra_links(self, ti: TaskInstance, name: str) -> str | None:
+ """
+ For an operator, gets the URLs that the ``extra_links`` entry points
to.
+
+ :meta private:
+
+ :raise ValueError: The error message of a ValueError will be passed on
through to
+ the fronted to show up as a tooltip on the disabled link.
+ :param ti: The TaskInstance for the URL being searched for.
+ :param name: The name of the link we're looking for the URL for.
Should be
+ one of the options specified in ``extra_links``.
+ """
+ link = self.operator_extra_link_dict.get(name) or
self.global_operator_extra_link_dict.get(name)
+ if not link:
+ return None
+ # TODO: GH-52141 - BaseOperatorLink.get_link expects BaseOperator but
receives SerializedBaseOperator.
+ return link.get_link(self, ti_key=ti.key) # type: ignore[arg-type]
+
+ @property
+ def inherits_from_empty_operator(self) -> bool:
+ return self._is_empty
+
+ @property
+ def inherits_from_skipmixin(self) -> bool:
+ return self._can_skip_downstream
+
+ @property
+ def operator_name(self) -> str:
+ # Overwrites operator_name of BaseOperator to use _operator_name
instead of
+ # __class__.operator_name.
+ return self._operator_name or self.task_type
+
+ @operator_name.setter
+ def operator_name(self, operator_name: str):
+ self._operator_name = operator_name
+
+ @property
+ def task_display_name(self) -> str:
+ return self._task_display_name or self.task_id
+
+ def expand_start_trigger_args(self, *, context: Context) ->
StartTriggerArgs | None:
+ return self.start_trigger_args
+
+ @property
+ def weight_rule(self) -> PriorityWeightStrategy:
+ if isinstance(self._weight_rule, PriorityWeightStrategy):
+ return self._weight_rule
+ return validate_and_load_priority_weight_strategy(self._weight_rule)
+
+ def __getattr__(self, name):
+ # Handle missing attributes with task_type instead of
SerializedBaseOperator
+ # Don't intercept special methods that Python internals might check
+ if name.startswith("__") and name.endswith("__"):
+ # For special methods, raise the original error
+ raise AttributeError(f"'{self.__class__.__name__}' object has no
attribute '{name}'")
+ # For regular attributes, use task_type in the error message
+ raise AttributeError(f"'{self.task_type}' object has no attribute
'{name}'")
+
+ def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
+ return DagAttributeTypes.OP, self.task_id
+
+ def expand_start_from_trigger(self, *, context: Context) -> bool:
+ """
+ Get the start_from_trigger value of the current abstract operator.
+
+ Since a BaseOperator is not mapped to begin with, this simply returns
+ the original value of start_from_trigger.
+
+ :meta private:
+ """
+ return self.start_from_trigger
+
+ def _iter_all_mapped_downstreams(self) -> Iterator[MappedOperator |
SerializedMappedTaskGroup]:
+ """
+ Return mapped nodes that are direct dependencies of the current task.
+
+ For now, this walks the entire DAG to find mapped nodes that has this
+ current task as an upstream. We cannot use ``downstream_list`` since it
+ only contains operators, not task groups. In the future, we should
+ provide a way to record an DAG node's all downstream nodes instead.
+
+ Note that this does not guarantee the returned tasks actually use the
+ current task for task mapping, but only checks those task are mapped
+ operators, and are downstreams of the current task.
+
+ To get a list of tasks that uses the current task for task mapping, use
+ :meth:`iter_mapped_dependants` instead.
+ """
+ from airflow.models.mappedoperator import MappedOperator
+ from airflow.serialization.definitions.taskgroup import
SerializedMappedTaskGroup, SerializedTaskGroup
+
+ def _walk_group(group: SerializedTaskGroup) -> Iterable[tuple[str,
DAGNode]]:
+ """
+ Recursively walk children in a task group.
+
+ This yields all direct children (including both tasks and task
+ groups), and all children of any task groups.
+ """
+ for key, child in group.children.items():
+ yield key, child
+ if isinstance(child, SerializedTaskGroup):
+ yield from _walk_group(child)
+
+ if not (dag := self.dag):
+ raise RuntimeError("Cannot check for mapped dependants when not
attached to a DAG")
+ for key, child in _walk_group(dag.task_group):
+ if key == self.node_id:
+ continue
+ if not isinstance(child, MappedOperator |
SerializedMappedTaskGroup):
+ continue
+ if self.node_id in child.upstream_task_ids:
+ yield child
+
+ def iter_mapped_dependants(self) -> Iterator[MappedOperator |
SerializedMappedTaskGroup]:
+ """
+ Return mapped nodes that depend on the current task the expansion.
+
+ For now, this walks the entire DAG to find mapped nodes that has this
+ current task as an upstream. We cannot use ``downstream_list`` since it
+ only contains operators, not task groups. In the future, we should
+ provide a way to record an DAG node's all downstream nodes instead.
+ """
+ return (
+ downstream
+ for downstream in self._iter_all_mapped_downstreams()
+ if any(p.node_id == self.node_id for p in
downstream.iter_mapped_dependencies())
+ )
+
+ # TODO (GH-52141): Copied from sdk. Find a better place for this to live
in.
+ def iter_mapped_task_groups(self) -> Iterator[SerializedMappedTaskGroup]:
+ """
+ Return mapped task groups this task belongs to.
+
+ Groups are returned from the innermost to the outmost.
+
+ :meta private:
+ """
+ if (group := self.task_group) is None:
+ return
+ yield from group.iter_mapped_task_groups()
+
+ # TODO (GH-52141): Copied from sdk. Find a better place for this to live
in.
+ def get_closest_mapped_task_group(self) -> SerializedMappedTaskGroup |
None:
+ """
+ Get the mapped task group "closest" to this task in the DAG.
+
+ :meta private:
+ """
+ return next(self.iter_mapped_task_groups(), None)
+
+ # TODO (GH-52141): Copied from sdk. Find a better place for this to live
in.
+ def get_needs_expansion(self) -> bool:
+ """
+ Return true if the task is MappedOperator or is in a mapped task group.
+
+ :meta private:
+ """
+ return self._needs_expansion
+
+ # TODO (GH-52141): Copied from sdk. Find a better place for this to live
in.
+ @methodtools.lru_cache(maxsize=1)
+ def get_parse_time_mapped_ti_count(self) -> int:
+ """
+ Return the number of mapped task instances that can be created on DAG
run creation.
+
+ This only considers literal mapped arguments, and would return *None*
+ when any non-literal values are used for mapping.
+
+ :raise NotFullyPopulated: If non-literal mapped arguments are
encountered.
+ :raise NotMapped: If the operator is neither mapped, nor has any parent
+ mapped task groups.
+ :return: Total number of mapped TIs this task should have.
+ """
+ from airflow.exceptions import NotMapped
+
+ group = self.get_closest_mapped_task_group()
+ if group is None:
+ raise NotMapped()
+ return group.get_parse_time_mapped_ti_count()
diff --git a/airflow-core/src/airflow/serialization/definitions/dag.py
b/airflow-core/src/airflow/serialization/definitions/dag.py
index 238d2d748dc..53e366223cf 100644
--- a/airflow-core/src/airflow/serialization/definitions/dag.py
+++ b/airflow-core/src/airflow/serialization/definitions/dag.py
@@ -245,7 +245,7 @@ class SerializedDAG:
exclude_original: bool = False,
):
from airflow.models.mappedoperator import MappedOperator as
SerializedMappedOperator
- from airflow.serialization.serialized_objects import
SerializedBaseOperator
+ from airflow.serialization.definitions.baseoperator import
SerializedBaseOperator
def is_task(obj) -> TypeIs[SerializedOperator]:
return isinstance(obj, (SerializedMappedOperator,
SerializedBaseOperator))
diff --git a/airflow-core/src/airflow/serialization/definitions/node.py
b/airflow-core/src/airflow/serialization/definitions/node.py
index b17e46234ab..d5a07cb6099 100644
--- a/airflow-core/src/airflow/serialization/definitions/node.py
+++ b/airflow-core/src/airflow/serialization/definitions/node.py
@@ -27,11 +27,14 @@ if TYPE_CHECKING:
from typing import TypeAlias
from airflow.models.mappedoperator import MappedOperator
+ from airflow.serialization.definitions.baseoperator import
SerializedBaseOperator
+ from airflow.serialization.definitions.dag import SerializedDAG # noqa:
F401
from airflow.serialization.definitions.taskgroup import
SerializedTaskGroup # noqa: F401
- from airflow.serialization.serialized_objects import
SerializedBaseOperator, SerializedDAG # noqa: F401
Operator: TypeAlias = SerializedBaseOperator | MappedOperator
+__all__ = ["DAGNode"]
+
class DAGNode(GenericDAGNode["SerializedDAG", "Operator",
"SerializedTaskGroup"], metaclass=abc.ABCMeta):
"""
diff --git a/airflow-core/src/airflow/serialization/definitions/taskgroup.py
b/airflow-core/src/airflow/serialization/definitions/taskgroup.py
index c127353bcfe..414936d7e67 100644
--- a/airflow-core/src/airflow/serialization/definitions/taskgroup.py
+++ b/airflow-core/src/airflow/serialization/definitions/taskgroup.py
@@ -34,7 +34,7 @@ if TYPE_CHECKING:
from typing import Any, ClassVar
from airflow.models.expandinput import SchedulerExpandInput
- from airflow.serialization.serialized_objects import SerializedDAG,
SerializedOperator
+ from airflow.serialization.definitions.dag import SerializedDAG,
SerializedOperator
@attrs.define(eq=False, hash=False, kw_only=True)
@@ -186,7 +186,7 @@ class SerializedTaskGroup(DAGNode):
def iter_tasks(self) -> Iterator[SerializedOperator]:
"""Return an iterator of the child tasks."""
from airflow.models.mappedoperator import MappedOperator
- from airflow.serialization.serialized_objects import
SerializedBaseOperator
+ from airflow.serialization.definitions.baseoperator import
SerializedBaseOperator
groups_to_visit = [self]
while groups_to_visit:
diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py
b/airflow-core/src/airflow/serialization/serialized_objects.py
index 2f6d4b362a8..5bf74f6ee10 100644
--- a/airflow-core/src/airflow/serialization/serialized_objects.py
+++ b/airflow-core/src/airflow/serialization/serialized_objects.py
@@ -28,7 +28,7 @@ import logging
import math
import sys
import weakref
-from collections.abc import Collection, Iterable, Iterator, Mapping, Sequence
+from collections.abc import Collection, Iterable, Mapping
from functools import cache, cached_property, lru_cache
from inspect import signature
from textwrap import dedent
@@ -36,7 +36,6 @@ from typing import TYPE_CHECKING, Any, ClassVar, NamedTuple,
TypeAlias, TypeVar,
import attrs
import lazy_object_proxy
-import methodtools
import pydantic
from dateutil import relativedelta
from pendulum.tz.timezone import FixedTimezone, Timezone
@@ -74,6 +73,7 @@ from airflow.serialization.definitions.assets import (
SerializedAssetBase,
SerializedAssetUniqueKey,
)
+from airflow.serialization.definitions.baseoperator import
SerializedBaseOperator
from airflow.serialization.definitions.dag import SerializedDAG
from airflow.serialization.definitions.node import DAGNode
from airflow.serialization.definitions.param import SerializedParam,
SerializedParamsDict
@@ -94,13 +94,8 @@ from airflow.task.priority_strategy import (
PriorityWeightStrategy,
airflow_priority_weight_strategies,
airflow_priority_weight_strategies_classes,
- validate_and_load_priority_weight_strategy,
)
-from airflow.ti_deps.deps.mapped_task_upstream_dep import MappedTaskUpstreamDep
-from airflow.ti_deps.deps.not_in_retry_period_dep import NotInRetryPeriodDep
-from airflow.ti_deps.deps.not_previously_skipped_dep import
NotPreviouslySkippedDep
-from airflow.ti_deps.deps.prev_dagrun_dep import PrevDagrunDep
-from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep
+from airflow.timetables.base import DagRunInfo, Timetable
from airflow.triggers.base import BaseTrigger, StartTriggerArgs
from airflow.utils.code_utils import get_python_source
from airflow.utils.context import ConnectionAccessor, Context, VariableAccessor
@@ -115,29 +110,16 @@ if TYPE_CHECKING:
from airflow.models.expandinput import SchedulerExpandInput
from airflow.models.mappedoperator import MappedOperator as
SerializedMappedOperator
- from airflow.models.taskinstance import TaskInstance
from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator
# noqa: TC004
from airflow.sdk import BaseOperatorLink
from airflow.sdk.definitions._internal.node import DAGNode as SDKDAGNode
from airflow.serialization.json_schema import Validator
- from airflow.task.trigger_rule import TriggerRule
- from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
from airflow.timetables.base import DagRunInfo, Timetable
from airflow.timetables.simple import PartitionMapper
SerializedOperator: TypeAlias = "SerializedMappedOperator |
SerializedBaseOperator"
SdkOperator: TypeAlias = BaseOperator | MappedOperator
-DEFAULT_OPERATOR_DEPS: frozenset[BaseTIDep] = frozenset(
- (
- NotInRetryPeriodDep(),
- PrevDagrunDep(),
- TriggerRuleDep(),
- NotPreviouslySkippedDep(),
- MappedTaskUpstreamDep(),
- )
-)
-
log = logging.getLogger(__name__)
@@ -589,10 +571,10 @@ class BaseSerialization:
elif isinstance(var, Resources):
return var.to_dict()
elif isinstance(var, MappedOperator):
- return
cls._encode(SerializedBaseOperator.serialize_mapped_operator(var), type_=DAT.OP)
+ return
cls._encode(OperatorSerialization.serialize_mapped_operator(var), type_=DAT.OP)
elif isinstance(var, BaseOperator):
var._needs_expansion = var.get_needs_expansion()
- return cls._encode(SerializedBaseOperator.serialize_operator(var),
type_=DAT.OP)
+ return cls._encode(OperatorSerialization.serialize_operator(var),
type_=DAT.OP)
elif isinstance(var, cls._datetime_types):
return cls._encode(var.timestamp(), type_=DAT.DATETIME)
elif isinstance(var, datetime.timedelta):
@@ -729,7 +711,7 @@ class BaseSerialization:
elif type_ == DAT.DAG:
return DagSerialization.deserialize_dag(var)
elif type_ == DAT.OP:
- return SerializedBaseOperator.deserialize_operator(var)
+ return OperatorSerialization.deserialize_operator(var)
elif type_ == DAT.DATETIME:
return from_timestamp(var)
elif type_ == DAT.POD:
@@ -1023,19 +1005,17 @@ class DependencyDetector:
yield from tt.asset_condition.iter_dag_dependencies(source="",
target=dag.dag_id)
-class SerializedBaseOperator(DAGNode, BaseSerialization):
+class OperatorSerialization(DAGNode, BaseSerialization):
"""
- A JSON serializable representation of operator.
+ Logic to encode an operator and decode the data.
- All operators are casted to SerializedBaseOperator after deserialization.
- Class specific attributes used by UI are move to object attributes.
+ This covers serialization of both BaseOperator and MappedOperator. Creating
+ a serializaed operator is a three-step process:
- Creating a SerializedBaseOperator is a three-step process:
-
- 1. Instantiate a :class:`SerializedBaseOperator` object.
- 2. Populate attributes with
:func:`SerializedBaseOperator.populated_operator`.
+ 1. Instantiate a :class:`SerializedBaseOperator` or
:class:`MappedOperator` object.
+ 2. Populate attributes with
:func:`OperatorSerialization.populated_operator`.
3. When the task's containing DAG is available, fix references to the DAG
- with :func:`SerializedBaseOperator.set_task_dag_references`.
+ with :func:`OperatorSerialization.set_task_dag_references`.
"""
_decorated_fields = {"executor_config"}
@@ -1046,199 +1026,6 @@ class SerializedBaseOperator(DAGNode,
BaseSerialization):
_const_fields: ClassVar[set[str] | None] = None
- _can_skip_downstream: bool
- _is_empty: bool
- _needs_expansion: bool
- _task_display_name: str | None
- _weight_rule: str | PriorityWeightStrategy = "downstream"
-
- dag: SerializedDAG | None = None
- task_group: SerializedTaskGroup | None = None
-
- allow_nested_operators: bool = True
- depends_on_past: bool = False
- do_xcom_push: bool = True
- doc: str | None = None
- doc_md: str | None = None
- doc_json: str | None = None
- doc_yaml: str | None = None
- doc_rst: str | None = None
- downstream_task_ids: set[str] = set()
- email: str | Sequence[str] | None
-
- # Following 2 should be deprecated
- email_on_retry: bool = True
- email_on_failure: bool = True
-
- execution_timeout: datetime.timedelta | None
- executor: str | None
- executor_config: dict = {}
- ignore_first_depends_on_past: bool = False
-
- inlets: Sequence = []
- is_setup: bool = False
- is_teardown: bool = False
-
- map_index_template: str | None = None
- max_active_tis_per_dag: int | None = None
- max_active_tis_per_dagrun: int | None = None
- max_retry_delay: datetime.timedelta | float | None = None
- multiple_outputs: bool = False
-
- # Boolean flags for callback existence
- has_on_execute_callback: bool = False
- has_on_failure_callback: bool = False
- has_on_retry_callback: bool = False
- has_on_success_callback: bool = False
- has_on_skipped_callback: bool = False
-
- operator_extra_links: Collection[BaseOperatorLink] = []
- on_failure_fail_dagrun: bool = False
-
- outlets: Sequence = []
- owner: str = "airflow"
- params: SerializedParamsDict = SerializedParamsDict()
- pool: str = "default_pool"
- pool_slots: int = 1
- priority_weight: int = 1
- queue: str = "default"
-
- resources: dict[str, Any] | None = None
- retries: int = 0
- retry_delay: datetime.timedelta = datetime.timedelta(seconds=300)
- retry_exponential_backoff: float = 0
- run_as_user: str | None = None
-
- start_date: datetime.datetime | None = None
- end_date: datetime.datetime | None = None
-
- start_from_trigger: bool = False
- start_trigger_args: StartTriggerArgs | None = None
-
- task_type: str = "BaseOperator"
- template_ext: Sequence[str] = []
- template_fields: Collection[str] = []
- template_fields_renderers: ClassVar[dict[str, str]] = {}
-
- trigger_rule: str | TriggerRule = "all_success"
-
- # TODO: Remove the following, they aren't used anymore
- ui_color: str = "#fff"
- ui_fgcolor: str = "#000"
-
- wait_for_downstream: bool = False
- wait_for_past_depends_before_skipping: bool = False
-
- is_mapped = False
-
- def __init__(self, *, task_id: str, _airflow_from_mapped: bool = False) ->
None:
- super().__init__()
-
- self._BaseOperator__from_mapped = _airflow_from_mapped
- self.task_id = task_id
- # Move class attributes into object attributes.
- self.deps = DEFAULT_OPERATOR_DEPS
- self._operator_name: str | None = None
-
- def __eq__(self, other: Any) -> bool:
- if not isinstance(other, (SerializedBaseOperator, BaseOperator)):
- return NotImplemented
- return self.task_type == other.task_type and all(
- getattr(self, c, None) == getattr(other, c, None) for c in
BaseOperator._comps
- )
-
- def __hash__(self):
- return hash((self.task_type, *[getattr(self, c, None) for c in
BaseOperator._comps]))
-
- def __repr__(self) -> str:
- return f"<SerializedTask({self.task_type}): {self.task_id}>"
-
- @property
- def node_id(self) -> str:
- return self.task_id
-
- def get_dag(self) -> SerializedDAG | None:
- return self.dag
-
- @property
- def roots(self) -> Sequence[DAGNode]:
- """Required by DAGNode."""
- return [self]
-
- @property
- def leaves(self) -> Sequence[DAGNode]:
- """Required by DAGNode."""
- return [self]
-
- @cached_property
- def operator_extra_link_dict(self) -> dict[str, BaseOperatorLink]:
- """Returns dictionary of all extra links for the operator."""
- return {link.name: link for link in self.operator_extra_links}
-
- @cached_property
- def global_operator_extra_link_dict(self) -> dict[str, Any]:
- """Returns dictionary of all global extra links."""
- from airflow import plugins_manager
-
- plugins_manager.initialize_extra_operators_links_plugins()
- if plugins_manager.global_operator_extra_links is None:
- raise AirflowException("Can't load operators")
- return {link.name: link for link in
plugins_manager.global_operator_extra_links}
-
- @cached_property
- def extra_links(self) -> list[str]:
- return
sorted(set(self.operator_extra_link_dict).union(self.global_operator_extra_link_dict))
-
- def get_extra_links(self, ti: TaskInstance, name: str) -> str | None:
- """
- For an operator, gets the URLs that the ``extra_links`` entry points
to.
-
- :meta private:
-
- :raise ValueError: The error message of a ValueError will be passed on
through to
- the fronted to show up as a tooltip on the disabled link.
- :param ti: The TaskInstance for the URL being searched for.
- :param name: The name of the link we're looking for the URL for.
Should be
- one of the options specified in ``extra_links``.
- """
- link = self.operator_extra_link_dict.get(name) or
self.global_operator_extra_link_dict.get(name)
- if not link:
- return None
- # TODO: GH-52141 - BaseOperatorLink.get_link expects BaseOperator but
receives SerializedBaseOperator.
- return link.get_link(self, ti_key=ti.key) # type: ignore[arg-type]
-
- @property
- def operator_name(self) -> str:
- # Overwrites operator_name of BaseOperator to use _operator_name
instead of
- # __class__.operator_name.
- return self._operator_name or self.task_type
-
- @operator_name.setter
- def operator_name(self, operator_name: str):
- self._operator_name = operator_name
-
- @property
- def task_display_name(self) -> str:
- return self._task_display_name or self.task_id
-
- def expand_start_trigger_args(self, *, context: Context) ->
StartTriggerArgs | None:
- return self.start_trigger_args
-
- @property
- def weight_rule(self) -> PriorityWeightStrategy:
- if isinstance(self._weight_rule, PriorityWeightStrategy):
- return self._weight_rule
- return validate_and_load_priority_weight_strategy(self._weight_rule)
-
- def __getattr__(self, name):
- # Handle missing attributes with task_type instead of
SerializedBaseOperator
- # Don't intercept special methods that Python internals might check
- if name.startswith("__") and name.endswith("__"):
- # For special methods, raise the original error
- raise AttributeError(f"'{self.__class__.__name__}' object has no
attribute '{name}'")
- # For regular attributes, use task_type in the error message
- raise AttributeError(f"'{self.task_type}' object has no attribute
'{name}'")
-
@classmethod
def serialize_mapped_operator(cls, op: MappedOperator) -> dict[str, Any]:
serialized_op = cls._serialize_node(op)
@@ -1804,97 +1591,6 @@ class SerializedBaseOperator(DAGNode, BaseSerialization):
def deserialize(cls, encoded_var: Any) -> Any:
return BaseSerialization.deserialize(encoded_var=encoded_var)
- def serialize_for_task_group(self) -> tuple[DAT, Any]:
- """Serialize; required by DAGNode."""
- return DAT.OP, self.task_id
-
- @property
- def inherits_from_empty_operator(self) -> bool:
- return self._is_empty
-
- @property
- def inherits_from_skipmixin(self) -> bool:
- return self._can_skip_downstream
-
- def expand_start_from_trigger(self, *, context: Context) -> bool:
- """
- Get the start_from_trigger value of the current abstract operator.
-
- Since a BaseOperator is not mapped to begin with, this simply returns
- the original value of start_from_trigger.
-
- :meta private:
- """
- return self.start_from_trigger
-
- @classmethod
- def get_serialized_fields(cls):
- """Fields to deserialize from the serialized JSON object."""
- return frozenset(
- {
- "_logger_name",
- "_needs_expansion",
- "_task_display_name",
- "allow_nested_operators",
- "depends_on_past",
- "do_xcom_push",
- "doc",
- "doc_json",
- "doc_md",
- "doc_rst",
- "doc_yaml",
- "downstream_task_ids",
- "email",
- "email_on_failure",
- "email_on_retry",
- "end_date",
- "execution_timeout",
- "executor",
- "executor_config",
- "ignore_first_depends_on_past",
- "inlets",
- "is_setup",
- "is_teardown",
- "map_index_template",
- "max_active_tis_per_dag",
- "max_active_tis_per_dagrun",
- "max_retry_delay",
- "multiple_outputs",
- "has_on_execute_callback",
- "has_on_failure_callback",
- "has_on_retry_callback",
- "has_on_skipped_callback",
- "has_on_success_callback",
- "on_failure_fail_dagrun",
- "outlets",
- "owner",
- "params",
- "pool",
- "pool_slots",
- "priority_weight",
- "queue",
- "resources",
- "retries",
- "retry_delay",
- "retry_exponential_backoff",
- "run_as_user",
- "start_date",
- "start_from_trigger",
- "start_trigger_args",
- "task_id",
- "task_type",
- "template_ext",
- "template_fields",
- "template_fields_renderers",
- "trigger_rule",
- "ui_color",
- "ui_fgcolor",
- "wait_for_downstream",
- "wait_for_past_depends_before_skipping",
- "weight_rule",
- }
- )
-
@classmethod
@lru_cache(maxsize=1)
def generate_client_defaults(cls) -> dict[str, Any]:
@@ -1913,7 +1609,7 @@ class SerializedBaseOperator(DAGNode, BaseSerialization):
# Only include OPERATOR_DEFAULTS values that differ from schema
defaults
for k, v in OPERATOR_DEFAULTS.items():
- if k not in cls.get_serialized_fields():
+ if k not in SerializedBaseOperator.get_serialized_fields():
continue
# Exclude values that are None or empty collections
@@ -2045,112 +1741,6 @@ class SerializedBaseOperator(DAGNode,
BaseSerialization):
return result
- def _iter_all_mapped_downstreams(self) -> Iterator[MappedOperator |
SerializedMappedTaskGroup]:
- """
- Return mapped nodes that are direct dependencies of the current task.
-
- For now, this walks the entire DAG to find mapped nodes that has this
- current task as an upstream. We cannot use ``downstream_list`` since it
- only contains operators, not task groups. In the future, we should
- provide a way to record an DAG node's all downstream nodes instead.
-
- Note that this does not guarantee the returned tasks actually use the
- current task for task mapping, but only checks those task are mapped
- operators, and are downstreams of the current task.
-
- To get a list of tasks that uses the current task for task mapping, use
- :meth:`iter_mapped_dependants` instead.
- """
-
- def _walk_group(group: SerializedTaskGroup) -> Iterable[tuple[str,
DAGNode]]:
- """
- Recursively walk children in a task group.
-
- This yields all direct children (including both tasks and task
- groups), and all children of any task groups.
- """
- for key, child in group.children.items():
- yield key, child
- if isinstance(child, SerializedTaskGroup):
- yield from _walk_group(child)
-
- if not (dag := self.dag):
- raise RuntimeError("Cannot check for mapped dependants when not
attached to a DAG")
- for key, child in _walk_group(dag.task_group):
- if key == self.node_id:
- continue
- if not isinstance(child, MappedOperator |
SerializedMappedTaskGroup):
- continue
- if self.node_id in child.upstream_task_ids:
- yield child
-
- def iter_mapped_dependants(self) -> Iterator[MappedOperator |
SerializedMappedTaskGroup]:
- """
- Return mapped nodes that depend on the current task the expansion.
-
- For now, this walks the entire DAG to find mapped nodes that has this
- current task as an upstream. We cannot use ``downstream_list`` since it
- only contains operators, not task groups. In the future, we should
- provide a way to record an DAG node's all downstream nodes instead.
- """
- return (
- downstream
- for downstream in self._iter_all_mapped_downstreams()
- if any(p.node_id == self.node_id for p in
downstream.iter_mapped_dependencies())
- )
-
- # TODO (GH-52141): Copied from sdk. Find a better place for this to live
in.
- def iter_mapped_task_groups(self) -> Iterator[SerializedMappedTaskGroup]:
- """
- Return mapped task groups this task belongs to.
-
- Groups are returned from the innermost to the outmost.
-
- :meta private:
- """
- if (group := self.task_group) is None:
- return
- yield from group.iter_mapped_task_groups()
-
- # TODO (GH-52141): Copied from sdk. Find a better place for this to live
in.
- def get_closest_mapped_task_group(self) -> SerializedMappedTaskGroup |
None:
- """
- Get the mapped task group "closest" to this task in the DAG.
-
- :meta private:
- """
- return next(self.iter_mapped_task_groups(), None)
-
- # TODO (GH-52141): Copied from sdk. Find a better place for this to live
in.
- def get_needs_expansion(self) -> bool:
- """
- Return true if the task is MappedOperator or is in a mapped task group.
-
- :meta private:
- """
- return self._needs_expansion
-
- # TODO (GH-52141): Copied from sdk. Find a better place for this to live
in.
- @methodtools.lru_cache(maxsize=1)
- def get_parse_time_mapped_ti_count(self) -> int:
- """
- Return the number of mapped task instances that can be created on DAG
run creation.
-
- This only considers literal mapped arguments, and would return *None*
- when any non-literal values are used for mapping.
-
- :raise NotFullyPopulated: If non-literal mapped arguments are
encountered.
- :raise NotMapped: If the operator is neither mapped, nor has any parent
- mapped task groups.
- :return: Total number of mapped TIs this task should have.
- """
- from airflow.exceptions import NotMapped
-
- group = self.get_closest_mapped_task_group()
- if group is None:
- raise NotMapped()
- return group.get_parse_time_mapped_ti_count()
-
class DagSerialization(BaseSerialization):
"""Logic to encode a ``DAG`` object and decode the data into
``SerializedDAG``."""
@@ -2184,7 +1774,7 @@ class DagSerialization(BaseSerialization):
dag_deps = [
dep
for task in dag.task_dict.values()
- for dep in SerializedBaseOperator.detect_dependencies(task)
+ for dep in OperatorSerialization.detect_dependencies(task)
]
dag_deps.extend(DependencyDetector.detect_dag_dependencies(dag))
serialized_dag["dag_dependencies"] = [x.__dict__ for x in
sorted(dag_deps)]
@@ -2266,13 +1856,11 @@ class DagSerialization(BaseSerialization):
if k == "_downstream_task_ids":
v = set(v)
elif k == "tasks":
- SerializedBaseOperator._load_operator_extra_links =
cls._load_operator_extra_links
+ OperatorSerialization._load_operator_extra_links =
cls._load_operator_extra_links
tasks = {}
for obj in v:
if obj.get(Encoding.TYPE) == DAT.OP:
- deser = SerializedBaseOperator.deserialize_operator(
- obj[Encoding.VAR], client_defaults
- )
+ deser =
OperatorSerialization.deserialize_operator(obj[Encoding.VAR], client_defaults)
tasks[deser.task_id] = deser
k = "task_dict"
v = tasks
@@ -2344,7 +1932,7 @@ class DagSerialization(BaseSerialization):
setattr(dag, k, None)
for t in dag.task_dict.values():
- SerializedBaseOperator.set_task_dag_references(t, dag)
+ OperatorSerialization.set_task_dag_references(t, dag)
return dag
@@ -2393,13 +1981,13 @@ class DagSerialization(BaseSerialization):
"""Stringifies DAGs and operators contained by var and returns a dict
of var."""
# Clear any cached client_defaults to ensure fresh generation for this
DAG
# Clear lru_cache for client defaults
- SerializedBaseOperator.generate_client_defaults.cache_clear()
+ OperatorSerialization.generate_client_defaults.cache_clear()
json_dict = {"__version": cls.SERIALIZER_VERSION, "dag":
cls.serialize_dag(var)}
# Add client_defaults section with only values that differ from schema
defaults
# for tasks
- client_defaults = SerializedBaseOperator.generate_client_defaults()
+ client_defaults = OperatorSerialization.generate_client_defaults()
if client_defaults:
json_dict["client_defaults"] = {"tasks": client_defaults}
@@ -2831,9 +2419,9 @@ def create_scheduler_operator(op: SdkOperator |
SerializedOperator) -> Serialize
if isinstance(op, (SerializedBaseOperator, SerializedMappedOperator)):
return op
if isinstance(op, BaseOperator):
- d = SerializedBaseOperator.serialize_operator(op)
+ d = OperatorSerialization.serialize_operator(op)
elif isinstance(op, MappedOperator):
- d = SerializedBaseOperator.serialize_mapped_operator(op)
+ d = OperatorSerialization.serialize_mapped_operator(op)
else:
raise TypeError(type(op).__name__)
- return SerializedBaseOperator.deserialize_operator(d)
+ return OperatorSerialization.deserialize_operator(d)
diff --git a/airflow-core/src/airflow/ti_deps/deps/mapped_task_upstream_dep.py
b/airflow-core/src/airflow/ti_deps/deps/mapped_task_upstream_dep.py
index 5e00d4b7b1b..87a6c175b9e 100644
--- a/airflow-core/src/airflow/ti_deps/deps/mapped_task_upstream_dep.py
+++ b/airflow-core/src/airflow/ti_deps/deps/mapped_task_upstream_dep.py
@@ -30,7 +30,7 @@ if TYPE_CHECKING:
from airflow.models.mappedoperator import MappedOperator
from airflow.models.taskinstance import TaskInstance
- from airflow.serialization.serialized_objects import SerializedBaseOperator
+ from airflow.serialization.definitions.baseoperator import
SerializedBaseOperator
from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.deps.base_ti_dep import TIDepStatus
diff --git a/airflow-core/src/airflow/ti_deps/deps/prev_dagrun_dep.py
b/airflow-core/src/airflow/ti_deps/deps/prev_dagrun_dep.py
index 2322476f83c..8b123c4e9a1 100644
--- a/airflow-core/src/airflow/ti_deps/deps/prev_dagrun_dep.py
+++ b/airflow-core/src/airflow/ti_deps/deps/prev_dagrun_dep.py
@@ -33,7 +33,7 @@ if TYPE_CHECKING:
from sqlalchemy.orm import Session
from airflow.models.mappedoperator import MappedOperator
- from airflow.serialization.serialized_objects import SerializedBaseOperator
+ from airflow.serialization.definitions.baseoperator import
SerializedBaseOperator
SchedulerOperator: TypeAlias = MappedOperator | SerializedBaseOperator
diff --git a/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py
b/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py
index 2effd1fef6e..e52fa0195f6 100644
--- a/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py
+++ b/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py
@@ -37,8 +37,8 @@ if TYPE_CHECKING:
from airflow.models.mappedoperator import MappedOperator
from airflow.models.taskinstance import TaskInstance
+ from airflow.serialization.definitions.baseoperator import
SerializedBaseOperator
from airflow.serialization.definitions.taskgroup import
SerializedMappedTaskGroup
- from airflow.serialization.serialized_objects import SerializedBaseOperator
from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.deps.base_ti_dep import TIDepStatus
diff --git a/airflow-core/src/airflow/utils/cli.py
b/airflow-core/src/airflow/utils/cli.py
index ed9eb7e84f6..1e9e98514a5 100644
--- a/airflow-core/src/airflow/utils/cli.py
+++ b/airflow-core/src/airflow/utils/cli.py
@@ -46,7 +46,7 @@ T = TypeVar("T", bound=Callable)
if TYPE_CHECKING:
from airflow.sdk import DAG
- from airflow.serialization.serialized_objects import SerializedDAG
+ from airflow.serialization.definitions.dag import SerializedDAG
logger = logging.getLogger(__name__)
diff --git a/airflow-core/src/airflow/utils/dag_edges.py
b/airflow-core/src/airflow/utils/dag_edges.py
index 1f3c0fbd254..b4087b6bcf9 100644
--- a/airflow-core/src/airflow/utils/dag_edges.py
+++ b/airflow-core/src/airflow/utils/dag_edges.py
@@ -20,10 +20,11 @@ from typing import TYPE_CHECKING, TypeAlias, cast
from airflow.models.mappedoperator import MappedOperator
from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator
-from airflow.serialization.serialized_objects import SerializedBaseOperator,
SerializedDAG
+from airflow.serialization.definitions.baseoperator import
SerializedBaseOperator
if TYPE_CHECKING:
from airflow.sdk import DAG
+ from airflow.serialization.definitions.dag import SerializedDAG
Operator: TypeAlias = MappedOperator | SerializedBaseOperator
diff --git a/airflow-core/src/airflow/utils/dot_renderer.py
b/airflow-core/src/airflow/utils/dot_renderer.py
index d0802972980..10d5828fe7c 100644
--- a/airflow-core/src/airflow/utils/dot_renderer.py
+++ b/airflow-core/src/airflow/utils/dot_renderer.py
@@ -27,8 +27,8 @@ from airflow.exceptions import AirflowException
from airflow.models.mappedoperator import MappedOperator as
SerializedMappedOperator
from airflow.sdk import DAG, BaseOperator, TaskGroup
from airflow.sdk.definitions.mappedoperator import MappedOperator
+from airflow.serialization.definitions.baseoperator import
SerializedBaseOperator
from airflow.serialization.definitions.taskgroup import SerializedTaskGroup
-from airflow.serialization.serialized_objects import SerializedBaseOperator,
SerializedDAG
from airflow.utils.dag_edges import dag_edges
from airflow.utils.state import State
@@ -37,6 +37,7 @@ if TYPE_CHECKING:
from airflow.models import TaskInstance
from airflow.serialization.dag_dependency import DagDependency
+ from airflow.serialization.definitions.dag import SerializedDAG
else:
try:
import graphviz
diff --git a/airflow-core/tests/unit/api/common/test_mark_tasks.py
b/airflow-core/tests/unit/api/common/test_mark_tasks.py
index 7dfa8ecac2d..e50a4de7f15 100644
--- a/airflow-core/tests/unit/api/common/test_mark_tasks.py
+++ b/airflow-core/tests/unit/api/common/test_mark_tasks.py
@@ -28,7 +28,7 @@ from airflow.utils.state import DagRunState, State,
TaskInstanceState
if TYPE_CHECKING:
from airflow.models.taskinstance import TaskInstance
- from airflow.serialization.serialized_objects import SerializedDAG
+ from airflow.serialization.definitions.dag import SerializedDAG
from tests_common.pytest_plugin import DagMaker
diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py
b/airflow-core/tests/unit/jobs/test_scheduler_job.py
index 1c2203579f9..10d9a5e493b 100644
--- a/airflow-core/tests/unit/jobs/test_scheduler_job.py
+++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py
@@ -82,7 +82,8 @@ from airflow.providers.standard.operators.empty import
EmptyOperator
from airflow.providers.standard.triggers.file import FileDeleteTrigger
from airflow.sdk import DAG, Asset, AssetAlias, AssetWatcher, task
from airflow.sdk.definitions.callback import AsyncCallback, SyncCallback
-from airflow.serialization.serialized_objects import LazyDeserializedDAG,
SerializedDAG
+from airflow.serialization.definitions.dag import SerializedDAG
+from airflow.serialization.serialized_objects import LazyDeserializedDAG
from airflow.timetables.base import DataInterval
from airflow.timetables.simple import IdentityMapper, PartitionedAssetTimetable
from airflow.utils.session import create_session, provide_session
diff --git a/airflow-core/tests/unit/models/test_cleartasks.py
b/airflow-core/tests/unit/models/test_cleartasks.py
index 8525677f866..a523fa2ac76 100644
--- a/airflow-core/tests/unit/models/test_cleartasks.py
+++ b/airflow-core/tests/unit/models/test_cleartasks.py
@@ -30,7 +30,7 @@ from airflow.models.taskinstancehistory import
TaskInstanceHistory
from airflow.models.taskreschedule import TaskReschedule
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.providers.standard.sensors.python import PythonSensor
-from airflow.serialization.serialized_objects import SerializedDAG
+from airflow.serialization.definitions.dag import SerializedDAG
from airflow.utils.session import create_session
from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.types import DagRunTriggeredByType, DagRunType
diff --git a/airflow-core/tests/unit/models/test_dagcode.py
b/airflow-core/tests/unit/models/test_dagcode.py
index b98499ca238..df7b0ac26d6 100644
--- a/airflow-core/tests/unit/models/test_dagcode.py
+++ b/airflow-core/tests/unit/models/test_dagcode.py
@@ -28,7 +28,7 @@ from airflow.dag_processing.dagbag import DagBag
from airflow.models.dag_version import DagVersion
from airflow.models.dagcode import DagCode
from airflow.sdk import task as task_decorator
-from airflow.serialization.serialized_objects import SerializedDAG
+from airflow.serialization.definitions.dag import SerializedDAG
# To move it to a shared module.
from airflow.utils.file import open_maybe_zipped
diff --git a/airflow-core/tests/unit/models/test_dagrun.py
b/airflow-core/tests/unit/models/test_dagrun.py
index a09aceccc02..1c85d1762f8 100644
--- a/airflow-core/tests/unit/models/test_dagrun.py
+++ b/airflow-core/tests/unit/models/test_dagrun.py
@@ -48,7 +48,7 @@ from airflow.providers.standard.operators.python import
PythonOperator, ShortCir
from airflow.sdk import DAG, BaseOperator, get_current_context, setup, task,
task_group, teardown
from airflow.sdk.definitions.callback import AsyncCallback
from airflow.sdk.definitions.deadline import DeadlineAlert, DeadlineReference
-from airflow.serialization.serialized_objects import LazyDeserializedDAG,
SerializedDAG
+from airflow.serialization.serialized_objects import LazyDeserializedDAG
from airflow.task.trigger_rule import TriggerRule
from airflow.triggers.base import StartTriggerArgs
from airflow.utils.span_status import SpanStatus
@@ -68,6 +68,8 @@ pytestmark = [pytest.mark.db_test,
pytest.mark.need_serialized_dag]
if TYPE_CHECKING:
from sqlalchemy.orm.session import Session
+ from airflow.serialization.definitions.dag import SerializedDAG
+
TI = TaskInstance
DEFAULT_DATE = pendulum.instance(_DEFAULT_DATE)
diff --git a/airflow-core/tests/unit/models/test_mappedoperator.py
b/airflow-core/tests/unit/models/test_mappedoperator.py
index 35265ca8424..e5bfe78a059 100644
--- a/airflow-core/tests/unit/models/test_mappedoperator.py
+++ b/airflow-core/tests/unit/models/test_mappedoperator.py
@@ -33,7 +33,7 @@ from airflow.models.taskinstance import TaskInstance
from airflow.models.taskmap import TaskMap
from airflow.providers.standard.operators.python import PythonOperator
from airflow.sdk import DAG, BaseOperator, TaskGroup, setup, task, task_group,
teardown
-from airflow.serialization.serialized_objects import SerializedBaseOperator
+from airflow.serialization.definitions.baseoperator import
SerializedBaseOperator
from airflow.task.priority_strategy import PriorityWeightStrategy
from airflow.task.trigger_rule import TriggerRule
from airflow.utils.state import TaskInstanceState
diff --git a/airflow-core/tests/unit/models/test_taskinstance.py
b/airflow-core/tests/unit/models/test_taskinstance.py
index b19e56b4c6e..e83f906fa04 100644
--- a/airflow-core/tests/unit/models/test_taskinstance.py
+++ b/airflow-core/tests/unit/models/test_taskinstance.py
@@ -80,8 +80,9 @@ from airflow.sdk.definitions.param import process_params
from airflow.sdk.definitions.taskgroup import TaskGroup
from airflow.sdk.execution_time.comms import AssetEventsResult
from airflow.serialization.definitions.assets import SerializedAsset
+from airflow.serialization.definitions.dag import SerializedDAG
from airflow.serialization.encoders import ensure_serialized_asset
-from airflow.serialization.serialized_objects import SerializedBaseOperator,
SerializedDAG
+from airflow.serialization.serialized_objects import OperatorSerialization
from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.dependencies_deps import REQUEUEABLE_DEPS, RUNNING_DEPS
from airflow.ti_deps.dependencies_states import RUNNABLE_STATES
@@ -2679,8 +2680,8 @@ class TestTaskInstance:
# Verify that ti.operator field renders correctly "without"
Serialization
assert ti.operator == "EmptyOperator"
- serialized_op = SerializedBaseOperator.serialize_operator(ti.task)
- deserialized_op =
SerializedBaseOperator.deserialize_operator(serialized_op)
+ serialized_op = OperatorSerialization.serialize_operator(ti.task)
+ deserialized_op =
OperatorSerialization.deserialize_operator(serialized_op)
assert deserialized_op.task_type == "EmptyOperator"
# Verify that ti.operator field renders correctly "with" Serialization
ser_ti = TI(task=deserialized_op, run_id=None,
dag_version_id=ti.dag_version_id)
diff --git a/airflow-core/tests/unit/serialization/test_dag_serialization.py
b/airflow-core/tests/unit/serialization/test_dag_serialization.py
index 4704a51cbc7..cb9883c2d7f 100644
--- a/airflow-core/tests/unit/serialization/test_dag_serialization.py
+++ b/airflow-core/tests/unit/serialization/test_dag_serialization.py
@@ -77,6 +77,7 @@ from airflow.serialization.json_schema import
load_dag_schema_dict
from airflow.serialization.serialized_objects import (
BaseSerialization,
DagSerialization,
+ OperatorSerialization,
SerializedBaseOperator,
SerializedParam,
XComOperatorLink,
@@ -105,6 +106,16 @@ if TYPE_CHECKING:
from airflow.sdk.definitions.context import Context
+def _operator_equal(a, b):
+ if not isinstance(a, (BaseOperator, SerializedBaseOperator)):
+ return NotImplemented
+ if not isinstance(b, (BaseOperator, SerializedBaseOperator)):
+ return NotImplemented
+ a_fields = {getattr(a, f) for f in BaseOperator._comps}
+ b_fields = {getattr(b, f) for f in BaseOperator._comps}
+ return a_fields == b_fields
+
+
@pytest.fixture
def operator_defaults(monkeypatch):
"""
@@ -117,7 +128,7 @@ def operator_defaults(monkeypatch):
"""
import airflow.sdk.definitions._internal.abstractoperator as
abstract_op_module
from airflow.sdk.bases.operator import OPERATOR_DEFAULTS
- from airflow.serialization.serialized_objects import SerializedBaseOperator
+ from airflow.serialization.serialized_objects import OperatorSerialization
@contextlib.contextmanager
def _operator_defaults(overrides):
@@ -131,13 +142,13 @@ def operator_defaults(monkeypatch):
monkeypatch.setattr(abstract_op_module, const_name, value)
# Clear the cache to ensure fresh generation
- SerializedBaseOperator.generate_client_defaults.cache_clear()
+ OperatorSerialization.generate_client_defaults.cache_clear()
try:
yield
finally:
# Clear cache again to restore normal behavior
- SerializedBaseOperator.generate_client_defaults.cache_clear()
+ OperatorSerialization.generate_client_defaults.cache_clear()
return _operator_defaults
@@ -866,9 +877,10 @@ class TestStringifiedDAGs:
assert serialized_partial_kwargs == original_partial_kwargs
# ExpandInputs have different classes between scheduler and
definition
- assert attrs.asdict(serialized_task._get_specified_expand_input())
== attrs.asdict(
- task._get_specified_expand_input()
- )
+ ser_expand_input_data =
attrs.asdict(serialized_task._get_specified_expand_input())
+ sdk_expand_input_data =
attrs.asdict(task._get_specified_expand_input())
+ with mock.patch.object(SerializedBaseOperator, "__eq__",
_operator_equal):
+ assert ser_expand_input_data == sdk_expand_input_data
@pytest.mark.parametrize(
("dag_start_date", "task_start_date", "expected_task_start_date"),
@@ -1478,8 +1490,8 @@ class TestStringifiedDAGs:
op = MyOperator(task_id="dummy")
assert op.do_xcom_push is False
- blob = SerializedBaseOperator.serialize_operator(op)
- serialized_op = SerializedBaseOperator.deserialize_operator(blob)
+ blob = OperatorSerialization.serialize_operator(op)
+ serialized_op = OperatorSerialization.deserialize_operator(blob)
assert serialized_op.do_xcom_push is False
@@ -1589,11 +1601,9 @@ class TestStringifiedDAGs:
"ui_fgcolor": "#000",
}
- DagSerialization._json_schema.validate(
- blob,
- _schema=load_dag_schema_dict()["definitions"]["operator"],
- )
- serialized_op = SerializedBaseOperator.deserialize_operator(blob)
+ operator_schema = load_dag_schema_dict()["definitions"]["operator"]
+ DagSerialization._json_schema.validate(blob, _schema=operator_schema)
+ serialized_op = OperatorSerialization.deserialize_operator(blob)
assert serialized_op.downstream_task_ids == {"foo"}
def test_task_resources(self):
@@ -1650,11 +1660,11 @@ class TestStringifiedDAGs:
children = node.children.values()
except AttributeError:
# Round-trip serialization and check the result
- expected_serialized =
SerializedBaseOperator.serialize_operator(dag.get_task(node.task_id))
- expected_deserialized =
SerializedBaseOperator.deserialize_operator(expected_serialized)
- expected_dict =
SerializedBaseOperator.serialize_operator(expected_deserialized)
+ expected_serialized =
OperatorSerialization.serialize_operator(dag.get_task(node.task_id))
+ expected_deserialized =
OperatorSerialization.deserialize_operator(expected_serialized)
+ expected_dict =
OperatorSerialization.serialize_operator(expected_deserialized)
assert node
- assert SerializedBaseOperator.serialize_operator(node) ==
expected_dict
+ assert OperatorSerialization.serialize_operator(node) ==
expected_dict
return
for child in children:
@@ -1796,11 +1806,11 @@ class TestStringifiedDAGs:
assert op.inlets == []
assert op.outlets == []
- serialized = SerializedBaseOperator.serialize_mapped_operator(op)
+ serialized = OperatorSerialization.serialize_mapped_operator(op)
assert "inlets" not in serialized
assert "outlets" not in serialized
- round_tripped = SerializedBaseOperator.deserialize_operator(serialized)
+ round_tripped = OperatorSerialization.deserialize_operator(serialized)
assert isinstance(round_tripped, MappedOperator)
assert round_tripped.inlets == []
assert round_tripped.outlets == []
@@ -2178,10 +2188,10 @@ class TestStringifiedDAGs:
op = DummySensor(task_id="dummy", mode=mode, poke_interval=23)
- blob = SerializedBaseOperator.serialize_operator(op)
+ blob = OperatorSerialization.serialize_operator(op)
assert "_is_sensor" in blob
- serialized_op = SerializedBaseOperator.deserialize_operator(blob)
+ serialized_op = OperatorSerialization.deserialize_operator(blob)
assert serialized_op.reschedule == (mode == "reschedule")
assert ReadyToRescheduleDep in [type(d) for d in serialized_op.deps]
@@ -2196,11 +2206,11 @@ class TestStringifiedDAGs:
op = DummySensor.partial(task_id="dummy",
mode=mode).expand(poke_interval=[23])
- blob = SerializedBaseOperator.serialize_mapped_operator(op)
+ blob = OperatorSerialization.serialize_mapped_operator(op)
assert "_is_sensor" in blob
assert "_is_mapped" in blob
- serialized_op = SerializedBaseOperator.deserialize_operator(blob)
+ serialized_op = OperatorSerialization.deserialize_operator(blob)
assert ReadyToRescheduleDep in [type(d) for d in serialized_op.deps]
@pytest.mark.parametrize(
@@ -2696,7 +2706,8 @@ def test_operator_expand_xcomarg_serde():
@pytest.mark.parametrize("strict", [True, False])
def test_operator_expand_kwargs_literal_serde(strict):
from airflow.sdk.definitions.xcom_arg import XComArg
- from airflow.serialization.serialized_objects import
DEFAULT_OPERATOR_DEPS, _XComRef
+ from airflow.serialization.definitions.baseoperator import
DEFAULT_OPERATOR_DEPS
+ from airflow.serialization.serialized_objects import _XComRef
with DAG("test-dag", schedule=None, start_date=datetime(2020, 1, 1)) as
dag:
task1 = BaseOperator(task_id="op1")
@@ -2764,7 +2775,7 @@ def test_operator_expand_kwargs_xcomarg_serde(strict):
task1 = BaseOperator(task_id="op1")
mapped =
MockOperator.partial(task_id="task_2").expand_kwargs(XComArg(task1),
strict=strict)
- serialized = SerializedBaseOperator.serialize(mapped)
+ serialized = OperatorSerialization.serialize(mapped)
assert serialized["__var"] == {
"_is_mapped": True,
"_task_module": "tests_common.test_utils.mock_operators",
@@ -3038,7 +3049,7 @@ def test_mapped_task_group_serde():
tg.expand(a=[".", ".."])
- ser_dag = SerializedBaseOperator.serialize(dag)
+ ser_dag = OperatorSerialization.serialize(dag)
assert ser_dag[Encoding.VAR]["task_group"]["children"]["tg"] == (
"taskgroup",
{
@@ -3083,7 +3094,7 @@ def test_mapped_task_with_operator_extra_links_property():
with DAG("test-dag", schedule=None, start_date=datetime(2020, 1, 1)) as
dag:
_DummyOperator.partial(task_id="task").expand(inputs=[1, 2, 3])
- serialized_dag = SerializedBaseOperator.serialize(dag)
+ serialized_dag = OperatorSerialization.serialize(dag)
assert serialized_dag[Encoding.VAR]["tasks"][0]["__var"] == {
"task_id": "task",
"expand_input": {
@@ -3127,11 +3138,11 @@ def test_python_callable_in_partial_kwargs():
python_callable=empty_function,
).expand(op_kwargs=[{"x": 1}])
- serialized = SerializedBaseOperator.serialize_mapped_operator(operator)
+ serialized = OperatorSerialization.serialize_mapped_operator(operator)
assert "python_callable" not in serialized["partial_kwargs"]
assert serialized["partial_kwargs"]["python_callable_name"] ==
qualname(empty_function)
- deserialized = SerializedBaseOperator.deserialize_operator(serialized)
+ deserialized = OperatorSerialization.deserialize_operator(serialized)
assert "python_callable" not in deserialized.partial_kwargs
assert deserialized.partial_kwargs["python_callable_name"] ==
qualname(empty_function)
@@ -3422,6 +3433,7 @@ def test_handle_v1_serdag():
expected = DagSerialization.from_dict(expected_sdag)
fields_to_verify = set(vars(expected).keys()) - {
+ "task_dict", # Tested separately
"task_group", # Tested separately
"dag_dependencies", # Tested separately
"last_loaded", # Dynamically set to utcnow
@@ -3430,11 +3442,13 @@ def test_handle_v1_serdag():
for f in fields_to_verify:
dag_value = getattr(dag, f)
expected_value = getattr(expected, f)
-
assert dag_value == expected_value, (
f"V2 DAG field '{f}' differs from V3: V2={dag_value!r} !=
V3={expected_value!r}"
)
+ with mock.patch.object(SerializedBaseOperator, "__eq__", _operator_equal):
+ assert dag.task_dict == expected.task_dict
+
for f in set(vars(expected.task_group).keys()) - {"dag"}:
dag_tg_value = getattr(dag.task_group, f)
expected_tg_value = getattr(expected.task_group, f)
@@ -3627,6 +3641,7 @@ def test_handle_v2_serdag():
expected = DagSerialization.from_dict(expected_sdag)
fields_to_verify = set(vars(expected).keys()) - {
+ "task_dict", # Tested separately
"task_group", # Tested separately
"last_loaded", # Dynamically set to utcnow
}
@@ -3634,15 +3649,16 @@ def test_handle_v2_serdag():
for f in fields_to_verify:
dag_value = getattr(dag, f)
expected_value = getattr(expected, f)
-
assert dag_value == expected_value, (
f"V2 DAG field '{f}' differs from V3: V2={dag_value!r} !=
V3={expected_value!r}"
)
+ with mock.patch.object(SerializedBaseOperator, "__eq__", _operator_equal):
+ assert dag.task_dict == expected.task_dict
+
for f in set(vars(expected.task_group).keys()) - {"dag"}:
dag_tg_value = getattr(dag.task_group, f)
expected_tg_value = getattr(expected.task_group, f)
-
assert dag_tg_value == expected_tg_value, (
f"V2 task_group field '{f}' differs: V2={dag_tg_value!r} !=
V3={expected_tg_value!r}"
)
@@ -3963,7 +3979,7 @@ def
test_task_callback_backward_compatibility(old_callback_name, new_callback_na
}
# Test deserialization converts old format to new format
- deserialized_task =
SerializedBaseOperator.deserialize_operator(old_serialized_task)
+ deserialized_task =
OperatorSerialization.deserialize_operator(old_serialized_task)
# Verify the new format is present and correct
assert hasattr(deserialized_task, new_callback_name)
@@ -3973,7 +3989,7 @@ def
test_task_callback_backward_compatibility(old_callback_name, new_callback_na
# Test with empty/None callback (should convert to False)
old_serialized_task[old_callback_name] = None
- deserialized_task_empty =
SerializedBaseOperator.deserialize_operator(old_serialized_task)
+ deserialized_task_empty =
OperatorSerialization.deserialize_operator(old_serialized_task)
assert getattr(deserialized_task_empty, new_callback_name) is False
@@ -4003,7 +4019,7 @@ class TestClientDefaultsGeneration:
def test_generate_client_defaults_basic(self):
"""Test basic client defaults generation."""
- client_defaults = SerializedBaseOperator.generate_client_defaults()
+ client_defaults = OperatorSerialization.generate_client_defaults()
assert isinstance(client_defaults, dict)
@@ -4014,8 +4030,8 @@ class TestClientDefaultsGeneration:
def test_generate_client_defaults_excludes_schema_defaults(self):
"""Test that client defaults excludes values that match schema
defaults."""
- client_defaults = SerializedBaseOperator.generate_client_defaults()
- schema_defaults =
SerializedBaseOperator.get_schema_defaults("operator")
+ client_defaults = OperatorSerialization.generate_client_defaults()
+ schema_defaults = OperatorSerialization.get_schema_defaults("operator")
# Check that values matching schema defaults are excluded
for field, value in client_defaults.items():
@@ -4026,7 +4042,7 @@ class TestClientDefaultsGeneration:
def test_generate_client_defaults_excludes_none_and_empty(self):
"""Test that client defaults excludes None and empty collection
values."""
- client_defaults = SerializedBaseOperator.generate_client_defaults()
+ client_defaults = OperatorSerialization.generate_client_defaults()
for field, value in client_defaults.items():
assert value is not None, f"Field {field} has None value"
@@ -4035,23 +4051,23 @@ class TestClientDefaultsGeneration:
def test_generate_client_defaults_caching(self):
"""Test that client defaults generation is cached."""
# Clear cache first
- SerializedBaseOperator.generate_client_defaults.cache_clear()
+ OperatorSerialization.generate_client_defaults.cache_clear()
# First call
- client_defaults_1 = SerializedBaseOperator.generate_client_defaults()
+ client_defaults_1 = OperatorSerialization.generate_client_defaults()
# Second call should return same object (cached)
- client_defaults_2 = SerializedBaseOperator.generate_client_defaults()
+ client_defaults_2 = OperatorSerialization.generate_client_defaults()
assert client_defaults_1 is client_defaults_2, "Client defaults should
be cached"
# Check cache info
- cache_info =
SerializedBaseOperator.generate_client_defaults.cache_info()
+ cache_info =
OperatorSerialization.generate_client_defaults.cache_info()
assert cache_info.hits >= 1, "Cache should have at least one hit"
def test_generate_client_defaults_only_operator_defaults_fields(self):
"""Test that only fields from OPERATOR_DEFAULTS are considered."""
- client_defaults = SerializedBaseOperator.generate_client_defaults()
+ client_defaults = OperatorSerialization.generate_client_defaults()
# All fields in client_defaults should originate from OPERATOR_DEFAULTS
for field in client_defaults:
@@ -4063,7 +4079,7 @@ class TestSchemaDefaults:
def test_get_schema_defaults_operator(self):
"""Test getting schema defaults for operator type."""
- schema_defaults =
SerializedBaseOperator.get_schema_defaults("operator")
+ schema_defaults = OperatorSerialization.get_schema_defaults("operator")
assert isinstance(schema_defaults, dict)
@@ -4086,12 +4102,12 @@ class TestSchemaDefaults:
def test_get_schema_defaults_nonexistent_type(self):
"""Test getting schema defaults for nonexistent type."""
- schema_defaults =
SerializedBaseOperator.get_schema_defaults("nonexistent")
+ schema_defaults =
OperatorSerialization.get_schema_defaults("nonexistent")
assert schema_defaults == {}
def test_get_operator_optional_fields_from_schema(self):
"""Test getting optional fields from schema."""
- optional_fields =
SerializedBaseOperator.get_operator_optional_fields_from_schema()
+ optional_fields =
OperatorSerialization.get_operator_optional_fields_from_schema()
assert isinstance(optional_fields, set)
@@ -4127,7 +4143,7 @@ class TestDeserializationDefaultsResolution:
encoded_op = {"task_id": "test_task", "task_type": "BashOperator",
"retries": 10}
client_defaults = {"tasks": {"retry_delay": 300.0, "retries": 2}} #
Fix: wrap in "tasks"
- result =
SerializedBaseOperator._apply_defaults_to_encoded_op(encoded_op,
client_defaults)
+ result =
OperatorSerialization._apply_defaults_to_encoded_op(encoded_op, client_defaults)
# Should merge in order: client_defaults, encoded_op
assert result["retry_delay"] == 300.0 # From client_defaults
@@ -4139,7 +4155,7 @@ class TestDeserializationDefaultsResolution:
encoded_op = {"task_id": "test_task"}
# With None client_defaults
- result =
SerializedBaseOperator._apply_defaults_to_encoded_op(encoded_op, None)
+ result =
OperatorSerialization._apply_defaults_to_encoded_op(encoded_op, None)
assert result == encoded_op
def test_multiple_tasks_share_client_defaults(self, operator_defaults):
@@ -4379,7 +4395,7 @@ class TestMappedOperatorSerializationAndClientDefaults:
)
def test_partial_kwargs_deserialization_formats(self, partial_kwargs_data,
expected_results):
"""Test deserialization of partial_kwargs in various formats (encoded,
non-encoded, mixed)."""
- result =
SerializedBaseOperator._deserialize_partial_kwargs(partial_kwargs_data)
+ result =
OperatorSerialization._deserialize_partial_kwargs(partial_kwargs_data)
# Verify all expected results
for key, expected_value in expected_results.items():
diff --git a/airflow-core/tests/unit/serialization/test_serialized_objects.py
b/airflow-core/tests/unit/serialization/test_serialized_objects.py
index 0ddc396bd88..92f72b31e02 100644
--- a/airflow-core/tests/unit/serialization/test_serialized_objects.py
+++ b/airflow-core/tests/unit/serialization/test_serialized_objects.py
@@ -173,7 +173,7 @@ def test_strict_mode():
def test_prevent_re_serialization_of_serialized_operators():
"""SerializedBaseOperator should not be re-serializable."""
- from airflow.serialization.serialized_objects import BaseSerialization,
SerializedBaseOperator
+ from airflow.serialization.definitions.baseoperator import
SerializedBaseOperator
serialized_op = SerializedBaseOperator(task_id="test_task")
diff --git a/devel-common/src/tests_common/pytest_plugin.py
b/devel-common/src/tests_common/pytest_plugin.py
index 39eb4fd2b9c..dde8989f48c 100644
--- a/devel-common/src/tests_common/pytest_plugin.py
+++ b/devel-common/src/tests_common/pytest_plugin.py
@@ -52,7 +52,7 @@ if TYPE_CHECKING:
from airflow.sdk.execution_time.comms import StartupDetails, ToSupervisor
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
from airflow.sdk.types import DagRunProtocol, Operator
- from airflow.serialization.serialized_objects import SerializedDAG
+ from airflow.serialization.definitions.dag import SerializedDAG
from airflow.timetables.base import DataInterval
from airflow.typing_compat import Self
from airflow.utils.state import DagRunState, TaskInstanceState
diff --git a/devel-common/src/tests_common/test_utils/compat.py
b/devel-common/src/tests_common/test_utils/compat.py
index 6b3b1b03d35..1e2286db9e8 100644
--- a/devel-common/src/tests_common/test_utils/compat.py
+++ b/devel-common/src/tests_common/test_utils/compat.py
@@ -41,12 +41,16 @@ except ImportError:
try:
from airflow.serialization.definitions.dag import SerializedDAG
- from airflow.serialization.serialized_objects import DagSerialization
+ from airflow.serialization.serialized_objects import DagSerialization,
OperatorSerialization
except ImportError:
# Compatibility for Airflow < 3.2.*
- from airflow.serialization.serialized_objects import SerializedDAG #
type: ignore[no-redef]
+ from airflow.serialization.serialized_objects import ( # type:
ignore[no-redef]
+ SerializedBaseOperator,
+ SerializedDAG,
+ )
DagSerialization = SerializedDAG # type: ignore[assignment,misc,no-redef]
+ OperatorSerialization = SerializedBaseOperator # type:
ignore[assignment,misc,no-redef]
try:
from airflow.providers.common.sql.operators.generic_transfer import
GenericTransfer
diff --git a/devel-common/src/tests_common/test_utils/mapping.py
b/devel-common/src/tests_common/test_utils/mapping.py
index fb6096c6ea4..f8ed5c52331 100644
--- a/devel-common/src/tests_common/test_utils/mapping.py
+++ b/devel-common/src/tests_common/test_utils/mapping.py
@@ -24,7 +24,7 @@ if TYPE_CHECKING:
from sqlalchemy.orm import Session
from airflow.models.mappedoperator import MappedOperator
- from airflow.serialization.serialized_objects import SerializedBaseOperator
+ from airflow.serialization.definitions.baseoperator import
SerializedBaseOperator
def expand_mapped_task(
diff --git
a/providers/fab/src/airflow/providers/fab/auth_manager/security_manager/override.py
b/providers/fab/src/airflow/providers/fab/auth_manager/security_manager/override.py
index 464ce6b1ea5..3553e85e438 100644
---
a/providers/fab/src/airflow/providers/fab/auth_manager/security_manager/override.py
+++
b/providers/fab/src/airflow/providers/fab/auth_manager/security_manager/override.py
@@ -109,7 +109,7 @@ if TYPE_CHECKING:
RESOURCE_ASSET_ALIAS,
)
from airflow.sdk import DAG
- from airflow.serialization.serialized_objects import SerializedDAG
+ from airflow.serialization.definitions.dag import SerializedDAG
else:
from airflow.providers.common.compat.security.permissions import (
RESOURCE_ASSET,
diff --git
a/providers/openlineage/src/airflow/providers/openlineage/utils/selective_enable.py
b/providers/openlineage/src/airflow/providers/openlineage/utils/selective_enable.py
index 529a386259f..c017c2c61d8 100644
---
a/providers/openlineage/src/airflow/providers/openlineage/utils/selective_enable.py
+++
b/providers/openlineage/src/airflow/providers/openlineage/utils/selective_enable.py
@@ -25,7 +25,7 @@ from airflow.providers.common.compat.sdk import DAG, Param,
XComArg
if TYPE_CHECKING:
from airflow.providers.common.compat.sdk import BaseOperator,
MappedOperator
from airflow.providers.openlineage.utils.utils import AnyOperator
- from airflow.serialization.serialized_objects import SerializedDAG
+ from airflow.serialization.definitions.dag import SerializedDAG
T = TypeVar("T", bound=DAG | BaseOperator | MappedOperator)
diff --git a/providers/openlineage/tests/unit/openlineage/utils/test_utils.py
b/providers/openlineage/tests/unit/openlineage/utils/test_utils.py
index cad9a386834..e9bfaf052e2 100644
--- a/providers/openlineage/tests/unit/openlineage/utils/test_utils.py
+++ b/providers/openlineage/tests/unit/openlineage/utils/test_utils.py
@@ -62,14 +62,13 @@ from airflow.providers.openlineage.utils.utils import (
get_user_provided_run_facets,
)
from airflow.providers.standard.operators.empty import EmptyOperator
-from airflow.serialization.serialized_objects import SerializedBaseOperator
from airflow.timetables.events import EventsTimetable
from airflow.timetables.trigger import CronTriggerTimetable
from airflow.utils.session import create_session
from airflow.utils.state import DagRunState
from airflow.utils.types import DagRunType
-from tests_common.test_utils.compat import BashOperator, PythonOperator
+from tests_common.test_utils.compat import BashOperator,
OperatorSerialization, PythonOperator
from tests_common.test_utils.mock_operators import MockOperator
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_3_PLUS,
AIRFLOW_V_3_0_PLUS
@@ -271,8 +270,8 @@ def
test_get_fully_qualified_class_name_serialized_operator():
op_path_before_serialization = get_fully_qualified_class_name(op)
assert op_path_before_serialization == f"{op_module_path}.{op_name}"
- serialized = SerializedBaseOperator.serialize_operator(op)
- deserialized = SerializedBaseOperator.deserialize_operator(serialized)
+ serialized = OperatorSerialization.serialize_operator(op)
+ deserialized = OperatorSerialization.deserialize_operator(serialized)
op_path_after_deserialization =
get_fully_qualified_class_name(deserialized)
assert op_path_after_deserialization == f"{op_module_path}.{op_name}"
@@ -406,8 +405,8 @@ def test_get_task_documentation_serialized_operator():
op_doc_before_serialization = get_task_documentation(op)
assert op_doc_before_serialization == ("some_doc", "text/plain")
- serialized = SerializedBaseOperator.serialize_operator(op)
- deserialized = SerializedBaseOperator.deserialize_operator(serialized)
+ serialized = OperatorSerialization.serialize_operator(op)
+ deserialized = OperatorSerialization.deserialize_operator(serialized)
op_doc_after_deserialization = get_task_documentation(deserialized)
assert op_doc_after_deserialization == ("some_doc", "text/plain")
diff --git
a/providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py
b/providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py
index 274d86f5890..13589b2d12c 100644
---
a/providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py
+++
b/providers/standard/tests/unit/standard/sensors/test_external_task_sensor.py
@@ -52,12 +52,12 @@ from airflow.providers.standard.operators.python import
PythonOperator
from airflow.providers.standard.sensors.external_task import
ExternalTaskMarker, ExternalTaskSensor
from airflow.providers.standard.sensors.time import TimeSensor
from airflow.providers.standard.triggers.external_task import WorkflowTrigger
-from airflow.serialization.serialized_objects import SerializedBaseOperator
from airflow.timetables.base import DataInterval
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import DagRunState, State
from airflow.utils.types import DagRunType
+from tests_common.test_utils.compat import OperatorSerialization
from tests_common.test_utils.dag import create_scheduler_dag, sync_dag_to_db,
sync_dags_to_db
from tests_common.test_utils.db import clear_db_runs
from tests_common.test_utils.mock_operators import MockOperator
@@ -1603,8 +1603,8 @@ class TestExternalTaskMarker:
dag=dag,
)
- serialized_op = SerializedBaseOperator.serialize_operator(task)
- deserialized_op =
SerializedBaseOperator.deserialize_operator(serialized_op)
+ serialized_op = OperatorSerialization.serialize_operator(task)
+ deserialized_op =
OperatorSerialization.deserialize_operator(serialized_op)
assert deserialized_op.task_type == "ExternalTaskMarker"
assert getattr(deserialized_op, "external_dag_id") ==
"external_task_marker_child"
assert getattr(deserialized_op, "external_task_id") == "child_task1"
diff --git a/providers/standard/tests/unit/standard/utils/test_sensor_helper.py
b/providers/standard/tests/unit/standard/utils/test_sensor_helper.py
index d6c5cffd134..9f9bf634f38 100644
--- a/providers/standard/tests/unit/standard/utils/test_sensor_helper.py
+++ b/providers/standard/tests/unit/standard/utils/test_sensor_helper.py
@@ -50,7 +50,7 @@ except ImportError: # Fallback for Airflow < 3.1
if TYPE_CHECKING:
from sqlalchemy.orm.session import Session
- from airflow.serialization.serialized_objects import SerializedDAG
+ from airflow.serialization.definitions.dag import SerializedDAG
TI = TaskInstance
diff --git a/scripts/in_container/run_schema_defaults_check.py
b/scripts/in_container/run_schema_defaults_check.py
index 8ad321ac8e2..9b754f265ca 100755
--- a/scripts/in_container/run_schema_defaults_check.py
+++ b/scripts/in_container/run_schema_defaults_check.py
@@ -60,7 +60,7 @@ def load_schema_defaults(object_type: str = "operator") ->
dict[str, Any]:
def get_server_side_operator_defaults() -> dict[str, Any]:
"""Get default values from server-side SerializedBaseOperator class."""
try:
- from airflow.serialization.serialized_objects import
SerializedBaseOperator
+ from airflow.serialization.definitions.baseoperator import
SerializedBaseOperator
# Get all serializable fields
serialized_fields = SerializedBaseOperator.get_serialized_fields()