This is an automated email from the ASF dual-hosted git repository.
johnbodley pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git
The following commit(s) were added to refs/heads/master by this push:
new e789a35 [mypy] Enforcing typing for superset.models (#9883)
e789a35 is described below
commit e789a3555843d9791b9230a61454a3abb8cb07e0
Author: John Bodley <[email protected]>
AuthorDate: Fri May 22 20:31:21 2020 -0700
[mypy] Enforcing typing for superset.models (#9883)
Co-authored-by: John Bodley <[email protected]>
---
setup.cfg | 2 +-
superset/connectors/sqla/models.py | 8 +-
superset/legacy.py | 3 +-
superset/models/annotations.py | 6 +-
superset/models/core.py | 21 ++--
superset/models/dashboard.py | 26 +++--
superset/models/helpers.py | 140 +++++++++++++++-----------
superset/models/schedules.py | 6 +-
superset/models/slice.py | 12 ++-
superset/models/sql_lab.py | 22 ++--
superset/models/sql_types/presto_sql_types.py | 22 ++--
superset/models/tags.py | 59 ++++++++---
superset/utils/cache.py | 6 +-
superset/utils/core.py | 4 +-
14 files changed, 207 insertions(+), 130 deletions(-)
diff --git a/setup.cfg b/setup.cfg
index dc3e701..bfef8af 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -53,7 +53,7 @@ order_by_type = false
ignore_missing_imports = true
no_implicit_optional = true
-[mypy-superset.bin.*,superset.charts.*,superset.datasets.*,superset.dashboards.*,superset.commands.*,superset.common.*,superset.dao.*,superset.db_engine_specs.*,superset.db_engines.*,superset.examples.*,superset.migrations.*,superset.queries.*,superset.security.*,superset.sql_validators.*,superset.tasks.*,superset.translations.*]
+[mypy-superset.bin.*,superset.charts.*,superset.datasets.*,superset.dashboards.*,superset.commands.*,superset.common.*,superset.dao.*,superset.db_engine_specs.*,superset.db_engines.*,superset.examples.*,superset.migrations.*,superset.models.*,superset.queries.*,superset.security.*,superset.sql_validators.*,superset.tasks.*,superset.translations.*]
check_untyped_defs = true
disallow_untyped_calls = true
disallow_untyped_defs = true
diff --git a/superset/connectors/sqla/models.py
b/superset/connectors/sqla/models.py
index ad45474..3e50280 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -18,7 +18,7 @@
import logging
import re
from collections import OrderedDict
-from datetime import datetime
+from datetime import datetime, timedelta
from typing import Any, Dict, Hashable, List, NamedTuple, Optional, Tuple,
Union
import pandas as pd
@@ -103,7 +103,11 @@ class AnnotationDatasource(BaseDatasource):
logger.exception(ex)
error_message = utils.error_msg_from_exception(ex)
return QueryResult(
- status=status, df=df, duration=0, query="",
error_message=error_message
+ status=status,
+ df=df,
+ duration=timedelta(0),
+ query="",
+ error_message=error_message,
)
def get_query_str(self, query_obj):
diff --git a/superset/legacy.py b/superset/legacy.py
index b867edc..168b9c0 100644
--- a/superset/legacy.py
+++ b/superset/legacy.py
@@ -15,9 +15,10 @@
# specific language governing permissions and limitations
# under the License.
"""Code related with dealing with legacy / change management"""
+from typing import Any, Dict
-def update_time_range(form_data):
+def update_time_range(form_data: Dict[str, Any]) -> None:
"""Move since and until to time_range."""
if "since" in form_data or "until" in form_data:
form_data["time_range"] = "{} : {}".format(
diff --git a/superset/models/annotations.py b/superset/models/annotations.py
index 07e2351..ec8d3c0 100644
--- a/superset/models/annotations.py
+++ b/superset/models/annotations.py
@@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.
"""a collection of Annotation-related models"""
+from typing import Any, Dict
+
from flask_appbuilder import Model
from sqlalchemy import Column, DateTime, ForeignKey, Index, Integer, String,
Text
from sqlalchemy.orm import relationship
@@ -31,7 +33,7 @@ class AnnotationLayer(Model, AuditMixinNullable):
name = Column(String(250))
descr = Column(Text)
- def __repr__(self):
+ def __repr__(self) -> str:
return self.name
@@ -52,7 +54,7 @@ class Annotation(Model, AuditMixinNullable):
__table_args__ = (Index("ti_dag_state", layer_id, start_dttm, end_dttm),)
@property
- def data(self):
+ def data(self) -> Dict[str, Any]:
return {
"layer_id": self.layer_id,
"start_dttm": self.start_dttm,
diff --git a/superset/models/core.py b/superset/models/core.py
index abcb210..69d306a 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -152,7 +152,7 @@ class Database(
]
export_children = ["tables"]
- def __repr__(self):
+ def __repr__(self) -> str:
return self.name
@property
@@ -234,7 +234,9 @@ class Database(
return self.get_extra().get("default_schemas", [])
@classmethod
- def get_password_masked_url_from_uri(cls, uri: str): # pylint:
disable=invalid-name
+ def get_password_masked_url_from_uri( # pylint: disable=invalid-name
+ cls, uri: str
+ ) -> URL:
sqlalchemy_url = make_url(uri)
return cls.get_password_masked_url(sqlalchemy_url)
@@ -279,7 +281,7 @@ class Database(
effective_username = g.user.username
return effective_username
- @utils.memoized(watch=("impersonate_user", "sqlalchemy_uri_decrypted",
"extra"))
+ @utils.memoized(watch=["impersonate_user", "sqlalchemy_uri_decrypted",
"extra"])
def get_sqla_engine(
self,
schema: Optional[str] = None,
@@ -339,7 +341,7 @@ class Database(
def get_reserved_words(self) -> Set[str]:
return self.get_dialect().preparer.reserved_words
- def get_quoter(self):
+ def get_quoter(self) -> Callable:
return self.get_dialect().identifier_preparer.quote
def get_df( # pylint: disable=too-many-locals
@@ -405,7 +407,7 @@ class Database(
indent: bool = True,
latest_partition: bool = False,
cols: Optional[List[Dict[str, Any]]] = None,
- ):
+ ) -> str:
"""Generates a ``select *`` statement in the proper dialect"""
eng = self.get_sqla_engine(schema=schema,
source=utils.QuerySource.SQL_LAB)
return self.db_engine_spec.select_star(
@@ -436,7 +438,10 @@ class Database(
attribute_in_key="id",
)
def get_all_table_names_in_database(
- self, cache: bool = False, cache_timeout: Optional[bool] = None,
force=False
+ self,
+ cache: bool = False,
+ cache_timeout: Optional[bool] = None,
+ force: bool = False,
) -> List[utils.DatasourceName]:
"""Parameters need to be passed as keyword arguments."""
if not self.allow_multi_schema_metadata_fetch:
@@ -547,7 +552,7 @@ class Database(
@classmethod
def get_db_engine_spec_for_backend(
- cls, backend
+ cls, backend: str
) -> Type[db_engine_specs.BaseEngineSpec]:
return db_engine_specs.engines.get(backend,
db_engine_specs.BaseEngineSpec)
@@ -565,7 +570,7 @@ class Database(
def get_extra(self) -> Dict[str, Any]:
return self.db_engine_spec.get_extra_params(self)
- def get_encrypted_extra(self):
+ def get_encrypted_extra(self) -> Dict[str, Any]:
encrypted_extra = {}
if self.encrypted_extra:
try:
diff --git a/superset/models/dashboard.py b/superset/models/dashboard.py
index c86f8ff..de42285 100644
--- a/superset/models/dashboard.py
+++ b/superset/models/dashboard.py
@@ -36,7 +36,9 @@ from sqlalchemy import (
Text,
UniqueConstraint,
)
+from sqlalchemy.engine.base import Connection
from sqlalchemy.orm import relationship, sessionmaker, subqueryload
+from sqlalchemy.orm.mapper import Mapper
from superset import app, ConnectorRegistry, db, is_feature_enabled,
security_manager
from superset.models.helpers import AuditMixinNullable, ImportMixin
@@ -59,7 +61,7 @@ config = app.config
logger = logging.getLogger(__name__)
-def copy_dashboard(mapper, connection, target):
+def copy_dashboard(mapper: Mapper, connection: Connection, target:
"Dashboard") -> None:
# pylint: disable=unused-argument
dashboard_id = config["DASHBOARD_TEMPLATE_ID"]
if dashboard_id is None:
@@ -140,7 +142,7 @@ class Dashboard( # pylint:
disable=too-many-instance-attributes
"slug",
]
- def __repr__(self):
+ def __repr__(self) -> str:
return self.dashboard_title or str(self.id)
@property
@@ -202,13 +204,13 @@ class Dashboard( # pylint:
disable=too-many-instance-attributes
return f"/api/v1/dashboard/{self.id}/thumbnail/{self.digest}/"
@property
- def changed_by_name(self):
+ def changed_by_name(self) -> str:
if not self.changed_by:
return ""
return str(self.changed_by)
@property
- def changed_by_url(self):
+ def changed_by_url(self) -> str:
if not self.changed_by:
return ""
return f"/superset/profile/{self.changed_by.username}"
@@ -229,8 +231,8 @@ class Dashboard( # pylint:
disable=too-many-instance-attributes
"position_json": positions,
}
- @property
- def params(self) -> str:
+ @property # type: ignore
+ def params(self) -> str: # type: ignore
return self.json_metadata
@params.setter
@@ -257,7 +259,9 @@ class Dashboard( # pylint:
disable=too-many-instance-attributes
Audit metadata isn't copied over.
"""
- def alter_positions(dashboard, old_to_new_slc_id_dict):
+ def alter_positions(
+ dashboard: Dashboard, old_to_new_slc_id_dict: Dict[int, int]
+ ) -> None:
""" Updates slice_ids in the position json.
Sample position_json data:
@@ -291,9 +295,9 @@ class Dashboard( # pylint:
disable=too-many-instance-attributes
if (
isinstance(value, dict)
and value.get("meta")
- and value.get("meta").get("chartId")
+ and value.get("meta", {}).get("chartId")
):
- old_slice_id = value.get("meta").get("chartId")
+ old_slice_id = value["meta"]["chartId"]
if old_slice_id in old_to_new_slc_id_dict:
value["meta"]["chartId"] =
old_to_new_slc_id_dict[old_slice_id]
@@ -470,8 +474,8 @@ class Dashboard( # pylint:
disable=too-many-instance-attributes
def event_after_dashboard_changed( # pylint: disable=unused-argument
- mapper, connection, target
-):
+ mapper: Mapper, connection: Connection, target: Dashboard
+) -> None:
cache_dashboard_thumbnail.delay(target.id, force=True)
diff --git a/superset/models/helpers.py b/superset/models/helpers.py
index 93b3b7f..42169e6 100644
--- a/superset/models/helpers.py
+++ b/superset/models/helpers.py
@@ -18,8 +18,8 @@
import json
import logging
import re
-from datetime import datetime
-from typing import Any, Dict, List, Optional
+from datetime import datetime, timedelta
+from typing import Any, Dict, List, Optional, Set, Union
# isort and pylint disagree, isort should win
# pylint: disable=ungrouped-imports
@@ -30,8 +30,10 @@ import yaml
from flask import escape, g, Markup
from flask_appbuilder.models.decorators import renders
from flask_appbuilder.models.mixins import AuditMixin
+from flask_appbuilder.security.sqla.models import User
from sqlalchemy import and_, or_, UniqueConstraint
from sqlalchemy.ext.declarative import declared_attr
+from sqlalchemy.orm import Session
from sqlalchemy.orm.exc import MultipleResultsFound
from superset.utils.core import QueryStatus
@@ -39,7 +41,7 @@ from superset.utils.core import QueryStatus
logger = logging.getLogger(__name__)
-def json_to_dict(json_str):
+def json_to_dict(json_str: str) -> Dict[Any, Any]:
if json_str:
val = re.sub(",[ \t\r\n]+}", "}", json_str)
val = re.sub(
@@ -64,48 +66,56 @@ class ImportMixin:
# that are available for import and export
@classmethod
- def _parent_foreign_key_mappings(cls):
+ def _parent_foreign_key_mappings(cls) -> Dict[str, str]:
"""Get a mapping of foreign name to the local name of foreign keys"""
- parent_rel = cls.__mapper__.relationships.get(cls.export_parent)
+ parent_rel = cls.__mapper__.relationships.get(cls.export_parent) #
type: ignore
if parent_rel:
return {l.name: r.name for (l, r) in parent_rel.local_remote_pairs}
return {}
@classmethod
- def _unique_constrains(cls):
+ def _unique_constrains(cls) -> List[Set[str]]:
"""Get all (single column and multi column) unique constraints"""
unique = [
{c.name for c in u.columns}
- for u in cls.__table_args__
+ for u in cls.__table_args__ # type: ignore
if isinstance(u, UniqueConstraint)
]
- unique.extend({c.name} for c in cls.__table__.columns if c.unique)
+ unique.extend( # type: ignore
+ {c.name} for c in cls.__table__.columns if c.unique # type: ignore
+ )
return unique
@classmethod
- def export_schema(cls, recursive=True, include_parent_ref=False):
+ def export_schema(
+ cls, recursive: bool = True, include_parent_ref: bool = False
+ ) -> Dict[str, Any]:
"""Export schema as a dictionary"""
- parent_excludes = {}
+ parent_excludes = set()
if not include_parent_ref:
- parent_ref = cls.__mapper__.relationships.get(cls.export_parent)
+ parent_ref = cls.__mapper__.relationships.get( # type: ignore
+ cls.export_parent
+ )
if parent_ref:
parent_excludes = {column.name for column in
parent_ref.local_columns}
- def formatter(column):
+ def formatter(column: sa.Column) -> str:
return (
"{0} Default ({1})".format(str(column.type),
column.default.arg)
if column.default
else str(column.type)
)
- schema = {
+ schema: Dict[str, Any] = {
column.name: formatter(column)
- for column in cls.__table__.columns
+ for column in cls.__table__.columns # type: ignore
if (column.name in cls.export_fields and column.name not in
parent_excludes)
}
if recursive:
for column in cls.export_children:
- child_class =
cls.__mapper__.relationships[column].argument.class_
+ child_class = cls.__mapper__.relationships[ # type: ignore
+ column
+ ].argument.class_
schema[column] = [
child_class.export_schema(
recursive=recursive,
include_parent_ref=include_parent_ref
@@ -114,17 +124,20 @@ class ImportMixin:
return schema
@classmethod
- def import_from_dict(
- cls, session, dict_rep, parent=None, recursive=True, sync=None
- ): # pylint: disable=too-many-arguments,too-many-locals,too-many-branches
+ def import_from_dict( # pylint:
disable=too-many-arguments,too-many-branches,too-many-locals
+ cls,
+ session: Session,
+ dict_rep: Dict[Any, Any],
+ parent: Optional[Any] = None,
+ recursive: bool = True,
+ sync: Optional[List[str]] = None,
+ ) -> Any: # pylint:
disable=too-many-arguments,too-many-locals,too-many-branches
"""Import obj from a dictionary"""
if sync is None:
sync = []
parent_refs = cls._parent_foreign_key_mappings()
export_fields = set(cls.export_fields) | set(parent_refs.keys())
- new_children = {
- c: dict_rep.get(c) for c in cls.export_children if c in dict_rep
- }
+ new_children = {c: dict_rep[c] for c in cls.export_children if c in
dict_rep}
unique_constrains = cls._unique_constrains()
filters = [] # Using these filters to check if obj already exists
@@ -178,7 +191,7 @@ class ImportMixin:
if not obj:
is_new_obj = True
# Create new DB object
- obj = cls(**dict_rep)
+ obj = cls(**dict_rep) # type: ignore
logger.info("Importing new %s %s", obj.__tablename__, str(obj))
if cls.export_parent and parent:
setattr(obj, cls.export_parent, parent)
@@ -193,7 +206,9 @@ class ImportMixin:
# Recursively create children
if recursive:
for child in cls.export_children:
- child_class =
cls.__mapper__.relationships[child].argument.class_
+ child_class = cls.__mapper__.relationships[ # type: ignore
+ child
+ ].argument.class_
added = []
for c_obj in new_children.get(child, []):
added.append(
@@ -221,18 +236,23 @@ class ImportMixin:
return obj
def export_to_dict(
- self, recursive=True, include_parent_ref=False, include_defaults=False
- ):
+ self,
+ recursive: bool = True,
+ include_parent_ref: bool = False,
+ include_defaults: bool = False,
+ ) -> Dict[Any, Any]:
"""Export obj to dictionary"""
cls = self.__class__
- parent_excludes = {}
+ parent_excludes = set()
if recursive and not include_parent_ref:
- parent_ref = cls.__mapper__.relationships.get(cls.export_parent)
+ parent_ref = cls.__mapper__.relationships.get( # type: ignore
+ cls.export_parent
+ )
if parent_ref:
parent_excludes = {c.name for c in parent_ref.local_columns}
dict_rep = {
c.name: getattr(self, c.name)
- for c in cls.__table__.columns
+ for c in cls.__table__.columns # type: ignore
if (
c.name in self.export_fields
and c.name not in parent_excludes
@@ -262,18 +282,18 @@ class ImportMixin:
return dict_rep
- def override(self, obj):
+ def override(self, obj: Any) -> None:
"""Overrides the plain fields of the dashboard."""
for field in obj.__class__.export_fields:
setattr(self, field, getattr(obj, field))
- def copy(self):
+ def copy(self) -> Any:
"""Creates a copy of the dashboard without relationships."""
new_obj = self.__class__()
new_obj.override(self)
return new_obj
- def alter_params(self, **kwargs):
+ def alter_params(self, **kwargs: Any) -> None:
params = self.params_dict
params.update(kwargs)
self.params = json.dumps(params)
@@ -283,7 +303,7 @@ class ImportMixin:
params.pop(param_to_remove, None)
self.params = json.dumps(params)
- def reset_ownership(self):
+ def reset_ownership(self) -> None:
""" object will belong to the user the current user """
# make sure the object doesn't have relations to a user
# it will be filled by appbuilder on save
@@ -297,15 +317,15 @@ class ImportMixin:
self.owners = []
@property
- def params_dict(self):
+ def params_dict(self) -> Dict[Any, Any]:
return json_to_dict(self.params)
@property
- def template_params_dict(self):
- return json_to_dict(self.template_params)
+ def template_params_dict(self) -> Dict[Any, Any]:
+ return json_to_dict(self.template_params) # type: ignore
-def _user_link(user): # pylint: disable=no-self-use
+def _user_link(user: User) -> Union[Markup, str]: # pylint:
disable=no-self-use
if not user:
return ""
url = "/superset/profile/{}/".format(user.username)
@@ -325,7 +345,7 @@ class AuditMixinNullable(AuditMixin):
)
@declared_attr
- def created_by_fk(self):
+ def created_by_fk(self) -> sa.Column:
return sa.Column(
sa.Integer,
sa.ForeignKey("ab_user.id"),
@@ -334,7 +354,7 @@ class AuditMixinNullable(AuditMixin):
)
@declared_attr
- def changed_by_fk(self):
+ def changed_by_fk(self) -> sa.Column:
return sa.Column(
sa.Integer,
sa.ForeignKey("ab_user.id"),
@@ -343,29 +363,29 @@ class AuditMixinNullable(AuditMixin):
nullable=True,
)
- def changed_by_name(self):
+ def changed_by_name(self) -> str:
if self.created_by:
return escape("{}".format(self.created_by))
return ""
@renders("created_by")
- def creator(self):
+ def creator(self) -> Union[Markup, str]:
return _user_link(self.created_by)
@property
- def changed_by_(self):
+ def changed_by_(self) -> Union[Markup, str]:
return _user_link(self.changed_by)
@renders("changed_on")
- def changed_on_(self):
+ def changed_on_(self) -> Markup:
return Markup(f'<span class="no-wrap">{self.changed_on}</span>')
@property
- def changed_on_humanized(self):
+ def changed_on_humanized(self) -> str:
return humanize.naturaltime(datetime.now() - self.changed_on)
@renders("changed_on")
- def modified(self):
+ def modified(self) -> Markup:
return Markup(f'<span
class="no-wrap">{self.changed_on_humanized}</span>')
@@ -375,19 +395,19 @@ class QueryResult: # pylint:
disable=too-few-public-methods
def __init__( # pylint: disable=too-many-arguments
self,
- df,
- query,
- duration,
- status=QueryStatus.SUCCESS,
- error_message=None,
- errors=None,
- ):
- self.df: pd.DataFrame = df
- self.query: str = query
- self.duration: int = duration
- self.status: str = status
- self.error_message: Optional[str] = error_message
- self.errors: List[Dict[str, Any]] = errors or []
+ df: pd.DataFrame,
+ query: str,
+ duration: timedelta,
+ status: str = QueryStatus.SUCCESS,
+ error_message: Optional[str] = None,
+ errors: Optional[List[Dict[str, Any]]] = None,
+ ) -> None:
+ self.df = df
+ self.query = query
+ self.duration = duration
+ self.status = status
+ self.error_message = error_message
+ self.errors = errors or []
class ExtraJSONMixin:
@@ -396,16 +416,16 @@ class ExtraJSONMixin:
extra_json = sa.Column(sa.Text, default="{}")
@property
- def extra(self):
+ def extra(self) -> Dict[str, Any]:
try:
return json.loads(self.extra_json)
except Exception: # pylint: disable=broad-except
return {}
- def set_extra_json(self, extras):
+ def set_extra_json(self, extras: Dict[str, Any]) -> None:
self.extra_json = json.dumps(extras)
- def set_extra_json_key(self, key, value):
+ def set_extra_json_key(self, key: str, value: Any) -> None:
extra = self.extra
extra[key] = value
self.extra_json = json.dumps(extra)
diff --git a/superset/models/schedules.py b/superset/models/schedules.py
index f70b076..0eb31a5 100644
--- a/superset/models/schedules.py
+++ b/superset/models/schedules.py
@@ -21,7 +21,7 @@ from typing import Optional, Type
from flask_appbuilder import Model
from sqlalchemy import Boolean, Column, Enum, ForeignKey, Integer, String, Text
from sqlalchemy.ext.declarative import declared_attr
-from sqlalchemy.orm import relationship
+from sqlalchemy.orm import relationship, RelationshipProperty
from superset import security_manager
from superset.models.helpers import AuditMixinNullable, ImportMixin
@@ -55,11 +55,11 @@ class EmailSchedule:
crontab = Column(String(50))
@declared_attr
- def user_id(self):
+ def user_id(self) -> int:
return Column(Integer, ForeignKey("ab_user.id"))
@declared_attr
- def user(self):
+ def user(self) -> RelationshipProperty:
return relationship(
security_manager.user_model,
backref=self.__tablename__,
diff --git a/superset/models/slice.py b/superset/models/slice.py
index a570bcf..76eb457 100644
--- a/superset/models/slice.py
+++ b/superset/models/slice.py
@@ -24,7 +24,9 @@ from flask_appbuilder import Model
from flask_appbuilder.models.decorators import renders
from markupsafe import escape, Markup
from sqlalchemy import Column, ForeignKey, Integer, String, Table, Text
+from sqlalchemy.engine.base import Connection
from sqlalchemy.orm import make_transient, relationship
+from sqlalchemy.orm.mapper import Mapper
from superset import ConnectorRegistry, db, is_feature_enabled,
security_manager
from superset.legacy import update_time_range
@@ -92,7 +94,7 @@ class Slice(
"cache_timeout",
]
- def __repr__(self):
+ def __repr__(self) -> str:
return self.slice_name or str(self.id)
@property
@@ -263,7 +265,7 @@ class Slice(
@property
def changed_by_url(self) -> str:
- return f"/superset/profile/{self.created_by.username}"
+ return f"/superset/profile/{self.created_by.username}" # type: ignore
@property
def icons(self) -> str:
@@ -324,7 +326,7 @@ class Slice(
return
f"/superset/explore/?form_data=%7B%22slice_id%22%3A%20{self.id}%7D"
-def set_related_perm(mapper, connection, target):
+def set_related_perm(mapper: Mapper, connection: Connection, target: Slice) ->
None:
# pylint: disable=unused-argument
src_class = target.cls_model
id_ = target.datasource_id
@@ -336,8 +338,8 @@ def set_related_perm(mapper, connection, target):
def event_after_chart_changed( # pylint: disable=unused-argument
- mapper, connection, target
-):
+ mapper: Mapper, connection: Connection, target: Slice
+) -> None:
cache_chart_thumbnail.delay(target.id, force=True)
diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py
index 654bc2f..9c3c239 100644
--- a/superset/models/sql_lab.py
+++ b/superset/models/sql_lab.py
@@ -17,6 +17,7 @@
"""A collection of ORM sqlalchemy models for SQL Lab"""
import re
from datetime import datetime
+from typing import Any, Dict
# pylint: disable=ungrouped-imports
import simplejson as json
@@ -33,6 +34,7 @@ from sqlalchemy import (
String,
Text,
)
+from sqlalchemy.engine.url import URL
from sqlalchemy.orm import backref, relationship
from superset import security_manager
@@ -99,7 +101,7 @@ class Query(Model, ExtraJSONMixin):
__table_args__ = (sqla.Index("ti_user_id_changed_on", user_id,
changed_on),)
- def to_dict(self):
+ def to_dict(self) -> Dict[str, Any]:
return {
"changedOn": self.changed_on,
"changed_on": self.changed_on.isoformat(),
@@ -130,7 +132,7 @@ class Query(Model, ExtraJSONMixin):
}
@property
- def name(self):
+ def name(self) -> str:
"""Name property"""
ts = datetime.now().isoformat()
ts = ts.replace("-", "").replace(":", "").split(".")[0]
@@ -139,11 +141,11 @@ class Query(Model, ExtraJSONMixin):
return f"sqllab_{tab}_{ts}"
@property
- def database_name(self):
+ def database_name(self) -> str:
return self.database.name
@property
- def username(self):
+ def username(self) -> str:
return self.user.username
@@ -170,7 +172,7 @@ class SavedQuery(Model, AuditMixinNullable, ExtraJSONMixin):
)
@property
- def pop_tab_link(self):
+ def pop_tab_link(self) -> Markup:
return Markup(
f"""
<a href="/superset/sqllab?savedQueryId={self.id}">
@@ -180,14 +182,14 @@ class SavedQuery(Model, AuditMixinNullable,
ExtraJSONMixin):
)
@property
- def user_email(self):
+ def user_email(self) -> str:
return self.user.email
@property
- def sqlalchemy_uri(self):
+ def sqlalchemy_uri(self) -> URL:
return self.database.sqlalchemy_uri
- def url(self):
+ def url(self) -> str:
return "/superset/sqllab?savedQueryId={0}".format(self.id)
@@ -226,7 +228,7 @@ class TabState(Model, AuditMixinNullable, ExtraJSONMixin):
autorun = Column(Boolean, default=False)
template_params = Column(Text)
- def to_dict(self):
+ def to_dict(self) -> Dict[str, Any]:
return {
"id": self.id,
"user_id": self.user_id,
@@ -260,7 +262,7 @@ class TableSchema(Model, AuditMixinNullable,
ExtraJSONMixin):
expanded = Column(Boolean, default=False)
- def to_dict(self):
+ def to_dict(self) -> Dict[str, Any]:
try:
description = json.loads(self.description)
except json.JSONDecodeError:
diff --git a/superset/models/sql_types/presto_sql_types.py
b/superset/models/sql_types/presto_sql_types.py
index f0b46fa..a50b4c2 100644
--- a/superset/models/sql_types/presto_sql_types.py
+++ b/superset/models/sql_types/presto_sql_types.py
@@ -14,10 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from typing import Any, Optional, Type
from sqlalchemy import types
from sqlalchemy.sql.sqltypes import Integer
from sqlalchemy.sql.type_api import TypeEngine
+from sqlalchemy.sql.visitors import Visitable
# _compiler_dispatch is defined to help with type compilation
@@ -27,11 +29,11 @@ class TinyInteger(Integer):
A type for tiny ``int`` integers.
"""
- def python_type(self):
+ def python_type(self) -> Type:
return int
@classmethod
- def _compiler_dispatch(cls, _visitor, **_kw):
+ def _compiler_dispatch(cls, _visitor: Visitable, **_kw: Any) -> str:
return "TINYINT"
@@ -40,11 +42,11 @@ class Interval(TypeEngine):
A type for intervals.
"""
- def python_type(self):
+ def python_type(self) -> Optional[Type]:
return None
@classmethod
- def _compiler_dispatch(cls, _visitor, **_kw):
+ def _compiler_dispatch(cls, _visitor: Visitable, **_kw: Any) -> str:
return "INTERVAL"
@@ -53,11 +55,11 @@ class Array(TypeEngine):
A type for arrays.
"""
- def python_type(self):
+ def python_type(self) -> Optional[Type]:
return list
@classmethod
- def _compiler_dispatch(cls, _visitor, **_kw):
+ def _compiler_dispatch(cls, _visitor: Visitable, **_kw: Any) -> str:
return "ARRAY"
@@ -66,11 +68,11 @@ class Map(TypeEngine):
A type for maps.
"""
- def python_type(self):
+ def python_type(self) -> Optional[Type]:
return dict
@classmethod
- def _compiler_dispatch(cls, _visitor, **_kw):
+ def _compiler_dispatch(cls, _visitor: Visitable, **_kw: Any) -> str:
return "MAP"
@@ -79,11 +81,11 @@ class Row(TypeEngine):
A type for rows.
"""
- def python_type(self):
+ def python_type(self) -> Optional[Type]:
return None
@classmethod
- def _compiler_dispatch(cls, _visitor, **_kw):
+ def _compiler_dispatch(cls, _visitor: Visitable, **_kw: Any) -> str:
return "ROW"
diff --git a/superset/models/tags.py b/superset/models/tags.py
index 0cb00cc..c09bb16 100644
--- a/superset/models/tags.py
+++ b/superset/models/tags.py
@@ -17,15 +17,23 @@
from __future__ import absolute_import, division, print_function,
unicode_literals
import enum
-from typing import Optional
+from typing import List, Optional, TYPE_CHECKING, Union
from flask_appbuilder import Model
from sqlalchemy import Column, Enum, ForeignKey, Integer, String
-from sqlalchemy.orm import relationship, sessionmaker
+from sqlalchemy.engine.base import Connection
+from sqlalchemy.orm import relationship, Session, sessionmaker
from sqlalchemy.orm.exc import NoResultFound
+from sqlalchemy.orm.mapper import Mapper
from superset.models.helpers import AuditMixinNullable
+if TYPE_CHECKING:
+ from superset.models.core import FavStar # pylint: disable=unused-import
+ from superset.models.dashboard import Dashboard # pylint:
disable=unused-import
+ from superset.models.slice import Slice # pylint: disable=unused-import
+ from superset.models.sql_lab import Query # pylint: disable=unused-import
+
Session = sessionmaker(autoflush=False)
@@ -80,7 +88,7 @@ class TaggedObject(Model, AuditMixinNullable):
tag = relationship("Tag", backref="objects")
-def get_tag(name, session, type_):
+def get_tag(name: str, session: Session, type_: TagTypes) -> Tag:
try:
tag = session.query(Tag).filter_by(name=name, type=type_).one()
except NoResultFound:
@@ -91,7 +99,7 @@ def get_tag(name, session, type_):
return tag
-def get_object_type(class_name):
+def get_object_type(class_name: str) -> ObjectTypes:
mapping = {
"slice": ObjectTypes.chart,
"dashboard": ObjectTypes.dashboard,
@@ -108,11 +116,15 @@ class ObjectUpdater:
object_type: Optional[str] = None
@classmethod
- def get_owners_ids(cls, target):
+ def get_owners_ids(
+ cls, target: Union["Dashboard", "FavStar", "Slice"]
+ ) -> List[int]:
raise NotImplementedError("Subclass should implement `get_owners_ids`")
@classmethod
- def _add_owners(cls, session, target):
+ def _add_owners(
+ cls, session: Session, target: Union["Dashboard", "FavStar", "Slice"]
+ ) -> None:
for owner_id in cls.get_owners_ids(target):
name = "owner:{0}".format(owner_id)
tag = get_tag(name, session, TagTypes.owner)
@@ -122,7 +134,12 @@ class ObjectUpdater:
session.add(tagged_object)
@classmethod
- def after_insert(cls, mapper, connection, target):
+ def after_insert(
+ cls,
+ mapper: Mapper,
+ connection: Connection,
+ target: Union["Dashboard", "FavStar", "Slice"],
+ ) -> None:
# pylint: disable=unused-argument
session = Session(bind=connection)
@@ -139,7 +156,12 @@ class ObjectUpdater:
session.commit()
@classmethod
- def after_update(cls, mapper, connection, target):
+ def after_update(
+ cls,
+ mapper: Mapper,
+ connection: Connection,
+ target: Union["Dashboard", "FavStar", "Slice"],
+ ) -> None:
# pylint: disable=unused-argument
session = Session(bind=connection)
@@ -164,7 +186,12 @@ class ObjectUpdater:
session.commit()
@classmethod
- def after_delete(cls, mapper, connection, target):
+ def after_delete(
+ cls,
+ mapper: Mapper,
+ connection: Connection,
+ target: Union["Dashboard", "FavStar", "Slice"],
+ ) -> None:
# pylint: disable=unused-argument
session = Session(bind=connection)
@@ -182,7 +209,7 @@ class ChartUpdater(ObjectUpdater):
object_type = "chart"
@classmethod
- def get_owners_ids(cls, target):
+ def get_owners_ids(cls, target: "Slice") -> List[int]:
return [owner.id for owner in target.owners]
@@ -191,7 +218,7 @@ class DashboardUpdater(ObjectUpdater):
object_type = "dashboard"
@classmethod
- def get_owners_ids(cls, target):
+ def get_owners_ids(cls, target: "Dashboard") -> List[int]:
return [owner.id for owner in target.owners]
@@ -200,13 +227,15 @@ class QueryUpdater(ObjectUpdater):
object_type = "query"
@classmethod
- def get_owners_ids(cls, target):
+ def get_owners_ids(cls, target: "Query") -> List[int]:
return [target.user_id]
class FavStarUpdater:
@classmethod
- def after_insert(cls, mapper, connection, target):
+ def after_insert(
+ cls, mapper: Mapper, connection: Connection, target: "FavStar"
+ ) -> None:
# pylint: disable=unused-argument
session = Session(bind=connection)
name = "favorited_by:{0}".format(target.user_id)
@@ -221,7 +250,9 @@ class FavStarUpdater:
session.commit()
@classmethod
- def after_delete(cls, mapper, connection, target):
+ def after_delete(
+ cls, mapper: Mapper, connection: Connection, target: "FavStar"
+ ) -> None:
# pylint: disable=unused-argument
session = Session(bind=connection)
name = "favorited_by:{0}".format(target.user_id)
diff --git a/superset/utils/cache.py b/superset/utils/cache.py
index a7a5dc6..b555005 100644
--- a/superset/utils/cache.py
+++ b/superset/utils/cache.py
@@ -14,6 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from typing import Callable, Optional
+
from flask import request
from superset.extensions import cache_manager
@@ -24,7 +26,9 @@ def view_cache_key(*_, **__) -> str:
return "view/{}/{}".format(request.path, args_hash)
-def memoized_func(key=view_cache_key, attribute_in_key=None):
+def memoized_func(
+ key: Callable = view_cache_key, attribute_in_key: Optional[str] = None
+) -> Callable:
"""Use this decorator to cache functions that have predefined first arg.
enable_cache is treated as True by default,
diff --git a/superset/utils/core.py b/superset/utils/core.py
index 19e39ab..e093d3d 100644
--- a/superset/utils/core.py
+++ b/superset/utils/core.py
@@ -143,7 +143,7 @@ class _memoized:
self.func = func
self.cache = {}
self.is_method = False
- self.watch = watch
+ self.watch = watch or []
def __call__(self, *args, **kwargs):
key = [args, frozenset(kwargs.items())]
@@ -172,7 +172,7 @@ class _memoized:
return functools.partial(self.__call__, obj)
-def memoized(func=None, watch=None):
+def memoized(func: Optional[Callable] = None, watch: Optional[List[str]] =
None):
if func:
return _memoized(func)
else: