This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch refactor-s2 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 1eae362f62720aa98eb483ff7ff3174b852442bd Author: tqchen <[email protected]> AuthorDate: Tue Apr 29 20:21:20 2025 -0400 [FFI] Cython update map to conforms to abc --- ffi/include/tvm/ffi/container/map.h | 4 +- ffi/src/ffi/container.cc | 37 +++++++++++--- python/tvm/ffi/container.py | 95 +++++++++++++++++++++++++++-------- tests/python/ffi/test_container.py | 33 +++++++----- tests/scripts/task_python_unittest.sh | 1 + 5 files changed, 128 insertions(+), 42 deletions(-) diff --git a/ffi/include/tvm/ffi/container/map.h b/ffi/include/tvm/ffi/container/map.h index 49ec5dadbb..26eeee7925 100644 --- a/ffi/include/tvm/ffi/container/map.h +++ b/ffi/include/tvm/ffi/container/map.h @@ -241,7 +241,7 @@ class SmallMapObj : public MapObj, public details::InplaceArrayBase<SmallMapObj, const mapped_type& at(const key_type& key) const { iterator itr = find(key); if (itr.index >= size_) { - TVM_FFI_THROW(IndexError) << "key is not in Map"; + TVM_FFI_THROW(KeyError) << "key is not in Map"; } return itr->second; } @@ -253,7 +253,7 @@ class SmallMapObj : public MapObj, public details::InplaceArrayBase<SmallMapObj, mapped_type& at(const key_type& key) { iterator itr = find(key); if (itr.index >= size_) { - TVM_FFI_THROW(IndexError) << "key is not in Map"; + TVM_FFI_THROW(KeyError) << "key is not in Map"; } return itr->second; } diff --git a/ffi/src/ffi/container.cc b/ffi/src/ffi/container.cc index e5ffaadd9f..885e8395ed 100644 --- a/ffi/src/ffi/container.cc +++ b/ffi/src/ffi/container.cc @@ -60,14 +60,35 @@ TVM_FFI_REGISTER_GLOBAL("ffi.MapGetItem") TVM_FFI_REGISTER_GLOBAL("ffi.MapCount") .set_body_typed([](const ffi::MapObj* n, const Any& k) -> int64_t { return n->count(k); }); -TVM_FFI_REGISTER_GLOBAL("ffi.MapItems").set_body_typed([](const ffi::MapObj* n) -> Array<Any> { - Array<Any> rkvs; - for (const auto& kv : *n) { - rkvs.push_back(kv.first); - rkvs.push_back(kv.second); - } - return rkvs; -}); +TVM_FFI_REGISTER_GLOBAL("ffi.MapForwardIterFunctor") + .set_body_typed([](const ffi::MapObj* n) -> ffi::Function { + class MapForwardIterFunctor { + public: + MapForwardIterFunctor(ffi::MapObj::iterator iter, ffi::MapObj::iterator end) + : iter_(iter), end_(end) {} + // 0 get current key + // 1 get current value + // 2 move to next: return true if success, false if end + Any operator()(int command) const { + if (command == 0) { + return (*iter_).first; + } else if (command == 1) { + return (*iter_).second; + } else { + ++iter_; + if (iter_ == end_) { + return false; + } + return true; + } + } + + private: + mutable ffi::MapObj::iterator iter_; + ffi::MapObj::iterator end_; + }; + return ffi::Function::FromUnpacked(MapForwardIterFunctor(n->begin(), n->end())); + }); } // namespace ffi } // namespace tvm diff --git a/python/tvm/ffi/container.py b/python/tvm/ffi/container.py index 829cf8cf23..77495a52ff 100644 --- a/python/tvm/ffi/container.py +++ b/python/tvm/ffi/container.py @@ -15,7 +15,9 @@ # specific language governing permissions and limitations # under the License. """Container classes.""" -from typing import Any, List, Dict +import collections.abc + +from typing import Any, Mapping, Sequence from . import core from . import _ffi_api from .registry import register_object @@ -63,10 +65,10 @@ def getitem_helper(obj, elem_getter, length, idx): @register_object("object.Array") -class Array(core.Object): +class Array(core.Object, collections.abc.Sequence): """Array container""" - def __init__(self, input_list: List[Any]): + def __init__(self, input_list: Sequence[Any]): self.__init_handle_by_constructor__(_ffi_api.Array, *input_list) def __getitem__(self, idx): @@ -76,11 +78,72 @@ class Array(core.Object): return _ffi_api.ArraySize(self) +class KeysView(collections.abc.KeysView): + """Helper class to return keys view""" + + def __init__(self, backend_map): + self._backend_map = backend_map + + def __len__(self): + return len(self._backend_map) + + def __iter__(self): + if self.__len__() == 0: + return + functor = _ffi_api.MapForwardIterFunctor(self._backend_map) + while True: + k = functor(0) + yield k + if not functor(2): + break + + +class ValuesView(collections.abc.ValuesView): + """Helper class to return values view""" + + def __init__(self, backend_map): + self._backend_map = backend_map + + def __len__(self): + return len(self._backend_map) + + def __iter__(self): + if self.__len__() == 0: + return + functor = _ffi_api.MapForwardIterFunctor(self._backend_map) + while True: + v = functor(1) + yield v + if not functor(2): + break + + +class ItemsView(collections.abc.ItemsView): + """Helper class to return items view""" + + def __init__(self, backend_map): + self.backend_map = backend_map + + def __len__(self): + return len(self.backend_map) + + def __iter__(self): + if self.__len__() == 0: + return + functor = _ffi_api.MapForwardIterFunctor(self.backend_map) + while True: + k = functor(0) + v = functor(1) + yield (k, v) + if not functor(2): + break + + @register_object("object.Map") -class Map(core.Object): - """Map container""" +class Map(core.Object, collections.abc.Mapping): + """Map container.""" - def __init__(self, input_dict: Dict[Any, Any]): + def __init__(self, input_dict: Mapping[Any, Any]): list_kvs = [] for k, v in input_dict.items(): list_kvs.append(k) @@ -93,30 +156,22 @@ class Map(core.Object): def __contains__(self, k): return _ffi_api.MapCount(self, k) != 0 - def __iter__(self): - akvs = _ffi_api.MapItems(self) - for i in range(len(self)): - yield akvs[i * 2] - - def __dir__(self): - return sorted(dir(self.__class__) + ["type_key"]) - def keys(self): - return iter(self) + return KeysView(self) def values(self): - akvs = _ffi_api.MapItems(self) - for i in range(len(self)): - yield akvs[i * 2 + 1] + return ValuesView(self) def items(self): """Get the items from the map""" - akvs = _ffi_api.MapItems(self) - return [(akvs[i], akvs[i + 1]) for i in range(0, len(akvs), 2)] + return ItemsView(self) def __len__(self): return _ffi_api.MapSize(self) + def __iter__(self): + return iter(self.keys()) + def get(self, key, default=None): """Get an element with a default value. diff --git a/tests/python/ffi/test_container.py b/tests/python/ffi/test_container.py index 3a2166dd20..a6e3053c0d 100644 --- a/tests/python/ffi/test_container.py +++ b/tests/python/ffi/test_container.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +import pytest import tvm.ffi as tvm_ffi @@ -46,17 +46,26 @@ def test_int_map(): assert 3 in dd assert 4 in dd assert 5 not in amap - assert {x for x in amap} == {3, 4} - assert set(amap.keys()) == {3, 4} - assert set(amap.values()) == {2, 3} + assert tuple(amap.items()) == ((3, 2), (4, 3)) + assert tuple(amap.keys()) == (3, 4) + assert tuple(amap.values()) == (2, 3) def test_str_map(): - amap = tvm_ffi.convert({"a": 2, "b": 3}) - assert "a" in amap - assert len(amap) == 2 - dd = dict(amap.items()) - assert amap["a"] == 2 - assert amap.get("b") == 3 - assert "a" in dd - assert "b" in dd + data = [] + for i in reversed(range(10)): + data.append((f"a{i}", i)) + amap = tvm_ffi.convert({k: v for k, v in data}) + assert tuple(amap.items()) == tuple(data) + for k, v in data: + assert k in amap + assert amap[k] == v + assert amap.get(k) == v + + assert tuple(k for k in amap) == tuple(k for k, _ in data) + + +def test_key_not_found(): + amap = tvm_ffi.convert({3: 2, 4: 3}) + with pytest.raises(KeyError): + amap[5] diff --git a/tests/scripts/task_python_unittest.sh b/tests/scripts/task_python_unittest.sh index ddc775933c..5417013353 100755 --- a/tests/scripts/task_python_unittest.sh +++ b/tests/scripts/task_python_unittest.sh @@ -36,6 +36,7 @@ run_pytest ${TVM_UNITTEST_TESTSUITE_NAME}-platform-minimal-test tests/python/all # Then run all unittests on both ctypes and cython. TEST_FILES=( + "ffi" "arith" "ci" "codegen"
