This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch reformat-sqlglot
in repository https://gitbox.apache.org/repos/asf/superset.git

commit e14ad062ffeb9bc4b73e952f05631870c7867869
Author: Beto Dealmeida <[email protected]>
AuthorDate: Thu Oct 10 17:34:54 2024 -0400

    fix: sqlparse fallback for formatting queries
---
 superset/sql/parse.py               | 86 ++++++++++++++++++++++++++++---------
 tests/unit_tests/sql/parse_tests.py | 22 ++++++++++
 2 files changed, 88 insertions(+), 20 deletions(-)

diff --git a/superset/sql/parse.py b/superset/sql/parse.py
index 377411b944..1866870262 100644
--- a/superset/sql/parse.py
+++ b/superset/sql/parse.py
@@ -26,6 +26,7 @@ from dataclasses import dataclass
 from typing import Any, Generic, TypeVar
 
 import sqlglot
+import sqlparse
 from sqlglot import exp
 from sqlglot.dialects.dialect import Dialect, Dialects
 from sqlglot.errors import ParseError
@@ -138,9 +139,9 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
     """
     Base class for SQL statements.
 
-    The class can be instantiated with a string representation of the script 
or, for
-    efficiency reasons, with a pre-parsed AST. This is useful with 
`sqlglot.parse`,
-    which will split a script in multiple already parsed statements.
+    The class should be instantiated with a string representation of the 
script and, for
+    efficiency reasons, optionally with a pre-parsed AST. This is useful with
+    `sqlglot.parse`, which will split a script in multiple already parsed 
statements.
 
     The `engine` parameters comes from the `engine` attribute in a Superset DB 
engine
     spec.
@@ -148,14 +149,12 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
 
     def __init__(
         self,
-        statement: str | InternalRepresentation,
+        statement: str,
         engine: str,
+        ast: InternalRepresentation | None = None,
     ):
-        self._parsed: InternalRepresentation = (
-            self._parse_statement(statement, engine)
-            if isinstance(statement, str)
-            else statement
-        )
+        self._sql = statement
+        self._parsed = ast or self._parse_statement(statement, engine)
         self.engine = engine
         self.tables = self._extract_tables_from_statement(self._parsed, 
self.engine)
 
@@ -239,11 +238,12 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
 
     def __init__(
         self,
-        statement: str | exp.Expression,
+        statement: str,
         engine: str,
+        ast: exp.Expression | None = None,
     ):
         self._dialect = SQLGLOT_DIALECTS.get(engine)
-        super().__init__(statement, engine)
+        super().__init__(statement, engine, ast)
 
     @classmethod
     def _parse(cls, script: str, engine: str) -> list[exp.Expression]:
@@ -275,11 +275,37 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
         script: str,
         engine: str,
     ) -> list[SQLStatement]:
-        return [
-            cls(statement, engine)
-            for statement in cls._parse(script, engine)
-            if statement
-        ]
+        if dialect := SQLGLOT_DIALECTS.get(engine):
+            try:
+                return [
+                    cls(ast.sql(), engine, ast)
+                    for ast in cls._parse(script, engine)
+                    if ast
+                ]
+            except ValueError:
+                # `ast.sql()` might raise an error on some cases (eg, `SHOW 
TABLES
+                # FROM`). In this case, we rely on the tokenizer to generate 
the
+                # statements.
+                pass
+
+        # When we don't have a sqlglot dialect we can't rely on `ast.sql()` to 
correctly
+        # generate the SQL of each statement, so we tokenize the script and 
split it
+        # based on the location of semi-colons.
+        statements = []
+        start = 0
+        remainder = script
+        for token in sqlglot.tokenize(script):
+            if token.token_type == sqlglot.TokenType.SEMICOLON:
+                statement, start = script[start : token.start + 1], token.end 
+ 1
+                ast = sqlglot.parse_one(statement, dialect)
+                statements.append(cls(statement, engine, ast))
+                remainder = script[start:]
+
+        if remainder.strip():
+            ast = sqlglot.parse_one(remainder, dialect)
+            statements.append(cls(remainder, engine, ast))
+
+        return statements
 
     @classmethod
     def _parse_statement(
@@ -349,8 +375,23 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
         """
         Pretty-format the SQL statement.
         """
-        write = Dialect.get_or_raise(self._dialect)
-        return write.generate(self._parsed, copy=False, comments=comments, 
pretty=True)
+        if self._dialect:
+            try:
+                write = Dialect.get_or_raise(self._dialect)
+                return write.generate(
+                    self._parsed,
+                    copy=False,
+                    comments=comments,
+                    pretty=True,
+                )
+            except ValueError:
+                pass
+
+        # Reformatting SQL using the generic sqlglot dialect is known to break 
queries.
+        # For example, it will change `foo NOT IN (1, 2)` to `NOT foo IN 
(1,2)`, which
+        # breaks the query for Firebolt. To avoid this, we use sqlparse for 
formatting
+        # when the dialect is not known.
+        return sqlparse.format(self._sql, reindent=True, keyword_case="upper")
 
     def get_settings(self) -> dict[str, str | bool]:
         """
@@ -456,7 +497,9 @@ class KustoKQLStatement(BaseSQLStatement[str]):
         
https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/scalar-data-types/string
         for more information.
         """
-        return [cls(statement, engine) for statement in split_kql(script)]
+        return [
+            cls(statement, engine, statement.strip()) for statement in 
split_kql(script)
+        ]
 
     @classmethod
     def _parse_statement(
@@ -498,7 +541,7 @@ class KustoKQLStatement(BaseSQLStatement[str]):
         """
         Pretty-format the SQL statement.
         """
-        return self._parsed
+        return self._sql.strip()
 
     def get_settings(self) -> dict[str, str | bool]:
         """
@@ -548,6 +591,9 @@ class SQLScript:
     def format(self, comments: bool = True) -> str:
         """
         Pretty-format the SQL script.
+
+        Note that even though KQL is very different from SQL, multiple 
statements are
+        still separated by semi-colons.
         """
         return ";\n".join(statement.format(comments) for statement in 
self.statements)
 
diff --git a/tests/unit_tests/sql/parse_tests.py 
b/tests/unit_tests/sql/parse_tests.py
index ae5ebf89a8..338a97a71a 100644
--- a/tests/unit_tests/sql/parse_tests.py
+++ b/tests/unit_tests/sql/parse_tests.py
@@ -284,6 +284,28 @@ def test_extract_tables_show_tables_from() -> None:
     )
 
 
+def test_format_show_tables() -> None:
+    """
+    Test format when `ast.sql()` raises an exception.
+
+    In that case sqlparse should be used instead.
+    """
+    assert (
+        SQLScript("SHOW TABLES FROM s1 like '%order%'", "mysql").format()
+        == "SHOW TABLES FROM s1 LIKE '%order%'"
+    )
+
+
+def test_format_no_dialect() -> None:
+    """
+    Test format with an engine that has no corresponding dialect.
+    """
+    assert (
+        SQLScript("SELECT col FROM t WHERE col NOT IN (1, 2)", 
"firebolt").format()
+        == "SELECT col\nFROM t\nWHERE col NOT IN (1,\n                  2)"
+    )
+
+
 def test_extract_tables_show_columns_from() -> None:
     """
     Test `SHOW COLUMNS FROM`.

Reply via email to