This generates a "mask" byte array of size mask_len bytes as a
concatenation of digests, where each digest is calculated on a
concatenation of an input seed and a running counter to fill up
mask_len bytes - as described by RFC8017 sec B.2.1. "MGF1".

The mask is useful for RSA signing/verification process with
encoding RSASSA-PSS.

Reference: https://tools.ietf.org/html/rfc8017#appendix-B.2.1
Signed-off-by: Varad Gautam <varad.gau...@suse.com>
---
 crypto/rsa-psspad.c | 54 +++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 54 insertions(+)

diff --git a/crypto/rsa-psspad.c b/crypto/rsa-psspad.c
index eec303bb55b2d..ed5374c381513 100644
--- a/crypto/rsa-psspad.c
+++ b/crypto/rsa-psspad.c
@@ -51,6 +51,60 @@ static int psspad_set_sig_params(struct crypto_akcipher *tfm,
        return 0;
 }
 
+/* MGF1 per RFC8017 B.2.1. */
+static int pkcs1_mgf1(u8 *seed, unsigned int seed_len,
+                     struct shash_desc *desc,
+                     u8 *mask, unsigned int mask_len)
+{
+       unsigned int pos, h_len, i, c;
+       u8 *tmp;
+       int ret = 0;
+
+       h_len = crypto_shash_digestsize(desc->tfm);
+
+       pos = i = 0;
+       while ((i < (mask_len / h_len) + 1) && pos < mask_len) {
+               /* Compute T = T || Hash(mgfSeed || C) into mask at pos. */
+               c = cpu_to_be32(i);
+
+               ret = crypto_shash_init(desc);
+               if (ret < 0)
+                       goto out_err;
+
+               ret = crypto_shash_update(desc, seed, seed_len);
+               if (ret < 0)
+                       goto out_err;
+
+               ret = crypto_shash_update(desc, (u8 *) &c, sizeof(c));
+               if (ret < 0)
+                       goto out_err;
+
+               if (mask_len - pos >= h_len) {
+                       ret = crypto_shash_final(desc, mask + pos);
+                       pos += h_len;
+               } else {
+                       tmp = kzalloc(h_len, GFP_KERNEL);
+                       if (!tmp) {
+                               ret = -ENOMEM;
+                               goto out_err;
+                       }
+                       ret = crypto_shash_final(desc, tmp);
+                       /* copy the last hash */
+                       memcpy(mask + pos, tmp, mask_len - pos);
+                       kfree(tmp);
+                       pos = mask_len;
+               }
+               if (ret < 0) {
+                       goto out_err;
+               }
+
+               i++;
+       }
+
+out_err:
+       return ret;
+}
+
 static int psspad_s_v_e_d(struct akcipher_request *req)
 {
        return -EOPNOTSUPP;
-- 
2.30.2

Reply via email to