This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch refactor-s2
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 1eae362f62720aa98eb483ff7ff3174b852442bd
Author: tqchen <[email protected]>
AuthorDate: Tue Apr 29 20:21:20 2025 -0400

    [FFI] Cython update map to conforms to abc
---
 ffi/include/tvm/ffi/container/map.h   |  4 +-
 ffi/src/ffi/container.cc              | 37 +++++++++++---
 python/tvm/ffi/container.py           | 95 +++++++++++++++++++++++++++--------
 tests/python/ffi/test_container.py    | 33 +++++++-----
 tests/scripts/task_python_unittest.sh |  1 +
 5 files changed, 128 insertions(+), 42 deletions(-)

diff --git a/ffi/include/tvm/ffi/container/map.h 
b/ffi/include/tvm/ffi/container/map.h
index 49ec5dadbb..26eeee7925 100644
--- a/ffi/include/tvm/ffi/container/map.h
+++ b/ffi/include/tvm/ffi/container/map.h
@@ -241,7 +241,7 @@ class SmallMapObj : public MapObj, public 
details::InplaceArrayBase<SmallMapObj,
   const mapped_type& at(const key_type& key) const {
     iterator itr = find(key);
     if (itr.index >= size_) {
-      TVM_FFI_THROW(IndexError) << "key is not in Map";
+      TVM_FFI_THROW(KeyError) << "key is not in Map";
     }
     return itr->second;
   }
@@ -253,7 +253,7 @@ class SmallMapObj : public MapObj, public 
details::InplaceArrayBase<SmallMapObj,
   mapped_type& at(const key_type& key) {
     iterator itr = find(key);
     if (itr.index >= size_) {
-      TVM_FFI_THROW(IndexError) << "key is not in Map";
+      TVM_FFI_THROW(KeyError) << "key is not in Map";
     }
     return itr->second;
   }
diff --git a/ffi/src/ffi/container.cc b/ffi/src/ffi/container.cc
index e5ffaadd9f..885e8395ed 100644
--- a/ffi/src/ffi/container.cc
+++ b/ffi/src/ffi/container.cc
@@ -60,14 +60,35 @@ TVM_FFI_REGISTER_GLOBAL("ffi.MapGetItem")
 TVM_FFI_REGISTER_GLOBAL("ffi.MapCount")
     .set_body_typed([](const ffi::MapObj* n, const Any& k) -> int64_t { return 
n->count(k); });
 
-TVM_FFI_REGISTER_GLOBAL("ffi.MapItems").set_body_typed([](const ffi::MapObj* 
n) -> Array<Any> {
-  Array<Any> rkvs;
-  for (const auto& kv : *n) {
-    rkvs.push_back(kv.first);
-    rkvs.push_back(kv.second);
-  }
-  return rkvs;
-});
+TVM_FFI_REGISTER_GLOBAL("ffi.MapForwardIterFunctor")
+    .set_body_typed([](const ffi::MapObj* n) -> ffi::Function {
+      class MapForwardIterFunctor {
+       public:
+        MapForwardIterFunctor(ffi::MapObj::iterator iter, 
ffi::MapObj::iterator end)
+            : iter_(iter), end_(end) {}
+        // 0 get current key
+        // 1 get current value
+        // 2 move to next: return true if success, false if end
+        Any operator()(int command) const {
+          if (command == 0) {
+            return (*iter_).first;
+          } else if (command == 1) {
+            return (*iter_).second;
+          } else {
+            ++iter_;
+            if (iter_ == end_) {
+              return false;
+            }
+            return true;
+          }
+        }
+
+       private:
+        mutable ffi::MapObj::iterator iter_;
+        ffi::MapObj::iterator end_;
+      };
+      return ffi::Function::FromUnpacked(MapForwardIterFunctor(n->begin(), 
n->end()));
+    });
 
 }  // namespace ffi
 }  // namespace tvm
diff --git a/python/tvm/ffi/container.py b/python/tvm/ffi/container.py
index 829cf8cf23..77495a52ff 100644
--- a/python/tvm/ffi/container.py
+++ b/python/tvm/ffi/container.py
@@ -15,7 +15,9 @@
 # specific language governing permissions and limitations
 # under the License.
 """Container classes."""
-from typing import Any, List, Dict
+import collections.abc
+
+from typing import Any, Mapping, Sequence
 from . import core
 from . import _ffi_api
 from .registry import register_object
@@ -63,10 +65,10 @@ def getitem_helper(obj, elem_getter, length, idx):
 
 
 @register_object("object.Array")
-class Array(core.Object):
+class Array(core.Object, collections.abc.Sequence):
     """Array container"""
 
-    def __init__(self, input_list: List[Any]):
+    def __init__(self, input_list: Sequence[Any]):
         self.__init_handle_by_constructor__(_ffi_api.Array, *input_list)
 
     def __getitem__(self, idx):
@@ -76,11 +78,72 @@ class Array(core.Object):
         return _ffi_api.ArraySize(self)
 
 
