From c9b5aad7b2f693b204c158aac7d0d710807593ce Mon Sep 17 00:00:00 2001
From: Fedor Indutny <fedor@indutny.com>
Date: Thu, 11 Sep 2014 01:13:38 +0100
Subject: [PATCH] ssl: SSL_MODE_ASYNC_KEY_EX

Support async RSA exchange by providing new SSL_want_rsa_sign(),
SSL_want_rsa_decrypt() API methods.

After getting such want values - SSL_supply_key_ex_data() should be
invoked to continue handshake with a sign/decrypt data that was received
from the remote server.
---
 ssl/s3_srvr.c  | 378 +++++++++++++++++++++++++++++++++++++++------------------
 ssl/ssl.h      |  63 ++++++++++
 ssl/ssl3.h     |   6 +
 ssl/ssl_lib.c  |  58 ++++++++-
 ssl/ssl_locl.h |   2 +
 ssl/ssl_rsa.c  |  24 ++--
 ssl/ssltest.c  | 124 ++++++++++++++++++-
 test/testssl   |   6 +
 8 files changed, 525 insertions(+), 136 deletions(-)

diff --git a/ssl/s3_srvr.c b/ssl/s3_srvr.c
index 440fc13..994ef2f 100644
--- a/ssl/s3_srvr.c
+++ b/ssl/s3_srvr.c
@@ -496,15 +496,41 @@ int ssl3_accept(SSL *s)
 			    )
 				{
 				ret=ssl3_send_server_key_exchange(s);
-				if (ret <= 0) goto end;
+				if (ret <= 0)
+					goto end;
+				else if (ret == 2)
+					{
+					s->state=SSL3_ST_SW_KEY_EXCH_SIGN_SUPPLY;
+					break;
+					}
+				else if (ret == 3)
+					{
+					s->state=SSL3_ST_SW_KEY_EXCH_SIGN_WAIT;
+					break;
+					}
 				}
 			else
 				skip=1;
+			/* Intentional fall through */
 
+		case SSL3_ST_SW_KEY_EXCH_C:
 			s->state=SSL3_ST_SW_CERT_REQ_A;
 			s->init_num=0;
 			break;
 
+		case SSL3_ST_SW_KEY_EXCH_SIGN_WAIT:
+			s->rwstate=SSL_SIGN;
+			ret = -1;
+			goto end;
+
+		case SSL3_ST_SW_KEY_EXCH_SIGN_SUPPLY:
+			ret=ssl3_cont_server_key_exchange(s);
+			if (ret != 1)
+				goto end;
+			s->rwstate=SSL_NOTHING;
+			s->state=SSL3_ST_SW_KEY_EXCH_C;
+			break;
+
 		case SSL3_ST_SW_CERT_REQ_A:
 		case SSL3_ST_SW_CERT_REQ_B:
 			if (/* don't request cert unless asked for it: */
@@ -607,7 +633,13 @@ int ssl3_accept(SSL *s)
 			ret=ssl3_get_client_key_exchange(s);
 			if (ret <= 0)
 				goto end;
-			if (ret == 2)
+			if (ret == 3)
+				{
+				s->state=(s->mode & SSL_MODE_ASYNC_KEY_EX) ?
+						SSL3_ST_SR_KEY_EXCH_RSA_DECRYPT_WAIT :
+						SSL3_ST_SR_KEY_EXCH_RSA_DECRYPT_SUPPLY;
+				}
+			else if (ret == 2)
 				{
 				/* For the ECDH ciphersuites when
 				 * the client sends its ECDH pub key in
@@ -627,7 +659,12 @@ int ssl3_accept(SSL *s)
 #endif
 				s->init_num = 0;
 				}
-			else if (SSL_USE_SIGALGS(s))
+			else
+				s->state=SSL3_ST_SR_KEY_EXCH_C;
+			break;
+
+		case SSL3_ST_SR_KEY_EXCH_C:
+			if (SSL_USE_SIGALGS(s))
 				{
 				s->state=SSL3_ST_SR_CERT_VRFY_A;
 				s->init_num=0;
@@ -678,6 +715,19 @@ int ssl3_accept(SSL *s)
 				}
 			break;
 
+		case SSL3_ST_SR_KEY_EXCH_RSA_DECRYPT_WAIT:
+			s->rwstate=SSL_RSA_DECRYPT;
+			ret = -1;
+			goto end;
+
+		case SSL3_ST_SR_KEY_EXCH_RSA_DECRYPT_SUPPLY:
+			ret=ssl3_cont_client_key_exchange(s);
+			if (ret != 1)
+				goto end;
+			s->rwstate=SSL_NOTHING;
+			s->state=SSL3_ST_SR_KEY_EXCH_C;
+			break;
+
 		case SSL3_ST_SR_CERT_VRFY_A:
 		case SSL3_ST_SR_CERT_VRFY_B:
 
@@ -1617,7 +1667,9 @@ int ssl3_send_server_key_exchange(SSL *s)
 	int nr[4],kn;
 	BUF_MEM *buf;
 	EVP_MD_CTX md_ctx;
+	int async;
 
+	async=0;
 	EVP_MD_CTX_init(&md_ctx);
 	if (s->state == SSL3_ST_SW_KEY_EXCH_A)
 		{
@@ -2009,14 +2061,32 @@ int ssl3_send_server_key_exchange(SSL *s)
 					q+=i;
 					j+=i;
 					}
-				if (RSA_sign(NID_md5_sha1, md_buf, j,
-					&(p[2]), &u, pkey->pkey.rsa) <= 0)
+				s->init_num=n;
+				s->init_off=&(p[2])-(unsigned char*)s->init_buf->data;
+				s->s3->tmp.reuse_message=1;
+				if ((s->mode & SSL_MODE_ASYNC_KEY_EX) == 0)
 					{
-					SSLerr(SSL_F_SSL3_SEND_SERVER_KEY_EXCHANGE,ERR_LIB_RSA);
-					goto err;
+					if (RSA_sign(NID_md5_sha1, md_buf, j,
+						&(p[2]), &u, pkey->pkey.rsa) <= 0)
+						{
+						SSLerr(SSL_F_SSL3_SEND_SERVER_KEY_EXCHANGE,ERR_LIB_RSA);
+						goto err;
+						}
+					s->key_ex.len=u;
+					}
+				else
+					{
+					/* Copy md_buf contents to init_buf */
+					s->key_ex.data=&(p[2]);
+					s->key_ex.md=NID_md5_sha1;
+					s->key_ex.type=pkey->type;
+					s->key_ex.len=j;
+					/* Update pointers after growth */
+					p=ssl_handshake_start(s) + (p - d);
+					d=ssl_handshake_start(s);
+					memcpy(s->key_ex.data, md_buf, s->key_ex.len);
+					async=1;
 					}
-				s2n(u,p);
-				n+=u+2;
 				}
 			else
 #endif
