This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch sqlglot-rls
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 178771430630a078e94a91b5033cd5079906b93e
Author: Beto Dealmeida <[email protected]>
AuthorDate: Mon May 19 14:28:44 2025 -0400

    feat: implement RLS in sqlglot
---
 superset/sql/parse.py               | 190 +++++++++++-
 tests/unit_tests/sql/parse_tests.py | 564 ++++++++++++++++++++++++++++++++++--
 2 files changed, 734 insertions(+), 20 deletions(-)

diff --git a/superset/sql/parse.py b/superset/sql/parse.py
index 0cada5f0e5..1ca100975f 100644
--- a/superset/sql/parse.py
+++ b/superset/sql/parse.py
@@ -55,7 +55,7 @@ SQLGLOT_DIALECTS = {
     # "db2": ???
     # "dremio": ???
     "drill": Dialects.DRILL,
-    # "druid": ???
+    "druid": Dialects.DRUID,
     "duckdb": Dialects.DUCKDB,
     # "dynamodb": ???
     # "elasticsearch": ???
@@ -108,6 +108,150 @@ class LimitMethod(enum.Enum):
     FETCH_MANY = enum.auto()
 
 
+class RLSMethod(enum.Enum):
+    """
+    Methods for enforcing RLS.
+    """
+
+    AS_PREDICATE = enum.auto()
+    AS_SUBQUERY = enum.auto()
+
+
+class RLSTransformer:
+    """
+    AST transformer to apply RLS rules.
+    """
+
+    def __init__(
+        self,
+        catalog: str | None,
+        schema: str | None,
+        rules: dict[Table, list[exp.Expression]],
+    ) -> None:
+        self.catalog = catalog
+        self.schema = schema
+        self.rules = rules
+
+    def get_predicate(self, table_node: exp.Table) -> exp.Expression | None:
+        """
+        Get the combined RLS predicate for a table.
+        """
+        table = Table(
+            table_node.name,
+            table_node.db if table_node.db else self.schema,
+            table_node.catalog if table_node.catalog else self.catalog,
+        )
+        if predicates := self.rules.get(table):
+            return (
+                exp.And(
+                    this=predicates[0],
+                    expressions=predicates[1:],
+                )
+                if len(predicates) > 1
+                else predicates[0]
+            )
+
+        return None
+
+
+class RLSAsPredicateTransformer(RLSTransformer):
+    """
+    Apply Row Level Security role as a predicate.
+
+    This transformer will apply any RLS predicates to the relevant tables. For 
example,
+    given the RLS rule:
+
+        table: some_table
+        clause: id = 42
+
+    If a user subject to the rule runs the following query:
+
+        SELECT foo FROM some_table WHERE bar = 'baz'
+
+    The query will be modified to:
+
+        SELECT foo FROM some_table WHERE bar = 'baz' AND id = 42
+
+    This approach is probably less secure than using subqueries, so it's only 
used for
+    databases without support for subqueries.
+    """
+
+    def __call__(self, node: exp.Expression) -> exp.Expression:
+        if not isinstance(node, exp.Table):
+            return node
+
+        predicate = self.get_predicate(node)
+        if not predicate:
+            return node
+
+        # qualify columns with table name
+        for column in predicate.find_all(exp.Column):
+            column.set("table", node.alias or node.this)
+
+        if isinstance(node.parent, exp.From):
+            select = node.parent.parent
+            if where := select.args.get("where"):
+                predicate = exp.And(
+                    this=predicate,
+                    expression=exp.Paren(this=where.this),
+                )
+            select.set("where", exp.Where(this=predicate))
+
+        elif isinstance(node.parent, exp.Join):
+            join = node.parent
+            if on := join.args.get("on"):
+                predicate = exp.And(
+                    this=predicate,
+                    expression=exp.Paren(this=on),
+                )
+            join.set("on", predicate)
+
+        return node
+
+
+class RLSAsSubqueryTransformer(RLSTransformer):
+    """
+    Apply Row Level Security role as a subquery.
+
+    This transformer will apply any RLS predicates to the relevant tables. For 
example,
+    given the RLS rule:
+
+        table: some_table
+        clause: id = 42
+
+    If a user subject to the rule runs the following query:
+
+        SELECT foo FROM some_table WHERE bar = 'baz'
+
+    The query will be modified to:
+
+        SELECT foo FROM (SELECT * FROM some_table WHERE id = 42) AS some_table
+        WHERE bar = 'baz'
+
+    This approach is probably more secure than using predicates, but it 
doesn't work for
+    all databases.
+    """
+
+    def __call__(self, node: exp.Expression) -> exp.Expression:
+        if not isinstance(node, exp.Table):
+            return node
+
+        if predicate := self.get_predicate(node):
+            # use alias or name
+            alias = node.alias or node.sql()
+            node.set("alias", None)
+            node = exp.Subquery(
+                this=exp.Select(
+                    expressions=[exp.Star()],
+                    where=exp.Where(this=predicate),
+                    **{"from": exp.From(this=node.copy())},
+                ),
+                alias=alias,
+            )
+
+        return node
+
+
 @dataclass(eq=True, frozen=True)
 class Table:
     """
@@ -173,7 +317,7 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
         elif statement:
             self._parsed = self._parse_statement(statement, engine)
         else:
-            raise SupersetParseError("Either statement or ast must be 
provided")
+            raise ValueError("Either statement or ast must be provided")
 
         self.engine = engine
         self.tables = self._extract_tables_from_statement(self._parsed, 
self.engine)
@@ -293,6 +437,22 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
         """
         raise NotImplementedError()
 
+    def apply_rls(
+        self,
+        catalog: str | None,
+        schema: str | None,
+        predicates: dict[Table, list[InternalRepresentation]],
+        method: RLSMethod,
+    ) -> None:
+        """
+        Apply relevant RLS rules to the statement inplace.
+
+        :param catalog: The default catalog for non-qualified table names
+        :param schema: The default schema for non-qualified table names
+        :param method: The method to use for applying the rules.
+        """
+        raise NotImplementedError()
+
     def __str__(self) -> str:
         return self.format()
 
@@ -573,6 +733,30 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
             engine=self.engine,
         )
 
