Author: hsaputra Date: Sat May 21 10:24:36 2011 New Revision: 1125661 URL: http://svn.apache.org/viewvc?rev=1125661&view=rev Log: Add injection to auth realm and sets auth header when InvalidAuthenticationException occur when sending 401 status. Update unit test.
CR at https://reviews.apache.org/r/760/ Modified: shindig/trunk/java/common/src/main/java/org/apache/shindig/auth/AuthenticationServletFilter.java shindig/trunk/java/common/src/test/java/org/apache/shindig/auth/AuthenticationServletFilterTest.java Modified: shindig/trunk/java/common/src/main/java/org/apache/shindig/auth/AuthenticationServletFilter.java URL: http://svn.apache.org/viewvc/shindig/trunk/java/common/src/main/java/org/apache/shindig/auth/AuthenticationServletFilter.java?rev=1125661&r1=1125660&r2=1125661&view=diff ============================================================================== --- shindig/trunk/java/common/src/main/java/org/apache/shindig/auth/AuthenticationServletFilter.java (original) +++ shindig/trunk/java/common/src/main/java/org/apache/shindig/auth/AuthenticationServletFilter.java Sat May 21 10:24:36 2011 @@ -21,9 +21,11 @@ import com.google.common.base.Charsets; import com.google.common.base.Preconditions; import com.google.inject.Inject; +import org.apache.shindig.common.Nullable; import org.apache.shindig.common.logging.i18n.MessageKeys; import org.apache.shindig.common.servlet.InjectedFilter; +import com.google.inject.name.Named; import java.io.BufferedReader; import java.io.ByteArrayInputStream; import java.io.IOException; @@ -53,17 +55,20 @@ import javax.servlet.http.HttpServletRes * additional handler. */ public class AuthenticationServletFilter extends InjectedFilter { - public static final String AUTH_TYPE_OAUTH = "OAuth"; - - // At some point change this to a container specific realm - private static final String REALM = "shindig"; - - private List<AuthenticationHandler> handlers; + public static final String WWW_AUTHENTICATE_HEADER = "WWW-Authenticate"; //class name for logging purpose private static final String CLASSNAME = AuthenticationServletFilter.class.getName(); private static final Logger LOG = Logger.getLogger(CLASSNAME, MessageKeys.MESSAGES); - + + private String realm = "shindig"; + private List<AuthenticationHandler> handlers; + + @Inject(optional = true) + public void setAuthenticationRealm(@Named("shindig.authentication.realm") String realm) { + this.realm = realm; + } + @Inject public void setAuthenticationHandlers(List<AuthenticationHandler> handlers) { this.handlers = handlers; @@ -71,18 +76,19 @@ public class AuthenticationServletFilter public void destroy() { } - public void doFilter(ServletRequest request, ServletResponse response, - FilterChain chain) throws IOException, ServletException { - + public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) + throws IOException, ServletException { if (!(request instanceof HttpServletRequest && response instanceof HttpServletResponse)) { throw new ServletException("Auth filter can only handle HTTP"); } HttpServletRequest req = (HttpServletRequest) request; HttpServletResponse resp = (HttpServletResponse) response; + String authHeader = null; try { for (AuthenticationHandler handler : handlers) { + authHeader = handler.getWWWAuthenticateHeader(getRealm(req)); SecurityToken token = handler.getSecurityTokenFromRequest(req); if (token != null) { AuthInfoUtil.setAuthTypeForRequest(req, handler.getName()); @@ -90,10 +96,8 @@ public class AuthenticationServletFilter callChain(chain, req, resp); return; } else { - String authHeader = handler.getWWWAuthenticateHeader(REALM); - if (authHeader != null) { - resp.addHeader("WWW-Authenticate", authHeader); - } + // Set auth header + setAuthHeader(authHeader, resp); } } @@ -105,7 +109,7 @@ public class AuthenticationServletFilter if (LOG.isLoggable(Level.INFO)) { LOG.logp(Level.INFO, CLASSNAME, "doFilter", MessageKeys.ERROR_PARSING_SECURE_TOKEN, cause); } - + if (iae.getAdditionalHeaders() != null) { for (Map.Entry<String,String> entry : iae.getAdditionalHeaders().entrySet()) { resp.addHeader(entry.getKey(), entry.getValue()); @@ -114,6 +118,9 @@ public class AuthenticationServletFilter if (iae.getRedirect() != null) { resp.sendRedirect(iae.getRedirect()); } else { + // Set auth header + setAuthHeader(authHeader, resp); + // For now append the cause message if set, this allows us to send any underlying oauth errors String message = (cause==null) ? iae.getMessage() : iae.getMessage() + cause.getMessage(); @@ -122,6 +129,20 @@ public class AuthenticationServletFilter } } + /** + * Override this to return container server specific realm. + * @return The authentication realm for this server. + */ + protected String getRealm(HttpServletRequest request) { + return realm; + } + + private void setAuthHeader(@Nullable String authHeader, HttpServletResponse response) { + if (authHeader != null) { + response.addHeader(WWW_AUTHENTICATE_HEADER, authHeader); + } + } + private void callChain(FilterChain chain, HttpServletRequest request, HttpServletResponse response) throws IOException, ServletException { if (request.getAttribute(AuthenticationHandler.STASHED_BODY) != null) { @@ -132,12 +153,10 @@ public class AuthenticationServletFilter } private static class StashedBodyRequestwrapper extends HttpServletRequestWrapper { - final InputStream rawStream; ServletInputStream stream; BufferedReader reader; - StashedBodyRequestwrapper(HttpServletRequest wrapped) { super(wrapped); rawStream = new ByteArrayInputStream( @@ -146,7 +165,8 @@ public class AuthenticationServletFilter @Override public ServletInputStream getInputStream() throws IOException { - Preconditions.checkState(reader == null, "The methods getInputStream() and getReader() are mutually exclusive."); + Preconditions.checkState(reader == null, + "The methods getInputStream() and getReader() are mutually exclusive."); if (stream == null) { stream = new ServletInputStream() { @@ -160,7 +180,8 @@ public class AuthenticationServletFilter @Override public BufferedReader getReader() throws IOException { - Preconditions.checkState(stream == null, "The methods getInputStream() and getReader() are mutually exclusive."); + Preconditions.checkState(stream == null, + "The methods getInputStream() and getReader() are mutually exclusive."); if (reader == null) { Charset charset = Charset.forName(getCharacterEncoding()); Modified: shindig/trunk/java/common/src/test/java/org/apache/shindig/auth/AuthenticationServletFilterTest.java URL: http://svn.apache.org/viewvc/shindig/trunk/java/common/src/test/java/org/apache/shindig/auth/AuthenticationServletFilterTest.java?rev=1125661&r1=1125660&r2=1125661&view=diff ============================================================================== --- shindig/trunk/java/common/src/test/java/org/apache/shindig/auth/AuthenticationServletFilterTest.java (original) +++ shindig/trunk/java/common/src/test/java/org/apache/shindig/auth/AuthenticationServletFilterTest.java Sat May 21 10:24:36 2011 @@ -17,19 +17,38 @@ */ package org.apache.shindig.auth; +import org.apache.shindig.common.EasyMockTestCase; +import org.apache.shindig.common.servlet.HttpServletResponseRecorder; + import com.google.common.collect.ImmutableList; +import javax.servlet.FilterChain; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import static org.easymock.EasyMock.expect; import org.junit.Before; import org.junit.Test; import javax.servlet.ServletException; -public class AuthenticationServletFilterTest { - AuthenticationServletFilter filter; +public class AuthenticationServletFilterTest extends EasyMockTestCase { + private static final String TEST_AUTH_HEADER = "Test Authentication Header"; + + private AuthenticationServletFilter filter; + + private HttpServletRequest request; + private HttpServletResponse response; + private HttpServletResponseRecorder recorder; + private FilterChain chain; + private AuthenticationHandler nullStHandler; @Before public void setup() { + request = mock(HttpServletRequest.class); + response = mock(HttpServletResponse.class); + recorder = new HttpServletResponseRecorder(response); + chain = mock(FilterChain.class); filter = new AuthenticationServletFilter(); - filter.setAuthenticationHandlers(ImmutableList.<AuthenticationHandler>of()); + nullStHandler = new NullSecurityTokenAuthenticationHandler(); } @Test(expected = ServletException.class) @@ -37,5 +56,29 @@ public class AuthenticationServletFilter filter.doFilter(null, null, null); } + @Test + public void testNullSecurityToken() throws Exception { + filter.setAuthenticationHandlers(ImmutableList.<AuthenticationHandler>of(nullStHandler)); + filter.doFilter(request, recorder, chain); + assertEquals(TEST_AUTH_HEADER, + recorder.getHeader(AuthenticationServletFilter.WWW_AUTHENTICATE_HEADER)); + } + private static class NullSecurityTokenAuthenticationHandler implements AuthenticationHandler { + @Override + public String getName() { + return "TestAuth"; + } + + @Override + public SecurityToken getSecurityTokenFromRequest(HttpServletRequest request) + throws InvalidAuthenticationException { + return null; + } + + @Override + public String getWWWAuthenticateHeader(String realm) { + return TEST_AUTH_HEADER; + } + } }
