This is an automated email from the ASF dual-hosted git repository. maximebeauchemin pushed a commit to branch late_config in repository https://gitbox.apache.org/repos/asf/superset.git
commit 8c78a3207a9169d019d30be37fdf783f2505e60f Author: Maxime Beauchemin <[email protected]> AuthorDate: Tue May 7 09:58:57 2024 -0700 chore: module scope should not require the app context This is a major issue that's been plaguing the backend for a long time. Currently you can't simply run a simple `from superset import models` without getting an error about the app context missing. This DRAFT PR tries to evaluate what's getting in the way. So far I've identified: - app.config being referenced in module scope, this is mostly easy to fix for common configuration, where we can rely on flask.current_app and avoid module scope - dynamic/configurable model: this seems to be the core issue, where say we let people configure their EncryptedField's type - some reliance on SecurityManager, and SecurityManager.user_model used as a relationship, I think this can be worked around using sqlalchemy's `relationship("use-a-string-reference")` call - ??? --- superset/commands/chart/warm_up_cache.py | 4 +-- superset/common/query_context_processor.py | 4 +-- superset/connectors/sqla/models.py | 10 +++--- superset/initialization/__init__.py | 1 + superset/models/core.py | 49 ++++++++++++++++----------- superset/models/helpers.py | 16 ++++----- superset/models/slice.py | 12 +++---- superset/models/sql_lab.py | 4 +-- superset/models/user_attributes.py | 5 +-- superset/utils/cache.py | 13 ++++--- superset/utils/encrypt.py | 26 ++++++-------- superset/utils/screenshots.py | 8 ++--- superset/views/utils.py | 6 ++-- superset/viz.py | 54 +++++++++++++++--------------- test.py | 1 + test_imports.py | 31 +++++++++++++++++ 16 files changed, 135 insertions(+), 109 deletions(-) diff --git a/superset/commands/chart/warm_up_cache.py b/superset/commands/chart/warm_up_cache.py index 2e5c0ac3a3..80aa045e17 100644 --- a/superset/commands/chart/warm_up_cache.py +++ b/superset/commands/chart/warm_up_cache.py @@ -31,7 +31,7 @@ from superset.extensions import db from superset.models.slice import Slice from superset.utils.core import error_msg_from_exception from superset.views.utils import get_dashboard_extra_filters, get_form_data, get_viz -from superset.viz import viz_types +from superset.viz import get_viz_types class ChartWarmUpCacheCommand(BaseCommand): @@ -52,7 +52,7 @@ class ChartWarmUpCacheCommand(BaseCommand): try: form_data = get_form_data(chart.id, use_slice_data=True)[0] - if form_data.get("viz_type") in viz_types: + if form_data.get("viz_type") in get_viz_types(): # Legacy visualizations. if not chart.datasource: raise ChartInvalidError("Chart's datasource does not exist") diff --git a/superset/common/query_context_processor.py b/superset/common/query_context_processor.py index 55c80386a3..f10d1717e2 100644 --- a/superset/common/query_context_processor.py +++ b/superset/common/query_context_processor.py @@ -66,7 +66,7 @@ from superset.utils.core import ( from superset.utils.date_parser import get_past_or_future, normalize_time_delta from superset.utils.pandas_postprocessing.utils import unescape_separator from superset.views.utils import get_viz -from superset.viz import viz_types +from superset.viz import get_viz_types if TYPE_CHECKING: from superset.common.query_context import QueryContext @@ -704,7 +704,7 @@ class QueryContextProcessor: raise QueryObjectValidationError(_("The chart does not exist")) try: - if chart.viz_type in viz_types: + if chart.viz_type in get_viz_types(): if not chart.datasource: raise QueryObjectValidationError( _("The chart datasource does not exist"), diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 12fbdc3bd7..b4eee0c42e 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -122,10 +122,8 @@ from superset.utils import core as utils from superset.utils.backports import StrEnum from superset.utils.core import GenericDataType, MediumText -config = app.config metadata = Model.metadata # pylint: disable=no-member logger = logging.getLogger(__name__) -ADVANCED_DATA_TYPES = config["ADVANCED_DATA_TYPES"] VIRTUAL_TABLE_ALIAS = "virtual_table" # a non-exhaustive set of additive metrics @@ -1120,7 +1118,7 @@ class SqlaTable( ) metric_class = SqlMetric column_class = TableColumn - owner_class = security_manager.user_model + owner_class = "User" __tablename__ = "tables" @@ -1358,7 +1356,7 @@ class SqlaTable( @property def health_check_message(self) -> str | None: - check = config["DATASET_HEALTH_CHECK"] + check = app.config["DATASET_HEALTH_CHECK"] return check(self) if check else None @property @@ -1876,7 +1874,7 @@ class SqlaTable( self.add_missing_metrics(metrics) # Apply config supplied mutations. - config["SQLA_TABLE_MUTATOR"](self) + app.config["SQLA_TABLE_MUTATOR"](self) db.session.merge(self) if commit: @@ -2105,7 +2103,7 @@ class RowLevelSecurityFilter(Model, AuditMixinNullable): ) group_key = Column(String(255), nullable=True) roles = relationship( - security_manager.role_model, + "Role", secondary=RLSFilterRoles, backref="row_level_security_filters", ) diff --git a/superset/initialization/__init__.py b/superset/initialization/__init__.py index 83d3afb3d1..77dc4cfe81 100644 --- a/superset/initialization/__init__.py +++ b/superset/initialization/__init__.py @@ -71,6 +71,7 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods self.superset_app = app self.config = app.config self.manifest: dict[Any, Any] = {} + self.stats_logger = app.config["STATS_LOGGER"] @deprecated(details="use self.superset_app instead of self.flask_app") # type: ignore @property diff --git a/superset/models/core.py b/superset/models/core.py index fe486bf2b1..b9ebd760f7 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -81,10 +81,6 @@ from superset.utils.backports import StrEnum from superset.utils.core import DatasourceName, get_username from superset.utils.oauth2 import get_oauth2_access_token -config = app.config -custom_password_store = config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"] -stats_logger = config["STATS_LOGGER"] -log_query = config["QUERY_LOGGER"] metadata = Model.metadata # pylint: disable=no-member logger = logging.getLogger(__name__) @@ -92,8 +88,6 @@ if TYPE_CHECKING: from superset.databases.ssh_tunnel.models import SSHTunnel from superset.models.sql_lab import Query -DB_CONNECTION_MUTATOR = config["DB_CONNECTION_MUTATOR"] - class KeyValue(Model): # pylint: disable=too-few-public-methods """Used for any type of key-value store""" @@ -369,6 +363,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable def set_sqlalchemy_uri(self, uri: str) -> None: conn = make_url_safe(uri.strip()) + custom_password_store = app.config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"] if conn.password != PASSWORD_MASK and not custom_password_store: # do not over-write the password with the password mask self.password = conn.password @@ -446,7 +441,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable sqlalchemy_uri=sqlalchemy_uri, ) - def _get_sqla_engine( + def _get_sqla_engine( # pylint: disable=too-many-locals self, catalog: str | None = None, schema: str | None = None, @@ -510,7 +505,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable self.update_params_from_encrypted_extra(params) - if DB_CONNECTION_MUTATOR: + if db_conn_mutator := app.config["DB_CONNECTION_MUTATOR"]: if not source and request and request.referrer: if "/superset/dashboard/" in request.referrer: source = utils.QuerySource.DASHBOARD @@ -519,7 +514,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable elif "/sqllab/" in request.referrer: source = utils.QuerySource.SQL_LAB - sqlalchemy_url, params = DB_CONNECTION_MUTATOR( + sqlalchemy_url, params = db_conn_mutator( sqlalchemy_url, params, effective_username, @@ -618,8 +613,8 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable on the group of queries as a whole. Here the called passes the context as to whether the SQL is split or already. """ - sql_mutator = config["SQL_QUERY_MUTATOR"] - if sql_mutator and (is_split == config["MUTATE_AFTER_SPLIT"]): + sql_mutator = app.config["SQL_QUERY_MUTATOR"] + if sql_mutator and (is_split == app.config["MUTATE_AFTER_SPLIT"]): return sql_mutator( sql_, security_manager=security_manager, @@ -638,6 +633,8 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable with self.get_sqla_engine(catalog=catalog, schema=schema) as engine: engine_url = engine.url + log_query = app.config["QUERY_LOGGER"] + def _log_query(sql: str) -> None: if log_query: log_query( @@ -974,7 +971,7 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable allowed_databases = literal_eval(allowed_databases) if hasattr(g, "user"): - extra_allowed_databases = config["ALLOWED_USER_CSV_SCHEMA_FUNC"]( + extra_allowed_databases = app.config["ALLOWED_USER_CSV_SCHEMA_FUNC"]( self, g.user ) allowed_databases += extra_allowed_databases @@ -988,7 +985,8 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable # if the URI is invalid, ignore and return a placeholder url # (so users see 500 less often) return "dialect://invalid_uri" - if custom_password_store: + + if custom_password_store := app.config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"]: conn = conn.set(password=custom_password_store(conn)) else: conn = conn.set(password=self.password) @@ -1071,9 +1069,22 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable return self.db_engine_spec.get_oauth2_config() -sqla.event.listen(Database, "after_insert", security_manager.database_after_insert) -sqla.event.listen(Database, "after_update", security_manager.database_after_update) -sqla.event.listen(Database, "after_delete", security_manager.database_after_delete) +# Using lambdas for security manager to prevent referencing it in module scope +sqla.event.listen( + Database, + "after_insert", + lambda *args, **kwargs: security_manager.database_after_insert(*args, **kwargs), # pylint: disable=unnecessary-lambda +) +sqla.event.listen( + Database, + "after_update", + lambda *args, **kwargs: security_manager.database_after_update(*args, **kwargs), # pylint: disable=unnecessary-lambda +) +sqla.event.listen( + Database, + "after_delete", + lambda *args, **kwargs: security_manager.database_after_delete(*args, **kwargs), # pylint: disable=unnecessary-lambda +) class DatabaseUserOAuth2Tokens(Model, AuditMixinNullable): @@ -1091,7 +1102,7 @@ class DatabaseUserOAuth2Tokens(Model, AuditMixinNullable): ForeignKey("ab_user.id", ondelete="CASCADE"), nullable=False, ) - user = relationship(security_manager.user_model, foreign_keys=[user_id]) + user = relationship("User", foreign_keys=[user_id]) database_id = Column( Integer, @@ -1116,9 +1127,7 @@ class Log(Model): # pylint: disable=too-few-public-methods dashboard_id = Column(Integer) slice_id = Column(Integer) json = Column(utils.MediumText()) - user = relationship( - security_manager.user_model, backref="logs", foreign_keys=[user_id] - ) + user = relationship("User", backref="logs", foreign_keys=[user_id]) dttm = Column(DateTime, default=datetime.utcnow) duration_ms = Column(Integer) referrer = Column(String(1024)) diff --git a/superset/models/helpers.py b/superset/models/helpers.py index 100391086c..1ebf9244f7 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -105,12 +105,10 @@ if TYPE_CHECKING: from superset.models.core import Database -config = app.config logger = logging.getLogger(__name__) VIRTUAL_TABLE_ALIAS = "virtual_table" SERIES_LIMIT_SUBQ_ALIAS = "series_limit" -ADVANCED_DATA_TYPES = config["ADVANCED_DATA_TYPES"] def validate_adhoc_subquery( @@ -1822,12 +1820,12 @@ class ExploreMixin: # pylint: disable=too-many-public-methods and feature_flag_manager.is_feature_enabled( "ENABLE_ADVANCED_DATA_TYPES" ) - and col_advanced_data_type in ADVANCED_DATA_TYPES + and col_advanced_data_type in app.config["ADVANCED_DATA_TYPES"] ): values = eq if is_list_target else [eq] # type: ignore - bus_resp: AdvancedDataTypeResponse = ADVANCED_DATA_TYPES[ - col_advanced_data_type - ].translate_type( + bus_resp: AdvancedDataTypeResponse = app.config[ + "ADVANCED_DATA_TYPES" + ][col_advanced_data_type].translate_type( { "type": col_advanced_data_type, "values": values, @@ -1839,9 +1837,9 @@ class ExploreMixin: # pylint: disable=too-many-public-methods ) where_clause_and.append( - ADVANCED_DATA_TYPES[col_advanced_data_type].translate_filter( - sqla_col, op, bus_resp["values"] - ) + app.config["ADVANCED_DATA_TYPES"][ + col_advanced_data_type + ].translate_filter(sqla_col, op, bus_resp["values"]) ) elif is_list_target: assert isinstance(eq, (tuple, list)) diff --git a/superset/models/slice.py b/superset/models/slice.py index 2a0734b107..7d645ea7fd 100644 --- a/superset/models/slice.py +++ b/superset/models/slice.py @@ -39,14 +39,14 @@ from sqlalchemy.engine.base import Connection from sqlalchemy.orm import relationship from sqlalchemy.orm.mapper import Mapper -from superset import db, is_feature_enabled, security_manager +from superset import db, is_feature_enabled from superset.legacy import update_time_range from superset.models.helpers import AuditMixinNullable, ImportExportMixin from superset.tasks.thumbnails import cache_chart_thumbnail from superset.tasks.utils import get_current_user from superset.thumbnails.digest import get_chart_digest from superset.utils import core as utils -from superset.viz import BaseViz, viz_types +from superset.viz import BaseViz, get_viz_types if TYPE_CHECKING: from superset.common.query_context import QueryContext @@ -93,11 +93,9 @@ class Slice( # pylint: disable=too-many-public-methods certification_details = Column(Text) is_managed_externally = Column(Boolean, nullable=False, default=False) external_url = Column(Text, nullable=True) - last_saved_by = relationship( - security_manager.user_model, foreign_keys=[last_saved_by_fk] - ) + last_saved_by = relationship("User", foreign_keys=[last_saved_by_fk]) owners = relationship( - security_manager.user_model, + "User", secondary=slice_user, passive_deletes=True, ) @@ -204,7 +202,7 @@ class Slice( # pylint: disable=too-many-public-methods @property def viz(self) -> BaseViz | None: form_data = json.loads(self.params) - viz_class = viz_types.get(self.viz_type) + viz_class = get_viz_types().get(self.viz_type) datasource = self.datasource if viz_class and datasource: return viz_class(datasource=datasource, form_data=form_data) diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py index 31443b4bb1..9707b442bb 100644 --- a/superset/models/sql_lab.py +++ b/superset/models/sql_lab.py @@ -148,7 +148,7 @@ class Query( foreign_keys=[database_id], backref=backref("queries", cascade="all, delete-orphan"), ) - user = relationship(security_manager.user_model, foreign_keys=[user_id]) + user = relationship("User", foreign_keys=[user_id]) __table_args__ = (sqla.Index("ti_user_id_changed_on", user_id, changed_on),) @@ -393,7 +393,7 @@ class SavedQuery( sql = Column(MediumText()) template_parameters = Column(Text) user = relationship( - security_manager.user_model, + "User", backref=backref("saved_queries", cascade="all, delete-orphan"), foreign_keys=[user_id], ) diff --git a/superset/models/user_attributes.py b/superset/models/user_attributes.py index 512270c89c..f5438ccc6e 100644 --- a/superset/models/user_attributes.py +++ b/superset/models/user_attributes.py @@ -19,7 +19,6 @@ from flask_appbuilder import Model from sqlalchemy import Column, ForeignKey, Integer, String from sqlalchemy.orm import relationship -from superset import security_manager from superset.models.helpers import AuditMixinNullable @@ -36,9 +35,7 @@ class UserAttribute(Model, AuditMixinNullable): __tablename__ = "user_attribute" id = Column(Integer, primary_key=True) user_id = Column(Integer, ForeignKey("ab_user.id")) - user = relationship( - security_manager.user_model, backref="extra_attributes", foreign_keys=[user_id] - ) + user = relationship("User", backref="extra_attributes", foreign_keys=[user_id]) welcome_dashboard_id = Column(Integer, ForeignKey("dashboards.id")) welcome_dashboard = relationship("Dashboard") avatar_url = Column(String(100)) diff --git a/superset/utils/cache.py b/superset/utils/cache.py index 00216fc4b1..c7ba4e8382 100644 --- a/superset/utils/cache.py +++ b/superset/utils/cache.py @@ -20,7 +20,7 @@ import inspect import logging from datetime import datetime, timedelta from functools import wraps -from typing import Any, Callable, TYPE_CHECKING +from typing import Any, Callable, Optional, TYPE_CHECKING from flask import current_app as app, request from flask_caching import Cache @@ -34,10 +34,8 @@ from superset.utils.core import json_int_dttm_ser from superset.utils.hashing import md5_sha_from_dict if TYPE_CHECKING: - from superset.stats_logger import BaseStatsLogger + pass -config = app.config -stats_logger: BaseStatsLogger = config["STATS_LOGGER"] logger = logging.getLogger(__name__) @@ -65,9 +63,9 @@ def set_and_log_cache( dttm = datetime.utcnow().isoformat().split(".")[0] value = {**cache_value, "dttm": dttm} cache_instance.set(cache_key, value, timeout=timeout) - stats_logger.incr("set_cache_key") + app.stats_logger.incr("set_cache_key") - if datasource_uid and config["STORE_CACHE_KEYS_IN_METADATA_DB"]: + if datasource_uid and app.config["STORE_CACHE_KEYS_IN_METADATA_DB"]: ck = CacheKey( cache_key=cache_key, cache_timeout=cache_timeout, @@ -147,7 +145,7 @@ def memoized_func(key: str, cache: Cache = cache_manager.cache) -> Callable[..., def etag_cache( cache: Cache = cache_manager.cache, get_last_modified: Callable[..., datetime] | None = None, - max_age: int | float = app.config["CACHE_DEFAULT_TIMEOUT"], + max_age: Optional[int | float] = None, raise_for_access: Callable[..., Any] | None = None, skip: Callable[..., bool] | None = None, ) -> Callable[..., Any]: @@ -163,6 +161,7 @@ def etag_cache( dataframe cache for requests that produce the same SQL. """ + max_age = max_age or app.config["CACHE_DEFAULT_TIMEOUT"] def decorator(f: Callable[..., Any]) -> Callable[..., Any]: @wraps(f) diff --git a/superset/utils/encrypt.py b/superset/utils/encrypt.py index e5bbd13825..ed58a49f6a 100644 --- a/superset/utils/encrypt.py +++ b/superset/utils/encrypt.py @@ -18,7 +18,7 @@ import logging from abc import ABC, abstractmethod from typing import Any, Optional -from flask import Flask +from flask import current_app as app from flask_babel import lazy_gettext as _ from sqlalchemy import text, TypeDecorator from sqlalchemy.engine import Connection, Dialect, Row @@ -27,6 +27,10 @@ from sqlalchemy_utils import EncryptedType logger = logging.getLogger(__name__) +def get_key() -> str: + return app.config["SECRET_KEY"] + + class AbstractEncryptedFieldAdapter(ABC): # pylint: disable=too-few-public-methods @abstractmethod def create( @@ -47,22 +51,15 @@ class SQLAlchemyUtilsAdapter( # pylint: disable=too-few-public-methods *args: list[Any], **kwargs: Optional[dict[str, Any]], ) -> TypeDecorator: - if app_config: - return EncryptedType(*args, app_config["SECRET_KEY"], **kwargs) - - raise Exception( # pylint: disable=broad-exception-raised - "Missing app_config kwarg" - ) + return EncryptedType(*args, get_key, **kwargs) class EncryptedFieldFactory: def __init__(self) -> None: self._concrete_type_adapter: Optional[AbstractEncryptedFieldAdapter] = None - self._config: Optional[dict[str, Any]] = None - def init_app(self, app: Flask) -> None: - self._config = app.config - self._concrete_type_adapter = self._config[ # type: ignore + def init_app(self, *args, **kwargs) -> None: # type: ignore # pylint: disable=unused-argument + self._concrete_type_adapter = app.config[ "SQLALCHEMY_ENCRYPTED_FIELD_TYPE_ADAPTER" ]() @@ -70,11 +67,8 @@ class EncryptedFieldFactory: self, *args: list[Any], **kwargs: Optional[dict[str, Any]] ) -> TypeDecorator: if self._concrete_type_adapter: - return self._concrete_type_adapter.create(self._config, *args, **kwargs) - - raise Exception( # pylint: disable=broad-exception-raised - "App not initialized yet. Please call init_app first" - ) + return self._concrete_type_adapter.create(app.config, *args, **kwargs) + return None class SecretsMigrator: diff --git a/superset/utils/screenshots.py b/superset/utils/screenshots.py index bf6ed0f9e8..ad7726dc5d 100644 --- a/superset/utils/screenshots.py +++ b/superset/utils/screenshots.py @@ -20,7 +20,7 @@ import logging from io import BytesIO from typing import TYPE_CHECKING -from flask import current_app +from flask import current_app as app from superset import feature_flag_manager from superset.utils.hashing import md5_sha_from_dict @@ -53,7 +53,6 @@ if TYPE_CHECKING: class BaseScreenshot: - driver_type = current_app.config["WEBDRIVER_TYPE"] thumbnail_type: str = "" element: str = "" window_size: WindowSize = DEFAULT_SCREENSHOT_WINDOW_SIZE @@ -66,9 +65,10 @@ class BaseScreenshot: def driver(self, window_size: WindowSize | None = None) -> WebDriver: window_size = window_size or self.window_size + driver_type = app.config["WEBDRIVER_TYPE"] if feature_flag_manager.is_feature_enabled("PLAYWRIGHT_REPORTS_AND_THUMBNAILS"): - return WebDriverPlaywright(self.driver_type, window_size) - return WebDriverSelenium(self.driver_type, window_size) + return WebDriverPlaywright(driver_type, window_size) + return WebDriverSelenium(driver_type, window_size) def cache_key( self, diff --git a/superset/views/utils.py b/superset/views/utils.py index 2d8fcd68da..c446978d6f 100644 --- a/superset/views/utils.py +++ b/superset/views/utils.py @@ -49,7 +49,6 @@ from superset.models.sql_lab import Query from superset.superset_typing import FormData from superset.utils.core import DatasourceType from superset.utils.decorators import stats_timing -from superset.viz import BaseViz logger = logging.getLogger(__name__) stats_logger = app.config["STATS_LOGGER"] @@ -124,13 +123,14 @@ def get_viz( datasource_id: int, force: bool = False, force_cached: bool = False, -) -> BaseViz: +) -> viz.BaseViz: viz_type = form_data.get("viz_type", "table") datasource = DatasourceDAO.get_datasource( DatasourceType(datasource_type), datasource_id, ) - viz_obj = viz.viz_types[viz_type]( + viz_types = viz.get_viz_types() + viz_obj = viz_types[viz_type]( datasource, form_data=form_data, force=force, force_cached=force_cached ) return viz_obj diff --git a/superset/viz.py b/superset/viz.py index 5a4b323079..654e156185 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -86,10 +86,6 @@ from superset.utils.hashing import md5_sha_from_str if TYPE_CHECKING: from superset.connectors.sqla.models import BaseDatasource -config = app.config -stats_logger = config["STATS_LOGGER"] -relative_start = config["DEFAULT_RELATIVE_START_TIME"] -relative_end = config["DEFAULT_RELATIVE_END_TIME"] logger = logging.getLogger(__name__) METRIC_KEYS = [ @@ -132,6 +128,7 @@ class BaseViz: # pylint: disable=too-many-public-methods self.query = "" self.token = utils.get_form_data_token(form_data) + self.stats_logger = app.config["STATS_LOGGER"] self.groupby: list[Column] = self.form_data.get("groupby") or [] self.time_shift = timedelta() @@ -250,7 +247,7 @@ class BaseViz: # pylint: disable=too-many-public-methods "groupby": [], "metrics": [], "orderby": [], - "row_limit": config["SAMPLES_ROW_LIMIT"], + "row_limit": app.config["SAMPLES_ROW_LIMIT"], "columns": [o.column_name for o in self.datasource.columns], "from_dttm": None, "to_dttm": None, @@ -362,7 +359,7 @@ class BaseViz: # pylint: disable=too-many-public-methods timeseries_limit_metric = self.form_data.get("timeseries_limit_metric") # apply row limit to query - row_limit = int(self.form_data.get("row_limit") or config["ROW_LIMIT"]) + row_limit = int(self.form_data.get("row_limit") or app.config["ROW_LIMIT"]) row_limit = apply_max_row_limit(row_limit) # default order direction @@ -370,8 +367,8 @@ class BaseViz: # pylint: disable=too-many-public-methods try: since, until = get_since_until( - relative_start=relative_start, - relative_end=relative_end, + relative_start=app.config["DEFAULT_RELATIVE_START_TIME"], + relative_end=app.config["DEFAULT_RELATIVE_END_TIME"], time_range=self.form_data.get("time_range"), since=self.form_data.get("since"), until=self.form_data.get("until"), @@ -434,9 +431,9 @@ class BaseViz: # pylint: disable=too-many-public-methods and self.datasource.database.cache_timeout ) is not None: return self.datasource.database.cache_timeout - if config["DATA_CACHE_CONFIG"].get("CACHE_DEFAULT_TIMEOUT") is not None: - return config["DATA_CACHE_CONFIG"]["CACHE_DEFAULT_TIMEOUT"] - return config["CACHE_DEFAULT_TIMEOUT"] + if app.config["DATA_CACHE_CONFIG"].get("CACHE_DEFAULT_TIMEOUT") is not None: + return app.config["DATA_CACHE_CONFIG"]["CACHE_DEFAULT_TIMEOUT"] + return app.config["CACHE_DEFAULT_TIMEOUT"] @deprecated(deprecated_in="3.0") def get_json(self) -> str: @@ -532,7 +529,7 @@ class BaseViz: # pylint: disable=too-many-public-methods if cache_key and cache_manager.data_cache and not force: cache_value = cache_manager.data_cache.get(cache_key) if cache_value: - stats_logger.incr("loading_from_cache") + self.stats_logger.incr("loading_from_cache") try: df = cache_value["df"] self.query = cache_value["query"] @@ -544,7 +541,7 @@ class BaseViz: # pylint: disable=too-many-public-methods ) self.status = QueryStatus.SUCCESS is_loaded = True - stats_logger.incr("loaded_from_cache") + self.stats_logger.incr("loaded_from_cache") except Exception as ex: # pylint: disable=broad-except logger.exception(ex) logger.error( @@ -582,9 +579,9 @@ class BaseViz: # pylint: disable=too-many-public-methods ) df = self.get_df(query_obj) if self.status != QueryStatus.FAILED: - stats_logger.incr("loaded_from_source") + self.stats_logger.incr("loaded_from_source") if not self.force: - stats_logger.incr("loaded_from_source_without_force") + self.stats_logger.incr("loaded_from_source_without_force") is_loaded = True except QueryObjectValidationError as ex: error = dataclasses.asdict( @@ -677,7 +674,9 @@ class BaseViz: # pylint: disable=too-many-public-methods def get_csv(self) -> str | None: df = self.get_df_payload()["df"] # leverage caching logic include_index = not isinstance(df.index, pd.RangeIndex) - return csv.df_to_escaped_csv(df, index=include_index, **config["CSV_EXPORT"]) + return csv.df_to_escaped_csv( + df, index=include_index, **app.config["CSV_EXPORT"] + ) @deprecated(deprecated_in="3.0") def get_data(self, df: pd.DataFrame) -> VizData: @@ -772,8 +771,8 @@ class CalHeatmapViz(BaseViz): try: start, end = get_since_until( - relative_start=relative_start, - relative_end=relative_end, + relative_start=app.config["DEFAULT_RELATIVE_START_TIME"], + relative_end=app.config["DEFAULT_RELATIVE_END_TIME"], time_range=form_data.get("time_range"), since=form_data.get("since"), until=form_data.get("until"), @@ -1794,7 +1793,7 @@ class MapboxViz(BaseViz): return { "geoJSON": geo_json, "hasCustomMetric": has_custom_metric, - "mapboxApiKey": config["MAPBOX_API_KEY"], + "mapboxApiKey": app.config["MAPBOX_API_KEY"], "mapStyle": self.form_data.get("mapbox_style"), "aggregatorName": self.form_data.get("pandas_aggfunc"), "clusteringRadius": self.form_data.get("clustering_radius"), @@ -1830,7 +1829,7 @@ class DeckGLMultiLayer(BaseViz): slice_ids = self.form_data.get("deck_slices") slices = db.session.query(Slice).filter(Slice.id.in_(slice_ids)).all() return { - "mapboxApiKey": config["MAPBOX_API_KEY"], + "mapboxApiKey": app.config["MAPBOX_API_KEY"], "slices": [slc.data for slc in slices], } @@ -2002,7 +2001,7 @@ class BaseDeckGLViz(BaseViz): return { "features": features, - "mapboxApiKey": config["MAPBOX_API_KEY"], + "mapboxApiKey": app.config["MAPBOX_API_KEY"], "metricLabels": self.metric_labels, } @@ -2325,7 +2324,7 @@ class DeckArc(BaseDeckGLViz): return { "features": super().get_data(df)["features"], - "mapboxApiKey": config["MAPBOX_API_KEY"], + "mapboxApiKey": app.config["MAPBOX_API_KEY"], } @@ -2660,8 +2659,9 @@ def get_subclasses(cls: type[BaseViz]) -> set[type[BaseViz]]: ) -viz_types = { - o.viz_type: o - for o in get_subclasses(BaseViz) - if o.viz_type not in config["VIZ_TYPE_DENYLIST"] -} +def get_viz_types() -> dict[str, type[BaseViz]]: + return { + o.viz_type: o + for o in get_subclasses(BaseViz) + if o.viz_type not in app.config["VIZ_TYPE_DENYLIST"] + } diff --git a/test.py b/test.py new file mode 100644 index 0000000000..4da9b14742 --- /dev/null +++ b/test.py @@ -0,0 +1 @@ +print("YO") diff --git a/test_imports.py b/test_imports.py new file mode 100644 index 0000000000..b51d69d8c9 --- /dev/null +++ b/test_imports.py @@ -0,0 +1,31 @@ +import os + +# mypy: ignore-errors + +# pylint: skip-file +# pylint: disable-all + + +def generate_import_statements(root_dir): + import_statements = set() + for dirpath, _, filenames in os.walk(root_dir): + for filename in filenames: + if filename.endswith(".py"): + module_path = os.path.relpath(os.path.join(dirpath, filename), root_dir) + module_name = os.path.splitext(module_path.replace(os.path.sep, "."))[0] + splitted = module_name.split(".") + if splitted[-1] == "__init__": + splitted.pop() + if len(splitted) > 1 and "migration" not in module_path: + package = ".".join(splitted[:-1]) + import_statements.add( + f"from superset.{package} import {splitted[-1]}" + ) + return sorted(import_statements) + + +if __name__ == "__main__": + root_dir = "superset" + import_statements = generate_import_statements(root_dir) + for statement in import_statements: + print(statement)
