Add the missing support for ciphertext stealing in the implementation
of AES-XTS, which is part of the XTS specification but was omitted up
until now due to lack of a need for it.

The asm helpers are updated so they can deal with any input size, as
long as the last full block and the final partial block are presented
at the same time. The glue code is updated so that the common case of
operating on a sector or page is mostly as before. When CTS is needed,
the walk is split up into two pieces, unless the entire input is covered
by a single step.

Signed-off-by: Ard Biesheuvel <ard.biesheu...@linaro.org>
---
 arch/arm64/crypto/aes-glue.c  | 126 ++++++++++++++++++--
 arch/arm64/crypto/aes-modes.S |  99 ++++++++++++---
 2 files changed, 195 insertions(+), 30 deletions(-)

diff --git a/arch/arm64/crypto/aes-glue.c b/arch/arm64/crypto/aes-glue.c
index 5ee980c5a5c2..eecb74fd2f61 100644
--- a/arch/arm64/crypto/aes-glue.c
+++ b/arch/arm64/crypto/aes-glue.c
@@ -90,10 +90,10 @@ asmlinkage void aes_ctr_encrypt(u8 out[], u8 const in[], 
u32 const rk[],
                                int rounds, int blocks, u8 ctr[]);
 
 asmlinkage void aes_xts_encrypt(u8 out[], u8 const in[], u32 const rk1[],
-                               int rounds, int blocks, u32 const rk2[], u8 
iv[],
+                               int rounds, int bytes, u32 const rk2[], u8 iv[],
                                int first);
 asmlinkage void aes_xts_decrypt(u8 out[], u8 const in[], u32 const rk1[],
-                               int rounds, int blocks, u32 const rk2[], u8 
iv[],
+                               int rounds, int bytes, u32 const rk2[], u8 iv[],
                                int first);
 
 asmlinkage void aes_essiv_cbc_encrypt(u8 out[], u8 const in[], u32 const rk1[],
@@ -529,21 +529,71 @@ static int __maybe_unused xts_encrypt(struct 
skcipher_request *req)
        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
        struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
        int err, first, rounds = 6 + ctx->key1.key_length / 4;
+       int tail = req->cryptlen % AES_BLOCK_SIZE;
+       struct scatterlist sg_src[2], sg_dst[2];
+       struct skcipher_request subreq;
+       struct scatterlist *src, *dst;
        struct skcipher_walk walk;
-       unsigned int blocks;
+
+       if (req->cryptlen < AES_BLOCK_SIZE)
+               return -EINVAL;
 
        err = skcipher_walk_virt(&walk, req, false);
 
-       for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
+       if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
+               int xts_blocks = DIV_ROUND_UP(req->cryptlen,
+                                             AES_BLOCK_SIZE) - 2;
+
+               skcipher_walk_abort(&walk);
+
+               skcipher_request_set_tfm(&subreq, tfm);
+               skcipher_request_set_callback(&subreq,
+                                             skcipher_request_flags(req),
+                                             NULL, NULL);
+               skcipher_request_set_crypt(&subreq, req->src, req->dst,
+                                          xts_blocks * AES_BLOCK_SIZE,
+                                          req->iv);
+               req = &subreq;
+               err = skcipher_walk_virt(&walk, req, false);
+       } else {
+               tail = 0;
+       }
+
+       for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
+               int nbytes = walk.nbytes;
+
+               if (walk.nbytes < walk.total)
+                       nbytes &= ~(AES_BLOCK_SIZE - 1);
+
                kernel_neon_begin();
                aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
-                               ctx->key1.key_enc, rounds, blocks,
+                               ctx->key1.key_enc, rounds, nbytes,
                                ctx->key2.key_enc, walk.iv, first);
                kernel_neon_end();
-               err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
+               err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
        }
 
-       return err;
+       if (err || likely(!tail))
+               return err;
+
+       dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
+       if (req->dst != req->src)
+               dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
+
+       skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
+                                  req->iv);
+
+       err = skcipher_walk_virt(&walk, &subreq, false);
+       if (err)
+               return err;
+
+       kernel_neon_begin();
+       aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
+                       ctx->key1.key_enc, rounds, walk.nbytes,
+                       ctx->key2.key_enc, walk.iv, first);
+       kernel_neon_end();
+
+       return skcipher_walk_done(&walk, 0);
 }
 
 static int __maybe_unused xts_decrypt(struct skcipher_request *req)
@@ -551,21 +601,72 @@ static int __maybe_unused xts_decrypt(struct 
skcipher_request *req)
        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
        struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
        int err, first, rounds = 6 + ctx->key1.key_length / 4;
+       int tail = req->cryptlen % AES_BLOCK_SIZE;
+       struct scatterlist sg_src[2], sg_dst[2];
+       struct skcipher_request subreq;
+       struct scatterlist *src, *dst;
        struct skcipher_walk walk;
