This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 13ce305a15d Refactor PostgresHook and associated runtime tests (#66893)
13ce305a15d is described below
commit 13ce305a15de47e38ecd14eabee65bacb61cdf3b
Author: SameerMesiah97 <[email protected]>
AuthorDate: Tue Jun 2 12:36:09 2026 +0100
Refactor PostgresHook and associated runtime tests (#66893)
* Refactor shared PostgresHook runtime test coverage for psycopg2 and
psycopg3. Consolidate duplicated insert/upsert and dialect tests into a shared
base class while preserving version-specific behavior and lineage coverage.
* Remove unnecesary typing casts
---------
Co-authored-by: Sameer Mesiah <[email protected]>
---
.../airflow/providers/postgres/hooks/postgres.py | 50 +--
.../tests/unit/postgres/hooks/test_postgres.py | 435 +++++++--------------
2 files changed, 163 insertions(+), 322 deletions(-)
diff --git
a/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py
b/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py
index d38171bbe21..ec28bcc834f 100644
--- a/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py
+++ b/providers/postgres/src/airflow/providers/postgres/hooks/postgres.py
@@ -230,6 +230,29 @@ class PostgresHook(DbApiHook):
valid_cursors = ", ".join(cursor_types.keys())
raise ValueError(f"Invalid cursor passed {_cursor}. Valid options are:
{valid_cursors}")
+ def _get_cursor_config(self, raw_cursor: str) -> tuple[str, Any]:
+ cursor = self._get_cursor(raw_cursor)
+
+ if USE_PSYCOPG3:
+ return "row_factory", cursor
+
+ return "cursor_factory", cursor
+
+ def _create_connection(self, conn_args: dict[str, Any]) ->
CompatConnection:
+ if USE_PSYCOPG3:
+ from psycopg.connection import Connection as pgConnection
+
+ connection = pgConnection.connect(**cast("Any", conn_args))
+
+ register_default_adapters(connection)
+
+ if self.enable_log_db_messages and hasattr(connection,
"add_notice_handler"):
+ connection.add_notice_handler(self._notice_handler)
+
+ return connection
+
+ return ppg2_connect(**conn_args)
+
def _generate_cursor_name(self):
"""Generate a unique name for server-side cursor."""
import uuid
@@ -262,30 +285,13 @@ class PostgresHook(DbApiHook):
if arg_name not in self.ignored_extra_options:
conn_args[arg_name] = arg_val
- if USE_PSYCOPG3:
- from psycopg.connection import Connection as pgConnection
-
- raw_cursor = conn.extra_dejson.get("cursor")
- if raw_cursor:
- conn_args["row_factory"] = self._get_cursor(raw_cursor)
-
- # Use Any type for the connection args to avoid type conflicts
- connection = pgConnection.connect(**cast("Any", conn_args))
- self.conn = cast("CompatConnection", connection)
-
- # Register JSON handlers for both json and jsonb types
- # This ensures JSON data is properly decoded from bytes to Python
objects
- register_default_adapters(connection)
+ raw_cursor = conn.extra_dejson.get("cursor")
- # Add the notice handler AFTER the connection is established
- if self.enable_log_db_messages and hasattr(self.conn,
"add_notice_handler"):
- self.conn.add_notice_handler(self._notice_handler)
- else: # psycopg2
- raw_cursor = conn.extra_dejson.get("cursor", False)
- if raw_cursor:
- conn_args["cursor_factory"] = self._get_cursor(raw_cursor)
+ if raw_cursor:
+ key, value = self._get_cursor_config(raw_cursor)
+ conn_args[key] = value
- self.conn = cast("CompatConnection", ppg2_connect(**conn_args))
+ self.conn = self._create_connection(conn_args)
return self.conn
diff --git a/providers/postgres/tests/unit/postgres/hooks/test_postgres.py
b/providers/postgres/tests/unit/postgres/hooks/test_postgres.py
index 6cad536c07b..fab8bc6f5d7 100644
--- a/providers/postgres/tests/unit/postgres/hooks/test_postgres.py
+++ b/providers/postgres/tests/unit/postgres/hooks/test_postgres.py
@@ -19,7 +19,6 @@ from __future__ import annotations
import json
import os
-from types import SimpleNamespace
from unittest import mock
import pandas as pd
@@ -30,7 +29,7 @@ import sqlalchemy
from airflow.models import Connection
from airflow.providers.common.compat.sdk import AirflowException,
AirflowOptionalProviderFeatureException
from airflow.providers.postgres.dialects.postgres import PostgresDialect
-from airflow.providers.postgres.hooks.postgres import CompatConnection,
PostgresHook
+from airflow.providers.postgres.hooks.postgres import PostgresHook
from tests_common.test_utils.common_sql import mock_db_hook
from tests_common.test_utils.version_compat import NOTSET
@@ -59,36 +58,6 @@ else:
import psycopg2.extras
[email protected]
-def postgres_hook_setup():
- """Set up mock PostgresHook for testing."""
- table = "test_postgres_hook_table"
- cur = mock.MagicMock(rowcount=0)
- conn = mock.MagicMock(spec=CompatConnection)
- conn.cursor.return_value = cur
-
- class UnitTestPostgresHook(PostgresHook):
- conn_name_attr = "test_conn_id"
-
- def get_conn(self):
- return conn
-
- db_hook = UnitTestPostgresHook()
-
- # Return a namespace with all the objects
- setup = SimpleNamespace(table=table, cur=cur, conn=conn, db_hook=db_hook)
-
- yield setup
-
- # Teardown - only for real database tests
- try:
- with PostgresHook().get_conn() as real_conn:
- with real_conn.cursor() as real_cur:
- real_cur.execute(f"DROP TABLE IF EXISTS {table}")
- except Exception:
- pass # Ignore cleanup errors for unit tests
-
-
@pytest.fixture
def mock_connect(mocker):
"""Mock the connection object according to the correct psycopg version."""
@@ -816,10 +785,8 @@ class TestPostgresHook:
) == INSERT_SQL_STATEMENT.format('"schema"')
[email protected]("postgres")
[email protected](USE_PSYCOPG3, reason="psycopg v3 is available")
-class TestPostgresHookPPG2:
- """PostgresHook tests that are specific to psycopg2."""
+class _BasePostgresHookRuntimeTests:
+ """Shared runtime tests for psycopg2 and psycopg3."""
table = "test_postgres_hook_table"
@@ -841,6 +808,121 @@ class TestPostgresHookPPG2:
with conn.cursor() as cur:
cur.execute(f"DROP TABLE IF EXISTS {self.table}")
+ def test_insert_rows(self):
+ table = "table"
+ rows = [("hello",), ("world",)]
+
+ self.db_hook.insert_rows(table, rows)
+
+ assert self.conn.close.call_count == 1
+ assert self.cur.close.call_count == 1
+ assert self.conn.commit.call_count == 2
+
+ sql = f"INSERT INTO {table} VALUES (%s)"
+ self.cur.executemany.assert_any_call(sql, rows)
+
+ def test_insert_rows_replace(self):
+ table = "table"
+ rows = [
+ (1, "hello"),
+ (2, "world"),
+ ]
+ fields = ("id", "value")
+
+ self.db_hook.insert_rows(
+ table,
+ rows,
+ fields,
+ replace=True,
+ replace_index=fields[0],
+ )
+
+ assert self.conn.close.call_count == 1
+ assert self.cur.close.call_count == 1
+ assert self.conn.commit.call_count == 2
+
+ sql = (
+ f"INSERT INTO {table} ({fields[0]}, {fields[1]}) VALUES (%s,%s) "
+ f"ON CONFLICT ({fields[0]}) DO UPDATE SET {fields[1]} =
excluded.{fields[1]}"
+ )
+ self.cur.executemany.assert_any_call(sql, rows)
+
+ def test_insert_rows_replace_missing_target_field_arg(self):
+ table = "table"
+ rows = [
+ (1, "hello"),
+ (2, "world"),
+ ]
+ fields = ("id", "value")
+
+ with pytest.raises(
+ ValueError,
+ match="PostgreSQL ON CONFLICT upsert syntax requires column names",
+ ):
+ self.db_hook.insert_rows(
+ table,
+ rows,
+ replace=True,
+ replace_index=fields[0],
+ )
+
+ def test_insert_rows_replace_missing_replace_index_arg(self):
+ table = "table"
+ rows = [
+ (1, "hello"),
+ (2, "world"),
+ ]
+ fields = ("id", "value")
+
+ with pytest.raises(
+ ValueError,
+ match="PostgreSQL ON CONFLICT upsert syntax requires an unique
index",
+ ):
+ self.db_hook.insert_rows(
+ table,
+ rows,
+ fields,
+ replace=True,
+ )
+
+ def test_insert_rows_replace_all_index(self):
+ table = "table"
+ rows = [
+ (1, "hello"),
+ (2, "world"),
+ ]
+ fields = ("id", "value")
+
+ self.db_hook.insert_rows(
+ table,
+ rows,
+ fields,
+ replace=True,
+ replace_index=fields,
+ )
+
+ assert self.conn.close.call_count == 1
+ assert self.cur.close.call_count == 1
+ assert self.conn.commit.call_count == 2
+
+ sql = (
+ f"INSERT INTO {table} ({', '.join(fields)}) VALUES (%s,%s) "
+ f"ON CONFLICT ({', '.join(fields)}) DO NOTHING"
+ )
+ self.cur.executemany.assert_any_call(sql, rows)
+
+ def test_dialect_name(self):
+ assert self.db_hook.dialect_name == "postgresql"
+
+ def test_dialect(self):
+ assert isinstance(self.db_hook.dialect, PostgresDialect)
+
+
[email protected]("postgres")
[email protected](USE_PSYCOPG3, reason="psycopg v3 is available")
+class TestPostgresHookPPG2(_BasePostgresHookRuntimeTests):
+ """PostgresHook tests that are specific to psycopg2."""
+
def test_copy_expert(self, mocker):
open_mock = mocker.mock_open(read_data='{"some": "json"}')
mocker.patch("airflow.providers.postgres.hooks.postgres.open",
open_mock)
@@ -915,169 +997,59 @@ class TestPostgresHookPPG2:
assert call_kw["sql"] == sql
assert call_kw["sql_parameters"] == parameters
- def test_insert_rows(self, postgres_hook_setup):
- setup = postgres_hook_setup
- table = "table"
- rows = [("hello",), ("world",)]
-
- setup.db_hook.insert_rows(table, rows)
-
- assert setup.conn.close.call_count == 1
- assert setup.cur.close.call_count == 1
-
- commit_count = 2 # The first and last commit
- assert commit_count == setup.conn.commit.call_count
-
- sql = f"INSERT INTO {table} VALUES (%s)"
- setup.cur.executemany.assert_any_call(sql, rows)
-
@mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage")
- def test_insert_rows_hook_lineage(self, mock_send_lineage,
postgres_hook_setup):
- setup = postgres_hook_setup
+ def test_insert_rows_hook_lineage(self, mock_send_lineage):
table = "table"
rows = [("hello",), ("world",)]
- setup.db_hook.insert_rows(table, rows)
+ self.db_hook.insert_rows(table, rows)
mock_send_lineage.assert_called_once()
+
call_kw = mock_send_lineage.call_args.kwargs
- assert call_kw["context"] is setup.db_hook
+
+ assert call_kw["context"] is self.db_hook
assert call_kw["sql"] == f"INSERT INTO {table} VALUES (%s)"
assert call_kw["row_count"] == 2
@mock.patch("airflow.providers.postgres.hooks.postgres.execute_batch")
- def test_insert_rows_fast_executemany(self, mock_execute_batch,
postgres_hook_setup):
- setup = postgres_hook_setup
+ def test_insert_rows_fast_executemany(self, mock_execute_batch):
table = "table"
rows = [("hello",), ("world",)]
- setup.db_hook.insert_rows(table, rows, fast_executemany=True)
+ self.db_hook.insert_rows(table, rows, fast_executemany=True)
- assert setup.conn.close.call_count == 1
- assert setup.cur.close.call_count == 1
+ assert self.conn.close.call_count == 1
+ assert self.cur.close.call_count == 1
commit_count = 2 # The first and last commit
- assert setup.conn.commit.call_count == commit_count
+ assert self.conn.commit.call_count == commit_count
mock_execute_batch.assert_called_once_with(
- setup.cur,
+ self.cur,
f"INSERT INTO {table} VALUES (%s)", # expected SQL
[("hello",), ("world",)], # expected values
page_size=1000,
)
# executemany should NOT be called in this mode
- setup.cur.executemany.assert_not_called()
+ self.cur.executemany.assert_not_called()
@mock.patch("airflow.providers.postgres.hooks.postgres.send_sql_hook_lineage")
@mock.patch("airflow.providers.postgres.hooks.postgres.execute_batch")
- def test_insert_rows_fast_executemany_hook_lineage(
- self, mock_execute_batch, mock_send_lineage, postgres_hook_setup
- ):
- setup = postgres_hook_setup
+ def test_insert_rows_fast_executemany_hook_lineage(self,
mock_execute_batch, mock_send_lineage):
+
table = "table"
rows = [("hello",), ("world",)]
- setup.db_hook.insert_rows(table, rows, fast_executemany=True)
+ self.db_hook.insert_rows(table, rows, fast_executemany=True)
mock_send_lineage.assert_called_once()
call_kw = mock_send_lineage.call_args.kwargs
- assert call_kw["context"] is setup.db_hook
+ assert call_kw["context"] is self.db_hook
assert call_kw["sql"] == f"INSERT INTO {table} VALUES (%s)"
assert call_kw["row_count"] == 2
- def test_insert_rows_replace(self, postgres_hook_setup):
- setup = postgres_hook_setup
- table = "table"
- rows = [
- (
- 1,
- "hello",
- ),
- (
- 2,
- "world",
- ),
- ]
- fields = ("id", "value")
-
- setup.db_hook.insert_rows(table, rows, fields, replace=True,
replace_index=fields[0])
-
- assert setup.conn.close.call_count == 1
- assert setup.cur.close.call_count == 1
-
- commit_count = 2 # The first and last commit
- assert commit_count == setup.conn.commit.call_count
-
- sql = (
- f"INSERT INTO {table} ({fields[0]}, {fields[1]}) VALUES (%s,%s) "
- f"ON CONFLICT ({fields[0]}) DO UPDATE SET {fields[1]} =
excluded.{fields[1]}"
- )
- setup.cur.executemany.assert_any_call(sql, rows)
-
- def test_insert_rows_replace_missing_target_field_arg(self,
postgres_hook_setup):
- setup = postgres_hook_setup
- table = "table"
- rows = [
- (
- 1,
- "hello",
- ),
- (
- 2,
- "world",
- ),
- ]
- fields = ("id", "value")
- with pytest.raises(ValueError, match="PostgreSQL ON CONFLICT upsert
syntax requires column names"):
- setup.db_hook.insert_rows(table, rows, replace=True,
replace_index=fields[0])
-
- def test_insert_rows_replace_missing_replace_index_arg(self,
postgres_hook_setup):
- setup = postgres_hook_setup
- table = "table"
- rows = [
- (
- 1,
- "hello",
- ),
- (
- 2,
- "world",
- ),
- ]
- fields = ("id", "value")
- with pytest.raises(ValueError, match="PostgreSQL ON CONFLICT upsert
syntax requires an unique index"):
- setup.db_hook.insert_rows(table, rows, fields, replace=True)
-
- def test_insert_rows_replace_all_index(self, postgres_hook_setup):
- setup = postgres_hook_setup
- table = "table"
- rows = [
- (
- 1,
- "hello",
- ),
- (
- 2,
- "world",
- ),
- ]
- fields = ("id", "value")
-
- setup.db_hook.insert_rows(table, rows, fields, replace=True,
replace_index=fields)
-
- assert setup.conn.close.call_count == 1
- assert setup.cur.close.call_count == 1
-
- commit_count = 2 # The first and last commit
- assert commit_count == setup.conn.commit.call_count
-
- sql = (
- f"INSERT INTO {table} ({', '.join(fields)}) VALUES (%s,%s) "
- f"ON CONFLICT ({', '.join(fields)}) DO NOTHING"
- )
- setup.cur.executemany.assert_any_call(sql, rows)
-
@pytest.mark.usefixtures("reset_logging_config")
def test_get_all_db_log_messages(self, mocker):
messages = ["a", "b", "c"]
@@ -1120,40 +1092,12 @@ class TestPostgresHookPPG2:
finally:
hook.run(sql=f"DROP PROCEDURE {proc_name} (s text)")
- def test_dialect_name(self, postgres_hook_setup):
- setup = postgres_hook_setup
- assert setup.db_hook.dialect_name == "postgresql"
-
- def test_dialect(self, postgres_hook_setup):
- setup = postgres_hook_setup
- assert isinstance(setup.db_hook.dialect, PostgresDialect)
-
@pytest.mark.backend("postgres")
@pytest.mark.skipif(not USE_PSYCOPG3, reason="psycopg v3 or sqlalchemy v2 are
not available")
-class TestPostgresHookPPG3:
+class TestPostgresHookPPG3(_BasePostgresHookRuntimeTests):
"""PostgresHook tests that are specific to psycopg3."""
- table = "test_postgres_hook_table"
-
- def setup_method(self):
- self.cur = mock.MagicMock(rowcount=0)
- self.conn = conn = mock.MagicMock()
- self.conn.cursor.return_value = self.cur
-
- class UnitTestPostgresHook(PostgresHook):
- conn_name_attr = "test_conn_id"
-
- def get_conn(self):
- return conn
-
- self.db_hook = UnitTestPostgresHook()
-
- def teardown_method(self):
- with PostgresHook().get_conn() as conn:
- with conn.cursor() as cur:
- cur.execute(f"DROP TABLE IF EXISTS {self.table}")
-
def test_copy_expert_from(self, mocker):
"""Tests copy_expert with a 'COPY FROM STDIN' operation."""
statement = "COPY test_table FROM STDIN"
@@ -1235,109 +1179,6 @@ class TestPostgresHookPPG3:
)
self.conn.commit.assert_called_once()
- def test_insert_rows(self):
- table = "table"
- rows = [("hello",), ("world",)]
-
- self.db_hook.insert_rows(table, rows)
-
- assert self.conn.close.call_count == 1
- assert self.cur.close.call_count == 1
-
- commit_count = 2 # The first and last commit
- assert commit_count == self.conn.commit.call_count
-
- sql = f"INSERT INTO {table} VALUES (%s)"
- self.cur.executemany.assert_any_call(sql, rows)
-
- def test_insert_rows_replace(self):
- table = "table"
- rows = [
- (
- 1,
- "hello",
- ),
- (
- 2,
- "world",
- ),
- ]
- fields = ("id", "value")
-
- self.db_hook.insert_rows(table, rows, fields, replace=True,
replace_index=fields[0])
-
- assert self.conn.close.call_count == 1
- assert self.cur.close.call_count == 1
-
- commit_count = 2 # The first and last commit
- assert commit_count == self.conn.commit.call_count
-
- sql = (
- f"INSERT INTO {table} ({fields[0]}, {fields[1]}) VALUES (%s,%s) "
- f"ON CONFLICT ({fields[0]}) DO UPDATE SET {fields[1]} =
excluded.{fields[1]}"
- )
- self.cur.executemany.assert_any_call(sql, rows)
-
- def test_insert_rows_replace_missing_target_field_arg(self):
- table = "table"
- rows = [
- (
- 1,
- "hello",
- ),
- (
- 2,
- "world",
- ),
- ]
- fields = ("id", "value")
- with pytest.raises(ValueError, match="PostgreSQL ON CONFLICT upsert
syntax requires column names"):
- self.db_hook.insert_rows(table, rows, replace=True,
replace_index=fields[0])
-
- def test_insert_rows_replace_missing_replace_index_arg(self):
- table = "table"
- rows = [
- (
- 1,
- "hello",
- ),
- (
- 2,
- "world",
- ),
- ]
- fields = ("id", "value")
- with pytest.raises(ValueError, match="PostgreSQL ON CONFLICT upsert
syntax requires an unique index"):
- self.db_hook.insert_rows(table, rows, fields, replace=True)
-
- def test_insert_rows_replace_all_index(self):
- table = "table"
- rows = [
- (
- 1,
- "hello",
- ),
- (
- 2,
- "world",
- ),
- ]
- fields = ("id", "value")
-
- self.db_hook.insert_rows(table, rows, fields, replace=True,
replace_index=fields)
-
- assert self.conn.close.call_count == 1
- assert self.cur.close.call_count == 1
-
- commit_count = 2 # The first and last commit
- assert commit_count == self.conn.commit.call_count
-
- sql = (
- f"INSERT INTO {table} ({', '.join(fields)}) VALUES (%s,%s) "
- f"ON CONFLICT ({', '.join(fields)}) DO NOTHING"
- )
- self.cur.executemany.assert_any_call(sql, rows)
-
@pytest.mark.skip(reason="Notice handling is callback-based in psycopg3
and cannot be tested this way.")
def test_get_all_db_log_messages(self, mocker):
pass
@@ -1366,9 +1207,3 @@ class TestPostgresHookPPG3:
mock_logger.info.assert_any_call("Message from db: 42")
finally:
hook.run(sql=f"DROP PROCEDURE {proc_name} (s text)")
-
- def test_dialect_name(self):
- assert self.db_hook.dialect_name == "postgresql"
-
- def test_dialect(self):
- assert isinstance(self.db_hook.dialect, PostgresDialect)