This is an automated email from the ASF dual-hosted git repository. villebro pushed a commit to branch 1.3 in repository https://gitbox.apache.org/repos/asf/superset.git
commit 8146a5cad360dc299e3c4bad67422a34aff4bb8d Author: Jesse Yang <[email protected]> AuthorDate: Tue Aug 3 19:01:39 2021 -0700 fix(dashboard): 500 error caused by data_for_slices API (#16053) (cherry picked from commit 490890de23e876811d9ea9de31a8bd384c3550de) --- superset/connectors/base/models.py | 9 ++++----- superset/utils/core.py | 12 +++--------- superset/viz.py | 2 +- tests/integration_tests/model_tests.py | 14 ++++++++------ 4 files changed, 16 insertions(+), 21 deletions(-) diff --git a/superset/connectors/base/models.py b/superset/connectors/base/models.py index 88c4f40..f332ce8 100644 --- a/superset/connectors/base/models.py +++ b/superset/connectors/base/models.py @@ -282,10 +282,9 @@ class BaseDatasource( column_names = set() for slc in slices: form_data = slc.form_data - # pull out all required metrics from the form_data - for param in METRIC_FORM_DATA_PARAMS: - for metric in utils.get_iterable(form_data.get(param) or []): + for metric_param in METRIC_FORM_DATA_PARAMS: + for metric in utils.get_iterable(form_data.get(metric_param) or []): metric_names.add(utils.get_metric_name(metric)) if utils.is_adhoc_metric(metric): column_names.add( @@ -308,8 +307,8 @@ class BaseDatasource( column_names.update( column - for column in utils.get_iterable(form_data.get(param) or []) - for param in COLUMN_FORM_DATA_PARAMS + for column_param in COLUMN_FORM_DATA_PARAMS + for column in utils.get_iterable(form_data.get(column_param) or []) ) filtered_metrics = [ diff --git a/superset/utils/core.py b/superset/utils/core.py index 9e37b7c..9739287 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -1217,10 +1217,10 @@ def get_metric_name(metric: Metric) -> str: def get_metric_names(metrics: Sequence[Metric]) -> List[str]: - return [get_metric_name(metric) for metric in metrics] + return [metric for metric in map(get_metric_name, metrics) if metric] -def get_main_metric_name(metrics: Sequence[Metric]) -> Optional[str]: +def get_first_metric_name(metrics: Sequence[Metric]) -> Optional[str]: metric_labels = get_metric_names(metrics) return metric_labels[0] if metric_labels else None @@ -1427,7 +1427,6 @@ def get_iterable(x: Any) -> List[Any]: :param x: The object :returns: An iterable representation """ - return x if isinstance(x, list) else [x] @@ -1464,12 +1463,7 @@ def get_column_names_from_metrics(metrics: List[Metric]) -> List[str]: :param metrics: Ad-hoc metric :return: column name if simple metric, otherwise None """ - columns: List[str] = [] - for metric in metrics: - column_name = get_column_name_from_metric(metric) - if column_name: - columns.append(column_name) - return columns + return [col for col in map(get_column_name_from_metric, metrics) if col] def extract_dataframe_dtypes(df: pd.DataFrame) -> List[GenericDataType]: diff --git a/superset/viz.py b/superset/viz.py index 18b0309..3dc799e 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -1230,7 +1230,7 @@ class NVD3TimeSeriesViz(NVD3Viz): d = super().query_obj() sort_by = self.form_data.get( "timeseries_limit_metric" - ) or utils.get_main_metric_name(d.get("metrics") or []) + ) or utils.get_first_metric_name(d.get("metrics") or []) is_asc = not self.form_data.get("order_desc") if sort_by: sort_by_label = utils.get_metric_name(sort_by) diff --git a/tests/integration_tests/model_tests.py b/tests/integration_tests/model_tests.py index 665f1cd..ee9b3cd 100644 --- a/tests/integration_tests/model_tests.py +++ b/tests/integration_tests/model_tests.py @@ -497,13 +497,15 @@ class TestSqlaTableModel(SupersetTestCase): slc = ( metadata_db.session.query(Slice) .filter_by( - datasource_id=tbl.id, - datasource_type=tbl.type, - slice_name="Participants", + datasource_id=tbl.id, datasource_type=tbl.type, slice_name="Genders", ) .first() ) data_for_slices = tbl.data_for_slices([slc]) - self.assertEqual(len(data_for_slices["columns"]), 0) - self.assertEqual(len(data_for_slices["metrics"]), 1) - self.assertEqual(len(data_for_slices["verbose_map"].keys()), 2) + assert len(data_for_slices["metrics"]) == 1 + assert len(data_for_slices["columns"]) == 1 + assert data_for_slices["metrics"][0]["metric_name"] == "sum__num" + assert data_for_slices["columns"][0]["column_name"] == "gender" + assert set(data_for_slices["verbose_map"].keys()) == set( + ["__timestamp", "sum__num", "gender",] + )
