This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch remove-sqlparse-kusto
in repository https://gitbox.apache.org/repos/asf/superset.git

commit c7ad3d5fa8f0737861f1c2d6cb94c897f75ad74d
Author: Beto Dealmeida <[email protected]>
AuthorDate: Wed Jan 24 10:02:47 2024 -0500

    feat: support for KQL in SQLQuery
---
 superset/sql_parse.py               | 389 ++++++++++++++++++++++++++++++------
 tests/unit_tests/sql_parse_tests.py |  88 +++++++-
 2 files changed, 408 insertions(+), 69 deletions(-)

diff --git a/superset/sql_parse.py b/superset/sql_parse.py
index da82ffd60b..fda17d891e 100644
--- a/superset/sql_parse.py
+++ b/superset/sql_parse.py
@@ -17,12 +17,15 @@
 
 # pylint: disable=too-many-lines
 
+from __future__ import annotations
+
+import enum
 import logging
 import re
 import urllib.parse
 from collections.abc import Iterable, Iterator
 from dataclasses import dataclass
-from typing import Any, cast, Optional, Union
+from typing import Any, cast, Generic, TypeVar
 
 import sqlglot
 import sqlparse
@@ -138,7 +141,7 @@ class CtasMethod(StrEnum):
     VIEW = "VIEW"
 
 
-def _extract_limit_from_query(statement: TokenList) -> Optional[int]:
+def _extract_limit_from_query(statement: TokenList) -> int | None:
     """
     Extract limit clause from SQL statement.
 
@@ -159,9 +162,7 @@ def _extract_limit_from_query(statement: TokenList) -> 
Optional[int]:
     return None
 
 
-def extract_top_from_query(
-    statement: TokenList, top_keywords: set[str]
-) -> Optional[int]:
+def extract_top_from_query(statement: TokenList, top_keywords: set[str]) -> 
int | None:
     """
     Extract top clause value from SQL statement.
 
@@ -185,7 +186,7 @@ def extract_top_from_query(
     return top
 
 
-def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]:
+def get_cte_remainder_query(sql: str) -> tuple[str | None, str]:
     """
     parse the SQL and return the CTE and rest of the block to the caller
 
@@ -193,7 +194,7 @@ def get_cte_remainder_query(sql: str) -> 
tuple[Optional[str], str]:
     :return: CTE and remainder block to the caller
 
     """
-    cte: Optional[str] = None
+    cte: str | None = None
     remainder = sql
     stmt = sqlparse.parse(sql)[0]
 
@@ -211,7 +212,7 @@ def get_cte_remainder_query(sql: str) -> 
tuple[Optional[str], str]:
     return cte, remainder
 
 
-def strip_comments_from_sql(statement: str, engine: Optional[str] = None) -> 
str:
+def strip_comments_from_sql(statement: str, engine: str | None = None) -> str:
     """
     Strips comments from a SQL statement, does a simple test first
     to avoid always instantiating the expensive ParsedQuery constructor
@@ -235,8 +236,8 @@ class Table:
     """
 
     table: str
-    schema: Optional[str] = None
-    catalog: Optional[str] = None
+    schema: str | None = None
+    catalog: str | None = None
 
     def __str__(self) -> str:
         """
@@ -255,7 +256,7 @@ class Table:
 
 def extract_tables_from_statement(
     statement: exp.Expression,
-    dialect: Optional[Dialects],
+    dialect: Dialects | None,
 ) -> set[Table]:
     """
     Extract all table references in a single statement.
@@ -323,82 +324,151 @@ def is_cte(source: exp.Table, scope: Scope) -> bool:
     return source.name in ctes_in_scope
 
 
