This is an automated email from the ASF dual-hosted git repository. johnbodley pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-superset.git
The following commit(s) were added to refs/heads/master by this push: new 91517a5 style(mypy): Spit-and-polish pass (#10001) 91517a5 is described below commit 91517a56a3bcacd4c8f2dca233d8df248bfca10e Author: John Bodley <4567245+john-bod...@users.noreply.github.com> AuthorDate: Sun Jun 7 08:53:46 2020 -0700 style(mypy): Spit-and-polish pass (#10001) Co-authored-by: John Bodley <john.bod...@airbnb.com> --- setup.cfg | 4 +- superset/app.py | 4 +- superset/charts/commands/create.py | 4 +- superset/charts/commands/update.py | 4 +- superset/common/query_object.py | 13 +++--- superset/config.py | 7 +-- superset/connectors/base/models.py | 13 +++--- superset/connectors/connector_registry.py | 5 ++- superset/connectors/druid/models.py | 51 ++++++++++------------ superset/connectors/sqla/models.py | 15 ++++--- superset/dao/base.py | 8 ++-- superset/dashboards/commands/create.py | 4 +- superset/dashboards/commands/update.py | 4 +- superset/datasets/commands/create.py | 4 +- superset/datasets/commands/update.py | 10 ++--- superset/datasets/dao.py | 12 ++--- superset/db_engine_specs/base.py | 12 ++--- superset/db_engine_specs/bigquery.py | 2 +- superset/db_engine_specs/exasol.py | 2 +- superset/db_engine_specs/hive.py | 6 +-- superset/db_engine_specs/mssql.py | 2 +- superset/db_engine_specs/postgres.py | 2 +- superset/db_engine_specs/presto.py | 19 ++++---- superset/extensions.py | 6 ++- superset/models/core.py | 9 ++-- superset/models/dashboard.py | 8 ++-- superset/models/helpers.py | 2 +- superset/models/slice.py | 2 +- superset/models/sql_types/presto_sql_types.py | 12 ++--- superset/queries/filters.py | 4 +- superset/result_set.py | 6 +-- superset/security/manager.py | 15 +++---- superset/sql_lab.py | 7 +-- superset/tasks/celery_app.py | 2 +- superset/tasks/schedules.py | 10 +++-- superset/utils/cache.py | 7 +-- superset/utils/core.py | 18 +++++--- .../utils/dashboard_filter_scopes_converter.py | 11 ++--- superset/utils/decorators.py | 4 +- superset/utils/import_datasource.py | 8 ++-- superset/utils/log.py | 4 +- superset/utils/logging_configurator.py | 2 +- superset/utils/pandas_postprocessing.py | 6 +-- superset/utils/screenshots.py | 14 +++--- superset/views/base.py | 14 +++--- superset/views/base_api.py | 2 +- superset/views/base_schemas.py | 4 +- superset/views/core.py | 14 +++--- superset/views/database/api.py | 10 ++--- superset/views/database/decorators.py | 2 +- superset/views/schedules.py | 8 ++-- superset/views/sql_lab.py | 4 +- superset/views/utils.py | 12 ++--- superset/viz.py | 4 +- tests/base_tests.py | 10 +++-- tests/superset_test_config_thumbnails.py | 2 +- 56 files changed, 243 insertions(+), 207 deletions(-) diff --git a/setup.cfg b/setup.cfg index 81c7ed2..93e33af 100644 --- a/setup.cfg +++ b/setup.cfg @@ -50,10 +50,12 @@ multi_line_output = 3 order_by_type = false [mypy] +disallow_any_generics = true ignore_missing_imports = true no_implicit_optional = true +warn_unused_ignores = true -[mypy-superset,superset.app,superset.bin.*,superset.charts.*,superset.cli,superset.commands.*,superset.common.*,superset.config,superset.connectors.*,superset.constants,superset.dataframe,superset.dao.*,superset.dashboards.*,superset.datasets.*,superset.db_engine_specs.*,superset.db_engines.*,superset.errors,superset.examples.*,superset.exceptions,superset.extensions,superset.forms,superset.jinja_context,superset.legacy,superset.migrations.*,superset.models.*,superset.result_set,superset [...] +[mypy-superset.*] check_untyped_defs = true disallow_untyped_calls = true disallow_untyped_defs = true diff --git a/superset/app.py b/superset/app.py index 18165ed..b36df75 100644 --- a/superset/app.py +++ b/superset/app.py @@ -80,7 +80,7 @@ class SupersetAppInitializer: self.flask_app = app self.config = app.config - self.manifest: dict = {} + self.manifest: Dict[Any, Any] = {} def pre_init(self) -> None: """ @@ -542,7 +542,7 @@ class SupersetAppInitializer: self.app = app def __call__( - self, environ: Dict[str, Any], start_response: Callable + self, environ: Dict[str, Any], start_response: Callable[..., Any] ) -> Any: # Setting wsgi.input_terminated tells werkzeug.wsgi to ignore # content-length and read the stream till the end. diff --git a/superset/charts/commands/create.py b/superset/charts/commands/create.py index 8e7dcb7..1425396 100644 --- a/superset/charts/commands/create.py +++ b/superset/charts/commands/create.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional from flask_appbuilder.models.sqla import Model from flask_appbuilder.security.sqla.models import User @@ -36,7 +36,7 @@ logger = logging.getLogger(__name__) class CreateChartCommand(BaseCommand): - def __init__(self, user: User, data: Dict): + def __init__(self, user: User, data: Dict[str, Any]): self._actor = user self._properties = data.copy() diff --git a/superset/charts/commands/update.py b/superset/charts/commands/update.py index 21c236c..70055bf 100644 --- a/superset/charts/commands/update.py +++ b/superset/charts/commands/update.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional from flask_appbuilder.models.sqla import Model from flask_appbuilder.security.sqla.models import User @@ -42,7 +42,7 @@ logger = logging.getLogger(__name__) class UpdateChartCommand(BaseCommand): - def __init__(self, user: User, model_id: int, data: Dict): + def __init__(self, user: User, model_id: int, data: Dict[str, Any]): self._actor = user self._model_id = model_id self._properties = data.copy() diff --git a/superset/common/query_object.py b/superset/common/query_object.py index 3c5b778..188d0b3 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -26,6 +26,7 @@ from pandas import DataFrame from superset import app, is_feature_enabled from superset.exceptions import QueryObjectValidationError +from superset.typing import Metric from superset.utils import core as utils, pandas_postprocessing from superset.views.utils import get_time_range_endpoints @@ -67,11 +68,11 @@ class QueryObject: row_limit: int filter: List[Dict[str, Any]] timeseries_limit: int - timeseries_limit_metric: Optional[Dict] + timeseries_limit_metric: Optional[Metric] order_desc: bool - extras: Dict + extras: Dict[str, Any] columns: List[str] - orderby: List[List] + orderby: List[List[str]] post_processing: List[Dict[str, Any]] def __init__( @@ -85,11 +86,11 @@ class QueryObject: is_timeseries: bool = False, timeseries_limit: int = 0, row_limit: int = app.config["ROW_LIMIT"], - timeseries_limit_metric: Optional[Dict] = None, + timeseries_limit_metric: Optional[Metric] = None, order_desc: bool = True, - extras: Optional[Dict] = None, + extras: Optional[Dict[str, Any]] = None, columns: Optional[List[str]] = None, - orderby: Optional[List[List]] = None, + orderby: Optional[List[List[str]]] = None, post_processing: Optional[List[Dict[str, Any]]] = None, **kwargs: Any, ): diff --git a/superset/config.py b/superset/config.py index 35dbbf8..6da24b4 100644 --- a/superset/config.py +++ b/superset/config.py @@ -33,6 +33,7 @@ from typing import Any, Callable, Dict, List, Optional, Type, TYPE_CHECKING from cachelib.base import BaseCache from celery.schedules import crontab from dateutil import tz +from flask import Blueprint from flask_appbuilder.security.manager import AUTH_DB from superset.jinja_context import ( # pylint: disable=unused-import @@ -421,7 +422,7 @@ DEFAULT_MODULE_DS_MAP = OrderedDict( ] ) ADDITIONAL_MODULE_DS_MAP: Dict[str, List[str]] = {} -ADDITIONAL_MIDDLEWARE: List[Callable] = [] +ADDITIONAL_MIDDLEWARE: List[Callable[..., Any]] = [] # 1) https://docs.python-guide.org/writing/logging/ # 2) https://docs.python.org/2/library/logging.config.html @@ -624,7 +625,7 @@ ALLOWED_USER_CSV_SCHEMA_FUNC: Callable[ # SQL Lab. The existing context gets updated with this dictionary, # meaning values for existing keys get overwritten by the content of this # dictionary. -JINJA_CONTEXT_ADDONS: Dict[str, Callable] = {} +JINJA_CONTEXT_ADDONS: Dict[str, Callable[..., Any]] = {} # A dictionary of macro template processors that gets merged into global # template processors. The existing template processors get updated with this @@ -684,7 +685,7 @@ PERMISSION_INSTRUCTIONS_LINK = "" # Integrate external Blueprints to the app by passing them to your # configuration. These blueprints will get integrated in the app -BLUEPRINTS: List[Callable] = [] +BLUEPRINTS: List[Blueprint] = [] # Provide a callable that receives a tracking_url and returns another # URL. This is used to translate internal Hadoop job tracker URL diff --git a/superset/connectors/base/models.py b/superset/connectors/base/models.py index 0533aa1..fb2d5eb 100644 --- a/superset/connectors/base/models.py +++ b/superset/connectors/base/models.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import json -from typing import Any, Dict, Hashable, List, Optional, Type +from typing import Any, Dict, Hashable, List, Optional, Type, Union from flask_appbuilder.security.sqla.models import User from sqlalchemy import and_, Boolean, Column, Integer, String, Text @@ -64,12 +64,12 @@ class BaseDatasource( baselink: Optional[str] = None # url portion pointing to ModelView endpoint @property - def column_class(self) -> Type: + def column_class(self) -> Type["BaseColumn"]: # link to derivative of BaseColumn raise NotImplementedError() @property - def metric_class(self) -> Type: + def metric_class(self) -> Type["BaseMetric"]: # link to derivative of BaseMetric raise NotImplementedError() @@ -368,7 +368,7 @@ class BaseDatasource( """ raise NotImplementedError() - def values_for_column(self, column_name: str, limit: int = 10000) -> List: + def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]: """Given a column, returns an iterable of distinct values This is used to populate the dropdown showing a list of @@ -389,7 +389,10 @@ class BaseDatasource( @staticmethod def get_fk_many_from_list( - object_list: List[Any], fkmany: List[Column], fkmany_class: Type, key_attr: str, + object_list: List[Any], + fkmany: List[Column], + fkmany_class: Type[Union["BaseColumn", "BaseMetric"]], + key_attr: str, ) -> List[Column]: # pylint: disable=too-many-locals """Update ORM one-to-many list from object list diff --git a/superset/connectors/connector_registry.py b/superset/connectors/connector_registry.py index 4097066..3b11973 100644 --- a/superset/connectors/connector_registry.py +++ b/superset/connectors/connector_registry.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from collections import OrderedDict from typing import Dict, List, Optional, Set, Type, TYPE_CHECKING from sqlalchemy import or_ @@ -22,6 +21,8 @@ from sqlalchemy.orm import Session, subqueryload if TYPE_CHECKING: # pylint: disable=unused-import + from collections import OrderedDict + from superset.models.core import Database from superset.connectors.base.models import BaseDatasource @@ -32,7 +33,7 @@ class ConnectorRegistry: sources: Dict[str, Type["BaseDatasource"]] = {} @classmethod - def register_sources(cls, datasource_config: OrderedDict) -> None: + def register_sources(cls, datasource_config: "OrderedDict[str, List[str]]") -> None: for module_name, class_names in datasource_config.items(): class_names = [str(s) for s in class_names] module_obj = __import__(module_name, fromlist=class_names) diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py index 50f1637..4de56c9 100644 --- a/superset/connectors/druid/models.py +++ b/superset/connectors/druid/models.py @@ -24,18 +24,7 @@ from copy import deepcopy from datetime import datetime, timedelta from distutils.version import LooseVersion from multiprocessing.pool import ThreadPool -from typing import ( - Any, - Callable, - cast, - Dict, - Iterable, - List, - Optional, - Set, - Tuple, - Union, -) +from typing import Any, cast, Dict, Iterable, List, Optional, Set, Tuple, Union import pandas as pd import sqlalchemy as sa @@ -173,7 +162,7 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin): return self.__repr__() @property - def data(self) -> Dict: + def data(self) -> Dict[str, Any]: return {"id": self.id, "name": self.cluster_name, "backend": "druid"} @staticmethod @@ -354,7 +343,7 @@ class DruidColumn(Model, BaseColumn): return self.dimension_spec_json @property - def dimension_spec(self) -> Optional[Dict]: + def dimension_spec(self) -> Optional[Dict[str, Any]]: if self.dimension_spec_json: return json.loads(self.dimension_spec_json) return None @@ -438,7 +427,7 @@ class DruidMetric(Model, BaseMetric): return self.json @property - def json_obj(self) -> Dict: + def json_obj(self) -> Dict[str, Any]: try: obj = json.loads(self.json) except Exception: @@ -614,7 +603,7 @@ class DruidDatasource(Model, BaseDatasource): name = escape(self.datasource_name) return Markup(f'<a href="{url}">{name}</a>') - def get_metric_obj(self, metric_name: str) -> Dict: + def get_metric_obj(self, metric_name: str) -> Dict[str, Any]: return [m.json_obj for m in self.metrics if m.metric_name == metric_name][0] @classmethod @@ -705,7 +694,11 @@ class DruidDatasource(Model, BaseDatasource): @classmethod def sync_to_db_from_config( - cls, druid_config: Dict, user: User, cluster: DruidCluster, refresh: bool = True + cls, + druid_config: Dict[str, Any], + user: User, + cluster: DruidCluster, + refresh: bool = True, ) -> None: """Merges the ds config from druid_config into one stored in the db.""" session = db.session @@ -901,7 +894,7 @@ class DruidDatasource(Model, BaseDatasource): return postagg_metrics @staticmethod - def recursive_get_fields(_conf: Dict) -> List[str]: + def recursive_get_fields(_conf: Dict[str, Any]) -> List[str]: _type = _conf.get("type") _field = _conf.get("field") _fields = _conf.get("fields") @@ -957,8 +950,8 @@ class DruidDatasource(Model, BaseDatasource): @staticmethod def metrics_and_post_aggs( - metrics: List[Union[Dict, str]], metrics_dict: Dict[str, DruidMetric], - ) -> Tuple[OrderedDict, OrderedDict]: + metrics: List[Metric], metrics_dict: Dict[str, DruidMetric], + ) -> Tuple["OrderedDict[str, Any]", "OrderedDict[str, Any]"]: # Separate metrics into those that are aggregations # and those that are post aggregations saved_agg_names = set() @@ -987,7 +980,7 @@ class DruidDatasource(Model, BaseDatasource): ) return aggs, post_aggs - def values_for_column(self, column_name: str, limit: int = 10000) -> List: + def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]: """Retrieve some values for the given column""" logger.info( "Getting values for columns [{}] limited to [{}]".format(column_name, limit) @@ -1079,8 +1072,10 @@ class DruidDatasource(Model, BaseDatasource): @staticmethod def get_aggregations( - metrics_dict: Dict, saved_metrics: Set[str], adhoc_metrics: List[Dict] = [] - ) -> OrderedDict: + metrics_dict: Dict[str, Any], + saved_metrics: Set[str], + adhoc_metrics: Optional[List[Dict[str, Any]]] = None, + ) -> "OrderedDict[str, Any]": """ Returns a dictionary of aggregation metric names to aggregation json objects @@ -1089,7 +1084,9 @@ class DruidDatasource(Model, BaseDatasource): :param adhoc_metrics: list of adhoc metric names :raise SupersetException: if one or more metric names are not aggregations """ - aggregations: OrderedDict = OrderedDict() + if not adhoc_metrics: + adhoc_metrics = [] + aggregations = OrderedDict() invalid_metric_names = [] for metric_name in saved_metrics: if metric_name in metrics_dict: @@ -1115,7 +1112,7 @@ class DruidDatasource(Model, BaseDatasource): def get_dimensions( self, columns: List[str], columns_dict: Dict[str, DruidColumn] - ) -> List[Union[str, Dict]]: + ) -> List[Union[str, Dict[str, Any]]]: dimensions = [] columns = [col for col in columns if col in columns_dict] for column_name in columns: @@ -1433,7 +1430,7 @@ class DruidDatasource(Model, BaseDatasource): df[columns] = df[columns].fillna(NULL_STRING).astype("unicode") return df - def query(self, query_obj: Dict) -> QueryResult: + def query(self, query_obj: QueryObjectDict) -> QueryResult: qry_start_dttm = datetime.now() client = self.cluster.get_pydruid_client() query_str = self.get_query_str(client=client, query_obj=query_obj, phase=2) @@ -1583,7 +1580,7 @@ class DruidDatasource(Model, BaseDatasource): dimension=col, value=eq, extraction_function=extraction_fn ) elif is_list_target: - eq = cast(list, eq) + eq = cast(List[Any], eq) fields = [] # ignore the filter if it has no value if not len(eq): diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 0e91bd2..4e93d5f 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -597,7 +597,7 @@ class SqlaTable(Model, BaseDatasource): ) @property - def data(self) -> Dict: + def data(self) -> Dict[str, Any]: d = super().data if self.type == "table": grains = self.database.grains() or [] @@ -684,7 +684,9 @@ class SqlaTable(Model, BaseDatasource): return TextAsFrom(sa.text(from_sql), []).alias("expr_qry") return self.get_sqla_table() - def adhoc_metric_to_sqla(self, metric: Dict, cols: Dict) -> Optional[Column]: + def adhoc_metric_to_sqla( + self, metric: Dict[str, Any], cols: Dict[str, Any] + ) -> Optional[Column]: """ Turn an adhoc metric into a sqlalchemy column. @@ -804,7 +806,7 @@ class SqlaTable(Model, BaseDatasource): main_metric_expr = self.make_sqla_column_compatible(main_metric_expr, label) select_exprs: List[Column] = [] - groupby_exprs_sans_timestamp: OrderedDict = OrderedDict() + groupby_exprs_sans_timestamp = OrderedDict() if (is_sip_38 and metrics and columns) or (not is_sip_38 and groupby): # dedup columns while preserving order @@ -874,7 +876,7 @@ class SqlaTable(Model, BaseDatasource): qry = qry.group_by(*groupby_exprs_with_timestamp.values()) where_clause_and = [] - having_clause_and: List = [] + having_clause_and = [] for flt in filter: # type: ignore if not all([flt.get(s) for s in ["col", "op"]]): @@ -1082,7 +1084,10 @@ class SqlaTable(Model, BaseDatasource): return ob def _get_top_groups( - self, df: pd.DataFrame, dimensions: List, groupby_exprs: OrderedDict + self, + df: pd.DataFrame, + dimensions: List[str], + groupby_exprs: "OrderedDict[str, Any]", ) -> ColumnElement: groups = [] for unused, row in df.iterrows(): diff --git a/superset/dao/base.py b/superset/dao/base.py index 020feed..59791ff 100644 --- a/superset/dao/base.py +++ b/superset/dao/base.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional from flask_appbuilder.models.filters import BaseFilter from flask_appbuilder.models.sqla import Model @@ -75,7 +75,7 @@ class BaseDAO: return query.all() @classmethod - def create(cls, properties: Dict, commit: bool = True) -> Model: + def create(cls, properties: Dict[str, Any], commit: bool = True) -> Model: """ Generic for creating models :raises: DAOCreateFailedError @@ -95,7 +95,9 @@ class BaseDAO: return model @classmethod - def update(cls, model: Model, properties: Dict, commit: bool = True) -> Model: + def update( + cls, model: Model, properties: Dict[str, Any], commit: bool = True + ) -> Model: """ Generic update a model :raises: DAOCreateFailedError diff --git a/superset/dashboards/commands/create.py b/superset/dashboards/commands/create.py index 0aa1241..8376369 100644 --- a/superset/dashboards/commands/create.py +++ b/superset/dashboards/commands/create.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional from flask_appbuilder.models.sqla import Model from flask_appbuilder.security.sqla.models import User @@ -35,7 +35,7 @@ logger = logging.getLogger(__name__) class CreateDashboardCommand(BaseCommand): - def __init__(self, user: User, data: Dict): + def __init__(self, user: User, data: Dict[str, Any]): self._actor = user self._properties = data.copy() diff --git a/superset/dashboards/commands/update.py b/superset/dashboards/commands/update.py index 7746b7e..54c5de1 100644 --- a/superset/dashboards/commands/update.py +++ b/superset/dashboards/commands/update.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional from flask_appbuilder.models.sqla import Model from flask_appbuilder.security.sqla.models import User @@ -40,7 +40,7 @@ logger = logging.getLogger(__name__) class UpdateDashboardCommand(BaseCommand): - def __init__(self, user: User, model_id: int, data: Dict): + def __init__(self, user: User, model_id: int, data: Dict[str, Any]): self._actor = user self._model_id = model_id self._properties = data.copy() diff --git a/superset/datasets/commands/create.py b/superset/datasets/commands/create.py index 3114a4f..436fdd2 100644 --- a/superset/datasets/commands/create.py +++ b/superset/datasets/commands/create.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional from flask_appbuilder.models.sqla import Model from flask_appbuilder.security.sqla.models import User @@ -39,7 +39,7 @@ logger = logging.getLogger(__name__) class CreateDatasetCommand(BaseCommand): - def __init__(self, user: User, data: Dict): + def __init__(self, user: User, data: Dict[str, Any]): self._actor = user self._properties = data.copy() diff --git a/superset/datasets/commands/update.py b/superset/datasets/commands/update.py index c7f70dd..14cc087 100644 --- a/superset/datasets/commands/update.py +++ b/superset/datasets/commands/update.py @@ -16,7 +16,7 @@ # under the License. import logging from collections import Counter -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional from flask_appbuilder.models.sqla import Model from flask_appbuilder.security.sqla.models import User @@ -48,7 +48,7 @@ logger = logging.getLogger(__name__) class UpdateDatasetCommand(BaseCommand): - def __init__(self, user: User, model_id: int, data: Dict): + def __init__(self, user: User, model_id: int, data: Dict[str, Any]): self._actor = user self._model_id = model_id self._properties = data.copy() @@ -111,7 +111,7 @@ class UpdateDatasetCommand(BaseCommand): raise exception def _validate_columns( - self, columns: List[Dict], exceptions: List[ValidationError] + self, columns: List[Dict[str, Any]], exceptions: List[ValidationError] ) -> None: # Validate duplicates on data if self._get_duplicates(columns, "column_name"): @@ -133,7 +133,7 @@ class UpdateDatasetCommand(BaseCommand): exceptions.append(DatasetColumnsExistsValidationError()) def _validate_metrics( - self, metrics: List[Dict], exceptions: List[ValidationError] + self, metrics: List[Dict[str, Any]], exceptions: List[ValidationError] ) -> None: if self._get_duplicates(metrics, "metric_name"): exceptions.append(DatasetMetricsDuplicateValidationError()) @@ -152,7 +152,7 @@ class UpdateDatasetCommand(BaseCommand): exceptions.append(DatasetMetricsExistsValidationError()) @staticmethod - def _get_duplicates(data: List[Dict], key: str) -> List[str]: + def _get_duplicates(data: List[Dict[str, Any]], key: str) -> List[str]: duplicates = [ name for name, count in Counter([item[key] for item in data]).items() diff --git a/superset/datasets/dao.py b/superset/datasets/dao.py index 5dfe4ef..ef20a69 100644 --- a/superset/datasets/dao.py +++ b/superset/datasets/dao.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional from flask import current_app from sqlalchemy.exc import SQLAlchemyError @@ -116,7 +116,7 @@ class DatasetDAO(BaseDAO): @classmethod def update( - cls, model: SqlaTable, properties: Dict, commit: bool = True + cls, model: SqlaTable, properties: Dict[str, Any], commit: bool = True ) -> Optional[SqlaTable]: """ Updates a Dataset model on the metadata DB @@ -151,13 +151,13 @@ class DatasetDAO(BaseDAO): @classmethod def update_column( - cls, model: TableColumn, properties: Dict, commit: bool = True + cls, model: TableColumn, properties: Dict[str, Any], commit: bool = True ) -> Optional[TableColumn]: return DatasetColumnDAO.update(model, properties, commit=commit) @classmethod def create_column( - cls, properties: Dict, commit: bool = True + cls, properties: Dict[str, Any], commit: bool = True ) -> Optional[TableColumn]: """ Creates a Dataset model on the metadata DB @@ -166,13 +166,13 @@ class DatasetDAO(BaseDAO): @classmethod def update_metric( - cls, model: SqlMetric, properties: Dict, commit: bool = True + cls, model: SqlMetric, properties: Dict[str, Any], commit: bool = True ) -> Optional[SqlMetric]: return DatasetMetricDAO.update(model, properties, commit=commit) @classmethod def create_metric( - cls, properties: Dict, commit: bool = True + cls, properties: Dict[str, Any], commit: bool = True ) -> Optional[SqlMetric]: """ Creates a Dataset model on the metadata DB diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index a593f59..7b0d537 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -151,7 +151,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods try_remove_schema_from_table_name = True # pylint: disable=invalid-name # default matching patterns for identifying column types - db_column_types: Dict[utils.DbColumnType, Tuple[Pattern, ...]] = { + db_column_types: Dict[utils.DbColumnType, Tuple[Pattern[Any], ...]] = { utils.DbColumnType.NUMERIC: ( re.compile(r".*DOUBLE.*", re.IGNORECASE), re.compile(r".*FLOAT.*", re.IGNORECASE), @@ -296,7 +296,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return select_exprs @classmethod - def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]: + def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple[Any, ...]]: """ :param cursor: Cursor instance @@ -311,8 +311,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods @classmethod def expand_data( - cls, columns: List[dict], data: List[dict] - ) -> Tuple[List[dict], List[dict], List[dict]]: + cls, columns: List[Dict[Any, Any]], data: List[Dict[Any, Any]] + ) -> Tuple[List[Dict[Any, Any]], List[Dict[Any, Any]], List[Dict[Any, Any]]]: """ Some engines support expanding nested fields. See implementation in Presto spec for details. @@ -645,7 +645,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods schema: Optional[str], database: "Database", query: Select, - columns: Optional[List] = None, + columns: Optional[List[Dict[str, str]]] = None, ) -> Optional[Select]: """ Add a where clause to a query to reference only the most recent partition @@ -925,7 +925,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return [] @staticmethod - def pyodbc_rows_to_tuples(data: List[Any]) -> List[Tuple]: + def pyodbc_rows_to_tuples(data: List[Any]) -> List[Tuple[Any, ...]]: """ Convert pyodbc.Row objects from `fetch_data` to tuples. diff --git a/superset/db_engine_specs/bigquery.py b/superset/db_engine_specs/bigquery.py index 992b5fe..3091d65 100644 --- a/superset/db_engine_specs/bigquery.py +++ b/superset/db_engine_specs/bigquery.py @@ -83,7 +83,7 @@ class BigQueryEngineSpec(BaseEngineSpec): return None @classmethod - def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]: + def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple[Any, ...]]: data = super().fetch_data(cursor, limit) # Support type BigQuery Row, introduced here PR #4071 # google.cloud.bigquery.table.Row diff --git a/superset/db_engine_specs/exasol.py b/superset/db_engine_specs/exasol.py index 480a8c2..23449f0 100644 --- a/superset/db_engine_specs/exasol.py +++ b/superset/db_engine_specs/exasol.py @@ -39,7 +39,7 @@ class ExasolEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method } @classmethod - def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]: + def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple[Any, ...]]: data = super().fetch_data(cursor, limit) # Lists of `pyodbc.Row` need to be unpacked further return cls.pyodbc_rows_to_tuples(data) diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py index 3fb09ef..63c27bd 100644 --- a/superset/db_engine_specs/hive.py +++ b/superset/db_engine_specs/hive.py @@ -93,7 +93,7 @@ class HiveEngineSpec(PrestoEngineSpec): return BaseEngineSpec.get_all_datasource_names(database, datasource_type) @classmethod - def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]: + def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple[Any, ...]]: import pyhive from TCLIService import ttypes @@ -304,7 +304,7 @@ class HiveEngineSpec(PrestoEngineSpec): schema: Optional[str], database: "Database", query: Select, - columns: Optional[List] = None, + columns: Optional[List[Dict[str, str]]] = None, ) -> Optional[Select]: try: col_names, values = cls.latest_partition( @@ -323,7 +323,7 @@ class HiveEngineSpec(PrestoEngineSpec): return None @classmethod - def _get_fields(cls, cols: List[dict]) -> List[ColumnClause]: + def _get_fields(cls, cols: List[Dict[str, Any]]) -> List[ColumnClause]: return BaseEngineSpec._get_fields(cols) # pylint: disable=protected-access @classmethod diff --git a/superset/db_engine_specs/mssql.py b/superset/db_engine_specs/mssql.py index fde69b3..45c2f23 100644 --- a/superset/db_engine_specs/mssql.py +++ b/superset/db_engine_specs/mssql.py @@ -66,7 +66,7 @@ class MssqlEngineSpec(BaseEngineSpec): return None @classmethod - def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]: + def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple[Any, ...]]: data = super().fetch_data(cursor, limit) # Lists of `pyodbc.Row` need to be unpacked further return cls.pyodbc_rows_to_tuples(data) diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py index b5f1b2c..c5e4221 100644 --- a/superset/db_engine_specs/postgres.py +++ b/superset/db_engine_specs/postgres.py @@ -51,7 +51,7 @@ class PostgresBaseEngineSpec(BaseEngineSpec): } @classmethod - def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]: + def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple[Any, ...]]: cursor.tzinfo_factory = FixedOffsetTimezone if not cursor.description: return [] diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index 9bc9307..e8c9603 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -164,7 +164,7 @@ class PrestoEngineSpec(BaseEngineSpec): return [row[0] for row in results] @classmethod - def _create_column_info(cls, name: str, data_type: str) -> dict: + def _create_column_info(cls, name: str, data_type: str) -> Dict[str, Any]: """ Create column info object :param name: column name @@ -213,7 +213,10 @@ class PrestoEngineSpec(BaseEngineSpec): @classmethod def _parse_structural_column( # pylint: disable=too-many-locals,too-many-branches - cls, parent_column_name: str, parent_data_type: str, result: List[dict] + cls, + parent_column_name: str, + parent_data_type: str, + result: List[Dict[str, Any]], ) -> None: """ Parse a row or array column @@ -322,7 +325,7 @@ class PrestoEngineSpec(BaseEngineSpec): (i.e. column name and data type) """ columns = cls._show_columns(inspector, table_name, schema) - result: List[dict] = [] + result: List[Dict[str, Any]] = [] for column in columns: try: # parse column if it is a row or array @@ -361,7 +364,7 @@ class PrestoEngineSpec(BaseEngineSpec): return column_name.startswith('"') and column_name.endswith('"') @classmethod - def _get_fields(cls, cols: List[dict]) -> List[ColumnClause]: + def _get_fields(cls, cols: List[Dict[str, Any]]) -> List[ColumnClause]: """ Format column clauses where names are in quotes and labels are specified :param cols: columns @@ -561,8 +564,8 @@ class PrestoEngineSpec(BaseEngineSpec): @classmethod def expand_data( # pylint: disable=too-many-locals - cls, columns: List[dict], data: List[dict] - ) -> Tuple[List[dict], List[dict], List[dict]]: + cls, columns: List[Dict[Any, Any]], data: List[Dict[Any, Any]] + ) -> Tuple[List[Dict[Any, Any]], List[Dict[Any, Any]], List[Dict[Any, Any]]]: """ We do not immediately display rows and arrays clearly in the data grid. This method separates out nested fields and data values to help clearly display @@ -590,7 +593,7 @@ class PrestoEngineSpec(BaseEngineSpec): # process each column, unnesting ARRAY types and # expanding ROW types into new columns to_process = deque((column, 0) for column in columns) - all_columns: List[dict] = [] + all_columns: List[Dict[str, Any]] = [] expanded_columns = [] current_array_level = None while to_process: @@ -843,7 +846,7 @@ class PrestoEngineSpec(BaseEngineSpec): schema: Optional[str], database: "Database", query: Select, - columns: Optional[List] = None, + columns: Optional[List[Dict[str, str]]] = None, ) -> Optional[Select]: try: col_names, values = cls.latest_partition( diff --git a/superset/extensions.py b/superset/extensions.py index f321046..a0dad81 100644 --- a/superset/extensions.py +++ b/superset/extensions.py @@ -95,7 +95,9 @@ class UIManifestProcessor: self.parse_manifest_json() @app.context_processor - def get_manifest() -> Dict[str, Callable]: # pylint: disable=unused-variable + def get_manifest() -> Dict[ # pylint: disable=unused-variable + str, Callable[[str], List[str]] + ]: loaded_chunks = set() def get_files(bundle: str, asset_type: str = "js") -> List[str]: @@ -131,7 +133,7 @@ appbuilder = AppBuilder(update_perms=False) cache_manager = CacheManager() celery_app = celery.Celery() db = SQLA() -_event_logger: dict = {} +_event_logger: Dict[str, Any] = {} event_logger = LocalProxy(lambda: _event_logger.get("event_logger")) feature_flag_manager = FeatureFlagManager() jinja_context_manager = JinjaContextManager() diff --git a/superset/models/core.py b/superset/models/core.py index 015abc2..ed90e4a 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -341,11 +341,14 @@ class Database( def get_reserved_words(self) -> Set[str]: return self.get_dialect().preparer.reserved_words - def get_quoter(self) -> Callable: + def get_quoter(self) -> Callable[[str, Any], str]: return self.get_dialect().identifier_preparer.quote def get_df( # pylint: disable=too-many-locals - self, sql: str, schema: Optional[str] = None, mutator: Optional[Callable] = None + self, + sql: str, + schema: Optional[str] = None, + mutator: Optional[Callable[[pd.DataFrame], None]] = None, ) -> pd.DataFrame: sqls = [str(s).strip(" ;") for s in sqlparse.parse(sql)] @@ -450,7 +453,7 @@ class Database( @cache_util.memoized_func( key=lambda *args, **kwargs: "db:{}:schema:None:view_list", - attribute_in_key="id", # type: ignore + attribute_in_key="id", ) def get_all_view_names_in_database( self, diff --git a/superset/models/dashboard.py b/superset/models/dashboard.py index de42285..d10809e 100644 --- a/superset/models/dashboard.py +++ b/superset/models/dashboard.py @@ -240,7 +240,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes self.json_metadata = value @property - def position(self) -> Dict: + def position(self) -> Dict[str, Any]: if self.position_json: return json.loads(self.position_json) return {} @@ -315,7 +315,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes old_to_new_slc_id_dict: Dict[int, int] = {} new_timed_refresh_immune_slices = [] new_expanded_slices = {} - new_filter_scopes: Dict[str, Dict] = {} + new_filter_scopes = {} i_params_dict = dashboard_to_import.params_dict remote_id_slice_map = { slc.params_dict["remote_id"]: slc @@ -351,7 +351,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes # are converted to filter_scopes # but dashboard create from import may still have old dashboard filter metadata # here we convert them to new filter_scopes metadata first - filter_scopes: Dict = {} + filter_scopes = {} if ( "filter_immune_slices" in i_params_dict or "filter_immune_slice_fields" in i_params_dict @@ -415,7 +415,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes @classmethod def export_dashboards( # pylint: disable=too-many-locals - cls, dashboard_ids: List + cls, dashboard_ids: List[int] ) -> str: copied_dashboards = [] datasource_ids = set() diff --git a/superset/models/helpers.py b/superset/models/helpers.py index 42169e6..4ffe5e9 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -81,7 +81,7 @@ class ImportMixin: for u in cls.__table_args__ # type: ignore if isinstance(u, UniqueConstraint) ] - unique.extend( # type: ignore + unique.extend( {c.name} for c in cls.__table__.columns if c.unique # type: ignore ) return unique diff --git a/superset/models/slice.py b/superset/models/slice.py index 76eb457..4f73e43 100644 --- a/superset/models/slice.py +++ b/superset/models/slice.py @@ -36,7 +36,7 @@ from superset.tasks.thumbnails import cache_chart_thumbnail from superset.utils import core as utils if is_feature_enabled("SIP_38_VIZ_REARCHITECTURE"): - from superset.viz_sip38 import BaseViz, viz_types # type: ignore + from superset.viz_sip38 import BaseViz, viz_types else: from superset.viz import BaseViz, viz_types # type: ignore diff --git a/superset/models/sql_types/presto_sql_types.py b/superset/models/sql_types/presto_sql_types.py index a50b4c2..47486cf 100644 --- a/superset/models/sql_types/presto_sql_types.py +++ b/superset/models/sql_types/presto_sql_types.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Optional, Type +from typing import Any, Dict, List, Optional, Type from sqlalchemy import types from sqlalchemy.sql.sqltypes import Integer @@ -29,7 +29,7 @@ class TinyInteger(Integer): A type for tiny ``int`` integers. """ - def python_type(self) -> Type: + def python_type(self) -> Type[int]: return int @classmethod @@ -42,7 +42,7 @@ class Interval(TypeEngine): A type for intervals. """ - def python_type(self) -> Optional[Type]: + def python_type(self) -> Optional[Type[Any]]: return None @classmethod @@ -55,7 +55,7 @@ class Array(TypeEngine): A type for arrays. """ - def python_type(self) -> Optional[Type]: + def python_type(self) -> Optional[Type[List[Any]]]: return list @classmethod @@ -68,7 +68,7 @@ class Map(TypeEngine): A type for maps. """ - def python_type(self) -> Optional[Type]: + def python_type(self) -> Optional[Type[Dict[Any, Any]]]: return dict @classmethod @@ -81,7 +81,7 @@ class Row(TypeEngine): A type for rows. """ - def python_type(self) -> Optional[Type]: + def python_type(self) -> Optional[Type[Any]]: return None @classmethod diff --git a/superset/queries/filters.py b/superset/queries/filters.py index 323c3c6..22cf45f 100644 --- a/superset/queries/filters.py +++ b/superset/queries/filters.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Callable +from typing import Any from flask import g from flask_sqlalchemy import BaseQuery @@ -25,7 +25,7 @@ from superset.views.base import BaseFilter class QueryFilter(BaseFilter): # pylint: disable=too-few-public-methods - def apply(self, query: BaseQuery, value: Callable) -> BaseQuery: + def apply(self, query: BaseQuery, value: Any) -> BaseQuery: """ Filter queries to only those owned by current user. If can_access_all_queries permission is set a user can list all queries diff --git a/superset/result_set.py b/superset/result_set.py index 4880511..dd6f0ff 100644 --- a/superset/result_set.py +++ b/superset/result_set.py @@ -20,7 +20,7 @@ import datetime import json import logging -from typing import Any, Callable, Dict, List, Optional, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple, Type import numpy as np import pandas as pd @@ -64,7 +64,7 @@ def stringify(obj: Any) -> str: def stringify_values(array: np.ndarray) -> np.ndarray: - vstringify: Callable = np.vectorize(stringify) + vstringify = np.vectorize(stringify) return vstringify(array) @@ -172,7 +172,7 @@ class SupersetResultSet: return table.to_pandas(integer_object_nulls=True) @staticmethod - def first_nonempty(items: List) -> Any: + def first_nonempty(items: List[Any]) -> Any: return next((i for i in items if i), None) def is_temporal(self, db_type_str: Optional[str]) -> bool: diff --git a/superset/security/manager.py b/superset/security/manager.py index 51cb3e0..772edbb 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -21,11 +21,11 @@ from typing import Any, Callable, List, Optional, Set, Tuple, TYPE_CHECKING, Uni from flask import current_app, g from flask_appbuilder import Model -from flask_appbuilder.security.sqla import models as ab_models from flask_appbuilder.security.sqla.manager import SecurityManager from flask_appbuilder.security.sqla.models import ( assoc_permissionview_role, assoc_user_role, + PermissionView, ) from flask_appbuilder.security.views import ( PermissionModelView, @@ -602,11 +602,8 @@ class SupersetSecurityManager(SecurityManager): logger.info("Cleaning faulty perms") sesh = self.get_session - pvms = sesh.query(ab_models.PermissionView).filter( - or_( - ab_models.PermissionView.permission == None, - ab_models.PermissionView.view_menu == None, - ) + pvms = sesh.query(PermissionView).filter( + or_(PermissionView.permission == None, PermissionView.view_menu == None,) ) deleted_count = pvms.delete() sesh.commit() @@ -640,7 +637,9 @@ class SupersetSecurityManager(SecurityManager): self.get_session.commit() self.clean_perms() - def set_role(self, role_name: str, pvm_check: Callable) -> None: + def set_role( + self, role_name: str, pvm_check: Callable[[PermissionView], bool] + ) -> None: """ Set the FAB permission/views for the role. @@ -650,7 +649,7 @@ class SupersetSecurityManager(SecurityManager): logger.info("Syncing {} perms".format(role_name)) sesh = self.get_session - pvms = sesh.query(ab_models.PermissionView).all() + pvms = sesh.query(PermissionView).all() pvms = [p for p in pvms if p.permission and p.view_menu] role = self.add_role(role_name) role_pvms = [p for p in pvms if pvm_check(p)] diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 3ba1e3a..d9f5b38 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -299,9 +299,10 @@ def _serialize_and_expand_data( db_engine_spec: BaseEngineSpec, use_msgpack: Optional[bool] = False, expand_data: bool = False, -) -> Tuple[Union[bytes, str], list, list, list]: - selected_columns: List[Dict] = result_set.columns - expanded_columns: List[Dict] +) -> Tuple[Union[bytes, str], List[Any], List[Any], List[Any]]: + selected_columns = result_set.columns + all_columns: List[Any] + expanded_columns: List[Any] if use_msgpack: with stats_timing( diff --git a/superset/tasks/celery_app.py b/superset/tasks/celery_app.py index 0344b59..0f3cd0e 100644 --- a/superset/tasks/celery_app.py +++ b/superset/tasks/celery_app.py @@ -25,7 +25,7 @@ from superset import create_app from superset.extensions import celery_app # Init the Flask app / configure everything -create_app() # type: ignore +create_app() # Need to import late, as the celery_app will have been setup by "create_app()" # pylint: disable=wrong-import-position, unused-import diff --git a/superset/tasks/schedules.py b/superset/tasks/schedules.py index 2a5733e..3e6c1dd 100644 --- a/superset/tasks/schedules.py +++ b/superset/tasks/schedules.py @@ -23,7 +23,7 @@ import urllib.request from collections import namedtuple from datetime import datetime, timedelta from email.utils import make_msgid, parseaddr -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union +from typing import Any, Dict, Iterator, List, Optional, Tuple, TYPE_CHECKING, Union from urllib.error import URLError # pylint: disable=ungrouped-imports import croniter @@ -36,7 +36,6 @@ from flask_login import login_user from retry.api import retry_call from selenium.common.exceptions import WebDriverException from selenium.webdriver import chrome, firefox -from werkzeug.datastructures import TypeConversionDict from werkzeug.http import parse_cookie # Superset framework imports @@ -53,6 +52,11 @@ from superset.models.schedules import ( ) from superset.utils.core import get_email_address_list, send_email_smtp +if TYPE_CHECKING: + # pylint: disable=unused-import + from werkzeug.datastructures import TypeConversionDict + + # Globals config = app.config logger = logging.getLogger("tasks.email_reports") @@ -131,7 +135,7 @@ def _generate_mail_content( return EmailContent(body, data, images) -def _get_auth_cookies() -> List[TypeConversionDict]: +def _get_auth_cookies() -> List["TypeConversionDict[Any, Any]"]: # Login with the user specified to get the reports with app.test_request_context(): user = security_manager.find_user(config["EMAIL_REPORTS_USER"]) diff --git a/superset/utils/cache.py b/superset/utils/cache.py index bd39f87..586cb2b 100644 --- a/superset/utils/cache.py +++ b/superset/utils/cache.py @@ -27,8 +27,9 @@ def view_cache_key(*args: Any, **kwargs: Any) -> str: # pylint: disable=unused- def memoized_func( - key: Callable = view_cache_key, attribute_in_key: Optional[str] = None -) -> Callable: + key: Callable[..., str] = view_cache_key, # pylint: disable=bad-whitespace + attribute_in_key: Optional[str] = None, +) -> Callable[..., Any]: """Use this decorator to cache functions that have predefined first arg. enable_cache is treated as True by default, @@ -45,7 +46,7 @@ def memoized_func( returns the caching key. """ - def wrap(f: Callable) -> Callable: + def wrap(f: Callable[..., Any]) -> Callable[..., Any]: if cache_manager.tables_cache: def wrapped_f(self: Any, *args: Any, **kwargs: Any) -> Any: diff --git a/superset/utils/core.py b/superset/utils/core.py index 00e3484..1620542 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -85,7 +85,7 @@ from superset.exceptions import ( SupersetException, SupersetTimeoutException, ) -from superset.typing import FormData, Metric +from superset.typing import FlaskResponse, FormData, Metric from superset.utils.dates import datetime_to_epoch, EPOCH try: @@ -147,7 +147,9 @@ class _memoized: should account for instance variable changes. """ - def __init__(self, func: Callable, watch: Optional[Tuple[str, ...]] = None) -> None: + def __init__( + self, func: Callable[..., Any], watch: Optional[Tuple[str, ...]] = None + ) -> None: self.func = func self.cache: Dict[Any, Any] = {} self.is_method = False @@ -173,7 +175,7 @@ class _memoized: """Return the function's docstring.""" return self.func.__doc__ or "" - def __get__(self, obj: Any, objtype: Type) -> functools.partial: + def __get__(self, obj: Any, objtype: Type[Any]) -> functools.partial: # type: ignore if not self.is_method: self.is_method = True """Support instance methods.""" @@ -181,13 +183,13 @@ class _memoized: def memoized( - func: Optional[Callable] = None, watch: Optional[Tuple[str, ...]] = None -) -> Callable: + func: Optional[Callable[..., Any]] = None, watch: Optional[Tuple[str, ...]] = None +) -> Callable[..., Any]: if func: return _memoized(func) else: - def wrapper(f: Callable) -> Callable: + def wrapper(f: Callable[..., Any]) -> Callable[..., Any]: return _memoized(f, watch) return wrapper @@ -1241,7 +1243,9 @@ def create_ssl_cert_file(certificate: str) -> str: return path -def time_function(func: Callable, *args: Any, **kwargs: Any) -> Tuple[float, Any]: +def time_function( + func: Callable[..., FlaskResponse], *args: Any, **kwargs: Any +) -> Tuple[float, Any]: """ Measures the amount of time a function takes to execute in ms diff --git a/superset/utils/dashboard_filter_scopes_converter.py b/superset/utils/dashboard_filter_scopes_converter.py index f77e0e0..d95582d 100644 --- a/superset/utils/dashboard_filter_scopes_converter.py +++ b/superset/utils/dashboard_filter_scopes_converter.py @@ -29,7 +29,7 @@ def convert_filter_scopes( ) -> Dict[int, Dict[str, Dict[str, Any]]]: filter_scopes = {} immuned_by_id: List[int] = json_metadata.get("filter_immune_slices") or [] - immuned_by_column: Dict = defaultdict(list) + immuned_by_column: Dict[str, List[int]] = defaultdict(list) for slice_id, columns in json_metadata.get( "filter_immune_slice_fields", {} ).items(): @@ -52,7 +52,7 @@ def convert_filter_scopes( logging.info(f"slice [{filter_id}] has invalid field: {filter_field}") for filter_slice in filters: - filter_fields: Dict = {} + filter_fields: Dict[str, Dict[str, Any]] = {} filter_id = filter_slice.id slice_params = json.loads(filter_slice.params or "{}") configs = slice_params.get("filter_configs") or [] @@ -77,9 +77,10 @@ def convert_filter_scopes( def copy_filter_scopes( - old_to_new_slc_id_dict: Dict[int, int], old_filter_scopes: Dict[str, Dict] -) -> Dict: - new_filter_scopes: Dict[str, Dict] = {} + old_to_new_slc_id_dict: Dict[int, int], + old_filter_scopes: Dict[int, Dict[str, Dict[str, Any]]], +) -> Dict[str, Dict[Any, Any]]: + new_filter_scopes: Dict[str, Dict[Any, Any]] = {} for (filter_id, scopes) in old_filter_scopes.items(): new_filter_key = old_to_new_slc_id_dict.get(int(filter_id)) if new_filter_key: diff --git a/superset/utils/decorators.py b/superset/utils/decorators.py index a1165c5..bb0219c 100644 --- a/superset/utils/decorators.py +++ b/superset/utils/decorators.py @@ -46,7 +46,7 @@ def stats_timing(stats_key: str, stats_logger: BaseStatsLogger) -> Iterator[floa stats_logger.timing(stats_key, now_as_float() - start_ts) -def etag_cache(max_age: int, check_perms: Callable) -> Callable: +def etag_cache(max_age: int, check_perms: Callable[..., Any]) -> Callable[..., Any]: """ A decorator for caching views and handling etag conditional requests. @@ -60,7 +60,7 @@ def etag_cache(max_age: int, check_perms: Callable) -> Callable: """ - def decorator(f: Callable) -> Callable: + def decorator(f: Callable[..., Any]) -> Callable[..., Any]: @wraps(f) def wrapper(*args: Any, **kwargs: Any) -> ETagResponseMixin: # check if the user can access the resource diff --git a/superset/utils/import_datasource.py b/superset/utils/import_datasource.py index 19f6d59..50f375c 100644 --- a/superset/utils/import_datasource.py +++ b/superset/utils/import_datasource.py @@ -27,8 +27,8 @@ logger = logging.getLogger(__name__) def import_datasource( session: Session, i_datasource: Model, - lookup_database: Callable, - lookup_datasource: Callable, + lookup_database: Callable[[Model], Model], + lookup_datasource: Callable[[Model], Model], import_time: Optional[int] = None, ) -> int: """Imports the datasource from the object to the database. @@ -82,7 +82,9 @@ def import_datasource( return datasource.id -def import_simple_obj(session: Session, i_obj: Model, lookup_obj: Callable) -> Model: +def import_simple_obj( + session: Session, i_obj: Model, lookup_obj: Callable[[Model], Model] +) -> Model: make_transient(i_obj) i_obj.id = None i_obj.table = None diff --git a/superset/utils/log.py b/superset/utils/log.py index aafe3b8..b31abce 100644 --- a/superset/utils/log.py +++ b/superset/utils/log.py @@ -35,7 +35,7 @@ class AbstractEventLogger(ABC): ) -> None: pass - def log_this(self, f: Callable) -> Callable: + def log_this(self, f: Callable[..., Any]) -> Callable[..., Any]: @functools.wraps(f) def wrapper(*args: Any, **kwargs: Any) -> Any: user_id = None @@ -124,7 +124,7 @@ def get_event_logger_from_cfg_value(cfg_value: Any) -> AbstractEventLogger: ) ) - event_logger_type = cast(Type, cfg_value) + event_logger_type = cast(Type[Any], cfg_value) result = event_logger_type() # Verify that we have a valid logger impl diff --git a/superset/utils/logging_configurator.py b/superset/utils/logging_configurator.py index 396d35e..09f1e58 100644 --- a/superset/utils/logging_configurator.py +++ b/superset/utils/logging_configurator.py @@ -58,7 +58,7 @@ class DefaultLoggingConfigurator(LoggingConfigurator): if app_config["ENABLE_TIME_ROTATE"]: logging.getLogger().setLevel(app_config["TIME_ROTATE_LOG_LEVEL"]) - handler = TimedRotatingFileHandler( # type: ignore + handler = TimedRotatingFileHandler( app_config["FILENAME"], when=app_config["ROLLOVER"], interval=app_config["INTERVAL"], diff --git a/superset/utils/pandas_postprocessing.py b/superset/utils/pandas_postprocessing.py index 39a4278..e62b393 100644 --- a/superset/utils/pandas_postprocessing.py +++ b/superset/utils/pandas_postprocessing.py @@ -72,8 +72,8 @@ WHITELIST_CUMULATIVE_FUNCTIONS = ( ) -def validate_column_args(*argnames: str) -> Callable: - def wrapper(func: Callable) -> Callable: +def validate_column_args(*argnames: str) -> Callable[..., Any]: + def wrapper(func: Callable[..., Any]) -> Callable[..., Any]: def wrapped(df: DataFrame, **options: Any) -> Any: columns = df.columns.tolist() for name in argnames: @@ -471,7 +471,7 @@ def geodetic_parse( Parse a string containing a geodetic point and return latitude, longitude and altitude """ - point = Point(location) # type: ignore + point = Point(location) return point[0], point[1], point[2] try: diff --git a/superset/utils/screenshots.py b/superset/utils/screenshots.py index e07d2a2..b2f222f 100644 --- a/superset/utils/screenshots.py +++ b/superset/utils/screenshots.py @@ -51,7 +51,7 @@ SELENIUM_HEADSTART = 3 WindowSize = Tuple[int, int] -def get_auth_cookies(user: "User") -> List[Dict]: +def get_auth_cookies(user: "User") -> List[Dict[Any, Any]]: # Login with the user specified to get the reports with current_app.test_request_context("/login"): login_user(user) @@ -101,14 +101,14 @@ class AuthWebDriverProxy: self, driver_type: str, window: Optional[WindowSize] = None, - auth_func: Optional[Callable] = None, + auth_func: Optional[ + Callable[..., Any] + ] = None, # pylint: disable=bad-whitespace ): self._driver_type = driver_type self._window: WindowSize = window or (800, 600) - config_auth_func: Callable = current_app.config.get( - "WEBDRIVER_AUTH_FUNC", auth_driver - ) - self._auth_func: Callable = auth_func or config_auth_func + config_auth_func = current_app.config.get("WEBDRIVER_AUTH_FUNC", auth_driver) + self._auth_func = auth_func or config_auth_func def create(self) -> WebDriver: if self._driver_type == "firefox": @@ -123,7 +123,7 @@ class AuthWebDriverProxy: raise Exception(f"Webdriver name ({self._driver_type}) not supported") # Prepare args for the webdriver init options.add_argument("--headless") - kwargs: Dict = dict(options=options) + kwargs: Dict[Any, Any] = dict(options=options) kwargs.update(current_app.config["WEBDRIVER_CONFIGURATION"]) logger.info("Init selenium driver") return driver_class(**kwargs) diff --git a/superset/views/base.py b/superset/views/base.py index 1238218..bbf18c1 100644 --- a/superset/views/base.py +++ b/superset/views/base.py @@ -143,7 +143,7 @@ def generate_download_headers( return headers -def api(f: Callable) -> Callable: +def api(f: Callable[..., FlaskResponse]) -> Callable[..., FlaskResponse]: """ A decorator to label an endpoint as an API. Catches uncaught exceptions and return the response in the JSON format @@ -383,11 +383,11 @@ class DeleteMixin: # pylint: disable=too-few-public-methods :param primary_key: record primary key to delete """ - item = self.datamodel.get(primary_key, self._base_filters) # type: ignore + item = self.datamodel.get(primary_key, self._base_filters) if not item: abort(404) try: - self.pre_delete(item) # type: ignore + self.pre_delete(item) except Exception as ex: # pylint: disable=broad-except flash(str(ex), "danger") else: @@ -400,8 +400,8 @@ class DeleteMixin: # pylint: disable=too-few-public-methods .all() ) - if self.datamodel.delete(item): # type: ignore - self.post_delete(item) # type: ignore + if self.datamodel.delete(item): + self.post_delete(item) for pv in pvs: security_manager.get_session.delete(pv) @@ -411,8 +411,8 @@ class DeleteMixin: # pylint: disable=too-few-public-methods security_manager.get_session.commit() - flash(*self.datamodel.message) # type: ignore - self.update_redirect() # type: ignore + flash(*self.datamodel.message) + self.update_redirect() @action( "muldelete", __("Delete"), __("Delete all Really?"), "fa-trash", single=False diff --git a/superset/views/base_api.py b/superset/views/base_api.py index 3d40c33..a72a1c5 100644 --- a/superset/views/base_api.py +++ b/superset/views/base_api.py @@ -41,7 +41,7 @@ get_related_schema = { } -def statsd_metrics(f: Callable) -> Callable: +def statsd_metrics(f: Callable[..., Any]) -> Callable[..., Any]: """ Handle sending all statsd metrics from the REST API """ diff --git a/superset/views/base_schemas.py b/superset/views/base_schemas.py index a4436dd..87a9190 100644 --- a/superset/views/base_schemas.py +++ b/superset/views/base_schemas.py @@ -88,7 +88,9 @@ class BaseOwnedSchema(BaseSupersetSchema): owners_field_name = "owners" @post_load - def make_object(self, data: Dict, discard: Optional[List[str]] = None) -> Model: + def make_object( + self, data: Dict[str, Any], discard: Optional[List[str]] = None + ) -> Model: discard = discard or [] discard.append(self.owners_field_name) instance = super().make_object(data, discard) diff --git a/superset/views/core.py b/superset/views/core.py index c561222..44cade9 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -251,7 +251,7 @@ def check_slice_perms(self: "Superset", slice_id: int) -> None: def _deserialize_results_payload( payload: Union[bytes, str], query: Query, use_msgpack: Optional[bool] = False -) -> Dict[Any, Any]: +) -> Dict[str, Any]: logger.debug(f"Deserializing from msgpack: {use_msgpack}") if use_msgpack: with stats_timing( @@ -278,7 +278,7 @@ def _deserialize_results_payload( with stats_timing( "sqllab.query.results_backend_json_deserialize", stats_logger ): - return json.loads(payload) # type: ignore + return json.loads(payload) def get_cta_schema_name( @@ -1343,7 +1343,7 @@ class Superset(BaseSupersetView): if "timed_refresh_immune_slices" not in md: md["timed_refresh_immune_slices"] = [] - new_filter_scopes: Dict[str, Dict] = {} + new_filter_scopes = {} if "filter_scopes" in data: # replace filter_id and immune ids from old slice id to new slice id: # and remove slice ids that are not in dash anymore @@ -2137,7 +2137,7 @@ class Superset(BaseSupersetView): f"deprecated.{self.__class__.__name__}.select_star.database_not_found" ) return json_error_response("Not found", 404) - schema = utils.parse_js_uri_path_item(schema, eval_undefined=True) # type: ignore + schema = utils.parse_js_uri_path_item(schema, eval_undefined=True) table_name = utils.parse_js_uri_path_item(table_name) # type: ignore # Check that the user can access the datasource if not self.appbuilder.sm.can_access_datasource( @@ -2245,7 +2245,7 @@ class Superset(BaseSupersetView): ) payload = utils.zlib_decompress(blob, decode=not results_backend_use_msgpack) - obj: dict = _deserialize_results_payload( + obj = _deserialize_results_payload( payload, query, cast(bool, results_backend_use_msgpack) ) @@ -2474,9 +2474,7 @@ class Superset(BaseSupersetView): schema: str = cast(str, query_params.get("schema")) sql: str = cast(str, query_params.get("sql")) try: - template_params: dict = json.loads( - query_params.get("templateParams") or "{}" - ) + template_params = json.loads(query_params.get("templateParams") or "{}") except json.JSONDecodeError: logger.warning( f"Invalid template parameter {query_params.get('templateParams')}" diff --git a/superset/views/database/api.py b/superset/views/database/api.py index 0050326..fe328d6 100644 --- a/superset/views/database/api.py +++ b/superset/views/database/api.py @@ -61,7 +61,7 @@ def get_col_type(col: Dict[Any, Any]) -> str: def get_table_metadata( database: Database, table_name: str, schema_name: Optional[str] -) -> Dict: +) -> Dict[str, Any]: """ Get table metadata information, including type, pk, fks. This function raises SQLAlchemyError when a schema is not found. @@ -72,7 +72,7 @@ def get_table_metadata( :param schema_name: schema name :return: Dict table metadata ready for API response """ - keys: List = [] + keys = [] columns = database.get_columns(table_name, schema_name) primary_key = database.get_pk_constraint(table_name, schema_name) if primary_key and primary_key.get("constrained_columns"): @@ -82,7 +82,7 @@ def get_table_metadata( foreign_keys = get_foreign_keys_metadata(database, table_name, schema_name) indexes = get_indexes_metadata(database, table_name, schema_name) keys += foreign_keys + indexes - payload_columns: List[Dict] = [] + payload_columns: List[Dict[str, Any]] = [] for col in columns: dtype = get_col_type(col) payload_columns.append( @@ -90,7 +90,7 @@ def get_table_metadata( "name": col["name"], "type": dtype.split("(")[0] if "(" in dtype else dtype, "longType": dtype, - "keys": [k for k in keys if col["name"] in k.get("column_names")], + "keys": [k for k in keys if col["name"] in k["column_names"]], } ) return { @@ -270,7 +270,7 @@ class DatabaseRestApi(DatabaseMixin, BaseSupersetModelRestApi): """ self.incr_stats("init", self.table_metadata.__name__) try: - table_info: Dict = get_table_metadata(database, table_name, schema_name) + table_info = get_table_metadata(database, table_name, schema_name) except SQLAlchemyError as ex: self.incr_stats("error", self.table_metadata.__name__) return self.response_422(error_msg_from_exception(ex)) diff --git a/superset/views/database/decorators.py b/superset/views/database/decorators.py index 0d2e83b..291a1af 100644 --- a/superset/views/database/decorators.py +++ b/superset/views/database/decorators.py @@ -29,7 +29,7 @@ from superset.views.base_api import BaseSupersetModelRestApi logger = logging.getLogger(__name__) -def check_datasource_access(f: Callable) -> Callable: +def check_datasource_access(f: Callable[..., Any]) -> Callable[..., Any]: """ A Decorator that checks if a user has datasource access """ diff --git a/superset/views/schedules.py b/superset/views/schedules.py index 68ae6ff..de09c31 100644 --- a/superset/views/schedules.py +++ b/superset/views/schedules.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import enum -from typing import Type +from typing import Type, Union import simplejson as json from croniter import croniter @@ -55,7 +55,7 @@ class EmailScheduleView( raise NotImplementedError() @property - def schedule_type_model(self) -> Type: + def schedule_type_model(self) -> Type[Union[Dashboard, Slice]]: raise NotImplementedError() page_size = 20 @@ -154,9 +154,7 @@ class EmailScheduleView( info[col] = info[col].username info["user"] = schedule.user.username - info[self.schedule_type] = getattr( # type: ignore - schedule, self.schedule_type - ).id + info[self.schedule_type] = getattr(schedule, self.schedule_type).id schedules.append(info) return json_success(json.dumps(schedules, default=json_iso_dttm_ser)) diff --git a/superset/views/sql_lab.py b/superset/views/sql_lab.py index 3476bb3..534c6fb 100644 --- a/superset/views/sql_lab.py +++ b/superset/views/sql_lab.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Callable +from typing import Any import simplejson as json from flask import g, redirect, request, Response @@ -40,7 +40,7 @@ from .base import ( class QueryFilter(BaseFilter): # pylint: disable=too-few-public-methods - def apply(self, query: BaseQuery, value: Callable) -> BaseQuery: + def apply(self, query: BaseQuery, value: Any) -> BaseQuery: """ Filter queries to only those owned by current user. If can_access_all_queries permission is set a user can list all queries diff --git a/superset/views/utils.py b/superset/views/utils.py index 4edd2e7..2a8b2bf 100644 --- a/superset/views/utils.py +++ b/superset/views/utils.py @@ -35,7 +35,7 @@ from superset.utils.core import QueryStatus, TimeRangeEndpoint from superset.viz import BaseViz if is_feature_enabled("SIP_38_VIZ_REARCHITECTURE"): - from superset import viz_sip38 as viz # type: ignore + from superset import viz_sip38 as viz else: from superset import viz # type: ignore @@ -318,9 +318,9 @@ def get_dashboard_extra_filters( def build_extra_filters( - layout: Dict, - filter_scopes: Dict, - default_filters: Dict[str, Dict[str, List]], + layout: Dict[str, Dict[str, Any]], + filter_scopes: Dict[str, Dict[str, Any]], + default_filters: Dict[str, Dict[str, List[Any]]], slice_id: int, ) -> List[Dict[str, Any]]: extra_filters = [] @@ -343,7 +343,9 @@ def build_extra_filters( return extra_filters -def is_slice_in_container(layout: Dict, container_id: str, slice_id: int) -> bool: +def is_slice_in_container( + layout: Dict[str, Dict[str, Any]], container_id: str, slice_id: int +) -> bool: if container_id == "ROOT_ID": return True diff --git a/superset/viz.py b/superset/viz.py index d53dcf2..d34405c 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -2720,7 +2720,7 @@ class PairedTTestViz(BaseViz): else: cols.append(col) df.columns = cols - data: Dict = {} + data: Dict[str, List[Dict[str, Any]]] = {} series = df.to_dict("series") for nameSet in df.columns: # If no groups are defined, nameSet will be the metric name @@ -2750,7 +2750,7 @@ class RoseViz(NVD3TimeSeriesViz): return None data = super().get_data(df) - result: Dict = {} + result: Dict[str, List[Dict[str, str]]] = {} for datum in data: # type: ignore key = datum["key"] for val in datum["values"]: diff --git a/tests/base_tests.py b/tests/base_tests.py index 88c5f7b..d6d6516 100644 --- a/tests/base_tests.py +++ b/tests/base_tests.py @@ -18,7 +18,7 @@ """Unit tests for Superset""" import imp import json -from typing import Dict, Union, List +from typing import Any, Dict, Union, List from unittest.mock import Mock, patch import pandas as pd @@ -397,7 +397,9 @@ class SupersetTestCase(TestCase): mock_method.assert_called_once_with("error", func_name) return rv - def post_assert_metric(self, uri: str, data: Dict, func_name: str) -> Response: + def post_assert_metric( + self, uri: str, data: Dict[str, Any], func_name: str + ) -> Response: """ Simple client post with an extra assertion for statsd metrics @@ -417,7 +419,9 @@ class SupersetTestCase(TestCase): mock_method.assert_called_once_with("error", func_name) return rv - def put_assert_metric(self, uri: str, data: Dict, func_name: str) -> Response: + def put_assert_metric( + self, uri: str, data: Dict[str, Any], func_name: str + ) -> Response: """ Simple client put with an extra assertion for statsd metrics diff --git a/tests/superset_test_config_thumbnails.py b/tests/superset_test_config_thumbnails.py index bfcb3a3..3b97604 100644 --- a/tests/superset_test_config_thumbnails.py +++ b/tests/superset_test_config_thumbnails.py @@ -20,7 +20,7 @@ from copy import copy from cachelib.redis import RedisCache from flask import Flask -from superset.config import * # type: ignore +from superset.config import * AUTH_USER_REGISTRATION_ROLE = "alpha" SQLALCHEMY_DATABASE_URI = "sqlite:///" + os.path.join(DATA_DIR, "unittests.db")