This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 438f643 feat: Fix perf issue in `Map.get` (#341)
438f643 is described below
commit 438f6439148b059d424ce2cc2a348736923f6948
Author: Junru Shao <[email protected]>
AuthorDate: Fri Dec 12 13:08:38 2025 -0800
feat: Fix perf issue in `Map.get` (#341)
Should fix #326
---
python/tvm_ffi/_ffi_api.py | 6 +++++-
python/tvm_ffi/container.py | 17 +++++++++--------
src/ffi/container.cc | 18 ++++++++++++++++--
3 files changed, 30 insertions(+), 11 deletions(-)
diff --git a/python/tvm_ffi/_ffi_api.py b/python/tvm_ffi/_ffi_api.py
index 4b49716..f9850a5 100644
--- a/python/tvm_ffi/_ffi_api.py
+++ b/python/tvm_ffi/_ffi_api.py
@@ -23,7 +23,7 @@ from __future__ import annotations
from typing import Any, Callable, TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
- from tvm_ffi import Module
+ from tvm_ffi import Module, Object
from tvm_ffi.access_path import AccessPath
# isort: on
# fmt: on
@@ -50,6 +50,8 @@ if TYPE_CHECKING:
def MapCount(_0: Mapping[Any, Any], _1: Any, /) -> int: ...
def MapForwardIterFunctor(_0: Mapping[Any, Any], /) -> Callable[..., Any]:
...
def MapGetItem(_0: Mapping[Any, Any], _1: Any, /) -> Any: ...
+ def MapGetItemOrMissing(_0: Mapping[Any, Any], _1: Any, /) -> Any: ...
+ def MapGetMissingObject() -> Object: ...
def MapSize(_0: Mapping[Any, Any], /) -> int: ...
def ModuleClearImports(_0: Module, /) -> None: ...
def ModuleGetFunction(_0: Module, _1: str, _2: bool, /) -> Callable[...,
Any] | None: ...
@@ -95,6 +97,8 @@ __all__ = [
"MapCount",
"MapForwardIterFunctor",
"MapGetItem",
+ "MapGetItemOrMissing",
+ "MapGetMissingObject",
"MapSize",
"ModuleClearImports",
"ModuleGetFunction",
diff --git a/python/tvm_ffi/container.py b/python/tvm_ffi/container.py
index 06fb92e..dfa0d22 100644
--- a/python/tvm_ffi/container.py
+++ b/python/tvm_ffi/container.py
@@ -77,6 +77,8 @@ K = TypeVar("K")
V = TypeVar("V")
_DefaultT = TypeVar("_DefaultT")
+MISSING = _ffi_api.MapGetMissingObject()
+
def getitem_helper(
obj: Any,
@@ -254,12 +256,11 @@ class ItemsView(ItemsViewBase[K, V]):
if not isinstance(item, tuple) or len(item) != 2:
return False
key, value = item
- try:
- existing_value = self._backend_map[key]
- except KeyError:
+ actual_value = self._backend_map.get(key, MISSING)
+ if actual_value is MISSING:
return False
- else:
- return existing_value == value
+ # TODO(@junrus): Is `__eq__` the right method to use here?
+ return actual_value == value
@register_object("ffi.Map")
@@ -349,10 +350,10 @@ class Map(core.Object, Mapping[K, V]):
The result value.
"""
- try:
- return self[key]
- except KeyError:
+ ret = _ffi_api.MapGetItemOrMissing(self, key)
+ if MISSING.same_as(ret):
return default
+ return ret
def __repr__(self) -> str:
"""Return a string representation of the map."""
diff --git a/src/ffi/container.cc b/src/ffi/container.cc
index b777dc0..57eda37 100644
--- a/src/ffi/container.cc
+++ b/src/ffi/container.cc
@@ -55,6 +55,11 @@ class MapForwardIterFunctor {
ffi::MapObj::iterator end_;
};
+ObjectRef GetMissingObject() {
+ static ObjectRef missing_obj(make_object<Object>());
+ return missing_obj;
+}
+
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
@@ -81,8 +86,17 @@ TVM_FFI_STATIC_INIT_BLOCK() {
[](const ffi::MapObj* n, const Any& k) -> int64_t {
return static_cast<int64_t>(n->count(k));
})
- .def("ffi.MapForwardIterFunctor", [](const ffi::MapObj* n) ->
ffi::Function {
- return ffi::Function::FromTyped(MapForwardIterFunctor(n->begin(),
n->end()));
+ .def("ffi.MapForwardIterFunctor",
+ [](const ffi::MapObj* n) -> ffi::Function {
+ return ffi::Function::FromTyped(MapForwardIterFunctor(n->begin(),
n->end()));
+ })
+ .def("ffi.MapGetMissingObject", GetMissingObject)
+ .def("ffi.MapGetItemOrMissing", [](const ffi::MapObj* n, const Any& k)
-> Any {
+ try {
+ return n->at(k);
+ } catch (const tvm::ffi::Error& e) {
+ return GetMissingObject();
+ }
});
}
} // namespace ffi