This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new d727503e4 object_score: Support Azure Fabric OAuth Provider (#6382)
d727503e4 is described below

commit d727503e4d52d40d7215a42398b02d58e964ed01
Author: Robin Lin <[email protected]>
AuthorDate: Sat Sep 21 18:15:26 2024 +0800

    object_score: Support Azure Fabric OAuth Provider (#6382)
    
    * Update Azure dependencies and add support for Fabric token authentication
    
    * Refactor Azure credential provider to support Fabric token authentication
    
    * Refactor Azure credential provider to remove unnecessary print statements 
and improve token handling
    
    * Bump object_store version to 0.11.0
    
    * Refactor Azure credential provider to remove unnecessary print statements 
and improve token handling
---
 object_store/src/azure/builder.rs    |  88 ++++++++++++++++++++++++++-
 object_store/src/azure/credential.rs | 114 ++++++++++++++++++++++++++++++++++-
 2 files changed, 199 insertions(+), 3 deletions(-)

diff --git a/object_store/src/azure/builder.rs 
b/object_store/src/azure/builder.rs
index 0208073e8..35cedeafc 100644
--- a/object_store/src/azure/builder.rs
+++ b/object_store/src/azure/builder.rs
@@ -17,8 +17,8 @@
 
 use crate::azure::client::{AzureClient, AzureConfig};
 use crate::azure::credential::{
-    AzureAccessKey, AzureCliCredential, ClientSecretOAuthProvider, 
ImdsManagedIdentityProvider,
-    WorkloadIdentityOAuthProvider,
+    AzureAccessKey, AzureCliCredential, ClientSecretOAuthProvider, 
FabricTokenOAuthProvider,
+    ImdsManagedIdentityProvider, WorkloadIdentityOAuthProvider,
 };
 use crate::azure::{AzureCredential, AzureCredentialProvider, MicrosoftAzure, 
STORE};
 use crate::client::TokenCredentialProvider;
@@ -172,6 +172,14 @@ pub struct MicrosoftAzureBuilder {
     use_fabric_endpoint: ConfigValue<bool>,
     /// When set to true, skips tagging objects
     disable_tagging: ConfigValue<bool>,
+    /// Fabric token service url
+    fabric_token_service_url: Option<String>,
+    /// Fabric workload host
+    fabric_workload_host: Option<String>,
+    /// Fabric session token
+    fabric_session_token: Option<String>,
+    /// Fabric cluster identifier
+    fabric_cluster_identifier: Option<String>,
 }
 
 /// Configuration keys for [`MicrosoftAzureBuilder`]
@@ -336,6 +344,34 @@ pub enum AzureConfigKey {
     /// - `disable_tagging`
     DisableTagging,
 
+    /// Fabric token service url
+    ///
+    /// Supported keys:
+    /// - `azure_fabric_token_service_url`
+    /// - `fabric_token_service_url`
+    FabricTokenServiceUrl,
+
+    /// Fabric workload host
+    ///
+    /// Supported keys:
+    /// - `azure_fabric_workload_host`
+    /// - `fabric_workload_host`
+    FabricWorkloadHost,
+
+    /// Fabric session token
+    ///
+    /// Supported keys:
+    /// - `azure_fabric_session_token`
+    /// - `fabric_session_token`
+    FabricSessionToken,
+
+    /// Fabric cluster identifier
+    ///
+    /// Supported keys:
+    /// - `azure_fabric_cluster_identifier`
+    /// - `fabric_cluster_identifier`
+    FabricClusterIdentifier,
+
     /// Client options
     Client(ClientConfigKey),
 }
@@ -361,6 +397,10 @@ impl AsRef<str> for AzureConfigKey {
             Self::SkipSignature => "azure_skip_signature",
             Self::ContainerName => "azure_container_name",
             Self::DisableTagging => "azure_disable_tagging",
+            Self::FabricTokenServiceUrl => "azure_fabric_token_service_url",
+            Self::FabricWorkloadHost => "azure_fabric_workload_host",
+            Self::FabricSessionToken => "azure_fabric_session_token",
+            Self::FabricClusterIdentifier => "azure_fabric_cluster_identifier",
             Self::Client(key) => key.as_ref(),
         }
     }
@@ -406,6 +446,14 @@ impl FromStr for AzureConfigKey {
             "azure_skip_signature" | "skip_signature" => 
Ok(Self::SkipSignature),
             "azure_container_name" | "container_name" => 
Ok(Self::ContainerName),
             "azure_disable_tagging" | "disable_tagging" => 
Ok(Self::DisableTagging),
+            "azure_fabric_token_service_url" | "fabric_token_service_url" => {
+                Ok(Self::FabricTokenServiceUrl)
+            }
+            "azure_fabric_workload_host" | "fabric_workload_host" => 
Ok(Self::FabricWorkloadHost),
+            "azure_fabric_session_token" | "fabric_session_token" => 
Ok(Self::FabricSessionToken),
+            "azure_fabric_cluster_identifier" | "fabric_cluster_identifier" => 
{
+                Ok(Self::FabricClusterIdentifier)
+            }
             // Backwards compatibility
             "azure_allow_http" => Ok(Self::Client(ClientConfigKey::AllowHttp)),
             _ => match s.strip_prefix("azure_").unwrap_or(s).parse() {
@@ -525,6 +573,14 @@ impl MicrosoftAzureBuilder {
             }
             AzureConfigKey::ContainerName => self.container_name = 
Some(value.into()),
             AzureConfigKey::DisableTagging => 
self.disable_tagging.parse(value),
+            AzureConfigKey::FabricTokenServiceUrl => {
+                self.fabric_token_service_url = Some(value.into())
+            }
+            AzureConfigKey::FabricWorkloadHost => self.fabric_workload_host = 
Some(value.into()),
+            AzureConfigKey::FabricSessionToken => self.fabric_session_token = 
Some(value.into()),
+            AzureConfigKey::FabricClusterIdentifier => {
+                self.fabric_cluster_identifier = Some(value.into())
+            }
         };
         self
     }
@@ -561,6 +617,10 @@ impl MicrosoftAzureBuilder {
             AzureConfigKey::Client(key) => 
self.client_options.get_config_value(key),
             AzureConfigKey::ContainerName => self.container_name.clone(),
             AzureConfigKey::DisableTagging => 
Some(self.disable_tagging.to_string()),
+            AzureConfigKey::FabricTokenServiceUrl => 
self.fabric_token_service_url.clone(),
+            AzureConfigKey::FabricWorkloadHost => 
self.fabric_workload_host.clone(),
+            AzureConfigKey::FabricSessionToken => 
self.fabric_session_token.clone(),
+            AzureConfigKey::FabricClusterIdentifier => 
self.fabric_cluster_identifier.clone(),
         }
     }
 
