From 09d4332d860611ea493b0bbf25e99d77d822cd56 Mon Sep 17 00:00:00 2001
From: Fedor Indutny <fedor@indutny.com>
Date: Sat, 23 Aug 2014 21:38:52 +0400
Subject: [PATCH] ssl: introduce async sign/decrypt APIs

This patch is introducing `async_key_ex_cb` member of both `SSL_CTX` and
`SSL`, and `SSL_supply()`. If `async_key_ex_cb` is present:

* Server will ignore dummy RSA key, assuming that it is matching the
  certificate.
* Server will invoke this callback with either:
  * `SSL_KEY_EX_RSA`
  * `SSL_KEY_EX_RSA_SIGN`
  as a `type` argument, and some data for signature or decryption in
  `p`/`n` pair.

At that time the sign/decryption may be performed on any thread, or even
remotely, and the result should be supplied with `SSL_supply()`. Calling
`SSL_supply()` will continue the handshake process without even touching
the real private key.
---
 ssl/s3_srvr.c | 139 +++++++++++++++++++++++++++++++++++++++++++++++++++++++---
 ssl/ssl.h     |  21 +++++++++
 ssl/ssl3.h    |   6 +++
 ssl/ssl_lib.c |  35 +++++++++++++++
 ssl/ssl_rsa.c |  24 +++++-----
 5 files changed, 208 insertions(+), 17 deletions(-)

diff --git a/ssl/s3_srvr.c b/ssl/s3_srvr.c
index 440fc13..f8006b4 100644
--- a/ssl/s3_srvr.c
+++ b/ssl/s3_srvr.c
@@ -445,6 +445,8 @@ int ssl3_accept(SSL *s)
 
 		case SSL3_ST_SW_KEY_EXCH_A:
 		case SSL3_ST_SW_KEY_EXCH_B:
+		case SSL3_ST_SW_KEY_EXCH_RSA_SUPPLY_A:
+		case SSL3_ST_SW_KEY_EXCH_RSA_SUPPLY_B:
 			alg_k = s->s3->tmp.new_cipher->algorithm_mkey;
 
 			/* clear this, it may get reset by
@@ -496,7 +498,18 @@ int ssl3_accept(SSL *s)
 			    )
 				{
 				ret=ssl3_send_server_key_exchange(s);
-				if (ret <= 0) goto end;
+				if (ret <= 0)
+					goto end;
+				else if (ret == 2)
+					{
+					s->state=SSL3_ST_SW_KEY_EXCH_RSA_WAIT_A;
+					goto end;
+					}
+				else if (ret == 3)
+					{
+					s->state=SSL3_ST_SW_KEY_EXCH_RSA_WAIT_B;
+					goto end;
+					}
 				}
 			else
 				skip=1;
@@ -604,6 +617,7 @@ int ssl3_accept(SSL *s)
 
 		case SSL3_ST_SR_KEY_EXCH_A:
 		case SSL3_ST_SR_KEY_EXCH_B:
+		case SSL3_ST_SR_KEY_EXCH_RSA_SUPPLY:
 			ret=ssl3_get_client_key_exchange(s);
 			if (ret <= 0)
 				goto end;
@@ -627,6 +641,10 @@ int ssl3_accept(SSL *s)
 #endif
 				s->init_num = 0;
 				}
+			else if (ret == 3)
+				{
+				s->state=SSL3_ST_SR_KEY_EXCH_RSA_WAIT;
+				}
 			else if (SSL_USE_SIGALGS(s))
 				{
 				s->state=SSL3_ST_SR_CERT_VRFY_A;
@@ -678,6 +696,15 @@ int ssl3_accept(SSL *s)
 				}
 			break;
 
+		case SSL3_ST_SR_KEY_EXCH_RSA_WAIT:
+		case SSL3_ST_SW_KEY_EXCH_RSA_WAIT_A:
+		case SSL3_ST_SW_KEY_EXCH_RSA_WAIT_B:
+			/* Just to return SSL_WANT_READ */
+			s->rwstate=SSL_READING;
+			BIO_set_flags(SSL_get_rbio(s), BIO_FLAGS_READ);
+			ret = -1;
+			goto end;
+
 		case SSL3_ST_SR_CERT_VRFY_A:
 		case SSL3_ST_SR_CERT_VRFY_B:
 
@@ -2009,7 +2036,35 @@ int ssl3_send_server_key_exchange(SSL *s)
 					q+=i;
 					j+=i;
 					}
