Author: Amaury Forgeot d'Arc <[email protected]>
Branch: py3k
Changeset: r50148:a0548becd3ac
Date: 2011-12-04 16:45 +0100
http://bitbucket.org/pypy/pypy/changeset/a0548becd3ac/

Log:    Progress in the _ssl module

diff --git a/pypy/module/_socket/interp_socket.py 
b/pypy/module/_socket/interp_socket.py
--- a/pypy/module/_socket/interp_socket.py
+++ b/pypy/module/_socket/interp_socket.py
@@ -108,6 +108,15 @@
         """
         return space.wrap(intmask(self.fd))
 
+    def detach_w(self, space):
+        """detach()
+
+        Close the socket object without closing the underlying file descriptor.
+        The object cannot be used after this call, but the file descriptor
+        can be reused for other purposes.  The file descriptor is returned."""
+        fd = self.detach()
+        return space.wrap(intmask(fd))
+
     def getpeername_w(self, space):
         """getpeername() -> address info
 
@@ -464,7 +473,7 @@
 # ____________________________________________________________
 
 socketmethodnames = """
-_accept bind close connect connect_ex dup fileno
+_accept bind close connect connect_ex dup fileno detach
 getpeername getsockname getsockopt gettimeout listen makefile
 recv recvfrom send sendall sendto setblocking
 setsockopt settimeout shutdown _reuse _drop recv_into recvfrom_into
diff --git a/pypy/module/_ssl/__init__.py b/pypy/module/_ssl/__init__.py
--- a/pypy/module/_ssl/__init__.py
+++ b/pypy/module/_ssl/__init__.py
@@ -5,8 +5,9 @@
     See the socket module for documentation."""
 
     interpleveldefs = {
-        'sslwrap': 'interp_ssl.sslwrap',
         'SSLError': 'interp_ssl.get_error(space)',
+        '_SSLSocket': 'interp_ssl.SSLSocket',
+        '_SSLContext': 'interp_ssl.SSLContext',
         '_test_decode_cert': 'interp_ssl._test_decode_cert',
     }
 
diff --git a/pypy/module/_ssl/interp_ssl.py b/pypy/module/_ssl/interp_ssl.py
--- a/pypy/module/_ssl/interp_ssl.py
+++ b/pypy/module/_ssl/interp_ssl.py
@@ -2,7 +2,7 @@
 from pypy.rpython.lltypesystem import rffi, lltype
 from pypy.interpreter.error import OperationError
 from pypy.interpreter.baseobjspace import Wrappable
-from pypy.interpreter.typedef import TypeDef
+from pypy.interpreter.typedef import TypeDef, GetSetProperty
 from pypy.interpreter.gateway import interp2app, unwrap_spec
 
 from pypy.rlib.rarithmetic import intmask
@@ -10,6 +10,7 @@
 from pypy.rlib.ropenssl import *
 
 from pypy.module._socket import interp_socket
+import weakref
 
 
 ## user defined constants
@@ -58,15 +59,28 @@
 constants["PROTOCOL_SSLv23"] = PY_SSL_VERSION_SSL23
 constants["PROTOCOL_TLSv1"]  = PY_SSL_VERSION_TLS1
 
-constants["OPENSSL_VERSION_NUMBER"] = OPENSSL_VERSION_NUMBER
-ver = OPENSSL_VERSION_NUMBER
-ver, status = divmod(ver, 16)
-ver, patch  = divmod(ver, 256)
-ver, fix    = divmod(ver, 256)
-ver, minor  = divmod(ver, 256)
-ver, major  = divmod(ver, 256)
-constants["OPENSSL_VERSION_INFO"] = (major, minor, fix, patch, status)
+# protocol options
+constants["OP_ALL"] = SSL_OP_ALL
+constants["OP_NO_SSLv2"] = SSL_OP_NO_SSLv2
+constants["OP_NO_SSLv3"] = SSL_OP_NO_SSLv3
+constants["OP_NO_TLSv1"] = SSL_OP_NO_TLSv1
+constants["HAS_SNI"] = HAS_SNI
+
+# OpenSSL version
+def _parse_version(ver):
+    ver, status = divmod(ver, 16)
+    ver, patch  = divmod(ver, 256)
+    ver, fix    = divmod(ver, 256)
+    ver, minor  = divmod(ver, 256)
+    ver, major  = divmod(ver, 256)
+    return (major, minor, fix, patch, status)
+# XXX use SSLeay() to get the version of the library linked against, which
+# could be different from the headers version.
+libver = OPENSSL_VERSION_NUMBER
+constants["OPENSSL_VERSION_NUMBER"] = libver
+constants["OPENSSL_VERSION_INFO"] = _parse_version(libver)
 constants["OPENSSL_VERSION"] = SSLEAY_VERSION
+constants["_OPENSSL_API_VERSION"] = _parse_version(libver)
 
 def ssl_error(space, msg, errno=0):
     w_exception_class = get_error(space)
@@ -74,6 +88,186 @@
                                       space.wrap(errno), space.wrap(msg))
     return OperationError(w_exception_class, w_exception)
 
