This is an automated email from the ASF dual-hosted git repository.

lyndsi pushed a commit to branch 
lyndsi/sql-lab-new-explore-button-functionality-and-move-save-dataset-to-split-save-button
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 8d087f2d413a54698536c1585920b5b2b6d6cc24
Author: Hugh A. Miles II <[email protected]>
AuthorDate: Fri Jun 24 12:46:39 2022 +0000

    patch for pre-commit
---
 superset/models/helpers.py | 160 +++++++++++++++++++++++++++++----------------
 superset/models/sql_lab.py |  14 +---
 superset/utils/core.py     |   2 +-
 3 files changed, 106 insertions(+), 70 deletions(-)

diff --git a/superset/models/helpers.py b/superset/models/helpers.py
index 73f10741fb..62eec0ec57 100644
--- a/superset/models/helpers.py
+++ b/superset/models/helpers.py
@@ -29,15 +29,18 @@ from typing import (
     Dict,
     Hashable,
     List,
+    Mapping,
     NamedTuple,
     Optional,
     Set,
+    Text,
     Tuple,
     Type,
     TYPE_CHECKING,
     Union,
 )
 
+import dateutil.parser
 import humanize
 import numpy as np
 import pandas as pd
@@ -62,10 +65,15 @@ from sqlalchemy.sql.selectable import Alias, TableClause
 from sqlalchemy_utils import UUIDType
 
 from superset import app, db, is_feature_enabled, security_manager
+from superset.advanced_data_type.types import AdvancedDataTypeResponse
 from superset.common.db_query_status import QueryStatus
 from superset.constants import EMPTY_STRING, NULL_STRING
 from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
