On Thu, Dec 15, 2022 at 04:59:52AM +0900, Michael Paquier wrote:
> However, that's only half of the picture.  The key length and the hash
> type (or just the hash type to know what's the digest/key length to
> use but that's more invasive) still need to be sent across the
> internal routines of SCRAM and attached to the state data of the
> frontend and the backend or we won't be able to do the hash and HMAC
> computations dependent on that.

Attached is a patch to do exactly that, and as a result v2 is half the
size of v1:
- SCRAM_KEY_LEN is now named SCRAM_MAX_KEY_LEN, adding a note that
this should be kept in sync as the maximum digest size of the
supported hash methods.  This is used as the method to size all the
internal buffers of the SCRAM routines.
- SCRAM_SHA_256_KEY_LEN is used to track the key length for
SCRAM-SHA-256, the one initialized with the state data.
- No changes in the internal, the buffers are just resized based on
the max defined.

I'd like to move on with that in the next couple of days (still need
to study more the other areas of the code to see what else could be
made more pluggable), so let me know if there are any objections..
--
Michael
From 0e01ec72ebfdf71bafd7434ea19c2dcb17164f1d Mon Sep 17 00:00:00 2001
From: Michael Paquier <mich...@paquier.xyz>
Date: Sat, 17 Dec 2022 12:06:37 +0900
Subject: [PATCH v2] Remove dependency to hash type and key length in internal
 SCRAM code

SCRAM_KEY_LEN had a hard dependency on SHA-256, making difficult the
addition of more hash methods in SCRAM with many internal buffers sized
depending on that.  A second issue is that SHA-256 is assumed as the
computation method to use all the time.

This commit renames SCRAM_KEY_LEN to a more generic SCRAM_KEY_MAX_LEN,
which is used as the size of the buffers used by the internal routines
of SCRAM, which is aimed at tracking centrally the maximum size
necessary for all the hash methods supported.  A second change is that
the key length (SHA digest length) and hash types are now tracked by the
state data in the backend and the frontend, the common portions being
extended to handle these as arguments by the internal routines of
SCRAM.
---
 src/include/common/scram-common.h    |  31 ++++--
 src/include/libpq/scram.h            |   6 +-
 src/backend/libpq/auth-scram.c       | 137 ++++++++++++++++-----------
 src/backend/libpq/crypt.c            |  10 +-
 src/common/scram-common.c            |  83 +++++++++-------
 src/interfaces/libpq/fe-auth-scram.c |  65 ++++++++-----
 6 files changed, 201 insertions(+), 131 deletions(-)

diff --git a/src/include/common/scram-common.h b/src/include/common/scram-common.h
index 4acf2a78ad..953d30ac54 100644
--- a/src/include/common/scram-common.h
+++ b/src/include/common/scram-common.h
@@ -21,7 +21,13 @@
 #define SCRAM_SHA_256_PLUS_NAME "SCRAM-SHA-256-PLUS"	/* with channel binding */
 
 /* Length of SCRAM keys (client and server) */
-#define SCRAM_KEY_LEN				PG_SHA256_DIGEST_LENGTH
+#define SCRAM_SHA_256_KEY_LEN				PG_SHA256_DIGEST_LENGTH
+
+/*
+ * Size of buffers used internally by SCRAM routines, that should be the
+ * maximum of SCRAM_SHA_*_KEY_LEN among the hash methods supported.
+ */
+#define SCRAM_MAX_KEY_LEN					SCRAM_SHA_256_KEY_LEN
 
 /*
  * Size of random nonce generated in the authentication exchange.  This
@@ -43,17 +49,22 @@
  */
 #define SCRAM_DEFAULT_ITERATIONS	4096
 
