This is an automated email from the ASF dual-hosted git repository.
shahar pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new f58bd73eb0c Remove Pydantic classes from models/dag (#44509)
f58bd73eb0c is described below
commit f58bd73eb0c8d5f2b90174b9f26cbea6afe8bd48
Author: LIU ZHE YOU <[email protected]>
AuthorDate: Sun Dec 1 04:01:20 2024 +0800
Remove Pydantic classes from models/dag (#44509)
* Fix: Remove Pydantic classes from models/dag
* Fix mypy error
---
airflow/models/dag.py | 10 ++++------
airflow/utils/log/file_task_handler.py | 3 +--
2 files changed, 5 insertions(+), 8 deletions(-)
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index e7947499f7d..5cc7bb47d21 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -118,8 +118,6 @@ if TYPE_CHECKING:
from airflow.models.abstractoperator import TaskStateChangeCallback
from airflow.models.dagbag import DagBag
from airflow.models.operator import Operator
- from airflow.serialization.pydantic.dag import DagModelPydantic
- from airflow.serialization.pydantic.dag_run import DagRunPydantic
from airflow.typing_compat import Literal
log = logging.getLogger(__name__)
@@ -513,7 +511,7 @@ class DAG(TaskSDKDag, LoggingMixin):
# infer from the logical date.
return self.infer_automated_data_interval(dag_model.next_dagrun)
- def get_run_data_interval(self, run: DagRun | DagRunPydantic) ->
DataInterval:
+ def get_run_data_interval(self, run: DagRun) -> DataInterval:
"""
Get the data interval of this run.
@@ -873,7 +871,7 @@ class DAG(TaskSDKDag, LoggingMixin):
@staticmethod
@provide_session
- def fetch_dagrun(dag_id: str, run_id: str, session: Session = NEW_SESSION)
-> DagRun | DagRunPydantic:
+ def fetch_dagrun(dag_id: str, run_id: str, session: Session = NEW_SESSION)
-> DagRun:
"""
Return the dag run for a given run_id if it exists, otherwise none.
@@ -885,7 +883,7 @@ class DAG(TaskSDKDag, LoggingMixin):
return session.scalar(select(DagRun).where(DagRun.dag_id == dag_id,
DagRun.run_id == run_id))
@provide_session
- def get_dagrun(self, run_id: str, session: Session = NEW_SESSION) ->
DagRun | DagRunPydantic:
+ def get_dagrun(self, run_id: str, session: Session = NEW_SESSION) ->
DagRun:
return DAG.fetch_dagrun(dag_id=self.dag_id, run_id=run_id,
session=session)
@provide_session
@@ -2139,7 +2137,7 @@ class DagModel(Base):
@classmethod
@provide_session
- def get_current(cls, dag_id: str, session=NEW_SESSION) -> DagModel |
DagModelPydantic:
+ def get_current(cls, dag_id: str, session=NEW_SESSION) -> DagModel:
return session.scalar(select(cls).where(cls.dag_id == dag_id))
@provide_session
diff --git a/airflow/utils/log/file_task_handler.py
b/airflow/utils/log/file_task_handler.py
index 9ec14c98dbb..38b17a9a9f1 100644
--- a/airflow/utils/log/file_task_handler.py
+++ b/airflow/utils/log/file_task_handler.py
@@ -48,7 +48,6 @@ if TYPE_CHECKING:
from airflow.models import DagRun
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancekey import TaskInstanceKey
- from airflow.serialization.pydantic.dag_run import DagRunPydantic
from airflow.serialization.pydantic.taskinstance import
TaskInstancePydantic
logger = logging.getLogger(__name__)
@@ -271,7 +270,7 @@ class FileTaskHandler(logging.Handler):
@provide_session
def _render_filename_db_access(
*, ti: TaskInstance | TaskInstancePydantic, try_number: int,
session=None
- ) -> tuple[DagRun | DagRunPydantic, TaskInstance | TaskInstancePydantic,
str | None, str | None]:
+ ) -> tuple[DagRun, TaskInstance | TaskInstancePydantic, str | None, str |
None]:
ti = _ensure_ti(ti, session)
dag_run = ti.get_dagrun(session=session)
template = dag_run.get_log_template(session=session).filename