-				if (RSA_sign(NID_md5_sha1, md_buf, j,
+
+				/* Use supplied data */
+				if (s->state == SSL3_ST_SW_KEY_EXCH_RSA_SUPPLY_A)
+					{
+async_RSA_cont_a:
+					d=(unsigned char *)s->init_buf->data;
+					p=s->async_key_ex_int;
+					n=s->async_key_ex_int_n;
+					u=s->async_key_ex_len;
+					memcpy(&(p[2]), s->async_key_ex_data, u);
+					}
+				/* Switch to a wait state */
+				else if (s->async_key_ex_cb != NULL)
+					{
+					s->async_key_ex_data = NULL;
+					s->async_key_ex_int=p;
+					s->async_key_ex_int_n=n;
+					s->async_key_ex_cb(s,
+										 SSL_KEY_EX_RSA_SIGN,
+										 EVP_MD_name(md),
+										 md_buf,
+										 j);
+
+					/* Synchronous execution */
+					if (s->async_key_ex_data != NULL)
+						goto async_RSA_cont_a;
+					return 2;
+					}
+				else if (RSA_sign(NID_md5_sha1, md_buf, j,
 					&(p[2]), &u, pkey->pkey.rsa) <= 0)
 					{
 					SSLerr(SSL_F_SSL3_SEND_SERVER_KEY_EXCHANGE,ERR_LIB_RSA);
@@ -2038,6 +2093,47 @@ int ssl3_send_server_key_exchange(SSL *s)
 				fprintf(stderr, "Using hash %s\n",
 							EVP_MD_name(md));
 #endif
+				/* Use supplied data */
+				if (s->state == SSL3_ST_SW_KEY_EXCH_RSA_SUPPLY_B)
+					{
+async_RSA_cont_b:
+					d=(unsigned char *)s->init_buf->data;
+					p=s->async_key_ex_int;
+					n=s->async_key_ex_int_n;
+					i=s->async_key_ex_len;
+					memcpy(&(p[2]), s->async_key_ex_data, i);
+					}
+				/* Switch to a wait state */
+				else if (s->async_key_ex_cb != NULL)
+					{
+					unsigned int j=SSL3_RANDOM_SIZE * 2 + n;
+					unsigned char q[SSL3_RANDOM_SIZE * 2 + 256];
+					if (j > sizeof(q)) {
+						/* Should never happen */
+						al=SSL_AD_INTERNAL_ERROR;
+						SSLerr(SSL_F_SSL3_SEND_SERVER_KEY_EXCHANGE,ERR_R_INTERNAL_ERROR);
+						goto f_err;
+					}
+					memcpy(q, &(s->s3->client_random[0]), SSL3_RANDOM_SIZE);
+					memcpy(q + SSL3_RANDOM_SIZE,
+								 &(s->s3->server_random[0]),
+								 SSL3_RANDOM_SIZE);
+					memcpy(q + 2 * SSL3_RANDOM_SIZE,
+								 &(d[4]),
+								 n);
+
+					s->async_key_ex_data = NULL;
+					s->async_key_ex_int=p;
+					s->async_key_ex_int_n=n;
+					s->async_key_ex_cb(s, SSL_KEY_EX_RSA_SIGN, EVP_MD_name(md), q, j);
+
+					/* Synchronous execution */
+					if (s->async_key_ex_data != NULL)
+						goto async_RSA_cont_b;
+					return 3;
+				}
+			else
+				{
 				EVP_SignInit_ex(&md_ctx, md, NULL);
 				EVP_SignUpdate(&md_ctx,&(s->s3->client_random[0]),SSL3_RANDOM_SIZE);
 				EVP_SignUpdate(&md_ctx,&(s->s3->server_random[0]),SSL3_RANDOM_SIZE);
@@ -2048,6 +2144,7 @@ int ssl3_send_server_key_exchange(SSL *s)
 					SSLerr(SSL_F_SSL3_SEND_SERVER_KEY_EXCHANGE,ERR_LIB_EVP);
 					goto err;
 					}
+				}
 				s2n(i,p);
 				n+=i+2;
 				if (SSL_USE_SIGALGS(s))
@@ -2061,9 +2158,16 @@ int ssl3_send_server_key_exchange(SSL *s)
 				goto f_err;
 				}
 			}
-
 		ssl_set_handshake_header(s, SSL3_MT_SERVER_KEY_EXCHANGE, n);
 		}
+	else if (s->state == SSL3_ST_SW_KEY_EXCH_RSA_SUPPLY_A)
+		{
+		goto async_RSA_cont_a;
+		}
+	else if (s->state == SSL3_ST_SW_KEY_EXCH_RSA_SUPPLY_B)
+		{
+		goto async_RSA_cont_b;
+		}
 
 	s->state = SSL3_ST_SW_KEY_EXCH_B;
 	EVP_MD_CTX_cleanup(&md_ctx);
