This is an automated email from the ASF dual-hosted git repository.
dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 5ea5f05b3f Implement `set_state` on `TaskInstancePydantic` (#38297)
5ea5f05b3f is described below
commit 5ea5f05b3ffd12ccf58a7ef5d0f31fca4c2b3fec
Author: Daniel Standish <[email protected]>
AuthorDate: Wed Mar 20 09:24:06 2024 -0700
Implement `set_state` on `TaskInstancePydantic` (#38297)
---
airflow/api_internal/endpoints/rpc_api_endpoint.py | 1 +
airflow/models/taskinstance.py | 41 +++++++++++++++-------
airflow/serialization/pydantic/taskinstance.py | 3 ++
3 files changed, 32 insertions(+), 13 deletions(-)
diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py
b/airflow/api_internal/endpoints/rpc_api_endpoint.py
index 7a8f2edfdb..656fc70002 100644
--- a/airflow/api_internal/endpoints/rpc_api_endpoint.py
+++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py
@@ -81,6 +81,7 @@ def _initialize_map() -> dict[str, Callable]:
TaskInstance._check_and_change_state_before_execution,
TaskInstance.get_task_instance,
TaskInstance._get_dagrun,
+ TaskInstance._set_state,
TaskInstance.fetch_handle_failure_context,
TaskInstance.save_to_db,
TaskInstance._schedule_downstream_tasks,
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index b16fa633a6..ad956e68b4 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -62,7 +62,7 @@ from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.orm import reconstructor, relationship
from sqlalchemy.orm.attributes import NO_VALUE, set_committed_value
-from sqlalchemy.sql.expression import case
+from sqlalchemy.sql.expression import case, select
from airflow import settings
from airflow.api_internal.internal_api_call import internal_api_call
@@ -1845,6 +1845,32 @@ class TaskInstance(Base, LoggingMixin):
"""Returns a tuple that identifies the task instance uniquely."""
return TaskInstanceKey(self.dag_id, self.task_id, self.run_id,
self.try_number, self.map_index)
+ @staticmethod
+ @internal_api_call
+ def _set_state(ti: TaskInstance | TaskInstancePydantic, state, session:
Session) -> bool:
+ if not isinstance(ti, TaskInstance):
+ ti = session.scalars(
+ select(TaskInstance).where(
+ TaskInstance.task_id == ti.task_id,
+ TaskInstance.dag_id == ti.dag_id,
+ TaskInstance.run_id == ti.run_id,
+ TaskInstance.map_index == ti.map_index,
+ )
+ ).one()
+
+ if ti.state == state:
+ return False
+
+ current_time = timezone.utcnow()
+ ti.log.debug("Setting task state for %s to %s", ti, state)
+ ti.state = state
+ ti.start_date = ti.start_date or current_time
+ if ti.state in State.finished or ti.state ==
TaskInstanceState.UP_FOR_RETRY:
+ ti.end_date = ti.end_date or current_time
+ ti.duration = (ti.end_date - ti.start_date).total_seconds()
+ session.merge(ti)
+ return True
+
@provide_session
def set_state(self, state: str | None, session: Session = NEW_SESSION) ->
bool:
"""
@@ -1854,18 +1880,7 @@ class TaskInstance(Base, LoggingMixin):
:param session: SQLAlchemy ORM Session
:return: Was the state changed
"""
- if self.state == state:
- return False
-
- current_time = timezone.utcnow()
- self.log.debug("Setting task state for %s to %s", self, state)
- self.state = state
- self.start_date = self.start_date or current_time
- if self.state in State.finished or self.state ==
TaskInstanceState.UP_FOR_RETRY:
- self.end_date = self.end_date or current_time
- self.duration = (self.end_date - self.start_date).total_seconds()
- session.merge(self)
- return True
+ return self._set_state(ti=self, state=state, session=session)
@property
def is_premature(self) -> bool:
diff --git a/airflow/serialization/pydantic/taskinstance.py
b/airflow/serialization/pydantic/taskinstance.py
index 44db550e5b..7c60c5afc5 100644
--- a/airflow/serialization/pydantic/taskinstance.py
+++ b/airflow/serialization/pydantic/taskinstance.py
@@ -123,6 +123,9 @@ class TaskInstancePydantic(BaseModelPydantic, LoggingMixin):
def clear_xcom_data(self, session: Session | None = None):
TaskInstance._clear_xcom_data(ti=self, session=session)
+ def set_state(self, state, session: Session | None = None) -> bool:
+ return TaskInstance._set_state(ti=self, state=state, session=session)
+
def init_run_context(self, raw: bool = False) -> None:
"""Set the log context."""
self.raw = raw