This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 3c0a714cb5 [AIP-44] Migrate run_task_by_local_task_job to Internal API 
(#35527)
3c0a714cb5 is described below

commit 3c0a714cb57894b0816bf39079e29d79ea0b1d0a
Author: mhenc <[email protected]>
AuthorDate: Wed Nov 15 19:41:33 2023 +0100

    [AIP-44] Migrate run_task_by_local_task_job to Internal API (#35527)
---
 airflow/api_internal/endpoints/rpc_api_endpoint.py |   2 +
 airflow/cli/commands/task_command.py               |   2 +-
 airflow/jobs/local_task_job_runner.py              |   7 +-
 airflow/models/taskinstance.py                     | 242 ++++++++++++++-------
 airflow/serialization/pydantic/dag.py              |  81 ++++++-
 airflow/serialization/pydantic/dag_run.py          |  28 +--
 airflow/serialization/pydantic/taskinstance.py     |  85 +++++++-
 tests/serialization/test_pydantic_models.py        |  20 +-
 8 files changed, 353 insertions(+), 114 deletions(-)

diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py 
b/airflow/api_internal/endpoints/rpc_api_endpoint.py
index 945587b4e9..f451659cc0 100644
--- a/airflow/api_internal/endpoints/rpc_api_endpoint.py
+++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py
@@ -77,9 +77,11 @@ def _initialize_map() -> dict[str, Callable]:
         DagRun.get_previous_scheduled_dagrun,
         DagRun.fetch_task_instance,
         SerializedDagModel.get_serialized_dag,
+        TaskInstance._check_and_change_state_before_execution,
         TaskInstance.get_task_instance,
         TaskInstance.fetch_handle_failure_context,
         TaskInstance.save_to_db,
+        TaskInstance._schedule_downstream_tasks,
         Trigger.from_object,
         Trigger.bulk_fetch,
         Trigger.clean_unused,
diff --git a/airflow/cli/commands/task_command.py 
b/airflow/cli/commands/task_command.py
index 8a83fa8e8a..5c7c47d69b 100644
--- a/airflow/cli/commands/task_command.py
+++ b/airflow/cli/commands/task_command.py
@@ -259,7 +259,7 @@ def _run_task_by_executor(args, dag: DAG, ti: TaskInstance) 
-> None:
     executor.end()
 
 
-def _run_task_by_local_task_job(args, ti: TaskInstance) -> TaskReturnCode | 
None:
+def _run_task_by_local_task_job(args, ti: TaskInstance | TaskInstancePydantic) 
-> TaskReturnCode | None:
     """Run LocalTaskJob, which monitors the raw task execution process."""
     job_runner = LocalTaskJobRunner(
         job=Job(dag_id=ti.dag_id),
diff --git a/airflow/jobs/local_task_job_runner.py 
b/airflow/jobs/local_task_job_runner.py
index e068d88203..e12ecefcc0 100644
--- a/airflow/jobs/local_task_job_runner.py
+++ b/airflow/jobs/local_task_job_runner.py
@@ -41,6 +41,7 @@ if TYPE_CHECKING:
 
     from airflow.jobs.job import Job
     from airflow.models.taskinstance import TaskInstance
+    from airflow.serialization.pydantic.taskinstance import 
TaskInstancePydantic
 
 SIGSEGV_MESSAGE = """
 ******************************************* Received SIGSEGV 
*******************************************
@@ -81,7 +82,7 @@ class LocalTaskJobRunner(BaseJobRunner, LoggingMixin):
     def __init__(
         self,
         job: Job,
-        task_instance: TaskInstance,  # TODO add TaskInstancePydantic
+        task_instance: TaskInstance | TaskInstancePydantic,
         ignore_all_deps: bool = False,
         ignore_depends_on_past: bool = False,
         wait_for_past_depends_before_skipping: bool = False,
@@ -284,7 +285,9 @@ class LocalTaskJobRunner(BaseJobRunner, LoggingMixin):
             if ti.state == TaskInstanceState.SKIPPED:
                 # A DagRun timeout will cause tasks to be externally marked as 
skipped.
                 dagrun = ti.get_dagrun(session=session)
-                execution_time = (dagrun.end_date or timezone.utcnow()) - 
dagrun.start_date
+                execution_time = (dagrun.end_date or timezone.utcnow()) - (
+                    dagrun.start_date or timezone.utcnow()
+                )
                 if ti.task.dag is not None:
                     dagrun_timeout = ti.task.dag.dagrun_timeout
                 else:
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index fe09bba732..95a2f5945f 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -146,6 +146,7 @@ if TYPE_CHECKING:
     from airflow.models.dagrun import DagRun
     from airflow.models.dataset import DatasetEvent
     from airflow.models.operator import Operator
+    from airflow.serialization.pydantic.dag import DagModelPydantic
     from airflow.serialization.pydantic.taskinstance import 
TaskInstancePydantic
     from airflow.timetables.base import DataInterval
     from airflow.typing_compat import Literal, TypeGuard
@@ -1276,7 +1277,7 @@ class TaskInstance(Base, LoggingMixin):
         ),
     )
 
-    dag_model = relationship(
+    dag_model: DagModel = relationship(
         "DagModel",
         primaryjoin="TaskInstance.dag_id == DagModel.dag_id",
         foreign_keys=dag_id,
@@ -1451,32 +1452,31 @@ class TaskInstance(Base, LoggingMixin):
         """@property: use a more friendly display name for the operator, if 
set."""
         return self.custom_operator_name or self.operator
 
-    def command_as_list(
-        self,
-        mark_success=False,
-        ignore_all_deps=False,
-        ignore_task_deps=False,
-        ignore_depends_on_past=False,
-        wait_for_past_depends_before_skipping=False,
-        ignore_ti_state=False,
-        local=False,
+    @staticmethod
+    def _command_as_list(
+        ti: TaskInstance | TaskInstancePydantic,
+        mark_success: bool = False,
+        ignore_all_deps: bool = False,
+        ignore_task_deps: bool = False,
+        ignore_depends_on_past: bool = False,
+        wait_for_past_depends_before_skipping: bool = False,
+        ignore_ti_state: bool = False,
+        local: bool = False,
         pickle_id: int | None = None,
-        raw=False,
-        job_id=None,
-        pool=None,
-        cfg_path=None,
+        raw: bool = False,
+        job_id: str | None = None,
+        pool: str | None = None,
+        cfg_path: str | None = None,
     ) -> list[str]:
-        """
-        Return a command that can be executed anywhere where airflow is 
installed.
-
-        This command is part of the message sent to executors by the 
orchestrator.
-        """
-        dag: DAG | DagModel
+        dag: DAG | DagModel | DagModelPydantic | None
         # Use the dag if we have it, else fallback to the ORM dag_model, which 
might not be loaded
-        if hasattr(self, "task") and hasattr(self.task, "dag") and 
self.task.dag is not None:
-            dag = self.task.dag
+        if hasattr(ti, "task") and hasattr(ti.task, "dag") and ti.task.dag is 
not None:
+            dag = ti.task.dag
         else:
-            dag = self.dag_model
+            dag = ti.dag_model
+
+        if dag is None:
+            raise ValueError("DagModel is empty")
 
         should_pass_filepath = not pickle_id and dag
         path: PurePath | None = None
@@ -1493,9 +1493,9 @@ class TaskInstance(Base, LoggingMixin):
                     path = "DAGS_FOLDER" / path
 
         return TaskInstance.generate_command(
-            self.dag_id,
-            self.task_id,
-            run_id=self.run_id,
+            ti.dag_id,
+            ti.task_id,
+            run_id=ti.run_id,
             mark_success=mark_success,
             ignore_all_deps=ignore_all_deps,
             ignore_task_deps=ignore_task_deps,
@@ -1509,7 +1509,43 @@ class TaskInstance(Base, LoggingMixin):
             job_id=job_id,
             pool=pool,
             cfg_path=cfg_path,
-            map_index=self.map_index,
+            map_index=ti.map_index,
+        )
+
+    def command_as_list(
+        self,
+        mark_success: bool = False,
+        ignore_all_deps: bool = False,
+        ignore_task_deps: bool = False,
+        ignore_depends_on_past: bool = False,
+        wait_for_past_depends_before_skipping: bool = False,
+        ignore_ti_state: bool = False,
+        local: bool = False,
+        pickle_id: int | None = None,
+        raw: bool = False,
+        job_id: str | None = None,
+        pool: str | None = None,
+        cfg_path: str | None = None,
+    ) -> list[str]:
+        """
+        Return a command that can be executed anywhere where airflow is 
installed.
+
+        This command is part of the message sent to executors by the 
orchestrator.
+        """
+        return TaskInstance._command_as_list(
+            ti=self,
+            mark_success=mark_success,
+            ignore_all_deps=ignore_all_deps,
+            ignore_task_deps=ignore_task_deps,
+            ignore_depends_on_past=ignore_depends_on_past,
+            
wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,
+            ignore_ti_state=ignore_ti_state,
+            local=local,
+            pickle_id=pickle_id,
+            raw=raw,
+            job_id=job_id,
+            pool=pool,
+            cfg_path=cfg_path,
         )
 
     @staticmethod
@@ -2014,9 +2050,12 @@ class TaskInstance(Base, LoggingMixin):
 
         return dr
 
+    @classmethod
+    @internal_api_call
     @provide_session
-    def check_and_change_state_before_execution(
-        self,
+    def _check_and_change_state_before_execution(
+        cls,
+        task_instance: TaskInstance | TaskInstancePydantic,
         verbose: bool = True,
         ignore_all_deps: bool = False,
         ignore_depends_on_past: bool = False,
@@ -2025,6 +2064,7 @@ class TaskInstance(Base, LoggingMixin):
         ignore_ti_state: bool = False,
         mark_success: bool = False,
         test_mode: bool = False,
+        hostname: str = "",
         job_id: str | None = None,
         pool: str | None = None,
         external_executor_id: str | None = None,
@@ -2044,22 +2084,28 @@ class TaskInstance(Base, LoggingMixin):
         :param ignore_ti_state: Disregards previous task instance state
         :param mark_success: Don't run the task, mark its state as success
         :param test_mode: Doesn't record success or failure in the DB
+        :param hostname: The hostname of the worker running the task instance.
         :param job_id: Job (BackfillJob / LocalTaskJob / SchedulerJob) ID
         :param pool: specifies the pool to use to run the task instance
         :param external_executor_id: The identifier of the celery executor
         :param session: SQLAlchemy ORM Session
         :return: whether the state was changed to running or not
         """
-        task = self.task
-        self.refresh_from_task(task, pool_override=pool)
-        self.test_mode = test_mode
-        self.refresh_from_db(session=session, lock_for_update=True)
-        self.job_id = job_id
-        self.hostname = get_hostname()
-        self.pid = None
-
-        if not ignore_all_deps and not ignore_ti_state and self.state == 
TaskInstanceState.SUCCESS:
-            Stats.incr("previously_succeeded", tags=self.stats_tags)
+        if isinstance(task_instance, TaskInstance):
+            ti: TaskInstance = task_instance
+        else:  # isinstance(task_instance,TaskInstancePydantic)
+            filters = (col == getattr(task_instance, col.name) for col in 
inspect(TaskInstance).primary_key)
+            ti = session.query(TaskInstance).filter(*filters).scalar()
+        task = task_instance.task
+        ti.refresh_from_task(task, pool_override=pool)
+        ti.test_mode = test_mode
+        ti.refresh_from_db(session=session, lock_for_update=True)
+        ti.job_id = job_id
+        ti.hostname = hostname
+        ti.pid = None
+
+        if not ignore_all_deps and not ignore_ti_state and ti.state == 
TaskInstanceState.SUCCESS:
+            Stats.incr("previously_succeeded", tags=ti.stats_tags)
 
         if not mark_success:
             # Firstly find non-runnable and non-requeueable tis.
@@ -2073,7 +2119,7 @@ class TaskInstance(Base, LoggingMixin):
                 ignore_task_deps=ignore_task_deps,
                 description="non-requeueable deps",
             )
-            if not self.are_dependencies_met(
+            if not ti.are_dependencies_met(
                 dep_context=non_requeueable_dep_context, session=session, 
verbose=True
             ):
                 session.commit()
@@ -2085,15 +2131,13 @@ class TaskInstance(Base, LoggingMixin):
             # Set the task start date. In case it was re-scheduled use the 
initial
             # start date that is recorded in task_reschedule table
             # If the task continues after being deferred (next_method is set), 
use the original start_date
-            self.start_date = self.start_date if self.next_method else 
timezone.utcnow()
-            if self.state == TaskInstanceState.UP_FOR_RESCHEDULE:
+            ti.start_date = ti.start_date if ti.next_method else 
timezone.utcnow()
+            if ti.state == TaskInstanceState.UP_FOR_RESCHEDULE:
                 tr_start_date = session.scalar(
-                    TR.stmt_for_task_instance(self, descending=False)
-                    .with_only_columns(TR.start_date)
-                    .limit(1)
+                    TR.stmt_for_task_instance(ti, 
descending=False).with_only_columns(TR.start_date).limit(1)
                 )
                 if tr_start_date:
-                    self.start_date = tr_start_date
+                    ti.start_date = tr_start_date
 
             # Secondly we find non-runnable but requeueable tis. We reset its 
state.
             # This is because we might have hit concurrency limits,
@@ -2107,35 +2151,35 @@ class TaskInstance(Base, LoggingMixin):
                 ignore_ti_state=ignore_ti_state,
                 description="requeueable deps",
             )
-            if not self.are_dependencies_met(dep_context=dep_context, 
session=session, verbose=True):
-                self.state = None
-                self.log.warning(
+            if not ti.are_dependencies_met(dep_context=dep_context, 
session=session, verbose=True):
+                ti.state = None
+                cls.logger().warning(
                     "Rescheduling due to concurrency limits reached "
                     "at task runtime. Attempt %s of "
                     "%s. State set to NONE.",
-                    self.try_number,
-                    self.max_tries + 1,
+                    ti.try_number,
+                    ti.max_tries + 1,
                 )
-                self.queued_dttm = timezone.utcnow()
-                session.merge(self)
+                ti.queued_dttm = timezone.utcnow()
+                session.merge(ti)
                 session.commit()
                 return False
 
-        if self.next_kwargs is not None:
-            self.log.info("Resuming after deferral")
+        if ti.next_kwargs is not None:
+            cls.logger().info("Resuming after deferral")
         else:
-            self.log.info("Starting attempt %s of %s", self.try_number, 
self.max_tries + 1)
-        self._try_number += 1
+            cls.logger().info("Starting attempt %s of %s", ti.try_number, 
ti.max_tries + 1)
+        ti._try_number += 1
 
         if not test_mode:
-            session.add(Log(TaskInstanceState.RUNNING.value, self))
+            session.add(Log(TaskInstanceState.RUNNING.value, ti))
 
-        self.state = TaskInstanceState.RUNNING
-        self.emit_state_change_metric(TaskInstanceState.RUNNING)
-        self.external_executor_id = external_executor_id
-        self.end_date = None
+        ti.state = TaskInstanceState.RUNNING
+        ti.emit_state_change_metric(TaskInstanceState.RUNNING)
+        ti.external_executor_id = external_executor_id
+        ti.end_date = None
         if not test_mode:
-            session.merge(self).task = task
+            session.merge(ti).task = task
         session.commit()
 
         # Closing all pooled connections to prevent
@@ -2143,11 +2187,44 @@ class TaskInstance(Base, LoggingMixin):
         settings.engine.dispose()  # type: ignore
         if verbose:
             if mark_success:
-                self.log.info("Marking success for %s on %s", self.task, 
self.execution_date)
+                cls.logger().info("Marking success for %s on %s", ti.task, 
ti.execution_date)
             else:
-                self.log.info("Executing %s on %s", self.task, 
self.execution_date)
+                cls.logger().info("Executing %s on %s", ti.task, 
ti.execution_date)
         return True
 
+    @provide_session
+    def check_and_change_state_before_execution(
+        self,
+        verbose: bool = True,
+        ignore_all_deps: bool = False,
+        ignore_depends_on_past: bool = False,
+        wait_for_past_depends_before_skipping: bool = False,
+        ignore_task_deps: bool = False,
+        ignore_ti_state: bool = False,
+        mark_success: bool = False,
+        test_mode: bool = False,
+        job_id: str | None = None,
+        pool: str | None = None,
+        external_executor_id: str | None = None,
+        session: Session = NEW_SESSION,
+    ) -> bool:
+        return TaskInstance._check_and_change_state_before_execution(
+            task_instance=self,
+            verbose=verbose,
+            ignore_all_deps=ignore_all_deps,
+            ignore_depends_on_past=ignore_depends_on_past,
+            
wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,
+            ignore_task_deps=ignore_task_deps,
+            ignore_ti_state=ignore_ti_state,
+            mark_success=mark_success,
+            test_mode=test_mode,
+            hostname=get_hostname(),
+            job_id=job_id,
+            pool=pool,
+            external_executor_id=external_executor_id,
+            session=session,
+        )
+
     def emit_state_change_metric(self, new_state: TaskInstanceState) -> None:
         """
         Send a time metric representing how much time a given state transition 
took.
@@ -3137,14 +3214,16 @@ class TaskInstance(Base, LoggingMixin):
             return filters[0]
         return or_(*filters)
 
+    @classmethod
+    @internal_api_call
     @Sentry.enrich_errors
     @provide_session
-    def schedule_downstream_tasks(self, session: Session = NEW_SESSION, 
max_tis_per_query: int | None = None):
-        """
-        Schedule downstream tasks of this task instance.
-
-        :meta: private
-        """
+    def _schedule_downstream_tasks(
+        cls,
+        ti: TaskInstance | TaskInstancePydantic,
+        session: Session = NEW_SESSION,
+        max_tis_per_query: int | None = None,
+    ):
         from sqlalchemy.exc import OperationalError
 
         from airflow.models.dagrun import DagRun
@@ -3153,13 +3232,13 @@ class TaskInstance(Base, LoggingMixin):
             # Re-select the row with a lock
             dag_run = with_row_locks(
                 session.query(DagRun).filter_by(
-                    dag_id=self.dag_id,
-                    run_id=self.run_id,
+                    dag_id=ti.dag_id,
+                    run_id=ti.run_id,
                 ),
                 session=session,
             ).one()
 
-            task = self.task
+            task = ti.task
             if TYPE_CHECKING:
                 assert task.dag
 
@@ -3196,19 +3275,30 @@ class TaskInstance(Base, LoggingMixin):
                     schedulable_ti.task = 
task.dag.get_task(schedulable_ti.task_id)
 
             num = dag_run.schedule_tis(schedulable_tis, session=session, 
max_tis_per_query=max_tis_per_query)
-            self.log.info("%d downstream tasks scheduled from follow-on 
schedule check", num)
+            cls.logger().info("%d downstream tasks scheduled from follow-on 
schedule check", num)
 
             session.flush()
 
         except OperationalError as e:
             # Any kind of DB error here is _non fatal_ as this block is just 
an optimisation.
-            self.log.info(
+            cls.logger().info(
                 "Skipping mini scheduling run due to exception: %s",
                 e.statement,
                 exc_info=True,
             )
             session.rollback()
 
+    @provide_session
+    def schedule_downstream_tasks(self, session: Session = NEW_SESSION, 
max_tis_per_query: int | None = None):
+        """
+        Schedule downstream tasks of this task instance.
+
+        :meta: private
+        """
+        return TaskInstance._schedule_downstream_tasks(
+            ti=self, session=session, max_tis_per_query=max_tis_per_query
+        )
+
     def get_relevant_upstream_map_indexes(
         self,
         upstream: Operator,
diff --git a/airflow/serialization/pydantic/dag.py 
b/airflow/serialization/pydantic/dag.py
index 48ce1f0c56..6631afdf73 100644
--- a/airflow/serialization/pydantic/dag.py
+++ b/airflow/serialization/pydantic/dag.py
@@ -16,12 +16,68 @@
 # under the License.
 from __future__ import annotations
 
-from datetime import datetime
-from typing import List, Optional
+import pathlib
+from datetime import datetime, timedelta
+from typing import Any, List, Optional
 
-from pydantic import BaseModel as BaseModelPydantic
+from dateutil import relativedelta
+from pydantic import BaseModel as BaseModelPydantic, PlainSerializer, 
PlainValidator, ValidationInfo
+from typing_extensions import Annotated
 
+from airflow import DAG, settings
 from airflow.configuration import conf as airflow_conf
+from airflow.utils.sqlalchemy import Interval
+
+
+def serialize_interval(value: Interval) -> Interval:
+    interval = Interval()
+    return interval.process_bind_param(value, None)
+
+
+def validate_interval(value: Interval | Any, _info: ValidationInfo) -> Any:
+    if (
+        isinstance(value, Interval)
+        or isinstance(value, timedelta)
+        or isinstance(value, relativedelta.relativedelta)
+    ):
+        return value
+    interval = Interval()
+    try:
+        return interval.process_result_value(value, None)
+    except ValueError as e:
+        # Interval may be provided in string format (cron),
+        # so it must be returned as valid value.
+        if isinstance(value, str):
+            return value
+        raise e
+
+
+PydanticInterval = Annotated[
+    Interval,
+    PlainValidator(validate_interval),
+    PlainSerializer(serialize_interval, return_type=Interval),
+]
+
+
+def serialize_operator(x: DAG) -> dict:
+    from airflow.serialization.serialized_objects import SerializedDAG
+
+    return SerializedDAG.serialize_dag(x)
+
+
+def validate_operator(x: DAG | dict[str, Any], _info: ValidationInfo) -> Any:
+    from airflow.serialization.serialized_objects import SerializedDAG
+
+    if isinstance(x, DAG):
+        return x
+    return SerializedDAG.deserialize_dag(x)
+
+
+PydanticDag = Annotated[
+    DAG,
+    PlainValidator(validate_operator),
+    PlainSerializer(serialize_operator, return_type=dict),
+]
 
 
 class DagOwnerAttributesPydantic(BaseModelPydantic):
@@ -71,10 +127,11 @@ class DagModelPydantic(BaseModelPydantic):
     owners: Optional[str]
     description: Optional[str]
     default_view: Optional[str]
-    schedule_interval: Optional[str]
+    schedule_interval: Optional[PydanticInterval]
     timetable_description: Optional[str]
     tags: List[DagTagPydantic]  # noqa
     dag_owner_links: List[DagOwnerAttributesPydantic]  # noqa
+    parent_dag: Optional[PydanticDag]
 
     max_active_tasks: int
     max_active_runs: Optional[int]
@@ -82,9 +139,25 @@ class DagModelPydantic(BaseModelPydantic):
     has_task_concurrency_limits: bool
     has_import_errors: Optional[bool] = False
 
+    _processor_dags_folder: Optional[str] = None
+
     class Config:
         """Make sure it deals automatically with SQLAlchemy ORM classes."""
 
         from_attributes = True
         orm_mode = True  # Pydantic 1.x compatibility.
         arbitrary_types_allowed = True
+
+    @property
+    def relative_fileloc(self) -> pathlib.Path:
+        """File location of the importable dag 'file' relative to the 
configured DAGs folder."""
+        path = pathlib.Path(self.fileloc)
+        try:
+            rel_path = path.relative_to(self._processor_dags_folder or 
settings.DAGS_FOLDER)
+            if rel_path == pathlib.Path("."):
+                return path
+            else:
+                return rel_path
+        except ValueError:
+            # Not relative to DAGS_FOLDER.
+            return path
diff --git a/airflow/serialization/pydantic/dag_run.py 
b/airflow/serialization/pydantic/dag_run.py
index d7f7aae27d..aaa4372a50 100644
--- a/airflow/serialization/pydantic/dag_run.py
+++ b/airflow/serialization/pydantic/dag_run.py
@@ -17,12 +17,11 @@
 from __future__ import annotations
 
 from datetime import datetime
-from typing import TYPE_CHECKING, Any, Iterable, List, Optional
+from typing import TYPE_CHECKING, Iterable, List, Optional
 
-from pydantic import BaseModel as BaseModelPydantic, PlainSerializer, 
PlainValidator, ValidationInfo
-from typing_extensions import Annotated
+from pydantic import BaseModel as BaseModelPydantic
 
-from airflow import DAG
+from airflow.serialization.pydantic.dag import PydanticDag
 from airflow.serialization.pydantic.dataset import DatasetEventPydantic
 from airflow.utils.session import NEW_SESSION, provide_session
 
@@ -34,27 +33,6 @@ if TYPE_CHECKING:
     from airflow.utils.state import TaskInstanceState
 
 
-def serialize_operator(x: DAG) -> dict:
-    from airflow.serialization.serialized_objects import SerializedDAG
-
-    return SerializedDAG.serialize_dag(x)
-
-
-def validated_operator(x: DAG | dict[str, Any], _info: ValidationInfo) -> Any:
-    from airflow.serialization.serialized_objects import SerializedDAG
-
-    if isinstance(x, DAG):
-        return x
-    return SerializedDAG.deserialize_dag(x)
-
-
-PydanticDag = Annotated[
-    DAG,
-    PlainValidator(validated_operator),
-    PlainSerializer(serialize_operator, return_type=dict),
-]
-
-
 class DagRunPydantic(BaseModelPydantic):
     """Serializable representation of the DagRun ORM SqlAlchemyModel used by 
internal API."""
 
diff --git a/airflow/serialization/pydantic/taskinstance.py 
b/airflow/serialization/pydantic/taskinstance.py
index 7cd4868781..0043bfaef0 100644
--- a/airflow/serialization/pydantic/taskinstance.py
+++ b/airflow/serialization/pydantic/taskinstance.py
@@ -24,8 +24,11 @@ from typing_extensions import Annotated
 
 from airflow.models import Operator
 from airflow.models.baseoperator import BaseOperator
+from airflow.models.taskinstance import TaskInstance
+from airflow.serialization.pydantic.dag import DagModelPydantic
 from airflow.serialization.pydantic.dag_run import DagRunPydantic
 from airflow.utils.log.logging_mixin import LoggingMixin
+from airflow.utils.net import get_hostname
 from airflow.utils.session import NEW_SESSION, provide_session
 from airflow.utils.xcom import XCOM_RETURN_KEY
 
@@ -35,7 +38,6 @@ if TYPE_CHECKING:
     from sqlalchemy.orm import Session
 
     from airflow.models.dagrun import DagRun
-    from airflow.models.taskinstance import TaskInstance
     from airflow.utils.context import Context
     from airflow.utils.state import DagRunState
 
@@ -101,6 +103,7 @@ class TaskInstancePydantic(BaseModelPydantic, LoggingMixin):
     task: PydanticOperator
     test_mode: bool
     dag_run: Optional[DagRunPydantic]
+    dag_model: Optional[DagModelPydantic]
 
     class Config:
         """Make sure it deals automatically with SQLAlchemy ORM classes."""
@@ -351,5 +354,85 @@ class TaskInstancePydantic(BaseModelPydantic, 
LoggingMixin):
 
         return _get_previous_ti(task_instance=self, state=state, 
session=session)
 
+    @provide_session
+    def check_and_change_state_before_execution(
+        self,
+        verbose: bool = True,
+        ignore_all_deps: bool = False,
+        ignore_depends_on_past: bool = False,
+        wait_for_past_depends_before_skipping: bool = False,
+        ignore_task_deps: bool = False,
+        ignore_ti_state: bool = False,
+        mark_success: bool = False,
+        test_mode: bool = False,
+        job_id: str | None = None,
+        pool: str | None = None,
+        external_executor_id: str | None = None,
+        session: Session = NEW_SESSION,
+    ) -> bool:
+        return TaskInstance._check_and_change_state_before_execution(
+            task_instance=self,
+            verbose=verbose,
+            ignore_all_deps=ignore_all_deps,
+            ignore_depends_on_past=ignore_depends_on_past,
+            
wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,
+            ignore_task_deps=ignore_task_deps,
+            ignore_ti_state=ignore_ti_state,
+            mark_success=mark_success,
+            test_mode=test_mode,
+            hostname=get_hostname(),
+            job_id=job_id,
+            pool=pool,
+            external_executor_id=external_executor_id,
+            session=session,
+        )
+
+    @provide_session
+    def schedule_downstream_tasks(self, session: Session = NEW_SESSION, 
max_tis_per_query: int | None = None):
+        """
+        Schedule downstream tasks of this task instance.
+
+        :meta: private
+        """
+        return TaskInstance._schedule_downstream_tasks(
+            ti=self, sessions=session, max_tis_per_query=max_tis_per_query
+        )
+
+    def command_as_list(
+        self,
+        mark_success: bool = False,
+        ignore_all_deps: bool = False,
+        ignore_task_deps: bool = False,
+        ignore_depends_on_past: bool = False,
+        wait_for_past_depends_before_skipping: bool = False,
+        ignore_ti_state: bool = False,
+        local: bool = False,
+        pickle_id: int | None = None,
+        raw: bool = False,
+        job_id: str | None = None,
+        pool: str | None = None,
+        cfg_path: str | None = None,
+    ) -> list[str]:
+        """
+        Return a command that can be executed anywhere where airflow is 
installed.
+
+        This command is part of the message sent to executors by the 
orchestrator.
+        """
+        return TaskInstance._command_as_list(
+            ti=self,
+            mark_success=mark_success,
+            ignore_all_deps=ignore_all_deps,
+            ignore_task_deps=ignore_task_deps,
+            ignore_depends_on_past=ignore_depends_on_past,
+            
wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,
+            ignore_ti_state=ignore_ti_state,
+            local=local,
+            pickle_id=pickle_id,
+            raw=raw,
+            job_id=job_id,
+            pool=pool,
+            cfg_path=cfg_path,
+        )
+
 
 TaskInstancePydantic.model_rebuild()
diff --git a/tests/serialization/test_pydantic_models.py 
b/tests/serialization/test_pydantic_models.py
index b64d0b5aa8..326e4f239d 100644
--- a/tests/serialization/test_pydantic_models.py
+++ b/tests/serialization/test_pydantic_models.py
@@ -17,7 +17,10 @@
 # under the License.
 from __future__ import annotations
 
+import datetime
+
 import pytest
+from dateutil import relativedelta
 
 from airflow.jobs.job import Job
 from airflow.jobs.local_task_job_runner import LocalTaskJobRunner
@@ -80,23 +83,31 @@ def test_serializing_pydantic_dagrun(session, 
create_task_instance):
     assert deserialized_model.state == State.RUNNING
 
 
-def test_serializing_pydantic_dagmodel():
[email protected](
+    "schedule_interval",
+    [
+        None,
+        "*/10 * * *",
+        datetime.timedelta(days=1),
+        relativedelta.relativedelta(days=+12),
+    ],
+)
+def test_serializing_pydantic_dagmodel(schedule_interval):
     dag_model = DagModel(
         dag_id="test-dag",
         fileloc="/tmp/dag_1.py",
-        schedule_interval="2 2 * * *",
+        schedule_interval=schedule_interval,
         is_active=True,
         is_paused=False,
     )
 
     pydantic_dag_model = DagModelPydantic.model_validate(dag_model)
     json_string = pydantic_dag_model.model_dump_json()
-    print(json_string)
 
     deserialized_model = DagModelPydantic.model_validate_json(json_string)
     assert deserialized_model.dag_id == "test-dag"
     assert deserialized_model.fileloc == "/tmp/dag_1.py"
-    assert deserialized_model.schedule_interval == "2 2 * * *"
+    assert deserialized_model.schedule_interval == schedule_interval
     assert deserialized_model.is_active is True
     assert deserialized_model.is_paused is False
 
@@ -111,7 +122,6 @@ def test_serializing_pydantic_local_task_job(session, 
create_task_instance):
     pydantic_job = JobPydantic.model_validate(ltj)
 
     json_string = pydantic_job.model_dump_json()
-    print(json_string)
 
     deserialized_model = JobPydantic.model_validate_json(json_string)
     assert deserialized_model.dag_id == dag_id

Reply via email to