WICKET-6245: open up CsrfPreventionRequestCycleListener for extension

Project: http://git-wip-us.apache.org/repos/asf/wicket/repo
Commit: http://git-wip-us.apache.org/repos/asf/wicket/commit/6c40c919
Tree: http://git-wip-us.apache.org/repos/asf/wicket/tree/6c40c919
Diff: http://git-wip-us.apache.org/repos/asf/wicket/diff/6c40c919

Branch: refs/heads/master
Commit: 6c40c919f54fce610c584b9e4ec7925c14a5a19b
Parents: c04f2b0
Author: Emond Papegaaij <emond.papega...@topicus.nl>
Authored: Mon Sep 19 15:24:57 2016 +0200
Committer: Emond Papegaaij <emond.papega...@topicus.nl>
Committed: Mon Sep 19 15:25:21 2016 +0200

----------------------------------------------------------------------
 .../CsrfPreventionRequestCycleListener.java     | 182 +++++++++++--------
 1 file changed, 111 insertions(+), 71 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/wicket/blob/6c40c919/wicket-core/src/main/java/org/apache/wicket/protocol/http/CsrfPreventionRequestCycleListener.java
----------------------------------------------------------------------
diff --git 
a/wicket-core/src/main/java/org/apache/wicket/protocol/http/CsrfPreventionRequestCycleListener.java
 
b/wicket-core/src/main/java/org/apache/wicket/protocol/http/CsrfPreventionRequestCycleListener.java
index a2bf124..ce03862 100644
--- 
a/wicket-core/src/main/java/org/apache/wicket/protocol/http/CsrfPreventionRequestCycleListener.java
+++ 
b/wicket-core/src/main/java/org/apache/wicket/protocol/http/CsrfPreventionRequestCycleListener.java
@@ -27,7 +27,9 @@ import javax.servlet.http.HttpServletRequest;
 import org.apache.wicket.RestartResponseException;
 import org.apache.wicket.core.request.handler.IPageRequestHandler;
 import org.apache.wicket.core.request.handler.RenderPageRequestHandler;
+import org.apache.wicket.protocol.http.WebApplication;
 import org.apache.wicket.request.IRequestHandler;
+import org.apache.wicket.request.IRequestHandlerDelegate;
 import org.apache.wicket.request.component.IRequestablePage;
 import org.apache.wicket.request.cycle.AbstractRequestCycleListener;
 import org.apache.wicket.request.cycle.IRequestCycleListener;
