This is an automated email from the ASF dual-hosted git repository.

maximebeauchemin pushed a commit to branch late_config
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 9487db44c1c921f75d29875bd92f22c8a6e457c9
Author: Maxime Beauchemin <[email protected]>
AuthorDate: Tue May 7 09:58:57 2024 -0700

    chore: module scope should not require the app context
    
    This is a major issue that's been plaguing the backend for a long time. 
Currently you can't simply run a simple `from superset import models` without 
getting an error about the app context missing.
    
    This DRAFT PR tries to evaluate what's getting in the way. So far I've 
identified:
    
    - app.config being referenced in module scope, this is mostly easy to fix 
for common configuration, where we can rely on flask.current_app and avoid 
module scope
    - dynamic/configurable model: this seems to be the core issue, where say we 
let people configure their EncryptedField's type
    - some reliance on SecurityManager, and SecurityManager.user_model used as 
a relationship, I think this can be worked around using sqlalchemy's 
`relationship("use-a-string-reference")` call
    - ???
---
 superset/initialization/__init__.py |  1 +
 superset/models/core.py             | 49 ++++++++++++++++++++++---------------
 superset/models/helpers.py          | 16 ++++++------
 superset/models/sql_lab.py          |  4 +--
 superset/models/user_attributes.py  |  5 +---
 superset/utils/cache.py             | 13 +++++-----
 superset/utils/encrypt.py           | 26 ++++++++------------
 7 files changed, 56 insertions(+), 58 deletions(-)

diff --git a/superset/initialization/__init__.py 
b/superset/initialization/__init__.py
index 83d3afb3d1..77dc4cfe81 100644
--- a/superset/initialization/__init__.py
+++ b/superset/initialization/__init__.py
@@ -71,6 +71,7 @@ class SupersetAppInitializer:  # pylint: 
disable=too-many-public-methods
         self.superset_app = app
         self.config = app.config
         self.manifest: dict[Any, Any] = {}
+        self.stats_logger = app.config["STATS_LOGGER"]
 
     @deprecated(details="use self.superset_app instead of self.flask_app")  # 
type: ignore
     @property
diff --git a/superset/models/core.py b/superset/models/core.py
index fe486bf2b1..b9ebd760f7 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -81,10 +81,6 @@ from superset.utils.backports import StrEnum
 from superset.utils.core import DatasourceName, get_username
 from superset.utils.oauth2 import get_oauth2_access_token
 
-config = app.config
-custom_password_store = config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"]
-stats_logger = config["STATS_LOGGER"]
-log_query = config["QUERY_LOGGER"]
 metadata = Model.metadata  # pylint: disable=no-member
 logger = logging.getLogger(__name__)
 
@@ -92,8 +88,6 @@ if TYPE_CHECKING:
     from superset.databases.ssh_tunnel.models import SSHTunnel
     from superset.models.sql_lab import Query
 
-DB_CONNECTION_MUTATOR = config["DB_CONNECTION_MUTATOR"]
-
 
 class KeyValue(Model):  # pylint: disable=too-few-public-methods
     """Used for any type of key-value store"""
@@ -369,6 +363,7 @@ class Database(Model, AuditMixinNullable, 
ImportExportMixin):  # pylint: disable
 
     def set_sqlalchemy_uri(self, uri: str) -> None:
         conn = make_url_safe(uri.strip())
+        custom_password_store = app.config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"]
         if conn.password != PASSWORD_MASK and not custom_password_store:
             # do not over-write the password with the password mask
             self.password = conn.password
@@ -446,7 +441,7 @@ class Database(Model, AuditMixinNullable, 
ImportExportMixin):  # pylint: disable
                 sqlalchemy_uri=sqlalchemy_uri,
             )
 
