This is an automated email from the ASF dual-hosted git repository.

dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 583fa2da38 Remove select_column option in 
TaskInstance.get_task_instance (#38571)
583fa2da38 is described below

commit 583fa2da387ef08ce3ff999dea9e6e61524b0cb7
Author: Daniel Standish <[email protected]>
AuthorDate: Tue Apr 2 09:16:58 2024 -0700

    Remove select_column option in TaskInstance.get_task_instance (#38571)
    
    Fundamentally what's going on here is we need a TaskInstance object instead 
of a Row object when sending over the wire in RPC call.  But the full story on 
this one is actually somewhat complicated.
    It was back in 2.2.0 in #25312 when we converted to query with the column 
attrs instead of the TI object (#28900 only refactored this logic into a 
function).  The reason was to avoid locking the dag_run table since TI newly 
had a dag_run relationship attr.  Now, this causes a problem with AIP-44 
because the RPC api does not know how to serialize a Row object.
    This PR switches back to querying a TaskInstance object, but avoids locking 
dag_run by using lazy_load option.  Meanwhile, since try_number is a horrible 
attribute (which gives you a different answer depending on the state), we have 
to switch it back to look at the underlying private attr instead of the public 
accesor.
---
 airflow/models/taskinstance.py    | 24 +++++++++++-------------
 tests/models/test_taskinstance.py | 13 +++++++++++++
 2 files changed, 24 insertions(+), 13 deletions(-)

diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 2107781041..14fc0fc8f7 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -61,7 +61,7 @@ from sqlalchemy import (
 )
 from sqlalchemy.ext.associationproxy import association_proxy
 from sqlalchemy.ext.mutable import MutableDict
-from sqlalchemy.orm import reconstructor, relationship
+from sqlalchemy.orm import lazyload, reconstructor, relationship
 from sqlalchemy.orm.attributes import NO_VALUE, set_committed_value
 from sqlalchemy.sql.expression import case, select
 
@@ -523,7 +523,6 @@ def _refresh_from_db(
         task_id=task_instance.task_id,
         run_id=task_instance.run_id,
         map_index=task_instance.map_index,
-        select_columns=True,
         lock_for_update=lock_for_update,
         session=session,
     )
@@ -534,8 +533,7 @@ def _refresh_from_db(
         task_instance.end_date = ti.end_date
         task_instance.duration = ti.duration
         task_instance.state = ti.state
-        # Since we selected columns, not the object, this is the raw value
-        task_instance.try_number = ti.try_number
+        task_instance.try_number = ti._try_number  # private attr to get value 
unaltered by accessor
         task_instance.max_tries = ti.max_tries
         task_instance.hostname = ti.hostname
         task_instance.unixname = ti.unixname
@@ -914,7 +912,7 @@ def _get_try_number(*, task_instance: TaskInstance | 
TaskInstancePydantic):
 
     :meta private:
     """
-    if task_instance.state == TaskInstanceState.RUNNING.RUNNING:
+    if task_instance.state == TaskInstanceState.RUNNING:
         return task_instance._try_number
     return task_instance._try_number + 1
 
@@ -1798,18 +1796,18 @@ class TaskInstance(Base, LoggingMixin):
         run_id: str,
         task_id: str,
         map_index: int,
-        select_columns: bool = False,
         lock_for_update: bool = False,
         session: Session = NEW_SESSION,
     ) -> TaskInstance | TaskInstancePydantic | None:
         query = (
-            session.query(*TaskInstance.__table__.columns) if select_columns 
else session.query(TaskInstance)
-        )
-        query = query.filter_by(
-            dag_id=dag_id,
-            run_id=run_id,
-            task_id=task_id,
-            map_index=map_index,
+            session.query(TaskInstance)
+            .options(lazyload("dag_run"))  # lazy load dag run to avoid 
locking it
+            .filter_by(
+                dag_id=dag_id,
+                run_id=run_id,
+                task_id=task_id,
+                map_index=map_index,
+            )
         )
 
         if lock_for_update:
diff --git a/tests/models/test_taskinstance.py 
b/tests/models/test_taskinstance.py
index 8dacc839cb..46654d564d 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -4562,3 +4562,16 @@ def test_taskinstance_with_note(create_task_instance, 
session):
 
     assert 
session.query(TaskInstance).filter_by(**filter_kwargs).one_or_none() is None
     assert 
session.query(TaskInstanceNote).filter_by(**filter_kwargs).one_or_none() is None
+
+
+def test__refresh_from_db_should_not_increment_try_number(dag_maker, session):
+    with dag_maker():
+        BashOperator(task_id="hello", bash_command="hi")
+    dag_maker.create_dagrun(state="success")
+    ti = session.scalar(select(TaskInstance))
+    assert ti.task_id == "hello"  # just to confirm...
+    assert ti.try_number == 1  # starts out as 1
+    ti.refresh_from_db()
+    assert ti.try_number == 1  # stays 1
+    ti.refresh_from_db()
+    assert ti.try_number == 1  # stays 1

Reply via email to