This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 2f68010729 fix: `is_select` (#25189)
2f68010729 is described below

commit 2f68010729453bdf29e31b7de29731d812e1668c
Author: Beto Dealmeida <[email protected]>
AuthorDate: Wed Sep 6 11:54:25 2023 -0700

    fix: `is_select` (#25189)
---
 superset/sql_parse.py               | 74 ++++++++++++++++++++-----------------
 tests/unit_tests/sql_parse_tests.py |  7 ++++
 2 files changed, 47 insertions(+), 34 deletions(-)

diff --git a/superset/sql_parse.py b/superset/sql_parse.py
index 2a283b81f0..34fc354730 100644
--- a/superset/sql_parse.py
+++ b/superset/sql_parse.py
@@ -244,46 +244,52 @@ class ParsedQuery:
         # make sure we strip comments; prevents a bug with comments in the CTE
         parsed = sqlparse.parse(self.strip_comments())
 
-        # Check if this is a CTE
-        if parsed[0].is_group and parsed[0][0].ttype == Keyword.CTE:
-            if sqloxide_parse is not None:
-                try:
-                    if not self._check_cte_is_select(
-                        sqloxide_parse(self.strip_comments(), dialect="ansi")
-                    ):
-                        return False
-                except ValueError:
-                    # sqloxide was not able to parse the query, so let's 
continue with
-                    # sqlparse
-                    pass
-            inner_cte = self.get_inner_cte_expression(parsed[0].tokens) or []
-            # Check if the inner CTE is a not a SELECT
-            if any(token.ttype == DDL for token in inner_cte) or any(
+        for statement in parsed:
+            # Check if this is a CTE
+            if statement.is_group and statement[0].ttype == Keyword.CTE:
+                if sqloxide_parse is not None:
+                    try:
+                        if not self._check_cte_is_select(
+                            sqloxide_parse(self.strip_comments(), 
dialect="ansi")
+                        ):
+                            return False
+                    except ValueError:
+                        # sqloxide was not able to parse the query, so let's 
continue with
+                        # sqlparse
+                        pass
+                inner_cte = self.get_inner_cte_expression(statement.tokens) or 
[]
+                # Check if the inner CTE is a not a SELECT
+                if any(token.ttype == DDL for token in inner_cte) or any(
+                    token.ttype == DML and token.normalized != "SELECT"
+                    for token in inner_cte
+                ):
+                    return False
+
+            if statement.get_type() == "SELECT":
+                continue
+
+            if statement.get_type() != "UNKNOWN":
+                return False
+
+            # for `UNKNOWN`, check all DDL/DML explicitly: only `SELECT` DML 
is allowed,
+            # and no DDL is allowed
+            if any(token.ttype == DDL for token in statement) or any(
                 token.ttype == DML and token.normalized != "SELECT"
-                for token in inner_cte
+                for token in statement
             ):
                 return False
 
-        if parsed[0].get_type() == "SELECT":
-            return True
-
-        if parsed[0].get_type() != "UNKNOWN":
-            return False
-
-        # for `UNKNOWN`, check all DDL/DML explicitly: only `SELECT` DML is 
allowed,
-        # and no DDL is allowed
-        if any(token.ttype == DDL for token in parsed[0]) or any(
-            token.ttype == DML and token.normalized != "SELECT" for token in 
parsed[0]
-        ):
-            return False
+            # return false on `EXPLAIN`, `SET`, `SHOW`, etc.
+            if statement[0].ttype == Keyword:
+                return False
 
-        # return false on `EXPLAIN`, `SET`, `SHOW`, etc.
-        if parsed[0][0].ttype == Keyword:
-            return False
+            if not any(
+                token.ttype == DML and token.normalized == "SELECT"
+                for token in statement
+            ):
+                return False
 
-        return any(
-            token.ttype == DML and token.normalized == "SELECT" for token in 
parsed[0]
-        )
+        return True
 
     def get_inner_cte_expression(self, tokens: TokenList) -> 
Optional[TokenList]:
         for token in tokens:
diff --git a/tests/unit_tests/sql_parse_tests.py 
b/tests/unit_tests/sql_parse_tests.py
index 7d8839198c..73074d3df6 100644
--- a/tests/unit_tests/sql_parse_tests.py
+++ b/tests/unit_tests/sql_parse_tests.py
@@ -1616,3 +1616,10 @@ def test_extract_table_references(mocker: MockerFixture) 
-> None:
         Table(table="other_table", schema=None, catalog=None)
     }
     logger.warning.assert_not_called()
+
+
+def test_is_select() -> None:
+    """
+    Test `is_select`.
+    """
+    assert not ParsedQuery("SELECT 1; DROP DATABASE superset").is_select()

Reply via email to