This is an automated email from the ASF dual-hosted git repository. elizabeth pushed a commit to branch elizabeth/test-2.1.1 in repository https://gitbox.apache.org/repos/asf/superset.git
commit 882117492117378bce0c002c7e250322ed560931 Author: Daniel Vaz Gaspar <[email protected]> AuthorDate: Tue Apr 18 17:07:37 2023 +0100 feat: add enforce URI query params with a specific for MySQL (#23723) --- superset/db_engine_specs/base.py | 9 +++++++- superset/db_engine_specs/mysql.py | 6 ++++- tests/integration_tests/model_tests.py | 15 ++++++++++++ tests/unit_tests/db_engine_specs/test_mysql.py | 32 ++++++++++++++++++++++++-- 4 files changed, 58 insertions(+), 4 deletions(-) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 5243b4660d..21aa171323 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -356,6 +356,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods top_keywords: Set[str] = {"TOP"} # A set of disallowed connection query parameters disallow_uri_query_params: Set[str] = set() + # A Dict of query parameters that will always be used on every connection + enforce_uri_query_params: Dict[str, Any] = {} force_column_alias_quotes = False arraysize = 0 @@ -1016,8 +1018,13 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods Some database drivers like Presto accept '{catalog}/{schema}' in the database component of the URL, that can be handled here. + + Currently, changing the catalog is not supported. The method accepts a catalog so + that when catalog support is added to Superset the interface remains the same. + This is important because DB engine specs can be installed from 3rd party + packages. """ - return uri + return uri, {**cls.enforce_uri_query_params} @classmethod def patch(cls) -> None: diff --git a/superset/db_engine_specs/mysql.py b/superset/db_engine_specs/mysql.py index 348b3287e3..28ef442319 100644 --- a/superset/db_engine_specs/mysql.py +++ b/superset/db_engine_specs/mysql.py @@ -174,6 +174,7 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin): ), } disallow_uri_query_params = {"local_infile"} + enforce_uri_query_params = {"local_infile": 0} @classmethod def convert_dttm( @@ -192,10 +193,13 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin): def adjust_database_uri( cls, uri: URL, selected_schema: Optional[str] = None ) -> URL: + uri, new_connect_args = super( + MySQLEngineSpec, MySQLEngineSpec + ).adjust_database_uri(uri) if selected_schema: uri = uri.set(database=parse.quote(selected_schema, safe="")) - return uri + return uri, new_connect_args @classmethod def get_datatype(cls, type_code: Any) -> Optional[str]: diff --git a/tests/integration_tests/model_tests.py b/tests/integration_tests/model_tests.py index da6c5e6a3c..35dbcc0a6b 100644 --- a/tests/integration_tests/model_tests.py +++ b/tests/integration_tests/model_tests.py @@ -188,6 +188,21 @@ class TestDatabaseModel(SupersetTestCase): "password": "original_user_password", } + @unittest.skipUnless( + SupersetTestCase.is_module_installed("MySQLdb"), "mysqlclient not installed" + ) + @mock.patch("superset.models.core.create_engine") + def test_adjust_engine_params_mysql(self, mocked_create_engine): + model = Database( + database_name="test_database", + sqlalchemy_uri="mysql://user:password@localhost", + ) + model._get_sqla_engine() + call_args = mocked_create_engine.call_args + + assert str(call_args[0][0]) == "mysql://user:password@localhost" + assert call_args[1]["connect_args"]["local_infile"] == 0 + @mock.patch("superset.models.core.create_engine") def test_impersonate_user_trino(self, mocked_create_engine): principal_user = security_manager.find_user(username="gamma") diff --git a/tests/unit_tests/db_engine_specs/test_mysql.py b/tests/unit_tests/db_engine_specs/test_mysql.py index a512e71a97..3a24e1c2dc 100644 --- a/tests/unit_tests/db_engine_specs/test_mysql.py +++ b/tests/unit_tests/db_engine_specs/test_mysql.py @@ -16,7 +16,7 @@ # under the License. from datetime import datetime -from typing import Any, Dict, Optional, Type +from typing import Any, Dict, Optional, Tuple, Type from unittest.mock import Mock, patch import pytest @@ -33,7 +33,7 @@ from sqlalchemy.dialects.mysql import ( TINYINT, TINYTEXT, ) -from sqlalchemy.engine.url import make_url +from sqlalchemy.engine.url import make_url, URL from superset.utils.core import GenericDataType from tests.unit_tests.db_engine_specs.utils import ( @@ -119,6 +119,34 @@ def test_validate_database_uri(sqlalchemy_uri: str, error: bool) -> None: MySQLEngineSpec.validate_database_uri(url) [email protected]( + "sqlalchemy_uri,connect_args,returns", + [ + ("mysql://user:password@host/db1", {"local_infile": 1}, {"local_infile": 0}), + ("mysql://user:password@host/db1", {"local_infile": -1}, {"local_infile": 0}), + ("mysql://user:password@host/db1", {"local_infile": 0}, {"local_infile": 0}), + ( + "mysql://user:password@host/db1", + {"param1": "some_value"}, + {"local_infile": 0, "param1": "some_value"}, + ), + ( + "mysql://user:password@host/db1", + {"local_infile": 1, "param1": "some_value"}, + {"local_infile": 0, "param1": "some_value"}, + ), + ], +) +def test_adjust_database_uri( + sqlalchemy_uri: str, connect_args: Dict[str, Any], returns: Dict[str, Any] +) -> None: + from superset.db_engine_specs.mysql import MySQLEngineSpec + + url = make_url(sqlalchemy_uri) + returned_url, returned_connect_args = MySQLEngineSpec.adjust_database_uri(url) + assert returned_connect_args == returns + + @patch("sqlalchemy.engine.Engine.connect") def test_get_cancel_query_id(engine_mock: Mock) -> None: from superset.db_engine_specs.mysql import MySQLEngineSpec