+
+class SSLContext(Wrappable):
+    def __init__(self, method):
+        self.ctx = libssl_SSL_CTX_new(method)
+
+        # Defaults
+        libssl_SSL_CTX_set_verify(self.ctx, SSL_VERIFY_NONE, None)
+        libssl_SSL_CTX_set_options(self.ctx, SSL_OP_ALL)
+        libssl_SSL_CTX_set_session_id_context(self.ctx, "Python", 
len("Python"))
+
+    def __del__(self):
+        if self.ctx:
+            libssl_SSL_CTX_free(self.ctx)
+
+    @unwrap_spec(protocol=int)
+    def descr_new(space, w_subtype, protocol=PY_SSL_VERSION_SSL23):
+        self = space.allocate_instance(SSLContext, w_subtype)
+        if protocol == PY_SSL_VERSION_TLS1:
+            method = libssl_TLSv1_method()
+        elif protocol == PY_SSL_VERSION_SSL3:
+            method = libssl_SSLv3_method()
+        elif protocol == PY_SSL_VERSION_SSL2 and not OPENSSL_NO_SSL2:
+            method = libssl_SSLv2_method()
+        elif protocol == PY_SSL_VERSION_SSL23:
+            method = libssl_SSLv23_method()
+        else:
+            raise ssl_error(space, "invalid SSL protocol version")
+        self.__init__(method)
+        if not self.ctx:
+            raise ssl_error(space, "failed to allocate SSL context")
+        return space.wrap(self)
+
+    @unwrap_spec(cipherlist=str)
+    def set_ciphers_w(self, space, cipherlist):
+        ret = libssl_SSL_CTX_set_cipher_list(self.ctx, cipherlist)
+        if ret == 0:
+            # Clearing the error queue is necessary on some OpenSSL
+            # versions, otherwise the error will be reported again
+            # when another SSL call is done.
+            libssl_ERR_clear_error()
+            raise ssl_error(space, "No cipher can be selected.")
+
+    def get_verify_mode_w(self, space):
+        verify_mode = libssl_SSL_CTX_get_verify_mode(self.ctx)
+        if verify_mode == SSL_VERIFY_NONE:
+            return space.wrap(PY_SSL_CERT_NONE)
+        elif verify_mode == SSL_VERIFY_PEER:
+            return space.wrap(PY_SSL_CERT_OPTIONAL)
+        elif verify_mode == (SSL_VERIFY_PEER | 
SSL_VERIFY_FAIL_IF_NO_PEER_CERT):
+            return space.wrap(PY_SSL_CERT_REQUIRED)
+        else:
+            raise ssl_error(
+                space,  "invalid return value from SSL_CTX_get_verify_mode")
+
+    def set_verify_mode_w(self, space, w_mode):
+        mode = space.int_w(w_mode)
+        if mode == PY_SSL_CERT_NONE:
+            verify_mode = SSL_VERIFY_NONE
+        elif mode == PY_SSL_CERT_OPTIONAL:
+            verify_mode = SSL_VERIFY_PEER
+        elif mode == PY_SSL_CERT_REQUIRED:
+            verify_mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT
+        else:
+            raise OperationError(space.w_ValueError, space.wrap(
+                    "invalid value for verify_mode"))
+        libssl_SSL_CTX_set_verify(self.ctx, verify_mode, None)
+        
+    def get_options_w(self, space):
+        return space.wrap(libssl_SSL_CTX_get_options(self.ctx))
+
+    def set_options_w(self, space, value):
+        opts = libssl_SSL_CTX_get_options(self.ctx)
+        clear = opts & ~new_opts
+        set = ~opts & new_opts
+        if clear:
+            if HAVE_SSL_CTX_CLEAR_OPTIONS:
+                libssl_SSL_CTX_clear_options(self.ctx, clear)
+            else:
+                raise OperationError(space.w_ValueError, space.wrap(
+                        "can't clear options before OpenSSL 0.9.8m"))
+        if set:
+            libssl_SSL_CTX_set_options(self.ctx, set)
+
+    def load_cert_chain_w(self, space, w_certfile, w_keyfile=None):
+        if space.is_w(w_certfile, space.w_None):
+            certfile = None
+        else:
+            certfile = space.str_w(w_certfile)
+        if space.is_w(w_keyfile, space.w_None):
+            keyfile = certfile
+        else:
+            keyfile = space.str_w(w_keyfile)
+
+        ret = libssl_SSL_CTX_use_certificate_chain_file(self.ctx, certfile)
+        if ret != 1:
+            errno = get_errno()
+            if errno:
+                libssl_ERR_clear_error()
+                raise_from_errno(space.w_IOError, errno)
+            else:
+                raise _ssl_seterror(space, None, -1)
+
+        ret = libssl_SSL_CTX_use_PrivateKey_file(ss.ctx, key_file,
+                                                 SSL_FILETYPE_PEM)
+        if ret != 1:
+            errno = get_errno()
+            if errno:
+                libssl_ERR_clear_error()
+                raise_from_errno(space.w_IOError, errno)
+            else:
+                raise _ssl_seterror(space, None, -1)
+
+        ret = libssl_SSL_CTX_check_private_key(self.ctx)
+        if ret != 1:
+            raise _ssl_seterror(space, None, -1)
+
+    def load_verify_locations_w(self, space, w_cafile=None, w_capath=None):
+        if space.is_w(w_cafile, space.w_None):
+            cafile = None
+        else:
+            cafile = space.str_w(w_cafile)
+        if space.is_w(w_capath, space.w_None):
+            capath = None
+        else:
+            capath = space.str_w(w_capath)
+        if cafile is None and capath is None:
+            raise OperationError(space.w_ValueError, space.wrap(
+                    "cafile and capath cannot be both omitted"))
+        ret = libssl_SSL_CTX_load_verify_locations(
+            self.ctx, cafile, capath)
+        if ret != 1:
+            errno = get_errno()
+            if errno:
+                libssl_ERR_clear_error()
+                raise_from_errno(space.w_IOError, errno)
+            else:
+                raise _ssl_seterror(space, None, -1)
+
+    @unwrap_spec(server_side=int)
+    def wrap_socket_w(self, space, w_sock, server_side,
+                      w_server_hostname=None):
+        assert w_sock is not None
+        # server_hostname is either None (or absent), or to be encoded
+        # using the idna encoding.
+        if space.is_w(w_server_hostname, space.w_None):
+            hostname = None
+        else:
+            hostname = space.bytes_w(
+                space.call_method(w_server_hostname, "idna"))
+
+        if hostname and not HAS_SNI:
+            raise OperationError(space.w_ValueError,
+                                 space.wrap("server_hostname is not supported "
+                                            "by your OpenSSL library"))
+
+        return new_sslobject(space, self.ctx, w_sock, server_side, hostname)
+
+    def session_stats_w(self, space):
+        w_stats = space.newdict()
+        for name, ssl_func in SSL_CTX_STATS:
+            w_value = space.wrap(ssl_func(self.ctx))
+            space.setitem_str(w_stats, attr, w_value)
+        return w_stats
+
+    def set_default_verify_paths_w(self):
+        ret = libssl_SSL_CTX_set_default_verify_paths(self.ctx)
+        if ret != 1:
+            raise _ssl_seterror(space, None, -1)
+
+
+SSLContext.typedef = TypeDef(
+    "_SSLContext",
+    __new__ = interp2app(SSLContext.descr_new.im_func),
+    verify_mode = GetSetProperty(SSLContext.get_verify_mode_w,
+                                 SSLContext.set_verify_mode_w),
+    _wrap_socket = interp2app(SSLContext.wrap_socket_w),
+)
+
+    
+
 if HAVE_OPENSSL_RAND:
     # helper routines for seeding the SSL PRNG
     @unwrap_spec(string=str, entropy=float)
