Author: Tyler Wade <way...@gmail.com>
Branch: fix-bytearray-complexity
Changeset: r71879:7372138b89d1
Date: 2014-06-01 16:26 -0500
http://bitbucket.org/pypy/pypy/changeset/7372138b89d1/

Log:    Use [] and len() for buffers in rpython.rlib.rstring

diff --git a/rpython/rlib/rstring.py b/rpython/rlib/rstring.py
--- a/rpython/rlib/rstring.py
+++ b/rpython/rlib/rstring.py
@@ -18,20 +18,9 @@
 
 @specialize.argtype(0, 1)
 def _get_access_functions(value, other):
-    if (isinstance(other, str) or  isinstance(other, unicode) or
-        isinstance(other, list)):
-        def getitem(obj, i):
-            return obj[i]
-        def getlength(obj):
-            return len(obj)
-    else:
-        assert isinstance(other, Buffer)
-        def getitem(obj, i):
-            return obj.getitem(i)
-        def getlength(obj):
-            return obj.getlength()
+    if (not (isinstance(value, str) or isinstance(value, unicode)) or
+        not (isinstance(other, str) or isinstance(other, unicode))):
 
-    if isinstance(value, list) or isinstance(other, Buffer):
         def find(obj, other, start, end):
             return search(obj, other, start, end, SEARCH_FIND)
         def rfind(obj, other, start, end):
@@ -39,8 +28,8 @@
         def count(obj, other, start, end):
             return search(obj, other, start, end, SEARCH_COUNT)
     else:
-        assert isinstance(value, str) or  isinstance(value, unicode)
-        assert isinstance(other, str) or  isinstance(other, unicode)
+        assert isinstance(value, str) or isinstance(value, unicode)
+        assert isinstance(other, str) or isinstance(other, unicode)
         def find(obj, other, start, end):
             return obj.find(other, start, end)
         def rfind(obj, other, start, end):
@@ -48,7 +37,7 @@
         def count(obj, other, start, end):
             return obj.count(other, start, end)
 
-    return getitem, getlength, find, rfind, count
+    return find, rfind, count
 
 @specialize.argtype(0)
 def _isspace(char):
@@ -59,7 +48,7 @@
         return unicodedb.isspace(ord(char))
 
 
-@specialize.argtype(0)
+@specialize.argtype(0, 1)
 def split(value, by=None, maxsplit=-1):
     if by is None:
         length = len(value)
@@ -90,11 +79,7 @@
             i = j + 1
         return res
 
-    if isinstance(value, list) or isinstance(value, str):
-        assert isinstance(by, str)
-    else:
-        assert isinstance(by, unicode)
-    _, _, find, _, count = _get_access_functions(value, by)
+    find, _, count = _get_access_functions(value, by)
     bylen = len(by)
     if bylen == 0:
         raise ValueError("empty separator")
@@ -133,7 +118,7 @@
     return res
 
 
-@specialize.argtype(0)
+@specialize.argtype(0, 1)
 def rsplit(value, by=None, maxsplit=-1):
     if by is None:
         res = []
@@ -169,15 +154,11 @@
         res.reverse()
         return res
 
-    if isinstance(value, list) or isinstance(value, str):
-        assert isinstance(by, str)
-    else:
-        assert isinstance(by, unicode)
     if maxsplit > 0:
         res = newlist_hint(min(maxsplit + 1, len(value)))
     else:
         res = []
-    _, _, _, rfind, _ = _get_access_functions(value, by)
+    _, rfind, _ = _get_access_functions(value, by)
     end = len(value)
     bylen = len(by)
     if bylen == 0:
@@ -196,27 +177,20 @@
     return res
 
 
-@specialize.argtype(0)
+@specialize.argtype(0, 1)
 @jit.elidable
 def replace(input, sub, by, maxsplit=-1):
     if isinstance(input, str):
-        assert isinstance(sub, str)
-        assert isinstance(by, str)
         Builder = StringBuilder
     elif isinstance(input, unicode):
-        assert isinstance(sub, unicode)
-        assert isinstance(by, unicode)
         Builder = UnicodeBuilder
     else:
         assert isinstance(input, list)
-        assert isinstance(sub, str)
-        assert isinstance(by, str)
-        # TODO: ????
         Builder = ByteListBuilder
     if maxsplit == 0:
         return input
 
-    _, _, find, _, count = _get_access_functions(input, sub)
+    find, _, count = _get_access_functions(input, sub)
 
     if not sub:
         upper = len(input)
@@ -280,7 +254,7 @@
         end = length
     return start, end
 
-@specialize.argtype(0)
+@specialize.argtype(0, 1)
 @jit.elidable
 def startswith(u_self, prefix, start=0, end=sys.maxint):
     length = len(u_self)
@@ -293,7 +267,7 @@
             return False
     return True
 
-@specialize.argtype(0)
+@specialize.argtype(0, 1)
 @jit.elidable
 def endswith(u_self, suffix, start=0, end=sys.maxint):
     length = len(u_self)