-class SQLQuery:
+# To avoid unnecessary parsing/formatting of queries, the statement has the 
concept of
+# an "internal representation", which is the AST of the SQL statement. For 
most of the
+# engines supported by Superset this is `sqlglot.exp.Expression`, but there is 
a special
+# case: KustoKQL uses a different syntax and there are no Python parsed, so we 
store the
+# AST as a string (the original query), and manipulate it with regular 
expressions.
+InternalRepresentation = TypeVar("InternalRepresentation")
+
+# The base type. This helps type checking the `split_query` method correctly, 
since each
+# derived class has a more specific return type (the class itself). This will 
no longer
+# be needed once Python 3.11 is the smalled version supported. See SIP 673 for 
more
+# information: https://peps.python.org/pep-0673/
+TBaseSQLStatement = TypeVar("TBaseSQLStatement")
+
+
+class BaseSQLStatement(Generic[InternalRepresentation]):
     """
-    A SQL query, with 0+ statements.
+    Base class for SQL statements.
+
+    The class can be instantiated with a string representation of the query 
or, for
+    efficiency reasons, with a pre-parsed AST. This is useful with 
`sqlglot.parse`,
+    which will split a query in multiple already parsed statements.
+
+    The `engine` parameters comes from the `engine` attribute in a Superset DB 
engine
+    spec.
     """
 
     def __init__(
         self,
-        query: str,
-        engine: Optional[str] = None,
+        statement: str | InternalRepresentation,
+        engine: str,
     ):
-        dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
+        self._parsed: InternalRepresentation = (
+            self._parse_statement(statement, engine)
+            if isinstance(statement, str)
+            else statement
+        )
+        self.engine = engine
+        self.tables = self._extract_tables_from_statement(self._parsed, 
self.engine)
 
-        self.statements = [
-            SQLStatement(statement, engine=engine)
-            for statement in parse(query, dialect=dialect)
-            if statement
-        ]
+    @classmethod
+    def split_query(
+        cls: type[TBaseSQLStatement],
+        query: str,
+        engine: str,
+    ) -> list[TBaseSQLStatement]:
+        """
+        Split a query into multiple instantiated statements.
+
+        This is a helper function to split a full SQL query into multiple
+        `BaseSQLStatement` instances. It's used by `SQLQuery` when 
instantiating the
+        statements within a query.
+        """
+        raise NotImplementedError()
+
+    @classmethod
+    def _parse_statement(
+        cls,
+        statement: str,
+        engine: str,
+    ) -> InternalRepresentation:
+        """
+        Parse a string containing a single SQL statement, and returns the 
parsed AST.
+
+        Derived classes should not assume that `statement` contains a single 
statement,
+        and MUST explicitly validate that. Since this validation is parser 
dependent the
+        responsibility is left to the children classes.
+        """
+        raise NotImplementedError()
+
+    @classmethod
+    def _extract_tables_from_statement(
+        cls,
+        parsed: InternalRepresentation,
+        engine: str,
+    ) -> set[Table]:
+        """
+        Extract all table references in a given statement.
+        """
+        raise NotImplementedError()
 
     def format(self, comments: bool = True) -> str:
         """
-        Pretty-format the SQL query.
+        Format the statement, optionally ommitting comments.
         """
-        return ";\n".join(statement.format(comments) for statement in 
self.statements)
+        raise NotImplementedError()
 
-    def get_settings(self) -> dict[str, str]:
+    def get_settings(self) -> dict[str, str | bool]:
         """
-        Return the settings for the SQL query.
+        Return any settings set by the statement.
 
-            >>> statement = SQLQuery("SET foo = 'bar'; SET foo = 'baz'")
-            >>> statement.get_settings()
-            {"foo": "'baz'"}
+        For example, for this statement:
+
+            sql> SET foo = 'bar';
 
+        The method should return `{"foo": "'bar'"}`. Note the single quotes.
         """
-        settings: dict[str, str] = {}
-        for statement in self.statements:
-            settings.update(statement.get_settings())
+        raise NotImplementedError()
 
-        return settings
+    def __str__(self) -> str:
+        return self.format()
 
 
-class SQLStatement:
+class SQLStatement(BaseSQLStatement[exp.Expression]):
     """
     A SQL statement.
 
-    This class provides helper methods to manipulate and introspect SQL.
+    This class is used for all engines with dialects that can be parsed using 
sqlglot.
     """
 
     def __init__(
         self,
-        statement: Union[str, exp.Expression],
-        engine: Optional[str] = None,
+        statement: str | exp.Expression,
+        engine: str,
     ):
