This is an automated email from the ASF dual-hosted git repository.

yuzelin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/paimon-python.git


The following commit(s) were added to refs/heads/main by this push:
     new e3b56d8  #49 Fix: Python native read with PyArrow (#53)
e3b56d8 is described below

commit e3b56d8280509f7081eacf7c6c82e3a0f792bad0
Author: ChengHui Chen <[email protected]>
AuthorDate: Wed Jul 9 14:41:39 2025 +0800

    #49 Fix: Python native read with PyArrow (#53)
---
 pypaimon/py4j/java_implementation.py               |  16 ++-
 .../pynative/reader/core/columnar_row_iterator.py  |   4 +-
 .../pynative/reader/data_file_record_reader.py     |  96 +++++++++++++++++-
 pypaimon/pynative/reader/pyarrow_dataset_reader.py |  13 +--
 pypaimon/pynative/reader/sort_merge_reader.py      |  11 ++-
 pypaimon/pynative/tests/test_pynative_reader.py    |  89 ++++++++++++++++-
 pypaimon/pynative/util/reader_convert_func.py      | 107 +++++++++++++++++----
 pypaimon/pynative/util/reader_converter.py         |   5 +-
 8 files changed, 298 insertions(+), 43 deletions(-)

diff --git a/pypaimon/py4j/java_implementation.py 
b/pypaimon/py4j/java_implementation.py
index 43425b0..34bd1c8 100644
--- a/pypaimon/py4j/java_implementation.py
+++ b/pypaimon/py4j/java_implementation.py
@@ -82,8 +82,12 @@ class Table(table.Table):
             primary_keys = None
         else:
             primary_keys = [str(key) for key in self._j_table.primaryKeys()]
+        if self._j_table.partitionKeys().isEmpty():
+            partition_keys = None
+        else:
+            partition_keys = [str(key) for key in 
self._j_table.partitionKeys()]
         return ReadBuilder(j_read_builder, self._j_table.rowType(), 
self._catalog_options,
-                           primary_keys)
+                           primary_keys, partition_keys)
 
     def new_batch_write_builder(self) -> 'BatchWriteBuilder':
         java_utils.check_batch_write(self._j_table)
@@ -93,11 +97,12 @@ class Table(table.Table):
 
 class ReadBuilder(read_builder.ReadBuilder):
 
-    def __init__(self, j_read_builder, j_row_type, catalog_options: dict, 
primary_keys: List[str]):
+    def __init__(self, j_read_builder, j_row_type, catalog_options: dict, 
primary_keys: List[str], partition_keys: List[str]):
         self._j_read_builder = j_read_builder
         self._j_row_type = j_row_type
         self._catalog_options = catalog_options
         self._primary_keys = primary_keys
+        self._partition_keys = partition_keys
         self._predicate = None
         self._projection = None
 
@@ -128,7 +133,7 @@ class ReadBuilder(read_builder.ReadBuilder):
     def new_read(self) -> 'TableRead':
         j_table_read = self._j_read_builder.newRead().executeFilter()
         return TableRead(j_table_read, self._j_read_builder.readType(), 
self._catalog_options,
-                         self._predicate, self._projection, self._primary_keys)
+                         self._predicate, self._projection, 
self._primary_keys, self._partition_keys)
 
     def new_predicate_builder(self) -> 'PredicateBuilder':
         return PredicateBuilder(self._j_row_type)
@@ -203,7 +208,7 @@ class Split(split.Split):
 class TableRead(table_read.TableRead):
 
     def __init__(self, j_table_read, j_read_type, catalog_options, predicate, 
projection,
-                 primary_keys: List[str]):
+                 primary_keys: List[str], partition_keys: List[str]):
         self._j_table_read = j_table_read
         self._j_read_type = j_read_type
         self._catalog_options = catalog_options
@@ -211,6 +216,7 @@ class TableRead(table_read.TableRead):
         self._predicate = predicate
         self._projection = projection
         self._primary_keys = primary_keys
+        self._partition_keys = partition_keys
 
         self._arrow_schema = java_utils.to_arrow_schema(j_read_type)
         self._j_bytes_reader = 
