This is an automated email from the ASF dual-hosted git repository. valentyn pushed a commit to branch valentyn/sigv4 in repository https://gitbox.apache.org/repos/asf/tinkerpop.git
commit 3cf5b24e2135ca6642e976cf21494e17ba442b80 Author: Valentyn Kahamlyk <[email protected]> AuthorDate: Sun May 12 23:12:33 2024 -0700 sigv4 auth client implementation --- gremlin-driver/pom.xml | 5 + .../org/apache/tinkerpop/gremlin/driver/Auth.java | 229 ++++++++++++++++++++- .../driver/handler/HttpGremlinRequestEncoder.java | 4 +- .../server/GremlinServerAuthIntegrateTest.java | 45 ++++ 4 files changed, 280 insertions(+), 3 deletions(-) diff --git a/gremlin-driver/pom.xml b/gremlin-driver/pom.xml index 84cf73fd7d..c6b6ff5d5b 100644 --- a/gremlin-driver/pom.xml +++ b/gremlin-driver/pom.xml @@ -96,6 +96,11 @@ limitations under the License. <artifactId>hamcrest</artifactId> <scope>test</scope> </dependency> + <dependency> + <groupId>com.amazonaws</groupId> + <artifactId>aws-java-sdk-core</artifactId> + <version>1.12.241</version> + </dependency> </dependencies> <build> diff --git a/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/Auth.java b/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/Auth.java index fa832aeca0..beb54bee2e 100644 --- a/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/Auth.java +++ b/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/Auth.java @@ -18,33 +18,258 @@ */ package org.apache.tinkerpop.gremlin.driver; +import com.amazonaws.DefaultRequest; +import com.amazonaws.SignableRequest; +import com.amazonaws.auth.AWS4Signer; +import com.amazonaws.auth.AWSCredentials; +import com.amazonaws.auth.AWSCredentialsProvider; +import com.amazonaws.auth.BasicSessionCredentials; +import com.amazonaws.http.HttpMethodName; +import com.amazonaws.util.SdkHttpUtils; +import com.amazonaws.util.StringUtils; +import io.netty.buffer.ByteBuf; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaders; +import org.apache.http.entity.StringEntity; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.UnsupportedEncodingException; +import java.net.URI; +import java.util.ArrayList; import java.util.Base64; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static com.amazonaws.auth.internal.SignerConstants.AUTHORIZATION; +import static com.amazonaws.auth.internal.SignerConstants.HOST; +import static com.amazonaws.auth.internal.SignerConstants.X_AMZ_DATE; +import static com.amazonaws.auth.internal.SignerConstants.X_AMZ_SECURITY_TOKEN; public abstract class Auth implements RequestInterceptor { + private static final String NEPTUNE_SERVICE_NAME = "neptune-db"; + public static Auth basic(final String username, final String password) { return new Basic(username, password); } + public static Auth sigv4(final String regionName, final AWSCredentialsProvider awsCredentialsProvider) { + return new Sigv4(regionName, awsCredentialsProvider, NEPTUNE_SERVICE_NAME); + } + + public static Auth sigv4(final String regionName, final AWSCredentialsProvider awsCredentialsProvider, final String serviceName) { + return new Sigv4(regionName, awsCredentialsProvider, serviceName); + } + public static class Basic extends Auth { private final String username; private final String password; - private Basic(final String username, final String password ) { + private Basic(final String username, final String password) { this.username = username; this.password = password; } @Override - public FullHttpRequest apply(FullHttpRequest fullHttpRequest) { + public FullHttpRequest apply(final FullHttpRequest fullHttpRequest) { final String valueToEncode = username + ":" + password; fullHttpRequest.headers().add(HttpHeaderNames.AUTHORIZATION, "Basic " + Base64.getEncoder().encodeToString(valueToEncode.getBytes())); return fullHttpRequest; } } + + public static class Sigv4 extends Auth { + private final AWSCredentialsProvider awsCredentialsProvider; + private final AWS4Signer aws4Signer; + + private Sigv4(final String regionName, final AWSCredentialsProvider awsCredentialsProvider, final String serviceName) { + this.awsCredentialsProvider = awsCredentialsProvider; + + aws4Signer = new AWS4Signer(); + aws4Signer.setRegionName(regionName); + aws4Signer.setServiceName(serviceName); + } + + @Override + public FullHttpRequest apply(final FullHttpRequest fullHttpRequest) { + try { + // Convert Http request into an AWS SDK signable request + final SignableRequest<?> awsSignableRequest = toSignableRequest(fullHttpRequest); + + // Sign the AWS SDK signable request (which internally adds some HTTP headers) + final AWSCredentials credentials = awsCredentialsProvider.getCredentials(); + aws4Signer.sign(awsSignableRequest, credentials); + + // extract session token if temporary credentials are provided + String sessionToken = ""; + if ((credentials instanceof BasicSessionCredentials)) { + sessionToken = ((BasicSessionCredentials) credentials).getSessionToken(); + } + + // todo: confirm is needed to replace header `Host` with `host` + fullHttpRequest.headers().remove(HttpHeaderNames.HOST); + fullHttpRequest.headers().add(HOST, awsSignableRequest.getHeaders().get(HOST)); + fullHttpRequest.headers().add(X_AMZ_DATE, awsSignableRequest.getHeaders().get(X_AMZ_DATE)); + fullHttpRequest.headers().add(AUTHORIZATION, awsSignableRequest.getHeaders().get(AUTHORIZATION)); + + if (!sessionToken.isEmpty()) { + fullHttpRequest.headers().add(X_AMZ_SECURITY_TOKEN, sessionToken); + } + } catch (final Throwable t) { + throw new RuntimeException(t); + } + return fullHttpRequest; + } + + private SignableRequest<?> toSignableRequest(final FullHttpRequest request) + throws Exception { + + // make sure the request is not null and contains the minimal required set of information + checkNotNull(request, "The request must not be null"); + checkNotNull(request.uri(), "The request URI must not be null"); + checkNotNull(request.method(), "The request method must not be null"); + + // convert the headers to the internal API format + final HttpHeaders headers = request.headers(); + final Map<String, String> headersInternal = new HashMap<>(); + + String hostName = ""; + + // we don't want to add the Host header as the Signer always adds the host header. + for (String header : headers.names()) { + // Skip adding the Host header as the signing process will add one. + if (!header.equalsIgnoreCase(HOST)) { + headersInternal.put(header, headers.get(header)); + } else { + hostName = headers.get(header); + } + } + + // convert the parameters to the internal API format + final URI uri = URI.create(request.uri()); + + final String queryStr = uri.getQuery(); + final Map<String, List<String>> parametersInternal = new HashMap<>(extractParametersFromQueryString(queryStr)); + + // carry over the entity (or an empty entity, if no entity is provided) + final InputStream content; + final ByteBuf contentBuffer = request.content(); + boolean hasContent = false; + try { + if (contentBuffer != null && contentBuffer.isReadable()) { + hasContent = true; + contentBuffer.retain(); + byte[] bytes = new byte[contentBuffer.readableBytes()]; + contentBuffer.getBytes(contentBuffer.readerIndex(), bytes); + content = new ByteArrayInputStream(bytes); + } else { + content = new StringEntity("").getContent(); + } + } catch (UnsupportedEncodingException e) { + throw new Exception("Encoding of the input string failed", e); + } catch (IOException e) { + throw new Exception("IOException while accessing entity content", e); + } finally { + if (hasContent) { + contentBuffer.release(); + } + } + + if (StringUtils.isNullOrEmpty(hostName)) { + // try to extract hostname from the uri since hostname was not provided in the header. + final String authority = uri.getAuthority(); + if (authority == null) { + throw new Exception("Unable to identify host information," + + " either hostname should be provided in the uri or should be passed as a header"); + } + + hostName = authority; + } + + final URI endpointUri = URI.create("http://" + hostName); + + return convertToSignableRequest( + request.method().name(), + endpointUri, + uri.getPath(), + headersInternal, + parametersInternal, + content); + } + + private Map<String, List<String>> extractParametersFromQueryString(final String queryStr) { + + final Map<String, List<String>> parameters = new HashMap<>(); + + // convert the parameters to the internal API format + if (queryStr != null) { + for (final String queryParam : queryStr.split("&")) { + + if (!queryParam.isEmpty()) { + + final String[] keyValuePair = queryParam.split("=", 2); + + // parameters are encoded in the HTTP request, we need to decode them here + final String key = SdkHttpUtils.urlDecode(keyValuePair[0]); + final String value; + + if (keyValuePair.length == 2) { + value = SdkHttpUtils.urlDecode(keyValuePair[1]); + } else { + value = ""; + } + + // insert the parameter key into the map, if not yet present + if (!parameters.containsKey(key)) { + parameters.put(key, new ArrayList<>()); + } + + // append the parameter value to the list for the given key + parameters.get(key).add(value); + } + } + } + + return parameters; + } + + private SignableRequest<?> convertToSignableRequest( + final String httpMethodName, + final URI httpEndpointUri, + final String resourcePath, + final Map<String, String> httpHeaders, + final Map<String, List<String>> httpParameters, + final InputStream httpContent) throws Exception { + + checkNotNull(httpMethodName, "Http method name must not be null"); + checkNotNull(httpEndpointUri, "Http endpoint URI must not be null"); + checkNotNull(httpHeaders, "Http headers must not be null"); + checkNotNull(httpParameters, "Http parameters must not be null"); + checkNotNull(httpContent, "Http content name must not be null"); + + // create the HTTP AWS SDK Signable Request and carry over information + final DefaultRequest<?> awsRequest = new DefaultRequest<>(NEPTUNE_SERVICE_NAME); + awsRequest.setHttpMethod(HttpMethodName.fromValue(httpMethodName)); + awsRequest.setEndpoint(httpEndpointUri); + awsRequest.setResourcePath(resourcePath); + awsRequest.setHeaders(httpHeaders); + awsRequest.setParameters(httpParameters); + awsRequest.setContent(httpContent); + + return awsRequest; + } + + private void checkNotNull(final Object obj, final String errMsg) throws Exception { + if (obj == null) { + throw new Exception(errMsg); + } + } + } } diff --git a/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/handler/HttpGremlinRequestEncoder.java b/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/handler/HttpGremlinRequestEncoder.java index a004c899a7..23334f3ca6 100644 --- a/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/handler/HttpGremlinRequestEncoder.java +++ b/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/handler/HttpGremlinRequestEncoder.java @@ -35,6 +35,7 @@ import org.apache.tinkerpop.gremlin.util.MessageSerializerV4; import org.apache.tinkerpop.gremlin.util.message.RequestMessageV4; import org.apache.tinkerpop.gremlin.util.ser.SerTokens; +import java.net.InetSocketAddress; import java.util.List; import java.util.function.UnaryOperator; @@ -73,11 +74,12 @@ public final class HttpGremlinRequestEncoder extends MessageToMessageEncoder<Req request.headers().add(HttpHeaderNames.CONTENT_TYPE, mimeType); request.headers().add(HttpHeaderNames.CONTENT_LENGTH, buffer.readableBytes()); request.headers().add(HttpHeaderNames.ACCEPT, mimeType); + request.headers().add(HttpHeaderNames.HOST, ((InetSocketAddress) channelHandlerContext.channel().remoteAddress()).getAddress().getHostAddress()); if (userAgentEnabled) { request.headers().add(HttpHeaderNames.USER_AGENT, UserAgent.USER_AGENT); } - for (final UnaryOperator<FullHttpRequest> interceptor: interceptors ) { + for (final UnaryOperator<FullHttpRequest> interceptor : interceptors) { request = interceptor.apply(request); } objects.add(request); diff --git a/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/GremlinServerAuthIntegrateTest.java b/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/GremlinServerAuthIntegrateTest.java index c33e42ba29..dfda66ee9a 100644 --- a/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/GremlinServerAuthIntegrateTest.java +++ b/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/GremlinServerAuthIntegrateTest.java @@ -18,6 +18,9 @@ */ package org.apache.tinkerpop.gremlin.server; +import com.amazonaws.auth.AWSCredentials; +import com.amazonaws.auth.AWSCredentialsProvider; +import io.netty.handler.codec.http.FullHttpRequest; import org.apache.tinkerpop.gremlin.driver.Client; import org.apache.tinkerpop.gremlin.driver.Cluster; import org.apache.tinkerpop.gremlin.driver.exception.ResponseException; @@ -26,16 +29,24 @@ import org.apache.tinkerpop.gremlin.util.ExceptionHelper; import org.ietf.jgss.GSSException; import org.junit.Test; +import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.concurrent.ExecutionException; +import java.util.concurrent.atomic.AtomicReference; import static org.apache.tinkerpop.gremlin.driver.Auth.basic; +import static org.apache.tinkerpop.gremlin.driver.Auth.sigv4; +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.CoreMatchers.startsWith; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.core.AnyOf.anyOf; import static org.hamcrest.core.IsInstanceOf.instanceOf; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; /** * @author Stephen Mallette (http://stephen.genoprime.com) @@ -59,6 +70,9 @@ public class GremlinServerAuthIntegrateTest extends AbstractGremlinServerIntegra final String nameOfTest = name.getMethodName(); switch (nameOfTest) { + case "shouldPassSigv4ToServer": + settings.authentication = new Settings.AuthenticationSettings(); + break; case "shouldAuthenticateOverSslWithPlainText": case "shouldFailIfSslEnabledOnServerButNotClient": final Settings.SslSettings sslConfig = new Settings.SslSettings(); @@ -72,6 +86,34 @@ public class GremlinServerAuthIntegrateTest extends AbstractGremlinServerIntegra return settings; } + @Test + public void shouldPassSigv4ToServer() throws Exception { + final AWSCredentialsProvider credentialsProvider = mock(AWSCredentialsProvider.class); + final AWSCredentials credentials = mock(AWSCredentials.class); + when(credentialsProvider.getCredentials()).thenReturn(credentials); + when(credentials.getAWSAccessKeyId()).thenReturn("I am AWSAccessKeyId"); + when(credentials.getAWSSecretKey()).thenReturn("I am AWSSecretKey"); + + final AtomicReference<FullHttpRequest> fullHttpRequest = new AtomicReference<>(); + final Cluster cluster = TestClientFactory.build() + .auth(sigv4("us-west2", credentialsProvider)) + .requestInterceptor(r -> { + fullHttpRequest.set(r); + return r; + }) + .create(); + final Client client = cluster.connect(); + client.submit("1+1").all().get(); + + assertNotNull(fullHttpRequest.get().headers().get("X-Amz-Date")); + assertThat(fullHttpRequest.get().headers().get("Authorization"), + startsWith("AWS4-HMAC-SHA256 Credential=I am AWSAccessKeyId")); + assertThat(fullHttpRequest.get().headers().get("Authorization"), + containsString("/us-west2/neptune-db/aws4_request, SignedHeaders=accept;content-length;content-type;host;user-agent;x-amz-date, Signature=")); + + cluster.close(); + } + @Test public void shouldFailIfSslEnabledOnServerButNotClient() throws Exception { final Cluster cluster = TestClientFactory.open(); @@ -102,8 +144,11 @@ public class GremlinServerAuthIntegrateTest extends AbstractGremlinServerIntegra final Cluster cluster = TestClientFactory.build() .enableSsl(true).sslSkipCertValidation(true) .auth(basic("stephen", "password")).create(); + final Client client = cluster.connect(); + client.submit("1+1").all().get(); + assertConnection(cluster, client); }
