This is an automated email from the ASF dual-hosted git repository.

betodealmeida pushed a commit to branch rls-splice
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 3fe2b2505f3d35009d7f7958b33da78933b24f75
Author: Beto Dealmeida <[email protected]>
AuthorDate: Fri May 8 13:33:42 2026 -0400

    feat: new splice RLSMethod
---
 superset/db_engine_specs/base.py              |  22 ++-
 superset/sql/parse.py                         | 117 +++++++++++-
 superset/sql/rls_splice.py                    | 263 ++++++++++++++++++++++++++
 superset/utils/rls.py                         |  25 ++-
 tests/unit_tests/db_engine_specs/test_base.py |  43 ++++-
 tests/unit_tests/sql/parse_tests.py           | 159 ++++++++++++++++
 6 files changed, 615 insertions(+), 14 deletions(-)

diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 4d26ca8517a..78fc2e78afa 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -557,6 +557,14 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
     # if True, database will be listed as option in the upload file form
     supports_file_upload = True
 
+    # Optional override for the RLS method used by ``get_rls_method``. When 
set,
+    # the engine spec opts into a specific strategy regardless of the
+    # ``allows_subqueries`` / ``allows_alias_in_select`` defaults. Use
+    # ``RLSMethod.AS_PREDICATE_SPLICE`` for engines whose sqlglot dialect can
+    # parse but not faithfully regenerate the SQL — splice mode rewrites the
+    # original query string instead of round-tripping through the generator.
+    rls_method: RLSMethod | None = None
+
     # Is the DB engine spec able to change the default schema? This requires 
implementing  # noqa: E501
     # a custom `adjust_engine_params` method.
     supports_dynamic_schema = False
@@ -623,10 +631,18 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         """
         Returns the RLS method to be used for this engine.
 
-        There are two ways to insert RLS: either replacing the table with a 
subquery
-        that has the RLS, or appending the RLS to the ``WHERE`` clause. The 
former is
-        safer, but not supported in all databases.
+        There are three ways to insert RLS: replacing the table with a subquery
+        that has the RLS (safest, but not supported in all databases), 
appending
+        the RLS to the ``WHERE`` clause via AST transformation, or splicing the
+        RLS into the original SQL string (preserves dialect-specific syntax 
that
+        the sqlglot generator would otherwise transpile).
+
+        Engine specs can opt into a specific strategy by setting the 
class-level
+        ``rls_method`` attribute; otherwise the choice falls back to subquery
+        when supported, and predicate otherwise.
         """
+        if cls.rls_method is not None:
+            return cls.rls_method
         return (
             RLSMethod.AS_SUBQUERY
             if cls.allows_subqueries and cls.allows_alias_in_select
diff --git a/superset/sql/parse.py b/superset/sql/parse.py
index bb3ef5e1c4b..a45943ea463 100644
--- a/superset/sql/parse.py
+++ b/superset/sql/parse.py
@@ -142,6 +142,7 @@ class RLSMethod(enum.Enum):
 
     AS_PREDICATE = enum.auto()
     AS_SUBQUERY = enum.auto()
+    AS_PREDICATE_SPLICE = enum.auto()
 
 
 class RLSTransformer:
@@ -355,6 +356,7 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
         statement: str | None = None,
         engine: str = "base",
         ast: InternalRepresentation | None = None,
+        source: str | None = None,
     ):
         if ast:
             self._parsed = ast
@@ -365,6 +367,16 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
 
         self.engine = engine
         self.tables = self._extract_tables_from_statement(self._parsed, 
self.engine)
+        # Original SQL substring for this statement, when known. Used by the
+        # splice-mode RLS path which rewrites this string instead of 
regenerating
+        # SQL from the AST. ``None`` means the statement was constructed from 
an
+        # AST without an associated source string (splice mode falls back).
+        self._source_sql: str | None = source if source is not None else 
statement
+        # Verbatim SQL to return from ``format()``. Set by string-rewriting
+        # operations (e.g. splice-mode RLS) that produce a final SQL string and
+        # need to bypass the dialect generator. Cleared by AST-mutating methods
+        # since those invalidate this cached text.
+        self._raw_sql: str | None = None
 
     @classmethod
     def split_script(
@@ -559,9 +571,10 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
         statement: str | None = None,
         engine: str = "base",
         ast: exp.Expression | None = None,
+        source: str | None = None,
     ):
         self._dialect = SQLGLOT_DIALECTS.get(engine)
