This is an automated email from the ASF dual-hosted git repository.

JingsongLi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git


The following commit(s) were added to refs/heads/master by this push:
     new 38051b40c9 [python][ray] Honor partition overwrite in write_ray (#8088)
38051b40c9 is described below

commit 38051b40c97963a0b5950bbb72c84c11eadeecb1
Author: QuakeWang <[email protected]>
AuthorDate: Wed Jun 3 08:58:10 2026 +0800

    [python][ray] Honor partition overwrite in write_ray (#8088)
    
    `TableWrite.write_ray()` previously did not carry builder-level
    overwrite partitions into the Ray datasink. As a result,
    
`table.new_batch_write_builder().overwrite({...}).new_write().write_ray(...)`
    wrote through Ray without the configured partition overwrite contract,
    while `overwrite=True` only supported full-table overwrite.
    
    This PR carries the builder static partition into `TableWrite`, forwards
    it to `PaimonDatasink`, and applies the same overwrite partition on both
    Ray write tasks and the driver-side commit path.
---
 docs/docs/pypaimon/ray-data.md                     |  46 +++++---
 .../pypaimon/tests/ray_integration_test.py         |  49 +++++++++
 paimon-python/pypaimon/tests/ray_sink_test.py      | 122 +++++++++++++++++++++
 paimon-python/pypaimon/write/ray_datasink.py       |  19 +++-
 paimon-python/pypaimon/write/table_write.py        |  22 +++-
 paimon-python/pypaimon/write/write_builder.py      |   4 +-
 6 files changed, 233 insertions(+), 29 deletions(-)

diff --git a/docs/docs/pypaimon/ray-data.md b/docs/docs/pypaimon/ray-data.md
index cdbc7939ac..658a1098ae 100644
--- a/docs/docs/pypaimon/ray-data.md
+++ b/docs/docs/pypaimon/ray-data.md
@@ -279,11 +279,8 @@ import ray
 
 table = catalog.get_table('database_name.table_name')
 
-# 1. Create table write and commit (commit is only needed for non-Ray writes
-#    on the same table_write instance — see below).
-write_builder = table.new_batch_write_builder()
-table_write = write_builder.new_write()
-table_commit = write_builder.new_commit()
+# 1. Create table write.
+table_write = table.new_batch_write_builder().new_write()
 
 # 2. Write Ray Dataset
 ray_dataset = ray.data.read_json("/path/to/data.jsonl")
@@ -292,6 +289,7 @@ table_write.write_ray(
     overwrite=False,
     concurrency=2,
     hash_fixed_precluster="auto",
+    static_partition=None,
 )
 # Parameters:
 #   - dataset: Ray Dataset to write
@@ -300,28 +298,42 @@ table_write.write_ray(
 #   - ray_remote_args: Optional kwargs passed to ray.remote() (e.g., 
{"num_cpus": 2})
 #   - hash_fixed_precluster: Same HASH_FIXED modes and primary-key safety
 #     checks as write_paimon()
+#   - static_partition: Optional partition spec to overwrite. When set,
+#     write_ray() runs in overwrite mode for this partition.
 
-# 3. Commit data (required for write_pandas/write_arrow/write_arrow_batch only)
-commit_messages = table_write.prepare_commit()
-table_commit.commit(commit_messages)
-
-# 4. Close resources
+# 3. Close resources
 table_write.close()
-table_commit.close()
 ```
 
-### Overwrite at builder level
+### Overwrite
+
+The top-level `write_paimon()` API supports whole-table overwrite with the
+`overwrite=True` flag above. With the lower-level `write_ray()` API, you can
+use `overwrite=True` for whole-table overwrite and `static_partition={...}` for
+partition overwrite:
+
+```python
+table_write.write_ray(ray_dataset, overwrite=True)
+table_write.write_ray(ray_dataset, static_partition={'dt': '2024-01-01'})
+```
 
-The recommended way to overwrite via `write_paimon` is the `overwrite=True`
-flag above. When using the lower-level builder API, you can also configure
-overwrite mode on the write builder itself:
+When using the lower-level builder API, you can also configure overwrite mode
+on the write builder itself. The resulting `table_write` carries the overwrite
+partition into `write_ray()`. A `static_partition` argument passed directly to
+`write_ray()` overrides the builder-level partition:
 
 ```python
 # overwrite whole table
-write_builder = table.new_batch_write_builder().overwrite()
+table_write = table.new_batch_write_builder().overwrite().new_write()
+table_write.write_ray(ray_dataset)
 
 # overwrite partition 'dt=2024-01-01'
-write_builder = table.new_batch_write_builder().overwrite({'dt': '2024-01-01'})
+table_write = (
+    table.new_batch_write_builder()
+    .overwrite({'dt': '2024-01-01'})
+    .new_write()
+)
+table_write.write_ray(ray_dataset)
 ```
 
 ## Merge Into
diff --git a/paimon-python/pypaimon/tests/ray_integration_test.py 
b/paimon-python/pypaimon/tests/ray_integration_test.py
index 2ad95f610b..275d810dd3 100644
--- a/paimon-python/pypaimon/tests/ray_integration_test.py
+++ b/paimon-python/pypaimon/tests/ray_integration_test.py
@@ -333,6 +333,55 @@ class RayIntegrationTest(unittest.TestCase):
         result = read_paimon(identifier, self.catalog_options)
         self.assertEqual(result.count(), 0)
 
+    def test_table_write_ray_builder_partition_overwrite(self):
+        """Builder-level partition overwrite is honored by write_ray()."""
+        from pypaimon.ray import read_paimon
+
+        pa_schema = pa.schema([
+            ('id', pa.int32()),
+            ('val', pa.string()),
+            ('dt', pa.string()),
+        ])
+        identifier = 'default.test_write_ray_partition_overwrite'
+        catalog = CatalogFactory.create(self.catalog_options)
+        schema = Schema.from_pyarrow_schema(
+            pa_schema,
+            partition_keys=['dt'],
+            options={'dynamic-partition-overwrite': 'false'},
+        )
+        catalog.create_table(identifier, schema, False)
+        table = catalog.get_table(identifier)
+
+        initial = pa.Table.from_pydict(
+            {
+                'id': [1, 2, 3],
+                'val': ['old-p1-a', 'old-p1-b', 'old-p2'],
+                'dt': ['p1', 'p1', 'p2'],
+            },
+            schema=pa_schema,
+        )
+        write_builder = table.new_batch_write_builder()
+        writer = write_builder.new_write()
+        writer.write_arrow(initial)
+        write_builder.new_commit().commit(writer.prepare_commit())
+        writer.close()
+
+        replacement = ray.data.from_arrow(
+            pa.Table.from_pydict(
+                {'id': [4], 'val': ['new-p1'], 'dt': ['p1']},
+                schema=pa_schema,
+            )
+        )
+        writer = table.new_batch_write_builder().overwrite({'dt': 
'p1'}).new_write()
+        writer.write_ray(replacement, concurrency=1)
+        writer.close()
+
+        result = read_paimon(identifier, self.catalog_options)
+        df = result.to_pandas().sort_values('id').reset_index(drop=True)
+        self.assertEqual(list(df['id']), [3, 4])
+        self.assertEqual(list(df['val']), ['old-p2', 'new-p1'])
+        self.assertEqual(list(df['dt']), ['p2', 'p1'])
+
     def test_read_paimon_primary_key(self):
         """read_paimon() merges PK rows correctly after an upsert."""
         from pypaimon.ray import read_paimon
diff --git a/paimon-python/pypaimon/tests/ray_sink_test.py 
b/paimon-python/pypaimon/tests/ray_sink_test.py
index ca51b05d1b..a6d761df5a 100644
--- a/paimon-python/pypaimon/tests/ray_sink_test.py
+++ b/paimon-python/pypaimon/tests/ray_sink_test.py
@@ -26,6 +26,7 @@ from ray.data._internal.execution.interfaces import 
TaskContext
 from pypaimon import CatalogFactory, Schema
 from pypaimon.write.ray_datasink import PaimonDatasink
 from pypaimon.write.commit_message import CommitMessage
+from pypaimon.write.table_write import TableWrite
 
 
 class RaySinkTest(unittest.TestCase):
@@ -69,23 +70,34 @@ class RaySinkTest(unittest.TestCase):
         datasink = PaimonDatasink(self.table, overwrite=False)
         self.assertEqual(datasink.table, self.table)
         self.assertFalse(datasink.overwrite)
+        self.assertIsNone(datasink.static_partition)
         self.assertIsNone(datasink._writer_builder)
         self.assertEqual(datasink._table_name, "test_db.test_table")
 
         datasink_overwrite = PaimonDatasink(self.table, overwrite=True)
         self.assertTrue(datasink_overwrite.overwrite)
 
+        datasink_partition_overwrite = PaimonDatasink(
+            self.table, static_partition={'dt': '2024-01-01'})
+        self.assertFalse(datasink_partition_overwrite.overwrite)
+        self.assertEqual(
+            datasink_partition_overwrite.static_partition,
+            {'dt': '2024-01-01'},
+        )
+
         # Test serialization
         datasink._writer_builder = Mock()
         state = datasink.__getstate__()
         self.assertIn('table', state)
         self.assertIn('overwrite', state)
+        self.assertIn('static_partition', state)
         self.assertIn('_writer_builder', state)
 
         new_datasink = PaimonDatasink.__new__(PaimonDatasink)
         new_datasink.__setstate__(state)
         self.assertEqual(new_datasink.table, self.table)
         self.assertFalse(new_datasink.overwrite)
+        self.assertIsNone(new_datasink.static_partition)
 
     def test_table_and_writer_builder_serializable(self):
         import pickle
@@ -120,6 +132,29 @@ class RaySinkTest(unittest.TestCase):
         except Exception as e:
             self.fail(f"Overwrite WriterBuilder is not serializable: {e}")
 
+    def test_write_builder_new_write_carries_static_partition(self):
+        batch_write = (
+            self.table
+            .new_batch_write_builder()
+            .overwrite({'dt': '2024-01-01'})
+            .new_write()
+        )
+        try:
+            self.assertEqual(batch_write.static_partition, {'dt': 
'2024-01-01'})
+        finally:
+            batch_write.close()
+
+        stream_write = (
+            self.table
+            .new_stream_write_builder()
+            .overwrite({'dt': '2024-01-01'})
+            .new_write()
+        )
+        try:
+            self.assertEqual(stream_write.static_partition, {'dt': 
'2024-01-01'})
+        finally:
+            stream_write.close()
+
     def test_on_write_start(self):
         """Test on_write_start with normal and overwrite modes."""
         datasink = PaimonDatasink(self.table, overwrite=False)
@@ -131,6 +166,14 @@ class RaySinkTest(unittest.TestCase):
         datasink_overwrite.on_write_start()
         
self.assertIsNotNone(datasink_overwrite._writer_builder.static_partition)
 
+        datasink_partition_overwrite = PaimonDatasink(
+            self.table, static_partition={'dt': '2024-01-01'})
+        datasink_partition_overwrite.on_write_start()
+        self.assertEqual(
+            datasink_partition_overwrite._writer_builder.static_partition,
+            {'dt': '2024-01-01'},
+        )
+
     def test_write(self):
         """Test write method: empty blocks, multiple blocks, error handling, 
and resource cleanup."""
         datasink = PaimonDatasink(self.table, overwrite=False)
@@ -189,6 +232,25 @@ class RaySinkTest(unittest.TestCase):
             datasink.write([data_table], ctx)
             mock_builder.assert_called_once()
 
+        partition_datasink = PaimonDatasink(
+            self.table, static_partition={'dt': '2024-01-01'})
+        with patch.object(self.table, 'new_batch_write_builder') as 
mock_builder:
+            mock_write_builder = Mock()
+            mock_write_builder.overwrite.return_value = mock_write_builder
+            mock_write = Mock()
+            mock_write.prepare_commit.return_value = []
+            mock_write_builder.new_write.return_value = mock_write
+            mock_builder.return_value = mock_write_builder
+
+            data_table = pa.table({
+                'id': [1],
+                'name': ['Alice'],
+                'value': [1.1]
+            })
+            partition_datasink.write([data_table], ctx)
+            mock_write_builder.overwrite.assert_called_once_with(
+                {'dt': '2024-01-01'})
+
         invalid_table = pa.table({
             'wrong_column': [1, 2, 3]
         })
@@ -241,6 +303,20 @@ class RaySinkTest(unittest.TestCase):
         mock_commit.commit.assert_called_once_with([])
         mock_commit.close.assert_called_once()
 
+        datasink = PaimonDatasink(self.table, static_partition={'dt': 
'2024-01-01'})
+        datasink.on_write_start()
+        write_result = WriteResult(
+            num_rows=0,
+            size_bytes=0,
+            write_returns=[[], []]
+        )
+        mock_commit = Mock()
+        datasink._writer_builder.new_commit = Mock(return_value=mock_commit)
+        datasink.on_write_complete(write_result)
+
+        mock_commit.commit.assert_called_once_with([])
+        mock_commit.close.assert_called_once()
+
         # Test with messages and filtering empty messages
         datasink = PaimonDatasink(self.table, overwrite=False)
         datasink.on_write_start()
@@ -308,6 +384,52 @@ class RaySinkTest(unittest.TestCase):
             datasink.on_write_complete(write_result)
         self.assertEqual(len(datasink._pending_commit_messages), 1)
 
+    def test_table_write_ray_forwards_static_partition(self):
+        dataset = Mock()
+        table_write = TableWrite.__new__(TableWrite)
+        table_write.table = self.table
+        table_write.static_partition = {'dt': '2024-01-01'}
+
+        with patch('pypaimon.ray.shuffle.maybe_apply_repartition') as 
mock_repartition, \
+                patch('pypaimon.write.ray_datasink.PaimonDatasink') as 
mock_datasink_cls:
+            mock_repartition.return_value = dataset
+            datasink = mock_datasink_cls.return_value
+
+            table_write.write_ray(dataset, concurrency=2)
+
+            mock_repartition.assert_called_once_with(dataset, self.table, 
'auto')
+            mock_datasink_cls.assert_called_once_with(
+                self.table,
+                overwrite=False,
+                static_partition={'dt': '2024-01-01'},
+            )
+            dataset.write_datasink.assert_called_once_with(
+                datasink,
+                concurrency=2,
+                ray_remote_args=None,
+            )
+
+    def test_table_write_ray_static_partition_argument_overrides_builder(self):
+        dataset = Mock()
+        table_write = TableWrite.__new__(TableWrite)
+        table_write.table = self.table
+        table_write.static_partition = {'dt': '2024-01-01'}
+
+        with patch('pypaimon.ray.shuffle.maybe_apply_repartition') as 
mock_repartition, \
+                patch('pypaimon.write.ray_datasink.PaimonDatasink') as 
mock_datasink_cls:
+            mock_repartition.return_value = dataset
+
+            table_write.write_ray(
+                dataset,
+                static_partition={'dt': '2024-01-02'},
+            )
+
+            mock_datasink_cls.assert_called_once_with(
+                self.table,
+                overwrite=False,
+                static_partition={'dt': '2024-01-02'},
+            )
+
     def test_on_write_failed(self):
         # Test without pending messages (on_write_complete() never called)
         datasink = PaimonDatasink(self.table, overwrite=False)
diff --git a/paimon-python/pypaimon/write/ray_datasink.py 
b/paimon-python/pypaimon/write/ray_datasink.py
index 60edbc855e..6d48906f9f 100644
--- a/paimon-python/pypaimon/write/ray_datasink.py
+++ b/paimon-python/pypaimon/write/ray_datasink.py
@@ -20,7 +20,7 @@ Module to write a Paimon table from a Ray Dataset, by using 
the Ray Datasink API
 """
 
 import logging
-from typing import TYPE_CHECKING, Any, Iterable, List, Optional
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional
 
 from ray.data.datasource.datasink import Datasink
 
@@ -72,13 +72,18 @@ class PaimonDatasink(_DatasinkBase):
         self,
         table: "Table",
         overwrite: bool = False,
+        static_partition: Optional[Dict[str, Any]] = None,
     ):
         self.table = table
         self.overwrite = overwrite
+        self.static_partition = static_partition
         self._table_name = table.identifier.get_full_name()
         self._writer_builder: Optional["WriteBuilder"] = None
         self._pending_commit_messages: List["CommitMessage"] = []
 
+    def _is_overwrite(self) -> bool:
+        return self.overwrite or self.static_partition is not None
+
     def __getstate__(self) -> dict:
         state = self.__dict__.copy()
         return state
@@ -90,13 +95,15 @@ class PaimonDatasink(_DatasinkBase):
             self._writer_builder = None
         if not hasattr(self, '_table_name'):
             self._table_name = self.table.identifier.get_full_name()
+        if not hasattr(self, 'static_partition'):
+            self.static_partition = None
 
     def on_write_start(self, schema=None) -> None:
         logger.info(f"Starting write job for table {self._table_name}")
 
         self._writer_builder = self.table.new_batch_write_builder()
-        if self.overwrite:
-            self._writer_builder = self._writer_builder.overwrite()
+        if self._is_overwrite():
+            self._writer_builder = 
self._writer_builder.overwrite(self.static_partition)
 
     def write(
         self,
@@ -108,8 +115,8 @@ class PaimonDatasink(_DatasinkBase):
 
         try:
             writer_builder = self.table.new_batch_write_builder()
-            if self.overwrite:
-                writer_builder = writer_builder.overwrite()
+            if self._is_overwrite():
+                writer_builder = 
writer_builder.overwrite(self.static_partition)
             
             table_write = writer_builder.new_write()
 
@@ -167,7 +174,7 @@ class PaimonDatasink(_DatasinkBase):
 
             self._pending_commit_messages = non_empty_messages
 
-            if not non_empty_messages and not self.overwrite:
+            if not non_empty_messages and not self._is_overwrite():
                 logger.info("No data to commit (all commit messages are 
empty)")
                 self._pending_commit_messages = []
                 return
diff --git a/paimon-python/pypaimon/write/table_write.py 
b/paimon-python/pypaimon/write/table_write.py
index 1eb63793d0..1eeeb5e846 100644
--- a/paimon-python/pypaimon/write/table_write.py
+++ b/paimon-python/pypaimon/write/table_write.py
@@ -30,7 +30,7 @@ if TYPE_CHECKING:
 
 
 class TableWrite:
-    def __init__(self, table, commit_user):
+    def __init__(self, table, commit_user, static_partition: Optional[dict] = 
None):
         from pypaimon.table.file_store_table import FileStoreTable
 
         self.table: FileStoreTable = table
@@ -38,6 +38,7 @@ class TableWrite:
         self.file_store_write = FileStoreWrite(self.table, commit_user)
         self.row_key_extractor = self.table.create_row_key_extractor()
         self.commit_user = commit_user
+        self.static_partition = static_partition
 
     def write_arrow(self, table: pa.Table):
         batches_iterator = table.to_batches()
@@ -78,6 +79,7 @@ class TableWrite:
         concurrency: Optional[int] = None,
         ray_remote_args: Optional[Dict[str, Any]] = None,
         hash_fixed_precluster: str = "auto",
+        static_partition: Optional[dict] = None,
     ) -> None:
         """
         Write a Ray Dataset to Paimon table.
@@ -86,6 +88,7 @@ class TableWrite:
             dataset: Ray Dataset to write. This is a distributed data 
collection
                 from Ray Data (ray.data.Dataset).
             overwrite: Whether to overwrite existing data. Defaults to False.
+                Builder-level or static_partition overwrite mode takes 
precedence.
             concurrency: Optional max number of Ray tasks to run concurrently.
                 By default, dynamically decided based on available resources.
             ray_remote_args: Optional kwargs passed to :func:`ray.remote` in 
write tasks.
@@ -95,6 +98,9 @@ class TableWrite:
                 and reject HASH_FIXED primary-key tables. ``"map_groups"``
                 preserves the legacy small-file optimization and its single
                 group memory bound for HASH_FIXED primary-key tables.
+            static_partition: Optional partition spec to overwrite. When set,
+                the Ray write runs in overwrite mode for this partition and
+                overrides any builder-level partition spec.
         """
         from pypaimon.ray.shuffle import maybe_apply_repartition
         from pypaimon.write.ray_datasink import PaimonDatasink
@@ -102,7 +108,15 @@ class TableWrite:
         dataset = maybe_apply_repartition(
             dataset, self.table, hash_fixed_precluster)
 
-        datasink = PaimonDatasink(self.table, overwrite=overwrite)
+        overwrite_partition = self.static_partition
+        if static_partition is not None:
+            overwrite_partition = static_partition
+
+        datasink = PaimonDatasink(
+            self.table,
+            overwrite=overwrite,
+            static_partition=overwrite_partition,
+        )
         dataset.write_datasink(
             datasink,
             concurrency=concurrency,
@@ -141,8 +155,8 @@ class TableWrite:
 
 
 class BatchTableWrite(TableWrite):
-    def __init__(self, table, commit_user):
-        super().__init__(table, commit_user)
+    def __init__(self, table, commit_user, static_partition: Optional[dict] = 
None):
+        super().__init__(table, commit_user, static_partition)
         self.batch_committed = False
 
     def prepare_commit(self) -> List[CommitMessage]:
diff --git a/paimon-python/pypaimon/write/write_builder.py 
b/paimon-python/pypaimon/write/write_builder.py
index 724e5d7a3f..f7a0459305 100644
--- a/paimon-python/pypaimon/write/write_builder.py
+++ b/paimon-python/pypaimon/write/write_builder.py
@@ -59,7 +59,7 @@ class WriteBuilder(ABC):
 class BatchWriteBuilder(WriteBuilder):
 
     def new_write(self) -> BatchTableWrite:
-        return BatchTableWrite(self.table, self.commit_user)
+        return BatchTableWrite(self.table, self.commit_user, 
self.static_partition)
 
     def new_update(self) -> BatchTableUpdate:
         return BatchTableUpdate(self.table, self.commit_user)
@@ -72,7 +72,7 @@ class BatchWriteBuilder(WriteBuilder):
 class StreamWriteBuilder(WriteBuilder):
 
     def new_write(self) -> StreamTableWrite:
-        return StreamTableWrite(self.table, self.commit_user)
+        return StreamTableWrite(self.table, self.commit_user, 
self.static_partition)
 
     def new_update(self) -> StreamTableUpdate:
         return StreamTableUpdate(self.table, self.commit_user)

Reply via email to