@@ -856,6 +916,30 @@ impl MicrosoftAzureBuilder {
 
             let credential = if let Some(credential) = self.credentials {
                 credential
+            } else if let (
+                Some(fabric_token_service_url),
+                Some(fabric_workload_host),
+                Some(fabric_session_token),
+                Some(fabric_cluster_identifier),
+            ) = (
+                &self.fabric_token_service_url,
+                &self.fabric_workload_host,
+                &self.fabric_session_token,
+                &self.fabric_cluster_identifier,
+            ) {
+                // This case should precede the bearer token case because it 
is more specific and will utilize the bearer token.
+                let fabric_credential = FabricTokenOAuthProvider::new(
+                    fabric_token_service_url,
+                    fabric_workload_host,
+                    fabric_session_token,
+                    fabric_cluster_identifier,
+                    self.bearer_token.clone(),
+                );
+                Arc::new(TokenCredentialProvider::new(
+                    fabric_credential,
+                    self.client_options.client()?,
+                    self.retry_config.clone(),
+                )) as _
             } else if let Some(bearer_token) = self.bearer_token {
                 static_creds(AzureCredential::BearerToken(bearer_token))
             } else if let Some(access_key) = self.access_key {
diff --git a/object_store/src/azure/credential.rs 
b/object_store/src/azure/credential.rs
index 7808c7c4a..6b5fa19d1 100644
--- a/object_store/src/azure/credential.rs
+++ b/object_store/src/azure/credential.rs
@@ -22,7 +22,7 @@ use crate::client::{CredentialProvider, TokenProvider};
 use crate::util::hmac_sha256;
 use crate::RetryConfig;
 use async_trait::async_trait;
-use base64::prelude::BASE64_STANDARD;
+use base64::prelude::{BASE64_STANDARD, BASE64_URL_SAFE_NO_PAD};
 use base64::Engine;
 use chrono::{DateTime, SecondsFormat, Utc};
 use reqwest::header::{
@@ -51,10 +51,15 @@ pub(crate) static BLOB_TYPE: HeaderName = 
HeaderName::from_static("x-ms-blob-typ
 pub(crate) static DELETE_SNAPSHOTS: HeaderName = 
HeaderName::from_static("x-ms-delete-snapshots");
 pub(crate) static COPY_SOURCE: HeaderName = 
HeaderName::from_static("x-ms-copy-source");
 static CONTENT_MD5: HeaderName = HeaderName::from_static("content-md5");
+static PARTNER_TOKEN: HeaderName = 
HeaderName::from_static("x-ms-partner-token");
+static CLUSTER_IDENTIFIER: HeaderName = 
HeaderName::from_static("x-ms-cluster-identifier");
+static WORKLOAD_RESOURCE: HeaderName = 
HeaderName::from_static("x-ms-workload-resource-moniker");
+static PROXY_HOST: HeaderName = HeaderName::from_static("x-ms-proxy-host");
 pub(crate) const RFC1123_FMT: &str = "%a, %d %h %Y %T GMT";
 const CONTENT_TYPE_JSON: &str = "application/json";
 const MSI_SECRET_ENV_KEY: &str = "IDENTITY_HEADER";
 const MSI_API_VERSION: &str = "2019-08-01";
+const TOKEN_MIN_TTL: u64 = 300;
 
 /// OIDC scope used when interacting with OAuth2 APIs
 ///
@@ -934,6 +939,113 @@ impl AzureCliCredential {
     }
 }
 
+/// Encapsulates the logic to perform an OAuth token challenge for Fabric
+#[derive(Debug)]
+pub struct FabricTokenOAuthProvider {
+    fabric_token_service_url: String,
+    fabric_workload_host: String,
+    fabric_session_token: String,
+    fabric_cluster_identifier: String,
+    storage_access_token: Option<String>,
+    token_expiry: Option<u64>,
+}
+
+#[derive(Debug, Deserialize)]
+struct Claims {
+    exp: u64,
+}
+
+impl FabricTokenOAuthProvider {
+    /// Create a new [`FabricTokenOAuthProvider`] for an azure backed store
+    pub fn new(
+        fabric_token_service_url: impl Into<String>,
+        fabric_workload_host: impl Into<String>,
+        fabric_session_token: impl Into<String>,
+        fabric_cluster_identifier: impl Into<String>,
+        storage_access_token: Option<String>,
+    ) -> Self {
+        let (storage_access_token, token_expiry) = match storage_access_token {
+            Some(token) => match Self::validate_and_get_expiry(&token) {
+                Some(expiry) if expiry > Self::get_current_timestamp() + 
TOKEN_MIN_TTL => {
+                    (Some(token), Some(expiry))
+                }
+                _ => (None, None),
+            },
+            None => (None, None),
+        };
+
+        Self {
+            fabric_token_service_url: fabric_token_service_url.into(),
+            fabric_workload_host: fabric_workload_host.into(),
+            fabric_session_token: fabric_session_token.into(),
+            fabric_cluster_identifier: fabric_cluster_identifier.into(),
+            storage_access_token,
+            token_expiry,
+        }
+    }
+
+    fn validate_and_get_expiry(token: &str) -> Option<u64> {
+        let payload = token.split('.').nth(1)?;
+        let decoded_bytes = BASE64_URL_SAFE_NO_PAD.decode(payload).ok()?;
+        let decoded_str = str::from_utf8(&decoded_bytes).ok()?;
+        let claims: Claims = serde_json::from_str(decoded_str).ok()?;
+        Some(claims.exp)
+    }
+
+    fn get_current_timestamp() -> u64 {
+        SystemTime::now()
+            .duration_since(SystemTime::UNIX_EPOCH)
+            .map_or(0, |d| d.as_secs())
+    }
+}
+
+#[async_trait::async_trait]
+impl TokenProvider for FabricTokenOAuthProvider {
+    type Credential = AzureCredential;
+
+    /// Fetch a token
+    async fn fetch_token(
+        &self,
+        client: &Client,
+        retry: &RetryConfig,
+    ) -> crate::Result<TemporaryToken<Arc<AzureCredential>>> {
+        if let Some(storage_access_token) = &self.storage_access_token {
+            if let Some(expiry) = self.token_expiry {
+                let exp_in = expiry - Self::get_current_timestamp();
+                if exp_in > TOKEN_MIN_TTL {
+                    return Ok(TemporaryToken {
+                        token: 
Arc::new(AzureCredential::BearerToken(storage_access_token.clone())),
+                        expiry: Some(Instant::now() + 
Duration::from_secs(exp_in)),
+                    });
+                }
+            }
+        }
+
+        let query_items = vec![("resource", AZURE_STORAGE_RESOURCE)];
+        let access_token: String = client
+            .request(Method::GET, &self.fabric_token_service_url)
+            .header(&PARTNER_TOKEN, self.fabric_session_token.as_str())
+            .header(&CLUSTER_IDENTIFIER, 
self.fabric_cluster_identifier.as_str())
+            .header(&WORKLOAD_RESOURCE, 
self.fabric_cluster_identifier.as_str())
+            .header(&PROXY_HOST, self.fabric_workload_host.as_str())
+            .query(&query_items)
+            .retryable(retry)
+            .idempotent(true)
+            .send()
+            .await
+            .context(TokenRequestSnafu)?
+            .text()
+            .await
+            .context(TokenResponseBodySnafu)?;
+        let exp_in = Self::validate_and_get_expiry(&access_token)
+            .map_or(3600, |expiry| expiry - Self::get_current_timestamp());
+        Ok(TemporaryToken {
+            token: Arc::new(AzureCredential::BearerToken(access_token)),
+            expiry: Some(Instant::now() + Duration::from_secs(exp_in)),
+        })
+    }
+}
+
 #[async_trait]
 impl CredentialProvider for AzureCliCredential {
     type Credential = AzureCredential;

Reply via email to