WICKET-6245: open up CsrfPreventionRequestCycleListener for extension
Project: http://git-wip-us.apache.org/repos/asf/wicket/repo Commit: http://git-wip-us.apache.org/repos/asf/wicket/commit/6c40c919 Tree: http://git-wip-us.apache.org/repos/asf/wicket/tree/6c40c919 Diff: http://git-wip-us.apache.org/repos/asf/wicket/diff/6c40c919 Branch: refs/heads/master Commit: 6c40c919f54fce610c584b9e4ec7925c14a5a19b Parents: c04f2b0 Author: Emond Papegaaij <emond.papega...@topicus.nl> Authored: Mon Sep 19 15:24:57 2016 +0200 Committer: Emond Papegaaij <emond.papega...@topicus.nl> Committed: Mon Sep 19 15:25:21 2016 +0200 ---------------------------------------------------------------------- .../CsrfPreventionRequestCycleListener.java | 182 +++++++++++-------- 1 file changed, 111 insertions(+), 71 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/wicket/blob/6c40c919/wicket-core/src/main/java/org/apache/wicket/protocol/http/CsrfPreventionRequestCycleListener.java ---------------------------------------------------------------------- diff --git a/wicket-core/src/main/java/org/apache/wicket/protocol/http/CsrfPreventionRequestCycleListener.java b/wicket-core/src/main/java/org/apache/wicket/protocol/http/CsrfPreventionRequestCycleListener.java index a2bf124..ce03862 100644 --- a/wicket-core/src/main/java/org/apache/wicket/protocol/http/CsrfPreventionRequestCycleListener.java +++ b/wicket-core/src/main/java/org/apache/wicket/protocol/http/CsrfPreventionRequestCycleListener.java @@ -27,7 +27,9 @@ import javax.servlet.http.HttpServletRequest; import org.apache.wicket.RestartResponseException; import org.apache.wicket.core.request.handler.IPageRequestHandler; import org.apache.wicket.core.request.handler.RenderPageRequestHandler; +import org.apache.wicket.protocol.http.WebApplication; import org.apache.wicket.request.IRequestHandler; +import org.apache.wicket.request.IRequestHandlerDelegate; import org.apache.wicket.request.component.IRequestablePage; import org.apache.wicket.request.cycle.AbstractRequestCycleListener; import org.apache.wicket.request.cycle.IRequestCycleListener; @@ -39,9 +41,9 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** - * Prevents CSRF attacks on Wicket components by checking the {@code Origin} HTTP header for cross - * domain requests. By default only checks requests that try to perform an action on a component, - * such as a form submit, or link click. + * Prevents CSRF attacks on Wicket components by checking the {@code Origin} and {@code Referer} + * HTTP headers for cross domain requests. By default only checks requests that try to perform an + * action on a component, such as a form submit, or link click. * <p> * <h3>Installation</h3> * <p> @@ -60,18 +62,17 @@ import org.slf4j.LoggerFactory; * <p> * <h3>Configuration</h3> * <p> - * A missing {@code Origin} HTTP header is (by default) handled as if it were a good request and - * accepted. You can {@link #setNoOriginAction(CsrfAction) configure the specific action} to a - * different value, suppressing or aborting the request when the {@code Origin} HTTP header is - * missing. + * When the {@code Origin} or {@code Referer} HTTP header is present but doesn't match the requested + * URL this listener will by default throw a HTTP error ( {@code 400 BAD REQUEST}) and abort the + * request. You can {@link #setConflictingOriginAction(CsrfAction) configure} this specific action. * <p> - * When the {@code Origin} HTTP header is present and has the value {@code null} it is considered to - * be from a "privacy-sensitive" context and will trigger the conflicting origin action. You can - * customize what happens in those actions by overriding the respective {@code onXXXX} methods. + * A missing {@code Origin} and {@code Referer} HTTP header is handled as if it were a bad request + * and rejected. You can {@link #setNoOriginAction(CsrfAction) configure the specific action} to a + * different value, suppressing or allowing the request when the HTTP headers are missing. * <p> - * When the {@code Origin} HTTP header is present but doesn't match the requested URL this listener - * will by default throw a HTTP error ( {@code 400 BAD REQUEST}) and abort the request. You can - * {@link #setConflictingOriginAction(CsrfAction) configure} this specific action. + * When the {@code Origin} HTTP header is present and has the value {@code null} it is considered to + * be from a "privacy-sensitive" context and will trigger the no origin action. You can customize + * what happens in those actions by overriding the respective {@code onXXXX} methods. * <p> * When you want to accept certain cross domain request from a range of hosts, you can * {@link #addAcceptedOrigin(String) whitelist those domains}. @@ -96,7 +97,7 @@ import org.slf4j.LoggerFactory; * {@link #isChecked(IRequestHandler)} to customize this behavior. * </p> * <p> - * You can override the default actions that are performed by overriding the event handlers for + * You can customize the default actions that are performed by overriding the event handlers for * them: * <ul> * <li>{@link #onWhitelisted(HttpServletRequest, String, IRequestablePage)} when an origin was @@ -119,7 +120,7 @@ public class CsrfPreventionRequestCycleListener extends AbstractRequestCycleList .getLogger(CsrfPreventionRequestCycleListener.class); /** - * The action to perform when a missing or conflicting Origin header is detected. + * The action to perform when a missing or conflicting source URI is detected. */ public enum CsrfAction { /** Aborts the request and throws an exception when a CSRF request is detected. */ @@ -155,7 +156,7 @@ public class CsrfPreventionRequestCycleListener extends AbstractRequestCycleList /** * Action to perform when no Origin header is present in the request. */ - private CsrfAction noOriginAction = CsrfAction.ALLOW; + private CsrfAction noOriginAction = CsrfAction.ABORT; /** * Action to perform when a conflicting Origin header is found. @@ -271,8 +272,7 @@ public class CsrfPreventionRequestCycleListener extends AbstractRequestCycleList { HttpServletRequest containerRequest = (HttpServletRequest)cycle.getRequest() .getContainerRequest(); - String origin = containerRequest.getHeader("Origin"); - log.debug("Request header Origin: {}", origin); + log.debug("Request Source URI: {}", getSourceUri(containerRequest)); } } @@ -315,6 +315,21 @@ public class CsrfPreventionRequestCycleListener extends AbstractRequestCycleList !(handler instanceof RenderPageRequestHandler); } + /** + * Unwraps the handler if it is a {@code IRequestHandlerDelegate} down to the deepest nested + * handler. + * + * @param handler + * The handler to unwrap + * @return the deepest handler that does not implement {@code IRequestHandlerDelegate} + */ + protected final IRequestHandler unwrap(IRequestHandler handler) + { + while (handler instanceof IRequestHandlerDelegate) + handler = ((IRequestHandlerDelegate)handler).getDelegateHandler(); + return handler; + } + @Override public void onRequestHandlerResolved(RequestCycle cycle, IRequestHandler handler) { @@ -324,6 +339,8 @@ public class CsrfPreventionRequestCycleListener extends AbstractRequestCycleList return; } + handler = unwrap(handler); + // check if the request is targeted at a page if (isChecked(handler)) { @@ -331,112 +348,131 @@ public class CsrfPreventionRequestCycleListener extends AbstractRequestCycleList IRequestablePage targetedPage = prh.getPage(); HttpServletRequest containerRequest = (HttpServletRequest)cycle.getRequest() .getContainerRequest(); - String origin = containerRequest.getHeader("Origin"); + String sourceUri = getSourceUri(containerRequest); // Check if the page should be CSRF protected if (isChecked(targetedPage)) { // if so check the Origin HTTP header - checkOrigin(containerRequest, origin, targetedPage); + checkRequest(containerRequest, sourceUri, targetedPage); } else { log.debug("Targeted page {} was opted out of the CSRF origin checks, allowed", targetedPage.getClass().getName()); - allowHandler(containerRequest, origin, targetedPage); + allowHandler(containerRequest, sourceUri, targetedPage); } } else { if (log.isTraceEnabled()) - log.trace("Resolved handler {} doesn't target a page, no CSRF check performed", + log.trace( + "Resolved handler {} doesn't target an action on a page, no CSRF check performed", handler.getClass().getName()); } } /** - * Performs the check of the {@code Origin} header that is targeted at the {@code page}. + * Resolves the source URI from the request headers ({@code Origin} or {@code Referer}). + * + * @param containerRequest + * the current container request + * @return the normalized source URI. + */ + protected String getSourceUri(HttpServletRequest containerRequest) + { + String sourceUri = containerRequest.getHeader("Origin"); + if (Strings.isEmpty(sourceUri)) + { + sourceUri = containerRequest.getHeader("Referer"); + } + return normalizeUri(sourceUri); + } + + /** + * Performs the check of the {@code Origin} or {@code Referer} header that is targeted at the + * {@code page}. * * @param request * the current container request - * @param origin - * the {@code Origin} header + * @param sourceUri + * the source URI * @param page * the page that is the target of the request */ - private void checkOrigin(HttpServletRequest request, String origin, IRequestablePage page) + protected void checkRequest(HttpServletRequest request, String sourceUri, IRequestablePage page) { - if (origin == null || origin.isEmpty()) + if (sourceUri == null || sourceUri.isEmpty()) { - log.debug("Origin-header not present in request, {}", noOriginAction); + log.debug("Source URI not present in request, {}", noOriginAction); switch (noOriginAction) { case ALLOW : - allowHandler(request, origin, page); + allowHandler(request, sourceUri, page); break; case SUPPRESS : - suppressHandler(request, origin, page); + suppressHandler(request, sourceUri, page); break; case ABORT : - abortHandler(request, origin, page); + abortHandler(request, sourceUri, page); break; } return; } - origin = origin.toLowerCase(); + sourceUri = sourceUri.toLowerCase(); // if the origin is a know and trusted origin, don't check any further but allow the request - if (isWhitelistedOrigin(origin)) + if (isWhitelistedHost(sourceUri)) { - whitelistedHandler(request, origin, page); + whitelistedHandler(request, sourceUri, page); return; } // check if the origin HTTP header matches the request URI - if (!isLocalOrigin(request, origin)) + if (!isLocalOrigin(request, sourceUri)) { - log.debug("Origin-header conflicts with request origin, {}", conflictingOriginAction); + log.debug("Source URI conflicts with request origin, {}", conflictingOriginAction); switch (conflictingOriginAction) { case ALLOW : - allowHandler(request, origin, page); + allowHandler(request, sourceUri, page); break; case SUPPRESS : - suppressHandler(request, origin, page); + suppressHandler(request, sourceUri, page); break; case ABORT : - abortHandler(request, origin, page); + abortHandler(request, sourceUri, page); break; } } else { - matchingOrigin(request, origin, page); + matchingOrigin(request, sourceUri, page); } } /** - * Checks whether the domain part of the {@code Origin} HTTP header is whitelisted. + * Checks whether the domain part of the {@code sourceUri} ({@code Origin} or {@code Referer} + * header) is whitelisted. * - * @param origin - * the {@code Origin} HTTP header - * @return {@code true} when the origin domain was whitelisted + * @param sourceUri + * the contents of the {@code Origin} or {@code Referer} HTTP header + * @return {@code true} when the source domain was whitelisted */ - private boolean isWhitelistedOrigin(final String origin) + protected final boolean isWhitelistedHost(final String sourceUri) { try { - final URI originUri = new URI(origin); - final String originHost = originUri.getHost(); - if (Strings.isEmpty(originHost)) + final String sourceHost = new URI(sourceUri).getHost(); + if (Strings.isEmpty(sourceHost)) return false; for (String whitelistedOrigin : acceptedOrigins) { - if (originHost.equalsIgnoreCase(whitelistedOrigin) || - originHost.endsWith("." + whitelistedOrigin)) + if (sourceHost.equalsIgnoreCase(whitelistedOrigin) || + sourceHost.endsWith("." + whitelistedOrigin)) { - log.trace("Origin {} matched whitelisted origin {}, request accepted", origin, - whitelistedOrigin); + log.trace("Origin {} matched whitelisted origin {}, request accepted", + sourceUri, whitelistedOrigin); return true; } } @@ -444,7 +480,7 @@ public class CsrfPreventionRequestCycleListener extends AbstractRequestCycleList catch (URISyntaxException e) { log.debug("Origin: {} not parseable as an URI. Whitelisted-origin check skipped.", - origin); + sourceUri); } return false; @@ -460,14 +496,14 @@ public class CsrfPreventionRequestCycleListener extends AbstractRequestCycleList * the contents of the {@code Origin} HTTP header * @return {@code true} when the origin of the request matches the {@code Origin} HTTP header */ - private boolean isLocalOrigin(HttpServletRequest containerRequest, String originHeader) + protected boolean isLocalOrigin(HttpServletRequest containerRequest, String originHeader) { // Make comparable strings from Origin and Location - String origin = getOriginHeaderOrigin(originHeader); + String origin = normalizeUri(originHeader); if (origin == null) return false; - String request = getLocationHeaderOrigin(containerRequest); + String request = getTargetUriFromRequest(containerRequest); if (request == null) return false; @@ -475,27 +511,27 @@ public class CsrfPreventionRequestCycleListener extends AbstractRequestCycleList } /** - * Creates a RFC-6454 comparable origin from the {@code origin} string. + * Creates a RFC-6454 comparable URI from the {@code uri} string. * - * @param origin - * the contents of the Origin HTTP header - * @return only the scheme://host[:port] part, or {@code null} when the origin string is not + * @param uri + * the contents of the Origin or Referer HTTP header + * @return only the scheme://host[:port] part, or {@code null} when the URI string is not * compliant */ - private String getOriginHeaderOrigin(String origin) + protected final String normalizeUri(String uri) { // the request comes from a privacy sensitive context, flag as non-local origin. If // alternative action is required, an implementor can override any of the onAborted, // onSuppressed or onAllowed and implement such needed action. - if ("null".equals(origin)) + if (Strings.isEmpty(uri) || "null".equals(uri)) return null; StringBuilder target = new StringBuilder(); try { - URI originUri = new URI(origin); + URI originUri = new URI(uri); String scheme = originUri.getScheme(); if (scheme == null) { @@ -530,20 +566,20 @@ public class CsrfPreventionRequestCycleListener extends AbstractRequestCycleList } catch (URISyntaxException e) { - log.debug("Invalid Origin header provided: {}, marked conflicting", origin); + log.debug("Invalid URI provided: {}, marked conflicting", uri); return null; } } /** - * Creates a RFC-6454 comparable origin from the {@code request} requested resource. + * Creates a RFC-6454 comparable URI from the {@code request} requested resource. * * @param request * the incoming request * @return only the scheme://host[:port] part, or {@code null} when the origin string is not * compliant */ - private String getLocationHeaderOrigin(HttpServletRequest request) + protected final String getTargetUriFromRequest(HttpServletRequest request) { // Build scheme://host:port from request StringBuilder target = new StringBuilder(); @@ -587,7 +623,7 @@ public class CsrfPreventionRequestCycleListener extends AbstractRequestCycleList * @param page * the page that is targeted with this request */ - private void whitelistedHandler(HttpServletRequest request, String origin, + protected final void whitelistedHandler(HttpServletRequest request, String origin, IRequestablePage page) { onWhitelisted(request, origin, page); @@ -624,7 +660,8 @@ public class CsrfPreventionRequestCycleListener extends AbstractRequestCycleList * @param page * the page that is targeted with this request */ - private void matchingOrigin(HttpServletRequest request, String origin, IRequestablePage page) + protected final void matchingOrigin(HttpServletRequest request, String origin, + IRequestablePage page) { onMatchingOrigin(request, origin, page); if (log.isDebugEnabled()) @@ -662,7 +699,8 @@ public class CsrfPreventionRequestCycleListener extends AbstractRequestCycleList * @param page * the page that is targeted with this request */ - private void allowHandler(HttpServletRequest request, String origin, IRequestablePage page) + protected final void allowHandler(HttpServletRequest request, String origin, + IRequestablePage page) { onAllowed(request, origin, page); log.info("Possible CSRF attack, request URL: {}, Origin: {}, action: allowed", @@ -697,7 +735,8 @@ public class CsrfPreventionRequestCycleListener extends AbstractRequestCycleList * @param page * the page that is targeted with this request */ - private void suppressHandler(HttpServletRequest request, String origin, IRequestablePage page) + protected final void suppressHandler(HttpServletRequest request, String origin, + IRequestablePage page) { onSuppressed(request, origin, page); log.info("Possible CSRF attack, request URL: {}, Origin: {}, action: suppressed", @@ -733,7 +772,8 @@ public class CsrfPreventionRequestCycleListener extends AbstractRequestCycleList * @param page * the page that is targeted with this request */ - private void abortHandler(HttpServletRequest request, String origin, IRequestablePage page) + protected final void abortHandler(HttpServletRequest request, String origin, + IRequestablePage page) { onAborted(request, origin, page); log.info(