Author: mtredinnick
Date: 2007-10-13 21:16:38 -0500 (Sat, 13 Oct 2007)
New Revision: 6497

Modified:
   django/branches/queryset-refactor/django/db/models/query.py
   django/branches/queryset-refactor/django/db/models/sql/datastructures.py
   django/branches/queryset-refactor/django/db/models/sql/query.py
   django/branches/queryset-refactor/tests/regressiontests/queries/models.py
Log:
queryset-refactor: Made all the changes needed to have count() work properly
with ValuesQuerySet. This is the general case of #2939.

At this point, all the existing tests now pass on the branch (except for
Oracle). It's a bit slower than before, though, and there are still a bunch of 
known bugs that aren't in the tests (or only exercised for some backends).


Modified: django/branches/queryset-refactor/django/db/models/query.py
===================================================================
--- django/branches/queryset-refactor/django/db/models/query.py 2007-10-14 
02:16:08 UTC (rev 6496)
+++ django/branches/queryset-refactor/django/db/models/query.py 2007-10-14 
02:16:38 UTC (rev 6497)
@@ -251,7 +251,7 @@
     ##################################################
 
     def values(self, *fields):
-        return self._clone(klass=ValuesQuerySet, _fields=fields)
+        return self._clone(klass=ValuesQuerySet, setup=True, _fields=fields)
 
     def dates(self, field_name, kind, order='ASC'):
         """
@@ -266,8 +266,8 @@
         field = self.model._meta.get_field(field_name, many_to_many=False)
         assert isinstance(field, DateField), "%r isn't a DateField." \
                 % field_name
-        return self._clone(klass=DateQuerySet, _field=field, _kind=kind,
-                _order=order)
+        return self._clone(klass=DateQuerySet, setup=True, _field=field,
+                _kind=kind, _order=order)
 
     ##################################################################
     # PUBLIC METHODS THAT ALTER ATTRIBUTES AND RETURN A NEW QUERYSET #
@@ -363,13 +363,15 @@
     # PRIVATE METHODS #
     ###################
 
-    def _clone(self, klass=None, **kwargs):
+    def _clone(self, klass=None, setup=False, **kwargs):
         if klass is None:
             klass = self.__class__
         c = klass()
         c.model = self.model
         c.query = self.query.clone()
         c.__dict__.update(kwargs)
+        if setup and hasattr(c, '_setup_query'):
+            c._setup_query()
         return c
 
     def _get_data(self):
@@ -389,16 +391,33 @@
         # select_related isn't supported in values().
         self.query.select_related = False
 
+        # QuerySet.clone() will also set up the _fields attribute with the
+        # names of the model fields to select.
+
     def iterator(self):
         extra_select = self.query.extra_select.keys()
         extra_select.sort()
+        if extra_select:
+            self.field_names.extend([f for f in extra_select])
 
-        # Construct two objects -- fields and field_names.
-        # fields is a list of Field objects to fetch.
-        # field_names is a list of field names, which will be the keys in the
-        # resulting dictionaries.
+        for row in self.query.results_iter():
+            yield dict(zip(self.field_names, row))
+
+    def _setup_query(self):
+        """
+        Sets up any special features of the query attribute.
+
+        Called by the _clone() method after initialising the rest of the
+        instance.
+        """
+        # Construct two objects:
+        #   - fields is a list of Field objects to fetch.
+        #   - field_names is a list of field names, which will be the keys in
+        #   the resulting dictionaries.
+        # 'fields' is used to configure the query, whilst field_names is stored
+        # in this object for use by iterator().
         if self._fields:
-            if not extra_select:
+            if not self.query.extra_select:
                 fields = [self.model._meta.get_field(f, many_to_many=False)
                         for f in self._fields]
                 field_names = self._fields
@@ -418,30 +437,42 @@
             field_names = [f.attname for f in fields]
 
         self.query.add_local_columns([f.column for f in fields])
-        if extra_select:
-            field_names.extend([f for f in extra_select])
+        self.field_names = field_names
 
-        for row in self.query.results_iter():
-            yield dict(zip(field_names, row))
-
-    def _clone(self, klass=None, **kwargs):
+    def _clone(self, klass=None, setup=False, **kwargs):
+        """
+        Cloning a ValuesQuerySet preserves the current fields.
+        """
         c = super(ValuesQuerySet, self)._clone(klass, **kwargs)
         c._fields = self._fields[:]
+        c.field_names = self.field_names[:]
+        if setup and hasattr(c, '_setup_query'):
+            c._setup_query()
         return c
 
 class DateQuerySet(QuerySet):
     def iterator(self):
-        self.query = self.query.clone(klass=sql.DateQuery)
+        return self.query.results_iter()
+
+    def _setup_query(self):
+        """
+        Sets up any special features of the query attribute.
+
+        Called by the _clone() method after initialising the rest of the
+        instance.
+        """
+        self.query = self.query.clone(klass=sql.DateQuery, setup=True)
         self.query.select = []
         self.query.add_date_select(self._field.column, self._kind, self._order)
         if self._field.null:
             self.query.add_filter(('%s__isnull' % self._field.name, True))
-        return self.query.results_iter()
 
-    def _clone(self, klass=None, **kwargs):
-        c = super(DateQuerySet, self)._clone(klass, **kwargs)
+    def _clone(self, klass=None, setup=False, **kwargs):
+        c = super(DateQuerySet, self)._clone(klass, False, **kwargs)
         c._field = self._field
         c._kind = self._kind
+        if setup and hasattr(c, '_setup_query'):
+            c._setup_query()
         return c
 
 class EmptyQuerySet(QuerySet):
@@ -455,14 +486,14 @@
     def delete(self):
         pass
 
-    def _clone(self, klass=None, **kwargs):
+    def _clone(self, klass=None, setup=False, **kwargs):
         c = super(EmptyQuerySet, self)._clone(klass, **kwargs)
         c._result_cache = []
         return c
 
     def iterator(self):
         # This slightly odd construction is because we need an empty generator
-        # (it should raise StopIteration immediately).
+        # (it raises StopIteration immediately).
         yield iter([]).next()
 
 # QOperator, QAnd and QOr are temporarily retained for backwards compatibility.

Modified: 
django/branches/queryset-refactor/django/db/models/sql/datastructures.py
===================================================================
--- django/branches/queryset-refactor/django/db/models/sql/datastructures.py    
2007-10-14 02:16:08 UTC (rev 6496)
+++ django/branches/queryset-refactor/django/db/models/sql/datastructures.py    
2007-10-14 02:16:38 UTC (rev 6497)
@@ -32,12 +32,12 @@
     """
     Perform a count on the given column.
     """
-    def __init__(self, col=None, distinct=False):
+    def __init__(self, col='*', distinct=False):
         """
         Set the column to count on (defaults to '*') and set whether the count
         should be distinct or not.
         """
-        self.col = col and col or '*'
+        self.col = col
         self.distinct = distinct
 
     def relabel_aliases(self, change_map):
@@ -49,13 +49,13 @@
         if not quote_func:
             quote_func = lambda x: x
         if isinstance(self.col, (list, tuple)):
-            col = '%s.%s' % tuple([quote_func(c) for c in self.col])
+            col = ('%s.%s' % tuple([quote_func(c) for c in self.col]))
         else:
             col = self.col
         if self.distinct:
-            return 'COUNT(DISTINCT(%s))' % col
+            return 'COUNT(DISTINCT %s)' % col
         else:
-            return 'COUNT(%s)' % col
+            return 'COUNT(%s)' % self.col
 
 class Date(object):
     """

