This is an automated email from the ASF dual-hosted git repository. johnbodley pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-superset.git
The following commit(s) were added to refs/heads/master by this push: new b296a0f [mypy] Enforcing typing for superset.utils (#9905) b296a0f is described below commit b296a0f250979bf70e9cb2a2a2b48fd10038a363 Author: John Bodley <4567245+john-bod...@users.noreply.github.com> AuthorDate: Wed May 27 22:57:30 2020 -0700 [mypy] Enforcing typing for superset.utils (#9905) Co-authored-by: John Bodley <john.bod...@airbnb.com> --- setup.cfg | 2 +- superset/config.py | 2 +- superset/typing.py | 1 + superset/utils/cache.py | 10 +- superset/utils/core.py | 161 +++++++++++---------- .../utils/dashboard_filter_scopes_converter.py | 22 +-- superset/utils/dashboard_import_export.py | 12 +- superset/utils/dates.py | 4 +- superset/utils/decorators.py | 21 ++- superset/utils/dict_import_export.py | 15 +- superset/utils/feature_flag_manager.py | 24 ++- superset/utils/log.py | 30 ++-- superset/utils/pandas_postprocessing.py | 6 +- superset/utils/screenshots.py | 6 +- superset/utils/url_map_converters.py | 12 +- superset/views/core.py | 2 +- tests/utils_tests.py | 12 -- 17 files changed, 194 insertions(+), 148 deletions(-) diff --git a/setup.cfg b/setup.cfg index 99c09ce..1115de9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -53,7 +53,7 @@ order_by_type = false ignore_missing_imports = true no_implicit_optional = true -[mypy-superset.bin.*,superset.charts.*,superset.commands.*,superset.common.*,superset.connectors.*,superset.dao.*,superset.dashboards.*,superset.datasets.*,superset.db_engine_specs.*,superset.db_engines.*,superset.examples.*,superset.migrations.*,superset.models.*,uperset.queries.*,superset.security.*,superset.sql_validators.*,superset.tasks.*,superset.translations.*,superset.views.chart.*,superset.views.dashboard.*,superset.views.database.*] +[mypy-superset.bin.*,superset.charts.*,superset.commands.*,superset.common.*,superset.connectors.*,superset.dao.*,superset.dashboards.*,superset.datasets.*,superset.db_engine_specs.*,superset.db_engines.*,superset.examples.*,superset.migrations.*,superset.models.*,uperset.queries.*,superset.security.*,superset.sql_validators.*,superset.tasks.*,superset.translations.*,superset.utils.*,superset.views.chart.*,superset.views.dashboard.*,superset.views.database.*] check_untyped_defs = true disallow_untyped_calls = true disallow_untyped_defs = true diff --git a/superset/config.py b/superset/config.py index e3d3ffb..e0f22f7 100644 --- a/superset/config.py +++ b/superset/config.py @@ -279,7 +279,7 @@ LANGUAGES = { # For example, DEFAULT_FEATURE_FLAGS = { 'FOO': True, 'BAR': False } here # and FEATURE_FLAGS = { 'BAR': True, 'BAZ': True } in superset_config.py # will result in combined feature flags of { 'FOO': True, 'BAR': True, 'BAZ': True } -DEFAULT_FEATURE_FLAGS = { +DEFAULT_FEATURE_FLAGS: Dict[str, bool] = { # Experimental feature introducing a client (browser) cache "CLIENT_CACHE": False, "ENABLE_EXPLORE_JSON_CSRF_PROTECTION": False, diff --git a/superset/typing.py b/superset/typing.py index f3db6ae..09a3393 100644 --- a/superset/typing.py +++ b/superset/typing.py @@ -28,6 +28,7 @@ DbapiDescription = Union[List[DbapiDescriptionRow], Tuple[DbapiDescriptionRow, . DbapiResult = List[Union[List[Any], Tuple[Any, ...]]] FilterValue = Union[float, int, str] FilterValues = Union[FilterValue, List[FilterValue], Tuple[FilterValue]] +FormData = Dict[str, Any] Granularity = Union[str, Dict[str, Union[str, float]]] Metric = Union[Dict[str, str], str] QueryObjectDict = Dict[str, Any] diff --git a/superset/utils/cache.py b/superset/utils/cache.py index b555005..bd39f87 100644 --- a/superset/utils/cache.py +++ b/superset/utils/cache.py @@ -14,14 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Callable, Optional +from typing import Any, Callable, Optional from flask import request from superset.extensions import cache_manager -def view_cache_key(*_, **__) -> str: +def view_cache_key(*args: Any, **kwargs: Any) -> str: # pylint: disable=unused-argument args_hash = hash(frozenset(request.args.items())) return "view/{}/{}".format(request.path, args_hash) @@ -45,10 +45,10 @@ def memoized_func( returns the caching key. """ - def wrap(f): + def wrap(f: Callable) -> Callable: if cache_manager.tables_cache: - def wrapped_f(self, *args, **kwargs): + def wrapped_f(self: Any, *args: Any, **kwargs: Any) -> Any: if not kwargs.get("cache", True): return f(self, *args, **kwargs) @@ -69,7 +69,7 @@ def memoized_func( else: # noop - def wrapped_f(self, *args, **kwargs): + def wrapped_f(self: Any, *args: Any, **kwargs: Any) -> Any: return f(self, *args, **kwargs) return wrapped_f diff --git a/superset/utils/core.py b/superset/utils/core.py index 3618a28..b23136d 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -39,6 +39,7 @@ from email.utils import formatdate from enum import Enum from time import struct_time from timeit import default_timer +from types import TracebackType from typing import ( Any, Callable, @@ -51,6 +52,7 @@ from typing import ( Sequence, Set, Tuple, + Type, TYPE_CHECKING, Union, ) @@ -69,10 +71,12 @@ from dateutil.parser import parse from dateutil.relativedelta import relativedelta from flask import current_app, flash, g, Markup, render_template from flask_appbuilder import SQLA -from flask_appbuilder.security.sqla.models import User +from flask_appbuilder.security.sqla.models import Role, User from flask_babel import gettext as __, lazy_gettext as _ from sqlalchemy import event, exc, select, Text from sqlalchemy.dialects.mysql import MEDIUMTEXT +from sqlalchemy.engine import Connection, Engine +from sqlalchemy.engine.reflection import Inspector from sqlalchemy.sql.type_api import Variant from sqlalchemy.types import TEXT, TypeDecorator @@ -81,7 +85,7 @@ from superset.exceptions import ( SupersetException, SupersetTimeoutException, ) -from superset.typing import Metric +from superset.typing import FormData, Metric from superset.utils.dates import datetime_to_epoch, EPOCH try: @@ -90,6 +94,7 @@ except ImportError: pass if TYPE_CHECKING: + from superset.connectors.base.models import BaseDatasource from superset.models.core import Database @@ -121,7 +126,7 @@ except NameError: pass -def flasher(msg: str, severity: str) -> None: +def flasher(msg: str, severity: str = "message") -> None: """Flask's flash if available, logging call if not""" try: flash(msg, severity) @@ -142,17 +147,17 @@ class _memoized: should account for instance variable changes. """ - def __init__(self, func, watch=()): + def __init__(self, func: Callable, watch: Optional[List[str]] = None) -> None: self.func = func - self.cache = {} + self.cache: Dict[Any, Any] = {} self.is_method = False self.watch = watch or [] - def __call__(self, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: Any) -> Any: key = [args, frozenset(kwargs.items())] if self.is_method: key.append(tuple([getattr(args[0], v, None) for v in self.watch])) - key = tuple(key) + key = tuple(key) # type: ignore if key in self.cache: return self.cache[key] try: @@ -164,23 +169,25 @@ class _memoized: # Better to not cache than to blow up entirely. return self.func(*args, **kwargs) - def __repr__(self): + def __repr__(self) -> str: """Return the function's docstring.""" - return self.func.__doc__ + return self.func.__doc__ or "" - def __get__(self, obj, objtype): + def __get__(self, obj: Any, objtype: Type) -> functools.partial: if not self.is_method: self.is_method = True """Support instance methods.""" return functools.partial(self.__call__, obj) -def memoized(func: Optional[Callable] = None, watch: Optional[List[str]] = None): +def memoized( + func: Optional[Callable] = None, watch: Optional[List[str]] = None +) -> Callable: if func: return _memoized(func) else: - def wrapper(f): + def wrapper(f: Callable) -> Callable: return _memoized(f, watch) return wrapper @@ -229,7 +236,7 @@ def cast_to_num(value: Union[float, int, str]) -> Optional[Union[float, int]]: return None -def list_minus(l: List, minus: List) -> List: +def list_minus(l: List[Any], minus: List[Any]) -> List[Any]: """Returns l without what is in minus >>> list_minus([1, 2, 3], [2]) @@ -284,19 +291,19 @@ def md5_hex(data: str) -> str: class DashboardEncoder(json.JSONEncoder): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.sort_keys = True # pylint: disable=E0202 - def default(self, o): + def default(self, o: Any) -> Dict[Any, Any]: try: vals = {k: v for k, v in o.__dict__.items() if k != "_sa_instance_state"} return {"__{}__".format(o.__class__.__name__): vals} except Exception: if type(o) == datetime: return {"__datetime__": o.replace(microsecond=0).isoformat()} - return json.JSONEncoder(sort_keys=True).default(self, o) + return json.JSONEncoder(sort_keys=True).default(o) def parse_human_timedelta(s: Optional[str]) -> timedelta: @@ -332,28 +339,15 @@ class JSONEncodedDict(TypeDecorator): impl = TEXT - def process_bind_param(self, value, dialect): - if value is not None: - value = json.dumps(value) + def process_bind_param( + self, value: Optional[Dict[Any, Any]], dialect: str + ) -> Optional[str]: + return json.dumps(value) if value is not None else None - return value - - def process_result_value(self, value, dialect): - if value is not None: - value = json.loads(value) - return value - - -def datetime_f(dttm): - """Formats datetime to take less room when it is recent""" - if dttm: - dttm = dttm.isoformat() - now_iso = datetime.now().isoformat() - if now_iso[:10] == dttm[:10]: - dttm = dttm[11:] - elif now_iso[:4] == dttm[:4]: - dttm = dttm[5:] - return "<nobr>{}</nobr>".format(dttm) + def process_result_value( + self, value: Optional[str], dialect: str + ) -> Optional[Dict[Any, Any]]: + return json.loads(value) if value is not None else None def format_timedelta(td: timedelta) -> str: @@ -373,7 +367,7 @@ def format_timedelta(td: timedelta) -> str: return str(td) -def base_json_conv(obj): +def base_json_conv(obj: Any) -> Any: if isinstance(obj, memoryview): obj = obj.tobytes() if isinstance(obj, np.int64): @@ -397,7 +391,7 @@ def base_json_conv(obj): return "[bytes]" -def json_iso_dttm_ser(obj, pessimistic: Optional[bool] = False): +def json_iso_dttm_ser(obj: Any, pessimistic: bool = False) -> str: """ json serializer that deals with dates @@ -420,14 +414,14 @@ def json_iso_dttm_ser(obj, pessimistic: Optional[bool] = False): return obj -def pessimistic_json_iso_dttm_ser(obj): +def pessimistic_json_iso_dttm_ser(obj: Any) -> str: """Proxy to call json_iso_dttm_ser in a pessimistic way If one of object is not serializable to json, it will still succeed""" return json_iso_dttm_ser(obj, pessimistic=True) -def json_int_dttm_ser(obj): +def json_int_dttm_ser(obj: Any) -> float: """json serializer that deals with dates""" val = base_json_conv(obj) if val is not None: @@ -441,7 +435,7 @@ def json_int_dttm_ser(obj): return obj -def json_dumps_w_dates(payload): +def json_dumps_w_dates(payload: Dict[Any, Any]) -> str: return json.dumps(payload, default=json_int_dttm_ser) @@ -522,7 +516,7 @@ def readfile(file_path: str) -> Optional[str]: def generic_find_constraint_name( table: str, columns: Set[str], referenced: str, db: SQLA -): +) -> Optional[str]: """Utility to find a constraint name in alembic migrations""" t = sa.Table(table, db.metadata, autoload=True, autoload_with=db.engine) @@ -530,10 +524,12 @@ def generic_find_constraint_name( if fk.referred_table.name == referenced and set(fk.column_keys) == columns: return fk.name + return None + def generic_find_fk_constraint_name( - table: str, columns: Set[str], referenced: str, insp -): + table: str, columns: Set[str], referenced: str, insp: Inspector +) -> Optional[str]: """Utility to find a foreign-key constraint name in alembic migrations""" for fk in insp.get_foreign_keys(table): if ( @@ -542,8 +538,12 @@ def generic_find_fk_constraint_name( ): return fk["name"] + return None + -def generic_find_fk_constraint_names(table, columns, referenced, insp): +def generic_find_fk_constraint_names( + table: str, columns: Set[str], referenced: str, insp: Inspector +) -> Set[str]: """Utility to find foreign-key constraint names in alembic migrations""" names = set() @@ -557,13 +557,17 @@ def generic_find_fk_constraint_names(table, columns, referenced, insp): return names -def generic_find_uq_constraint_name(table, columns, insp): +def generic_find_uq_constraint_name( + table: str, columns: Set[str], insp: Inspector +) -> Optional[str]: """Utility to find a unique constraint name in alembic migrations""" for uq in insp.get_unique_constraints(table): if columns == set(uq["column_names"]): return uq["name"] + return None + def get_datasource_full_name( database_name: str, datasource_name: str, schema: Optional[str] = None @@ -582,30 +586,20 @@ def validate_json(obj: Union[bytes, bytearray, str]) -> None: raise SupersetException("JSON is not valid") -def table_has_constraint(table, name, db): - """Utility to find a constraint name in alembic migrations""" - t = sa.Table(table, db.metadata, autoload=True, autoload_with=db.engine) - - for c in t.constraints: - if c.name == name: - return True - return False - - class timeout: """ To be used in a ``with`` block and timeout its content. """ - def __init__(self, seconds=1, error_message="Timeout"): + def __init__(self, seconds: int = 1, error_message: str = "Timeout") -> None: self.seconds = seconds self.error_message = error_message - def handle_timeout(self, signum, frame): + def handle_timeout(self, signum: int, frame: Any) -> None: logger.error("Process timed out") raise SupersetTimeoutException(self.error_message) - def __enter__(self): + def __enter__(self) -> None: try: signal.signal(signal.SIGALRM, self.handle_timeout) signal.alarm(self.seconds) @@ -613,7 +607,7 @@ class timeout: logger.warning("timeout can't be used in the current context") logger.exception(ex) - def __exit__(self, type, value, traceback): + def __exit__(self, type: Any, value: Any, traceback: TracebackType) -> None: try: signal.alarm(0) except ValueError as ex: @@ -621,9 +615,9 @@ class timeout: logger.exception(ex) -def pessimistic_connection_handling(some_engine): +def pessimistic_connection_handling(some_engine: Engine) -> None: @event.listens_for(some_engine, "engine_connect") - def ping_connection(connection, branch): + def ping_connection(connection: Connection, branch: bool) -> None: if branch: # 'branch' refers to a sub-connection of a connection, # we don't want to bother pinging on these. @@ -670,7 +664,14 @@ class QueryStatus: TIMED_OUT: str = "timed_out" -def notify_user_about_perm_udate(granter, user, role, datasource, tpl_name, config): +def notify_user_about_perm_udate( + granter: User, + user: User, + role: Role, + datasource: "BaseDatasource", + tpl_name: str, + config: Dict[str, Any], +) -> None: msg = render_template( tpl_name, granter=granter, user=user, role=role, datasource=datasource ) @@ -762,7 +763,13 @@ def send_email_smtp( send_MIME_email(smtp_mail_from, recipients, msg, config, dryrun=dryrun) -def send_MIME_email(e_from, e_to, mime_msg, config, dryrun=False): +def send_MIME_email( + e_from: str, + e_to: List[str], + mime_msg: MIMEMultipart, + config: Dict[str, Any], + dryrun: bool = False, +) -> None: SMTP_HOST = config["SMTP_HOST"] SMTP_PORT = config["SMTP_PORT"] SMTP_USER = config["SMTP_USER"] @@ -800,7 +807,7 @@ def choicify(values: Iterable[Any]) -> List[Tuple[Any, Any]]: return [(v, v) for v in values] -def zlib_compress(data): +def zlib_compress(data: Union[bytes, str]) -> bytes: """ Compress things in a py2/3 safe fashion >>> json_str = '{"test": 1}' @@ -827,7 +834,9 @@ def zlib_decompress(blob: bytes, decode: Optional[bool] = True) -> Union[bytes, return decompressed.decode("utf-8") if decode else decompressed -def to_adhoc(filt, expressionType="SIMPLE", clause="where"): +def to_adhoc( + filt: Dict[str, Any], expressionType: str = "SIMPLE", clause: str = "where" +) -> Dict[str, Any]: result = { "clause": clause.upper(), "expressionType": expressionType, @@ -849,7 +858,7 @@ def to_adhoc(filt, expressionType="SIMPLE", clause="where"): return result -def merge_extra_filters(form_data: dict): +def merge_extra_filters(form_data: Dict[str, Any]) -> None: # extra_filters are temporary/contextual filters (using the legacy constructs) # that are external to the slice definition. We use those for dynamic # interactive filters like the ones emitted by the "Filter Box" visualization. @@ -872,7 +881,7 @@ def merge_extra_filters(form_data: dict): } # Grab list of existing filters 'keyed' on the column and operator - def get_filter_key(f): + def get_filter_key(f: Dict[str, Any]) -> str: if "expressionType" in f: return "{}__{}".format(f["subject"], f["operator"]) else: @@ -945,7 +954,9 @@ def user_label(user: User) -> Optional[str]: return None -def get_or_create_db(database_name, sqlalchemy_uri, *args, **kwargs): +def get_or_create_db( + database_name: str, sqlalchemy_uri: str, *args: Any, **kwargs: Any +) -> "Database": from superset import db from superset.models import core as models @@ -996,7 +1007,7 @@ def get_metric_names(metrics: Sequence[Metric]) -> List[str]: return [get_metric_name(metric) for metric in metrics] -def ensure_path_exists(path: str): +def ensure_path_exists(path: str) -> None: try: os.makedirs(path) except OSError as exc: @@ -1119,7 +1130,7 @@ def add_ago_to_since(since: str) -> str: return since -def convert_legacy_filters_into_adhoc(fd): +def convert_legacy_filters_into_adhoc(fd: FormData) -> None: mapping = {"having": "having_filters", "where": "filters"} if not fd.get("adhoc_filters"): @@ -1138,7 +1149,7 @@ def convert_legacy_filters_into_adhoc(fd): del fd[key] -def split_adhoc_filters_into_base_filters(fd): +def split_adhoc_filters_into_base_filters(fd: FormData) -> None: """ Mutates form data to restructure the adhoc filters in the form of the four base filters, `where`, `having`, `filters`, and `having_filters` which represent @@ -1230,7 +1241,7 @@ def create_ssl_cert_file(certificate: str) -> str: return path -def time_function(func: Callable, *args, **kwargs) -> Tuple[float, Any]: +def time_function(func: Callable, *args: Any, **kwargs: Any) -> Tuple[float, Any]: """ Measures the amount of time a function takes to execute in ms @@ -1296,7 +1307,7 @@ def split( yield s[i:] -def get_iterable(x: Any) -> List: +def get_iterable(x: Any) -> List[Any]: """ Get an iterable (list) representation of the object. diff --git a/superset/utils/dashboard_filter_scopes_converter.py b/superset/utils/dashboard_filter_scopes_converter.py index 6954990..f77e0e0 100644 --- a/superset/utils/dashboard_filter_scopes_converter.py +++ b/superset/utils/dashboard_filter_scopes_converter.py @@ -17,14 +17,16 @@ import json import logging from collections import defaultdict -from typing import Dict, List +from typing import Any, Dict, List from superset.models.slice import Slice logger = logging.getLogger(__name__) -def convert_filter_scopes(json_metadata: Dict, filters: List[Slice]): +def convert_filter_scopes( + json_metadata: Dict[Any, Any], filters: List[Slice] +) -> Dict[int, Dict[str, Dict[str, Any]]]: filter_scopes = {} immuned_by_id: List[int] = json_metadata.get("filter_immune_slices") or [] immuned_by_column: Dict = defaultdict(list) @@ -34,7 +36,9 @@ def convert_filter_scopes(json_metadata: Dict, filters: List[Slice]): for column in columns: immuned_by_column[column].append(int(slice_id)) - def add_filter_scope(filter_field, filter_id): + def add_filter_scope( + filter_fields: Dict[str, Dict[str, Any]], filter_field: str, filter_id: int + ) -> None: # in case filter field is invalid if isinstance(filter_field, str): current_filter_immune = list( @@ -54,17 +58,17 @@ def convert_filter_scopes(json_metadata: Dict, filters: List[Slice]): configs = slice_params.get("filter_configs") or [] if slice_params.get("date_filter"): - add_filter_scope("__time_range", filter_id) + add_filter_scope(filter_fields, "__time_range", filter_id) if slice_params.get("show_sqla_time_column"): - add_filter_scope("__time_col", filter_id) + add_filter_scope(filter_fields, "__time_col", filter_id) if slice_params.get("show_sqla_time_granularity"): - add_filter_scope("__time_grain", filter_id) + add_filter_scope(filter_fields, "__time_grain", filter_id) if slice_params.get("show_druid_time_granularity"): - add_filter_scope("__granularity", filter_id) + add_filter_scope(filter_fields, "__granularity", filter_id) if slice_params.get("show_druid_time_origin"): - add_filter_scope("druid_time_origin", filter_id) + add_filter_scope(filter_fields, "druid_time_origin", filter_id) for config in configs: - add_filter_scope(config.get("column"), filter_id) + add_filter_scope(filter_fields, config.get("column"), filter_id) if filter_fields: filter_scopes[filter_id] = filter_fields diff --git a/superset/utils/dashboard_import_export.py b/superset/utils/dashboard_import_export.py index 53100f8..19dd2e2 100644 --- a/superset/utils/dashboard_import_export.py +++ b/superset/utils/dashboard_import_export.py @@ -19,6 +19,10 @@ import json import logging import time from datetime import datetime +from io import BytesIO +from typing import Any, Dict, Optional + +from sqlalchemy.orm import Session from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn from superset.models.dashboard import Dashboard @@ -27,7 +31,7 @@ from superset.models.slice import Slice logger = logging.getLogger(__name__) -def decode_dashboards(o): +def decode_dashboards(o: Dict[str, Any]) -> Any: """ Function to be passed into json.loads obj_hook parameter Recreates the dashboard object from a json representation. @@ -50,7 +54,9 @@ def decode_dashboards(o): return o -def import_dashboards(session, data_stream, import_time=None): +def import_dashboards( + session: Session, data_stream: BytesIO, import_time: Optional[int] = None +) -> None: """Imports dashboards from a stream to databases""" current_tt = int(time.time()) import_time = current_tt if import_time is None else import_time @@ -64,7 +70,7 @@ def import_dashboards(session, data_stream, import_time=None): session.commit() -def export_dashboards(session): +def export_dashboards(session: Session) -> str: """Returns all dashboards metadata as a json dump""" logger.info("Starting export") dashboards = session.query(Dashboard) diff --git a/superset/utils/dates.py b/superset/utils/dates.py index a1826e2..021ec7f 100644 --- a/superset/utils/dates.py +++ b/superset/utils/dates.py @@ -21,7 +21,7 @@ import pytz EPOCH = datetime(1970, 1, 1) -def datetime_to_epoch(dttm): +def datetime_to_epoch(dttm: datetime) -> float: if dttm.tzinfo: dttm = dttm.replace(tzinfo=pytz.utc) epoch_with_tz = pytz.utc.localize(EPOCH) @@ -29,5 +29,5 @@ def datetime_to_epoch(dttm): return (dttm - EPOCH).total_seconds() * 1000 -def now_as_float(): +def now_as_float() -> float: return datetime_to_epoch(datetime.utcnow()) diff --git a/superset/utils/decorators.py b/superset/utils/decorators.py index 52ba61f..a1165c5 100644 --- a/superset/utils/decorators.py +++ b/superset/utils/decorators.py @@ -17,11 +17,14 @@ import logging from datetime import datetime, timedelta from functools import wraps +from typing import Any, Callable, Iterator from contextlib2 import contextmanager from flask import request +from werkzeug.wrappers.etag import ETagResponseMixin from superset import app, cache +from superset.stats_logger import BaseStatsLogger from superset.utils.dates import now_as_float # If a user sets `max_age` to 0, for long the browser should cache the @@ -32,7 +35,7 @@ logger = logging.getLogger(__name__) @contextmanager -def stats_timing(stats_key, stats_logger): +def stats_timing(stats_key: str, stats_logger: BaseStatsLogger) -> Iterator[float]: """Provide a transactional scope around a series of operations.""" start_ts = now_as_float() try: @@ -43,7 +46,7 @@ def stats_timing(stats_key, stats_logger): stats_logger.timing(stats_key, now_as_float() - start_ts) -def etag_cache(max_age, check_perms=bool): +def etag_cache(max_age: int, check_perms: Callable) -> Callable: """ A decorator for caching views and handling etag conditional requests. @@ -57,9 +60,9 @@ def etag_cache(max_age, check_perms=bool): """ - def decorator(f): + def decorator(f: Callable) -> Callable: @wraps(f) - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any) -> ETagResponseMixin: # check if the user can access the resource check_perms(*args, **kwargs) @@ -77,7 +80,9 @@ def etag_cache(max_age, check_perms=bool): key_args = list(args) key_kwargs = kwargs.copy() key_kwargs.update(request.args) - cache_key = wrapper.make_cache_key(f, *key_args, **key_kwargs) + cache_key = wrapper.make_cache_key( # type: ignore + f, *key_args, **key_kwargs + ) response = cache.get(cache_key) except Exception: # pylint: disable=broad-except if app.debug: @@ -109,9 +114,9 @@ def etag_cache(max_age, check_perms=bool): return response.make_conditional(request) if cache: - wrapper.uncached = f - wrapper.cache_timeout = max_age - wrapper.make_cache_key = cache._memoize_make_cache_key( # pylint: disable=protected-access + wrapper.uncached = f # type: ignore + wrapper.cache_timeout = max_age # type: ignore + wrapper.make_cache_key = cache._memoize_make_cache_key( # type: ignore # pylint: disable=protected-access make_name=None, timeout=max_age ) diff --git a/superset/utils/dict_import_export.py b/superset/utils/dict_import_export.py index d7ede85..a58635d 100644 --- a/superset/utils/dict_import_export.py +++ b/superset/utils/dict_import_export.py @@ -16,6 +16,9 @@ # under the License. # pylint: disable=C,R,W import logging +from typing import Any, Dict, List, Optional + +from sqlalchemy.orm import Session from superset.connectors.druid.models import DruidCluster from superset.models.core import Database @@ -25,7 +28,7 @@ DRUID_CLUSTERS_KEY = "druid_clusters" logger = logging.getLogger(__name__) -def export_schema_to_dict(back_references): +def export_schema_to_dict(back_references: bool) -> Dict[str, Any]: """Exports the supported import/export schema to a dictionary""" databases = [ Database.export_schema(recursive=True, include_parent_ref=back_references) @@ -41,7 +44,9 @@ def export_schema_to_dict(back_references): return data -def export_to_dict(session, recursive, back_references, include_defaults): +def export_to_dict( + session: Session, recursive: bool, back_references: bool, include_defaults: bool +) -> Dict[str, Any]: """Exports databases and druid clusters to a dictionary""" logger.info("Starting export") dbs = session.query(Database) @@ -72,8 +77,12 @@ def export_to_dict(session, recursive, back_references, include_defaults): return data -def import_from_dict(session, data, sync=[]): +def import_from_dict( + session: Session, data: Dict[str, Any], sync: Optional[List[str]] = None +) -> None: """Imports databases and druid clusters from dictionary""" + if not sync: + sync = [] if isinstance(data, dict): logger.info("Importing %d %s", len(data.get(DATABASES_KEY, [])), DATABASES_KEY) for database in data.get(DATABASES_KEY, []): diff --git a/superset/utils/feature_flag_manager.py b/superset/utils/feature_flag_manager.py index 654607b..88f19c2 100644 --- a/superset/utils/feature_flag_manager.py +++ b/superset/utils/feature_flag_manager.py @@ -15,25 +15,33 @@ # specific language governing permissions and limitations # under the License. from copy import deepcopy +from typing import Any, Dict + +from flask import Flask class FeatureFlagManager: def __init__(self) -> None: super().__init__() self._get_feature_flags_func = None - self._feature_flags = None + self._feature_flags: Dict[str, Any] = {} - def init_app(self, app): - self._get_feature_flags_func = app.config.get("GET_FEATURE_FLAGS_FUNC") - self._feature_flags = app.config.get("DEFAULT_FEATURE_FLAGS") or {} - self._feature_flags.update(app.config.get("FEATURE_FLAGS") or {}) + def init_app(self, app: Flask) -> None: + self._get_feature_flags_func = app.config["GET_FEATURE_FLAGS_FUNC"] + self._feature_flags = app.config["DEFAULT_FEATURE_FLAGS"] + self._feature_flags.update(app.config["FEATURE_FLAGS"]) - def get_feature_flags(self): + def get_feature_flags(self) -> Dict[str, Any]: if self._get_feature_flags_func: return self._get_feature_flags_func(deepcopy(self._feature_flags)) return self._feature_flags - def is_feature_enabled(self, feature) -> bool: + def is_feature_enabled(self, feature: str) -> bool: """Utility function for checking whether a feature is turned on""" - return self.get_feature_flags().get(feature) + feature_flags = self.get_feature_flags() + + if feature_flags and feature in feature_flags: + return feature_flags[feature] + + return False diff --git a/superset/utils/log.py b/superset/utils/log.py index 5d8c52e..aafe3b8 100644 --- a/superset/utils/log.py +++ b/superset/utils/log.py @@ -21,19 +21,23 @@ import logging import textwrap from abc import ABC, abstractmethod from datetime import datetime -from typing import Any, cast, Type +from typing import Any, Callable, cast, Optional, Type from flask import current_app, g, request +from superset.stats_logger import BaseStatsLogger + class AbstractEventLogger(ABC): @abstractmethod - def log(self, user_id, action, *args, **kwargs): + def log( + self, user_id: Optional[int], action: str, *args: Any, **kwargs: Any + ) -> None: pass - def log_this(self, f): + def log_this(self, f: Callable) -> Callable: @functools.wraps(f) - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any) -> Any: user_id = None if g.user: user_id = g.user.get_id() @@ -49,7 +53,12 @@ class AbstractEventLogger(ABC): try: slice_id = int( - slice_id or json.loads(form_data.get("form_data")).get("slice_id") + slice_id + or json.loads( + form_data.get("form_data") # type: ignore + ).get( + "slice_id" + ) ) except (ValueError, TypeError): slice_id = 0 @@ -62,7 +71,7 @@ class AbstractEventLogger(ABC): # bulk insert try: explode_by = form_data.get("explode") - records = json.loads(form_data.get(explode_by)) + records = json.loads(form_data.get(explode_by)) # type: ignore except Exception: # pylint: disable=broad-except records = [form_data] @@ -82,11 +91,11 @@ class AbstractEventLogger(ABC): return wrapper @property - def stats_logger(self): + def stats_logger(self) -> BaseStatsLogger: return current_app.config["STATS_LOGGER"] -def get_event_logger_from_cfg_value(cfg_value: object) -> AbstractEventLogger: +def get_event_logger_from_cfg_value(cfg_value: Any) -> AbstractEventLogger: """ This function implements the deprecation of assignment of class objects to EVENT_LOGGER configuration, and validates @@ -130,7 +139,9 @@ def get_event_logger_from_cfg_value(cfg_value: object) -> AbstractEventLogger: class DBEventLogger(AbstractEventLogger): - def log(self, user_id, action, *args, **kwargs): # pylint: disable=too-many-locals + def log( # pylint: disable=too-many-locals + self, user_id: Optional[int], action: str, *args: Any, **kwargs: Any + ) -> None: from superset.models.core import Log records = kwargs.get("records", list()) @@ -141,6 +152,7 @@ class DBEventLogger(AbstractEventLogger): logs = list() for record in records: + json_string: Optional[str] try: json_string = json.dumps(record) except Exception: # pylint: disable=broad-except diff --git a/superset/utils/pandas_postprocessing.py b/superset/utils/pandas_postprocessing.py index dabebed..39a4278 100644 --- a/superset/utils/pandas_postprocessing.py +++ b/superset/utils/pandas_postprocessing.py @@ -73,8 +73,8 @@ WHITELIST_CUMULATIVE_FUNCTIONS = ( def validate_column_args(*argnames: str) -> Callable: - def wrapper(func): - def wrapped(df, **options): + def wrapper(func: Callable) -> Callable: + def wrapped(df: DataFrame, **options: Any) -> Any: columns = df.columns.tolist() for name in argnames: if name in options and not all( @@ -159,7 +159,7 @@ def pivot( # pylint: disable=too-many-arguments metric_fill_value: Optional[Any] = None, column_fill_value: Optional[str] = None, drop_missing_columns: Optional[bool] = True, - combine_value_with_metric=False, + combine_value_with_metric: bool = False, marginal_distributions: Optional[bool] = None, marginal_distribution_name: Optional[str] = None, ) -> DataFrame: diff --git a/superset/utils/screenshots.py b/superset/utils/screenshots.py index 18283e7..e07d2a2 100644 --- a/superset/utils/screenshots.py +++ b/superset/utils/screenshots.py @@ -18,7 +18,7 @@ import logging import time import urllib.parse from io import BytesIO -from typing import Callable, Dict, List, Optional, Tuple, TYPE_CHECKING +from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING from flask import current_app, request, Response, session, url_for from flask_login import login_user @@ -91,7 +91,7 @@ def headless_url(path: str) -> str: return urllib.parse.urljoin(current_app.config.get("WEBDRIVER_BASEURL", ""), path) -def get_url_path(view: str, **kwargs) -> str: +def get_url_path(view: str, **kwargs: Any) -> str: with current_app.test_request_context(): return headless_url(url_for(view, **kwargs)) @@ -135,7 +135,7 @@ class AuthWebDriverProxy: return self._auth_func(driver, user) @staticmethod - def destroy(driver: WebDriver, tries=2): + def destroy(driver: WebDriver, tries: int = 2) -> None: """Destroy a driver""" # This is some very flaky code in selenium. Hence the retries # and catch-all exceptions diff --git a/superset/utils/url_map_converters.py b/superset/utils/url_map_converters.py index 6697b7d..463dfcd 100644 --- a/superset/utils/url_map_converters.py +++ b/superset/utils/url_map_converters.py @@ -14,22 +14,24 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from werkzeug.routing import BaseConverter +from typing import Any, List + +from werkzeug.routing import BaseConverter, Map from superset.models.tags import ObjectTypes class RegexConverter(BaseConverter): - def __init__(self, url_map, *items): - super(RegexConverter, self).__init__(url_map) + def __init__(self, url_map: Map, *items: List[str]) -> None: + super(RegexConverter, self).__init__(url_map) # type: ignore self.regex = items[0] class ObjectTypeConverter(BaseConverter): """Validate that object_type is indeed an object type.""" - def to_python(self, value): + def to_python(self, value: str) -> Any: return ObjectTypes[value] - def to_url(self, value): + def to_url(self, value: Any) -> str: return value.name diff --git a/superset/views/core.py b/superset/views/core.py index 716c11a..60454d9 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -2164,7 +2164,7 @@ class Superset(BaseSupersetView): return json_error_response(str(ex)) spec = mydb.db_engine_spec - query_cost_formatters = get_feature_flags().get( + query_cost_formatters: Dict[str, Any] = get_feature_flags().get( "QUERY_COST_FORMATTERS_BY_ENGINE", {} ) query_cost_formatter = query_cost_formatters.get( diff --git a/tests/utils_tests.py b/tests/utils_tests.py index 4a4b640..02d4a88 100644 --- a/tests/utils_tests.py +++ b/tests/utils_tests.py @@ -38,7 +38,6 @@ from superset.utils.core import ( base_json_conv, convert_legacy_filters_into_adhoc, create_ssl_cert_file, - datetime_f, format_timedelta, get_iterable, get_email_address_list, @@ -560,17 +559,6 @@ class UtilsTestCase(SupersetTestCase): url_params["dashboard_ids"], form_data["url_params"]["dashboard_ids"] ) - def test_datetime_f(self): - self.assertEqual( - datetime_f(datetime(1990, 9, 21, 19, 11, 19, 626096)), - "<nobr>1990-09-21T19:11:19.626096</nobr>", - ) - self.assertEqual(len(datetime_f(datetime.now())), 28) - self.assertEqual(datetime_f(None), "<nobr>None</nobr>") - iso = datetime.now().isoformat()[:10].split("-") - [a, b, c] = [int(v) for v in iso] - self.assertEqual(datetime_f(datetime(a, b, c)), "<nobr>00:00:00</nobr>") - def test_format_timedelta(self): self.assertEqual(format_timedelta(timedelta(0)), "0:00:00") self.assertEqual(format_timedelta(timedelta(days=1)), "1 day, 0:00:00")