@@ -2207,6 +2311,10 @@ int ssl3_get_client_key_exchange(SSL *s)
 	BN_CTX *bn_ctx = NULL; 
 #endif
 
+	/* Jump point */
+	if (s->state == SSL3_ST_SR_KEY_EXCH_RSA_SUPPLY)
+		goto async_RSA_cont;
+
 	n=s->method->ssl_get_message(s,
 		SSL3_ST_SR_KEY_EXCH_A,
 		SSL3_ST_SR_KEY_EXCH_B,
@@ -2284,8 +2392,29 @@ int ssl3_get_client_key_exchange(SSL *s)
 		if (RAND_pseudo_bytes(rand_premaster_secret,
 				      sizeof(rand_premaster_secret)) <= 0)
 			goto err;
-		decrypt_len = RSA_private_decrypt((int)n,p,p,rsa,RSA_PKCS1_PADDING);
-		ERR_clear_error();
+		/* Use supplied data */
+		if (s->state == SSL3_ST_SR_KEY_EXCH_RSA_SUPPLY)
+			{
+async_RSA_cont:
+				p=s->async_key_ex_data;
+				decrypt_len=s->async_key_ex_len;
+			}
+		/* Switch to a wait state */
+		else if (s->async_key_ex_cb != NULL)
+			{
+				s->async_key_ex_data = NULL;
+				s->async_key_ex_cb(s, SSL_KEY_EX_RSA, NULL, p, n);
+
+				/* Synchronous execution */
+				if (s->async_key_ex_data != NULL)
+					goto async_RSA_cont;
+				return 3;
+			}
+		else
+			{
+			decrypt_len = RSA_private_decrypt((int)n,p,p,rsa,RSA_PKCS1_PADDING);
+			ERR_clear_error();
+			}
 
 		/* decrypt_len should be SSL_MAX_MASTER_KEY_LENGTH.
 		 * decrypt_good_mask will be zero if so and non-zero otherwise. */
diff --git a/ssl/ssl.h b/ssl/ssl.h
index 5f542d1..a24e046 100644
--- a/ssl/ssl.h
+++ b/ssl/ssl.h
@@ -1146,6 +1146,11 @@ struct ssl_ctx_st
 	size_t tlsext_ellipticcurvelist_length;
 	unsigned char *tlsext_ellipticcurvelist;
 # endif /* OPENSSL_NO_EC */
+	void (*async_key_ex_cb)(SSL* s,
+													int type,
+													const char* md,
+													unsigned char* p,
+													long n);
 	};
 
 #endif
@@ -1570,6 +1575,17 @@ struct ssl_st
 	/* Callback for disabling session caching and ticket support
 	 * on a session basis, depending on the chosen cipher. */
 	int (*not_resumable_session_cb)(SSL *ssl, int is_forward_secure);
+
+	void (*async_key_ex_cb)(SSL* s,
+													int type,
+													const char* md,
+													unsigned char* p,
+													long n);
+	unsigned char* async_key_ex_data;
+	long async_key_ex_len;
+	/* Internal data, don't touch it */
+	void* async_key_ex_int;
+	unsigned int async_key_ex_int_n;
 	};
 
 #endif
@@ -1657,6 +1673,11 @@ size_t SSL_get_peer_finished(const SSL *s, void *buf, size_t count);
 #define OpenSSL_add_ssl_algorithms()	SSL_library_init()
 #define SSLeay_add_ssl_algorithms()	SSL_library_init()
 
+/* Async key exchange params */
+#define SSL_KEY_EX_RSA 0x0
+#define SSL_KEY_EX_RSA_SIGN 0x1
+int SSL_supply(SSL* s, unsigned char* data, long len);
+
 /* this is for backward compatibility */
 #if 0 /* NEW_SSLEAY */
 #define SSL_CTX_set_default_verify(a,b,c) SSL_CTX_set_verify(a,b,c)
diff --git a/ssl/ssl3.h b/ssl/ssl3.h
index d3167cf..e4f7d83 100644
--- a/ssl/ssl3.h
+++ b/ssl/ssl3.h
@@ -689,6 +689,10 @@ typedef struct ssl3_state_st
 #define SSL3_ST_SW_CERT_B		(0x141|SSL_ST_ACCEPT)
 #define SSL3_ST_SW_KEY_EXCH_A		(0x150|SSL_ST_ACCEPT)
 #define SSL3_ST_SW_KEY_EXCH_B		(0x151|SSL_ST_ACCEPT)