-        dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
-
-        if isinstance(statement, str):
-            try:
-                self._parsed = self._parse_statement(statement, dialect)
-            except ParseError as ex:
-                raise SupersetParseError(statement, engine) from ex
-        else:
-            self._parsed = statement
+        self._dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
+        super().__init__(statement, engine)
 
-        self._dialect = dialect
-        self.tables = extract_tables_from_statement(self._parsed, dialect)
+    @classmethod
+    def split_query(
+        cls,
+        query: str,
+        engine: str,
+    ) -> list[SQLStatement]:
+        return [
+            cls(statement, engine)
+            for statement in sqlglot.parse(query, engine)
+            if statement
+        ]
 
-    @staticmethod
+    @classmethod
     def _parse_statement(
-        sql_statement: str,
-        dialect: Optional[Dialects],
+        cls,
+        statement: str,
+        engine: str,
     ) -> exp.Expression:
         """
         Parse a single SQL statement.
         """
+        dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
+
+        # We could parse with `sqlglot.parse_one` to get a single statement, 
but we need
+        # to verify that the string contains exactly one statement.
         statements = [
             statement
-            for statement in sqlglot.parse(sql_statement, dialect=dialect)
+            for statement in sqlglot.parse(statement, dialect=dialect)
             if statement
         ]
         if len(statements) != 1:
@@ -406,6 +476,18 @@ class SQLStatement:
 
         return statements[0]
 
+    @classmethod
+    def _extract_tables_from_statement(
+        cls,
+        parsed: exp.Expression,
+        engine: str,
+    ) -> set[Table]:
+        """
+        Find all referenced tables.
+        """
+        dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
+        return extract_tables_from_statement(parsed, dialect)
+
     def format(self, comments: bool = True) -> str:
         """
         Pretty-format the SQL statement.
@@ -413,7 +495,7 @@ class SQLStatement:
         write = Dialect.get_or_raise(self._dialect)
         return write.generate(self._parsed, copy=False, comments=comments, 
pretty=True)
 
-    def get_settings(self) -> dict[str, str]:
+    def get_settings(self) -> dict[str, str | bool]:
         """
         Return the settings for the SQL statement.
 
@@ -429,12 +511,189 @@ class SQLStatement:
         }
 
 