-extern int	scram_SaltedPassword(const char *password, const char *salt,
-								 int saltlen, int iterations, uint8 *result,
-								 const char **errstr);
-extern int	scram_H(const uint8 *input, int len, uint8 *result,
+extern int	scram_SaltedPassword(const char *password,
+								 pg_cryptohash_type hash_type, int key_length,
+								 const char *salt, int saltlen, int iterations,
+								 uint8 *result, const char **errstr);
+extern int	scram_H(const uint8 *input, pg_cryptohash_type hash_type,
+					int key_length, uint8 *result,
 					const char **errstr);
-extern int	scram_ClientKey(const uint8 *salted_password, uint8 *result,
-							const char **errstr);
-extern int	scram_ServerKey(const uint8 *salted_password, uint8 *result,
-							const char **errstr);
+extern int	scram_ClientKey(const uint8 *salted_password,
+							pg_cryptohash_type hash_type, int key_length,
+							uint8 *result, const char **errstr);
+extern int	scram_ServerKey(const uint8 *salted_password,
+							pg_cryptohash_type hash_type, int key_length,
+							uint8 *result, const char **errstr);
 
-extern char *scram_build_secret(const char *salt, int saltlen, int iterations,
+extern char *scram_build_secret(pg_cryptohash_type hash_type, int key_length,
+								const char *salt, int saltlen, int iterations,
 								const char *password, const char **errstr);
 
 #endif							/* SCRAM_COMMON_H */
diff --git a/src/include/libpq/scram.h b/src/include/libpq/scram.h
index c51e848c24..b29501ef96 100644
--- a/src/include/libpq/scram.h
+++ b/src/include/libpq/scram.h
@@ -13,6 +13,7 @@
 #ifndef PG_SCRAM_H
 #define PG_SCRAM_H
 
+#include "common/cryptohash.h"
 #include "lib/stringinfo.h"
 #include "libpq/libpq-be.h"
 #include "libpq/sasl.h"
@@ -22,7 +23,10 @@ extern PGDLLIMPORT const pg_be_sasl_mech pg_be_scram_mech;
 
 /* Routines to handle and check SCRAM-SHA-256 secret */
 extern char *pg_be_scram_build_secret(const char *password);
-extern bool parse_scram_secret(const char *secret, int *iterations, char **salt,
+extern bool parse_scram_secret(const char *secret,
+							   int *iterations,
+							   pg_cryptohash_type *hash_type,
+							   int *key_length, char **salt,
 							   uint8 *stored_key, uint8 *server_key);
 extern bool scram_verify_plain_password(const char *username,
 										const char *password, const char *secret);
diff --git a/src/backend/libpq/auth-scram.c b/src/backend/libpq/auth-scram.c
index c9bab85e82..0e4bbfc4f1 100644
--- a/src/backend/libpq/auth-scram.c
+++ b/src/backend/libpq/auth-scram.c
@@ -141,10 +141,14 @@ typedef struct
 	Port	   *port;
 	bool		channel_binding_in_use;
 
+	/* State data depending on the hash type */
+	pg_cryptohash_type hash_type;
+	int			key_length;
+
 	int			iterations;
 	char	   *salt;			/* base64-encoded */
-	uint8		StoredKey[SCRAM_KEY_LEN];
-	uint8		ServerKey[SCRAM_KEY_LEN];
+	uint8		StoredKey[SCRAM_MAX_KEY_LEN];
+	uint8		ServerKey[SCRAM_MAX_KEY_LEN];
 
 	/* Fields of the first message from client */
 	char		cbind_flag;
@@ -155,7 +159,7 @@ typedef struct
 	/* Fields from the last message from client */
 	char	   *client_final_message_without_proof;
 	char	   *client_final_nonce;
-	char		ClientProof[SCRAM_KEY_LEN];
+	char		ClientProof[SCRAM_MAX_KEY_LEN];
 
 	/* Fields generated in the server */
 	char	   *server_first_message;
@@ -177,12 +181,15 @@ static char *build_server_first_message(scram_state *state);
 static char *build_server_final_message(scram_state *state);
 static bool verify_client_proof(scram_state *state);
 static bool verify_final_nonce(scram_state *state);
-static void mock_scram_secret(const char *username, int *iterations,
-							  char **salt, uint8 *stored_key, uint8 *server_key);
+static void mock_scram_secret(const char *username, pg_cryptohash_type *hash_type,
+							  int *iterations, int *key_length, char **salt,
+							  uint8 *stored_key, uint8 *server_key);
 static bool is_scram_printable(char *p);
 static char *sanitize_char(char c);
 static char *sanitize_str(const char *s);
-static char *scram_mock_salt(const char *username);
+static char *scram_mock_salt(const char *username,
+							 pg_cryptohash_type hash_type,
+							 int key_length);
 
 /*
  * Get a list of SASL mechanisms that this module supports.
@@ -266,8 +273,11 @@ scram_init(Port *port, const char *selected_mech, const char *shadow_pass)
 
 		if (password_type == PASSWORD_TYPE_SCRAM_SHA_256)
 		{
-			if (parse_scram_secret(shadow_pass, &state->iterations, &state->salt,
-								   state->StoredKey, state->ServerKey))
+			if (parse_scram_secret(shadow_pass, &state->iterations,
+								   &state->hash_type, &state->key_length,
+								   &state->salt,
+								   state->StoredKey,
+								   state->ServerKey))
 				got_secret = true;
 			else
 			{
@@ -310,8 +320,10 @@ scram_init(Port *port, const char *selected_mech, const char *shadow_pass)
 	 */
 	if (!got_secret)
 	{
-		mock_scram_secret(state->port->user_name, &state->iterations,
-						  &state->salt, state->StoredKey, state->ServerKey);
+		mock_scram_secret(state->port->user_name, &state->hash_type,
+						  &state->iterations, &state->key_length,
+						  &state->salt,
+						  state->StoredKey, state->ServerKey);
 		state->doomed = true;
 	}
 
@@ -482,7 +494,8 @@ pg_be_scram_build_secret(const char *password)
 				(errcode(ERRCODE_INTERNAL_ERROR),
 				 errmsg("could not generate random salt")));
 
-	result = scram_build_secret(saltbuf, SCRAM_DEFAULT_SALT_LEN,
+	result = scram_build_secret(PG_SHA256, SCRAM_SHA_256_KEY_LEN,
+								saltbuf, SCRAM_DEFAULT_SALT_LEN,
 								SCRAM_DEFAULT_ITERATIONS, password,
 								&errstr);
 
@@ -505,16 +518,18 @@ scram_verify_plain_password(const char *username, const char *password,
 	char	   *salt;
 	int			saltlen;
 	int			iterations;
-	uint8		salted_password[SCRAM_KEY_LEN];
-	uint8		stored_key[SCRAM_KEY_LEN];
-	uint8		server_key[SCRAM_KEY_LEN];
-	uint8		computed_key[SCRAM_KEY_LEN];
+	int			key_length = 0;
+	pg_cryptohash_type hash_type;
+	uint8		salted_password[SCRAM_MAX_KEY_LEN];
+	uint8		stored_key[SCRAM_MAX_KEY_LEN];
+	uint8		server_key[SCRAM_MAX_KEY_LEN];
+	uint8		computed_key[SCRAM_MAX_KEY_LEN];
 	char	   *prep_password;
 	pg_saslprep_rc rc;
 	const char *errstr = NULL;
 
-	if (!parse_scram_secret(secret, &iterations, &encoded_salt,
-							stored_key, server_key))
+	if (!parse_scram_secret(secret, &iterations, &hash_type, &key_length,
+							&encoded_salt, stored_key, server_key))
 	{
 		/*
 		 * The password looked like a SCRAM secret, but could not be parsed.
@@ -541,9 +556,11 @@ scram_verify_plain_password(const char *username, const char *password,
 		password = prep_password;
 
 	/* Compute Server Key based on the user-supplied plaintext password */
-	if (scram_SaltedPassword(password, salt, saltlen, iterations,
+	if (scram_SaltedPassword(password, hash_type, key_length,
+							 salt, saltlen, iterations,
 							 salted_password, &errstr) < 0 ||
-		scram_ServerKey(salted_password, computed_key, &errstr) < 0)
+		scram_ServerKey(salted_password, hash_type, key_length,
+						computed_key, &errstr) < 0)
 	{
 		elog(ERROR, "could not compute server key: %s", errstr);
 	}
@@ -555,7 +572,7 @@ scram_verify_plain_password(const char *username, const char *password,
 	 * Compare the secret's Server Key with the one computed from the
 	 * user-supplied password.
 	 */
-	return memcmp(computed_key, server_key, SCRAM_KEY_LEN) == 0;
+	return memcmp(computed_key, server_key, key_length) == 0;
 }
 
 
@@ -565,14 +582,15 @@ scram_verify_plain_password(const char *username, const char *password,
  * On success, the iteration count, salt, stored key, and server key are
  * extracted from the secret, and returned to the caller.  For 'stored_key'
  * and 'server_key', the caller must pass pre-allocated buffers of size
- * SCRAM_KEY_LEN.  Salt is returned as a base64-encoded, null-terminated
+ * SCRAM_MAX_KEY_LEN.  Salt is returned as a base64-encoded, null-terminated
  * string.  The buffer for the salt is palloc'd by this function.
  *
  * Returns true if the SCRAM secret has been parsed, and false otherwise.
  */
 bool
-parse_scram_secret(const char *secret, int *iterations, char **salt,
-				   uint8 *stored_key, uint8 *server_key)
+parse_scram_secret(const char *secret, int *iterations,
+				   pg_cryptohash_type *hash_type, int *key_length,
+				   char **salt, uint8 *stored_key, uint8 *server_key)
 {
 	char	   *v;
 	char	   *p;
@@ -606,6 +624,8 @@ parse_scram_secret(const char *secret, int *iterations, char **salt,
 	/* Parse the fields */
 	if (strcmp(scheme_str, "SCRAM-SHA-256") != 0)
 		goto invalid_secret;
+	*hash_type = PG_SHA256;
+	*key_length = SCRAM_SHA_256_KEY_LEN;
 
 	errno = 0;
 	*iterations = strtol(iterations_str, &p, 10);
@@ -631,17 +651,17 @@ parse_scram_secret(const char *secret, int *iterations, char **salt,
 	decoded_stored_buf = palloc(decoded_len);
 	decoded_len = pg_b64_decode(storedkey_str, strlen(storedkey_str),
 								decoded_stored_buf, decoded_len);
-	if (decoded_len != SCRAM_KEY_LEN)
+	if (decoded_len != *key_length)
 		goto invalid_secret;
-	memcpy(stored_key, decoded_stored_buf, SCRAM_KEY_LEN);
+	memcpy(stored_key, decoded_stored_buf, *key_length);
 
 	decoded_len = pg_b64_dec_len(strlen(serverkey_str));
 	decoded_server_buf = palloc(decoded_len);
 	decoded_len = pg_b64_decode(serverkey_str, strlen(serverkey_str),
 								decoded_server_buf, decoded_len);
-	if (decoded_len != SCRAM_KEY_LEN)
+	if (decoded_len != *key_length)
 		goto invalid_secret;
-	memcpy(server_key, decoded_server_buf, SCRAM_KEY_LEN);
+	memcpy(server_key, decoded_server_buf, *key_length);
 
 	return true;
 
@@ -655,20 +675,25 @@ invalid_secret:
  *
  * In a normal authentication, these are extracted from the secret
  * stored in the server.  This function generates values that look
- * realistic, for when there is no stored secret.
+ * realistic, for when there is no stored secret, using SCRAM-SHA-256.
  *
  * Like in parse_scram_secret(), for 'stored_key' and 'server_key', the
- * caller must pass pre-allocated buffers of size SCRAM_KEY_LEN, and
+ * caller must pass pre-allocated buffers of size SCRAM_MAX_KEY_LEN, and
  * the buffer for the salt is palloc'd by this function.
  */
 static void
-mock_scram_secret(const char *username, int *iterations, char **salt,
+mock_scram_secret(const char *username, pg_cryptohash_type *hash_type,
+				  int *iterations, int *key_length, char **salt,
 				  uint8 *stored_key, uint8 *server_key)
 {
 	char	   *raw_salt;
 	char	   *encoded_salt;
 	int			encoded_len;
 
+	/* Enforce the use of SHA-256, which would be realistic enough */
+	*hash_type = PG_SHA256;
+	*key_length = SCRAM_SHA_256_KEY_LEN;
+
 	/*
 	 * Generate deterministic salt.
 	 *
@@ -677,7 +702,7 @@ mock_scram_secret(const char *username, int *iterations, char **salt,
 	 * as the salt generated for mock authentication uses the cluster's nonce
 	 * value.
 	 */
-	raw_salt = scram_mock_salt(username);
+	raw_salt = scram_mock_salt(username, *hash_type, *key_length);
 	if (raw_salt == NULL)
 		elog(ERROR, "could not encode salt");
 
@@ -695,8 +720,8 @@ mock_scram_secret(const char *username, int *iterations, char **salt,
 	*iterations = SCRAM_DEFAULT_ITERATIONS;
 
 	/* StoredKey and ServerKey are not used in a doomed authentication */
-	memset(stored_key, 0, SCRAM_KEY_LEN);
-	memset(server_key, 0, SCRAM_KEY_LEN);
+	memset(stored_key, 0, SCRAM_MAX_KEY_LEN);
+	memset(server_key, 0, SCRAM_MAX_KEY_LEN);
 }
 
 /*
@@ -1111,10 +1136,10 @@ verify_final_nonce(scram_state *state)
 static bool
 verify_client_proof(scram_state *state)
 {
-	uint8		ClientSignature[SCRAM_KEY_LEN];
-	uint8		ClientKey[SCRAM_KEY_LEN];
-	uint8		client_StoredKey[SCRAM_KEY_LEN];
-	pg_hmac_ctx *ctx = pg_hmac_create(PG_SHA256);
+	uint8		ClientSignature[SCRAM_MAX_KEY_LEN];
+	uint8		ClientKey[SCRAM_MAX_KEY_LEN];
+	uint8		client_StoredKey[SCRAM_MAX_KEY_LEN];
+	pg_hmac_ctx *ctx = pg_hmac_create(state->hash_type);
 	int			i;
 	const char *errstr = NULL;
 
@@ -1123,7 +1148,7 @@ verify_client_proof(scram_state *state)
 	 * here even when processing the calculations as this could involve a mock
 	 * authentication.
 	 */
-	if (pg_hmac_init(ctx, state->StoredKey, SCRAM_KEY_LEN) < 0 ||
+	if (pg_hmac_init(ctx, state->StoredKey, state->key_length) < 0 ||
 		pg_hmac_update(ctx,
 					   (uint8 *) state->client_first_message_bare,
 					   strlen(state->client_first_message_bare)) < 0 ||
@@ -1135,7 +1160,7 @@ verify_client_proof(scram_state *state)
 		pg_hmac_update(ctx,
 					   (uint8 *) state->client_final_message_without_proof,
 					   strlen(state->client_final_message_without_proof)) < 0 ||
-		pg_hmac_final(ctx, ClientSignature, sizeof(ClientSignature)) < 0)
+		pg_hmac_final(ctx, ClientSignature, state->key_length) < 0)
 	{
 		elog(ERROR, "could not calculate client signature: %s",
 			 pg_hmac_error(ctx));
@@ -1144,14 +1169,15 @@ verify_client_proof(scram_state *state)
 	pg_hmac_free(ctx);
 
 	/* Extract the ClientKey that the client calculated from the proof */
-	for (i = 0; i < SCRAM_KEY_LEN; i++)
+	for (i = 0; i < state->key_length; i++)
 		ClientKey[i] = state->ClientProof[i] ^ ClientSignature[i];
 
 	/* Hash it one more time, and compare with StoredKey */
-	if (scram_H(ClientKey, SCRAM_KEY_LEN, client_StoredKey, &errstr) < 0)
+	if (scram_H(ClientKey, state->hash_type, state->key_length,
+				client_StoredKey, &errstr) < 0)
 		elog(ERROR, "could not hash stored key: %s", errstr);
 
-	if (memcmp(client_StoredKey, state->StoredKey, SCRAM_KEY_LEN) != 0)
+	if (memcmp(client_StoredKey, state->StoredKey, state->key_length) != 0)
 		return false;
 
 	return true;
@@ -1349,12 +1375,12 @@ read_client_final_message(scram_state *state, const char *input)
 	client_proof_len = pg_b64_dec_len(strlen(value));
 	client_proof = palloc(client_proof_len);
 	if (pg_b64_decode(value, strlen(value), client_proof,
-					  client_proof_len) != SCRAM_KEY_LEN)
+					  client_proof_len) != state->key_length)
 		ereport(ERROR,
 				(errcode(ERRCODE_PROTOCOL_VIOLATION),
 				 errmsg("malformed SCRAM message"),
 				 errdetail("Malformed proof in client-final-message.")));
-	memcpy(state->ClientProof, client_proof, SCRAM_KEY_LEN);
+	memcpy(state->ClientProof, client_proof, state->key_length);
 	pfree(client_proof);
 
 	if (*p != '\0')
@@ -1374,13 +1400,13 @@ read_client_final_message(scram_state *state, const char *input)
 static char *
 build_server_final_message(scram_state *state)
 {
-	uint8		ServerSignature[SCRAM_KEY_LEN];
+	uint8		ServerSignature[SCRAM_MAX_KEY_LEN];
 	char	   *server_signature_base64;
 	int			siglen;
-	pg_hmac_ctx *ctx = pg_hmac_create(PG_SHA256);
+	pg_hmac_ctx *ctx = pg_hmac_create(state->hash_type);
 
 	/* calculate ServerSignature */
-	if (pg_hmac_init(ctx, state->ServerKey, SCRAM_KEY_LEN) < 0 ||
+	if (pg_hmac_init(ctx, state->ServerKey, state->key_length) < 0 ||
 		pg_hmac_update(ctx,
 					   (uint8 *) state->client_first_message_bare,
 					   strlen(state->client_first_message_bare)) < 0 ||
@@ -1392,7 +1418,7 @@ build_server_final_message(scram_state *state)
 		pg_hmac_update(ctx,
 					   (uint8 *) state->client_final_message_without_proof,
 					   strlen(state->client_final_message_without_proof)) < 0 ||
-		pg_hmac_final(ctx, ServerSignature, sizeof(ServerSignature)) < 0)
+		pg_hmac_final(ctx, ServerSignature, state->key_length) < 0)
 	{
 		elog(ERROR, "could not calculate server signature: %s",
 			 pg_hmac_error(ctx));
@@ -1400,11 +1426,11 @@ build_server_final_message(scram_state *state)
 
 	pg_hmac_free(ctx);
 
-	siglen = pg_b64_enc_len(SCRAM_KEY_LEN);
+	siglen = pg_b64_enc_len(state->key_length);
 	/* don't forget the zero-terminator */
 	server_signature_base64 = palloc(siglen + 1);
 	siglen = pg_b64_encode((const char *) ServerSignature,
-						   SCRAM_KEY_LEN, server_signature_base64,
+						   state->key_length, server_signature_base64,
 						   siglen);
 	if (siglen < 0)
 		elog(ERROR, "could not encode server signature");
@@ -1431,10 +1457,11 @@ build_server_final_message(scram_state *state)
  * pointer to a static buffer of size SCRAM_DEFAULT_SALT_LEN, or NULL.
  */
 static char *
-scram_mock_salt(const char *username)
+scram_mock_salt(const char *username, pg_cryptohash_type hash_type,
+				int key_length)
 {
 	pg_cryptohash_ctx *ctx;
-	static uint8 sha_digest[PG_SHA256_DIGEST_LENGTH];
+	static uint8 sha_digest[SCRAM_MAX_KEY_LEN];
 	char	   *mock_auth_nonce = GetMockAuthenticationNonce();
 
 	/*
@@ -1446,11 +1473,13 @@ scram_mock_salt(const char *username)
 	StaticAssertDecl(PG_SHA256_DIGEST_LENGTH >= SCRAM_DEFAULT_SALT_LEN,
 					 "salt length greater than SHA256 digest length");
 
-	ctx = pg_cryptohash_create(PG_SHA256);
+	Assert(hash_type == PG_SHA256);
+
+	ctx = pg_cryptohash_create(hash_type);
 	if (pg_cryptohash_init(ctx) < 0 ||
 		pg_cryptohash_update(ctx, (uint8 *) username, strlen(username)) < 0 ||
 		pg_cryptohash_update(ctx, (uint8 *) mock_auth_nonce, MOCK_AUTH_NONCE_LEN) < 0 ||
-		pg_cryptohash_final(ctx, sha_digest, sizeof(sha_digest)) < 0)
+		pg_cryptohash_final(ctx, sha_digest, key_length) < 0)
 	{
 		pg_cryptohash_free(ctx);
 		return NULL;
diff --git a/src/backend/libpq/crypt.c b/src/backend/libpq/crypt.c
index 1ff8b0507d..a81af0749a 100644
--- a/src/backend/libpq/crypt.c
+++ b/src/backend/libpq/crypt.c
@@ -90,15 +90,17 @@ get_password_type(const char *shadow_pass)
 {
 	char	   *encoded_salt;
 	int			iterations;
-	uint8		stored_key[SCRAM_KEY_LEN];
-	uint8		server_key[SCRAM_KEY_LEN];
+	int			key_length = 0;
+	pg_cryptohash_type hash_type;
+	uint8		stored_key[SCRAM_MAX_KEY_LEN];
+	uint8		server_key[SCRAM_MAX_KEY_LEN];
 
 	if (strncmp(shadow_pass, "md5", 3) == 0 &&
 		strlen(shadow_pass) == MD5_PASSWD_LEN &&
 		strspn(shadow_pass + 3, MD5_PASSWD_CHARSET) == MD5_PASSWD_LEN - 3)
 		return PASSWORD_TYPE_MD5;
-	if (parse_scram_secret(shadow_pass, &iterations, &encoded_salt,
-						   stored_key, server_key))
+	if (parse_scram_secret(shadow_pass, &iterations, &hash_type, &key_length,
+						   &encoded_salt, stored_key, server_key))
 		return PASSWORD_TYPE_SCRAM_SHA_256;
 	return PASSWORD_TYPE_PLAINTEXT;
 }
diff --git a/src/common/scram-common.c b/src/common/scram-common.c
index 1268625929..4f59910dea 100644
--- a/src/common/scram-common.c
+++ b/src/common/scram-common.c
@@ -33,6 +33,7 @@
  */
 int
 scram_SaltedPassword(const char *password,
+					 pg_cryptohash_type hash_type, int key_length,
 					 const char *salt, int saltlen, int iterations,
 					 uint8 *result, const char **errstr)
 {
@@ -40,9 +41,9 @@ scram_SaltedPassword(const char *password,
 	uint32		one = pg_hton32(1);
 	int			i,
 				j;
-	uint8		Ui[SCRAM_KEY_LEN];
-	uint8		Ui_prev[SCRAM_KEY_LEN];
-	pg_hmac_ctx *hmac_ctx = pg_hmac_create(PG_SHA256);
+	uint8		Ui[SCRAM_MAX_KEY_LEN];
+	uint8		Ui_prev[SCRAM_MAX_KEY_LEN];
+	pg_hmac_ctx *hmac_ctx = pg_hmac_create(hash_type);
 
 	if (hmac_ctx == NULL)
 	{
@@ -60,30 +61,30 @@ scram_SaltedPassword(const char *password,
 	if (pg_hmac_init(hmac_ctx, (uint8 *) password, password_len) < 0 ||
 		pg_hmac_update(hmac_ctx, (uint8 *) salt, saltlen) < 0 ||
 		pg_hmac_update(hmac_ctx, (uint8 *) &one, sizeof(uint32)) < 0 ||
-		pg_hmac_final(hmac_ctx, Ui_prev, sizeof(Ui_prev)) < 0)
+		pg_hmac_final(hmac_ctx, Ui_prev, key_length) < 0)
 	{
 		*errstr = pg_hmac_error(hmac_ctx);
 		pg_hmac_free(hmac_ctx);
 		return -1;
 	}
 
-	memcpy(result, Ui_prev, SCRAM_KEY_LEN);
+	memcpy(result, Ui_prev, key_length);
 
 	/* Subsequent iterations */
 	for (i = 2; i <= iterations; i++)
 	{
 		if (pg_hmac_init(hmac_ctx, (uint8 *) password, password_len) < 0 ||
-			pg_hmac_update(hmac_ctx, (uint8 *) Ui_prev, SCRAM_KEY_LEN) < 0 ||
-			pg_hmac_final(hmac_ctx, Ui, sizeof(Ui)) < 0)
+			pg_hmac_update(hmac_ctx, (uint8 *) Ui_prev, key_length) < 0 ||
+			pg_hmac_final(hmac_ctx, Ui, key_length) < 0)
 		{
 			*errstr = pg_hmac_error(hmac_ctx);
 			pg_hmac_free(hmac_ctx);
 			return -1;
 		}
 
-		for (j = 0; j < SCRAM_KEY_LEN; j++)
+		for (j = 0; j < key_length; j++)
 			result[j] ^= Ui[j];
-		memcpy(Ui_prev, Ui, SCRAM_KEY_LEN);
+		memcpy(Ui_prev, Ui, key_length);
 	}
 
 	pg_hmac_free(hmac_ctx);
@@ -92,16 +93,17 @@ scram_SaltedPassword(const char *password,
 
 
 /*
- * Calculate SHA-256 hash for a NULL-terminated string. (The NULL terminator is
+ * Calculate hash for a NULL-terminated string. (The NULL terminator is
  * not included in the hash).  Returns 0 on success, -1 on failure with *errstr
  * pointing to a message about the error details.
  */
 int
-scram_H(const uint8 *input, int len, uint8 *result, const char **errstr)
+scram_H(const uint8 *input, pg_cryptohash_type hash_type, int key_length,
+		uint8 *result, const char **errstr)
 {
 	pg_cryptohash_ctx *ctx;
 
-	ctx = pg_cryptohash_create(PG_SHA256);
+	ctx = pg_cryptohash_create(hash_type);
 	if (ctx == NULL)
 	{
 		*errstr = pg_cryptohash_error(NULL);	/* returns OOM */
@@ -109,8 +111,8 @@ scram_H(const uint8 *input, int len, uint8 *result, const char **errstr)
 	}
 
 	if (pg_cryptohash_init(ctx) < 0 ||
-		pg_cryptohash_update(ctx, input, len) < 0 ||
-		pg_cryptohash_final(ctx, result, SCRAM_KEY_LEN) < 0)
+		pg_cryptohash_update(ctx, input, key_length) < 0 ||
+		pg_cryptohash_final(ctx, result, key_length) < 0)
 	{
 		*errstr = pg_cryptohash_error(ctx);
 		pg_cryptohash_free(ctx);
@@ -126,10 +128,11 @@ scram_H(const uint8 *input, int len, uint8 *result, const char **errstr)
  * pointing to a message about the error details.
  */
 int
-scram_ClientKey(const uint8 *salted_password, uint8 *result,
-				const char **errstr)
+scram_ClientKey(const uint8 *salted_password,
+				pg_cryptohash_type hash_type, int key_length,
+				uint8 *result, const char **errstr)
 {
-	pg_hmac_ctx *ctx = pg_hmac_create(PG_SHA256);
+	pg_hmac_ctx *ctx = pg_hmac_create(hash_type);
 
 	if (ctx == NULL)
 	{
@@ -137,9 +140,9 @@ scram_ClientKey(const uint8 *salted_password, uint8 *result,
 		return -1;
 	}
 
-	if (pg_hmac_init(ctx, salted_password, SCRAM_KEY_LEN) < 0 ||
+	if (pg_hmac_init(ctx, salted_password, key_length) < 0 ||
 		pg_hmac_update(ctx, (uint8 *) "Client Key", strlen("Client Key")) < 0 ||
-		pg_hmac_final(ctx, result, SCRAM_KEY_LEN) < 0)
+		pg_hmac_final(ctx, result, key_length) < 0)
 	{
 		*errstr = pg_hmac_error(ctx);
 		pg_hmac_free(ctx);
@@ -155,10 +158,11 @@ scram_ClientKey(const uint8 *salted_password, uint8 *result,
  * pointing to a message about the error details.
  */
 int
-scram_ServerKey(const uint8 *salted_password, uint8 *result,
-				const char **errstr)
+scram_ServerKey(const uint8 *salted_password,
+				pg_cryptohash_type hash_type, int key_length,
+				uint8 *result, const char **errstr)
 {
-	pg_hmac_ctx *ctx = pg_hmac_create(PG_SHA256);
+	pg_hmac_ctx *ctx = pg_hmac_create(hash_type);
 
 	if (ctx == NULL)
 	{
@@ -166,9 +170,9 @@ scram_ServerKey(const uint8 *salted_password, uint8 *result,
 		return -1;
 	}
 
-	if (pg_hmac_init(ctx, salted_password, SCRAM_KEY_LEN) < 0 ||
+	if (pg_hmac_init(ctx, salted_password, key_length) < 0 ||
 		pg_hmac_update(ctx, (uint8 *) "Server Key", strlen("Server Key")) < 0 ||
-		pg_hmac_final(ctx, result, SCRAM_KEY_LEN) < 0)
+		pg_hmac_final(ctx, result, key_length) < 0)
 	{
 		*errstr = pg_hmac_error(ctx);
 		pg_hmac_free(ctx);
@@ -192,12 +196,13 @@ scram_ServerKey(const uint8 *salted_password, uint8 *result,
  * error details.
  */
 char *
-scram_build_secret(const char *salt, int saltlen, int iterations,
+scram_build_secret(pg_cryptohash_type hash_type, int key_length,
+				   const char *salt, int saltlen, int iterations,
 				   const char *password, const char **errstr)
 {
-	uint8		salted_password[SCRAM_KEY_LEN];
-	uint8		stored_key[SCRAM_KEY_LEN];
-	uint8		server_key[SCRAM_KEY_LEN];
+	uint8		salted_password[SCRAM_MAX_KEY_LEN];
+	uint8		stored_key[SCRAM_MAX_KEY_LEN];
+	uint8		server_key[SCRAM_MAX_KEY_LEN];
 	char	   *result;
 	char	   *p;
 	int			maxlen;
@@ -206,15 +211,21 @@ scram_build_secret(const char *salt, int saltlen, int iterations,
 	int			encoded_server_len;
 	int			encoded_result;
 
+	Assert(hash_type == PG_SHA256);
+
 	if (iterations <= 0)
 		iterations = SCRAM_DEFAULT_ITERATIONS;
 
 	/* Calculate StoredKey and ServerKey */
-	if (scram_SaltedPassword(password, salt, saltlen, iterations,
+	if (scram_SaltedPassword(password, hash_type, key_length,
+							 salt, saltlen, iterations,
 							 salted_password, errstr) < 0 ||
-		scram_ClientKey(salted_password, stored_key, errstr) < 0 ||
-		scram_H(stored_key, SCRAM_KEY_LEN, stored_key, errstr) < 0 ||
-		scram_ServerKey(salted_password, server_key, errstr) < 0)
+		scram_ClientKey(salted_password, hash_type, key_length,
+						stored_key, errstr) < 0 ||
+		scram_H(stored_key, hash_type, key_length,
+				stored_key, errstr) < 0 ||
+		scram_ServerKey(salted_password, hash_type, key_length,
+						server_key, errstr) < 0)
 	{
 		/* errstr is filled already here */
 #ifdef FRONTEND
@@ -231,8 +242,8 @@ scram_build_secret(const char *salt, int saltlen, int iterations,
 	 *----------
 	 */
 	encoded_salt_len = pg_b64_enc_len(saltlen);
-	encoded_stored_len = pg_b64_enc_len(SCRAM_KEY_LEN);
-	encoded_server_len = pg_b64_enc_len(SCRAM_KEY_LEN);
+	encoded_stored_len = pg_b64_enc_len(key_length);
+	encoded_server_len = pg_b64_enc_len(key_length);
 
 	maxlen = strlen("SCRAM-SHA-256") + 1
 		+ 10 + 1				/* iteration count */
@@ -269,7 +280,7 @@ scram_build_secret(const char *salt, int saltlen, int iterations,
 	*(p++) = '$';
 
 	/* stored key */
-	encoded_result = pg_b64_encode((char *) stored_key, SCRAM_KEY_LEN, p,
+	encoded_result = pg_b64_encode((char *) stored_key, key_length, p,
 								   encoded_stored_len);
 	if (encoded_result < 0)
 	{
@@ -286,7 +297,7 @@ scram_build_secret(const char *salt, int saltlen, int iterations,
 	*(p++) = ':';
 
 	/* server key */
-	encoded_result = pg_b64_encode((char *) server_key, SCRAM_KEY_LEN, p,
+	encoded_result = pg_b64_encode((char *) server_key, key_length, p,
 								   encoded_server_len);
 	if (encoded_result < 0)
 	{
diff --git a/src/interfaces/libpq/fe-auth-scram.c b/src/interfaces/libpq/fe-auth-scram.c
index c500bea9e7..7410d5ba52 100644
--- a/src/interfaces/libpq/fe-auth-scram.c
+++ b/src/interfaces/libpq/fe-auth-scram.c
@@ -58,8 +58,12 @@ typedef struct
 	char	   *password;
 	char	   *sasl_mechanism;
 
+	/* State data depending on the hash type */
+	pg_cryptohash_type hash_type;
+	int			key_length;
+
 	/* We construct these */
-	uint8		SaltedPassword[SCRAM_KEY_LEN];
+	uint8		SaltedPassword[SCRAM_MAX_KEY_LEN];
 	char	   *client_nonce;
 	char	   *client_first_message_bare;
 	char	   *client_final_message_without_proof;
@@ -73,7 +77,7 @@ typedef struct
 
 	/* These come from the server-final message */
 	char	   *server_final_message;
-	char		ServerSignature[SCRAM_KEY_LEN];
+	char		ServerSignature[SCRAM_MAX_KEY_LEN];
 } fe_scram_state;
 
 static bool read_server_first_message(fe_scram_state *state, char *input);
@@ -106,8 +110,10 @@ scram_init(PGconn *conn,
 	memset(state, 0, sizeof(fe_scram_state));
 	state->conn = conn;
 	state->state = FE_SCRAM_INIT;
-	state->sasl_mechanism = strdup(sasl_mechanism);
+	state->key_length = SCRAM_SHA_256_KEY_LEN;
+	state->hash_type = PG_SHA256;
 
+	state->sasl_mechanism = strdup(sasl_mechanism);
 	if (!state->sasl_mechanism)
 	{
 		free(state);
@@ -450,7 +456,7 @@ build_client_final_message(fe_scram_state *state)
 {
 	PQExpBufferData buf;
 	PGconn	   *conn = state->conn;
-	uint8		client_proof[SCRAM_KEY_LEN];
+	uint8		client_proof[SCRAM_MAX_KEY_LEN];
 	char	   *result;
 	int			encoded_len;
 	const char *errstr = NULL;
@@ -565,11 +571,11 @@ build_client_final_message(fe_scram_state *state)
 	}
 
 	appendPQExpBufferStr(&buf, ",p=");
-	encoded_len = pg_b64_enc_len(SCRAM_KEY_LEN);
+	encoded_len = pg_b64_enc_len(state->key_length);
 	if (!enlargePQExpBuffer(&buf, encoded_len))
 		goto oom_error;
 	encoded_len = pg_b64_encode((char *) client_proof,
-								SCRAM_KEY_LEN,
+								state->key_length,
 								buf.data + buf.len,
 								encoded_len);
 	if (encoded_len < 0)
@@ -738,13 +744,14 @@ read_server_final_message(fe_scram_state *state, char *input)
 										 strlen(encoded_server_signature),
 										 decoded_server_signature,
 										 server_signature_len);
-	if (server_signature_len != SCRAM_KEY_LEN)
+	if (server_signature_len != state->key_length)
 	{
 		free(decoded_server_signature);
 		libpq_append_conn_error(conn, "malformed SCRAM message (invalid server signature)");
 		return false;
 	}
-	memcpy(state->ServerSignature, decoded_server_signature, SCRAM_KEY_LEN);
+	memcpy(state->ServerSignature, decoded_server_signature,
+		   state->key_length);
 	free(decoded_server_signature);
 
 	return true;
@@ -760,13 +767,13 @@ calculate_client_proof(fe_scram_state *state,
 					   const char *client_final_message_without_proof,
 					   uint8 *result, const char **errstr)
 {
-	uint8		StoredKey[SCRAM_KEY_LEN];
-	uint8		ClientKey[SCRAM_KEY_LEN];
-	uint8		ClientSignature[SCRAM_KEY_LEN];
+	uint8		StoredKey[SCRAM_MAX_KEY_LEN];
+	uint8		ClientKey[SCRAM_MAX_KEY_LEN];
+	uint8		ClientSignature[SCRAM_MAX_KEY_LEN];
 	int			i;
 	pg_hmac_ctx *ctx;
 
-	ctx = pg_hmac_create(PG_SHA256);
+	ctx = pg_hmac_create(state->hash_type);
 	if (ctx == NULL)
 	{
 		*errstr = pg_hmac_error(NULL);	/* returns OOM */
@@ -777,18 +784,21 @@ calculate_client_proof(fe_scram_state *state,
 	 * Calculate SaltedPassword, and store it in 'state' so that we can reuse
 	 * it later in verify_server_signature.
 	 */
-	if (scram_SaltedPassword(state->password, state->salt, state->saltlen,
+	if (scram_SaltedPassword(state->password, state->hash_type,
+							 state->key_length, state->salt, state->saltlen,
 							 state->iterations, state->SaltedPassword,
 							 errstr) < 0 ||
-		scram_ClientKey(state->SaltedPassword, ClientKey, errstr) < 0 ||
-		scram_H(ClientKey, SCRAM_KEY_LEN, StoredKey, errstr) < 0)
+		scram_ClientKey(state->SaltedPassword, state->hash_type,
+						state->key_length, ClientKey, errstr) < 0 ||
+		scram_H(ClientKey, state->hash_type, state->key_length,
+				StoredKey, errstr) < 0)
 	{
 		/* errstr is already filled here */
 		pg_hmac_free(ctx);
 		return false;
 	}
 
-	if (pg_hmac_init(ctx, StoredKey, SCRAM_KEY_LEN) < 0 ||
+	if (pg_hmac_init(ctx, StoredKey, state->key_length) < 0 ||
 		pg_hmac_update(ctx,
 					   (uint8 *) state->client_first_message_bare,
 					   strlen(state->client_first_message_bare)) < 0 ||
@@ -800,14 +810,14 @@ calculate_client_proof(fe_scram_state *state,
 		pg_hmac_update(ctx,
 					   (uint8 *) client_final_message_without_proof,
 					   strlen(client_final_message_without_proof)) < 0 ||
-		pg_hmac_final(ctx, ClientSignature, sizeof(ClientSignature)) < 0)
+		pg_hmac_final(ctx, ClientSignature, state->key_length) < 0)
 	{
 		*errstr = pg_hmac_error(ctx);
 		pg_hmac_free(ctx);
 		return false;
 	}
 
-	for (i = 0; i < SCRAM_KEY_LEN; i++)
+	for (i = 0; i < state->key_length; i++)
 		result[i] = ClientKey[i] ^ ClientSignature[i];
 
 	pg_hmac_free(ctx);
@@ -825,18 +835,19 @@ static bool
 verify_server_signature(fe_scram_state *state, bool *match,
 						const char **errstr)
 {
-	uint8		expected_ServerSignature[SCRAM_KEY_LEN];
-	uint8		ServerKey[SCRAM_KEY_LEN];
+	uint8		expected_ServerSignature[SCRAM_MAX_KEY_LEN];
+	uint8		ServerKey[SCRAM_MAX_KEY_LEN];
 	pg_hmac_ctx *ctx;
 
-	ctx = pg_hmac_create(PG_SHA256);
+	ctx = pg_hmac_create(state->hash_type);
 	if (ctx == NULL)
 	{
 		*errstr = pg_hmac_error(NULL);	/* returns OOM */
 		return false;
 	}
 
-	if (scram_ServerKey(state->SaltedPassword, ServerKey, errstr) < 0)
+	if (scram_ServerKey(state->SaltedPassword, state->hash_type,
+						state->key_length, ServerKey, errstr) < 0)
 	{
 		/* errstr is filled already */
 		pg_hmac_free(ctx);
@@ -844,7 +855,7 @@ verify_server_signature(fe_scram_state *state, bool *match,
 	}
 
 	/* calculate ServerSignature */
-	if (pg_hmac_init(ctx, ServerKey, SCRAM_KEY_LEN) < 0 ||
+	if (pg_hmac_init(ctx, ServerKey, state->key_length) < 0 ||
 		pg_hmac_update(ctx,
 					   (uint8 *) state->client_first_message_bare,
 					   strlen(state->client_first_message_bare)) < 0 ||
@@ -857,7 +868,7 @@ verify_server_signature(fe_scram_state *state, bool *match,
 					   (uint8 *) state->client_final_message_without_proof,
 					   strlen(state->client_final_message_without_proof)) < 0 ||
 		pg_hmac_final(ctx, expected_ServerSignature,
-					  sizeof(expected_ServerSignature)) < 0)
+					  state->key_length) < 0)
 	{
 		*errstr = pg_hmac_error(ctx);
 		pg_hmac_free(ctx);
@@ -867,7 +878,8 @@ verify_server_signature(fe_scram_state *state, bool *match,
 	pg_hmac_free(ctx);
 
 	/* signature processed, so now check after it */
-	if (memcmp(expected_ServerSignature, state->ServerSignature, SCRAM_KEY_LEN) != 0)
+	if (memcmp(expected_ServerSignature, state->ServerSignature,
+			   state->key_length) != 0)
 		*match = false;
 	else
 		*match = true;
@@ -912,7 +924,8 @@ pg_fe_scram_build_secret(const char *password, const char **errstr)
 		return NULL;
 	}
 
-	result = scram_build_secret(saltbuf, SCRAM_DEFAULT_SALT_LEN,
+	result = scram_build_secret(PG_SHA256, SCRAM_SHA_256_KEY_LEN, saltbuf,
+								SCRAM_DEFAULT_SALT_LEN,
 								SCRAM_DEFAULT_ITERATIONS, password,
 								errstr);
 
-- 
2.38.1

Attachment: signature.asc
Description: PGP signature

Reply via email to