This is an automated email from the ASF dual-hosted git repository.
lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git
The following commit(s) were added to refs/heads/master by this push:
new 6ff3f4010d [python] Support update partial file by row id (#7307)
6ff3f4010d is described below
commit 6ff3f4010d56b1e96b7ba0a8345eee74c66b691d
Author: umi <[email protected]>
AuthorDate: Fri Feb 27 10:26:01 2026 +0800
[python] Support update partial file by row id (#7307)
---
docs/content/pypaimon/data-evolution.md | 2 -
paimon-python/pypaimon/tests/table_update_test.py | 533 ++++++++++++++++++++-
paimon-python/pypaimon/write/file_store_commit.py | 1 -
.../pypaimon/write/table_update_by_row_id.py | 155 +++++-
4 files changed, 649 insertions(+), 42 deletions(-)
diff --git a/docs/content/pypaimon/data-evolution.md
b/docs/content/pypaimon/data-evolution.md
index e35e40fe9d..91714e52c7 100644
--- a/docs/content/pypaimon/data-evolution.md
+++ b/docs/content/pypaimon/data-evolution.md
@@ -45,8 +45,6 @@ its corresponding `first_row_id`, then groups rows with the
same `first_row_id`
**Requirements for `_ROW_ID` updates**
-- **All rows required**: the input table must contain **exactly the full table
row count** (one row per existing row).
-- **Row id coverage**: after sorting by `_ROW_ID`, it must be **0..N-1** (no
duplicates, no gaps).
- **Update columns only**: include `_ROW_ID` plus the columns you want to
update (partial schema is OK).
```python
diff --git a/paimon-python/pypaimon/tests/table_update_test.py
b/paimon-python/pypaimon/tests/table_update_test.py
index a3fd3022e2..ad9158e9fe 100644
--- a/paimon-python/pypaimon/tests/table_update_test.py
+++ b/paimon-python/pypaimon/tests/table_update_test.py
@@ -313,8 +313,78 @@ class TableUpdateTest(unittest.TestCase):
self.assertEqual(ages, expected_ages, "Age column was not updated
correctly")
self.assertEqual(cities, expected_cities, "City column was not updated
correctly")
- def test_wrong_total_row_count(self):
- """Test that wrong total row count raises an error."""
+ def test_update_partial_file(self):
+ """Test updating an existing column using data evolution."""
+ # Create table with initial data
+ table = self._create_table()
+
+ # Create data evolution table update
+ write_builder = table.new_batch_write_builder()
+ batch_write = write_builder.new_write()
+
+ # Prepare update data (sorted by row_id)
+ update_data = pa.Table.from_pydict({
+ '_ROW_ID': [1],
+ 'age': [31]
+ })
+
+ # Update the age column
+ write_builder = table.new_batch_write_builder()
+ table_update = write_builder.new_update().with_update_type(['age'])
+ commit_messages = table_update.update_by_arrow_with_row_id(update_data)
+
+ # Commit the changes
+ table_commit = write_builder.new_commit()
+ table_commit.commit(commit_messages)
+ table_commit.close()
+ batch_write.close()
+
+ # Verify the updated data
+ read_builder = table.new_read_builder()
+ table_read = read_builder.new_read()
+ splits = read_builder.new_scan().plan().splits()
+ result = table_read.to_arrow(splits)
+
+ # Check that ages were updated for rows 0-2
+ ages = result['age'].to_pylist()
+ expected_ages = [25, 31, 35, 40, 45]
+ self.assertEqual(ages, expected_ages)
+
+ def test_partial_rows_update_single_file(self):
+ """Test updating only some rows within a single file."""
+ # Create table with initial data
+ table = self._create_table()
+
+ # Create data evolution table update
+ write_builder = table.new_batch_write_builder()
+ table_update = write_builder.new_update().with_update_type(['age'])
+
+ # Update only row 0 in the first file (which contains rows 0-1)
+ update_data = pa.Table.from_pydict({
+ '_ROW_ID': [0],
+ 'age': [100]
+ })
+
+ commit_messages = table_update.update_by_arrow_with_row_id(update_data)
+
+ # Commit the changes
+ table_commit = write_builder.new_commit()
+ table_commit.commit(commit_messages)
+ table_commit.close()
+
+ # Verify the updated data
+ read_builder = table.new_read_builder()
+ table_read = read_builder.new_read()
+ splits = read_builder.new_scan().plan().splits()
+ result = table_read.to_arrow(splits)
+
+ # Check that only row 0 was updated, others remain unchanged
+ ages = result['age'].to_pylist()
+ expected_ages = [100, 30, 35, 40, 45] # Only first row updated
+ self.assertEqual(ages, expected_ages)
+
+ def test_partial_rows_update_multiple_files(self):
+ """Test updating partial rows across multiple files."""
# Create table with initial data
table = self._create_table()
@@ -322,20 +392,186 @@ class TableUpdateTest(unittest.TestCase):
write_builder = table.new_batch_write_builder()
table_update = write_builder.new_update().with_update_type(['age'])
- # Prepare update data with wrong row count (only 3 rows instead of 5)
+ # Update row 1 from first file and row 2 from second file
update_data = pa.Table.from_pydict({
- '_ROW_ID': [0, 1, 2],
- 'age': [26, 31, 36]
+ '_ROW_ID': [1, 2],
+ 'age': [200, 300]
})
- # Should raise ValueError for total row count mismatch
+ commit_messages = table_update.update_by_arrow_with_row_id(update_data)
+
+ # Commit the changes
+ table_commit = write_builder.new_commit()
+ table_commit.commit(commit_messages)
+ table_commit.close()
+
+ # Verify the updated data
+ read_builder = table.new_read_builder()
+ table_read = read_builder.new_read()
+ splits = read_builder.new_scan().plan().splits()
+ result = table_read.to_arrow(splits)
+
+ # Check that only specified rows were updated
+ ages = result['age'].to_pylist()
+ expected_ages = [25, 200, 300, 40, 45] # Rows 1 and 2 updated
+ self.assertEqual(ages, expected_ages)
+
+ def test_partial_rows_non_consecutive(self):
+ """Test updating non-consecutive rows."""
+ # Create table with initial data
+ table = self._create_table()
+
+ # Create data evolution table update
+ write_builder = table.new_batch_write_builder()
+ table_update = write_builder.new_update().with_update_type(['age'])
+
+ # Update rows 0, 2, 4 (non-consecutive)
+ update_data = pa.Table.from_pydict({
+ '_ROW_ID': [0, 2, 4],
+ 'age': [100, 300, 500]
+ })
+
+ commit_messages = table_update.update_by_arrow_with_row_id(update_data)
+
+ # Commit the changes
+ table_commit = write_builder.new_commit()
+ table_commit.commit(commit_messages)
+ table_commit.close()
+
+ # Verify the updated data
+ read_builder = table.new_read_builder()
+ table_read = read_builder.new_read()
+ splits = read_builder.new_scan().plan().splits()
+ result = table_read.to_arrow(splits)
+
+ # Check that only specified rows were updated
+ ages = result['age'].to_pylist()
+ expected_ages = [100, 30, 300, 40, 500] # Rows 0, 2, 4 updated
+ self.assertEqual(ages, expected_ages)
+
+ def test_update_preserves_other_columns(self):
+ """Test that updating one column preserves other columns."""
+ # Create table with initial data
+ table = self._create_table()
+
+ # Create data evolution table update
+ write_builder = table.new_batch_write_builder()
+ table_update = write_builder.new_update().with_update_type(['age'])
+
+ # Update only row 1
+ update_data = pa.Table.from_pydict({
+ '_ROW_ID': [1],
+ 'age': [999]
+ })
+
+ commit_messages = table_update.update_by_arrow_with_row_id(update_data)
+
+ # Commit the changes
+ table_commit = write_builder.new_commit()
+ table_commit.commit(commit_messages)
+ table_commit.close()
+
+ # Verify the updated data
+ read_builder = table.new_read_builder()
+ table_read = read_builder.new_read()
+ splits = read_builder.new_scan().plan().splits()
+ result = table_read.to_arrow(splits)
+
+ # Check that age was updated for row 1
+ ages = result['age'].to_pylist()
+ expected_ages = [25, 999, 35, 40, 45]
+ self.assertEqual(ages, expected_ages)
+
+ # Check that other columns remain unchanged
+ names = result['name'].to_pylist()
+ expected_names = ['Alice', 'Bob', 'Charlie', 'David', 'Eve']
+ self.assertEqual(names, expected_names)
+
+ cities = result['city'].to_pylist()
+ expected_cities = ['NYC', 'LA', 'Chicago', 'Houston', 'Phoenix']
+ self.assertEqual(cities, expected_cities)
+
+ def test_sequential_partial_updates(self):
+ """Test multiple sequential partial updates on the same table."""
+ # Create table with initial data
+ table = self._create_table()
+
+ # First update: Update row 0
+ write_builder = table.new_batch_write_builder()
+ table_update = write_builder.new_update().with_update_type(['age'])
+
+ update_data_1 = pa.Table.from_pydict({
+ '_ROW_ID': [0],
+ 'age': [100]
+ })
+
+ commit_messages =
table_update.update_by_arrow_with_row_id(update_data_1)
+ table_commit = write_builder.new_commit()
+ table_commit.commit(commit_messages)
+ table_commit.close()
+
+ # Second update: Update row 2
+ write_builder = table.new_batch_write_builder()
+ table_update = write_builder.new_update().with_update_type(['age'])
+
+ update_data_2 = pa.Table.from_pydict({
+ '_ROW_ID': [2],
+ 'age': [300]
+ })
+
+ commit_messages =
table_update.update_by_arrow_with_row_id(update_data_2)
+ table_commit = write_builder.new_commit()
+ table_commit.commit(commit_messages)
+ table_commit.close()
+
+ # Third update: Update row 4
+ write_builder = table.new_batch_write_builder()
+ table_update = write_builder.new_update().with_update_type(['age'])
+
+ update_data_3 = pa.Table.from_pydict({
+ '_ROW_ID': [4],
+ 'age': [500]
+ })
+
+ commit_messages =
table_update.update_by_arrow_with_row_id(update_data_3)
+ table_commit = write_builder.new_commit()
+ table_commit.commit(commit_messages)
+ table_commit.close()
+
+ # Verify the final data
+ read_builder = table.new_read_builder()
+ table_read = read_builder.new_read()
+ splits = read_builder.new_scan().plan().splits()
+ result = table_read.to_arrow(splits)
+
+ # Check that all updates were applied correctly
+ ages = result['age'].to_pylist()
+ expected_ages = [100, 30, 300, 40, 500]
+ self.assertEqual(expected_ages, ages)
+
+ def test_row_id_out_of_range(self):
+ """Test that row_id out of valid range raises an error."""
+ # Create table with initial data
+ table = self._create_table()
+
+ # Create data evolution table update
+ write_builder = table.new_batch_write_builder()
+ table_update = write_builder.new_update().with_update_type(['age'])
+
+ # Prepare update data with row_id out of range (table has 5 rows: 0-4)
+ update_data = pa.Table.from_pydict({
+ '_ROW_ID': [0, 10], # 10 is out of range
+ 'age': [26, 100]
+ })
+
+ # Should raise ValueError for row_id out of range
with self.assertRaises(ValueError) as context:
table_update.update_by_arrow_with_row_id(update_data)
- self.assertIn("does not match table total row count",
str(context.exception))
+ self.assertIn("out of valid range", str(context.exception))
- def test_wrong_first_row_id_row_count(self):
- """Test that wrong row count for a first_row_id raises an error."""
+ def test_negative_row_id(self):
+ """Test that negative row_id raises an error."""
# Create table with initial data
table = self._create_table()
@@ -343,17 +579,286 @@ class TableUpdateTest(unittest.TestCase):
write_builder = table.new_batch_write_builder()
table_update = write_builder.new_update().with_update_type(['age'])
- # Prepare update data with duplicate row_id (violates monotonically
increasing)
+ # Prepare update data with negative row_id
+ update_data = pa.Table.from_pydict({
+ '_ROW_ID': [-1, 0],
+ 'age': [100, 26]
+ })
+
+ # Should raise ValueError for negative row_id
+ with self.assertRaises(ValueError) as context:
+ table_update.update_by_arrow_with_row_id(update_data)
+
+ self.assertIn("out of valid range", str(context.exception))
+
+ def test_duplicate_row_id(self):
+ """Test that duplicate _ROW_ID values raise an error."""
+ table = self._create_table()
+
+ write_builder = table.new_batch_write_builder()
+ table_update = write_builder.new_update().with_update_type(['age'])
+
update_data = pa.Table.from_pydict({
- '_ROW_ID': [0, 1, 1, 4, 5],
- 'age': [26, 31, 36, 37, 38]
+ '_ROW_ID': [0, 0, 1],
+ 'age': [100, 200, 300]
})
- # Should raise ValueError for row ID validation
with self.assertRaises(ValueError) as context:
table_update.update_by_arrow_with_row_id(update_data)
- self.assertIn("Row IDs are not monotonically increasing",
str(context.exception))
+ self.assertIn("duplicate _ROW_ID values", str(context.exception))
+
+ def test_large_table_partial_column_updates(self):
+ """Test partial column updates on a large table with 4 columns.
+
+ This test covers:
+ 1. Update first column (id), update 1 row, verify result
+ 2. Update first and second columns (id, name), update 2 rows, verify
result
+ 3. Update second column (name), update 1 row, verify result
+ 4. Update third column (age), verify result
+ """
+ import uuid
+
+ # Create table with 4 columns and 20 rows in 2 files
+ table_name = f'test_large_table_{uuid.uuid4().hex[:8]}'
+ schema = Schema.from_pyarrow_schema(self.pa_schema,
options=self.table_options)
+ self.catalog.create_table(f'default.{table_name}', schema, False)
+ table = self.catalog.get_table(f'default.{table_name}')
+
+ # Write batch-1
+ num_row = 1000
+ write_builder = table.new_batch_write_builder()
+ batch1_data = pa.Table.from_pydict({
+ 'id': list(range(num_row)),
+ 'name': [f'Name_{i}' for i in range(num_row)],
+ 'age': [20 + i for i in range(num_row)],
+ 'city': [f'City_{i}' for i in range(num_row)]
+ }, schema=self.pa_schema)
+
+ table_write = write_builder.new_write()
+ table_commit = write_builder.new_commit()
+ table_write.write_arrow(batch1_data)
+ table_commit.commit(table_write.prepare_commit())
+ table_write.close()
+ table_commit.close()
+
+ # Write batch-2
+ batch2_data = pa.Table.from_pydict({
+ 'id': list(range(num_row, num_row * 2)),
+ 'name': [f'Name_{i}' for i in range(num_row, num_row * 2)],
+ 'age': [20 + num_row + i for i in range(num_row)],
+ 'city': [f'City_{i}' for i in range(num_row, num_row * 2)]
+ }, schema=self.pa_schema)
+
+ table_write = write_builder.new_write()
+ table_commit = write_builder.new_commit()
+ table_write.write_arrow(batch2_data)
+ table_commit.commit(table_write.prepare_commit())
+ table_write.close()
+ table_commit.close()
+
+ # --- Test 1: Update first column (id), update 1 row ---
+ write_builder = table.new_batch_write_builder()
+ table_update = write_builder.new_update().with_update_type(['id'])
+
+ update_data = pa.Table.from_pydict({
+ '_ROW_ID': [5],
+ 'id': [999]
+ })
+
+ commit_messages = table_update.update_by_arrow_with_row_id(update_data)
+ table_commit = write_builder.new_commit()
+ table_commit.commit(commit_messages)
+ table_commit.close()
+
+ # Verify Test 1
+ read_builder = table.new_read_builder()
+ table_read = read_builder.new_read()
+ splits = read_builder.new_scan().plan().splits()
+ result = table_read.to_arrow(splits)
+
+ ids = result['id'].to_pylist()
+ expected_ids = list(range(num_row * 2))
+ expected_ids[5] = 999
+ self.assertEqual(expected_ids, ids, "Test 1 failed: Update first
column for 1 row")
+
+ # --- Test 2: Update first and second columns (id, name), update 2
rows ---
+ write_builder = table.new_batch_write_builder()
+ table_update = write_builder.new_update().with_update_type(['id',
'name'])
+
+ update_data = pa.Table.from_pydict({
+ '_ROW_ID': [3, 15],
+ 'id': [888, 777],
+ 'name': ['Updated_Name_3', 'Updated_Name_15']
+ })
+
+ commit_messages = table_update.update_by_arrow_with_row_id(update_data)
+ table_commit = write_builder.new_commit()
+ table_commit.commit(commit_messages)
+ table_commit.close()
+
+ # Verify Test 2
+ read_builder = table.new_read_builder()
+ table_read = read_builder.new_read()
+ splits = read_builder.new_scan().plan().splits()
+ result = table_read.to_arrow(splits)
+
+ ids = result['id'].to_pylist()
+ expected_ids[3] = 888
+ expected_ids[15] = 777
+ self.assertEqual(expected_ids, ids, "Test 2 failed: Update id column
for 2 rows")
+
+ names = result['name'].to_pylist()
+ expected_names = [f'Name_{i}' for i in range(num_row * 2)]
+ expected_names[3] = 'Updated_Name_3'
+ expected_names[15] = 'Updated_Name_15'
+ self.assertEqual(expected_names, names, "Test 2 failed: Update name
column for 2 rows")
+
+ # --- Test 3: Update second column (name), update 1 row ---
+ write_builder = table.new_batch_write_builder()
+ table_update = write_builder.new_update().with_update_type(['name'])
+
+ update_data = pa.Table.from_pydict({
+ '_ROW_ID': [12],
+ 'name': ['NewName_12']
+ })
+
+ commit_messages = table_update.update_by_arrow_with_row_id(update_data)
+ table_commit = write_builder.new_commit()
+ table_commit.commit(commit_messages)
+ table_commit.close()
+
+ # Verify Test 3
+ read_builder = table.new_read_builder()
+ table_read = read_builder.new_read()
+ splits = read_builder.new_scan().plan().splits()
+ result = table_read.to_arrow(splits)
+
+ names = result['name'].to_pylist()
+ expected_names[12] = 'NewName_12'
+ self.assertEqual(expected_names, names, "Test 3 failed: Update name
column for 1 row")
+
+ # --- Test 4: Update third column (age), update multiple rows across
both files ---
+ write_builder = table.new_batch_write_builder()
+ table_update = write_builder.new_update().with_update_type(['age'])
+
+ update_data = pa.Table.from_pydict({
+ '_ROW_ID': [0, 5, 10, 15],
+ 'age': [100, 105, 110, 115]
+ })
+
+ commit_messages = table_update.update_by_arrow_with_row_id(update_data)
+ table_commit = write_builder.new_commit()
+ table_commit.commit(commit_messages)
+ table_commit.close()
+
+ # Verify Test 4
+ read_builder = table.new_read_builder()
+ table_read = read_builder.new_read()
+ splits = read_builder.new_scan().plan().splits()
+ result = table_read.to_arrow(splits)
+
+ ages = result['age'].to_pylist()
+ expected_ages = [20 + i for i in range(num_row)] + [20 + num_row + i
for i in range(num_row)]
+ expected_ages[0] = 100
+ expected_ages[5] = 105
+ expected_ages[10] = 110
+ expected_ages[15] = 115
+ self.assertEqual(expected_ages, ages, "Test 4 failed: Update age
column for multiple rows")
+
+ # Verify other columns remain unchanged after all updates
+ cities = result['city'].to_pylist()
+ expected_cities = [f'City_{i}' for i in range(num_row * 2)]
+ self.assertEqual(expected_cities, cities, "City column should remain
unchanged")
+
+ def test_update_partial_rows_across_two_files(self):
+ """Test updating partial rows across two data files in a single update
operation.
+
+ This test creates a table with 2 commits (2 data files), then performs
a single update
+ that modifies partial rows from both files simultaneously.
+ """
+ import uuid
+
+ # Create table
+ table_name = f'test_two_files_{uuid.uuid4().hex[:8]}'
+ schema = Schema.from_pyarrow_schema(self.pa_schema,
options=self.table_options)
+ self.catalog.create_table(f'default.{table_name}', schema, False)
+ table = self.catalog.get_table(f'default.{table_name}')
+
+ # Commit 1: Write first batch of data (row_id 0-4)
+ write_builder = table.new_batch_write_builder()
+ batch1_data = pa.Table.from_pydict({
+ 'id': [1, 2, 3, 4, 5],
+ 'name': ['Alice', 'Bob', 'Charlie', 'David', 'Eve'],
+ 'age': [20, 25, 30, 35, 40],
+ 'city': ['NYC', 'LA', 'Chicago', 'Houston', 'Phoenix']
+ }, schema=self.pa_schema)
+
+ table_write = write_builder.new_write()
+ table_commit = write_builder.new_commit()
+ table_write.write_arrow(batch1_data)
+ table_commit.commit(table_write.prepare_commit())
+ table_write.close()
+ table_commit.close()
+
+ # Commit 2: Write second batch of data (row_id 5-9)
+ batch2_data = pa.Table.from_pydict({
+ 'id': [6, 7, 8, 9, 10],
+ 'name': ['Frank', 'Grace', 'Henry', 'Ivy', 'Jack'],
+ 'age': [45, 50, 55, 60, 65],
+ 'city': ['Seattle', 'Boston', 'Denver', 'Miami', 'Atlanta']
+ }, schema=self.pa_schema)
+
+ table_write = write_builder.new_write()
+ table_commit = write_builder.new_commit()
+ table_write.write_arrow(batch2_data)
+ table_commit.commit(table_write.prepare_commit())
+ table_write.close()
+ table_commit.close()
+
+ # Single update: Update partial rows from both files
+ # Update rows 1, 3 from file 1 and rows 6, 8 from file 2
+ write_builder = table.new_batch_write_builder()
+ table_update = write_builder.new_update().with_update_type(['age',
'name'])
+
+ update_data = pa.Table.from_pydict({
+ '_ROW_ID': [1, 3, 6, 8],
+ 'age': [100, 200, 300, 400],
+ 'name': ['Updated_Bob', 'Updated_David', 'Updated_Grace',
'Updated_Ivy']
+ })
+
+ commit_messages = table_update.update_by_arrow_with_row_id(update_data)
+ table_commit = write_builder.new_commit()
+ table_commit.commit(commit_messages)
+ table_commit.close()
+
+ # Verify the updated data
+ read_builder = table.new_read_builder()
+ table_read = read_builder.new_read()
+ splits = read_builder.new_scan().plan().splits()
+ result = table_read.to_arrow(splits)
+
+ # Verify ages: rows 1, 3, 6, 8 should be updated
+ ages = result['age'].to_pylist()
+ expected_ages = [20, 100, 30, 200, 40, 45, 300, 55, 400, 65]
+ self.assertEqual(expected_ages, ages, "Ages not updated correctly
across two files")
+
+ # Verify names: rows 1, 3, 6, 8 should be updated
+ names = result['name'].to_pylist()
+ expected_names = ['Alice', 'Updated_Bob', 'Charlie', 'Updated_David',
'Eve',
+ 'Frank', 'Updated_Grace', 'Henry', 'Updated_Ivy',
'Jack']
+ self.assertEqual(expected_names, names, "Names not updated correctly
across two files")
+
+ # Verify other columns remain unchanged
+ ids = result['id'].to_pylist()
+ expected_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+ self.assertEqual(expected_ids, ids, "IDs should remain unchanged")
+
+ cities = result['city'].to_pylist()
+ expected_cities = ['NYC', 'LA', 'Chicago', 'Houston', 'Phoenix',
+ 'Seattle', 'Boston', 'Denver', 'Miami', 'Atlanta']
+ self.assertEqual(expected_cities, cities, "Cities should remain
unchanged")
+
if __name__ == '__main__':
unittest.main()
diff --git a/paimon-python/pypaimon/write/file_store_commit.py
b/paimon-python/pypaimon/write/file_store_commit.py
index 8aa9107ef8..80eb858087 100644
--- a/paimon-python/pypaimon/write/file_store_commit.py
+++ b/paimon-python/pypaimon/write/file_store_commit.py
@@ -633,7 +633,6 @@ class FileStoreCommit:
def _assign_row_tracking_meta(self, first_row_id_start: int,
commit_entries: List[ManifestEntry]):
"""
Assign row tracking metadata (first_row_id) to new files.
- This follows the Java implementation logic from
FileStoreCommitImpl.assignRowTrackingMeta.
"""
if not commit_entries:
return commit_entries, first_row_id_start
diff --git a/paimon-python/pypaimon/write/table_update_by_row_id.py
b/paimon-python/pypaimon/write/table_update_by_row_id.py
index fb854dbccc..b027564ff3 100644
--- a/paimon-python/pypaimon/write/table_update_by_row_id.py
+++ b/paimon-python/pypaimon/write/table_update_by_row_id.py
@@ -21,6 +21,9 @@ from typing import Dict, List, Optional
import pyarrow as pa
import pyarrow.compute as pc
+from pypaimon.read.split import DataSplit
+from pypaimon.read.table_read import TableRead
+from pypaimon.schema.data_types import DataField
from pypaimon.snapshot.snapshot import BATCH_COMMIT_IDENTIFIER
from pypaimon.table.row.generic_row import GenericRow
from pypaimon.table.special_fields import SpecialFields
@@ -47,7 +50,8 @@ class TableUpdateByRowId:
(self.first_row_ids,
self.first_row_id_to_partition_map,
self.first_row_id_to_row_count_map,
- self.total_row_count) = self._load_existing_files_info()
+ self.total_row_count,
+ self.splits) = self._load_existing_files_info()
# Collect commit messages
self.commit_messages = []
@@ -72,8 +76,11 @@ class TableUpdateByRowId:
total_row_count = sum(first_row_id_to_row_count_map.values())
- return sorted(list(set(first_row_ids))
- ), first_row_id_to_partition_map,
first_row_id_to_row_count_map, total_row_count
+ return (sorted(list(set(first_row_ids))),
+ first_row_id_to_partition_map,
+ first_row_id_to_row_count_map,
+ total_row_count,
+ splits)
def update_columns(self, data: pa.Table, column_names: List[str]) -> List:
"""
@@ -100,11 +107,6 @@ class TableUpdateByRowId:
if col_name not in self.table.field_names:
raise ValueError(f"Column {col_name} not found in table
schema")
- # Validate data row count matches total row count
- if data.num_rows != self.total_row_count:
- raise ValueError(
- f"Input data row count ({data.num_rows}) does not match table
total row count ({self.total_row_count})")
-
# Sort data by _ROW_ID column
sorted_data = data.sort_by([(SpecialFields.ROW_ID.name, "ascending")])
@@ -117,14 +119,20 @@ class TableUpdateByRowId:
return self.commit_messages
def _calculate_first_row_id(self, data: pa.Table) -> pa.Table:
- """Calculate _first_row_id for each row based on _ROW_ID."""
+ """Calculate _first_row_id for each row based on _ROW_ID.
+
+ Supports partial row updates - row_ids don't need to be consecutive.
+ """
row_ids = data[SpecialFields.ROW_ID.name].to_pylist()
- # Validate that row_ids are monotonically increasing starting from 0
- expected_row_ids = list(range(len(row_ids)))
- if row_ids != expected_row_ids:
- raise ValueError(f"Row IDs are not monotonically increasing
starting from 0. "
- f"Expected: {expected_row_ids}")
+ # Validate row_ids have no duplicates
+ if len(row_ids) != len(set(row_ids)):
+ raise ValueError("Input data contains duplicate _ROW_ID values")
+
+ # Validate row_ids are within valid range
+ for row_id in row_ids:
+ if row_id < 0 or row_id >= self.total_row_count:
+ raise ValueError(f"Row ID {row_id} is out of valid range [0,
{self.total_row_count})")
# Calculate first_row_id for each row_id
first_row_id_values = []
@@ -171,31 +179,128 @@ class TableUpdateByRowId:
"""Find the partition for a given first_row_id using pre-built
partition map."""
return self.first_row_id_to_partition_map.get(first_row_id)
+ def _read_original_file_data(self, first_row_id: int, column_names:
List[str]) -> Optional[pa.Table]:
+ """Read original file data for the given first_row_id.
+
+ Only reads columns that exist in the original file and need to be
updated.
+ In Data Evolution mode, uses the table's read API to get the latest
data
+ for the specified columns, which handles merging multiple files
correctly.
+
+ Args:
+ first_row_id: The first_row_id of the file to read
+ column_names: The column names to update
+
+ Returns:
+ PyArrow Table containing the original data for columns that exist
in the file,
+ or None if no columns need to be read from the original file.
+ """
+
+ # Build read type for the columns we need to read
+ read_fields: List[DataField] = []
+ for field in self.table.fields:
+ if field.name in column_names:
+ read_fields.append(field)
+
+ if not read_fields:
+ return None
+
+ # Find the split that contains files with this first_row_id
+ target_split = None
+ target_files = []
+ for split in self.splits:
+ for file_idx, file in enumerate(split.files):
+ if file.first_row_id == first_row_id:
+ target_files.append(file)
+ if target_split is None:
+ target_split = split
+ if target_split is not None:
+ break
+
+ if not target_files:
+ raise ValueError(f"No file found for first_row_id {first_row_id}")
+
+ # Create a DataSplit containing all files with this first_row_id
+ origin_split = DataSplit(
+ files=target_files,
+ partition=target_split.partition,
+ bucket=target_split.bucket,
+ raw_convertible=True
+ )
+
+ # Create TableRead and read the data
+ table_read = TableRead(self.table, predicate=None,
read_type=read_fields)
+ origin_data = table_read.to_arrow([origin_split])
+
+ return origin_data
+
+ def _merge_update_with_original(self, original_data: Optional[pa.Table],
update_data: pa.Table,
+ column_names: List[str], first_row_id:
int) -> pa.Table:
+ """Merge update data with original data, preserving row order.
+
+ For rows that have updates, use the update values.
+ For rows without updates, use the original values (if available).
+
+ Args:
+ original_data: Original data from the file (may be None if no
columns need to be read)
+ update_data: Update data (may contain only partial rows)
+ column_names: Column names being updated
+ first_row_id: The first_row_id of this file group
+
+ Returns:
+ Merged PyArrow Table with all rows
+ """
+
+ # Get the _ROW_ID values from update_data to determine which rows are
being updated
+ relative_indices = pc.subtract(
+ update_data[SpecialFields.ROW_ID.name],
+ pa.scalar(first_row_id, type=pa.int64())
+ ).cast(pa.int64())
+
+ # Build a boolean mask: True at positions that need to be updated
+ all_indices = pa.array(range(original_data.num_rows), type=pa.int64())
+ mask = pc.is_in(all_indices, relative_indices)
+
+ # Build the merged table column by column
+ merged_columns = {}
+ for col_name in column_names:
+ update_col = update_data[col_name].combine_chunks()
+ original_col = original_data[col_name].combine_chunks()
+ # replace_with_mask fills mask=True positions with update values
in order
+ merged_columns[col_name] = pc.replace_with_mask(
+ original_col, mask, update_col.cast(original_col.type)
+ )
+
+ # Create the merged table
+ merged_table = pa.table(merged_columns)
+
+ return merged_table
+
def _write_group(self, partition: GenericRow, first_row_id: int,
data: pa.Table, column_names: List[str]):
- """Write a group of data with the same first_row_id."""
+ """Write a group of data with the same first_row_id.
+
+ This method reads the original file data, merges it with the update
data,
+ and writes out the complete merged data. Supports partial row updates.
+ """
+
+ # Read original file data for the columns being updated
+ original_data = self._read_original_file_data(first_row_id,
column_names)
- # Validate data row count matches the first_row_id's row count
- expected_row_count =
self.first_row_id_to_row_count_map.get(first_row_id, 0)
- if data.num_rows != expected_row_count:
- raise ValueError(
- f"Data row count ({data.num_rows}) does not match expected row
count ({expected_row_count}) "
- f"for first_row_id {first_row_id}")
+ # Merge update data with original data
+ merged_data = self._merge_update_with_original(original_data, data,
column_names, first_row_id)
# Create a file store write for this partition
file_store_write = FileStoreWrite(self.table, self.commit_user)
# Set write columns to only update specific columns
- # Note: _ROW_ID is metadata column, not part of schema
write_cols = column_names
file_store_write.write_cols = write_cols
# Convert partition to tuple for hashing
partition_tuple = tuple(partition.values)
- # Write data - convert Table to RecordBatch
- data_to_write = data.select(write_cols)
- for batch in data_to_write.to_batches():
+ # Write merged data - convert Table to RecordBatch
+ for batch in merged_data.to_batches():
file_store_write.write(partition_tuple, 0, batch)
# Prepare commit and assign first_row_id