This is an automated email from the ASF dual-hosted git repository.

maximebeauchemin pushed a commit to branch late_config
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 8c78a3207a9169d019d30be37fdf783f2505e60f
Author: Maxime Beauchemin <[email protected]>
AuthorDate: Tue May 7 09:58:57 2024 -0700

    chore: module scope should not require the app context
    
    This is a major issue that's been plaguing the backend for a long time. 
Currently you can't simply run a simple `from superset import models` without 
getting an error about the app context missing.
    
    This DRAFT PR tries to evaluate what's getting in the way. So far I've 
identified:
    
    - app.config being referenced in module scope, this is mostly easy to fix 
for common configuration, where we can rely on flask.current_app and avoid 
module scope
    - dynamic/configurable model: this seems to be the core issue, where say we 
let people configure their EncryptedField's type
    - some reliance on SecurityManager, and SecurityManager.user_model used as 
a relationship, I think this can be worked around using sqlalchemy's 
`relationship("use-a-string-reference")` call
    - ???
---
 superset/commands/chart/warm_up_cache.py   |  4 +--
 superset/common/query_context_processor.py |  4 +--
 superset/connectors/sqla/models.py         | 10 +++---
 superset/initialization/__init__.py        |  1 +
 superset/models/core.py                    | 49 ++++++++++++++++-----------
 superset/models/helpers.py                 | 16 ++++-----
 superset/models/slice.py                   | 12 +++----
 superset/models/sql_lab.py                 |  4 +--
 superset/models/user_attributes.py         |  5 +--
 superset/utils/cache.py                    | 13 ++++---
 superset/utils/encrypt.py                  | 26 ++++++--------
 superset/utils/screenshots.py              |  8 ++---
 superset/views/utils.py                    |  6 ++--
 superset/viz.py                            | 54 +++++++++++++++---------------
 test.py                                    |  1 +
 test_imports.py                            | 31 +++++++++++++++++
 16 files changed, 135 insertions(+), 109 deletions(-)

diff --git a/superset/commands/chart/warm_up_cache.py 
b/superset/commands/chart/warm_up_cache.py
index 2e5c0ac3a3..80aa045e17 100644
--- a/superset/commands/chart/warm_up_cache.py
+++ b/superset/commands/chart/warm_up_cache.py
@@ -31,7 +31,7 @@ from superset.extensions import db
 from superset.models.slice import Slice
 from superset.utils.core import error_msg_from_exception
 from superset.views.utils import get_dashboard_extra_filters, get_form_data, 
get_viz
-from superset.viz import viz_types
+from superset.viz import get_viz_types
 
 
 class ChartWarmUpCacheCommand(BaseCommand):
@@ -52,7 +52,7 @@ class ChartWarmUpCacheCommand(BaseCommand):
         try:
             form_data = get_form_data(chart.id, use_slice_data=True)[0]
 