-       unsigned int blocks;
+
+       if (req->cryptlen < AES_BLOCK_SIZE)
+               return -EINVAL;
 
        err = skcipher_walk_virt(&walk, req, false);
 
-       for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
+       if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
+               int xts_blocks = DIV_ROUND_UP(req->cryptlen,
+                                             AES_BLOCK_SIZE) - 2;
+
+               skcipher_walk_abort(&walk);
+
+               skcipher_request_set_tfm(&subreq, tfm);
+               skcipher_request_set_callback(&subreq,
+                                             skcipher_request_flags(req),
+                                             NULL, NULL);
+               skcipher_request_set_crypt(&subreq, req->src, req->dst,
+                                          xts_blocks * AES_BLOCK_SIZE,
+                                          req->iv);
+               req = &subreq;
+               err = skcipher_walk_virt(&walk, req, false);
+       } else {
+               tail = 0;
+       }
+
+       for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
+               int nbytes = walk.nbytes;
+
+               if (walk.nbytes < walk.total)
+                       nbytes &= ~(AES_BLOCK_SIZE - 1);
+
                kernel_neon_begin();
                aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
-                               ctx->key1.key_dec, rounds, blocks,
+                               ctx->key1.key_dec, rounds, nbytes,
                                ctx->key2.key_enc, walk.iv, first);
                kernel_neon_end();
-               err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
+               err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
        }
 
