This is an automated email from the ASF dual-hosted git repository. johnbodley pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-superset.git
The following commit(s) were added to refs/heads/master by this push: new 244677c style(mypy): Enforcing typing for superset (#9943) 244677c is described below commit 244677cf5e0ecb7c767455e96655af6c18cc58bc Author: John Bodley <4567245+john-bod...@users.noreply.github.com> AuthorDate: Wed Jun 3 15:26:12 2020 -0700 style(mypy): Enforcing typing for superset (#9943) Co-authored-by: John Bodley <john.bod...@airbnb.com> --- setup.cfg | 2 +- superset/app.py | 66 +++++---- superset/cli.py | 56 ++++---- superset/config.py | 21 +-- superset/exceptions.py | 2 +- superset/extensions.py | 45 +++--- superset/forms.py | 16 +-- superset/jinja_context.py | 26 ++-- superset/sql_lab.py | 76 +++++----- superset/sql_parse.py | 6 +- superset/stats_logger.py | 27 ++-- superset/typing.py | 1 + superset/viz.py | 357 ++++++++++++++++++++++++++-------------------- superset/viz_sip38.py | 3 +- tests/viz_tests.py | 2 +- 15 files changed, 393 insertions(+), 313 deletions(-) diff --git a/setup.cfg b/setup.cfg index 1115de9..fc94a24 100644 --- a/setup.cfg +++ b/setup.cfg @@ -53,7 +53,7 @@ order_by_type = false ignore_missing_imports = true no_implicit_optional = true -[mypy-superset.bin.*,superset.charts.*,superset.commands.*,superset.common.*,superset.connectors.*,superset.dao.*,superset.dashboards.*,superset.datasets.*,superset.db_engine_specs.*,superset.db_engines.*,superset.examples.*,superset.migrations.*,superset.models.*,uperset.queries.*,superset.security.*,superset.sql_validators.*,superset.tasks.*,superset.translations.*,superset.utils.*,superset.views.chart.*,superset.views.dashboard.*,superset.views.database.*] +[mypy-superset,superset.app,superset.bin.*,superset.charts.*,superset.cli,superset.commands.*,superset.common.*,superset.config,superset.connectors.*,superset.constants,superset.dataframe,superset.dao.*,superset.dashboards.*,superset.datasets.*,superset.db_engine_specs.*,superset.db_engines.*,superset.errors,superset.examples.*,superset.exceptions,superset.extensions,superset.forms,superset.jinja_context,superset.legacy,superset.migrations.*,superset.models.*,superset.result_set,superset [...] check_untyped_defs = true disallow_untyped_calls = true disallow_untyped_defs = true diff --git a/superset/app.py b/superset/app.py index 98f1459..18165ed 100644 --- a/superset/app.py +++ b/superset/app.py @@ -17,6 +17,7 @@ import logging import os +from typing import Any, Callable, Dict import wtforms_json from flask import Flask, redirect @@ -41,13 +42,14 @@ from superset.extensions import ( talisman, ) from superset.security import SupersetSecurityManager +from superset.typing import FlaskResponse from superset.utils.core import pessimistic_connection_handling from superset.utils.log import DBEventLogger, get_event_logger_from_cfg_value logger = logging.getLogger(__name__) -def create_app(): +def create_app() -> Flask: app = Flask(__name__) try: @@ -68,7 +70,7 @@ def create_app(): class SupersetIndexView(IndexView): @expose("/") - def index(self): + def index(self) -> FlaskResponse: return redirect("/superset/welcome") @@ -109,8 +111,8 @@ class SupersetAppInitializer: abstract = True # Grab each call into the task and set up an app context - def __call__(self, *args, **kwargs): - with flask_app.app_context(): + def __call__(self, *args: Any, **kwargs: Any) -> Any: + with flask_app.app_context(): # type: ignore return task_base.__call__(self, *args, **kwargs) celery_app.Task = AppContextTask @@ -454,51 +456,41 @@ class SupersetAppInitializer: order to fully init the app """ self.pre_init() - self.setup_db() - self.configure_celery() - self.setup_event_logger() - self.setup_bundle_manifest() - self.register_blueprints() - self.configure_wtf() - self.configure_logging() - self.configure_middlewares() - self.configure_cache() - self.configure_jinja_context() - with self.flask_app.app_context(): + with self.flask_app.app_context(): # type: ignore self.init_app_in_ctx() self.post_init() - def setup_event_logger(self): + def setup_event_logger(self) -> None: _event_logger["event_logger"] = get_event_logger_from_cfg_value( self.flask_app.config.get("EVENT_LOGGER", DBEventLogger()) ) - def configure_data_sources(self): + def configure_data_sources(self) -> None: # Registering sources module_datasource_map = self.config["DEFAULT_MODULE_DS_MAP"] module_datasource_map.update(self.config["ADDITIONAL_MODULE_DS_MAP"]) ConnectorRegistry.register_sources(module_datasource_map) - def configure_cache(self): + def configure_cache(self) -> None: cache_manager.init_app(self.flask_app) results_backend_manager.init_app(self.flask_app) - def configure_feature_flags(self): + def configure_feature_flags(self) -> None: feature_flag_manager.init_app(self.flask_app) - def configure_fab(self): + def configure_fab(self) -> None: if self.config["SILENCE_FAB"]: logging.getLogger("flask_appbuilder").setLevel(logging.ERROR) @@ -516,7 +508,7 @@ class SupersetAppInitializer: appbuilder.update_perms = False appbuilder.init_app(self.flask_app, db.session) - def configure_url_map_converters(self): + def configure_url_map_converters(self) -> None: # # Doing local imports here as model importing causes a reference to # app.config to be invoked and we need the current_app to have been setup @@ -527,10 +519,10 @@ class SupersetAppInitializer: self.flask_app.url_map.converters["regex"] = RegexConverter self.flask_app.url_map.converters["object_type"] = ObjectTypeConverter - def configure_jinja_context(self): + def configure_jinja_context(self) -> None: jinja_context_manager.init_app(self.flask_app) - def configure_middlewares(self): + def configure_middlewares(self) -> None: if self.config["ENABLE_CORS"]: from flask_cors import CORS @@ -539,24 +531,28 @@ class SupersetAppInitializer: if self.config["ENABLE_PROXY_FIX"]: from werkzeug.middleware.proxy_fix import ProxyFix - self.flask_app.wsgi_app = ProxyFix( + self.flask_app.wsgi_app = ProxyFix( # type: ignore self.flask_app.wsgi_app, **self.config["PROXY_FIX_CONFIG"] ) if self.config["ENABLE_CHUNK_ENCODING"]: class ChunkedEncodingFix: # pylint: disable=too-few-public-methods - def __init__(self, app): + def __init__(self, app: Flask) -> None: self.app = app - def __call__(self, environ, start_response): + def __call__( + self, environ: Dict[str, Any], start_response: Callable + ) -> Any: # Setting wsgi.input_terminated tells werkzeug.wsgi to ignore # content-length and read the stream till the end. if environ.get("HTTP_TRANSFER_ENCODING", "").lower() == "chunked": environ["wsgi.input_terminated"] = True return self.app(environ, start_response) - self.flask_app.wsgi_app = ChunkedEncodingFix(self.flask_app.wsgi_app) + self.flask_app.wsgi_app = ChunkedEncodingFix( # type: ignore + self.flask_app.wsgi_app # type: ignore + ) if self.config["UPLOAD_FOLDER"]: try: @@ -565,7 +561,9 @@ class SupersetAppInitializer: pass for middleware in self.config["ADDITIONAL_MIDDLEWARE"]: - self.flask_app.wsgi_app = middleware(self.flask_app.wsgi_app) + self.flask_app.wsgi_app = middleware( # type: ignore + self.flask_app.wsgi_app + ) # Flask-Compress if self.config["ENABLE_FLASK_COMPRESS"]: @@ -574,27 +572,27 @@ class SupersetAppInitializer: if self.config["TALISMAN_ENABLED"]: talisman.init_app(self.flask_app, **self.config["TALISMAN_CONFIG"]) - def configure_logging(self): + def configure_logging(self) -> None: self.config["LOGGING_CONFIGURATOR"].configure_logging( self.config, self.flask_app.debug ) - def setup_db(self): + def setup_db(self) -> None: db.init_app(self.flask_app) - with self.flask_app.app_context(): + with self.flask_app.app_context(): # type: ignore pessimistic_connection_handling(db.engine) migrate.init_app(self.flask_app, db=db, directory=APP_DIR + "/migrations") - def configure_wtf(self): + def configure_wtf(self) -> None: if self.config["WTF_CSRF_ENABLED"]: csrf = CSRFProtect(self.flask_app) csrf_exempt_list = self.config["WTF_CSRF_EXEMPT_LIST"] for ex in csrf_exempt_list: csrf.exempt(ex) - def register_blueprints(self): + def register_blueprints(self) -> None: for bp in self.config["BLUEPRINTS"]: try: logger.info(f"Registering blueprint: '{bp.name}'") @@ -602,5 +600,5 @@ class SupersetAppInitializer: except Exception: # pylint: disable=broad-except logger.exception("blueprint registration failed") - def setup_bundle_manifest(self): + def setup_bundle_manifest(self) -> None: manifest_processor.init_app(self.flask_app) diff --git a/superset/cli.py b/superset/cli.py index 090e1fb..a136010 100755 --- a/superset/cli.py +++ b/superset/cli.py @@ -19,10 +19,11 @@ import logging from datetime import datetime from subprocess import Popen from sys import stdout -from typing import Type, Union +from typing import Any, Dict, Type, Union import click import yaml +from celery.utils.abstract import CallableTask from colorama import Fore, Style from flask import g from flask.cli import FlaskGroup, with_appcontext @@ -56,17 +57,17 @@ def normalize_token(token_name: str) -> str: context_settings={"token_normalize_func": normalize_token}, ) @with_appcontext -def superset(): +def superset() -> None: """This is a management script for the Superset application.""" @app.shell_context_processor - def make_shell_context(): # pylint: disable=unused-variable + def make_shell_context() -> Dict[str, Any]: # pylint: disable=unused-variable return dict(app=app, db=db) @superset.command() @with_appcontext -def init(): +def init() -> None: """Inits the Superset application""" appbuilder.add_permissions(update_perms=True) security_manager.sync_role_definitions() @@ -75,7 +76,7 @@ def init(): @superset.command() @with_appcontext @click.option("--verbose", "-v", is_flag=True, help="Show extra information") -def version(verbose): +def version(verbose: bool) -> None: """Prints the current version number""" print(Fore.BLUE + "-=" * 15) print( @@ -90,7 +91,9 @@ def version(verbose): print(Style.RESET_ALL) -def load_examples_run(load_test_data, only_metadata=False, force=False): +def load_examples_run( + load_test_data: bool, only_metadata: bool = False, force: bool = False +) -> None: if only_metadata: print("Loading examples metadata") else: @@ -160,7 +163,9 @@ def load_examples_run(load_test_data, only_metadata=False, force=False): @click.option( "--force", "-f", is_flag=True, help="Force load data even if table already exists" ) -def load_examples(load_test_data, only_metadata=False, force=False): +def load_examples( + load_test_data: bool, only_metadata: bool = False, force: bool = False +) -> None: """Loads a set of Slices and Dashboards and a supporting dataset """ load_examples_run(load_test_data, only_metadata, force) @@ -169,7 +174,7 @@ def load_examples(load_test_data, only_metadata=False, force=False): @superset.command() @click.option("--database_name", "-d", help="Database name to change") @click.option("--uri", "-u", help="Database URI to change") -def set_database_uri(database_name, uri): +def set_database_uri(database_name: str, uri: str) -> None: """Updates a database connection URI """ utils.get_or_create_db(database_name, uri) @@ -189,7 +194,7 @@ def set_database_uri(database_name, uri): default=False, help="Specify using 'merge' property during operation. " "Default value is False.", ) -def refresh_druid(datasource, merge): +def refresh_druid(datasource: str, merge: bool) -> None: """Refresh druid datasources""" session = db.session() from superset.connectors.druid.models import DruidCluster @@ -226,7 +231,7 @@ def refresh_druid(datasource, merge): default=None, help="Specify the user name to assign dashboards to", ) -def import_dashboards(path, recursive, username): +def import_dashboards(path: str, recursive: bool, username: str) -> None: """Import dashboards from JSON""" from superset.utils import dashboard_import_export @@ -258,7 +263,7 @@ def import_dashboards(path, recursive, username): @click.option( "--print_stdout", "-p", is_flag=True, default=False, help="Print JSON to stdout" ) -def export_dashboards(print_stdout, dashboard_file): +def export_dashboards(dashboard_file: str, print_stdout: bool) -> None: """Export dashboards to JSON""" from superset.utils import dashboard_import_export @@ -295,7 +300,7 @@ def export_dashboards(print_stdout, dashboard_file): default=False, help="recursively search the path for yaml files", ) -def import_datasources(path, sync, recursive): +def import_datasources(path: str, sync: str, recursive: bool) -> None: """Import datasources from YAML""" from superset.utils import dict_import_export @@ -345,8 +350,11 @@ def import_datasources(path, sync, recursive): help="Include fields containing defaults", ) def export_datasources( - print_stdout, datasource_file, back_references, include_defaults -): + print_stdout: bool, + datasource_file: str, + back_references: bool, + include_defaults: bool, +) -> None: """Export datasources to YAML""" from superset.utils import dict_import_export @@ -373,7 +381,7 @@ def export_datasources( default=False, help="Include parent back references", ) -def export_datasource_schema(back_references): +def export_datasource_schema(back_references: bool) -> None: """Export datasource YAML schema to stdout""" from superset.utils import dict_import_export @@ -383,7 +391,7 @@ def export_datasource_schema(back_references): @superset.command() @with_appcontext -def update_datasources_cache(): +def update_datasources_cache() -> None: """Refresh sqllab datasources cache""" from superset.models.core import Database @@ -406,7 +414,7 @@ def update_datasources_cache(): @click.option( "--workers", "-w", type=int, help="Number of celery server workers to fire up" ) -def worker(workers): +def worker(workers: int) -> None: """Starts a Superset worker for async SQL query execution.""" logger.info( "The 'superset worker' command is deprecated. Please use the 'celery " @@ -431,7 +439,7 @@ def worker(workers): @click.option( "-a", "--address", default="localhost", help="Address on which to run the service" ) -def flower(port, address): +def flower(port: int, address: str) -> None: """Runs a Celery Flower web server Celery Flower is a UI to monitor the Celery operation on a given @@ -487,7 +495,7 @@ def compute_thumbnails( charts_only: bool, force: bool, model_id: int, -): +) -> None: """Compute thumbnails""" from superset.models.dashboard import Dashboard from superset.models.slice import Slice @@ -500,8 +508,8 @@ def compute_thumbnails( friendly_type: str, model_cls: Union[Type[Dashboard], Type[Slice]], model_id: int, - compute_func, - ): + compute_func: CallableTask, + ) -> None: query = db.session.query(model_cls) if model_id: query = query.filter(model_cls.id.in_(model_id)) @@ -528,7 +536,7 @@ def compute_thumbnails( @superset.command() @with_appcontext -def load_test_users(): +def load_test_users() -> None: """ Loads admin, alpha, and gamma user for testing purposes @@ -538,7 +546,7 @@ def load_test_users(): load_test_users_run() -def load_test_users_run(): +def load_test_users_run() -> None: """ Loads admin, alpha, and gamma user for testing purposes @@ -583,7 +591,7 @@ def load_test_users_run(): @superset.command() @with_appcontext -def sync_tags(): +def sync_tags() -> None: """Rebuilds special tags (owner, type, favorited by).""" # pylint: disable=no-member metadata = Model.metadata diff --git a/superset/config.py b/superset/config.py index 738c251..35dbbf8 100644 --- a/superset/config.py +++ b/superset/config.py @@ -28,8 +28,9 @@ import os import sys from collections import OrderedDict from datetime import date -from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING +from typing import Any, Callable, Dict, List, Optional, Type, TYPE_CHECKING +from cachelib.base import BaseCache from celery.schedules import crontab from dateutil import tz from flask_appbuilder.security.manager import AUTH_DB @@ -78,7 +79,7 @@ PACKAGE_JSON_FILE = os.path.join(BASE_DIR, "static", "assets", "package.json") FAVICONS = [{"href": "/static/assets/images/favicon.png"}] -def _try_json_readversion(filepath): +def _try_json_readversion(filepath: str) -> Optional[str]: try: with open(filepath, "r") as f: return json.load(f).get("version") @@ -86,7 +87,9 @@ def _try_json_readversion(filepath): return None -def _try_json_readsha(filepath, length): # pylint: disable=unused-argument +def _try_json_readsha( # pylint: disable=unused-argument + filepath: str, length: int +) -> Optional[str]: try: with open(filepath, "r") as f: return json.load(f).get("GIT_SHA")[:length] @@ -453,6 +456,7 @@ BACKUP_COUNT = 30 # user=None, # client=None, # security_manager=None, +# log_params=None, # ): # pass QUERY_LOGGER = None @@ -578,10 +582,9 @@ SQLLAB_CTAS_SCHEMA_NAME_FUNC: Optional[ Callable[["Database", "models.User", str, str], str] ] = None -# An instantiated derivative of cachelib.base.BaseCache -# if enabled, it can be used to store the results of long-running queries +# If enabled, it can be used to store the results of long-running queries # in SQL Lab by using the "Run Async" button/feature -RESULTS_BACKEND = None +RESULTS_BACKEND: Optional[BaseCache] = None # Use PyArrow and MessagePack for async query results serialization, # rather than JSON. This feature requires additional testing from the @@ -604,7 +607,7 @@ CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC: Callable[ # The namespace within hive where the tables created from # uploading CSVs will be stored. -UPLOADED_CSV_HIVE_NAMESPACE = None +UPLOADED_CSV_HIVE_NAMESPACE: Optional[str] = None # Function that computes the allowed schemas for the CSV uploads. # Allowed schemas will be a union of schemas_allowed_for_csv_upload @@ -614,7 +617,7 @@ UPLOADED_CSV_HIVE_NAMESPACE = None ALLOWED_USER_CSV_SCHEMA_FUNC: Callable[ ["Database", "models.User"], List[str] ] = lambda database, user: [ - UPLOADED_CSV_HIVE_NAMESPACE # type: ignore + UPLOADED_CSV_HIVE_NAMESPACE ] if UPLOADED_CSV_HIVE_NAMESPACE else [] # A dictionary of items that gets merged into the Jinja context for @@ -628,7 +631,7 @@ JINJA_CONTEXT_ADDONS: Dict[str, Callable] = {} # dictionary, which means the existing keys get overwritten by the content of this # dictionary. The customized addons don't necessarily need to use jinjia templating # language. This allows you to define custom logic to process macro template. -CUSTOM_TEMPLATE_PROCESSORS = {} # type: Dict[str, BaseTemplateProcessor] +CUSTOM_TEMPLATE_PROCESSORS: Dict[str, Type[BaseTemplateProcessor]] = {} # Roles that are controlled by the API / Superset and should not be changes # by humans. diff --git a/superset/exceptions.py b/superset/exceptions.py index 59ea042..51bd85f 100644 --- a/superset/exceptions.py +++ b/superset/exceptions.py @@ -32,7 +32,7 @@ class SupersetException(Exception): super().__init__(self.message) @property - def exception(self): + def exception(self) -> Optional[Exception]: return self._exception diff --git a/superset/extensions.py b/superset/extensions.py index c501eeb..f321046 100644 --- a/superset/extensions.py +++ b/superset/extensions.py @@ -20,10 +20,12 @@ import random import time import uuid from datetime import datetime, timedelta -from typing import Dict, TYPE_CHECKING # pylint: disable=unused-import +from typing import Any, Callable, Dict, List, Optional, Type, TYPE_CHECKING import celery +from cachelib.base import BaseCache from dateutil.relativedelta import relativedelta +from flask import Flask from flask_appbuilder import AppBuilder, SQLA from flask_migrate import Migrate from flask_talisman import Talisman @@ -32,7 +34,6 @@ from werkzeug.local import LocalProxy from superset.utils.cache_manager import CacheManager from superset.utils.feature_flag_manager import FeatureFlagManager -# Avoid circular import if TYPE_CHECKING: from superset.jinja_context import ( # pylint: disable=unused-import BaseTemplateProcessor, @@ -49,18 +50,18 @@ class JinjaContextManager: "timedelta": timedelta, "uuid": uuid, } - self._template_processors = {} # type: Dict[str, BaseTemplateProcessor] + self._template_processors: Dict[str, Type["BaseTemplateProcessor"]] = {} - def init_app(self, app): + def init_app(self, app: Flask) -> None: self._base_context.update(app.config["JINJA_CONTEXT_ADDONS"]) self._template_processors.update(app.config["CUSTOM_TEMPLATE_PROCESSORS"]) @property - def base_context(self): + def base_context(self) -> Dict[str, Any]: return self._base_context @property - def template_processors(self): + def template_processors(self) -> Dict[str, Type["BaseTemplateProcessor"]]: return self._template_processors @@ -69,35 +70,35 @@ class ResultsBackendManager: self._results_backend = None self._use_msgpack = False - def init_app(self, app): - self._results_backend = app.config.get("RESULTS_BACKEND") - self._use_msgpack = app.config.get("RESULTS_BACKEND_USE_MSGPACK") + def init_app(self, app: Flask) -> None: + self._results_backend = app.config["RESULTS_BACKEND"] + self._use_msgpack = app.config["RESULTS_BACKEND_USE_MSGPACK"] @property - def results_backend(self): + def results_backend(self) -> Optional[BaseCache]: return self._results_backend @property - def should_use_msgpack(self): + def should_use_msgpack(self) -> bool: return self._use_msgpack class UIManifestProcessor: def __init__(self, app_dir: str) -> None: - self.app = None - self.manifest: dict = {} + self.app: Optional[Flask] = None + self.manifest: Dict[str, Dict[str, List[str]]] = {} self.manifest_file = f"{app_dir}/static/assets/manifest.json" - def init_app(self, app): + def init_app(self, app: Flask) -> None: self.app = app # Preload the cache self.parse_manifest_json() @app.context_processor - def get_manifest(): # pylint: disable=unused-variable + def get_manifest() -> Dict[str, Callable]: # pylint: disable=unused-variable loaded_chunks = set() - def get_files(bundle, asset_type="js"): + def get_files(bundle: str, asset_type: str = "js") -> List[str]: files = self.get_manifest_files(bundle, asset_type) filtered_files = [f for f in files if f not in loaded_chunks] for f in filtered_files: @@ -109,18 +110,18 @@ class UIManifestProcessor: css_manifest=lambda bundle: get_files(bundle, "css"), ) - def parse_manifest_json(self): + def parse_manifest_json(self) -> None: try: with open(self.manifest_file, "r") as f: - # the manifest includes non-entry files - # we only need entries in templates + # the manifest includes non-entry files we only need entries in + # templates full_manifest = json.load(f) self.manifest = full_manifest.get("entrypoints", {}) except Exception: # pylint: disable=broad-except pass - def get_manifest_files(self, bundle, asset_type): - if self.app.debug: + def get_manifest_files(self, bundle: str, asset_type: str) -> List[str]: + if self.app and self.app.debug: self.parse_manifest_json() return self.manifest.get(bundle, {}).get(asset_type, []) @@ -133,7 +134,7 @@ db = SQLA() _event_logger: dict = {} event_logger = LocalProxy(lambda: _event_logger.get("event_logger")) feature_flag_manager = FeatureFlagManager() -jinja_context_manager = JinjaContextManager() # type: JinjaContextManager +jinja_context_manager = JinjaContextManager() manifest_processor = UIManifestProcessor(APP_DIR) migrate = Migrate() results_backend_manager = ResultsBackendManager() diff --git a/superset/forms.py b/superset/forms.py index 175903a..4ba3ca2 100644 --- a/superset/forms.py +++ b/superset/forms.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """Contains the logic to create cohesive forms on the explore view""" -from typing import List # pylint: disable=unused-import +from typing import Any, List, Optional from flask_appbuilder.fieldwidgets import BS3TextFieldWidget from wtforms import Field @@ -25,24 +25,24 @@ class CommaSeparatedListField(Field): widget = BS3TextFieldWidget() data: List[str] = [] - def _value(self): + def _value(self) -> str: if self.data: - return u", ".join(self.data) + return ", ".join(self.data) - return u"" + return "" - def process_formdata(self, valuelist): + def process_formdata(self, valuelist: List[str]) -> None: if valuelist: self.data = [x.strip() for x in valuelist[0].split(",")] else: self.data = [] -def filter_not_empty_values(value): +def filter_not_empty_values(values: Optional[List[Any]]) -> Optional[List[Any]]: """Returns a list of non empty values or None""" - if not value: + if not values: return None - data = [x for x in value if x] + data = [value for value in values if value] if not data: return None return data diff --git a/superset/jinja_context.py b/superset/jinja_context.py index e1a10cd..95ee723 100644 --- a/superset/jinja_context.py +++ b/superset/jinja_context.py @@ -17,7 +17,7 @@ """Defines the templating context for SQL Lab""" import inspect import re -from typing import Any, List, Optional, Tuple, TYPE_CHECKING +from typing import Any, cast, List, Optional, Tuple, TYPE_CHECKING from flask import g, request from jinja2.sandbox import SandboxedEnvironment @@ -207,7 +207,7 @@ class BaseTemplateProcessor: # pylint: disable=too-few-public-methods def __init__( self, - database: Optional["Database"] = None, + database: "Database", query: Optional["Query"] = None, table: Optional["SqlaTable"] = None, extra_cache_keys: Optional[List[Any]] = None, @@ -266,7 +266,7 @@ class PrestoTemplateProcessor(BaseTemplateProcessor): schema, table_name = table_name.split(".") return table_name, schema - def first_latest_partition(self, table_name: str) -> str: + def first_latest_partition(self, table_name: str) -> Optional[str]: """ Gets the first value in the array of all latest partitions @@ -275,9 +275,10 @@ class PrestoTemplateProcessor(BaseTemplateProcessor): :raises IndexError: If no partition exists """ - return self.latest_partitions(table_name)[0] + latest_partitions = self.latest_partitions(table_name) + return latest_partitions[0] if latest_partitions else None - def latest_partitions(self, table_name: str) -> List[str]: + def latest_partitions(self, table_name: str) -> Optional[List[str]]: """ Gets the array of all latest partitions @@ -285,16 +286,21 @@ class PrestoTemplateProcessor(BaseTemplateProcessor): :return: the latest partition array """ + from superset.db_engine_specs.presto import PrestoEngineSpec + table_name, schema = self._schema_table(table_name, self.schema) - assert self.database - return self.database.db_engine_spec.latest_partition( # type: ignore + return cast(PrestoEngineSpec, self.database.db_engine_spec).latest_partition( table_name, schema, self.database )[1] - def latest_sub_partition(self, table_name, **kwargs): + def latest_sub_partition(self, table_name: str, **kwargs: Any) -> Any: table_name, schema = self._schema_table(table_name, self.schema) - assert self.database - return self.database.db_engine_spec.latest_sub_partition( + + from superset.db_engine_specs.presto import PrestoEngineSpec + + return cast( + PrestoEngineSpec, self.database.db_engine_spec + ).latest_sub_partition( table_name=table_name, schema=schema, database=self.database, **kwargs ) diff --git a/superset/sql_lab.py b/superset/sql_lab.py index ab952db..3ba1e3a 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -19,7 +19,7 @@ import uuid from contextlib import closing from datetime import datetime from sys import getsizeof -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, cast, Dict, Iterator, List, Optional, Tuple, Union import backoff import msgpack @@ -27,9 +27,10 @@ import pyarrow as pa import simplejson as json import sqlalchemy from celery.exceptions import SoftTimeLimitExceeded +from celery.task.base import Task from contextlib2 import contextmanager from flask_babel import lazy_gettext as _ -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.pool import NullPool from superset import ( @@ -77,7 +78,9 @@ class SqlLabTimeoutException(SqlLabException): pass -def handle_query_error(msg, query, session, payload=None): +def handle_query_error( + msg: str, query: Query, session: Session, payload: Optional[Dict[str, Any]] = None +) -> Dict[str, Any]: """Local method handling error while processing the SQL""" payload = payload or {} troubleshooting_link = config["TROUBLESHOOTING_LINK"] @@ -91,14 +94,14 @@ def handle_query_error(msg, query, session, payload=None): return payload -def get_query_backoff_handler(details): +def get_query_backoff_handler(details: Dict[Any, Any]) -> None: query_id = details["kwargs"]["query_id"] logger.error(f"Query with id `{query_id}` could not be retrieved") stats_logger.incr("error_attempting_orm_query_{}".format(details["tries"] - 1)) logger.error(f"Query {query_id}: Sleeping for a sec before retrying...") -def get_query_giveup_handler(_): +def get_query_giveup_handler(_: Any) -> None: stats_logger.incr("error_failed_at_getting_orm_query") @@ -110,7 +113,7 @@ def get_query_giveup_handler(_): on_giveup=get_query_giveup_handler, max_tries=5, ) -def get_query(query_id, session): +def get_query(query_id: int, session: Session) -> Query: """attempts to get the query and retry if it cannot""" try: return session.query(Query).filter_by(id=query_id).one() @@ -119,7 +122,7 @@ def get_query(query_id, session): @contextmanager -def session_scope(nullpool): +def session_scope(nullpool: bool) -> Iterator[Session]: """Provide a transactional scope around a series of operations.""" database_uri = app.config["SQLALCHEMY_DATABASE_URI"] if "sqlite" in database_uri: @@ -154,16 +157,16 @@ def session_scope(nullpool): soft_time_limit=SQLLAB_TIMEOUT, ) def get_sql_results( # pylint: disable=too-many-arguments - ctask, - query_id, - rendered_query, - return_results=True, - store_results=False, - user_name=None, - start_time=None, - expand_data=False, - log_params=None, -): + ctask: Task, + query_id: int, + rendered_query: str, + return_results: bool = True, + store_results: bool = False, + user_name: Optional[str] = None, + start_time: Optional[float] = None, + expand_data: bool = False, + log_params: Optional[Dict[str, Any]] = None, +) -> Optional[Dict[str, Any]]: """Executes the sql query returns the results.""" with session_scope(not ctask.request.called_directly) as session: @@ -188,7 +191,14 @@ def get_sql_results( # pylint: disable=too-many-arguments # pylint: disable=too-many-arguments -def execute_sql_statement(sql_statement, query, user_name, session, cursor, log_params): +def execute_sql_statement( + sql_statement: str, + query: Query, + user_name: Optional[str], + session: Session, + cursor: Any, + log_params: Optional[Dict[str, Any]], +) -> SupersetResultSet: """Executes a single SQL statement""" database = query.database db_engine_spec = database.db_engine_spec @@ -275,7 +285,7 @@ def execute_sql_statement(sql_statement, query, user_name, session, cursor, log_ def _serialize_payload( - payload: dict, use_msgpack: Optional[bool] = False + payload: Dict[Any, Any], use_msgpack: Optional[bool] = False ) -> Union[bytes, str]: logger.debug(f"Serializing to msgpack: {use_msgpack}") if use_msgpack: @@ -321,24 +331,24 @@ def _serialize_and_expand_data( return (data, selected_columns, all_columns, expanded_columns) -def execute_sql_statements( - query_id, - rendered_query, - return_results=True, - store_results=False, - user_name=None, - session=None, - start_time=None, - expand_data=False, - log_params=None, -): # pylint: disable=too-many-arguments, too-many-locals, too-many-statements +def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-locals, too-many-statements + query_id: int, + rendered_query: str, + return_results: bool, + store_results: bool, + user_name: Optional[str], + session: Session, + start_time: Optional[float], + expand_data: bool, + log_params: Optional[Dict[str, Any]], +) -> Optional[Dict[str, Any]]: """Executes the sql query returns the results.""" if store_results and start_time: # only asynchronous queries stats_logger.timing("sqllab.query.time_pending", now_as_float() - start_time) query = get_query(query_id, session) - payload = dict(query_id=query_id) + payload: Dict[str, Any] = dict(query_id=query_id) database = query.database db_engine_spec = database.db_engine_spec db_engine_spec.patch() @@ -406,7 +416,7 @@ def execute_sql_statements( ) query.end_time = now_as_float() - use_arrow_data = store_results and results_backend_use_msgpack + use_arrow_data = store_results and cast(bool, results_backend_use_msgpack) data, selected_columns, all_columns, expanded_columns = _serialize_and_expand_data( result_set, db_engine_spec, use_arrow_data, expand_data ) @@ -432,7 +442,7 @@ def execute_sql_statements( "sqllab.query.results_backend_write_serialization", stats_logger ): serialized_payload = _serialize_payload( - payload, results_backend_use_msgpack + payload, cast(bool, results_backend_use_msgpack) ) cache_timeout = database.cache_timeout if cache_timeout is None: diff --git a/superset/sql_parse.py b/superset/sql_parse.py index be9cf10..3e50386 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -158,7 +158,7 @@ class ParsedQuery: def _is_identifier(token: Token) -> bool: return isinstance(token, (IdentifierList, Identifier)) - def _process_tokenlist(self, token_list: TokenList): + def _process_tokenlist(self, token_list: TokenList) -> None: """ Add table names to table set @@ -204,7 +204,9 @@ class ParsedQuery: exec_sql += f"CREATE TABLE {full_table_name} AS \n{sql}" return exec_sql - def _extract_from_token(self, token: Token): # pylint: disable=too-many-branches + def _extract_from_token( # pylint: disable=too-many-branches + self, token: Token + ) -> None: """ Populate self._tables from token diff --git a/superset/stats_logger.py b/superset/stats_logger.py index 37fe3d3..75cfd8a 100644 --- a/superset/stats_logger.py +++ b/superset/stats_logger.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging +from typing import Optional from colorama import Fore, Style @@ -40,7 +41,7 @@ class BaseStatsLogger: """Decrement a counter""" raise NotImplementedError() - def timing(self, key, value: float) -> None: + def timing(self, key: str, value: float) -> None: raise NotImplementedError() def gauge(self, key: str) -> None: @@ -49,18 +50,18 @@ class BaseStatsLogger: class DummyStatsLogger(BaseStatsLogger): - def incr(self, key): + def incr(self, key: str) -> None: logger.debug(Fore.CYAN + "[stats_logger] (incr) " + key + Style.RESET_ALL) - def decr(self, key): + def decr(self, key: str) -> None: logger.debug((Fore.CYAN + "[stats_logger] (decr) " + key + Style.RESET_ALL)) - def timing(self, key, value): + def timing(self, key: str, value: float) -> None: logger.debug( (Fore.CYAN + f"[stats_logger] (timing) {key} | {value} " + Style.RESET_ALL) ) - def gauge(self, key): + def gauge(self, key: str) -> None: logger.debug( (Fore.CYAN + "[stats_logger] (gauge) " + f"{key}" + Style.RESET_ALL) ) @@ -71,8 +72,12 @@ try: class StatsdStatsLogger(BaseStatsLogger): def __init__( # pylint: disable=super-init-not-called - self, host="localhost", port=8125, prefix="superset", statsd_client=None - ): + self, + host: str = "localhost", + port: int = 8125, + prefix: str = "superset", + statsd_client: Optional[StatsClient] = None, + ) -> None: """ Initializes from either params or a supplied, pre-constructed statsd client. @@ -84,16 +89,16 @@ try: else: self.client = StatsClient(host=host, port=port, prefix=prefix) - def incr(self, key): + def incr(self, key: str) -> None: self.client.incr(key) - def decr(self, key): + def decr(self, key: str) -> None: self.client.decr(key) - def timing(self, key, value): + def timing(self, key: str, value: float) -> None: self.client.timing(key, value) - def gauge(self, key): + def gauge(self, key: str) -> None: # pylint: disable=no-value-for-parameter self.client.gauge(key) diff --git a/superset/typing.py b/superset/typing.py index 09a3393..e238000 100644 --- a/superset/typing.py +++ b/superset/typing.py @@ -33,6 +33,7 @@ Granularity = Union[str, Dict[str, Union[str, float]]] Metric = Union[Dict[str, str], str] QueryObjectDict = Dict[str, Any] VizData = Optional[Union[List[Any], Dict[Any, Any]]] +VizPayload = Dict[str, Any] # Flask response. Base = Union[bytes, str] diff --git a/superset/viz.py b/superset/viz.py index bf3c110..2e38bf2 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -31,7 +31,7 @@ import uuid from collections import defaultdict, OrderedDict from datetime import datetime, timedelta from itertools import product -from typing import Any, Dict, List, Optional, Set, Tuple, TYPE_CHECKING +from typing import Any, cast, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union import dataclasses import geohash @@ -55,7 +55,7 @@ from superset.exceptions import ( SpatialException, ) from superset.models.helpers import QueryResult -from superset.typing import VizData +from superset.typing import QueryObjectDict, VizData, VizPayload from superset.utils import core as utils from superset.utils.core import ( DTTM_ALIAS, @@ -101,7 +101,7 @@ class BaseViz: datasource: "BaseDatasource", form_data: Dict[str, Any], force: bool = False, - ): + ) -> None: if not datasource: raise Exception(_("Viz is missing a datasource")) @@ -134,7 +134,7 @@ class BaseViz: self.process_metrics() - def process_metrics(self): + def process_metrics(self) -> None: # metrics in TableViz is order sensitive, so metric_dict should be # OrderedDict self.metric_dict = OrderedDict() @@ -153,8 +153,10 @@ class BaseViz: self.metric_labels = list(self.metric_dict.keys()) @staticmethod - def handle_js_int_overflow(data): - for d in data.get("records", dict()): + def handle_js_int_overflow( + data: Dict[str, List[Dict[str, Any]]] + ) -> Dict[str, List[Dict[str, Any]]]: + for d in data.get("records", {}): for k, v in list(d.items()): if isinstance(v, int): # if an int is too big for Java Script to handle @@ -163,7 +165,7 @@ class BaseViz: d[k] = str(v) return data - def run_extra_queries(self): + def run_extra_queries(self) -> None: """Lifecycle method to use when more than one query is needed In rare-ish cases, a visualization may need to execute multiple @@ -186,7 +188,7 @@ class BaseViz: """ pass - def apply_rolling(self, df): + def apply_rolling(self, df: pd.DataFrame) -> pd.DataFrame: fd = self.form_data rolling_type = fd.get("rolling_type") rolling_periods = int(fd.get("rolling_periods") or 0) @@ -206,7 +208,7 @@ class BaseViz: df = df[min_periods:] return df - def get_samples(self): + def get_samples(self) -> List[Dict[str, Any]]: query_obj = self.query_obj() query_obj.update( { @@ -219,7 +221,7 @@ class BaseViz: df = self.get_df(query_obj) return df.to_dict(orient="records") - def get_df(self, query_obj: Optional[Dict[str, Any]] = None) -> pd.DataFrame: + def get_df(self, query_obj: Optional[QueryObjectDict] = None) -> pd.DataFrame: """Returns a pandas dataframe based on the query object""" if not query_obj: query_obj = self.query_obj() @@ -281,19 +283,19 @@ class BaseViz: df.replace([np.inf, -np.inf], np.nan, inplace=True) return df - def df_metrics_to_num(self, df): + def df_metrics_to_num(self, df: pd.DataFrame) -> None: """Converting metrics to numeric when pandas.read_sql cannot""" metrics = self.metric_labels for col, dtype in df.dtypes.items(): if dtype.type == np.object_ and col in metrics: df[col] = pd.to_numeric(df[col], errors="coerce") - def process_query_filters(self): + def process_query_filters(self) -> None: utils.convert_legacy_filters_into_adhoc(self.form_data) merge_extra_filters(self.form_data) utils.split_adhoc_filters_into_base_filters(self.form_data) - def query_obj(self) -> Dict[str, Any]: + def query_obj(self) -> QueryObjectDict: """Building a query object""" form_data = self.form_data self.process_query_filters() @@ -362,9 +364,9 @@ class BaseViz: return d @property - def cache_timeout(self): + def cache_timeout(self) -> int: if self.form_data.get("cache_timeout") is not None: - return int(self.form_data.get("cache_timeout")) + return int(self.form_data["cache_timeout"]) if self.datasource.cache_timeout is not None: return self.datasource.cache_timeout if ( @@ -374,12 +376,12 @@ class BaseViz: return self.datasource.database.cache_timeout return config["CACHE_DEFAULT_TIMEOUT"] - def get_json(self): + def get_json(self) -> str: return json.dumps( self.get_payload(), default=utils.json_int_dttm_ser, ignore_nan=True ) - def cache_key(self, query_obj, **extra): + def cache_key(self, query_obj: QueryObjectDict, **extra: Any) -> str: """ The cache key is made out of the key/values in `query_obj`, plus any other key/values in `extra`. @@ -410,7 +412,7 @@ class BaseViz: json_data = self.json_dumps(cache_dict, sort_keys=True) return hashlib.md5(json_data.encode("utf-8")).hexdigest() - def get_payload(self, query_obj=None): + def get_payload(self, query_obj: Optional[QueryObjectDict] = None) -> VizPayload: """Returns a payload of metadata and data""" self.run_extra_queries() payload = self.get_df_payload(query_obj) @@ -422,7 +424,9 @@ class BaseViz: del payload["df"] return payload - def get_df_payload(self, query_obj=None, **kwargs): + def get_df_payload( + self, query_obj: Optional[QueryObjectDict] = None, **kwargs: Any + ) -> Dict[str, Any]: """Handles caching around the df payload retrieval""" if not query_obj: query_obj = self.query_obj() @@ -512,21 +516,21 @@ class BaseViz: "rowcount": len(df.index) if df is not None else 0, } - def json_dumps(self, obj, sort_keys=False): + def json_dumps(self, obj: Any, sort_keys: bool = False) -> str: return json.dumps( obj, default=utils.json_int_dttm_ser, ignore_nan=True, sort_keys=sort_keys ) - def payload_json_and_has_error(self, payload): + def payload_json_and_has_error(self, payload: VizPayload) -> Tuple[str, bool]: has_error = ( payload.get("status") == utils.QueryStatus.FAILED or payload.get("error") is not None - or len(payload.get("errors")) > 0 + or len(payload.get("errors") or []) > 0 ) return self.json_dumps(payload), has_error @property - def data(self): + def data(self) -> Dict[str, Any]: """This is the data object serialized to the js layer""" content = { "form_data": self.form_data, @@ -536,7 +540,7 @@ class BaseViz: } return content - def get_csv(self): + def get_csv(self) -> Optional[str]: df = self.get_df() include_index = not isinstance(df.index, pd.RangeIndex) return df.to_csv(index=include_index, **config["CSV_EXPORT"]) @@ -545,7 +549,7 @@ class BaseViz: return df.to_dict(orient="records") @property - def json_data(self): + def json_data(self) -> str: return json.dumps(self.data) @@ -559,7 +563,7 @@ class TableViz(BaseViz): is_timeseries = False enforce_numerical_metrics = False - def should_be_timeseries(self): + def should_be_timeseries(self) -> bool: fd = self.form_data # TODO handle datasource-type-specific code in datasource conditions_met = (fd.get("granularity") and fd.get("granularity") != "all") or ( @@ -569,9 +573,9 @@ class TableViz(BaseViz): raise QueryObjectValidationError( _("Pick a granularity in the Time section or " "uncheck 'Include Time'") ) - return fd.get("include_time") + return bool(fd.get("include_time")) - def query_obj(self): + def query_obj(self) -> QueryObjectDict: d = super().query_obj() fd = self.form_data @@ -660,7 +664,7 @@ class TableViz(BaseViz): return data - def json_dumps(self, obj, sort_keys=False): + def json_dumps(self, obj: Any, sort_keys: bool = False) -> str: return json.dumps( obj, default=utils.json_iso_dttm_ser, sort_keys=sort_keys, ignore_nan=True ) @@ -675,14 +679,14 @@ class TimeTableViz(BaseViz): credits = 'a <a href="https://github.com/airbnb/superset">Superset</a> original' is_timeseries = True - def query_obj(self): + def query_obj(self) -> QueryObjectDict: d = super().query_obj() fd = self.form_data if not fd.get("metrics"): raise QueryObjectValidationError(_("Pick at least one metric")) - if fd.get("groupby") and len(fd.get("metrics")) > 1: + if fd.get("groupby") and len(fd["metrics"]) > 1: raise QueryObjectValidationError( _("When using 'Group By' you are limited to use a single metric") ) @@ -694,7 +698,7 @@ class TimeTableViz(BaseViz): fd = self.form_data columns = None - values = self.metric_labels + values: Union[List[str], str] = self.metric_labels if fd.get("groupby"): values = self.metric_labels[0] columns = fd.get("groupby") @@ -717,7 +721,7 @@ class PivotTableViz(BaseViz): credits = 'a <a href="https://github.com/airbnb/superset">Superset</a> original' is_timeseries = False - def query_obj(self): + def query_obj(self) -> QueryObjectDict: d = super().query_obj() groupby = self.form_data.get("groupby") columns = self.form_data.get("columns") @@ -798,10 +802,10 @@ class MarkupViz(BaseViz): verbose_name = _("Markup") is_timeseries = False - def query_obj(self): - return None + def query_obj(self) -> QueryObjectDict: + return {} - def get_df(self, query_obj: Optional[Dict[str, Any]] = None) -> pd.DataFrame: + def get_df(self, query_obj: Optional[QueryObjectDict] = None) -> pd.DataFrame: return pd.DataFrame() def get_data(self, df: pd.DataFrame) -> VizData: @@ -832,7 +836,7 @@ class WordCloudViz(BaseViz): verbose_name = _("Word Cloud") is_timeseries = False - def query_obj(self): + def query_obj(self) -> QueryObjectDict: d = super().query_obj() d["groupby"] = [self.form_data.get("series")] return d @@ -847,7 +851,7 @@ class TreemapViz(BaseViz): credits = '<a href="https://d3js.org">d3.js</a>' is_timeseries = False - def _nest(self, metric, df): + def _nest(self, metric: str, df: pd.DataFrame) -> List[Dict[str, Any]]: nlevels = df.index.nlevels if nlevels == 1: result = [{"name": n, "value": v} for n, v in zip(df.index, df[metric])] @@ -927,7 +931,7 @@ class CalHeatmapViz(BaseViz): "range": range_, } - def query_obj(self): + def query_obj(self) -> QueryObjectDict: d = super().query_obj() fd = self.form_data d["metrics"] = fd.get("metrics") @@ -953,19 +957,21 @@ class BoxPlotViz(NVD3Viz): sort_series = False is_timeseries = True - def to_series(self, df, classed="", title_suffix=""): + def to_series( + self, df: pd.DataFrame, classed: str = "", title_suffix: str = "" + ) -> List[Dict[str, Any]]: label_sep = " - " chart_data = [] for index_value, row in zip(df.index, df.to_dict(orient="records")): if isinstance(index_value, tuple): index_value = label_sep.join(index_value) - boxes = defaultdict(dict) + boxes: Dict[str, Dict[str, Any]] = defaultdict(dict) for (label, key), value in row.items(): if key == "nanmedian": key = "Q2" boxes[label][key] = value for label, box in boxes.items(): - if len(self.form_data.get("metrics")) > 1: + if len(self.form_data["metrics"]) > 1: # need to render data labels with metrics chart_label = label_sep.join([index_value, label]) else: @@ -980,46 +986,45 @@ class BoxPlotViz(NVD3Viz): form_data = self.form_data # conform to NVD3 names - def Q1(series): # need to be named functions - can't use lambdas + def Q1(series: pd.Series) -> float: + # need to be named functions - can't use lambdas return np.nanpercentile(series, 25) - def Q3(series): + def Q3(series: pd.Series) -> float: return np.nanpercentile(series, 75) whisker_type = form_data.get("whisker_options") if whisker_type == "Tukey": - def whisker_high(series): + def whisker_high(series: pd.Series) -> float: upper_outer_lim = Q3(series) + 1.5 * (Q3(series) - Q1(series)) return series[series <= upper_outer_lim].max() - def whisker_low(series): + def whisker_low(series: pd.Series) -> float: lower_outer_lim = Q1(series) - 1.5 * (Q3(series) - Q1(series)) return series[series >= lower_outer_lim].min() elif whisker_type == "Min/max (no outliers)": - def whisker_high(series): + def whisker_high(series: pd.Series) -> float: return series.max() - def whisker_low(series): + def whisker_low(series: pd.Series) -> float: return series.min() elif " percentiles" in whisker_type: # type: ignore - low, high = whisker_type.replace(" percentiles", "").split( # type: ignore - "/" - ) + low, high = cast(str, whisker_type).replace(" percentiles", "").split("/") - def whisker_high(series): + def whisker_high(series: pd.Series) -> float: return np.nanpercentile(series, int(high)) - def whisker_low(series): + def whisker_low(series: pd.Series) -> float: return np.nanpercentile(series, int(low)) else: raise ValueError("Unknown whisker type: {}".format(whisker_type)) - def outliers(series): + def outliers(series: pd.Series) -> Set[float]: above = series[series > whisker_high(series)] below = series[series < whisker_low(series)] # pandas sometimes doesn't like getting lists back here @@ -1039,7 +1044,7 @@ class BubbleViz(NVD3Viz): verbose_name = _("Bubble Chart") is_timeseries = False - def query_obj(self): + def query_obj(self) -> QueryObjectDict: form_data = self.form_data d = super().query_obj() d["groupby"] = [form_data.get("entity")] @@ -1090,7 +1095,7 @@ class BulletViz(NVD3Viz): verbose_name = _("Bullet Chart") is_timeseries = False - def query_obj(self): + def query_obj(self) -> QueryObjectDict: form_data = self.form_data d = super().query_obj() self.metric = form_data["metric"] @@ -1117,7 +1122,7 @@ class BigNumberViz(BaseViz): credits = 'a <a href="https://github.com/airbnb/superset">Superset</a> original' is_timeseries = True - def query_obj(self): + def query_obj(self) -> QueryObjectDict: d = super().query_obj() metric = self.form_data.get("metric") if not metric: @@ -1151,7 +1156,7 @@ class BigNumberTotalViz(BaseViz): credits = 'a <a href="https://github.com/airbnb/superset">Superset</a> original' is_timeseries = False - def query_obj(self): + def query_obj(self) -> QueryObjectDict: d = super().query_obj() metric = self.form_data.get("metric") if not metric: @@ -1174,7 +1179,9 @@ class NVD3TimeSeriesViz(NVD3Viz): is_timeseries = True pivot_fill_value: Optional[int] = None - def to_series(self, df, classed="", title_suffix=""): + def to_series( + self, df: pd.DataFrame, classed: str = "", title_suffix: str = "" + ) -> List[Dict[str, Any]]: cols = [] for col in df.columns: if col == "": @@ -1191,6 +1198,7 @@ class NVD3TimeSeriesViz(NVD3Viz): ys = series[name] if df[name].dtype.kind not in "biufc": continue + series_title: Union[List[str], str, Tuple[str, ...]] if isinstance(name, list): series_title = [str(title) for title in name] elif isinstance(name, tuple): @@ -1207,7 +1215,9 @@ class NVD3TimeSeriesViz(NVD3Viz): if title_suffix: if isinstance(series_title, str): series_title = (series_title, title_suffix) - elif isinstance(series_title, (list, tuple)): + elif isinstance(series_title, list): + series_title = series_title + [title_suffix] + elif isinstance(series_title, tuple): series_title = series_title + (title_suffix,) values = [] @@ -1274,7 +1284,7 @@ class NVD3TimeSeriesViz(NVD3Viz): return df - def run_extra_queries(self): + def run_extra_queries(self) -> None: fd = self.form_data time_compare = fd.get("time_compare") or [] @@ -1364,8 +1374,8 @@ class MultiLineViz(NVD3Viz): is_timeseries = True - def query_obj(self): - return None + def query_obj(self) -> QueryObjectDict: + return {} def get_data(self, df: pd.DataFrame) -> VizData: fd = self.form_data @@ -1394,7 +1404,7 @@ class NVD3DualLineViz(NVD3Viz): sort_series = False is_timeseries = True - def query_obj(self): + def query_obj(self) -> QueryObjectDict: d = super().query_obj() m1 = self.form_data.get("metric") m2 = self.form_data.get("metric_2") @@ -1409,7 +1419,7 @@ class NVD3DualLineViz(NVD3Viz): ) return d - def to_series(self, df, classed=""): + def to_series(self, df: pd.DataFrame, classed: str = "") -> List[Dict[str, Any]]: cols = [] for col in df.columns: if col == "": @@ -1421,7 +1431,7 @@ class NVD3DualLineViz(NVD3Viz): df.columns = cols series = df.to_dict("series") chart_data = [] - metrics = [self.form_data.get("metric"), self.form_data.get("metric_2")] + metrics = [self.form_data["metric"], self.form_data["metric_2"]] for i, m in enumerate(metrics): m = utils.get_metric_name(m) ys = series[m] @@ -1476,7 +1486,7 @@ class NVD3TimePivotViz(NVD3TimeSeriesViz): sort_series = True verbose_name = _("Time Series - Period Pivot") - def query_obj(self): + def query_obj(self) -> QueryObjectDict: d = super().query_obj() d["metrics"] = [self.form_data.get("metric")] return d @@ -1561,7 +1571,7 @@ class HistogramViz(BaseViz): verbose_name = _("Histogram") is_timeseries = False - def query_obj(self): + def query_obj(self) -> QueryObjectDict: """Returns the query object for this visualization""" d = super().query_obj() d["row_limit"] = self.form_data.get("row_limit", int(config["VIZ_ROW_LIMIT"])) @@ -1576,9 +1586,9 @@ class HistogramViz(BaseViz): d["groupby"] = [] return d - def labelify(self, keys, column): + def labelify(self, keys: Union[List[str], str], column: str) -> str: if isinstance(keys, str): - keys = (keys,) + keys = [keys] # removing undesirable characters labels = [re.sub(r"\W+", r"_", k) for k in keys] if len(self.columns) > 1 or not self.groupby: @@ -1617,7 +1627,7 @@ class DistributionBarViz(DistributionPieViz): verbose_name = _("Distribution - Bar Chart") is_timeseries = False - def query_obj(self): + def query_obj(self) -> QueryObjectDict: d = super().query_obj() fd = self.form_data if len(d["groupby"]) < len(fd.get("groupby") or []) + len( @@ -1708,7 +1718,7 @@ class SunburstViz(BaseViz): df = df[cols] return df.to_numpy().tolist() - def query_obj(self): + def query_obj(self) -> QueryObjectDict: qry = super().query_obj() fd = self.form_data qry["metrics"] = [fd["metric"]] @@ -1727,7 +1737,7 @@ class SankeyViz(BaseViz): is_timeseries = False credits = '<a href="https://www.npmjs.com/package/d3-sankey">d3-sankey on npm</a>' - def query_obj(self): + def query_obj(self) -> QueryObjectDict: qry = super().query_obj() if len(qry["groupby"]) != 2: raise QueryObjectValidationError( @@ -1746,21 +1756,23 @@ class SankeyViz(BaseViz): for row in recs: hierarchy[row["source"]].add(row["target"]) - def find_cycle(g): + def find_cycle(g: Dict[str, Set[str]]) -> Optional[Tuple[str, str]]: """Whether there's a cycle in a directed graph""" path = set() - def visit(vertex): + def visit(vertex: str) -> Optional[Tuple[str, str]]: path.add(vertex) for neighbour in g.get(vertex, ()): if neighbour in path or visit(neighbour): return (vertex, neighbour) path.remove(vertex) + return None for v in g: cycle = visit(v) if cycle: return cycle + return None cycle = find_cycle(hierarchy) if cycle: @@ -1782,7 +1794,7 @@ class DirectedForceViz(BaseViz): credits = 'd3noob @<a href="http://bl.ocks.org/d3noob/5141278">bl.ocks.org</a>' is_timeseries = False - def query_obj(self): + def query_obj(self) -> QueryObjectDict: qry = super().query_obj() if len(self.form_data["groupby"]) != 2: raise QueryObjectValidationError(_("Pick exactly 2 columns to 'Group By'")) @@ -1803,7 +1815,7 @@ class ChordViz(BaseViz): credits = '<a href="https://github.com/d3/d3-chord">Bostock</a>' is_timeseries = False - def query_obj(self): + def query_obj(self) -> QueryObjectDict: qry = super().query_obj() fd = self.form_data qry["groupby"] = [fd.get("groupby"), fd.get("columns")] @@ -1836,7 +1848,7 @@ class CountryMapViz(BaseViz): is_timeseries = False credits = "From bl.ocks.org By john-guerra" - def query_obj(self): + def query_obj(self) -> QueryObjectDict: qry = super().query_obj() qry["metrics"] = [self.form_data["metric"]] qry["groupby"] = [self.form_data["entity"]] @@ -1863,7 +1875,7 @@ class WorldMapViz(BaseViz): is_timeseries = False credits = 'datamaps on <a href="https://www.npmjs.com/package/datamaps">npm</a>' - def query_obj(self): + def query_obj(self) -> QueryObjectDict: qry = super().query_obj() qry["groupby"] = [self.form_data["entity"]] return qry @@ -1923,10 +1935,10 @@ class FilterBoxViz(BaseViz): cache_type = "get_data" filter_row_limit = 1000 - def query_obj(self): - return None + def query_obj(self) -> QueryObjectDict: + return {} - def run_extra_queries(self): + def run_extra_queries(self) -> None: qry = super().query_obj() filters = self.form_data.get("filter_configs") or [] qry["row_limit"] = self.filter_row_limit @@ -1979,10 +1991,10 @@ class IFrameViz(BaseViz): credits = 'a <a href="https://github.com/airbnb/superset">Superset</a> original' is_timeseries = False - def query_obj(self): - return None + def query_obj(self) -> QueryObjectDict: + return {} - def get_df(self, query_obj: Optional[Dict[str, Any]] = None) -> pd.DataFrame: + def get_df(self, query_obj: Optional[QueryObjectDict] = None) -> pd.DataFrame: return pd.DataFrame() def get_data(self, df: pd.DataFrame) -> VizData: @@ -2005,7 +2017,7 @@ class ParallelCoordinatesViz(BaseViz): ) is_timeseries = False - def query_obj(self): + def query_obj(self) -> QueryObjectDict: d = super().query_obj() fd = self.form_data d["groupby"] = [fd.get("series")] @@ -2027,7 +2039,7 @@ class HeatmapViz(BaseViz): "bl.ocks.org</a>" ) - def query_obj(self): + def query_obj(self) -> QueryObjectDict: d = super().query_obj() fd = self.form_data d["metrics"] = [fd.get("metric")] @@ -2092,7 +2104,7 @@ class MapboxViz(BaseViz): is_timeseries = False credits = "<a href=https://www.mapbox.com/mapbox-gl-js/api/>Mapbox GL JS</a>" - def query_obj(self): + def query_obj(self) -> QueryObjectDict: d = super().query_obj() fd = self.form_data label_col = fd.get("mapbox_label") @@ -2124,22 +2136,24 @@ class MapboxViz(BaseViz): label_col and len(label_col) >= 1 and label_col[0] != "count" - and label_col[0] not in fd.get("groupby") + and label_col[0] not in fd["groupby"] ): raise QueryObjectValidationError( _("Choice of [Label] must be present in [Group By]") ) - if fd.get("point_radius") != "Auto" and fd.get( - "point_radius" - ) not in fd.get("groupby"): + if ( + fd.get("point_radius") != "Auto" + and fd.get("point_radius") not in fd["groupby"] + ): raise QueryObjectValidationError( _("Choice of [Point Radius] must be present in [Group By]") ) - if fd.get("all_columns_x") not in fd.get("groupby") or fd.get( - "all_columns_y" - ) not in fd.get("groupby"): + if ( + fd.get("all_columns_x") not in fd["groupby"] + or fd.get("all_columns_y") not in fd["groupby"] + ): raise QueryObjectValidationError( _( "[Longitude] and [Latitude] columns must be present in " @@ -2226,8 +2240,8 @@ class DeckGLMultiLayer(BaseViz): is_timeseries = False credits = '<a href="https://uber.github.io/deck.gl/">deck.gl</a>' - def query_obj(self): - return None + def query_obj(self) -> QueryObjectDict: + return {} def get_data(self, df: pd.DataFrame) -> VizData: fd = self.form_data @@ -2251,14 +2265,14 @@ class BaseDeckGLViz(BaseViz): credits = '<a href="https://uber.github.io/deck.gl/">deck.gl</a>' spatial_control_keys: List[str] = [] - def get_metrics(self): + def get_metrics(self) -> List[str]: self.metric = self.form_data.get("size") return [self.metric] if self.metric else [] - def process_spatial_query_obj(self, key, group_by): + def process_spatial_query_obj(self, key: str, group_by: List[str]) -> None: group_by.extend(self.get_spatial_columns(key)) - def get_spatial_columns(self, key): + def get_spatial_columns(self, key: str) -> List[str]: spatial = self.form_data.get(key) if spatial is None: raise ValueError(_("Bad spatial key")) @@ -2269,9 +2283,10 @@ class BaseDeckGLViz(BaseViz): return [spatial.get("lonlatCol")] elif spatial.get("type") == "geohash": return [spatial.get("geohashCol")] + return [] @staticmethod - def parse_coordinates(s): + def parse_coordinates(s: Any) -> Optional[Tuple[float, float]]: if not s: return None try: @@ -2281,15 +2296,15 @@ class BaseDeckGLViz(BaseViz): raise SpatialException(_("Invalid spatial point encountered: %s" % s)) @staticmethod - def reverse_geohash_decode(geohash_code): + def reverse_geohash_decode(geohash_code: str) -> Tuple[str, str]: lat, lng = geohash.decode(geohash_code) return (lng, lat) @staticmethod - def reverse_latlong(df, key): + def reverse_latlong(df: pd.DataFrame, key: str) -> None: df[key] = [tuple(reversed(o)) for o in df[key] if isinstance(o, (list, tuple))] - def process_spatial_data_obj(self, key, df): + def process_spatial_data_obj(self, key: str, df: pd.DataFrame) -> pd.DataFrame: spatial = self.form_data.get(key) if spatial is None: raise ValueError(_("Bad spatial key")) @@ -2321,7 +2336,7 @@ class BaseDeckGLViz(BaseViz): ) return df - def add_null_filters(self): + def add_null_filters(self) -> None: fd = self.form_data spatial_columns = set() for key in self.spatial_control_keys: @@ -2339,7 +2354,7 @@ class BaseDeckGLViz(BaseViz): filter_ = to_adhoc({"col": column, "op": "IS NOT NULL", "val": ""}) fd["adhoc_filters"].append(filter_) - def query_obj(self): + def query_obj(self) -> QueryObjectDict: fd = self.form_data # add NULL filters @@ -2347,16 +2362,16 @@ class BaseDeckGLViz(BaseViz): self.add_null_filters() d = super().query_obj() - gb = [] + gb: List[str] = [] for key in self.spatial_control_keys: self.process_spatial_query_obj(key, gb) if fd.get("dimension"): - gb += [fd.get("dimension")] + gb += [fd["dimension"]] if fd.get("js_columns"): - gb += fd.get("js_columns") + gb += fd.get("js_columns") or [] metrics = self.get_metrics() gb = list(set(gb)) if metrics: @@ -2367,7 +2382,7 @@ class BaseDeckGLViz(BaseViz): d["columns"] = gb return d - def get_js_columns(self, d): + def get_js_columns(self, d: Dict[str, Any]) -> Dict[str, Any]: cols = self.form_data.get("js_columns") or [] return {col: d.get(col) for col in cols} @@ -2393,7 +2408,7 @@ class BaseDeckGLViz(BaseViz): "metricLabels": self.metric_labels, } - def get_properties(self, d): + def get_properties(self, d: Dict[str, Any]) -> Dict[str, Any]: raise NotImplementedError() @@ -2406,7 +2421,7 @@ class DeckScatterViz(BaseDeckGLViz): spatial_control_keys = ["spatial"] is_timeseries = True - def query_obj(self): + def query_obj(self) -> QueryObjectDict: fd = self.form_data self.is_timeseries = bool(fd.get("time_grain_sqla") or fd.get("granularity")) self.point_radius_fixed = fd.get("point_radius_fixed") or { @@ -2415,19 +2430,21 @@ class DeckScatterViz(BaseDeckGLViz): } return super().query_obj() - def get_metrics(self): + def get_metrics(self) -> List[str]: self.metric = None if self.point_radius_fixed.get("type") == "metric": - self.metric = self.point_radius_fixed.get("value") + self.metric = self.point_radius_fixed["value"] return [self.metric] - return None + return [] - def get_properties(self, d): + def get_properties(self, d: Dict[str, Any]) -> Dict[str, Any]: return { - "metric": d.get(self.metric_label), + "metric": d.get(self.metric_label) if self.metric_label else None, "radius": self.fixed_value if self.fixed_value - else d.get(self.metric_label), + else d.get(self.metric_label) + if self.metric_label + else None, "cat_color": d.get(self.dim) if self.dim else None, "position": d.get("spatial"), DTTM_ALIAS: d.get(DTTM_ALIAS), @@ -2453,20 +2470,20 @@ class DeckScreengrid(BaseDeckGLViz): spatial_control_keys = ["spatial"] is_timeseries = True - def query_obj(self): + def query_obj(self) -> QueryObjectDict: fd = self.form_data - self.is_timeseries = fd.get("time_grain_sqla") or fd.get("granularity") + self.is_timeseries = bool(fd.get("time_grain_sqla") or fd.get("granularity")) return super().query_obj() - def get_properties(self, d): + def get_properties(self, d: Dict[str, Any]) -> Dict[str, Any]: return { "position": d.get("spatial"), - "weight": d.get(self.metric_label) or 1, + "weight": (d.get(self.metric_label) if self.metric_label else None) or 1, "__timestamp": d.get(DTTM_ALIAS) or d.get("__time"), } def get_data(self, df: pd.DataFrame) -> VizData: - self.metric_label = utils.get_metric_name(self.metric) + self.metric_label = utils.get_metric_name(self.metric) if self.metric else None return super().get_data(df) @@ -2478,15 +2495,18 @@ class DeckGrid(BaseDeckGLViz): verbose_name = _("Deck.gl - 3D Grid") spatial_control_keys = ["spatial"] - def get_properties(self, d): - return {"position": d.get("spatial"), "weight": d.get(self.metric_label) or 1} + def get_properties(self, d: Dict[str, Any]) -> Dict[str, Any]: + return { + "position": d.get("spatial"), + "weight": (d.get(self.metric_label) if self.metric_label else None) or 1, + } def get_data(self, df: pd.DataFrame) -> VizData: - self.metric_label = utils.get_metric_name(self.metric) + self.metric_label = utils.get_metric_name(self.metric) if self.metric else None return super().get_data(df) -def geohash_to_json(geohash_code): +def geohash_to_json(geohash_code: str) -> List[List[float]]: p = geohash.bbox(geohash_code) return [ [p.get("w"), p.get("n")], @@ -2511,9 +2531,9 @@ class DeckPathViz(BaseDeckGLViz): "geohash": geohash_to_json, } - def query_obj(self): + def query_obj(self) -> QueryObjectDict: fd = self.form_data - self.is_timeseries = fd.get("time_grain_sqla") or fd.get("granularity") + self.is_timeseries = bool(fd.get("time_grain_sqla") or fd.get("granularity")) d = super().query_obj() self.metric = fd.get("metric") line_col = fd.get("line_column") @@ -2525,11 +2545,11 @@ class DeckPathViz(BaseDeckGLViz): d["columns"].append(line_col) return d - def get_properties(self, d): + def get_properties(self, d: Dict[str, Any]) -> Dict[str, Any]: fd = self.form_data - line_type = fd.get("line_type") + line_type = fd["line_type"] deser = self.deser_map[line_type] - line_column = fd.get("line_column") + line_column = fd["line_column"] path = deser(d[line_column]) if fd.get("reverse_long_lat"): path = [(o[1], o[0]) for o in path] @@ -2540,7 +2560,7 @@ class DeckPathViz(BaseDeckGLViz): return d def get_data(self, df: pd.DataFrame) -> VizData: - self.metric_label = utils.get_metric_name(self.metric) + self.metric_label = utils.get_metric_name(self.metric) if self.metric else None return super().get_data(df) @@ -2552,18 +2572,18 @@ class DeckPolygon(DeckPathViz): deck_viz_key = "polygon" verbose_name = _("Deck.gl - Polygon") - def query_obj(self): + def query_obj(self) -> QueryObjectDict: fd = self.form_data self.elevation = fd.get("point_radius_fixed") or {"type": "fix", "value": 500} return super().query_obj() - def get_metrics(self): + def get_metrics(self) -> List[str]: metrics = [self.form_data.get("metric")] if self.elevation.get("type") == "metric": metrics.append(self.elevation.get("value")) return [metric for metric in metrics if metric] - def get_properties(self, d): + def get_properties(self, d: Dict[str, Any]) -> Dict[str, Any]: super().get_properties(d) fd = self.form_data elevation = fd["point_radius_fixed"]["value"] @@ -2582,11 +2602,14 @@ class DeckHex(BaseDeckGLViz): verbose_name = _("Deck.gl - 3D HEX") spatial_control_keys = ["spatial"] - def get_properties(self, d): - return {"position": d.get("spatial"), "weight": d.get(self.metric_label) or 1} + def get_properties(self, d: Dict[str, Any]) -> Dict[str, Any]: + return { + "position": d.get("spatial"), + "weight": (d.get(self.metric_label) if self.metric_label else None) or 1, + } def get_data(self, df: pd.DataFrame) -> VizData: - self.metric_label = utils.get_metric_name(self.metric) + self.metric_label = utils.get_metric_name(self.metric) if self.metric else None return super(DeckHex, self).get_data(df) @@ -2597,15 +2620,15 @@ class DeckGeoJson(BaseDeckGLViz): viz_type = "deck_geojson" verbose_name = _("Deck.gl - GeoJSON") - def query_obj(self): + def query_obj(self) -> QueryObjectDict: d = super().query_obj() d["columns"] += [self.form_data.get("geojson")] d["metrics"] = [] d["groupby"] = [] return d - def get_properties(self, d): - geojson = d.get(self.form_data.get("geojson")) + def get_properties(self, d: Dict[str, Any]) -> Dict[str, Any]: + geojson = d[self.form_data["geojson"]] return json.loads(geojson) @@ -2618,12 +2641,12 @@ class DeckArc(BaseDeckGLViz): spatial_control_keys = ["start_spatial", "end_spatial"] is_timeseries = True - def query_obj(self): + def query_obj(self) -> QueryObjectDict: fd = self.form_data self.is_timeseries = bool(fd.get("time_grain_sqla") or fd.get("granularity")) return super().query_obj() - def get_properties(self, d): + def get_properties(self, d: Dict[str, Any]) -> Dict[str, Any]: dim = self.form_data.get("dimension") return { "sourcePosition": d.get("start_spatial"), @@ -2653,15 +2676,15 @@ class EventFlowViz(BaseViz): credits = 'from <a href="https://github.com/williaster/data-ui">@data-ui</a>' is_timeseries = True - def query_obj(self): + def query_obj(self) -> QueryObjectDict: query = super().query_obj() form_data = self.form_data - event_key = form_data.get("all_columns_x") - entity_key = form_data.get("entity") + event_key = form_data["all_columns_x"] + entity_key = form_data["entity"] meta_keys = [ col - for col in form_data.get("all_columns") + for col in form_data["all_columns"] if col != event_key and col != entity_key ] @@ -2773,14 +2796,16 @@ class PartitionViz(NVD3TimeSeriesViz): viz_type = "partition" verbose_name = _("Partition Diagram") - def query_obj(self): + def query_obj(self) -> QueryObjectDict: query_obj = super().query_obj() time_op = self.form_data.get("time_series_option", "not_time") # Return time series data if the user specifies so query_obj["is_timeseries"] = time_op != "not_time" return query_obj - def levels_for(self, time_op, groups, df): + def levels_for( + self, time_op: str, groups: List[str], df: pd.DataFrame + ) -> Dict[int, pd.Series]: """ Compute the partition at each `level` from the dataframe. """ @@ -2794,7 +2819,9 @@ class PartitionViz(NVD3TimeSeriesViz): ) return levels - def levels_for_diff(self, time_op, groups, df): + def levels_for_diff( + self, time_op: str, groups: List[str], df: pd.DataFrame + ) -> Dict[int, pd.DataFrame]: # Obtain a unique list of the time grains times = list(set(df[DTTM_ALIAS])) times.sort() @@ -2828,7 +2855,9 @@ class PartitionViz(NVD3TimeSeriesViz): ) return levels - def levels_for_time(self, groups, df): + def levels_for_time( + self, groups: List[str], df: pd.DataFrame + ) -> Dict[int, VizData]: procs = {} for i in range(0, len(groups) + 1): self.form_data["groupby"] = groups[:i] @@ -2837,11 +2866,19 @@ class PartitionViz(NVD3TimeSeriesViz): self.form_data["groupby"] = groups return procs - def nest_values(self, levels, level=0, metric=None, dims=()): + def nest_values( + self, + levels: Dict[int, pd.DataFrame], + level: int = 0, + metric: Optional[str] = None, + dims: Optional[List[str]] = None, + ) -> List[Dict[str, Any]]: """ Nest values at each level on the back-end with access and setting, instead of summing from the bottom. """ + if dims is None: + dims = [] if not level: return [ { @@ -2856,7 +2893,7 @@ class PartitionViz(NVD3TimeSeriesViz): { "name": i, "val": levels[1][metric][i], - "children": self.nest_values(levels, 2, metric, (i,)), + "children": self.nest_values(levels, 2, metric, [i]), } for i in levels[1][metric].index ] @@ -2866,12 +2903,20 @@ class PartitionViz(NVD3TimeSeriesViz): { "name": i, "val": levels[level][metric][dims][i], - "children": self.nest_values(levels, level + 1, metric, dims + (i,)), + "children": self.nest_values(levels, level + 1, metric, dims + [i]), } for i in levels[level][metric][dims].index ] - def nest_procs(self, procs, level=-1, dims=(), time=None): + def nest_procs( + self, + procs: Dict[int, pd.DataFrame], + level: int = -1, + dims: Optional[Tuple[str, ...]] = None, + time: Any = None, + ) -> List[Dict[str, Any]]: + if dims is None: + dims = () if level == -1: return [ {"name": m, "children": self.nest_procs(procs, 0, (m,))} diff --git a/superset/viz_sip38.py b/superset/viz_sip38.py index f51580d..1df2528 100644 --- a/superset/viz_sip38.py +++ b/superset/viz_sip38.py @@ -20,6 +20,7 @@ These objects represent the backend of all the visualizations that Superset can render. """ +# mypy: ignore-errors import copy import hashlib import inspect @@ -610,7 +611,7 @@ class TableViz(BaseViz): raise QueryObjectValidationError( _("Pick a granularity in the Time section or " "uncheck 'Include Time'") ) - return fd.get("include_time") + return bool(fd.get("include_time")) def query_obj(self): d = super().query_obj() diff --git a/tests/viz_tests.py b/tests/viz_tests.py index 748d50b..f8eb8ce 100644 --- a/tests/viz_tests.py +++ b/tests/viz_tests.py @@ -974,7 +974,7 @@ class BaseDeckGLVizTestCase(SupersetTestCase): test_viz_deckgl = viz.DeckScatterViz(datasource, form_data) test_viz_deckgl.point_radius_fixed = {} result = test_viz_deckgl.get_metrics() - assert result is None + assert result == [] def test_get_js_columns(self): form_data = load_fixture("deck_path_form_data.json")