This is an automated email from the ASF dual-hosted git repository.
dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new d39367bcbf Fix try_number handling when db isolation enabled (#38943)
d39367bcbf is described below
commit d39367bcbfb1117236f23caa12c28d19daa970c9
Author: Daniel Standish <[email protected]>
AuthorDate: Fri Apr 12 09:17:19 2024 -0700
Fix try_number handling when db isolation enabled (#38943)
There was an error in the refresh_from_db code, and because of try_number
inconsistency, the same run was going into two different log files. There is
some ugliness here, but some ugliness is unavoidable when dealing with
try_number as it is right now.
---
airflow/models/dagrun.py | 1 -
airflow/models/taskinstance.py | 27 ++++++++++++++++++++++++--
airflow/serialization/pydantic/taskinstance.py | 1 -
airflow/utils/log/file_task_handler.py | 13 ++++++++-----
tests/models/test_taskinstance.py | 15 +++++++++++++-
5 files changed, 47 insertions(+), 10 deletions(-)
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index 20d92cfa95..fb7a2ae6cd 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -650,7 +650,6 @@ class DagRun(Base, LoggingMixin):
)
@staticmethod
- @internal_api_call
@provide_session
def fetch_task_instance(
dag_id: str,
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index b07aed936d..43b388ef68 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -539,7 +539,7 @@ def _refresh_from_db(
task_instance.end_date = ti.end_date
task_instance.duration = ti.duration
task_instance.state = ti.state
- task_instance.try_number = ti._try_number # private attr to get value
unaltered by accessor
+ task_instance.try_number = _get_private_try_number(task_instance=ti)
task_instance.max_tries = ti.max_tries
task_instance.hostname = ti.hostname
task_instance.unixname = ti.unixname
@@ -925,7 +925,7 @@ def _handle_failure(
TaskInstance.save_to_db(failure_context["ti"], session)
-def _get_try_number(*, task_instance: TaskInstance | TaskInstancePydantic):
+def _get_try_number(*, task_instance: TaskInstance):
"""
Return the try number that a task number will be when it is actually run.
@@ -943,6 +943,23 @@ def _get_try_number(*, task_instance: TaskInstance |
TaskInstancePydantic):
return task_instance._try_number + 1
+def _get_private_try_number(*, task_instance: TaskInstance |
TaskInstancePydantic):
+ """
+ Opposite of _get_try_number.
+
+ Given the value returned by try_number, return the value of _try_number
that
+ should produce the same result.
+ This is needed for setting _try_number on TaskInstance from the value on
PydanticTaskInstance, which has no private attrs.
+
+ :param task_instance: the task instance
+
+ :meta private:
+ """
+ if task_instance.state == TaskInstanceState.RUNNING:
+ return task_instance.try_number
+ return task_instance.try_number - 1
+
+
def _set_try_number(*, task_instance: TaskInstance | TaskInstancePydantic,
value: int) -> None:
"""
Set a task try number.
@@ -3000,6 +3017,12 @@ class TaskInstance(Base, LoggingMixin):
_stop_remaining_tasks(task_instance=ti, session=session)
else:
if ti.state == TaskInstanceState.QUEUED:
+ from airflow.serialization.pydantic.taskinstance import
TaskInstancePydantic
+
+ if isinstance(ti, TaskInstancePydantic):
+ # todo: (AIP-44) we should probably "coalesce" `ti` to
TaskInstance before here
+ # e.g. we could make refresh_from_db return a TI and
replace ti with that
+ raise RuntimeError("Expected TaskInstance here. Further
AIP-44 work required.")
# We increase the try_number to fail the task if it fails to
start after sometime
ti._try_number += 1
ti.state = State.UP_FOR_RETRY
diff --git a/airflow/serialization/pydantic/taskinstance.py
b/airflow/serialization/pydantic/taskinstance.py
index cf27d755b5..cc52aa9989 100644
--- a/airflow/serialization/pydantic/taskinstance.py
+++ b/airflow/serialization/pydantic/taskinstance.py
@@ -85,7 +85,6 @@ class TaskInstancePydantic(BaseModelPydantic, LoggingMixin):
duration: Optional[float]
state: Optional[str]
try_number: int
- _try_number: int
max_tries: int
hostname: str
unixname: str
diff --git a/airflow/utils/log/file_task_handler.py
b/airflow/utils/log/file_task_handler.py
index 95d6849b20..72b8deedbb 100644
--- a/airflow/utils/log/file_task_handler.py
+++ b/airflow/utils/log/file_task_handler.py
@@ -145,7 +145,8 @@ def _ensure_ti(ti: TaskInstanceKey | TaskInstance |
TaskInstancePydantic, sessio
Will raise exception if no TI is found in the database.
"""
- from airflow.models.taskinstance import TaskInstance
+ from airflow.models.taskinstance import TaskInstance,
_get_private_try_number
+ from airflow.serialization.pydantic.taskinstance import
TaskInstancePydantic
if isinstance(ti, TaskInstance):
return ti
@@ -159,11 +160,13 @@ def _ensure_ti(ti: TaskInstanceKey | TaskInstance |
TaskInstancePydantic, sessio
)
.one_or_none()
)
- if isinstance(val, TaskInstance):
- val._try_number = ti.try_number
- return val
- else:
+ if not val:
raise AirflowException(f"Could not find TaskInstance for {ti}")
+ if isinstance(ti, TaskInstancePydantic):
+ val.try_number = _get_private_try_number(task_instance=ti)
+ else: # TaskInstanceKey
+ val.try_number = ti.try_number
+ return val
class FileTaskHandler(logging.Handler):
diff --git a/tests/models/test_taskinstance.py
b/tests/models/test_taskinstance.py
index 4e42882dd7..fe01569e90 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -29,7 +29,7 @@ import urllib
from traceback import format_exception
from typing import cast
from unittest import mock
-from unittest.mock import call, mock_open, patch
+from unittest.mock import MagicMock, call, mock_open, patch
from uuid import uuid4
import pendulum
@@ -66,6 +66,8 @@ from airflow.models.taskinstance import (
TaskInstance,
TaskInstance as TI,
TaskInstanceNote,
+ _get_private_try_number,
+ _get_try_number,
_run_finished_callback,
)
from airflow.models.taskmap import TaskMap
@@ -4636,3 +4638,14 @@ def
test__refresh_from_db_should_not_increment_try_number(dag_maker, session):
assert ti.try_number == 1 # stays 1
ti.refresh_from_db()
assert ti.try_number == 1 # stays 1
+
+
[email protected]("state", list(TaskInstanceState))
+def test_get_private_try_number(state: str):
+ mock_ti = MagicMock()
+ mock_ti.state = state
+ private_try_number = 2
+ mock_ti._try_number = private_try_number
+ mock_ti.try_number = _get_try_number(task_instance=mock_ti)
+ delattr(mock_ti, "_try_number")
+ assert _get_private_try_number(task_instance=mock_ti) == private_try_number