This patch add the support of rsa-pss encoding which is described
rfc8017.
Similar to rsa-pkcs1, we create a pss template.

Signed-off-by: Hongbo Li <herbert.tenc...@gmail.com>
---
 crypto/Makefile               |   7 ++-
 crypto/rsa.c                  |  14 ++---
 crypto/rsa_helper.c           | 127 ++++++++++++++++++++++++++++++++++++++++++
 crypto/rsapss_params.asn1     |  21 +++++++
 include/crypto/internal/rsa.h |  25 ++++++++-
 5 files changed, 185 insertions(+), 9 deletions(-)
 create mode 100644 crypto/rsapss_params.asn1

diff --git a/crypto/Makefile b/crypto/Makefile
index 10526d4..2c65744 100644
--- a/crypto/Makefile
+++ b/crypto/Makefile
@@ -33,13 +33,18 @@ obj-$(CONFIG_CRYPTO_DH) += dh_generic.o
 
 $(obj)/rsapubkey.asn1.o: $(obj)/rsapubkey.asn1.c $(obj)/rsapubkey.asn1.h
 $(obj)/rsaprivkey.asn1.o: $(obj)/rsaprivkey.asn1.c $(obj)/rsaprivkey.asn1.h
-$(obj)/rsa_helper.o: $(obj)/rsapubkey.asn1.h $(obj)/rsaprivkey.asn1.h
+$(obj)/rsapss_params.asn1.o: $(obj)/rsapss_params.asn1.c \
+                            $(obj)/rsapss_params.asn1.h
+$(obj)/rsa_helper.o: $(obj)/rsapubkey.asn1.h $(obj)/rsaprivkey.asn1.h \
+                    $(obj)/rsapss_params.asn1.h
 
 rsa_generic-y := rsapubkey.asn1.o
 rsa_generic-y += rsaprivkey.asn1.o
+rsa_generic-y += rsapss_params.asn1.o
 rsa_generic-y += rsa.o
 rsa_generic-y += rsa_helper.o
 rsa_generic-y += rsa-pkcs1pad.o
+rsa_generic-y += rsa-psspad.o
 obj-$(CONFIG_CRYPTO_RSA) += rsa_generic.o
 
 $(obj)/sm2signature.asn1.o: $(obj)/sm2signature.asn1.c 
$(obj)/sm2signature.asn1.h
diff --git a/crypto/rsa.c b/crypto/rsa.c
index 4cdbec9..adc9b2d2 100644
--- a/crypto/rsa.c
+++ b/crypto/rsa.c
@@ -6,18 +6,11 @@
  */
 
 #include <linux/module.h>
-#include <linux/mpi.h>
 #include <crypto/internal/rsa.h>
 #include <crypto/internal/akcipher.h>
 #include <crypto/akcipher.h>
 #include <crypto/algapi.h>
 
