This is an automated email from the ASF dual-hosted git repository.

dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new d39367bcbf Fix try_number handling when db isolation enabled (#38943)
d39367bcbf is described below

commit d39367bcbfb1117236f23caa12c28d19daa970c9
Author: Daniel Standish <[email protected]>
AuthorDate: Fri Apr 12 09:17:19 2024 -0700

    Fix try_number handling when db isolation enabled (#38943)
    
    There was an error in the refresh_from_db code, and because of try_number 
inconsistency, the same run was going into two different log files.  There is 
some ugliness here, but some ugliness is unavoidable when dealing with 
try_number as it is right now.
---
 airflow/models/dagrun.py                       |  1 -
 airflow/models/taskinstance.py                 | 27 ++++++++++++++++++++++++--
 airflow/serialization/pydantic/taskinstance.py |  1 -
 airflow/utils/log/file_task_handler.py         | 13 ++++++++-----
 tests/models/test_taskinstance.py              | 15 +++++++++++++-
 5 files changed, 47 insertions(+), 10 deletions(-)

diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index 20d92cfa95..fb7a2ae6cd 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -650,7 +650,6 @@ class DagRun(Base, LoggingMixin):
         )
 
     @staticmethod
-    @internal_api_call
     @provide_session
     def fetch_task_instance(
         dag_id: str,
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index b07aed936d..43b388ef68 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -539,7 +539,7 @@ def _refresh_from_db(
         task_instance.end_date = ti.end_date
         task_instance.duration = ti.duration
         task_instance.state = ti.state
-        task_instance.try_number = ti._try_number  # private attr to get value 
unaltered by accessor
+        task_instance.try_number = _get_private_try_number(task_instance=ti)
         task_instance.max_tries = ti.max_tries
         task_instance.hostname = ti.hostname
         task_instance.unixname = ti.unixname
@@ -925,7 +925,7 @@ def _handle_failure(
         TaskInstance.save_to_db(failure_context["ti"], session)
 
 
-def _get_try_number(*, task_instance: TaskInstance | TaskInstancePydantic):
+def _get_try_number(*, task_instance: TaskInstance):
     """
     Return the try number that a task number will be when it is actually run.
 
@@ -943,6 +943,23 @@ def _get_try_number(*, task_instance: TaskInstance | 
TaskInstancePydantic):
     return task_instance._try_number + 1
 
 
+def _get_private_try_number(*, task_instance: TaskInstance | 
TaskInstancePydantic):
+    """
+    Opposite of _get_try_number.
+
+    Given the value returned by try_number, return the value of _try_number 
that
+    should produce the same result.
+    This is needed for setting _try_number on TaskInstance from the value on 
PydanticTaskInstance, which has no private attrs.
+
+    :param task_instance: the task instance
+
+    :meta private:
+    """
+    if task_instance.state == TaskInstanceState.RUNNING:
+        return task_instance.try_number
+    return task_instance.try_number - 1
+
+
 def _set_try_number(*, task_instance: TaskInstance | TaskInstancePydantic, 
value: int) -> None:
     """
     Set a task try number.
@@ -3000,6 +3017,12 @@ class TaskInstance(Base, LoggingMixin):
                 _stop_remaining_tasks(task_instance=ti, session=session)
         else:
             if ti.state == TaskInstanceState.QUEUED:
+                from airflow.serialization.pydantic.taskinstance import 
TaskInstancePydantic
+
+                if isinstance(ti, TaskInstancePydantic):
+                    # todo: (AIP-44) we should probably "coalesce" `ti` to 
TaskInstance before here
+                    #  e.g. we could make refresh_from_db return a TI and 
replace ti with that
+                    raise RuntimeError("Expected TaskInstance here. Further 
AIP-44 work required.")
                 # We increase the try_number to fail the task if it fails to 
start after sometime
                 ti._try_number += 1
             ti.state = State.UP_FOR_RETRY
diff --git a/airflow/serialization/pydantic/taskinstance.py 
b/airflow/serialization/pydantic/taskinstance.py
index cf27d755b5..cc52aa9989 100644
--- a/airflow/serialization/pydantic/taskinstance.py
+++ b/airflow/serialization/pydantic/taskinstance.py
@@ -85,7 +85,6 @@ class TaskInstancePydantic(BaseModelPydantic, LoggingMixin):
     duration: Optional[float]
     state: Optional[str]
     try_number: int
-    _try_number: int
     max_tries: int
     hostname: str
     unixname: str
diff --git a/airflow/utils/log/file_task_handler.py 
b/airflow/utils/log/file_task_handler.py
index 95d6849b20..72b8deedbb 100644
--- a/airflow/utils/log/file_task_handler.py
+++ b/airflow/utils/log/file_task_handler.py
@@ -145,7 +145,8 @@ def _ensure_ti(ti: TaskInstanceKey | TaskInstance | 
TaskInstancePydantic, sessio
 
     Will raise exception if no TI is found in the database.
     """
-    from airflow.models.taskinstance import TaskInstance
+    from airflow.models.taskinstance import TaskInstance, 
_get_private_try_number
+    from airflow.serialization.pydantic.taskinstance import 
TaskInstancePydantic
 
     if isinstance(ti, TaskInstance):
         return ti
@@ -159,11 +160,13 @@ def _ensure_ti(ti: TaskInstanceKey | TaskInstance | 
TaskInstancePydantic, sessio
         )
         .one_or_none()
     )
-    if isinstance(val, TaskInstance):
-        val._try_number = ti.try_number
-        return val
-    else:
+    if not val:
         raise AirflowException(f"Could not find TaskInstance for {ti}")
+    if isinstance(ti, TaskInstancePydantic):
+        val.try_number = _get_private_try_number(task_instance=ti)
+    else:  # TaskInstanceKey
+        val.try_number = ti.try_number
+    return val
 
 
 class FileTaskHandler(logging.Handler):
diff --git a/tests/models/test_taskinstance.py 
b/tests/models/test_taskinstance.py
index 4e42882dd7..fe01569e90 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -29,7 +29,7 @@ import urllib
 from traceback import format_exception
 from typing import cast
 from unittest import mock
-from unittest.mock import call, mock_open, patch
+from unittest.mock import MagicMock, call, mock_open, patch
 from uuid import uuid4
 
 import pendulum
@@ -66,6 +66,8 @@ from airflow.models.taskinstance import (
     TaskInstance,
     TaskInstance as TI,
     TaskInstanceNote,
+    _get_private_try_number,
+    _get_try_number,
     _run_finished_callback,
 )
 from airflow.models.taskmap import TaskMap
@@ -4636,3 +4638,14 @@ def 
test__refresh_from_db_should_not_increment_try_number(dag_maker, session):
     assert ti.try_number == 1  # stays 1
     ti.refresh_from_db()
     assert ti.try_number == 1  # stays 1
+
+
[email protected]("state", list(TaskInstanceState))
+def test_get_private_try_number(state: str):
+    mock_ti = MagicMock()
+    mock_ti.state = state
+    private_try_number = 2
+    mock_ti._try_number = private_try_number
+    mock_ti.try_number = _get_try_number(task_instance=mock_ti)
+    delattr(mock_ti, "_try_number")
+    assert _get_private_try_number(task_instance=mock_ti) == private_try_number

Reply via email to