@@ -2038,20 +2108,41 @@ int ssl3_send_server_key_exchange(SSL *s)
 				fprintf(stderr, "Using hash %s\n",
 							EVP_MD_name(md));
 #endif
-				EVP_SignInit_ex(&md_ctx, md, NULL);
-				EVP_SignUpdate(&md_ctx,&(s->s3->client_random[0]),SSL3_RANDOM_SIZE);
-				EVP_SignUpdate(&md_ctx,&(s->s3->server_random[0]),SSL3_RANDOM_SIZE);
-				EVP_SignUpdate(&md_ctx,d,n);
-				if (!EVP_SignFinal(&md_ctx,&(p[2]),
-					(unsigned int *)&i,pkey))
+				s->init_num=n;
+				s->init_off=&(p[2]) - (unsigned char*)s->init_buf->data;
+				s->s3->tmp.reuse_message=1;
+				if ((s->mode & SSL_MODE_ASYNC_KEY_EX) == 0)
 					{
-					SSLerr(SSL_F_SSL3_SEND_SERVER_KEY_EXCHANGE,ERR_LIB_EVP);
-					goto err;
+					EVP_SignInit_ex(&md_ctx, md, NULL);
+					EVP_SignUpdate(&md_ctx,&(s->s3->client_random[0]),SSL3_RANDOM_SIZE);
+					EVP_SignUpdate(&md_ctx,&(s->s3->server_random[0]),SSL3_RANDOM_SIZE);
+					EVP_SignUpdate(&md_ctx,d,n);
+					if (!EVP_SignFinal(&md_ctx,&(p[2]),
+						(unsigned int *)&i,pkey))
+						{
+						SSLerr(SSL_F_SSL3_SEND_SERVER_KEY_EXCHANGE,ERR_LIB_EVP);
+						goto err;
+						}
+					s->key_ex.len=i;
+					}
+				else
+					{
+					/* Copy digest inputs to init_buf */
+					s->key_ex.data=&p[2];
+					s->key_ex.len=2 * SSL3_RANDOM_SIZE + n;
+					s->key_ex.md=EVP_MD_nid(md);
+					s->key_ex.type=pkey->type;
+					/* Update pointers after growth */
+					p=ssl_handshake_start(s) + (p - d);
+					d=ssl_handshake_start(s);
+					q=s->key_ex.data;
+					memcpy(q, &(s->s3->client_random[0]), SSL3_RANDOM_SIZE);
+					memcpy(q + SSL3_RANDOM_SIZE,
+								 &(s->s3->server_random[0]),
+								 SSL3_RANDOM_SIZE);
+					memcpy(q + 2 * SSL3_RANDOM_SIZE, d, n);
+					async=1;
 					}
-				s2n(i,p);
-				n+=i+2;
-				if (SSL_USE_SIGALGS(s))
-					n+= 2;
 				}
 			else
 				{
@@ -2060,11 +2151,12 @@ int ssl3_send_server_key_exchange(SSL *s)
 				SSLerr(SSL_F_SSL3_SEND_SERVER_KEY_EXCHANGE,SSL_R_UNKNOWN_PKEY_TYPE);
 				goto f_err;
 				}
+			EVP_MD_CTX_cleanup(&md_ctx);
+			return async ? 3 : 2;
 			}
-
-		ssl_set_handshake_header(s, SSL3_MT_SERVER_KEY_EXCHANGE, n);
 		}
 
+  ssl_set_handshake_header(s, SSL3_MT_SERVER_KEY_EXCHANGE, n);
 	s->state = SSL3_ST_SW_KEY_EXCH_B;
 	EVP_MD_CTX_cleanup(&md_ctx);
 	return ssl_do_write(s);
@@ -2079,6 +2171,30 @@ err:
 	return(-1);
 	}
 
