This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 5da710b56 feat(rust/driver/snowflake)!: return a `Result` from
`Builder::from_env` when parsing fails (#2334)
5da710b56 is described below
commit 5da710b56ed9c3f55ace8bdadd437b0aa3e2e5bb
Author: Matthijs Brobbel <[email protected]>
AuthorDate: Mon Nov 25 02:00:21 2024 +0100
feat(rust/driver/snowflake)!: return a `Result` from `Builder::from_env`
when parsing fails (#2334)
As suggested in
https://github.com/apache/arrow-adbc/pull/2207#discussion_r1853206066
the `Builder::from_env` methods should return a result when parsing
fails.
---
rust/driver/snowflake/README.md | 4 +-
rust/driver/snowflake/src/builder.rs | 36 +++++
rust/driver/snowflake/src/connection/builder.rs | 21 +--
rust/driver/snowflake/src/database/builder.rs | 182 +++++++++++++-----------
rust/driver/snowflake/src/driver/builder.rs | 42 ++++--
rust/driver/snowflake/src/duration.rs | 6 +-
rust/driver/snowflake/tests/driver.rs | 6 +-
7 files changed, 184 insertions(+), 113 deletions(-)
diff --git a/rust/driver/snowflake/README.md b/rust/driver/snowflake/README.md
index 5f8c143f0..114603662 100644
--- a/rust/driver/snowflake/README.md
+++ b/rust/driver/snowflake/README.md
@@ -41,10 +41,10 @@ use arrow_array::{cast::AsArray, types::Decimal128Type};
let mut driver = Driver::try_load()?;
// Construct a database using environment variables
-let mut database = database::Builder::from_env().build(&mut driver)?;
+let mut database = database::Builder::from_env()?.build(&mut driver)?;
// Create a connection to the database
-let mut connection = connection::Builder::from_env().build(&mut database)?;
+let mut connection = connection::Builder::from_env()?.build(&mut database)?;
// Construct a statement to execute a query
let mut statement = connection.new_statement()?;
diff --git a/rust/driver/snowflake/src/builder.rs
b/rust/driver/snowflake/src/builder.rs
index 766cb6bed..28e65dfb2 100644
--- a/rust/driver/snowflake/src/builder.rs
+++ b/rust/driver/snowflake/src/builder.rs
@@ -20,7 +20,11 @@
//!
use std::iter::{Chain, Flatten};
+#[cfg(feature = "env")]
+use std::{env, error::Error as StdError};
+#[cfg(feature = "env")]
+use adbc_core::error::{Error, Status};
use adbc_core::options::OptionValue;
/// An iterator over the builder options.
@@ -48,3 +52,35 @@ impl<T, const COUNT: usize> Iterator for BuilderIter<T,
COUNT> {
self.0.next()
}
}
+
+#[cfg(feature = "env")]
+/// Attempt to read the environment variable with the given `key`, parsing it
+/// using the provided `parse` function.
+///
+/// Returns
+///
+/// - `Ok(None)` when the env variable is not set.
+/// - `Ok(Some(T))` when the env variable is set and the parser succeeds.
+/// - `Err(Error)` when the env variable is set and the parse fails.
+pub(crate) fn env_parse<T>(
+ key: &str,
+ parse: impl FnOnce(&str) -> Result<T, Error>,
+) -> Result<Option<T>, Error> {
+ env::var(key).ok().as_deref().map(parse).transpose()
+}
+
+#[cfg(feature = "env")]
+/// Attempt to read the environment variable with the given `key`, parsing it
+/// using the provided `parse` function, mapping the parse result to an
+/// [`Error`] with [`Status::InvalidArguments`].
+pub(crate) fn env_parse_map_err<T, E: StdError>(
+ key: &str,
+ parse: impl FnOnce(&str) -> Result<T, E>,
+) -> Result<Option<T>, Error> {
+ env::var(key)
+ .ok()
+ .as_deref()
+ .map(parse)
+ .transpose()
+ .map_err(|err| Error::with_message_and_status(err.to_string(),
Status::InvalidArguments))
+}
diff --git a/rust/driver/snowflake/src/connection/builder.rs
b/rust/driver/snowflake/src/connection/builder.rs
index 8b6f0e857..d1b8422db 100644
--- a/rust/driver/snowflake/src/connection/builder.rs
+++ b/rust/driver/snowflake/src/connection/builder.rs
@@ -19,8 +19,6 @@
//!
//!
-#[cfg(feature = "env")]
-use std::env;
use std::fmt;
use adbc_core::{
@@ -30,7 +28,7 @@ use adbc_core::{
};
#[cfg(feature = "env")]
-use crate::database;
+use crate::{builder::env_parse_map_err, database};
use crate::{builder::BuilderIter, Connection, Database};
/// A builder for [`Connection`].
@@ -61,18 +59,21 @@ impl Builder {
/// Construct a builder, setting values based on values of the
/// configuration environment variables.
- pub fn from_env() -> Self {
+ ///
+ /// # Error
+ ///
+ /// Returns an error when environment variables are set but their values
+ /// fail to parse.
+ pub fn from_env() -> Result<Self> {
#[cfg(feature = "dotenv")]
let _ = dotenvy::dotenv();
- let use_high_precision = env::var(Self::USE_HIGH_PRECISION_ENV)
- .ok()
- .as_deref()
- .and_then(|value| value.parse().ok());
- Self {
+ let use_high_precision =
env_parse_map_err(Self::USE_HIGH_PRECISION_ENV, str::parse)?;
+
+ Ok(Self {
use_high_precision,
..Default::default()
- }
+ })
}
}
diff --git a/rust/driver/snowflake/src/database/builder.rs
b/rust/driver/snowflake/src/database/builder.rs
index 3279927bc..db57c547a 100644
--- a/rust/driver/snowflake/src/database/builder.rs
+++ b/rust/driver/snowflake/src/database/builder.rs
@@ -31,9 +31,12 @@ use adbc_core::{
};
use url::{Host, Url};
-#[cfg(feature = "env")]
-use crate::duration::parse_duration;
use crate::{builder::BuilderIter, Database, Driver};
+#[cfg(feature = "env")]
+use crate::{
+ builder::{env_parse, env_parse_map_err},
+ duration::parse_duration,
+};
/// Authentication types.
#[derive(Copy, Clone, Debug, Default)]
@@ -83,7 +86,15 @@ impl str::FromStr for AuthType {
"auth_jwt" => Ok(Self::Jwt),
"auth_mfa" => Ok(Self::UsernamePasswordMFA),
_ => Err(Error::with_message_and_status(
- format!("invalid auth type: {s}"),
+ format!(
+ "invalid auth type: {s} (possible values: {}, {}, {}, {},
{}, {})",
+ Self::Snowflake,
+ Self::OAuth,
+ Self::ExternalBrowser,
+ Self::Okta,
+ Self::Jwt,
+ Self::UsernamePasswordMFA
+ ),
Status::InvalidArguments,
)),
}
@@ -122,7 +133,11 @@ impl str::FromStr for Protocol {
"https" | "HTTPS" => Ok(Self::Https),
"http" | "HTTP" => Ok(Self::Http),
_ => Err(Error::with_message_and_status(
- format!("invalid protocol type: {s}"),
+ format!(
+ "invalid protocol type: {s} (possible values: {}, {})",
+ Self::Https,
+ Self::Http
+ ),
Status::InvalidArguments,
)),
}
@@ -181,7 +196,16 @@ impl str::FromStr for LogLevel {
"fatal" => Ok(Self::Fatal),
"off" => Ok(Self::Off),
_ => Err(Error::with_message_and_status(
- format!("invalid log level: {s}"),
+ format!(
+ "invalid log level: {s} (possible values: {}, {}, {}, {},
{}, {}, {})",
+ Self::Trace,
+ Self::Debug,
+ Self::Info,
+ Self::Warn,
+ Self::Error,
+ Self::Fatal,
+ Self::Off
+ ),
Status::InvalidArguments,
)),
}
@@ -453,14 +477,16 @@ impl Builder {
/// Construct a builder, setting values based on values of the
/// configuration environment variables.
- pub fn from_env() -> Self {
+ ///
+ /// # Error
+ ///
+ /// Returns an error when environment variables are set but their values
+ /// fail to parse.
+ pub fn from_env() -> Result<Self> {
#[cfg(feature = "dotenv")]
let _ = dotenvy::dotenv();
- let uri = env::var(Self::URI_ENV)
- .ok()
- .as_deref()
- .and_then(|value| Url::parse(value).ok());
+ let uri = env_parse_map_err(Self::URI_ENV, Url::parse)?;
let username = env::var(Self::USERNAME_ENV).ok();
let password = env::var(Self::PASSWORD_ENV).ok();
let database = env::var(Self::DATABASE_ENV).ok();
@@ -469,86 +495,34 @@ impl Builder {
let role = env::var(Self::ROLE_ENV).ok();
let region = env::var(Self::REGION_ENV).ok();
let account = env::var(Self::ACCOUNT_ENV).ok();
- let protocol = env::var(Self::PROTOCOL_ENV)
- .ok()
- .as_deref()
- .and_then(|value| value.parse().ok());
- let port = env::var(Self::PORT_ENV)
- .ok()
- .and_then(|value| value.parse().ok());
- let host = env::var(Self::HOST_ENV)
- .ok()
- .as_deref()
- .and_then(|value| Host::parse(value).ok());
- let auth_type = env::var(Self::AUTH_TYPE_ENV)
- .ok()
- .and_then(|value| value.parse().ok());
- let login_timeout = env::var(Self::LOGIN_TIMEOUT_ENV)
- .ok()
- .as_deref()
- .and_then(|value| parse_duration(value).ok());
- let request_timeout = env::var(Self::REQUEST_TIMEOUT_ENV)
- .ok()
- .as_deref()
- .and_then(|value| parse_duration(value).ok());
- let jwt_expire_timeout = env::var(Self::JWT_EXPIRE_TIMEOUT_ENV)
- .ok()
- .as_deref()
- .and_then(|value| parse_duration(value).ok());
- let client_timeout = env::var(Self::CLIENT_TIMEOUT_ENV)
- .ok()
- .as_deref()
- .and_then(|value| parse_duration(value).ok());
- let use_high_precision = env::var(Self::USE_HIGH_PRECISION_ENV)
- .ok()
- .as_deref()
- .and_then(|value| value.parse().ok());
+ let protocol = env_parse(Self::PROTOCOL_ENV, str::parse)?;
+ let port = env_parse_map_err(Self::PORT_ENV, str::parse)?;
+ let host = env_parse_map_err(Self::HOST_ENV, Host::parse)?;
+ let auth_type = env_parse(Self::AUTH_TYPE_ENV, str::parse)?;
+ let login_timeout = env_parse(Self::LOGIN_TIMEOUT_ENV,
parse_duration)?;
+ let request_timeout = env_parse(Self::REQUEST_TIMEOUT_ENV,
parse_duration)?;
+ let jwt_expire_timeout = env_parse(Self::JWT_EXPIRE_TIMEOUT_ENV,
parse_duration)?;
+ let client_timeout = env_parse(Self::CLIENT_TIMEOUT_ENV,
parse_duration)?;
+ let use_high_precision =
env_parse_map_err(Self::USE_HIGH_PRECISION_ENV, str::parse)?;
let application_name = env::var(Self::APPLICATION_NAME_ENV).ok();
- let ssl_skip_verify = env::var(Self::SSL_SKIP_VERIFY_ENV)
- .ok()
- .as_deref()
- .and_then(|value| value.parse().ok());
- let ocsp_fail_open_mode = env::var(Self::OCSP_FAIL_OPEN_MODE_ENV)
- .ok()
- .as_deref()
- .and_then(|value| value.parse().ok());
+ let ssl_skip_verify = env_parse_map_err(Self::SSL_SKIP_VERIFY_ENV,
str::parse)?;
+ let ocsp_fail_open_mode =
env_parse_map_err(Self::OCSP_FAIL_OPEN_MODE_ENV, str::parse)?;
let auth_token = env::var(Self::AUTH_TOKEN_ENV).ok();
- let auth_okta_url = env::var(Self::AUTH_OKTA_URL_ENV)
- .ok()
- .as_deref()
- .and_then(|value| Url::parse(value).ok());
- let keep_session_alive = env::var(Self::OCSP_FAIL_OPEN_MODE_ENV)
- .ok()
- .as_deref()
- .and_then(|value| value.parse().ok());
- let jwt_private_key = env::var(Self::JWT_PRIVATE_KEY_ENV)
- .ok()
- .as_deref()
- .and_then(|value| value.parse().ok());
+ let auth_okta_url = env_parse_map_err(Self::AUTH_OKTA_URL_ENV,
Url::parse)?;
+ let keep_session_alive =
env_parse_map_err(Self::OCSP_FAIL_OPEN_MODE_ENV, str::parse)?;
+ let jwt_private_key = env_parse_map_err(Self::JWT_PRIVATE_KEY_ENV,
str::parse)?;
let jwt_private_key_pkcs8_value =
env::var(Self::JWT_PRIVATE_KEY_PKCS8_VALUE_ENV).ok();
let jwt_private_key_pkcs8_password =
env::var(Self::JWT_PRIVATE_KEY_PKCS8_PASSWORD_ENV).ok();
- let disable_telemetry = env::var(Self::DISABLE_TELEMETRY_ENV)
- .ok()
- .as_deref()
- .and_then(|value| value.parse().ok());
- let log_tracing = env::var(Self::LOG_TRACING_ENV)
- .ok()
- .as_deref()
- .and_then(|value| value.parse().ok());
- let client_config_file = env::var(Self::CLIENT_CONFIG_FILE_ENV)
- .ok()
- .as_deref()
- .and_then(|value| value.parse().ok());
- let client_cache_mfa_token = env::var(Self::CLIENT_CACHE_MFA_TOKEN_ENV)
- .ok()
- .as_deref()
- .and_then(|value| value.parse().ok());
- let client_store_temp_creds =
env::var(Self::CLIENT_STORE_TEMP_CREDS_ENV)
- .ok()
- .as_deref()
- .and_then(|value| value.parse().ok());
- Self {
+ let disable_telemetry = env_parse_map_err(Self::DISABLE_TELEMETRY_ENV,
str::parse)?;
+ let log_tracing = env_parse(Self::LOG_TRACING_ENV, str::parse)?;
+ let client_config_file =
env_parse_map_err(Self::CLIENT_CONFIG_FILE_ENV, str::parse)?;
+ let client_cache_mfa_token =
+ env_parse_map_err(Self::CLIENT_CACHE_MFA_TOKEN_ENV, str::parse)?;
+ let client_store_temp_creds =
+ env_parse_map_err(Self::CLIENT_STORE_TEMP_CREDS_ENV, str::parse)?;
+
+ Ok(Self {
uri,
username,
password,
@@ -582,7 +556,7 @@ impl Builder {
client_cache_mfa_token,
client_store_temp_creds,
..Default::default()
- }
+ })
}
}
@@ -1076,3 +1050,39 @@ impl IntoIterator for Builder {
)
}
}
+
+#[cfg(test)]
+#[cfg(feature = "env")]
+mod tests {
+ use std::env;
+
+ use adbc_core::error::Status;
+
+ use super::*;
+
+ #[test]
+ fn from_env_parse_error() {
+ // Set a value that fails to parse to a LogLevel
+ env::set_var(Builder::LOG_TRACING_ENV, "warning");
+ let result = Builder::from_env();
+ assert!(result.is_err());
+ assert_eq!(
+ result.unwrap_err(),
+ Error::with_message_and_status(
+ "invalid log level: warning (possible values: trace, debug,
info, warn, error, fatal, off)",
+ Status::InvalidArguments
+ )
+ );
+ // Fix it to move on
+ env::set_var(Builder::LOG_TRACING_ENV, "warn");
+
+ // Set a value that fails to parse to a duration
+ env::set_var(Builder::LOGIN_TIMEOUT_ENV, "forever");
+ let result = Builder::from_env();
+ assert!(result.is_err());
+ assert_eq!(
+ result.unwrap_err(),
+ Error::with_message_and_status("invalid duration (valid durations
are a sequence of decimal numbers, each with optional fraction and a unit
suffix, such as 300ms, 1.5h, 2h45m, valid time units are ns, us, ms, s, m, h)",
Status::InvalidArguments)
+ );
+ }
+}
diff --git a/rust/driver/snowflake/src/driver/builder.rs
b/rust/driver/snowflake/src/driver/builder.rs
index d1137cc3e..0ef852bd9 100644
--- a/rust/driver/snowflake/src/driver/builder.rs
+++ b/rust/driver/snowflake/src/driver/builder.rs
@@ -19,14 +19,13 @@
//!
//!
-#[cfg(feature = "env")]
-use std::env;
-
use adbc_core::{
error::{Error, Result},
options::AdbcVersion,
};
+#[cfg(feature = "env")]
+use crate::builder::env_parse;
use crate::Driver;
/// A builder for [`Driver`].
@@ -47,15 +46,18 @@ impl Builder {
/// Construct a builder, setting values based on values of the
/// configuration environment variables.
- pub fn from_env() -> Self {
+ ///
+ /// # Error
+ ///
+ /// Returns an error when environment variables are set but their values
+ /// fail to parse.
+ pub fn from_env() -> Result<Self> {
#[cfg(feature = "dotenv")]
let _ = dotenvy::dotenv();
- let adbc_version = env::var(Self::ADBC_VERSION_ENV)
- .ok()
- .as_deref()
- .and_then(|value| value.parse().ok());
- Self { adbc_version }
+ let adbc_version = env_parse(Self::ADBC_VERSION_ENV, str::parse)?;
+
+ Ok(Self { adbc_version })
}
}
@@ -79,3 +81,25 @@ impl TryFrom<Builder> for Driver {
value.try_load()
}
}
+
+#[cfg(test)]
+#[cfg(feature = "env")]
+mod tests {
+ use std::env;
+
+ use adbc_core::error::Status;
+
+ use super::*;
+
+ #[test]
+ fn from_env_parse_error() {
+ // Set a value that fails to parse to an AdbcVersion
+ env::set_var(Builder::ADBC_VERSION_ENV, "?");
+ let result = Builder::from_env();
+ assert!(result.is_err());
+ assert_eq!(
+ result.unwrap_err(),
+ Error::with_message_and_status("Unknown ADBC version: ?",
Status::InvalidArguments)
+ );
+ }
+}
diff --git a/rust/driver/snowflake/src/duration.rs
b/rust/driver/snowflake/src/duration.rs
index 642851e65..8afaca22c 100644
--- a/rust/driver/snowflake/src/duration.rs
+++ b/rust/driver/snowflake/src/duration.rs
@@ -39,11 +39,11 @@ fn arg_err(err: impl StdError) -> Error {
}
fn overflow() -> Error {
- invalid_arg("overflow")
+ invalid_arg("duration overflow")
}
fn bad_input<T>() -> Result<T> {
- Err(invalid_arg("bad input"))
+ Err(invalid_arg("invalid duration (valid durations are a sequence of
decimal numbers, each with optional fraction and a unit suffix, such as 300ms,
1.5h, 2h45m, valid time units are ns, us, ms, s, m, h)"))
}
/// Parse the given string to a [`Duration`], returning an eror when parsing
@@ -260,7 +260,7 @@ mod tests {
+ Duration::from_secs(48)
+ Duration::from_nanos(372539828))
);
- let bad_input = Err(invalid_arg("bad input"));
+ let bad_input = Err(invalid_arg("invalid duration (valid durations are
a sequence of decimal numbers, each with optional fraction and a unit suffix,
such as 300ms, 1.5h, 2h45m, valid time units are ns, us, ms, s, m, h)"));
assert_eq!(parse_duration(""), bad_input);
assert_eq!(parse_duration("3"), bad_input);
assert_eq!(parse_duration("-"), bad_input);
diff --git a/rust/driver/snowflake/tests/driver.rs
b/rust/driver/snowflake/tests/driver.rs
index 8f2a28d72..300bc3225 100644
--- a/rust/driver/snowflake/tests/driver.rs
+++ b/rust/driver/snowflake/tests/driver.rs
@@ -47,16 +47,16 @@ mod tests {
const ADBC_VERSION: AdbcVersion = AdbcVersion::V110;
static DRIVER: LazyLock<Result<Driver>> = LazyLock::new(|| {
- driver::Builder::from_env()
+ driver::Builder::from_env()?
.with_adbc_version(ADBC_VERSION)
.try_load()
});
static DATABASE: LazyLock<Result<Database>> =
- LazyLock::new(|| database::Builder::from_env().build(&mut
DRIVER.deref().clone()?));
+ LazyLock::new(|| database::Builder::from_env()?.build(&mut
DRIVER.deref().clone()?));
static CONNECTION: LazyLock<Result<Connection>> =
- LazyLock::new(|| connection::Builder::from_env().build(&mut
DATABASE.deref().clone()?));
+ LazyLock::new(|| connection::Builder::from_env()?.build(&mut
DATABASE.deref().clone()?));
fn with_database(func: impl FnOnce(Database) -> Result<()>) -> Result<()> {
DATABASE.deref().clone().and_then(func)