This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new b314118a0 Deffered config parsing (#4191) (#4192)
b314118a0 is described below
commit b314118a01250a06ef324f84abe50edd527e7023
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Wed May 10 18:44:30 2023 +0100
Deffered config parsing (#4191) (#4192)
---
object_store/src/aws/checksum.rs | 10 ++++
object_store/src/aws/mod.rs | 100 ++++++++++++++++++++++++---------------
object_store/src/azure/mod.rs | 29 ++++++------
object_store/src/client/mod.rs | 31 ++++++++++--
object_store/src/config.rs | 81 +++++++++++++++++++++++++++++++
object_store/src/lib.rs | 3 ++
object_store/src/util.rs | 9 ----
7 files changed, 198 insertions(+), 65 deletions(-)
diff --git a/object_store/src/aws/checksum.rs b/object_store/src/aws/checksum.rs
index 57762b641..a50bd2d18 100644
--- a/object_store/src/aws/checksum.rs
+++ b/object_store/src/aws/checksum.rs
@@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.
+use crate::config::Parse;
use ring::digest::{self, digest as ring_digest};
use std::str::FromStr;
@@ -66,3 +67,12 @@ impl TryFrom<&String> for Checksum {
value.parse()
}
}
+
+impl Parse for Checksum {
+ fn parse(v: &str) -> crate::Result<Self> {
+ v.parse().map_err(|_| crate::Error::Generic {
+ store: "Config",
+ source: format!("\"{v}\" is not a valid checksum
algorithm").into(),
+ })
+ }
+}
diff --git a/object_store/src/aws/mod.rs b/object_store/src/aws/mod.rs
index 6ea24fb70..fe49471c4 100644
--- a/object_store/src/aws/mod.rs
+++ b/object_store/src/aws/mod.rs
@@ -53,8 +53,9 @@ use crate::aws::credential::{
AwsCredential, CredentialProvider, InstanceCredentialProvider,
StaticCredentialProvider, WebIdentityProvider,
};
+use crate::client::ClientConfigKey;
+use crate::config::ConfigValue;
use crate::multipart::{CloudMultiPartUpload, CloudMultiPartUploadImpl,
UploadPart};
-use crate::util::str_is_truthy;
use crate::{
ClientOptions, GetResult, ListResult, MultipartId, ObjectMeta,
ObjectStore, Path,
Result, RetryConfig, StreamExt,
@@ -103,9 +104,6 @@ enum Error {
source: std::num::ParseIntError,
},
- #[snafu(display("Invalid Checksum algorithm"))]
- InvalidChecksumAlgorithm,
-
#[snafu(display("Missing region"))]
MissingRegion,
@@ -461,13 +459,13 @@ pub struct AmazonS3Builder {
/// Retry config
retry_config: RetryConfig,
/// When set to true, fallback to IMDSv1
- imdsv1_fallback: bool,
+ imdsv1_fallback: ConfigValue<bool>,
/// When set to true, virtual hosted style request has to be used
- virtual_hosted_style_request: bool,
+ virtual_hosted_style_request: ConfigValue<bool>,
/// When set to true, unsigned payload option has to be used
- unsigned_payload: bool,
+ unsigned_payload: ConfigValue<bool>,
/// Checksum algorithm which has to be used for object integrity check
during upload
- checksum_algorithm: Option<String>,
+ checksum_algorithm: Option<ConfigValue<Checksum>>,
/// Metadata endpoint, see
<https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-instance-metadata.html>
metadata_endpoint: Option<String>,
/// Profile name, see
<https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-profiles.html>
@@ -709,8 +707,9 @@ impl AmazonS3Builder {
}
if let Ok(text) = std::env::var("AWS_ALLOW_HTTP") {
- builder.client_options =
- builder.client_options.with_allow_http(str_is_truthy(&text));
+ builder.client_options = builder
+ .client_options
+ .with_config(ClientConfigKey::AllowHttp, text);
}
builder
@@ -756,11 +755,9 @@ impl AmazonS3Builder {
AmazonS3ConfigKey::Bucket => self.bucket_name = Some(value.into()),
AmazonS3ConfigKey::Endpoint => self.endpoint = Some(value.into()),
AmazonS3ConfigKey::Token => self.token = Some(value.into()),
- AmazonS3ConfigKey::ImdsV1Fallback => {
- self.imdsv1_fallback = str_is_truthy(&value.into())
- }
+ AmazonS3ConfigKey::ImdsV1Fallback =>
self.imdsv1_fallback.parse(value),
AmazonS3ConfigKey::VirtualHostedStyleRequest => {
- self.virtual_hosted_style_request =
str_is_truthy(&value.into())
+ self.virtual_hosted_style_request.parse(value)
}
AmazonS3ConfigKey::DefaultRegion => {
self.region = self.region.or_else(|| Some(value.into()))
@@ -769,10 +766,10 @@ impl AmazonS3Builder {
self.metadata_endpoint = Some(value.into())
}
AmazonS3ConfigKey::Profile => self.profile = Some(value.into()),
- AmazonS3ConfigKey::UnsignedPayload => {
- self.unsigned_payload = str_is_truthy(&value.into())
+ AmazonS3ConfigKey::UnsignedPayload =>
self.unsigned_payload.parse(value),
+ AmazonS3ConfigKey::Checksum => {
+ self.checksum_algorithm =
Some(ConfigValue::Deferred(value.into()))
}
- AmazonS3ConfigKey::Checksum => self.checksum_algorithm =
Some(value.into()),
};
self
}
@@ -834,7 +831,9 @@ impl AmazonS3Builder {
AmazonS3ConfigKey::MetadataEndpoint =>
self.metadata_endpoint.clone(),
AmazonS3ConfigKey::Profile => self.profile.clone(),
AmazonS3ConfigKey::UnsignedPayload =>
Some(self.unsigned_payload.to_string()),
- AmazonS3ConfigKey::Checksum => self.checksum_algorithm.clone(),
+ AmazonS3ConfigKey::Checksum => {
+ self.checksum_algorithm.as_ref().map(ToString::to_string)
+ }
}
}
@@ -858,7 +857,7 @@ impl AmazonS3Builder {
Some((bucket, "s3", region, "amazonaws.com")) => {
self.bucket_name = Some(bucket.to_string());
self.region = Some(region.to_string());
- self.virtual_hosted_style_request = true;
+ self.virtual_hosted_style_request = true.into();
}
Some((account, "r2", "cloudflarestorage", "com")) => {
self.region = Some("auto".to_string());
@@ -944,7 +943,7 @@ impl AmazonS3Builder {
mut self,
virtual_hosted_style_request: bool,
) -> Self {
- self.virtual_hosted_style_request = virtual_hosted_style_request;
+ self.virtual_hosted_style_request =
virtual_hosted_style_request.into();
self
}
@@ -967,7 +966,7 @@ impl AmazonS3Builder {
/// [SSRF attack]:
https://aws.amazon.com/blogs/security/defense-in-depth-open-firewalls-reverse-proxies-ssrf-vulnerabilities-ec2-instance-metadata-service/
///
pub fn with_imdsv1_fallback(mut self) -> Self {
- self.imdsv1_fallback = true;
+ self.imdsv1_fallback = true.into();
self
}
@@ -976,7 +975,7 @@ impl AmazonS3Builder {
/// * false (default): Signed payload option is used, where the checksum
for the request body is computed and included when constructing a canonical
request.
/// * true: Unsigned payload option is used. `UNSIGNED-PAYLOAD` literal is
included when constructing a canonical request,
pub fn with_unsigned_payload(mut self, unsigned_payload: bool) -> Self {
- self.unsigned_payload = unsigned_payload;
+ self.unsigned_payload = unsigned_payload.into();
self
}
@@ -985,7 +984,7 @@ impl AmazonS3Builder {
/// [checksum algorithm]:
https://docs.aws.amazon.com/AmazonS3/latest/userguide/checking-object-integrity.html
pub fn with_checksum_algorithm(mut self, checksum_algorithm: Checksum) ->
Self {
// Convert to String to enable deferred parsing of config
- self.checksum_algorithm = Some(checksum_algorithm.to_string());
+ self.checksum_algorithm = Some(checksum_algorithm.into());
self
}
@@ -1038,11 +1037,7 @@ impl AmazonS3Builder {
let bucket = self.bucket_name.context(MissingBucketNameSnafu)?;
let region = self.region.context(MissingRegionSnafu)?;
- let checksum = self
- .checksum_algorithm
- .map(|c| c.parse())
- .transpose()
- .map_err(|_| Error::InvalidChecksumAlgorithm)?;
+ let checksum = self.checksum_algorithm.map(|x| x.get()).transpose()?;
let credentials = match (self.access_key_id, self.secret_access_key,
self.token) {
(Some(key_id), Some(secret_key), token) => {
@@ -1103,7 +1098,7 @@ impl AmazonS3Builder {
cache: Default::default(),
client: client_options.client()?,
retry_config: self.retry_config.clone(),
- imdsv1_fallback: self.imdsv1_fallback,
+ imdsv1_fallback: self.imdsv1_fallback.get()?,
metadata_endpoint: self
.metadata_endpoint
.unwrap_or_else(|| METADATA_ENDPOINT.into()),
@@ -1119,7 +1114,7 @@ impl AmazonS3Builder {
// If `endpoint` is provided then its assumed to be consistent with
// `virtual_hosted_style_request`. i.e. if
`virtual_hosted_style_request` is true then
// `endpoint` should have bucket name included.
- if self.virtual_hosted_style_request {
+ if self.virtual_hosted_style_request.get()? {
endpoint = self
.endpoint
.unwrap_or_else(||
format!("https://{bucket}.s3.{region}.amazonaws.com"));
@@ -1139,7 +1134,7 @@ impl AmazonS3Builder {
credentials,
retry_config: self.retry_config,
client_options: self.client_options,
- sign_payload: !self.unsigned_payload,
+ sign_payload: !self.unsigned_payload.get()?,
checksum,
};
@@ -1315,10 +1310,10 @@ mod tests {
let metadata_uri =
format!("{METADATA_ENDPOINT}{container_creds_relative_uri}");
assert_eq!(builder.metadata_endpoint.unwrap(), metadata_uri);
assert_eq!(
- builder.checksum_algorithm.unwrap(),
- Checksum::SHA256.to_string()
+ builder.checksum_algorithm.unwrap().get().unwrap(),
+ Checksum::SHA256
);
- assert!(builder.unsigned_payload);
+ assert!(builder.unsigned_payload.get().unwrap());
}
#[test]
@@ -1351,10 +1346,10 @@ mod tests {
assert_eq!(builder.endpoint.unwrap(), aws_endpoint);
assert_eq!(builder.token.unwrap(), aws_session_token);
assert_eq!(
- builder.checksum_algorithm.unwrap(),
- Checksum::SHA256.to_string()
+ builder.checksum_algorithm.unwrap().get().unwrap(),
+ Checksum::SHA256
);
- assert!(builder.unsigned_payload);
+ assert!(builder.unsigned_payload.get().unwrap());
}
#[test]
@@ -1564,7 +1559,7 @@ mod tests {
.unwrap();
assert_eq!(builder.bucket_name, Some("bucket".to_string()));
assert_eq!(builder.region, Some("region".to_string()));
- assert!(builder.virtual_hosted_style_request);
+ assert!(builder.virtual_hosted_style_request.get().unwrap());
let mut builder = AmazonS3Builder::new();
builder
@@ -1591,6 +1586,35 @@ mod tests {
builder.parse_url(case).unwrap_err();
}
}
+
+ #[test]
+ fn test_invalid_config() {
+ let err = AmazonS3Builder::new()
+ .with_config(AmazonS3ConfigKey::ImdsV1Fallback, "enabled")
+ .with_bucket_name("bucket")
+ .with_region("region")
+ .build()
+ .unwrap_err()
+ .to_string();
+
+ assert_eq!(
+ err,
+ "Generic Config error: failed to parse \"enabled\" as boolean"
+ );
+
+ let err = AmazonS3Builder::new()
+ .with_config(AmazonS3ConfigKey::Checksum, "md5")
+ .with_bucket_name("bucket")
+ .with_region("region")
+ .build()
+ .unwrap_err()
+ .to_string();
+
+ assert_eq!(
+ err,
+ "Generic Config error: \"md5\" is not a valid checksum algorithm"
+ );
+ }
}
#[cfg(test)]
diff --git a/object_store/src/azure/mod.rs b/object_store/src/azure/mod.rs
index 15033dca7..2b5b43ada 100644
--- a/object_store/src/azure/mod.rs
+++ b/object_store/src/azure/mod.rs
@@ -51,7 +51,9 @@ use std::{collections::BTreeSet, str::FromStr};
use tokio::io::AsyncWrite;
use url::Url;
-use crate::util::{str_is_truthy, RFC1123_FMT};
+use crate::client::ClientConfigKey;
+use crate::config::ConfigValue;
+use crate::util::RFC1123_FMT;
pub use credential::authority_hosts;
mod client;
@@ -417,7 +419,7 @@ pub struct MicrosoftAzureBuilder {
/// Url
url: Option<String>,
/// When set to true, azurite storage emulator has to be used
- use_emulator: bool,
+ use_emulator: ConfigValue<bool>,
/// Msi endpoint for acquiring managed identity token
msi_endpoint: Option<String>,
/// Object id for use with managed identity authentication
@@ -427,7 +429,7 @@ pub struct MicrosoftAzureBuilder {
/// File containing token for Azure AD workload identity federation
federated_token_file: Option<String>,
/// When set to true, azure cli has to be used for acquiring access token
- use_azure_cli: bool,
+ use_azure_cli: ConfigValue<bool>,
/// Retry config
retry_config: RetryConfig,
/// Client options
@@ -672,8 +674,9 @@ impl MicrosoftAzureBuilder {
}
if let Ok(text) = std::env::var("AZURE_ALLOW_HTTP") {
- builder.client_options =
- builder.client_options.with_allow_http(str_is_truthy(&text));
+ builder.client_options = builder
+ .client_options
+ .with_config(ClientConfigKey::AllowHttp, text)
}
if let Ok(text) = std::env::var(MSI_ENDPOINT_ENV_KEY) {
@@ -726,12 +729,8 @@ impl MicrosoftAzureBuilder {
AzureConfigKey::FederatedTokenFile => {
self.federated_token_file = Some(value.into())
}
- AzureConfigKey::UseAzureCli => {
- self.use_azure_cli = str_is_truthy(&value.into())
- }
- AzureConfigKey::UseEmulator => {
- self.use_emulator = str_is_truthy(&value.into())
- }
+ AzureConfigKey::UseAzureCli => self.use_azure_cli.parse(value),
+ AzureConfigKey::UseEmulator => self.use_emulator.parse(value),
};
self
}
@@ -898,7 +897,7 @@ impl MicrosoftAzureBuilder {
/// Set if the Azure emulator should be used (defaults to false)
pub fn with_use_emulator(mut self, use_emulator: bool) -> Self {
- self.use_emulator = use_emulator;
+ self.use_emulator = use_emulator.into();
self
}
@@ -956,7 +955,7 @@ impl MicrosoftAzureBuilder {
/// Set if the Azure Cli should be used for acquiring access token
///
<https://learn.microsoft.com/en-us/cli/azure/account?view=azure-cli-latest#az-account-get-access-token>
pub fn with_use_azure_cli(mut self, use_azure_cli: bool) -> Self {
- self.use_azure_cli = use_azure_cli;
+ self.use_azure_cli = use_azure_cli.into();
self
}
@@ -969,7 +968,7 @@ impl MicrosoftAzureBuilder {
let container = self.container_name.ok_or(Error::MissingContainerName
{})?;
- let (is_emulator, storage_url, auth, account) = if self.use_emulator {
+ let (is_emulator, storage_url, auth, account) = if
self.use_emulator.get()? {
let account_name = self
.account_name
.unwrap_or_else(|| EMULATOR_ACCOUNT.to_string());
@@ -1022,7 +1021,7 @@ impl MicrosoftAzureBuilder {
credential::CredentialProvider::SASToken(query_pairs)
} else if let Some(sas) = self.sas_key {
credential::CredentialProvider::SASToken(split_sas(&sas)?)
- } else if self.use_azure_cli {
+ } else if self.use_azure_cli.get()? {
credential::CredentialProvider::TokenCredential(
TokenCache::default(),
Box::new(credential::AzureCliCredential::new()),
diff --git a/object_store/src/client/mod.rs b/object_store/src/client/mod.rs
index d019e8119..d7b0b86d9 100644
--- a/object_store/src/client/mod.rs
+++ b/object_store/src/client/mod.rs
@@ -26,8 +26,10 @@ pub mod retry;
#[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))]
pub mod token;
+use crate::config::ConfigValue;
use reqwest::header::{HeaderMap, HeaderValue};
use reqwest::{Client, ClientBuilder, Proxy};
+use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::Duration;
@@ -43,6 +45,14 @@ fn map_client_error(e: reqwest::Error) -> super::Error {
static DEFAULT_USER_AGENT: &str =
concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"),);
+/// Configuration keys for [`ClientOptions`]
+#[derive(PartialEq, Eq, Hash, Clone, Debug, Copy, Deserialize, Serialize)]
+#[non_exhaustive]
+pub enum ClientConfigKey {
+ /// Allow non-TLS, i.e. non-HTTPS connections
+ AllowHttp,
+}
+
/// HTTP client configuration for remote object stores
#[derive(Debug, Clone, Default)]
pub struct ClientOptions {
@@ -51,7 +61,7 @@ pub struct ClientOptions {
default_content_type: Option<String>,
default_headers: Option<HeaderMap>,
proxy_url: Option<String>,
- allow_http: bool,
+ allow_http: ConfigValue<bool>,
allow_insecure: bool,
timeout: Option<Duration>,
connect_timeout: Option<Duration>,
@@ -70,6 +80,21 @@ impl ClientOptions {
Default::default()
}
+ /// Set an option by key
+ pub fn with_config(mut self, key: ClientConfigKey, value: impl
Into<String>) -> Self {
+ match key {
+ ClientConfigKey::AllowHttp => self.allow_http.parse(value),
+ }
+ self
+ }
+
+ /// Get an option by key
+ pub fn get_config_value(&self, key: &ClientConfigKey) -> Option<String> {
+ match key {
+ ClientConfigKey::AllowHttp => Some(self.allow_http.to_string()),
+ }
+ }
+
/// Sets the User-Agent header to be used by this client
///
/// Default is based on the version of this crate
@@ -104,7 +129,7 @@ impl ClientOptions {
/// * false (default): Only HTTPS are allowed
/// * true: HTTP and HTTPS are allowed
pub fn with_allow_http(mut self, allow_http: bool) -> Self {
- self.allow_http = allow_http;
+ self.allow_http = allow_http.into();
self
}
/// Allows connections to invalid SSL certificates
@@ -280,7 +305,7 @@ impl ClientOptions {
}
builder
- .https_only(!self.allow_http)
+ .https_only(!self.allow_http.get()?)
.build()
.map_err(map_client_error)
}
diff --git a/object_store/src/config.rs b/object_store/src/config.rs
new file mode 100644
index 000000000..3ecce2e52
--- /dev/null
+++ b/object_store/src/config.rs
@@ -0,0 +1,81 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::{Error, Result};
+use std::fmt::{Debug, Display, Formatter};
+
+/// Provides deferred parsing of a value
+///
+/// This allows builders to defer fallibility to build
+#[derive(Debug, Clone)]
+pub enum ConfigValue<T> {
+ Parsed(T),
+ Deferred(String),
+}
+
+impl<T: Display> Display for ConfigValue<T> {
+ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ match self {
+ Self::Parsed(v) => write!(f, "{v}"),
+ Self::Deferred(v) => write!(f, "{v}"),
+ }
+ }
+}
+
+impl<T> From<T> for ConfigValue<T> {
+ fn from(value: T) -> Self {
+ Self::Parsed(value)
+ }
+}
+
+impl<T: Parse + Clone> ConfigValue<T> {
+ pub fn parse(&mut self, v: impl Into<String>) {
+ *self = Self::Deferred(v.into())
+ }
+
+ pub fn get(&self) -> Result<T> {
+ match self {
+ Self::Parsed(v) => Ok(v.clone()),
+ Self::Deferred(v) => T::parse(v),
+ }
+ }
+}
+
+impl<T: Default> Default for ConfigValue<T> {
+ fn default() -> Self {
+ Self::Parsed(T::default())
+ }
+}
+
+/// A value that can be stored in [`ConfigValue`]
+pub trait Parse: Sized {
+ fn parse(v: &str) -> Result<Self>;
+}
+
+impl Parse for bool {
+ fn parse(v: &str) -> Result<Self> {
+ let lower = v.to_ascii_lowercase();
+ match lower.as_str() {
+ "1" | "true" | "on" | "yes" | "y" => Ok(true),
+ "0" | "false" | "off" | "no" | "n" => Ok(false),
+ _ => Err(Error::Generic {
+ store: "Config",
+ source: format!("failed to parse \"{v}\" as boolean").into(),
+ }),
+ }
+ }
+}
diff --git a/object_store/src/lib.rs b/object_store/src/lib.rs
index c31027c07..1390a0140 100644
--- a/object_store/src/lib.rs
+++ b/object_store/src/lib.rs
@@ -247,6 +247,9 @@ mod client;
#[cfg(any(feature = "gcp", feature = "aws", feature = "azure", feature =
"http"))]
pub use client::{backoff::BackoffConfig, retry::RetryConfig};
+#[cfg(any(feature = "gcp", feature = "aws", feature = "azure", feature =
"http"))]
+mod config;
+
#[cfg(any(feature = "azure", feature = "aws", feature = "gcp"))]
mod multipart;
mod util;
diff --git a/object_store/src/util.rs b/object_store/src/util.rs
index 1ec63f219..e5c701dd8 100644
--- a/object_store/src/util.rs
+++ b/object_store/src/util.rs
@@ -185,15 +185,6 @@ fn merge_ranges(
ret
}
-#[allow(dead_code)]
-pub(crate) fn str_is_truthy(val: &str) -> bool {
- val.eq_ignore_ascii_case("1")
- | val.eq_ignore_ascii_case("true")
- | val.eq_ignore_ascii_case("on")
- | val.eq_ignore_ascii_case("yes")
- | val.eq_ignore_ascii_case("y")
-}
-
#[cfg(test)]
mod tests {
use super::*;