+class KustoKQLStatement(BaseSQLStatement[str]):
+    """
+    Special class for Kusto KQL.
+
+    Kusto KQL is a SQL-like language, but it's not supported by sqlglot. 
Queries look
+    like this:
+
+        StormEvents
+        | summarize PropertyDamage = sum(DamageProperty) by State
+        | join kind=innerunique PopulationData on State
+        | project State, PropertyDamagePerCapita = PropertyDamage / Population
+        | sort by PropertyDamagePerCapita
+
+    See https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/ for 
more
+    details about it.
+    """
+
+    @classmethod
+    def split_query(
+        cls,
+        query: str,
+        engine: str,
+    ) -> list[KustoKQLStatement]:
+        """
+        Split a query at semi-colons.
+
+        Since we don't have a parser, we use a simple state machine based 
function. See
+        
https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/scalar-data-types/string
+        for more information.
+        """
+
+        class KQLSplitState(enum.Enum):
+            """
+            State machine for splitting a KQL query.
+
+            The state machine keeps track of whether we're inside a string or 
not, so we
+            don't split the query in a semi-colon that's part of a string.
+            """
+
+            OUTSIDE_STRING = enum.auto()
+            INSIDE_SINGLE_QUOTED_STRING = enum.auto()
+            INSIDE_DOUBLE_QUOTED_STRING = enum.auto()
+            INSIDE_MULTILINE_STRING = enum.auto()
+
+        statements = []
+        state = KQLSplitState.OUTSIDE_STRING
+        statement_start = 0
+        query = query if query.endswith(";") else query + ";"
+        for i, character in enumerate(query):
+            if state == KQLSplitState.OUTSIDE_STRING:
+                if character == ";":
+                    statements.append(query[statement_start:i])
+                    statement_start = i + 1
+                elif character == "'":
+                    state = KQLSplitState.INSIDE_SINGLE_QUOTED_STRING
+                elif character == '"':
+                    state = KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING
+                elif character == "`" and query[i - 2 : i] == "``":
+                    state = KQLSplitState.INSIDE_MULTILINE_STRING
+
+            elif (
+                state == KQLSplitState.INSIDE_SINGLE_QUOTED_STRING
+                and character == "'"
+                and query[i - 1] != "\\"
+            ):
+                state = KQLSplitState.OUTSIDE_STRING
+
+            elif (
+                state == KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING
+                and character == '"'
+                and query[i - 1] != "\\"
+            ):
+                state = KQLSplitState.OUTSIDE_STRING
+
+            elif (
+                state == KQLSplitState.INSIDE_MULTILINE_STRING
+                and character == "`"
+                and query[i - 2 : i] == "``"
+            ):
+                state = KQLSplitState.OUTSIDE_STRING
+
+        return [cls(statement, engine) for statement in statements]
+
+    @classmethod
+    def _parse_statement(
+        cls,
+        statement: str,
+        engine: str,
+    ) -> str:
+        if engine != "kustokql":
+            raise ValueError(f"Invalid engine: {engine}")
+
+        # TODO: check if it's just a single statement
+
+        return statement.strip()
+
+    @classmethod
+    def _extract_tables_from_statement(cls, parsed: str, engine: str) -> 
set[Table]:
+        """
+        Extract all tables referenced in the statement.
+
+            StormEvents
+            | where InjuriesDirect + InjuriesIndirect > 50
+            | join (PopulationData) on State
+            | project State, Population, TotalInjuries = InjuriesDirect + 
InjuriesIndirect
+
+        """
+        logger.warning(
+            "Kusto KQL doesn't support table extraction. This means that data 
access "
+            "roles will not be enforced by Superset in the database."
+        )
+        return set()
+
+    def format(self, comments: bool = True) -> str:
+        """
+        Pretty-format the SQL statement.
+        """
+        return self._parsed
+
+    def get_settings(self) -> dict[str, str | bool]:
+        """
+        Return the settings for the SQL statement.
+
+            >>> statement = KustoKQLStatement("set querytrace;")
+            >>> statement.get_settings()
+            {"querytrace": True}
+
+        """
+        set_regex = r"^set\s+(?P<name>\w+)(?:\s*=\s*(?P<value>\w+))?$"
+        if match := re.match(set_regex, self._parsed, re.IGNORECASE):
+            return {match.group("name"): match.group("value") or True}
+
+        return {}
+
+
+class SQLQuery:
+    """
+    A SQL query, with 0+ statements.
+    """
+
+    # Special engines that can't be parsed using sqlglot. Supporting non-SQL 
engines
+    # adds a lot of complexity to Superset, so we should avoid adding new 
engines to
+    # this data structure.
+    special_engines = {
+        "kustokql": KustoKQLStatement,
+    }
+
+    def __init__(
+        self,
+        query: str,
+        engine: str,
+    ):
+        statement_class = self.special_engines.get(engine, SQLStatement)
+        self.statements = statement_class.split_query(query, engine)
+
+    def format(self, comments: bool = True) -> str:
+        """
+        Pretty-format the SQL query.
+        """
+        return ";\n".join(statement.format(comments) for statement in 
self.statements)
+
+    def get_settings(self) -> dict[str, str | bool]:
+        """
+        Return the settings for the SQL query.
+
+            >>> statement = SQLQuery("SET foo = 'bar'; SET foo = 'baz'")
+            >>> statement.get_settings()
+            {"foo": "'baz'"}
+
+        """
+        settings: dict[str, str | bool] = {}
+        for statement in self.statements:
+            settings.update(statement.get_settings())
+
+        return settings
+
+
 class ParsedQuery:
     def __init__(
         self,
         sql_statement: str,
         strip_comments: bool = False,
-        engine: Optional[str] = None,
+        engine: str | None = None,
     ):
         if strip_comments:
             sql_statement = sqlparse.format(sql_statement, strip_comments=True)
