This is an automated email from the ASF dual-hosted git repository. beto pushed a commit to branch engine-manager in repository https://gitbox.apache.org/repos/asf/superset.git
commit de8c250f86d67a008dfec9d8d3fa72e7e1cd261e Author: Beto Dealmeida <[email protected]> AuthorDate: Wed Dec 3 20:00:29 2025 -0500 Update existing tests --- superset/extensions/engine_manager.py | 6 - tests/integration_tests/conftest.py | 1 - .../integration_tests/databases/commands_tests.py | 22 +- tests/unit_tests/engines/manager_test.py | 232 +++++++++++++++++++++ tests/unit_tests/initialization_test.py | 2 +- tests/unit_tests/models/core_test.py | 167 --------------- 6 files changed, 244 insertions(+), 186 deletions(-) diff --git a/superset/extensions/engine_manager.py b/superset/extensions/engine_manager.py index df391ad5d8e..e15ead09b43 100644 --- a/superset/extensions/engine_manager.py +++ b/superset/extensions/engine_manager.py @@ -70,12 +70,6 @@ class EngineManagerExtension: def shutdown_engine_manager() -> None: if self.engine_manager: self.engine_manager.stop_cleanup_thread() - # Use a try-except to handle closed log file handlers during tests - try: - logger.info("Stopped EngineManager cleanup thread") - except ValueError: - # Ignore logging errors during test shutdown when file handles are closed - pass app.teardown_appcontext_funcs.append(lambda exc: None) diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py index 4f6ce10b0f9..95f1015e85f 100644 --- a/tests/integration_tests/conftest.py +++ b/tests/integration_tests/conftest.py @@ -170,7 +170,6 @@ def example_db_provider() -> Callable[[], Database]: return self._db def _load_lazy_data_to_decouple_from_session(self) -> None: - self._db._get_sqla_engine() # type: ignore self._db.backend # type: ignore # noqa: B018 def remove(self) -> None: diff --git a/tests/integration_tests/databases/commands_tests.py b/tests/integration_tests/databases/commands_tests.py index 27c1ce56542..2d43ee14d0f 100644 --- a/tests/integration_tests/databases/commands_tests.py +++ b/tests/integration_tests/databases/commands_tests.py @@ -897,7 +897,7 @@ class TestImportDatabasesCommand(SupersetTestCase): class TestTestConnectionDatabaseCommand(SupersetTestCase): - @patch("superset.models.core.Database._get_sqla_engine") + @patch("superset.models.core.Database.get_sqla_engine") @patch("superset.commands.database.test_connection.event_logger.log_with_context") @patch("superset.utils.core.g") def test_connection_db_exception( @@ -906,19 +906,19 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase): """Test to make sure event_logger is called when an exception is raised""" database = get_example_database() mock_g.user = security_manager.find_user("admin") - mock_get_sqla_engine.side_effect = Exception("An error has occurred!") + mock_get_sqla_engine.__enter__.side_effect = Exception("An error has occurred!") db_uri = database.sqlalchemy_uri_decrypted json_payload = {"sqlalchemy_uri": db_uri} command_without_db_name = TestConnectionDatabaseCommand(json_payload) with pytest.raises(DatabaseTestConnectionUnexpectedError) as excinfo: # noqa: PT012 command_without_db_name.run() - assert str(excinfo.value) == ( - "Unexpected error occurred, please check your logs for details" - ) + assert str(excinfo.value) == ( + "Unexpected error occurred, please check your logs for details" + ) mock_event_logger.assert_called() - @patch("superset.models.core.Database._get_sqla_engine") + @patch("superset.models.core.Database.get_sqla_engine") @patch("superset.commands.database.test_connection.event_logger.log_with_context") @patch("superset.utils.core.g") def test_connection_do_ping_exception( @@ -927,7 +927,7 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase): """Test to make sure do_ping exceptions gets captured""" database = get_example_database() mock_g.user = security_manager.find_user("admin") - mock_get_sqla_engine.return_value.dialect.do_ping.side_effect = Exception( + mock_get_sqla_engine.__enter__().dialect.do_ping.side_effect = Exception( "An error has occurred!" ) db_uri = database.sqlalchemy_uri_decrypted @@ -967,7 +967,7 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase): == SupersetErrorType.CONNECTION_DATABASE_TIMEOUT ) - @patch("superset.models.core.Database._get_sqla_engine") + @patch("superset.models.core.Database.get_sqla_engine") @patch("superset.commands.database.test_connection.event_logger.log_with_context") @patch("superset.utils.core.g") def test_connection_superset_security_connection( @@ -977,7 +977,7 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase): connection exc is raised""" database = get_example_database() mock_g.user = security_manager.find_user("admin") - mock_get_sqla_engine.side_effect = SupersetSecurityException( + mock_get_sqla_engine.__enter__.side_effect = SupersetSecurityException( SupersetError(error_type=500, message="test", level="info") ) db_uri = database.sqlalchemy_uri_decrypted @@ -990,7 +990,7 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase): mock_event_logger.assert_called() - @patch("superset.models.core.Database._get_sqla_engine") + @patch("superset.models.core.Database.get_sqla_engine") @patch("superset.commands.database.test_connection.event_logger.log_with_context") @patch("superset.utils.core.g") def test_connection_db_api_exc( @@ -999,7 +999,7 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase): """Test to make sure event_logger is called when DBAPIError is raised""" database = get_example_database() mock_g.user = security_manager.find_user("admin") - mock_get_sqla_engine.side_effect = DBAPIError( + mock_get_sqla_engine.__enter__.side_effect = DBAPIError( statement="error", params={}, orig={} ) db_uri = database.sqlalchemy_uri_decrypted diff --git a/tests/unit_tests/engines/manager_test.py b/tests/unit_tests/engines/manager_test.py index 287820eaf1b..871624f2a42 100644 --- a/tests/unit_tests/engines/manager_test.py +++ b/tests/unit_tests/engines/manager_test.py @@ -96,7 +96,9 @@ class TestEngineManager: @pytest.fixture def engine_manager(self): """Create a mock EngineManager instance.""" + from contextlib import contextmanager + @contextmanager def dummy_context_manager( database: MagicMock, catalog: str | None, schema: str | None ) -> Iterator[None]: @@ -293,3 +295,233 @@ class TestEngineManager: result2 = engine_manager._get_tunnel(ssh_tunnel, uri) assert result2 is active_tunnel assert mock_tunnel_class.call_count == 2 + + @patch("superset.engines.manager.create_engine") + @patch("superset.engines.manager.make_url_safe") + def test_get_engine_args_basic( + self, mock_make_url, mock_create_engine, engine_manager + ): + """Test _get_engine_args returns correct URI and kwargs.""" + from sqlalchemy.engine.url import make_url + + from superset.engines.manager import EngineModes + + engine_manager.mode = EngineModes.NEW + + mock_uri = make_url("trino://") + mock_make_url.return_value = mock_uri + + database = MagicMock() + database.id = 1 + database.sqlalchemy_uri_decrypted = "trino://" + database.get_extra.return_value = { + "engine_params": {}, + "connect_args": {"source": "Apache Superset"}, + } + database.get_effective_user.return_value = "alice" + database.impersonate_user = False + database.update_params_from_encrypted_extra = MagicMock() + database.db_engine_spec = MagicMock() + database.db_engine_spec.adjust_engine_params.return_value = ( + mock_uri, + {"source": "Apache Superset"}, + ) + database.db_engine_spec.validate_database_uri = MagicMock() + + uri, kwargs = engine_manager._get_engine_args(database, None, None, None, None) + + assert str(uri) == "trino://" + assert "connect_args" in database.get_extra.return_value + + @patch("superset.engines.manager.create_engine") + @patch("superset.engines.manager.make_url_safe") + def test_get_engine_args_user_impersonation( + self, mock_make_url, mock_create_engine, engine_manager + ): + """Test user impersonation in _get_engine_args.""" + from sqlalchemy.engine.url import make_url + + from superset.engines.manager import EngineModes + + engine_manager.mode = EngineModes.NEW + + mock_uri = make_url("trino://") + mock_make_url.return_value = mock_uri + + database = MagicMock() + database.id = 1 + database.sqlalchemy_uri_decrypted = "trino://" + database.get_extra.return_value = { + "engine_params": {}, + "connect_args": {"source": "Apache Superset"}, + } + database.get_effective_user.return_value = "alice" + database.impersonate_user = True + database.get_oauth2_config.return_value = None + database.update_params_from_encrypted_extra = MagicMock() + database.db_engine_spec = MagicMock() + database.db_engine_spec.adjust_engine_params.return_value = ( + mock_uri, + {"source": "Apache Superset"}, + ) + database.db_engine_spec.impersonate_user.return_value = ( + mock_uri, + {"connect_args": {"user": "alice", "source": "Apache Superset"}}, + ) + database.db_engine_spec.validate_database_uri = MagicMock() + + uri, kwargs = engine_manager._get_engine_args(database, None, None, None, None) + + # Verify impersonate_user was called + database.db_engine_spec.impersonate_user.assert_called_once() + call_args = database.db_engine_spec.impersonate_user.call_args + assert call_args[0][0] is database # database + assert call_args[0][1] == "alice" # username + assert call_args[0][2] is None # access_token (no OAuth2) + + @patch("superset.engines.manager.create_engine") + @patch("superset.engines.manager.make_url_safe") + def test_get_engine_args_user_impersonation_email_prefix( + self, + mock_make_url, + mock_create_engine, + engine_manager, + ): + """Test user impersonation with IMPERSONATE_WITH_EMAIL_PREFIX feature flag.""" + from sqlalchemy.engine.url import make_url + + from superset.engines.manager import EngineModes + + engine_manager.mode = EngineModes.NEW + + mock_uri = make_url("trino://") + mock_make_url.return_value = mock_uri + + # Mock user with email + mock_user = MagicMock() + mock_user.email = "[email protected]" + + database = MagicMock() + database.id = 1 + database.sqlalchemy_uri_decrypted = "trino://" + database.get_extra.return_value = { + "engine_params": {}, + "connect_args": {"source": "Apache Superset"}, + } + database.get_effective_user.return_value = "alice" + database.impersonate_user = True + database.get_oauth2_config.return_value = None + database.update_params_from_encrypted_extra = MagicMock() + database.db_engine_spec = MagicMock() + database.db_engine_spec.adjust_engine_params.return_value = ( + mock_uri, + {"source": "Apache Superset"}, + ) + database.db_engine_spec.impersonate_user.return_value = ( + mock_uri, + {"connect_args": {"user": "alice.doe", "source": "Apache Superset"}}, + ) + database.db_engine_spec.validate_database_uri = MagicMock() + + with ( + patch( + "superset.utils.feature_flag_manager.FeatureFlagManager.is_feature_enabled", + return_value=True, + ), + patch( + "superset.extensions.security_manager.find_user", + return_value=mock_user, + ), + ): + uri, kwargs = engine_manager._get_engine_args( + database, None, None, None, None + ) + + # Verify impersonate_user was called with the email prefix + database.db_engine_spec.impersonate_user.assert_called_once() + call_args = database.db_engine_spec.impersonate_user.call_args + assert call_args[0][1] == "alice.doe" # username from email prefix + + @patch("superset.engines.manager.create_engine") + @patch("superset.engines.manager.make_url_safe") + def test_engine_context_manager_called( + self, mock_make_url, mock_create_engine, engine_manager, mock_database + ): + """Test that the engine context manager is properly called.""" + from sqlalchemy.engine.url import make_url + + mock_uri = make_url("trino://") + mock_make_url.return_value = mock_uri + mock_engine = MagicMock() + mock_create_engine.return_value = mock_engine + + # Track context manager calls + context_manager_calls = [] + + def tracking_context_manager(database, catalog, schema): + from contextlib import contextmanager + + @contextmanager + def inner(): + context_manager_calls.append(("enter", database, catalog, schema)) + yield + context_manager_calls.append(("exit", database, catalog, schema)) + + return inner() + + engine_manager.engine_context_manager = tracking_context_manager + + with engine_manager.get_engine(mock_database, "catalog1", "schema1", None): + pass + + assert len(context_manager_calls) == 2 + assert context_manager_calls[0][0] == "enter" + assert context_manager_calls[0][1] is mock_database + assert context_manager_calls[0][2] == "catalog1" + assert context_manager_calls[0][3] == "schema1" + assert context_manager_calls[1][0] == "exit" + + @patch("superset.utils.oauth2.check_for_oauth2") + @patch("superset.engines.manager.create_engine") + @patch("superset.engines.manager.make_url_safe") + def test_engine_oauth2_error_handling( + self, + mock_make_url, + mock_create_engine, + mock_check_for_oauth2, + engine_manager, + mock_database, + ): + """Test that OAuth2 errors are properly propagated from get_engine.""" + from contextlib import contextmanager + + from sqlalchemy.engine.url import make_url + + mock_uri = make_url("trino://") + mock_make_url.return_value = mock_uri + + # Simulate OAuth2 error during engine creation + class OAuth2TestError(Exception): + pass + + oauth_error = OAuth2TestError("OAuth2 required") + mock_create_engine.side_effect = oauth_error + + # Make get_dbapi_mapped_exception return the original exception + mock_database.db_engine_spec.get_dbapi_mapped_exception.return_value = ( + oauth_error + ) + + # Mock check_for_oauth2 to re-raise the exception + @contextmanager + def mock_oauth2_context(database): + try: + yield + except OAuth2TestError: + raise + + mock_check_for_oauth2.return_value = mock_oauth2_context(mock_database) + + with pytest.raises(OAuth2TestError, match="OAuth2 required"): + with engine_manager.get_engine(mock_database, "catalog1", "schema1", None): + pass diff --git a/tests/unit_tests/initialization_test.py b/tests/unit_tests/initialization_test.py index 01fde0967c9..93fdf4d352e 100644 --- a/tests/unit_tests/initialization_test.py +++ b/tests/unit_tests/initialization_test.py @@ -123,7 +123,7 @@ class TestSupersetAppInitializer: patch.object(app_initializer, "configure_data_sources"), patch.object(app_initializer, "configure_auth_provider"), patch.object(app_initializer, "configure_async_queries"), - patch.object(app_initializer, "configure_ssh_manager"), + patch.object(app_initializer, "configure_engine_manager"), patch.object(app_initializer, "configure_stats_manager"), patch.object(app_initializer, "init_views"), ): diff --git a/tests/unit_tests/models/core_test.py b/tests/unit_tests/models/core_test.py index 7d7aa96ea19..b2a48df0592 100644 --- a/tests/unit_tests/models/core_test.py +++ b/tests/unit_tests/models/core_test.py @@ -19,7 +19,6 @@ from datetime import datetime import pytest -from flask import current_app from pytest_mock import MockerFixture from sqlalchemy import ( Column, @@ -29,7 +28,6 @@ from sqlalchemy import ( Table as SqlalchemyTable, ) from sqlalchemy.engine.reflection import Inspector -from sqlalchemy.engine.url import make_url from sqlalchemy.orm.session import Session from sqlalchemy.sql import Select @@ -525,60 +523,6 @@ def test_get_all_materialized_view_names_in_schema_needs_oauth2( assert excinfo.value.error.error_type == SupersetErrorType.OAUTH2_REDIRECT -def test_get_sqla_engine(mocker: MockerFixture) -> None: - """ - Test `_get_sqla_engine`. - """ - from superset.models.core import Database - - user = mocker.MagicMock() - user.email = "[email protected]" - mocker.patch( - "superset.models.core.security_manager.find_user", - return_value=user, - ) - mocker.patch("superset.models.core.get_username", return_value="alice") - - create_engine = mocker.patch("superset.models.core.create_engine") - - database = Database(database_name="my_db", sqlalchemy_uri="trino://") - database._get_sqla_engine(nullpool=False) - - create_engine.assert_called_with( - make_url("trino:///"), - connect_args={"source": "Apache Superset"}, - ) - - -def test_get_sqla_engine_user_impersonation(mocker: MockerFixture) -> None: - """ - Test user impersonation in `_get_sqla_engine`. - """ - from superset.models.core import Database - - user = mocker.MagicMock() - user.email = "[email protected]" - mocker.patch( - "superset.models.core.security_manager.find_user", - return_value=user, - ) - mocker.patch("superset.models.core.get_username", return_value="alice") - - create_engine = mocker.patch("superset.models.core.create_engine") - - database = Database( - database_name="my_db", - sqlalchemy_uri="trino://", - impersonate_user=True, - ) - database._get_sqla_engine(nullpool=False) - - create_engine.assert_called_with( - make_url("trino:///"), - connect_args={"user": "alice", "source": "Apache Superset"}, - ) - - def test_add_database_to_signature(): args = ["param1", "param2"] @@ -604,36 +548,6 @@ def test_add_database_to_signature(): assert args3 == ["param1", "param2", database] -@with_feature_flags(IMPERSONATE_WITH_EMAIL_PREFIX=True) -def test_get_sqla_engine_user_impersonation_email(mocker: MockerFixture) -> None: - """ - Test user impersonation in `_get_sqla_engine` with `username_from_email`. - """ - from superset.models.core import Database - - user = mocker.MagicMock() - user.email = "[email protected]" - mocker.patch( - "superset.models.core.security_manager.find_user", - return_value=user, - ) - mocker.patch("superset.models.core.get_username", return_value="alice") - - create_engine = mocker.patch("superset.models.core.create_engine") - - database = Database( - database_name="my_db", - sqlalchemy_uri="trino://", - impersonate_user=True, - ) - database._get_sqla_engine(nullpool=False) - - create_engine.assert_called_with( - make_url("trino:///"), - connect_args={"user": "alice.doe", "source": "Apache Superset"}, - ) - - def test_is_oauth2_enabled() -> None: """ Test the `is_oauth2_enabled` method. @@ -753,37 +667,6 @@ def test_get_oauth2_config_redirect_uri_from_config( assert config["redirect_uri"] == custom_redirect_uri -def test_raw_connection_oauth_engine(mocker: MockerFixture) -> None: - """ - Test that we can start OAuth2 from `raw_connection()` errors. - - With OAuth2, some databases will raise an exception when the engine is first created - (eg, BigQuery). Others, like, Snowflake, when the connection is created. And - finally, GSheets will raise an exception when the query is executed. - - This tests verifies that when calling `raw_connection()` the OAuth2 flow is - triggered when the engine is created. - """ - g = mocker.patch("superset.db_engine_specs.base.g") - g.user = mocker.MagicMock() - g.user.id = 42 - - database = Database( - id=1, - database_name="my_db", - sqlalchemy_uri="sqlite://", - encrypted_extra=json.dumps(oauth2_client_info), - ) - database.db_engine_spec.oauth2_exception = OAuth2Error # type: ignore - _get_sqla_engine = mocker.patch.object(database, "_get_sqla_engine") - _get_sqla_engine.side_effect = OAuth2Error("OAuth2 required") - - with pytest.raises(OAuth2RedirectError) as excinfo: - with database.get_raw_connection() as conn: - conn.cursor() - assert str(excinfo.value) == "You don't have permission to access the data." - - def test_raw_connection_oauth_connection(mocker: MockerFixture) -> None: """ Test that we can start OAuth2 from `raw_connection()` errors. @@ -879,56 +762,6 @@ def test_get_schema_access_for_file_upload() -> None: assert database.get_schema_access_for_file_upload() == {"public"} -def test_engine_context_manager(mocker: MockerFixture, app_context: None) -> None: - """ - Test the engine context manager. - """ - from unittest.mock import MagicMock - - engine_context_manager = MagicMock() - mocker.patch.dict( - current_app.config, - {"ENGINE_CONTEXT_MANAGER": engine_context_manager}, - ) - _get_sqla_engine = mocker.patch.object(Database, "_get_sqla_engine") - - database = Database(database_name="my_db", sqlalchemy_uri="trino://") - with database.get_sqla_engine("catalog", "schema"): - pass - - engine_context_manager.assert_called_once_with(database, "catalog", "schema") - engine_context_manager().__enter__.assert_called_once() - engine_context_manager().__exit__.assert_called_once_with(None, None, None) - _get_sqla_engine.assert_called_once_with( - catalog="catalog", - schema="schema", - nullpool=True, - source=None, - sqlalchemy_uri="trino://", - ) - - -def test_engine_oauth2(mocker: MockerFixture) -> None: - """ - Test that we handle OAuth2 when `create_engine` fails. - """ - database = Database(database_name="my_db", sqlalchemy_uri="trino://") - mocker.patch.object(database, "_get_sqla_engine", side_effect=Exception) - mocker.patch.object(database, "is_oauth2_enabled", return_value=True) - mocker.patch.object(database.db_engine_spec, "needs_oauth2", return_value=True) - start_oauth2_dance = mocker.patch.object( - database.db_engine_spec, - "start_oauth2_dance", - side_effect=OAuth2Error("OAuth2 required"), - ) - - with pytest.raises(OAuth2Error): - with database.get_sqla_engine("catalog", "schema"): - pass - - start_oauth2_dance.assert_called_with(database) - - def test_purge_oauth2_tokens(session: Session) -> None: """ Test the `purge_oauth2_tokens` method.