@@ -119,11 +313,10 @@
             raise ssl_error(space, msg)
         return space.wrap(bytes)
 
-class SSLObject(Wrappable):
+
+class SSLSocket(Wrappable):
     def __init__(self, space):
-        self.space = space
         self.w_socket = None
-        self.ctx = lltype.nullptr(SSL_CTX.TO)
         self.ssl = lltype.nullptr(SSL.TO)
         self.peer_cert = lltype.nullptr(X509.TO)
         self._server = lltype.malloc(rffi.CCHARP.TO, X509_NAME_MAXLEN, 
flavor='raw')
@@ -132,43 +325,35 @@
         self._issuer[0] = '\0'
         self.shutdown_seen_zero = False
 
-    def server(self):
-        return self.space.wrap(rffi.charp2str(self._server))
+    def server(self, space):
+        return space.wrap(rffi.charp2str(self._server))
 
-    def issuer(self):
-        return self.space.wrap(rffi.charp2str(self._issuer))
+    def issuer(self, space):
+        return space.wrap(rffi.charp2str(self._issuer))
 
     def __del__(self):
-        self.enqueue_for_destruction(self.space, SSLObject.destructor,
-                                     '__del__() method of ')
-
-    def destructor(self):
-        assert isinstance(self, SSLObject)
         if self.peer_cert:
             libssl_X509_free(self.peer_cert)
         if self.ssl:
             libssl_SSL_free(self.ssl)
-        if self.ctx:
-            libssl_SSL_CTX_free(self.ctx)
         lltype.free(self._server, flavor='raw')
         lltype.free(self._issuer, flavor='raw')
 
     @unwrap_spec(data='bufferstr')
-    def write(self, data):
+    def write(self, space, data):
         """write(s) -> len
 
         Writes the string s into the SSL object.  Returns the number
         of bytes written."""
-        self._refresh_nonblocking(self.space)
+        w_socket = self._get_socket(space)
 
-        sockstate = check_socket_and_wait_for_timeout(self.space,
-            self.w_socket, True)
+        sockstate = check_socket_and_wait_for_timeout(space, w_socket, True)
         if sockstate == SOCKET_HAS_TIMED_OUT:
-            raise ssl_error(self.space, "The write operation timed out")
+            raise ssl_error(space, "The write operation timed out")
         elif sockstate == SOCKET_HAS_BEEN_CLOSED:
-            raise ssl_error(self.space, "Underlying socket has been closed.")
+            raise ssl_error(space, "Underlying socket has been closed.")
         elif sockstate == SOCKET_TOO_LARGE_FOR_SELECT:
-            raise ssl_error(self.space, "Underlying socket too large for 
select().")
+            raise ssl_error(space, "Underlying socket too large for select().")
 
         num_bytes = 0
         while True:
@@ -178,18 +363,18 @@
             err = libssl_SSL_get_error(self.ssl, num_bytes)
 
             if err == SSL_ERROR_WANT_READ:
-                sockstate = check_socket_and_wait_for_timeout(self.space,
-                    self.w_socket, False)
+                sockstate = check_socket_and_wait_for_timeout(
+                    space, w_socket, False)
             elif err == SSL_ERROR_WANT_WRITE:
-                sockstate = check_socket_and_wait_for_timeout(self.space,
-                    self.w_socket, True)
+                sockstate = check_socket_and_wait_for_timeout(
+                    space, w_socket, True)
             else:
                 sockstate = SOCKET_OPERATION_OK
 
             if sockstate == SOCKET_HAS_TIMED_OUT:
-                raise ssl_error(self.space, "The write operation timed out")
+                raise ssl_error(space, "The write operation timed out")
             elif sockstate == SOCKET_HAS_BEEN_CLOSED:
-                raise ssl_error(self.space, "Underlying socket has been 
closed.")
+                raise ssl_error(space, "Underlying socket has been closed.")
             elif sockstate == SOCKET_IS_NONBLOCKING:
                 break
 
@@ -199,38 +384,39 @@
                 break
 
         if num_bytes > 0:
-            return self.space.wrap(num_bytes)
+            return space.wrap(num_bytes)
         else:
-            raise _ssl_seterror(self.space, self, num_bytes)
+            raise _ssl_seterror(space, self, num_bytes)
 
-    def pending(self):
+    def pending(self, space):
         """pending() -> count
 
         Returns the number of already decrypted bytes available for read,
         pending on the connection."""
         count = libssl_SSL_pending(self.ssl)
         if count < 0:
-            raise _ssl_seterror(self.space, self, count)
-        return self.space.wrap(count)
+            raise _ssl_seterror(space, self, count)
+        return space.wrap(count)
 
     @unwrap_spec(num_bytes=int)
-    def read(self, num_bytes=1024):
+    def read(self, space, num_bytes=1024):
         """read([len]) -> string
 
         Read up to len bytes from the SSL socket."""
+        w_socket = self._get_socket(space)
 
         count = libssl_SSL_pending(self.ssl)
         if not count:
-            sockstate = check_socket_and_wait_for_timeout(self.space,
-                self.w_socket, False)
+            sockstate = check_socket_and_wait_for_timeout(
+                space, w_socket, False)
             if sockstate == SOCKET_HAS_TIMED_OUT:
-                raise ssl_error(self.space, "The read operation timed out")
+                raise ssl_error(space, "The read operation timed out")
             elif sockstate == SOCKET_TOO_LARGE_FOR_SELECT:
-                raise ssl_error(self.space, "Underlying socket too large for 
select().")
+                raise ssl_error(space, "Underlying socket too large for 
select().")
             elif sockstate == SOCKET_HAS_BEEN_CLOSED:
                 if libssl_SSL_get_shutdown(self.ssl) == SSL_RECEIVED_SHUTDOWN:
-                    return self.space.wrap('')
-                raise ssl_error(self.space, "Socket closed without SSL 
shutdown handshake")
+                    return space.wrapbytes('')
+                raise ssl_error(space, "Socket closed without SSL shutdown 
handshake")
 
         raw_buf, gc_buf = rffi.alloc_buffer(num_bytes)
         while True:
@@ -240,19 +426,19 @@
             err = libssl_SSL_get_error(self.ssl, count)
 
             if err == SSL_ERROR_WANT_READ:
-                sockstate = check_socket_and_wait_for_timeout(self.space,
-                    self.w_socket, False)
+                sockstate = check_socket_and_wait_for_timeout(
+                    space, w_socket, False)
             elif err == SSL_ERROR_WANT_WRITE:
-                sockstate = check_socket_and_wait_for_timeout(self.space,
-                    self.w_socket, True)
+                sockstate = check_socket_and_wait_for_timeout(
+                    space, w_socket, True)
             elif (err == SSL_ERROR_ZERO_RETURN and
                   libssl_SSL_get_shutdown(self.ssl) == SSL_RECEIVED_SHUTDOWN):
-                return self.space.wrap("")
+                return space.wrapbytes('')
             else:
                 sockstate = SOCKET_OPERATION_OK
 
             if sockstate == SOCKET_HAS_TIMED_OUT:
-                raise ssl_error(self.space, "The read operation timed out")
+                raise ssl_error(space, "The read operation timed out")
             elif sockstate == SOCKET_IS_NONBLOCKING:
                 break
 
@@ -262,21 +448,27 @@
                 break
 
         if count <= 0:
-            raise _ssl_seterror(self.space, self, count)
+            raise _ssl_seterror(space, self, count)
 
         result = rffi.str_from_buffer(raw_buf, gc_buf, num_bytes, count)
         rffi.keep_buffer_alive_until_here(raw_buf, gc_buf)
-        return self.space.wrap(result)
+        return space.wrapbytes(result)
 
-    def _refresh_nonblocking(self, space):
+    def _get_socket(self, space):
+        w_socket = self.w_socket()
+        if w_socket is None:
+            raise ssl_error(space, "Underlying socket connection gone")
+
         # just in case the blocking state of the socket has been changed
-        w_timeout = space.call_method(self.w_socket, "gettimeout")
+        w_timeout = space.call_method(w_socket, "gettimeout")
         nonblocking = not space.is_w(w_timeout, space.w_None)
         libssl_BIO_set_nbio(libssl_SSL_get_rbio(self.ssl), nonblocking)
         libssl_BIO_set_nbio(libssl_SSL_get_wbio(self.ssl), nonblocking)
 
+        return w_socket
+
     def do_handshake(self, space):
-        self._refresh_nonblocking(space)
+        w_socket = self._get_socket(space)
 
         # Actually negotiate SSL connection
         # XXX If SSL_do_handshake() returns 0, it's also a failure.
@@ -286,10 +478,10 @@
             # XXX PyErr_CheckSignals()
             if err == SSL_ERROR_WANT_READ:
                 sockstate = check_socket_and_wait_for_timeout(
-                    space, self.w_socket, False)
+                    space, w_socket, False)
             elif err == SSL_ERROR_WANT_WRITE:
                 sockstate = check_socket_and_wait_for_timeout(
-                    space, self.w_socket, True)
+                    space, w_socket, True)
             else:
                 sockstate = SOCKET_OPERATION_OK
             if sockstate == SOCKET_HAS_TIMED_OUT:
@@ -321,13 +513,13 @@
                 self._issuer, X509_NAME_MAXLEN)
 
     def shutdown(self, space):
+        w_socket = self._get_socket(space)
+
         # Guard against closed socket
-        w_fileno = space.call_method(self.w_socket, "fileno")
+        w_fileno = space.call_method(w_socket, "fileno")
         if space.int_w(w_fileno) < 0:
             raise ssl_error(space, "Underlying socket has been closed")
 
-        self._refresh_nonblocking(space)
-
         zeros = 0
 
         while True:
@@ -360,18 +552,18 @@
             ssl_err = libssl_SSL_get_error(self.ssl, ret)
             if ssl_err == SSL_ERROR_WANT_READ:
                 sockstate = check_socket_and_wait_for_timeout(
-                    self.space, self.w_socket, False)
+                    space, w_socket, False)
             elif ssl_err == SSL_ERROR_WANT_WRITE:
                 sockstate = check_socket_and_wait_for_timeout(
-                    self.space, self.w_socket, True)
+                    space, w_socket, True)
             else:
                 break
 
             if sockstate == SOCKET_HAS_TIMED_OUT:
                 if ssl_err == SSL_ERROR_WANT_READ:
-                    raise ssl_error(self.space, "The read operation timed out")
+                    raise ssl_error(space, "The read operation timed out")
                 else:
-                    raise ssl_error(self.space, "The write operation timed 
out")
+                    raise ssl_error(space, "The write operation timed out")
             elif sockstate == SOCKET_TOO_LARGE_FOR_SELECT:
                 raise ssl_error(space, "Underlying socket too large for 
select().")
             elif sockstate != SOCKET_OPERATION_OK:
@@ -381,7 +573,7 @@
         if ret < 0:
             raise _ssl_seterror(space, self, ret)
 
-        return self.w_socket
+        return w_socket
 
     def cipher(self, space):
         if not self.ssl:
@@ -409,7 +601,7 @@
         return space.newtuple([w_name, w_proto, w_bits])
 
     @unwrap_spec(der=bool)
-    def peer_certificate(self, der=False):
+    def peer_certificate(self, space, der=False):
         """peer_certificate([der=False]) -> certificate
 
         Returns the certificate for the peer.  If no certificate was provided,
@@ -421,7 +613,7 @@
         peer certificate, or None if no certificate was provided.  This will
         return the certificate even if it wasn't validated."""
         if not self.peer_cert:
-            return self.space.w_None
+            return space.w_None
 
         if der:
             # return cert in DER-encoded format
@@ -429,19 +621,19 @@
                 buf_ptr[0] = lltype.nullptr(rffi.CCHARP.TO)
                 length = libssl_i2d_X509(self.peer_cert, buf_ptr)
                 if length < 0:
-                    raise _ssl_seterror(self.space, self, length)
+                    raise _ssl_seterror(space, self, length)
                 try:
                     # this is actually an immutable bytes sequence
-                    return self.space.wrap(rffi.charp2str(buf_ptr[0]))
+                    return space.wrap(rffi.charp2str(buf_ptr[0]))
                 finally:
                     libssl_OPENSSL_free(buf_ptr[0])
         else:
             verification = libssl_SSL_CTX_get_verify_mode(
                 libssl_SSL_get_SSL_CTX(self.ssl))
             if not verification & SSL_VERIFY_PEER:
-                return self.space.newdict()
+                return space.newdict()
             else:
-                return _decode_certificate(self.space, self.peer_cert)
+                return _decode_certificate(space, self.peer_cert)
 
 def _decode_certificate(space, certificate, verbose=False):
     w_retval = space.newdict()
@@ -625,22 +817,21 @@
 
     return space.newtuple([w_name, w_value])
 
-SSLObject.typedef = TypeDef("SSLObject",
-    server = interp2app(SSLObject.server),
-    issuer = interp2app(SSLObject.issuer),
-    write = interp2app(SSLObject.write),
-    pending = interp2app(SSLObject.pending),
-    read = interp2app(SSLObject.read),
-    do_handshake = interp2app(SSLObject.do_handshake),
-    shutdown = interp2app(SSLObject.shutdown),
-    cipher = interp2app(SSLObject.cipher),
-    peer_certificate = interp2app(SSLObject.peer_certificate),
+SSLSocket.typedef = TypeDef("_SSLSocket",
+    server = interp2app(SSLSocket.server),
+    issuer = interp2app(SSLSocket.issuer),
+    write = interp2app(SSLSocket.write),
+    pending = interp2app(SSLSocket.pending),
+    read = interp2app(SSLSocket.read),
+    do_handshake = interp2app(SSLSocket.do_handshake),
+    shutdown = interp2app(SSLSocket.shutdown),
+    cipher = interp2app(SSLSocket.cipher),
+    peer_certificate = interp2app(SSLSocket.peer_certificate),
 )
 
 
-def new_sslobject(space, w_sock, side, w_key_file, w_cert_file,
-                  cert_mode, protocol, w_cacerts_file, w_ciphers):
-    ss = SSLObject(space)
+def new_sslobject(space, ctx, w_sock, side, server_hostname):
+    ss = SSLSocket(space)
 
     sock_fd = space.int_w(space.call_method(w_sock, "fileno"))
     w_timeout = space.call_method(w_sock, "gettimeout")
@@ -648,93 +839,27 @@
         has_timeout = False
     else:
         has_timeout = True
-    if space.is_w(w_key_file, space.w_None):
-        key_file = None
-    else:
-        key_file = space.str_w(w_key_file)
-    if space.is_w(w_cert_file, space.w_None):
-        cert_file = None
-    else:
-        cert_file = space.str_w(w_cert_file)
-    if space.is_w(w_cacerts_file, space.w_None):
-        cacerts_file = None
-    else:
-        cacerts_file = space.str_w(w_cacerts_file)
-    if space.is_w(w_ciphers, space.w_None):
-        ciphers = None
-    else:
-        ciphers = space.str_w(w_ciphers)
 
-    if side == PY_SSL_SERVER and (not key_file or not cert_file):
-        raise ssl_error(space, "Both the key & certificate files "
-                        "must be specified for server-side operation")
-
-    # set up context
-    if protocol == PY_SSL_VERSION_TLS1:
-        method = libssl_TLSv1_method()
-    elif protocol == PY_SSL_VERSION_SSL3:
-        method = libssl_SSLv3_method()
-    elif protocol == PY_SSL_VERSION_SSL2 and not OPENSSL_NO_SSL2:
-        method = libssl_SSLv2_method()
-    elif protocol == PY_SSL_VERSION_SSL23:
-        method = libssl_SSLv23_method()
-    else:
-        raise ssl_error(space, "Invalid SSL protocol variant specified")
-    ss.ctx = libssl_SSL_CTX_new(method)
-    if not ss.ctx:
-        raise ssl_error(space, "Could not create SSL context")
-
-    if ciphers:
-        ret = libssl_SSL_CTX_set_cipher_list(ss.ctx, ciphers)
-        if ret == 0:
-            raise ssl_error(space, "No cipher can be selected.")
-
-    if cert_mode != PY_SSL_CERT_NONE:
-        if not cacerts_file:
-            raise ssl_error(space,
-                            "No root certificates specified for "
-                            "verification of other-side certificates.")
-        ret = libssl_SSL_CTX_load_verify_locations(ss.ctx, cacerts_file, None)
-        if ret != 1:
-            raise _ssl_seterror(space, None, 0)
-
-    if key_file:
-        ret = libssl_SSL_CTX_use_PrivateKey_file(ss.ctx, key_file,
-                                                 SSL_FILETYPE_PEM)
-        if ret < 1:
-            raise ssl_error(space, "SSL_CTX_use_PrivateKey_file error")
-
-        ret = libssl_SSL_CTX_use_certificate_chain_file(ss.ctx, cert_file)
-        if ret < 1:
-            raise ssl_error(space, "SSL_CTX_use_certificate_chain_file error")
-
-    # ssl compatibility
-    libssl_SSL_CTX_set_options(ss.ctx, SSL_OP_ALL)
-
-    verification_mode = SSL_VERIFY_NONE
-    if cert_mode == PY_SSL_CERT_OPTIONAL:
-        verification_mode = SSL_VERIFY_PEER
-    elif cert_mode == PY_SSL_CERT_REQUIRED:
-        verification_mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT
-    libssl_SSL_CTX_set_verify(ss.ctx, verification_mode, None)
-    ss.ssl = libssl_SSL_new(ss.ctx) # new ssl struct
+    ss.ssl = libssl_SSL_new(ctx) # new ssl struct
     libssl_SSL_set_fd(ss.ssl, sock_fd) # set the socket for SSL
     libssl_SSL_set_mode(ss.ssl, SSL_MODE_AUTO_RETRY)
 
+    if server_hostname:
+        libssl_SSL_set_tlsext_host_name(ss.ssl, server_hostname);
+
     # If the socket is in non-blocking mode or timeout mode, set the BIO
     # to non-blocking mode (blocking is the default)
     if has_timeout:
         # Set both the read and write BIO's to non-blocking mode
         libssl_BIO_set_nbio(libssl_SSL_get_rbio(ss.ssl), 1)
         libssl_BIO_set_nbio(libssl_SSL_get_wbio(ss.ssl), 1)
-    libssl_SSL_set_connect_state(ss.ssl)
 
     if side == PY_SSL_CLIENT:
         libssl_SSL_set_connect_state(ss.ssl)
     else:
         libssl_SSL_set_accept_state(ss.ssl)
 
-    ss.w_socket = w_sock
+    ss.w_socket = weakref.ref(w_sock)
     return ss
 
 def check_socket_and_wait_for_timeout(space, w_sock, writing):
@@ -812,7 +937,7 @@
     elif err == SSL_ERROR_SYSCALL:
         e = libssl_ERR_get_error()
         if e == 0:
-            if ret == 0 or space.is_w(ss.w_socket, space.w_None):
+            if ret == 0 or ss.w_socket() is None:
                 errstr = "EOF occurred in violation of protocol"
                 errval = PY_SSL_ERROR_EOF
             elif ret == -1:
@@ -839,16 +964,6 @@
     return ssl_error(space, errstr, errval)
 
 
-@unwrap_spec(side=int, cert_mode=int, protocol=int)
-def sslwrap(space, w_socket, side, w_key_file=None, w_cert_file=None,
-            cert_mode=PY_SSL_CERT_NONE, protocol=PY_SSL_VERSION_SSL23,
-            w_cacerts_file=None, w_ciphers=None):
-    """sslwrap(socket, side, [keyfile, certfile]) -> sslobject"""
-    return space.wrap(new_sslobject(
-        space, w_socket, side, w_key_file, w_cert_file,
-        cert_mode, protocol,
-        w_cacerts_file, w_ciphers))
-
 class Cache:
     def __init__(self, space):
         w_socketerror = interp_socket.get_error(space, "error")
diff --git a/pypy/module/_ssl/test/test_ssl.py 
b/pypy/module/_ssl/test/test_ssl.py
--- a/pypy/module/_ssl/test/test_ssl.py
+++ b/pypy/module/_ssl/test/test_ssl.py
@@ -63,29 +63,11 @@
         _ssl.RAND_egd("entropy")
 
     def test_sslwrap(self):
-        import _ssl, _socket, sys, gc
+        import ssl, _socket, sys, gc
         if sys.platform == 'darwin':
             skip("hangs indefinitely on OSX (also on CPython)")
         s = _socket.socket()
-        ss = _ssl.sslwrap(s, 0)
-        exc = raises(_socket.error, ss.do_handshake)
-        if sys.platform == 'win32':
-            assert exc.value.errno == 10057 # WSAENOTCONN
-        else:
-            assert exc.value.errno == 32 # Broken pipe
-        del exc, ss, s
-        gc.collect()     # force the destructor() to be called now
-
-    def test_async_closed(self):
-        import _ssl, _socket, gc
-        s = _socket.socket()
-        s.settimeout(3)
-        ss = _ssl.sslwrap(s, 0)
-        s.close()
-        exc = raises(_ssl.SSLError, ss.write, "data")
-        assert exc.value.strerror == "Underlying socket has been closed."
-        del exc, ss, s
-        gc.collect()     # force the destructor() to be called now
+        ss = ssl.wrap_socket(s)
 
 
 class AppTestConnectedSSL:
@@ -108,65 +90,65 @@
             """)
 
     def test_connect(self):
-        import socket, gc
-        ss = socket.ssl(self.s)
+        import ssl, gc
+        ss = ssl.wrap_socket(self.s)
         self.s.close()
         del ss; gc.collect()
 
     def test_server(self):
-        import socket, gc
-        ss = socket.ssl(self.s)
+        import ssl, gc
+        ss = ssl.wrap_socket(self.s)
         assert isinstance(ss.server(), str)
         self.s.close()
         del ss; gc.collect()
 
     def test_issuer(self):
-        import socket, gc
-        ss = socket.ssl(self.s)
+        import ssl, gc
+        ss = ssl.wrap_socket(self.s)
         assert isinstance(ss.issuer(), str)
         self.s.close()
         del ss; gc.collect()
 
     def test_write(self):
-        import socket, gc
-        ss = socket.ssl(self.s)
+        import ssl, gc
+        ss = ssl.wrap_socket(self.s)
         raises(TypeError, ss.write, 123)
-        num_bytes = ss.write("hello\n")
+        num_bytes = ss.write(b"hello\n")
         assert isinstance(num_bytes, int)
         assert num_bytes >= 0
         self.s.close()
         del ss; gc.collect()
 
     def test_read(self):
-        import socket, gc
-        ss = socket.ssl(self.s)
-        raises(TypeError, ss.read, "foo")
-        ss.write("hello\n")
+        import ssl, gc
+        ss = ssl.wrap_socket(self.s)
+        raises(TypeError, ss.read, b"foo")
+        ss.write(b"hello\n")
         data = ss.read()
-        assert isinstance(data, str)
+        assert isinstance(data, bytes)
         self.s.close()
         del ss; gc.collect()
 
     def test_read_upto(self):
-        import socket, gc
-        ss = socket.ssl(self.s)
-        raises(TypeError, ss.read, "foo")
-        ss.write("hello\n")
+        import ssl, gc
+        ss = ssl.wrap_socket(self.s)
+        raises(TypeError, ss.read, b"foo")
+        ss.write(b"hello\n")
         data = ss.read(10)
-        assert isinstance(data, str)
+        assert isinstance(data, bytes)
         assert len(data) == 10
         assert ss.pending() > 50 # many more bytes to read
         self.s.close()
         del ss; gc.collect()
 
     def test_shutdown(self):
-        import socket, ssl, sys, gc
+        import ssl, ssl, sys, gc
         if sys.platform == 'darwin':
             skip("get also on CPython: error: [Errno 0]")
-        ss = socket.ssl(self.s)
-        ss.write("hello\n")
+        ss = ssl.wrap_socket(self.s)
+        ss.write(b"hello\n")
         assert ss.shutdown() is self.s._sock
-        raises(ssl.SSLError, ss.write, "hello\n")
+        raises(ssl.SSLError, ss.write, b"hello\n")
         del ss; gc.collect()
 
 class AppTestConnectedSSL_Timeout(AppTestConnectedSSL):
diff --git a/pypy/rlib/ropenssl.py b/pypy/rlib/ropenssl.py
--- a/pypy/rlib/ropenssl.py
+++ b/pypy/rlib/ropenssl.py
@@ -2,6 +2,7 @@
 from pypy.rpython.tool import rffi_platform
 from pypy.translator.platform import platform
 from pypy.translator.tool.cbuild import ExternalCompilationInfo
+from pypy.rlib.unroll import unrolling_iterable
 
 import sys
 
@@ -66,6 +67,10 @@
     OPENSSL_NO_SSL2 = rffi_platform.Defined("OPENSSL_NO_SSL2")
     SSL_FILETYPE_PEM = rffi_platform.ConstantInteger("SSL_FILETYPE_PEM")
     SSL_OP_ALL = rffi_platform.ConstantInteger("SSL_OP_ALL")
+    SSL_OP_NO_SSLv2 = rffi_platform.ConstantInteger("SSL_OP_NO_SSLv2")
+    SSL_OP_NO_SSLv3 = rffi_platform.ConstantInteger("SSL_OP_NO_SSLv3")
+    SSL_OP_NO_TLSv1 = rffi_platform.ConstantInteger("SSL_OP_NO_TLSv1")
+    HAS_SNI = rffi_platform.Defined("SSL_CTRL_SET_TLSEXT_HOSTNAME")
     SSL_VERIFY_NONE = rffi_platform.ConstantInteger("SSL_VERIFY_NONE")
     SSL_VERIFY_PEER = rffi_platform.ConstantInteger("SSL_VERIFY_PEER")
     SSL_VERIFY_FAIL_IF_NO_PEER_CERT = 
rffi_platform.ConstantInteger("SSL_VERIFY_FAIL_IF_NO_PEER_CERT")
@@ -186,6 +191,14 @@
 ssl_external('SSL_CTX_get_verify_mode', [SSL_CTX], rffi.INT)
 ssl_external('SSL_CTX_set_cipher_list', [SSL_CTX, rffi.CCHARP], rffi.INT)
 ssl_external('SSL_CTX_load_verify_locations', [SSL_CTX, rffi.CCHARP, 
rffi.CCHARP], rffi.INT)
+ssl_external('SSL_CTX_set_session_id_context', [SSL_CTX, rffi.CCHARP, 
rffi.UINT], rffi.INT)
+SSL_CTX_STATS_NAMES = """
+    number connect connect_good connect_renegotiate accept accept_god
+    accept_renegotiate hits misses timeouts cache_full""".split()
+SSL_CTX_STATS = unrolling_iterable(
+    (name, external('SSL_CTX_sess_' + name, [SSL_CTX], rffi.LONG))
+    for name in SSL_CTX_STATS_NAMES)
+
 ssl_external('SSL_new', [SSL_CTX], SSL)
 ssl_external('SSL_set_fd', [SSL, rffi.INT], rffi.INT)
 ssl_external('SSL_set_mode', [SSL, rffi.INT], rffi.INT, macro=True)
@@ -201,6 +214,7 @@
 ssl_external('SSL_get_error', [SSL, rffi.INT], rffi.INT)
 ssl_external('SSL_get_shutdown', [SSL], rffi.INT)
 ssl_external('SSL_set_read_ahead', [SSL, rffi.INT], lltype.Void)
+ssl_external('SSL_set_tlsext_host_name', [SSL, rffi.CCHARP], rffi.INT, 
macro=True)
 
 ssl_external('SSL_get_peer_certificate', [SSL], X509)
 ssl_external('X509_get_subject_name', [X509], X509_NAME)
@@ -211,7 +225,7 @@
 ssl_external('X509_NAME_ENTRY_get_object', [X509_NAME_ENTRY], ASN1_OBJECT)
 ssl_external('X509_NAME_ENTRY_get_data', [X509_NAME_ENTRY], ASN1_STRING)
 ssl_external('i2d_X509', [X509, rffi.CCHARPP], rffi.INT)
-ssl_external('X509_free', [X509], lltype.Void)
+ssl_external('X509_free', [X509], lltype.Void, threadsafe=False)
 ssl_external('X509_get_notBefore', [X509], ASN1_TIME, macro=True)
 ssl_external('X509_get_notAfter', [X509], ASN1_TIME, macro=True)
 ssl_external('X509_get_serialNumber', [X509], ASN1_INTEGER)
@@ -246,8 +260,8 @@
 ssl_external('ERR_get_error', [], rffi.INT)
 ssl_external('ERR_error_string', [rffi.ULONG, rffi.CCHARP], rffi.CCHARP)
 
-ssl_external('SSL_free', [SSL], lltype.Void)
-ssl_external('SSL_CTX_free', [SSL_CTX], lltype.Void)
+ssl_external('SSL_free', [SSL], lltype.Void, threadsafe=False)
+ssl_external('SSL_CTX_free', [SSL_CTX], lltype.Void, threadsafe=False)
 ssl_external('CRYPTO_free', [rffi.VOIDP], lltype.Void)
 libssl_OPENSSL_free = libssl_CRYPTO_free
 
diff --git a/pypy/rlib/rsocket.py b/pypy/rlib/rsocket.py
--- a/pypy/rlib/rsocket.py
+++ b/pypy/rlib/rsocket.py
@@ -744,6 +744,11 @@
             if res != 0:
                 raise self.error_handler()
 
+    def detach(self):
+        fd = self.fd
+        self.fd = _c.INVALID_SOCKET
+        return fd
+
     if _c.WIN32:
         def _connect(self, address):
             """Connect the socket to a remote address."""
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to