This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 3c0a714cb5 [AIP-44] Migrate run_task_by_local_task_job to Internal API
(#35527)
3c0a714cb5 is described below
commit 3c0a714cb57894b0816bf39079e29d79ea0b1d0a
Author: mhenc <[email protected]>
AuthorDate: Wed Nov 15 19:41:33 2023 +0100
[AIP-44] Migrate run_task_by_local_task_job to Internal API (#35527)
---
airflow/api_internal/endpoints/rpc_api_endpoint.py | 2 +
airflow/cli/commands/task_command.py | 2 +-
airflow/jobs/local_task_job_runner.py | 7 +-
airflow/models/taskinstance.py | 242 ++++++++++++++-------
airflow/serialization/pydantic/dag.py | 81 ++++++-
airflow/serialization/pydantic/dag_run.py | 28 +--
airflow/serialization/pydantic/taskinstance.py | 85 +++++++-
tests/serialization/test_pydantic_models.py | 20 +-
8 files changed, 353 insertions(+), 114 deletions(-)
diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py
b/airflow/api_internal/endpoints/rpc_api_endpoint.py
index 945587b4e9..f451659cc0 100644
--- a/airflow/api_internal/endpoints/rpc_api_endpoint.py
+++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py
@@ -77,9 +77,11 @@ def _initialize_map() -> dict[str, Callable]:
DagRun.get_previous_scheduled_dagrun,
DagRun.fetch_task_instance,
SerializedDagModel.get_serialized_dag,
+ TaskInstance._check_and_change_state_before_execution,
TaskInstance.get_task_instance,
TaskInstance.fetch_handle_failure_context,
TaskInstance.save_to_db,
+ TaskInstance._schedule_downstream_tasks,
Trigger.from_object,
Trigger.bulk_fetch,
Trigger.clean_unused,
diff --git a/airflow/cli/commands/task_command.py
b/airflow/cli/commands/task_command.py
index 8a83fa8e8a..5c7c47d69b 100644
--- a/airflow/cli/commands/task_command.py
+++ b/airflow/cli/commands/task_command.py
@@ -259,7 +259,7 @@ def _run_task_by_executor(args, dag: DAG, ti: TaskInstance)
-> None:
executor.end()
-def _run_task_by_local_task_job(args, ti: TaskInstance) -> TaskReturnCode |
None:
+def _run_task_by_local_task_job(args, ti: TaskInstance | TaskInstancePydantic)
-> TaskReturnCode | None:
"""Run LocalTaskJob, which monitors the raw task execution process."""
job_runner = LocalTaskJobRunner(
job=Job(dag_id=ti.dag_id),
diff --git a/airflow/jobs/local_task_job_runner.py
b/airflow/jobs/local_task_job_runner.py
index e068d88203..e12ecefcc0 100644
--- a/airflow/jobs/local_task_job_runner.py
+++ b/airflow/jobs/local_task_job_runner.py
@@ -41,6 +41,7 @@ if TYPE_CHECKING:
from airflow.jobs.job import Job
from airflow.models.taskinstance import TaskInstance
+ from airflow.serialization.pydantic.taskinstance import
TaskInstancePydantic
SIGSEGV_MESSAGE = """
******************************************* Received SIGSEGV
*******************************************
@@ -81,7 +82,7 @@ class LocalTaskJobRunner(BaseJobRunner, LoggingMixin):
def __init__(
self,
job: Job,
- task_instance: TaskInstance, # TODO add TaskInstancePydantic
+ task_instance: TaskInstance | TaskInstancePydantic,
ignore_all_deps: bool = False,
ignore_depends_on_past: bool = False,
wait_for_past_depends_before_skipping: bool = False,
@@ -284,7 +285,9 @@ class LocalTaskJobRunner(BaseJobRunner, LoggingMixin):
if ti.state == TaskInstanceState.SKIPPED:
# A DagRun timeout will cause tasks to be externally marked as
skipped.
dagrun = ti.get_dagrun(session=session)
- execution_time = (dagrun.end_date or timezone.utcnow()) -
dagrun.start_date
+ execution_time = (dagrun.end_date or timezone.utcnow()) - (
+ dagrun.start_date or timezone.utcnow()
+ )
if ti.task.dag is not None:
dagrun_timeout = ti.task.dag.dagrun_timeout
else:
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index fe09bba732..95a2f5945f 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -146,6 +146,7 @@ if TYPE_CHECKING:
from airflow.models.dagrun import DagRun
from airflow.models.dataset import DatasetEvent
from airflow.models.operator import Operator
+ from airflow.serialization.pydantic.dag import DagModelPydantic
from airflow.serialization.pydantic.taskinstance import
TaskInstancePydantic
from airflow.timetables.base import DataInterval
from airflow.typing_compat import Literal, TypeGuard
@@ -1276,7 +1277,7 @@ class TaskInstance(Base, LoggingMixin):
),
)
- dag_model = relationship(
+ dag_model: DagModel = relationship(
"DagModel",
primaryjoin="TaskInstance.dag_id == DagModel.dag_id",
foreign_keys=dag_id,
@@ -1451,32 +1452,31 @@ class TaskInstance(Base, LoggingMixin):
"""@property: use a more friendly display name for the operator, if
set."""
return self.custom_operator_name or self.operator
- def command_as_list(
- self,
- mark_success=False,
- ignore_all_deps=False,
- ignore_task_deps=False,
- ignore_depends_on_past=False,
- wait_for_past_depends_before_skipping=False,
- ignore_ti_state=False,
- local=False,
+ @staticmethod
+ def _command_as_list(
+ ti: TaskInstance | TaskInstancePydantic,
+ mark_success: bool = False,
+ ignore_all_deps: bool = False,
+ ignore_task_deps: bool = False,
+ ignore_depends_on_past: bool = False,
+ wait_for_past_depends_before_skipping: bool = False,
+ ignore_ti_state: bool = False,
+ local: bool = False,
pickle_id: int | None = None,
- raw=False,
- job_id=None,
- pool=None,
- cfg_path=None,
+ raw: bool = False,
+ job_id: str | None = None,
+ pool: str | None = None,
+ cfg_path: str | None = None,
) -> list[str]:
- """
- Return a command that can be executed anywhere where airflow is
installed.
-
- This command is part of the message sent to executors by the
orchestrator.
- """
- dag: DAG | DagModel
+ dag: DAG | DagModel | DagModelPydantic | None
# Use the dag if we have it, else fallback to the ORM dag_model, which
might not be loaded
- if hasattr(self, "task") and hasattr(self.task, "dag") and
self.task.dag is not None:
- dag = self.task.dag
+ if hasattr(ti, "task") and hasattr(ti.task, "dag") and ti.task.dag is
not None:
+ dag = ti.task.dag
else:
- dag = self.dag_model
+ dag = ti.dag_model
+
+ if dag is None:
+ raise ValueError("DagModel is empty")
should_pass_filepath = not pickle_id and dag
path: PurePath | None = None
@@ -1493,9 +1493,9 @@ class TaskInstance(Base, LoggingMixin):
path = "DAGS_FOLDER" / path
return TaskInstance.generate_command(
- self.dag_id,
- self.task_id,
- run_id=self.run_id,
+ ti.dag_id,
+ ti.task_id,
+ run_id=ti.run_id,
mark_success=mark_success,
ignore_all_deps=ignore_all_deps,
ignore_task_deps=ignore_task_deps,
@@ -1509,7 +1509,43 @@ class TaskInstance(Base, LoggingMixin):
job_id=job_id,
pool=pool,
cfg_path=cfg_path,
- map_index=self.map_index,
+ map_index=ti.map_index,
+ )
+
+ def command_as_list(
+ self,
+ mark_success: bool = False,
+ ignore_all_deps: bool = False,
+ ignore_task_deps: bool = False,
+ ignore_depends_on_past: bool = False,
+ wait_for_past_depends_before_skipping: bool = False,
+ ignore_ti_state: bool = False,
+ local: bool = False,
+ pickle_id: int | None = None,
+ raw: bool = False,
+ job_id: str | None = None,
+ pool: str | None = None,
+ cfg_path: str | None = None,
+ ) -> list[str]:
+ """
+ Return a command that can be executed anywhere where airflow is
installed.
+
+ This command is part of the message sent to executors by the
orchestrator.
+ """
+ return TaskInstance._command_as_list(
+ ti=self,
+ mark_success=mark_success,
+ ignore_all_deps=ignore_all_deps,
+ ignore_task_deps=ignore_task_deps,
+ ignore_depends_on_past=ignore_depends_on_past,
+
wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,
+ ignore_ti_state=ignore_ti_state,
+ local=local,
+ pickle_id=pickle_id,
+ raw=raw,
+ job_id=job_id,
+ pool=pool,
+ cfg_path=cfg_path,
)
@staticmethod
@@ -2014,9 +2050,12 @@ class TaskInstance(Base, LoggingMixin):
return dr
+ @classmethod
+ @internal_api_call
@provide_session
- def check_and_change_state_before_execution(
- self,
+ def _check_and_change_state_before_execution(
+ cls,
+ task_instance: TaskInstance | TaskInstancePydantic,
verbose: bool = True,
ignore_all_deps: bool = False,
ignore_depends_on_past: bool = False,
@@ -2025,6 +2064,7 @@ class TaskInstance(Base, LoggingMixin):
ignore_ti_state: bool = False,
mark_success: bool = False,
test_mode: bool = False,
+ hostname: str = "",
job_id: str | None = None,
pool: str | None = None,
external_executor_id: str | None = None,
@@ -2044,22 +2084,28 @@ class TaskInstance(Base, LoggingMixin):
:param ignore_ti_state: Disregards previous task instance state
:param mark_success: Don't run the task, mark its state as success
:param test_mode: Doesn't record success or failure in the DB
+ :param hostname: The hostname of the worker running the task instance.
:param job_id: Job (BackfillJob / LocalTaskJob / SchedulerJob) ID
:param pool: specifies the pool to use to run the task instance
:param external_executor_id: The identifier of the celery executor
:param session: SQLAlchemy ORM Session
:return: whether the state was changed to running or not
"""
- task = self.task
- self.refresh_from_task(task, pool_override=pool)
- self.test_mode = test_mode
- self.refresh_from_db(session=session, lock_for_update=True)
- self.job_id = job_id
- self.hostname = get_hostname()
- self.pid = None
-
- if not ignore_all_deps and not ignore_ti_state and self.state ==
TaskInstanceState.SUCCESS:
- Stats.incr("previously_succeeded", tags=self.stats_tags)
+ if isinstance(task_instance, TaskInstance):
+ ti: TaskInstance = task_instance
+ else: # isinstance(task_instance,TaskInstancePydantic)
+ filters = (col == getattr(task_instance, col.name) for col in
inspect(TaskInstance).primary_key)
+ ti = session.query(TaskInstance).filter(*filters).scalar()
+ task = task_instance.task
+ ti.refresh_from_task(task, pool_override=pool)
+ ti.test_mode = test_mode
+ ti.refresh_from_db(session=session, lock_for_update=True)
+ ti.job_id = job_id
+ ti.hostname = hostname
+ ti.pid = None
+
+ if not ignore_all_deps and not ignore_ti_state and ti.state ==
TaskInstanceState.SUCCESS:
+ Stats.incr("previously_succeeded", tags=ti.stats_tags)
if not mark_success:
# Firstly find non-runnable and non-requeueable tis.
@@ -2073,7 +2119,7 @@ class TaskInstance(Base, LoggingMixin):
ignore_task_deps=ignore_task_deps,
description="non-requeueable deps",
)
- if not self.are_dependencies_met(
+ if not ti.are_dependencies_met(
dep_context=non_requeueable_dep_context, session=session,
verbose=True
):
session.commit()
@@ -2085,15 +2131,13 @@ class TaskInstance(Base, LoggingMixin):
# Set the task start date. In case it was re-scheduled use the
initial
# start date that is recorded in task_reschedule table
# If the task continues after being deferred (next_method is set),
use the original start_date
- self.start_date = self.start_date if self.next_method else
timezone.utcnow()
- if self.state == TaskInstanceState.UP_FOR_RESCHEDULE:
+ ti.start_date = ti.start_date if ti.next_method else
timezone.utcnow()
+ if ti.state == TaskInstanceState.UP_FOR_RESCHEDULE:
tr_start_date = session.scalar(
- TR.stmt_for_task_instance(self, descending=False)
- .with_only_columns(TR.start_date)
- .limit(1)
+ TR.stmt_for_task_instance(ti,
descending=False).with_only_columns(TR.start_date).limit(1)
)
if tr_start_date:
- self.start_date = tr_start_date
+ ti.start_date = tr_start_date
# Secondly we find non-runnable but requeueable tis. We reset its
state.
# This is because we might have hit concurrency limits,
@@ -2107,35 +2151,35 @@ class TaskInstance(Base, LoggingMixin):
ignore_ti_state=ignore_ti_state,
description="requeueable deps",
)
- if not self.are_dependencies_met(dep_context=dep_context,
session=session, verbose=True):
- self.state = None
- self.log.warning(
+ if not ti.are_dependencies_met(dep_context=dep_context,
session=session, verbose=True):
+ ti.state = None
+ cls.logger().warning(
"Rescheduling due to concurrency limits reached "
"at task runtime. Attempt %s of "
"%s. State set to NONE.",
- self.try_number,
- self.max_tries + 1,
+ ti.try_number,
+ ti.max_tries + 1,
)
- self.queued_dttm = timezone.utcnow()
- session.merge(self)
+ ti.queued_dttm = timezone.utcnow()
+ session.merge(ti)
session.commit()
return False
- if self.next_kwargs is not None:
- self.log.info("Resuming after deferral")
+ if ti.next_kwargs is not None:
+ cls.logger().info("Resuming after deferral")
else:
- self.log.info("Starting attempt %s of %s", self.try_number,
self.max_tries + 1)
- self._try_number += 1
+ cls.logger().info("Starting attempt %s of %s", ti.try_number,
ti.max_tries + 1)
+ ti._try_number += 1
if not test_mode:
- session.add(Log(TaskInstanceState.RUNNING.value, self))
+ session.add(Log(TaskInstanceState.RUNNING.value, ti))
- self.state = TaskInstanceState.RUNNING
- self.emit_state_change_metric(TaskInstanceState.RUNNING)
- self.external_executor_id = external_executor_id
- self.end_date = None
+ ti.state = TaskInstanceState.RUNNING
+ ti.emit_state_change_metric(TaskInstanceState.RUNNING)
+ ti.external_executor_id = external_executor_id
+ ti.end_date = None
if not test_mode:
- session.merge(self).task = task
+ session.merge(ti).task = task
session.commit()
# Closing all pooled connections to prevent
@@ -2143,11 +2187,44 @@ class TaskInstance(Base, LoggingMixin):
settings.engine.dispose() # type: ignore
if verbose:
if mark_success:
- self.log.info("Marking success for %s on %s", self.task,
self.execution_date)
+ cls.logger().info("Marking success for %s on %s", ti.task,
ti.execution_date)
else:
- self.log.info("Executing %s on %s", self.task,
self.execution_date)
+ cls.logger().info("Executing %s on %s", ti.task,
ti.execution_date)
return True
+ @provide_session
+ def check_and_change_state_before_execution(
+ self,
+ verbose: bool = True,
+ ignore_all_deps: bool = False,
+ ignore_depends_on_past: bool = False,
+ wait_for_past_depends_before_skipping: bool = False,
+ ignore_task_deps: bool = False,
+ ignore_ti_state: bool = False,
+ mark_success: bool = False,
+ test_mode: bool = False,
+ job_id: str | None = None,
+ pool: str | None = None,
+ external_executor_id: str | None = None,
+ session: Session = NEW_SESSION,
+ ) -> bool:
+ return TaskInstance._check_and_change_state_before_execution(
+ task_instance=self,
+ verbose=verbose,
+ ignore_all_deps=ignore_all_deps,
+ ignore_depends_on_past=ignore_depends_on_past,
+
wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,
+ ignore_task_deps=ignore_task_deps,
+ ignore_ti_state=ignore_ti_state,
+ mark_success=mark_success,
+ test_mode=test_mode,
+ hostname=get_hostname(),
+ job_id=job_id,
+ pool=pool,
+ external_executor_id=external_executor_id,
+ session=session,
+ )
+
def emit_state_change_metric(self, new_state: TaskInstanceState) -> None:
"""
Send a time metric representing how much time a given state transition
took.
@@ -3137,14 +3214,16 @@ class TaskInstance(Base, LoggingMixin):
return filters[0]
return or_(*filters)
+ @classmethod
+ @internal_api_call
@Sentry.enrich_errors
@provide_session
- def schedule_downstream_tasks(self, session: Session = NEW_SESSION,
max_tis_per_query: int | None = None):
- """
- Schedule downstream tasks of this task instance.
-
- :meta: private
- """
+ def _schedule_downstream_tasks(
+ cls,
+ ti: TaskInstance | TaskInstancePydantic,
+ session: Session = NEW_SESSION,
+ max_tis_per_query: int | None = None,
+ ):
from sqlalchemy.exc import OperationalError
from airflow.models.dagrun import DagRun
@@ -3153,13 +3232,13 @@ class TaskInstance(Base, LoggingMixin):
# Re-select the row with a lock
dag_run = with_row_locks(
session.query(DagRun).filter_by(
- dag_id=self.dag_id,
- run_id=self.run_id,
+ dag_id=ti.dag_id,
+ run_id=ti.run_id,
),
session=session,
).one()
- task = self.task
+ task = ti.task
if TYPE_CHECKING:
assert task.dag
@@ -3196,19 +3275,30 @@ class TaskInstance(Base, LoggingMixin):
schedulable_ti.task =
task.dag.get_task(schedulable_ti.task_id)
num = dag_run.schedule_tis(schedulable_tis, session=session,
max_tis_per_query=max_tis_per_query)
- self.log.info("%d downstream tasks scheduled from follow-on
schedule check", num)
+ cls.logger().info("%d downstream tasks scheduled from follow-on
schedule check", num)
session.flush()
except OperationalError as e:
# Any kind of DB error here is _non fatal_ as this block is just
an optimisation.
- self.log.info(
+ cls.logger().info(
"Skipping mini scheduling run due to exception: %s",
e.statement,
exc_info=True,
)
session.rollback()
+ @provide_session
+ def schedule_downstream_tasks(self, session: Session = NEW_SESSION,
max_tis_per_query: int | None = None):
+ """
+ Schedule downstream tasks of this task instance.
+
+ :meta: private
+ """
+ return TaskInstance._schedule_downstream_tasks(
+ ti=self, session=session, max_tis_per_query=max_tis_per_query
+ )
+
def get_relevant_upstream_map_indexes(
self,
upstream: Operator,
diff --git a/airflow/serialization/pydantic/dag.py
b/airflow/serialization/pydantic/dag.py
index 48ce1f0c56..6631afdf73 100644
--- a/airflow/serialization/pydantic/dag.py
+++ b/airflow/serialization/pydantic/dag.py
@@ -16,12 +16,68 @@
# under the License.
from __future__ import annotations
-from datetime import datetime
-from typing import List, Optional
+import pathlib
+from datetime import datetime, timedelta
+from typing import Any, List, Optional
-from pydantic import BaseModel as BaseModelPydantic
+from dateutil import relativedelta
+from pydantic import BaseModel as BaseModelPydantic, PlainSerializer,
PlainValidator, ValidationInfo
+from typing_extensions import Annotated
+from airflow import DAG, settings
from airflow.configuration import conf as airflow_conf
+from airflow.utils.sqlalchemy import Interval
+
+
+def serialize_interval(value: Interval) -> Interval:
+ interval = Interval()
+ return interval.process_bind_param(value, None)
+
+
+def validate_interval(value: Interval | Any, _info: ValidationInfo) -> Any:
+ if (
+ isinstance(value, Interval)
+ or isinstance(value, timedelta)
+ or isinstance(value, relativedelta.relativedelta)
+ ):
+ return value
+ interval = Interval()
+ try:
+ return interval.process_result_value(value, None)
+ except ValueError as e:
+ # Interval may be provided in string format (cron),
+ # so it must be returned as valid value.
+ if isinstance(value, str):
+ return value
+ raise e
+
+
+PydanticInterval = Annotated[
+ Interval,
+ PlainValidator(validate_interval),
+ PlainSerializer(serialize_interval, return_type=Interval),
+]
+
+
+def serialize_operator(x: DAG) -> dict:
+ from airflow.serialization.serialized_objects import SerializedDAG
+
+ return SerializedDAG.serialize_dag(x)
+
+
+def validate_operator(x: DAG | dict[str, Any], _info: ValidationInfo) -> Any:
+ from airflow.serialization.serialized_objects import SerializedDAG
+
+ if isinstance(x, DAG):
+ return x
+ return SerializedDAG.deserialize_dag(x)
+
+
+PydanticDag = Annotated[
+ DAG,
+ PlainValidator(validate_operator),
+ PlainSerializer(serialize_operator, return_type=dict),
+]
class DagOwnerAttributesPydantic(BaseModelPydantic):
@@ -71,10 +127,11 @@ class DagModelPydantic(BaseModelPydantic):
owners: Optional[str]
description: Optional[str]
default_view: Optional[str]
- schedule_interval: Optional[str]
+ schedule_interval: Optional[PydanticInterval]
timetable_description: Optional[str]
tags: List[DagTagPydantic] # noqa
dag_owner_links: List[DagOwnerAttributesPydantic] # noqa
+ parent_dag: Optional[PydanticDag]
max_active_tasks: int
max_active_runs: Optional[int]
@@ -82,9 +139,25 @@ class DagModelPydantic(BaseModelPydantic):
has_task_concurrency_limits: bool
has_import_errors: Optional[bool] = False
+ _processor_dags_folder: Optional[str] = None
+
class Config:
"""Make sure it deals automatically with SQLAlchemy ORM classes."""
from_attributes = True
orm_mode = True # Pydantic 1.x compatibility.
arbitrary_types_allowed = True
+
+ @property
+ def relative_fileloc(self) -> pathlib.Path:
+ """File location of the importable dag 'file' relative to the
configured DAGs folder."""
+ path = pathlib.Path(self.fileloc)
+ try:
+ rel_path = path.relative_to(self._processor_dags_folder or
settings.DAGS_FOLDER)
+ if rel_path == pathlib.Path("."):
+ return path
+ else:
+ return rel_path
+ except ValueError:
+ # Not relative to DAGS_FOLDER.
+ return path
diff --git a/airflow/serialization/pydantic/dag_run.py
b/airflow/serialization/pydantic/dag_run.py
index d7f7aae27d..aaa4372a50 100644
--- a/airflow/serialization/pydantic/dag_run.py
+++ b/airflow/serialization/pydantic/dag_run.py
@@ -17,12 +17,11 @@
from __future__ import annotations
from datetime import datetime
-from typing import TYPE_CHECKING, Any, Iterable, List, Optional
+from typing import TYPE_CHECKING, Iterable, List, Optional
-from pydantic import BaseModel as BaseModelPydantic, PlainSerializer,
PlainValidator, ValidationInfo
-from typing_extensions import Annotated
+from pydantic import BaseModel as BaseModelPydantic
-from airflow import DAG
+from airflow.serialization.pydantic.dag import PydanticDag
from airflow.serialization.pydantic.dataset import DatasetEventPydantic
from airflow.utils.session import NEW_SESSION, provide_session
@@ -34,27 +33,6 @@ if TYPE_CHECKING:
from airflow.utils.state import TaskInstanceState
-def serialize_operator(x: DAG) -> dict:
- from airflow.serialization.serialized_objects import SerializedDAG
-
- return SerializedDAG.serialize_dag(x)
-
-
-def validated_operator(x: DAG | dict[str, Any], _info: ValidationInfo) -> Any:
- from airflow.serialization.serialized_objects import SerializedDAG
-
- if isinstance(x, DAG):
- return x
- return SerializedDAG.deserialize_dag(x)
-
-
-PydanticDag = Annotated[
- DAG,
- PlainValidator(validated_operator),
- PlainSerializer(serialize_operator, return_type=dict),
-]
-
-
class DagRunPydantic(BaseModelPydantic):
"""Serializable representation of the DagRun ORM SqlAlchemyModel used by
internal API."""
diff --git a/airflow/serialization/pydantic/taskinstance.py
b/airflow/serialization/pydantic/taskinstance.py
index 7cd4868781..0043bfaef0 100644
--- a/airflow/serialization/pydantic/taskinstance.py
+++ b/airflow/serialization/pydantic/taskinstance.py
@@ -24,8 +24,11 @@ from typing_extensions import Annotated
from airflow.models import Operator
from airflow.models.baseoperator import BaseOperator
+from airflow.models.taskinstance import TaskInstance
+from airflow.serialization.pydantic.dag import DagModelPydantic
from airflow.serialization.pydantic.dag_run import DagRunPydantic
from airflow.utils.log.logging_mixin import LoggingMixin
+from airflow.utils.net import get_hostname
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.xcom import XCOM_RETURN_KEY
@@ -35,7 +38,6 @@ if TYPE_CHECKING:
from sqlalchemy.orm import Session
from airflow.models.dagrun import DagRun
- from airflow.models.taskinstance import TaskInstance
from airflow.utils.context import Context
from airflow.utils.state import DagRunState
@@ -101,6 +103,7 @@ class TaskInstancePydantic(BaseModelPydantic, LoggingMixin):
task: PydanticOperator
test_mode: bool
dag_run: Optional[DagRunPydantic]
+ dag_model: Optional[DagModelPydantic]
class Config:
"""Make sure it deals automatically with SQLAlchemy ORM classes."""
@@ -351,5 +354,85 @@ class TaskInstancePydantic(BaseModelPydantic,
LoggingMixin):
return _get_previous_ti(task_instance=self, state=state,
session=session)
+ @provide_session
+ def check_and_change_state_before_execution(
+ self,
+ verbose: bool = True,
+ ignore_all_deps: bool = False,
+ ignore_depends_on_past: bool = False,
+ wait_for_past_depends_before_skipping: bool = False,
+ ignore_task_deps: bool = False,
+ ignore_ti_state: bool = False,
+ mark_success: bool = False,
+ test_mode: bool = False,
+ job_id: str | None = None,
+ pool: str | None = None,
+ external_executor_id: str | None = None,
+ session: Session = NEW_SESSION,
+ ) -> bool:
+ return TaskInstance._check_and_change_state_before_execution(
+ task_instance=self,
+ verbose=verbose,
+ ignore_all_deps=ignore_all_deps,
+ ignore_depends_on_past=ignore_depends_on_past,
+
wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,
+ ignore_task_deps=ignore_task_deps,
+ ignore_ti_state=ignore_ti_state,
+ mark_success=mark_success,
+ test_mode=test_mode,
+ hostname=get_hostname(),
+ job_id=job_id,
+ pool=pool,
+ external_executor_id=external_executor_id,
+ session=session,
+ )
+
+ @provide_session
+ def schedule_downstream_tasks(self, session: Session = NEW_SESSION,
max_tis_per_query: int | None = None):
+ """
+ Schedule downstream tasks of this task instance.
+
+ :meta: private
+ """
+ return TaskInstance._schedule_downstream_tasks(
+ ti=self, sessions=session, max_tis_per_query=max_tis_per_query
+ )
+
+ def command_as_list(
+ self,
+ mark_success: bool = False,
+ ignore_all_deps: bool = False,
+ ignore_task_deps: bool = False,
+ ignore_depends_on_past: bool = False,
+ wait_for_past_depends_before_skipping: bool = False,
+ ignore_ti_state: bool = False,
+ local: bool = False,
+ pickle_id: int | None = None,
+ raw: bool = False,
+ job_id: str | None = None,
+ pool: str | None = None,
+ cfg_path: str | None = None,
+ ) -> list[str]:
+ """
+ Return a command that can be executed anywhere where airflow is
installed.
+
+ This command is part of the message sent to executors by the
orchestrator.
+ """
+ return TaskInstance._command_as_list(
+ ti=self,
+ mark_success=mark_success,
+ ignore_all_deps=ignore_all_deps,
+ ignore_task_deps=ignore_task_deps,
+ ignore_depends_on_past=ignore_depends_on_past,
+
wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,
+ ignore_ti_state=ignore_ti_state,
+ local=local,
+ pickle_id=pickle_id,
+ raw=raw,
+ job_id=job_id,
+ pool=pool,
+ cfg_path=cfg_path,
+ )
+
TaskInstancePydantic.model_rebuild()
diff --git a/tests/serialization/test_pydantic_models.py
b/tests/serialization/test_pydantic_models.py
index b64d0b5aa8..326e4f239d 100644
--- a/tests/serialization/test_pydantic_models.py
+++ b/tests/serialization/test_pydantic_models.py
@@ -17,7 +17,10 @@
# under the License.
from __future__ import annotations
+import datetime
+
import pytest
+from dateutil import relativedelta
from airflow.jobs.job import Job
from airflow.jobs.local_task_job_runner import LocalTaskJobRunner
@@ -80,23 +83,31 @@ def test_serializing_pydantic_dagrun(session,
create_task_instance):
assert deserialized_model.state == State.RUNNING
-def test_serializing_pydantic_dagmodel():
[email protected](
+ "schedule_interval",
+ [
+ None,
+ "*/10 * * *",
+ datetime.timedelta(days=1),
+ relativedelta.relativedelta(days=+12),
+ ],
+)
+def test_serializing_pydantic_dagmodel(schedule_interval):
dag_model = DagModel(
dag_id="test-dag",
fileloc="/tmp/dag_1.py",
- schedule_interval="2 2 * * *",
+ schedule_interval=schedule_interval,
is_active=True,
is_paused=False,
)
pydantic_dag_model = DagModelPydantic.model_validate(dag_model)
json_string = pydantic_dag_model.model_dump_json()
- print(json_string)
deserialized_model = DagModelPydantic.model_validate_json(json_string)
assert deserialized_model.dag_id == "test-dag"
assert deserialized_model.fileloc == "/tmp/dag_1.py"
- assert deserialized_model.schedule_interval == "2 2 * * *"
+ assert deserialized_model.schedule_interval == schedule_interval
assert deserialized_model.is_active is True
assert deserialized_model.is_paused is False
@@ -111,7 +122,6 @@ def test_serializing_pydantic_local_task_job(session,
create_task_instance):
pydantic_job = JobPydantic.model_validate(ltj)
json_string = pydantic_job.model_dump_json()
- print(json_string)
deserialized_model = JobPydantic.model_validate_json(json_string)
assert deserialized_model.dag_id == dag_id