Signed-off-by: Yanis Kurganov <YKurganov@ptsecurity.com>

---
 include/libssh/crypto.h  |   6 +-
 include/libssh/dh.h      |   4 +-
 include/libssh/session.h |   1 +
 src/client.c             |  19 +++-
 src/dh.c                 | 281 +++++++++++++++++++++++++++++++----------------
 src/kex.c                |   6 +-
 src/packet.c             |   2 +-
 src/packet_cb.c          |  17 ++-
 src/wrapper.c            |   2 +
 9 files changed, 233 insertions(+), 105 deletions(-)

diff --git a/include/libssh/crypto.h b/include/libssh/crypto.h
index eaff2ff..2aca915 100644
--- a/include/libssh/crypto.h
+++ b/include/libssh/crypto.h
@@ -51,6 +51,10 @@ enum ssh_key_exchange_e {
   SSH_KEX_DH_GROUP1_SHA1=1,
   /* diffie-hellman-group14-sha1 */
   SSH_KEX_DH_GROUP14_SHA1,
+  /* diffie-hellman-group-exchange-sha1 */
+  SSH_KEX_DH_GROUP_SHA1,
+  /* diffie-hellman-group-exchange-sha256 */
+  SSH_KEX_DH_GROUP_SHA256,
   /* ecdh-sha2-nistp256 */
   SSH_KEX_ECDH_SHA2_NISTP256,
   /* curve25519-sha256@libssh.org */
@@ -58,7 +62,7 @@ enum ssh_key_exchange_e {
 };
 
 struct ssh_crypto_struct {
-    bignum e,f,x,k,y;
+    bignum p,g,e,f,x,k,y;
 #ifdef HAVE_ECDH
     EC_KEY *ecdh_privkey;
     ssh_string ecdh_client_pubkey;
diff --git a/include/libssh/dh.h b/include/libssh/dh.h
index e1039e2..a579a3d 100644
--- a/include/libssh/dh.h
+++ b/include/libssh/dh.h
@@ -40,7 +40,9 @@ int dh_import_f(ssh_session session,ssh_string f_string);
 int dh_import_e(ssh_session session, ssh_string e_string);
 void dh_import_pubkey(ssh_session session,ssh_string pubkey_string);
 int dh_build_k(ssh_session session);
-int ssh_client_dh_init(ssh_session session);
+int ssh_client_dh_group_init(ssh_session session);
+int ssh_client_dh_gex_init(ssh_session session);
+int ssh_client_dh_gex_reply(ssh_session session, ssh_buffer packet);
 int ssh_client_dh_reply(ssh_session session, ssh_buffer packet);
 
 int make_sessionid(ssh_session session);
diff --git a/include/libssh/session.h b/include/libssh/session.h
index c360a70..977d2e9 100644
--- a/include/libssh/session.h
+++ b/include/libssh/session.h
@@ -45,6 +45,7 @@ enum ssh_session_state_e {
 
 enum ssh_dh_state_e {
   DH_STATE_INIT=0,
+  DH_STATE_GEX_REQUEST_SENT,
   DH_STATE_INIT_SENT,
   DH_STATE_NEWKEYS_SENT,
   DH_STATE_FINISHED
diff --git a/src/client.c b/src/client.c
index cb41f1c..229c962 100644
--- a/src/client.c
+++ b/src/client.c
@@ -194,7 +194,7 @@ end:
  * completed
  */
 static int dh_handshake(ssh_session session) {
-
+  enum ssh_dh_state_e dh_handshake_state = DH_STATE_INIT_SENT;
   int rc = SSH_AGAIN;
 
   switch (session->dh_handshake_state) {
@@ -202,7 +202,12 @@ static int dh_handshake(ssh_session session) {
       switch(session->next_crypto->kex_type){
         case SSH_KEX_DH_GROUP1_SHA1:
         case SSH_KEX_DH_GROUP14_SHA1:
-          rc = ssh_client_dh_init(session);
+          rc = ssh_client_dh_group_init(session);
+          break;
+        case SSH_KEX_DH_GROUP_SHA1:
+        case SSH_KEX_DH_GROUP_SHA256:
+          rc = ssh_client_dh_gex_init(session);
+          dh_handshake_state = DH_STATE_GEX_REQUEST_SENT;
           break;
 #ifdef HAVE_ECDH
         case SSH_KEX_ECDH_SHA2_NISTP256:
@@ -222,14 +227,22 @@ static int dh_handshake(ssh_session session) {
           return SSH_ERROR;
       }
 
-      session->dh_handshake_state = DH_STATE_INIT_SENT;
+      session->dh_handshake_state = dh_handshake_state;
+      break;
+    case DH_STATE_GEX_REQUEST_SENT:
+        SSH_LOG(SSH_LOG_TRACE,"case DH_STATE_GEX_REQUEST_SENT");
+        /* wait until ssh_packet_dh_reply is called */
+    	break;
     case DH_STATE_INIT_SENT:
+        SSH_LOG(SSH_LOG_TRACE,"case DH_STATE_INIT_SENT");
     	/* wait until ssh_packet_dh_reply is called */
     	break;
     case DH_STATE_NEWKEYS_SENT:
+        SSH_LOG(SSH_LOG_TRACE,"case DH_STATE_NEWKEYS_SENT");
     	/* wait until ssh_packet_newkeys is called */
     	break;
     case DH_STATE_FINISHED:
+      SSH_LOG(SSH_LOG_TRACE,"case DH_STATE_FINISHED");
       return SSH_OK;
     default:
       ssh_set_error(session, SSH_FATAL, "Invalid state in dh_handshake(): %d",
diff --git a/src/dh.c b/src/dh.c
index 3c2e5ad..3e07961 100644
--- a/src/dh.c
+++ b/src/dh.c
@@ -81,8 +81,8 @@ static unsigned char p_group1_value[] = {
         0xEE, 0x38, 0x6B, 0xFB, 0x5A, 0x89, 0x9F, 0xA5, 0xAE, 0x9F, 0x24, 0x11,
         0x7C, 0x4B, 0x1F, 0xE6, 0x49, 0x28, 0x66, 0x51, 0xEC, 0xE6, 0x53, 0x81,
         0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF};
-#define P_GROUP1_LEN 128	/* Size in bytes of the p number */
 
+#define P_GROUP1_LEN 128	/* Size in bytes of the p number */
 
 static unsigned char p_group14_value[] = {
         0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xC9, 0x0F, 0xDA, 0xA2,
@@ -110,15 +110,9 @@ static unsigned char p_group14_value[] = {
 
 #define P_GROUP14_LEN 256 /* Size in bytes of the p number for group 14 */
 
-static unsigned long g_int = 2 ;	/* G is defined as 2 by the ssh2 standards */
-static bignum g;
-static bignum p_group1;
-static bignum p_group14;
-static int ssh_crypto_initialized;
+#define PREFERRED_GROUP_BITS 1048576 /* preferred size in bits of the group the server will send */
 
-static bignum select_p(enum ssh_key_exchange_e type) {
-    return type == SSH_KEX_DH_GROUP14_SHA1 ? p_group14 : p_group1;
-}
+static int ssh_crypto_initialized;
 
 int ssh_get_random(void *where, int len, int strong){
 
@@ -141,11 +135,6 @@ int ssh_get_random(void *where, int len, int strong){
   return 1;
 }
 
-
-/*
- * This inits the values g and p which are used for DH key agreement
- * FIXME: Make the function thread safe by adding a semaphore or mutex.
- */
 int ssh_crypto_init(void) {
   if (ssh_crypto_initialized == 0) {
 #ifdef HAVE_LIBGCRYPT
@@ -155,52 +144,9 @@ int ssh_crypto_init(void) {
       gcry_control(GCRYCTL_INITIALIZATION_FINISHED,0);
     }
 #endif
-
-    g = bignum_new();
-    if (g == NULL) {
-      return -1;
-    }
-    bignum_set_word(g,g_int);
-
-#ifdef HAVE_LIBGCRYPT
-    bignum_bin2bn(p_group1_value, P_GROUP1_LEN, &p_group1);
-    if (p_group1 == NULL) {
-      bignum_free(g);
-      g = NULL;
-      return -1;
-    }
-    bignum_bin2bn(p_group14_value, P_GROUP14_LEN, &p_group14);
-    if (p_group14 == NULL) {
-      bignum_free(g);
-      bignum_free(p_group1);
-      g = NULL;
-      p_group1 = NULL;
-      return -1;
-    }
-
-#elif defined HAVE_LIBCRYPTO
-    p_group1 = bignum_new();
-    if (p_group1 == NULL) {
-      bignum_free(g);
-      g = NULL;
-      return -1;
-    }
-    bignum_bin2bn(p_group1_value, P_GROUP1_LEN, p_group1);
-
-    p_group14 = bignum_new();
-    if (p_group14 == NULL) {
-      bignum_free(g);
-      bignum_free(p_group1);
-      g = NULL;
-      p_group1 = NULL;
-      return -1;
-    }
-    bignum_bin2bn(p_group14_value, P_GROUP14_LEN, p_group14);
-
+#ifdef HAVE_LIBCRYPTO
     OpenSSL_add_all_algorithms();
-
 #endif
-
     ssh_crypto_initialized = 1;
   }
 
@@ -209,12 +155,6 @@ int ssh_crypto_init(void) {
 
 void ssh_crypto_finalize(void) {
   if (ssh_crypto_initialized) {
-    bignum_free(g);
-    g = NULL;
-    bignum_free(p_group1);
-    p_group1 = NULL;
-    bignum_free(p_group14);
-    p_group14 = NULL;
 #ifdef HAVE_LIBGCRYPT
     gcry_control(GCRYCTL_TERM_SECMEM);
 #elif defined HAVE_LIBCRYPTO
@@ -298,11 +238,11 @@ int dh_generate_e(ssh_session session) {
   }
 
 #ifdef HAVE_LIBGCRYPT
-  bignum_mod_exp(session->next_crypto->e, g, session->next_crypto->x,
-      select_p(session->next_crypto->kex_type));
+  bignum_mod_exp(session->next_crypto->e, session->next_crypto->g,
+      session->next_crypto->x, session->next_crypto->p);
 #elif defined HAVE_LIBCRYPTO
-  bignum_mod_exp(session->next_crypto->e, g, session->next_crypto->x,
-      select_p(session->next_crypto->kex_type), ctx);
+  bignum_mod_exp(session->next_crypto->e, session->next_crypto->g,
+      session->next_crypto->x, session->next_crypto->p, ctx);
 #endif
 
 #ifdef DEBUG_CRYPTO
@@ -333,11 +273,11 @@ int dh_generate_f(ssh_session session) {
   }
 
 #ifdef HAVE_LIBGCRYPT
-  bignum_mod_exp(session->next_crypto->f, g, session->next_crypto->y,
-      select_p(session->next_crypto->kex_type));
+  bignum_mod_exp(session->next_crypto->f, session->next_crypto->g,
+      session->next_crypto->y, session->next_crypto->p);
 #elif defined HAVE_LIBCRYPTO
-  bignum_mod_exp(session->next_crypto->f, g, session->next_crypto->y,
-      select_p(session->next_crypto->kex_type), ctx);
+  bignum_mod_exp(session->next_crypto->f, session->next_crypto->g,
+      session->next_crypto->y, session->next_crypto->p, ctx);
 #endif
 
 #ifdef DEBUG_CRYPTO
@@ -447,6 +387,65 @@ int dh_import_e(ssh_session session, ssh_string e_string) {
   return 0;
 }
 
+/* p number */
+static int dh_import_p_string(ssh_session session, ssh_string p_string) {
+  session->next_crypto->p = make_string_bn(p_string);
+  if (session->next_crypto->p == NULL) {
+    return -1;
+  }
+
+#ifdef DEBUG_CRYPTO
+    ssh_print_bignum("p",session->next_crypto->p);
+#endif
+
+  return 0;
+}
+
+static int dh_import_p_value(ssh_session session, const unsigned char* p_value, size_t p_size) {
+  ssh_string p_string = ssh_string_new(p_size);
+  int rc;
+
+  if (p_string == NULL) {
+    return SSH_ERROR;
+  }
+
+  ssh_string_fill(p_string, p_value, p_size);
+  rc = dh_import_p_string(session, p_string);
+
+  ssh_string_burn(p_string);
+  ssh_string_free(p_string);
+
+  return rc;
+}
+
+/* g number */
+static int dh_import_g_string(ssh_session session, ssh_string g_string) {
+  session->next_crypto->g = make_string_bn(g_string);
+  if (session->next_crypto->g == NULL) {
+    return -1;
+  }
+
+#ifdef DEBUG_CRYPTO
+    ssh_print_bignum("g",session->next_crypto->g);
+#endif
+
+  return 0;
+}
+
+static int dh_import_g_value(ssh_session session, unsigned int g_value) {
+  session->next_crypto->g = bignum_new();
+  if (session->next_crypto->g == NULL) {
+    return -1;
+  }
+  bignum_set_word(session->next_crypto->g, g_value);
+
+#ifdef DEBUG_CRYPTO
+  ssh_print_bignum("g",session->next_crypto->g);
+#endif
+
+  return 0;
+}
+
 int dh_build_k(ssh_session session) {
 #ifdef HAVE_LIBCRYPTO
   bignum_CTX ctx = bignum_ctx_new();
@@ -467,18 +466,18 @@ int dh_build_k(ssh_session session) {
 #ifdef HAVE_LIBGCRYPT
   if(session->client) {
     bignum_mod_exp(session->next_crypto->k, session->next_crypto->f,
-        session->next_crypto->x, select_p(session->next_crypto->kex_type));
+        session->next_crypto->x, session->next_crypto->p);
   } else {
     bignum_mod_exp(session->next_crypto->k, session->next_crypto->e,
-        session->next_crypto->y, select_p(session->next_crypto->kex_type));
+        session->next_crypto->y, session->next_crypto->p);
   }
 #elif defined HAVE_LIBCRYPTO
   if (session->client) {
     bignum_mod_exp(session->next_crypto->k, session->next_crypto->f,
-        session->next_crypto->x, select_p(session->next_crypto->kex_type), ctx);
+        session->next_crypto->x, session->next_crypto->p, ctx);
   } else {
     bignum_mod_exp(session->next_crypto->k, session->next_crypto->e,
-        session->next_crypto->y, select_p(session->next_crypto->kex_type), ctx);
+        session->next_crypto->y, session->next_crypto->p, ctx);
   }
 #endif
 
@@ -498,15 +497,10 @@ int dh_build_k(ssh_session session) {
 }
 
 /** @internal
- * @brief Starts diffie-hellman-group1 key exchange
+ * @brief Starts diffie-hellman key exchange
  */
-int ssh_client_dh_init(ssh_session session){
+static int ssh_client_dh_init(ssh_session session){
   ssh_string e = NULL;
-  int rc;
-
-  if (buffer_add_u8(session->out_buffer, SSH2_MSG_KEXDH_INIT) < 0) {
-    goto error;
-  }
 
   if (dh_generate_x(session) < 0) {
     goto error;
@@ -527,8 +521,7 @@ int ssh_client_dh_init(ssh_session session){
   ssh_string_free(e);
   e=NULL;
 
-  rc = packet_send(session);
-  return rc;
+  return packet_send(session);
   error:
   if(e != NULL){
     ssh_string_burn(e);
@@ -538,6 +531,74 @@ int ssh_client_dh_init(ssh_session session){
   return SSH_ERROR;
 }
 
+int ssh_client_dh_group_init(ssh_session session){
+  const unsigned int g_value = 2; /* G is defined as 2 by the ssh2 standards */
+
+  if(dh_import_p_value(session,
+     session->next_crypto->kex_type == SSH_KEX_DH_GROUP1_SHA1 ? p_group1_value : p_group14_value,
+     session->next_crypto->kex_type == SSH_KEX_DH_GROUP1_SHA1 ? P_GROUP1_LEN : P_GROUP14_LEN) < 0) {
+    ssh_set_error(session, SSH_FATAL, "Cannot import p number");
+    return SSH_ERROR;
+  }
+
+  if(dh_import_g_value(session, g_value) < 0) {
+    ssh_set_error(session, SSH_FATAL, "Cannot import g number");
+    return SSH_ERROR;
+  }
+
+  if(buffer_add_u8(session->out_buffer, SSH2_MSG_KEXDH_INIT) < 0) {
+    return SSH_ERROR;
+  }
+
+  return ssh_client_dh_init(session);
+}
+
+int ssh_client_dh_gex_init(ssh_session session){
+  if (buffer_add_u8(session->out_buffer, SSH2_MSG_KEX_DH_GEX_REQUEST_OLD) < 0) {
+    return SSH_ERROR;
+  }
+  if (buffer_add_u32(session->out_buffer, PREFERRED_GROUP_BITS) < 0) {
+    return SSH_ERROR;
+  }
+  return packet_send(session);
+}
+
+int ssh_client_dh_gex_reply(ssh_session session, ssh_buffer packet){
+  ssh_string s;
+  int rc;
+
+  s = buffer_get_ssh_string(packet);
+  if (s == NULL) {
+    ssh_set_error(session,SSH_FATAL, "No p number in packet");
+    return SSH_ERROR;
+  }
+  rc = dh_import_p_string(session, s);
+  ssh_string_burn(s);
+  ssh_string_free(s);
+  if (rc < 0) {
+    ssh_set_error(session, SSH_FATAL, "Cannot import p number");
+    return SSH_ERROR;
+  }
+
+  s = buffer_get_ssh_string(packet);
+  if (s == NULL) {
+    ssh_set_error(session,SSH_FATAL, "No g number in packet");
+    return SSH_ERROR;
+  }
+  rc = dh_import_g_string(session, s);
+  ssh_string_burn(s);
+  ssh_string_free(s);
+  if (rc < 0) {
+    ssh_set_error(session, SSH_FATAL, "Cannot import g number");
+    return SSH_ERROR;
+  }
+
+  if (buffer_add_u8(session->out_buffer, SSH2_MSG_KEX_DH_GEX_INIT) < 0) {
+    return SSH_ERROR;
+  }
+  return ssh_client_dh_init(session);
+}
+
 int ssh_client_dh_reply(ssh_session session, ssh_buffer packet){
   ssh_string f;
   ssh_string pubkey = NULL;
@@ -546,48 +607,45 @@ int ssh_client_dh_reply(ssh_session session, ssh_buffer packet){
   pubkey = buffer_get_ssh_string(packet);
   if (pubkey == NULL){
     ssh_set_error(session,SSH_FATAL, "No public key in packet");
-    goto error;
+    return SSH_ERROR;
   }
   dh_import_pubkey(session, pubkey);
 
   f = buffer_get_ssh_string(packet);
   if (f == NULL) {
     ssh_set_error(session,SSH_FATAL, "No F number in packet");
-    goto error;
+    return SSH_ERROR;
   }
   rc = dh_import_f(session, f);
   ssh_string_burn(f);
   ssh_string_free(f);
   if (rc < 0) {
     ssh_set_error(session, SSH_FATAL, "Cannot import f number");
-    goto error;
+    return SSH_ERROR;
   }
 
   signature = buffer_get_ssh_string(packet);
   if (signature == NULL) {
     ssh_set_error(session, SSH_FATAL, "No signature in packet");
-    goto error;
+    return SSH_ERROR;
   }
   session->next_crypto->dh_server_signature = signature;
   signature=NULL; /* ownership changed */
   if (dh_build_k(session) < 0) {
     ssh_set_error(session, SSH_FATAL, "Cannot build k number");
-    goto error;
+    return SSH_ERROR;
   }
 
   /* Send the MSG_NEWKEYS */
   if (buffer_add_u8(session->out_buffer, SSH2_MSG_NEWKEYS) < 0) {
-    goto error;
+    return SSH_ERROR;
   }
 
   rc=packet_send(session);
   SSH_LOG(SSH_LOG_PROTOCOL, "SSH_MSG_NEWKEYS sent");
   return rc;
-error:
-  return SSH_ERROR;
 }
 
-
 /*
 static void sha_add(ssh_string str,SHACTX ctx){
     sha1_update(ctx,str,string_len(str)+4);
@@ -674,7 +732,38 @@ int make_sessionid(ssh_session session) {
     goto error;
   }
   if(session->next_crypto->kex_type == SSH_KEX_DH_GROUP1_SHA1 ||
-     session->next_crypto->kex_type == SSH_KEX_DH_GROUP14_SHA1) {
+     session->next_crypto->kex_type == SSH_KEX_DH_GROUP14_SHA1 ||
+     session->next_crypto->kex_type == SSH_KEX_DH_GROUP_SHA1 ||
+     session->next_crypto->kex_type == SSH_KEX_DH_GROUP_SHA256) {
+
+    if(session->next_crypto->kex_type == SSH_KEX_DH_GROUP_SHA1 ||
+       session->next_crypto->kex_type == SSH_KEX_DH_GROUP_SHA256) {
+      if (buffer_add_u32(buf, PREFERRED_GROUP_BITS) < 0) {
+        goto error;
+      }
+
+      num = make_bignum_string(session->next_crypto->p);
+      if (num == NULL) {
+        goto error;
+      }
+
+      len = ssh_string_len(num) + 4;
+      if (buffer_add_data(buf, num, len) < 0) {
+        goto error;
+      }
+      ssh_string_free(num);
+
+      num = make_bignum_string(session->next_crypto->g);
+      if (num == NULL) {
+        goto error;
+      }
+
+      len = ssh_string_len(num) + 4;
+      if (buffer_add_data(buf, num, len) < 0) {
+        goto error;
+      }
+      ssh_string_free(num);
+    }
 
     num = make_bignum_string(session->next_crypto->e);
     if (num == NULL) {
@@ -744,6 +833,7 @@ int make_sessionid(ssh_session session) {
   switch(session->next_crypto->kex_type){
     case SSH_KEX_DH_GROUP1_SHA1:
     case SSH_KEX_DH_GROUP14_SHA1:
+    case SSH_KEX_DH_GROUP_SHA1:
       session->next_crypto->digest_len = SHA_DIGEST_LENGTH;
       session->next_crypto->mac_type = SSH_MAC_SHA1;
       session->next_crypto->secret_hash = malloc(session->next_crypto->digest_len);
@@ -754,6 +844,7 @@ int make_sessionid(ssh_session session) {
       sha1(buffer_get_rest(buf), buffer_get_rest_len(buf),
                 session->next_crypto->secret_hash);
       break;
+    case SSH_KEX_DH_GROUP_SHA256:
     case SSH_KEX_ECDH_SHA2_NISTP256:
     case SSH_KEX_CURVE25519_SHA256_LIBSSH_ORG:
       session->next_crypto->digest_len = SHA256_DIGEST_LENGTH;
diff --git a/src/kex.c b/src/kex.c
index f19beb8..ceadbbb 100644
--- a/src/kex.c
+++ b/src/kex.c
@@ -79,7 +79,7 @@
 #define ECDH ""
 #endif
 
-#define KEY_EXCHANGE CURVE25519 ECDH "diffie-hellman-group14-sha1,diffie-hellman-group1-sha1"
+#define KEY_EXCHANGE CURVE25519 ECDH "diffie-hellman-group14-sha1,diffie-hellman-group1-sha1,diffie-hellman-group-exchange-sha1,diffie-hellman-group-exchange-sha256"
 #define KEX_METHODS_SIZE 10
 
 /* NOTE: This is a fixed API and the index is defined by ssh_kex_types_e */
@@ -473,6 +473,10 @@ int ssh_kex_select_methods (ssh_session session){
       session->next_crypto->kex_type=SSH_KEX_DH_GROUP1_SHA1;
     } else if(strcmp(session->next_crypto->kex_methods[SSH_KEX], "diffie-hellman-group14-sha1") == 0){
       session->next_crypto->kex_type=SSH_KEX_DH_GROUP14_SHA1;
+    } else if(strcmp(session->next_crypto->kex_methods[SSH_KEX], "diffie-hellman-group-exchange-sha1") == 0){
+      session->next_crypto->kex_type=SSH_KEX_DH_GROUP_SHA1;
+    } else if(strcmp(session->next_crypto->kex_methods[SSH_KEX], "diffie-hellman-group-exchange-sha256") == 0){
+      session->next_crypto->kex_type=SSH_KEX_DH_GROUP_SHA256;
     } else if(strcmp(session->next_crypto->kex_methods[SSH_KEX], "ecdh-sha2-nistp256") == 0){
       session->next_crypto->kex_type=SSH_KEX_ECDH_SHA2_NISTP256;
     } else if(strcmp(session->next_crypto->kex_methods[SSH_KEX], "curve25519-sha256@libssh.org") == 0){
diff --git a/src/packet.c b/src/packet.c
index 96f6d10..4296a74 100644
--- a/src/packet.c
+++ b/src/packet.c
@@ -76,7 +76,7 @@ static ssh_packet_callback default_packet_handlers[]= {
   ssh_packet_dh_reply,                     // SSH2_MSG_KEXDH_REPLY                31
                                            // SSH2_MSG_KEX_DH_GEX_GROUP           31
   NULL,                                    // SSH2_MSG_KEX_DH_GEX_INIT            32
-  NULL,                                    // SSH2_MSG_KEX_DH_GEX_REPLY           33
+  ssh_packet_dh_reply,                     // SSH2_MSG_KEX_DH_GEX_REPLY           33
   NULL,                                    // SSH2_MSG_KEX_DH_GEX_REQUEST         34
   NULL, NULL, NULL, NULL, NULL, NULL,	NULL,
   NULL, NULL, NULL, NULL, NULL, NULL, NULL,
diff --git a/src/packet_cb.c b/src/packet_cb.c
index f5d4f05..aa7c599 100644
--- a/src/packet_cb.c
+++ b/src/packet_cb.c
@@ -84,12 +84,14 @@ SSH_PACKET_CALLBACK(ssh_packet_ignore_callback){
 }
 
 SSH_PACKET_CALLBACK(ssh_packet_dh_reply){
+  enum ssh_dh_state_e dh_handshake_state = DH_STATE_NEWKEYS_SENT;
   int rc;
   (void)type;
   (void)user;
   SSH_LOG(SSH_LOG_PROTOCOL,"Received SSH_KEXDH_REPLY");
   if(session->session_state!= SSH_SESSION_STATE_DH &&
-		session->dh_handshake_state != DH_STATE_INIT_SENT){
+		(session->dh_handshake_state != DH_STATE_GEX_REQUEST_SENT ||
+         session->dh_handshake_state != DH_STATE_INIT_SENT)){
 	ssh_set_error(session,SSH_FATAL,"ssh_packet_dh_reply called in wrong state : %d:%d",
 			session->session_state,session->dh_handshake_state);
 	goto error;
@@ -97,7 +99,16 @@ SSH_PACKET_CALLBACK(ssh_packet_dh_reply){
   switch(session->next_crypto->kex_type){
     case SSH_KEX_DH_GROUP1_SHA1:
     case SSH_KEX_DH_GROUP14_SHA1:
-      rc=ssh_client_dh_reply(session, packet);
+      rc = ssh_client_dh_reply(session, packet);
+      break;
+    case SSH_KEX_DH_GROUP_SHA1:
+    case SSH_KEX_DH_GROUP_SHA256:
+      if(session->dh_handshake_state == DH_STATE_GEX_REQUEST_SENT) {
+        rc = ssh_client_dh_gex_reply(session, packet);
+        dh_handshake_state = DH_STATE_INIT_SENT;
+      } else {
+        rc = ssh_client_dh_reply(session, packet);
+      }
       break;
 #ifdef HAVE_ECDH
     case SSH_KEX_ECDH_SHA2_NISTP256:
@@ -114,7 +125,7 @@ SSH_PACKET_CALLBACK(ssh_packet_dh_reply){
       goto error;
   }
   if(rc==SSH_OK) {
-    session->dh_handshake_state = DH_STATE_NEWKEYS_SENT;
+    session->dh_handshake_state = dh_handshake_state;
     return SSH_PACKET_USED;
   }
 error:
diff --git a/src/wrapper.c b/src/wrapper.c
index 94175d0..a51d943 100644
--- a/src/wrapper.c
+++ b/src/wrapper.c
@@ -108,6 +108,8 @@ void crypto_free(struct ssh_crypto_struct *crypto){
   cipher_free(crypto->in_cipher);
   cipher_free(crypto->out_cipher);
 
+  bignum_free(crypto->p);
+  bignum_free(crypto->g);
   bignum_free(crypto->e);
   bignum_free(crypto->f);
   bignum_free(crypto->x);
-- 
1.8.1.msysgit.1

