This is an automated email from the ASF dual-hosted git repository.
lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git
The following commit(s) were added to refs/heads/master by this push:
new 8a310149c8 [python] Fix token cache pollution in Python
`RESTTokenFileIO` by aligning with the Java implementation. (#7562)
8a310149c8 is described below
commit 8a310149c8875006e900a9be455a9730afe62bb0
Author: shyjsarah <[email protected]>
AuthorDate: Tue Mar 31 00:52:57 2026 -0700
[python] Fix token cache pollution in Python `RESTTokenFileIO` by aligning
with the Java implementation. (#7562)
Fix token cache pollution in Python `RESTTokenFileIO` by aligning with
the Java implementation.
**Problem:**
Python's `RESTTokenFileIO` had a class-level `_TOKEN_CACHE` (keyed by
table identifier string) that caused token sharing across different
catalog instances. When two catalogs with different AK/SK credentials
operated on the same table within one process, the second catalog would
reuse the first catalog's token — leading to permission errors (e.g.,
read AK/SK token used for write operations).
**Root Cause:**
Java's `RESTTokenFileIO` has **no token cache** — each instance manages
its own `token` field independently. Python added an extra
`_TOKEN_CACHE` class-level dict that Java never had.
**Changes:**
- Remove class-level `_TOKEN_CACHE`, `_TOKEN_LOCKS`, `_TOKEN_LOCKS_LOCK`
and their associated methods (`_get_cached_token`, `_set_cached_token`,
`_get_global_token_lock`, `_is_token_expired`)
- Simplify `try_to_refresh_token()` to use instance-level lock with
double-check pattern, aligned with Java's `tryToRefreshToken()`
- Merge `should_refresh()` and `_is_token_expired()` into a single
`_should_refresh()` method
- Add system table identifier handling in `refresh_token()` (strip
`$snapshots` suffix before requesting token), aligned with Java
The `FILE_IO_CACHE` (keyed by `RESTToken` object) is kept unchanged — it
correctly isolates different credentials since different tokens produce
different cache keys.
---
.../pypaimon/catalog/rest/rest_token_file_io.py | 95 ++++--------
.../pypaimon/tests/rest/rest_token_file_io_test.py | 164 ++++++++++++++++++---
2 files changed, 177 insertions(+), 82 deletions(-)
diff --git a/paimon-python/pypaimon/catalog/rest/rest_token_file_io.py
b/paimon-python/pypaimon/catalog/rest/rest_token_file_io.py
index 7bf984d1f5..16b16e6972 100644
--- a/paimon-python/pypaimon/catalog/rest/rest_token_file_io.py
+++ b/paimon-python/pypaimon/catalog/rest/rest_token_file_io.py
@@ -27,7 +27,7 @@ from pypaimon.api.rest_util import RESTUtil
from pypaimon.catalog.rest.rest_token import RESTToken
from pypaimon.common.file_io import FileIO
from pypaimon.filesystem.pyarrow_file_io import PyArrowFileIO
-from pypaimon.common.identifier import Identifier
+from pypaimon.common.identifier import Identifier, SYSTEM_TABLE_SPLITTER
from pypaimon.common.options import Options
from pypaimon.common.options.config import CatalogOptions, OssOptions
from pypaimon.common.uri_reader import UriReaderFactory
@@ -37,17 +37,13 @@ class RESTTokenFileIO(FileIO):
"""
A FileIO to support getting token from REST Server.
"""
-
+
_FILE_IO_CACHE_MAXSIZE = 1000
_FILE_IO_CACHE_TTL = 36000 # 10 hours in seconds
-
+
_FILE_IO_CACHE: TTLCache = None
_FILE_IO_CACHE_LOCK = threading.Lock()
-
- _TOKEN_CACHE: dict = {}
- _TOKEN_LOCKS: dict = {}
- _TOKEN_LOCKS_LOCK = threading.Lock()
-
+
@classmethod
def _get_file_io_cache(cls) -> TTLCache:
if cls._FILE_IO_CACHE is None:
@@ -58,7 +54,7 @@ class RESTTokenFileIO(FileIO):
ttl=cls._FILE_IO_CACHE_TTL
)
return cls._FILE_IO_CACHE
-
+
def __init__(self, identifier: Identifier, path: str,
catalog_options: Optional[Union[dict, Options]] = None):
self.identifier = identifier
@@ -99,26 +95,26 @@ class RESTTokenFileIO(FileIO):
if self.token is None:
return FileIO.get(self.path, self.catalog_options or Options({}))
-
+
cache_key = self.token
cache = self._get_file_io_cache()
-
+
file_io = cache.get(cache_key)
if file_io is not None:
return file_io
-
+
with self._FILE_IO_CACHE_LOCK:
self.try_to_refresh_token()
-
+
if self.token is None:
return FileIO.get(self.path, self.catalog_options or
Options({}))
-
+
cache_key = self.token
cache = self._get_file_io_cache()
file_io = cache.get(cache_key)
if file_io is not None:
return file_io
-
+
merged_properties = RESTUtil.merge(
self.catalog_options.to_map() if self.catalog_options else {},
self.token.token
@@ -128,7 +124,7 @@ class RESTTokenFileIO(FileIO):
if dlf_oss_endpoint and dlf_oss_endpoint.strip():
merged_properties[OssOptions.OSS_ENDPOINT.key()] =
dlf_oss_endpoint
merged_options = Options(merged_properties)
-
+
file_io = PyArrowFileIO(self.path, merged_options)
cache[cache_key] = file_io
return file_io
@@ -198,7 +194,7 @@ class RESTTokenFileIO(FileIO):
if self._uri_reader_factory_cache is None:
catalog_options = self.catalog_options or Options({})
self._uri_reader_factory_cache = UriReaderFactory(catalog_options)
-
+
return self._uri_reader_factory_cache
@property
@@ -206,66 +202,35 @@ class RESTTokenFileIO(FileIO):
return self.file_io().filesystem
def try_to_refresh_token(self):
- identifier_str = str(self.identifier)
-
- if self.token is not None and not self._is_token_expired(self.token):
- return
-
- cached_token = self._get_cached_token(identifier_str)
- if cached_token and not self._is_token_expired(cached_token):
- self.token = cached_token
- return
-
- global_lock = self._get_global_token_lock(identifier_str)
-
- with global_lock:
- cached_token = self._get_cached_token(identifier_str)
- if cached_token and not self._is_token_expired(cached_token):
- self.token = cached_token
- return
-
- token_to_check = cached_token if cached_token else self.token
- if token_to_check is None or
self._is_token_expired(token_to_check):
- self.refresh_token()
- self._set_cached_token(identifier_str, self.token)
-
- def _get_cached_token(self, identifier_str: str) -> Optional[RESTToken]:
- with self._TOKEN_LOCKS_LOCK:
- return self._TOKEN_CACHE.get(identifier_str)
-
- def _set_cached_token(self, identifier_str: str, token: RESTToken):
- with self._TOKEN_LOCKS_LOCK:
- self._TOKEN_CACHE[identifier_str] = token
-
- def _is_token_expired(self, token: Optional[RESTToken]) -> bool:
- if token is None:
- return True
- current_time = int(time.time() * 1000)
- return (token.expire_at_millis - current_time) <
RESTApi.TOKEN_EXPIRATION_SAFE_TIME_MILLIS
-
- def _get_global_token_lock(self, identifier_str: str) -> threading.Lock:
- with self._TOKEN_LOCKS_LOCK:
- if identifier_str not in self._TOKEN_LOCKS:
- self._TOKEN_LOCKS[identifier_str] = threading.Lock()
- return self._TOKEN_LOCKS[identifier_str]
-
- def should_refresh(self):
+ if self._should_refresh():
+ with self.lock:
+ if self._should_refresh():
+ self.refresh_token()
+
+ def _should_refresh(self):
if self.token is None:
return True
current_time = int(time.time() * 1000)
- time_until_expiry = self.token.expire_at_millis - current_time
- return time_until_expiry < RESTApi.TOKEN_EXPIRATION_SAFE_TIME_MILLIS
+ return (self.token.expire_at_millis - current_time) <
RESTApi.TOKEN_EXPIRATION_SAFE_TIME_MILLIS
def refresh_token(self):
self.log.info(f"begin refresh data token for identifier
[{self.identifier}]")
if self.api_instance is None:
self.api_instance = RESTApi(self.properties, False)
- response = self.api_instance.load_table_token(self.identifier)
+ table_identifier = self.identifier
+ if SYSTEM_TABLE_SPLITTER in self.identifier.get_object_name():
+ base_table =
self.identifier.get_object_name().split(SYSTEM_TABLE_SPLITTER)[0]
+ table_identifier = Identifier(
+ database=self.identifier.get_database_name(),
+ object=base_table,
+ branch=self.identifier.get_branch_name())
+
+ response = self.api_instance.load_table_token(table_identifier)
self.log.info(
f"end refresh data token for identifier [{self.identifier}]
expiresAtMillis [{response.expires_at_millis}]"
)
-
+
merged_token_dict =
self._merge_token_with_catalog_options(response.token)
new_token = RESTToken(merged_token_dict, response.expires_at_millis)
self.token = new_token
diff --git a/paimon-python/pypaimon/tests/rest/rest_token_file_io_test.py
b/paimon-python/pypaimon/tests/rest/rest_token_file_io_test.py
index 47ea8e6cb6..cdcd5ed36c 100644
--- a/paimon-python/pypaimon/tests/rest/rest_token_file_io_test.py
+++ b/paimon-python/pypaimon/tests/rest/rest_token_file_io_test.py
@@ -18,10 +18,12 @@
import os
import pickle
import tempfile
+import time
import unittest
-from unittest.mock import patch
+from unittest.mock import patch, MagicMock
from pypaimon.catalog.rest.rest_token_file_io import RESTTokenFileIO
+from pypaimon.catalog.rest.rest_token import RESTToken
from pypaimon.common.identifier import Identifier
from pypaimon.common.options import Options
@@ -101,13 +103,13 @@ class RESTTokenFileIOTest(unittest.TestCase):
target_dir = os.path.join(self.temp_dir, "target_dir")
os.makedirs(target_dir)
-
+
result = file_io.try_to_write_atomic(f"file://{target_dir}", "test
content")
self.assertFalse(result, "try_to_write_atomic should return False
when target is a directory")
-
+
self.assertTrue(os.path.isdir(target_dir))
self.assertEqual(len(os.listdir(target_dir)), 0, "No file should
be created inside the directory")
-
+
normal_file = os.path.join(self.temp_dir, "normal_file.txt")
result = file_io.try_to_write_atomic(f"file://{normal_file}",
"test content")
self.assertTrue(result, "try_to_write_atomic should succeed for a
normal file path")
@@ -223,35 +225,35 @@ class RESTTokenFileIOTest(unittest.TestCase):
CatalogOptions.URI.key(): "http://test-uri",
"custom.key": "custom.value"
})
-
+
catalog_options_copy = Options(original_catalog_options.to_map())
-
+
with patch.object(RESTTokenFileIO, 'try_to_refresh_token'):
file_io = RESTTokenFileIO(
self.identifier,
self.warehouse_path,
original_catalog_options
)
-
+
token_dict = {
OssOptions.OSS_ACCESS_KEY_ID.key(): "token-access-key",
OssOptions.OSS_ACCESS_KEY_SECRET.key(): "token-secret-key",
OssOptions.OSS_ENDPOINT.key(): "token-endpoint"
}
-
+
merged_token =
file_io._merge_token_with_catalog_options(token_dict)
-
+
self.assertEqual(
original_catalog_options.to_map(),
catalog_options_copy.to_map(),
"Original catalog_options should not be modified"
)
-
+
merged_properties = RESTUtil.merge(
original_catalog_options.to_map(),
merged_token
)
-
+
self.assertIn("custom.key", merged_properties)
self.assertEqual(merged_properties["custom.key"], "custom.value")
self.assertIn(OssOptions.OSS_ACCESS_KEY_ID.key(),
merged_properties)
@@ -264,11 +266,11 @@ class RESTTokenFileIOTest(unittest.TestCase):
self.warehouse_path,
self.catalog_options
)
-
+
self.assertTrue(hasattr(file_io, 'filesystem'), "RESTTokenFileIO
should have filesystem property")
filesystem = file_io.filesystem
self.assertIsNotNone(filesystem, "filesystem should not be None")
-
+
self.assertTrue(hasattr(filesystem, 'open_input_file'),
"filesystem should support open_input_file method")
@@ -279,12 +281,12 @@ class RESTTokenFileIOTest(unittest.TestCase):
self.warehouse_path,
self.catalog_options
)
-
+
self.assertTrue(hasattr(file_io, 'uri_reader_factory'),
"RESTTokenFileIO should have uri_reader_factory
property")
uri_reader_factory = file_io.uri_reader_factory
self.assertIsNotNone(uri_reader_factory, "uri_reader_factory
should not be None")
-
+
self.assertTrue(hasattr(uri_reader_factory, 'create'),
"uri_reader_factory should support create method")
@@ -295,15 +297,143 @@ class RESTTokenFileIOTest(unittest.TestCase):
self.warehouse_path,
self.catalog_options
)
-
+
pickled = pickle.dumps(original_file_io)
restored_file_io = pickle.loads(pickled)
-
+
self.assertIsNotNone(restored_file_io.filesystem,
"filesystem should work after
deserialization")
self.assertIsNotNone(restored_file_io.uri_reader_factory,
"uri_reader_factory should work after
deserialization")
+ def test_should_refresh_when_token_is_none(self):
+ """_should_refresh() returns True when token is None."""
+ with patch.object(RESTTokenFileIO, 'try_to_refresh_token'):
+ file_io = RESTTokenFileIO(
+ self.identifier, self.warehouse_path, self.catalog_options)
+ self.assertIsNone(file_io.token)
+ self.assertTrue(file_io._should_refresh())
+
+ def test_should_refresh_when_token_not_expired(self):
+ """_should_refresh() returns False when token is far from expiry."""
+ with patch.object(RESTTokenFileIO, 'try_to_refresh_token'):
+ file_io = RESTTokenFileIO(
+ self.identifier, self.warehouse_path, self.catalog_options)
+ # Token that expires 2 hours from now (well beyond the 1-hour safe
margin)
+ future_millis = int(time.time() * 1000) + 7200_000
+ file_io.token = RESTToken({'ak': 'v'}, future_millis)
+ self.assertFalse(file_io._should_refresh())
+
+ def test_should_refresh_when_token_expired(self):
+ """_should_refresh() returns True when token is already expired."""
+ with patch.object(RESTTokenFileIO, 'try_to_refresh_token'):
+ file_io = RESTTokenFileIO(
+ self.identifier, self.warehouse_path, self.catalog_options)
+ # Token that expired 1 second ago
+ past_millis = int(time.time() * 1000) - 1000
+ file_io.token = RESTToken({'ak': 'v'}, past_millis)
+ self.assertTrue(file_io._should_refresh())
+
+ def test_try_to_refresh_token_calls_refresh_once(self):
+ """try_to_refresh_token() calls refresh_token() exactly once via
double-check."""
+ file_io = RESTTokenFileIO(
+ self.identifier, self.warehouse_path, self.catalog_options)
+ self.assertIsNone(file_io.token)
+
+ mock_response = MagicMock()
+ mock_response.token = {'ak': 'test-ak'}
+ mock_response.expires_at_millis = int(time.time() * 1000) + 7200_000
+
+ mock_api = MagicMock()
+ mock_api.load_table_token.return_value = mock_response
+ file_io.api_instance = mock_api
+
+ file_io.try_to_refresh_token()
+
+ mock_api.load_table_token.assert_called_once()
+ self.assertIsNotNone(file_io.token)
+
+ # Second call should NOT trigger refresh again (token is valid)
+ file_io.try_to_refresh_token()
+ mock_api.load_table_token.assert_called_once()
+
+ def test_refresh_token_strips_system_table_suffix(self):
+ """refresh_token() strips $snapshots suffix before requesting token."""
+ system_identifier = Identifier.create("db", "my_table$snapshots")
+ file_io = RESTTokenFileIO(
+ system_identifier, self.warehouse_path, self.catalog_options)
+
+ mock_response = MagicMock()
+ mock_response.token = {'ak': 'test-ak'}
+ mock_response.expires_at_millis = int(time.time() * 1000) + 7200_000
+
+ mock_api = MagicMock()
+ mock_api.load_table_token.return_value = mock_response
+ file_io.api_instance = mock_api
+
+ file_io.refresh_token()
+
+ # Verify load_table_token was called with base table identifier (no
$snapshots)
+ called_identifier = mock_api.load_table_token.call_args[0][0]
+ self.assertEqual(called_identifier.get_database_name(), "db")
+ self.assertEqual(called_identifier.get_object_name(), "my_table")
+
+ def test_refresh_token_keeps_normal_identifier(self):
+ """refresh_token() does not modify normal (non-system) identifiers."""
+ normal_identifier = Identifier.create("db", "my_table")
+ file_io = RESTTokenFileIO(
+ normal_identifier, self.warehouse_path, self.catalog_options)
+
+ mock_response = MagicMock()
+ mock_response.token = {'ak': 'test-ak'}
+ mock_response.expires_at_millis = int(time.time() * 1000) + 7200_000
+
+ mock_api = MagicMock()
+ mock_api.load_table_token.return_value = mock_response
+ file_io.api_instance = mock_api
+
+ file_io.refresh_token()
+
+ called_identifier = mock_api.load_table_token.call_args[0][0]
+ self.assertEqual(called_identifier.get_object_name(), "my_table")
+
+ def test_different_instances_do_not_share_token(self):
+ """Two instances with same identifier get independent tokens (no
class-level cache)."""
+ same_identifier = Identifier.from_string("db.shared_table")
+
+ file_io_a = RESTTokenFileIO(
+ same_identifier, self.warehouse_path, self.catalog_options)
+ file_io_b = RESTTokenFileIO(
+ same_identifier, self.warehouse_path, self.catalog_options)
+
+ token_a = RESTToken({'ak': 'ak-A'}, int(time.time() * 1000) + 7200_000)
+ token_b = RESTToken({'ak': 'ak-B'}, int(time.time() * 1000) + 7200_000)
+
+ mock_response_a = MagicMock()
+ mock_response_a.token = token_a.token
+ mock_response_a.expires_at_millis = token_a.expire_at_millis
+
+ mock_response_b = MagicMock()
+ mock_response_b.token = token_b.token
+ mock_response_b.expires_at_millis = token_b.expire_at_millis
+
+ mock_api_a = MagicMock()
+ mock_api_a.load_table_token.return_value = mock_response_a
+ file_io_a.api_instance = mock_api_a
+
+ mock_api_b = MagicMock()
+ mock_api_b.load_table_token.return_value = mock_response_b
+ file_io_b.api_instance = mock_api_b
+
+ # Refresh both
+ file_io_a.try_to_refresh_token()
+ file_io_b.try_to_refresh_token()
+
+ # Each instance should hold its own token
+ self.assertEqual(file_io_a.token.token['ak'], 'ak-A')
+ self.assertEqual(file_io_b.token.token['ak'], 'ak-B')
+ self.assertIsNot(file_io_a.token, file_io_b.token)
+
if __name__ == '__main__':
unittest.main()