This is an automated email from the ASF dual-hosted git repository.
phanikumv pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 46e50595019 Fix #47919: get_uri() is not implemented in task sdk
(#48245)
46e50595019 is described below
commit 46e505950192fc8b6d35216ad3727e0a5d6021aa
Author: Johnny1cyber <[email protected]>
AuthorDate: Tue Apr 1 17:00:47 2025 +0100
Fix #47919: get_uri() is not implemented in task sdk (#48245)
Co-authored-by: Tzu-ping Chung <[email protected]>
---
task-sdk/src/airflow/sdk/definitions/connection.py | 56 +++++++++++++++++++++-
.../tests/task_sdk/definitions/test_connections.py | 26 ++++++++++
2 files changed, 81 insertions(+), 1 deletion(-)
diff --git a/task-sdk/src/airflow/sdk/definitions/connection.py
b/task-sdk/src/airflow/sdk/definitions/connection.py
index 5a447895679..c66b264dce0 100644
--- a/task-sdk/src/airflow/sdk/definitions/connection.py
+++ b/task-sdk/src/airflow/sdk/definitions/connection.py
@@ -56,7 +56,61 @@ class Connection:
port: int | None = None
extra: str | None = None
- def get_uri(self): ...
+ EXTRA_KEY = "__extra__"
+
+ def get_uri(self) -> str:
+ """Generate and return connection in URI format."""
+ from urllib.parse import parse_qsl, quote, urlencode
+
+ if self.conn_type and "_" in self.conn_type:
+ log.warning(
+ "Connection schemes (type: %s) shall not contain '_' according
to RFC3986.",
+ self.conn_type,
+ )
+ if self.conn_type:
+ uri = f"{self.conn_type.lower().replace('_', '-')}://"
+ else:
+ uri = "//"
+
+ if self.host and "://" in self.host:
+ protocol, host = self.host.split("://", 1)
+ else:
+ protocol, host = None, self.host or ""
+ if protocol:
+ uri += f"{protocol}://"
+
+ authority_block = ""
+ if self.login is not None:
+ authority_block += quote(self.login, safe="")
+ if self.password is not None:
+ authority_block += ":" + quote(self.password, safe="")
+ if authority_block > "":
+ authority_block += "@"
+ uri += authority_block
+
+ host_block = ""
+ if host != "":
+ host_block += quote(host, safe="")
+ if self.port:
+ if host_block == "" and authority_block == "":
+ host_block += f"@:{self.port}"
+ else:
+ host_block += f":{self.port}"
+ if self.schema:
+ host_block += f"/{quote(self.schema, safe='')}"
+ uri += host_block
+
+ if self.extra:
+ try:
+ query: str | None = urlencode(self.extra_dejson)
+ except TypeError:
+ query = None
+ if query and self.extra_dejson == dict(parse_qsl(query,
keep_blank_values=True)):
+ uri += ("?" if self.schema else "/?") + query
+ else:
+ uri += ("?" if self.schema else "/?") +
urlencode({self.EXTRA_KEY: self.extra})
+
+ return uri
def get_hook(self, *, hook_params=None):
"""Return hook based on conn_type."""
diff --git a/task-sdk/tests/task_sdk/definitions/test_connections.py
b/task-sdk/tests/task_sdk/definitions/test_connections.py
index b92090ff81c..102e85d36b3 100644
--- a/task-sdk/tests/task_sdk/definitions/test_connections.py
+++ b/task-sdk/tests/task_sdk/definitions/test_connections.py
@@ -18,6 +18,7 @@
from __future__ import annotations
from unittest import mock
+from urllib.parse import urlparse
import pytest
@@ -76,6 +77,31 @@ class TestConnections:
with pytest.raises(AirflowException, match='Unknown hook type
"unknown_type"'):
conn.get_hook()
+ def test_get_uri(self):
+ """Test that get_uri generates the correct URI based on connection
attributes."""
+
+ conn = Connection(
+ conn_id="test_conn",
+ conn_type="mysql",
+ host="localhost",
+ login="user",
+ password="password",
+ schema="test_schema",
+ port=3306,
+ extra=None,
+ )
+
+ uri = conn.get_uri()
+ parsed_uri = urlparse(uri)
+
+ assert uri == "mysql://user:password@localhost:3306/test_schema"
+ assert parsed_uri.scheme == "mysql"
+ assert parsed_uri.hostname == "localhost"
+ assert parsed_uri.username == "user"
+ assert parsed_uri.password == "password"
+ assert parsed_uri.port == 3306
+ assert parsed_uri.path.lstrip("/") == "test_schema"
+
def test_conn_get(self, mock_supervisor_comms):
conn_result = ConnectionResult(conn_id="mysql_conn",
conn_type="mysql", host="mysql", port=3306)
mock_supervisor_comms.get_message.return_value = conn_result