This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new e374b2af42d Refactor CloudSQLDatabaseHook.create_connection method to
align with new Connection from airflow.sdk. (#56323)
e374b2af42d is described below
commit e374b2af42d33e39875887c8579df016a0358152
Author: Nitochkin <[email protected]>
AuthorDate: Mon Oct 13 19:42:50 2025 +0200
Refactor CloudSQLDatabaseHook.create_connection method to align with new
Connection from airflow.sdk. (#56323)
Co-authored-by: Anton Nitochkin <[email protected]>
---
.../providers/google/cloud/hooks/cloud_sql.py | 99 ++++++-
.../unit/google/cloud/hooks/test_cloud_sql.py | 292 ++++++++++++++-------
2 files changed, 289 insertions(+), 102 deletions(-)
diff --git
a/providers/google/src/airflow/providers/google/cloud/hooks/cloud_sql.py
b/providers/google/src/airflow/providers/google/cloud/hooks/cloud_sql.py
index 32b86cd9c02..ab05183aef7 100644
--- a/providers/google/src/airflow/providers/google/cloud/hooks/cloud_sql.py
+++ b/providers/google/src/airflow/providers/google/cloud/hooks/cloud_sql.py
@@ -50,7 +50,13 @@ from googleapiclient.errors import HttpError
# Number of retries - used by googleapiclient method calls to perform retries
# For requests that are "retriable"
from airflow.exceptions import AirflowException
-from airflow.models import Connection
+from airflow.providers.google.version_compat import AIRFLOW_V_3_1_PLUS
+
+if AIRFLOW_V_3_1_PLUS:
+ from airflow.sdk import Connection
+else:
+ from airflow.models import Connection # type:
ignore[assignment,attr-defined,no-redef]
+
from airflow.providers.google.cloud.hooks.secret_manager import (
GoogleCloudSecretManagerHook,
)
@@ -1045,15 +1051,26 @@ class CloudSQLDatabaseHook(BaseHook):
def _quote(value) -> str | None:
return quote_plus(value) if value else None
- def _generate_connection_uri(self) -> str:
+ def _reserve_port(self):
if self.use_proxy:
if self.sql_proxy_use_tcp:
if not self.sql_proxy_tcp_port:
self.reserve_free_tcp_port()
if not self.sql_proxy_unique_path:
self.sql_proxy_unique_path = self._generate_unique_path()
+
+ def _generate_connection_uri(self) -> str:
+ self._reserve_port()
if not self.database_type:
raise ValueError("The database_type should be set")
+ if not self.user:
+ raise AirflowException("The login parameter needs to be set in
connection")
+ if not self.public_ip:
+ raise AirflowException("The location parameter needs to be set in
connection")
+ if not self.password:
+ raise AirflowException("The password parameter needs to be set in
connection")
+ if not self.database:
+ raise AirflowException("The database parameter needs to be set in
connection")
database_uris = CONNECTION_URIS[self.database_type]
ssl_spec = None
@@ -1072,14 +1089,6 @@ class CloudSQLDatabaseHook(BaseHook):
ssl_spec = {"cert": self.sslcert, "key": self.sslkey, "ca":
self.sslrootcert}
else:
format_string = public_uris["non-ssl"]
- if not self.user:
- raise AirflowException("The login parameter needs to be set in
connection")
- if not self.public_ip:
- raise AirflowException("The location parameter needs to be set in
connection")
- if not self.password:
- raise AirflowException("The password parameter needs to be set in
connection")
- if not self.database:
- raise AirflowException("The database parameter needs to be set in
connection")
connection_uri = format_string.format(
user=quote_plus(self.user) if self.user else "",
@@ -1113,6 +1122,69 @@ class CloudSQLDatabaseHook(BaseHook):
instance_specification += f"=tcp:{self.sql_proxy_tcp_port}"
return instance_specification
+ def _generate_connection_parameters(self) -> dict:
+ self._reserve_port()
+ if not self.database_type:
+ raise ValueError("The database_type should be set")
+ if not self.user:
+ raise AirflowException("The login parameter needs to be set in
connection")
+ if not self.public_ip:
+ raise AirflowException("The location parameter needs to be set in
connection")
+ if not self.password:
+ raise AirflowException("The password parameter needs to be set in
connection")
+ if not self.database:
+ raise AirflowException("The database parameter needs to be set in
connection")
+
+ connection_parameters = {}
+
+ connection_parameters["conn_type"] = self.database_type
+ connection_parameters["login"] = self.user
+ connection_parameters["password"] = self.password
+ connection_parameters["schema"] = self.database
+ connection_parameters["extra"] = {}
+
+ database_uris = CONNECTION_URIS[self.database_type]
+ if self.use_proxy:
+ proxy_uris = database_uris["proxy"]
+ if self.sql_proxy_use_tcp:
+ connection_parameters["host"] = "127.0.0.1"
+ connection_parameters["port"] = self.sql_proxy_tcp_port
+ else:
+ socket_path =
f"{self.sql_proxy_unique_path}/{self._get_instance_socket_name()}"
+ if "localhost" in proxy_uris["socket"]:
+ connection_parameters["host"] = "localhost"
+ connection_parameters["extra"].update({"unix_socket":
socket_path})
+ else:
+ connection_parameters["host"] = socket_path
+ else:
+ public_uris = database_uris["public"]
+ if self.use_ssl:
+ connection_parameters["host"] = self.public_ip
+ connection_parameters["port"] = self.public_port
+ if "ssl_spec" in public_uris["ssl"]:
+ connection_parameters["extra"].update(
+ {
+ "ssl": json.dumps(
+ {"cert": self.sslcert, "key": self.sslkey,
"ca": self.sslrootcert}
+ )
+ }
+ )
+ else:
+ connection_parameters["extra"].update(
+ {
+ "sslmode": "verify-ca",
+ "sslcert": self.sslcert,
+ "sslkey": self.sslkey,
+ "sslrootcert": self.sslrootcert,
+ }
+ )
+ else:
+ connection_parameters["host"] = self.public_ip
+ connection_parameters["port"] = self.public_port
+ if connection_parameters.get("extra"):
+ connection_parameters["extra"] =
json.dumps(connection_parameters["extra"])
+ return connection_parameters
+
def create_connection(self) -> Connection:
"""
Create a connection.
@@ -1120,8 +1192,11 @@ class CloudSQLDatabaseHook(BaseHook):
Connection ID will be randomly generated according to whether it uses
proxy, TCP, UNIX sockets, SSL.
"""
- uri = self._generate_connection_uri()
- connection = Connection(conn_id=self.db_conn_id, uri=uri)
+ if AIRFLOW_V_3_1_PLUS:
+ kwargs = self._generate_connection_parameters()
+ else:
+ kwargs = {"uri": self._generate_connection_uri()}
+ connection = Connection(conn_id=self.db_conn_id, **kwargs)
self.log.info("Creating connection %s", self.db_conn_id)
return connection
diff --git a/providers/google/tests/unit/google/cloud/hooks/test_cloud_sql.py
b/providers/google/tests/unit/google/cloud/hooks/test_cloud_sql.py
index 52d1190fbc3..1a35b439756 100644
--- a/providers/google/tests/unit/google/cloud/hooks/test_cloud_sql.py
+++ b/providers/google/tests/unit/google/cloud/hooks/test_cloud_sql.py
@@ -24,6 +24,7 @@ import platform
import tempfile
from unittest import mock
from unittest.mock import PropertyMock, call, mock_open
+from urllib.parse import parse_qsl, unquote, urlsplit
import aiohttp
import httplib2
@@ -33,7 +34,13 @@ from googleapiclient.errors import HttpError
from yarl import URL
from airflow.exceptions import AirflowException
-from airflow.models import Connection
+
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS
+
+if AIRFLOW_V_3_1_PLUS:
+ from airflow.sdk import Connection
+else:
+ from airflow.models import Connection # type:
ignore[assignment,attr-defined,no-redef]
from airflow.providers.google.cloud.hooks.cloud_sql import (
CloudSQLAsyncHook,
CloudSQLDatabaseHook,
@@ -761,13 +768,40 @@ class TestGcpSqlHookNoDefaultProjectID:
)
+def _parse_from_uri(uri: str):
+ connection_parameters = {}
+ uri_parts = urlsplit(uri)
+ connection_parameters["conn_type"] = uri_parts.scheme
+ rest_of_the_url = uri.replace(f"{uri_parts.scheme}://", "//")
+ uri_parts = urlsplit(rest_of_the_url)
+ host = unquote(uri_parts.hostname or "")
+ connection_parameters["host"] = host
+ quoted_schema = uri_parts.path[1:]
+ connection_parameters["schema"] = unquote(quoted_schema) if quoted_schema
else ""
+ connection_parameters["login"] = unquote(uri_parts.username) if
uri_parts.username else ""
+ connection_parameters["password"] = unquote(uri_parts.password) if
uri_parts.password else ""
+ connection_parameters["port"] = uri_parts.port # type: ignore[assignment]
+ if uri_parts.query:
+ query = dict(parse_qsl(uri_parts.query, keep_blank_values=True))
+ connection_parameters["extra"] = json.dumps(query)
+ return connection_parameters
+
+
class TestCloudSqlDatabaseHook:
@mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection")
def test_cloudsql_database_hook_validate_ssl_certs_no_ssl(self,
get_connection):
- connection = Connection()
- connection.set_extra(
- json.dumps({"location": "test", "instance": "instance",
"database_type": "postgres"})
+ connection = Connection(
+ conn_id="my_test_connection",
+ conn_type="gcpcloudsqldb",
)
+ if AIRFLOW_V_3_1_PLUS:
+ connection.extra = json.dumps(
+ {"location": "test", "instance": "instance", "database_type":
"postgres"}
+ )
+ else:
+ connection.set_extra(
+ json.dumps({"location": "test", "instance": "instance",
"database_type": "postgres"})
+ )
get_connection.return_value = connection
hook = CloudSQLDatabaseHook(
gcp_cloudsql_conn_id="cloudsql_connection",
default_gcp_project_id="google_connection"
@@ -794,10 +828,16 @@ class TestCloudSqlDatabaseHook:
):
mock_is_file.side_effects = True
mock_set_temporary_ssl_file.side_effect = cert_dict.values()
- connection = Connection()
+ connection = Connection(
+ conn_id="my_test_connection",
+ conn_type="gcpcloudsqldb",
+ )
extras = {"location": "test", "instance": "instance", "database_type":
"postgres", "use_ssl": "True"}
extras.update(cert_dict)
- connection.set_extra(json.dumps(extras))
+ if AIRFLOW_V_3_1_PLUS:
+ connection.extra = json.dumps(extras)
+ else:
+ connection.set_extra(json.dumps(extras))
get_connection.return_value = connection
hook = CloudSQLDatabaseHook(
@@ -814,26 +854,31 @@ class TestCloudSqlDatabaseHook:
def test_cloudsql_database_hook_validate_ssl_certs_with_ssl(
self, get_connection, mock_set_temporary_ssl_file, mock_is_file
):
- connection = Connection()
+ connection = Connection(
+ conn_id="my_test_connection",
+ conn_type="gcpcloudsqldb",
+ )
mock_is_file.return_value = True
mock_set_temporary_ssl_file.side_effect = [
"/tmp/cert_file.pem",
"/tmp/rootcert_file.pem",
"/tmp/key_file.pem",
]
- connection.set_extra(
- json.dumps(
- {
- "location": "test",
- "instance": "instance",
- "database_type": "postgres",
- "use_ssl": "True",
- "sslcert": "cert_file.pem",
- "sslrootcert": "rootcert_file.pem",
- "sslkey": "key_file.pem",
- }
- )
+ extras = json.dumps(
+ {
+ "location": "test",
+ "instance": "instance",
+ "database_type": "postgres",
+ "use_ssl": "True",
+ "sslcert": "cert_file.pem",
+ "sslrootcert": "rootcert_file.pem",
+ "sslkey": "key_file.pem",
+ }
)
+ if AIRFLOW_V_3_1_PLUS:
+ connection.extra = extras
+ else:
+ connection.set_extra(extras)
get_connection.return_value = connection
hook = CloudSQLDatabaseHook(
gcp_cloudsql_conn_id="cloudsql_connection",
default_gcp_project_id="google_connection"
@@ -846,26 +891,31 @@ class TestCloudSqlDatabaseHook:
def
test_cloudsql_database_hook_validate_ssl_certs_with_ssl_files_not_readable(
self, get_connection, mock_set_temporary_ssl_file, mock_is_file
):
- connection = Connection()
+ connection = Connection(
+ conn_id="my_test_connection",
+ conn_type="gcpcloudsqldb",
+ )
mock_is_file.return_value = False
mock_set_temporary_ssl_file.side_effect = [
"/tmp/cert_file.pem",
"/tmp/rootcert_file.pem",
"/tmp/key_file.pem",
]
- connection.set_extra(
- json.dumps(
- {
- "location": "test",
- "instance": "instance",
- "database_type": "postgres",
- "use_ssl": "True",
- "sslcert": "cert_file.pem",
- "sslrootcert": "rootcert_file.pem",
- "sslkey": "key_file.pem",
- }
- )
+ extras = json.dumps(
+ {
+ "location": "test",
+ "instance": "instance",
+ "database_type": "postgres",
+ "use_ssl": "True",
+ "sslcert": "cert_file.pem",
+ "sslrootcert": "rootcert_file.pem",
+ "sslkey": "key_file.pem",
+ }
)
+ if AIRFLOW_V_3_1_PLUS:
+ connection.extra = extras
+ else:
+ connection.set_extra(extras)
get_connection.return_value = connection
hook = CloudSQLDatabaseHook(
gcp_cloudsql_conn_id="cloudsql_connection",
default_gcp_project_id="google_connection"
@@ -881,18 +931,23 @@ class TestCloudSqlDatabaseHook:
self, get_connection, gettempdir_mock
):
gettempdir_mock.return_value = "/tmp"
- connection = Connection()
- connection.set_extra(
- json.dumps(
- {
- "location": "test",
- "instance":
"very_long_instance_name_that_will_be_too_long_to_build_socket_length",
- "database_type": "postgres",
- "use_proxy": "True",
- "use_tcp": "False",
- }
- )
+ connection = Connection(
+ conn_id="my_test_connection",
+ conn_type="gcpcloudsqldb",
)
+ extras = json.dumps(
+ {
+ "location": "test",
+ "instance":
"very_long_instance_name_that_will_be_too_long_to_build_socket_length",
+ "database_type": "postgres",
+ "use_proxy": "True",
+ "use_tcp": "False",
+ }
+ )
+ if AIRFLOW_V_3_1_PLUS:
+ connection.extra = extras
+ else:
+ connection.set_extra(extras)
get_connection.return_value = connection
hook = CloudSQLDatabaseHook(
gcp_cloudsql_conn_id="cloudsql_connection",
default_gcp_project_id="google_connection"
@@ -908,18 +963,23 @@ class TestCloudSqlDatabaseHook:
self, get_connection, gettempdir_mock
):
gettempdir_mock.return_value = "/tmp"
- connection = Connection()
- connection.set_extra(
- json.dumps(
- {
- "location": "test",
- "instance": "short_instance_name",
- "database_type": "postgres",
- "use_proxy": "True",
- "use_tcp": "False",
- }
- )
+ connection = Connection(
+ conn_id="my_test_connection",
+ conn_type="gcpcloudsqldb",
+ )
+ extras = json.dumps(
+ {
+ "location": "test",
+ "instance": "short_instance_name",
+ "database_type": "postgres",
+ "use_proxy": "True",
+ "use_tcp": "False",
+ }
)
+ if AIRFLOW_V_3_1_PLUS:
+ connection.extra = extras
+ else:
+ connection.set_extra(extras)
get_connection.return_value = connection
hook = CloudSQLDatabaseHook(
gcp_cloudsql_conn_id="cloudsql_connection",
default_gcp_project_id="google_connection"
@@ -940,7 +1000,10 @@ class TestCloudSqlDatabaseHook:
)
@mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection")
def test_cloudsql_database_hook_create_connection_missing_fields(self,
get_connection, uri):
- connection = Connection(uri=uri)
+ if AIRFLOW_V_3_1_PLUS:
+ connection = Connection(conn_id="test_conn_id",
**_parse_from_uri(uri))
+ else:
+ connection = Connection(uri=uri)
params = {
"location": "test",
"instance": "instance",
@@ -948,7 +1011,11 @@ class TestCloudSqlDatabaseHook:
"use_proxy": "True",
"use_tcp": "False",
}
- connection.set_extra(json.dumps(params))
+ extras = json.dumps(params)
+ if AIRFLOW_V_3_1_PLUS:
+ connection.extra = extras
+ else:
+ connection.set_extra(extras)
get_connection.return_value = connection
hook = CloudSQLDatabaseHook(
gcp_cloudsql_conn_id="cloudsql_connection",
default_gcp_project_id="google_connection"
@@ -960,16 +1027,23 @@ class TestCloudSqlDatabaseHook:
@mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection")
def test_cloudsql_database_hook_get_sqlproxy_runner_no_proxy(self,
get_connection):
- connection = Connection(uri="http://user:password@host:80/database")
- connection.set_extra(
- json.dumps(
- {
- "location": "test",
- "instance": "instance",
- "database_type": "postgres",
- }
+ if AIRFLOW_V_3_1_PLUS:
+ connection = Connection(
+ conn_id="test_conn_id",
**_parse_from_uri("http://user:password@host:80/database")
)
+ else:
+ connection =
Connection(uri="http://user:password@host:80/database")
+ extras = json.dumps(
+ {
+ "location": "test",
+ "instance": "instance",
+ "database_type": "postgres",
+ }
)
+ if AIRFLOW_V_3_1_PLUS:
+ connection.extra = extras
+ else:
+ connection.set_extra(extras)
get_connection.return_value = connection
hook = CloudSQLDatabaseHook(
gcp_cloudsql_conn_id="cloudsql_connection",
default_gcp_project_id="google_connection"
@@ -981,18 +1055,25 @@ class TestCloudSqlDatabaseHook:
@mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection")
def test_cloudsql_database_hook_get_sqlproxy_runner(self, get_connection):
- connection = Connection(uri="http://user:password@host:80/database")
- connection.set_extra(
- json.dumps(
- {
- "location": "test",
- "instance": "instance",
- "database_type": "postgres",
- "use_proxy": "True",
- "use_tcp": "False",
- }
+ if AIRFLOW_V_3_1_PLUS:
+ connection = Connection(
+ conn_id="test_conn_id",
**_parse_from_uri("http://user:password@host:80/database")
)
+ else:
+ connection =
Connection(uri="http://user:password@host:80/database")
+ extras = json.dumps(
+ {
+ "location": "test",
+ "instance": "instance",
+ "database_type": "postgres",
+ "use_proxy": "True",
+ "use_tcp": "False",
+ }
)
+ if AIRFLOW_V_3_1_PLUS:
+ connection.extra = extras
+ else:
+ connection.set_extra(extras)
get_connection.return_value = connection
hook = CloudSQLDatabaseHook(
gcp_cloudsql_conn_id="cloudsql_connection",
default_gcp_project_id="google_connection"
@@ -1003,16 +1084,23 @@ class TestCloudSqlDatabaseHook:
@mock.patch("airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook.get_connection")
def test_cloudsql_database_hook_get_database_hook(self, get_connection):
- connection = Connection(uri="http://user:password@host:80/database")
- connection.set_extra(
- json.dumps(
- {
- "location": "test",
- "instance": "instance",
- "database_type": "postgres",
- }
+ if AIRFLOW_V_3_1_PLUS:
+ connection = Connection(
+ conn_id="test_conn_id",
**_parse_from_uri("http://user:password@host:80/database")
)
+ else:
+ connection =
Connection(uri="http://user:password@host:80/database")
+ extras = json.dumps(
+ {
+ "location": "test",
+ "instance": "instance",
+ "database_type": "postgres",
+ }
)
+ if AIRFLOW_V_3_1_PLUS:
+ connection.extra = extras
+ else:
+ connection.set_extra(extras)
get_connection.return_value = connection
hook = CloudSQLDatabaseHook(
gcp_cloudsql_conn_id="cloudsql_connection",
default_gcp_project_id="google_connection"
@@ -1414,7 +1502,10 @@ class TestCloudSqlDatabaseQueryHook:
"key_path": "/var/local/google_cloud_default.json",
}
conn_extra_json = json.dumps(conn_extra)
- self.connection.set_extra(conn_extra_json)
+ if AIRFLOW_V_3_1_PLUS:
+ self.connection.extra = conn_extra_json
+ else:
+ self.connection.set_extra(conn_extra_json)
mock_get_conn.side_effect = [self.sql_connection, self.connection]
self.db_hook = CloudSQLDatabaseHook(
@@ -1440,14 +1531,20 @@ class TestCloudSqlDatabaseQueryHook:
"test_db_with_longname_but_with_limit_of_UNIX_socket&"
"use_proxy=True&sql_proxy_use_tcp=False"
)
- get_connection.side_effect = [Connection(uri=uri)]
+ if AIRFLOW_V_3_1_PLUS:
+ get_connection.side_effect = [Connection(conn_id="test_conn_id",
**_parse_from_uri(uri))]
+ else:
+ get_connection.side_effect = [Connection(uri=uri)]
hook = CloudSQLDatabaseHook()
connection = hook.create_connection()
assert connection.conn_type == "postgres"
assert connection.schema == "testdb"
def _verify_postgres_connection(self, get_connection, uri):
- get_connection.side_effect = [Connection(uri=uri)]
+ if AIRFLOW_V_3_1_PLUS:
+ get_connection.side_effect = [Connection(conn_id="test_conn_id",
**_parse_from_uri(uri))]
+ else:
+ get_connection.side_effect = [Connection(uri=uri)]
hook = CloudSQLDatabaseHook()
connection = hook.create_connection()
assert connection.conn_type == "postgres"
@@ -1490,7 +1587,10 @@ class TestCloudSqlDatabaseQueryHook:
"project_id=example-project&location=europe-west1&instance=testdb&"
"use_proxy=True&sql_proxy_use_tcp=False"
)
- get_connection.side_effect = [Connection(uri=uri)]
+ if AIRFLOW_V_3_1_PLUS:
+ get_connection.side_effect = [Connection(conn_id="test_conn_id",
**_parse_from_uri(uri))]
+ else:
+ get_connection.side_effect = [Connection(uri=uri)]
hook = CloudSQLDatabaseHook()
connection = hook.create_connection()
assert connection.conn_type == "postgres"
@@ -1509,7 +1609,10 @@ class TestCloudSqlDatabaseQueryHook:
self.verify_mysql_connection(get_connection, uri)
def verify_mysql_connection(self, get_connection, uri):
- get_connection.side_effect = [Connection(uri=uri)]
+ if AIRFLOW_V_3_1_PLUS:
+ get_connection.side_effect = [Connection(conn_id="test_conn_id",
**_parse_from_uri(uri))]
+ else:
+ get_connection.side_effect = [Connection(uri=uri)]
hook = CloudSQLDatabaseHook()
connection = hook.create_connection()
assert connection.conn_type == "mysql"
@@ -1525,7 +1628,10 @@ class TestCloudSqlDatabaseQueryHook:
"project_id=example-project&location=europe-west1&instance=testdb&"
"use_proxy=True&sql_proxy_use_tcp=True"
)
- get_connection.side_effect = [Connection(uri=uri)]
+ if AIRFLOW_V_3_1_PLUS:
+ get_connection.side_effect = [Connection(conn_id="test_conn_id",
**_parse_from_uri(uri))]
+ else:
+ get_connection.side_effect = [Connection(uri=uri)]
hook = CloudSQLDatabaseHook()
connection = hook.create_connection()
assert connection.conn_type == "postgres"
@@ -1567,7 +1673,10 @@ class TestCloudSqlDatabaseQueryHook:
"project_id=example-project&location=europe-west1&instance=testdb&"
"use_proxy=True&sql_proxy_use_tcp=False"
)
- get_connection.side_effect = [Connection(uri=uri)]
+ if AIRFLOW_V_3_1_PLUS:
+ get_connection.side_effect = [Connection(conn_id="test_conn_id",
**_parse_from_uri(uri))]
+ else:
+ get_connection.side_effect = [Connection(uri=uri)]
hook = CloudSQLDatabaseHook()
connection = hook.create_connection()
assert connection.conn_type == "mysql"
@@ -1584,7 +1693,10 @@ class TestCloudSqlDatabaseQueryHook:
"project_id=example-project&location=europe-west1&instance=testdb&"
"use_proxy=True&sql_proxy_use_tcp=True"
)
- get_connection.side_effect = [Connection(uri=uri)]
+ if AIRFLOW_V_3_1_PLUS:
+ get_connection.side_effect = [Connection(conn_id="test_conn_id",
**_parse_from_uri(uri))]
+ else:
+ get_connection.side_effect = [Connection(uri=uri)]
hook = CloudSQLDatabaseHook()
connection = hook.create_connection()
assert connection.conn_type == "mysql"