+#define SSL3_ST_SW_KEY_EXCH_RSA_WAIT_A		(0x152|SSL_ST_ACCEPT)
+#define SSL3_ST_SW_KEY_EXCH_RSA_SUPPLY_A		(0x153|SSL_ST_ACCEPT)
+#define SSL3_ST_SW_KEY_EXCH_RSA_WAIT_B		(0x154|SSL_ST_ACCEPT)
+#define SSL3_ST_SW_KEY_EXCH_RSA_SUPPLY_B		(0x155|SSL_ST_ACCEPT)
 #define SSL3_ST_SW_CERT_REQ_A		(0x160|SSL_ST_ACCEPT)
 #define SSL3_ST_SW_CERT_REQ_B		(0x161|SSL_ST_ACCEPT)
 #define SSL3_ST_SW_SRVR_DONE_A		(0x170|SSL_ST_ACCEPT)
@@ -698,6 +702,8 @@ typedef struct ssl3_state_st
 #define SSL3_ST_SR_CERT_B		(0x181|SSL_ST_ACCEPT)
 #define SSL3_ST_SR_KEY_EXCH_A		(0x190|SSL_ST_ACCEPT)
 #define SSL3_ST_SR_KEY_EXCH_B		(0x191|SSL_ST_ACCEPT)
+#define SSL3_ST_SR_KEY_EXCH_RSA_WAIT		(0x192|SSL_ST_ACCEPT)
+#define SSL3_ST_SR_KEY_EXCH_RSA_SUPPLY		(0x193|SSL_ST_ACCEPT)
 #define SSL3_ST_SR_CERT_VRFY_A		(0x1A0|SSL_ST_ACCEPT)
 #define SSL3_ST_SR_CERT_VRFY_B		(0x1A1|SSL_ST_ACCEPT)
 #define SSL3_ST_SR_CHANGE_A		(0x1B0|SSL_ST_ACCEPT)
diff --git a/ssl/ssl_lib.c b/ssl/ssl_lib.c
index 98f4018..5f791ee 100644
--- a/ssl/ssl_lib.c
+++ b/ssl/ssl_lib.c
@@ -413,6 +413,7 @@ SSL *SSL_new(SSL_CTX *ctx)
 	s->psk_client_callback=ctx->psk_client_callback;
 	s->psk_server_callback=ctx->psk_server_callback;
 #endif
+	s->async_key_ex_cb=ctx->async_key_ex_cb;
 
 	return(s);
 err:
@@ -809,6 +810,40 @@ size_t SSL_get_peer_finished(const SSL *s, void *buf, size_t count)
 	}
 
 
+int SSL_supply(SSL* s, unsigned char* data, long len)
+	{
+	int async;
+	if (s->s3 == NULL)
+		return 0;
+
+	async = 1;
+	switch (s->state) {
+		case SSL3_ST_SR_KEY_EXCH_RSA_WAIT:
+			s->state=SSL3_ST_SR_KEY_EXCH_RSA_SUPPLY;
+			break;
+		case SSL3_ST_SW_KEY_EXCH_RSA_WAIT_A:
+			s->state=SSL3_ST_SW_KEY_EXCH_RSA_SUPPLY_A;
+			break;
+		case SSL3_ST_SW_KEY_EXCH_RSA_WAIT_B:
+			s->state=SSL3_ST_SW_KEY_EXCH_RSA_SUPPLY_B;
+			break;
+		default:
+			async = 0;
+			break;
+	}
+	if (async) {
+		s->rwstate=SSL_NOTHING;
+		BIO_set_flags(SSL_get_rbio(s), 0);
+	}
+	s->async_key_ex_data=data;
+	s->async_key_ex_len=len;
+	if (async)
+		return SSL_accept(s);
+	else
+		return 1;
+	}
+
+
 int SSL_get_verify_mode(const SSL *s)
 	{
 	return(s->verify_mode);
diff --git a/ssl/ssl_rsa.c b/ssl/ssl_rsa.c
index c76a2a3..a300667 100644
--- a/ssl/ssl_rsa.c
+++ b/ssl/ssl_rsa.c
@@ -64,8 +64,8 @@
 #include <openssl/x509.h>
 #include <openssl/pem.h>
 
-static int ssl_set_cert(CERT *c, X509 *x509);
-static int ssl_set_pkey(CERT *c, EVP_PKEY *pkey);
+static int ssl_set_cert(CERT *c, X509 *x509, int force);
+static int ssl_set_pkey(CERT *c, EVP_PKEY *pkey, int force);
 int SSL_use_certificate(SSL *ssl, X509 *x)
 	{
 	int rv;
@@ -86,7 +86,7 @@ int SSL_use_certificate(SSL *ssl, X509 *x)
 		SSLerr(SSL_F_SSL_USE_CERTIFICATE,ERR_R_MALLOC_FAILURE);
 		return(0);
 		}
-	return(ssl_set_cert(ssl->cert,x));
+	return(ssl_set_cert(ssl->cert,x,ssl->async_key_ex_cb != NULL));
 	}
 
 #ifndef OPENSSL_NO_STDIO
