This is an automated email from the ASF dual-hosted git repository.
xuanwo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/opendal-reqsign.git
The following commit(s) were added to refs/heads/main by this push:
new c3b660e feat(aws): Add session policy support for sts (#670)
c3b660e is described below
commit c3b660eb42934a133881cabea8c42be01917271f
Author: Xuanwo <[email protected]>
AuthorDate: Fri Dec 26 19:21:58 2025 +0800
feat(aws): Add session policy support for sts (#670)
This PR adds session policy support for sts
---
**Parts of this PR were drafted with assistance from Codex (with
`gpt-5.2`) and fully reviewed and edited by me. I take full
responsibility for all changes.**
---
.../aws-v4/src/provide_credential/assume_role.rs | 144 ++++++++++++++++-----
.../assume_role_with_web_identity.rs | 122 ++++++++++++++++-
2 files changed, 227 insertions(+), 39 deletions(-)
diff --git a/services/aws-v4/src/provide_credential/assume_role.rs
b/services/aws-v4/src/provide_credential/assume_role.rs
index 6b6a616..9a47212 100644
--- a/services/aws-v4/src/provide_credential/assume_role.rs
+++ b/services/aws-v4/src/provide_credential/assume_role.rs
@@ -21,10 +21,10 @@ use crate::credential::Credential;
use crate::provide_credential::utils::{parse_sts_error, sts_endpoint};
use async_trait::async_trait;
use bytes::Bytes;
+use form_urlencoded::Serializer;
use quick_xml::de;
use reqsign_core::{Context, Error, ProvideCredential, Result, Signer};
use serde::Deserialize;
-use std::fmt::Write;
/// AssumeRoleCredentialProvider will load credential via assume role.
#[derive(Debug)]
@@ -35,6 +35,8 @@ pub struct AssumeRoleCredentialProvider {
external_id: Option<String>,
duration_seconds: Option<u32>,
tags: Option<Vec<(String, String)>>,
+ policy: Option<String>,
+ policy_arns: Option<Vec<String>>,
// MFA configuration
serial_number: Option<String>,
@@ -57,6 +59,8 @@ impl AssumeRoleCredentialProvider {
external_id: None,
duration_seconds: Some(3600),
tags: None,
+ policy: None,
+ policy_arns: None,
serial_number: None,
token_code: None,
region: None,
@@ -83,6 +87,18 @@ impl AssumeRoleCredentialProvider {
self
}
+ /// Set the session policy.
+ pub fn with_policy(mut self, policy: String) -> Self {
+ self.policy = Some(policy);
+ self
+ }
+
+ /// Set the session policy ARNs.
+ pub fn with_policy_arns(mut self, policy_arns: Vec<String>) -> Self {
+ self.policy_arns = Some(policy_arns);
+ self
+ }
+
/// Set the tags.
pub fn with_tags(mut self, tags: Vec<(String, String)>) -> Self {
self.tags = Some(tags);
@@ -146,37 +162,18 @@ impl ProvideCredential for AssumeRoleCredentialProvider {
let endpoint = sts_endpoint(self.region.as_deref(),
self.use_regional_sts_endpoint)
.map_err(|e| e.with_context(format!("role_arn: {}",
self.role_arn)))?;
- // Construct request to AWS STS Service.
- let mut url = format!(
-
"https://{endpoint}/?Action=AssumeRole&RoleArn={}&Version=2011-06-15&RoleSessionName={}",
- self.role_arn, self.role_session_name
- );
- if let Some(external_id) = &self.external_id {
- write!(url, "&ExternalId={external_id}")
- .map_err(|e| Error::unexpected("failed to format
URL").with_source(e))?;
- }
- if let Some(duration_seconds) = &self.duration_seconds {
- write!(url, "&DurationSeconds={duration_seconds}")
- .map_err(|e| Error::unexpected("failed to format
URL").with_source(e))?;
- }
- if let Some(tags) = &self.tags {
- for (idx, (key, value)) in tags.iter().enumerate() {
- let tag_index = idx + 1;
- write!(
- url,
-
"&Tags.member.{tag_index}.Key={key}&Tags.member.{tag_index}.Value={value}"
- )
- .map_err(|e| Error::unexpected("failed to format
URL").with_source(e))?;
- }
- }
- if let Some(serial_number) = &self.serial_number {
- write!(url, "&SerialNumber={serial_number}")
- .map_err(|e| Error::unexpected("failed to format
URL").with_source(e))?;
- }
- if let Some(token_code) = &self.token_code {
- write!(url, "&TokenCode={token_code}")
- .map_err(|e| Error::unexpected("failed to format
URL").with_source(e))?;
- }
+ let query = build_assume_role_query(AssumeRoleQueryInput {
+ role_arn: &self.role_arn,
+ role_session_name: &self.role_session_name,
+ external_id: self.external_id.as_deref(),
+ duration_seconds: self.duration_seconds,
+ tags: self.tags.as_deref(),
+ policy: self.policy.as_deref(),
+ policy_arns: self.policy_arns.as_deref(),
+ serial_number: self.serial_number.as_deref(),
+ token_code: self.token_code.as_deref(),
+ });
+ let url = format!("https://{endpoint}/?{query}");
let req = http::request::Request::builder()
.method("GET")
@@ -249,6 +246,59 @@ impl ProvideCredential for AssumeRoleCredentialProvider {
}
}
+struct AssumeRoleQueryInput<'a> {
+ role_arn: &'a str,
+ role_session_name: &'a str,
+ external_id: Option<&'a str>,
+ duration_seconds: Option<u32>,
+ tags: Option<&'a [(String, String)]>,
+ policy: Option<&'a str>,
+ policy_arns: Option<&'a [String]>,
+ serial_number: Option<&'a str>,
+ token_code: Option<&'a str>,
+}
+
+fn build_assume_role_query(input: AssumeRoleQueryInput<'_>) -> String {
+ let mut serializer = Serializer::new(String::new());
+ serializer
+ .append_pair("Action", "AssumeRole")
+ .append_pair("RoleArn", input.role_arn)
+ .append_pair("Version", "2011-06-15")
+ .append_pair("RoleSessionName", input.role_session_name);
+
+ if let Some(external_id) = input.external_id {
+ serializer.append_pair("ExternalId", external_id);
+ }
+ if let Some(duration_seconds) = input.duration_seconds {
+ serializer.append_pair("DurationSeconds",
&duration_seconds.to_string());
+ }
+ if let Some(policy) = input.policy {
+ serializer.append_pair("Policy", policy);
+ }
+ if let Some(policy_arns) = input.policy_arns {
+ for (idx, arn) in policy_arns.iter().enumerate() {
+ let key = format!("PolicyArns.member.{}.arn", idx + 1);
+ serializer.append_pair(&key, arn);
+ }
+ }
+ if let Some(tags) = input.tags {
+ for (idx, (key, value)) in tags.iter().enumerate() {
+ let tag_index = idx + 1;
+ serializer
+ .append_pair(&format!("Tags.member.{tag_index}.Key"), key)
+ .append_pair(&format!("Tags.member.{tag_index}.Value"), value);
+ }
+ }
+ if let Some(serial_number) = input.serial_number {
+ serializer.append_pair("SerialNumber", serial_number);
+ }
+ if let Some(token_code) = input.token_code {
+ serializer.append_pair("TokenCode", token_code);
+ }
+
+ serializer.finish()
+}
+
#[derive(Default, Debug, Deserialize)]
#[serde(default, rename_all = "PascalCase")]
struct AssumeRoleResponse {
@@ -328,4 +378,34 @@ mod tests {
Ok(())
}
+
+ #[test]
+ fn test_assume_role_encodes_policy_and_policy_arns() {
+ let policy =
r#"{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":"s3:ListBucket","Resource":"*","Condition":{"StringEquals":{"s3:prefix":"a
b"}}}]}"#;
+ let policy_arns = vec![
+ "arn:aws:iam::aws:policy/ReadOnlyAccess".to_string(),
+ "arn:aws:iam::123456789012:policy/ExamplePolicy".to_string(),
+ ];
+ let query = build_assume_role_query(AssumeRoleQueryInput {
+ role_arn: "arn:aws:iam::123456789012:role/test-role",
+ role_session_name: "reqsign",
+ external_id: None,
+ duration_seconds: Some(3600),
+ tags: None,
+ policy: Some(policy),
+ policy_arns: Some(policy_arns.as_slice()),
+ serial_number: None,
+ token_code: None,
+ });
+
+ assert!(
+
query.contains("Policy=%7B%22Version%22%3A%222012-10-17%22%2C%22Statement%22%3A%5B%7B%22Effect%22%3A%22Allow%22%2C%22Action%22%3A%22s3%3AListBucket%22%2C%22Resource%22%3A%22*%22%2C%22Condition%22%3A%7B%22StringEquals%22%3A%7B%22s3%3Aprefix%22%3A%22a+b%22%7D%7D%7D%5D%7D")
+ );
+ assert!(query.contains(
+
"PolicyArns.member.1.arn=arn%3Aaws%3Aiam%3A%3Aaws%3Apolicy%2FReadOnlyAccess"
+ ));
+ assert!(query.contains(
+
"PolicyArns.member.2.arn=arn%3Aaws%3Aiam%3A%3A123456789012%3Apolicy%2FExamplePolicy"
+ ));
+ }
}
diff --git
a/services/aws-v4/src/provide_credential/assume_role_with_web_identity.rs
b/services/aws-v4/src/provide_credential/assume_role_with_web_identity.rs
index ac4c6e1..d3e73ae 100644
--- a/services/aws-v4/src/provide_credential/assume_role_with_web_identity.rs
+++ b/services/aws-v4/src/provide_credential/assume_role_with_web_identity.rs
@@ -37,6 +37,9 @@ pub struct AssumeRoleWithWebIdentityCredentialProvider {
role_arn: Option<String>,
role_session_name: Option<String>,
web_identity_token_file: Option<PathBuf>,
+ duration_seconds: Option<u32>,
+ policy: Option<String>,
+ policy_arns: Option<Vec<String>>,
// STS configuration
region: Option<String>,
@@ -55,6 +58,9 @@ impl AssumeRoleWithWebIdentityCredentialProvider {
role_arn: Some(role_arn),
role_session_name: None,
web_identity_token_file: Some(token_file),
+ duration_seconds: None,
+ policy: None,
+ policy_arns: None,
region: None,
use_regional_sts_endpoint: None,
}
@@ -78,6 +84,24 @@ impl AssumeRoleWithWebIdentityCredentialProvider {
self
}
+ /// Set the duration in seconds.
+ pub fn with_duration_seconds(mut self, seconds: u32) -> Self {
+ self.duration_seconds = Some(seconds);
+ self
+ }
+
+ /// Set the session policy.
+ pub fn with_policy(mut self, policy: String) -> Self {
+ self.policy = Some(policy);
+ self
+ }
+
+ /// Set the session policy ARNs.
+ pub fn with_policy_arns(mut self, policy_arns: Vec<String>) -> Self {
+ self.policy_arns = Some(policy_arns);
+ self
+ }
+
/// Set the region.
pub fn with_region(mut self, region: String) -> Self {
self.region = Some(region);
@@ -152,13 +176,29 @@ impl ProvideCredential for
AssumeRoleWithWebIdentityCredentialProvider {
.unwrap_or_else(|| "reqsign".to_string());
// Construct request to AWS STS Service.
- let query = Serializer::new(String::new())
- .append_pair("Action", "AssumeRoleWithWebIdentity")
- .append_pair("RoleArn", &role_arn)
- .append_pair("WebIdentityToken", &token)
- .append_pair("Version", "2011-06-15")
- .append_pair("RoleSessionName", &session_name)
- .finish();
+ let query = {
+ let mut serializer = Serializer::new(String::new());
+ serializer
+ .append_pair("Action", "AssumeRoleWithWebIdentity")
+ .append_pair("RoleArn", &role_arn)
+ .append_pair("WebIdentityToken", &token)
+ .append_pair("Version", "2011-06-15")
+ .append_pair("RoleSessionName", &session_name);
+
+ if let Some(duration_seconds) = self.duration_seconds {
+ serializer.append_pair("DurationSeconds",
&duration_seconds.to_string());
+ }
+ if let Some(policy) = self.policy.as_deref() {
+ serializer.append_pair("Policy", policy);
+ }
+ if let Some(policy_arns) = self.policy_arns.as_deref() {
+ for (idx, arn) in policy_arns.iter().enumerate() {
+
serializer.append_pair(&format!("PolicyArns.member.{}.arn", idx + 1), arn);
+ }
+ }
+
+ serializer.finish()
+ };
let url = format!("https://{endpoint}/?{query}");
let req = http::request::Request::builder()
.method("GET")
@@ -422,4 +462,72 @@ mod tests {
Ok(())
}
+
+ #[tokio::test]
+ async fn
test_assume_role_with_web_identity_supports_policy_and_policy_arns() ->
Result<()> {
+ let _ = env_logger::builder().is_test(true).try_init();
+
+ let token_path = "/mock/token";
+ let raw_token = "header.payload+signature/\n";
+
+ let file_read = TestFileRead {
+ expected_path: token_path.to_string(),
+ content: raw_token.as_bytes().to_vec(),
+ };
+
+ let http_body = r#"<AssumeRoleWithWebIdentityResponse
xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
+ <AssumeRoleWithWebIdentityResult>
+ <Credentials>
+ <AccessKeyId>access_key_id</AccessKeyId>
+ <SecretAccessKey>secret_access_key</SecretAccessKey>
+ <SessionToken>session_token</SessionToken>
+ <Expiration>2124-05-25T11:45:17Z</Expiration>
+ </Credentials>
+ </AssumeRoleWithWebIdentityResult>
+</AssumeRoleWithWebIdentityResponse>"#;
+ let http_send = CaptureHttpSend::new(http_body);
+
+ let ctx = Context::new()
+ .with_file_read(file_read)
+ .with_http_send(http_send.clone())
+ .with_env(StaticEnv {
+ home_dir: None,
+ envs: HashMap::new(),
+ });
+
+ let policy =
r#"{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":"s3:ListBucket","Resource":"*","Condition":{"StringEquals":{"s3:prefix":"a
b"}}}]}"#;
+
+ let provider =
AssumeRoleWithWebIdentityCredentialProvider::with_config(
+ "arn:aws:iam::123456789012:role/test-role".to_string(),
+ token_path.into(),
+ )
+ .with_duration_seconds(900)
+ .with_policy(policy.to_string())
+ .with_policy_arns(vec![
+ "arn:aws:iam::aws:policy/ReadOnlyAccess".to_string(),
+ "arn:aws:iam::123456789012:policy/ExamplePolicy".to_string(),
+ ]);
+
+ let _ = provider
+ .provide_credential(&ctx)
+ .await?
+ .expect("credential must be loaded");
+
+ let recorded_uri = http_send
+ .uri()
+ .expect("http_send must capture outgoing uri");
+
+ assert!(recorded_uri.contains("DurationSeconds=900"));
+ assert!(
+
recorded_uri.contains("Policy=%7B%22Version%22%3A%222012-10-17%22%2C%22Statement%22%3A%5B%7B%22Effect%22%3A%22Allow%22%2C%22Action%22%3A%22s3%3AListBucket%22%2C%22Resource%22%3A%22*%22%2C%22Condition%22%3A%7B%22StringEquals%22%3A%7B%22s3%3Aprefix%22%3A%22a+b%22%7D%7D%7D%5D%7D")
+ );
+ assert!(recorded_uri.contains(
+
"PolicyArns.member.1.arn=arn%3Aaws%3Aiam%3A%3Aaws%3Apolicy%2FReadOnlyAccess"
+ ));
+ assert!(recorded_uri.contains(
+
"PolicyArns.member.2.arn=arn%3Aaws%3Aiam%3A%3A123456789012%3Apolicy%2FExamplePolicy"
+ ));
+
+ Ok(())
+ }
}