This is an automated email from the ASF dual-hosted git repository. michaelsmolina pushed a commit to branch 3.1 in repository https://gitbox.apache.org/repos/asf/superset.git
commit 82895af3ce7988548bf8f727e643b38ec789c1e2 Author: Michael S. Molina <[email protected]> AuthorDate: Fri Mar 8 17:00:38 2024 -0300 Revert "feat(sqlparse): improve table parsing (#26476)" This reverts commit 1d9cfdabd1816c71f716c8d7e213558d5b7ff05e. --- requirements/base.txt | 15 +- requirements/testing.txt | 4 + setup.py | 1 - superset/commands/dataset/duplicate.py | 5 +- superset/commands/sql_lab/export.py | 5 +- superset/connectors/sqla/models.py | 2 +- superset/connectors/sqla/utils.py | 2 +- superset/db_engine_specs/base.py | 12 +- superset/db_engine_specs/bigquery.py | 2 +- superset/models/helpers.py | 2 +- superset/models/sql_lab.py | 6 +- superset/security/manager.py | 5 +- superset/sql_lab.py | 11 +- superset/sql_parse.py | 246 +++++++++++---------------------- superset/sql_validators/presto_db.py | 4 +- superset/sqllab/query_render.py | 6 +- tests/unit_tests/sql_parse_tests.py | 55 ++------ 17 files changed, 120 insertions(+), 263 deletions(-) diff --git a/requirements/base.txt b/requirements/base.txt index de25938a01..fbb82b46f5 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -141,9 +141,7 @@ geographiclib==1.52 geopy==2.2.0 # via apache-superset greenlet==2.0.2 - # via - # shillelagh - # sqlalchemy + # via shillelagh gunicorn==21.2.0 # via apache-superset hashids==1.3.1 @@ -157,10 +155,7 @@ idna==3.2 # email-validator # requests importlib-metadata==6.6.0 - # via - # apache-superset - # flask - # shillelagh + # via apache-superset importlib-resources==5.12.0 # via limits isodate==0.6.0 @@ -335,8 +330,6 @@ sqlalchemy-utils==0.38.3 # via # apache-superset # flask-appbuilder -sqlglot==20.8.0 - # via apache-superset sqlparse==0.4.4 # via apache-superset sshtunnel==0.4.0 @@ -387,9 +380,7 @@ wtforms-json==0.3.5 xlsxwriter==3.0.7 # via apache-superset zipp==3.15.0 - # via - # importlib-metadata - # importlib-resources + # via importlib-metadata # The following packages are considered to be unsafe in a requirements file: # setuptools diff --git a/requirements/testing.txt b/requirements/testing.txt index fce953f8e4..725df3ac3c 100644 --- a/requirements/testing.txt +++ b/requirements/testing.txt @@ -24,6 +24,10 @@ db-dtypes==1.1.1 # via pandas-gbq docker==6.1.1 # via -r requirements/testing.in +exceptiongroup==1.1.1 + # via pytest +ephem==4.1.4 + # via lunarcalendar flask-testing==0.8.1 # via -r requirements/testing.in fonttools==4.39.4 diff --git a/setup.py b/setup.py index 7050d7b497..15ac83417b 100644 --- a/setup.py +++ b/setup.py @@ -126,7 +126,6 @@ setup( "slack_sdk>=3.19.0, <4", "sqlalchemy>=1.4, <2", "sqlalchemy-utils>=0.38.3, <0.39", - "sqlglot>=20,<21", "sqlparse>=0.4.4, <0.5", "tabulate>=0.8.9, <0.9", "typing-extensions>=4, <5", diff --git a/superset/commands/dataset/duplicate.py b/superset/commands/dataset/duplicate.py index 850290422e..0ae47c35bc 100644 --- a/superset/commands/dataset/duplicate.py +++ b/superset/commands/dataset/duplicate.py @@ -70,10 +70,7 @@ class DuplicateDatasetCommand(CreateMixin, BaseCommand): table.normalize_columns = self._base_model.normalize_columns table.always_filter_main_dttm = self._base_model.always_filter_main_dttm table.is_sqllab_view = True - table.sql = ParsedQuery( - self._base_model.sql, - engine=database.db_engine_spec.engine, - ).stripped() + table.sql = ParsedQuery(self._base_model.sql).stripped() db.session.add(table) cols = [] for config_ in self._base_model.columns: diff --git a/superset/commands/sql_lab/export.py b/superset/commands/sql_lab/export.py index aa6050f27f..1b9b0e0344 100644 --- a/superset/commands/sql_lab/export.py +++ b/superset/commands/sql_lab/export.py @@ -115,10 +115,7 @@ class SqlResultExportCommand(BaseCommand): limit = None else: sql = self._query.executed_sql - limit = ParsedQuery( - sql, - engine=self._query.database.db_engine_spec.engine, - ).limit + limit = ParsedQuery(sql).limit if limit is not None and self._query.limiting_factor in { LimitingFactor.QUERY, LimitingFactor.DROPDOWN, diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index bd54032d5d..598bc6741b 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -1458,7 +1458,7 @@ class SqlaTable( return self.get_sqla_table(), None from_sql = self.get_rendered_sql(template_processor) - parsed_query = ParsedQuery(from_sql, engine=self.db_engine_spec.engine) + parsed_query = ParsedQuery(from_sql) if not ( parsed_query.is_unknown() or self.db_engine_spec.is_readonly_query(parsed_query) diff --git a/superset/connectors/sqla/utils.py b/superset/connectors/sqla/utils.py index 688be53515..66594084c8 100644 --- a/superset/connectors/sqla/utils.py +++ b/superset/connectors/sqla/utils.py @@ -111,7 +111,7 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> list[ResultSetColumnType]: sql = dataset.get_template_processor().process_template( dataset.sql, **dataset.template_params_dict ) - parsed_query = ParsedQuery(sql, engine=db_engine_spec.engine) + parsed_query = ParsedQuery(sql) if not db_engine_spec.is_readonly_query(parsed_query): raise SupersetSecurityException( SupersetError( diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 66293ccf52..ce67cb448c 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -900,7 +900,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return database.compile_sqla_query(qry) if cls.limit_method == LimitMethod.FORCE_LIMIT: - parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine) + parsed_query = sql_parse.ParsedQuery(sql) sql = parsed_query.set_or_update_query_limit(limit, force=force) return sql @@ -981,7 +981,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods :param sql: SQL query :return: Value of limit clause in query """ - parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine) + parsed_query = sql_parse.ParsedQuery(sql) return parsed_query.limit @classmethod @@ -993,7 +993,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods :param limit: New limit to insert/replace into query :return: Query with new limit """ - parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine) + parsed_query = sql_parse.ParsedQuery(sql) return parsed_query.set_or_update_query_limit(limit) @classmethod @@ -1490,7 +1490,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods :param database: Database instance :return: Dictionary with different costs """ - parsed_query = ParsedQuery(statement, engine=cls.engine) + parsed_query = ParsedQuery(statement) sql = parsed_query.stripped() sql_query_mutator = current_app.config["SQL_QUERY_MUTATOR"] mutate_after_split = current_app.config["MUTATE_AFTER_SPLIT"] @@ -1525,7 +1525,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods "Database does not support cost estimation" ) - parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine) + parsed_query = sql_parse.ParsedQuery(sql) statements = parsed_query.get_statements() costs = [] @@ -1586,7 +1586,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods :return: """ if not cls.allows_sql_comments: - query = sql_parse.strip_comments_from_sql(query, engine=cls.engine) + query = sql_parse.strip_comments_from_sql(query) if cls.arraysize: cursor.arraysize = cls.arraysize diff --git a/superset/db_engine_specs/bigquery.py b/superset/db_engine_specs/bigquery.py index a8d834276e..8e7ed0bf7d 100644 --- a/superset/db_engine_specs/bigquery.py +++ b/superset/db_engine_specs/bigquery.py @@ -435,7 +435,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met if not cls.get_allow_cost_estimate(extra): raise SupersetException("Database does not support cost estimation") - parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine) + parsed_query = sql_parse.ParsedQuery(sql) statements = parsed_query.get_statements() costs = [] for statement in statements: diff --git a/superset/models/helpers.py b/superset/models/helpers.py index 4ff206882e..1dc5a57da5 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -1094,7 +1094,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods """ from_sql = self.get_rendered_sql(template_processor) - parsed_query = ParsedQuery(from_sql, engine=self.db_engine_spec.engine) + parsed_query = ParsedQuery(from_sql) if not ( parsed_query.is_unknown() or self.db_engine_spec.is_readonly_query(parsed_query) diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py index aff8c1ce3d..7e63e984df 100644 --- a/superset/models/sql_lab.py +++ b/superset/models/sql_lab.py @@ -183,7 +183,7 @@ class Query( @property def sql_tables(self) -> list[Table]: - return list(ParsedQuery(self.sql, engine=self.db_engine_spec.engine).tables) + return list(ParsedQuery(self.sql).tables) @property def columns(self) -> list["TableColumn"]: @@ -427,9 +427,7 @@ class SavedQuery(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin): @property def sql_tables(self) -> list[Table]: - return list( - ParsedQuery(self.sql, engine=self.database.db_engine_spec.engine).tables - ) + return list(ParsedQuery(self.sql).tables) @property def last_run_humanized(self) -> str: diff --git a/superset/security/manager.py b/superset/security/manager.py index e6eb77e645..501c8cf6a6 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -1909,10 +1909,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods default_schema = database.get_default_schema_for_query(query) tables = { Table(table_.table, table_.schema or default_schema) - for table_ in sql_parse.ParsedQuery( - query.sql, - engine=database.db_engine_spec.engine, - ).tables + for table_ in sql_parse.ParsedQuery(query.sql).tables } elif table: tables = {table} diff --git a/superset/sql_lab.py b/superset/sql_lab.py index e9b4d406f8..efbef6560a 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -208,7 +208,7 @@ def execute_sql_statement( # pylint: disable=too-many-arguments, too-many-local database: Database = query.database db_engine_spec = database.db_engine_spec - parsed_query = ParsedQuery(sql_statement, engine=db_engine_spec.engine) + parsed_query = ParsedQuery(sql_statement) if is_feature_enabled("RLS_IN_SQLLAB"): # There are two ways to insert RLS: either replacing the table with a subquery # that has the RLS, or appending the RLS to the ``WHERE`` clause. The former is @@ -228,8 +228,7 @@ def execute_sql_statement( # pylint: disable=too-many-arguments, too-many-local database.id, query.schema, ) - ), - engine=db_engine_spec.engine, + ) ) sql = parsed_query.stripped() @@ -420,11 +419,7 @@ def execute_sql_statements( ) # Breaking down into multiple statements - parsed_query = ParsedQuery( - rendered_query, - strip_comments=True, - engine=db_engine_spec.engine, - ) + parsed_query = ParsedQuery(rendered_query, strip_comments=True) if not db_engine_spec.run_multiple_statements_as_one: statements = parsed_query.get_statements() logger.info( diff --git a/superset/sql_parse.py b/superset/sql_parse.py index 7b89ab8f0e..b9af21c8c3 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -14,22 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -# pylint: disable=too-many-lines - import logging import re -import urllib.parse -from collections.abc import Iterable, Iterator +from collections.abc import Iterator from dataclasses import dataclass from typing import Any, cast, Optional +from urllib import parse import sqlparse from sqlalchemy import and_ -from sqlglot import exp, parse, parse_one -from sqlglot.dialects import Dialects -from sqlglot.errors import ParseError -from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope from sqlparse import keywords from sqlparse.lexer import Lexer from sqlparse.sql import ( @@ -60,7 +53,7 @@ from superset.utils.backports import StrEnum try: from sqloxide import parse_sql as sqloxide_parse -except (ImportError, ModuleNotFoundError): +except: # pylint: disable=bare-except sqloxide_parse = None RESULT_OPERATIONS = {"UNION", "INTERSECT", "EXCEPT", "SELECT"} @@ -79,59 +72,6 @@ sqlparser_sql_regex.insert(25, (r"'(''|\\\\|\\|[^'])*'", sqlparse.tokens.String. lex.set_SQL_REGEX(sqlparser_sql_regex) -# mapping between DB engine specs and sqlglot dialects -SQLGLOT_DIALECTS = { - "ascend": Dialects.HIVE, - "awsathena": Dialects.PRESTO, - "bigquery": Dialects.BIGQUERY, - "clickhouse": Dialects.CLICKHOUSE, - "clickhousedb": Dialects.CLICKHOUSE, - "cockroachdb": Dialects.POSTGRES, - # "crate": ??? - # "databend": ??? - "databricks": Dialects.DATABRICKS, - # "db2": ??? - # "dremio": ??? - "drill": Dialects.DRILL, - # "druid": ??? - "duckdb": Dialects.DUCKDB, - # "dynamodb": ??? - # "elasticsearch": ??? - # "exa": ??? - # "firebird": ??? - # "firebolt": ??? - "gsheets": Dialects.SQLITE, - "hana": Dialects.POSTGRES, - "hive": Dialects.HIVE, - # "ibmi": ??? - # "impala": ??? - # "kustokql": ??? - # "kylin": ??? - # "mssql": ??? - "mysql": Dialects.MYSQL, - "netezza": Dialects.POSTGRES, - # "ocient": ??? - # "odelasticsearch": ??? - "oracle": Dialects.ORACLE, - # "pinot": ??? - "postgresql": Dialects.POSTGRES, - "presto": Dialects.PRESTO, - "pydoris": Dialects.DORIS, - "redshift": Dialects.REDSHIFT, - # "risingwave": ??? - # "rockset": ??? - "shillelagh": Dialects.SQLITE, - "snowflake": Dialects.SNOWFLAKE, - # "solr": ??? - "sqlite": Dialects.SQLITE, - "starrocks": Dialects.STARROCKS, - "superset": Dialects.SQLITE, - "teradatasql": Dialects.TERADATA, - "trino": Dialects.TRINO, - "vertica": Dialects.POSTGRES, -} - - class CtasMethod(StrEnum): TABLE = "TABLE" VIEW = "VIEW" @@ -210,7 +150,7 @@ def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]: return cte, remainder -def strip_comments_from_sql(statement: str, engine: Optional[str] = None) -> str: +def strip_comments_from_sql(statement: str) -> str: """ Strips comments from a SQL statement, does a simple test first to avoid always instantiating the expensive ParsedQuery constructor @@ -220,11 +160,7 @@ def strip_comments_from_sql(statement: str, engine: Optional[str] = None) -> str :param statement: A string with the SQL statement :return: SQL statement without comments """ - return ( - ParsedQuery(statement, engine=engine).strip_comments() - if "--" in statement - else statement - ) + return ParsedQuery(statement).strip_comments() if "--" in statement else statement @dataclass(eq=True, frozen=True) @@ -243,7 +179,7 @@ class Table: """ return ".".join( - urllib.parse.quote(part, safe="").replace(".", "%2E") + parse.quote(part, safe="").replace(".", "%2E") for part in [self.catalog, self.schema, self.table] if part ) @@ -253,17 +189,11 @@ class Table: class ParsedQuery: - def __init__( - self, - sql_statement: str, - strip_comments: bool = False, - engine: Optional[str] = None, - ): + def __init__(self, sql_statement: str, strip_comments: bool = False): if strip_comments: sql_statement = sqlparse.format(sql_statement, strip_comments=True) self.sql: str = sql_statement - self._dialect = SQLGLOT_DIALECTS.get(engine) if engine else None self._tables: set[Table] = set() self._alias_names: set[str] = set() self._limit: Optional[int] = None @@ -276,93 +206,13 @@ class ParsedQuery: @property def tables(self) -> set[Table]: if not self._tables: - self._tables = self._extract_tables_from_sql() - return self._tables - - def _extract_tables_from_sql(self) -> set[Table]: - """ - Extract all table references in a query. - - Note: this uses sqlglot, since it's better at catching more edge cases. - """ - try: - statements = parse(self.stripped(), dialect=self._dialect) - except ParseError: - logger.warning("Unable to parse SQL (%s): %s", self._dialect, self.sql) - return set() - - return { - table - for statement in statements - for table in self._extract_tables_from_statement(statement) - if statement - } - - def _extract_tables_from_statement(self, statement: exp.Expression) -> set[Table]: - """ - Extract all table references in a single statement. - - Please not that this is not trivial; consider the following queries: - - DESCRIBE some_table; - SHOW PARTITIONS FROM some_table; - WITH masked_name AS (SELECT * FROM some_table) SELECT * FROM masked_name; - - See the unit tests for other tricky cases. - """ - sources: Iterable[exp.Table] - - if isinstance(statement, exp.Describe): - # A `DESCRIBE` query has no sources in sqlglot, so we need to explicitly - # query for all tables. - sources = statement.find_all(exp.Table) - elif isinstance(statement, exp.Command): - # Commands, like `SHOW COLUMNS FROM foo`, have to be converted into a - # `SELECT` statetement in order to extract tables. - literal = statement.find(exp.Literal) - if not literal: - return set() - - pseudo_query = parse_one(f"SELECT {literal.this}", dialect=self._dialect) - sources = pseudo_query.find_all(exp.Table) - else: - sources = [ - source - for scope in traverse_scope(statement) - for source in scope.sources.values() - if isinstance(source, exp.Table) and not self._is_cte(source, scope) - ] - - return { - Table( - source.name, - source.db if source.db != "" else None, - source.catalog if source.catalog != "" else None, - ) - for source in sources - } - - def _is_cte(self, source: exp.Table, scope: Scope) -> bool: - """ - Is the source a CTE? + for statement in self._parsed: + self._extract_from_token(statement) - CTEs in the parent scope look like tables (and are represented by - exp.Table objects), but should not be considered as such; - otherwise a user with access to table `foo` could access any table - with a query like this: - - WITH foo AS (SELECT * FROM target_table) SELECT * FROM foo - - """ - parent_sources = scope.parent.sources if scope.parent else {} - ctes_in_scope = { - name - for name, parent_scope in parent_sources.items() - if isinstance(parent_scope, Scope) - and parent_scope.scope_type == ScopeType.CTE - } - - return source.name in ctes_in_scope + self._tables = { + table for table in self._tables if str(table) not in self._alias_names + } + return self._tables @property def limit(self) -> Optional[int]: @@ -543,6 +393,28 @@ class ParsedQuery: def _is_identifier(token: Token) -> bool: return isinstance(token, (IdentifierList, Identifier)) + def _process_tokenlist(self, token_list: TokenList) -> None: + """ + Add table names to table set + + :param token_list: TokenList to be processed + """ + # exclude subselects + if "(" not in str(token_list): + table = self.get_table(token_list) + if table and not table.table.startswith(CTE_PREFIX): + self._tables.add(table) + return + + # store aliases + if token_list.has_alias(): + self._alias_names.add(token_list.get_alias()) + + # some aliases are not parsed properly + if token_list.tokens[0].ttype == Name: + self._alias_names.add(token_list.tokens[0].value) + self._extract_from_token(token_list) + def as_create_table( self, table_name: str, @@ -569,6 +441,50 @@ class ParsedQuery: exec_sql += f"CREATE {method} {full_table_name} AS \n{sql}" return exec_sql + def _extract_from_token(self, token: Token) -> None: + """ + <Identifier> store a list of subtokens and <IdentifierList> store lists of + subtoken list. + + It extracts <IdentifierList> and <Identifier> from :param token: and loops + through all subtokens recursively. It finds table_name_preceding_token and + passes <IdentifierList> and <Identifier> to self._process_tokenlist to populate + self._tables. + + :param token: instance of Token or child class, e.g. TokenList, to be processed + """ + if not hasattr(token, "tokens"): + return + + table_name_preceding_token = False + + for item in token.tokens: + if item.is_group and ( + not self._is_identifier(item) or isinstance(item.tokens[0], Parenthesis) + ): + self._extract_from_token(item) + + if item.ttype in Keyword and ( + item.normalized in PRECEDES_TABLE_NAME + or item.normalized.endswith(" JOIN") + ): + table_name_preceding_token = True + continue + + if item.ttype in Keyword: + table_name_preceding_token = False + continue + if table_name_preceding_token: + if isinstance(item, Identifier): + self._process_tokenlist(item) + elif isinstance(item, IdentifierList): + for token2 in item.get_identifiers(): + if isinstance(token2, TokenList): + self._process_tokenlist(token2) + elif isinstance(item, IdentifierList): + if any(not self._is_identifier(token2) for token2 in item.tokens): + self._extract_from_token(item) + def set_or_update_query_limit(self, new_limit: int, force: bool = False) -> str: """Returns the query with the specified limit. @@ -965,7 +881,7 @@ def insert_rls_in_predicate( # mapping between sqloxide and SQLAlchemy dialects -SQLOXIDE_DIALECTS = { +SQLOXITE_DIALECTS = { "ansi": {"trino", "trinonative", "presto"}, "hive": {"hive", "databricks"}, "ms": {"mssql"}, @@ -998,7 +914,7 @@ def extract_table_references( tree = None if sqloxide_parse: - for dialect, sqla_dialects in SQLOXIDE_DIALECTS.items(): + for dialect, sqla_dialects in SQLOXITE_DIALECTS.items(): if sqla_dialect in sqla_dialects: break sql_text = RE_JINJA_BLOCK.sub(" ", sql_text) diff --git a/superset/sql_validators/presto_db.py b/superset/sql_validators/presto_db.py index fed1ff3bfa..c01b938671 100644 --- a/superset/sql_validators/presto_db.py +++ b/superset/sql_validators/presto_db.py @@ -50,7 +50,7 @@ class PrestoDBSQLValidator(BaseSQLValidator): ) -> Optional[SQLValidationAnnotation]: # pylint: disable=too-many-locals db_engine_spec = database.db_engine_spec - parsed_query = ParsedQuery(statement, engine=db_engine_spec.engine) + parsed_query = ParsedQuery(statement) sql = parsed_query.stripped() # Hook to allow environment-specific mutation (usually comments) to the SQL @@ -154,7 +154,7 @@ class PrestoDBSQLValidator(BaseSQLValidator): For example, "SELECT 1 FROM default.mytable" becomes "EXPLAIN (TYPE VALIDATE) SELECT 1 FROM default.mytable. """ - parsed_query = ParsedQuery(sql, engine=database.db_engine_spec.engine) + parsed_query = ParsedQuery(sql) statements = parsed_query.get_statements() logger.info("Validating %i statement(s)", len(statements)) diff --git a/superset/sqllab/query_render.py b/superset/sqllab/query_render.py index 5597bcb086..f4c1c26c6e 100644 --- a/superset/sqllab/query_render.py +++ b/superset/sqllab/query_render.py @@ -58,11 +58,7 @@ class SqlQueryRenderImpl(SqlQueryRender): database=query_model.database, query=query_model ) - parsed_query = ParsedQuery( - query_model.sql, - strip_comments=True, - engine=query_model.database.db_engine_spec.engine, - ) + parsed_query = ParsedQuery(query_model.sql, strip_comments=True) rendered_query = sql_template_processor.process_template( parsed_query.stripped(), **execution_context.template_params ) diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py index f650b77734..efd8838101 100644 --- a/tests/unit_tests/sql_parse_tests.py +++ b/tests/unit_tests/sql_parse_tests.py @@ -40,11 +40,11 @@ from superset.sql_parse import ( ) -def extract_tables(query: str, engine: Optional[str] = None) -> set[Table]: +def extract_tables(query: str) -> set[Table]: """ Helper function to extract tables referenced in a query. """ - return ParsedQuery(query, engine=engine).tables + return ParsedQuery(query).tables def test_table() -> None: @@ -96,13 +96,8 @@ def test_extract_tables() -> None: Table("left_table") } - assert extract_tables( - "SELECT FROM (SELECT FROM forbidden_table) AS forbidden_table;" - ) == {Table("forbidden_table")} - - assert extract_tables( - "select * from (select * from forbidden_table) forbidden_table" - ) == {Table("forbidden_table")} + # reverse select + assert extract_tables("FROM t1 SELECT field") == {Table("t1")} def test_extract_tables_subselect() -> None: @@ -268,16 +263,14 @@ def test_extract_tables_illdefined() -> None: assert extract_tables("SELECT * FROM schemaname.") == set() assert extract_tables("SELECT * FROM catalogname.schemaname.") == set() assert extract_tables("SELECT * FROM catalogname..") == set() - assert extract_tables("SELECT * FROM catalogname..tbname") == { - Table(table="tbname", schema=None, catalog="catalogname") - } + assert extract_tables("SELECT * FROM catalogname..tbname") == set() def test_extract_tables_show_tables_from() -> None: """ Test ``SHOW TABLES FROM``. """ - assert extract_tables("SHOW TABLES FROM s1 like '%order%'", "mysql") == set() + assert extract_tables("SHOW TABLES FROM s1 like '%order%'") == set() def test_extract_tables_show_columns_from() -> None: @@ -318,7 +311,7 @@ WHERE regionkey IN (SELECT regionkey FROM t2) """ SELECT name FROM t1 -WHERE EXISTS (SELECT 1 FROM t2 WHERE t1.regionkey = t2.regionkey); +WHERE regionkey EXISTS (SELECT regionkey FROM t2) """ ) == {Table("t1"), Table("t2")} @@ -533,18 +526,6 @@ select * from (select key from q1) a == {Table("src")} ) - # weird query with circular dependency - assert ( - extract_tables( - """ -with src as ( select key from q2 where key = '5'), -q2 as ( select key from src where key = '5') -select * from (select key from src) a -""" - ) - == set() - ) - def test_extract_tables_multistatement() -> None: """ @@ -684,8 +665,7 @@ def test_extract_tables_nested_select() -> None: select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(TABLE_NAME) from INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA like "%bi%"),0x7e))); -""", - "mysql", +""" ) == {Table("COLUMNS", "INFORMATION_SCHEMA")} ) @@ -696,8 +676,7 @@ WHERE TABLE_SCHEMA like "%bi%"),0x7e))); select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(COLUMN_NAME) from INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME="bi_achievement_daily"),0x7e))); -""", - "mysql", +""" ) == {Table("COLUMNS", "INFORMATION_SCHEMA")} ) @@ -1327,14 +1306,6 @@ def test_sqlparse_issue_652(): "(SELECT table_name FROM /**/ information_schema.tables WHERE table_name LIKE '%user%' LIMIT 1)", True, ), - ( - "SELECT FROM (SELECT FROM forbidden_table) AS forbidden_table;", - True, - ), - ( - "SELECT * FROM (SELECT * FROM forbidden_table) forbidden_table", - True, - ), ], ) def test_has_table_query(sql: str, expected: bool) -> None: @@ -1819,17 +1790,13 @@ def test_extract_table_references(mocker: MockerFixture) -> None: assert extract_table_references( sql, "trino", - ) == { - Table(table="table", schema=None, catalog=None), - Table(table="other_table", schema=None, catalog=None), - } + ) == {Table(table="other_table", schema=None, catalog=None)} logger.warning.assert_called_once() logger = mocker.patch("superset.migrations.shared.utils.logger") sql = "SELECT * FROM table UNION ALL SELECT * FROM other_table" assert extract_table_references(sql, "trino", show_warning=False) == { - Table(table="table", schema=None, catalog=None), - Table(table="other_table", schema=None, catalog=None), + Table(table="other_table", schema=None, catalog=None) } logger.warning.assert_not_called()
