This is an automated email from the ASF dual-hosted git repository.
junhao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git
The following commit(s) were added to refs/heads/master by this push:
new 2bacfdbf71 [python] Fix tryCommit failed. (#6969)
2bacfdbf71 is described below
commit 2bacfdbf712827c1630b399db7b5f470211aed2d
Author: umi <[email protected]>
AuthorDate: Thu Jan 8 12:10:49 2026 +0800
[python] Fix tryCommit failed. (#6969)
---
paimon-python/pypaimon/common/file_io.py | 18 +-
.../pypaimon/common/options/core_options.py | 48 +++
.../pypaimon/common/options/options_utils.py | 13 +
paimon-python/pypaimon/common/time_utils.py | 81 +++++
.../pypaimon/snapshot/snapshot_manager.py | 47 ++-
paimon-python/pypaimon/tests/blob_table_test.py | 135 ++++++++
.../pypaimon/tests/reader_append_only_test.py | 106 +++++++
.../pypaimon/tests/reader_primary_key_test.py | 104 +++++++
.../pypaimon/tests/schema_evolution_read_test.py | 1 +
paimon-python/pypaimon/write/file_store_commit.py | 343 ++++++++++++++++-----
10 files changed, 810 insertions(+), 86 deletions(-)
diff --git a/paimon-python/pypaimon/common/file_io.py
b/paimon-python/pypaimon/common/file_io.py
index 497d711a49..2ec1909306 100644
--- a/paimon-python/pypaimon/common/file_io.py
+++ b/paimon-python/pypaimon/common/file_io.py
@@ -18,13 +18,15 @@
import logging
import os
import subprocess
+import threading
+import uuid
from pathlib import Path
from typing import Any, Dict, List, Optional
from urllib.parse import splitport, urlparse
import pyarrow
from packaging.version import parse
-from pyarrow._fs import FileSystem
+from pyarrow._fs import FileSystem, LocalFileSystem
from pypaimon.common.options import Options
from pypaimon.common.options.config import OssOptions, S3Options
@@ -37,6 +39,8 @@ from pypaimon.write.blob_format_writer import BlobFormatWriter
class FileIO:
+ rename_lock = threading.Lock()
+
def __init__(self, path: str, catalog_options: Options):
self.properties = catalog_options
self.logger = logging.getLogger(__name__)
@@ -251,7 +255,15 @@ class FileIO:
self.mkdirs(str(dst_parent))
src_str = self.to_filesystem_path(src)
- self.filesystem.move(src_str, dst_str)
+ if isinstance(self.filesystem, LocalFileSystem):
+ if self.exists(dst):
+ return False
+ with FileIO.rename_lock:
+ if self.exists(dst):
+ return False
+ self.filesystem.move(src_str, dst_str)
+ else:
+ self.filesystem.move(src_str, dst_str)
return True
except Exception as e:
self.logger.warning(f"Failed to rename {src} to {dst}: {e}")
@@ -303,7 +315,7 @@ class FileIO:
return input_stream.read().decode('utf-8')
def try_to_write_atomic(self, path: str, content: str) -> bool:
- temp_path = path + ".tmp"
+ temp_path = path + str(uuid.uuid4()) + ".tmp"
success = False
try:
self.write_file(temp_path, content, False)
diff --git a/paimon-python/pypaimon/common/options/core_options.py
b/paimon-python/pypaimon/common/options/core_options.py
index 4ab5a253d7..49230240a7 100644
--- a/paimon-python/pypaimon/common/options/core_options.py
+++ b/paimon-python/pypaimon/common/options/core_options.py
@@ -15,9 +15,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
+import sys
from enum import Enum
from typing import Dict
+from datetime import timedelta
+
from pypaimon.common.memory_size import MemorySize
from pypaimon.common.options import Options
from pypaimon.common.options.config_options import ConfigOptions
@@ -239,6 +242,34 @@ class CoreOptions:
.with_description("The prefix for commit user.")
)
+ COMMIT_MAX_RETRIES: ConfigOption[int] = (
+ ConfigOptions.key("commit.max-retries")
+ .int_type()
+ .default_value(10)
+ .with_description("Maximum number of retries for commit operations.")
+ )
+
+ COMMIT_TIMEOUT: ConfigOption[timedelta] = (
+ ConfigOptions.key("commit.timeout")
+ .duration_type()
+ .no_default_value()
+ .with_description("Timeout for commit operations (e.g., '10s', '5m').
If not set, effectively unlimited.")
+ )
+
+ COMMIT_MIN_RETRY_WAIT: ConfigOption[timedelta] = (
+ ConfigOptions.key("commit.min-retry-wait")
+ .duration_type()
+ .default_value(timedelta(milliseconds=10))
+ .with_description("Minimum wait time between commit retries (e.g.,
'10ms', '100ms').")
+ )
+
+ COMMIT_MAX_RETRY_WAIT: ConfigOption[timedelta] = (
+ ConfigOptions.key("commit.max-retry-wait")
+ .duration_type()
+ .default_value(timedelta(seconds=10))
+ .with_description("Maximum wait time between commit retries (e.g.,
'1s', '10s').")
+ )
+
ROW_TRACKING_ENABLED: ConfigOption[bool] = (
ConfigOptions.key("row-tracking.enabled")
.boolean_type()
@@ -390,3 +421,20 @@ class CoreOptions:
def data_file_external_paths_specific_fs(self, default=None):
return
self.options.get(CoreOptions.DATA_FILE_EXTERNAL_PATHS_SPECIFIC_FS, default)
+
+ def commit_max_retries(self) -> int:
+ return self.options.get(CoreOptions.COMMIT_MAX_RETRIES)
+
+ def commit_timeout(self) -> int:
+ timeout = self.options.get(CoreOptions.COMMIT_TIMEOUT)
+ if timeout is None:
+ return sys.maxsize
+ return int(timeout.total_seconds() * 1000)
+
+ def commit_min_retry_wait(self) -> int:
+ wait = self.options.get(CoreOptions.COMMIT_MIN_RETRY_WAIT)
+ return int(wait.total_seconds() * 1000)
+
+ def commit_max_retry_wait(self) -> int:
+ wait = self.options.get(CoreOptions.COMMIT_MAX_RETRY_WAIT)
+ return int(wait.total_seconds() * 1000)
diff --git a/paimon-python/pypaimon/common/options/options_utils.py
b/paimon-python/pypaimon/common/options/options_utils.py
index f48f549df4..9938e87e74 100644
--- a/paimon-python/pypaimon/common/options/options_utils.py
+++ b/paimon-python/pypaimon/common/options/options_utils.py
@@ -16,10 +16,12 @@ See the License for the specific language governing
permissions and
limitations under the License.
"""
+from datetime import timedelta
from enum import Enum
from typing import Any, Type
from pypaimon.common.memory_size import MemorySize
+from pypaimon.common.time_utils import parse_duration
class OptionsUtils:
@@ -63,6 +65,8 @@ class OptionsUtils:
return OptionsUtils.convert_to_double(value)
elif target_type == MemorySize:
return OptionsUtils.convert_to_memory_size(value)
+ elif target_type == timedelta:
+ return OptionsUtils.convert_to_duration(value)
else:
raise ValueError(f"Unsupported type: {target_type}")
@@ -125,6 +129,15 @@ class OptionsUtils:
return MemorySize.parse(value)
raise ValueError(f"Cannot convert {type(value)} to MemorySize")
+ @staticmethod
+ def convert_to_duration(value: Any) -> timedelta:
+ if isinstance(value, timedelta):
+ return value
+ if isinstance(value, str):
+ milliseconds = parse_duration(value)
+ return timedelta(milliseconds=milliseconds)
+ raise ValueError(f"Cannot convert {type(value)} to timedelta")
+
@staticmethod
def convert_to_enum(value: Any, enum_class: Type[Enum]) -> Enum:
diff --git a/paimon-python/pypaimon/common/time_utils.py
b/paimon-python/pypaimon/common/time_utils.py
new file mode 100644
index 0000000000..1a02bd4398
--- /dev/null
+++ b/paimon-python/pypaimon/common/time_utils.py
@@ -0,0 +1,81 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+
+def parse_duration(text: str) -> int:
+ if text is None:
+ raise ValueError("text cannot be None")
+
+ trimmed = text.strip().lower()
+ if not trimmed:
+ raise ValueError("argument is an empty- or whitespace-only string")
+
+ pos = 0
+ while pos < len(trimmed) and trimmed[pos].isdigit():
+ pos += 1
+
+ number_str = trimmed[:pos]
+ unit_str = trimmed[pos:].strip()
+
+ if not number_str:
+ raise ValueError("text does not start with a number")
+
+ try:
+ value = int(number_str)
+ except ValueError:
+ raise ValueError(
+ f"The value '{number_str}' cannot be re represented as 64bit
number (numeric overflow)."
+ )
+
+ if not unit_str:
+ result_ms = value
+ elif unit_str in ('ns', 'nano', 'nanosecond', 'nanoseconds'):
+ result_ms = value / 1_000_000
+ elif unit_str in ('µs', 'micro', 'microsecond', 'microseconds'):
+ result_ms = value / 1_000
+ elif unit_str in ('ms', 'milli', 'millisecond', 'milliseconds'):
+ result_ms = value
+ elif unit_str in ('s', 'sec', 'second', 'seconds'):
+ result_ms = value * 1000
+ elif unit_str in ('m', 'min', 'minute', 'minutes'):
+ result_ms = value * 60 * 1000
+ elif unit_str in ('h', 'hour', 'hours'):
+ result_ms = value * 60 * 60 * 1000
+ elif unit_str in ('d', 'day', 'days'):
+ result_ms = value * 24 * 60 * 60 * 1000
+ else:
+ supported_units = (
+ 'DAYS: (d | day | days), '
+ 'HOURS: (h | hour | hours), '
+ 'MINUTES: (m | min | minute | minutes), '
+ 'SECONDS: (s | sec | second | seconds), '
+ 'MILLISECONDS: (ms | milli | millisecond | milliseconds), '
+ 'MICROSECONDS: (µs | micro | microsecond | microseconds), '
+ 'NANOSECONDS: (ns | nano | nanosecond | nanoseconds)'
+ )
+ raise ValueError(
+ f"Time interval unit label '{unit_str}' does not match any of the
recognized units: "
+ f"{supported_units}"
+ )
+
+ result_ms_int = int(round(result_ms))
+
+ if result_ms_int < 0:
+ raise ValueError(f"Duration cannot be negative: {text}")
+
+ return result_ms_int
diff --git a/paimon-python/pypaimon/snapshot/snapshot_manager.py
b/paimon-python/pypaimon/snapshot/snapshot_manager.py
index 0d96563057..8291d9cf2c 100644
--- a/paimon-python/pypaimon/snapshot/snapshot_manager.py
+++ b/paimon-python/pypaimon/snapshot/snapshot_manager.py
@@ -38,7 +38,7 @@ class SnapshotManager:
if not self.file_io.exists(self.latest_file):
return None
- latest_content = self.file_io.read_file_utf8(self.latest_file)
+ latest_content = self.read_latest_file()
latest_snapshot_id = int(latest_content.strip())
snapshot_file = f"{self.snapshot_dir}/snapshot-{latest_snapshot_id}"
@@ -48,6 +48,51 @@ class SnapshotManager:
snapshot_content = self.file_io.read_file_utf8(snapshot_file)
return JSON.from_json(snapshot_content, Snapshot)
+ def read_latest_file(self, max_retries: int = 5):
+ """
+ Read the latest snapshot ID from LATEST file with retry mechanism.
+ If file doesn't exist or is empty after retries, scan snapshot
directory for max ID.
+ """
+ import re
+ import time
+
+ # Try to read LATEST file with retries
+ for retry_count in range(max_retries):
+ try:
+ if self.file_io.exists(self.latest_file):
+ content = self.file_io.read_file_utf8(self.latest_file)
+ if content and content.strip():
+ return content.strip()
+
+ # File doesn't exist or is empty, wait a bit before retry
+ if retry_count < max_retries - 1:
+ time.sleep(0.001)
+
+ except Exception:
+ # On exception, wait and retry
+ if retry_count < max_retries - 1:
+ time.sleep(0.001)
+
+ # List all files in snapshot directory
+ file_infos = self.file_io.list_status(self.snapshot_dir)
+
+ max_snapshot_id = None
+ snapshot_pattern = re.compile(r'^snapshot-(\d+)$')
+
+ for file_info in file_infos:
+ # Get filename from path
+ filename = file_info.path.split('/')[-1]
+ match = snapshot_pattern.match(filename)
+ if match:
+ snapshot_id = int(match.group(1))
+ if max_snapshot_id is None or snapshot_id > max_snapshot_id:
+ max_snapshot_id = snapshot_id
+
+ if not max_snapshot_id:
+ raise RuntimeError(f"No snapshot content found in
{self.snapshot_dir}")
+
+ return str(max_snapshot_id)
+
def get_snapshot_path(self, snapshot_id: int) -> str:
"""
Get the path for a snapshot file.
diff --git a/paimon-python/pypaimon/tests/blob_table_test.py
b/paimon-python/pypaimon/tests/blob_table_test.py
index de59d0398e..f87f73ded7 100755
--- a/paimon-python/pypaimon/tests/blob_table_test.py
+++ b/paimon-python/pypaimon/tests/blob_table_test.py
@@ -2567,6 +2567,141 @@ class DataBlobWriterTest(unittest.TestCase):
self.assertEqual(actual, expected)
+ def test_concurrent_blob_writes_with_retry(self):
+ """Test concurrent blob writes to verify retry mechanism works
correctly."""
+ import threading
+ from pypaimon import Schema
+ from pypaimon.snapshot.snapshot_manager import SnapshotManager
+
+ # Run the test 10 times to verify stability
+ iter_num = 2
+ for test_iteration in range(iter_num):
+ # Create a unique table for each iteration
+ table_name = f'test_db.blob_concurrent_writes_{test_iteration}'
+
+ # Create schema with blob column
+ pa_schema = pa.schema([
+ ('id', pa.int32()),
+ ('thread_id', pa.int32()),
+ ('metadata', pa.string()),
+ ('blob_data', pa.large_binary()),
+ ])
+
+ schema = Schema.from_pyarrow_schema(
+ pa_schema,
+ options={
+ 'row-tracking.enabled': 'true',
+ 'data-evolution.enabled': 'true'
+ }
+ )
+ self.catalog.create_table(table_name, schema, False)
+ table = self.catalog.get_table(table_name)
+
+ write_results = []
+ write_errors = []
+
+ # Create blob pattern for testing
+ blob_size = 5 * 1024 # 5KB
+ blob_pattern = b'BLOB_PATTERN_' + b'X' * 1024
+ pattern_size = len(blob_pattern)
+ repetitions = blob_size // pattern_size
+ base_blob_data = blob_pattern * repetitions
+
+ def write_blob_data(thread_id, start_id):
+ """Write blob data in a separate thread."""
+ try:
+ threading.current_thread().name =
f"Iter{test_iteration}-Thread-{thread_id}"
+ write_builder = table.new_batch_write_builder()
+ table_write = write_builder.new_write()
+ table_commit = write_builder.new_commit()
+
+ # Create unique blob data for this thread
+ data = {
+ 'id': list(range(start_id, start_id + 5)),
+ 'thread_id': [thread_id] * 5,
+ 'metadata': [f'thread{thread_id}_blob_{i}' for i in
range(5)],
+ 'blob_data': [i.to_bytes(2, byteorder='little') +
base_blob_data for i in range(5)]
+ }
+ pa_table = pa.Table.from_pydict(data, schema=pa_schema)
+
+ table_write.write_arrow(pa_table)
+ commit_messages = table_write.prepare_commit()
+
+ table_commit.commit(commit_messages)
+ table_write.close()
+ table_commit.close()
+
+ write_results.append({
+ 'thread_id': thread_id,
+ 'start_id': start_id,
+ 'success': True
+ })
+ except Exception as e:
+ write_errors.append({
+ 'thread_id': thread_id,
+ 'error': str(e)
+ })
+
+ # Create and start multiple threads
+ threads = []
+ num_threads = 100
+ for i in range(num_threads):
+ thread = threading.Thread(
+ target=write_blob_data,
+ args=(i, i * 10)
+ )
+ threads.append(thread)
+ thread.start()
+
+ # Wait for all threads to complete
+ for thread in threads:
+ thread.join()
+
+ # Verify all writes succeeded (retry mechanism should handle
conflicts)
+ self.assertEqual(num_threads, len(write_results),
+ f"Iteration {test_iteration}: Expected
{num_threads} successful writes, "
+ f"got {len(write_results)}. Errors:
{write_errors}")
+ self.assertEqual(0, len(write_errors),
+ f"Iteration {test_iteration}: Expected no errors,
but got: {write_errors}")
+
+ read_builder = table.new_read_builder()
+ table_scan = read_builder.new_scan()
+ table_read = read_builder.new_read()
+ actual =
table_read.to_arrow(table_scan.plan().splits()).sort_by('id')
+
+ # Verify data rows
+ self.assertEqual(num_threads * 5, actual.num_rows,
+ f"Iteration {test_iteration}: Expected
{num_threads * 5} rows")
+
+ # Verify id column
+ ids = actual.column('id').to_pylist()
+ expected_ids = []
+ for i in range(num_threads):
+ expected_ids.extend(range(i * 10, i * 10 + 5))
+ expected_ids.sort()
+
+ self.assertEqual(ids, expected_ids,
+ f"Iteration {test_iteration}: IDs mismatch")
+
+ # Verify blob data integrity (spot check)
+ blob_data_list = actual.column('blob_data').to_pylist()
+ for i in range(0, len(blob_data_list), 100): # Check every 100th
blob
+ blob = blob_data_list[i]
+ self.assertGreater(len(blob), 2, f"Blob {i} should have data")
+ # Verify blob contains the pattern
+ self.assertIn(b'BLOB_PATTERN_', blob, f"Blob {i} should
contain pattern")
+
+ # Verify snapshot count (should have num_threads snapshots)
+ snapshot_manager = SnapshotManager(table)
+ latest_snapshot = snapshot_manager.get_latest_snapshot()
+ self.assertIsNotNone(latest_snapshot,
+ f"Iteration {test_iteration}: Latest snapshot
should not be None")
+ self.assertEqual(latest_snapshot.id, num_threads,
+ f"Iteration {test_iteration}: Expected snapshot
ID {num_threads}, "
+ f"got {latest_snapshot.id}")
+
+ print(f"✓ Blob Table Iteration {test_iteration + 1}/{iter_num}
completed successfully")
+
if __name__ == '__main__':
unittest.main()
diff --git a/paimon-python/pypaimon/tests/reader_append_only_test.py
b/paimon-python/pypaimon/tests/reader_append_only_test.py
index 2661723919..b47f5d1f67 100644
--- a/paimon-python/pypaimon/tests/reader_append_only_test.py
+++ b/paimon-python/pypaimon/tests/reader_append_only_test.py
@@ -17,6 +17,7 @@
################################################################################
import os
+import shutil
import tempfile
import time
import unittest
@@ -53,6 +54,10 @@ class AoReaderTest(unittest.TestCase):
'dt': ['p1', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p2'],
}, schema=cls.pa_schema)
+ @classmethod
+ def tearDownClass(cls):
+ shutil.rmtree(cls.tempdir, ignore_errors=True)
+
def test_parquet_ao_reader(self):
schema = Schema.from_pyarrow_schema(self.pa_schema,
partition_keys=['dt'])
self.catalog.create_table('default.test_append_only_parquet', schema,
False)
@@ -410,3 +415,104 @@ class AoReaderTest(unittest.TestCase):
table_read = read_builder.new_read()
splits = read_builder.new_scan().plan().splits()
return table_read.to_arrow(splits)
+
+ def test_concurrent_writes_with_retry(self):
+ """Test concurrent writes to verify retry mechanism works correctly."""
+ import threading
+
+ # Run the test 10 times to verify stability
+ iter_num = 5
+ for test_iteration in range(iter_num):
+ # Create a unique table for each iteration
+ table_name = f'default.test_concurrent_writes_{test_iteration}'
+ schema = Schema.from_pyarrow_schema(self.pa_schema)
+ self.catalog.create_table(table_name, schema, False)
+ table = self.catalog.get_table(table_name)
+
+ write_results = []
+ write_errors = []
+
+ def write_data(thread_id, start_user_id):
+ """Write data in a separate thread."""
+ try:
+ threading.current_thread().name =
f"Iter{test_iteration}-Thread-{thread_id}"
+ write_builder = table.new_batch_write_builder()
+ table_write = write_builder.new_write()
+ table_commit = write_builder.new_commit()
+
+ # Create unique data for this thread
+ data = {
+ 'user_id': list(range(start_user_id, start_user_id +
5)),
+ 'item_id': [1000 + i for i in range(start_user_id,
start_user_id + 5)],
+ 'behavior': [f'thread{thread_id}_{i}' for i in
range(5)],
+ 'dt': ['p1' if i % 2 == 0 else 'p2' for i in range(5)],
+ }
+ pa_table = pa.Table.from_pydict(data,
schema=self.pa_schema)
+
+ table_write.write_arrow(pa_table)
+ commit_messages = table_write.prepare_commit()
+
+ table_commit.commit(commit_messages)
+ table_write.close()
+ table_commit.close()
+
+ write_results.append({
+ 'thread_id': thread_id,
+ 'start_user_id': start_user_id,
+ 'success': True
+ })
+ except Exception as e:
+ write_errors.append({
+ 'thread_id': thread_id,
+ 'error': str(e)
+ })
+
+ # Create and start multiple threads
+ threads = []
+ num_threads = 100
+ for i in range(num_threads):
+ thread = threading.Thread(
+ target=write_data,
+ args=(i, i * 10)
+ )
+ threads.append(thread)
+ thread.start()
+
+ # Wait for all threads to complete
+ for thread in threads:
+ thread.join()
+
+ # Verify all writes succeeded (retry mechanism should handle
conflicts)
+ self.assertEqual(num_threads, len(write_results),
+ f"Iteration {test_iteration}: Expected
{num_threads} successful writes, "
+ f"got {len(write_results)}. Errors:
{write_errors}")
+ self.assertEqual(0, len(write_errors),
+ f"Iteration {test_iteration}: Expected no errors,
but got: {write_errors}")
+
+ read_builder = table.new_read_builder()
+ actual = self._read_test_table(read_builder).sort_by('user_id')
+
+ # Verify data rows
+ self.assertEqual(num_threads * 5, actual.num_rows,
+ f"Iteration {test_iteration}: Expected
{num_threads * 5} rows")
+
+ # Verify user_id
+ user_ids = actual.column('user_id').to_pylist()
+ expected_user_ids = []
+ for i in range(num_threads):
+ expected_user_ids.extend(range(i * 10, i * 10 + 5))
+ expected_user_ids.sort()
+
+ self.assertEqual(user_ids, expected_user_ids,
+ f"Iteration {test_iteration}: User IDs mismatch")
+
+ # Verify snapshot count (should have num_threads snapshots)
+ snapshot_manager = SnapshotManager(table)
+ latest_snapshot = snapshot_manager.get_latest_snapshot()
+ self.assertIsNotNone(latest_snapshot,
+ f"Iteration {test_iteration}: Latest snapshot
should not be None")
+ self.assertEqual(latest_snapshot.id, num_threads,
+ f"Iteration {test_iteration}: Expected snapshot
ID {num_threads}, "
+ f"got {latest_snapshot.id}")
+
+ print(f"✓ Iteration {test_iteration + 1}/{iter_num} completed
successfully")
diff --git a/paimon-python/pypaimon/tests/reader_primary_key_test.py
b/paimon-python/pypaimon/tests/reader_primary_key_test.py
index 7077b2fd44..731203385d 100644
--- a/paimon-python/pypaimon/tests/reader_primary_key_test.py
+++ b/paimon-python/pypaimon/tests/reader_primary_key_test.py
@@ -422,3 +422,107 @@ class PkReaderTest(unittest.TestCase):
table_read = read_builder.new_read()
splits = read_builder.new_scan().plan().splits()
return table_read.to_arrow(splits)
+
+ def test_concurrent_writes_with_retry(self):
+ """Test concurrent writes to verify retry mechanism works correctly
for PK tables."""
+ import threading
+
+ # Run the test 3 times to verify stability
+ iter_num = 3
+ for test_iteration in range(iter_num):
+ # Create a unique table for each iteration
+ table_name = f'default.test_pk_concurrent_writes_{test_iteration}'
+ schema = Schema.from_pyarrow_schema(self.pa_schema,
+ partition_keys=['dt'],
+ primary_keys=['user_id', 'dt'],
+ options={'bucket': '2'})
+ self.catalog.create_table(table_name, schema, False)
+ table = self.catalog.get_table(table_name)
+
+ write_results = []
+ write_errors = []
+
+ def write_data(thread_id, start_user_id):
+ """Write data in a separate thread."""
+ try:
+ threading.current_thread().name =
f"Iter{test_iteration}-Thread-{thread_id}"
+ write_builder = table.new_batch_write_builder()
+ table_write = write_builder.new_write()
+ table_commit = write_builder.new_commit()
+
+ # Create unique data for this thread
+ data = {
+ 'user_id': list(range(start_user_id, start_user_id +
5)),
+ 'item_id': [1000 + i for i in range(start_user_id,
start_user_id + 5)],
+ 'behavior': [f'thread{thread_id}_{i}' for i in
range(5)],
+ 'dt': ['p1' if i % 2 == 0 else 'p2' for i in range(5)],
+ }
+ pa_table = pa.Table.from_pydict(data,
schema=self.pa_schema)
+
+ table_write.write_arrow(pa_table)
+ commit_messages = table_write.prepare_commit()
+
+ table_commit.commit(commit_messages)
+ table_write.close()
+ table_commit.close()
+
+ write_results.append({
+ 'thread_id': thread_id,
+ 'start_user_id': start_user_id,
+ 'success': True
+ })
+ except Exception as e:
+ write_errors.append({
+ 'thread_id': thread_id,
+ 'error': str(e)
+ })
+
+ # Create and start multiple threads
+ threads = []
+ num_threads = 100
+ for i in range(num_threads):
+ thread = threading.Thread(
+ target=write_data,
+ args=(i, i * 10)
+ )
+ threads.append(thread)
+ thread.start()
+
+ # Wait for all threads to complete
+ for thread in threads:
+ thread.join()
+
+ # Verify all writes succeeded (retry mechanism should handle
conflicts)
+ self.assertEqual(num_threads, len(write_results),
+ f"Iteration {test_iteration}: Expected
{num_threads} successful writes, "
+ f"got {len(write_results)}. Errors:
{write_errors}")
+ self.assertEqual(0, len(write_errors),
+ f"Iteration {test_iteration}: Expected no errors,
but got: {write_errors}")
+
+ read_builder = table.new_read_builder()
+ actual = self._read_test_table(read_builder).sort_by('user_id')
+
+ # Verify data rows (PK table should have unique user_id+dt
combinations)
+ self.assertEqual(num_threads * 5, actual.num_rows,
+ f"Iteration {test_iteration}: Expected
{num_threads * 5} rows")
+
+ # Verify user_id
+ user_ids = actual.column('user_id').to_pylist()
+ expected_user_ids = []
+ for i in range(num_threads):
+ expected_user_ids.extend(range(i * 10, i * 10 + 5))
+ expected_user_ids.sort()
+
+ self.assertEqual(user_ids, expected_user_ids,
+ f"Iteration {test_iteration}: User IDs mismatch")
+
+ # Verify snapshot count (should have num_threads snapshots)
+ snapshot_manager = SnapshotManager(table)
+ latest_snapshot = snapshot_manager.get_latest_snapshot()
+ self.assertIsNotNone(latest_snapshot,
+ f"Iteration {test_iteration}: Latest snapshot
should not be None")
+ self.assertEqual(latest_snapshot.id, num_threads,
+ f"Iteration {test_iteration}: Expected snapshot
ID {num_threads}, "
+ f"got {latest_snapshot.id}")
+
+ print(f"✓ PK Table Iteration {test_iteration + 1}/{iter_num}
completed successfully")
diff --git a/paimon-python/pypaimon/tests/schema_evolution_read_test.py
b/paimon-python/pypaimon/tests/schema_evolution_read_test.py
index f5dafaae35..a67a927a5e 100644
--- a/paimon-python/pypaimon/tests/schema_evolution_read_test.py
+++ b/paimon-python/pypaimon/tests/schema_evolution_read_test.py
@@ -322,6 +322,7 @@ class SchemaEvolutionReadTest(unittest.TestCase):
# write schema-0 and schema-1 to table2
schema_manager = SchemaManager(table2.file_io, table2.table_path)
+ schema_manager.file_io.delete_quietly(table2.table_path +
"/schema/schema-0")
schema_manager.commit(TableSchema.from_schema(schema_id=0,
schema=schema))
schema_manager.commit(TableSchema.from_schema(schema_id=1,
schema=schema2))
diff --git a/paimon-python/pypaimon/write/file_store_commit.py
b/paimon-python/pypaimon/write/file_store_commit.py
index a5b9fd9693..e55e25f7c8 100644
--- a/paimon-python/pypaimon/write/file_store_commit.py
+++ b/paimon-python/pypaimon/write/file_store_commit.py
@@ -16,9 +16,11 @@
# limitations under the License.
################################################################################
+import logging
+import random
import time
import uuid
-from typing import List
+from typing import List, Optional
from pypaimon.common.predicate_builder import PredicateBuilder
from pypaimon.manifest.manifest_file_manager import ManifestFileManager
@@ -35,6 +37,33 @@ from pypaimon.table.row.generic_row import GenericRow
from pypaimon.table.row.offset_row import OffsetRow
from pypaimon.write.commit_message import CommitMessage
+logger = logging.getLogger(__name__)
+
+
+class CommitResult:
+ """Base class for commit results."""
+
+ def is_success(self) -> bool:
+ """Returns True if commit was successful."""
+ raise NotImplementedError
+
+
+class SuccessResult(CommitResult):
+ """Result indicating successful commit."""
+
+ def is_success(self) -> bool:
+ return True
+
+
+class RetryResult(CommitResult):
+
+ def __init__(self, latest_snapshot, exception: Optional[Exception] = None):
+ self.latest_snapshot = latest_snapshot
+ self.exception = exception
+
+ def is_success(self) -> bool:
+ return False
+
class FileStoreCommit:
"""
@@ -58,6 +87,11 @@ class FileStoreCommit:
self.manifest_target_size = 8 * 1024 * 1024
self.manifest_merge_min_count = 30
+ self.commit_max_retries = table.options.commit_max_retries()
+ self.commit_timeout = table.options.commit_timeout()
+ self.commit_min_retry_wait = table.options.commit_min_retry_wait()
+ self.commit_max_retry_wait = table.options.commit_max_retry_wait()
+
def commit(self, commit_messages: List[CommitMessage], commit_identifier:
int):
"""Commit the given commit messages in normal append mode."""
if not commit_messages:
@@ -99,27 +133,81 @@ class FileStoreCommit:
raise RuntimeError(f"Trying to overwrite partition
{overwrite_partition}, but the changes "
f"in {msg.partition} does not belong to
this partition")
- commit_entries = []
- current_entries = FullStartingScanner(self.table, partition_filter,
None).plan_files()
- for entry in current_entries:
- entry.kind = 1
- commit_entries.append(entry)
- for msg in commit_messages:
- partition = GenericRow(list(msg.partition),
self.table.partition_keys_fields)
- for file in msg.new_files:
- commit_entries.append(ManifestEntry(
- kind=0,
- partition=partition,
- bucket=msg.bucket,
- total_buckets=self.table.total_buckets,
- file=file
- ))
+ self._overwrite_partition_filter = partition_filter
+ self._overwrite_commit_messages = commit_messages
- self._try_commit(commit_kind="OVERWRITE",
- commit_entries=commit_entries,
- commit_identifier=commit_identifier)
+ self._try_commit(
+ commit_kind="OVERWRITE",
+ commit_entries=None, # Will be generated in _try_commit based on
latest snapshot
+ commit_identifier=commit_identifier
+ )
def _try_commit(self, commit_kind, commit_entries, commit_identifier):
+ import threading
+
+ retry_count = 0
+ retry_result = None
+ start_time_ms = int(time.time() * 1000)
+ thread_id = threading.current_thread().name
+ while True:
+ latest_snapshot = self.snapshot_manager.get_latest_snapshot()
+
+ if commit_kind == "OVERWRITE":
+ commit_entries = self._generate_overwrite_entries()
+
+ result = self._try_commit_once(
+ retry_result=retry_result,
+ commit_kind=commit_kind,
+ commit_entries=commit_entries,
+ commit_identifier=commit_identifier,
+ latest_snapshot=latest_snapshot
+ )
+
+ if result.is_success():
+ logger.warning(
+ f"Thread {thread_id}: commit success {latest_snapshot.id +
1 if latest_snapshot else 1} "
+ f"after {retry_count} retries"
+ )
+ break
+
+ retry_result = result
+
+ elapsed_ms = int(time.time() * 1000) - start_time_ms
+ if elapsed_ms > self.commit_timeout or retry_count >=
self.commit_max_retries:
+ error_msg = (
+ f"Commit failed {latest_snapshot.id + 1 if latest_snapshot
else 1} "
+ f"after {elapsed_ms} millis with {retry_count} retries, "
+ f"there maybe exist commit conflicts between multiple
jobs."
+ )
+ if retry_result.exception:
+ raise RuntimeError(error_msg) from retry_result.exception
+ else:
+ raise RuntimeError(error_msg)
+
+ self._commit_retry_wait(retry_count)
+ retry_count += 1
+
+ def _try_commit_once(self, retry_result: Optional[RetryResult],
commit_kind: str,
+ commit_entries: List[ManifestEntry],
commit_identifier: int,
+ latest_snapshot: Optional[Snapshot]) -> CommitResult:
+ start_time_ms = int(time.time() * 1000)
+
+ if retry_result is not None and latest_snapshot is not None:
+ start_check_snapshot_id = 1 # Snapshot.FIRST_SNAPSHOT_ID
+ if retry_result.latest_snapshot is not None:
+ start_check_snapshot_id = retry_result.latest_snapshot.id + 1
+
+ for snapshot_id in range(start_check_snapshot_id,
latest_snapshot.id + 2):
+ snapshot =
self.snapshot_manager.get_snapshot_by_id(snapshot_id)
+ if (snapshot and snapshot.commit_user == self.commit_user and
+ snapshot.commit_identifier == commit_identifier and
+ snapshot.commit_kind == commit_kind):
+ logger.info(
+ f"Commit already completed (snapshot {snapshot_id}), "
+ f"user: {self.commit_user}, identifier:
{commit_identifier}"
+ )
+ return SuccessResult()
+
unique_id = uuid.uuid4()
base_manifest_list = f"manifest-list-{unique_id}-0"
delta_manifest_list = f"manifest-list-{unique_id}-1"
@@ -130,7 +218,6 @@ class FileStoreCommit:
deleted_file_count = 0
delta_record_count = 0
# process snapshot
- latest_snapshot = self.snapshot_manager.get_latest_snapshot()
new_snapshot_id = latest_snapshot.id + 1 if latest_snapshot else 1
# Check if row tracking is enabled
@@ -143,7 +230,7 @@ class FileStoreCommit:
commit_entries = self._assign_snapshot_id(new_snapshot_id,
commit_entries)
# Get the next row ID start from the latest snapshot
- first_row_id_start = self._get_next_row_id_start()
+ first_row_id_start = self._get_next_row_id_start(latest_snapshot)
# Assign row IDs to new files and get the next row ID for the
snapshot
commit_entries, next_row_id =
self._assign_row_tracking_meta(first_row_id_start, commit_entries)
@@ -155,71 +242,164 @@ class FileStoreCommit:
else:
deleted_file_count += 1
delta_record_count -= entry.file.row_count
- self.manifest_file_manager.write(new_manifest_file, commit_entries)
- # TODO: implement noConflictsOrFail logic
- partition_columns = list(zip(*(entry.partition.values for entry in
commit_entries)))
- partition_min_stats = [min(col) for col in partition_columns]
- partition_max_stats = [max(col) for col in partition_columns]
- partition_null_counts = [sum(value == 0 for value in col) for col in
partition_columns]
- if not all(count == 0 for count in partition_null_counts):
- raise RuntimeError("Partition value should not be null")
- manifest_file_path =
f"{self.manifest_file_manager.manifest_path}/{new_manifest_file}"
- new_manifest_list = ManifestFileMeta(
- file_name=new_manifest_file,
- file_size=self.table.file_io.get_file_size(manifest_file_path),
- num_added_files=added_file_count,
- num_deleted_files=deleted_file_count,
- partition_stats=SimpleStats(
- min_values=GenericRow(
- values=partition_min_stats,
- fields=self.table.partition_keys_fields
- ),
- max_values=GenericRow(
- values=partition_max_stats,
- fields=self.table.partition_keys_fields
+
+ try:
+ self.manifest_file_manager.write(new_manifest_file, commit_entries)
+
+ # TODO: implement noConflictsOrFail logic
+ partition_columns = list(zip(*(entry.partition.values for entry in
commit_entries)))
+ partition_min_stats = [min(col) for col in partition_columns]
+ partition_max_stats = [max(col) for col in partition_columns]
+ partition_null_counts = [sum(value == 0 for value in col) for col
in partition_columns]
+ if not all(count == 0 for count in partition_null_counts):
+ raise RuntimeError("Partition value should not be null")
+
+ manifest_file_path =
f"{self.manifest_file_manager.manifest_path}/{new_manifest_file}"
+ file_size = self.table.file_io.get_file_size(manifest_file_path)
+
+ new_manifest_file_meta = ManifestFileMeta(
+ file_name=new_manifest_file,
+ file_size=file_size,
+ num_added_files=added_file_count,
+ num_deleted_files=deleted_file_count,
+ partition_stats=SimpleStats(
+ min_values=GenericRow(
+ values=partition_min_stats,
+ fields=self.table.partition_keys_fields
+ ),
+ max_values=GenericRow(
+ values=partition_max_stats,
+ fields=self.table.partition_keys_fields
+ ),
+ null_counts=partition_null_counts,
),
- null_counts=partition_null_counts,
- ),
- schema_id=self.table.table_schema.id,
+ schema_id=self.table.table_schema.id,
+ )
+
+ self.manifest_list_manager.write(delta_manifest_list,
[new_manifest_file_meta])
+
+ # process existing_manifest
+ total_record_count = 0
+ if latest_snapshot:
+ existing_manifest_files =
self.manifest_list_manager.read_all(latest_snapshot)
+ previous_record_count = latest_snapshot.total_record_count
+ if previous_record_count:
+ total_record_count += previous_record_count
+ else:
+ existing_manifest_files = []
+
+ self.manifest_list_manager.write(base_manifest_list,
existing_manifest_files)
+ total_record_count += delta_record_count
+ snapshot_data = Snapshot(
+ version=3,
+ id=new_snapshot_id,
+ schema_id=self.table.table_schema.id,
+ base_manifest_list=base_manifest_list,
+ delta_manifest_list=delta_manifest_list,
+ total_record_count=total_record_count,
+ delta_record_count=delta_record_count,
+ commit_user=self.commit_user,
+ commit_identifier=commit_identifier,
+ commit_kind=commit_kind,
+ time_millis=int(time.time() * 1000),
+ next_row_id=next_row_id,
+ )
+ # Generate partition statistics for the commit
+ statistics = self._generate_partition_statistics(commit_entries)
+ except Exception as e:
+ self._cleanup_preparation_failure(new_manifest_file,
delta_manifest_list,
+ base_manifest_list)
+ logger.warning(f"Exception occurs when preparing snapshot: {e}",
exc_info=True)
+ raise RuntimeError(f"Failed to prepare snapshot: {e}")
+
+ # Use SnapshotCommit for atomic commit
+ try:
+ with self.snapshot_commit:
+ success = self.snapshot_commit.commit(snapshot_data,
self.table.current_branch(), statistics)
+ if not success:
+ # Commit failed, clean up temporary files and retry
+ commit_time_sec = (int(time.time() * 1000) -
start_time_ms) / 1000
+ logger.warning(
+ f"Atomic commit failed for snapshot #{new_snapshot_id}
"
+ f"by user {self.commit_user} "
+ f"with identifier {commit_identifier} and kind
{commit_kind} after {commit_time_sec}s. "
+ f"Clean up and try again."
+ )
+ self._cleanup_preparation_failure(new_manifest_file,
delta_manifest_list,
+ base_manifest_list)
+ return RetryResult(latest_snapshot, None)
+ except Exception as e:
+ # Commit exception, not sure about the situation and should not
clean up the files
+ logger.warning("Retry commit for exception")
+ return RetryResult(latest_snapshot, e)
+
+ logger.warning(
+ f"Successfully commit snapshot {new_snapshot_id} to table
{self.table.identifier} "
+ f"for snapshot-{new_snapshot_id} by user {self.commit_user} "
+ + f"with identifier {commit_identifier} and kind {commit_kind}."
)
- self.manifest_list_manager.write(delta_manifest_list,
[new_manifest_list])
-
- # process existing_manifest
- total_record_count = 0
- if latest_snapshot:
- existing_manifest_files =
self.manifest_list_manager.read_all(latest_snapshot)
- previous_record_count = latest_snapshot.total_record_count
- if previous_record_count:
- total_record_count += previous_record_count
- else:
- existing_manifest_files = []
- self.manifest_list_manager.write(base_manifest_list,
existing_manifest_files)
+ return SuccessResult()
- # process snapshot
- total_record_count += delta_record_count
- snapshot_data = Snapshot(
- version=3,
- id=new_snapshot_id,
- schema_id=self.table.table_schema.id,
- base_manifest_list=base_manifest_list,
- delta_manifest_list=delta_manifest_list,
- total_record_count=total_record_count,
- delta_record_count=delta_record_count,
- commit_user=self.commit_user,
- commit_identifier=commit_identifier,
- commit_kind=commit_kind,
- time_millis=int(time.time() * 1000),
- next_row_id=next_row_id,
+ def _generate_overwrite_entries(self):
+ """Generate commit entries for OVERWRITE mode based on latest
snapshot."""
+ entries = []
+ current_entries = FullStartingScanner(self.table,
self._overwrite_partition_filter, None).plan_files()
+ for entry in current_entries:
+ entry.kind = 1 # DELETE
+ entries.append(entry)
+ for msg in self._overwrite_commit_messages:
+ partition = GenericRow(list(msg.partition),
self.table.partition_keys_fields)
+ for file in msg.new_files:
+ entries.append(ManifestEntry(
+ kind=0, # ADD
+ partition=partition,
+ bucket=msg.bucket,
+ total_buckets=self.table.total_buckets,
+ file=file
+ ))
+ return entries
+
+ def _commit_retry_wait(self, retry_count: int):
+ import threading
+ thread_id = threading.get_ident()
+
+ retry_wait_ms = min(
+ self.commit_min_retry_wait * (2 ** retry_count),
+ self.commit_max_retry_wait
)
- # Generate partition statistics for the commit
- statistics = self._generate_partition_statistics(commit_entries)
+ jitter_ms = random.randint(0, max(1, int(retry_wait_ms * 0.2)))
+ total_wait_ms = retry_wait_ms + jitter_ms
- # Use SnapshotCommit for atomic commit
- with self.snapshot_commit:
- success = self.snapshot_commit.commit(snapshot_data,
self.table.current_branch(), statistics)
- if not success:
- raise RuntimeError(f"Failed to commit snapshot
{new_snapshot_id}")
+ logger.debug(
+ f"Thread {thread_id}: Waiting {total_wait_ms}ms before retry
(base: {retry_wait_ms}ms, "
+ f"jitter: {jitter_ms}ms)"
+ )
+ time.sleep(total_wait_ms / 1000.0)
+
+ def _cleanup_preparation_failure(self, manifest_file: Optional[str],
+ delta_manifest_list: Optional[str],
+ base_manifest_list: Optional[str]):
+ try:
+ manifest_path = self.manifest_list_manager.manifest_path
+
+ if delta_manifest_list:
+ manifest_files =
self.manifest_list_manager.read(delta_manifest_list)
+ for manifest_meta in manifest_files:
+ manifest_file_path =
f"{self.manifest_file_manager.manifest_path}/{manifest_meta.file_name}"
+ self.table.file_io.delete_quietly(manifest_file_path)
+ delta_path = f"{manifest_path}/{delta_manifest_list}"
+ self.table.file_io.delete_quietly(delta_path)
+
+ if base_manifest_list:
+ base_path = f"{manifest_path}/{base_manifest_list}"
+ self.table.file_io.delete_quietly(base_path)
+
+ if manifest_file:
+ manifest_file_path =
f"{self.manifest_file_manager.manifest_path}/{manifest_file}"
+ self.table.file_io.delete_quietly(manifest_file_path)
+ except Exception as e:
+ logger.warning(f"Failed to clean up temporary files during
preparation failure: {e}", exc_info=True)
def abort(self, commit_messages: List[CommitMessage]):
"""Abort commit and delete files. Uses external_path if available to
ensure proper scheme handling."""
@@ -332,9 +512,8 @@ class FileStoreCommit:
"""Assign snapshot ID to all commit entries."""
return [entry.assign_sequence_number(snapshot_id, snapshot_id) for
entry in commit_entries]
- def _get_next_row_id_start(self) -> int:
+ def _get_next_row_id_start(self, latest_snapshot) -> int:
"""Get the next row ID start from the latest snapshot."""
- latest_snapshot = self.snapshot_manager.get_latest_snapshot()
if latest_snapshot and hasattr(latest_snapshot, 'next_row_id') and
latest_snapshot.next_row_id is not None:
return latest_snapshot.next_row_id
return 0