On Thu, 9 Oct 2025 20:25:17 +0800
Guan-Chun Wu <[email protected]> wrote:

...
> As Eric mentioned, the decoder in fs/crypto/ needs to reject invalid input.

(to avoid two different input buffers giving the same output)

Which is annoyingly reasonable.

> One possible solution I came up with is to first create a shared
> base64_rev_common lookup table as the base for all Base64 variants.
> Then, depending on the variant (e.g., BASE64_STD, BASE64_URLSAFE, etc.), we
> can dynamically adjust the character mappings for position 62 and position 63
> at runtime, based on the variant.
> 
> Here are the changes to the code:
> 
> static const s8 base64_rev_common[256] = {
>       [0 ... 255] = -1,
>       ['A'] =  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12,
>               13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,
>       ['a'] = 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,
>               39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,
>       ['0'] = 52, 53, 54, 55, 56, 57, 58, 59, 60, 61,
> };
> 
> static const struct {
>       char char62, char63;
> } base64_symbols[] = {
>       [BASE64_STD] = { '+', '/' },
>       [BASE64_URLSAFE] = { '-', '_' },
>       [BASE64_IMAP] = { '+', ',' },
> };
> 
> int base64_decode(const char *src, int srclen, u8 *dst, bool padding, enum 
> base64_variant variant)
> {
>       u8 *bp = dst;
>       u8 pad_cnt = 0;
>       s8 input1, input2, input3, input4;
>       u32 val;
>       s8 base64_rev_tables[256];
> 
>       /* Validate the input length for padding */
>       if (unlikely(padding && (srclen & 0x03) != 0))
>               return -1;

There is no need for an early check.
Pick it up after the loop when 'srclen != 0'.

> 
>       memcpy(base64_rev_tables, base64_rev_common, sizeof(base64_rev_common));

Ugg - having a memcpy() here is not a good idea.
It really is better to have 3 arrays, but use a 'mostly common' initialiser.
Perhaps:
#define BASE64_REV_INIT(ch_62, ch_63) = { \
        [0 ... 255] = -1, \
        ['A'] =  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, \
                13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, \
        ['a'] = 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, \
                39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, \
        ['0'] = 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, \
        [ch_62] = 62, [ch_63] = 63, \
}

static const s8 base64_rev_maps[][256] = {
        [BASE64_STD] = BASE64_REV_INIT('+', '/'),
        [BASE64_URLSAFE] = BASE64_REV_INIT('-', '_'),
        [BASE64_IMAP] = BASE64_REV_INIT('+', ',')
};

Then (after validating variant):
        const s8 *map = base64_rev_maps[variant];

> 
>       if (variant < BASE64_STD || variant > BASE64_IMAP)
>               return -1;
> 
>       base64_rev_tables[base64_symbols[variant].char62] = 62;
>       base64_rev_tables[base64_symbols[variant].char63] = 63;
> 
>       while (padding && srclen > 0 && src[srclen - 1] == '=') {
>               pad_cnt++;
>               srclen--;
>               if (pad_cnt > 2)
>                       return -1;
>       }

I'm not sure I'd to that there.
You are (in some sense) optimising for padding.
From what I remember, "abcd" gives 24 bits, "abc=" 16 and "ab==" 8.

> 
>       while (srclen >= 4) {
>               /* Decode the next 4 characters */
>               input1 = base64_rev_tables[(u8)src[0]];
>               input2 = base64_rev_tables[(u8)src[1]];
>               input3 = base64_rev_tables[(u8)src[2]];
>               input4 = base64_rev_tables[(u8)src[3]];

I'd be tempted to make src[] unsigned - probably be assigning the parameter
to a local at the top of the function.

Also you have input3 = ... src[2]...
Perhaps they should be input[0..3] instead.

> 
>               val = (input1 << 18) |
>                     (input2 << 12) |
>                     (input3 << 6) |
>                     input4;

Four lines is excessive, C doesn't require the () and I'm not sure the
compilers complain about << and |.

> 
>               if (unlikely((s32)val < 0))
>                       return -1;

Make 'val' signed - then you don't need the cast.
You can pick up the padding check here, something like:
                        val = input1 << 18 | input2 << 12;
                        if (!padding || val < 0 || src[3] != '=')
                                return -1;
                        *bp++ = val >> 16;
                        if (src[2] == '=')
                                return bp - dst;
                        if (input3 < 0)
                                return -1;
                        val |= input3 << 6;
                        *bp++ = val >> 8;
                        return bp - dst;

Or, if you really want to use the code below the loop:
                        if (!padding || src[3] != '=')
                                return -1;
                        padding = 0;
                        srclen -= 1 + (src[2] == '=');
                        break;


> 
>               *bp++ = (u8)(val >> 16);
>               *bp++ = (u8)(val >> 8);
>               *bp++ = (u8)val;

You don't need those casts.

> 
>               src += 4;
>               srclen -= 4;
>       }
> 
>       /* Handle leftover characters when padding is not used */

You are coming here with padding.
I'm not sure what should happen without padding.
For a multi-line file decode I suspect the characters need adding to
the start of the next line (ie lines aren't required to contain
multiples of 4 characters - even though they almost always will).

>       if (srclen > 0) {
>               switch (srclen) {

You don't need an 'if' and a 'switch'.
srclen is likely to be zero, but perhaps write as:
        if (likely(!srclen))
                return bp - dst;
        if (padding || srclen == 1)
                return -1;

        val = base64_rev_tables[(u8)src[0]] << 12 | 
base64_rev_tables[(u8)src[1]] << 6;
        *bp++ = val >> 10;
        if (srclen == 1) {
                if (val & 0x800003ff)
                        return -1;
        } else {
                val |= base64_rev_tables[(u8)src[2]];
                if (val & 0x80000003)
                        return -1;
                *bp++ = val >> 2;
        }
        return bp - dst;
}

        David

>               case 2:
>                       input1 = base64_rev_tables[(u8)src[0]];
>                       input2 = base64_rev_tables[(u8)src[1]];
>                       val = (input1 << 6) | input2; /* 12 bits */
>                       if (unlikely((s32)val < 0 || val & 0x0F))
>                               return -1;
> 
>                       *bp++ = (u8)(val >> 4);
>                       break;
>               case 3:
>                       input1 = base64_rev_tables[(u8)src[0]];
>                       input2 = base64_rev_tables[(u8)src[1]];
>                       input3 = base64_rev_tables[(u8)src[2]];
> 
>                       val = (input1 << 12) |
>                             (input2 << 6) |
>                             input3; /* 18 bits */
>                       if (unlikely((s32)val < 0 || val & 0x03))
>                               return -1;
> 
>                       *bp++ = (u8)(val >> 10);
>                       *bp++ = (u8)(val >> 2);
>                       break;
>               default:
>                       return -1;
>               }
>       }
> 
>       return bp - dst;
> }
> Based on KUnit testing, the performance results are as follows:
>       base64_performance_tests: [64B] decode run : 40ns
>       base64_performance_tests: [1KB] decode run : 463ns
> 
> However, this approach introduces an issue. It uses 256 bytes of memory
> on the stack for base64_rev_tables, which might not be ideal. Does anyone
> have any thoughts or alternative suggestions to solve this issue, or is it
> not really a concern?
> 
> Best regards,
> Guan-Chun
> 
> > > 
> > > Best,
> > > Caleb  
> >   


Reply via email to