This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch v2-3-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit d3366fcdde01b7702d534ccaca44035d39a53518
Author: Maksim <[email protected]>
AuthorDate: Mon Jun 20 00:25:59 2022 +0300

    Fix bugs in URI constructor for MySQL connection (#24320)
    
    * Fix bugs in URI constructor for MySQL connection
    
    * Update unit tests
    
    (cherry picked from commit ea54faf290cc15a4ace50c23fe9ab2fa9593059f)
---
 airflow/models/connection.py                  |  10 +--
 tests/cli/commands/test_connection_command.py |   6 +-
 tests/hooks/test_dbapi.py                     | 106 ++++++++++++++++++++++++++
 tests/secrets/test_local_filesystem.py        |   4 +-
 4 files changed, 116 insertions(+), 10 deletions(-)

diff --git a/airflow/models/connection.py b/airflow/models/connection.py
index 8134f372ca..4935023292 100644
--- a/airflow/models/connection.py
+++ b/airflow/models/connection.py
@@ -231,10 +231,10 @@ class Connection(Base, LoggingMixin):
             host_block += quote(self.host, safe='')
 
         if self.port:
-            if host_block > '':
-                host_block += f':{self.port}'
-            else:
+            if host_block == '' and authority_block == '':
                 host_block += f'@:{self.port}'
+            else:
+                host_block += f':{self.port}'
 
         if self.schema:
             host_block += f"/{quote(self.schema, safe='')}"
@@ -247,9 +247,9 @@ class Connection(Base, LoggingMixin):
             except TypeError:
                 query = None
             if query and self.extra_dejson == dict(parse_qsl(query, 
keep_blank_values=True)):
-                uri += '?' + query
+                uri += ('?' if self.schema else '/?') + query
             else:
-                uri += '?' + urlencode({self.EXTRA_KEY: self.extra})
+                uri += ('?' if self.schema else '/?') + 
urlencode({self.EXTRA_KEY: self.extra})
 
         return uri
 
diff --git a/tests/cli/commands/test_connection_command.py 
b/tests/cli/commands/test_connection_command.py
index 621be5916a..eb2d575d02 100644
--- a/tests/cli/commands/test_connection_command.py
+++ b/tests/cli/commands/test_connection_command.py
@@ -242,14 +242,14 @@ class TestCliExportConnections:
                 'uri',
                 [
                     "airflow_db=mysql://root:plainpassword@mysql/airflow",
-                    
"druid_broker_default=druid://druid-broker:8082?endpoint=druid%2Fv2%2Fsql",
+                    
"druid_broker_default=druid://druid-broker:8082/?endpoint=druid%2Fv2%2Fsql",
                 ],
             ),
             (
                 None,  # tests that default is URI
                 [
                     "airflow_db=mysql://root:plainpassword@mysql/airflow",
-                    
"druid_broker_default=druid://druid-broker:8082?endpoint=druid%2Fv2%2Fsql",
+                    
"druid_broker_default=druid://druid-broker:8082/?endpoint=druid%2Fv2%2Fsql",
                 ],
             ),
             (
@@ -287,7 +287,7 @@ class TestCliExportConnections:
         connection_command.connections_export(args)
         expected_connections = [
             "airflow_db=mysql://root:plainpassword@mysql/airflow",
-            
"druid_broker_default=druid://druid-broker:8082?endpoint=druid%2Fv2%2Fsql",
+            
"druid_broker_default=druid://druid-broker:8082/?endpoint=druid%2Fv2%2Fsql",
         ]
 
         assert output_filepath.read_text().splitlines() == expected_connections
diff --git a/tests/hooks/test_dbapi.py b/tests/hooks/test_dbapi.py
index fd2bbd9132..a17c24aedb 100644
--- a/tests/hooks/test_dbapi.py
+++ b/tests/hooks/test_dbapi.py
@@ -17,6 +17,7 @@
 # under the License.
 #
 
+import json
 import unittest
 from unittest import mock
 
@@ -235,6 +236,111 @@ class TestDbApiHook(unittest.TestCase):
         )
         assert "conn-type://host:1/schema" == self.db_hook.get_uri()
 
