Hi Raffaello,

I'm still not sure it needed this much of a re-write to fix the bug;
It will take some time to look at the changes.

Regardless, OpenJDK conventions call for following the style of the existing code. Your new comments follow neither the existing convention to use "//..." comments nor the other prevalent comment form using /*... */ which use consistent indentation
and "* " on continuation lines.

Regards, Roger


On 7/3/20 11:48 AM, Raffaello Giulietti wrote:
Hello,

after Roger's notes, which escaped my attention before, I've withdrawn all the changes but for DecInputStream, *except* that I couldn't resist to simplify the maths in encodedOutLength(), while still using xxxExact() arithmetic.

Sorry for the confusion
Raffaello



Hi Raffaello,

There is way more code changed here than is needed to fix the bug.
General enhancement should be separated from bug fixes.
It makes it easier to review to see the bug was fixed
and easier to separately review other code to see that there are no unexpected changes.

If some of the changes are motivated by expected performance improvements,
there should be JMH tests comparing the before and after.
The change to use byte arrays seems useful, but even using char[]
there is little danger of cache thrashing.

If using the code using xxxExact was correct, don't change it.
Those methods are intrinsified and perform as well or better than using long.
Usually, it is better to leave code alone and not risk breaking it.

Special care needs taken when changing a method that is intrinsified.
The optimized version may use fields of the object and stop
working if they are changed.

In the test, the range of buffer sizes tests seems to waste a lot
of cycles on sizes greater than the encoded size of the input.

Regards, Roger


---------------------

# HG changeset patch
# User lello
# Date 1593790152 -7200
#      Fri Jul 03 17:29:12 2020 +0200
# Node ID 73370832de1d2bf3b450930f5cb2467e10528c69
# Parent  a7c0307232406c7b0c1a4b8c2de111077848203d
8222187: java.util.Base64.Decoder stream adds unexpected null bytes at the end
Reviewed-by: TBD
Contributed-by: Raffaello Giulietti <raffaello.giulie...@gmail.com>

