This is an automated email from the ASF dual-hosted git repository.

hugh pushed a commit to branch cpq-refactor-column
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 1e4d9d2e754a093bff944fa043d2de77d488e424
Author: hughhhh <[email protected]>
AuthorDate: Wed Sep 28 17:56:43 2022 -0400

    init
---
 superset/models/helpers.py              | 83 ++++++++++++++++-----------------
 tests/unit_tests/models/helpers_test.py | 43 +++++++++++++++++
 2 files changed, 83 insertions(+), 43 deletions(-)

diff --git a/superset/models/helpers.py b/superset/models/helpers.py
index c81d268a1e..39d0aa96cc 100644
--- a/superset/models/helpers.py
+++ b/superset/models/helpers.py
@@ -89,6 +89,7 @@ from superset.utils import core as utils
 from superset.utils.core import get_user_id
 
 if TYPE_CHECKING:
+    from superset.columns.models import Column as SLColumn
     from superset.connectors.sqla.models import SqlMetric, TableColumn
     from superset.db_engine_specs import BaseEngineSpec
     from superset.models.core import Database
@@ -1291,7 +1292,7 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
 
     def get_timestamp_expression(
         self,
-        column: Dict[str, Any],
+        column: "SLColumn",
         time_grain: Optional[str],
         label: Optional[str] = None,
         template_processor: Optional[BaseTemplateProcessor] = None,
@@ -1305,12 +1306,12 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
         :return: A TimeExpression object wrapped in a Label if supported by db
         """
         label = label or utils.DTTM_ALIAS
-        column_spec = self.db_engine_spec.get_column_spec(column.get("type"))
+        column_spec = self.db_engine_spec.get_column_spec(column.type)
         type_ = column_spec.sqla_type if column_spec else sa.DateTime
-        col = sa.column(column.get("column_name"), type_=type_)
+        col = sa.column(column.name, type_=type_)
 
         if template_processor:
-            expression = 
template_processor.process_template(column["column_name"])
+            expression = template_processor.process_template(column.name)
             col = sa.literal_column(expression, type_=type_)
 
         time_expr = self.db_engine_spec.get_timestamp_expr(col, None, 
time_grain)
@@ -1356,6 +1357,18 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
         extras = extras or {}
         time_grain = extras.get("time_grain_sqla")
 
+        # Create mapping for column_name -> ExploreColumn object
+        # todo(hugh): move this to a function to manage converting other column
+        # types into ExploreColumn for generating queries
+        from superset.columns.models import Column as SLColumn
+
+        columns_by_name: Dict[str, "SLColumn"] = {
+            col.get("column_name"): SLColumn(
+                name=col.get("column_name"), type=col.get("type")
+            )
+            for col in self.columns
+        }
+
         template_kwargs = {
             "columns": columns,
             "from_dttm": from_dttm.isoformat() if from_dttm else None,
@@ -1394,11 +1407,6 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
         if granularity not in self.dttm_cols and granularity is not None:
             granularity = self.main_dttm_col
 
-        columns_by_name: Dict[str, "TableColumn"] = {
-            col.get("column_name"): col
-            for col in self.columns  # col.column_name: col for col in 
self.columns
-        }
-
         if not granularity and is_timeseries:
             raise QueryObjectValidationError(
                 _(
@@ -1491,26 +1499,23 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
                     # if groupby field/expr equals granularity field/expr
                     if selected == granularity:
                         table_col = columns_by_name[selected]
-                        if isinstance(table_col, dict):
-                            outer = self.get_timestamp_expression(
-                                column=table_col,
-                                time_grain=time_grain,
-                                label=selected,
-                                template_processor=template_processor,
-                            )
-                        else:
-                            outer = table_col.get_timestamp_expression(
-                                time_grain=time_grain,
-                                label=selected,
-                                template_processor=template_processor,
-                            )
+                        outer = self.get_timestamp_expression(
+                            column=table_col,
+                            time_grain=time_grain,
+                            label=selected,
+                            template_processor=template_processor,
+                        )
+                        # else:
+                        #     outer = table_col.get_timestamp_expression(
+                        #         time_grain=time_grain,
+                        #         label=selected,
+                        #         template_processor=template_processor,
+                        #     )
                     # if groupby field equals a selected column
                     elif selected in columns_by_name:
-                        if isinstance(columns_by_name[selected], dict):
-                            outer = sa.column(f"{selected}")
-                            outer = self.make_sqla_column_compatible(outer, 
selected)
-                        else:
-                            outer = columns_by_name[selected].get_sqla_col()
+                        outer = self.make_sqla_column_compatible(
+                            sa.column(f"{selected}"), selected
+                        )
                     else:
                         selected = self.validate_adhoc_subquery(
                             selected,
@@ -1536,14 +1541,7 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
                     self.database_id,
                     self.schema,
                 )
-                if isinstance(columns_by_name[selected], dict):
-                    select_exprs.append(sa.column(f"{selected}"))
-                else:
-                    select_exprs.append(
-                        columns_by_name[selected].get_sqla_col()
-                        if selected in columns_by_name
-                        else 
self.make_sqla_column_compatible(literal_column(selected))
-                    )
+                select_exprs.append(sa.column(f"{selected}"))
             metrics_exprs = []
 
         if granularity:
@@ -1557,14 +1555,13 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
             time_filters: List[Any] = []
 
             if is_timeseries:
-                if isinstance(dttm_col, dict):
-                    timestamp = self.get_timestamp_expression(
-                        dttm_col, time_grain, 
template_processor=template_processor
-                    )
-                else:
-                    timestamp = dttm_col.get_timestamp_expression(
-                        time_grain=time_grain, 
template_processor=template_processor
-                    )
+                timestamp = self.get_timestamp_expression(
+                    dttm_col, time_grain, template_processor=template_processor
+                )
+                # else:
+                #     timestamp = dttm_col.get_timestamp_expression(
+                #         time_grain=time_grain, 
template_processor=template_processor
+                #     )
                 # always put timestamp as the first column
                 select_exprs.insert(0, timestamp)
                 groupby_all_columns[timestamp.name] = timestamp
diff --git a/tests/unit_tests/models/helpers_test.py 
b/tests/unit_tests/models/helpers_test.py
new file mode 100644
index 0000000000..8d97036cb0
--- /dev/null
+++ b/tests/unit_tests/models/helpers_test.py
@@ -0,0 +1,43 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: disable=import-outside-toplevel
+
+
+def test_explore_mixin_get_timestamp():
+    from superset.columns.models import Column as SLColumn
+    from superset.models.core import Database
+    from superset.models.sql_lab import Query
+
+    db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
+    query = Query(
+        client_id="foo",
+        database=db,
+        tab_name="test_tab",
+        sql_editor_id="test_editor_id",
+        sql="select * from bar",
+        select_sql="select * from bar",
+        executed_sql="select * from bar",
+        limit=100,
+        select_as_cta=False,
+        rows=100,
+        error_message="none",
+        results_key="abc",
+    )
+
+    col = SLColumn(name="foo", type="TIMESTAMP")
+    query.get_timestamp_expression(col, time_grain=None)

Reply via email to