This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 5da710b56 feat(rust/driver/snowflake)!: return a `Result` from 
`Builder::from_env` when parsing fails (#2334)
5da710b56 is described below

commit 5da710b56ed9c3f55ace8bdadd437b0aa3e2e5bb
Author: Matthijs Brobbel <[email protected]>
AuthorDate: Mon Nov 25 02:00:21 2024 +0100

    feat(rust/driver/snowflake)!: return a `Result` from `Builder::from_env` 
when parsing fails (#2334)
    
    As suggested in
    https://github.com/apache/arrow-adbc/pull/2207#discussion_r1853206066
    the `Builder::from_env` methods should return a result when parsing
    fails.
---
 rust/driver/snowflake/README.md                 |   4 +-
 rust/driver/snowflake/src/builder.rs            |  36 +++++
 rust/driver/snowflake/src/connection/builder.rs |  21 +--
 rust/driver/snowflake/src/database/builder.rs   | 182 +++++++++++++-----------
 rust/driver/snowflake/src/driver/builder.rs     |  42 ++++--
 rust/driver/snowflake/src/duration.rs           |   6 +-
 rust/driver/snowflake/tests/driver.rs           |   6 +-
 7 files changed, 184 insertions(+), 113 deletions(-)

diff --git a/rust/driver/snowflake/README.md b/rust/driver/snowflake/README.md
index 5f8c143f0..114603662 100644
--- a/rust/driver/snowflake/README.md
+++ b/rust/driver/snowflake/README.md
@@ -41,10 +41,10 @@ use arrow_array::{cast::AsArray, types::Decimal128Type};
 let mut driver = Driver::try_load()?;
 
 // Construct a database using environment variables
-let mut database = database::Builder::from_env().build(&mut driver)?;
+let mut database = database::Builder::from_env()?.build(&mut driver)?;
 
 // Create a connection to the database
-let mut connection = connection::Builder::from_env().build(&mut database)?;
+let mut connection = connection::Builder::from_env()?.build(&mut database)?;
 
 // Construct a statement to execute a query
 let mut statement = connection.new_statement()?;
diff --git a/rust/driver/snowflake/src/builder.rs 
b/rust/driver/snowflake/src/builder.rs
index 766cb6bed..28e65dfb2 100644
--- a/rust/driver/snowflake/src/builder.rs
+++ b/rust/driver/snowflake/src/builder.rs
@@ -20,7 +20,11 @@
 //!
 
 use std::iter::{Chain, Flatten};
+#[cfg(feature = "env")]
+use std::{env, error::Error as StdError};
 
+#[cfg(feature = "env")]
+use adbc_core::error::{Error, Status};
 use adbc_core::options::OptionValue;
 
 /// An iterator over the builder options.
@@ -48,3 +52,35 @@ impl<T, const COUNT: usize> Iterator for BuilderIter<T, 
COUNT> {
         self.0.next()
     }
 }
+
+#[cfg(feature = "env")]
+/// Attempt to read the environment variable with the given `key`, parsing it
+/// using the provided `parse` function.
+///
+/// Returns
+///
+/// - `Ok(None)` when the env variable is not set.
+/// - `Ok(Some(T))` when the env variable is set and the parser succeeds.
+/// - `Err(Error)` when the env variable is set and the parse fails.
+pub(crate) fn env_parse<T>(
+    key: &str,
+    parse: impl FnOnce(&str) -> Result<T, Error>,
+) -> Result<Option<T>, Error> {
+    env::var(key).ok().as_deref().map(parse).transpose()
+}
+
+#[cfg(feature = "env")]
+/// Attempt to read the environment variable with the given `key`, parsing it
+/// using the provided `parse` function, mapping the parse result to an
+/// [`Error`] with [`Status::InvalidArguments`].
+pub(crate) fn env_parse_map_err<T, E: StdError>(
+    key: &str,
+    parse: impl FnOnce(&str) -> Result<T, E>,
+) -> Result<Option<T>, Error> {
+    env::var(key)
+        .ok()
+        .as_deref()
+        .map(parse)
+        .transpose()
+        .map_err(|err| Error::with_message_and_status(err.to_string(), 
Status::InvalidArguments))
+}
diff --git a/rust/driver/snowflake/src/connection/builder.rs 
b/rust/driver/snowflake/src/connection/builder.rs
index 8b6f0e857..d1b8422db 100644
--- a/rust/driver/snowflake/src/connection/builder.rs
+++ b/rust/driver/snowflake/src/connection/builder.rs
@@ -19,8 +19,6 @@
 //!
 //!
 
-#[cfg(feature = "env")]
-use std::env;
 use std::fmt;
 
 use adbc_core::{
@@ -30,7 +28,7 @@ use adbc_core::{
 };
 
 #[cfg(feature = "env")]
-use crate::database;
+use crate::{builder::env_parse_map_err, database};
 use crate::{builder::BuilderIter, Connection, Database};
 
 /// A builder for [`Connection`].
@@ -61,18 +59,21 @@ impl Builder {
 
     /// Construct a builder, setting values based on values of the
     /// configuration environment variables.
-    pub fn from_env() -> Self {
+    ///
+    /// # Error
+    ///
+    /// Returns an error when environment variables are set but their values
+    /// fail to parse.
+    pub fn from_env() -> Result<Self> {
         #[cfg(feature = "dotenv")]
         let _ = dotenvy::dotenv();
 
-        let use_high_precision = env::var(Self::USE_HIGH_PRECISION_ENV)
-            .ok()
-            .as_deref()
-            .and_then(|value| value.parse().ok());
-        Self {
+        let use_high_precision = 
env_parse_map_err(Self::USE_HIGH_PRECISION_ENV, str::parse)?;
+
+        Ok(Self {
             use_high_precision,
             ..Default::default()
-        }
+        })
     }
 }
 
diff --git a/rust/driver/snowflake/src/database/builder.rs 
b/rust/driver/snowflake/src/database/builder.rs
index 3279927bc..db57c547a 100644
--- a/rust/driver/snowflake/src/database/builder.rs
+++ b/rust/driver/snowflake/src/database/builder.rs
@@ -31,9 +31,12 @@ use adbc_core::{
 };
 use url::{Host, Url};
 
-#[cfg(feature = "env")]
-use crate::duration::parse_duration;
 use crate::{builder::BuilderIter, Database, Driver};
+#[cfg(feature = "env")]
+use crate::{
+    builder::{env_parse, env_parse_map_err},
+    duration::parse_duration,
+};
 
 /// Authentication types.
 #[derive(Copy, Clone, Debug, Default)]
@@ -83,7 +86,15 @@ impl str::FromStr for AuthType {
             "auth_jwt" => Ok(Self::Jwt),
             "auth_mfa" => Ok(Self::UsernamePasswordMFA),
             _ => Err(Error::with_message_and_status(
-                format!("invalid auth type: {s}"),
+                format!(
+                    "invalid auth type: {s} (possible values: {}, {}, {}, {}, 
{}, {})",
+                    Self::Snowflake,
+                    Self::OAuth,
+                    Self::ExternalBrowser,
+                    Self::Okta,
+                    Self::Jwt,
+                    Self::UsernamePasswordMFA
+                ),
                 Status::InvalidArguments,
             )),
         }
@@ -122,7 +133,11 @@ impl str::FromStr for Protocol {
             "https" | "HTTPS" => Ok(Self::Https),
             "http" | "HTTP" => Ok(Self::Http),
             _ => Err(Error::with_message_and_status(
-                format!("invalid protocol type: {s}"),
+                format!(
+                    "invalid protocol type: {s} (possible values: {}, {})",
+                    Self::Https,
+                    Self::Http
+                ),
                 Status::InvalidArguments,
             )),
         }
@@ -181,7 +196,16 @@ impl str::FromStr for LogLevel {
             "fatal" => Ok(Self::Fatal),
             "off" => Ok(Self::Off),
             _ => Err(Error::with_message_and_status(
-                format!("invalid log level: {s}"),
+                format!(
+                    "invalid log level: {s} (possible values: {}, {}, {}, {}, 
{}, {}, {})",
+                    Self::Trace,
+                    Self::Debug,
+                    Self::Info,
+                    Self::Warn,
+                    Self::Error,
+                    Self::Fatal,
+                    Self::Off
+                ),
                 Status::InvalidArguments,
             )),
         }
@@ -453,14 +477,16 @@ impl Builder {
 
     /// Construct a builder, setting values based on values of the
     /// configuration environment variables.
-    pub fn from_env() -> Self {
+    ///
+    /// # Error
+    ///
+    /// Returns an error when environment variables are set but their values
+    /// fail to parse.
+    pub fn from_env() -> Result<Self> {
         #[cfg(feature = "dotenv")]
         let _ = dotenvy::dotenv();
 
-        let uri = env::var(Self::URI_ENV)
-            .ok()
-            .as_deref()
-            .and_then(|value| Url::parse(value).ok());
+        let uri = env_parse_map_err(Self::URI_ENV, Url::parse)?;
         let username = env::var(Self::USERNAME_ENV).ok();
         let password = env::var(Self::PASSWORD_ENV).ok();
         let database = env::var(Self::DATABASE_ENV).ok();
@@ -469,86 +495,34 @@ impl Builder {
         let role = env::var(Self::ROLE_ENV).ok();
         let region = env::var(Self::REGION_ENV).ok();
         let account = env::var(Self::ACCOUNT_ENV).ok();
-        let protocol = env::var(Self::PROTOCOL_ENV)
-            .ok()
-            .as_deref()
-            .and_then(|value| value.parse().ok());
-        let port = env::var(Self::PORT_ENV)
-            .ok()
-            .and_then(|value| value.parse().ok());
-        let host = env::var(Self::HOST_ENV)
-            .ok()
-            .as_deref()
-            .and_then(|value| Host::parse(value).ok());
-        let auth_type = env::var(Self::AUTH_TYPE_ENV)
-            .ok()
-            .and_then(|value| value.parse().ok());
-        let login_timeout = env::var(Self::LOGIN_TIMEOUT_ENV)
-            .ok()
-            .as_deref()
-            .and_then(|value| parse_duration(value).ok());
-        let request_timeout = env::var(Self::REQUEST_TIMEOUT_ENV)
-            .ok()
-            .as_deref()
-            .and_then(|value| parse_duration(value).ok());
-        let jwt_expire_timeout = env::var(Self::JWT_EXPIRE_TIMEOUT_ENV)
-            .ok()
-            .as_deref()
-            .and_then(|value| parse_duration(value).ok());
-        let client_timeout = env::var(Self::CLIENT_TIMEOUT_ENV)
-            .ok()
-            .as_deref()
-            .and_then(|value| parse_duration(value).ok());
-        let use_high_precision = env::var(Self::USE_HIGH_PRECISION_ENV)
-            .ok()
-            .as_deref()
-            .and_then(|value| value.parse().ok());
+        let protocol = env_parse(Self::PROTOCOL_ENV, str::parse)?;
+        let port = env_parse_map_err(Self::PORT_ENV, str::parse)?;
+        let host = env_parse_map_err(Self::HOST_ENV, Host::parse)?;
+        let auth_type = env_parse(Self::AUTH_TYPE_ENV, str::parse)?;
+        let login_timeout = env_parse(Self::LOGIN_TIMEOUT_ENV, 
parse_duration)?;
+        let request_timeout = env_parse(Self::REQUEST_TIMEOUT_ENV, 
parse_duration)?;
+        let jwt_expire_timeout = env_parse(Self::JWT_EXPIRE_TIMEOUT_ENV, 
parse_duration)?;
+        let client_timeout = env_parse(Self::CLIENT_TIMEOUT_ENV, 
parse_duration)?;
+        let use_high_precision = 
env_parse_map_err(Self::USE_HIGH_PRECISION_ENV, str::parse)?;
         let application_name = env::var(Self::APPLICATION_NAME_ENV).ok();
-        let ssl_skip_verify = env::var(Self::SSL_SKIP_VERIFY_ENV)
-            .ok()
-            .as_deref()
-            .and_then(|value| value.parse().ok());
-        let ocsp_fail_open_mode = env::var(Self::OCSP_FAIL_OPEN_MODE_ENV)
-            .ok()
-            .as_deref()
-            .and_then(|value| value.parse().ok());
+        let ssl_skip_verify = env_parse_map_err(Self::SSL_SKIP_VERIFY_ENV, 
str::parse)?;
+        let ocsp_fail_open_mode = 
env_parse_map_err(Self::OCSP_FAIL_OPEN_MODE_ENV, str::parse)?;
         let auth_token = env::var(Self::AUTH_TOKEN_ENV).ok();
-        let auth_okta_url = env::var(Self::AUTH_OKTA_URL_ENV)
-            .ok()
-            .as_deref()
-            .and_then(|value| Url::parse(value).ok());
-        let keep_session_alive = env::var(Self::OCSP_FAIL_OPEN_MODE_ENV)
-            .ok()
-            .as_deref()
-            .and_then(|value| value.parse().ok());
-        let jwt_private_key = env::var(Self::JWT_PRIVATE_KEY_ENV)
-            .ok()
-            .as_deref()
-            .and_then(|value| value.parse().ok());
+        let auth_okta_url = env_parse_map_err(Self::AUTH_OKTA_URL_ENV, 
Url::parse)?;
+        let keep_session_alive = 
env_parse_map_err(Self::OCSP_FAIL_OPEN_MODE_ENV, str::parse)?;
+        let jwt_private_key = env_parse_map_err(Self::JWT_PRIVATE_KEY_ENV, 
str::parse)?;
         let jwt_private_key_pkcs8_value = 
env::var(Self::JWT_PRIVATE_KEY_PKCS8_VALUE_ENV).ok();
         let jwt_private_key_pkcs8_password =
             env::var(Self::JWT_PRIVATE_KEY_PKCS8_PASSWORD_ENV).ok();
-        let disable_telemetry = env::var(Self::DISABLE_TELEMETRY_ENV)
-            .ok()
-            .as_deref()
-            .and_then(|value| value.parse().ok());
-        let log_tracing = env::var(Self::LOG_TRACING_ENV)
-            .ok()
-            .as_deref()
-            .and_then(|value| value.parse().ok());
-        let client_config_file = env::var(Self::CLIENT_CONFIG_FILE_ENV)
-            .ok()
-            .as_deref()
-            .and_then(|value| value.parse().ok());
-        let client_cache_mfa_token = env::var(Self::CLIENT_CACHE_MFA_TOKEN_ENV)
-            .ok()
-            .as_deref()
-            .and_then(|value| value.parse().ok());
-        let client_store_temp_creds = 
env::var(Self::CLIENT_STORE_TEMP_CREDS_ENV)
-            .ok()
-            .as_deref()
-            .and_then(|value| value.parse().ok());
-        Self {
+        let disable_telemetry = env_parse_map_err(Self::DISABLE_TELEMETRY_ENV, 
str::parse)?;
+        let log_tracing = env_parse(Self::LOG_TRACING_ENV, str::parse)?;
+        let client_config_file = 
env_parse_map_err(Self::CLIENT_CONFIG_FILE_ENV, str::parse)?;
+        let client_cache_mfa_token =
+            env_parse_map_err(Self::CLIENT_CACHE_MFA_TOKEN_ENV, str::parse)?;
+        let client_store_temp_creds =
+            env_parse_map_err(Self::CLIENT_STORE_TEMP_CREDS_ENV, str::parse)?;
+
+        Ok(Self {
             uri,
             username,
             password,
@@ -582,7 +556,7 @@ impl Builder {
             client_cache_mfa_token,
             client_store_temp_creds,
             ..Default::default()
-        }
+        })
     }
 }
 
@@ -1076,3 +1050,39 @@ impl IntoIterator for Builder {
         )
     }
 }
+
+#[cfg(test)]
+#[cfg(feature = "env")]
+mod tests {
+    use std::env;
+
+    use adbc_core::error::Status;
+
+    use super::*;
+
+    #[test]
+    fn from_env_parse_error() {
+        // Set a value that fails to parse to a LogLevel
+        env::set_var(Builder::LOG_TRACING_ENV, "warning");
+        let result = Builder::from_env();
+        assert!(result.is_err());
+        assert_eq!(
+            result.unwrap_err(),
+            Error::with_message_and_status(
+                "invalid log level: warning (possible values: trace, debug, 
info, warn, error, fatal, off)",
+                Status::InvalidArguments
+            )
+        );
+        // Fix it to move on
+        env::set_var(Builder::LOG_TRACING_ENV, "warn");
+
+        // Set a value that fails to parse to a duration
+        env::set_var(Builder::LOGIN_TIMEOUT_ENV, "forever");
+        let result = Builder::from_env();
+        assert!(result.is_err());
+        assert_eq!(
+            result.unwrap_err(),
+            Error::with_message_and_status("invalid duration (valid durations 
are a sequence of decimal numbers, each with optional fraction and a unit 
suffix, such as 300ms, 1.5h, 2h45m, valid time units are ns, us, ms, s, m, h)", 
Status::InvalidArguments)
+        );
+    }
+}
diff --git a/rust/driver/snowflake/src/driver/builder.rs 
b/rust/driver/snowflake/src/driver/builder.rs
index d1137cc3e..0ef852bd9 100644
--- a/rust/driver/snowflake/src/driver/builder.rs
+++ b/rust/driver/snowflake/src/driver/builder.rs
@@ -19,14 +19,13 @@
 //!
 //!
 
