http://git-wip-us.apache.org/repos/asf/knox/blob/8affbc02/gateway-provider-security-jwt/src/main/java/org/apache/knox/gateway/provider/federation/jwt/filter/AbstractJWTFilter.java ---------------------------------------------------------------------- diff --cc gateway-provider-security-jwt/src/main/java/org/apache/knox/gateway/provider/federation/jwt/filter/AbstractJWTFilter.java index 802019b,0000000..077fa05 mode 100644,000000..100644 --- a/gateway-provider-security-jwt/src/main/java/org/apache/knox/gateway/provider/federation/jwt/filter/AbstractJWTFilter.java +++ b/gateway-provider-security-jwt/src/main/java/org/apache/knox/gateway/provider/federation/jwt/filter/AbstractJWTFilter.java @@@ -1,278 -1,0 +1,278 @@@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.knox.gateway.provider.federation.jwt.filter; + +import java.io.IOException; +import java.security.Principal; +import java.security.PrivilegedActionException; +import java.security.PrivilegedExceptionAction; +import java.security.interfaces.RSAPublicKey; +import java.util.ArrayList; +import java.util.Date; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import javax.security.auth.Subject; +import javax.servlet.Filter; +import javax.servlet.FilterChain; +import javax.servlet.FilterConfig; +import javax.servlet.ServletContext; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.apache.knox.gateway.audit.api.Action; +import org.apache.knox.gateway.audit.api.ActionOutcome; +import org.apache.knox.gateway.audit.api.AuditContext; +import org.apache.knox.gateway.audit.api.AuditService; +import org.apache.knox.gateway.audit.api.AuditServiceFactory; +import org.apache.knox.gateway.audit.api.Auditor; +import org.apache.knox.gateway.audit.api.ResourceType; +import org.apache.knox.gateway.audit.log4j.audit.AuditConstants; +import org.apache.knox.gateway.filter.AbstractGatewayFilter; +import org.apache.knox.gateway.i18n.messages.MessagesFactory; +import org.apache.knox.gateway.provider.federation.jwt.JWTMessages; +import org.apache.knox.gateway.security.PrimaryPrincipal; +import org.apache.knox.gateway.services.GatewayServices; +import org.apache.knox.gateway.services.security.token.JWTokenAuthority; +import org.apache.knox.gateway.services.security.token.TokenServiceException; +import org.apache.knox.gateway.services.security.token.impl.JWTToken; + +/** + * + */ +public abstract class AbstractJWTFilter implements Filter { + /** + * If specified, this configuration property refers to a value which the issuer of a received + * token must match. Otherwise, the default value "KNOXSSO" is used + */ + public static final String JWT_EXPECTED_ISSUER = "jwt.expected.issuer"; + public static final String JWT_DEFAULT_ISSUER = "KNOXSSO"; + + static JWTMessages log = MessagesFactory.get( JWTMessages.class ); + private static AuditService auditService = AuditServiceFactory.getAuditService(); + private static Auditor auditor = auditService.getAuditor( + AuditConstants.DEFAULT_AUDITOR_NAME, AuditConstants.KNOX_SERVICE_NAME, + AuditConstants.KNOX_COMPONENT_NAME ); + + protected List<String> audiences; + protected JWTokenAuthority authority; + protected RSAPublicKey publicKey = null; + private String expectedIssuer; + + public abstract void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) + throws IOException, ServletException; + + /** + * + */ + public AbstractJWTFilter() { + super(); + } + + @Override + public void init( FilterConfig filterConfig ) throws ServletException { + ServletContext context = filterConfig.getServletContext(); + if (context != null) { + GatewayServices services = (GatewayServices) context.getAttribute(GatewayServices.GATEWAY_SERVICES_ATTRIBUTE); + if (services != null) { + authority = (JWTokenAuthority) services.getService(GatewayServices.TOKEN_SERVICE); + } + } + } + + protected void configureExpectedIssuer(FilterConfig filterConfig) { + expectedIssuer = filterConfig.getInitParameter(JWT_EXPECTED_ISSUER);; + if (expectedIssuer == null) { + expectedIssuer = JWT_DEFAULT_ISSUER; + } + } + + /** + * @param expectedAudiences + * @return + */ + protected List<String> parseExpectedAudiences(String expectedAudiences) { + ArrayList<String> audList = null; + // setup the list of valid audiences for token validation + if (expectedAudiences != null) { + // parse into the list + String[] audArray = expectedAudiences.split(","); + audList = new ArrayList<String>(); + for (String a : audArray) { - audList.add(a); ++ audList.add(a.trim()); + } + } + return audList; + } + + protected boolean tokenIsStillValid(JWTToken jwtToken) { + // if there is no expiration date then the lifecycle is tied entirely to + // the cookie validity - otherwise ensure that the current time is before + // the designated expiration time + Date expires = jwtToken.getExpiresDate(); + return (expires == null || expires != null && new Date().before(expires)); + } + + /** + * Validate whether any of the accepted audience claims is present in the + * issued token claims list for audience. Override this method in subclasses + * in order to customize the audience validation behavior. + * + * @param jwtToken + * the JWT token where the allowed audiences will be found + * @return true if an expected audience is present, otherwise false + */ + protected boolean validateAudiences(JWTToken jwtToken) { + boolean valid = false; + + String[] tokenAudienceList = jwtToken.getAudienceClaims(); + // if there were no expected audiences configured then just + // consider any audience acceptable + if (audiences == null) { + valid = true; + } else { + // if any of the configured audiences is found then consider it + // acceptable + if (tokenAudienceList != null) { + for (String aud : tokenAudienceList) { + if (audiences.contains(aud)) { + log.jwtAudienceValidated(); + valid = true; + break; + } + } + } + } + return valid; + } + + protected void continueWithEstablishedSecurityContext(Subject subject, final HttpServletRequest request, final HttpServletResponse response, final FilterChain chain) throws IOException, ServletException { + Principal principal = (Principal) subject.getPrincipals(PrimaryPrincipal.class).toArray()[0]; + AuditContext context = auditService.getContext(); + if (context != null) { + context.setUsername( principal.getName() ); + String sourceUri = (String)request.getAttribute( AbstractGatewayFilter.SOURCE_REQUEST_CONTEXT_URL_ATTRIBUTE_NAME ); + if (sourceUri != null) { + auditor.audit( Action.AUTHENTICATION , sourceUri, ResourceType.URI, ActionOutcome.SUCCESS ); + } + } + + try { + Subject.doAs( + subject, + new PrivilegedExceptionAction<Object>() { + @Override + public Object run() throws Exception { + chain.doFilter(request, response); + return null; + } + } + ); + } + catch (PrivilegedActionException e) { + Throwable t = e.getCause(); + if (t instanceof IOException) { + throw (IOException) t; + } + else if (t instanceof ServletException) { + throw (ServletException) t; + } + else { + throw new ServletException(t); + } + } + } + + protected Subject createSubjectFromToken(JWTToken token) { + final String principal = token.getSubject(); + + @SuppressWarnings("rawtypes") + HashSet emptySet = new HashSet(); + Set<Principal> principals = new HashSet<>(); + Principal p = new PrimaryPrincipal(principal); + principals.add(p); + + // The newly constructed Sets check whether this Subject has been set read-only + // before permitting subsequent modifications. The newly created Sets also prevent + // illegal modifications by ensuring that callers have sufficient permissions. + // + // To modify the Principals Set, the caller must have AuthPermission("modifyPrincipals"). + // To modify the public credential Set, the caller must have AuthPermission("modifyPublicCredentials"). + // To modify the private credential Set, the caller must have AuthPermission("modifyPrivateCredentials"). + javax.security.auth.Subject subject = new javax.security.auth.Subject(true, principals, emptySet, emptySet); + return subject; + } + + protected boolean validateToken(HttpServletRequest request, HttpServletResponse response, + FilterChain chain, JWTToken token) + throws IOException, ServletException { + boolean verified = false; + try { + if (publicKey == null) { + verified = authority.verifyToken(token); + } + else { + verified = authority.verifyToken(token, publicKey); + } + } catch (TokenServiceException e) { + log.unableToVerifyToken(e); + } + + if (verified) { + // confirm that issue matches intended target + if (expectedIssuer.equals(token.getIssuer())) { + // if there is no expiration data then the lifecycle is tied entirely to + // the cookie validity - otherwise ensure that the current time is before + // the designated expiration time + if (tokenIsStillValid(token)) { + boolean audValid = validateAudiences(token); + if (audValid) { + return true; + } + else { + log.failedToValidateAudience(); + handleValidationError(request, response, HttpServletResponse.SC_BAD_REQUEST, + "Bad request: missing required token audience"); + } + } + else { + log.tokenHasExpired(); + handleValidationError(request, response, HttpServletResponse.SC_BAD_REQUEST, + "Bad request: token has expired"); + } + } + else { + handleValidationError(request, response, HttpServletResponse.SC_UNAUTHORIZED, null); + } + } + else { + log.failedToVerifyTokenSignature(); + handleValidationError(request, response, HttpServletResponse.SC_UNAUTHORIZED, null); + } + + return false; + } + + protected abstract void handleValidationError(HttpServletRequest request, HttpServletResponse response, int status, + String error) throws IOException; + +}
http://git-wip-us.apache.org/repos/asf/knox/blob/8affbc02/gateway-provider-security-jwt/src/test/java/org/apache/knox/gateway/provider/federation/AbstractJWTFilterTest.java ---------------------------------------------------------------------- diff --cc gateway-provider-security-jwt/src/test/java/org/apache/knox/gateway/provider/federation/AbstractJWTFilterTest.java index 361a1ff,0000000..9888eab mode 100644,000000..100644 --- a/gateway-provider-security-jwt/src/test/java/org/apache/knox/gateway/provider/federation/AbstractJWTFilterTest.java +++ b/gateway-provider-security-jwt/src/test/java/org/apache/knox/gateway/provider/federation/AbstractJWTFilterTest.java @@@ -1,636 -1,0 +1,667 @@@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.knox.gateway.provider.federation; + +import static org.junit.Assert.fail; + +import java.io.IOException; +import java.net.InetAddress; +import java.security.AccessController; +import java.security.KeyPair; +import java.security.KeyPairGenerator; +import java.security.NoSuchAlgorithmException; +import java.security.Principal; +import java.security.PublicKey; +import java.security.cert.Certificate; +import java.security.interfaces.RSAPrivateKey; +import java.security.interfaces.RSAPublicKey; +import java.text.MessageFormat; +import java.util.Enumeration; +import java.util.List; +import java.util.ArrayList; +import java.util.Properties; +import java.util.Date; +import java.util.Set; + +import javax.security.auth.Subject; +import javax.servlet.FilterChain; +import javax.servlet.FilterConfig; +import javax.servlet.ServletContext; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.apache.commons.codec.binary.Base64; +import org.apache.knox.gateway.provider.federation.jwt.filter.AbstractJWTFilter; +import org.apache.knox.gateway.provider.federation.jwt.filter.SSOCookieFederationFilter; +import org.apache.knox.gateway.security.PrimaryPrincipal; +import org.apache.knox.gateway.services.security.impl.X509CertificateUtil; +import org.apache.knox.gateway.services.security.token.JWTokenAuthority; +import org.apache.knox.gateway.services.security.token.TokenServiceException; +import org.apache.knox.gateway.services.security.token.impl.JWT; +import org.apache.knox.gateway.services.security.token.impl.JWTToken; +import org.easymock.EasyMock; +import org.junit.After; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; + +import com.nimbusds.jose.*; +import com.nimbusds.jwt.JWTClaimsSet; +import com.nimbusds.jwt.SignedJWT; +import com.nimbusds.jose.crypto.RSASSASigner; +import com.nimbusds.jose.crypto.RSASSAVerifier; + +public abstract class AbstractJWTFilterTest { + private static final String SERVICE_URL = "https://localhost:8888/resource"; + private static final String dnTemplate = "CN={0},OU=Test,O=Hadoop,L=Test,ST=Test,C=US"; + + protected AbstractJWTFilter handler = null; + protected static RSAPublicKey publicKey = null; + protected static RSAPrivateKey privateKey = null; + protected static String pem = null; + + protected abstract void setTokenOnRequest(HttpServletRequest request, SignedJWT jwt); + protected abstract void setGarbledTokenOnRequest(HttpServletRequest request, SignedJWT jwt); + protected abstract String getAudienceProperty(); + protected abstract String getVerificationPemProperty(); + + private static String buildDistinguishedName(String hostname) { + MessageFormat headerFormatter = new MessageFormat(dnTemplate); + String[] paramArray = new String[1]; + paramArray[0] = hostname; + String dn = headerFormatter.format(paramArray); + return dn; + } + + @BeforeClass + public static void generateKeys() throws Exception, NoSuchAlgorithmException { + KeyPairGenerator kpg = KeyPairGenerator.getInstance("RSA"); + kpg.initialize(2048); + KeyPair KPair = kpg.generateKeyPair(); + String dn = buildDistinguishedName(InetAddress.getLocalHost().getHostName()); + Certificate cert = X509CertificateUtil.generateCertificate(dn, KPair, 365, "SHA1withRSA"); + byte[] data = cert.getEncoded(); + Base64 encoder = new Base64( 76, "\n".getBytes( "ASCII" ) ); + pem = new String(encoder.encodeToString( data ).getBytes( "ASCII" )).trim(); + + publicKey = (RSAPublicKey) KPair.getPublic(); + privateKey = (RSAPrivateKey) KPair.getPrivate(); + } + + @After + public void teardown() throws Exception { + handler.destroy(); + } + + @Test + public void testValidJWT() throws Exception { + try { + Properties props = getProperties(); + handler.init(new TestFilterConfig(props)); + + SignedJWT jwt = getJWT("alice", new Date(new Date().getTime() + 5000), privateKey, props); + + HttpServletRequest request = EasyMock.createNiceMock(HttpServletRequest.class); + setTokenOnRequest(request, jwt); + + EasyMock.expect(request.getRequestURL()).andReturn( + new StringBuffer(SERVICE_URL)).anyTimes(); + EasyMock.expect(request.getQueryString()).andReturn(null); + HttpServletResponse response = EasyMock.createNiceMock(HttpServletResponse.class); + EasyMock.expect(response.encodeRedirectURL(SERVICE_URL)).andReturn( + SERVICE_URL); + EasyMock.replay(request); + + TestFilterChain chain = new TestFilterChain(); + handler.doFilter(request, response, chain); + Assert.assertTrue("doFilterCalled should not be false.", chain.doFilterCalled ); + Set<PrimaryPrincipal> principals = chain.subject.getPrincipals(PrimaryPrincipal.class); + Assert.assertTrue("No PrimaryPrincipal", !principals.isEmpty()); + Assert.assertEquals("Not the expected principal", "alice", ((Principal)principals.toArray()[0]).getName()); + } catch (ServletException se) { + fail("Should NOT have thrown a ServletException."); + } + } + + @Test + public void testValidAudienceJWT() throws Exception { + try { + Properties props = getProperties(); + props.put(getAudienceProperty(), "bar"); + handler.init(new TestFilterConfig(props)); + + SignedJWT jwt = getJWT("alice", new Date(new Date().getTime() + 5000), privateKey, props); + + HttpServletRequest request = EasyMock.createNiceMock(HttpServletRequest.class); + setTokenOnRequest(request, jwt); + + EasyMock.expect(request.getRequestURL()).andReturn( + new StringBuffer(SERVICE_URL)).anyTimes(); + EasyMock.expect(request.getQueryString()).andReturn(null); + HttpServletResponse response = EasyMock.createNiceMock(HttpServletResponse.class); + EasyMock.expect(response.encodeRedirectURL(SERVICE_URL)).andReturn( + SERVICE_URL); + EasyMock.replay(request); + + TestFilterChain chain = new TestFilterChain(); + handler.doFilter(request, response, chain); + Assert.assertTrue("doFilterCalled should not be false.", chain.doFilterCalled ); + Set<PrimaryPrincipal> principals = chain.subject.getPrincipals(PrimaryPrincipal.class); + Assert.assertTrue("No PrimaryPrincipal", !principals.isEmpty()); + Assert.assertEquals("Not the expected principal", "alice", ((Principal)principals.toArray()[0]).getName()); + } catch (ServletException se) { + fail("Should NOT have thrown a ServletException."); + } + } + + @Test + public void testInvalidAudienceJWT() throws Exception { + try { + Properties props = getProperties(); + props.put(getAudienceProperty(), "foo"); + props.put("sso.authentication.provider.url", "https://localhost:8443/gateway/knoxsso/api/v1/websso"); + + handler.init(new TestFilterConfig(props)); + + SignedJWT jwt = getJWT("alice", new Date(new Date().getTime() + 5000), privateKey, props); + + HttpServletRequest request = EasyMock.createNiceMock(HttpServletRequest.class); + setTokenOnRequest(request, jwt); + + EasyMock.expect(request.getRequestURL()).andReturn( + new StringBuffer(SERVICE_URL)).anyTimes(); + EasyMock.expect(request.getQueryString()).andReturn(null); + HttpServletResponse response = EasyMock.createNiceMock(HttpServletResponse.class); + EasyMock.expect(response.encodeRedirectURL(SERVICE_URL)).andReturn( + SERVICE_URL); + EasyMock.replay(request); + + TestFilterChain chain = new TestFilterChain(); + handler.doFilter(request, response, chain); + Assert.assertTrue("doFilterCalled should not be true.", !chain.doFilterCalled); + Assert.assertTrue("No Subject should be returned.", chain.subject == null); + } catch (ServletException se) { + fail("Should NOT have thrown a ServletException."); + } + } + + @Test ++ public void testValidAudienceJWTWhitespace() throws Exception { ++ try { ++ Properties props = getProperties(); ++ props.put(getAudienceProperty(), " foo, bar "); ++ handler.init(new TestFilterConfig(props)); ++ ++ SignedJWT jwt = getJWT("alice", new Date(new Date().getTime() + 5000), privateKey, props); ++ ++ HttpServletRequest request = EasyMock.createNiceMock(HttpServletRequest.class); ++ setTokenOnRequest(request, jwt); ++ ++ EasyMock.expect(request.getRequestURL()).andReturn( ++ new StringBuffer(SERVICE_URL)).anyTimes(); ++ EasyMock.expect(request.getQueryString()).andReturn(null); ++ HttpServletResponse response = EasyMock.createNiceMock(HttpServletResponse.class); ++ EasyMock.expect(response.encodeRedirectURL(SERVICE_URL)).andReturn( ++ SERVICE_URL); ++ EasyMock.replay(request); ++ ++ TestFilterChain chain = new TestFilterChain(); ++ handler.doFilter(request, response, chain); ++ Assert.assertTrue("doFilterCalled should not be false.", chain.doFilterCalled ); ++ Set<PrimaryPrincipal> principals = chain.subject.getPrincipals(PrimaryPrincipal.class); ++ Assert.assertTrue("No PrimaryPrincipal", !principals.isEmpty()); ++ Assert.assertEquals("Not the expected principal", "alice", ((Principal)principals.toArray()[0]).getName()); ++ } catch (ServletException se) { ++ fail("Should NOT have thrown a ServletException."); ++ } ++ } ++ ++ @Test + public void testValidVerificationPEM() throws Exception { + try { + Properties props = getProperties(); + +// System.out.println("+" + pem + "+"); + + props.put(getAudienceProperty(), "bar"); + props.put("sso.authentication.provider.url", "https://localhost:8443/gateway/knoxsso/api/v1/websso"); + props.put(getVerificationPemProperty(), pem); + handler.init(new TestFilterConfig(props)); + + SignedJWT jwt = getJWT("alice", new Date(new Date().getTime() + 50000), privateKey, props); + + HttpServletRequest request = EasyMock.createNiceMock(HttpServletRequest.class); + setTokenOnRequest(request, jwt); + + EasyMock.expect(request.getRequestURL()).andReturn( + new StringBuffer(SERVICE_URL)).anyTimes(); + EasyMock.expect(request.getQueryString()).andReturn(null); + HttpServletResponse response = EasyMock.createNiceMock(HttpServletResponse.class); + EasyMock.expect(response.encodeRedirectURL(SERVICE_URL)).andReturn( + SERVICE_URL); + EasyMock.replay(request); + + TestFilterChain chain = new TestFilterChain(); + handler.doFilter(request, response, chain); + Assert.assertTrue("doFilterCalled should not be false.", chain.doFilterCalled ); + Set<PrimaryPrincipal> principals = chain.subject.getPrincipals(PrimaryPrincipal.class); + Assert.assertTrue("No PrimaryPrincipal", !principals.isEmpty()); + Assert.assertEquals("Not the expected principal", "alice", ((Principal)principals.toArray()[0]).getName()); + } catch (ServletException se) { + fail("Should NOT have thrown a ServletException."); + } + } + + @Test + public void testExpiredJWT() throws Exception { + try { + Properties props = getProperties(); + handler.init(new TestFilterConfig(props)); + + SignedJWT jwt = getJWT("alice", new Date(new Date().getTime() - 1000), privateKey, props); + + HttpServletRequest request = EasyMock.createNiceMock(HttpServletRequest.class); + setTokenOnRequest(request, jwt); + + EasyMock.expect(request.getRequestURL()).andReturn( + new StringBuffer(SERVICE_URL)).anyTimes(); + EasyMock.expect(request.getQueryString()).andReturn(null); + HttpServletResponse response = EasyMock.createNiceMock(HttpServletResponse.class); + EasyMock.expect(response.encodeRedirectURL(SERVICE_URL)).andReturn( + SERVICE_URL); + EasyMock.replay(request); + + TestFilterChain chain = new TestFilterChain(); + handler.doFilter(request, response, chain); + Assert.assertTrue("doFilterCalled should not be false.", !chain.doFilterCalled); + Assert.assertTrue("No Subject should be returned.", chain.subject == null); + } catch (ServletException se) { + fail("Should NOT have thrown a ServletException."); + } + } + + @Test + public void testValidJWTNoExpiration() throws Exception { + try { + Properties props = getProperties(); + handler.init(new TestFilterConfig(props)); + + SignedJWT jwt = getJWT("alice", null, privateKey, props); + + HttpServletRequest request = EasyMock.createNiceMock(HttpServletRequest.class); + setTokenOnRequest(request, jwt); + + EasyMock.expect(request.getRequestURL()).andReturn( + new StringBuffer(SERVICE_URL)).anyTimes(); + EasyMock.expect(request.getQueryString()).andReturn(null); + HttpServletResponse response = EasyMock.createNiceMock(HttpServletResponse.class); + EasyMock.expect(response.encodeRedirectURL(SERVICE_URL)).andReturn( + SERVICE_URL).anyTimes(); + EasyMock.replay(request); + + TestFilterChain chain = new TestFilterChain(); + handler.doFilter(request, response, chain); + Assert.assertTrue("doFilterCalled should not be false.", chain.doFilterCalled ); + Set<PrimaryPrincipal> principals = chain.subject.getPrincipals(PrimaryPrincipal.class); + Assert.assertTrue("No PrimaryPrincipal", !principals.isEmpty()); + Assert.assertEquals("Not the expected principal", "alice", ((Principal)principals.toArray()[0]).getName()); + } catch (ServletException se) { + fail("Should NOT have thrown a ServletException."); + } + } + + @Test + public void testUnableToParseJWT() throws Exception { + try { + Properties props = getProperties(); + handler.init(new TestFilterConfig(props)); + + SignedJWT jwt = getJWT("bob", new Date(new Date().getTime() + 5000), privateKey, props); + + HttpServletRequest request = EasyMock.createNiceMock(HttpServletRequest.class); + setGarbledTokenOnRequest(request, jwt); + + EasyMock.expect(request.getRequestURL()).andReturn( + new StringBuffer(SERVICE_URL)).anyTimes(); + EasyMock.expect(request.getQueryString()).andReturn(null); + HttpServletResponse response = EasyMock.createNiceMock(HttpServletResponse.class); + EasyMock.expect(response.encodeRedirectURL(SERVICE_URL)).andReturn( + SERVICE_URL).anyTimes(); + EasyMock.replay(request); + + TestFilterChain chain = new TestFilterChain(); + handler.doFilter(request, response, chain); + Assert.assertTrue("doFilterCalled should not be true.", !chain.doFilterCalled); + Assert.assertTrue("No Subject should be returned.", chain.subject == null); + } catch (ServletException se) { + fail("Should NOT have thrown a ServletException."); + } + } + + @Test + public void testFailedSignatureValidationJWT() throws Exception { + try { + // Create a private key to sign the token + KeyPairGenerator kpg = KeyPairGenerator.getInstance("RSA"); + kpg.initialize(1024); + + KeyPair kp = kpg.genKeyPair(); + + Properties props = getProperties(); + handler.init(new TestFilterConfig(props)); + + SignedJWT jwt = getJWT("bob", new Date(new Date().getTime() + 5000), + (RSAPrivateKey)kp.getPrivate(), props); + + HttpServletRequest request = EasyMock.createNiceMock(HttpServletRequest.class); + setTokenOnRequest(request, jwt); + + EasyMock.expect(request.getRequestURL()).andReturn( + new StringBuffer(SERVICE_URL)).anyTimes(); + EasyMock.expect(request.getQueryString()).andReturn(null); + HttpServletResponse response = EasyMock.createNiceMock(HttpServletResponse.class); + EasyMock.expect(response.encodeRedirectURL(SERVICE_URL)).andReturn( + SERVICE_URL).anyTimes(); + EasyMock.replay(request); + + TestFilterChain chain = new TestFilterChain(); + handler.doFilter(request, response, chain); + Assert.assertTrue("doFilterCalled should not be true.", !chain.doFilterCalled); + Assert.assertTrue("No Subject should be returned.", chain.subject == null); + } catch (ServletException se) { + fail("Should NOT have thrown a ServletException."); + } + } + + @Test + public void testInvalidVerificationPEM() throws Exception { + try { + Properties props = getProperties(); + + KeyPairGenerator kpg = KeyPairGenerator.getInstance("RSA"); + kpg.initialize(1024); + + KeyPair KPair = kpg.generateKeyPair(); + String dn = buildDistinguishedName(InetAddress.getLocalHost().getHostName()); + Certificate cert = X509CertificateUtil.generateCertificate(dn, KPair, 365, "SHA1withRSA"); + byte[] data = cert.getEncoded(); + Base64 encoder = new Base64( 76, "\n".getBytes( "ASCII" ) ); + String failingPem = new String(encoder.encodeToString( data ).getBytes( "ASCII" )).trim(); + + props.put(getAudienceProperty(), "bar"); + props.put(getVerificationPemProperty(), failingPem); + handler.init(new TestFilterConfig(props)); + + SignedJWT jwt = getJWT("alice", new Date(new Date().getTime() + 50000), privateKey, props); + + HttpServletRequest request = EasyMock.createNiceMock(HttpServletRequest.class); + setTokenOnRequest(request, jwt); + + EasyMock.expect(request.getRequestURL()).andReturn( + new StringBuffer(SERVICE_URL)).anyTimes(); + EasyMock.expect(request.getQueryString()).andReturn(null); + HttpServletResponse response = EasyMock.createNiceMock(HttpServletResponse.class); + EasyMock.expect(response.encodeRedirectURL(SERVICE_URL)).andReturn(SERVICE_URL); + EasyMock.replay(request); + + TestFilterChain chain = new TestFilterChain(); + handler.doFilter(request, response, chain); + Assert.assertTrue("doFilterCalled should not be true.", chain.doFilterCalled == false); + Assert.assertTrue("No Subject should be returned.", chain.subject == null); + } catch (ServletException se) { + fail("Should NOT have thrown a ServletException."); + } + } + + @Test + public void testInvalidIssuer() throws Exception { + try { + Properties props = getProperties(); + handler.init(new TestFilterConfig(props)); + + SignedJWT jwt = getJWT("new-issuer", "alice", new Date(new Date().getTime() + 5000), privateKey); + + HttpServletRequest request = EasyMock.createNiceMock(HttpServletRequest.class); + setTokenOnRequest(request, jwt); + + EasyMock.expect(request.getRequestURL()).andReturn( + new StringBuffer(SERVICE_URL)).anyTimes(); + EasyMock.expect(request.getQueryString()).andReturn(null); + HttpServletResponse response = EasyMock.createNiceMock(HttpServletResponse.class); + EasyMock.expect(response.encodeRedirectURL(SERVICE_URL)).andReturn( + SERVICE_URL); + EasyMock.replay(request); + + TestFilterChain chain = new TestFilterChain(); + handler.doFilter(request, response, chain); + Assert.assertTrue("doFilterCalled should not be true.", !chain.doFilterCalled); + Assert.assertTrue("No Subject should be returned.", chain.subject == null); + } catch (ServletException se) { + fail("Should NOT have thrown a ServletException."); + } + } + + @Test + public void testValidIssuerViaConfig() throws Exception { + try { + Properties props = getProperties(); + props.setProperty(AbstractJWTFilter.JWT_EXPECTED_ISSUER, "new-issuer"); + handler.init(new TestFilterConfig(props)); + + SignedJWT jwt = getJWT("new-issuer", "alice", new Date(new Date().getTime() + 5000), privateKey); + + HttpServletRequest request = EasyMock.createNiceMock(HttpServletRequest.class); + setTokenOnRequest(request, jwt); + + EasyMock.expect(request.getRequestURL()).andReturn( + new StringBuffer(SERVICE_URL)).anyTimes(); + EasyMock.expect(request.getQueryString()).andReturn(null); + HttpServletResponse response = EasyMock.createNiceMock(HttpServletResponse.class); + EasyMock.expect(response.encodeRedirectURL(SERVICE_URL)).andReturn( + SERVICE_URL); + EasyMock.replay(request); + + TestFilterChain chain = new TestFilterChain(); + handler.doFilter(request, response, chain); + Assert.assertTrue("doFilterCalled should not be false.", chain.doFilterCalled); + Set<PrimaryPrincipal> principals = chain.subject.getPrincipals(PrimaryPrincipal.class); + Assert.assertTrue("No PrimaryPrincipal", principals.size() > 0); + Assert.assertEquals("Not the expected principal", "alice", ((Principal)principals.toArray()[0]).getName()); + } catch (ServletException se) { + fail("Should NOT have thrown a ServletException."); + } + } + + protected Properties getProperties() { + Properties props = new Properties(); + props.setProperty( + SSOCookieFederationFilter.SSO_AUTHENTICATION_PROVIDER_URL, + "https://localhost:8443/authserver"); + return props; + } + + protected SignedJWT getJWT(String sub, Date expires, RSAPrivateKey privateKey, + Properties props) throws Exception { + return getJWT(AbstractJWTFilter.JWT_DEFAULT_ISSUER, sub, expires, privateKey); + } + + protected SignedJWT getJWT(String issuer, String sub, Date expires, RSAPrivateKey privateKey) + throws Exception { + List<String> aud = new ArrayList<String>(); + aud.add("bar"); + + JWTClaimsSet claims = new JWTClaimsSet.Builder() + .issuer(issuer) + .subject(sub) + .audience(aud) + .expirationTime(expires) + .claim("scope", "openid") + .build(); + + JWSHeader header = new JWSHeader.Builder(JWSAlgorithm.RS256).build(); + + SignedJWT signedJWT = new SignedJWT(header, claims); + JWSSigner signer = new RSASSASigner(privateKey); + + signedJWT.sign(signer); + + return signedJWT; + } + + protected static class TestFilterConfig implements FilterConfig { + Properties props = null; + + public TestFilterConfig(Properties props) { + this.props = props; + } + + @Override + public String getFilterName() { + return null; + } + + /* (non-Javadoc) + * @see javax.servlet.FilterConfig#getServletContext() + */ + @Override + public ServletContext getServletContext() { +// JWTokenAuthority authority = EasyMock.createNiceMock(JWTokenAuthority.class); +// GatewayServices services = EasyMock.createNiceMock(GatewayServices.class); +// EasyMock.expect(services.getService("TokenService").andReturn(authority)); +// ServletContext context = EasyMock.createNiceMock(ServletContext.class); +// EasyMock.expect(context.getAttribute(GatewayServices.GATEWAY_SERVICES_ATTRIBUTE).andReturn(new DefaultGatewayServices())); + return null; + } + + /* (non-Javadoc) + * @see javax.servlet.FilterConfig#getInitParameter(java.lang.String) + */ + @Override + public String getInitParameter(String name) { + return props.getProperty(name, null); + } + + /* (non-Javadoc) + * @see javax.servlet.FilterConfig#getInitParameterNames() + */ + @Override + public Enumeration<String> getInitParameterNames() { + return null; + } + + } + + protected static class TestJWTokenAuthority implements JWTokenAuthority { + + private PublicKey verifyingKey; + + public TestJWTokenAuthority(PublicKey verifyingKey) { + this.verifyingKey = verifyingKey; + } + + /* (non-Javadoc) + * @see JWTokenAuthority#issueToken(javax.security.auth.Subject, java.lang.String) + */ + @Override + public JWT issueToken(Subject subject, String algorithm) + throws TokenServiceException { + // TODO Auto-generated method stub + return null; + } + + /* (non-Javadoc) + * @see JWTokenAuthority#issueToken(java.security.Principal, java.lang.String) + */ + @Override + public JWT issueToken(Principal p, String algorithm) + throws TokenServiceException { + // TODO Auto-generated method stub + return null; + } + + /* (non-Javadoc) + * @see JWTokenAuthority#issueToken(java.security.Principal, java.lang.String, java.lang.String) + */ + @Override + public JWT issueToken(Principal p, String audience, String algorithm) + throws TokenServiceException { + return null; + } + + /* (non-Javadoc) + * @see org.apache.knox.gateway.services.security.token.JWTokenAuthority#verifyToken(org.apache.knox.gateway.services.security.token.impl.JWT) + */ + @Override + public boolean verifyToken(JWT token) throws TokenServiceException { + JWSVerifier verifier = new RSASSAVerifier((RSAPublicKey) verifyingKey); + return token.verify(verifier); + } + + /* (non-Javadoc) + * @see JWTokenAuthority#issueToken(java.security.Principal, java.lang.String, java.lang.String, long) + */ + @Override + public JWT issueToken(Principal p, String audience, String algorithm, + long expires) throws TokenServiceException { + return null; + } + + @Override + public JWT issueToken(Principal p, List<String> audiences, String algorithm, + long expires) throws TokenServiceException { + return null; + } + + /* (non-Javadoc) + * @see JWTokenAuthority#issueToken(java.security.Principal, java.lang.String, long) + */ + @Override + public JWT issueToken(Principal p, String algorithm, long expires) + throws TokenServiceException { + // TODO Auto-generated method stub + return null; + } + + @Override + public boolean verifyToken(JWT token, RSAPublicKey publicKey) throws TokenServiceException { + JWSVerifier verifier = new RSASSAVerifier(publicKey); + return token.verify(verifier); + } + + } + + protected static class TestFilterChain implements FilterChain { + boolean doFilterCalled = false; + Subject subject = null; + + /* (non-Javadoc) + * @see javax.servlet.FilterChain#doFilter(javax.servlet.ServletRequest, javax.servlet.ServletResponse) + */ + @Override + public void doFilter(ServletRequest request, ServletResponse response) + throws IOException, ServletException { + doFilterCalled = true; + + subject = Subject.getSubject( AccessController.getContext() ); + } + + } +} http://git-wip-us.apache.org/repos/asf/knox/blob/8affbc02/gateway-provider-security-picketlink/src/main/java/org/apache/knox/gateway/picketlink/PicketlinkMessages.java ---------------------------------------------------------------------- diff --cc gateway-provider-security-picketlink/src/main/java/org/apache/knox/gateway/picketlink/PicketlinkMessages.java index 86f2854,0000000..e69de29 mode 100644,000000..100644 --- a/gateway-provider-security-picketlink/src/main/java/org/apache/knox/gateway/picketlink/PicketlinkMessages.java +++ b/gateway-provider-security-picketlink/src/main/java/org/apache/knox/gateway/picketlink/PicketlinkMessages.java http://git-wip-us.apache.org/repos/asf/knox/blob/8affbc02/gateway-provider-security-picketlink/src/main/java/org/apache/knox/gateway/picketlink/deploy/PicketlinkConf.java ---------------------------------------------------------------------- diff --cc gateway-provider-security-picketlink/src/main/java/org/apache/knox/gateway/picketlink/deploy/PicketlinkConf.java index 5b3b6e0,0000000..e69de29 mode 100644,000000..100644 --- a/gateway-provider-security-picketlink/src/main/java/org/apache/knox/gateway/picketlink/deploy/PicketlinkConf.java +++ b/gateway-provider-security-picketlink/src/main/java/org/apache/knox/gateway/picketlink/deploy/PicketlinkConf.java http://git-wip-us.apache.org/repos/asf/knox/blob/8affbc02/gateway-provider-security-picketlink/src/main/java/org/apache/knox/gateway/picketlink/deploy/PicketlinkFederationProviderContributor.java ---------------------------------------------------------------------- diff --cc gateway-provider-security-picketlink/src/main/java/org/apache/knox/gateway/picketlink/deploy/PicketlinkFederationProviderContributor.java index d13bdaa,0000000..e69de29 mode 100644,000000..100644 --- a/gateway-provider-security-picketlink/src/main/java/org/apache/knox/gateway/picketlink/deploy/PicketlinkFederationProviderContributor.java +++ b/gateway-provider-security-picketlink/src/main/java/org/apache/knox/gateway/picketlink/deploy/PicketlinkFederationProviderContributor.java http://git-wip-us.apache.org/repos/asf/knox/blob/8affbc02/gateway-provider-security-picketlink/src/main/java/org/apache/knox/gateway/picketlink/filter/CaptureOriginalURLFilter.java ---------------------------------------------------------------------- diff --cc gateway-provider-security-picketlink/src/main/java/org/apache/knox/gateway/picketlink/filter/CaptureOriginalURLFilter.java index b062013,0000000..e69de29 mode 100644,000000..100644 --- a/gateway-provider-security-picketlink/src/main/java/org/apache/knox/gateway/picketlink/filter/CaptureOriginalURLFilter.java +++ b/gateway-provider-security-picketlink/src/main/java/org/apache/knox/gateway/picketlink/filter/CaptureOriginalURLFilter.java http://git-wip-us.apache.org/repos/asf/knox/blob/8affbc02/gateway-provider-security-picketlink/src/main/java/org/apache/knox/gateway/picketlink/filter/PicketlinkIdentityAdapter.java ---------------------------------------------------------------------- diff --cc gateway-provider-security-picketlink/src/main/java/org/apache/knox/gateway/picketlink/filter/PicketlinkIdentityAdapter.java index e3811b4,0000000..e69de29 mode 100644,000000..100644 --- a/gateway-provider-security-picketlink/src/main/java/org/apache/knox/gateway/picketlink/filter/PicketlinkIdentityAdapter.java +++ b/gateway-provider-security-picketlink/src/main/java/org/apache/knox/gateway/picketlink/filter/PicketlinkIdentityAdapter.java http://git-wip-us.apache.org/repos/asf/knox/blob/8affbc02/gateway-provider-security-picketlink/src/test/java/org/apache/knox/gateway/picketlink/PicketlinkTest.java ---------------------------------------------------------------------- diff --cc gateway-provider-security-picketlink/src/test/java/org/apache/knox/gateway/picketlink/PicketlinkTest.java index a0cd7be,0000000..e69de29 mode 100644,000000..100644 --- a/gateway-provider-security-picketlink/src/test/java/org/apache/knox/gateway/picketlink/PicketlinkTest.java +++ b/gateway-provider-security-picketlink/src/test/java/org/apache/knox/gateway/picketlink/PicketlinkTest.java http://git-wip-us.apache.org/repos/asf/knox/blob/8affbc02/gateway-release/pom.xml ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/knox/blob/8affbc02/gateway-server/src/main/java/org/apache/knox/gateway/services/registry/impl/DefaultServiceRegistryService.java ---------------------------------------------------------------------- diff --cc gateway-server/src/main/java/org/apache/knox/gateway/services/registry/impl/DefaultServiceRegistryService.java index 84330c7,0000000..075eda1 mode 100644,000000..100644 --- a/gateway-server/src/main/java/org/apache/knox/gateway/services/registry/impl/DefaultServiceRegistryService.java +++ b/gateway-server/src/main/java/org/apache/knox/gateway/services/registry/impl/DefaultServiceRegistryService.java @@@ -1,207 -1,0 +1,207 @@@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.knox.gateway.services.registry.impl; + +import com.fasterxml.jackson.core.JsonFactory; +import com.fasterxml.jackson.core.JsonParseException; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.JsonMappingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.commons.codec.binary.Base64; +import org.apache.commons.io.FileUtils; +import org.apache.knox.gateway.GatewayMessages; +import org.apache.knox.gateway.config.GatewayConfig; +import org.apache.knox.gateway.i18n.messages.MessagesFactory; +import org.apache.knox.gateway.services.Service; +import org.apache.knox.gateway.services.ServiceLifecycleException; +import org.apache.knox.gateway.services.registry.ServiceRegistry; +import org.apache.knox.gateway.services.security.CryptoService; + +import java.io.File; +import java.io.IOException; ++import java.security.SecureRandom; +import java.util.HashMap; +import java.util.List; +import java.util.Map; - import java.util.Random; + +public class DefaultServiceRegistryService implements ServiceRegistry, Service { + private static GatewayMessages LOG = MessagesFactory.get( GatewayMessages.class ); - ++ + protected char[] chars = { 'a', 'b', 'c', 'd', 'e', 'f', 'g', + 'h', 'j', 'k', 'm', 'n', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', + 'x', 'y', 'z', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J', 'K', + 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', + '2', '3', '4', '5', '6', '7', '8', '9',}; + + private CryptoService crypto; + private Registry registry = new Registry(); + + private String registryFileName; - ++ + public DefaultServiceRegistryService() { + } - ++ + public void setCryptoService(CryptoService crypto) { + this.crypto = crypto; + } - ++ + public String getRegistrationCode(String clusterName) { + String code = generateRegCode(16); + byte[] signature = crypto.sign("SHA256withRSA","gateway-identity",code); + String encodedSig = Base64.encodeBase64URLSafeString(signature); - ++ + return code + "::" + encodedSig; + } - ++ + private String generateRegCode(int length) { - StringBuffer sb = new StringBuffer(); - Random r = new Random(); ++ StringBuilder sb = new StringBuilder(); ++ SecureRandom r = new SecureRandom(); + for (int i = 0; i < length; i++) { + sb.append(chars[r.nextInt(chars.length)]); + } + return sb.toString(); + } - ++ + public void removeClusterServices(String clusterName) { + registry.remove(clusterName); + } + + public boolean registerService(String regCode, String clusterName, String serviceName, List<String> urls) { + boolean rc = false; + // verify the signature of the regCode + if (regCode == null) { + throw new IllegalArgumentException("Registration Code must not be null."); + } + String[] parts = regCode.split("::"); - ++ + // part one is the code and part two is the signature + boolean verified = crypto.verify("SHA256withRSA", "gateway-identity", parts[0], Base64.decodeBase64(parts[1])); + if (verified) { + HashMap<String,RegEntry> clusterServices = registry.get(clusterName); + if (clusterServices == null) { + synchronized(this) { + clusterServices = new HashMap<>(); + registry.put(clusterName, clusterServices); + } + } + RegEntry regEntry = new RegEntry(); + regEntry.setClusterName(clusterName); + regEntry.setServiceName(serviceName); + regEntry.setUrls(urls); + clusterServices.put(serviceName , regEntry); + String json = renderAsJsonString(registry); + try { + FileUtils.write(new File(registryFileName), json); + rc = true; + } catch (IOException e) { + // log appropriately + e.printStackTrace(); //TODO: I18N + } + } - ++ + return rc; + } - ++ + private String renderAsJsonString(HashMap<String,HashMap<String,RegEntry>> registry) { + String json = null; + ObjectMapper mapper = new ObjectMapper(); - ++ + try { + // write JSON to a file + json = mapper.writeValueAsString((Object)registry); - ++ + } catch ( JsonProcessingException e ) { + e.printStackTrace(); //TODO: I18N + } + return json; + } - ++ + @Override + public String lookupServiceURL(String clusterName, String serviceName) { + List<String> urls = lookupServiceURLs( clusterName, serviceName ); + if ( urls != null && !urls.isEmpty() ) { + return urls.get( 0 ); + } + return null; + } + + @Override + public List<String> lookupServiceURLs( String clusterName, String serviceName ) { + RegEntry entry = null; - HashMap clusterServices = registry.get(clusterName); ++ HashMap<String, RegEntry> clusterServices = registry.get(clusterName); + if (clusterServices != null) { - entry = (RegEntry) clusterServices.get(serviceName); ++ entry = clusterServices.get(serviceName); + if( entry != null ) { + return entry.getUrls(); + } + } + return null; + } - ++ + private HashMap<String, HashMap<String,RegEntry>> getMapFromJsonString(String json) { + Registry map = null; - JsonFactory factory = new JsonFactory(); - ObjectMapper mapper = new ObjectMapper(factory); - TypeReference<Registry> typeRef - = new TypeReference<Registry>() {}; ++ JsonFactory factory = new JsonFactory(); ++ ObjectMapper mapper = new ObjectMapper(factory); ++ TypeReference<Registry> typeRef ++ = new TypeReference<Registry>() {}; + try { + map = mapper.readValue(json, typeRef); + } catch (JsonParseException e) { + LOG.failedToGetMapFromJsonString( json, e ); + } catch (JsonMappingException e) { + LOG.failedToGetMapFromJsonString( json, e ); + } catch (IOException e) { + LOG.failedToGetMapFromJsonString( json, e ); - } ++ } + return map; - } ++ } + + @Override + public void init(GatewayConfig config, Map<String, String> options) + throws ServiceLifecycleException { + String securityDir = config.getGatewaySecurityDir(); + String filename = "registry"; + setupRegistryFile(securityDir, filename); + } + + protected void setupRegistryFile(String securityDir, String filename) throws ServiceLifecycleException { + File registryFile = new File(securityDir, filename); + if (registryFile.exists()) { + try { + String json = FileUtils.readFileToString(registryFile); + Registry reg = (Registry) getMapFromJsonString(json); + if (reg != null) { + registry = reg; + } + } catch (Exception e) { + throw new ServiceLifecycleException("Unable to load the persisted registry.", e); + } + } + registryFileName = registryFile.getAbsolutePath(); + } + + @Override + public void start() throws ServiceLifecycleException { + } + + @Override + public void stop() throws ServiceLifecycleException { + } + +} http://git-wip-us.apache.org/repos/asf/knox/blob/8affbc02/gateway-server/src/main/java/org/apache/knox/gateway/services/security/impl/DefaultAliasService.java ---------------------------------------------------------------------- diff --cc gateway-server/src/main/java/org/apache/knox/gateway/services/security/impl/DefaultAliasService.java index f52a7b3,0000000..b5e62ab mode 100644,000000..100644 --- a/gateway-server/src/main/java/org/apache/knox/gateway/services/security/impl/DefaultAliasService.java +++ b/gateway-server/src/main/java/org/apache/knox/gateway/services/security/impl/DefaultAliasService.java @@@ -1,217 -1,0 +1,217 @@@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.knox.gateway.services.security.impl; + +import java.security.KeyStore; +import java.security.KeyStoreException; ++import java.security.SecureRandom; +import java.security.cert.Certificate; +import java.util.ArrayList; +import java.util.Enumeration; +import java.util.List; +import java.util.Map; - import java.util.Random; + +import org.apache.knox.gateway.GatewayMessages; +import org.apache.knox.gateway.config.GatewayConfig; +import org.apache.knox.gateway.i18n.messages.MessagesFactory; +import org.apache.knox.gateway.services.ServiceLifecycleException; +import org.apache.knox.gateway.services.security.AliasService; +import org.apache.knox.gateway.services.security.AliasServiceException; +import org.apache.knox.gateway.services.security.KeystoreService; +import org.apache.knox.gateway.services.security.KeystoreServiceException; +import org.apache.knox.gateway.services.security.MasterService; + +public class DefaultAliasService implements AliasService { + private static final GatewayMessages LOG = MessagesFactory.get( GatewayMessages.class ); + - private static final String GATEWAY_IDENTITY_PASSPHRASE = "gateway-identity-passphrase"; ++ private static final String GATEWAY_IDENTITY_PASSPHRASE = "gateway-identity-passphrase"; + + protected char[] chars = { 'a', 'b', 'c', 'd', 'e', 'f', 'g', + 'h', 'j', 'k', 'm', 'n', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', + 'x', 'y', 'z', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J', 'K', + 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', + '2', '3', '4', '5', '6', '7', '8', '9',}; + + private KeystoreService keystoreService; + private MasterService masterService; + + @Override + public void init(GatewayConfig config, Map<String, String> options) + throws ServiceLifecycleException { + } + + @Override + public void start() throws ServiceLifecycleException { + } + + @Override + public void stop() throws ServiceLifecycleException { + } + + @Override + public char[] getGatewayIdentityPassphrase() throws AliasServiceException { + char[] passphrase = getPasswordFromAliasForGateway(GATEWAY_IDENTITY_PASSPHRASE); + if (passphrase == null) { + passphrase = masterService.getMasterSecret(); + } + return passphrase; + } + + /* (non-Javadoc) + * @see org.apache.knox.gateway.services.security.impl.AliasService#getAliasForCluster(java.lang.String, java.lang.String) + */ + @Override + public char[] getPasswordFromAliasForCluster(String clusterName, String alias) + throws AliasServiceException { + return getPasswordFromAliasForCluster(clusterName, alias, false); + } + + /* (non-Javadoc) + * @see org.apache.knox.gateway.services.security.impl.AliasService#getAliasForCluster(java.lang.String, java.lang.String, boolean) + */ + @Override + public char[] getPasswordFromAliasForCluster(String clusterName, String alias, boolean generate) + throws AliasServiceException { + char[] credential = null; + try { + credential = keystoreService.getCredentialForCluster(clusterName, alias); + if (credential == null) { + if (generate) { + generateAliasForCluster(clusterName, alias); + credential = keystoreService.getCredentialForCluster(clusterName, alias); + } + } + } catch (KeystoreServiceException e) { + LOG.failedToGetCredentialForCluster(clusterName, e); + throw new AliasServiceException(e); + } + return credential; + } + + private String generatePassword(int length) { - StringBuffer sb = new StringBuffer(); - Random r = new Random(); ++ StringBuilder sb = new StringBuilder(); ++ SecureRandom r = new SecureRandom(); + for (int i = 0; i < length; i++) { + sb.append(chars[r.nextInt(chars.length)]); + } + return sb.toString(); + } - ++ + public void setKeystoreService(KeystoreService ks) { + this.keystoreService = ks; + } + + public void setMasterService(MasterService ms) { + this.masterService = ms; - ++ + } + + @Override + public void generateAliasForCluster(String clusterName, String alias) + throws AliasServiceException { + try { + keystoreService.getCredentialStoreForCluster(clusterName); + } catch (KeystoreServiceException e) { + LOG.failedToGenerateAliasForCluster(clusterName, e); + throw new AliasServiceException(e); + } + String passwordString = generatePassword(16); + addAliasForCluster(clusterName, alias, passwordString); + } + + /* (non-Javadoc) + * @see org.apache.knox.gateway.services.security.impl.AliasService#addAliasForCluster(java.lang.String, java.lang.String, java.lang.String) + */ + @Override + public void addAliasForCluster(String clusterName, String alias, String value) { + try { + keystoreService.addCredentialForCluster(clusterName, alias, value); + } catch (KeystoreServiceException e) { + LOG.failedToAddCredentialForCluster(clusterName, e); + } + } + + @Override + public void removeAliasForCluster(String clusterName, String alias) + throws AliasServiceException { + try { + keystoreService.removeCredentialForCluster(clusterName, alias); + } catch (KeystoreServiceException e) { + throw new AliasServiceException(e); + } + } + + @Override + public char[] getPasswordFromAliasForGateway(String alias) + throws AliasServiceException { + return getPasswordFromAliasForCluster("__gateway", alias); + } + + @Override + public void generateAliasForGateway(String alias) + throws AliasServiceException { + generateAliasForCluster("__gateway", alias); + } + + /* (non-Javadoc) + * @see AliasService#getCertificateForGateway(java.lang.String) + */ + @Override + public Certificate getCertificateForGateway(String alias) { + Certificate cert = null; + try { + cert = this.keystoreService.getKeystoreForGateway().getCertificate(alias); + } catch (KeyStoreException e) { + LOG.unableToRetrieveCertificateForGateway(e); + // should we throw an exception? + } catch (KeystoreServiceException e) { + LOG.unableToRetrieveCertificateForGateway(e); + } + return cert; + } + + /* (non-Javadoc) + * @see AliasService#getAliasesForCluster(java.lang.String) + */ + @Override + public List<String> getAliasesForCluster(String clusterName) { + ArrayList<String> list = new ArrayList<String>(); + KeyStore keyStore; + try { + keyStore = keystoreService.getCredentialStoreForCluster(clusterName); + if (keyStore != null) { + String alias = null; + try { + Enumeration<String> e = keyStore.aliases(); + while (e.hasMoreElements()) { + alias = e.nextElement(); + // only include the metadata key names in the list of names + if (!alias.contains("@")) { + list.add(alias); + } + } + } catch (KeyStoreException e) { + LOG.failedToGetCredentialForCluster(clusterName, e); + } + } + } catch (KeystoreServiceException kse) { + LOG.failedToGetCredentialForCluster(clusterName, kse); + } + return list; + } +} http://git-wip-us.apache.org/repos/asf/knox/blob/8affbc02/gateway-server/src/main/java/org/apache/knox/gateway/topology/simple/SimpleDescriptorHandler.java ---------------------------------------------------------------------- diff --cc gateway-server/src/main/java/org/apache/knox/gateway/topology/simple/SimpleDescriptorHandler.java index c4a3914,0000000..16d5b81 mode 100644,000000..100644 --- a/gateway-server/src/main/java/org/apache/knox/gateway/topology/simple/SimpleDescriptorHandler.java +++ b/gateway-server/src/main/java/org/apache/knox/gateway/topology/simple/SimpleDescriptorHandler.java @@@ -1,187 -1,0 +1,234 @@@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with this + * work for additional information regarding copyright ownership. The ASF + * licenses this file to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package org.apache.knox.gateway.topology.simple; + +import org.apache.knox.gateway.i18n.messages.MessagesFactory; +import org.apache.knox.gateway.services.Service; +import org.apache.knox.gateway.topology.discovery.DefaultServiceDiscoveryConfig; +import org.apache.knox.gateway.topology.discovery.ServiceDiscovery; +import org.apache.knox.gateway.topology.discovery.ServiceDiscoveryFactory; ++import java.io.BufferedWriter; ++import java.io.File; ++import java.io.FileInputStream; ++import java.io.FileWriter; ++import java.io.InputStreamReader; ++import java.io.IOException; ++ ++import java.net.URI; ++import java.net.URISyntaxException; ++ ++import java.util.ArrayList; ++import java.util.Collections; ++import java.util.HashMap; ++import java.util.List; ++import java.util.Map; + - import java.io.*; - import java.util.*; + + +/** + * Processes simple topology descriptors, producing full topology files, which can subsequently be deployed to the + * gateway. + */ +public class SimpleDescriptorHandler { + + private static final Service[] NO_GATEWAY_SERVICES = new Service[]{}; + + private static final SimpleDescriptorMessages log = MessagesFactory.get(SimpleDescriptorMessages.class); + + public static Map<String, File> handle(File desc) throws IOException { + return handle(desc, NO_GATEWAY_SERVICES); + } + + public static Map<String, File> handle(File desc, Service...gatewayServices) throws IOException { + return handle(desc, desc.getParentFile(), gatewayServices); + } + + public static Map<String, File> handle(File desc, File destDirectory) throws IOException { + return handle(desc, destDirectory, NO_GATEWAY_SERVICES); + } + + public static Map<String, File> handle(File desc, File destDirectory, Service...gatewayServices) throws IOException { + return handle(SimpleDescriptorFactory.parse(desc.getAbsolutePath()), desc.getParentFile(), destDirectory, gatewayServices); + } + + public static Map<String, File> handle(SimpleDescriptor desc, File srcDirectory, File destDirectory) { + return handle(desc, srcDirectory, destDirectory, NO_GATEWAY_SERVICES); + } + + public static Map<String, File> handle(SimpleDescriptor desc, File srcDirectory, File destDirectory, Service...gatewayServices) { + Map<String, File> result = new HashMap<>(); + + File topologyDescriptor; + + DefaultServiceDiscoveryConfig sdc = new DefaultServiceDiscoveryConfig(desc.getDiscoveryAddress()); + sdc.setUser(desc.getDiscoveryUser()); + sdc.setPasswordAlias(desc.getDiscoveryPasswordAlias()); - ServiceDiscovery sd = ServiceDiscoveryFactory - .get(desc.getDiscoveryType(), gatewayServices); ++ ServiceDiscovery sd = ServiceDiscoveryFactory.get(desc.getDiscoveryType(), gatewayServices); + ServiceDiscovery.Cluster cluster = sd.discover(sdc, desc.getClusterName()); + + Map<String, List<String>> serviceURLs = new HashMap<>(); + + if (cluster != null) { + for (SimpleDescriptor.Service descService : desc.getServices()) { + String serviceName = descService.getName(); + + List<String> descServiceURLs = descService.getURLs(); + if (descServiceURLs == null || descServiceURLs.isEmpty()) { + descServiceURLs = cluster.getServiceURLs(serviceName); + } + - // If there is at least one URL associated with the service, then add it to the map ++ // Validate the discovered service URLs ++ List<String> validURLs = new ArrayList<>(); + if (descServiceURLs != null && !descServiceURLs.isEmpty()) { - serviceURLs.put(serviceName, descServiceURLs); ++ // Validate the URL(s) ++ for (String descServiceURL : descServiceURLs) { ++ if (validateURL(serviceName, descServiceURL)) { ++ validURLs.add(descServiceURL); ++ } ++ } ++ } ++ ++ // If there is at least one valid URL associated with the service, then add it to the map ++ if (!validURLs.isEmpty()) { ++ serviceURLs.put(serviceName, validURLs); + } else { + log.failedToDiscoverClusterServiceURLs(serviceName, cluster.getName()); - throw new IllegalStateException("ServiceDiscovery failed to resolve any URLs for " + serviceName + - ". Topology update aborted!"); + } + } + } else { + log.failedToDiscoverClusterServices(desc.getClusterName()); + } + ++ BufferedWriter fw = null; + topologyDescriptor = null; + File providerConfig = null; + try { + // Verify that the referenced provider configuration exists before attempting to reading it + providerConfig = resolveProviderConfigurationReference(desc.getProviderConfig(), srcDirectory); + if (providerConfig == null) { + log.failedToResolveProviderConfigRef(desc.getProviderConfig()); + throw new IllegalArgumentException("Unresolved provider configuration reference: " + + desc.getProviderConfig() + " ; Topology update aborted!"); + } + result.put("reference", providerConfig); + + // TODO: Should the contents of the provider config be validated before incorporating it into the topology? + + String topologyFilename = desc.getName(); + if (topologyFilename == null) { + topologyFilename = desc.getClusterName(); + } + topologyDescriptor = new File(destDirectory, topologyFilename + ".xml"); - FileWriter fw = new FileWriter(topologyDescriptor); ++ fw = new BufferedWriter(new FileWriter(topologyDescriptor)); + + fw.write("<topology>\n"); + + // Copy the externalized provider configuration content into the topology descriptor in-line + InputStreamReader policyReader = new InputStreamReader(new FileInputStream(providerConfig)); + char[] buffer = new char[1024]; + int count; + while ((count = policyReader.read(buffer)) > 0) { + fw.write(buffer, 0, count); + } + policyReader.close(); + ++ // Sort the service names to write the services alphabetically ++ List<String> serviceNames = new ArrayList<>(serviceURLs.keySet()); ++ Collections.sort(serviceNames); ++ + // Write the service declarations - for (String serviceName : serviceURLs.keySet()) { ++ for (String serviceName : serviceNames) { + fw.write(" <service>\n"); + fw.write(" <role>" + serviceName + "</role>\n"); + for (String url : serviceURLs.get(serviceName)) { + fw.write(" <url>" + url + "</url>\n"); + } + fw.write(" </service>\n"); + } + + fw.write("</topology>\n"); + + fw.flush(); - fw.close(); + } catch (IOException e) { + log.failedToGenerateTopologyFromSimpleDescriptor(topologyDescriptor.getName(), e); + topologyDescriptor.delete(); ++ } finally { ++ if (fw != null) { ++ try { ++ fw.close(); ++ } catch (IOException e) { ++ // ignore ++ } ++ } + } + + result.put("topology", topologyDescriptor); + return result; + } + ++ private static boolean validateURL(String serviceName, String url) { ++ boolean result = false; ++ ++ if (url != null && !url.isEmpty()) { ++ try { ++ new URI(url); ++ result = true; ++ } catch (URISyntaxException e) { ++ log.serviceURLValidationFailed(serviceName, url, e); ++ } ++ } ++ ++ return result; ++ } + + private static File resolveProviderConfigurationReference(String reference, File srcDirectory) { + File providerConfig; + + // If the reference includes a path + if (reference.contains(File.separator)) { + // Check if it's an absolute path + providerConfig = new File(reference); + if (!providerConfig.exists()) { + // If it's not an absolute path, try treating it as a relative path + providerConfig = new File(srcDirectory, reference); + if (!providerConfig.exists()) { + providerConfig = null; + } + } + } else { // No file path, just a name + // Check if it's co-located with the referencing descriptor + providerConfig = new File(srcDirectory, reference); + if (!providerConfig.exists()) { + // Check the shared-providers config location + File sharedProvidersDir = new File(srcDirectory, "../shared-providers"); + if (sharedProvidersDir.exists()) { + providerConfig = new File(sharedProvidersDir, reference); + if (!providerConfig.exists()) { + // Check if it's a valid name without the extension + providerConfig = new File(sharedProvidersDir, reference + ".xml"); + if (!providerConfig.exists()) { + providerConfig = null; + } + } + } + } + } + + return providerConfig; + } + +} http://git-wip-us.apache.org/repos/asf/knox/blob/8affbc02/gateway-server/src/main/java/org/apache/knox/gateway/topology/simple/SimpleDescriptorMessages.java ---------------------------------------------------------------------- diff --cc gateway-server/src/main/java/org/apache/knox/gateway/topology/simple/SimpleDescriptorMessages.java index eb9d887,0000000..07c4350 mode 100644,000000..100644 --- a/gateway-server/src/main/java/org/apache/knox/gateway/topology/simple/SimpleDescriptorMessages.java +++ b/gateway-server/src/main/java/org/apache/knox/gateway/topology/simple/SimpleDescriptorMessages.java @@@ -1,44 -1,0 +1,50 @@@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with this + * work for additional information regarding copyright ownership. The ASF + * licenses this file to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * <p> + * http://www.apache.org/licenses/LICENSE-2.0 + * <p> + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package org.apache.knox.gateway.topology.simple; + +import org.apache.knox.gateway.i18n.messages.Message; +import org.apache.knox.gateway.i18n.messages.MessageLevel; +import org.apache.knox.gateway.i18n.messages.Messages; +import org.apache.knox.gateway.i18n.messages.StackTrace; + +@Messages(logger="org.apache.gateway.topology.simple") +public interface SimpleDescriptorMessages { + + @Message(level = MessageLevel.ERROR, + text = "Service discovery for cluster {0} failed.") + void failedToDiscoverClusterServices(final String cluster); + + @Message(level = MessageLevel.ERROR, - text = "No URLs were discovered for {0} in the {1} cluster.") ++ text = "No valid URLs were discovered for {0} in the {1} cluster.") + void failedToDiscoverClusterServiceURLs(final String serviceName, final String clusterName); + + @Message(level = MessageLevel.ERROR, + text = "Failed to resolve the referenced provider configuration {0}.") + void failedToResolveProviderConfigRef(final String providerConfigRef); + + @Message(level = MessageLevel.ERROR, ++ text = "URL validation failed for {0} URL {1} : {2}") ++ void serviceURLValidationFailed(final String serviceName, ++ final String url, ++ @StackTrace( level = MessageLevel.DEBUG ) Exception e ); ++ ++ @Message(level = MessageLevel.ERROR, + text = "Error generating topology {0} from simple descriptor: {1}") + void failedToGenerateTopologyFromSimpleDescriptor(final String topologyFile, + @StackTrace( level = MessageLevel.DEBUG ) Exception e ); + +} http://git-wip-us.apache.org/repos/asf/knox/blob/8affbc02/gateway-server/src/main/java/org/apache/knox/gateway/websockets/GatewayWebsocketHandler.java ---------------------------------------------------------------------- diff --cc gateway-server/src/main/java/org/apache/knox/gateway/websockets/GatewayWebsocketHandler.java index 3ddd311,0000000..69634a7 mode 100644,000000..100644 --- a/gateway-server/src/main/java/org/apache/knox/gateway/websockets/GatewayWebsocketHandler.java +++ b/gateway-server/src/main/java/org/apache/knox/gateway/websockets/GatewayWebsocketHandler.java @@@ -1,241 -1,0 +1,266 @@@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.knox.gateway.websockets; + +import java.io.File; +import java.net.MalformedURLException; +import java.net.URI; +import java.net.URL; ++import java.util.List; ++import java.util.Map; +import java.util.Set; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +import org.apache.commons.lang3.StringUtils; +import org.apache.knox.gateway.config.GatewayConfig; +import org.apache.knox.gateway.i18n.messages.MessagesFactory; +import org.apache.knox.gateway.service.definition.ServiceDefinition; +import org.apache.knox.gateway.services.GatewayServices; +import org.apache.knox.gateway.services.registry.ServiceDefEntry; +import org.apache.knox.gateway.services.registry.ServiceDefinitionRegistry; +import org.apache.knox.gateway.services.registry.ServiceRegistry; +import org.apache.knox.gateway.util.ServiceDefinitionsLoader; +import org.eclipse.jetty.websocket.server.WebSocketHandler; +import org.eclipse.jetty.websocket.servlet.ServletUpgradeRequest; +import org.eclipse.jetty.websocket.servlet.ServletUpgradeResponse; +import org.eclipse.jetty.websocket.servlet.WebSocketCreator; +import org.eclipse.jetty.websocket.servlet.WebSocketServletFactory; + ++import javax.websocket.ClientEndpointConfig; ++ +/** + * Websocket handler that will handle websocket connection request. This class + * is responsible for creating a proxy socket for inbound and outbound + * connections. This is also where the http to websocket handoff happens. - * ++ * + * @since 0.10 + */ +public class GatewayWebsocketHandler extends WebSocketHandler + implements WebSocketCreator { + + private static final WebsocketLogMessages LOG = MessagesFactory + .get(WebsocketLogMessages.class); + + public static final String WEBSOCKET_PROTOCOL_STRING = "ws://"; + + public static final String SECURE_WEBSOCKET_PROTOCOL_STRING = "wss://"; + + static final String REGEX_SPLIT_CONTEXT = "^((?:[^/]*/){2}[^/]*)"; + + final static String REGEX_SPLIT_SERVICE_PATH = "^((?:[^/]*/){3}[^/]*)"; + + private static final int POOL_SIZE = 10; + + /** + * Manage the threads that are spawned + * @since 0.13 + */ + private final ExecutorService pool; + + final GatewayConfig config; + final GatewayServices services; + + /** + * Create an instance - * ++ * + * @param config + * @param services + */ + public GatewayWebsocketHandler(final GatewayConfig config, + final GatewayServices services) { + super(); + + this.config = config; + this.services = services; + pool = Executors.newFixedThreadPool(POOL_SIZE); + + } + + /* + * (non-Javadoc) - * ++ * + * @see + * org.eclipse.jetty.websocket.server.WebSocketHandler#configure(org.eclipse. + * jetty.websocket.servlet.WebSocketServletFactory) + */ + @Override + public void configure(final WebSocketServletFactory factory) { + factory.setCreator(this); + factory.getPolicy() + .setMaxTextMessageSize(config.getWebsocketMaxTextMessageSize()); + factory.getPolicy() + .setMaxBinaryMessageSize(config.getWebsocketMaxBinaryMessageSize()); + + factory.getPolicy().setMaxBinaryMessageBufferSize( + config.getWebsocketMaxBinaryMessageBufferSize()); + factory.getPolicy().setMaxTextMessageBufferSize( + config.getWebsocketMaxTextMessageBufferSize()); + + factory.getPolicy() + .setInputBufferSize(config.getWebsocketInputBufferSize()); + + factory.getPolicy() + .setAsyncWriteTimeout(config.getWebsocketAsyncWriteTimeout()); + factory.getPolicy().setIdleTimeout(config.getWebsocketIdleTimeout()); + + } + + /* + * (non-Javadoc) - * ++ * + * @see + * org.eclipse.jetty.websocket.servlet.WebSocketCreator#createWebSocket(org. + * eclipse.jetty.websocket.servlet.ServletUpgradeRequest, + * org.eclipse.jetty.websocket.servlet.ServletUpgradeResponse) + */ + @Override + public Object createWebSocket(ServletUpgradeRequest req, + ServletUpgradeResponse resp) { + + try { + final URI requestURI = req.getRequestURI(); + final String path = requestURI.getPath(); + + /* URL used to connect to websocket backend */ + final String backendURL = getMatchedBackendURL(path); + + /* Upgrade happens here */ - return new ProxyWebSocketAdapter(URI.create(backendURL), pool); ++ return new ProxyWebSocketAdapter(URI.create(backendURL), pool, getClientEndpointConfig(req)); + } catch (final Exception e) { + LOG.failedCreatingWebSocket(e); + throw e; + } + } + + /** ++ * Returns a {@link ClientEndpointConfig} config that contains the headers ++ * to be passed to the backend. ++ * @since 0.14.0 ++ * @param req ++ * @return ++ */ ++ private ClientEndpointConfig getClientEndpointConfig(final ServletUpgradeRequest req) { ++ ++ return ClientEndpointConfig.Builder.create().configurator( new ClientEndpointConfig.Configurator() { ++ ++ @Override ++ public void beforeRequest(final Map<String, List<String>> headers) { ++ ++ /* Add request headers */ ++ req.getHeaders().forEach(headers::putIfAbsent); ++ ++ } ++ }).build(); ++ } ++ ++ /** + * This method looks at the context path and returns the backend websocket + * url. If websocket url is found it is used as is, or we default to + * ws://{host}:{port} which might or might not be right. - * - * @param The context path ++ * ++ * @param + * @return Websocket backend url + */ + private synchronized String getMatchedBackendURL(final String path) { + + final ServiceRegistry serviceRegistryService = services + .getService(GatewayServices.SERVICE_REGISTRY_SERVICE); + + final ServiceDefinitionRegistry serviceDefinitionService = services + .getService(GatewayServices.SERVICE_DEFINITION_REGISTRY); + + /* Filter out the /cluster/topology to get the context we want */ + String[] pathInfo = path.split(REGEX_SPLIT_CONTEXT); + + final ServiceDefEntry entry = serviceDefinitionService + .getMatchingService(pathInfo[1]); + + if (entry == null) { + throw new RuntimeException( + String.format("Cannot find service for the given path: %s", path)); + } + + /* Filter out /cluster/topology/service to get endpoint */ + String[] pathService = path.split(REGEX_SPLIT_SERVICE_PATH); + + final File servicesDir = new File(config.getGatewayServicesDir()); + + final Set<ServiceDefinition> serviceDefs = ServiceDefinitionsLoader + .getServiceDefinitions(servicesDir); + + /* URL used to connect to websocket backend */ + String backendURL = urlFromServiceDefinition(serviceDefs, + serviceRegistryService, entry, path); + + StringBuffer backend = new StringBuffer(); + try { + + /* if we do not find websocket URL we default to HTTP */ + if (!StringUtils.containsAny(backendURL, WEBSOCKET_PROTOCOL_STRING, SECURE_WEBSOCKET_PROTOCOL_STRING)) { + URL serviceUrl = new URL(backendURL); + + /* Use http host:port if ws url not configured */ + final String protocol = (serviceUrl.getProtocol() == "ws" + || serviceUrl.getProtocol() == "wss") ? serviceUrl.getProtocol() + : "ws"; + backend.append(protocol).append("://"); + backend.append(serviceUrl.getHost()).append(":"); + backend.append(serviceUrl.getPort()).append("/"); + backend.append(serviceUrl.getPath()); + } + else { + URI serviceUri = new URI(backendURL); + backend.append(serviceUri); + /* Avoid Zeppelin Regression - as this would require ambari changes and break current knox websocket use case*/ - if (!StringUtils.endsWith(backend.toString(), "/ws") && pathService[1] != null) { ++ if (!StringUtils.endsWith(backend.toString(), "/ws") && pathService.length > 0 && pathService[1] != null) { + backend.append(pathService[1]); + } + } + backendURL = backend.toString(); + + } catch (MalformedURLException e){ + LOG.badUrlError(e); + throw new RuntimeException(e.toString()); + } catch (Exception e1) { + LOG.failedCreatingWebSocket(e1); + throw new RuntimeException(e1.toString()); + } + + return backendURL; + } + + private static String urlFromServiceDefinition( + final Set<ServiceDefinition> serviceDefs, + final ServiceRegistry serviceRegistry, final ServiceDefEntry entry, + final String path) { + + final String[] contexts = path.split("/"); + + final String serviceURL = serviceRegistry.lookupServiceURL(contexts[2], + entry.getName().toUpperCase()); + + /* + * we have a match, if ws:// is present it is returned else http:// is + * returned + */ + return serviceURL; + + } + +}