This is an automated email from the ASF dual-hosted git repository.

weilee pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new dc5d007fe57 feat: overwrite `get_uri` for `JDBC` (#48915)
dc5d007fe57 is described below

commit dc5d007fe579c13bd7854318bfdc0749153b455a
Author: Guan Ming(Wesley) Chiu <[email protected]>
AuthorDate: Mon Apr 14 22:33:48 2025 +0800

    feat: overwrite `get_uri` for `JDBC` (#48915)
    
    * feat: overwrite `get_uri` for `JDBC`
    
    * fix: apply suggestions from code review
    
    Co-authored-by: Wei Lee <[email protected]>
    
    * fix: make string as format string
    
    ---------
    
    Co-authored-by: Wei Lee <[email protected]>
    Co-authored-by: LIU ZHE YOU <[email protected]>
---
 .../jdbc/src/airflow/providers/jdbc/hooks/jdbc.py  | 36 ++++++++++++
 providers/jdbc/tests/unit/jdbc/hooks/test_jdbc.py  | 67 ++++++++++++++++++++++
 2 files changed, 103 insertions(+)

diff --git a/providers/jdbc/src/airflow/providers/jdbc/hooks/jdbc.py 
b/providers/jdbc/src/airflow/providers/jdbc/hooks/jdbc.py
index 07b5fc42d9a..705ed847e02 100644
--- a/providers/jdbc/src/airflow/providers/jdbc/hooks/jdbc.py
+++ b/providers/jdbc/src/airflow/providers/jdbc/hooks/jdbc.py
@@ -22,6 +22,7 @@ import warnings
 from contextlib import contextmanager
 from threading import RLock
 from typing import TYPE_CHECKING, Any
+from urllib.parse import quote_plus, urlencode
 
 import jaydebeapi
 import jpype
@@ -220,3 +221,38 @@ class JdbcHook(DbApiHook):
         with suppress_and_warn(jaydebeapi.Error, jpype.JException):
             return conn.jconn.getAutoCommit()
         return False
+
+    def get_uri(self) -> str:
+        """Get the connection URI for the JDBC connection."""
+        conn = self.connection
+        extra = conn.extra_dejson
+
+        scheme = extra.get("sqlalchemy_scheme")
+        if not scheme:
+            return conn.host
+
+        driver = extra.get("sqlalchemy_driver")
+        uri_prefix = f"{scheme}+{driver}" if driver else scheme
+
+        auth_part = ""
+        if conn.login:
+            auth_part = quote_plus(conn.login)
+            if conn.password:
+                auth_part = f"{auth_part}:{quote_plus(conn.password)}"
+            auth_part = f"{auth_part}@"
+
+        host_part = conn.host or "localhost"
+        if conn.port:
+            host_part = f"{host_part}:{conn.port}"
+
+        schema_part = f"/{quote_plus(conn.schema)}" if conn.schema else ""
+
+        uri = f"{uri_prefix}://{auth_part}{host_part}{schema_part}"
+
+        sqlalchemy_query = extra.get("sqlalchemy_query", {})
+        if isinstance(sqlalchemy_query, dict):
+            query_string = urlencode({k: str(v) for k, v in 
sqlalchemy_query.items() if v is not None})
+            if query_string:
+                uri = f"{uri}?{query_string}"
+
+        return uri
diff --git a/providers/jdbc/tests/unit/jdbc/hooks/test_jdbc.py 
b/providers/jdbc/tests/unit/jdbc/hooks/test_jdbc.py
index 646b3e9c09e..1e8da576004 100644
--- a/providers/jdbc/tests/unit/jdbc/hooks/test_jdbc.py
+++ b/providers/jdbc/tests/unit/jdbc/hooks/test_jdbc.py
@@ -309,3 +309,70 @@ class TestJdbcHook:
                     future.result()  # This will raise OSError if get_conn 
isn't threadsafe
 
             assert mock_connect.call_count == 10
+
+    @pytest.mark.parametrize(
+        "params,expected_uri",
+        [
+            # JDBC URL fallback cases
+            pytest.param(
+                {"host": "jdbc:mysql://localhost:3306/test"},
+                "jdbc:mysql://localhost:3306/test",
+                id="jdbc-mysql",
+            ),
+            pytest.param(
+                {"host": 
"jdbc:postgresql://localhost:5432/test?user=user&password=pass%40word"},
+                
"jdbc:postgresql://localhost:5432/test?user=user&password=pass%40word",
+                id="jdbc-postgresql",
+            ),
+            pytest.param(
+                {"host": "jdbc:oracle:thin:@localhost:1521:xe"},
+                "jdbc:oracle:thin:@localhost:1521:xe",
+                id="jdbc-oracle",
+            ),
+            pytest.param(
+                {"host": 
"jdbc:sqlserver://localhost:1433;databaseName=test;trustServerCertificate=true"},
+                
"jdbc:sqlserver://localhost:1433;databaseName=test;trustServerCertificate=true",
+                id="jdbc-sqlserver",
+            ),
+            # SQLAlchemy URI cases
+            pytest.param(
+                {
+                    "conn_params": {
+                        "extra": json.dumps(
+                            {"sqlalchemy_scheme": "mssql", "sqlalchemy_query": 
{"servicename": "test"}}
+                        )
+                    }
+                },
+                "mssql://login:password@host:1234/schema?servicename=test",
+                id="sqlalchemy-scheme-with-query",
+            ),
+            pytest.param(
+                {
+                    "conn_params": {
+                        "extra": json.dumps(
+                            {"sqlalchemy_scheme": "postgresql", 
"sqlalchemy_driver": "psycopg2"}
+                        )
+                    }
+                },
+                "postgresql+psycopg2://login:password@host:1234/schema",
+                id="sqlalchemy-scheme-with-driver",
+            ),
+            pytest.param(
+                {
+                    "login": "user@domain",
+                    "password": "pass/word",
+                    "schema": "my/db",
+                    "conn_params": {"extra": json.dumps({"sqlalchemy_scheme": 
"mysql"})},
+                },
+                "mysql://user%40domain:pass%2Fword@host:1234/my%2Fdb",
+                id="sqlalchemy-with-encoding",
+            ),
+        ],
+    )
+    def test_get_uri(self, params, expected_uri):
+        """Test get_uri with different configurations including JDBC URLs and 
SQLAlchemy URIs."""
+        valid_keys = {"host", "login", "password", "schema", "conn_params"}
+        hook_params = {key: params[key] for key in valid_keys & params.keys()}
+
+        jdbc_hook = get_hook(**hook_params)
+        assert jdbc_hook.get_uri() == expected_uri

Reply via email to