This is an automated email from the ASF dual-hosted git repository.
beto pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new 47c1e09c75 fix: `sqlparse` fallback for formatting queries (#30578)
47c1e09c75 is described below
commit 47c1e09c755b0dd93d46a3a203ee6bf644c66ea1
Author: Beto Dealmeida <[email protected]>
AuthorDate: Fri Oct 11 15:45:40 2024 -0400
fix: `sqlparse` fallback for formatting queries (#30578)
---
superset/sql/parse.py | 108 +++++++++++++++++++++-----
tests/integration_tests/sql_lab/api_tests.py | 2 +-
tests/unit_tests/db_engine_specs/test_base.py | 16 +---
tests/unit_tests/sql/parse_tests.py | 34 ++++++++
4 files changed, 125 insertions(+), 35 deletions(-)
diff --git a/superset/sql/parse.py b/superset/sql/parse.py
index 377411b944..33ed76473f 100644
--- a/superset/sql/parse.py
+++ b/superset/sql/parse.py
@@ -26,6 +26,8 @@ from dataclasses import dataclass
from typing import Any, Generic, TypeVar
import sqlglot
+import sqlparse
+from deprecation import deprecated
from sqlglot import exp
from sqlglot.dialects.dialect import Dialect, Dialects
from sqlglot.errors import ParseError
@@ -138,9 +140,9 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
"""
Base class for SQL statements.
- The class can be instantiated with a string representation of the script
or, for
- efficiency reasons, with a pre-parsed AST. This is useful with
`sqlglot.parse`,
- which will split a script in multiple already parsed statements.
+ The class should be instantiated with a string representation of the
script and, for
+ efficiency reasons, optionally with a pre-parsed AST. This is useful with
+ `sqlglot.parse`, which will split a script in multiple already parsed
statements.
The `engine` parameters comes from the `engine` attribute in a Superset DB
engine
spec.
@@ -148,14 +150,12 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
def __init__(
self,
- statement: str | InternalRepresentation,
+ statement: str,
engine: str,
+ ast: InternalRepresentation | None = None,
):
- self._parsed: InternalRepresentation = (
- self._parse_statement(statement, engine)
- if isinstance(statement, str)
- else statement
- )
+ self._sql = statement
+ self._parsed = ast or self._parse_statement(statement, engine)
self.engine = engine
self.tables = self._extract_tables_from_statement(self._parsed,
self.engine)
@@ -239,11 +239,12 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
def __init__(
self,
- statement: str | exp.Expression,
+ statement: str,
engine: str,
+ ast: exp.Expression | None = None,
):
self._dialect = SQLGLOT_DIALECTS.get(engine)
- super().__init__(statement, engine)
+ super().__init__(statement, engine, ast)
@classmethod
def _parse(cls, script: str, engine: str) -> list[exp.Expression]:
@@ -275,11 +276,47 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
script: str,
engine: str,
) -> list[SQLStatement]:
- return [
- cls(statement, engine)
- for statement in cls._parse(script, engine)
- if statement
- ]
+ if engine in SQLGLOT_DIALECTS:
+ try:
+ return [
+ cls(ast.sql(), engine, ast)
+ for ast in cls._parse(script, engine)
+ if ast
+ ]
+ except ValueError:
+ # `ast.sql()` might raise an error on some cases (eg, `SHOW
TABLES
+ # FROM`). In this case, we rely on the tokenizer to generate
the
+ # statements.
+ pass
+
+ # When we don't have a sqlglot dialect we can't rely on `ast.sql()` to
correctly
+ # generate the SQL of each statement, so we tokenize the script and
split it
+ # based on the location of semi-colons.
+ statements = []
+ start = 0
+ remainder = script
+
+ try:
+ tokens = sqlglot.tokenize(script)
+ except sqlglot.errors.TokenError as ex:
+ raise SupersetParseError(
+ script,
+ engine,
+ message="Unable to tokenize script",
+ ) from ex
+
+ for token in tokens:
+ if token.token_type == sqlglot.TokenType.SEMICOLON:
+ statement, start = script[start : token.start], token.end + 1
+ ast = cls._parse(statement, engine)[0]
+ statements.append(cls(statement.strip(), engine, ast))
+ remainder = script[start:]
+
+ if remainder.strip():
+ ast = cls._parse(remainder, engine)[0]
+ statements.append(cls(remainder.strip(), engine, ast))
+
+ return statements
@classmethod
def _parse_statement(
@@ -349,8 +386,34 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
"""
Pretty-format the SQL statement.
"""
- write = Dialect.get_or_raise(self._dialect)
- return write.generate(self._parsed, copy=False, comments=comments,
pretty=True)
+ if self._dialect:
+ try:
+ write = Dialect.get_or_raise(self._dialect)
+ return write.generate(
+ self._parsed,
+ copy=False,
+ comments=comments,
+ pretty=True,
+ )
+ except ValueError:
+ pass
+
+ return self._fallback_formatting()
+
+ @deprecated(deprecated_in="4.0", removed_in="5.0")
+ def _fallback_formatting(self) -> str:
+ """
+ Format SQL without a specific dialect.
+
+ Reformatting SQL using the generic sqlglot dialect is known to break
queries.
+ For example, it will change `foo NOT IN (1, 2)` to `NOT foo IN (1,2)`,
which
+ breaks the query for Firebolt. To avoid this, we use sqlparse for
formatting
+ when the dialect is not known.
+
+ In 5.0 we should remove `sqlparse`, and the method should return the
query
+ unmodified.
+ """
+ return sqlparse.format(self._sql, reindent=True, keyword_case="upper")
def get_settings(self) -> dict[str, str | bool]:
"""
@@ -456,7 +519,9 @@ class KustoKQLStatement(BaseSQLStatement[str]):
https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/scalar-data-types/string
for more information.
"""
- return [cls(statement, engine) for statement in split_kql(script)]
+ return [
+ cls(statement, engine, statement.strip()) for statement in
split_kql(script)
+ ]
@classmethod
def _parse_statement(
@@ -498,7 +563,7 @@ class KustoKQLStatement(BaseSQLStatement[str]):
"""
Pretty-format the SQL statement.
"""
- return self._parsed
+ return self._sql.strip()
def get_settings(self) -> dict[str, str | bool]:
"""
@@ -548,6 +613,9 @@ class SQLScript:
def format(self, comments: bool = True) -> str:
"""
Pretty-format the SQL script.
+
+ Note that even though KQL is very different from SQL, multiple
statements are
+ still separated by semi-colons.
"""
return ";\n".join(statement.format(comments) for statement in
self.statements)
diff --git a/tests/integration_tests/sql_lab/api_tests.py
b/tests/integration_tests/sql_lab/api_tests.py
index 19d6e56fb6..cf1e190bbb 100644
--- a/tests/integration_tests/sql_lab/api_tests.py
+++ b/tests/integration_tests/sql_lab/api_tests.py
@@ -281,7 +281,7 @@ class TestSqlLabApi(SupersetTestCase):
"/api/v1/sqllab/format_sql/",
json=data,
)
- success_resp = {"result": "SELECT\n 1\nFROM my_table"}
+ success_resp = {"result": "SELECT 1\nFROM my_table"}
resp_data = json.loads(rv.data.decode("utf-8"))
self.assertDictEqual(resp_data, success_resp) # noqa: PT009
assert rv.status_code == 200
diff --git a/tests/unit_tests/db_engine_specs/test_base.py
b/tests/unit_tests/db_engine_specs/test_base.py
index a3af155815..d8e632ce09 100644
--- a/tests/unit_tests/db_engine_specs/test_base.py
+++ b/tests/unit_tests/db_engine_specs/test_base.py
@@ -241,14 +241,7 @@ def test_select_star(mocker: MockerFixture) -> None:
latest_partition=False,
cols=cols,
)
- assert (
- sql
- == """SELECT
- a
-FROM my_table
-LIMIT ?
-OFFSET ?"""
- )
+ assert sql == "SELECT a\nFROM my_table\nLIMIT ?\nOFFSET ?"
sql = NoLimitDBEngineSpec.select_star(
database=database,
@@ -260,12 +253,7 @@ OFFSET ?"""
latest_partition=False,
cols=cols,
)
- assert (
- sql
- == """SELECT
- a
-FROM my_table"""
- )
+ assert sql == "SELECT a\nFROM my_table"
def test_extra_table_metadata(mocker: MockerFixture) -> None:
diff --git a/tests/unit_tests/sql/parse_tests.py
b/tests/unit_tests/sql/parse_tests.py
index ae5ebf89a8..ada6314457 100644
--- a/tests/unit_tests/sql/parse_tests.py
+++ b/tests/unit_tests/sql/parse_tests.py
@@ -284,6 +284,40 @@ def test_extract_tables_show_tables_from() -> None:
)
+def test_format_show_tables() -> None:
+ """
+ Test format when `ast.sql()` raises an exception.
+
+ In that case sqlparse should be used instead.
+ """
+ assert (
+ SQLScript("SHOW TABLES FROM s1 like '%order%'", "mysql").format()
+ == "SHOW TABLES FROM s1 LIKE '%order%'"
+ )
+
+
+def test_format_no_dialect() -> None:
+ """
+ Test format with an engine that has no corresponding dialect.
+ """
+ assert (
+ SQLScript("SELECT col FROM t WHERE col NOT IN (1, 2)",
"firebolt").format()
+ == "SELECT col\nFROM t\nWHERE col NOT IN (1,\n 2)"
+ )
+
+
+def test_split_no_dialect() -> None:
+ """
+ Test the statement split when the engine has no corresponding dialect.
+ """
+ sql = "SELECT col FROM t WHERE col NOT IN (1, 2); SELECT * FROM t; SELECT
foo"
+ statements = SQLScript(sql, "firebolt").statements
+ assert len(statements) == 3
+ assert statements[0]._sql == "SELECT col FROM t WHERE col NOT IN (1, 2)"
+ assert statements[1]._sql == "SELECT * FROM t"
+ assert statements[2]._sql == "SELECT foo"
+
+
def test_extract_tables_show_columns_from() -> None:
"""
Test `SHOW COLUMNS FROM`.