This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 11d2fe390 Expose credential provider (#4235)
11d2fe390 is described below
commit 11d2fe390b7d3ba8c23ac33545dbf75933be9f8b
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Wed May 17 21:13:25 2023 +0100
Expose credential provider (#4235)
---
object_store/src/aws/mod.rs | 159 ++++++++++++++++++++-----------------
object_store/src/azure/mod.rs | 25 +++++-
object_store/src/client/mod.rs | 2 +
object_store/src/gcp/credential.rs | 1 +
object_store/src/gcp/mod.rs | 30 +++++--
object_store/src/lib.rs | 2 +-
6 files changed, 137 insertions(+), 82 deletions(-)
diff --git a/object_store/src/aws/mod.rs b/object_store/src/aws/mod.rs
index ddb9dc799..a10561ba6 100644
--- a/object_store/src/aws/mod.rs
+++ b/object_store/src/aws/mod.rs
@@ -47,9 +47,7 @@ use url::Url;
pub use crate::aws::checksum::Checksum;
use crate::aws::client::{S3Client, S3Config};
-use crate::aws::credential::{
- AwsCredential, InstanceCredentialProvider, WebIdentityProvider,
-};
+use crate::aws::credential::{InstanceCredentialProvider, WebIdentityProvider};
use crate::client::header::header_meta;
use crate::client::{
ClientConfigKey, CredentialProvider, StaticCredentialProvider,
@@ -85,7 +83,9 @@ const STRICT_PATH_ENCODE_SET: percent_encoding::AsciiSet =
STRICT_ENCODE_SET.rem
const STORE: &str = "S3";
-type AwsCredentialProvider = Arc<dyn CredentialProvider<Credential =
AwsCredential>>;
+/// [`CredentialProvider`] for [`AmazonS3`]
+pub type AwsCredentialProvider = Arc<dyn CredentialProvider<Credential =
AwsCredential>>;
+pub use credential::AwsCredential;
/// Default metadata endpoint
static METADATA_ENDPOINT: &str = "http://169.254.169.254";
@@ -209,6 +209,13 @@ impl std::fmt::Display for AmazonS3 {
}
}
+impl AmazonS3 {
+ /// Returns the [`AwsCredentialProvider`] used by [`AmazonS3`]
+ pub fn credentials(&self) -> &AwsCredentialProvider {
+ &self.client.config().credentials
+ }
+}
+
#[async_trait]
impl ObjectStore for AmazonS3 {
async fn put(&self, location: &Path, bytes: Bytes) -> Result<()> {
@@ -424,6 +431,8 @@ pub struct AmazonS3Builder {
profile: Option<String>,
/// Client options
client_options: ClientOptions,
+ /// Credentials
+ credentials: Option<AwsCredentialProvider>,
}
/// Configuration keys for [`AmazonS3Builder`]
@@ -879,6 +888,12 @@ impl AmazonS3Builder {
self
}
+ /// Set the credential provider overriding any other options
+ pub fn with_credentials(mut self, credentials: AwsCredentialProvider) ->
Self {
+ self.credentials = Some(credentials);
+ self
+ }
+
/// Sets what protocol is allowed. If `allow_http` is :
/// * false (default): Only HTTPS are allowed
/// * true: HTTP and HTTPS are allowed
@@ -992,7 +1007,7 @@ impl AmazonS3Builder {
self.parse_url(&url)?;
}
- let region = match (self.region.clone(), self.profile.clone()) {
+ let region = match (self.region, self.profile.clone()) {
(Some(region), _) => Some(region),
(None, Some(profile)) => profile_region(profile),
(None, None) => None,
@@ -1002,76 +1017,74 @@ impl AmazonS3Builder {
let region = region.context(MissingRegionSnafu)?;
let checksum = self.checksum_algorithm.map(|x| x.get()).transpose()?;
- let credentials = match (self.access_key_id, self.secret_access_key,
self.token) {
- (Some(key_id), Some(secret_key), token) => {
- info!("Using Static credential provider");
- let credential = AwsCredential {
- key_id,
- secret_key,
- token,
- };
- Arc::new(StaticCredentialProvider::new(credential)) as _
- }
- (None, Some(_), _) => return Err(Error::MissingAccessKeyId.into()),
- (Some(_), None, _) => return
Err(Error::MissingSecretAccessKey.into()),
- // TODO: Replace with `AmazonS3Builder::credentials_from_env`
- _ => match (
- std::env::var("AWS_WEB_IDENTITY_TOKEN_FILE"),
- std::env::var("AWS_ROLE_ARN"),
- ) {
- (Ok(token_path), Ok(role_arn)) => {
- info!("Using WebIdentity credential provider");
-
- let session_name = std::env::var("AWS_ROLE_SESSION_NAME")
- .unwrap_or_else(|_| "WebIdentitySession".to_string());
-
- let endpoint =
format!("https://sts.{region}.amazonaws.com");
-
- // Disallow non-HTTPs requests
- let client = self
- .client_options
- .clone()
- .with_allow_http(false)
- .client()?;
-
- let token = WebIdentityProvider {
- token_path,
- session_name,
- role_arn,
- endpoint,
- };
-
- Arc::new(TokenCredentialProvider::new(
+ let credentials = if let Some(credentials) = self.credentials {
+ credentials
+ } else if self.access_key_id.is_some() ||
self.secret_access_key.is_some() {
+ match (self.access_key_id, self.secret_access_key, self.token) {
+ (Some(key_id), Some(secret_key), token) => {
+ info!("Using Static credential provider");
+ let credential = AwsCredential {
+ key_id,
+ secret_key,
token,
- client,
- self.retry_config.clone(),
- )) as _
+ };
+ Arc::new(StaticCredentialProvider::new(credential)) as _
}
- _ => match self.profile {
- Some(profile) => {
- info!("Using profile \"{}\" credential provider",
profile);
- profile_credentials(profile, region.clone())?
- }
- None => {
- info!("Using Instance credential provider");
-
- let token = InstanceCredentialProvider {
- cache: Default::default(),
- imdsv1_fallback: self.imdsv1_fallback.get()?,
- metadata_endpoint: self
- .metadata_endpoint
- .unwrap_or_else(|| METADATA_ENDPOINT.into()),
- };
-
- Arc::new(TokenCredentialProvider::new(
- token,
- // The instance metadata endpoint is access over
HTTP
-
self.client_options.clone().with_allow_http(true).client()?,
- self.retry_config.clone(),
- )) as _
- }
- },
- },
+ (None, Some(_), _) => return
Err(Error::MissingAccessKeyId.into()),
+ (Some(_), None, _) => return
Err(Error::MissingSecretAccessKey.into()),
+ (None, None, _) => unreachable!(),
+ }
+ } else if let (Ok(token_path), Ok(role_arn)) = (
+ std::env::var("AWS_WEB_IDENTITY_TOKEN_FILE"),
+ std::env::var("AWS_ROLE_ARN"),
+ ) {
+ // TODO: Replace with `AmazonS3Builder::credentials_from_env`
+ info!("Using WebIdentity credential provider");
+
+ let session_name = std::env::var("AWS_ROLE_SESSION_NAME")
+ .unwrap_or_else(|_| "WebIdentitySession".to_string());
+
+ let endpoint = format!("https://sts.{region}.amazonaws.com");
+
+ // Disallow non-HTTPs requests
+ let client = self
+ .client_options
+ .clone()
+ .with_allow_http(false)
+ .client()?;
+
+ let token = WebIdentityProvider {
+ token_path,
+ session_name,
+ role_arn,
+ endpoint,
+ };
+
+ Arc::new(TokenCredentialProvider::new(
+ token,
+ client,
+ self.retry_config.clone(),
+ )) as _
+ } else if let Some(profile) = self.profile {
+ info!("Using profile \"{}\" credential provider", profile);
+ profile_credentials(profile, region.clone())?
+ } else {
+ info!("Using Instance credential provider");
+
+ let token = InstanceCredentialProvider {
+ cache: Default::default(),
+ imdsv1_fallback: self.imdsv1_fallback.get()?,
+ metadata_endpoint: self
+ .metadata_endpoint
+ .unwrap_or_else(|| METADATA_ENDPOINT.into()),
+ };
+
+ Arc::new(TokenCredentialProvider::new(
+ token,
+ // The instance metadata endpoint is access over HTTP
+ self.client_options.clone().with_allow_http(true).client()?,
+ self.retry_config.clone(),
+ )) as _
};
let endpoint: String;
diff --git a/object_store/src/azure/mod.rs b/object_store/src/azure/mod.rs
index 6dc14cfb5..069b033d1 100644
--- a/object_store/src/azure/mod.rs
+++ b/object_store/src/azure/mod.rs
@@ -48,7 +48,6 @@ use std::{collections::BTreeSet, str::FromStr};
use tokio::io::AsyncWrite;
use url::Url;
-use crate::azure::credential::AzureCredential;
use crate::client::header::header_meta;
use crate::client::{
ClientConfigKey, CredentialProvider, StaticCredentialProvider,
@@ -60,7 +59,10 @@ pub use credential::authority_hosts;
mod client;
mod credential;
-type AzureCredentialProvider = Arc<dyn CredentialProvider<Credential =
AzureCredential>>;
+/// [`CredentialProvider`] for [`MicrosoftAzure`]
+pub type AzureCredentialProvider =
+ Arc<dyn CredentialProvider<Credential = AzureCredential>>;
+pub use credential::AzureCredential;
const STORE: &str = "MicrosoftAzure";
@@ -153,6 +155,13 @@ pub struct MicrosoftAzure {
client: Arc<client::AzureClient>,
}
+impl MicrosoftAzure {
+ /// Returns the [`AzureCredentialProvider`] used by [`MicrosoftAzure`]
+ pub fn credentials(&self) -> &AzureCredentialProvider {
+ &self.client.config().credentials
+ }
+}
+
impl std::fmt::Display for MicrosoftAzure {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
@@ -374,6 +383,8 @@ pub struct MicrosoftAzureBuilder {
retry_config: RetryConfig,
/// Client options
client_options: ClientOptions,
+ /// Credentials
+ credentials: Option<AzureCredentialProvider>,
}
/// Configuration keys for [`MicrosoftAzureBuilder`]
@@ -840,6 +851,12 @@ impl MicrosoftAzureBuilder {
self
}
+ /// Set the credential provider overriding any other options
+ pub fn with_credentials(mut self, credentials: AzureCredentialProvider) ->
Self {
+ self.credentials = Some(credentials);
+ self
+ }
+
/// Set if the Azure emulator should be used (defaults to false)
pub fn with_use_emulator(mut self, use_emulator: bool) -> Self {
self.use_emulator = use_emulator.into();
@@ -937,7 +954,9 @@ impl MicrosoftAzureBuilder {
let url = Url::parse(&account_url)
.context(UnableToParseUrlSnafu { url: account_url })?;
- let credential = if let Some(bearer_token) = self.bearer_token {
+ let credential = if let Some(credential) = self.credentials {
+ credential
+ } else if let Some(bearer_token) = self.bearer_token {
static_creds(AzureCredential::BearerToken(bearer_token))
} else if let Some(access_key) = self.access_key {
static_creds(AzureCredential::AccessKey(access_key))
diff --git a/object_store/src/client/mod.rs b/object_store/src/client/mod.rs
index 292e4678f..8c2357699 100644
--- a/object_store/src/client/mod.rs
+++ b/object_store/src/client/mod.rs
@@ -509,8 +509,10 @@ impl GetOptionsExt for RequestBuilder {
/// Provides credentials for use when signing requests
#[async_trait]
pub trait CredentialProvider: std::fmt::Debug + Send + Sync {
+ /// The type of credential returned by this provider
type Credential;
+ /// Return a credential
async fn get_credential(&self) -> Result<Arc<Self::Credential>>;
}
diff --git a/object_store/src/gcp/credential.rs
b/object_store/src/gcp/credential.rs
index ad12855e1..205b80594 100644
--- a/object_store/src/gcp/credential.rs
+++ b/object_store/src/gcp/credential.rs
@@ -82,6 +82,7 @@ impl From<Error> for crate::Error {
}
}
+/// A Google Cloud Storage Credential
#[derive(Debug, Eq, PartialEq)]
pub struct GcpCredential {
/// An HTTP bearer token
diff --git a/object_store/src/gcp/mod.rs b/object_store/src/gcp/mod.rs
index 6813bbf6e..21ba1588f 100644
--- a/object_store/src/gcp/mod.rs
+++ b/object_store/src/gcp/mod.rs
@@ -52,7 +52,6 @@ use crate::client::{
ClientConfigKey, CredentialProvider, GetOptionsExt,
StaticCredentialProvider,
TokenCredentialProvider,
};
-use crate::gcp::credential::{application_default_credentials, GcpCredential};
use crate::{
multipart::{CloudMultiPartUpload, CloudMultiPartUploadImpl, UploadPart},
path::{Path, DELIMITER},
@@ -61,15 +60,18 @@ use crate::{
ObjectStore, Result, RetryConfig,
};
-use self::credential::{
- default_gcs_base_url, InstanceCredentialProvider,
ServiceAccountCredentials,
+use credential::{
+ application_default_credentials, default_gcs_base_url,
InstanceCredentialProvider,
+ ServiceAccountCredentials,
};
mod credential;
const STORE: &str = "GCS";
-type GcpCredentialProvider = Arc<dyn CredentialProvider<Credential =
GcpCredential>>;
+/// [`CredentialProvider`] for [`GoogleCloudStorage`]
+pub type GcpCredentialProvider = Arc<dyn CredentialProvider<Credential =
GcpCredential>>;
+pub use credential::GcpCredential;
#[derive(Debug, Snafu)]
enum Error {
@@ -205,6 +207,13 @@ impl std::fmt::Display for GoogleCloudStorage {
}
}
+impl GoogleCloudStorage {
+ /// Returns the [`GcpCredentialProvider`] used by [`GoogleCloudStorage`]
+ pub fn credentials(&self) -> &GcpCredentialProvider {
+ &self.client.credentials
+ }
+}
+
#[derive(Debug)]
struct GoogleCloudStorageClient {
client: Client,
@@ -696,6 +705,8 @@ pub struct GoogleCloudStorageBuilder {
retry_config: RetryConfig,
/// Client options
client_options: ClientOptions,
+ /// Credentials
+ credentials: Option<GcpCredentialProvider>,
}
/// Configuration keys for [`GoogleCloudStorageBuilder`]
@@ -794,6 +805,7 @@ impl Default for GoogleCloudStorageBuilder {
retry_config: Default::default(),
client_options: ClientOptions::new().with_allow_http(true),
url: None,
+ credentials: None,
}
}
}
@@ -1006,6 +1018,12 @@ impl GoogleCloudStorageBuilder {
self
}
+ /// Set the credential provider overriding any other options
+ pub fn with_credentials(mut self, credentials: GcpCredentialProvider) ->
Self {
+ self.credentials = Some(credentials);
+ self
+ }
+
/// Set the retry configuration
pub fn with_retry(mut self, retry_config: RetryConfig) -> Self {
self.retry_config = retry_config;
@@ -1072,7 +1090,9 @@ impl GoogleCloudStorageBuilder {
let scope = "https://www.googleapis.com/auth/devstorage.full_control";
let audience = "https://www.googleapis.com/oauth2/v4/token";
- let credentials = if disable_oauth {
+ let credentials = if let Some(credentials) = self.credentials {
+ credentials
+ } else if disable_oauth {
Arc::new(StaticCredentialProvider::new(GcpCredential {
bearer: "".to_string(),
})) as _
diff --git a/object_store/src/lib.rs b/object_store/src/lib.rs
index 0f3ed809e..7116a8732 100644
--- a/object_store/src/lib.rs
+++ b/object_store/src/lib.rs
@@ -245,7 +245,7 @@ pub mod throttle;
mod client;
#[cfg(any(feature = "gcp", feature = "aws", feature = "azure", feature =
"http"))]
-pub use client::{backoff::BackoffConfig, retry::RetryConfig};
+pub use client::{backoff::BackoffConfig, retry::RetryConfig,
CredentialProvider};
#[cfg(any(feature = "gcp", feature = "aws", feature = "azure", feature =
"http"))]
mod config;