get_gateway().jvm.InvocationUtil.createParallelBytesReader(
@@ -259,7 +265,7 @@ class TableRead(table_read.TableRead):
         try:
             j_splits = list(s.to_j_split() for s in splits)
             j_reader = 
get_gateway().jvm.InvocationUtil.createReader(self._j_table_read, j_splits)
-            converter = ReaderConverter(self._predicate, self._projection, 
self._primary_keys)
+            converter = ReaderConverter(self._predicate, self._projection, 
self._primary_keys, self._partition_keys)
             pynative_reader = converter.convert_java_reader(j_reader)
 
             def _record_generator():
diff --git a/pypaimon/pynative/reader/core/columnar_row_iterator.py 
b/pypaimon/pynative/reader/core/columnar_row_iterator.py
index b42b96c..124e4af 100644
--- a/pypaimon/pynative/reader/core/columnar_row_iterator.py
+++ b/pypaimon/pynative/reader/core/columnar_row_iterator.py
@@ -32,7 +32,7 @@ class ColumnarRowIterator(FileRecordIterator[InternalRow]):
 
     def __init__(self, file_path: str, record_batch: pa.RecordBatch):
         self.file_path = file_path
-        self._record_batch = record_batch
+        self.record_batch = record_batch
         self._row = ColumnarRow(record_batch)
 
         self.num_rows = record_batch.num_rows
@@ -58,4 +58,4 @@ class ColumnarRowIterator(FileRecordIterator[InternalRow]):
         self.next_file_pos = next_file_pos
 
     def release_batch(self):
-        del self._record_batch
+        del self.record_batch
diff --git a/pypaimon/pynative/reader/data_file_record_reader.py 
b/pypaimon/pynative/reader/data_file_record_reader.py
index 0b161fe..a8b28eb 100644
--- a/pypaimon/pynative/reader/data_file_record_reader.py
+++ b/pypaimon/pynative/reader/data_file_record_reader.py
@@ -16,12 +16,89 @@
 # limitations under the License.
 
################################################################################
 
-from typing import Optional
+from typing import Optional, List, Any
+import pyarrow as pa
 
+from pypaimon.pynative.common.exception import PyNativeNotImplementedError
 from pypaimon.pynative.common.row.internal_row import InternalRow
 from pypaimon.pynative.reader.core.file_record_iterator import 
FileRecordIterator
 from pypaimon.pynative.reader.core.file_record_reader import FileRecordReader
 from pypaimon.pynative.reader.core.record_reader import RecordReader
+from pypaimon.pynative.reader.core.columnar_row_iterator import 
ColumnarRowIterator
+
+
+class PartitionInfo:
+    """
+    Partition information about how the row mapping of outer row.
+    """
+
+    def __init__(self, mapping: List[int], partition_values: List[Any]):
+        self.mapping = mapping  # Mapping array similar to Java version
+        self.partition_values = partition_values  # Partition values to be 
injected
+
+    def size(self) -> int:
+        return len(self.mapping) - 1
+
+    def in_partition_row(self, pos: int) -> bool:
+        return self.mapping[pos] < 0
+
+    def get_real_index(self, pos: int) -> int:
+        return abs(self.mapping[pos]) - 1
+
+    def get_partition_value(self, pos: int) -> Any:
+        real_index = self.get_real_index(pos)
+        return self.partition_values[real_index] if real_index < 
len(self.partition_values) else None
+
+
+class MappedColumnarRowIterator(ColumnarRowIterator):
+    """
+    ColumnarRowIterator with mapping support for partition and index mapping.
+    """
+
+    def __init__(self, file_path: str, record_batch: pa.RecordBatch,
+                 partition_info: Optional[PartitionInfo] = None,
+                 index_mapping: Optional[List[int]] = None):
+        mapped_batch = self._apply_mappings(record_batch, partition_info, 
index_mapping)
+        super().__init__(file_path, mapped_batch)
+
+    def _apply_mappings(self, record_batch: pa.RecordBatch,
+                        partition_info: Optional[PartitionInfo],
+                        index_mapping: Optional[List[int]]) -> pa.RecordBatch:
+        arrays = []
+        names = []
+
+        if partition_info is not None:
+            for i in range(partition_info.size()):
+                if partition_info.in_partition_row(i):
+                    partition_value = partition_info.get_partition_value(i)
+                    const_array = pa.array([partition_value] * 
record_batch.num_rows)
+                    arrays.append(const_array)
+                    names.append(f"partition_field_{i}")
+                else:
+                    real_index = partition_info.get_real_index(i)
+                    if real_index < record_batch.num_columns:
+                        arrays.append(record_batch.column(real_index))
+                        names.append(record_batch.column_names[real_index])
+        else:
+            arrays = [record_batch.column(i) for i in 
range(record_batch.num_columns)]
+            names = record_batch.column_names[:]
+
+        if index_mapping is not None:
+            mapped_arrays = []
+            mapped_names = []
+            for i, real_index in enumerate(index_mapping):
+                if real_index >= 0 and real_index < len(arrays):
+                    mapped_arrays.append(arrays[real_index])
+                    mapped_names.append(names[real_index] if real_index < 
len(names) else f"field_{i}")
+                else:
+                    null_array = pa.array([None] * record_batch.num_rows)
+                    mapped_arrays.append(null_array)
+                    mapped_names.append(f"null_field_{i}")
+            arrays = mapped_arrays
+            names = mapped_names
+
+        final_batch = pa.RecordBatch.from_arrays(arrays, names=names)
+        return final_batch
 
 
 class DataFileRecordReader(FileRecordReader[InternalRow]):
@@ -29,15 +106,28 @@ class DataFileRecordReader(FileRecordReader[InternalRow]):
     Reads InternalRow from data files.
     """
 
-    def __init__(self, wrapped_reader: RecordReader):
+    def __init__(self, wrapped_reader: RecordReader,
+                 index_mapping: Optional[List[int]] = None,
+                 partition_info: Optional[PartitionInfo] = None):
         self.wrapped_reader = wrapped_reader
+        self.index_mapping = index_mapping
+        self.partition_info = partition_info
 
     def read_batch(self) -> Optional[FileRecordIterator['InternalRow']]:
         iterator = self.wrapped_reader.read_batch()
         if iterator is None:
             return None
 
-        # TODO: Handle partition_info, index_mapping, and cast_mapping
+        if isinstance(iterator, ColumnarRowIterator):
+            if self.partition_info is not None or self.index_mapping is not 
None:
+                iterator = MappedColumnarRowIterator(
+                    iterator.file_path,
+                    iterator.record_batch,
+                    self.partition_info,
+                    self.index_mapping
+                )
+        else:
+            raise PyNativeNotImplementedError("partition_info & index_mapping 
for non ColumnarRowIterator")
 
         return iterator
 
diff --git a/pypaimon/pynative/reader/pyarrow_dataset_reader.py 
b/pypaimon/pynative/reader/pyarrow_dataset_reader.py
index 2f3bc85..07ed9f7 100644
--- a/pypaimon/pynative/reader/pyarrow_dataset_reader.py
+++ b/pypaimon/pynative/reader/pyarrow_dataset_reader.py
@@ -35,16 +35,9 @@ class PyArrowDatasetReader(FileRecordReader[InternalRow]):
     """
 
     def __init__(self, format, file_path, batch_size, projection,
-                 predicate: Predicate, primary_keys: List[str]):
+                 predicate: Predicate, primary_keys: List[str], fields: 
List[str]):
+
         if primary_keys is not None:
