These were all doing the same thing. Make things more generic. We also speed up test (inadvertently) by using the 'patch_id' attribute of the 'Check' model rather than 'patch.id', thus avoiding the JOIN.
Signed-off-by: Stephen Finucane <step...@that.guru> --- patchwork/api/base.py | 52 +++++++++---------------------- patchwork/api/check.py | 10 ++++-- patchwork/api/embedded.py | 28 +++++++++++++---- patchwork/tests/api/test_event.py | 2 +- 4 files changed, 45 insertions(+), 47 deletions(-) diff --git patchwork/api/base.py patchwork/api/base.py index 3ed4182c..0f5c44a2 100644 --- patchwork/api/base.py +++ patchwork/api/base.py @@ -9,8 +9,8 @@ from django.conf import settings from django.shortcuts import get_object_or_404 from rest_framework import permissions from rest_framework.pagination import PageNumberPagination +from rest_framework.relations import HyperlinkedIdentityField from rest_framework.response import Response -from rest_framework.serializers import HyperlinkedIdentityField from rest_framework.serializers import HyperlinkedModelSerializer from rest_framework.utils.urls import replace_query_param @@ -122,52 +122,28 @@ class MultipleFieldLookupMixin(object): return get_object_or_404(queryset, **filter_kwargs) -class CheckHyperlinkedIdentityField(HyperlinkedIdentityField): - def get_url(self, obj, view_name, request, format): - # Unsaved objects will not yet have a valid URL. - if obj.pk is None: - return None - - return self.reverse( - view_name, - kwargs={ - 'patch_id': obj.patch.id, - 'check_id': obj.id, - }, - request=request, - format=format, - ) +class NestedHyperlinkedIdentityField(HyperlinkedIdentityField): + """A variant of HyperlinkedIdentityField that supports nested resources.""" + def __init__(self, view_name, lookup_field_mapping, **kwargs): + self.lookup_field_mapping = lookup_field_mapping + super().__init__(view_name, **kwargs) -class CoverCommentHyperlinkedIdentityField(HyperlinkedIdentityField): def get_url(self, obj, view_name, request, format): # Unsaved objects will not yet have a valid URL. - if obj.pk is None: + if hasattr(obj, 'pk') and obj.pk in (None, ''): return None - return self.reverse( - view_name, - kwargs={ - 'cover_id': obj.cover.id, - 'comment_id': obj.id, - }, - request=request, - format=format, - ) - - -class PatchCommentHyperlinkedIdentityField(HyperlinkedIdentityField): - def get_url(self, obj, view_name, request, format): - # Unsaved objects will not yet have a valid URL. - if obj.pk is None: - return None + kwargs = {} + for ( + lookup_url_kwarg, + lookup_field, + ) in self.lookup_field_mapping.items(): + kwargs[lookup_url_kwarg] = getattr(obj, lookup_field) return self.reverse( view_name, - kwargs={ - 'patch_id': obj.patch.id, - 'comment_id': obj.id, - }, + kwargs=kwargs, request=request, format=format, ) diff --git patchwork/api/check.py patchwork/api/check.py index c28d89f7..f5461fc6 100644 --- patchwork/api/check.py +++ patchwork/api/check.py @@ -14,8 +14,8 @@ from rest_framework.serializers import HiddenField from rest_framework.serializers import HyperlinkedModelSerializer from rest_framework.serializers import ValidationError -from patchwork.api.base import CheckHyperlinkedIdentityField from patchwork.api.base import MultipleFieldLookupMixin +from patchwork.api.base import NestedHyperlinkedIdentityField from patchwork.api.base import CurrentPatchDefault from patchwork.api.embedded import UserSerializer from patchwork.api.filters import CheckFilterSet @@ -25,7 +25,13 @@ from patchwork.models import Patch class CheckSerializer(HyperlinkedModelSerializer): - url = CheckHyperlinkedIdentityField('api-check-detail') + url = NestedHyperlinkedIdentityField( + 'api-check-detail', + lookup_field_mapping={ + 'patch_id': 'patch_id', + 'check_id': 'id', + }, + ) patch = HiddenField(default=CurrentPatchDefault()) user = UserSerializer(default=CurrentUserDefault()) diff --git patchwork/api/embedded.py patchwork/api/embedded.py index 485ed6f7..7105da08 100644 --- patchwork/api/embedded.py +++ patchwork/api/embedded.py @@ -16,9 +16,7 @@ from rest_framework.serializers import PrimaryKeyRelatedField from rest_framework.serializers import SerializerMethodField from patchwork.api.base import BaseHyperlinkedModelSerializer -from patchwork.api.base import CheckHyperlinkedIdentityField -from patchwork.api.base import CoverCommentHyperlinkedIdentityField -from patchwork.api.base import PatchCommentHyperlinkedIdentityField +from patchwork.api.base import NestedHyperlinkedIdentityField from patchwork import models @@ -82,7 +80,13 @@ class WebURLMixin(BaseHyperlinkedModelSerializer): class CheckSerializer(SerializedRelatedField): class _Serializer(BaseHyperlinkedModelSerializer): - url = CheckHyperlinkedIdentityField('api-check-detail') + url = NestedHyperlinkedIdentityField( + 'api-check-detail', + lookup_field_mapping={ + 'patch_id': 'patch_id', + 'check_id': 'id', + }, + ) def to_representation(self, instance): data = super(CheckSerializer._Serializer, self).to_representation( @@ -130,7 +134,13 @@ class CoverSerializer(SerializedRelatedField): class CoverCommentSerializer(SerializedRelatedField): class _Serializer(MboxMixin, WebURLMixin, BaseHyperlinkedModelSerializer): - url = CoverCommentHyperlinkedIdentityField('api-cover-comment-detail') + url = NestedHyperlinkedIdentityField( + 'api-cover-comment-detail', + lookup_field_mapping={ + 'cover_id': 'cover_id', + 'comment_id': 'id', + }, + ) class Meta: model = models.CoverComment @@ -182,7 +192,13 @@ class PatchSerializer(SerializedRelatedField): class PatchCommentSerializer(SerializedRelatedField): class _Serializer(MboxMixin, WebURLMixin, BaseHyperlinkedModelSerializer): - url = PatchCommentHyperlinkedIdentityField('api-patch-comment-detail') + url = NestedHyperlinkedIdentityField( + 'api-patch-comment-detail', + lookup_field_mapping={ + 'patch_id': 'patch_id', + 'comment_id': 'id', + }, + ) class Meta: model = models.PatchComment diff --git patchwork/tests/api/test_event.py patchwork/tests/api/test_event.py index 7ca09c2e..1a0d811d 100644 --- patchwork/tests/api/test_event.py +++ patchwork/tests/api/test_event.py @@ -200,7 +200,7 @@ class TestEventAPI(APITestCase): for _ in range(3): self._create_events() - with self.assertNumQueries(33): + with self.assertNumQueries(30): self.client.get(self.api_url()) def test_order_by_date_default(self): -- 2.37.3 _______________________________________________ Patchwork mailing list Patchwork@lists.ozlabs.org https://lists.ozlabs.org/listinfo/patchwork