This is an automated email from the ASF dual-hosted git repository.
kaxil pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new ecbb1bff407 Added insert and update on conflict to renderedtifields.py
(#63874)
ecbb1bff407 is described below
commit ecbb1bff407f318218224495cd52a020013e183f
Author: manipatnam <[email protected]>
AuthorDate: Wed Jun 3 06:11:50 2026 +0530
Added insert and update on conflict to renderedtifields.py (#63874)
closes: #61705
---
.../src/airflow/models/renderedtifields.py | 31 +++++++++--
airflow-core/src/airflow/models/variable.py | 34 +-----------
airflow-core/src/airflow/utils/sqlalchemy.py | 50 ++++++++++++++++-
.../tests/unit/models/test_renderedtifields.py | 62 +++++++++++++++++++++-
4 files changed, 140 insertions(+), 37 deletions(-)
diff --git a/airflow-core/src/airflow/models/renderedtifields.py
b/airflow-core/src/airflow/models/renderedtifields.py
index d9b5f115b33..e405f3bfce7 100644
--- a/airflow-core/src/airflow/models/renderedtifields.py
+++ b/airflow-core/src/airflow/models/renderedtifields.py
@@ -38,7 +38,7 @@ from airflow.models.base import StringID,
TaskInstanceDependencies
from airflow.serialization.helpers import serialize_template_field
from airflow.utils.retries import retry_db_transaction
from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.sqlalchemy import get_dialect_name
+from airflow.utils.sqlalchemy import build_upsert_stmt, get_dialect_name
if TYPE_CHECKING:
from sqlalchemy.orm import Session
@@ -239,13 +239,38 @@ class
RenderedTaskInstanceFields(TaskInstanceDependencies):
@provide_session
@retry_db_transaction
- def write(self, session: Session):
+ def write(self, session: Session = NEW_SESSION):
"""
Write instance to database.
+ Uses a database-level upsert (INSERT ... ON CONFLICT DO UPDATE) to
+ atomically insert or update the record, avoiding race conditions that
+ can occur with session.merge() when concurrent requests (e.g. from
+ client-side timeout retries) target the same primary key.
+
:param session: SqlAlchemy Session
"""
- session.merge(self)
+ values = {
+ "dag_id": self.dag_id,
+ "task_id": self.task_id,
+ "run_id": self.run_id,
+ "map_index": self.map_index,
+ "rendered_fields": self.rendered_fields,
+ "k8s_pod_yaml": self.k8s_pod_yaml,
+ }
+ update_on_conflict = {
+ "rendered_fields": self.rendered_fields,
+ "k8s_pod_yaml": self.k8s_pod_yaml,
+ }
+
+ stmt = build_upsert_stmt(
+ get_dialect_name(session),
+ RenderedTaskInstanceFields,
+ ["dag_id", "task_id", "run_id", "map_index"],
+ values,
+ update_on_conflict,
+ )
+ session.execute(stmt)
@classmethod
@provide_session
diff --git a/airflow-core/src/airflow/models/variable.py
b/airflow-core/src/airflow/models/variable.py
index 667c3303567..eb50b92f988 100644
--- a/airflow-core/src/airflow/models/variable.py
+++ b/airflow-core/src/airflow/models/variable.py
@@ -49,44 +49,14 @@ except ImportError:
from airflow.secrets.metastore import MetastoreBackend
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, create_session, provide_session
-from airflow.utils.sqlalchemy import get_dialect_name
+from airflow.utils.sqlalchemy import build_upsert_stmt, get_dialect_name
if TYPE_CHECKING:
- from sqlalchemy.dialects.mysql.dml import Insert as MySQLInsert
- from sqlalchemy.dialects.postgresql.dml import Insert as PostgreSQLInsert
- from sqlalchemy.dialects.sqlite.dml import Insert as SQLiteInsert
from sqlalchemy.orm import Session
log = logging.getLogger(__name__)
-def _build_variable_upsert_stmt(
- dialect: str | None,
- model: type[Variable],
- conflict_cols: list[str],
- values: dict[str, Any],
- update_fields: dict[str, Any],
-) -> MySQLInsert | PostgreSQLInsert | SQLiteInsert:
- """Return a dialect-specific INSERT ... ON CONFLICT UPDATE statement."""
- stmt: MySQLInsert | PostgreSQLInsert | SQLiteInsert
- if dialect == "postgresql":
- from sqlalchemy.dialects.postgresql import insert as pg_insert
-
- stmt = pg_insert(model).values(**values)
- stmt = stmt.on_conflict_do_update(index_elements=conflict_cols,
set_=update_fields)
- elif dialect == "mysql":
- from sqlalchemy.dialects.mysql import insert as mysql_insert
-
- stmt = mysql_insert(model).values(**values)
- stmt = stmt.on_duplicate_key_update(**update_fields)
- else:
- from sqlalchemy.dialects.sqlite import insert as sqlite_insert
-
- stmt = sqlite_insert(model).values(**values)
- stmt = stmt.on_conflict_do_update(index_elements=conflict_cols,
set_=update_fields)
- return stmt
-
-
class Variable(Base, LoggingMixin):
"""A generic way to store and retrieve arbitrary content or settings as a
simple key/value store."""
@@ -311,7 +281,7 @@ class Variable(Base, LoggingMixin):
is_encrypted=is_encrypted,
team_name=team_name,
)
- stmt = _build_variable_upsert_stmt(
+ stmt = build_upsert_stmt(
get_dialect_name(session), Variable, ["key"], upsert_values,
update_fields
)
session.execute(stmt)
diff --git a/airflow-core/src/airflow/utils/sqlalchemy.py
b/airflow-core/src/airflow/utils/sqlalchemy.py
index a767c65fad9..35a6f4ee05b 100644
--- a/airflow-core/src/airflow/utils/sqlalchemy.py
+++ b/airflow-core/src/airflow/utils/sqlalchemy.py
@@ -22,7 +22,7 @@ import copy
import datetime
import logging
from collections.abc import Generator
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any
from sqlalchemy import TIMESTAMP, PickleType, String, event, nullsfirst
from sqlalchemy.dialects import mysql
@@ -39,6 +39,9 @@ if TYPE_CHECKING:
from collections.abc import Iterable
from kubernetes.client.models.v1_pod import V1Pod
+ from sqlalchemy.dialects.mysql.dml import Insert as MySQLInsert
+ from sqlalchemy.dialects.postgresql.dml import Insert as PostgreSQLInsert
+ from sqlalchemy.dialects.sqlite.dml import Insert as SQLiteInsert
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import Session
from sqlalchemy.sql import Select
@@ -58,6 +61,51 @@ def get_dialect_name(session: Session) -> str | None:
return getattr(bind.dialect, "name", None)
+def build_upsert_stmt(
+ dialect: str | None,
+ model: Any,
+ conflict_cols: list[str],
+ values: dict[str, Any],
+ update_fields: dict[str, Any],
+) -> MySQLInsert | PostgreSQLInsert | SQLiteInsert:
+ """
+ Build a dialect-specific ``INSERT ... ON CONFLICT DO UPDATE`` (upsert)
statement.
+
+ A single-statement upsert is atomic at the database level, which avoids the
+ race conditions that arise from the non-atomic SELECT-then-INSERT
performed by
+ ``session.merge()`` when concurrent transactions target the same primary
key.
+
+ :param dialect: dialect name as returned by :func:`get_dialect_name`
+ :param model: the SQLAlchemy model (or table) to insert into
+ :param conflict_cols: columns that make up the conflict target
(PostgreSQL/SQLite)
+ :param values: column values to insert
+ :param update_fields: column values to set when a conflicting row already
exists
+ :raises ValueError: if the dialect does not support a known upsert syntax
+ """
+ stmt: MySQLInsert | PostgreSQLInsert | SQLiteInsert
+ if dialect == "postgresql":
+ from sqlalchemy.dialects.postgresql import insert as pg_insert
+
+ stmt = pg_insert(model).values(**values)
+ stmt = stmt.on_conflict_do_update(index_elements=conflict_cols,
set_=update_fields)
+ elif dialect == "mysql":
+ from sqlalchemy.dialects.mysql import insert as mysql_insert
+
+ stmt = mysql_insert(model).values(**values)
+ stmt = stmt.on_duplicate_key_update(**update_fields)
+ elif dialect == "sqlite":
+ from sqlalchemy.dialects.sqlite import insert as sqlite_insert
+
+ stmt = sqlite_insert(model).values(**values)
+ stmt = stmt.on_conflict_do_update(index_elements=conflict_cols,
set_=update_fields)
+ else:
+ raise ValueError(
+ f"Unsupported database dialect '{dialect}' for upsert. "
+ "Supported dialects are: postgresql, mysql, sqlite."
+ )
+ return stmt
+
+
class random_db_uuid(FunctionElement):
"""
Cross-dialect random UUID generation for use in SQL expressions.
diff --git a/airflow-core/tests/unit/models/test_renderedtifields.py
b/airflow-core/tests/unit/models/test_renderedtifields.py
index 37e6088494d..f695c46aac2 100644
--- a/airflow-core/tests/unit/models/test_renderedtifields.py
+++ b/airflow-core/tests/unit/models/test_renderedtifields.py
@@ -27,7 +27,7 @@ from unittest import mock
import pendulum
import pytest
-from sqlalchemy import select
+from sqlalchemy import insert, select
from airflow import settings
from airflow._shared.template_rendering import truncate_rendered_value
@@ -372,6 +372,66 @@ class TestRenderedTaskInstanceFields:
{"bash_command": "echo test_val_updated", "env": None, "cwd":
None},
)
+ def test_write_upsert_existing_record(self, dag_maker, session):
+ """
+ Verify that write() updates an existing row instead of failing on its
primary key.
+
+ A row is seeded via a direct INSERT (bypassing write()) to represent a
record
+ already present for this task instance. Calling write() with different
values
+ must update that row via the upsert's DO UPDATE branch.
+
+ This exercises the upsert's update path within a single transaction;
it does not
+ reproduce the concurrent-transaction race from #61705, which needs two
separate
+ uncommitted transactions and cannot be triggered reliably in a unit
test. The
+ atomic single-statement upsert is what closes that race in production.
+ """
+ with dag_maker("test_write_upsert", session=session):
+ task = BashOperator(task_id="test", bash_command="echo original")
+ dr = dag_maker.create_dagrun()
+ ti = dr.task_instances[0]
+ ti.task = task
+
+ # Seed the row via a direct INSERT to simulate a row already committed
by
+ # the first request. Using write() here would mask whether write()
itself
+ # correctly handles conflicts, since merge() also handles existing
rows.
+ session.execute(
+ insert(RTIF).values(
+ dag_id=ti.dag_id,
+ task_id=ti.task_id,
+ run_id=ti.run_id,
+ map_index=ti.map_index,
+ rendered_fields={"bash_command": "echo original"},
+ k8s_pod_yaml=None,
+ )
+ )
+ session.flush()
+
+ result = session.scalar(
+ select(RTIF).where(
+ RTIF.dag_id == ti.dag_id,
+ RTIF.task_id == ti.task_id,
+ RTIF.run_id == ti.run_id,
+ RTIF.map_index == ti.map_index,
+ )
+ )
+ assert result.rendered_fields == {"bash_command": "echo original"}
+
+ # write() must not raise IntegrityError even though the row already
exists.
+ rtif = RTIF(ti=ti, render_templates=False,
rendered_fields={"bash_command": "echo updated"})
+ rtif.write(session=session)
+ session.flush()
+ session.expire_all()
+
+ result = session.scalar(
+ select(RTIF).where(
+ RTIF.dag_id == ti.dag_id,
+ RTIF.task_id == ti.task_id,
+ RTIF.run_id == ti.run_id,
+ RTIF.map_index == ti.map_index,
+ )
+ )
+ assert result.rendered_fields == {"bash_command": "echo updated"}
+
@mock.patch.dict(os.environ, {"AIRFLOW_VAR_API_KEY": "secret"})
def test_redact(self, dag_maker):
with mock.patch("airflow._shared.secrets_masker.redact",
autospec=True) as redact: