This is an automated email from the ASF dual-hosted git repository.

dabla pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new b35ecd59e18 Add configurable UPSERT update fields to PostgresHook 
(#67045)
b35ecd59e18 is described below

commit b35ecd59e18dd85519036572d45e43d85c372b2c
Author: SameerMesiah97 <[email protected]>
AuthorDate: Tue Jun 16 11:23:59 2026 +0100

    Add configurable UPSERT update fields to PostgresHook (#67045)
    
    Extend PostgreSQL ON CONFLICT support by allowing callers to
    specify which columns are updated when conflicts occur.
    
    Preserve existing behavior when no update fields are provided,
    support DO NOTHING semantics via an empty update field list,
    and add an upsert_rows convenience wrapper built on top of the
    existing insert_rows(replace=True) implementation.
    
    Co-authored-by: Sameer Mesiah <[email protected]>
---
 .../providers/postgres/dialects/postgres.py        | 12 ++++-
 .../airflow/providers/postgres/hooks/postgres.py   | 41 +++++++++++++++-
 .../tests/unit/postgres/dialects/test_postgres.py  | 57 ++++++++++++++++++++++
 .../tests/unit/postgres/hooks/test_postgres.py     | 29 +++++++++++
 4 files changed, 136 insertions(+), 3 deletions(-)

diff --git 
a/providers/postgres/src/airflow/providers/postgres/dialects/postgres.py 
b/providers/postgres/src/airflow/providers/postgres/dialects/postgres.py
index 446ea199fa1..3635cca98c3 100644
--- a/providers/postgres/src/airflow/providers/postgres/dialects/postgres.py
+++ b/providers/postgres/src/airflow/providers/postgres/dialects/postgres.py
@@ -106,9 +106,11 @@ class PostgresDialect(Dialect):
         :param table: Name of the target table
         :param values: The row to insert into the table
         :param target_fields: The names of the columns to fill in the table
-        :param replace: Whether to replace instead of insert
         :param replace_index: the column or list of column names to act as
             index for the ON CONFLICT clause
+        :param replace_target: Column name or list of column names to update 
when
+            a conflict occurs. If omitted, all non-conflict columns are 
updated.
+            If an empty list is provided, ``DO NOTHING`` is used.
         :return: The generated INSERT or REPLACE SQL statement
         """
         if not target_fields:
@@ -124,7 +126,13 @@ class PostgresDialect(Dialect):
 
         sql = self.generate_insert_sql(table, values, target_fields, **kwargs)
         on_conflict_str = f" ON CONFLICT ({', '.join(map(self.escape_word, 
replace_index))})"
-        replace_target = [self.escape_word(f) for f in target_fields if f not 
in replace_index]
+
+        replace_target = kwargs.get("replace_target")
+
+        if replace_target is None:
+            replace_target = [self.escape_word(f) for f in target_fields if f 
not in replace_index]
+        else:
+            replace_target = [self.escape_word(f) for f in replace_target]
 
         if replace_target:
             replace_target_str = ", ".join(f"{col} = excluded.{col}" for col 
in replace_target)
diff --git 
a/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py 
b/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py
index 66f692eb004..77ca0537b93 100644
--- a/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py
+++ b/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py
@@ -18,7 +18,7 @@
 from __future__ import annotations
 
 import os
-from collections.abc import Mapping
+from collections.abc import Iterable, Mapping
 from contextlib import closing
 from copy import deepcopy
 from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeAlias, cast, 
overload
@@ -741,3 +741,42 @@ class PostgresHook(DbApiHook):
 
         self.log.info("Done loading. Loaded a total of %s rows into %s", 
nb_rows, table)
         return None
+
+    def upsert_rows(
+        self,
+        table: str,
+        rows: Iterable[tuple[Any, ...]],
+        target_fields: list[str],
+        conflict_fields: list[str],
+        update_fields: list[str] | None = None,
+        commit_every: int = 1000,
+        *,
+        fast_executemany: bool = False,
+        autocommit: bool = False,
+    ) -> None:
+        """
+        Upsert rows into a PostgreSQL table using ``ON CONFLICT``.
+
+        :param table: Name of the target table.
+        :param rows: Rows to upsert.
+        :param target_fields: Non-empty column names used in the ``INSERT`` 
statement.
+        :param conflict_fields: Non-empty column names used in the ``ON 
CONFLICT`` clause.
+        :param update_fields: Columns updated on conflict. If omitted, all
+            non-conflict columns are updated. If an empty list is provided,
+            conflicting rows are ignored via ``DO NOTHING``.
+        :param commit_every: Maximum number of rows per transaction. Default 
value is 1000.
+        :param fast_executemany: Use ``psycopg2.extras.execute_batch`` for 
improved
+            batch performance.
+        :param autocommit: Connection autocommit setting.
+        """
+        return self.insert_rows(
+            table=table,
+            rows=rows,
+            target_fields=target_fields,
+            replace_index=conflict_fields,
+            replace_target=update_fields,
+            commit_every=commit_every,
+            replace=True,
+            fast_executemany=fast_executemany,
+            autocommit=autocommit,
+        )
diff --git a/providers/postgres/tests/unit/postgres/dialects/test_postgres.py 
b/providers/postgres/tests/unit/postgres/dialects/test_postgres.py
index 999386baac1..946805ec40f 100644
--- a/providers/postgres/tests/unit/postgres/dialects/test_postgres.py
+++ b/providers/postgres/tests/unit/postgres/dialects/test_postgres.py
@@ -19,6 +19,8 @@ from __future__ import annotations
 
 from unittest.mock import MagicMock
 
+import pytest
+
 from airflow.providers.common.sql.hooks.sql import DbApiHook
 from airflow.providers.postgres.dialects.postgres import PostgresDialect
 
@@ -103,3 +105,58 @@ class TestPostgresDialect:
             INSERT INTO hollywood.actors ("id", "name", "firstname", "age") 
VALUES (?,?,?,?,?) ON CONFLICT ("id") DO UPDATE SET "name" = excluded."name", 
"firstname" = excluded."firstname", "age" = excluded."age"
         """.strip()
         )
+
+    @pytest.mark.parametrize(
+        ("replace_index", "replace_target", "expected_clause"),
+        [
+            (
+                None,
+                ["name"],
+                "ON CONFLICT (id) DO UPDATE SET name = excluded.name",
+            ),
+            (
+                None,
+                ["name", "age"],
+                "ON CONFLICT (id) DO UPDATE SET name = excluded.name, age = 
excluded.age",
+            ),
+            (
+                None,
+                [],
+                "ON CONFLICT (id) DO NOTHING",
+            ),
+            (
+                ["id", "name"],
+                ["age"],
+                "ON CONFLICT (id, name) DO UPDATE SET age = excluded.age",
+            ),
+        ],
+    )
+    def test_generate_replace_sql_with_replace_target(
+        self,
+        replace_index,
+        replace_target,
+        expected_clause,
+    ):
+        values = [
+            {"id": 1, "name": "Stallone", "firstname": "Sylvester", "age": 
"78"},
+            {"id": 2, "name": "Statham", "firstname": "Jason", "age": "57"},
+            {"id": 3, "name": "Li", "firstname": "Jet", "age": "61"},
+            {"id": 4, "name": "Lundgren", "firstname": "Dolph", "age": "66"},
+            {"id": 5, "name": "Norris", "firstname": "Chuck", "age": "84"},
+        ]
+
+        target_fields = ["id", "name", "firstname", "age"]
+
+        sql = PostgresDialect(self.test_db_hook).generate_replace_sql(
+            "hollywood.actors",
+            values,
+            target_fields,
+            replace_index=replace_index,
+            replace_target=replace_target,
+        )
+
+        assert (
+            sql
+            == f"""
+            INSERT INTO hollywood.actors (id, name, firstname, age) VALUES 
(?,?,?,?,?) {expected_clause}""".strip()
+        )
diff --git a/providers/postgres/tests/unit/postgres/hooks/test_postgres.py 
b/providers/postgres/tests/unit/postgres/hooks/test_postgres.py
index f356e3f233e..e8c1dd5eff9 100644
--- a/providers/postgres/tests/unit/postgres/hooks/test_postgres.py
+++ b/providers/postgres/tests/unit/postgres/hooks/test_postgres.py
@@ -911,6 +911,35 @@ class _BasePostgresHookRuntimeTests:
         )
         self.cur.executemany.assert_any_call(sql, rows)
 
+    
@mock.patch("airflow.providers.postgres.hooks.postgres.PostgresHook.insert_rows")
+    def test_upsert_rows(self, mock_insert_rows):
+
+        rows = [(1, "hello")]
+        table = "table"
+
+        self.db_hook.upsert_rows(
+            table=table,
+            rows=rows,
+            target_fields=["id", "value"],
+            conflict_fields=["id"],
+            update_fields=["value"],
+            commit_every=123,
+            fast_executemany=True,
+            autocommit=True,
+        )
+
+        mock_insert_rows.assert_called_once_with(
+            table=table,
+            rows=rows,
+            target_fields=["id", "value"],
+            replace_index=["id"],
+            replace_target=["value"],
+            commit_every=123,
+            replace=True,
+            fast_executemany=True,
+            autocommit=True,
+        )
+
     def test_dialect_name(self):
         assert self.db_hook.dialect_name == "postgresql"
 

Reply via email to