Author: Armin Rigo <[email protected]>
Branch: 
Changeset: r89184:10018e2abec8
Date: 2016-12-19 16:53 +0100
http://bitbucket.org/pypy/pypy/changeset/10018e2abec8/

Log:    Fix rzlib to support arbitrary large input strings (> 2**32 on
        64-bit)

diff --git a/rpython/rlib/rzlib.py b/rpython/rlib/rzlib.py
--- a/rpython/rlib/rzlib.py
+++ b/rpython/rlib/rzlib.py
@@ -76,6 +76,10 @@
     DEF_MEM_LEVEL = MAX_MEM_LEVEL
 
 OUTPUT_BUFFER_SIZE = 32*1024
+INPUT_BUFFER_MAX = 2047*1024*1024
+# Note: we assume that zlib never outputs less than OUTPUT_BUFFER_SIZE
+# from an input of INPUT_BUFFER_MAX bytes.  This should be true by a
+# large margin (I think zlib never compresses by more than ~1000x).
 
 
 class ComplexCConfig:
@@ -366,10 +370,10 @@
     """Common code for compress() and decompress().
     """
     # Prepare the input buffer for the stream
-    assert data is not None # XXX seems to be sane assumption, however not for 
sure
+    assert data is not None
     with rffi.scoped_nonmovingbuffer(data) as inbuf:
         stream.c_next_in = rffi.cast(Bytefp, inbuf)
-        rffi.setintfield(stream, 'c_avail_in', len(data))
+        end_inbuf = rffi.ptradd(stream.c_next_in, len(data))
 
         # Prepare the output buffer
         with lltype.scoped_alloc(rffi.CCHARP.TO, OUTPUT_BUFFER_SIZE) as outbuf:
@@ -379,6 +383,11 @@
             result = StringBuilder()
 
             while True:
+                avail_in = ptrdiff(end_inbuf, stream.c_next_in)
+                if avail_in > INPUT_BUFFER_MAX:
+                    avail_in = INPUT_BUFFER_MAX
+                rffi.setintfield(stream, 'c_avail_in', avail_in)
+
                 stream.c_next_out = rffi.cast(Bytefp, outbuf)
                 bufsize = OUTPUT_BUFFER_SIZE
                 if max_length < bufsize:
@@ -388,7 +397,9 @@
                     bufsize = max_length
                 max_length -= bufsize
                 rffi.setintfield(stream, 'c_avail_out', bufsize)
+
                 err = cfunc(stream, flush)
+
                 if err == Z_NEED_DICT and zdict is not None:
                     inflateSetDictionary(stream, zdict)
                     # repeat the call to inflate
@@ -422,6 +433,9 @@
     # When decompressing, if the compressed stream of data was truncated,
     # then the zlib simply returns Z_OK and waits for more.  If it is
     # complete it returns Z_STREAM_END.
-    return (result.build(),
-            err,
-            rffi.cast(lltype.Signed, stream.c_avail_in))
+    avail_in = ptrdiff(end_inbuf, stream.c_next_in)
+    return (result.build(), err, avail_in)
+
+def ptrdiff(p, q):
+    x = rffi.cast(lltype.Unsigned, p) - rffi.cast(lltype.Unsigned, q)
+    return rffi.cast(lltype.Signed, x)
diff --git a/rpython/rlib/test/test_rzlib.py b/rpython/rlib/test/test_rzlib.py
--- a/rpython/rlib/test/test_rzlib.py
+++ b/rpython/rlib/test/test_rzlib.py
@@ -3,7 +3,7 @@
 Tests for the rzlib module.
 """
 
-import py
+import py, sys
 from rpython.rlib import rzlib
 from rpython.rlib.rarithmetic import r_uint
 from rpython.rlib import clibffi # for side effect of testing lib_c_name on 
win32
@@ -274,3 +274,36 @@
 def test_zlibVersion():
     runtime_version = rzlib.zlibVersion()
     assert runtime_version[0] == rzlib.ZLIB_VERSION[0]
+
+def test_translate_and_large_input():
+    from rpython.translator.c.test.test_genc import compile
+
+    def f(i):
+        bytes = "s" * i
+        for j in range(3):
+            stream = rzlib.deflateInit()
+            bytes = rzlib.compress(stream, bytes, rzlib.Z_FINISH)
+            rzlib.deflateEnd(stream)
+        return bytes
+
+    fc = compile(f, [int])
+
+    test_list = [1, 2, 3, 5, 8, 87, 876, 8765, 87654, 876543, 8765432,
+                 127329129]       # up to ~128MB
+    if sys.maxint > 2**32:
+        test_list.append(2971215073)    # 3GB (greater than INPUT_BUFFER_MAX)
+    for a in test_list:
+        print 'Testing compression of "s" * %d' % a
+        z = zlib.compressobj()
+        count = a
+        pieces = []
+        while count > 1024*1024:
+            pieces.append(z.compress("s" * (1024*1024)))
+            count -= 1024*1024
+        pieces.append(z.compress("s" * count))
+        pieces.append(z.flush(zlib.Z_FINISH))
+        expected = ''.join(pieces)
+        del pieces
+        expected = zlib.compress(expected)
+        expected = zlib.compress(expected)
+        assert fc(a) == expected
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to