This is an automated email from the ASF dual-hosted git repository.
JingsongLi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git
The following commit(s) were added to refs/heads/master by this push:
new 38051b40c9 [python][ray] Honor partition overwrite in write_ray (#8088)
38051b40c9 is described below
commit 38051b40c97963a0b5950bbb72c84c11eadeecb1
Author: QuakeWang <[email protected]>
AuthorDate: Wed Jun 3 08:58:10 2026 +0800
[python][ray] Honor partition overwrite in write_ray (#8088)
`TableWrite.write_ray()` previously did not carry builder-level
overwrite partitions into the Ray datasink. As a result,
`table.new_batch_write_builder().overwrite({...}).new_write().write_ray(...)`
wrote through Ray without the configured partition overwrite contract,
while `overwrite=True` only supported full-table overwrite.
This PR carries the builder static partition into `TableWrite`, forwards
it to `PaimonDatasink`, and applies the same overwrite partition on both
Ray write tasks and the driver-side commit path.
---
docs/docs/pypaimon/ray-data.md | 46 +++++---
.../pypaimon/tests/ray_integration_test.py | 49 +++++++++
paimon-python/pypaimon/tests/ray_sink_test.py | 122 +++++++++++++++++++++
paimon-python/pypaimon/write/ray_datasink.py | 19 +++-
paimon-python/pypaimon/write/table_write.py | 22 +++-
paimon-python/pypaimon/write/write_builder.py | 4 +-
6 files changed, 233 insertions(+), 29 deletions(-)
diff --git a/docs/docs/pypaimon/ray-data.md b/docs/docs/pypaimon/ray-data.md
index cdbc7939ac..658a1098ae 100644
--- a/docs/docs/pypaimon/ray-data.md
+++ b/docs/docs/pypaimon/ray-data.md
@@ -279,11 +279,8 @@ import ray
table = catalog.get_table('database_name.table_name')
-# 1. Create table write and commit (commit is only needed for non-Ray writes
-# on the same table_write instance — see below).
-write_builder = table.new_batch_write_builder()
-table_write = write_builder.new_write()
-table_commit = write_builder.new_commit()
+# 1. Create table write.
+table_write = table.new_batch_write_builder().new_write()
# 2. Write Ray Dataset
ray_dataset = ray.data.read_json("/path/to/data.jsonl")
@@ -292,6 +289,7 @@ table_write.write_ray(
overwrite=False,
concurrency=2,
hash_fixed_precluster="auto",
+ static_partition=None,
)
# Parameters:
# - dataset: Ray Dataset to write
@@ -300,28 +298,42 @@ table_write.write_ray(
# - ray_remote_args: Optional kwargs passed to ray.remote() (e.g.,
{"num_cpus": 2})
# - hash_fixed_precluster: Same HASH_FIXED modes and primary-key safety
# checks as write_paimon()
+# - static_partition: Optional partition spec to overwrite. When set,
+# write_ray() runs in overwrite mode for this partition.
-# 3. Commit data (required for write_pandas/write_arrow/write_arrow_batch only)
-commit_messages = table_write.prepare_commit()
-table_commit.commit(commit_messages)
-
-# 4. Close resources
+# 3. Close resources
table_write.close()
-table_commit.close()
```
-### Overwrite at builder level
+### Overwrite
+
+The top-level `write_paimon()` API supports whole-table overwrite with the
+`overwrite=True` flag above. With the lower-level `write_ray()` API, you can
+use `overwrite=True` for whole-table overwrite and `static_partition={...}` for
+partition overwrite:
+
+```python
+table_write.write_ray(ray_dataset, overwrite=True)
+table_write.write_ray(ray_dataset, static_partition={'dt': '2024-01-01'})
+```
-The recommended way to overwrite via `write_paimon` is the `overwrite=True`
-flag above. When using the lower-level builder API, you can also configure
-overwrite mode on the write builder itself:
+When using the lower-level builder API, you can also configure overwrite mode
+on the write builder itself. The resulting `table_write` carries the overwrite
+partition into `write_ray()`. A `static_partition` argument passed directly to
+`write_ray()` overrides the builder-level partition:
```python
# overwrite whole table
-write_builder = table.new_batch_write_builder().overwrite()
+table_write = table.new_batch_write_builder().overwrite().new_write()
+table_write.write_ray(ray_dataset)
# overwrite partition 'dt=2024-01-01'
-write_builder = table.new_batch_write_builder().overwrite({'dt': '2024-01-01'})
+table_write = (
+ table.new_batch_write_builder()
+ .overwrite({'dt': '2024-01-01'})
+ .new_write()
+)
+table_write.write_ray(ray_dataset)
```
## Merge Into
diff --git a/paimon-python/pypaimon/tests/ray_integration_test.py
b/paimon-python/pypaimon/tests/ray_integration_test.py
index 2ad95f610b..275d810dd3 100644
--- a/paimon-python/pypaimon/tests/ray_integration_test.py
+++ b/paimon-python/pypaimon/tests/ray_integration_test.py
@@ -333,6 +333,55 @@ class RayIntegrationTest(unittest.TestCase):
result = read_paimon(identifier, self.catalog_options)
self.assertEqual(result.count(), 0)
+ def test_table_write_ray_builder_partition_overwrite(self):
+ """Builder-level partition overwrite is honored by write_ray()."""
+ from pypaimon.ray import read_paimon
+
+ pa_schema = pa.schema([
+ ('id', pa.int32()),
+ ('val', pa.string()),
+ ('dt', pa.string()),
+ ])
+ identifier = 'default.test_write_ray_partition_overwrite'
+ catalog = CatalogFactory.create(self.catalog_options)
+ schema = Schema.from_pyarrow_schema(
+ pa_schema,
+ partition_keys=['dt'],
+ options={'dynamic-partition-overwrite': 'false'},
+ )
+ catalog.create_table(identifier, schema, False)
+ table = catalog.get_table(identifier)
+
+ initial = pa.Table.from_pydict(
+ {
+ 'id': [1, 2, 3],
+ 'val': ['old-p1-a', 'old-p1-b', 'old-p2'],
+ 'dt': ['p1', 'p1', 'p2'],
+ },
+ schema=pa_schema,
+ )
+ write_builder = table.new_batch_write_builder()
+ writer = write_builder.new_write()
+ writer.write_arrow(initial)
+ write_builder.new_commit().commit(writer.prepare_commit())
+ writer.close()
+
+ replacement = ray.data.from_arrow(
+ pa.Table.from_pydict(
+ {'id': [4], 'val': ['new-p1'], 'dt': ['p1']},
+ schema=pa_schema,
+ )
+ )
+ writer = table.new_batch_write_builder().overwrite({'dt':
'p1'}).new_write()
+ writer.write_ray(replacement, concurrency=1)
+ writer.close()
+
+ result = read_paimon(identifier, self.catalog_options)
+ df = result.to_pandas().sort_values('id').reset_index(drop=True)
+ self.assertEqual(list(df['id']), [3, 4])
+ self.assertEqual(list(df['val']), ['old-p2', 'new-p1'])
+ self.assertEqual(list(df['dt']), ['p2', 'p1'])
+
def test_read_paimon_primary_key(self):
"""read_paimon() merges PK rows correctly after an upsert."""
from pypaimon.ray import read_paimon
diff --git a/paimon-python/pypaimon/tests/ray_sink_test.py
b/paimon-python/pypaimon/tests/ray_sink_test.py
index ca51b05d1b..a6d761df5a 100644
--- a/paimon-python/pypaimon/tests/ray_sink_test.py
+++ b/paimon-python/pypaimon/tests/ray_sink_test.py
@@ -26,6 +26,7 @@ from ray.data._internal.execution.interfaces import
TaskContext
from pypaimon import CatalogFactory, Schema
from pypaimon.write.ray_datasink import PaimonDatasink
from pypaimon.write.commit_message import CommitMessage
+from pypaimon.write.table_write import TableWrite
class RaySinkTest(unittest.TestCase):
@@ -69,23 +70,34 @@ class RaySinkTest(unittest.TestCase):
datasink = PaimonDatasink(self.table, overwrite=False)
self.assertEqual(datasink.table, self.table)
self.assertFalse(datasink.overwrite)
+ self.assertIsNone(datasink.static_partition)
self.assertIsNone(datasink._writer_builder)
self.assertEqual(datasink._table_name, "test_db.test_table")
datasink_overwrite = PaimonDatasink(self.table, overwrite=True)
self.assertTrue(datasink_overwrite.overwrite)
+ datasink_partition_overwrite = PaimonDatasink(
+ self.table, static_partition={'dt': '2024-01-01'})
+ self.assertFalse(datasink_partition_overwrite.overwrite)
+ self.assertEqual(
+ datasink_partition_overwrite.static_partition,
+ {'dt': '2024-01-01'},
+ )
+
# Test serialization
datasink._writer_builder = Mock()
state = datasink.__getstate__()
self.assertIn('table', state)
self.assertIn('overwrite', state)
+ self.assertIn('static_partition', state)
self.assertIn('_writer_builder', state)
new_datasink = PaimonDatasink.__new__(PaimonDatasink)
new_datasink.__setstate__(state)
self.assertEqual(new_datasink.table, self.table)
self.assertFalse(new_datasink.overwrite)
+ self.assertIsNone(new_datasink.static_partition)
def test_table_and_writer_builder_serializable(self):
import pickle
@@ -120,6 +132,29 @@ class RaySinkTest(unittest.TestCase):
except Exception as e:
self.fail(f"Overwrite WriterBuilder is not serializable: {e}")
+ def test_write_builder_new_write_carries_static_partition(self):
+ batch_write = (
+ self.table
+ .new_batch_write_builder()
+ .overwrite({'dt': '2024-01-01'})
+ .new_write()
+ )
+ try:
+ self.assertEqual(batch_write.static_partition, {'dt':
'2024-01-01'})
+ finally:
+ batch_write.close()
+
+ stream_write = (
+ self.table
+ .new_stream_write_builder()
+ .overwrite({'dt': '2024-01-01'})
+ .new_write()
+ )
+ try:
+ self.assertEqual(stream_write.static_partition, {'dt':
'2024-01-01'})
+ finally:
+ stream_write.close()
+
def test_on_write_start(self):
"""Test on_write_start with normal and overwrite modes."""
datasink = PaimonDatasink(self.table, overwrite=False)
@@ -131,6 +166,14 @@ class RaySinkTest(unittest.TestCase):
datasink_overwrite.on_write_start()
self.assertIsNotNone(datasink_overwrite._writer_builder.static_partition)
+ datasink_partition_overwrite = PaimonDatasink(
+ self.table, static_partition={'dt': '2024-01-01'})
+ datasink_partition_overwrite.on_write_start()
+ self.assertEqual(
+ datasink_partition_overwrite._writer_builder.static_partition,
+ {'dt': '2024-01-01'},
+ )
+
def test_write(self):
"""Test write method: empty blocks, multiple blocks, error handling,
and resource cleanup."""
datasink = PaimonDatasink(self.table, overwrite=False)
@@ -189,6 +232,25 @@ class RaySinkTest(unittest.TestCase):
datasink.write([data_table], ctx)
mock_builder.assert_called_once()
+ partition_datasink = PaimonDatasink(
+ self.table, static_partition={'dt': '2024-01-01'})
+ with patch.object(self.table, 'new_batch_write_builder') as
mock_builder:
+ mock_write_builder = Mock()
+ mock_write_builder.overwrite.return_value = mock_write_builder
+ mock_write = Mock()
+ mock_write.prepare_commit.return_value = []
+ mock_write_builder.new_write.return_value = mock_write
+ mock_builder.return_value = mock_write_builder
+
+ data_table = pa.table({
+ 'id': [1],
+ 'name': ['Alice'],
+ 'value': [1.1]
+ })
+ partition_datasink.write([data_table], ctx)
+ mock_write_builder.overwrite.assert_called_once_with(
+ {'dt': '2024-01-01'})
+
invalid_table = pa.table({
'wrong_column': [1, 2, 3]
})
@@ -241,6 +303,20 @@ class RaySinkTest(unittest.TestCase):
mock_commit.commit.assert_called_once_with([])
mock_commit.close.assert_called_once()
+ datasink = PaimonDatasink(self.table, static_partition={'dt':
'2024-01-01'})
+ datasink.on_write_start()
+ write_result = WriteResult(
+ num_rows=0,
+ size_bytes=0,
+ write_returns=[[], []]
+ )
+ mock_commit = Mock()
+ datasink._writer_builder.new_commit = Mock(return_value=mock_commit)
+ datasink.on_write_complete(write_result)
+
+ mock_commit.commit.assert_called_once_with([])
+ mock_commit.close.assert_called_once()
+
# Test with messages and filtering empty messages
datasink = PaimonDatasink(self.table, overwrite=False)
datasink.on_write_start()
@@ -308,6 +384,52 @@ class RaySinkTest(unittest.TestCase):
datasink.on_write_complete(write_result)
self.assertEqual(len(datasink._pending_commit_messages), 1)
+ def test_table_write_ray_forwards_static_partition(self):
+ dataset = Mock()
+ table_write = TableWrite.__new__(TableWrite)
+ table_write.table = self.table
+ table_write.static_partition = {'dt': '2024-01-01'}
+
+ with patch('pypaimon.ray.shuffle.maybe_apply_repartition') as
mock_repartition, \
+ patch('pypaimon.write.ray_datasink.PaimonDatasink') as
mock_datasink_cls:
+ mock_repartition.return_value = dataset
+ datasink = mock_datasink_cls.return_value
+
+ table_write.write_ray(dataset, concurrency=2)
+
+ mock_repartition.assert_called_once_with(dataset, self.table,
'auto')
+ mock_datasink_cls.assert_called_once_with(
+ self.table,
+ overwrite=False,
+ static_partition={'dt': '2024-01-01'},
+ )
+ dataset.write_datasink.assert_called_once_with(
+ datasink,
+ concurrency=2,
+ ray_remote_args=None,
+ )
+
+ def test_table_write_ray_static_partition_argument_overrides_builder(self):
+ dataset = Mock()
+ table_write = TableWrite.__new__(TableWrite)
+ table_write.table = self.table
+ table_write.static_partition = {'dt': '2024-01-01'}
+
+ with patch('pypaimon.ray.shuffle.maybe_apply_repartition') as
mock_repartition, \
+ patch('pypaimon.write.ray_datasink.PaimonDatasink') as
mock_datasink_cls:
+ mock_repartition.return_value = dataset
+
+ table_write.write_ray(
+ dataset,
+ static_partition={'dt': '2024-01-02'},
+ )
+
+ mock_datasink_cls.assert_called_once_with(
+ self.table,
+ overwrite=False,
+ static_partition={'dt': '2024-01-02'},
+ )
+
def test_on_write_failed(self):
# Test without pending messages (on_write_complete() never called)
datasink = PaimonDatasink(self.table, overwrite=False)
diff --git a/paimon-python/pypaimon/write/ray_datasink.py
b/paimon-python/pypaimon/write/ray_datasink.py
index 60edbc855e..6d48906f9f 100644
--- a/paimon-python/pypaimon/write/ray_datasink.py
+++ b/paimon-python/pypaimon/write/ray_datasink.py
@@ -20,7 +20,7 @@ Module to write a Paimon table from a Ray Dataset, by using
the Ray Datasink API
"""
import logging
-from typing import TYPE_CHECKING, Any, Iterable, List, Optional
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional
from ray.data.datasource.datasink import Datasink
@@ -72,13 +72,18 @@ class PaimonDatasink(_DatasinkBase):
self,
table: "Table",
overwrite: bool = False,
+ static_partition: Optional[Dict[str, Any]] = None,
):
self.table = table
self.overwrite = overwrite
+ self.static_partition = static_partition
self._table_name = table.identifier.get_full_name()
self._writer_builder: Optional["WriteBuilder"] = None
self._pending_commit_messages: List["CommitMessage"] = []
+ def _is_overwrite(self) -> bool:
+ return self.overwrite or self.static_partition is not None
+
def __getstate__(self) -> dict:
state = self.__dict__.copy()
return state
@@ -90,13 +95,15 @@ class PaimonDatasink(_DatasinkBase):
self._writer_builder = None
if not hasattr(self, '_table_name'):
self._table_name = self.table.identifier.get_full_name()
+ if not hasattr(self, 'static_partition'):
+ self.static_partition = None
def on_write_start(self, schema=None) -> None:
logger.info(f"Starting write job for table {self._table_name}")
self._writer_builder = self.table.new_batch_write_builder()
- if self.overwrite:
- self._writer_builder = self._writer_builder.overwrite()
+ if self._is_overwrite():
+ self._writer_builder =
self._writer_builder.overwrite(self.static_partition)
def write(
self,
@@ -108,8 +115,8 @@ class PaimonDatasink(_DatasinkBase):
try:
writer_builder = self.table.new_batch_write_builder()
- if self.overwrite:
- writer_builder = writer_builder.overwrite()
+ if self._is_overwrite():
+ writer_builder =
writer_builder.overwrite(self.static_partition)
table_write = writer_builder.new_write()
@@ -167,7 +174,7 @@ class PaimonDatasink(_DatasinkBase):
self._pending_commit_messages = non_empty_messages
- if not non_empty_messages and not self.overwrite:
+ if not non_empty_messages and not self._is_overwrite():
logger.info("No data to commit (all commit messages are
empty)")
self._pending_commit_messages = []
return
diff --git a/paimon-python/pypaimon/write/table_write.py
b/paimon-python/pypaimon/write/table_write.py
index 1eb63793d0..1eeeb5e846 100644
--- a/paimon-python/pypaimon/write/table_write.py
+++ b/paimon-python/pypaimon/write/table_write.py
@@ -30,7 +30,7 @@ if TYPE_CHECKING:
class TableWrite:
- def __init__(self, table, commit_user):
+ def __init__(self, table, commit_user, static_partition: Optional[dict] =
None):
from pypaimon.table.file_store_table import FileStoreTable
self.table: FileStoreTable = table
@@ -38,6 +38,7 @@ class TableWrite:
self.file_store_write = FileStoreWrite(self.table, commit_user)
self.row_key_extractor = self.table.create_row_key_extractor()
self.commit_user = commit_user
+ self.static_partition = static_partition
def write_arrow(self, table: pa.Table):
batches_iterator = table.to_batches()
@@ -78,6 +79,7 @@ class TableWrite:
concurrency: Optional[int] = None,
ray_remote_args: Optional[Dict[str, Any]] = None,
hash_fixed_precluster: str = "auto",
+ static_partition: Optional[dict] = None,
) -> None:
"""
Write a Ray Dataset to Paimon table.
@@ -86,6 +88,7 @@ class TableWrite:
dataset: Ray Dataset to write. This is a distributed data
collection
from Ray Data (ray.data.Dataset).
overwrite: Whether to overwrite existing data. Defaults to False.
+ Builder-level or static_partition overwrite mode takes
precedence.
concurrency: Optional max number of Ray tasks to run concurrently.
By default, dynamically decided based on available resources.
ray_remote_args: Optional kwargs passed to :func:`ray.remote` in
write tasks.
@@ -95,6 +98,9 @@ class TableWrite:
and reject HASH_FIXED primary-key tables. ``"map_groups"``
preserves the legacy small-file optimization and its single
group memory bound for HASH_FIXED primary-key tables.
+ static_partition: Optional partition spec to overwrite. When set,
+ the Ray write runs in overwrite mode for this partition and
+ overrides any builder-level partition spec.
"""
from pypaimon.ray.shuffle import maybe_apply_repartition
from pypaimon.write.ray_datasink import PaimonDatasink
@@ -102,7 +108,15 @@ class TableWrite:
dataset = maybe_apply_repartition(
dataset, self.table, hash_fixed_precluster)
- datasink = PaimonDatasink(self.table, overwrite=overwrite)
+ overwrite_partition = self.static_partition
+ if static_partition is not None:
+ overwrite_partition = static_partition
+
+ datasink = PaimonDatasink(
+ self.table,
+ overwrite=overwrite,
+ static_partition=overwrite_partition,
+ )
dataset.write_datasink(
datasink,
concurrency=concurrency,
@@ -141,8 +155,8 @@ class TableWrite:
class BatchTableWrite(TableWrite):
- def __init__(self, table, commit_user):
- super().__init__(table, commit_user)
+ def __init__(self, table, commit_user, static_partition: Optional[dict] =
None):
+ super().__init__(table, commit_user, static_partition)
self.batch_committed = False
def prepare_commit(self) -> List[CommitMessage]:
diff --git a/paimon-python/pypaimon/write/write_builder.py
b/paimon-python/pypaimon/write/write_builder.py
index 724e5d7a3f..f7a0459305 100644
--- a/paimon-python/pypaimon/write/write_builder.py
+++ b/paimon-python/pypaimon/write/write_builder.py
@@ -59,7 +59,7 @@ class WriteBuilder(ABC):
class BatchWriteBuilder(WriteBuilder):
def new_write(self) -> BatchTableWrite:
- return BatchTableWrite(self.table, self.commit_user)
+ return BatchTableWrite(self.table, self.commit_user,
self.static_partition)
def new_update(self) -> BatchTableUpdate:
return BatchTableUpdate(self.table, self.commit_user)
@@ -72,7 +72,7 @@ class BatchWriteBuilder(WriteBuilder):
class StreamWriteBuilder(WriteBuilder):
def new_write(self) -> StreamTableWrite:
- return StreamTableWrite(self.table, self.commit_user)
+ return StreamTableWrite(self.table, self.commit_user,
self.static_partition)
def new_update(self) -> StreamTableUpdate:
return StreamTableUpdate(self.table, self.commit_user)