-from superset.exceptions import SupersetSecurityException
+from superset.exceptions import (
+    AdvancedDataTypeResponseError,
+    QueryClauseValidationException,
+    SupersetSecurityException,
+)
 from superset.extensions import feature_flag_manager
 from superset.jinja_context import (
     BaseTemplateProcessor,
@@ -716,7 +724,7 @@ class ExploreMixin:
         raise NotImplementedError()
 
     @property
-    def main_dttm_col(self) -> str:
+    def main_dttm_col(self) -> Optional[str]:
         raise NotImplementedError()
 
     @property
@@ -751,6 +759,26 @@ class ExploreMixin:
     def get_extra_cache_keys(query_obj: Dict[str, Any]) -> List[str]:
         raise NotImplementedError()
 
+    def _process_sql_expression(  # type: ignore
+        expression: Optional[str],
+        database_id: int,
+        schema: str,
+        template_processor: Optional[BaseTemplateProcessor],
+    ) -> Optional[str]:
+        if template_processor and expression:
+            expression = template_processor.process_template(expression)
+        if expression:
+            expression = validate_adhoc_subquery(
+                expression,
+                database_id,
+                schema,
+            )
+            try:
+                expression = sanitize_clause(expression)
+            except QueryClauseValidationException as ex:
+                raise QueryObjectValidationError(ex.message) from ex
+        return expression
+
     def make_sqla_column_compatible(
         self, sqla_col: ColumnElement, label: Optional[str] = None
     ) -> ColumnElement:
@@ -832,13 +860,10 @@ class ExploreMixin:
 
     def get_query_str_extended(self, query_obj: QueryObjectDict) -> 
QueryStringExtended:
         sqlaq = self.get_sqla_query(**query_obj)
-        sql = self.database.compile_sqla_query(sqlaq.sqla_query)
+        sql = self.database.compile_sqla_query(sqlaq.sqla_query)  # type: 
ignore
         sql = self._apply_cte(sql, sqlaq.cte)
         sql = sqlparse.format(sql, reindent=True)
         sql = self.mutate_query_from_config(sql)
-        # from pprint import pprint
-        # pprint(sql)
-        # breakpoint()
         return QueryStringExtended(
             applied_template_filters=sqlaq.applied_template_filters,
             labels_expected=sqlaq.labels_expected,
@@ -846,6 +871,43 @@ class ExploreMixin:
             sql=sql,
         )
 
+    def _normalize_prequery_result_type(
+        self,
+        row: pd.Series,
+        dimension: str,
+        columns_by_name: Dict[str, "TableColumn"],
+    ) -> Union[str, int, float, bool, Text]:
+        """
+        Convert a prequery result type to its equivalent Python type.
+
+        Some databases like Druid will return timestamps as strings, but do 
not perform
+        automatic casting when comparing these strings to a timestamp. For 
cases like
+        this we convert the value via the appropriate SQL transform.
+
+        :param row: A prequery record
+        :param dimension: The dimension name
+        :param columns_by_name: The mapping of columns by name
+        :return: equivalent primitive python type
+        """
+
+        value = row[dimension]
+
+        if isinstance(value, np.generic):
+            value = value.item()
+
+        column_ = columns_by_name[dimension]
+        db_extra: Dict[str, Any] = self.database.get_extra()  # type: ignore
+
+        if column_.type and column_.is_temporal and isinstance(value, str):
+            sql = self.db_engine_spec.convert_dttm(
+                column_.type, dateutil.parser.parse(value), db_extra=db_extra
+            )
+
+            if sql:
+                value = self.text(sql)
+
+        return value
+
     def make_orderby_compatible(
         self, select_exprs: List[ColumnElement], orderby_exprs: 
List[ColumnElement]
     ) -> None:
@@ -883,49 +945,40 @@ class ExploreMixin:
         errors = None
         error_message = None
 
-        # def assign_column_label(df: pd.DataFrame) -> Optional[pd.DataFrame]:
-        #     """
-        #     Some engines change the case or generate bespoke column names, 
either by
-        #     default or due to lack of support for aliasing. This function 
ensures that
-        #     the column names in the DataFrame correspond to what is expected 
by
-        #     the viz components.
-        #     Sometimes a query may also contain only order by columns that 
are not used
-        #     as metrics or groupby columns, but need to present in the SQL 
`select`,
-        #     filtering by `labels_expected` make sure we only return columns 
users want.
-        #     :param df: Original DataFrame returned by the engine
-        #     :return: Mutated DataFrame
-        #     """
-        #     labels_expected = query_str_ext.labels_expected
-        #     if df is not None and not df.empty:
-        #         if len(df.columns) < len(labels_expected):
-        #             raise QueryObjectValidationError(
-        #                 _("Db engine did not return all queried columns")
-        #             )
-        #         if len(df.columns) > len(labels_expected):
-        #             df = df.iloc[:, 0: len(labels_expected)]
-        #         df.columns = labels_expected
-        #     return df
+        def assign_column_label(df: pd.DataFrame) -> Optional[pd.DataFrame]:
+            """
+            Some engines change the case or generate bespoke column names, 
either by
+            default or due to lack of support for aliasing. This function 
ensures that
+            the column names in the DataFrame correspond to what is expected by
+            the viz components.
+            Sometimes a query may also contain only order by columns that are 
not used
+            as metrics or groupby columns, but need to present in the SQL 
`select`,
+            filtering by `labels_expected` make sure we only return columns 
users want.
+            :param df: Original DataFrame returned by the engine
+            :return: Mutated DataFrame
+            """
+            labels_expected = query_str_ext.labels_expected
+            if df is not None and not df.empty:
+                if len(df.columns) < len(labels_expected):
+                    raise QueryObjectValidationError(
+                        _("Db engine did not return all queried columns")
+                    )
+                if len(df.columns) > len(labels_expected):
+                    df = df.iloc[:, 0 : len(labels_expected)]
+                df.columns = labels_expected
+            return df
 
         try:
-            # todo(hugh) fix this
-            # df = self.database.get_df(
-            #     sql, self.schema, mutator=assign_column_label)
-            df = self.database.get_df(sql, self.schema)
+            df = self.database.get_df(sql, self.schema, 
mutator=assign_column_label)  # type: ignore
         except Exception as ex:  # pylint: disable=broad-except
             df = pd.DataFrame()
             status = QueryStatus.FAILED
             logger.warning(
                 "Query %s on schema %s failed", sql, self.schema, exc_info=True
             )
-            # todo(hugh): how are we handling errors
-            # db_engine_spec = self.db_engine_spec
-            # errors = [
-            #     dataclasses.asdict(error) for error in 
db_engine_spec.extract_errors(ex)
-            # ]
             error_message = utils.error_msg_from_exception(ex)
 
         return QueryResult(
-            # applied_template_filters=query_str_ext.applied_template_filters,
             status=status,
             df=df,
             duration=datetime.now() - qry_start_dttm,
@@ -1016,16 +1069,11 @@ class ExploreMixin:
         if expression_type == utils.AdhocMetricExpressionType.SIMPLE:
             metric_column = metric.get("column") or {}
             column_name = cast(str, metric_column.get("column_name"))
-            table_column: Optional[TableColumn] = 
columns_by_name.get(column_name)
+            table_column: Optional[Dict[str, Any]] = 
columns_by_name.get(column_name)
             sqla_column = sa.column(column_name)
-            # todo(hughhh): understand how this works?
-            # if table_column:
-            #     sqla_column = table_column.get_sqla_col()
-            # else:
-            #     sqla_column = column(column_name)
             sqla_metric = 
self.sqla_aggregations[metric["aggregate"]](sqla_column)
         elif expression_type == utils.AdhocMetricExpressionType.SQL:
-            expression = _process_sql_expression(
+            expression = _process_sql_expression(  # type: ignore
                 expression=metric["sqlExpression"],
                 database_id=self.database_id,
                 schema=self.schema,
@@ -1295,7 +1343,7 @@ class ExploreMixin:
                 metrics_exprs.append(
                     self.adhoc_metric_to_sqla(
                         metric=metric,
-                        columns_by_name=columns_by_name,
+                        columns_by_name=columns_by_name,  # type: ignore
                         template_processor=template_processor,
                     )
                 )
@@ -1325,7 +1373,7 @@ class ExploreMixin:
             if isinstance(col, dict):
                 col = cast(AdhocMetric, col)
                 if col.get("sqlExpression"):
-                    col["sqlExpression"] = _process_sql_expression(
+                    col["sqlExpression"] = _process_sql_expression(  # type: 
ignore
                         expression=col["sqlExpression"],
                         database_id=self.database_id,
                         schema=self.schema,
@@ -1333,7 +1381,7 @@ class ExploreMixin:
                     )
                 if utils.is_adhoc_metric(col):
                     # add adhoc sort by column to columns_by_name if not exists
-                    col = self.adhoc_metric_to_sqla(col, columns_by_name)
+                    col = self.adhoc_metric_to_sqla(col, columns_by_name)  # 
type: ignore
                     # if the adhoc metric has been defined before
                     # use the existing instance.
                     col = metrics_exprs_by_expr.get(str(col), col)
@@ -1486,7 +1534,7 @@ class ExploreMixin:
             if flt_col == utils.DTTM_ALIAS and is_timeseries and dttm_col:
                 col_obj = dttm_col
             elif utils.is_adhoc_column(flt_col):
-                sqla_col = self.adhoc_column_to_sqla(flt_col)
+                sqla_col = self.adhoc_column_to_sqla(flt_col)  # type: ignore
             else:
                 col_obj = columns_by_name.get(flt_col)
             filter_grain = flt.get("grain")
@@ -1514,7 +1562,7 @@ class ExploreMixin:
                     col_type = col_obj.type if col_obj else None
                 col_spec = db_engine_spec.get_column_spec(
                     native_type=col_type,
-                    db_extra=self.database.get_extra(),
+                    db_extra=self.database.get_extra(),  # type: ignore
                 )
                 is_list_target = op in (
                     utils.FilterOperator.IN.value,
@@ -1538,7 +1586,7 @@ class ExploreMixin:
                     target_native_type=col_type,
                     is_list_target=is_list_target,
                     db_engine_spec=db_engine_spec,
-                    db_extra=self.database.get_extra(),
+                    db_extra=self.database.get_extra(),  # type: ignore
                 )
                 if (
                     col_advanced_data_type != ""
@@ -1627,7 +1675,7 @@ class ExploreMixin:
             where = extras.get("where")
             if where:
                 try:
-                    where = template_processor.process_template(f"({where})")
+                    where = template_processor.process_template(f"({where})")  
# type: ignore
                 except TemplateError as ex:
                     raise QueryObjectValidationError(
                         _(
@@ -1639,7 +1687,7 @@ class ExploreMixin:
             having = extras.get("having")
             if having:
                 try:
-                    having = template_processor.process_template(f"({having})")
+                    having = 
template_processor.process_template(f"({having})")  # type: ignore
                 except TemplateError as ex:
                     raise QueryObjectValidationError(
                         _(
@@ -1648,8 +1696,8 @@ class ExploreMixin:
                         )
                     ) from ex
                 having_clause_and += [self.text(having)]
-        if apply_fetch_values_predicate and self.fetch_values_predicate:
-            qry = qry.where(self.get_fetch_values_predicate())
+        if apply_fetch_values_predicate and self.fetch_values_predicate:  # 
type: ignore
+            qry = qry.where(self.get_fetch_values_predicate())  # type: ignore
         if granularity:
             qry = qry.where(and_(*(time_filters + where_clause_and)))
         else:
@@ -1755,7 +1803,7 @@ class ExploreMixin:
                     "order_desc": True,
                 }
 
-                result = self.query(prequery_obj)  # ignore: typing
+                result = self.query(prequery_obj)  # type: ignore
                 prequeries.append(result.query)
                 dimensions = [
                     c
diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py
index 6c6e10b514..d610410ded 100644
--- a/superset/models/sql_lab.py
+++ b/superset/models/sql_lab.py
@@ -225,18 +225,6 @@ class Query(Model, ExtraJSONMixin, ExploreMixin):
             "database": {"id": self.database_id, "backend": 
self.database.backend},
         }
 
-    @property
-    def data(self) -> Dict[str, Any]:
-        return {
-            "columns": self.columns,
-            "metrics": [],
-            "id": self.id,
-            "type": self.type,
-            "sql": self.sql,
-            "owners": self.owners_data,
-            "database": {"id": self.database_id, "backend": 
self.database.backend},
-        }
-
     def raise_for_access(self) -> None:
         """
         Raise an exception if the user cannot access the resource.
@@ -282,7 +270,7 @@ class Query(Model, ExtraJSONMixin, ExploreMixin):
     def main_dttm_col(self) -> Optional[str]:
         for col in self.columns:
             if col.get("is_dttm"):
-                return col.get("column_name")
+                return col.get("column_name")  # type: ignore
         return None
 
     @property
diff --git a/superset/utils/core.py b/superset/utils/core.py
index 0ecff893c5..f76995fb32 100644
--- a/superset/utils/core.py
+++ b/superset/utils/core.py
@@ -1739,7 +1739,7 @@ def get_time_filter_status(  # pylint: 
disable=too-many-branches
 
     # todo(hugh): fix this
     # temporal_columns = {col.column_name for col in datasource.columns if 
col.is_dttm}
-    temporal_columns = {}
+    temporal_columns: Dict[str, Any] = {}
     applied: List[Dict[str, str]] = []
     rejected: List[Dict[str, str]] = []
     time_column = applied_time_extras.get(ExtraFiltersTimeColumnType.TIME_COL)

Reply via email to