This is an automated email from the ASF dual-hosted git repository.

JingsongLi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git


The following commit(s) were added to refs/heads/master by this push:
     new 5eee48055b [python] Add per-partition bucket pruning for HASH_FIXED 
tables (#7804)
5eee48055b is described below

commit 5eee48055baed6bee84a49887cbbb2f54fd215b5
Author: chaoyang <[email protected]>
AuthorDate: Mon May 11 16:11:45 2026 +0800

    [python] Add per-partition bucket pruning for HASH_FIXED tables (#7804)
    
    PR-5.4 (#7744) added bucket pruning for HASH_FIXED tables but only on
    the bucket-key dimension. Predicates that mix a partition column and
    a bucket column under a top-level OR — e.g.
    ``(part='a' AND id=1) OR (part='b' AND id=2)`` — couldn't be pruned:
    the OR mixes two dimensions, so the existing logic gave up and read
    every bucket in both partitions. PR-5.4 left this as a TODO in the
    module docstring.
    
    ## Effect
    
    Same query now reads exactly one bucket per partition (the bucket
    holding ``id=1`` in ``part='a'``, the bucket holding ``id=2`` in
    ``part='b'``). The selector evaluates the predicate per partition
    value first — the OR collapses to a single AND inside each partition
    — and bucket selection runs on that simplified form.
    
    Soundness contract is unchanged: the bucket set remains a superset
    of the buckets that contain matching rows; any error falls open to
    "all buckets accept", never drops a bucket with matches.
    
    Two commits — helper + ``FileScanner`` wiring. 9 unit tests cover the
    predicate-fold walker and the per-partition cache; one e2e test on a
    2-partition × 4-bucket table proves the mixed-OR query reads ≤ 2
    splits instead of one per (partition, bucket).
---
 .../read/scanner/bucket_select_converter.py        | 390 ++++++++++++++++-----
 .../pypaimon/read/scanner/file_scanner.py          |  31 +-
 .../pypaimon/tests/pushdown_bucket_test.py         | 271 ++++++++++++++
 3 files changed, 588 insertions(+), 104 deletions(-)

diff --git a/paimon-python/pypaimon/read/scanner/bucket_select_converter.py 
b/paimon-python/pypaimon/read/scanner/bucket_select_converter.py
index da3f0e0f47..7b2e9e104f 100644
--- a/paimon-python/pypaimon/read/scanner/bucket_select_converter.py
+++ b/paimon-python/pypaimon/read/scanner/bucket_select_converter.py
@@ -51,38 +51,26 @@ Conservative scope (deliberately narrower than Java's 
general flexibility):
   * Total cartesian product capped at MAX_VALUES (1000), again matching
     Java; above that, fall back to a full scan.
 
-Returns a callable ``selector(bucket: int, total_buckets: int) -> bool``.
-The callable is cached per ``total_buckets`` to handle the rare case
-where bucket count varies across snapshots (rescale).
-
-TODO: per-partition predicate pre-evaluation.
-
-  Predicates of the form ``(part='a' AND bk IN (1,2)) OR (part='b' AND bk
-  IN (3,4))`` currently fall through to "no pruning" because the top-level
-  OR mixes partition and bucket-key constraints. Java simplifies the
-  predicate per concrete partition value first (replacing partition
-  leaves with literal true/false and folding AND/OR), so each partition
-  gets a tighter bucket-key predicate and the corresponding bucket set.
-
-  Implementing this here would need three pieces:
-
-    * a Predicate-replace walker that substitutes a partition's actual
-      values into partition-column leaves (mirrors Java's
-      ``paimon-common/.../predicate/PartitionValuePredicateVisitor.java``).
-    * lifting ``_Selector`` to key its cache by
-      ``(partition, total_buckets)`` instead of just ``total_buckets``.
-    * threading the partition value into the early manifest filter
-      ``FileScanner._build_early_bucket_filter`` (currently sees only
-      ``(bucket, total_buckets)``).
+Returns a callable ``selector(partition, bucket: int, total_buckets: int)
+-> bool``. The callable is cached per ``(partition, total_buckets)`` to
+handle (a) bucket count variation across snapshots (rescale) and (b)
+per-partition predicate specialisation: predicates of the form
+``(part='a' AND bk IN (1,2)) OR (part='b' AND bk IN (3,4))`` are
+simplified per concrete partition value before bucket selection, so each
+partition gets its own tight bucket set.
+
+When ``partition`` is ``None`` (early manifest filter that has not yet
+deserialised the entry), the selector falls back to a partition-agnostic
+result — sound but possibly wider than the per-partition tight set.
 """
 
 from itertools import product
-from typing import Any, Callable, Dict, FrozenSet, List, Optional
+from typing import Any, Callable, Dict, FrozenSet, List, Optional, Set, Tuple, 
Union
 
 from pypaimon.common.predicate import Predicate
 from pypaimon.schema.data_types import DataField
 from pypaimon.table.row.generic_row import GenericRow, GenericRowSerializer
-from pypaimon.table.row.internal_row import RowKind
+from pypaimon.table.row.internal_row import InternalRow, RowKind
 from pypaimon.write.row_key_extractor import (_bucket_from_hash,
                                               _hash_bytes_by_words)
 
@@ -177,78 +165,17 @@ def _extract_or_clause(or_pred: Predicate,
     return None if slot is None else [slot, values]
 
 
-class _Selector:
-    """Callable bucket filter, lazy + cached per ``total_buckets``."""
-
-    __slots__ = ('_combinations', '_bucket_key_fields', '_cache')
-
-    def __init__(self, combinations: List[List[Any]],
-                 bucket_key_fields: List[DataField]):
-        self._combinations = combinations
-        self._bucket_key_fields = bucket_key_fields
-        self._cache: Dict[int, FrozenSet[int]] = {}
-
-    def __call__(self, bucket: int, total_buckets: int) -> bool:
-        # ``total_buckets <= 0`` shows up for postpone / legacy / special
-        # entries and must NOT be pruned: returning False here would drop
-        # rows the writer hashed under a different convention. Fail open.
-        if total_buckets <= 0:
-            return True
-        try:
-            return bucket in self._compute(total_buckets)
-        except Exception:
-            # Fail open on any hashing/serialization error (e.g. a literal
-            # type that doesn't match the bucket-key column's atomic type:
-            # ``pb.equal('id_bigint', 'foo')`` — GenericRowSerializer raises
-            # struct.error trying to pack the string as int64). Crashing
-            # the entire scan here would be worse than skipping pruning;
-            # the soundness contract still forbids false-negatives.
-            return True
-
-    def _compute(self, total_buckets: int) -> FrozenSet[int]:
-        cached = self._cache.get(total_buckets)
-        if cached is not None:
-            return cached
-        result = set()
-        for combo in self._combinations:
-            row = GenericRow(list(combo), self._bucket_key_fields,
-                             RowKind.INSERT)
-            serialized = GenericRowSerializer.to_bytes(row)
-            # Skip the 4-byte length prefix — matches the writer's hash
-            # input exactly (see RowKeyExtractor._binary_row_hash_code).
-            h = _hash_bytes_by_words(serialized[4:])
-            result.add(_bucket_from_hash(h, total_buckets))
-        frozen = frozenset(result)
-        self._cache[total_buckets] = frozen
-        return frozen
-
-    @property
-    def bucket_combinations(self) -> int:
-        """Number of (bucket-key) combinations used to compute the filter.
-        Exposed for tests / observability."""
-        return len(self._combinations)
-
-
-def create_bucket_selector(
-        predicate: Optional[Predicate],
-        bucket_key_fields: List[DataField]) -> Optional[Callable[[int, int], 
bool]]:
-    """Try to derive a bucket selector from ``predicate`` constrained to
-    ``bucket_key_fields``.
+def _build_combinations(
+        predicate: Predicate,
+        bucket_key_fields: List[DataField]) -> Optional[List[List[Any]]]:
+    """Walk ``predicate`` for top-level AND clauses constraining bucket-key
+    columns by Equal/In, intersect repeated constraints, and return the
+    cartesian product of literal values (one row per combination).
 
-    Returns:
-      A callable ``(bucket, total_buckets) -> bool`` if the predicate
-      pins down all bucket keys to a finite Equal/In set; otherwise None
-      (caller must NOT prune by bucket).
+    Returns None when the predicate doesn't pin down every bucket-key
+    column or when the cartesian product exceeds ``MAX_VALUES`` — the
+    caller treats that as "no pruning, all buckets accept".
     """
-    if predicate is None or not bucket_key_fields:
-        return None
-
-    # See ``_UNSAFE_BUCKET_KEY_TYPES``: refuse pruning when the bucket-key
-    # column types are prone to writer/reader byte-level disagreement on
-    # equal logical values. Fail open rather than risk false-negatives.
-    if _has_unsafe_bucket_key_type(bucket_key_fields):
-        return None
-
     bk_name_to_slot: Dict[str, int] = {
         f.name: i for i, f in enumerate(bucket_key_fields)
     }
@@ -294,5 +221,274 @@ def create_bucket_selector(
         if total > MAX_VALUES:
             return None
 
-    combinations = [list(combo) for combo in product(*slot_values)]
-    return _Selector(combinations, bucket_key_fields)
+    return [list(combo) for combo in product(*slot_values)]
+
+
+def _hash_combinations(combinations: List[List[Any]],
+                       bucket_key_fields: List[DataField],
+                       total_buckets: int) -> FrozenSet[int]:
+    result = set()
+    for combo in combinations:
+        row = GenericRow(list(combo), bucket_key_fields, RowKind.INSERT)
+        serialized = GenericRowSerializer.to_bytes(row)
+        # Skip the 4-byte length prefix — matches the writer's hash
+        # input exactly (see RowKeyExtractor._binary_row_hash_code).
+        h = _hash_bytes_by_words(serialized[4:])
+        result.add(_bucket_from_hash(h, total_buckets))
+    return frozenset(result)
+
+
+def _predicate_touches_partition(predicate: Predicate,
+                                 partition_field_names: Set[str]) -> bool:
+    """True if ``predicate`` references any partition column directly or
+    inside an AND/OR/NOT subtree."""
+    if predicate.method in ('and', 'or', 'not'):
+        return any(_predicate_touches_partition(c, partition_field_names)
+                   for c in (predicate.literals or []))
+    return predicate.field is not None and predicate.field in 
partition_field_names
+
+
+def _evaluate_partition_leaf(predicate: Predicate,
+                             partition_values: Dict[str, Any]) -> 
Optional[bool]:
+    """Evaluate ``predicate`` (a leaf on a partition column) against the
+    concrete partition values. Returns True / False, or ``None`` if the
+    leaf isn't safely evaluable here (caller should keep the leaf
+    unchanged — bucket selection stays sound as long as we don't fold
+    away an evaluable False).
+    """
+    field_value = partition_values.get(predicate.field)
+    tester = Predicate.testers.get(predicate.method)
+    if tester is None:
+        return None
+    try:
+        return tester.test_by_value(field_value, predicate.literals)
+    except Exception:
+        return None
+
+
+_AlwaysFalse = False  # sentinel: predicate always evaluates to False
+_AlwaysTrue = None    # sentinel: predicate cleared (always True)
+
+
+def replace_partition_predicate(
+        predicate: Predicate,
+        partition_field_names: Set[str],
+        partition_values: Dict[str, Any]) -> Optional[Union[bool, Predicate]]:
+    """Substitute partition-column leaves with their concrete values and
+    fold away always-true / always-false sub-expressions.
+
+    Three-way return:
+
+    * ``None`` — predicate is unconditionally True after substitution
+      (no constraint left for this partition).
+    * ``False`` — predicate is unconditionally False (this partition
+      cannot contain matching rows).
+    * ``Predicate`` — the simplified predicate; partition leaves are
+      gone. The caller continues bucket-key extraction on this.
+    """
+    if predicate.method == 'and':
+        new_children: List[Predicate] = []
+        for child in (predicate.literals or []):
+            simplified = replace_partition_predicate(
+                child, partition_field_names, partition_values)
+            if simplified is _AlwaysFalse:
+                return _AlwaysFalse
+            if simplified is _AlwaysTrue:
+                continue
+            new_children.append(simplified)
+        if not new_children:
+            return _AlwaysTrue
+        if len(new_children) == 1:
+            return new_children[0]
+        return Predicate(method='and', index=None, field=None,
+                         literals=new_children)
+
+    if predicate.method == 'or':
+        new_children = []
+        for child in (predicate.literals or []):
+            simplified = replace_partition_predicate(
+                child, partition_field_names, partition_values)
+            if simplified is _AlwaysTrue:
+                return _AlwaysTrue
+            if simplified is _AlwaysFalse:
+                continue
+            new_children.append(simplified)
+        if not new_children:
+            return _AlwaysFalse
+        if len(new_children) == 1:
+            return new_children[0]
+        return Predicate(method='or', index=None, field=None,
+                         literals=new_children)
+
+    # Leaf predicate.
+    if predicate.field is not None and predicate.field in 
partition_field_names:
+        truth = _evaluate_partition_leaf(predicate, partition_values)
+        if truth is True:
+            return _AlwaysTrue
+        if truth is False:
+            return _AlwaysFalse
+        # Couldn't safely evaluate — keep the leaf. Bucket selection
+        # stays sound: the leaf still gets ANDed in, just doesn't help
+        # narrow buckets for this partition.
+        return predicate
+
+    # Non-partition leaf: keep as-is.
+    return predicate
+
+
+def _partition_to_dict(partition: Optional[InternalRow],
+                       partition_fields: List[DataField]) -> Dict[str, Any]:
+    """Pull each partition column's value out of ``partition`` keyed by
+    field name. Returns an empty dict when ``partition`` is None."""
+    if partition is None:
+        return {}
+    out: Dict[str, Any] = {}
+    for i, field in enumerate(partition_fields):
+        try:
+            out[field.name] = partition.get_field(i)
+        except Exception:
+            out[field.name] = None
+    return out
+
+
+def _partition_to_cache_key(partition: Optional[InternalRow],
+                            partition_fields: List[DataField]
+                            ) -> Optional[Tuple[Any, ...]]:
+    if partition is None or not partition_fields:
+        return None
+    try:
+        return tuple(partition.get_field(i) for i in 
range(len(partition_fields)))
+    except Exception:
+        return None
+
+
+class _Selector:
+    """Callable bucket filter, lazy + cached per ``(partition, 
total_buckets)``."""
+
+    __slots__ = ('_predicate', '_bucket_key_fields', '_partition_fields',
+                 '_cache')
+
+    def __init__(self, predicate: Predicate,
+                 bucket_key_fields: List[DataField],
+                 partition_fields: Optional[List[DataField]] = None):
+        self._predicate = predicate
+        self._bucket_key_fields = bucket_key_fields
+        self._partition_fields = list(partition_fields or [])
+        self._cache: Dict[Tuple[Optional[Tuple[Any, ...]], int], 
FrozenSet[int]] = {}
+
+    def __call__(self, *args) -> bool:
+        # Accept ``(bucket, total_buckets)`` (early manifest filter that
+        # hasn't deserialised the entry yet — partition unknown) or
+        # ``(partition, bucket, total_buckets)`` (late filter on a fully
+        # decoded ``ManifestEntry``). The two-arg form is partition-
+        # agnostic; partition substitution is skipped.
+        if len(args) == 2:
+            partition = None
+            bucket, total_buckets = args
+        elif len(args) == 3:
+            partition, bucket, total_buckets = args
+        else:
+            raise TypeError(
+                "_Selector expects 2 or 3 positional args, got %d" % len(args))
+        # ``total_buckets <= 0`` shows up for postpone / legacy / special
+        # entries and must NOT be pruned: returning False here would drop
+        # rows the writer hashed under a different convention. Fail open.
+        if total_buckets <= 0:
+            return True
+        try:
+            return bucket in self._compute(partition, total_buckets)
+        except Exception:
+            # Fail open on any hashing / serialization / specialisation
+            # error (e.g. a literal type that doesn't match the bucket-key
+            # column's atomic type). Crashing the entire scan here would
+            # be worse than skipping pruning; the soundness contract still
+            # forbids false-negatives.
+            return True
+
+    def _compute(self, partition, total_buckets: int) -> FrozenSet[int]:
+        cache_key = (_partition_to_cache_key(partition, 
self._partition_fields),
+                     total_buckets)
+        cached = self._cache.get(cache_key)
+        if cached is not None:
+            return cached
+
+        effective_predicate: Optional[Union[bool, Predicate]] = self._predicate
+        if partition is not None and self._partition_fields:
+            partition_values = _partition_to_dict(partition, 
self._partition_fields)
+            partition_field_names = {f.name for f in self._partition_fields}
+            effective_predicate = replace_partition_predicate(
+                self._predicate, partition_field_names, partition_values)
+
+        if effective_predicate is _AlwaysFalse:
+            # No row in this partition can match — empty bucket set.
+            frozen: FrozenSet[int] = frozenset()
+            self._cache[cache_key] = frozen
+            return frozen
+
+        if effective_predicate is _AlwaysTrue:
+            # Predicate cleared after partition substitution — accept all
+            # buckets for this partition.
+            frozen = frozenset(range(total_buckets))
+            self._cache[cache_key] = frozen
+            return frozen
+
+        combinations = _build_combinations(effective_predicate,
+                                           self._bucket_key_fields)
+        if combinations is None:
+            # Couldn't pin down all bucket keys (or above MAX_VALUES) —
+            # fall back to "all buckets accept" for soundness.
+            frozen = frozenset(range(total_buckets))
+            self._cache[cache_key] = frozen
+            return frozen
+
+        frozen = _hash_combinations(combinations, self._bucket_key_fields,
+                                    total_buckets)
+        self._cache[cache_key] = frozen
+        return frozen
+
+
+def create_bucket_selector(
+        predicate: Optional[Predicate],
+        bucket_key_fields: List[DataField],
+        partition_fields: Optional[List[DataField]] = None,
+) -> Optional[Callable[[Any, int, int], bool]]:
+    """Try to derive a bucket selector from ``predicate`` constrained to
+    ``bucket_key_fields``.
+
+    Returns:
+      A callable ``(partition, bucket, total_buckets) -> bool``. When
+      ``partition_fields`` is given and the predicate references those
+      partition columns, the selector specialises the predicate per
+      partition value before hashing — this catches mixed forms like
+      ``(part='a' AND bk IN (1,2)) OR (part='b' AND bk IN (3,4))`` that
+      would otherwise be unprunable. ``partition=None`` callsites
+      (early manifest filter that hasn't deserialised the entry yet)
+      simply get the partition-agnostic result.
+
+      Returns None when the predicate carries no usable bucket-key
+      constraint at all (caller must NOT prune by bucket).
+    """
+    if predicate is None or not bucket_key_fields:
+        return None
+
+    # See ``_UNSAFE_BUCKET_KEY_TYPES``: refuse pruning when the bucket-key
+    # column types are prone to writer/reader byte-level disagreement on
+    # equal logical values. Fail open rather than risk false-negatives.
+    if _has_unsafe_bucket_key_type(bucket_key_fields):
+        return None
+
+    # Sanity gate: if the predicate without any partition substitution
+    # already fails to pin down bucket keys AND it doesn't touch any
+    # partition columns, there's no point handing the caller a selector
+    # that always returns "all buckets" — preserve the original "return
+    # None for unprunable" contract so the caller can skip the wrap.
+    partition_names = {f.name for f in (partition_fields or [])}
+    touches_partition = (
+        bool(partition_names)
+        and _predicate_touches_partition(predicate, partition_names)
+    )
+    if not touches_partition:
+        if _build_combinations(predicate, bucket_key_fields) is None:
+            return None
+
+    return _Selector(predicate, bucket_key_fields, partition_fields)
diff --git a/paimon-python/pypaimon/read/scanner/file_scanner.py 
b/paimon-python/pypaimon/read/scanner/file_scanner.py
index 70cfa6c978..ea0e8219ee 100755
--- a/paimon-python/pypaimon/read/scanner/file_scanner.py
+++ b/paimon-python/pypaimon/read/scanner/file_scanner.py
@@ -30,6 +30,7 @@ from pypaimon.manifest.manifest_list_manager import 
ManifestListManager
 from pypaimon.manifest.schema.manifest_entry import ManifestEntry
 from pypaimon.manifest.schema.manifest_file_meta import ManifestFileMeta
 from pypaimon.manifest.simple_stats_evolutions import SimpleStatsEvolutions
+from pypaimon.schema.data_types import DataField
 from pypaimon.read.plan import Plan
 from pypaimon.read.push_down_utils import (_get_all_fields,
                                            remove_row_id_filter,
@@ -356,11 +357,12 @@ class FileScanner:
         """Compose the (bucket, total_buckets) -> bool used by the manifest
         reader to drop entries before deserialising ``_FILE`` / partition.
 
-        Mirrors the BucketFilter applied at Java's InternalRow stage in
-        ``ManifestEntryCache``. The signature is intentionally minimal:
-        per-partition predicate pre-evaluation would also need
-        ``(partition, bucket, total_buckets)``, but the current selector
-        is partition-agnostic.
+        The selector is partition-aware now, but at this early stage the
+        partition field has not been deserialised yet, so callers stick
+        with the two-arg form. The selector internally falls back to a
+        partition-agnostic over-approximation; per-partition tightening
+        still happens later in ``_filter_manifest_entry`` once the entry
+        is fully decoded.
         """
         only_real = self.only_read_real_buckets
         selector = self._bucket_selector
@@ -479,7 +481,21 @@ class FileScanner:
             return None
         if not bucket_key_fields:
             return None
-        return create_bucket_selector(self.predicate, bucket_key_fields)
+        # Partition fields are passed so the selector can specialise
+        # the predicate per partition value at the late filter stage,
+        # turning ``(part='a' AND bk=1) OR (part='b' AND bk=2)`` into a
+        # precise bucket pick per partition instead of an over-scan.
+        partition_fields: Optional[List[DataField]] = None
+        if self.table.partition_keys:
+            partition_fields = [
+                self.table.field_dict[name]
+                for name in self.table.partition_keys
+                if name in self.table.field_dict
+            ]
+        return create_bucket_selector(
+            self.predicate, bucket_key_fields,
+            partition_fields=partition_fields,
+        )
 
     def _filter_manifest_entry(self, entry: ManifestEntry) -> bool:
         # Redundant safety net: the early filter in the manifest reader
@@ -489,7 +505,8 @@ class FileScanner:
             return False
         if (self._bucket_selector is not None
                 and entry.bucket >= 0
-                and not self._bucket_selector(entry.bucket, 
entry.total_buckets)):
+                and not self._bucket_selector(
+                    entry.partition, entry.bucket, entry.total_buckets)):
             return False
         if self.partition_key_predicate and not 
self.partition_key_predicate.test(entry.partition):
             return False
diff --git a/paimon-python/pypaimon/tests/pushdown_bucket_test.py 
b/paimon-python/pypaimon/tests/pushdown_bucket_test.py
index b83283200e..80d3065ca3 100644
--- a/paimon-python/pypaimon/tests/pushdown_bucket_test.py
+++ b/paimon-python/pypaimon/tests/pushdown_bucket_test.py
@@ -345,6 +345,217 @@ class BucketSelectConverterUnitTest(unittest.TestCase):
                                 "not crash (bucket={}, total={})".format(b, 
total))
 
 
+class PartitionAwareBucketSelectorUnitTest(unittest.TestCase):
+    """Unit tests for the per-partition predicate specialisation path.
+
+    Covers ``replace_partition_predicate`` (the AND/OR fold walker) and
+    the partition-aware ``_Selector.__call__(partition, bucket,
+    total_buckets)`` 3-arg form that ``FileScanner._filter_manifest_entry``
+    will use after wiring."""
+
+    @classmethod
+    def setUpClass(cls):
+        cls.id_field = _bigint_field(0, 'id')
+        cls.part_field = DataField(2, 'part', AtomicType('STRING'))
+        cls.pb = PredicateBuilder([cls.id_field, cls.part_field])
+
+    # ----- replace_partition_predicate --------------------------------
+
+    def test_replace_partition_leaf_to_true_drops_constraint(self):
+        from pypaimon.read.scanner.bucket_select_converter import \
+            replace_partition_predicate
+        # ``part = 'a' AND id = 1`` against partition {part: 'a'} →
+        # part leaf becomes True → AND fold removes it → only ``id = 1``
+        pred = PredicateBuilder.and_predicates([
+            self.pb.equal('part', 'a'),
+            self.pb.equal('id', 1),
+        ])
+        result = replace_partition_predicate(
+            pred, {'part'}, {'part': 'a'})
+        self.assertTrue(isinstance(result, type(pred)),
+                        "AND should fold to a remaining single leaf")
+        self.assertEqual(result.method, 'equal')
+        self.assertEqual(result.field, 'id')
+
+    def test_replace_partition_leaf_to_false_collapses_and(self):
+        from pypaimon.read.scanner.bucket_select_converter import \
+            replace_partition_predicate
+        # ``part = 'a' AND id = 1`` against partition {part: 'b'} →
+        # part leaf becomes False → AND collapses to AlwaysFalse (False).
+        pred = PredicateBuilder.and_predicates([
+            self.pb.equal('part', 'a'),
+            self.pb.equal('id', 1),
+        ])
+        result = replace_partition_predicate(
+            pred, {'part'}, {'part': 'b'})
+        self.assertIs(result, False)
+
+    def test_replace_partition_leaf_in_or_keeps_other_branch(self):
+        from pypaimon.read.scanner.bucket_select_converter import \
+            replace_partition_predicate
+        # ``(part='a' AND id=1) OR (part='b' AND id=2)`` against
+        # partition {part: 'a'} → first OR child becomes ``id=1``, second
+        # collapses to AlwaysFalse and is dropped. Result is just ``id=1``.
+        pred = PredicateBuilder.or_predicates([
+            PredicateBuilder.and_predicates([
+                self.pb.equal('part', 'a'),
+                self.pb.equal('id', 1),
+            ]),
+            PredicateBuilder.and_predicates([
+                self.pb.equal('part', 'b'),
+                self.pb.equal('id', 2),
+            ]),
+        ])
+        result = replace_partition_predicate(
+            pred, {'part'}, {'part': 'a'})
+        # OR with a single surviving child unwraps to that child.
+        self.assertEqual(result.method, 'equal')
+        self.assertEqual(result.field, 'id')
+        self.assertEqual(result.literals, [1])
+
+    def test_replace_partition_leaf_in_or_other_partition(self):
+        from pypaimon.read.scanner.bucket_select_converter import \
+            replace_partition_predicate
+        # Same predicate, partition {part: 'b'} → second branch survives
+        # as ``id=2``.
+        pred = PredicateBuilder.or_predicates([
+            PredicateBuilder.and_predicates([
+                self.pb.equal('part', 'a'),
+                self.pb.equal('id', 1),
+            ]),
+            PredicateBuilder.and_predicates([
+                self.pb.equal('part', 'b'),
+                self.pb.equal('id', 2),
+            ]),
+        ])
+        result = replace_partition_predicate(
+            pred, {'part'}, {'part': 'b'})
+        self.assertEqual(result.method, 'equal')
+        self.assertEqual(result.field, 'id')
+        self.assertEqual(result.literals, [2])
+
+    def test_replace_partition_leaf_unrelated_predicate_unchanged(self):
+        from pypaimon.read.scanner.bucket_select_converter import \
+            replace_partition_predicate
+        # No partition leaf → predicate returned as-is.
+        pred = self.pb.equal('id', 42)
+        result = replace_partition_predicate(
+            pred, {'part'}, {'part': 'a'})
+        self.assertIs(result, pred)
+
+    # ----- _Selector partition-aware path -----------------------------
+
+    def test_selector_3arg_specialises_per_partition(self):
+        # ``(part='a' AND id=1) OR (part='b' AND id=2)`` should hit
+        # bucket(1) only when partition='a' and bucket(2) only when
+        # partition='b'. Master without this fix returns "all buckets".
+        pred = PredicateBuilder.or_predicates([
+            PredicateBuilder.and_predicates([
+                self.pb.equal('part', 'a'),
+                self.pb.equal('id', 1),
+            ]),
+            PredicateBuilder.and_predicates([
+                self.pb.equal('part', 'b'),
+                self.pb.equal('id', 2),
+            ]),
+        ])
+        sel = create_bucket_selector(
+            pred, [self.id_field], partition_fields=[self.part_field])
+        self.assertIsNotNone(sel)
+        bucket_for_1 = _hash_bucket([1], [self.id_field], total=8)
+        bucket_for_2 = _hash_bucket([2], [self.id_field], total=8)
+
+        part_a = GenericRow(['a'], [self.part_field], RowKind.INSERT)
+        part_b = GenericRow(['b'], [self.part_field], RowKind.INSERT)
+
+        for b in range(8):
+            self.assertEqual(sel(part_a, b, 8), b == bucket_for_1,
+                             "partition a only keeps bucket %d" % bucket_for_1)
+            self.assertEqual(sel(part_b, b, 8), b == bucket_for_2,
+                             "partition b only keeps bucket %d" % bucket_for_2)
+
+    def test_selector_falls_through_when_partition_unknown(self):
+        """Early manifest filter passes ``partition=None`` (or uses the
+        2-arg form) — no specialisation runs, bucket set falls back to a
+        sound over-approximation: all buckets accept."""
+        pred = PredicateBuilder.or_predicates([
+            PredicateBuilder.and_predicates([
+                self.pb.equal('part', 'a'),
+                self.pb.equal('id', 1),
+            ]),
+            PredicateBuilder.and_predicates([
+                self.pb.equal('part', 'b'),
+                self.pb.equal('id', 2),
+            ]),
+        ])
+        sel = create_bucket_selector(
+            pred, [self.id_field], partition_fields=[self.part_field])
+        self.assertIsNotNone(sel)
+        # 2-arg form (legacy callsite) — partition unknown, all buckets keep.
+        for b in range(8):
+            self.assertTrue(sel(b, 8),
+                            "partition-unknown call must accept all buckets")
+        # 3-arg form with partition=None has the same semantics.
+        for b in range(8):
+            self.assertTrue(sel(None, b, 8))
+
+    def test_selector_partition_not_matching_returns_empty_bucket_set(self):
+        # ``part = 'a' AND id = 1`` on partition {part: 'c'} simplifies to
+        # AlwaysFalse — the selector returns False for every bucket since
+        # no row in this partition can possibly match. Sound: dropping a
+        # partition that *can't* contain matches doesn't lose data.
+        pred = PredicateBuilder.and_predicates([
+            self.pb.equal('part', 'a'),
+            self.pb.equal('id', 1),
+        ])
+        sel = create_bucket_selector(
+            pred, [self.id_field], partition_fields=[self.part_field])
+        self.assertIsNotNone(sel)
+        part_c = GenericRow(['c'], [self.part_field], RowKind.INSERT)
+        for b in range(8):
+            self.assertFalse(sel(part_c, b, 8),
+                             "partition c can't satisfy part='a', "
+                             "drop every bucket (b=%d)" % b)
+
+    def test_selector_partition_only_constraint_drops_partition(self):
+        # ``part='a' AND id IN (1,2)`` — same partition value 'a'
+        # specialises ``part='a'`` to True, leaving ``id IN (1,2)``.
+        pred = PredicateBuilder.and_predicates([
+            self.pb.equal('part', 'a'),
+            self.pb.is_in('id', [1, 2]),
+        ])
+        sel = create_bucket_selector(
+            pred, [self.id_field], partition_fields=[self.part_field])
+        self.assertIsNotNone(sel)
+        part_a = GenericRow(['a'], [self.part_field], RowKind.INSERT)
+        expected = {_hash_bucket([v], [self.id_field], 8) for v in (1, 2)}
+        for b in range(8):
+            self.assertEqual(sel(part_a, b, 8), b in expected)
+
+    def test_selector_caches_per_partition(self):
+        pred = PredicateBuilder.or_predicates([
+            PredicateBuilder.and_predicates([
+                self.pb.equal('part', 'a'),
+                self.pb.equal('id', 1),
+            ]),
+            PredicateBuilder.and_predicates([
+                self.pb.equal('part', 'b'),
+                self.pb.equal('id', 2),
+            ]),
+        ])
+        sel = create_bucket_selector(
+            pred, [self.id_field], partition_fields=[self.part_field])
+        part_a = GenericRow(['a'], [self.part_field], RowKind.INSERT)
+        part_b = GenericRow(['b'], [self.part_field], RowKind.INSERT)
+        # Drive the cache.
+        for _ in range(5):
+            sel(part_a, 0, 8)
+            sel(part_b, 0, 8)
+        # Cache keyed by (partition tuple, total_buckets); two distinct
+        # partitions × one total → exactly two entries.
+        self.assertEqual(len(sel._cache), 2)
+
+
 # ---------------------------------------------------------------------------
 # Layer 2 — Integration: real tables, public API, assert correctness AND
 # that pruning actually fired (otherwise we're not testing the optimisation,
@@ -606,6 +817,66 @@ class BucketPruningIntegrationTest(unittest.TestCase):
         self.assertEqual(self._split_buckets(splits),
                          self._expected_buckets(table, [17]))
 
+    def test_per_partition_pruning_with_mixed_or(self):
+        """``(part='a' AND id=1) OR (part='b' AND id=2)``: each partition
+        sees only the bucket for its own ``id`` literal. Without
+        per-partition predicate specialisation this query falls through
+        to "all buckets in both partitions"."""
+        opts = {'bucket': '4', 'file.format': 'parquet'}
+        pa_schema = pa.schema([
+            pa.field('part', pa.string(), nullable=False),
+            pa.field('id', pa.int64(), nullable=False),
+            ('val', pa.int64()),
+        ])
+        schema = Schema.from_pyarrow_schema(
+            pa_schema, primary_keys=['part', 'id'],
+            partition_keys=['part'], options=opts)
+        identifier = 'default.per_part_mixed_or'
+        self.catalog.create_table(identifier, schema, False)
+        table = self.catalog.get_table(identifier)
+        # Two partitions × three id values each → up to 6 (part, bucket)
+        # combinations after the writer hashes.
+        rows = []
+        for p in ('a', 'b'):
+            for i in (1, 2, 3):
+                rows.append({'part': p, 'id': i, 'val': i * 7})
+        wb = table.new_batch_write_builder()
+        w = wb.new_write()
+        c = wb.new_commit()
+        try:
+            w.write_arrow(pa.Table.from_pylist(rows, schema=pa_schema))
+            c.commit(w.prepare_commit())
+        finally:
+            w.close()
+            c.close()
+
+        pb = table.new_read_builder().new_predicate_builder()
+        from pypaimon.common.predicate_builder import PredicateBuilder
+        mixed = PredicateBuilder.or_predicates([
+            PredicateBuilder.and_predicates([
+                pb.equal('part', 'a'),
+                pb.equal('id', 1),
+            ]),
+            PredicateBuilder.and_predicates([
+                pb.equal('part', 'b'),
+                pb.equal('id', 2),
+            ]),
+        ])
+        got, splits = self._read_with(table, mixed)
+        # Correctness: only the two matching rows.
+        got_sorted = sorted(got, key=lambda r: (r['part'], r['id']))
+        self.assertEqual(
+            got_sorted,
+            [{'part': 'a', 'id': 1, 'val': 7},
+             {'part': 'b', 'id': 2, 'val': 14}])
+        # Pruning effectiveness: across both partitions we should see at
+        # most two distinct (partition, bucket) splits — one per branch.
+        # Without per-partition pruning we'd see every (partition, bucket)
+        # combo that exists on disk for the predicate's id literals.
+        self.assertLessEqual(len(splits), 2,
+                             "per-partition pruning should keep ≤ 2 splits, "
+                             "got %d" % len(splits))
+
 
 # ---------------------------------------------------------------------------
 # Layer 3 — Property: random PK tables, random Equal/In predicates,


Reply via email to