-            if projection is not None:
-                key_columns = []
-                for pk in primary_keys:
-                    key_column = f"_KEY_{pk}"
-                    if key_column not in projection:
-                        key_columns.append(key_column)
-                system_columns = ["_SEQUENCE_NUMBER", "_VALUE_KIND"]
-                projection = key_columns + system_columns + projection
             # TODO: utilize predicate to improve performance
             predicate = None
 
@@ -54,7 +47,7 @@ class PyArrowDatasetReader(FileRecordReader[InternalRow]):
         self._file_path = file_path
         self.dataset = ds.dataset(file_path, format=format)
         self.scanner = self.dataset.scanner(
-            columns=projection,
+            columns=fields,
             filter=predicate,
             batch_size=batch_size
         )
diff --git a/pypaimon/pynative/reader/sort_merge_reader.py 
b/pypaimon/pynative/reader/sort_merge_reader.py
index 896eb50..30757b2 100644
--- a/pypaimon/pynative/reader/sort_merge_reader.py
+++ b/pypaimon/pynative/reader/sort_merge_reader.py
@@ -196,11 +196,18 @@ class SortMergeIterator(RecordIterator):
 
 
 class SortMergeReader:
-    def __init__(self, readers, primary_keys):
+    def __init__(self, readers, primary_keys, partition_keys):
         self.next_batch_readers = list(readers)
         self.merge_function = DeduplicateMergeFunction(False)
 
