This is an automated email from the ASF dual-hosted git repository.

maximebeauchemin pushed a commit to branch late_config
in repository https://gitbox.apache.org/repos/asf/superset.git

commit aa4742353b5b57d45e2a47b9d1e0288789153548
Author: Maxime Beauchemin <[email protected]>
AuthorDate: Tue May 7 18:12:33 2024 -0700

    fixes
---
 superset/commands/chart/warm_up_cache.py   |  4 +--
 superset/common/query_context_processor.py |  4 +--
 superset/connectors/sqla/models.py         | 10 +++---
 superset/models/slice.py                   | 12 +++----
 superset/utils/screenshots.py              |  8 ++---
 superset/views/utils.py                    |  6 ++--
 superset/viz.py                            | 54 +++++++++++++++---------------
 7 files changed, 47 insertions(+), 51 deletions(-)

diff --git a/superset/commands/chart/warm_up_cache.py 
b/superset/commands/chart/warm_up_cache.py
index 2e5c0ac3a3..80aa045e17 100644
--- a/superset/commands/chart/warm_up_cache.py
+++ b/superset/commands/chart/warm_up_cache.py
@@ -31,7 +31,7 @@ from superset.extensions import db
 from superset.models.slice import Slice
 from superset.utils.core import error_msg_from_exception
 from superset.views.utils import get_dashboard_extra_filters, get_form_data, 
get_viz
-from superset.viz import viz_types
+from superset.viz import get_viz_types
 
 
 class ChartWarmUpCacheCommand(BaseCommand):
@@ -52,7 +52,7 @@ class ChartWarmUpCacheCommand(BaseCommand):
         try:
             form_data = get_form_data(chart.id, use_slice_data=True)[0]
 
