This is an automated email from the ASF dual-hosted git repository.

valentyn pushed a commit to branch valentyn/sigv4
in repository https://gitbox.apache.org/repos/asf/tinkerpop.git

commit 3cf5b24e2135ca6642e976cf21494e17ba442b80
Author: Valentyn Kahamlyk <[email protected]>
AuthorDate: Sun May 12 23:12:33 2024 -0700

    sigv4 auth client implementation
---
 gremlin-driver/pom.xml                             |   5 +
 .../org/apache/tinkerpop/gremlin/driver/Auth.java  | 229 ++++++++++++++++++++-
 .../driver/handler/HttpGremlinRequestEncoder.java  |   4 +-
 .../server/GremlinServerAuthIntegrateTest.java     |  45 ++++
 4 files changed, 280 insertions(+), 3 deletions(-)

diff --git a/gremlin-driver/pom.xml b/gremlin-driver/pom.xml
index 84cf73fd7d..c6b6ff5d5b 100644
--- a/gremlin-driver/pom.xml
+++ b/gremlin-driver/pom.xml
@@ -96,6 +96,11 @@ limitations under the License.
             <artifactId>hamcrest</artifactId>
             <scope>test</scope>
         </dependency>
+        <dependency>
+                       <groupId>com.amazonaws</groupId>
+                       <artifactId>aws-java-sdk-core</artifactId>
+                       <version>1.12.241</version>
+        </dependency>
     </dependencies>
 
     <build>
diff --git 
a/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/Auth.java 
b/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/Auth.java
index fa832aeca0..beb54bee2e 100644
--- a/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/Auth.java
+++ b/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/Auth.java
@@ -18,33 +18,258 @@
  */
 package org.apache.tinkerpop.gremlin.driver;
 
+import com.amazonaws.DefaultRequest;
+import com.amazonaws.SignableRequest;
+import com.amazonaws.auth.AWS4Signer;
+import com.amazonaws.auth.AWSCredentials;
+import com.amazonaws.auth.AWSCredentialsProvider;
+import com.amazonaws.auth.BasicSessionCredentials;
+import com.amazonaws.http.HttpMethodName;
+import com.amazonaws.util.SdkHttpUtils;
+import com.amazonaws.util.StringUtils;
+import io.netty.buffer.ByteBuf;
 import io.netty.handler.codec.http.FullHttpRequest;
 import io.netty.handler.codec.http.HttpHeaderNames;
+import io.netty.handler.codec.http.HttpHeaders;
+import org.apache.http.entity.StringEntity;
 
+import java.io.ByteArrayInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.UnsupportedEncodingException;
+import java.net.URI;
+import java.util.ArrayList;
 import java.util.Base64;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+
