This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 8af1bbdf25e Remove Pydanitc models introduced for AIP-44 (#44552)
8af1bbdf25e is described below
commit 8af1bbdf25e2650e617d456f729d1d4f46465524
Author: Jarek Potiuk <[email protected]>
AuthorDate: Thu Dec 12 22:36:44 2024 +0100
Remove Pydanitc models introduced for AIP-44 (#44552)
The Pudanic models have been used in a number of places still and
we are using them also for context passing for PythonVirtualEnv
and ExternaPythonOperator - this PR removes all the models and
their usages.
Closes: #44436
# Please enter the commit message for your changes. Lines starting
---
.../api_fastapi/core_api/datamodels/dag_tags.py | 13 +-
airflow/api_fastapi/core_api/datamodels/dags.py | 4 +-
.../api_fastapi/core_api/openapi/v1-generated.yaml | 13 +-
.../execution_api/datamodels/taskinstance.py | 2 +-
.../cli/commands/remote_commands/task_command.py | 20 +-
airflow/jobs/JOB_LIFECYCLE.md | 8 +-
airflow/jobs/base_job_runner.py | 3 +-
airflow/models/__init__.py | 2 +
airflow/models/param.py | 3 +-
airflow/models/skipmixin.py | 5 +-
airflow/models/taskinstance.py | 8 +-
airflow/serialization/pydantic/__init__.py | 16 --
airflow/serialization/pydantic/asset.py | 74 -------
airflow/serialization/pydantic/dag.py | 116 ----------
airflow/serialization/pydantic/dag_run.py | 114 ----------
airflow/serialization/pydantic/job.py | 70 ------
airflow/serialization/pydantic/taskinstance.py | 246 ---------------------
airflow/serialization/pydantic/tasklog.py | 30 ---
airflow/serialization/pydantic/trigger.py | 63 ------
airflow/serialization/serde.py | 20 +-
airflow/serialization/serialized_objects.py | 109 ++-------
airflow/ui/openapi-gen/requests/schemas.gen.ts | 13 +-
airflow/ui/openapi-gen/requests/types.gen.ts | 10 +-
airflow/ui/src/pages/DagsList/DagCard.test.tsx | 4 +-
airflow/ui/src/pages/DagsList/DagTags.tsx | 4 +-
docs/apache-airflow/extra-packages-ref.rst | 2 -
.../providers/edge/worker_api/routes/_v2_routes.py | 6 +-
.../airflow/providers/standard/operators/python.py | 51 +++--
.../standard/utils/python_virtualenv_script.jinja2 | 4 +-
providers/tests/standard/operators/test_python.py | 97 ++++----
tests/always/test_example_dags.py | 6 +
tests/assets/test_manager.py | 47 +---
tests/jobs/test_base_job.py | 6 +-
tests/serialization/test_serde.py | 10 -
tests/serialization/test_serialized_objects.py | 171 +-------------
35 files changed, 166 insertions(+), 1204 deletions(-)
diff --git a/airflow/api_fastapi/core_api/datamodels/dag_tags.py
b/airflow/api_fastapi/core_api/datamodels/dag_tags.py
index 9e67e1ce7b1..8d5014fdf34 100644
--- a/airflow/api_fastapi/core_api/datamodels/dag_tags.py
+++ b/airflow/api_fastapi/core_api/datamodels/dag_tags.py
@@ -17,7 +17,18 @@
from __future__ import annotations
-from pydantic import BaseModel
+from pydantic import ConfigDict
+
+from airflow.api_fastapi.core_api.base import BaseModel
+
+
+class DagTagResponse(BaseModel):
+ """DAG Tag serializer for responses."""
+
+ model_config = ConfigDict(populate_by_name=True, from_attributes=True)
+
+ name: str
+ dag_id: str
class DAGTagCollectionResponse(BaseModel):
diff --git a/airflow/api_fastapi/core_api/datamodels/dags.py
b/airflow/api_fastapi/core_api/datamodels/dags.py
index eddf0e1be22..30399b42f8d 100644
--- a/airflow/api_fastapi/core_api/datamodels/dags.py
+++ b/airflow/api_fastapi/core_api/datamodels/dags.py
@@ -32,8 +32,8 @@ from pydantic import (
)
from airflow.api_fastapi.core_api.base import BaseModel
+from airflow.api_fastapi.core_api.datamodels.dag_tags import DagTagResponse
from airflow.configuration import conf
-from airflow.serialization.pydantic.dag import DagTagPydantic
class DAGResponse(BaseModel):
@@ -50,7 +50,7 @@ class DAGResponse(BaseModel):
description: str | None
timetable_summary: str | None
timetable_description: str | None
- tags: list[DagTagPydantic]
+ tags: list[DagTagResponse]
max_active_tasks: int
max_active_runs: int | None
max_consecutive_failed_dag_runs: int
diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
index 0bfe7f6fe63..d5b4f66b06d 100644
--- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
+++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
@@ -6714,7 +6714,7 @@ components:
title: Timetable Description
tags:
items:
- $ref: '#/components/schemas/DagTagPydantic'
+ $ref: '#/components/schemas/DagTagResponse'
type: array
title: Tags
max_active_tasks:
@@ -6936,7 +6936,7 @@ components:
title: Timetable Description
tags:
items:
- $ref: '#/components/schemas/DagTagPydantic'
+ $ref: '#/components/schemas/DagTagResponse'
type: array
title: Tags
max_active_tasks:
@@ -7412,7 +7412,7 @@ components:
title: Timetable Description
tags:
items:
- $ref: '#/components/schemas/DagTagPydantic'
+ $ref: '#/components/schemas/DagTagResponse'
type: array
title: Tags
max_active_tasks:
@@ -7665,7 +7665,7 @@ components:
- count
title: DagStatsStateResponse
description: DagStatsState serializer for responses.
- DagTagPydantic:
+ DagTagResponse:
properties:
name:
type: string
@@ -7677,9 +7677,8 @@ components:
required:
- name
- dag_id
- title: DagTagPydantic
- description: Serializable representation of the DagTag ORM
SqlAlchemyModel used
- by internal API.
+ title: DagTagResponse
+ description: DAG Tag serializer for responses.
DagWarningType:
type: string
enum:
diff --git a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
index 2b8830c6de1..e0d8f371f09 100644
--- a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
+++ b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
@@ -33,7 +33,7 @@ class TIEnterRunningPayload(BaseModel):
state: Annotated[
Literal[TIState.RUNNING],
- # Specify a default in the schema, but not in code, so Pydantic marks
it as required.
+ # Specify a default in the schema, but not in code.
WithJsonSchema({"type": "string", "enum": [TIState.RUNNING],
"default": TIState.RUNNING}),
]
hostname: str
diff --git a/airflow/cli/commands/remote_commands/task_command.py
b/airflow/cli/commands/remote_commands/task_command.py
index b2795b7bf9d..e4c72b94d62 100644
--- a/airflow/cli/commands/remote_commands/task_command.py
+++ b/airflow/cli/commands/remote_commands/task_command.py
@@ -47,7 +47,6 @@ from airflow.models.dag import DAG, _run_inline_trigger
from airflow.models.dagrun import DagRun
from airflow.models.param import ParamsDict
from airflow.models.taskinstance import TaskReturnCode
-from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
from airflow.settings import IS_EXECUTOR_CONTAINER, IS_K8S_EXECUTOR_POD
from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.dependencies_deps import SCHEDULER_QUEUED_DEPS
@@ -74,7 +73,6 @@ if TYPE_CHECKING:
from sqlalchemy.orm.session import Session
from airflow.models.operator import Operator
- from airflow.serialization.pydantic.dag_run import DagRunPydantic
log = logging.getLogger(__name__)
@@ -96,7 +94,7 @@ def _fetch_dag_run_from_run_id_or_logical_date_string(
dag_id: str,
value: str,
session: Session,
-) -> tuple[DagRun | DagRunPydantic, pendulum.DateTime | None]:
+) -> tuple[DagRun, pendulum.DateTime | None]:
"""
Try to find a DAG run with a given string value.
@@ -132,7 +130,7 @@ def _get_dag_run(
create_if_necessary: CreateIfNecessary,
logical_date_or_run_id: str | None = None,
session: Session | None = None,
-) -> tuple[DagRun | DagRunPydantic, bool]:
+) -> tuple[DagRun, bool]:
"""
Try to retrieve a DAG run from a string representing either a run ID or
logical date.
@@ -259,8 +257,6 @@ def _run_task_by_selected_method(args, dag: DAG, ti:
TaskInstance) -> None | Tas
- as raw task
- by executor
"""
- if TYPE_CHECKING:
- assert not isinstance(ti, TaskInstancePydantic) # Wait for AIP-44
implementation to complete
if args.local:
return _run_task_by_local_task_job(args, ti)
if args.raw:
@@ -497,9 +493,6 @@ def task_failed_deps(args) -> None:
dag = get_dag(args.subdir, args.dag_id)
task = dag.get_task(task_id=args.task_id)
ti, _ = _get_ti(task, args.map_index,
logical_date_or_run_id=args.logical_date_or_run_id)
- # tasks_failed-deps is executed with access to the database.
- if isinstance(ti, TaskInstancePydantic):
- raise ValueError("not a TaskInstance")
dep_context = DepContext(deps=SCHEDULER_QUEUED_DEPS)
failed_deps = list(ti.get_failed_dep_statuses(dep_context=dep_context))
# TODO, Do we want to print or log this
@@ -524,9 +517,6 @@ def task_state(args) -> None:
dag = get_dag(args.subdir, args.dag_id)
task = dag.get_task(task_id=args.task_id)
ti, _ = _get_ti(task, args.map_index,
logical_date_or_run_id=args.logical_date_or_run_id)
- # task_state is executed with access to the database.
- if isinstance(ti, TaskInstancePydantic):
- raise ValueError("not a TaskInstance")
print(ti.current_state())
@@ -654,9 +644,6 @@ def task_test(args, dag: DAG | None = None, session:
Session = NEW_SESSION) -> N
ti, dr_created = _get_ti(
task, args.map_index,
logical_date_or_run_id=args.logical_date_or_run_id, create_if_necessary="db"
)
- # task_test is executed with access to the database.
- if isinstance(ti, TaskInstancePydantic):
- raise ValueError("not a TaskInstance")
try:
with redirect_stdout(RedactedIO()):
if args.dry_run:
@@ -705,9 +692,6 @@ def task_render(args, dag: DAG | None = None) -> None:
ti, _ = _get_ti(
task, args.map_index,
logical_date_or_run_id=args.logical_date_or_run_id, create_if_necessary="memory"
)
- # task_render is executed with access to the database.
- if isinstance(ti, TaskInstancePydantic):
- raise ValueError("not a TaskInstance")
with create_session() as session,
set_current_task_instance_session(session=session):
ti.render_templates()
for attr in task.template_fields:
diff --git a/airflow/jobs/JOB_LIFECYCLE.md b/airflow/jobs/JOB_LIFECYCLE.md
index 61742ea44bc..bdb91381506 100644
--- a/airflow/jobs/JOB_LIFECYCLE.md
+++ b/airflow/jobs/JOB_LIFECYCLE.md
@@ -100,7 +100,7 @@ sequenceDiagram
DB --> Internal API: Close Session
deactivate DB
- Internal API->>CLI component: JobPydantic object
+ Internal API->>CLI component: Job object
CLI component->>JobRunner: Create Job Runner
JobRunner ->> CLI component: JobRunner object
@@ -109,7 +109,7 @@ sequenceDiagram
activate JobRunner
- JobRunner->>Internal API: prepare_for_execution [JobPydantic]
+ JobRunner->>Internal API: prepare_for_execution [Job]
Internal API-->>DB: Create Session
activate DB
@@ -131,7 +131,7 @@ sequenceDiagram
deactivate DB
Internal API ->> JobRunner: returned data
and
- JobRunner->>Internal API: perform_heartbeat <br> [Job Pydantic]
+ JobRunner->>Internal API: perform_heartbeat <br> [Job]
Internal API-->>DB: Create Session
activate DB
Internal API->>DB: perform_heartbeat [Job]
@@ -142,7 +142,7 @@ sequenceDiagram
deactivate DB
end
- JobRunner->>Internal API: complete_execution <br> [Job Pydantic]
+ JobRunner->>Internal API: complete_execution <br> [Job]
Internal API-->>DB: Create Session
Internal API->>DB: complete_execution [Job]
activate DB
diff --git a/airflow/jobs/base_job_runner.py b/airflow/jobs/base_job_runner.py
index df6fcc67abb..05671e2050a 100644
--- a/airflow/jobs/base_job_runner.py
+++ b/airflow/jobs/base_job_runner.py
@@ -26,7 +26,6 @@ if TYPE_CHECKING:
from sqlalchemy.orm import Session
from airflow.jobs.job import Job
- from airflow.serialization.pydantic.job import JobPydantic
class BaseJobRunner:
@@ -64,7 +63,7 @@ class BaseJobRunner:
@classmethod
@provide_session
- def most_recent_job(cls, session: Session = NEW_SESSION) -> Job |
JobPydantic | None:
+ def most_recent_job(cls, session: Session = NEW_SESSION) -> Job | None:
"""Return the most recent job of this type, if any, based on last
heartbeat received."""
from airflow.jobs.job import most_recent_job
diff --git a/airflow/models/__init__.py b/airflow/models/__init__.py
index ffbfdf4d4b2..6e4d0149966 100644
--- a/airflow/models/__init__.py
+++ b/airflow/models/__init__.py
@@ -82,6 +82,7 @@ def __getattr__(name):
__lazy_imports = {
+ "Job": "airflow.jobs.job",
"DAG": "airflow.models.dag",
"ID_LEN": "airflow.models.base",
"Base": "airflow.models.base",
@@ -112,6 +113,7 @@ __lazy_imports = {
if TYPE_CHECKING:
# I was unable to get mypy to respect a airflow/models/__init__.pyi, so
# having to resort back to this hacky method
+ from airflow.jobs.job import Job
from airflow.models.base import ID_LEN, Base
from airflow.models.baseoperator import BaseOperator
from airflow.models.baseoperatorlink import BaseOperatorLink
diff --git a/airflow/models/param.py b/airflow/models/param.py
index e6150ee50cb..ab7d2facd7e 100644
--- a/airflow/models/param.py
+++ b/airflow/models/param.py
@@ -31,7 +31,6 @@ if TYPE_CHECKING:
from airflow.models.dagrun import DagRun
from airflow.models.operator import Operator
from airflow.sdk.definitions.dag import DAG
- from airflow.serialization.pydantic.dag_run import DagRunPydantic
from airflow.utils.context import Context
logger = logging.getLogger(__name__)
@@ -334,7 +333,7 @@ class DagParam(ResolveMixin):
def process_params(
dag: DAG,
task: Operator,
- dag_run: DagRun | DagRunPydantic | None,
+ dag_run: DagRun | None,
*,
suppress_exception: bool,
) -> dict[str, Any]:
diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py
index 257a9bbf3a5..ad5c5d01539 100644
--- a/airflow/models/skipmixin.py
+++ b/airflow/models/skipmixin.py
@@ -37,7 +37,6 @@ if TYPE_CHECKING:
from airflow.models.dagrun import DagRun
from airflow.models.operator import Operator
from airflow.sdk.definitions.node import DAGNode
- from airflow.serialization.pydantic.dag_run import DagRunPydantic
# The key used by SkipMixin to store XCom data.
XCOM_SKIPMIXIN_KEY = "skipmixin_key"
@@ -61,7 +60,7 @@ class SkipMixin(LoggingMixin):
@staticmethod
def _set_state_to_skipped(
- dag_run: DagRun | DagRunPydantic,
+ dag_run: DagRun,
tasks: Sequence[str] | Sequence[tuple[str, int]],
session: Session,
) -> None:
@@ -95,7 +94,7 @@ class SkipMixin(LoggingMixin):
@provide_session
def skip(
self,
- dag_run: DagRun | DagRunPydantic,
+ dag_run: DagRun,
tasks: Iterable[DAGNode],
map_index: int = -1,
session: Session = NEW_SESSION,
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index dbdba3658ad..15ac435ff7d 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -163,8 +163,6 @@ if TYPE_CHECKING:
from airflow.models.dagrun import DagRun
from airflow.models.operator import Operator
from airflow.sdk.definitions.dag import DAG
- from airflow.serialization.pydantic.asset import AssetEventPydantic
- from airflow.serialization.pydantic.dag import DagModelPydantic
from airflow.timetables.base import DataInterval
from airflow.typing_compat import Literal, TypeGuard
from airflow.utils.task_group import TaskGroup
@@ -984,7 +982,7 @@ def _get_template_context(
return None
return timezone.coerce_datetime(dagrun.end_date)
- def get_triggering_events() -> dict[str, list[AssetEvent |
AssetEventPydantic]]:
+ def get_triggering_events() -> dict[str, list[AssetEvent]]:
if TYPE_CHECKING:
assert session is not None
@@ -995,7 +993,7 @@ def _get_template_context(
if dag_run not in session:
dag_run = session.merge(dag_run, load=False)
asset_events = dag_run.consumed_asset_events
- triggering_events: dict[str, list[AssetEvent | AssetEventPydantic]] =
defaultdict(list)
+ triggering_events: dict[str, list[AssetEvent]] = defaultdict(list)
for event in asset_events:
if event.asset:
triggering_events[event.asset.uri].append(event)
@@ -1890,7 +1888,7 @@ class TaskInstance(Base, LoggingMixin):
pool: str | None = None,
cfg_path: str | None = None,
) -> list[str]:
- dag: DAG | DagModel | DagModelPydantic | None
+ dag: DAG | DagModel | None
# Use the dag if we have it, else fallback to the ORM dag_model, which
might not be loaded
if hasattr(ti, "task") and getattr(ti.task, "dag", None) is not None:
if TYPE_CHECKING:
diff --git a/airflow/serialization/pydantic/__init__.py
b/airflow/serialization/pydantic/__init__.py
deleted file mode 100644
index 13a83393a91..00000000000
--- a/airflow/serialization/pydantic/__init__.py
+++ /dev/null
@@ -1,16 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
diff --git a/airflow/serialization/pydantic/asset.py
b/airflow/serialization/pydantic/asset.py
deleted file mode 100644
index 0e5623099ea..00000000000
--- a/airflow/serialization/pydantic/asset.py
+++ /dev/null
@@ -1,74 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from datetime import datetime
-from typing import Optional
-
-from pydantic import BaseModel as BaseModelPydantic, ConfigDict
-
-
-class DagScheduleAssetReferencePydantic(BaseModelPydantic):
- """Serializable version of the DagScheduleAssetReference ORM
SqlAlchemyModel used by internal API."""
-
- asset_id: int
- dag_id: str
- created_at: datetime
- updated_at: datetime
-
- model_config = ConfigDict(from_attributes=True)
-
-
-class TaskOutletAssetReferencePydantic(BaseModelPydantic):
- """Serializable version of the TaskOutletAssetReference ORM
SqlAlchemyModel used by internal API."""
-
- asset_id: int
- dag_id: str
- task_id: str
- created_at: datetime
- updated_at: datetime
-
- model_config = ConfigDict(from_attributes=True)
-
-
-class AssetPydantic(BaseModelPydantic):
- """Serializable representation of the Asset ORM SqlAlchemyModel used by
internal API."""
-
- id: int
- uri: str
- extra: Optional[dict]
- created_at: datetime
- updated_at: datetime
-
- consuming_dags: list[DagScheduleAssetReferencePydantic]
- producing_tasks: list[TaskOutletAssetReferencePydantic]
-
- model_config = ConfigDict(from_attributes=True)
-
-
-class AssetEventPydantic(BaseModelPydantic):
- """Serializable representation of the AssetEvent ORM SqlAlchemyModel used
by internal API."""
-
- id: int
- asset_id: Optional[int]
- extra: dict
- source_task_id: Optional[str]
- source_dag_id: Optional[str]
- source_run_id: Optional[str]
- source_map_index: Optional[int]
- timestamp: datetime
- asset: Optional[AssetPydantic]
-
- model_config = ConfigDict(from_attributes=True,
arbitrary_types_allowed=True)
diff --git a/airflow/serialization/pydantic/dag.py
b/airflow/serialization/pydantic/dag.py
deleted file mode 100644
index 5a1199887f9..00000000000
--- a/airflow/serialization/pydantic/dag.py
+++ /dev/null
@@ -1,116 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from __future__ import annotations
-
-import pathlib
-from datetime import datetime
-from typing import Annotated, Any, Optional
-
-from pydantic import (
- BaseModel as BaseModelPydantic,
- ConfigDict,
- PlainSerializer,
- PlainValidator,
- ValidationInfo,
-)
-
-from airflow import DAG, settings
-from airflow.configuration import conf as airflow_conf
-
-
-def serialize_operator(x: DAG) -> dict:
- from airflow.serialization.serialized_objects import SerializedDAG
-
- return SerializedDAG.serialize_dag(x)
-
-
-def validate_operator(x: DAG | dict[str, Any], _info: ValidationInfo) -> Any:
- from airflow.serialization.serialized_objects import SerializedDAG
-
- if isinstance(x, DAG):
- return x
- return SerializedDAG.deserialize_dag(x)
-
-
-PydanticDag = Annotated[
- DAG,
- PlainValidator(validate_operator),
- PlainSerializer(serialize_operator, return_type=dict),
-]
-
-
-class DagOwnerAttributesPydantic(BaseModelPydantic):
- """Serializable representation of the DagOwnerAttributes ORM
SqlAlchemyModel used by internal API."""
-
- owner: str
- link: str
-
- model_config = ConfigDict(from_attributes=True,
arbitrary_types_allowed=True)
-
-
-class DagTagPydantic(BaseModelPydantic):
- """Serializable representation of the DagTag ORM SqlAlchemyModel used by
internal API."""
-
- name: str
- dag_id: str
-
- model_config = ConfigDict(from_attributes=True,
arbitrary_types_allowed=True)
-
-
-class DagModelPydantic(BaseModelPydantic):
- """Serializable representation of the DagModel ORM SqlAlchemyModel used by
internal API."""
-
- dag_id: str
- is_paused_at_creation: bool = airflow_conf.getboolean("core",
"dags_are_paused_at_creation")
- is_paused: bool = is_paused_at_creation
- is_active: Optional[bool] = False
- last_parsed_time: Optional[datetime]
- last_expired: Optional[datetime]
- fileloc: str
- processor_subdir: Optional[str]
- owners: Optional[str]
- description: Optional[str]
- default_view: Optional[str]
- timetable_summary: Optional[str]
- timetable_description: Optional[str]
- tags: list[DagTagPydantic]
- dag_owner_links: list[DagOwnerAttributesPydantic]
-
- max_active_tasks: int
- max_active_runs: Optional[int]
- max_consecutive_failed_dag_runs: Optional[int]
-
- has_task_concurrency_limits: bool
- has_import_errors: Optional[bool] = False
-
- _processor_dags_folder: Optional[str] = None
-
- model_config = ConfigDict(from_attributes=True,
arbitrary_types_allowed=True)
-
- @property
- def relative_fileloc(self) -> pathlib.Path:
- """File location of the importable dag 'file' relative to the
configured DAGs folder."""
- path = pathlib.Path(self.fileloc)
- try:
- rel_path = path.relative_to(self._processor_dags_folder or
settings.DAGS_FOLDER)
- if rel_path == pathlib.Path("."):
- return path
- else:
- return rel_path
- except ValueError:
- # Not relative to DAGS_FOLDER.
- return path
diff --git a/airflow/serialization/pydantic/dag_run.py
b/airflow/serialization/pydantic/dag_run.py
deleted file mode 100644
index a31f2c35927..00000000000
--- a/airflow/serialization/pydantic/dag_run.py
+++ /dev/null
@@ -1,114 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from __future__ import annotations
-
-from collections.abc import Iterable
-from datetime import datetime
-from typing import TYPE_CHECKING, Optional
-from uuid import UUID
-
-from pydantic import BaseModel as BaseModelPydantic, ConfigDict
-
-from airflow.models.dagrun import DagRun
-from airflow.serialization.pydantic.asset import AssetEventPydantic
-from airflow.serialization.pydantic.dag import PydanticDag
-from airflow.utils.types import DagRunTriggeredByType
-
-if TYPE_CHECKING:
- from sqlalchemy.orm import Session
-
- from airflow.jobs.scheduler_job_runner import TI
- from airflow.utils.state import TaskInstanceState
-
-
-class DagRunPydantic(BaseModelPydantic):
- """Serializable representation of the DagRun ORM SqlAlchemyModel used by
internal API."""
-
- id: int
- dag_id: str
- queued_at: Optional[datetime]
- logical_date: datetime
- start_date: Optional[datetime]
- end_date: Optional[datetime]
- state: str
- run_id: str
- creating_job_id: Optional[int]
- external_trigger: bool
- run_type: str
- conf: dict
- data_interval_start: Optional[datetime]
- data_interval_end: Optional[datetime]
- last_scheduling_decision: Optional[datetime]
- dag_version_id: Optional[UUID]
- updated_at: Optional[datetime]
- dag: Optional[PydanticDag]
- consumed_asset_events: list[AssetEventPydantic]
- log_template_id: Optional[int]
- triggered_by: Optional[DagRunTriggeredByType]
-
- model_config = ConfigDict(from_attributes=True,
arbitrary_types_allowed=True)
-
- def get_task_instances(
- self,
- state: Iterable[TaskInstanceState | None] | None = None,
- session=None,
- ) -> list[TI]:
- """
- Return the task instances for this dag run.
-
- Redirect to DagRun.fetch_task_instances method.
- Keep this method because it is widely used across the code.
- """
- task_ids = DagRun._get_partial_task_ids(self.dag)
- return DagRun.fetch_task_instances(
- dag_id=self.dag_id,
- run_id=self.run_id,
- task_ids=task_ids,
- state=state,
- session=session,
- )
-
- def get_task_instance(
- self,
- task_id: str,
- session: Session,
- *,
- map_index: int = -1,
- ) -> TI | None:
- """
- Return the task instance specified by task_id for this dag run.
-
- :param task_id: the task id
- :param session: Sqlalchemy ORM Session
- """
- from airflow.models.dagrun import DagRun
-
- return DagRun.fetch_task_instance(
- dag_id=self.dag_id,
- dag_run_id=self.run_id,
- task_id=task_id,
- session=session,
- map_index=map_index,
- )
-
- def get_log_template(self, session: Session):
- from airflow.models.dagrun import DagRun
-
- return DagRun._get_log_template(log_template_id=self.log_template_id)
-
-
-DagRunPydantic.model_rebuild()
diff --git a/airflow/serialization/pydantic/job.py
b/airflow/serialization/pydantic/job.py
deleted file mode 100644
index ab2dafa1787..00000000000
--- a/airflow/serialization/pydantic/job.py
+++ /dev/null
@@ -1,70 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-import datetime
-from functools import cached_property
-from typing import TYPE_CHECKING, Optional
-
-from pydantic import BaseModel as BaseModelPydantic, ConfigDict
-
-from airflow.executors.executor_loader import ExecutorLoader
-from airflow.jobs.base_job_runner import BaseJobRunner
-
-
-def check_runner_initialized(job_runner: Optional[BaseJobRunner], job_type:
str) -> BaseJobRunner:
- if job_runner is None:
- raise ValueError(f"In order to run {job_type} you need to initialize
the {job_type}Runner first.")
- return job_runner
-
-
-class JobPydantic(BaseModelPydantic):
- """Serializable representation of the Job ORM SqlAlchemyModel used by
internal API."""
-
- id: Optional[int]
- dag_id: Optional[str]
- state: Optional[str]
- job_type: Optional[str]
- start_date: Optional[datetime.datetime]
- end_date: Optional[datetime.datetime]
- latest_heartbeat: datetime.datetime
- executor_class: Optional[str]
- hostname: Optional[str]
- unixname: Optional[str]
- grace_multiplier: float = 2.1
-
- model_config = ConfigDict(from_attributes=True)
-
- @cached_property
- def executor(self):
- return ExecutorLoader.get_default_executor()
-
- @cached_property
- def heartrate(self) -> float:
- from airflow.jobs.job import Job
-
- if TYPE_CHECKING:
- assert self.job_type is not None
- return Job._heartrate(self.job_type)
-
- def is_alive(self) -> bool:
- """Is this job currently alive."""
- from airflow.jobs.job import Job, health_check_threshold
-
- return Job._is_alive(
- state=self.state,
- health_check_threshold_value=health_check_threshold(self.job_type,
self.heartrate),
- latest_heartbeat=self.latest_heartbeat,
- )
diff --git a/airflow/serialization/pydantic/taskinstance.py
b/airflow/serialization/pydantic/taskinstance.py
deleted file mode 100644
index 43bfd527a74..00000000000
--- a/airflow/serialization/pydantic/taskinstance.py
+++ /dev/null
@@ -1,246 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from __future__ import annotations
-
-from collections.abc import Iterable
-from datetime import datetime
-from typing import TYPE_CHECKING, Annotated, Any, Optional
-from uuid import UUID
-
-from pydantic import (
- BaseModel as BaseModelPydantic,
- ConfigDict,
- PlainSerializer,
- PlainValidator,
-)
-
-from airflow.exceptions import AirflowRescheduleException
-from airflow.models import Operator
-from airflow.models.baseoperator import BaseOperator
-from airflow.models.taskinstance import (
- TaskInstance,
- _handle_reschedule,
- _set_ti_attrs,
-)
-from airflow.serialization.pydantic.dag import DagModelPydantic
-from airflow.serialization.pydantic.dag_run import DagRunPydantic
-from airflow.utils import timezone
-from airflow.utils.log.logging_mixin import LoggingMixin
-from airflow.utils.xcom import XCOM_RETURN_KEY
-
-if TYPE_CHECKING:
- from pydantic import ValidationInfo
- from sqlalchemy.orm import Session
-
- from airflow.models.dagrun import DagRun
- from airflow.utils.context import Context
-
-
-def serialize_operator(x: Operator | None) -> dict | None:
- if x:
- from airflow.serialization.serialized_objects import BaseSerialization
-
- return BaseSerialization.serialize(x, use_pydantic_models=True)
- return None
-
-
-def validated_operator(x: dict[str, Any] | Operator, _info: ValidationInfo) ->
Any:
- from airflow.models.mappedoperator import MappedOperator
-
- if isinstance(x, BaseOperator) or isinstance(x, MappedOperator) or x is
None:
- return x
- from airflow.serialization.serialized_objects import BaseSerialization
-
- return BaseSerialization.deserialize(x, use_pydantic_models=True)
-
-
-PydanticOperator = Annotated[
- Operator,
- PlainValidator(validated_operator),
- PlainSerializer(serialize_operator, return_type=dict),
-]
-
-
-class TaskInstancePydantic(BaseModelPydantic, LoggingMixin):
- """Serializable representation of the TaskInstance ORM SqlAlchemyModel
used by internal API."""
-
- id: str
- task_id: str
- dag_id: str
- run_id: str
- map_index: int
- start_date: Optional[datetime]
- end_date: Optional[datetime]
- logical_date: Optional[datetime]
- duration: Optional[float]
- state: Optional[str]
- try_number: int
- max_tries: int
- hostname: str
- unixname: str
- pool: str
- pool_slots: int
- queue: str
- priority_weight: Optional[int]
- operator: str
- custom_operator_name: Optional[str]
- queued_dttm: Optional[datetime]
- queued_by_job_id: Optional[int]
- last_heartbeat_at: Optional[datetime] = None
- pid: Optional[int]
- executor: Optional[str]
- executor_config: Any
- updated_at: Optional[datetime]
- rendered_map_index: Optional[str]
- external_executor_id: Optional[str]
- trigger_id: Optional[int]
- trigger_timeout: Optional[datetime]
- next_method: Optional[str]
- next_kwargs: Optional[dict]
- dag_version_id: Optional[UUID]
- run_as_user: Optional[str]
- task: Optional[PydanticOperator]
- test_mode: bool
- dag_run: Optional[DagRunPydantic]
- dag_model: Optional[DagModelPydantic]
- raw: Optional[bool]
- is_trigger_log_context: Optional[bool]
- model_config = ConfigDict(from_attributes=True,
arbitrary_types_allowed=True)
-
- @property
- def _logger_name(self):
- return "airflow.task"
-
- def _run_execute_callback(self, context, task):
- TaskInstance._run_execute_callback(self=self, context=context,
task=task) # type: ignore[arg-type]
-
- def render_templates(self, context: Context | None = None, jinja_env=None):
- return TaskInstance.render_templates(self=self, context=context,
jinja_env=jinja_env) # type: ignore[arg-type]
-
- def init_run_context(self, raw: bool = False) -> None:
- """Set the log context."""
- self.raw = raw
- self._set_context(self)
-
- def xcom_pull(
- self,
- task_ids: str | Iterable[str] | None = None,
- dag_id: str | None = None,
- key: str = XCOM_RETURN_KEY,
- include_prior_dates: bool = False,
- session: Session | None = None,
- *,
- map_indexes: int | Iterable[int] | None = None,
- default: Any = None,
- ) -> Any:
- """
- Pull an XCom value for this task instance.
-
- :param task_ids: task id or list of task ids, if None, the task_id of
the current task is used
- :param dag_id: dag id, if None, the dag_id of the current task is used
- :param key: the key to identify the XCom value
- :param include_prior_dates: whether to include prior logical dates
- :param session: the sqlalchemy session
- :param map_indexes: map index or list of map indexes, if None, the
map_index of the current task
- is used
- :param default: the default value to return if the XCom value does not
exist
- :return: Xcom value
- """
- return TaskInstance.xcom_pull(
- self=self, # type: ignore[arg-type]
- task_ids=task_ids,
- dag_id=dag_id,
- key=key,
- include_prior_dates=include_prior_dates,
- map_indexes=map_indexes,
- default=default,
- session=session,
- )
-
- def xcom_push(
- self,
- key: str,
- value: Any,
- session: Session | None = None,
- ) -> None:
- """
- Push an XCom value for this task instance.
-
- :param key: the key to identify the XCom value
- :param value: the value of the XCom
- """
- return TaskInstance.xcom_push(
- self=self, # type: ignore[arg-type]
- key=key,
- value=value,
- session=session,
- )
-
- def get_dagrun(self, session: Session | None = None) -> DagRun:
- """
- Return the DagRun for this TaskInstance.
-
- :param session: SQLAlchemy ORM Session
-
- :return: Pydantic serialized version of DagRun
- """
- return TaskInstance._get_dagrun(dag_id=self.dag_id,
run_id=self.run_id, session=session)
-
- def _execute_task(self, context, task_orig):
- """
- Execute Task (optionally with a Timeout) and push Xcom results.
-
- :param context: Jinja2 context
- :param task_orig: origin task
- """
- from airflow.models.taskinstance import _execute_task
-
- return _execute_task(task_instance=self, context=context,
task_orig=task_orig)
-
- def update_heartbeat(self):
- """Update the recorded heartbeat for this task to "now"."""
- from airflow.models.taskinstance import _update_ti_heartbeat
-
- return _update_ti_heartbeat(self.id, timezone.utcnow())
-
- def is_eligible_to_retry(self):
- """Is task instance is eligible for retry."""
- from airflow.models.taskinstance import _is_eligible_to_retry
-
- return _is_eligible_to_retry(task_instance=self)
-
- def _register_asset_changes(self, *, events, session: Session | None =
None) -> None:
- TaskInstance._register_asset_changes(self=self, events=events,
session=session) # type: ignore[arg-type]
-
- def _handle_reschedule(
- self,
- actual_start_date: datetime,
- reschedule_exception: AirflowRescheduleException,
- test_mode: bool = False,
- session: Session | None = None,
- ):
- updated_ti = _handle_reschedule(
- ti=self,
- actual_start_date=actual_start_date,
- reschedule_exception=reschedule_exception,
- test_mode=test_mode,
- session=session,
- )
- _set_ti_attrs(self, updated_ti) # _handle_reschedule is a remote call
that mutates the TI
-
-
-TaskInstancePydantic.model_rebuild()
diff --git a/airflow/serialization/pydantic/tasklog.py
b/airflow/serialization/pydantic/tasklog.py
deleted file mode 100644
index a23204400c1..00000000000
--- a/airflow/serialization/pydantic/tasklog.py
+++ /dev/null
@@ -1,30 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from datetime import datetime
-
-from pydantic import BaseModel as BaseModelPydantic, ConfigDict
-
-
-class LogTemplatePydantic(BaseModelPydantic):
- """Serializable version of the LogTemplate ORM SqlAlchemyModel used by
internal API."""
-
- id: int
- filename: str
- elasticsearch_id: str
- created_at: datetime
-
- model_config = ConfigDict(from_attributes=True)
diff --git a/airflow/serialization/pydantic/trigger.py
b/airflow/serialization/pydantic/trigger.py
deleted file mode 100644
index 4c120148aff..00000000000
--- a/airflow/serialization/pydantic/trigger.py
+++ /dev/null
@@ -1,63 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-import datetime
-from typing import Any, Optional
-
-from pydantic import BaseModel as BaseModelPydantic, ConfigDict
-
-from airflow.utils import timezone
-
-
-class TriggerPydantic(BaseModelPydantic):
- """Serializable representation of the Trigger ORM SqlAlchemyModel used by
internal API."""
-
- # This is technically non-optional, however when we serialize it in
from_object we do not have the ID
- id: Optional[int]
- classpath: str
- encrypted_kwargs: str
- created_date: datetime.datetime
- triggerer_id: Optional[int]
-
- model_config = ConfigDict(from_attributes=True)
-
- def __init__(self, **kwargs) -> None:
- from airflow.models.trigger import Trigger
-
- # Here we have to handle two ways the object can be created:
- # * when Pydantic recreates it from Trigger model, we need a default
__init__ behavior
- # * when we create it in from_object - the kwargs will contain
classpath, kwargs to create it and
- # created_date
- if "kwargs" in kwargs:
- self.classpath = kwargs.pop("classpath")
- self.encrypted_kwargs =
Trigger.encrypt_kwargs(kwargs.pop("kwargs"))
- self.created_date = kwargs.pop("created_date", timezone.utcnow())
- super().__init__(**kwargs)
-
- @property
- def kwargs(self) -> dict[str, Any]:
- """Return the decrypted kwargs of the trigger."""
- from airflow.models import Trigger
-
- return Trigger._decrypt_kwargs(self.encrypted_kwargs)
-
- @kwargs.setter
- def kwargs(self, kwargs: dict[str, Any]) -> None:
- """Set the encrypted kwargs of the trigger."""
- from airflow.models import Trigger
-
- self.encrypted_kwargs = Trigger.encrypt_kwargs(kwargs)
diff --git a/airflow/serialization/serde.py b/airflow/serialization/serde.py
index 7d1b583ce9a..ae289b70f6c 100644
--- a/airflow/serialization/serde.py
+++ b/airflow/serialization/serde.py
@@ -168,12 +168,6 @@ def serialize(o: object, depth: int = 0) -> U | None:
dct[DATA] = data
return dct
- # pydantic models are recursive
- if _is_pydantic(cls):
- data = o.model_dump() # type: ignore[attr-defined]
- dct[DATA] = serialize(data, depth + 1)
- return dct
-
# dataclasses
if dataclasses.is_dataclass(cls):
# fixme: unfortunately using asdict with nested dataclasses it looses
information
@@ -268,8 +262,8 @@ def deserialize(o: T | None, full=True, type_hint: Any =
None) -> object:
if hasattr(cls, "deserialize"):
return getattr(cls, "deserialize")(deserialize(value), version)
- # attr or dataclass or pydantic
- if attr.has(cls) or dataclasses.is_dataclass(cls) or _is_pydantic(cls):
+ # attr or dataclass
+ if attr.has(cls) or dataclasses.is_dataclass(cls):
class_version = getattr(cls, "__version__", 0)
if int(version) > class_version:
raise TypeError(
@@ -339,16 +333,6 @@ def _stringify(classname: str, version: int, value: T |
None) -> str:
return s
-def _is_pydantic(cls: Any) -> bool:
- """
- Return True if the class is a pydantic model.
-
- Checking is done by attributes as it is significantly faster than
- using isinstance.
- """
- return hasattr(cls, "model_config") and hasattr(cls, "model_fields") and
hasattr(cls, "model_fields_set")
-
-
def _is_namedtuple(cls: Any) -> bool:
"""
Return True if the class is a namedtuple.
diff --git a/airflow/serialization/serialized_objects.py
b/airflow/serialization/serialized_objects.py
index 95d2b82f551..fa0984f6d5f 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -38,12 +38,9 @@ from pendulum.tz.timezone import FixedTimezone, Timezone
from airflow import macros
from airflow.callbacks.callback_requests import DagCallbackRequest,
TaskCallbackRequest
from airflow.exceptions import AirflowException, SerializationError,
TaskDeferred
-from airflow.jobs.job import Job
-from airflow.models import Trigger
from airflow.models.baseoperator import BaseOperator
from airflow.models.connection import Connection
-from airflow.models.dag import DAG, DagModel
-from airflow.models.dagrun import DagRun
+from airflow.models.dag import DAG
from airflow.models.expandinput import (
EXPAND_INPUT_EMPTY,
create_expand_input,
@@ -51,9 +48,8 @@ from airflow.models.expandinput import (
)
from airflow.models.mappedoperator import MappedOperator
from airflow.models.param import Param, ParamsDict
-from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance
+from airflow.models.taskinstance import SimpleTaskInstance
from airflow.models.taskinstancekey import TaskInstanceKey
-from airflow.models.tasklog import LogTemplate
from airflow.models.xcom_arg import XComArg, deserialize_xcom_arg,
serialize_xcom_arg
from airflow.providers_manager import ProvidersManager
from airflow.sdk.definitions.asset import (
@@ -71,13 +67,6 @@ from airflow.serialization.dag_dependency import
DagDependency
from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
from airflow.serialization.helpers import serialize_template_field
from airflow.serialization.json_schema import load_dag_schema
-from airflow.serialization.pydantic.asset import AssetPydantic
-from airflow.serialization.pydantic.dag import DagModelPydantic
-from airflow.serialization.pydantic.dag_run import DagRunPydantic
-from airflow.serialization.pydantic.job import JobPydantic
-from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
-from airflow.serialization.pydantic.tasklog import LogTemplatePydantic
-from airflow.serialization.pydantic.trigger import TriggerPydantic
from airflow.settings import DAGS_FOLDER, json
from airflow.task.priority_strategy import (
PriorityWeightStrategy,
@@ -105,8 +94,6 @@ from airflow.utils.types import NOTSET, ArgNotSet,
AttributeRemoved
if TYPE_CHECKING:
from inspect import Parameter
- from pydantic import BaseModel
-
from airflow.models.baseoperatorlink import BaseOperatorLink
from airflow.models.expandinput import ExpandInput
from airflow.models.operator import Operator
@@ -528,34 +515,6 @@ class _ExpandInputRef(NamedTuple):
return create_expand_input(self.key, value)
-_orm_to_model = {
- Job: JobPydantic,
- TaskInstance: TaskInstancePydantic,
- DagRun: DagRunPydantic,
- DagModel: DagModelPydantic,
- LogTemplate: LogTemplatePydantic,
- Asset: AssetPydantic,
- Trigger: TriggerPydantic,
-}
-_type_to_class: dict[DAT | str, list] = {
- DAT.BASE_JOB: [JobPydantic, Job],
- DAT.TASK_INSTANCE: [TaskInstancePydantic, TaskInstance],
- DAT.DAG_RUN: [DagRunPydantic, DagRun],
- DAT.DAG_MODEL: [DagModelPydantic, DagModel],
- DAT.LOG_TEMPLATE: [LogTemplatePydantic, LogTemplate],
- DAT.ASSET: [AssetPydantic, Asset],
- DAT.TRIGGER: [TriggerPydantic, Trigger],
-}
-_class_to_type = {cls_: type_ for type_, classes in _type_to_class.items() for
cls_ in classes}
-
-
-def add_pydantic_class_type_mapping(attribute_type: str, orm_class,
pydantic_class):
- _orm_to_model[orm_class] = pydantic_class
- _type_to_class[attribute_type] = [pydantic_class, orm_class]
- _class_to_type[pydantic_class] = attribute_type
- _class_to_type[orm_class] = attribute_type
-
-
class BaseSerialization:
"""BaseSerialization provides utils for serialization."""
@@ -674,7 +633,7 @@ class BaseSerialization:
@classmethod
def serialize(
- cls, var: Any, *, strict: bool = False, use_pydantic_models: bool =
False
+ cls, var: Any, *, strict: bool = False
) -> Any: # Unfortunately there is no support for recursive types in mypy
"""
Serialize an object; helper function of depth first search for
serialization.
@@ -696,14 +655,11 @@ class BaseSerialization:
return var
elif isinstance(var, dict):
return cls._encode(
- {
- str(k): cls.serialize(v, strict=strict,
use_pydantic_models=use_pydantic_models)
- for k, v in var.items()
- },
+ {str(k): cls.serialize(v, strict=strict) for k, v in
var.items()},
type_=DAT.DICT,
)
elif isinstance(var, list):
- return [cls.serialize(v, strict=strict,
use_pydantic_models=use_pydantic_models) for v in var]
+ return [cls.serialize(v, strict=strict) for v in var]
elif var.__class__.__name__ == "V1Pod" and _has_kubernetes() and
isinstance(var, k8s.V1Pod):
json_pod = PodGenerator.serialize_pod(var)
return cls._encode(json_pod, type_=DAT.POD)
@@ -749,7 +705,6 @@ class BaseSerialization:
return cls._encode(
cls.serialize(
{"exc_cls_name": exc_cls_name, "args": args, "kwargs":
kwargs},
- use_pydantic_models=use_pydantic_models,
strict=strict,
),
type_=DAT.AIRFLOW_EXC_SER,
@@ -762,7 +717,6 @@ class BaseSerialization:
"args": [var.args],
"kwargs": {},
},
- use_pydantic_models=use_pydantic_models,
strict=strict,
),
type_=DAT.BASE_EXC_SER,
@@ -771,7 +725,6 @@ class BaseSerialization:
return cls._encode(
cls.serialize(
var.serialize(),
- use_pydantic_models=use_pydantic_models,
strict=strict,
),
type_=DAT.BASE_TRIGGER,
@@ -782,20 +735,18 @@ class BaseSerialization:
# FIXME: casts set to list in customized serialization in future.
try:
return cls._encode(
- sorted(
- cls.serialize(v, strict=strict,
use_pydantic_models=use_pydantic_models) for v in var
- ),
+ sorted(cls.serialize(v, strict=strict) for v in var),
type_=DAT.SET,
)
except TypeError:
return cls._encode(
- [cls.serialize(v, strict=strict,
use_pydantic_models=use_pydantic_models) for v in var],
+ [cls.serialize(v, strict=strict) for v in var],
type_=DAT.SET,
)
elif isinstance(var, tuple):
# FIXME: casts tuple to list in customized serialization in future.
return cls._encode(
- [cls.serialize(v, strict=strict,
use_pydantic_models=use_pydantic_models) for v in var],
+ [cls.serialize(v, strict=strict) for v in var],
type_=DAT.TUPLE,
)
elif isinstance(var, TaskGroup):
@@ -813,7 +764,7 @@ class BaseSerialization:
return cls._encode({"name": var.name}, type_=DAT.ASSET_REF)
elif isinstance(var, SimpleTaskInstance):
return cls._encode(
- cls.serialize(var.__dict__, strict=strict,
use_pydantic_models=use_pydantic_models),
+ cls.serialize(var.__dict__, strict=strict),
type_=DAT.SIMPLE_TASK_INSTANCE,
)
elif isinstance(var, Connection):
@@ -825,21 +776,9 @@ class BaseSerialization:
elif var.__class__ == Context:
d = {}
for k, v in var._context.items():
- obj = cls.serialize(v, strict=strict,
use_pydantic_models=use_pydantic_models)
+ obj = cls.serialize(v, strict=strict)
d[str(k)] = obj
return cls._encode(d, type_=DAT.TASK_CONTEXT)
- elif use_pydantic_models:
-
- def _pydantic_model_dump(model_cls: type[BaseModel], var: Any) ->
dict[str, Any]:
- return model_cls.model_validate(var).model_dump(mode="json")
# type: ignore[attr-defined]
-
- if var.__class__ in _class_to_type:
- pyd_mod = _orm_to_model.get(var.__class__, var)
- mod = _pydantic_model_dump(pyd_mod, var)
- type_ = _class_to_type[var.__class__]
- return cls._encode(mod, type_=type_)
- else:
- return cls.default_serialization(strict, var)
elif isinstance(var, ArgNotSet):
return cls._encode(None, type_=DAT.ARG_NOT_SET)
else:
@@ -853,7 +792,7 @@ class BaseSerialization:
return str(var)
@classmethod
- def deserialize(cls, encoded_var: Any, use_pydantic_models=False) -> Any:
+ def deserialize(cls, encoded_var: Any) -> Any:
"""
Deserialize an object; helper function of depth first search for
deserialization.
@@ -862,7 +801,7 @@ class BaseSerialization:
if cls._is_primitive(encoded_var):
return encoded_var
elif isinstance(encoded_var, list):
- return [cls.deserialize(v, use_pydantic_models) for v in
encoded_var]
+ return [cls.deserialize(v) for v in encoded_var]
if not isinstance(encoded_var, dict):
raise ValueError(f"The encoded_var should be dict and is
{type(encoded_var)}")
@@ -873,7 +812,7 @@ class BaseSerialization:
for k, v in var.items():
if k == "task": # todo: add `_encode` of Operator so we don't
need this
continue
- d[k] = cls.deserialize(v, use_pydantic_models=True)
+ d[k] = cls.deserialize(v)
d["task"] = d["task_instance"].task # todo: add `_encode` of
Operator so we don't need this
d["macros"] = macros
d["var"] = {
@@ -883,7 +822,7 @@ class BaseSerialization:
d["conn"] = ConnectionAccessor()
return Context(**d)
elif type_ == DAT.DICT:
- return {k: cls.deserialize(v, use_pydantic_models) for k, v in
var.items()}
+ return {k: cls.deserialize(v) for k, v in var.items()}
elif type_ == DAT.ASSET_EVENT_ACCESSORS:
return decode_outlet_event_accessors(var)
elif type_ == DAT.ASSET_UNIQUE_KEY:
@@ -908,7 +847,7 @@ class BaseSerialization:
elif type_ == DAT.RELATIVEDELTA:
return decode_relativedelta(var)
elif type_ == DAT.AIRFLOW_EXC_SER or type_ == DAT.BASE_EXC_SER:
- deser = cls.deserialize(var,
use_pydantic_models=use_pydantic_models)
+ deser = cls.deserialize(var)
exc_cls_name = deser["exc_cls_name"]
args = deser["args"]
kwargs = deser["kwargs"]
@@ -919,13 +858,13 @@ class BaseSerialization:
exc_cls = import_string(f"builtins.{exc_cls_name}")
return exc_cls(*args, **kwargs)
elif type_ == DAT.BASE_TRIGGER:
- tr_cls_name, kwargs = cls.deserialize(var,
use_pydantic_models=use_pydantic_models)
+ tr_cls_name, kwargs = cls.deserialize(var)
tr_cls = import_string(tr_cls_name)
return tr_cls(**kwargs)
elif type_ == DAT.SET:
- return {cls.deserialize(v, use_pydantic_models) for v in var}
+ return {cls.deserialize(v) for v in var}
elif type_ == DAT.TUPLE:
- return tuple(cls.deserialize(v, use_pydantic_models) for v in var)
+ return tuple(cls.deserialize(v) for v in var)
elif type_ == DAT.PARAM:
return cls._deserialize_param(var)
elif type_ == DAT.XCOM_REF:
@@ -950,8 +889,6 @@ class BaseSerialization:
return DagCallbackRequest.from_json(var)
elif type_ == DAT.TASK_INSTANCE_KEY:
return TaskInstanceKey(**var)
- elif use_pydantic_models:
- return _type_to_class[type_][0].model_validate(var)
elif type_ == DAT.ARG_NOT_SET:
return NOTSET
else:
@@ -1420,7 +1357,7 @@ class SerializedBaseOperator(BaseOperator,
BaseSerialization):
elif k in {"expand_input", "op_kwargs_expand_input"}:
v = _ExpandInputRef(v["type"], cls.deserialize(v["value"]))
elif k == "operator_class":
- v = {k_: cls.deserialize(v_, use_pydantic_models=True) for k_,
v_ in v.items()}
+ v = {k_: cls.deserialize(v_) for k_, v_ in v.items()}
elif (
k in cls._decorated_fields
or k not in op.get_serialized_fields()
@@ -1662,13 +1599,13 @@ class SerializedBaseOperator(BaseOperator,
BaseSerialization):
return serialize_operator_extra_links
@classmethod
- def serialize(cls, var: Any, *, strict: bool = False, use_pydantic_models:
bool = False) -> Any:
+ def serialize(cls, var: Any, *, strict: bool = False) -> Any:
# the wonders of multiple inheritance BaseOperator defines an instance
method
- return BaseSerialization.serialize(var=var, strict=strict,
use_pydantic_models=use_pydantic_models)
+ return BaseSerialization.serialize(var=var, strict=strict)
@classmethod
- def deserialize(cls, encoded_var: Any, use_pydantic_models: bool = False)
-> Any:
- return BaseSerialization.deserialize(encoded_var=encoded_var,
use_pydantic_models=use_pydantic_models)
+ def deserialize(cls, encoded_var: Any) -> Any:
+ return BaseSerialization.deserialize(encoded_var=encoded_var)
class SerializedDAG(DAG, BaseSerialization):
diff --git a/airflow/ui/openapi-gen/requests/schemas.gen.ts
b/airflow/ui/openapi-gen/requests/schemas.gen.ts
index 965355b0c19..0b9333c500c 100644
--- a/airflow/ui/openapi-gen/requests/schemas.gen.ts
+++ b/airflow/ui/openapi-gen/requests/schemas.gen.ts
@@ -1148,7 +1148,7 @@ export const $DAGDetailsResponse = {
},
tags: {
items: {
- $ref: "#/components/schemas/DagTagPydantic",
+ $ref: "#/components/schemas/DagTagResponse",
},
type: "array",
title: "Tags",
@@ -1521,7 +1521,7 @@ export const $DAGResponse = {
},
tags: {
items: {
- $ref: "#/components/schemas/DagTagPydantic",
+ $ref: "#/components/schemas/DagTagResponse",
},
type: "array",
title: "Tags",
@@ -2267,7 +2267,7 @@ export const $DAGWithLatestDagRunsResponse = {
},
tags: {
items: {
- $ref: "#/components/schemas/DagTagPydantic",
+ $ref: "#/components/schemas/DagTagResponse",
},
type: "array",
title: "Tags",
@@ -2605,7 +2605,7 @@ export const $DagStatsStateResponse = {
description: "DagStatsState serializer for responses.",
} as const;
-export const $DagTagPydantic = {
+export const $DagTagResponse = {
properties: {
name: {
type: "string",
@@ -2618,9 +2618,8 @@ export const $DagTagPydantic = {
},
type: "object",
required: ["name", "dag_id"],
- title: "DagTagPydantic",
- description:
- "Serializable representation of the DagTag ORM SqlAlchemyModel used by
internal API.",
+ title: "DagTagResponse",
+ description: "DAG Tag serializer for responses.",
} as const;
export const $DagWarningType = {
diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts
b/airflow/ui/openapi-gen/requests/types.gen.ts
index 42196f1a892..f7640ef45a0 100644
--- a/airflow/ui/openapi-gen/requests/types.gen.ts
+++ b/airflow/ui/openapi-gen/requests/types.gen.ts
@@ -293,7 +293,7 @@ export type DAGDetailsResponse = {
description: string | null;
timetable_summary: string | null;
timetable_description: string | null;
- tags: Array<DagTagPydantic>;
+ tags: Array<DagTagResponse>;
max_active_tasks: number;
max_active_runs: number | null;
max_consecutive_failed_dag_runs: number;
@@ -352,7 +352,7 @@ export type DAGResponse = {
description: string | null;
timetable_summary: string | null;
timetable_description: string | null;
- tags: Array<DagTagPydantic>;
+ tags: Array<DagTagResponse>;
max_active_tasks: number;
max_active_runs: number | null;
max_consecutive_failed_dag_runs: number;
@@ -515,7 +515,7 @@ export type DAGWithLatestDagRunsResponse = {
description: string | null;
timetable_summary: string | null;
timetable_description: string | null;
- tags: Array<DagTagPydantic>;
+ tags: Array<DagTagResponse>;
max_active_tasks: number;
max_active_runs: number | null;
max_consecutive_failed_dag_runs: number;
@@ -620,9 +620,9 @@ export type DagStatsStateResponse = {
};
/**
- * Serializable representation of the DagTag ORM SqlAlchemyModel used by
internal API.
+ * DAG Tag serializer for responses.
*/
-export type DagTagPydantic = {
+export type DagTagResponse = {
name: string;
dag_id: string;
};
diff --git a/airflow/ui/src/pages/DagsList/DagCard.test.tsx
b/airflow/ui/src/pages/DagsList/DagCard.test.tsx
index 44705863343..d32cd8786f6 100644
--- a/airflow/ui/src/pages/DagsList/DagCard.test.tsx
+++ b/airflow/ui/src/pages/DagsList/DagCard.test.tsx
@@ -20,7 +20,7 @@
*/
import { render, screen } from "@testing-library/react";
import type {
- DagTagPydantic,
+ DagTagResponse,
DAGWithLatestDagRunsResponse,
} from "openapi-gen/requests/types.gen";
import { afterEach, describe, it, vi, expect } from "vitest";
@@ -74,7 +74,7 @@ describe("DagCard", () => {
{ dag_id: "id", name: "tag2" },
{ dag_id: "id", name: "tag3" },
{ dag_id: "id", name: "tag4" },
- ] satisfies Array<DagTagPydantic>;
+ ] satisfies Array<DagTagResponse>;
const expandedMockDag = {
...mockDag,
diff --git a/airflow/ui/src/pages/DagsList/DagTags.tsx
b/airflow/ui/src/pages/DagsList/DagTags.tsx
index 8716824312c..77253391d59 100644
--- a/airflow/ui/src/pages/DagsList/DagTags.tsx
+++ b/airflow/ui/src/pages/DagsList/DagTags.tsx
@@ -19,14 +19,14 @@
import { Flex, Text, VStack } from "@chakra-ui/react";
import { FiTag } from "react-icons/fi";
-import type { DagTagPydantic } from "openapi/requests/types.gen";
+import type { DagTagResponse } from "openapi/requests/types.gen";
import { Tooltip } from "src/components/ui";
const MAX_TAGS = 3;
type Props = {
readonly hideIcon?: boolean;
- readonly tags: Array<DagTagPydantic>;
+ readonly tags: Array<DagTagResponse>;
};
export const DagTags = ({ hideIcon = false, tags }: Props) =>
diff --git a/docs/apache-airflow/extra-packages-ref.rst
b/docs/apache-airflow/extra-packages-ref.rst
index 5d54a94a422..70d8e3b5fe6 100644
--- a/docs/apache-airflow/extra-packages-ref.rst
+++ b/docs/apache-airflow/extra-packages-ref.rst
@@ -71,8 +71,6 @@ python dependencies for the provided package.
+---------------------+-----------------------------------------------------+----------------------------------------------------------------------------+
| password | ``pip install 'apache-airflow[password]'`` |
Password authentication for users |
+---------------------+-----------------------------------------------------+----------------------------------------------------------------------------+
-| pydantic | ``pip install 'apache-airflow[pydantic]'`` |
Pydantic serialization for internal-api |
-+---------------------+-----------------------------------------------------+----------------------------------------------------------------------------+
| rabbitmq | ``pip install 'apache-airflow[rabbitmq]'`` |
RabbitMQ support as a Celery backend |
+---------------------+-----------------------------------------------------+----------------------------------------------------------------------------+
| sentry | ``pip install 'apache-airflow[sentry]'`` |
Sentry service for application logging and monitoring |
diff --git
a/providers/src/airflow/providers/edge/worker_api/routes/_v2_routes.py
b/providers/src/airflow/providers/edge/worker_api/routes/_v2_routes.py
index 2b68879531d..b4e9d440563 100644
--- a/providers/src/airflow/providers/edge/worker_api/routes/_v2_routes.py
+++ b/providers/src/airflow/providers/edge/worker_api/routes/_v2_routes.py
@@ -88,7 +88,7 @@ def rpcapi_v2(body: dict[str, Any]) -> APIResponse:
params = {}
try:
if request_obj.params:
- params = BaseSerialization.deserialize(request_obj.params,
use_pydantic_models=True)
+ params = BaseSerialization.deserialize(request_obj.params)
except Exception:
raise error_response("Error deserializing parameters.",
status.HTTP_400_BAD_REQUEST)
@@ -97,13 +97,13 @@ def rpcapi_v2(body: dict[str, Any]) -> APIResponse:
# Session must be created there as it may be needed by serializer
for lazy-loaded fields.
with create_session() as session:
output = handler(**params, session=session)
- output_json = BaseSerialization.serialize(output,
use_pydantic_models=True)
+ output_json = BaseSerialization.serialize(output)
log.debug(
"Sending response: %s", json.dumps(output_json) if
output_json is not None else None
)
# In case of AirflowException or other selective known types,
transport the exception class back to caller
except (KeyError, AttributeError, AirflowException) as e:
- output_json = BaseSerialization.serialize(e,
use_pydantic_models=True)
+ output_json = BaseSerialization.serialize(e)
log.debug(
"Sending exception response: %s", json.dumps(output_json) if
output_json is not None else None
)
diff --git a/providers/src/airflow/providers/standard/operators/python.py
b/providers/src/airflow/providers/standard/operators/python.py
index ce1447769e1..40a0cb7a922 100644
--- a/providers/src/airflow/providers/standard/operators/python.py
+++ b/providers/src/airflow/providers/standard/operators/python.py
@@ -57,14 +57,12 @@ from airflow.utils.context import context_copy_partial,
context_merge
from airflow.utils.file import get_unique_dag_module_name
from airflow.utils.operator_helpers import KeywordParameters
from airflow.utils.process_utils import execute_in_subprocess,
execute_in_subprocess_with_kwargs
-from airflow.utils.session import create_session
log = logging.getLogger(__name__)
if TYPE_CHECKING:
from pendulum.datetime import DateTime
- from airflow.serialization.enums import Encoding
from airflow.utils.context import Context
@@ -530,18 +528,19 @@ class _BasePythonVirtualenvOperator(PythonOperator,
metaclass=ABCMeta):
render_template_as_native_obj=self.dag.render_template_as_native_obj,
)
if self.use_airflow_context:
- from airflow.serialization.serialized_objects import
BaseSerialization
-
- context = get_current_context()
- with create_session() as session:
- # FIXME: DetachedInstanceError
- dag_run, task_instance = context["dag_run"],
context["task_instance"]
- session.add_all([dag_run, task_instance])
- serializable_context: dict[Encoding, Any] =
BaseSerialization.serialize(
- context, use_pydantic_models=True
- )
- with airflow_context_path.open("w+") as file:
- json.dump(serializable_context, file)
+ # TODO: replace with commented code when context serialization
is implemented in AIP-72
+ raise AirflowException(
+ "The `use_airflow_context=True` is not yet implemented. "
+ "It will work in Airflow 3 after AIP-72 context "
+ "serialization is ready."
+ )
+ # context = get_current_context()
+ # with create_session() as session:
+ # dag_run, task_instance = context["dag_run"],
context["task_instance"]
+ # session.add_all([dag_run, task_instance])
+ # serializable_context: dict[Encoding, Any] = # Get
serializable context here
+ # with airflow_context_path.open("w+") as file:
+ # json.dump(serializable_context, file)
env_vars = dict(os.environ) if self.inherit_env else {}
if self.env_vars:
@@ -653,6 +652,7 @@ class
PythonVirtualenvOperator(_BasePythonVirtualenvOperator):
of the parent process (``os.environ``). If set to ``False``, the
virtual environment will be
executed with a clean environment.
:param use_airflow_context: Whether to provide ``get_current_context()``
to the python_callable.
+ NOT YET IMPLEMENTED - waits for AIP-72 context serialization.
"""
template_fields: Sequence[str] = tuple(
@@ -697,15 +697,18 @@ class
PythonVirtualenvOperator(_BasePythonVirtualenvOperator):
raise AirflowException(
"Passing non-string types (e.g. int or float) as
python_version not supported"
)
- if use_airflow_context and not AIRFLOW_V_3_0_PLUS:
- raise AirflowException(
- "The `use_airflow_context=True` is only supported in Airflow
3.0.0 and later."
- )
if use_airflow_context and (not expect_airflow and not
system_site_packages):
raise AirflowException(
"The `use_airflow_context` parameter is set to True, but "
"expect_airflow and system_site_packages are set to False."
)
+ # TODO: remove when context serialization is implemented in AIP-72
+ if use_airflow_context and not AIRFLOW_V_3_0_PLUS:
+ raise AirflowException(
+ "The `use_airflow_context=True` is not yet implemented. "
+ "It will work in Airflow 3 after AIP-72 context "
+ "serialization is ready."
+ )
if not requirements:
self.requirements: list[str] = []
elif isinstance(requirements, str):
@@ -951,6 +954,7 @@ class ExternalPythonOperator(_BasePythonVirtualenvOperator):
of the parent process (``os.environ``). If set to ``False``, the
virtual environment will be
executed with a clean environment.
:param use_airflow_context: Whether to provide ``get_current_context()``
to the python_callable.
+ NOT YET IMPLEMENTED - waits for AIP-72 context serialization.
"""
template_fields: Sequence[str] =
tuple({"python"}.union(PythonOperator.template_fields))
@@ -976,14 +980,17 @@ class
ExternalPythonOperator(_BasePythonVirtualenvOperator):
):
if not python:
raise ValueError("Python Path must be defined in
ExternalPythonOperator")
- if use_airflow_context and not AIRFLOW_V_3_0_PLUS:
- raise AirflowException(
- "The `use_airflow_context=True` is only supported in Airflow
3.0.0 and later."
- )
if use_airflow_context and not expect_airflow:
raise AirflowException(
"The `use_airflow_context` parameter is set to True, but
expect_airflow is set to False."
)
+ # TODO: remove when context serialization is implemented in AIP-72
+ if use_airflow_context:
+ raise AirflowException(
+ "The `use_airflow_context=True` is not yet implemented. "
+ "It will work in Airflow 3 after AIP-72 context "
+ "serialization is ready."
+ )
self.python = python
self.expect_pendulum = expect_pendulum
super().__init__(
diff --git
a/providers/src/airflow/providers/standard/utils/python_virtualenv_script.jinja2
b/providers/src/airflow/providers/standard/utils/python_virtualenv_script.jinja2
index 54c8412d1d8..6b803b408f2 100644
---
a/providers/src/airflow/providers/standard/utils/python_virtualenv_script.jinja2
+++
b/providers/src/airflow/providers/standard/utils/python_virtualenv_script.jinja2
@@ -70,7 +70,6 @@ if len(sys.argv) > 5:
from types import ModuleType
from airflow.providers.standard.operators import python as airflow_python
- from airflow.serialization.serialized_objects import BaseSerialization
class _MockPython(ModuleType):
@@ -78,7 +77,8 @@ if len(sys.argv) > 5:
def get_current_context():
with open(sys.argv[5]) as file:
context = json.load(file)
- return BaseSerialization.deserialize(context,
use_pydantic_models=True)
+ raise Exception("Not yet implemented")
+ # TODO: return deserialized context
def __getattr__(self, name: str):
return getattr(airflow_python, name)
diff --git a/providers/tests/standard/operators/test_python.py
b/providers/tests/standard/operators/test_python.py
index a434448ff33..2d2fe9e2c16 100644
--- a/providers/tests/standard/operators/test_python.py
+++ b/providers/tests/standard/operators/test_python.py
@@ -90,9 +90,7 @@ DILL_MARKER = pytest.mark.skipif(not DILL_INSTALLED,
reason="`dill` is not insta
CLOUDPICKLE_INSTALLED = find_spec("cloudpickle") is not None
CLOUDPICKLE_MARKER = pytest.mark.skipif(not CLOUDPICKLE_INSTALLED,
reason="`cloudpickle` is not installed")
-AIRFLOW_CONTEXT_BEFORE_V3_0_MESSAGE = (
- r"The `use_airflow_context=True` is only supported in Airflow 3.0.0 and
later."
-)
+AIRFLOW_CONTEXT_NOT_IMPLEMENTED_YET_MESSAGE = r"The `use_airflow_context=True`
is not yet implemented."
class BasePythonTest:
@@ -1054,12 +1052,12 @@ class BaseTestPythonVirtualenvOperator(BasePythonTest):
return []
- if AIRFLOW_V_3_0_PLUS:
- ti = self.run_as_task(f, return_ti=True, multiple_outputs=False,
use_airflow_context=True)
- assert ti.state == TaskInstanceState.SUCCESS
- else:
- with pytest.raises(AirflowException,
match=AIRFLOW_CONTEXT_BEFORE_V3_0_MESSAGE):
- self.run_as_task(f, return_ti=True, use_airflow_context=True)
+ # TODO: replace with commented code when context serialization is
implemented in AIP-72
+ with pytest.raises(Exception,
match=AIRFLOW_CONTEXT_NOT_IMPLEMENTED_YET_MESSAGE):
+ self.run_as_task(f, return_ti=True, use_airflow_context=True)
+
+ # ti = self.run_as_task(f, return_ti=True, multiple_outputs=False,
use_airflow_context=True)
+ # assert ti.state == TaskInstanceState.SUCCESS
def test_current_context_not_found_error(self):
def f():
@@ -1085,6 +1083,7 @@ class BaseTestPythonVirtualenvOperator(BasePythonTest):
def test_current_context_airflow_not_found_error(self):
airflow_flag: dict[str, bool] = {"expect_airflow": False}
+
error_msg = r"The `use_airflow_context` parameter is set to True, but
expect_airflow is set to False."
if not issubclass(self.opcls, ExternalPythonOperator):
@@ -1100,14 +1099,10 @@ class BaseTestPythonVirtualenvOperator(BasePythonTest):
get_current_context()
return []
- if AIRFLOW_V_3_0_PLUS:
- with pytest.raises(AirflowException, match=error_msg):
- self.run_as_task(
- f, return_ti=True, multiple_outputs=False,
use_airflow_context=True, **airflow_flag
- )
- else:
- with pytest.raises(AirflowException,
match=AIRFLOW_CONTEXT_BEFORE_V3_0_MESSAGE):
- self.run_as_task(f, return_ti=True, use_airflow_context=True,
**airflow_flag)
+ with pytest.raises(AirflowException, match=error_msg):
+ self.run_as_task(
+ f, return_ti=True, multiple_outputs=False,
use_airflow_context=True, **airflow_flag
+ )
def test_use_airflow_context_touch_other_variables(self):
def f():
@@ -1121,12 +1116,12 @@ class BaseTestPythonVirtualenvOperator(BasePythonTest):
return []
- if AIRFLOW_V_3_0_PLUS:
- ti = self.run_as_task(f, return_ti=True, multiple_outputs=False,
use_airflow_context=True)
- assert ti.state == TaskInstanceState.SUCCESS
- else:
- with pytest.raises(AirflowException,
match=AIRFLOW_CONTEXT_BEFORE_V3_0_MESSAGE):
- self.run_as_task(f, return_ti=True, use_airflow_context=True)
+ # TODO: replace with commented code when context serialization is
implemented in AIP-72
+ with pytest.raises(AirflowException,
match=AIRFLOW_CONTEXT_NOT_IMPLEMENTED_YET_MESSAGE):
+ self.run_as_task(f, return_ti=True, use_airflow_context=True)
+
+ # ti = self.run_as_task(f, return_ti=True, multiple_outputs=False,
use_airflow_context=True)
+ # assert ti.state == TaskInstanceState.SUCCESS
venv_cache_path = tempfile.mkdtemp(prefix="venv_cache_path")
@@ -1499,27 +1494,27 @@ class
TestPythonVirtualenvOperator(BaseTestPythonVirtualenvOperator):
return []
- if AIRFLOW_V_3_0_PLUS:
- ti = self.run_as_task(
+ # TODO: replace with commented code when context serialization is
implemented in AIP-72
+ with pytest.raises(AirflowException,
match=AIRFLOW_CONTEXT_NOT_IMPLEMENTED_YET_MESSAGE):
+ self.run_as_task(
f,
return_ti=True,
- multiple_outputs=False,
use_airflow_context=True,
session=session,
expect_airflow=False,
system_site_packages=True,
)
- assert ti.state == TaskInstanceState.SUCCESS
- else:
- with pytest.raises(AirflowException,
match=AIRFLOW_CONTEXT_BEFORE_V3_0_MESSAGE):
- self.run_as_task(
- f,
- return_ti=True,
- use_airflow_context=True,
- session=session,
- expect_airflow=False,
- system_site_packages=True,
- )
+
+ # ti = self.run_as_task(
+ # f,
+ # return_ti=True,
+ # multiple_outputs=False,
+ # use_airflow_context=True,
+ # session=session,
+ # expect_airflow=False,
+ # system_site_packages=True,
+ # )
+ # assert ti.state == TaskInstanceState.SUCCESS
# when venv tests are run in parallel to other test they create new processes
and this might take
@@ -1862,27 +1857,27 @@ class
TestBranchPythonVirtualenvOperator(BaseTestBranchPythonVirtualenvOperator)
return []
- if AIRFLOW_V_3_0_PLUS:
- ti = self.run_as_task(
+ # TODO: replace with commented code when context serialization is
implemented in AIP-72
+ with pytest.raises(AirflowException,
match=AIRFLOW_CONTEXT_NOT_IMPLEMENTED_YET_MESSAGE):
+ self.run_as_task(
f,
return_ti=True,
- multiple_outputs=False,
use_airflow_context=True,
session=session,
expect_airflow=False,
system_site_packages=True,
)
- assert ti.state == TaskInstanceState.SUCCESS
- else:
- with pytest.raises(AirflowException,
match=AIRFLOW_CONTEXT_BEFORE_V3_0_MESSAGE):
- self.run_as_task(
- f,
- return_ti=True,
- use_airflow_context=True,
- session=session,
- expect_airflow=False,
- system_site_packages=True,
- )
+
+ # ti = self.run_as_task(
+ # f,
+ # return_ti=True,
+ # multiple_outputs=False,
+ # use_airflow_context=True,
+ # session=session,
+ # expect_airflow=False,
+ # system_site_packages=True,
+ # )
+ # assert ti.state == TaskInstanceState.SUCCESS
# when venv tests are run in parallel to other test they create new processes
and this might take
diff --git a/tests/always/test_example_dags.py
b/tests/always/test_example_dags.py
index da45c64c0bb..3953db2cea2 100644
--- a/tests/always/test_example_dags.py
+++ b/tests/always/test_example_dags.py
@@ -149,6 +149,12 @@ def example_not_excluded_dags(xfail_db_exception: bool =
False):
pytest.mark.skip(reason=f"Not supported for Python
{CURRENT_PYTHON_VERSION}")
)
+ # TODO: remove when context serialization is implemented in AIP-72
+ if "/example_python_context_" in candidate:
+ param_marks.append(
+ pytest.mark.skip(reason="Temporary excluded until AIP-72
context serialization is done.")
+ )
+
for optional, dependencies in
OPTIONAL_PROVIDERS_DEPENDENCIES.items():
if re.match(optional, candidate):
for distribution_name, specifier in dependencies.items():
diff --git a/tests/assets/test_manager.py b/tests/assets/test_manager.py
index b716056e814..afbbfc23ade 100644
--- a/tests/assets/test_manager.py
+++ b/tests/assets/test_manager.py
@@ -18,7 +18,6 @@
from __future__ import annotations
import itertools
-from datetime import datetime
from unittest import mock
import pytest
@@ -37,7 +36,6 @@ from airflow.models.asset import (
from airflow.models.dag import DagModel
from airflow.models.dagbag import DagPriorityParsingRequest
from airflow.sdk.definitions.asset import Asset, AssetAlias
-from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
from tests.listeners import asset_listener
@@ -58,49 +56,8 @@ def clear_assets():
@pytest.fixture
def mock_task_instance():
- return TaskInstancePydantic(
- id="1",
- task_id="5",
- dag_id="7",
- run_id="11",
- map_index="13",
- start_date=datetime.now(),
- end_date=datetime.now(),
- logical_date=datetime.now(),
- duration=0.1,
- state="success",
- try_number=1,
- max_tries=4,
- hostname="host",
- unixname="unix",
- job_id=13,
- pool="default",
- pool_slots=1,
- queue="default",
- priority_weight=77,
- operator="DummyOperator",
- custom_operator_name="DummyOperator",
- queued_dttm=datetime.now(),
- queued_by_job_id=3,
- pid=12345,
- executor="default",
- executor_config=None,
- updated_at=datetime.now(),
- rendered_map_index="1",
- external_executor_id="x",
- trigger_id=1,
- trigger_timeout=datetime.now(),
- next_method="bla",
- next_kwargs=None,
- dag_version_id=None,
- run_as_user=None,
- task=None,
- test_mode=False,
- dag_run=None,
- dag_model=None,
- raw=False,
- is_trigger_log_context=False,
- )
+ # TODO: Fixme - some mock_task_instance is needed here
+ return None
def create_mock_dag():
diff --git a/tests/jobs/test_base_job.py b/tests/jobs/test_base_job.py
index 63f5ef1910a..f5f60605ca1 100644
--- a/tests/jobs/test_base_job.py
+++ b/tests/jobs/test_base_job.py
@@ -20,7 +20,6 @@ from __future__ import annotations
import datetime
import logging
import sys
-from typing import TYPE_CHECKING
from unittest.mock import ANY, Mock, patch
import pytest
@@ -37,9 +36,6 @@ from tests.listeners import lifecycle_listener
from tests.utils.test_helpers import MockJobRunner, SchedulerJobRunner,
TriggererJobRunner
from tests_common.test_utils.config import conf_vars
-if TYPE_CHECKING:
- from airflow.serialization.pydantic.job import JobPydantic
-
pytestmark = pytest.mark.db_test
@@ -131,7 +127,7 @@ class TestJob:
# heartrate should be 12 since we passed that to the constructor
directly
assert job.heartrate == 12
- def _compare_jobs(self, job1: Job | JobPydantic, job2: Job | JobPydantic):
+ def _compare_jobs(self, job1: Job, job2: Job):
"""Helper to compare two jobs where one can by Pydantic and the other
not."""
assert job1.id == job2.id
assert job1.dag_id == job2.dag_id
diff --git a/tests/serialization/test_serde.py
b/tests/serialization/test_serde.py
index 2fc8ad8d17b..60f895e3efa 100644
--- a/tests/serialization/test_serde.py
+++ b/tests/serialization/test_serde.py
@@ -18,7 +18,6 @@ from __future__ import annotations
import datetime
import enum
-import warnings
from collections import namedtuple
from dataclasses import dataclass
from importlib import import_module
@@ -439,15 +438,6 @@ class TestSerDe:
s = deserialize(e)
assert i == s
- def test_pydantic(self):
- pydantic = pytest.importorskip("pydantic", minversion="2.0.0")
- with warnings.catch_warnings():
- warnings.simplefilter("error",
category=pydantic.warnings.PydanticDeprecationWarning)
- i = U(x=10, v=V(W(10), ["l1", "l2"], (1, 2), 10), u=(1, 2))
- e = serialize(i)
- s = deserialize(e)
- assert i == s
-
def test_error_when_serializing_callable_without_name(self):
i = C()
with pytest.raises(
diff --git a/tests/serialization/test_serialized_objects.py
b/tests/serialization/test_serialized_objects.py
index 8100c2a84bc..0faeed038e6 100644
--- a/tests/serialization/test_serialized_objects.py
+++ b/tests/serialization/test_serialized_objects.py
@@ -17,19 +17,15 @@
from __future__ import annotations
-import inspect
import json
-import warnings
from collections.abc import Iterator
from datetime import datetime, timedelta
-from importlib import import_module
import pendulum
import pytest
from dateutil import relativedelta
from kubernetes.client import models as k8s
from pendulum.tz.timezone import Timezone
-from pydantic import BaseModel
from airflow.exceptions import (
AirflowException,
@@ -38,25 +34,16 @@ from airflow.exceptions import (
SerializationError,
TaskDeferred,
)
-from airflow.jobs.job import Job
-from airflow.models.asset import AssetEvent
from airflow.models.connection import Connection
-from airflow.models.dag import DAG, DagModel, DagTag
+from airflow.models.dag import DAG
from airflow.models.dagrun import DagRun
from airflow.models.param import Param
from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance
-from airflow.models.tasklog import LogTemplate
from airflow.models.xcom_arg import XComArg
from airflow.operators.empty import EmptyOperator
from airflow.providers.standard.operators.python import PythonOperator
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey
from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
-from airflow.serialization.pydantic.asset import AssetEventPydantic,
AssetPydantic
-from airflow.serialization.pydantic.dag import DagModelPydantic, DagTagPydantic
-from airflow.serialization.pydantic.dag_run import DagRunPydantic
-from airflow.serialization.pydantic.job import JobPydantic
-from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
-from airflow.serialization.pydantic.tasklog import LogTemplatePydantic
from airflow.serialization.serialized_objects import BaseSerialization
from airflow.triggers.base import BaseTrigger
from airflow.utils import timezone
@@ -353,162 +340,6 @@ def test_backcompat_deserialize_connection(conn_uri):
assert deserialized.get_uri() == conn_uri
-sample_objects = {
- JobPydantic: Job(state=State.RUNNING, latest_heartbeat=timezone.utcnow()),
- TaskInstancePydantic: TI_WITH_START_DAY,
- DagRunPydantic: DAG_RUN,
- DagModelPydantic: DagModel(
- dag_id="TEST_DAG_1",
- fileloc="/tmp/dag_1.py",
- timetable_summary="2 2 * * *",
- is_paused=True,
- ),
- LogTemplatePydantic: LogTemplate(
- id=1,
- filename="test_file",
- elasticsearch_id="test_id",
- created_at=datetime.now(),
- ),
- DagTagPydantic: DagTag(),
- AssetPydantic: Asset(name="test", uri="test://asset1", extra={}),
- AssetEventPydantic: AssetEvent(),
-}
-
-
[email protected](
- "input, pydantic_class, encoded_type, cmp_func",
- [
- (
- sample_objects.get(JobPydantic),
- JobPydantic,
- DAT.BASE_JOB,
- lambda a, b: equal_time(a.latest_heartbeat, b.latest_heartbeat),
- ),
- (
- sample_objects.get(TaskInstancePydantic),
- TaskInstancePydantic,
- DAT.TASK_INSTANCE,
- lambda a, b: equal_time(a.start_date, b.start_date),
- ),
- (
- sample_objects.get(DagRunPydantic),
- DagRunPydantic,
- DAT.DAG_RUN,
- lambda a, b: equal_time(a.logical_date, b.logical_date)
- and equal_time(a.start_date, b.start_date),
- ),
- # Asset is already serialized by non-Pydantic serialization. Is
AssetPydantic needed then?
- # (
- # Asset(
- # uri="foo://bar",
- # extra={"foo": "bar"},
- # ),
- # AssetPydantic,
- # DAT.ASSET,
- # lambda a, b: a.uri == b.uri and a.extra == b.extra,
- # ),
- (
- sample_objects.get(DagModelPydantic),
- DagModelPydantic,
- DAT.DAG_MODEL,
- lambda a, b: a.fileloc == b.fileloc and a.timetable_summary ==
b.timetable_summary,
- ),
- (
- sample_objects.get(LogTemplatePydantic),
- LogTemplatePydantic,
- DAT.LOG_TEMPLATE,
- lambda a, b: a.id == b.id and a.filename == b.filename and
equal_time(a.created_at, b.created_at),
- ),
- ],
-)
-def test_serialize_deserialize_pydantic(input, pydantic_class, encoded_type,
cmp_func):
- """If use_pydantic_models=True the objects should be serialized to
Pydantic objects."""
- pydantic = pytest.importorskip("pydantic", minversion="2.0.0")
-
- from airflow.serialization.serialized_objects import BaseSerialization
-
- with warnings.catch_warnings():
- warnings.simplefilter("error",
category=pydantic.warnings.PydanticDeprecationWarning)
-
- serialized = BaseSerialization.serialize(input,
use_pydantic_models=True) # does not raise
- # Verify the result is JSON-serializable
- json.dumps(serialized) # does not raise
- assert serialized["__type"] == encoded_type
- assert serialized["__var"] is not None
- deserialized = BaseSerialization.deserialize(serialized,
use_pydantic_models=True)
- assert isinstance(deserialized, pydantic_class)
- assert cmp_func(input, deserialized)
-
- # verify that when we round trip a pydantic model we get the same thing
- reserialized = BaseSerialization.serialize(deserialized,
use_pydantic_models=True)
- dereserialized = BaseSerialization.deserialize(reserialized,
use_pydantic_models=True)
- assert isinstance(dereserialized, pydantic_class)
-
- if encoded_type == "task_instance":
- deserialized.task.dag = None
- dereserialized.task.dag = None
-
- assert dereserialized == deserialized
-
- # Verify recursive behavior
- obj = [[input]]
- BaseSerialization.serialize(obj, use_pydantic_models=True) # does not
raise
-
-
-def test_all_pydantic_models_round_trip():
- pytest.importorskip("pydantic", minversion="2.0.0")
- classes = set()
- mods_folder = REPO_ROOT / "airflow/serialization/pydantic"
- for p in mods_folder.iterdir():
- if p.name.startswith("__"):
- continue
- relpath = str(p.relative_to(REPO_ROOT).stem)
- mod = import_module(f"airflow.serialization.pydantic.{relpath}")
- for _, obj in inspect.getmembers(mod):
- if inspect.isclass(obj) and issubclass(obj, BaseModel):
- if obj == BaseModel:
- continue
- classes.add(obj)
- exclusion_list = {
- "AssetPydantic",
- "DagTagPydantic",
- "DagScheduleAssetReferencePydantic",
- "TaskOutletAssetReferencePydantic",
- "DagOwnerAttributesPydantic",
- "AssetEventPydantic",
- "TriggerPydantic",
- }
- for c in sorted(classes, key=str):
- if c.__name__ in exclusion_list:
- continue
- orm_instance = sample_objects.get(c)
- if not orm_instance:
- pytest.fail(
- f"Class {c.__name__} not set up for testing. Either (1) add"
- f" to `sample_objects` an object for testing roundtrip or"
- f" (2) add class name to `exclusion list` if it does not"
- f" need to be serialized directly."
- )
- orm_ser = BaseSerialization.serialize(orm_instance,
use_pydantic_models=True)
- pydantic_instance = BaseSerialization.deserialize(orm_ser,
use_pydantic_models=True)
- if isinstance(pydantic_instance, str):
- pytest.fail(
- f"The model object {orm_instance.__class__} came back as a
string "
- f"after round trip. Probably you need to define a
DagAttributeType "
- f"for it and define it in mappings `_orm_to_model` and
`_type_to_class` "
- f"in `serialized_objects.py`"
- )
- assert isinstance(pydantic_instance, c)
- serialized = BaseSerialization.serialize(pydantic_instance,
use_pydantic_models=True)
- deserialized = BaseSerialization.deserialize(serialized,
use_pydantic_models=True)
- assert isinstance(deserialized, c)
- if isinstance(pydantic_instance, TaskInstancePydantic):
- # we can't access the dag on deserialization; but there is no dag
here.
- deserialized.task.dag = None
- pydantic_instance.task.dag = None
- assert pydantic_instance == deserialized
-
-
@pytest.mark.db_test
def test_serialized_mapped_operator_unmap(dag_maker):
from airflow.serialization.serialized_objects import SerializedDAG