This is an automated email from the ASF dual-hosted git repository.

michaelsmolina pushed a commit to branch 3.1
in repository https://gitbox.apache.org/repos/asf/superset.git

commit b477abd1ca8dde4ef6d8d8a19c6bdb4d51cbe8eb
Author: John Bodley <[email protected]>
AuthorDate: Thu May 2 11:25:55 2024 -0700

    fix: Ignore USE SQL keyword when determining SELECT statement (#28279)
    
    (cherry picked from commit 27952e705754802fa5127e467012607bd892cef2)
---
 superset/sql_parse.py               | 9 +++++++--
 tests/unit_tests/sql_parse_tests.py | 3 +++
 2 files changed, 10 insertions(+), 2 deletions(-)

diff --git a/superset/sql_parse.py b/superset/sql_parse.py
index 6eab11de3d..657fde1982 100644
--- a/superset/sql_parse.py
+++ b/superset/sql_parse.py
@@ -429,6 +429,7 @@ class ParsedQuery:
     def is_select(self) -> bool:
         # make sure we strip comments; prevents a bug with comments in the CTE
         parsed = sqlparse.parse(self.strip_comments())
+        seen_select = False
 
         for statement in parsed:
             # Check if this is a CTE
@@ -452,6 +453,7 @@ class ParsedQuery:
                     return False
 
             if statement.get_type() == "SELECT":
+                seen_select = True
                 continue
 
             if statement.get_type() != "UNKNOWN":
@@ -465,8 +467,11 @@ class ParsedQuery:
             ):
                 return False
 
+            if imt(statement.tokens[0], m=(Keyword, "USE")):
+                continue
+
             # return false on `EXPLAIN`, `SET`, `SHOW`, etc.
-            if statement[0].ttype == Keyword:
+            if imt(statement.tokens[0], t=Keyword):
                 return False
 
             if not any(
@@ -475,7 +480,7 @@ class ParsedQuery:
             ):
                 return False
 
-        return True
+        return seen_select
 
     def get_inner_cte_expression(self, tokens: TokenList) -> TokenList | None:
         for token in tokens:
diff --git a/tests/unit_tests/sql_parse_tests.py 
b/tests/unit_tests/sql_parse_tests.py
index 9ca39f787d..3b5e801652 100644
--- a/tests/unit_tests/sql_parse_tests.py
+++ b/tests/unit_tests/sql_parse_tests.py
@@ -1881,6 +1881,9 @@ WITH t AS (
 )
 SELECT * FROM t"""
     ).is_select()
+    assert not ParsedQuery("").is_select()
+    assert not ParsedQuery("USE foo").is_select()
+    assert ParsedQuery("USE foo; SELECT * FROM bar").is_select()
 
 
 @pytest.mark.parametrize(

Reply via email to