This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new e374b2af42d Refactor CloudSQLDatabaseHook.create_connection method to 
align with new Connection from airflow.sdk. (#56323)
e374b2af42d is described below

commit e374b2af42d33e39875887c8579df016a0358152
Author: Nitochkin <[email protected]>
AuthorDate: Mon Oct 13 19:42:50 2025 +0200

    Refactor CloudSQLDatabaseHook.create_connection method to align with new 
Connection from airflow.sdk. (#56323)
    
    Co-authored-by: Anton Nitochkin <[email protected]>
---
 .../providers/google/cloud/hooks/cloud_sql.py      |  99 ++++++-
 .../unit/google/cloud/hooks/test_cloud_sql.py      | 292 ++++++++++++++-------
 2 files changed, 289 insertions(+), 102 deletions(-)

diff --git 
a/providers/google/src/airflow/providers/google/cloud/hooks/cloud_sql.py 
b/providers/google/src/airflow/providers/google/cloud/hooks/cloud_sql.py
index 32b86cd9c02..ab05183aef7 100644
--- a/providers/google/src/airflow/providers/google/cloud/hooks/cloud_sql.py
+++ b/providers/google/src/airflow/providers/google/cloud/hooks/cloud_sql.py
@@ -50,7 +50,13 @@ from googleapiclient.errors import HttpError
 # Number of retries - used by googleapiclient method calls to perform retries
 # For requests that are "retriable"
 from airflow.exceptions import AirflowException
-from airflow.models import Connection
+from airflow.providers.google.version_compat import AIRFLOW_V_3_1_PLUS
+
+if AIRFLOW_V_3_1_PLUS:
+    from airflow.sdk import Connection
+else:
+    from airflow.models import Connection  # type: 
ignore[assignment,attr-defined,no-redef]
+
 from airflow.providers.google.cloud.hooks.secret_manager import (
     GoogleCloudSecretManagerHook,
 )
@@ -1045,15 +1051,26 @@ class CloudSQLDatabaseHook(BaseHook):
     def _quote(value) -> str | None:
         return quote_plus(value) if value else None
 
-    def _generate_connection_uri(self) -> str:
+    def _reserve_port(self):
         if self.use_proxy:
             if self.sql_proxy_use_tcp:
                 if not self.sql_proxy_tcp_port:
                     self.reserve_free_tcp_port()
             if not self.sql_proxy_unique_path:
                 self.sql_proxy_unique_path = self._generate_unique_path()
+
+    def _generate_connection_uri(self) -> str:
+        self._reserve_port()
         if not self.database_type:
             raise ValueError("The database_type should be set")
+        if not self.user:
+            raise AirflowException("The login parameter needs to be set in 
connection")
+        if not self.public_ip:
+            raise AirflowException("The location parameter needs to be set in 
connection")
+        if not self.password:
+            raise AirflowException("The password parameter needs to be set in 
connection")
+        if not self.database:
+            raise AirflowException("The database parameter needs to be set in 
connection")
 
         database_uris = CONNECTION_URIS[self.database_type]
         ssl_spec = None
@@ -1072,14 +1089,6 @@ class CloudSQLDatabaseHook(BaseHook):
                 ssl_spec = {"cert": self.sslcert, "key": self.sslkey, "ca": 
self.sslrootcert}
             else:
                 format_string = public_uris["non-ssl"]
-        if not self.user:
-            raise AirflowException("The login parameter needs to be set in 
connection")
-        if not self.public_ip:
-            raise AirflowException("The location parameter needs to be set in 
connection")
-        if not self.password:
-            raise AirflowException("The password parameter needs to be set in 
connection")
-        if not self.database:
-            raise AirflowException("The database parameter needs to be set in 
connection")
 
         connection_uri = format_string.format(
             user=quote_plus(self.user) if self.user else "",
@@ -1113,6 +1122,69 @@ class CloudSQLDatabaseHook(BaseHook):
             instance_specification += f"=tcp:{self.sql_proxy_tcp_port}"
         return instance_specification
 
+    def _generate_connection_parameters(self) -> dict:
+        self._reserve_port()
+        if not self.database_type:
+            raise ValueError("The database_type should be set")
+        if not self.user:
+            raise AirflowException("The login parameter needs to be set in 
connection")
+        if not self.public_ip:
+            raise AirflowException("The location parameter needs to be set in 
connection")
+        if not self.password:
+            raise AirflowException("The password parameter needs to be set in 
connection")
+        if not self.database:
+            raise AirflowException("The database parameter needs to be set in 
connection")
+
+        connection_parameters = {}
+
+        connection_parameters["conn_type"] = self.database_type
+        connection_parameters["login"] = self.user
+        connection_parameters["password"] = self.password
+        connection_parameters["schema"] = self.database
+        connection_parameters["extra"] = {}
+
+        database_uris = CONNECTION_URIS[self.database_type]
+        if self.use_proxy:
+            proxy_uris = database_uris["proxy"]
+            if self.sql_proxy_use_tcp:
+                connection_parameters["host"] = "127.0.0.1"
+                connection_parameters["port"] = self.sql_proxy_tcp_port
+            else:
+                socket_path = 
f"{self.sql_proxy_unique_path}/{self._get_instance_socket_name()}"
+                if "localhost" in proxy_uris["socket"]:
+                    connection_parameters["host"] = "localhost"
+                    connection_parameters["extra"].update({"unix_socket": 
socket_path})
+                else:
+                    connection_parameters["host"] = socket_path
+        else:
+            public_uris = database_uris["public"]
+            if self.use_ssl:
+                connection_parameters["host"] = self.public_ip
+                connection_parameters["port"] = self.public_port
+                if "ssl_spec" in public_uris["ssl"]:
+                    connection_parameters["extra"].update(
+                        {
+                            "ssl": json.dumps(
+                                {"cert": self.sslcert, "key": self.sslkey, 
"ca": self.sslrootcert}
+                            )
+                        }
+                    )
+                else:
+                    connection_parameters["extra"].update(
+                        {
+                            "sslmode": "verify-ca",
+                            "sslcert": self.sslcert,
+                            "sslkey": self.sslkey,
+                            "sslrootcert": self.sslrootcert,
+                        }
+                    )
+            else:
+                connection_parameters["host"] = self.public_ip
+                connection_parameters["port"] = self.public_port
+        if connection_parameters.get("extra"):
+            connection_parameters["extra"] = 
json.dumps(connection_parameters["extra"])
+        return connection_parameters
+
     def create_connection(self) -> Connection:
         """
         Create a connection.
@@ -1120,8 +1192,11 @@ class CloudSQLDatabaseHook(BaseHook):
         Connection ID will be randomly generated according to whether it uses
         proxy, TCP, UNIX sockets, SSL.
         """
-        uri = self._generate_connection_uri()
-        connection = Connection(conn_id=self.db_conn_id, uri=uri)
+        if AIRFLOW_V_3_1_PLUS:
+            kwargs = self._generate_connection_parameters()
+        else:
+            kwargs = {"uri": self._generate_connection_uri()}
+        connection = Connection(conn_id=self.db_conn_id, **kwargs)
         self.log.info("Creating connection %s", self.db_conn_id)
         return connection
 
diff --git a/providers/google/tests/unit/google/cloud/hooks/test_cloud_sql.py 
b/providers/google/tests/unit/google/cloud/hooks/test_cloud_sql.py
index 52d1190fbc3..1a35b439756 100644
--- a/providers/google/tests/unit/google/cloud/hooks/test_cloud_sql.py
+++ b/providers/google/tests/unit/google/cloud/hooks/test_cloud_sql.py
@@ -24,6 +24,7 @@ import platform
 import tempfile
 from unittest import mock
 from unittest.mock import PropertyMock, call, mock_open
+from urllib.parse import parse_qsl, unquote, urlsplit
 
 import aiohttp
 import httplib2
@@ -33,7 +34,13 @@ from googleapiclient.errors import HttpError
 from yarl import URL
 
 from airflow.exceptions import AirflowException
-from airflow.models import Connection
+
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS
+
+if AIRFLOW_V_3_1_PLUS:
+    from airflow.sdk import Connection
+else:
+    from airflow.models import Connection  # type: 
ignore[assignment,attr-defined,no-redef]
 from airflow.providers.google.cloud.hooks.cloud_sql import (
     CloudSQLAsyncHook,
     CloudSQLDatabaseHook,
@@ -761,13 +768,40 @@ class TestGcpSqlHookNoDefaultProjectID:
         )
 
 
+def _parse_from_uri(uri: str):
+    connection_parameters = {}
+    uri_parts = urlsplit(uri)
+    connection_parameters["conn_type"] = uri_parts.scheme
+    rest_of_the_url = uri.replace(f"{uri_parts.scheme}://", "//")
+    uri_parts = urlsplit(rest_of_the_url)
+    host = unquote(uri_parts.hostname or "")
+    connection_parameters["host"] = host
+    quoted_schema = uri_parts.path[1:]
+    connection_parameters["schema"] = unquote(quoted_schema) if quoted_schema 
else ""
+    connection_parameters["login"] = unquote(uri_parts.username) if 
uri_parts.username else ""
+    connection_parameters["password"] = unquote(uri_parts.password) if 
uri_parts.password else ""
+    connection_parameters["port"] = uri_parts.port  # type: ignore[assignment]
+    if uri_parts.query:
+        query = dict(parse_qsl(uri_parts.query, keep_blank_values=True))
+        connection_parameters["extra"] = json.dumps(query)
+    return connection_parameters
+
+
 class TestCloudSqlDatabaseHook:
     
@mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection")
     def test_cloudsql_database_hook_validate_ssl_certs_no_ssl(self, 
get_connection):
-        connection = Connection()
-        connection.set_extra(
-            json.dumps({"location": "test", "instance": "instance", 
"database_type": "postgres"})
+        connection = Connection(
+            conn_id="my_test_connection",
+            conn_type="gcpcloudsqldb",
         )
+        if AIRFLOW_V_3_1_PLUS:
+            connection.extra = json.dumps(
+                {"location": "test", "instance": "instance", "database_type": 
"postgres"}
+            )
+        else:
+            connection.set_extra(
+                json.dumps({"location": "test", "instance": "instance", 
"database_type": "postgres"})
+            )
         get_connection.return_value = connection
         hook = CloudSQLDatabaseHook(
             gcp_cloudsql_conn_id="cloudsql_connection", 
default_gcp_project_id="google_connection"
@@ -794,10 +828,16 @@ class TestCloudSqlDatabaseHook:
     ):
         mock_is_file.side_effects = True
         mock_set_temporary_ssl_file.side_effect = cert_dict.values()
-        connection = Connection()
+        connection = Connection(
+            conn_id="my_test_connection",
+            conn_type="gcpcloudsqldb",
+        )
         extras = {"location": "test", "instance": "instance", "database_type": 
"postgres", "use_ssl": "True"}
         extras.update(cert_dict)
-        connection.set_extra(json.dumps(extras))
+        if AIRFLOW_V_3_1_PLUS:
+            connection.extra = json.dumps(extras)
+        else:
+            connection.set_extra(json.dumps(extras))
 
         get_connection.return_value = connection
         hook = CloudSQLDatabaseHook(
@@ -814,26 +854,31 @@ class TestCloudSqlDatabaseHook:
     def test_cloudsql_database_hook_validate_ssl_certs_with_ssl(
         self, get_connection, mock_set_temporary_ssl_file, mock_is_file
     ):
-        connection = Connection()
+        connection = Connection(
+            conn_id="my_test_connection",
+            conn_type="gcpcloudsqldb",
+        )
         mock_is_file.return_value = True
         mock_set_temporary_ssl_file.side_effect = [
             "/tmp/cert_file.pem",
             "/tmp/rootcert_file.pem",
             "/tmp/key_file.pem",
         ]
-        connection.set_extra(
-            json.dumps(
-                {
-                    "location": "test",
-                    "instance": "instance",
-                    "database_type": "postgres",
-                    "use_ssl": "True",
-                    "sslcert": "cert_file.pem",
-                    "sslrootcert": "rootcert_file.pem",
-                    "sslkey": "key_file.pem",
-                }
-            )
+        extras = json.dumps(
+            {
+                "location": "test",
+                "instance": "instance",
+                "database_type": "postgres",
+                "use_ssl": "True",
+                "sslcert": "cert_file.pem",
+                "sslrootcert": "rootcert_file.pem",
+                "sslkey": "key_file.pem",
+            }
         )
+        if AIRFLOW_V_3_1_PLUS:
+            connection.extra = extras
+        else:
+            connection.set_extra(extras)
         get_connection.return_value = connection
         hook = CloudSQLDatabaseHook(
             gcp_cloudsql_conn_id="cloudsql_connection", 
default_gcp_project_id="google_connection"
@@ -846,26 +891,31 @@ class TestCloudSqlDatabaseHook:
     def 
test_cloudsql_database_hook_validate_ssl_certs_with_ssl_files_not_readable(
         self, get_connection, mock_set_temporary_ssl_file, mock_is_file
     ):
-        connection = Connection()
+        connection = Connection(
+            conn_id="my_test_connection",
+            conn_type="gcpcloudsqldb",
+        )
         mock_is_file.return_value = False
         mock_set_temporary_ssl_file.side_effect = [
             "/tmp/cert_file.pem",
             "/tmp/rootcert_file.pem",
             "/tmp/key_file.pem",
         ]
-        connection.set_extra(
-            json.dumps(
-                {
-                    "location": "test",
-                    "instance": "instance",
-                    "database_type": "postgres",
-                    "use_ssl": "True",
-                    "sslcert": "cert_file.pem",
-                    "sslrootcert": "rootcert_file.pem",
-                    "sslkey": "key_file.pem",
-                }
-            )
+        extras = json.dumps(
+            {
+                "location": "test",
+                "instance": "instance",
+                "database_type": "postgres",
+                "use_ssl": "True",
+                "sslcert": "cert_file.pem",
+                "sslrootcert": "rootcert_file.pem",
+                "sslkey": "key_file.pem",
+            }
         )
+        if AIRFLOW_V_3_1_PLUS:
+            connection.extra = extras
+        else:
+            connection.set_extra(extras)
         get_connection.return_value = connection
         hook = CloudSQLDatabaseHook(
             gcp_cloudsql_conn_id="cloudsql_connection", 
default_gcp_project_id="google_connection"
@@ -881,18 +931,23 @@ class TestCloudSqlDatabaseHook:
         self, get_connection, gettempdir_mock
     ):
         gettempdir_mock.return_value = "/tmp"
-        connection = Connection()
-        connection.set_extra(
-            json.dumps(
-                {
-                    "location": "test",
-                    "instance": 
"very_long_instance_name_that_will_be_too_long_to_build_socket_length",
-                    "database_type": "postgres",
-                    "use_proxy": "True",
-                    "use_tcp": "False",
-                }
-            )
+        connection = Connection(
+            conn_id="my_test_connection",
+            conn_type="gcpcloudsqldb",
         )
+        extras = json.dumps(
+            {
+                "location": "test",
+                "instance": 
"very_long_instance_name_that_will_be_too_long_to_build_socket_length",
+                "database_type": "postgres",
+                "use_proxy": "True",
+                "use_tcp": "False",
+            }
+        )
+        if AIRFLOW_V_3_1_PLUS:
+            connection.extra = extras
+        else:
+            connection.set_extra(extras)
         get_connection.return_value = connection
         hook = CloudSQLDatabaseHook(
             gcp_cloudsql_conn_id="cloudsql_connection", 
default_gcp_project_id="google_connection"
@@ -908,18 +963,23 @@ class TestCloudSqlDatabaseHook:
         self, get_connection, gettempdir_mock
     ):
         gettempdir_mock.return_value = "/tmp"
-        connection = Connection()
-        connection.set_extra(
-            json.dumps(
-                {
-                    "location": "test",
-                    "instance": "short_instance_name",
-                    "database_type": "postgres",
-                    "use_proxy": "True",
-                    "use_tcp": "False",
-                }
-            )
+        connection = Connection(
+            conn_id="my_test_connection",
+            conn_type="gcpcloudsqldb",
+        )
+        extras = json.dumps(
+            {
+                "location": "test",
+                "instance": "short_instance_name",
+                "database_type": "postgres",
+                "use_proxy": "True",
+                "use_tcp": "False",
+            }
         )
+        if AIRFLOW_V_3_1_PLUS:
+            connection.extra = extras
+        else:
+            connection.set_extra(extras)
         get_connection.return_value = connection
         hook = CloudSQLDatabaseHook(
             gcp_cloudsql_conn_id="cloudsql_connection", 
default_gcp_project_id="google_connection"
@@ -940,7 +1000,10 @@ class TestCloudSqlDatabaseHook:
     )
     
@mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection")
     def test_cloudsql_database_hook_create_connection_missing_fields(self, 
get_connection, uri):
-        connection = Connection(uri=uri)
+        if AIRFLOW_V_3_1_PLUS:
+            connection = Connection(conn_id="test_conn_id", 
**_parse_from_uri(uri))
+        else:
+            connection = Connection(uri=uri)
         params = {
             "location": "test",
             "instance": "instance",
@@ -948,7 +1011,11 @@ class TestCloudSqlDatabaseHook:
             "use_proxy": "True",
             "use_tcp": "False",
         }
-        connection.set_extra(json.dumps(params))
+        extras = json.dumps(params)
+        if AIRFLOW_V_3_1_PLUS:
+            connection.extra = extras
+        else:
+            connection.set_extra(extras)
         get_connection.return_value = connection
         hook = CloudSQLDatabaseHook(
             gcp_cloudsql_conn_id="cloudsql_connection", 
default_gcp_project_id="google_connection"
@@ -960,16 +1027,23 @@ class TestCloudSqlDatabaseHook:
 
     
@mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection")
     def test_cloudsql_database_hook_get_sqlproxy_runner_no_proxy(self, 
get_connection):
-        connection = Connection(uri="http://user:password@host:80/database";)
-        connection.set_extra(
-            json.dumps(
-                {
-                    "location": "test",
-                    "instance": "instance",
-                    "database_type": "postgres",
-                }
+        if AIRFLOW_V_3_1_PLUS:
+            connection = Connection(
+                conn_id="test_conn_id", 
**_parse_from_uri("http://user:password@host:80/database";)
             )
+        else:
+            connection = 
Connection(uri="http://user:password@host:80/database";)
+        extras = json.dumps(
+            {
+                "location": "test",
+                "instance": "instance",
+                "database_type": "postgres",
+            }
         )
+        if AIRFLOW_V_3_1_PLUS:
+            connection.extra = extras
+        else:
+            connection.set_extra(extras)
         get_connection.return_value = connection
         hook = CloudSQLDatabaseHook(
             gcp_cloudsql_conn_id="cloudsql_connection", 
default_gcp_project_id="google_connection"
@@ -981,18 +1055,25 @@ class TestCloudSqlDatabaseHook:
 
     
@mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection")
     def test_cloudsql_database_hook_get_sqlproxy_runner(self, get_connection):
-        connection = Connection(uri="http://user:password@host:80/database";)
-        connection.set_extra(
-            json.dumps(
-                {
-                    "location": "test",
-                    "instance": "instance",
-                    "database_type": "postgres",
-                    "use_proxy": "True",
-                    "use_tcp": "False",
-                }
+        if AIRFLOW_V_3_1_PLUS:
+            connection = Connection(
+                conn_id="test_conn_id", 
**_parse_from_uri("http://user:password@host:80/database";)
             )
+        else:
+            connection = 
Connection(uri="http://user:password@host:80/database";)
+        extras = json.dumps(
+            {
+                "location": "test",
+                "instance": "instance",
+                "database_type": "postgres",
+                "use_proxy": "True",
+                "use_tcp": "False",
+            }
         )
+        if AIRFLOW_V_3_1_PLUS:
+            connection.extra = extras
+        else:
+            connection.set_extra(extras)
         get_connection.return_value = connection
         hook = CloudSQLDatabaseHook(
             gcp_cloudsql_conn_id="cloudsql_connection", 
default_gcp_project_id="google_connection"
@@ -1003,16 +1084,23 @@ class TestCloudSqlDatabaseHook:
 
     
@mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection")
     def test_cloudsql_database_hook_get_database_hook(self, get_connection):
-        connection = Connection(uri="http://user:password@host:80/database";)
-        connection.set_extra(
-            json.dumps(
-                {
-                    "location": "test",
-                    "instance": "instance",
-                    "database_type": "postgres",
-                }
+        if AIRFLOW_V_3_1_PLUS:
+            connection = Connection(
+                conn_id="test_conn_id", 
**_parse_from_uri("http://user:password@host:80/database";)
             )
+        else:
+            connection = 
Connection(uri="http://user:password@host:80/database";)
+        extras = json.dumps(
+            {
+                "location": "test",
+                "instance": "instance",
+                "database_type": "postgres",
+            }
         )
+        if AIRFLOW_V_3_1_PLUS:
+            connection.extra = extras
+        else:
+            connection.set_extra(extras)
         get_connection.return_value = connection
         hook = CloudSQLDatabaseHook(
             gcp_cloudsql_conn_id="cloudsql_connection", 
default_gcp_project_id="google_connection"
@@ -1414,7 +1502,10 @@ class TestCloudSqlDatabaseQueryHook:
             "key_path": "/var/local/google_cloud_default.json",
         }
         conn_extra_json = json.dumps(conn_extra)
-        self.connection.set_extra(conn_extra_json)
+        if AIRFLOW_V_3_1_PLUS:
+            self.connection.extra = conn_extra_json
+        else:
+            self.connection.set_extra(conn_extra_json)
 
         mock_get_conn.side_effect = [self.sql_connection, self.connection]
         self.db_hook = CloudSQLDatabaseHook(
@@ -1440,14 +1531,20 @@ class TestCloudSqlDatabaseQueryHook:
             "test_db_with_longname_but_with_limit_of_UNIX_socket&"
             "use_proxy=True&sql_proxy_use_tcp=False"
         )
-        get_connection.side_effect = [Connection(uri=uri)]
+        if AIRFLOW_V_3_1_PLUS:
+            get_connection.side_effect = [Connection(conn_id="test_conn_id", 
**_parse_from_uri(uri))]
+        else:
+            get_connection.side_effect = [Connection(uri=uri)]
         hook = CloudSQLDatabaseHook()
         connection = hook.create_connection()
         assert connection.conn_type == "postgres"
         assert connection.schema == "testdb"
 
     def _verify_postgres_connection(self, get_connection, uri):
-        get_connection.side_effect = [Connection(uri=uri)]
+        if AIRFLOW_V_3_1_PLUS:
+            get_connection.side_effect = [Connection(conn_id="test_conn_id", 
**_parse_from_uri(uri))]
+        else:
+            get_connection.side_effect = [Connection(uri=uri)]
         hook = CloudSQLDatabaseHook()
         connection = hook.create_connection()
         assert connection.conn_type == "postgres"
@@ -1490,7 +1587,10 @@ class TestCloudSqlDatabaseQueryHook:
             "project_id=example-project&location=europe-west1&instance=testdb&"
             "use_proxy=True&sql_proxy_use_tcp=False"
         )
-        get_connection.side_effect = [Connection(uri=uri)]
+        if AIRFLOW_V_3_1_PLUS:
+            get_connection.side_effect = [Connection(conn_id="test_conn_id", 
**_parse_from_uri(uri))]
+        else:
+            get_connection.side_effect = [Connection(uri=uri)]
         hook = CloudSQLDatabaseHook()
         connection = hook.create_connection()
         assert connection.conn_type == "postgres"
@@ -1509,7 +1609,10 @@ class TestCloudSqlDatabaseQueryHook:
         self.verify_mysql_connection(get_connection, uri)
 
     def verify_mysql_connection(self, get_connection, uri):
-        get_connection.side_effect = [Connection(uri=uri)]
+        if AIRFLOW_V_3_1_PLUS:
+            get_connection.side_effect = [Connection(conn_id="test_conn_id", 
**_parse_from_uri(uri))]
+        else:
+            get_connection.side_effect = [Connection(uri=uri)]
         hook = CloudSQLDatabaseHook()
         connection = hook.create_connection()
         assert connection.conn_type == "mysql"
@@ -1525,7 +1628,10 @@ class TestCloudSqlDatabaseQueryHook:
             "project_id=example-project&location=europe-west1&instance=testdb&"
             "use_proxy=True&sql_proxy_use_tcp=True"
         )
-        get_connection.side_effect = [Connection(uri=uri)]
+        if AIRFLOW_V_3_1_PLUS:
+            get_connection.side_effect = [Connection(conn_id="test_conn_id", 
**_parse_from_uri(uri))]
+        else:
+            get_connection.side_effect = [Connection(uri=uri)]
         hook = CloudSQLDatabaseHook()
         connection = hook.create_connection()
         assert connection.conn_type == "postgres"
@@ -1567,7 +1673,10 @@ class TestCloudSqlDatabaseQueryHook:
             "project_id=example-project&location=europe-west1&instance=testdb&"
             "use_proxy=True&sql_proxy_use_tcp=False"
         )
-        get_connection.side_effect = [Connection(uri=uri)]
+        if AIRFLOW_V_3_1_PLUS:
+            get_connection.side_effect = [Connection(conn_id="test_conn_id", 
**_parse_from_uri(uri))]
+        else:
+            get_connection.side_effect = [Connection(uri=uri)]
         hook = CloudSQLDatabaseHook()
         connection = hook.create_connection()
         assert connection.conn_type == "mysql"
@@ -1584,7 +1693,10 @@ class TestCloudSqlDatabaseQueryHook:
             "project_id=example-project&location=europe-west1&instance=testdb&"
             "use_proxy=True&sql_proxy_use_tcp=True"
         )
-        get_connection.side_effect = [Connection(uri=uri)]
+        if AIRFLOW_V_3_1_PLUS:
+            get_connection.side_effect = [Connection(conn_id="test_conn_id", 
**_parse_from_uri(uri))]
+        else:
+            get_connection.side_effect = [Connection(uri=uri)]
         hook = CloudSQLDatabaseHook()
         connection = hook.create_connection()
         assert connection.conn_type == "mysql"

Reply via email to