This is an automated email from the ASF dual-hosted git repository.

villebro pushed a commit to branch 1.3
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 8146a5cad360dc299e3c4bad67422a34aff4bb8d
Author: Jesse Yang <[email protected]>
AuthorDate: Tue Aug 3 19:01:39 2021 -0700

    fix(dashboard): 500 error caused by data_for_slices API (#16053)
    
    (cherry picked from commit 490890de23e876811d9ea9de31a8bd384c3550de)
---
 superset/connectors/base/models.py     |  9 ++++-----
 superset/utils/core.py                 | 12 +++---------
 superset/viz.py                        |  2 +-
 tests/integration_tests/model_tests.py | 14 ++++++++------
 4 files changed, 16 insertions(+), 21 deletions(-)

diff --git a/superset/connectors/base/models.py 
b/superset/connectors/base/models.py
index 88c4f40..f332ce8 100644
--- a/superset/connectors/base/models.py
+++ b/superset/connectors/base/models.py
@@ -282,10 +282,9 @@ class BaseDatasource(
         column_names = set()
         for slc in slices:
             form_data = slc.form_data
-
             # pull out all required metrics from the form_data
-            for param in METRIC_FORM_DATA_PARAMS:
-                for metric in utils.get_iterable(form_data.get(param) or []):
+            for metric_param in METRIC_FORM_DATA_PARAMS:
+                for metric in utils.get_iterable(form_data.get(metric_param) 
or []):
                     metric_names.add(utils.get_metric_name(metric))
                     if utils.is_adhoc_metric(metric):
                         column_names.add(
@@ -308,8 +307,8 @@ class BaseDatasource(
 
             column_names.update(
                 column
-                for column in utils.get_iterable(form_data.get(param) or [])
-                for param in COLUMN_FORM_DATA_PARAMS
+                for column_param in COLUMN_FORM_DATA_PARAMS
+                for column in utils.get_iterable(form_data.get(column_param) 
or [])
             )
 
         filtered_metrics = [
diff --git a/superset/utils/core.py b/superset/utils/core.py
index 9e37b7c..9739287 100644
--- a/superset/utils/core.py
+++ b/superset/utils/core.py
@@ -1217,10 +1217,10 @@ def get_metric_name(metric: Metric) -> str:
 
 
 def get_metric_names(metrics: Sequence[Metric]) -> List[str]:
-    return [get_metric_name(metric) for metric in metrics]
+    return [metric for metric in map(get_metric_name, metrics) if metric]
 
 
-def get_main_metric_name(metrics: Sequence[Metric]) -> Optional[str]:
+def get_first_metric_name(metrics: Sequence[Metric]) -> Optional[str]:
     metric_labels = get_metric_names(metrics)
     return metric_labels[0] if metric_labels else None
 
@@ -1427,7 +1427,6 @@ def get_iterable(x: Any) -> List[Any]:
     :param x: The object
     :returns: An iterable representation
     """
-
     return x if isinstance(x, list) else [x]
 
 
@@ -1464,12 +1463,7 @@ def get_column_names_from_metrics(metrics: List[Metric]) 
-> List[str]:
     :param metrics: Ad-hoc metric
     :return: column name if simple metric, otherwise None
     """
-    columns: List[str] = []
-    for metric in metrics:
-        column_name = get_column_name_from_metric(metric)
-        if column_name:
-            columns.append(column_name)
-    return columns
+    return [col for col in map(get_column_name_from_metric, metrics) if col]
 
 
 def extract_dataframe_dtypes(df: pd.DataFrame) -> List[GenericDataType]:
diff --git a/superset/viz.py b/superset/viz.py
index 18b0309..3dc799e 100644
--- a/superset/viz.py
+++ b/superset/viz.py
@@ -1230,7 +1230,7 @@ class NVD3TimeSeriesViz(NVD3Viz):
         d = super().query_obj()
         sort_by = self.form_data.get(
             "timeseries_limit_metric"
-        ) or utils.get_main_metric_name(d.get("metrics") or [])
+        ) or utils.get_first_metric_name(d.get("metrics") or [])
         is_asc = not self.form_data.get("order_desc")
         if sort_by:
             sort_by_label = utils.get_metric_name(sort_by)
diff --git a/tests/integration_tests/model_tests.py 
b/tests/integration_tests/model_tests.py
index 665f1cd..ee9b3cd 100644
--- a/tests/integration_tests/model_tests.py
+++ b/tests/integration_tests/model_tests.py
@@ -497,13 +497,15 @@ class TestSqlaTableModel(SupersetTestCase):
         slc = (
             metadata_db.session.query(Slice)
             .filter_by(
-                datasource_id=tbl.id,
-                datasource_type=tbl.type,
-                slice_name="Participants",
+                datasource_id=tbl.id, datasource_type=tbl.type, 
slice_name="Genders",
             )
             .first()
         )
         data_for_slices = tbl.data_for_slices([slc])
-        self.assertEqual(len(data_for_slices["columns"]), 0)
-        self.assertEqual(len(data_for_slices["metrics"]), 1)
-        self.assertEqual(len(data_for_slices["verbose_map"].keys()), 2)
+        assert len(data_for_slices["metrics"]) == 1
+        assert len(data_for_slices["columns"]) == 1
+        assert data_for_slices["metrics"][0]["metric_name"] == "sum__num"
+        assert data_for_slices["columns"][0]["column_name"] == "gender"
+        assert set(data_for_slices["verbose_map"].keys()) == set(
+            ["__timestamp", "sum__num", "gender",]
+        )

Reply via email to