This is an automated email from the ASF dual-hosted git repository.
yzheng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/polaris-tools.git
The following commit(s) were added to refs/heads/main by this push:
new 19c56b7 [MCP] feat: Add realm support (#83)
19c56b7 is described below
commit 19c56b70903c04df995f6f8b0441997c13380d95
Author: Yong Zheng <[email protected]>
AuthorDate: Wed Dec 3 19:07:02 2025 -0600
[MCP] feat: Add realm support (#83)
---
mcp-server/README.md | 7 +-
mcp-server/polaris_mcp/authorization.py | 116 +++++++++-----
mcp-server/polaris_mcp/rest.py | 9 +-
mcp-server/polaris_mcp/server.py | 39 +++--
mcp-server/polaris_mcp/tools/catalog.py | 4 +
mcp-server/polaris_mcp/tools/catalog_role.py | 4 +
mcp-server/polaris_mcp/tools/namespace.py | 4 +
mcp-server/polaris_mcp/tools/policy.py | 4 +
mcp-server/polaris_mcp/tools/principal.py | 4 +
mcp-server/polaris_mcp/tools/principal_role.py | 4 +
mcp-server/polaris_mcp/tools/table.py | 4 +
mcp-server/tests/test_authorization.py | 210 ++++++++++++++++++++++---
mcp-server/tests/test_rest_tool.py | 65 ++++++++
mcp-server/tests/test_server.py | 5 +-
14 files changed, 400 insertions(+), 79 deletions(-)
diff --git a/mcp-server/README.md b/mcp-server/README.md
index 98f7998..a6c5fd4 100644
--- a/mcp-server/README.md
+++ b/mcp-server/README.md
@@ -157,6 +157,11 @@ uv run client.py http://localhost:8000/mcp \
| `POLARIS_CLIENT_SECRET` | OAuth
client secret. | _unset_
|
| `POLARIS_TOKEN_SCOPE` | OAuth scope
string. | _unset_
|
| `POLARIS_TOKEN_URL` | Optional
override for the token endpoint URL. |
`${POLARIS_BASE_URL}api/catalog/v1/oauth/tokens` |
+| `POLARIS_REALM_{realm}_CLIENT_ID` | OAuth
client id for a specific realm. | _unset_
|
+| `POLARIS_REALM_{realm}_CLIENT_SECRET` | OAuth
client secret for a specific realm. | _unset_
|
+| `POLARIS_REALM_{realm}_TOKEN_SCOPE` | OAuth scope
for a specific realm. | _unset_
|
+| `POLARIS_REALM_{realm}_TOKEN_URL` | Token
endpoint URL for a specific realm. | _unset_
|
+| `POLARIS_REALM_CONTEXT_HEADER_NAME` | Header name
used for realm context. | `Polaris-Realm`
|
| `POLARIS_TOKEN_REFRESH_BUFFER_SECONDS` | Minimum
remaining token lifetime before refreshing in seconds. | `60.0`
|
| `POLARIS_HTTP_TIMEOUT_SECONDS` | Default
timeout in seconds for all HTTP requests. | `30.0`
|
| `POLARIS_HTTP_CONNECT_TIMEOUT_SECONDS` | Timeout in
seconds for establishing HTTP connections. | `30.0`
|
@@ -166,8 +171,8 @@ uv run client.py http://localhost:8000/mcp \
| `POLARIS_CONFIG_FILE` | Path to a
configuration file containing configuration variables. | `.polaris_mcp.env` in
current working directory |
-
When OAuth variables are supplied, the server automatically acquires and
refreshes tokens using the client credentials flow; otherwise a static bearer
token is used if provided.
+Realm-specific variables (e.g., `POLARIS_REALM_${realm}_CLIENT_ID`) override
the global settings for a given realm for client ID, client secret, token
scope, and token URL. If realm-specific credentials are provided but
incomplete, the server will not fall back to global credentials for that realm.
## Tools
diff --git a/mcp-server/polaris_mcp/authorization.py
b/mcp-server/polaris_mcp/authorization.py
index 2b61894..6f38076 100644
--- a/mcp-server/polaris_mcp/authorization.py
+++ b/mcp-server/polaris_mcp/authorization.py
@@ -22,11 +22,12 @@
from __future__ import annotations
import json
+import os
import threading
import time
from abc import ABC, abstractmethod
from typing import Optional
-from urllib.parse import urlencode
+from urllib.parse import urlencode, urljoin
import urllib3
@@ -35,7 +36,7 @@ class AuthorizationProvider(ABC):
"""Return Authorization header values for outgoing requests."""
@abstractmethod
- def authorization_header(self) -> Optional[str]: ...
+ def authorization_header(self, realm: Optional[str] = None) ->
Optional[str]: ...
class StaticAuthorizationProvider(AuthorizationProvider):
@@ -45,7 +46,7 @@ class StaticAuthorizationProvider(AuthorizationProvider):
value = (token or "").strip()
self._header = f"Bearer {value}" if value else None
- def authorization_header(self) -> Optional[str]:
+ def authorization_header(self, realm: Optional[str] = None) ->
Optional[str]:
return self._header
@@ -54,59 +55,100 @@ class
ClientCredentialsAuthorizationProvider(AuthorizationProvider):
def __init__(
self,
- token_endpoint: str,
- client_id: str,
- client_secret: str,
- scope: Optional[str],
+ base_url: str,
http: urllib3.PoolManager,
refresh_buffer_seconds: float,
timeout: urllib3.Timeout,
) -> None:
- self._token_endpoint = token_endpoint
- self._client_id = client_id
- self._client_secret = client_secret
- self._scope = scope
+ self._base_url = base_url
self._http = http
+ self._refresh_buffer_seconds = max(refresh_buffer_seconds, 0.0)
self._timeout = timeout
self._lock = threading.Lock()
- self._cached: Optional[tuple[str, float]] = None # (token,
expires_at_epoch)
- self._refresh_buffer_seconds = max(refresh_buffer_seconds, 0.0)
+ # {realm: (token, expires_at_epoch)}
+ self._cached: dict[str, tuple[str, float]] = {}
- def authorization_header(self) -> Optional[str]:
- token = self._current_token()
+ def authorization_header(self, realm: Optional[str] = None) ->
Optional[str]:
+ token = self._get_token_from_realm(realm)
return f"Bearer {token}" if token else None
- def _current_token(self) -> Optional[str]:
- now = time.time()
- cached = self._cached
- if not cached or cached[1] - self._refresh_buffer_seconds <= now:
- with self._lock:
- cached = self._cached
- if (
- not cached
- or cached[1] - self._refresh_buffer_seconds <= time.time()
- ):
- self._cached = cached = self._fetch_token()
- return cached[0] if cached else None
-
- def _fetch_token(self) -> tuple[str, float]:
+ def _get_token_from_realm(self, realm: Optional[str]) -> Optional[str]:
+ def needs_refresh(cached):
+ return (
+ cached is None
+ or cached[1] - self._refresh_buffer_seconds <= time.time()
+ )
+
+ cache_key = realm or ""
+ token = self._cached.get(cache_key)
+ # Token not expired
+ if not needs_refresh(token):
+ return token[0]
+ # Acquire lock and verify again if token expired
+ with self._lock:
+ token = self._cached.get(cache_key)
+ if needs_refresh(token):
+ credentials = self._get_credentials_from_realm(realm)
+ if not credentials:
+ return None
+ token = self._fetch_token(realm, credentials)
+ self._cached[cache_key] = token
+ return token[0] if token else None
+
+ def _get_credentials_from_realm(
+ self, realm: Optional[str]
+ ) -> Optional[dict[str, str]]:
+ def get_env(key: str) -> Optional[str]:
+ val = os.getenv(key)
+ return val.strip() or None if val else None
+
+ def load_creds(realm: Optional[str] = None) -> dict[str,
Optional[str]]:
+ prefix = f"POLARIS_REALM_{realm}_" if realm else "POLARIS_"
+ return {
+ "client_id": get_env(f"{prefix}CLIENT_ID"),
+ "client_secret": get_env(f"{prefix}CLIENT_SECRET"),
+ "scope": get_env(f"{prefix}TOKEN_SCOPE"),
+ "token_url": get_env(f"{prefix}TOKEN_URL"),
+ }
+
+ # Only use realm-specific credentials
+ if realm:
+ creds = load_creds(realm)
+ if creds["client_id"] and creds["client_secret"]:
+ return creds
+ return None
+ # No realm specified, use global credentials
+ creds = load_creds()
+ if creds["client_id"] and creds["client_secret"]:
+ return creds
+ return None
+
+ def _fetch_token(
+ self, realm: Optional[str], credentials: dict[str, str]
+ ) -> tuple[str, float]:
+ token_url = credentials.get("token_url") or urljoin(
+ self._base_url, "api/catalog/v1/oauth/tokens"
+ )
payload = {
"grant_type": "client_credentials",
- "client_id": self._client_id,
- "client_secret": self._client_secret,
+ "client_id": credentials["client_id"],
+ "client_secret": credentials["client_secret"],
}
- if self._scope:
- payload["scope"] = self._scope
+ if credentials.get("scope"):
+ payload["scope"] = credentials["scope"]
encoded = urlencode(payload)
+ header_name = os.getenv("POLARIS_REALM_CONTEXT_HEADER_NAME",
"Polaris-Realm")
+ headers = {"Content-Type": "application/x-www-form-urlencoded"}
+ if realm:
+ headers[header_name] = realm
response = self._http.request(
"POST",
- self._token_endpoint,
+ token_url,
body=encoded,
- headers={"Content-Type": "application/x-www-form-urlencoded"},
+ headers=headers,
timeout=self._timeout,
)
-
if response.status != 200:
raise RuntimeError(
f"OAuth token endpoint returned {response.status}:
{response.data.decode('utf-8', errors='ignore')}"
@@ -132,7 +174,7 @@ class
ClientCredentialsAuthorizationProvider(AuthorizationProvider):
class _NoneAuthorizationProvider(AuthorizationProvider):
- def authorization_header(self) -> Optional[str]:
+ def authorization_header(self, realm: Optional[str] = None) ->
Optional[str]:
return None
diff --git a/mcp-server/polaris_mcp/rest.py b/mcp-server/polaris_mcp/rest.py
index 8c9728d..6e02633 100644
--- a/mcp-server/polaris_mcp/rest.py
+++ b/mcp-server/polaris_mcp/rest.py
@@ -22,6 +22,7 @@
from __future__ import annotations
import json
+import os
from typing import Any, Dict, List, Optional, Tuple
from urllib.parse import urlencode, urljoin, urlsplit, urlunsplit, quote
@@ -230,6 +231,7 @@ class PolarisRestTool:
query_params = arguments.get("query")
headers_param = arguments.get("headers")
body_node = arguments.get("body")
+ realm = arguments.get("realm")
query = query_params if isinstance(query_params, dict) else None
headers = headers_param if isinstance(headers_param, dict) else None
@@ -238,9 +240,14 @@ class PolarisRestTool:
header_values = _merge_headers(headers)
if not any(name.lower() == "authorization" for name in header_values):
- token = self._authorization.authorization_header()
+ token = self._authorization.authorization_header(realm)
if token:
header_values["Authorization"] = token
+ header_name = os.getenv("POLARIS_REALM_CONTEXT_HEADER_NAME",
"Polaris-Realm")
+ if realm and not any(
+ name.lower() == header_name.lower() for name in header_values
+ ):
+ header_values[header_name] = realm
body_text = _serialize_body(body_node)
if body_text is not None and not any(
diff --git a/mcp-server/polaris_mcp/server.py b/mcp-server/polaris_mcp/server.py
index a338612..7ebb812 100644
--- a/mcp-server/polaris_mcp/server.py
+++ b/mcp-server/polaris_mcp/server.py
@@ -27,7 +27,7 @@ import logging.config
import argparse
import os
from typing import Any, Mapping, MutableMapping, Sequence, Optional
-from urllib.parse import urljoin, urlparse
+from urllib.parse import urlparse
import urllib3
from fastmcp import FastMCP
@@ -164,6 +164,7 @@ def create_server() -> FastMCP:
query: Mapping[str, str | Sequence[str]] | None = None,
headers: Mapping[str, str | Sequence[str]] | None = None,
body: Any | None = None,
+ realm: str | None = None,
) -> FastMcpToolResult:
return _call_tool(
table_tool,
@@ -177,6 +178,7 @@ def create_server() -> FastMCP:
"query": query,
"headers": headers,
"body": body,
+ "realm": realm,
},
transforms={
"namespace": _normalize_namespace,
@@ -198,6 +200,7 @@ def create_server() -> FastMCP:
query: Mapping[str, str | Sequence[str]] | None = None,
headers: Mapping[str, str | Sequence[str]] | None = None,
body: Any | None = None,
+ realm: str | None = None,
) -> FastMcpToolResult:
return _call_tool(
namespace_tool,
@@ -210,6 +213,7 @@ def create_server() -> FastMCP:
"query": query,
"headers": headers,
"body": body,
+ "realm": realm,
},
transforms={
"namespace": _normalize_namespace,
@@ -231,6 +235,7 @@ def create_server() -> FastMCP:
query: Mapping[str, str | Sequence[str]] | None = None,
headers: Mapping[str, str | Sequence[str]] | None = None,
body: Any | None = None,
+ realm: str | None = None,
) -> FastMcpToolResult:
return _call_tool(
principal_tool,
@@ -241,6 +246,7 @@ def create_server() -> FastMCP:
"query": query,
"headers": headers,
"body": body,
+ "realm": realm,
},
transforms={
"query": _copy_mapping,
@@ -262,6 +268,7 @@ def create_server() -> FastMCP:
query: Mapping[str, str | Sequence[str]] | None = None,
headers: Mapping[str, str | Sequence[str]] | None = None,
body: Any | None = None,
+ realm: str | None = None,
) -> FastMcpToolResult:
return _call_tool(
principal_role_tool,
@@ -273,6 +280,7 @@ def create_server() -> FastMCP:
"query": query,
"headers": headers,
"body": body,
+ "realm": realm,
},
transforms={
"query": _copy_mapping,
@@ -293,6 +301,7 @@ def create_server() -> FastMCP:
query: Mapping[str, str | Sequence[str]] | None = None,
headers: Mapping[str, str | Sequence[str]] | None = None,
body: Any | None = None,
+ realm: str | None = None,
) -> FastMcpToolResult:
return _call_tool(
catalog_role_tool,
@@ -305,6 +314,7 @@ def create_server() -> FastMCP:
"query": query,
"headers": headers,
"body": body,
+ "realm": realm,
},
transforms={
"query": _copy_mapping,
@@ -326,6 +336,7 @@ def create_server() -> FastMCP:
query: Mapping[str, str | Sequence[str]] | None = None,
headers: Mapping[str, str | Sequence[str]] | None = None,
body: Any | None = None,
+ realm: str | None = None,
) -> FastMcpToolResult:
return _call_tool(
policy_tool,
@@ -339,6 +350,7 @@ def create_server() -> FastMCP:
"query": query,
"headers": headers,
"body": body,
+ "realm": realm,
},
transforms={
"namespace": _normalize_namespace,
@@ -359,6 +371,7 @@ def create_server() -> FastMCP:
query: Mapping[str, str | Sequence[str]] | None = None,
headers: Mapping[str, str | Sequence[str]] | None = None,
body: Any | None = None,
+ realm: str | None = None,
) -> FastMcpToolResult:
return _call_tool(
catalog_tool,
@@ -368,6 +381,7 @@ def create_server() -> FastMCP:
"query": query,
"headers": headers,
"body": body,
+ "realm": realm,
},
transforms={
"query": _copy_mapping,
@@ -482,23 +496,21 @@ def _resolve_http_timeout() -> urllib3.Timeout:
def _resolve_authorization_provider(
- base_url: str, http: urllib3.PoolManager, timeout: urllib3.Timeout
+ base_url: str,
+ http: urllib3.PoolManager,
+ timeout: urllib3.Timeout,
) -> AuthorizationProvider:
token = _resolve_token()
if token:
return StaticAuthorizationProvider(token)
- client_id = _first_non_blank(
- os.getenv("POLARIS_CLIENT_ID"),
- )
- client_secret = _first_non_blank(
- os.getenv("POLARIS_CLIENT_SECRET"),
+ client_id = _first_non_blank(os.getenv("POLARIS_CLIENT_ID"))
+ client_secret = _first_non_blank(os.getenv("POLARIS_CLIENT_SECRET"))
+ has_realm_credentials = any(
+ key.startswith("POLARIS_REALM_") for key in os.environ.keys()
)
- if client_id and client_secret:
- scope = _first_non_blank(os.getenv("POLARIS_TOKEN_SCOPE"))
- token_url = _first_non_blank(os.getenv("POLARIS_TOKEN_URL"))
- endpoint = token_url or urljoin(base_url,
"api/catalog/v1/oauth/tokens")
+ if client_id and client_secret or has_realm_credentials:
refresh_buffer_seconds = DEFAULT_TOKEN_REFRESH_BUFFER_SECONDS
refresh_buffer_seconds_str =
os.getenv("POLARIS_TOKEN_REFRESH_BUFFER_SECONDS")
if refresh_buffer_seconds_str:
@@ -507,10 +519,7 @@ def _resolve_authorization_provider(
except ValueError:
pass
return ClientCredentialsAuthorizationProvider(
- token_endpoint=endpoint,
- client_id=client_id,
- client_secret=client_secret,
- scope=scope,
+ base_url=base_url,
http=http,
refresh_buffer_seconds=refresh_buffer_seconds,
timeout=timeout,
diff --git a/mcp-server/polaris_mcp/tools/catalog.py
b/mcp-server/polaris_mcp/tools/catalog.py
index cdbd909..5c0dc8f 100644
--- a/mcp-server/polaris_mcp/tools/catalog.py
+++ b/mcp-server/polaris_mcp/tools/catalog.py
@@ -106,6 +106,10 @@ class PolarisCatalogTool(McpTool):
copy_if_object(arguments.get("query"), delegate_args, "query")
copy_if_object(arguments.get("headers"), delegate_args, "headers")
+ realm = arguments.get("realm")
+ if isinstance(realm, str) and realm.strip():
+ delegate_args["realm"] = realm
+
if normalized == "list":
delegate_args["method"] = "GET"
delegate_args["path"] = "catalogs"
diff --git a/mcp-server/polaris_mcp/tools/catalog_role.py
b/mcp-server/polaris_mcp/tools/catalog_role.py
index eeb0111..8d13ec2 100644
--- a/mcp-server/polaris_mcp/tools/catalog_role.py
+++ b/mcp-server/polaris_mcp/tools/catalog_role.py
@@ -127,6 +127,10 @@ class PolarisCatalogRoleTool(McpTool):
copy_if_object(arguments.get("query"), delegate_args, "query")
copy_if_object(arguments.get("headers"), delegate_args, "headers")
+ realm = arguments.get("realm")
+ if isinstance(realm, str) and realm.strip():
+ delegate_args["realm"] = realm
+
base_path = f"catalogs/{catalog}/catalog-roles"
if normalized == "list":
diff --git a/mcp-server/polaris_mcp/tools/namespace.py
b/mcp-server/polaris_mcp/tools/namespace.py
index 02f312d..7d34070 100644
--- a/mcp-server/polaris_mcp/tools/namespace.py
+++ b/mcp-server/polaris_mcp/tools/namespace.py
@@ -132,6 +132,10 @@ class PolarisNamespaceTool(McpTool):
copy_if_object(arguments.get("query"), delegate_args, "query")
copy_if_object(arguments.get("headers"), delegate_args, "headers")
+ realm = arguments.get("realm")
+ if isinstance(realm, str) and realm.strip():
+ delegate_args["realm"] = realm
+
if normalized == "list":
self._handle_list(delegate_args, catalog)
elif normalized == "get":
diff --git a/mcp-server/polaris_mcp/tools/policy.py
b/mcp-server/polaris_mcp/tools/policy.py
index 5463aa8..f8eee13 100644
--- a/mcp-server/polaris_mcp/tools/policy.py
+++ b/mcp-server/polaris_mcp/tools/policy.py
@@ -143,6 +143,10 @@ class PolarisPolicyTool(McpTool):
copy_if_object(arguments.get("query"), delegate_args, "query")
copy_if_object(arguments.get("headers"), delegate_args, "headers")
+ realm = arguments.get("realm")
+ if isinstance(realm, str) and realm.strip():
+ delegate_args["realm"] = realm
+
if normalized == "list":
self._require_namespace(namespace, "list")
self._handle_list(delegate_args, catalog, namespace)
diff --git a/mcp-server/polaris_mcp/tools/principal.py
b/mcp-server/polaris_mcp/tools/principal.py
index 60470fb..9bea911 100644
--- a/mcp-server/polaris_mcp/tools/principal.py
+++ b/mcp-server/polaris_mcp/tools/principal.py
@@ -129,6 +129,10 @@ class PolarisPrincipalTool(McpTool):
copy_if_object(arguments.get("query"), delegate_args, "query")
copy_if_object(arguments.get("headers"), delegate_args, "headers")
+ realm = arguments.get("realm")
+ if isinstance(realm, str) and realm.strip():
+ delegate_args["realm"] = realm
+
if normalized == "list":
self._handle_list(delegate_args)
elif normalized == "create":
diff --git a/mcp-server/polaris_mcp/tools/principal_role.py
b/mcp-server/polaris_mcp/tools/principal_role.py
index 2941769..b23cce8 100644
--- a/mcp-server/polaris_mcp/tools/principal_role.py
+++ b/mcp-server/polaris_mcp/tools/principal_role.py
@@ -135,6 +135,10 @@ class PolarisPrincipalRoleTool(McpTool):
copy_if_object(arguments.get("query"), delegate_args, "query")
copy_if_object(arguments.get("headers"), delegate_args, "headers")
+ realm = arguments.get("realm")
+ if isinstance(realm, str) and realm.strip():
+ delegate_args["realm"] = realm
+
if normalized == "list":
delegate_args["method"] = "GET"
delegate_args["path"] = "principal-roles"
diff --git a/mcp-server/polaris_mcp/tools/table.py
b/mcp-server/polaris_mcp/tools/table.py
index a92a40b..d021a2f 100644
--- a/mcp-server/polaris_mcp/tools/table.py
+++ b/mcp-server/polaris_mcp/tools/table.py
@@ -123,6 +123,10 @@ class PolarisTableTool(McpTool):
copy_if_object(arguments.get("query"), delegate_args, "query")
copy_if_object(arguments.get("headers"), delegate_args, "headers")
+ realm = arguments.get("realm")
+ if isinstance(realm, str) and realm.strip():
+ delegate_args["realm"] = realm
+
if normalized == "list":
self._handle_list(delegate_args, catalog, namespace)
elif normalized == "get":
diff --git a/mcp-server/tests/test_authorization.py
b/mcp-server/tests/test_authorization.py
index dfb0945..3c89d32 100644
--- a/mcp-server/tests/test_authorization.py
+++ b/mcp-server/tests/test_authorization.py
@@ -51,6 +51,9 @@ def test_none_authorization_provider_returns_none() -> None:
def test_client_credentials_fetches_and_caches_tokens(
monkeypatch: pytest.MonkeyPatch,
) -> None:
+ monkeypatch.setenv("POLARIS_CLIENT_ID", "client")
+ monkeypatch.setenv("POLARIS_CLIENT_SECRET", "secret")
+ monkeypatch.setenv("POLARIS_TOKEN_URL", "https://auth/token")
http = mock.Mock()
now = time.time()
response = SimpleNamespace(
@@ -60,12 +63,9 @@ def test_client_credentials_fetches_and_caches_tokens(
http.request.return_value = response
provider = ClientCredentialsAuthorizationProvider(
- token_endpoint="https://auth/token",
- client_id="client",
- client_secret="secret",
- scope=None,
+ base_url="https://polaris/",
http=http,
- refresh_buffer_seconds=0.0,
+ refresh_buffer_seconds=60.0,
timeout=mock.sentinel.timeout,
)
@@ -77,7 +77,9 @@ def test_client_credentials_fetches_and_caches_tokens(
assert header2 == "Bearer abc"
http.request.assert_called_once()
- body = http.request.call_args.kwargs["body"]
+ args, kwargs = http.request.call_args
+ assert args[1] == "https://auth/token"
+ body = kwargs["body"]
assert "grant_type=client_credentials" in body
assert "client_id=client" in body
assert "client_secret=secret" in body
@@ -96,7 +98,44 @@ def test_client_credentials_fetches_and_caches_tokens(
http.request.assert_called_once()
-def test_client_credentials_refresh_buffer() -> None:
+def test_client_credentials_fetches_and_caches_realm_specific_token(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ http_mock = mock.Mock()
+ provider = ClientCredentialsAuthorizationProvider(
+ base_url="https://polaris/",
+ http=http_mock,
+ refresh_buffer_seconds=60.0,
+ timeout=mock.sentinel.timeout,
+ )
+ monkeypatch.setenv("POLARIS_REALM_TEST_REALM_CLIENT_ID", "realm_client")
+ monkeypatch.setenv("POLARIS_REALM_TEST_REALM_CLIENT_SECRET",
"realm_secret")
+ monkeypatch.setenv("POLARIS_REALM_TEST_REALM_TOKEN_URL",
"https://realm-auth/token")
+
+ now = time.time()
+ response = SimpleNamespace(
+ status=200,
+ data=json.dumps({"access_token": "realm_token", "expires_in":
3600}).encode(
+ "utf-8"
+ ),
+ )
+ http_mock.request.return_value = response
+
+ with mock.patch("time.time", return_value=now):
+ header = provider.authorization_header(realm="TEST_REALM")
+
+ assert header == "Bearer realm_token"
+ http_mock.request.assert_called_once()
+ args, kwargs = http_mock.request.call_args
+ assert args[1] == "https://realm-auth/token"
+ assert "client_id=realm_client" in kwargs["body"]
+ assert "Polaris-Realm" in kwargs["headers"]
+ assert kwargs["headers"]["Polaris-Realm"] == "TEST_REALM"
+
+
+def test_client_credentials_refresh_buffer(monkeypatch: pytest.MonkeyPatch) ->
None:
+ monkeypatch.setenv("POLARIS_CLIENT_ID", "client")
+ monkeypatch.setenv("POLARIS_CLIENT_SECRET", "secret")
http = mock.Mock()
now = time.time()
expires_in = 120
@@ -111,10 +150,7 @@ def test_client_credentials_refresh_buffer() -> None:
http.request.return_value = response
provider = ClientCredentialsAuthorizationProvider(
- token_endpoint="https://auth/token",
- client_id="client",
- client_secret="secret",
- scope=None,
+ base_url="https://polaris/",
http=http,
refresh_buffer_seconds=refresh_buffer,
timeout=mock.sentinel.timeout,
@@ -154,6 +190,18 @@ def test_client_credentials_refresh_buffer() -> None:
http.request.assert_not_called()
+def test_client_credentials_returns_none_if_no_credentials() -> None:
+ http_mock = mock.Mock()
+ provider = ClientCredentialsAuthorizationProvider(
+ base_url="https://polaris/",
+ http=http_mock,
+ refresh_buffer_seconds=60.0,
+ timeout=mock.sentinel.timeout,
+ )
+ assert provider.authorization_header() is None
+ assert provider.authorization_header(realm="foo") is None
+
+
@pytest.mark.parametrize(
"payload,expected_message",
[
@@ -163,8 +211,10 @@ def test_client_credentials_refresh_buffer() -> None:
],
)
def test_client_credentials_rejects_invalid_responses(
- payload: object, expected_message: str
+ payload: object, expected_message: str, monkeypatch: pytest.MonkeyPatch
) -> None:
+ monkeypatch.setenv("POLARIS_CLIENT_ID", "client")
+ monkeypatch.setenv("POLARIS_CLIENT_SECRET", "secret")
http = mock.Mock()
if isinstance(payload, str):
data = payload.encode("utf-8")
@@ -173,10 +223,7 @@ def test_client_credentials_rejects_invalid_responses(
http.request.return_value = SimpleNamespace(status=200, data=data)
provider = ClientCredentialsAuthorizationProvider(
- token_endpoint="https://auth/token",
- client_id="client",
- client_secret="secret",
- scope=None,
+ base_url="https://polaris/",
http=http,
refresh_buffer_seconds=0.0,
timeout=mock.sentinel.timeout,
@@ -186,15 +233,16 @@ def test_client_credentials_rejects_invalid_responses(
provider.authorization_header()
-def test_client_credentials_errors_on_non_200_status() -> None:
+def test_client_credentials_errors_on_non_200_status(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ monkeypatch.setenv("POLARIS_CLIENT_ID", "client")
+ monkeypatch.setenv("POLARIS_CLIENT_SECRET", "secret")
http = mock.Mock()
http.request.return_value = SimpleNamespace(status=500, data=b"boom")
provider = ClientCredentialsAuthorizationProvider(
- token_endpoint="https://auth/token",
- client_id="client",
- client_secret="secret",
- scope=None,
+ base_url="https://polaris/",
http=http,
refresh_buffer_seconds=0.0,
timeout=mock.sentinel.timeout,
@@ -202,3 +250,123 @@ def test_client_credentials_errors_on_non_200_status() ->
None:
with pytest.raises(RuntimeError, match="500"):
provider.authorization_header()
+
+
+def test_client_credentials_caches_tokens_separately_for_each_realm(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ http_mock = mock.Mock()
+ provider = ClientCredentialsAuthorizationProvider(
+ base_url="https://polaris/",
+ http=http_mock,
+ refresh_buffer_seconds=60.0,
+ timeout=mock.sentinel.timeout,
+ )
+ # Global creds
+ monkeypatch.setenv("POLARIS_CLIENT_ID", "global_client")
+ monkeypatch.setenv("POLARIS_CLIENT_SECRET", "global_secret")
+ # Realm creds
+ monkeypatch.setenv("POLARIS_REALM_realm1_CLIENT_ID", "realm1_client")
+ monkeypatch.setenv("POLARIS_REALM_realm1_CLIENT_SECRET", "realm1_secret")
+
+ # First call for global
+ http_mock.request.return_value = SimpleNamespace(
+ status=200,
+ data=json.dumps({"access_token": "global_token"}).encode("utf-8"),
+ )
+ assert provider.authorization_header() == "Bearer global_token"
+ http_mock.request.assert_called_once()
+ assert "client_id=global_client" in
http_mock.request.call_args.kwargs["body"]
+
+ # First call for realm1
+ http_mock.request.return_value = SimpleNamespace(
+ status=200,
+ data=json.dumps({"access_token": "realm1_token"}).encode("utf-8"),
+ )
+ assert provider.authorization_header(realm="realm1") == "Bearer
realm1_token"
+ assert http_mock.request.call_count == 2
+ assert "client_id=realm1_client" in
http_mock.request.call_args.kwargs["body"]
+
+ # Second call for global should hit cache
+ assert provider.authorization_header() == "Bearer global_token"
+ assert http_mock.request.call_count == 2
+
+ # Second call for realm1 should hit cache
+ assert provider.authorization_header(realm="realm1") == "Bearer
realm1_token"
+ assert http_mock.request.call_count == 2
+
+
+def test_with_realm_header(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ realm_name = "TEST_REALM"
+ http = mock.Mock()
+ http.request.return_value = SimpleNamespace(
+ status=200,
+ data=json.dumps({"access_token": "token", "expires_in":
3600}).encode("utf-8"),
+ )
+ monkeypatch.setenv(f"POLARIS_REALM_{realm_name}_CLIENT_ID", "client")
+ monkeypatch.setenv(f"POLARIS_REALM_{realm_name}_CLIENT_SECRET", "secret")
+ provider = ClientCredentialsAuthorizationProvider(
+ base_url="https://polaris/",
+ http=http,
+ refresh_buffer_seconds=60.0,
+ timeout=mock.sentinel.timeout,
+ )
+ provider.authorization_header(realm=realm_name)
+ call_args = http.request.call_args
+ headers = call_args[1]["headers"]
+ assert headers["Polaris-Realm"] == realm_name
+
+
+def test_with_custom_realm_header(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ realm_name = "TEST_REALM"
+ monkeypatch.setenv("POLARIS_REALM_CONTEXT_HEADER_NAME", "X-Polaris-Realm")
+ http = mock.Mock()
+ http.request.return_value = SimpleNamespace(
+ status=200,
+ data=json.dumps({"access_token": "token", "expires_in":
3600}).encode("utf-8"),
+ )
+ monkeypatch.setenv(f"POLARIS_REALM_{realm_name}_CLIENT_ID", "client")
+ monkeypatch.setenv(f"POLARIS_REALM_{realm_name}_CLIENT_SECRET", "secret")
+ provider = ClientCredentialsAuthorizationProvider(
+ base_url="https://polaris/",
+ http=http,
+ refresh_buffer_seconds=60.0,
+ timeout=mock.sentinel.timeout,
+ )
+ provider.authorization_header(realm=realm_name)
+ call_args = http.request.call_args
+ headers = call_args[1]["headers"]
+ assert headers["X-Polaris-Realm"] == realm_name
+
+
+def test_two_realms_one_incomplete(monkeypatch: pytest.MonkeyPatch) -> None:
+ realm1_name = "TEST_REALM"
+ realm2_name = "TEST2_REALM"
+ http = mock.Mock()
+ provider = ClientCredentialsAuthorizationProvider(
+ base_url="https://polaris/",
+ http=http,
+ refresh_buffer_seconds=60.0,
+ timeout=mock.sentinel.timeout,
+ )
+ # Realm 1 – complete credentials
+ monkeypatch.setenv(f"POLARIS_REALM_{realm1_name}_CLIENT_ID", "client")
+ monkeypatch.setenv(f"POLARIS_REALM_{realm1_name}_CLIENT_SECRET", "secret")
+ # Realm 2 – missing secret
+ monkeypatch.setenv(f"POLARIS_REALM_{realm2_name}_CLIENT_ID", "client2")
+ # Mock response for realm 1
+ http.request.return_value = SimpleNamespace(
+ status=200,
+ data=json.dumps({"access_token": "token", "expires_in":
3600}).encode("utf-8"),
+ )
+ # Realm 1 should succeed
+ assert provider.authorization_header(realm=f"{realm1_name}") == "Bearer
token"
+ assert http.request.call_count == 1
+ # Realm 2 should return None and not trigger an HTTP request
+ http.request.reset_mock()
+ assert provider.authorization_header(realm=f"{realm2_name}") is None
+ assert http.request.call_count == 0
diff --git a/mcp-server/tests/test_rest_tool.py
b/mcp-server/tests/test_rest_tool.py
index cbd3682..07de663 100644
--- a/mcp-server/tests/test_rest_tool.py
+++ b/mcp-server/tests/test_rest_tool.py
@@ -169,3 +169,68 @@ def test_call_requires_non_empty_path() -> None:
tool.call({"method": "GET"})
http.request.assert_not_called()
+
+
+def test_call_with_realm() -> None:
+ tool, http, auth = _create_tool()
+ http.request.return_value = _build_response(status=200, body="{}")
+ tool.call(
+ {
+ "method": "GET",
+ "path": "namespace",
+ "realm": "realm1",
+ }
+ )
+ auth.authorization_header.assert_called_once_with("realm1")
+ call_args = http.request.call_args
+ headers = call_args[1]["headers"]
+ assert headers["Polaris-Realm"] == "realm1"
+
+
+def test_call_with_existed_realm() -> None:
+ tool, http, auth = _create_tool()
+ http.request.return_value = _build_response(status=200, body="{}")
+ tool.call(
+ {
+ "method": "GET",
+ "path": "namespace",
+ "headers": {"Polaris-Realm": "existing_realm"},
+ "realm": "realm1",
+ }
+ )
+ call_args = http.request.call_args
+ headers = call_args[1]["headers"]
+ assert headers["Polaris-Realm"] == "existing_realm"
+
+
+def test_call_with_custom_realm_header(
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ monkeypatch.setenv("POLARIS_REALM_CONTEXT_HEADER_NAME", "X-Polaris-Realm")
+ tool, http, auth = _create_tool()
+ http.request.return_value = _build_response(status=200, body="{}")
+ tool.call(
+ {
+ "method": "GET",
+ "path": "namespace",
+ "realm": "realm1",
+ }
+ )
+ call_args = http.request.call_args
+ headers = call_args[1]["headers"]
+ assert "Polaris-Realm" not in headers
+ assert headers["X-Polaris-Realm"] == "realm1"
+
+
+def test_call_without_provide_realm() -> None:
+ tool, http, auth = _create_tool()
+ http.request.return_value = _build_response(status=200, body="{}")
+ tool.call(
+ {
+ "method": "GET",
+ "path": "namespace",
+ }
+ )
+ call_args = http.request.call_args
+ headers = call_args[1]["headers"]
+ assert "Polaris-Realm" not in headers
diff --git a/mcp-server/tests/test_server.py b/mcp-server/tests/test_server.py
index 0ba55e0..5336b7d 100644
--- a/mcp-server/tests/test_server.py
+++ b/mcp-server/tests/test_server.py
@@ -242,10 +242,7 @@ class TestAuthorizationProviderResolution:
assert provider is fake_provider
mock_factory.assert_called_once_with(
- token_endpoint="https://oauth/token",
- client_id="client",
- client_secret="secret",
- scope="scope",
+ base_url="https://base/",
http=fake_http,
refresh_buffer_seconds=60.0,
timeout=mock.sentinel.timeout,