Modified: django/branches/queryset-refactor/django/db/models/sql/query.py
===================================================================
--- django/branches/queryset-refactor/django/db/models/sql/query.py     
2007-10-14 02:16:08 UTC (rev 6496)
+++ django/branches/queryset-refactor/django/db/models/sql/query.py     
2007-10-14 02:16:38 UTC (rev 6497)
@@ -147,15 +147,17 @@
 
     def get_count(self):
         """
-        Performs a COUNT() or COUNT(DISTINCT()) query, as appropriate, using
-        the current filter constraints.
+        Performs a COUNT() query using the current filter constraints.
         """
-        counter = self.clone()
-        counter.clear_ordering()
-        counter.clear_limits()
-        counter.select_related = False
-        counter.add_count_column()
-        data = counter.execute_sql(SINGLE)
+        obj = self.clone()
+        obj.clear_ordering()
+        obj.clear_limits()
+        obj.select_related = False
+        if obj.distinct and len(obj.select) > 1:
+            obj = self.clone(CountQuery, _query=obj, where=WhereNode(self),
+                    distinct=False)
+        obj.add_count_column()
+        data = obj.execute_sql(SINGLE)
         if not data:
             return 0
         number = data[0]
@@ -176,7 +178,6 @@
         If 'with_limits' is False, any limit/offset information is not included
         in the query.
         """
-        qn = self.connection.ops.quote_name
         self.pre_sql_setup()
         result = ['SELECT']
         if self.distinct:
@@ -185,21 +186,12 @@
         result.append(', '.join(out_cols))
 
         result.append('FROM')
-        for alias in self.tables:
-            if not self.alias_map[alias][ALIAS_REFCOUNT]:
-                continue
-            name, alias, join_type, lhs, lhs_col, col = \
-                    self.alias_map[alias][ALIAS_JOIN]
-            alias_str = (alias != name and ' AS %s' % alias or '')
-            if join_type:
-                result.append('%s %s%s ON (%s.%s = %s.%s)'
-                        % (join_type, qn(name), alias_str, qn(lhs),
-                            qn(lhs_col), qn(alias), qn(col)))
-            else:
-                result.append('%s%s' % (qn(name), alias_str))
-        result.extend(self.extra_tables)
+        from_, f_params = self.get_from_clause()
+        result.extend(from_)
+        params = list(f_params)
 
-        where, params = self.where.as_sql()
+        where, w_params = self.where.as_sql()
+        params.extend(w_params)
         if where:
             result.append('WHERE %s' % where)
         if self.extra_where:
@@ -348,6 +340,30 @@
                 for alias, col in extra_select])
         return result
 
+    def get_from_clause(self):
+        """
+        Returns a list of strings that are joined together to go after the
+        "FROM" part of the query, as well as any extra parameters that need to
+        be included. Sub-classes, can override this to create a from-clause via
+        a "select", for example (e.g. CountQuery).
+        """
+        result = []
+        qn = self.connection.ops.quote_name
+        for alias in self.tables:
+            if not self.alias_map[alias][ALIAS_REFCOUNT]:
+                continue
+            name, alias, join_type, lhs, lhs_col, col = \
+                    self.alias_map[alias][ALIAS_JOIN]
+            alias_str = (alias != name and ' AS %s' % alias or '')
+            if join_type:
+                result.append('%s %s%s ON (%s.%s = %s.%s)'
+                        % (join_type, qn(name), alias_str, qn(lhs),
+                            qn(lhs_col), qn(alias), qn(col)))
+            else:
+                result.append('%s%s' % (qn(name), alias_str))
+        result.extend(self.extra_tables)
+        return result, []
+
     def get_grouping(self):
         """
         Returns a tuple representing the SQL elements in the "group by" clause.
