This is an automated email from the ASF dual-hosted git repository.

arivero pushed a commit to branch joined_time_comparison
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/joined_time_comparison by this 
push:
     new 6fe13adca8 Time Comparison with JOINs:
6fe13adca8 is described below

commit 6fe13adca8999176c91c0ba259ed99cfe9bd5849
Author: Antonio Rivero <[email protected]>
AuthorDate: Tue Apr 2 15:12:00 2024 +0200

    Time Comparison with JOINs:
    
    - Add optional arg to extras schema in queryObject
    - Perfom a lett oute rjoin if the new arg is present so the query is 
processed in single query
    - Add tests for new fuctionality
---
 superset/charts/schemas.py                 |  17 ++
 superset/constants.py                      |   3 +-
 superset/models/helpers.py                 | 147 ++++++++++-
 tests/unit_tests/connectors/__init__.py    |  16 ++
 tests/unit_tests/connectors/test_models.py | 386 +++++++++++++++++++++++++++++
 5 files changed, 565 insertions(+), 4 deletions(-)

diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py
index 611f7af597..06ff334b6e 100644
--- a/superset/charts/schemas.py
+++ b/superset/charts/schemas.py
@@ -26,6 +26,7 @@ from marshmallow.validate import Length, Range
 
 from superset import app
 from superset.common.chart_data import ChartDataResultFormat, 
ChartDataResultType
+from superset.constants import InstantTimeComparison
 from superset.db_engine_specs.base import builtin_time_grains
 from superset.tags.models import TagType
 from superset.utils import pandas_postprocessing, schema as utils
@@ -948,6 +949,14 @@ class ChartDataFilterSchema(Schema):
     )
 
 
+class InstantTimeComparisonInfoSchema(Schema):
+    range = fields.String(
+        metadata={"description": "Type of time comparison to be used"},
+        validate=validate.OneOf(choices=[ran.value for ran in 
InstantTimeComparison]),
+    )
+    filter = fields.Nested(ChartDataFilterSchema, allow_none=True)
+
+
 class ChartDataExtrasSchema(Schema):
     relative_start = fields.String(
         metadata={
@@ -998,6 +1007,14 @@ class ChartDataExtrasSchema(Schema):
         },
         allow_none=True,
     )
+    instant_time_comparison_info = fields.Nested(
+        InstantTimeComparisonInfoSchema,
+        metadata={
+            "description": "Extra parameters to use instant time comparison"
+            " with JOINs using a single query"
+        },
+        allow_none=True,
+    )
 
 
 class AnnotationLayerSchema(Schema):
diff --git a/superset/constants.py b/superset/constants.py
index bf4e7717d5..b08966f8a2 100644
--- a/superset/constants.py
+++ b/superset/constants.py
@@ -44,10 +44,11 @@ LRU_CACHE_MAX_SIZE = 256
 
 # Used when calculating the time shift for time comparison
 class InstantTimeComparison(StrEnum):
+    CUSTOM = "c"
     INHERITED = "r"
-    YEAR = "y"
     MONTH = "m"
     WEEK = "w"
+    YEAR = "y"
 
 
 class RouteMethod:  # pylint: disable=too-few-public-methods
diff --git a/superset/models/helpers.py b/superset/models/helpers.py
index 450def33b2..9bfac8e88b 100644
--- a/superset/models/helpers.py
+++ b/superset/models/helpers.py
@@ -17,6 +17,7 @@
 # pylint: disable=too-many-lines
 """a collection of model-related helper classes and functions"""
 import builtins
+import copy
 import dataclasses
 import json
 import logging
@@ -57,7 +58,7 @@ from superset import app, db, is_feature_enabled, 
security_manager
 from superset.advanced_data_type.types import AdvancedDataTypeResponse
 from superset.common.db_query_status import QueryStatus
 from superset.common.utils.time_range_utils import 
