This is an automated email from the ASF dual-hosted git repository.

JingsongLi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git


The following commit(s) were added to refs/heads/master by this push:
     new 2093598e1f [python] Fix upsert row_id validation failure on tables 
with row_id holes (#8092)
2093598e1f is described below

commit 2093598e1f532bc905b485bd73478a92265325a9
Author: XiaoHongbo <[email protected]>
AuthorDate: Wed Jun 3 16:39:18 2026 +0800

    [python] Fix upsert row_id validation failure on tables with row_id holes 
(#8092)
---
 .../pypaimon/ray/data_evolution_merge_join.py      | 23 ++++++-----
 paimon-python/pypaimon/tests/table_update_test.py  | 44 +++++++++++++++++++++-
 .../pypaimon/tests/table_upsert_by_key_test.py     | 42 +++++++++++++++++++++
 .../pypaimon/write/table_update_by_row_id.py       | 33 +++++++---------
 4 files changed, 110 insertions(+), 32 deletions(-)

diff --git a/paimon-python/pypaimon/ray/data_evolution_merge_join.py 
b/paimon-python/pypaimon/ray/data_evolution_merge_join.py
index 40e7994b2d..f1b814a861 100644
--- a/paimon-python/pypaimon/ray/data_evolution_merge_join.py
+++ b/paimon-python/pypaimon/ray/data_evolution_merge_join.py
@@ -169,8 +169,9 @@ def distributed_update_apply(
     frid_col = "_FIRST_ROW_ID"
     captured_sorted = sorted_first_row_ids
     captured_sorted_arr = np.asarray(captured_sorted, dtype=np.int64)
-    first = captured_sorted_arr[0]
-    total_row_count = planner.total_row_count
+    valid_ranges = planner.valid_row_id_ranges
+    range_starts = np.asarray([r.from_ for r in valid_ranges], dtype=np.int64)
+    range_ends = np.asarray([r.to for r in valid_ranges], dtype=np.int64)
 
     def _assign_frid(batch: pa.Table) -> pa.Table:
         if batch.num_rows == 0:
@@ -184,15 +185,17 @@ def distributed_update_apply(
                 "or matched rows come from a different table."
             )
         rids = rid_col.to_numpy(zero_copy_only=False)
-        # Out-of-range _ROW_IDs would silently map via searchsorted 
wrap-around.
-        out_of_range = (rids < first) | (rids >= total_row_count)
-        if out_of_range.any():
-            bad = rids[out_of_range][0]
+        # Check each row_id belongs to a valid range (vectorized).
+        in_range = np.zeros(len(rids), dtype=bool)
+        for s, e in zip(range_starts, range_ends):
+            in_range |= (rids >= s) & (rids <= e)
+        if not in_range.all():
+            bad = rids[~in_range][0]
             raise ValueError(
-                f"_ROW_ID {bad} is out of valid range "
-                f"[{first}, {total_row_count}); planner snapshot "
-                f"is stale or matched rows come from a different "
-                f"table."
+                f"_ROW_ID {bad} does not belong to any valid range "
+                f"{[f'[{r.from_}, {r.to}]' for r in valid_ranges]}; "
+                f"planner snapshot is stale or matched rows come "
+                f"from a different table."
             )
         idx = np.searchsorted(
             captured_sorted_arr, rids, side="right"
diff --git a/paimon-python/pypaimon/tests/table_update_test.py 
b/paimon-python/pypaimon/tests/table_update_test.py
index 181cf7a11b..57ae605703 100644
--- a/paimon-python/pypaimon/tests/table_update_test.py
+++ b/paimon-python/pypaimon/tests/table_update_test.py
@@ -427,7 +427,7 @@ class _TableUpdateTestBase(DataEvolutionTestBase):
         self.assertIn('_ROW_ID column', str(ctx.exception))
 
     def test_invalid_row_id_raises(self):
-        """row_id outside [0, total_row_count) (both directions) raises."""
+        """row_id outside valid row_id ranges raises."""
         table = self._create_seeded_table()
         cases = [
             ('out_of_range_high', [0, 10], [26, 100]),
@@ -440,7 +440,7 @@ class _TableUpdateTestBase(DataEvolutionTestBase):
                 bad = pa.Table.from_pydict({'_ROW_ID': row_ids, 'age': ages})
                 with self.assertRaises(ValueError) as ctx:
                     self._apply_update(tu, bad, self._next_commit_id())
-                self.assertIn('out of valid range', str(ctx.exception))
+                self.assertIn('does not belong to any valid range', 
str(ctx.exception))
 
     def test_duplicate_row_id_raises(self):
         table = self._create_seeded_table()
@@ -457,6 +457,46 @@ class _TableUpdateTestBase(DataEvolutionTestBase):
             )
         self.assertIn('duplicate _ROW_ID', str(ctx.exception))
 
+    def test_update_deleted_row_id_raises(self):
+        """Updating a row_id that fell into a hole after truncate raises."""
+        partitioned_schema = pa.schema([
+            ('id', pa.int32()),
+            ('name', pa.string()),
+            ('age', pa.int32()),
+            ('region', pa.string()),
+        ])
+        table = self._create_table(
+            pa_schema=partitioned_schema,
+            partition_keys=['region'],
+        )
+        self._write_arrow(table, pa.Table.from_pydict({
+            'id': pa.array([1, 2, 3], type=pa.int32()),
+            'name': ['A', 'B', 'C'],
+            'age': pa.array([10, 20, 30], type=pa.int32()),
+            'region': ['US', 'US', 'US'],
+        }, schema=partitioned_schema))
+
+        self._write_arrow(table, pa.Table.from_pydict({
+            'id': pa.array([4, 5], type=pa.int32()),
+            'name': ['D', 'E'],
+            'age': pa.array([40, 50], type=pa.int32()),
+            'region': ['EU', 'EU'],
+        }, schema=partitioned_schema))
+
+        wb = table.new_batch_write_builder()
+        tc = wb.new_commit()
+        tc.truncate_partitions([{'region': 'US'}])
+
+        wb = self._make_write_builder(table)
+        tu = wb.new_update().with_update_type(['age'])
+        with self.assertRaises(ValueError) as ctx:
+            self._apply_update(
+                tu,
+                pa.Table.from_pydict({'_ROW_ID': [0], 'age': [99]}),
+                self._next_commit_id(),
+            )
+        self.assertIn('does not belong to any valid range', str(ctx.exception))
+
     # ------------------------------------------------------------------
     # Concurrency tests
     # ------------------------------------------------------------------
diff --git a/paimon-python/pypaimon/tests/table_upsert_by_key_test.py 
b/paimon-python/pypaimon/tests/table_upsert_by_key_test.py
index 1d85ac0b99..f276e44ab8 100644
--- a/paimon-python/pypaimon/tests/table_upsert_by_key_test.py
+++ b/paimon-python/pypaimon/tests/table_upsert_by_key_test.py
@@ -436,6 +436,48 @@ class _TableUpsertByKeyTestBase(DataEvolutionTestBase):
         self.assertEqual('Carol', names[idx3])
         self.assertEqual('US',    regions[idx3])
 
+    def test_upsert_after_truncate_partition(self):
+        table = self._create_table(
+            pa_schema=self.partitioned_pa_schema,
+            partition_keys=['region'],
+        )
+        self._write_arrow(table, pa.Table.from_pydict({
+            'id': pa.array([1, 2, 3], type=pa.int32()),
+            'name': ['A', 'B', 'C'],
+            'age': pa.array([10, 20, 30], type=pa.int32()),
+            'region': ['US', 'US', 'US'],
+        }, schema=self.partitioned_pa_schema))
+
+        self._write_arrow(table, pa.Table.from_pydict({
+            'id': pa.array([4, 5], type=pa.int32()),
+            'name': ['D', 'E'],
+            'age': pa.array([40, 50], type=pa.int32()),
+            'region': ['EU', 'EU'],
+        }, schema=self.partitioned_pa_schema))
+
+        wb = table.new_batch_write_builder()
+        tc = wb.new_commit()
+        tc.truncate_partitions([{'region': 'US'}])
+
+        upsert_data = pa.Table.from_pydict({
+            'id': pa.array([4], type=pa.int32()),
+            'name': ['D_v2'],
+            'age': pa.array([41], type=pa.int32()),
+            'region': ['EU'],
+        }, schema=self.partitioned_pa_schema)
+        self._upsert(table, upsert_data, upsert_keys=['id'])
+
+        result = self._read_all(table)
+        self.assertEqual(2, result.num_rows)
+        rows = sorted(zip(
+            result['id'].to_pylist(),
+            result['name'].to_pylist(),
+            result['age'].to_pylist(),
+            result['region'].to_pylist(),
+        ))
+        self.assertEqual((4, 'D_v2', 41, 'EU'), rows[0])
+        self.assertEqual((5, 'E', 50, 'EU'), rows[1])
+
     # ==================================================================
     # update_cols partial update (non-partitioned)
     # ==================================================================
diff --git a/paimon-python/pypaimon/write/table_update_by_row_id.py 
b/paimon-python/pypaimon/write/table_update_by_row_id.py
index 7220f1718d..34816d6ffb 100644
--- a/paimon-python/pypaimon/write/table_update_by_row_id.py
+++ b/paimon-python/pypaimon/write/table_update_by_row_id.py
@@ -47,7 +47,7 @@ class _FilesInfo:
     first_row_id_index: Dict[int, Tuple[DataSplit, List[DataFileMeta]]] = (
         field(default_factory=dict)
     )
-    total_row_count: int = 0
+    valid_row_id_ranges: List[Range] = field(default_factory=list)
 
 
 class TableUpdateByRowId:
@@ -74,7 +74,7 @@ class TableUpdateByRowId:
         self.snapshot_id = info.snapshot_id
         self.first_row_ids = info.first_row_ids
         self._first_row_id_index = info.first_row_id_index
-        self.total_row_count = info.total_row_count
+        self.valid_row_id_ranges = info.valid_row_id_ranges
 
         self.commit_messages: List[CommitMessage] = []
 
@@ -84,7 +84,7 @@ class TableUpdateByRowId:
             snapshot_id=self.snapshot_id,
             first_row_ids=self.first_row_ids,
             first_row_id_index=self._first_row_id_index,
-            total_row_count=self.total_row_count,
+            valid_row_id_ranges=self.valid_row_id_ranges,
         )
 
     def _load_existing_files_info(self) -> _FilesInfo:
@@ -131,21 +131,17 @@ class TableUpdateByRowId:
                         if target_file.file_name not in existing_names
                     )
 
-        # Multiple physical files may share the same first_row_id (data 
evolution);
-        # summing row_count per file would over-count logical rows and widen
-        # the _ROW_ID validation range incorrectly.
         if row_id_ranges:
             merged = Range.sort_and_merge_overlap(row_id_ranges, True, True)
-            total_row_count = sum(r.count() for r in merged)
         else:
-            total_row_count = 0
+            merged = []
 
         snapshot_id = plan.snapshot_id if plan.snapshot_id is not None else -1
         return _FilesInfo(
             snapshot_id=snapshot_id,
             first_row_ids=sorted(index.keys()),
             first_row_id_index=index,
-            total_row_count=total_row_count,
+            valid_row_id_ranges=merged,
         )
 
     @staticmethod
@@ -183,8 +179,8 @@ class TableUpdateByRowId:
     def _calculate_first_row_id(self, data: pa.Table) -> pa.Table:
         """Append ``_FIRST_ROW_ID`` to *data* by looking up each ``_ROW_ID``.
 
-        Validates that every input ``_ROW_ID`` is unique and falls in
-        ``[0, total_row_count)``. Supports partial / non-consecutive updates.
+        Validates that every input ``_ROW_ID`` is unique and belongs to
+        a valid row_id range. Supports partial / non-consecutive updates.
         """
         row_id_arr = data[SpecialFields.ROW_ID.name]
         row_ids = row_id_arr.to_pylist()
@@ -197,15 +193,12 @@ class TableUpdateByRowId:
                 self.FIRST_ROW_ID_COLUMN, pa.array([], type=pa.int64()),
             )
 
-        # Vectorised range check (avoids a Python-level per-row loop).
-        min_id = pc.min(row_id_arr).as_py()
-        max_id = pc.max(row_id_arr).as_py()
-        if min_id < 0 or max_id >= self.total_row_count:
-            offending = min_id if min_id < 0 else max_id
-            raise ValueError(
-                f"Row ID {offending} is out of valid range "
-                f"[0, {self.total_row_count})"
-            )
+        for row_id in row_ids:
+            if not any(r.contains(row_id) for r in self.valid_row_id_ranges):
+                raise ValueError(
+                    f"Row ID {row_id} does not belong to any valid range "
+                    f"{[f'[{r.from_}, {r.to}]' for r in 
self.valid_row_id_ranges]}"
+                )
 
         if not self.first_row_ids:
             raise ValueError("The input sorted sequence is empty.")

Reply via email to