This is an automated email from the ASF dual-hosted git repository.

shahar pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new e46453966ef Support connection extra parameters in MsSqlHook (#44310)
e46453966ef is described below

commit e46453966ef521cc8f6dc7019566ea2fdcc0063b
Author: jaejun <63435794+jx2...@users.noreply.github.com>
AuthorDate: Tue Nov 26 16:47:16 2024 +0900

    Support connection extra parameters in MsSqlHook (#44310)
    
    * enable extras
    
    * mark db_test
    
    * connections to fixture
---
 .../providers/microsoft/mssql/hooks/mssql.py       |   2 +
 .../tests/microsoft/mssql/hooks/test_mssql.py      | 308 ++++++++++-----------
 2 files changed, 149 insertions(+), 161 deletions(-)

diff --git a/providers/src/airflow/providers/microsoft/mssql/hooks/mssql.py 
b/providers/src/airflow/providers/microsoft/mssql/hooks/mssql.py
index d45a43a188c..a367250ed33 100644
--- a/providers/src/airflow/providers/microsoft/mssql/hooks/mssql.py
+++ b/providers/src/airflow/providers/microsoft/mssql/hooks/mssql.py
@@ -137,12 +137,14 @@ class MsSqlHook(DbApiHook):
     def get_conn(self) -> PymssqlConnection:
         """Return ``pymssql`` connection object."""
         conn = self.connection
+        extra_conn_args = {key: val for key, val in conn.extra_dejson.items() 
if key != "sqlalchemy_scheme"}
         return pymssql.connect(
             server=conn.host,
             user=conn.login,
             password=conn.password,
             database=self.schema or conn.schema,
             port=str(conn.port),
+            **extra_conn_args,
         )
 
     def set_autocommit(
diff --git a/providers/tests/microsoft/mssql/hooks/test_mssql.py 
b/providers/tests/microsoft/mssql/hooks/test_mssql.py
index 1b43bb78783..be8f921112a 100644
--- a/providers/tests/microsoft/mssql/hooks/test_mssql.py
+++ b/providers/tests/microsoft/mssql/hooks/test_mssql.py
@@ -18,7 +18,6 @@
 from __future__ import annotations
 
 from unittest import mock
-from urllib.parse import quote_plus
 
 import pytest
 
@@ -31,33 +30,9 @@ try:
 except ImportError:
     pytest.skip("MSSQL not available", allow_module_level=True)
 
-PYMSSQL_CONN = Connection(
-    conn_type="mssql", host="ip", schema="share", login="username", 
password="password", port=8081
-)
-PYMSSQL_CONN_ALT = Connection(
-    conn_type="mssql", host="ip", schema="", login="username", 
password="password", port=8081
-)
-PYMSSQL_CONN_ALT_1 = Connection(
-    conn_type="mssql",
-    host="ip",
-    schema="",
-    login="username",
-    password="password",
-    port=8081,
-    extra={"SQlalchemy_Scheme": "mssql+testdriver"},
-)
-PYMSSQL_CONN_ALT_2 = Connection(
-    conn_type="mssql",
-    host="ip",
-    schema="",
-    login="username",
-    password="password",
-    port=8081,
-    extra={"SQlalchemy_Scheme": "mssql+testdriver", "myparam": "5@-//*"},
-)
-
-
-def get_primary_keys(self, table: str) -> list[str]:
+
+@pytest.fixture
+def get_primary_keys():
     return [
         "GroupDisplayName",
         "OwnerPrincipalName",
@@ -66,11 +41,49 @@ def get_primary_keys(self, table: str) -> list[str]:
     ]
 
 
+@pytest.fixture
+def mssql_connections():
+    return {
+        "default": Connection(
+            conn_type="mssql", host="ip", schema="share", login="username", 
password="password", port=8081
+        ),
+        "alt": Connection(
+            conn_type="mssql", host="ip", schema="", login="username", 
password="password", port=8081
+        ),
+        "alt_1": Connection(
+            conn_type="mssql",
+            host="ip",
+            schema="",
+            login="username",
+            password="password",
+            port=8081,
+            extra={"SQlalchemy_Scheme": "mssql+testdriver"},
+        ),
+        "alt_2": Connection(
+            conn_type="mssql",
+            host="ip",
+            schema="",
+            login="username",
+            password="password",
+            port=8081,
+            extra={"SQlalchemy_Scheme": "mssql+testdriver", "myparam": 
"5@-//*"},
+        ),
+    }
+
+
+URI_TEST_CASES = [
+    ("default", "mssql+pymssql://username:password@ip:8081/share"),
+    ("alt", "mssql+pymssql://username:password@ip:8081"),
+    ("alt_1", "mssql+testdriver://username:password@ip:8081/"),
+    ("alt_2", 
"mssql+testdriver://username:password@ip:8081/?myparam=5%40-%2F%2F%2A"),
+]
+
+
 class TestMsSqlHook:
     
@mock.patch("airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_conn")
     
@mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook.get_connection")
-    def test_get_conn_should_return_connection(self, get_connection, 
mssql_get_conn):
-        get_connection.return_value = PYMSSQL_CONN
+    def test_get_conn_should_return_connection(self, get_connection, 
mssql_get_conn, mssql_connections):
+        get_connection.return_value = mssql_connections["default"]
         mssql_get_conn.return_value = mock.Mock()
 
         hook = MsSqlHook()
@@ -81,8 +94,8 @@ class TestMsSqlHook:
 
     
@mock.patch("airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_conn")
     
@mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook.get_connection")
-    def test_set_autocommit_should_invoke_autocommit(self, get_connection, 
mssql_get_conn):
-        get_connection.return_value = PYMSSQL_CONN
+    def test_set_autocommit_should_invoke_autocommit(self, get_connection, 
mssql_get_conn, mssql_connections):
+        get_connection.return_value = mssql_connections["default"]
         mssql_get_conn.return_value = mock.Mock()
         autocommit_value = mock.Mock()
 
@@ -95,8 +108,10 @@ class TestMsSqlHook:
 
     
@mock.patch("airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_conn")
     
@mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook.get_connection")
-    def test_get_autocommit_should_return_autocommit_state(self, 
get_connection, mssql_get_conn):
-        get_connection.return_value = PYMSSQL_CONN
+    def test_get_autocommit_should_return_autocommit_state(
+        self, get_connection, mssql_get_conn, mssql_connections
+    ):
+        get_connection.return_value = mssql_connections["default"]
         mssql_get_conn.return_value = mock.Mock()
         mssql_get_conn.return_value.autocommit_state = "autocommit_state"
 
@@ -106,47 +121,10 @@ class TestMsSqlHook:
         mssql_get_conn.assert_called_once()
         assert hook.get_autocommit(conn) == "autocommit_state"
 
-    @pytest.mark.parametrize(
-        "conn, exp_uri",
-        [
-            (
-                PYMSSQL_CONN,
-                (
-                    "mssql+pymssql://"
-                    
f"{quote_plus(PYMSSQL_CONN.login)}:{quote_plus(PYMSSQL_CONN.password)}"
-                    
f"@{PYMSSQL_CONN.host}:{PYMSSQL_CONN.port}/{PYMSSQL_CONN.schema}"
-                ),
-            ),
-            (
-                PYMSSQL_CONN_ALT,
-                (
-                    "mssql+pymssql://"
-                    
f"{quote_plus(PYMSSQL_CONN_ALT.login)}:{quote_plus(PYMSSQL_CONN_ALT.password)}"
-                    f"@{PYMSSQL_CONN_ALT.host}:{PYMSSQL_CONN_ALT.port}"
-                ),
-            ),
-            (
-                PYMSSQL_CONN_ALT_1,
-                (
-                    
f"{PYMSSQL_CONN_ALT_1.extra_dejson['SQlalchemy_Scheme']}://"
-                    
f"{quote_plus(PYMSSQL_CONN_ALT.login)}:{quote_plus(PYMSSQL_CONN_ALT.password)}"
-                    f"@{PYMSSQL_CONN_ALT.host}:{PYMSSQL_CONN_ALT.port}/"
-                ),
-            ),
-            (
-                PYMSSQL_CONN_ALT_2,
-                (
-                    
f"{PYMSSQL_CONN_ALT_2.extra_dejson['SQlalchemy_Scheme']}://"
-                    
f"{quote_plus(PYMSSQL_CONN_ALT_2.login)}:{quote_plus(PYMSSQL_CONN_ALT_2.password)}"
-                    f"@{PYMSSQL_CONN_ALT_2.host}:{PYMSSQL_CONN_ALT_2.port}/"
-                    
f"?myparam={quote_plus(PYMSSQL_CONN_ALT_2.extra_dejson['myparam'])}"
-                ),
-            ),
-        ],
-    )
+    @pytest.mark.parametrize("conn_id,exp_uri", URI_TEST_CASES)
     
@mock.patch("airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_connection")
-    def test_get_uri_driver_rewrite(self, get_connection, conn, exp_uri):
-        get_connection.return_value = conn
+    def test_get_uri_driver_rewrite(self, get_connection, mssql_connections, 
conn_id, exp_uri):
+        get_connection.return_value = mssql_connections[conn_id]
 
         hook = MsSqlHook()
         res_uri = hook.get_uri()
@@ -155,8 +133,8 @@ class TestMsSqlHook:
         assert res_uri == exp_uri
 
     
@mock.patch("airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_connection")
-    def test_sqlalchemy_scheme_is_default(self, get_connection):
-        get_connection.return_value = PYMSSQL_CONN
+    def test_sqlalchemy_scheme_is_default(self, get_connection, 
mssql_connections):
+        get_connection.return_value = mssql_connections["default"]
 
         hook = MsSqlHook()
         assert hook.sqlalchemy_scheme == hook.DEFAULT_SQLALCHEMY_SCHEME
@@ -167,101 +145,109 @@ class TestMsSqlHook:
         assert hook.sqlalchemy_scheme == "mssql+mytestdriver"
 
     
@mock.patch("airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_connection")
-    def test_sqlalchemy_scheme_is_from_conn_extra(self, get_connection):
-        get_connection.return_value = PYMSSQL_CONN_ALT_1
+    def test_sqlalchemy_scheme_is_from_conn_extra(self, get_connection, 
mssql_connections):
+        get_connection.return_value = mssql_connections["alt_1"]
 
         hook = MsSqlHook()
         scheme = hook.sqlalchemy_scheme
         get_connection.assert_called()
-        assert scheme == PYMSSQL_CONN_ALT_1.extra_dejson["SQlalchemy_Scheme"]
+        assert scheme == 
mssql_connections["alt_1"].extra_dejson["SQlalchemy_Scheme"]
 
     
@mock.patch("airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_connection")
-    def test_get_sqlalchemy_engine(self, get_connection):
-        get_connection.return_value = PYMSSQL_CONN
+    def test_get_sqlalchemy_engine(self, get_connection, mssql_connections):
+        get_connection.return_value = mssql_connections["default"]
 
         hook = MsSqlHook()
         hook.get_sqlalchemy_engine()
 
     
@mock.patch("airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_connection")
-    
@mock.patch("airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_primary_keys",
 get_primary_keys)
-    def test_generate_insert_sql(self, get_connection):
-        get_connection.return_value = PYMSSQL_CONN
+    def test_generate_insert_sql(self, get_connection, mssql_connections, 
get_primary_keys):
+        get_connection.return_value = mssql_connections["default"]
+
+        hook = MsSqlHook()
+        with mock.patch.object(hook, "get_primary_keys", 
return_value=get_primary_keys):
+            sql = hook._generate_insert_sql(
+                table="YAMMER_GROUPS_ACTIVITY_DETAIL",
+                values=[
+                    "2024-07-17",
+                    "daa5b44c-80d6-4e22-85b5-a94e04cf7206",
+                    "no-re...@microsoft.com",
+                    "2024-07-17",
+                    0,
+                    0.0,
+                    "MICROSOFT FABRIC (FREE)+MICROSOFT 365 E5",
+                    0,
+                    0,
+                    0,
+                    0,
+                    0,
+                    0,
+                    0,
+                    0,
+                    0,
+                    0,
+                    0,
+                    0,
+                    "PT0S",
+                    "PT0S",
+                    "PT0S",
+                    0,
+                    0,
+                    0,
+                    "Yes",
+                    0,
+                    0,
+                    "APACHE",
+                    0.0,
+                    0,
+                    "Yes",
+                    1,
+                    "2024-07-17T00:00:00+00:00",
+                ],
+                target_fields=[
+                    "ReportRefreshDate",
+                    "UserId",
+                    "UserPrincipalName",
+                    "LastActivityDate",
+                    "IsDeleted",
+                    "DeletedDate",
+                    "AssignedProducts",
+                    "TeamChatMessageCount",
+                    "PrivateChatMessageCount",
+                    "CallCount",
+                    "MeetingCount",
+                    "MeetingsOrganizedCount",
+                    "MeetingsAttendedCount",
+                    "AdHocMeetingsOrganizedCount",
+                    "AdHocMeetingsAttendedCount",
+                    "ScheduledOne-timeMeetingsOrganizedCount",
+                    "ScheduledOne-timeMeetingsAttendedCount",
+                    "ScheduledRecurringMeetingsOrganizedCount",
+                    "ScheduledRecurringMeetingsAttendedCount",
+                    "AudioDuration",
+                    "VideoDuration",
+                    "ScreenShareDuration",
+                    "AudioDurationInSeconds",
+                    "VideoDurationInSeconds",
+                    "ScreenShareDurationInSeconds",
+                    "HasOtherAction",
+                    "UrgentMessages",
+                    "PostMessages",
+                    "TenantDisplayName",
+                    "SharedChannelTenantDisplayNames",
+                    "ReplyMessages",
+                    "IsLicensed",
+                    "ReportPeriod",
+                    "LoadDate",
+                ],
+                replace=True,
+            )
+            assert sql == load_file("resources", "replace.sql")
+
+    @pytest.mark.db_test
+    
@mock.patch("airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook.get_connection")
+    def test_get_extra(self, get_connection, mssql_connections):
+        get_connection.return_value = mssql_connections["alt_2"]
 
         hook = MsSqlHook()
-        sql = hook._generate_insert_sql(
-            table="YAMMER_GROUPS_ACTIVITY_DETAIL",
-            values=[
-                "2024-07-17",
-                "daa5b44c-80d6-4e22-85b5-a94e04cf7206",
-                "no-re...@microsoft.com",
-                "2024-07-17",
-                0,
-                0.0,
-                "MICROSOFT FABRIC (FREE)+MICROSOFT 365 E5",
-                0,
-                0,
-                0,
-                0,
-                0,
-                0,
-                0,
-                0,
-                0,
-                0,
-                0,
-                0,
-                "PT0S",
-                "PT0S",
-                "PT0S",
-                0,
-                0,
-                0,
-                "Yes",
-                0,
-                0,
-                "APACHE",
-                0.0,
-                0,
-                "Yes",
-                1,
-                "2024-07-17T00:00:00+00:00",
-            ],
-            target_fields=[
-                "ReportRefreshDate",
-                "UserId",
-                "UserPrincipalName",
-                "LastActivityDate",
-                "IsDeleted",
-                "DeletedDate",
-                "AssignedProducts",
-                "TeamChatMessageCount",
-                "PrivateChatMessageCount",
-                "CallCount",
-                "MeetingCount",
-                "MeetingsOrganizedCount",
-                "MeetingsAttendedCount",
-                "AdHocMeetingsOrganizedCount",
-                "AdHocMeetingsAttendedCount",
-                "ScheduledOne-timeMeetingsOrganizedCount",
-                "ScheduledOne-timeMeetingsAttendedCount",
-                "ScheduledRecurringMeetingsOrganizedCount",
-                "ScheduledRecurringMeetingsAttendedCount",
-                "AudioDuration",
-                "VideoDuration",
-                "ScreenShareDuration",
-                "AudioDurationInSeconds",
-                "VideoDurationInSeconds",
-                "ScreenShareDurationInSeconds",
-                "HasOtherAction",
-                "UrgentMessages",
-                "PostMessages",
-                "TenantDisplayName",
-                "SharedChannelTenantDisplayNames",
-                "ReplyMessages",
-                "IsLicensed",
-                "ReportPeriod",
-                "LoadDate",
-            ],
-            replace=True,
-        )
-        assert sql == load_file("resources", "replace.sql")
+        assert hook.get_connection().extra

Reply via email to