This is an automated email from the ASF dual-hosted git repository.

markt pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tomcat.git

commit 374ff41eea1b7212d51a233494c7bd44678c3931
Author: Mark Thomas <ma...@apache.org>
AuthorDate: Fri Jun 4 12:53:14 2021 +0100

    Fix HEAD response for reset() and resetBuffer()
---
 java/jakarta/servlet/http/HttpServlet.java        | 210 ++++++++++++++++++++--
 java/jakarta/servlet/http/LocalStrings.properties |   1 +
 test/jakarta/servlet/http/TestHttpServlet.java    |  91 ++++++++++
 webapps/docs/changelog.xml                        |   6 +
 4 files changed, 296 insertions(+), 12 deletions(-)

diff --git a/java/jakarta/servlet/http/HttpServlet.java 
b/java/jakarta/servlet/http/HttpServlet.java
index 86a5a7b..e4c0d3a 100644
--- a/java/jakarta/servlet/http/HttpServlet.java
+++ b/java/jakarta/servlet/http/HttpServlet.java
@@ -20,6 +20,7 @@ import java.io.IOException;
 import java.io.OutputStreamWriter;
 import java.io.PrintWriter;
 import java.io.UnsupportedEncodingException;
+import java.io.Writer;
 import java.lang.reflect.InvocationTargetException;
 import java.lang.reflect.Method;
 import java.text.MessageFormat;
