This is an automated email from the ASF dual-hosted git repository.

lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git


The following commit(s) were added to refs/heads/master by this push:
     new 77f130bde0 [Python] Predicate Push Down for Scan / Read (#6166)
77f130bde0 is described below

commit 77f130bde03381e3b26b86e4f0e95cff5c1dc622
Author: ChengHui Chen <27797326+chenghuic...@users.noreply.github.com>
AuthorDate: Fri Aug 29 16:34:24 2025 +0800

    [Python] Predicate Push Down for Scan / Read (#6166)
---
 paimon-python/pypaimon/common/predicate.py         | 103 ++++++++++++-
 .../pypaimon/manifest/manifest_file_manager.py     |  32 ++--
 .../pypaimon/manifest/manifest_list_manager.py     |  12 +-
 .../pypaimon/manifest/schema/simple_stats.py       |   6 +-
 paimon-python/pypaimon/read/push_down_utils.py     |  72 +++++++++
 .../pypaimon/read/reader/format_avro_reader.py     |  24 +--
 .../pypaimon/read/reader/format_pyarrow_reader.py  |  50 +------
 paimon-python/pypaimon/read/split_read.py          |  12 +-
 paimon-python/pypaimon/read/table_read.py          |  35 ++++-
 paimon-python/pypaimon/read/table_scan.py          | 161 +++++++--------------
 .../pypaimon/tests/predicate_push_down_test.py     | 151 +++++++++++++++++++
 paimon-python/pypaimon/write/file_store_commit.py  |   6 +-
 paimon-python/pypaimon/write/writer/data_writer.py |  34 ++---
 13 files changed, 471 insertions(+), 227 deletions(-)

diff --git a/paimon-python/pypaimon/common/predicate.py 
b/paimon-python/pypaimon/common/predicate.py
index ee13aca99b..ba56713032 100644
--- a/paimon-python/pypaimon/common/predicate.py
+++ b/paimon-python/pypaimon/common/predicate.py
@@ -19,7 +19,7 @@ from __future__ import annotations
 
 from dataclasses import dataclass
 from functools import reduce
-from typing import Any, List, Optional
+from typing import Any, Dict, List, Optional
 
 import pyarrow
 from pyarrow import compute as pyarrow_compute
@@ -82,6 +82,107 @@ class Predicate:
         else:
             raise ValueError("Unsupported predicate method: 
{}".format(self.method))
 
+    def test_by_value(self, value: Any) -> bool:
+        if self.method == 'and':
+            return all(p.test_by_value(value) for p in self.literals)
+        if self.method == 'or':
+            t = any(p.test_by_value(value) for p in self.literals)
+            return t
+
+        if self.method == 'equal':
+            return value == self.literals[0]
+        if self.method == 'notEqual':
+            return value != self.literals[0]
+        if self.method == 'lessThan':
+            return value < self.literals[0]
+        if self.method == 'lessOrEqual':
+            return value <= self.literals[0]
+        if self.method == 'greaterThan':
+            return value > self.literals[0]
+        if self.method == 'greaterOrEqual':
+            return value >= self.literals[0]
+        if self.method == 'isNull':
+            return value is None
+        if self.method == 'isNotNull':
+            return value is not None
+        if self.method == 'startsWith':
+            if not isinstance(value, str):
+                return False
+            return value.startswith(self.literals[0])
+        if self.method == 'endsWith':
+            if not isinstance(value, str):
+                return False
+            return value.endswith(self.literals[0])
+        if self.method == 'contains':
+            if not isinstance(value, str):
+                return False
+            return self.literals[0] in value
+        if self.method == 'in':
+            return value in self.literals
+        if self.method == 'notIn':
+            return value not in self.literals
+        if self.method == 'between':
+            return self.literals[0] <= value <= self.literals[1]
+
+        raise ValueError("Unsupported predicate method: 
{}".format(self.method))
+
+    def test_by_stats(self, stat: Dict) -> bool:
+        if self.method == 'and':
+            return all(p.test_by_stats(stat) for p in self.literals)
+        if self.method == 'or':
+            t = any(p.test_by_stats(stat) for p in self.literals)
+            return t
+
+        null_count = stat["null_counts"][self.field]
+        row_count = stat["row_count"]
+
+        if self.method == 'isNull':
+            return null_count is not None and null_count > 0
+        if self.method == 'isNotNull':
+            return null_count is None or row_count is None or null_count < 
row_count
+
+        min_value = stat["min_values"][self.field]
+        max_value = stat["max_values"][self.field]
+
+        if min_value is None or max_value is None or (null_count is not None 
and null_count == row_count):
+            return False
+
+        if self.method == 'equal':
+            return min_value <= self.literals[0] <= max_value
+        if self.method == 'notEqual':
+            return not (min_value == self.literals[0] == max_value)
+        if self.method == 'lessThan':
+            return self.literals[0] > min_value
+        if self.method == 'lessOrEqual':
+            return self.literals[0] >= min_value
+        if self.method == 'greaterThan':
+            return self.literals[0] < max_value
+        if self.method == 'greaterOrEqual':
+            return self.literals[0] <= max_value
+        if self.method == 'startsWith':
+            if not isinstance(min_value, str) or not isinstance(max_value, 
str):
+                raise RuntimeError("startsWith predicate on non-str field")
+            return ((min_value.startswith(self.literals[0]) or min_value < 
self.literals[0])
+                    and (max_value.startswith(self.literals[0]) or max_value > 
self.literals[0]))
+        if self.method == 'endsWith':
+            return True
+        if self.method == 'contains':
+            return True
+        if self.method == 'in':
+            for literal in self.literals:
+                if min_value <= literal <= max_value:
+                    return True
+            return False
+        if self.method == 'notIn':
+            for literal in self.literals:
+                if min_value == literal == max_value:
+                    return False
+            return True
+        if self.method == 'between':
+            return self.literals[0] <= max_value and self.literals[1] >= 
min_value
+        else:
+            raise ValueError("Unsupported predicate method: 
{}".format(self.method))
+
     def to_arrow(self) -> pyarrow_compute.Expression | bool:
         if self.method == 'equal':
             return pyarrow_dataset.field(self.field) == self.literals[0]
diff --git a/paimon-python/pypaimon/manifest/manifest_file_manager.py 
b/paimon-python/pypaimon/manifest/manifest_file_manager.py
index 7c46b368d2..7c97f7b0ca 100644
--- a/paimon-python/pypaimon/manifest/manifest_file_manager.py
+++ b/paimon-python/pypaimon/manifest/manifest_file_manager.py
@@ -55,19 +55,19 @@ class ManifestFileManager:
             file_dict = dict(record['_FILE'])
             key_dict = dict(file_dict['_KEY_STATS'])
             key_stats = SimpleStats(
-                
min_value=BinaryRowDeserializer.from_bytes(key_dict['_MIN_VALUES'],
-                                                           
self.trimmed_primary_key_fields),
-                
max_value=BinaryRowDeserializer.from_bytes(key_dict['_MAX_VALUES'],
-                                                           
self.trimmed_primary_key_fields),
-                null_count=key_dict['_NULL_COUNTS'],
+                
min_values=BinaryRowDeserializer.from_bytes(key_dict['_MIN_VALUES'],
+                                                            
self.trimmed_primary_key_fields),
+                
max_values=BinaryRowDeserializer.from_bytes(key_dict['_MAX_VALUES'],
+                                                            
self.trimmed_primary_key_fields),
+                null_counts=key_dict['_NULL_COUNTS'],
             )
             value_dict = dict(file_dict['_VALUE_STATS'])
             value_stats = SimpleStats(
-                
min_value=BinaryRowDeserializer.from_bytes(value_dict['_MIN_VALUES'],
-                                                           
self.table.table_schema.fields),
-                
max_value=BinaryRowDeserializer.from_bytes(value_dict['_MAX_VALUES'],
-                                                           
self.table.table_schema.fields),
-                null_count=value_dict['_NULL_COUNTS'],
+                
min_values=BinaryRowDeserializer.from_bytes(value_dict['_MIN_VALUES'],
+                                                            
self.table.table_schema.fields),
+                
max_values=BinaryRowDeserializer.from_bytes(value_dict['_MAX_VALUES'],
+                                                            
self.table.table_schema.fields),
+                null_counts=value_dict['_NULL_COUNTS'],
             )
             file_meta = DataFileMeta(
                 file_name=file_dict['_FILE_NAME'],
@@ -118,14 +118,14 @@ class ManifestFileManager:
                         "_MIN_KEY": BinaryRowSerializer.to_bytes(file.min_key),
                         "_MAX_KEY": BinaryRowSerializer.to_bytes(file.max_key),
                         "_KEY_STATS": {
-                            "_MIN_VALUES": 
BinaryRowSerializer.to_bytes(file.key_stats.min_value),
-                            "_MAX_VALUES": 
BinaryRowSerializer.to_bytes(file.key_stats.max_value),
-                            "_NULL_COUNTS": file.key_stats.null_count,
+                            "_MIN_VALUES": 
BinaryRowSerializer.to_bytes(file.key_stats.min_values),
+                            "_MAX_VALUES": 
BinaryRowSerializer.to_bytes(file.key_stats.max_values),
+                            "_NULL_COUNTS": file.key_stats.null_counts,
                         },
                         "_VALUE_STATS": {
-                            "_MIN_VALUES": 
BinaryRowSerializer.to_bytes(file.value_stats.min_value),
-                            "_MAX_VALUES": 
BinaryRowSerializer.to_bytes(file.value_stats.max_value),
-                            "_NULL_COUNTS": file.value_stats.null_count,
+                            "_MIN_VALUES": 
BinaryRowSerializer.to_bytes(file.value_stats.min_values),
+                            "_MAX_VALUES": 
BinaryRowSerializer.to_bytes(file.value_stats.max_values),
+                            "_NULL_COUNTS": file.value_stats.null_counts,
                         },
                         "_MIN_SEQUENCE_NUMBER": file.min_sequence_number,
                         "_MAX_SEQUENCE_NUMBER": file.max_sequence_number,
diff --git a/paimon-python/pypaimon/manifest/manifest_list_manager.py 
b/paimon-python/pypaimon/manifest/manifest_list_manager.py
index 65fd2b21ac..dc9d5db44d 100644
--- a/paimon-python/pypaimon/manifest/manifest_list_manager.py
+++ b/paimon-python/pypaimon/manifest/manifest_list_manager.py
@@ -58,15 +58,15 @@ class ManifestListManager:
         for record in reader:
             stats_dict = dict(record['_PARTITION_STATS'])
             partition_stats = SimpleStats(
-                min_value=BinaryRowDeserializer.from_bytes(
+                min_values=BinaryRowDeserializer.from_bytes(
                     stats_dict['_MIN_VALUES'],
                     self.table.table_schema.get_partition_key_fields()
                 ),
-                max_value=BinaryRowDeserializer.from_bytes(
+                max_values=BinaryRowDeserializer.from_bytes(
                     stats_dict['_MAX_VALUES'],
                     self.table.table_schema.get_partition_key_fields()
                 ),
-                null_count=stats_dict['_NULL_COUNTS'],
+                null_counts=stats_dict['_NULL_COUNTS'],
             )
             manifest_file_meta = ManifestFileMeta(
                 file_name=record['_FILE_NAME'],
@@ -90,9 +90,9 @@ class ManifestListManager:
                 "_NUM_ADDED_FILES": meta.num_added_files,
                 "_NUM_DELETED_FILES": meta.num_deleted_files,
                 "_PARTITION_STATS": {
-                    "_MIN_VALUES": 
BinaryRowSerializer.to_bytes(meta.partition_stats.min_value),
-                    "_MAX_VALUES": 
BinaryRowSerializer.to_bytes(meta.partition_stats.max_value),
-                    "_NULL_COUNTS": meta.partition_stats.null_count,
+                    "_MIN_VALUES": 
BinaryRowSerializer.to_bytes(meta.partition_stats.min_values),
+                    "_MAX_VALUES": 
BinaryRowSerializer.to_bytes(meta.partition_stats.max_values),
+                    "_NULL_COUNTS": meta.partition_stats.null_counts,
                 },
                 "_SCHEMA_ID": meta.schema_id,
             }
diff --git a/paimon-python/pypaimon/manifest/schema/simple_stats.py 
b/paimon-python/pypaimon/manifest/schema/simple_stats.py
index 4a73d3eee4..55b2163e76 100644
--- a/paimon-python/pypaimon/manifest/schema/simple_stats.py
+++ b/paimon-python/pypaimon/manifest/schema/simple_stats.py
@@ -24,9 +24,9 @@ from pypaimon.table.row.binary_row import BinaryRow
 
 @dataclass
 class SimpleStats:
-    min_value: BinaryRow
-    max_value: BinaryRow
-    null_count: Optional[List[int]]
+    min_values: BinaryRow
+    max_values: BinaryRow
+    null_counts: Optional[List[int]]
 
 
 SIMPLE_STATS_SCHEMA = {
diff --git a/paimon-python/pypaimon/read/push_down_utils.py 
b/paimon-python/pypaimon/read/push_down_utils.py
new file mode 100644
index 0000000000..31e66973c6
--- /dev/null
+++ b/paimon-python/pypaimon/read/push_down_utils.py
@@ -0,0 +1,72 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+from typing import Dict, List, Set
+
+from pypaimon.common.predicate import Predicate
+
+
+def extract_predicate_to_list(result: list, input_predicate: 'Predicate', 
keys: List[str]):
+    if not input_predicate or not keys:
+        return
+
+    if input_predicate.method == 'and':
+        for sub_predicate in input_predicate.literals:
+            extract_predicate_to_list(result, sub_predicate, keys)
+        return
+    elif input_predicate.method == 'or':
+        # condition: involved keys all belong to primary keys
+        involved_fields = _get_all_fields(input_predicate)
+        if involved_fields and involved_fields.issubset(keys):
+            result.append(input_predicate)
+        return
+
+    if input_predicate.field in keys:
+        result.append(input_predicate)
+
+
+def _get_all_fields(predicate: 'Predicate') -> Set[str]:
+    if predicate.field is not None:
+        return {predicate.field}
+    involved_fields = set()
+    if predicate.literals:
+        for sub_predicate in predicate.literals:
+            involved_fields.update(_get_all_fields(sub_predicate))
+    return involved_fields
+
+
+def extract_predicate_to_dict(result: Dict, input_predicate: 'Predicate', 
keys: List[str]):
+    if not input_predicate or not keys:
+        return
+
+    if input_predicate.method == 'and':
+        for sub_predicate in input_predicate.literals:
+            extract_predicate_to_dict(result, sub_predicate, keys)
+        return
+    elif input_predicate.method == 'or':
+        # ensure no recursive and/or
+        if not input_predicate.literals or any(p.field is None for p in 
input_predicate.literals):
+            return
+        # condition: only one key for 'or', and the key belongs to keys
+        involved_fields = {p.field for p in input_predicate.literals}
+        if len(involved_fields) == 1 and (field := involved_fields.pop()) in 
keys:
+            result[field].append(input_predicate)
+        return
+
+    if input_predicate.field in keys:
+        result[input_predicate.field].append(input_predicate)
diff --git a/paimon-python/pypaimon/read/reader/format_avro_reader.py 
b/paimon-python/pypaimon/read/reader/format_avro_reader.py
index 83e90606e7..4ce7c04ed4 100644
--- a/paimon-python/pypaimon/read/reader/format_avro_reader.py
+++ b/paimon-python/pypaimon/read/reader/format_avro_reader.py
@@ -20,11 +20,11 @@ from typing import List, Optional
 
 import fastavro
 import pyarrow as pa
+import pyarrow.compute as pc
 import pyarrow.dataset as ds
 from pyarrow import RecordBatch
 
 from pypaimon.common.file_io import FileIO
-from pypaimon.common.predicate import Predicate
 from pypaimon.read.reader.iface.record_batch_reader import RecordBatchReader
 from pypaimon.schema.data_types import DataField, PyarrowFieldParser
 
@@ -35,26 +35,18 @@ class FormatAvroReader(RecordBatchReader):
     provided predicate and projection, and converts Avro records to 
RecordBatch format.
     """
 
-    def __init__(self, file_io: FileIO, file_path: str, primary_keys: 
List[str],
-                 fields: List[str], full_fields: List[DataField], predicate: 
Predicate, batch_size: int = 4096):
+    def __init__(self, file_io: FileIO, file_path: str, read_fields: 
List[str], full_fields: List[DataField],
+                 push_down_predicate: pc.Expression | bool, batch_size: int = 
4096):
         self._file = file_io.filesystem.open_input_file(file_path)
         self._avro_reader = fastavro.reader(self._file)
         self._batch_size = batch_size
-        self._primary_keys = primary_keys
+        self._push_down_predicate = push_down_predicate
 
-        self._fields = fields
+        self._fields = read_fields
         full_fields_map = {field.name: field for field in full_fields}
-        projected_data_fields = [full_fields_map[name] for name in fields]
+        projected_data_fields = [full_fields_map[name] for name in read_fields]
         self._schema = 
PyarrowFieldParser.from_paimon_schema(projected_data_fields)
 
-        if primary_keys:
-            # TODO: utilize predicate to improve performance
-            predicate = None
-        if predicate is not None:
-            self._predicate = predicate.to_arrow()
-        else:
-            self._predicate = None
-
     def read_arrow_batch(self) -> Optional[RecordBatch]:
         pydict_data = {name: [] for name in self._fields}
         records_in_batch = 0
@@ -68,12 +60,12 @@ class FormatAvroReader(RecordBatchReader):
 
         if records_in_batch == 0:
             return None
-        if self._predicate is None:
+        if self._push_down_predicate is None:
             return pa.RecordBatch.from_pydict(pydict_data, self._schema)
         else:
             pa_batch = pa.Table.from_pydict(pydict_data, self._schema)
             dataset = ds.InMemoryDataset(pa_batch)
-            scanner = dataset.scanner(filter=self._predicate)
+            scanner = dataset.scanner(filter=self._push_down_predicate)
             combine_chunks = scanner.to_table().combine_chunks()
             if combine_chunks.num_rows > 0:
                 return combine_chunks.to_batches()[0]
diff --git a/paimon-python/pypaimon/read/reader/format_pyarrow_reader.py 
b/paimon-python/pypaimon/read/reader/format_pyarrow_reader.py
index b01f8113b7..ecef589391 100644
--- a/paimon-python/pypaimon/read/reader/format_pyarrow_reader.py
+++ b/paimon-python/pypaimon/read/reader/format_pyarrow_reader.py
@@ -18,11 +18,11 @@
 
 from typing import List, Optional
 
+import pyarrow.compute as pc
 import pyarrow.dataset as ds
 from pyarrow import RecordBatch
 
 from pypaimon.common.file_io import FileIO
-from pypaimon.common.predicate import Predicate
 from pypaimon.read.reader.iface.record_batch_reader import RecordBatchReader
 
 
@@ -32,19 +32,12 @@ class FormatPyArrowReader(RecordBatchReader):
     and filters it based on the provided predicate and projection.
     """
 
-    def __init__(self, file_io: FileIO, file_format: str, file_path: str, 
primary_keys: List[str],
-                 fields: List[str], predicate: Predicate, batch_size: int = 
4096):
-
-        if primary_keys:
-            # TODO: utilize predicate to improve performance
-            predicate = None
-        if predicate is not None:
-            predicate = predicate.to_arrow()
-
+    def __init__(self, file_io: FileIO, file_format: str, file_path: str, 
read_fields: List[str],
+                 push_down_predicate: pc.Expression | bool, batch_size: int = 
4096):
         self.dataset = ds.dataset(file_path, format=file_format, 
filesystem=file_io.filesystem)
         self.reader = self.dataset.scanner(
-            columns=fields,
-            filter=predicate,
+            columns=read_fields,
+            filter=push_down_predicate,
             batch_size=batch_size
         ).to_reader()
 
@@ -58,36 +51,3 @@ class FormatPyArrowReader(RecordBatchReader):
         if self.reader is not None:
             self.reader.close()
             self.reader = None
-
-
-def _filter_predicate_by_primary_keys(predicate: Predicate, primary_keys):
-    """
-    Filter out predicates that are not related to primary key fields.
-    """
-    if predicate is None or primary_keys is None:
-        return predicate
-
-    if predicate.method in ['and', 'or']:
-        filtered_literals = []
-        for literal in predicate.literals:
-            filtered = _filter_predicate_by_primary_keys(literal, primary_keys)
-            if filtered is not None:
-                filtered_literals.append(filtered)
-
-        if not filtered_literals:
-            return None
-
-        if len(filtered_literals) == 1:
-            return filtered_literals[0]
-
-        return Predicate(
-            method=predicate.method,
-            index=predicate.index,
-            field=predicate.field,
-            literals=filtered_literals
-        )
-
-    if predicate.field in primary_keys:
-        return predicate
-    else:
-        return None
diff --git a/paimon-python/pypaimon/read/split_read.py 
b/paimon-python/pypaimon/read/split_read.py
index 99f8a4da21..1fe0a89d0e 100644
--- a/paimon-python/pypaimon/read/split_read.py
+++ b/paimon-python/pypaimon/read/split_read.py
@@ -49,11 +49,13 @@ NULL_FIELD_INDEX = -1
 class SplitRead(ABC):
     """Abstract base class for split reading operations."""
 
-    def __init__(self, table, predicate: Optional[Predicate], read_type: 
List[DataField], split: Split):
+    def __init__(self, table, predicate: Optional[Predicate], 
push_down_predicate,
+                 read_type: List[DataField], split: Split):
         from pypaimon.table.file_store_table import FileStoreTable
 
         self.table: FileStoreTable = table
         self.predicate = predicate
+        self.push_down_predicate = push_down_predicate
         self.split = split
         self.value_arity = len(read_type)
 
@@ -72,11 +74,11 @@ class SplitRead(ABC):
 
         format_reader: RecordBatchReader
         if file_format == "avro":
-            format_reader = FormatAvroReader(self.table.file_io, file_path, 
self.table.primary_keys,
-                                             
self._get_final_read_data_fields(), self.read_fields, self.predicate)
+            format_reader = FormatAvroReader(self.table.file_io, file_path, 
self._get_final_read_data_fields(),
+                                             self.read_fields, 
self.push_down_predicate)
         elif file_format == "parquet" or file_format == "orc":
-            format_reader = FormatPyArrowReader(self.table.file_io, 
file_format, file_path, self.table.primary_keys,
-                                                
self._get_final_read_data_fields(), self.predicate)
+            format_reader = FormatPyArrowReader(self.table.file_io, 
file_format, file_path,
+                                                
self._get_final_read_data_fields(), self.push_down_predicate)
         else:
             raise ValueError(f"Unexpected file format: {file_format}")
 
diff --git a/paimon-python/pypaimon/read/table_read.py 
b/paimon-python/pypaimon/read/table_read.py
index b8a28c19d1..621549d832 100644
--- a/paimon-python/pypaimon/read/table_read.py
+++ b/paimon-python/pypaimon/read/table_read.py
@@ -19,8 +19,11 @@ from typing import Iterator, List, Optional
 
 import pandas
 import pyarrow
+import pyarrow.compute as pc
 
 from pypaimon.common.predicate import Predicate
+from pypaimon.common.predicate_builder import PredicateBuilder
+from pypaimon.read.push_down_utils import extract_predicate_to_list
 from pypaimon.read.reader.iface.record_batch_reader import RecordBatchReader
 from pypaimon.read.split import Split
 from pypaimon.read.split_read import (MergeFileSplitRead, RawFileSplitRead,
@@ -37,6 +40,7 @@ class TableRead:
 
         self.table: FileStoreTable = table
         self.predicate = predicate
+        self.push_down_predicate = self._push_down_predicate()
         self.read_type = read_type
 
     def to_iterator(self, splits: List[Split]) -> Iterator:
@@ -78,12 +82,12 @@ class TableRead:
                             row_tuple_chunk.append(row.row_tuple[row.offset: 
row.offset + row.arity])
 
                             if len(row_tuple_chunk) >= chunk_size:
-                                batch = 
convert_rows_to_arrow_batch(row_tuple_chunk, schema)
+                                batch = 
self.convert_rows_to_arrow_batch(row_tuple_chunk, schema)
                                 yield batch
                                 row_tuple_chunk = []
 
                     if row_tuple_chunk:
-                        batch = convert_rows_to_arrow_batch(row_tuple_chunk, 
schema)
+                        batch = 
self.convert_rows_to_arrow_batch(row_tuple_chunk, schema)
                         yield batch
             finally:
                 reader.close()
@@ -105,11 +109,27 @@ class TableRead:
 
         return ray.data.from_arrow(self.to_arrow(splits))
 
+    def _push_down_predicate(self) -> pc.Expression | bool:
+        if self.predicate is None:
+            return None
+        elif self.table.is_primary_key_table:
+            result = []
+            extract_predicate_to_list(result, self.predicate, 
self.table.primary_keys)
+            if result:
+                # the field index is unused for arrow field
+                pk_predicates = 
(PredicateBuilder(self.table.fields).and_predicates(result)).to_arrow()
+                return pk_predicates
+            else:
+                return None
+        else:
+            return self.predicate.to_arrow()
+
     def _create_split_read(self, split: Split) -> SplitRead:
         if self.table.is_primary_key_table and not split.raw_convertible:
             return MergeFileSplitRead(
                 table=self.table,
                 predicate=self.predicate,
+                push_down_predicate=self.push_down_predicate,
                 read_type=self.read_type,
                 split=split
             )
@@ -117,12 +137,13 @@ class TableRead:
             return RawFileSplitRead(
                 table=self.table,
                 predicate=self.predicate,
+                push_down_predicate=self.push_down_predicate,
                 read_type=self.read_type,
                 split=split
             )
 
-
-def convert_rows_to_arrow_batch(row_tuples: List[tuple], schema: 
pyarrow.Schema) -> pyarrow.RecordBatch:
-    columns_data = zip(*row_tuples)
-    pydict = {name: list(column) for name, column in zip(schema.names, 
columns_data)}
-    return pyarrow.RecordBatch.from_pydict(pydict, schema=schema)
+    @staticmethod
+    def convert_rows_to_arrow_batch(row_tuples: List[tuple], schema: 
pyarrow.Schema) -> pyarrow.RecordBatch:
+        columns_data = zip(*row_tuples)
+        pydict = {name: list(column) for name, column in zip(schema.names, 
columns_data)}
+        return pyarrow.RecordBatch.from_pydict(pydict, schema=schema)
diff --git a/paimon-python/pypaimon/read/table_scan.py 
b/paimon-python/pypaimon/read/table_scan.py
index d89ddd0bcb..1c2c4f33dc 100644
--- a/paimon-python/pypaimon/read/table_scan.py
+++ b/paimon-python/pypaimon/read/table_scan.py
@@ -20,12 +20,15 @@ from collections import defaultdict
 from typing import Callable, List, Optional
 
 from pypaimon.common.predicate import Predicate
+from pypaimon.common.predicate_builder import PredicateBuilder
 from pypaimon.manifest.manifest_file_manager import ManifestFileManager
 from pypaimon.manifest.manifest_list_manager import ManifestListManager
 from pypaimon.manifest.schema.data_file_meta import DataFileMeta
 from pypaimon.manifest.schema.manifest_entry import ManifestEntry
 from pypaimon.read.interval_partition import IntervalPartition, SortedRun
 from pypaimon.read.plan import Plan
+from pypaimon.read.push_down_utils import (extract_predicate_to_dict,
+                                           extract_predicate_to_list)
 from pypaimon.read.split import Split
 from pypaimon.schema.data_types import DataField
 from pypaimon.snapshot.snapshot_manager import SnapshotManager
@@ -49,15 +52,23 @@ class TableScan:
         self.manifest_list_manager = ManifestListManager(table)
         self.manifest_file_manager = ManifestFileManager(table)
 
-        self.partition_conditions = self._extract_partition_conditions()
+        pk_conditions = []
+        trimmed_pk = [field.name for field in 
self.table.table_schema.get_trimmed_primary_key_fields()]
+        extract_predicate_to_list(pk_conditions, self.predicate, trimmed_pk)
+        self.primary_key_predicate = 
PredicateBuilder(self.table.fields).and_predicates(pk_conditions)
+
+        partition_conditions = defaultdict(list)
+        extract_predicate_to_dict(partition_conditions, self.predicate, 
self.table.partition_keys)
+        self.partition_key_predicate = partition_conditions
+
         self.target_split_size = 128 * 1024 * 1024
         self.open_file_cost = 4 * 1024 * 1024
 
         self.idx_of_this_subtask = None
         self.number_of_para_subtasks = None
 
-        self.only_read_real_buckets = True if self.table.options.get('bucket',
-                                                                     -1) == 
BucketMode.POSTPONE_BUCKET.value else False
+        self.only_read_real_buckets = True \
+            if (self.table.options.get('bucket', -1) == 
BucketMode.POSTPONE_BUCKET.value) else False
 
     def plan(self) -> Plan:
         latest_snapshot = self.snapshot_manager.get_latest_snapshot()
@@ -129,7 +140,7 @@ class TableScan:
 
         filtered_files = []
         for file_entry in file_entries:
-            if self.partition_conditions and not 
self._filter_by_partition(file_entry):
+            if self.partition_key_predicate and not 
self._filter_by_partition(file_entry):
                 continue
             if not self._filter_by_stats(file_entry):
                 continue
@@ -138,98 +149,31 @@ class TableScan:
         return filtered_files
 
     def _filter_by_partition(self, file_entry: ManifestEntry) -> bool:
-        # TODO: refactor with a better solution
         partition_dict = file_entry.partition.to_dict()
-        for field_name, condition in self.partition_conditions.items():
+        for field_name, conditions in self.partition_key_predicate.items():
             partition_value = partition_dict[field_name]
-            if condition['op'] == '=':
-                if str(partition_value) != str(condition['value']):
-                    return False
-            elif condition['op'] == 'in':
-                if str(partition_value) not in [str(v) for v in 
condition['values']]:
-                    return False
-            elif condition['op'] == 'notIn':
-                if str(partition_value) in [str(v) for v in 
condition['values']]:
-                    return False
-            elif condition['op'] == '>':
-                if partition_value <= condition['values']:
-                    return False
-            elif condition['op'] == '>=':
-                if partition_value < condition['values']:
-                    return False
-            elif condition['op'] == '<':
-                if partition_value >= condition['values']:
-                    return False
-            elif condition['op'] == '<=':
-                if partition_value > condition['values']:
+            for predicate in conditions:
+                if not predicate.test_by_value(partition_value):
                     return False
         return True
 
     def _filter_by_stats(self, file_entry: ManifestEntry) -> bool:
-        # TODO: real support for filtering by stat
-        return True
-
-    def _extract_partition_conditions(self) -> dict:
-        if not self.predicate or not self.table.partition_keys:
-            return {}
-
-        conditions = {}
-        self._extract_conditions_from_predicate(self.predicate, conditions, 
self.table.partition_keys)
-        return conditions
-
-    def _extract_conditions_from_predicate(self, predicate: 'Predicate', 
conditions: dict,
-                                           partition_keys: List[str]):
-        if predicate.method == 'and':
-            for sub_predicate in predicate.literals:
-                self._extract_conditions_from_predicate(sub_predicate, 
conditions, partition_keys)
-            return
-        elif predicate.method == 'or':
-            all_partition_conditions = True
-            for sub_predicate in predicate.literals:
-                if sub_predicate.field not in partition_keys:
-                    all_partition_conditions = False
-                    break
-            if all_partition_conditions:
-                for sub_predicate in predicate.literals:
-                    self._extract_conditions_from_predicate(sub_predicate, 
conditions, partition_keys)
-            return
-
-        if predicate.field in partition_keys:
-            if predicate.method == 'equal':
-                conditions[predicate.field] = {
-                    'op': '=',
-                    'value': predicate.literals[0] if predicate.literals else 
None
-                }
-            elif predicate.method == 'in':
-                conditions[predicate.field] = {
-                    'op': 'in',
-                    'values': predicate.literals if predicate.literals else []
-                }
-            elif predicate.method == 'notIn':
-                conditions[predicate.field] = {
-                    'op': 'notIn',
-                    'values': predicate.literals if predicate.literals else []
-                }
-            elif predicate.method == 'greaterThan':
-                conditions[predicate.field] = {
-                    'op': '>',
-                    'value': predicate.literals[0] if predicate.literals else 
None
-                }
-            elif predicate.method == 'greaterOrEqual':
-                conditions[predicate.field] = {
-                    'op': '>=',
-                    'value': predicate.literals[0] if predicate.literals else 
None
-                }
-            elif predicate.method == 'lessThan':
-                conditions[predicate.field] = {
-                    'op': '<',
-                    'value': predicate.literals[0] if predicate.literals else 
None
-                }
-            elif predicate.method == 'lessOrEqual':
-                conditions[predicate.field] = {
-                    'op': '<=',
-                    'value': predicate.literals[0] if predicate.literals else 
None
-                }
+        if file_entry.kind != 0:
+            return False
+        if self.table.is_primary_key_table:
+            predicate = self.primary_key_predicate
+            stats = file_entry.file.key_stats
+        else:
+            predicate = self.predicate
+            stats = file_entry.file.value_stats
+        return predicate.test_by_stats({
+            "min_values": stats.min_values.to_dict(),
+            "max_values": stats.max_values.to_dict(),
+            "null_counts": {
+                stats.min_values.fields[i].name: stats.null_counts[i] for i in 
range(len(stats.min_values.fields))
+            },
+            "row_count": file_entry.file.row_count,
+        })
 
     def _create_append_only_splits(self, file_entries: List[ManifestEntry]) -> 
List['Split']:
         if not file_entries:
@@ -240,7 +184,7 @@ class TableScan:
         def weight_func(f: DataFileMeta) -> int:
             return max(f.file_size, self.open_file_cost)
 
-        packed_files: List[List[DataFileMeta]] = _pack_for_ordered(data_files, 
weight_func, self.target_split_size)
+        packed_files: List[List[DataFileMeta]] = 
self._pack_for_ordered(data_files, weight_func, self.target_split_size)
         return self._build_split_from_pack(packed_files, file_entries, False)
 
     def _create_primary_key_splits(self, file_entries: List[ManifestEntry]) -> 
List['Split']:
@@ -257,7 +201,8 @@ class TableScan:
         def weight_func(fl: List[DataFileMeta]) -> int:
             return max(sum(f.file_size for f in fl), self.open_file_cost)
 
-        packed_files: List[List[List[DataFileMeta]]] = 
_pack_for_ordered(sections, weight_func, self.target_split_size)
+        packed_files: List[List[List[DataFileMeta]]] = 
self._pack_for_ordered(sections, weight_func,
+                                                                              
self.target_split_size)
         flatten_packed_files: List[List[DataFileMeta]] = [
             [file for sub_pack in pack for file in sub_pack]
             for pack in packed_files
@@ -295,23 +240,23 @@ class TableScan:
                 splits.append(split)
         return splits
 
+    @staticmethod
+    def _pack_for_ordered(items: List, weight_func: Callable, target_weight: 
int) -> List[List]:
+        packed = []
+        bin_items = []
+        bin_weight = 0
 
-def _pack_for_ordered(items: List, weight_func: Callable, target_weight: int) 
-> List[List]:
-    packed = []
-    bin_items = []
-    bin_weight = 0
+        for item in items:
+            weight = weight_func(item)
+            if bin_weight + weight > target_weight and len(bin_items) > 0:
+                packed.append(bin_items)
+                bin_items.clear()
+                bin_weight = 0
 
-    for item in items:
-        weight = weight_func(item)
-        if bin_weight + weight > target_weight and len(bin_items) > 0:
-            packed.append(bin_items)
-            bin_items.clear()
-            bin_weight = 0
-
-        bin_weight += weight
-        bin_items.append(item)
+            bin_weight += weight
+            bin_items.append(item)
 
-    if len(bin_items) > 0:
-        packed.append(bin_items)
+        if len(bin_items) > 0:
+            packed.append(bin_items)
 
-    return packed
+        return packed
diff --git a/paimon-python/pypaimon/tests/predicate_push_down_test.py 
b/paimon-python/pypaimon/tests/predicate_push_down_test.py
new file mode 100644
index 0000000000..b5b403f674
--- /dev/null
+++ b/paimon-python/pypaimon/tests/predicate_push_down_test.py
@@ -0,0 +1,151 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+import os
+import shutil
+import tempfile
+import unittest
+
+import pyarrow as pa
+
+from pypaimon.catalog.catalog_factory import CatalogFactory
+from pypaimon.common.predicate_builder import PredicateBuilder
+from pypaimon.schema.schema import Schema
+
+
+class PredicatePushDownTest(unittest.TestCase):
+    @classmethod
+    def setUpClass(cls):
+        cls.tempdir = tempfile.mkdtemp()
+        cls.warehouse = os.path.join(cls.tempdir, 'warehouse')
+        cls.catalog = CatalogFactory.create({
+            'warehouse': cls.warehouse
+        })
+        cls.catalog.create_database('default', False)
+
+        cls.pa_schema = pa.schema([
+            pa.field('key1', pa.int32(), nullable=False),
+            pa.field('key2', pa.string(), nullable=False),
+            ('behavior', pa.string()),
+            pa.field('dt1', pa.string(), nullable=False),
+            pa.field('dt2', pa.int32(), nullable=False)
+        ])
+        cls.expected = pa.Table.from_pydict({
+            'key1': [1, 2, 3, 4, 5, 7, 8],
+            'key2': ['h', 'g', 'f', 'e', 'd', 'b', 'a'],
+            'behavior': ['a', 'b-new', 'c', None, 'e', 'g', 'h'],
+            'dt1': ['p1', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2'],
+            'dt2': [2, 2, 1, 2, 2, 1, 2],
+        }, schema=cls.pa_schema)
+
+    @classmethod
+    def tearDownClass(cls):
+        shutil.rmtree(cls.tempdir, ignore_errors=True)
+
+    def testPkReaderWithFilter(self):
+        schema = Schema.from_pyarrow_schema(self.pa_schema,
+                                            partition_keys=['dt1', 'dt2'],
+                                            primary_keys=['key1', 'key2'],
+                                            options={'bucket': '1'})
+        self.catalog.create_table('default.test_pk_filter', schema, False)
+        table = self.catalog.get_table('default.test_pk_filter')
+
+        write_builder = table.new_batch_write_builder()
+        table_write = write_builder.new_write()
+        table_commit = write_builder.new_commit()
+        data1 = {
+            'key1': [1, 2, 3, 4],
+            'key2': ['h', 'g', 'f', 'e'],
+            'behavior': ['a', 'b', 'c', None],
+            'dt1': ['p1', 'p1', 'p2', 'p1'],
+            'dt2': [2, 2, 1, 2],
+        }
+        pa_table = pa.Table.from_pydict(data1, schema=self.pa_schema)
+        table_write.write_arrow(pa_table)
+        table_commit.commit(table_write.prepare_commit())
+        table_write.close()
+        table_commit.close()
+
+        table_write = write_builder.new_write()
+        table_commit = write_builder.new_commit()
+        data1 = {
+            'key1': [5, 2, 7, 8],
+            'key2': ['d', 'g', 'b', 'a'],
+            'behavior': ['e', 'b-new', 'g', 'h'],
+            'dt1': ['p2', 'p1', 'p1', 'p2'],
+            'dt2': [2, 2, 1, 2]
+        }
+        pa_table = pa.Table.from_pydict(data1, schema=self.pa_schema)
+        table_write.write_arrow(pa_table)
+        table_commit.commit(table_write.prepare_commit())
+        table_write.close()
+        table_commit.close()
+
+        # test filter by partition
+        predicate_builder: PredicateBuilder = 
table.new_read_builder().new_predicate_builder()
+        p1 = predicate_builder.startswith('dt1', "p1")
+        p2 = predicate_builder.is_in('dt1', ["p2"])
+        p3 = predicate_builder.or_predicates([p1, p2])
+        p4 = predicate_builder.equal('dt2', 2)
+        g1 = predicate_builder.and_predicates([p3, p4])
+        # (dt1 startswith 'p1' or dt1 is_in ["p2"]) and dt2 == 2
+        read_builder = table.new_read_builder().with_filter(g1)
+        splits = read_builder.new_scan().plan().splits()
+        self.assertEqual(len(splits), 2)
+        self.assertEqual(splits[0].partition.to_dict()["dt2"], 2)
+        self.assertEqual(splits[1].partition.to_dict()["dt2"], 2)
+
+        # test filter by stats
+        predicate_builder: PredicateBuilder = 
table.new_read_builder().new_predicate_builder()
+        p1 = predicate_builder.equal('key1', 7)
+        p2 = predicate_builder.is_in('key2', ["e", "f"])
+        p3 = predicate_builder.or_predicates([p1, p2])
+        p4 = predicate_builder.greater_than('key1', 3)
+        g1 = predicate_builder.and_predicates([p3, p4])
+        # (key1 == 7 or key2 is_in ["e", "f"]) and key1 > 3
+        read_builder = table.new_read_builder().with_filter(g1)
+        splits = read_builder.new_scan().plan().splits()
+        # initial splits meta:
+        # p1, 2 -> 2g, 2g; 1e, 4h
+        # p2, 1 -> 3f, 3f
+        # p2, 2 -> 5a, 8d
+        # p1, 1 -> 7b, 7b
+        self.assertEqual(len(splits), 3)
+        # expect to filter out `p1, 2 -> 2g, 2g` and `p2, 1 -> 3f, 3f`
+        count = 0
+        for split in splits:
+            if split.partition.values == ["p1", 2]:
+                count += 1
+                self.assertEqual(len(split.files), 1)
+                min_values = split.files[0].value_stats.min_values.to_dict()
+                max_values = split.files[0].value_stats.max_values.to_dict()
+                self.assertTrue(min_values["key1"] == 1 and min_values["key2"] 
== "e"
+                                and max_values["key1"] == 4 and 
max_values["key2"] == "h")
+            elif split.partition.values == ["p2", 2]:
+                count += 1
+                min_values = split.files[0].value_stats.min_values.to_dict()
+                max_values = split.files[0].value_stats.max_values.to_dict()
+                self.assertTrue(min_values["key1"] == 5 and min_values["key2"] 
== "a"
+                                and max_values["key1"] == 8 and 
max_values["key2"] == "d")
+            elif split.partition.values == ["p1", 1]:
+                count += 1
+                min_values = split.files[0].value_stats.min_values.to_dict()
+                max_values = split.files[0].value_stats.max_values.to_dict()
+                self.assertTrue(min_values["key1"] == max_values["key1"] == 7
+                                and max_values["key2"] == max_values["key2"] 
== "b")
+        self.assertEqual(count, 3)
diff --git a/paimon-python/pypaimon/write/file_store_commit.py 
b/paimon-python/pypaimon/write/file_store_commit.py
index e7bf7ba534..efe8207ebe 100644
--- a/paimon-python/pypaimon/write/file_store_commit.py
+++ b/paimon-python/pypaimon/write/file_store_commit.py
@@ -81,15 +81,15 @@ class FileStoreCommit:
             num_added_files=sum(len(msg.new_files) for msg in commit_messages),
             num_deleted_files=0,
             partition_stats=SimpleStats(
-                min_value=BinaryRow(
+                min_values=BinaryRow(
                     values=partition_min_stats,
                     fields=self.table.table_schema.get_partition_key_fields(),
                 ),
-                max_value=BinaryRow(
+                max_values=BinaryRow(
                     values=partition_max_stats,
                     fields=self.table.table_schema.get_partition_key_fields(),
                 ),
-                null_count=partition_null_counts,
+                null_counts=partition_null_counts,
             ),
             schema_id=self.table.table_schema.id,
         )
diff --git a/paimon-python/pypaimon/write/writer/data_writer.py 
b/paimon-python/pypaimon/write/writer/data_writer.py
index 5d9641718c..bc4553847c 100644
--- a/paimon-python/pypaimon/write/writer/data_writer.py
+++ b/paimon-python/pypaimon/write/writer/data_writer.py
@@ -19,7 +19,7 @@ import uuid
 from abc import ABC, abstractmethod
 from datetime import datetime
 from pathlib import Path
-from typing import List, Optional, Tuple
+from typing import Dict, List, Optional, Tuple
 
 import pyarrow as pa
 import pyarrow.compute as pc
@@ -128,13 +128,13 @@ class DataWriter(ABC):
             for field in self.table.table_schema.fields
         }
         all_fields = self.table.table_schema.fields
-        min_value_stats = [column_stats[field.name]['min_value'] for field in 
all_fields]
-        max_value_stats = [column_stats[field.name]['max_value'] for field in 
all_fields]
-        value_null_counts = [column_stats[field.name]['null_count'] for field 
in all_fields]
+        min_value_stats = [column_stats[field.name]['min_values'] for field in 
all_fields]
+        max_value_stats = [column_stats[field.name]['max_values'] for field in 
all_fields]
+        value_null_counts = [column_stats[field.name]['null_counts'] for field 
in all_fields]
         key_fields = self.trimmed_primary_key_fields
-        min_key_stats = [column_stats[field.name]['min_value'] for field in 
key_fields]
-        max_key_stats = [column_stats[field.name]['max_value'] for field in 
key_fields]
-        key_null_counts = [column_stats[field.name]['null_count'] for field in 
key_fields]
+        min_key_stats = [column_stats[field.name]['min_values'] for field in 
key_fields]
+        max_key_stats = [column_stats[field.name]['max_values'] for field in 
key_fields]
+        key_null_counts = [column_stats[field.name]['null_counts'] for field 
in key_fields]
         if not all(count == 0 for count in key_null_counts):
             raise RuntimeError("Primary key should not be null")
 
@@ -203,21 +203,21 @@ class DataWriter(ABC):
         return best_split
 
     @staticmethod
-    def _get_column_stats(record_batch: pa.RecordBatch, column_name: str) -> 
dict:
+    def _get_column_stats(record_batch: pa.RecordBatch, column_name: str) -> 
Dict:
         column_array = record_batch.column(column_name)
         if column_array.null_count == len(column_array):
             return {
-                "min_value": None,
-                "max_value": None,
-                "null_count": column_array.null_count,
+                "min_values": None,
+                "max_values": None,
+                "null_counts": column_array.null_count,
             }
-        min_value = pc.min(column_array).as_py()
-        max_value = pc.max(column_array).as_py()
-        null_count = column_array.null_count
+        min_values = pc.min(column_array).as_py()
+        max_values = pc.max(column_array).as_py()
+        null_counts = column_array.null_count
         return {
-            "min_value": min_value,
-            "max_value": max_value,
-            "null_count": null_count,
+            "min_values": min_values,
+            "max_values": max_values,
+            "null_counts": null_counts,
         }
 
 

Reply via email to