-#[cfg(feature = "env")]
-use std::env;
-
 use adbc_core::{
     error::{Error, Result},
     options::AdbcVersion,
 };
 
+#[cfg(feature = "env")]
+use crate::builder::env_parse;
 use crate::Driver;
 
 /// A builder for [`Driver`].
@@ -47,15 +46,18 @@ impl Builder {
 
     /// Construct a builder, setting values based on values of the
     /// configuration environment variables.
-    pub fn from_env() -> Self {
+    ///
+    /// # Error
+    ///
+    /// Returns an error when environment variables are set but their values
+    /// fail to parse.
+    pub fn from_env() -> Result<Self> {
         #[cfg(feature = "dotenv")]
         let _ = dotenvy::dotenv();
 
-        let adbc_version = env::var(Self::ADBC_VERSION_ENV)
-            .ok()
-            .as_deref()
-            .and_then(|value| value.parse().ok());
-        Self { adbc_version }
+        let adbc_version = env_parse(Self::ADBC_VERSION_ENV, str::parse)?;
+
+        Ok(Self { adbc_version })
     }
 }
 
@@ -79,3 +81,25 @@ impl TryFrom<Builder> for Driver {
         value.try_load()
     }
 }
+
+#[cfg(test)]
+#[cfg(feature = "env")]
+mod tests {
+    use std::env;
+
+    use adbc_core::error::Status;
+
+    use super::*;
+
+    #[test]
+    fn from_env_parse_error() {
+        // Set a value that fails to parse to an AdbcVersion
+        env::set_var(Builder::ADBC_VERSION_ENV, "?");
+        let result = Builder::from_env();
+        assert!(result.is_err());
+        assert_eq!(
+            result.unwrap_err(),
+            Error::with_message_and_status("Unknown ADBC version: ?", 
Status::InvalidArguments)
+        );
+    }
+}
diff --git a/rust/driver/snowflake/src/duration.rs 
b/rust/driver/snowflake/src/duration.rs
index 642851e65..8afaca22c 100644
--- a/rust/driver/snowflake/src/duration.rs
+++ b/rust/driver/snowflake/src/duration.rs
@@ -39,11 +39,11 @@ fn arg_err(err: impl StdError) -> Error {
 }
 
 fn overflow() -> Error {
-    invalid_arg("overflow")
+    invalid_arg("duration overflow")
 }
 
 fn bad_input<T>() -> Result<T> {
-    Err(invalid_arg("bad input"))
+    Err(invalid_arg("invalid duration (valid durations are a sequence of 
decimal numbers, each with optional fraction and a unit suffix, such as 300ms, 
1.5h, 2h45m, valid time units are ns, us, ms, s, m, h)"))
 }
 
 /// Parse the given string to a [`Duration`], returning an eror when parsing
@@ -260,7 +260,7 @@ mod tests {
                 + Duration::from_secs(48)
                 + Duration::from_nanos(372539828))
         );