-struct rsa_mpi_key {
-       MPI n;
-       MPI e;
-       MPI d;
-};
-
 /*
  * RSAEP function [RFC3447 sec 5.1.1]
  * c = m^e mod n;
@@ -269,12 +262,19 @@ static int rsa_init(void)
                return err;
        }
 
+       err = crypto_register_template(&rsa_psspad_tmpl);
+       if (err) {
+               crypto_unregister_akcipher(&rsa);
+               return err;
+       }
+
        return 0;
 }
 
 static void rsa_exit(void)
 {
        crypto_unregister_template(&rsa_pkcs1pad_tmpl);
+       crypto_unregister_template(&rsa_psspad_tmpl);
        crypto_unregister_akcipher(&rsa);
 }
 
diff --git a/crypto/rsa_helper.c b/crypto/rsa_helper.c
index 94266f2..912d975 100644
--- a/crypto/rsa_helper.c
+++ b/crypto/rsa_helper.c
@@ -12,6 +12,7 @@
 #include <crypto/internal/rsa.h>
 #include "rsapubkey.asn1.h"
 #include "rsaprivkey.asn1.h"
+#include "rsapss_params.asn1.h"
 
 int rsa_get_n(void *context, size_t hdrlen, unsigned char tag,
              const void *value, size_t vlen)
@@ -148,6 +149,115 @@ int rsa_get_qinv(void *context, size_t hdrlen, unsigned 
char tag,
        return 0;
 }
 
+int rsa_get_pss_hash(void *context, size_t hdrlen, unsigned char tag,
+                    const void *value, size_t vlen)
+{
+       struct rsa_pss_ctx *ctx = context;
+       enum OID oid;
+
+       if (!value || !vlen)
+               return -EINVAL;
+
+       oid = look_up_OID(value, vlen);
+       switch (oid) {
+       case OID_sha1:
+               ctx->hash_algo = "sha1";
+               break;
+       case OID_sha224:
+               ctx->hash_algo = "sha224";
+               break;
+       case OID_sha256:
+               ctx->hash_algo = "sha256";
+               break;
+       case OID_sha384:
+               ctx->hash_algo = "sha384";
+               break;
+       case OID_sha512:
+               ctx->hash_algo = "sha512";
+               break;
+       default:
+               return -ENOPKG;
+
+       }
+
+       return 0;
+}
+
+int rsa_get_pss_mgf(void *context, size_t hdrlen, unsigned char tag,
+                   const void *value, size_t vlen)
+{
+       struct rsa_pss_ctx *ctx = context;
+       enum OID oid;
+
+       if (!value || !vlen)
+               return -EINVAL;
+
+       oid = look_up_OID(value, vlen);
+       if (oid != OID_rsa_mgf1)
+               return -ENOPKG;
+       ctx->mgf_algo = "mgf1";
+
+       return 0;
+}
+
+int rsa_get_pss_mgf_hash(void *context, size_t hdrlen, unsigned char tag,
+                        const void *value, size_t vlen)
+{
+       struct rsa_pss_ctx *ctx = context;
+       enum OID oid;
+
+       if (!value || !vlen)
+               return -EINVAL;
+       /* todo, merge with get_pss_hash */
+       oid = look_up_OID(value, vlen);
+       switch (oid) {
+       case OID_sha1:
+               ctx->mgf_hash_algo = "sha1";
+               break;
+       case OID_sha224:
+               ctx->mgf_hash_algo = "sha224";
+               break;
+       case OID_sha256:
+               ctx->mgf_hash_algo = "sha256";
+               break;
+       case OID_sha384:
+               ctx->mgf_hash_algo = "sha384";
+               break;
+       case OID_sha512:
+               ctx->mgf_hash_algo = "sha512";
+               break;
+       default:
+               return -ENOPKG;
+       }
+
+       return 0;
+}
+
+int rsa_get_pss_saltlen(void *context, size_t hdrlen, unsigned char tag,
+                       const void *value, size_t vlen)
+{
+       struct rsa_pss_ctx *ctx = context;
+
+       if (!value || vlen < 1 || vlen > 2)
+               return -EINVAL;
+
+       if (vlen == 1)
+               ctx->salt_len = *(u8 *)value;
+       else if (vlen == 2)
+               ctx->salt_len = ntohs(*(u16 *)value);
+
+       return 0;
+}
+
+int rsa_get_pss_trailerfield(void *context, size_t hdrlen, unsigned char tag,
+                            const void *value, size_t vlen)
+{
+       if (!value || !vlen || *(u8 *)value != 1)
+               return -EINVAL;
+
+       return 0;
+}
+
 /**
  * rsa_parse_pub_key() - decodes the BER encoded buffer and stores in the
  *                       provided struct rsa_key, pointers to the raw key as 
is,
@@ -184,3 +294,20 @@ int rsa_parse_priv_key(struct rsa_key *rsa_key, const void 
*key,
        return asn1_ber_decoder(&rsaprivkey_decoder, rsa_key, key, key_len);
 }
 EXPORT_SYMBOL_GPL(rsa_parse_priv_key);
+
+/**
+ * rsa_parse_pss_params() - decodes the BER encoded pss padding params
+ *
+ * @ctx:       struct rsa_pss_ctx, pss padding context
+ * @params:    params in BER format
+ * @params_len:        length of params
+ *
+ * Return:     0 on success or error code in case of error
+ */
+int rsa_parse_pss_params(struct rsa_pss_ctx *ctx, const void *params,
+                        unsigned int params_len)
+{
+       return asn1_ber_decoder(&rsapss_params_decoder, ctx, params,
+                               params_len);
+}
+EXPORT_SYMBOL_GPL(rsa_parse_pss_params);
diff --git a/crypto/rsapss_params.asn1 b/crypto/rsapss_params.asn1
new file mode 100644
index 0000000..4d6b0ba
--- /dev/null
+++ b/crypto/rsapss_params.asn1
@@ -0,0 +1,21 @@
+-- rfc4055 section 3.1.
+
+RSAPSS_Params ::= SEQUENCE {
+       hashAlgorithm           [0] HashAlgorithm OPTIONAL,
+       maskGenAlgorithm        [1] MaskGenAlgorithm OPTIONAL,
+       saltLen                 [2] INTEGER OPTIONAL ({ rsa_get_pss_saltlen }),
+       trailerField            [3] INTEGER OPTIONAL ({ 
rsa_get_pss_trailerfield })
+       }
+
+HashAlgorithm ::= SEQUENCE {
+       algorithm               OBJECT IDENTIFIER ({ rsa_get_pss_hash })
+       }
+
+MaskGenAlgorithm ::= SEQUENCE {
+       algorithm       OBJECT IDENTIFIER  ({ rsa_get_pss_mgf }),
+       hashAlgorithm   MgfHashAlgorithm
+       }
+
+MgfHashAlgorithm ::= SEQUENCE {
+       algorithm       OBJECT IDENTIFIER ({ rsa_get_pss_mgf_hash })
+       }
diff --git a/include/crypto/internal/rsa.h b/include/crypto/internal/rsa.h
index e870133..cfb0801 100644
--- a/include/crypto/internal/rsa.h
+++ b/include/crypto/internal/rsa.h
@@ -8,6 +8,12 @@
 #ifndef _RSA_HELPER_
 #define _RSA_HELPER_
 #include <linux/types.h>
+#include <linux/mpi.h>
+#include <linux/oid_registry.h>
+#include <crypto/sha2.h>
+
+#define RSA_MAX_DIGEST_SIZE            SHA512_DIGEST_SIZE
+#define RSA_PSS_DEFAULT_SALT_LEN       20
 
 /**
  * rsa_key - RSA key structure
@@ -47,11 +53,28 @@ struct rsa_key {
        size_t qinv_sz;
 };
 
+struct rsa_mpi_key {
+       MPI n;
+       MPI e;
+       MPI d;
+};
+
+struct rsa_pss_ctx {
+       struct crypto_akcipher *child;
+       unsigned int key_size;
+       const char *hash_algo;
+       const char *mgf_algo;
+       const char *mgf_hash_algo;
+       u32 salt_len;
+};
+
 int rsa_parse_pub_key(struct rsa_key *rsa_key, const void *key,
                      unsigned int key_len);
 
 int rsa_parse_priv_key(struct rsa_key *rsa_key, const void *key,
                       unsigned int key_len);
-
+int rsa_parse_pss_params(struct rsa_pss_ctx *ctx, const void *params,
+                        unsigned int params_len);
 extern struct crypto_template rsa_pkcs1pad_tmpl;
+extern struct crypto_template rsa_psspad_tmpl;
 #endif
-- 
1.8.3.1

Reply via email to