@@ -443,7 +702,7 @@ class ParsedQuery:
         self._dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
         self._tables: set[Table] = set()
         self._alias_names: set[str] = set()
-        self._limit: Optional[int] = None
+        self._limit: int | None = None
 
         logger.debug("Parsing with sqlparse statement: %s", self.sql)
         self._parsed = sqlparse.parse(self.stripped())
@@ -476,7 +735,7 @@ class ParsedQuery:
         }
 
     @property
-    def limit(self) -> Optional[int]:
+    def limit(self) -> int | None:
         return self._limit
 
     def _get_cte_tables(self, parsed: dict[str, Any]) -> list[dict[str, Any]]:
@@ -557,7 +816,7 @@ class ParsedQuery:
 
         return True
 
-    def get_inner_cte_expression(self, tokens: TokenList) -> 
Optional[TokenList]:
+    def get_inner_cte_expression(self, tokens: TokenList) -> TokenList | None:
         for token in tokens:
             if self._is_identifier(token):
                 for identifier_token in token.tokens:
@@ -621,7 +880,7 @@ class ParsedQuery:
         return statements
 
     @staticmethod
-    def get_table(tlist: TokenList) -> Optional[Table]:
+    def get_table(tlist: TokenList) -> Table | None:
         """
         Return the table if valid, i.e., conforms to the 
[[catalog.]schema.]table
         construct.
@@ -657,7 +916,7 @@ class ParsedQuery:
     def as_create_table(
         self,
         table_name: str,
-        schema_name: Optional[str] = None,
+        schema_name: str | None = None,
         overwrite: bool = False,
         method: CtasMethod = CtasMethod.TABLE,
     ) -> str:
@@ -817,8 +1076,8 @@ def add_table_name(rls: TokenList, table: str) -> None:
 def get_rls_for_table(
     candidate: Token,
     database_id: int,
-    default_schema: Optional[str],
-) -> Optional[TokenList]:
+    default_schema: str | None,
+) -> TokenList | None:
     """
     Given a table name, return any associated RLS predicates.
     """
@@ -864,7 +1123,7 @@ def get_rls_for_table(
 def insert_rls_as_subquery(
     token_list: TokenList,
     database_id: int,
-    default_schema: Optional[str],
+    default_schema: str | None,
 ) -> TokenList:
     """
     Update a statement inplace applying any associated RLS predicates.
