This is an automated email from the ASF dual-hosted git repository.
JingsongLi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git
The following commit(s) were added to refs/heads/master by this push:
new 5eee48055b [python] Add per-partition bucket pruning for HASH_FIXED
tables (#7804)
5eee48055b is described below
commit 5eee48055baed6bee84a49887cbbb2f54fd215b5
Author: chaoyang <[email protected]>
AuthorDate: Mon May 11 16:11:45 2026 +0800
[python] Add per-partition bucket pruning for HASH_FIXED tables (#7804)
PR-5.4 (#7744) added bucket pruning for HASH_FIXED tables but only on
the bucket-key dimension. Predicates that mix a partition column and
a bucket column under a top-level OR — e.g.
``(part='a' AND id=1) OR (part='b' AND id=2)`` — couldn't be pruned:
the OR mixes two dimensions, so the existing logic gave up and read
every bucket in both partitions. PR-5.4 left this as a TODO in the
module docstring.
## Effect
Same query now reads exactly one bucket per partition (the bucket
holding ``id=1`` in ``part='a'``, the bucket holding ``id=2`` in
``part='b'``). The selector evaluates the predicate per partition
value first — the OR collapses to a single AND inside each partition
— and bucket selection runs on that simplified form.
Soundness contract is unchanged: the bucket set remains a superset
of the buckets that contain matching rows; any error falls open to
"all buckets accept", never drops a bucket with matches.
Two commits — helper + ``FileScanner`` wiring. 9 unit tests cover the
predicate-fold walker and the per-partition cache; one e2e test on a
2-partition × 4-bucket table proves the mixed-OR query reads ≤ 2
splits instead of one per (partition, bucket).
---
.../read/scanner/bucket_select_converter.py | 390 ++++++++++++++++-----
.../pypaimon/read/scanner/file_scanner.py | 31 +-
.../pypaimon/tests/pushdown_bucket_test.py | 271 ++++++++++++++
3 files changed, 588 insertions(+), 104 deletions(-)
diff --git a/paimon-python/pypaimon/read/scanner/bucket_select_converter.py
b/paimon-python/pypaimon/read/scanner/bucket_select_converter.py
index da3f0e0f47..7b2e9e104f 100644
--- a/paimon-python/pypaimon/read/scanner/bucket_select_converter.py
+++ b/paimon-python/pypaimon/read/scanner/bucket_select_converter.py
@@ -51,38 +51,26 @@ Conservative scope (deliberately narrower than Java's
general flexibility):
* Total cartesian product capped at MAX_VALUES (1000), again matching
Java; above that, fall back to a full scan.
-Returns a callable ``selector(bucket: int, total_buckets: int) -> bool``.
-The callable is cached per ``total_buckets`` to handle the rare case
-where bucket count varies across snapshots (rescale).
-
-TODO: per-partition predicate pre-evaluation.
-
- Predicates of the form ``(part='a' AND bk IN (1,2)) OR (part='b' AND bk
- IN (3,4))`` currently fall through to "no pruning" because the top-level
- OR mixes partition and bucket-key constraints. Java simplifies the
- predicate per concrete partition value first (replacing partition
- leaves with literal true/false and folding AND/OR), so each partition
- gets a tighter bucket-key predicate and the corresponding bucket set.
-
- Implementing this here would need three pieces:
-
- * a Predicate-replace walker that substitutes a partition's actual
- values into partition-column leaves (mirrors Java's
- ``paimon-common/.../predicate/PartitionValuePredicateVisitor.java``).
- * lifting ``_Selector`` to key its cache by
- ``(partition, total_buckets)`` instead of just ``total_buckets``.
- * threading the partition value into the early manifest filter
- ``FileScanner._build_early_bucket_filter`` (currently sees only
- ``(bucket, total_buckets)``).
+Returns a callable ``selector(partition, bucket: int, total_buckets: int)
+-> bool``. The callable is cached per ``(partition, total_buckets)`` to
+handle (a) bucket count variation across snapshots (rescale) and (b)
+per-partition predicate specialisation: predicates of the form
+``(part='a' AND bk IN (1,2)) OR (part='b' AND bk IN (3,4))`` are
+simplified per concrete partition value before bucket selection, so each
+partition gets its own tight bucket set.
+
+When ``partition`` is ``None`` (early manifest filter that has not yet
+deserialised the entry), the selector falls back to a partition-agnostic
+result — sound but possibly wider than the per-partition tight set.
"""
from itertools import product
-from typing import Any, Callable, Dict, FrozenSet, List, Optional
+from typing import Any, Callable, Dict, FrozenSet, List, Optional, Set, Tuple,
Union
from pypaimon.common.predicate import Predicate
from pypaimon.schema.data_types import DataField
from pypaimon.table.row.generic_row import GenericRow, GenericRowSerializer
-from pypaimon.table.row.internal_row import RowKind
+from pypaimon.table.row.internal_row import InternalRow, RowKind
from pypaimon.write.row_key_extractor import (_bucket_from_hash,
_hash_bytes_by_words)
@@ -177,78 +165,17 @@ def _extract_or_clause(or_pred: Predicate,
return None if slot is None else [slot, values]
-class _Selector:
- """Callable bucket filter, lazy + cached per ``total_buckets``."""
-
- __slots__ = ('_combinations', '_bucket_key_fields', '_cache')
-
- def __init__(self, combinations: List[List[Any]],
- bucket_key_fields: List[DataField]):
- self._combinations = combinations
- self._bucket_key_fields = bucket_key_fields
- self._cache: Dict[int, FrozenSet[int]] = {}
-
- def __call__(self, bucket: int, total_buckets: int) -> bool:
- # ``total_buckets <= 0`` shows up for postpone / legacy / special
- # entries and must NOT be pruned: returning False here would drop
- # rows the writer hashed under a different convention. Fail open.
- if total_buckets <= 0:
- return True
- try:
- return bucket in self._compute(total_buckets)
- except Exception:
- # Fail open on any hashing/serialization error (e.g. a literal
- # type that doesn't match the bucket-key column's atomic type:
- # ``pb.equal('id_bigint', 'foo')`` — GenericRowSerializer raises
- # struct.error trying to pack the string as int64). Crashing
- # the entire scan here would be worse than skipping pruning;
- # the soundness contract still forbids false-negatives.
- return True
-
- def _compute(self, total_buckets: int) -> FrozenSet[int]:
- cached = self._cache.get(total_buckets)
- if cached is not None:
- return cached
- result = set()
- for combo in self._combinations:
- row = GenericRow(list(combo), self._bucket_key_fields,
- RowKind.INSERT)
- serialized = GenericRowSerializer.to_bytes(row)
- # Skip the 4-byte length prefix — matches the writer's hash
- # input exactly (see RowKeyExtractor._binary_row_hash_code).
- h = _hash_bytes_by_words(serialized[4:])
- result.add(_bucket_from_hash(h, total_buckets))
- frozen = frozenset(result)
- self._cache[total_buckets] = frozen
- return frozen
-
- @property
- def bucket_combinations(self) -> int:
- """Number of (bucket-key) combinations used to compute the filter.
- Exposed for tests / observability."""
- return len(self._combinations)
-
-
-def create_bucket_selector(
- predicate: Optional[Predicate],
- bucket_key_fields: List[DataField]) -> Optional[Callable[[int, int],
bool]]:
- """Try to derive a bucket selector from ``predicate`` constrained to
- ``bucket_key_fields``.
+def _build_combinations(
+ predicate: Predicate,
+ bucket_key_fields: List[DataField]) -> Optional[List[List[Any]]]:
+ """Walk ``predicate`` for top-level AND clauses constraining bucket-key
+ columns by Equal/In, intersect repeated constraints, and return the
+ cartesian product of literal values (one row per combination).
- Returns:
- A callable ``(bucket, total_buckets) -> bool`` if the predicate
- pins down all bucket keys to a finite Equal/In set; otherwise None
- (caller must NOT prune by bucket).
+ Returns None when the predicate doesn't pin down every bucket-key
+ column or when the cartesian product exceeds ``MAX_VALUES`` — the
+ caller treats that as "no pruning, all buckets accept".
"""
- if predicate is None or not bucket_key_fields:
- return None
-
- # See ``_UNSAFE_BUCKET_KEY_TYPES``: refuse pruning when the bucket-key
- # column types are prone to writer/reader byte-level disagreement on
- # equal logical values. Fail open rather than risk false-negatives.
- if _has_unsafe_bucket_key_type(bucket_key_fields):
- return None
-
bk_name_to_slot: Dict[str, int] = {
f.name: i for i, f in enumerate(bucket_key_fields)
}
@@ -294,5 +221,274 @@ def create_bucket_selector(
if total > MAX_VALUES:
return None
- combinations = [list(combo) for combo in product(*slot_values)]
- return _Selector(combinations, bucket_key_fields)
+ return [list(combo) for combo in product(*slot_values)]
+
+
+def _hash_combinations(combinations: List[List[Any]],
+ bucket_key_fields: List[DataField],
+ total_buckets: int) -> FrozenSet[int]:
+ result = set()
+ for combo in combinations:
+ row = GenericRow(list(combo), bucket_key_fields, RowKind.INSERT)
+ serialized = GenericRowSerializer.to_bytes(row)
+ # Skip the 4-byte length prefix — matches the writer's hash
+ # input exactly (see RowKeyExtractor._binary_row_hash_code).
+ h = _hash_bytes_by_words(serialized[4:])
+ result.add(_bucket_from_hash(h, total_buckets))
+ return frozenset(result)
+
+
+def _predicate_touches_partition(predicate: Predicate,
+ partition_field_names: Set[str]) -> bool:
+ """True if ``predicate`` references any partition column directly or
+ inside an AND/OR/NOT subtree."""
+ if predicate.method in ('and', 'or', 'not'):
+ return any(_predicate_touches_partition(c, partition_field_names)
+ for c in (predicate.literals or []))
+ return predicate.field is not None and predicate.field in
partition_field_names
+
+
+def _evaluate_partition_leaf(predicate: Predicate,
+ partition_values: Dict[str, Any]) ->
Optional[bool]:
+ """Evaluate ``predicate`` (a leaf on a partition column) against the
+ concrete partition values. Returns True / False, or ``None`` if the
+ leaf isn't safely evaluable here (caller should keep the leaf
+ unchanged — bucket selection stays sound as long as we don't fold
+ away an evaluable False).
+ """
+ field_value = partition_values.get(predicate.field)
+ tester = Predicate.testers.get(predicate.method)
+ if tester is None:
+ return None
+ try:
+ return tester.test_by_value(field_value, predicate.literals)
+ except Exception:
+ return None
+
+
+_AlwaysFalse = False # sentinel: predicate always evaluates to False
+_AlwaysTrue = None # sentinel: predicate cleared (always True)
+
+
+def replace_partition_predicate(
+ predicate: Predicate,
+ partition_field_names: Set[str],
+ partition_values: Dict[str, Any]) -> Optional[Union[bool, Predicate]]:
+ """Substitute partition-column leaves with their concrete values and
+ fold away always-true / always-false sub-expressions.
+
+ Three-way return:
+
+ * ``None`` — predicate is unconditionally True after substitution
+ (no constraint left for this partition).
+ * ``False`` — predicate is unconditionally False (this partition
+ cannot contain matching rows).
+ * ``Predicate`` — the simplified predicate; partition leaves are
+ gone. The caller continues bucket-key extraction on this.
+ """
+ if predicate.method == 'and':
+ new_children: List[Predicate] = []
+ for child in (predicate.literals or []):
+ simplified = replace_partition_predicate(
+ child, partition_field_names, partition_values)
+ if simplified is _AlwaysFalse:
+ return _AlwaysFalse
+ if simplified is _AlwaysTrue:
+ continue
+ new_children.append(simplified)
+ if not new_children:
+ return _AlwaysTrue
+ if len(new_children) == 1:
+ return new_children[0]
+ return Predicate(method='and', index=None, field=None,
+ literals=new_children)
+
+ if predicate.method == 'or':
+ new_children = []
+ for child in (predicate.literals or []):
+ simplified = replace_partition_predicate(
+ child, partition_field_names, partition_values)
+ if simplified is _AlwaysTrue:
+ return _AlwaysTrue
+ if simplified is _AlwaysFalse:
+ continue
+ new_children.append(simplified)
+ if not new_children:
+ return _AlwaysFalse
+ if len(new_children) == 1:
+ return new_children[0]
+ return Predicate(method='or', index=None, field=None,
+ literals=new_children)
+
+ # Leaf predicate.
+ if predicate.field is not None and predicate.field in
partition_field_names:
+ truth = _evaluate_partition_leaf(predicate, partition_values)
+ if truth is True:
+ return _AlwaysTrue
+ if truth is False:
+ return _AlwaysFalse
+ # Couldn't safely evaluate — keep the leaf. Bucket selection
+ # stays sound: the leaf still gets ANDed in, just doesn't help
+ # narrow buckets for this partition.
+ return predicate
+
+ # Non-partition leaf: keep as-is.
+ return predicate
+
+
+def _partition_to_dict(partition: Optional[InternalRow],
+ partition_fields: List[DataField]) -> Dict[str, Any]:
+ """Pull each partition column's value out of ``partition`` keyed by
+ field name. Returns an empty dict when ``partition`` is None."""
+ if partition is None:
+ return {}
+ out: Dict[str, Any] = {}
+ for i, field in enumerate(partition_fields):
+ try:
+ out[field.name] = partition.get_field(i)
+ except Exception:
+ out[field.name] = None
+ return out
+
+
+def _partition_to_cache_key(partition: Optional[InternalRow],
+ partition_fields: List[DataField]
+ ) -> Optional[Tuple[Any, ...]]:
+ if partition is None or not partition_fields:
+ return None
+ try:
+ return tuple(partition.get_field(i) for i in
range(len(partition_fields)))
+ except Exception:
+ return None
+
+
+class _Selector:
+ """Callable bucket filter, lazy + cached per ``(partition,
total_buckets)``."""
+
+ __slots__ = ('_predicate', '_bucket_key_fields', '_partition_fields',
+ '_cache')
+
+ def __init__(self, predicate: Predicate,
+ bucket_key_fields: List[DataField],
+ partition_fields: Optional[List[DataField]] = None):
+ self._predicate = predicate
+ self._bucket_key_fields = bucket_key_fields
+ self._partition_fields = list(partition_fields or [])
+ self._cache: Dict[Tuple[Optional[Tuple[Any, ...]], int],
FrozenSet[int]] = {}
+
+ def __call__(self, *args) -> bool:
+ # Accept ``(bucket, total_buckets)`` (early manifest filter that
+ # hasn't deserialised the entry yet — partition unknown) or
+ # ``(partition, bucket, total_buckets)`` (late filter on a fully
+ # decoded ``ManifestEntry``). The two-arg form is partition-
+ # agnostic; partition substitution is skipped.
+ if len(args) == 2:
+ partition = None
+ bucket, total_buckets = args
+ elif len(args) == 3:
+ partition, bucket, total_buckets = args
+ else:
+ raise TypeError(
+ "_Selector expects 2 or 3 positional args, got %d" % len(args))
+ # ``total_buckets <= 0`` shows up for postpone / legacy / special
+ # entries and must NOT be pruned: returning False here would drop
+ # rows the writer hashed under a different convention. Fail open.
+ if total_buckets <= 0:
+ return True
+ try:
+ return bucket in self._compute(partition, total_buckets)
+ except Exception:
+ # Fail open on any hashing / serialization / specialisation
+ # error (e.g. a literal type that doesn't match the bucket-key
+ # column's atomic type). Crashing the entire scan here would
+ # be worse than skipping pruning; the soundness contract still
+ # forbids false-negatives.
+ return True
+
+ def _compute(self, partition, total_buckets: int) -> FrozenSet[int]:
+ cache_key = (_partition_to_cache_key(partition,
self._partition_fields),
+ total_buckets)
+ cached = self._cache.get(cache_key)
+ if cached is not None:
+ return cached
+
+ effective_predicate: Optional[Union[bool, Predicate]] = self._predicate
+ if partition is not None and self._partition_fields:
+ partition_values = _partition_to_dict(partition,
self._partition_fields)
+ partition_field_names = {f.name for f in self._partition_fields}
+ effective_predicate = replace_partition_predicate(
+ self._predicate, partition_field_names, partition_values)
+
+ if effective_predicate is _AlwaysFalse:
+ # No row in this partition can match — empty bucket set.
+ frozen: FrozenSet[int] = frozenset()
+ self._cache[cache_key] = frozen
+ return frozen
+
+ if effective_predicate is _AlwaysTrue:
+ # Predicate cleared after partition substitution — accept all
+ # buckets for this partition.
+ frozen = frozenset(range(total_buckets))
+ self._cache[cache_key] = frozen
+ return frozen
+
+ combinations = _build_combinations(effective_predicate,
+ self._bucket_key_fields)
+ if combinations is None:
+ # Couldn't pin down all bucket keys (or above MAX_VALUES) —
+ # fall back to "all buckets accept" for soundness.
+ frozen = frozenset(range(total_buckets))
+ self._cache[cache_key] = frozen
+ return frozen
+
+ frozen = _hash_combinations(combinations, self._bucket_key_fields,
+ total_buckets)
+ self._cache[cache_key] = frozen
+ return frozen
+
+
+def create_bucket_selector(
+ predicate: Optional[Predicate],
+ bucket_key_fields: List[DataField],
+ partition_fields: Optional[List[DataField]] = None,
+) -> Optional[Callable[[Any, int, int], bool]]:
+ """Try to derive a bucket selector from ``predicate`` constrained to
+ ``bucket_key_fields``.
+
+ Returns:
+ A callable ``(partition, bucket, total_buckets) -> bool``. When
+ ``partition_fields`` is given and the predicate references those
+ partition columns, the selector specialises the predicate per
+ partition value before hashing — this catches mixed forms like
+ ``(part='a' AND bk IN (1,2)) OR (part='b' AND bk IN (3,4))`` that
+ would otherwise be unprunable. ``partition=None`` callsites
+ (early manifest filter that hasn't deserialised the entry yet)
+ simply get the partition-agnostic result.
+
+ Returns None when the predicate carries no usable bucket-key
+ constraint at all (caller must NOT prune by bucket).
+ """
+ if predicate is None or not bucket_key_fields:
+ return None
+
+ # See ``_UNSAFE_BUCKET_KEY_TYPES``: refuse pruning when the bucket-key
+ # column types are prone to writer/reader byte-level disagreement on
+ # equal logical values. Fail open rather than risk false-negatives.
+ if _has_unsafe_bucket_key_type(bucket_key_fields):
+ return None
+
+ # Sanity gate: if the predicate without any partition substitution
+ # already fails to pin down bucket keys AND it doesn't touch any
+ # partition columns, there's no point handing the caller a selector
+ # that always returns "all buckets" — preserve the original "return
+ # None for unprunable" contract so the caller can skip the wrap.
+ partition_names = {f.name for f in (partition_fields or [])}
+ touches_partition = (
+ bool(partition_names)
+ and _predicate_touches_partition(predicate, partition_names)
+ )
+ if not touches_partition:
+ if _build_combinations(predicate, bucket_key_fields) is None:
+ return None
+
+ return _Selector(predicate, bucket_key_fields, partition_fields)
diff --git a/paimon-python/pypaimon/read/scanner/file_scanner.py
b/paimon-python/pypaimon/read/scanner/file_scanner.py
index 70cfa6c978..ea0e8219ee 100755
--- a/paimon-python/pypaimon/read/scanner/file_scanner.py
+++ b/paimon-python/pypaimon/read/scanner/file_scanner.py
@@ -30,6 +30,7 @@ from pypaimon.manifest.manifest_list_manager import
ManifestListManager
from pypaimon.manifest.schema.manifest_entry import ManifestEntry
from pypaimon.manifest.schema.manifest_file_meta import ManifestFileMeta
from pypaimon.manifest.simple_stats_evolutions import SimpleStatsEvolutions
+from pypaimon.schema.data_types import DataField
from pypaimon.read.plan import Plan
from pypaimon.read.push_down_utils import (_get_all_fields,
remove_row_id_filter,
@@ -356,11 +357,12 @@ class FileScanner:
"""Compose the (bucket, total_buckets) -> bool used by the manifest
reader to drop entries before deserialising ``_FILE`` / partition.
- Mirrors the BucketFilter applied at Java's InternalRow stage in
- ``ManifestEntryCache``. The signature is intentionally minimal:
- per-partition predicate pre-evaluation would also need
- ``(partition, bucket, total_buckets)``, but the current selector
- is partition-agnostic.
+ The selector is partition-aware now, but at this early stage the
+ partition field has not been deserialised yet, so callers stick
+ with the two-arg form. The selector internally falls back to a
+ partition-agnostic over-approximation; per-partition tightening
+ still happens later in ``_filter_manifest_entry`` once the entry
+ is fully decoded.
"""
only_real = self.only_read_real_buckets
selector = self._bucket_selector
@@ -479,7 +481,21 @@ class FileScanner:
return None
if not bucket_key_fields:
return None
- return create_bucket_selector(self.predicate, bucket_key_fields)
+ # Partition fields are passed so the selector can specialise
+ # the predicate per partition value at the late filter stage,
+ # turning ``(part='a' AND bk=1) OR (part='b' AND bk=2)`` into a
+ # precise bucket pick per partition instead of an over-scan.
+ partition_fields: Optional[List[DataField]] = None
+ if self.table.partition_keys:
+ partition_fields = [
+ self.table.field_dict[name]
+ for name in self.table.partition_keys
+ if name in self.table.field_dict
+ ]
+ return create_bucket_selector(
+ self.predicate, bucket_key_fields,
+ partition_fields=partition_fields,
+ )
def _filter_manifest_entry(self, entry: ManifestEntry) -> bool:
# Redundant safety net: the early filter in the manifest reader
@@ -489,7 +505,8 @@ class FileScanner:
return False
if (self._bucket_selector is not None
and entry.bucket >= 0
- and not self._bucket_selector(entry.bucket,
entry.total_buckets)):
+ and not self._bucket_selector(
+ entry.partition, entry.bucket, entry.total_buckets)):
return False
if self.partition_key_predicate and not
self.partition_key_predicate.test(entry.partition):
return False
diff --git a/paimon-python/pypaimon/tests/pushdown_bucket_test.py
b/paimon-python/pypaimon/tests/pushdown_bucket_test.py
index b83283200e..80d3065ca3 100644
--- a/paimon-python/pypaimon/tests/pushdown_bucket_test.py
+++ b/paimon-python/pypaimon/tests/pushdown_bucket_test.py
@@ -345,6 +345,217 @@ class BucketSelectConverterUnitTest(unittest.TestCase):
"not crash (bucket={}, total={})".format(b,
total))
+class PartitionAwareBucketSelectorUnitTest(unittest.TestCase):
+ """Unit tests for the per-partition predicate specialisation path.
+
+ Covers ``replace_partition_predicate`` (the AND/OR fold walker) and
+ the partition-aware ``_Selector.__call__(partition, bucket,
+ total_buckets)`` 3-arg form that ``FileScanner._filter_manifest_entry``
+ will use after wiring."""
+
+ @classmethod
+ def setUpClass(cls):
+ cls.id_field = _bigint_field(0, 'id')
+ cls.part_field = DataField(2, 'part', AtomicType('STRING'))
+ cls.pb = PredicateBuilder([cls.id_field, cls.part_field])
+
+ # ----- replace_partition_predicate --------------------------------
+
+ def test_replace_partition_leaf_to_true_drops_constraint(self):
+ from pypaimon.read.scanner.bucket_select_converter import \
+ replace_partition_predicate
+ # ``part = 'a' AND id = 1`` against partition {part: 'a'} →
+ # part leaf becomes True → AND fold removes it → only ``id = 1``
+ pred = PredicateBuilder.and_predicates([
+ self.pb.equal('part', 'a'),
+ self.pb.equal('id', 1),
+ ])
+ result = replace_partition_predicate(
+ pred, {'part'}, {'part': 'a'})
+ self.assertTrue(isinstance(result, type(pred)),
+ "AND should fold to a remaining single leaf")
+ self.assertEqual(result.method, 'equal')
+ self.assertEqual(result.field, 'id')
+
+ def test_replace_partition_leaf_to_false_collapses_and(self):
+ from pypaimon.read.scanner.bucket_select_converter import \
+ replace_partition_predicate
+ # ``part = 'a' AND id = 1`` against partition {part: 'b'} →
+ # part leaf becomes False → AND collapses to AlwaysFalse (False).
+ pred = PredicateBuilder.and_predicates([
+ self.pb.equal('part', 'a'),
+ self.pb.equal('id', 1),
+ ])
+ result = replace_partition_predicate(
+ pred, {'part'}, {'part': 'b'})
+ self.assertIs(result, False)
+
+ def test_replace_partition_leaf_in_or_keeps_other_branch(self):
+ from pypaimon.read.scanner.bucket_select_converter import \
+ replace_partition_predicate
+ # ``(part='a' AND id=1) OR (part='b' AND id=2)`` against
+ # partition {part: 'a'} → first OR child becomes ``id=1``, second
+ # collapses to AlwaysFalse and is dropped. Result is just ``id=1``.
+ pred = PredicateBuilder.or_predicates([
+ PredicateBuilder.and_predicates([
+ self.pb.equal('part', 'a'),
+ self.pb.equal('id', 1),
+ ]),
+ PredicateBuilder.and_predicates([
+ self.pb.equal('part', 'b'),
+ self.pb.equal('id', 2),
+ ]),
+ ])
+ result = replace_partition_predicate(
+ pred, {'part'}, {'part': 'a'})
+ # OR with a single surviving child unwraps to that child.
+ self.assertEqual(result.method, 'equal')
+ self.assertEqual(result.field, 'id')
+ self.assertEqual(result.literals, [1])
+
+ def test_replace_partition_leaf_in_or_other_partition(self):
+ from pypaimon.read.scanner.bucket_select_converter import \
+ replace_partition_predicate
+ # Same predicate, partition {part: 'b'} → second branch survives
+ # as ``id=2``.
+ pred = PredicateBuilder.or_predicates([
+ PredicateBuilder.and_predicates([
+ self.pb.equal('part', 'a'),
+ self.pb.equal('id', 1),
+ ]),
+ PredicateBuilder.and_predicates([
+ self.pb.equal('part', 'b'),
+ self.pb.equal('id', 2),
+ ]),
+ ])
+ result = replace_partition_predicate(
+ pred, {'part'}, {'part': 'b'})
+ self.assertEqual(result.method, 'equal')
+ self.assertEqual(result.field, 'id')
+ self.assertEqual(result.literals, [2])
+
+ def test_replace_partition_leaf_unrelated_predicate_unchanged(self):
+ from pypaimon.read.scanner.bucket_select_converter import \
+ replace_partition_predicate
+ # No partition leaf → predicate returned as-is.
+ pred = self.pb.equal('id', 42)
+ result = replace_partition_predicate(
+ pred, {'part'}, {'part': 'a'})
+ self.assertIs(result, pred)
+
+ # ----- _Selector partition-aware path -----------------------------
+
+ def test_selector_3arg_specialises_per_partition(self):
+ # ``(part='a' AND id=1) OR (part='b' AND id=2)`` should hit
+ # bucket(1) only when partition='a' and bucket(2) only when
+ # partition='b'. Master without this fix returns "all buckets".
+ pred = PredicateBuilder.or_predicates([
+ PredicateBuilder.and_predicates([
+ self.pb.equal('part', 'a'),
+ self.pb.equal('id', 1),
+ ]),
+ PredicateBuilder.and_predicates([
+ self.pb.equal('part', 'b'),
+ self.pb.equal('id', 2),
+ ]),
+ ])
+ sel = create_bucket_selector(
+ pred, [self.id_field], partition_fields=[self.part_field])
+ self.assertIsNotNone(sel)
+ bucket_for_1 = _hash_bucket([1], [self.id_field], total=8)
+ bucket_for_2 = _hash_bucket([2], [self.id_field], total=8)
+
+ part_a = GenericRow(['a'], [self.part_field], RowKind.INSERT)
+ part_b = GenericRow(['b'], [self.part_field], RowKind.INSERT)
+
+ for b in range(8):
+ self.assertEqual(sel(part_a, b, 8), b == bucket_for_1,
+ "partition a only keeps bucket %d" % bucket_for_1)
+ self.assertEqual(sel(part_b, b, 8), b == bucket_for_2,
+ "partition b only keeps bucket %d" % bucket_for_2)
+
+ def test_selector_falls_through_when_partition_unknown(self):
+ """Early manifest filter passes ``partition=None`` (or uses the
+ 2-arg form) — no specialisation runs, bucket set falls back to a
+ sound over-approximation: all buckets accept."""
+ pred = PredicateBuilder.or_predicates([
+ PredicateBuilder.and_predicates([
+ self.pb.equal('part', 'a'),
+ self.pb.equal('id', 1),
+ ]),
+ PredicateBuilder.and_predicates([
+ self.pb.equal('part', 'b'),
+ self.pb.equal('id', 2),
+ ]),
+ ])
+ sel = create_bucket_selector(
+ pred, [self.id_field], partition_fields=[self.part_field])
+ self.assertIsNotNone(sel)
+ # 2-arg form (legacy callsite) — partition unknown, all buckets keep.
+ for b in range(8):
+ self.assertTrue(sel(b, 8),
+ "partition-unknown call must accept all buckets")
+ # 3-arg form with partition=None has the same semantics.
+ for b in range(8):
+ self.assertTrue(sel(None, b, 8))
+
+ def test_selector_partition_not_matching_returns_empty_bucket_set(self):
+ # ``part = 'a' AND id = 1`` on partition {part: 'c'} simplifies to
+ # AlwaysFalse — the selector returns False for every bucket since
+ # no row in this partition can possibly match. Sound: dropping a
+ # partition that *can't* contain matches doesn't lose data.
+ pred = PredicateBuilder.and_predicates([
+ self.pb.equal('part', 'a'),
+ self.pb.equal('id', 1),
+ ])
+ sel = create_bucket_selector(
+ pred, [self.id_field], partition_fields=[self.part_field])
+ self.assertIsNotNone(sel)
+ part_c = GenericRow(['c'], [self.part_field], RowKind.INSERT)
+ for b in range(8):
+ self.assertFalse(sel(part_c, b, 8),
+ "partition c can't satisfy part='a', "
+ "drop every bucket (b=%d)" % b)
+
+ def test_selector_partition_only_constraint_drops_partition(self):
+ # ``part='a' AND id IN (1,2)`` — same partition value 'a'
+ # specialises ``part='a'`` to True, leaving ``id IN (1,2)``.
+ pred = PredicateBuilder.and_predicates([
+ self.pb.equal('part', 'a'),
+ self.pb.is_in('id', [1, 2]),
+ ])
+ sel = create_bucket_selector(
+ pred, [self.id_field], partition_fields=[self.part_field])
+ self.assertIsNotNone(sel)
+ part_a = GenericRow(['a'], [self.part_field], RowKind.INSERT)
+ expected = {_hash_bucket([v], [self.id_field], 8) for v in (1, 2)}
+ for b in range(8):
+ self.assertEqual(sel(part_a, b, 8), b in expected)
+
+ def test_selector_caches_per_partition(self):
+ pred = PredicateBuilder.or_predicates([
+ PredicateBuilder.and_predicates([
+ self.pb.equal('part', 'a'),
+ self.pb.equal('id', 1),
+ ]),
+ PredicateBuilder.and_predicates([
+ self.pb.equal('part', 'b'),
+ self.pb.equal('id', 2),
+ ]),
+ ])
+ sel = create_bucket_selector(
+ pred, [self.id_field], partition_fields=[self.part_field])
+ part_a = GenericRow(['a'], [self.part_field], RowKind.INSERT)
+ part_b = GenericRow(['b'], [self.part_field], RowKind.INSERT)
+ # Drive the cache.
+ for _ in range(5):
+ sel(part_a, 0, 8)
+ sel(part_b, 0, 8)
+ # Cache keyed by (partition tuple, total_buckets); two distinct
+ # partitions × one total → exactly two entries.
+ self.assertEqual(len(sel._cache), 2)
+
+
# ---------------------------------------------------------------------------
# Layer 2 — Integration: real tables, public API, assert correctness AND
# that pruning actually fired (otherwise we're not testing the optimisation,
@@ -606,6 +817,66 @@ class BucketPruningIntegrationTest(unittest.TestCase):
self.assertEqual(self._split_buckets(splits),
self._expected_buckets(table, [17]))
+ def test_per_partition_pruning_with_mixed_or(self):
+ """``(part='a' AND id=1) OR (part='b' AND id=2)``: each partition
+ sees only the bucket for its own ``id`` literal. Without
+ per-partition predicate specialisation this query falls through
+ to "all buckets in both partitions"."""
+ opts = {'bucket': '4', 'file.format': 'parquet'}
+ pa_schema = pa.schema([
+ pa.field('part', pa.string(), nullable=False),
+ pa.field('id', pa.int64(), nullable=False),
+ ('val', pa.int64()),
+ ])
+ schema = Schema.from_pyarrow_schema(
+ pa_schema, primary_keys=['part', 'id'],
+ partition_keys=['part'], options=opts)
+ identifier = 'default.per_part_mixed_or'
+ self.catalog.create_table(identifier, schema, False)
+ table = self.catalog.get_table(identifier)
+ # Two partitions × three id values each → up to 6 (part, bucket)
+ # combinations after the writer hashes.
+ rows = []
+ for p in ('a', 'b'):
+ for i in (1, 2, 3):
+ rows.append({'part': p, 'id': i, 'val': i * 7})
+ wb = table.new_batch_write_builder()
+ w = wb.new_write()
+ c = wb.new_commit()
+ try:
+ w.write_arrow(pa.Table.from_pylist(rows, schema=pa_schema))
+ c.commit(w.prepare_commit())
+ finally:
+ w.close()
+ c.close()
+
+ pb = table.new_read_builder().new_predicate_builder()
+ from pypaimon.common.predicate_builder import PredicateBuilder
+ mixed = PredicateBuilder.or_predicates([
+ PredicateBuilder.and_predicates([
+ pb.equal('part', 'a'),
+ pb.equal('id', 1),
+ ]),
+ PredicateBuilder.and_predicates([
+ pb.equal('part', 'b'),
+ pb.equal('id', 2),
+ ]),
+ ])
+ got, splits = self._read_with(table, mixed)
+ # Correctness: only the two matching rows.
+ got_sorted = sorted(got, key=lambda r: (r['part'], r['id']))
+ self.assertEqual(
+ got_sorted,
+ [{'part': 'a', 'id': 1, 'val': 7},
+ {'part': 'b', 'id': 2, 'val': 14}])
+ # Pruning effectiveness: across both partitions we should see at
+ # most two distinct (partition, bucket) splits — one per branch.
+ # Without per-partition pruning we'd see every (partition, bucket)
+ # combo that exists on disk for the predicate's id literals.
+ self.assertLessEqual(len(splits), 2,
+ "per-partition pruning should keep ≤ 2 splits, "
+ "got %d" % len(splits))
+
# ---------------------------------------------------------------------------
# Layer 3 — Property: random PK tables, random Equal/In predicates,