This is an automated email from the ASF dual-hosted git repository.

diegopucci pushed a commit to branch geido/fix/rls-at-column-values
in repository https://gitbox.apache.org/repos/asf/superset.git

commit cc020ae9af132f16577c90c9c36095bd0cc2a953
Author: Diego Pucci <[email protected]>
AuthorDate: Tue Oct 1 16:31:14 2024 +0300

    fix(Explore): Apply RLS at column values
---
 superset/models/helpers.py                      |  6 ++-
 tests/integration_tests/datasource/api_tests.py | 29 ++++++++++++++
 tests/integration_tests/sqla_models_tests.py    | 26 ++++++++++++
 tests/unit_tests/models/helpers_test.py         | 53 +++++++++++++++++++++++++
 4 files changed, 113 insertions(+), 1 deletion(-)

diff --git a/superset/models/helpers.py b/superset/models/helpers.py
index 80e66f5027..65565ddcc6 100644
--- a/superset/models/helpers.py
+++ b/superset/models/helpers.py
@@ -1314,7 +1314,7 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
             )
         return and_(*l)
 
-    def values_for_column(
+    def values_for_column(  # pylint: disable=too-many-locals
         self,
         column_name: str,
         limit: int = 10000,
@@ -1350,6 +1350,10 @@ class ExploreMixin:  # pylint: 
disable=too-many-public-methods
         if self.fetch_values_predicate:
             qry = 
qry.where(self.get_fetch_values_predicate(template_processor=tp))
 
+        rls_filters = self.get_sqla_row_level_filters(template_processor=tp)
+        if rls_filters:
+            qry = qry.where(and_(*rls_filters))
+
         with self.database.get_sqla_engine() as engine:
             sql = str(qry.compile(engine, compile_kwargs={"literal_binds": 
True}))
             sql = self._apply_cte(sql, cte)
diff --git a/tests/integration_tests/datasource/api_tests.py 
b/tests/integration_tests/datasource/api_tests.py
index d9f3650793..c0db670964 100644
--- a/tests/integration_tests/datasource/api_tests.py
+++ b/tests/integration_tests/datasource/api_tests.py
@@ -18,6 +18,7 @@
 from unittest.mock import ANY, patch
 
 import pytest
+from sqlalchemy.sql.elements import TextClause
 
 from superset import db, security_manager
 from superset.connectors.sqla.models import SqlaTable
@@ -176,3 +177,31 @@ class TestDatasourceApi(SupersetTestCase):
         table.normalize_columns = False
         
self.client.get(f"api/v1/datasource/table/{table.id}/column/col2/values/")  # 
noqa: F841
         denormalize_name_mock.assert_called_with(ANY, "col2")
+
+    @pytest.mark.usefixtures("app_context", "virtual_dataset")
+    def test_get_column_values_with_rls(self):
+        self.login(ADMIN_USERNAME)
+        table = self.get_virtual_dataset()
+        with patch.object(
+            table, "get_sqla_row_level_filters", 
return_value=[TextClause("col2 = 'b'")]
+        ):
+            rv = self.client.get(
+                f"api/v1/datasource/table/{table.id}/column/col1/values/"
+            )
+            self.assertEqual(rv.status_code, 200)
+            response = json.loads(rv.data.decode("utf-8"))
+            self.assertEqual(response["result"], [1])
+
+    @pytest.mark.usefixtures("app_context", "virtual_dataset")
+    def test_get_column_values_with_rls_no_values(self):
+        self.login(ADMIN_USERNAME)
+        table = self.get_virtual_dataset()
+        with patch.object(
+            table, "get_sqla_row_level_filters", 
return_value=[TextClause("col2 = 'q'")]
+        ):
+            rv = self.client.get(
+                f"api/v1/datasource/table/{table.id}/column/col1/values/"
+            )
+            self.assertEqual(rv.status_code, 200)
+            response = json.loads(rv.data.decode("utf-8"))
+            self.assertEqual(response["result"], [])
diff --git a/tests/integration_tests/sqla_models_tests.py 
b/tests/integration_tests/sqla_models_tests.py
index 4398d75c12..60602f1474 100644
--- a/tests/integration_tests/sqla_models_tests.py
+++ b/tests/integration_tests/sqla_models_tests.py
@@ -626,6 +626,32 @@ def 
test_values_for_column_on_text_column(text_column_table):
     assert len(with_null) == 8
 
 
+def test_values_for_column_on_text_column_with_rls(text_column_table):
+    with patch.object(
+        text_column_table,
+        "get_sqla_row_level_filters",
+        return_value=[
+            TextClause("foo = 'foo'"),
+        ],
+    ):
+        with_rls = text_column_table.values_for_column(column_name="foo", 
limit=10000)
+        assert with_rls == ["foo"]
+        assert len(with_rls) == 1
+
+
+def 
test_values_for_column_on_text_column_with_rls_no_values(text_column_table):
+    with patch.object(
+        text_column_table,
+        "get_sqla_row_level_filters",
+        return_value=[
+            TextClause("foo = 'bar'"),
+        ],
+    ):
+        with_rls = text_column_table.values_for_column(column_name="foo", 
limit=10000)
+        assert with_rls == []
+        assert len(with_rls) == 0
+
+
 def test_filter_on_text_column(text_column_table):
     table = text_column_table
     # null value should be replaced
diff --git a/tests/unit_tests/models/helpers_test.py 
b/tests/unit_tests/models/helpers_test.py
index 009cff0adf..c87b217928 100644
--- a/tests/unit_tests/models/helpers_test.py
+++ b/tests/unit_tests/models/helpers_test.py
@@ -21,6 +21,7 @@ from __future__ import annotations
 
 from contextlib import contextmanager
 from typing import TYPE_CHECKING
+from unittest.mock import patch
 
 import pytest
 from pytest_mock import MockerFixture
@@ -85,6 +86,58 @@ def test_values_for_column(database: Database) -> None:
     assert table.values_for_column("a") == [1, None]
 
 
+def test_values_for_column_with_rls(database: Database) -> None:
+    """
+    Test the `values_for_column` method with RLS enabled.
+    """
+    from sqlalchemy.sql.elements import TextClause
+
+    from superset.connectors.sqla.models import SqlaTable, TableColumn
+
+    table = SqlaTable(
+        database=database,
+        schema=None,
+        table_name="t",
+        columns=[
+            TableColumn(column_name="a"),
+        ],
+    )
+    with patch.object(
+        table,
+        "get_sqla_row_level_filters",
+        return_value=[
+            TextClause("a = 1"),
+        ],
+    ):
+        assert table.values_for_column("a") == [1]
+
+
+def test_values_for_column_with_rls_no_values(database: Database) -> None:
+    """
+    Test the `values_for_column` method with RLS enabled and no values.
+    """
+    from sqlalchemy.sql.elements import TextClause
+
+    from superset.connectors.sqla.models import SqlaTable, TableColumn
+
+    table = SqlaTable(
+        database=database,
+        schema=None,
+        table_name="t",
+        columns=[
+            TableColumn(column_name="a"),
+        ],
+    )
+    with patch.object(
+        table,
+        "get_sqla_row_level_filters",
+        return_value=[
+            TextClause("a = 2"),
+        ],
+    ):
+        assert table.values_for_column("a") == []
+
+
 def test_values_for_column_calculated(
     mocker: MockerFixture,
     database: Database,

Reply via email to