Author: Amaury Forgeot d'Arc <[email protected]>
Branch: stdlib-2.7.9
Changeset: r75630:1db91f1d2dbe
Date: 2015-02-01 19:04 +0100
http://bitbucket.org/pypy/pypy/changeset/1db91f1d2dbe/

Log:    Implement SSLContext.set_servername_callback. No test so far.

diff --git a/pypy/module/_ssl/interp_ssl.py b/pypy/module/_ssl/interp_ssl.py
--- a/pypy/module/_ssl/interp_ssl.py
+++ b/pypy/module/_ssl/interp_ssl.py
@@ -1,9 +1,9 @@
-from rpython.rlib import rpoll, rsocket, rthread
+from rpython.rlib import rpoll, rsocket, rthread, rweakref
 from rpython.rlib.rarithmetic import intmask, widen, r_uint
 from rpython.rlib.ropenssl import *
 from rpython.rlib.rposix import get_errno, set_errno
 from rpython.rlib.rweakref import RWeakValueDictionary
-from rpython.rlib.objectmodel import specialize
+from rpython.rlib.objectmodel import specialize, compute_unique_id
 from rpython.rtyper.lltypesystem import lltype, rffi
 
 from pypy.interpreter.baseobjspace import W_Root
@@ -244,7 +244,7 @@
 
         self.socket_type = socket_type
         self.w_socket = w_sock
-        self.w_ssl_sock = None
+        self.ssl_sock_weakref_w = rweakref.ref(w_ssl_sock)
         return self
 
     def __del__(self):
@@ -1059,6 +1059,64 @@
         buf[i] = c
     return rffi.cast(rffi.INT, len(password))
 
+class ServernameCallback(object):
+    w_ctx = None
+    space = None
+SERVERNAME_CALLBACKS = RWeakValueDictionary(int, ServernameCallback)
+
+def _servername_callback(ssl, ad, arg):
+    struct = SERVERNAME_CALLBACKS.get(rffi.cast(lltype.Signed, arg))
+    w_ctx = struct.w_ctx
+    space = struct.space
+    servername = libssl_SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name)
+    if not w_ctx.w_set_hostname:
+        # remove race condition.
+        return SSL_TLSEXT_ERR_OK
+    # The high-level ssl.SSLSocket object
+    w_ssl = libssl_SSL_get_app_data(ssl)
+    assert isinstance(w_ssl, _SSLSocket)
+    if w_ssl.ssl_sock_weakref_w:
+        w_ssl_socket = w_ssl.ssl_sock_weakref_w.get()
+    else:
+        w_ssl_socket = space.w_None
+    if space.is_none(w_ssl_socket):
+        ad[0] = SSL_AD_INTERNAL_ERROR
+        return SSL_TLSEXT_ERR_ALERT_FATAL
+
+    try:
+        if not servername:
+            w_result = space.call_function(w_ctx.w_set_hostname,
+                                           w_ssl_socket, space.w_None, w_ctx)
+
+        else:
+            try:
+                w_servername = space.wrap_bytes(rffi.charp2str(servername))
+                w_servername_idna = space.call_method(
+                    w_servername, 'decode', space.wrap('idna'))
+            except OperationError as e:
+                space.write_unraisable(e, "w_servername")
+                ad[0] = SSL_AD_INTERNAL_ERROR;
+                return SSL_TLSEXT_ERR_ALERT_FATAL
+
+            w_result = space.call_function(w_ctx.w_set_hostname,
+                                           w_ssl_socket,
+                                           w_servername_idna, w_ctx)
+    except OperationError as e:
+        space.write_unraisable(e, "ssl_ctx->set_hostname")
+        ad[0] = SSL_AD_HANDSHAKE_FAILURE
+        return SSL_TLSEXT_ERR_ALERT_FATAL
+
+    if space.is_none(w_result):
+        return SSL_TLSEXT_ERR_OK
+    else:
+        try:
+            ad[0] = space.int_w(w_result)
+        except OperationError as e:
+            space.write_unraisable(e, "w_result")
+            ad[0] = SSL_AD_INTERNAL_ERROR
+        return SSL_TLSEXT_ERR_ALERT_FATAL
+
+
 class _SSLContext(W_Root):
     @staticmethod
     @unwrap_spec(protocol=int)
@@ -1402,6 +1460,24 @@
                 rlist.append(_decode_certificate(space, cert))
         return space.newlist(rlist)
 