-            if form_data.get("viz_type") in viz_types:
+            if form_data.get("viz_type") in get_viz_types():
                 # Legacy visualizations.
                 if not chart.datasource:
                     raise ChartInvalidError("Chart's datasource does not 
exist")
diff --git a/superset/common/query_context_processor.py 
b/superset/common/query_context_processor.py
index 55c80386a3..f10d1717e2 100644
--- a/superset/common/query_context_processor.py
+++ b/superset/common/query_context_processor.py
@@ -66,7 +66,7 @@ from superset.utils.core import (
 from superset.utils.date_parser import get_past_or_future, normalize_time_delta
 from superset.utils.pandas_postprocessing.utils import unescape_separator
 from superset.views.utils import get_viz
-from superset.viz import viz_types
+from superset.viz import get_viz_types
 
 if TYPE_CHECKING:
     from superset.common.query_context import QueryContext
@@ -704,7 +704,7 @@ class QueryContextProcessor:
             raise QueryObjectValidationError(_("The chart does not exist"))
 
         try:
-            if chart.viz_type in viz_types:
+            if chart.viz_type in get_viz_types():
                 if not chart.datasource:
                     raise QueryObjectValidationError(
                         _("The chart datasource does not exist"),
diff --git a/superset/connectors/sqla/models.py 
b/superset/connectors/sqla/models.py
index 12fbdc3bd7..b4eee0c42e 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -122,10 +122,8 @@ from superset.utils import core as utils
 from superset.utils.backports import StrEnum
 from superset.utils.core import GenericDataType, MediumText
 
-config = app.config
 metadata = Model.metadata  # pylint: disable=no-member
 logger = logging.getLogger(__name__)
-ADVANCED_DATA_TYPES = config["ADVANCED_DATA_TYPES"]
 VIRTUAL_TABLE_ALIAS = "virtual_table"
 
 # a non-exhaustive set of additive metrics
@@ -1120,7 +1118,7 @@ class SqlaTable(
     )
     metric_class = SqlMetric
     column_class = TableColumn
-    owner_class = security_manager.user_model
+    owner_class = "User"
 
     __tablename__ = "tables"
 
@@ -1358,7 +1356,7 @@ class SqlaTable(
 
     @property
     def health_check_message(self) -> str | None:
-        check = config["DATASET_HEALTH_CHECK"]
+        check = app.config["DATASET_HEALTH_CHECK"]
         return check(self) if check else None
 
     @property
@@ -1876,7 +1874,7 @@ class SqlaTable(
         self.add_missing_metrics(metrics)
 
         # Apply config supplied mutations.
-        config["SQLA_TABLE_MUTATOR"](self)
+        app.config["SQLA_TABLE_MUTATOR"](self)
 
         db.session.merge(self)
         if commit:
@@ -2105,7 +2103,7 @@ class RowLevelSecurityFilter(Model, AuditMixinNullable):
     )
     group_key = Column(String(255), nullable=True)
     roles = relationship(
-        security_manager.role_model,
+        "Role",
         secondary=RLSFilterRoles,
         backref="row_level_security_filters",
     )
diff --git a/superset/models/slice.py b/superset/models/slice.py
index 2a0734b107..7d645ea7fd 100644
--- a/superset/models/slice.py
+++ b/superset/models/slice.py
@@ -39,14 +39,14 @@ from sqlalchemy.engine.base import Connection
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm.mapper import Mapper
 
-from superset import db, is_feature_enabled, security_manager
+from superset import db, is_feature_enabled
 from superset.legacy import update_time_range
 from superset.models.helpers import AuditMixinNullable, ImportExportMixin
 from superset.tasks.thumbnails import cache_chart_thumbnail
 from superset.tasks.utils import get_current_user
 from superset.thumbnails.digest import get_chart_digest
 from superset.utils import core as utils
-from superset.viz import BaseViz, viz_types
+from superset.viz import BaseViz, get_viz_types
 
 if TYPE_CHECKING:
     from superset.common.query_context import QueryContext
@@ -93,11 +93,9 @@ class Slice(  # pylint: disable=too-many-public-methods
     certification_details = Column(Text)
     is_managed_externally = Column(Boolean, nullable=False, default=False)
     external_url = Column(Text, nullable=True)
-    last_saved_by = relationship(
-        security_manager.user_model, foreign_keys=[last_saved_by_fk]
-    )
+    last_saved_by = relationship("User", foreign_keys=[last_saved_by_fk])
     owners = relationship(
-        security_manager.user_model,
+        "User",
         secondary=slice_user,
         passive_deletes=True,
     )
@@ -204,7 +202,7 @@ class Slice(  # pylint: disable=too-many-public-methods
     @property
     def viz(self) -> BaseViz | None:
         form_data = json.loads(self.params)
-        viz_class = viz_types.get(self.viz_type)
+        viz_class = get_viz_types().get(self.viz_type)
         datasource = self.datasource
         if viz_class and datasource:
             return viz_class(datasource=datasource, form_data=form_data)
diff --git a/superset/utils/screenshots.py b/superset/utils/screenshots.py
index bf6ed0f9e8..ad7726dc5d 100644
--- a/superset/utils/screenshots.py
+++ b/superset/utils/screenshots.py
@@ -20,7 +20,7 @@ import logging
 from io import BytesIO
 from typing import TYPE_CHECKING
 
-from flask import current_app
+from flask import current_app as app
 
 from superset import feature_flag_manager
 from superset.utils.hashing import md5_sha_from_dict
@@ -53,7 +53,6 @@ if TYPE_CHECKING:
 
 
 class BaseScreenshot:
-    driver_type = current_app.config["WEBDRIVER_TYPE"]
     thumbnail_type: str = ""
     element: str = ""
     window_size: WindowSize = DEFAULT_SCREENSHOT_WINDOW_SIZE
@@ -66,9 +65,10 @@ class BaseScreenshot:
 
     def driver(self, window_size: WindowSize | None = None) -> WebDriver:
         window_size = window_size or self.window_size
+        driver_type = app.config["WEBDRIVER_TYPE"]
         if 
feature_flag_manager.is_feature_enabled("PLAYWRIGHT_REPORTS_AND_THUMBNAILS"):
-            return WebDriverPlaywright(self.driver_type, window_size)
-        return WebDriverSelenium(self.driver_type, window_size)
+            return WebDriverPlaywright(driver_type, window_size)
+        return WebDriverSelenium(driver_type, window_size)
 
     def cache_key(
         self,
diff --git a/superset/views/utils.py b/superset/views/utils.py
index 2d8fcd68da..c446978d6f 100644
--- a/superset/views/utils.py
+++ b/superset/views/utils.py
@@ -49,7 +49,6 @@ from superset.models.sql_lab import Query
 from superset.superset_typing import FormData
 from superset.utils.core import DatasourceType
 from superset.utils.decorators import stats_timing
-from superset.viz import BaseViz
 
 logger = logging.getLogger(__name__)
 stats_logger = app.config["STATS_LOGGER"]
@@ -124,13 +123,14 @@ def get_viz(
     datasource_id: int,
     force: bool = False,
     force_cached: bool = False,
-) -> BaseViz:
+) -> viz.BaseViz:
     viz_type = form_data.get("viz_type", "table")
     datasource = DatasourceDAO.get_datasource(
         DatasourceType(datasource_type),
         datasource_id,
     )
-    viz_obj = viz.viz_types[viz_type](
+    viz_types = viz.get_viz_types()
+    viz_obj = viz_types[viz_type](
         datasource, form_data=form_data, force=force, force_cached=force_cached
     )
     return viz_obj
diff --git a/superset/viz.py b/superset/viz.py
index 5a4b323079..654e156185 100644
--- a/superset/viz.py
+++ b/superset/viz.py
@@ -86,10 +86,6 @@ from superset.utils.hashing import md5_sha_from_str
 if TYPE_CHECKING:
     from superset.connectors.sqla.models import BaseDatasource
 
-config = app.config
-stats_logger = config["STATS_LOGGER"]
-relative_start = config["DEFAULT_RELATIVE_START_TIME"]
-relative_end = config["DEFAULT_RELATIVE_END_TIME"]
 logger = logging.getLogger(__name__)
 
 METRIC_KEYS = [
@@ -132,6 +128,7 @@ class BaseViz:  # pylint: disable=too-many-public-methods
 
         self.query = ""
         self.token = utils.get_form_data_token(form_data)
+        self.stats_logger = app.config["STATS_LOGGER"]
 
         self.groupby: list[Column] = self.form_data.get("groupby") or []
         self.time_shift = timedelta()
@@ -250,7 +247,7 @@ class BaseViz:  # pylint: disable=too-many-public-methods
                 "groupby": [],
                 "metrics": [],
                 "orderby": [],
-                "row_limit": config["SAMPLES_ROW_LIMIT"],
+                "row_limit": app.config["SAMPLES_ROW_LIMIT"],
                 "columns": [o.column_name for o in self.datasource.columns],
                 "from_dttm": None,
                 "to_dttm": None,
@@ -362,7 +359,7 @@ class BaseViz:  # pylint: disable=too-many-public-methods
         timeseries_limit_metric = self.form_data.get("timeseries_limit_metric")
 
         # apply row limit to query
-        row_limit = int(self.form_data.get("row_limit") or config["ROW_LIMIT"])
+        row_limit = int(self.form_data.get("row_limit") or 
app.config["ROW_LIMIT"])
         row_limit = apply_max_row_limit(row_limit)
 
         # default order direction
@@ -370,8 +367,8 @@ class BaseViz:  # pylint: disable=too-many-public-methods
 
         try:
             since, until = get_since_until(
-                relative_start=relative_start,
-                relative_end=relative_end,
+                relative_start=app.config["DEFAULT_RELATIVE_START_TIME"],
+                relative_end=app.config["DEFAULT_RELATIVE_END_TIME"],
                 time_range=self.form_data.get("time_range"),
                 since=self.form_data.get("since"),
                 until=self.form_data.get("until"),
@@ -434,9 +431,9 @@ class BaseViz:  # pylint: disable=too-many-public-methods
             and self.datasource.database.cache_timeout
         ) is not None:
             return self.datasource.database.cache_timeout
-        if config["DATA_CACHE_CONFIG"].get("CACHE_DEFAULT_TIMEOUT") is not 
None:
-            return config["DATA_CACHE_CONFIG"]["CACHE_DEFAULT_TIMEOUT"]
-        return config["CACHE_DEFAULT_TIMEOUT"]
+        if app.config["DATA_CACHE_CONFIG"].get("CACHE_DEFAULT_TIMEOUT") is not 
None:
+            return app.config["DATA_CACHE_CONFIG"]["CACHE_DEFAULT_TIMEOUT"]
+        return app.config["CACHE_DEFAULT_TIMEOUT"]
 
     @deprecated(deprecated_in="3.0")
     def get_json(self) -> str:
@@ -532,7 +529,7 @@ class BaseViz:  # pylint: disable=too-many-public-methods
         if cache_key and cache_manager.data_cache and not force:
             cache_value = cache_manager.data_cache.get(cache_key)
             if cache_value:
-                stats_logger.incr("loading_from_cache")
+                self.stats_logger.incr("loading_from_cache")
                 try:
                     df = cache_value["df"]
                     self.query = cache_value["query"]
@@ -544,7 +541,7 @@ class BaseViz:  # pylint: disable=too-many-public-methods
                     )
                     self.status = QueryStatus.SUCCESS
                     is_loaded = True
-                    stats_logger.incr("loaded_from_cache")
+                    self.stats_logger.incr("loaded_from_cache")
                 except Exception as ex:  # pylint: disable=broad-except
                     logger.exception(ex)
                     logger.error(
@@ -582,9 +579,9 @@ class BaseViz:  # pylint: disable=too-many-public-methods
                     )
                 df = self.get_df(query_obj)
                 if self.status != QueryStatus.FAILED:
-                    stats_logger.incr("loaded_from_source")
+                    self.stats_logger.incr("loaded_from_source")
                     if not self.force:
-                        stats_logger.incr("loaded_from_source_without_force")
+                        
self.stats_logger.incr("loaded_from_source_without_force")
                     is_loaded = True
             except QueryObjectValidationError as ex:
                 error = dataclasses.asdict(
@@ -677,7 +674,9 @@ class BaseViz:  # pylint: disable=too-many-public-methods
     def get_csv(self) -> str | None:
         df = self.get_df_payload()["df"]  # leverage caching logic
         include_index = not isinstance(df.index, pd.RangeIndex)
-        return csv.df_to_escaped_csv(df, index=include_index, 
**config["CSV_EXPORT"])
+        return csv.df_to_escaped_csv(
+            df, index=include_index, **app.config["CSV_EXPORT"]
+        )
 
     @deprecated(deprecated_in="3.0")
     def get_data(self, df: pd.DataFrame) -> VizData:
@@ -772,8 +771,8 @@ class CalHeatmapViz(BaseViz):
 
         try:
             start, end = get_since_until(
-                relative_start=relative_start,
-                relative_end=relative_end,
+                relative_start=app.config["DEFAULT_RELATIVE_START_TIME"],
+                relative_end=app.config["DEFAULT_RELATIVE_END_TIME"],
                 time_range=form_data.get("time_range"),
                 since=form_data.get("since"),
                 until=form_data.get("until"),
@@ -1794,7 +1793,7 @@ class MapboxViz(BaseViz):
         return {
             "geoJSON": geo_json,
             "hasCustomMetric": has_custom_metric,
-            "mapboxApiKey": config["MAPBOX_API_KEY"],
+            "mapboxApiKey": app.config["MAPBOX_API_KEY"],
             "mapStyle": self.form_data.get("mapbox_style"),
             "aggregatorName": self.form_data.get("pandas_aggfunc"),
             "clusteringRadius": self.form_data.get("clustering_radius"),
@@ -1830,7 +1829,7 @@ class DeckGLMultiLayer(BaseViz):
         slice_ids = self.form_data.get("deck_slices")
         slices = db.session.query(Slice).filter(Slice.id.in_(slice_ids)).all()
         return {
-            "mapboxApiKey": config["MAPBOX_API_KEY"],
+            "mapboxApiKey": app.config["MAPBOX_API_KEY"],
             "slices": [slc.data for slc in slices],
         }
 
@@ -2002,7 +2001,7 @@ class BaseDeckGLViz(BaseViz):
 
         return {
             "features": features,
-            "mapboxApiKey": config["MAPBOX_API_KEY"],
+            "mapboxApiKey": app.config["MAPBOX_API_KEY"],
             "metricLabels": self.metric_labels,
         }
 
@@ -2325,7 +2324,7 @@ class DeckArc(BaseDeckGLViz):
 
         return {
             "features": super().get_data(df)["features"],
-            "mapboxApiKey": config["MAPBOX_API_KEY"],
+            "mapboxApiKey": app.config["MAPBOX_API_KEY"],
         }
 
 
@@ -2660,8 +2659,9 @@ def get_subclasses(cls: type[BaseViz]) -> 
set[type[BaseViz]]:
     )
 
 
-viz_types = {
-    o.viz_type: o
-    for o in get_subclasses(BaseViz)
-    if o.viz_type not in config["VIZ_TYPE_DENYLIST"]
-}
+def get_viz_types() -> dict[str, type[BaseViz]]:
+    return {
+        o.viz_type: o
+        for o in get_subclasses(BaseViz)
+        if o.viz_type not in app.config["VIZ_TYPE_DENYLIST"]
+    }

Reply via email to