From 271d1ba1b950bbe134cfc2cc6b8c95eb14e87ae2 Mon Sep 17 00:00:00 2001
From: Fedor Indutny <fedor@indutny.com>
Date: Sat, 13 Sep 2014 19:57:57 +0100
Subject: [PATCH 2/2] ssl: support non-RSA key signatures in key ex

---
 ssl/s3_srvr.c | 12 +++++++-----
 ssl/ssl.h     | 10 ++++++----
 ssl/ssl3.h    |  4 ++--
 ssl/ssl_lib.c |  8 ++++----
 ssl/ssltest.c | 40 +++++++++++++++++++++-------------------
 5 files changed, 40 insertions(+), 34 deletions(-)

diff --git a/ssl/s3_srvr.c b/ssl/s3_srvr.c
index 2faf319..c8129d3 100644
--- a/ssl/s3_srvr.c
+++ b/ssl/s3_srvr.c
@@ -500,12 +500,12 @@ int ssl3_accept(SSL *s)
 					goto end;
 				else if (ret == 2)
 					{
-					s->state=SSL3_ST_SW_KEY_EXCH_RSA_SIGN_SUPPLY;
+					s->state=SSL3_ST_SW_KEY_EXCH_SIGN_SUPPLY;
 					break;
 					}
 				else if (ret == 3)
 					{
-					s->state=SSL3_ST_SW_KEY_EXCH_RSA_SIGN_WAIT;
+					s->state=SSL3_ST_SW_KEY_EXCH_SIGN_WAIT;
 					break;
 					}
 				}
@@ -518,12 +518,12 @@ int ssl3_accept(SSL *s)
 			s->init_num=0;
 			break;
 
-		case SSL3_ST_SW_KEY_EXCH_RSA_SIGN_WAIT:
-			s->rwstate=SSL_RSA_SIGN;
+		case SSL3_ST_SW_KEY_EXCH_SIGN_WAIT:
+			s->rwstate=SSL_SIGN;
 			ret = -1;
 			goto end;
 
-		case SSL3_ST_SW_KEY_EXCH_RSA_SIGN_SUPPLY:
+		case SSL3_ST_SW_KEY_EXCH_SIGN_SUPPLY:
 			ret=ssl3_cont_server_key_exchange(s);
 			if (ret != 1)
 				goto end;
@@ -2079,6 +2079,7 @@ int ssl3_send_server_key_exchange(SSL *s)
 					/* Copy md_buf contents to init_buf */
 					s->key_ex.data=&p[2];
 					s->key_ex.md=NID_md5_sha1;
+					s->key_ex.type=pkey->type;
 					s->key_ex.len=j;
 					if (!BUF_MEM_grow(s->init_buf,
 														(s->key_ex.data - d) + s->key_ex.len))
@@ -2136,6 +2137,7 @@ int ssl3_send_server_key_exchange(SSL *s)
 					s->key_ex.data=&p[2];
 					s->key_ex.len=2 * SSL3_RANDOM_SIZE + n;
 					s->key_ex.md=EVP_MD_nid(md);
