This is an automated email from the ASF dual-hosted git repository. villebro pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-superset.git
The following commit(s) were added to refs/heads/master by this push: new fc28c92 feat: support non-numeric columns in pivot table (#10389) fc28c92 is described below commit fc28c92f57edd7bccac57715bef159c0f18daef1 Author: Ville Brofeldt <33317356+ville...@users.noreply.github.com> AuthorDate: Tue Jul 28 10:40:53 2020 +0300 feat: support non-numeric columns in pivot table (#10389) * fix: support non-numeric columns in pivot table * bump package and add unit tests * mypy --- superset/viz.py | 39 +++++++++++++++++++++++++++++++-------- tests/viz_tests.py | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 8 deletions(-) diff --git a/superset/viz.py b/superset/viz.py index 6ce4d5a..6f0ba53 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -29,7 +29,18 @@ import uuid from collections import defaultdict, OrderedDict from datetime import datetime, timedelta from itertools import product -from typing import Any, cast, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union +from typing import ( + Any, + Callable, + cast, + Dict, + List, + Optional, + Set, + Tuple, + TYPE_CHECKING, + Union, +) import dataclasses import geohash @@ -734,6 +745,7 @@ class PivotTableViz(BaseViz): verbose_name = _("Pivot Table") credits = 'a <a href="https://github.com/airbnb/superset">Superset</a> original' is_timeseries = False + enforce_numerical_metrics = False def query_obj(self) -> QueryObjectDict: d = super().query_obj() @@ -764,6 +776,18 @@ class PivotTableViz(BaseViz): raise QueryObjectValidationError(_("Group By' and 'Columns' can't overlap")) return d + @staticmethod + def get_aggfunc( + metric: str, df: pd.DataFrame, form_data: Dict[str, Any] + ) -> Union[str, Callable[[Any], Any]]: + aggfunc = form_data.get("pandas_aggfunc") or "sum" + if pd.api.types.is_numeric_dtype(df[metric]): + # Ensure that Pandas's sum function mimics that of SQL. + if aggfunc == "sum": + return lambda x: x.sum(min_count=1) + # only min and max work properly for non-numerics + return aggfunc if aggfunc in ("min", "max") else "max" + def get_data(self, df: pd.DataFrame) -> VizData: if df.empty: return None @@ -771,22 +795,21 @@ class PivotTableViz(BaseViz): if self.form_data.get("granularity") == "all" and DTTM_ALIAS in df: del df[DTTM_ALIAS] - aggfunc = self.form_data.get("pandas_aggfunc") or "sum" - - # Ensure that Pandas's sum function mimics that of SQL. - if aggfunc == "sum": - aggfunc = lambda x: x.sum(min_count=1) + metrics = [utils.get_metric_name(m) for m in self.form_data["metrics"]] + aggfuncs: Dict[str, Union[str, Callable[[Any], Any]]] = {} + for metric in metrics: + aggfuncs[metric] = self.get_aggfunc(metric, df, self.form_data) groupby = self.form_data.get("groupby") columns = self.form_data.get("columns") if self.form_data.get("transpose_pivot"): groupby, columns = columns, groupby - metrics = [utils.get_metric_name(m) for m in self.form_data["metrics"]] + df = df.pivot_table( index=groupby, columns=columns, values=metrics, - aggfunc=aggfunc, + aggfunc=aggfuncs, margins=self.form_data.get("pivot_margins"), ) diff --git a/tests/viz_tests.py b/tests/viz_tests.py index c6d0c80..d1ac508 100644 --- a/tests/viz_tests.py +++ b/tests/viz_tests.py @@ -1292,3 +1292,41 @@ class TestBigNumberViz(SupersetTestCase): ) data = viz.BigNumberViz(datasource, {"metrics": ["y"]}).get_data(df) assert np.isnan(data[2]["y"]) + + +class TestPivotTableViz(SupersetTestCase): + df = pd.DataFrame( + data={ + "intcol": [1, 2, 3, None], + "floatcol": [0.1, 0.2, 0.3, None], + "strcol": ["a", "b", "c", None], + } + ) + + def test_get_aggfunc_numeric(self): + # is a sum function + func = viz.PivotTableViz.get_aggfunc("intcol", self.df, {}) + assert hasattr(func, "__call__") + assert func(self.df["intcol"]) == 6 + + assert ( + viz.PivotTableViz.get_aggfunc("intcol", self.df, {"pandas_aggfunc": "min"}) + == "min" + ) + assert ( + viz.PivotTableViz.get_aggfunc( + "floatcol", self.df, {"pandas_aggfunc": "max"} + ) + == "max" + ) + + def test_get_aggfunc_non_numeric(self): + assert viz.PivotTableViz.get_aggfunc("strcol", self.df, {}) == "max" + assert ( + viz.PivotTableViz.get_aggfunc("strcol", self.df, {"pandas_aggfunc": "sum"}) + == "max" + ) + assert ( + viz.PivotTableViz.get_aggfunc("strcol", self.df, {"pandas_aggfunc": "min"}) + == "min" + )