This is an automated email from the ASF dual-hosted git repository.

dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 5ea5f05b3f Implement `set_state` on `TaskInstancePydantic` (#38297)
5ea5f05b3f is described below

commit 5ea5f05b3ffd12ccf58a7ef5d0f31fca4c2b3fec
Author: Daniel Standish <[email protected]>
AuthorDate: Wed Mar 20 09:24:06 2024 -0700

    Implement `set_state` on `TaskInstancePydantic` (#38297)
---
 airflow/api_internal/endpoints/rpc_api_endpoint.py |  1 +
 airflow/models/taskinstance.py                     | 41 +++++++++++++++-------
 airflow/serialization/pydantic/taskinstance.py     |  3 ++
 3 files changed, 32 insertions(+), 13 deletions(-)

diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py 
b/airflow/api_internal/endpoints/rpc_api_endpoint.py
index 7a8f2edfdb..656fc70002 100644
--- a/airflow/api_internal/endpoints/rpc_api_endpoint.py
+++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py
@@ -81,6 +81,7 @@ def _initialize_map() -> dict[str, Callable]:
         TaskInstance._check_and_change_state_before_execution,
         TaskInstance.get_task_instance,
         TaskInstance._get_dagrun,
+        TaskInstance._set_state,
         TaskInstance.fetch_handle_failure_context,
         TaskInstance.save_to_db,
         TaskInstance._schedule_downstream_tasks,
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index b16fa633a6..ad956e68b4 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -62,7 +62,7 @@ from sqlalchemy.ext.associationproxy import association_proxy
 from sqlalchemy.ext.mutable import MutableDict
 from sqlalchemy.orm import reconstructor, relationship
 from sqlalchemy.orm.attributes import NO_VALUE, set_committed_value
-from sqlalchemy.sql.expression import case
+from sqlalchemy.sql.expression import case, select
 
 from airflow import settings
 from airflow.api_internal.internal_api_call import internal_api_call
@@ -1845,6 +1845,32 @@ class TaskInstance(Base, LoggingMixin):
         """Returns a tuple that identifies the task instance uniquely."""
         return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, 
self.try_number, self.map_index)
 
+    @staticmethod
+    @internal_api_call
+    def _set_state(ti: TaskInstance | TaskInstancePydantic, state, session: 
Session) -> bool:
+        if not isinstance(ti, TaskInstance):
+            ti = session.scalars(
+                select(TaskInstance).where(
+                    TaskInstance.task_id == ti.task_id,
+                    TaskInstance.dag_id == ti.dag_id,
+                    TaskInstance.run_id == ti.run_id,
+                    TaskInstance.map_index == ti.map_index,
+                )
+            ).one()
+
+        if ti.state == state:
+            return False
+
+        current_time = timezone.utcnow()
+        ti.log.debug("Setting task state for %s to %s", ti, state)
+        ti.state = state
+        ti.start_date = ti.start_date or current_time
+        if ti.state in State.finished or ti.state == 
TaskInstanceState.UP_FOR_RETRY:
+            ti.end_date = ti.end_date or current_time
+            ti.duration = (ti.end_date - ti.start_date).total_seconds()
+        session.merge(ti)
+        return True
+
     @provide_session
     def set_state(self, state: str | None, session: Session = NEW_SESSION) -> 
bool:
         """
@@ -1854,18 +1880,7 @@ class TaskInstance(Base, LoggingMixin):
         :param session: SQLAlchemy ORM Session
         :return: Was the state changed
         """
-        if self.state == state:
-            return False
-
-        current_time = timezone.utcnow()
-        self.log.debug("Setting task state for %s to %s", self, state)
-        self.state = state
-        self.start_date = self.start_date or current_time
-        if self.state in State.finished or self.state == 
TaskInstanceState.UP_FOR_RETRY:
-            self.end_date = self.end_date or current_time
-            self.duration = (self.end_date - self.start_date).total_seconds()
-        session.merge(self)
-        return True
+        return self._set_state(ti=self, state=state, session=session)
 
     @property
     def is_premature(self) -> bool:
diff --git a/airflow/serialization/pydantic/taskinstance.py 
b/airflow/serialization/pydantic/taskinstance.py
index 44db550e5b..7c60c5afc5 100644
--- a/airflow/serialization/pydantic/taskinstance.py
+++ b/airflow/serialization/pydantic/taskinstance.py
@@ -123,6 +123,9 @@ class TaskInstancePydantic(BaseModelPydantic, LoggingMixin):
     def clear_xcom_data(self, session: Session | None = None):
         TaskInstance._clear_xcom_data(ti=self, session=session)
 
+    def set_state(self, state, session: Session | None = None) -> bool:
+        return TaskInstance._set_state(ti=self, state=state, session=session)
+
     def init_run_context(self, raw: bool = False) -> None:
         """Set the log context."""
         self.raw = raw

Reply via email to