This is an automated email from the ASF dual-hosted git repository. beto pushed a commit to branch reformat-sqlglot in repository https://gitbox.apache.org/repos/asf/superset.git
commit ffb97529e7d62308f3723c4531b4b2fbf8c9b6df Author: Beto Dealmeida <[email protected]> AuthorDate: Thu Oct 10 17:34:54 2024 -0400 WIP --- superset/sql/parse.py | 68 +++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 50 insertions(+), 18 deletions(-) diff --git a/superset/sql/parse.py b/superset/sql/parse.py index 377411b944..c2d95b9e24 100644 --- a/superset/sql/parse.py +++ b/superset/sql/parse.py @@ -26,6 +26,7 @@ from dataclasses import dataclass from typing import Any, Generic, TypeVar import sqlglot +import sqlparse from sqlglot import exp from sqlglot.dialects.dialect import Dialect, Dialects from sqlglot.errors import ParseError @@ -138,9 +139,9 @@ class BaseSQLStatement(Generic[InternalRepresentation]): """ Base class for SQL statements. - The class can be instantiated with a string representation of the script or, for - efficiency reasons, with a pre-parsed AST. This is useful with `sqlglot.parse`, - which will split a script in multiple already parsed statements. + The class should be instantiated with a string representation of the script and, for + efficiency reasons, optionally with a pre-parsed AST. This is useful with + `sqlglot.parse`, which will split a script in multiple already parsed statements. The `engine` parameters comes from the `engine` attribute in a Superset DB engine spec. @@ -148,14 +149,12 @@ class BaseSQLStatement(Generic[InternalRepresentation]): def __init__( self, - statement: str | InternalRepresentation, + statement: str, engine: str, + ast: InternalRepresentation | None = None, ): - self._parsed: InternalRepresentation = ( - self._parse_statement(statement, engine) - if isinstance(statement, str) - else statement - ) + self._sql = statement + self._parsed = ast or self._parse_statement(statement, engine) self.engine = engine self.tables = self._extract_tables_from_statement(self._parsed, self.engine) @@ -239,11 +238,12 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): def __init__( self, - statement: str | exp.Expression, + statement: str, engine: str, + ast: exp.Expression | None = None, ): self._dialect = SQLGLOT_DIALECTS.get(engine) - super().__init__(statement, engine) + super().__init__(statement, engine, ast) @classmethod def _parse(cls, script: str, engine: str) -> list[exp.Expression]: @@ -275,11 +275,30 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): script: str, engine: str, ) -> list[SQLStatement]: - return [ - cls(statement, engine) - for statement in cls._parse(script, engine) - if statement - ] + if engine in SQLGLOT_DIALECTS: + try: + return [ + cls(ast.sql(), engine, ast) + for ast in cls._parse(script, engine) + if ast + ] + except ValueError: + # `ast.sql()` might raise an error on some cases (eg, `SHOW TABLES + # FROM`). In this case, we rely on the tokenizer to generate the + # statements. + pass + + # When we don't have a sqlglot dialect we can't rely on `ast.sql()` to correctly + # generate the SQL of each statement, so we tokenize the script and split it + # based on the location of semi-colons. + statements = [] + for token in sqlglot.tokenize(script): + if token.token_type == sqlglot.TokenType.SEMICOLON: + statement, script = script[token.start + 1], script[token.end + 1 :] + ast = cls._parse_statement(statement, engine)[0] + statements.append(cls(statement, engine, ast)) + + return statements @classmethod def _parse_statement( @@ -349,6 +368,14 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): """ Pretty-format the SQL statement. """ + # Reformatting SQL using the generic sqlglot dialect is known to break queries. + # For example, it will change `foo NOT IN (1, 2)` to `NOT foo IN (1,2)`, which + # breaks the query for Firebolt. To avoid this, we use sqlparse for formatting + # when the dialect is not known. + if self._dialect is None: + return sqlparse.format(self._sql, reindent=True, keyword_case="upper") + + # XXX: this might fail even if the dialect is supported! SHOW TABLES write = Dialect.get_or_raise(self._dialect) return write.generate(self._parsed, copy=False, comments=comments, pretty=True) @@ -456,7 +483,9 @@ class KustoKQLStatement(BaseSQLStatement[str]): https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/scalar-data-types/string for more information. """ - return [cls(statement, engine) for statement in split_kql(script)] + return [ + cls(statement, engine, statement.strip()) for statement in split_kql(script) + ] @classmethod def _parse_statement( @@ -498,7 +527,7 @@ class KustoKQLStatement(BaseSQLStatement[str]): """ Pretty-format the SQL statement. """ - return self._parsed + return self._sql.strip() def get_settings(self) -> dict[str, str | bool]: """ @@ -548,6 +577,9 @@ class SQLScript: def format(self, comments: bool = True) -> str: """ Pretty-format the SQL script. + + Note that even though KQL is very different from SQL, multiple statements are + still separated by semi-colons. """ return ";\n".join(statement.format(comments) for statement in self.statements)