+					s->key_ex.type=pkey->type;
 					if (!BUF_MEM_grow(s->init_buf,
 														(s->key_ex.data - d) + s->key_ex.len))
 						{
diff --git a/ssl/ssl.h b/ssl/ssl.h
index 239dde2..0dde6f9 100644
--- a/ssl/ssl.h
+++ b/ssl/ssl.h
@@ -685,7 +685,7 @@ struct ssl_session_st
  */
 #define SSL_MODE_SEND_CLIENTHELLO_TIME 0x00000020L
 #define SSL_MODE_SEND_SERVERHELLO_TIME 0x00000040L
-/* If set - SSL_ERROR_WANT_RSA_DECRYPT/SSL_ERROR_WANT_RSA_SIGN may be returned
+/* If set - SSL_ERROR_WANT_RSA_DECRYPT/SSL_ERROR_WANT_SIGN may be returned
  * by SSL_read()/SSL_write(), when the key exchange requires the use of the
  * RSA private key.
  */
@@ -1281,7 +1281,7 @@ int SSL_extension_supported(unsigned int ext_type);
 #define SSL_READING	3
 #define SSL_X509_LOOKUP	4
 #define SSL_RSA_DECRYPT	5
-#define SSL_RSA_SIGN	6
+#define SSL_SIGN	6
 
 /* These will only be used when doing non-blocking IO */
 #define SSL_want_nothing(s)	(SSL_want(s) == SSL_NOTHING)
@@ -1289,7 +1289,7 @@ int SSL_extension_supported(unsigned int ext_type);
 #define SSL_want_write(s)	(SSL_want(s) == SSL_WRITING)
 #define SSL_want_x509_lookup(s)	(SSL_want(s) == SSL_X509_LOOKUP)
 #define SSL_want_rsa_decrypt(s)	(SSL_want(s) == SSL_RSA_DECRYPT)
-#define SSL_want_rsa_sign(s)	(SSL_want(s) == SSL_RSA_SIGN)
+#define SSL_want_sign(s)	(SSL_want(s) == SSL_SIGN)
 
 #define SSL_MAC_FLAG_READ_MAC_STREAM 1
 #define SSL_MAC_FLAG_WRITE_MAC_STREAM 2
@@ -1560,6 +1560,7 @@ struct ssl_st
 		unsigned char* data;
 		long len;
 		int md;
+		int type;
 
 		/* Internal */
 		long recoff;
@@ -1656,6 +1657,7 @@ size_t SSL_get_peer_finished(const SSL *s, void *buf, size_t count);
 #define SSL_get_key_ex_data(s) ((s)->key_ex.data)
 #define SSL_get_key_ex_len(s) ((s)->key_ex.len)
 #define SSL_get_key_ex_md(s) ((s)->key_ex.md)
+#define SSL_get_key_ex_type(s) ((s)->key_ex.type)
 
 int SSL_supply_key_ex_data(SSL* s, unsigned char* data, long len);
 
@@ -1730,7 +1732,7 @@ DECLARE_PEM_rw(SSL_SESSION, SSL_SESSION)
 #define SSL_ERROR_WANT_CONNECT		7
 #define SSL_ERROR_WANT_ACCEPT		8
 #define SSL_ERROR_WANT_RSA_DECRYPT		9
-#define SSL_ERROR_WANT_RSA_SIGN		10
+#define SSL_ERROR_WANT_SIGN		10
 
 #define SSL_CTRL_NEED_TMP_RSA			1
 #define SSL_CTRL_SET_TMP_RSA			2
diff --git a/ssl/ssl3.h b/ssl/ssl3.h
index f3dd3c3..f4484b3 100644
--- a/ssl/ssl3.h
+++ b/ssl/ssl3.h
@@ -628,8 +628,8 @@ typedef struct ssl3_state_st
 #define SSL3_ST_CR_KEY_EXCH_A		(0x140|SSL_ST_CONNECT)
 #define SSL3_ST_CR_KEY_EXCH_B		(0x141|SSL_ST_CONNECT)
 #define SSL3_ST_SW_KEY_EXCH_C		(0x152|SSL_ST_ACCEPT)
-#define SSL3_ST_SW_KEY_EXCH_RSA_SIGN_WAIT		(0x153|SSL_ST_ACCEPT)
-#define SSL3_ST_SW_KEY_EXCH_RSA_SIGN_SUPPLY		(0x154|SSL_ST_ACCEPT)
+#define SSL3_ST_SW_KEY_EXCH_SIGN_WAIT		(0x153|SSL_ST_ACCEPT)
+#define SSL3_ST_SW_KEY_EXCH_SIGN_SUPPLY		(0x154|SSL_ST_ACCEPT)
 #define SSL3_ST_CR_CERT_REQ_A		(0x150|SSL_ST_CONNECT)
 #define SSL3_ST_CR_CERT_REQ_B		(0x151|SSL_ST_CONNECT)
 #define SSL3_ST_CR_SRVR_DONE_A		(0x160|SSL_ST_CONNECT)
diff --git a/ssl/ssl_lib.c b/ssl/ssl_lib.c
index abec25e..c36ee4b 100644
--- a/ssl/ssl_lib.c
+++ b/ssl/ssl_lib.c
@@ -818,8 +818,8 @@ int SSL_supply_key_ex_data(SSL* s, unsigned char* data, long len)
 		case SSL3_ST_SR_KEY_EXCH_RSA_DECRYPT_WAIT:
 			s->state=SSL3_ST_SR_KEY_EXCH_RSA_DECRYPT_SUPPLY;
 			break;
-		case SSL3_ST_SW_KEY_EXCH_RSA_SIGN_WAIT:
-			s->state=SSL3_ST_SW_KEY_EXCH_RSA_SIGN_SUPPLY;
+		case SSL3_ST_SW_KEY_EXCH_SIGN_WAIT:
+			s->state=SSL3_ST_SW_KEY_EXCH_SIGN_SUPPLY;
 			break;
 		default:
 			return 0;
@@ -2783,9 +2783,9 @@ int SSL_get_error(const SSL *s,int i)
 		{
 		return(SSL_ERROR_WANT_RSA_DECRYPT);
 		}
-	if ((i < 0) && SSL_want_rsa_sign(s))
+	if ((i < 0) && SSL_want_sign(s))
 		{
-		return(SSL_ERROR_WANT_RSA_SIGN);
+		return(SSL_ERROR_WANT_SIGN);
 		}
 	if (i == 0)
 		{
diff --git a/ssl/ssltest.c b/ssl/ssltest.c
index d70fc7a..2bbbfa4 100644
--- a/ssl/ssltest.c
+++ b/ssl/ssltest.c
@@ -3269,28 +3269,30 @@ static int handle_async_key_ex(SSL *s)
 		return 0;
 		}
 
-	if (SSL_want_rsa_sign(s) && SSL_get_key_ex_md(s) == NID_md5_sha1)
+	if (SSL_want_sign(s) && SSL_get_key_ex_md(s) == NID_md5_sha1)
 		{
-			unsigned int len;
-			if (pkey->type != EVP_PKEY_RSA)
-				{
-				fprintf(stderr, "async key ex: non-rsa private key for sign\n");
-				return 0;
-				}
-			if (RSA_sign(NID_md5_sha1,
-									 SSL_get_key_ex_data(s),
-									 SSL_get_key_ex_len(s),
-									 buf,
-									 &len,
-									 pkey->pkey.rsa) <= 0)
-				{
-				fprintf(stderr, "async key ex: rsa sign failure\n");
-				return 0;
-				}
-			return SSL_supply_key_ex_data(s, buf, len);
+		assert(SSL_get_key_ex_type(s) == EVP_PKEY_RSA);
+		unsigned int len;
+		if (pkey->type != EVP_PKEY_RSA)
+			{
+			fprintf(stderr, "async key ex: non-rsa private key for sign\n");
+			return 0;
+			}
+		if (RSA_sign(NID_md5_sha1,
+								 SSL_get_key_ex_data(s),
+								 SSL_get_key_ex_len(s),
+								 buf,
+								 &len,
+								 pkey->pkey.rsa) <= 0)
+			{
+			fprintf(stderr, "async key ex: rsa sign failure\n");
+			return 0;
+			}
+		return SSL_supply_key_ex_data(s, buf, len);
 		}
-	if (SSL_want_rsa_sign(s) && SSL_get_key_ex_md(s) != NID_md5_sha1)
+	if (SSL_want_sign(s) && SSL_get_key_ex_md(s) != NID_md5_sha1)
 		{
+		assert(SSL_get_key_ex_type(s) == pkey->type);
 		EVP_MD_CTX md_ctx;
 		const EVP_MD *md = NULL;
 		unsigned int len;
-- 
2.1.0

