This is an automated email from the ASF dual-hosted git repository. arivero pushed a commit to branch table-time-comparison in repository https://gitbox.apache.org/repos/asf/superset.git
commit 4d1b4ab3ab619982242a03c075970c9a8188afc6 Author: Antonio Rivero <[email protected]> AuthorDate: Tue Mar 26 16:39:41 2024 +0100 Table with Time Comparison: - Handle case sensitive labels so DBs that rely on that don't break - Add a test to consider case sensitive labels --- superset/connectors/sqla/models.py | 26 +++++++++- tests/unit_tests/connectors/test_models.py | 82 ++++++++++++++++++++++++++++++ 2 files changed, 107 insertions(+), 1 deletion(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index c41cdf90bc..3116fac2a9 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -1431,6 +1431,14 @@ class SqlaTable( column_names.append(column_name) return column_names + def find_column_by_insensitive_name( + self, col_list: list[TableColumn], target_name: str + ) -> TableColumn | None: + for col in col_list: + if col.name.lower() == target_name.lower(): + return col + return None + def process_time_compare_join( # pylint: disable=too-many-locals self, query_obj: QueryObjectDict, @@ -1492,7 +1500,23 @@ class SqlaTable( column_names_a = [column.key for column in original_query_a.c] exclude_columns_b = set(query_obj_clone.get("columns") or []) # Let's prepare the columns set to be used in query A and B - selected_columns_a = [query_a_cte.c[col].label(col) for col in column_names_a] + # We need to use the label form the CTE in order to guarantee that + # the columns are correctly selected by the UI later on + original_case_mapping = {col.lower(): col for col in query_a_cte.c.keys()} + selected_columns_a = [] + for col_name in column_names_a: + matched_column = self.find_column_by_insensitive_name( + query_a_cte.columns, col_name + ) + if matched_column is not None: + original_case_name = original_case_mapping.get( + matched_column.name.lower() + ) + if original_case_name: + selected_columns_a.append(matched_column.label(original_case_name)) + else: + # Just as fallback + selected_columns_a.append(matched_column.label(matched_column.name)) # Renamed columns from Query B (with "prev_" prefix) selected_columns_b = [ shifted_query_b_subquery.c[col].label(f"prev_{col}") diff --git a/tests/unit_tests/connectors/test_models.py b/tests/unit_tests/connectors/test_models.py index a601440e64..6deb98d487 100644 --- a/tests/unit_tests/connectors/test_models.py +++ b/tests/unit_tests/connectors/test_models.py @@ -384,3 +384,85 @@ class TestInstantTimeComparisonQueryGeneration: simplified_query2 = " ".join(expected_str.split()).lower() assert table.id == 1 assert simplified_query1 == simplified_query2 + + @with_feature_flags(CHART_PLUGINS_EXPERIMENTAL=True) + def test_creates_time_comparison_query_case_sensitivity(session: Session): + table = TestInstantTimeComparisonQueryGeneration.base_setup(session) + query_obj = TestInstantTimeComparisonQueryGeneration.generate_base_query_obj() + query_obj["columns"] = ["Name"] # Changed case + query_obj["metrics"] = [ + { + "aggregate": "SUM", + "column": { + "column_name": "num_boys", + "type": "BIGINT", + "filterable": True, + "groupby": True, + "id": 334, + "is_certified": False, + "is_dttm": False, + "type_generic": 0, + }, + "datasourceWarning": False, + "expressionType": "SIMPLE", + "hasCustomLabel": False, + "label": "SUM(Num_Boys)", + "optionName": "metric_gzp6eq9g1lc_d8o0mj0mhq4", + "sqlExpression": None, + }, + { + "aggregate": "SUM", + "column": { + "column_name": "num_girls", + "type": "BIGINT", + "filterable": True, + "groupby": True, # Note: This will need adjustment in some cases + "id": 335, + "is_certified": False, + "is_dttm": False, + "type_generic": 0, + }, + "datasourceWarning": False, + "expressionType": "SIMPLE", + "hasCustomLabel": False, + "label": "SUM(Num_Girls)", + "optionName": "metric_5gyhtmyfw1t_d42py86jpco", + "sqlExpression": None, + }, + ] + query_obj["orderby"] = [("SUM(Num_Boys)", False)] # Changed case + str = table.get_query_str_extended(query_obj) + expected_str = """ + WITH query_a_results AS + (SELECT (Name) AS "Name", + sum(num_boys) AS "SUM(Num_Boys)", + sum(num_girls) AS "SUM(Num_Girls)" + FROM my_schema.my_table + WHERE ds >= '1984-01-01 00:00:00' + AND ds < '2024-02-14 00:00:00' + GROUP BY (Name) + ORDER BY "SUM(Num_Boys)" DESC + LIMIT 10 + OFFSET 0) + SELECT query_a_results."Name" AS "Name", + query_a_results."SUM(Num_Boys)" AS "SUM(Num_Boys)", + query_a_results."SUM(Num_Girls)" AS "SUM(Num_Girls)", + anon_1."SUM(Num_Boys)" AS "prev_SUM(Num_Boys)", + anon_1."SUM(Num_Girls)" AS "prev_SUM(Num_Girls)" + FROM query_a_results + LEFT OUTER JOIN + (SELECT (Name) AS "Name", + sum(num_boys) AS "SUM(Num_Boys)", + sum(num_girls) AS "SUM(Num_Girls)" + FROM my_schema.my_table + WHERE ds >= '1983-01-01 00:00:00' + AND ds < '2023-02-14 00:00:00' + GROUP BY (Name) + ORDER BY "SUM(Num_Boys)" DESC) AS anon_1 ON anon_1."Name" = query_a_results."Name" + """ + simplified_query1 = " ".join(str.sql.split()).lower() + simplified_query2 = " ".join(expected_str.split()).lower() + + # Assertions + assert table.id == 1 + assert simplified_query1 == simplified_query2
