This is an automated email from the ASF dual-hosted git repository.

lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git


The following commit(s) were added to refs/heads/master by this push:
     new bda3eb031b [python] with_shard should be evenly distributed for data 
evolution mode (#7271)
bda3eb031b is described below

commit bda3eb031b829f2d4dce377fa64e3a4b09606c44
Author: Jingsong Lee <[email protected]>
AuthorDate: Thu Feb 12 10:50:43 2026 +0800

    [python] with_shard should be evenly distributed for data evolution mode 
(#7271)
---
 .../read/scanner/data_evolution_split_generator.py | 281 ++++++---------------
 .../pypaimon/read/scanner/split_generator.py       |  12 +-
 paimon-python/pypaimon/read/split.py               |  54 +++-
 paimon-python/pypaimon/tests/blob_table_test.py    |  22 +-
 .../pypaimon/tests/data_evolution_test.py          |  65 +++++
 paimon-python/pypaimon/write/table_update.py       |  61 ++++-
 6 files changed, 256 insertions(+), 239 deletions(-)

diff --git 
a/paimon-python/pypaimon/read/scanner/data_evolution_split_generator.py 
b/paimon-python/pypaimon/read/scanner/data_evolution_split_generator.py
index 241966134f..4ac154fed0 100644
--- a/paimon-python/pypaimon/read/scanner/data_evolution_split_generator.py
+++ b/paimon-python/pypaimon/read/scanner/data_evolution_split_generator.py
@@ -16,7 +16,7 @@ See the License for the specific language governing 
permissions and
 limitations under the License.
 """
 from collections import defaultdict
-from typing import List, Optional, Dict, Tuple
+from typing import List, Optional, Tuple
 
 from pypaimon.globalindex.indexed_split import IndexedSplit
 from pypaimon.globalindex.range import Range
@@ -66,15 +66,11 @@ class DataEvolutionSplitGenerator(AbstractSplitGenerator):
 
         slice_row_ranges = None  # Row ID ranges for slice-based filtering
 
-        if self.start_pos_of_this_subtask is not None:
+        if self.start_pos_of_this_subtask is not None or 
self.idx_of_this_subtask is not None:
             # Calculate Row ID range for slice-based filtering
             slice_row_ranges = 
self._calculate_slice_row_ranges(partitioned_files)
             # Filter files by Row ID range
             partitioned_files = 
self._filter_files_by_row_ranges(partitioned_files, slice_row_ranges)
-        elif self.idx_of_this_subtask is not None:
-            partitioned_files = self._filter_by_shard(
-                partitioned_files, self.idx_of_this_subtask, 
self.number_of_para_subtasks
-            )
 
         def weight_func(file_list: List[DataFileMeta]) -> int:
             return max(sum(f.file_size for f in file_list), 
self.open_file_cost)
@@ -133,21 +129,14 @@ class DataEvolutionSplitGenerator(AbstractSplitGenerator):
             pack = packed_files[i] if i < len(packed_files) else []
             raw_convertible = all(len(sub_pack) == 1 for sub_pack in pack)
 
-            file_paths = []
-            total_file_size = 0
-            total_record_count = 0
-
             for data_file in file_group:
                 data_file.set_file_path(
                     self.table.table_path,
                     file_entries[0].partition,
                     file_entries[0].bucket
                 )
-                file_paths.append(data_file.file_path)
-                total_file_size += data_file.file_size
-                total_record_count += data_file.row_count
 
-            if file_paths:
+            if file_group:
                 # Get deletion files for this split
                 data_deletion_files = None
                 if self.deletion_files_map:
@@ -161,9 +150,6 @@ class DataEvolutionSplitGenerator(AbstractSplitGenerator):
                     files=file_group,
                     partition=file_entries[0].partition,
                     bucket=file_entries[0].bucket,
-                    file_paths=file_paths,
-                    row_count=total_record_count,
-                    file_size=total_file_size,
                     raw_convertible=raw_convertible,
                     data_deletion_files=data_deletion_files
                 )
@@ -196,13 +182,31 @@ class DataEvolutionSplitGenerator(AbstractSplitGenerator):
     def _divide_ranges_by_position(self, sorted_ranges: List[Range]) -> 
Tuple[Optional[Range], Optional[Range]]:
         """
         Divide ranges by position (start_pos, end_pos) to get the Row ID range 
for this slice.
+        If idx_of_this_subtask exists, divide total rows by 
number_of_para_subtasks.
         """
         if not sorted_ranges:
             return None, None
 
         total_row_count = sum(r.count() for r in sorted_ranges)
-        start_pos = self.start_pos_of_this_subtask
-        end_pos = self.end_pos_of_this_subtask
+        
+        # If idx_of_this_subtask exists, calculate start_pos and end_pos based 
on number_of_para_subtasks
+        if self.idx_of_this_subtask is not None:
+            # Calculate shard boundaries based on total row count
+            rows_per_task = total_row_count // self.number_of_para_subtasks
+            remainder = total_row_count % self.number_of_para_subtasks
+            
+            start_pos = self.idx_of_this_subtask * rows_per_task
+            # Distribute remainder rows across first 'remainder' tasks
+            if self.idx_of_this_subtask < remainder:
+                start_pos += self.idx_of_this_subtask
+                end_pos = start_pos + rows_per_task + 1
+            else:
+                start_pos += remainder
+                end_pos = start_pos + rows_per_task
+        else:
+            # Use existing start_pos and end_pos
+            start_pos = self.start_pos_of_this_subtask
+            end_pos = self.end_pos_of_this_subtask
 
         if start_pos >= total_row_count:
             return None, None
@@ -239,201 +243,86 @@ class 
DataEvolutionSplitGenerator(AbstractSplitGenerator):
     def _filter_files_by_row_ranges(partitioned_files: defaultdict, 
row_ranges: List[Range]) -> defaultdict:
         """
         Filter files by Row ID ranges. Keep files that overlap with the given 
ranges.
+        Blob files are only included if they overlap with non-blob files that 
match the ranges.
         """
         filtered_partitioned_files = defaultdict(list)
 
         for key, file_entries in partitioned_files.items():
-            filtered_entries = []
-
+            # Separate blob and non-blob files
+            non_blob_entries = []
+            blob_entries = []
+            
             for entry in file_entries:
+                if DataFileMeta.is_blob_file(entry.file.file_name):
+                    blob_entries.append(entry)
+                else:
+                    non_blob_entries.append(entry)
+            
+            # First, filter non-blob files based on row ranges
+            filtered_non_blob_entries = []
+            non_blob_ranges = []
+            for entry in non_blob_entries:
                 first_row_id = entry.file.first_row_id
                 file_range = Range(first_row_id, first_row_id + 
entry.file.row_count - 1)
-
+                
                 # Check if file overlaps with any of the row ranges
                 overlaps = False
                 for r in row_ranges:
                     if r.overlaps(file_range):
                         overlaps = True
                         break
-
+                
                 if overlaps:
-                    filtered_entries.append(entry)
-
+                    filtered_non_blob_entries.append(entry)
+                    non_blob_ranges.append(file_range)
+            
+            # Then, filter blob files based on row ID range of non-blob files
+            filtered_blob_entries = []
+            non_blob_ranges = Range.sort_and_merge_overlap(non_blob_ranges, 
True, True)
+            # Only keep blob files that overlap with merged non-blob ranges
+            for entry in blob_entries:
+                first_row_id = entry.file.first_row_id
+                blob_range = Range(first_row_id, first_row_id + 
entry.file.row_count - 1)
+                # Check if blob file overlaps with any merged range
+                for merged_range in non_blob_ranges:
+                    if merged_range.overlaps(blob_range):
+                        filtered_blob_entries.append(entry)
+                        break
+            
+            # Combine filtered non-blob and blob files
+            filtered_entries = filtered_non_blob_entries + 
filtered_blob_entries
+            
             if filtered_entries:
                 filtered_partitioned_files[key] = filtered_entries
 
         return filtered_partitioned_files
 
-    def _filter_by_shard(self, partitioned_files: defaultdict, sub_task_id: 
int, total_tasks: int) -> defaultdict:
-        list_ranges = []
-        for file_entries in partitioned_files.values():
-            for entry in file_entries:
-                first_row_id = entry.file.first_row_id
-                if first_row_id is None:
-                    raise ValueError("Found None first row id in files")
-                # Range is inclusive [from_, to], so use row_count - 1
-                list_ranges.append(Range(first_row_id, first_row_id + 
entry.file.row_count - 1))
-
-        sorted_ranges = Range.sort_and_merge_overlap(list_ranges, True, False)
-
-        start_range, end_range = self._divide_ranges(sorted_ranges, 
sub_task_id, total_tasks)
-        if start_range is None or end_range is None:
-            return defaultdict(list)
-        start_first_row_id = start_range.from_
-        end_first_row_id = end_range.to
-
-        filtered_partitioned_files = {
-            k: [x for x in v if x.file.first_row_id >= start_first_row_id and 
x.file.first_row_id <= end_first_row_id]
-            for k, v in partitioned_files.items()
-        }
-
-        filtered_partitioned_files = {k: v for k, v in 
filtered_partitioned_files.items() if v}
-        return defaultdict(list, filtered_partitioned_files)
-
     @staticmethod
-    def _divide_ranges(
-        sorted_ranges: List[Range], sub_task_id: int, total_tasks: int
-    ) -> Tuple[Optional[Range], Optional[Range]]:
-        if not sorted_ranges:
-            return None, None
-
-        num_ranges = len(sorted_ranges)
-
-        # If more tasks than ranges, some tasks get nothing
-        if sub_task_id >= num_ranges:
-            return None, None
-
-        # Calculate balanced distribution of ranges across tasks
-        base_ranges_per_task = num_ranges // total_tasks
-        remainder = num_ranges % total_tasks
-
-        # Each of the first 'remainder' tasks gets one extra range
-        if sub_task_id < remainder:
-            num_ranges_for_task = base_ranges_per_task + 1
-            start_idx = sub_task_id * (base_ranges_per_task + 1)
-        else:
-            num_ranges_for_task = base_ranges_per_task
-            start_idx = (
-                remainder * (base_ranges_per_task + 1) +
-                (sub_task_id - remainder) * base_ranges_per_task
-            )
-        end_idx = start_idx + num_ranges_for_task - 1
-        return sorted_ranges[start_idx], sorted_ranges[end_idx]
-
-    def _split_by_row_id(self, files: List[DataFileMeta]) -> 
List[List[DataFileMeta]]:
+    def _split_by_row_id(files: List[DataFileMeta]) -> 
List[List[DataFileMeta]]:
         """
         Split files by row ID for data evolution tables.
+        Files are grouped by their overlapping row ID ranges.
         """
-        split_by_row_id = []
-
-        # Filter blob files to only include those within the row ID range of 
non-blob files
-        sorted_files = self._filter_blob(files)
-
-        # Split files by firstRowId
-        last_row_id = -1
-        check_row_id_start = 0
-        current_split = []
-
-        for file in sorted_files:
+        list_ranges = []
+        for file in files:
             first_row_id = file.first_row_id
-            if first_row_id is None:
-                # Files without firstRowId are treated as individual splits
-                split_by_row_id.append([file])
-                continue
-
-            if not DataFileMeta.is_blob_file(file.file_name) and first_row_id 
!= last_row_id:
-                if current_split:
-                    split_by_row_id.append(current_split)
-
-                # Validate that files don't overlap
-                if first_row_id < check_row_id_start:
-                    file_names = [f.file_name for f in sorted_files]
-                    raise ValueError(
-                        f"There are overlapping files in the split: 
{file_names}, "
-                        f"the wrong file is: {file.file_name}"
-                    )
-
-                current_split = []
-                last_row_id = first_row_id
-                check_row_id_start = first_row_id + file.row_count
+            list_ranges.append(Range(first_row_id, first_row_id + 
file.row_count - 1))
 
-            current_split.append(file)
-
-        if current_split:
-            split_by_row_id.append(current_split)
-
-        return split_by_row_id
-
-    def _compute_slice_split_file_idx_map(
-        self,
-        plan_start_pos: int,
-        plan_end_pos: int,
-        split: Split,
-        file_end_pos: int
-    ) -> Dict[str, Tuple[int, int]]:
-        """
-        Compute file index map for a split, determining which rows to read 
from each file.
-        For data files, the range is calculated based on the file's position 
in the cumulative row space.
-        For blob files (which may be rolled), the range is calculated based on 
each file's first_row_id.
-        """
-        shard_file_idx_map = {}
-        
-        # First pass: data files only. Compute range and apply directly to 
avoid second-pass lookup.
-        current_pos = file_end_pos
-        data_file_infos = []
-        for file in split.files:
-            if DataFileMeta.is_blob_file(file.file_name):
-                continue
-            file_begin_pos = current_pos
-            current_pos += file.row_count
-            data_file_range = self._compute_file_range(
-                plan_start_pos, plan_end_pos, file_begin_pos, file.row_count
-            )
-            data_file_infos.append((file, data_file_range))
-            if data_file_range is not None:
-                shard_file_idx_map[file.file_name] = data_file_range
-
-        if not data_file_infos:
-            # No data file, skip this split
-            shard_file_idx_map[self.NEXT_POS_KEY] = file_end_pos
-            return shard_file_idx_map
+        if not list_ranges:
+            return []
 
-        next_pos = current_pos
+        sorted_ranges = Range.sort_and_merge_overlap(list_ranges, True, False)
 
-        # Second pass: only blob files (data files already in 
shard_file_idx_map from first pass)
-        for file in split.files:
-            if not DataFileMeta.is_blob_file(file.file_name):
-                continue
-            blob_first_row_id = file.first_row_id if file.first_row_id is not 
None else 0
-            data_file_range = None
-            data_file_first_row_id = None
-            for df, fr in data_file_infos:
-                df_first = df.first_row_id if df.first_row_id is not None else 0
-                if df_first <= blob_first_row_id < df_first + df.row_count:
-                    data_file_range = fr
-                    data_file_first_row_id = df_first
+        range_to_files = {}
+        for file in files:
+            first_row_id = file.first_row_id
+            file_range = Range(first_row_id, first_row_id + file.row_count - 1)
+            for r in sorted_ranges:
+                if r.overlaps(file_range):
+                    range_to_files.setdefault(r, []).append(file)
                     break
-            if data_file_range is None:
-                continue
-            if data_file_range == (-1, -1):
-                shard_file_idx_map[file.file_name] = (-1, -1)
-                continue
-            blob_rel_start = blob_first_row_id - data_file_first_row_id
-            blob_rel_end = blob_rel_start + file.row_count
-            shard_start, shard_end = data_file_range
-            intersect_start = max(blob_rel_start, shard_start)
-            intersect_end = min(blob_rel_end, shard_end)
-            if intersect_start >= intersect_end:
-                shard_file_idx_map[file.file_name] = (-1, -1)
-            elif intersect_start == blob_rel_start and intersect_end == 
blob_rel_end:
-                pass
-            else:
-                local_start = intersect_start - blob_rel_start
-                local_end = intersect_end - blob_rel_start
-                shard_file_idx_map[file.file_name] = (local_start, local_end)
 
-        shard_file_idx_map[self.NEXT_POS_KEY] = next_pos
-        return shard_file_idx_map
+        return list(range_to_files.values())
 
     def _wrap_to_indexed_splits(self, splits: List[Split], row_ranges: 
List[Range]) -> List[Split]:
         """
@@ -478,25 +367,3 @@ class DataEvolutionSplitGenerator(AbstractSplitGenerator):
             indexed_splits.append(IndexedSplit(split, expected, scores))
 
         return indexed_splits
-
-    @staticmethod
-    def _filter_blob(files: List[DataFileMeta]) -> List[DataFileMeta]:
-        """
-        Filter blob files to only include those within row ID range of 
non-blob files.
-        """
-        result = []
-        row_id_start = -1
-        row_id_end = -1
-
-        for file in files:
-            if not DataFileMeta.is_blob_file(file.file_name):
-                if file.first_row_id is not None:
-                    row_id_start = file.first_row_id
-                    row_id_end = file.first_row_id + file.row_count
-                result.append(file)
-            else:
-                if file.first_row_id is not None and row_id_start != -1:
-                    if row_id_start <= file.first_row_id < row_id_end:
-                        result.append(file)
-
-        return result
diff --git a/paimon-python/pypaimon/read/scanner/split_generator.py 
b/paimon-python/pypaimon/read/scanner/split_generator.py
index f4f2cebbe7..f11240a8a9 100644
--- a/paimon-python/pypaimon/read/scanner/split_generator.py
+++ b/paimon-python/pypaimon/read/scanner/split_generator.py
@@ -98,21 +98,14 @@ class AbstractSplitGenerator(ABC):
             else:
                 raw_convertible = True
 
-            file_paths = []
-            total_file_size = 0
-            total_record_count = 0
-
             for data_file in file_group:
                 data_file.set_file_path(
                     self.table.table_path,
                     file_entries[0].partition,
                     file_entries[0].bucket
                 )
-                file_paths.append(data_file.file_path)
-                total_file_size += data_file.file_size
-                total_record_count += data_file.row_count
 
-            if file_paths:
+            if file_group:
                 # Get deletion files for this split
                 data_deletion_files = None
                 if self.deletion_files_map:
@@ -126,9 +119,6 @@ class AbstractSplitGenerator(ABC):
                     files=file_group,
                     partition=file_entries[0].partition,
                     bucket=file_entries[0].bucket,
-                    file_paths=file_paths,
-                    row_count=total_record_count,
-                    file_size=total_file_size,
                     raw_convertible=raw_convertible,
                     data_deletion_files=data_deletion_files
                 )
diff --git a/paimon-python/pypaimon/read/split.py 
b/paimon-python/pypaimon/read/split.py
index 12d20c0947..5bd63d8f52 100644
--- a/paimon-python/pypaimon/read/split.py
+++ b/paimon-python/pypaimon/read/split.py
@@ -17,7 +17,7 @@
 
################################################################################
 
 from abc import ABC, abstractmethod
-from typing import List, Optional
+from typing import List, Optional, Callable
 
 from pypaimon.manifest.schema.data_file_meta import DataFileMeta
 from pypaimon.table.row.generic_row import GenericRow
@@ -77,18 +77,12 @@ class DataSplit(Split):
         files: List[DataFileMeta],
         partition: GenericRow,
         bucket: int,
-        file_paths: List[str],
-        row_count: int,
-        file_size: int,
         raw_convertible: bool = False,
         data_deletion_files: Optional[List[DeletionFile]] = None
     ):
         self._files = files
         self._partition = partition
         self._bucket = bucket
