kszucs commented on code in PR #7185:
URL: https://github.com/apache/opendal/pull/7185#discussion_r2791970511


##########
core/services/huggingface/src/core.rs:
##########
@@ -18,330 +18,442 @@
 use std::fmt::Debug;
 use std::sync::Arc;
 
+use bytes::Buf;
 use bytes::Bytes;
 use http::Request;
 use http::Response;
 use http::header;
-use percent_encoding::{NON_ALPHANUMERIC, utf8_percent_encode};
 use serde::Deserialize;
 
-use super::backend::RepoType;
+#[cfg(feature = "xet")]
+use xet_utils::auth::TokenRefresher;
+
+use super::error::parse_error;
+use super::uri::HfRepo;
 use opendal_core::raw::*;
 use opendal_core::*;
 
-fn percent_encode_revision(revision: &str) -> String {
-    utf8_percent_encode(revision, NON_ALPHANUMERIC).to_string()
+/// API payload structures for commit operations
+#[derive(Debug, serde::Serialize)]
+pub(super) struct CommitFile {
+    pub path: String,
+    pub content: String,
+    pub encoding: String,
 }
 
-pub struct HuggingfaceCore {
-    pub info: Arc<AccessorInfo>,
+#[derive(Debug, serde::Serialize)]
+pub(super) struct LfsFile {
+    pub path: String,
+    pub oid: String,
+    pub algo: String,
+    pub size: u64,
+}
 
-    pub repo_type: RepoType,
-    pub repo_id: String,
-    pub revision: String,
-    pub root: String,
-    pub token: Option<String>,
-    pub endpoint: String,
+#[derive(Clone, Debug, serde::Serialize)]
+pub(super) struct DeletedFile {
+    pub path: String,
 }
 
-impl Debug for HuggingfaceCore {
-    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
-        f.debug_struct("HuggingfaceCore")
-            .field("repo_type", &self.repo_type)
-            .field("repo_id", &self.repo_id)
-            .field("revision", &self.revision)
-            .field("root", &self.root)
-            .field("endpoint", &self.endpoint)
-            .finish_non_exhaustive()
-    }
+#[derive(serde::Serialize)]
+pub(super) struct MixedCommitPayload {
+    pub summary: String,
+    #[serde(skip_serializing_if = "Vec::is_empty")]
+    pub files: Vec<CommitFile>,
+    #[serde(rename = "lfsFiles", skip_serializing_if = "Vec::is_empty")]
+    pub lfs_files: Vec<LfsFile>,
+    #[serde(rename = "deletedFiles", skip_serializing_if = "Vec::is_empty")]
+    pub deleted_files: Vec<DeletedFile>,
 }
 
-impl HuggingfaceCore {
-    pub async fn hf_path_info(&self, path: &str) -> Result<Response<Buffer>> {
-        let p = build_abs_path(&self.root, path)
-            .trim_end_matches('/')
-            .to_string();
-
-        let url = match self.repo_type {
-            RepoType::Model => format!(
-                "{}/api/models/{}/paths-info/{}",
-                &self.endpoint,
-                &self.repo_id,
-                percent_encode_revision(&self.revision)
-            ),
-            RepoType::Dataset => format!(
-                "{}/api/datasets/{}/paths-info/{}",
-                &self.endpoint,
-                &self.repo_id,
-                percent_encode_revision(&self.revision)
-            ),
-            RepoType::Space => format!(
-                "{}/api/spaces/{}/paths-info/{}",
-                &self.endpoint,
-                &self.repo_id,
-                percent_encode_revision(&self.revision)
-            ),
-        };
+// API response types
 
-        let mut req = Request::post(&url);
-        // Inject operation to the request.
-        req = req.extension(Operation::Stat);
-        if let Some(token) = &self.token {
-            let auth_header_content = format_authorization_by_bearer(token)?;
-            req = req.header(header::AUTHORIZATION, auth_header_content);
+#[derive(serde::Deserialize, Debug)]
+pub(super) struct CommitResponse {
+    #[serde(rename = "commitOid")]
+    pub commit_oid: Option<String>,
+    #[allow(dead_code)]
+    #[serde(rename = "commitUrl")]
+    pub commit_url: Option<String>,
+}
+
+#[derive(Deserialize, Eq, PartialEq, Debug)]
+#[serde(rename_all = "camelCase")]
+pub(super) struct PathInfo {
+    #[serde(rename = "type")]
+    pub type_: String,
+    pub oid: String,
+    pub size: u64,
+    #[serde(default)]
+    pub lfs: Option<LfsInfo>,
+    pub path: String,
+    #[serde(default)]
+    pub last_commit: Option<LastCommit>,
+}
+
+impl PathInfo {
+    pub fn entry_mode(&self) -> EntryMode {
+        match self.type_.as_str() {
+            "directory" => EntryMode::DIR,
+            "file" => EntryMode::FILE,
+            _ => EntryMode::Unknown,
         }
+    }
 
-        req = req.header(header::CONTENT_TYPE, 
"application/x-www-form-urlencoded");
+    pub fn metadata(&self) -> Result<Metadata> {
+        let mode = self.entry_mode();
+        let mut meta = Metadata::new(mode);
 
-        let req_body = format!("paths={}&expand=True", 
percent_encode_path(&p));
+        if let Some(commit_info) = self.last_commit.as_ref() {
+            meta.set_last_modified(commit_info.date.parse::<Timestamp>()?);
+        }
 
-        let req = req
-            .body(Buffer::from(Bytes::from(req_body)))
-            .map_err(new_request_build_error)?;
+        if mode == EntryMode::FILE {
+            meta.set_content_length(self.size);
+            let etag = if let Some(lfs) = &self.lfs {
+                &lfs.oid
+            } else {
+                &self.oid
+            };
+            meta.set_etag(etag);
+        }
 
-        self.info.http_client().send(req).await
+        Ok(meta)
     }
+}
 
-    pub async fn hf_list(
-        &self,
-        path: &str,
-        recursive: bool,
-        cursor: Option<&str>,
-    ) -> Result<Response<Buffer>> {
-        let p = build_abs_path(&self.root, path)
-            .trim_end_matches('/')
-            .to_string();
-
-        let mut url = match self.repo_type {
-            RepoType::Model => format!(
-                "{}/api/models/{}/tree/{}/{}?expand=True",
-                &self.endpoint,
-                &self.repo_id,
-                percent_encode_revision(&self.revision),
-                percent_encode_path(&p)
-            ),
-            RepoType::Dataset => format!(
-                "{}/api/datasets/{}/tree/{}/{}?expand=True",
-                &self.endpoint,
-                &self.repo_id,
-                percent_encode_revision(&self.revision),
-                percent_encode_path(&p)
-            ),
-            RepoType::Space => format!(
-                "{}/api/spaces/{}/tree/{}/{}?expand=True",
-                &self.endpoint,
-                &self.repo_id,
-                percent_encode_revision(&self.revision),
-                percent_encode_path(&p)
-            ),
-        };
+#[derive(Deserialize, Eq, PartialEq, Debug)]
+pub(super) struct LfsInfo {
+    pub oid: String,
+}
 
-        if recursive {
-            url.push_str("&recursive=True");
-        }
+#[derive(Deserialize, Eq, PartialEq, Debug)]
+pub(super) struct LastCommit {
+    pub date: String,
+}
 
-        if let Some(cursor_val) = cursor {
-            url.push_str(&format!("&cursor={}", cursor_val));
-        }
+#[cfg(feature = "xet")]
+#[derive(Clone, Debug, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub(super) struct XetToken {
+    pub access_token: String,
+    pub cas_url: String,
+    pub exp: u64,
+}
 
-        let mut req = Request::get(&url);
-        // Inject operation to the request.
-        req = req.extension(Operation::List);
-        if let Some(token) = &self.token {
-            let auth_header_content = format_authorization_by_bearer(token)?;
-            req = req.header(header::AUTHORIZATION, auth_header_content);
-        }
+// Core HuggingFace client that manages API interactions, authentication
+// and shared logic for reader/writer/lister.
+
+#[derive(Clone)]
+pub struct HfCore {
+    pub info: Arc<AccessorInfo>,
+
+    pub repo: HfRepo,
+    pub root: String,
+    pub token: Option<String>,
+    pub endpoint: String,
+    pub max_retries: usize,
+
+    // Whether XET storage protocol is enabled for reads. When true
+    // and the `xet` feature is compiled in, reads will check for
+    // XET-backed files and use the XET protocol for downloading.
+    #[cfg(feature = "xet")]
+    pub xet_enabled: bool,
+
+    /// HTTP client with redirects disabled, used by XET probes to
+    /// inspect headers on 302 responses.
+    #[cfg(feature = "xet")]
+    pub no_redirect_client: HttpClient,
+}
 
-        let req = req.body(Buffer::new()).map_err(new_request_build_error)?;
+impl Debug for HfCore {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        let mut s = f.debug_struct("HfCore");
+        s.field("repo", &self.repo)
+            .field("root", &self.root)
+            .field("endpoint", &self.endpoint);
+        #[cfg(feature = "xet")]
+        s.field("xet_enabled", &self.xet_enabled);
+        s.finish_non_exhaustive()
+    }
+}
+
+impl HfCore {
+    pub fn new(
+        info: Arc<AccessorInfo>,
+        repo: HfRepo,
+        root: String,
+        token: Option<String>,
+        endpoint: String,
+        max_retries: usize,
+        #[cfg(feature = "xet")] xet_enabled: bool,
+    ) -> Result<Self> {
+        // When xet is enabled at runtime, use dedicated reqwest clients 
instead
+        // of the global one. This avoids "dispatch task is gone" errors when
+        // multiple tokio runtimes exist (e.g. in tests) and ensures the
+        // no-redirect client shares the same runtime as the standard client.
+        // When xet is disabled, preserve whatever HTTP client is already set
+        // on `info` (important for mock-based unit tests).
+        #[cfg(feature = "xet")]
+        let no_redirect_client = if xet_enabled {
+            let standard = 
HttpClient::with(build_reqwest(reqwest::redirect::Policy::default())?);
+            let no_redirect = 
HttpClient::with(build_reqwest(reqwest::redirect::Policy::none())?);
+            info.update_http_client(|_| standard);
+            no_redirect
+        } else {
+            info.http_client()
+        };
 
-        self.info.http_client().send(req).await
+        Ok(Self {
+            info,
+            repo,
+            root,
+            token,
+            endpoint,
+            max_retries,
+            #[cfg(feature = "xet")]
+            xet_enabled,
+            #[cfg(feature = "xet")]
+            no_redirect_client,
+        })
     }
 
-    pub async fn hf_list_with_url(&self, url: &str) -> 
Result<Response<Buffer>> {
-        let mut req = Request::get(url);
-        // Inject operation to the request.
-        req = req.extension(Operation::List);
+    /// Build an authenticated HTTP request.
+    pub(super) fn request(
+        &self,
+        method: http::Method,
+        url: &str,
+        op: Operation,
+    ) -> http::request::Builder {
+        let mut req = Request::builder().method(method).uri(url).extension(op);
         if let Some(token) = &self.token {
-            let auth_header_content = format_authorization_by_bearer(token)?;
-            req = req.header(header::AUTHORIZATION, auth_header_content);
+            if let Ok(auth) = format_authorization_by_bearer(token) {
+                req = req.header(header::AUTHORIZATION, auth);
+            }
         }
+        req
+    }
+
+    pub(super) fn uri(&self, path: &str) -> super::uri::HfUri {
+        self.repo.uri(&self.root, path)
+    }
 
-        let req = req.body(Buffer::new()).map_err(new_request_build_error)?;
+    /// Exponential backoff: 200ms, 400ms, 800ms, … capped at ~6s.
+    async fn backoff(attempt: usize) {
+        let millis = 200u64 * (1u64 << attempt.min(5));
+        tokio::time::sleep(std::time::Duration::from_millis(millis)).await;
+    }

Review Comment:
   Good idea, updating!



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to