This is an automated email from the ASF dual-hosted git repository. maximebeauchemin pushed a commit to branch late_config in repository https://gitbox.apache.org/repos/asf/superset.git
commit aa4742353b5b57d45e2a47b9d1e0288789153548 Author: Maxime Beauchemin <[email protected]> AuthorDate: Tue May 7 18:12:33 2024 -0700 fixes --- superset/commands/chart/warm_up_cache.py | 4 +-- superset/common/query_context_processor.py | 4 +-- superset/connectors/sqla/models.py | 10 +++--- superset/models/slice.py | 12 +++---- superset/utils/screenshots.py | 8 ++--- superset/views/utils.py | 6 ++-- superset/viz.py | 54 +++++++++++++++--------------- 7 files changed, 47 insertions(+), 51 deletions(-) diff --git a/superset/commands/chart/warm_up_cache.py b/superset/commands/chart/warm_up_cache.py index 2e5c0ac3a3..80aa045e17 100644 --- a/superset/commands/chart/warm_up_cache.py +++ b/superset/commands/chart/warm_up_cache.py @@ -31,7 +31,7 @@ from superset.extensions import db from superset.models.slice import Slice from superset.utils.core import error_msg_from_exception from superset.views.utils import get_dashboard_extra_filters, get_form_data, get_viz -from superset.viz import viz_types +from superset.viz import get_viz_types class ChartWarmUpCacheCommand(BaseCommand): @@ -52,7 +52,7 @@ class ChartWarmUpCacheCommand(BaseCommand): try: form_data = get_form_data(chart.id, use_slice_data=True)[0] - if form_data.get("viz_type") in viz_types: + if form_data.get("viz_type") in get_viz_types(): # Legacy visualizations. if not chart.datasource: raise ChartInvalidError("Chart's datasource does not exist") diff --git a/superset/common/query_context_processor.py b/superset/common/query_context_processor.py index 55c80386a3..f10d1717e2 100644 --- a/superset/common/query_context_processor.py +++ b/superset/common/query_context_processor.py @@ -66,7 +66,7 @@ from superset.utils.core import ( from superset.utils.date_parser import get_past_or_future, normalize_time_delta from superset.utils.pandas_postprocessing.utils import unescape_separator from superset.views.utils import get_viz -from superset.viz import viz_types +from superset.viz import get_viz_types if TYPE_CHECKING: from superset.common.query_context import QueryContext @@ -704,7 +704,7 @@ class QueryContextProcessor: raise QueryObjectValidationError(_("The chart does not exist")) try: - if chart.viz_type in viz_types: + if chart.viz_type in get_viz_types(): if not chart.datasource: raise QueryObjectValidationError( _("The chart datasource does not exist"), diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 12fbdc3bd7..b4eee0c42e 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -122,10 +122,8 @@ from superset.utils import core as utils from superset.utils.backports import StrEnum from superset.utils.core import GenericDataType, MediumText -config = app.config metadata = Model.metadata # pylint: disable=no-member logger = logging.getLogger(__name__) -ADVANCED_DATA_TYPES = config["ADVANCED_DATA_TYPES"] VIRTUAL_TABLE_ALIAS = "virtual_table" # a non-exhaustive set of additive metrics @@ -1120,7 +1118,7 @@ class SqlaTable( ) metric_class = SqlMetric column_class = TableColumn - owner_class = security_manager.user_model + owner_class = "User" __tablename__ = "tables" @@ -1358,7 +1356,7 @@ class SqlaTable( @property def health_check_message(self) -> str | None: - check = config["DATASET_HEALTH_CHECK"] + check = app.config["DATASET_HEALTH_CHECK"] return check(self) if check else None @property @@ -1876,7 +1874,7 @@ class SqlaTable( self.add_missing_metrics(metrics) # Apply config supplied mutations. - config["SQLA_TABLE_MUTATOR"](self) + app.config["SQLA_TABLE_MUTATOR"](self) db.session.merge(self) if commit: @@ -2105,7 +2103,7 @@ class RowLevelSecurityFilter(Model, AuditMixinNullable): ) group_key = Column(String(255), nullable=True) roles = relationship( - security_manager.role_model, + "Role", secondary=RLSFilterRoles, backref="row_level_security_filters", ) diff --git a/superset/models/slice.py b/superset/models/slice.py index 2a0734b107..7d645ea7fd 100644 --- a/superset/models/slice.py +++ b/superset/models/slice.py @@ -39,14 +39,14 @@ from sqlalchemy.engine.base import Connection from sqlalchemy.orm import relationship from sqlalchemy.orm.mapper import Mapper -from superset import db, is_feature_enabled, security_manager +from superset import db, is_feature_enabled from superset.legacy import update_time_range from superset.models.helpers import AuditMixinNullable, ImportExportMixin from superset.tasks.thumbnails import cache_chart_thumbnail from superset.tasks.utils import get_current_user from superset.thumbnails.digest import get_chart_digest from superset.utils import core as utils -from superset.viz import BaseViz, viz_types +from superset.viz import BaseViz, get_viz_types if TYPE_CHECKING: from superset.common.query_context import QueryContext @@ -93,11 +93,9 @@ class Slice( # pylint: disable=too-many-public-methods certification_details = Column(Text) is_managed_externally = Column(Boolean, nullable=False, default=False) external_url = Column(Text, nullable=True) - last_saved_by = relationship( - security_manager.user_model, foreign_keys=[last_saved_by_fk] - ) + last_saved_by = relationship("User", foreign_keys=[last_saved_by_fk]) owners = relationship( - security_manager.user_model, + "User", secondary=slice_user, passive_deletes=True, ) @@ -204,7 +202,7 @@ class Slice( # pylint: disable=too-many-public-methods @property def viz(self) -> BaseViz | None: form_data = json.loads(self.params) - viz_class = viz_types.get(self.viz_type) + viz_class = get_viz_types().get(self.viz_type) datasource = self.datasource if viz_class and datasource: return viz_class(datasource=datasource, form_data=form_data) diff --git a/superset/utils/screenshots.py b/superset/utils/screenshots.py index bf6ed0f9e8..ad7726dc5d 100644 --- a/superset/utils/screenshots.py +++ b/superset/utils/screenshots.py @@ -20,7 +20,7 @@ import logging from io import BytesIO from typing import TYPE_CHECKING -from flask import current_app +from flask import current_app as app from superset import feature_flag_manager from superset.utils.hashing import md5_sha_from_dict @@ -53,7 +53,6 @@ if TYPE_CHECKING: class BaseScreenshot: - driver_type = current_app.config["WEBDRIVER_TYPE"] thumbnail_type: str = "" element: str = "" window_size: WindowSize = DEFAULT_SCREENSHOT_WINDOW_SIZE @@ -66,9 +65,10 @@ class BaseScreenshot: def driver(self, window_size: WindowSize | None = None) -> WebDriver: window_size = window_size or self.window_size + driver_type = app.config["WEBDRIVER_TYPE"] if feature_flag_manager.is_feature_enabled("PLAYWRIGHT_REPORTS_AND_THUMBNAILS"): - return WebDriverPlaywright(self.driver_type, window_size) - return WebDriverSelenium(self.driver_type, window_size) + return WebDriverPlaywright(driver_type, window_size) + return WebDriverSelenium(driver_type, window_size) def cache_key( self, diff --git a/superset/views/utils.py b/superset/views/utils.py index 2d8fcd68da..c446978d6f 100644 --- a/superset/views/utils.py +++ b/superset/views/utils.py @@ -49,7 +49,6 @@ from superset.models.sql_lab import Query from superset.superset_typing import FormData from superset.utils.core import DatasourceType from superset.utils.decorators import stats_timing -from superset.viz import BaseViz logger = logging.getLogger(__name__) stats_logger = app.config["STATS_LOGGER"] @@ -124,13 +123,14 @@ def get_viz( datasource_id: int, force: bool = False, force_cached: bool = False, -) -> BaseViz: +) -> viz.BaseViz: viz_type = form_data.get("viz_type", "table") datasource = DatasourceDAO.get_datasource( DatasourceType(datasource_type), datasource_id, ) - viz_obj = viz.viz_types[viz_type]( + viz_types = viz.get_viz_types() + viz_obj = viz_types[viz_type]( datasource, form_data=form_data, force=force, force_cached=force_cached ) return viz_obj diff --git a/superset/viz.py b/superset/viz.py index 5a4b323079..654e156185 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -86,10 +86,6 @@ from superset.utils.hashing import md5_sha_from_str if TYPE_CHECKING: from superset.connectors.sqla.models import BaseDatasource -config = app.config -stats_logger = config["STATS_LOGGER"] -relative_start = config["DEFAULT_RELATIVE_START_TIME"] -relative_end = config["DEFAULT_RELATIVE_END_TIME"] logger = logging.getLogger(__name__) METRIC_KEYS = [ @@ -132,6 +128,7 @@ class BaseViz: # pylint: disable=too-many-public-methods self.query = "" self.token = utils.get_form_data_token(form_data) + self.stats_logger = app.config["STATS_LOGGER"] self.groupby: list[Column] = self.form_data.get("groupby") or [] self.time_shift = timedelta() @@ -250,7 +247,7 @@ class BaseViz: # pylint: disable=too-many-public-methods "groupby": [], "metrics": [], "orderby": [], - "row_limit": config["SAMPLES_ROW_LIMIT"], + "row_limit": app.config["SAMPLES_ROW_LIMIT"], "columns": [o.column_name for o in self.datasource.columns], "from_dttm": None, "to_dttm": None, @@ -362,7 +359,7 @@ class BaseViz: # pylint: disable=too-many-public-methods timeseries_limit_metric = self.form_data.get("timeseries_limit_metric") # apply row limit to query - row_limit = int(self.form_data.get("row_limit") or config["ROW_LIMIT"]) + row_limit = int(self.form_data.get("row_limit") or app.config["ROW_LIMIT"]) row_limit = apply_max_row_limit(row_limit) # default order direction @@ -370,8 +367,8 @@ class BaseViz: # pylint: disable=too-many-public-methods try: since, until = get_since_until( - relative_start=relative_start, - relative_end=relative_end, + relative_start=app.config["DEFAULT_RELATIVE_START_TIME"], + relative_end=app.config["DEFAULT_RELATIVE_END_TIME"], time_range=self.form_data.get("time_range"), since=self.form_data.get("since"), until=self.form_data.get("until"), @@ -434,9 +431,9 @@ class BaseViz: # pylint: disable=too-many-public-methods and self.datasource.database.cache_timeout ) is not None: return self.datasource.database.cache_timeout - if config["DATA_CACHE_CONFIG"].get("CACHE_DEFAULT_TIMEOUT") is not None: - return config["DATA_CACHE_CONFIG"]["CACHE_DEFAULT_TIMEOUT"] - return config["CACHE_DEFAULT_TIMEOUT"] + if app.config["DATA_CACHE_CONFIG"].get("CACHE_DEFAULT_TIMEOUT") is not None: + return app.config["DATA_CACHE_CONFIG"]["CACHE_DEFAULT_TIMEOUT"] + return app.config["CACHE_DEFAULT_TIMEOUT"] @deprecated(deprecated_in="3.0") def get_json(self) -> str: @@ -532,7 +529,7 @@ class BaseViz: # pylint: disable=too-many-public-methods if cache_key and cache_manager.data_cache and not force: cache_value = cache_manager.data_cache.get(cache_key) if cache_value: - stats_logger.incr("loading_from_cache") + self.stats_logger.incr("loading_from_cache") try: df = cache_value["df"] self.query = cache_value["query"] @@ -544,7 +541,7 @@ class BaseViz: # pylint: disable=too-many-public-methods ) self.status = QueryStatus.SUCCESS is_loaded = True - stats_logger.incr("loaded_from_cache") + self.stats_logger.incr("loaded_from_cache") except Exception as ex: # pylint: disable=broad-except logger.exception(ex) logger.error( @@ -582,9 +579,9 @@ class BaseViz: # pylint: disable=too-many-public-methods ) df = self.get_df(query_obj) if self.status != QueryStatus.FAILED: - stats_logger.incr("loaded_from_source") + self.stats_logger.incr("loaded_from_source") if not self.force: - stats_logger.incr("loaded_from_source_without_force") + self.stats_logger.incr("loaded_from_source_without_force") is_loaded = True except QueryObjectValidationError as ex: error = dataclasses.asdict( @@ -677,7 +674,9 @@ class BaseViz: # pylint: disable=too-many-public-methods def get_csv(self) -> str | None: df = self.get_df_payload()["df"] # leverage caching logic include_index = not isinstance(df.index, pd.RangeIndex) - return csv.df_to_escaped_csv(df, index=include_index, **config["CSV_EXPORT"]) + return csv.df_to_escaped_csv( + df, index=include_index, **app.config["CSV_EXPORT"] + ) @deprecated(deprecated_in="3.0") def get_data(self, df: pd.DataFrame) -> VizData: @@ -772,8 +771,8 @@ class CalHeatmapViz(BaseViz): try: start, end = get_since_until( - relative_start=relative_start, - relative_end=relative_end, + relative_start=app.config["DEFAULT_RELATIVE_START_TIME"], + relative_end=app.config["DEFAULT_RELATIVE_END_TIME"], time_range=form_data.get("time_range"), since=form_data.get("since"), until=form_data.get("until"), @@ -1794,7 +1793,7 @@ class MapboxViz(BaseViz): return { "geoJSON": geo_json, "hasCustomMetric": has_custom_metric, - "mapboxApiKey": config["MAPBOX_API_KEY"], + "mapboxApiKey": app.config["MAPBOX_API_KEY"], "mapStyle": self.form_data.get("mapbox_style"), "aggregatorName": self.form_data.get("pandas_aggfunc"), "clusteringRadius": self.form_data.get("clustering_radius"), @@ -1830,7 +1829,7 @@ class DeckGLMultiLayer(BaseViz): slice_ids = self.form_data.get("deck_slices") slices = db.session.query(Slice).filter(Slice.id.in_(slice_ids)).all() return { - "mapboxApiKey": config["MAPBOX_API_KEY"], + "mapboxApiKey": app.config["MAPBOX_API_KEY"], "slices": [slc.data for slc in slices], } @@ -2002,7 +2001,7 @@ class BaseDeckGLViz(BaseViz): return { "features": features, - "mapboxApiKey": config["MAPBOX_API_KEY"], + "mapboxApiKey": app.config["MAPBOX_API_KEY"], "metricLabels": self.metric_labels, } @@ -2325,7 +2324,7 @@ class DeckArc(BaseDeckGLViz): return { "features": super().get_data(df)["features"], - "mapboxApiKey": config["MAPBOX_API_KEY"], + "mapboxApiKey": app.config["MAPBOX_API_KEY"], } @@ -2660,8 +2659,9 @@ def get_subclasses(cls: type[BaseViz]) -> set[type[BaseViz]]: ) -viz_types = { - o.viz_type: o - for o in get_subclasses(BaseViz) - if o.viz_type not in config["VIZ_TYPE_DENYLIST"] -} +def get_viz_types() -> dict[str, type[BaseViz]]: + return { + o.viz_type: o + for o in get_subclasses(BaseViz) + if o.viz_type not in app.config["VIZ_TYPE_DENYLIST"] + }
