This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 11d2fe390 Expose credential provider (#4235)
11d2fe390 is described below

commit 11d2fe390b7d3ba8c23ac33545dbf75933be9f8b
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Wed May 17 21:13:25 2023 +0100

    Expose credential provider (#4235)
---
 object_store/src/aws/mod.rs        | 159 ++++++++++++++++++++-----------------
 object_store/src/azure/mod.rs      |  25 +++++-
 object_store/src/client/mod.rs     |   2 +
 object_store/src/gcp/credential.rs |   1 +
 object_store/src/gcp/mod.rs        |  30 +++++--
 object_store/src/lib.rs            |   2 +-
 6 files changed, 137 insertions(+), 82 deletions(-)

diff --git a/object_store/src/aws/mod.rs b/object_store/src/aws/mod.rs
index ddb9dc799..a10561ba6 100644
--- a/object_store/src/aws/mod.rs
+++ b/object_store/src/aws/mod.rs
@@ -47,9 +47,7 @@ use url::Url;
 
 pub use crate::aws::checksum::Checksum;
 use crate::aws::client::{S3Client, S3Config};
-use crate::aws::credential::{
-    AwsCredential, InstanceCredentialProvider, WebIdentityProvider,
-};
+use crate::aws::credential::{InstanceCredentialProvider, WebIdentityProvider};
 use crate::client::header::header_meta;
 use crate::client::{
     ClientConfigKey, CredentialProvider, StaticCredentialProvider,
@@ -85,7 +83,9 @@ const STRICT_PATH_ENCODE_SET: percent_encoding::AsciiSet = 
STRICT_ENCODE_SET.rem
 
 const STORE: &str = "S3";
 
-type AwsCredentialProvider = Arc<dyn CredentialProvider<Credential = 
AwsCredential>>;
+/// [`CredentialProvider`] for [`AmazonS3`]
+pub type AwsCredentialProvider = Arc<dyn CredentialProvider<Credential = 
AwsCredential>>;
+pub use credential::AwsCredential;
 
 /// Default metadata endpoint
 static METADATA_ENDPOINT: &str = "http://169.254.169.254";;
@@ -209,6 +209,13 @@ impl std::fmt::Display for AmazonS3 {
     }
 }
 
+impl AmazonS3 {
+    /// Returns the [`AwsCredentialProvider`] used by [`AmazonS3`]
+    pub fn credentials(&self) -> &AwsCredentialProvider {
+        &self.client.config().credentials
+    }
+}
+
 #[async_trait]
 impl ObjectStore for AmazonS3 {
     async fn put(&self, location: &Path, bytes: Bytes) -> Result<()> {
@@ -424,6 +431,8 @@ pub struct AmazonS3Builder {
     profile: Option<String>,
     /// Client options
     client_options: ClientOptions,
+    /// Credentials
+    credentials: Option<AwsCredentialProvider>,
 }
 
 /// Configuration keys for [`AmazonS3Builder`]
@@ -879,6 +888,12 @@ impl AmazonS3Builder {
         self
     }
 
+    /// Set the credential provider overriding any other options
+    pub fn with_credentials(mut self, credentials: AwsCredentialProvider) -> 
Self {
+        self.credentials = Some(credentials);
+        self
+    }
+
     /// Sets what protocol is allowed. If `allow_http` is :
     /// * false (default):  Only HTTPS are allowed
     /// * true:  HTTP and HTTPS are allowed
@@ -992,7 +1007,7 @@ impl AmazonS3Builder {
             self.parse_url(&url)?;
         }
 
-        let region = match (self.region.clone(), self.profile.clone()) {
+        let region = match (self.region, self.profile.clone()) {
             (Some(region), _) => Some(region),
             (None, Some(profile)) => profile_region(profile),
             (None, None) => None,
@@ -1002,76 +1017,74 @@ impl AmazonS3Builder {
         let region = region.context(MissingRegionSnafu)?;
         let checksum = self.checksum_algorithm.map(|x| x.get()).transpose()?;
 
-        let credentials = match (self.access_key_id, self.secret_access_key, 
self.token) {
-            (Some(key_id), Some(secret_key), token) => {
-                info!("Using Static credential provider");
-                let credential = AwsCredential {
-                    key_id,
-                    secret_key,
-                    token,
-                };
-                Arc::new(StaticCredentialProvider::new(credential)) as _
-            }
-            (None, Some(_), _) => return Err(Error::MissingAccessKeyId.into()),
-            (Some(_), None, _) => return 
Err(Error::MissingSecretAccessKey.into()),
-            // TODO: Replace with `AmazonS3Builder::credentials_from_env`
-            _ => match (
-                std::env::var("AWS_WEB_IDENTITY_TOKEN_FILE"),
-                std::env::var("AWS_ROLE_ARN"),
-            ) {
-                (Ok(token_path), Ok(role_arn)) => {
-                    info!("Using WebIdentity credential provider");
-
-                    let session_name = std::env::var("AWS_ROLE_SESSION_NAME")
-                        .unwrap_or_else(|_| "WebIdentitySession".to_string());
-
-                    let endpoint = 
format!("https://sts.{region}.amazonaws.com";);
-
-                    // Disallow non-HTTPs requests
-                    let client = self
-                        .client_options
-                        .clone()
-                        .with_allow_http(false)
-                        .client()?;
-
-                    let token = WebIdentityProvider {
-                        token_path,
-                        session_name,
-                        role_arn,
-                        endpoint,
-                    };
-
-                    Arc::new(TokenCredentialProvider::new(
+        let credentials = if let Some(credentials) = self.credentials {
+            credentials
+        } else if self.access_key_id.is_some() || 
self.secret_access_key.is_some() {
+            match (self.access_key_id, self.secret_access_key, self.token) {
+                (Some(key_id), Some(secret_key), token) => {
+                    info!("Using Static credential provider");
+                    let credential = AwsCredential {
+                        key_id,
+                        secret_key,
                         token,
-                        client,
-                        self.retry_config.clone(),
-                    )) as _
+                    };
+                    Arc::new(StaticCredentialProvider::new(credential)) as _
                 }
-                _ => match self.profile {
-                    Some(profile) => {
-                        info!("Using profile \"{}\" credential provider", 
profile);
-                        profile_credentials(profile, region.clone())?
-                    }
-                    None => {
-                        info!("Using Instance credential provider");
-
-                        let token = InstanceCredentialProvider {
-                            cache: Default::default(),
-                            imdsv1_fallback: self.imdsv1_fallback.get()?,
-                            metadata_endpoint: self
-                                .metadata_endpoint
-                                .unwrap_or_else(|| METADATA_ENDPOINT.into()),
-                        };
-
-                        Arc::new(TokenCredentialProvider::new(
-                            token,
-                            // The instance metadata endpoint is access over 
HTTP
-                            
self.client_options.clone().with_allow_http(true).client()?,
-                            self.retry_config.clone(),
-                        )) as _
-                    }
-                },
-            },
+                (None, Some(_), _) => return 
Err(Error::MissingAccessKeyId.into()),
+                (Some(_), None, _) => return 
Err(Error::MissingSecretAccessKey.into()),
+                (None, None, _) => unreachable!(),
+            }
+        } else if let (Ok(token_path), Ok(role_arn)) = (
+            std::env::var("AWS_WEB_IDENTITY_TOKEN_FILE"),
+            std::env::var("AWS_ROLE_ARN"),
+        ) {
+            // TODO: Replace with `AmazonS3Builder::credentials_from_env`
+            info!("Using WebIdentity credential provider");
+
+            let session_name = std::env::var("AWS_ROLE_SESSION_NAME")
+                .unwrap_or_else(|_| "WebIdentitySession".to_string());
+
+            let endpoint = format!("https://sts.{region}.amazonaws.com";);
+
+            // Disallow non-HTTPs requests
+            let client = self
+                .client_options
+                .clone()
+                .with_allow_http(false)
+                .client()?;
+
+            let token = WebIdentityProvider {
+                token_path,
+                session_name,
+                role_arn,
+                endpoint,
+            };
+
+            Arc::new(TokenCredentialProvider::new(
+                token,
+                client,
+                self.retry_config.clone(),
+            )) as _
+        } else if let Some(profile) = self.profile {
+            info!("Using profile \"{}\" credential provider", profile);
+            profile_credentials(profile, region.clone())?
+        } else {
+            info!("Using Instance credential provider");
+
+            let token = InstanceCredentialProvider {
+                cache: Default::default(),
+                imdsv1_fallback: self.imdsv1_fallback.get()?,
+                metadata_endpoint: self
+                    .metadata_endpoint
+                    .unwrap_or_else(|| METADATA_ENDPOINT.into()),
+            };
+
+            Arc::new(TokenCredentialProvider::new(
+                token,
+                // The instance metadata endpoint is access over HTTP
+                self.client_options.clone().with_allow_http(true).client()?,
+                self.retry_config.clone(),
+            )) as _
         };
 
         let endpoint: String;
diff --git a/object_store/src/azure/mod.rs b/object_store/src/azure/mod.rs
index 6dc14cfb5..069b033d1 100644
--- a/object_store/src/azure/mod.rs
+++ b/object_store/src/azure/mod.rs
@@ -48,7 +48,6 @@ use std::{collections::BTreeSet, str::FromStr};
 use tokio::io::AsyncWrite;
 use url::Url;
 
-use crate::azure::credential::AzureCredential;
 use crate::client::header::header_meta;
 use crate::client::{
     ClientConfigKey, CredentialProvider, StaticCredentialProvider,
@@ -60,7 +59,10 @@ pub use credential::authority_hosts;
 mod client;
 mod credential;
 
-type AzureCredentialProvider = Arc<dyn CredentialProvider<Credential = 
AzureCredential>>;
+/// [`CredentialProvider`] for [`MicrosoftAzure`]
+pub type AzureCredentialProvider =
+    Arc<dyn CredentialProvider<Credential = AzureCredential>>;
+pub use credential::AzureCredential;
 
 const STORE: &str = "MicrosoftAzure";
 
@@ -153,6 +155,13 @@ pub struct MicrosoftAzure {
     client: Arc<client::AzureClient>,
 }
 
+impl MicrosoftAzure {
+    /// Returns the [`AzureCredentialProvider`] used by [`MicrosoftAzure`]
+    pub fn credentials(&self) -> &AzureCredentialProvider {
+        &self.client.config().credentials
+    }
+}
+
 impl std::fmt::Display for MicrosoftAzure {
     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
         write!(
@@ -374,6 +383,8 @@ pub struct MicrosoftAzureBuilder {
     retry_config: RetryConfig,
     /// Client options
     client_options: ClientOptions,
+    /// Credentials
+    credentials: Option<AzureCredentialProvider>,
 }
 
 /// Configuration keys for [`MicrosoftAzureBuilder`]
@@ -840,6 +851,12 @@ impl MicrosoftAzureBuilder {
         self
     }
 
+    /// Set the credential provider overriding any other options
+    pub fn with_credentials(mut self, credentials: AzureCredentialProvider) -> 
Self {
+        self.credentials = Some(credentials);
+        self
+    }
+
     /// Set if the Azure emulator should be used (defaults to false)
     pub fn with_use_emulator(mut self, use_emulator: bool) -> Self {
         self.use_emulator = use_emulator.into();
@@ -937,7 +954,9 @@ impl MicrosoftAzureBuilder {
             let url = Url::parse(&account_url)
                 .context(UnableToParseUrlSnafu { url: account_url })?;
 
-            let credential = if let Some(bearer_token) = self.bearer_token {
+            let credential = if let Some(credential) = self.credentials {
+                credential
+            } else if let Some(bearer_token) = self.bearer_token {
                 static_creds(AzureCredential::BearerToken(bearer_token))
             } else if let Some(access_key) = self.access_key {
                 static_creds(AzureCredential::AccessKey(access_key))
diff --git a/object_store/src/client/mod.rs b/object_store/src/client/mod.rs
index 292e4678f..8c2357699 100644
--- a/object_store/src/client/mod.rs
+++ b/object_store/src/client/mod.rs
@@ -509,8 +509,10 @@ impl GetOptionsExt for RequestBuilder {
 /// Provides credentials for use when signing requests
 #[async_trait]
 pub trait CredentialProvider: std::fmt::Debug + Send + Sync {
+    /// The type of credential returned by this provider
     type Credential;
 
+    /// Return a credential
     async fn get_credential(&self) -> Result<Arc<Self::Credential>>;
 }
 
diff --git a/object_store/src/gcp/credential.rs 
b/object_store/src/gcp/credential.rs
index ad12855e1..205b80594 100644
--- a/object_store/src/gcp/credential.rs
+++ b/object_store/src/gcp/credential.rs
@@ -82,6 +82,7 @@ impl From<Error> for crate::Error {
     }
 }
 
+/// A Google Cloud Storage Credential
 #[derive(Debug, Eq, PartialEq)]
 pub struct GcpCredential {
     /// An HTTP bearer token
diff --git a/object_store/src/gcp/mod.rs b/object_store/src/gcp/mod.rs
index 6813bbf6e..21ba1588f 100644
--- a/object_store/src/gcp/mod.rs
+++ b/object_store/src/gcp/mod.rs
@@ -52,7 +52,6 @@ use crate::client::{
     ClientConfigKey, CredentialProvider, GetOptionsExt, 
StaticCredentialProvider,
     TokenCredentialProvider,
 };
-use crate::gcp::credential::{application_default_credentials, GcpCredential};
 use crate::{
     multipart::{CloudMultiPartUpload, CloudMultiPartUploadImpl, UploadPart},
     path::{Path, DELIMITER},
@@ -61,15 +60,18 @@ use crate::{
     ObjectStore, Result, RetryConfig,
 };
 
-use self::credential::{
-    default_gcs_base_url, InstanceCredentialProvider, 
ServiceAccountCredentials,
+use credential::{
+    application_default_credentials, default_gcs_base_url, 
InstanceCredentialProvider,
+    ServiceAccountCredentials,
 };
 
 mod credential;
 
 const STORE: &str = "GCS";
 
-type GcpCredentialProvider = Arc<dyn CredentialProvider<Credential = 
GcpCredential>>;
+/// [`CredentialProvider`] for [`GoogleCloudStorage`]
+pub type GcpCredentialProvider = Arc<dyn CredentialProvider<Credential = 
GcpCredential>>;
+pub use credential::GcpCredential;
 
 #[derive(Debug, Snafu)]
 enum Error {
@@ -205,6 +207,13 @@ impl std::fmt::Display for GoogleCloudStorage {
     }
 }
 
+impl GoogleCloudStorage {
+    /// Returns the [`GcpCredentialProvider`] used by [`GoogleCloudStorage`]
+    pub fn credentials(&self) -> &GcpCredentialProvider {
+        &self.client.credentials
+    }
+}
+
 #[derive(Debug)]
 struct GoogleCloudStorageClient {
     client: Client,
@@ -696,6 +705,8 @@ pub struct GoogleCloudStorageBuilder {
     retry_config: RetryConfig,
     /// Client options
     client_options: ClientOptions,
+    /// Credentials
+    credentials: Option<GcpCredentialProvider>,
 }
 
 /// Configuration keys for [`GoogleCloudStorageBuilder`]
@@ -794,6 +805,7 @@ impl Default for GoogleCloudStorageBuilder {
             retry_config: Default::default(),
             client_options: ClientOptions::new().with_allow_http(true),
             url: None,
+            credentials: None,
         }
     }
 }
@@ -1006,6 +1018,12 @@ impl GoogleCloudStorageBuilder {
         self
     }
 
+    /// Set the credential provider overriding any other options
+    pub fn with_credentials(mut self, credentials: GcpCredentialProvider) -> 
Self {
+        self.credentials = Some(credentials);
+        self
+    }
+
     /// Set the retry configuration
     pub fn with_retry(mut self, retry_config: RetryConfig) -> Self {
         self.retry_config = retry_config;
@@ -1072,7 +1090,9 @@ impl GoogleCloudStorageBuilder {
         let scope = "https://www.googleapis.com/auth/devstorage.full_control";;
         let audience = "https://www.googleapis.com/oauth2/v4/token";;
 
-        let credentials = if disable_oauth {
+        let credentials = if let Some(credentials) = self.credentials {
+            credentials
+        } else if disable_oauth {
             Arc::new(StaticCredentialProvider::new(GcpCredential {
                 bearer: "".to_string(),
             })) as _
diff --git a/object_store/src/lib.rs b/object_store/src/lib.rs
index 0f3ed809e..7116a8732 100644
--- a/object_store/src/lib.rs
+++ b/object_store/src/lib.rs
@@ -245,7 +245,7 @@ pub mod throttle;
 mod client;
 
 #[cfg(any(feature = "gcp", feature = "aws", feature = "azure", feature = 
"http"))]
-pub use client::{backoff::BackoffConfig, retry::RetryConfig};
+pub use client::{backoff::BackoffConfig, retry::RetryConfig, 
CredentialProvider};
 
 #[cfg(any(feature = "gcp", feature = "aws", feature = "azure", feature = 
"http"))]
 mod config;

Reply via email to