Author: Armin Rigo <ar...@tunes.org>
Branch: cpyext-ext
Changeset: r82646:3af62800d459
Date: 2016-03-01 21:15 +0100
http://bitbucket.org/pypy/pypy/changeset/3af62800d459/

Log:    Test and fix for tp_iter and tp_iternext (also fixes the latter to
        not raise StopIteration but simply return NULL in case of
        exhaustion)

diff --git a/pypy/module/cpyext/slotdefs.py b/pypy/module/cpyext/slotdefs.py
--- a/pypy/module/cpyext/slotdefs.py
+++ b/pypy/module/cpyext/slotdefs.py
@@ -336,14 +336,6 @@
     space.get_and_call_args(w_descr, w_self, args)
     return 0
 
-@cpython_api([PyObject], PyObject, header=None)
-def slot_tp_iter(space, w_self):
-    return space.iter(w_self)
-
-@cpython_api([PyObject], PyObject, header=None)
-def slot_tp_iternext(space, w_self):
-    return space.next(w_self)
-
 from rpython.rlib.nonconst import NonConstant
 
 SLOTS = {}
@@ -437,6 +429,33 @@
             return space.call_function(str_fn, w_self)
         api_func = slot_tp_str.api_func
 
+    elif name == 'tp_iter':
+        iter_fn = w_type.getdictvalue(space, '__iter__')
+        if iter_fn is None:
+            return
+
+        @cpython_api([PyObject], PyObject, header=header)
+        @func_renamer("cpyext_%s_%s" % (name.replace('.', '_'), typedef.name))
+        def slot_tp_iter(space, w_self):
+            return space.call_function(iter_fn, w_self)
+        api_func = slot_tp_iter.api_func
+
+    elif name == 'tp_iternext':
+        iternext_fn = w_type.getdictvalue(space, 'next')
+        if iternext_fn is None:
+            return
+
+        @cpython_api([PyObject], PyObject, header=header)
+        @func_renamer("cpyext_%s_%s" % (name.replace('.', '_'), typedef.name))
+        def slot_tp_iternext(space, w_self):
+            try:
+                return space.call_function(iternext_fn, w_self)
+            except OperationError, e:
+                if not e.match(space, space.w_StopIteration):
+                    raise
+                return None
+        api_func = slot_tp_iternext.api_func
+
     else:
         return
 
diff --git a/pypy/module/cpyext/test/test_typeobject.py 
b/pypy/module/cpyext/test/test_typeobject.py
--- a/pypy/module/cpyext/test/test_typeobject.py
+++ b/pypy/module/cpyext/test/test_typeobject.py
@@ -645,32 +645,49 @@
 
     def test_tp_iter(self):
         module = self.import_extension('foo', [
-           ("tp_iter", "METH_O",
+           ("tp_iter", "METH_VARARGS",
             '''
-                 if (!args->ob_type->tp_iter)
+                 PyTypeObject *type = (PyTypeObject *)PyTuple_GET_ITEM(args, 
0);
+                 PyObject *obj = PyTuple_GET_ITEM(args, 1);
+                 if (!type->tp_iter)
                  {
                      PyErr_SetNone(PyExc_ValueError);
                      return NULL;
                  }
-                 return args->ob_type->tp_iter(args);
+                 return type->tp_iter(obj);
              '''
              ),
-           ("tp_iternext", "METH_O",
+           ("tp_iternext", "METH_VARARGS",
             '''
-                 if (!args->ob_type->tp_iternext)
+                 PyTypeObject *type = (PyTypeObject *)PyTuple_GET_ITEM(args, 
0);
+                 PyObject *obj = PyTuple_GET_ITEM(args, 1);
+                 PyObject *result;
+                 if (!type->tp_iternext)
                  {
                      PyErr_SetNone(PyExc_ValueError);
                      return NULL;
                  }
-                 return args->ob_type->tp_iternext(args);
+                 result = type->tp_iternext(obj);
+                 if (!result && !PyErr_Occurred())
+                     result = PyString_FromString("stop!");
+                 return result;
              '''
              )
             ])
         l = [1]
-        it = module.tp_iter(l)
+        it = module.tp_iter(list, l)
         assert type(it) is type(iter([]))
-        assert module.tp_iternext(it) == 1
-        raises(StopIteration, module.tp_iternext, it)
+        assert module.tp_iternext(type(it), it) == 1
+        assert module.tp_iternext(type(it), it) == "stop!"
+        #
+        class LL(list):
+            def __iter__(self):
+                return iter(())
+        ll = LL([1])
+        it = module.tp_iter(list, ll)
+        assert type(it) is type(iter([]))
+        x = list(it)
+        assert x == [1]
 
     def test_bool(self):
         module = self.import_extension('foo', [
_______________________________________________
pypy-commit mailing list
pypy-commit@python.org
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to