This is an automated email from the ASF dual-hosted git repository.

johnbodley pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git


The following commit(s) were added to refs/heads/master by this push:
     new e789a35  [mypy] Enforcing typing for superset.models (#9883)
e789a35 is described below

commit e789a3555843d9791b9230a61454a3abb8cb07e0
Author: John Bodley <[email protected]>
AuthorDate: Fri May 22 20:31:21 2020 -0700

    [mypy] Enforcing typing for superset.models (#9883)
    
    Co-authored-by: John Bodley <[email protected]>
---
 setup.cfg                                     |   2 +-
 superset/connectors/sqla/models.py            |   8 +-
 superset/legacy.py                            |   3 +-
 superset/models/annotations.py                |   6 +-
 superset/models/core.py                       |  21 ++--
 superset/models/dashboard.py                  |  26 +++--
 superset/models/helpers.py                    | 140 +++++++++++++++-----------
 superset/models/schedules.py                  |   6 +-
 superset/models/slice.py                      |  12 ++-
 superset/models/sql_lab.py                    |  22 ++--
 superset/models/sql_types/presto_sql_types.py |  22 ++--
 superset/models/tags.py                       |  59 ++++++++---
 superset/utils/cache.py                       |   6 +-
 superset/utils/core.py                        |   4 +-
 14 files changed, 207 insertions(+), 130 deletions(-)

diff --git a/setup.cfg b/setup.cfg
index dc3e701..bfef8af 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -53,7 +53,7 @@ order_by_type = false
 ignore_missing_imports = true
 no_implicit_optional = true
 
-[mypy-superset.bin.*,superset.charts.*,superset.datasets.*,superset.dashboards.*,superset.commands.*,superset.common.*,superset.dao.*,superset.db_engine_specs.*,superset.db_engines.*,superset.examples.*,superset.migrations.*,superset.queries.*,superset.security.*,superset.sql_validators.*,superset.tasks.*,superset.translations.*]
+[mypy-superset.bin.*,superset.charts.*,superset.datasets.*,superset.dashboards.*,superset.commands.*,superset.common.*,superset.dao.*,superset.db_engine_specs.*,superset.db_engines.*,superset.examples.*,superset.migrations.*,superset.models.*,superset.queries.*,superset.security.*,superset.sql_validators.*,superset.tasks.*,superset.translations.*]
 check_untyped_defs = true
 disallow_untyped_calls = true
 disallow_untyped_defs = true
diff --git a/superset/connectors/sqla/models.py 
b/superset/connectors/sqla/models.py
index ad45474..3e50280 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -18,7 +18,7 @@
 import logging
 import re
 from collections import OrderedDict
-from datetime import datetime
+from datetime import datetime, timedelta
 from typing import Any, Dict, Hashable, List, NamedTuple, Optional, Tuple, 
Union
 
 import pandas as pd
@@ -103,7 +103,11 @@ class AnnotationDatasource(BaseDatasource):
             logger.exception(ex)
             error_message = utils.error_msg_from_exception(ex)
         return QueryResult(
-            status=status, df=df, duration=0, query="", 
error_message=error_message
+            status=status,
+            df=df,
+            duration=timedelta(0),
+            query="",
+            error_message=error_message,
         )
 
     def get_query_str(self, query_obj):
diff --git a/superset/legacy.py b/superset/legacy.py
index b867edc..168b9c0 100644
--- a/superset/legacy.py
+++ b/superset/legacy.py
@@ -15,9 +15,10 @@
 # specific language governing permissions and limitations
 # under the License.
 """Code related with dealing with legacy / change management"""
+from typing import Any, Dict
 
 
-def update_time_range(form_data):
+def update_time_range(form_data: Dict[str, Any]) -> None:
     """Move since and until to time_range."""
     if "since" in form_data or "until" in form_data:
         form_data["time_range"] = "{} : {}".format(
diff --git a/superset/models/annotations.py b/superset/models/annotations.py
index 07e2351..ec8d3c0 100644
--- a/superset/models/annotations.py
+++ b/superset/models/annotations.py
@@ -15,6 +15,8 @@
 # specific language governing permissions and limitations
 # under the License.
 """a collection of Annotation-related models"""
+from typing import Any, Dict
+
 from flask_appbuilder import Model
 from sqlalchemy import Column, DateTime, ForeignKey, Index, Integer, String, 
Text
 from sqlalchemy.orm import relationship
@@ -31,7 +33,7 @@ class AnnotationLayer(Model, AuditMixinNullable):
     name = Column(String(250))
     descr = Column(Text)
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return self.name
 
 
@@ -52,7 +54,7 @@ class Annotation(Model, AuditMixinNullable):
     __table_args__ = (Index("ti_dag_state", layer_id, start_dttm, end_dttm),)
 
     @property
-    def data(self):
+    def data(self) -> Dict[str, Any]:
         return {
             "layer_id": self.layer_id,
             "start_dttm": self.start_dttm,
diff --git a/superset/models/core.py b/superset/models/core.py
index abcb210..69d306a 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -152,7 +152,7 @@ class Database(
     ]
     export_children = ["tables"]
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return self.name
 
     @property
@@ -234,7 +234,9 @@ class Database(
         return self.get_extra().get("default_schemas", [])
 
     @classmethod
-    def get_password_masked_url_from_uri(cls, uri: str):  # pylint: 
disable=invalid-name
+    def get_password_masked_url_from_uri(  # pylint: disable=invalid-name
+        cls, uri: str
+    ) -> URL:
         sqlalchemy_url = make_url(uri)
         return cls.get_password_masked_url(sqlalchemy_url)
 
@@ -279,7 +281,7 @@ class Database(
                 effective_username = g.user.username
         return effective_username
 
-    @utils.memoized(watch=("impersonate_user", "sqlalchemy_uri_decrypted", 
"extra"))
+    @utils.memoized(watch=["impersonate_user", "sqlalchemy_uri_decrypted", 
"extra"])
     def get_sqla_engine(
         self,
         schema: Optional[str] = None,
@@ -339,7 +341,7 @@ class Database(
     def get_reserved_words(self) -> Set[str]:
         return self.get_dialect().preparer.reserved_words
 
-    def get_quoter(self):
+    def get_quoter(self) -> Callable:
         return self.get_dialect().identifier_preparer.quote
 
     def get_df(  # pylint: disable=too-many-locals
@@ -405,7 +407,7 @@ class Database(
         indent: bool = True,
         latest_partition: bool = False,
         cols: Optional[List[Dict[str, Any]]] = None,
-    ):
+    ) -> str:
         """Generates a ``select *`` statement in the proper dialect"""
         eng = self.get_sqla_engine(schema=schema, 
source=utils.QuerySource.SQL_LAB)
         return self.db_engine_spec.select_star(
@@ -436,7 +438,10 @@ class Database(
         attribute_in_key="id",
     )
     def get_all_table_names_in_database(
-        self, cache: bool = False, cache_timeout: Optional[bool] = None, 
force=False
+        self,
+        cache: bool = False,
+        cache_timeout: Optional[bool] = None,
+        force: bool = False,
     ) -> List[utils.DatasourceName]:
         """Parameters need to be passed as keyword arguments."""
         if not self.allow_multi_schema_metadata_fetch:
@@ -547,7 +552,7 @@ class Database(
 
     @classmethod
     def get_db_engine_spec_for_backend(
-        cls, backend
+        cls, backend: str
     ) -> Type[db_engine_specs.BaseEngineSpec]:
         return db_engine_specs.engines.get(backend, 
db_engine_specs.BaseEngineSpec)
 
@@ -565,7 +570,7 @@ class Database(
     def get_extra(self) -> Dict[str, Any]:
         return self.db_engine_spec.get_extra_params(self)
 
-    def get_encrypted_extra(self):
+    def get_encrypted_extra(self) -> Dict[str, Any]:
         encrypted_extra = {}
         if self.encrypted_extra:
             try:
diff --git a/superset/models/dashboard.py b/superset/models/dashboard.py
index c86f8ff..de42285 100644
--- a/superset/models/dashboard.py
+++ b/superset/models/dashboard.py
@@ -36,7 +36,9 @@ from sqlalchemy import (
     Text,
     UniqueConstraint,
 )
+from sqlalchemy.engine.base import Connection
 from sqlalchemy.orm import relationship, sessionmaker, subqueryload
+from sqlalchemy.orm.mapper import Mapper
 
 from superset import app, ConnectorRegistry, db, is_feature_enabled, 
security_manager
 from superset.models.helpers import AuditMixinNullable, ImportMixin
@@ -59,7 +61,7 @@ config = app.config
 logger = logging.getLogger(__name__)
 
 
-def copy_dashboard(mapper, connection, target):
+def copy_dashboard(mapper: Mapper, connection: Connection, target: 
"Dashboard") -> None:
     # pylint: disable=unused-argument
     dashboard_id = config["DASHBOARD_TEMPLATE_ID"]
     if dashboard_id is None:
@@ -140,7 +142,7 @@ class Dashboard(  # pylint: 
disable=too-many-instance-attributes
         "slug",
     ]
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return self.dashboard_title or str(self.id)
 
     @property
@@ -202,13 +204,13 @@ class Dashboard(  # pylint: 
disable=too-many-instance-attributes
         return f"/api/v1/dashboard/{self.id}/thumbnail/{self.digest}/"
 
     @property
-    def changed_by_name(self):
+    def changed_by_name(self) -> str:
         if not self.changed_by:
             return ""
         return str(self.changed_by)
 
     @property
-    def changed_by_url(self):
+    def changed_by_url(self) -> str:
         if not self.changed_by:
             return ""
         return f"/superset/profile/{self.changed_by.username}"
@@ -229,8 +231,8 @@ class Dashboard(  # pylint: 
disable=too-many-instance-attributes
             "position_json": positions,
         }
 
-    @property
-    def params(self) -> str:
+    @property  # type: ignore
+    def params(self) -> str:  # type: ignore
         return self.json_metadata
 
     @params.setter
@@ -257,7 +259,9 @@ class Dashboard(  # pylint: 
disable=too-many-instance-attributes
          Audit metadata isn't copied over.
         """
 
-        def alter_positions(dashboard, old_to_new_slc_id_dict):
+        def alter_positions(
+            dashboard: Dashboard, old_to_new_slc_id_dict: Dict[int, int]
+        ) -> None:
             """ Updates slice_ids in the position json.
 
             Sample position_json data:
@@ -291,9 +295,9 @@ class Dashboard(  # pylint: 
disable=too-many-instance-attributes
                 if (
                     isinstance(value, dict)
                     and value.get("meta")
-                    and value.get("meta").get("chartId")
+                    and value.get("meta", {}).get("chartId")
                 ):
-                    old_slice_id = value.get("meta").get("chartId")
+                    old_slice_id = value["meta"]["chartId"]
 
                     if old_slice_id in old_to_new_slc_id_dict:
                         value["meta"]["chartId"] = 
old_to_new_slc_id_dict[old_slice_id]
@@ -470,8 +474,8 @@ class Dashboard(  # pylint: 
disable=too-many-instance-attributes
 
 
 def event_after_dashboard_changed(  # pylint: disable=unused-argument
-    mapper, connection, target
-):
+    mapper: Mapper, connection: Connection, target: Dashboard
+) -> None:
     cache_dashboard_thumbnail.delay(target.id, force=True)
 
 
diff --git a/superset/models/helpers.py b/superset/models/helpers.py
index 93b3b7f..42169e6 100644
--- a/superset/models/helpers.py
+++ b/superset/models/helpers.py
@@ -18,8 +18,8 @@
 import json
 import logging
 import re
-from datetime import datetime
-from typing import Any, Dict, List, Optional
+from datetime import datetime, timedelta
+from typing import Any, Dict, List, Optional, Set, Union
 
 # isort and pylint disagree, isort should win
 # pylint: disable=ungrouped-imports
@@ -30,8 +30,10 @@ import yaml
 from flask import escape, g, Markup
 from flask_appbuilder.models.decorators import renders
 from flask_appbuilder.models.mixins import AuditMixin
+from flask_appbuilder.security.sqla.models import User
 from sqlalchemy import and_, or_, UniqueConstraint
 from sqlalchemy.ext.declarative import declared_attr
+from sqlalchemy.orm import Session
 from sqlalchemy.orm.exc import MultipleResultsFound
 
 from superset.utils.core import QueryStatus
@@ -39,7 +41,7 @@ from superset.utils.core import QueryStatus
 logger = logging.getLogger(__name__)
 
 
-def json_to_dict(json_str):
+def json_to_dict(json_str: str) -> Dict[Any, Any]:
     if json_str:
         val = re.sub(",[ \t\r\n]+}", "}", json_str)
         val = re.sub(
@@ -64,48 +66,56 @@ class ImportMixin:
     # that are available for import and export
 
     @classmethod
-    def _parent_foreign_key_mappings(cls):
+    def _parent_foreign_key_mappings(cls) -> Dict[str, str]:
         """Get a mapping of foreign name to the local name of foreign keys"""
-        parent_rel = cls.__mapper__.relationships.get(cls.export_parent)
+        parent_rel = cls.__mapper__.relationships.get(cls.export_parent)  # 
type: ignore
         if parent_rel:
             return {l.name: r.name for (l, r) in parent_rel.local_remote_pairs}
         return {}
 
     @classmethod
-    def _unique_constrains(cls):
+    def _unique_constrains(cls) -> List[Set[str]]:
         """Get all (single column and multi column) unique constraints"""
         unique = [
             {c.name for c in u.columns}
-            for u in cls.__table_args__
+            for u in cls.__table_args__  # type: ignore
             if isinstance(u, UniqueConstraint)
         ]
-        unique.extend({c.name} for c in cls.__table__.columns if c.unique)
+        unique.extend(  # type: ignore
+            {c.name} for c in cls.__table__.columns if c.unique  # type: ignore
+        )
         return unique
 
     @classmethod
-    def export_schema(cls, recursive=True, include_parent_ref=False):
+    def export_schema(
+        cls, recursive: bool = True, include_parent_ref: bool = False
+    ) -> Dict[str, Any]:
         """Export schema as a dictionary"""
-        parent_excludes = {}
+        parent_excludes = set()
         if not include_parent_ref:
-            parent_ref = cls.__mapper__.relationships.get(cls.export_parent)
+            parent_ref = cls.__mapper__.relationships.get(  # type: ignore
+                cls.export_parent
+            )
             if parent_ref:
                 parent_excludes = {column.name for column in 
parent_ref.local_columns}
 
-        def formatter(column):
+        def formatter(column: sa.Column) -> str:
             return (
                 "{0} Default ({1})".format(str(column.type), 
column.default.arg)
                 if column.default
                 else str(column.type)
             )
 
-        schema = {
+        schema: Dict[str, Any] = {
             column.name: formatter(column)
-            for column in cls.__table__.columns
+            for column in cls.__table__.columns  # type: ignore
             if (column.name in cls.export_fields and column.name not in 
parent_excludes)
         }
         if recursive:
             for column in cls.export_children:
-                child_class = 
cls.__mapper__.relationships[column].argument.class_
+                child_class = cls.__mapper__.relationships[  # type: ignore
+                    column
+                ].argument.class_
                 schema[column] = [
                     child_class.export_schema(
                         recursive=recursive, 
include_parent_ref=include_parent_ref
@@ -114,17 +124,20 @@ class ImportMixin:
         return schema
 
     @classmethod
-    def import_from_dict(
-        cls, session, dict_rep, parent=None, recursive=True, sync=None
-    ):  # pylint: disable=too-many-arguments,too-many-locals,too-many-branches
+    def import_from_dict(  # pylint: 
disable=too-many-arguments,too-many-branches,too-many-locals
+        cls,
+        session: Session,
+        dict_rep: Dict[Any, Any],
+        parent: Optional[Any] = None,
+        recursive: bool = True,
+        sync: Optional[List[str]] = None,
+    ) -> Any:  # pylint: 
disable=too-many-arguments,too-many-locals,too-many-branches
         """Import obj from a dictionary"""
         if sync is None:
             sync = []
         parent_refs = cls._parent_foreign_key_mappings()
         export_fields = set(cls.export_fields) | set(parent_refs.keys())
-        new_children = {
-            c: dict_rep.get(c) for c in cls.export_children if c in dict_rep
-        }
+        new_children = {c: dict_rep[c] for c in cls.export_children if c in 
dict_rep}
         unique_constrains = cls._unique_constrains()
 
         filters = []  # Using these filters to check if obj already exists
@@ -178,7 +191,7 @@ class ImportMixin:
         if not obj:
             is_new_obj = True
             # Create new DB object
-            obj = cls(**dict_rep)
+            obj = cls(**dict_rep)  # type: ignore
             logger.info("Importing new %s %s", obj.__tablename__, str(obj))
             if cls.export_parent and parent:
                 setattr(obj, cls.export_parent, parent)
@@ -193,7 +206,9 @@ class ImportMixin:
         # Recursively create children
         if recursive:
             for child in cls.export_children:
-                child_class = 
cls.__mapper__.relationships[child].argument.class_
+                child_class = cls.__mapper__.relationships[  # type: ignore
+                    child
+                ].argument.class_
                 added = []
                 for c_obj in new_children.get(child, []):
                     added.append(
@@ -221,18 +236,23 @@ class ImportMixin:
         return obj
 
     def export_to_dict(
-        self, recursive=True, include_parent_ref=False, include_defaults=False
-    ):
+        self,
+        recursive: bool = True,
+        include_parent_ref: bool = False,
+        include_defaults: bool = False,
+    ) -> Dict[Any, Any]:
         """Export obj to dictionary"""
         cls = self.__class__
-        parent_excludes = {}
+        parent_excludes = set()
         if recursive and not include_parent_ref:
-            parent_ref = cls.__mapper__.relationships.get(cls.export_parent)
+            parent_ref = cls.__mapper__.relationships.get(  # type: ignore
+                cls.export_parent
+            )
             if parent_ref:
                 parent_excludes = {c.name for c in parent_ref.local_columns}
         dict_rep = {
             c.name: getattr(self, c.name)
-            for c in cls.__table__.columns
+            for c in cls.__table__.columns  # type: ignore
             if (
                 c.name in self.export_fields
                 and c.name not in parent_excludes
@@ -262,18 +282,18 @@ class ImportMixin:
 
         return dict_rep
 
-    def override(self, obj):
+    def override(self, obj: Any) -> None:
         """Overrides the plain fields of the dashboard."""
         for field in obj.__class__.export_fields:
             setattr(self, field, getattr(obj, field))
 
-    def copy(self):
+    def copy(self) -> Any:
         """Creates a copy of the dashboard without relationships."""
         new_obj = self.__class__()
         new_obj.override(self)
         return new_obj
 
-    def alter_params(self, **kwargs):
+    def alter_params(self, **kwargs: Any) -> None:
         params = self.params_dict
         params.update(kwargs)
         self.params = json.dumps(params)
@@ -283,7 +303,7 @@ class ImportMixin:
         params.pop(param_to_remove, None)
         self.params = json.dumps(params)
 
-    def reset_ownership(self):
+    def reset_ownership(self) -> None:
         """ object will belong to the user the current user """
         # make sure the object doesn't have relations to a user
         # it will be filled by appbuilder on save
@@ -297,15 +317,15 @@ class ImportMixin:
             self.owners = []
 
     @property
-    def params_dict(self):
+    def params_dict(self) -> Dict[Any, Any]:
         return json_to_dict(self.params)
 
     @property
-    def template_params_dict(self):
-        return json_to_dict(self.template_params)
+    def template_params_dict(self) -> Dict[Any, Any]:
+        return json_to_dict(self.template_params)  # type: ignore
 
 
-def _user_link(user):  # pylint: disable=no-self-use
+def _user_link(user: User) -> Union[Markup, str]:  # pylint: 
disable=no-self-use
     if not user:
         return ""
     url = "/superset/profile/{}/".format(user.username)
@@ -325,7 +345,7 @@ class AuditMixinNullable(AuditMixin):
     )
 
     @declared_attr
-    def created_by_fk(self):
+    def created_by_fk(self) -> sa.Column:
         return sa.Column(
             sa.Integer,
             sa.ForeignKey("ab_user.id"),
@@ -334,7 +354,7 @@ class AuditMixinNullable(AuditMixin):
         )
 
     @declared_attr
-    def changed_by_fk(self):
+    def changed_by_fk(self) -> sa.Column:
         return sa.Column(
             sa.Integer,
             sa.ForeignKey("ab_user.id"),
@@ -343,29 +363,29 @@ class AuditMixinNullable(AuditMixin):
             nullable=True,
         )
 
-    def changed_by_name(self):
+    def changed_by_name(self) -> str:
         if self.created_by:
             return escape("{}".format(self.created_by))
         return ""
 
     @renders("created_by")
-    def creator(self):
+    def creator(self) -> Union[Markup, str]:
         return _user_link(self.created_by)
 
     @property
-    def changed_by_(self):
+    def changed_by_(self) -> Union[Markup, str]:
         return _user_link(self.changed_by)
 
     @renders("changed_on")
-    def changed_on_(self):
+    def changed_on_(self) -> Markup:
         return Markup(f'<span class="no-wrap">{self.changed_on}</span>')
 
     @property
-    def changed_on_humanized(self):
+    def changed_on_humanized(self) -> str:
         return humanize.naturaltime(datetime.now() - self.changed_on)
 
     @renders("changed_on")
-    def modified(self):
+    def modified(self) -> Markup:
         return Markup(f'<span 
class="no-wrap">{self.changed_on_humanized}</span>')
 
 
@@ -375,19 +395,19 @@ class QueryResult:  # pylint: 
disable=too-few-public-methods
 
     def __init__(  # pylint: disable=too-many-arguments
         self,
-        df,
-        query,
-        duration,
-        status=QueryStatus.SUCCESS,
-        error_message=None,
-        errors=None,
-    ):
-        self.df: pd.DataFrame = df
-        self.query: str = query
-        self.duration: int = duration
-        self.status: str = status
-        self.error_message: Optional[str] = error_message
-        self.errors: List[Dict[str, Any]] = errors or []
+        df: pd.DataFrame,
+        query: str,
+        duration: timedelta,
+        status: str = QueryStatus.SUCCESS,
+        error_message: Optional[str] = None,
+        errors: Optional[List[Dict[str, Any]]] = None,
+    ) -> None:
+        self.df = df
+        self.query = query
+        self.duration = duration
+        self.status = status
+        self.error_message = error_message
+        self.errors = errors or []
 
 
 class ExtraJSONMixin:
@@ -396,16 +416,16 @@ class ExtraJSONMixin:
     extra_json = sa.Column(sa.Text, default="{}")
 
     @property
-    def extra(self):
+    def extra(self) -> Dict[str, Any]:
         try:
             return json.loads(self.extra_json)
         except Exception:  # pylint: disable=broad-except
             return {}
 
-    def set_extra_json(self, extras):
+    def set_extra_json(self, extras: Dict[str, Any]) -> None:
         self.extra_json = json.dumps(extras)
 
-    def set_extra_json_key(self, key, value):
+    def set_extra_json_key(self, key: str, value: Any) -> None:
         extra = self.extra
         extra[key] = value
         self.extra_json = json.dumps(extra)
diff --git a/superset/models/schedules.py b/superset/models/schedules.py
index f70b076..0eb31a5 100644
--- a/superset/models/schedules.py
+++ b/superset/models/schedules.py
@@ -21,7 +21,7 @@ from typing import Optional, Type
 from flask_appbuilder import Model
 from sqlalchemy import Boolean, Column, Enum, ForeignKey, Integer, String, Text
 from sqlalchemy.ext.declarative import declared_attr
-from sqlalchemy.orm import relationship
+from sqlalchemy.orm import relationship, RelationshipProperty
 
 from superset import security_manager
 from superset.models.helpers import AuditMixinNullable, ImportMixin
@@ -55,11 +55,11 @@ class EmailSchedule:
     crontab = Column(String(50))
 
     @declared_attr
-    def user_id(self):
+    def user_id(self) -> int:
         return Column(Integer, ForeignKey("ab_user.id"))
 
     @declared_attr
-    def user(self):
+    def user(self) -> RelationshipProperty:
         return relationship(
             security_manager.user_model,
             backref=self.__tablename__,
diff --git a/superset/models/slice.py b/superset/models/slice.py
index a570bcf..76eb457 100644
--- a/superset/models/slice.py
+++ b/superset/models/slice.py
@@ -24,7 +24,9 @@ from flask_appbuilder import Model
 from flask_appbuilder.models.decorators import renders
 from markupsafe import escape, Markup
 from sqlalchemy import Column, ForeignKey, Integer, String, Table, Text
+from sqlalchemy.engine.base import Connection
 from sqlalchemy.orm import make_transient, relationship
+from sqlalchemy.orm.mapper import Mapper
 
 from superset import ConnectorRegistry, db, is_feature_enabled, 
security_manager
 from superset.legacy import update_time_range
@@ -92,7 +94,7 @@ class Slice(
         "cache_timeout",
     ]
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return self.slice_name or str(self.id)
 
     @property
@@ -263,7 +265,7 @@ class Slice(
 
     @property
     def changed_by_url(self) -> str:
-        return f"/superset/profile/{self.created_by.username}"
+        return f"/superset/profile/{self.created_by.username}"  # type: ignore
 
     @property
     def icons(self) -> str:
@@ -324,7 +326,7 @@ class Slice(
         return 
f"/superset/explore/?form_data=%7B%22slice_id%22%3A%20{self.id}%7D"
 
 
-def set_related_perm(mapper, connection, target):
+def set_related_perm(mapper: Mapper, connection: Connection, target: Slice) -> 
None:
     # pylint: disable=unused-argument
     src_class = target.cls_model
     id_ = target.datasource_id
@@ -336,8 +338,8 @@ def set_related_perm(mapper, connection, target):
 
 
 def event_after_chart_changed(  # pylint: disable=unused-argument
-    mapper, connection, target
-):
+    mapper: Mapper, connection: Connection, target: Slice
+) -> None:
     cache_chart_thumbnail.delay(target.id, force=True)
 
 
diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py
index 654bc2f..9c3c239 100644
--- a/superset/models/sql_lab.py
+++ b/superset/models/sql_lab.py
@@ -17,6 +17,7 @@
 """A collection of ORM sqlalchemy models for SQL Lab"""
 import re
 from datetime import datetime
+from typing import Any, Dict
 
 # pylint: disable=ungrouped-imports
 import simplejson as json
@@ -33,6 +34,7 @@ from sqlalchemy import (
     String,
     Text,
 )
+from sqlalchemy.engine.url import URL
 from sqlalchemy.orm import backref, relationship
 
 from superset import security_manager
@@ -99,7 +101,7 @@ class Query(Model, ExtraJSONMixin):
 
     __table_args__ = (sqla.Index("ti_user_id_changed_on", user_id, 
changed_on),)
 
-    def to_dict(self):
+    def to_dict(self) -> Dict[str, Any]:
         return {
             "changedOn": self.changed_on,
             "changed_on": self.changed_on.isoformat(),
@@ -130,7 +132,7 @@ class Query(Model, ExtraJSONMixin):
         }
 
     @property
-    def name(self):
+    def name(self) -> str:
         """Name property"""
         ts = datetime.now().isoformat()
         ts = ts.replace("-", "").replace(":", "").split(".")[0]
@@ -139,11 +141,11 @@ class Query(Model, ExtraJSONMixin):
         return f"sqllab_{tab}_{ts}"
 
     @property
-    def database_name(self):
+    def database_name(self) -> str:
         return self.database.name
 
     @property
-    def username(self):
+    def username(self) -> str:
         return self.user.username
 
 
@@ -170,7 +172,7 @@ class SavedQuery(Model, AuditMixinNullable, ExtraJSONMixin):
     )
 
     @property
-    def pop_tab_link(self):
+    def pop_tab_link(self) -> Markup:
         return Markup(
             f"""
             <a href="/superset/sqllab?savedQueryId={self.id}">
@@ -180,14 +182,14 @@ class SavedQuery(Model, AuditMixinNullable, 
ExtraJSONMixin):
         )
 
     @property
-    def user_email(self):
+    def user_email(self) -> str:
         return self.user.email
 
     @property
-    def sqlalchemy_uri(self):
+    def sqlalchemy_uri(self) -> URL:
         return self.database.sqlalchemy_uri
 
-    def url(self):
+    def url(self) -> str:
         return "/superset/sqllab?savedQueryId={0}".format(self.id)
 
 
@@ -226,7 +228,7 @@ class TabState(Model, AuditMixinNullable, ExtraJSONMixin):
     autorun = Column(Boolean, default=False)
     template_params = Column(Text)
 
-    def to_dict(self):
+    def to_dict(self) -> Dict[str, Any]:
         return {
             "id": self.id,
             "user_id": self.user_id,
@@ -260,7 +262,7 @@ class TableSchema(Model, AuditMixinNullable, 
ExtraJSONMixin):
 
     expanded = Column(Boolean, default=False)
 
-    def to_dict(self):
+    def to_dict(self) -> Dict[str, Any]:
         try:
             description = json.loads(self.description)
         except json.JSONDecodeError:
diff --git a/superset/models/sql_types/presto_sql_types.py 
b/superset/models/sql_types/presto_sql_types.py
index f0b46fa..a50b4c2 100644
--- a/superset/models/sql_types/presto_sql_types.py
+++ b/superset/models/sql_types/presto_sql_types.py
@@ -14,10 +14,12 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+from typing import Any, Optional, Type
 
 from sqlalchemy import types
 from sqlalchemy.sql.sqltypes import Integer
 from sqlalchemy.sql.type_api import TypeEngine
+from sqlalchemy.sql.visitors import Visitable
 
 # _compiler_dispatch is defined to help with type compilation
 
@@ -27,11 +29,11 @@ class TinyInteger(Integer):
     A type for tiny ``int`` integers.
     """
 
-    def python_type(self):
+    def python_type(self) -> Type:
         return int
 
     @classmethod
-    def _compiler_dispatch(cls, _visitor, **_kw):
+    def _compiler_dispatch(cls, _visitor: Visitable, **_kw: Any) -> str:
         return "TINYINT"
 
 
@@ -40,11 +42,11 @@ class Interval(TypeEngine):
     A type for intervals.
     """
 
-    def python_type(self):
+    def python_type(self) -> Optional[Type]:
         return None
 
     @classmethod
-    def _compiler_dispatch(cls, _visitor, **_kw):
+    def _compiler_dispatch(cls, _visitor: Visitable, **_kw: Any) -> str:
         return "INTERVAL"
 
 
@@ -53,11 +55,11 @@ class Array(TypeEngine):
     A type for arrays.
     """
 
-    def python_type(self):
+    def python_type(self) -> Optional[Type]:
         return list
 
     @classmethod
-    def _compiler_dispatch(cls, _visitor, **_kw):
+    def _compiler_dispatch(cls, _visitor: Visitable, **_kw: Any) -> str:
         return "ARRAY"
 
 
@@ -66,11 +68,11 @@ class Map(TypeEngine):
     A type for maps.
     """
 
-    def python_type(self):
+    def python_type(self) -> Optional[Type]:
         return dict
 
     @classmethod
-    def _compiler_dispatch(cls, _visitor, **_kw):
+    def _compiler_dispatch(cls, _visitor: Visitable, **_kw: Any) -> str:
         return "MAP"
 
 
@@ -79,11 +81,11 @@ class Row(TypeEngine):
     A type for rows.
     """
 
-    def python_type(self):
+    def python_type(self) -> Optional[Type]:
         return None
 
     @classmethod
-    def _compiler_dispatch(cls, _visitor, **_kw):
+    def _compiler_dispatch(cls, _visitor: Visitable, **_kw: Any) -> str:
         return "ROW"
 
 
diff --git a/superset/models/tags.py b/superset/models/tags.py
index 0cb00cc..c09bb16 100644
--- a/superset/models/tags.py
+++ b/superset/models/tags.py
@@ -17,15 +17,23 @@
 from __future__ import absolute_import, division, print_function, 
unicode_literals
 
 import enum
-from typing import Optional
+from typing import List, Optional, TYPE_CHECKING, Union
 
 from flask_appbuilder import Model
 from sqlalchemy import Column, Enum, ForeignKey, Integer, String
-from sqlalchemy.orm import relationship, sessionmaker
+from sqlalchemy.engine.base import Connection
+from sqlalchemy.orm import relationship, Session, sessionmaker
 from sqlalchemy.orm.exc import NoResultFound
+from sqlalchemy.orm.mapper import Mapper
 
 from superset.models.helpers import AuditMixinNullable
 
+if TYPE_CHECKING:
+    from superset.models.core import FavStar  # pylint: disable=unused-import
+    from superset.models.dashboard import Dashboard  # pylint: 
disable=unused-import
+    from superset.models.slice import Slice  # pylint: disable=unused-import
+    from superset.models.sql_lab import Query  # pylint: disable=unused-import
+
 Session = sessionmaker(autoflush=False)
 
 
@@ -80,7 +88,7 @@ class TaggedObject(Model, AuditMixinNullable):
     tag = relationship("Tag", backref="objects")
 
 
-def get_tag(name, session, type_):
+def get_tag(name: str, session: Session, type_: TagTypes) -> Tag:
     try:
         tag = session.query(Tag).filter_by(name=name, type=type_).one()
     except NoResultFound:
@@ -91,7 +99,7 @@ def get_tag(name, session, type_):
     return tag
 
 
-def get_object_type(class_name):
+def get_object_type(class_name: str) -> ObjectTypes:
     mapping = {
         "slice": ObjectTypes.chart,
         "dashboard": ObjectTypes.dashboard,
@@ -108,11 +116,15 @@ class ObjectUpdater:
     object_type: Optional[str] = None
 
     @classmethod
-    def get_owners_ids(cls, target):
+    def get_owners_ids(
+        cls, target: Union["Dashboard", "FavStar", "Slice"]
+    ) -> List[int]:
         raise NotImplementedError("Subclass should implement `get_owners_ids`")
 
     @classmethod
-    def _add_owners(cls, session, target):
+    def _add_owners(
+        cls, session: Session, target: Union["Dashboard", "FavStar", "Slice"]
+    ) -> None:
         for owner_id in cls.get_owners_ids(target):
             name = "owner:{0}".format(owner_id)
             tag = get_tag(name, session, TagTypes.owner)
@@ -122,7 +134,12 @@ class ObjectUpdater:
             session.add(tagged_object)
 
     @classmethod
-    def after_insert(cls, mapper, connection, target):
+    def after_insert(
+        cls,
+        mapper: Mapper,
+        connection: Connection,
+        target: Union["Dashboard", "FavStar", "Slice"],
+    ) -> None:
         # pylint: disable=unused-argument
         session = Session(bind=connection)
 
@@ -139,7 +156,12 @@ class ObjectUpdater:
         session.commit()
 
     @classmethod
-    def after_update(cls, mapper, connection, target):
+    def after_update(
+        cls,
+        mapper: Mapper,
+        connection: Connection,
+        target: Union["Dashboard", "FavStar", "Slice"],
+    ) -> None:
         # pylint: disable=unused-argument
         session = Session(bind=connection)
 
@@ -164,7 +186,12 @@ class ObjectUpdater:
         session.commit()
 
     @classmethod
-    def after_delete(cls, mapper, connection, target):
+    def after_delete(
+        cls,
+        mapper: Mapper,
+        connection: Connection,
+        target: Union["Dashboard", "FavStar", "Slice"],
+    ) -> None:
         # pylint: disable=unused-argument
         session = Session(bind=connection)
 
@@ -182,7 +209,7 @@ class ChartUpdater(ObjectUpdater):
     object_type = "chart"
 
     @classmethod
-    def get_owners_ids(cls, target):
+    def get_owners_ids(cls, target: "Slice") -> List[int]:
         return [owner.id for owner in target.owners]
 
 
@@ -191,7 +218,7 @@ class DashboardUpdater(ObjectUpdater):
     object_type = "dashboard"
 
     @classmethod
-    def get_owners_ids(cls, target):
+    def get_owners_ids(cls, target: "Dashboard") -> List[int]:
         return [owner.id for owner in target.owners]
 
 
@@ -200,13 +227,15 @@ class QueryUpdater(ObjectUpdater):
     object_type = "query"
 
     @classmethod
-    def get_owners_ids(cls, target):
+    def get_owners_ids(cls, target: "Query") -> List[int]:
         return [target.user_id]
 
 
 class FavStarUpdater:
     @classmethod
-    def after_insert(cls, mapper, connection, target):
+    def after_insert(
+        cls, mapper: Mapper, connection: Connection, target: "FavStar"
+    ) -> None:
         # pylint: disable=unused-argument
         session = Session(bind=connection)
         name = "favorited_by:{0}".format(target.user_id)
@@ -221,7 +250,9 @@ class FavStarUpdater:
         session.commit()
 
     @classmethod
-    def after_delete(cls, mapper, connection, target):
+    def after_delete(
+        cls, mapper: Mapper, connection: Connection, target: "FavStar"
+    ) -> None:
         # pylint: disable=unused-argument
         session = Session(bind=connection)
         name = "favorited_by:{0}".format(target.user_id)
diff --git a/superset/utils/cache.py b/superset/utils/cache.py
index a7a5dc6..b555005 100644
--- a/superset/utils/cache.py
+++ b/superset/utils/cache.py
@@ -14,6 +14,8 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+from typing import Callable, Optional
+
 from flask import request
 
 from superset.extensions import cache_manager
@@ -24,7 +26,9 @@ def view_cache_key(*_, **__) -> str:
     return "view/{}/{}".format(request.path, args_hash)
 
 
-def memoized_func(key=view_cache_key, attribute_in_key=None):
+def memoized_func(
+    key: Callable = view_cache_key, attribute_in_key: Optional[str] = None
+) -> Callable:
     """Use this decorator to cache functions that have predefined first arg.
 
     enable_cache is treated as True by default,
diff --git a/superset/utils/core.py b/superset/utils/core.py
index 19e39ab..e093d3d 100644
--- a/superset/utils/core.py
+++ b/superset/utils/core.py
@@ -143,7 +143,7 @@ class _memoized:
         self.func = func
         self.cache = {}
         self.is_method = False
-        self.watch = watch
+        self.watch = watch or []
 
     def __call__(self, *args, **kwargs):
         key = [args, frozenset(kwargs.items())]
@@ -172,7 +172,7 @@ class _memoized:
         return functools.partial(self.__call__, obj)
 
 
-def memoized(func=None, watch=None):
+def memoized(func: Optional[Callable] = None, watch: Optional[List[str]] = 
None):
     if func:
         return _memoized(func)
     else:

Reply via email to