This is an automated email from the ASF dual-hosted git repository.

yongjiezhao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 94282b7  fix: time comparison can't guarantee the accuracy (#16895)
94282b7 is described below

commit 94282b7ecdb0851af08b7954efbc46a84c1b2408
Author: Yongjie Zhao <[email protected]>
AuthorDate: Thu Sep 30 12:59:57 2021 +0100

    fix: time comparison can't guarantee the accuracy (#16895)
    
    * fix: time comparison can't guarantee the accuracy
    
    * fix multiple series
    
    * fix lint
    
    * fix ut
    
    * fix lint
    
    * more ut
    
    * fix typo
---
 superset/common/query_context.py               | 60 ++++++++--------
 tests/integration_tests/query_context_tests.py | 99 ++++++++++++++++++++++++++
 2 files changed, 131 insertions(+), 28 deletions(-)

diff --git a/superset/common/query_context.py b/superset/common/query_context.py
index 9f1f4bf..c1162b9 100644
--- a/superset/common/query_context.py
+++ b/superset/common/query_context.py
@@ -113,11 +113,11 @@ class QueryContext:
         }
 
     @staticmethod
-    def left_join_on_dttm(
-        left_df: pd.DataFrame, right_df: pd.DataFrame
+    def left_join_df(
+        left_df: pd.DataFrame, right_df: pd.DataFrame, join_keys: List[str],
     ) -> pd.DataFrame:
-        df = left_df.set_index(DTTM_ALIAS).join(right_df.set_index(DTTM_ALIAS))
-        df.reset_index(level=0, inplace=True)
+        df = left_df.set_index(join_keys).join(right_df.set_index(join_keys))
+        df.reset_index(inplace=True)
         return df
 
     def processing_time_offsets(  # pylint: disable=too-many-locals
@@ -125,8 +125,9 @@ class QueryContext:
     ) -> CachedTimeOffset:
         # ensure query_object is immutable
         query_object_clone = copy.copy(query_object)
-        queries = []
-        cache_keys = []
+        queries: List[str] = []
+        cache_keys: List[Optional[str]] = []
+        rv_dfs: List[pd.DataFrame] = [df]
 
         time_offsets = query_object.time_offsets
         outer_from_dttm = query_object.from_dttm
@@ -155,31 +156,34 @@ class QueryContext:
             # `offset` is added to the hash function
             cache_key = self.query_cache_key(query_object_clone, 
time_offset=offset)
             cache = QueryCacheManager.get(cache_key, CacheRegion.DATA, 
self.force)
-            # whether hit in the cache
+            # whether hit on the cache
             if cache.is_loaded:
-                df = self.left_join_on_dttm(df, cache.df)
+                rv_dfs.append(cache.df)
                 queries.append(cache.query)
                 cache_keys.append(cache_key)
                 continue
 
             query_object_clone_dct = query_object_clone.to_dict()
-            result = self.datasource.query(query_object_clone_dct)
-            queries.append(result.query)
-            cache_keys.append(None)
-
             # rename metrics: SUM(value) => SUM(value) 1 year ago
-            columns_name_mapping = {
+            metrics_mapping = {
                 metric: TIME_COMPARISION.join([metric, offset])
                 for metric in get_metric_names(
                     query_object_clone_dct.get("metrics", [])
                 )
             }
-            columns_name_mapping[DTTM_ALIAS] = DTTM_ALIAS
+            join_keys = [col for col in df.columns if col not in 
metrics_mapping.keys()]
+
+            result = self.datasource.query(query_object_clone_dct)
+            queries.append(result.query)
+            cache_keys.append(None)
 
             offset_metrics_df = result.df
             if offset_metrics_df.empty:
                 offset_metrics_df = pd.DataFrame(
-                    {col: [np.NaN] for col in columns_name_mapping.values()}
+                    {
+                        col: [np.NaN]
+                        for col in join_keys + list(metrics_mapping.values())
+                    }
                 )
             else:
                 # 1. normalize df, set dttm column
@@ -187,25 +191,23 @@ class QueryContext:
                     offset_metrics_df, query_object_clone
                 )
 
-                # 2. extract `metrics` columns and `dttm` column from extra 
query
-                offset_metrics_df = 
offset_metrics_df[columns_name_mapping.keys()]
+                # 2. rename extra query columns
+                offset_metrics_df = 
offset_metrics_df.rename(columns=metrics_mapping)
 
-                # 3. rename extra query columns
-                offset_metrics_df = offset_metrics_df.rename(
-                    columns=columns_name_mapping
-                )
-
-                # 4. set offset for dttm column
+                # 3. set time offset for dttm column
                 offset_metrics_df[DTTM_ALIAS] = offset_metrics_df[
                     DTTM_ALIAS
                 ] - DateOffset(**normalize_time_delta(offset))
 
-            # df left join `offset_metrics_df` on `DTTM`
-            df = self.left_join_on_dttm(df, offset_metrics_df)
+            # df left join `offset_metrics_df`
+            offset_df = self.left_join_df(
+                left_df=df, right_df=offset_metrics_df, join_keys=join_keys,
+            )
+            offset_slice = offset_df[metrics_mapping.values()]
 
-            # set offset df to cache.
+            # set offset_slice to cache and stack.
             value = {
-                "df": offset_metrics_df,
+                "df": offset_slice,
                 "query": result.query,
             }
             cache.set(
@@ -215,8 +217,10 @@ class QueryContext:
                 datasource_uid=self.datasource.uid,
                 region=CacheRegion.DATA,
             )
+            rv_dfs.append(offset_slice)
 
-        return CachedTimeOffset(df=df, queries=queries, cache_keys=cache_keys)
+        rv_df = pd.concat(rv_dfs, axis=1, copy=False) if time_offsets else df
+        return CachedTimeOffset(df=rv_df, queries=queries, 
cache_keys=cache_keys)
 
     def normalize_df(self, df: pd.DataFrame, query_object: QueryObject) -> 
pd.DataFrame:
         timestamp_format = None
diff --git a/tests/integration_tests/query_context_tests.py 
b/tests/integration_tests/query_context_tests.py
index ecc69b7..cd76540 100644
--- a/tests/integration_tests/query_context_tests.py
+++ b/tests/integration_tests/query_context_tests.py
@@ -20,6 +20,7 @@ import time
 from typing import Any, Dict
 
 import pytest
+from pandas import DateOffset
 
 from superset import db
 from superset.charts.schemas import ChartDataQueryContextSchema
@@ -546,11 +547,15 @@ class TestQueryContext(SupersetTestCase):
         self.login(username="admin")
         payload = get_query_context("birth_names")
         payload["queries"][0]["metrics"] = ["sum__num"]
+        # should process empty dateframe correctly
+        # due to "name" is random generated, each time_offset slice will be 
empty
         payload["queries"][0]["groupby"] = ["name"]
         payload["queries"][0]["is_timeseries"] = True
         payload["queries"][0]["timeseries_limit"] = 5
         payload["queries"][0]["time_offsets"] = []
         payload["queries"][0]["time_range"] = "1990 : 1991"
+        payload["queries"][0]["granularity"] = "ds"
+        payload["queries"][0]["extras"]["time_grain_sqla"] = "P1Y"
         query_context = ChartDataQueryContextSchema().load(payload)
         query_object = query_context.queries[0]
         query_result = query_context.get_query_result(query_object)
@@ -588,3 +593,97 @@ class TestQueryContext(SupersetTestCase):
         self.assertIs(rv["df"], df)
         self.assertEqual(rv["queries"], [])
         self.assertEqual(rv["cache_keys"], [])
+
+    @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
+    def test_time_offsets_sql(self):
+        payload = get_query_context("birth_names")
+        payload["queries"][0]["metrics"] = ["sum__num"]
+        payload["queries"][0]["groupby"] = ["state"]
+        payload["queries"][0]["is_timeseries"] = True
+        payload["queries"][0]["timeseries_limit"] = 5
+        payload["queries"][0]["time_offsets"] = []
+        payload["queries"][0]["time_range"] = "1980 : 1991"
+        payload["queries"][0]["granularity"] = "ds"
+        payload["queries"][0]["extras"]["time_grain_sqla"] = "P1Y"
+        query_context = ChartDataQueryContextSchema().load(payload)
+        query_object = query_context.queries[0]
+        query_result = query_context.get_query_result(query_object)
+        # get main query dataframe
+        df = query_result.df
+
+        # set time_offsets to query_object
+        payload["queries"][0]["time_offsets"] = ["3 years ago", "3 years 
later"]
+        query_context = ChartDataQueryContextSchema().load(payload)
+        query_object = query_context.queries[0]
+        time_offsets_obj = query_context.processing_time_offsets(df, 
query_object)
+        query_from_1977_to_1988 = time_offsets_obj["queries"][0]
+        query_from_1983_to_1994 = time_offsets_obj["queries"][1]
+
+        # should generate expected date range in sql
+        assert "1977-01-01" in query_from_1977_to_1988
+        assert "1988-01-01" in query_from_1977_to_1988
+        assert "1983-01-01" in query_from_1983_to_1994
+        assert "1994-01-01" in query_from_1983_to_1994
+
+    @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
+    def test_time_offsets_accuracy(self):
+        payload = get_query_context("birth_names")
+        payload["queries"][0]["metrics"] = ["sum__num"]
+        payload["queries"][0]["groupby"] = ["state"]
+        payload["queries"][0]["is_timeseries"] = True
+        payload["queries"][0]["timeseries_limit"] = 5
+        payload["queries"][0]["time_offsets"] = []
+        payload["queries"][0]["time_range"] = "1980 : 1991"
+        payload["queries"][0]["granularity"] = "ds"
+        payload["queries"][0]["extras"]["time_grain_sqla"] = "P1Y"
+        query_context = ChartDataQueryContextSchema().load(payload)
+        query_object = query_context.queries[0]
+        query_result = query_context.get_query_result(query_object)
+        # get main query dataframe
+        df = query_result.df
+
+        # set time_offsets to query_object
+        payload["queries"][0]["time_offsets"] = ["3 years ago", "3 years 
later"]
+        query_context = ChartDataQueryContextSchema().load(payload)
+        query_object = query_context.queries[0]
+        time_offsets_obj = query_context.processing_time_offsets(df, 
query_object)
+        df_with_offsets = time_offsets_obj["df"]
+        df_with_offsets = df_with_offsets.set_index(["__timestamp", "state"])
+
+        # should get correct data when apply "3 years ago"
+        payload["queries"][0]["time_offsets"] = []
+        payload["queries"][0]["time_range"] = "1977 : 1988"
+        query_context = ChartDataQueryContextSchema().load(payload)
+        query_object = query_context.queries[0]
+        query_result = query_context.get_query_result(query_object)
+        # get df for "3 years ago"
+        df_3_years_ago = query_result.df
+        df_3_years_ago["__timestamp"] = df_3_years_ago["__timestamp"] + 
DateOffset(
+            years=3
+        )
+        df_3_years_ago = df_3_years_ago.set_index(["__timestamp", "state"])
+        for index, row in df_with_offsets.iterrows():
+            if index in df_3_years_ago.index:
+                assert (
+                    row["sum__num__3 years ago"]
+                    == df_3_years_ago.loc[index]["sum__num"]
+                )
+
+        # should get correct data when apply "3 years later"
+        payload["queries"][0]["time_offsets"] = []
+        payload["queries"][0]["time_range"] = "1983 : 1994"
+        query_context = ChartDataQueryContextSchema().load(payload)
+        query_object = query_context.queries[0]
+        query_result = query_context.get_query_result(query_object)
+        # get df for "3 years later"
+        df_3_years_later = query_result.df
+        df_3_years_later["__timestamp"] = df_3_years_later["__timestamp"] - 
DateOffset(
+            years=3
+        )
+        df_3_years_later = df_3_years_later.set_index(["__timestamp", "state"])
+        for index, row in df_with_offsets.iterrows():
+            if index in df_3_years_later.index:
+                assert (
+                    row["sum__num__3 years later"]
+                    == df_3_years_later.loc[index]["sum__num"]
+                )

Reply via email to