@@ -181,13 +181,13 @@ int SSL_use_RSAPrivateKey(SSL *ssl, RSA *rsa)
 	RSA_up_ref(rsa);
 	EVP_PKEY_assign_RSA(pkey,rsa);
 
-	ret=ssl_set_pkey(ssl->cert,pkey);
+	ret=ssl_set_pkey(ssl->cert,pkey,ssl->async_key_ex_cb != NULL);
 	EVP_PKEY_free(pkey);
 	return(ret);
 	}
 #endif
 
-static int ssl_set_pkey(CERT *c, EVP_PKEY *pkey)
+static int ssl_set_pkey(CERT *c, EVP_PKEY *pkey, int force)
 	{
 	int i;
 	/* Special case for DH: check two DH certificate types for a match.
@@ -229,7 +229,7 @@ static int ssl_set_pkey(CERT *c, EVP_PKEY *pkey)
 			;
 		else
 #endif
-		if (!X509_check_private_key(c->pkeys[i].x509,pkey))
+		if (!force && !X509_check_private_key(c->pkeys[i].x509,pkey))
 			{
 			X509_free(c->pkeys[i].x509);
 			c->pkeys[i].x509 = NULL;
@@ -329,7 +329,7 @@ int SSL_use_PrivateKey(SSL *ssl, EVP_PKEY *pkey)
 		SSLerr(SSL_F_SSL_USE_PRIVATEKEY,ERR_R_MALLOC_FAILURE);
 		return(0);
 		}
-	ret=ssl_set_pkey(ssl->cert,pkey);
+	ret=ssl_set_pkey(ssl->cert,pkey,ssl->async_key_ex_cb != NULL);
 	return(ret);
 	}
 
@@ -418,10 +418,10 @@ int SSL_CTX_use_certificate(SSL_CTX *ctx, X509 *x)
 		SSLerr(SSL_F_SSL_CTX_USE_CERTIFICATE,ERR_R_MALLOC_FAILURE);
 		return(0);
 		}
-	return(ssl_set_cert(ctx->cert, x));
+	return(ssl_set_cert(ctx->cert, x,ctx->async_key_ex_cb != NULL));
 	}
 
-static int ssl_set_cert(CERT *c, X509 *x)
+static int ssl_set_cert(CERT *c, X509 *x, int force)
 	{
 	EVP_PKEY *pkey;
 	int i;
@@ -455,7 +455,7 @@ static int ssl_set_cert(CERT *c, X509 *x)
 			 ;
 		else
 #endif /* OPENSSL_NO_RSA */
-		if (!X509_check_private_key(x,c->pkeys[i].privatekey))
+		if (!force && !X509_check_private_key(x,c->pkeys[i].privatekey))
 			{
 			/* don't fail for a cert/key mismatch, just free
 			 * current private key (when switching to a different
@@ -572,7 +572,7 @@ int SSL_CTX_use_RSAPrivateKey(SSL_CTX *ctx, RSA *rsa)
 	RSA_up_ref(rsa);
 	EVP_PKEY_assign_RSA(pkey,rsa);
 
-	ret=ssl_set_pkey(ctx->cert, pkey);
+	ret=ssl_set_pkey(ctx->cert, pkey,ctx->async_key_ex_cb != NULL);
 	EVP_PKEY_free(pkey);
 	return(ret);
 	}
@@ -656,7 +656,7 @@ int SSL_CTX_use_PrivateKey(SSL_CTX *ctx, EVP_PKEY *pkey)
 		SSLerr(SSL_F_SSL_CTX_USE_PRIVATEKEY,ERR_R_MALLOC_FAILURE);
 		return(0);
 		}
-	return(ssl_set_pkey(ctx->cert,pkey));
+	return(ssl_set_pkey(ctx->cert,pkey,ctx->async_key_ex_cb != NULL));
 	}
 
 #ifndef OPENSSL_NO_STDIO
-- 
2.0.2

