This is an automated email from the ASF dual-hosted git repository.
liurenjie1024 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-rust.git
The following commit(s) were added to refs/heads/main by this push:
new d6703df feat: implement OAuth for catalog rest client (#254)
d6703df is described below
commit d6703df40b24477d0a5a36939746bb1b36cc6933
Author: TennyZhuang <[email protected]>
AuthorDate: Mon Mar 18 11:32:53 2024 +0800
feat: implement OAuth for catalog rest client (#254)
---
crates/catalog/rest/src/catalog.rs | 118 +++++++++++++++++++++++--
crates/catalog/rest/tests/rest_catalog_test.rs | 1 +
2 files changed, 114 insertions(+), 5 deletions(-)
diff --git a/crates/catalog/rest/src/catalog.rs
b/crates/catalog/rest/src/catalog.rs
index 812ac82..ae9ae10 100644
--- a/crates/catalog/rest/src/catalog.rs
+++ b/crates/catalog/rest/src/catalog.rs
@@ -38,7 +38,7 @@ use iceberg::{
use self::_serde::{
CatalogConfig, ErrorResponse, ListNamespaceResponse, ListTableResponse,
NamespaceSerde,
- RenameTableRequest, NO_CONTENT, OK,
+ RenameTableRequest, TokenResponse, NO_CONTENT, OK,
};
const ICEBERG_REST_SPEC_VERSION: &str = "0.14.1";
@@ -96,9 +96,13 @@ impl RestCatalogConfig {
.join("/")
}
+ fn get_token_endpoint(&self) -> String {
+ [&self.uri, PATH_V1, "oauth", "tokens"].join("/")
+ }
+
fn try_create_rest_client(&self) -> Result<HttpClient> {
- //TODO: We will add oauth, ssl config, sigv4 later
- let headers = HeaderMap::from_iter([
+ // TODO: We will add ssl config, sigv4 later
+ let mut headers = HeaderMap::from_iter([
(
header::CONTENT_TYPE,
HeaderValue::from_static("application/json"),
@@ -113,6 +117,19 @@ impl RestCatalogConfig {
),
]);
+ if let Some(token) = self.props.get("token") {
+ headers.insert(
+ header::AUTHORIZATION,
+ HeaderValue::from_str(&format!("Bearer {token}")).map_err(|e| {
+ Error::new(
+ ErrorKind::DataInvalid,
+ "Invalid token received from catalog server!",
+ )
+ .with_source(e)
+ })?,
+ );
+ }
+
Ok(HttpClient(
Client::builder().default_headers(headers).build()?,
))
@@ -144,6 +161,7 @@ impl HttpClient {
.with_source(e)
})?)
} else {
+ let code = resp.status();
let text = resp.bytes().await?;
let e = serde_json::from_slice::<E>(&text).map_err(|e| {
Error::new(
@@ -151,6 +169,7 @@ impl HttpClient {
"Failed to parse response from rest catalog server!",
)
.with_context("json", String::from_utf8_lossy(&text))
+ .with_context("code", code.to_string())
.with_source(e)
})?;
Err(e.into())
@@ -497,13 +516,56 @@ impl RestCatalog {
client: config.try_create_rest_client()?,
config,
};
-
+ catalog.fetch_access_token().await?;
+ catalog.client = catalog.config.try_create_rest_client()?;
catalog.update_config().await?;
catalog.client = catalog.config.try_create_rest_client()?;
Ok(catalog)
}
+ async fn fetch_access_token(&mut self) -> Result<()> {
+ if self.config.props.contains_key("token") {
+ return Ok(());
+ }
+ if let Some(credential) = self.config.props.get("credential") {
+ let (client_id, client_secret) = if credential.contains(':') {
+ let (client_id, client_secret) =
credential.split_once(':').unwrap();
+ (Some(client_id), client_secret)
+ } else {
+ (None, credential.as_str())
+ };
+ let mut params = HashMap::with_capacity(4);
+ params.insert("grant_type", "client_credentials");
+ if let Some(client_id) = client_id {
+ params.insert("client_id", client_id);
+ }
+ params.insert("client_secret", client_secret);
+ params.insert("scope", "catalog");
+ let req = self
+ .client
+ .0
+ .post(self.config.get_token_endpoint())
+ .form(¶ms)
+ .build()?;
+ let res = self
+ .client
+ .query::<TokenResponse, ErrorResponse, OK>(req)
+ .await
+ .map_err(|e| {
+ Error::new(
+ ErrorKind::Unexpected,
+ "Failed to fetch access token from catalog server!",
+ )
+ .with_source(e)
+ })?;
+ let token = res.access_token;
+ self.config.props.insert("token".to_string(), token);
+ }
+
+ Ok(())
+ }
+
async fn update_config(&mut self) -> Result<()> {
let mut request = self.client.0.get(self.config.config_endpoint());
@@ -626,6 +688,14 @@ mod _serde {
}
}
+ #[derive(Debug, Serialize, Deserialize)]
+ pub(super) struct TokenResponse {
+ pub(super) access_token: String,
+ pub(super) token_type: String,
+ pub(super) expires_in: Option<u64>,
+ pub(super) issued_token_type: Option<String>,
+ }
+
#[derive(Debug, Serialize, Deserialize)]
pub(super) struct NamespaceSerde {
pub(super) namespace: Vec<String>,
@@ -778,6 +848,44 @@ mod tests {
.await
}
+ async fn create_oauth_mock(server: &mut ServerGuard) -> Mock {
+ server
+ .mock("POST", "/v1/oauth/tokens")
+ .with_status(200)
+ .with_body(
+ r#"{
+ "access_token": "ey000000000000",
+ "token_type": "Bearer",
+ "issued_token_type":
"urn:ietf:params:oauth:token-type:access_token",
+ "expires_in": 86400
+ }"#,
+ )
+ .create_async()
+ .await
+ }
+
+ #[tokio::test]
+ async fn test_oauth() {
+ let mut server = Server::new_async().await;
+ let oauth_mock = create_oauth_mock(&mut server).await;
+ let config_mock = create_config_mock(&mut server).await;
+
+ let mut props = HashMap::new();
+ props.insert("credential".to_string(), "client1:secret1".to_string());
+
+ let _catalog = RestCatalog::new(
+ RestCatalogConfig::builder()
+ .uri(server.url())
+ .props(props)
+ .build(),
+ )
+ .await
+ .unwrap();
+
+ oauth_mock.assert_async().await;
+ config_mock.assert_async().await;
+ }
+
#[tokio::test]
async fn test_list_namespace() {
let mut server = Server::new_async().await;
@@ -1557,7 +1665,7 @@ mod tests {
"type": "NoSuchTableException",
"code": 404
}
-}
+}
"#,
)
.create_async()
diff --git a/crates/catalog/rest/tests/rest_catalog_test.rs
b/crates/catalog/rest/tests/rest_catalog_test.rs
index a4d0795..205428d 100644
--- a/crates/catalog/rest/tests/rest_catalog_test.rs
+++ b/crates/catalog/rest/tests/rest_catalog_test.rs
@@ -66,6 +66,7 @@ async fn set_test_fixture(func: &str) -> TestFixture {
rest_catalog,
}
}
+
#[tokio::test]
async fn test_get_non_exist_namespace() {
let fixture = set_test_fixture("test_get_non_exist_namespace").await;