This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch reformat-sqlglot
in repository https://gitbox.apache.org/repos/asf/superset.git

commit ffb97529e7d62308f3723c4531b4b2fbf8c9b6df
Author: Beto Dealmeida <[email protected]>
AuthorDate: Thu Oct 10 17:34:54 2024 -0400

    WIP
---
 superset/sql/parse.py | 68 +++++++++++++++++++++++++++++++++++++--------------
 1 file changed, 50 insertions(+), 18 deletions(-)

diff --git a/superset/sql/parse.py b/superset/sql/parse.py
index 377411b944..c2d95b9e24 100644
--- a/superset/sql/parse.py
+++ b/superset/sql/parse.py
@@ -26,6 +26,7 @@ from dataclasses import dataclass
 from typing import Any, Generic, TypeVar
 
 import sqlglot
+import sqlparse
 from sqlglot import exp
 from sqlglot.dialects.dialect import Dialect, Dialects
 from sqlglot.errors import ParseError
@@ -138,9 +139,9 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
     """
     Base class for SQL statements.
 
-    The class can be instantiated with a string representation of the script 
or, for
-    efficiency reasons, with a pre-parsed AST. This is useful with 
`sqlglot.parse`,
-    which will split a script in multiple already parsed statements.
+    The class should be instantiated with a string representation of the 
script and, for
+    efficiency reasons, optionally with a pre-parsed AST. This is useful with
+    `sqlglot.parse`, which will split a script in multiple already parsed 
statements.
 
     The `engine` parameters comes from the `engine` attribute in a Superset DB 
engine
     spec.
@@ -148,14 +149,12 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
 
     def __init__(
         self,
-        statement: str | InternalRepresentation,
+        statement: str,
         engine: str,
+        ast: InternalRepresentation | None = None,
     ):
-        self._parsed: InternalRepresentation = (
-            self._parse_statement(statement, engine)
-            if isinstance(statement, str)
-            else statement
-        )
+        self._sql = statement
+        self._parsed = ast or self._parse_statement(statement, engine)
         self.engine = engine
         self.tables = self._extract_tables_from_statement(self._parsed, 
self.engine)
 
@@ -239,11 +238,12 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
 
     def __init__(
         self,
-        statement: str | exp.Expression,
+        statement: str,
         engine: str,
+        ast: exp.Expression | None = None,
     ):
         self._dialect = SQLGLOT_DIALECTS.get(engine)
-        super().__init__(statement, engine)
+        super().__init__(statement, engine, ast)
 
     @classmethod
     def _parse(cls, script: str, engine: str) -> list[exp.Expression]:
@@ -275,11 +275,30 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
         script: str,
         engine: str,
     ) -> list[SQLStatement]:
-        return [
-            cls(statement, engine)
-            for statement in cls._parse(script, engine)
-            if statement
-        ]
+        if engine in SQLGLOT_DIALECTS:
+            try:
+                return [
+                    cls(ast.sql(), engine, ast)
+                    for ast in cls._parse(script, engine)
+                    if ast
+                ]
+            except ValueError:
+                # `ast.sql()` might raise an error on some cases (eg, `SHOW 
TABLES
+                # FROM`). In this case, we rely on the tokenizer to generate 
the
+                # statements.
+                pass
+
+        # When we don't have a sqlglot dialect we can't rely on `ast.sql()` to 
correctly
+        # generate the SQL of each statement, so we tokenize the script and 
split it
+        # based on the location of semi-colons.
+        statements = []
+        for token in sqlglot.tokenize(script):
+            if token.token_type == sqlglot.TokenType.SEMICOLON:
+                statement, script = script[token.start + 1], script[token.end 
+ 1 :]
+                ast = cls._parse_statement(statement, engine)[0]
+                statements.append(cls(statement, engine, ast))
+
+        return statements
 
     @classmethod
     def _parse_statement(
@@ -349,6 +368,14 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
         """
         Pretty-format the SQL statement.
         """
+        # Reformatting SQL using the generic sqlglot dialect is known to break 
queries.
+        # For example, it will change `foo NOT IN (1, 2)` to `NOT foo IN 
(1,2)`, which
+        # breaks the query for Firebolt. To avoid this, we use sqlparse for 
formatting
+        # when the dialect is not known.
+        if self._dialect is None:
+            return sqlparse.format(self._sql, reindent=True, 
keyword_case="upper")
+
+        # XXX: this might fail even if the dialect is supported! SHOW TABLES
         write = Dialect.get_or_raise(self._dialect)
         return write.generate(self._parsed, copy=False, comments=comments, 
pretty=True)
 
@@ -456,7 +483,9 @@ class KustoKQLStatement(BaseSQLStatement[str]):
         
https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/scalar-data-types/string
         for more information.
         """
-        return [cls(statement, engine) for statement in split_kql(script)]
+        return [
+            cls(statement, engine, statement.strip()) for statement in 
split_kql(script)
+        ]
 
     @classmethod
     def _parse_statement(
@@ -498,7 +527,7 @@ class KustoKQLStatement(BaseSQLStatement[str]):
         """
         Pretty-format the SQL statement.
         """
-        return self._parsed
+        return self._sql.strip()
 
     def get_settings(self) -> dict[str, str | bool]:
         """
@@ -548,6 +577,9 @@ class SQLScript:
     def format(self, comments: bool = True) -> str:
         """
         Pretty-format the SQL script.
+
+        Note that even though KQL is very different from SQL, multiple 
statements are
+        still separated by semi-colons.
         """
         return ";\n".join(statement.format(comments) for statement in 
self.statements)
 

Reply via email to