This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch rls-sqlglot
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 74745b5ba656876482ccb19bea3733d73a960e5c
Author: Beto Dealmeida <[email protected]>
AuthorDate: Wed May 22 22:48:42 2024 -0700

    WIP
---
 superset/connectors/sqla/models.py   |  49 +----
 superset/db_engine_specs/mssql.py    |   2 +-
 superset/db_engine_specs/teradata.py |   2 +-
 superset/models/helpers.py           |   6 +-
 superset/sql_lab.py                  |  70 ++-----
 superset/sql_parse.py                | 293 ++++++++++++++++++++++++++-
 tests/unit_tests/sql_parse_tests.py  | 379 +++++++++++++++++++++++++++--------
 7 files changed, 603 insertions(+), 198 deletions(-)

diff --git a/superset/connectors/sqla/models.py 
b/superset/connectors/sqla/models.py
index 8c74dfd589..6dd0d76b69 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -22,7 +22,6 @@ import dataclasses
 import json
 import logging
 import re
-from collections import defaultdict
 from collections.abc import Hashable
 from dataclasses import dataclass, field
 from datetime import datetime, timedelta
@@ -72,7 +71,7 @@ from sqlalchemy.sql.elements import ColumnClause, TextClause
 from sqlalchemy.sql.expression import Label, TextAsFrom
 from sqlalchemy.sql.selectable import Alias, TableClause
 
