This is an automated email from the ASF dual-hosted git repository. maximebeauchemin pushed a commit to branch late_config in repository https://gitbox.apache.org/repos/asf/superset.git
commit 9487db44c1c921f75d29875bd92f22c8a6e457c9 Author: Maxime Beauchemin <[email protected]> AuthorDate: Tue May 7 09:58:57 2024 -0700 chore: module scope should not require the app context This is a major issue that's been plaguing the backend for a long time. Currently you can't simply run a simple `from superset import models` without getting an error about the app context missing. This DRAFT PR tries to evaluate what's getting in the way. So far I've identified: - app.config being referenced in module scope, this is mostly easy to fix for common configuration, where we can rely on flask.current_app and avoid module scope - dynamic/configurable model: this seems to be the core issue, where say we let people configure their EncryptedField's type - some reliance on SecurityManager, and SecurityManager.user_model used as a relationship, I think this can be worked around using sqlalchemy's `relationship("use-a-string-reference")` call - ??? --- superset/initialization/__init__.py | 1 + superset/models/core.py | 49 ++++++++++++++++++++++--------------- superset/models/helpers.py | 16 ++++++------ superset/models/sql_lab.py | 4 +-- superset/models/user_attributes.py | 5 +--- superset/utils/cache.py | 13 +++++----- superset/utils/encrypt.py | 26 ++++++++------------ 7 files changed, 56 insertions(+), 58 deletions(-) diff --git a/superset/initialization/__init__.py b/superset/initialization/__init__.py index 83d3afb3d1..77dc4cfe81 100644 --- a/superset/initialization/__init__.py +++ b/superset/initialization/__init__.py @@ -71,6 +71,7 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods self.superset_app = app self.config = app.config self.manifest: dict[Any, Any] = {} + self.stats_logger = app.config["STATS_LOGGER"] @deprecated(details="use self.superset_app instead of self.flask_app") # type: ignore @property diff --git a/superset/models/core.py b/superset/models/core.py index fe486bf2b1..b9ebd760f7 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -81,10 +81,6 @@ from superset.utils.backports import StrEnum from superset.utils.core import DatasourceName, get_username from superset.utils.oauth2 import get_oauth2_access_token -config = app.config -custom_password_store = config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"] -stats_logger = config["STATS_LOGGER"] -log_query = config["QUERY_LOGGER"] metadata = Model.metadata # pylint: disable=no-member logger = logging.getLogger(__name__) @@ -92,8 +88,6 @@ if TYPE_CHECKING: from superset.databases.ssh_tunnel.models import SSHTunnel from superset.models.sql_lab import Query -DB_CONNECTION_MUTATOR = config["DB_CONNECTION_MUTATOR"] - class KeyValue(Model): # pylint: disable=too-few-public-methods """Used for any type of key-value store""" @@ -369,6 +363,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable def set_sqlalchemy_uri(self, uri: str) -> None: conn = make_url_safe(uri.strip()) + custom_password_store = app.config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"] if conn.password != PASSWORD_MASK and not custom_password_store: # do not over-write the password with the password mask self.password = conn.password @@ -446,7 +441,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable sqlalchemy_uri=sqlalchemy_uri, ) - def _get_sqla_engine( + def _get_sqla_engine( # pylint: disable=too-many-locals self, catalog: str | None = None, schema: str | None = None, @@ -510,7 +505,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable self.update_params_from_encrypted_extra(params) - if DB_CONNECTION_MUTATOR: + if db_conn_mutator := app.config["DB_CONNECTION_MUTATOR"]: if not source and request and request.referrer: if "/superset/dashboard/" in request.referrer: source = utils.QuerySource.DASHBOARD @@ -519,7 +514,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable elif "/sqllab/" in request.referrer: source = utils.QuerySource.SQL_LAB - sqlalchemy_url, params = DB_CONNECTION_MUTATOR( + sqlalchemy_url, params = db_conn_mutator( sqlalchemy_url, params, effective_username, @@ -618,8 +613,8 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable on the group of queries as a whole. Here the called passes the context as to whether the SQL is split or already. """ - sql_mutator = config["SQL_QUERY_MUTATOR"] - if sql_mutator and (is_split == config["MUTATE_AFTER_SPLIT"]): + sql_mutator = app.config["SQL_QUERY_MUTATOR"] + if sql_mutator and (is_split == app.config["MUTATE_AFTER_SPLIT"]): return sql_mutator( sql_, security_manager=security_manager, @@ -638,6 +633,8 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable with self.get_sqla_engine(catalog=catalog, schema=schema) as engine: engine_url = engine.url + log_query = app.config["QUERY_LOGGER"] + def _log_query(sql: str) -> None: if log_query: log_query( @@ -974,7 +971,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable allowed_databases = literal_eval(allowed_databases) if hasattr(g, "user"): - extra_allowed_databases = config["ALLOWED_USER_CSV_SCHEMA_FUNC"]( + extra_allowed_databases = app.config["ALLOWED_USER_CSV_SCHEMA_FUNC"]( self, g.user ) allowed_databases += extra_allowed_databases @@ -988,7 +985,8 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable # if the URI is invalid, ignore and return a placeholder url # (so users see 500 less often) return "dialect://invalid_uri" - if custom_password_store: + + if custom_password_store := app.config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"]: conn = conn.set(password=custom_password_store(conn)) else: conn = conn.set(password=self.password) @@ -1071,9 +1069,22 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable return self.db_engine_spec.get_oauth2_config() -sqla.event.listen(Database, "after_insert", security_manager.database_after_insert) -sqla.event.listen(Database, "after_update", security_manager.database_after_update) -sqla.event.listen(Database, "after_delete", security_manager.database_after_delete) +# Using lambdas for security manager to prevent referencing it in module scope +sqla.event.listen( + Database, + "after_insert", + lambda *args, **kwargs: security_manager.database_after_insert(*args, **kwargs), # pylint: disable=unnecessary-lambda +) +sqla.event.listen( + Database, + "after_update", + lambda *args, **kwargs: security_manager.database_after_update(*args, **kwargs), # pylint: disable=unnecessary-lambda +) +sqla.event.listen( + Database, + "after_delete", + lambda *args, **kwargs: security_manager.database_after_delete(*args, **kwargs), # pylint: disable=unnecessary-lambda +) class DatabaseUserOAuth2Tokens(Model, AuditMixinNullable): @@ -1091,7 +1102,7 @@ class DatabaseUserOAuth2Tokens(Model, AuditMixinNullable): ForeignKey("ab_user.id", ondelete="CASCADE"), nullable=False, ) - user = relationship(security_manager.user_model, foreign_keys=[user_id]) + user = relationship("User", foreign_keys=[user_id]) database_id = Column( Integer, @@ -1116,9 +1127,7 @@ class Log(Model): # pylint: disable=too-few-public-methods dashboard_id = Column(Integer) slice_id = Column(Integer) json = Column(utils.MediumText()) - user = relationship( - security_manager.user_model, backref="logs", foreign_keys=[user_id] - ) + user = relationship("User", backref="logs", foreign_keys=[user_id]) dttm = Column(DateTime, default=datetime.utcnow) duration_ms = Column(Integer) referrer = Column(String(1024)) diff --git a/superset/models/helpers.py b/superset/models/helpers.py index 100391086c..1ebf9244f7 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -105,12 +105,10 @@ if TYPE_CHECKING: from superset.models.core import Database -config = app.config logger = logging.getLogger(__name__) VIRTUAL_TABLE_ALIAS = "virtual_table" SERIES_LIMIT_SUBQ_ALIAS = "series_limit" -ADVANCED_DATA_TYPES = config["ADVANCED_DATA_TYPES"] def validate_adhoc_subquery( @@ -1822,12 +1820,12 @@ class ExploreMixin: # pylint: disable=too-many-public-methods and feature_flag_manager.is_feature_enabled( "ENABLE_ADVANCED_DATA_TYPES" ) - and col_advanced_data_type in ADVANCED_DATA_TYPES + and col_advanced_data_type in app.config["ADVANCED_DATA_TYPES"] ): values = eq if is_list_target else [eq] # type: ignore - bus_resp: AdvancedDataTypeResponse = ADVANCED_DATA_TYPES[ - col_advanced_data_type - ].translate_type( + bus_resp: AdvancedDataTypeResponse = app.config[ + "ADVANCED_DATA_TYPES" + ][col_advanced_data_type].translate_type( { "type": col_advanced_data_type, "values": values, @@ -1839,9 +1837,9 @@ class ExploreMixin: # pylint: disable=too-many-public-methods ) where_clause_and.append( - ADVANCED_DATA_TYPES[col_advanced_data_type].translate_filter( - sqla_col, op, bus_resp["values"] - ) + app.config["ADVANCED_DATA_TYPES"][ + col_advanced_data_type + ].translate_filter(sqla_col, op, bus_resp["values"]) ) elif is_list_target: assert isinstance(eq, (tuple, list)) diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py index 31443b4bb1..9707b442bb 100644 --- a/superset/models/sql_lab.py +++ b/superset/models/sql_lab.py @@ -148,7 +148,7 @@ class Query( foreign_keys=[database_id], backref=backref("queries", cascade="all, delete-orphan"), ) - user = relationship(security_manager.user_model, foreign_keys=[user_id]) + user = relationship("User", foreign_keys=[user_id]) __table_args__ = (sqla.Index("ti_user_id_changed_on", user_id, changed_on),) @@ -393,7 +393,7 @@ class SavedQuery( sql = Column(MediumText()) template_parameters = Column(Text) user = relationship( - security_manager.user_model, + "User", backref=backref("saved_queries", cascade="all, delete-orphan"), foreign_keys=[user_id], ) diff --git a/superset/models/user_attributes.py b/superset/models/user_attributes.py index 512270c89c..f5438ccc6e 100644 --- a/superset/models/user_attributes.py +++ b/superset/models/user_attributes.py @@ -19,7 +19,6 @@ from flask_appbuilder import Model from sqlalchemy import Column, ForeignKey, Integer, String from sqlalchemy.orm import relationship -from superset import security_manager from superset.models.helpers import AuditMixinNullable @@ -36,9 +35,7 @@ class UserAttribute(Model, AuditMixinNullable): __tablename__ = "user_attribute" id = Column(Integer, primary_key=True) user_id = Column(Integer, ForeignKey("ab_user.id")) - user = relationship( - security_manager.user_model, backref="extra_attributes", foreign_keys=[user_id] - ) + user = relationship("User", backref="extra_attributes", foreign_keys=[user_id]) welcome_dashboard_id = Column(Integer, ForeignKey("dashboards.id")) welcome_dashboard = relationship("Dashboard") avatar_url = Column(String(100)) diff --git a/superset/utils/cache.py b/superset/utils/cache.py index 00216fc4b1..c7ba4e8382 100644 --- a/superset/utils/cache.py +++ b/superset/utils/cache.py @@ -20,7 +20,7 @@ import inspect import logging from datetime import datetime, timedelta from functools import wraps -from typing import Any, Callable, TYPE_CHECKING +from typing import Any, Callable, Optional, TYPE_CHECKING from flask import current_app as app, request from flask_caching import Cache @@ -34,10 +34,8 @@ from superset.utils.core import json_int_dttm_ser from superset.utils.hashing import md5_sha_from_dict if TYPE_CHECKING: - from superset.stats_logger import BaseStatsLogger + pass -config = app.config -stats_logger: BaseStatsLogger = config["STATS_LOGGER"] logger = logging.getLogger(__name__) @@ -65,9 +63,9 @@ def set_and_log_cache( dttm = datetime.utcnow().isoformat().split(".")[0] value = {**cache_value, "dttm": dttm} cache_instance.set(cache_key, value, timeout=timeout) - stats_logger.incr("set_cache_key") + app.stats_logger.incr("set_cache_key") - if datasource_uid and config["STORE_CACHE_KEYS_IN_METADATA_DB"]: + if datasource_uid and app.config["STORE_CACHE_KEYS_IN_METADATA_DB"]: ck = CacheKey( cache_key=cache_key, cache_timeout=cache_timeout, @@ -147,7 +145,7 @@ def memoized_func(key: str, cache: Cache = cache_manager.cache) -> Callable[..., def etag_cache( cache: Cache = cache_manager.cache, get_last_modified: Callable[..., datetime] | None = None, - max_age: int | float = app.config["CACHE_DEFAULT_TIMEOUT"], + max_age: Optional[int | float] = None, raise_for_access: Callable[..., Any] | None = None, skip: Callable[..., bool] | None = None, ) -> Callable[..., Any]: @@ -163,6 +161,7 @@ def etag_cache( dataframe cache for requests that produce the same SQL. """ + max_age = max_age or app.config["CACHE_DEFAULT_TIMEOUT"] def decorator(f: Callable[..., Any]) -> Callable[..., Any]: @wraps(f) diff --git a/superset/utils/encrypt.py b/superset/utils/encrypt.py index e5bbd13825..ed58a49f6a 100644 --- a/superset/utils/encrypt.py +++ b/superset/utils/encrypt.py @@ -18,7 +18,7 @@ import logging from abc import ABC, abstractmethod from typing import Any, Optional -from flask import Flask +from flask import current_app as app from flask_babel import lazy_gettext as _ from sqlalchemy import text, TypeDecorator from sqlalchemy.engine import Connection, Dialect, Row @@ -27,6 +27,10 @@ from sqlalchemy_utils import EncryptedType logger = logging.getLogger(__name__) +def get_key() -> str: + return app.config["SECRET_KEY"] + + class AbstractEncryptedFieldAdapter(ABC): # pylint: disable=too-few-public-methods @abstractmethod def create( @@ -47,22 +51,15 @@ class SQLAlchemyUtilsAdapter( # pylint: disable=too-few-public-methods *args: list[Any], **kwargs: Optional[dict[str, Any]], ) -> TypeDecorator: - if app_config: - return EncryptedType(*args, app_config["SECRET_KEY"], **kwargs) - - raise Exception( # pylint: disable=broad-exception-raised - "Missing app_config kwarg" - ) + return EncryptedType(*args, get_key, **kwargs) class EncryptedFieldFactory: def __init__(self) -> None: self._concrete_type_adapter: Optional[AbstractEncryptedFieldAdapter] = None - self._config: Optional[dict[str, Any]] = None - def init_app(self, app: Flask) -> None: - self._config = app.config - self._concrete_type_adapter = self._config[ # type: ignore + def init_app(self, *args, **kwargs) -> None: # type: ignore # pylint: disable=unused-argument + self._concrete_type_adapter = app.config[ "SQLALCHEMY_ENCRYPTED_FIELD_TYPE_ADAPTER" ]() @@ -70,11 +67,8 @@ class EncryptedFieldFactory: self, *args: list[Any], **kwargs: Optional[dict[str, Any]] ) -> TypeDecorator: if self._concrete_type_adapter: - return self._concrete_type_adapter.create(self._config, *args, **kwargs) - - raise Exception( # pylint: disable=broad-exception-raised - "App not initialized yet. Please call init_app first" - ) + return self._concrete_type_adapter.create(app.config, *args, **kwargs) + return None class SecretsMigrator:
