This is an automated email from the ASF dual-hosted git repository.

liurenjie1024 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-rust.git


The following commit(s) were added to refs/heads/main by this push:
     new d6703df  feat: implement OAuth for catalog rest client (#254)
d6703df is described below

commit d6703df40b24477d0a5a36939746bb1b36cc6933
Author: TennyZhuang <[email protected]>
AuthorDate: Mon Mar 18 11:32:53 2024 +0800

    feat: implement OAuth for catalog rest client (#254)
---
 crates/catalog/rest/src/catalog.rs             | 118 +++++++++++++++++++++++--
 crates/catalog/rest/tests/rest_catalog_test.rs |   1 +
 2 files changed, 114 insertions(+), 5 deletions(-)

diff --git a/crates/catalog/rest/src/catalog.rs 
b/crates/catalog/rest/src/catalog.rs
index 812ac82..ae9ae10 100644
--- a/crates/catalog/rest/src/catalog.rs
+++ b/crates/catalog/rest/src/catalog.rs
@@ -38,7 +38,7 @@ use iceberg::{
 
 use self::_serde::{
     CatalogConfig, ErrorResponse, ListNamespaceResponse, ListTableResponse, 
NamespaceSerde,
-    RenameTableRequest, NO_CONTENT, OK,
+    RenameTableRequest, TokenResponse, NO_CONTENT, OK,
 };
 
 const ICEBERG_REST_SPEC_VERSION: &str = "0.14.1";
@@ -96,9 +96,13 @@ impl RestCatalogConfig {
         .join("/")
     }
 
+    fn get_token_endpoint(&self) -> String {
+        [&self.uri, PATH_V1, "oauth", "tokens"].join("/")
+    }
+
     fn try_create_rest_client(&self) -> Result<HttpClient> {
-        //TODO: We will add oauth, ssl config, sigv4 later
-        let headers = HeaderMap::from_iter([
+        // TODO: We will add ssl config, sigv4 later
+        let mut headers = HeaderMap::from_iter([
             (
                 header::CONTENT_TYPE,
                 HeaderValue::from_static("application/json"),
@@ -113,6 +117,19 @@ impl RestCatalogConfig {
             ),
         ]);
 
+        if let Some(token) = self.props.get("token") {
+            headers.insert(
+                header::AUTHORIZATION,
+                HeaderValue::from_str(&format!("Bearer {token}")).map_err(|e| {
+                    Error::new(
+                        ErrorKind::DataInvalid,
+                        "Invalid token received from catalog server!",
+                    )
+                    .with_source(e)
+                })?,
+            );
+        }
+
         Ok(HttpClient(
             Client::builder().default_headers(headers).build()?,
         ))
@@ -144,6 +161,7 @@ impl HttpClient {
                 .with_source(e)
             })?)
         } else {
+            let code = resp.status();
             let text = resp.bytes().await?;
             let e = serde_json::from_slice::<E>(&text).map_err(|e| {
                 Error::new(
@@ -151,6 +169,7 @@ impl HttpClient {
                     "Failed to parse response from rest catalog server!",
                 )
                 .with_context("json", String::from_utf8_lossy(&text))
+                .with_context("code", code.to_string())
                 .with_source(e)
             })?;
             Err(e.into())
@@ -497,13 +516,56 @@ impl RestCatalog {
             client: config.try_create_rest_client()?,
             config,
         };
-
+        catalog.fetch_access_token().await?;
+        catalog.client = catalog.config.try_create_rest_client()?;
         catalog.update_config().await?;
         catalog.client = catalog.config.try_create_rest_client()?;
 
         Ok(catalog)
     }
 
+    async fn fetch_access_token(&mut self) -> Result<()> {
+        if self.config.props.contains_key("token") {
+            return Ok(());
+        }
+        if let Some(credential) = self.config.props.get("credential") {
+            let (client_id, client_secret) = if credential.contains(':') {
+                let (client_id, client_secret) = 
credential.split_once(':').unwrap();
+                (Some(client_id), client_secret)
+            } else {
+                (None, credential.as_str())
+            };
+            let mut params = HashMap::with_capacity(4);
+            params.insert("grant_type", "client_credentials");
+            if let Some(client_id) = client_id {
+                params.insert("client_id", client_id);
+            }
+            params.insert("client_secret", client_secret);
+            params.insert("scope", "catalog");
+            let req = self
+                .client
+                .0
+                .post(self.config.get_token_endpoint())
+                .form(&params)
+                .build()?;
+            let res = self
+                .client
+                .query::<TokenResponse, ErrorResponse, OK>(req)
+                .await
+                .map_err(|e| {
+                    Error::new(
+                        ErrorKind::Unexpected,
+                        "Failed to fetch access token from catalog server!",
+                    )
+                    .with_source(e)
+                })?;
+            let token = res.access_token;
+            self.config.props.insert("token".to_string(), token);
+        }
+
+        Ok(())
+    }
+
     async fn update_config(&mut self) -> Result<()> {
         let mut request = self.client.0.get(self.config.config_endpoint());
 
@@ -626,6 +688,14 @@ mod _serde {
         }
     }
 
+    #[derive(Debug, Serialize, Deserialize)]
+    pub(super) struct TokenResponse {
+        pub(super) access_token: String,
+        pub(super) token_type: String,
+        pub(super) expires_in: Option<u64>,
+        pub(super) issued_token_type: Option<String>,
+    }
+
     #[derive(Debug, Serialize, Deserialize)]
     pub(super) struct NamespaceSerde {
         pub(super) namespace: Vec<String>,
@@ -778,6 +848,44 @@ mod tests {
             .await
     }
 
+    async fn create_oauth_mock(server: &mut ServerGuard) -> Mock {
+        server
+            .mock("POST", "/v1/oauth/tokens")
+            .with_status(200)
+            .with_body(
+                r#"{
+                "access_token": "ey000000000000",
+                "token_type": "Bearer",
+                "issued_token_type": 
"urn:ietf:params:oauth:token-type:access_token",
+                "expires_in": 86400
+                }"#,
+            )
+            .create_async()
+            .await
+    }
+
+    #[tokio::test]
+    async fn test_oauth() {
+        let mut server = Server::new_async().await;
+        let oauth_mock = create_oauth_mock(&mut server).await;
+        let config_mock = create_config_mock(&mut server).await;
+
+        let mut props = HashMap::new();
+        props.insert("credential".to_string(), "client1:secret1".to_string());
+
+        let _catalog = RestCatalog::new(
+            RestCatalogConfig::builder()
+                .uri(server.url())
+                .props(props)
+                .build(),
+        )
+        .await
+        .unwrap();
+
+        oauth_mock.assert_async().await;
+        config_mock.assert_async().await;
+    }
+
     #[tokio::test]
     async fn test_list_namespace() {
         let mut server = Server::new_async().await;
@@ -1557,7 +1665,7 @@ mod tests {
         "type": "NoSuchTableException",
         "code": 404
     }
-}      
+}
             "#,
             )
             .create_async()
diff --git a/crates/catalog/rest/tests/rest_catalog_test.rs 
b/crates/catalog/rest/tests/rest_catalog_test.rs
index a4d0795..205428d 100644
--- a/crates/catalog/rest/tests/rest_catalog_test.rs
+++ b/crates/catalog/rest/tests/rest_catalog_test.rs
@@ -66,6 +66,7 @@ async fn set_test_fixture(func: &str) -> TestFixture {
         rest_catalog,
     }
 }
+
 #[tokio::test]
 async fn test_get_non_exist_namespace() {
     let fixture = set_test_fixture("test_get_non_exist_namespace").await;

Reply via email to