This is an automated email from the ASF dual-hosted git repository.
dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 583fa2da38 Remove select_column option in
TaskInstance.get_task_instance (#38571)
583fa2da38 is described below
commit 583fa2da387ef08ce3ff999dea9e6e61524b0cb7
Author: Daniel Standish <[email protected]>
AuthorDate: Tue Apr 2 09:16:58 2024 -0700
Remove select_column option in TaskInstance.get_task_instance (#38571)
Fundamentally what's going on here is we need a TaskInstance object instead
of a Row object when sending over the wire in RPC call. But the full story on
this one is actually somewhat complicated.
It was back in 2.2.0 in #25312 when we converted to query with the column
attrs instead of the TI object (#28900 only refactored this logic into a
function). The reason was to avoid locking the dag_run table since TI newly
had a dag_run relationship attr. Now, this causes a problem with AIP-44
because the RPC api does not know how to serialize a Row object.
This PR switches back to querying a TaskInstance object, but avoids locking
dag_run by using lazy_load option. Meanwhile, since try_number is a horrible
attribute (which gives you a different answer depending on the state), we have
to switch it back to look at the underlying private attr instead of the public
accesor.
---
airflow/models/taskinstance.py | 24 +++++++++++-------------
tests/models/test_taskinstance.py | 13 +++++++++++++
2 files changed, 24 insertions(+), 13 deletions(-)
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 2107781041..14fc0fc8f7 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -61,7 +61,7 @@ from sqlalchemy import (
)
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.ext.mutable import MutableDict
-from sqlalchemy.orm import reconstructor, relationship
+from sqlalchemy.orm import lazyload, reconstructor, relationship
from sqlalchemy.orm.attributes import NO_VALUE, set_committed_value
from sqlalchemy.sql.expression import case, select
@@ -523,7 +523,6 @@ def _refresh_from_db(
task_id=task_instance.task_id,
run_id=task_instance.run_id,
map_index=task_instance.map_index,
- select_columns=True,
lock_for_update=lock_for_update,
session=session,
)
@@ -534,8 +533,7 @@ def _refresh_from_db(
task_instance.end_date = ti.end_date
task_instance.duration = ti.duration
task_instance.state = ti.state
- # Since we selected columns, not the object, this is the raw value
- task_instance.try_number = ti.try_number
+ task_instance.try_number = ti._try_number # private attr to get value
unaltered by accessor
task_instance.max_tries = ti.max_tries
task_instance.hostname = ti.hostname
task_instance.unixname = ti.unixname
@@ -914,7 +912,7 @@ def _get_try_number(*, task_instance: TaskInstance |
TaskInstancePydantic):
:meta private:
"""
- if task_instance.state == TaskInstanceState.RUNNING.RUNNING:
+ if task_instance.state == TaskInstanceState.RUNNING:
return task_instance._try_number
return task_instance._try_number + 1
@@ -1798,18 +1796,18 @@ class TaskInstance(Base, LoggingMixin):
run_id: str,
task_id: str,
map_index: int,
- select_columns: bool = False,
lock_for_update: bool = False,
session: Session = NEW_SESSION,
) -> TaskInstance | TaskInstancePydantic | None:
query = (
- session.query(*TaskInstance.__table__.columns) if select_columns
else session.query(TaskInstance)
- )
- query = query.filter_by(
- dag_id=dag_id,
- run_id=run_id,
- task_id=task_id,
- map_index=map_index,
+ session.query(TaskInstance)
+ .options(lazyload("dag_run")) # lazy load dag run to avoid
locking it
+ .filter_by(
+ dag_id=dag_id,
+ run_id=run_id,
+ task_id=task_id,
+ map_index=map_index,
+ )
)
if lock_for_update:
diff --git a/tests/models/test_taskinstance.py
b/tests/models/test_taskinstance.py
index 8dacc839cb..46654d564d 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -4562,3 +4562,16 @@ def test_taskinstance_with_note(create_task_instance,
session):
assert
session.query(TaskInstance).filter_by(**filter_kwargs).one_or_none() is None
assert
session.query(TaskInstanceNote).filter_by(**filter_kwargs).one_or_none() is None
+
+
+def test__refresh_from_db_should_not_increment_try_number(dag_maker, session):
+ with dag_maker():
+ BashOperator(task_id="hello", bash_command="hi")
+ dag_maker.create_dagrun(state="success")
+ ti = session.scalar(select(TaskInstance))
+ assert ti.task_id == "hello" # just to confirm...
+ assert ti.try_number == 1 # starts out as 1
+ ti.refresh_from_db()
+ assert ti.try_number == 1 # stays 1
+ ti.refresh_from_db()
+ assert ti.try_number == 1 # stays 1