This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new a10b3fccb09 Allow configuration of sqlalchemy query parameter for 
JdbcHook and PostgresHook through extras (#44910)
a10b3fccb09 is described below

commit a10b3fccb09805397e607df4cd3ded6194d20170
Author: David Blain <[email protected]>
AuthorDate: Wed Dec 18 16:39:46 2024 +0100

    Allow configuration of sqlalchemy query parameter for JdbcHook and 
PostgresHook through extras (#44910)
---
 providers/src/airflow/providers/jdbc/hooks/jdbc.py |  4 +++
 .../airflow/providers/postgres/hooks/postgres.py   | 28 +++++++++++------
 providers/tests/jdbc/hooks/test_jdbc.py            | 17 ++++++++++
 providers/tests/postgres/hooks/test_postgres.py    | 36 +++++++++++++++++++++-
 4 files changed, 74 insertions(+), 11 deletions(-)

diff --git a/providers/src/airflow/providers/jdbc/hooks/jdbc.py 
b/providers/src/airflow/providers/jdbc/hooks/jdbc.py
index 808b946bd97..07b5fc42d9a 100644
--- a/providers/src/airflow/providers/jdbc/hooks/jdbc.py
+++ b/providers/src/airflow/providers/jdbc/hooks/jdbc.py
@@ -152,6 +152,9 @@ class JdbcHook(DbApiHook):
     @property
     def sqlalchemy_url(self) -> URL:
         conn = self.connection
+        sqlalchemy_query = conn.extra_dejson.get("sqlalchemy_query", {})
+        if not isinstance(sqlalchemy_query, dict):
+            raise AirflowException("The parameter 'sqlalchemy_query' must be 
of type dict!")
         sqlalchemy_scheme = conn.extra_dejson.get("sqlalchemy_scheme")
         if sqlalchemy_scheme is None:
             raise AirflowException(
@@ -164,6 +167,7 @@ class JdbcHook(DbApiHook):
             host=conn.host,
             port=conn.port,
             database=conn.schema,
+            query=sqlalchemy_query,
         )
 
     def get_sqlalchemy_engine(self, engine_kwargs=None):
diff --git a/providers/src/airflow/providers/postgres/hooks/postgres.py 
b/providers/src/airflow/providers/postgres/hooks/postgres.py
index f5dcfe2df49..9b657c14416 100644
--- a/providers/src/airflow/providers/postgres/hooks/postgres.py
+++ b/providers/src/airflow/providers/postgres/hooks/postgres.py
@@ -29,6 +29,7 @@ import psycopg2.extras
 from psycopg2.extras import DictCursor, NamedTupleCursor, RealDictCursor
 from sqlalchemy.engine import URL
 
+from airflow.exceptions import AirflowException
 from airflow.providers.common.sql.hooks.sql import DbApiHook
 
 if TYPE_CHECKING:
@@ -85,6 +86,17 @@ class PostgresHook(DbApiHook):
     hook_name = "Postgres"
     supports_autocommit = True
     supports_executemany = True
+    ignored_extra_options = {
+        "iam",
+        "redshift",
+        "redshift-serverless",
+        "cursor",
+        "cluster-identifier",
+        "workgroup-name",
+        "aws_conn_id",
+        "sqlalchemy_scheme",
+        "sqlalchemy_query",
+    }
 
     def __init__(
         self, *args, options: str | None = None, enable_log_db_messages: bool 
= False, **kwargs
@@ -97,7 +109,10 @@ class PostgresHook(DbApiHook):
 
     @property
     def sqlalchemy_url(self) -> URL:
-        conn = self.get_connection(self.get_conn_id())
+        conn = self.connection
+        query = conn.extra_dejson.get("sqlalchemy_query", {})
+        if not isinstance(query, dict):
+            raise AirflowException("The parameter 'sqlalchemy_query' must be 
of type dict!")
         return URL.create(
             drivername="postgresql",
             username=conn.login,
@@ -105,6 +120,7 @@ class PostgresHook(DbApiHook):
             host=conn.host,
             port=conn.port,
             database=self.database or conn.schema,
+            query=query,
         )
 
     def _get_cursor(self, raw_cursor: str) -> CursorType:
@@ -143,15 +159,7 @@ class PostgresHook(DbApiHook):
             conn_args["options"] = self.options
 
         for arg_name, arg_val in conn.extra_dejson.items():
-            if arg_name not in [
-                "iam",
-                "redshift",
-                "redshift-serverless",
-                "cursor",
-                "cluster-identifier",
-                "workgroup-name",
-                "aws_conn_id",
-            ]:
+            if arg_name not in self.ignored_extra_options:
                 conn_args[arg_name] = arg_val
 
         self.conn = psycopg2.connect(**conn_args)
diff --git a/providers/tests/jdbc/hooks/test_jdbc.py 
b/providers/tests/jdbc/hooks/test_jdbc.py
index 73015b5b522..ce4e5266234 100644
--- a/providers/tests/jdbc/hooks/test_jdbc.py
+++ b/providers/tests/jdbc/hooks/test_jdbc.py
@@ -219,6 +219,23 @@ class TestJdbcHook:
 
         assert str(hook.sqlalchemy_url) == 
"mssql://login:password@host:1234/schema"
 
+    def test_sqlalchemy_url_with_sqlalchemy_scheme_and_query(self):
+        conn_params = dict(
+            extra=json.dumps(dict(sqlalchemy_scheme="mssql", 
sqlalchemy_query={"servicename": "test"}))
+        )
+        hook_params = {"driver_path": "ParamDriverPath", "driver_class": 
"ParamDriverClass"}
+        hook = get_hook(conn_params=conn_params, hook_params=hook_params)
+
+        assert str(hook.sqlalchemy_url) == 
"mssql://login:password@host:1234/schema?servicename=test"
+
+    def test_sqlalchemy_url_with_sqlalchemy_scheme_and_wrong_query_value(self):
+        conn_params = dict(extra=json.dumps(dict(sqlalchemy_scheme="mssql", 
sqlalchemy_query="wrong type")))
+        hook_params = {"driver_path": "ParamDriverPath", "driver_class": 
"ParamDriverClass"}
+        hook = get_hook(conn_params=conn_params, hook_params=hook_params)
+
+        with pytest.raises(AirflowException):
+            hook.sqlalchemy_url
+
     def test_get_sqlalchemy_engine_verify_creator_is_being_used(self):
         jdbc_hook = get_hook(
             conn_params=dict(extra={"sqlalchemy_scheme": "sqlite"}),
diff --git a/providers/tests/postgres/hooks/test_postgres.py 
b/providers/tests/postgres/hooks/test_postgres.py
index 7a720534d4b..76206d57958 100644
--- a/providers/tests/postgres/hooks/test_postgres.py
+++ b/providers/tests/postgres/hooks/test_postgres.py
@@ -25,6 +25,7 @@ from unittest import mock
 import psycopg2.extras
 import pytest
 
+from airflow.exceptions import AirflowException
 from airflow.models import Connection
 from airflow.providers.postgres.hooks.postgres import PostgresHook
 from airflow.utils.types import NOTSET
@@ -65,9 +66,42 @@ class TestPostgresHookConn:
         assert mock_connect.call_count == 1
         assert self.db_hook.get_uri() == 
"postgresql://login:password@host:5432/database"
 
+    def test_sqlalchemy_url(self):
+        conn = Connection(login="login-conn", password="password-conn", 
host="host", schema="database")
+        hook = PostgresHook(connection=conn)
+        assert str(hook.sqlalchemy_url) == 
"postgresql://login-conn:password-conn@host/database"
+
+    def test_sqlalchemy_url_with_sqlalchemy_query(self):
+        conn = Connection(
+            login="login-conn",
+            password="password-conn",
+            host="host",
+            schema="database",
+            extra=dict(sqlalchemy_query={"gssencmode": "disable"}),
+        )
+        hook = PostgresHook(connection=conn)
+
+        assert (
+            str(hook.sqlalchemy_url)
+            == 
"postgresql://login-conn:password-conn@host/database?gssencmode=disable"
+        )
+
+    def test_sqlalchemy_url_with_wrong_sqlalchemy_query_value(self):
+        conn = Connection(
+            login="login-conn",
+            password="password-conn",
+            host="host",
+            schema="database",
+            extra=dict(sqlalchemy_query="wrong type"),
+        )
+        hook = PostgresHook(connection=conn)
+
+        with pytest.raises(AirflowException):
+            hook.sqlalchemy_url
+
     @mock.patch("airflow.providers.postgres.hooks.postgres.psycopg2.connect")
     def test_get_conn_cursor(self, mock_connect):
-        self.connection.extra = '{"cursor": "dictcursor"}'
+        self.connection.extra = '{"cursor": "dictcursor", "sqlalchemy_query": 
{"gssencmode": "disable"}}'
         self.db_hook.get_conn()
         mock_connect.assert_called_once_with(
             cursor_factory=psycopg2.extras.DictCursor,

Reply via email to