This is an automated email from the ASF dual-hosted git repository.
lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git
The following commit(s) were added to refs/heads/master by this push:
new 9bd9c41cb9 [python] Refactor dicts to static fields to improve
performance (#6436)
9bd9c41cb9 is described below
commit 9bd9c41cb9747bc203fc1bbc7913d75b7b9adca9
Author: Jingsong Lee <[email protected]>
AuthorDate: Mon Oct 20 23:56:04 2025 +0800
[python] Refactor dicts to static fields to improve performance (#6436)
---
paimon-python/pypaimon/common/predicate.py | 298 +++++++++++++++++++++++------
1 file changed, 242 insertions(+), 56 deletions(-)
diff --git a/paimon-python/pypaimon/common/predicate.py
b/paimon-python/pypaimon/common/predicate.py
index a245bb8e1e..6a760e473f 100644
--- a/paimon-python/pypaimon/common/predicate.py
+++ b/paimon-python/pypaimon/common/predicate.py
@@ -16,9 +16,11 @@
# limitations under the License.
################################################################################
+from abc import ABC, ABCMeta, abstractmethod
from dataclasses import dataclass
from functools import reduce
from typing import Any, Dict, List, Optional
+from typing import ClassVar
import pyarrow
from pyarrow import compute as pyarrow_compute
@@ -35,6 +37,8 @@ class Predicate:
field: Optional[str]
literals: Optional[List[Any]] = None
+ testers: ClassVar[Dict[str, Any]] = {}
+
def new_index(self, index: int):
return Predicate(
method=self.method,
@@ -56,26 +60,10 @@ class Predicate:
t = any(p.test(record) for p in self.literals)
return t
- dispatch = {
- 'equal': lambda val, literals: val == literals[0],
- 'notEqual': lambda val, literals: val != literals[0],
- 'lessThan': lambda val, literals: val < literals[0],
- 'lessOrEqual': lambda val, literals: val <= literals[0],
- 'greaterThan': lambda val, literals: val > literals[0],
- 'greaterOrEqual': lambda val, literals: val >= literals[0],
- 'isNull': lambda val, literals: val is None,
- 'isNotNull': lambda val, literals: val is not None,
- 'startsWith': lambda val, literals: isinstance(val, str) and
val.startswith(literals[0]),
- 'endsWith': lambda val, literals: isinstance(val, str) and
val.endswith(literals[0]),
- 'contains': lambda val, literals: isinstance(val, str) and
literals[0] in val,
- 'in': lambda val, literals: val in literals,
- 'notIn': lambda val, literals: val not in literals,
- 'between': lambda val, literals: literals[0] <= val <= literals[1],
- }
- func = dispatch.get(self.method)
- if func:
- field_value = record.get_field(self.index)
- return func(field_value, self.literals)
+ field_value = record.get_field(self.index)
+ tester = Predicate.testers.get(self.method)
+ if tester:
+ return tester.test_by_value(field_value, self.literals)
raise ValueError(f"Unsupported predicate method: {self.method}")
def test_by_simple_stats(self, stat: SimpleStats, row_count: int) -> bool:
@@ -110,25 +98,9 @@ class Predicate:
# invalid stats, skip validation
return True
- dispatch = {
- 'equal': lambda literals: min_value <= literals[0] <= max_value,
- 'notEqual': lambda literals: not (min_value == literals[0] ==
max_value),
- 'lessThan': lambda literals: literals[0] > min_value,
- 'lessOrEqual': lambda literals: literals[0] >= min_value,
- 'greaterThan': lambda literals: literals[0] < max_value,
- 'greaterOrEqual': lambda literals: literals[0] <= max_value,
- 'in': lambda literals: any(min_value <= l <= max_value for l in
literals),
- 'notIn': lambda literals: not any(min_value == l == max_value for
l in literals),
- 'between': lambda literals: literals[0] <= max_value and
literals[1] >= min_value,
- 'startsWith': lambda literals: ((isinstance(min_value, str) and
isinstance(max_value, str)) and
-
((min_value.startswith(literals[0]) or min_value < literals[0]) and
-
(max_value.startswith(literals[0]) or max_value > literals[0]))),
- 'endsWith': lambda literals: True,
- 'contains': lambda literals: True,
- }
- func = dispatch.get(self.method)
- if func:
- return func(self.literals)
+ tester = Predicate.testers.get(self.method)
+ if tester:
+ return tester.test_by_stats(min_value, max_value, self.literals)
raise ValueError(f"Unsupported predicate method: {self.method}")
def to_arrow(self) -> Any:
@@ -177,22 +149,236 @@ class Predicate:
return pyarrow_dataset.field(self.field).is_valid() |
pyarrow_dataset.field(self.field).is_null()
field = pyarrow_dataset.field(self.field)
- dispatch = {
- 'equal': lambda literals: field == literals[0],
- 'notEqual': lambda literals: field != literals[0],
- 'lessThan': lambda literals: field < literals[0],
- 'lessOrEqual': lambda literals: field <= literals[0],
- 'greaterThan': lambda literals: field > literals[0],
- 'greaterOrEqual': lambda literals: field >= literals[0],
- 'isNull': lambda literals: field.is_null(),
- 'isNotNull': lambda literals: field.is_valid(),
- 'in': lambda literals: field.isin(literals),
- 'notIn': lambda literals: ~field.isin(literals),
- 'between': lambda literals: (field >= literals[0]) & (field <=
literals[1]),
- }
-
- func = dispatch.get(self.method)
- if func:
- return func(self.literals)
+ tester = Predicate.testers.get(self.method)
+ if tester:
+ return tester.test_by_arrow(field, self.literals)
raise ValueError("Unsupported predicate method:
{}".format(self.method))
+
+
+class RegisterMeta(ABCMeta):
+ def __init__(cls, name, bases, dct):
+ super().__init__(name, bases, dct)
+ if not bool(cls.__abstractmethods__):
+ Predicate.testers[cls.name] = cls()
+
+
+class Tester(ABC, metaclass=RegisterMeta):
+
+ name = None
+
+ @abstractmethod
+ def test_by_value(self, val, literals) -> bool:
+ """
+ Test based on the specific val and literals.
+ """
+
+ @abstractmethod
+ def test_by_stats(self, min_v, max_v, literals) -> bool:
+ """
+ Test based on the specific min_value and max_value and literals.
+ """
+
+ @abstractmethod
+ def test_by_arrow(self, val, literals) -> bool:
+ """
+ Test based on the specific arrow value and literals.
+ """
+
+
+class Equal(Tester):
+
+ name = 'equal'
+
+ def test_by_value(self, val, literals) -> bool:
+ return val == literals[0]
+
+ def test_by_stats(self, min_v, max_v, literals) -> bool:
+ return min_v <= literals[0] <= max_v
+
+ def test_by_arrow(self, val, literals) -> bool:
+ return val == literals[0]
+
+
+class NotEqual(Tester):
+
+ name = "notEqual"
+
+ def test_by_value(self, val, literals) -> bool:
+ return val != literals[0]
+
+ def test_by_stats(self, min_v, max_v, literals) -> bool:
+ return not (min_v == literals[0] == max_v)
+
+ def test_by_arrow(self, val, literals) -> bool:
+ return val != literals[0]
+
+
+class LessThan(Tester):
+
+ name = "lessThan"
+
+ def test_by_value(self, val, literals) -> bool:
+ return val < literals[0]
+
+ def test_by_stats(self, min_v, max_v, literals) -> bool:
+ return literals[0] > min_v
+
+ def test_by_arrow(self, val, literals) -> bool:
+ return val < literals[0]
+
+
+class LessOrEqual(Tester):
+
+ name = "lessOrEqual"
+
+ def test_by_value(self, val, literals) -> bool:
+ return val <= literals[0]
+
+ def test_by_stats(self, min_v, max_v, literals) -> bool:
+ return literals[0] >= min_v
+
+ def test_by_arrow(self, val, literals) -> bool:
+ return val <= literals[0]
+
+
+class GreaterThan(Tester):
+
+ name = "greaterThan"
+
+ def test_by_value(self, val, literals) -> bool:
+ return val > literals[0]
+
+ def test_by_stats(self, min_v, max_v, literals) -> bool:
+ return literals[0] < max_v
+
+ def test_by_arrow(self, val, literals) -> bool:
+ return val > literals[0]
+
+
+class GreaterOrEqual(Tester):
+
+ name = "greaterOrEqual"
+
+ def test_by_value(self, val, literals) -> bool:
+ return val >= literals[0]
+
+ def test_by_stats(self, min_v, max_v, literals) -> bool:
+ return literals[0] <= max_v
+
+ def test_by_arrow(self, val, literals) -> bool:
+ return val >= literals[0]
+
+
+class In(Tester):
+
+ name = "in"
+
+ def test_by_value(self, val, literals) -> bool:
+ return val in literals
+
+ def test_by_stats(self, min_v, max_v, literals) -> bool:
+ return any(min_v <= l <= max_v for l in literals)
+
+ def test_by_arrow(self, val, literals) -> bool:
+ return val.isin(literals)
+
+
+class NotIn(Tester):
+
+ name = "notIn"
+
+ def test_by_value(self, val, literals) -> bool:
+ return val not in literals
+
+ def test_by_stats(self, min_v, max_v, literals) -> bool:
+ return not any(min_v == l == max_v for l in literals)
+
+ def test_by_arrow(self, val, literals) -> bool:
+ return ~val.isin(literals)
+
+
+class Between(Tester):
+
+ name = "between"
+
+ def test_by_value(self, val, literals) -> bool:
+ return literals[0] <= val <= literals[1]
+
+ def test_by_stats(self, min_v, max_v, literals) -> bool:
+ return literals[0] <= max_v and literals[1] >= min_v
+
+ def test_by_arrow(self, val, literals) -> bool:
+ return (val >= literals[0]) & (val <= literals[1])
+
+
+class StartsWith(Tester):
+
+ name = "startsWith"
+
+ def test_by_value(self, val, literals) -> bool:
+ return isinstance(val, str) and val.startswith(literals[0])
+
+ def test_by_stats(self, min_v, max_v, literals) -> bool:
+ return ((isinstance(min_v, str) and isinstance(max_v, str)) and
+ ((min_v.startswith(literals[0]) or min_v < literals[0]) and
+ (max_v.startswith(literals[0]) or max_v > literals[0])))
+
+ def test_by_arrow(self, val, literals) -> bool:
+ return True
+
+
+class EndsWith(Tester):
+
+ name = "endsWith"
+
+ def test_by_value(self, val, literals) -> bool:
+ return isinstance(val, str) and val.endswith(literals[0])
+
+ def test_by_stats(self, min_v, max_v, literals) -> bool:
+ return True
+
+ def test_by_arrow(self, val, literals) -> bool:
+ return True
+
+
+class Contains(Tester):
+
+ name = "contains"
+
+ def test_by_value(self, val, literals) -> bool:
+ return isinstance(val, str) and literals[0] in val
+
+ def test_by_stats(self, min_v, max_v, literals) -> bool:
+ return True
+
+ def test_by_arrow(self, val, literals) -> bool:
+ return True
+
+
+class IsNull(Tester):
+
+ name = "isNull"
+
+ def test_by_value(self, val, literals) -> bool:
+ return val is None
+
+ def test_by_stats(self, min_v, max_v, literals) -> bool:
+ return True
+
+ def test_by_arrow(self, val, literals) -> bool:
+ return val.is_null()
+
+
+class IsNotNull(Tester):
+
+ name = "isNotNull"
+
+ def test_by_value(self, val, literals) -> bool:
+ return val is not None
+
+ def test_by_stats(self, min_v, max_v, literals) -> bool:
+ return True
+
+ def test_by_arrow(self, val, literals) -> bool:
+ return val.is_valid()