This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new c0b57bd1c3 feat(sqlparse): improve table parsing (#26476)
c0b57bd1c3 is described below

commit c0b57bd1c3d487a661315a1944aca9f9ce728d51
Author: Beto Dealmeida <robe...@dealmeida.net>
AuthorDate: Mon Jan 22 11:16:50 2024 -0500

    feat(sqlparse): improve table parsing (#26476)
---
 requirements/base.txt                  |  15 +-
 requirements/testing.txt               |   6 +-
 setup.py                               |   1 +
 superset/commands/dataset/duplicate.py |   5 +-
 superset/commands/sql_lab/export.py    |   5 +-
 superset/connectors/sqla/models.py     |   2 +-
 superset/connectors/sqla/utils.py      |   2 +-
 superset/db_engine_specs/base.py       |  12 +-
 superset/db_engine_specs/bigquery.py   |   2 +-
 superset/models/helpers.py             |   2 +-
 superset/models/sql_lab.py             |   6 +-
 superset/security/manager.py           |   5 +-
 superset/sql_lab.py                    |  11 +-
 superset/sql_parse.py                  | 246 ++++++++++++++++++++++-----------
 superset/sql_validators/presto_db.py   |   4 +-
 superset/sqllab/query_render.py        |   6 +-
 tests/unit_tests/sql_parse_tests.py    |  55 ++++++--
 17 files changed, 265 insertions(+), 120 deletions(-)

diff --git a/requirements/base.txt b/requirements/base.txt
index 98d2a8094e..b198a5c9ce 100644
--- a/requirements/base.txt
+++ b/requirements/base.txt
@@ -141,7 +141,9 @@ geographiclib==1.52
 geopy==2.2.0
     # via apache-superset
 greenlet==2.0.2
-    # via shillelagh
+    # via
+    #   shillelagh
+    #   sqlalchemy
 gunicorn==21.2.0
     # via apache-superset
 hashids==1.3.1
@@ -155,7 +157,10 @@ idna==3.2
     #   email-validator
     #   requests
 importlib-metadata==6.6.0
-    # via apache-superset
+    # via
+    #   apache-superset
+    #   flask
+    #   shillelagh
 importlib-resources==5.12.0
     # via limits
 isodate==0.6.0
@@ -327,6 +332,8 @@ sqlalchemy-utils==0.38.3
     # via
     #   apache-superset
     #   flask-appbuilder
+sqlglot==20.8.0
+    # via apache-superset
 sqlparse==0.4.4
     # via apache-superset
 sshtunnel==0.4.0
@@ -376,7 +383,9 @@ wtforms-json==0.3.5
 xlsxwriter==3.0.7
     # via apache-superset
 zipp==3.15.0
-    # via importlib-metadata
+    # via
+    #   importlib-metadata
+    #   importlib-resources
 
 # The following packages are considered to be unsafe in a requirements file:
 # setuptools
diff --git a/requirements/testing.txt b/requirements/testing.txt
index 3bf3c78d03..b40497c8fc 100644
--- a/requirements/testing.txt
+++ b/requirements/testing.txt
@@ -24,10 +24,6 @@ db-dtypes==1.1.1
     # via pandas-gbq
 docker==6.1.1
     # via -r requirements/testing.in
-exceptiongroup==1.1.1
-    # via pytest
-ephem==4.1.4
-    # via lunarcalendar
 flask-testing==0.8.1
     # via -r requirements/testing.in
 fonttools==4.39.4
@@ -121,6 +117,8 @@ pyee==9.0.4
     # via playwright
 pyfakefs==5.2.2
     # via -r requirements/testing.in
+pyhive[presto]==0.7.0
+    # via apache-superset
 pytest==7.3.1
     # via
     #   -r requirements/testing.in
diff --git a/setup.py b/setup.py
index fd2a9f8c80..cb02a7f490 100644
--- a/setup.py
+++ b/setup.py
@@ -125,6 +125,7 @@ setup(
         "slack_sdk>=3.19.0, <4",
         "sqlalchemy>=1.4, <2",
         "sqlalchemy-utils>=0.38.3, <0.39",
+        "sqlglot>=20,<21",
         "sqlparse>=0.4.4, <0.5",
         "tabulate>=0.8.9, <0.9",
         "typing-extensions>=4, <5",
diff --git a/superset/commands/dataset/duplicate.py 
b/superset/commands/dataset/duplicate.py
index 0ae47c35bc..850290422e 100644
--- a/superset/commands/dataset/duplicate.py
+++ b/superset/commands/dataset/duplicate.py
@@ -70,7 +70,10 @@ class DuplicateDatasetCommand(CreateMixin, BaseCommand):
             table.normalize_columns = self._base_model.normalize_columns
             table.always_filter_main_dttm = 
self._base_model.always_filter_main_dttm
             table.is_sqllab_view = True
-            table.sql = ParsedQuery(self._base_model.sql).stripped()
+            table.sql = ParsedQuery(
+                self._base_model.sql,
+                engine=database.db_engine_spec.engine,
+            ).stripped()
             db.session.add(table)
             cols = []
             for config_ in self._base_model.columns:
diff --git a/superset/commands/sql_lab/export.py 
b/superset/commands/sql_lab/export.py
index 1b9b0e0344..aa6050f27f 100644
--- a/superset/commands/sql_lab/export.py
+++ b/superset/commands/sql_lab/export.py
@@ -115,7 +115,10 @@ class SqlResultExportCommand(BaseCommand):
                 limit = None
             else:
                 sql = self._query.executed_sql
-                limit = ParsedQuery(sql).limit
+                limit = ParsedQuery(
+                    sql,
+                    engine=self._query.database.db_engine_spec.engine,
+                ).limit
             if limit is not None and self._query.limiting_factor in {
                 LimitingFactor.QUERY,
                 LimitingFactor.DROPDOWN,
diff --git a/superset/connectors/sqla/models.py 
b/superset/connectors/sqla/models.py
index 624eb2ce5a..08dc923c21 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -1457,7 +1457,7 @@ class SqlaTable(
             return self.get_sqla_table(), None
 
         from_sql = self.get_rendered_sql(template_processor)
-        parsed_query = ParsedQuery(from_sql)
+        parsed_query = ParsedQuery(from_sql, engine=self.db_engine_spec.engine)
         if not (
             parsed_query.is_unknown()
             or self.db_engine_spec.is_readonly_query(parsed_query)
diff --git a/superset/connectors/sqla/utils.py 
b/superset/connectors/sqla/utils.py
index 66594084c8..688be53515 100644
--- a/superset/connectors/sqla/utils.py
+++ b/superset/connectors/sqla/utils.py
@@ -111,7 +111,7 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> 
list[ResultSetColumnType]:
     sql = dataset.get_template_processor().process_template(
         dataset.sql, **dataset.template_params_dict
     )
-    parsed_query = ParsedQuery(sql)
+    parsed_query = ParsedQuery(sql, engine=db_engine_spec.engine)
     if not db_engine_spec.is_readonly_query(parsed_query):
         raise SupersetSecurityException(
             SupersetError(
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 48e44064ac..3b8bb2bd33 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -899,7 +899,7 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
             return database.compile_sqla_query(qry)
 
         if cls.limit_method == LimitMethod.FORCE_LIMIT:
-            parsed_query = sql_parse.ParsedQuery(sql)
+            parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
             sql = parsed_query.set_or_update_query_limit(limit, force=force)
 
         return sql
@@ -980,7 +980,7 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         :param sql: SQL query
         :return: Value of limit clause in query
         """
-        parsed_query = sql_parse.ParsedQuery(sql)
+        parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
         return parsed_query.limit
 
     @classmethod
@@ -992,7 +992,7 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         :param limit: New limit to insert/replace into query
         :return: Query with new limit
         """
-        parsed_query = sql_parse.ParsedQuery(sql)
+        parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
         return parsed_query.set_or_update_query_limit(limit)
 
     @classmethod
@@ -1487,7 +1487,7 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         :param database: Database instance
         :return: Dictionary with different costs
         """
-        parsed_query = ParsedQuery(statement)
+        parsed_query = ParsedQuery(statement, engine=cls.engine)
         sql = parsed_query.stripped()
         sql_query_mutator = current_app.config["SQL_QUERY_MUTATOR"]
         mutate_after_split = current_app.config["MUTATE_AFTER_SPLIT"]
@@ -1522,7 +1522,7 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
                 "Database does not support cost estimation"
             )
 
-        parsed_query = sql_parse.ParsedQuery(sql)
+        parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
         statements = parsed_query.get_statements()
 
         costs = []
@@ -1583,7 +1583,7 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         :return:
         """
         if not cls.allows_sql_comments:
-            query = sql_parse.strip_comments_from_sql(query)
+            query = sql_parse.strip_comments_from_sql(query, engine=cls.engine)
 
         if cls.arraysize:
             cursor.arraysize = cls.arraysize
diff --git a/superset/db_engine_specs/bigquery.py 
b/superset/db_engine_specs/bigquery.py
index 8e7ed0bf7d..a8d834276e 100644
--- a/superset/db_engine_specs/bigquery.py
+++ b/superset/db_engine_specs/bigquery.py
@@ -435,7 +435,7 @@ class BigQueryEngineSpec(BaseEngineSpec):  # pylint: 
disable=too-many-public-met
         if not cls.get_allow_cost_estimate(extra):
             raise SupersetException("Database does not support cost 
estimation")
 
-        parsed_query = sql_parse.ParsedQuery(sql)
+        parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
         statements = parsed_query.get_statements()
         costs = []
         for statement in statements:
diff --git a/superset/models/helpers.py b/superset/models/helpers.py
index a6d879b785..9c8e83147e 100644
--- a/superset/models/helpers.py
+++ b/superset/models/helpers.py
@@ -1093,7 +1093,7 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
         """
 
         from_sql = self.get_rendered_sql(template_processor)
-        parsed_query = ParsedQuery(from_sql)
+        parsed_query = ParsedQuery(from_sql, engine=self.db_engine_spec.engine)
         if not (
             parsed_query.is_unknown()
             or self.db_engine_spec.is_readonly_query(parsed_query)
diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py
index ca530ff8b9..a0e9fa6b6e 100644
--- a/superset/models/sql_lab.py
+++ b/superset/models/sql_lab.py
@@ -183,7 +183,7 @@ class Query(
 
     @property
     def sql_tables(self) -> list[Table]:
-        return list(ParsedQuery(self.sql).tables)
+        return list(ParsedQuery(self.sql, 
engine=self.db_engine_spec.engine).tables)
 
     @property
     def columns(self) -> list["TableColumn"]:
@@ -427,7 +427,9 @@ class SavedQuery(AuditMixinNullable, ExtraJSONMixin, 
ImportExportMixin, Model):
 
     @property
     def sql_tables(self) -> list[Table]:
-        return list(ParsedQuery(self.sql).tables)
+        return list(
+            ParsedQuery(self.sql, 
engine=self.database.db_engine_spec.engine).tables
+        )
 
     @property
     def last_run_humanized(self) -> str:
diff --git a/superset/security/manager.py b/superset/security/manager.py
index 618a7e2808..7e1b697840 100644
--- a/superset/security/manager.py
+++ b/superset/security/manager.py
@@ -1876,7 +1876,10 @@ class SupersetSecurityManager(  # pylint: 
disable=too-many-public-methods
                 default_schema = database.get_default_schema_for_query(query)
                 tables = {
                     Table(table_.table, table_.schema or default_schema)
-                    for table_ in sql_parse.ParsedQuery(query.sql).tables
+                    for table_ in sql_parse.ParsedQuery(
+                        query.sql,
+                        engine=database.db_engine_spec.engine,
+                    ).tables
                 }
             elif table:
                 tables = {table}
diff --git a/superset/sql_lab.py b/superset/sql_lab.py
index 1029ff402c..1b883a77cf 100644
--- a/superset/sql_lab.py
+++ b/superset/sql_lab.py
@@ -199,7 +199,7 @@ def execute_sql_statement(
     database: Database = query.database
     db_engine_spec = database.db_engine_spec
 
-    parsed_query = ParsedQuery(sql_statement)
+    parsed_query = ParsedQuery(sql_statement, engine=db_engine_spec.engine)
     if is_feature_enabled("RLS_IN_SQLLAB"):
         # There are two ways to insert RLS: either replacing the table with a 
subquery
         # that has the RLS, or appending the RLS to the ``WHERE`` clause. The 
former is
@@ -219,7 +219,8 @@ def execute_sql_statement(
                     database.id,
                     query.schema,
                 )
-            )
+            ),
+            engine=db_engine_spec.engine,
         )
 
     sql = parsed_query.stripped()
@@ -409,7 +410,11 @@ def execute_sql_statements(
         )
 
     # Breaking down into multiple statements
-    parsed_query = ParsedQuery(rendered_query, strip_comments=True)
+    parsed_query = ParsedQuery(
+        rendered_query,
+        strip_comments=True,
+        engine=db_engine_spec.engine,
+    )
     if not db_engine_spec.run_multiple_statements_as_one:
         statements = parsed_query.get_statements()
         logger.info(
diff --git a/superset/sql_parse.py b/superset/sql_parse.py
index cecd673276..07704171de 100644
--- a/superset/sql_parse.py
+++ b/superset/sql_parse.py
@@ -14,15 +14,22 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+
+# pylint: disable=too-many-lines
+
 import logging
 import re
-from collections.abc import Iterator
+import urllib.parse
+from collections.abc import Iterable, Iterator
 from dataclasses import dataclass
 from typing import Any, cast, Optional
-from urllib import parse
 
 import sqlparse
 from sqlalchemy import and_
+from sqlglot import exp, parse, parse_one
+from sqlglot.dialects import Dialects
+from sqlglot.errors import ParseError
+from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope
 from sqlparse import keywords
 from sqlparse.lexer import Lexer
 from sqlparse.sql import (
@@ -53,7 +60,7 @@ from superset.utils.backports import StrEnum
 
 try:
     from sqloxide import parse_sql as sqloxide_parse
-except:  # pylint: disable=bare-except
+except (ImportError, ModuleNotFoundError):
     sqloxide_parse = None
 
 RESULT_OPERATIONS = {"UNION", "INTERSECT", "EXCEPT", "SELECT"}
@@ -72,6 +79,59 @@ sqlparser_sql_regex.insert(25, (r"'(''|\\\\|\\|[^'])*'", 
sqlparse.tokens.String.
 lex.set_SQL_REGEX(sqlparser_sql_regex)
 
 
+# mapping between DB engine specs and sqlglot dialects
+SQLGLOT_DIALECTS = {
+    "ascend": Dialects.HIVE,
+    "awsathena": Dialects.PRESTO,
+    "bigquery": Dialects.BIGQUERY,
+    "clickhouse": Dialects.CLICKHOUSE,
+    "clickhousedb": Dialects.CLICKHOUSE,
+    "cockroachdb": Dialects.POSTGRES,
+    # "crate": ???
+    # "databend": ???
+    "databricks": Dialects.DATABRICKS,
+    # "db2": ???
+    # "dremio": ???
+    "drill": Dialects.DRILL,
+    # "druid": ???
+    "duckdb": Dialects.DUCKDB,
+    # "dynamodb": ???
+    # "elasticsearch": ???
+    # "exa": ???
+    # "firebird": ???
+    # "firebolt": ???
+    "gsheets": Dialects.SQLITE,
+    "hana": Dialects.POSTGRES,
+    "hive": Dialects.HIVE,
+    # "ibmi": ???
+    # "impala": ???
+    # "kustokql": ???
+    # "kylin": ???
+    # "mssql": ???
+    "mysql": Dialects.MYSQL,
+    "netezza": Dialects.POSTGRES,
+    # "ocient": ???
+    # "odelasticsearch": ???
+    "oracle": Dialects.ORACLE,
+    # "pinot": ???
+    "postgresql": Dialects.POSTGRES,
+    "presto": Dialects.PRESTO,
+    "pydoris": Dialects.DORIS,
+    "redshift": Dialects.REDSHIFT,
+    # "risingwave": ???
+    # "rockset": ???
+    "shillelagh": Dialects.SQLITE,
+    "snowflake": Dialects.SNOWFLAKE,
+    # "solr": ???
+    "sqlite": Dialects.SQLITE,
+    "starrocks": Dialects.STARROCKS,
+    "superset": Dialects.SQLITE,
+    "teradatasql": Dialects.TERADATA,
+    "trino": Dialects.TRINO,
+    "vertica": Dialects.POSTGRES,
+}
+
+
 class CtasMethod(StrEnum):
     TABLE = "TABLE"
     VIEW = "VIEW"
@@ -150,7 +210,7 @@ def get_cte_remainder_query(sql: str) -> 
tuple[Optional[str], str]:
     return cte, remainder
 
 
-def strip_comments_from_sql(statement: str) -> str:
+def strip_comments_from_sql(statement: str, engine: Optional[str] = None) -> 
str:
     """
     Strips comments from a SQL statement, does a simple test first
     to avoid always instantiating the expensive ParsedQuery constructor
@@ -160,7 +220,11 @@ def strip_comments_from_sql(statement: str) -> str:
     :param statement: A string with the SQL statement
     :return: SQL statement without comments
     """
-    return ParsedQuery(statement).strip_comments() if "--" in statement else 
statement
+    return (
+        ParsedQuery(statement, engine=engine).strip_comments()
+        if "--" in statement
+        else statement
+    )
 
 
 @dataclass(eq=True, frozen=True)
@@ -179,7 +243,7 @@ class Table:
         """
 
         return ".".join(
-            parse.quote(part, safe="").replace(".", "%2E")
+            urllib.parse.quote(part, safe="").replace(".", "%2E")
             for part in [self.catalog, self.schema, self.table]
             if part
         )
@@ -189,11 +253,17 @@ class Table:
 
 
 class ParsedQuery:
-    def __init__(self, sql_statement: str, strip_comments: bool = False):
+    def __init__(
+        self,
+        sql_statement: str,
+        strip_comments: bool = False,
+        engine: Optional[str] = None,
+    ):
         if strip_comments:
             sql_statement = sqlparse.format(sql_statement, strip_comments=True)
 
         self.sql: str = sql_statement
+        self._dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
         self._tables: set[Table] = set()
         self._alias_names: set[str] = set()
         self._limit: Optional[int] = None
@@ -206,14 +276,94 @@ class ParsedQuery:
     @property
     def tables(self) -> set[Table]:
         if not self._tables:
-            for statement in self._parsed:
-                self._extract_from_token(statement)
-
-            self._tables = {
-                table for table in self._tables if str(table) not in 
self._alias_names
-            }
+            self._tables = self._extract_tables_from_sql()
         return self._tables
 
+    def _extract_tables_from_sql(self) -> set[Table]:
+        """
+        Extract all table references in a query.
+
+        Note: this uses sqlglot, since it's better at catching more edge cases.
+        """
+        try:
+            statements = parse(self.sql, dialect=self._dialect)
+        except ParseError:
+            logger.warning("Unable to parse SQL (%s): %s", self._dialect, 
self.sql)
+            return set()
+
+        return {
+            table
+            for statement in statements
+            for table in self._extract_tables_from_statement(statement)
+            if statement
+        }
+
+    def _extract_tables_from_statement(self, statement: exp.Expression) -> 
set[Table]:
+        """
+        Extract all table references in a single statement.
+
+        Please not that this is not trivial; consider the following queries:
+
+            DESCRIBE some_table;
+            SHOW PARTITIONS FROM some_table;
+            WITH masked_name AS (SELECT * FROM some_table) SELECT * FROM 
masked_name;
+
+        See the unit tests for other tricky cases.
+        """
+        sources: Iterable[exp.Table]
+
+        if isinstance(statement, exp.Describe):
+            # A `DESCRIBE` query has no sources in sqlglot, so we need to 
explicitly
+            # query for all tables.
+            sources = statement.find_all(exp.Table)
+        elif isinstance(statement, exp.Command):
+            # Commands, like `SHOW COLUMNS FROM foo`, have to be converted 
into a
+            # `SELECT` statetement in order to extract tables.
+            literal = statement.find(exp.Literal)
+            if not literal:
+                return set()
+
+            pseudo_query = parse_one(f"SELECT {literal.this}", 
dialect=self._dialect)
+            sources = pseudo_query.find_all(exp.Table)
+        else:
+            sources = [
+                source
+                for scope in traverse_scope(statement)
+                for source in scope.sources.values()
+                if isinstance(source, exp.Table) and not self._is_cte(source, 
scope)
+            ]
+
+        return {
+            Table(
+                source.name,
+                source.db if source.db != "" else None,
+                source.catalog if source.catalog != "" else None,
+            )
+            for source in sources
+        }
+
+    def _is_cte(self, source: exp.Table, scope: Scope) -> bool:
+        """
+        Is the source a CTE?
+
+        CTEs in the parent scope look like tables (and are represented by
+        exp.Table objects), but should not be considered as such;
+        otherwise a user with access to table `foo` could access any table
+        with a query like this:
+
+            WITH foo AS (SELECT * FROM target_table) SELECT * FROM foo
+
+        """
+        parent_sources = scope.parent.sources if scope.parent else {}
+        ctes_in_scope = {
+            name
+            for name, parent_scope in parent_sources.items()
+            if isinstance(parent_scope, Scope)
+            and parent_scope.scope_type == ScopeType.CTE
+        }
+
+        return source.name in ctes_in_scope
+
     @property
     def limit(self) -> Optional[int]:
         return self._limit
@@ -393,28 +543,6 @@ class ParsedQuery:
     def _is_identifier(token: Token) -> bool:
         return isinstance(token, (IdentifierList, Identifier))
 
-    def _process_tokenlist(self, token_list: TokenList) -> None:
-        """
-        Add table names to table set
-
-        :param token_list: TokenList to be processed
-        """
-        # exclude subselects
-        if "(" not in str(token_list):
-            table = self.get_table(token_list)
-            if table and not table.table.startswith(CTE_PREFIX):
-                self._tables.add(table)
-            return
-
-        # store aliases
-        if token_list.has_alias():
-            self._alias_names.add(token_list.get_alias())
-
-        # some aliases are not parsed properly
-        if token_list.tokens[0].ttype == Name:
-            self._alias_names.add(token_list.tokens[0].value)
-        self._extract_from_token(token_list)
-
     def as_create_table(
         self,
         table_name: str,
@@ -441,50 +569,6 @@ class ParsedQuery:
         exec_sql += f"CREATE {method} {full_table_name} AS \n{sql}"
         return exec_sql
 
-    def _extract_from_token(self, token: Token) -> None:
-        """
-        <Identifier> store a list of subtokens and <IdentifierList> store 
lists of
-        subtoken list.
-
-        It extracts <IdentifierList> and <Identifier> from :param token: and 
loops
-        through all subtokens recursively. It finds table_name_preceding_token 
and
-        passes <IdentifierList> and <Identifier> to self._process_tokenlist to 
populate
-        self._tables.
-
-        :param token: instance of Token or child class, e.g. TokenList, to be 
processed
-        """
-        if not hasattr(token, "tokens"):
-            return
-
-        table_name_preceding_token = False
-
-        for item in token.tokens:
-            if item.is_group and (
-                not self._is_identifier(item) or isinstance(item.tokens[0], 
Parenthesis)
-            ):
-                self._extract_from_token(item)
-
-            if item.ttype in Keyword and (
-                item.normalized in PRECEDES_TABLE_NAME
-                or item.normalized.endswith(" JOIN")
-            ):
-                table_name_preceding_token = True
-                continue
-
-            if item.ttype in Keyword:
-                table_name_preceding_token = False
-                continue
-            if table_name_preceding_token:
-                if isinstance(item, Identifier):
-                    self._process_tokenlist(item)
-                elif isinstance(item, IdentifierList):
-                    for token2 in item.get_identifiers():
-                        if isinstance(token2, TokenList):
-                            self._process_tokenlist(token2)
-            elif isinstance(item, IdentifierList):
-                if any(not self._is_identifier(token2) for token2 in 
item.tokens):
-                    self._extract_from_token(item)
-
     def set_or_update_query_limit(self, new_limit: int, force: bool = False) 
-> str:
         """Returns the query with the specified limit.
 
@@ -881,7 +965,7 @@ def insert_rls_in_predicate(
 
 
 # mapping between sqloxide and SQLAlchemy dialects
-SQLOXITE_DIALECTS = {
+SQLOXIDE_DIALECTS = {
     "ansi": {"trino", "trinonative", "presto"},
     "hive": {"hive", "databricks"},
     "ms": {"mssql"},
@@ -914,7 +998,7 @@ def extract_table_references(
     tree = None
 
     if sqloxide_parse:
-        for dialect, sqla_dialects in SQLOXITE_DIALECTS.items():
+        for dialect, sqla_dialects in SQLOXIDE_DIALECTS.items():
             if sqla_dialect in sqla_dialects:
                 break
         sql_text = RE_JINJA_BLOCK.sub(" ", sql_text)
diff --git a/superset/sql_validators/presto_db.py 
b/superset/sql_validators/presto_db.py
index c01b938671..fed1ff3bfa 100644
--- a/superset/sql_validators/presto_db.py
+++ b/superset/sql_validators/presto_db.py
@@ -50,7 +50,7 @@ class PrestoDBSQLValidator(BaseSQLValidator):
     ) -> Optional[SQLValidationAnnotation]:
         # pylint: disable=too-many-locals
         db_engine_spec = database.db_engine_spec
-        parsed_query = ParsedQuery(statement)
+        parsed_query = ParsedQuery(statement, engine=db_engine_spec.engine)
         sql = parsed_query.stripped()
 
         # Hook to allow environment-specific mutation (usually comments) to 
the SQL
@@ -154,7 +154,7 @@ class PrestoDBSQLValidator(BaseSQLValidator):
         For example, "SELECT 1 FROM default.mytable" becomes "EXPLAIN (TYPE
         VALIDATE) SELECT 1 FROM default.mytable.
         """
-        parsed_query = ParsedQuery(sql)
+        parsed_query = ParsedQuery(sql, engine=database.db_engine_spec.engine)
         statements = parsed_query.get_statements()
 
         logger.info("Validating %i statement(s)", len(statements))
diff --git a/superset/sqllab/query_render.py b/superset/sqllab/query_render.py
index f4c1c26c6e..5597bcb086 100644
--- a/superset/sqllab/query_render.py
+++ b/superset/sqllab/query_render.py
@@ -58,7 +58,11 @@ class SqlQueryRenderImpl(SqlQueryRender):
                 database=query_model.database, query=query_model
             )
 
-            parsed_query = ParsedQuery(query_model.sql, strip_comments=True)
+            parsed_query = ParsedQuery(
+                query_model.sql,
+                strip_comments=True,
+                engine=query_model.database.db_engine_spec.engine,
+            )
             rendered_query = sql_template_processor.process_template(
                 parsed_query.stripped(), **execution_context.template_params
             )
diff --git a/tests/unit_tests/sql_parse_tests.py 
b/tests/unit_tests/sql_parse_tests.py
index efd8838101..f650b77734 100644
--- a/tests/unit_tests/sql_parse_tests.py
+++ b/tests/unit_tests/sql_parse_tests.py
@@ -40,11 +40,11 @@ from superset.sql_parse import (
 )
 
 
-def extract_tables(query: str) -> set[Table]:
+def extract_tables(query: str, engine: Optional[str] = None) -> set[Table]:
     """
     Helper function to extract tables referenced in a query.
     """
-    return ParsedQuery(query).tables
+    return ParsedQuery(query, engine=engine).tables
 
 
 def test_table() -> None:
@@ -96,8 +96,13 @@ def test_extract_tables() -> None:
         Table("left_table")
     }
 
-    # reverse select
-    assert extract_tables("FROM t1 SELECT field") == {Table("t1")}
+    assert extract_tables(
+        "SELECT FROM (SELECT FROM forbidden_table) AS forbidden_table;"
+    ) == {Table("forbidden_table")}
+
+    assert extract_tables(
+        "select * from (select * from forbidden_table) forbidden_table"
+    ) == {Table("forbidden_table")}
 
 
 def test_extract_tables_subselect() -> None:
@@ -263,14 +268,16 @@ def test_extract_tables_illdefined() -> None:
     assert extract_tables("SELECT * FROM schemaname.") == set()
     assert extract_tables("SELECT * FROM catalogname.schemaname.") == set()
     assert extract_tables("SELECT * FROM catalogname..") == set()
-    assert extract_tables("SELECT * FROM catalogname..tbname") == set()
+    assert extract_tables("SELECT * FROM catalogname..tbname") == {
+        Table(table="tbname", schema=None, catalog="catalogname")
+    }
 
 
 def test_extract_tables_show_tables_from() -> None:
     """
     Test ``SHOW TABLES FROM``.
     """
-    assert extract_tables("SHOW TABLES FROM s1 like '%order%'") == set()
+    assert extract_tables("SHOW TABLES FROM s1 like '%order%'", "mysql") == 
set()
 
 
 def test_extract_tables_show_columns_from() -> None:
@@ -311,7 +318,7 @@ WHERE regionkey IN (SELECT regionkey FROM t2)
             """
 SELECT name
 FROM t1
-WHERE regionkey EXISTS (SELECT regionkey FROM t2)
+WHERE EXISTS (SELECT 1 FROM t2 WHERE t1.regionkey = t2.regionkey);
 """
         )
         == {Table("t1"), Table("t2")}
@@ -526,6 +533,18 @@ select * from (select key from q1) a
         == {Table("src")}
     )
 
+    # weird query with circular dependency
+    assert (
+        extract_tables(
+            """
+with src as ( select key from q2 where key = '5'),
+q2 as ( select key from src where key = '5')
+select * from (select key from src) a
+"""
+        )
+        == set()
+    )
+
 
 def test_extract_tables_multistatement() -> None:
     """
@@ -665,7 +684,8 @@ def test_extract_tables_nested_select() -> None:
 select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(TABLE_NAME)
 from INFORMATION_SCHEMA.COLUMNS
 WHERE TABLE_SCHEMA like "%bi%"),0x7e)));
-"""
+""",
+            "mysql",
         )
         == {Table("COLUMNS", "INFORMATION_SCHEMA")}
     )
@@ -676,7 +696,8 @@ WHERE TABLE_SCHEMA like "%bi%"),0x7e)));
 select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(COLUMN_NAME)
 from INFORMATION_SCHEMA.COLUMNS
 WHERE TABLE_NAME="bi_achievement_daily"),0x7e)));
-"""
+""",
+            "mysql",
         )
         == {Table("COLUMNS", "INFORMATION_SCHEMA")}
     )
@@ -1306,6 +1327,14 @@ def test_sqlparse_issue_652():
             "(SELECT table_name FROM /**/ information_schema.tables WHERE 
table_name LIKE '%user%' LIMIT 1)",
             True,
         ),
+        (
+            "SELECT FROM (SELECT FROM forbidden_table) AS forbidden_table;",
+            True,
+        ),
+        (
+            "SELECT * FROM (SELECT * FROM forbidden_table) forbidden_table",
+            True,
+        ),
     ],
 )
 def test_has_table_query(sql: str, expected: bool) -> None:
@@ -1790,13 +1819,17 @@ def test_extract_table_references(mocker: 
MockerFixture) -> None:
     assert extract_table_references(
         sql,
         "trino",
-    ) == {Table(table="other_table", schema=None, catalog=None)}
+    ) == {
+        Table(table="table", schema=None, catalog=None),
+        Table(table="other_table", schema=None, catalog=None),
+    }
     logger.warning.assert_called_once()
 
     logger = mocker.patch("superset.migrations.shared.utils.logger")
     sql = "SELECT * FROM table UNION ALL SELECT * FROM other_table"
     assert extract_table_references(sql, "trino", show_warning=False) == {
-        Table(table="other_table", schema=None, catalog=None)
+        Table(table="table", schema=None, catalog=None),
+        Table(table="other_table", schema=None, catalog=None),
     }
     logger.warning.assert_not_called()
 

Reply via email to