This is an automated email from the ASF dual-hosted git repository.

etudenhoefner pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg.git


The following commit(s) were added to refs/heads/main by this push:
     new 9e895cb6df AWS: Refresh vended credentials (#11389)
9e895cb6df is described below

commit 9e895cb6dff9dcd2a117a4f5e197f0235047ff54
Author: Eduard Tudenhoefner <[email protected]>
AuthorDate: Wed Oct 30 10:29:45 2024 +0100

    AWS: Refresh vended credentials (#11389)
---
 .../apache/iceberg/aws/AwsClientProperties.java    |  34 ++-
 .../apache/iceberg/aws/s3/S3FileIOProperties.java  |   7 +
 .../iceberg/aws/s3/VendedCredentialsProvider.java  | 138 +++++++++
 .../iceberg/aws/AwsClientPropertiesTest.java       |  29 ++
 .../aws/s3/TestVendedCredentialsProvider.java      | 323 +++++++++++++++++++++
 5 files changed, 526 insertions(+), 5 deletions(-)

diff --git a/aws/src/main/java/org/apache/iceberg/aws/AwsClientProperties.java 
b/aws/src/main/java/org/apache/iceberg/aws/AwsClientProperties.java
index 0c91f8685a..4f2d4d6a5a 100644
--- a/aws/src/main/java/org/apache/iceberg/aws/AwsClientProperties.java
+++ b/aws/src/main/java/org/apache/iceberg/aws/AwsClientProperties.java
@@ -20,6 +20,7 @@ package org.apache.iceberg.aws;
 
 import java.io.Serializable;
 import java.util.Map;
+import org.apache.iceberg.aws.s3.VendedCredentialsProvider;
 import org.apache.iceberg.common.DynClasses;
 import org.apache.iceberg.common.DynMethods;
 import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
@@ -66,14 +67,27 @@ public class AwsClientProperties implements Serializable {
    */
   public static final String CLIENT_REGION = "client.region";
 
+  /**
+   * When set, the {@link VendedCredentialsProvider} will be used to fetch and 
refresh vended
+   * credentials from this endpoint.
+   */
+  public static final String REFRESH_CREDENTIALS_ENDPOINT = 
"client.refresh-credentials-endpoint";
+
+  /** Controls whether vended credentials should be refreshed or not. Defaults 
to true. */
+  public static final String REFRESH_CREDENTIALS_ENABLED = 
"client.refresh-credentials-enabled";
+
   private String clientRegion;
   private final String clientCredentialsProvider;
   private final Map<String, String> clientCredentialsProviderProperties;
+  private final String refreshCredentialsEndpoint;
+  private final boolean refreshCredentialsEnabled;
 
   public AwsClientProperties() {
     this.clientRegion = null;
     this.clientCredentialsProvider = null;
     this.clientCredentialsProviderProperties = null;
+    this.refreshCredentialsEndpoint = null;
+    this.refreshCredentialsEnabled = true;
   }
 
   public AwsClientProperties(Map<String, String> properties) {
@@ -81,6 +95,9 @@ public class AwsClientProperties implements Serializable {
     this.clientCredentialsProvider = 
properties.get(CLIENT_CREDENTIALS_PROVIDER);
     this.clientCredentialsProviderProperties =
         PropertyUtil.propertiesWithPrefix(properties, 
CLIENT_CREDENTIAL_PROVIDER_PREFIX);
+    this.refreshCredentialsEndpoint = 
properties.get(REFRESH_CREDENTIALS_ENDPOINT);
+    this.refreshCredentialsEnabled =
+        PropertyUtil.propertyAsBoolean(properties, 
REFRESH_CREDENTIALS_ENABLED, true);
   }
 
   public String clientRegion() {
@@ -122,11 +139,12 @@ public class AwsClientProperties implements Serializable {
   }
 
   /**
-   * Returns a credentials provider instance. If params were set, we return a 
new credentials
-   * instance. If none of the params are set, we try to dynamically load the 
provided credentials
-   * provider class. Upon loading the class, we try to invoke {@code 
create(Map<String, String>)}
-   * static method. If that fails, we fall back to {@code create()}. If 
credential provider class
-   * wasn't set, we fall back to default credentials provider.
+   * Returns a credentials provider instance. If {@link 
#refreshCredentialsEndpoint} is set, an
+   * instance of {@link VendedCredentialsProvider} is returned. If params were 
set, we return a new
+   * credentials instance. If none of the params are set, we try to 
dynamically load the provided
+   * credentials provider class. Upon loading the class, we try to invoke 
{@code create(Map<String,
+   * String>)} static method. If that fails, we fall back to {@code create()}. 
If credential
+   * provider class wasn't set, we fall back to default credentials provider.
    *
    * @param accessKeyId the AWS access key ID
    * @param secretAccessKey the AWS secret access key
@@ -136,6 +154,12 @@ public class AwsClientProperties implements Serializable {
   @SuppressWarnings("checkstyle:HiddenField")
   public AwsCredentialsProvider credentialsProvider(
       String accessKeyId, String secretAccessKey, String sessionToken) {
+    if (refreshCredentialsEnabled && 
!Strings.isNullOrEmpty(refreshCredentialsEndpoint)) {
+      clientCredentialsProviderProperties.put(
+          VendedCredentialsProvider.URI, refreshCredentialsEndpoint);
+      return credentialsProvider(VendedCredentialsProvider.class.getName());
+    }
+
     if (!Strings.isNullOrEmpty(accessKeyId) && 
!Strings.isNullOrEmpty(secretAccessKey)) {
       if (Strings.isNullOrEmpty(sessionToken)) {
         return StaticCredentialsProvider.create(
diff --git 
a/aws/src/main/java/org/apache/iceberg/aws/s3/S3FileIOProperties.java 
b/aws/src/main/java/org/apache/iceberg/aws/s3/S3FileIOProperties.java
index 5da758704a..8d97b9d1bf 100644
--- a/aws/src/main/java/org/apache/iceberg/aws/s3/S3FileIOProperties.java
+++ b/aws/src/main/java/org/apache/iceberg/aws/s3/S3FileIOProperties.java
@@ -225,6 +225,13 @@ public class S3FileIOProperties implements Serializable {
    */
   public static final String SESSION_TOKEN = "s3.session-token";
 
+  /**
+   * Configure the expiration time in millis of the static session token used 
to access S3FileIO.
+   * This expiration time is currently only used in {@link 
VendedCredentialsProvider} for refreshing
+   * vended credentials.
+   */
+  static final String SESSION_TOKEN_EXPIRES_AT_MS = 
"s3.session-token-expires-at-ms";
+
   /**
    * Enable to make S3FileIO, to make cross-region call to the region 
specified in the ARN of an
    * access point.
diff --git 
a/aws/src/main/java/org/apache/iceberg/aws/s3/VendedCredentialsProvider.java 
b/aws/src/main/java/org/apache/iceberg/aws/s3/VendedCredentialsProvider.java
new file mode 100644
index 0000000000..e249d3ff1d
--- /dev/null
+++ b/aws/src/main/java/org/apache/iceberg/aws/s3/VendedCredentialsProvider.java
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.aws.s3;
+
+import java.time.Instant;
+import java.time.temporal.ChronoUnit;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
+import org.apache.iceberg.rest.ErrorHandlers;
+import org.apache.iceberg.rest.HTTPClient;
+import org.apache.iceberg.rest.RESTClient;
+import org.apache.iceberg.rest.auth.OAuth2Properties;
+import org.apache.iceberg.rest.auth.OAuth2Util;
+import org.apache.iceberg.rest.credentials.Credential;
+import org.apache.iceberg.rest.responses.LoadCredentialsResponse;
+import software.amazon.awssdk.auth.credentials.AwsCredentials;
+import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
+import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
+import software.amazon.awssdk.utils.IoUtils;
+import software.amazon.awssdk.utils.SdkAutoCloseable;
+import software.amazon.awssdk.utils.cache.CachedSupplier;
+import software.amazon.awssdk.utils.cache.RefreshResult;
+
+public class VendedCredentialsProvider implements AwsCredentialsProvider, 
SdkAutoCloseable {
+  public static final String URI = "credentials.uri";
+  private volatile HTTPClient client;
+  private final Map<String, String> properties;
+  private final CachedSupplier<AwsCredentials> credentialCache;
+
+  private VendedCredentialsProvider(Map<String, String> properties) {
+    Preconditions.checkArgument(null != properties, "Invalid properties: 
null");
+    Preconditions.checkArgument(null != properties.get(URI), "Invalid URI: 
null");
+    this.properties = properties;
+    this.credentialCache =
+        CachedSupplier.builder(this::refreshCredential)
+            .cachedValueName(VendedCredentialsProvider.class.getName())
+            .build();
+  }
+
+  @Override
+  public AwsCredentials resolveCredentials() {
+    return credentialCache.get();
+  }
+
+  @Override
+  public void close() {
+    IoUtils.closeQuietly(client, null);
+    credentialCache.close();
+  }
+
+  public static VendedCredentialsProvider create(Map<String, String> 
properties) {
+    return new VendedCredentialsProvider(properties);
+  }
+
+  private RESTClient httpClient() {
+    if (null == client) {
+      synchronized (this) {
+        if (null == client) {
+          client = 
HTTPClient.builder(properties).uri(properties.get(URI)).build();
+        }
+      }
+    }
+
+    return client;
+  }
+
+  private LoadCredentialsResponse fetchCredentials() {
+    return httpClient()
+        .get(
+            properties.get(URI),
+            null,
+            LoadCredentialsResponse.class,
+            OAuth2Util.authHeaders(properties.get(OAuth2Properties.TOKEN)),
+            ErrorHandlers.defaultErrorHandler());
+  }
+
+  private RefreshResult<AwsCredentials> refreshCredential() {
+    LoadCredentialsResponse response = fetchCredentials();
+
+    List<Credential> s3Credentials =
+        response.credentials().stream()
+            .filter(c -> c.prefix().startsWith("s3"))
+            .collect(Collectors.toList());
+
+    Preconditions.checkState(!s3Credentials.isEmpty(), "Invalid S3 
Credentials: empty");
+    Preconditions.checkState(
+        s3Credentials.size() == 1, "Invalid S3 Credentials: only one S3 
credential should exist");
+
+    Credential s3Credential = s3Credentials.get(0);
+    checkCredential(s3Credential, S3FileIOProperties.ACCESS_KEY_ID);
+    checkCredential(s3Credential, S3FileIOProperties.SECRET_ACCESS_KEY);
+    checkCredential(s3Credential, S3FileIOProperties.SESSION_TOKEN);
+    checkCredential(s3Credential, 
S3FileIOProperties.SESSION_TOKEN_EXPIRES_AT_MS);
+
+    String accessKeyId = 
s3Credential.config().get(S3FileIOProperties.ACCESS_KEY_ID);
+    String secretAccessKey = 
s3Credential.config().get(S3FileIOProperties.SECRET_ACCESS_KEY);
+    String sessionToken = 
s3Credential.config().get(S3FileIOProperties.SESSION_TOKEN);
+    String tokenExpiresAtMillis =
+        
s3Credential.config().get(S3FileIOProperties.SESSION_TOKEN_EXPIRES_AT_MS);
+    Instant expiresAt = 
Instant.ofEpochMilli(Long.parseLong(tokenExpiresAtMillis));
+    Instant prefetchAt = expiresAt.minus(5, ChronoUnit.MINUTES);
+
+    return RefreshResult.builder(
+            (AwsCredentials)
+                AwsSessionCredentials.builder()
+                    .accessKeyId(accessKeyId)
+                    .secretAccessKey(secretAccessKey)
+                    .sessionToken(sessionToken)
+                    .expirationTime(expiresAt)
+                    .build())
+        .staleTime(expiresAt)
+        .prefetchTime(prefetchAt)
+        .build();
+  }
+
+  private void checkCredential(Credential credential, String property) {
+    Preconditions.checkState(
+        credential.config().containsKey(property), "Invalid S3 Credentials: %s 
not set", property);
+  }
+}
diff --git 
a/aws/src/test/java/org/apache/iceberg/aws/AwsClientPropertiesTest.java 
b/aws/src/test/java/org/apache/iceberg/aws/AwsClientPropertiesTest.java
index c318538d95..5cf9dd810c 100644
--- a/aws/src/test/java/org/apache/iceberg/aws/AwsClientPropertiesTest.java
+++ b/aws/src/test/java/org/apache/iceberg/aws/AwsClientPropertiesTest.java
@@ -21,6 +21,8 @@ package org.apache.iceberg.aws;
 import static org.assertj.core.api.Assertions.assertThat;
 
 import java.util.Map;
+import org.apache.iceberg.aws.s3.VendedCredentialsProvider;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
 import org.apache.iceberg.relocated.com.google.common.collect.Maps;
 import org.junit.jupiter.api.Test;
 import org.mockito.ArgumentCaptor;
@@ -29,6 +31,7 @@ import 
software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
 import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
 import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
 import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
+import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
 import software.amazon.awssdk.regions.Region;
 import software.amazon.awssdk.services.s3.S3ClientBuilder;
 
@@ -111,4 +114,30 @@ public class AwsClientPropertiesTest {
         .as("The secret access key should be the same as the one set by tag 
SECRET_ACCESS_KEY")
         .isEqualTo("secret");
   }
+
+  @Test
+  public void refreshCredentialsEndpoint() {
+    AwsClientProperties awsClientProperties =
+        new AwsClientProperties(
+            ImmutableMap.of(
+                AwsClientProperties.REFRESH_CREDENTIALS_ENDPOINT,
+                "http://localhost:1234/v1/credentials";));
+
+    assertThat(awsClientProperties.credentialsProvider("key", "secret", 
"token"))
+        .isInstanceOf(VendedCredentialsProvider.class);
+  }
+
+  @Test
+  public void refreshCredentialsEndpointSetButRefreshDisabled() {
+    AwsClientProperties awsClientProperties =
+        new AwsClientProperties(
+            ImmutableMap.of(
+                AwsClientProperties.REFRESH_CREDENTIALS_ENABLED,
+                "false",
+                AwsClientProperties.REFRESH_CREDENTIALS_ENDPOINT,
+                "http://localhost:1234/v1/credentials";));
+
+    assertThat(awsClientProperties.credentialsProvider("key", "secret", 
"token"))
+        .isInstanceOf(StaticCredentialsProvider.class);
+  }
 }
diff --git 
a/aws/src/test/java/org/apache/iceberg/aws/s3/TestVendedCredentialsProvider.java
 
b/aws/src/test/java/org/apache/iceberg/aws/s3/TestVendedCredentialsProvider.java
new file mode 100644
index 0000000000..67cd1cb552
--- /dev/null
+++ 
b/aws/src/test/java/org/apache/iceberg/aws/s3/TestVendedCredentialsProvider.java
@@ -0,0 +1,323 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.aws.s3;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.mockserver.integration.ClientAndServer.startClientAndServer;
+import static org.mockserver.model.HttpRequest.request;
+import static org.mockserver.model.HttpResponse.response;
+
+import java.time.Instant;
+import java.time.temporal.ChronoUnit;
+import org.apache.iceberg.exceptions.RESTException;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
+import org.apache.iceberg.rest.HttpMethod;
+import org.apache.iceberg.rest.credentials.Credential;
+import org.apache.iceberg.rest.credentials.ImmutableCredential;
+import org.apache.iceberg.rest.responses.ImmutableLoadCredentialsResponse;
+import org.apache.iceberg.rest.responses.LoadCredentialsResponse;
+import org.apache.iceberg.rest.responses.LoadCredentialsResponseParser;
+import org.junit.jupiter.api.AfterAll;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.mockserver.integration.ClientAndServer;
+import org.mockserver.model.HttpRequest;
+import org.mockserver.model.HttpResponse;
+import org.mockserver.verify.VerificationTimes;
+import software.amazon.awssdk.auth.credentials.AwsCredentials;
+import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
+
+public class TestVendedCredentialsProvider {
+
+  private static final int PORT = 3232;
+  private static final String URI = 
String.format("http://127.0.0.1:%d/v1/credentials";, PORT);
+  private static ClientAndServer mockServer;
+
+  @BeforeAll
+  public static void beforeAll() {
+    mockServer = startClientAndServer(PORT);
+  }
+
+  @AfterAll
+  public static void stopServer() {
+    mockServer.stop();
+  }
+
+  @BeforeEach
+  public void before() {
+    mockServer.reset();
+  }
+
+  @Test
+  public void invalidOrMissingUri() {
+    assertThatThrownBy(() -> VendedCredentialsProvider.create(null))
+        .isInstanceOf(IllegalArgumentException.class)
+        .hasMessage("Invalid properties: null");
+    assertThatThrownBy(() -> 
VendedCredentialsProvider.create(ImmutableMap.of()))
+        .isInstanceOf(IllegalArgumentException.class)
+        .hasMessage("Invalid URI: null");
+
+    try (VendedCredentialsProvider provider =
+        VendedCredentialsProvider.create(
+            ImmutableMap.of(VendedCredentialsProvider.URI, "invalid uri"))) {
+      assertThatThrownBy(provider::resolveCredentials)
+          .isInstanceOf(RESTException.class)
+          .hasMessageStartingWith("Failed to create request URI from base 
invalid uri");
+    }
+  }
+
+  @Test
+  public void noS3Credentials() {
+    HttpRequest mockRequest = 
request("/v1/credentials").withMethod(HttpMethod.GET.name());
+
+    HttpResponse mockResponse =
+        response(
+                LoadCredentialsResponseParser.toJson(
+                    ImmutableLoadCredentialsResponse.builder().build()))
+            .withStatusCode(200);
+    mockServer.when(mockRequest).respond(mockResponse);
+
+    try (VendedCredentialsProvider provider =
+        
VendedCredentialsProvider.create(ImmutableMap.of(VendedCredentialsProvider.URI, 
URI))) {
+      assertThatThrownBy(provider::resolveCredentials)
+          .isInstanceOf(IllegalStateException.class)
+          .hasMessage("Invalid S3 Credentials: empty");
+    }
+  }
+
+  @Test
+  public void accessKeyIdAndSecretAccessKeyWithoutToken() {
+    HttpRequest mockRequest = 
request("/v1/credentials").withMethod(HttpMethod.GET.name());
+    LoadCredentialsResponse response =
+        ImmutableLoadCredentialsResponse.builder()
+            .addCredentials(
+                ImmutableCredential.builder()
+                    .prefix("s3")
+                    .config(
+                        ImmutableMap.of(
+                            S3FileIOProperties.ACCESS_KEY_ID,
+                            "randomAccessKey",
+                            S3FileIOProperties.SECRET_ACCESS_KEY,
+                            "randomSecretAccessKey"))
+                    .build())
+            .build();
+
+    HttpResponse mockResponse =
+        
response(LoadCredentialsResponseParser.toJson(response)).withStatusCode(200);
+    mockServer.when(mockRequest).respond(mockResponse);
+
+    try (VendedCredentialsProvider provider =
+        
VendedCredentialsProvider.create(ImmutableMap.of(VendedCredentialsProvider.URI, 
URI))) {
+      assertThatThrownBy(provider::resolveCredentials)
+          .isInstanceOf(IllegalStateException.class)
+          .hasMessage("Invalid S3 Credentials: s3.session-token not set");
+    }
+  }
+
+  @Test
+  public void expirationNotSet() {
+    HttpRequest mockRequest = 
request("/v1/credentials").withMethod(HttpMethod.GET.name());
+    LoadCredentialsResponse response =
+        ImmutableLoadCredentialsResponse.builder()
+            .addCredentials(
+                ImmutableCredential.builder()
+                    .prefix("s3")
+                    .config(
+                        ImmutableMap.of(
+                            S3FileIOProperties.ACCESS_KEY_ID,
+                            "randomAccessKey",
+                            S3FileIOProperties.SECRET_ACCESS_KEY,
+                            "randomSecretAccessKey",
+                            S3FileIOProperties.SESSION_TOKEN,
+                            "sessionToken"))
+                    .build())
+            .build();
+
+    HttpResponse mockResponse =
+        
response(LoadCredentialsResponseParser.toJson(response)).withStatusCode(200);
+    mockServer.when(mockRequest).respond(mockResponse);
+
+    try (VendedCredentialsProvider provider =
+        
VendedCredentialsProvider.create(ImmutableMap.of(VendedCredentialsProvider.URI, 
URI))) {
+      assertThatThrownBy(provider::resolveCredentials)
+          .isInstanceOf(IllegalStateException.class)
+          .hasMessage("Invalid S3 Credentials: s3.session-token-expires-at-ms 
not set");
+    }
+  }
+
+  @Test
+  public void nonExpiredToken() {
+    HttpRequest mockRequest = 
request("/v1/credentials").withMethod(HttpMethod.GET.name());
+    Credential credential =
+        ImmutableCredential.builder()
+            .prefix("s3")
+            .config(
+                ImmutableMap.of(
+                    S3FileIOProperties.ACCESS_KEY_ID,
+                    "randomAccessKey",
+                    S3FileIOProperties.SECRET_ACCESS_KEY,
+                    "randomSecretAccessKey",
+                    S3FileIOProperties.SESSION_TOKEN,
+                    "sessionToken",
+                    S3FileIOProperties.SESSION_TOKEN_EXPIRES_AT_MS,
+                    Long.toString(Instant.now().plus(1, 
ChronoUnit.HOURS).toEpochMilli())))
+            .build();
+    LoadCredentialsResponse response =
+        
ImmutableLoadCredentialsResponse.builder().addCredentials(credential).build();
+
+    HttpResponse mockResponse =
+        
response(LoadCredentialsResponseParser.toJson(response)).withStatusCode(200);
+    mockServer.when(mockRequest).respond(mockResponse);
+
+    try (VendedCredentialsProvider provider =
+        
VendedCredentialsProvider.create(ImmutableMap.of(VendedCredentialsProvider.URI, 
URI))) {
+      AwsCredentials awsCredentials = provider.resolveCredentials();
+
+      verifyCredentials(awsCredentials, credential);
+
+      for (int i = 0; i < 5; i++) {
+        // resolving credentials multiple times should not hit the credentials 
endpoint again
+        assertThat(provider.resolveCredentials()).isSameAs(awsCredentials);
+      }
+    }
+
+    mockServer.verify(mockRequest, VerificationTimes.once());
+  }
+
+  @Test
+  public void expiredToken() {
+    HttpRequest mockRequest = 
request("/v1/credentials").withMethod(HttpMethod.GET.name());
+    Credential credential =
+        ImmutableCredential.builder()
+            .prefix("s3")
+            .config(
+                ImmutableMap.of(
+                    S3FileIOProperties.ACCESS_KEY_ID,
+                    "randomAccessKey",
+                    S3FileIOProperties.SECRET_ACCESS_KEY,
+                    "randomSecretAccessKey",
+                    S3FileIOProperties.SESSION_TOKEN,
+                    "sessionToken",
+                    S3FileIOProperties.SESSION_TOKEN_EXPIRES_AT_MS,
+                    Long.toString(Instant.now().minus(1, 
ChronoUnit.MINUTES).toEpochMilli())))
+            .build();
+    LoadCredentialsResponse response =
+        
ImmutableLoadCredentialsResponse.builder().addCredentials(credential).build();
+
+    HttpResponse mockResponse =
+        
response(LoadCredentialsResponseParser.toJson(response)).withStatusCode(200);
+    mockServer.when(mockRequest).respond(mockResponse);
+
+    try (VendedCredentialsProvider provider =
+        
VendedCredentialsProvider.create(ImmutableMap.of(VendedCredentialsProvider.URI, 
URI))) {
+      AwsCredentials awsCredentials = provider.resolveCredentials();
+      verifyCredentials(awsCredentials, credential);
+
+      // resolving credentials multiple times should hit the credentials 
endpoint again
+      AwsCredentials refreshedCredentials = provider.resolveCredentials();
+      assertThat(refreshedCredentials).isNotSameAs(awsCredentials);
+      verifyCredentials(refreshedCredentials, credential);
+    }
+
+    mockServer.verify(mockRequest, VerificationTimes.exactly(2));
+  }
+
+  @Test
+  public void multipleS3Credentials() {
+    HttpRequest mockRequest = 
request("/v1/credentials").withMethod(HttpMethod.GET.name());
+    Credential credentialOne =
+        ImmutableCredential.builder()
+            .prefix("gcs")
+            .config(
+                ImmutableMap.of(
+                    S3FileIOProperties.ACCESS_KEY_ID,
+                    "randomAccessKey1",
+                    S3FileIOProperties.SECRET_ACCESS_KEY,
+                    "randomSecretAccessKey1",
+                    S3FileIOProperties.SESSION_TOKEN,
+                    "sessionToken1",
+                    S3FileIOProperties.SESSION_TOKEN_EXPIRES_AT_MS,
+                    Long.toString(Instant.now().plus(1, 
ChronoUnit.HOURS).toEpochMilli())))
+            .build();
+    Credential credentialTwo =
+        ImmutableCredential.builder()
+            .prefix("s3://custom-uri/longest-prefix")
+            .config(
+                ImmutableMap.of(
+                    S3FileIOProperties.ACCESS_KEY_ID,
+                    "randomAccessKey2",
+                    S3FileIOProperties.SECRET_ACCESS_KEY,
+                    "randomSecretAccessKey2",
+                    S3FileIOProperties.SESSION_TOKEN,
+                    "sessionToken2",
+                    S3FileIOProperties.SESSION_TOKEN_EXPIRES_AT_MS,
+                    Long.toString(Instant.now().plus(2, 
ChronoUnit.HOURS).toEpochMilli())))
+            .build();
+    Credential credentialThree =
+        ImmutableCredential.builder()
+            .prefix("s3://custom-uri/long")
+            .config(
+                ImmutableMap.of(
+                    S3FileIOProperties.ACCESS_KEY_ID,
+                    "randomAccessKey3",
+                    S3FileIOProperties.SECRET_ACCESS_KEY,
+                    "randomSecretAccessKey3",
+                    S3FileIOProperties.SESSION_TOKEN,
+                    "sessionToken3",
+                    S3FileIOProperties.SESSION_TOKEN_EXPIRES_AT_MS,
+                    Long.toString(Instant.now().plus(3, 
ChronoUnit.HOURS).toEpochMilli())))
+            .build();
+    LoadCredentialsResponse response =
+        ImmutableLoadCredentialsResponse.builder()
+            .addCredentials(credentialOne, credentialTwo, credentialThree)
+            .build();
+
+    HttpResponse mockResponse =
+        
response(LoadCredentialsResponseParser.toJson(response)).withStatusCode(200);
+    mockServer.when(mockRequest).respond(mockResponse);
+
+    try (VendedCredentialsProvider provider =
+        
VendedCredentialsProvider.create(ImmutableMap.of(VendedCredentialsProvider.URI, 
URI))) {
+      assertThatThrownBy(provider::resolveCredentials)
+          .isInstanceOf(IllegalStateException.class)
+          .hasMessage("Invalid S3 Credentials: only one S3 credential should 
exist");
+    }
+  }
+
+  private void verifyCredentials(AwsCredentials awsCredentials, Credential 
credential) {
+    assertThat(awsCredentials).isInstanceOf(AwsSessionCredentials.class);
+    AwsSessionCredentials creds = (AwsSessionCredentials) awsCredentials;
+
+    assertThat(creds.accessKeyId())
+        .isEqualTo(credential.config().get(S3FileIOProperties.ACCESS_KEY_ID));
+    assertThat(creds.secretAccessKey())
+        
.isEqualTo(credential.config().get(S3FileIOProperties.SECRET_ACCESS_KEY));
+    assertThat(creds.sessionToken())
+        .isEqualTo(credential.config().get(S3FileIOProperties.SESSION_TOKEN));
+    assertThat(creds.expirationTime())
+        .isPresent()
+        .get()
+        .extracting(Instant::toEpochMilli)
+        .isEqualTo(
+            Long.parseLong(
+                
credential.config().get(S3FileIOProperties.SESSION_TOKEN_EXPIRES_AT_MS)));
+  }
+}

Reply via email to