This is an automated email from the ASF dual-hosted git repository.

lzljs3620320 pushed a commit to branch release-1.3
in repository https://gitbox.apache.org/repos/asf/paimon.git

commit 909a165964f89be1f2614a0ef64dc9449eaf9684
Author: umi <[email protected]>
AuthorDate: Tue Sep 23 11:28:04 2025 +0800

    [python] Support reading data by splitting according to rows (#6274)
---
 .../pypaimon/manifest/manifest_file_manager.py     |   4 +-
 paimon-python/pypaimon/read/plan.py                |   6 +-
 .../pypaimon/read/reader/concat_batch_reader.py    |  27 ++
 paimon-python/pypaimon/read/split.py               |   2 +
 paimon-python/pypaimon/read/split_read.py          |   7 +-
 paimon-python/pypaimon/read/table_scan.py          | 181 ++++++---
 .../pypaimon/tests/py36/ao_simple_test.py          | 332 ++++++++++++++++
 .../pypaimon/tests/rest/rest_simple_test.py        | 419 +++++++++++++++++++--
 paimon-python/pypaimon/write/file_store_commit.py  |   2 +-
 9 files changed, 898 insertions(+), 82 deletions(-)

diff --git a/paimon-python/pypaimon/manifest/manifest_file_manager.py 
b/paimon-python/pypaimon/manifest/manifest_file_manager.py
index aec8bc7ed0..f4b0ab0be3 100644
--- a/paimon-python/pypaimon/manifest/manifest_file_manager.py
+++ b/paimon-python/pypaimon/manifest/manifest_file_manager.py
@@ -40,7 +40,7 @@ class ManifestFileManager:
         self.primary_key_fields = 
self.table.table_schema.get_primary_key_fields()
         self.trimmed_primary_key_fields = 
self.table.table_schema.get_trimmed_primary_key_fields()
 
-    def read(self, manifest_file_name: str, shard_filter=None) -> 
List[ManifestEntry]:
+    def read(self, manifest_file_name: str, bucket_filter=None) -> 
List[ManifestEntry]:
         manifest_file_path = self.manifest_path / manifest_file_name
 
         entries = []
@@ -97,7 +97,7 @@ class ManifestFileManager:
                 total_buckets=record['_TOTAL_BUCKETS'],
                 file=file_meta
             )
-            if shard_filter is not None and not shard_filter(entry):
+            if bucket_filter is not None and not bucket_filter(entry):
                 continue
             entries.append(entry)
         return entries
diff --git a/paimon-python/pypaimon/read/plan.py 
b/paimon-python/pypaimon/read/plan.py
index 9a65fd6f12..8c69a41a9b 100644
--- a/paimon-python/pypaimon/read/plan.py
+++ b/paimon-python/pypaimon/read/plan.py
@@ -19,18 +19,14 @@
 from dataclasses import dataclass
 from typing import List
 
-from pypaimon.manifest.schema.manifest_entry import ManifestEntry
+
 from pypaimon.read.split import Split
 
 
 @dataclass
 class Plan:
     """Implementation of Plan for native Python reading."""
-    _files: List[ManifestEntry]
     _splits: List[Split]
 
-    def files(self) -> List[ManifestEntry]:
-        return self._files
-
     def splits(self) -> List[Split]:
         return self._splits
diff --git a/paimon-python/pypaimon/read/reader/concat_batch_reader.py 
b/paimon-python/pypaimon/read/reader/concat_batch_reader.py
index 76b9f10c71..a5a596e1ea 100644
--- a/paimon-python/pypaimon/read/reader/concat_batch_reader.py
+++ b/paimon-python/pypaimon/read/reader/concat_batch_reader.py
@@ -49,3 +49,30 @@ class ConcatBatchReader(RecordBatchReader):
             self.current_reader.close()
             self.current_reader = None
         self.queue.clear()
+
+
+class ShardBatchReader(ConcatBatchReader):
+
+    def __init__(self, readers, split_start_row, split_end_row):
+        super().__init__(readers)
+        self.split_start_row = split_start_row
+        self.split_end_row = split_end_row
+        self.cur_end = 0
+
+    def read_arrow_batch(self) -> Optional[RecordBatch]:
+        batch = super().read_arrow_batch()
+        if batch is None:
+            return None
+        if self.split_start_row is not None or self.split_end_row is not None:
+            cur_begin = self.cur_end  # begin idx of current batch based on 
the split
+            self.cur_end += batch.num_rows
+            # shard the first batch and the last batch
+            if self.split_start_row <= cur_begin < self.cur_end <= 
self.split_end_row:
+                return batch
+            elif cur_begin <= self.split_start_row < self.cur_end:
+                return batch.slice(self.split_start_row - cur_begin,
+                                   min(self.split_end_row, self.cur_end) - 
self.split_start_row)
+            elif cur_begin < self.split_end_row <= self.cur_end:
+                return batch.slice(0, self.split_end_row - cur_begin)
+        else:
+            return batch
diff --git a/paimon-python/pypaimon/read/split.py 
b/paimon-python/pypaimon/read/split.py
index 9b802d9880..f1ab5f3a5b 100644
--- a/paimon-python/pypaimon/read/split.py
+++ b/paimon-python/pypaimon/read/split.py
@@ -32,6 +32,8 @@ class Split:
     _file_paths: List[str]
     _row_count: int
     _file_size: int
+    split_start_row: int = None
+    split_end_row: int = None
     raw_convertible: bool = False
 
     @property
diff --git a/paimon-python/pypaimon/read/split_read.py 
b/paimon-python/pypaimon/read/split_read.py
index 1fe0a89d0e..d9e9939c11 100644
--- a/paimon-python/pypaimon/read/split_read.py
+++ b/paimon-python/pypaimon/read/split_read.py
@@ -24,7 +24,7 @@ from typing import List, Optional, Tuple
 from pypaimon.common.predicate import Predicate
 from pypaimon.read.interval_partition import IntervalPartition, SortedRun
 from pypaimon.read.partition_info import PartitionInfo
-from pypaimon.read.reader.concat_batch_reader import ConcatBatchReader
+from pypaimon.read.reader.concat_batch_reader import ConcatBatchReader, 
ShardBatchReader
 from pypaimon.read.reader.concat_record_reader import ConcatRecordReader
 from pypaimon.read.reader.data_file_record_reader import DataFileBatchReader
 from pypaimon.read.reader.drop_delete_reader import DropDeleteRecordReader
@@ -249,7 +249,10 @@ class RawFileSplitRead(SplitRead):
 
         if not data_readers:
             return EmptyFileRecordReader()
-        concat_reader = ConcatBatchReader(data_readers)
+        if self.split.split_start_row is not None:
+            concat_reader = ShardBatchReader(data_readers, 
self.split.split_start_row, self.split.split_end_row)
+        else:
+            concat_reader = ConcatBatchReader(data_readers)
         # if the table is appendonly table, we don't need extra filter, all 
predicates has pushed down
         if self.table.is_primary_key_table and self.predicate:
             return FilterRecordReader(concat_reader, self.predicate)
diff --git a/paimon-python/pypaimon/read/table_scan.py 
b/paimon-python/pypaimon/read/table_scan.py
index 0b9f97db4f..6a6ab9f3f8 100644
--- a/paimon-python/pypaimon/read/table_scan.py
+++ b/paimon-python/pypaimon/read/table_scan.py
@@ -33,7 +33,6 @@ from pypaimon.read.split import Split
 from pypaimon.schema.data_types import DataField
 from pypaimon.snapshot.snapshot_manager import SnapshotManager
 from pypaimon.table.bucket_mode import BucketMode
