This is an automated email from the ASF dual-hosted git repository. lzljs3620320 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/paimon.git
The following commit(s) were added to refs/heads/master by this push: new 77f130bde0 [Python] Predicate Push Down for Scan / Read (#6166) 77f130bde0 is described below commit 77f130bde03381e3b26b86e4f0e95cff5c1dc622 Author: ChengHui Chen <27797326+chenghuic...@users.noreply.github.com> AuthorDate: Fri Aug 29 16:34:24 2025 +0800 [Python] Predicate Push Down for Scan / Read (#6166) --- paimon-python/pypaimon/common/predicate.py | 103 ++++++++++++- .../pypaimon/manifest/manifest_file_manager.py | 32 ++-- .../pypaimon/manifest/manifest_list_manager.py | 12 +- .../pypaimon/manifest/schema/simple_stats.py | 6 +- paimon-python/pypaimon/read/push_down_utils.py | 72 +++++++++ .../pypaimon/read/reader/format_avro_reader.py | 24 +-- .../pypaimon/read/reader/format_pyarrow_reader.py | 50 +------ paimon-python/pypaimon/read/split_read.py | 12 +- paimon-python/pypaimon/read/table_read.py | 35 ++++- paimon-python/pypaimon/read/table_scan.py | 161 +++++++-------------- .../pypaimon/tests/predicate_push_down_test.py | 151 +++++++++++++++++++ paimon-python/pypaimon/write/file_store_commit.py | 6 +- paimon-python/pypaimon/write/writer/data_writer.py | 34 ++--- 13 files changed, 471 insertions(+), 227 deletions(-) diff --git a/paimon-python/pypaimon/common/predicate.py b/paimon-python/pypaimon/common/predicate.py index ee13aca99b..ba56713032 100644 --- a/paimon-python/pypaimon/common/predicate.py +++ b/paimon-python/pypaimon/common/predicate.py @@ -19,7 +19,7 @@ from __future__ import annotations from dataclasses import dataclass from functools import reduce -from typing import Any, List, Optional +from typing import Any, Dict, List, Optional import pyarrow from pyarrow import compute as pyarrow_compute @@ -82,6 +82,107 @@ class Predicate: else: raise ValueError("Unsupported predicate method: {}".format(self.method)) + def test_by_value(self, value: Any) -> bool: + if self.method == 'and': + return all(p.test_by_value(value) for p in self.literals) + if self.method == 'or': + t = any(p.test_by_value(value) for p in self.literals) + return t + + if self.method == 'equal': + return value == self.literals[0] + if self.method == 'notEqual': + return value != self.literals[0] + if self.method == 'lessThan': + return value < self.literals[0] + if self.method == 'lessOrEqual': + return value <= self.literals[0] + if self.method == 'greaterThan': + return value > self.literals[0] + if self.method == 'greaterOrEqual': + return value >= self.literals[0] + if self.method == 'isNull': + return value is None + if self.method == 'isNotNull': + return value is not None + if self.method == 'startsWith': + if not isinstance(value, str): + return False + return value.startswith(self.literals[0]) + if self.method == 'endsWith': + if not isinstance(value, str): + return False + return value.endswith(self.literals[0]) + if self.method == 'contains': + if not isinstance(value, str): + return False + return self.literals[0] in value + if self.method == 'in': + return value in self.literals + if self.method == 'notIn': + return value not in self.literals + if self.method == 'between': + return self.literals[0] <= value <= self.literals[1] + + raise ValueError("Unsupported predicate method: {}".format(self.method)) + + def test_by_stats(self, stat: Dict) -> bool: + if self.method == 'and': + return all(p.test_by_stats(stat) for p in self.literals) + if self.method == 'or': + t = any(p.test_by_stats(stat) for p in self.literals) + return t + + null_count = stat["null_counts"][self.field] + row_count = stat["row_count"] + + if self.method == 'isNull': + return null_count is not None and null_count > 0 + if self.method == 'isNotNull': + return null_count is None or row_count is None or null_count < row_count + + min_value = stat["min_values"][self.field] + max_value = stat["max_values"][self.field] + + if min_value is None or max_value is None or (null_count is not None and null_count == row_count): + return False + + if self.method == 'equal': + return min_value <= self.literals[0] <= max_value + if self.method == 'notEqual': + return not (min_value == self.literals[0] == max_value) + if self.method == 'lessThan': + return self.literals[0] > min_value + if self.method == 'lessOrEqual': + return self.literals[0] >= min_value + if self.method == 'greaterThan': + return self.literals[0] < max_value + if self.method == 'greaterOrEqual': + return self.literals[0] <= max_value + if self.method == 'startsWith': + if not isinstance(min_value, str) or not isinstance(max_value, str): + raise RuntimeError("startsWith predicate on non-str field") + return ((min_value.startswith(self.literals[0]) or min_value < self.literals[0]) + and (max_value.startswith(self.literals[0]) or max_value > self.literals[0])) + if self.method == 'endsWith': + return True + if self.method == 'contains': + return True + if self.method == 'in': + for literal in self.literals: + if min_value <= literal <= max_value: + return True + return False + if self.method == 'notIn': + for literal in self.literals: + if min_value == literal == max_value: + return False + return True + if self.method == 'between': + return self.literals[0] <= max_value and self.literals[1] >= min_value + else: + raise ValueError("Unsupported predicate method: {}".format(self.method)) + def to_arrow(self) -> pyarrow_compute.Expression | bool: if self.method == 'equal': return pyarrow_dataset.field(self.field) == self.literals[0] diff --git a/paimon-python/pypaimon/manifest/manifest_file_manager.py b/paimon-python/pypaimon/manifest/manifest_file_manager.py index 7c46b368d2..7c97f7b0ca 100644 --- a/paimon-python/pypaimon/manifest/manifest_file_manager.py +++ b/paimon-python/pypaimon/manifest/manifest_file_manager.py @@ -55,19 +55,19 @@ class ManifestFileManager: file_dict = dict(record['_FILE']) key_dict = dict(file_dict['_KEY_STATS']) key_stats = SimpleStats( - min_value=BinaryRowDeserializer.from_bytes(key_dict['_MIN_VALUES'], - self.trimmed_primary_key_fields), - max_value=BinaryRowDeserializer.from_bytes(key_dict['_MAX_VALUES'], - self.trimmed_primary_key_fields), - null_count=key_dict['_NULL_COUNTS'], + min_values=BinaryRowDeserializer.from_bytes(key_dict['_MIN_VALUES'], + self.trimmed_primary_key_fields), + max_values=BinaryRowDeserializer.from_bytes(key_dict['_MAX_VALUES'], + self.trimmed_primary_key_fields), + null_counts=key_dict['_NULL_COUNTS'], ) value_dict = dict(file_dict['_VALUE_STATS']) value_stats = SimpleStats( - min_value=BinaryRowDeserializer.from_bytes(value_dict['_MIN_VALUES'], - self.table.table_schema.fields), - max_value=BinaryRowDeserializer.from_bytes(value_dict['_MAX_VALUES'], - self.table.table_schema.fields), - null_count=value_dict['_NULL_COUNTS'], + min_values=BinaryRowDeserializer.from_bytes(value_dict['_MIN_VALUES'], + self.table.table_schema.fields), + max_values=BinaryRowDeserializer.from_bytes(value_dict['_MAX_VALUES'], + self.table.table_schema.fields), + null_counts=value_dict['_NULL_COUNTS'], ) file_meta = DataFileMeta( file_name=file_dict['_FILE_NAME'], @@ -118,14 +118,14 @@ class ManifestFileManager: "_MIN_KEY": BinaryRowSerializer.to_bytes(file.min_key), "_MAX_KEY": BinaryRowSerializer.to_bytes(file.max_key), "_KEY_STATS": { - "_MIN_VALUES": BinaryRowSerializer.to_bytes(file.key_stats.min_value), - "_MAX_VALUES": BinaryRowSerializer.to_bytes(file.key_stats.max_value), - "_NULL_COUNTS": file.key_stats.null_count, + "_MIN_VALUES": BinaryRowSerializer.to_bytes(file.key_stats.min_values), + "_MAX_VALUES": BinaryRowSerializer.to_bytes(file.key_stats.max_values), + "_NULL_COUNTS": file.key_stats.null_counts, }, "_VALUE_STATS": { - "_MIN_VALUES": BinaryRowSerializer.to_bytes(file.value_stats.min_value), - "_MAX_VALUES": BinaryRowSerializer.to_bytes(file.value_stats.max_value), - "_NULL_COUNTS": file.value_stats.null_count, + "_MIN_VALUES": BinaryRowSerializer.to_bytes(file.value_stats.min_values), + "_MAX_VALUES": BinaryRowSerializer.to_bytes(file.value_stats.max_values), + "_NULL_COUNTS": file.value_stats.null_counts, }, "_MIN_SEQUENCE_NUMBER": file.min_sequence_number, "_MAX_SEQUENCE_NUMBER": file.max_sequence_number, diff --git a/paimon-python/pypaimon/manifest/manifest_list_manager.py b/paimon-python/pypaimon/manifest/manifest_list_manager.py index 65fd2b21ac..dc9d5db44d 100644 --- a/paimon-python/pypaimon/manifest/manifest_list_manager.py +++ b/paimon-python/pypaimon/manifest/manifest_list_manager.py @@ -58,15 +58,15 @@ class ManifestListManager: for record in reader: stats_dict = dict(record['_PARTITION_STATS']) partition_stats = SimpleStats( - min_value=BinaryRowDeserializer.from_bytes( + min_values=BinaryRowDeserializer.from_bytes( stats_dict['_MIN_VALUES'], self.table.table_schema.get_partition_key_fields() ), - max_value=BinaryRowDeserializer.from_bytes( + max_values=BinaryRowDeserializer.from_bytes( stats_dict['_MAX_VALUES'], self.table.table_schema.get_partition_key_fields() ), - null_count=stats_dict['_NULL_COUNTS'], + null_counts=stats_dict['_NULL_COUNTS'], ) manifest_file_meta = ManifestFileMeta( file_name=record['_FILE_NAME'], @@ -90,9 +90,9 @@ class ManifestListManager: "_NUM_ADDED_FILES": meta.num_added_files, "_NUM_DELETED_FILES": meta.num_deleted_files, "_PARTITION_STATS": { - "_MIN_VALUES": BinaryRowSerializer.to_bytes(meta.partition_stats.min_value), - "_MAX_VALUES": BinaryRowSerializer.to_bytes(meta.partition_stats.max_value), - "_NULL_COUNTS": meta.partition_stats.null_count, + "_MIN_VALUES": BinaryRowSerializer.to_bytes(meta.partition_stats.min_values), + "_MAX_VALUES": BinaryRowSerializer.to_bytes(meta.partition_stats.max_values), + "_NULL_COUNTS": meta.partition_stats.null_counts, }, "_SCHEMA_ID": meta.schema_id, } diff --git a/paimon-python/pypaimon/manifest/schema/simple_stats.py b/paimon-python/pypaimon/manifest/schema/simple_stats.py index 4a73d3eee4..55b2163e76 100644 --- a/paimon-python/pypaimon/manifest/schema/simple_stats.py +++ b/paimon-python/pypaimon/manifest/schema/simple_stats.py @@ -24,9 +24,9 @@ from pypaimon.table.row.binary_row import BinaryRow @dataclass class SimpleStats: - min_value: BinaryRow - max_value: BinaryRow - null_count: Optional[List[int]] + min_values: BinaryRow + max_values: BinaryRow + null_counts: Optional[List[int]] SIMPLE_STATS_SCHEMA = { diff --git a/paimon-python/pypaimon/read/push_down_utils.py b/paimon-python/pypaimon/read/push_down_utils.py new file mode 100644 index 0000000000..31e66973c6 --- /dev/null +++ b/paimon-python/pypaimon/read/push_down_utils.py @@ -0,0 +1,72 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +from typing import Dict, List, Set + +from pypaimon.common.predicate import Predicate + + +def extract_predicate_to_list(result: list, input_predicate: 'Predicate', keys: List[str]): + if not input_predicate or not keys: + return + + if input_predicate.method == 'and': + for sub_predicate in input_predicate.literals: + extract_predicate_to_list(result, sub_predicate, keys) + return + elif input_predicate.method == 'or': + # condition: involved keys all belong to primary keys + involved_fields = _get_all_fields(input_predicate) + if involved_fields and involved_fields.issubset(keys): + result.append(input_predicate) + return + + if input_predicate.field in keys: + result.append(input_predicate) + + +def _get_all_fields(predicate: 'Predicate') -> Set[str]: + if predicate.field is not None: + return {predicate.field} + involved_fields = set() + if predicate.literals: + for sub_predicate in predicate.literals: + involved_fields.update(_get_all_fields(sub_predicate)) + return involved_fields + + +def extract_predicate_to_dict(result: Dict, input_predicate: 'Predicate', keys: List[str]): + if not input_predicate or not keys: + return + + if input_predicate.method == 'and': + for sub_predicate in input_predicate.literals: + extract_predicate_to_dict(result, sub_predicate, keys) + return + elif input_predicate.method == 'or': + # ensure no recursive and/or + if not input_predicate.literals or any(p.field is None for p in input_predicate.literals): + return + # condition: only one key for 'or', and the key belongs to keys + involved_fields = {p.field for p in input_predicate.literals} + if len(involved_fields) == 1 and (field := involved_fields.pop()) in keys: + result[field].append(input_predicate) + return + + if input_predicate.field in keys: + result[input_predicate.field].append(input_predicate) diff --git a/paimon-python/pypaimon/read/reader/format_avro_reader.py b/paimon-python/pypaimon/read/reader/format_avro_reader.py index 83e90606e7..4ce7c04ed4 100644 --- a/paimon-python/pypaimon/read/reader/format_avro_reader.py +++ b/paimon-python/pypaimon/read/reader/format_avro_reader.py @@ -20,11 +20,11 @@ from typing import List, Optional import fastavro import pyarrow as pa +import pyarrow.compute as pc import pyarrow.dataset as ds from pyarrow import RecordBatch from pypaimon.common.file_io import FileIO -from pypaimon.common.predicate import Predicate from pypaimon.read.reader.iface.record_batch_reader import RecordBatchReader from pypaimon.schema.data_types import DataField, PyarrowFieldParser @@ -35,26 +35,18 @@ class FormatAvroReader(RecordBatchReader): provided predicate and projection, and converts Avro records to RecordBatch format. """ - def __init__(self, file_io: FileIO, file_path: str, primary_keys: List[str], - fields: List[str], full_fields: List[DataField], predicate: Predicate, batch_size: int = 4096): + def __init__(self, file_io: FileIO, file_path: str, read_fields: List[str], full_fields: List[DataField], + push_down_predicate: pc.Expression | bool, batch_size: int = 4096): self._file = file_io.filesystem.open_input_file(file_path) self._avro_reader = fastavro.reader(self._file) self._batch_size = batch_size - self._primary_keys = primary_keys + self._push_down_predicate = push_down_predicate - self._fields = fields + self._fields = read_fields full_fields_map = {field.name: field for field in full_fields} - projected_data_fields = [full_fields_map[name] for name in fields] + projected_data_fields = [full_fields_map[name] for name in read_fields] self._schema = PyarrowFieldParser.from_paimon_schema(projected_data_fields) - if primary_keys: - # TODO: utilize predicate to improve performance - predicate = None - if predicate is not None: - self._predicate = predicate.to_arrow() - else: - self._predicate = None - def read_arrow_batch(self) -> Optional[RecordBatch]: pydict_data = {name: [] for name in self._fields} records_in_batch = 0 @@ -68,12 +60,12 @@ class FormatAvroReader(RecordBatchReader): if records_in_batch == 0: return None - if self._predicate is None: + if self._push_down_predicate is None: return pa.RecordBatch.from_pydict(pydict_data, self._schema) else: pa_batch = pa.Table.from_pydict(pydict_data, self._schema) dataset = ds.InMemoryDataset(pa_batch) - scanner = dataset.scanner(filter=self._predicate) + scanner = dataset.scanner(filter=self._push_down_predicate) combine_chunks = scanner.to_table().combine_chunks() if combine_chunks.num_rows > 0: return combine_chunks.to_batches()[0] diff --git a/paimon-python/pypaimon/read/reader/format_pyarrow_reader.py b/paimon-python/pypaimon/read/reader/format_pyarrow_reader.py index b01f8113b7..ecef589391 100644 --- a/paimon-python/pypaimon/read/reader/format_pyarrow_reader.py +++ b/paimon-python/pypaimon/read/reader/format_pyarrow_reader.py @@ -18,11 +18,11 @@ from typing import List, Optional +import pyarrow.compute as pc import pyarrow.dataset as ds from pyarrow import RecordBatch from pypaimon.common.file_io import FileIO -from pypaimon.common.predicate import Predicate from pypaimon.read.reader.iface.record_batch_reader import RecordBatchReader @@ -32,19 +32,12 @@ class FormatPyArrowReader(RecordBatchReader): and filters it based on the provided predicate and projection. """ - def __init__(self, file_io: FileIO, file_format: str, file_path: str, primary_keys: List[str], - fields: List[str], predicate: Predicate, batch_size: int = 4096): - - if primary_keys: - # TODO: utilize predicate to improve performance - predicate = None - if predicate is not None: - predicate = predicate.to_arrow() - + def __init__(self, file_io: FileIO, file_format: str, file_path: str, read_fields: List[str], + push_down_predicate: pc.Expression | bool, batch_size: int = 4096): self.dataset = ds.dataset(file_path, format=file_format, filesystem=file_io.filesystem) self.reader = self.dataset.scanner( - columns=fields, - filter=predicate, + columns=read_fields, + filter=push_down_predicate, batch_size=batch_size ).to_reader() @@ -58,36 +51,3 @@ class FormatPyArrowReader(RecordBatchReader): if self.reader is not None: self.reader.close() self.reader = None - - -def _filter_predicate_by_primary_keys(predicate: Predicate, primary_keys): - """ - Filter out predicates that are not related to primary key fields. - """ - if predicate is None or primary_keys is None: - return predicate - - if predicate.method in ['and', 'or']: - filtered_literals = [] - for literal in predicate.literals: - filtered = _filter_predicate_by_primary_keys(literal, primary_keys) - if filtered is not None: - filtered_literals.append(filtered) - - if not filtered_literals: - return None - - if len(filtered_literals) == 1: - return filtered_literals[0] - - return Predicate( - method=predicate.method, - index=predicate.index, - field=predicate.field, - literals=filtered_literals - ) - - if predicate.field in primary_keys: - return predicate - else: - return None diff --git a/paimon-python/pypaimon/read/split_read.py b/paimon-python/pypaimon/read/split_read.py index 99f8a4da21..1fe0a89d0e 100644 --- a/paimon-python/pypaimon/read/split_read.py +++ b/paimon-python/pypaimon/read/split_read.py @@ -49,11 +49,13 @@ NULL_FIELD_INDEX = -1 class SplitRead(ABC): """Abstract base class for split reading operations.""" - def __init__(self, table, predicate: Optional[Predicate], read_type: List[DataField], split: Split): + def __init__(self, table, predicate: Optional[Predicate], push_down_predicate, + read_type: List[DataField], split: Split): from pypaimon.table.file_store_table import FileStoreTable self.table: FileStoreTable = table self.predicate = predicate + self.push_down_predicate = push_down_predicate self.split = split self.value_arity = len(read_type) @@ -72,11 +74,11 @@ class SplitRead(ABC): format_reader: RecordBatchReader if file_format == "avro": - format_reader = FormatAvroReader(self.table.file_io, file_path, self.table.primary_keys, - self._get_final_read_data_fields(), self.read_fields, self.predicate) + format_reader = FormatAvroReader(self.table.file_io, file_path, self._get_final_read_data_fields(), + self.read_fields, self.push_down_predicate) elif file_format == "parquet" or file_format == "orc": - format_reader = FormatPyArrowReader(self.table.file_io, file_format, file_path, self.table.primary_keys, - self._get_final_read_data_fields(), self.predicate) + format_reader = FormatPyArrowReader(self.table.file_io, file_format, file_path, + self._get_final_read_data_fields(), self.push_down_predicate) else: raise ValueError(f"Unexpected file format: {file_format}") diff --git a/paimon-python/pypaimon/read/table_read.py b/paimon-python/pypaimon/read/table_read.py index b8a28c19d1..621549d832 100644 --- a/paimon-python/pypaimon/read/table_read.py +++ b/paimon-python/pypaimon/read/table_read.py @@ -19,8 +19,11 @@ from typing import Iterator, List, Optional import pandas import pyarrow +import pyarrow.compute as pc from pypaimon.common.predicate import Predicate +from pypaimon.common.predicate_builder import PredicateBuilder +from pypaimon.read.push_down_utils import extract_predicate_to_list from pypaimon.read.reader.iface.record_batch_reader import RecordBatchReader from pypaimon.read.split import Split from pypaimon.read.split_read import (MergeFileSplitRead, RawFileSplitRead, @@ -37,6 +40,7 @@ class TableRead: self.table: FileStoreTable = table self.predicate = predicate + self.push_down_predicate = self._push_down_predicate() self.read_type = read_type def to_iterator(self, splits: List[Split]) -> Iterator: @@ -78,12 +82,12 @@ class TableRead: row_tuple_chunk.append(row.row_tuple[row.offset: row.offset + row.arity]) if len(row_tuple_chunk) >= chunk_size: - batch = convert_rows_to_arrow_batch(row_tuple_chunk, schema) + batch = self.convert_rows_to_arrow_batch(row_tuple_chunk, schema) yield batch row_tuple_chunk = [] if row_tuple_chunk: - batch = convert_rows_to_arrow_batch(row_tuple_chunk, schema) + batch = self.convert_rows_to_arrow_batch(row_tuple_chunk, schema) yield batch finally: reader.close() @@ -105,11 +109,27 @@ class TableRead: return ray.data.from_arrow(self.to_arrow(splits)) + def _push_down_predicate(self) -> pc.Expression | bool: + if self.predicate is None: + return None + elif self.table.is_primary_key_table: + result = [] + extract_predicate_to_list(result, self.predicate, self.table.primary_keys) + if result: + # the field index is unused for arrow field + pk_predicates = (PredicateBuilder(self.table.fields).and_predicates(result)).to_arrow() + return pk_predicates + else: + return None + else: + return self.predicate.to_arrow() + def _create_split_read(self, split: Split) -> SplitRead: if self.table.is_primary_key_table and not split.raw_convertible: return MergeFileSplitRead( table=self.table, predicate=self.predicate, + push_down_predicate=self.push_down_predicate, read_type=self.read_type, split=split ) @@ -117,12 +137,13 @@ class TableRead: return RawFileSplitRead( table=self.table, predicate=self.predicate, + push_down_predicate=self.push_down_predicate, read_type=self.read_type, split=split ) - -def convert_rows_to_arrow_batch(row_tuples: List[tuple], schema: pyarrow.Schema) -> pyarrow.RecordBatch: - columns_data = zip(*row_tuples) - pydict = {name: list(column) for name, column in zip(schema.names, columns_data)} - return pyarrow.RecordBatch.from_pydict(pydict, schema=schema) + @staticmethod + def convert_rows_to_arrow_batch(row_tuples: List[tuple], schema: pyarrow.Schema) -> pyarrow.RecordBatch: + columns_data = zip(*row_tuples) + pydict = {name: list(column) for name, column in zip(schema.names, columns_data)} + return pyarrow.RecordBatch.from_pydict(pydict, schema=schema) diff --git a/paimon-python/pypaimon/read/table_scan.py b/paimon-python/pypaimon/read/table_scan.py index d89ddd0bcb..1c2c4f33dc 100644 --- a/paimon-python/pypaimon/read/table_scan.py +++ b/paimon-python/pypaimon/read/table_scan.py @@ -20,12 +20,15 @@ from collections import defaultdict from typing import Callable, List, Optional from pypaimon.common.predicate import Predicate +from pypaimon.common.predicate_builder import PredicateBuilder from pypaimon.manifest.manifest_file_manager import ManifestFileManager from pypaimon.manifest.manifest_list_manager import ManifestListManager from pypaimon.manifest.schema.data_file_meta import DataFileMeta from pypaimon.manifest.schema.manifest_entry import ManifestEntry from pypaimon.read.interval_partition import IntervalPartition, SortedRun from pypaimon.read.plan import Plan +from pypaimon.read.push_down_utils import (extract_predicate_to_dict, + extract_predicate_to_list) from pypaimon.read.split import Split from pypaimon.schema.data_types import DataField from pypaimon.snapshot.snapshot_manager import SnapshotManager @@ -49,15 +52,23 @@ class TableScan: self.manifest_list_manager = ManifestListManager(table) self.manifest_file_manager = ManifestFileManager(table) - self.partition_conditions = self._extract_partition_conditions() + pk_conditions = [] + trimmed_pk = [field.name for field in self.table.table_schema.get_trimmed_primary_key_fields()] + extract_predicate_to_list(pk_conditions, self.predicate, trimmed_pk) + self.primary_key_predicate = PredicateBuilder(self.table.fields).and_predicates(pk_conditions) + + partition_conditions = defaultdict(list) + extract_predicate_to_dict(partition_conditions, self.predicate, self.table.partition_keys) + self.partition_key_predicate = partition_conditions + self.target_split_size = 128 * 1024 * 1024 self.open_file_cost = 4 * 1024 * 1024 self.idx_of_this_subtask = None self.number_of_para_subtasks = None - self.only_read_real_buckets = True if self.table.options.get('bucket', - -1) == BucketMode.POSTPONE_BUCKET.value else False + self.only_read_real_buckets = True \ + if (self.table.options.get('bucket', -1) == BucketMode.POSTPONE_BUCKET.value) else False def plan(self) -> Plan: latest_snapshot = self.snapshot_manager.get_latest_snapshot() @@ -129,7 +140,7 @@ class TableScan: filtered_files = [] for file_entry in file_entries: - if self.partition_conditions and not self._filter_by_partition(file_entry): + if self.partition_key_predicate and not self._filter_by_partition(file_entry): continue if not self._filter_by_stats(file_entry): continue @@ -138,98 +149,31 @@ class TableScan: return filtered_files def _filter_by_partition(self, file_entry: ManifestEntry) -> bool: - # TODO: refactor with a better solution partition_dict = file_entry.partition.to_dict() - for field_name, condition in self.partition_conditions.items(): + for field_name, conditions in self.partition_key_predicate.items(): partition_value = partition_dict[field_name] - if condition['op'] == '=': - if str(partition_value) != str(condition['value']): - return False - elif condition['op'] == 'in': - if str(partition_value) not in [str(v) for v in condition['values']]: - return False - elif condition['op'] == 'notIn': - if str(partition_value) in [str(v) for v in condition['values']]: - return False - elif condition['op'] == '>': - if partition_value <= condition['values']: - return False - elif condition['op'] == '>=': - if partition_value < condition['values']: - return False - elif condition['op'] == '<': - if partition_value >= condition['values']: - return False - elif condition['op'] == '<=': - if partition_value > condition['values']: + for predicate in conditions: + if not predicate.test_by_value(partition_value): return False return True def _filter_by_stats(self, file_entry: ManifestEntry) -> bool: - # TODO: real support for filtering by stat - return True - - def _extract_partition_conditions(self) -> dict: - if not self.predicate or not self.table.partition_keys: - return {} - - conditions = {} - self._extract_conditions_from_predicate(self.predicate, conditions, self.table.partition_keys) - return conditions - - def _extract_conditions_from_predicate(self, predicate: 'Predicate', conditions: dict, - partition_keys: List[str]): - if predicate.method == 'and': - for sub_predicate in predicate.literals: - self._extract_conditions_from_predicate(sub_predicate, conditions, partition_keys) - return - elif predicate.method == 'or': - all_partition_conditions = True - for sub_predicate in predicate.literals: - if sub_predicate.field not in partition_keys: - all_partition_conditions = False - break - if all_partition_conditions: - for sub_predicate in predicate.literals: - self._extract_conditions_from_predicate(sub_predicate, conditions, partition_keys) - return - - if predicate.field in partition_keys: - if predicate.method == 'equal': - conditions[predicate.field] = { - 'op': '=', - 'value': predicate.literals[0] if predicate.literals else None - } - elif predicate.method == 'in': - conditions[predicate.field] = { - 'op': 'in', - 'values': predicate.literals if predicate.literals else [] - } - elif predicate.method == 'notIn': - conditions[predicate.field] = { - 'op': 'notIn', - 'values': predicate.literals if predicate.literals else [] - } - elif predicate.method == 'greaterThan': - conditions[predicate.field] = { - 'op': '>', - 'value': predicate.literals[0] if predicate.literals else None - } - elif predicate.method == 'greaterOrEqual': - conditions[predicate.field] = { - 'op': '>=', - 'value': predicate.literals[0] if predicate.literals else None - } - elif predicate.method == 'lessThan': - conditions[predicate.field] = { - 'op': '<', - 'value': predicate.literals[0] if predicate.literals else None - } - elif predicate.method == 'lessOrEqual': - conditions[predicate.field] = { - 'op': '<=', - 'value': predicate.literals[0] if predicate.literals else None - } + if file_entry.kind != 0: + return False + if self.table.is_primary_key_table: + predicate = self.primary_key_predicate + stats = file_entry.file.key_stats + else: + predicate = self.predicate + stats = file_entry.file.value_stats + return predicate.test_by_stats({ + "min_values": stats.min_values.to_dict(), + "max_values": stats.max_values.to_dict(), + "null_counts": { + stats.min_values.fields[i].name: stats.null_counts[i] for i in range(len(stats.min_values.fields)) + }, + "row_count": file_entry.file.row_count, + }) def _create_append_only_splits(self, file_entries: List[ManifestEntry]) -> List['Split']: if not file_entries: @@ -240,7 +184,7 @@ class TableScan: def weight_func(f: DataFileMeta) -> int: return max(f.file_size, self.open_file_cost) - packed_files: List[List[DataFileMeta]] = _pack_for_ordered(data_files, weight_func, self.target_split_size) + packed_files: List[List[DataFileMeta]] = self._pack_for_ordered(data_files, weight_func, self.target_split_size) return self._build_split_from_pack(packed_files, file_entries, False) def _create_primary_key_splits(self, file_entries: List[ManifestEntry]) -> List['Split']: @@ -257,7 +201,8 @@ class TableScan: def weight_func(fl: List[DataFileMeta]) -> int: return max(sum(f.file_size for f in fl), self.open_file_cost) - packed_files: List[List[List[DataFileMeta]]] = _pack_for_ordered(sections, weight_func, self.target_split_size) + packed_files: List[List[List[DataFileMeta]]] = self._pack_for_ordered(sections, weight_func, + self.target_split_size) flatten_packed_files: List[List[DataFileMeta]] = [ [file for sub_pack in pack for file in sub_pack] for pack in packed_files @@ -295,23 +240,23 @@ class TableScan: splits.append(split) return splits + @staticmethod + def _pack_for_ordered(items: List, weight_func: Callable, target_weight: int) -> List[List]: + packed = [] + bin_items = [] + bin_weight = 0 -def _pack_for_ordered(items: List, weight_func: Callable, target_weight: int) -> List[List]: - packed = [] - bin_items = [] - bin_weight = 0 + for item in items: + weight = weight_func(item) + if bin_weight + weight > target_weight and len(bin_items) > 0: + packed.append(bin_items) + bin_items.clear() + bin_weight = 0 - for item in items: - weight = weight_func(item) - if bin_weight + weight > target_weight and len(bin_items) > 0: - packed.append(bin_items) - bin_items.clear() - bin_weight = 0 - - bin_weight += weight - bin_items.append(item) + bin_weight += weight + bin_items.append(item) - if len(bin_items) > 0: - packed.append(bin_items) + if len(bin_items) > 0: + packed.append(bin_items) - return packed + return packed diff --git a/paimon-python/pypaimon/tests/predicate_push_down_test.py b/paimon-python/pypaimon/tests/predicate_push_down_test.py new file mode 100644 index 0000000000..b5b403f674 --- /dev/null +++ b/paimon-python/pypaimon/tests/predicate_push_down_test.py @@ -0,0 +1,151 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import os +import shutil +import tempfile +import unittest + +import pyarrow as pa + +from pypaimon.catalog.catalog_factory import CatalogFactory +from pypaimon.common.predicate_builder import PredicateBuilder +from pypaimon.schema.schema import Schema + + +class PredicatePushDownTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.tempdir = tempfile.mkdtemp() + cls.warehouse = os.path.join(cls.tempdir, 'warehouse') + cls.catalog = CatalogFactory.create({ + 'warehouse': cls.warehouse + }) + cls.catalog.create_database('default', False) + + cls.pa_schema = pa.schema([ + pa.field('key1', pa.int32(), nullable=False), + pa.field('key2', pa.string(), nullable=False), + ('behavior', pa.string()), + pa.field('dt1', pa.string(), nullable=False), + pa.field('dt2', pa.int32(), nullable=False) + ]) + cls.expected = pa.Table.from_pydict({ + 'key1': [1, 2, 3, 4, 5, 7, 8], + 'key2': ['h', 'g', 'f', 'e', 'd', 'b', 'a'], + 'behavior': ['a', 'b-new', 'c', None, 'e', 'g', 'h'], + 'dt1': ['p1', 'p1', 'p2', 'p1', 'p2', 'p1', 'p2'], + 'dt2': [2, 2, 1, 2, 2, 1, 2], + }, schema=cls.pa_schema) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tempdir, ignore_errors=True) + + def testPkReaderWithFilter(self): + schema = Schema.from_pyarrow_schema(self.pa_schema, + partition_keys=['dt1', 'dt2'], + primary_keys=['key1', 'key2'], + options={'bucket': '1'}) + self.catalog.create_table('default.test_pk_filter', schema, False) + table = self.catalog.get_table('default.test_pk_filter') + + write_builder = table.new_batch_write_builder() + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + data1 = { + 'key1': [1, 2, 3, 4], + 'key2': ['h', 'g', 'f', 'e'], + 'behavior': ['a', 'b', 'c', None], + 'dt1': ['p1', 'p1', 'p2', 'p1'], + 'dt2': [2, 2, 1, 2], + } + pa_table = pa.Table.from_pydict(data1, schema=self.pa_schema) + table_write.write_arrow(pa_table) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + data1 = { + 'key1': [5, 2, 7, 8], + 'key2': ['d', 'g', 'b', 'a'], + 'behavior': ['e', 'b-new', 'g', 'h'], + 'dt1': ['p2', 'p1', 'p1', 'p2'], + 'dt2': [2, 2, 1, 2] + } + pa_table = pa.Table.from_pydict(data1, schema=self.pa_schema) + table_write.write_arrow(pa_table) + table_commit.commit(table_write.prepare_commit()) + table_write.close() + table_commit.close() + + # test filter by partition + predicate_builder: PredicateBuilder = table.new_read_builder().new_predicate_builder() + p1 = predicate_builder.startswith('dt1', "p1") + p2 = predicate_builder.is_in('dt1', ["p2"]) + p3 = predicate_builder.or_predicates([p1, p2]) + p4 = predicate_builder.equal('dt2', 2) + g1 = predicate_builder.and_predicates([p3, p4]) + # (dt1 startswith 'p1' or dt1 is_in ["p2"]) and dt2 == 2 + read_builder = table.new_read_builder().with_filter(g1) + splits = read_builder.new_scan().plan().splits() + self.assertEqual(len(splits), 2) + self.assertEqual(splits[0].partition.to_dict()["dt2"], 2) + self.assertEqual(splits[1].partition.to_dict()["dt2"], 2) + + # test filter by stats + predicate_builder: PredicateBuilder = table.new_read_builder().new_predicate_builder() + p1 = predicate_builder.equal('key1', 7) + p2 = predicate_builder.is_in('key2', ["e", "f"]) + p3 = predicate_builder.or_predicates([p1, p2]) + p4 = predicate_builder.greater_than('key1', 3) + g1 = predicate_builder.and_predicates([p3, p4]) + # (key1 == 7 or key2 is_in ["e", "f"]) and key1 > 3 + read_builder = table.new_read_builder().with_filter(g1) + splits = read_builder.new_scan().plan().splits() + # initial splits meta: + # p1, 2 -> 2g, 2g; 1e, 4h + # p2, 1 -> 3f, 3f + # p2, 2 -> 5a, 8d + # p1, 1 -> 7b, 7b + self.assertEqual(len(splits), 3) + # expect to filter out `p1, 2 -> 2g, 2g` and `p2, 1 -> 3f, 3f` + count = 0 + for split in splits: + if split.partition.values == ["p1", 2]: + count += 1 + self.assertEqual(len(split.files), 1) + min_values = split.files[0].value_stats.min_values.to_dict() + max_values = split.files[0].value_stats.max_values.to_dict() + self.assertTrue(min_values["key1"] == 1 and min_values["key2"] == "e" + and max_values["key1"] == 4 and max_values["key2"] == "h") + elif split.partition.values == ["p2", 2]: + count += 1 + min_values = split.files[0].value_stats.min_values.to_dict() + max_values = split.files[0].value_stats.max_values.to_dict() + self.assertTrue(min_values["key1"] == 5 and min_values["key2"] == "a" + and max_values["key1"] == 8 and max_values["key2"] == "d") + elif split.partition.values == ["p1", 1]: + count += 1 + min_values = split.files[0].value_stats.min_values.to_dict() + max_values = split.files[0].value_stats.max_values.to_dict() + self.assertTrue(min_values["key1"] == max_values["key1"] == 7 + and max_values["key2"] == max_values["key2"] == "b") + self.assertEqual(count, 3) diff --git a/paimon-python/pypaimon/write/file_store_commit.py b/paimon-python/pypaimon/write/file_store_commit.py index e7bf7ba534..efe8207ebe 100644 --- a/paimon-python/pypaimon/write/file_store_commit.py +++ b/paimon-python/pypaimon/write/file_store_commit.py @@ -81,15 +81,15 @@ class FileStoreCommit: num_added_files=sum(len(msg.new_files) for msg in commit_messages), num_deleted_files=0, partition_stats=SimpleStats( - min_value=BinaryRow( + min_values=BinaryRow( values=partition_min_stats, fields=self.table.table_schema.get_partition_key_fields(), ), - max_value=BinaryRow( + max_values=BinaryRow( values=partition_max_stats, fields=self.table.table_schema.get_partition_key_fields(), ), - null_count=partition_null_counts, + null_counts=partition_null_counts, ), schema_id=self.table.table_schema.id, ) diff --git a/paimon-python/pypaimon/write/writer/data_writer.py b/paimon-python/pypaimon/write/writer/data_writer.py index 5d9641718c..bc4553847c 100644 --- a/paimon-python/pypaimon/write/writer/data_writer.py +++ b/paimon-python/pypaimon/write/writer/data_writer.py @@ -19,7 +19,7 @@ import uuid from abc import ABC, abstractmethod from datetime import datetime from pathlib import Path -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple import pyarrow as pa import pyarrow.compute as pc @@ -128,13 +128,13 @@ class DataWriter(ABC): for field in self.table.table_schema.fields } all_fields = self.table.table_schema.fields - min_value_stats = [column_stats[field.name]['min_value'] for field in all_fields] - max_value_stats = [column_stats[field.name]['max_value'] for field in all_fields] - value_null_counts = [column_stats[field.name]['null_count'] for field in all_fields] + min_value_stats = [column_stats[field.name]['min_values'] for field in all_fields] + max_value_stats = [column_stats[field.name]['max_values'] for field in all_fields] + value_null_counts = [column_stats[field.name]['null_counts'] for field in all_fields] key_fields = self.trimmed_primary_key_fields - min_key_stats = [column_stats[field.name]['min_value'] for field in key_fields] - max_key_stats = [column_stats[field.name]['max_value'] for field in key_fields] - key_null_counts = [column_stats[field.name]['null_count'] for field in key_fields] + min_key_stats = [column_stats[field.name]['min_values'] for field in key_fields] + max_key_stats = [column_stats[field.name]['max_values'] for field in key_fields] + key_null_counts = [column_stats[field.name]['null_counts'] for field in key_fields] if not all(count == 0 for count in key_null_counts): raise RuntimeError("Primary key should not be null") @@ -203,21 +203,21 @@ class DataWriter(ABC): return best_split @staticmethod - def _get_column_stats(record_batch: pa.RecordBatch, column_name: str) -> dict: + def _get_column_stats(record_batch: pa.RecordBatch, column_name: str) -> Dict: column_array = record_batch.column(column_name) if column_array.null_count == len(column_array): return { - "min_value": None, - "max_value": None, - "null_count": column_array.null_count, + "min_values": None, + "max_values": None, + "null_counts": column_array.null_count, } - min_value = pc.min(column_array).as_py() - max_value = pc.max(column_array).as_py() - null_count = column_array.null_count + min_values = pc.min(column_array).as_py() + max_values = pc.max(column_array).as_py() + null_counts = column_array.null_count return { - "min_value": min_value, - "max_value": max_value, - "null_count": null_count, + "min_values": min_values, + "max_values": max_values, + "null_counts": null_counts, }