This is an automated email from the ASF dual-hosted git repository. beto pushed a commit to branch engine-manager in repository https://gitbox.apache.org/repos/asf/superset.git
commit baf6e03d16f95820ee480b53a8c5c09e62574a25 Author: Beto Dealmeida <[email protected]> AuthorDate: Tue Jul 29 20:18:49 2025 -0400 Add extension --- superset/config.py | 16 ++++++ superset/engines/manager.py | 47 ++++++++++------ superset/extensions/__init__.py | 2 + superset/extensions/engine_manager.py | 103 ++++++++++++++++++++++++++++++++++ superset/initialization/__init__.py | 5 ++ 5 files changed, 157 insertions(+), 16 deletions(-) diff --git a/superset/config.py b/superset/config.py index 9731d80bd58..3b0e3146197 100644 --- a/superset/config.py +++ b/superset/config.py @@ -260,6 +260,22 @@ SQLALCHEMY_ENGINE_OPTIONS = {} # SQLALCHEMY_CUSTOM_PASSWORD_STORE = lookup_password SQLALCHEMY_CUSTOM_PASSWORD_STORE = None +# --------------------------------------------------------- +# Engine Manager Configuration +# --------------------------------------------------------- + +# Engine manager mode: "NEW" creates a new engine for every connection (default), +# "SINGLETON" reuses engines with connection pooling +ENGINE_MANAGER_MODE = "NEW" + +# Cleanup interval for abandoned locks in seconds (default: 5 minutes) +ENGINE_MANAGER_CLEANUP_INTERVAL = 300.0 + +# Automatically start cleanup thread for SINGLETON mode (default: True) +ENGINE_MANAGER_AUTO_START_CLEANUP = True + +# --------------------------------------------------------- + # # The EncryptedFieldTypeAdapter is used whenever we're building SqlAlchemy models # which include sensitive fields that should be app-encrypted BEFORE sending diff --git a/superset/engines/manager.py b/superset/engines/manager.py index 01eae105635..9f2b41caad6 100644 --- a/superset/engines/manager.py +++ b/superset/engines/manager.py @@ -30,16 +30,12 @@ from sqlalchemy.engine import Engine from sqlalchemy.engine.url import URL from sshtunnel import SSHTunnelForwarder -from superset import is_feature_enabled -from superset.databases.ssh_tunnel.models import SSHTunnel from superset.databases.utils import make_url_safe -from superset.extensions import security_manager from superset.utils.core import get_query_source_from_request, get_user_id, QuerySource from superset.utils.json import dumps -from superset.utils.oauth2 import check_for_oauth2, get_oauth2_access_token -from superset.utils.ssh_tunnel import get_default_port if TYPE_CHECKING: + from superset.databases.ssh_tunnel.models import SSHTunnel from superset.models.core import Database @@ -111,7 +107,7 @@ class EngineManager: @contextmanager def get_engine( self, - database: Database, + database: "Database", catalog: str | None, schema: str | None, source: QuerySource | None, @@ -125,12 +121,14 @@ class EngineManager: with customization(database, catalog, schema): # we need to check for errors indicating that OAuth2 is needed, and # return the proper exception so it starts the authentication flow + from superset.utils.oauth2 import check_for_oauth2 + with check_for_oauth2(database): yield self._get_engine(database, catalog, schema, source) def _get_engine( self, - database: Database, + database: "Database", catalog: str | None, schema: str | None, source: QuerySource | None, @@ -177,7 +175,7 @@ class EngineManager: def _get_engine_key( self, - database: Database, + database: "Database", catalog: str | None, schema: str | None, source: QuerySource | None, @@ -200,7 +198,7 @@ class EngineManager: def _get_engine_args( self, - database: Database, + database: "Database", catalog: str | None, schema: str | None, source: QuerySource | None, @@ -242,7 +240,15 @@ class EngineManager: # get effective username username = database.get_effective_user(uri) - if username and is_feature_enabled("IMPERSONATE_WITH_EMAIL_PREFIX"): + + # Import here to avoid circular imports + from superset.extensions import security_manager + from superset.utils.feature_flag_manager import FeatureFlagManager + + feature_flag_manager = FeatureFlagManager() + if username and feature_flag_manager.is_feature_enabled( + "IMPERSONATE_WITH_EMAIL_PREFIX" + ): user = security_manager.find_user(username=username) if user and user.email and "@" in user.email: username = user.email.split("@")[0] @@ -250,6 +256,9 @@ class EngineManager: # update URI/kwargs for user impersonation if database.impersonate_user: oauth2_config = database.get_oauth2_config() + # Import here to avoid circular imports + from superset.utils.oauth2 import get_oauth2_access_token + access_token = ( get_oauth2_access_token( oauth2_config, @@ -275,6 +284,9 @@ class EngineManager: # mutate URI if mutator := current_app.config["DB_CONNECTION_MUTATOR"]: source = source or get_query_source_from_request() + # Import here to avoid circular imports + from superset.extensions import security_manager + uri, kwargs = mutator( uri, kwargs, @@ -290,7 +302,7 @@ class EngineManager: def _create_engine( self, - database: Database, + database: "Database", catalog: str | None, schema: str | None, source: QuerySource | None, @@ -324,7 +336,7 @@ class EngineManager: return engine - def _get_tunnel(self, ssh_tunnel: SSHTunnel, uri: URL) -> SSHTunnelForwarder: + def _get_tunnel(self, ssh_tunnel: "SSHTunnel", uri: URL) -> SSHTunnelForwarder: tunnel_key = self._get_tunnel_key(ssh_tunnel, uri) # tunnel exists and is healthy @@ -345,7 +357,7 @@ class EngineManager: def _replace_tunnel( self, tunnel_key: str, - ssh_tunnel: SSHTunnel, + ssh_tunnel: "SSHTunnel", uri: URL, old_tunnel: SSHTunnelForwarder | None, ) -> SSHTunnelForwarder: @@ -371,7 +383,7 @@ class EngineManager: return new_tunnel - def _get_tunnel_key(self, ssh_tunnel: SSHTunnel, uri: URL) -> TunnelKey: + def _get_tunnel_key(self, ssh_tunnel: "SSHTunnel", uri: URL) -> TunnelKey: """ Build a unique key for the SSH tunnel. """ @@ -379,15 +391,18 @@ class EngineManager: return dumps(keys, sort_keys=True) - def _create_tunnel(self, ssh_tunnel: SSHTunnel, uri: URL) -> SSHTunnelForwarder: + def _create_tunnel(self, ssh_tunnel: "SSHTunnel", uri: URL) -> SSHTunnelForwarder: kwargs = self._get_tunnel_kwargs(ssh_tunnel, uri) tunnel = SSHTunnelForwarder(**kwargs) tunnel.start() return tunnel - def _get_tunnel_kwargs(self, ssh_tunnel: SSHTunnel, uri: URL) -> dict[str, Any]: + def _get_tunnel_kwargs(self, ssh_tunnel: "SSHTunnel", uri: URL) -> dict[str, Any]: backend = uri.get_backend_name() + # Import here to avoid circular imports + from superset.utils.ssh_tunnel import get_default_port + kwargs = { "ssh_address_or_host": (ssh_tunnel.server_address, ssh_tunnel.server_port), "ssh_username": ssh_tunnel.username, diff --git a/superset/extensions/__init__.py b/superset/extensions/__init__.py index 628af40cd62..07ffbe743a3 100644 --- a/superset/extensions/__init__.py +++ b/superset/extensions/__init__.py @@ -41,6 +41,7 @@ from werkzeug.local import LocalProxy from superset.async_events.async_query_manager import AsyncQueryManager from superset.async_events.async_query_manager_factory import AsyncQueryManagerFactory +from superset.extensions.engine_manager import EngineManagerExtension from superset.extensions.ssh import SSHManagerFactory from superset.extensions.stats_logger import BaseStatsLoggerManager from superset.security.manager import SupersetSecurityManager @@ -136,6 +137,7 @@ cache_manager = CacheManager() celery_app = celery.Celery() csrf = CSRFProtect() db = get_sqla_class()() +engine_manager_extension = EngineManagerExtension() _event_logger: dict[str, Any] = {} encrypted_field_factory = EncryptedFieldFactory() event_logger = LocalProxy(lambda: _event_logger.get("event_logger")) diff --git a/superset/extensions/engine_manager.py b/superset/extensions/engine_manager.py new file mode 100644 index 00000000000..5c07d1cdd9e --- /dev/null +++ b/superset/extensions/engine_manager.py @@ -0,0 +1,103 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import logging +from typing import TYPE_CHECKING + +from flask import Flask + +from superset.engines.manager import EngineManager, EngineModes + +if TYPE_CHECKING: + pass + +logger = logging.getLogger(__name__) + + +class EngineManagerExtension: + """ + Flask extension for managing SQLAlchemy engines in Superset. + + This extension creates and configures an EngineManager instance based on + Flask configuration, handling startup and shutdown of background cleanup + threads as needed. + """ + + def __init__(self) -> None: + self.engine_manager: EngineManager | None = None + + def init_app(self, app: Flask) -> None: + """ + Initialize the EngineManager with Flask app configuration. + """ + # Get configuration values with defaults + mode_name = app.config.get("ENGINE_MANAGER_MODE", "NEW") + cleanup_interval = app.config.get("ENGINE_MANAGER_CLEANUP_INTERVAL", 300.0) + auto_start_cleanup = app.config.get("ENGINE_MANAGER_AUTO_START_CLEANUP", True) + + # Convert mode string to enum + try: + mode = EngineModes[mode_name.upper()] + except KeyError: + logger.warning( + f"Invalid ENGINE_MANAGER_MODE '{mode_name}', defaulting to NEW" + ) + mode = EngineModes.NEW + + # Create the engine manager + self.engine_manager = EngineManager( + mode=mode, + cleanup_interval=cleanup_interval, + ) + + # Start cleanup thread if requested and in SINGLETON mode + if auto_start_cleanup and mode == EngineModes.SINGLETON: + self.engine_manager.start_cleanup_thread() + logger.info("Started EngineManager cleanup thread") + + # Register shutdown handler + def shutdown_engine_manager() -> None: + if self.engine_manager: + self.engine_manager.stop_cleanup_thread() + logger.info("Stopped EngineManager cleanup thread") + + app.teardown_appcontext_funcs.append(lambda exc: None) + + # Register with atexit for clean shutdown + import atexit + + atexit.register(shutdown_engine_manager) + + logger.info( + f"Initialized EngineManager with mode={mode.name}, " + f"cleanup_interval={cleanup_interval}s" + ) + + @property + def manager(self) -> EngineManager: + """ + Get the EngineManager instance. + + Raises: + RuntimeError: If the extension hasn't been initialized with an app. + """ + if self.engine_manager is None: + raise RuntimeError( + "EngineManager extension not initialized. " + "Call init_app() with a Flask app first." + ) + return self.engine_manager diff --git a/superset/initialization/__init__.py b/superset/initialization/__init__.py index 3a34d315bf5..f18149d61f8 100644 --- a/superset/initialization/__init__.py +++ b/superset/initialization/__init__.py @@ -49,6 +49,7 @@ from superset.extensions import ( csrf, db, encrypted_field_factory, + engine_manager_extension, feature_flag_manager, machine_auth_provider_factory, manifest_processor, @@ -585,6 +586,7 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods self.configure_url_map_converters() self.configure_data_sources() self.configure_auth_provider() + self.configure_engine_manager() self.configure_async_queries() self.configure_ssh_manager() self.configure_stats_manager() @@ -761,6 +763,9 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods def configure_auth_provider(self) -> None: machine_auth_provider_factory.init_app(self.superset_app) + def configure_engine_manager(self) -> None: + engine_manager_extension.init_app(self.superset_app) + def configure_ssh_manager(self) -> None: ssh_manager_factory.init_app(self.superset_app)
