This is an automated email from the ASF dual-hosted git repository.

lyndsi pushed a commit to branch 
lyndsi/sql-lab-new-explore-button-functionality-and-move-save-dataset-to-split-save-button
in repository https://gitbox.apache.org/repos/asf/superset.git

commit f992a027b5db71485fea86bc5d7d7f67e8c56cb6
Author: Hugh A. Miles II <[email protected]>
AuthorDate: Mon Jun 6 15:09:01 2022 +0000

    add POC ExploreMixin
---
 superset/models/helpers.py | 673 ++++++++++++++++-----------------------------
 1 file changed, 234 insertions(+), 439 deletions(-)

diff --git a/superset/models/helpers.py b/superset/models/helpers.py
index 6cd6d17c7b..08efb59b60 100644
--- a/superset/models/helpers.py
+++ b/superset/models/helpers.py
@@ -24,21 +24,19 @@ from datetime import datetime, timedelta
 from json.decoder import JSONDecodeError
 from typing import (
     Any,
+    Callable,
     cast,
     Dict,
+    Hashable,
     List,
-    Mapping,
     NamedTuple,
     Optional,
     Set,
-    Text,
     Tuple,
     Type,
-    TYPE_CHECKING,
     Union,
 )
 
-import dateutil.parser
 import humanize
 import numpy as np
 import pandas as pd
@@ -52,48 +50,33 @@ from flask_appbuilder.models.decorators import renders
 from flask_appbuilder.models.mixins import AuditMixin
 from flask_appbuilder.security.sqla.models import User
 from flask_babel import lazy_gettext as _
-from jinja2.exceptions import TemplateError
-from sqlalchemy import and_, Column, or_, UniqueConstraint
+from sqlalchemy import and_, or_, UniqueConstraint
 from sqlalchemy.ext.declarative import declared_attr
 from sqlalchemy.orm import Mapper, Session
 from sqlalchemy.orm.exc import MultipleResultsFound
-from sqlalchemy.sql.elements import ColumnElement, literal_column, TextClause
+from sqlalchemy.sql.elements import ColumnClause, TextClause
 from sqlalchemy.sql.expression import Label, Select, TextAsFrom
 from sqlalchemy.sql.selectable import Alias, TableClause
 from sqlalchemy_utils import UUIDType
 
-from superset import app, is_feature_enabled, security_manager
-from superset.advanced_data_type.types import AdvancedDataTypeResponse
+from superset import app, db, is_feature_enabled, security_manager
 from superset.common.db_query_status import QueryStatus
