Author: violetagg
Date: Wed Jun 26 05:00:22 2013
New Revision: 1496732

URL: http://svn.apache.org/r1496732
Log:
When AsyncContext.dispatch(...) is invoked do not cast request and response to 
HttpServletRequest/HttpServletResponse. 
AsyncContext.startAsync(ServletRequest,ServletResponse) can be invoked with 
custom ServletRequest/ServletResponse.

Modified:
    tomcat/trunk/java/org/apache/catalina/core/AsyncContextImpl.java
    tomcat/trunk/test/org/apache/catalina/core/TestAsyncContextImpl.java

Modified: tomcat/trunk/java/org/apache/catalina/core/AsyncContextImpl.java
URL: 
http://svn.apache.org/viewvc/tomcat/trunk/java/org/apache/catalina/core/AsyncContextImpl.java?rev=1496732&r1=1496731&r2=1496732&view=diff
==============================================================================
--- tomcat/trunk/java/org/apache/catalina/core/AsyncContextImpl.java (original)
+++ tomcat/trunk/java/org/apache/catalina/core/AsyncContextImpl.java Wed Jun 26 
05:00:22 2013
@@ -168,9 +168,17 @@ public class AsyncContextImpl implements
     @Override
     public void dispatch() {
         check();
-        HttpServletRequest sr = (HttpServletRequest)getRequest();
-        String path = sr.getRequestURI();
-        String cpath = sr.getContextPath();
+        String path;
+        String cpath;
+        ServletRequest servletRequest = getRequest();
+        if (servletRequest instanceof HttpServletRequest) {
+            HttpServletRequest sr = (HttpServletRequest) servletRequest;
+            path = sr.getRequestURI();
+            cpath = sr.getContextPath();
+        } else {
+            path = request.getRequestURI();
+            cpath = request.getContextPath();
+        }
         if (cpath.length()>1) path = path.substring(cpath.length());
         dispatch(path);
     }
@@ -205,10 +213,8 @@ public class AsyncContextImpl implements
         }
         final AsyncDispatcher applicationDispatcher =
                 (AsyncDispatcher) requestDispatcher;
-        final HttpServletRequest servletRequest =
-                (HttpServletRequest) getRequest();
-        final HttpServletResponse servletResponse =
-                (HttpServletResponse) getResponse();
+        final ServletRequest servletRequest = getRequest();
+        final ServletResponse servletResponse = getResponse();
         Runnable run = new Runnable() {
             @Override
             public void run() {

Modified: tomcat/trunk/test/org/apache/catalina/core/TestAsyncContextImpl.java
URL: 
http://svn.apache.org/viewvc/tomcat/trunk/test/org/apache/catalina/core/TestAsyncContextImpl.java?rev=1496732&r1=1496731&r2=1496732&view=diff
==============================================================================
--- tomcat/trunk/test/org/apache/catalina/core/TestAsyncContextImpl.java 
(original)
+++ tomcat/trunk/test/org/apache/catalina/core/TestAsyncContextImpl.java Wed 
Jun 26 05:00:22 2013
@@ -30,11 +30,15 @@ import javax.servlet.AsyncContext;
 import javax.servlet.AsyncEvent;
 import javax.servlet.AsyncListener;
 import javax.servlet.DispatcherType;
+import javax.servlet.GenericServlet;
 import javax.servlet.RequestDispatcher;
 import javax.servlet.ServletException;
+import javax.servlet.ServletRequest;
 import javax.servlet.ServletRequestEvent;
 import javax.servlet.ServletRequestListener;
+import javax.servlet.ServletRequestWrapper;
 import javax.servlet.ServletResponse;
+import javax.servlet.ServletResponseWrapper;
 import javax.servlet.http.HttpServlet;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
@@ -1740,9 +1744,9 @@ public class TestAsyncContextImpl extend
         Wrapper wrapper = Tomcat.addServlet(ctx, "nonAsyncServlet",
                 nonAsyncServlet);
         wrapper.setAsyncSupported(true);
-        ctx.addServletMapping("/nonAsyncServlet", "nonAsyncServlet");
+        ctx.addServletMapping("/target", "nonAsyncServlet");
 
-        ForbiddenDispatchingServlet forbiddenDispatchingServlet = new 
ForbiddenDispatchingServlet();
+        DispatchingGenericServlet forbiddenDispatchingServlet = new 
DispatchingGenericServlet();
         Wrapper wrapper1 = Tomcat.addServlet(ctx,
                 "forbiddenDispatchingServlet", forbiddenDispatchingServlet);
         wrapper1.setAsyncSupported(true);
@@ -1766,20 +1770,37 @@ public class TestAsyncContextImpl extend
         assertTrue(body.toString().contains("NonAsyncServletGet"));
     }
 
-    private static class ForbiddenDispatchingServlet extends HttpServlet {
+    private static class DispatchingGenericServlet extends GenericServlet {
 
         private static final long serialVersionUID = 1L;
+        private static final String CUSTOM_REQ_RESP = "crr";
+        private static final String EMPTY_DISPATCH = "empty";
 
         @Override
-        protected void doGet(HttpServletRequest req, HttpServletResponse resp)
+        public void service(ServletRequest req, ServletResponse resp)
                 throws ServletException, IOException {
-            AsyncContext asyncContext = req.startAsync();
-            asyncContext.dispatch("/nonAsyncServlet");
-            try {
-                asyncContext.dispatch("/nonExistingServlet");
-                resp.getWriter().println("FAIL");
-            } catch (IllegalStateException e) {
-                resp.getWriter().println("OK");
+            if (DispatcherType.ASYNC != req.getDispatcherType()) {
+                AsyncContext asyncContext;
+                if ("y".equals(req.getParameter(CUSTOM_REQ_RESP))) {
+                    asyncContext = req.startAsync(
+                            new ServletRequestWrapper(req),
+                            new ServletResponseWrapper(resp));
+                } else {
+                    asyncContext = req.startAsync();
+                }
+                if ("y".equals(req.getParameter(EMPTY_DISPATCH))) {
+                    asyncContext.dispatch();
+                } else {
+                    asyncContext.dispatch("/target");
+                }
+                try {
+                    asyncContext.dispatch("/nonExistingServlet");
+                    resp.getWriter().print("FAIL");
+                } catch (IllegalStateException e) {
+                    resp.getWriter().print("OK");
+                }
+            } else {
+                resp.getWriter().print("ForbiddenDispatchingServletGet-");
             }
         }
     }
@@ -1855,4 +1876,59 @@ public class TestAsyncContextImpl extend
             }
         }
     }