+    def set_servername_callback_w(self, space, w_callback):
+        if space.is_none(w_callback):
+            libssl_SSL_CTX_set_tlsext_servername_callback(
+                self.ctx, lltype.nullptr(servername_cb.TO))
+            return
+        if not space.is_true(space.callable(w_callback)):
+            raise oefmt(space.w_TypeError, "not a callable object")
+        self.w_set_hostname = w_callback
+        struct = ServernameCallback()
+        struct.space = space
+        struct.w_ctx = self
+        index = compute_unique_id(self)
+        SERVERNAME_CALLBACKS.set(index, struct)
+        libssl_SSL_CTX_set_tlsext_servername_callback(
+            self.ctx, _servername_callback)
+        libssl_SSL_CTX_set_tlsext_servername_arg(self.ctx,
+                                                 rffi.cast(rffi.VOIDP, index))
+
 _SSLContext.typedef = TypeDef(
     "_ssl._SSLContext",
     __new__=interp2app(_SSLContext.descr_new),
@@ -1414,6 +1490,7 @@
     
set_default_verify_paths=interp2app(_SSLContext.descr_set_default_verify_paths),
     _set_npn_protocols=interp2app(_SSLContext.set_npn_protocols_w),
     get_ca_certs=interp2app(_SSLContext.get_ca_certs_w),
+    set_servername_callback=interp2app(_SSLContext.set_servername_callback_w),
 
     options=GetSetProperty(_SSLContext.descr_get_options,
                            _SSLContext.descr_set_options),
diff --git a/rpython/rlib/ropenssl.py b/rpython/rlib/ropenssl.py
--- a/rpython/rlib/ropenssl.py
+++ b/rpython/rlib/ropenssl.py
@@ -115,6 +115,12 @@
     SSL_MODE_AUTO_RETRY = rffi_platform.ConstantInteger("SSL_MODE_AUTO_RETRY")
     SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER = 
rffi_platform.ConstantInteger("SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER")
     SSL_TLSEXT_ERR_OK = rffi_platform.ConstantInteger("SSL_TLSEXT_ERR_OK")
+    SSL_TLSEXT_ERR_ALERT_FATAL = 
rffi_platform.ConstantInteger("SSL_TLSEXT_ERR_ALERT_FATAL")
+
+    SSL_AD_INTERNAL_ERROR = 
rffi_platform.ConstantInteger("SSL_AD_INTERNAL_ERROR")
+    SSL_AD_HANDSHAKE_FAILURE = 
rffi_platform.ConstantInteger("SSL_AD_HANDSHAKE_FAILURE")
+
+    TLSEXT_NAMETYPE_host_name = 
rffi_platform.ConstantInteger("TLSEXT_NAMETYPE_host_name")
 
     ERR_LIB_X509 = rffi_platform.ConstantInteger("ERR_LIB_X509")
     ERR_LIB_PEM = rffi_platform.ConstantInteger("ERR_LIB_PEM")
@@ -273,6 +279,11 @@
 pem_password_cb = lltype.Ptr(lltype.FuncType([rffi.CCHARP, rffi.INT, rffi.INT, 
rffi.VOIDP], rffi.INT))
 ssl_external('SSL_CTX_set_default_passwd_cb', [SSL_CTX, pem_password_cb], 
lltype.Void)
 ssl_external('SSL_CTX_set_default_passwd_cb_userdata', [SSL_CTX, rffi.VOIDP], 
lltype.Void)
+servername_cb = lltype.Ptr(lltype.FuncType([SSL, rffi.INTP, rffi.VOIDP], 
rffi.INT))
+ssl_external('SSL_CTX_set_tlsext_servername_callback', [SSL_CTX, 
servername_cb],
+             lltype.Void, macro=True)
+ssl_external('SSL_CTX_set_tlsext_servername_arg', [SSL_CTX, rffi.VOIDP], 
lltype.Void)
+
 SSL_CTX_STATS_NAMES = """
     number connect connect_good connect_renegotiate accept accept_good
     accept_renegotiate hits misses timeouts cache_full""".split()
@@ -303,6 +314,8 @@
 ssl_external('SSL_get_version', [SSL], rffi.CCHARP)
 
 ssl_external('SSL_get_peer_certificate', [SSL], X509)
+ssl_external('SSL_get_servername', [SSL, rffi.INT], rffi.CCHARP)
+ssl_external('SSL_get_app_data', [SSL], rffi.VOIDP, macro=True)
 ssl_external('X509_get_subject_name', [X509], X509_NAME)
 ssl_external('X509_get_issuer_name', [X509], X509_NAME)
 ssl_external('X509_NAME_oneline', [X509_NAME, rffi.CCHARP, rffi.INT], 
rffi.CCHARP)
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to