roeap commented on code in PR #3581:
URL: https://github.com/apache/arrow-rs/pull/3581#discussion_r1084447481


##########
object_store/src/azure/credential.rs:
##########
@@ -294,45 +306,241 @@ impl ClientSecretOAuthProvider {
     pub fn new(
         client_id: String,
         client_secret: String,
-        tenant_id: String,
+        tenant_id: impl AsRef<str>,
         authority_host: Option<String>,
     ) -> Self {
         let authority_host = authority_host
             .unwrap_or_else(|| authority_hosts::AZURE_PUBLIC_CLOUD.to_owned());
 
         Self {
-            scope: "https://storage.azure.com/.default".to_owned(),
-            token_url: format!("{}/{}/oauth2/v2.0/token", authority_host, 
tenant_id),
+            token_url: format!(
+                "{}/{}/oauth2/v2.0/token",
+                authority_host,
+                tenant_id.as_ref()
+            ),
             client_id,
             client_secret,
             cache: TokenCache::default(),
         }
     }
 
+    /// Fetch a fresh token
+    async fn fetch_token_inner(
+        &self,
+        client: &Client,
+        retry: &RetryConfig,
+    ) -> Result<TemporaryToken<String>> {
+        let response: TokenResponse = client
+            .request(Method::POST, &self.token_url)
+            .header(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_JSON))
+            .form(&[
+                ("client_id", self.client_id.as_str()),
+                ("client_secret", self.client_secret.as_str()),
+                ("scope", AZURE_STORAGE_SCOPE),
+                ("grant_type", "client_credentials"),
+            ])
+            .send_retry(retry)
+            .await
+            .context(TokenRequestSnafu)?
+            .json()
+            .await
+            .context(TokenResponseBodySnafu)?;
+
+        let token = TemporaryToken {
+            token: response.access_token,
+            expiry: Instant::now() + Duration::from_secs(response.expires_in),
+        };
+
+        Ok(token)
+    }
+}
+
+#[async_trait::async_trait]
+impl TokenCredential for ClientSecretOAuthProvider {
     /// Fetch a token
-    pub async fn fetch_token(
+    async fn fetch_token(&self, client: &Client, retry: &RetryConfig) -> 
Result<String> {
+        self.cache
+            .get_or_insert_with(|| self.fetch_token_inner(client, retry))
+            .await
+    }
+}
+
+fn expires_in_string<'de, D>(deserializer: D) -> std::result::Result<u64, 
D::Error>
+where
+    D: serde::de::Deserializer<'de>,
+{
+    let v = String::deserialize(deserializer)?;
+    v.parse::<u64>().map_err(serde::de::Error::custom)
+}
+
+// NOTE: expires_on is a String version of unix epoch time, not an integer.
+// 
<https://learn.microsoft.com/en-gb/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http>
+#[derive(Debug, Clone, Deserialize)]
+struct MsiTokenResponse {
+    pub access_token: String,
+    #[serde(deserialize_with = "expires_in_string")]
+    pub expires_in: u64,
+}
+
+/// Attempts authentication using a managed identity that has been assigned to 
the deployment environment.
+///
+/// This authentication type works in Azure VMs, App Service and Azure 
Functions applications, as well as the Azure Cloud Shell
+/// 
<https://learn.microsoft.com/en-gb/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http>
+#[derive(Debug)]
+pub struct ImdsManagedIdentityOAuthProvider {
+    msi_endpoint: String,
+    client_id: Option<String>,
+    object_id: Option<String>,
+    msi_res_id: Option<String>,
+    cache: TokenCache<String>,
+}
+
+impl ImdsManagedIdentityOAuthProvider {
+    /// Create a new [`ImdsManagedIdentityOAuthProvider`] for an azure backed 
store
+    pub fn new(
+        client_id: Option<String>,
+        object_id: Option<String>,
+        msi_res_id: Option<String>,
+        msi_endpoint: Option<String>,
+    ) -> Self {
+        let msi_endpoint = msi_endpoint.unwrap_or_else(|| {
+            "http://169.254.169.254/metadata/identity/oauth2/token".to_owned()
+        });
+
+        Self {
+            msi_endpoint,
+            client_id,
+            object_id,
+            msi_res_id,
+            cache: TokenCache::default(),
+        }
+    }
+
+    /// Fetch a fresh token
+    async fn fetch_token_inner(
         &self,
         client: &Client,
         retry: &RetryConfig,
-    ) -> Result<String> {
+    ) -> Result<TemporaryToken<String>> {
+        let mut query_items = vec![
+            ("api-version", MSI_API_VERSION),
+            ("resource", AZURE_STORAGE_SCOPE),
+        ];
+
+        match (
+            self.object_id.as_ref(),
+            self.client_id.as_ref(),
+            self.msi_res_id.as_ref(),
+        ) {
+            (Some(object_id), None, None) => query_items.push(("object_id", 
object_id)),
+            (None, Some(client_id), None) => query_items.push(("client_id", 
client_id)),
+            (None, None, Some(msi_res_id)) => {
+                query_items.push(("msi_res_id", msi_res_id))
+            }
+            _ => (),

Review Comment:
   True - I went for just having a hierarchy. Since object id and msi_res_id 
are more specific (only used by managed identity) and less likely to just "end 
up" in the environment, I thought they would get precedence. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to