This is an automated email from the ASF dual-hosted git repository. beto pushed a commit to branch hackathon-12-2025 in repository https://gitbox.apache.org/repos/asf/superset.git
commit bf342d66db6fe15f47e3acf60ca4883f7a3bbf32 Author: Beto Dealmeida <[email protected]> AuthorDate: Fri Dec 19 13:22:58 2025 -0500 Fix case --- superset/sql/parse.py | 96 +++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 71 insertions(+), 25 deletions(-) diff --git a/superset/sql/parse.py b/superset/sql/parse.py index 709085d728..4ec056e6fd 100644 --- a/superset/sql/parse.py +++ b/superset/sql/parse.py @@ -35,6 +35,7 @@ from sqlglot.dialects.dialect import ( ) from sqlglot.dialects.singlestore import SingleStore from sqlglot.errors import ParseError +from sqlglot.optimizer.normalize_identifiers import normalize_identifiers from sqlglot.optimizer.pushdown_predicates import ( pushdown_predicates, ) @@ -292,24 +293,40 @@ class RLSTransformer: catalog: str | None, schema: str | None, rules: dict[Table, list[exp.Expression]], + dialect: Dialects | type[Dialect] | None = None, ) -> None: self.catalog = catalog self.schema = schema - # Normalize table keys to lowercase for case-insensitive matching - # This is needed because apply_cls calls qualify() which may change - # identifier case (e.g., Snowflake uppercases identifiers) + self.dialect = dialect + # Normalize table keys using dialect-aware normalization + # This ensures matching works correctly regardless of how the dialect + # handles identifier case (e.g., Snowflake uppercases, Postgres lowercases) self.rules = { self._normalize_table(table): predicates for table, predicates in rules.items() } - @staticmethod - def _normalize_table(table: Table) -> Table: - """Normalize table to lowercase for case-insensitive matching.""" + def _normalize_table(self, table: Table) -> Table: + """ + Normalize table identifiers using dialect-aware normalization. + + This uses sqlglot's normalize_identifiers to match how the dialect + handles identifier case: + - Snowflake: uppercases unquoted identifiers + - PostgreSQL: lowercases unquoted identifiers + - Quoted identifiers preserve their case + """ + # Create a temporary exp.Table node for normalization + table_exp = exp.Table( + this=exp.Identifier(this=table.table) if table.table else None, + db=exp.Identifier(this=table.schema) if table.schema else None, + catalog=exp.Identifier(this=table.catalog) if table.catalog else None, + ) + normalized = normalize_identifiers(table_exp, dialect=self.dialect) return Table( - table=table.table.lower() if table.table else table.table, - schema=table.schema.lower() if table.schema else table.schema, - catalog=table.catalog.lower() if table.catalog else table.catalog, + table=normalized.name if normalized.name else table.table, + schema=normalized.db if normalized.db else table.schema, + catalog=normalized.catalog if normalized.catalog else table.catalog, ) def get_predicate(self, table_node: exp.Table) -> exp.Expression | None: @@ -473,20 +490,47 @@ class CLSTransformer: rules: CLSRules, dialect: Dialects | type[Dialect] | None, ) -> None: - self.rules = self._normalize_rules(rules) self.dialect = dialect + self.rules = self._normalize_rules(rules) self.hash_pattern = CLS_HASH_FUNCTIONS.get(dialect, CLS_HASH_FUNCTIONS[None]) + def _normalize_identifier(self, name: str) -> str: + """Normalize an identifier using dialect-aware normalization.""" + ident = exp.Identifier(this=name) + normalized = normalize_identifiers(ident, dialect=self.dialect) + return normalized.name + + def _normalize_table(self, table: Table) -> Table: + """ + Normalize table identifiers using dialect-aware normalization. + + This uses sqlglot's normalize_identifiers to match how the dialect + handles identifier case: + - Snowflake: uppercases unquoted identifiers + - PostgreSQL: lowercases unquoted identifiers + - Quoted identifiers preserve their case + """ + table_exp = exp.Table( + this=exp.Identifier(this=table.table) if table.table else None, + db=exp.Identifier(this=table.schema) if table.schema else None, + catalog=exp.Identifier(this=table.catalog) if table.catalog else None, + ) + normalized = normalize_identifiers(table_exp, dialect=self.dialect) + return Table( + table=normalized.name if normalized.name else table.table, + schema=normalized.db if normalized.db else table.schema, + catalog=normalized.catalog if normalized.catalog else table.catalog, + ) + def _normalize_rules(self, rules: CLSRules) -> dict[Table, dict[str, CLSAction]]: """ - Normalize table and column names to lowercase for case-insensitive matching. + Normalize table and column names using dialect-aware normalization. """ return { - Table( - table=table.table.lower(), - schema=table.schema.lower() if table.schema else None, - catalog=table.catalog.lower() if table.catalog else None, - ): {col.lower(): action for col, action in cols.items()} + self._normalize_table(table): { + self._normalize_identifier(col): action + for col, action in cols.items() + } for table, cols in rules.items() } @@ -509,22 +553,24 @@ class CLSTransformer: return None # Create a normalized Table for lookup - lookup_table = Table( - table=table_name.lower(), - schema=schema.lower() if schema else None, - catalog=catalog.lower() if catalog else None, + lookup_table = self._normalize_table( + Table( + table=table_name, + schema=schema, + catalog=catalog, + ) ) + normalized_column = self._normalize_identifier(column_name) # First try exact match with schema/catalog - table_rules = self.rules.get(lookup_table) - if table_rules: - return table_rules.get(column_name.lower()) + if (table_rules := self.rules.get(lookup_table)): + return table_rules.get(normalized_column) # Fallback: match by table name only # This handles cases where the rule has schema/catalog but the query doesn't for rule_table, cols in self.rules.items(): if rule_table.table == lookup_table.table: - action = cols.get(column_name.lower()) + action = cols.get(normalized_column) if action: return action @@ -1535,7 +1581,7 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): if method not in transformers: raise ValueError(f"Invalid RLS method: {method}") - transformer = transformers[method](catalog, schema, predicates) + transformer = transformers[method](catalog, schema, predicates, self._dialect) self._parsed = self._parsed.transform(transformer) def apply_cls(