@@ -810,7 +811,7 @@ public abstract class HttpServlet extends GenericServlet {
      */
     private static class NoBodyResponse extends HttpServletResponseWrapper {
         private final NoBodyOutputStream noBody;
-        private PrintWriter writer;
+        private NoBodyPrintWriter writer;
         private boolean didSetContentLength;
 
         private NoBodyResponse(HttpServletResponse r) {
@@ -823,7 +824,7 @@ public abstract class HttpServlet extends GenericServlet {
                 if (writer != null) {
                     writer.flush();
                 }
-                super.setContentLengthLong(noBody.getContentLength());
+                super.setContentLengthLong(noBody.getWrittenByteCount());
             }
         }
 
@@ -879,13 +880,24 @@ public abstract class HttpServlet extends GenericServlet {
         public PrintWriter getWriter() throws UnsupportedEncodingException {
 
             if (writer == null) {
-                OutputStreamWriter w;
-
-                w = new OutputStreamWriter(noBody, getCharacterEncoding());
-                writer = new PrintWriter(w);
+                writer = new NoBodyPrintWriter(noBody, getCharacterEncoding());
             }
             return writer;
         }
+
+        @Override
+        public void reset() {
+            super.reset();
+            resetBuffer();
+        }
+
+        @Override
+        public void resetBuffer() {
+            noBody.resetBuffer();
+            if (writer != null) {
+                writer.resetBuffer();
+            }
+        }
     }
 
 
@@ -899,19 +911,19 @@ public abstract class HttpServlet extends GenericServlet {
 
         private final HttpServletResponse response;
         private boolean flushed = false;
-        private long contentLength = 0;
+        private long writtenByteCount = 0;
 
         private NoBodyOutputStream(HttpServletResponse response) {
             this.response = response;
         }
 
-        private long getContentLength() {
-            return contentLength;
+        private long getWrittenByteCount() {
+            return writtenByteCount;
         }
 
         @Override
         public void write(int b) throws IOException {
-            contentLength++;
+            writtenByteCount++;
             checkCommit();
         }
 
@@ -932,7 +944,7 @@ public abstract class HttpServlet extends GenericServlet {
                 throw new IndexOutOfBoundsException(msg);
             }
 
-            contentLength += len;
+            writtenByteCount += len;
             checkCommit();
         }
 
@@ -948,10 +960,184 @@ public abstract class HttpServlet extends GenericServlet 
{
         }
 
         private void checkCommit() throws IOException {
-            if (!flushed && contentLength > response.getBufferSize()) {
+            if (!flushed && writtenByteCount > response.getBufferSize()) {
                 response.flushBuffer();
                 flushed = true;
             }
         }
+
+        private void resetBuffer() {
+            if (flushed) {
+                throw new 
IllegalStateException(lStrings.getString("err.state.commit"));
+            }
+            writtenByteCount = 0;
+        }
+    }
+
+
+    private static class NoBodyPrintWriter extends PrintWriter {
+
+        private final NoBodyOutputStream out;
+        private final String encoding;
+        private PrintWriter pw;
+
+        public NoBodyPrintWriter(NoBodyOutputStream out, String encoding) 
throws UnsupportedEncodingException {
+            super(out);
+            this.out = out;
+            this.encoding = encoding;
+
+            Writer osw = new OutputStreamWriter(out, encoding);
+            pw = new PrintWriter(osw);
+        }
+
+        private void resetBuffer() {
+            out.resetBuffer();
+
+            Writer osw = null;
+            try {
+                osw = new OutputStreamWriter(out, encoding);
+            } catch (UnsupportedEncodingException e) {
+                // Impossible.
+                // The same values were used in the constructor. If this method
+                // gets called then the constructor must have succeeded so the
+                // above call must also succeed.
+            }
+            pw = new PrintWriter(osw);
+        }
+
+        @Override
+        public void flush() {
+            pw.flush();
+        }
+
+        @Override
+        public void close() {
+            pw.close();
+        }
+
+        @Override
+        public boolean checkError() {
+            return pw.checkError();
+        }
+
+        @Override
+        public void write(int c) {
+            pw.write(c);
+        }
+
+        @Override
+        public void write(char[] buf, int off, int len) {
+            pw.write(buf, off, len);
+        }
+
+        @Override
+        public void write(char[] buf) {
+            pw.write(buf);
+        }
+
+        @Override
+        public void write(String s, int off, int len) {
+            pw.write(s, off, len);
+        }
+
+        @Override
+        public void write(String s) {
+            pw.write(s);
+        }
+
+        @Override
+        public void print(boolean b) {
+            pw.print(b);
+        }
+
+        @Override
+        public void print(char c) {
+            pw.print(c);
+        }
+
+        @Override
+        public void print(int i) {
+            pw.print(i);
+        }
+
+        @Override
+        public void print(long l) {
+            pw.print(l);
+        }
+
+        @Override
+        public void print(float f) {
+            pw.print(f);
+        }
+
+        @Override
+        public void print(double d) {
+            pw.print(d);
+        }
+
+        @Override
+        public void print(char[] s) {
+            pw.print(s);
+        }
+
+        @Override
+        public void print(String s) {
+            pw.print(s);
+        }
+
+        @Override
+        public void print(Object obj) {
+            pw.print(obj);
+        }
+
+        @Override
+        public void println() {
+            pw.println();
+        }
+
+        @Override
+        public void println(boolean x) {
+            pw.println(x);
+        }
+
+        @Override
+        public void println(char x) {
+            pw.println(x);
+        }
+
+        @Override
+        public void println(int x) {
+            pw.println(x);
+        }
+
+        @Override
+        public void println(long x) {
+            pw.println(x);
+        }
+
+        @Override
+        public void println(float x) {
+            pw.println(x);
+        }
+
+        @Override
+        public void println(double x) {
+            pw.println(x);
+        }
+
+        @Override
+        public void println(char[] x) {
+            pw.println(x);
+        }
+
+        @Override
+        public void println(String x) {
+            pw.println(x);
+        }
+
+        @Override
+        public void println(Object x) {
+            pw.println(x);
+        }
     }
 }
diff --git a/java/jakarta/servlet/http/LocalStrings.properties 
b/java/jakarta/servlet/http/LocalStrings.properties
index 4d0e8d8..f9fbd6c 100644
--- a/java/jakarta/servlet/http/LocalStrings.properties
+++ b/java/jakarta/servlet/http/LocalStrings.properties
@@ -21,6 +21,7 @@ err.cookie_name_is_token=Cookie name [{0}] is a reserved token
 err.io.indexOutOfBounds=Invalid offset [{0}] and / or length [{1}] specified 
for array of size [{2}]
 err.io.nullArray=Null passed for byte array in write method
 err.io.short_read=Short Read
+err.state.commit=Not permitted once response has been committed
 
 http.method_delete_not_supported=HTTP method DELETE is not supported by this 
URL
 http.method_get_not_supported=HTTP method GET is not supported by this URL
diff --git a/test/jakarta/servlet/http/TestHttpServlet.java 
b/test/jakarta/servlet/http/TestHttpServlet.java
index ef034f8..c0cf95b 100644
--- a/test/jakarta/servlet/http/TestHttpServlet.java
+++ b/test/jakarta/servlet/http/TestHttpServlet.java
@@ -24,6 +24,7 @@ import java.util.Map;
 
 import jakarta.servlet.Servlet;
 import jakarta.servlet.ServletException;
+import jakarta.servlet.ServletOutputStream;
 
 import org.junit.Assert;
 import org.junit.Test;
@@ -122,6 +123,30 @@ public class TestHttpServlet extends TomcatBaseTest {
     }
 
 
+    @Test
+    public void testHeadWithResetBufferWriter() throws Exception {
+        doTestHead(new ResetBufferServlet(true));
+    }
+
+
+    @Test
+    public void testHeadWithResetBufferStream() throws Exception {
+        doTestHead(new ResetBufferServlet(false));
+    }
+
+
+    @Test
+    public void testHeadWithResetWriter() throws Exception {
+        doTestHead(new ResetServlet(true));
+    }
+
+
+    @Test
+    public void testHeadWithResetStream() throws Exception {
+        doTestHead(new ResetServlet(false));
+    }
+
+
     private void doTestHead(Servlet servlet) throws Exception {
         Tomcat tomcat = getTomcatInstance();
 
@@ -346,6 +371,72 @@ public class TestHttpServlet extends TomcatBaseTest {
     }
 
 
+    private static class ResetBufferServlet extends HttpServlet {
+
+        private static final long serialVersionUID = 1L;
+
+        private final boolean useWriter;
+
+        public ResetBufferServlet(boolean useWriter) {
+            this.useWriter = useWriter;
+        }
+
+        @Override
+        protected void doGet(HttpServletRequest req, HttpServletResponse resp)
+                throws ServletException, IOException {
+            resp.setContentType("text/plain");
+            resp.setCharacterEncoding("UTF-8");
+
+            if (useWriter) {
+                PrintWriter pw = resp.getWriter();
+                pw.write(new char[4 * 1024]);
+                resp.resetBuffer();
+                pw.write(new char[4 * 1024]);
+            } else {
+                ServletOutputStream sos = resp.getOutputStream();
+                sos.write(new byte [4 * 1024]);
+                resp.resetBuffer();
+                sos.write(new byte [4 * 1024]);
+            }
+        }
+    }
+
+
+    private static class ResetServlet extends HttpServlet {
+
+        private static final long serialVersionUID = 1L;
+
+        private final boolean useWriter;
+
+        public ResetServlet(boolean useWriter) {
+            this.useWriter = useWriter;
+        }
+
+        @Override
+        protected void doGet(HttpServletRequest req, HttpServletResponse resp)
+                throws ServletException, IOException {
+            resp.setContentType("text/plain");
+            resp.setCharacterEncoding("UTF-8");
+
+            if (useWriter) {
+                PrintWriter pw = resp.getWriter();
+                resp.addHeader("aaa", "bbb");
+                pw.write(new char[4 * 1024]);
+                resp.resetBuffer();
+                resp.addHeader("ccc", "ddd");
+                pw.write(new char[4 * 1024]);
+            } else {
+                ServletOutputStream sos = resp.getOutputStream();
+                resp.addHeader("aaa", "bbb");
+                sos.write(new byte [4 * 1024]);
+                resp.resetBuffer();
+                resp.addHeader("ccc", "ddd");
+                sos.write(new byte [4 * 1024]);
+            }
+        }
+    }
+
+
     private static class OptionsServlet extends HttpServlet {
 
         private static final long serialVersionUID = 1L;
diff --git a/webapps/docs/changelog.xml b/webapps/docs/changelog.xml
index 8154ef8..118e7a0 100644
--- a/webapps/docs/changelog.xml
+++ b/webapps/docs/changelog.xml
@@ -174,6 +174,12 @@
         Avoid synchronization on roles verification for the memory
         <code>UserDatabase</code>. (remm)
       </fix>
+      <fix>
+        Fix the default <code>doHead())</code> implementation in
+        <code>HttpServlet</code> to correctly handle responses where the 
Servlet
+        calls <code>ServletResponse.reset()</code> and/or
+        <code>ServletResponse.resetBuffer()</code>. (markt)
+      </fix>
     </changelog>
   </subsection>
   <subsection name="Coyote">

---------------------------------------------------------------------
To unsubscribe, e-mail: dev-unsubscr...@tomcat.apache.org
For additional commands, e-mail: dev-h...@tomcat.apache.org

Reply via email to