-from pypaimon.write.row_key_extractor import FixedBucketRowKeyExtractor
 
 
 class TableScan:
@@ -71,9 +70,21 @@ class TableScan:
             self.table.options.get('bucket', -1)) == 
BucketMode.POSTPONE_BUCKET.value else False
 
     def plan(self) -> Plan:
+        file_entries = self.plan_files()
+        if not file_entries:
+            return Plan([])
+        if self.table.is_primary_key_table:
+            splits = self._create_primary_key_splits(file_entries)
+        else:
+            splits = self._create_append_only_splits(file_entries)
+
+        splits = self._apply_push_down_limit(splits)
+        return Plan(splits)
+
+    def plan_files(self) -> List[ManifestEntry]:
         latest_snapshot = self.snapshot_manager.get_latest_snapshot()
         if not latest_snapshot:
-            return Plan([], [])
+            return []
         manifest_files = self.manifest_list_manager.read_all(latest_snapshot)
 
         deleted_entries = set()
@@ -92,40 +103,98 @@ class TableScan:
             entry for entry in added_entries
             if (tuple(entry.partition.values), entry.bucket, 
entry.file.file_name) not in deleted_entries
         ]
-
         if self.predicate:
             file_entries = self._filter_by_predicate(file_entries)
-
-        partitioned_split = defaultdict(list)
-        for entry in file_entries:
-            partitioned_split[(tuple(entry.partition.values), 
entry.bucket)].append(entry)
-
-        splits = []
-        for key, values in partitioned_split.items():
-            if self.table.is_primary_key_table:
-                splits += self._create_primary_key_splits(values)
-            else:
-                splits += self._create_append_only_splits(values)
-
-        splits = self._apply_push_down_limit(splits)
-
-        return Plan(file_entries, splits)
+        return file_entries
 
     def with_shard(self, idx_of_this_subtask, number_of_para_subtasks) -> 
'TableScan':
+        if idx_of_this_subtask >= number_of_para_subtasks:
+            raise Exception("idx_of_this_subtask must be less than 
number_of_para_subtasks")
         self.idx_of_this_subtask = idx_of_this_subtask
         self.number_of_para_subtasks = number_of_para_subtasks
         return self
 
+    def _append_only_filter_by_shard(self, partitioned_files: defaultdict) -> 
(defaultdict, int, int):
+        total_row = 0
+        # Sort by file creation time to ensure consistent sharding
+        for key, file_entries in partitioned_files.items():
+            for entry in file_entries:
+                total_row += entry.file.row_count
+
+        # Calculate number of rows this shard should process
+        # Last shard handles all remaining rows (handles non-divisible cases)
+        if self.idx_of_this_subtask == self.number_of_para_subtasks - 1:
+            num_row = total_row - total_row // self.number_of_para_subtasks * 
self.idx_of_this_subtask
+        else:
+            num_row = total_row // self.number_of_para_subtasks
+        # Calculate start row and end row position for current shard in all 
data
+        start_row = self.idx_of_this_subtask * (total_row // 
self.number_of_para_subtasks)
+        end_row = start_row + num_row
+
+        plan_start_row = 0
+        plan_end_row = 0
+        entry_end_row = 0  # end row position of current file in all data
+        splits_start_row = 0
+        filtered_partitioned_files = defaultdict(list)
+        # Iterate through all file entries to find files that overlap with 
current shard range
+        for key, file_entries in partitioned_files.items():
+            filtered_entries = []
+            for entry in file_entries:
+                entry_begin_row = entry_end_row  # Starting row position of 
current file in all data
+                entry_end_row += entry.file.row_count  # Update to row 
position after current file
+
+                # If current file is completely after shard range, stop 
iteration
+                if entry_begin_row >= end_row:
+                    break
+                # If current file is completely before shard range, skip it
+                if entry_end_row <= start_row:
+                    continue
+                if entry_begin_row <= start_row < entry_end_row:
+                    splits_start_row = entry_begin_row
+                    plan_start_row = start_row - entry_begin_row
+                # If shard end position is within current file, record 
relative end position
+                if entry_begin_row < end_row <= entry_end_row:
+                    plan_end_row = end_row - splits_start_row
+                # Add files that overlap with shard range to result
+                filtered_entries.append(entry)
+            if filtered_entries:
+                filtered_partitioned_files[key] = filtered_entries
+
+        return filtered_partitioned_files, plan_start_row, plan_end_row
+
+    def _compute_split_start_end_row(self, splits: List[Split], 
plan_start_row, plan_end_row):
+        file_end_row = 0  # end row position of current file in all data
+        for split in splits:
+            files = split.files
+            split_start_row = file_end_row
+            # Iterate through all file entries to find files that overlap with 
current shard range
+            for file in files:
+                file_begin_row = file_end_row  # Starting row position of 
current file in all data
+                file_end_row += file.row_count  # Update to row position after 
current file
+
+                # If shard start position is within current file, record 
actual start position and relative offset
+                if file_begin_row <= plan_start_row < file_end_row:
+                    split.split_start_row = plan_start_row - file_begin_row
+
+                # If shard end position is within current file, record 
relative end position
+                if file_begin_row < plan_end_row <= file_end_row:
+                    split.split_end_row = plan_end_row - split_start_row
+            if split.split_start_row is None:
+                split.split_start_row = 0
+            if split.split_end_row is None:
+                split.split_end_row = split.row_count
+
+    def _primary_key_filter_by_shard(self, file_entries: List[ManifestEntry]) 
-> List[ManifestEntry]:
+        filtered_entries = []
+        for entry in file_entries:
+            if entry.bucket % self.number_of_para_subtasks == 
self.idx_of_this_subtask:
+                filtered_entries.append(entry)
+        return filtered_entries
+
     def _bucket_filter(self, entry: Optional[ManifestEntry]) -> bool:
         bucket = entry.bucket
         if self.only_read_real_buckets and bucket < 0:
             return False
-        if self.idx_of_this_subtask is not None:
-            if self.table.is_primary_key_table:
-                return bucket % self.number_of_para_subtasks == 
self.idx_of_this_subtask
-            else:
-                file = entry.file.file_name
-                return FixedBucketRowKeyExtractor.hash(file) % 
self.number_of_para_subtasks == self.idx_of_this_subtask
         return True
 
     def _apply_push_down_limit(self, splits: List[Split]) -> List[Split]:
@@ -185,38 +254,60 @@ class TableScan:
         })
 
     def _create_append_only_splits(self, file_entries: List[ManifestEntry]) -> 
List['Split']:
-        if not file_entries:
-            return []
+        partitioned_files = defaultdict(list)
+        for entry in file_entries:
+            partitioned_files[(tuple(entry.partition.values), 
entry.bucket)].append(entry)
 
-        data_files: List[DataFileMeta] = [e.file for e in file_entries]
+        if self.idx_of_this_subtask is not None:
+            partitioned_files, plan_start_row, plan_end_row = 
self._append_only_filter_by_shard(partitioned_files)
 
         def weight_func(f: DataFileMeta) -> int:
             return max(f.file_size, self.open_file_cost)
 
-        packed_files: List[List[DataFileMeta]] = 
self._pack_for_ordered(data_files, weight_func, self.target_split_size)
-        return self._build_split_from_pack(packed_files, file_entries, False)
+        splits = []
+        for key, file_entries in partitioned_files.items():
+            if not file_entries:
+                return []
 
