This is an automated email from the ASF dual-hosted git repository.

elizabeth pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 090ae64dfa feat(tag): fast follow for Tags flatten api + update client 
with generator + some bug fixes (#25309)
090ae64dfa is described below

commit 090ae64dfa3dab8c71a0ffbbdfc69e4ef4a73210
Author: Hugh A. Miles II <[email protected]>
AuthorDate: Mon Sep 18 14:56:08 2023 -0400

    feat(tag): fast follow for Tags flatten api + update client with generator 
+ some bug fixes (#25309)
---
 .../src/features/tags/BulkTagModal.tsx             | 12 +++++---
 superset/daos/tag.py                               |  1 -
 superset/tags/api.py                               |  5 +++-
 superset/tags/commands/create.py                   | 22 ++++++++++++--
 superset/tags/commands/utils.py                    | 18 +++++++++++
 superset/tags/schemas.py                           | 22 +++++---------
 tests/integration_tests/tags/api_tests.py          | 24 +++++++++++----
 tests/unit_tests/dao/tag_test.py                   |  3 --
 tests/unit_tests/tags/commands/create_test.py      | 19 ++++++++++--
 tests/unit_tests/tags/commands/update_test.py      | 35 +++++++++++++++++++---
 10 files changed, 124 insertions(+), 37 deletions(-)

diff --git a/superset-frontend/src/features/tags/BulkTagModal.tsx 
b/superset-frontend/src/features/tags/BulkTagModal.tsx
index adacef1f47..3fff056f41 100644
--- a/superset-frontend/src/features/tags/BulkTagModal.tsx
+++ b/superset-frontend/src/features/tags/BulkTagModal.tsx
@@ -45,13 +45,19 @@ const BulkTagModal: React.FC<BulkTagModalProps> = ({
   addDangerToast,
 }) => {
   useEffect(() => {}, []);
+  const [tags, setTags] = useState<TaggableResourceOption[]>([]);
 
   const onSave = async () => {
     await SupersetClient.post({
       endpoint: `/api/v1/tag/bulk_create`,
       jsonPayload: {
-        tags: tags.map(tag => tag.value),
-        objects_to_tag: selected.map(item => [resourceName, 
+item.original.id]),
+        tags: tags.map(tag => ({
+          name: tag.value,
+          objects_to_tag: selected.map(item => [
+            resourceName,
+            +item.original.id,
+          ]),
+        })),
       },
     })
       .then(({ json = {} }) => {
@@ -66,8 +72,6 @@ const BulkTagModal: React.FC<BulkTagModalProps> = ({
     setTags([]);
   };
 
-  const [tags, setTags] = useState<TaggableResourceOption[]>([]);
-
   return (
     <Modal
       title={t('Bulk tag')}
diff --git a/superset/daos/tag.py b/superset/daos/tag.py
index b55397a325..c063657ea0 100644
--- a/superset/daos/tag.py
+++ b/superset/daos/tag.py
@@ -412,4 +412,3 @@ class TagDAO(BaseDAO[Tag]):
                 )
 
         db.session.add_all(tagged_objects)
-        db.session.commit()
diff --git a/superset/tags/api.py b/superset/tags/api.py
index 50807ec078..7224d1423a 100644
--- a/superset/tags/api.py
+++ b/superset/tags/api.py
@@ -260,7 +260,10 @@ class TagRestApi(BaseSupersetModelRestApi):
         try:
             for tag in item.get("tags"):
                 tagged_item: dict[str, Any] = self.add_model_schema.load(
-                    {"name": tag, "objects_to_tag": item.get("objects_to_tag")}
+                    {
+                        "name": tag.get("name"),
+                        "objects_to_tag": tag.get("objects_to_tag"),
+                    }
                 )
                 CreateCustomTagWithRelationshipsCommand(
                     tagged_item, bulk_create=True
diff --git a/superset/tags/commands/create.py b/superset/tags/commands/create.py
index 3f05ccd23e..e8311ad520 100644
--- a/superset/tags/commands/create.py
+++ b/superset/tags/commands/create.py
@@ -17,12 +17,13 @@
 import logging
 from typing import Any
 
-from superset import db
+from superset import db, security_manager
 from superset.commands.base import BaseCommand, CreateMixin
 from superset.daos.exceptions import DAOCreateFailedError
 from superset.daos.tag import TagDAO
+from superset.exceptions import SupersetSecurityException
 from superset.tags.commands.exceptions import TagCreateFailedError, 
TagInvalidError
-from superset.tags.commands.utils import to_object_type
+from superset.tags.commands.utils import to_object_model, to_object_type
 from superset.tags.models import ObjectTypes, TagTypes
 
 logger = logging.getLogger(__name__)
@@ -73,6 +74,7 @@ class CreateCustomTagWithRelationshipsCommand(CreateMixin, 
BaseCommand):
 
     def run(self) -> None:
         self.validate()
+
         try:
             tag = TagDAO.get_by_name(self._tag.strip(), TagTypes.custom)
             if self._objects_to_tag:
@@ -84,7 +86,8 @@ class CreateCustomTagWithRelationshipsCommand(CreateMixin, 
BaseCommand):
 
             if self._description:
                 tag.description = self._description
-                db.session.commit()
+
+            db.session.commit()
 
         except DAOCreateFailedError as ex:
             logger.exception(ex.exception)
@@ -98,12 +101,25 @@ class CreateCustomTagWithRelationshipsCommand(CreateMixin, 
BaseCommand):
                 exceptions.append(TagInvalidError())
 
             # Validate object type
+            skipped_tagged_objects: list[tuple[str, int]] = []
             for obj_type, obj_id in self._objects_to_tag:
+                skipped_tagged_objects = []
                 object_type = to_object_type(obj_type)
+
                 if not object_type:
                     exceptions.append(
                         TagInvalidError(f"invalid object type {object_type}")
                     )
+                try:
+                    model = to_object_model(object_type, obj_id)  # type: 
ignore
+                    security_manager.raise_for_ownership(model)
+                except SupersetSecurityException:
+                    # skip the object if the user doesn't have access
+                    skipped_tagged_objects.append((obj_type, obj_id))
+
+            self._objects_to_tag = set(self._objects_to_tag) - set(
+                skipped_tagged_objects
+            )
 
         if exceptions:
             raise TagInvalidError(exceptions=exceptions)
diff --git a/superset/tags/commands/utils.py b/superset/tags/commands/utils.py
index 2993365b7a..028465d83a 100644
--- a/superset/tags/commands/utils.py
+++ b/superset/tags/commands/utils.py
@@ -17,6 +17,12 @@
 
 from typing import Optional, Union
 
+from superset.daos.chart import ChartDAO
+from superset.daos.dashboard import DashboardDAO
+from superset.daos.query import SavedQueryDAO
+from superset.models.dashboard import Dashboard
+from superset.models.slice import Slice
+from superset.models.sql_lab import SavedQuery
 from superset.tags.models import ObjectTypes
 
 
@@ -27,3 +33,15 @@ def to_object_type(object_type: Union[ObjectTypes, int, 
str]) -> Optional[Object
         if object_type in [type_.value, type_.name]:
             return type_
     return None
+
+
+def to_object_model(
+    object_type: ObjectTypes, object_id: int
+) -> Optional[Union[Dashboard, SavedQuery, Slice]]:
+    if ObjectTypes.dashboard == object_type:
+        return DashboardDAO.find_by_id(object_id)
+    if ObjectTypes.query == object_type:
+        return SavedQueryDAO.find_by_id(object_id)
+    if ObjectTypes.chart == object_type:
+        return ChartDAO.find_by_id(object_id)
+    return None
diff --git a/superset/tags/schemas.py b/superset/tags/schemas.py
index 8aafbb76b5..571a2a03c9 100644
--- a/superset/tags/schemas.py
+++ b/superset/tags/schemas.py
@@ -54,27 +54,21 @@ class TagGetResponseSchema(Schema):
     type = fields.String()
 
 
-class TagPostSchema(Schema):
+class TagObjectSchema(Schema):
     name = fields.String()
     description = fields.String(required=False, allow_none=True)
-    # resource id's to tag with tag
     objects_to_tag = fields.List(
         fields.Tuple((fields.String(), fields.Int())), required=False
     )
 
 
 class TagPostBulkSchema(Schema):
-    tags = fields.List(fields.String())
-    # resource id's to tag with tag
-    objects_to_tag = fields.List(
-        fields.Tuple((fields.String(), fields.Int())), required=False
-    )
+    tags = fields.List(fields.Nested(TagObjectSchema))
 
 
-class TagPutSchema(Schema):
-    name = fields.String()
-    description = fields.String(required=False, allow_none=True)
-    # resource id's to tag with tag
-    objects_to_tag = fields.List(
-        fields.Tuple((fields.String(), fields.Int())), required=False
-    )
+class TagPostSchema(TagObjectSchema):
+    pass
+
+
+class TagPutSchema(TagObjectSchema):
+    pass
diff --git a/tests/integration_tests/tags/api_tests.py 
b/tests/integration_tests/tags/api_tests.py
index 06e4a73e19..444d52078e 100644
--- a/tests/integration_tests/tags/api_tests.py
+++ b/tests/integration_tests/tags/api_tests.py
@@ -530,8 +530,23 @@ class TestTagApi(SupersetTestCase):
         rv = self.client.post(
             uri,
             json={
-                "tags": ["tag1", "tag2", "tag3"],
-                "objects_to_tag": [["dashboard", dashboard.id], ["chart", 
chart.id]],
+                "tags": [
+                    {
+                        "name": "tag1",
+                        "objects_to_tag": [
+                            ["dashboard", dashboard.id],
+                            ["chart", chart.id],
+                        ],
+                    },
+                    {
+                        "name": "tag2",
+                        "objects_to_tag": [["dashboard", dashboard.id]],
+                    },
+                    {
+                        "name": "tag3",
+                        "objects_to_tag": [["chart", chart.id]],
+                    },
+                ]
             },
         )
 
@@ -547,11 +562,10 @@ class TestTagApi(SupersetTestCase):
             TaggedObject.object_id == dashboard.id,
             TaggedObject.object_type == ObjectTypes.dashboard,
         )
-        assert tagged_objects.count() == 3
+        assert tagged_objects.count() == 2
 
         tagged_objects = db.session.query(TaggedObject).filter(
-            # TaggedObject.tag_id.in_([tag.id for tag in tags]),
             TaggedObject.object_id == chart.id,
             TaggedObject.object_type == ObjectTypes.chart,
         )
-        assert tagged_objects.count() == 3
+        assert tagged_objects.count() == 2
diff --git a/tests/unit_tests/dao/tag_test.py b/tests/unit_tests/dao/tag_test.py
index 476c51e45d..065ed75662 100644
--- a/tests/unit_tests/dao/tag_test.py
+++ b/tests/unit_tests/dao/tag_test.py
@@ -169,6 +169,3 @@ def test_create_tag_relationship(mocker):
     # Verify that the correct number of TaggedObjects are added to the session
     assert mock_session.add_all.call_count == 1
     assert len(mock_session.add_all.call_args[0][0]) == len(objects_to_tag)
-
-    # Verify that commit is called
-    mock_session.commit.assert_called_once()
diff --git a/tests/unit_tests/tags/commands/create_test.py 
b/tests/unit_tests/tags/commands/create_test.py
index a188625b40..639372a70f 100644
--- a/tests/unit_tests/tags/commands/create_test.py
+++ b/tests/unit_tests/tags/commands/create_test.py
@@ -1,4 +1,5 @@
 import pytest
+from pytest_mock import MockFixture
 from sqlalchemy.orm.session import Session
 
 from superset.utils.core import DatasourceType
@@ -47,7 +48,7 @@ def session_with_data(session: Session):
     yield session
 
 
-def test_create_command_success(session_with_data: Session):
+def test_create_command_success(session_with_data: Session, mocker: 
MockFixture):
     from superset.connectors.sqla.models import SqlaTable
     from superset.daos.tag import TagDAO
     from superset.models.dashboard import Dashboard
@@ -61,6 +62,12 @@ def test_create_command_success(session_with_data: Session):
     chart = session_with_data.query(Slice).first()
     dashboard = session_with_data.query(Dashboard).first()
 
+    mocker.patch(
+        "superset.security.SupersetSecurityManager.is_admin", return_value=True
+    )
+    mocker.patch("superset.daos.chart.ChartDAO.find_by_id", return_value=chart)
+    mocker.patch("superset.daos.query.SavedQueryDAO.find_by_id", 
return_value=query)
+
     objects_to_tag = [
         (ObjectTypes.query, query.id),
         (ObjectTypes.chart, chart.id),
@@ -84,7 +91,9 @@ def test_create_command_success(session_with_data: Session):
         )
 
 
-def test_create_command_failed_validate(session_with_data: Session):
+def test_create_command_failed_validate(
+    session_with_data: Session, mocker: MockFixture
+):
     from superset.connectors.sqla.models import SqlaTable
     from superset.daos.tag import TagDAO
     from superset.models.dashboard import Dashboard
@@ -98,6 +107,12 @@ def test_create_command_failed_validate(session_with_data: 
Session):
     chart = session_with_data.query(Slice).first()
     dashboard = session_with_data.query(Dashboard).first()
 
+    mocker.patch(
+        "superset.security.SupersetSecurityManager.is_admin", return_value=True
+    )
+    mocker.patch("superset.daos.chart.ChartDAO.find_by_id", return_value=query)
+    mocker.patch("superset.daos.query.SavedQueryDAO.find_by_id", 
return_value=chart)
+
     objects_to_tag = [
         (ObjectTypes.query, query.id),
         (ObjectTypes.chart, chart.id),
diff --git a/tests/unit_tests/tags/commands/update_test.py 
b/tests/unit_tests/tags/commands/update_test.py
index 2c2454547e..84007fbb68 100644
--- a/tests/unit_tests/tags/commands/update_test.py
+++ b/tests/unit_tests/tags/commands/update_test.py
@@ -1,4 +1,5 @@
 import pytest
+from pytest_mock import MockFixture
 from sqlalchemy.orm.session import Session
 
 from superset.utils.core import DatasourceType
@@ -56,13 +57,19 @@ def session_with_data(session: Session):
     yield session
 
 
-def test_update_command_success(session_with_data: Session):
+def test_update_command_success(session_with_data: Session, mocker: 
MockFixture):
     from superset.daos.tag import TagDAO
     from superset.models.dashboard import Dashboard
     from superset.tags.commands.update import UpdateTagCommand
     from superset.tags.models import ObjectTypes, TaggedObject
 
     dashboard = session_with_data.query(Dashboard).first()
+    mocker.patch(
+        "superset.security.SupersetSecurityManager.is_admin", return_value=True
+    )
+    mocker.patch(
+        "superset.daos.dashboard.DashboardDAO.find_by_id", 
return_value=dashboard
+    )
 
     objects_to_tag = [
         (ObjectTypes.dashboard, dashboard.id),
@@ -84,7 +91,9 @@ def test_update_command_success(session_with_data: Session):
     assert len(session_with_data.query(TaggedObject).all()) == 
len(objects_to_tag)
 
 
-def test_update_command_success_duplicates(session_with_data: Session):
+def test_update_command_success_duplicates(
+    session_with_data: Session, mocker: MockFixture
+):
     from superset.daos.tag import TagDAO
     from superset.models.dashboard import Dashboard
     from superset.models.slice import Slice
@@ -95,6 +104,14 @@ def 
test_update_command_success_duplicates(session_with_data: Session):
     dashboard = session_with_data.query(Dashboard).first()
     chart = session_with_data.query(Slice).first()
 
+    mocker.patch(
+        "superset.security.SupersetSecurityManager.is_admin", return_value=True
+    )
+    mocker.patch("superset.daos.chart.ChartDAO.find_by_id", return_value=chart)
+    mocker.patch(
+        "superset.daos.dashboard.DashboardDAO.find_by_id", 
return_value=dashboard
+    )
+
     objects_to_tag = [
         (ObjectTypes.dashboard, dashboard.id),
     ]
@@ -124,14 +141,16 @@ def 
test_update_command_success_duplicates(session_with_data: Session):
     assert changed_model.objects[0].object_id == chart.id
 
 
-def test_update_command_failed_validation(session_with_data: Session):
+def test_update_command_failed_validation(
+    session_with_data: Session, mocker: MockFixture
+):
     from superset.daos.tag import TagDAO
     from superset.models.dashboard import Dashboard
     from superset.models.slice import Slice
     from superset.tags.commands.create import 
CreateCustomTagWithRelationshipsCommand
     from superset.tags.commands.exceptions import TagInvalidError
     from superset.tags.commands.update import UpdateTagCommand
-    from superset.tags.models import ObjectTypes, TaggedObject
+    from superset.tags.models import ObjectTypes
 
     dashboard = session_with_data.query(Dashboard).first()
     chart = session_with_data.query(Slice).first()
@@ -139,6 +158,14 @@ def 
test_update_command_failed_validation(session_with_data: Session):
         (ObjectTypes.chart, chart.id),
     ]
 
+    mocker.patch(
+        "superset.security.SupersetSecurityManager.is_admin", return_value=True
+    )
+    mocker.patch("superset.daos.chart.ChartDAO.find_by_id", return_value=chart)
+    mocker.patch(
+        "superset.daos.dashboard.DashboardDAO.find_by_id", 
return_value=dashboard
+    )
+
     CreateCustomTagWithRelationshipsCommand(
         data={"name": "test_tag", "objects_to_tag": objects_to_tag}
     ).run()

Reply via email to