singhpk234 commented on code in PR #3327:
URL: https://github.com/apache/polaris/pull/3327#discussion_r2653725307
##########
polaris-core/src/main/java/org/apache/polaris/core/storage/aws/AwsCredentialsStorageIntegration.java:
##########
@@ -118,6 +124,19 @@ public StorageAccessConfig getSubscopedCreds(
accountId)
.toJson())
.durationSeconds(storageCredentialDurationSeconds);
+
+ // Add session tags when the feature is enabled
+ if (includeSessionTags) {
+ List<Tag> sessionTags =
+ buildSessionTags(polarisPrincipal.getName(),
credentialVendingContext);
Review Comment:
we can send PolarisPrincipal as an arg here and can extract both
principalName and `getRoles` from the prinicipal rather than making it part of
credentialVendingContext ?
##########
polaris-core/src/test/java/org/apache/polaris/service/storage/aws/AwsCredentialsStorageIntegrationTest.java:
##########
@@ -965,4 +981,390 @@ public void testGetSubscopedCredsLongPrincipalName() {
private static @Nonnull String s3Path(String bucket, String keyPrefix) {
return "s3://" + bucket + "/" + keyPrefix;
}
+
+ // Tests for AWS STS Session Tags functionality
+
+ @Test
+ public void testSessionTagsIncludedWhenFeatureEnabled() {
+ StsClient stsClient = Mockito.mock(StsClient.class);
+ String roleARN = "arn:aws:iam::012345678901:role/jdoe";
+ String externalId = "externalId";
+ String bucket = "bucket";
+ String warehouseKeyPrefix = "path/to/warehouse";
+
+ // Create a realm config with session tags enabled
+ RealmConfig sessionTagsEnabledConfig =
+ new RealmConfigImpl(
+ new PolarisConfigurationStore() {
+ @SuppressWarnings("unchecked")
+ @Override
+ public String getConfiguration(@Nonnull RealmContext ctx, String
configName) {
+ if (configName.equals(
+
FeatureConfiguration.INCLUDE_SESSION_TAGS_IN_SUBSCOPED_CREDENTIAL.key())) {
+ return "true";
+ }
+ return null;
+ }
+ },
+ () -> "realm");
+
+ ArgumentCaptor<AssumeRoleRequest> requestCaptor =
+ ArgumentCaptor.forClass(AssumeRoleRequest.class);
+
Mockito.when(stsClient.assumeRole(requestCaptor.capture())).thenReturn(ASSUME_ROLE_RESPONSE);
+
+ CredentialVendingContext context =
+ CredentialVendingContext.builder()
+ .catalogName(Optional.of("test-catalog"))
+ .namespace(Optional.of("db.schema"))
+ .tableName(Optional.of("my_table"))
+ .activatedRoles(Optional.of("admin,reader"))
+ .build();
+
+ new AwsCredentialsStorageIntegration(
+ AwsStorageConfigurationInfo.builder()
+ .addAllowedLocation(s3Path(bucket, warehouseKeyPrefix))
+ .roleARN(roleARN)
+ .externalId(externalId)
+ .build(),
+ stsClient)
+ .getSubscopedCreds(
+ sessionTagsEnabledConfig,
+ true,
+ Set.of(s3Path(bucket, warehouseKeyPrefix)),
+ Set.of(s3Path(bucket, warehouseKeyPrefix)),
+ POLARIS_PRINCIPAL,
+ Optional.empty(),
+ context);
+
+ AssumeRoleRequest capturedRequest = requestCaptor.getValue();
+ Assertions.assertThat(capturedRequest.tags()).isNotEmpty();
+ Assertions.assertThat(capturedRequest.tags())
+ .anyMatch(tag -> tag.key().equals("polaris:catalog") &&
tag.value().equals("test-catalog"));
+ Assertions.assertThat(capturedRequest.tags())
+ .anyMatch(tag -> tag.key().equals("polaris:namespace") &&
tag.value().equals("db.schema"));
+ Assertions.assertThat(capturedRequest.tags())
+ .anyMatch(tag -> tag.key().equals("polaris:table") &&
tag.value().equals("my_table"));
+ Assertions.assertThat(capturedRequest.tags())
+ .anyMatch(
+ tag -> tag.key().equals("polaris:principal") &&
tag.value().equals("test-principal"));
+ Assertions.assertThat(capturedRequest.tags())
+ .anyMatch(tag -> tag.key().equals("polaris:roles") &&
tag.value().equals("admin,reader"));
+
+ // Verify transitive tag keys are set
+ Assertions.assertThat(capturedRequest.transitiveTagKeys())
+ .containsExactlyInAnyOrder(
+ "polaris:catalog",
+ "polaris:namespace",
+ "polaris:table",
+ "polaris:principal",
+ "polaris:roles");
+ }
+
+ @Test
+ public void testSessionTagsNotIncludedWhenFeatureDisabled() {
+ StsClient stsClient = Mockito.mock(StsClient.class);
+ String roleARN = "arn:aws:iam::012345678901:role/jdoe";
+ String externalId = "externalId";
+ String bucket = "bucket";
+ String warehouseKeyPrefix = "path/to/warehouse";
+
+ ArgumentCaptor<AssumeRoleRequest> requestCaptor =
+ ArgumentCaptor.forClass(AssumeRoleRequest.class);
+
Mockito.when(stsClient.assumeRole(requestCaptor.capture())).thenReturn(ASSUME_ROLE_RESPONSE);
+
+ CredentialVendingContext context =
+ CredentialVendingContext.builder()
+ .catalogName(Optional.of("test-catalog"))
+ .namespace(Optional.of("db.schema"))
+ .tableName(Optional.of("my_table"))
+ .build();
+
+ // Use EMPTY_REALM_CONFIG which has session tags disabled by default
+ new AwsCredentialsStorageIntegration(
+ AwsStorageConfigurationInfo.builder()
+ .addAllowedLocation(s3Path(bucket, warehouseKeyPrefix))
+ .roleARN(roleARN)
+ .externalId(externalId)
+ .build(),
+ stsClient)
+ .getSubscopedCreds(
+ EMPTY_REALM_CONFIG,
+ true,
+ Set.of(s3Path(bucket, warehouseKeyPrefix)),
+ Set.of(s3Path(bucket, warehouseKeyPrefix)),
+ POLARIS_PRINCIPAL,
+ Optional.empty(),
+ context);
+
+ AssumeRoleRequest capturedRequest = requestCaptor.getValue();
+ // Tags should be empty when feature is disabled
+ Assertions.assertThat(capturedRequest.tags()).isEmpty();
+ Assertions.assertThat(capturedRequest.transitiveTagKeys()).isEmpty();
+ }
+
+ @Test
+ public void testSessionTagsWithPartialContext() {
+ StsClient stsClient = Mockito.mock(StsClient.class);
+ String roleARN = "arn:aws:iam::012345678901:role/jdoe";
+ String externalId = "externalId";
+ String bucket = "bucket";
+ String warehouseKeyPrefix = "path/to/warehouse";
+
+ RealmConfig sessionTagsEnabledConfig =
+ new RealmConfigImpl(
+ new PolarisConfigurationStore() {
+ @SuppressWarnings("unchecked")
+ @Override
+ public String getConfiguration(@Nonnull RealmContext ctx, String
configName) {
+ if (configName.equals(
+
FeatureConfiguration.INCLUDE_SESSION_TAGS_IN_SUBSCOPED_CREDENTIAL.key())) {
+ return "true";
+ }
+ return null;
+ }
+ },
+ () -> "realm");
+
+ ArgumentCaptor<AssumeRoleRequest> requestCaptor =
+ ArgumentCaptor.forClass(AssumeRoleRequest.class);
+
Mockito.when(stsClient.assumeRole(requestCaptor.capture())).thenReturn(ASSUME_ROLE_RESPONSE);
+
+ // Only provide catalog name, no namespace/table
+ CredentialVendingContext context =
+
CredentialVendingContext.builder().catalogName(Optional.of("test-catalog")).build();
+
+ new AwsCredentialsStorageIntegration(
+ AwsStorageConfigurationInfo.builder()
+ .addAllowedLocation(s3Path(bucket, warehouseKeyPrefix))
+ .roleARN(roleARN)
+ .externalId(externalId)
+ .build(),
+ stsClient)
+ .getSubscopedCreds(
+ sessionTagsEnabledConfig,
+ true,
+ Set.of(s3Path(bucket, warehouseKeyPrefix)),
+ Set.of(s3Path(bucket, warehouseKeyPrefix)),
+ POLARIS_PRINCIPAL,
+ Optional.empty(),
+ context);
+
+ AssumeRoleRequest capturedRequest = requestCaptor.getValue();
+ // All 5 tags are always included; missing values use "unknown" placeholder
+ Assertions.assertThat(capturedRequest.tags()).hasSize(5);
+ Assertions.assertThat(capturedRequest.tags())
+ .anyMatch(tag -> tag.key().equals("polaris:catalog") &&
tag.value().equals("test-catalog"));
+ Assertions.assertThat(capturedRequest.tags())
+ .anyMatch(
+ tag -> tag.key().equals("polaris:principal") &&
tag.value().equals("test-principal"));
+ // Absent values should be "unknown"
+ Assertions.assertThat(capturedRequest.tags())
+ .anyMatch(tag -> tag.key().equals("polaris:namespace") &&
tag.value().equals("unknown"));
+ Assertions.assertThat(capturedRequest.tags())
+ .anyMatch(tag -> tag.key().equals("polaris:table") &&
tag.value().equals("unknown"));
+ Assertions.assertThat(capturedRequest.tags())
+ .anyMatch(tag -> tag.key().equals("polaris:roles") &&
tag.value().equals("unknown"));
+ }
+
+ @Test
+ public void testSessionTagsWithLongValues() {
+ StsClient stsClient = Mockito.mock(StsClient.class);
+ String roleARN = "arn:aws:iam::012345678901:role/jdoe";
+ String externalId = "externalId";
+ String bucket = "bucket";
+ String warehouseKeyPrefix = "path/to/warehouse";
+
+ RealmConfig sessionTagsEnabledConfig =
+ new RealmConfigImpl(
+ new PolarisConfigurationStore() {
+ @SuppressWarnings("unchecked")
+ @Override
+ public String getConfiguration(@Nonnull RealmContext ctx, String
configName) {
+ if (configName.equals(
+
FeatureConfiguration.INCLUDE_SESSION_TAGS_IN_SUBSCOPED_CREDENTIAL.key())) {
+ return "true";
+ }
+ return null;
+ }
+ },
+ () -> "realm");
+
+ ArgumentCaptor<AssumeRoleRequest> requestCaptor =
+ ArgumentCaptor.forClass(AssumeRoleRequest.class);
+
Mockito.when(stsClient.assumeRole(requestCaptor.capture())).thenReturn(ASSUME_ROLE_RESPONSE);
+
+ // Create context with very long namespace (over 256 chars)
+ String longNamespace = "db." + "a".repeat(300) + ".schema";
+ CredentialVendingContext context =
+ CredentialVendingContext.builder()
+ .catalogName(Optional.of("test-catalog"))
+ .namespace(Optional.of(longNamespace))
+ .build();
+
+ new AwsCredentialsStorageIntegration(
+ AwsStorageConfigurationInfo.builder()
+ .addAllowedLocation(s3Path(bucket, warehouseKeyPrefix))
+ .roleARN(roleARN)
+ .externalId(externalId)
+ .build(),
+ stsClient)
+ .getSubscopedCreds(
+ sessionTagsEnabledConfig,
+ true,
+ Set.of(s3Path(bucket, warehouseKeyPrefix)),
+ Set.of(s3Path(bucket, warehouseKeyPrefix)),
+ POLARIS_PRINCIPAL,
+ Optional.empty(),
+ context);
+
+ AssumeRoleRequest capturedRequest = requestCaptor.getValue();
+ // Verify namespace tag is truncated to 256 characters
+ Assertions.assertThat(capturedRequest.tags())
+ .anyMatch(
+ tag ->
+ tag.key().equals("polaris:namespace")
+ && tag.value().length() == 256
+ && tag.value().startsWith("db."));
+ }
+
+ @Test
+ public void testSessionTagsWithEmptyContext() {
+ StsClient stsClient = Mockito.mock(StsClient.class);
+ String roleARN = "arn:aws:iam::012345678901:role/jdoe";
+ String externalId = "externalId";
+ String bucket = "bucket";
+ String warehouseKeyPrefix = "path/to/warehouse";
+
+ RealmConfig sessionTagsEnabledConfig =
+ new RealmConfigImpl(
+ new PolarisConfigurationStore() {
+ @SuppressWarnings("unchecked")
+ @Override
+ public String getConfiguration(@Nonnull RealmContext ctx, String
configName) {
+ if (configName.equals(
+
FeatureConfiguration.INCLUDE_SESSION_TAGS_IN_SUBSCOPED_CREDENTIAL.key())) {
+ return "true";
+ }
+ return null;
+ }
+ },
+ () -> "realm");
+
+ ArgumentCaptor<AssumeRoleRequest> requestCaptor =
+ ArgumentCaptor.forClass(AssumeRoleRequest.class);
+
Mockito.when(stsClient.assumeRole(requestCaptor.capture())).thenReturn(ASSUME_ROLE_RESPONSE);
+
+ // Use empty context
+ new AwsCredentialsStorageIntegration(
+ AwsStorageConfigurationInfo.builder()
+ .addAllowedLocation(s3Path(bucket, warehouseKeyPrefix))
+ .roleARN(roleARN)
+ .externalId(externalId)
+ .build(),
+ stsClient)
+ .getSubscopedCreds(
+ sessionTagsEnabledConfig,
+ true,
+ Set.of(s3Path(bucket, warehouseKeyPrefix)),
+ Set.of(s3Path(bucket, warehouseKeyPrefix)),
+ POLARIS_PRINCIPAL,
+ Optional.empty(),
+ CredentialVendingContext.empty());
+
+ AssumeRoleRequest capturedRequest = requestCaptor.getValue();
+ // All 5 tags are always included; missing values use "unknown" placeholder
+ Assertions.assertThat(capturedRequest.tags()).hasSize(5);
+ Assertions.assertThat(capturedRequest.tags())
+ .anyMatch(
+ tag -> tag.key().equals("polaris:principal") &&
tag.value().equals("test-principal"));
+ // All context tags should be "unknown" when context is empty
+ Assertions.assertThat(capturedRequest.tags())
+ .anyMatch(tag -> tag.key().equals("polaris:catalog") &&
tag.value().equals("unknown"));
+ Assertions.assertThat(capturedRequest.tags())
+ .anyMatch(tag -> tag.key().equals("polaris:namespace") &&
tag.value().equals("unknown"));
+ Assertions.assertThat(capturedRequest.tags())
+ .anyMatch(tag -> tag.key().equals("polaris:table") &&
tag.value().equals("unknown"));
+ Assertions.assertThat(capturedRequest.tags())
+ .anyMatch(tag -> tag.key().equals("polaris:roles") &&
tag.value().equals("unknown"));
+ }
+
+ /**
+ * Tests graceful error handling when STS throws an exception due to missing
sts:TagSession
+ * permission. When the IAM role's trust policy doesn't allow
sts:TagSession, the assumeRole call
+ * should fail and the exception should be propagated appropriately.
+ *
+ * <p>NOTE: Full integration tests with LocalStack or real AWS to verify
sts:TagSession permission
+ * behavior are recommended but out of scope for unit tests.
+ */
+ @Test
+ public void testSessionTagsAccessDeniedGracefulHandling() {
+ StsClient stsClient = Mockito.mock(StsClient.class);
+ String roleARN = "arn:aws:iam::012345678901:role/jdoe";
+ String externalId = "externalId";
+ String bucket = "bucket";
+ String warehouseKeyPrefix = "path/to/warehouse";
+
+ RealmConfig sessionTagsEnabledConfig =
+ new RealmConfigImpl(
+ new PolarisConfigurationStore() {
+ @SuppressWarnings("unchecked")
+ @Override
+ public String getConfiguration(@Nonnull RealmContext ctx, String
configName) {
+ if (configName.equals(
+
FeatureConfiguration.INCLUDE_SESSION_TAGS_IN_SUBSCOPED_CREDENTIAL.key())) {
+ return "true";
+ }
+ return null;
+ }
+ },
+ () -> "realm");
+
+ // Simulate STS throwing AccessDeniedException when sts:TagSession is not
allowed
+ // In AWS SDK v2, this is represented as StsException with error code
"AccessDenied"
+ software.amazon.awssdk.services.sts.model.StsException
accessDeniedException =
+ (software.amazon.awssdk.services.sts.model.StsException)
+ software.amazon.awssdk.services.sts.model.StsException.builder()
Review Comment:
nit: can we remove inline imports, would be better to just import this
normally and use ?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]