This is an automated email from the ASF dual-hosted git repository. diegopucci pushed a commit to branch geido/fix/rls-at-column-values in repository https://gitbox.apache.org/repos/asf/superset.git
commit cc020ae9af132f16577c90c9c36095bd0cc2a953 Author: Diego Pucci <[email protected]> AuthorDate: Tue Oct 1 16:31:14 2024 +0300 fix(Explore): Apply RLS at column values --- superset/models/helpers.py | 6 ++- tests/integration_tests/datasource/api_tests.py | 29 ++++++++++++++ tests/integration_tests/sqla_models_tests.py | 26 ++++++++++++ tests/unit_tests/models/helpers_test.py | 53 +++++++++++++++++++++++++ 4 files changed, 113 insertions(+), 1 deletion(-) diff --git a/superset/models/helpers.py b/superset/models/helpers.py index 80e66f5027..65565ddcc6 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -1314,7 +1314,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods ) return and_(*l) - def values_for_column( + def values_for_column( # pylint: disable=too-many-locals self, column_name: str, limit: int = 10000, @@ -1350,6 +1350,10 @@ class ExploreMixin: # pylint: disable=too-many-public-methods if self.fetch_values_predicate: qry = qry.where(self.get_fetch_values_predicate(template_processor=tp)) + rls_filters = self.get_sqla_row_level_filters(template_processor=tp) + if rls_filters: + qry = qry.where(and_(*rls_filters)) + with self.database.get_sqla_engine() as engine: sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True})) sql = self._apply_cte(sql, cte) diff --git a/tests/integration_tests/datasource/api_tests.py b/tests/integration_tests/datasource/api_tests.py index d9f3650793..c0db670964 100644 --- a/tests/integration_tests/datasource/api_tests.py +++ b/tests/integration_tests/datasource/api_tests.py @@ -18,6 +18,7 @@ from unittest.mock import ANY, patch import pytest +from sqlalchemy.sql.elements import TextClause from superset import db, security_manager from superset.connectors.sqla.models import SqlaTable @@ -176,3 +177,31 @@ class TestDatasourceApi(SupersetTestCase): table.normalize_columns = False self.client.get(f"api/v1/datasource/table/{table.id}/column/col2/values/") # noqa: F841 denormalize_name_mock.assert_called_with(ANY, "col2") + + @pytest.mark.usefixtures("app_context", "virtual_dataset") + def test_get_column_values_with_rls(self): + self.login(ADMIN_USERNAME) + table = self.get_virtual_dataset() + with patch.object( + table, "get_sqla_row_level_filters", return_value=[TextClause("col2 = 'b'")] + ): + rv = self.client.get( + f"api/v1/datasource/table/{table.id}/column/col1/values/" + ) + self.assertEqual(rv.status_code, 200) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(response["result"], [1]) + + @pytest.mark.usefixtures("app_context", "virtual_dataset") + def test_get_column_values_with_rls_no_values(self): + self.login(ADMIN_USERNAME) + table = self.get_virtual_dataset() + with patch.object( + table, "get_sqla_row_level_filters", return_value=[TextClause("col2 = 'q'")] + ): + rv = self.client.get( + f"api/v1/datasource/table/{table.id}/column/col1/values/" + ) + self.assertEqual(rv.status_code, 200) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(response["result"], []) diff --git a/tests/integration_tests/sqla_models_tests.py b/tests/integration_tests/sqla_models_tests.py index 4398d75c12..60602f1474 100644 --- a/tests/integration_tests/sqla_models_tests.py +++ b/tests/integration_tests/sqla_models_tests.py @@ -626,6 +626,32 @@ def test_values_for_column_on_text_column(text_column_table): assert len(with_null) == 8 +def test_values_for_column_on_text_column_with_rls(text_column_table): + with patch.object( + text_column_table, + "get_sqla_row_level_filters", + return_value=[ + TextClause("foo = 'foo'"), + ], + ): + with_rls = text_column_table.values_for_column(column_name="foo", limit=10000) + assert with_rls == ["foo"] + assert len(with_rls) == 1 + + +def test_values_for_column_on_text_column_with_rls_no_values(text_column_table): + with patch.object( + text_column_table, + "get_sqla_row_level_filters", + return_value=[ + TextClause("foo = 'bar'"), + ], + ): + with_rls = text_column_table.values_for_column(column_name="foo", limit=10000) + assert with_rls == [] + assert len(with_rls) == 0 + + def test_filter_on_text_column(text_column_table): table = text_column_table # null value should be replaced diff --git a/tests/unit_tests/models/helpers_test.py b/tests/unit_tests/models/helpers_test.py index 009cff0adf..c87b217928 100644 --- a/tests/unit_tests/models/helpers_test.py +++ b/tests/unit_tests/models/helpers_test.py @@ -21,6 +21,7 @@ from __future__ import annotations from contextlib import contextmanager from typing import TYPE_CHECKING +from unittest.mock import patch import pytest from pytest_mock import MockerFixture @@ -85,6 +86,58 @@ def test_values_for_column(database: Database) -> None: assert table.values_for_column("a") == [1, None] +def test_values_for_column_with_rls(database: Database) -> None: + """ + Test the `values_for_column` method with RLS enabled. + """ + from sqlalchemy.sql.elements import TextClause + + from superset.connectors.sqla.models import SqlaTable, TableColumn + + table = SqlaTable( + database=database, + schema=None, + table_name="t", + columns=[ + TableColumn(column_name="a"), + ], + ) + with patch.object( + table, + "get_sqla_row_level_filters", + return_value=[ + TextClause("a = 1"), + ], + ): + assert table.values_for_column("a") == [1] + + +def test_values_for_column_with_rls_no_values(database: Database) -> None: + """ + Test the `values_for_column` method with RLS enabled and no values. + """ + from sqlalchemy.sql.elements import TextClause + + from superset.connectors.sqla.models import SqlaTable, TableColumn + + table = SqlaTable( + database=database, + schema=None, + table_name="t", + columns=[ + TableColumn(column_name="a"), + ], + ) + with patch.object( + table, + "get_sqla_row_level_filters", + return_value=[ + TextClause("a = 2"), + ], + ): + assert table.values_for_column("a") == [] + + def test_values_for_column_calculated( mocker: MockerFixture, database: Database,
