Author: Lukas Diekmann <[email protected]>
Branch: set-strategies
Changeset: r49204:bb83301f7ae1
Date: 2011-10-04 13:40 +0200
http://bitbucket.org/pypy/pypy/changeset/bb83301f7ae1/

Log:    refactored intersection for sets

diff --git a/pypy/objspace/std/setobject.py b/pypy/objspace/std/setobject.py
--- a/pypy/objspace/std/setobject.py
+++ b/pypy/objspace/std/setobject.py
@@ -477,33 +477,47 @@
         w_set.strategy = strategy
         w_set.sstorage = storage
 
+    def _intersect_base(self, w_set, w_other):
+        if w_set.strategy is w_other.strategy:
+            strategy = w_set.strategy
+            storage = strategy._intersect_unwrapped(w_set, w_other)
+        else:
+            strategy = self.space.fromcache(ObjectSetStrategy)
+            storage = strategy._intersect_wrapped(w_set, w_other)
+        return storage, strategy
+
+    def _intersect_wrapped(self, w_set, w_other):
+        result = self.get_empty_dict()
+        items = self.cast_from_void_star(w_set.sstorage).keys()
+        for key in items:
+            w_key = self.wrap(key)
+            if w_other.has_key(w_key):
+                result[w_key] = None
+        return self.cast_to_void_star(result)
+
+    def _intersect_unwrapped(self, w_set, w_other):
+        result = self.get_empty_dict()
+        d_this = self.cast_from_void_star(w_set.sstorage)
+        d_other = self.cast_from_void_star(w_other.sstorage)
+        for key in d_this:
+            if key in d_other:
+                result[key] = None
+        return self.cast_to_void_star(result)
+
     def intersect(self, w_set, w_other):
         if w_set.length() > w_other.length():
             return w_other.intersect(w_set)
 
-        result = w_set._newobj(self.space, None)
-        items = self.cast_from_void_star(w_set.sstorage).keys()
-        #XXX do it without wrapping when strategies are equal
-        for key in items:
-            w_key = self.wrap(key)
-            if w_other.has_key(w_key):
-                result.add(w_key)
-        return result
+        storage, strategy = self._intersect_base(w_set, w_other)
+        return w_set.from_storage_and_strategy(storage, strategy)
 
     def intersect_update(self, w_set, w_other):
         if w_set.length() > w_other.length():
-            return w_other.intersect(w_set)
-
-        setdata = newset(self.space)
-        items = self.cast_from_void_star(w_set.sstorage).keys()
-        for key in items:
-            w_key = self.wrap(key)
-            if w_other.has_key(w_key):
-                setdata[w_key] = None
-
-        # do not switch strategy here if other items match
-        w_set.strategy = strategy = self.space.fromcache(ObjectSetStrategy)
-        w_set.sstorage = strategy.cast_to_void_star(setdata)
+            storage, strategy = self._intersect_base(w_other, w_set)
+        else:
+            storage, strategy = self._intersect_base(w_set, w_other)
+        w_set.strategy = strategy
+        w_set.sstorage = storage
         return w_set
 
     def intersect_multiple(self, w_set, others_w):
@@ -514,7 +528,6 @@
                 #XXX this creates setobject again
                 result = result.intersect(w_other)
             else:
-                #XXX directly give w_other as argument to result2
                 result2 = w_set._newobj(self.space, None)
                 for w_key in self.space.listview(w_other):
                     if result.has_key(w_key):
@@ -1084,7 +1097,6 @@
     return
 
 def inplace_and__Set_Set(space, w_left, w_other):
-    #XXX why do we need to return here?
     return w_left.intersect_update(w_other)
 
 inplace_and__Set_Frozenset = inplace_and__Set_Set
diff --git a/pypy/objspace/std/test/test_setobject.py 
b/pypy/objspace/std/test/test_setobject.py
--- a/pypy/objspace/std/test/test_setobject.py
+++ b/pypy/objspace/std/test/test_setobject.py
@@ -493,6 +493,24 @@
         assert s.intersection() == s
         assert s.intersection() is not s
 
+    def test_intersection_swap(self):
+        s1 = s3 = set([1,2,3,4,5])
+        s2 = set([2,3,6,7])
+        s1 &= s2
+        assert s1 == set([2,3])
+        assert s3 == set([2,3])
+
+    def test_intersection_generator(self):
+        def foo():
+            for i in range(5):
+                yield i
+
+        s1 = s2 = set([1,2,3,4,5,6])
+        assert s1.intersection(foo()) == set([1,2,3,4])
+        s1.intersection_update(foo())
+        assert s1 == set([1,2,3,4])
+        assert s2 == set([1,2,3,4])
+
     def test_difference(self):
         assert set([1,2,3]).difference(set([2,3,4])) == set([1])
         assert set([1,2,3]).difference(frozenset([2,3,4])) == set([1])
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to