-    def _create_primary_key_splits(self, file_entries: List[ManifestEntry]) -> 
List['Split']:
-        if not file_entries:
-            return []
+            data_files: List[DataFileMeta] = [e.file for e in file_entries]
 
-        data_files: List[DataFileMeta] = [e.file for e in file_entries]
-        partition_sort_runs: List[List[SortedRun]] = 
IntervalPartition(data_files).partition()
-        sections: List[List[DataFileMeta]] = [
-            [file for s in sl for file in s.files]
-            for sl in partition_sort_runs
-        ]
+            packed_files: List[List[DataFileMeta]] = 
self._pack_for_ordered(data_files, weight_func,
+                                                                            
self.target_split_size)
+            splits += self._build_split_from_pack(packed_files, file_entries, 
False)
+        if self.idx_of_this_subtask is not None:
+            self._compute_split_start_end_row(splits, plan_start_row, 
plan_end_row)
+        return splits
+
+    def _create_primary_key_splits(self, file_entries: List[ManifestEntry]) -> 
List['Split']:
+        if self.idx_of_this_subtask is not None:
+            file_entries = self._primary_key_filter_by_shard(file_entries)
+        partitioned_files = defaultdict(list)
+        for entry in file_entries:
+            partitioned_files[(tuple(entry.partition.values), 
entry.bucket)].append(entry)
 
         def weight_func(fl: List[DataFileMeta]) -> int:
             return max(sum(f.file_size for f in fl), self.open_file_cost)
 
-        packed_files: List[List[List[DataFileMeta]]] = 
self._pack_for_ordered(sections, weight_func,
-                                                                              
self.target_split_size)
-        flatten_packed_files: List[List[DataFileMeta]] = [
-            [file for sub_pack in pack for file in sub_pack]
-            for pack in packed_files
-        ]
-        return self._build_split_from_pack(flatten_packed_files, file_entries, 
True)
+        splits = []
+        for key, file_entries in partitioned_files.items():
+            if not file_entries:
+                return []
+
+            data_files: List[DataFileMeta] = [e.file for e in file_entries]
+            partition_sort_runs: List[List[SortedRun]] = 
IntervalPartition(data_files).partition()
+            sections: List[List[DataFileMeta]] = [
+                [file for s in sl for file in s.files]
+                for sl in partition_sort_runs
+            ]
+
+            packed_files: List[List[List[DataFileMeta]]] = 
self._pack_for_ordered(sections, weight_func,
+                                                                               
   self.target_split_size)
+            flatten_packed_files: List[List[DataFileMeta]] = [
+                [file for sub_pack in pack for file in sub_pack]
+                for pack in packed_files
+            ]
+            splits += self._build_split_from_pack(flatten_packed_files, 
file_entries, True)
+        return splits
 
     def _build_split_from_pack(self, packed_files, file_entries, 
for_primary_key_split: bool) -> List['Split']:
         splits = []
