This is an automated email from the ASF dual-hosted git repository.

xuanwo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/opendal-reqsign.git


The following commit(s) were added to refs/heads/main by this push:
     new c3b660e  feat(aws): Add session policy support for sts (#670)
c3b660e is described below

commit c3b660eb42934a133881cabea8c42be01917271f
Author: Xuanwo <[email protected]>
AuthorDate: Fri Dec 26 19:21:58 2025 +0800

    feat(aws): Add session policy support for sts (#670)
    
    This PR adds session policy support for sts
    
    ---
    
    **Parts of this PR were drafted with assistance from Codex (with
    `gpt-5.2`) and fully reviewed and edited by me. I take full
    responsibility for all changes.**
---
 .../aws-v4/src/provide_credential/assume_role.rs   | 144 ++++++++++++++++-----
 .../assume_role_with_web_identity.rs               | 122 ++++++++++++++++-
 2 files changed, 227 insertions(+), 39 deletions(-)

diff --git a/services/aws-v4/src/provide_credential/assume_role.rs 
b/services/aws-v4/src/provide_credential/assume_role.rs
index 6b6a616..9a47212 100644
--- a/services/aws-v4/src/provide_credential/assume_role.rs
+++ b/services/aws-v4/src/provide_credential/assume_role.rs
@@ -21,10 +21,10 @@ use crate::credential::Credential;
 use crate::provide_credential::utils::{parse_sts_error, sts_endpoint};
 use async_trait::async_trait;
 use bytes::Bytes;
+use form_urlencoded::Serializer;
 use quick_xml::de;
 use reqsign_core::{Context, Error, ProvideCredential, Result, Signer};
 use serde::Deserialize;
-use std::fmt::Write;
 
 /// AssumeRoleCredentialProvider will load credential via assume role.
 #[derive(Debug)]
@@ -35,6 +35,8 @@ pub struct AssumeRoleCredentialProvider {
     external_id: Option<String>,
     duration_seconds: Option<u32>,
     tags: Option<Vec<(String, String)>>,
+    policy: Option<String>,
+    policy_arns: Option<Vec<String>>,
 
     // MFA configuration
     serial_number: Option<String>,
@@ -57,6 +59,8 @@ impl AssumeRoleCredentialProvider {
             external_id: None,
             duration_seconds: Some(3600),
             tags: None,
+            policy: None,
+            policy_arns: None,
             serial_number: None,
             token_code: None,
             region: None,
@@ -83,6 +87,18 @@ impl AssumeRoleCredentialProvider {
         self
     }
 
+    /// Set the session policy.
+    pub fn with_policy(mut self, policy: String) -> Self {
+        self.policy = Some(policy);
+        self
+    }
+
+    /// Set the session policy ARNs.
+    pub fn with_policy_arns(mut self, policy_arns: Vec<String>) -> Self {
+        self.policy_arns = Some(policy_arns);
+        self
+    }
+
     /// Set the tags.
     pub fn with_tags(mut self, tags: Vec<(String, String)>) -> Self {
         self.tags = Some(tags);
@@ -146,37 +162,18 @@ impl ProvideCredential for AssumeRoleCredentialProvider {
         let endpoint = sts_endpoint(self.region.as_deref(), 
self.use_regional_sts_endpoint)
             .map_err(|e| e.with_context(format!("role_arn: {}", 
self.role_arn)))?;
 
-        // Construct request to AWS STS Service.
-        let mut url = format!(
-            
"https://{endpoint}/?Action=AssumeRole&RoleArn={}&Version=2011-06-15&RoleSessionName={}";,
-            self.role_arn, self.role_session_name
-        );
-        if let Some(external_id) = &self.external_id {
-            write!(url, "&ExternalId={external_id}")
-                .map_err(|e| Error::unexpected("failed to format 
URL").with_source(e))?;
-        }
-        if let Some(duration_seconds) = &self.duration_seconds {
-            write!(url, "&DurationSeconds={duration_seconds}")
-                .map_err(|e| Error::unexpected("failed to format 
URL").with_source(e))?;
-        }
-        if let Some(tags) = &self.tags {
-            for (idx, (key, value)) in tags.iter().enumerate() {
-                let tag_index = idx + 1;
-                write!(
-                    url,
-                    
"&Tags.member.{tag_index}.Key={key}&Tags.member.{tag_index}.Value={value}"
-                )
-                .map_err(|e| Error::unexpected("failed to format 
URL").with_source(e))?;
-            }
-        }
-        if let Some(serial_number) = &self.serial_number {
-            write!(url, "&SerialNumber={serial_number}")
-                .map_err(|e| Error::unexpected("failed to format 
URL").with_source(e))?;
-        }
-        if let Some(token_code) = &self.token_code {
-            write!(url, "&TokenCode={token_code}")
-                .map_err(|e| Error::unexpected("failed to format 
URL").with_source(e))?;
-        }
+        let query = build_assume_role_query(AssumeRoleQueryInput {
+            role_arn: &self.role_arn,
+            role_session_name: &self.role_session_name,
+            external_id: self.external_id.as_deref(),
+            duration_seconds: self.duration_seconds,
+            tags: self.tags.as_deref(),
+            policy: self.policy.as_deref(),
+            policy_arns: self.policy_arns.as_deref(),
+            serial_number: self.serial_number.as_deref(),
+            token_code: self.token_code.as_deref(),
+        });
+        let url = format!("https://{endpoint}/?{query}";);
 
         let req = http::request::Request::builder()
             .method("GET")
@@ -249,6 +246,59 @@ impl ProvideCredential for AssumeRoleCredentialProvider {
     }
 }
 
+struct AssumeRoleQueryInput<'a> {
+    role_arn: &'a str,
+    role_session_name: &'a str,
+    external_id: Option<&'a str>,
+    duration_seconds: Option<u32>,
+    tags: Option<&'a [(String, String)]>,
+    policy: Option<&'a str>,
+    policy_arns: Option<&'a [String]>,
+    serial_number: Option<&'a str>,
+    token_code: Option<&'a str>,
+}
+
+fn build_assume_role_query(input: AssumeRoleQueryInput<'_>) -> String {
+    let mut serializer = Serializer::new(String::new());
+    serializer
+        .append_pair("Action", "AssumeRole")
+        .append_pair("RoleArn", input.role_arn)
+        .append_pair("Version", "2011-06-15")
+        .append_pair("RoleSessionName", input.role_session_name);
+
+    if let Some(external_id) = input.external_id {
+        serializer.append_pair("ExternalId", external_id);
+    }
+    if let Some(duration_seconds) = input.duration_seconds {
+        serializer.append_pair("DurationSeconds", 
&duration_seconds.to_string());
+    }
+    if let Some(policy) = input.policy {
+        serializer.append_pair("Policy", policy);
+    }
+    if let Some(policy_arns) = input.policy_arns {
+        for (idx, arn) in policy_arns.iter().enumerate() {
+            let key = format!("PolicyArns.member.{}.arn", idx + 1);
+            serializer.append_pair(&key, arn);
+        }
+    }
+    if let Some(tags) = input.tags {
+        for (idx, (key, value)) in tags.iter().enumerate() {
+            let tag_index = idx + 1;
+            serializer
+                .append_pair(&format!("Tags.member.{tag_index}.Key"), key)
+                .append_pair(&format!("Tags.member.{tag_index}.Value"), value);
+        }
+    }
+    if let Some(serial_number) = input.serial_number {
+        serializer.append_pair("SerialNumber", serial_number);
+    }
+    if let Some(token_code) = input.token_code {
+        serializer.append_pair("TokenCode", token_code);
+    }
+
+    serializer.finish()
+}
+
 #[derive(Default, Debug, Deserialize)]
 #[serde(default, rename_all = "PascalCase")]
 struct AssumeRoleResponse {
@@ -328,4 +378,34 @@ mod tests {
 
         Ok(())
     }
+
+    #[test]
+    fn test_assume_role_encodes_policy_and_policy_arns() {
+        let policy = 
r#"{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":"s3:ListBucket","Resource":"*","Condition":{"StringEquals":{"s3:prefix":"a
 b"}}}]}"#;
+        let policy_arns = vec![
+            "arn:aws:iam::aws:policy/ReadOnlyAccess".to_string(),
+            "arn:aws:iam::123456789012:policy/ExamplePolicy".to_string(),
+        ];
+        let query = build_assume_role_query(AssumeRoleQueryInput {
+            role_arn: "arn:aws:iam::123456789012:role/test-role",
+            role_session_name: "reqsign",
+            external_id: None,
+            duration_seconds: Some(3600),
+            tags: None,
+            policy: Some(policy),
+            policy_arns: Some(policy_arns.as_slice()),
+            serial_number: None,
+            token_code: None,
+        });
+
+        assert!(
+            
query.contains("Policy=%7B%22Version%22%3A%222012-10-17%22%2C%22Statement%22%3A%5B%7B%22Effect%22%3A%22Allow%22%2C%22Action%22%3A%22s3%3AListBucket%22%2C%22Resource%22%3A%22*%22%2C%22Condition%22%3A%7B%22StringEquals%22%3A%7B%22s3%3Aprefix%22%3A%22a+b%22%7D%7D%7D%5D%7D")
+        );
+        assert!(query.contains(
+            
"PolicyArns.member.1.arn=arn%3Aaws%3Aiam%3A%3Aaws%3Apolicy%2FReadOnlyAccess"
+        ));
+        assert!(query.contains(
+            
"PolicyArns.member.2.arn=arn%3Aaws%3Aiam%3A%3A123456789012%3Apolicy%2FExamplePolicy"
+        ));
+    }
 }
diff --git 
a/services/aws-v4/src/provide_credential/assume_role_with_web_identity.rs 
b/services/aws-v4/src/provide_credential/assume_role_with_web_identity.rs
index ac4c6e1..d3e73ae 100644
--- a/services/aws-v4/src/provide_credential/assume_role_with_web_identity.rs
+++ b/services/aws-v4/src/provide_credential/assume_role_with_web_identity.rs
@@ -37,6 +37,9 @@ pub struct AssumeRoleWithWebIdentityCredentialProvider {
     role_arn: Option<String>,
     role_session_name: Option<String>,
     web_identity_token_file: Option<PathBuf>,
+    duration_seconds: Option<u32>,
+    policy: Option<String>,
+    policy_arns: Option<Vec<String>>,
 
     // STS configuration
     region: Option<String>,
@@ -55,6 +58,9 @@ impl AssumeRoleWithWebIdentityCredentialProvider {
             role_arn: Some(role_arn),
             role_session_name: None,
             web_identity_token_file: Some(token_file),
+            duration_seconds: None,
+            policy: None,
+            policy_arns: None,
             region: None,
             use_regional_sts_endpoint: None,
         }
@@ -78,6 +84,24 @@ impl AssumeRoleWithWebIdentityCredentialProvider {
         self
     }
 
+    /// Set the duration in seconds.
+    pub fn with_duration_seconds(mut self, seconds: u32) -> Self {
+        self.duration_seconds = Some(seconds);
+        self
+    }
+
+    /// Set the session policy.
+    pub fn with_policy(mut self, policy: String) -> Self {
+        self.policy = Some(policy);
+        self
+    }
+
+    /// Set the session policy ARNs.
+    pub fn with_policy_arns(mut self, policy_arns: Vec<String>) -> Self {
+        self.policy_arns = Some(policy_arns);
+        self
+    }
+
     /// Set the region.
     pub fn with_region(mut self, region: String) -> Self {
         self.region = Some(region);
@@ -152,13 +176,29 @@ impl ProvideCredential for 
AssumeRoleWithWebIdentityCredentialProvider {
             .unwrap_or_else(|| "reqsign".to_string());
 
         // Construct request to AWS STS Service.
-        let query = Serializer::new(String::new())
-            .append_pair("Action", "AssumeRoleWithWebIdentity")
-            .append_pair("RoleArn", &role_arn)
-            .append_pair("WebIdentityToken", &token)
-            .append_pair("Version", "2011-06-15")
-            .append_pair("RoleSessionName", &session_name)
-            .finish();
+        let query = {
+            let mut serializer = Serializer::new(String::new());
+            serializer
+                .append_pair("Action", "AssumeRoleWithWebIdentity")
+                .append_pair("RoleArn", &role_arn)
+                .append_pair("WebIdentityToken", &token)
+                .append_pair("Version", "2011-06-15")
+                .append_pair("RoleSessionName", &session_name);
+
+            if let Some(duration_seconds) = self.duration_seconds {
+                serializer.append_pair("DurationSeconds", 
&duration_seconds.to_string());
+            }
+            if let Some(policy) = self.policy.as_deref() {
+                serializer.append_pair("Policy", policy);
+            }
+            if let Some(policy_arns) = self.policy_arns.as_deref() {
+                for (idx, arn) in policy_arns.iter().enumerate() {
+                    
serializer.append_pair(&format!("PolicyArns.member.{}.arn", idx + 1), arn);
+                }
+            }
+
+            serializer.finish()
+        };
         let url = format!("https://{endpoint}/?{query}";);
         let req = http::request::Request::builder()
             .method("GET")
@@ -422,4 +462,72 @@ mod tests {
 
         Ok(())
     }
+
+    #[tokio::test]
+    async fn 
test_assume_role_with_web_identity_supports_policy_and_policy_arns() -> 
Result<()> {
+        let _ = env_logger::builder().is_test(true).try_init();
+
+        let token_path = "/mock/token";
+        let raw_token = "header.payload+signature/\n";
+
+        let file_read = TestFileRead {
+            expected_path: token_path.to_string(),
+            content: raw_token.as_bytes().to_vec(),
+        };
+
+        let http_body = r#"<AssumeRoleWithWebIdentityResponse 
xmlns="https://sts.amazonaws.com/doc/2011-06-15/";>
+  <AssumeRoleWithWebIdentityResult>
+    <Credentials>
+      <AccessKeyId>access_key_id</AccessKeyId>
+      <SecretAccessKey>secret_access_key</SecretAccessKey>
+      <SessionToken>session_token</SessionToken>
+      <Expiration>2124-05-25T11:45:17Z</Expiration>
+    </Credentials>
+  </AssumeRoleWithWebIdentityResult>
+</AssumeRoleWithWebIdentityResponse>"#;
+        let http_send = CaptureHttpSend::new(http_body);
+
+        let ctx = Context::new()
+            .with_file_read(file_read)
+            .with_http_send(http_send.clone())
+            .with_env(StaticEnv {
+                home_dir: None,
+                envs: HashMap::new(),
+            });
+
+        let policy = 
r#"{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":"s3:ListBucket","Resource":"*","Condition":{"StringEquals":{"s3:prefix":"a
 b"}}}]}"#;
+
+        let provider = 
AssumeRoleWithWebIdentityCredentialProvider::with_config(
+            "arn:aws:iam::123456789012:role/test-role".to_string(),
+            token_path.into(),
+        )
+        .with_duration_seconds(900)
+        .with_policy(policy.to_string())
+        .with_policy_arns(vec![
+            "arn:aws:iam::aws:policy/ReadOnlyAccess".to_string(),
+            "arn:aws:iam::123456789012:policy/ExamplePolicy".to_string(),
+        ]);
+
+        let _ = provider
+            .provide_credential(&ctx)
+            .await?
+            .expect("credential must be loaded");
+
+        let recorded_uri = http_send
+            .uri()
+            .expect("http_send must capture outgoing uri");
+
+        assert!(recorded_uri.contains("DurationSeconds=900"));
+        assert!(
+            
recorded_uri.contains("Policy=%7B%22Version%22%3A%222012-10-17%22%2C%22Statement%22%3A%5B%7B%22Effect%22%3A%22Allow%22%2C%22Action%22%3A%22s3%3AListBucket%22%2C%22Resource%22%3A%22*%22%2C%22Condition%22%3A%7B%22StringEquals%22%3A%7B%22s3%3Aprefix%22%3A%22a+b%22%7D%7D%7D%5D%7D")
+        );
+        assert!(recorded_uri.contains(
+            
"PolicyArns.member.1.arn=arn%3Aaws%3Aiam%3A%3Aaws%3Apolicy%2FReadOnlyAccess"
+        ));
+        assert!(recorded_uri.contains(
+            
"PolicyArns.member.2.arn=arn%3Aaws%3Aiam%3A%3A123456789012%3Apolicy%2FExamplePolicy"
+        ));
+
+        Ok(())
+    }
 }

Reply via email to