This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 13ce305a15d Refactor PostgresHook and associated runtime tests (#66893)
13ce305a15d is described below

commit 13ce305a15de47e38ecd14eabee65bacb61cdf3b
Author: SameerMesiah97 <[email protected]>
AuthorDate: Tue Jun 2 12:36:09 2026 +0100

    Refactor PostgresHook and associated runtime tests (#66893)
    
    * Refactor shared PostgresHook runtime test coverage for psycopg2 and 
psycopg3. Consolidate duplicated insert/upsert and dialect tests into a shared 
base class while preserving version-specific behavior and lineage coverage.
    
    * Remove unnecesary typing casts
    
    ---------
    
    Co-authored-by: Sameer Mesiah <[email protected]>
---
 .../airflow/providers/postgres/hooks/postgres.py   |  50 +--
 .../tests/unit/postgres/hooks/test_postgres.py     | 435 +++++++--------------
 2 files changed, 163 insertions(+), 322 deletions(-)

diff --git 
a/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py 
b/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py
index d38171bbe21..ec28bcc834f 100644
--- a/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py
+++ b/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py
@@ -230,6 +230,29 @@ class PostgresHook(DbApiHook):
         valid_cursors = ", ".join(cursor_types.keys())
         raise ValueError(f"Invalid cursor passed {_cursor}. Valid options are: 
{valid_cursors}")
 
+    def _get_cursor_config(self, raw_cursor: str) -> tuple[str, Any]:
+        cursor = self._get_cursor(raw_cursor)
+
+        if USE_PSYCOPG3:
+            return "row_factory", cursor
+
+        return "cursor_factory", cursor
+
+    def _create_connection(self, conn_args: dict[str, Any]) -> 
CompatConnection:
+        if USE_PSYCOPG3:
+            from psycopg.connection import Connection as pgConnection
+
+            connection = pgConnection.connect(**cast("Any", conn_args))
+
+            register_default_adapters(connection)
+
+            if self.enable_log_db_messages and hasattr(connection, 
"add_notice_handler"):
+                connection.add_notice_handler(self._notice_handler)
+
+            return connection
+
+        return ppg2_connect(**conn_args)
+
     def _generate_cursor_name(self):
         """Generate a unique name for server-side cursor."""
         import uuid
@@ -262,30 +285,13 @@ class PostgresHook(DbApiHook):
             if arg_name not in self.ignored_extra_options:
                 conn_args[arg_name] = arg_val
 
-        if USE_PSYCOPG3:
-            from psycopg.connection import Connection as pgConnection
-
-            raw_cursor = conn.extra_dejson.get("cursor")
-            if raw_cursor:
-                conn_args["row_factory"] = self._get_cursor(raw_cursor)
-
-            # Use Any type for the connection args to avoid type conflicts
-            connection = pgConnection.connect(**cast("Any", conn_args))
-            self.conn = cast("CompatConnection", connection)
-
-            # Register JSON handlers for both json and jsonb types
-            # This ensures JSON data is properly decoded from bytes to Python 
objects
-            register_default_adapters(connection)
+        raw_cursor = conn.extra_dejson.get("cursor")
 
-            # Add the notice handler AFTER the connection is established
-            if self.enable_log_db_messages and hasattr(self.conn, 
"add_notice_handler"):
-                self.conn.add_notice_handler(self._notice_handler)
-        else:  # psycopg2
-            raw_cursor = conn.extra_dejson.get("cursor", False)
-            if raw_cursor:
-                conn_args["cursor_factory"] = self._get_cursor(raw_cursor)
+        if raw_cursor:
+            key, value = self._get_cursor_config(raw_cursor)
+            conn_args[key] = value
 
-            self.conn = cast("CompatConnection", ppg2_connect(**conn_args))
+        self.conn = self._create_connection(conn_args)
 
         return self.conn
 
diff --git a/providers/postgres/tests/unit/postgres/hooks/test_postgres.py 
b/providers/postgres/tests/unit/postgres/hooks/test_postgres.py
index 6cad536c07b..fab8bc6f5d7 100644
--- a/providers/postgres/tests/unit/postgres/hooks/test_postgres.py
+++ b/providers/postgres/tests/unit/postgres/hooks/test_postgres.py
@@ -19,7 +19,6 @@ from __future__ import annotations
 
 import json
 import os
-from types import SimpleNamespace
 from unittest import mock
 
 import pandas as pd
@@ -30,7 +29,7 @@ import sqlalchemy
 from airflow.models import Connection
 from airflow.providers.common.compat.sdk import AirflowException, 
AirflowOptionalProviderFeatureException
 from airflow.providers.postgres.dialects.postgres import PostgresDialect
-from airflow.providers.postgres.hooks.postgres import CompatConnection, 
PostgresHook
+from airflow.providers.postgres.hooks.postgres import PostgresHook
 
 from tests_common.test_utils.common_sql import mock_db_hook
 from tests_common.test_utils.version_compat import NOTSET
@@ -59,36 +58,6 @@ else:
     import psycopg2.extras
 
 
[email protected]
-def postgres_hook_setup():
-    """Set up mock PostgresHook for testing."""
-    table = "test_postgres_hook_table"
-    cur = mock.MagicMock(rowcount=0)
-    conn = mock.MagicMock(spec=CompatConnection)
-    conn.cursor.return_value = cur
-
-    class UnitTestPostgresHook(PostgresHook):
-        conn_name_attr = "test_conn_id"
-
-        def get_conn(self):
-            return conn
-
-    db_hook = UnitTestPostgresHook()
-
-    # Return a namespace with all the objects
-    setup = SimpleNamespace(table=table, cur=cur, conn=conn, db_hook=db_hook)
-
-    yield setup
-
-    # Teardown - only for real database tests
-    try:
-        with PostgresHook().get_conn() as real_conn:
-            with real_conn.cursor() as real_cur:
-                real_cur.execute(f"DROP TABLE IF EXISTS {table}")
-    except Exception:
-        pass  # Ignore cleanup errors for unit tests
-
-
 @pytest.fixture
 def mock_connect(mocker):
     """Mock the connection object according to the correct psycopg version."""
@@ -816,10 +785,8 @@ class TestPostgresHook:
         ) == INSERT_SQL_STATEMENT.format('"schema"')
 
 
[email protected]("postgres")
[email protected](USE_PSYCOPG3, reason="psycopg v3 is available")
-class TestPostgresHookPPG2:
-    """PostgresHook tests that are specific to psycopg2."""
+class _BasePostgresHookRuntimeTests:
+    """Shared runtime tests for psycopg2 and psycopg3."""
 
     table = "test_postgres_hook_table"
 
@@ -841,6 +808,121 @@ class TestPostgresHookPPG2:
             with conn.cursor() as cur:
                 cur.execute(f"DROP TABLE IF EXISTS {self.table}")
 
+    def test_insert_rows(self):
+        table = "table"
+        rows = [("hello",), ("world",)]
+
+        self.db_hook.insert_rows(table, rows)
+
+        assert self.conn.close.call_count == 1
+        assert self.cur.close.call_count == 1
+        assert self.conn.commit.call_count == 2
+
+        sql = f"INSERT INTO {table}  VALUES (%s)"
+        self.cur.executemany.assert_any_call(sql, rows)
+
+    def test_insert_rows_replace(self):
+        table = "table"
+        rows = [
+            (1, "hello"),
+            (2, "world"),
+        ]
+        fields = ("id", "value")
+
+        self.db_hook.insert_rows(
+            table,
+            rows,
+            fields,
+            replace=True,
+            replace_index=fields[0],
+        )
+
+        assert self.conn.close.call_count == 1
+        assert self.cur.close.call_count == 1
+        assert self.conn.commit.call_count == 2
+
+        sql = (
+            f"INSERT INTO {table} ({fields[0]}, {fields[1]}) VALUES (%s,%s) "
+            f"ON CONFLICT ({fields[0]}) DO UPDATE SET {fields[1]} = 
excluded.{fields[1]}"
+        )
+        self.cur.executemany.assert_any_call(sql, rows)
+
+    def test_insert_rows_replace_missing_target_field_arg(self):
+        table = "table"
+        rows = [
+            (1, "hello"),
+            (2, "world"),
+        ]
+        fields = ("id", "value")
+
+        with pytest.raises(
+            ValueError,
+            match="PostgreSQL ON CONFLICT upsert syntax requires column names",
+        ):
+            self.db_hook.insert_rows(
+                table,
+                rows,
+                replace=True,
+                replace_index=fields[0],
+            )
+
+    def test_insert_rows_replace_missing_replace_index_arg(self):
+        table = "table"
+        rows = [
+            (1, "hello"),
+            (2, "world"),
+        ]
+        fields = ("id", "value")
+
+        with pytest.raises(
+            ValueError,
+            match="PostgreSQL ON CONFLICT upsert syntax requires an unique 
index",
+        ):
+            self.db_hook.insert_rows(
+                table,
+                rows,
+                fields,
+                replace=True,
+            )
+
+    def test_insert_rows_replace_all_index(self):
+        table = "table"
+        rows = [
+            (1, "hello"),
+            (2, "world"),
+        ]
+        fields = ("id", "value")
+
+        self.db_hook.insert_rows(
+            table,
+            rows,
+            fields,
+            replace=True,
+            replace_index=fields,
+        )
+
+        assert self.conn.close.call_count == 1
+        assert self.cur.close.call_count == 1
+        assert self.conn.commit.call_count == 2
+
+        sql = (
+            f"INSERT INTO {table} ({', '.join(fields)}) VALUES (%s,%s) "
+            f"ON CONFLICT ({', '.join(fields)}) DO NOTHING"
+        )
+        self.cur.executemany.assert_any_call(sql, rows)
+
+    def test_dialect_name(self):
+        assert self.db_hook.dialect_name == "postgresql"
+
+    def test_dialect(self):
+        assert isinstance(self.db_hook.dialect, PostgresDialect)
+
+
[email protected]("postgres")
[email protected](USE_PSYCOPG3, reason="psycopg v3 is available")
+class TestPostgresHookPPG2(_BasePostgresHookRuntimeTests):
+    """PostgresHook tests that are specific to psycopg2."""
+
     def test_copy_expert(self, mocker):
         open_mock = mocker.mock_open(read_data='{"some": "json"}')
         mocker.patch("airflow.providers.postgres.hooks.postgres.open", 
open_mock)
@@ -915,169 +997,59 @@ class TestPostgresHookPPG2:
         assert call_kw["sql"] == sql
         assert call_kw["sql_parameters"] == parameters
 
-    def test_insert_rows(self, postgres_hook_setup):
-        setup = postgres_hook_setup
-        table = "table"
-        rows = [("hello",), ("world",)]
-
-        setup.db_hook.insert_rows(table, rows)
-
-        assert setup.conn.close.call_count == 1
-        assert setup.cur.close.call_count == 1
-
-        commit_count = 2  # The first and last commit
-        assert commit_count == setup.conn.commit.call_count
-
-        sql = f"INSERT INTO {table}  VALUES (%s)"
-        setup.cur.executemany.assert_any_call(sql, rows)
-
     @mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage")
-    def test_insert_rows_hook_lineage(self, mock_send_lineage, 
postgres_hook_setup):
-        setup = postgres_hook_setup
+    def test_insert_rows_hook_lineage(self, mock_send_lineage):
         table = "table"
         rows = [("hello",), ("world",)]
 
-        setup.db_hook.insert_rows(table, rows)
+        self.db_hook.insert_rows(table, rows)
 
         mock_send_lineage.assert_called_once()
+
         call_kw = mock_send_lineage.call_args.kwargs
-        assert call_kw["context"] is setup.db_hook
+
+        assert call_kw["context"] is self.db_hook
         assert call_kw["sql"] == f"INSERT INTO {table}  VALUES (%s)"
         assert call_kw["row_count"] == 2
 
     @mock.patch("airflow.providers.postgres.hooks.postgres.execute_batch")
-    def test_insert_rows_fast_executemany(self, mock_execute_batch, 
postgres_hook_setup):
-        setup = postgres_hook_setup
+    def test_insert_rows_fast_executemany(self, mock_execute_batch):
         table = "table"
         rows = [("hello",), ("world",)]
 
-        setup.db_hook.insert_rows(table, rows, fast_executemany=True)
+        self.db_hook.insert_rows(table, rows, fast_executemany=True)
 
-        assert setup.conn.close.call_count == 1
-        assert setup.cur.close.call_count == 1
+        assert self.conn.close.call_count == 1
+        assert self.cur.close.call_count == 1
 
         commit_count = 2  # The first and last commit
-        assert setup.conn.commit.call_count == commit_count
+        assert self.conn.commit.call_count == commit_count
 
         mock_execute_batch.assert_called_once_with(
-            setup.cur,
+            self.cur,
             f"INSERT INTO {table}  VALUES (%s)",  # expected SQL
             [("hello",), ("world",)],  # expected values
             page_size=1000,
         )
 
         # executemany should NOT be called in this mode
-        setup.cur.executemany.assert_not_called()
+        self.cur.executemany.assert_not_called()
 
     
@mock.patch("airflow.providers.postgres.hooks.postgres.send_sql_hook_lineage")
     @mock.patch("airflow.providers.postgres.hooks.postgres.execute_batch")
-    def test_insert_rows_fast_executemany_hook_lineage(
-        self, mock_execute_batch, mock_send_lineage, postgres_hook_setup
-    ):
-        setup = postgres_hook_setup
+    def test_insert_rows_fast_executemany_hook_lineage(self, 
mock_execute_batch, mock_send_lineage):
+
         table = "table"
         rows = [("hello",), ("world",)]
 
-        setup.db_hook.insert_rows(table, rows, fast_executemany=True)
+        self.db_hook.insert_rows(table, rows, fast_executemany=True)
 
         mock_send_lineage.assert_called_once()
         call_kw = mock_send_lineage.call_args.kwargs
-        assert call_kw["context"] is setup.db_hook
+        assert call_kw["context"] is self.db_hook
         assert call_kw["sql"] == f"INSERT INTO {table}  VALUES (%s)"
         assert call_kw["row_count"] == 2
 
-    def test_insert_rows_replace(self, postgres_hook_setup):
-        setup = postgres_hook_setup
-        table = "table"
-        rows = [
-            (
-                1,
-                "hello",
-            ),
-            (
-                2,
-                "world",
-            ),
-        ]
-        fields = ("id", "value")
-
-        setup.db_hook.insert_rows(table, rows, fields, replace=True, 
replace_index=fields[0])
-
-        assert setup.conn.close.call_count == 1
-        assert setup.cur.close.call_count == 1
-
-        commit_count = 2  # The first and last commit
-        assert commit_count == setup.conn.commit.call_count
-
-        sql = (
-            f"INSERT INTO {table} ({fields[0]}, {fields[1]}) VALUES (%s,%s) "
-            f"ON CONFLICT ({fields[0]}) DO UPDATE SET {fields[1]} = 
excluded.{fields[1]}"
-        )
-        setup.cur.executemany.assert_any_call(sql, rows)
-
-    def test_insert_rows_replace_missing_target_field_arg(self, 
postgres_hook_setup):
-        setup = postgres_hook_setup
-        table = "table"
-        rows = [
-            (
-                1,
-                "hello",
-            ),
-            (
-                2,
-                "world",
-            ),
-        ]
-        fields = ("id", "value")
-        with pytest.raises(ValueError, match="PostgreSQL ON CONFLICT upsert 
syntax requires column names"):
-            setup.db_hook.insert_rows(table, rows, replace=True, 
replace_index=fields[0])
-
-    def test_insert_rows_replace_missing_replace_index_arg(self, 
postgres_hook_setup):
-        setup = postgres_hook_setup
-        table = "table"
-        rows = [
-            (
-                1,
-                "hello",
-            ),
-            (
-                2,
-                "world",
-            ),
-        ]
-        fields = ("id", "value")
-        with pytest.raises(ValueError, match="PostgreSQL ON CONFLICT upsert 
syntax requires an unique index"):
-            setup.db_hook.insert_rows(table, rows, fields, replace=True)
-
-    def test_insert_rows_replace_all_index(self, postgres_hook_setup):
-        setup = postgres_hook_setup
-        table = "table"
-        rows = [
-            (
-                1,
-                "hello",
-            ),
-            (
-                2,
-                "world",
-            ),
-        ]
-        fields = ("id", "value")
-
-        setup.db_hook.insert_rows(table, rows, fields, replace=True, 
replace_index=fields)
-
-        assert setup.conn.close.call_count == 1
-        assert setup.cur.close.call_count == 1
-
-        commit_count = 2  # The first and last commit
-        assert commit_count == setup.conn.commit.call_count
-
-        sql = (
-            f"INSERT INTO {table} ({', '.join(fields)}) VALUES (%s,%s) "
-            f"ON CONFLICT ({', '.join(fields)}) DO NOTHING"
-        )
-        setup.cur.executemany.assert_any_call(sql, rows)
-
     @pytest.mark.usefixtures("reset_logging_config")
     def test_get_all_db_log_messages(self, mocker):
         messages = ["a", "b", "c"]
@@ -1120,40 +1092,12 @@ class TestPostgresHookPPG2:
         finally:
             hook.run(sql=f"DROP PROCEDURE {proc_name} (s text)")
 
-    def test_dialect_name(self, postgres_hook_setup):
-        setup = postgres_hook_setup
-        assert setup.db_hook.dialect_name == "postgresql"
-
-    def test_dialect(self, postgres_hook_setup):
-        setup = postgres_hook_setup
-        assert isinstance(setup.db_hook.dialect, PostgresDialect)
-
 
 @pytest.mark.backend("postgres")
 @pytest.mark.skipif(not USE_PSYCOPG3, reason="psycopg v3 or sqlalchemy v2 are 
not available")
-class TestPostgresHookPPG3:
+class TestPostgresHookPPG3(_BasePostgresHookRuntimeTests):
     """PostgresHook tests that are specific to psycopg3."""
 
-    table = "test_postgres_hook_table"
-
-    def setup_method(self):
-        self.cur = mock.MagicMock(rowcount=0)
-        self.conn = conn = mock.MagicMock()
-        self.conn.cursor.return_value = self.cur
-
-        class UnitTestPostgresHook(PostgresHook):
-            conn_name_attr = "test_conn_id"
-
-            def get_conn(self):
-                return conn
-
-        self.db_hook = UnitTestPostgresHook()
-
-    def teardown_method(self):
-        with PostgresHook().get_conn() as conn:
-            with conn.cursor() as cur:
-                cur.execute(f"DROP TABLE IF EXISTS {self.table}")
-
     def test_copy_expert_from(self, mocker):
         """Tests copy_expert with a 'COPY FROM STDIN' operation."""
         statement = "COPY test_table FROM STDIN"
@@ -1235,109 +1179,6 @@ class TestPostgresHookPPG3:
         )
         self.conn.commit.assert_called_once()
 
-    def test_insert_rows(self):
-        table = "table"
-        rows = [("hello",), ("world",)]
-
-        self.db_hook.insert_rows(table, rows)
-
-        assert self.conn.close.call_count == 1
-        assert self.cur.close.call_count == 1
-
-        commit_count = 2  # The first and last commit
-        assert commit_count == self.conn.commit.call_count
-
-        sql = f"INSERT INTO {table}  VALUES (%s)"
-        self.cur.executemany.assert_any_call(sql, rows)
-
-    def test_insert_rows_replace(self):
-        table = "table"
-        rows = [
-            (
-                1,
-                "hello",
-            ),
-            (
-                2,
-                "world",
-            ),
-        ]
-        fields = ("id", "value")
-
-        self.db_hook.insert_rows(table, rows, fields, replace=True, 
replace_index=fields[0])
-
-        assert self.conn.close.call_count == 1
-        assert self.cur.close.call_count == 1
-
-        commit_count = 2  # The first and last commit
-        assert commit_count == self.conn.commit.call_count
-
-        sql = (
-            f"INSERT INTO {table} ({fields[0]}, {fields[1]}) VALUES (%s,%s) "
-            f"ON CONFLICT ({fields[0]}) DO UPDATE SET {fields[1]} = 
excluded.{fields[1]}"
-        )
-        self.cur.executemany.assert_any_call(sql, rows)
-
-    def test_insert_rows_replace_missing_target_field_arg(self):
-        table = "table"
-        rows = [
-            (
-                1,
-                "hello",
-            ),
-            (
-                2,
-                "world",
-            ),
-        ]
-        fields = ("id", "value")
-        with pytest.raises(ValueError, match="PostgreSQL ON CONFLICT upsert 
syntax requires column names"):
-            self.db_hook.insert_rows(table, rows, replace=True, 
replace_index=fields[0])
-
-    def test_insert_rows_replace_missing_replace_index_arg(self):
-        table = "table"
-        rows = [
-            (
-                1,
-                "hello",
-            ),
-            (
-                2,
-                "world",
-            ),
-        ]
-        fields = ("id", "value")
-        with pytest.raises(ValueError, match="PostgreSQL ON CONFLICT upsert 
syntax requires an unique index"):
-            self.db_hook.insert_rows(table, rows, fields, replace=True)
-
-    def test_insert_rows_replace_all_index(self):
-        table = "table"
-        rows = [
-            (
-                1,
-                "hello",
-            ),
-            (
-                2,
-                "world",
-            ),
-        ]
-        fields = ("id", "value")
-
-        self.db_hook.insert_rows(table, rows, fields, replace=True, 
replace_index=fields)
-
-        assert self.conn.close.call_count == 1
-        assert self.cur.close.call_count == 1
-
-        commit_count = 2  # The first and last commit
-        assert commit_count == self.conn.commit.call_count
-
-        sql = (
-            f"INSERT INTO {table} ({', '.join(fields)}) VALUES (%s,%s) "
-            f"ON CONFLICT ({', '.join(fields)}) DO NOTHING"
-        )
-        self.cur.executemany.assert_any_call(sql, rows)
-
     @pytest.mark.skip(reason="Notice handling is callback-based in psycopg3 
and cannot be tested this way.")
     def test_get_all_db_log_messages(self, mocker):
         pass
@@ -1366,9 +1207,3 @@ class TestPostgresHookPPG3:
             mock_logger.info.assert_any_call("Message from db: 42")
         finally:
             hook.run(sql=f"DROP PROCEDURE {proc_name} (s text)")
-
-    def test_dialect_name(self):
-        assert self.db_hook.dialect_name == "postgresql"
-
-    def test_dialect(self):
-        assert isinstance(self.db_hook.dialect, PostgresDialect)

Reply via email to