This is an automated email from the ASF dual-hosted git repository.

weilee pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 39a373cc841 fix: overwrite `get-uri` for `Trino` (#48917)
39a373cc841 is described below

commit 39a373cc841eecf1ea9b4885539394d93ca6eb60
Author: Guan Ming(Wesley) Chiu <105915352+guan404m...@users.noreply.github.com>
AuthorDate: Wed Apr 16 14:52:57 2025 +0800

    fix: overwrite `get-uri` for `Trino` (#48917)
---
 .../src/airflow/providers/trino/hooks/trino.py     | 36 +++++++++++++++
 .../trino/tests/unit/trino/hooks/test_trino.py     | 52 ++++++++++++++++++++++
 2 files changed, 88 insertions(+)

diff --git a/providers/trino/src/airflow/providers/trino/hooks/trino.py 
b/providers/trino/src/airflow/providers/trino/hooks/trino.py
index f907e95f4ee..39be65dd7a2 100644
--- a/providers/trino/src/airflow/providers/trino/hooks/trino.py
+++ b/providers/trino/src/airflow/providers/trino/hooks/trino.py
@@ -22,6 +22,7 @@ import os
 from collections.abc import Iterable, Mapping
 from pathlib import Path
 from typing import TYPE_CHECKING, Any, TypeVar
+from urllib.parse import quote_plus, urlencode
 
 import trino
 from trino.exceptions import DatabaseError
@@ -322,3 +323,38 @@ class TrinoHook(DbApiHook):
     def get_openlineage_default_schema(self):
         """Return Trino default schema."""
         return trino.constants.DEFAULT_SCHEMA
+
+    def get_uri(self) -> str:
+        """Return the Trino URI for the connection."""
+        conn = self.connection
+        uri = "trino://"
+
+        auth_part = ""
+        if conn.login:
+            auth_part = quote_plus(conn.login)
+            if conn.password:
+                auth_part = f"{auth_part}:{quote_plus(conn.password)}"
+            auth_part = f"{auth_part}@"
+
+        host_part = conn.host or "localhost"
+        if conn.port:
+            host_part = f"{host_part}:{conn.port}"
+
+        schema_part = ""
+        if conn.schema:
+            schema_part = f"/{quote_plus(conn.schema)}"
+            extra_schema = conn.extra_dejson.get("schema")
+            if extra_schema:
+                schema_part = f"{schema_part}/{quote_plus(extra_schema)}"
+
+        uri = f"{uri}{auth_part}{host_part}{schema_part}"
+
+        extra = conn.extra_dejson.copy()
+        if "schema" in extra:
+            extra.pop("schema")
+
+        query_params = {k: str(v) for k, v in extra.items() if v is not None}
+        if query_params:
+            uri = f"{uri}?{urlencode(query_params)}"
+
+        return uri
diff --git a/providers/trino/tests/unit/trino/hooks/test_trino.py 
b/providers/trino/tests/unit/trino/hooks/test_trino.py
index b0608023118..7b6dc819320 100644
--- a/providers/trino/tests/unit/trino/hooks/test_trino.py
+++ b/providers/trino/tests/unit/trino/hooks/test_trino.py
@@ -444,3 +444,55 @@ def test_execute_openlineage_events():
             },
         )
     ]
+
+
+@pytest.mark.parametrize(
+    "conn_params, expected_uri",
+    [
+        (
+            {"login": "user", "password": "pass", "host": "localhost", "port": 
8080, "schema": "hive"},
+            "trino://user:pass@localhost:8080/hive",
+        ),
+        (
+            {
+                "login": "user",
+                "password": "pass",
+                "host": "localhost",
+                "port": 8080,
+                "schema": "hive",
+                "extra": json.dumps({"schema": "sales"}),
+            },
+            "trino://user:pass@localhost:8080/hive/sales",
+        ),
+        (
+            {"login": "u...@example.com", "password": "p@ss:word", "host": 
"localhost", "schema": "hive"},
+            "trino://user%40example.com:p%40ss%3Aword@localhost/hive",
+        ),
+        (
+            {"host": "localhost", "port": 8080, "schema": "hive"},
+            "trino://localhost:8080/hive",
+        ),
+        (
+            {
+                "login": "user",
+                "host": "host.example.com",
+                "schema": "hive",
+                "extra": json.dumps({"param1": "value1", "param2": "value2"}),
+            },
+            "trino://u...@host.example.com/hive?param1=value1&param2=value2",
+        ),
+    ],
+    ids=[
+        "basic-connection",
+        "with-extra-schema",
+        "special-chars",
+        "no-credentials",
+        "extra-params",
+    ],
+)
+def test_get_uri(conn_params, expected_uri):
+    """Test TrinoHook.get_uri properly formats connection URIs."""
+    with patch(HOOK_GET_CONNECTION) as mock_get_connection:
+        mock_get_connection.return_value = Connection(**conn_params)
+        hook = TrinoHook()
+        assert hook.get_uri() == expected_uri

Reply via email to