-from superset.constants import EMPTY_STRING, NULL_STRING
-from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
-from superset.exceptions import (
-    AdvancedDataTypeResponseError,
-    QueryClauseValidationException,
-    QueryObjectValidationError,
-    SupersetSecurityException,
+from superset.jinja_context import (
+    BaseTemplateProcessor,
+    ExtraCache,
+    get_template_processor,
 )
-from superset.extensions import feature_flag_manager
-from superset.jinja_context import BaseTemplateProcessor
-from superset.sql_parse import has_table_query, insert_rls, ParsedQuery, 
sanitize_clause
-from superset.superset_typing import (
-    AdhocMetric,
-    FilterValue,
-    FilterValues,
-    Metric,
-    OrderBy,
-    QueryObjectDict,
+from superset.sql_parse import (
+    extract_table_references,
+    ParsedQuery,
+    sanitize_clause,
+    Table as TableName,
 )
 from superset.utils import core as utils
-from superset.utils.core import get_user_id
 
-if TYPE_CHECKING:
-    from superset.connectors.sqla.models import SqlMetric, TableColumn
-    from superset.db_engine_specs import BaseEngineSpec
-    from superset.models.core import Database
+VIRTUAL_TABLE_ALIAS = "virtual_table"
 
 
-config = app.config
 logger = logging.getLogger(__name__)
 
 CTE_ALIAS = "__cte"
@@ -634,6 +617,17 @@ def clone_model(
     return target.__class__(**data)
 
 
+from typing import Any, Dict, List, NamedTuple
+
+import sqlparse
+from sqlalchemy import Column
+from sqlalchemy.sql.elements import ColumnElement, Label, literal_column
+
+from superset.exceptions import QueryObjectValidationError
+from superset.superset_typing import AdhocMetric, Metric, OrderBy, 
QueryObjectDict
+from superset.utils import core as utils
+
+
 # todo(hugh): centralize where this code lives
 class QueryStringExtended(NamedTuple):
     applied_template_filters: Optional[List[str]]
@@ -651,113 +645,63 @@ class SqlaQuery(NamedTuple):
     sqla_query: Select
 
 
-class ExploreMixin:  # pylint: disable=too-many-public-methods
+class ExploreMixin:
     """
     Allows any flask_appbuilder.Model (Query, Table, etc.)
     to be used to power a chart inside /explore
     """
 
-    sqla_aggregations = {
-        "COUNT_DISTINCT": lambda column_name: 
sa.func.COUNT(sa.distinct(column_name)),
-        "COUNT": sa.func.COUNT,
-        "SUM": sa.func.SUM,
-        "AVG": sa.func.AVG,
-        "MIN": sa.func.MIN,
-        "MAX": sa.func.MAX,
-    }
-
     @property
-    def query(self) -> str:
-        raise NotImplementedError()
+    def data(self):
+        return {"foo": "bar"}
 
     @property
-    def database_id(self) -> int:
-        raise NotImplementedError()
+    def owners_data(self):
+        return []
 
     @property
-    def owners_data(self) -> List[Any]:
-        raise NotImplementedError()
+    def metrics(self):
+        return []
 
     @property
-    def metrics(self) -> List[Any]:
-        raise NotImplementedError()
+    def uid(self):
+        return "foo"
 
     @property
-    def uid(self) -> str:
-        raise NotImplementedError()
+    def is_rls_supported(self):
+        return False
 
     @property
-    def is_rls_supported(self) -> bool:
-        raise NotImplementedError()
+    def cache_timeout(self):
+        return None
 
     @property
-    def cache_timeout(self) -> int:
-        raise NotImplementedError()
+    def column_names(self):
+        return ["ethnic_minority", "gender"]
 
     @property
-    def column_names(self) -> List[str]:
-        raise NotImplementedError()
+    def columns(self):
+        return ["<col_name>"]
 
     @property
-    def offset(self) -> int:
-        raise NotImplementedError()
+    def offset(self):
+        return 0
 
     @property
-    def main_dttm_col(self) -> Optional[str]:
-        raise NotImplementedError()
+    def main_dttm_col(self) -> str:  # todo - this should be a real column
+        return "ds"
 
     @property
     def dttm_cols(self) -> List[str]:
-        raise NotImplementedError()
-
-    @property
-    def db_engine_spec(self) -> Type["BaseEngineSpec"]:
-        raise NotImplementedError()
-
-    @property
-    def database(self) -> Type["Database"]:
-        raise NotImplementedError()
-
-    @property
-    def schema(self) -> str:
-        raise NotImplementedError()
-
-    @property
-    def sql(self) -> str:
-        raise NotImplementedError()
-
-    @property
-    def columns(self) -> List[Any]:
-        raise NotImplementedError()
-
-    @property
-    def get_fetch_values_predicate(self) -> List[Any]:
-        raise NotImplementedError()
+        return []
+        # l = [c.column_name for c in self.columns if c.is_dttm]
+        # if self.main_dttm_col and self.main_dttm_col not in l:
+        #     l.append(self.main_dttm_col)
+        # return l
 
     @staticmethod
-    def get_extra_cache_keys(query_obj: Dict[str, Any]) -> List[str]:
-        raise NotImplementedError()
-
-    def _process_sql_expression(  # type: ignore # pylint: disable=no-self-use
-        self,
-        expression: Optional[str],
-        database_id: int,
-        schema: str,
-        template_processor: Optional[BaseTemplateProcessor],
-    ) -> Optional[str]:
-        if template_processor and expression:
-            expression = template_processor.process_template(expression)
-        if expression:
-            expression = validate_adhoc_subquery(
-                expression,
-                database_id,
-                schema,
-            )
-            try:
-                expression = sanitize_clause(expression)
-            except QueryClauseValidationException as ex:
-                raise QueryObjectValidationError(ex.message) from ex
-        return expression
+    def get_extra_cache_keys(query_obj):
+        return []
 
     def make_sqla_column_compatible(
         self, sqla_col: ColumnElement, label: Optional[str] = None
@@ -776,72 +720,12 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
         sqla_col.key = label_expected
         return sqla_col
 
-    def mutate_query_from_config(self, sql: str) -> str:
-        """Apply config's SQL_QUERY_MUTATOR
-
-        Typically adds comments to the query with context"""
-        sql_query_mutator = config["SQL_QUERY_MUTATOR"]
-        if sql_query_mutator:
-            sql = sql_query_mutator(
-                sql,
-                user_name=utils.get_username(),  # TODO(john-bodley): 
Deprecate in 3.0.
-                security_manager=security_manager,
-                database=self.database,
-            )
-        return sql
-
-    @staticmethod
-    def _apply_cte(sql: str, cte: Optional[str]) -> str:
-        """
-        Append a CTE before the SELECT statement if defined
-
-        :param sql: SELECT statement
-        :param cte: CTE statement
-        :return:
-        """
-        if cte:
-            sql = f"{cte}\n{sql}"
-        return sql
-
-    @staticmethod
-    def validate_adhoc_subquery(
-        sql: str,
-        database_id: int,
-        default_schema: str,
-    ) -> str:
-        """
-        Check if adhoc SQL contains sub-queries or nested sub-queries with 
table.
-
-        If sub-queries are allowed, the adhoc SQL is modified to insert any 
applicable RLS
-        predicates to it.
-
-        :param sql: adhoc sql expression
-        :raise SupersetSecurityException if sql contains sub-queries or
-        nested sub-queries with table
-        """
-
-        statements = []
-        for statement in sqlparse.parse(sql):
-            if has_table_query(statement):
-                if not is_feature_enabled("ALLOW_ADHOC_SUBQUERY"):
-                    raise SupersetSecurityException(
-                        SupersetError(
-                            
error_type=SupersetErrorType.ADHOC_SUBQUERY_NOT_ALLOWED_ERROR,
-                            message=_("Custom SQL fields cannot contain 
sub-queries."),
-                            level=ErrorLevel.ERROR,
-                        )
-                    )
-                statement = insert_rls(statement, database_id, default_schema)
-            statements.append(statement)
-
-        return ";\n".join(str(statement) for statement in statements)
-
     def get_query_str_extended(self, query_obj: QueryObjectDict) -> 
QueryStringExtended:
         sqlaq = self.get_sqla_query(**query_obj)
-        sql = self.database.compile_sqla_query(sqlaq.sqla_query)  # type: 
ignore
-        sql = self._apply_cte(sql, sqlaq.cte)
+        sql = self.database.compile_sqla_query(sqlaq.sqla_query)
+        # sql = self._apply_cte(sql, sqlaq.cte)
         sql = sqlparse.format(sql, reindent=True)
-        sql = self.mutate_query_from_config(sql)
+        # sql = self.mutate_query_from_config(sql)
         return QueryStringExtended(
             applied_template_filters=sqlaq.applied_template_filters,
             labels_expected=sqlaq.labels_expected,
@@ -849,43 +733,6 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
             sql=sql,
         )
 
-    def _normalize_prequery_result_type(
-        self,
-        row: pd.Series,
-        dimension: str,
-        columns_by_name: Dict[str, "TableColumn"],
-    ) -> Union[str, int, float, bool, Text]:
-        """
-        Convert a prequery result type to its equivalent Python type.
-
-        Some databases like Druid will return timestamps as strings, but do 
not perform
-        automatic casting when comparing these strings to a timestamp. For 
cases like
-        this we convert the value via the appropriate SQL transform.
-
-        :param row: A prequery record
-        :param dimension: The dimension name
-        :param columns_by_name: The mapping of columns by name
-        :return: equivalent primitive python type
-        """
-
-        value = row[dimension]
-
-        if isinstance(value, np.generic):
-            value = value.item()
-
-        column_ = columns_by_name[dimension]
-        db_extra: Dict[str, Any] = self.database.get_extra()  # type: ignore
-
-        if column_.type and column_.is_temporal and isinstance(value, str):
-            sql = self.db_engine_spec.convert_dttm(
-                column_.type, dateutil.parser.parse(value), db_extra=db_extra
-            )
-
-            if sql:
-                value = self.text(sql)
-
-        return value
-
     def make_orderby_compatible(
         self, select_exprs: List[ColumnElement], orderby_exprs: 
List[ColumnElement]
     ) -> None:
@@ -917,48 +764,56 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
 
     def exc_query(self, qry: Any) -> QueryResult:
         qry_start_dttm = datetime.now()
+        # todo(hugh): apply filters for extended query
         query_str_ext = self.get_query_str_extended(qry)
         sql = query_str_ext.sql
         status = QueryStatus.SUCCESS
         errors = None
         error_message = None
 
-        def assign_column_label(df: pd.DataFrame) -> Optional[pd.DataFrame]:
-            """
-            Some engines change the case or generate bespoke column names, 
either by
-            default or due to lack of support for aliasing. This function 
ensures that
-            the column names in the DataFrame correspond to what is expected by
-            the viz components.
-            Sometimes a query may also contain only order by columns that are 
not used
-            as metrics or groupby columns, but need to present in the SQL 
`select`,
-            filtering by `labels_expected` make sure we only return columns 
users want.
-            :param df: Original DataFrame returned by the engine
-            :return: Mutated DataFrame
-            """
-            labels_expected = query_str_ext.labels_expected
-            if df is not None and not df.empty:
-                if len(df.columns) < len(labels_expected):
-                    raise QueryObjectValidationError(
-                        _("Db engine did not return all queried columns")
-                    )
-                if len(df.columns) > len(labels_expected):
-                    df = df.iloc[:, 0 : len(labels_expected)]
-                df.columns = labels_expected
-            return df
+        # def assign_column_label(df: pd.DataFrame) -> Optional[pd.DataFrame]:
+        #     """
+        #     Some engines change the case or generate bespoke column names, 
either by
+        #     default or due to lack of support for aliasing. This function 
ensures that
+        #     the column names in the DataFrame correspond to what is expected 
by
+        #     the viz components.
+        #     Sometimes a query may also contain only order by columns that 
are not used
+        #     as metrics or groupby columns, but need to present in the SQL 
`select`,
+        #     filtering by `labels_expected` make sure we only return columns 
users want.
+        #     :param df: Original DataFrame returned by the engine
+        #     :return: Mutated DataFrame
+        #     """
+        #     labels_expected = query_str_ext.labels_expected
+        #     if df is not None and not df.empty:
+        #         if len(df.columns) < len(labels_expected):
+        #             raise QueryObjectValidationError(
+        #                 _("Db engine did not return all queried columns")
+        #             )
+        #         if len(df.columns) > len(labels_expected):
+        #             df = df.iloc[:, 0: len(labels_expected)]
+        #         df.columns = labels_expected
+        #     return df
 
         try:
-            df = self.database.get_df(
-                sql, self.schema, mutator=assign_column_label  # type: ignore
-            )
+            # todo(hugh) fix this
+            # df = self.database.get_df(
+            #     sql, self.schema, mutator=assign_column_label)
+            df = self.database.get_df(sql, self.schema)
         except Exception as ex:  # pylint: disable=broad-except
             df = pd.DataFrame()
             status = QueryStatus.FAILED
             logger.warning(
                 "Query %s on schema %s failed", sql, self.schema, exc_info=True
             )
+            # todo(hugh): how are we handling errors
+            # db_engine_spec = self.db_engine_spec
+            # errors = [
+            #     dataclasses.asdict(error) for error in 
db_engine_spec.extract_errors(ex)
+            # ]
             error_message = utils.error_msg_from_exception(ex)
 
         return QueryResult(
+            # applied_template_filters=query_str_ext.applied_template_filters,
             status=status,
             df=df,
             duration=datetime.now() - qry_start_dttm,
@@ -1005,6 +860,9 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
         or a virtual table with it's own subquery. If the FROM is referencing a
         CTE, the CTE is returned as the second value in the return tuple.
         """
+        # todo(hugh): fix this
+        # if not self.is_virtual:
+        #     return self.get_sqla_table(), None
 
         from_sql = self.get_rendered_sql(template_processor)
         parsed_query = ParsedQuery(from_sql)
@@ -1018,49 +876,13 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
 
         cte = self.db_engine_spec.get_cte_query(from_sql)
         from_clause = (
-            sa.table(CTE_ALIAS)
+            table(CTE_ALIAS)
             if cte
             else TextAsFrom(self.text(from_sql), []).alias(VIRTUAL_TABLE_ALIAS)
         )
 
         return from_clause, cte
 
-    def adhoc_metric_to_sqla(
-        self,
-        metric: AdhocMetric,
-        columns_by_name: Dict[str, "TableColumn"],  # # pylint: 
disable=unused-argument
-        template_processor: Optional[BaseTemplateProcessor] = None,
-    ) -> ColumnElement:
-        """
-        Turn an adhoc metric into a sqlalchemy column.
-
-        :param dict metric: Adhoc metric definition
-        :param dict columns_by_name: Columns for the current table
-        :param template_processor: template_processor instance
-        :returns: The metric defined as a sqlalchemy column
-        :rtype: sqlalchemy.sql.column
-        """
-        expression_type = metric.get("expressionType")
-        label = utils.get_metric_name(metric)
-
-        if expression_type == utils.AdhocMetricExpressionType.SIMPLE:
-            metric_column = metric.get("column") or {}
-            column_name = cast(str, metric_column.get("column_name"))
-            sqla_column = sa.column(column_name)
-            sqla_metric = 
self.sqla_aggregations[metric["aggregate"]](sqla_column)
-        elif expression_type == utils.AdhocMetricExpressionType.SQL:
-            expression = self._process_sql_expression(  # type: ignore
-                expression=metric["sqlExpression"],
-                database_id=self.database_id,
-                schema=self.schema,
-                template_processor=template_processor,
-            )
-            sqla_metric = literal_column(expression)
-        else:
-            raise QueryObjectValidationError("Adhoc metric expressionType is 
invalid")
-
-        return self.make_sqla_column_compatible(sqla_metric, label)
-
     @property
     def template_params_dict(self) -> Dict[Any, Any]:
         return {}
@@ -1247,7 +1069,7 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
         row_offset: Optional[int] = None,
         timeseries_limit: Optional[int] = None,
         timeseries_limit_metric: Optional[Metric] = None,
-    ) -> SqlaQuery:
+    ) -> Any:
         """Querying any sqla table from this common interface"""
         if granularity not in self.dttm_cols and granularity is not None:
             granularity = self.main_dttm_col
@@ -1265,7 +1087,7 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
             "time_column": granularity,
             "time_grain": time_grain,
             "to_dttm": to_dttm.isoformat() if to_dttm else None,
-            "table_columns": [col.get("column_name") for col in self.columns],
+            "table_columns": self.column_names,
             "filter": filter,
         }
         columns = columns or []
@@ -1275,14 +1097,16 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
         if is_timeseries and timeseries_limit:
             series_limit = timeseries_limit
         series_limit_metric = series_limit_metric or timeseries_limit_metric
-        template_kwargs.update(self.template_params_dict)
+        template_kwargs.update(self.template_params_dict)  # todo
         extra_cache_keys: List[Any] = []
         template_kwargs["extra_cache_keys"] = extra_cache_keys
         removed_filters: List[str] = []
         applied_template_filters: List[str] = []
         template_kwargs["removed_filters"] = removed_filters
         template_kwargs["applied_filters"] = applied_template_filters
-        template_processor = None  # 
self.get_template_processor(**template_kwargs)
+        template_processor = (
+            None  # self.get_template_processor(**template_kwargs) #todo
+        )
         db_engine_spec = self.db_engine_spec
         prequeries: List[str] = []
         orderby = orderby or []
@@ -1293,36 +1117,40 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
         if granularity not in self.dttm_cols and granularity is not None:
             granularity = self.main_dttm_col
 
-        columns_by_name: Dict[str, "TableColumn"] = {
-            col.get("column_name"): col
-            for col in self.columns  # col.column_name: col for col in 
self.columns
-        }
-
-        if not granularity and is_timeseries:
-            raise QueryObjectValidationError(
-                _(
-                    "Datetime column not provided as part table configuration "
-                    "and is required by this type of chart"
-                )
-            )
-        if not metrics and not columns and not groupby:
-            raise QueryObjectValidationError(_("Empty query?"))
+        # columns_by_name: Dict[str, sa.Table] = {
+        #     col.column_name: col for col in self.columns
+        # }
+        # todo(hugh): fix this
+        columns_by_name = {}
+
+        # todo(hugh): how are we handling metrics
+        # metrics_by_name: Dict[str, Column] = {  # todo column vs metric?
+        #     m.metric_name: m for m in self.metrics
+        # }
+        metrics_by_name: Dict[str, Column] = {}
+
+        # if not granularity and is_timeseries:
+        #     raise QueryObjectValidationError(
+        #         _(
+        #             "Datetime column not provided as part table 
configuration "
+        #             "and is required by this type of chart"
+        #         )
+        #     )
+        # if not metrics and not columns and not groupby:
+        #     raise QueryObjectValidationError(_("Empty query?"))
 
         metrics_exprs: List[ColumnElement] = []
-        for metric in metrics:
-            if utils.is_adhoc_metric(metric):
-                assert isinstance(metric, dict)
-                metrics_exprs.append(
-                    self.adhoc_metric_to_sqla(
-                        metric=metric,
-                        columns_by_name=columns_by_name,  # type: ignore
-                        template_processor=template_processor,
-                    )
-                )
-            else:
-                raise QueryObjectValidationError(
-                    _("Metric '%(metric)s' does not exist", metric=metric)
-                )
+        # for metric in metrics:
+        #     if utils.is_adhoc_metric(metric):
+        #         assert isinstance(metric, dict)
+        #         # metrics_exprs.append(
+        #         #     self.adhoc_metric_to_sqla(metric, columns_by_name))
+        #     elif isinstance(metric, str) and metric in metrics_by_name:
+        #         metrics_exprs.append(metrics_by_name[metric].get_sqla_col())
+        #     else:
+        #         raise QueryObjectValidationError(
+        #             _("Metric '%(metric)s' does not exist", metric=metric)
+        #         )
 
         if metrics_exprs:
             main_metric_expr = metrics_exprs[0]
@@ -1342,16 +1170,10 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
             col: Union[AdhocMetric, ColumnElement] = orig_col
             if isinstance(col, dict):
                 col = cast(AdhocMetric, col)
-                if col.get("sqlExpression"):
-                    col["sqlExpression"] = self._process_sql_expression(  # 
type: ignore
-                        expression=col["sqlExpression"],
-                        database_id=self.database_id,
-                        schema=self.schema,
-                        template_processor=template_processor,
-                    )
                 if utils.is_adhoc_metric(col):
                     # add adhoc sort by column to columns_by_name if not exists
-                    col = self.adhoc_metric_to_sqla(col, columns_by_name)  # 
type: ignore
+                    # todo(hugh): figure out if we should have metrics
+                    # col = self.adhoc_metric_to_sqla(col, columns_by_name)
                     # if the adhoc metric has been defined before
                     # use the existing instance.
                     col = metrics_exprs_by_expr.get(str(col), col)
@@ -1361,14 +1183,18 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
             elif col in metrics_exprs_by_label:
                 col = metrics_exprs_by_label[col]
                 need_groupby = True
+            elif col in metrics_by_name:
+                col = metrics_by_name[col].get_sqla_col()
+                need_groupby = True
 
-            if isinstance(col, ColumnElement):
-                orderby_exprs.append(col)
-            else:
-                # Could not convert a column reference to valid ColumnElement
-                raise QueryObjectValidationError(
-                    _("Unknown column used in orderby: %(col)s", col=orig_col)
-                )
+            # todo(hugh): fix this
+            # if isinstance(col, ColumnElement):
+            #     orderby_exprs.append(col)
+            # else:
+            #     # Could not convert a column reference to valid ColumnElement
+            #     raise QueryObjectValidationError(
+            #         _("Unknown column used in orderby: %(col)s", 
col=orig_col)
+            #     )
 
         select_exprs: List[Union[Column, Label]] = []
         groupby_all_columns = {}
@@ -1389,26 +1215,17 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
                         outer = table_col.get_timestamp_expression(
                             time_grain=time_grain,
                             label=selected,
-                            template_processor=template_processor,
+                            # template_processor=template_processor,
                         )
                     # if groupby field equals a selected column
                     elif selected in columns_by_name:
-                        if isinstance(columns_by_name[selected], dict):
-                            outer = sa.column(f"{selected}")
-                            outer = self.make_sqla_column_compatible(outer, 
selected)
-                        else:
-                            outer = columns_by_name[selected].get_sqla_col()
+                        outer = columns_by_name[selected].get_sqla_col()
                     else:
-                        selected = self.validate_adhoc_subquery(
-                            selected,
-                            self.database_id,
-                            self.schema,
-                        )
-                        outer = sa.column(f"{selected}")
+                        outer = literal_column(f"({selected})")
                         outer = self.make_sqla_column_compatible(outer, 
selected)
                 else:
                     outer = self.adhoc_column_to_sqla(
-                        col=selected, template_processor=template_processor
+                        col=selected,  # template_processor=template_processor
                     )
                 groupby_all_columns[outer.name] = outer
                 if not series_column_names or outer.name in 
series_column_names:
@@ -1416,38 +1233,45 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
                 select_exprs.append(outer)
         elif columns:
             for selected in columns:
-                selected = self.validate_adhoc_subquery(
-                    selected,
-                    self.database_id,
-                    self.schema,
+                select_exprs.append(
+                    columns_by_name[selected].get_sqla_col()
+                    if selected in columns_by_name
+                    else 
self.make_sqla_column_compatible(literal_column(selected))
                 )
-                if isinstance(columns_by_name[selected], dict):
-                    select_exprs.append(sa.column(f"{selected}"))
-                else:
-                    select_exprs.append(
-                        columns_by_name[selected].get_sqla_col()
-                        if selected in columns_by_name
-                        else 
self.make_sqla_column_compatible(literal_column(selected))
-                    )
             metrics_exprs = []
 
-        if granularity:
-            if granularity not in columns_by_name or not dttm_col:
-                raise QueryObjectValidationError(
-                    _(
-                        'Time column "%(col)s" does not exist in dataset',
-                        col=granularity,
-                    )
-                )
-            time_filters: List[Any] = []
-
-            if is_timeseries:
-                timestamp = dttm_col.get_timestamp_expression(
-                    time_grain=time_grain, 
template_processor=template_processor
-                )
-                # always put timestamp as the first column
-                select_exprs.insert(0, timestamp)
-                groupby_all_columns[timestamp.name] = timestamp
+        # todo(hugh): fix this
+        # if granularity:
+        #     if granularity not in columns_by_name or not dttm_col:
+        #         raise QueryObjectValidationError(
+        #             _(
+        #                 'Time column "%(col)s" does not exist in dataset',
+        #                 col=granularity,
+        #             )
+        #         )
+        #     time_filters = []
+
+        #     if is_timeseries:
+        #         timestamp = dttm_col.get_timestamp_expression(
+        #             time_grain=time_grain, 
template_processor=template_processor
+        #         )
+        #         # always put timestamp as the first column
+        #         select_exprs.insert(0, timestamp)
+        #         groupby_all_columns[timestamp.name] = timestamp
+
+        #     # Use main dttm column to support index with secondary dttm 
columns.
+        #     if (
+        #         db_engine_spec.time_secondary_columns
+        #         and self.main_dttm_col in self.dttm_cols
+        #         and self.main_dttm_col != dttm_col.column_name
+        #     ):
+        #         time_filters.append(
+        #             columns_by_name[self.main_dttm_col].get_time_filter(
+        #                 from_dttm,
+        #                 to_dttm,
+        #             )
+        #         )
+        #     time_filters.append(dttm_col.get_time_filter(from_dttm, to_dttm))
 
         # Always remove duplicates by column name, as sometimes `metrics_exprs`
         # can have the same name as a groupby column (e.g. when users use
@@ -1466,6 +1290,8 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
 
         qry = sa.select(select_exprs)
 
+        # todo(hugh) fix templating
+        # tbl, cte = self.get_from_clause(template_processor)
         tbl, cte = self.get_from_clause(template_processor)
 
         if groupby_all_columns:
@@ -1480,18 +1306,18 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
             flt_col = flt["col"]
             val = flt.get("val")
             op = flt["op"].upper()
-            col_obj: Optional["TableColumn"] = None
-            sqla_col: Optional[Column] = None
+            col_obj: Optional[Column] = None
+            sqla_col: Optional[sa.Column] = None
             if flt_col == utils.DTTM_ALIAS and is_timeseries and dttm_col:
                 col_obj = dttm_col
             elif utils.is_adhoc_column(flt_col):
-                sqla_col = self.adhoc_column_to_sqla(flt_col)  # type: ignore
+                sqla_col = self.adhoc_column_to_sqla(flt_col)
             else:
                 col_obj = columns_by_name.get(flt_col)
             filter_grain = flt.get("grain")
 
             if is_feature_enabled("ENABLE_TEMPLATE_REMOVE_FILTERS"):
-                if utils.get_column_name(flt_col) in removed_filters:
+                if get_column_name(flt_col) in removed_filters:
                     # Skip generating SQLA filter when the jinja template 
handles it.
                     continue
 
@@ -1502,81 +1328,35 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
                     sqla_col = col_obj.get_timestamp_expression(
                         time_grain=filter_grain, 
template_processor=template_processor
                     )
-                elif col_obj and isinstance(col_obj, dict):
-                    sqla_col = sa.column(col_obj.get("column_name"))
                 elif col_obj:
                     sqla_col = col_obj.get_sqla_col()
-
-                if col_obj and isinstance(col_obj, dict):
-                    col_type = col_obj.get("type")
-                else:
-                    col_type = col_obj.type if col_obj else None
                 col_spec = db_engine_spec.get_column_spec(
-                    native_type=col_type,
-                    db_extra=self.database.get_extra(),  # type: ignore
+                    col_obj.type if col_obj else None
                 )
                 is_list_target = op in (
                     utils.FilterOperator.IN.value,
                     utils.FilterOperator.NOT_IN.value,
                 )
-
-                if col_obj and isinstance(col_obj, dict):
-                    col_advanced_data_type = ""
+                if col_spec:
+                    target_type = col_spec.generic_type
                 else:
-                    col_advanced_data_type = (
-                        col_obj.advanced_data_type if col_obj else ""
-                    )
-
-                if col_spec and not col_advanced_data_type:
-                    target_generic_type = col_spec.generic_type
-                else:
-                    target_generic_type = utils.GenericDataType.STRING
+                    target_type = GenericDataType.STRING
                 eq = self.filter_values_handler(
                     values=val,
-                    target_generic_type=target_generic_type,
-                    target_native_type=col_type,
+                    target_column_type=target_type,
                     is_list_target=is_list_target,
-                    db_engine_spec=db_engine_spec,
-                    db_extra=self.database.get_extra(),  # type: ignore
                 )
-                if (
-                    col_advanced_data_type != ""
-                    and feature_flag_manager.is_feature_enabled(
-                        "ENABLE_ADVANCED_DATA_TYPES"
-                    )
-                    and col_advanced_data_type in ADVANCED_DATA_TYPES
-                ):
-                    values = eq if is_list_target else [eq]  # type: ignore
-                    bus_resp: AdvancedDataTypeResponse = ADVANCED_DATA_TYPES[
-                        col_advanced_data_type
-                    ].translate_type(
-                        {
-                            "type": col_advanced_data_type,
-                            "values": values,
-                        }
-                    )
-                    if bus_resp["error_message"]:
-                        raise AdvancedDataTypeResponseError(
-                            _(bus_resp["error_message"])
-                        )
-
-                    where_clause_and.append(
-                        
ADVANCED_DATA_TYPES[col_advanced_data_type].translate_filter(
-                            sqla_col, op, bus_resp["values"]
-                        )
-                    )
-                elif is_list_target:
+                if is_list_target:
                     assert isinstance(eq, (tuple, list))
                     if len(eq) == 0:
                         raise QueryObjectValidationError(
                             _("Filter value list cannot be empty")
                         )
-                    if len(eq) > len(
-                        eq_without_none := [x for x in eq if x is not None]
-                    ):
+                    if None in eq:
+                        eq = [x for x in eq if x is not None]
                         is_null_cond = sqla_col.is_(None)
                         if eq:
-                            cond = or_(is_null_cond, 
sqla_col.in_(eq_without_none))
+                            cond = or_(is_null_cond, sqla_col.in_(eq))
                         else:
                             cond = is_null_cond
                     else:
@@ -1620,15 +1400,13 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
                         raise QueryObjectValidationError(
                             _("Invalid filter operation type: %(op)s", op=op)
                         )
-        # todo(hugh): fix this w/ template_processor
-        # where_clause_and += 
self.get_sqla_row_level_filters(template_processor)
+        if is_feature_enabled("ROW_LEVEL_SECURITY"):
+            where_clause_and += 
self._get_sqla_row_level_filters(template_processor)
         if extras:
             where = extras.get("where")
             if where:
                 try:
-                    where = template_processor.process_template(  # type: 
ignore
-                        f"({where})"
-                    )
+                    where = template_processor.process_template(where)
                 except TemplateError as ex:
                     raise QueryObjectValidationError(
                         _(
@@ -1636,13 +1414,11 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
                             msg=ex.message,
                         )
                     ) from ex
-                where_clause_and += [self.text(where)]
+                where_clause_and += [self.text(f"({where})")]
             having = extras.get("having")
             if having:
                 try:
-                    having = template_processor.process_template(  # type: 
ignore
-                        f"({having})"
-                    )
+                    having = template_processor.process_template(having)
                 except TemplateError as ex:
                     raise QueryObjectValidationError(
                         _(
@@ -1650,10 +1426,13 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
                             msg=ex.message,
                         )
                     ) from ex
-                having_clause_and += [self.text(having)]
-        if apply_fetch_values_predicate and self.fetch_values_predicate:  # 
type: ignore
-            qry = qry.where(self.get_fetch_values_predicate())  # type: ignore
+                having_clause_and += [self.text(f"({having})")]
+        if apply_fetch_values_predicate and self.fetch_values_predicate:
+            qry = qry.where(self.get_fetch_values_predicate())
         if granularity:
+            time_filters = (
+                []
+            )  # todo(hugh): remove this once time filters are actually set
             qry = qry.where(and_(*(time_filters + where_clause_and)))
         else:
             qry = qry.where(and_(*where_clause_and))
@@ -1673,7 +1452,7 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
                 and col.name in [select_col.name for select_col in 
select_exprs]
             ):
                 col = literal_column(col.name)
-            direction = sa.asc if ascending else sa.desc
+            direction = asc if ascending else desc
             qry = qry.order_by(direction(col))
 
         if row_limit:
@@ -1692,13 +1471,13 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
                 inner_groupby_exprs = []
                 inner_select_exprs = []
                 for gby_name, gby_obj in groupby_series_columns.items():
-                    label = utils.get_column_name(gby_name)
+                    label = get_column_name(gby_name)
                     inner = self.make_sqla_column_compatible(gby_obj, gby_name 
+ "__")
                     inner_groupby_exprs.append(inner)
                     inner_select_exprs.append(inner)
 
                 inner_select_exprs += [inner_main_metric_expr]
-                subq = sa.select(inner_select_exprs).select_from(tbl)
+                subq = select(inner_select_exprs).select_from(tbl)
                 inner_time_filter = []
 
                 if dttm_col and not db_engine_spec.time_groupby_inline:
@@ -1712,7 +1491,11 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
                 subq = subq.group_by(*inner_groupby_exprs)
 
                 ob = inner_main_metric_expr
-                direction = sa.desc if order_desc else sa.asc
+                if series_limit_metric:
+                    ob = self._get_series_orderby(
+                        series_limit_metric, metrics_by_name, columns_by_name
+                    )
+                direction = desc if order_desc else asc
                 subq = subq.order_by(direction(ob))
                 subq = subq.limit(series_limit)
 
@@ -1722,9 +1505,21 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
                     # conditionally mutated, as it refers to the column alias 
in
                     # the inner query
                     col_name = db_engine_spec.make_label_compatible(gby_name + 
"__")
-                    on_clause.append(gby_obj == sa.column(col_name))
+                    on_clause.append(gby_obj == column(col_name))
 
                 tbl = tbl.join(subq.alias(), and_(*on_clause))
+            else:
+                if series_limit_metric:
+                    orderby = [
+                        (
+                            self._get_series_orderby(
+                                series_limit_metric,
+                                metrics_by_name,
+                                columns_by_name,
+                            ),
+                            not order_desc,
+                        )
+                    ]
 
                 # run prequery to get top groups
                 prequery_obj = {
@@ -1742,7 +1537,7 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
                     "order_desc": True,
                 }
 
-                result = self.query(prequery_obj)  # type: ignore
+                result = self.query(prequery_obj)
                 prequeries.append(result.query)
                 dimensions = [
                     c
@@ -1763,7 +1558,7 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
                 )
             label = "rowcount"
             col = self.make_sqla_column_compatible(literal_column("COUNT(*)"), 
label)
-            qry = sa.select([col]).select_from(qry.alias("rowcount_qry"))
+            qry = select([col]).select_from(qry.alias("rowcount_qry"))
             labels_expected = [label]
 
         return SqlaQuery(

Reply via email to