+int ssl3_cont_server_key_exchange(SSL *s)
+	{
+	unsigned char *p,*d;
+	int n;
+
+	d=ssl_handshake_start(s);
+	/* NOTE: init_off points to the supplied data */
+	p=d + s->init_off - (d - (unsigned char*) s->init_buf->data) - 2;
+	n=s->init_num;
+	s->init_num=0;
+	s->s3->tmp.reuse_message=0;
+
+	s2n(s->key_ex.len,p);
+	n+=s->key_ex.len+2;
+
+	/* Signature/Hash algorithms */
+	if (SSL_USE_SIGALGS(s))
+		n+= 2;
+	ssl_set_handshake_header(s, SSL3_MT_SERVER_KEY_EXCHANGE, n);
+
+	s->state=SSL3_ST_SW_KEY_EXCH_B;
+	return ssl_do_write(s);
+	}
+
 int ssl3_send_certificate_request(SSL *s)
 	{
 	unsigned char *p,*d;
@@ -2222,10 +2338,6 @@ int ssl3_get_client_key_exchange(SSL *s)
 #ifndef OPENSSL_NO_RSA
 	if (alg_k & SSL_kRSA)
 		{
-		unsigned char rand_premaster_secret[SSL_MAX_MASTER_KEY_LENGTH];
-		int decrypt_len, decrypt_good_mask;
-		unsigned char version_good;
-
 		/* FIX THIS UP EAY EAY EAY EAY */
 		if (s->s3->tmp.use_rsa_tmp)
 			{
@@ -2273,99 +2385,21 @@ int ssl3_get_client_key_exchange(SSL *s)
 				n=i;
 			}
 
-		/* We must not leak whether a decryption failure occurs because
-		 * of Bleichenbacher's attack on PKCS #1 v1.5 RSA padding (see
-		 * RFC 2246, section 7.4.7.1). The code follows that advice of
-		 * the TLS RFC and generates a random premaster secret for the
-		 * case that the decrypt fails. See
-		 * https://tools.ietf.org/html/rfc5246#section-7.4.7.1 */
-
-		/* should be RAND_bytes, but we cannot work around a failure. */
-		if (RAND_pseudo_bytes(rand_premaster_secret,
-				      sizeof(rand_premaster_secret)) <= 0)
-			goto err;
-		decrypt_len = RSA_private_decrypt((int)n,p,p,rsa,RSA_PKCS1_PADDING);
-		ERR_clear_error();
-
-		/* decrypt_len should be SSL_MAX_MASTER_KEY_LENGTH.
-		 * decrypt_good_mask will be zero if so and non-zero otherwise. */
-		decrypt_good_mask = decrypt_len ^ SSL_MAX_MASTER_KEY_LENGTH;
-
-		/* If the version in the decrypted pre-master secret is correct
-		 * then version_good will be zero. The Klima-Pokorny-Rosa
-		 * extension of Bleichenbacher's attack
-		 * (http://eprint.iacr.org/2003/052/) exploits the version
-		 * number check as a "bad version oracle". Thus version checks
-		 * are done in constant time and are treated like any other
-		 * decryption error. */
-		version_good = p[0] ^ (s->client_version>>8);
-		version_good |= p[1] ^ (s->client_version&0xff);
-
-		/* The premaster secret must contain the same version number as
-		 * the ClientHello to detect version rollback attacks
-		 * (strangely, the protocol does not offer such protection for
-		 * DH ciphersuites). However, buggy clients exist that send the
-		 * negotiated protocol version instead if the server does not
-		 * support the requested protocol version. If
-		 * SSL_OP_TLS_ROLLBACK_BUG is set, tolerate such clients. */
-		if (s->options & SSL_OP_TLS_ROLLBACK_BUG)
-			{
-			unsigned char workaround_mask = version_good;
-			unsigned char workaround;
-
-			/* workaround_mask will be 0xff if version_good is
-			 * non-zero (i.e. the version match failed). Otherwise
-			 * it'll be 0x00. */
-			workaround_mask |= workaround_mask >> 4;
-			workaround_mask |= workaround_mask >> 2;
-			workaround_mask |= workaround_mask >> 1;
-			workaround_mask = ~((workaround_mask & 1) - 1);
-
-			workaround = p[0] ^ (s->version>>8);
-			workaround |= p[1] ^ (s->version&0xff);
-
-			/* If workaround_mask is 0xff (i.e. there was a version
-			 * mismatch) then we copy the value of workaround over
-			 * version_good. */
-			version_good = (workaround & workaround_mask) |
-				       (version_good & ~workaround_mask);
-			}
-
-		/* If any bits in version_good are set then they'll poision
-		 * decrypt_good_mask and cause rand_premaster_secret to be
-		 * used. */
-		decrypt_good_mask |= version_good;
-
-		/* decrypt_good_mask will be zero iff decrypt_len ==
-		 * SSL_MAX_MASTER_KEY_LENGTH and the version check passed. We
-		 * fold the bottom 32 bits of it with an OR so that the LSB
-		 * will be zero iff everything is good. This assumes that we'll
-		 * never decrypt a value > 2**31 bytes, which seems safe. */
-		decrypt_good_mask |= decrypt_good_mask >> 16;
-		decrypt_good_mask |= decrypt_good_mask >> 8;
-		decrypt_good_mask |= decrypt_good_mask >> 4;
-		decrypt_good_mask |= decrypt_good_mask >> 2;
-		decrypt_good_mask |= decrypt_good_mask >> 1;
-		/* Now select only the LSB and subtract one. If decrypt_len ==
-		 * SSL_MAX_MASTER_KEY_LENGTH and the version check passed then
-		 * decrypt_good_mask will be all ones. Otherwise it'll be all
-		 * zeros. */
-		decrypt_good_mask &= 1;
-		decrypt_good_mask--;
-
-		/* Now copy rand_premaster_secret over p using
-		 * decrypt_good_mask. */
-		for (i = 0; i < (int) sizeof(rand_premaster_secret); i++)
-			{
-			p[i] = (p[i] & decrypt_good_mask) |
-			       (rand_premaster_secret[i] & ~decrypt_good_mask);
+		s->init_off=p-(unsigned char *)s->init_msg;
+		s->s3->tmp.reuse_message=1;
+		if ((s->mode & SSL_MODE_ASYNC_KEY_EX) == 0)
+			{
+			int decrypt_len;
+			decrypt_len = RSA_private_decrypt((int)n,p,p,rsa,RSA_PKCS1_PADDING);
+			ERR_clear_error();
+			s->key_ex.len=decrypt_len;
 			}
-
-		s->session->master_key_length=
-			s->method->ssl3_enc->generate_master_secret(s,
-				s->session->master_key,
-				p,i);
-		OPENSSL_cleanse(p,i);
+		else
+			{
+			s->key_ex.data=p;
+			s->key_ex.len=n;
+			}
+		return 3;
 		}
 	else
 #endif
@@ -3044,6 +3078,112 @@ err:
 	return(-1);
 	}
 
+int ssl3_cont_client_key_exchange(SSL *s)
+	{
+	unsigned char rand_premaster_secret[SSL_MAX_MASTER_KEY_LENGTH];
+	int decrypt_len, decrypt_good_mask;
+	unsigned char version_good;
+	int i;
+	unsigned char* p;
+
+	p=(unsigned char*)s->init_msg + s->init_off;
+	decrypt_len=s->key_ex.len;
+	s->s3->tmp.reuse_message=0;
+
+	/* We must not leak whether a decryption failure occurs because
+	 * of Bleichenbacher's attack on PKCS #1 v1.5 RSA padding (see
+	 * RFC 2246, section 7.4.7.1). The code follows that advice of
+	 * the TLS RFC and generates a random premaster secret for the
+	 * case that the decrypt fails. See
+	 * https://tools.ietf.org/html/rfc5246#section-7.4.7.1 */
+
+	/* should be RAND_bytes, but we cannot work around a failure. */
+	if (RAND_pseudo_bytes(rand_premaster_secret,
+						sizeof(rand_premaster_secret)) <= 0)
+		return -1;
+
+	/* decrypt_len should be SSL_MAX_MASTER_KEY_LENGTH.
+	 * decrypt_good_mask will be zero if so and non-zero otherwise. */
+	decrypt_good_mask = decrypt_len ^ SSL_MAX_MASTER_KEY_LENGTH;
+
+	/* If the version in the decrypted pre-master secret is correct
+	 * then version_good will be zero. The Klima-Pokorny-Rosa
+	 * extension of Bleichenbacher's attack
+	 * (http://eprint.iacr.org/2003/052/) exploits the version
+	 * number check as a "bad version oracle". Thus version checks
+	 * are done in constant time and are treated like any other
+	 * decryption error. */
+	version_good = p[0] ^ (s->client_version>>8);
+	version_good |= p[1] ^ (s->client_version&0xff);
+
+	/* The premaster secret must contain the same version number as
+	 * the ClientHello to detect version rollback attacks
+	 * (strangely, the protocol does not offer such protection for
+	 * DH ciphersuites). However, buggy clients exist that send the
+	 * negotiated protocol version instead if the server does not
+	 * support the requested protocol version. If
+	 * SSL_OP_TLS_ROLLBACK_BUG is set, tolerate such clients. */
+	if (s->options & SSL_OP_TLS_ROLLBACK_BUG)
+		{
+		unsigned char workaround_mask = version_good;
+		unsigned char workaround;
+
+		/* workaround_mask will be 0xff if version_good is
+		 * non-zero (i.e. the version match failed). Otherwise
+		 * it'll be 0x00. */
+		workaround_mask |= workaround_mask >> 4;
+		workaround_mask |= workaround_mask >> 2;
+		workaround_mask |= workaround_mask >> 1;
+		workaround_mask = ~((workaround_mask & 1) - 1);
+
+		workaround = p[0] ^ (s->version>>8);
+		workaround |= p[1] ^ (s->version&0xff);
+
+		/* If workaround_mask is 0xff (i.e. there was a version
+		 * mismatch) then we copy the value of workaround over
+		 * version_good. */
+		version_good = (workaround & workaround_mask) |
+						 (version_good & ~workaround_mask);
+		}
+
+	/* If any bits in version_good are set then they'll poision
+	 * decrypt_good_mask and cause rand_premaster_secret to be
+	 * used. */
+	decrypt_good_mask |= version_good;
+
+	/* decrypt_good_mask will be zero iff decrypt_len ==
+	 * SSL_MAX_MASTER_KEY_LENGTH and the version check passed. We
+	 * fold the bottom 32 bits of it with an OR so that the LSB
+	 * will be zero iff everything is good. This assumes that we'll
+	 * never decrypt a value > 2**31 bytes, which seems safe. */
+	decrypt_good_mask |= decrypt_good_mask >> 16;
+	decrypt_good_mask |= decrypt_good_mask >> 8;
+	decrypt_good_mask |= decrypt_good_mask >> 4;
+	decrypt_good_mask |= decrypt_good_mask >> 2;
+	decrypt_good_mask |= decrypt_good_mask >> 1;
+	/* Now select only the LSB and subtract one. If decrypt_len ==
+	 * SSL_MAX_MASTER_KEY_LENGTH and the version check passed then
+	 * decrypt_good_mask will be all ones. Otherwise it'll be all
+	 * zeros. */
+	decrypt_good_mask &= 1;
+	decrypt_good_mask--;
+
+	/* Now copy rand_premaster_secret over p using
+	 * decrypt_good_mask. */
+	for (i = 0; i < (int) sizeof(rand_premaster_secret); i++)
+		{
+		p[i] = (p[i] & decrypt_good_mask) |
+					 (rand_premaster_secret[i] & ~decrypt_good_mask);
+		}
+
+	s->session->master_key_length=
+		s->method->ssl3_enc->generate_master_secret(s,
+			s->session->master_key,
+			p,i);
+	OPENSSL_cleanse(p,i);
+  return 1;
+	}
+
 int ssl3_get_cert_verify(SSL *s)
 	{
 	EVP_PKEY *pkey=NULL;
diff --git a/ssl/ssl.h b/ssl/ssl.h
index a1a3e13..586b69f 100644
--- a/ssl/ssl.h
+++ b/ssl/ssl.h
@@ -685,6 +685,32 @@ struct ssl_session_st
  */
 #define SSL_MODE_SEND_CLIENTHELLO_TIME 0x00000020L
 #define SSL_MODE_SEND_SERVERHELLO_TIME 0x00000040L
+/* If set - SSL_ERROR_WANT_RSA_DECRYPT/SSL_ERROR_WANT_SIGN may be returned
+ * by SSL_read()/SSL_write()/SSL_accept(), when the key exchange requires the
+ * use of the RSA private key.
+ *
+ * In such case following functions should be called in order to get input
+ * data and input length:
+ *
+ *     SSL_get_key_ex_data()
+ *     SSL_get_key_ex_len()
+ *
+ * In case of SSL_ERROR_WANT_SIGN - SSL_get_key_ex_type() will return the
+ * type of the private key that should be used for signature.
+ * SSL_get_key_ex_md() - for getting the nid of digest. Note that it may return
+ * NID_md5_sha1, which could only be supplied to RSA_sign() and is not accepted
+ * by general EVP methods.
+ *
+ * After performing this operation (decrypt / sign), the output data should be
+ * supplied via:
+ *
+ *     SSL_supply_key_ex_data()
+ *
+ * Next either of SSL_accept()/SSL_read()/SSL_write() may be called to continue
+ * the handshake.
+ */
+
+#define SSL_MODE_ASYNC_KEY_EX          0x00000080L
 
 /* Cert related flags */
 /* Many implementations ignore some aspects of the TLS standards such as
@@ -1275,12 +1301,16 @@ int SSL_extension_supported(unsigned int ext_type);
 #define SSL_WRITING	2
 #define SSL_READING	3
 #define SSL_X509_LOOKUP	4
+#define SSL_RSA_DECRYPT	5
+#define SSL_SIGN	6
 
 /* These will only be used when doing non-blocking IO */
 #define SSL_want_nothing(s)	(SSL_want(s) == SSL_NOTHING)
 #define SSL_want_read(s)	(SSL_want(s) == SSL_READING)
 #define SSL_want_write(s)	(SSL_want(s) == SSL_WRITING)
 #define SSL_want_x509_lookup(s)	(SSL_want(s) == SSL_X509_LOOKUP)
+#define SSL_want_rsa_decrypt(s)	(SSL_want(s) == SSL_RSA_DECRYPT)
+#define SSL_want_sign(s)	(SSL_want(s) == SSL_SIGN)
 
 #define SSL_MAC_FLAG_READ_MAC_STREAM 1
 #define SSL_MAC_FLAG_WRITE_MAC_STREAM 2
@@ -1546,6 +1576,17 @@ struct ssl_st
 	/* Callback for disabling session caching and ticket support
 	 * on a session basis, depending on the chosen cipher. */
 	int (*not_resumable_session_cb)(SSL *ssl, int is_forward_secure);
+	struct
+		{
+		/* Input data for key exchange */
+		unsigned char* data;
+		/* Input length */
+		long len;
+		/* Digest type for signature */
+		int md;
+		/* EVP_PKEY type */
+		int type;
+		} key_ex;
 	};
 
 #endif
@@ -1633,6 +1674,26 @@ size_t SSL_get_peer_finished(const SSL *s, void *buf, size_t count);
 #define OpenSSL_add_ssl_algorithms()	SSL_library_init()
 #define SSLeay_add_ssl_algorithms()	SSL_library_init()
 
+/*
+ * Async key exchange methods
+ * See SSL_MODE_ASYNC_KEY_EX for detailed documentation
+ */
+
+/* Input data for async key exchange */
+const unsigned char* SSL_get_key_ex_data(const SSL *s);
+
+/* Input length for async key exchange */
+long SSL_get_key_ex_len(const SSL *s);
+
+/* Signature digest algorithm NID */
+int SSL_get_key_ex_md(const SSL *s);
+
+/* Signature private key type, EVP_PKEY_RSA/EVP_PKEY_ECC/... */
+int SSL_get_key_ex_type(const SSL *s);
+
+/* Supply the key exchange data */
+int SSL_supply_key_ex_data(SSL *s, unsigned char *data, long len);
+
 /* this is for backward compatibility */
 #if 0 /* NEW_SSLEAY */
 #define SSL_CTX_set_default_verify(a,b,c) SSL_CTX_set_verify(a,b,c)
@@ -1703,6 +1764,8 @@ DECLARE_PEM_rw(SSL_SESSION, SSL_SESSION)
 #define SSL_ERROR_ZERO_RETURN		6
 #define SSL_ERROR_WANT_CONNECT		7
 #define SSL_ERROR_WANT_ACCEPT		8
+#define SSL_ERROR_WANT_RSA_DECRYPT		9
+#define SSL_ERROR_WANT_SIGN		10
 
 #define SSL_CTRL_NEED_TMP_RSA			1
 #define SSL_CTRL_SET_TMP_RSA			2
diff --git a/ssl/ssl3.h b/ssl/ssl3.h
index 29cb184..f4484b3 100644
--- a/ssl/ssl3.h
+++ b/ssl/ssl3.h
@@ -627,6 +627,9 @@ typedef struct ssl3_state_st
 #define SSL3_ST_CR_CERT_B		(0x131|SSL_ST_CONNECT)
 #define SSL3_ST_CR_KEY_EXCH_A		(0x140|SSL_ST_CONNECT)
 #define SSL3_ST_CR_KEY_EXCH_B		(0x141|SSL_ST_CONNECT)
+#define SSL3_ST_SW_KEY_EXCH_C		(0x152|SSL_ST_ACCEPT)
+#define SSL3_ST_SW_KEY_EXCH_SIGN_WAIT		(0x153|SSL_ST_ACCEPT)
+#define SSL3_ST_SW_KEY_EXCH_SIGN_SUPPLY		(0x154|SSL_ST_ACCEPT)
 #define SSL3_ST_CR_CERT_REQ_A		(0x150|SSL_ST_CONNECT)
 #define SSL3_ST_CR_CERT_REQ_B		(0x151|SSL_ST_CONNECT)
 #define SSL3_ST_CR_SRVR_DONE_A		(0x160|SSL_ST_CONNECT)
@@ -638,6 +641,9 @@ typedef struct ssl3_state_st
 #define SSL3_ST_CW_CERT_D		(0x173|SSL_ST_CONNECT)
 #define SSL3_ST_CW_KEY_EXCH_A		(0x180|SSL_ST_CONNECT)
 #define SSL3_ST_CW_KEY_EXCH_B		(0x181|SSL_ST_CONNECT)
+#define SSL3_ST_SR_KEY_EXCH_C		(0x192|SSL_ST_ACCEPT)
+#define SSL3_ST_SR_KEY_EXCH_RSA_DECRYPT_WAIT		(0x193|SSL_ST_ACCEPT)
+#define SSL3_ST_SR_KEY_EXCH_RSA_DECRYPT_SUPPLY		(0x194|SSL_ST_ACCEPT)
 #define SSL3_ST_CW_CERT_VRFY_A		(0x190|SSL_ST_CONNECT)
 #define SSL3_ST_CW_CERT_VRFY_B		(0x191|SSL_ST_CONNECT)
 #define SSL3_ST_CW_CHANGE_A		(0x1A0|SSL_ST_CONNECT)
diff --git a/ssl/ssl_lib.c b/ssl/ssl_lib.c
index 6a33b9d..47961af 100644
--- a/ssl/ssl_lib.c
+++ b/ssl/ssl_lib.c
@@ -809,6 +809,55 @@ size_t SSL_get_peer_finished(const SSL *s, void *buf, size_t count)
 	}
 
 