@@ -880,7 +1139,7 @@ def insert_rls_as_subquery(
     This method is safer than ``insert_rls_in_predicate``, but doesn't work in 
all
     databases.
     """
-    rls: Optional[TokenList] = None
+    rls: TokenList | None = None
     state = InsertRLSState.SCANNING
     for token in token_list.tokens:
         # Recurse into child token list
@@ -956,7 +1215,7 @@ def insert_rls_as_subquery(
 def insert_rls_in_predicate(
     token_list: TokenList,
     database_id: int,
-    default_schema: Optional[str],
+    default_schema: str | None,
 ) -> TokenList:
     """
     Update a statement inplace applying any associated RLS predicates.
@@ -967,7 +1226,7 @@ def insert_rls_in_predicate(
         after:  SELECT * FROM some_table WHERE ( 1=1) AND some_table.id=42
 
     """
-    rls: Optional[TokenList] = None
+    rls: TokenList | None = None
     state = InsertRLSState.SCANNING
     for token in token_list.tokens:
         # Recurse into child token list
@@ -1101,7 +1360,7 @@ RE_JINJA_BLOCK = re.compile(r"\{[%#][^\{\}%#]+[%#]\}")
 
 def extract_table_references(
     sql_text: str, sqla_dialect: str, show_warning: bool = True
-) -> set["Table"]:
+) -> set[Table]:
     """
     Return all the dependencies from a SQL sql_text.
     """
diff --git a/tests/unit_tests/sql_parse_tests.py 
b/tests/unit_tests/sql_parse_tests.py
index 268e482919..793e6076a5 100644
--- a/tests/unit_tests/sql_parse_tests.py
+++ b/tests/unit_tests/sql_parse_tests.py
@@ -33,6 +33,7 @@ from superset.sql_parse import (
     has_table_query,
     insert_rls_as_subquery,
     insert_rls_in_predicate,
+    KustoKQLStatement,
     ParsedQuery,
     sanitize_clause,
     SQLQuery,
@@ -1857,21 +1858,31 @@ def test_sqlquery() -> None:
     """
     Test the `SQLQuery` class.
     """
-    query = SQLQuery("SELECT 1; SELECT 2;")
+    query = SQLQuery("SELECT 1; SELECT 2;", "sqlite")
 
     assert len(query.statements) == 2
     assert query.format() == "SELECT\n  1;\nSELECT\n  2"
     assert query.statements[0].format() == "SELECT\n  1"
 
-    query = SQLQuery("SET a=1; SET a=2; SELECT 3;")
+    query = SQLQuery("SET a=1; SET a=2; SELECT 3;", "sqlite")
     assert query.get_settings() == {"a": "2"}
 
+    query = SQLQuery(
+        """set querytrace;
+Events | take 100""",
+        "kustokql",
+    )
+    assert query.get_settings() == {"querytrace": True}
+
 
 def test_sqlstatement() -> None:
     """
     Test the `SQLStatement` class.
     """
-    statement = SQLStatement("SELECT * FROM table1 UNION ALL SELECT * FROM 
table2")
+    statement = SQLStatement(
+        "SELECT * FROM table1 UNION ALL SELECT * FROM table2",
+        "sqlite",
+    )
 
     assert statement.tables == {
         Table(table="table1", schema=None, catalog=None),
@@ -1882,5 +1893,74 @@ def test_sqlstatement() -> None:
         == "SELECT\n  *\nFROM table1\nUNION ALL\nSELECT\n  *\nFROM table2"
     )
 
-    statement = SQLStatement("SET a=1")
+    statement = SQLStatement("SET a=1", "sqlite")
     assert statement.get_settings() == {"a": "1"}
+
+
+def test_kustokqlstatement() -> None:
+    """
+    Test the `KustoKQLStatement` class.
+    """
+    statements = KustoKQLStatement.split_query(
+        """
+let totalPagesPerDay = PageViews
+| summarize by Page, Day = startofday(Timestamp)
+| summarize count() by Day;
+let materializedScope = PageViews
+| summarize by Page, Day = startofday(Timestamp);
+let cachedResult = materialize(materializedScope);
+cachedResult
+| project Page, Day1 = Day
+| join kind = inner
+(
+    cachedResult
+    | project Page, Day2 = Day
+)
+on Page
+| where Day2 > Day1
+| summarize count() by Day1, Day2
+| join kind = inner
+    totalPagesPerDay
+on $left.Day1 == $right.Day
+| project Day1, Day2, Percentage = count_*100.0/count_1
+        """,
+        "kustokql",
+    )
+    assert len(statements) == 4
+
+    statements = KustoKQLStatement.split_query(
+        """
+print program = ```
+  public class Program {
+    public static void Main() {
+      System.Console.WriteLine("Hello!");
+    }
+  }```
+        """,
+        "kustokql",
+    )
+    assert len(statements) == 1
+
+    statements = KustoKQLStatement.split_query(
+        """
+set querytrace;
+Events | take 100
+        """,
+        "kustokql",
+    )
+    assert len(statements) == 2
+    assert statements[0].format() == "set querytrace"
+    assert statements[1].format() == "Events | take 100"
+
+
[email protected](
+    "kql,statements",
+    [
+        ('print banner=strcat("Hello", ", ", "World!")', 1),
+        (r"print 'O\'Malley\'s'", 1),
+        (r"print 'O\'Mal;ley\'s'", 1),
+        ("print ```foo;\nbar;\nbaz;```\n", 1),
+    ],
+)
+def test_kustokql_statement_split_special(kql: str, statements: int) -> None:
+    assert len(KustoKQLStatement.split_query(kql, "kustokql")) == statements

Reply via email to