This is an automated email from the ASF dual-hosted git repository.

betodealmeida pushed a commit to branch rls-splice
in repository https://gitbox.apache.org/repos/asf/superset.git

commit af2d3babecc150dfebe39ae72213675bf99ddb1f
Author: Beto Dealmeida <[email protected]>
AuthorDate: Fri May 8 13:49:37 2026 -0400

    Improvements
---
 superset/sql/parse.py               |  7 ++-
 superset/sql/rls_splice.py          | 42 ++++++++++++++--
 tests/unit_tests/sql/parse_tests.py | 95 +++++++++++++++++++++++++++++++++++++
 3 files changed, 139 insertions(+), 5 deletions(-)

diff --git a/superset/sql/parse.py b/superset/sql/parse.py
index a45943ea463..d944dce5a2f 100644
--- a/superset/sql/parse.py
+++ b/superset/sql/parse.py
@@ -873,7 +873,12 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
         Modify the `LIMIT` or `TOP` value of the SQL statement inplace.
         """
         # AST mutation invalidates any cached verbatim SQL (e.g. from splice).
-        self._raw_sql = None
+        # If we already have a rewritten SQL string, re-parse it first so 
further
+        # AST mutations (like LIMIT injection) preserve prior text-based 
rewrites.
+        if self._raw_sql is not None:
+            self._parsed = self._parse_statement(self._raw_sql, self.engine)
+            self._source_sql = self._raw_sql
+            self._raw_sql = None
         if method == LimitMethod.FORCE_LIMIT:
             self._parsed.args["limit"] = exp.Limit(
                 expression=exp.Literal(this=str(limit), is_string=False)
diff --git a/superset/sql/rls_splice.py b/superset/sql/rls_splice.py
index 0166a13fe49..48576f7ac71 100644
--- a/superset/sql/rls_splice.py
+++ b/superset/sql/rls_splice.py
@@ -62,8 +62,15 @@ _CLAUSE_ENDS = {
     TokenType.GROUP_BY,
     TokenType.HAVING,
     TokenType.ORDER_BY,
+    TokenType.WINDOW,
+    TokenType.QUALIFY,
     TokenType.LIMIT,
     TokenType.FETCH,
+    TokenType.CLUSTER_BY,
+    TokenType.DISTRIBUTE_BY,
+    TokenType.SORT_BY,
+    TokenType.CONNECT_BY,
+    TokenType.START_WITH,
     TokenType.UNION,
     TokenType.INTERSECT,
     TokenType.EXCEPT,
@@ -77,6 +84,33 @@ def _before_whitespace(sql: str, offset: int) -> int:
     return offset
 
 
+def _before_trivia(sql: str, offset: int) -> int:
+    """
+    Back up past whitespace and adjacent comments immediately before *offset*.
+
+    This ensures insertion points land before inline/block comments that appear
+    between `FROM`/`WHERE` and the next clause keyword.
+    """
+    while True:
+        offset = _before_whitespace(sql, offset)
+
+        # Inline comment ending at offset, eg: "... -- comment\nGROUP BY".
+        line_start = sql.rfind("\n", 0, offset) + 1
+        inline_comment_start = sql.rfind("--", line_start, offset)
+        if inline_comment_start != -1:
+            offset = inline_comment_start
+            continue
+
+        # Block comment ending at offset, eg: "... /* comment */GROUP BY".
+        if offset >= 2 and sql[offset - 2 : offset] == "*/":
+            block_comment_start = sql.rfind("/*", 0, offset - 2)
+            if block_comment_start != -1:
+                offset = block_comment_start
+                continue
+
+        return offset
+
+
 def _table_from_node(
     node: exp.Table,
     catalog: str | None,
@@ -217,7 +251,7 @@ def _find_splice_point(
         if tok.token_type == TokenType.R_PAREN:
             if depth == 0:
                 # Closing paren of our subquery — insert just before it.
-                offset = _before_whitespace(sql, tok.start)
+                offset = _before_trivia(sql, tok.start)
                 text = f" AND {pred_sql}" if has_where else f" WHERE 
{pred_sql}"
                 return (offset, text)
             depth -= 1
@@ -231,7 +265,7 @@ def _find_splice_point(
 
         if not has_where and tok.token_type in _CLAUSE_ENDS:
             # Insert WHERE before this clause keyword.
-            return (_before_whitespace(sql, tok.start), f" WHERE {pred_sql}")
+            return (_before_trivia(sql, tok.start), f" WHERE {pred_sql}")
 
     # No clause boundary found — append at end of SQL.
     text = f" AND {pred_sql}" if has_where else f" WHERE {pred_sql}"
@@ -255,9 +289,9 @@ def _find_after_where(
             depth += 1
         elif tok.token_type == TokenType.R_PAREN:
             if depth == 0:
-                return (_before_whitespace(sql, tok.start), f" AND {pred_sql}")
+                return (_before_trivia(sql, tok.start), f" AND {pred_sql}")
             depth -= 1
         elif depth == 0 and tok.token_type in _CLAUSE_ENDS:
-            return (_before_whitespace(sql, tok.start), f" AND {pred_sql}")
+            return (_before_trivia(sql, tok.start), f" AND {pred_sql}")
         prev_end = tok.end
     return (prev_end + 1, f" AND {pred_sql}")
diff --git a/tests/unit_tests/sql/parse_tests.py 
b/tests/unit_tests/sql/parse_tests.py
index 7835840b04a..79a9cec6953 100644
--- a/tests/unit_tests/sql/parse_tests.py
+++ b/tests/unit_tests/sql/parse_tests.py
@@ -1704,6 +1704,21 @@ def test_set_limit_value(
     assert statement.format() == expected
 
 
+def test_set_limit_value_after_splice_reparses_from_raw_sql() -> None:
+    """
+    When a statement has cached verbatim SQL from splice-mode rewrites, setting
+    limit should reparse that SQL before mutating the AST.
+    """
+    statement = SQLStatement("SELECT * FROM some_table", "postgresql")
+    statement._raw_sql = "SELECT * FROM some_table WHERE tenant_id = 42"
+
+    statement.set_limit_value(10, LimitMethod.FORCE_LIMIT)
+    formatted = statement.format()
+
+    assert "tenant_id = 42" in formatted
+    assert "LIMIT 10" in formatted
+
+
 @pytest.mark.parametrize(
     "kql, limit, expected",
     [
@@ -2707,6 +2722,86 @@ def 
test_rls_predicate_splice_string_predicates_skip_parse() -> None:
     )
 
 
[email protected](
+    "sql, expected",
+    [
+        (
+            "SELECT * FROM some_table -- hi\nGROUP BY id",
+            "SELECT * FROM some_table WHERE tenant_id = 42 -- hi\nGROUP BY id",
+        ),
+        (
+            "SELECT * FROM some_table /* inline */ GROUP BY id",
+            "SELECT * FROM some_table WHERE tenant_id = 42 /* inline */ GROUP 
BY id",
+        ),
+    ],
+)
+def test_rls_predicate_splice_inserts_before_comments(sql: str, expected: str) 
-> None:
+    """
+    Splice mode should insert predicates before comments that precede the next
+    clause boundary, so comments do not swallow the injected SQL.
+    """
+    statement = SQLStatement(sql, engine="postgresql")
+    statement.apply_rls(
+        None,
+        None,
+        {Table("some_table"): ["tenant_id = 42"]},
+        RLSMethod.AS_PREDICATE_SPLICE,
+    )
+    assert statement.format() == expected
+
+
[email protected](
+    "sql, engine, expected",
+    [
+        (
+            "SELECT * FROM some_table QUALIFY row_number() OVER (PARTITION BY 
id ORDER BY ts DESC) = 1",
+            "snowflake",
+            "SELECT * FROM some_table WHERE tenant_id = 42 QUALIFY 
row_number() OVER (PARTITION BY id ORDER BY ts DESC) = 1",
+        ),
+        (
+            "SELECT sum(v) OVER () FROM some_table WINDOW w AS (PARTITION BY 
id)",
+            "postgresql",
+            "SELECT sum(v) OVER () FROM some_table WHERE tenant_id = 42 WINDOW 
w AS (PARTITION BY id)",
+        ),
+    ],
+)
+def test_rls_predicate_splice_handles_additional_clause_boundaries(
+    sql: str,
+    engine: str,
+    expected: str,
+) -> None:
+    """
+    Splice mode should insert WHERE before clause types that can legally follow
+    FROM/WHERE (for example QUALIFY and WINDOW).
+    """
+    statement = SQLStatement(sql, engine=engine)
+    statement.apply_rls(
+        None,
+        None,
+        {Table("some_table"): ["tenant_id = 42"]},
+        RLSMethod.AS_PREDICATE_SPLICE,
+    )
+    assert statement.format() == expected
+
+
+def test_rls_predicate_splice_then_limit_keeps_rls() -> None:
+    """
+    LIMIT rewrites after splice-mode RLS should retain injected predicates.
+    """
+    statement = SQLStatement("SELECT * FROM some_table", engine="postgresql")
+    statement.apply_rls(
+        None,
+        None,
+        {Table("some_table"): ["tenant_id = 42"]},
+        RLSMethod.AS_PREDICATE_SPLICE,
+    )
+    statement.set_limit_value(101, LimitMethod.FORCE_LIMIT)
+
+    formatted = statement.format()
+    assert "tenant_id = 42" in formatted
+    assert "LIMIT 101" in formatted
+
+
 @pytest.mark.parametrize(
     "sql, table, expected",
     [

Reply via email to