+
+    @Test
+    public void testDispatchWithCustomRequestResponse() throws Exception {
+        // Setup Tomcat instance
+        Tomcat tomcat = getTomcatInstance();
+
+        // Must have a real docBase - just use temp
+        File docBase = new File(System.getProperty("java.io.tmpdir"));
+
+        Context ctx = tomcat.addContext("", docBase.getAbsolutePath());
+
+        DispatchingGenericServlet dispatch = new DispatchingGenericServlet();
+        Wrapper wrapper = Tomcat.addServlet(ctx, "dispatch", dispatch);
+        wrapper.setAsyncSupported(true);
+        ctx.addServletMapping("/dispatch", "dispatch");
+
+        CustomGenericServlet customGeneric = new CustomGenericServlet();
+        Wrapper wrapper2 = Tomcat.addServlet(ctx, "customGeneric",
+                customGeneric);
+        wrapper2.setAsyncSupported(true);
+        ctx.addServletMapping("/target", "customGeneric");
+
+        tomcat.start();
+
+        ByteChunk res = getUrl("http://localhost:"; + getPort()
+                + "/dispatch?crr=y");
+
+        StringBuilder expected = new StringBuilder();
+        expected.append("OK");
+        expected.append("CustomGenericServletGet-");
+        assertEquals(expected.toString(), res.toString());
+
+        res = getUrl("http://localhost:"; + getPort()
+                + "/dispatch?crr=y&empty=y");
+
+        expected = new StringBuilder();
+        expected.append("OK");
+        expected.append("ForbiddenDispatchingServletGet-");
+        assertEquals(expected.toString(), res.toString());
+    }
+
+    private static class CustomGenericServlet extends GenericServlet {
+
+        private static final long serialVersionUID = 1L;
+
+        @Override
+        public void service(ServletRequest req, ServletResponse res)
+                throws ServletException, IOException {
+            if (req instanceof ServletRequestWrapper
+                    && res instanceof ServletResponseWrapper) {
+                res.getWriter().print("CustomGenericServletGet-");
+            }
+        }
+
+    }
 }



---------------------------------------------------------------------
To unsubscribe, e-mail: dev-unsubscr...@tomcat.apache.org
For additional commands, e-mail: dev-h...@tomcat.apache.org

Reply via email to