This is an automated email from the ASF dual-hosted git repository. lzljs3620320 pushed a commit to branch release-1.3 in repository https://gitbox.apache.org/repos/asf/paimon.git
commit 11360cf31cca5df32c2e27540909ef8ff353a4ba Author: Jingsong Lee <[email protected]> AuthorDate: Mon Oct 20 23:56:04 2025 +0800 [python] Refactor dicts to static fields to improve performance (#6436) --- paimon-python/pypaimon/common/predicate.py | 298 +++++++++++++++++++++++------ 1 file changed, 242 insertions(+), 56 deletions(-) diff --git a/paimon-python/pypaimon/common/predicate.py b/paimon-python/pypaimon/common/predicate.py index a245bb8e1e..6a760e473f 100644 --- a/paimon-python/pypaimon/common/predicate.py +++ b/paimon-python/pypaimon/common/predicate.py @@ -16,9 +16,11 @@ # limitations under the License. ################################################################################ +from abc import ABC, ABCMeta, abstractmethod from dataclasses import dataclass from functools import reduce from typing import Any, Dict, List, Optional +from typing import ClassVar import pyarrow from pyarrow import compute as pyarrow_compute @@ -35,6 +37,8 @@ class Predicate: field: Optional[str] literals: Optional[List[Any]] = None + testers: ClassVar[Dict[str, Any]] = {} + def new_index(self, index: int): return Predicate( method=self.method, @@ -56,26 +60,10 @@ class Predicate: t = any(p.test(record) for p in self.literals) return t - dispatch = { - 'equal': lambda val, literals: val == literals[0], - 'notEqual': lambda val, literals: val != literals[0], - 'lessThan': lambda val, literals: val < literals[0], - 'lessOrEqual': lambda val, literals: val <= literals[0], - 'greaterThan': lambda val, literals: val > literals[0], - 'greaterOrEqual': lambda val, literals: val >= literals[0], - 'isNull': lambda val, literals: val is None, - 'isNotNull': lambda val, literals: val is not None, - 'startsWith': lambda val, literals: isinstance(val, str) and val.startswith(literals[0]), - 'endsWith': lambda val, literals: isinstance(val, str) and val.endswith(literals[0]), - 'contains': lambda val, literals: isinstance(val, str) and literals[0] in val, - 'in': lambda val, literals: val in literals, - 'notIn': lambda val, literals: val not in literals, - 'between': lambda val, literals: literals[0] <= val <= literals[1], - } - func = dispatch.get(self.method) - if func: - field_value = record.get_field(self.index) - return func(field_value, self.literals) + field_value = record.get_field(self.index) + tester = Predicate.testers.get(self.method) + if tester: + return tester.test_by_value(field_value, self.literals) raise ValueError(f"Unsupported predicate method: {self.method}") def test_by_simple_stats(self, stat: SimpleStats, row_count: int) -> bool: @@ -110,25 +98,9 @@ class Predicate: # invalid stats, skip validation return True - dispatch = { - 'equal': lambda literals: min_value <= literals[0] <= max_value, - 'notEqual': lambda literals: not (min_value == literals[0] == max_value), - 'lessThan': lambda literals: literals[0] > min_value, - 'lessOrEqual': lambda literals: literals[0] >= min_value, - 'greaterThan': lambda literals: literals[0] < max_value, - 'greaterOrEqual': lambda literals: literals[0] <= max_value, - 'in': lambda literals: any(min_value <= l <= max_value for l in literals), - 'notIn': lambda literals: not any(min_value == l == max_value for l in literals), - 'between': lambda literals: literals[0] <= max_value and literals[1] >= min_value, - 'startsWith': lambda literals: ((isinstance(min_value, str) and isinstance(max_value, str)) and - ((min_value.startswith(literals[0]) or min_value < literals[0]) and - (max_value.startswith(literals[0]) or max_value > literals[0]))), - 'endsWith': lambda literals: True, - 'contains': lambda literals: True, - } - func = dispatch.get(self.method) - if func: - return func(self.literals) + tester = Predicate.testers.get(self.method) + if tester: + return tester.test_by_stats(min_value, max_value, self.literals) raise ValueError(f"Unsupported predicate method: {self.method}") def to_arrow(self) -> Any: @@ -177,22 +149,236 @@ class Predicate: return pyarrow_dataset.field(self.field).is_valid() | pyarrow_dataset.field(self.field).is_null() field = pyarrow_dataset.field(self.field) - dispatch = { - 'equal': lambda literals: field == literals[0], - 'notEqual': lambda literals: field != literals[0], - 'lessThan': lambda literals: field < literals[0], - 'lessOrEqual': lambda literals: field <= literals[0], - 'greaterThan': lambda literals: field > literals[0], - 'greaterOrEqual': lambda literals: field >= literals[0], - 'isNull': lambda literals: field.is_null(), - 'isNotNull': lambda literals: field.is_valid(), - 'in': lambda literals: field.isin(literals), - 'notIn': lambda literals: ~field.isin(literals), - 'between': lambda literals: (field >= literals[0]) & (field <= literals[1]), - } - - func = dispatch.get(self.method) - if func: - return func(self.literals) + tester = Predicate.testers.get(self.method) + if tester: + return tester.test_by_arrow(field, self.literals) raise ValueError("Unsupported predicate method: {}".format(self.method)) + + +class RegisterMeta(ABCMeta): + def __init__(cls, name, bases, dct): + super().__init__(name, bases, dct) + if not bool(cls.__abstractmethods__): + Predicate.testers[cls.name] = cls() + + +class Tester(ABC, metaclass=RegisterMeta): + + name = None + + @abstractmethod + def test_by_value(self, val, literals) -> bool: + """ + Test based on the specific val and literals. + """ + + @abstractmethod + def test_by_stats(self, min_v, max_v, literals) -> bool: + """ + Test based on the specific min_value and max_value and literals. + """ + + @abstractmethod + def test_by_arrow(self, val, literals) -> bool: + """ + Test based on the specific arrow value and literals. + """ + + +class Equal(Tester): + + name = 'equal' + + def test_by_value(self, val, literals) -> bool: + return val == literals[0] + + def test_by_stats(self, min_v, max_v, literals) -> bool: + return min_v <= literals[0] <= max_v + + def test_by_arrow(self, val, literals) -> bool: + return val == literals[0] + + +class NotEqual(Tester): + + name = "notEqual" + + def test_by_value(self, val, literals) -> bool: + return val != literals[0] + + def test_by_stats(self, min_v, max_v, literals) -> bool: + return not (min_v == literals[0] == max_v) + + def test_by_arrow(self, val, literals) -> bool: + return val != literals[0] + + +class LessThan(Tester): + + name = "lessThan" + + def test_by_value(self, val, literals) -> bool: + return val < literals[0] + + def test_by_stats(self, min_v, max_v, literals) -> bool: + return literals[0] > min_v + + def test_by_arrow(self, val, literals) -> bool: + return val < literals[0] + + +class LessOrEqual(Tester): + + name = "lessOrEqual" + + def test_by_value(self, val, literals) -> bool: + return val <= literals[0] + + def test_by_stats(self, min_v, max_v, literals) -> bool: + return literals[0] >= min_v + + def test_by_arrow(self, val, literals) -> bool: + return val <= literals[0] + + +class GreaterThan(Tester): + + name = "greaterThan" + + def test_by_value(self, val, literals) -> bool: + return val > literals[0] + + def test_by_stats(self, min_v, max_v, literals) -> bool: + return literals[0] < max_v + + def test_by_arrow(self, val, literals) -> bool: + return val > literals[0] + + +class GreaterOrEqual(Tester): + + name = "greaterOrEqual" + + def test_by_value(self, val, literals) -> bool: + return val >= literals[0] + + def test_by_stats(self, min_v, max_v, literals) -> bool: + return literals[0] <= max_v + + def test_by_arrow(self, val, literals) -> bool: + return val >= literals[0] + + +class In(Tester): + + name = "in" + + def test_by_value(self, val, literals) -> bool: + return val in literals + + def test_by_stats(self, min_v, max_v, literals) -> bool: + return any(min_v <= l <= max_v for l in literals) + + def test_by_arrow(self, val, literals) -> bool: + return val.isin(literals) + + +class NotIn(Tester): + + name = "notIn" + + def test_by_value(self, val, literals) -> bool: + return val not in literals + + def test_by_stats(self, min_v, max_v, literals) -> bool: + return not any(min_v == l == max_v for l in literals) + + def test_by_arrow(self, val, literals) -> bool: + return ~val.isin(literals) + + +class Between(Tester): + + name = "between" + + def test_by_value(self, val, literals) -> bool: + return literals[0] <= val <= literals[1] + + def test_by_stats(self, min_v, max_v, literals) -> bool: + return literals[0] <= max_v and literals[1] >= min_v + + def test_by_arrow(self, val, literals) -> bool: + return (val >= literals[0]) & (val <= literals[1]) + + +class StartsWith(Tester): + + name = "startsWith" + + def test_by_value(self, val, literals) -> bool: + return isinstance(val, str) and val.startswith(literals[0]) + + def test_by_stats(self, min_v, max_v, literals) -> bool: + return ((isinstance(min_v, str) and isinstance(max_v, str)) and + ((min_v.startswith(literals[0]) or min_v < literals[0]) and + (max_v.startswith(literals[0]) or max_v > literals[0]))) + + def test_by_arrow(self, val, literals) -> bool: + return True + + +class EndsWith(Tester): + + name = "endsWith" + + def test_by_value(self, val, literals) -> bool: + return isinstance(val, str) and val.endswith(literals[0]) + + def test_by_stats(self, min_v, max_v, literals) -> bool: + return True + + def test_by_arrow(self, val, literals) -> bool: + return True + + +class Contains(Tester): + + name = "contains" + + def test_by_value(self, val, literals) -> bool: + return isinstance(val, str) and literals[0] in val + + def test_by_stats(self, min_v, max_v, literals) -> bool: + return True + + def test_by_arrow(self, val, literals) -> bool: + return True + + +class IsNull(Tester): + + name = "isNull" + + def test_by_value(self, val, literals) -> bool: + return val is None + + def test_by_stats(self, min_v, max_v, literals) -> bool: + return True + + def test_by_arrow(self, val, literals) -> bool: + return val.is_null() + + +class IsNotNull(Tester): + + name = "isNotNull" + + def test_by_value(self, val, literals) -> bool: + return val is not None + + def test_by_stats(self, min_v, max_v, literals) -> bool: + return True + + def test_by_arrow(self, val, literals) -> bool: + return val.is_valid()
