This is an automated email from the ASF dual-hosted git repository.
lilykuang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new fba66c6250 fix: Use RLS clause instead of ID for cache key (#25229)
fba66c6250 is described below
commit fba66c6250c38944639cfc1f95a67ef00c66629c
Author: Jack Fragassi <[email protected]>
AuthorDate: Mon Sep 18 11:37:35 2023 -0700
fix: Use RLS clause instead of ID for cache key (#25229)
---
superset/security/manager.py | 26 ++++++++++++----------
.../security/row_level_security_tests.py | 15 +++++++++++++
2 files changed, 29 insertions(+), 12 deletions(-)
diff --git a/superset/security/manager.py b/superset/security/manager.py
index ef0f9c975a..2935e1eb98 100644
--- a/superset/security/manager.py
+++ b/superset/security/manager.py
@@ -79,7 +79,7 @@ from superset.utils.urls import get_url_host
if TYPE_CHECKING:
from superset.common.query_context import QueryContext
from superset.connectors.base.models import BaseDatasource
- from superset.connectors.sqla.models import SqlaTable
+ from superset.connectors.sqla.models import RowLevelSecurityFilter,
SqlaTable
from superset.models.core import Database
from superset.models.dashboard import Dashboard
from superset.models.sql_lab import Query
@@ -2083,28 +2083,30 @@ class SupersetSecurityManager( # pylint:
disable=too-many-public-methods
)
return query.all()
- def get_rls_ids(self, table: "BaseDatasource") -> list[int]:
+ def get_rls_sorted(self, table: "BaseDatasource") ->
list["RowLevelSecurityFilter"]:
"""
- Retrieves the appropriate row level security filters IDs for the
current user
- and the passed table.
+ Retrieves a list RLS filters sorted by ID for
+ the current user and the passed table.
:param table: The table to check against
- :returns: A list of IDs
+ :returns: A list RLS filters
"""
- ids = [f.id for f in self.get_rls_filters(table)]
- ids.sort() # Combinations rather than permutations
- return ids
+ filters = self.get_rls_filters(table)
+ filters.sort(key=lambda f: f.id)
+ return filters
def get_guest_rls_filters_str(self, table: "BaseDatasource") -> list[str]:
return [f.get("clause", "") for f in self.get_guest_rls_filters(table)]
def get_rls_cache_key(self, datasource: "BaseDatasource") -> list[str]:
- rls_ids = []
+ rls_clauses_with_group_key = []
if datasource.is_rls_supported:
- rls_ids = self.get_rls_ids(datasource)
- rls_str = [str(rls_id) for rls_id in rls_ids]
+ rls_clauses_with_group_key = [
+ f"{f.clause}-{f.group_key or ''}"
+ for f in self.get_rls_sorted(datasource)
+ ]
guest_rls = self.get_guest_rls_filters_str(datasource)
- return guest_rls + rls_str
+ return guest_rls + rls_clauses_with_group_key
@staticmethod
def _get_current_epoch_time() -> float:
diff --git a/tests/integration_tests/security/row_level_security_tests.py
b/tests/integration_tests/security/row_level_security_tests.py
index c29ebe9afe..41ca0d5e79 100644
--- a/tests/integration_tests/security/row_level_security_tests.py
+++ b/tests/integration_tests/security/row_level_security_tests.py
@@ -305,6 +305,21 @@ class TestRowLevelSecurity(SupersetTestCase):
assert not self.NAMES_Q_REGEX.search(sql)
assert not self.BASE_FILTER_REGEX.search(sql)
+ @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
+ def test_get_rls_cache_key(self):
+ g.user = self.get_user(username="admin")
+ tbl = self.get_table(name="birth_names")
+ clauses = security_manager.get_rls_cache_key(tbl)
+ assert clauses == []
+
+ g.user = self.get_user(username="gamma")
+ clauses = security_manager.get_rls_cache_key(tbl)
+ assert clauses == [
+ "name like 'A%' or name like 'B%'-name",
+ "name like 'Q%'-name",
+ "gender = 'boy'-gender",
+ ]
+
class TestRowLevelSecurityCreateAPI(SupersetTestCase):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")