-        super().__init__(statement, engine, ast)
+        super().__init__(statement, engine, ast, source)
 
     @classmethod
     def _parse(cls, script: str, engine: str) -> list[exp.Expression]:
@@ -626,10 +639,55 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
         script: str,
         engine: str,
     ) -> list[SQLStatement]:
+        asts = [ast for ast in cls._parse(script, engine) if ast]
+        sources = cls._split_source(script, engine, len(asts))
         return [
-            cls(ast=ast, engine=engine) for ast in cls._parse(script, engine) 
if ast
+            cls(ast=ast, engine=engine, source=source)
+            for ast, source in zip(asts, sources, strict=False)
         ]
 
+    @classmethod
+    def _split_source(
+        cls,
+        script: str,
+        engine: str,
+        expected_count: int,
+    ) -> list[str | None]:
+        """
+        Slice ``script`` into per-statement substrings using top-level 
semicolon
+        positions from the tokenizer. Returns a list of length 
``expected_count``;
+        any entry is ``None`` if the slicing didn't yield a usable substring.
+
+        The returned substrings preserve the original byte content of the 
script
+        for each statement — necessary for splice-mode RLS, which rewrites the
+        original SQL rather than regenerating from the AST.
+        """
+        none_result: list[str | None] = [None] * expected_count
+        dialect = SQLGLOT_DIALECTS.get(engine)
+        try:
+            tokens = list(Dialect.get_or_raise(dialect).tokenize(script))
+        except sqlglot.errors.SqlglotError:
+            return none_result
+
+        # Top-level semicolon offsets (depth 0).
+        boundaries: list[int] = []
+        depth = 0
+        for tok in tokens:
+            if tok.token_type == sqlglot.tokens.TokenType.L_PAREN:
+                depth += 1
+            elif tok.token_type == sqlglot.tokens.TokenType.R_PAREN:
+                depth -= 1
+            elif tok.token_type == sqlglot.tokens.TokenType.SEMICOLON and 
depth == 0:
+                boundaries.append(tok.start)
+
+        starts = [0, *(b + 1 for b in boundaries)]
+        ends = [*boundaries, len(script)]
+        sources = [script[s:e].strip() for s, e in zip(starts, ends, 
strict=False)]
+        sources = [s for s in sources if s]
+        if len(sources) != expected_count:
+            return none_result
+        return list(sources)
+
     @classmethod
     def _parse_statement(
         cls,
@@ -722,7 +780,13 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
     def format(self, comments: bool = True) -> str:
         """
         Pretty-format the SQL statement.
+
+        When a string-rewriting operation (e.g. splice-mode RLS) has cached a
+        verbatim result in ``_raw_sql``, return it as-is — the whole point of
+        those operations is to avoid the dialect generator round-trip.
         """
+        if self._raw_sql is not None:
+            return self._raw_sql
         return Dialect.get_or_raise(self._dialect).generate(
             self._parsed,
             copy=True,
@@ -808,6 +872,8 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
         """
         Modify the `LIMIT` or `TOP` value of the SQL statement inplace.
         """
+        # AST mutation invalidates any cached verbatim SQL (e.g. from splice).
+        self._raw_sql = None
         if method == LimitMethod.FORCE_LIMIT:
             self._parsed.args["limit"] = exp.Limit(
                 expression=exp.Literal(this=str(limit), is_string=False)
@@ -902,7 +968,7 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
         self,
         catalog: str | None,
         schema: str | None,
-        predicates: dict[Table, list[exp.Expression]],
+        predicates: dict[Table, list[exp.Expression]] | dict[Table, list[str]],
         method: RLSMethod,
     ) -> None:
         """
@@ -910,11 +976,18 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
 
         :param catalog: The default catalog for non-qualified table names
         :param schema: The default schema for non-qualified table names
+        :param predicates: Mapping of fully qualified ``Table`` to predicates.
+            For ``AS_PREDICATE`` and ``AS_SUBQUERY`` the predicates are sqlglot
+            expressions. For ``AS_PREDICATE_SPLICE`` they are raw SQL strings.
         :param method: The method to use for applying the rules.
         """
         if not predicates:
             return
 
+        if method == RLSMethod.AS_PREDICATE_SPLICE:
+            self._apply_rls_splice(catalog, schema, predicates)
+            return
+
         transformers = {
             RLSMethod.AS_PREDICATE: RLSAsPredicateTransformer,
             RLSMethod.AS_SUBQUERY: RLSAsSubqueryTransformer,
@@ -925,6 +998,44 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
         transformer = transformers[method](catalog, schema, predicates)
         self._parsed = self._parsed.transform(transformer)
 
+    def _apply_rls_splice(
+        self,
+        catalog: str | None,
+        schema: str | None,
+        predicates: dict[Table, list[exp.Expression]] | dict[Table, list[str]],
+    ) -> None:
+        """
+        Apply RLS via text splicing on the original SQL.
+
+        Requires the source SQL substring to be available. Raises 
``ValueError``
+        if it isn't — the caller must ensure the statement was constructed from
+        a source string (the standard ``SQLScript`` path does this).
+        """
+        from superset.sql.rls_splice import apply_rls_splice
+
+        if self._source_sql is None:
+            raise ValueError(
+                "Splice-mode RLS requires the source SQL string; "
+                "this SQLStatement was constructed without one."
+            )
+
+        # Splice operates on raw predicate strings; coerce expressions if 
needed.
+        string_predicates: dict[Table, list[str]] = {
+            table: [
+                pred if isinstance(pred, str) else 
pred.sql(dialect=self._dialect)
+                for pred in preds
+            ]
+            for table, preds in predicates.items()
+        }
+        spliced = apply_rls_splice(
+            self._source_sql,
+            catalog,
+            schema,
+            string_predicates,
+            dialect=self._dialect,
+        )
+        self._raw_sql = spliced
+
 
 class KQLSplitState(enum.Enum):
     """
diff --git a/superset/sql/rls_splice.py b/superset/sql/rls_splice.py
new file mode 100644
index 00000000000..0166a13fe49
--- /dev/null
+++ b/superset/sql/rls_splice.py
@@ -0,0 +1,263 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+RLS predicate injection via text splicing.
+
+Instead of round-tripping through sqlglot's generator (which transpiles
+dialect-specific functions like ``LAST_DAY`` into something else), this 
approach:
+
+  1. Parses the SQL with sqlglot — only to understand structure (scope tree).
+  2. Uses sqlglot's tokenizer to get byte-accurate positions for every token
+     in the original SQL string.
+  3. For each ``SELECT`` scope that references a table with an RLS predicate,
+     finds the exact byte offset to inject at — either the end of an existing
+     ``WHERE`` clause, or just before ``GROUP BY`` / ``ORDER BY`` / ``HAVING``
+     / ``LIMIT`` / the closing paren of a subquery.
+  4. Splices the predicate text directly into the original string at that
+     offset — never calling ``.sql()``, so the generator never runs.
+
+Result: everything outside the splice points is the original SQL, byte for
+byte. Dialect-specific functions, comments, and formatting are all preserved
+exactly.
+
+Known limitations:
+  - SQL that fails to parse under the chosen dialect raises a ``ParseError``.
+    A thin dialect subclass is still required for parsing — but only for
+    parsing, not generation.
+  - Predicate strings are spliced in as raw SQL. They must come from a trusted
+    source (the RLS config), not user input.
+"""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+import sqlglot
+from sqlglot import exp
+from sqlglot.optimizer.scope import traverse_scope
+from sqlglot.tokens import Token, TokenType
+
+if TYPE_CHECKING:
+    from superset.sql.parse import Table
+
+
+# Token types that end a WHERE clause / FROM section at the current paren 
depth,
+# indicating where a new predicate must be inserted just before.
+_CLAUSE_ENDS = {
+    TokenType.GROUP_BY,
+    TokenType.HAVING,
+    TokenType.ORDER_BY,
+    TokenType.LIMIT,
+    TokenType.FETCH,
+    TokenType.UNION,
+    TokenType.INTERSECT,
+    TokenType.EXCEPT,
+}
+
+
+def _before_whitespace(sql: str, offset: int) -> int:
+    """Back up past any whitespace immediately before *offset*."""
+    while offset > 0 and sql[offset - 1] in (" ", "\t", "\n", "\r"):
+        offset -= 1
+    return offset
+
+
+def _table_from_node(
+    node: exp.Table,
+    catalog: str | None,
+    schema: str | None,
+) -> Table:
+    """
+    Build a fully qualified ``Table`` from a sqlglot ``exp.Table`` node, 
defaulting
+    unqualified parts to the supplied catalog/schema.
+    """
+    # Imported lazily to avoid a circular import with ``superset.sql.parse``.
+    from superset.sql.parse import Table
+
+    return Table(
+        table=node.name,
+        schema=node.db if node.db else schema,
+        catalog=node.catalog if node.catalog else catalog,
+    )
+
+
+def apply_rls_splice(
+    sql: str,
+    catalog: str | None,
+    schema: str | None,
+    predicates: dict[Table, list[str]],
+    dialect: str | None = None,
+) -> str:
+    """
+    Inject RLS predicates into ``sql`` by splicing text at the right positions.
+
+    :param sql: The original SQL query. Returned unchanged except at splice 
points.
+    :param catalog: The default catalog for non-qualified table names.
+    :param schema: The default schema for non-qualified table names.
+    :param predicates: Mapping of ``Table`` to predicate SQL strings. Each 
entry
+        maps a fully qualified table to one or more raw predicate strings to
+        ``AND`` together when that table is referenced in a SELECT scope.
+    :param dialect: The sqlglot dialect used for *parsing only* — to understand
+        scope structure and locate token positions. The generator is never
+        called, so this does not affect output formatting.
+    :return: The query with RLS predicates injected into every relevant SELECT
+        scope.
+    """
+    if not predicates or not any(predicates.values()):
+        return sql
+
+    resolved_dialect = sqlglot.Dialect.get_or_raise(dialect)
+    tokens = list(resolved_dialect.tokenize(sql))
+    tree = sqlglot.parse_one(sql, dialect=dialect)
+
+    splices: list[tuple[int, str]] = []
+    for scope in traverse_scope(tree):
+        splice = _splice_for_scope(sql, tokens, scope, predicates, catalog, 
schema)
+        if splice is not None:
+            splices.append(splice)
+
+    # Apply splices in reverse offset order so earlier positions stay valid.
+    splices.sort(key=lambda item: item[0], reverse=True)
+    result = sql
+    for offset, text in splices:
+        result = result[:offset] + text + result[offset:]
+    return result
+
+
+def _splice_for_scope(
+    sql: str,
+    tokens: list[Token],
+    scope: object,
+    predicates: dict[Table, list[str]],
+    catalog: str | None,
+    schema: str | None,
+) -> tuple[int, str] | None:
+    """
+    Compute the (offset, text) splice for a single SELECT scope, or ``None`` if
+    the scope has no matching predicates or no usable anchor.
+    """
+    scope_preds = _collect_scope_predicates(scope, predicates, catalog, schema)
+    if not scope_preds:
+        return None
+
+    # Anchor: rightmost character position among the table-name identifiers
+    # directly owned by this scope. Used to skip past tokens that belong to
+    # earlier parts of the query (projections, JOIN ON clauses, etc.).
+    table_ends = [
+        ident._meta["end"]
+        for source in scope.sources.values()  # type: ignore[attr-defined]
+        if isinstance(source, exp.Table)
+        for ident in [source.find(exp.Identifier)]
+        if ident and getattr(ident, "_meta", None)
+    ]
+    if not table_ends:
+        return None
+
+    has_where = scope.expression.args.get("where") is not None  # type: 
ignore[attr-defined]
+    pred_sql = " AND ".join(scope_preds)
+    return _find_splice_point(sql, tokens, max(table_ends), has_where, 
pred_sql)
+
+
+def _collect_scope_predicates(
+    scope: object,
+    predicates: dict[Table, list[str]],
+    catalog: str | None,
+    schema: str | None,
+) -> list[str]:
+    """
+    Collect the predicates that apply to direct Table sources in ``scope``,
+    deduped while preserving order.
+    """
+    scope_preds: list[str] = []
+    for source in scope.sources.values():  # type: ignore[attr-defined]
+        if not isinstance(source, exp.Table):
+            continue
+        table = _table_from_node(source, catalog, schema)
+        for predicate in predicates.get(table, []):
+            if predicate and predicate not in scope_preds:
+                scope_preds.append(predicate)
+    return scope_preds
+
+
+def _find_splice_point(
+    sql: str,
+    tokens: list[Token],
+    anchor: int,
+    has_where: bool,
+    pred_sql: str,
+) -> tuple[int, str] | None:
+    """
+    Scan tokens forward from ``anchor``, tracking paren depth, to find where to
+    insert the RLS predicate for a single scope.
+    """
+    depth = 0
+    for i, tok in enumerate(tokens):
+        if tok.start <= anchor:
+            continue
+
+        if tok.token_type == TokenType.L_PAREN:
+            depth += 1
+            continue
+
+        if tok.token_type == TokenType.R_PAREN:
+            if depth == 0:
+                # Closing paren of our subquery — insert just before it.
+                offset = _before_whitespace(sql, tok.start)
+                text = f" AND {pred_sql}" if has_where else f" WHERE 
{pred_sql}"
+                return (offset, text)
+            depth -= 1
+            continue
+
+        if depth > 0:
+            continue
+
+        if has_where and tok.token_type == TokenType.WHERE:
+            return _find_after_where(sql, tokens, i, pred_sql)
+
+        if not has_where and tok.token_type in _CLAUSE_ENDS:
+            # Insert WHERE before this clause keyword.
+            return (_before_whitespace(sql, tok.start), f" WHERE {pred_sql}")
+
+    # No clause boundary found — append at end of SQL.
+    text = f" AND {pred_sql}" if has_where else f" WHERE {pred_sql}"
+    return (len(sql), text)
+
+
+def _find_after_where(
+    sql: str,
+    tokens: list[Token],
+    where_index: int,
+    pred_sql: str,
+) -> tuple[int, str] | None:
+    """
+    Given the index of a ``WHERE`` token in ``tokens``, find the offset just
+    after the WHERE clause body where ``AND <pred>`` should be inserted.
+    """
+    depth = 0
+    prev_end = tokens[where_index].end
+    for tok in tokens[where_index + 1 :]:
+        if tok.token_type == TokenType.L_PAREN:
+            depth += 1
+        elif tok.token_type == TokenType.R_PAREN:
+            if depth == 0:
+                return (_before_whitespace(sql, tok.start), f" AND {pred_sql}")
+            depth -= 1
+        elif depth == 0 and tok.token_type in _CLAUSE_ENDS:
+            return (_before_whitespace(sql, tok.start), f" AND {pred_sql}")
+        prev_end = tok.end
+    return (prev_end + 1, f" AND {pred_sql}")
diff --git a/superset/utils/rls.py b/superset/utils/rls.py
index 7e6cdf2aee7..456b589e365 100644
--- a/superset/utils/rls.py
+++ b/superset/utils/rls.py
@@ -22,7 +22,7 @@ from typing import Any, TYPE_CHECKING
 from sqlalchemy import and_, or_
 
 from superset import db
-from superset.sql.parse import Table
+from superset.sql.parse import RLSMethod, Table
 
 if TYPE_CHECKING:
     from superset.models.core import Database
@@ -40,17 +40,23 @@ def apply_rls(
 
     :returns: True if any RLS predicates were actually applied, False 
otherwise.
     """
-    # There are two ways to insert RLS: either replacing the table with a 
subquery
-    # that has the RLS, or appending the RLS to the ``WHERE`` clause. The 
former is
-    # safer, but not supported in all databases.
+    # There are three ways to insert RLS:
+    #   - replace the table with a subquery containing the RLS (safest, but not
+    #     supported in all databases)
+    #   - append the RLS to the ``WHERE`` clause via AST transformation
+    #   - splice the RLS into the original SQL string (preserves 
dialect-specific
+    #     syntax that the sqlglot generator would otherwise transpile)
     method = database.db_engine_spec.get_rls_method()
 
-    # collect all RLS predicates for all tables in the query
+    # In splice mode predicates stay as raw SQL strings and are inserted 
verbatim
+    # into the source query — re-parsing them would force a generator 
round-trip
+    # later and defeat the purpose.
+    use_splice = method == RLSMethod.AS_PREDICATE_SPLICE
     predicates: dict[Table, list[Any]] = {}
     for table in parsed_statement.tables:
         table = table.qualify(catalog=catalog, schema=schema)
-        predicates[table] = [
-            parsed_statement.parse_predicate(predicate)
+        raw_predicates = [
+            predicate
             for predicate in get_predicates_for_table(
                 table,
                 database,
@@ -58,6 +64,11 @@ def apply_rls(
             )
             if predicate
         ]
+        predicates[table] = (
+            raw_predicates
+            if use_splice
+            else [parsed_statement.parse_predicate(p) for p in raw_predicates]
+        )
 
     has_predicates = any(predicates.values())
     parsed_statement.apply_rls(catalog, schema, predicates, method)
diff --git a/tests/unit_tests/db_engine_specs/test_base.py 
b/tests/unit_tests/db_engine_specs/test_base.py
index 5eae41458de..1b58809b3b2 100644
--- a/tests/unit_tests/db_engine_specs/test_base.py
+++ b/tests/unit_tests/db_engine_specs/test_base.py
@@ -36,7 +36,7 @@ from sqlalchemy.sql import sqltypes
 from superset.db_engine_specs.base import BaseEngineSpec, 
convert_inspector_columns
 from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
 from superset.exceptions import OAuth2RedirectError
-from superset.sql.parse import Table
+from superset.sql.parse import RLSMethod, Table
 from superset.superset_typing import (
     OAuth2ClientConfig,
     OAuth2State,
@@ -1283,3 +1283,44 @@ def 
test_start_oauth2_dance_falls_back_to_url_for(mocker: MockerFixture) -> None
     error = exc_info.value.error
 
     assert error.extra["redirect_uri"] == fallback_uri
+
+
+def test_get_rls_method_default_subquery() -> None:
+    """
+    By default, an engine that supports subqueries and aliases-in-select
+    uses the safer subquery RLS strategy.
+    """
+
+    class _Spec(BaseEngineSpec):
+        allows_subqueries = True
+        allows_alias_in_select = True
+
+    assert _Spec.get_rls_method() == RLSMethod.AS_SUBQUERY
+
+
+def test_get_rls_method_default_predicate_when_no_subqueries() -> None:
+    """
+    Engines without subquery / alias-in-select support fall back to the
+    AST predicate strategy.
+    """
+
+    class _Spec(BaseEngineSpec):
+        allows_subqueries = False
+        allows_alias_in_select = True
+
+    assert _Spec.get_rls_method() == RLSMethod.AS_PREDICATE
+
+
+def test_get_rls_method_class_attribute_override() -> None:
+    """
+    Setting ``rls_method`` on an engine spec opts the engine into a specific
+    strategy regardless of the subquery/alias defaults — used by engines whose
+    sqlglot dialect can parse but not faithfully regenerate SQL.
+    """
+
+    class _SpliceSpec(BaseEngineSpec):
+        allows_subqueries = True
+        allows_alias_in_select = True
+        rls_method = RLSMethod.AS_PREDICATE_SPLICE
+
+    assert _SpliceSpec.get_rls_method() == RLSMethod.AS_PREDICATE_SPLICE
diff --git a/tests/unit_tests/sql/parse_tests.py 
b/tests/unit_tests/sql/parse_tests.py
index 78b00f4487d..7835840b04a 100644
--- a/tests/unit_tests/sql/parse_tests.py
+++ b/tests/unit_tests/sql/parse_tests.py
@@ -2548,6 +2548,165 @@ def test_rls_predicate_transformer(
     assert statement.format() == expected
 
 
[email protected](
+    "sql, rules, expected",
+    [
+        # Simple — no WHERE clause to extend.
+        (
+            "SELECT LAST_DAY(d) FROM some_table",
+            {Table("some_table", "schema1", "catalog1"): "tenant_id = 42"},
+            "SELECT LAST_DAY(d) FROM some_table WHERE tenant_id = 42",
+        ),
+        # Append to an existing WHERE clause.
+        (
+            "SELECT LAST_DAY(d) FROM some_table WHERE status = 'open'",
+            {Table("some_table", "schema1", "catalog1"): "tenant_id = 42"},
+            "SELECT LAST_DAY(d) FROM some_table WHERE status = 'open' "
+            "AND tenant_id = 42",
+        ),
+        # WHERE precedes GROUP BY: predicate goes before GROUP BY.
+        (
+            "SELECT LAST_DAY(d) FROM some_table WHERE status = 'open' GROUP BY 
d",
+            {Table("some_table", "schema1", "catalog1"): "tenant_id = 42"},
+            "SELECT LAST_DAY(d) FROM some_table WHERE status = 'open' "
+            "AND tenant_id = 42 GROUP BY d",
+        ),
+        # No WHERE, but GROUP BY and ORDER BY are present.
+        (
+            "SELECT LAST_DAY(d) FROM some_table GROUP BY d ORDER BY d",
+            {Table("some_table", "schema1", "catalog1"): "tenant_id = 42"},
+            "SELECT LAST_DAY(d) FROM some_table WHERE tenant_id = 42 "
+            "GROUP BY d ORDER BY d",
+        ),
+        # JOIN — predicate scoped to one of the tables.
+        (
+            "SELECT o.id FROM some_table o JOIN locations l ON o.loc_id = 
l.id",
+            {Table("some_table", "schema1", "catalog1"): "tenant_id = 42"},
+            "SELECT o.id FROM some_table o JOIN locations l "
+            "ON o.loc_id = l.id WHERE tenant_id = 42",
+        ),
+        # JOIN — different predicate per table, both spliced into one WHERE.
+        (
+            "SELECT * FROM some_table JOIN events ON some_table.id = 
events.order_id",
+            {
+                Table("some_table", "schema1", "catalog1"): "tenant_id = 42",
+                Table("events", "schema1", "catalog1"): "user_id = 99",
+            },
+            "SELECT * FROM some_table JOIN events "
+            "ON some_table.id = events.order_id "
+            "WHERE tenant_id = 42 AND user_id = 99",
+        ),
+        # Subquery in FROM — splice into the inner SELECT.
+        (
+            "SELECT x FROM (SELECT LAST_DAY(d) AS x FROM some_table) sub",
+            {Table("some_table", "schema1", "catalog1"): "tenant_id = 42"},
+            "SELECT x FROM (SELECT LAST_DAY(d) AS x FROM some_table "
+            "WHERE tenant_id = 42) sub",
+        ),
+        # CTE — splice into the CTE body.
+        (
+            "WITH cte AS (SELECT LAST_DAY(d) FROM some_table) SELECT * FROM 
cte",
+            {Table("some_table", "schema1", "catalog1"): "tenant_id = 42"},
+            "WITH cte AS (SELECT LAST_DAY(d) FROM some_table "
+            "WHERE tenant_id = 42) SELECT * FROM cte",
+        ),
+        # Dialect-specific function (LAST_DAY) preserved verbatim.
+        (
+            "SELECT id, LAST_DAY(created_at) FROM some_table WHERE region = 
'US'",
+            {Table("some_table", "schema1", "catalog1"): "tenant_id = 42"},
+            "SELECT id, LAST_DAY(created_at) FROM some_table "
+            "WHERE region = 'US' AND tenant_id = 42",
+        ),
+        # Multiline + inline comment preserved exactly.
+        (
+            "SELECT LAST_DAY(created_at) -- last day of month\n"
+            "FROM some_table\n"
+            "WHERE region = 'US'",
+            {Table("some_table", "schema1", "catalog1"): "tenant_id = 42"},
+            "SELECT LAST_DAY(created_at) -- last day of month\n"
+            "FROM some_table\n"
+            "WHERE region = 'US' AND tenant_id = 42",
+        ),
+        # Schema-qualified table name (no default schema match) — no predicate.
+        (
+            "SELECT t.foo FROM schema2.some_table AS t",
+            {Table("some_table", "schema1", "catalog1"): "id = 42"},
+            "SELECT t.foo FROM schema2.some_table AS t",
+        ),
+    ],
+)
+def test_rls_predicate_splice(
+    sql: str,
+    rules: dict[Table, str],
+    expected: str,
+) -> None:
+    """
+    Test the splice-mode RLS via ``RLSMethod.AS_PREDICATE_SPLICE``.
+
+    Splice mode rewrites the original SQL string instead of re-rendering the
+    AST through the dialect generator, so byte-level fidelity (including
+    dialect-specific functions, comments, and whitespace) is preserved.
+    """
+    statement = SQLStatement(sql)
+    statement.apply_rls(
+        "catalog1",
+        "schema1",
+        {k: [v] for k, v in rules.items()},
+        RLSMethod.AS_PREDICATE_SPLICE,
+    )
+    assert statement.format() == expected
+
+
+def test_rls_predicate_splice_requires_source() -> None:
+    """
+    Splice mode requires the original SQL substring; constructing a statement
+    purely from an AST should make splice mode raise.
+    """
+    ast = parse_one("SELECT * FROM some_table")
+    statement = SQLStatement(ast=ast, engine="postgresql")
+    with pytest.raises(ValueError, match="Splice-mode RLS requires the source 
SQL"):
+        statement.apply_rls(
+            "catalog1",
+            "schema1",
+            {Table("some_table", "schema1", "catalog1"): ["id = 42"]},
+            RLSMethod.AS_PREDICATE_SPLICE,
+        )
+
+
+def test_rls_predicate_splice_preserves_dialect_function() -> None:
+    """
+    Splice mode must NOT round-trip through the sqlglot generator. ``LAST_DAY``
+    on the postgres dialect would otherwise be transpiled by the generator.
+    """
+    sql = "SELECT LAST_DAY(d) FROM some_table"
+    statement = SQLStatement(sql, engine="postgresql")
+    statement.apply_rls(
+        None,
+        None,
+        {Table("some_table"): ["tenant_id = 42"]},
+        RLSMethod.AS_PREDICATE_SPLICE,
+    )
+    assert "LAST_DAY(d)" in statement.format()
+
+
+def test_rls_predicate_splice_string_predicates_skip_parse() -> None:
+    """
+    Splice mode accepts predicate strings directly — no ``parse_predicate`` is
+    needed at the call site.
+    """
+    sql = "SELECT * FROM some_table"
+    statement = SQLStatement(sql, engine="postgresql")
+    statement.apply_rls(
+        None,
+        None,
+        {Table("some_table"): ["tenant_id = 42 AND active"]},
+        RLSMethod.AS_PREDICATE_SPLICE,
+    )
+    assert statement.format() == (
+        "SELECT * FROM some_table WHERE tenant_id = 42 AND active"
+    )
+
+
 @pytest.mark.parametrize(
     "sql, table, expected",
     [

Reply via email to