Author: guido.van.rossum
Date: Sat Aug 25 00:33:45 2007
New Revision: 57422

Modified:
   python/branches/py3k/Lib/test/test_dict.py
   python/branches/py3k/Objects/dictobject.c
Log:
Patch with Keir Mierle: add rich comparisons between dict views and sets,
at least for .keys() and .items() (not .values() of course).


Modified: python/branches/py3k/Lib/test/test_dict.py
==============================================================================
--- python/branches/py3k/Lib/test/test_dict.py  (original)
+++ python/branches/py3k/Lib/test/test_dict.py  Sat Aug 25 00:33:45 2007
@@ -398,6 +398,66 @@
         else:
             self.fail("< didn't raise Exc")
 
+    def test_keys_contained(self):
+        # Test rich comparisons against dict key views, which should behave the
+        # same as sets.
+        empty = dict()
+        empty2 = dict()
+        smaller = {1:1, 2:2}
+        larger = {1:1, 2:2, 3:3}
+        larger2 = {1:1, 2:2, 3:3}
+        larger3 = {4:1, 2:2, 3:3}
+
+        self.assertTrue(smaller.keys() <  larger.keys())
+        self.assertTrue(smaller.keys() <= larger.keys())
+        self.assertTrue(larger.keys() >  smaller.keys())
+        self.assertTrue(larger.keys() >= smaller.keys())
+
+        self.assertFalse(smaller.keys() >= larger.keys())
+        self.assertFalse(smaller.keys() >  larger.keys())
+        self.assertFalse(larger.keys()  <= smaller.keys())
+        self.assertFalse(larger.keys()  <  smaller.keys())
+
+        self.assertFalse(smaller.keys() <  larger3.keys())
+        self.assertFalse(smaller.keys() <= larger3.keys())
+        self.assertFalse(larger3.keys() >  smaller.keys())
+        self.assertFalse(larger3.keys() >= smaller.keys())
+
+        # Inequality strictness
+        self.assertTrue(larger2.keys() >= larger.keys())
+        self.assertTrue(larger2.keys() <= larger.keys())
+        self.assertFalse(larger2.keys() > larger.keys())
+        self.assertFalse(larger2.keys() < larger.keys())
+
+        self.assertTrue(larger.keys() == larger2.keys())
+        self.assertTrue(smaller.keys() != larger.keys())
+
+        # There is an optimization on the zero-element case.
+        self.assertTrue(empty.keys() == empty2.keys())
+        self.assertFalse(empty.keys() != empty2.keys())
+        self.assertFalse(empty.keys() == smaller.keys())
+        self.assertTrue(empty.keys() != smaller.keys())
+
+        # With the same size, an elementwise compare happens
+        self.assertTrue(larger.keys() != larger3.keys())
+        self.assertFalse(larger.keys() == larger3.keys())
+
+        # XXX the same tests for .items()
+
+    def test_errors_in_view_containment_check(self):
+        class C:
+            def __eq__(self, other):
+                raise RuntimeError
+        d1 = {1: C()}
+        d2 = {1: C()}
+        self.assertRaises(RuntimeError, lambda: d1.items() == d2.items())
+        self.assertRaises(RuntimeError, lambda: d1.items() != d2.items())
+        self.assertRaises(RuntimeError, lambda: d1.items() <= d2.items())
+        self.assertRaises(RuntimeError, lambda: d1.items() >= d2.items())
+        d3 = {1: C(), 2: C()}
+        self.assertRaises(RuntimeError, lambda: d2.items() < d3.items())
+        self.assertRaises(RuntimeError, lambda: d3.items() > d2.items())
+
     def test_missing(self):
         # Make sure dict doesn't have a __missing__ method
         self.assertEqual(hasattr(dict, "__missing__"), False)

Modified: python/branches/py3k/Objects/dictobject.c
==============================================================================
--- python/branches/py3k/Objects/dictobject.c   (original)
+++ python/branches/py3k/Objects/dictobject.c   Sat Aug 25 00:33:45 2007
@@ -2371,6 +2371,8 @@
 # define PyDictViewSet_Check(obj) \
        (PyDictKeys_Check(obj) || PyDictItems_Check(obj))
 
+/* Return 1 if self is a subset of other, iterating over self;
+   0 if not; -1 if an error occurred. */
 static int
 all_contained_in(PyObject *self, PyObject *other)
 {
@@ -2398,41 +2400,63 @@
 static PyObject *
 dictview_richcompare(PyObject *self, PyObject *other, int op)
 {
+       Py_ssize_t len_self, len_other;
+       int ok;
+       PyObject *result;
+
        assert(self != NULL);
        assert(PyDictViewSet_Check(self));
        assert(other != NULL);
-       if ((op == Py_EQ || op == Py_NE) &&
-           (PyAnySet_Check(other) || PyDictViewSet_Check(other)))
-       {
-               Py_ssize_t len_self, len_other;
-               int ok;
-               PyObject *result;
-
-               len_self = PyObject_Size(self);
-               if (len_self < 0)
-                       return NULL;
-               len_other = PyObject_Size(other);
-               if (len_other < 0)
-                       return NULL;
-               if (len_self != len_other)
-                       ok = 0;
-               else if (len_self == 0)
-                       ok = 1;
-               else
-                       ok = all_contained_in(self, other);
-               if (ok < 0)
-                       return NULL;
-               if (ok == (op == Py_EQ))
-                       result = Py_True;
-               else
-                       result = Py_False;
-               Py_INCREF(result);
-               return result;
-       }
-       else {
+
+       if (!PyAnySet_Check(other) && !PyDictViewSet_Check(other)) {
                Py_INCREF(Py_NotImplemented);
                return Py_NotImplemented;
        }
+
+       len_self = PyObject_Size(self);
+       if (len_self < 0)
+               return NULL;
+       len_other = PyObject_Size(other);
+       if (len_other < 0)
+               return NULL;
+
+       ok = 0;
+       switch(op) {
+
+       case Py_NE:
+       case Py_EQ:
+               if (len_self == len_other)
+                       ok = all_contained_in(self, other);
+               if (op == Py_NE && ok >= 0)
+                       ok = !ok;
+               break;
+
+       case Py_LT:
+               if (len_self < len_other)
+                       ok = all_contained_in(self, other);
+               break;
+
+         case Py_LE:
+                 if (len_self <= len_other)
+                         ok = all_contained_in(self, other);
+                 break;
+
+       case Py_GT:
+               if (len_self > len_other)
+                       ok = all_contained_in(other, self);
+               break;
+
+       case Py_GE:
+               if (len_self >= len_other)
+                       ok = all_contained_in(other, self);
+               break;
+
+       }
+       if (ok < 0)
+               return NULL;
+       result = ok ? Py_True : Py_False;
+       Py_INCREF(result);
+       return result;
 }
 
 /*** dict_keys ***/
_______________________________________________
Python-3000-checkins mailing list
[email protected]
http://mail.python.org/mailman/listinfo/python-3000-checkins

Reply via email to