This is an automated email from the ASF dual-hosted git repository.
arivero pushed a commit to branch table-time-comparison
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/table-time-comparison by this
push:
new a08667853f Table with Time Comparison:
a08667853f is described below
commit a08667853f8e8d93fea99987fb00d9b18445dd83
Author: Antonio Rivero <[email protected]>
AuthorDate: Tue Mar 26 16:39:41 2024 +0100
Table with Time Comparison:
- Handle case sensitive labels so DBs that rely on that don't break
- Add a test to consider case sensitive labels
---
superset/connectors/sqla/models.py | 26 +++++++++-
tests/unit_tests/connectors/test_models.py | 82 ++++++++++++++++++++++++++++++
2 files changed, 107 insertions(+), 1 deletion(-)
diff --git a/superset/connectors/sqla/models.py
b/superset/connectors/sqla/models.py
index c41cdf90bc..3116fac2a9 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -1431,6 +1431,14 @@ class SqlaTable(
column_names.append(column_name)
return column_names
+ def find_column_by_insensitive_name(
+ self, col_list: list[TableColumn], target_name: str
+ ) -> TableColumn | None:
+ for col in col_list:
+ if col.name.lower() == target_name.lower():
+ return col
+ return None
+
def process_time_compare_join( # pylint: disable=too-many-locals
self,
query_obj: QueryObjectDict,
@@ -1492,7 +1500,23 @@ class SqlaTable(
column_names_a = [column.key for column in original_query_a.c]
exclude_columns_b = set(query_obj_clone.get("columns") or [])
# Let's prepare the columns set to be used in query A and B
- selected_columns_a = [query_a_cte.c[col].label(col) for col in
column_names_a]
+ # We need to use the label form the CTE in order to guarantee that
+ # the columns are correctly selected by the UI later on
+ original_case_mapping = {col.lower(): col for col in
query_a_cte.c.keys()}
+ selected_columns_a = []
+ for col_name in column_names_a:
+ matched_column = self.find_column_by_insensitive_name(
+ query_a_cte.columns, col_name
+ )
+ if matched_column is not None:
+ original_case_name = original_case_mapping.get(
+ matched_column.name.lower()
+ )
+ if original_case_name:
+
selected_columns_a.append(matched_column.label(original_case_name))
+ else:
+ # Just as fallback
+
selected_columns_a.append(matched_column.label(matched_column.name))
# Renamed columns from Query B (with "prev_" prefix)
selected_columns_b = [
shifted_query_b_subquery.c[col].label(f"prev_{col}")
diff --git a/tests/unit_tests/connectors/test_models.py
b/tests/unit_tests/connectors/test_models.py
index a601440e64..6deb98d487 100644
--- a/tests/unit_tests/connectors/test_models.py
+++ b/tests/unit_tests/connectors/test_models.py
@@ -384,3 +384,85 @@ class TestInstantTimeComparisonQueryGeneration:
simplified_query2 = " ".join(expected_str.split()).lower()
assert table.id == 1
assert simplified_query1 == simplified_query2
+
+ @with_feature_flags(CHART_PLUGINS_EXPERIMENTAL=True)
+ def test_creates_time_comparison_query_case_sensitivity(session: Session):
+ table = TestInstantTimeComparisonQueryGeneration.base_setup(session)
+ query_obj =
TestInstantTimeComparisonQueryGeneration.generate_base_query_obj()
+ query_obj["columns"] = ["Name"] # Changed case
+ query_obj["metrics"] = [
+ {
+ "aggregate": "SUM",
+ "column": {
+ "column_name": "num_boys",
+ "type": "BIGINT",
+ "filterable": True,
+ "groupby": True,
+ "id": 334,
+ "is_certified": False,
+ "is_dttm": False,
+ "type_generic": 0,
+ },
+ "datasourceWarning": False,
+ "expressionType": "SIMPLE",
+ "hasCustomLabel": False,
+ "label": "SUM(Num_Boys)",
+ "optionName": "metric_gzp6eq9g1lc_d8o0mj0mhq4",
+ "sqlExpression": None,
+ },
+ {
+ "aggregate": "SUM",
+ "column": {
+ "column_name": "num_girls",
+ "type": "BIGINT",
+ "filterable": True,
+ "groupby": True, # Note: This will need adjustment in
some cases
+ "id": 335,
+ "is_certified": False,
+ "is_dttm": False,
+ "type_generic": 0,
+ },
+ "datasourceWarning": False,
+ "expressionType": "SIMPLE",
+ "hasCustomLabel": False,
+ "label": "SUM(Num_Girls)",
+ "optionName": "metric_5gyhtmyfw1t_d42py86jpco",
+ "sqlExpression": None,
+ },
+ ]
+ query_obj["orderby"] = [("SUM(Num_Boys)", False)] # Changed case
+ str = table.get_query_str_extended(query_obj)
+ expected_str = """
+ WITH query_a_results AS
+ (SELECT (Name) AS "Name",
+ sum(num_boys) AS "SUM(Num_Boys)",
+ sum(num_girls) AS "SUM(Num_Girls)"
+ FROM my_schema.my_table
+ WHERE ds >= '1984-01-01 00:00:00'
+ AND ds < '2024-02-14 00:00:00'
+ GROUP BY (Name)
+ ORDER BY "SUM(Num_Boys)" DESC
+ LIMIT 10
+ OFFSET 0)
+ SELECT query_a_results."Name" AS "Name",
+ query_a_results."SUM(Num_Boys)" AS "SUM(Num_Boys)",
+ query_a_results."SUM(Num_Girls)" AS "SUM(Num_Girls)",
+ anon_1."SUM(Num_Boys)" AS "prev_SUM(Num_Boys)",
+ anon_1."SUM(Num_Girls)" AS "prev_SUM(Num_Girls)"
+ FROM query_a_results
+ LEFT OUTER JOIN
+ (SELECT (Name) AS "Name",
+ sum(num_boys) AS "SUM(Num_Boys)",
+ sum(num_girls) AS "SUM(Num_Girls)"
+ FROM my_schema.my_table
+ WHERE ds >= '1983-01-01 00:00:00'
+ AND ds < '2023-02-14 00:00:00'
+ GROUP BY (Name)
+ ORDER BY "SUM(Num_Boys)" DESC) AS anon_1 ON anon_1."Name" =
query_a_results."Name"
+ """
+ simplified_query1 = " ".join(str.sql.split()).lower()
+ simplified_query2 = " ".join(expected_str.split()).lower()
+
+ # Assertions
+ assert table.id == 1
+ assert simplified_query1 == simplified_query2