This is an automated email from the ASF dual-hosted git repository. weilee pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 39a373cc841 fix: overwrite `get-uri` for `Trino` (#48917) 39a373cc841 is described below commit 39a373cc841eecf1ea9b4885539394d93ca6eb60 Author: Guan Ming(Wesley) Chiu <105915352+guan404m...@users.noreply.github.com> AuthorDate: Wed Apr 16 14:52:57 2025 +0800 fix: overwrite `get-uri` for `Trino` (#48917) --- .../src/airflow/providers/trino/hooks/trino.py | 36 +++++++++++++++ .../trino/tests/unit/trino/hooks/test_trino.py | 52 ++++++++++++++++++++++ 2 files changed, 88 insertions(+) diff --git a/providers/trino/src/airflow/providers/trino/hooks/trino.py b/providers/trino/src/airflow/providers/trino/hooks/trino.py index f907e95f4ee..39be65dd7a2 100644 --- a/providers/trino/src/airflow/providers/trino/hooks/trino.py +++ b/providers/trino/src/airflow/providers/trino/hooks/trino.py @@ -22,6 +22,7 @@ import os from collections.abc import Iterable, Mapping from pathlib import Path from typing import TYPE_CHECKING, Any, TypeVar +from urllib.parse import quote_plus, urlencode import trino from trino.exceptions import DatabaseError @@ -322,3 +323,38 @@ class TrinoHook(DbApiHook): def get_openlineage_default_schema(self): """Return Trino default schema.""" return trino.constants.DEFAULT_SCHEMA + + def get_uri(self) -> str: + """Return the Trino URI for the connection.""" + conn = self.connection + uri = "trino://" + + auth_part = "" + if conn.login: + auth_part = quote_plus(conn.login) + if conn.password: + auth_part = f"{auth_part}:{quote_plus(conn.password)}" + auth_part = f"{auth_part}@" + + host_part = conn.host or "localhost" + if conn.port: + host_part = f"{host_part}:{conn.port}" + + schema_part = "" + if conn.schema: + schema_part = f"/{quote_plus(conn.schema)}" + extra_schema = conn.extra_dejson.get("schema") + if extra_schema: + schema_part = f"{schema_part}/{quote_plus(extra_schema)}" + + uri = f"{uri}{auth_part}{host_part}{schema_part}" + + extra = conn.extra_dejson.copy() + if "schema" in extra: + extra.pop("schema") + + query_params = {k: str(v) for k, v in extra.items() if v is not None} + if query_params: + uri = f"{uri}?{urlencode(query_params)}" + + return uri diff --git a/providers/trino/tests/unit/trino/hooks/test_trino.py b/providers/trino/tests/unit/trino/hooks/test_trino.py index b0608023118..7b6dc819320 100644 --- a/providers/trino/tests/unit/trino/hooks/test_trino.py +++ b/providers/trino/tests/unit/trino/hooks/test_trino.py @@ -444,3 +444,55 @@ def test_execute_openlineage_events(): }, ) ] + + +@pytest.mark.parametrize( + "conn_params, expected_uri", + [ + ( + {"login": "user", "password": "pass", "host": "localhost", "port": 8080, "schema": "hive"}, + "trino://user:pass@localhost:8080/hive", + ), + ( + { + "login": "user", + "password": "pass", + "host": "localhost", + "port": 8080, + "schema": "hive", + "extra": json.dumps({"schema": "sales"}), + }, + "trino://user:pass@localhost:8080/hive/sales", + ), + ( + {"login": "u...@example.com", "password": "p@ss:word", "host": "localhost", "schema": "hive"}, + "trino://user%40example.com:p%40ss%3Aword@localhost/hive", + ), + ( + {"host": "localhost", "port": 8080, "schema": "hive"}, + "trino://localhost:8080/hive", + ), + ( + { + "login": "user", + "host": "host.example.com", + "schema": "hive", + "extra": json.dumps({"param1": "value1", "param2": "value2"}), + }, + "trino://u...@host.example.com/hive?param1=value1¶m2=value2", + ), + ], + ids=[ + "basic-connection", + "with-extra-schema", + "special-chars", + "no-credentials", + "extra-params", + ], +) +def test_get_uri(conn_params, expected_uri): + """Test TrinoHook.get_uri properly formats connection URIs.""" + with patch(HOOK_GET_CONNECTION) as mock_get_connection: + mock_get_connection.return_value = Connection(**conn_params) + hook = TrinoHook() + assert hook.get_uri() == expected_uri