-from superset import app, db, is_feature_enabled, security_manager
+from superset import app, db, security_manager
 from superset.commands.dataset.exceptions import DatasetNotFoundError
 from superset.common.db_query_status import QueryStatus
 from superset.connectors.sqla.utils import (
@@ -175,7 +174,9 @@ class DatasourceKind(StrEnum):
     PHYSICAL = "physical"
 
 
-class BaseDatasource(AuditMixinNullable, ImportExportMixin):  # pylint: 
disable=too-many-public-methods
+class BaseDatasource(
+    AuditMixinNullable, ImportExportMixin
+):  # pylint: disable=too-many-public-methods
     """A common interface to objects that are queryable
     (tables and datasources)"""
 
@@ -1605,48 +1606,6 @@ class SqlaTable(
             if is_alias_used_in_orderby(col):
                 col.name = f"{col.name}__"
 
-    def get_sqla_row_level_filters(
-        self,
-        template_processor: BaseTemplateProcessor,
-    ) -> list[TextClause]:
-        """
-        Return the appropriate row level security filters for this table and 
the
-        current user. A custom username can be passed when the user is not 
present in the
-        Flask global namespace.
-
-        :param template_processor: The template processor to apply to the 
filters.
-        :returns: A list of SQL clauses to be ANDed together.
-        """
-        all_filters: list[TextClause] = []
-        filter_groups: dict[int | str, list[TextClause]] = defaultdict(list)
-        try:
-            for filter_ in security_manager.get_rls_filters(self):
-                clause = self.text(
-                    f"({template_processor.process_template(filter_.clause)})"
-                )
-                if filter_.group_key:
-                    filter_groups[filter_.group_key].append(clause)
-                else:
-                    all_filters.append(clause)
-
-            if is_feature_enabled("EMBEDDED_SUPERSET"):
-                for rule in security_manager.get_guest_rls_filters(self):
-                    clause = self.text(
-                        
f"({template_processor.process_template(rule['clause'])})"
-                    )
-                    all_filters.append(clause)
-
-            grouped_filters = [or_(*clauses) for clauses in 
filter_groups.values()]
-            all_filters.extend(grouped_filters)
-            return all_filters
-        except TemplateError as ex:
-            raise QueryObjectValidationError(
-                _(
-                    "Error in jinja expression in RLS filters: %(msg)s",
-                    msg=ex.message,
-                )
-            ) from ex
-
     def text(self, clause: str) -> TextClause:
         return self.db_engine_spec.get_text_clause(clause)
 
diff --git a/superset/db_engine_specs/mssql.py 
b/superset/db_engine_specs/mssql.py
index d5cc86c859..3521deeb51 100644
--- a/superset/db_engine_specs/mssql.py
+++ b/superset/db_engine_specs/mssql.py
@@ -49,7 +49,7 @@ CONNECTION_HOST_DOWN_REGEX = re.compile(
 class MssqlEngineSpec(BaseEngineSpec):
     engine = "mssql"
     engine_name = "Microsoft SQL Server"
-    limit_method = LimitMethod.WRAP_SQL
+    limit_method = LimitMethod.FORCE_LIMIT
     max_column_name_length = 128
     allows_cte_in_subquery = False
     allow_limit_clause = False
diff --git a/superset/db_engine_specs/teradata.py 
b/superset/db_engine_specs/teradata.py
index 887add24e9..910ac9461d 100644
--- a/superset/db_engine_specs/teradata.py
+++ b/superset/db_engine_specs/teradata.py
@@ -23,7 +23,7 @@ class TeradataEngineSpec(BaseEngineSpec):
 
     engine = "teradatasql"
     engine_name = "Teradata"
-    limit_method = LimitMethod.WRAP_SQL
+    limit_method = LimitMethod.FORCE_LIMIT
     max_column_name_length = 30  # since 14.10 this is 128
     allow_limit_clause = False
     select_keywords = {"SELECT", "SEL"}
diff --git a/superset/models/helpers.py b/superset/models/helpers.py
index 5845a400c2..790ed7e76c 100644
--- a/superset/models/helpers.py
+++ b/superset/models/helpers.py
@@ -17,6 +17,8 @@
 # pylint: disable=too-many-lines
 """a collection of model-related helper classes and functions"""
 
+from __future__ import annotations
+
 import builtins
 import dataclasses
 import json
@@ -807,7 +809,7 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
 
     def get_sqla_row_level_filters(
         self,
-        template_processor: BaseTemplateProcessor,
+        template_processor: BaseTemplateProcessor | None,
     ) -> list[TextClause]:
         """
         Return the appropriate row level security filters for this table and 
the
@@ -817,6 +819,8 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
         :param template_processor: The template processor to apply to the 
filters.
         :returns: A list of SQL clauses to be ANDed together.
         """
+        template_processor = template_processor or 
self.get_template_processor()
+
         all_filters: list[TextClause] = []
         filter_groups: dict[Union[int, str], list[TextClause]] = 
defaultdict(list)
         try:
diff --git a/superset/sql_lab.py b/superset/sql_lab.py
index 3f8c1cc737..4c67e84ae9 100644
--- a/superset/sql_lab.py
+++ b/superset/sql_lab.py
@@ -52,9 +52,8 @@ from superset.models.sql_lab import Query
 from superset.result_set import SupersetResultSet
 from superset.sql_parse import (
     CtasMethod,
-    insert_rls_as_subquery,
-    insert_rls_in_predicate,
     ParsedQuery,
+    SQLStatement,
     Table,
 )
 from superset.sqllab.limiting_factor import LimitingFactor
@@ -203,67 +202,49 @@ def execute_sql_statement(  # pylint: 
disable=too-many-statements
     database: Database = query.database
     db_engine_spec = database.db_engine_spec
 
-    parsed_query = ParsedQuery(sql_statement, engine=db_engine_spec.engine)
-    if is_feature_enabled("RLS_IN_SQLLAB"):
-        # There are two ways to insert RLS: either replacing the table with a 
subquery
-        # that has the RLS, or appending the RLS to the ``WHERE`` clause. The 
former is
-        # safer, but not supported in all databases.
-        insert_rls = (
-            insert_rls_as_subquery
-            if database.db_engine_spec.allows_subqueries
-            and database.db_engine_spec.allows_alias_in_select
-            else insert_rls_in_predicate
-        )
+    parsed_statement = SQLStatement(sql_statement, 
engine=db_engine_spec.engine)
 
-        # Insert any applicable RLS predicates
-        parsed_query = ParsedQuery(
-            str(
-                insert_rls(
-                    parsed_query._parsed[0],  # pylint: 
disable=protected-access
-                    database.id,
-                    query.schema,
-                )
-            ),
-            engine=db_engine_spec.engine,
-        )
-
-    sql = parsed_query.stripped()
-
-    # This is a test to see if the query is being
-    # limited by either the dropdown or the sql.
-    # We are testing to see if more rows exist than the limit.
-    increased_limit = None if query.limit is None else query.limit + 1
+    if is_feature_enabled("RLS_IN_SQLLAB"):
+        default_schema = database.get_default_schema_for_query(query)
+        parsed_statement = parsed_statement.apply_rls(query.catalog, 
default_schema)
 
-    if not db_engine_spec.is_readonly_query(parsed_query) and not 
database.allow_dml:
+    if parsed_statement.is_dml() and not database.allow_dml:
         raise SupersetErrorException(
             SupersetError(
-                message=__("Only SELECT statements are allowed against this 
database."),
+                message=__("Only read-only statements are allowed in this 
database."),
                 error_type=SupersetErrorType.DML_NOT_ALLOWED_ERROR,
                 level=ErrorLevel.ERROR,
             )
         )
+
     if apply_ctas:
         if not query.tmp_table_name:
             start_dttm = datetime.fromtimestamp(query.start_time)
             query.tmp_table_name = (
                 
f'tmp_{query.user_id}_table_{start_dttm.strftime("%Y_%m_%d_%H_%M_%S")}'
             )
-        sql = parsed_query.as_create_table(
-            query.tmp_table_name,
-            schema_name=query.tmp_schema_name,
+        parsed_statement = parsed_statement.as_create_table(
+            Table(query.tmp_table_name, query.tmp_schema_name, query.catalog),
             method=query.ctas_method,
         )
         query.select_as_cta_used = True
 
+    increased_limit = None if query.limit is None else query.limit + 1
+
     # Do not apply limit to the CTA queries when SQLLAB_CTAS_NO_LIMIT is set 
to true
-    if db_engine_spec.is_select_query(parsed_query) and not (
+    if parsed_statement.is_select() and not (
         query.select_as_cta_used and SQLLAB_CTAS_NO_LIMIT
     ):
         if SQL_MAX_ROW and (not query.limit or query.limit > SQL_MAX_ROW):
             query.limit = SQL_MAX_ROW
-        sql = apply_limit_if_exists(database, increased_limit, query, sql)
+
+        if query.limit:
+            # Increase limit by one so we can test if there are more rows when 
the
+            # database returns exactly the number of rows requested by the 
user.
+            parsed_statement = parsed_statement.apply_limit(increased_limit)
 
     # Hook to allow environment-specific mutation (usually comments) to the SQL
+    sql = parsed_statement.format(strip=True)
     sql = database.mutate_sql_based_on_config(sql)
     try:
         query.executed_sql = sql
@@ -331,19 +312,6 @@ def execute_sql_statement(  # pylint: 
disable=too-many-statements
     return SupersetResultSet(data, cursor_description, db_engine_spec)
 
 
-def apply_limit_if_exists(
-    database: Database, increased_limit: Optional[int], query: Query, sql: str
-) -> str:
-    if query.limit and increased_limit:
-        # We are fetching one more than the requested limit in order
-        # to test whether there are more rows than the limit. According to the 
DB
-        # Engine support it will choose top or limit parse
-        # Later, the extra row will be dropped before sending
-        # the results back to the user.
-        sql = database.apply_limit_to_sql(sql, increased_limit, force=True)
-    return sql
-
-
 def _serialize_payload(
     payload: dict[Any, Any], use_msgpack: Optional[bool] = False
 ) -> Union[bytes, str]:
diff --git a/superset/sql_parse.py b/superset/sql_parse.py
index f32647042b..24c279c3d1 100644
--- a/superset/sql_parse.py
+++ b/superset/sql_parse.py
@@ -32,7 +32,7 @@ import sqlparse
 from flask_babel import gettext as __
 from jinja2 import nodes
 from sqlalchemy import and_
-from sqlglot import exp, parse, parse_one
+from sqlglot import exp
 from sqlglot.dialects.dialect import Dialect, Dialects
 from sqlglot.errors import ParseError, SqlglotError
 from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope
@@ -294,7 +294,7 @@ def extract_tables_from_statement(
             return set()
 
         try:
-            pseudo_query = parse_one(f"SELECT {literal.this}", dialect=dialect)
+            pseudo_query = sqlglot.parse_one(f"SELECT {literal.this}", 
dialect=dialect)
         except ParseError:
             return set()
         sources = pseudo_query.find_all(exp.Table)
@@ -419,7 +419,7 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
         """
         raise NotImplementedError()
 
-    def format(self, comments: bool = True) -> str:
+    def format(self, comments: bool = True, strip: bool = False) -> str:
         """
         Format the statement, optionally ommitting comments.
         """
@@ -437,10 +437,93 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
         """
         raise NotImplementedError()
 
+    def apply_rls(
+        self,
+        catalog: str | None,
+        schema: str | None,
+    ) -> InternalRepresentation:
+        """
+        Apply Row Level Security to the SQL.
+
+        :param database: The database where the SQL will run
+        :param catalog: The default catalog for non-qualified table names
+        :param schema: The default schema for non-qualified table names
+        :return: The SQL with RLS applied
+        """
+        raise NotImplementedError()
+
     def __str__(self) -> str:
         return self.format()
 
 
+class RLSAsPredicate:
+    """
+    Apply Row Level Security role as a predicate.
+
+    This transformer will apply any RLS predicates to the relevant tables. For 
example,
+    given the RLS rule:
+
+        table: some_table
+        clause: id = 42
+
+    If a user subject to the rule runs the following query:
+
+        SELECT foo FROM some_table WHERE bar = 'baz'
+
+    The query will be modified to:
+
+        SELECT foo FROM some_table WHERE bar = 'baz' AND id = 42
+
+    This approach is probably less secure than using subqueries, so it's only 
used for
+    databases without support for subqueries.
+    """
+
+    def __init__(self, rules: dict[Table, str]) -> None:
+        self.rules = rules
+
+    def __call__(self, node: exp.Expression) -> exp.Expression:
+        if not isinstance(node, exp.Select):
+            return node
+
+        table_node = node.find(exp.Table)
+        if not table_node:
+            return node
+
+        table = Table(
+            str(table_node.this),
+            str(table_node.db) if table_node.db else None,
+            str(table_node.catalog) if table_node.catalog else None,
+        )
+        if predicate := self.rules.get(table):
+            if where := node.args.get("where"):
+                predicate = exp.And(this=predicate, expression=where.this)
+
+            node.set("where", exp.Where(this=predicate))
+
+        return node
+
+
+class RLSAsSubquery:
+    def __init__(self, rules: dict[Table, str]) -> None:
+        self.rules = rules
+
+    def __call__(self, node: exp.Expression) -> exp.Expression:
+        if not isinstance(node, exp.Table):
+            return node
+
+        table = Table(
+            str(node.this),
+            str(node.db) if node.db else None,
+            str(node.catalog) if node.catalog else None,
+        )
+        if predicate := self.rules.get(table):
+            alias = node.alias
+            node.set("alias", None)
+            return f"(SELECT * FROM {node} WHERE {predicate}) AS {alias}"
+
+        return node
+
+
 class SQLStatement(BaseSQLStatement[exp.Expression]):
     """
     A SQL statement.
@@ -507,12 +590,19 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
         dialect = SQLGLOT_DIALECTS.get(engine)
         return extract_tables_from_statement(parsed, dialect)
 
-    def format(self, comments: bool = True) -> str:
+    def format(self, comments: bool = True, strip: bool = False) -> str:
         """
         Pretty-format the SQL statement.
         """
         write = Dialect.get_or_raise(self._dialect)
-        return write.generate(self._parsed, copy=False, comments=comments, 
pretty=True)
+        output = write.generate(
+            self._parsed,
+            copy=False,
+            comments=comments,
+            pretty=True,
+        )
+
+        return output.strip(" \t\r\n;") if strip else output
 
     def get_settings(self) -> dict[str, str | bool]:
         """
@@ -529,6 +619,186 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
             for eq in set_item.find_all(exp.EQ)
         }
 
+    def apply_rls(
+        self,
+        catalog: str | None,
+        schema: str | None,
+    ) -> SQLStatement:
+        """
+        Apply Row Level Security to the SQL.
+
+        :param catalog: The default catalog for non-qualified table names
+        :param schema: The default schema for non-qualified table names
+        :return: The SQL with RLS applied
+        """
+        from superset.db_engine_specs import load_engine_specs
+
+        statement = self._parsed.copy()
+
+        # collect all relevant RLS rules
+        rules = {}
+        for table in self.tables:
+            if rls := self._get_rls_for_table(table, catalog, schema):
+                rules[table] = rls
+
+        if not rules:
+            return statement
+
+        use_subquery = all(
+            engine_spec.allows_subqueries
+            for engine_spec in load_engine_specs()
+            if engine_spec.engine == self.engine
+        )
+        transformer = RLSAsSubquery(rules) if use_subquery else 
RLSAsPredicate(rules)
+
+        return SQLStatement(statement.transform(transformer), self.engine)
+
+    def _get_rls_for_table(
+        self,
+        database: Database,
+        table: Table,
+        catalog: str | None,
+        schema: str | None,
+    ) -> exp.Expression | None:
+        """
+        Get the RLS for a table.
+
+        :param table: The table to get the RLS for
+        :param catalog: The default catalog for non-qualified table names
+        :param schema: The default schema for non-qualified table names
+        :return: The RLS for the table
+        """
+        # pylint: disable=import-outside-toplevel
+        from superset import db
+        from superset.connectors.sqla.models import SqlaTable
+
+        dataset = db.session.query(SqlaTable).filter(
+            and_(
+                SqlaTable.database_id == database.id,
+                SqlaTable.catalog == table.catalog or catalog,
+                SqlaTable.schema == table.schema or schema,
+                SqlaTable.table_name == table.table,
+            ).one_or_none()
+        )
+        if not dataset:
+            return None
+
+        filters = dataset.get_sqla_row_level_filters()
+        if not filters:
+            return None
+
+        rls = and_(*filters).compile(
+            dialect=database.get_dialect(),
+            compile_kwargs={"literal_binds": True},
+        )
+
+        return sqlglot.parse_one(str(rls), dialect=self._dialect)
+
+    def is_dml(self) -> bool:
+        """
+        Check if the statement is DML.
+
+        :return: True if the statement is DML
+        """
+        for node in self._parsed.walk():
+            if isinstance(
+                node,
+                (
+                    exp.Insert,
+                    exp.Update,
+                    exp.Delete,
+                    exp.Merge,
+                    exp.Create,
+                    exp.Alter,
+                    exp.Drop,
+                    exp.TruncateTable,
+                ),
+            ):
+                return True
+
+        return False
+
+    def as_create_table(self, table: Table, method: CtasMethod) -> 
SQLStatement:
+        """
+        Convert the statement to a CREATE TABLE statement.
+        """
+        create_table = exp.Create(
+            this=sqlglot.parse_one(table, into=exp.Table),
+            kind=method.value,
+            expression=self._parsed.copy(),
+        )
+
+        return SQLStatement(create_table, self.engine)
+
+    def is_select(self) -> bool:
+        """
+        Check if the statement is a SELECT statement.
+
+        :return: True if the statement is a SELECT statement
+        """
+        return isinstance(self._parsed, exp.Select)
+
+    def apply_limit(self, limit: int, force: bool = False) -> SQLStatement:
+        """
+        Apply a limit to the SQL.
+
+        There are 3 strategies to limit queries, defined in the DB engine spec:
+
+            1. `FORCE_LIMIT`: a limit is added to the query, or the existing 
one is
+                replaced. This is the most efficient, since the database will 
produce at
+                most the number of rows that Superset will display.
+            2. `WRAP_SQL`: the query is wrapped in a subquery, and the limit 
is applied
+                to the outer query. This might be inneficient, since the 
database
+                optimizer might not be able to push the limit down to the 
inner query.
+            3. `FETCH_MANY`: no limit is applied, but only `LIMIT` rows are 
fetched from
+                the database. This is the least efficient, unless the database 
computes
+                rows as they are read by the cursor, which is unlikely.
+
+        :param limit: The limit to apply
+        :param force: Apply limit even when a lower one is present
+        :return: The SQL with the limit applied
+        """
+        from superset.db_engine_specs import load_engine_specs
+        from superset.db_engine_specs.base import LimitMethod
+
+        methods = {
+            engine_spec.limit_method
+            for engine_spec in load_engine_specs()
+            if engine_spec.engine == self.engine
+        }
+        if not methods:
+            methods = {LimitMethod.FETCH_MANY}
+
+        # When multiple methods are supported, we prefer the more generic one 
--
+        # usually less efficient.
+        preference = [
+            LimitMethod.FETCH_MANY,
+            LimitMethod.WRAP_SQL,
+            LimitMethod.FORCE_LIMIT,
+        ]
+        method = sorted(methods, key=preference.index)[0]
+
+        if not self.is_select() or method == LimitMethod.FETCH_MANY:
+            return SQLStatement(self._parsed.copy(), self.engine)
+
+        if method == LimitMethod.WRAP_SQL:
+            limited = exp.Select(
+                expressions=[exp.Star()],
+                from_=exp.Subquery(subquery=self._parsed.copy(), 
alias="inner_qry"),
+                limit=exp.Literal.number(limit),
+            )
+            return SQLStatement(limited, self.engine)
+
+        current_limit: int | None = None
+        for node in self._parsed.find_all(exp.Limit):
+            current_limit = int(node.expression.this)
+            break
+
+        if force or current_limit is None or limit < current_limit:
+            return SQLStatement(self._parsed.limit(limit), self.engine)
+
+        return SQLStatement(self._parsed.copy(), self.engine)
+
 
 class KQLSplitState(enum.Enum):
     """
@@ -652,11 +922,11 @@ class KustoKQLStatement(BaseSQLStatement[str]):
         )
         return set()
 
-    def format(self, comments: bool = True) -> str:
+    def format(self, comments: bool = True, strip: bool = False) -> str:
         """
         Pretty-format the SQL statement.
         """
-        return self._parsed
+        return self._parsed.strip(" \t\r\n;") if strip else self._parsed
 
     def get_settings(self) -> dict[str, str | bool]:
         """
@@ -698,7 +968,10 @@ class SQLScript:
         """
         Pretty-format the SQL query.
         """
-        return ";\n".join(statement.format(comments) for statement in 
self.statements)
+        return (
+            ";\n".join(statement.format(comments) for statement in 
self.statements)
+            + ";"
+        )
 
     def get_settings(self) -> dict[str, str | bool]:
         """
@@ -750,7 +1023,7 @@ class ParsedQuery:
         Note: this uses sqlglot, since it's better at catching more edge cases.
         """
         try:
-            statements = parse(self.stripped(), dialect=self._dialect)
+            statements = sqlglot.parse(self.stripped(), dialect=self._dialect)
         except SqlglotError as ex:
             logger.warning("Unable to parse SQL (%s): %s", self._dialect, 
self.sql)
 
@@ -804,7 +1077,7 @@ class ParsedQuery:
                 return set()
 
             try:
-                pseudo_query = parse_one(
+                pseudo_query = sqlglot.parse_one(
                     f"SELECT {literal.this}",
                     dialect=self._dialect,
                 )
diff --git a/tests/unit_tests/sql_parse_tests.py 
b/tests/unit_tests/sql_parse_tests.py
index 3b80b7e01d..a8e655773d 100644
--- a/tests/unit_tests/sql_parse_tests.py
+++ b/tests/unit_tests/sql_parse_tests.py
@@ -20,6 +20,7 @@ from typing import Optional
 from unittest.mock import Mock
 
 import pytest
+import sqlglot
 import sqlparse
 from pytest_mock import MockerFixture
 from sqlalchemy import text
@@ -40,6 +41,8 @@ from superset.sql_parse import (
     insert_rls_in_predicate,
     KustoKQLStatement,
     ParsedQuery,
+    RLSAsPredicate,
+    RLSAsSubquery,
     sanitize_clause,
     split_kql,
     SQLScript,
@@ -118,8 +121,9 @@ def test_extract_tables_subselect() -> None:
     """
     Test that tables inside subselects are parsed correctly.
     """
-    assert extract_tables(
-        """
+    assert (
+        extract_tables(
+            """
 SELECT sub.*
 FROM (
     SELECT *
@@ -128,10 +132,13 @@ FROM (
     ) sub, s2.t2
 WHERE sub.resolution = 'NONE'
 """
-    ) == {Table("t1", "s1"), Table("t2", "s2")}
+        )
+        == {Table("t1", "s1"), Table("t2", "s2")}
+    )
 
-    assert extract_tables(
-        """
+    assert (
+        extract_tables(
+            """
 SELECT sub.*
 FROM (
     SELECT *
@@ -140,10 +147,13 @@ FROM (
 ) sub
 WHERE sub.resolution = 'NONE'
 """
-    ) == {Table("t1", "s1")}
+        )
+        == {Table("t1", "s1")}
+    )
 
-    assert extract_tables(
-        """
+    assert (
+        extract_tables(
+            """
 SELECT * FROM t1
 WHERE s11 > ANY (
     SELECT COUNT(*) /* no hint */ FROM t2
@@ -155,7 +165,9 @@ WHERE s11 > ANY (
     )
 )
 """
-    ) == {Table("t1"), Table("t2"), Table("t3"), Table("t4")}
+        )
+        == {Table("t1"), Table("t2"), Table("t3"), Table("t4")}
+    )
 
 
 def test_extract_tables_select_in_expression() -> None:
@@ -226,24 +238,30 @@ def test_extract_tables_select_array() -> None:
     """
     Test that queries selecting arrays work as expected.
     """
-    assert extract_tables(
-        """
+    assert (
+        extract_tables(
+            """
 SELECT ARRAY[1, 2, 3] AS my_array
 FROM t1 LIMIT 10
 """
-    ) == {Table("t1")}
+        )
+        == {Table("t1")}
+    )
 
 
 def test_extract_tables_select_if() -> None:
     """
     Test that queries with an ``IF`` work as expected.
     """
-    assert extract_tables(
-        """
+    assert (
+        extract_tables(
+            """
 SELECT IF(CARDINALITY(my_array) >= 3, my_array[3], NULL)
 FROM t1 LIMIT 10
 """
-    ) == {Table("t1")}
+        )
+        == {Table("t1")}
+    )
 
 
 def test_extract_tables_with_catalog() -> None:
@@ -311,29 +329,38 @@ def test_extract_tables_where_subquery() -> None:
     """
     Test that tables in a ``WHERE`` subquery are parsed correctly.
     """
-    assert extract_tables(
-        """
+    assert (
+        extract_tables(
+            """
 SELECT name
 FROM t1
 WHERE regionkey = (SELECT max(regionkey) FROM t2)
 """
-    ) == {Table("t1"), Table("t2")}
+        )
+        == {Table("t1"), Table("t2")}
+    )
 
-    assert extract_tables(
-        """
+    assert (
+        extract_tables(
+            """
 SELECT name
 FROM t1
 WHERE regionkey IN (SELECT regionkey FROM t2)
 """
-    ) == {Table("t1"), Table("t2")}
+        )
+        == {Table("t1"), Table("t2")}
+    )
 
-    assert extract_tables(
-        """
+    assert (
+        extract_tables(
+            """
 SELECT name
 FROM t1
 WHERE EXISTS (SELECT 1 FROM t2 WHERE t1.regionkey = t2.regionkey);
 """
-    ) == {Table("t1"), Table("t2")}
+        )
+        == {Table("t1"), Table("t2")}
+    )
 
 
 def test_extract_tables_describe() -> None:
@@ -347,12 +374,15 @@ def test_extract_tables_show_partitions() -> None:
     """
     Test ``SHOW PARTITIONS``.
     """
-    assert extract_tables(
-        """
+    assert (
+        extract_tables(
+            """
 SHOW PARTITIONS FROM orders
 WHERE ds >= '2013-01-01' ORDER BY ds DESC
 """
-    ) == {Table("orders")}
+        )
+        == {Table("orders")}
+    )
 
 
 def test_extract_tables_join() -> None:
@@ -364,8 +394,9 @@ def test_extract_tables_join() -> None:
         Table("t2"),
     }
 
-    assert extract_tables(
-        """
+    assert (
+        extract_tables(
+            """
 SELECT a.date, b.name
 FROM left_table a
 JOIN (
@@ -376,10 +407,13 @@ JOIN (
 ) b
 ON a.date = b.date
 """
-    ) == {Table("left_table"), Table("right_table")}
+        )
+        == {Table("left_table"), Table("right_table")}
+    )
 
-    assert extract_tables(
-        """
+    assert (
+        extract_tables(
+            """
 SELECT a.date, b.name
 FROM left_table a
 LEFT INNER JOIN (
@@ -390,10 +424,13 @@ LEFT INNER JOIN (
 ) b
 ON a.date = b.date
 """
-    ) == {Table("left_table"), Table("right_table")}
+        )
+        == {Table("left_table"), Table("right_table")}
+    )
 
-    assert extract_tables(
-        """
+    assert (
+        extract_tables(
+            """
 SELECT a.date, b.name
 FROM left_table a
 RIGHT OUTER JOIN (
@@ -404,10 +441,13 @@ RIGHT OUTER JOIN (
 ) b
 ON a.date = b.date
 """
-    ) == {Table("left_table"), Table("right_table")}
+        )
+        == {Table("left_table"), Table("right_table")}
+    )
 
-    assert extract_tables(
-        """
+    assert (
+        extract_tables(
+            """
 SELECT a.date, b.name
 FROM left_table a
 FULL OUTER JOIN (
@@ -418,15 +458,18 @@ FULL OUTER JOIN (
 ) b
 ON a.date = b.date
 """
-    ) == {Table("left_table"), Table("right_table")}
+        )
+        == {Table("left_table"), Table("right_table")}
+    )
 
 
 def test_extract_tables_semi_join() -> None:
     """
     Test ``LEFT SEMI JOIN``.
     """
-    assert extract_tables(
-        """
+    assert (
+        extract_tables(
+            """
 SELECT a.date, b.name
 FROM left_table a
 LEFT SEMI JOIN (
@@ -437,15 +480,18 @@ LEFT SEMI JOIN (
 ) b
 ON a.data = b.date
 """
-    ) == {Table("left_table"), Table("right_table")}
+        )
+        == {Table("left_table"), Table("right_table")}
+    )
 
 
 def test_extract_tables_combinations() -> None:
     """
     Test a complex case with nested queries.
     """
-    assert extract_tables(
-        """
+    assert (
+        extract_tables(
+            """
 SELECT * FROM t1
 WHERE s11 > ANY (
     SELECT * FROM t1 UNION ALL SELECT * FROM (
@@ -459,10 +505,13 @@ WHERE s11 > ANY (
     )
 )
 """
-    ) == {Table("t1"), Table("t3"), Table("t4"), Table("t6")}
+        )
+        == {Table("t1"), Table("t3"), Table("t4"), Table("t6")}
+    )
 
-    assert extract_tables(
-        """
+    assert (
+        extract_tables(
+            """
 SELECT * FROM (
     SELECT * FROM (
         SELECT * FROM (
@@ -471,45 +520,56 @@ SELECT * FROM (
     ) AS S2
 ) AS S3
 """
-    ) == {Table("EmployeeS")}
+        )
+        == {Table("EmployeeS")}
+    )
 
 
 def test_extract_tables_with() -> None:
     """
     Test ``WITH``.
     """
-    assert extract_tables(
-        """
+    assert (
+        extract_tables(
+            """
 WITH
     x AS (SELECT a FROM t1),
     y AS (SELECT a AS b FROM t2),
     z AS (SELECT b AS c FROM t3)
 SELECT c FROM z
 """
-    ) == {Table("t1"), Table("t2"), Table("t3")}
+        )
+        == {Table("t1"), Table("t2"), Table("t3")}
+    )
 
-    assert extract_tables(
-        """
+    assert (
+        extract_tables(
+            """
 WITH
     x AS (SELECT a FROM t1),
     y AS (SELECT a AS b FROM x),
     z AS (SELECT b AS c FROM y)
 SELECT c FROM z
 """
-    ) == {Table("t1")}
+        )
+        == {Table("t1")}
+    )
 
 
 def test_extract_tables_reusing_aliases() -> None:
     """
     Test that the parser follows aliases.
     """
-    assert extract_tables(
-        """
+    assert (
+        extract_tables(
+            """
 with q1 as ( select key from q2 where key = '5'),
 q2 as ( select key from src where key = '5')
 select * from (select key from q1) a
 """
-    ) == {Table("src")}
+        )
+        == {Table("src")}
+    )
 
     # weird query with circular dependency
     assert (
@@ -546,8 +606,9 @@ def test_extract_tables_complex() -> None:
     """
     Test a few complex queries.
     """
-    assert extract_tables(
-        """
+    assert (
+        extract_tables(
+            """
 SELECT sum(m_examples) AS "sum__m_example"
 FROM (
     SELECT
@@ -568,23 +629,29 @@ FROM (
 ORDER BY "sum__m_example" DESC
 LIMIT 10;
 """
-    ) == {
-        Table("my_l_table"),
-        Table("my_b_table"),
-        Table("my_t_table"),
-        Table("inner_table"),
-    }
+        )
+        == {
+            Table("my_l_table"),
+            Table("my_b_table"),
+            Table("my_t_table"),
+            Table("inner_table"),
+        }
+    )
 
-    assert extract_tables(
-        """
+    assert (
+        extract_tables(
+            """
 SELECT *
 FROM table_a AS a, table_b AS b, table_c as c
 WHERE a.id = b.id and b.id = c.id
 """
-    ) == {Table("table_a"), Table("table_b"), Table("table_c")}
+        )
+        == {Table("table_a"), Table("table_b"), Table("table_c")}
+    )
 
-    assert extract_tables(
-        """
+    assert (
+        extract_tables(
+            """
 SELECT somecol AS somecol
 FROM (
     WITH bla AS (
@@ -628,51 +695,63 @@ FROM (
     LIMIT 50000
 )
 """
-    ) == {Table("a"), Table("b"), Table("c"), Table("d"), Table("e"), 
Table("f")}
+        )
+        == {Table("a"), Table("b"), Table("c"), Table("d"), Table("e"), 
Table("f")}
+    )
 
 
 def test_extract_tables_mixed_from_clause() -> None:
     """
     Test that the parser handles a ``FROM`` clause with table and subselect.
     """
-    assert extract_tables(
-        """
+    assert (
+        extract_tables(
+            """
 SELECT *
 FROM table_a AS a, (select * from table_b) AS b, table_c as c
 WHERE a.id = b.id and b.id = c.id
 """
-    ) == {Table("table_a"), Table("table_b"), Table("table_c")}
+        )
+        == {Table("table_a"), Table("table_b"), Table("table_c")}
+    )
 
 
 def test_extract_tables_nested_select() -> None:
     """
     Test that the parser handles selects inside functions.
     """
-    assert extract_tables(
-        """
+    assert (
+        extract_tables(
+            """
 select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(TABLE_NAME)
 from INFORMATION_SCHEMA.COLUMNS
 WHERE TABLE_SCHEMA like "%bi%"),0x7e)));
 """,
-        "mysql",
-    ) == {Table("COLUMNS", "INFORMATION_SCHEMA")}
+            "mysql",
+        )
+        == {Table("COLUMNS", "INFORMATION_SCHEMA")}
+    )
 
-    assert extract_tables(
-        """
+    assert (
+        extract_tables(
+            """
 select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(COLUMN_NAME)
 from INFORMATION_SCHEMA.COLUMNS
 WHERE TABLE_NAME="bi_achievement_daily"),0x7e)));
 """,
-        "mysql",
-    ) == {Table("COLUMNS", "INFORMATION_SCHEMA")}
+            "mysql",
+        )
+        == {Table("COLUMNS", "INFORMATION_SCHEMA")}
+    )
 
 
 def test_extract_tables_complex_cte_with_prefix() -> None:
     """
     Test that the parser handles CTEs with prefixes.
     """
-    assert extract_tables(
-        """
+    assert (
+        extract_tables(
+            """
 WITH CTE__test (SalesPersonID, SalesOrderID, SalesYear)
 AS (
     SELECT SalesPersonID, SalesOrderID, YEAR(OrderDate) AS SalesYear
@@ -684,21 +763,26 @@ FROM CTE__test
 GROUP BY SalesYear, SalesPersonID
 ORDER BY SalesPersonID, SalesYear;
 """
-    ) == {Table("SalesOrderHeader")}
+        )
+        == {Table("SalesOrderHeader")}
+    )
 
 
 def test_extract_tables_identifier_list_with_keyword_as_alias() -> None:
     """
     Test that aliases that are keywords are parsed correctly.
     """
-    assert extract_tables(
-        """
+    assert (
+        extract_tables(
+            """
 WITH
     f AS (SELECT * FROM foo),
     match AS (SELECT * FROM f)
 SELECT * FROM match
 """
-    ) == {Table("foo")}
+        )
+        == {Table("foo")}
+    )
 
 
 def test_update() -> None:
@@ -1815,7 +1899,7 @@ def test_sqlquery() -> None:
     script = SQLScript("SELECT 1; SELECT 2;", "sqlite")
 
     assert len(script.statements) == 2
-    assert script.format() == "SELECT\n  1;\nSELECT\n  2"
+    assert script.format() == "SELECT\n  1;\nSELECT\n  2;"
     assert script.statements[0].format() == "SELECT\n  1"
 
     script = SQLScript("SET a=1; SET a=2; SELECT 3;", "sqlite")
@@ -2032,3 +2116,120 @@ on $left.Day1 == $right.Day
 | project Day1, Day2, Percentage = count_*100.0/count_1
     """,
     ]
+
+
[email protected](
+    "sql,rules,expected",
+    [
+        (
+            "SELECT t.foo FROM some_table AS t",
+            {Table("some_table"): "id = 42"},
+            "SELECT t.foo FROM (SELECT * FROM some_table WHERE id = 42) AS t",
+        ),
+        (
+            "SELECT t.foo FROM some_table AS t WHERE bar = 'baz'",
+            {Table("some_table"): "id = 42"},
+            (
+                "SELECT t.foo FROM (SELECT * FROM some_table WHERE id = 42) AS 
t "
+                "WHERE bar = 'baz'"
+            ),
+        ),
+        (
+            "SELECT t.foo FROM schema1.some_table AS t",
+            {Table("some_table", "schema1"): "id = 42"},
+            "SELECT t.foo FROM (SELECT * FROM schema1.some_table WHERE id = 
42) AS t",
+        ),
+        (
+            "SELECT t.foo FROM schema1.some_table AS t",
+            {Table("some_table", "schema2"): "id = 42"},
+            "SELECT t.foo FROM schema1.some_table AS t",
+        ),
+        (
+            "SELECT t.foo FROM catalog1.schema1.some_table AS t",
+            {Table("some_table", "schema1", "catalog1"): "id = 42"},
+            "SELECT t.foo FROM (SELECT * FROM catalog1.schema1.some_table 
WHERE id = 42) AS t",
+        ),
+        (
+            "SELECT t.foo FROM catalog1.schema1.some_table AS t",
+            {Table("some_table", "schema1", "catalog2"): "id = 42"},
+            "SELECT t.foo FROM catalog1.schema1.some_table AS t",
+        ),
+    ],
+)
+def test_RLSAsSubquery(sql: str, rules: dict[Table, str], expected: str) -> 
None:
+    """
+    Test the `RLSAsSubquery` transformer.
+    """
+    statement = sqlglot.parse_one(sql)
+    transformer = RLSAsSubquery(rules)
+    assert str(statement.transform(transformer)) == expected
+
+
[email protected](
+    "sql,rules,expected",
+    [
+        (
+            "SELECT t.foo FROM some_table AS t",
+            {Table("some_table"): "id = 42"},
+            "SELECT t.foo FROM some_table AS t WHERE id = 42",
+        ),
+        (
+            "SELECT t.foo FROM some_table AS t WHERE bar = 'baz'",
+            {Table("some_table"): "id = 42"},
+            "SELECT t.foo FROM some_table AS t WHERE id = 42 AND bar = 'baz'",
+        ),
+    ],
+)
+def test_RLSAsPredicate(sql: str, rules: dict[Table, str], expected: str) -> 
None:
+    """
+    Test the `RLSAsPredicate` transformer.
+    """
+    statement = sqlglot.parse_one(sql)
+    transformer = RLSAsPredicate(rules)
+    assert str(statement.transform(transformer)) == expected
+
+
[email protected](
+    "sql,engine,limit,force,expected",
+    [
+        (
+            "SELECT TOP 10 * FROM Customers",
+            "teradatasql",
+            5,
+            False,
+            "SELECT\nTOP 5\n  *\nFROM Customers",
+        ),
+        (
+            "SELECT TOP 10 * FROM Customers",
+            "teradatasql",
+            15,
+            False,
+            "SELECT\nTOP 10\n  *\nFROM Customers",
+        ),
+        (
+            "SELECT TOP 10 * FROM Customers",
+            "teradatasql",
+            15,
+            True,
+            "SELECT\nTOP 15\n  *\nFROM Customers",
+        ),
+        (
+            "SELECT TOP 10 * FROM Customers",
+            "mssql",
+            15,
+            True,
+            "SELECT\nTOP 15\n  *\nFROM Customers",
+        ),
+    ],
+)
+def test_apply_limit(
+    sql: str,
+    engine: str,
+    limit: int,
+    force: bool,
+    expected: str,
+) -> None:
+    """
+    Test the `apply_limit` function.
+    """
+    assert SQLStatement(sql, engine).apply_limit(limit, force).format() == 
expected


Reply via email to