+const unsigned char* SSL_get_key_ex_data(const SSL *s)
+	{
+	return s->key_ex.data;
+	}
+
+
+long SSL_get_key_ex_len(const SSL *s)
+	{
+	return s->key_ex.len;
+	}
+
+
+int SSL_get_key_ex_md(const SSL *s)
+	{
+	return s->key_ex.md;
+	}
+
+
+int SSL_get_key_ex_type(const SSL *s)
+	{
+	return s->key_ex.type;
+	}
+
+
+int SSL_supply_key_ex_data(SSL *s, unsigned char *data, long len)
+	{
+	if (s->s3 == NULL)
+		return 0;
+
+	switch (s->state) {
+		case SSL3_ST_SR_KEY_EXCH_RSA_DECRYPT_WAIT:
+			s->state=SSL3_ST_SR_KEY_EXCH_RSA_DECRYPT_SUPPLY;
+			break;
+		case SSL3_ST_SW_KEY_EXCH_SIGN_WAIT:
+			s->state=SSL3_ST_SW_KEY_EXCH_SIGN_SUPPLY;
+			break;
+		default:
+			return 0;
+	}
+	BIO_set_flags(SSL_get_rbio(s), 0);
+	/* Copy the data right into the message */
+	memcpy(s->init_buf->data + s->init_off, data, len);
+	s->key_ex.data=NULL;
+	/* The length is needed for RSA_DECRYPT case */
+	s->key_ex.len=len;
+	return 1;
+	}
+
+
 int SSL_get_verify_mode(const SSL *s)
 	{
 	return(s->verify_mode);
@@ -2757,7 +2806,14 @@ int SSL_get_error(const SSL *s,int i)
 		{
 		return(SSL_ERROR_WANT_X509_LOOKUP);
 		}
-
+	if ((i < 0) && SSL_want_rsa_decrypt(s))
+		{
+		return(SSL_ERROR_WANT_RSA_DECRYPT);
+		}
+	if ((i < 0) && SSL_want_sign(s))
+		{
+		return(SSL_ERROR_WANT_SIGN);
+		}
 	if (i == 0)
 		{
 		if (s->version == SSL2_VERSION)
diff --git a/ssl/ssl_locl.h b/ssl/ssl_locl.h
index 3f87da7..8a58be0 100644
--- a/ssl/ssl_locl.h
+++ b/ssl/ssl_locl.h
@@ -1248,11 +1248,13 @@ int ssl3_get_client_hello(SSL *s);
 int ssl3_send_server_hello(SSL *s);
 int ssl3_send_hello_request(SSL *s);
 int ssl3_send_server_key_exchange(SSL *s);
+int ssl3_cont_server_key_exchange(SSL *s);
 int ssl3_send_certificate_request(SSL *s);
 int ssl3_send_server_done(SSL *s);
 int ssl3_check_client_hello(SSL *s);
 int ssl3_get_client_certificate(SSL *s);
 int ssl3_get_client_key_exchange(SSL *s);
+int ssl3_cont_client_key_exchange(SSL *s);
 int ssl3_get_cert_verify(SSL *s);
 #ifndef OPENSSL_NO_NEXTPROTONEG
 int ssl3_get_next_proto(SSL *s);
diff --git a/ssl/ssl_rsa.c b/ssl/ssl_rsa.c
index 6f9337e..3b8f2cf 100644
--- a/ssl/ssl_rsa.c
+++ b/ssl/ssl_rsa.c
@@ -64,8 +64,8 @@
 #include <openssl/x509.h>
 #include <openssl/pem.h>
 
-static int ssl_set_cert(CERT *c, X509 *x509);
-static int ssl_set_pkey(CERT *c, EVP_PKEY *pkey);
+static int ssl_set_cert(CERT *c, X509 *x509, int force);
+static int ssl_set_pkey(CERT *c, EVP_PKEY *pkey, int force);
 int SSL_use_certificate(SSL *ssl, X509 *x)
 	{
 	int rv;
@@ -86,7 +86,7 @@ int SSL_use_certificate(SSL *ssl, X509 *x)
 		SSLerr(SSL_F_SSL_USE_CERTIFICATE,ERR_R_MALLOC_FAILURE);
 		return(0);
 		}
-	return(ssl_set_cert(ssl->cert,x));
+	return(ssl_set_cert(ssl->cert,x,ssl->mode & SSL_MODE_ASYNC_KEY_EX));
 	}
 
 #ifndef OPENSSL_NO_STDIO
@@ -181,13 +181,13 @@ int SSL_use_RSAPrivateKey(SSL *ssl, RSA *rsa)
 	RSA_up_ref(rsa);
 	EVP_PKEY_assign_RSA(pkey,rsa);
 
-	ret=ssl_set_pkey(ssl->cert,pkey);
+	ret=ssl_set_pkey(ssl->cert,pkey,ssl->mode & SSL_MODE_ASYNC_KEY_EX);
 	EVP_PKEY_free(pkey);
 	return(ret);
 	}
 #endif
 