-        self._file_paths = file_paths
-        self._row_count = row_count
-        self._file_size = file_size
         self.raw_convertible = raw_convertible
         self.data_deletion_files = data_deletion_files
 
@@ -96,6 +90,40 @@ class DataSplit(Split):
     def files(self) -> List[DataFileMeta]:
         return self._files
 
+    def filter_file(self, func: Callable[[DataFileMeta], bool]) -> 
Optional['DataSplit']:
+        """
+        Filter files based on a predicate function and create a new DataSplit.
+        
+        Args:
+            func: A function that takes a DataFileMeta and returns True if the 
file should be kept
+        
+        Returns:
+            A new DataSplit with filtered files, adjusted data_deletion_files
+        """
+        # Filter files based on the predicate
+        filtered_files = [f for f in self._files if func(f)]
+        
+        # If no files match, return None
+        if not filtered_files:
+            return None
+        
+        # Find indices of filtered files to adjust data_deletion_files
+        filtered_indices = [i for i, f in enumerate(self._files) if func(f)]
+        
+        # Filter data_deletion_files to match filtered files
+        filtered_data_deletion_files = None
+        if self.data_deletion_files is not None:
+            filtered_data_deletion_files = [self.data_deletion_files[i] for i 
in filtered_indices]
+        
+        # Create new DataSplit with filtered data
+        return DataSplit(
+            files=filtered_files,
+            partition=self._partition,
+            bucket=self._bucket,
+            raw_convertible=self.raw_convertible,
+            data_deletion_files=filtered_data_deletion_files
+        )
+
     @property
     def partition(self) -> GenericRow:
         return self._partition