-        key_columns = [f"_KEY_{pk}" for pk in primary_keys]
+        if partition_keys:
+            trimmed_primary_keys = [pk for pk in primary_keys if pk not in 
partition_keys]
+            if not trimmed_primary_keys:
+                raise ValueError(f"Primary key constraint {primary_keys} same 
with partition fields")
+        else:
+            trimmed_primary_keys = primary_keys
+
+        key_columns = [f"_KEY_{pk}" for pk in trimmed_primary_keys]
         key_schema = pa.schema([pa.field(column, pa.string()) for column in 
key_columns])
         self.user_key_comparator = built_comparator(key_schema)
 
diff --git a/pypaimon/pynative/tests/test_pynative_reader.py 
b/pypaimon/pynative/tests/test_pynative_reader.py
index 76667a0..fe9efb3 100644
--- a/pypaimon/pynative/tests/test_pynative_reader.py
+++ b/pypaimon/pynative/tests/test_pynative_reader.py
@@ -38,6 +38,12 @@ class NativeReaderTest(PypaimonTestBase):
             ('f1', pa.string()),
             ('f2', pa.string())
         ])
+        cls.partition_pk_pa_schema = pa.schema([
+            ('user_id', pa.int32(), False),
+            ('item_id', pa.int32()),
+            ('behavior', pa.string()),
+            ('dt', pa.string(), False)
+        ])
         cls._expected_full_data = pd.DataFrame({
             'f0': [1, 2, 3, 4, 5, 6, 7, 8],
             'f1': ['a', 'b', 'c', None, 'e', 'f', 'g', 'h'],
@@ -201,7 +207,7 @@ class NativeReaderTest(PypaimonTestBase):
         actual = self._read_test_table(read_builder)
         self.assertEqual(actual, self.expected_full_pk)
 
-    def testPkOrcReader(self):
+    def skip_testPkOrcReader(self):
         schema = Schema(self.pk_pa_schema, primary_keys=['f0'], options={
             'bucket': '1',
             'file.format': 'orc'
@@ -214,7 +220,7 @@ class NativeReaderTest(PypaimonTestBase):
         actual = self._read_test_table(read_builder)
         self.assertEqual(actual, self.expected_full_pk)
 
-    def testPkAvroReader(self):
+    def skip_testPkAvroReader(self):
         schema = Schema(self.pk_pa_schema, primary_keys=['f0'], options={
             'bucket': '1',
             'file.format': 'avro'
@@ -263,6 +269,51 @@ class NativeReaderTest(PypaimonTestBase):
         expected = self.expected_full_pk.select(['f0', 'f2'])
         self.assertEqual(actual, expected)
 
+    def testPartitionPkParquetReader(self):
+        schema = Schema(self.partition_pk_pa_schema,
+                        partition_keys=['dt'],
+                        primary_keys=['dt', 'user_id'],
+                        options={
+                            'bucket': '2'
+                        })
+        self.catalog.create_table('default.test_partition_pk_parquet', schema, 
False)
+        table = self.catalog.get_table('default.test_partition_pk_parquet')
+        self._write_partition_test_table(table)
+
+        read_builder = table.new_read_builder()
+        actual = self._read_test_table(read_builder)
+        expected = pa.Table.from_pandas(
+            pd.DataFrame({
+                'user_id': [1, 2, 3, 4, 5, 7, 8],
+                'item_id': [1, 2, 3, 4, 5, 7, 8],
+                'behavior': ["b-1", "b-2-new", "b-3", None, "b-5", "b-7", 
None],
+                'dt': ["p-1", "p-1", "p-1", "p-1", "p-2", "p-1", "p-2"]
+            }),
+            schema=self.partition_pk_pa_schema)
+        self.assertEqual(actual.sort_by('user_id'), expected)
+
+    def testPartitionPkParquetReaderWriteOnce(self):
+        schema = Schema(self.partition_pk_pa_schema,
+                        partition_keys=['dt'],
+                        primary_keys=['dt', 'user_id'],
+                        options={
+                            'bucket': '1'
+                        })
+        self.catalog.create_table('default.test_partition_pk_parquet2', 
schema, False)
+        table = self.catalog.get_table('default.test_partition_pk_parquet2')
+        self._write_partition_test_table(table, write_once=True)
+
+        read_builder = table.new_read_builder()
+        actual = self._read_test_table(read_builder)
+        expected = pa.Table.from_pandas(
+            pd.DataFrame({
+                'user_id': [1, 2, 3, 4],
+                'item_id': [1, 2, 3, 4],
+                'behavior': ['b-1', 'b-2', 'b-3', None],
+                'dt': ['p-1', 'p-1', 'p-1', 'p-1']
+            }), schema=self.partition_pk_pa_schema)
+        self.assertEqual(actual, expected)
+
     def _write_test_table(self, table, for_pk=False):
         write_builder = table.new_batch_write_builder()
 
@@ -301,6 +352,40 @@ class NativeReaderTest(PypaimonTestBase):
         table_write.close()
         table_commit.close()
 
+    def _write_partition_test_table(self, table, write_once=False):
+        write_builder = table.new_batch_write_builder()
+
+        table_write = write_builder.new_write()
+        table_commit = write_builder.new_commit()
+        data1 = {
+            'user_id': [1, 2, 3, 4],
+            'item_id': [1, 2, 3, 4],
+            'behavior': ['b-1', 'b-2', 'b-3', None],
+            'dt': ['p-1', 'p-1', 'p-1', 'p-1']
+        }
+        pa_table = pa.Table.from_pydict(data1, 
schema=self.partition_pk_pa_schema)
+        table_write.write_arrow(pa_table)
+        table_commit.commit(table_write.prepare_commit())
+        table_write.close()
+        table_commit.close()
+
+        if write_once:
+            return
+
+        table_write = write_builder.new_write()
+        table_commit = write_builder.new_commit()
+        data1 = {
+            'user_id': [5, 2, 7, 8],
+            'item_id': [5, 2, 7, 8],
+            'behavior': ['b-5', 'b-2-new', 'b-7', None],
+            'dt': ['p-2', 'p-1', 'p-1', 'p-2']
+        }
+        pa_table = pa.Table.from_pydict(data1, 
schema=self.partition_pk_pa_schema)
+        table_write.write_arrow(pa_table)
+        table_commit.commit(table_write.prepare_commit())
+        table_write.close()
+        table_commit.close()
+
     def _read_test_table(self, read_builder):
         table_read = read_builder.new_read()
         splits = read_builder.new_scan().plan().splits()
diff --git a/pypaimon/pynative/util/reader_convert_func.py 
b/pypaimon/pynative/util/reader_convert_func.py
index 00b1a14..0ccae0f 100644
--- a/pypaimon/pynative/util/reader_convert_func.py
+++ b/pypaimon/pynative/util/reader_convert_func.py
@@ -17,7 +17,7 @@
 
################################################################################
 
 
-def create_concat_record_reader(j_reader, converter, predicate, projection, 
primary_keys):
+def create_concat_record_reader(j_reader, converter, predicate, projection, 
primary_keys, partition_keys):
     from pypaimon.pynative.reader.concat_record_reader import 
ConcatRecordReader
     reader_class = j_reader.getClass()
     queue_field = reader_class.getDeclaredField("queue")
@@ -26,17 +26,27 @@ def create_concat_record_reader(j_reader, converter, 
predicate, projection, prim
     return ConcatRecordReader(converter, j_supplier_queue)
 
 
-def create_data_file_record_reader(j_reader, converter, predicate, projection, 
primary_keys):
+def create_data_file_record_reader(j_reader, converter, predicate, projection, 
primary_keys, partition_keys):
     from pypaimon.pynative.reader.data_file_record_reader import 
DataFileRecordReader
     reader_class = j_reader.getClass()
     wrapped_reader_field = reader_class.getDeclaredField("reader")
     wrapped_reader_field.setAccessible(True)
     j_wrapped_reader = wrapped_reader_field.get(j_reader)
     wrapped_reader = converter.convert_java_reader(j_wrapped_reader)
-    return DataFileRecordReader(wrapped_reader)
 
+    index_mapping_field = reader_class.getDeclaredField("indexMapping")
+    index_mapping_field.setAccessible(True)
+    index_mapping = index_mapping_field.get(j_reader)
 
-def create_filter_reader(j_reader, converter, predicate, projection, 
primary_keys):
+    partition_info_field = reader_class.getDeclaredField("partitionInfo")
+    partition_info_field.setAccessible(True)
+    j_partition_info = partition_info_field.get(j_reader)
+    partition_info = convert_partition_info(j_partition_info)
+
+    return DataFileRecordReader(wrapped_reader, index_mapping, partition_info)
+
+
+def create_filter_reader(j_reader, converter, predicate, projection, 
primary_keys, partition_keys):
     from pypaimon.pynative.reader.filter_record_reader import 
FilterRecordReader
     reader_class = j_reader.getClass()
     wrapped_reader_field = reader_class.getDeclaredField("val$thisReader")
@@ -49,7 +59,7 @@ def create_filter_reader(j_reader, converter, predicate, 
projection, primary_key
         return wrapped_reader
 
 
-def create_pyarrow_reader_for_parquet(j_reader, converter, predicate, 
projection, primary_keys):
+def create_pyarrow_reader_for_parquet(j_reader, converter, predicate, 
projection, primary_keys, partition_keys):
     from pypaimon.pynative.reader.pyarrow_dataset_reader import 
PyArrowDatasetReader
 
     reader_class = j_reader.getClass()
@@ -70,11 +80,17 @@ def create_pyarrow_reader_for_parquet(j_reader, converter, 
predicate, projection
     j_input_file = input_file_field.get(j_file_reader)
     file_path = j_input_file.getPath().toUri().toString()
 
+    fields_field = reader_class.getDeclaredField("fields")
+    fields_field.setAccessible(True)
+    fields = fields_field.get(j_reader)
+    if fields is not None:
+        fields = [str(field.getDescriptor().getPrimitiveType().getName()) for 
field in fields]
+
     return PyArrowDatasetReader('parquet', file_path, batch_size, projection,
-                                predicate, primary_keys)
+                                predicate, primary_keys, fields)
 
 
-def create_pyarrow_reader_for_orc(j_reader, converter, predicate, projection, 
primary_keys):
+def create_pyarrow_reader_for_orc(j_reader, converter, predicate, projection, 
primary_keys, partition_keys):
     from pypaimon.pynative.reader.pyarrow_dataset_reader import 
PyArrowDatasetReader
 
     reader_class = j_reader.getClass()
@@ -90,10 +106,10 @@ def create_pyarrow_reader_for_orc(j_reader, converter, 
predicate, projection, pr
     # TODO: Temporarily hard-coded to 1024 as we cannot reflectively obtain 
this value yet
     batch_size = 1024
 
-    return PyArrowDatasetReader('orc', file_path, batch_size, projection, 
predicate, primary_keys)
+    return PyArrowDatasetReader('orc', file_path, batch_size, projection, 
predicate, primary_keys, None)
 
 
-def create_avro_format_reader(j_reader, converter, predicate, projection, 
primary_keys):
+def create_avro_format_reader(j_reader, converter, predicate, projection, 
primary_keys, partition_keys):
     from pypaimon.pynative.reader.avro_format_reader import AvroFormatReader
 
     reader_class = j_reader.getClass()
@@ -108,7 +124,7 @@ def create_avro_format_reader(j_reader, converter, 
predicate, projection, primar
     return AvroFormatReader(file_path, batch_size, None)
 
 
-def create_key_value_unwrap_reader(j_reader, converter, predicate, projection, 
primary_keys):
+def create_key_value_unwrap_reader(j_reader, converter, predicate, projection, 
primary_keys, partition_keys):
     from pypaimon.pynative.reader.key_value_unwrap_reader import 
KeyValueUnwrapReader
     reader_class = j_reader.getClass()
     wrapped_reader_field = reader_class.getDeclaredField("val$reader")
@@ -118,7 +134,7 @@ def create_key_value_unwrap_reader(j_reader, converter, 
predicate, projection, p
     return KeyValueUnwrapReader(wrapped_reader)
 
 
-def create_transform_reader(j_reader, converter, predicate, projection, 
primary_keys):
+def create_transform_reader(j_reader, converter, predicate, projection, 
primary_keys, partition_keys):
     reader_class = j_reader.getClass()
     wrapped_reader_field = reader_class.getDeclaredField("val$thisReader")
     wrapped_reader_field.setAccessible(True)
@@ -127,7 +143,7 @@ def create_transform_reader(j_reader, converter, predicate, 
projection, primary_
     return converter.convert_java_reader(j_wrapped_reader)
 
 
-def create_drop_delete_reader(j_reader, converter, predicate, projection, 
primary_keys):
+def create_drop_delete_reader(j_reader, converter, predicate, projection, 
primary_keys, partition_keys):
     from pypaimon.pynative.reader.drop_delete_reader import DropDeleteReader
     reader_class = j_reader.getClass()
     wrapped_reader_field = reader_class.getDeclaredField("reader")
@@ -137,7 +153,7 @@ def create_drop_delete_reader(j_reader, converter, 
predicate, projection, primar
     return DropDeleteReader(wrapped_reader)
 
 
-def create_sort_merge_reader_minhep(j_reader, converter, predicate, 
projection, primary_keys):
+def create_sort_merge_reader_minhep(j_reader, converter, predicate, 
projection, primary_keys, partition_keys):
     from pypaimon.pynative.reader.sort_merge_reader import SortMergeReader
     j_reader_class = j_reader.getClass()
     batch_readers_field = j_reader_class.getDeclaredField("nextBatchReaders")
@@ -146,10 +162,10 @@ def create_sort_merge_reader_minhep(j_reader, converter, 
predicate, projection,
     readers = []
     for next_reader in j_batch_readers:
         readers.append(converter.convert_java_reader(next_reader))
-    return SortMergeReader(readers, primary_keys)
+    return SortMergeReader(readers, primary_keys, partition_keys)
 
 
-def create_sort_merge_reader_loser_tree(j_reader, converter, predicate, 
projection, primary_keys):
+def create_sort_merge_reader_loser_tree(j_reader, converter, predicate, 
projection, primary_keys, partition_keys):
     from pypaimon.pynative.reader.sort_merge_reader import SortMergeReader
     j_reader_class = j_reader.getClass()
     loser_tree_field = j_reader_class.getDeclaredField("loserTree")
@@ -166,10 +182,10 @@ def create_sort_merge_reader_loser_tree(j_reader, 
converter, predicate, projecti
         j_leaf_reader_field.setAccessible(True)
         j_leaf_reader = j_leaf_reader_field.get(j_leaf)
         readers.append(converter.convert_java_reader(j_leaf_reader))
-    return SortMergeReader(readers, primary_keys)
+    return SortMergeReader(readers, primary_keys, partition_keys)
 
 
-def create_key_value_wrap_record_reader(j_reader, converter, predicate, 
projection, primary_keys):
+def create_key_value_wrap_record_reader(j_reader, converter, predicate, 
projection, primary_keys, partition_keys):
     from pypaimon.pynative.reader.key_value_wrap_reader import 
KeyValueWrapReader
     reader_class = j_reader.getClass()
 
@@ -198,3 +214,60 @@ def create_key_value_wrap_record_reader(j_reader, 
converter, predicate, projecti
     arity_field.setAccessible(True)
     value_arity = arity_field.get(j_reused_value)
     return KeyValueWrapReader(wrapped_reader, level, key_arity, value_arity)
+
+
+def convert_partition_info(j_partition_info):
+    if j_partition_info is None:
+        return None
+
+    partition_info_class = j_partition_info.getClass()
+
+    map_field = partition_info_class.getDeclaredField("map")
+    map_field.setAccessible(True)
+    j_mapping = map_field.get(j_partition_info)
+    mapping = list(j_mapping) if j_mapping is not None else []
+
+    partition_field = partition_info_class.getDeclaredField("partition")
+    partition_field.setAccessible(True)
+    j_binary_row = partition_field.get(j_partition_info)
+
+    partition_type_field = 
partition_info_class.getDeclaredField("partitionType")
+    partition_type_field.setAccessible(True)
+    j_partition_type = partition_type_field.get(j_partition_info)
+
+    partition_values = []
+    if j_binary_row is not None and j_partition_type is not None:
+        field_count = j_binary_row.getFieldCount()
+        for i in range(field_count):
+            if j_binary_row.isNullAt(i):
+                partition_values.append(None)
+            else:
+                field_type = j_partition_type.getTypeAt(i)
+                type_info = field_type.getTypeRoot().toString()
+
+                if "INTEGER" in type_info:
+                    partition_values.append(j_binary_row.getInt(i))
+                elif "BIGINT" in type_info:
+                    partition_values.append(j_binary_row.getLong(i))
+                elif "VARCHAR" in type_info or "CHAR" in type_info:
+                    binary_string = j_binary_row.getString(i)
+                    partition_values.append(str(binary_string) if 
binary_string is not None else None)
+                elif "BOOLEAN" in type_info:
+                    partition_values.append(j_binary_row.getBoolean(i))
+                elif "DOUBLE" in type_info:
+                    partition_values.append(j_binary_row.getDouble(i))
+                elif "FLOAT" in type_info:
+                    partition_values.append(j_binary_row.getFloat(i))
+                elif "DATE" in type_info:
+                    partition_values.append(j_binary_row.getInt(i))  # Date 
stored as int
+                elif "TIMESTAMP" in type_info:
+                    timestamp = j_binary_row.getTimestamp(i, 3)  # precision=3 
for millis
+                    partition_values.append(timestamp.getMillisecond() if 
timestamp is not None else None)
+                else:
+                    try:
+                        partition_values.append(str(j_binary_row.getString(i) 
or ""))
+                    except:
+                        partition_values.append(None)
+
+    from pypaimon.pynative.reader.data_file_record_reader import PartitionInfo
+    return PartitionInfo(mapping, partition_values)
diff --git a/pypaimon/pynative/util/reader_converter.py 
b/pypaimon/pynative/util/reader_converter.py
index ef9bbb0..92c8ddf 100644
--- a/pypaimon/pynative/util/reader_converter.py
+++ b/pypaimon/pynative/util/reader_converter.py
@@ -72,11 +72,12 @@ class ReaderConverter:
     # Convert Java RecordReader to Python RecordReader
     """
 
-    def __init__(self, predicate, projection, primary_keys: List[str]):
+    def __init__(self, predicate, projection, primary_keys: List[str], 
partition_keys: List[str]):
         self.reader_mapping = reader_mapping
         self._predicate = predicate
         self._projection = projection
         self._primary_keys = primary_keys
+        self._partition_keys = partition_keys or []
 
     def convert_java_reader(self, java_reader: JavaObject) -> RecordReader:
         java_class_name = java_reader.getClass().getName()
@@ -84,6 +85,6 @@ class ReaderConverter:
             if os.environ.get(constants.PYPAIMON4J_TEST_MODE) == "true":
                 print("converting Java reader: " + str(java_class_name))
             return reader_mapping[java_class_name](java_reader, self, 
self._predicate,
-                                                   self._projection, 
self._primary_keys)
+                                                   self._projection, 
self._primary_keys, self._partition_keys)
         else:
             raise PyNativeNotImplementedError(f"Unsupported RecordReader type: 
{java_class_name}")

Reply via email to