@@ -787,8 +803,17 @@
         if not self.distinct:
             select = Count()
         else:
-            select = Count((self.table_map[self.model._meta.db_table][0],
-                    self.model._meta.pk.column), True)
+            opts = self.model._meta
+            if not self.select:
+                select = Count((self.join((None, opts.db_table, None, None)),
+                        opts.pk.column), True)
+            else:
+                # Because of SQL portability issues, multi-column, distinct
+                # counts need a sub-query -- see get_count() for details.
+                assert len(self.select) == 1, \
+                        "Cannot add count col with multiple cols in 'select'."
+                select = Count(self.select[0], True)
+
             # Distinct handling is done in Count(), so don't do it at this
             # level.
             self.distinct = False
@@ -987,6 +1012,16 @@
         else:
             self.group_by = [select]
 
+class CountQuery(Query):
+    """
+    A CountQuery knows how to take a normal query which would select over
+    multiple distinct columns and turn it into SQL that can be used on a
+    variety of backends (it requires a select in the FROM clause).
+    """
+    def get_from_clause(self):
+        result, params = self._query.as_sql()
+        return ['(%s) AS A1' % result], params
+
 def find_field(name, field_list, related_query):
     """
     Finds a field with a specific name in a list of field instances.

Modified: 
django/branches/queryset-refactor/tests/regressiontests/queries/models.py
===================================================================
--- django/branches/queryset-refactor/tests/regressiontests/queries/models.py   
2007-10-14 02:16:08 UTC (rev 6496)
+++ django/branches/queryset-refactor/tests/regressiontests/queries/models.py   
2007-10-14 02:16:38 UTC (rev 6497)
@@ -111,11 +111,18 @@
 >>> Author.objects.filter(Q(name='a3') | Q(item__name='one'))
 [<Author: a1>, <Author: a3>]
 
-Bug #2939
-# FIXME: ValueQuerySets don't work yet.
-# >>> Item.objects.values('creator').distinct().count()
-# 2
+Bug #1878, #2939
+>>> Item.objects.values('creator').distinct().count()
+3
 
+# Create something with a duplicate 'name' so that we can test multi-column
+# cases (which require some tricky SQL transformations under the covers).
+>>> xx = Item(name='four', creator=a2)
+>>> xx.save()
+>>> Item.objects.exclude(name='two').values('creator', 
'name').distinct().count()
+4
+>>> xx.delete()
+
 Bug #2253
 >>> q1 = Item.objects.order_by('name')
 >>> q2 = Item.objects.filter(id=i1.id)


--~--~---------~--~----~------------~-------~--~----~
You received this message because you are subscribed to the Google Groups 
"Django updates" group.
To post to this group, send email to [email protected]
To unsubscribe from this group, send email to [EMAIL PROTECTED]
For more options, visit this group at 
http://groups.google.com/group/django-updates?hl=en
-~----------~----~----~----~------~----~------~--~---

Reply via email to