This is an automated email from the ASF dual-hosted git repository. xuanwo pushed a commit to branch xuanwo/sts-session-policy in repository https://gitbox.apache.org/repos/asf/opendal-reqsign.git
commit 0281b28e9c7e1083ddb2c65ac4d46c31653bef4e Author: Xuanwo <[email protected]> AuthorDate: Fri Dec 26 18:56:19 2025 +0800 feat(aws): Add session policy support for sts --- .../aws-v4/src/provide_credential/assume_role.rs | 136 ++++++++++++++++----- .../assume_role_with_web_identity.rs | 122 ++++++++++++++++-- 2 files changed, 220 insertions(+), 38 deletions(-) diff --git a/services/aws-v4/src/provide_credential/assume_role.rs b/services/aws-v4/src/provide_credential/assume_role.rs index 6b6a616..89b5e1e 100644 --- a/services/aws-v4/src/provide_credential/assume_role.rs +++ b/services/aws-v4/src/provide_credential/assume_role.rs @@ -21,10 +21,10 @@ use crate::credential::Credential; use crate::provide_credential::utils::{parse_sts_error, sts_endpoint}; use async_trait::async_trait; use bytes::Bytes; +use form_urlencoded::Serializer; use quick_xml::de; use reqsign_core::{Context, Error, ProvideCredential, Result, Signer}; use serde::Deserialize; -use std::fmt::Write; /// AssumeRoleCredentialProvider will load credential via assume role. #[derive(Debug)] @@ -35,6 +35,8 @@ pub struct AssumeRoleCredentialProvider { external_id: Option<String>, duration_seconds: Option<u32>, tags: Option<Vec<(String, String)>>, + policy: Option<String>, + policy_arns: Option<Vec<String>>, // MFA configuration serial_number: Option<String>, @@ -57,6 +59,8 @@ impl AssumeRoleCredentialProvider { external_id: None, duration_seconds: Some(3600), tags: None, + policy: None, + policy_arns: None, serial_number: None, token_code: None, region: None, @@ -83,6 +87,18 @@ impl AssumeRoleCredentialProvider { self } + /// Set the session policy. + pub fn with_policy(mut self, policy: String) -> Self { + self.policy = Some(policy); + self + } + + /// Set the session policy ARNs. + pub fn with_policy_arns(mut self, policy_arns: Vec<String>) -> Self { + self.policy_arns = Some(policy_arns); + self + } + /// Set the tags. pub fn with_tags(mut self, tags: Vec<(String, String)>) -> Self { self.tags = Some(tags); @@ -146,37 +162,18 @@ impl ProvideCredential for AssumeRoleCredentialProvider { let endpoint = sts_endpoint(self.region.as_deref(), self.use_regional_sts_endpoint) .map_err(|e| e.with_context(format!("role_arn: {}", self.role_arn)))?; - // Construct request to AWS STS Service. - let mut url = format!( - "https://{endpoint}/?Action=AssumeRole&RoleArn={}&Version=2011-06-15&RoleSessionName={}", - self.role_arn, self.role_session_name + let query = build_assume_role_query( + &self.role_arn, + &self.role_session_name, + self.external_id.as_deref(), + self.duration_seconds, + self.tags.as_deref(), + self.policy.as_deref(), + self.policy_arns.as_deref(), + self.serial_number.as_deref(), + self.token_code.as_deref(), ); - if let Some(external_id) = &self.external_id { - write!(url, "&ExternalId={external_id}") - .map_err(|e| Error::unexpected("failed to format URL").with_source(e))?; - } - if let Some(duration_seconds) = &self.duration_seconds { - write!(url, "&DurationSeconds={duration_seconds}") - .map_err(|e| Error::unexpected("failed to format URL").with_source(e))?; - } - if let Some(tags) = &self.tags { - for (idx, (key, value)) in tags.iter().enumerate() { - let tag_index = idx + 1; - write!( - url, - "&Tags.member.{tag_index}.Key={key}&Tags.member.{tag_index}.Value={value}" - ) - .map_err(|e| Error::unexpected("failed to format URL").with_source(e))?; - } - } - if let Some(serial_number) = &self.serial_number { - write!(url, "&SerialNumber={serial_number}") - .map_err(|e| Error::unexpected("failed to format URL").with_source(e))?; - } - if let Some(token_code) = &self.token_code { - write!(url, "&TokenCode={token_code}") - .map_err(|e| Error::unexpected("failed to format URL").with_source(e))?; - } + let url = format!("https://{endpoint}/?{query}"); let req = http::request::Request::builder() .method("GET") @@ -249,6 +246,57 @@ impl ProvideCredential for AssumeRoleCredentialProvider { } } +fn build_assume_role_query( + role_arn: &str, + role_session_name: &str, + external_id: Option<&str>, + duration_seconds: Option<u32>, + tags: Option<&[(String, String)]>, + policy: Option<&str>, + policy_arns: Option<&[String]>, + serial_number: Option<&str>, + token_code: Option<&str>, +) -> String { + let mut serializer = Serializer::new(String::new()); + serializer + .append_pair("Action", "AssumeRole") + .append_pair("RoleArn", role_arn) + .append_pair("Version", "2011-06-15") + .append_pair("RoleSessionName", role_session_name); + + if let Some(external_id) = external_id { + serializer.append_pair("ExternalId", external_id); + } + if let Some(duration_seconds) = duration_seconds { + serializer.append_pair("DurationSeconds", &duration_seconds.to_string()); + } + if let Some(policy) = policy { + serializer.append_pair("Policy", policy); + } + if let Some(policy_arns) = policy_arns { + for (idx, arn) in policy_arns.iter().enumerate() { + let key = format!("PolicyArns.member.{}.arn", idx + 1); + serializer.append_pair(&key, arn); + } + } + if let Some(tags) = tags { + for (idx, (key, value)) in tags.iter().enumerate() { + let tag_index = idx + 1; + serializer + .append_pair(&format!("Tags.member.{tag_index}.Key"), key) + .append_pair(&format!("Tags.member.{tag_index}.Value"), value); + } + } + if let Some(serial_number) = serial_number { + serializer.append_pair("SerialNumber", serial_number); + } + if let Some(token_code) = token_code { + serializer.append_pair("TokenCode", token_code); + } + + serializer.finish() +} + #[derive(Default, Debug, Deserialize)] #[serde(default, rename_all = "PascalCase")] struct AssumeRoleResponse { @@ -328,4 +376,30 @@ mod tests { Ok(()) } + + #[test] + fn test_assume_role_encodes_policy_and_policy_arns() { + let policy = r#"{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":"s3:ListBucket","Resource":"*","Condition":{"StringEquals":{"s3:prefix":"a b"}}}]}"#; + let policy_arns = vec![ + "arn:aws:iam::aws:policy/ReadOnlyAccess".to_string(), + "arn:aws:iam::123456789012:policy/ExamplePolicy".to_string(), + ]; + let query = build_assume_role_query( + "arn:aws:iam::123456789012:role/test-role", + "reqsign", + None, + Some(3600), + None, + Some(policy), + Some(policy_arns.as_slice()), + None, + None, + ); + + assert!( + query.contains("Policy=%7B%22Version%22%3A%222012-10-17%22%2C%22Statement%22%3A%5B%7B%22Effect%22%3A%22Allow%22%2C%22Action%22%3A%22s3%3AListBucket%22%2C%22Resource%22%3A%22*%22%2C%22Condition%22%3A%7B%22StringEquals%22%3A%7B%22s3%3Aprefix%22%3A%22a+b%22%7D%7D%7D%5D%7D") + ); + assert!(query.contains("PolicyArns.member.1.arn=arn%3Aaws%3Aiam%3A%3Aaws%3Apolicy%2FReadOnlyAccess")); + assert!(query.contains("PolicyArns.member.2.arn=arn%3Aaws%3Aiam%3A%3A123456789012%3Apolicy%2FExamplePolicy")); + } } diff --git a/services/aws-v4/src/provide_credential/assume_role_with_web_identity.rs b/services/aws-v4/src/provide_credential/assume_role_with_web_identity.rs index ac4c6e1..d3e73ae 100644 --- a/services/aws-v4/src/provide_credential/assume_role_with_web_identity.rs +++ b/services/aws-v4/src/provide_credential/assume_role_with_web_identity.rs @@ -37,6 +37,9 @@ pub struct AssumeRoleWithWebIdentityCredentialProvider { role_arn: Option<String>, role_session_name: Option<String>, web_identity_token_file: Option<PathBuf>, + duration_seconds: Option<u32>, + policy: Option<String>, + policy_arns: Option<Vec<String>>, // STS configuration region: Option<String>, @@ -55,6 +58,9 @@ impl AssumeRoleWithWebIdentityCredentialProvider { role_arn: Some(role_arn), role_session_name: None, web_identity_token_file: Some(token_file), + duration_seconds: None, + policy: None, + policy_arns: None, region: None, use_regional_sts_endpoint: None, } @@ -78,6 +84,24 @@ impl AssumeRoleWithWebIdentityCredentialProvider { self } + /// Set the duration in seconds. + pub fn with_duration_seconds(mut self, seconds: u32) -> Self { + self.duration_seconds = Some(seconds); + self + } + + /// Set the session policy. + pub fn with_policy(mut self, policy: String) -> Self { + self.policy = Some(policy); + self + } + + /// Set the session policy ARNs. + pub fn with_policy_arns(mut self, policy_arns: Vec<String>) -> Self { + self.policy_arns = Some(policy_arns); + self + } + /// Set the region. pub fn with_region(mut self, region: String) -> Self { self.region = Some(region); @@ -152,13 +176,29 @@ impl ProvideCredential for AssumeRoleWithWebIdentityCredentialProvider { .unwrap_or_else(|| "reqsign".to_string()); // Construct request to AWS STS Service. - let query = Serializer::new(String::new()) - .append_pair("Action", "AssumeRoleWithWebIdentity") - .append_pair("RoleArn", &role_arn) - .append_pair("WebIdentityToken", &token) - .append_pair("Version", "2011-06-15") - .append_pair("RoleSessionName", &session_name) - .finish(); + let query = { + let mut serializer = Serializer::new(String::new()); + serializer + .append_pair("Action", "AssumeRoleWithWebIdentity") + .append_pair("RoleArn", &role_arn) + .append_pair("WebIdentityToken", &token) + .append_pair("Version", "2011-06-15") + .append_pair("RoleSessionName", &session_name); + + if let Some(duration_seconds) = self.duration_seconds { + serializer.append_pair("DurationSeconds", &duration_seconds.to_string()); + } + if let Some(policy) = self.policy.as_deref() { + serializer.append_pair("Policy", policy); + } + if let Some(policy_arns) = self.policy_arns.as_deref() { + for (idx, arn) in policy_arns.iter().enumerate() { + serializer.append_pair(&format!("PolicyArns.member.{}.arn", idx + 1), arn); + } + } + + serializer.finish() + }; let url = format!("https://{endpoint}/?{query}"); let req = http::request::Request::builder() .method("GET") @@ -422,4 +462,72 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_assume_role_with_web_identity_supports_policy_and_policy_arns() -> Result<()> { + let _ = env_logger::builder().is_test(true).try_init(); + + let token_path = "/mock/token"; + let raw_token = "header.payload+signature/\n"; + + let file_read = TestFileRead { + expected_path: token_path.to_string(), + content: raw_token.as_bytes().to_vec(), + }; + + let http_body = r#"<AssumeRoleWithWebIdentityResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/"> + <AssumeRoleWithWebIdentityResult> + <Credentials> + <AccessKeyId>access_key_id</AccessKeyId> + <SecretAccessKey>secret_access_key</SecretAccessKey> + <SessionToken>session_token</SessionToken> + <Expiration>2124-05-25T11:45:17Z</Expiration> + </Credentials> + </AssumeRoleWithWebIdentityResult> +</AssumeRoleWithWebIdentityResponse>"#; + let http_send = CaptureHttpSend::new(http_body); + + let ctx = Context::new() + .with_file_read(file_read) + .with_http_send(http_send.clone()) + .with_env(StaticEnv { + home_dir: None, + envs: HashMap::new(), + }); + + let policy = r#"{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":"s3:ListBucket","Resource":"*","Condition":{"StringEquals":{"s3:prefix":"a b"}}}]}"#; + + let provider = AssumeRoleWithWebIdentityCredentialProvider::with_config( + "arn:aws:iam::123456789012:role/test-role".to_string(), + token_path.into(), + ) + .with_duration_seconds(900) + .with_policy(policy.to_string()) + .with_policy_arns(vec![ + "arn:aws:iam::aws:policy/ReadOnlyAccess".to_string(), + "arn:aws:iam::123456789012:policy/ExamplePolicy".to_string(), + ]); + + let _ = provider + .provide_credential(&ctx) + .await? + .expect("credential must be loaded"); + + let recorded_uri = http_send + .uri() + .expect("http_send must capture outgoing uri"); + + assert!(recorded_uri.contains("DurationSeconds=900")); + assert!( + recorded_uri.contains("Policy=%7B%22Version%22%3A%222012-10-17%22%2C%22Statement%22%3A%5B%7B%22Effect%22%3A%22Allow%22%2C%22Action%22%3A%22s3%3AListBucket%22%2C%22Resource%22%3A%22*%22%2C%22Condition%22%3A%7B%22StringEquals%22%3A%7B%22s3%3Aprefix%22%3A%22a+b%22%7D%7D%7D%5D%7D") + ); + assert!(recorded_uri.contains( + "PolicyArns.member.1.arn=arn%3Aaws%3Aiam%3A%3Aaws%3Apolicy%2FReadOnlyAccess" + )); + assert!(recorded_uri.contains( + "PolicyArns.member.2.arn=arn%3Aaws%3Aiam%3A%3A123456789012%3Apolicy%2FExamplePolicy" + )); + + Ok(()) + } }
