This is an automated email from the ASF dual-hosted git repository.

yongjiezhao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 5ee070c  feat: datasource access to allow more granular access to 
tables on SQL Lab (#18064)
5ee070c is described below

commit 5ee070c40228d6abbb30e4a8f7888886cf35d7f1
Author: Victor Arbues <victor.arb...@hmrc.gov.uk>
AuthorDate: Wed Feb 9 14:05:25 2022 +0000

    feat: datasource access to allow more granular access to tables on SQL Lab 
(#18064)
---
 superset/databases/filters.py                 | 19 ++++++---
 superset/security/manager.py                  | 20 +++++----
 tests/integration_tests/core_tests.py         | 59 +++++++++++++++++++++++++++
 tests/integration_tests/datasets/api_tests.py |  6 ++-
 4 files changed, 89 insertions(+), 15 deletions(-)

diff --git a/superset/databases/filters.py b/superset/databases/filters.py
index 6fa9339..bee7d2c 100644
--- a/superset/databases/filters.py
+++ b/superset/databases/filters.py
@@ -25,21 +25,28 @@ from superset.views.base import BaseFilter
 
 class DatabaseFilter(BaseFilter):
     # TODO(bogdan): consider caching.
-    def schema_access_databases(self) -> Set[str]:  # noqa pylint: 
disable=no-self-use
+
+    def can_access_databases(  # noqa pylint: disable=no-self-use
+        self, view_menu_name: str,
+    ) -> Set[str]:
         return {
-            security_manager.unpack_schema_perm(vm)[0]
-            for vm in security_manager.user_view_menu_names("schema_access")
+            security_manager.unpack_database_and_schema(vm).database
+            for vm in security_manager.user_view_menu_names(view_menu_name)
         }
 
     def apply(self, query: Query, value: Any) -> Query:
         if security_manager.can_access_all_databases():
             return query
         database_perms = 
security_manager.user_view_menu_names("database_access")
-        # TODO(bogdan): consider adding datasource access here as well.
-        schema_access_databases = self.schema_access_databases()
+        schema_access_databases = self.can_access_databases("schema_access")
+
+        datasource_access_databases = 
self.can_access_databases("datasource_access")
+
         return query.filter(
             or_(
                 self.model.perm.in_(database_perms),
-                self.model.database_name.in_(schema_access_databases),
+                self.model.database_name.in_(
+                    [*schema_access_databases, *datasource_access_databases]
+                ),
             )
         )
diff --git a/superset/security/manager.py b/superset/security/manager.py
index d9206df..0bed447 100644
--- a/superset/security/manager.py
+++ b/superset/security/manager.py
@@ -26,9 +26,9 @@ from typing import (
     cast,
     Dict,
     List,
+    NamedTuple,
     Optional,
     Set,
-    Tuple,
     TYPE_CHECKING,
     Union,
 )
@@ -88,6 +88,11 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+class DatabaseAndSchema(NamedTuple):
+    database: str
+    schema: str
+
+
 class SupersetSecurityListWidget(ListWidget):  # pylint: 
disable=too-few-public-methods
     """
     Redeclaring to avoid circular imports
@@ -263,13 +268,14 @@ class SupersetSecurityManager(  # pylint: 
disable=too-many-public-methods
 
         return None
 
-    def unpack_schema_perm(  # pylint: disable=no-self-use
+    def unpack_database_and_schema(  # pylint: disable=no-self-use
         self, schema_permission: str
-    ) -> Tuple[str, str]:
-        # [database_name].[schema_name]
+    ) -> DatabaseAndSchema:
+        # [database_name].[schema|table]
+
         schema_name = schema_permission.split(".")[1][1:-1]
         database_name = schema_permission.split(".")[0][1:-1]
-        return database_name, schema_name
+        return DatabaseAndSchema(database_name, schema_name)
 
     def can_access(self, permission_name: str, view_name: str) -> bool:
         """
@@ -558,7 +564,7 @@ class SupersetSecurityManager(  # pylint: 
disable=too-many-public-methods
 
         # schema_access
         accessible_schemas = {
-            self.unpack_schema_perm(s)[1]
+            self.unpack_database_and_schema(s).schema
             for s in self.user_view_menu_names("schema_access")
             if s.startswith(f"[{database}].")
         }
@@ -608,7 +614,7 @@ class SupersetSecurityManager(  # pylint: 
disable=too-many-public-methods
         )
         if schema:
             names = {d.table_name for d in user_datasources if d.schema == 
schema}
-            return [d for d in datasource_names if d in names]
+            return [d for d in datasource_names if d.table in names]
 
         full_names = {d.full_name for d in user_datasources}
         return [d for d in datasource_names if f"[{database}].[{d}]" in 
full_names]
diff --git a/tests/integration_tests/core_tests.py 
b/tests/integration_tests/core_tests.py
index 1c4682a..43288e0 100644
--- a/tests/integration_tests/core_tests.py
+++ b/tests/integration_tests/core_tests.py
@@ -163,6 +163,65 @@ class TestCore(SupersetTestCase):
         rv = self.client.get(uri)
         self.assertEqual(rv.status_code, 404)
 
+    @pytest.mark.usefixtures("load_energy_table_with_slice")
+    def test_get_superset_tables_allowed(self):
+        session = db.session
+        table_name = "energy_usage"
+        role_name = "dummy_role"
+        self.logout()
+        self.login(username="gamma")
+        gamma_user = security_manager.find_user(username="gamma")
+        security_manager.add_role(role_name)
+        dummy_role = security_manager.find_role(role_name)
+        gamma_user.roles.append(dummy_role)
+
+        tbl_id = self.table_ids.get(table_name)
+        table = db.session.query(SqlaTable).filter(SqlaTable.id == 
tbl_id).first()
+        table_perm = table.perm
+
+        security_manager.add_permission_role(
+            dummy_role,
+            security_manager.find_permission_view_menu("datasource_access", 
table_perm),
+        )
+
+        session.commit()
+
+        example_db = utils.get_example_database()
+        schema_name = self.default_schema_backend_map[example_db.backend]
+        uri = f"superset/tables/{example_db.id}/{schema_name}/{table_name}/"
+        rv = self.client.get(uri)
+        self.assertEqual(rv.status_code, 200)
+
+        # cleanup
+        gamma_user = security_manager.find_user(username="gamma")
+        gamma_user.roles.remove(security_manager.find_role(role_name))
+        session.commit()
+
+    @pytest.mark.usefixtures("load_energy_table_with_slice")
+    def test_get_superset_tables_not_allowed_with_out_permissions(self):
+        session = db.session
+        table_name = "energy_usage"
+        role_name = "dummy_role_no_table_access"
+        self.logout()
+        self.login(username="gamma")
+        gamma_user = security_manager.find_user(username="gamma")
+        security_manager.add_role(role_name)
+        dummy_role = security_manager.find_role(role_name)
+        gamma_user.roles.append(dummy_role)
+
+        session.commit()
+
+        example_db = utils.get_example_database()
+        schema_name = self.default_schema_backend_map[example_db.backend]
+        uri = f"superset/tables/{example_db.id}/{schema_name}/{table_name}/"
+        rv = self.client.get(uri)
+        self.assertEqual(rv.status_code, 404)
+
+        # cleanup
+        gamma_user = security_manager.find_user(username="gamma")
+        gamma_user.roles.remove(security_manager.find_role(role_name))
+        session.commit()
+
     def test_get_superset_tables_substr(self):
         example_db = superset.utils.database.get_example_database()
         if example_db.backend in {"presto", "hive"}:
diff --git a/tests/integration_tests/datasets/api_tests.py 
b/tests/integration_tests/datasets/api_tests.py
index 0e7606c..aaf7633 100644
--- a/tests/integration_tests/datasets/api_tests.py
+++ b/tests/integration_tests/datasets/api_tests.py
@@ -219,8 +219,10 @@ class TestDatasetApi(SupersetTestCase):
         rv = self.client.get(uri)
         assert rv.status_code == 200
         response = json.loads(rv.data.decode("utf-8"))
-        assert response["count"] == 0
-        assert response["result"] == []
+
+        assert response["count"] == 1
+        main_db = get_main_database()
+        assert filter(lambda x: x.text == main_db, response["result"]) != []
 
     @pytest.mark.usefixtures("load_energy_table_with_slice")
     def test_get_dataset_item(self):

Reply via email to