This is an automated email from the ASF dual-hosted git repository.

villebro pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 14260f9  feat: add contribution operation and fix cache_key bug 
(#10286)
14260f9 is described below

commit 14260f984334c0adedf813cd821f3fc92d3a2bae
Author: Ville Brofeldt <33317356+ville...@users.noreply.github.com>
AuthorDate: Fri Jul 10 17:06:05 2020 +0300

    feat: add contribution operation and fix cache_key bug (#10286)
    
    * feat: add contribution operation and fix cache_key_bug
    
    * Add contribution schema
---
 superset/charts/schemas.py              | 17 ++++++++++++++++-
 superset/common/query_object.py         |  8 +++++---
 superset/utils/core.py                  |  9 +++++++++
 superset/utils/pandas_postprocessing.py | 31 +++++++++++++++++++++++++++++--
 tests/charts/schema_tests.py            | 10 ++++++++++
 tests/pandas_postprocessing_tests.py    | 29 ++++++++++++++++++++++++++++-
 tests/query_context_tests.py            | 27 +++++++++++++++++++++++++++
 7 files changed, 124 insertions(+), 7 deletions(-)

diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py
index 8ab4859..4f2e3f0 100644
--- a/superset/charts/schemas.py
+++ b/superset/charts/schemas.py
@@ -395,6 +395,19 @@ class 
ChartDataSortOptionsSchema(ChartDataPostProcessingOperationOptionsSchema):
     aggregates = ChartDataAggregateConfigField()
 
 
+class 
ChartDataContributionOptionsSchema(ChartDataPostProcessingOperationOptionsSchema):
+    """
+    Contribution operation config.
+    """
+
+    orientation = fields.String(
+        description="Should cell values be calculated across the row or 
column.",
+        required=True,
+        validate=validate.OneOf(choices=("row", "column",)),
+        example="row",
+    )
+
+
 class 
ChartDataPivotOptionsSchema(ChartDataPostProcessingOperationOptionsSchema):
     """
     Pivot operation config.
@@ -500,6 +513,7 @@ class ChartDataPostProcessingOperationSchema(Schema):
         validate=validate.OneOf(
             choices=(
                 "aggregate",
+                "contribution",
                 "cum",
                 "geodetic_parse",
                 "geohash_decode",
@@ -637,7 +651,7 @@ class ChartDataQueryObjectSchema(Schema):
         "`ChartDataAdhocMetricSchema` for the structure of ad-hoc metrics.",
     )
     post_processing = fields.List(
-        fields.Nested(ChartDataPostProcessingOperationSchema),
+        fields.Nested(ChartDataPostProcessingOperationSchema, allow_none=True),
         description="Post processing operations to be applied to the result 
set. "
         "Operations are applied to the result set in sequential order.",
     )
@@ -812,6 +826,7 @@ CHART_DATA_SCHEMAS = (
     #  by Marshmallow<3, this is not currently possible.
     ChartDataAdhocMetricSchema,
     ChartDataAggregateOptionsSchema,
+    ChartDataContributionOptionsSchema,
     ChartDataPivotOptionsSchema,
     ChartDataRollingOptionsSchema,
     ChartDataSelectOptionsSchema,
diff --git a/superset/common/query_object.py b/superset/common/query_object.py
index 8de2165..a2676b9 100644
--- a/superset/common/query_object.py
+++ b/superset/common/query_object.py
@@ -94,7 +94,7 @@ class QueryObject:
         extras: Optional[Dict[str, Any]] = None,
         columns: Optional[List[str]] = None,
         orderby: Optional[List[List[str]]] = None,
-        post_processing: Optional[List[Dict[str, Any]]] = None,
+        post_processing: Optional[List[Optional[Dict[str, Any]]]] = None,
         **kwargs: Any,
     ):
         metrics = metrics or []
@@ -114,7 +114,9 @@ class QueryObject:
         self.is_timeseries = is_timeseries
         self.time_range = time_range
         self.time_shift = utils.parse_human_timedelta(time_shift)
-        self.post_processing = post_processing or []
+        self.post_processing = [
+            post_proc for post_proc in post_processing or [] if post_proc
+        ]
         if not is_sip_38:
             self.groupby = groupby or []
 
@@ -224,9 +226,9 @@ class QueryObject:
             del cache_dict[k]
         if self.time_range:
             cache_dict["time_range"] = self.time_range
-        json_data = self.json_dumps(cache_dict, sort_keys=True)
         if self.post_processing:
             cache_dict["post_processing"] = self.post_processing
+        json_data = self.json_dumps(cache_dict, sort_keys=True)
         return hashlib.md5(json_data.encode("utf-8")).hexdigest()
 
     def json_dumps(self, obj: Any, sort_keys: bool = False) -> str:
diff --git a/superset/utils/core.py b/superset/utils/core.py
index 9edee1c..c464d78 100644
--- a/superset/utils/core.py
+++ b/superset/utils/core.py
@@ -1476,3 +1476,12 @@ class TemporalType(str, Enum):
     TEXT = "TEXT"
     TIME = "TIME"
     TIMESTAMP = "TIMESTAMP"
+
+
+class PostProcessingContributionOrientation(str, Enum):
+    """
+    Calculate cell contibution to row/column total
+    """
+
+    ROW = "row"
+    COLUMN = "column"
diff --git a/superset/utils/pandas_postprocessing.py 
b/superset/utils/pandas_postprocessing.py
index b693977..12b49bc 100644
--- a/superset/utils/pandas_postprocessing.py
+++ b/superset/utils/pandas_postprocessing.py
@@ -15,15 +15,16 @@
 # specific language governing permissions and limitations
 # under the License.
 from functools import partial
-from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
 
 import geohash as geohash_lib
 import numpy as np
 from flask_babel import gettext as _
 from geopy.point import Point
-from pandas import DataFrame, NamedAgg
+from pandas import DataFrame, NamedAgg, Series
 
 from superset.exceptions import QueryObjectValidationError
+from superset.utils.core import DTTM_ALIAS, 
PostProcessingContributionOrientation
 
 WHITELIST_NUMPY_FUNCTIONS = (
     "average",
@@ -517,3 +518,29 @@ def geodetic_parse(
         return _append_columns(df, geodetic_df, columns)
     except ValueError:
         raise QueryObjectValidationError(_("Invalid geodetic string"))
+
+
+def contribution(
+    df: DataFrame, orientation: PostProcessingContributionOrientation
+) -> DataFrame:
+    """
+    Calculate cell contibution to row/column total.
+
+    :param df: DataFrame containing all-numeric data (temporal column ignored)
+    :param orientation: calculate by dividing cell with row/column total
+    :return: DataFrame with contributions, with temporal column at beginning 
if present
+    """
+    temporal_series: Optional[Series] = None
+    contribution_df = df.copy()
+    if DTTM_ALIAS in df.columns:
+        temporal_series = cast(Series, contribution_df.pop(DTTM_ALIAS))
+
+    if orientation == PostProcessingContributionOrientation.ROW:
+        contribution_dft = contribution_df.T
+        contribution_df = (contribution_dft / contribution_dft.sum()).T
+    else:
+        contribution_df = contribution_df / contribution_df.sum()
+
+    if temporal_series is not None:
+        contribution_df.insert(0, DTTM_ALIAS, temporal_series)
+    return contribution_df
diff --git a/tests/charts/schema_tests.py b/tests/charts/schema_tests.py
index 354ed82..ecb2c97 100644
--- a/tests/charts/schema_tests.py
+++ b/tests/charts/schema_tests.py
@@ -69,3 +69,13 @@ class TestSchema(SupersetTestCase):
 
         payload["queries"][0]["extras"]["time_grain_sqla"] = None
         _ = ChartDataQueryContextSchema().load(payload)
+
+    def test_query_context_null_post_processing_op(self):
+        self.login(username="admin")
+        table_name = "birth_names"
+        table = self.get_table_by_name(table_name)
+        payload = get_query_context(table.name, table.id, table.type)
+
+        payload["queries"][0]["post_processing"] = [None]
+        query_context = ChartDataQueryContextSchema().load(payload)
+        self.assertEqual(query_context.queries[0].post_processing, [])
diff --git a/tests/pandas_postprocessing_tests.py 
b/tests/pandas_postprocessing_tests.py
index 87d2cc1..ea70834 100644
--- a/tests/pandas_postprocessing_tests.py
+++ b/tests/pandas_postprocessing_tests.py
@@ -15,13 +15,15 @@
 # specific language governing permissions and limitations
 # under the License.
 # isort:skip_file
+from datetime import datetime
 import math
 from typing import Any, List, Optional
 
-from pandas import Series
+from pandas import DataFrame, Series
 
 from superset.exceptions import QueryObjectValidationError
 from superset.utils import pandas_postprocessing as proc
+from superset.utils.core import DTTM_ALIAS, 
PostProcessingContributionOrientation
 
 from .base_tests import SupersetTestCase
 from .fixtures.dataframes import categories_df, lonlat_df, timeseries_df
@@ -481,3 +483,28 @@ class TestPostProcessing(SupersetTestCase):
         self.assertListEqual(
             series_to_list(post_df["latitude"]), 
series_to_list(lonlat_df["latitude"]),
         )
+
+    def test_contribution(self):
+        df = DataFrame(
+            {
+                DTTM_ALIAS: [
+                    datetime(2020, 7, 16, 14, 49),
+                    datetime(2020, 7, 16, 14, 50),
+                ],
+                "a": [1, 3],
+                "b": [1, 9],
+            }
+        )
+
+        # cell contribution across row
+        row_df = proc.contribution(df, 
PostProcessingContributionOrientation.ROW)
+        self.assertListEqual(df.columns.tolist(), [DTTM_ALIAS, "a", "b"])
+        self.assertListEqual(series_to_list(row_df["a"]), [0.5, 0.25])
+        self.assertListEqual(series_to_list(row_df["b"]), [0.5, 0.75])
+
+        # cell contribution across column without temporal column
+        df.pop(DTTM_ALIAS)
+        column_df = proc.contribution(df, 
PostProcessingContributionOrientation.COLUMN)
+        self.assertListEqual(df.columns.tolist(), ["a", "b"])
+        self.assertListEqual(series_to_list(column_df["a"]), [0.25, 0.75])
+        self.assertListEqual(series_to_list(column_df["b"]), [0.1, 0.9])
diff --git a/tests/query_context_tests.py b/tests/query_context_tests.py
index 4b625b5..f816bcd 100644
--- a/tests/query_context_tests.py
+++ b/tests/query_context_tests.py
@@ -99,6 +99,33 @@ class TestQueryContext(SupersetTestCase):
         # the new cache_key should be different due to updated datasource
         self.assertNotEqual(cache_key_original, cache_key_new)
 
+    def test_cache_key_changes_when_post_processing_is_updated(self):
+        self.login(username="admin")
+        table_name = "birth_names"
+        table = self.get_table_by_name(table_name)
+        payload = get_query_context(
+            table.name, table.id, table.type, 
add_postprocessing_operations=True
+        )
+
+        # construct baseline cache_key from query_context with post processing 
operation
+        query_context = QueryContext(**payload)
+        query_object = query_context.queries[0]
+        cache_key_original = query_context.cache_key(query_object)
+
+        # ensure added None post_processing operation doesn't change cache_key
+        payload["queries"][0]["post_processing"].append(None)
+        query_context = QueryContext(**payload)
+        query_object = query_context.queries[0]
+        cache_key_with_null = query_context.cache_key(query_object)
+        self.assertEqual(cache_key_original, cache_key_with_null)
+
+        # ensure query without post processing operation is different
+        payload["queries"][0].pop("post_processing")
+        query_context = QueryContext(**payload)
+        query_object = query_context.queries[0]
+        cache_key_without_post_processing = 
query_context.cache_key(query_object)
+        self.assertNotEqual(cache_key_original, 
cache_key_without_post_processing)
+
     def test_query_context_time_range_endpoints(self):
         """
         Ensure that time_range_endpoints are populated automatically when 
missing

Reply via email to