This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 47c1e09c75 fix: `sqlparse` fallback for formatting queries (#30578)
47c1e09c75 is described below

commit 47c1e09c755b0dd93d46a3a203ee6bf644c66ea1
Author: Beto Dealmeida <[email protected]>
AuthorDate: Fri Oct 11 15:45:40 2024 -0400

    fix: `sqlparse` fallback for formatting queries (#30578)
---
 superset/sql/parse.py                         | 108 +++++++++++++++++++++-----
 tests/integration_tests/sql_lab/api_tests.py  |   2 +-
 tests/unit_tests/db_engine_specs/test_base.py |  16 +---
 tests/unit_tests/sql/parse_tests.py           |  34 ++++++++
 4 files changed, 125 insertions(+), 35 deletions(-)

diff --git a/superset/sql/parse.py b/superset/sql/parse.py
index 377411b944..33ed76473f 100644
--- a/superset/sql/parse.py
+++ b/superset/sql/parse.py
@@ -26,6 +26,8 @@ from dataclasses import dataclass
 from typing import Any, Generic, TypeVar
 
 import sqlglot
+import sqlparse
+from deprecation import deprecated
 from sqlglot import exp
 from sqlglot.dialects.dialect import Dialect, Dialects
 from sqlglot.errors import ParseError
@@ -138,9 +140,9 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
     """
     Base class for SQL statements.
 
-    The class can be instantiated with a string representation of the script 
or, for
-    efficiency reasons, with a pre-parsed AST. This is useful with 
`sqlglot.parse`,
-    which will split a script in multiple already parsed statements.
+    The class should be instantiated with a string representation of the 
script and, for
+    efficiency reasons, optionally with a pre-parsed AST. This is useful with
+    `sqlglot.parse`, which will split a script in multiple already parsed 
statements.
 
     The `engine` parameters comes from the `engine` attribute in a Superset DB 
engine
     spec.
@@ -148,14 +150,12 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
 
     def __init__(
         self,
-        statement: str | InternalRepresentation,
+        statement: str,
         engine: str,
+        ast: InternalRepresentation | None = None,
     ):
-        self._parsed: InternalRepresentation = (
-            self._parse_statement(statement, engine)
-            if isinstance(statement, str)
-            else statement
-        )
+        self._sql = statement
+        self._parsed = ast or self._parse_statement(statement, engine)
         self.engine = engine
         self.tables = self._extract_tables_from_statement(self._parsed, 
self.engine)
 
@@ -239,11 +239,12 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
 
     def __init__(
         self,
-        statement: str | exp.Expression,
+        statement: str,
         engine: str,
+        ast: exp.Expression | None = None,
     ):
         self._dialect = SQLGLOT_DIALECTS.get(engine)
-        super().__init__(statement, engine)
+        super().__init__(statement, engine, ast)
 
     @classmethod
     def _parse(cls, script: str, engine: str) -> list[exp.Expression]:
@@ -275,11 +276,47 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
         script: str,
         engine: str,
     ) -> list[SQLStatement]:
-        return [
-            cls(statement, engine)
-            for statement in cls._parse(script, engine)
-            if statement
-        ]
+        if engine in SQLGLOT_DIALECTS:
+            try:
+                return [
+                    cls(ast.sql(), engine, ast)
+                    for ast in cls._parse(script, engine)
+                    if ast
+                ]
+            except ValueError:
+                # `ast.sql()` might raise an error on some cases (eg, `SHOW 
TABLES
+                # FROM`). In this case, we rely on the tokenizer to generate 
the
+                # statements.
+                pass
+
+        # When we don't have a sqlglot dialect we can't rely on `ast.sql()` to 
correctly
+        # generate the SQL of each statement, so we tokenize the script and 
split it
+        # based on the location of semi-colons.
+        statements = []
+        start = 0
+        remainder = script
+
+        try:
+            tokens = sqlglot.tokenize(script)
+        except sqlglot.errors.TokenError as ex:
+            raise SupersetParseError(
+                script,
+                engine,
+                message="Unable to tokenize script",
+            ) from ex
+
+        for token in tokens:
+            if token.token_type == sqlglot.TokenType.SEMICOLON:
+                statement, start = script[start : token.start], token.end + 1
+                ast = cls._parse(statement, engine)[0]
+                statements.append(cls(statement.strip(), engine, ast))
+                remainder = script[start:]
+
+        if remainder.strip():
+            ast = cls._parse(remainder, engine)[0]
+            statements.append(cls(remainder.strip(), engine, ast))
+
+        return statements
 
     @classmethod
     def _parse_statement(
@@ -349,8 +386,34 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
         """
         Pretty-format the SQL statement.
         """
-        write = Dialect.get_or_raise(self._dialect)
-        return write.generate(self._parsed, copy=False, comments=comments, 
pretty=True)
+        if self._dialect:
+            try:
+                write = Dialect.get_or_raise(self._dialect)
+                return write.generate(
+                    self._parsed,
+                    copy=False,
+                    comments=comments,
+                    pretty=True,
+                )
+            except ValueError:
+                pass
+
+        return self._fallback_formatting()
+
+    @deprecated(deprecated_in="4.0", removed_in="5.0")
+    def _fallback_formatting(self) -> str:
+        """
+        Format SQL without a specific dialect.
+
+        Reformatting SQL using the generic sqlglot dialect is known to break 
queries.
+        For example, it will change `foo NOT IN (1, 2)` to `NOT foo IN (1,2)`, 
which
+        breaks the query for Firebolt. To avoid this, we use sqlparse for 
formatting
+        when the dialect is not known.
+
+        In 5.0 we should remove `sqlparse`, and the method should return the 
query
+        unmodified.
+        """
+        return sqlparse.format(self._sql, reindent=True, keyword_case="upper")
 
     def get_settings(self) -> dict[str, str | bool]:
         """
@@ -456,7 +519,9 @@ class KustoKQLStatement(BaseSQLStatement[str]):
         
https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/scalar-data-types/string
         for more information.
         """
-        return [cls(statement, engine) for statement in split_kql(script)]
+        return [
+            cls(statement, engine, statement.strip()) for statement in 
split_kql(script)
+        ]
 
     @classmethod
     def _parse_statement(
@@ -498,7 +563,7 @@ class KustoKQLStatement(BaseSQLStatement[str]):
         """
         Pretty-format the SQL statement.
         """
-        return self._parsed
+        return self._sql.strip()
 
     def get_settings(self) -> dict[str, str | bool]:
         """
@@ -548,6 +613,9 @@ class SQLScript:
     def format(self, comments: bool = True) -> str:
         """
         Pretty-format the SQL script.
+
+        Note that even though KQL is very different from SQL, multiple 
statements are
+        still separated by semi-colons.
         """
         return ";\n".join(statement.format(comments) for statement in 
self.statements)
 
diff --git a/tests/integration_tests/sql_lab/api_tests.py 
b/tests/integration_tests/sql_lab/api_tests.py
index 19d6e56fb6..cf1e190bbb 100644
--- a/tests/integration_tests/sql_lab/api_tests.py
+++ b/tests/integration_tests/sql_lab/api_tests.py
@@ -281,7 +281,7 @@ class TestSqlLabApi(SupersetTestCase):
             "/api/v1/sqllab/format_sql/",
             json=data,
         )
-        success_resp = {"result": "SELECT\n  1\nFROM my_table"}
+        success_resp = {"result": "SELECT 1\nFROM my_table"}
         resp_data = json.loads(rv.data.decode("utf-8"))
         self.assertDictEqual(resp_data, success_resp)  # noqa: PT009
         assert rv.status_code == 200
diff --git a/tests/unit_tests/db_engine_specs/test_base.py 
b/tests/unit_tests/db_engine_specs/test_base.py
index a3af155815..d8e632ce09 100644
--- a/tests/unit_tests/db_engine_specs/test_base.py
+++ b/tests/unit_tests/db_engine_specs/test_base.py
@@ -241,14 +241,7 @@ def test_select_star(mocker: MockerFixture) -> None:
         latest_partition=False,
         cols=cols,
     )