+    def apply_rls(
+        self,
+        catalog: str | None,
+        schema: str | None,
+        predicates: dict[Table, list[exp.Expression]],
+        method: RLSMethod,
+    ) -> None:
+        """
+        Apply relevant RLS rules to the statement inplace.
+
+        :param catalog: The default catalog for non-qualified table names
+        :param schema: The default schema for non-qualified table names
+        :param method: The method to use for applying the rules.
+        """
+        transformers = {
+            RLSMethod.AS_PREDICATE: RLSAsPredicateTransformer,
+            RLSMethod.AS_SUBQUERY: RLSAsSubqueryTransformer,
+        }
+        if method not in transformers:
+            raise ValueError(f"Invalid RLS method: {method}")
+
+        transformer = transformers[method](catalog, schema, predicates)
+        self._parsed = self._parsed.transform(transformer)
+
 
 class KQLSplitState(enum.Enum):
     """
@@ -966,7 +1150,7 @@ def extract_tables_from_statement(
     """
     Extract all table references in a single statement.
 
-    Please not that this is not trivial; consider the following queries:
+    Please note that this is not trivial; consider the following queries:
 
         DESCRIBE some_table;
         SHOW PARTITIONS FROM some_table;
diff --git a/tests/unit_tests/sql/parse_tests.py 
b/tests/unit_tests/sql/parse_tests.py
index d750870c98..f9c0100c07 100644
--- a/tests/unit_tests/sql/parse_tests.py
+++ b/tests/unit_tests/sql/parse_tests.py
@@ -18,13 +18,14 @@
 
 
 import pytest
