This is an automated email from the ASF dual-hosted git repository.

arivero pushed a commit to branch table-time-comparison
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/table-time-comparison by this 
push:
     new a08667853f Table with Time Comparison:
a08667853f is described below

commit a08667853f8e8d93fea99987fb00d9b18445dd83
Author: Antonio Rivero <[email protected]>
AuthorDate: Tue Mar 26 16:39:41 2024 +0100

    Table with Time Comparison:
    
    - Handle case sensitive labels so DBs that rely on that don't break
    - Add a test to consider case sensitive labels
---
 superset/connectors/sqla/models.py         | 26 +++++++++-
 tests/unit_tests/connectors/test_models.py | 82 ++++++++++++++++++++++++++++++
 2 files changed, 107 insertions(+), 1 deletion(-)

diff --git a/superset/connectors/sqla/models.py 
b/superset/connectors/sqla/models.py
index c41cdf90bc..3116fac2a9 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -1431,6 +1431,14 @@ class SqlaTable(
                 column_names.append(column_name)
         return column_names
 
+    def find_column_by_insensitive_name(
+        self, col_list: list[TableColumn], target_name: str
+    ) -> TableColumn | None:
+        for col in col_list:
+            if col.name.lower() == target_name.lower():
+                return col
+        return None
+
     def process_time_compare_join(  # pylint: disable=too-many-locals
         self,
         query_obj: QueryObjectDict,
@@ -1492,7 +1500,23 @@ class SqlaTable(
         column_names_a = [column.key for column in original_query_a.c]
         exclude_columns_b = set(query_obj_clone.get("columns") or [])
         # Let's prepare the columns set to be used in query A and B
-        selected_columns_a = [query_a_cte.c[col].label(col) for col in 
column_names_a]
+        # We need to use the label form the CTE in order to guarantee that
+        # the columns are correctly selected by the UI later on
+        original_case_mapping = {col.lower(): col for col in 
query_a_cte.c.keys()}
+        selected_columns_a = []
+        for col_name in column_names_a:
+            matched_column = self.find_column_by_insensitive_name(
+                query_a_cte.columns, col_name
+            )
+            if matched_column is not None:
+                original_case_name = original_case_mapping.get(
+                    matched_column.name.lower()
+                )
+                if original_case_name:
+                    
selected_columns_a.append(matched_column.label(original_case_name))
+                else:
+                    # Just as fallback
+                    
selected_columns_a.append(matched_column.label(matched_column.name))
         # Renamed columns from Query B (with "prev_" prefix)
         selected_columns_b = [
             shifted_query_b_subquery.c[col].label(f"prev_{col}")
diff --git a/tests/unit_tests/connectors/test_models.py 
b/tests/unit_tests/connectors/test_models.py
index a601440e64..6deb98d487 100644
--- a/tests/unit_tests/connectors/test_models.py
+++ b/tests/unit_tests/connectors/test_models.py
@@ -384,3 +384,85 @@ class TestInstantTimeComparisonQueryGeneration:
         simplified_query2 = " ".join(expected_str.split()).lower()
         assert table.id == 1
         assert simplified_query1 == simplified_query2
+
+    @with_feature_flags(CHART_PLUGINS_EXPERIMENTAL=True)
+    def test_creates_time_comparison_query_case_sensitivity(session: Session):
+        table = TestInstantTimeComparisonQueryGeneration.base_setup(session)
+        query_obj = 
TestInstantTimeComparisonQueryGeneration.generate_base_query_obj()
+        query_obj["columns"] = ["Name"]  # Changed case
+        query_obj["metrics"] = [
+            {
+                "aggregate": "SUM",
+                "column": {
+                    "column_name": "num_boys",
+                    "type": "BIGINT",
+                    "filterable": True,
+                    "groupby": True,
+                    "id": 334,
+                    "is_certified": False,
+                    "is_dttm": False,
+                    "type_generic": 0,
+                },
+                "datasourceWarning": False,
+                "expressionType": "SIMPLE",
+                "hasCustomLabel": False,
+                "label": "SUM(Num_Boys)",
+                "optionName": "metric_gzp6eq9g1lc_d8o0mj0mhq4",
+                "sqlExpression": None,
+            },
+            {
+                "aggregate": "SUM",
+                "column": {
+                    "column_name": "num_girls",
+                    "type": "BIGINT",
+                    "filterable": True,
+                    "groupby": True,  # Note: This will need adjustment in 
some cases
+                    "id": 335,
+                    "is_certified": False,
+                    "is_dttm": False,
+                    "type_generic": 0,
+                },
+                "datasourceWarning": False,
+                "expressionType": "SIMPLE",
+                "hasCustomLabel": False,
+                "label": "SUM(Num_Girls)",
+                "optionName": "metric_5gyhtmyfw1t_d42py86jpco",
+                "sqlExpression": None,
+            },
+        ]
+        query_obj["orderby"] = [("SUM(Num_Boys)", False)]  # Changed case
+        str = table.get_query_str_extended(query_obj)
+        expected_str = """
+            WITH query_a_results AS
+            (SELECT (Name) AS "Name",
+                    sum(num_boys) AS "SUM(Num_Boys)",
+                    sum(num_girls) AS "SUM(Num_Girls)"
+            FROM my_schema.my_table
+            WHERE ds >= '1984-01-01 00:00:00'
+                AND ds < '2024-02-14 00:00:00'
+            GROUP BY (Name)
+            ORDER BY "SUM(Num_Boys)" DESC
+            LIMIT 10
+            OFFSET 0)
+            SELECT query_a_results."Name" AS "Name",
+                query_a_results."SUM(Num_Boys)" AS "SUM(Num_Boys)",
+                query_a_results."SUM(Num_Girls)" AS "SUM(Num_Girls)",
+                anon_1."SUM(Num_Boys)" AS "prev_SUM(Num_Boys)",
+                anon_1."SUM(Num_Girls)" AS "prev_SUM(Num_Girls)"
+            FROM query_a_results
+            LEFT OUTER JOIN
+            (SELECT (Name) AS "Name",
+                    sum(num_boys) AS "SUM(Num_Boys)",
+                    sum(num_girls) AS "SUM(Num_Girls)"
+            FROM my_schema.my_table
+            WHERE ds >= '1983-01-01 00:00:00'
+                AND ds < '2023-02-14 00:00:00'
+            GROUP BY (Name)
+            ORDER BY "SUM(Num_Boys)" DESC) AS anon_1 ON anon_1."Name" = 
query_a_results."Name"
+        """
+        simplified_query1 = " ".join(str.sql.split()).lower()
+        simplified_query2 = " ".join(expected_str.split()).lower()
+
+        # Assertions
+        assert table.id == 1
+        assert simplified_query1 == simplified_query2

Reply via email to