diff --git a/paimon-python/pypaimon/tests/py36/ao_simple_test.py 
b/paimon-python/pypaimon/tests/py36/ao_simple_test.py
new file mode 100644
index 0000000000..17ebf58be7
--- /dev/null
+++ b/paimon-python/pypaimon/tests/py36/ao_simple_test.py
@@ -0,0 +1,332 @@
+"""
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements.  See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership.  The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License.  You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import pyarrow as pa
+
+from pypaimon import Schema
+from pypaimon.tests.py36.pyarrow_compat import table_sort_by
+from pypaimon.tests.rest.rest_base_test import RESTBaseTest
+
+
+class AOSimpleTest(RESTBaseTest):
+    def setUp(self):
+        super().setUp()
+        self.pa_schema = pa.schema([
+            ('user_id', pa.int64()),
+            ('item_id', pa.int64()),
+            ('behavior', pa.string()),
+            ('dt', pa.string()),
+        ])
+        self.data = {
+            'user_id': [2, 4, 6, 8, 10],
+            'item_id': [1001, 1002, 1003, 1004, 1005],
+            'behavior': ['a', 'b', 'c', 'd', 'e'],
+            'dt': ['2000-10-10', '2025-08-10', '2025-08-11', '2025-08-12', 
'2025-08-13']
+        }
+        self.expected = pa.Table.from_pydict(self.data, schema=self.pa_schema)
+
+    def test_with_shard_ao_unaware_bucket(self):
+        schema = Schema.from_pyarrow_schema(self.pa_schema, 
partition_keys=['dt'])
+        
self.rest_catalog.create_table('default.test_with_shard_ao_unaware_bucket', 
schema, False)
+        table = 
self.rest_catalog.get_table('default.test_with_shard_ao_unaware_bucket')
+        write_builder = table.new_batch_write_builder()
+        # first write
+        table_write = write_builder.new_write()
+        table_commit = write_builder.new_commit()
+        data1 = {
+            'user_id': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
+            'item_id': [1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 
1010, 1011, 1012, 1013, 1014],
+            'behavior': ['a', 'b', 'c', None, 'd', 'e', 'f', 'g', 'h', 'i', 
'j', 'k', 'l', 'm'],
+            'dt': ['p1', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p1', 
'p2', 'p1', 'p2', 'p1'],
+        }
+        pa_table = pa.Table.from_pydict(data1, schema=self.pa_schema)
+        table_write.write_arrow(pa_table)
+        table_commit.commit(table_write.prepare_commit())
+        table_write.close()
+        table_commit.close()
+        # second write
+        table_write = write_builder.new_write()
+        table_commit = write_builder.new_commit()
+        data2 = {
+            'user_id': [5, 6, 7, 8, 18],
+            'item_id': [1005, 1006, 1007, 1008, 1018],
+            'behavior': ['e', 'f', 'g', 'h', 'z'],
+            'dt': ['p2', 'p1', 'p2', 'p2', 'p1'],
+        }
+        pa_table = pa.Table.from_pydict(data2, schema=self.pa_schema)
+        table_write.write_arrow(pa_table)
+        table_commit.commit(table_write.prepare_commit())
+        table_write.close()
+        table_commit.close()
+
+        read_builder = table.new_read_builder()
+        table_read = read_builder.new_read()
+        splits = read_builder.new_scan().with_shard(2, 3).plan().splits()
+        actual = table_sort_by(table_read.to_arrow(splits), 'user_id')
+        expected = pa.Table.from_pydict({
+            'user_id': [5, 7, 7, 8, 9, 11, 13],
+            'item_id': [1005, 1007, 1007, 1008, 1009, 1011, 1013],
+            'behavior': ['e', 'f', 'g', 'h', 'h', 'j', 'l'],
+            'dt': ['p2', 'p2', 'p2', 'p2', 'p2', 'p2', 'p2'],
+        }, schema=self.pa_schema)
+        self.assertEqual(actual, expected)
+
+        # Get the three actual tables
+        splits1 = read_builder.new_scan().with_shard(0, 3).plan().splits()
+        actual1 = table_sort_by(table_read.to_arrow(splits1), 'user_id')
+        splits2 = read_builder.new_scan().with_shard(1, 3).plan().splits()
+        actual2 = table_sort_by(table_read.to_arrow(splits2), 'user_id')
+        splits3 = read_builder.new_scan().with_shard(2, 3).plan().splits()
+        actual3 = table_sort_by(table_read.to_arrow(splits3), 'user_id')
+
+        # Concatenate the three tables
+        actual = table_sort_by(pa.concat_tables([actual1, actual2, actual3]), 
'user_id')
+        expected = table_sort_by(self._read_test_table(read_builder), 
'user_id')
+        self.assertEqual(actual, expected)
+
+    def test_with_shard_ao_fixed_bucket(self):
+        schema = Schema.from_pyarrow_schema(self.pa_schema, 
partition_keys=['dt'],
+                                            options={'bucket': '5', 
'bucket-key': 'item_id'})
+        
self.rest_catalog.create_table('default.test_with_slice_ao_fixed_bucket', 
schema, False)
+        table = 
self.rest_catalog.get_table('default.test_with_slice_ao_fixed_bucket')
+        write_builder = table.new_batch_write_builder()
+        # first write
+        table_write = write_builder.new_write()
+        table_commit = write_builder.new_commit()
+        data1 = {
+            'user_id': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
+            'item_id': [1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 
1010, 1011, 1012, 1013, 1014],
+            'behavior': ['a', 'b', 'c', None, 'd', 'e', 'f', 'g', 'h', 'i', 
'j', 'k', 'l', 'm'],
+            'dt': ['p1', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p1', 
'p2', 'p1', 'p2', 'p1'],
+        }
+        pa_table = pa.Table.from_pydict(data1, schema=self.pa_schema)
+        table_write.write_arrow(pa_table)
+        table_commit.commit(table_write.prepare_commit())
+        table_write.close()
+        table_commit.close()
+        # second write
+        table_write = write_builder.new_write()
+        table_commit = write_builder.new_commit()
+        data2 = {
+            'user_id': [5, 6, 7, 8],
+            'item_id': [1005, 1006, 1007, 1008],
+            'behavior': ['e', 'f', 'g', 'h'],
+            'dt': ['p2', 'p1', 'p2', 'p2'],
+        }
+        pa_table = pa.Table.from_pydict(data2, schema=self.pa_schema)
+        table_write.write_arrow(pa_table)
+        table_commit.commit(table_write.prepare_commit())
+        table_write.close()
+        table_commit.close()
+
+        read_builder = table.new_read_builder()
+        table_read = read_builder.new_read()
+        splits = read_builder.new_scan().with_shard(0, 3).plan().splits()
+        actual = table_sort_by(table_read.to_arrow(splits), 'user_id')
+        expected = pa.Table.from_pydict({
+            'user_id': [1, 2, 3, 5, 8, 12],
+            'item_id': [1001, 1002, 1003, 1005, 1008, 1012],
+            'behavior': ['a', 'b', 'c', 'd', 'g', 'k'],
+            'dt': ['p1', 'p1', 'p2', 'p2', 'p1', 'p1'],
+        }, schema=self.pa_schema)
+        self.assertEqual(actual, expected)
+
+        # Get the three actual tables
+        splits1 = read_builder.new_scan().with_shard(0, 3).plan().splits()
+        actual1 = table_sort_by(table_read.to_arrow(splits1), 'user_id')
+        splits2 = read_builder.new_scan().with_shard(1, 3).plan().splits()
+        actual2 = table_sort_by(table_read.to_arrow(splits2), 'user_id')
+        splits3 = read_builder.new_scan().with_shard(2, 3).plan().splits()
+        actual3 = table_sort_by(table_read.to_arrow(splits3), 'user_id')
+
+        # Concatenate the three tables
+        actual = table_sort_by(pa.concat_tables([actual1, actual2, actual3]), 
'user_id')
+        expected = table_sort_by(self._read_test_table(read_builder), 
'user_id')
+        self.assertEqual(actual, expected)
+
+    def test_shard_single_partition(self):
+        """Test sharding with single partition - tests _filter_by_shard with 
simple data"""
+        schema = Schema.from_pyarrow_schema(self.pa_schema, 
partition_keys=['dt'])
+        self.rest_catalog.create_table('default.test_shard_single_partition', 
schema, False)
+        table = 
self.rest_catalog.get_table('default.test_shard_single_partition')
+        write_builder = table.new_batch_write_builder()
+
+        # Write data with single partition
+        table_write = write_builder.new_write()
+        table_commit = write_builder.new_commit()
+        data = {
+            'user_id': [1, 2, 3, 4, 5, 6],
+            'item_id': [1001, 1002, 1003, 1004, 1005, 1006],
+            'behavior': ['a', 'b', 'c', 'd', 'e', 'f'],
+            'dt': ['p1', 'p1', 'p1', 'p1', 'p1', 'p1'],
+        }
+        pa_table = pa.Table.from_pydict(data, schema=self.pa_schema)
+        table_write.write_arrow(pa_table)
+        table_commit.commit(table_write.prepare_commit())
+        table_write.close()
+        table_commit.close()
+
+        read_builder = table.new_read_builder()
+        table_read = read_builder.new_read()
+
+        # Test first shard (0, 2) - should get first 3 rows
+        splits = read_builder.new_scan().with_shard(0, 2).plan().splits()
+        actual = table_sort_by(table_read.to_arrow(splits), 'user_id')
+        expected = pa.Table.from_pydict({
+            'user_id': [1, 2, 3],
+            'item_id': [1001, 1002, 1003],
+            'behavior': ['a', 'b', 'c'],
+            'dt': ['p1', 'p1', 'p1'],
+        }, schema=self.pa_schema)
+        self.assertEqual(actual, expected)
+
+        # Test second shard (1, 2) - should get last 3 rows
+        splits = read_builder.new_scan().with_shard(1, 2).plan().splits()
+        actual = table_sort_by(table_read.to_arrow(splits), 'user_id')
+        expected = pa.Table.from_pydict({
+            'user_id': [4, 5, 6],
+            'item_id': [1004, 1005, 1006],
+            'behavior': ['d', 'e', 'f'],
+            'dt': ['p1', 'p1', 'p1'],
+        }, schema=self.pa_schema)
+        self.assertEqual(actual, expected)
+
+    def test_shard_uneven_distribution(self):
+        """Test sharding with uneven row distribution across shards"""
+        schema = Schema.from_pyarrow_schema(self.pa_schema, 
partition_keys=['dt'])
+        self.rest_catalog.create_table('default.test_shard_uneven', schema, 
False)
+        table = self.rest_catalog.get_table('default.test_shard_uneven')
+        write_builder = table.new_batch_write_builder()
+
+        # Write data with 7 rows (not evenly divisible by 3)
+        table_write = write_builder.new_write()
+        table_commit = write_builder.new_commit()
+        data = {
+            'user_id': [1, 2, 3, 4, 5, 6, 7],
+            'item_id': [1001, 1002, 1003, 1004, 1005, 1006, 1007],
+            'behavior': ['a', 'b', 'c', 'd', 'e', 'f', 'g'],
+            'dt': ['p1', 'p1', 'p1', 'p1', 'p1', 'p1', 'p1'],
+        }
+        pa_table = pa.Table.from_pydict(data, schema=self.pa_schema)
+        table_write.write_arrow(pa_table)
+        table_commit.commit(table_write.prepare_commit())
+        table_write.close()
+        table_commit.close()
+
+        read_builder = table.new_read_builder()
+        table_read = read_builder.new_read()
+
+        # Test sharding into 3 parts: 2, 2, 3 rows
+        splits = read_builder.new_scan().with_shard(0, 3).plan().splits()
+        actual1 = table_sort_by(table_read.to_arrow(splits), 'user_id')
+        expected1 = pa.Table.from_pydict({
+            'user_id': [1, 2],
+            'item_id': [1001, 1002],
+            'behavior': ['a', 'b'],
+            'dt': ['p1', 'p1'],
+        }, schema=self.pa_schema)
+        self.assertEqual(actual1, expected1)
+
+        splits = read_builder.new_scan().with_shard(1, 3).plan().splits()
+        actual2 = table_sort_by(table_read.to_arrow(splits), 'user_id')
+        expected2 = pa.Table.from_pydict({
+            'user_id': [3, 4],
+            'item_id': [1003, 1004],
+            'behavior': ['c', 'd'],
+            'dt': ['p1', 'p1'],
+        }, schema=self.pa_schema)
+        self.assertEqual(actual2, expected2)
+
+        splits = read_builder.new_scan().with_shard(2, 3).plan().splits()
+        actual3 = table_sort_by(table_read.to_arrow(splits), 'user_id')
+        expected3 = pa.Table.from_pydict({
+            'user_id': [5, 6, 7],
+            'item_id': [1005, 1006, 1007],
+            'behavior': ['e', 'f', 'g'],
+            'dt': ['p1', 'p1', 'p1'],
+        }, schema=self.pa_schema)
+        self.assertEqual(actual3, expected3)
+
+    def test_shard_many_small_shards(self):
+        """Test sharding with many small shards"""
+        schema = Schema.from_pyarrow_schema(self.pa_schema, 
partition_keys=['dt'])
+        self.rest_catalog.create_table('default.test_shard_many_small', 
schema, False)
+        table = self.rest_catalog.get_table('default.test_shard_many_small')
+        write_builder = table.new_batch_write_builder()
+
+        table_write = write_builder.new_write()
+        table_commit = write_builder.new_commit()
+        data = {
+            'user_id': [1, 2, 3, 4, 5, 6],
+            'item_id': [1001, 1002, 1003, 1004, 1005, 1006],
+            'behavior': ['a', 'b', 'c', 'd', 'e', 'f'],
+            'dt': ['p1', 'p1', 'p1', 'p1', 'p1', 'p1'],
+        }
+        pa_table = pa.Table.from_pydict(data, schema=self.pa_schema)
+        table_write.write_arrow(pa_table)
+        table_commit.commit(table_write.prepare_commit())
+        table_write.close()
+        table_commit.close()
+
+        read_builder = table.new_read_builder()
+        table_read = read_builder.new_read()
+
+        # Test with 6 shards (one row per shard)
+        for i in range(6):
+            splits = read_builder.new_scan().with_shard(i, 6).plan().splits()
+            actual = table_read.to_arrow(splits)
+            self.assertEqual(len(actual), 1)
+            self.assertEqual(actual['user_id'][0].as_py(), i + 1)
+
+    def test_shard_boundary_conditions(self):
+        """Test sharding boundary conditions with edge cases"""
+        schema = Schema.from_pyarrow_schema(self.pa_schema, 
partition_keys=['dt'])
+        self.rest_catalog.create_table('default.test_shard_boundary', schema, 
False)
+        table = self.rest_catalog.get_table('default.test_shard_boundary')
+        write_builder = table.new_batch_write_builder()
+
+        table_write = write_builder.new_write()
+        table_commit = write_builder.new_commit()
+        data = {
+            'user_id': [1, 2, 3, 4, 5],
+            'item_id': [1001, 1002, 1003, 1004, 1005],
+            'behavior': ['a', 'b', 'c', 'd', 'e'],
+            'dt': ['p1', 'p1', 'p1', 'p1', 'p1'],
+        }
+        pa_table = pa.Table.from_pydict(data, schema=self.pa_schema)
+        table_write.write_arrow(pa_table)
+        table_commit.commit(table_write.prepare_commit())
+        table_write.close()
+        table_commit.close()
+
+        read_builder = table.new_read_builder()
+        table_read = read_builder.new_read()
+
+        # Test first shard (0, 4) - should get 1 row (5//4 = 1)
+        splits = read_builder.new_scan().with_shard(0, 4).plan().splits()
+        actual = table_read.to_arrow(splits)
+        self.assertEqual(len(actual), 1)
+
+        # Test middle shard (1, 4) - should get 1 row
+        splits = read_builder.new_scan().with_shard(1, 4).plan().splits()
+        actual = table_read.to_arrow(splits)
+        self.assertEqual(len(actual), 1)
+
+        # Test last shard (3, 4) - should get 2 rows (remainder goes to last 
shard)
+        splits = read_builder.new_scan().with_shard(3, 4).plan().splits()
+        actual = table_read.to_arrow(splits)
+        self.assertEqual(len(actual), 2)
diff --git a/paimon-python/pypaimon/tests/rest/rest_simple_test.py 
b/paimon-python/pypaimon/tests/rest/rest_simple_test.py
index 95a20345b0..03685c317a 100644
--- a/paimon-python/pypaimon/tests/rest/rest_simple_test.py
+++ b/paimon-python/pypaimon/tests/rest/rest_simple_test.py
@@ -22,9 +22,7 @@ import pyarrow as pa
 
 from pypaimon import Schema
 from pypaimon.tests.rest.rest_base_test import RESTBaseTest
-from pypaimon.write.row_key_extractor import (DynamicBucketRowKeyExtractor,
-                                              FixedBucketRowKeyExtractor,
-                                              UnawareBucketRowKeyExtractor)
+from pypaimon.write.row_key_extractor import FixedBucketRowKeyExtractor, 
DynamicBucketRowKeyExtractor
 
 
 class RESTSimpleTest(RESTBaseTest):
@@ -45,58 +43,425 @@ class RESTSimpleTest(RESTBaseTest):
         self.expected = pa.Table.from_pydict(self.data, schema=self.pa_schema)
 
     def test_with_shard_ao_unaware_bucket(self):
-        schema = Schema.from_pyarrow_schema(self.pa_schema, 
partition_keys=['user_id'])
-        self.rest_catalog.create_table('default.test_with_shard', schema, 
False)
-        table = self.rest_catalog.get_table('default.test_with_shard')
+        schema = Schema.from_pyarrow_schema(self.pa_schema, 
partition_keys=['dt'])
+        
self.rest_catalog.create_table('default.test_with_shard_ao_unaware_bucket', 
schema, False)
+        table = 
self.rest_catalog.get_table('default.test_with_shard_ao_unaware_bucket')
+        write_builder = table.new_batch_write_builder()
+        # first write
+        table_write = write_builder.new_write()
+        table_commit = write_builder.new_commit()
+        data1 = {
+            'user_id': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
+            'item_id': [1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 
1010, 1011, 1012, 1013, 1014],
+            'behavior': ['a', 'b', 'c', None, 'd', 'e', 'f', 'g', 'h', 'i', 
'j', 'k', 'l', 'm'],
+            'dt': ['p1', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p1', 
'p2', 'p1', 'p2', 'p1'],
+        }
+        pa_table = pa.Table.from_pydict(data1, schema=self.pa_schema)
+        table_write.write_arrow(pa_table)
+        table_commit.commit(table_write.prepare_commit())
+        table_write.close()
+        table_commit.close()
+        # second write
+        table_write = write_builder.new_write()
+        table_commit = write_builder.new_commit()
+        data2 = {
+            'user_id': [5, 6, 7, 8, 18],
+            'item_id': [1005, 1006, 1007, 1008, 1018],
+            'behavior': ['e', 'f', 'g', 'h', 'z'],
+            'dt': ['p2', 'p1', 'p2', 'p2', 'p1'],
+        }
+        pa_table = pa.Table.from_pydict(data2, schema=self.pa_schema)
+        table_write.write_arrow(pa_table)
+        table_commit.commit(table_write.prepare_commit())
+        table_write.close()
+        table_commit.close()
+
+        read_builder = table.new_read_builder()
+        table_read = read_builder.new_read()
+        splits = read_builder.new_scan().with_shard(2, 3).plan().splits()
+        actual = table_read.to_arrow(splits).sort_by('user_id')
+        expected = pa.Table.from_pydict({
+            'user_id': [5, 7, 7, 8, 9, 11, 13],
+            'item_id': [1005, 1007, 1007, 1008, 1009, 1011, 1013],
+            'behavior': ['e', 'f', 'g', 'h', 'h', 'j', 'l'],
+            'dt': ['p2', 'p2', 'p2', 'p2', 'p2', 'p2', 'p2'],
+        }, schema=self.pa_schema)
+        self.assertEqual(actual, expected)
+
+        # Get the three actual tables
+        splits1 = read_builder.new_scan().with_shard(0, 3).plan().splits()
+        actual1 = table_read.to_arrow(splits1).sort_by('user_id')
+        splits2 = read_builder.new_scan().with_shard(1, 3).plan().splits()
+        actual2 = table_read.to_arrow(splits2).sort_by('user_id')
+        splits3 = read_builder.new_scan().with_shard(2, 3).plan().splits()
+        actual3 = table_read.to_arrow(splits3).sort_by('user_id')
+
+        # Concatenate the three tables
+        actual = pa.concat_tables([actual1, actual2, 
actual3]).sort_by('user_id')
+        expected = self._read_test_table(read_builder).sort_by('user_id')
+        self.assertEqual(actual, expected)
 
+    def test_with_shard_ao_fixed_bucket(self):
+        schema = Schema.from_pyarrow_schema(self.pa_schema, 
partition_keys=['dt'],
+                                            options={'bucket': '5', 
'bucket-key': 'item_id'})
+        
self.rest_catalog.create_table('default.test_with_slice_ao_fixed_bucket', 
schema, False)
+        table = 
self.rest_catalog.get_table('default.test_with_slice_ao_fixed_bucket')
         write_builder = table.new_batch_write_builder()
+        # first write
+        table_write = write_builder.new_write()
+        table_commit = write_builder.new_commit()
+        data1 = {
+            'user_id': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
+            'item_id': [1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 
1010, 1011, 1012, 1013, 1014],
+            'behavior': ['a', 'b', 'c', None, 'd', 'e', 'f', 'g', 'h', 'i', 
'j', 'k', 'l', 'm'],
+            'dt': ['p1', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2', 'p1', 
'p2', 'p1', 'p2', 'p1'],
+        }
+        pa_table = pa.Table.from_pydict(data1, schema=self.pa_schema)
+        table_write.write_arrow(pa_table)
+        table_commit.commit(table_write.prepare_commit())
+        table_write.close()
+        table_commit.close()
+        # second write
         table_write = write_builder.new_write()
         table_commit = write_builder.new_commit()
-        self.assertIsInstance(table_write.row_key_extractor, 
UnawareBucketRowKeyExtractor)
+        data2 = {
+            'user_id': [5, 6, 7, 8],
+            'item_id': [1005, 1006, 1007, 1008],
+            'behavior': ['e', 'f', 'g', 'h'],
+            'dt': ['p2', 'p1', 'p2', 'p2'],
+        }
+        pa_table = pa.Table.from_pydict(data2, schema=self.pa_schema)
+        table_write.write_arrow(pa_table)
+        table_commit.commit(table_write.prepare_commit())
+        table_write.close()
+        table_commit.close()
 
-        pa_table = pa.Table.from_pydict(self.data, schema=self.pa_schema)
+        read_builder = table.new_read_builder()
+        table_read = read_builder.new_read()
+        splits = read_builder.new_scan().with_shard(0, 3).plan().splits()
+        actual = table_read.to_arrow(splits).sort_by('user_id')
+        expected = pa.Table.from_pydict({
+            'user_id': [1, 2, 3, 5, 8, 12],
+            'item_id': [1001, 1002, 1003, 1005, 1008, 1012],
+            'behavior': ['a', 'b', 'c', 'd', 'g', 'k'],
+            'dt': ['p1', 'p1', 'p2', 'p2', 'p1', 'p1'],
+        }, schema=self.pa_schema)
+        self.assertEqual(actual, expected)
+
+        # Get the three actual tables
+        splits1 = read_builder.new_scan().with_shard(0, 3).plan().splits()
+        actual1 = table_read.to_arrow(splits1).sort_by('user_id')
+        splits2 = read_builder.new_scan().with_shard(1, 3).plan().splits()
+        actual2 = table_read.to_arrow(splits2).sort_by('user_id')
+        splits3 = read_builder.new_scan().with_shard(2, 3).plan().splits()
+        actual3 = table_read.to_arrow(splits3).sort_by('user_id')
+
+        # Concatenate the three tables
+        actual = pa.concat_tables([actual1, actual2, 
actual3]).sort_by('user_id')
+        expected = self._read_test_table(read_builder).sort_by('user_id')
+        self.assertEqual(actual, expected)
+
+    def test_shard_single_partition(self):
+        """Test sharding with single partition - tests _filter_by_shard with 
simple data"""
+        schema = Schema.from_pyarrow_schema(self.pa_schema, 
partition_keys=['dt'])
+        self.rest_catalog.create_table('default.test_shard_single_partition', 
schema, False)
+        table = 
self.rest_catalog.get_table('default.test_shard_single_partition')
+        write_builder = table.new_batch_write_builder()
+
+        # Write data with single partition
+        table_write = write_builder.new_write()
+        table_commit = write_builder.new_commit()
+        data = {
+            'user_id': [1, 2, 3, 4, 5, 6],
+            'item_id': [1001, 1002, 1003, 1004, 1005, 1006],
+            'behavior': ['a', 'b', 'c', 'd', 'e', 'f'],
+            'dt': ['p1', 'p1', 'p1', 'p1', 'p1', 'p1'],
+        }
+        pa_table = pa.Table.from_pydict(data, schema=self.pa_schema)
         table_write.write_arrow(pa_table)
         table_commit.commit(table_write.prepare_commit())
         table_write.close()
         table_commit.close()
 
-        splits = []
         read_builder = table.new_read_builder()
-        splits.extend(read_builder.new_scan().with_shard(0, 3).plan().splits())
-        splits.extend(read_builder.new_scan().with_shard(1, 3).plan().splits())
-        splits.extend(read_builder.new_scan().with_shard(2, 3).plan().splits())
+        table_read = read_builder.new_read()
 
+        # Test first shard (0, 2) - should get first 3 rows
+        plan = read_builder.new_scan().with_shard(0, 2).plan()
+        actual = table_read.to_arrow(plan.splits()).sort_by('user_id')
+        expected = pa.Table.from_pydict({
+            'user_id': [1, 2, 3],
+            'item_id': [1001, 1002, 1003],
+            'behavior': ['a', 'b', 'c'],
+            'dt': ['p1', 'p1', 'p1'],
+        }, schema=self.pa_schema)
+        self.assertEqual(actual, expected)
+
+        # Test second shard (1, 2) - should get last 3 rows
+        plan = read_builder.new_scan().with_shard(1, 2).plan()
+        actual = table_read.to_arrow(plan.splits()).sort_by('user_id')
+        expected = pa.Table.from_pydict({
+            'user_id': [4, 5, 6],
+            'item_id': [1004, 1005, 1006],
+            'behavior': ['d', 'e', 'f'],
+            'dt': ['p1', 'p1', 'p1'],
+        }, schema=self.pa_schema)
+        self.assertEqual(actual, expected)
+
+    def test_shard_uneven_distribution(self):
+        """Test sharding with uneven row distribution across shards"""
+        schema = Schema.from_pyarrow_schema(self.pa_schema, 
partition_keys=['dt'])
+        self.rest_catalog.create_table('default.test_shard_uneven', schema, 
False)
+        table = self.rest_catalog.get_table('default.test_shard_uneven')
+        write_builder = table.new_batch_write_builder()
+
+        # Write data with 7 rows (not evenly divisible by 3)
+        table_write = write_builder.new_write()
+        table_commit = write_builder.new_commit()
+        data = {
+            'user_id': [1, 2, 3, 4, 5, 6, 7],
+            'item_id': [1001, 1002, 1003, 1004, 1005, 1006, 1007],
+            'behavior': ['a', 'b', 'c', 'd', 'e', 'f', 'g'],
+            'dt': ['p1', 'p1', 'p1', 'p1', 'p1', 'p1', 'p1'],
+        }
+        pa_table = pa.Table.from_pydict(data, schema=self.pa_schema)
+        table_write.write_arrow(pa_table)
+        table_commit.commit(table_write.prepare_commit())
+        table_write.close()
+        table_commit.close()
+
+        read_builder = table.new_read_builder()
         table_read = read_builder.new_read()
-        actual = table_read.to_arrow(splits)
 
-        self.assertEqual(actual.sort_by('user_id'), self.expected)
+        # Test sharding into 3 parts: 2, 2, 3 rows
+        plan1 = read_builder.new_scan().with_shard(0, 3).plan()
+        actual1 = table_read.to_arrow(plan1.splits()).sort_by('user_id')
+        expected1 = pa.Table.from_pydict({
+            'user_id': [1, 2],
+            'item_id': [1001, 1002],
+            'behavior': ['a', 'b'],
+            'dt': ['p1', 'p1'],
+        }, schema=self.pa_schema)
+        self.assertEqual(actual1, expected1)
+
+        plan2 = read_builder.new_scan().with_shard(1, 3).plan()
+        actual2 = table_read.to_arrow(plan2.splits()).sort_by('user_id')
+        expected2 = pa.Table.from_pydict({
+            'user_id': [3, 4],
+            'item_id': [1003, 1004],
+            'behavior': ['c', 'd'],
+            'dt': ['p1', 'p1'],
+        }, schema=self.pa_schema)
+        self.assertEqual(actual2, expected2)
+
+        plan3 = read_builder.new_scan().with_shard(2, 3).plan()
+        actual3 = table_read.to_arrow(plan3.splits()).sort_by('user_id')
+        expected3 = pa.Table.from_pydict({
+            'user_id': [5, 6, 7],
+            'item_id': [1005, 1006, 1007],
+            'behavior': ['e', 'f', 'g'],
+            'dt': ['p1', 'p1', 'p1'],
+        }, schema=self.pa_schema)
+        self.assertEqual(actual3, expected3)
+
+    def test_shard_single_shard(self):
+        """Test sharding with only one shard - should return all data"""
+        schema = Schema.from_pyarrow_schema(self.pa_schema, 
partition_keys=['dt'])
+        self.rest_catalog.create_table('default.test_shard_single', schema, 
False)
+        table = self.rest_catalog.get_table('default.test_shard_single')
+        write_builder = table.new_batch_write_builder()
 
-    def test_with_shard_ao_fixed_bucket(self):
-        schema = Schema.from_pyarrow_schema(self.pa_schema, 
partition_keys=['user_id'],
-                                            options={'bucket': '5', 
'bucket-key': 'item_id'})
-        self.rest_catalog.create_table('default.test_with_shard', schema, 
False)
-        table = self.rest_catalog.get_table('default.test_with_shard')
+        table_write = write_builder.new_write()
+        table_commit = write_builder.new_commit()
+        data = {
+            'user_id': [1, 2, 3, 4],
+            'item_id': [1001, 1002, 1003, 1004],
+            'behavior': ['a', 'b', 'c', 'd'],
+            'dt': ['p1', 'p1', 'p2', 'p2'],
+        }
+        pa_table = pa.Table.from_pydict(data, schema=self.pa_schema)
+        table_write.write_arrow(pa_table)
+        table_commit.commit(table_write.prepare_commit())
+        table_write.close()
+        table_commit.close()
 
+        read_builder = table.new_read_builder()
+        table_read = read_builder.new_read()
+
+        # Test single shard (0, 1) - should get all data
+        plan = read_builder.new_scan().with_shard(0, 1).plan()
+        actual = table_read.to_arrow(plan.splits()).sort_by('user_id')
+        expected = pa.Table.from_pydict(data, schema=self.pa_schema)
+        self.assertEqual(actual, expected)
+
+    def test_shard_many_small_shards(self):
+        """Test sharding with many small shards"""
+        schema = Schema.from_pyarrow_schema(self.pa_schema, 
partition_keys=['dt'])
+        self.rest_catalog.create_table('default.test_shard_many_small', 
schema, False)
+        table = self.rest_catalog.get_table('default.test_shard_many_small')
         write_builder = table.new_batch_write_builder()
+
         table_write = write_builder.new_write()
         table_commit = write_builder.new_commit()
-        self.assertIsInstance(table_write.row_key_extractor, 
FixedBucketRowKeyExtractor)
+        data = {
+            'user_id': [1, 2, 3, 4, 5, 6],
+            'item_id': [1001, 1002, 1003, 1004, 1005, 1006],
+            'behavior': ['a', 'b', 'c', 'd', 'e', 'f'],
+            'dt': ['p1', 'p1', 'p1', 'p1', 'p1', 'p1'],
+        }
+        pa_table = pa.Table.from_pydict(data, schema=self.pa_schema)
+        table_write.write_arrow(pa_table)
+        table_commit.commit(table_write.prepare_commit())
+        table_write.close()
+        table_commit.close()
 
-        pa_table = pa.Table.from_pydict(self.data, schema=self.pa_schema)
+        read_builder = table.new_read_builder()
+        table_read = read_builder.new_read()
+
+        # Test with 6 shards (one row per shard)
+        for i in range(6):
+            plan = read_builder.new_scan().with_shard(i, 6).plan()
+            actual = table_read.to_arrow(plan.splits())
+            self.assertEqual(len(actual), 1)
+            self.assertEqual(actual['user_id'][0].as_py(), i + 1)
+
+    def test_shard_boundary_conditions(self):
+        """Test sharding boundary conditions with edge cases"""
+        schema = Schema.from_pyarrow_schema(self.pa_schema, 
partition_keys=['dt'])
+        self.rest_catalog.create_table('default.test_shard_boundary', schema, 
False)
+        table = self.rest_catalog.get_table('default.test_shard_boundary')
+        write_builder = table.new_batch_write_builder()
+
+        table_write = write_builder.new_write()
+        table_commit = write_builder.new_commit()
+        data = {
+            'user_id': [1, 2, 3, 4, 5],
+            'item_id': [1001, 1002, 1003, 1004, 1005],
+            'behavior': ['a', 'b', 'c', 'd', 'e'],
+            'dt': ['p1', 'p1', 'p1', 'p1', 'p1'],
+        }
+        pa_table = pa.Table.from_pydict(data, schema=self.pa_schema)
         table_write.write_arrow(pa_table)
         table_commit.commit(table_write.prepare_commit())
         table_write.close()
         table_commit.close()
 
-        splits = []
         read_builder = table.new_read_builder()
-        splits.extend(read_builder.new_scan().with_shard(0, 3).plan().splits())
-        splits.extend(read_builder.new_scan().with_shard(1, 3).plan().splits())
-        splits.extend(read_builder.new_scan().with_shard(2, 3).plan().splits())
+        table_read = read_builder.new_read()
+
+        # Test first shard (0, 4) - should get 1 row (5//4 = 1)
+        plan = read_builder.new_scan().with_shard(0, 4).plan()
+        actual = table_read.to_arrow(plan.splits())
+        self.assertEqual(len(actual), 1)
+
+        # Test middle shard (1, 4) - should get 1 row
+        plan = read_builder.new_scan().with_shard(1, 4).plan()
+        actual = table_read.to_arrow(plan.splits())
+        self.assertEqual(len(actual), 1)
+
+        # Test last shard (3, 4) - should get 2 rows (remainder goes to last 
shard)
+        plan = read_builder.new_scan().with_shard(3, 4).plan()
+        actual = table_read.to_arrow(plan.splits())
+        self.assertEqual(len(actual), 2)
+
+    def test_with_shard_large_dataset(self):
+        """Test with_shard method using 50000 rows of data to verify 
performance and correctness"""
+        schema = Schema.from_pyarrow_schema(self.pa_schema, 
partition_keys=['dt'],
+                                            options={'bucket': '5', 
'bucket-key': 'item_id'})
+        
self.rest_catalog.create_table('default.test_with_shard_large_dataset', schema, 
False)
+        table = 
self.rest_catalog.get_table('default.test_with_shard_large_dataset')
+        write_builder = table.new_batch_write_builder()
 
+        # Generate 50000 rows of test data
+        num_rows = 50000
+        batch_size = 5000  # Write in batches to avoid memory issues
+
+        for batch_start in range(0, num_rows, batch_size):
+            batch_end = min(batch_start + batch_size, num_rows)
+            batch_data = {
+                'user_id': list(range(batch_start + 1, batch_end + 1)),
+                'item_id': [2000 + i for i in range(batch_start, batch_end)],
+                'behavior': [chr(ord('a') + (i % 26)) for i in 
range(batch_start, batch_end)],
+                'dt': [f'p{(i % 5) + 1}' for i in range(batch_start, 
batch_end)],
+            }
+
+            table_write = write_builder.new_write()
+            table_commit = write_builder.new_commit()
+            pa_table = pa.Table.from_pydict(batch_data, schema=self.pa_schema)
+            table_write.write_arrow(pa_table)
+            table_commit.commit(table_write.prepare_commit())
+            table_write.close()
+            table_commit.close()
+
+        read_builder = table.new_read_builder()
         table_read = read_builder.new_read()
-        actual = table_read.to_arrow(splits)
-        self.assertEqual(actual.sort_by("user_id"), self.expected)
+
+        # Test with 6 shards
+        num_shards = 6
+        shard_results = []
+        total_rows_from_shards = 0
+
+        for shard_idx in range(num_shards):
+            splits = read_builder.new_scan().with_shard(shard_idx, 
num_shards).plan().splits()
+            shard_result = table_read.to_arrow(splits)
+            shard_results.append(shard_result)
+            shard_rows = len(shard_result) if shard_result else 0
+            total_rows_from_shards += shard_rows
+            print(f"Shard {shard_idx}/{num_shards}: {shard_rows} rows")
+
+        # Verify that all shards together contain all the data
+        concatenated_result = 
pa.concat_tables(shard_results).sort_by('user_id')
+
+        # Read all data without sharding for comparison
+        all_splits = read_builder.new_scan().plan().splits()
+        all_data = table_read.to_arrow(all_splits).sort_by('user_id')
+
+        # Verify total row count
+        self.assertEqual(len(concatenated_result), len(all_data))
+        self.assertEqual(len(all_data), num_rows)
+        self.assertEqual(total_rows_from_shards, num_rows)
+
+        # Verify data integrity - check first and last few rows
+        self.assertEqual(concatenated_result['user_id'][0].as_py(), 1)
+        self.assertEqual(concatenated_result['user_id'][-1].as_py(), num_rows)
+        self.assertEqual(concatenated_result['item_id'][0].as_py(), 2000)
+        self.assertEqual(concatenated_result['item_id'][-1].as_py(), 2000 + 
num_rows - 1)
+
+        # Verify that concatenated result equals all data
+        self.assertEqual(concatenated_result, all_data)
+        # Test with different shard configurations
+        # Test with 10 shards
+        shard_10_results = []
+        for shard_idx in range(10):
+            splits = read_builder.new_scan().with_shard(shard_idx, 
10).plan().splits()
+            shard_result = table_read.to_arrow(splits)
+            if shard_result:
+                shard_10_results.append(shard_result)
+
+        if shard_10_results:
+            concatenated_10_shards = 
pa.concat_tables(shard_10_results).sort_by('user_id')
+            self.assertEqual(len(concatenated_10_shards), num_rows)
+            self.assertEqual(concatenated_10_shards, all_data)
+
+        # Test with single shard (should return all data)
+        single_shard_splits = read_builder.new_scan().with_shard(0, 
1).plan().splits()
+        single_shard_result = 
table_read.to_arrow(single_shard_splits).sort_by('user_id')
+        self.assertEqual(len(single_shard_result), num_rows)
+        self.assertEqual(single_shard_result, all_data)
+
+        print(f"Successfully tested with_shard method using {num_rows} rows of 
data")
+
+    def test_shard_parameter_validation(self):
+        """Test edge cases for parameter validation"""
+        schema = Schema.from_pyarrow_schema(self.pa_schema, 
partition_keys=['dt'])
+        self.rest_catalog.create_table('default.test_shard_validation_edge', 
schema, False)
+        table = 
self.rest_catalog.get_table('default.test_shard_validation_edge')
+
+        read_builder = table.new_read_builder()
+        # Test invalid case with number_of_para_subtasks = 1
+        with self.assertRaises(Exception) as context:
+            read_builder.new_scan().with_shard(1, 1).plan()
+        self.assertEqual(str(context.exception), "idx_of_this_subtask must be 
less than number_of_para_subtasks")
 
     def test_with_shard_pk_dynamic_bucket(self):
         schema = Schema.from_pyarrow_schema(self.pa_schema, 
partition_keys=['user_id'], primary_keys=['user_id', 'dt'])
diff --git a/paimon-python/pypaimon/write/file_store_commit.py 
b/paimon-python/pypaimon/write/file_store_commit.py
index 03c4d034d0..5920f50ad8 100644
--- a/paimon-python/pypaimon/write/file_store_commit.py
+++ b/paimon-python/pypaimon/write/file_store_commit.py
@@ -101,7 +101,7 @@ class FileStoreCommit:
                                        f"in {msg.partition} does not belong to 
this partition")
 
         commit_entries = []
-        current_entries = TableScan(self.table, partition_filter, None, 
[]).plan().files()
+        current_entries = TableScan(self.table, partition_filter, None, 
[]).plan_files()
         for entry in current_entries:
             entry.kind = 1
             commit_entries.append(entry)

Reply via email to