roeap commented on code in PR #3581:
URL: https://github.com/apache/arrow-rs/pull/3581#discussion_r1084405653
##########
object_store/src/azure/credential.rs:
##########
@@ -294,45 +306,241 @@ impl ClientSecretOAuthProvider {
pub fn new(
client_id: String,
client_secret: String,
- tenant_id: String,
+ tenant_id: impl AsRef<str>,
authority_host: Option<String>,
) -> Self {
let authority_host = authority_host
.unwrap_or_else(|| authority_hosts::AZURE_PUBLIC_CLOUD.to_owned());
Self {
- scope: "https://storage.azure.com/.default".to_owned(),
- token_url: format!("{}/{}/oauth2/v2.0/token", authority_host,
tenant_id),
+ token_url: format!(
+ "{}/{}/oauth2/v2.0/token",
+ authority_host,
+ tenant_id.as_ref()
+ ),
client_id,
client_secret,
cache: TokenCache::default(),
}
}
+ /// Fetch a fresh token
+ async fn fetch_token_inner(
+ &self,
+ client: &Client,
+ retry: &RetryConfig,
+ ) -> Result<TemporaryToken<String>> {
+ let response: TokenResponse = client
+ .request(Method::POST, &self.token_url)
+ .header(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_JSON))
+ .form(&[
+ ("client_id", self.client_id.as_str()),
+ ("client_secret", self.client_secret.as_str()),
+ ("scope", AZURE_STORAGE_SCOPE),
+ ("grant_type", "client_credentials"),
+ ])
+ .send_retry(retry)
+ .await
+ .context(TokenRequestSnafu)?
+ .json()
+ .await
+ .context(TokenResponseBodySnafu)?;
+
+ let token = TemporaryToken {
+ token: response.access_token,
+ expiry: Instant::now() + Duration::from_secs(response.expires_in),
+ };
+
+ Ok(token)
+ }
+}
+
+#[async_trait::async_trait]
+impl TokenCredential for ClientSecretOAuthProvider {
/// Fetch a token
- pub async fn fetch_token(
+ async fn fetch_token(&self, client: &Client, retry: &RetryConfig) ->
Result<String> {
+ self.cache
Review Comment:
makes absolute sense - done.
##########
object_store/src/azure/credential.rs:
##########
@@ -294,45 +306,241 @@ impl ClientSecretOAuthProvider {
pub fn new(
client_id: String,
client_secret: String,
- tenant_id: String,
+ tenant_id: impl AsRef<str>,
authority_host: Option<String>,
) -> Self {
let authority_host = authority_host
.unwrap_or_else(|| authority_hosts::AZURE_PUBLIC_CLOUD.to_owned());
Self {
- scope: "https://storage.azure.com/.default".to_owned(),
- token_url: format!("{}/{}/oauth2/v2.0/token", authority_host,
tenant_id),
+ token_url: format!(
+ "{}/{}/oauth2/v2.0/token",
+ authority_host,
+ tenant_id.as_ref()
+ ),
client_id,
client_secret,
cache: TokenCache::default(),
}
}
+ /// Fetch a fresh token
+ async fn fetch_token_inner(
+ &self,
+ client: &Client,
+ retry: &RetryConfig,
+ ) -> Result<TemporaryToken<String>> {
+ let response: TokenResponse = client
+ .request(Method::POST, &self.token_url)
+ .header(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_JSON))
+ .form(&[
+ ("client_id", self.client_id.as_str()),
+ ("client_secret", self.client_secret.as_str()),
+ ("scope", AZURE_STORAGE_SCOPE),
+ ("grant_type", "client_credentials"),
+ ])
+ .send_retry(retry)
+ .await
+ .context(TokenRequestSnafu)?
+ .json()
+ .await
+ .context(TokenResponseBodySnafu)?;
+
+ let token = TemporaryToken {
+ token: response.access_token,
+ expiry: Instant::now() + Duration::from_secs(response.expires_in),
+ };
+
+ Ok(token)
+ }
+}
+
+#[async_trait::async_trait]
+impl TokenCredential for ClientSecretOAuthProvider {
/// Fetch a token
- pub async fn fetch_token(
+ async fn fetch_token(&self, client: &Client, retry: &RetryConfig) ->
Result<String> {
+ self.cache
+ .get_or_insert_with(|| self.fetch_token_inner(client, retry))
+ .await
+ }
+}
+
+fn expires_in_string<'de, D>(deserializer: D) -> std::result::Result<u64,
D::Error>
+where
+ D: serde::de::Deserializer<'de>,
+{
+ let v = String::deserialize(deserializer)?;
+ v.parse::<u64>().map_err(serde::de::Error::custom)
+}
+
+// NOTE: expires_on is a String version of unix epoch time, not an integer.
+//
<https://learn.microsoft.com/en-gb/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http>
+#[derive(Debug, Clone, Deserialize)]
+struct MsiTokenResponse {
+ pub access_token: String,
+ #[serde(deserialize_with = "expires_in_string")]
+ pub expires_in: u64,
+}
+
+/// Attempts authentication using a managed identity that has been assigned to
the deployment environment.
+///
+/// This authentication type works in Azure VMs, App Service and Azure
Functions applications, as well as the Azure Cloud Shell
+///
<https://learn.microsoft.com/en-gb/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http>
+#[derive(Debug)]
+pub struct ImdsManagedIdentityOAuthProvider {
+ msi_endpoint: String,
+ client_id: Option<String>,
+ object_id: Option<String>,
+ msi_res_id: Option<String>,
+ cache: TokenCache<String>,
+}
+
+impl ImdsManagedIdentityOAuthProvider {
+ /// Create a new [`ImdsManagedIdentityOAuthProvider`] for an azure backed
store
+ pub fn new(
+ client_id: Option<String>,
+ object_id: Option<String>,
+ msi_res_id: Option<String>,
+ msi_endpoint: Option<String>,
+ ) -> Self {
+ let msi_endpoint = msi_endpoint.unwrap_or_else(|| {
+ "http://169.254.169.254/metadata/identity/oauth2/token".to_owned()
+ });
+
+ Self {
+ msi_endpoint,
+ client_id,
+ object_id,
+ msi_res_id,
+ cache: TokenCache::default(),
+ }
+ }
+
+ /// Fetch a fresh token
+ async fn fetch_token_inner(
&self,
client: &Client,
retry: &RetryConfig,
- ) -> Result<String> {
+ ) -> Result<TemporaryToken<String>> {
+ let mut query_items = vec![
+ ("api-version", MSI_API_VERSION),
+ ("resource", AZURE_STORAGE_SCOPE),
+ ];
+
+ match (
+ self.object_id.as_ref(),
+ self.client_id.as_ref(),
+ self.msi_res_id.as_ref(),
+ ) {
+ (Some(object_id), None, None) => query_items.push(("object_id",
object_id)),
+ (None, Some(client_id), None) => query_items.push(("client_id",
client_id)),
+ (None, None, Some(msi_res_id)) => {
+ query_items.push(("msi_res_id", msi_res_id))
+ }
+ _ => (),
+ }
+
+ let mut builder = client
+ .request(Method::GET, &self.msi_endpoint)
+ .header("metadata", "true")
+ .query(&query_items);
+
+ if let Ok(val) = std::env::var(MSI_SECRET_ENV_KEY) {
Review Comment:
I was not sure about this value. From what I understood, its [rotated at
runtime](https://learn.microsoft.com/en-us/azure/app-service/overview-managed-identity?tabs=portal%2Cdotnet#rest-endpoint-reference).
##########
object_store/src/azure/credential.rs:
##########
@@ -350,3 +558,131 @@ impl ClientSecretOAuthProvider {
Ok(token)
}
}
+
+#[async_trait::async_trait]
+impl TokenCredential for WorkloadIdentityOAuthProvider {
+ /// Fetch a token
+ async fn fetch_token(&self, client: &Client, retry: &RetryConfig) ->
Result<String> {
+ self.cache
+ .get_or_insert_with(|| self.fetch_token_inner(client, retry))
+ .await
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::client::mock_server::MockServer;
+ use crate::local::LocalFileSystem;
+ use crate::ObjectStore;
+ use hyper::{Body, Response};
+ use reqwest::{Client, Method};
+ use tempfile::TempDir;
+
+ #[tokio::test]
+ async fn test_managed_identity() {
+ let server = MockServer::new();
+
+ std::env::set_var(MSI_SECRET_ENV_KEY, "env-secret");
+
+ let endpoint = server.url();
+ let client = Client::new();
+ let retry_config = RetryConfig::default();
+
+ // Test IMDS
+ server.push_fn(|req| {
+ assert_eq!(req.uri().path(), "/metadata/identity/oauth2/token");
+
assert!(req.uri().query().unwrap().contains("client_id=client_id"));
+ assert_eq!(req.method(), &Method::GET);
+ let t = req
+ .headers()
+ .get("x-identity-header")
+ .unwrap()
+ .to_str()
+ .unwrap();
+ assert_eq!(t, "env-secret");
+ let t = req.headers().get("metadata").unwrap().to_str().unwrap();
+ assert_eq!(t, "true");
+ Response::new(Body::from(
+ r#"
+ {
+ "access_token": "TOKEN",
+ "refresh_token": "",
+ "expires_in": "3599",
+ "expires_on": "1506484173",
+ "not_before": "1506480273",
+ "resource": "https://management.azure.com/",
+ "token_type": "Bearer"
+ }
+ "#,
+ ))
+ });
+
+ let credential = ImdsManagedIdentityOAuthProvider::new(
+ Some("client_id".into()),
+ None,
+ None,
+ Some(format!("{}/metadata/identity/oauth2/token", endpoint)),
+ );
+
+ let token = credential
+ .fetch_token(&client, &retry_config)
+ .await
+ .unwrap();
+
+ assert_eq!(&token, "TOKEN");
+ }
+
+ #[tokio::test]
+ async fn test_workload_identity() {
+ let server = MockServer::new();
+
+ let root = TempDir::new().unwrap();
+ let fs = LocalFileSystem::new_with_prefix(root.path()).unwrap();
+ let tenant = "tenant";
+ let tokenfile = root.path().join("tokenfile");
+ fs.put(
+ &crate::path::Path::from("tokenfile"),
+ bytes::Bytes::from("federated-token"),
+ )
+ .await
+ .unwrap();
Review Comment:
good point :)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]