This is an automated email from the ASF dual-hosted git repository. beto pushed a commit to branch remove-sqlparse-kusto in repository https://gitbox.apache.org/repos/asf/superset.git
commit c7ad3d5fa8f0737861f1c2d6cb94c897f75ad74d Author: Beto Dealmeida <[email protected]> AuthorDate: Wed Jan 24 10:02:47 2024 -0500 feat: support for KQL in SQLQuery --- superset/sql_parse.py | 389 ++++++++++++++++++++++++++++++------ tests/unit_tests/sql_parse_tests.py | 88 +++++++- 2 files changed, 408 insertions(+), 69 deletions(-) diff --git a/superset/sql_parse.py b/superset/sql_parse.py index da82ffd60b..fda17d891e 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -17,12 +17,15 @@ # pylint: disable=too-many-lines +from __future__ import annotations + +import enum import logging import re import urllib.parse from collections.abc import Iterable, Iterator from dataclasses import dataclass -from typing import Any, cast, Optional, Union +from typing import Any, cast, Generic, TypeVar import sqlglot import sqlparse @@ -138,7 +141,7 @@ class CtasMethod(StrEnum): VIEW = "VIEW" -def _extract_limit_from_query(statement: TokenList) -> Optional[int]: +def _extract_limit_from_query(statement: TokenList) -> int | None: """ Extract limit clause from SQL statement. @@ -159,9 +162,7 @@ def _extract_limit_from_query(statement: TokenList) -> Optional[int]: return None -def extract_top_from_query( - statement: TokenList, top_keywords: set[str] -) -> Optional[int]: +def extract_top_from_query(statement: TokenList, top_keywords: set[str]) -> int | None: """ Extract top clause value from SQL statement. @@ -185,7 +186,7 @@ def extract_top_from_query( return top -def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]: +def get_cte_remainder_query(sql: str) -> tuple[str | None, str]: """ parse the SQL and return the CTE and rest of the block to the caller @@ -193,7 +194,7 @@ def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]: :return: CTE and remainder block to the caller """ - cte: Optional[str] = None + cte: str | None = None remainder = sql stmt = sqlparse.parse(sql)[0] @@ -211,7 +212,7 @@ def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]: return cte, remainder -def strip_comments_from_sql(statement: str, engine: Optional[str] = None) -> str: +def strip_comments_from_sql(statement: str, engine: str | None = None) -> str: """ Strips comments from a SQL statement, does a simple test first to avoid always instantiating the expensive ParsedQuery constructor @@ -235,8 +236,8 @@ class Table: """ table: str - schema: Optional[str] = None - catalog: Optional[str] = None + schema: str | None = None + catalog: str | None = None def __str__(self) -> str: """ @@ -255,7 +256,7 @@ class Table: def extract_tables_from_statement( statement: exp.Expression, - dialect: Optional[Dialects], + dialect: Dialects | None, ) -> set[Table]: """ Extract all table references in a single statement. @@ -323,82 +324,151 @@ def is_cte(source: exp.Table, scope: Scope) -> bool: return source.name in ctes_in_scope -class SQLQuery: +# To avoid unnecessary parsing/formatting of queries, the statement has the concept of +# an "internal representation", which is the AST of the SQL statement. For most of the +# engines supported by Superset this is `sqlglot.exp.Expression`, but there is a special +# case: KustoKQL uses a different syntax and there are no Python parsed, so we store the +# AST as a string (the original query), and manipulate it with regular expressions. +InternalRepresentation = TypeVar("InternalRepresentation") + +# The base type. This helps type checking the `split_query` method correctly, since each +# derived class has a more specific return type (the class itself). This will no longer +# be needed once Python 3.11 is the smalled version supported. See SIP 673 for more +# information: https://peps.python.org/pep-0673/ +TBaseSQLStatement = TypeVar("TBaseSQLStatement") + + +class BaseSQLStatement(Generic[InternalRepresentation]): """ - A SQL query, with 0+ statements. + Base class for SQL statements. + + The class can be instantiated with a string representation of the query or, for + efficiency reasons, with a pre-parsed AST. This is useful with `sqlglot.parse`, + which will split a query in multiple already parsed statements. + + The `engine` parameters comes from the `engine` attribute in a Superset DB engine + spec. """ def __init__( self, - query: str, - engine: Optional[str] = None, + statement: str | InternalRepresentation, + engine: str, ): - dialect = SQLGLOT_DIALECTS.get(engine) if engine else None + self._parsed: InternalRepresentation = ( + self._parse_statement(statement, engine) + if isinstance(statement, str) + else statement + ) + self.engine = engine + self.tables = self._extract_tables_from_statement(self._parsed, self.engine) - self.statements = [ - SQLStatement(statement, engine=engine) - for statement in parse(query, dialect=dialect) - if statement - ] + @classmethod + def split_query( + cls: type[TBaseSQLStatement], + query: str, + engine: str, + ) -> list[TBaseSQLStatement]: + """ + Split a query into multiple instantiated statements. + + This is a helper function to split a full SQL query into multiple + `BaseSQLStatement` instances. It's used by `SQLQuery` when instantiating the + statements within a query. + """ + raise NotImplementedError() + + @classmethod + def _parse_statement( + cls, + statement: str, + engine: str, + ) -> InternalRepresentation: + """ + Parse a string containing a single SQL statement, and returns the parsed AST. + + Derived classes should not assume that `statement` contains a single statement, + and MUST explicitly validate that. Since this validation is parser dependent the + responsibility is left to the children classes. + """ + raise NotImplementedError() + + @classmethod + def _extract_tables_from_statement( + cls, + parsed: InternalRepresentation, + engine: str, + ) -> set[Table]: + """ + Extract all table references in a given statement. + """ + raise NotImplementedError() def format(self, comments: bool = True) -> str: """ - Pretty-format the SQL query. + Format the statement, optionally ommitting comments. """ - return ";\n".join(statement.format(comments) for statement in self.statements) + raise NotImplementedError() - def get_settings(self) -> dict[str, str]: + def get_settings(self) -> dict[str, str | bool]: """ - Return the settings for the SQL query. + Return any settings set by the statement. - >>> statement = SQLQuery("SET foo = 'bar'; SET foo = 'baz'") - >>> statement.get_settings() - {"foo": "'baz'"} + For example, for this statement: + + sql> SET foo = 'bar'; + The method should return `{"foo": "'bar'"}`. Note the single quotes. """ - settings: dict[str, str] = {} - for statement in self.statements: - settings.update(statement.get_settings()) + raise NotImplementedError() - return settings + def __str__(self) -> str: + return self.format() -class SQLStatement: +class SQLStatement(BaseSQLStatement[exp.Expression]): """ A SQL statement. - This class provides helper methods to manipulate and introspect SQL. + This class is used for all engines with dialects that can be parsed using sqlglot. """ def __init__( self, - statement: Union[str, exp.Expression], - engine: Optional[str] = None, + statement: str | exp.Expression, + engine: str, ): - dialect = SQLGLOT_DIALECTS.get(engine) if engine else None - - if isinstance(statement, str): - try: - self._parsed = self._parse_statement(statement, dialect) - except ParseError as ex: - raise SupersetParseError(statement, engine) from ex - else: - self._parsed = statement + self._dialect = SQLGLOT_DIALECTS.get(engine) if engine else None + super().__init__(statement, engine) - self._dialect = dialect - self.tables = extract_tables_from_statement(self._parsed, dialect) + @classmethod + def split_query( + cls, + query: str, + engine: str, + ) -> list[SQLStatement]: + return [ + cls(statement, engine) + for statement in sqlglot.parse(query, engine) + if statement + ] - @staticmethod + @classmethod def _parse_statement( - sql_statement: str, - dialect: Optional[Dialects], + cls, + statement: str, + engine: str, ) -> exp.Expression: """ Parse a single SQL statement. """ + dialect = SQLGLOT_DIALECTS.get(engine) if engine else None + + # We could parse with `sqlglot.parse_one` to get a single statement, but we need + # to verify that the string contains exactly one statement. statements = [ statement - for statement in sqlglot.parse(sql_statement, dialect=dialect) + for statement in sqlglot.parse(statement, dialect=dialect) if statement ] if len(statements) != 1: @@ -406,6 +476,18 @@ class SQLStatement: return statements[0] + @classmethod + def _extract_tables_from_statement( + cls, + parsed: exp.Expression, + engine: str, + ) -> set[Table]: + """ + Find all referenced tables. + """ + dialect = SQLGLOT_DIALECTS.get(engine) if engine else None + return extract_tables_from_statement(parsed, dialect) + def format(self, comments: bool = True) -> str: """ Pretty-format the SQL statement. @@ -413,7 +495,7 @@ class SQLStatement: write = Dialect.get_or_raise(self._dialect) return write.generate(self._parsed, copy=False, comments=comments, pretty=True) - def get_settings(self) -> dict[str, str]: + def get_settings(self) -> dict[str, str | bool]: """ Return the settings for the SQL statement. @@ -429,12 +511,189 @@ class SQLStatement: } +class KustoKQLStatement(BaseSQLStatement[str]): + """ + Special class for Kusto KQL. + + Kusto KQL is a SQL-like language, but it's not supported by sqlglot. Queries look + like this: + + StormEvents + | summarize PropertyDamage = sum(DamageProperty) by State + | join kind=innerunique PopulationData on State + | project State, PropertyDamagePerCapita = PropertyDamage / Population + | sort by PropertyDamagePerCapita + + See https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/ for more + details about it. + """ + + @classmethod + def split_query( + cls, + query: str, + engine: str, + ) -> list[KustoKQLStatement]: + """ + Split a query at semi-colons. + + Since we don't have a parser, we use a simple state machine based function. See + https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/scalar-data-types/string + for more information. + """ + + class KQLSplitState(enum.Enum): + """ + State machine for splitting a KQL query. + + The state machine keeps track of whether we're inside a string or not, so we + don't split the query in a semi-colon that's part of a string. + """ + + OUTSIDE_STRING = enum.auto() + INSIDE_SINGLE_QUOTED_STRING = enum.auto() + INSIDE_DOUBLE_QUOTED_STRING = enum.auto() + INSIDE_MULTILINE_STRING = enum.auto() + + statements = [] + state = KQLSplitState.OUTSIDE_STRING + statement_start = 0 + query = query if query.endswith(";") else query + ";" + for i, character in enumerate(query): + if state == KQLSplitState.OUTSIDE_STRING: + if character == ";": + statements.append(query[statement_start:i]) + statement_start = i + 1 + elif character == "'": + state = KQLSplitState.INSIDE_SINGLE_QUOTED_STRING + elif character == '"': + state = KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING + elif character == "`" and query[i - 2 : i] == "``": + state = KQLSplitState.INSIDE_MULTILINE_STRING + + elif ( + state == KQLSplitState.INSIDE_SINGLE_QUOTED_STRING + and character == "'" + and query[i - 1] != "\\" + ): + state = KQLSplitState.OUTSIDE_STRING + + elif ( + state == KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING + and character == '"' + and query[i - 1] != "\\" + ): + state = KQLSplitState.OUTSIDE_STRING + + elif ( + state == KQLSplitState.INSIDE_MULTILINE_STRING + and character == "`" + and query[i - 2 : i] == "``" + ): + state = KQLSplitState.OUTSIDE_STRING + + return [cls(statement, engine) for statement in statements] + + @classmethod + def _parse_statement( + cls, + statement: str, + engine: str, + ) -> str: + if engine != "kustokql": + raise ValueError(f"Invalid engine: {engine}") + + # TODO: check if it's just a single statement + + return statement.strip() + + @classmethod + def _extract_tables_from_statement(cls, parsed: str, engine: str) -> set[Table]: + """ + Extract all tables referenced in the statement. + + StormEvents + | where InjuriesDirect + InjuriesIndirect > 50 + | join (PopulationData) on State + | project State, Population, TotalInjuries = InjuriesDirect + InjuriesIndirect + + """ + logger.warning( + "Kusto KQL doesn't support table extraction. This means that data access " + "roles will not be enforced by Superset in the database." + ) + return set() + + def format(self, comments: bool = True) -> str: + """ + Pretty-format the SQL statement. + """ + return self._parsed + + def get_settings(self) -> dict[str, str | bool]: + """ + Return the settings for the SQL statement. + + >>> statement = KustoKQLStatement("set querytrace;") + >>> statement.get_settings() + {"querytrace": True} + + """ + set_regex = r"^set\s+(?P<name>\w+)(?:\s*=\s*(?P<value>\w+))?$" + if match := re.match(set_regex, self._parsed, re.IGNORECASE): + return {match.group("name"): match.group("value") or True} + + return {} + + +class SQLQuery: + """ + A SQL query, with 0+ statements. + """ + + # Special engines that can't be parsed using sqlglot. Supporting non-SQL engines + # adds a lot of complexity to Superset, so we should avoid adding new engines to + # this data structure. + special_engines = { + "kustokql": KustoKQLStatement, + } + + def __init__( + self, + query: str, + engine: str, + ): + statement_class = self.special_engines.get(engine, SQLStatement) + self.statements = statement_class.split_query(query, engine) + + def format(self, comments: bool = True) -> str: + """ + Pretty-format the SQL query. + """ + return ";\n".join(statement.format(comments) for statement in self.statements) + + def get_settings(self) -> dict[str, str | bool]: + """ + Return the settings for the SQL query. + + >>> statement = SQLQuery("SET foo = 'bar'; SET foo = 'baz'") + >>> statement.get_settings() + {"foo": "'baz'"} + + """ + settings: dict[str, str | bool] = {} + for statement in self.statements: + settings.update(statement.get_settings()) + + return settings + + class ParsedQuery: def __init__( self, sql_statement: str, strip_comments: bool = False, - engine: Optional[str] = None, + engine: str | None = None, ): if strip_comments: sql_statement = sqlparse.format(sql_statement, strip_comments=True) @@ -443,7 +702,7 @@ class ParsedQuery: self._dialect = SQLGLOT_DIALECTS.get(engine) if engine else None self._tables: set[Table] = set() self._alias_names: set[str] = set() - self._limit: Optional[int] = None + self._limit: int | None = None logger.debug("Parsing with sqlparse statement: %s", self.sql) self._parsed = sqlparse.parse(self.stripped()) @@ -476,7 +735,7 @@ class ParsedQuery: } @property - def limit(self) -> Optional[int]: + def limit(self) -> int | None: return self._limit def _get_cte_tables(self, parsed: dict[str, Any]) -> list[dict[str, Any]]: @@ -557,7 +816,7 @@ class ParsedQuery: return True - def get_inner_cte_expression(self, tokens: TokenList) -> Optional[TokenList]: + def get_inner_cte_expression(self, tokens: TokenList) -> TokenList | None: for token in tokens: if self._is_identifier(token): for identifier_token in token.tokens: @@ -621,7 +880,7 @@ class ParsedQuery: return statements @staticmethod - def get_table(tlist: TokenList) -> Optional[Table]: + def get_table(tlist: TokenList) -> Table | None: """ Return the table if valid, i.e., conforms to the [[catalog.]schema.]table construct. @@ -657,7 +916,7 @@ class ParsedQuery: def as_create_table( self, table_name: str, - schema_name: Optional[str] = None, + schema_name: str | None = None, overwrite: bool = False, method: CtasMethod = CtasMethod.TABLE, ) -> str: @@ -817,8 +1076,8 @@ def add_table_name(rls: TokenList, table: str) -> None: def get_rls_for_table( candidate: Token, database_id: int, - default_schema: Optional[str], -) -> Optional[TokenList]: + default_schema: str | None, +) -> TokenList | None: """ Given a table name, return any associated RLS predicates. """ @@ -864,7 +1123,7 @@ def get_rls_for_table( def insert_rls_as_subquery( token_list: TokenList, database_id: int, - default_schema: Optional[str], + default_schema: str | None, ) -> TokenList: """ Update a statement inplace applying any associated RLS predicates. @@ -880,7 +1139,7 @@ def insert_rls_as_subquery( This method is safer than ``insert_rls_in_predicate``, but doesn't work in all databases. """ - rls: Optional[TokenList] = None + rls: TokenList | None = None state = InsertRLSState.SCANNING for token in token_list.tokens: # Recurse into child token list @@ -956,7 +1215,7 @@ def insert_rls_as_subquery( def insert_rls_in_predicate( token_list: TokenList, database_id: int, - default_schema: Optional[str], + default_schema: str | None, ) -> TokenList: """ Update a statement inplace applying any associated RLS predicates. @@ -967,7 +1226,7 @@ def insert_rls_in_predicate( after: SELECT * FROM some_table WHERE ( 1=1) AND some_table.id=42 """ - rls: Optional[TokenList] = None + rls: TokenList | None = None state = InsertRLSState.SCANNING for token in token_list.tokens: # Recurse into child token list @@ -1101,7 +1360,7 @@ RE_JINJA_BLOCK = re.compile(r"\{[%#][^\{\}%#]+[%#]\}") def extract_table_references( sql_text: str, sqla_dialect: str, show_warning: bool = True -) -> set["Table"]: +) -> set[Table]: """ Return all the dependencies from a SQL sql_text. """ diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py index 268e482919..793e6076a5 100644 --- a/tests/unit_tests/sql_parse_tests.py +++ b/tests/unit_tests/sql_parse_tests.py @@ -33,6 +33,7 @@ from superset.sql_parse import ( has_table_query, insert_rls_as_subquery, insert_rls_in_predicate, + KustoKQLStatement, ParsedQuery, sanitize_clause, SQLQuery, @@ -1857,21 +1858,31 @@ def test_sqlquery() -> None: """ Test the `SQLQuery` class. """ - query = SQLQuery("SELECT 1; SELECT 2;") + query = SQLQuery("SELECT 1; SELECT 2;", "sqlite") assert len(query.statements) == 2 assert query.format() == "SELECT\n 1;\nSELECT\n 2" assert query.statements[0].format() == "SELECT\n 1" - query = SQLQuery("SET a=1; SET a=2; SELECT 3;") + query = SQLQuery("SET a=1; SET a=2; SELECT 3;", "sqlite") assert query.get_settings() == {"a": "2"} + query = SQLQuery( + """set querytrace; +Events | take 100""", + "kustokql", + ) + assert query.get_settings() == {"querytrace": True} + def test_sqlstatement() -> None: """ Test the `SQLStatement` class. """ - statement = SQLStatement("SELECT * FROM table1 UNION ALL SELECT * FROM table2") + statement = SQLStatement( + "SELECT * FROM table1 UNION ALL SELECT * FROM table2", + "sqlite", + ) assert statement.tables == { Table(table="table1", schema=None, catalog=None), @@ -1882,5 +1893,74 @@ def test_sqlstatement() -> None: == "SELECT\n *\nFROM table1\nUNION ALL\nSELECT\n *\nFROM table2" ) - statement = SQLStatement("SET a=1") + statement = SQLStatement("SET a=1", "sqlite") assert statement.get_settings() == {"a": "1"} + + +def test_kustokqlstatement() -> None: + """ + Test the `KustoKQLStatement` class. + """ + statements = KustoKQLStatement.split_query( + """ +let totalPagesPerDay = PageViews +| summarize by Page, Day = startofday(Timestamp) +| summarize count() by Day; +let materializedScope = PageViews +| summarize by Page, Day = startofday(Timestamp); +let cachedResult = materialize(materializedScope); +cachedResult +| project Page, Day1 = Day +| join kind = inner +( + cachedResult + | project Page, Day2 = Day +) +on Page +| where Day2 > Day1 +| summarize count() by Day1, Day2 +| join kind = inner + totalPagesPerDay +on $left.Day1 == $right.Day +| project Day1, Day2, Percentage = count_*100.0/count_1 + """, + "kustokql", + ) + assert len(statements) == 4 + + statements = KustoKQLStatement.split_query( + """ +print program = ``` + public class Program { + public static void Main() { + System.Console.WriteLine("Hello!"); + } + }``` + """, + "kustokql", + ) + assert len(statements) == 1 + + statements = KustoKQLStatement.split_query( + """ +set querytrace; +Events | take 100 + """, + "kustokql", + ) + assert len(statements) == 2 + assert statements[0].format() == "set querytrace" + assert statements[1].format() == "Events | take 100" + + [email protected]( + "kql,statements", + [ + ('print banner=strcat("Hello", ", ", "World!")', 1), + (r"print 'O\'Malley\'s'", 1), + (r"print 'O\'Mal;ley\'s'", 1), + ("print ```foo;\nbar;\nbaz;```\n", 1), + ], +) +def test_kustokql_statement_split_special(kql: str, statements: int) -> None: + assert len(KustoKQLStatement.split_query(kql, "kustokql")) == statements