-    assert (
-        sql
-        == """SELECT
-  a
-FROM my_table
-LIMIT ?
-OFFSET ?"""
-    )
+    assert sql == "SELECT a\nFROM my_table\nLIMIT ?\nOFFSET ?"
 
     sql = NoLimitDBEngineSpec.select_star(
         database=database,
@@ -260,12 +253,7 @@ OFFSET ?"""
         latest_partition=False,
         cols=cols,
     )
-    assert (
-        sql
-        == """SELECT
-  a
-FROM my_table"""
-    )
+    assert sql == "SELECT a\nFROM my_table"
 
 
 def test_extra_table_metadata(mocker: MockerFixture) -> None:
diff --git a/tests/unit_tests/sql/parse_tests.py 
b/tests/unit_tests/sql/parse_tests.py
index ae5ebf89a8..ada6314457 100644
--- a/tests/unit_tests/sql/parse_tests.py
+++ b/tests/unit_tests/sql/parse_tests.py
@@ -284,6 +284,40 @@ def test_extract_tables_show_tables_from() -> None:
     )
 
 
+def test_format_show_tables() -> None:
+    """
+    Test format when `ast.sql()` raises an exception.
+
+    In that case sqlparse should be used instead.
+    """
+    assert (
+        SQLScript("SHOW TABLES FROM s1 like '%order%'", "mysql").format()
+        == "SHOW TABLES FROM s1 LIKE '%order%'"
+    )
+
+
+def test_format_no_dialect() -> None:
+    """
+    Test format with an engine that has no corresponding dialect.
+    """
+    assert (
+        SQLScript("SELECT col FROM t WHERE col NOT IN (1, 2)", 
"firebolt").format()
+        == "SELECT col\nFROM t\nWHERE col NOT IN (1,\n                  2)"
+    )
+
+
+def test_split_no_dialect() -> None:
+    """
+    Test the statement split when the engine has no corresponding dialect.
+    """
+    sql = "SELECT col FROM t WHERE col NOT IN (1, 2); SELECT * FROM t; SELECT 
foo"
+    statements = SQLScript(sql, "firebolt").statements
+    assert len(statements) == 3
+    assert statements[0]._sql == "SELECT col FROM t WHERE col NOT IN (1, 2)"
+    assert statements[1]._sql == "SELECT * FROM t"
+    assert statements[2]._sql == "SELECT foo"
+
+
 def test_extract_tables_show_columns_from() -> None:
     """
     Test `SHOW COLUMNS FROM`.

Reply via email to