This is an automated email from the ASF dual-hosted git repository.
weilee pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new dc5d007fe57 feat: overwrite `get_uri` for `JDBC` (#48915)
dc5d007fe57 is described below
commit dc5d007fe579c13bd7854318bfdc0749153b455a
Author: Guan Ming(Wesley) Chiu <[email protected]>
AuthorDate: Mon Apr 14 22:33:48 2025 +0800
feat: overwrite `get_uri` for `JDBC` (#48915)
* feat: overwrite `get_uri` for `JDBC`
* fix: apply suggestions from code review
Co-authored-by: Wei Lee <[email protected]>
* fix: make string as format string
---------
Co-authored-by: Wei Lee <[email protected]>
Co-authored-by: LIU ZHE YOU <[email protected]>
---
.../jdbc/src/airflow/providers/jdbc/hooks/jdbc.py | 36 ++++++++++++
providers/jdbc/tests/unit/jdbc/hooks/test_jdbc.py | 67 ++++++++++++++++++++++
2 files changed, 103 insertions(+)
diff --git a/providers/jdbc/src/airflow/providers/jdbc/hooks/jdbc.py
b/providers/jdbc/src/airflow/providers/jdbc/hooks/jdbc.py
index 07b5fc42d9a..705ed847e02 100644
--- a/providers/jdbc/src/airflow/providers/jdbc/hooks/jdbc.py
+++ b/providers/jdbc/src/airflow/providers/jdbc/hooks/jdbc.py
@@ -22,6 +22,7 @@ import warnings
from contextlib import contextmanager
from threading import RLock
from typing import TYPE_CHECKING, Any
+from urllib.parse import quote_plus, urlencode
import jaydebeapi
import jpype
@@ -220,3 +221,38 @@ class JdbcHook(DbApiHook):
with suppress_and_warn(jaydebeapi.Error, jpype.JException):
return conn.jconn.getAutoCommit()
return False
+
+ def get_uri(self) -> str:
+ """Get the connection URI for the JDBC connection."""
+ conn = self.connection
+ extra = conn.extra_dejson
+
+ scheme = extra.get("sqlalchemy_scheme")
+ if not scheme:
+ return conn.host
+
+ driver = extra.get("sqlalchemy_driver")
+ uri_prefix = f"{scheme}+{driver}" if driver else scheme
+
+ auth_part = ""
+ if conn.login:
+ auth_part = quote_plus(conn.login)
+ if conn.password:
+ auth_part = f"{auth_part}:{quote_plus(conn.password)}"
+ auth_part = f"{auth_part}@"
+
+ host_part = conn.host or "localhost"
+ if conn.port:
+ host_part = f"{host_part}:{conn.port}"
+
+ schema_part = f"/{quote_plus(conn.schema)}" if conn.schema else ""
+
+ uri = f"{uri_prefix}://{auth_part}{host_part}{schema_part}"
+
+ sqlalchemy_query = extra.get("sqlalchemy_query", {})
+ if isinstance(sqlalchemy_query, dict):
+ query_string = urlencode({k: str(v) for k, v in
sqlalchemy_query.items() if v is not None})
+ if query_string:
+ uri = f"{uri}?{query_string}"
+
+ return uri
diff --git a/providers/jdbc/tests/unit/jdbc/hooks/test_jdbc.py
b/providers/jdbc/tests/unit/jdbc/hooks/test_jdbc.py
index 646b3e9c09e..1e8da576004 100644
--- a/providers/jdbc/tests/unit/jdbc/hooks/test_jdbc.py
+++ b/providers/jdbc/tests/unit/jdbc/hooks/test_jdbc.py
@@ -309,3 +309,70 @@ class TestJdbcHook:
future.result() # This will raise OSError if get_conn
isn't threadsafe
assert mock_connect.call_count == 10
+
+ @pytest.mark.parametrize(
+ "params,expected_uri",
+ [
+ # JDBC URL fallback cases
+ pytest.param(
+ {"host": "jdbc:mysql://localhost:3306/test"},
+ "jdbc:mysql://localhost:3306/test",
+ id="jdbc-mysql",
+ ),
+ pytest.param(
+ {"host":
"jdbc:postgresql://localhost:5432/test?user=user&password=pass%40word"},
+
"jdbc:postgresql://localhost:5432/test?user=user&password=pass%40word",
+ id="jdbc-postgresql",
+ ),
+ pytest.param(
+ {"host": "jdbc:oracle:thin:@localhost:1521:xe"},
+ "jdbc:oracle:thin:@localhost:1521:xe",
+ id="jdbc-oracle",
+ ),
+ pytest.param(
+ {"host":
"jdbc:sqlserver://localhost:1433;databaseName=test;trustServerCertificate=true"},
+
"jdbc:sqlserver://localhost:1433;databaseName=test;trustServerCertificate=true",
+ id="jdbc-sqlserver",
+ ),
+ # SQLAlchemy URI cases
+ pytest.param(
+ {
+ "conn_params": {
+ "extra": json.dumps(
+ {"sqlalchemy_scheme": "mssql", "sqlalchemy_query":
{"servicename": "test"}}
+ )
+ }
+ },
+ "mssql://login:password@host:1234/schema?servicename=test",
+ id="sqlalchemy-scheme-with-query",
+ ),
+ pytest.param(
+ {
+ "conn_params": {
+ "extra": json.dumps(
+ {"sqlalchemy_scheme": "postgresql",
"sqlalchemy_driver": "psycopg2"}
+ )
+ }
+ },
+ "postgresql+psycopg2://login:password@host:1234/schema",
+ id="sqlalchemy-scheme-with-driver",
+ ),
+ pytest.param(
+ {
+ "login": "user@domain",
+ "password": "pass/word",
+ "schema": "my/db",
+ "conn_params": {"extra": json.dumps({"sqlalchemy_scheme":
"mysql"})},
+ },
+ "mysql://user%40domain:pass%2Fword@host:1234/my%2Fdb",
+ id="sqlalchemy-with-encoding",
+ ),
+ ],
+ )
+ def test_get_uri(self, params, expected_uri):
+ """Test get_uri with different configurations including JDBC URLs and
SQLAlchemy URIs."""
+ valid_keys = {"host", "login", "password", "schema", "conn_params"}
+ hook_params = {key: params[key] for key in valid_keys & params.keys()}
+
+ jdbc_hook = get_hook(**hook_params)
+ assert jdbc_hook.get_uri() == expected_uri