Author: Armin Rigo <[email protected]>
Branch: py3.5
Changeset: r89073:4a3854e5e543
Date: 2016-12-15 16:05 +0100
http://bitbucket.org/pypy/pypy/changeset/4a3854e5e543/

Log:    bytes(x) should return x.__bytes__() even if that is a subclass of
        'bytes'

diff --git a/pypy/objspace/std/bytesobject.py b/pypy/objspace/std/bytesobject.py
--- a/pypy/objspace/std/bytesobject.py
+++ b/pypy/objspace/std/bytesobject.py
@@ -544,8 +544,16 @@
                     w_item = space.getitem(w_source, space.wrap(0))
                     value = getbytevalue(space, w_item)
                     return W_BytesObject(value)
-        #
-        value = newbytesdata_w(space, w_source, encoding, errors)
+            else:
+                # special-case 'bytes(X)' if X has a __bytes__() method:
+                # we must return the result unmodified even if it is a
+                # subclass of bytes
+                w_result = invoke_bytes_method(space, w_source)
+                if w_result is not None:
+                    return w_result
+            value = newbytesdata_w_tail(space, w_source)
+        else:
+            value = newbytesdata_w(space, w_source, encoding, errors)
         w_obj = space.allocate_instance(W_BytesObject, w_stringtype)
         W_BytesObject.__init__(w_obj, value)
         return w_obj
@@ -699,6 +707,16 @@
         raise oefmt(space.w_ValueError, "byte must be in range(0, 256)")
     return chr(value)
 
+def invoke_bytes_method(space, w_source):
+    w_bytes_method = space.lookup(w_source, "__bytes__")
+    if w_bytes_method is not None:
+        w_bytes = space.get_and_call_function(w_bytes_method, w_source)
+        if not space.isinstance_w(w_bytes, space.w_bytes):
+            raise oefmt(space.w_TypeError,
+                        "__bytes__ returned non-bytes (type '%T')", w_bytes)
+        return w_bytes
+    return None
+
 def newbytesdata_w(space, w_source, encoding, errors):
     # None value
     if w_source is None:
@@ -725,16 +743,18 @@
             raise oefmt(space.w_TypeError,
                 "string argument without an encoding")
     # Fast-path for bytes
-    if space.isinstance_w(w_source, space.w_str):
+    if space.type(w_source) is space.w_bytes:
         return space.bytes_w(w_source)
     # Some other object with a __bytes__ special method (could be str subclass)
-    w_bytes_method = space.lookup(w_source, "__bytes__")
-    if w_bytes_method is not None:
-        w_bytes = space.get_and_call_function(w_bytes_method, w_source)
-        if not space.isinstance_w(w_bytes, space.w_bytes):
-            raise oefmt(space.w_TypeError,
-                        "__bytes__ returned non-bytes (type '%T')", w_bytes)
-        return space.bytes_w(w_bytes)
+    w_result = invoke_bytes_method(space, w_source)
+    if w_result is not None:
+        return space.bytes_w(w_result)
+
+    return newbytesdata_w_tail(space, w_source)
+
+def newbytesdata_w_tail(space, w_source):
+    # converts rare case of bytes constructor arguments: we don't have
+    # any encodings/errors, and the argument does not have __bytes__()
     if space.isinstance_w(w_source, space.w_unicode):
         raise oefmt(space.w_TypeError, "string argument without an encoding")
 
diff --git a/pypy/objspace/std/test/test_bytesobject.py 
b/pypy/objspace/std/test/test_bytesobject.py
--- a/pypy/objspace/std/test/test_bytesobject.py
+++ b/pypy/objspace/std/test/test_bytesobject.py
@@ -969,3 +969,30 @@
     def test_constructor_typeerror(self):
         raises(TypeError, bytes, b'', 'ascii')
         raises(TypeError, bytes, '')
+
+    def test_constructor_subclass(self):
+        class Sub(bytes):
+            pass
+        class X:
+            def __bytes__(self):
+                return Sub(b'foo')
+        assert type(bytes(X())) is Sub
+
+    def test_constructor_subclass_2(self):
+        class Sub(bytes):
+            pass
+        class X(bytes):
+            def __bytes__(self):
+                return Sub(b'foo')
+        assert type(bytes(X())) is Sub
+
+    def test_constructor_subclass_3(self):
+        class Sub(bytes):
+            pass
+        class X(bytes):
+            def __bytes__(self):
+                return Sub(b'foo')
+        class Sub1(bytes):
+            pass
+        assert type(Sub1(X())) is Sub1
+        assert Sub1(X()) == b'foo'
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to