@@ -106,18 +134,18 @@ class DataSplit(Split):
 
     @property
     def row_count(self) -> int:
-        return self._row_count
+        """Calculate total row count from all files."""
+        return sum(f.row_count for f in self._files)
 
     @property
     def file_size(self) -> int:
-        return self._file_size
+        """Calculate total file size from all files."""
+        return sum(f.file_size for f in self._files)
 
     @property
     def file_paths(self) -> List[str]:
-        return self._file_paths
-
-    def set_row_count(self, row_count: int) -> None:
-        self._row_count = row_count
+        """Get file paths from all files."""
+        return [f.file_path for f in self._files if f.file_path is not None]
 
     def merged_row_count(self) -> Optional[int]:
         """
diff --git a/paimon-python/pypaimon/tests/blob_table_test.py 
b/paimon-python/pypaimon/tests/blob_table_test.py
index 7670ec9447..b1236d764b 100755
--- a/paimon-python/pypaimon/tests/blob_table_test.py
+++ b/paimon-python/pypaimon/tests/blob_table_test.py
@@ -1455,7 +1455,7 @@ class DataBlobWriterTest(unittest.TestCase):
         splits = table_scan.plan().splits()
         result = table_read.to_arrow(splits)
 
-        self.assertEqual(sum([s._row_count for s in splits]), 40 * 2)
+        self.assertEqual(sum([s.row_count for s in splits]), 40 * 2)
 
         # Verify the data
         self.assertEqual(result.num_rows, 40, "Should have 40 rows")
@@ -2204,12 +2204,12 @@ class DataBlobWriterTest(unittest.TestCase):
         result = table_read.to_arrow(table_scan.plan().splits())
 
         # Verify the data
-        self.assertEqual(result.num_rows, 80, "Should have 54 rows")
+        self.assertEqual(result.num_rows, 54, "Should have 54 rows")
         self.assertEqual(result.num_columns, 4, "Should have 4 columns")
 
         # Verify blob data integrity
         blob_data = result.column('large_blob').to_pylist()
-        self.assertEqual(len(blob_data), 80, "Should have 54 blob records")
+        self.assertEqual(len(blob_data), 54, "Should have 54 blob records")
         # Verify each blob
         for i, blob in enumerate(blob_data):
             self.assertEqual(len(blob), len(large_blob_data), f"Blob {i + 1} 
should be {large_blob_size:,} bytes")
@@ -2500,8 +2500,18 @@ class DataBlobWriterTest(unittest.TestCase):
         result = table_read.to_arrow(splits)
 
         # Verify the data was read back correctly
-        # Just one file, so split 0 occupied the whole records
-        self.assertEqual(result.num_rows, 5, "Should have 2 rows")
+        self.assertEqual(result.num_rows, 3, "Should have 3 rows")
+        self.assertEqual(result.num_columns, 3, "Should have 3 columns")
+
+        # Read data back using table API
+        read_builder = table.new_read_builder()
+        table_scan = read_builder.new_scan().with_shard(1, 2)
+        table_read = read_builder.new_read()
+        splits = table_scan.plan().splits()
+        result = table_read.to_arrow(splits)
+
+        # Verify the data was read back correctly
+        self.assertEqual(result.num_rows, 2, "Should have 2 rows")
         self.assertEqual(result.num_columns, 3, "Should have 3 columns")
 
     def test_blob_write_read_large_data_volume_rolling_with_shard(self):
@@ -2770,7 +2780,7 @@ class DataBlobWriterTest(unittest.TestCase):
         table_read = read_builder.new_read()
         splits = table_scan.plan().splits()
 
-        total_split_row_count = sum([s._row_count for s in splits])
+        total_split_row_count = sum([s.row_count for s in splits])
         self.assertEqual(total_split_row_count, num_rows * 2,
                          f"Total split row count should be {num_rows}, got 
{total_split_row_count}")
         
diff --git a/paimon-python/pypaimon/tests/data_evolution_test.py 
b/paimon-python/pypaimon/tests/data_evolution_test.py
index cfb09a0caf..9f95672a19 100644
--- a/paimon-python/pypaimon/tests/data_evolution_test.py
+++ b/paimon-python/pypaimon/tests/data_evolution_test.py
@@ -24,6 +24,7 @@ import pyarrow as pa
 
 from pypaimon import CatalogFactory, Schema
 from pypaimon.manifest.manifest_list_manager import ManifestListManager
+from pypaimon.read.read_builder import ReadBuilder
 from pypaimon.snapshot.snapshot_manager import SnapshotManager
 
 
@@ -238,6 +239,70 @@ class DataEvolutionTest(unittest.TestCase):
             % len(result_oob),
         )
 
+    def test_with_slice_partitioned_table(self):
+        pa_schema = pa.schema([
+            ("pt", pa.int64()),
+            ("b", pa.int32()),
+            ("c", pa.int32()),
+        ])
+        schema = Schema.from_pyarrow_schema(
+            pa_schema,
+            partition_keys=["pt"],
+            options={
+                "row-tracking.enabled": "true",
+                "data-evolution.enabled": "true",
+                "source.split.target-size": "512m",
+            },
+        )
+        table_name = "default.test_with_slice_partitioned_table"
+        self.catalog.create_table(table_name, schema, ignore_if_exists=True)
+        table = self.catalog.get_table(table_name)
+
+        for batch in [
+            {"pt": [1, 1], "b": [10, 20], "c": [100, 200]},
+            {"pt": [2, 2], "b": [1011, 2011], "c": [1001, 2001]},
+            {"pt": [2, 2], "b": [-10, -20], "c": [-100, -200]},
+        ]:
+            wb = table.new_batch_write_builder()
+            tw = wb.new_write()
+            tc = wb.new_commit()
+            tw.write_arrow(pa.Table.from_pydict(batch, schema=pa_schema))
+            tc.commit(tw.prepare_commit())
+            tw.close()
+            tc.close()
+
+        rb: ReadBuilder = table.new_read_builder()
+        full_splits = rb.new_scan().plan().splits()
+        full_result = rb.new_read().to_pandas(full_splits)
+        self.assertEqual(
+            len(full_result),
+            6,
+            "Full scan should return 6 rows",
+        )
+
+        predicate_builder = rb.new_predicate_builder()
+        rb.with_filter(predicate_builder.equal("pt", 2))
+
+        # 0 to 2
+        scan_oob = rb.new_scan().with_slice(0, 2)
+        splits_oob = scan_oob.plan().splits()
+        result_oob = rb.new_read().to_pandas(splits_oob)
+        self.assertEqual(
+            sorted(result_oob["b"].tolist()),
+            [1011, 2011],
+            "Full set b mismatch",
+        )
+
+        # 2 to 4
+        scan_oob = rb.new_scan().with_slice(2, 4)
+        splits_oob = scan_oob.plan().splits()
+        result_oob = rb.new_read().to_pandas(splits_oob)
+        self.assertEqual(
+            sorted(result_oob["b"].tolist()),
+            [-20, -10],
+            "Full set b mismatch",
+        )
+
     def test_multiple_appends(self):
         simple_pa_schema = pa.schema([
             ('f0', pa.int32()),
diff --git a/paimon-python/pypaimon/write/table_update.py 
b/paimon-python/pypaimon/write/table_update.py
index 8e3f91bde4..747085a980 100644
--- a/paimon-python/pypaimon/write/table_update.py
+++ b/paimon-python/pypaimon/write/table_update.py
@@ -31,6 +31,61 @@ from pypaimon.write.writer.data_writer import DataWriter
 from pypaimon.write.writer.append_only_data_writer import AppendOnlyDataWriter
 
 
+def _filter_by_whole_file_shard(splits: List[DataSplit], sub_task_id: int, 
total_tasks: int) -> List[DataSplit]:
+    list_ranges = []
+    for split in splits:
+        for file in split.files:
+            first_row_id = file.first_row_id
+            list_ranges.append(Range(first_row_id, first_row_id + 
file.row_count - 1))
+
+    sorted_ranges = Range.sort_and_merge_overlap(list_ranges, True, False)
+
+    start_range, end_range = _divide_ranges(sorted_ranges, sub_task_id, 
total_tasks)
+    if start_range is None or end_range is None:
+        return []
+    start_first_row_id = start_range.from_
+    end_first_row_id = end_range.to
+
+    def filter_data_file(f: DataFileMeta) -> bool:
+        return start_first_row_id <= f.first_row_id <= end_first_row_id
+
+    filtered_splits = []
+
+    for split in splits:
+        split = split.filter_file(filter_data_file)
+        if split is not None:
+            filtered_splits.append(split)
+
+    return filtered_splits
+
+
+def _divide_ranges(
+        sorted_ranges: List[Range], sub_task_id: int, total_tasks: int
+) -> Tuple[Optional[Range], Optional[Range]]:
+    if not sorted_ranges:
+        return None, None
+
+    num_ranges = len(sorted_ranges)
+
+    # If more tasks than ranges, some tasks get nothing
+    if sub_task_id >= num_ranges:
+        return None, None
+
+    # Calculate balanced distribution of ranges across tasks
+    base_ranges_per_task = num_ranges // total_tasks
+    remainder = num_ranges % total_tasks
+
+    # Each of the first 'remainder' tasks gets one extra range
+    if sub_task_id < remainder:
+        num_ranges_for_task = base_ranges_per_task + 1
+        start_idx = sub_task_id * (base_ranges_per_task + 1)
+    else:
+        num_ranges_for_task = base_ranges_per_task
+        start_idx = (remainder * (base_ranges_per_task + 1) + (sub_task_id - 
remainder) * base_ranges_per_task)
+    end_idx = start_idx + num_ranges_for_task - 1
+    return sorted_ranges[start_idx], sorted_ranges[end_idx]
+
+
 class TableUpdate:
     def __init__(self, table, commit_user):
         from pypaimon.table.file_store_table import FileStoreTable
@@ -97,8 +152,10 @@ class ShardTableUpdator:
         self.writer: Optional[SingleWriter] = None
         self.dict = defaultdict(list)
 
-        scanner = 
self.table.new_read_builder().new_scan().with_shard(shard_num, 
total_shard_count)
-        self.splits = scanner.plan().splits()
+        scanner = self.table.new_read_builder().new_scan()
+        splits = scanner.plan().splits()
+        splits = _filter_by_whole_file_shard(splits, shard_num, 
total_shard_count)
+        self.splits = splits
 
         self.row_ranges: List[(Tuple, Range)] = []
         for split in self.splits:

Reply via email to