@@ -321,7 +295,6 @@
 
 @specialize.argtype(0, 1)
 def search(value, other, start, end, mode):
-    getitem, getlength, _, _, _ = _get_access_functions(value, other)
     if start < 0:
         start = 0
     if end > len(value):
@@ -331,7 +304,7 @@
 
     count = 0
     n = end - start
-    m = getlength(other)
+    m = len(other)
 
     if m == 0:
         if mode == SEARCH_COUNT:
@@ -352,17 +325,17 @@
 
     if mode != SEARCH_RFIND:
         for i in range(mlast):
-            mask = bloom_add(mask, getitem(other, i))
-            if getitem(other, i) == getitem(other, mlast):
+            mask = bloom_add(mask, other[i])
+            if other[i] == other[mlast]:
                 skip = mlast - i - 1
-        mask = bloom_add(mask, getitem(other, mlast))
+        mask = bloom_add(mask, other[mlast])
 
         i = start - 1
         while i + 1 <= start + w:
             i += 1
-            if value[i + m - 1] == getitem(other, m - 1):
+            if value[i + m - 1] == other[m - 1]:
                 for j in range(mlast):
-                    if value[i + j] != getitem(other, j):
+                    if value[i + j] != other[j]:
                         break
                 else:
                     if mode != SEARCH_COUNT:
@@ -387,18 +360,18 @@
                 if not bloom(mask, c):
                     i += m
     else:
-        mask = bloom_add(mask, getitem(other, 0))
+        mask = bloom_add(mask, other[0])
         for i in range(mlast, 0, -1):
-            mask = bloom_add(mask, getitem(other, i))
-            if getitem(other, i) == getitem(other, 0):
+            mask = bloom_add(mask, other[i])
+            if other[i] == other[0]:
                 skip = i - 1
 
         i = start + w + 1
         while i - 1 >= start:
             i -= 1
-            if value[i] == getitem(other, 0):
+            if value[i] == other[0]:
                 for j in xrange(mlast, 0, -1):
-                    if value[i + j] != getitem(other, j):
+                    if value[i + j] != other[j]:
                         break
                 else:
                     return i
diff --git a/rpython/rlib/test/test_rstring.py 
b/rpython/rlib/test/test_rstring.py
--- a/rpython/rlib/test/test_rstring.py
+++ b/rpython/rlib/test/test_rstring.py
@@ -6,10 +6,17 @@
 from rpython.rtyper.test.tool import BaseRtypingTest
 
 def test_split():
-    def check_split(value, *args, **kwargs):
+    def check_split(value, sub, *args, **kwargs):
         result = kwargs['res']
-        assert split(value, *args) == result
-        assert split(list(value), *args) == [list(i) for i in result]
+        assert split(value, sub, *args) == result
+        assert split(value, buffer(sub), *args) == result
+
+        list_result = [list(i) for i in result]
+        assert split(list(value), sub, *args) == list_result
+        assert split(list(value), buffer(sub), *args) == list_result
+
+        assert split(buffer(value), sub, *args) == result
+        assert split(buffer(value), buffer(sub), *args) == result
 
     check_split("", 'x', res=[''])
     check_split("a", "a", 1, res=['', ''])
@@ -39,10 +46,17 @@
     py.test.raises(ValueError, split, u'abc', u'')
 
 def test_rsplit():
-    def check_rsplit(value, *args, **kwargs):
+    def check_rsplit(value, sub, *args, **kwargs):
         result = kwargs['res']
-        assert rsplit(value, *args) == result
-        assert rsplit(list(value), *args) == [list(i) for i in result]
+        assert rsplit(value, sub, *args) == result
+        assert rsplit(value, buffer(sub), *args) == result
+
+        list_result = [list(i) for i in result]
+        assert rsplit(list(value), sub, *args) == list_result
+        assert rsplit(list(value), buffer(sub), *args) == list_result
+
+        assert rsplit(buffer(value), sub, *args) == result
+        assert rsplit(buffer(value), buffer(sub), *args) == result
 
     check_rsplit("a", "a", 1, res=['', ''])
     check_rsplit(" ", " ", 1, res=['', ''])
@@ -69,10 +83,13 @@
     py.test.raises(ValueError, rsplit, u"abc", u'')
 
 def test_string_replace():
-    def check_replace(value, *args, **kwargs):
+    def check_replace(value, sub, *args, **kwargs):
         result = kwargs['res']
-        assert replace(value, *args) == result
-        assert replace(list(value), *args) == list(result)
+        assert replace(value, sub, *args) == result
+        assert replace(value, buffer(sub), *args) == result
+
+        assert replace(list(value), sub, *args) == list(result)
+        assert replace(list(value), buffer(sub), *args) == list(result)
         
     check_replace('one!two!three!', '!', '@', 1, res='one@two!three!')
     check_replace('one!two!three!', '!', '', res='onetwothree')
_______________________________________________
pypy-commit mailing list
pypy-commit@python.org
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to