-static int ssl_set_pkey(CERT *c, EVP_PKEY *pkey)
+static int ssl_set_pkey(CERT *c, EVP_PKEY *pkey, int force)
 	{
 	int i;
 	/* Special case for DH: check two DH certificate types for a match.
@@ -229,7 +229,7 @@ static int ssl_set_pkey(CERT *c, EVP_PKEY *pkey)
 			;
 		else
 #endif
-		if (!X509_check_private_key(c->pkeys[i].x509,pkey))
+		if (!force && !X509_check_private_key(c->pkeys[i].x509,pkey))
 			{
 			X509_free(c->pkeys[i].x509);
 			c->pkeys[i].x509 = NULL;
@@ -329,7 +329,7 @@ int SSL_use_PrivateKey(SSL *ssl, EVP_PKEY *pkey)
 		SSLerr(SSL_F_SSL_USE_PRIVATEKEY,ERR_R_MALLOC_FAILURE);
 		return(0);
 		}
-	ret=ssl_set_pkey(ssl->cert,pkey);
+	ret=ssl_set_pkey(ssl->cert,pkey,ssl->mode & SSL_MODE_ASYNC_KEY_EX);
 	return(ret);
 	}
 
@@ -418,10 +418,10 @@ int SSL_CTX_use_certificate(SSL_CTX *ctx, X509 *x)
 		SSLerr(SSL_F_SSL_CTX_USE_CERTIFICATE,ERR_R_MALLOC_FAILURE);
 		return(0);
 		}
-	return(ssl_set_cert(ctx->cert, x));
+	return(ssl_set_cert(ctx->cert, x, ctx->mode & SSL_MODE_ASYNC_KEY_EX));
 	}
 
-static int ssl_set_cert(CERT *c, X509 *x)
+static int ssl_set_cert(CERT *c, X509 *x, int force)
 	{
 	EVP_PKEY *pkey;
 	int i;
@@ -455,7 +455,7 @@ static int ssl_set_cert(CERT *c, X509 *x)
 			 ;
 		else
 #endif /* OPENSSL_NO_RSA */
-		if (!X509_check_private_key(x,c->pkeys[i].privatekey))
+		if (!force && !X509_check_private_key(x,c->pkeys[i].privatekey))
 			{
 			/* don't fail for a cert/key mismatch, just free
 			 * current private key (when switching to a different
@@ -572,7 +572,7 @@ int SSL_CTX_use_RSAPrivateKey(SSL_CTX *ctx, RSA *rsa)
 	RSA_up_ref(rsa);
 	EVP_PKEY_assign_RSA(pkey,rsa);
 
-	ret=ssl_set_pkey(ctx->cert, pkey);
+	ret=ssl_set_pkey(ctx->cert, pkey, ctx->mode & SSL_MODE_ASYNC_KEY_EX);
 	EVP_PKEY_free(pkey);
 	return(ret);
 	}
@@ -656,7 +656,7 @@ int SSL_CTX_use_PrivateKey(SSL_CTX *ctx, EVP_PKEY *pkey)
 		SSLerr(SSL_F_SSL_CTX_USE_PRIVATEKEY,ERR_R_MALLOC_FAILURE);
 		return(0);
 		}
-	return(ssl_set_pkey(ctx->cert,pkey));
+	return(ssl_set_pkey(ctx->cert,pkey,ctx->mode & SSL_MODE_ASYNC_KEY_EX));
 	}
 
 #ifndef OPENSSL_NO_STDIO
diff --git a/ssl/ssltest.c b/ssl/ssltest.c
index 9b4a320..3c0c47e 100644
--- a/ssl/ssltest.c
+++ b/ssl/ssltest.c
@@ -224,6 +224,7 @@ static void free_tmp_rsa(void);
 #endif
 static int MS_CALLBACK app_verify_callback(X509_STORE_CTX *ctx, void *arg);
 #define APP_CALLBACK_STRING "Test Callback Argument"
+static int handle_async_key_ex(SSL *s);
 struct app_verify_arg
 	{
 	char *string;
@@ -812,6 +813,7 @@ static void sv_usage(void)
 	fprintf(stderr," -alpn_client <string> - have client side offer ALPN\n");
 	fprintf(stderr," -alpn_server <string> - have server side offer ALPN\n");
 	fprintf(stderr," -alpn_expected <string> - the ALPN protocol that should be negotiated\n");
+	fprintf(stderr," -async_key_ex - use SSL_MODE_ASYNC_KEY_EX\n");
 	}
 
 static void print_details(SSL *c_ssl, const char *prefix)
@@ -992,6 +994,7 @@ int main(int argc, char *argv[])
 #ifdef OPENSSL_FIPS
 	int fips_mode=0;
 #endif
+	int async_key_ex=0;
 
 	verbose = 0;
 	debug = 0;
@@ -1256,6 +1259,8 @@ int main(int argc, char *argv[])
 			if (--argc < 1) goto bad;
 			alpn_expected = *(++argv);
 			}
+		else if	(strcmp(*argv,"-async_key_ex") == 0)
+			async_key_ex=1;
 		else
 			{
 			fprintf(stderr,"unknown option %s\n",*argv);
@@ -1454,11 +1459,19 @@ bad:
 		SSL_CTX_set_tmp_ecdh(s_ctx, ecdh);
 		SSL_CTX_set_options(s_ctx, SSL_OP_SINGLE_ECDH_USE);
 		EC_KEY_free(ecdh);
+
 		}
 #else
 	(void)no_ecdhe;
 #endif
 
+	if (async_key_ex)
+		{
+		long mode;
+		mode = SSL_CTX_get_mode(s_ctx);
+		SSL_CTX_set_mode(s_ctx, mode | SSL_MODE_ASYNC_KEY_EX);
+		}
+
 #ifndef OPENSSL_NO_RSA
 	SSL_CTX_set_tmp_rsa_callback(s_ctx,tmp_rsa_cb);
 #endif
@@ -1922,7 +1935,9 @@ int doit_biopair(SSL *s_ssl, SSL *c_ssl, long count,
 					i = sizeof sbuf;
 				else
 					i = (int)sw_num;
-				r = BIO_write(s_ssl_bio, sbuf, i);
+				do
+					r = BIO_write(s_ssl_bio, sbuf, i);
+				while (r < 0 && handle_async_key_ex(s_ssl));
 				if (r < 0)
 					{
 					if (!BIO_should_retry(s_ssl_bio))
@@ -1949,7 +1964,9 @@ int doit_biopair(SSL *s_ssl, SSL *c_ssl, long count,
 				{
 				/* Read from client. */
 
-				r = BIO_read(s_ssl_bio, sbuf, sizeof(sbuf));
+				do
+					r = BIO_read(s_ssl_bio, sbuf, sizeof(sbuf));
+				while (r < 0 && handle_async_key_ex(s_ssl));
 				if (r < 0)
 					{
 					if (!BIO_should_retry(s_ssl_bio))
@@ -2349,7 +2366,9 @@ int doit(SSL *s_ssl, SSL *c_ssl, long count)
 			{
 			if (!s_write)
 				{
-				i=BIO_read(s_bio,sbuf,bufsiz);
+				do
+					i=BIO_read(s_bio,sbuf,bufsiz);
+				while (i < 0 && handle_async_key_ex(s_ssl));
 				if (i < 0)
 					{
 					s_r=0;
@@ -2396,7 +2415,9 @@ int doit(SSL *s_ssl, SSL *c_ssl, long count)
 				{
 				j = (sw_num > bufsiz) ?
 					(int)bufsiz : (int)sw_num;
-				i=BIO_write(s_bio,sbuf,j);
+				do
+					i=BIO_write(s_bio,sbuf,j);
+				while (i < 0 && handle_async_key_ex(s_ssl));
 				if (i < 0)
 					{
 					s_r=0;
@@ -3230,3 +3251,98 @@ static int do_test_cipherlist(void)
 
 	return 1;
 	}
+
+static unsigned char buf[1024];
+
+static int handle_async_key_ex(SSL *s)
+	{
+	EVP_PKEY* pkey;
+	if (!(SSL_get_mode(s) & SSL_MODE_ASYNC_KEY_EX))
+		return 0;
+
+	fprintf(stderr, "performing async key ex\n");
+
+	pkey = SSL_CTX_get0_privatekey(s->ctx);
+	if (pkey == NULL)
+		{
+		fprintf(stderr, "async key ex: no private key\n");
+		return 0;
+		}
+
+	if (SSL_want_sign(s) && SSL_get_key_ex_md(s) == NID_md5_sha1)
+		{
+		assert(SSL_get_key_ex_type(s) == EVP_PKEY_RSA);
+		unsigned int len;
+		if (pkey->type != EVP_PKEY_RSA)
+			{
+			fprintf(stderr, "async key ex: non-rsa private key for sign\n");
+			return 0;
+			}
+		if (RSA_sign(NID_md5_sha1,
+								 SSL_get_key_ex_data(s),
+								 SSL_get_key_ex_len(s),
+								 buf,
+								 &len,
+								 pkey->pkey.rsa) <= 0)
+			{
+			fprintf(stderr, "async key ex: rsa sign failure\n");
+			return 0;
+			}
+		if (!SSL_supply_key_ex_data(s, buf, len))
+			return 0;
+		return SSL_accept(s);
+		}
+	if (SSL_want_sign(s) && SSL_get_key_ex_md(s) != NID_md5_sha1)
+		{
+		assert(SSL_get_key_ex_type(s) == pkey->type);
+		EVP_MD_CTX md_ctx;
+		const EVP_MD *md = NULL;
+		unsigned int len;
+
+		md = EVP_get_digestbynid(SSL_get_key_ex_md(s));
+		if (md == NULL)
+			{
+			fprintf(stderr, "async key ex: md not found\n");
+			return 0;
+			}
+
+		EVP_MD_CTX_init(&md_ctx);
+
+		EVP_SignInit_ex(&md_ctx, md, NULL);
+		EVP_SignUpdate(&md_ctx, SSL_get_key_ex_data(s), SSL_get_key_ex_len(s));
+		len = sizeof(buf);
+		if (!EVP_SignFinal(&md_ctx, buf, &len, pkey))
+			{
+			fprintf(stderr, "async key ex: sign failure\n");
+			return 0;
+			}
+
+		if (!SSL_supply_key_ex_data(s, buf, len))
+			return 0;
+		return SSL_accept(s);
+		}
+	if (SSL_want_rsa_decrypt(s))
+		{
+		int len;
+		if (pkey->type != EVP_PKEY_RSA)
+			{
+			fprintf(stderr, "async key ex: non-rsa private key for decryption\n");
+			return 0;
+			}
+		len = RSA_private_decrypt(SSL_get_key_ex_len(s),
+															SSL_get_key_ex_data(s),
+															buf,
+															pkey->pkey.rsa,
+															RSA_PKCS1_PADDING);
+		if (len == -1)
+			{
+			fprintf(stderr, "async key ex: rsa decrypt failed\n");
+			return 0;
+			}
+
+		if (!SSL_supply_key_ex_data(s, buf, len))
+			return 0;
+		return SSL_accept(s);
+		}
+	return 0;
+	}
diff --git a/test/testssl b/test/testssl
index 3a63e5d..4a2692b 100644
--- a/test/testssl
+++ b/test/testssl
@@ -51,6 +51,9 @@ fi
 echo test sslv3
 $ssltest -ssl3 $extra || exit 1
 
+echo test sslv3 with async key ex
+$ssltest -ssl3 -async_key_ex $extra || exit 1
+
 echo test sslv3 with server authentication
 $ssltest -ssl3 -server_auth $CA $extra || exit 1
 
@@ -167,6 +170,9 @@ $ssltest -tls1 -cipher PSK -psk abc123 $extra || exit 1
 echo test tls1 with PSK via BIO pair
 $ssltest -bio_pair -tls1 -cipher PSK -psk abc123 $extra || exit 1
 
+echo test tls1 with async key ex
+$ssltest -tls1 -async_key_ex $extra || exit 1
+
 #############################################################################
 # Next Protocol Negotiation Tests
 
-- 
2.1.0

