Author: Ronan Lamy <ronan.l...@gmail.com>
Branch: py3.5
Changeset: r93235:e6985c577de2
Date: 2017-12-01 21:16 +0000
http://bitbucket.org/pypy/pypy/changeset/e6985c577de2/

Log:    Fix handling of arguments containing null bytes in zipimporter
        methods

diff --git a/pypy/module/imp/importing.py b/pypy/module/imp/importing.py
--- a/pypy/module/imp/importing.py
+++ b/pypy/module/imp/importing.py
@@ -292,7 +292,7 @@
               ext + PYC_TAG + '.pyc')
     return result
 
-#@signature(types.str0(), returns=types.str0())
+@signature(types.str0(), returns=types.any())
 def make_source_pathname(pathname):
     "Given the path to a .pyc file, return the path to its .py file."
     # (...)/__pycache__/foo.<tag>.pyc -> (...)/foo.py
diff --git a/pypy/module/zipimport/interp_zipimport.py 
b/pypy/module/zipimport/interp_zipimport.py
--- a/pypy/module/zipimport/interp_zipimport.py
+++ b/pypy/module/zipimport/interp_zipimport.py
@@ -1,6 +1,15 @@
 import os
 import stat
 
+from rpython.annotator.model import s_Str0
+from rpython.rlib.objectmodel import enforceargs
+from rpython.rlib.unroll import unrolling_iterable
+from rpython.rlib.rzipfile import RZipFile, BadZipfile
+from rpython.rlib.rzlib import RZlibError
+from rpython.rlib.rstring import assert_str0
+from rpython.rlib.signature import signature, finishsigs
+from rpython.rlib import types
+
 from pypy.interpreter.baseobjspace import W_Root
 from pypy.interpreter.error import OperationError, oefmt
 from pypy.interpreter.gateway import interp2app, unwrap_spec
@@ -8,9 +17,6 @@
 from pypy.interpreter.module import Module
 from pypy.module.imp import importing
 from pypy.module.zlib.interp_zlib import zlib_error
-from rpython.rlib.unroll import unrolling_iterable
-from rpython.rlib.rzipfile import RZipFile, BadZipfile
-from rpython.rlib.rzlib import RZlibError
 
 ZIPSEP = '/'
 # note that zipfiles always use slash, but for OSes with other
@@ -116,6 +122,7 @@
 
 zip_cache = W_ZipCache()
 
+@finishsigs
 class W_ZipImporter(W_Root):
     def __init__(self, space, name, filename, zip_file, prefix):
         self.space = space
@@ -138,12 +145,14 @@
             filename = filename.replace(os.path.sep, ZIPSEP)
         return filename
 
+    @signature(types.self(), types.str0(), returns=types.str0())
     def corr_zname(self, fname):
         if ZIPSEP != os.path.sep:
             return fname.replace(ZIPSEP, os.path.sep)
         else:
             return fname
 
+    @enforceargs(filename=s_Str0, typecheck=False)
     def import_py_file(self, space, modname, filename, buf, pkgpath):
         w_mod = Module(space, space.newtext(modname))
         real_name = self.filename + os.path.sep + self.corr_zname(filename)
@@ -194,20 +203,21 @@
             return False
         return True
 
+    @enforceargs(filename=s_Str0, typecheck=False)
     def import_pyc_file(self, space, modname, filename, buf, pkgpath):
         magic = importing._get_long(buf[:4])
         timestamp = importing._get_long(buf[4:8])
         if not self.can_use_pyc(space, filename, magic, timestamp):
             return None
         # zipimport ignores the size field
-        buf = buf[12:] # XXX ugly copy, should use sequential read instead
+        buf = buf[12:]  # XXX ugly copy, should use sequential read instead
         w_mod = Module(space, space.newtext(modname))
         real_name = self.filename + os.path.sep + self.corr_zname(filename)
         space.setattr(w_mod, space.newtext('__loader__'), self)
         importing._prepare_module(space, w_mod, real_name, pkgpath)
-        result = importing.load_compiled_module(space, space.newtext(modname), 
w_mod,
-                                                real_name, magic, timestamp,
-                                                buf)
+        result = importing.load_compiled_module(
+            space, space.newtext(modname),
+            w_mod, real_name, magic, timestamp, buf)
         return result
 
     def have_modulefile(self, space, filename):
@@ -227,14 +237,14 @@
                 return self
 
     def make_filename(self, fullname):
-        startpos = fullname.rfind('.') + 1 # 0 when not found
+        startpos = fullname.rfind('.') + 1  # 0 when not found
         assert startpos >= 0
         subname = fullname[startpos:]
         if ZIPSEP == os.path.sep:
             return self.prefix + subname.replace('.', '/')
         else:
-            return self.prefix.replace(os.path.sep, ZIPSEP) + \
-                    subname.replace('.', '/')
+            return (self.prefix.replace(os.path.sep, ZIPSEP) +
+                    subname.replace('.', '/'))
 
     def make_co_filename(self, filename):
         """
@@ -248,6 +258,12 @@
         fullname = space.text_w(w_fullname)
         filename = self.make_filename(fullname)
         for compiled, is_package, ext in ENUMERATE_EXTS:
+            if '\x00' in filename:
+                # Special case to make the annotator happy:
+                # filenames inside ZIPs shouldn't contain NULs so no module can
+                # possibly be found in this case
+                break
+            filename = assert_str0(filename)
             fname = filename + ext
             try:
                 buf = self.zip_file.read(fname)
@@ -302,6 +318,12 @@
         fullname = space.text_w(w_fullname)
         filename = self.make_filename(fullname)
         for compiled, _, ext in ENUMERATE_EXTS:
+            if '\x00' in filename:
+                # Special case to make the annotator happy:
+                # filenames inside ZIPs shouldn't contain NULs so no module can
+                # possibly be found in this case
+                break
+            filename = assert_str0(filename)
             if self.have_modulefile(space, filename + ext):
                 w_source = self.get_data(space, filename + ext)
                 source = space.bytes_w(w_source)
@@ -328,6 +350,12 @@
         filename = self.make_filename(fullname)
         found = False
         for compiled, _, ext in ENUMERATE_EXTS:
+            if '\x00' in filename:
+                # Special case to make the annotator happy:
+                # filenames inside ZIPs shouldn't contain NULs so no module can
+                # possibly be found in this case
+                break
+            filename = assert_str0(filename)
             fname = filename + ext
             if self.have_modulefile(space, fname):
                 if not compiled:
@@ -349,6 +377,12 @@
         fullname = space.text_w(w_fullname)
         filename = self.make_filename(fullname)
         for _, is_package, ext in ENUMERATE_EXTS:
+            if '\x00' in filename:
+                # Special case to make the annotator happy:
+                # filenames inside ZIPs shouldn't contain NULs so no module can
+                # possibly be found in this case
+                break
+            filename = assert_str0(filename)
             if self.have_modulefile(space, filename + ext):
                 return space.newfilename(self.filename + os.path.sep +
                                             self.corr_zname(filename + ext))
@@ -361,6 +395,12 @@
         fullname = space.text_w(w_fullname)
         filename = self.make_filename(fullname)
         for _, is_package, ext in ENUMERATE_EXTS:
+            if '\x00' in filename:
+                # Special case to make the annotator happy:
+                # filenames inside ZIPs shouldn't contain NULs so no module can
+                # possibly be found in this case
+                break
+            filename = assert_str0(filename)
             if self.have_modulefile(space, filename + ext):
                 return space.newbool(is_package)
         raise oefmt(get_error(space),
@@ -373,7 +413,14 @@
         return space.newfilename(self.filename)
 
     def _find_loader(self, space, fullname):
+        if '\x00' in fullname:
+            # Special case to make the annotator happy:
+            # filenames inside ZIPs shouldn't contain NULs so no module can
+            # possibly be found in this case
+            return False, None
+        fullname = assert_str0(fullname)
         filename = self.make_filename(fullname)
+        filename = assert_str0(filename)
         for _, _, ext in ENUMERATE_EXTS:
             if self.have_modulefile(space, filename + ext):
                 return True, None
diff --git a/pypy/module/zipimport/test/test_zipimport.py 
b/pypy/module/zipimport/test/test_zipimport.py
--- a/pypy/module/zipimport/test/test_zipimport.py
+++ b/pypy/module/zipimport/test/test_zipimport.py
@@ -394,10 +394,19 @@
         assert z.get_code('&#228;')
         raises(ImportError, z.get_code, 'xx')
         mod = z.load_module('&#228;')
+        #assert z.load_module('&#228;') is mod
         assert z.get_filename('&#228;') == mod.__file__
         raises(ImportError, z.load_module, 'xx')
         raises(ImportError, z.get_filename, 'xx')
         assert z.archive == self.zipfile
+        # PyPy fix: check null byte behavior:
+        import sys
+        if '__pypy__' in sys.builtin_module_names:
+            raises(ImportError, z.is_package, '&#228;\0 b')
+            raises(ImportError, z.get_source, '&#228;\0 b')
+            raises(ImportError, z.get_code, '&#228;\0 b')
+            raises(ImportError, z.load_module, '&#228;\0 b')
+            raises(ImportError, z.get_filename, '&#228;\0 b')
 
     def test_co_filename(self):
         self.writefile('mymodule.py', """
_______________________________________________
pypy-commit mailing list
pypy-commit@python.org
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to