This is an automated email from the ASF dual-hosted git repository.

JingsongLi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git


The following commit(s) were added to refs/heads/master by this push:
     new 81fd80f63a [python] Add predicate-driven bucket pruning for HASH_FIXED 
tables (#7744)
81fd80f63a is described below

commit 81fd80f63aeb68652082aae12d795bbd8386dc70
Author: chaoyang <[email protected]>
AuthorDate: Sat May 9 21:04:43 2026 +0800

    [python] Add predicate-driven bucket pruning for HASH_FIXED tables (#7744)
---
 .../pypaimon/manifest/manifest_file_manager.py     |  26 +-
 .../read/scanner/bucket_select_converter.py        | 297 +++++++++
 .../pypaimon/read/scanner/file_scanner.py          |  87 ++-
 paimon-python/pypaimon/schema/table_schema.py      |  38 ++
 .../pypaimon/tests/pushdown_bucket_test.py         | 739 +++++++++++++++++++++
 paimon-python/pypaimon/tests/table_schema_test.py  |  93 +++
 paimon-python/pypaimon/write/row_key_extractor.py  |  16 +-
 7 files changed, 1278 insertions(+), 18 deletions(-)

diff --git a/paimon-python/pypaimon/manifest/manifest_file_manager.py 
b/paimon-python/pypaimon/manifest/manifest_file_manager.py
index 0ed5091825..308dc13a73 100644
--- a/paimon-python/pypaimon/manifest/manifest_file_manager.py
+++ b/paimon-python/pypaimon/manifest/manifest_file_manager.py
@@ -17,7 +17,7 @@
 
################################################################################
 from concurrent.futures import ThreadPoolExecutor
 from io import BytesIO
-from typing import List
+from typing import Callable, List, Optional
 
 import fastavro
 
@@ -48,10 +48,13 @@ class ManifestFileManager:
         self.trimmed_primary_keys_fields = 
self.table.trimmed_primary_keys_fields
 
     def read_entries_parallel(self, manifest_files: List[ManifestFileMeta], 
manifest_entry_filter=None,
-                              drop_stats=True, max_workers=8) -> 
List[ManifestEntry]:
+                              drop_stats=True, max_workers=8,
+                              early_entry_filter: Optional[Callable[[int, 
int], bool]] = None
+                              ) -> List[ManifestEntry]:
 
         def _process_single_manifest(manifest_file: ManifestFileMeta) -> 
List[ManifestEntry]:
-            return self.read(manifest_file.file_name, manifest_entry_filter, 
drop_stats)
+            return self.read(manifest_file.file_name, manifest_entry_filter, 
drop_stats,
+                             early_entry_filter=early_entry_filter)
 
         def _entry_identifier(e: ManifestEntry) -> tuple:
             return (
@@ -81,7 +84,19 @@ class ManifestFileManager:
         ]
         return final_entries
 
-    def read(self, manifest_file_name: str, manifest_entry_filter=None, 
drop_stats=True) -> List[ManifestEntry]:
+    def read(self, manifest_file_name: str, manifest_entry_filter=None, 
drop_stats=True,
+             early_entry_filter: Optional[Callable[[int, int], bool]] = None
+             ) -> List[ManifestEntry]:
+        """
+        early_entry_filter: optional ``(bucket, total_buckets) -> bool``
+        called immediately after the avro record is parsed. Mirrors
+        Java ``BucketFilter`` applied at the InternalRow stage in
+        ``ManifestEntryCache``: when it returns False, the entry's
+        ``_FILE`` block / partition / stats are never deserialized.
+        Caller is responsible for soundness (any non-pruning rule must
+        return True). The full ``manifest_entry_filter`` still runs on
+        the survivors.
+        """
         manifest_file_path = f"{self.manifest_path}/{manifest_file_name}"
 
         entries = []
@@ -91,6 +106,9 @@ class ManifestFileManager:
         reader = fastavro.reader(buffer)
 
         for record in reader:
+            if early_entry_filter is not None and not early_entry_filter(
+                    record['_BUCKET'], record['_TOTAL_BUCKETS']):
+                continue
             file_dict = dict(record['_FILE'])
             key_dict = dict(file_dict['_KEY_STATS'])
             key_stats = SimpleStats(
diff --git a/paimon-python/pypaimon/read/scanner/bucket_select_converter.py 
b/paimon-python/pypaimon/read/scanner/bucket_select_converter.py
new file mode 100644
index 0000000000..e0c3b6bfa4
--- /dev/null
+++ b/paimon-python/pypaimon/read/scanner/bucket_select_converter.py
@@ -0,0 +1,297 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+################################################################################
+
+"""
+Predicate-driven bucket pruning for HASH_FIXED tables.
+
+Mirrors Java's ``org.apache.paimon.operation.BucketSelectConverter``:
+walk the predicate, isolate AND clauses that constrain bucket-key fields
+with Equal/In, take the cartesian product of literal values, hash each
+combination using the writer's hash routine, and produce the set of
+buckets the query can possibly hit. All other entries are safely dropped.
+
+Hard correctness contract: the bucket set this returns is a *superset* of
+the buckets that contain any matching rows. False-positive (over-keep)
+allowed; false-negative (drop a bucket that has matching rows) MUST never
+happen — that would be silent data loss.
+
+The hashing routine reuses ``RowKeyExtractor._hash_bytes_by_words`` /
+``_bucket_from_hash`` from ``pypaimon.write.row_key_extractor`` — the same
+code path the writer uses to assign rows to buckets. Reusing it (rather
+than copying) is what guarantees read/write hash agreement in the face of
+future routine changes.
+
+Conservative scope (deliberately narrower than Java's general flexibility):
+
+  * Only HASH_FIXED tables (caller's responsibility to gate; this module
+    does not look at the bucket mode itself).
+  * All bucket-key fields must be constrained, with Equal or In, in a
+    single AND-of-OR-of-literals shape. If any bucket-key column is
+    unconstrained, return None — the caller must scan all buckets.
+  * Repeated constraints on the same bucket-key column under top-level
+    AND (e.g. ``id IN (1,2,3) AND id IN (2,3,4)``) intersect their
+    literal sets (mirrors Java ``BucketSelector.retainAll``). An empty
+    intersection means the predicate is unsatisfiable, and we return
+    None.
+  * Total cartesian product capped at MAX_VALUES (1000), again matching
+    Java; above that, fall back to a full scan.
+
+Returns a callable ``selector(bucket: int, total_buckets: int) -> bool``.
+The callable is cached per ``total_buckets`` to handle the rare case
+where bucket count varies across snapshots (rescale).
+
+TODO: per-partition predicate pre-evaluation.
+
+  Predicates of the form ``(part='a' AND bk IN (1,2)) OR (part='b' AND bk
+  IN (3,4))`` currently fall through to "no pruning" because the top-level
+  OR mixes partition and bucket-key constraints. Java simplifies the
+  predicate per concrete partition value first (replacing partition
+  leaves with literal true/false and folding AND/OR), so each partition
+  gets a tighter bucket-key predicate and the corresponding bucket set.
+
+  Implementing this here would need three pieces:
+
+    * a Predicate-replace walker that substitutes a partition's actual
+      values into partition-column leaves (mirrors Java's
+      ``paimon-common/.../predicate/PartitionValuePredicateVisitor.java``).
+    * lifting ``_Selector`` to key its cache by
+      ``(partition, total_buckets)`` instead of just ``total_buckets``.
+    * threading the partition value into the early manifest filter
+      ``FileScanner._build_early_bucket_filter`` (currently sees only
+      ``(bucket, total_buckets)``).
+"""
+
+from itertools import product
+from typing import Any, Callable, Dict, FrozenSet, List, Optional
+
+from pypaimon.common.predicate import Predicate
+from pypaimon.schema.data_types import DataField
+from pypaimon.table.row.generic_row import GenericRow, GenericRowSerializer
+from pypaimon.table.row.internal_row import RowKind
+from pypaimon.write.row_key_extractor import (_bucket_from_hash,
+                                              _hash_bytes_by_words)
+
+MAX_VALUES = 1000
+
+# Bucket-key column types where the Python serializer is not byte-aligned
+# with the writer's logical value, or with Java's ``BinaryRow`` byte layout.
+# A divergent hash is silent data loss (false-negative), so the selector
+# refuses to build at all when a bucket-key field has one of these types.
+#
+# Two reasons something gets blacklisted:
+#
+#   1. Locale / precision drift between writer and reader for equal logical
+#      values (DECIMAL via float-vs-Decimal, TIMESTAMP via naive datetime
+#      timezone interpretation).
+#   2. Composite / nested types whose ``GenericRowSerializer`` byte layout
+#      hasn't been cross-validated against Java's ``BinaryRow`` (ARRAY,
+#      MAP, ROW, MULTISET, VARIANT, BLOB). Until that validation lands,
+#      treating them as safe risks a hash divergence.
+_UNSAFE_BUCKET_KEY_TYPES = (
+    'DECIMAL',
+    'TIMESTAMP',
+    'ARRAY',
+    'MAP',
+    'ROW',
+    'MULTISET',
+    'VARIANT',
+    'BLOB',
+)
+
+
+def _has_unsafe_bucket_key_type(bucket_key_fields: List[DataField]) -> bool:
+    for f in bucket_key_fields:
+        type_name = getattr(getattr(f, 'type', None), 'type', '')
+        if not type_name:
+            continue
+        head = type_name.split('(')[0].strip().upper()
+        if any(head.startswith(prefix) for prefix in _UNSAFE_BUCKET_KEY_TYPES):
+            return True
+    return False
+
+
+def _split_and(p: Predicate) -> List[Predicate]:
+    if p.method == 'and':
+        out: List[Predicate] = []
+        for child in (p.literals or []):
+            out.extend(_split_and(child))
+        return out
+    return [p]
+
+
+def _split_or(p: Predicate) -> List[Predicate]:
+    if p.method == 'or':
+        out: List[Predicate] = []
+        for child in (p.literals or []):
+            out.extend(_split_or(child))
+        return out
+    return [p]
+
+
+def _extract_or_clause(or_pred: Predicate,
+                       bk_name_to_slot: Dict[str, int]) -> Optional[List[Any]]:
+    """For one AND-child predicate, return either:
+      * ``[slot_index, [literal, ...]]`` — the OR/leaf is a pure
+        Equal-or-In list on a single bucket-key field; or
+      * ``None`` — the clause is not a bucket-key constraint we can
+        safely use; the caller skips it.
+
+    All disjuncts must hit the same bucket-key column. Mixed columns or
+    non-Equal/In operators disqualify the entire AND clause.
+    """
+    slot: Optional[int] = None
+    values: List[Any] = []
+    for clause in _split_or(or_pred):
+        if clause.method not in ('equal', 'in'):
+            return None
+        if clause.field is None or clause.field not in bk_name_to_slot:
+            return None
+        this_slot = bk_name_to_slot[clause.field]
+        if slot is not None and slot != this_slot:
+            return None
+        slot = this_slot
+        for lit in (clause.literals or []):
+            # Java filters nulls; null literals are degenerate (NULL = NULL
+            # is UNKNOWN in SQL). Producing zero values for a slot will
+            # cascade through the cartesian product to "match nothing",
+            # which is the same observable behaviour as Java.
+            if lit is None:
+                continue
+            values.append(lit)
+    return None if slot is None else [slot, values]
+
+
+class _Selector:
+    """Callable bucket filter, lazy + cached per ``total_buckets``."""
+
+    __slots__ = ('_combinations', '_bucket_key_fields', '_cache')
+
+    def __init__(self, combinations: List[List[Any]],
+                 bucket_key_fields: List[DataField]):
+        self._combinations = combinations
+        self._bucket_key_fields = bucket_key_fields
+        self._cache: Dict[int, FrozenSet[int]] = {}
+
+    def __call__(self, bucket: int, total_buckets: int) -> bool:
+        # ``total_buckets <= 0`` shows up for postpone / legacy / special
+        # entries and must NOT be pruned: returning False here would drop
+        # rows the writer hashed under a different convention. Fail open.
+        if total_buckets <= 0:
+            return True
+        try:
+            return bucket in self._compute(total_buckets)
+        except Exception:
+            # Fail open on any hashing/serialization error (e.g. a literal
+            # type that doesn't match the bucket-key column's atomic type:
+            # ``pb.equal('id_bigint', 'foo')`` — GenericRowSerializer raises
+            # struct.error trying to pack the string as int64). Crashing
+            # the entire scan here would be worse than skipping pruning;
+            # the soundness contract still forbids false-negatives.
+            return True
+
+    def _compute(self, total_buckets: int) -> FrozenSet[int]:
+        cached = self._cache.get(total_buckets)
+        if cached is not None:
+            return cached
+        result = set()
+        for combo in self._combinations:
+            row = GenericRow(list(combo), self._bucket_key_fields,
+                             RowKind.INSERT)
+            serialized = GenericRowSerializer.to_bytes(row)
+            # Skip the 4-byte length prefix — matches the writer's hash
+            # input exactly (see RowKeyExtractor._binary_row_hash_code).
+            h = _hash_bytes_by_words(serialized[4:])
+            result.add(_bucket_from_hash(h, total_buckets))
+        frozen = frozenset(result)
+        self._cache[total_buckets] = frozen
+        return frozen
+
+    @property
+    def bucket_combinations(self) -> int:
+        """Number of (bucket-key) combinations used to compute the filter.
+        Exposed for tests / observability."""
+        return len(self._combinations)
+
+
+def create_bucket_selector(
+        predicate: Optional[Predicate],
+        bucket_key_fields: List[DataField]) -> Optional[Callable[[int, int], 
bool]]:
+    """Try to derive a bucket selector from ``predicate`` constrained to
+    ``bucket_key_fields``.
+
+    Returns:
+      A callable ``(bucket, total_buckets) -> bool`` if the predicate
+      pins down all bucket keys to a finite Equal/In set; otherwise None
+      (caller must NOT prune by bucket).
+    """
+    if predicate is None or not bucket_key_fields:
+        return None
+
+    # See ``_UNSAFE_BUCKET_KEY_TYPES``: refuse pruning when the bucket-key
+    # column types are prone to writer/reader byte-level disagreement on
+    # equal logical values. Fail open rather than risk false-negatives.
+    if _has_unsafe_bucket_key_type(bucket_key_fields):
+        return None
+
+    bk_name_to_slot: Dict[str, int] = {
+        f.name: i for i, f in enumerate(bucket_key_fields)
+    }
+    n_slots = len(bucket_key_fields)
+    slot_values: List[Optional[List[Any]]] = [None] * n_slots
+
+    for and_child in _split_and(predicate):
+        extracted = _extract_or_clause(and_child, bk_name_to_slot)
+        if extracted is None:
+            # Not a bucket-key constraint — that's fine, just skip it. The
+            # remaining predicate still describes a SUPERSET of matching
+            # rows; bucket pruning stays sound as long as we don't *add*
+            # constraints that aren't actually true.
+            continue
+        slot, values = extracted
+        if slot_values[slot] is not None:
+            # Same bucket-key column constrained twice in top-level AND
+            # (e.g. ``id IN (1,2,3) AND id IN (2,3,4)``). Mirror Java's
+            # ``retainAll``: keep the intersection, bail only when it is
+            # empty (the predicate is unsatisfiable).
+            new_values_set = set(values)
+            intersection = [v for v in slot_values[slot]
+                            if v in new_values_set]
+            if not intersection:
+                return None
+            slot_values[slot] = intersection
+        else:
+            slot_values[slot] = values
+
+    # Every bucket-key column must be constrained.
+    for v in slot_values:
+        if v is None:
+            return None
+
+    # Cartesian-product cap. Above the cap the bucket set is essentially
+    # all buckets anyway; punting saves the hash computation.
+    total = 1
+    for v in slot_values:
+        # An empty slot (e.g. all literals were null) collapses the
+        # product to 0 — observable behaviour: empty bucket set, drop
+        # everything. Mirrors Java.
+        total *= len(v)
+        if total > MAX_VALUES:
+            return None
+
+    combinations = [list(combo) for combo in product(*slot_values)]
+    return _Selector(combinations, bucket_key_fields)
diff --git a/paimon-python/pypaimon/read/scanner/file_scanner.py 
b/paimon-python/pypaimon/read/scanner/file_scanner.py
index 78e48f85e5..795e76a160 100755
--- a/paimon-python/pypaimon/read/scanner/file_scanner.py
+++ b/paimon-python/pypaimon/read/scanner/file_scanner.py
@@ -36,6 +36,8 @@ from pypaimon.read.push_down_utils import (_get_all_fields,
                                            trim_and_transform_predicate)
 from pypaimon.read.scanner.append_table_split_generator import \
     AppendTableSplitGenerator
+from pypaimon.read.scanner.bucket_select_converter import \
+    create_bucket_selector
 from pypaimon.read.scanner.data_evolution_split_generator import \
     DataEvolutionSplitGenerator
 from pypaimon.read.scanner.primary_key_table_split_generator import \
@@ -208,6 +210,12 @@ class FileScanner:
         self._scanned_snapshot = None
         self._scanned_snapshot_id = None
 
+        # Predicate-driven bucket pruning (HASH_FIXED only). Mirrors Java
+        # BucketSelectConverter. Set on demand and reused across all
+        # _filter_manifest_entry calls; the inner _Selector caches the
+        # bucket set per ``total_buckets`` value.
+        self._bucket_selector = self._init_bucket_selector()
+
         def schema_fields_func(schema_id: int):
             return self.table.schema_manager.get_schema(schema_id).fields
 
@@ -340,9 +348,36 @@ class FileScanner:
         return self.manifest_file_manager.read_entries_parallel(
             manifest_files,
             self._filter_manifest_entry,
-            max_workers=max_workers
+            max_workers=max_workers,
+            early_entry_filter=self._build_early_bucket_filter(),
         )
 
+    def _build_early_bucket_filter(self):
+        """Compose the (bucket, total_buckets) -> bool used by the manifest
+        reader to drop entries before deserialising ``_FILE`` / partition.
+
+        Mirrors the BucketFilter applied at Java's InternalRow stage in
+        ``ManifestEntryCache``. The signature is intentionally minimal:
+        per-partition predicate pre-evaluation would also need
+        ``(partition, bucket, total_buckets)``, but the current selector
+        is partition-agnostic.
+        """
+        only_real = self.only_read_real_buckets
+        selector = self._bucket_selector
+        if not only_real and selector is None:
+            return None
+
+        def _filter(bucket: int, total_buckets: int) -> bool:
+            if only_real and bucket < 0:
+                return False
+            if (selector is not None
+                    and bucket >= 0
+                    and not selector(bucket, total_buckets)):
+                return False
+            return True
+
+        return _filter
+
     def with_shard(self, idx_of_this_subtask: int, number_of_para_subtasks: 
int) -> 'FileScanner':
         if idx_of_this_subtask >= number_of_para_subtasks:
             raise ValueError("idx_of_this_subtask must be less than 
number_of_para_subtasks")
@@ -401,9 +436,55 @@ class FileScanner:
             file.partition_stats,
             file.num_added_files + file.num_deleted_files)
 
+    def _init_bucket_selector(self):
+        """Build the predicate-driven bucket selector if (and only if) the
+        table is in HASH_FIXED mode and the predicate pins all bucket-key
+        fields to Equal/In literals. Anything else returns None — the
+        caller treats None as "no bucket-level pruning".
+
+        Bucket-key fields come from ``TableSchema.logical_bucket_key_fields``
+        — the same source the writer's ``FixedBucketRowKeyExtractor`` reads
+        from, which is what makes the read/write hash agreement a property
+        of the schema rather than of any particular extractor instance.
+
+        Sound across rescale: ``_Selector`` caches per ``total_buckets``,
+        which can vary between manifest entries after a bucket rescale.
+        """
+        if self.predicate is None:
+            return None
+        # ``bucket_mode()`` returns HASH_FIXED only when ``options.bucket()
+        # > 0``; other modes (DYNAMIC / POSTPONE / UNAWARE / CROSS_PARTITION)
+        # have no fixed hash → bucket mapping at write time and must NOT
+        # be pruned here.
+        try:
+            if self.table.bucket_mode() != BucketMode.HASH_FIXED:
+                return None
+        except Exception:
+            # Defensive: any catalog/proxy table that fails the mode check
+            # falls back to no pruning rather than crashing the scan.
+            return None
+        try:
+            bucket_key_fields = 
self.table.table_schema.logical_bucket_key_fields
+        except Exception:
+            # ``bucket_keys`` raises on misconfigured ``bucket-key`` (e.g.
+            # references an unknown column). The previous extractor-based
+            # path failed open here; preserve that — pruning is an
+            # optimisation, never a correctness requirement.
+            return None
+        if not bucket_key_fields:
+            return None
+        return create_bucket_selector(self.predicate, bucket_key_fields)
+
     def _filter_manifest_entry(self, entry: ManifestEntry) -> bool:
-        if self.only_read_real_buckets and entry.bucket < 0:
-            return False
+        # NOTE: bucket-level filtering (``only_read_real_buckets`` + the
+        # predicate-driven selector) is enforced in the manifest reader's
+        # early filter (see ``_build_early_bucket_filter``) so rejected
+        # entries skip ``_FILE`` / partition decoding entirely. This
+        # method assumes that early filter has already run; a caller that
+        # bypasses ``read_entries_parallel`` and invokes this directly on
+        # raw entries MUST still apply ``_build_early_bucket_filter`` (or
+        # otherwise enforce ``bucket >= 0`` on POSTPONE tables) — this
+        # function alone is not sound on its own.
         if self.partition_key_predicate and not 
self.partition_key_predicate.test(entry.partition):
             return False
         # Get SimpleStatsEvolution for this schema
diff --git a/paimon-python/pypaimon/schema/table_schema.py 
b/paimon-python/pypaimon/schema/table_schema.py
index 53ddccfefc..789cf4e34c 100644
--- a/paimon-python/pypaimon/schema/table_schema.py
+++ b/paimon-python/pypaimon/schema/table_schema.py
@@ -64,6 +64,44 @@ class TableSchema:
         # Return True if they don't contain all (cross-partition update)
         return not all(pk in self.primary_keys for pk in self.partition_keys)
 
+    @property
+    def bucket_keys(self) -> List[str]:
+        """Resolve the effective bucket-key column names.
+
+        Resolution rule matches Java ``TableSchema.bucketKeys()``: prefer
+        the explicit ``bucket-key`` option; otherwise fall back to primary
+        keys with partition keys stripped (the same convention writers
+        use).
+
+        Validation is intentionally narrower than Java's
+        ``originalBucketKeys()``: only ``unknown column name`` is checked
+        here. Java additionally enforces ``bucket-key`` ⊄ partition keys,
+        and (when primary keys are non-empty) ``bucket-key`` ⊆ primary
+        keys, but it does so once at schema construction. Doing the same
+        in a property would add per-read overhead and could surface
+        errors on tables already in the catalog. The narrow check here
+        is just enough to fail fast on the typo case.
+        """
+        configured = self.options.get(CoreOptions.BUCKET_KEY.key())
+        if configured and configured.strip():
+            keys = [k.strip() for k in configured.split(',') if k.strip()]
+            field_names = {f.name for f in self.fields}
+            missing = [k for k in keys if k not in field_names]
+            if missing:
+                raise ValueError(
+                    "bucket-key references unknown columns: 
{}".format(missing))
+            return keys
+        return [pk for pk in self.primary_keys if pk not in 
self.partition_keys]
+
+    @property
+    def logical_bucket_key_fields(self) -> List[DataField]:
+        """The ``DataField``s for ``bucket_keys``, in the order they were
+        declared. Mirrors Java ``TableSchema.logicalBucketKeyType()``.
+        """
+        field_map = {f.name: f for f in self.fields}
+        return [field_map[name] for name in self.bucket_keys
+                if name in field_map]
+
     def to_schema(self) -> Schema:
         return Schema(
             fields=self.fields,
diff --git a/paimon-python/pypaimon/tests/pushdown_bucket_test.py 
b/paimon-python/pypaimon/tests/pushdown_bucket_test.py
new file mode 100644
index 0000000000..b83283200e
--- /dev/null
+++ b/paimon-python/pypaimon/tests/pushdown_bucket_test.py
@@ -0,0 +1,739 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+################################################################################
+
+"""
+Three-layer correctness tests for predicate-driven bucket pruning.
+
+Mirrors Java's ``BucketSelectConverter`` contract: PK Equal/In queries on
+HASH_FIXED tables must touch only the bucket(s) the writer would have
+placed those keys in. Two correctness obligations:
+
+  1. Sound: every bucket retained by the selector contains AT MOST a
+     superset of matching rows. Buckets that DO contain matching rows
+     are NEVER dropped — false-negative-free.
+  2. Hash-consistent with writers: ``RowKeyExtractor`` (writer) and
+     ``BucketSelectConverter`` (reader) must agree on every literal.
+     This is what makes ``pk = 'X'`` read the bucket holding 'X'.
+
+Layered:
+  * Unit       — direct calls to ``create_bucket_selector`` with crafted
+                 predicates, asserting selector behaviour.
+  * Integration — real PK tables with multiple buckets; queries; assert
+                 (a) result correctness, (b) bucket pruning happened.
+  * Property   — randomly-seeded PK tables, random Equal/In predicates,
+                 result == oracle. No hypothesis dependency (keeps
+                 Python 3.6 compat).
+"""
+
+import os
+import random
+import shutil
+import tempfile
+import unittest
+from typing import Any, Dict, List
+
+import pyarrow as pa
+
+from pypaimon import CatalogFactory, Schema
+from pypaimon.common.predicate_builder import PredicateBuilder
+from pypaimon.read.scanner.bucket_select_converter import (
+    MAX_VALUES, create_bucket_selector)
+from pypaimon.schema.data_types import AtomicType, DataField
+from pypaimon.write.row_key_extractor import (FixedBucketRowKeyExtractor,
+                                              _bucket_from_hash,
+                                              _hash_bytes_by_words)
+from pypaimon.table.row.generic_row import GenericRow, GenericRowSerializer
+from pypaimon.table.row.internal_row import RowKind
+
+
+def _bigint_field(idx: int, name: str) -> DataField:
+    return DataField(idx, name, AtomicType('BIGINT', nullable=False))
+
+
+def _field(idx: int, name: str, type_name: str) -> DataField:
+    return DataField(idx, name, AtomicType(type_name, nullable=False))
+
+
+def _hash_bucket(values: List[Any], fields: List[DataField], total: int) -> 
int:
+    """Re-implement the writer's hash so unit tests can compute the
+    expected bucket without spinning up a real table."""
+    row = GenericRow(values, fields, RowKind.INSERT)
+    serialized = GenericRowSerializer.to_bytes(row)
+    h = _hash_bytes_by_words(serialized[4:])
+    return _bucket_from_hash(h, total)
+
+
+# ---------------------------------------------------------------------------
+# Layer 1 — Unit: drive ``create_bucket_selector`` with crafted predicates.
+# ---------------------------------------------------------------------------
+class BucketSelectConverterUnitTest(unittest.TestCase):
+
+    @classmethod
+    def setUpClass(cls):
+        cls.id_field = _bigint_field(0, 'id')
+        cls.val_field = _bigint_field(1, 'val')
+        cls.k1 = _bigint_field(0, 'k1')
+        cls.k2 = _bigint_field(1, 'k2')
+        cls.pb_id_val = PredicateBuilder([cls.id_field, cls.val_field])
+        cls.pb_k1_k2 = PredicateBuilder([cls.k1, cls.k2])
+
+    # -- Equal / In on single bucket key ---------------------------------
+    def test_equal_on_single_bucket_key_yields_single_bucket(self):
+        sel = create_bucket_selector(
+            self.pb_id_val.equal('id', 42), [self.id_field])
+        self.assertIsNotNone(sel, "PK Equal must produce a selector")
+        expected = _hash_bucket([42], [self.id_field], total=8)
+        for b in range(8):
+            self.assertEqual(
+                sel(b, 8), b == expected,
+                "only bucket {} must be kept (got {})".format(expected, b))
+
+    def test_in_on_single_bucket_key_unions_buckets(self):
+        sel = create_bucket_selector(
+            self.pb_id_val.is_in('id', [1, 2, 3, 100]), [self.id_field])
+        expected = {_hash_bucket([v], [self.id_field], 8)
+                    for v in (1, 2, 3, 100)}
+        for b in range(8):
+            self.assertEqual(sel(b, 8), b in expected)
+
+    def test_or_of_equals_on_same_field_unions_buckets(self):
+        # ``id = 1 OR id = 2`` must equal ``id IN (1, 2)``.
+        pred = PredicateBuilder.or_predicates([
+            self.pb_id_val.equal('id', 1),
+            self.pb_id_val.equal('id', 2),
+        ])
+        sel = create_bucket_selector(pred, [self.id_field])
+        expected = {_hash_bucket([v], [self.id_field], 8) for v in (1, 2)}
+        for b in range(8):
+            self.assertEqual(sel(b, 8), b in expected)
+
+    # -- Composite bucket keys ------------------------------------------
+    def test_composite_bucket_key_intersects_via_cartesian(self):
+        pred = PredicateBuilder.and_predicates([
+            self.pb_k1_k2.is_in('k1', [1, 2]),
+            self.pb_k1_k2.equal('k2', 99),
+        ])
+        sel = create_bucket_selector(pred, [self.k1, self.k2])
+        expected = {
+            _hash_bucket([k1, 99], [self.k1, self.k2], 4)
+            for k1 in (1, 2)
+        }
+        for b in range(4):
+            self.assertEqual(sel(b, 4), b in expected)
+
+    def test_composite_bucket_key_missing_one_field_returns_none(self):
+        pred = self.pb_k1_k2.equal('k1', 1)  # k2 unconstrained
+        sel = create_bucket_selector(pred, [self.k1, self.k2])
+        self.assertIsNone(sel,
+                          "all bucket keys must be constrained or fall back")
+
+    # -- Predicates that can't be reduced -------------------------------
+    def test_non_bucket_key_predicate_returns_none(self):
+        sel = create_bucket_selector(
+            self.pb_id_val.equal('val', 5), [self.id_field])
+        self.assertIsNone(sel, "predicate not on bucket key -> no selector")
+
+    def test_range_predicate_on_bucket_key_returns_none(self):
+        sel = create_bucket_selector(
+            self.pb_id_val.greater_than('id', 100), [self.id_field])
+        self.assertIsNone(sel, "ranges can't be turned into a finite bucket 
set")
+
+    def test_or_with_non_bucket_key_returns_none(self):
+        # ``id = 1 OR val = 5`` — ``val`` isn't a bucket key, so the OR
+        # is not a pure bucket-key constraint.
+        pred = PredicateBuilder.or_predicates([
+            self.pb_id_val.equal('id', 1),
+            self.pb_id_val.equal('val', 5),
+        ])
+        sel = create_bucket_selector(pred, [self.id_field])
+        self.assertIsNone(sel)
+
+    def 
test_repeated_equal_on_same_key_with_empty_intersection_returns_none(self):
+        # ``id = 1 AND id = 2``: literal sets {1} and {2} intersect to
+        # empty; Java's ``retainAll`` would also bail here, since the
+        # predicate is unsatisfiable.
+        pred = PredicateBuilder.and_predicates([
+            self.pb_id_val.equal('id', 1),
+            self.pb_id_val.equal('id', 2),
+        ])
+        sel = create_bucket_selector(pred, [self.id_field])
+        self.assertIsNone(sel)
+
+    def test_repeated_in_on_same_key_intersects_literals(self):
+        # ``id IN (1,2,3) AND id IN (2,3,4)`` should now keep the
+        # intersection {2, 3} and prune to those buckets only. Used to
+        # bail with no selector before the Java parity fix.
+        pred = PredicateBuilder.and_predicates([
+            self.pb_id_val.is_in('id', [1, 2, 3]),
+            self.pb_id_val.is_in('id', [2, 3, 4]),
+        ])
+        sel = create_bucket_selector(pred, [self.id_field])
+        self.assertIsNotNone(sel)
+        expected = {_hash_bucket([v], [self.id_field], 8) for v in (2, 3)}
+        for b in range(8):
+            self.assertEqual(sel(b, 8), b in expected)
+
+    def test_and_with_unrelated_clause_is_unaffected(self):
+        # ``id = 7 AND val > 100`` — the ``val > 100`` part doesn't
+        # constrain buckets, but mustn't disqualify the ``id = 7`` part.
+        pred = PredicateBuilder.and_predicates([
+            self.pb_id_val.equal('id', 7),
+            self.pb_id_val.greater_than('val', 100),
+        ])
+        sel = create_bucket_selector(pred, [self.id_field])
+        self.assertIsNotNone(sel)
+        expected = _hash_bucket([7], [self.id_field], 4)
+        for b in range(4):
+            self.assertEqual(sel(b, 4), b == expected)
+
+    # -- Cap & degenerate edge cases ------------------------------------
+    def test_cartesian_above_max_values_returns_none(self):
+        # Two columns of size > sqrt(MAX_VALUES) → product > MAX_VALUES.
+        size = 33  # 33 * 33 = 1089 > 1000
+        pred = PredicateBuilder.and_predicates([
+            self.pb_k1_k2.is_in('k1', list(range(size))),
+            self.pb_k1_k2.is_in('k2', list(range(size))),
+        ])
+        self.assertGreater(size * size, MAX_VALUES)
+        sel = create_bucket_selector(pred, [self.k1, self.k2])
+        self.assertIsNone(sel)
+
+    def test_null_only_literal_drops_everything(self):
+        # ``id IN (NULL)`` after null-stripping has zero literals; the
+        # cartesian product is empty → selector matches no buckets. Same
+        # behaviour as Java.
+        pred = self.pb_id_val.is_in('id', [None])
+        sel = create_bucket_selector(pred, [self.id_field])
+        self.assertIsNotNone(sel)
+        for b in range(4):
+            self.assertFalse(sel(b, 4),
+                             "all-null literal collapses bucket set to empty")
+
+    def test_no_predicate_returns_none(self):
+        self.assertIsNone(create_bucket_selector(None, [self.id_field]))
+
+    def test_no_bucket_keys_returns_none(self):
+        self.assertIsNone(
+            create_bucket_selector(self.pb_id_val.equal('id', 1), []))
+
+    # -- Selector cache + rescale -------------------------------------
+    def test_selector_caches_per_total_buckets(self):
+        """Selector must answer correctly when the same query applies to
+        different ``total_buckets`` values (the rescale scenario)."""
+        sel = create_bucket_selector(
+            self.pb_id_val.equal('id', 42), [self.id_field])
+        for total in (4, 8, 16, 32):
+            expected = _hash_bucket([42], [self.id_field], total)
+            self.assertTrue(sel(expected, total))
+            other = (expected + 1) % total
+            self.assertFalse(sel(other, total))
+
+    def test_non_positive_total_buckets_fails_open(self):
+        """Manifest entries can carry ``total_buckets <= 0`` for legacy /
+        special bucket modes. Pruning MUST fail open — returning False
+        would silently drop rows the writer placed in those entries.
+        This is correctness, not performance: the soundness contract
+        forbids false-negatives."""
+        sel = create_bucket_selector(
+            self.pb_id_val.equal('id', 1), [self.id_field])
+        for total in (0, -1, -2):
+            self.assertTrue(sel(0, total),
+                            "total_buckets={} must be kept (fail 
open)".format(total))
+            self.assertTrue(sel(-1, total))
+            self.assertTrue(sel(99, total))
+
+    # -- Bucket-key column types beyond BIGINT --------------------------
+    def test_string_bucket_key_yields_correct_bucket(self):
+        """STRING uses a different ``GenericRowSerializer`` path (utf-8
+        encode + variable-part offset) — verify writer/reader agree on
+        its byte layout independent of the BIGINT happy path."""
+        sf = _field(0, 'sk', 'STRING')
+        vf = _bigint_field(1, 'val')
+        pb = PredicateBuilder([sf, vf])
+        sel = create_bucket_selector(pb.equal('sk', 'hello'), [sf])
+        self.assertIsNotNone(sel)
+        expected = _hash_bucket(['hello'], [sf], total=8)
+        for b in range(8):
+            self.assertEqual(sel(b, 8), b == expected)
+
+    def test_int_bucket_key_yields_correct_bucket(self):
+        """INT (32-bit) and BIGINT (64-bit) hit different struct.pack
+        paths in the serializer — guard the smaller width."""
+        intf = _field(0, 'i', 'INT')
+        vf = _bigint_field(1, 'val')
+        pb = PredicateBuilder([intf, vf])
+        sel = create_bucket_selector(pb.equal('i', 7), [intf])
+        self.assertIsNotNone(sel)
+        expected = _hash_bucket([7], [intf], total=4)
+        for b in range(4):
+            self.assertEqual(sel(b, 4), b == expected)
+
+    # -- Hash-divergence-prone types refuse to build a selector --------
+    def test_decimal_bucket_key_disables_pruning(self):
+        """DECIMAL columns risk silent hash divergence between writer
+        (Decimal) and reader-supplied ``float`` literals. Soundness
+        contract demands fail-open: refuse to build a selector at all."""
+        df = _field(0, 'd', 'DECIMAL(10, 2)')
+        vf = _bigint_field(1, 'val')
+        pb = PredicateBuilder([df, vf])
+        from decimal import Decimal
+        sel = create_bucket_selector(pb.equal('d', Decimal('1.50')), [df])
+        self.assertIsNone(
+            sel, "DECIMAL bucket-key column must disable pruning")
+
+    def test_array_bucket_key_disables_pruning(self):
+        """Composite types (ARRAY/MAP/ROW/MULTISET/VARIANT/BLOB) have no
+        cross-validated byte alignment with Java's ``BinaryRow`` — until
+        that exists, refuse to prune on them."""
+        # Hand-roll a DataField whose AtomicType reports an ARRAY type
+        # name; the converter inspects ``field.type.type`` only.
+        af = DataField(0, 'arr', AtomicType('ARRAY<BIGINT>'))
+        vf = _bigint_field(1, 'val')
+        pb = PredicateBuilder([af, vf])
+        sel = create_bucket_selector(pb.equal('arr', [1]), [af])
+        self.assertIsNone(
+            sel, "ARRAY bucket-key column must disable pruning")
+
+    def test_timestamp_bucket_key_disables_pruning(self):
+        """TIMESTAMP columns serialise via ``value.timestamp()`` whose
+        result depends on the process timezone for naive datetimes —
+        writer and reader running in different TZs would disagree."""
+        tf = _field(0, 't', 'TIMESTAMP(3)')
+        vf = _bigint_field(1, 'val')
+        pb = PredicateBuilder([tf, vf])
+        from datetime import datetime
+        sel = create_bucket_selector(
+            pb.equal('t', datetime(2026, 1, 1)), [tf])
+        self.assertIsNone(
+            sel, "TIMESTAMP bucket-key column must disable pruning")
+
+    def test_type_mismatched_literal_fails_open_not_crash(self):
+        """If the user constructs a predicate whose literal type doesn't
+        match the bucket-key column's atomic type — e.g. a STRING literal
+        on a BIGINT column — ``GenericRowSerializer`` raises during the
+        deferred hash inside ``_Selector``. The selector MUST swallow the
+        exception and fail open (return True for every bucket) rather
+        than propagate it. Crashing the entire scan with an opaque
+        ``struct.error`` is a worse user experience than silently
+        skipping bucket pruning, and the soundness contract still
+        forbids false-negatives."""
+        sel = create_bucket_selector(
+            self.pb_id_val.equal('id', 'not-an-int'), [self.id_field])
+        # Construction itself succeeds (no eager hashing).
+        self.assertIsNotNone(sel)
+        # Calling the selector must NOT raise; instead it returns True
+        # for every (bucket, total_buckets), preserving soundness.
+        for total in (4, 8):
+            for b in range(total):
+                self.assertTrue(sel(b, total),
+                                "type-mismatched literal must fail open, "
+                                "not crash (bucket={}, total={})".format(b, 
total))
+
+
+# ---------------------------------------------------------------------------
+# Layer 2 — Integration: real tables, public API, assert correctness AND
+# that pruning actually fired (otherwise we're not testing the optimisation,
+# only that we didn't break full-scan).
+# ---------------------------------------------------------------------------
+class BucketPruningIntegrationTest(unittest.TestCase):
+
+    NUM_BUCKETS = 8
+
+    @classmethod
+    def setUpClass(cls):
+        cls.tempdir = tempfile.mkdtemp()
+        cls.warehouse = os.path.join(cls.tempdir, 'warehouse')
+        cls.catalog = CatalogFactory.create({'warehouse': cls.warehouse})
+        cls.catalog.create_database('default', False)
+
+    @classmethod
+    def tearDownClass(cls):
+        shutil.rmtree(cls.tempdir, ignore_errors=True)
+
+    def _create_pk_table(self, name: str, num_buckets: int = NUM_BUCKETS,
+                         bucket_key: str = None) -> Any:
+        opts = {'bucket': str(num_buckets), 'file.format': 'parquet'}
+        if bucket_key is not None:
+            opts['bucket-key'] = bucket_key
+        pa_schema = pa.schema([
+            pa.field('id', pa.int64(), nullable=False),
+            ('val', pa.int64()),
+        ])
+        schema = Schema.from_pyarrow_schema(
+            pa_schema, primary_keys=['id'], options=opts)
+        full = 'default.{}'.format(name)
+        self.catalog.create_table(full, schema, False)
+        return self.catalog.get_table(full)
+
+    def _write(self, table, rows: List[Dict]):
+        pa_schema = pa.schema([
+            pa.field('id', pa.int64(), nullable=False),
+            ('val', pa.int64()),
+        ])
+        wb = table.new_batch_write_builder()
+        w = wb.new_write()
+        c = wb.new_commit()
+        try:
+            w.write_arrow(pa.Table.from_pylist(rows, schema=pa_schema))
+            c.commit(w.prepare_commit())
+        finally:
+            w.close()
+            c.close()
+
+    def _read_with(self, table, predicate=None):
+        rb = table.new_read_builder()
+        if predicate is not None:
+            rb = rb.with_filter(predicate)
+        splits = rb.new_scan().plan().splits()
+        if not splits:
+            return [], splits
+        return rb.new_read().to_arrow(splits).to_pylist(), splits
+
+    @staticmethod
+    def _split_buckets(splits) -> set:
+        """Collect the distinct bucket numbers actually returned in a plan."""
+        return {s.bucket for s in splits}
+
+    @staticmethod
+    def _expected_buckets(table, ids, value_field='val') -> set:
+        """Use the writer's RowKeyExtractor to compute the bucket(s) the
+        rows for ``ids`` were written into. Cross-check against the
+        reader's selector — divergence indicates read/write hash drift."""
+        ext = FixedBucketRowKeyExtractor(table.table_schema)
+        pa_schema = pa.schema([
+            pa.field('id', pa.int64(), nullable=False),
+            (value_field, pa.int64()),
+        ])
+        out = set()
+        for i in ids:
+            arr = pa.RecordBatch.from_pylist(
+                [{'id': i, value_field: 0}], schema=pa_schema)
+            out.update(ext._extract_buckets_batch(arr))
+        return out
+
+    # -- Equal on PK -----------------------------------------------------
+    def test_pk_equal_only_reads_target_bucket(self):
+        table = self._create_pk_table('int_eq')
+        rows = [{'id': i, 'val': i * 11} for i in range(100)]
+        self._write(table, rows)
+
+        target_id = 42
+        pred = table.new_read_builder().new_predicate_builder().equal(
+            'id', target_id)
+        got, splits = self._read_with(table, pred)
+
+        # Correctness: row for id=42 returned (and only that).
+        self.assertEqual(got, [{'id': 42, 'val': 42 * 11}])
+
+        # Pruning effectiveness AND hash correctness: the touched bucket
+        # must equal the bucket the writer placed id=42 into. Asserting
+        # only ``len == 1`` would mask a hash drift that picks the wrong
+        # single bucket.
+        self.assertEqual(self._split_buckets(splits),
+                         self._expected_buckets(table, [target_id]),
+                         "PK equal must touch exactly the writer's bucket")
+
+    def test_pk_in_reads_only_target_buckets(self):
+        table = self._create_pk_table('int_in')
+        rows = [{'id': i, 'val': i * 7} for i in range(200)]
+        self._write(table, rows)
+
+        ids = [3, 17, 99, 150]
+        pred = table.new_read_builder().new_predicate_builder().is_in(
+            'id', ids)
+        got, splits = self._read_with(table, pred)
+        got_sorted = sorted(got, key=lambda r: r['id'])
+        self.assertEqual(got_sorted,
+                         [{'id': i, 'val': i * 7} for i in sorted(ids)])
+
+        actual = self._split_buckets(splits)
+        expected_buckets = self._expected_buckets(table, ids)
+        # Equality (not subset): under the single-commit setup every
+        # target bucket actually has a file, so the planner must produce
+        # exactly the writer's bucket set. ``issubset`` would mask a
+        # selector that's overly aggressive on a subset of the IN list.
+        self.assertEqual(actual, expected_buckets,
+                         "got {}, expected {}".format(actual, 
expected_buckets))
+
+    # -- Predicates that should NOT prune -------------------------------
+    def test_value_only_predicate_falls_back_to_full_scan(self):
+        """``val < X`` doesn't constrain the PK → selector must be None
+        and no bucket pruning may fire. Both checked: result correctness
+        AND the explicit "selector is None" property."""
+        table = self._create_pk_table('val_only')
+        rows = [{'id': i, 'val': i} for i in range(100)]
+        self._write(table, rows)
+
+        pred = table.new_read_builder().new_predicate_builder().less_than(
+            'val', 30)
+        got, splits = self._read_with(table, pred)
+        self.assertEqual(sorted([r['id'] for r in got]), list(range(30)))
+
+        # Inspect the scanner's bucket selector to prove pruning DIDN'T
+        # fire — without this assertion the test would also pass under a
+        # buggy selector that prunes wrongly but happens to keep the
+        # rows we picked.
+        rb = table.new_read_builder().with_filter(pred)
+        scan = rb.new_scan()
+        self.assertIsNone(scan.file_scanner._bucket_selector,
+                          "value-only predicate must NOT produce a selector")
+
+    def test_range_on_pk_falls_back_to_full_scan(self):
+        """``id > X`` is a range, not Equal/In, so cannot derive a bucket
+        set. Selector returns None — result must still be exact."""
+        table = self._create_pk_table('pk_range')
+        rows = [{'id': i, 'val': i} for i in range(50)]
+        self._write(table, rows)
+
+        pred = 
table.new_read_builder().new_predicate_builder().greater_or_equal(
+            'id', 40)
+        got, _ = self._read_with(table, pred)
+        self.assertEqual(sorted([r['id'] for r in got]), list(range(40, 50)))
+
+    # -- Mixed predicate: Equal on PK AND range on val ------------------
+    def test_pk_equal_with_unrelated_value_predicate_still_prunes(self):
+        table = self._create_pk_table('int_eq_with_val')
+        rows = [{'id': i, 'val': i} for i in range(50)]
+        self._write(table, rows)
+
+        pb = table.new_read_builder().new_predicate_builder()
+        pred = pb.and_predicates([
+            pb.equal('id', 25),
+            pb.greater_than('val', 20),
+        ])
+        got, splits = self._read_with(table, pred)
+        self.assertEqual(got, [{'id': 25, 'val': 25}])
+        self.assertEqual(self._split_buckets(splits),
+                         self._expected_buckets(table, [25]),
+                         "Equal on PK still narrows to the writer's bucket "
+                         "even when AND'd with a non-bucket-key predicate")
+
+    def test_early_filter_skips_full_entry_decode_for_pruned_buckets(self):
+        """Entries the bucket selector rejects must never reach
+        ``GenericRowDeserializer.from_bytes`` for their partition / key
+        stats. Without the early filter the count would scale with the
+        manifest entry count; with it, only the surviving entries pay
+        the deserialisation cost."""
+        from unittest import mock
+
+        from pypaimon.table.row import generic_row
+
+        table = self._create_pk_table('early_filter')
+        # 8 separate single-row commits → 8 manifest entries each touching
+        # a different bucket. ``pk = X`` should reach exactly one of them.
+        for i in range(self.NUM_BUCKETS):
+            self._write(table, [{'id': i, 'val': i * 11}])
+
+        pred = table.new_read_builder().new_predicate_builder().equal('id', 0)
+        rb = table.new_read_builder().with_filter(pred)
+
+        real_from_bytes = generic_row.GenericRowDeserializer.from_bytes
+        calls = {'n': 0}
+
+        def counting(*args, **kwargs):
+            calls['n'] += 1
+            return real_from_bytes(*args, **kwargs)
+
+        with mock.patch.object(generic_row.GenericRowDeserializer,
+                               'from_bytes',
+                               side_effect=counting):
+            splits = rb.new_scan().plan().splits()
+            got = rb.new_read().to_arrow(splits).to_pylist() if splits else []
+
+        self.assertEqual(got, [{'id': 0, 'val': 0}])
+        # Each surviving entry decodes partition + min_key + max_key
+        # (3 ``from_bytes`` calls). Allow a small slack in case the planner
+        # touches extras, but assert it is well below 8 entries × 3 = 24.
+        self.assertLess(
+            calls['n'], 3 * self.NUM_BUCKETS,
+            "early filter should skip from_bytes for pruned entries; "
+            "got {} calls (would be {}+ without the filter)".format(
+                calls['n'], 3 * self.NUM_BUCKETS))
+
+    def test_init_bucket_selector_fails_open_when_bucket_keys_raises(self):
+        """``TableSchema.bucket_keys`` raises if ``bucket-key`` references
+        an unknown column. The pre-Java-alignment selector path used to
+        catch ``Exception`` from instantiating ``FixedBucketRowKeyExtractor``
+        and silently skip pruning; that property must survive the move
+        of bucket-key resolution onto ``TableSchema``. Crashing the scan
+        on a misconfiguration would be worse than skipping the
+        optimisation."""
+        table = self._create_pk_table('init_fails_open')
+        self._write(table, [{'id': 1, 'val': 1}])
+        # Mutate the in-memory schema options to a broken value to
+        # simulate a corrupted/migrated catalog without rewriting it.
+        table.table_schema.options['bucket-key'] = 'nope_no_such_column'
+
+        rb = table.new_read_builder().with_filter(
+            table.new_read_builder().new_predicate_builder().equal('id', 1))
+        scanner = rb.new_scan().file_scanner
+        # Must NOT raise: the broken option falls back to "no pruning",
+        # and the scan still finds the row.
+        self.assertIsNone(scanner._init_bucket_selector())
+        got, _ = self._read_with(table, scanner.predicate)
+        self.assertEqual(got, [{'id': 1, 'val': 1}])
+
+    # -- Explicit bucket-key option ------------------------------------
+    def test_bucket_key_option_overrides_pk_for_pruning(self):
+        """When the ``bucket-key`` option is set explicitly, the bucket
+        derivation must use it — not the trimmed primary keys. This is
+        the path that catches read/write hash divergence if a refactor
+        forgets the option."""
+        # PK = id, bucket-key = id explicitly (single key but exercises
+        # the explicit-config branch in ``_init_bucket_selector``).
+        table = self._create_pk_table('explicit_bk', bucket_key='id')
+        rows = [{'id': i, 'val': i * 3} for i in range(40)]
+        self._write(table, rows)
+
+        pred = table.new_read_builder().new_predicate_builder().equal('id', 17)
+        got, splits = self._read_with(table, pred)
+        self.assertEqual(got, [{'id': 17, 'val': 51}])
+        self.assertEqual(self._split_buckets(splits),
+                         self._expected_buckets(table, [17]))
+
+
+# ---------------------------------------------------------------------------
+# Layer 3 — Property: random PK tables, random Equal/In predicates,
+# correctness vs oracle.
+# ---------------------------------------------------------------------------
+class BucketPruningPropertyTest(unittest.TestCase):
+
+    SEED = 0xB0CC
+    TRIALS = 30
+
+    @classmethod
+    def setUpClass(cls):
+        cls.tempdir = tempfile.mkdtemp()
+        cls.warehouse = os.path.join(cls.tempdir, 'warehouse')
+        cls.catalog = CatalogFactory.create({'warehouse': cls.warehouse})
+        cls.catalog.create_database('default', False)
+        cls.rnd = random.Random(cls.SEED)
+
+    @classmethod
+    def tearDownClass(cls):
+        shutil.rmtree(cls.tempdir, ignore_errors=True)
+
+    def _make_table(self, idx: int, num_buckets: int):
+        pa_schema = pa.schema([
+            pa.field('k', pa.int64(), nullable=False),
+            ('v', pa.int64()),
+        ])
+        schema = Schema.from_pyarrow_schema(
+            pa_schema,
+            primary_keys=['k'],
+            options={'bucket': str(num_buckets), 'file.format': 'parquet'},
+        )
+        name = 'default.bp_{}'.format(idx)
+        self.catalog.create_table(name, schema, False)
+        return self.catalog.get_table(name)
+
+    def _write(self, table, rows):
+        pa_schema = pa.schema([
+            pa.field('k', pa.int64(), nullable=False),
+            ('v', pa.int64()),
+        ])
+        wb = table.new_batch_write_builder()
+        w = wb.new_write()
+        c = wb.new_commit()
+        try:
+            w.write_arrow(pa.Table.from_pylist(rows, schema=pa_schema))
+            c.commit(w.prepare_commit())
+        finally:
+            w.close()
+            c.close()
+
+    @staticmethod
+    def _expected_buckets(table, keys) -> set:
+        """Independent oracle: writer's bucket placement for the given keys."""
+        ext = FixedBucketRowKeyExtractor(table.table_schema)
+        pa_schema = pa.schema([
+            pa.field('k', pa.int64(), nullable=False),
+            ('v', pa.int64()),
+        ])
+        out = set()
+        for k in keys:
+            arr = pa.RecordBatch.from_pylist(
+                [{'k': k, 'v': 0}], schema=pa_schema)
+            out.update(ext._extract_buckets_batch(arr))
+        return out
+
+    def test_property_pk_equal_correctness(self):
+        for trial in range(self.TRIALS):
+            num_buckets = self.rnd.choice([2, 4, 8, 16])
+            table = self._make_table(trial, num_buckets)
+            keys = self.rnd.sample(range(1000), self.rnd.randint(20, 100))
+            rows = [{'k': k, 'v': k * 13} for k in keys]
+            self._write(table, rows)
+
+            target = self.rnd.choice(keys)
+            pb = table.new_read_builder().new_predicate_builder()
+            pred = pb.equal('k', target)
+            rb = table.new_read_builder().with_filter(pred)
+            splits = rb.new_scan().plan().splits()
+            if splits:
+                got = rb.new_read().to_arrow(splits).to_pylist()
+            else:
+                got = []
+            self.assertEqual(got, [{'k': target, 'v': target * 13}],
+                             "trial {} buckets={} target={}: result mismatch"
+                             .format(trial, num_buckets, target))
+            # Pruning fired AND picked the writer's bucket. Without this
+            # cross-check a fail-open selector (i.e. no pruning) would
+            # still pass the result-equality assertion above.
+            self.assertEqual(self._split_buckets(splits),
+                             self._expected_buckets(table, [target]),
+                             "trial {}: bucket set != writer's placement"
+                             .format(trial))
+
+    def test_property_pk_in_correctness(self):
+        for trial in range(self.TRIALS):
+            num_buckets = self.rnd.choice([2, 4, 8, 16])
+            offset = self.TRIALS + trial  # avoid name clash with prev test
+            table = self._make_table(offset, num_buckets)
+            keys = self.rnd.sample(range(1000), self.rnd.randint(20, 100))
+            rows = [{'k': k, 'v': k * 13} for k in keys]
+            self._write(table, rows)
+
+            target_n = self.rnd.randint(1, min(10, len(keys)))
+            targets = self.rnd.sample(keys, target_n)
+            pb = table.new_read_builder().new_predicate_builder()
+            pred = pb.is_in('k', targets)
+            rb = table.new_read_builder().with_filter(pred)
+            splits = rb.new_scan().plan().splits()
+            if splits:
+                got = rb.new_read().to_arrow(splits).to_pylist()
+            else:
+                got = []
+            got_sorted = sorted(got, key=lambda r: r['k'])
+            want = sorted(
+                [{'k': k, 'v': k * 13} for k in targets],
+                key=lambda r: r['k'])
+            self.assertEqual(got_sorted, want,
+                             "trial {}: IN result mismatch".format(trial))
+            self.assertEqual(self._split_buckets(splits),
+                             self._expected_buckets(table, targets),
+                             "trial {}: IN bucket set != writer's placement"
+                             .format(trial))
+
+    @staticmethod
+    def _split_buckets(splits) -> set:
+        return {s.bucket for s in splits}
+
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/paimon-python/pypaimon/tests/table_schema_test.py 
b/paimon-python/pypaimon/tests/table_schema_test.py
new file mode 100644
index 0000000000..d42eef2ab8
--- /dev/null
+++ b/paimon-python/pypaimon/tests/table_schema_test.py
@@ -0,0 +1,93 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+################################################################################
+
+import unittest
+
+from pypaimon.schema.data_types import AtomicType, DataField
+from pypaimon.schema.table_schema import TableSchema
+
+
+def _bigint_field(idx: int, name: str) -> DataField:
+    return DataField(idx, name, AtomicType('BIGINT', nullable=False))
+
+
+def _string_field(idx: int, name: str) -> DataField:
+    return DataField(idx, name, AtomicType('STRING'))
+
+
+class TableSchemaBucketKeysTest(unittest.TestCase):
+    """Cover the ``bucket-key`` resolution lifted onto TableSchema.
+
+    Mirrors Java ``TableSchema.bucketKeys()`` / ``logicalBucketKeyType()``.
+    """
+
+    def _schema(self, primary_keys=None, partition_keys=None, options=None):
+        fields = [
+            _bigint_field(0, 'id'),
+            _string_field(1, 'region'),
+            _bigint_field(2, 'val'),
+        ]
+        return TableSchema(
+            id=0,
+            fields=fields,
+            partition_keys=partition_keys or [],
+            primary_keys=primary_keys or [],
+            options=options or {},
+        )
+
+    def test_explicit_bucket_key_option_returns_those_columns(self):
+        schema = self._schema(
+            primary_keys=['id'],
+            options={'bucket-key': 'region,val'},
+        )
+        self.assertEqual(schema.bucket_keys, ['region', 'val'])
+
+        fields = schema.logical_bucket_key_fields
+        self.assertEqual([f.name for f in fields], ['region', 'val'])
+
+    def test_no_bucket_key_falls_back_to_trimmed_primary_keys(self):
+        # PK includes a partition column; trimmed bucket keys drop it.
+        schema = self._schema(
+            primary_keys=['region', 'id'],
+            partition_keys=['region'],
+        )
+        self.assertEqual(schema.bucket_keys, ['id'])
+        self.assertEqual(
+            [f.name for f in schema.logical_bucket_key_fields], ['id'])
+
+    def test_no_bucket_key_no_primary_keys_returns_empty(self):
+        schema = self._schema()
+        self.assertEqual(schema.bucket_keys, [])
+        self.assertEqual(schema.logical_bucket_key_fields, [])
+
+    def test_unknown_bucket_key_column_raises(self):
+        schema = self._schema(options={'bucket-key': 'nope'})
+        with self.assertRaises(ValueError):
+            _ = schema.bucket_keys
+
+    def test_whitespace_only_option_falls_back(self):
+        # Whitespace-only ``bucket-key`` mirrors an unset option.
+        schema = self._schema(
+            primary_keys=['id'],
+            options={'bucket-key': '   '},
+        )
+        self.assertEqual(schema.bucket_keys, ['id'])
+
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/paimon-python/pypaimon/write/row_key_extractor.py 
b/paimon-python/pypaimon/write/row_key_extractor.py
index 2f09e6577b..c30b63e21b 100644
--- a/paimon-python/pypaimon/write/row_key_extractor.py
+++ b/paimon-python/pypaimon/write/row_key_extractor.py
@@ -125,18 +125,12 @@ class FixedBucketRowKeyExtractor(RowKeyExtractor):
         if self.num_buckets <= 0:
             raise ValueError(f"Fixed bucket mode requires bucket > 0, got 
{self.num_buckets}")
 
-        bucket_key_option = options.bucket_key()
-        if bucket_key_option and bucket_key_option.strip():
-            self.bucket_keys = [k.strip() for k in 
bucket_key_option.split(',')]
-        else:
-            self.bucket_keys = [pk for pk in table_schema.primary_keys
-                                if pk not in table_schema.partition_keys]
-
+        # Bucket-key resolution lives on TableSchema (mirrors Java
+        # ``TableSchema.bucketKeys()`` / ``logicalBucketKeyType()``); reuse
+        # it so any reader path that walks the same logic stays in sync.
+        self.bucket_keys = table_schema.bucket_keys
         self.bucket_key_indices = self._get_field_indices(self.bucket_keys)
-        field_map = {f.name: f for f in table_schema.fields}
-        self._bucket_key_fields = [
-            field_map[name] for name in self.bucket_keys if name in field_map
-        ]
+        self._bucket_key_fields = table_schema.logical_bucket_key_fields
 
     def _extract_buckets_batch(self, data: pa.RecordBatch) -> List[int]:
         columns = [data.column(i) for i in self.bucket_key_indices]

Reply via email to