This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 933d3489dc Lower GCP token min_ttl to 4 minutes and add backoff to 
token refresh logic (#6638)
933d3489dc is described below

commit 933d3489dcdb4d00ecc5cbec618d0f8f8aadd4e4
Author: Micah Wylde <mi...@micahw.com>
AuthorDate: Wed Oct 30 00:22:22 2024 -0700

    Lower GCP token min_ttl to 4 minutes and add backoff to token refresh logic 
(#6638)
---
 object_store/src/client/mod.rs   |  2 +-
 object_store/src/client/token.rs | 89 ++++++++++++++++++++++++++++++++++++----
 object_store/src/gcp/builder.rs  | 25 +++++++----
 3 files changed, 99 insertions(+), 17 deletions(-)

diff --git a/object_store/src/client/mod.rs b/object_store/src/client/mod.rs
index b65fea7436..76d1c1f22f 100644
--- a/object_store/src/client/mod.rs
+++ b/object_store/src/client/mod.rs
@@ -774,7 +774,7 @@ mod cloud {
         }
 
         /// Override the minimum remaining TTL for a cached token to be used
-        #[cfg(feature = "aws")]
+        #[cfg(any(feature = "aws", feature = "gcp"))]
         pub(crate) fn with_min_ttl(mut self, min_ttl: Duration) -> Self {
             self.cache = self.cache.with_min_ttl(min_ttl);
             self
diff --git a/object_store/src/client/token.rs b/object_store/src/client/token.rs
index f7294190f5..81ffc110ac 100644
--- a/object_store/src/client/token.rs
+++ b/object_store/src/client/token.rs
@@ -33,8 +33,9 @@ pub(crate) struct TemporaryToken<T> {
 /// [`TemporaryToken`] based on its expiry
 #[derive(Debug)]
 pub(crate) struct TokenCache<T> {
-    cache: Mutex<Option<TemporaryToken<T>>>,
+    cache: Mutex<Option<(TemporaryToken<T>, Instant)>>,
     min_ttl: Duration,
+    fetch_backoff: Duration,
 }
 
 impl<T> Default for TokenCache<T> {
@@ -42,13 +43,16 @@ impl<T> Default for TokenCache<T> {
         Self {
             cache: Default::default(),
             min_ttl: Duration::from_secs(300),
+            // How long to wait before re-attempting a token fetch after 
receiving one that
+            // is still within the min-ttl
+            fetch_backoff: Duration::from_millis(100),
         }
     }
 }
 
 impl<T: Clone + Send> TokenCache<T> {
     /// Override the minimum remaining TTL for a cached token to be used
-    #[cfg(feature = "aws")]
+    #[cfg(any(feature = "aws", feature = "gcp"))]
     pub(crate) fn with_min_ttl(self, min_ttl: Duration) -> Self {
         Self { min_ttl, ..self }
     }
@@ -61,20 +65,91 @@ impl<T: Clone + Send> TokenCache<T> {
         let now = Instant::now();
         let mut locked = self.cache.lock().await;
 
-        if let Some(cached) = locked.as_ref() {
+        if let Some((cached, fetched_at)) = locked.as_ref() {
             match cached.expiry {
-                Some(ttl) if 
ttl.checked_duration_since(now).unwrap_or_default() > self.min_ttl => {
-                    return Ok(cached.token.clone());
+                Some(ttl) => {
+                    if ttl.checked_duration_since(now).unwrap_or_default() > 
self.min_ttl ||
+                        // if we've recently attempted to fetch this token and 
it's not actually
+                        // expired, we'll wait to re-fetch it and return the 
cached one
+                        (fetched_at.elapsed() < self.fetch_backoff && 
ttl.checked_duration_since(now).is_some())
+                    {
+                        return Ok(cached.token.clone());
+                    }
                 }
                 None => return Ok(cached.token.clone()),
-                _ => (),
             }
         }
 
         let cached = f().await?;
         let token = cached.token.clone();
-        *locked = Some(cached);
+        *locked = Some((cached, Instant::now()));
 
         Ok(token)
     }
 }
+
+#[cfg(test)]
+mod test {
+    use crate::client::token::{TemporaryToken, TokenCache};
+    use std::sync::atomic::{AtomicU32, Ordering};
+    use std::time::{Duration, Instant};
+
+    // Helper function to create a token with a specific expiry duration from 
now
+    fn create_token(expiry_duration: Option<Duration>) -> 
TemporaryToken<String> {
+        TemporaryToken {
+            token: "test_token".to_string(),
+            expiry: expiry_duration.map(|d| Instant::now() + d),
+        }
+    }
+
+    #[tokio::test]
+    async fn test_expired_token_is_refreshed() {
+        let cache = TokenCache::default();
+        static COUNTER: AtomicU32 = AtomicU32::new(0);
+
+        async fn get_token() -> Result<TemporaryToken<String>, String> {
+            COUNTER.fetch_add(1, Ordering::SeqCst);
+            Ok::<_, String>(create_token(Some(Duration::from_secs(0))))
+        }
+
+        // Should fetch initial token
+        let _ = cache.get_or_insert_with(get_token).await.unwrap();
+        assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
+
+        tokio::time::sleep(Duration::from_millis(2)).await;
+
+        // Token is expired, so should fetch again
+        let _ = cache.get_or_insert_with(get_token).await.unwrap();
+        assert_eq!(COUNTER.load(Ordering::SeqCst), 2);
+    }
+
+    #[tokio::test]
+    async fn test_min_ttl_causes_refresh() {
+        let cache = TokenCache {
+            cache: Default::default(),
+            min_ttl: Duration::from_secs(1),
+            fetch_backoff: Duration::from_millis(1),
+        };
+
+        static COUNTER: AtomicU32 = AtomicU32::new(0);
+
+        async fn get_token() -> Result<TemporaryToken<String>, String> {
+            COUNTER.fetch_add(1, Ordering::SeqCst);
+            Ok::<_, String>(create_token(Some(Duration::from_millis(100))))
+        }
+
+        // Initial fetch
+        let _ = cache.get_or_insert_with(get_token).await.unwrap();
+        assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
+
+        // Should not fetch again since not expired and within fetch_backoff
+        let _ = cache.get_or_insert_with(get_token).await.unwrap();
+        assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
+
+        tokio::time::sleep(Duration::from_millis(2)).await;
+
+        // Should fetch, since we've passed fetch_backoff
+        let _ = cache.get_or_insert_with(get_token).await.unwrap();
+        assert_eq!(COUNTER.load(Ordering::SeqCst), 2);
+    }
+}
diff --git a/object_store/src/gcp/builder.rs b/object_store/src/gcp/builder.rs
index 26cc8211d2..fac923c4b9 100644
--- a/object_store/src/gcp/builder.rs
+++ b/object_store/src/gcp/builder.rs
@@ -30,10 +30,13 @@ use serde::{Deserialize, Serialize};
 use snafu::{OptionExt, ResultExt, Snafu};
 use std::str::FromStr;
 use std::sync::Arc;
+use std::time::Duration;
 use url::Url;
 
 use super::credential::{AuthorizedUserSigningCredentials, 
InstanceSigningCredentialProvider};
 
+const TOKEN_MIN_TTL: Duration = Duration::from_secs(4 * 60);
+
 #[derive(Debug, Snafu)]
 enum Error {
     #[snafu(display("Missing bucket name"))]
@@ -463,13 +466,14 @@ impl GoogleCloudStorageBuilder {
             )) as _
         } else if let Some(credentials) = 
application_default_credentials.clone() {
             match credentials {
-                ApplicationDefaultCredentials::AuthorizedUser(token) => {
-                    Arc::new(TokenCredentialProvider::new(
+                ApplicationDefaultCredentials::AuthorizedUser(token) => 
Arc::new(
+                    TokenCredentialProvider::new(
                         token,
                         self.client_options.client()?,
                         self.retry_config.clone(),
-                    )) as _
-                }
+                    )
+                    .with_min_ttl(TOKEN_MIN_TTL),
+                ) as _,
                 ApplicationDefaultCredentials::ServiceAccount(token) => {
                     Arc::new(TokenCredentialProvider::new(
                         token.token_provider()?,
@@ -479,11 +483,14 @@ impl GoogleCloudStorageBuilder {
                 }
             }
         } else {
-            Arc::new(TokenCredentialProvider::new(
-                InstanceCredentialProvider::default(),
-                self.client_options.metadata_client()?,
-                self.retry_config.clone(),
-            )) as _
+            Arc::new(
+                TokenCredentialProvider::new(
+                    InstanceCredentialProvider::default(),
+                    self.client_options.metadata_client()?,
+                    self.retry_config.clone(),
+                )
+                .with_min_ttl(TOKEN_MIN_TTL),
+            ) as _
         };
 
         let signing_credentials = if let Some(signing_credentials) = 
self.signing_credentials {

Reply via email to