-       return err;
+       if (err || likely(!tail))
+               return err;
+
+       dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
+       if (req->dst != req->src)
+               dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
+
+       skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
+                                  req->iv);
+
+       err = skcipher_walk_virt(&walk, &subreq, false);
+       if (err)
+               return err;
+
+
+       kernel_neon_begin();
+       aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
+                       ctx->key1.key_dec, rounds, walk.nbytes,
+                       ctx->key2.key_enc, walk.iv, first);
+       kernel_neon_end();
+
+       return skcipher_walk_done(&walk, 0);
 }
 
 static struct skcipher_alg aes_algs[] = { {
@@ -646,6 +747,7 @@ static struct skcipher_alg aes_algs[] = { {
        .min_keysize    = 2 * AES_MIN_KEY_SIZE,
        .max_keysize    = 2 * AES_MAX_KEY_SIZE,
        .ivsize         = AES_BLOCK_SIZE,
+       .walksize       = 2 * AES_BLOCK_SIZE,
        .setkey         = xts_set_key,
        .encrypt        = xts_encrypt,
        .decrypt        = xts_decrypt,
diff --git a/arch/arm64/crypto/aes-modes.S b/arch/arm64/crypto/aes-modes.S
index 38cd5a2091a8..f2c2ba739f36 100644
--- a/arch/arm64/crypto/aes-modes.S
+++ b/arch/arm64/crypto/aes-modes.S
@@ -413,10 +413,10 @@ AES_ENDPROC(aes_ctr_encrypt)
 
 
        /*
+        * aes_xts_encrypt(u8 out[], u8 const in[], u8 const rk1[], int rounds,
+        *                 int bytes, u8 const rk2[], u8 iv[], int first)
         * aes_xts_decrypt(u8 out[], u8 const in[], u8 const rk1[], int rounds,
-        *                 int blocks, u8 const rk2[], u8 iv[], int first)
-        * aes_xts_decrypt(u8 out[], u8 const in[], u8 const rk1[], int rounds,
-        *                 int blocks, u8 const rk2[], u8 iv[], int first)
+        *                 int bytes, u8 const rk2[], u8 iv[], int first)
         */
 
        .macro          next_tweak, out, in, tmp
@@ -451,7 +451,7 @@ AES_ENTRY(aes_xts_encrypt)
 .LxtsencloopNx:
        next_tweak      v4, v4, v8
 .LxtsencNx:
-       subs            w4, w4, #4
+       subs            w4, w4, #64
        bmi             .Lxtsenc1x
        ld1             {v0.16b-v3.16b}, [x1], #64      /* get 4 pt blocks */
        next_tweak      v5, v4, v8
@@ -468,33 +468,66 @@ AES_ENTRY(aes_xts_encrypt)
        eor             v2.16b, v2.16b, v6.16b
        st1             {v0.16b-v3.16b}, [x0], #64
        mov             v4.16b, v7.16b
-       cbz             w4, .Lxtsencout
+       cbz             w4, .Lxtsencret
        xts_reload_mask v8
        b               .LxtsencloopNx
 .Lxtsenc1x:
-       adds            w4, w4, #4
+       adds            w4, w4, #64
        beq             .Lxtsencout
+       subs            w4, w4, #16
+       bmi             .LxtsencctsNx
 .Lxtsencloop:
-       ld1             {v1.16b}, [x1], #16
-       eor             v0.16b, v1.16b, v4.16b
+       ld1             {v0.16b}, [x1], #16
+.Lxtsencctsout:
+       eor             v0.16b, v0.16b, v4.16b
        encrypt_block   v0, w3, x2, x8, w7
        eor             v0.16b, v0.16b, v4.16b
-       st1             {v0.16b}, [x0], #16
-       subs            w4, w4, #1
-       beq             .Lxtsencout
+       cbz             w4, .Lxtsencout
+       subs            w4, w4, #16
        next_tweak      v4, v4, v8
+       bmi             .Lxtsenccts
+       st1             {v0.16b}, [x0], #16
        b               .Lxtsencloop
 .Lxtsencout:
+       st1             {v0.16b}, [x0]
+.Lxtsencret:
        st1             {v4.16b}, [x6]
        ldp             x29, x30, [sp], #16
        ret
-AES_ENDPROC(aes_xts_encrypt)
 
+.LxtsencctsNx:
+       mov             v0.16b, v3.16b
+       sub             x0, x0, #16
+.Lxtsenccts:
+       adr_l           x8, .Lcts_permute_table
+
+       add             x1, x1, w4, sxtw        /* rewind input pointer */
+       add             w4, w4, #16             /* # bytes in final block */
+       add             x9, x8, #32
+       add             x8, x8, x4
+       sub             x9, x9, x4
+       add             x4, x0, x4              /* output address of final 
block */
+
+       ld1             {v1.16b}, [x1]          /* load final block */
+       ld1             {v2.16b}, [x8]
+       ld1             {v3.16b}, [x9]
+
+       tbl             v2.16b, {v0.16b}, v2.16b
+       tbx             v0.16b, {v1.16b}, v3.16b
+       st1             {v2.16b}, [x4]                  /* overlapping stores */
+       mov             w4, wzr
+       b               .Lxtsencctsout
+AES_ENDPROC(aes_xts_encrypt)
 
 AES_ENTRY(aes_xts_decrypt)
        stp             x29, x30, [sp, #-16]!
        mov             x29, sp
 
+       /* subtract 16 bytes if we are doing CTS */
+       sub             w8, w4, #0x10
+       tst             w4, #0xf
+       csel            w4, w4, w8, eq
+
        ld1             {v4.16b}, [x6]
        xts_load_mask   v8
        cbz             w7, .Lxtsdecnotfirst
@@ -509,7 +542,7 @@ AES_ENTRY(aes_xts_decrypt)
 .LxtsdecloopNx:
        next_tweak      v4, v4, v8
 .LxtsdecNx:
-       subs            w4, w4, #4
+       subs            w4, w4, #64
        bmi             .Lxtsdec1x
        ld1             {v0.16b-v3.16b}, [x1], #64      /* get 4 ct blocks */
        next_tweak      v5, v4, v8
@@ -530,22 +563,52 @@ AES_ENTRY(aes_xts_decrypt)
        xts_reload_mask v8
        b               .LxtsdecloopNx
 .Lxtsdec1x:
-       adds            w4, w4, #4
+       adds            w4, w4, #64
        beq             .Lxtsdecout
+       subs            w4, w4, #16
 .Lxtsdecloop:
-       ld1             {v1.16b}, [x1], #16
-       eor             v0.16b, v1.16b, v4.16b
+       ld1             {v0.16b}, [x1], #16
+       bmi             .Lxtsdeccts
+.Lxtsdecctsout:
+       eor             v0.16b, v0.16b, v4.16b
        decrypt_block   v0, w3, x2, x8, w7
        eor             v0.16b, v0.16b, v4.16b
        st1             {v0.16b}, [x0], #16
-       subs            w4, w4, #1
-       beq             .Lxtsdecout
+       cbz             w4, .Lxtsdecout
+       subs            w4, w4, #16
        next_tweak      v4, v4, v8
        b               .Lxtsdecloop
 .Lxtsdecout:
        st1             {v4.16b}, [x6]
        ldp             x29, x30, [sp], #16
        ret
+
+.Lxtsdeccts:
+       adr_l           x8, .Lcts_permute_table
+
+       add             x1, x1, w4, sxtw        /* rewind input pointer */
+       add             w4, w4, #16             /* # bytes in final block */
+       add             x9, x8, #32
+       add             x8, x8, x4
+       sub             x9, x9, x4
+       add             x4, x0, x4              /* output address of final 
block */
+
+       next_tweak      v5, v4, v8
+
+       ld1             {v1.16b}, [x1]          /* load final block */
+       ld1             {v2.16b}, [x8]
+       ld1             {v3.16b}, [x9]
+
+       eor             v0.16b, v0.16b, v5.16b
+       decrypt_block   v0, w3, x2, x8, w7
+       eor             v0.16b, v0.16b, v5.16b
+
+       tbl             v2.16b, {v0.16b}, v2.16b
+       tbx             v0.16b, {v1.16b}, v3.16b
+
+       st1             {v2.16b}, [x4]                  /* overlapping stores */
+       mov             w4, wzr
+       b               .Lxtsdecctsout
 AES_ENDPROC(aes_xts_decrypt)
 
        /*
-- 
2.17.1

Reply via email to