This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch refactor-s2
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/refactor-s2 by this push:
     new f4eed48b58 [FFI] Cython update map items to use lazy iterator
f4eed48b58 is described below

commit f4eed48b5863c888db98fad85c5b18c37f661e3d
Author: tqchen <[email protected]>
AuthorDate: Tue Apr 29 20:21:20 2025 -0400

    [FFI] Cython update map items to use lazy iterator
---
 ffi/src/ffi/container.cc           | 37 ++++++++++++++----
 python/tvm/ffi/container.py        | 80 +++++++++++++++++++++++++++++++-------
 tests/python/ffi/test_container.py | 23 +++++------
 3 files changed, 106 insertions(+), 34 deletions(-)

diff --git a/ffi/src/ffi/container.cc b/ffi/src/ffi/container.cc
index e5ffaadd9f..885e8395ed 100644
--- a/ffi/src/ffi/container.cc
+++ b/ffi/src/ffi/container.cc
@@ -60,14 +60,35 @@ TVM_FFI_REGISTER_GLOBAL("ffi.MapGetItem")
 TVM_FFI_REGISTER_GLOBAL("ffi.MapCount")
     .set_body_typed([](const ffi::MapObj* n, const Any& k) -> int64_t { return 
n->count(k); });
 
-TVM_FFI_REGISTER_GLOBAL("ffi.MapItems").set_body_typed([](const ffi::MapObj* 
n) -> Array<Any> {
-  Array<Any> rkvs;
-  for (const auto& kv : *n) {
-    rkvs.push_back(kv.first);
-    rkvs.push_back(kv.second);
-  }
-  return rkvs;
-});
+TVM_FFI_REGISTER_GLOBAL("ffi.MapForwardIterFunctor")
+    .set_body_typed([](const ffi::MapObj* n) -> ffi::Function {
+      class MapForwardIterFunctor {
+       public:
+        MapForwardIterFunctor(ffi::MapObj::iterator iter, 
ffi::MapObj::iterator end)
+            : iter_(iter), end_(end) {}
+        // 0 get current key
+        // 1 get current value
+        // 2 move to next: return true if success, false if end
+        Any operator()(int command) const {
+          if (command == 0) {
+            return (*iter_).first;
+          } else if (command == 1) {
+            return (*iter_).second;
+          } else {
+            ++iter_;
+            if (iter_ == end_) {
+              return false;
+            }
+            return true;
+          }
+        }
+
+       private:
+        mutable ffi::MapObj::iterator iter_;
+        ffi::MapObj::iterator end_;
+      };
+      return ffi::Function::FromUnpacked(MapForwardIterFunctor(n->begin(), 
n->end()));
+    });
 
 }  // namespace ffi
 }  // namespace tvm
diff --git a/python/tvm/ffi/container.py b/python/tvm/ffi/container.py
index 829cf8cf23..9dcc737078 100644
--- a/python/tvm/ffi/container.py
+++ b/python/tvm/ffi/container.py
@@ -76,9 +76,70 @@ class Array(core.Object):
         return _ffi_api.ArraySize(self)
 
 
+class MapKeys:
+    """Helper class to return keys view"""
+
+    def __init__(self, map):
+        self._map = map
+
+    def __len__(self):
+        return len(self._map)
+
+    def __iter__(self):
+        if self.__len__() == 0:
+            return
+        functor = _ffi_api.MapForwardIterFunctor(self._map)
+        while True:
+            k = functor(0)
+            yield k
+            if not functor(2):
+                break
+
+
+class MapValues:
+    """Helper class to return values view"""
+
+    def __init__(self, map):
+        self._map = map
+
+    def __len__(self):
+        return len(self._map)
+
+    def __iter__(self):
+        if self.__len__() == 0:
+            return
+        functor = _ffi_api.MapForwardIterFunctor(self._map)
+        while True:
+            v = functor(1)
+            yield v
+            if not functor(2):
+                break
+
+
+class MapItems:
+    """Helper class to return items view"""
+
+    def __init__(self, map):
+        self._map = map
+
+    def __len__(self):
+        return len(self._map)
+
+    def __iter__(self):
+        if self.__len__() == 0:
+            return
+        functor = _ffi_api.MapForwardIterFunctor(self._map)
+        while True:
+            k = functor(0)
+            v = functor(1)
+            yield (k, v)
+            if not functor(2):
+                break
+
+
 @register_object("object.Map")
 class Map(core.Object):
-    """Map container"""
+    """Map container."""
 
     def __init__(self, input_dict: Dict[Any, Any]):
         list_kvs = []
@@ -93,26 +154,15 @@ class Map(core.Object):
     def __contains__(self, k):
         return _ffi_api.MapCount(self, k) != 0
 
-    def __iter__(self):
-        akvs = _ffi_api.MapItems(self)
-        for i in range(len(self)):
-            yield akvs[i * 2]
-
-    def __dir__(self):
-        return sorted(dir(self.__class__) + ["type_key"])
-
     def keys(self):
-        return iter(self)
+        return MapKeys(self)
 
     def values(self):
-        akvs = _ffi_api.MapItems(self)
-        for i in range(len(self)):
-            yield akvs[i * 2 + 1]
+        return MapValues(self)
 
     def items(self):
         """Get the items from the map"""
-        akvs = _ffi_api.MapItems(self)
-        return [(akvs[i], akvs[i + 1]) for i in range(0, len(akvs), 2)]
+        return MapItems(self)
 
     def __len__(self):
         return _ffi_api.MapSize(self)
diff --git a/tests/python/ffi/test_container.py 
b/tests/python/ffi/test_container.py
index 3a2166dd20..44cb3f321b 100644
--- a/tests/python/ffi/test_container.py
+++ b/tests/python/ffi/test_container.py
@@ -46,17 +46,18 @@ def test_int_map():
     assert 3 in dd
     assert 4 in dd
     assert 5 not in amap
-    assert {x for x in amap} == {3, 4}
-    assert set(amap.keys()) == {3, 4}
-    assert set(amap.values()) == {2, 3}
+    assert tuple(amap.items()) == ((3, 2), (4, 3))
+    assert tuple(amap.keys()) == (3, 4)
+    assert tuple(amap.values()) == (2, 3)
 
 
 def test_str_map():
-    amap = tvm_ffi.convert({"a": 2, "b": 3})
-    assert "a" in amap
-    assert len(amap) == 2
-    dd = dict(amap.items())
-    assert amap["a"] == 2
-    assert amap.get("b") == 3
-    assert "a" in dd
-    assert "b" in dd
+    data = []
+    for i in reversed(range(10)):
+        data.append((f"a{i}", i))
+    amap = tvm_ffi.convert({k: v for k, v in data})
+    assert tuple(amap.items()) == tuple(data)
+    for k, v in data:
+        assert k in amap
+        assert amap[k] == v
+        assert amap.get(k) == v

Reply via email to