-from sqlglot import Dialects
+from sqlglot import Dialects, parse_one
 
 from superset.exceptions import SupersetParseError
 from superset.sql.parse import (
     extract_tables_from_statement,
     KustoKQLStatement,
     LimitMethod,
+    RLSMethod,
     split_kql,
     SQLGLOT_DIALECTS,
     SQLScript,
@@ -303,11 +304,13 @@ def test_format_no_dialect() -> None:
     """
     assert (
         SQLScript("SELECT col FROM t WHERE col NOT IN (1, 2)", 
"dremio").format()
-        == """SELECT
+        == """
+SELECT
   col
 FROM t
 WHERE
-  NOT col IN (1, 2)"""
+  NOT col IN (1, 2)
+        """.strip()
     )
 
 
@@ -1118,7 +1121,8 @@ FROM some_table) AS anon_1
 WHERE anon_1.a > 1 AND anon_1.b = 2
     """
 
-    optimized = """SELECT
+    optimized = """
+SELECT
   anon_1.a,
   anon_1.b
 FROM (
@@ -1131,9 +1135,11 @@ FROM (
     some_table.a > 1 AND some_table.b = 2
 ) AS anon_1
 WHERE
-  TRUE AND TRUE"""
+  TRUE AND TRUE
+    """.strip()
 
-    not_optimized = """SELECT
+    not_optimized = """
+SELECT
   anon_1.a,
   anon_1.b
 FROM (
@@ -1144,7 +1150,8 @@ FROM (
   FROM some_table
 ) AS anon_1
 WHERE
-  anon_1.a > 1 AND anon_1.b = 2"""
+  anon_1.a > 1 AND anon_1.b = 2
+    """.strip()
 
     assert SQLStatement(sql, "sqlite").optimize().format() == optimized
     assert SQLStatement(sql, "dremio").optimize().format() == not_optimized
@@ -1195,9 +1202,11 @@ def test_firebolt_old() -> None:
     sql = "SELECT * FROM t1 UNNEST(col1 AS foo)"
     assert (
         SQLStatement(sql, "firebolt").format()
-        == """SELECT
+        == """
+SELECT
   *
-FROM t1 UNNEST(col1 AS foo)"""
+FROM t1 UNNEST(col1 AS foo)
+        """.strip()
     )
 
 
@@ -1216,9 +1225,11 @@ def test_firebolt_old_escape_string() -> None:
     # but they normalize to ''
     assert (
         SQLStatement(sql, "firebolt").format()
-        == """SELECT
+        == """
+SELECT
   'foo''bar',
-  'foo''bar'"""
+  'foo''bar'
+        """.strip()
     )
 
 
@@ -1410,7 +1421,8 @@ select TOP 100 * from currency
             "mssql",
             1000,
             LimitMethod.FORCE_LIMIT,
-            """WITH abc AS (
+            """
+WITH abc AS (
   SELECT
     *
   FROM test
@@ -1422,7 +1434,8 @@ select TOP 100 * from currency
 SELECT
 TOP 1000
   *
-FROM currency""",
+FROM currency
+            """.strip(),
         ),
         (
             "SELECT DISTINCT x from tbl",
@@ -1457,10 +1470,12 @@ FROM currency""",
             "postgresql",
             1000,
             LimitMethod.FORCE_LIMIT,
-            """SELECT
+            """
+SELECT
   *
 FROM birth_names /* SOME COMMENT WITH LIMIT 555 */
-LIMIT 1000""",
+LIMIT 1000
+            """.strip(),
         ),
         (
             "SELECT * FROM birth_names LIMIT 555",
@@ -1602,7 +1617,8 @@ UNION ALL
 SELECT * FROM currency_2
             """,
             "postgresql",
-            """WITH currency AS (
+            """
+WITH currency AS (
   SELECT
     'INR' AS cur
 ), currency_2 AS (
@@ -1616,7 +1632,8 @@ SELECT * FROM currency_2
   SELECT
     *
   FROM currency_2
-)""",
+)
+            """.strip(),
         ),
     ],
 )
@@ -1625,3 +1642,516 @@ def test_as_cte(sql: str, engine: str, expected: str) 
-> None:
     Test that we can covert select to CTE.
     """
     assert SQLStatement(sql, engine).as_cte().format() == expected
+
+
[email protected](
+    "sql, rules, expected",
+    [
+        (
+            "SELECT t.foo FROM some_table AS t",
+            {Table("some_table", "schema1", "catalog1"): "id = 42"},
+            """
+SELECT
+  t.foo
+FROM (
+  SELECT
+    *
+  FROM some_table
+  WHERE
+    id = 42
+) AS t
+            """.strip(),
+        ),
+        (
+            "SELECT t.foo FROM some_table AS t WHERE bar = 'baz'",
+            {Table("some_table", "schema1", "catalog1"): "id = 42"},
+            """
+SELECT
+  t.foo
+FROM (
+  SELECT
+    *
+  FROM some_table
+  WHERE
+    id = 42
+) AS t
+WHERE
+  bar = 'baz'
+            """.strip(),
+        ),
+        (
+            "SELECT t.foo FROM schema1.some_table AS t",
+            {Table("some_table", "schema1", "catalog1"): "id = 42"},
+            """
+SELECT
+  t.foo
+FROM (
+  SELECT
+    *
+  FROM schema1.some_table
+  WHERE
+    id = 42
+) AS t
+            """.strip(),
+        ),
+        (
+            "SELECT t.foo FROM schema1.some_table AS t",
+            {Table("some_table", "schema2"): "id = 42"},
+            "SELECT\n  t.foo\nFROM schema1.some_table AS t",
+        ),
+        (
+            "SELECT t.foo FROM catalog1.schema1.some_table AS t",
+            {Table("some_table", "schema1", "catalog1"): "id = 42"},
+            """
+SELECT
+  t.foo
+FROM (
+  SELECT
+    *
+  FROM catalog1.schema1.some_table
+  WHERE
+    id = 42
+) AS t
+            """.strip(),
+        ),
+        (
+            "SELECT t.foo FROM catalog1.schema1.some_table AS t",
+            {Table("some_table", "schema1", "catalog2"): "id = 42"},
+            "SELECT\n  t.foo\nFROM catalog1.schema1.some_table AS t",
+        ),
+        (
+            "SELECT * FROM some_table WHERE 1=1",
+            {Table("some_table", "schema1", "catalog1"): "id = 42"},
+            """
+SELECT
+  *
+FROM (
+  SELECT
+    *
+  FROM some_table
+  WHERE
+    id = 42
+) AS some_table
+WHERE
+  1 = 1
+            """.strip(),
+        ),
+        (
+            "SELECT * FROM table WHERE 1=1",
+            {Table("table", "schema1", "catalog1"): "id = 42"},
+            """
+SELECT
+  *
+FROM (
+  SELECT
+    *
+  FROM table
+  WHERE
+    id = 42
+) AS table
+WHERE
+  1 = 1
+            """.strip(),
+        ),
+        (
+            'SELECT * FROM "table" WHERE 1=1',
+            {Table("table", "schema1", "catalog1"): "id = 42"},
+            """
+SELECT
+  *
+FROM (
+  SELECT
+    *
+  FROM "table"
+  WHERE
+    id = 42
+) AS "table"
+WHERE
+  1 = 1
+            """.strip(),
+        ),
+        (
+            "SELECT * FROM table WHERE 1=1",
+            {Table("other_table", "schema1", "catalog1"): "id = 42"},
+            """
+SELECT
+  *
+FROM table
+WHERE
+  1 = 1
+            """.strip(),
+        ),
+        (
+            "SELECT * FROM other_table WHERE 1=1",
+            {Table("table", "schema1", "catalog1"): "id = 42"},
+            """
+SELECT
+  *
+FROM other_table
+WHERE
+  1 = 1
+            """.strip(),
+        ),
+        (
+            "SELECT * FROM table JOIN other_table ON table.id = 
other_table.id",
+            {Table("other_table", "schema1", "catalog1"): "id = 42"},
+            """
+SELECT
+  *
+FROM table
+JOIN (
+  SELECT
+    *
+  FROM other_table
+  WHERE
+    id = 42
+) AS other_table
+  ON table.id = other_table.id
+            """.strip(),
+        ),
+        (
+            'SELECT * FROM "table" JOIN other_table ON "table".id = 
other_table.id',
+            {Table("table", "schema1", "catalog1"): "id = 42"},
+            """
+SELECT
+  *
+FROM (
+  SELECT
+    *
+  FROM "table"
+  WHERE
+    id = 42
+) AS "table"
+JOIN other_table
+  ON "table".id = other_table.id
+            """.strip(),
+        ),
+        (
+            "SELECT * FROM (SELECT * FROM some_table)",
+            {Table("some_table", "schema1", "catalog1"): "id = 42"},
+            """
+SELECT
+  *
+FROM (
+  SELECT
+    *
+  FROM (
+    SELECT
+      *
+    FROM some_table
+    WHERE
+      id = 42
+  ) AS some_table
+)
+            """.strip(),
+        ),
+        (
+            "SELECT * FROM table UNION ALL SELECT * FROM other_table",
+            {Table("table", "schema1", "catalog1"): "id = 42"},
+            """
+SELECT
+  *
+FROM (
+  SELECT
+    *
+  FROM table
+  WHERE
+    id = 42
+) AS table
+UNION ALL
+SELECT
+  *
+FROM other_table
+            """.strip(),
+        ),
+        (
+            "SELECT * FROM table UNION ALL SELECT * FROM other_table",
+            {Table("other_table", "schema1", "catalog1"): "id = 42"},
+            """
+SELECT
+  *
+FROM table
+UNION ALL
+SELECT
+  *
+FROM (
+  SELECT
+    *
+  FROM other_table
+  WHERE
+    id = 42
+) AS other_table
+            """.strip(),
+        ),
+        (
+            "SELECT a.*, b.* FROM tbl_a AS a INNER JOIN tbl_b AS b ON a.col = 
b.col",
+            {Table("tbl_a", "schema1", "catalog1"): "id = 42"},
+            """
+SELECT
+  a.*,
+  b.*
+FROM (
+  SELECT
+    *
+  FROM tbl_a
+  WHERE
+    id = 42
+) AS a
+INNER JOIN tbl_b AS b
+  ON a.col = b.col
+            """.strip(),
+        ),
+        (
+            "SELECT a.*, b.* FROM tbl_a a INNER JOIN tbl_b b ON a.col = b.col",
+            {Table("tbl_a", "schema1", "catalog1"): "id = 42"},
+            """
+SELECT
+  a.*,
+  b.*
+FROM (
+  SELECT
+    *
+  FROM tbl_a
+  WHERE
+    id = 42
+) AS a
+INNER JOIN tbl_b AS b
+  ON a.col = b.col
+            """.strip(),
+        ),
+    ],
+)
+def test_rls_subquery_transformer(
+    sql: str,
+    rules: dict[Table, str],
+    expected: str,
+) -> None:
+    """
+    Test `RLSAsSubqueryTransformer`.
+    """
+    statement = SQLStatement(sql)
+    statement.apply_rls(
+        "catalog1",
+        "schema1",
+        {k: [parse_one(v)] for k, v in rules.items()},
+        RLSMethod.AS_SUBQUERY,
+    )
+    assert statement.format() == expected
+
+
[email protected](
+    "sql, rules, expected",
+    [
+        (
+            "SELECT t.foo FROM some_table AS t",
+            {Table("some_table", "schema1", "catalog1"): "id = 42"},
+            """
+SELECT
+  t.foo
+FROM some_table AS t
+WHERE
+  t.id = 42
+            """.strip(),
+        ),
+        (
+            "SELECT t.foo FROM schema2.some_table AS t",
+            {Table("some_table", "schema1", "catalog1"): "id = 42"},
+            """
+SELECT
+  t.foo
+FROM schema2.some_table AS t
+            """.strip(),
+        ),
+        (
+            "SELECT t.foo FROM catalog2.schema1.some_table AS t",
+            {Table("some_table", "schema1", "catalog1"): "id = 42"},
+            """
+SELECT
+  t.foo
+FROM catalog2.schema1.some_table AS t
+            """.strip(),
+        ),
+        (
+            "SELECT t.foo FROM some_table AS t WHERE bar = 'baz'",
+            {Table("some_table", "schema1", "catalog1"): "id = 42"},
+            """
+SELECT
+  t.foo
+FROM some_table AS t
+WHERE
+  t.id = 42 AND (
+    bar = 'baz'
+  )
+            """.strip(),
+        ),
+        (
+            "SELECT t.foo FROM some_table AS t WHERE bar = 'baz' OR foo = 
'qux'",
+            {Table("some_table", "schema1", "catalog1"): "id = 42"},
+            """
+SELECT
+  t.foo
+FROM some_table AS t
+WHERE
+  t.id = 42 AND (
+    bar = 'baz' OR foo = 'qux'
+  )
+            """.strip(),
+        ),
+        (
+            "SELECT * FROM some_table WHERE 1=1",
+            {Table("some_table", "schema1", "catalog1"): "id = 42"},
+            """
+SELECT
+  *
+FROM some_table
+WHERE
+  some_table.id = 42 AND (
+    1 = 1
+  )
+            """.strip(),
+        ),
+        (
+            "SELECT * FROM some_table WHERE TRUE OR FALSE",
+            {Table("some_table", "schema1", "catalog1"): "id = 42"},
+            """
+SELECT
+  *
+FROM some_table
+WHERE
+  some_table.id = 42 AND (
+    TRUE OR FALSE
+  )
+            """.strip(),
+        ),
+        (
+            "SELECT * FROM table WHERE 1=1",
+            {Table("table", "schema1", "catalog1"): "id = 42"},
+            """
+SELECT
+  *
+FROM table
+WHERE
+  table.id = 42 AND (
+    1 = 1
+  )
+            """.strip(),
+        ),
+        (
+            'SELECT * FROM "table" WHERE 1=1',
+            {Table("table", "schema1", "catalog1"): "id = 42"},
+            """
+SELECT
+  *
+FROM "table"
+WHERE
+  "table".id = 42 AND (
+    1 = 1
+  )
+            """.strip(),
+        ),
+        (
+            "SELECT * FROM table WHERE 1=1",
+            {Table("other_table", "schema1", "catalog1"): "id = 42"},
+            """
+SELECT
+  *
+FROM table
+WHERE
+  1 = 1
+            """.strip(),
+        ),
+        (
+            "SELECT * FROM other_table WHERE 1=1",
+            {Table("table", "schema1", "catalog1"): "id = 42"},
+            """
+SELECT
+  *
+FROM other_table
+WHERE
+  1 = 1
+            """.strip(),
+        ),
+        (
+            "SELECT * FROM table",
+            {Table("table", "schema1", "catalog1"): "id = 42"},
+            """
+SELECT
+  *
+FROM table
+WHERE
+  table.id = 42
+            """.strip(),
+        ),
+        (
+            "SELECT * FROM some_table",
+            {Table("some_table", "schema1", "catalog1"): "id = 42"},
+            """
+SELECT
+  *
+FROM some_table
+WHERE
+  some_table.id = 42
+            """.strip(),
+        ),
+        (
+            "SELECT * FROM table ORDER BY id",
+            {Table("table", "schema1", "catalog1"): "id = 42"},
+            """
+SELECT
+  *
+FROM table
+WHERE
+  table.id = 42
+ORDER BY
+  id
+            """.strip(),
+        ),
+        (
+            "SELECT * FROM table WHERE 1=1 AND table.id=42",
+            {Table("table", "schema1", "catalog1"): "id = 42"},
+            """
+SELECT
+  *
+FROM table
+WHERE
+  table.id = 42 AND (
+    1 = 1 AND table.id = 42
+  )
+            """.strip(),
+        ),
+        (
+            """
+SELECT * FROM table
+JOIN other_table
+ON table.id = other_table.id
+AND other_table.id=42
+            """,
+            {Table("other_table", "schema1", "catalog1"): "id = 42"},
+            """
+SELECT
+  *
+FROM table
+JOIN other_table
+  ON other_table.id = 42 AND (
+    table.id = other_table.id AND other_table.id = 42
+  )
+            """.strip(),
+        ),
+    ],
+)
+def test_rls_predicate_transformer(
+    sql: str,
+    rules: dict[Table, str],
+    expected: str,
+) -> None:
+    """
+    Test `RLSPredicateTransformer`.
+    """
+    statement = SQLStatement(sql)
+    statement.apply_rls(
+        "catalog1",
+        "schema1",
+        {k: [parse_one(v)] for k, v in rules.items()},
+        RLSMethod.AS_PREDICATE,
+    )
+    assert statement.format() == expected

Reply via email to