get_since_until_from_time_range
-from superset.constants import EMPTY_STRING, NULL_STRING
+from superset.constants import EMPTY_STRING, InstantTimeComparison, NULL_STRING
 from superset.db_engine_specs.base import TimestampExpression
 from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
 from superset.exceptions import (
@@ -89,6 +90,7 @@ from superset.superset_typing import (
 )
 from superset.utils import core as utils
 from superset.utils.core import (
+    FilterOperator,
     GenericDataType,
     get_column_name,
     get_user_id,
@@ -904,6 +906,123 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
             sql = f"{cte}\n{sql}"
         return sql
 
+    def extract_column_names(self, final_selected_columns: Any) -> list[str]:
+        column_names = []
+        for selected_col in final_selected_columns:
+            # The key attribute usually holds the name or alias of the column
+            column_name = selected_col.key if hasattr(selected_col, "key") 
else None
+            # If the column has a name attribute, use it as a fallback
+            if not column_name and hasattr(selected_col, "name"):
+                column_name = selected_col.name
+            # For labeled elements, the name is stored in the 'name' attribute
+            if hasattr(selected_col, "name"):
+                column_name = selected_col.name
+            # Append the extracted name to the list
+            if column_name:
+                column_names.append(column_name)
+        return column_names
+
+    def process_time_compare_join(  # pylint: disable=too-many-locals
+        self,
+        query_obj: QueryObjectDict,
+        sqlaq: SqlaQuery,
+        mutate: bool,
+        instant_time_comparison_info: dict[str, Any],
+    ) -> tuple[str, list[str]]:
+        """
+        Main goal of this method is to create a JOIN between a given query 
object and
+        other that shifts the time filters. This is different from 
time_offsets because
+        we are not joining result sets but rather we're applying the JOIN at 
query level.
+        Use case: Compare paginated data in a Table Chart. But ideally can be 
leveraged by
+        anything that needs the experimental instant time comparison.
+        """
+        # So we don't override the original QueryObject
+        query_obj_clone = copy.copy(query_obj)
+        final_query_sql = ""
+        # The inner query object doesn't need limits nor offset
+        query_obj_clone["row_limit"] = None
+        query_obj_clone["row_offset"] = None
+        # Let's get what range should we be using when building the 
time_comparison shift
+        # This is computing the time_shift based on some predefined options of 
deltas
+        instant_time_comparison_range = 
instant_time_comparison_info.get("range")
+        if instant_time_comparison_range == InstantTimeComparison.CUSTOM:
+            # If it's a custom filter, we take the 1st temporal filter and 
change it with
+            # whatever value we received in the request as the custom filter.
+            custom_filter = instant_time_comparison_info.get("filter", {})
+            temporal_filters = [
+                filter["col"]
+                for filter in query_obj_clone.get("filter", {})
+                if filter.get("op", None) == FilterOperator.TEMPORAL_RANGE
+            ]
+            non_temporal_filters = [
+                filter["col"]
+                for filter in query_obj_clone.get("filter", {})
+                if filter.get("op", None) != FilterOperator.TEMPORAL_RANGE
+            ]
+            if len(temporal_filters) > 0:
+                # Edit the firt temporal filter to include the custom filter
+                temporal_filters[0] = custom_filter
+
+            new_filters = temporal_filters + non_temporal_filters
+            query_obj_clone["filter"] = new_filters
+        if instant_time_comparison_range != InstantTimeComparison.CUSTOM:
+            # When not custom, we're supposed to use the predefined time ranges
+            # Year, Month, Week or Inherited
+            query_obj_clone["extras"] = {
+                **query_obj_clone.get("extras", {}),
+                "instant_time_comparison_range": instant_time_comparison_range,
+            }
+        shifted_sqlaq = self.get_sqla_query(**query_obj_clone)
+        # We JOIN only over columns, not metrics or anything else since those 
cannot be
+        # joined
+        join_columns = query_obj_clone.get("columns") or []
+        original_query_a = sqlaq.sqla_query
+        shifted_query_b = shifted_sqlaq.sqla_query
+        shifted_query_b_subquery = shifted_query_b.subquery()
+        query_a_cte = original_query_a.cte("query_a_results")
+        column_names_a = [column.key for column in original_query_a.c]
+        exclude_columns_b = set(query_obj_clone.get("columns") or [])
+        # Let's prepare the columns set to be used in query A and B
+        selected_columns_a = [query_a_cte.c[col].label(col) for col in 
column_names_a]
+        # Renamed columns from Query B (with "prev_" prefix)
+        selected_columns_b = [
+            shifted_query_b_subquery.c[col].label(f"prev_{col}")
+            for col in shifted_query_b_subquery.c.keys()
+            if col not in exclude_columns_b
+        ]
+        # Combine selected columns from both queries
+        final_selected_columns = selected_columns_a + selected_columns_b
+        if join_columns and not query_obj_clone.get("is_rowcount"):
+            # Proceed with JOIN operation as before since join_columns is not 
empty
+            join_conditions = [
+                shifted_query_b_subquery.c[col] == query_a_cte.c[col]
+                for col in join_columns
+                if col in shifted_query_b_subquery.c and col in query_a_cte.c
+            ]
+            final_query = sa.select(*final_selected_columns).select_from(
+                shifted_query_b_subquery.join(query_a_cte, 
sa.and_(*join_conditions))
+            )
+        else:
+            # When dealing with queries that have no columns or that are 
totals,
+            # rowcounts etc we join with the 1 = 1 to create a result set that 
have
+            # both sets (original and prev)
+            final_query = sa.select(*final_selected_columns).select_from(
+                shifted_query_b_subquery.join(
+                    query_a_cte, sa.literal(True) == sa.literal(True)
+                )
+            )
+        # Transform the query as you would within get_query_str_extended
+        final_query_sql = self.database.compile_sqla_query(final_query)
+        final_query_sql = self._apply_cte(final_query_sql, sqlaq.cte)
+        final_query_sql = sqlparse.format(final_query_sql, reindent=True)
+        if mutate:
+            final_query_sql = self.mutate_query_from_config(final_query_sql)
+
+        # Prepare the labels for the columns to be used
+        labels_expected = self.extract_column_names(final_selected_columns)
+
+        return final_query_sql, labels_expected
+
     def get_query_str_extended(
         self,
         query_obj: QueryObjectDict,
@@ -917,15 +1036,37 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
         except SupersetParseError:
             logger.warning("Unable to parse SQL to format it, passing it 
as-is")
 
+        # Need to tell apart the regular queries from the ones that need
+        # Time comparison
+        query_obj_clone = copy.copy(query_obj)
+        query_object_extras: dict[str, Any] = query_obj.get("extras", {})
+        instant_time_comparison_info = query_object_extras.get(
+            "instant_time_comparison_info", {}
+        )
+
+        if (
+            is_feature_enabled("CHART_PLUGINS_EXPERIMENTAL")
+            and instant_time_comparison_info
+        ):
+            (
+                final_query_sql,
+                labels_expected,
+            ) = self.process_time_compare_join(
+                query_obj_clone, sqlaq, mutate, instant_time_comparison_info
+            )
+        else:
+            final_query_sql = sql
+            labels_expected = sqlaq.labels_expected
+
         if mutate:
             sql = self.mutate_query_from_config(sql)
         return QueryStringExtended(
             applied_template_filters=sqlaq.applied_template_filters,
             applied_filter_columns=sqlaq.applied_filter_columns,
             rejected_filter_columns=sqlaq.rejected_filter_columns,
-            labels_expected=sqlaq.labels_expected,
+            labels_expected=labels_expected,
             prequeries=sqlaq.prequeries,
-            sql=sql,
+            sql=final_query_sql if final_query_sql else sql,
         )
 
     def _normalize_prequery_result_type(
diff --git a/tests/unit_tests/connectors/__init__.py 
b/tests/unit_tests/connectors/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/tests/unit_tests/connectors/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/unit_tests/connectors/test_models.py 
b/tests/unit_tests/connectors/test_models.py
new file mode 100644
index 0000000000..1a176f4860
--- /dev/null
+++ b/tests/unit_tests/connectors/test_models.py
@@ -0,0 +1,386 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import datetime
+
+from sqlalchemy.orm.session import Session
+
+from superset import db
+from tests.unit_tests.conftest import with_feature_flags
+
+
+class TestInstantTimeComparisonQueryGeneration:
+    @staticmethod
+    def base_setup(session: Session):
+        from superset.connectors.sqla.models import SqlaTable, SqlMetric, 
TableColumn
+        from superset.models.core import Database
+
+        engine = db.session.get_bind()
+        SqlaTable.metadata.create_all(engine)  # pylint: disable=no-member
+
+        table = SqlaTable(
+            table_name="my_table",
+            schema="my_schema",
+            database=Database(database_name="my_database", 
sqlalchemy_uri="sqlite://"),
+        )
+
+        # Common columns
+        columns = [
+            {"column_name": "ds", "type": "DATETIME"},
+            {"column_name": "gender", "type": "VARCHAR(255)"},
+            {"column_name": "name", "type": "VARCHAR(255)"},
+            {"column_name": "state", "type": "VARCHAR(255)"},
+        ]
+
+        # Add columns to the table
+        for col in columns:
+            TableColumn(column_name=col["column_name"], type=col["type"], 
table=table)
+
+        # Common metrics
+        metrics = [
+            {"metric_name": "count", "expression": "count(*)"},
+            {"metric_name": "sum_sum", "expression": "SUM"},
+        ]
+
+        # Add metrics to the table
+        for metric in metrics:
+            SqlMetric(
+                metric_name=metric["metric_name"],
+                expression=metric["expression"],
+                table=table,
+            )
+
+        db.session.add(table)
+        db.session.flush()
+
+        return table
+
+    @staticmethod
+    def generate_base_query_obj():
+        return {
+            "apply_fetch_values_predicate": False,
+            "columns": ["name"],
+            "extras": {
+                "having": "",
+                "where": "",
+                "instant_time_comparison_info": {
+                    "range": "y",
+                },
+            },
+            "filter": [
+                {"op": "TEMPORAL_RANGE", "val": "1984-01-01 : 2024-02-14", 
"col": "ds"}
+            ],
+            "from_dttm": datetime.datetime(1984, 1, 1, 0, 0),
+            "granularity": None,
+            "inner_from_dttm": None,
+            "inner_to_dttm": None,
+            "is_rowcount": False,
+            "is_timeseries": False,
+            "order_desc": True,
+            "orderby": [("SUM(num_boys)", False)],
+            "row_limit": 10,
+            "row_offset": 0,
+            "series_columns": [],
+            "series_limit": 0,
+            "series_limit_metric": None,
+            "to_dttm": datetime.datetime(2024, 2, 14, 0, 0),
+            "time_shift": None,
+            "metrics": [
+                {
+                    "aggregate": "SUM",
+                    "column": {
+                        "column_name": "num_boys",
+                        "type": "BIGINT",
+                        "filterable": True,
+                        "groupby": True,
+                        "id": 334,
+                        "is_certified": False,
+                        "is_dttm": False,
+                        "type_generic": 0,
+                    },
+                    "datasourceWarning": False,
+                    "expressionType": "SIMPLE",
+                    "hasCustomLabel": False,
+                    "label": "SUM(num_boys)",
+                    "optionName": "metric_gzp6eq9g1lc_d8o0mj0mhq4",
+                    "sqlExpression": None,
+                },
+                {
+                    "aggregate": "SUM",
+                    "column": {
+                        "column_name": "num_girls",
+                        "type": "BIGINT",
+                        "filterable": True,
+                        "groupby": True,  # Note: This will need adjustment in 
some cases
+                        "id": 335,
+                        "is_certified": False,
+                        "is_dttm": False,
+                        "type_generic": 0,
+                    },
+                    "datasourceWarning": False,
+                    "expressionType": "SIMPLE",
+                    "hasCustomLabel": False,
+                    "label": "SUM(num_girls)",
+                    "optionName": "metric_5gyhtmyfw1t_d42py86jpco",
+                    "sqlExpression": None,
+                },
+            ],
+        }
+
+    @with_feature_flags(CHART_PLUGINS_EXPERIMENTAL=True)
+    def test_creates_time_comparison_query(session: Session):
+        table = TestInstantTimeComparisonQueryGeneration.base_setup(session)
+        query_obj = 
TestInstantTimeComparisonQueryGeneration.generate_base_query_obj()
+        str = table.get_query_str_extended(query_obj)
+        expected_str = """
+            WITH query_a_results AS
+            (SELECT name AS name,
+                    sum(num_boys) AS "SUM(num_boys)",
+                    sum(num_girls) AS "SUM(num_girls)"
+            FROM my_schema.my_table
+            WHERE ds >= '1984-01-01 00:00:00'
+                AND ds < '2024-02-14 00:00:00'
+            GROUP BY name
+            ORDER BY "SUM(num_boys)" DESC
+            LIMIT 10
+            OFFSET 0)
+            SELECT query_a_results.name AS name,
+                query_a_results."SUM(num_boys)" AS "SUM(num_boys)",
+                query_a_results."SUM(num_girls)" AS "SUM(num_girls)",
+                anon_1."SUM(num_boys)" AS "prev_SUM(num_boys)",
+                anon_1."SUM(num_girls)" AS "prev_SUM(num_girls)"
+            FROM
+            (SELECT name AS name,
+                    sum(num_boys) AS "SUM(num_boys)",
+                    sum(num_girls) AS "SUM(num_girls)"
+            FROM my_schema.my_table
+            WHERE ds >= '1983-01-01 00:00:00'
+                AND ds < '2023-02-14 00:00:00'
+            GROUP BY name
+            ORDER BY "SUM(num_boys)" DESC) AS anon_1
+            JOIN query_a_results ON anon_1.name = query_a_results.name
+        """
+        simplified_query1 = " ".join(str.sql.split()).lower()
+        simplified_query2 = " ".join(expected_str.split()).lower()
+        assert table.id == 1
+        assert simplified_query1 == simplified_query2
+
+    @with_feature_flags(CHART_PLUGINS_EXPERIMENTAL=True)
+    def test_creates_time_comparison_query_no_columns(session: Session):
+        table = TestInstantTimeComparisonQueryGeneration.base_setup(session)
+        query_obj = 
TestInstantTimeComparisonQueryGeneration.generate_base_query_obj()
+        query_obj["columns"] = []
+        query_obj["metrics"][0]["column"]["groupby"] = False
+        query_obj["metrics"][1]["column"]["groupby"] = False
+
+        str = table.get_query_str_extended(query_obj)
+        expected_str = """
+            WITH query_a_results AS
+            (SELECT sum(num_boys) AS "SUM(num_boys)",
+                    sum(num_girls) AS "SUM(num_girls)"
+            FROM my_schema.my_table
+            WHERE ds >= '1984-01-01 00:00:00'
+                AND ds < '2024-02-14 00:00:00'
+            ORDER BY "SUM(num_boys)" DESC
+            LIMIT 10
+            OFFSET 0)
+            SELECT query_a_results."SUM(num_boys)" AS "SUM(num_boys)",
+                query_a_results."SUM(num_girls)" AS "SUM(num_girls)",
+                anon_1."SUM(num_boys)" AS "prev_SUM(num_boys)",
+                anon_1."SUM(num_girls)" AS "prev_SUM(num_girls)"
+            FROM
+            (SELECT sum(num_boys) AS "SUM(num_boys)",
+                    sum(num_girls) AS "SUM(num_girls)"
+            FROM my_schema.my_table
+            WHERE ds >= '1983-01-01 00:00:00'
+                AND ds < '2023-02-14 00:00:00'
+            ORDER BY "SUM(num_boys)" DESC) AS anon_1
+            JOIN query_a_results ON 1 = 1
+        """
+        simplified_query1 = " ".join(str.sql.split()).lower()
+        simplified_query2 = " ".join(expected_str.split()).lower()
+        assert table.id == 1
+        assert simplified_query1 == simplified_query2
+
+    @with_feature_flags(CHART_PLUGINS_EXPERIMENTAL=True)
+    def test_creates_time_comparison_rowcount_query(session: Session):
+        table = TestInstantTimeComparisonQueryGeneration.base_setup(session)
+        query_obj = 
TestInstantTimeComparisonQueryGeneration.generate_base_query_obj()
+        query_obj["is_rowcount"] = True
+        str = table.get_query_str_extended(query_obj)
+        expected_str = """
+            WITH query_a_results AS
+        (SELECT COUNT(*) AS rowcount
+        FROM
+            (SELECT name AS name,
+                    sum(num_boys) AS "SUM(num_boys)",
+                    sum(num_girls) AS "SUM(num_girls)"
+            FROM my_schema.my_table
+            WHERE ds >= '1984-01-01 00:00:00'
+                AND ds < '2024-02-14 00:00:00'
+            GROUP BY name
+            ORDER BY "SUM(num_boys)" DESC
+            LIMIT 10
+            OFFSET 0) AS rowcount_qry)
+        SELECT query_a_results.rowcount AS rowcount,
+            anon_1.rowcount AS prev_rowcount
+        FROM
+        (SELECT COUNT(*) AS rowcount
+        FROM
+            (SELECT name AS name,
+                    sum(num_boys) AS "SUM(num_boys)",
+                    sum(num_girls) AS "SUM(num_girls)"
+            FROM my_schema.my_table
+            WHERE ds >= '1983-01-01 00:00:00'
+                AND ds < '2023-02-14 00:00:00'
+            GROUP BY name
+            ORDER BY "SUM(num_boys)" DESC) AS rowcount_qry) AS anon_1
+        JOIN query_a_results ON 1 = 1
+        """
+        simplified_query1 = " ".join(str.sql.split()).lower()
+        simplified_query2 = " ".join(expected_str.split()).lower()
+        assert table.id == 1
+        assert simplified_query1 == simplified_query2
+
+    @with_feature_flags(CHART_PLUGINS_EXPERIMENTAL=True)
+    def test_creates_query_without_time_comparison(session: Session):
+        table = TestInstantTimeComparisonQueryGeneration.base_setup(session)
+        query_obj = 
TestInstantTimeComparisonQueryGeneration.generate_base_query_obj()
+        query_obj["extras"]["instant_time_comparison_info"] = None
+        str = table.get_query_str_extended(query_obj)
+        expected_str = """
+            SELECT name AS name,
+                sum(num_boys) AS "SUM(num_boys)",
+                sum(num_girls) AS "SUM(num_girls)"
+            FROM my_schema.my_table
+            WHERE ds >= '1984-01-01 00:00:00'
+            AND ds < '2024-02-14 00:00:00'
+            GROUP BY name
+            ORDER BY "SUM(num_boys)" DESC
+            LIMIT 10
+            OFFSET 0
+        """
+        simplified_query1 = " ".join(str.sql.split()).lower()
+        simplified_query2 = " ".join(expected_str.split()).lower()
+        assert table.id == 1
+        assert simplified_query1 == simplified_query2
+
+    @with_feature_flags(CHART_PLUGINS_EXPERIMENTAL=True)
+    def test_creates_time_comparison_query_custom_filters(session: Session):
+        table = TestInstantTimeComparisonQueryGeneration.base_setup(session)
+        query_obj = 
TestInstantTimeComparisonQueryGeneration.generate_base_query_obj()
+        query_obj["extras"]["instant_time_comparison_info"] = {
+            "range": "c",
+            "filter": {
+                "op": "TEMPORAL_RANGE",
+                "val": "1900-01-01 : 1950-02-14",
+                "col": "ds",
+            },
+        }
+        str = table.get_query_str_extended(query_obj)
+        expected_str = """
+            WITH query_a_results AS
+            (SELECT name AS name,
+                    sum(num_boys) AS "SUM(num_boys)",
+                    sum(num_girls) AS "SUM(num_girls)"
+            FROM my_schema.my_table
+            WHERE ds >= '1984-01-01 00:00:00'
+                AND ds < '2024-02-14 00:00:00'
+            GROUP BY name
+            ORDER BY "SUM(num_boys)" DESC
+            LIMIT 10
+            OFFSET 0)
+            SELECT query_a_results.name AS name,
+                query_a_results."SUM(num_boys)" AS "SUM(num_boys)",
+                query_a_results."SUM(num_girls)" AS "SUM(num_girls)",
+                anon_1."SUM(num_boys)" AS "prev_SUM(num_boys)",
+                anon_1."SUM(num_girls)" AS "prev_SUM(num_girls)"
+            FROM
+            (SELECT name AS name,
+                    sum(num_boys) AS "SUM(num_boys)",
+                    sum(num_girls) AS "SUM(num_girls)"
+            FROM my_schema.my_table
+            WHERE ds >= '1900-01-01 00:00:00'
+                AND ds < '1950-02-14 00:00:00'
+            GROUP BY name
+            ORDER BY "SUM(num_boys)" DESC) AS anon_1
+            JOIN query_a_results ON anon_1.name = query_a_results.name
+        """
+        simplified_query1 = " ".join(str.sql.split()).lower()
+        simplified_query2 = " ".join(expected_str.split()).lower()
+        assert table.id == 1
+        assert simplified_query1 == simplified_query2
+
+    @with_feature_flags(CHART_PLUGINS_EXPERIMENTAL=True)
+    def test_creates_time_comparison_query_paginated(session: Session):
+        table = TestInstantTimeComparisonQueryGeneration.base_setup(session)
+        query_obj = 
TestInstantTimeComparisonQueryGeneration.generate_base_query_obj()
+        query_obj["row_offset"] = 20
+        str = table.get_query_str_extended(query_obj)
+        expected_str = """
+            WITH query_a_results AS
+            (SELECT name AS name,
+                    sum(num_boys) AS "SUM(num_boys)",
+                    sum(num_girls) AS "SUM(num_girls)"
+            FROM my_schema.my_table
+            WHERE ds >= '1984-01-01 00:00:00'
+                AND ds < '2024-02-14 00:00:00'
+            GROUP BY name
+            ORDER BY "SUM(num_boys)" DESC
+            LIMIT 10
+            OFFSET 20)
+            SELECT query_a_results.name AS name,
+                query_a_results."SUM(num_boys)" AS "SUM(num_boys)",
+                query_a_results."SUM(num_girls)" AS "SUM(num_girls)",
+                anon_1."SUM(num_boys)" AS "prev_SUM(num_boys)",
+                anon_1."SUM(num_girls)" AS "prev_SUM(num_girls)"
+            FROM
+            (SELECT name AS name,
+                    sum(num_boys) AS "SUM(num_boys)",
+                    sum(num_girls) AS "SUM(num_girls)"
+            FROM my_schema.my_table
+            WHERE ds >= '1983-01-01 00:00:00'
+                AND ds < '2023-02-14 00:00:00'
+            GROUP BY name
+            ORDER BY "SUM(num_boys)" DESC) AS anon_1
+            JOIN query_a_results ON anon_1.name = query_a_results.name
+        """
+        simplified_query1 = " ".join(str.sql.split()).lower()
+        simplified_query2 = " ".join(expected_str.split()).lower()
+        assert table.id == 1
+        assert simplified_query1 == simplified_query2
+
+    @with_feature_flags(CHART_PLUGINS_EXPERIMENTAL=False)
+    def test_ignore_if_ff_off(session: Session):
+        table = TestInstantTimeComparisonQueryGeneration.base_setup(session)
+        query_obj = 
TestInstantTimeComparisonQueryGeneration.generate_base_query_obj()
+        str = table.get_query_str_extended(query_obj)
+        expected_str = """
+            SELECT name AS name,
+                sum(num_boys) AS "SUM(num_boys)",
+                sum(num_girls) AS "SUM(num_girls)"
+            FROM my_schema.my_table
+            WHERE ds >= '1984-01-01 00:00:00'
+            AND ds < '2024-02-14 00:00:00'
+            GROUP BY name
+            ORDER BY "SUM(num_boys)" DESC
+            LIMIT 10
+            OFFSET 0
+        """
+        simplified_query1 = " ".join(str.sql.split()).lower()
+        simplified_query2 = " ".join(expected_str.split()).lower()
+        assert table.id == 1
+        assert simplified_query1 == simplified_query2

Reply via email to