Author: Tyler Wade <way...@gmail.com> Branch: fix-bytearray-complexity Changeset: r71879:7372138b89d1 Date: 2014-06-01 16:26 -0500 http://bitbucket.org/pypy/pypy/changeset/7372138b89d1/
Log: Use [] and len() for buffers in rpython.rlib.rstring diff --git a/rpython/rlib/rstring.py b/rpython/rlib/rstring.py --- a/rpython/rlib/rstring.py +++ b/rpython/rlib/rstring.py @@ -18,20 +18,9 @@ @specialize.argtype(0, 1) def _get_access_functions(value, other): - if (isinstance(other, str) or isinstance(other, unicode) or - isinstance(other, list)): - def getitem(obj, i): - return obj[i] - def getlength(obj): - return len(obj) - else: - assert isinstance(other, Buffer) - def getitem(obj, i): - return obj.getitem(i) - def getlength(obj): - return obj.getlength() + if (not (isinstance(value, str) or isinstance(value, unicode)) or + not (isinstance(other, str) or isinstance(other, unicode))): - if isinstance(value, list) or isinstance(other, Buffer): def find(obj, other, start, end): return search(obj, other, start, end, SEARCH_FIND) def rfind(obj, other, start, end): @@ -39,8 +28,8 @@ def count(obj, other, start, end): return search(obj, other, start, end, SEARCH_COUNT) else: - assert isinstance(value, str) or isinstance(value, unicode) - assert isinstance(other, str) or isinstance(other, unicode) + assert isinstance(value, str) or isinstance(value, unicode) + assert isinstance(other, str) or isinstance(other, unicode) def find(obj, other, start, end): return obj.find(other, start, end) def rfind(obj, other, start, end): @@ -48,7 +37,7 @@ def count(obj, other, start, end): return obj.count(other, start, end) - return getitem, getlength, find, rfind, count + return find, rfind, count @specialize.argtype(0) def _isspace(char): @@ -59,7 +48,7 @@ return unicodedb.isspace(ord(char)) -@specialize.argtype(0) +@specialize.argtype(0, 1) def split(value, by=None, maxsplit=-1): if by is None: length = len(value) @@ -90,11 +79,7 @@ i = j + 1 return res - if isinstance(value, list) or isinstance(value, str): - assert isinstance(by, str) - else: - assert isinstance(by, unicode) - _, _, find, _, count = _get_access_functions(value, by) + find, _, count = _get_access_functions(value, by) bylen = len(by) if bylen == 0: raise ValueError("empty separator") @@ -133,7 +118,7 @@ return res -@specialize.argtype(0) +@specialize.argtype(0, 1) def rsplit(value, by=None, maxsplit=-1): if by is None: res = [] @@ -169,15 +154,11 @@ res.reverse() return res - if isinstance(value, list) or isinstance(value, str): - assert isinstance(by, str) - else: - assert isinstance(by, unicode) if maxsplit > 0: res = newlist_hint(min(maxsplit + 1, len(value))) else: res = [] - _, _, _, rfind, _ = _get_access_functions(value, by) + _, rfind, _ = _get_access_functions(value, by) end = len(value) bylen = len(by) if bylen == 0: @@ -196,27 +177,20 @@ return res -@specialize.argtype(0) +@specialize.argtype(0, 1) @jit.elidable def replace(input, sub, by, maxsplit=-1): if isinstance(input, str): - assert isinstance(sub, str) - assert isinstance(by, str) Builder = StringBuilder elif isinstance(input, unicode): - assert isinstance(sub, unicode) - assert isinstance(by, unicode) Builder = UnicodeBuilder else: assert isinstance(input, list) - assert isinstance(sub, str) - assert isinstance(by, str) - # TODO: ???? Builder = ByteListBuilder if maxsplit == 0: return input - _, _, find, _, count = _get_access_functions(input, sub) + find, _, count = _get_access_functions(input, sub) if not sub: upper = len(input) @@ -280,7 +254,7 @@ end = length return start, end -@specialize.argtype(0) +@specialize.argtype(0, 1) @jit.elidable def startswith(u_self, prefix, start=0, end=sys.maxint): length = len(u_self) @@ -293,7 +267,7 @@ return False return True -@specialize.argtype(0) +@specialize.argtype(0, 1) @jit.elidable def endswith(u_self, suffix, start=0, end=sys.maxint): length = len(u_self) @@ -321,7 +295,6 @@ @specialize.argtype(0, 1) def search(value, other, start, end, mode): - getitem, getlength, _, _, _ = _get_access_functions(value, other) if start < 0: start = 0 if end > len(value): @@ -331,7 +304,7 @@ count = 0 n = end - start - m = getlength(other) + m = len(other) if m == 0: if mode == SEARCH_COUNT: @@ -352,17 +325,17 @@ if mode != SEARCH_RFIND: for i in range(mlast): - mask = bloom_add(mask, getitem(other, i)) - if getitem(other, i) == getitem(other, mlast): + mask = bloom_add(mask, other[i]) + if other[i] == other[mlast]: skip = mlast - i - 1 - mask = bloom_add(mask, getitem(other, mlast)) + mask = bloom_add(mask, other[mlast]) i = start - 1 while i + 1 <= start + w: i += 1 - if value[i + m - 1] == getitem(other, m - 1): + if value[i + m - 1] == other[m - 1]: for j in range(mlast): - if value[i + j] != getitem(other, j): + if value[i + j] != other[j]: break else: if mode != SEARCH_COUNT: @@ -387,18 +360,18 @@ if not bloom(mask, c): i += m else: - mask = bloom_add(mask, getitem(other, 0)) + mask = bloom_add(mask, other[0]) for i in range(mlast, 0, -1): - mask = bloom_add(mask, getitem(other, i)) - if getitem(other, i) == getitem(other, 0): + mask = bloom_add(mask, other[i]) + if other[i] == other[0]: skip = i - 1 i = start + w + 1 while i - 1 >= start: i -= 1 - if value[i] == getitem(other, 0): + if value[i] == other[0]: for j in xrange(mlast, 0, -1): - if value[i + j] != getitem(other, j): + if value[i + j] != other[j]: break else: return i diff --git a/rpython/rlib/test/test_rstring.py b/rpython/rlib/test/test_rstring.py --- a/rpython/rlib/test/test_rstring.py +++ b/rpython/rlib/test/test_rstring.py @@ -6,10 +6,17 @@ from rpython.rtyper.test.tool import BaseRtypingTest def test_split(): - def check_split(value, *args, **kwargs): + def check_split(value, sub, *args, **kwargs): result = kwargs['res'] - assert split(value, *args) == result - assert split(list(value), *args) == [list(i) for i in result] + assert split(value, sub, *args) == result + assert split(value, buffer(sub), *args) == result + + list_result = [list(i) for i in result] + assert split(list(value), sub, *args) == list_result + assert split(list(value), buffer(sub), *args) == list_result + + assert split(buffer(value), sub, *args) == result + assert split(buffer(value), buffer(sub), *args) == result check_split("", 'x', res=['']) check_split("a", "a", 1, res=['', '']) @@ -39,10 +46,17 @@ py.test.raises(ValueError, split, u'abc', u'') def test_rsplit(): - def check_rsplit(value, *args, **kwargs): + def check_rsplit(value, sub, *args, **kwargs): result = kwargs['res'] - assert rsplit(value, *args) == result - assert rsplit(list(value), *args) == [list(i) for i in result] + assert rsplit(value, sub, *args) == result + assert rsplit(value, buffer(sub), *args) == result + + list_result = [list(i) for i in result] + assert rsplit(list(value), sub, *args) == list_result + assert rsplit(list(value), buffer(sub), *args) == list_result + + assert rsplit(buffer(value), sub, *args) == result + assert rsplit(buffer(value), buffer(sub), *args) == result check_rsplit("a", "a", 1, res=['', '']) check_rsplit(" ", " ", 1, res=['', '']) @@ -69,10 +83,13 @@ py.test.raises(ValueError, rsplit, u"abc", u'') def test_string_replace(): - def check_replace(value, *args, **kwargs): + def check_replace(value, sub, *args, **kwargs): result = kwargs['res'] - assert replace(value, *args) == result - assert replace(list(value), *args) == list(result) + assert replace(value, sub, *args) == result + assert replace(value, buffer(sub), *args) == result + + assert replace(list(value), sub, *args) == list(result) + assert replace(list(value), buffer(sub), *args) == list(result) check_replace('one!two!three!', '!', '@', 1, res='one@two!three!') check_replace('one!two!three!', '!', '', res='onetwothree') _______________________________________________ pypy-commit mailing list pypy-commit@python.org https://mail.python.org/mailman/listinfo/pypy-commit