This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch v3-2-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/v3-2-test by this push:
new 1ff7aace593 [v3-2-test] fix(scheduler): catch StaleDataError in
verify_integrity to prevent scheduler crash (#64503) (#66727)
1ff7aace593 is described below
commit 1ff7aace5939eabe379dd5cb14ca4d62d0aae93a
Author: github-actions[bot]
<41898282+github-actions[bot]@users.noreply.github.com>
AuthorDate: Tue May 12 05:10:31 2026 +0200
[v3-2-test] fix(scheduler): catch StaleDataError in verify_integrity to
prevent scheduler crash (#64503) (#66727)
Closes #63926
StaleDataError raised by SQLAlchemy's optimistic locking when a concurrent
session modifies the same row can cause the scheduler to crash during
verify_integrity. Fix by catching StaleDataError alongside IntegrityError
in dagrun.verify_integrity() and adding it to the retry exceptions in
run_with_db_retries()/retry_db_transaction() so the operation is retried
automatically.
(cherry picked from commit dcfa2715632de7f665c3eba1b42d2e3084f08361)
Co-authored-by: Pradeep Kalluri
<[email protected]>
---
airflow-core/newsfragments/64503.bugfix.rst | 1 +
airflow-core/src/airflow/models/dagrun.py | 8 ++++++--
airflow-core/src/airflow/utils/retries.py | 5 +++--
airflow-core/tests/unit/models/test_dagrun.py | 24 ++++++++++++++++++++++
airflow-core/tests/unit/utils/test_retries.py | 29 ++++++++++++++++++---------
5 files changed, 54 insertions(+), 13 deletions(-)
diff --git a/airflow-core/newsfragments/64503.bugfix.rst
b/airflow-core/newsfragments/64503.bugfix.rst
new file mode 100644
index 00000000000..0358708ea1f
--- /dev/null
+++ b/airflow-core/newsfragments/64503.bugfix.rst
@@ -0,0 +1 @@
+Fix scheduler crashing with ``StaleDataError`` when a task instance is
completed or removed by another session between ``verify_integrity`` loading
task instances and ``session.flush()`` persisting them. Now caught and rolled
back like the existing ``IntegrityError`` path.
diff --git a/airflow-core/src/airflow/models/dagrun.py
b/airflow-core/src/airflow/models/dagrun.py
index 7eabadd73cf..afe73a43b96 100644
--- a/airflow-core/src/airflow/models/dagrun.py
+++ b/airflow-core/src/airflow/models/dagrun.py
@@ -57,6 +57,7 @@ from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.orm import Mapped, declared_attr, joinedload, mapped_column,
relationship, synonym, validates
+from sqlalchemy.orm.exc import StaleDataError
from sqlalchemy.sql.expression import false, select
from sqlalchemy.sql.functions import coalesce
@@ -1873,14 +1874,17 @@ class DagRun(Base, LoggingMixin):
extra_tags={"task_type": task_type},
)
session.flush()
- except IntegrityError:
+ except (IntegrityError, StaleDataError) as exc:
self.log.info(
- "Hit IntegrityError while creating the TIs for %s- %s",
+ "Hit %s while creating the TIs for %s- %s",
+ type(exc).__name__,
dag_id,
run_id,
exc_info=True,
)
self.log.info("Doing session rollback.")
+ # Catching StaleDataError and rolling back is sufficient here
because
+ # the next scheduler loop will re-read the latest state from the
DB.
# TODO[HA]: We probably need to savepoint this so we can keep the
transaction alive.
session.rollback()
diff --git a/airflow-core/src/airflow/utils/retries.py
b/airflow-core/src/airflow/utils/retries.py
index a30d6766853..69b71046acb 100644
--- a/airflow-core/src/airflow/utils/retries.py
+++ b/airflow-core/src/airflow/utils/retries.py
@@ -23,6 +23,7 @@ from inspect import signature
from typing import TYPE_CHECKING, TypeVar, overload
from sqlalchemy.exc import DBAPIError
+from sqlalchemy.orm.exc import StaleDataError
from airflow.configuration import conf
@@ -40,7 +41,7 @@ def run_with_db_retries(max_retries: int = MAX_DB_RETRIES,
logger: Logger | None
# Default kwargs
retry_kwargs = dict(
- retry=tenacity.retry_if_exception_type(exception_types=(DBAPIError)),
+ retry=tenacity.retry_if_exception_type(exception_types=(DBAPIError,
StaleDataError)),
wait=tenacity.wait_random_exponential(multiplier=0.5, max=5),
stop=tenacity.stop_after_attempt(max_retries),
reraise=True,
@@ -104,7 +105,7 @@ def retry_db_transaction(_func: Callable | None = None, *,
retries: int = MAX_DB
)
try:
return func(*args, **kwargs)
- except DBAPIError:
+ except (DBAPIError, StaleDataError):
session.rollback()
raise
diff --git a/airflow-core/tests/unit/models/test_dagrun.py
b/airflow-core/tests/unit/models/test_dagrun.py
index 93bf2dcbdf4..b259b62552e 100644
--- a/airflow-core/tests/unit/models/test_dagrun.py
+++ b/airflow-core/tests/unit/models/test_dagrun.py
@@ -39,6 +39,7 @@ from sqlalchemy import (
update,
)
from sqlalchemy.orm import joinedload
+from sqlalchemy.orm.exc import StaleDataError
from airflow import settings
from airflow._shared.observability.metrics.stats import Stats
@@ -1443,6 +1444,29 @@ def
test_expand_mapped_task_instance_task_decorator(is_noop, dag_maker, session)
assert indices == [0, 1, 2, 3]
+def test_verify_integrity_handles_stale_data_error(dag_maker, session):
+ """Test that StaleDataError during _create_task_instances is caught and
session is rolled back."""
+ with dag_maker("test_stale_data_error_dag", session=session) as dag:
+ task = EmptyOperator(task_id="task1")
+
+ dr = dag_maker.create_dagrun()
+ dag_version_id = DagVersion.get_latest_version(dag.dag_id,
session=session).id
+
+ with mock.patch.object(session, "flush", side_effect=StaleDataError()):
+ with mock.patch.object(session, "rollback") as mock_rollback:
+ # Should not raise — StaleDataError must be caught gracefully.
+ # Call _create_task_instances directly with a non-empty task list
so the
+ # test exercises the session.flush() → StaleDataError →
session.rollback() path.
+ dr._create_task_instances(
+ dag_id=dag.dag_id,
+ tasks=[TI(task=task, run_id=dr.run_id,
dag_version_id=dag_version_id)],
+ created_counts={"EmptyOperator": 1},
+ hook_is_noop=False,
+ session=session,
+ )
+ mock_rollback.assert_called_once()
+
+
def test_mapped_literal_verify_integrity(dag_maker, session):
"""Test that when the length of a mapped literal changes we remove extra
TIs"""
diff --git a/airflow-core/tests/unit/utils/test_retries.py
b/airflow-core/tests/unit/utils/test_retries.py
index 1f44ee9ebf8..f0976d0e358 100644
--- a/airflow-core/tests/unit/utils/test_retries.py
+++ b/airflow-core/tests/unit/utils/test_retries.py
@@ -18,17 +18,14 @@
from __future__ import annotations
import logging
-from typing import TYPE_CHECKING
from unittest import mock
import pytest
from sqlalchemy.exc import InternalError, OperationalError
+from sqlalchemy.orm.exc import StaleDataError
from airflow.utils.retries import retry_db_transaction
-if TYPE_CHECKING:
- from sqlalchemy.exc import DBAPIError
-
class TestRetries:
def test_retry_db_transaction_with_passing_retries(self):
@@ -48,15 +45,29 @@ class TestRetries:
assert mock_obj.call_count == 2
- @pytest.mark.db_test
- @pytest.mark.parametrize("excection_type", [OperationalError,
InternalError])
- def test_retry_db_transaction_with_default_retries(self, caplog,
excection_type: type[DBAPIError]):
+ @pytest.mark.parametrize(
+ ("exception_type", "exception_kwargs"),
+ [
+ pytest.param(
+ InternalError,
+ {"statement": mock.ANY, "params": mock.ANY, "orig": mock.ANY},
+ id="dbapi-internal",
+ ),
+ pytest.param(
+ OperationalError,
+ {"statement": mock.ANY, "params": mock.ANY, "orig": mock.ANY},
+ id="dbapi-operational",
+ ),
+ pytest.param(StaleDataError, {}, id="stale-data"),
+ ],
+ )
+ def test_retry_db_transaction_with_default_retries(self, caplog,
exception_type, exception_kwargs):
"""Test that by default 3 retries will be carried out"""
mock_obj = mock.MagicMock()
mock_session = mock.MagicMock()
mock_rollback = mock.MagicMock()
mock_session.rollback = mock_rollback
- db_error = excection_type(statement=mock.ANY, params=mock.ANY,
orig=mock.ANY)
+ db_error = exception_type(**exception_kwargs)
@retry_db_transaction
def test_function(session):
@@ -66,7 +77,7 @@ class TestRetries:
caplog.set_level(logging.DEBUG)
caplog.clear()
- with pytest.raises(excection_type):
+ with pytest.raises(exception_type):
test_function(session=mock_session)
for try_no in range(1, 4):