This is an automated email from the ASF dual-hosted git repository.

lilykuang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new fba66c6250 fix: Use RLS clause instead of ID for cache key (#25229)
fba66c6250 is described below

commit fba66c6250c38944639cfc1f95a67ef00c66629c
Author: Jack Fragassi <[email protected]>
AuthorDate: Mon Sep 18 11:37:35 2023 -0700

    fix: Use RLS clause instead of ID for cache key (#25229)
---
 superset/security/manager.py                       | 26 ++++++++++++----------
 .../security/row_level_security_tests.py           | 15 +++++++++++++
 2 files changed, 29 insertions(+), 12 deletions(-)

diff --git a/superset/security/manager.py b/superset/security/manager.py
index ef0f9c975a..2935e1eb98 100644
--- a/superset/security/manager.py
+++ b/superset/security/manager.py
@@ -79,7 +79,7 @@ from superset.utils.urls import get_url_host
 if TYPE_CHECKING:
     from superset.common.query_context import QueryContext
     from superset.connectors.base.models import BaseDatasource
-    from superset.connectors.sqla.models import SqlaTable
+    from superset.connectors.sqla.models import RowLevelSecurityFilter, 
SqlaTable
     from superset.models.core import Database
     from superset.models.dashboard import Dashboard
     from superset.models.sql_lab import Query
@@ -2083,28 +2083,30 @@ class SupersetSecurityManager(  # pylint: 
disable=too-many-public-methods
         )
         return query.all()
 
-    def get_rls_ids(self, table: "BaseDatasource") -> list[int]:
+    def get_rls_sorted(self, table: "BaseDatasource") -> 
list["RowLevelSecurityFilter"]:
         """
-        Retrieves the appropriate row level security filters IDs for the 
current user
-        and the passed table.
+        Retrieves a list RLS filters sorted by ID for
+        the current user and the passed table.
 
         :param table: The table to check against
-        :returns: A list of IDs
+        :returns: A list RLS filters
         """
-        ids = [f.id for f in self.get_rls_filters(table)]
-        ids.sort()  # Combinations rather than permutations
-        return ids
+        filters = self.get_rls_filters(table)
+        filters.sort(key=lambda f: f.id)
+        return filters
 
     def get_guest_rls_filters_str(self, table: "BaseDatasource") -> list[str]:
         return [f.get("clause", "") for f in self.get_guest_rls_filters(table)]
 
     def get_rls_cache_key(self, datasource: "BaseDatasource") -> list[str]:
-        rls_ids = []
+        rls_clauses_with_group_key = []
         if datasource.is_rls_supported:
-            rls_ids = self.get_rls_ids(datasource)
-        rls_str = [str(rls_id) for rls_id in rls_ids]
+            rls_clauses_with_group_key = [
+                f"{f.clause}-{f.group_key or ''}"
+                for f in self.get_rls_sorted(datasource)
+            ]
         guest_rls = self.get_guest_rls_filters_str(datasource)
-        return guest_rls + rls_str
+        return guest_rls + rls_clauses_with_group_key
 
     @staticmethod
     def _get_current_epoch_time() -> float:
diff --git a/tests/integration_tests/security/row_level_security_tests.py 
b/tests/integration_tests/security/row_level_security_tests.py
index c29ebe9afe..41ca0d5e79 100644
--- a/tests/integration_tests/security/row_level_security_tests.py
+++ b/tests/integration_tests/security/row_level_security_tests.py
@@ -305,6 +305,21 @@ class TestRowLevelSecurity(SupersetTestCase):
         assert not self.NAMES_Q_REGEX.search(sql)
         assert not self.BASE_FILTER_REGEX.search(sql)
 
+    @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
+    def test_get_rls_cache_key(self):
+        g.user = self.get_user(username="admin")
+        tbl = self.get_table(name="birth_names")
+        clauses = security_manager.get_rls_cache_key(tbl)
+        assert clauses == []
+
+        g.user = self.get_user(username="gamma")
+        clauses = security_manager.get_rls_cache_key(tbl)
+        assert clauses == [
+            "name like 'A%' or name like 'B%'-name",
+            "name like 'Q%'-name",
+            "gender = 'boy'-gender",
+        ]
+
 
 class TestRowLevelSecurityCreateAPI(SupersetTestCase):
     @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")

Reply via email to