diff --git a/src/java.base/share/classes/java/util/Base64.java b/src/java.base/share/classes/java/util/Base64.java
--- a/src/java.base/share/classes/java/util/Base64.java
+++ b/src/java.base/share/classes/java/util/Base64.java
@@ -255,14 +255,11 @@
          *
          */
         private final int encodedOutLength(int srclen, boolean throwOOME) {
-            int len = 0;
+            int len;
             try {
-                if (doPadding) {
-                    len = Math.multiplyExact(4, (Math.addExact(srclen, 2) / 3));
-                } else {
-                    int n = srclen % 3;
-                    len = Math.addExact(Math.multiplyExact(4, (srclen / 3)), (n == 0 ? 0 : n + 1));
-                }
+                len = doPadding
+                        ? Math.multiplyExact(4, (Math.addExact(srclen, 2) / 3)) +                        : Math.addExact((Math.addExact(srclen, 2) / 3), srclen);
                 if (linemax > 0) { // line separators
                     len = Math.addExact(len, (len - 1) / linemax * newline.length);
                 }
@@ -961,14 +958,15 @@

         private final InputStream is;
         private final boolean isMIME;
-        private final int[] base64;      // base64 -> byte mapping
-        private int bits = 0;            // 24-bit buffer for decoding
-        private int nextin = 18;         // next available "off" in "bits" for input;
-                                         // -> 18, 12, 6, 0
-        private int nextout = -8;        // next available "off" in "bits" for output; -                                         // -> 8, 0, -8 (no byte for output)
-        private boolean eof = false;
-        private boolean closed = false;
+        private final int[] base64;     // mapping from alphabet to values
+        private int bits;               // 24 bit buffer for decoding
+        private int wpos;               // writing bit pos inside bits
+        // one of 24 (left, msb), 18, 12, 6, 0
+        private int rpos;               // reading bit pos inside bits
+        // one of 24 (left, msb), 16, 8, 0
+        private boolean eos;
+        private boolean closed;
+        private byte[] onebuf = new byte[1];

         DecInputStream(InputStream is, int[] base64, boolean isMIME) {
             this.is = is;
@@ -976,114 +974,158 @@
             this.isMIME = isMIME;
         }

-        private byte[] sbBuf = new byte[1];
-
         @Override
         public int read() throws IOException {
-            return read(sbBuf, 0, 1) == -1 ? -1 : sbBuf[0] & 0xff;
+            return read(onebuf, 0, 1) >= 0 ? onebuf[0] & 0xff : -1;
+        }
+
+        private int leftovers(byte[] b, int off, int pos, int limit) {
+            eos = true;
+            /*
+            We use a loop here, as this code is executed only once.
+            Unrolling the loop would probably not contribute much here.
+             */
+            while (rpos - 8 >= wpos && pos != limit) {
+                b[pos++] = (byte) (bits >> (rpos -= 8));
+            }
+            return pos - off != 0 || rpos - 8 >= wpos ? pos - off : -1;
         }

-        private int eof(byte[] b, int off, int len, int oldOff)
-            throws IOException
-        {
-            eof = true;
-            if (nextin != 18) {
-                if (nextin == 12)
-                    throw new IOException("Base64 stream has one un-decoded dangling byte.");
-                // treat ending xx/xxx without padding character legal.
-                // same logic as v == '=' below
-                b[off++] = (byte)(bits >> (16));
-                if (nextin == 0) {           // only one padding byte
-                    if (len == 1) {          // no enough output space
-                        bits >>= 8;          // shift to lowest byte
-                        nextout = 0;
-                    } else {
-                        b[off++] = (byte) (bits >>  8);
-                    }
-                }
+        private int eos(byte[] b, int off, int pos, int limit) throws IOException {
+            /*
+            wpos == 18: x     dangling single x, invalid unit
+            accept ending xx or xxx without padding characters
+             */
+            if (wpos == 18) {
+                throw new IOException("Base64 stream has one un-decoded dangling byte");
             }
-            return off == oldOff ? -1 : off - oldOff;
+            rpos = 24;
+            return leftovers(b, off, pos, limit);
         }

-        private int padding(byte[] b, int off, int len, int oldOff)
-            throws IOException
-        {
-            // =     shiftto==18 unnecessary padding
-            // x=    shiftto==12 dangling x, invalid unit
-            // xx=   shiftto==6 && missing last '='
-            // xx=y  or last is not '='
-            if (nextin == 18 || nextin == 12 ||
-                nextin == 6 && is.read() != '=') {
-                throw new IOException("Illegal base64 ending sequence:" + nextin); +        private int padding(byte[] b, int off, int pos, int limit) throws IOException {
+            /*
+            wpos == 24: =    (unnecessary padding)
+            wpos == 18: x=   (dangling single x, invalid unit)
+            wpos == 12 and missing last '=': xx=  (invalid padding)
+            wpos == 12 and last is not '=': xx=x (invalid padding)
+             */
+            if (wpos >= 18 || wpos == 12 && is.read() != '=') {
+                throw new IOException("Illegal base64 ending sequence");
             }
-            b[off++] = (byte)(bits >> (16));
-            if (nextin == 0) {           // only one padding byte
-                if (len == 1) {          // no enough output space
-                    bits >>= 8;          // shift to lowest byte
-                    nextout = 0;
-                } else {
-                    b[off++] = (byte) (bits >>  8);
-                }
-            }
-            eof = true;
-            return off - oldOff;
+            rpos = 24;
+            return leftovers(b, off, pos, limit);
         }

         @Override
         public int read(byte[] b, int off, int len) throws IOException {
-            if (closed)
+            if (closed) {
                 throw new IOException("Stream is closed");
-            if (eof && nextout < 0)    // eof and no leftover
-                return -1;
-            if (off < 0 || len < 0 || len > b.length - off)
-                throw new IndexOutOfBoundsException();
-            int oldOff = off;
-            while (nextout >= 0) {       // leftover output byte(s) in bits buf
-                if (len == 0)
-                    return off - oldOff;
-                b[off++] = (byte)(bits >> nextout);
-                len--;
-                nextout -= 8;
+            }
+            Objects.checkFromIndexSize(off, len, b.length);
+            if (len == 0) {
+                return 0;
             }
-            bits = 0;
-            while (len > 0) {
-                int v = is.read();
-                if (v == -1) {
-                    return eof(b, off, len, oldOff);
+
+            /*
+            Rather than keeping 2 running vars (e.g., off and len), we only +            keep one (pos), while definitely fixing the boundaries of the range
+            [off, limit).
+
+            Note that limit can overflow to Integer.MIN_VALUE. However, as long
+            as comparisons with pos are done as coded, there's no harm.
+
+            In addition, limit - off (= len) is used from here on, so the +            location for len can be reallocated to other vars (e.g., limit).
+             */
+            int pos = off;
+            int limit = off + len;
+            if (eos) {
+                return leftovers(b, off, pos, limit);
+            }
+
+            /*
+            leftovers from previous invocation; here, wpos = 0.
+            There can be at most 2 leftover bytes (rpos <= 16).
+            Further, the buffer b has at least one free place.
+
+            The logic could be coded as a loop, (as in method leftovers()) +            but the explicit "unrolling" makes it possible to generate better
+            byte extraction code.
+             */
+            if (rpos == 16) {
+                b[pos++] = (byte) (bits >> 8);
+                rpos = 8;
+                if (pos == limit) {
+                    return limit - off;
                 }
-                if ((v = base64[v]) < 0) {
-                    if (v == -2) {       // padding byte(s)
-                        return padding(b, off, len, oldOff);
-                    }
-                    if (v == -1) {
-                        if (!isMIME)
-                            throw new IOException("Illegal base64 character " +
-                                Integer.toString(v, 16));
-                        continue;        // skip if for rfc2045
-                    }
-                    // neve be here
-                }
-                bits |= (v << nextin);
-                if (nextin == 0) {
-                    nextin = 18;         // clear for next in
-                    b[off++] = (byte)(bits >> 16);
-                    if (len == 1) {
-                        nextout = 8;    // 2 bytes left in bits
-                        break;
-                    }
-                    b[off++] = (byte)(bits >> 8);
-                    if (len == 2) {
-                        nextout = 0;    // 1 byte left in bits
-                        break;
-                    }
-                    b[off++] = (byte)bits;
-                    len -= 3;
-                    bits = 0;
-                } else {
-                    nextin -= 6;
+            }
+            if (rpos == 8) {
+                b[pos++] = (byte) bits;
+                rpos = 0;
+                if (pos == limit) {
+                    return limit - off;
                 }
             }
-            return off - oldOff;
+ 8222187
+            bits = 0;
+            wpos = 24;
+            for (;;) {
+                // Here, pos != limit
+                int i = is.read();
+                if (i < 0) {
+                    return eos(b, off, pos, limit);
+                }
+                int v = base64[i];
+                if (v < 0) {
+                    /*
+                    i not in alphabet, thus
+                        v = -2: i is '=', the padding
+                        v = -1: i is something else, e.g., CR or LF
+                     */
+                    if (v == -1) {
+                        if (isMIME) {
+                            continue;
+                        }
+                        throw new IOException("Illegal base64 byte 0x" +
+                                Integer.toHexString(i));
+                    }
+                    return padding(b, off, pos, limit);
+                }
+                bits |= (v << (wpos -= 6));
+                if (wpos != 0) {
+                    continue;
+                }
+                if (limit - pos >= 3) {
+                    // frequently taken fast path, no need to track rpos
+                    b[pos++] = (byte) (bits >> 16);
+                    b[pos++] = (byte) (bits >> 8);
+                    b[pos++] = (byte) bits;
+                    bits = 0;
+                    wpos = 24;
+                    if (pos == limit) {
+                        return limit - off;
+                    }
+                    continue;
+                }
+                /*
+                Here, buffer b has either 1 or 2 free places, that is,
+                limit - pos = 1 or limit - pos = 2.
+
+                As above, this could be coded as a loop. But since the
+                shift lengths are explicit multiples of 8, better code can be
+                probably generated.
+                 */
+                b[pos++] = (byte) (bits >> 16);
+                if (pos == limit) {
+                    rpos = 16;
+                    return limit - off;
+                }
+                b[pos++] = (byte) (bits >> 8);
+                // Here, pos = limit, no need for an if.
+                rpos = 8;
+                return limit - off;
+            }
         }

         @Override
diff --git a/test/jdk/java/util/Base64/TestBase64.java b/test/jdk/java/util/Base64/TestBase64.java
--- a/test/jdk/java/util/Base64/TestBase64.java
+++ b/test/jdk/java/util/Base64/TestBase64.java
@@ -144,6 +144,10 @@
         testDecoderKeepsAbstinence(Base64.getDecoder());
         testDecoderKeepsAbstinence(Base64.getUrlDecoder());
         testDecoderKeepsAbstinence(Base64.getMimeDecoder());
+
+        // tests patch addressing JDK-8222187
+        // https://bugs.openjdk.java.net/browse/JDK-8222187
+        testJDK_8222187();
     }

     private static void test(Base64.Encoder enc, Base64.Decoder dec,
@@ -607,4 +611,26 @@
             }
         }
     }
+
+    private static void testJDK_8222187() throws Throwable {
+        byte[] orig = "12345678".getBytes("US-ASCII");
+        byte[] encoded = Base64.getEncoder().encode(orig);
+        // decode using different buffer sizes
+        for (int bufferSize = 1; bufferSize <= encoded.length + 1; bufferSize++) {
+            try (
+                    InputStream in = Base64.getDecoder().wrap(
+                            new ByteArrayInputStream(encoded));
+                    ByteArrayOutputStream baos = new ByteArrayOutputStream();
+            ) {
+                byte[] buffer = new byte[bufferSize];
+                int read;
+                while ((read = in.read(buffer, 0, bufferSize)) >= 0) {
+                    baos.write(buffer, 0, read);
+                }
+                // compare result, output info if lengths do not match
+                byte[] decoded = baos.toByteArray();
+                checkEqual(decoded, orig, "Base64 stream decoding failed!");
+            }
+        }
+    }
 }

Reply via email to