This is an automated email from the ASF dual-hosted git repository.

dpgaspar pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 6ae5388dcf fix: allow db driver distinction on enforced URI params 
(#23769)
6ae5388dcf is described below

commit 6ae5388dcf0205e89d4abcc5cefcb644e8c7cdbd
Author: Daniel Vaz Gaspar <[email protected]>
AuthorDate: Sun Apr 23 15:44:21 2023 +0100

    fix: allow db driver distinction on enforced URI params (#23769)
---
 superset/db_engine_specs/base.py               | 18 ++++++++++-------
 superset/db_engine_specs/mysql.py              | 10 +++++++--
 tests/integration_tests/model_tests.py         | 12 ++++++++++-
 tests/unit_tests/db_engine_specs/test_mysql.py | 28 ++++++++++++++++++++++++++
 4 files changed, 58 insertions(+), 10 deletions(-)

diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index f1dda401af..a21c0b4100 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -355,10 +355,11 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
     # This set will give the keywords for data limit statements
     # to consider for the engines with TOP SQL parsing
     top_keywords: Set[str] = {"TOP"}
-    # A set of disallowed connection query parameters
-    disallow_uri_query_params: Set[str] = set()
+    # A set of disallowed connection query parameters by driver name
+    disallow_uri_query_params: Dict[str, Set[str]] = {}
     # A Dict of query parameters that will always be used on every connection
-    enforce_uri_query_params: Dict[str, Any] = {}
+    # by driver name
+    enforce_uri_query_params: Dict[str, Dict[str, Any]] = {}
 
     force_column_alias_quotes = False
     arraysize = 0
@@ -1099,7 +1100,10 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         This is important because DB engine specs can be installed from 3rd 
party
         packages.
         """
-        return uri, {**connect_args, **cls.enforce_uri_query_params}
+        return uri, {
+            **connect_args,
+            **cls.enforce_uri_query_params.get(uri.get_driver_name(), {}),
+        }
 
     @classmethod
     def patch(cls) -> None:
@@ -1853,9 +1857,9 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
 
         :param sqlalchemy_uri:
         """
-        if existing_disallowed := cls.disallow_uri_query_params.intersection(
-            sqlalchemy_uri.query
-        ):
+        if existing_disallowed := cls.disallow_uri_query_params.get(
+            sqlalchemy_uri.get_driver_name(), set()
+        ).intersection(sqlalchemy_uri.query):
             raise ValueError(f"Forbidden query parameter(s): 
{existing_disallowed}")
 
 
diff --git a/superset/db_engine_specs/mysql.py 
b/superset/db_engine_specs/mysql.py
index 07d2aea362..9c5bd0034a 100644
--- a/superset/db_engine_specs/mysql.py
+++ b/superset/db_engine_specs/mysql.py
@@ -175,8 +175,14 @@ class MySQLEngineSpec(BaseEngineSpec, 
BasicParametersMixin):
             {},
         ),
     }
-    disallow_uri_query_params = {"local_infile"}
-    enforce_uri_query_params = {"local_infile": 0}
+    disallow_uri_query_params = {
+        "mysqldb": {"local_infile"},
+        "mysqlconnector": {"allow_local_infile"},
+    }
+    enforce_uri_query_params = {
+        "mysqldb": {"local_infile": 0},
+        "mysqlconnector": {"allow_local_infile": 0},
+    }
 
     @classmethod
     def convert_dttm(
diff --git a/tests/integration_tests/model_tests.py 
b/tests/integration_tests/model_tests.py
index 35dbcc0a6b..d5684b1b62 100644
--- a/tests/integration_tests/model_tests.py
+++ b/tests/integration_tests/model_tests.py
@@ -194,7 +194,7 @@ class TestDatabaseModel(SupersetTestCase):
     @mock.patch("superset.models.core.create_engine")
     def test_adjust_engine_params_mysql(self, mocked_create_engine):
         model = Database(
-            database_name="test_database",
+            database_name="test_database1",
             sqlalchemy_uri="mysql://user:password@localhost",
         )
         model._get_sqla_engine()
@@ -203,6 +203,16 @@ class TestDatabaseModel(SupersetTestCase):
         assert str(call_args[0][0]) == "mysql://user:password@localhost"
         assert call_args[1]["connect_args"]["local_infile"] == 0
 
+        model = Database(
+            database_name="test_database2",
+            sqlalchemy_uri="mysql+mysqlconnector://user:password@localhost",
+        )
+        model._get_sqla_engine()
+        call_args = mocked_create_engine.call_args
+
+        assert str(call_args[0][0]) == 
"mysql+mysqlconnector://user:password@localhost"
+        assert call_args[1]["connect_args"]["allow_local_infile"] == 0
+
     @mock.patch("superset.models.core.create_engine")
     def test_impersonate_user_trino(self, mocked_create_engine):
         principal_user = security_manager.find_user(username="gamma")
diff --git a/tests/unit_tests/db_engine_specs/test_mysql.py 
b/tests/unit_tests/db_engine_specs/test_mysql.py
index 31e01ace58..07ce6838fc 100644
--- a/tests/unit_tests/db_engine_specs/test_mysql.py
+++ b/tests/unit_tests/db_engine_specs/test_mysql.py
@@ -104,8 +104,11 @@ def test_convert_dttm(
     "sqlalchemy_uri,error",
     [
         ("mysql://user:password@host/db1?local_infile=1", True),
+        ("mysql+mysqlconnector://user:password@host/db1?allow_local_infile=1", 
True),
         ("mysql://user:password@host/db1?local_infile=0", True),
+        ("mysql+mysqlconnector://user:password@host/db1?allow_local_infile=0", 
True),
         ("mysql://user:password@host/db1", False),
+        ("mysql+mysqlconnector://user:password@host/db1", False),
     ],
 )
 def test_validate_database_uri(sqlalchemy_uri: str, error: bool) -> None:
@@ -123,18 +126,43 @@ def test_validate_database_uri(sqlalchemy_uri: str, 
error: bool) -> None:
     "sqlalchemy_uri,connect_args,returns",
     [
         ("mysql://user:password@host/db1", {"local_infile": 1}, 
{"local_infile": 0}),
+        (
+            "mysql+mysqlconnector://user:password@host/db1",
+            {"allow_local_infile": 1},
+            {"allow_local_infile": 0},
+        ),
         ("mysql://user:password@host/db1", {"local_infile": -1}, 
{"local_infile": 0}),
+        (
+            "mysql+mysqlconnector://user:password@host/db1",
+            {"allow_local_infile": -1},
+            {"allow_local_infile": 0},
+        ),
         ("mysql://user:password@host/db1", {"local_infile": 0}, 
{"local_infile": 0}),
+        (
+            "mysql+mysqlconnector://user:password@host/db1",
+            {"allow_local_infile": 0},
+            {"allow_local_infile": 0},
+        ),
         (
             "mysql://user:password@host/db1",
             {"param1": "some_value"},
             {"local_infile": 0, "param1": "some_value"},
         ),
+        (
+            "mysql+mysqlconnector://user:password@host/db1",
+            {"param1": "some_value"},
+            {"allow_local_infile": 0, "param1": "some_value"},
+        ),
         (
             "mysql://user:password@host/db1",
             {"local_infile": 1, "param1": "some_value"},
             {"local_infile": 0, "param1": "some_value"},
         ),
+        (
+            "mysql+mysqlconnector://user:password@host/db1",
+            {"allow_local_infile": 1, "param1": "some_value"},
+            {"allow_local_infile": 0, "param1": "some_value"},
+        ),
     ],
 )
 def test_adjust_engine_params(

Reply via email to