This is an automated email from the ASF dual-hosted git repository.

diegopucci pushed a commit to branch geido/feat/rls-digest
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 076af95460e2b9702b6a32b6c88dde1c8af3642c
Author: Diego Pucci <[email protected]>
AuthorDate: Thu Sep 19 17:39:58 2024 +0200

    feat(Digest): Add RLS at digest generation for Charts and Dashboards
---
 superset/dashboards/api.py                 |   4 +-
 superset/tasks/thumbnails.py               |   2 +-
 superset/thumbnails/digest.py              |  49 +++++++-
 tests/unit_tests/thumbnails/test_digest.py | 189 ++++++++++++++++++++---------
 4 files changed, 186 insertions(+), 58 deletions(-)

diff --git a/superset/dashboards/api.py b/superset/dashboards/api.py
index 5ff5b17b24..d02442cc57 100644
--- a/superset/dashboards/api.py
+++ b/superset/dashboards/api.py
@@ -1046,14 +1046,14 @@ class DashboardRestApi(BaseSupersetModelRestApi):
             cache_dashboard_screenshot.delay(
                 username=get_current_user(),
                 guest_token=g.user.guest_token
-                if isinstance(g.user, GuestUser)
+                if hasattr(g, "user") and isinstance(g.user, GuestUser)
                 else None,
                 dashboard_id=dashboard.id,
                 dashboard_url=dashboard_url,
+                cache_key=cache_key,
                 force=True,
                 thumb_size=thumb_size,
                 window_size=window_size,
-                cache_key=cache_key,
             )
             return self.response(
                 202,
diff --git a/superset/tasks/thumbnails.py b/superset/tasks/thumbnails.py
index 34c4fc7377..dd9b5065dc 100644
--- a/superset/tasks/thumbnails.py
+++ b/superset/tasks/thumbnails.py
@@ -114,10 +114,10 @@ def cache_dashboard_screenshot(  # pylint: 
disable=too-many-arguments
     dashboard_id: int,
     dashboard_url: str,
     force: bool = True,
+    cache_key: Optional[str] = None,
     guest_token: Optional[GuestToken] = None,
     thumb_size: Optional[WindowSize] = None,
     window_size: Optional[WindowSize] = None,
-    cache_key: Optional[str] = None,
 ) -> None:
     # pylint: disable=import-outside-toplevel
     from superset.models.dashboard import Dashboard
diff --git a/superset/thumbnails/digest.py b/superset/thumbnails/digest.py
index fb209fcd50..484d15b100 100644
--- a/superset/thumbnails/digest.py
+++ b/superset/thumbnails/digest.py
@@ -18,15 +18,18 @@
 from __future__ import annotations
 
 import logging
-from typing import TYPE_CHECKING
+from typing import Optional, TYPE_CHECKING
 
 from flask import current_app
 
+from superset import security_manager
 from superset.tasks.types import ExecutorType
 from superset.tasks.utils import get_current_user, get_executor
+from superset.utils.core import override_user
 from superset.utils.hashing import md5_sha_from_str
 
 if TYPE_CHECKING:
+    from superset.connectors.sqla.models import BaseDatasource, SqlaTable
     from superset.models.dashboard import Dashboard
     from superset.models.slice import Slice
 
@@ -49,8 +52,42 @@ def _adjust_string_for_executor(
     return unique_string
 
 
+def _adjust_string_with_rls(
+    unique_string: str,
+    datasources: list[Optional[SqlaTable]] | set[BaseDatasource],
+    executor_type: ExecutorType,
+    executor: str,
+) -> str:
+    """
+    Add the RLS filters to the unique string based on current executor.
+    """
+    if executor_type != ExecutorType.SELENIUM:
+        user = (
+            security_manager.find_user(executor)
+            or security_manager.get_current_guest_user_if_guest()
+        )
+
+        if user:
+            stringified_rls = ""
+            with override_user(user):
+                for datasource in datasources:
+                    if (
+                        datasource
+                        and hasattr(datasource, "is_rls_supported")
+                        and datasource.is_rls_supported
+                    ):
+                        rls_filters = datasource.get_sqla_row_level_filters()
+                        stringified_rls += "".join([str(f) for f in 
rls_filters])
+
+            if stringified_rls != "":
+                unique_string = f"{unique_string}\n{stringified_rls}"
+
+    return unique_string
+
+
 def get_dashboard_digest(dashboard: Dashboard) -> str:
     config = current_app.config
+    datasources = dashboard.datasources
     executor_type, executor = get_executor(
         executor_types=config["THUMBNAIL_EXECUTE_AS"],
         model=dashboard,
@@ -65,19 +102,29 @@ def get_dashboard_digest(dashboard: Dashboard) -> str:
     )
 
     unique_string = _adjust_string_for_executor(unique_string, executor_type, 
executor)
+    unique_string = _adjust_string_with_rls(
+        unique_string, datasources, executor_type, executor
+    )
+
     return md5_sha_from_str(unique_string)
 
 
 def get_chart_digest(chart: Slice) -> str:
     config = current_app.config
+    datasource = chart.datasource
     executor_type, executor = get_executor(
         executor_types=config["THUMBNAIL_EXECUTE_AS"],
         model=chart,
         current_user=get_current_user(),
     )
+
     if func := config["THUMBNAIL_CHART_DIGEST_FUNC"]:
         return func(chart, executor_type, executor)
 
     unique_string = f"{chart.params or ''}.{executor}"
     unique_string = _adjust_string_for_executor(unique_string, executor_type, 
executor)
+    unique_string = _adjust_string_with_rls(
+        unique_string, [datasource], executor_type, executor
+    )
+
     return md5_sha_from_str(unique_string)
diff --git a/tests/unit_tests/thumbnails/test_digest.py 
b/tests/unit_tests/thumbnails/test_digest.py
index 987488ffe7..3e9602e4e0 100644
--- a/tests/unit_tests/thumbnails/test_digest.py
+++ b/tests/unit_tests/thumbnails/test_digest.py
@@ -18,14 +18,15 @@ from __future__ import annotations
 
 from contextlib import nullcontext
 from typing import Any, TYPE_CHECKING
-from unittest.mock import patch
+from unittest.mock import MagicMock, patch, PropertyMock
 
 import pytest
 from flask_appbuilder.security.sqla.models import User
 
+from superset.connectors.sqla.models import BaseDatasource, SqlaTable
 from superset.tasks.exceptions import ExecutorNotFoundError
 from superset.tasks.types import ExecutorType
-from superset.utils.core import override_user
+from superset.utils.core import DatasourceType, override_user
 
 if TYPE_CHECKING:
     from superset.models.dashboard import Dashboard
@@ -62,14 +63,28 @@ def CUSTOM_CHART_FUNC(
     return f"{chart.id}.{executor_type.value}.{executor}"
 
 
+def prepare_datasource_mock(
+    datasource_conf: dict[str, Any], spec: type[BaseDatasource | SqlaTable]
+) -> BaseDatasource | SqlaTable:
+    datasource = MagicMock(spec=spec)
+    datasource.id = 1
+    datasource.type = DatasourceType.TABLE
+    datasource.is_rls_supported = datasource_conf.get("is_rls_supported", 
False)
+    datasource.get_sqla_row_level_filters = datasource_conf.get(
+        "get_sqla_row_level_filters", MagicMock(return_value=[])
+    )
+    return datasource
+
+
 @pytest.mark.parametrize(
-    
"dashboard_overrides,execute_as,has_current_user,use_custom_digest,expected_result",
+    
"dashboard_overrides,execute_as,has_current_user,use_custom_digest,rls_datasources,expected_result",
     [
         (
             None,
             [ExecutorType.SELENIUM],
             False,
             False,
+            [],
             "71452fee8ffbd8d340193d611bcd4559",
         ),
         (
@@ -77,99 +92,103 @@ def CUSTOM_CHART_FUNC(
             [ExecutorType.CURRENT_USER],
             True,
             False,
+            [],
             "209dc060ac19271b8708731e3b8280f5",
         ),
         (
-            {
-                "dashboard_title": "My Other Title",
-            },
-            [ExecutorType.CURRENT_USER],
-            True,
-            False,
-            "209dc060ac19271b8708731e3b8280f5",
-        ),
-        (
-            {
-                "id": 2,
-            },
-            [ExecutorType.CURRENT_USER],
-            True,
-            False,
-            "06a4144466dbd5ffad0c3c2225e96296",
-        ),
-        (
-            {
-                "slices": [{"id": 2, "slice_name": "My Other Chart"}],
-            },
-            [ExecutorType.CURRENT_USER],
-            True,
-            False,
-            "a823ece9563895ccb14f3d9095e84f7a",
-        ),
-        (
-            {
-                "position_json": {"b": "c"},
-            },
-            [ExecutorType.CURRENT_USER],
-            True,
-            False,
-            "33c5475f92a904925ab3ef493526e5b5",
-        ),
-        (
-            {
-                "css": "background-color: darkblue;",
-            },
+            None,
             [ExecutorType.CURRENT_USER],
             True,
             False,
-            "cec57345e6402c0d4b3caee5cfaa0a03",
+            [
+                {
+                    "is_rls_supported": True,
+                    "get_sqla_row_level_filters": 
MagicMock(return_value=["filter1"]),
+                }
+            ],
+            "f4f262a649225d62717cdb11a9f2b8ee",
         ),
         (
-            {
-                "json_metadata": {"d": "e"},
-            },
+            None,
             [ExecutorType.CURRENT_USER],
             True,
             False,
-            "5380dcbe94621a0759b09554404f3d02",
+            [
+                {
+                    "is_rls_supported": True,
+                    "get_sqla_row_level_filters": MagicMock(
+                        return_value=["filter1", "filter2"]
+                    ),
+                },
+                {
+                    "is_rls_supported": True,
+                    "get_sqla_row_level_filters": MagicMock(
+                        return_value=["filter3", "filter4"]
+                    ),
+                },
+            ],
+            "4903ad13cfe28a5fc41478147d87211a",
         ),
         (
             None,
             [ExecutorType.CURRENT_USER],
             True,
-            True,
-            "1.current_user.1",
+            False,
+            [
+                {
+                    "is_rls_supported": False,
+                    "get_sqla_row_level_filters": MagicMock(return_value=[]),
+                },
+                {
+                    "is_rls_supported": True,
+                    "get_sqla_row_level_filters": MagicMock(
+                        return_value=["filter1", "filter2"]
+                    ),
+                },
+            ],
+            "774b0c1ed98142ee8f60aae51f128000",
         ),
         (
             None,
             [ExecutorType.CURRENT_USER],
             False,
             False,
+            [],
             ExecutorNotFoundError(),
         ),
     ],
 )
-def test_dashboard_digest(
+def test_dashboard_digest_with_rls(
     dashboard_overrides: dict[str, Any] | None,
     execute_as: list[ExecutorType],
     has_current_user: bool,
     use_custom_digest: bool,
+    rls_datasources: list[dict[str, Any]],
     expected_result: str | Exception,
 ) -> None:
-    from superset import app
+    from superset import app, security_manager
     from superset.models.dashboard import Dashboard
     from superset.models.slice import Slice
     from superset.thumbnails.digest import get_dashboard_digest
 
+    # Prepare dashboard and slices
     kwargs = {
         **_DEFAULT_DASHBOARD_KWARGS,
         **(dashboard_overrides or {}),
     }
     slices = [Slice(**slice_kwargs) for slice_kwargs in kwargs.pop("slices")]
     dashboard = Dashboard(**kwargs, slices=slices)
+
+    # Mock datasources with RLS
+    datasources = []
+    for rls_source in rls_datasources:
+        datasource = prepare_datasource_mock(rls_source, BaseDatasource)
+        datasources.append(datasource)
+
     user: User | None = None
     if has_current_user:
         user = User(id=1, username="1")
+
     func = CUSTOM_DASHBOARD_FUNC if use_custom_digest else None
 
     with (
@@ -180,6 +199,13 @@ def test_dashboard_digest(
                 "THUMBNAIL_DASHBOARD_DIGEST_FUNC": func,
             },
         ),
+        patch.object(
+            type(dashboard),
+            "datasources",
+            new_callable=PropertyMock,
+            return_value=datasources,
+        ),
+        patch.object(security_manager, "find_user", return_value=user),
         override_user(user),
     ):
         cm = (
@@ -192,13 +218,14 @@ def test_dashboard_digest(
 
 
 @pytest.mark.parametrize(
-    
"chart_overrides,execute_as,has_current_user,use_custom_digest,expected_result",
+    
"chart_overrides,execute_as,has_current_user,use_custom_digest,rls_datasource,expected_result",
     [
         (
             None,
             [ExecutorType.SELENIUM],
             False,
             False,
+            None,
             "47d852b5c4df211c115905617bb722c1",
         ),
         (
@@ -206,6 +233,7 @@ def test_dashboard_digest(
             [ExecutorType.CURRENT_USER],
             True,
             False,
+            None,
             "4f8109d3761e766e650af514bb358f10",
         ),
         (
@@ -213,36 +241,82 @@ def test_dashboard_digest(
             [ExecutorType.CURRENT_USER],
             True,
             True,
+            None,
             "2.current_user.1",
         ),
         (
             None,
             [ExecutorType.CURRENT_USER],
+            True,
+            False,
+            {
+                "is_rls_supported": True,
+                "get_sqla_row_level_filters": 
MagicMock(return_value=["filter1"]),
+            },
+            "aa9a5f74de01407aaf329734e78ee7f9",
+        ),
+        (
+            None,
+            [ExecutorType.CURRENT_USER],
+            True,
+            False,
+            {
+                "is_rls_supported": True,
+                "get_sqla_row_level_filters": MagicMock(
+                    return_value=["filter1", "filter2"]
+                ),
+            },
+            "508aa7084728ef609469ad0c4410014a",
+        ),
+        (
+            None,
+            [ExecutorType.CURRENT_USER],
+            True,
             False,
+            {
+                "is_rls_supported": False,
+                "get_sqla_row_level_filters": MagicMock(return_value=[]),
+            },
+            "4f8109d3761e766e650af514bb358f10",
+        ),
+        (
+            None,
+            [ExecutorType.CURRENT_USER],
             False,
+            False,
+            None,
             ExecutorNotFoundError(),
         ),
     ],
 )
-def test_chart_digest(
+def test_chart_digest_with_rls(
     chart_overrides: dict[str, Any] | None,
     execute_as: list[ExecutorType],
     has_current_user: bool,
     use_custom_digest: bool,
+    rls_datasource: dict[str, Any] | None,
     expected_result: str | Exception,
 ) -> None:
-    from superset import app
+    from superset import app, security_manager
     from superset.models.slice import Slice
     from superset.thumbnails.digest import get_chart_digest
 
+    # Mock datasource with RLS if provided
+    datasource = None
+    if rls_datasource:
+        datasource = prepare_datasource_mock(rls_datasource, SqlaTable)
+
+    # Prepare chart with the datasource in the constructor
     kwargs = {
         **_DEFAULT_CHART_KWARGS,
         **(chart_overrides or {}),
     }
     chart = Slice(**kwargs)
+
     user: User | None = None
     if has_current_user:
         user = User(id=1, username="1")
+
     func = CUSTOM_CHART_FUNC if use_custom_digest else None
 
     with (
@@ -253,6 +327,13 @@ def test_chart_digest(
                 "THUMBNAIL_CHART_DIGEST_FUNC": func,
             },
         ),
+        patch.object(
+            type(chart),
+            "datasource",
+            new_callable=PropertyMock,
+            return_value=datasource,
+        ),
+        patch.object(security_manager, "find_user", return_value=user),
         override_user(user),
     ):
         cm = (

Reply via email to