This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new b314118a0 Deffered config parsing (#4191) (#4192)
b314118a0 is described below

commit b314118a01250a06ef324f84abe50edd527e7023
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Wed May 10 18:44:30 2023 +0100

    Deffered config parsing (#4191) (#4192)
---
 object_store/src/aws/checksum.rs |  10 ++++
 object_store/src/aws/mod.rs      | 100 ++++++++++++++++++++++++---------------
 object_store/src/azure/mod.rs    |  29 ++++++------
 object_store/src/client/mod.rs   |  31 ++++++++++--
 object_store/src/config.rs       |  81 +++++++++++++++++++++++++++++++
 object_store/src/lib.rs          |   3 ++
 object_store/src/util.rs         |   9 ----
 7 files changed, 198 insertions(+), 65 deletions(-)

diff --git a/object_store/src/aws/checksum.rs b/object_store/src/aws/checksum.rs
index 57762b641..a50bd2d18 100644
--- a/object_store/src/aws/checksum.rs
+++ b/object_store/src/aws/checksum.rs
@@ -15,6 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use crate::config::Parse;
 use ring::digest::{self, digest as ring_digest};
 use std::str::FromStr;
 
@@ -66,3 +67,12 @@ impl TryFrom<&String> for Checksum {
         value.parse()
     }
 }
+
+impl Parse for Checksum {
+    fn parse(v: &str) -> crate::Result<Self> {
+        v.parse().map_err(|_| crate::Error::Generic {
+            store: "Config",
+            source: format!("\"{v}\" is not a valid checksum 
algorithm").into(),
+        })
+    }
+}
diff --git a/object_store/src/aws/mod.rs b/object_store/src/aws/mod.rs
index 6ea24fb70..fe49471c4 100644
--- a/object_store/src/aws/mod.rs
+++ b/object_store/src/aws/mod.rs
@@ -53,8 +53,9 @@ use crate::aws::credential::{
     AwsCredential, CredentialProvider, InstanceCredentialProvider,
     StaticCredentialProvider, WebIdentityProvider,
 };
+use crate::client::ClientConfigKey;
+use crate::config::ConfigValue;
 use crate::multipart::{CloudMultiPartUpload, CloudMultiPartUploadImpl, 
UploadPart};
-use crate::util::str_is_truthy;
 use crate::{
     ClientOptions, GetResult, ListResult, MultipartId, ObjectMeta, 
ObjectStore, Path,
     Result, RetryConfig, StreamExt,
@@ -103,9 +104,6 @@ enum Error {
         source: std::num::ParseIntError,
     },
 
-    #[snafu(display("Invalid Checksum algorithm"))]
-    InvalidChecksumAlgorithm,
-
     #[snafu(display("Missing region"))]
     MissingRegion,
 
@@ -461,13 +459,13 @@ pub struct AmazonS3Builder {
     /// Retry config
     retry_config: RetryConfig,
     /// When set to true, fallback to IMDSv1
-    imdsv1_fallback: bool,
+    imdsv1_fallback: ConfigValue<bool>,
     /// When set to true, virtual hosted style request has to be used
-    virtual_hosted_style_request: bool,
+    virtual_hosted_style_request: ConfigValue<bool>,
     /// When set to true, unsigned payload option has to be used
-    unsigned_payload: bool,
+    unsigned_payload: ConfigValue<bool>,
     /// Checksum algorithm which has to be used for object integrity check 
during upload
-    checksum_algorithm: Option<String>,
+    checksum_algorithm: Option<ConfigValue<Checksum>>,
     /// Metadata endpoint, see 
<https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-instance-metadata.html>
     metadata_endpoint: Option<String>,
     /// Profile name, see 
<https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-profiles.html>
@@ -709,8 +707,9 @@ impl AmazonS3Builder {
         }
 
         if let Ok(text) = std::env::var("AWS_ALLOW_HTTP") {
-            builder.client_options =
-                builder.client_options.with_allow_http(str_is_truthy(&text));
+            builder.client_options = builder
+                .client_options
+                .with_config(ClientConfigKey::AllowHttp, text);
         }
 
         builder
@@ -756,11 +755,9 @@ impl AmazonS3Builder {
             AmazonS3ConfigKey::Bucket => self.bucket_name = Some(value.into()),
             AmazonS3ConfigKey::Endpoint => self.endpoint = Some(value.into()),
             AmazonS3ConfigKey::Token => self.token = Some(value.into()),
-            AmazonS3ConfigKey::ImdsV1Fallback => {
-                self.imdsv1_fallback = str_is_truthy(&value.into())
-            }
+            AmazonS3ConfigKey::ImdsV1Fallback => 
self.imdsv1_fallback.parse(value),
             AmazonS3ConfigKey::VirtualHostedStyleRequest => {
-                self.virtual_hosted_style_request = 
str_is_truthy(&value.into())
+                self.virtual_hosted_style_request.parse(value)
             }
             AmazonS3ConfigKey::DefaultRegion => {
                 self.region = self.region.or_else(|| Some(value.into()))
@@ -769,10 +766,10 @@ impl AmazonS3Builder {
                 self.metadata_endpoint = Some(value.into())
             }
             AmazonS3ConfigKey::Profile => self.profile = Some(value.into()),
-            AmazonS3ConfigKey::UnsignedPayload => {
-                self.unsigned_payload = str_is_truthy(&value.into())
+            AmazonS3ConfigKey::UnsignedPayload => 
self.unsigned_payload.parse(value),
+            AmazonS3ConfigKey::Checksum => {
+                self.checksum_algorithm = 
Some(ConfigValue::Deferred(value.into()))
             }
-            AmazonS3ConfigKey::Checksum => self.checksum_algorithm = 
Some(value.into()),
         };
         self
     }
@@ -834,7 +831,9 @@ impl AmazonS3Builder {
             AmazonS3ConfigKey::MetadataEndpoint => 
self.metadata_endpoint.clone(),
             AmazonS3ConfigKey::Profile => self.profile.clone(),
             AmazonS3ConfigKey::UnsignedPayload => 
Some(self.unsigned_payload.to_string()),
-            AmazonS3ConfigKey::Checksum => self.checksum_algorithm.clone(),
+            AmazonS3ConfigKey::Checksum => {
+                self.checksum_algorithm.as_ref().map(ToString::to_string)
+            }
         }
     }
 
@@ -858,7 +857,7 @@ impl AmazonS3Builder {
                 Some((bucket, "s3", region, "amazonaws.com")) => {
                     self.bucket_name = Some(bucket.to_string());
                     self.region = Some(region.to_string());
-                    self.virtual_hosted_style_request = true;
+                    self.virtual_hosted_style_request = true.into();
                 }
                 Some((account, "r2", "cloudflarestorage", "com")) => {
                     self.region = Some("auto".to_string());
@@ -944,7 +943,7 @@ impl AmazonS3Builder {
         mut self,
         virtual_hosted_style_request: bool,
     ) -> Self {
-        self.virtual_hosted_style_request = virtual_hosted_style_request;
+        self.virtual_hosted_style_request = 
virtual_hosted_style_request.into();
         self
     }
 
@@ -967,7 +966,7 @@ impl AmazonS3Builder {
     /// [SSRF attack]: 
https://aws.amazon.com/blogs/security/defense-in-depth-open-firewalls-reverse-proxies-ssrf-vulnerabilities-ec2-instance-metadata-service/
     ///
     pub fn with_imdsv1_fallback(mut self) -> Self {
-        self.imdsv1_fallback = true;
+        self.imdsv1_fallback = true.into();
         self
     }
 
@@ -976,7 +975,7 @@ impl AmazonS3Builder {
     /// * false (default): Signed payload option is used, where the checksum 
for the request body is computed and included when constructing a canonical 
request.
     /// * true: Unsigned payload option is used. `UNSIGNED-PAYLOAD` literal is 
included when constructing a canonical request,
     pub fn with_unsigned_payload(mut self, unsigned_payload: bool) -> Self {
-        self.unsigned_payload = unsigned_payload;
+        self.unsigned_payload = unsigned_payload.into();
         self
     }
 
@@ -985,7 +984,7 @@ impl AmazonS3Builder {
     /// [checksum algorithm]: 
https://docs.aws.amazon.com/AmazonS3/latest/userguide/checking-object-integrity.html
     pub fn with_checksum_algorithm(mut self, checksum_algorithm: Checksum) -> 
Self {
         // Convert to String to enable deferred parsing of config
-        self.checksum_algorithm = Some(checksum_algorithm.to_string());
+        self.checksum_algorithm = Some(checksum_algorithm.into());
         self
     }
 
@@ -1038,11 +1037,7 @@ impl AmazonS3Builder {
 
         let bucket = self.bucket_name.context(MissingBucketNameSnafu)?;
         let region = self.region.context(MissingRegionSnafu)?;
-        let checksum = self
-            .checksum_algorithm
-            .map(|c| c.parse())
-            .transpose()
-            .map_err(|_| Error::InvalidChecksumAlgorithm)?;
+        let checksum = self.checksum_algorithm.map(|x| x.get()).transpose()?;
 
         let credentials = match (self.access_key_id, self.secret_access_key, 
self.token) {
             (Some(key_id), Some(secret_key), token) => {
@@ -1103,7 +1098,7 @@ impl AmazonS3Builder {
                             cache: Default::default(),
                             client: client_options.client()?,
                             retry_config: self.retry_config.clone(),
-                            imdsv1_fallback: self.imdsv1_fallback,
+                            imdsv1_fallback: self.imdsv1_fallback.get()?,
                             metadata_endpoint: self
                                 .metadata_endpoint
                                 .unwrap_or_else(|| METADATA_ENDPOINT.into()),
@@ -1119,7 +1114,7 @@ impl AmazonS3Builder {
         // If `endpoint` is provided then its assumed to be consistent with
         // `virtual_hosted_style_request`. i.e. if 
`virtual_hosted_style_request` is true then
         // `endpoint` should have bucket name included.
-        if self.virtual_hosted_style_request {
+        if self.virtual_hosted_style_request.get()? {
             endpoint = self
                 .endpoint
                 .unwrap_or_else(|| 
format!("https://{bucket}.s3.{region}.amazonaws.com";));
@@ -1139,7 +1134,7 @@ impl AmazonS3Builder {
             credentials,
             retry_config: self.retry_config,
             client_options: self.client_options,
-            sign_payload: !self.unsigned_payload,
+            sign_payload: !self.unsigned_payload.get()?,
             checksum,
         };
 
@@ -1315,10 +1310,10 @@ mod tests {
         let metadata_uri = 
format!("{METADATA_ENDPOINT}{container_creds_relative_uri}");
         assert_eq!(builder.metadata_endpoint.unwrap(), metadata_uri);
         assert_eq!(
-            builder.checksum_algorithm.unwrap(),
-            Checksum::SHA256.to_string()
+            builder.checksum_algorithm.unwrap().get().unwrap(),
+            Checksum::SHA256
         );
-        assert!(builder.unsigned_payload);
+        assert!(builder.unsigned_payload.get().unwrap());
     }
 
     #[test]
@@ -1351,10 +1346,10 @@ mod tests {
         assert_eq!(builder.endpoint.unwrap(), aws_endpoint);
         assert_eq!(builder.token.unwrap(), aws_session_token);
         assert_eq!(
-            builder.checksum_algorithm.unwrap(),
-            Checksum::SHA256.to_string()
+            builder.checksum_algorithm.unwrap().get().unwrap(),
+            Checksum::SHA256
         );
-        assert!(builder.unsigned_payload);
+        assert!(builder.unsigned_payload.get().unwrap());
     }
 
     #[test]
@@ -1564,7 +1559,7 @@ mod tests {
             .unwrap();
         assert_eq!(builder.bucket_name, Some("bucket".to_string()));
         assert_eq!(builder.region, Some("region".to_string()));
-        assert!(builder.virtual_hosted_style_request);
+        assert!(builder.virtual_hosted_style_request.get().unwrap());
 
         let mut builder = AmazonS3Builder::new();
         builder
@@ -1591,6 +1586,35 @@ mod tests {
             builder.parse_url(case).unwrap_err();
         }
     }
+
+    #[test]
+    fn test_invalid_config() {
+        let err = AmazonS3Builder::new()
+            .with_config(AmazonS3ConfigKey::ImdsV1Fallback, "enabled")
+            .with_bucket_name("bucket")
+            .with_region("region")
+            .build()
+            .unwrap_err()
+            .to_string();
+
+        assert_eq!(
+            err,
+            "Generic Config error: failed to parse \"enabled\" as boolean"
+        );
+
+        let err = AmazonS3Builder::new()
+            .with_config(AmazonS3ConfigKey::Checksum, "md5")
+            .with_bucket_name("bucket")
+            .with_region("region")
+            .build()
+            .unwrap_err()
+            .to_string();
+
+        assert_eq!(
+            err,
+            "Generic Config error: \"md5\" is not a valid checksum algorithm"
+        );
+    }
 }
 
 #[cfg(test)]
diff --git a/object_store/src/azure/mod.rs b/object_store/src/azure/mod.rs
index 15033dca7..2b5b43ada 100644
--- a/object_store/src/azure/mod.rs
+++ b/object_store/src/azure/mod.rs
@@ -51,7 +51,9 @@ use std::{collections::BTreeSet, str::FromStr};
 use tokio::io::AsyncWrite;
 use url::Url;
 
-use crate::util::{str_is_truthy, RFC1123_FMT};
+use crate::client::ClientConfigKey;
+use crate::config::ConfigValue;
+use crate::util::RFC1123_FMT;
 pub use credential::authority_hosts;
 
 mod client;
@@ -417,7 +419,7 @@ pub struct MicrosoftAzureBuilder {
     /// Url
     url: Option<String>,
     /// When set to true, azurite storage emulator has to be used
-    use_emulator: bool,
+    use_emulator: ConfigValue<bool>,
     /// Msi endpoint for acquiring managed identity token
     msi_endpoint: Option<String>,
     /// Object id for use with managed identity authentication
@@ -427,7 +429,7 @@ pub struct MicrosoftAzureBuilder {
     /// File containing token for Azure AD workload identity federation
     federated_token_file: Option<String>,
     /// When set to true, azure cli has to be used for acquiring access token
-    use_azure_cli: bool,
+    use_azure_cli: ConfigValue<bool>,
     /// Retry config
     retry_config: RetryConfig,
     /// Client options
@@ -672,8 +674,9 @@ impl MicrosoftAzureBuilder {
         }
 
         if let Ok(text) = std::env::var("AZURE_ALLOW_HTTP") {
-            builder.client_options =
-                builder.client_options.with_allow_http(str_is_truthy(&text));
+            builder.client_options = builder
+                .client_options
+                .with_config(ClientConfigKey::AllowHttp, text)
         }
 
         if let Ok(text) = std::env::var(MSI_ENDPOINT_ENV_KEY) {
@@ -726,12 +729,8 @@ impl MicrosoftAzureBuilder {
             AzureConfigKey::FederatedTokenFile => {
                 self.federated_token_file = Some(value.into())
             }
-            AzureConfigKey::UseAzureCli => {
-                self.use_azure_cli = str_is_truthy(&value.into())
-            }
-            AzureConfigKey::UseEmulator => {
-                self.use_emulator = str_is_truthy(&value.into())
-            }
+            AzureConfigKey::UseAzureCli => self.use_azure_cli.parse(value),
+            AzureConfigKey::UseEmulator => self.use_emulator.parse(value),
         };
         self
     }
@@ -898,7 +897,7 @@ impl MicrosoftAzureBuilder {
 
     /// Set if the Azure emulator should be used (defaults to false)
     pub fn with_use_emulator(mut self, use_emulator: bool) -> Self {
-        self.use_emulator = use_emulator;
+        self.use_emulator = use_emulator.into();
         self
     }
 
@@ -956,7 +955,7 @@ impl MicrosoftAzureBuilder {
     /// Set if the Azure Cli should be used for acquiring access token
     /// 
<https://learn.microsoft.com/en-us/cli/azure/account?view=azure-cli-latest#az-account-get-access-token>
     pub fn with_use_azure_cli(mut self, use_azure_cli: bool) -> Self {
-        self.use_azure_cli = use_azure_cli;
+        self.use_azure_cli = use_azure_cli.into();
         self
     }
 
@@ -969,7 +968,7 @@ impl MicrosoftAzureBuilder {
 
         let container = self.container_name.ok_or(Error::MissingContainerName 
{})?;
 
-        let (is_emulator, storage_url, auth, account) = if self.use_emulator {
+        let (is_emulator, storage_url, auth, account) = if 
self.use_emulator.get()? {
             let account_name = self
                 .account_name
                 .unwrap_or_else(|| EMULATOR_ACCOUNT.to_string());
@@ -1022,7 +1021,7 @@ impl MicrosoftAzureBuilder {
                 credential::CredentialProvider::SASToken(query_pairs)
             } else if let Some(sas) = self.sas_key {
                 credential::CredentialProvider::SASToken(split_sas(&sas)?)
-            } else if self.use_azure_cli {
+            } else if self.use_azure_cli.get()? {
                 credential::CredentialProvider::TokenCredential(
                     TokenCache::default(),
                     Box::new(credential::AzureCliCredential::new()),
diff --git a/object_store/src/client/mod.rs b/object_store/src/client/mod.rs
index d019e8119..d7b0b86d9 100644
--- a/object_store/src/client/mod.rs
+++ b/object_store/src/client/mod.rs
@@ -26,8 +26,10 @@ pub mod retry;
 #[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))]
 pub mod token;
 
+use crate::config::ConfigValue;
 use reqwest::header::{HeaderMap, HeaderValue};
 use reqwest::{Client, ClientBuilder, Proxy};
+use serde::{Deserialize, Serialize};
 use std::collections::HashMap;
 use std::time::Duration;
 
@@ -43,6 +45,14 @@ fn map_client_error(e: reqwest::Error) -> super::Error {
 static DEFAULT_USER_AGENT: &str =
     concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"),);
 
+/// Configuration keys for [`ClientOptions`]
+#[derive(PartialEq, Eq, Hash, Clone, Debug, Copy, Deserialize, Serialize)]
+#[non_exhaustive]
+pub enum ClientConfigKey {
+    /// Allow non-TLS, i.e. non-HTTPS connections
+    AllowHttp,
+}
+
 /// HTTP client configuration for remote object stores
 #[derive(Debug, Clone, Default)]
 pub struct ClientOptions {
@@ -51,7 +61,7 @@ pub struct ClientOptions {
     default_content_type: Option<String>,
     default_headers: Option<HeaderMap>,
     proxy_url: Option<String>,
-    allow_http: bool,
+    allow_http: ConfigValue<bool>,
     allow_insecure: bool,
     timeout: Option<Duration>,
     connect_timeout: Option<Duration>,
@@ -70,6 +80,21 @@ impl ClientOptions {
         Default::default()
     }
 
+    /// Set an option by key
+    pub fn with_config(mut self, key: ClientConfigKey, value: impl 
Into<String>) -> Self {
+        match key {
+            ClientConfigKey::AllowHttp => self.allow_http.parse(value),
+        }
+        self
+    }
+
+    /// Get an option by key
+    pub fn get_config_value(&self, key: &ClientConfigKey) -> Option<String> {
+        match key {
+            ClientConfigKey::AllowHttp => Some(self.allow_http.to_string()),
+        }
+    }
+
     /// Sets the User-Agent header to be used by this client
     ///
     /// Default is based on the version of this crate
@@ -104,7 +129,7 @@ impl ClientOptions {
     /// * false (default):  Only HTTPS are allowed
     /// * true:  HTTP and HTTPS are allowed
     pub fn with_allow_http(mut self, allow_http: bool) -> Self {
-        self.allow_http = allow_http;
+        self.allow_http = allow_http.into();
         self
     }
     /// Allows connections to invalid SSL certificates
@@ -280,7 +305,7 @@ impl ClientOptions {
         }
 
         builder
-            .https_only(!self.allow_http)
+            .https_only(!self.allow_http.get()?)
             .build()
             .map_err(map_client_error)
     }
diff --git a/object_store/src/config.rs b/object_store/src/config.rs
new file mode 100644
index 000000000..3ecce2e52
--- /dev/null
+++ b/object_store/src/config.rs
@@ -0,0 +1,81 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::{Error, Result};
+use std::fmt::{Debug, Display, Formatter};
+
+/// Provides deferred parsing of a value
+///
+/// This allows builders to defer fallibility to build
+#[derive(Debug, Clone)]
+pub enum ConfigValue<T> {
+    Parsed(T),
+    Deferred(String),
+}
+
+impl<T: Display> Display for ConfigValue<T> {
+    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+        match self {
+            Self::Parsed(v) => write!(f, "{v}"),
+            Self::Deferred(v) => write!(f, "{v}"),
+        }
+    }
+}
+
+impl<T> From<T> for ConfigValue<T> {
+    fn from(value: T) -> Self {
+        Self::Parsed(value)
+    }
+}
+
+impl<T: Parse + Clone> ConfigValue<T> {
+    pub fn parse(&mut self, v: impl Into<String>) {
+        *self = Self::Deferred(v.into())
+    }
+
+    pub fn get(&self) -> Result<T> {
+        match self {
+            Self::Parsed(v) => Ok(v.clone()),
+            Self::Deferred(v) => T::parse(v),
+        }
+    }
+}
+
+impl<T: Default> Default for ConfigValue<T> {
+    fn default() -> Self {
+        Self::Parsed(T::default())
+    }
+}
+
+/// A value that can be stored in [`ConfigValue`]
+pub trait Parse: Sized {
+    fn parse(v: &str) -> Result<Self>;
+}
+
+impl Parse for bool {
+    fn parse(v: &str) -> Result<Self> {
+        let lower = v.to_ascii_lowercase();
+        match lower.as_str() {
+            "1" | "true" | "on" | "yes" | "y" => Ok(true),
+            "0" | "false" | "off" | "no" | "n" => Ok(false),
+            _ => Err(Error::Generic {
+                store: "Config",
+                source: format!("failed to parse \"{v}\" as boolean").into(),
+            }),
+        }
+    }
+}
diff --git a/object_store/src/lib.rs b/object_store/src/lib.rs
index c31027c07..1390a0140 100644
--- a/object_store/src/lib.rs
+++ b/object_store/src/lib.rs
@@ -247,6 +247,9 @@ mod client;
 #[cfg(any(feature = "gcp", feature = "aws", feature = "azure", feature = 
"http"))]
 pub use client::{backoff::BackoffConfig, retry::RetryConfig};
 
+#[cfg(any(feature = "gcp", feature = "aws", feature = "azure", feature = 
"http"))]
+mod config;
+
 #[cfg(any(feature = "azure", feature = "aws", feature = "gcp"))]
 mod multipart;
 mod util;
diff --git a/object_store/src/util.rs b/object_store/src/util.rs
index 1ec63f219..e5c701dd8 100644
--- a/object_store/src/util.rs
+++ b/object_store/src/util.rs
@@ -185,15 +185,6 @@ fn merge_ranges(
     ret
 }
 
-#[allow(dead_code)]
-pub(crate) fn str_is_truthy(val: &str) -> bool {
-    val.eq_ignore_ascii_case("1")
-        | val.eq_ignore_ascii_case("true")
-        | val.eq_ignore_ascii_case("on")
-        | val.eq_ignore_ascii_case("yes")
-        | val.eq_ignore_ascii_case("y")
-}
-
 #[cfg(test)]
 mod tests {
     use super::*;

Reply via email to