-            if form_data.get("viz_type") in viz_types:
+            if form_data.get("viz_type") in get_viz_types():
                 # Legacy visualizations.
                 if not chart.datasource:
                     raise ChartInvalidError("Chart's datasource does not 
exist")
diff --git a/superset/common/query_context_processor.py 
b/superset/common/query_context_processor.py
index 55c80386a3..f10d1717e2 100644
--- a/superset/common/query_context_processor.py
+++ b/superset/common/query_context_processor.py
@@ -66,7 +66,7 @@ from superset.utils.core import (
 from superset.utils.date_parser import get_past_or_future, normalize_time_delta
 from superset.utils.pandas_postprocessing.utils import unescape_separator
 from superset.views.utils import get_viz
-from superset.viz import viz_types
+from superset.viz import get_viz_types
 
 if TYPE_CHECKING:
     from superset.common.query_context import QueryContext
@@ -704,7 +704,7 @@ class QueryContextProcessor:
             raise QueryObjectValidationError(_("The chart does not exist"))
 
         try:
-            if chart.viz_type in viz_types:
+            if chart.viz_type in get_viz_types():
                 if not chart.datasource:
                     raise QueryObjectValidationError(
                         _("The chart datasource does not exist"),
diff --git a/superset/connectors/sqla/models.py 
b/superset/connectors/sqla/models.py
index 12fbdc3bd7..b4eee0c42e 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -122,10 +122,8 @@ from superset.utils import core as utils
 from superset.utils.backports import StrEnum
 from superset.utils.core import GenericDataType, MediumText
 
-config = app.config
 metadata = Model.metadata  # pylint: disable=no-member
 logger = logging.getLogger(__name__)
-ADVANCED_DATA_TYPES = config["ADVANCED_DATA_TYPES"]
 VIRTUAL_TABLE_ALIAS = "virtual_table"
 
 # a non-exhaustive set of additive metrics
@@ -1120,7 +1118,7 @@ class SqlaTable(
     )
     metric_class = SqlMetric
     column_class = TableColumn
-    owner_class = security_manager.user_model
+    owner_class = "User"
 
     __tablename__ = "tables"
 
@@ -1358,7 +1356,7 @@ class SqlaTable(
 
     @property
     def health_check_message(self) -> str | None:
-        check = config["DATASET_HEALTH_CHECK"]
+        check = app.config["DATASET_HEALTH_CHECK"]
         return check(self) if check else None
 
     @property
@@ -1876,7 +1874,7 @@ class SqlaTable(
         self.add_missing_metrics(metrics)
 
         # Apply config supplied mutations.
-        config["SQLA_TABLE_MUTATOR"](self)
+        app.config["SQLA_TABLE_MUTATOR"](self)
 
         db.session.merge(self)
         if commit:
@@ -2105,7 +2103,7 @@ class RowLevelSecurityFilter(Model, AuditMixinNullable):
     )
     group_key = Column(String(255), nullable=True)
     roles = relationship(
-        security_manager.role_model,
+        "Role",
         secondary=RLSFilterRoles,
         backref="row_level_security_filters",
     )
diff --git a/superset/initialization/__init__.py 
b/superset/initialization/__init__.py
index 83d3afb3d1..77dc4cfe81 100644
--- a/superset/initialization/__init__.py
+++ b/superset/initialization/__init__.py
@@ -71,6 +71,7 @@ class SupersetAppInitializer:  # pylint: 
disable=too-many-public-methods
         self.superset_app = app
         self.config = app.config
         self.manifest: dict[Any, Any] = {}
+        self.stats_logger = app.config["STATS_LOGGER"]
 
     @deprecated(details="use self.superset_app instead of self.flask_app")  # 
type: ignore
     @property
diff --git a/superset/models/core.py b/superset/models/core.py
index fe486bf2b1..b9ebd760f7 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -81,10 +81,6 @@ from superset.utils.backports import StrEnum
 from superset.utils.core import DatasourceName, get_username
 from superset.utils.oauth2 import get_oauth2_access_token
 
-config = app.config
-custom_password_store = config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"]
-stats_logger = config["STATS_LOGGER"]
-log_query = config["QUERY_LOGGER"]
 metadata = Model.metadata  # pylint: disable=no-member
 logger = logging.getLogger(__name__)
 
@@ -92,8 +88,6 @@ if TYPE_CHECKING:
     from superset.databases.ssh_tunnel.models import SSHTunnel
     from superset.models.sql_lab import Query
 
-DB_CONNECTION_MUTATOR = config["DB_CONNECTION_MUTATOR"]
-
 
 class KeyValue(Model):  # pylint: disable=too-few-public-methods
     """Used for any type of key-value store"""
@@ -369,6 +363,7 @@ class Database(Model, AuditMixinNullable, 
ImportExportMixin):  # pylint: disable
 
     def set_sqlalchemy_uri(self, uri: str) -> None:
         conn = make_url_safe(uri.strip())
+        custom_password_store = app.config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"]
         if conn.password != PASSWORD_MASK and not custom_password_store:
             # do not over-write the password with the password mask
             self.password = conn.password
@@ -446,7 +441,7 @@ class Database(Model, AuditMixinNullable, 
ImportExportMixin):  # pylint: disable
                 sqlalchemy_uri=sqlalchemy_uri,
             )
 
-    def _get_sqla_engine(
+    def _get_sqla_engine(  # pylint: disable=too-many-locals
         self,
         catalog: str | None = None,
         schema: str | None = None,
@@ -510,7 +505,7 @@ class Database(Model, AuditMixinNullable, 
ImportExportMixin):  # pylint: disable
 
         self.update_params_from_encrypted_extra(params)
 
-        if DB_CONNECTION_MUTATOR:
+        if db_conn_mutator := app.config["DB_CONNECTION_MUTATOR"]:
             if not source and request and request.referrer:
                 if "/superset/dashboard/" in request.referrer:
                     source = utils.QuerySource.DASHBOARD
@@ -519,7 +514,7 @@ class Database(Model, AuditMixinNullable, 
ImportExportMixin):  # pylint: disable
                 elif "/sqllab/" in request.referrer:
                     source = utils.QuerySource.SQL_LAB
 
-            sqlalchemy_url, params = DB_CONNECTION_MUTATOR(
+            sqlalchemy_url, params = db_conn_mutator(
                 sqlalchemy_url,
                 params,
                 effective_username,
@@ -618,8 +613,8 @@ class Database(Model, AuditMixinNullable, 
ImportExportMixin):  # pylint: disable
           on the group of queries as a whole. Here the called passes the 
context
           as to whether the SQL is split or already.
         """
-        sql_mutator = config["SQL_QUERY_MUTATOR"]
-        if sql_mutator and (is_split == config["MUTATE_AFTER_SPLIT"]):
+        sql_mutator = app.config["SQL_QUERY_MUTATOR"]
+        if sql_mutator and (is_split == app.config["MUTATE_AFTER_SPLIT"]):
             return sql_mutator(
                 sql_,
                 security_manager=security_manager,
@@ -638,6 +633,8 @@ class Database(Model, AuditMixinNullable, 
ImportExportMixin):  # pylint: disable
         with self.get_sqla_engine(catalog=catalog, schema=schema) as engine:
             engine_url = engine.url
 
+        log_query = app.config["QUERY_LOGGER"]
+
         def _log_query(sql: str) -> None:
             if log_query:
                 log_query(
@@ -974,7 +971,7 @@ class Database(Model, AuditMixinNullable, 
ImportExportMixin):  # pylint: disable
             allowed_databases = literal_eval(allowed_databases)
 
         if hasattr(g, "user"):
-            extra_allowed_databases = config["ALLOWED_USER_CSV_SCHEMA_FUNC"](
+            extra_allowed_databases = 
app.config["ALLOWED_USER_CSV_SCHEMA_FUNC"](
                 self, g.user
             )
             allowed_databases += extra_allowed_databases
@@ -988,7 +985,8 @@ class Database(Model, AuditMixinNullable, 
ImportExportMixin):  # pylint: disable
             # if the URI is invalid, ignore and return a placeholder url
             # (so users see 500 less often)
             return "dialect://invalid_uri"
-        if custom_password_store:
+
+        if custom_password_store := 
app.config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"]:
             conn = conn.set(password=custom_password_store(conn))
         else:
             conn = conn.set(password=self.password)
@@ -1071,9 +1069,22 @@ class Database(Model, AuditMixinNullable, 
ImportExportMixin):  # pylint: disable
         return self.db_engine_spec.get_oauth2_config()
 
 
-sqla.event.listen(Database, "after_insert", 
security_manager.database_after_insert)
-sqla.event.listen(Database, "after_update", 
security_manager.database_after_update)
-sqla.event.listen(Database, "after_delete", 
security_manager.database_after_delete)
+# Using lambdas for security manager to prevent referencing it in module scope
+sqla.event.listen(
+    Database,
+    "after_insert",
+    lambda *args, **kwargs: security_manager.database_after_insert(*args, 
**kwargs),  # pylint: disable=unnecessary-lambda
+)
+sqla.event.listen(
+    Database,
+    "after_update",
+    lambda *args, **kwargs: security_manager.database_after_update(*args, 
**kwargs),  # pylint: disable=unnecessary-lambda
+)
+sqla.event.listen(
+    Database,
+    "after_delete",
+    lambda *args, **kwargs: security_manager.database_after_delete(*args, 
**kwargs),  # pylint: disable=unnecessary-lambda
+)
 
 
 class DatabaseUserOAuth2Tokens(Model, AuditMixinNullable):
@@ -1091,7 +1102,7 @@ class DatabaseUserOAuth2Tokens(Model, AuditMixinNullable):
         ForeignKey("ab_user.id", ondelete="CASCADE"),
         nullable=False,
     )
-    user = relationship(security_manager.user_model, foreign_keys=[user_id])
+    user = relationship("User", foreign_keys=[user_id])
 
     database_id = Column(
         Integer,
@@ -1116,9 +1127,7 @@ class Log(Model):  # pylint: 
disable=too-few-public-methods
     dashboard_id = Column(Integer)
     slice_id = Column(Integer)
     json = Column(utils.MediumText())
-    user = relationship(
-        security_manager.user_model, backref="logs", foreign_keys=[user_id]
-    )
+    user = relationship("User", backref="logs", foreign_keys=[user_id])
     dttm = Column(DateTime, default=datetime.utcnow)
     duration_ms = Column(Integer)
     referrer = Column(String(1024))
diff --git a/superset/models/helpers.py b/superset/models/helpers.py
index 100391086c..1ebf9244f7 100644
--- a/superset/models/helpers.py
+++ b/superset/models/helpers.py
@@ -105,12 +105,10 @@ if TYPE_CHECKING:
     from superset.models.core import Database
 
 
-config = app.config
 logger = logging.getLogger(__name__)
 
 VIRTUAL_TABLE_ALIAS = "virtual_table"
 SERIES_LIMIT_SUBQ_ALIAS = "series_limit"
-ADVANCED_DATA_TYPES = config["ADVANCED_DATA_TYPES"]
 
 
 def validate_adhoc_subquery(
@@ -1822,12 +1820,12 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
                     and feature_flag_manager.is_feature_enabled(
                         "ENABLE_ADVANCED_DATA_TYPES"
                     )
-                    and col_advanced_data_type in ADVANCED_DATA_TYPES
+                    and col_advanced_data_type in 
app.config["ADVANCED_DATA_TYPES"]
                 ):
                     values = eq if is_list_target else [eq]  # type: ignore
-                    bus_resp: AdvancedDataTypeResponse = ADVANCED_DATA_TYPES[
-                        col_advanced_data_type
-                    ].translate_type(
+                    bus_resp: AdvancedDataTypeResponse = app.config[
+                        "ADVANCED_DATA_TYPES"
+                    ][col_advanced_data_type].translate_type(
                         {
                             "type": col_advanced_data_type,
                             "values": values,
@@ -1839,9 +1837,9 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
                         )
 
                     where_clause_and.append(
-                        
ADVANCED_DATA_TYPES[col_advanced_data_type].translate_filter(
-                            sqla_col, op, bus_resp["values"]
-                        )
+                        app.config["ADVANCED_DATA_TYPES"][
+                            col_advanced_data_type
+                        ].translate_filter(sqla_col, op, bus_resp["values"])
                     )
                 elif is_list_target:
                     assert isinstance(eq, (tuple, list))
diff --git a/superset/models/slice.py b/superset/models/slice.py
index 2a0734b107..7d645ea7fd 100644
--- a/superset/models/slice.py
+++ b/superset/models/slice.py
@@ -39,14 +39,14 @@ from sqlalchemy.engine.base import Connection
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm.mapper import Mapper
 
-from superset import db, is_feature_enabled, security_manager
+from superset import db, is_feature_enabled
 from superset.legacy import update_time_range
 from superset.models.helpers import AuditMixinNullable, ImportExportMixin
 from superset.tasks.thumbnails import cache_chart_thumbnail
 from superset.tasks.utils import get_current_user
 from superset.thumbnails.digest import get_chart_digest
 from superset.utils import core as utils
-from superset.viz import BaseViz, viz_types
+from superset.viz import BaseViz, get_viz_types
 
 if TYPE_CHECKING:
     from superset.common.query_context import QueryContext
@@ -93,11 +93,9 @@ class Slice(  # pylint: disable=too-many-public-methods
     certification_details = Column(Text)
     is_managed_externally = Column(Boolean, nullable=False, default=False)
     external_url = Column(Text, nullable=True)
-    last_saved_by = relationship(
-        security_manager.user_model, foreign_keys=[last_saved_by_fk]
-    )
+    last_saved_by = relationship("User", foreign_keys=[last_saved_by_fk])
     owners = relationship(
-        security_manager.user_model,
+        "User",
         secondary=slice_user,
         passive_deletes=True,
     )
@@ -204,7 +202,7 @@ class Slice(  # pylint: disable=too-many-public-methods
     @property
     def viz(self) -> BaseViz | None:
         form_data = json.loads(self.params)
-        viz_class = viz_types.get(self.viz_type)
+        viz_class = get_viz_types().get(self.viz_type)
         datasource = self.datasource
         if viz_class and datasource:
             return viz_class(datasource=datasource, form_data=form_data)
diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py
index 31443b4bb1..9707b442bb 100644
--- a/superset/models/sql_lab.py
+++ b/superset/models/sql_lab.py
@@ -148,7 +148,7 @@ class Query(
         foreign_keys=[database_id],
         backref=backref("queries", cascade="all, delete-orphan"),
     )
-    user = relationship(security_manager.user_model, foreign_keys=[user_id])
+    user = relationship("User", foreign_keys=[user_id])
 
     __table_args__ = (sqla.Index("ti_user_id_changed_on", user_id, 
changed_on),)
 
@@ -393,7 +393,7 @@ class SavedQuery(
     sql = Column(MediumText())
     template_parameters = Column(Text)
     user = relationship(
-        security_manager.user_model,
+        "User",
         backref=backref("saved_queries", cascade="all, delete-orphan"),
         foreign_keys=[user_id],
     )
diff --git a/superset/models/user_attributes.py 
b/superset/models/user_attributes.py
index 512270c89c..f5438ccc6e 100644
--- a/superset/models/user_attributes.py
+++ b/superset/models/user_attributes.py
@@ -19,7 +19,6 @@ from flask_appbuilder import Model
 from sqlalchemy import Column, ForeignKey, Integer, String
 from sqlalchemy.orm import relationship
 
-from superset import security_manager
 from superset.models.helpers import AuditMixinNullable
 
 
@@ -36,9 +35,7 @@ class UserAttribute(Model, AuditMixinNullable):
     __tablename__ = "user_attribute"
     id = Column(Integer, primary_key=True)
     user_id = Column(Integer, ForeignKey("ab_user.id"))
-    user = relationship(
-        security_manager.user_model, backref="extra_attributes", 
foreign_keys=[user_id]
-    )
+    user = relationship("User", backref="extra_attributes", 
foreign_keys=[user_id])
     welcome_dashboard_id = Column(Integer, ForeignKey("dashboards.id"))
     welcome_dashboard = relationship("Dashboard")
     avatar_url = Column(String(100))
diff --git a/superset/utils/cache.py b/superset/utils/cache.py
index 00216fc4b1..c7ba4e8382 100644
--- a/superset/utils/cache.py
+++ b/superset/utils/cache.py
@@ -20,7 +20,7 @@ import inspect
 import logging
 from datetime import datetime, timedelta
 from functools import wraps
-from typing import Any, Callable, TYPE_CHECKING
+from typing import Any, Callable, Optional, TYPE_CHECKING
 
 from flask import current_app as app, request
 from flask_caching import Cache
@@ -34,10 +34,8 @@ from superset.utils.core import json_int_dttm_ser
 from superset.utils.hashing import md5_sha_from_dict
 
 if TYPE_CHECKING:
-    from superset.stats_logger import BaseStatsLogger
+    pass
 
-config = app.config
-stats_logger: BaseStatsLogger = config["STATS_LOGGER"]
 logger = logging.getLogger(__name__)
 
 
@@ -65,9 +63,9 @@ def set_and_log_cache(
         dttm = datetime.utcnow().isoformat().split(".")[0]
         value = {**cache_value, "dttm": dttm}
         cache_instance.set(cache_key, value, timeout=timeout)
-        stats_logger.incr("set_cache_key")
+        app.stats_logger.incr("set_cache_key")
 
-        if datasource_uid and config["STORE_CACHE_KEYS_IN_METADATA_DB"]:
+        if datasource_uid and app.config["STORE_CACHE_KEYS_IN_METADATA_DB"]:
             ck = CacheKey(
                 cache_key=cache_key,
                 cache_timeout=cache_timeout,
@@ -147,7 +145,7 @@ def memoized_func(key: str, cache: Cache = 
cache_manager.cache) -> Callable[...,
 def etag_cache(
     cache: Cache = cache_manager.cache,
     get_last_modified: Callable[..., datetime] | None = None,
-    max_age: int | float = app.config["CACHE_DEFAULT_TIMEOUT"],
+    max_age: Optional[int | float] = None,
     raise_for_access: Callable[..., Any] | None = None,
     skip: Callable[..., bool] | None = None,
 ) -> Callable[..., Any]:
@@ -163,6 +161,7 @@ def etag_cache(
     dataframe cache for requests that produce the same SQL.
 
     """
+    max_age = max_age or app.config["CACHE_DEFAULT_TIMEOUT"]
 
     def decorator(f: Callable[..., Any]) -> Callable[..., Any]:
         @wraps(f)
diff --git a/superset/utils/encrypt.py b/superset/utils/encrypt.py
index e5bbd13825..ed58a49f6a 100644
--- a/superset/utils/encrypt.py
+++ b/superset/utils/encrypt.py
@@ -18,7 +18,7 @@ import logging
 from abc import ABC, abstractmethod
 from typing import Any, Optional
 
-from flask import Flask
+from flask import current_app as app
 from flask_babel import lazy_gettext as _
 from sqlalchemy import text, TypeDecorator
 from sqlalchemy.engine import Connection, Dialect, Row
@@ -27,6 +27,10 @@ from sqlalchemy_utils import EncryptedType
 logger = logging.getLogger(__name__)
 
 
+def get_key() -> str:
+    return app.config["SECRET_KEY"]
+
+
 class AbstractEncryptedFieldAdapter(ABC):  # pylint: 
disable=too-few-public-methods
     @abstractmethod
     def create(
@@ -47,22 +51,15 @@ class SQLAlchemyUtilsAdapter(  # pylint: 
disable=too-few-public-methods
         *args: list[Any],
         **kwargs: Optional[dict[str, Any]],
     ) -> TypeDecorator:
-        if app_config:
-            return EncryptedType(*args, app_config["SECRET_KEY"], **kwargs)
-
-        raise Exception(  # pylint: disable=broad-exception-raised
-            "Missing app_config kwarg"
-        )
+        return EncryptedType(*args, get_key, **kwargs)
 
 
 class EncryptedFieldFactory:
     def __init__(self) -> None:
         self._concrete_type_adapter: Optional[AbstractEncryptedFieldAdapter] = 
None
-        self._config: Optional[dict[str, Any]] = None
 
-    def init_app(self, app: Flask) -> None:
-        self._config = app.config
-        self._concrete_type_adapter = self._config[  # type: ignore
+    def init_app(self, *args, **kwargs) -> None:  # type: ignore # pylint: 
disable=unused-argument
+        self._concrete_type_adapter = app.config[
             "SQLALCHEMY_ENCRYPTED_FIELD_TYPE_ADAPTER"
         ]()
 
@@ -70,11 +67,8 @@ class EncryptedFieldFactory:
         self, *args: list[Any], **kwargs: Optional[dict[str, Any]]
     ) -> TypeDecorator:
         if self._concrete_type_adapter:
-            return self._concrete_type_adapter.create(self._config, *args, 
**kwargs)
-
-        raise Exception(  # pylint: disable=broad-exception-raised
-            "App not initialized yet. Please call init_app first"
-        )
+            return self._concrete_type_adapter.create(app.config, *args, 
**kwargs)
+        return None
 
 
 class SecretsMigrator:
diff --git a/superset/utils/screenshots.py b/superset/utils/screenshots.py
index bf6ed0f9e8..ad7726dc5d 100644
--- a/superset/utils/screenshots.py
+++ b/superset/utils/screenshots.py
@@ -20,7 +20,7 @@ import logging
 from io import BytesIO
 from typing import TYPE_CHECKING
 
-from flask import current_app
+from flask import current_app as app
 
 from superset import feature_flag_manager
 from superset.utils.hashing import md5_sha_from_dict
@@ -53,7 +53,6 @@ if TYPE_CHECKING:
 
 
 class BaseScreenshot:
-    driver_type = current_app.config["WEBDRIVER_TYPE"]
     thumbnail_type: str = ""
     element: str = ""
     window_size: WindowSize = DEFAULT_SCREENSHOT_WINDOW_SIZE
@@ -66,9 +65,10 @@ class BaseScreenshot:
 
     def driver(self, window_size: WindowSize | None = None) -> WebDriver:
         window_size = window_size or self.window_size
+        driver_type = app.config["WEBDRIVER_TYPE"]
         if 
feature_flag_manager.is_feature_enabled("PLAYWRIGHT_REPORTS_AND_THUMBNAILS"):
-            return WebDriverPlaywright(self.driver_type, window_size)
-        return WebDriverSelenium(self.driver_type, window_size)
+            return WebDriverPlaywright(driver_type, window_size)
+        return WebDriverSelenium(driver_type, window_size)
 
     def cache_key(
         self,
diff --git a/superset/views/utils.py b/superset/views/utils.py
index 2d8fcd68da..c446978d6f 100644
--- a/superset/views/utils.py
+++ b/superset/views/utils.py
@@ -49,7 +49,6 @@ from superset.models.sql_lab import Query
 from superset.superset_typing import FormData
 from superset.utils.core import DatasourceType
 from superset.utils.decorators import stats_timing
-from superset.viz import BaseViz
 
 logger = logging.getLogger(__name__)
 stats_logger = app.config["STATS_LOGGER"]
@@ -124,13 +123,14 @@ def get_viz(
     datasource_id: int,
     force: bool = False,
     force_cached: bool = False,
-) -> BaseViz:
+) -> viz.BaseViz:
     viz_type = form_data.get("viz_type", "table")
     datasource = DatasourceDAO.get_datasource(
         DatasourceType(datasource_type),
         datasource_id,
     )
-    viz_obj = viz.viz_types[viz_type](
+    viz_types = viz.get_viz_types()
+    viz_obj = viz_types[viz_type](
         datasource, form_data=form_data, force=force, force_cached=force_cached
     )
     return viz_obj
diff --git a/superset/viz.py b/superset/viz.py
index 5a4b323079..654e156185 100644
--- a/superset/viz.py
+++ b/superset/viz.py
@@ -86,10 +86,6 @@ from superset.utils.hashing import md5_sha_from_str
 if TYPE_CHECKING:
     from superset.connectors.sqla.models import BaseDatasource
 
-config = app.config
-stats_logger = config["STATS_LOGGER"]
-relative_start = config["DEFAULT_RELATIVE_START_TIME"]
-relative_end = config["DEFAULT_RELATIVE_END_TIME"]
 logger = logging.getLogger(__name__)
 
 METRIC_KEYS = [
@@ -132,6 +128,7 @@ class BaseViz:  # pylint: disable=too-many-public-methods
 
         self.query = ""
         self.token = utils.get_form_data_token(form_data)
+        self.stats_logger = app.config["STATS_LOGGER"]
 
         self.groupby: list[Column] = self.form_data.get("groupby") or []
         self.time_shift = timedelta()
@@ -250,7 +247,7 @@ class BaseViz:  # pylint: disable=too-many-public-methods
                 "groupby": [],
                 "metrics": [],
                 "orderby": [],
-                "row_limit": config["SAMPLES_ROW_LIMIT"],
+                "row_limit": app.config["SAMPLES_ROW_LIMIT"],
                 "columns": [o.column_name for o in self.datasource.columns],
                 "from_dttm": None,
                 "to_dttm": None,
@@ -362,7 +359,7 @@ class BaseViz:  # pylint: disable=too-many-public-methods
         timeseries_limit_metric = self.form_data.get("timeseries_limit_metric")
 
         # apply row limit to query
-        row_limit = int(self.form_data.get("row_limit") or config["ROW_LIMIT"])
+        row_limit = int(self.form_data.get("row_limit") or 
app.config["ROW_LIMIT"])
         row_limit = apply_max_row_limit(row_limit)
 
         # default order direction
@@ -370,8 +367,8 @@ class BaseViz:  # pylint: disable=too-many-public-methods
 
         try:
             since, until = get_since_until(
-                relative_start=relative_start,
-                relative_end=relative_end,
+                relative_start=app.config["DEFAULT_RELATIVE_START_TIME"],
+                relative_end=app.config["DEFAULT_RELATIVE_END_TIME"],
                 time_range=self.form_data.get("time_range"),
                 since=self.form_data.get("since"),
                 until=self.form_data.get("until"),
@@ -434,9 +431,9 @@ class BaseViz:  # pylint: disable=too-many-public-methods
             and self.datasource.database.cache_timeout
         ) is not None:
             return self.datasource.database.cache_timeout
-        if config["DATA_CACHE_CONFIG"].get("CACHE_DEFAULT_TIMEOUT") is not 
None:
-            return config["DATA_CACHE_CONFIG"]["CACHE_DEFAULT_TIMEOUT"]
-        return config["CACHE_DEFAULT_TIMEOUT"]
+        if app.config["DATA_CACHE_CONFIG"].get("CACHE_DEFAULT_TIMEOUT") is not 
None:
+            return app.config["DATA_CACHE_CONFIG"]["CACHE_DEFAULT_TIMEOUT"]
+        return app.config["CACHE_DEFAULT_TIMEOUT"]
 
     @deprecated(deprecated_in="3.0")
     def get_json(self) -> str:
@@ -532,7 +529,7 @@ class BaseViz:  # pylint: disable=too-many-public-methods
         if cache_key and cache_manager.data_cache and not force:
             cache_value = cache_manager.data_cache.get(cache_key)
             if cache_value:
-                stats_logger.incr("loading_from_cache")
+                self.stats_logger.incr("loading_from_cache")
                 try:
                     df = cache_value["df"]
                     self.query = cache_value["query"]
@@ -544,7 +541,7 @@ class BaseViz:  # pylint: disable=too-many-public-methods
                     )
                     self.status = QueryStatus.SUCCESS
                     is_loaded = True
-                    stats_logger.incr("loaded_from_cache")
+                    self.stats_logger.incr("loaded_from_cache")
                 except Exception as ex:  # pylint: disable=broad-except
                     logger.exception(ex)
                     logger.error(
@@ -582,9 +579,9 @@ class BaseViz:  # pylint: disable=too-many-public-methods
                     )
                 df = self.get_df(query_obj)
                 if self.status != QueryStatus.FAILED:
-                    stats_logger.incr("loaded_from_source")
+                    self.stats_logger.incr("loaded_from_source")
                     if not self.force:
-                        stats_logger.incr("loaded_from_source_without_force")
+                        
self.stats_logger.incr("loaded_from_source_without_force")
                     is_loaded = True
             except QueryObjectValidationError as ex:
                 error = dataclasses.asdict(
@@ -677,7 +674,9 @@ class BaseViz:  # pylint: disable=too-many-public-methods
     def get_csv(self) -> str | None:
         df = self.get_df_payload()["df"]  # leverage caching logic
         include_index = not isinstance(df.index, pd.RangeIndex)
-        return csv.df_to_escaped_csv(df, index=include_index, 
**config["CSV_EXPORT"])
+        return csv.df_to_escaped_csv(
+            df, index=include_index, **app.config["CSV_EXPORT"]
+        )
 
     @deprecated(deprecated_in="3.0")
     def get_data(self, df: pd.DataFrame) -> VizData:
@@ -772,8 +771,8 @@ class CalHeatmapViz(BaseViz):
 
         try:
             start, end = get_since_until(
-                relative_start=relative_start,
-                relative_end=relative_end,
+                relative_start=app.config["DEFAULT_RELATIVE_START_TIME"],
+                relative_end=app.config["DEFAULT_RELATIVE_END_TIME"],
                 time_range=form_data.get("time_range"),
                 since=form_data.get("since"),
                 until=form_data.get("until"),
@@ -1794,7 +1793,7 @@ class MapboxViz(BaseViz):
         return {
             "geoJSON": geo_json,
             "hasCustomMetric": has_custom_metric,
-            "mapboxApiKey": config["MAPBOX_API_KEY"],
+            "mapboxApiKey": app.config["MAPBOX_API_KEY"],
             "mapStyle": self.form_data.get("mapbox_style"),
             "aggregatorName": self.form_data.get("pandas_aggfunc"),
             "clusteringRadius": self.form_data.get("clustering_radius"),
@@ -1830,7 +1829,7 @@ class DeckGLMultiLayer(BaseViz):
         slice_ids = self.form_data.get("deck_slices")
         slices = db.session.query(Slice).filter(Slice.id.in_(slice_ids)).all()
         return {
-            "mapboxApiKey": config["MAPBOX_API_KEY"],
+            "mapboxApiKey": app.config["MAPBOX_API_KEY"],
             "slices": [slc.data for slc in slices],
         }
 
@@ -2002,7 +2001,7 @@ class BaseDeckGLViz(BaseViz):
 
         return {
             "features": features,
-            "mapboxApiKey": config["MAPBOX_API_KEY"],
+            "mapboxApiKey": app.config["MAPBOX_API_KEY"],
             "metricLabels": self.metric_labels,
         }
 
@@ -2325,7 +2324,7 @@ class DeckArc(BaseDeckGLViz):
 
         return {
             "features": super().get_data(df)["features"],
-            "mapboxApiKey": config["MAPBOX_API_KEY"],
+            "mapboxApiKey": app.config["MAPBOX_API_KEY"],
         }
 
 
@@ -2660,8 +2659,9 @@ def get_subclasses(cls: type[BaseViz]) -> 
set[type[BaseViz]]:
     )
 
 
-viz_types = {
-    o.viz_type: o
-    for o in get_subclasses(BaseViz)
-    if o.viz_type not in config["VIZ_TYPE_DENYLIST"]
-}
+def get_viz_types() -> dict[str, type[BaseViz]]:
+    return {
+        o.viz_type: o
+        for o in get_subclasses(BaseViz)
+        if o.viz_type not in app.config["VIZ_TYPE_DENYLIST"]
+    }
diff --git a/test.py b/test.py
new file mode 100644
index 0000000000..4da9b14742
--- /dev/null
+++ b/test.py
@@ -0,0 +1 @@
+print("YO")
diff --git a/test_imports.py b/test_imports.py
new file mode 100644
index 0000000000..b51d69d8c9
--- /dev/null
+++ b/test_imports.py
@@ -0,0 +1,31 @@
+import os
+
+# mypy: ignore-errors
+
+# pylint: skip-file
+# pylint: disable-all
+
+
+def generate_import_statements(root_dir):
+    import_statements = set()
+    for dirpath, _, filenames in os.walk(root_dir):
+        for filename in filenames:
+            if filename.endswith(".py"):
+                module_path = os.path.relpath(os.path.join(dirpath, filename), 
root_dir)
+                module_name = 
os.path.splitext(module_path.replace(os.path.sep, "."))[0]
+                splitted = module_name.split(".")
+                if splitted[-1] == "__init__":
+                    splitted.pop()
+                if len(splitted) > 1 and "migration" not in module_path:
+                    package = ".".join(splitted[:-1])
+                    import_statements.add(
+                        f"from superset.{package} import {splitted[-1]}"
+                    )
+    return sorted(import_statements)
+
+
+if __name__ == "__main__":
+    root_dir = "superset"
+    import_statements = generate_import_statements(root_dir)
+    for statement in import_statements:
+        print(statement)

Reply via email to