This is an automated email from the ASF dual-hosted git repository.
vavila pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new 33b934cbb3 fix(Tags filter): Filter assets by tag ID (#29412)
33b934cbb3 is described below
commit 33b934cbb346b464260dc1e2f4218713595a07e1
Author: Vitor Avila <[email protected]>
AuthorDate: Thu Jul 11 12:26:36 2024 -0300
fix(Tags filter): Filter assets by tag ID (#29412)
---
superset-frontend/src/components/ListView/types.ts | 9 +-
superset-frontend/src/pages/ChartList/index.tsx | 2 +-
.../src/pages/DashboardList/index.tsx | 2 +-
.../src/pages/SavedQueryList/index.tsx | 2 +-
superset/charts/api.py | 5 +-
superset/charts/filters.py | 19 ++-
superset/dashboards/api.py | 5 +-
superset/dashboards/filters.py | 19 ++-
superset/queries/saved_queries/api.py | 15 +-
superset/queries/saved_queries/filters.py | 19 ++-
superset/tags/filters.py | 54 +++++++
superset/views/base_api.py | 25 ----
tests/integration_tests/base_tests.py | 16 +++
tests/integration_tests/charts/api_tests.py | 153 ++++++++++++++++----
tests/integration_tests/dashboards/api_tests.py | 159 +++++++++++++++++----
tests/integration_tests/fixtures/tags.py | 35 +++++
.../queries/saved_queries/api_tests.py | 121 ++++++++++++++++
tests/unit_tests/tags/filters_test.py | 85 +++++++++++
18 files changed, 636 insertions(+), 109 deletions(-)
diff --git a/superset-frontend/src/components/ListView/types.ts
b/superset-frontend/src/components/ListView/types.ts
index ca3a8b3c70..d7c7cd5117 100644
--- a/superset-frontend/src/components/ListView/types.ts
+++ b/superset-frontend/src/components/ListView/types.ts
@@ -117,7 +117,10 @@ export enum FilterOperator {
DatasetIsCertified = 'dataset_is_certified',
DashboardHasCreatedBy = 'dashboard_has_created_by',
ChartHasCreatedBy = 'chart_has_created_by',
- DashboardTags = 'dashboard_tags',
- ChartTags = 'chart_tags',
- SavedQueryTags = 'saved_query_tags',
+ DashboardTagByName = 'dashboard_tags',
+ DashboardTagById = 'dashboard_tag_id',
+ ChartTagByName = 'chart_tags',
+ ChartTagById = 'chart_tag_id',
+ SavedQueryTagByName = 'saved_query_tags',
+ SavedQueryTagById = 'saved_query_tag_id',
}
diff --git a/superset-frontend/src/pages/ChartList/index.tsx
b/superset-frontend/src/pages/ChartList/index.tsx
index 65ec54a40b..6650583534 100644
--- a/superset-frontend/src/pages/ChartList/index.tsx
+++ b/superset-frontend/src/pages/ChartList/index.tsx
@@ -614,7 +614,7 @@ function ChartList(props: ChartListProps) {
key: 'tags',
id: 'tags',
input: 'select',
- operator: FilterOperator.ChartTags,
+ operator: FilterOperator.ChartTagById,
unfilteredLabel: t('All'),
fetchSelects: loadTags,
},
diff --git a/superset-frontend/src/pages/DashboardList/index.tsx
b/superset-frontend/src/pages/DashboardList/index.tsx
index aa577749d4..8ffc51ce2a 100644
--- a/superset-frontend/src/pages/DashboardList/index.tsx
+++ b/superset-frontend/src/pages/DashboardList/index.tsx
@@ -547,7 +547,7 @@ function DashboardList(props: DashboardListProps) {
key: 'tags',
id: 'tags',
input: 'select',
- operator: FilterOperator.DashboardTags,
+ operator: FilterOperator.DashboardTagById,
unfilteredLabel: t('All'),
fetchSelects: loadTags,
},
diff --git a/superset-frontend/src/pages/SavedQueryList/index.tsx
b/superset-frontend/src/pages/SavedQueryList/index.tsx
index dd4506185c..72836d6059 100644
--- a/superset-frontend/src/pages/SavedQueryList/index.tsx
+++ b/superset-frontend/src/pages/SavedQueryList/index.tsx
@@ -501,7 +501,7 @@ function SavedQueryList({
id: 'tags',
key: 'tags',
input: 'select',
- operator: FilterOperator.SavedQueryTags,
+ operator: FilterOperator.SavedQueryTagById,
fetchSelects: loadTags,
},
]
diff --git a/superset/charts/api.py b/superset/charts/api.py
index d32f1f665a..d814d0fa02 100644
--- a/superset/charts/api.py
+++ b/superset/charts/api.py
@@ -39,7 +39,8 @@ from superset.charts.filters import (
ChartFilter,
ChartHasCreatedByFilter,
ChartOwnedCreatedFavoredByMeFilter,
- ChartTagFilter,
+ ChartTagIdFilter,
+ ChartTagNameFilter,
)
from superset.charts.schemas import (
CHART_SCHEMAS,
@@ -238,7 +239,7 @@ class ChartRestApi(BaseSupersetModelRestApi):
],
"slice_name": [ChartAllTextFilter],
"created_by": [ChartHasCreatedByFilter, ChartCreatedByMeFilter],
- "tags": [ChartTagFilter],
+ "tags": [ChartTagNameFilter, ChartTagIdFilter],
}
# Will just affect _info endpoint
edit_columns = ["slice_name"]
diff --git a/superset/charts/filters.py b/superset/charts/filters.py
index a7543ba284..f9748dd0ec 100644
--- a/superset/charts/filters.py
+++ b/superset/charts/filters.py
@@ -26,10 +26,11 @@ from superset.connectors.sqla import models
from superset.connectors.sqla.models import SqlaTable
from superset.models.core import FavStar
from superset.models.slice import Slice
+from superset.tags.filters import BaseTagIdFilter, BaseTagNameFilter
from superset.utils.core import get_user_id
from superset.utils.filters import get_dataset_access_filters
from superset.views.base import BaseFilter
-from superset.views.base_api import BaseFavoriteFilter, BaseTagFilter
+from superset.views.base_api import BaseFavoriteFilter
class ChartAllTextFilter(BaseFilter): # pylint: disable=too-few-public-methods
@@ -60,9 +61,10 @@ class ChartFavoriteFilter(BaseFavoriteFilter): # pylint:
disable=too-few-public
model = Slice
-class ChartTagFilter(BaseTagFilter): # pylint: disable=too-few-public-methods
+class ChartTagNameFilter(BaseTagNameFilter): # pylint:
disable=too-few-public-methods
"""
- Custom filter for the GET list that filters all dashboards that a user has
favored
+ Custom filter for the GET list that filters all charts associated with
+ a certain tag (by its name).
"""
arg_name = "chart_tags"
@@ -70,6 +72,17 @@ class ChartTagFilter(BaseTagFilter): # pylint:
disable=too-few-public-methods
model = Slice
+class ChartTagIdFilter(BaseTagIdFilter): # pylint:
disable=too-few-public-methods
+ """
+ Custom filter for the GET list that filters all charts associated with
+ a certain tag (by its ID).
+ """
+
+ arg_name = "chart_tag_id"
+ class_name = "slice"
+ model = Slice
+
+
class ChartCertifiedFilter(BaseFilter): # pylint:
disable=too-few-public-methods
"""
Custom filter for the GET list that filters all certified charts
diff --git a/superset/dashboards/api.py b/superset/dashboards/api.py
index 2967fd1abd..716e4c416d 100644
--- a/superset/dashboards/api.py
+++ b/superset/dashboards/api.py
@@ -60,7 +60,8 @@ from superset.dashboards.filters import (
DashboardCreatedByMeFilter,
DashboardFavoriteFilter,
DashboardHasCreatedByFilter,
- DashboardTagFilter,
+ DashboardTagIdFilter,
+ DashboardTagNameFilter,
DashboardTitleOrSlugFilter,
FilterRelatedRoles,
)
@@ -244,7 +245,7 @@ class DashboardRestApi(BaseSupersetModelRestApi):
"dashboard_title": [DashboardTitleOrSlugFilter],
"id": [DashboardFavoriteFilter, DashboardCertifiedFilter],
"created_by": [DashboardCreatedByMeFilter,
DashboardHasCreatedByFilter],
- "tags": [DashboardTagFilter],
+ "tags": [DashboardTagIdFilter, DashboardTagNameFilter],
}
base_order = ("changed_on", "desc")
diff --git a/superset/dashboards/filters.py b/superset/dashboards/filters.py
index 0c7878d508..9a4c496b20 100644
--- a/superset/dashboards/filters.py
+++ b/superset/dashboards/filters.py
@@ -29,10 +29,11 @@ from superset.models.dashboard import Dashboard, is_uuid
from superset.models.embedded_dashboard import EmbeddedDashboard
from superset.models.slice import Slice
from superset.security.guest_token import GuestTokenResourceType, GuestUser
+from superset.tags.filters import BaseTagIdFilter, BaseTagNameFilter
from superset.utils.core import get_user_id
from superset.utils.filters import get_dataset_access_filters
from superset.views.base import BaseFilter
-from superset.views.base_api import BaseFavoriteFilter, BaseTagFilter
+from superset.views.base_api import BaseFavoriteFilter
class DashboardTitleOrSlugFilter(BaseFilter): # pylint:
disable=too-few-public-methods
@@ -78,9 +79,10 @@ class DashboardFavoriteFilter( # pylint:
disable=too-few-public-methods
model = Dashboard
-class DashboardTagFilter(BaseTagFilter): # pylint:
disable=too-few-public-methods
+class DashboardTagNameFilter(BaseTagNameFilter): # pylint:
disable=too-few-public-methods
"""
- Custom filter for the GET list that filters all dashboards that a user has
favored
+ Custom filter for the GET list that filters all dashboards associated with
+ a certain tag (by its name).
"""
arg_name = "dashboard_tags"
@@ -88,6 +90,17 @@ class DashboardTagFilter(BaseTagFilter): # pylint:
disable=too-few-public-metho
model = Dashboard
+class DashboardTagIdFilter(BaseTagIdFilter): # pylint:
disable=too-few-public-methods
+ """
+ Custom filter for the GET list that filters all dashboards associated with
+ a certain tag (by its ID).
+ """
+
+ arg_name = "dashboard_tag_id"
+ class_name = "Dashboard"
+ model = Dashboard
+
+
class DashboardAccessFilter(BaseFilter): # pylint:
disable=too-few-public-methods
"""
List dashboards with the following criteria:
diff --git a/superset/queries/saved_queries/api.py
b/superset/queries/saved_queries/api.py
index cd7b04193f..4e34a75039 100644
--- a/superset/queries/saved_queries/api.py
+++ b/superset/queries/saved_queries/api.py
@@ -25,7 +25,6 @@ from flask_appbuilder.api import expose, protect, rison, safe
from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask_babel import ngettext
-from superset import is_feature_enabled
from superset.commands.importers.exceptions import (
IncorrectFormatError,
NoValidFilesFoundError,
@@ -46,7 +45,8 @@ from superset.queries.saved_queries.filters import (
SavedQueryAllTextFilter,
SavedQueryFavoriteFilter,
SavedQueryFilter,
- SavedQueryTagFilter,
+ SavedQueryTagIdFilter,
+ SavedQueryTagNameFilter,
)
from superset.queries.saved_queries.schemas import (
get_delete_ids_schema,
@@ -124,9 +124,10 @@ class SavedQueryRestApi(BaseSupersetModelRestApi):
"schema",
"sql",
"sql_tables",
+ "tags.id",
+ "tags.name",
+ "tags.type",
]
- if is_feature_enabled("TAGGING_SYSTEM"):
- list_columns += ["tags.id", "tags.name", "tags.type"]
list_select_columns = list_columns + ["changed_by_fk", "changed_on"]
add_columns = [
"db_id",
@@ -161,15 +162,13 @@ class SavedQueryRestApi(BaseSupersetModelRestApi):
"schema",
"created_by",
"changed_by",
+ "tags",
]
- if is_feature_enabled("TAGGING_SYSTEM"):
- search_columns += ["tags"]
search_filters = {
"id": [SavedQueryFavoriteFilter],
"label": [SavedQueryAllTextFilter],
+ "tags": [SavedQueryTagNameFilter, SavedQueryTagIdFilter],
}
- if is_feature_enabled("TAGGING_SYSTEM"):
- search_filters["tags"] = [SavedQueryTagFilter]
apispec_parameter_schemas = {
"get_delete_ids_schema": get_delete_ids_schema,
diff --git a/superset/queries/saved_queries/filters.py
b/superset/queries/saved_queries/filters.py
index 90e356163f..821f42d6f1 100644
--- a/superset/queries/saved_queries/filters.py
+++ b/superset/queries/saved_queries/filters.py
@@ -23,8 +23,9 @@ from sqlalchemy import or_
from sqlalchemy.orm.query import Query
from superset.models.sql_lab import SavedQuery
+from superset.tags.filters import BaseTagIdFilter, BaseTagNameFilter
from superset.views.base import BaseFilter
-from superset.views.base_api import BaseFavoriteFilter, BaseTagFilter
+from superset.views.base_api import BaseFavoriteFilter
class SavedQueryAllTextFilter(BaseFilter): # pylint:
disable=too-few-public-methods
@@ -56,9 +57,10 @@ class SavedQueryFavoriteFilter(BaseFavoriteFilter): #
pylint: disable=too-few-p
model = SavedQuery
-class SavedQueryTagFilter(BaseTagFilter): # pylint:
disable=too-few-public-methods
+class SavedQueryTagNameFilter(BaseTagNameFilter): # pylint:
disable=too-few-public-methods
"""
- Custom filter for the GET list that filters all dashboards that a user has
favored
+ Custom filter for the GET list that filters all saved queries associated
with
+ a certain tag (by its name).
"""
arg_name = "saved_query_tags"
@@ -66,6 +68,17 @@ class SavedQueryTagFilter(BaseTagFilter): # pylint:
disable=too-few-public-meth
model = SavedQuery
+class SavedQueryTagIdFilter(BaseTagIdFilter): # pylint:
disable=too-few-public-methods
+ """
+ Custom filter for the GET list that filters all saved queries associated
with
+ a certain tag (by its ID).
+ """
+
+ arg_name = "saved_query_tag_id"
+ class_name = "query"
+ model = SavedQuery
+
+
class SavedQueryFilter(BaseFilter): # pylint: disable=too-few-public-methods
def apply(self, query: BaseQuery, value: Any) -> BaseQuery:
"""
diff --git a/superset/tags/filters.py b/superset/tags/filters.py
index ff6be712d3..81df9fd7b9 100644
--- a/superset/tags/filters.py
+++ b/superset/tags/filters.py
@@ -14,9 +14,18 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
+from typing import Any
+
from flask_babel import lazy_gettext as _
from sqlalchemy.orm import Query
+from superset.connectors.sqla.models import SqlaTable
+from superset.extensions import db
+from superset.models.dashboard import Dashboard
+from superset.models.slice import Slice
+from superset.sql_lab import Query as SqllabQuery
from superset.tags.models import Tag, TagType
from superset.views.base import BaseFilter
@@ -37,3 +46,48 @@ class UserCreatedTagTypeFilter(BaseFilter): # pylint:
disable=too-few-public-me
if value is False:
return query.filter(Tag.type != TagType.custom)
return query
+
+
+class BaseTagNameFilter(BaseFilter): # pylint: disable=too-few-public-methods
+ """
+ Base Custom filter for the GET list that filters all dashboards, slices
+ and saved queries associated with a tag (by the tag name).
+ """
+
+ name = _("Is tagged")
+ arg_name = ""
+ class_name = ""
+ """ The Tag class_name to user """
+ model: type[Dashboard | Slice | SqllabQuery | SqlaTable] = Dashboard
+ """ The SQLAlchemy model """
+
+ def apply(self, query: Query, value: Any) -> Query:
+ ilike_value = f"%{value}%"
+ tags_query = (
+ db.session.query(self.model.id)
+ .join(self.model.tags)
+ .filter(Tag.name.ilike(ilike_value))
+ )
+ return query.filter(self.model.id.in_(tags_query))
+
+
+class BaseTagIdFilter(BaseFilter): # pylint: disable=too-few-public-methods
+ """
+ Base Custom filter for the GET list that filters all dashboards, slices
+ and saved queries associated with a tag (by the tag ID).
+ """
+
+ name = _("Is tagged")
+ arg_name = ""
+ class_name = ""
+ """ The Tag class_name to user """
+ model: type[Dashboard | Slice | SqllabQuery | SqlaTable] = Dashboard
+ """ The SQLAlchemy model """
+
+ def apply(self, query: Query, value: Any) -> Query:
+ tags_query = (
+ db.session.query(self.model.id)
+ .join(self.model.tags)
+ .filter(Tag.id == value)
+ )
+ return query.filter(self.model.id.in_(tags_query))
diff --git a/superset/views/base_api.py b/superset/views/base_api.py
index 5c71147517..8240481ada 100644
--- a/superset/views/base_api.py
+++ b/superset/views/base_api.py
@@ -31,7 +31,6 @@ from marshmallow import fields, Schema
from sqlalchemy import and_, distinct, func
from sqlalchemy.orm.query import Query
-from superset.connectors.sqla.models import SqlaTable
from superset.exceptions import InvalidPayloadFormatError
from superset.extensions import db, event_logger, security_manager,
stats_logger_manager
from superset.models.core import FavStar
@@ -40,7 +39,6 @@ from superset.models.slice import Slice
from superset.schemas import error_payload_content
from superset.sql_lab import Query as SqllabQuery
from superset.superset_typing import FlaskResponse
-from superset.tags.models import Tag
from superset.utils.core import get_user_id, time_function
from superset.views.error_handling import handle_api_exception
@@ -168,29 +166,6 @@ class BaseFavoriteFilter(BaseFilter): # pylint:
disable=too-few-public-methods
return query.filter(and_(~self.model.id.in_(users_favorite_query)))
-class BaseTagFilter(BaseFilter): # pylint: disable=too-few-public-methods
- """
- Base Custom filter for the GET list that filters all dashboards, slices
- that a user has favored or not
- """
-
- name = _("Is tagged")
- arg_name = ""
- class_name = ""
- """ The Tag class_name to user """
- model: type[Dashboard | Slice | SqllabQuery | SqlaTable] = Dashboard
- """ The SQLAlchemy model """
-
- def apply(self, query: Query, value: Any) -> Query:
- ilike_value = f"%{value}%"
- tags_query = (
- db.session.query(self.model.id)
- .join(self.model.tags)
- .filter(Tag.name.ilike(ilike_value))
- )
- return query.filter(self.model.id.in_(tags_query))
-
-
class BaseSupersetApiMixin:
csrf_exempt = False
diff --git a/tests/integration_tests/base_tests.py
b/tests/integration_tests/base_tests.py
index 0e407b8657..b3a000c601 100644
--- a/tests/integration_tests/base_tests.py
+++ b/tests/integration_tests/base_tests.py
@@ -24,6 +24,7 @@ from typing import Any, Union, Optional
from unittest.mock import Mock, patch, MagicMock
import pandas as pd
+import prison
from flask import Response, g
from flask_appbuilder.security.sqla import models as ab_models
from flask_testing import TestCase
@@ -33,6 +34,7 @@ from sqlalchemy.orm import Session # noqa: F401
from sqlalchemy.sql import func
from sqlalchemy.dialects.mysql import dialect
+from tests.integration_tests.constants import ADMIN_USERNAME
from tests.integration_tests.test_app import app, login
from superset.sql_parse import CtasMethod
from superset import db, security_manager
@@ -589,6 +591,20 @@ class SupersetTestCase(TestCase):
db.session.commit()
return dashboard
+ def get_list(
+ self,
+ asset_type: str,
+ filter: dict[str, Any] = {},
+ username: str = ADMIN_USERNAME,
+ ) -> Response:
+ """
+ Get list of assets, by default using admin account. Can be filtered.
+ """
+ self.login(username)
+ uri = f"api/v1/{asset_type}/?q={prison.dumps(filter)}"
+ response = self.get_assert_metric(uri, "get_list")
+ return response
+
@contextmanager
def db_insert_temp_object(obj: DeclarativeMeta):
diff --git a/tests/integration_tests/charts/api_tests.py
b/tests/integration_tests/charts/api_tests.py
index a9af7c12b3..0f5948ad7b 100644
--- a/tests/integration_tests/charts/api_tests.py
+++ b/tests/integration_tests/charts/api_tests.py
@@ -31,7 +31,7 @@ from sqlalchemy.sql import func
from superset.commands.chart.data.get_data_command import ChartDataCommand
from superset.commands.chart.exceptions import ChartDataQueryFailedError
from superset.connectors.sqla.models import SqlaTable
-from superset.extensions import cache_manager, db, security_manager # noqa:
F401
+from superset.extensions import cache_manager, db, security_manager
from superset.models.core import Database, FavStar, FavStarClassName
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
@@ -39,11 +39,8 @@ from superset.reports.models import ReportSchedule,
ReportScheduleType
from superset.tags.models import ObjectType, Tag, TaggedObject, TagType
from superset.utils import json
from superset.utils.core import get_example_default_schema
-from superset.utils.database import get_example_database # noqa: F401
-from superset.viz import viz_types # noqa: F401
from tests.integration_tests.base_api_tests import ApiOwnersTestCaseMixin
from tests.integration_tests.base_tests import SupersetTestCase
-from tests.integration_tests.conftest import with_feature_flags # noqa: F401
from tests.integration_tests.constants import (
ADMIN_USERNAME,
ALPHA_USERNAME,
@@ -64,6 +61,10 @@ from tests.integration_tests.fixtures.importexport import (
dataset_config,
dataset_metadata_config,
)
+from tests.integration_tests.fixtures.tags import (
+ create_custom_tags, # noqa: F401
+ get_filter_params,
+)
from tests.integration_tests.fixtures.unicode_dashboard import (
load_unicode_dashboard_with_slice, # noqa: F401
load_unicode_data, # noqa: F401
@@ -200,27 +201,8 @@ class TestChartApi(ApiOwnersTestCaseMixin,
InsertChartMixin, SupersetTestCase):
db.session.delete(self.chart)
db.session.commit()
- @pytest.fixture()
- def create_custom_tags(self):
- with self.create_app().app_context():
- tags: list[Tag] = []
- for tag_name in {"one_tag", "new_tag"}:
- tag = Tag(
- name=tag_name,
- type="custom",
- )
- db.session.add(tag)
- db.session.commit()
- tags.append(tag)
-
- yield tags
-
- for tags in tags:
- db.session.delete(tags)
- db.session.commit()
-
- @pytest.fixture()
- def create_chart_with_tag(self, create_custom_tags):
+ @pytest.fixture
+ def create_chart_with_tag(self, create_custom_tags): # noqa: F811
with self.create_app().app_context():
alpha_user = self.get_user(ALPHA_USERNAME)
@@ -230,7 +212,7 @@ class TestChartApi(ApiOwnersTestCaseMixin,
InsertChartMixin, SupersetTestCase):
1,
)
- tag = db.session.query(Tag).filter(Tag.name == "one_tag").first()
+ tag = db.session.query(Tag).filter(Tag.name == "first_tag").first()
tag_association = TaggedObject(
object_id=chart.id,
object_type=ObjectType.chart,
@@ -247,6 +229,70 @@ class TestChartApi(ApiOwnersTestCaseMixin,
InsertChartMixin, SupersetTestCase):
db.session.delete(chart)
db.session.commit()
+ @pytest.fixture
+ def create_charts_some_with_tags(self, create_custom_tags): # noqa: F811
+ """
+ Fixture that creates 4 charts:
+ - ``first_chart`` is associated with ``first_tag``
+ - ``second_chart`` is associated with ``second_tag``
+ - ``third_chart`` is associated with both ``first_tag`` and
``second_tag``
+ - ``fourth_chart`` is not associated with any tag
+
+ Relies on the ``create_custom_tags`` fixture for the tag creation.
+ """
+ with self.create_app().app_context():
+ admin_user = self.get_user(ADMIN_USERNAME)
+
+ tags = {
+ "first_tag": db.session.query(Tag)
+ .filter(Tag.name == "first_tag")
+ .first(),
+ "second_tag": db.session.query(Tag)
+ .filter(Tag.name == "second_tag")
+ .first(),
+ }
+
+ chart_names = ["first_chart", "second_chart", "third_chart",
"fourth_chart"]
+ charts = [
+ self.insert_chart(name, [admin_user.id], 1) for name in
chart_names
+ ]
+
+ tag_associations = [
+ TaggedObject(
+ object_id=charts[0].id,
+ object_type=ObjectType.chart,
+ tag=tags["first_tag"],
+ ),
+ TaggedObject(
+ object_id=charts[1].id,
+ object_type=ObjectType.chart,
+ tag=tags["second_tag"],
+ ),
+ TaggedObject(
+ object_id=charts[2].id,
+ object_type=ObjectType.chart,
+ tag=tags["first_tag"],
+ ),
+ TaggedObject(
+ object_id=charts[2].id,
+ object_type=ObjectType.chart,
+ tag=tags["second_tag"],
+ ),
+ ]
+
+ for association in tag_associations:
+ db.session.add(association)
+ db.session.commit()
+
+ yield charts
+
+ # rollback changes
+ for association in tag_associations:
+ db.session.delete(association)
+ for chart in charts:
+ db.session.delete(chart)
+ db.session.commit()
+
def test_info_security_chart(self):
"""
Chart API: Test info security
@@ -1131,6 +1177,55 @@ class TestChartApi(ApiOwnersTestCaseMixin,
InsertChartMixin, SupersetTestCase):
assert len(result) == 1
assert result[0]["slice_name"] == self.chart.slice_name
+ @pytest.mark.usefixtures("create_charts_some_with_tags")
+ def test_get_charts_tag_filters(self):
+ """
+ Chart API: Test get charts with tag filters
+ """
+ # Get custom tags relationship
+ tags = {
+ "first_tag": db.session.query(Tag).filter(Tag.name ==
"first_tag").first(),
+ "second_tag": db.session.query(Tag)
+ .filter(Tag.name == "second_tag")
+ .first(),
+ "third_tag": db.session.query(Tag).filter(Tag.name ==
"third_tag").first(),
+ }
+ chart_tag_relationship = {
+ tag.name: db.session.query(Slice.id)
+ .join(Slice.tags)
+ .filter(Tag.id == tag.id)
+ .all()
+ for tag in tags.values()
+ }
+
+ # Validate API results for each tag
+ for tag_name, tag in tags.items():
+ expected_charts = chart_tag_relationship[tag_name]
+
+ # Filter by tag ID
+ filter_params = get_filter_params("chart_tag_id", tag.id)
+ response_by_id = self.get_list("chart", filter_params)
+ self.assertEqual(response_by_id.status_code, 200)
+ data_by_id = json.loads(response_by_id.data.decode("utf-8"))
+
+ # Filter by tag name
+ filter_params = get_filter_params("chart_tags", tag.name)
+ response_by_name = self.get_list("chart", filter_params)
+ self.assertEqual(response_by_name.status_code, 200)
+ data_by_name = json.loads(response_by_name.data.decode("utf-8"))
+
+ # Compare results
+ self.assertEqual(
+ data_by_id["count"],
+ data_by_name["count"],
+ len(expected_charts),
+ )
+ self.assertEqual(
+ set(chart["id"] for chart in data_by_id["result"]),
+ set(chart["id"] for chart in data_by_name["result"]),
+ set(chart.id for chart in expected_charts),
+ )
+
def test_get_charts_changed_on(self):
"""
Dashboard API: Test get charts changed on
@@ -2059,7 +2154,7 @@ class TestChartApi(ApiOwnersTestCaseMixin,
InsertChartMixin, SupersetTestCase):
chart = (
db.session.query(Slice).filter(Slice.slice_name == "chart with
tag").first()
)
- new_tag = db.session.query(Tag).filter(Tag.name == "new_tag").one()
+ new_tag = db.session.query(Tag).filter(Tag.name == "second_tag").one()
# get existing tag and add a new one
new_tags = [tag.id for tag in chart.tags if tag.type == TagType.custom]
@@ -2118,7 +2213,7 @@ class TestChartApi(ApiOwnersTestCaseMixin,
InsertChartMixin, SupersetTestCase):
chart = (
db.session.query(Slice).filter(Slice.slice_name == "chart with
tag").first()
)
- new_tag = db.session.query(Tag).filter(Tag.name == "new_tag").one()
+ new_tag = db.session.query(Tag).filter(Tag.name == "second_tag").one()
# get existing tag and add a new one
new_tags = [tag.id for tag in chart.tags if tag.type == TagType.custom]
@@ -2183,7 +2278,7 @@ class TestChartApi(ApiOwnersTestCaseMixin,
InsertChartMixin, SupersetTestCase):
chart = (
db.session.query(Slice).filter(Slice.slice_name == "chart with
tag").first()
)
- new_tag = db.session.query(Tag).filter(Tag.name == "new_tag").one()
+ new_tag = db.session.query(Tag).filter(Tag.name == "second_tag").one()
# get existing tag and add a new one
new_tags = [tag.id for tag in chart.tags if tag.type == TagType.custom]
diff --git a/tests/integration_tests/dashboards/api_tests.py
b/tests/integration_tests/dashboards/api_tests.py
index 99a784e95f..b4e2958ccd 100644
--- a/tests/integration_tests/dashboards/api_tests.py
+++ b/tests/integration_tests/dashboards/api_tests.py
@@ -30,7 +30,7 @@ import yaml
from freezegun import freeze_time
from sqlalchemy import and_
-from superset import app, db, security_manager # noqa: F401
+from superset import db, security_manager # noqa: F401
from superset.models.dashboard import Dashboard
from superset.models.core import FavStar, FavStarClassName
from superset.reports.models import ReportSchedule, ReportScheduleType
@@ -41,7 +41,6 @@ from superset.utils import json
from tests.integration_tests.base_api_tests import ApiOwnersTestCaseMixin
from tests.integration_tests.base_tests import SupersetTestCase
-from tests.integration_tests.conftest import with_feature_flags # noqa: F401
from tests.integration_tests.constants import (
ADMIN_USERNAME,
ALPHA_USERNAME,
@@ -56,6 +55,10 @@ from tests.integration_tests.fixtures.importexport import (
dataset_config,
dataset_metadata_config,
)
+from tests.integration_tests.fixtures.tags import (
+ create_custom_tags, # noqa: F401
+ get_filter_params,
+)
from tests.integration_tests.utils.get_dashboards import get_dashboards_ids
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices, # noqa: F401
@@ -169,27 +172,8 @@ class TestDashboardApi(ApiOwnersTestCaseMixin,
InsertChartMixin, SupersetTestCas
db.session.delete(dashboard)
db.session.commit()
- @pytest.fixture()
- def create_custom_tags(self):
- with self.create_app().app_context():
- tags: list[Tag] = []
- for tag_name in {"one_tag", "new_tag"}:
- tag = Tag(
- name=tag_name,
- type="custom",
- )
- db.session.add(tag)
- db.session.commit()
- tags.append(tag)
-
- yield tags
-
- for tags in tags:
- db.session.delete(tags)
- db.session.commit()
-
- @pytest.fixture()
- def create_dashboard_with_tag(self, create_custom_tags):
+ @pytest.fixture
+ def create_dashboard_with_tag(self, create_custom_tags): # noqa: F811
with self.create_app().app_context():
gamma = self.get_user("gamma")
@@ -198,7 +182,7 @@ class TestDashboardApi(ApiOwnersTestCaseMixin,
InsertChartMixin, SupersetTestCas
None,
[gamma.id],
)
- tag = db.session.query(Tag).filter(Tag.name == "one_tag").first()
+ tag = db.session.query(Tag).filter(Tag.name == "first_tag").first()
tag_association = TaggedObject(
object_id=dashboard.id,
object_type=ObjectType.dashboard,
@@ -215,6 +199,76 @@ class TestDashboardApi(ApiOwnersTestCaseMixin,
InsertChartMixin, SupersetTestCas
db.session.delete(dashboard)
db.session.commit()
+ @pytest.fixture
+ def create_dashboards_some_with_tags(self, create_custom_tags): # noqa:
F811
+ """
+ Fixture that creates 4 dashboards:
+ - ``first_dashboard`` is associated with ``first_tag``
+ - ``second_dashboard`` is associated with ``second_tag``
+ - ``third_dashboard`` is associated with both ``first_tag`` and
``second_tag``
+ - ``fourth_dashboard`` is not associated with any tag
+
+ Relies on the ``create_custom_tags`` fixture for the tag creation.
+ """
+ with self.create_app().app_context():
+ admin_user = self.get_user(ADMIN_USERNAME)
+
+ tags = {
+ "first_tag": db.session.query(Tag)
+ .filter(Tag.name == "first_tag")
+ .first(),
+ "second_tag": db.session.query(Tag)
+ .filter(Tag.name == "second_tag")
+ .first(),
+ }
+
+ dashboard_names = [
+ "first_dashboard",
+ "second_dashboard",
+ "third_dashboard",
+ "fourth_dashboard",
+ ]
+ dashboards = [
+ self.insert_dashboard(name, None, [admin_user.id])
+ for name in dashboard_names
+ ]
+
+ tag_associations = [
+ TaggedObject(
+ object_id=dashboards[0].id,
+ object_type=ObjectType.chart,
+ tag=tags["first_tag"],
+ ),
+ TaggedObject(
+ object_id=dashboards[1].id,
+ object_type=ObjectType.chart,
+ tag=tags["second_tag"],
+ ),
+ TaggedObject(
+ object_id=dashboards[2].id,
+ object_type=ObjectType.chart,
+ tag=tags["first_tag"],
+ ),
+ TaggedObject(
+ object_id=dashboards[2].id,
+ object_type=ObjectType.chart,
+ tag=tags["second_tag"],
+ ),
+ ]
+
+ for association in tag_associations:
+ db.session.add(association)
+ db.session.commit()
+
+ yield dashboards
+
+ # rollback changes
+ for association in tag_associations:
+ db.session.delete(association)
+ for chart in dashboards:
+ db.session.delete(chart)
+ db.session.commit()
+
@pytest.mark.usefixtures("load_world_bank_dashboard_with_slices")
def test_get_dashboard_datasets(self):
self.login(ADMIN_USERNAME)
@@ -710,6 +764,55 @@ class TestDashboardApi(ApiOwnersTestCaseMixin,
InsertChartMixin, SupersetTestCas
expected_model.dashboard_title ==
data["result"][i]["dashboard_title"]
)
+ @pytest.mark.usefixtures("create_dashboards_some_with_tags")
+ def test_get_dashboards_tag_filters(self):
+ """
+ Dashboard API: Test get dashboards with tag filters
+ """
+ # Get custom tags relationship
+ tags = {
+ "first_tag": db.session.query(Tag).filter(Tag.name ==
"first_tag").first(),
+ "second_tag": db.session.query(Tag)
+ .filter(Tag.name == "second_tag")
+ .first(),
+ "third_tag": db.session.query(Tag).filter(Tag.name ==
"third_tag").first(),
+ }
+ dashboard_tag_relationship = {
+ tag.name: db.session.query(Dashboard.id)
+ .join(Dashboard.tags)
+ .filter(Tag.id == tag.id)
+ .all()
+ for tag in tags.values()
+ }
+
+ # Validate API results for each tag
+ for tag_name, tag in tags.items():
+ expected_dashboards = dashboard_tag_relationship[tag_name]
+
+ # Filter by tag ID
+ filter_params = get_filter_params("dashboard_tag_id", tag.id)
+ response_by_id = self.get_list("dashboard", filter_params)
+ self.assertEqual(response_by_id.status_code, 200)
+ data_by_id = json.loads(response_by_id.data.decode("utf-8"))
+
+ # Filter by tag name
+ filter_params = get_filter_params("dashboard_tags", tag.name)
+ response_by_name = self.get_list("dashboard", filter_params)
+ self.assertEqual(response_by_name.status_code, 200)
+ data_by_name = json.loads(response_by_name.data.decode("utf-8"))
+
+ # Compare results
+ self.assertEqual(
+ data_by_id["count"],
+ data_by_name["count"],
+ len(expected_dashboards),
+ )
+ self.assertEqual(
+ set(chart["id"] for chart in data_by_id["result"]),
+ set(chart["id"] for chart in data_by_name["result"]),
+ set(chart.id for chart in expected_dashboards),
+ )
+
@pytest.mark.usefixtures("create_dashboards")
def test_get_current_user_favorite_status(self):
"""
@@ -2504,7 +2607,7 @@ class TestDashboardApi(ApiOwnersTestCaseMixin,
InsertChartMixin, SupersetTestCas
.filter(Dashboard.dashboard_title == "dash with tag")
.first()
)
- new_tag = db.session.query(Tag).filter(Tag.name == "new_tag").one()
+ new_tag = db.session.query(Tag).filter(Tag.name == "second_tag").one()
# get existing tag and add a new one
new_tags = [tag.id for tag in dashboard.tags if tag.type ==
TagType.custom]
@@ -2566,7 +2669,7 @@ class TestDashboardApi(ApiOwnersTestCaseMixin,
InsertChartMixin, SupersetTestCas
.filter(Dashboard.dashboard_title == "dash with tag")
.first()
)
- new_tag = db.session.query(Tag).filter(Tag.name == "new_tag").one()
+ new_tag = db.session.query(Tag).filter(Tag.name == "second_tag").one()
# get existing tag and add a new one
new_tags = [tag.id for tag in dashboard.tags if tag.type ==
TagType.custom]
@@ -2580,7 +2683,7 @@ class TestDashboardApi(ApiOwnersTestCaseMixin,
InsertChartMixin, SupersetTestCas
# Clean up system tags
tag_list = [tag.id for tag in model.tags if tag.type == TagType.custom]
- self.assertEqual(tag_list, new_tags)
+ self.assertEqual(sorted(tag_list), sorted(new_tags))
security_manager.add_permission_role(gamma_role, write_tags_perm)
@@ -2635,7 +2738,7 @@ class TestDashboardApi(ApiOwnersTestCaseMixin,
InsertChartMixin, SupersetTestCas
.filter(Dashboard.dashboard_title == "dash with tag")
.first()
)
- new_tag = db.session.query(Tag).filter(Tag.name == "new_tag").one()
+ new_tag = db.session.query(Tag).filter(Tag.name == "second_tag").one()
# get existing tag and add a new one
new_tags = [tag.id for tag in dashboard.tags if tag.type ==
TagType.custom]
diff --git a/tests/integration_tests/fixtures/tags.py
b/tests/integration_tests/fixtures/tags.py
index 493b3295d8..90449957fa 100644
--- a/tests/integration_tests/fixtures/tags.py
+++ b/tests/integration_tests/fixtures/tags.py
@@ -17,7 +17,9 @@
import pytest
+from superset import db
from superset.tags.core import clear_sqla_event_listeners,
register_sqla_event_listeners
+from superset.tags.models import Tag
from tests.integration_tests.test_app import app
@@ -31,3 +33,36 @@ def with_tagging_system_feature():
yield
app.config["DEFAULT_FEATURE_FLAGS"]["TAGGING_SYSTEM"] = False
clear_sqla_event_listeners()
+
+
[email protected]
+def create_custom_tags():
+ with app.app_context():
+ tags: list[Tag] = []
+ for tag_name in {"first_tag", "second_tag", "third_tag"}:
+ tag = Tag(
+ name=tag_name,
+ type="custom",
+ )
+ db.session.add(tag)
+ db.session.commit()
+ tags.append(tag)
+
+ yield tags
+
+ for tags in tags:
+ db.session.delete(tags)
+ db.session.commit()
+
+
+# Helper function to return filter parameters
+def get_filter_params(opr, value):
+ return {
+ "filters": [
+ {
+ "col": "tags",
+ "opr": opr,
+ "value": value,
+ }
+ ]
+ }
diff --git a/tests/integration_tests/queries/saved_queries/api_tests.py
b/tests/integration_tests/queries/saved_queries/api_tests.py
index da203a7139..9b1184b1f7 100644
--- a/tests/integration_tests/queries/saved_queries/api_tests.py
+++ b/tests/integration_tests/queries/saved_queries/api_tests.py
@@ -32,6 +32,7 @@ from superset import db
from superset.models.core import Database
from superset.models.core import FavStar
from superset.models.sql_lab import SavedQuery
+from superset.tags.models import ObjectType, Tag, TaggedObject
from superset.utils.database import get_example_database
from superset.utils import json
@@ -42,6 +43,10 @@ from tests.integration_tests.fixtures.importexport import (
saved_queries_config,
saved_queries_metadata_config,
)
+from tests.integration_tests.fixtures.tags import (
+ create_custom_tags, # noqa: F401
+ get_filter_params,
+)
SAVED_QUERIES_FIXTURE_COUNT = 10
@@ -123,6 +128,73 @@ class TestSavedQueryApi(SupersetTestCase):
db.session.delete(fav_saved_query)
db.session.commit()
+ @pytest.fixture
+ def create_saved_queries_some_with_tags(self, create_custom_tags): #
noqa: F811
+ """
+ Fixture that creates 4 saved queries:
+ - ``first_query`` is associated with ``first_tag``
+ - ``second_query`` is associated with ``second_tag``
+ - ``third_query`` is associated with both ``first_tag`` and
``second_tag``
+ - ``fourth_query`` is not associated with any tag
+
+ Relies on the ``create_custom_tags`` fixture for the tag creation.
+ """
+ with self.create_app().app_context():
+ tags = {
+ "first_tag": db.session.query(Tag)
+ .filter(Tag.name == "first_tag")
+ .first(),
+ "second_tag": db.session.query(Tag)
+ .filter(Tag.name == "second_tag")
+ .first(),
+ }
+
+ query_labels = [
+ "first_query",
+ "second_query",
+ "third_query",
+ "fourth_query",
+ ]
+ queries = [
+ self.insert_default_saved_query(label=name) for name in
query_labels
+ ]
+
+ tag_associations = [
+ TaggedObject(
+ object_id=queries[0].id,
+ object_type=ObjectType.chart,
+ tag=tags["first_tag"],
+ ),
+ TaggedObject(
+ object_id=queries[1].id,
+ object_type=ObjectType.chart,
+ tag=tags["second_tag"],
+ ),
+ TaggedObject(
+ object_id=queries[2].id,
+ object_type=ObjectType.chart,
+ tag=tags["first_tag"],
+ ),
+ TaggedObject(
+ object_id=queries[2].id,
+ object_type=ObjectType.chart,
+ tag=tags["second_tag"],
+ ),
+ ]
+
+ for association in tag_associations:
+ db.session.add(association)
+ db.session.commit()
+
+ yield queries
+
+ # rollback changes
+ for association in tag_associations:
+ db.session.delete(association)
+ for chart in queries:
+ db.session.delete(chart)
+ db.session.commit()
+
@pytest.mark.usefixtures("create_saved_queries")
def test_get_list_saved_query(self):
"""
@@ -366,6 +438,55 @@ class TestSavedQueryApi(SupersetTestCase):
data = json.loads(rv.data.decode("utf-8"))
assert data["count"] == len(all_queries)
+ @pytest.mark.usefixtures("create_saved_queries_some_with_tags")
+ def test_get_saved_queries_tag_filters(self):
+ """
+ Saved Query API: Test get saved queries with tag filters
+ """
+ # Get custom tags relationship
+ tags = {
+ "first_tag": db.session.query(Tag).filter(Tag.name ==
"first_tag").first(),
+ "second_tag": db.session.query(Tag)
+ .filter(Tag.name == "second_tag")
+ .first(),
+ "third_tag": db.session.query(Tag).filter(Tag.name ==
"third_tag").first(),
+ }
+ saved_queries_tag_relationship = {
+ tag.name: db.session.query(SavedQuery.id)
+ .join(SavedQuery.tags)
+ .filter(Tag.id == tag.id)
+ .all()
+ for tag in tags.values()
+ }
+
+ # Validate API results for each tag
+ for tag_name, tag in tags.items():
+ expected_saved_queries = saved_queries_tag_relationship[tag_name]
+
+ # Filter by tag ID
+ filter_params = get_filter_params("saved_query_tag_id", tag.id)
+ response_by_id = self.get_list("saved_query", filter_params)
+ self.assertEqual(response_by_id.status_code, 200)
+ data_by_id = json.loads(response_by_id.data.decode("utf-8"))
+
+ # Filter by tag name
+ filter_params = get_filter_params("saved_query_tags", tag.name)
+ response_by_name = self.get_list("saved_query", filter_params)
+ self.assertEqual(response_by_name.status_code, 200)
+ data_by_name = json.loads(response_by_name.data.decode("utf-8"))
+
+ # Compare results
+ self.assertEqual(
+ data_by_id["count"],
+ data_by_name["count"],
+ len(expected_saved_queries),
+ )
+ self.assertEqual(
+ set(query["id"] for query in data_by_id["result"]),
+ set(query["id"] for query in data_by_name["result"]),
+ set(query.id for query in expected_saved_queries),
+ )
+
@pytest.mark.usefixtures("create_saved_queries")
def test_get_saved_query_favorite_filter(self):
"""
diff --git a/tests/unit_tests/tags/filters_test.py
b/tests/unit_tests/tags/filters_test.py
new file mode 100644
index 0000000000..fadf6216d3
--- /dev/null
+++ b/tests/unit_tests/tags/filters_test.py
@@ -0,0 +1,85 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+from flask_appbuilder import Model
+from flask_appbuilder.models.sqla.interface import SQLAInterface
+from sqlalchemy.orm.session import Session
+
+from superset.models.dashboard import Dashboard
+from superset.models.slice import Slice
+from superset.models.sql_lab import SavedQuery
+from superset.tags.filters import BaseTagIdFilter, BaseTagNameFilter
+
+FILTER_MODELS = [Slice, Dashboard, SavedQuery]
+OBJECT_TYPES = {
+ "dashboards": "dashboard",
+ "slices": "chart",
+ "saved_query": "query",
+}
+
+
[email protected]("model", FILTER_MODELS)
[email protected]("name", ["my_tag", "test tag", "blaah"])
+def test_base_tag_filter_by_name(session: Session, model: Model, name: str) ->
None:
+ table = model.__tablename__
+ engine = session.get_bind()
+ query = session.query(model)
+ filter = BaseTagNameFilter("tags", SQLAInterface(model))
+ final_query = filter.apply(query, name)
+ compiled_query = final_query.statement.compile(
+ engine,
+ compile_kwargs={"literal_binds": True},
+ )
+
+ # Assert the JOIN clause is correct
+ assert (
+ f"FROM {table} JOIN tagged_object AS tagged_object_1 ON {table}.id "
+ "= tagged_object_1.object_id AND tagged_object_1.object_type = "
+ f"'{OBJECT_TYPES.get(table)}' JOIN tag ON tagged_object_1.tag_id =
tag.id"
+ ) in str(compiled_query)
+
+ # Assert the WHERE clause is correct
+ assert str(compiled_query).endswith(
+ f"WHERE lower(tag.name) LIKE lower('%{name}%'))"
+ )
+
+
[email protected]("model", FILTER_MODELS)
[email protected]("id", [3, 5, 8])
+def test_base_tag_filter_by_id(session: Session, model: Model, id: int) ->
None:
+ table = model.__tablename__
+ engine = session.get_bind()
+ query = session.query(model)
+
+ filter = BaseTagIdFilter("tags", SQLAInterface(model))
+ filter.id_based_filter = True
+ final_query = filter.apply(query, id)
+ compiled_query = final_query.statement.compile(
+ engine,
+ compile_kwargs={"literal_binds": True},
+ )
+
+ # Assert the JOIN clause is correct
+ assert (
+ f"FROM {table} JOIN tagged_object AS tagged_object_1 ON {table}.id "
+ "= tagged_object_1.object_id AND tagged_object_1.object_type = "
+ f"'{OBJECT_TYPES.get(table)}' JOIN tag ON tagged_object_1.tag_id =
tag.id"
+ ) in str(compiled_query)
+
+ # Assert the WHERE clause is correct
+ assert str(compiled_query).endswith(f"WHERE tag.id = {id})")