@@ -39,9 +41,9 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 /**
- * Prevents CSRF attacks on Wicket components by checking the {@code Origin} 
HTTP header for cross
- * domain requests. By default only checks requests that try to perform an 
action on a component,
- * such as a form submit, or link click.
+ * Prevents CSRF attacks on Wicket components by checking the {@code Origin} 
and {@code Referer}
+ * HTTP headers for cross domain requests. By default only checks requests 
that try to perform an
+ * action on a component, such as a form submit, or link click.
  * <p>
  * <h3>Installation</h3>
  * <p>
@@ -60,18 +62,17 @@ import org.slf4j.LoggerFactory;
  * <p>
  * <h3>Configuration</h3>
  * <p>
- * A missing {@code Origin} HTTP header is (by default) handled as if it were 
a good request and
- * accepted. You can {@link #setNoOriginAction(CsrfAction) configure the 
specific action} to a
- * different value, suppressing or aborting the request when the {@code 
Origin} HTTP header is
- * missing.
+ * When the {@code Origin} or {@code Referer} HTTP header is present but 
doesn't match the requested
+ * URL this listener will by default throw a HTTP error ( {@code 400 BAD 
REQUEST}) and abort the
+ * request. You can {@link #setConflictingOriginAction(CsrfAction) configure} 
this specific action.
  * <p>
- * When the {@code Origin} HTTP header is present and has the value {@code 
null} it is considered to
- * be from a "privacy-sensitive" context and will trigger the conflicting 
origin action. You can
- * customize what happens in those actions by overriding the respective {@code 
onXXXX} methods.
+ * A missing {@code Origin} and {@code Referer} HTTP header is handled as if 
it were a bad request
+ * and rejected. You can {@link #setNoOriginAction(CsrfAction) configure the 
specific action} to a
+ * different value, suppressing or allowing the request when the HTTP headers 
are missing.
  * <p>
- * When the {@code Origin} HTTP header is present but doesn't match the 
requested URL this listener
- * will by default throw a HTTP error ( {@code 400 BAD REQUEST}) and abort the 
request. You can
- * {@link #setConflictingOriginAction(CsrfAction) configure} this specific 
action.
+ * When the {@code Origin} HTTP header is present and has the value {@code 
null} it is considered to
+ * be from a "privacy-sensitive" context and will trigger the no origin 
action. You can customize
+ * what happens in those actions by overriding the respective {@code onXXXX} 
methods.
  * <p>
  * When you want to accept certain cross domain request from a range of hosts, 
you can
  * {@link #addAcceptedOrigin(String) whitelist those domains}.
@@ -96,7 +97,7 @@ import org.slf4j.LoggerFactory;
  * {@link #isChecked(IRequestHandler)} to customize this behavior.
  * </p>
  * <p>
- * You can override the default actions that are performed by overriding the 
event handlers for
+ * You can customize the default actions that are performed by overriding the 
event handlers for
  * them:
  * <ul>
  * <li>{@link #onWhitelisted(HttpServletRequest, String, IRequestablePage)} 
when an origin was
@@ -119,7 +120,7 @@ public class CsrfPreventionRequestCycleListener extends 
AbstractRequestCycleList
                .getLogger(CsrfPreventionRequestCycleListener.class);
 
        /**
-        * The action to perform when a missing or conflicting Origin header is 
detected.
+        * The action to perform when a missing or conflicting source URI is 
detected.
         */
        public enum CsrfAction {
                /** Aborts the request and throws an exception when a CSRF 
request is detected. */
@@ -155,7 +156,7 @@ public class CsrfPreventionRequestCycleListener extends 
AbstractRequestCycleList
        /**
         * Action to perform when no Origin header is present in the request.
         */
-       private CsrfAction noOriginAction = CsrfAction.ALLOW;
+       private CsrfAction noOriginAction = CsrfAction.ABORT;
 
        /**
         * Action to perform when a conflicting Origin header is found.
@@ -271,8 +272,7 @@ public class CsrfPreventionRequestCycleListener extends 
AbstractRequestCycleList
                {
                        HttpServletRequest containerRequest = 
(HttpServletRequest)cycle.getRequest()
                                .getContainerRequest();
-                       String origin = containerRequest.getHeader("Origin");
-                       log.debug("Request header Origin: {}", origin);
+                       log.debug("Request Source URI: {}", 
getSourceUri(containerRequest));
                }
        }
 
@@ -315,6 +315,21 @@ public class CsrfPreventionRequestCycleListener extends 
AbstractRequestCycleList
                        !(handler instanceof RenderPageRequestHandler);
        }
 
+       /**
+        * Unwraps the handler if it is a {@code IRequestHandlerDelegate} down 
to the deepest nested
+        * handler.
+        *
+        * @param handler
+        *            The handler to unwrap
+        * @return the deepest handler that does not implement {@code 
IRequestHandlerDelegate}
+        */
+       protected final IRequestHandler unwrap(IRequestHandler handler)
+       {
+               while (handler instanceof IRequestHandlerDelegate)
+                       handler = 
((IRequestHandlerDelegate)handler).getDelegateHandler();
+               return handler;
+       }
+
        @Override
        public void onRequestHandlerResolved(RequestCycle cycle, 
IRequestHandler handler)
        {
@@ -324,6 +339,8 @@ public class CsrfPreventionRequestCycleListener extends 
AbstractRequestCycleList
                        return;
                }
 
+               handler = unwrap(handler);
+
                // check if the request is targeted at a page
                if (isChecked(handler))
                {
@@ -331,112 +348,131 @@ public class CsrfPreventionRequestCycleListener extends 
AbstractRequestCycleList
                        IRequestablePage targetedPage = prh.getPage();
                        HttpServletRequest containerRequest = 
(HttpServletRequest)cycle.getRequest()
                                .getContainerRequest();
-                       String origin = containerRequest.getHeader("Origin");
+                       String sourceUri = getSourceUri(containerRequest);
 
                        // Check if the page should be CSRF protected
                        if (isChecked(targetedPage))
                        {
                                // if so check the Origin HTTP header
-                               checkOrigin(containerRequest, origin, 
targetedPage);
+                               checkRequest(containerRequest, sourceUri, 
targetedPage);
                        }
                        else
                        {
                                log.debug("Targeted page {} was opted out of 
the CSRF origin checks, allowed",
                                        targetedPage.getClass().getName());
-                               allowHandler(containerRequest, origin, 
targetedPage);
+                               allowHandler(containerRequest, sourceUri, 
targetedPage);
                        }
                }
                else
                {
                        if (log.isTraceEnabled())
-                               log.trace("Resolved handler {} doesn't target a 
page, no CSRF check performed",
+                               log.trace(
+                                       "Resolved handler {} doesn't target an 
action on a page, no CSRF check performed",
                                        handler.getClass().getName());
                }
        }
 
        /**
-        * Performs the check of the {@code Origin} header that is targeted at 
the {@code page}.
+        * Resolves the source URI from the request headers ({@code Origin} or 
{@code Referer}).
+        *
+        * @param containerRequest
+        *            the current container request
+        * @return the normalized source URI.
+        */
+       protected String getSourceUri(HttpServletRequest containerRequest)
+       {
+               String sourceUri = containerRequest.getHeader("Origin");
+               if (Strings.isEmpty(sourceUri))
+               {
+                       sourceUri = containerRequest.getHeader("Referer");
+               }
+               return normalizeUri(sourceUri);
+       }
+
+       /**
+        * Performs the check of the {@code Origin} or {@code Referer} header 
that is targeted at the
+        * {@code page}.
         *
         * @param request
         *            the current container request
-        * @param origin
-        *            the {@code Origin} header
+        * @param sourceUri
+        *            the source URI
         * @param page
         *            the page that is the target of the request
         */
-       private void checkOrigin(HttpServletRequest request, String origin, 
IRequestablePage page)
+       protected void checkRequest(HttpServletRequest request, String 
sourceUri, IRequestablePage page)
        {
-               if (origin == null || origin.isEmpty())
+               if (sourceUri == null || sourceUri.isEmpty())
                {
-                       log.debug("Origin-header not present in request, {}", 
noOriginAction);
+                       log.debug("Source URI not present in request, {}", 
noOriginAction);
                        switch (noOriginAction)
                        {
                                case ALLOW :
-                                       allowHandler(request, origin, page);
+                                       allowHandler(request, sourceUri, page);
                                        break;
                                case SUPPRESS :
-                                       suppressHandler(request, origin, page);
+                                       suppressHandler(request, sourceUri, 
page);
                                        break;
                                case ABORT :
-                                       abortHandler(request, origin, page);
+                                       abortHandler(request, sourceUri, page);
                                        break;
                        }
                        return;
                }
-               origin = origin.toLowerCase();
+               sourceUri = sourceUri.toLowerCase();
 
                // if the origin is a know and trusted origin, don't check any 
further but allow the request
-               if (isWhitelistedOrigin(origin))
+               if (isWhitelistedHost(sourceUri))
                {
-                       whitelistedHandler(request, origin, page);
+                       whitelistedHandler(request, sourceUri, page);
                        return;
                }
 
                // check if the origin HTTP header matches the request URI
-               if (!isLocalOrigin(request, origin))
+               if (!isLocalOrigin(request, sourceUri))
                {
-                       log.debug("Origin-header conflicts with request origin, 
{}", conflictingOriginAction);
+                       log.debug("Source URI conflicts with request origin, 
{}", conflictingOriginAction);
                        switch (conflictingOriginAction)
                        {
                                case ALLOW :
-                                       allowHandler(request, origin, page);
+                                       allowHandler(request, sourceUri, page);
                                        break;
                                case SUPPRESS :
-                                       suppressHandler(request, origin, page);
+                                       suppressHandler(request, sourceUri, 
page);
                                        break;
                                case ABORT :
-                                       abortHandler(request, origin, page);
+                                       abortHandler(request, sourceUri, page);
                                        break;
                        }
                }
                else
                {
-                       matchingOrigin(request, origin, page);
+                       matchingOrigin(request, sourceUri, page);
                }
        }
 
        /**
-        * Checks whether the domain part of the {@code Origin} HTTP header is 
whitelisted.
+        * Checks whether the domain part of the {@code sourceUri} ({@code 
Origin} or {@code Referer}
+        * header) is whitelisted.
         *
-        * @param origin
-        *            the {@code Origin} HTTP header
-        * @return {@code true} when the origin domain was whitelisted
+        * @param sourceUri
+        *            the contents of the {@code Origin} or {@code Referer} 
HTTP header
+        * @return {@code true} when the source domain was whitelisted
         */
-       private boolean isWhitelistedOrigin(final String origin)
+       protected final boolean isWhitelistedHost(final String sourceUri)
        {
                try
                {
-                       final URI originUri = new URI(origin);
-                       final String originHost = originUri.getHost();
-                       if (Strings.isEmpty(originHost))
+                       final String sourceHost = new URI(sourceUri).getHost();
+                       if (Strings.isEmpty(sourceHost))
                                return false;
                        for (String whitelistedOrigin : acceptedOrigins)
                        {
-                               if 
(originHost.equalsIgnoreCase(whitelistedOrigin) ||
-                                       originHost.endsWith("." + 
whitelistedOrigin))
+                               if 
(sourceHost.equalsIgnoreCase(whitelistedOrigin) ||
+                                       sourceHost.endsWith("." + 
whitelistedOrigin))
                                {
-                                       log.trace("Origin {} matched 
whitelisted origin {}, request accepted", origin,
-                                               whitelistedOrigin);
+                                       log.trace("Origin {} matched 
whitelisted origin {}, request accepted",
+                                               sourceUri, whitelistedOrigin);
                                        return true;
                                }
                        }
@@ -444,7 +480,7 @@ public class CsrfPreventionRequestCycleListener extends 
AbstractRequestCycleList
                catch (URISyntaxException e)
                {
                        log.debug("Origin: {} not parseable as an URI. 
Whitelisted-origin check skipped.",
-                               origin);
+                               sourceUri);
                }
 
                return false;
@@ -460,14 +496,14 @@ public class CsrfPreventionRequestCycleListener extends 
AbstractRequestCycleList
         *            the contents of the {@code Origin} HTTP header
         * @return {@code true} when the origin of the request matches the 
{@code Origin} HTTP header
         */
-       private boolean isLocalOrigin(HttpServletRequest containerRequest, 
String originHeader)
+       protected boolean isLocalOrigin(HttpServletRequest containerRequest, 
String originHeader)
        {
                // Make comparable strings from Origin and Location
-               String origin = getOriginHeaderOrigin(originHeader);
+               String origin = normalizeUri(originHeader);
                if (origin == null)
                        return false;
 
-               String request = getLocationHeaderOrigin(containerRequest);
+               String request = getTargetUriFromRequest(containerRequest);
                if (request == null)
                        return false;
 
@@ -475,27 +511,27 @@ public class CsrfPreventionRequestCycleListener extends 
AbstractRequestCycleList
        }
 
        /**
-        * Creates a RFC-6454 comparable origin from the {@code origin} string.
+        * Creates a RFC-6454 comparable URI from the {@code uri} string.
         *
-        * @param origin
-        *            the contents of the Origin HTTP header
-        * @return only the scheme://host[:port] part, or {@code null} when the 
origin string is not
+        * @param uri
+        *            the contents of the Origin or Referer HTTP header
+        * @return only the scheme://host[:port] part, or {@code null} when the 
URI string is not
         *         compliant
         */
-       private String getOriginHeaderOrigin(String origin)
+       protected final String normalizeUri(String uri)
        {
                // the request comes from a privacy sensitive context, flag as 
non-local origin. If
                // alternative action is required, an implementor can override 
any of the onAborted,
                // onSuppressed or onAllowed and implement such needed action.
 
-               if ("null".equals(origin))
+               if (Strings.isEmpty(uri) || "null".equals(uri))
                        return null;
 
                StringBuilder target = new StringBuilder();
 
                try
                {
-                       URI originUri = new URI(origin);
+                       URI originUri = new URI(uri);
                        String scheme = originUri.getScheme();
                        if (scheme == null)
                        {
@@ -530,20 +566,20 @@ public class CsrfPreventionRequestCycleListener extends 
AbstractRequestCycleList
                }
                catch (URISyntaxException e)
                {
-                       log.debug("Invalid Origin header provided: {}, marked 
conflicting", origin);
+                       log.debug("Invalid URI provided: {}, marked 
conflicting", uri);
                        return null;
                }
        }
 
        /**
-        * Creates a RFC-6454 comparable origin from the {@code request} 
requested resource.
+        * Creates a RFC-6454 comparable URI from the {@code request} requested 
resource.
         *
         * @param request
         *            the incoming request
         * @return only the scheme://host[:port] part, or {@code null} when the 
origin string is not
         *         compliant
         */
-       private String getLocationHeaderOrigin(HttpServletRequest request)
+       protected final String getTargetUriFromRequest(HttpServletRequest 
request)
        {
                // Build scheme://host:port from request
                StringBuilder target = new StringBuilder();
@@ -587,7 +623,7 @@ public class CsrfPreventionRequestCycleListener extends 
AbstractRequestCycleList
         * @param page
         *            the page that is targeted with this request
         */
-       private void whitelistedHandler(HttpServletRequest request, String 
origin,
+       protected final void whitelistedHandler(HttpServletRequest request, 
String origin,
                IRequestablePage page)
        {
                onWhitelisted(request, origin, page);
@@ -624,7 +660,8 @@ public class CsrfPreventionRequestCycleListener extends 
AbstractRequestCycleList
         * @param page
         *            the page that is targeted with this request
         */
-       private void matchingOrigin(HttpServletRequest request, String origin, 
IRequestablePage page)
+       protected final void matchingOrigin(HttpServletRequest request, String 
origin,
+               IRequestablePage page)
        {
                onMatchingOrigin(request, origin, page);
                if (log.isDebugEnabled())
@@ -662,7 +699,8 @@ public class CsrfPreventionRequestCycleListener extends 
AbstractRequestCycleList
         * @param page
         *            the page that is targeted with this request
         */
-       private void allowHandler(HttpServletRequest request, String origin, 
IRequestablePage page)
+       protected final void allowHandler(HttpServletRequest request, String 
origin,
+               IRequestablePage page)
        {
                onAllowed(request, origin, page);
                log.info("Possible CSRF attack, request URL: {}, Origin: {}, 
action: allowed",
@@ -697,7 +735,8 @@ public class CsrfPreventionRequestCycleListener extends 
AbstractRequestCycleList
         * @param page
         *            the page that is targeted with this request
         */
-       private void suppressHandler(HttpServletRequest request, String origin, 
IRequestablePage page)
+       protected final void suppressHandler(HttpServletRequest request, String 
origin,
+               IRequestablePage page)
        {
                onSuppressed(request, origin, page);
                log.info("Possible CSRF attack, request URL: {}, Origin: {}, 
action: suppressed",
@@ -733,7 +772,8 @@ public class CsrfPreventionRequestCycleListener extends 
AbstractRequestCycleList
         * @param page
         *            the page that is targeted with this request
         */
-       private void abortHandler(HttpServletRequest request, String origin, 
IRequestablePage page)
+       protected final void abortHandler(HttpServletRequest request, String 
origin,
+               IRequestablePage page)
        {
                onAborted(request, origin, page);
                log.info(

Reply via email to