This is an automated email from the ASF dual-hosted git repository. hugh pushed a commit to branch cpq-refactor-column in repository https://gitbox.apache.org/repos/asf/superset.git
commit 1e4d9d2e754a093bff944fa043d2de77d488e424 Author: hughhhh <[email protected]> AuthorDate: Wed Sep 28 17:56:43 2022 -0400 init --- superset/models/helpers.py | 83 ++++++++++++++++----------------- tests/unit_tests/models/helpers_test.py | 43 +++++++++++++++++ 2 files changed, 83 insertions(+), 43 deletions(-) diff --git a/superset/models/helpers.py b/superset/models/helpers.py index c81d268a1e..39d0aa96cc 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -89,6 +89,7 @@ from superset.utils import core as utils from superset.utils.core import get_user_id if TYPE_CHECKING: + from superset.columns.models import Column as SLColumn from superset.connectors.sqla.models import SqlMetric, TableColumn from superset.db_engine_specs import BaseEngineSpec from superset.models.core import Database @@ -1291,7 +1292,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods def get_timestamp_expression( self, - column: Dict[str, Any], + column: "SLColumn", time_grain: Optional[str], label: Optional[str] = None, template_processor: Optional[BaseTemplateProcessor] = None, @@ -1305,12 +1306,12 @@ class ExploreMixin: # pylint: disable=too-many-public-methods :return: A TimeExpression object wrapped in a Label if supported by db """ label = label or utils.DTTM_ALIAS - column_spec = self.db_engine_spec.get_column_spec(column.get("type")) + column_spec = self.db_engine_spec.get_column_spec(column.type) type_ = column_spec.sqla_type if column_spec else sa.DateTime - col = sa.column(column.get("column_name"), type_=type_) + col = sa.column(column.name, type_=type_) if template_processor: - expression = template_processor.process_template(column["column_name"]) + expression = template_processor.process_template(column.name) col = sa.literal_column(expression, type_=type_) time_expr = self.db_engine_spec.get_timestamp_expr(col, None, time_grain) @@ -1356,6 +1357,18 @@ class ExploreMixin: # pylint: disable=too-many-public-methods extras = extras or {} time_grain = extras.get("time_grain_sqla") + # Create mapping for column_name -> ExploreColumn object + # todo(hugh): move this to a function to manage converting other column + # types into ExploreColumn for generating queries + from superset.columns.models import Column as SLColumn + + columns_by_name: Dict[str, "SLColumn"] = { + col.get("column_name"): SLColumn( + name=col.get("column_name"), type=col.get("type") + ) + for col in self.columns + } + template_kwargs = { "columns": columns, "from_dttm": from_dttm.isoformat() if from_dttm else None, @@ -1394,11 +1407,6 @@ class ExploreMixin: # pylint: disable=too-many-public-methods if granularity not in self.dttm_cols and granularity is not None: granularity = self.main_dttm_col - columns_by_name: Dict[str, "TableColumn"] = { - col.get("column_name"): col - for col in self.columns # col.column_name: col for col in self.columns - } - if not granularity and is_timeseries: raise QueryObjectValidationError( _( @@ -1491,26 +1499,23 @@ class ExploreMixin: # pylint: disable=too-many-public-methods # if groupby field/expr equals granularity field/expr if selected == granularity: table_col = columns_by_name[selected] - if isinstance(table_col, dict): - outer = self.get_timestamp_expression( - column=table_col, - time_grain=time_grain, - label=selected, - template_processor=template_processor, - ) - else: - outer = table_col.get_timestamp_expression( - time_grain=time_grain, - label=selected, - template_processor=template_processor, - ) + outer = self.get_timestamp_expression( + column=table_col, + time_grain=time_grain, + label=selected, + template_processor=template_processor, + ) + # else: + # outer = table_col.get_timestamp_expression( + # time_grain=time_grain, + # label=selected, + # template_processor=template_processor, + # ) # if groupby field equals a selected column elif selected in columns_by_name: - if isinstance(columns_by_name[selected], dict): - outer = sa.column(f"{selected}") - outer = self.make_sqla_column_compatible(outer, selected) - else: - outer = columns_by_name[selected].get_sqla_col() + outer = self.make_sqla_column_compatible( + sa.column(f"{selected}"), selected + ) else: selected = self.validate_adhoc_subquery( selected, @@ -1536,14 +1541,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods self.database_id, self.schema, ) - if isinstance(columns_by_name[selected], dict): - select_exprs.append(sa.column(f"{selected}")) - else: - select_exprs.append( - columns_by_name[selected].get_sqla_col() - if selected in columns_by_name - else self.make_sqla_column_compatible(literal_column(selected)) - ) + select_exprs.append(sa.column(f"{selected}")) metrics_exprs = [] if granularity: @@ -1557,14 +1555,13 @@ class ExploreMixin: # pylint: disable=too-many-public-methods time_filters: List[Any] = [] if is_timeseries: - if isinstance(dttm_col, dict): - timestamp = self.get_timestamp_expression( - dttm_col, time_grain, template_processor=template_processor - ) - else: - timestamp = dttm_col.get_timestamp_expression( - time_grain=time_grain, template_processor=template_processor - ) + timestamp = self.get_timestamp_expression( + dttm_col, time_grain, template_processor=template_processor + ) + # else: + # timestamp = dttm_col.get_timestamp_expression( + # time_grain=time_grain, template_processor=template_processor + # ) # always put timestamp as the first column select_exprs.insert(0, timestamp) groupby_all_columns[timestamp.name] = timestamp diff --git a/tests/unit_tests/models/helpers_test.py b/tests/unit_tests/models/helpers_test.py new file mode 100644 index 0000000000..8d97036cb0 --- /dev/null +++ b/tests/unit_tests/models/helpers_test.py @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=import-outside-toplevel + + +def test_explore_mixin_get_timestamp(): + from superset.columns.models import Column as SLColumn + from superset.models.core import Database + from superset.models.sql_lab import Query + + db = Database(database_name="my_database", sqlalchemy_uri="sqlite://") + query = Query( + client_id="foo", + database=db, + tab_name="test_tab", + sql_editor_id="test_editor_id", + sql="select * from bar", + select_sql="select * from bar", + executed_sql="select * from bar", + limit=100, + select_as_cta=False, + rows=100, + error_message="none", + results_key="abc", + ) + + col = SLColumn(name="foo", type="TIMESTAMP") + query.get_timestamp_expression(col, time_grain=None)