-        let bad_input = Err(invalid_arg("bad input"));
+        let bad_input = Err(invalid_arg("invalid duration (valid durations are 
a sequence of decimal numbers, each with optional fraction and a unit suffix, 
such as 300ms, 1.5h, 2h45m, valid time units are ns, us, ms, s, m, h)"));
         assert_eq!(parse_duration(""), bad_input);
         assert_eq!(parse_duration("3"), bad_input);
         assert_eq!(parse_duration("-"), bad_input);
diff --git a/rust/driver/snowflake/tests/driver.rs 
b/rust/driver/snowflake/tests/driver.rs
index 8f2a28d72..300bc3225 100644
--- a/rust/driver/snowflake/tests/driver.rs
+++ b/rust/driver/snowflake/tests/driver.rs
@@ -47,16 +47,16 @@ mod tests {
     const ADBC_VERSION: AdbcVersion = AdbcVersion::V110;
 
     static DRIVER: LazyLock<Result<Driver>> = LazyLock::new(|| {
-        driver::Builder::from_env()
+        driver::Builder::from_env()?
             .with_adbc_version(ADBC_VERSION)
             .try_load()
     });
 
     static DATABASE: LazyLock<Result<Database>> =
-        LazyLock::new(|| database::Builder::from_env().build(&mut 
DRIVER.deref().clone()?));
+        LazyLock::new(|| database::Builder::from_env()?.build(&mut 
DRIVER.deref().clone()?));
 
     static CONNECTION: LazyLock<Result<Connection>> =
-        LazyLock::new(|| connection::Builder::from_env().build(&mut 
DATABASE.deref().clone()?));
+        LazyLock::new(|| connection::Builder::from_env()?.build(&mut 
DATABASE.deref().clone()?));
 
     fn with_database(func: impl FnOnce(Database) -> Result<()>) -> Result<()> {
         DATABASE.deref().clone().and_then(func)

Reply via email to