+class KeysView(collections.abc.KeysView):
+    """Helper class to return keys view"""
+
+    def __init__(self, backend_map):
+        self._backend_map = backend_map
+
+    def __len__(self):
+        return len(self._backend_map)
+
+    def __iter__(self):
+        if self.__len__() == 0:
+            return
+        functor = _ffi_api.MapForwardIterFunctor(self._backend_map)
+        while True:
+            k = functor(0)
+            yield k
+            if not functor(2):
+                break
+
+
+class ValuesView(collections.abc.ValuesView):
+    """Helper class to return values view"""
+
+    def __init__(self, backend_map):
+        self._backend_map = backend_map
+
+    def __len__(self):
+        return len(self._backend_map)
+
+    def __iter__(self):
+        if self.__len__() == 0:
+            return
+        functor = _ffi_api.MapForwardIterFunctor(self._backend_map)
+        while True:
+            v = functor(1)
+            yield v
+            if not functor(2):
+                break
+
+
+class ItemsView(collections.abc.ItemsView):
+    """Helper class to return items view"""
+
+    def __init__(self, backend_map):
+        self.backend_map = backend_map
+
+    def __len__(self):
+        return len(self.backend_map)
+
+    def __iter__(self):
+        if self.__len__() == 0:
+            return
+        functor = _ffi_api.MapForwardIterFunctor(self.backend_map)
+        while True:
+            k = functor(0)
+            v = functor(1)
+            yield (k, v)
+            if not functor(2):
+                break
+
+
 @register_object("object.Map")
-class Map(core.Object):
-    """Map container"""
+class Map(core.Object, collections.abc.Mapping):
+    """Map container."""
 
-    def __init__(self, input_dict: Dict[Any, Any]):
+    def __init__(self, input_dict: Mapping[Any, Any]):
         list_kvs = []
         for k, v in input_dict.items():
             list_kvs.append(k)
@@ -93,30 +156,22 @@ class Map(core.Object):
     def __contains__(self, k):
         return _ffi_api.MapCount(self, k) != 0
 
-    def __iter__(self):
-        akvs = _ffi_api.MapItems(self)
-        for i in range(len(self)):
-            yield akvs[i * 2]
-
-    def __dir__(self):
-        return sorted(dir(self.__class__) + ["type_key"])
-
     def keys(self):
-        return iter(self)
+        return KeysView(self)
 
     def values(self):
-        akvs = _ffi_api.MapItems(self)
-        for i in range(len(self)):
-            yield akvs[i * 2 + 1]
+        return ValuesView(self)
 
     def items(self):
         """Get the items from the map"""
-        akvs = _ffi_api.MapItems(self)
-        return [(akvs[i], akvs[i + 1]) for i in range(0, len(akvs), 2)]
+        return ItemsView(self)
 
     def __len__(self):
         return _ffi_api.MapSize(self)
 
+    def __iter__(self):
+        return iter(self.keys())
+
     def get(self, key, default=None):
         """Get an element with a default value.
 
diff --git a/tests/python/ffi/test_container.py 
b/tests/python/ffi/test_container.py
index 3a2166dd20..a6e3053c0d 100644
--- a/tests/python/ffi/test_container.py
+++ b/tests/python/ffi/test_container.py
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
+import pytest
 import tvm.ffi as tvm_ffi
 
 
@@ -46,17 +46,26 @@ def test_int_map():
     assert 3 in dd
     assert 4 in dd
     assert 5 not in amap
-    assert {x for x in amap} == {3, 4}
-    assert set(amap.keys()) == {3, 4}
-    assert set(amap.values()) == {2, 3}
+    assert tuple(amap.items()) == ((3, 2), (4, 3))
+    assert tuple(amap.keys()) == (3, 4)
+    assert tuple(amap.values()) == (2, 3)
 
 
 def test_str_map():
-    amap = tvm_ffi.convert({"a": 2, "b": 3})
-    assert "a" in amap
-    assert len(amap) == 2
-    dd = dict(amap.items())
-    assert amap["a"] == 2
-    assert amap.get("b") == 3
-    assert "a" in dd
-    assert "b" in dd
+    data = []
+    for i in reversed(range(10)):
+        data.append((f"a{i}", i))
+    amap = tvm_ffi.convert({k: v for k, v in data})
+    assert tuple(amap.items()) == tuple(data)
+    for k, v in data:
+        assert k in amap
+        assert amap[k] == v
+        assert amap.get(k) == v
+
+    assert tuple(k for k in amap) == tuple(k for k, _ in data)
+
+
+def test_key_not_found():
+    amap = tvm_ffi.convert({3: 2, 4: 3})
+    with pytest.raises(KeyError):
+        amap[5]
diff --git a/tests/scripts/task_python_unittest.sh 
b/tests/scripts/task_python_unittest.sh
index ddc775933c..5417013353 100755
--- a/tests/scripts/task_python_unittest.sh
+++ b/tests/scripts/task_python_unittest.sh
@@ -36,6 +36,7 @@ run_pytest 
${TVM_UNITTEST_TESTSUITE_NAME}-platform-minimal-test tests/python/all
 
 # Then run all unittests on both ctypes and cython.
 TEST_FILES=(
+  "ffi"
   "arith"
   "ci"
   "codegen"

Reply via email to