-    def _get_sqla_engine(
+    def _get_sqla_engine(  # pylint: disable=too-many-locals
         self,
         catalog: str | None = None,
         schema: str | None = None,
@@ -510,7 +505,7 @@ class Database(Model, AuditMixinNullable, 
ImportExportMixin):  # pylint: disable
 
         self.update_params_from_encrypted_extra(params)
 
-        if DB_CONNECTION_MUTATOR:
+        if db_conn_mutator := app.config["DB_CONNECTION_MUTATOR"]:
             if not source and request and request.referrer:
                 if "/superset/dashboard/" in request.referrer:
                     source = utils.QuerySource.DASHBOARD
@@ -519,7 +514,7 @@ class Database(Model, AuditMixinNullable, 
ImportExportMixin):  # pylint: disable
                 elif "/sqllab/" in request.referrer:
                     source = utils.QuerySource.SQL_LAB
 
-            sqlalchemy_url, params = DB_CONNECTION_MUTATOR(
+            sqlalchemy_url, params = db_conn_mutator(
                 sqlalchemy_url,
                 params,
                 effective_username,
@@ -618,8 +613,8 @@ class Database(Model, AuditMixinNullable, 
ImportExportMixin):  # pylint: disable
           on the group of queries as a whole. Here the called passes the 
context
           as to whether the SQL is split or already.
         """
-        sql_mutator = config["SQL_QUERY_MUTATOR"]
-        if sql_mutator and (is_split == config["MUTATE_AFTER_SPLIT"]):
+        sql_mutator = app.config["SQL_QUERY_MUTATOR"]
+        if sql_mutator and (is_split == app.config["MUTATE_AFTER_SPLIT"]):
             return sql_mutator(
                 sql_,
                 security_manager=security_manager,
@@ -638,6 +633,8 @@ class Database(Model, AuditMixinNullable, 
ImportExportMixin):  # pylint: disable
         with self.get_sqla_engine(catalog=catalog, schema=schema) as engine:
             engine_url = engine.url
 
+        log_query = app.config["QUERY_LOGGER"]
+
         def _log_query(sql: str) -> None:
             if log_query:
                 log_query(
@@ -974,7 +971,7 @@ class Database(Model, AuditMixinNullable, 
ImportExportMixin):  # pylint: disable
             allowed_databases = literal_eval(allowed_databases)
 
         if hasattr(g, "user"):
-            extra_allowed_databases = config["ALLOWED_USER_CSV_SCHEMA_FUNC"](
+            extra_allowed_databases = 
app.config["ALLOWED_USER_CSV_SCHEMA_FUNC"](
                 self, g.user
             )
             allowed_databases += extra_allowed_databases
@@ -988,7 +985,8 @@ class Database(Model, AuditMixinNullable, 
ImportExportMixin):  # pylint: disable
             # if the URI is invalid, ignore and return a placeholder url
             # (so users see 500 less often)
             return "dialect://invalid_uri"
-        if custom_password_store:
+
+        if custom_password_store := 
app.config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"]:
             conn = conn.set(password=custom_password_store(conn))
         else:
             conn = conn.set(password=self.password)
@@ -1071,9 +1069,22 @@ class Database(Model, AuditMixinNullable, 
ImportExportMixin):  # pylint: disable
         return self.db_engine_spec.get_oauth2_config()
 
 
-sqla.event.listen(Database, "after_insert", 
security_manager.database_after_insert)
-sqla.event.listen(Database, "after_update", 
security_manager.database_after_update)
-sqla.event.listen(Database, "after_delete", 
security_manager.database_after_delete)
+# Using lambdas for security manager to prevent referencing it in module scope
+sqla.event.listen(
+    Database,
+    "after_insert",
+    lambda *args, **kwargs: security_manager.database_after_insert(*args, 
**kwargs),  # pylint: disable=unnecessary-lambda
+)
+sqla.event.listen(
+    Database,
+    "after_update",
+    lambda *args, **kwargs: security_manager.database_after_update(*args, 
**kwargs),  # pylint: disable=unnecessary-lambda
+)
+sqla.event.listen(
+    Database,
+    "after_delete",
+    lambda *args, **kwargs: security_manager.database_after_delete(*args, 
**kwargs),  # pylint: disable=unnecessary-lambda
+)
 
 
 class DatabaseUserOAuth2Tokens(Model, AuditMixinNullable):
@@ -1091,7 +1102,7 @@ class DatabaseUserOAuth2Tokens(Model, AuditMixinNullable):
         ForeignKey("ab_user.id", ondelete="CASCADE"),
         nullable=False,
     )
-    user = relationship(security_manager.user_model, foreign_keys=[user_id])
+    user = relationship("User", foreign_keys=[user_id])
 
     database_id = Column(
         Integer,
@@ -1116,9 +1127,7 @@ class Log(Model):  # pylint: 
disable=too-few-public-methods
     dashboard_id = Column(Integer)
     slice_id = Column(Integer)
     json = Column(utils.MediumText())
-    user = relationship(
-        security_manager.user_model, backref="logs", foreign_keys=[user_id]
-    )
+    user = relationship("User", backref="logs", foreign_keys=[user_id])
     dttm = Column(DateTime, default=datetime.utcnow)
     duration_ms = Column(Integer)
     referrer = Column(String(1024))
diff --git a/superset/models/helpers.py b/superset/models/helpers.py
index 100391086c..1ebf9244f7 100644
--- a/superset/models/helpers.py
+++ b/superset/models/helpers.py
@@ -105,12 +105,10 @@ if TYPE_CHECKING:
     from superset.models.core import Database
 
 
-config = app.config
 logger = logging.getLogger(__name__)
 
 VIRTUAL_TABLE_ALIAS = "virtual_table"
 SERIES_LIMIT_SUBQ_ALIAS = "series_limit"
-ADVANCED_DATA_TYPES = config["ADVANCED_DATA_TYPES"]
 
 
 def validate_adhoc_subquery(
@@ -1822,12 +1820,12 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
                     and feature_flag_manager.is_feature_enabled(
                         "ENABLE_ADVANCED_DATA_TYPES"
                     )
-                    and col_advanced_data_type in ADVANCED_DATA_TYPES
+                    and col_advanced_data_type in 
app.config["ADVANCED_DATA_TYPES"]
                 ):
                     values = eq if is_list_target else [eq]  # type: ignore
-                    bus_resp: AdvancedDataTypeResponse = ADVANCED_DATA_TYPES[
-                        col_advanced_data_type
-                    ].translate_type(
+                    bus_resp: AdvancedDataTypeResponse = app.config[
+                        "ADVANCED_DATA_TYPES"
+                    ][col_advanced_data_type].translate_type(
                         {
                             "type": col_advanced_data_type,
                             "values": values,
@@ -1839,9 +1837,9 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
                         )
 
                     where_clause_and.append(
-                        
ADVANCED_DATA_TYPES[col_advanced_data_type].translate_filter(
-                            sqla_col, op, bus_resp["values"]
-                        )
+                        app.config["ADVANCED_DATA_TYPES"][
+                            col_advanced_data_type
+                        ].translate_filter(sqla_col, op, bus_resp["values"])
                     )
                 elif is_list_target:
                     assert isinstance(eq, (tuple, list))
diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py
index 31443b4bb1..9707b442bb 100644
--- a/superset/models/sql_lab.py
+++ b/superset/models/sql_lab.py
@@ -148,7 +148,7 @@ class Query(
         foreign_keys=[database_id],
         backref=backref("queries", cascade="all, delete-orphan"),
     )
-    user = relationship(security_manager.user_model, foreign_keys=[user_id])
+    user = relationship("User", foreign_keys=[user_id])
 
     __table_args__ = (sqla.Index("ti_user_id_changed_on", user_id, 
changed_on),)
 
@@ -393,7 +393,7 @@ class SavedQuery(
     sql = Column(MediumText())
     template_parameters = Column(Text)
     user = relationship(
-        security_manager.user_model,
+        "User",
         backref=backref("saved_queries", cascade="all, delete-orphan"),
         foreign_keys=[user_id],
     )
diff --git a/superset/models/user_attributes.py 
b/superset/models/user_attributes.py
index 512270c89c..f5438ccc6e 100644
--- a/superset/models/user_attributes.py
+++ b/superset/models/user_attributes.py
@@ -19,7 +19,6 @@ from flask_appbuilder import Model
 from sqlalchemy import Column, ForeignKey, Integer, String
 from sqlalchemy.orm import relationship
 
-from superset import security_manager
 from superset.models.helpers import AuditMixinNullable
 
 
@@ -36,9 +35,7 @@ class UserAttribute(Model, AuditMixinNullable):
     __tablename__ = "user_attribute"
     id = Column(Integer, primary_key=True)
     user_id = Column(Integer, ForeignKey("ab_user.id"))
-    user = relationship(
-        security_manager.user_model, backref="extra_attributes", 
foreign_keys=[user_id]
-    )
+    user = relationship("User", backref="extra_attributes", 
foreign_keys=[user_id])
     welcome_dashboard_id = Column(Integer, ForeignKey("dashboards.id"))
     welcome_dashboard = relationship("Dashboard")
     avatar_url = Column(String(100))
diff --git a/superset/utils/cache.py b/superset/utils/cache.py
index 00216fc4b1..c7ba4e8382 100644
--- a/superset/utils/cache.py
+++ b/superset/utils/cache.py
@@ -20,7 +20,7 @@ import inspect
 import logging
 from datetime import datetime, timedelta
 from functools import wraps
-from typing import Any, Callable, TYPE_CHECKING
+from typing import Any, Callable, Optional, TYPE_CHECKING
 
 from flask import current_app as app, request
 from flask_caching import Cache
@@ -34,10 +34,8 @@ from superset.utils.core import json_int_dttm_ser
 from superset.utils.hashing import md5_sha_from_dict
 
 if TYPE_CHECKING:
-    from superset.stats_logger import BaseStatsLogger
+    pass
 
-config = app.config
-stats_logger: BaseStatsLogger = config["STATS_LOGGER"]
 logger = logging.getLogger(__name__)
 
 
@@ -65,9 +63,9 @@ def set_and_log_cache(
         dttm = datetime.utcnow().isoformat().split(".")[0]
         value = {**cache_value, "dttm": dttm}
         cache_instance.set(cache_key, value, timeout=timeout)
-        stats_logger.incr("set_cache_key")
+        app.stats_logger.incr("set_cache_key")
 
-        if datasource_uid and config["STORE_CACHE_KEYS_IN_METADATA_DB"]:
+        if datasource_uid and app.config["STORE_CACHE_KEYS_IN_METADATA_DB"]:
             ck = CacheKey(
                 cache_key=cache_key,
                 cache_timeout=cache_timeout,
@@ -147,7 +145,7 @@ def memoized_func(key: str, cache: Cache = 
cache_manager.cache) -> Callable[...,
 def etag_cache(
     cache: Cache = cache_manager.cache,
     get_last_modified: Callable[..., datetime] | None = None,
-    max_age: int | float = app.config["CACHE_DEFAULT_TIMEOUT"],
+    max_age: Optional[int | float] = None,
     raise_for_access: Callable[..., Any] | None = None,
     skip: Callable[..., bool] | None = None,
 ) -> Callable[..., Any]:
@@ -163,6 +161,7 @@ def etag_cache(
     dataframe cache for requests that produce the same SQL.
 
     """
+    max_age = max_age or app.config["CACHE_DEFAULT_TIMEOUT"]
 
     def decorator(f: Callable[..., Any]) -> Callable[..., Any]:
         @wraps(f)
diff --git a/superset/utils/encrypt.py b/superset/utils/encrypt.py
index e5bbd13825..ed58a49f6a 100644
--- a/superset/utils/encrypt.py
+++ b/superset/utils/encrypt.py
@@ -18,7 +18,7 @@ import logging
 from abc import ABC, abstractmethod
 from typing import Any, Optional
 
-from flask import Flask
+from flask import current_app as app
 from flask_babel import lazy_gettext as _
 from sqlalchemy import text, TypeDecorator
 from sqlalchemy.engine import Connection, Dialect, Row
@@ -27,6 +27,10 @@ from sqlalchemy_utils import EncryptedType
 logger = logging.getLogger(__name__)
 
 
+def get_key() -> str:
+    return app.config["SECRET_KEY"]
+
+
 class AbstractEncryptedFieldAdapter(ABC):  # pylint: 
disable=too-few-public-methods
     @abstractmethod
     def create(
@@ -47,22 +51,15 @@ class SQLAlchemyUtilsAdapter(  # pylint: 
disable=too-few-public-methods
         *args: list[Any],
         **kwargs: Optional[dict[str, Any]],
     ) -> TypeDecorator:
-        if app_config:
-            return EncryptedType(*args, app_config["SECRET_KEY"], **kwargs)
-
-        raise Exception(  # pylint: disable=broad-exception-raised
-            "Missing app_config kwarg"
-        )
+        return EncryptedType(*args, get_key, **kwargs)
 
 
 class EncryptedFieldFactory:
     def __init__(self) -> None:
         self._concrete_type_adapter: Optional[AbstractEncryptedFieldAdapter] = 
None
-        self._config: Optional[dict[str, Any]] = None
 
-    def init_app(self, app: Flask) -> None:
-        self._config = app.config
-        self._concrete_type_adapter = self._config[  # type: ignore
+    def init_app(self, *args, **kwargs) -> None:  # type: ignore # pylint: 
disable=unused-argument
+        self._concrete_type_adapter = app.config[
             "SQLALCHEMY_ENCRYPTED_FIELD_TYPE_ADAPTER"
         ]()
 
@@ -70,11 +67,8 @@ class EncryptedFieldFactory:
         self, *args: list[Any], **kwargs: Optional[dict[str, Any]]
     ) -> TypeDecorator:
         if self._concrete_type_adapter:
-            return self._concrete_type_adapter.create(self._config, *args, 
**kwargs)
-
-        raise Exception(  # pylint: disable=broad-exception-raised
-            "App not initialized yet. Please call init_app first"
-        )
+            return self._concrete_type_adapter.create(app.config, *args, 
**kwargs)
+        return None
 
 
 class SecretsMigrator:

Reply via email to