+    def test_get_uri_extra(self):
+        self.db_hook.get_connection = mock.MagicMock(
+            return_value=Connection(
+                conn_type="conn-type",
+                host="host",
+                login='login',
+                password='password',
+                extra=json.dumps({'charset': 'utf-8'}),
+            )
+        )
+        assert self.db_hook.get_uri() == 
"conn-type://login:password@host/?charset=utf-8"
+
+    def test_get_uri_extra_with_schema(self):
+        self.db_hook.get_connection = mock.MagicMock(
+            return_value=Connection(
+                conn_type="conn-type",
+                host="host",
+                login='login',
+                password='password',
+                schema="schema",
+                extra=json.dumps({'charset': 'utf-8'}),
+            )
+        )
+        assert self.db_hook.get_uri() == 
"conn-type://login:password@host/schema?charset=utf-8"
+
+    def test_get_uri_extra_with_port(self):
+        self.db_hook.get_connection = mock.MagicMock(
+            return_value=Connection(
+                conn_type="conn-type",
+                host="host",
+                login='login',
+                password='password',
+                port=3306,
+                extra=json.dumps({'charset': 'utf-8'}),
+            )
+        )
+        assert self.db_hook.get_uri() == 
"conn-type://login:password@host:3306/?charset=utf-8"
+
+    def test_get_uri_extra_with_port_and_empty_host(self):
+        self.db_hook.get_connection = mock.MagicMock(
+            return_value=Connection(
+                conn_type="conn-type",
+                login='login',
+                password='password',
+                port=3306,
+                extra=json.dumps({'charset': 'utf-8'}),
+            )
+        )
+        assert self.db_hook.get_uri() == 
"conn-type://login:password@:3306/?charset=utf-8"
+
+    def test_get_uri_extra_with_port_and_schema(self):
+        self.db_hook.get_connection = mock.MagicMock(
+            return_value=Connection(
+                conn_type="conn-type",
+                host="host",
+                login='login',
+                password='password',
+                schema="schema",
+                port=3306,
+                extra=json.dumps({'charset': 'utf-8'}),
+            )
+        )
+        assert self.db_hook.get_uri() == 
"conn-type://login:password@host:3306/schema?charset=utf-8"
+
+    def test_get_uri_without_password(self):
+        self.db_hook.get_connection = mock.MagicMock(
+            return_value=Connection(
+                conn_type="conn-type",
+                host="host",
+                login='login',
+                password=None,
+                schema="schema",
+                port=3306,
+                extra=json.dumps({'charset': 'utf-8'}),
+            )
+        )
+        assert self.db_hook.get_uri() == 
"conn-type://login@host:3306/schema?charset=utf-8"
+
+    def test_get_uri_without_auth(self):
+        self.db_hook.get_connection = mock.MagicMock(
+            return_value=Connection(
+                conn_type="conn-type",
+                host="host",
+                login=None,
+                password=None,
+                schema="schema",
+                port=3306,
+                extra=json.dumps({'charset': 'utf-8'}),
+            )
+        )
+        assert self.db_hook.get_uri() == 
"conn-type://host:3306/schema?charset=utf-8"
+
+    def test_get_uri_without_auth_and_empty_host(self):
+        self.db_hook.get_connection = mock.MagicMock(
+            return_value=Connection(
+                conn_type="conn-type",
+                login=None,
+                password=None,
+                schema="schema",
+                port=3306,
+                extra=json.dumps({'charset': 'utf-8'}),
+            )
+        )
+        assert self.db_hook.get_uri() == 
"conn-type://@:3306/schema?charset=utf-8"
+
     def test_run_log(self):
         statement = 'SQL'
         self.db_hook.run(statement)
diff --git a/tests/secrets/test_local_filesystem.py 
b/tests/secrets/test_local_filesystem.py
index a7bbd824f0..82ef7d469f 100644
--- a/tests/secrets/test_local_filesystem.py
+++ b/tests/secrets/test_local_filesystem.py
@@ -157,8 +157,8 @@ class TestLoadConnection(unittest.TestCase):
     @parameterized.expand(
         (
             (
-                "CONN_ID=mysql://host_1?param1=val1&param2=val2",
-                {"CONN_ID": "mysql://host_1?param1=val1&param2=val2"},
+                "CONN_ID=mysql://host_1/?param1=val1&param2=val2",
+                {"CONN_ID": "mysql://host_1/?param1=val1&param2=val2"},
             ),
         )
     )

Reply via email to