This is an automated email from the ASF dual-hosted git repository.

phanikumv pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 46e50595019 Fix #47919: get_uri() is not implemented in task sdk 
(#48245)
46e50595019 is described below

commit 46e505950192fc8b6d35216ad3727e0a5d6021aa
Author: Johnny1cyber <[email protected]>
AuthorDate: Tue Apr 1 17:00:47 2025 +0100

    Fix #47919: get_uri() is not implemented in task sdk (#48245)
    
    Co-authored-by: Tzu-ping Chung <[email protected]>
---
 task-sdk/src/airflow/sdk/definitions/connection.py | 56 +++++++++++++++++++++-
 .../tests/task_sdk/definitions/test_connections.py | 26 ++++++++++
 2 files changed, 81 insertions(+), 1 deletion(-)

diff --git a/task-sdk/src/airflow/sdk/definitions/connection.py 
b/task-sdk/src/airflow/sdk/definitions/connection.py
index 5a447895679..c66b264dce0 100644
--- a/task-sdk/src/airflow/sdk/definitions/connection.py
+++ b/task-sdk/src/airflow/sdk/definitions/connection.py
@@ -56,7 +56,61 @@ class Connection:
     port: int | None = None
     extra: str | None = None
 
-    def get_uri(self): ...
+    EXTRA_KEY = "__extra__"
+
+    def get_uri(self) -> str:
+        """Generate and return connection in URI format."""
+        from urllib.parse import parse_qsl, quote, urlencode
+
+        if self.conn_type and "_" in self.conn_type:
+            log.warning(
+                "Connection schemes (type: %s) shall not contain '_' according 
to RFC3986.",
+                self.conn_type,
+            )
+        if self.conn_type:
+            uri = f"{self.conn_type.lower().replace('_', '-')}://"
+        else:
+            uri = "//"
+
+        if self.host and "://" in self.host:
+            protocol, host = self.host.split("://", 1)
+        else:
+            protocol, host = None, self.host or ""
+        if protocol:
+            uri += f"{protocol}://"
+
+        authority_block = ""
+        if self.login is not None:
+            authority_block += quote(self.login, safe="")
+        if self.password is not None:
+            authority_block += ":" + quote(self.password, safe="")
+        if authority_block > "":
+            authority_block += "@"
+            uri += authority_block
+
+        host_block = ""
+        if host != "":
+            host_block += quote(host, safe="")
+        if self.port:
+            if host_block == "" and authority_block == "":
+                host_block += f"@:{self.port}"
+            else:
+                host_block += f":{self.port}"
+        if self.schema:
+            host_block += f"/{quote(self.schema, safe='')}"
+        uri += host_block
+
+        if self.extra:
+            try:
+                query: str | None = urlencode(self.extra_dejson)
+            except TypeError:
+                query = None
+            if query and self.extra_dejson == dict(parse_qsl(query, 
keep_blank_values=True)):
+                uri += ("?" if self.schema else "/?") + query
+            else:
+                uri += ("?" if self.schema else "/?") + 
urlencode({self.EXTRA_KEY: self.extra})
+
+        return uri
 
     def get_hook(self, *, hook_params=None):
         """Return hook based on conn_type."""
diff --git a/task-sdk/tests/task_sdk/definitions/test_connections.py 
b/task-sdk/tests/task_sdk/definitions/test_connections.py
index b92090ff81c..102e85d36b3 100644
--- a/task-sdk/tests/task_sdk/definitions/test_connections.py
+++ b/task-sdk/tests/task_sdk/definitions/test_connections.py
@@ -18,6 +18,7 @@
 from __future__ import annotations
 
 from unittest import mock
+from urllib.parse import urlparse
 
 import pytest
 
@@ -76,6 +77,31 @@ class TestConnections:
         with pytest.raises(AirflowException, match='Unknown hook type 
"unknown_type"'):
             conn.get_hook()
 
+    def test_get_uri(self):
+        """Test that get_uri generates the correct URI based on connection 
attributes."""
+
+        conn = Connection(
+            conn_id="test_conn",
+            conn_type="mysql",
+            host="localhost",
+            login="user",
+            password="password",
+            schema="test_schema",
+            port=3306,
+            extra=None,
+        )
+
+        uri = conn.get_uri()
+        parsed_uri = urlparse(uri)
+
+        assert uri == "mysql://user:password@localhost:3306/test_schema"
+        assert parsed_uri.scheme == "mysql"
+        assert parsed_uri.hostname == "localhost"
+        assert parsed_uri.username == "user"
+        assert parsed_uri.password == "password"
+        assert parsed_uri.port == 3306
+        assert parsed_uri.path.lstrip("/") == "test_schema"
+
     def test_conn_get(self, mock_supervisor_comms):
         conn_result = ConnectionResult(conn_id="mysql_conn", 
conn_type="mysql", host="mysql", port=3306)
         mock_supervisor_comms.get_message.return_value = conn_result

Reply via email to