+import static com.amazonaws.auth.internal.SignerConstants.AUTHORIZATION;
+import static com.amazonaws.auth.internal.SignerConstants.HOST;
+import static com.amazonaws.auth.internal.SignerConstants.X_AMZ_DATE;
+import static com.amazonaws.auth.internal.SignerConstants.X_AMZ_SECURITY_TOKEN;
 
 public abstract class Auth implements RequestInterceptor {
 
+    private static final String NEPTUNE_SERVICE_NAME = "neptune-db";
+
     public static Auth basic(final String username, final String password) {
         return new Basic(username, password);
     }
 
+    public static Auth sigv4(final String regionName, final 
AWSCredentialsProvider awsCredentialsProvider) {
+        return new Sigv4(regionName, awsCredentialsProvider, 
NEPTUNE_SERVICE_NAME);
+    }
+
+    public static Auth sigv4(final String regionName, final 
AWSCredentialsProvider awsCredentialsProvider, final String serviceName) {
+        return new Sigv4(regionName, awsCredentialsProvider, serviceName);
+    }
+
     public static class Basic extends Auth {
 
         private final String username;
         private final String password;
 
-        private Basic(final String username, final String password ) {
+        private Basic(final String username, final String password) {
             this.username = username;
             this.password = password;
         }
 
         @Override
-        public FullHttpRequest apply(FullHttpRequest fullHttpRequest) {
+        public FullHttpRequest apply(final FullHttpRequest fullHttpRequest) {
             final String valueToEncode = username + ":" + password;
             fullHttpRequest.headers().add(HttpHeaderNames.AUTHORIZATION,
                     "Basic " + 
Base64.getEncoder().encodeToString(valueToEncode.getBytes()));
             return fullHttpRequest;
         }
     }
+
+    public static class Sigv4 extends Auth {
+        private final AWSCredentialsProvider awsCredentialsProvider;
+        private final AWS4Signer aws4Signer;
+
+        private Sigv4(final String regionName, final AWSCredentialsProvider 
awsCredentialsProvider, final String serviceName) {
+            this.awsCredentialsProvider = awsCredentialsProvider;
+
+            aws4Signer = new AWS4Signer();
+            aws4Signer.setRegionName(regionName);
+            aws4Signer.setServiceName(serviceName);
+        }
+
+        @Override
+        public FullHttpRequest apply(final FullHttpRequest fullHttpRequest) {
+            try {
+                // Convert Http request into an AWS SDK signable request
+                final SignableRequest<?> awsSignableRequest = 
toSignableRequest(fullHttpRequest);
+
+                // Sign the AWS SDK signable request (which internally adds 
some HTTP headers)
+                final AWSCredentials credentials = 
awsCredentialsProvider.getCredentials();
+                aws4Signer.sign(awsSignableRequest, credentials);
+
+                // extract session token if temporary credentials are provided
+                String sessionToken = "";
+                if ((credentials instanceof BasicSessionCredentials)) {
+                    sessionToken = ((BasicSessionCredentials) 
credentials).getSessionToken();
+                }
+
+                // todo: confirm is needed to replace header `Host` with `host`
+                fullHttpRequest.headers().remove(HttpHeaderNames.HOST);
+                fullHttpRequest.headers().add(HOST, 
awsSignableRequest.getHeaders().get(HOST));
+                fullHttpRequest.headers().add(X_AMZ_DATE, 
awsSignableRequest.getHeaders().get(X_AMZ_DATE));
+                fullHttpRequest.headers().add(AUTHORIZATION, 
awsSignableRequest.getHeaders().get(AUTHORIZATION));
+
+                if (!sessionToken.isEmpty()) {
+                    fullHttpRequest.headers().add(X_AMZ_SECURITY_TOKEN, 
sessionToken);
+                }
+            } catch (final Throwable t) {
+                throw new RuntimeException(t);
+            }
+            return fullHttpRequest;
+        }
+
+        private SignableRequest<?> toSignableRequest(final FullHttpRequest 
request)
+                throws Exception {
+
+            // make sure the request is not null and contains the minimal 
required set of information
+            checkNotNull(request, "The request must not be null");
+            checkNotNull(request.uri(), "The request URI must not be null");
+            checkNotNull(request.method(), "The request method must not be 
null");
+
+            // convert the headers to the internal API format
+            final HttpHeaders headers = request.headers();
+            final Map<String, String> headersInternal = new HashMap<>();
+
+            String hostName = "";
+
+            // we don't want to add the Host header as the Signer always adds 
the host header.
+            for (String header : headers.names()) {
+                // Skip adding the Host header as the signing process will add 
one.
+                if (!header.equalsIgnoreCase(HOST)) {
+                    headersInternal.put(header, headers.get(header));
+                } else {
+                    hostName = headers.get(header);
+                }
+            }
+
+            // convert the parameters to the internal API format
+            final URI uri = URI.create(request.uri());
+
+            final String queryStr = uri.getQuery();
+            final Map<String, List<String>> parametersInternal = new 
HashMap<>(extractParametersFromQueryString(queryStr));
+
+            // carry over the entity (or an empty entity, if no entity is 
provided)
+            final InputStream content;
+            final ByteBuf contentBuffer = request.content();
+            boolean hasContent = false;
+            try {
+                if (contentBuffer != null && contentBuffer.isReadable()) {
+                    hasContent = true;
+                    contentBuffer.retain();
+                    byte[] bytes = new byte[contentBuffer.readableBytes()];
+                    contentBuffer.getBytes(contentBuffer.readerIndex(), bytes);
+                    content = new ByteArrayInputStream(bytes);
+                } else {
+                    content = new StringEntity("").getContent();
+                }
+            } catch (UnsupportedEncodingException e) {
+                throw new Exception("Encoding of the input string failed", e);
+            } catch (IOException e) {
+                throw new Exception("IOException while accessing entity 
content", e);
+            } finally {
+                if (hasContent) {
+                    contentBuffer.release();
+                }
+            }
+
+            if (StringUtils.isNullOrEmpty(hostName)) {
+                // try to extract hostname from the uri since hostname was not 
provided in the header.
+                final String authority = uri.getAuthority();
+                if (authority == null) {
+                    throw new Exception("Unable to identify host information,"
+                            + " either hostname should be provided in the uri 
or should be passed as a header");
+                }
+
+                hostName = authority;
+            }
+
+            final URI endpointUri = URI.create("http://"; + hostName);
+
+            return convertToSignableRequest(
+                    request.method().name(),
+                    endpointUri,
+                    uri.getPath(),
+                    headersInternal,
+                    parametersInternal,
+                    content);
+        }
+
+        private Map<String, List<String>> 
extractParametersFromQueryString(final String queryStr) {
+
+            final Map<String, List<String>> parameters = new HashMap<>();
+
+            // convert the parameters to the internal API format
+            if (queryStr != null) {
+                for (final String queryParam : queryStr.split("&")) {
+
+                    if (!queryParam.isEmpty()) {
+
+                        final String[] keyValuePair = queryParam.split("=", 2);
+
+                        // parameters are encoded in the HTTP request, we need 
to decode them here
+                        final String key = 
SdkHttpUtils.urlDecode(keyValuePair[0]);
+                        final String value;
+
+                        if (keyValuePair.length == 2) {
+                            value = SdkHttpUtils.urlDecode(keyValuePair[1]);
+                        } else {
+                            value = "";
+                        }
+
+                        // insert the parameter key into the map, if not yet 
present
+                        if (!parameters.containsKey(key)) {
+                            parameters.put(key, new ArrayList<>());
+                        }
+
+                        // append the parameter value to the list for the 
given key
+                        parameters.get(key).add(value);
+                    }
+                }
+            }
+
+            return parameters;
+        }
+
+        private SignableRequest<?> convertToSignableRequest(
+                final String httpMethodName,
+                final URI httpEndpointUri,
+                final String resourcePath,
+                final Map<String, String> httpHeaders,
+                final Map<String, List<String>> httpParameters,
+                final InputStream httpContent) throws Exception {
+
+            checkNotNull(httpMethodName, "Http method name must not be null");
+            checkNotNull(httpEndpointUri, "Http endpoint URI must not be 
null");
+            checkNotNull(httpHeaders, "Http headers must not be null");
+            checkNotNull(httpParameters, "Http parameters must not be null");
+            checkNotNull(httpContent, "Http content name must not be null");
+
+            // create the HTTP AWS SDK Signable Request and carry over 
information
+            final DefaultRequest<?> awsRequest = new 
DefaultRequest<>(NEPTUNE_SERVICE_NAME);
+            awsRequest.setHttpMethod(HttpMethodName.fromValue(httpMethodName));
+            awsRequest.setEndpoint(httpEndpointUri);
+            awsRequest.setResourcePath(resourcePath);
+            awsRequest.setHeaders(httpHeaders);
+            awsRequest.setParameters(httpParameters);
+            awsRequest.setContent(httpContent);
+
+            return awsRequest;
+        }
+
+        private void checkNotNull(final Object obj, final String errMsg) 
throws Exception {
+            if (obj == null) {
+                throw new Exception(errMsg);
+            }
+        }
+    }
 }
diff --git 
a/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/handler/HttpGremlinRequestEncoder.java
 
b/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/handler/HttpGremlinRequestEncoder.java
index a004c899a7..23334f3ca6 100644
--- 
a/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/handler/HttpGremlinRequestEncoder.java
+++ 
b/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/handler/HttpGremlinRequestEncoder.java
@@ -35,6 +35,7 @@ import org.apache.tinkerpop.gremlin.util.MessageSerializerV4;
 import org.apache.tinkerpop.gremlin.util.message.RequestMessageV4;
 import org.apache.tinkerpop.gremlin.util.ser.SerTokens;
 
+import java.net.InetSocketAddress;
 import java.util.List;
 import java.util.function.UnaryOperator;
 
@@ -73,11 +74,12 @@ public final class HttpGremlinRequestEncoder extends 
MessageToMessageEncoder<Req
             request.headers().add(HttpHeaderNames.CONTENT_TYPE, mimeType);
             request.headers().add(HttpHeaderNames.CONTENT_LENGTH, 
buffer.readableBytes());
             request.headers().add(HttpHeaderNames.ACCEPT, mimeType);
+            request.headers().add(HttpHeaderNames.HOST, ((InetSocketAddress) 
channelHandlerContext.channel().remoteAddress()).getAddress().getHostAddress());
             if (userAgentEnabled) {
                 request.headers().add(HttpHeaderNames.USER_AGENT, 
UserAgent.USER_AGENT);
             }
 
-            for (final UnaryOperator<FullHttpRequest> interceptor: 
interceptors ) {
+            for (final UnaryOperator<FullHttpRequest> interceptor : 
interceptors) {
                 request = interceptor.apply(request);
             }
             objects.add(request);
diff --git 
a/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/GremlinServerAuthIntegrateTest.java
 
b/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/GremlinServerAuthIntegrateTest.java
index c33e42ba29..dfda66ee9a 100644
--- 
a/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/GremlinServerAuthIntegrateTest.java
+++ 
b/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/GremlinServerAuthIntegrateTest.java
@@ -18,6 +18,9 @@
  */
 package org.apache.tinkerpop.gremlin.server;
 
+import com.amazonaws.auth.AWSCredentials;
+import com.amazonaws.auth.AWSCredentialsProvider;
+import io.netty.handler.codec.http.FullHttpRequest;
 import org.apache.tinkerpop.gremlin.driver.Client;
 import org.apache.tinkerpop.gremlin.driver.Cluster;
 import org.apache.tinkerpop.gremlin.driver.exception.ResponseException;
@@ -26,16 +29,24 @@ import org.apache.tinkerpop.gremlin.util.ExceptionHelper;
 import org.ietf.jgss.GSSException;
 import org.junit.Test;
 
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.concurrent.ExecutionException;
+import java.util.concurrent.atomic.AtomicReference;
 
 import static org.apache.tinkerpop.gremlin.driver.Auth.basic;
+import static org.apache.tinkerpop.gremlin.driver.Auth.sigv4;
+import static org.hamcrest.CoreMatchers.containsString;
+import static org.hamcrest.CoreMatchers.startsWith;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.core.AnyOf.anyOf;
 import static org.hamcrest.core.IsInstanceOf.instanceOf;
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.fail;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
 
 /**
  * @author Stephen Mallette (http://stephen.genoprime.com)
@@ -59,6 +70,9 @@ public class GremlinServerAuthIntegrateTest extends 
AbstractGremlinServerIntegra
 
         final String nameOfTest = name.getMethodName();
         switch (nameOfTest) {
+            case "shouldPassSigv4ToServer":
+                settings.authentication = new 
Settings.AuthenticationSettings();
+                break;
             case "shouldAuthenticateOverSslWithPlainText":
             case "shouldFailIfSslEnabledOnServerButNotClient":
                 final Settings.SslSettings sslConfig = new 
Settings.SslSettings();
@@ -72,6 +86,34 @@ public class GremlinServerAuthIntegrateTest extends 
AbstractGremlinServerIntegra
         return settings;
     }
 
+    @Test
+    public void shouldPassSigv4ToServer() throws Exception {
+        final AWSCredentialsProvider credentialsProvider = 
mock(AWSCredentialsProvider.class);
+        final AWSCredentials credentials = mock(AWSCredentials.class);
+        when(credentialsProvider.getCredentials()).thenReturn(credentials);
+        when(credentials.getAWSAccessKeyId()).thenReturn("I am 
AWSAccessKeyId");
+        when(credentials.getAWSSecretKey()).thenReturn("I am AWSSecretKey");
+
+        final AtomicReference<FullHttpRequest> fullHttpRequest = new 
AtomicReference<>();
+        final Cluster cluster = TestClientFactory.build()
+                .auth(sigv4("us-west2", credentialsProvider))
+                .requestInterceptor(r -> {
+                    fullHttpRequest.set(r);
+                    return r;
+                })
+                .create();
+        final Client client = cluster.connect();
+        client.submit("1+1").all().get();
+
+        assertNotNull(fullHttpRequest.get().headers().get("X-Amz-Date"));
+        assertThat(fullHttpRequest.get().headers().get("Authorization"),
+                startsWith("AWS4-HMAC-SHA256 Credential=I am AWSAccessKeyId"));
+        assertThat(fullHttpRequest.get().headers().get("Authorization"),
+                containsString("/us-west2/neptune-db/aws4_request, 
SignedHeaders=accept;content-length;content-type;host;user-agent;x-amz-date, 
Signature="));
+
+        cluster.close();
+    }
+
     @Test
     public void shouldFailIfSslEnabledOnServerButNotClient() throws Exception {
         final Cluster cluster = TestClientFactory.open();
@@ -102,8 +144,11 @@ public class GremlinServerAuthIntegrateTest extends 
AbstractGremlinServerIntegra
         final Cluster cluster = TestClientFactory.build()
                 .enableSsl(true).sslSkipCertValidation(true)
                 .auth(basic("stephen", "password")).create();
+
         final Client client = cluster.connect();
 
+        client.submit("1+1").all().get();
+
         assertConnection(cluster, client);
     }
 

Reply via email to