This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 95cbca64e Add ObjectStore ClientConfig (#3252)
95cbca64e is described below

commit 95cbca64e1dc30360304a1522f07c58dc661ef6b
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Fri Dec 2 09:49:02 2022 +0000

    Add ObjectStore ClientConfig (#3252)
    
    * Add ObjectStore ClientConfig
    
    * Fix default allow HTTP for GCP
    
    * Fix tests
    
    * Tweak error message
---
 object_store/src/aws/client.rs   | 22 +++-----------
 object_store/src/aws/mod.rs      | 65 +++++++++++++++++++--------------------
 object_store/src/azure/client.rs | 26 +++++-----------
 object_store/src/azure/mod.rs    | 32 +++++++++++--------
 object_store/src/client/mod.rs   | 50 ++++++++++++++++++++++++++++++
 object_store/src/gcp/mod.rs      | 66 ++++++++++++++++++++--------------------
 object_store/src/lib.rs          |  3 ++
 7 files changed, 147 insertions(+), 117 deletions(-)

diff --git a/object_store/src/aws/client.rs b/object_store/src/aws/client.rs
index e51fe415c..ccc0a9c6b 100644
--- a/object_store/src/aws/client.rs
+++ b/object_store/src/aws/client.rs
@@ -23,7 +23,8 @@ use crate::multipart::UploadPart;
 use crate::path::DELIMITER;
 use crate::util::{format_http_range, format_prefix};
 use crate::{
-    BoxStream, ListResult, MultipartId, ObjectMeta, Path, Result, RetryConfig, 
StreamExt,
+    BoxStream, ClientOptions, ListResult, MultipartId, ObjectMeta, Path, 
Result,
+    RetryConfig, StreamExt,
 };
 use bytes::{Buf, Bytes};
 use chrono::{DateTime, Utc};
@@ -88,9 +89,6 @@ pub(crate) enum Error {
 
     #[snafu(display("Got invalid multipart response: {}", source))]
     InvalidMultipartResponse { source: quick_xml::de::DeError },
-
-    #[snafu(display("Unable to use proxy url: {}", source))]
-    ProxyUrl { source: reqwest::Error },
 }
 
 impl From<Error> for crate::Error {
@@ -203,8 +201,7 @@ pub struct S3Config {
     pub bucket_endpoint: String,
     pub credentials: Box<dyn CredentialProvider>,
     pub retry_config: RetryConfig,
-    pub allow_http: bool,
-    pub proxy_url: Option<String>,
+    pub client_options: ClientOptions,
 }
 
 impl S3Config {
@@ -221,18 +218,7 @@ pub(crate) struct S3Client {
 
 impl S3Client {
     pub fn new(config: S3Config) -> Result<Self> {
-        let builder = 
reqwest::ClientBuilder::new().https_only(!config.allow_http);
-        let client = match &config.proxy_url {
-            Some(ref url) => {
-                let pr = reqwest::Proxy::all(url)
-                    .map_err(|source| Error::ProxyUrl { source })?;
-                builder.proxy(pr)
-            }
-            _ => builder,
-        }
-        .build()
-        .unwrap();
-
+        let client = config.client_options.client()?;
         Ok(Self { config, client })
     }
 
diff --git a/object_store/src/aws/mod.rs b/object_store/src/aws/mod.rs
index cf7a5542e..c92b8c29a 100644
--- a/object_store/src/aws/mod.rs
+++ b/object_store/src/aws/mod.rs
@@ -36,7 +36,6 @@ use bytes::Bytes;
 use chrono::{DateTime, Utc};
 use futures::stream::BoxStream;
 use futures::TryStreamExt;
-use reqwest::{Client, Proxy};
 use snafu::{OptionExt, ResultExt, Snafu};
 use std::collections::BTreeSet;
 use std::ops::Range;
@@ -51,8 +50,8 @@ use crate::aws::credential::{
 };
 use crate::multipart::{CloudMultiPartUpload, CloudMultiPartUploadImpl, 
UploadPart};
 use crate::{
-    GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, Path, Result,
-    RetryConfig, StreamExt,
+    ClientOptions, GetResult, ListResult, MultipartId, ObjectMeta, 
ObjectStore, Path,
+    Result, RetryConfig, StreamExt,
 };
 
 mod client;
@@ -120,9 +119,6 @@ enum Error {
 
     #[snafu(display("Error reading token file: {}", source))]
     ReadTokenFile { source: std::io::Error },
-
-    #[snafu(display("Unable to use proxy url: {}", source))]
-    ProxyUrl { source: reqwest::Error },
 }
 
 impl From<Error> for super::Error {
@@ -361,12 +357,11 @@ pub struct AmazonS3Builder {
     endpoint: Option<String>,
     token: Option<String>,
     retry_config: RetryConfig,
-    allow_http: bool,
     imdsv1_fallback: bool,
     virtual_hosted_style_request: bool,
     metadata_endpoint: Option<String>,
     profile: Option<String>,
-    proxy_url: Option<String>,
+    client_options: ClientOptions,
 }
 
 impl AmazonS3Builder {
@@ -431,7 +426,8 @@ impl AmazonS3Builder {
         }
 
         if let Ok(text) = std::env::var("AWS_ALLOW_HTTP") {
-            builder.allow_http = text == "true";
+            builder.client_options =
+                builder.client_options.with_allow_http(text == "true");
         }
 
         builder
@@ -487,7 +483,7 @@ impl AmazonS3Builder {
     /// * false (default):  Only HTTPS are allowed
     /// * true:  HTTP and HTTPS are allowed
     pub fn with_allow_http(mut self, allow_http: bool) -> Self {
-        self.allow_http = allow_http;
+        self.client_options = self.client_options.with_allow_http(allow_http);
         self
     }
 
@@ -543,7 +539,13 @@ impl AmazonS3Builder {
 
     /// Set the proxy_url to be used by the underlying client
     pub fn with_proxy_url(mut self, proxy_url: impl Into<String>) -> Self {
-        self.proxy_url = Some(proxy_url.into());
+        self.client_options = self.client_options.with_proxy_url(proxy_url);
+        self
+    }
+
+    /// Sets the client options, overriding any already set
+    pub fn with_client_options(mut self, options: ClientOptions) -> Self {
+        self.client_options = options;
         self
     }
 
@@ -571,14 +573,6 @@ impl AmazonS3Builder {
         let bucket = self.bucket_name.context(MissingBucketNameSnafu)?;
         let region = self.region.context(MissingRegionSnafu)?;
 
-        let clientbuilder = match self.proxy_url {
-            Some(ref url) => {
-                let pr: Proxy =
-                    Proxy::all(url).map_err(|source| Error::ProxyUrl { source 
})?;
-                Client::builder().proxy(pr)
-            }
-            None => Client::builder(),
-        };
         let credentials = match (self.access_key_id, self.secret_access_key, 
self.token) {
             (Some(key_id), Some(secret_key), token) => {
                 info!("Using Static credential provider");
@@ -608,7 +602,11 @@ impl AmazonS3Builder {
                     let endpoint = format!("https://sts.{}.amazonaws.com";, 
region);
 
                     // Disallow non-HTTPs requests
-                    let client = 
clientbuilder.https_only(true).build().unwrap();
+                    let client = self
+                        .client_options
+                        .clone()
+                        .with_allow_http(false)
+                        .client()?;
 
                     Box::new(WebIdentityProvider {
                         cache: Default::default(),
@@ -629,11 +627,12 @@ impl AmazonS3Builder {
                         info!("Using Instance credential provider");
 
                         // The instance metadata endpoint is access over HTTP
-                        let client = 
clientbuilder.https_only(false).build().unwrap();
+                        let client_options =
+                            self.client_options.clone().with_allow_http(true);
 
                         Box::new(InstanceCredentialProvider {
                             cache: Default::default(),
-                            client,
+                            client: client_options.client()?,
                             retry_config: self.retry_config.clone(),
                             imdsv1_fallback: self.imdsv1_fallback,
                             metadata_endpoint: self
@@ -670,11 +669,10 @@ impl AmazonS3Builder {
             bucket_endpoint,
             credentials,
             retry_config: self.retry_config,
-            allow_http: self.allow_http,
-            proxy_url: self.proxy_url,
+            client_options: self.client_options,
         };
 
-        let client = Arc::new(S3Client::new(config).unwrap());
+        let client = Arc::new(S3Client::new(config)?);
 
         Ok(AmazonS3 { client })
     }
@@ -931,21 +929,20 @@ mod tests {
 
         assert!(s3.is_ok());
 
-        let s3 = AmazonS3Builder::new()
+        let err = AmazonS3Builder::new()
             .with_access_key_id("access_key_id")
             .with_secret_access_key("secret_access_key")
             .with_region("region")
             .with_bucket_name("bucket_name")
             .with_allow_http(true)
             .with_proxy_url("asdf://example.com")
-            .build();
+            .build()
+            .unwrap_err()
+            .to_string();
 
-        assert!(match s3 {
-            Err(crate::Error::Generic { source, .. }) => matches!(
-                source.downcast_ref(),
-                Some(crate::aws::Error::ProxyUrl { .. })
-            ),
-            _ => false,
-        })
+        assert_eq!(
+            "Generic HTTP client error: builder error: unknown proxy scheme",
+            err
+        );
     }
 }
diff --git a/object_store/src/azure/client.rs b/object_store/src/azure/client.rs
index d8cfdd1c7..b537f5edf 100644
--- a/object_store/src/azure/client.rs
+++ b/object_store/src/azure/client.rs
@@ -21,13 +21,16 @@ use crate::client::pagination::stream_paginated;
 use crate::client::retry::RetryExt;
 use crate::path::DELIMITER;
 use crate::util::{format_http_range, format_prefix};
-use crate::{BoxStream, ListResult, ObjectMeta, Path, Result, RetryConfig, 
StreamExt};
+use crate::{
+    BoxStream, ClientOptions, ListResult, ObjectMeta, Path, Result, 
RetryConfig,
+    StreamExt,
+};
 use bytes::{Buf, Bytes};
 use chrono::{DateTime, TimeZone, Utc};
 use itertools::Itertools;
 use reqwest::{
     header::{HeaderValue, CONTENT_LENGTH, IF_NONE_MATCH, RANGE},
-    Client as ReqwestClient, Method, Proxy, Response, StatusCode,
+    Client as ReqwestClient, Method, Response, StatusCode,
 };
 use serde::{Deserialize, Deserializer, Serialize};
 use snafu::{ResultExt, Snafu};
@@ -82,9 +85,6 @@ pub(crate) enum Error {
     Authorization {
         source: crate::azure::credential::Error,
     },
-
-    #[snafu(display("Unable to use proxy url: {}", source))]
-    ProxyUrl { source: reqwest::Error },
 }
 
 impl From<Error> for crate::Error {
@@ -124,10 +124,9 @@ pub struct AzureConfig {
     pub container: String,
     pub credentials: CredentialProvider,
     pub retry_config: RetryConfig,
-    pub allow_http: bool,
     pub service: Url,
     pub is_emulator: bool,
-    pub proxy_url: Option<String>,
+    pub client_options: ClientOptions,
 }
 
 impl AzureConfig {
@@ -153,18 +152,7 @@ pub(crate) struct AzureClient {
 impl AzureClient {
     /// create a new instance of [AzureClient]
     pub fn new(config: AzureConfig) -> Result<Self> {
-        let builder = ReqwestClient::builder();
-
-        let client = if let Some(url) = config.proxy_url.as_ref() {
-            let pr = Proxy::all(url).map_err(|source| Error::ProxyUrl { source 
});
-            builder.proxy(pr.unwrap())
-        } else {
-            builder
-        }
-        .https_only(!config.allow_http)
-        .build()
-        .unwrap();
-
+        let client = config.client_options.client()?;
         Ok(Self { config, client })
     }
 
diff --git a/object_store/src/azure/mod.rs b/object_store/src/azure/mod.rs
index 060b4b2d2..4b7131ea8 100644
--- a/object_store/src/azure/mod.rs
+++ b/object_store/src/azure/mod.rs
@@ -30,7 +30,8 @@ use self::client::{BlockId, BlockList};
 use crate::{
     multipart::{CloudMultiPartUpload, CloudMultiPartUploadImpl, UploadPart},
     path::Path,
-    GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, Result, 
RetryConfig,
+    ClientOptions, GetResult, ListResult, MultipartId, ObjectMeta, 
ObjectStore, Result,
+    RetryConfig,
 };
 use async_trait::async_trait;
 use bytes::Bytes;
@@ -359,8 +360,7 @@ pub struct MicrosoftAzureBuilder {
     authority_host: Option<String>,
     use_emulator: bool,
     retry_config: RetryConfig,
-    allow_http: bool,
-    proxy_url: Option<String>,
+    client_options: ClientOptions,
 }
 
 impl Debug for MicrosoftAzureBuilder {
@@ -480,10 +480,10 @@ impl MicrosoftAzureBuilder {
     }
 
     /// Sets what protocol is allowed. If `allow_http` is :
-    /// * false (default):  Only HTTPS is allowed
+    /// * false (default):  Only HTTPS are allowed
     /// * true:  HTTP and HTTPS are allowed
     pub fn with_allow_http(mut self, allow_http: bool) -> Self {
-        self.allow_http = allow_http;
+        self.client_options = self.client_options.with_allow_http(allow_http);
         self
     }
 
@@ -503,7 +503,13 @@ impl MicrosoftAzureBuilder {
 
     /// Set the proxy_url to be used by the underlying client
     pub fn with_proxy_url(mut self, proxy_url: impl Into<String>) -> Self {
-        self.proxy_url = Some(proxy_url.into());
+        self.client_options = self.client_options.with_proxy_url(proxy_url);
+        self
+    }
+
+    /// Sets the client options, overriding any already set
+    pub fn with_client_options(mut self, options: ClientOptions) -> Self {
+        self.client_options = options;
         self
     }
 
@@ -521,14 +527,13 @@ impl MicrosoftAzureBuilder {
             sas_query_pairs,
             use_emulator,
             retry_config,
-            allow_http,
             authority_host,
-            proxy_url,
+            mut client_options,
         } = self;
 
         let container = container_name.ok_or(Error::MissingContainerName {})?;
 
-        let (is_emulator, allow_http, storage_url, auth, account) = if 
use_emulator {
+        let (is_emulator, storage_url, auth, account) = if use_emulator {
             let account_name =
                 account_name.unwrap_or_else(|| EMULATOR_ACCOUNT.to_string());
             // Allow overriding defaults. Values taken from
@@ -537,7 +542,9 @@ impl MicrosoftAzureBuilder {
             let account_key =
                 access_key.unwrap_or_else(|| EMULATOR_ACCOUNT_KEY.to_string());
             let credential = 
credential::CredentialProvider::AccessKey(account_key);
-            (true, true, url, credential, account_name)
+
+            client_options = client_options.with_allow_http(true);
+            (true, url, credential, account_name)
         } else {
             let account_name = account_name.ok_or(Error::MissingAccount {})?;
             let account_url = format!("https://{}.blob.core.windows.net";, 
&account_name);
@@ -564,18 +571,17 @@ impl MicrosoftAzureBuilder {
             } else {
                 Err(Error::MissingCredentials {})
             }?;
-            (false, allow_http, url, credential, account_name)
+            (false, url, credential, account_name)
         };
 
         let config = client::AzureConfig {
             account,
-            allow_http,
             retry_config,
             service: storage_url,
             container,
             credentials: auth,
             is_emulator,
-            proxy_url,
+            client_options,
         };
 
         let client = Arc::new(client::AzureClient::new(config)?);
diff --git a/object_store/src/client/mod.rs b/object_store/src/client/mod.rs
index c93c68a1f..2b58a77f2 100644
--- a/object_store/src/client/mod.rs
+++ b/object_store/src/client/mod.rs
@@ -23,3 +23,53 @@ pub mod mock_server;
 pub mod pagination;
 pub mod retry;
 pub mod token;
+
+use reqwest::{Client, ClientBuilder, Proxy};
+
+fn map_client_error(e: reqwest::Error) -> super::Error {
+    super::Error::Generic {
+        store: "HTTP client",
+        source: Box::new(e),
+    }
+}
+
+/// HTTP client configuration for remote object stores
+#[derive(Debug, Clone, Default)]
+pub struct ClientOptions {
+    proxy_url: Option<String>,
+    allow_http: bool,
+}
+
+impl ClientOptions {
+    /// Create a new [`ClientOptions`] with default values
+    pub fn new() -> Self {
+        Default::default()
+    }
+
+    /// Sets what protocol is allowed. If `allow_http` is :
+    /// * false (default):  Only HTTPS are allowed
+    /// * true:  HTTP and HTTPS are allowed
+    pub fn with_allow_http(mut self, allow_http: bool) -> Self {
+        self.allow_http = allow_http;
+        self
+    }
+
+    /// Set an HTTP proxy to use for requests
+    pub fn with_proxy_url(mut self, proxy_url: impl Into<String>) -> Self {
+        self.proxy_url = Some(proxy_url.into());
+        self
+    }
+
+    pub(crate) fn client(&self) -> super::Result<Client> {
+        let mut builder = ClientBuilder::new();
+        if let Some(proxy) = &self.proxy_url {
+            let proxy = Proxy::all(proxy).map_err(map_client_error)?;
+            builder = builder.proxy(proxy);
+        }
+
+        builder
+            .https_only(!self.allow_http)
+            .build()
+            .map_err(map_client_error)
+    }
+}
diff --git a/object_store/src/gcp/mod.rs b/object_store/src/gcp/mod.rs
index 0da92fdbe..41d6696c1 100644
--- a/object_store/src/gcp/mod.rs
+++ b/object_store/src/gcp/mod.rs
@@ -41,7 +41,6 @@ use chrono::{DateTime, Utc};
 use futures::{stream::BoxStream, StreamExt, TryStreamExt};
 use percent_encoding::{percent_encode, NON_ALPHANUMERIC};
 use reqwest::header::RANGE;
-use reqwest::Proxy;
 use reqwest::{header, Client, Method, Response, StatusCode};
 use snafu::{ResultExt, Snafu};
 use tokio::io::AsyncWrite;
@@ -53,7 +52,8 @@ use crate::{
     multipart::{CloudMultiPartUpload, CloudMultiPartUploadImpl, UploadPart},
     path::{Path, DELIMITER},
     util::{format_http_range, format_prefix},
-    GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, Result, 
RetryConfig,
+    ClientOptions, GetResult, ListResult, MultipartId, ObjectMeta, 
ObjectStore, Result,
+    RetryConfig,
 };
 
 use credential::OAuthProvider;
@@ -123,9 +123,6 @@ enum Error {
 
     #[snafu(display("GCP credential error: {}", source))]
     Credential { source: credential::Error },
-
-    #[snafu(display("Unable to use proxy url: {}", source))]
-    ProxyUrl { source: reqwest::Error },
 }
 
 impl From<Error> for super::Error {
@@ -739,13 +736,23 @@ fn reader_credentials_file(
 ///  .with_bucket_name(BUCKET_NAME)
 ///  .build();
 /// ```
-#[derive(Debug, Default)]
+#[derive(Debug)]
 pub struct GoogleCloudStorageBuilder {
     bucket_name: Option<String>,
     service_account_path: Option<String>,
-    client: Option<Client>,
     retry_config: RetryConfig,
-    proxy_url: Option<String>,
+    client_options: ClientOptions,
+}
+
+impl Default for GoogleCloudStorageBuilder {
+    fn default() -> Self {
+        Self {
+            bucket_name: None,
+            service_account_path: None,
+            retry_config: Default::default(),
+            client_options: ClientOptions::new().with_allow_http(true),
+        }
+    }
 }
 
 impl GoogleCloudStorageBuilder {
@@ -787,9 +794,15 @@ impl GoogleCloudStorageBuilder {
         self
     }
 
-    /// Set proxy url used for connection
+    /// Set the proxy_url to be used by the underlying client
     pub fn with_proxy_url(mut self, proxy_url: impl Into<String>) -> Self {
-        self.proxy_url = Some(proxy_url.into());
+        self.client_options = self.client_options.with_proxy_url(proxy_url);
+        self
+    }
+
+    /// Sets the client options, overriding any already set
+    pub fn with_client_options(mut self, options: ClientOptions) -> Self {
+        self.client_options = options;
         self
     }
 
@@ -799,27 +812,15 @@ impl GoogleCloudStorageBuilder {
         let Self {
             bucket_name,
             service_account_path,
-            client,
             retry_config,
-            proxy_url,
+            client_options,
         } = self;
 
         let bucket_name = bucket_name.ok_or(Error::MissingBucketName {})?;
         let service_account_path =
             service_account_path.ok_or(Error::MissingServiceAccountPath)?;
 
-        let client = match (proxy_url, client) {
-            (_, Some(client)) => client,
-            (Some(url), None) => {
-                let pr = Proxy::all(&url).map_err(|source| Error::ProxyUrl { 
source })?;
-                Client::builder()
-                    .proxy(pr)
-                    .build()
-                    .map_err(|source| Error::ProxyUrl { source })?
-            }
-            (None, None) => Client::new(),
-        };
-
+        let client = client_options.client()?;
         let credentials = reader_credentials_file(service_account_path)?;
 
         // TODO: 
https://cloud.google.com/storage/docs/authentication#oauth-scopes
@@ -1054,18 +1055,17 @@ mod test {
             .build();
         assert!(dbg!(gcs).is_ok());
 
-        let gcs = GoogleCloudStorageBuilder::new()
+        let err = GoogleCloudStorageBuilder::new()
             .with_service_account_path(service_account_path.to_str().unwrap())
             .with_bucket_name("foo")
             .with_proxy_url("asdf://example.com")
-            .build();
+            .build()
+            .unwrap_err()
+            .to_string();
 
-        assert!(match gcs {
-            Err(ObjectStoreError::Generic { source, .. }) => matches!(
-                source.downcast_ref(),
-                Some(crate::gcp::Error::ProxyUrl { .. })
-            ),
-            _ => false,
-        })
+        assert_eq!(
+            "Generic HTTP client error: builder error: unknown proxy scheme",
+            err
+        );
     }
 }
diff --git a/object_store/src/lib.rs b/object_store/src/lib.rs
index a36bb5fb8..ec41f3812 100644
--- a/object_store/src/lib.rs
+++ b/object_store/src/lib.rs
@@ -197,6 +197,9 @@ use std::io::{Read, Seek, SeekFrom};
 use std::ops::Range;
 use tokio::io::AsyncWrite;
 
+#[cfg(any(feature = "azure", feature = "aws", feature = "gcp"))]
+pub use client::ClientOptions;
+
 /// An alias for a dynamically dispatched object store implementation.
 pub type DynObjectStore = dyn ObjectStore;
 

Reply via email to