Author: Armin Rigo <ar...@tunes.org>
Branch: py3.5-corowrapper
Changeset: r87189:72809bf56b82
Date: 2016-09-17 20:15 +0200
http://bitbucket.org/pypy/pypy/changeset/72809bf56b82/

Log:    sys.set_coroutine_wrapper()

diff --git a/pypy/interpreter/executioncontext.py 
b/pypy/interpreter/executioncontext.py
--- a/pypy/interpreter/executioncontext.py
+++ b/pypy/interpreter/executioncontext.py
@@ -22,7 +22,8 @@
     # XXX [fijal] but they're not. is_being_profiled is guarded a bit all
     #     over the place as well as w_tracefunc
 
-    _immutable_fields_ = ['profilefunc?', 'w_tracefunc?']
+    _immutable_fields_ = ['profilefunc?', 'w_tracefunc?',
+                          'w_coroutine_wrapper_fn?']
 
     def __init__(self, space):
         self.space = space
@@ -33,6 +34,8 @@
         self.profilefunc = None
         self.w_profilefuncarg = None
         self.thread_disappeared = False   # might be set to True after 
os.fork()
+        self.w_coroutine_wrapper_fn = None
+        self.in_coroutine_wrapper = False
 
     @staticmethod
     def _mark_thread_disappeared(space):
diff --git a/pypy/interpreter/pyframe.py b/pypy/interpreter/pyframe.py
--- a/pypy/interpreter/pyframe.py
+++ b/pypy/interpreter/pyframe.py
@@ -253,17 +253,36 @@
     run._always_inline_ = True
 
     def initialize_as_generator(self, name, qualname):
+        space = self.space
         if self.getcode().co_flags & pycode.CO_COROUTINE:
             from pypy.interpreter.generator import Coroutine
             gen = Coroutine(self, name, qualname)
+            ec = space.getexecutioncontext()
+            w_wrapper = ec.w_coroutine_wrapper_fn
         else:
             from pypy.interpreter.generator import GeneratorIterator
             gen = GeneratorIterator(self, name, qualname)
-        if self.space.config.translation.rweakref:
+            ec = None
+            w_wrapper = None
+
+        if space.config.translation.rweakref:
             self.f_generator_wref = rweakref.ref(gen)
         else:
             self.f_generator_nowref = gen
-        return self.space.wrap(gen)
+        w_gen = space.wrap(gen)
+
+        if w_wrapper is not None:
+            if ec.in_coroutine_wrapper:
+                raise oefmt(space.w_RuntimeError,
+                            "coroutine wrapper %R attempted "
+                            "to recursively wrap %R",
+                            w_wrapper, w_gen)
+            ec.in_coroutine_wrapper = True
+            try:
+                w_gen = space.call_function(w_wrapper, w_gen)
+            finally:
+                ec.in_coroutine_wrapper = False
+        return w_gen
 
     def execute_frame(self, in_generator=None, w_arg_or_err=None):
         """Execute this frame.  Main entry point to the interpreter.
diff --git a/pypy/interpreter/test/test_coroutine.py 
b/pypy/interpreter/test/test_coroutine.py
--- a/pypy/interpreter/test/test_coroutine.py
+++ b/pypy/interpreter/test/test_coroutine.py
@@ -49,3 +49,20 @@
         cr = f(X())
         assert next(cr.__await__()) == 20
         """
+
+    def test_set_coroutine_wrapper(self): """
+        import sys
+        async def f():
+            pass
+        seen = []
+        def my_wrapper(cr):
+            seen.append(cr)
+            return 42
+        assert sys.get_coroutine_wrapper() is None
+        sys.set_coroutine_wrapper(my_wrapper)
+        assert sys.get_coroutine_wrapper() is my_wrapper
+        cr = f()
+        assert cr == 42
+        sys.set_coroutine_wrapper(None)
+        assert sys.get_coroutine_wrapper() is None
+        """
diff --git a/pypy/module/sys/__init__.py b/pypy/module/sys/__init__.py
--- a/pypy/module/sys/__init__.py
+++ b/pypy/module/sys/__init__.py
@@ -91,6 +91,9 @@
         'float_repr_style'      : 'system.get_float_repr_style(space)',
         'getdlopenflags'        : 'system.getdlopenflags',
         'setdlopenflags'        : 'system.setdlopenflags',
+
+        'get_coroutine_wrapper' : 'vm.get_coroutine_wrapper',
+        'set_coroutine_wrapper' : 'vm.set_coroutine_wrapper',
         }
 
     if sys.platform == 'win32':
diff --git a/pypy/module/sys/vm.py b/pypy/module/sys/vm.py
--- a/pypy/module/sys/vm.py
+++ b/pypy/module/sys/vm.py
@@ -258,3 +258,19 @@
         return space.new_interned_w_str(w_str)
     raise oefmt(space.w_TypeError, "intern() argument must be string.")
 
+def get_coroutine_wrapper(space):
+    "Return the wrapper for coroutine objects set by 
sys.set_coroutine_wrapper."
+    ec = space.getexecutioncontext()
+    if ec.w_coroutine_wrapper_fn is None:
+        return space.w_None
+    return ec.w_coroutine_wrapper_fn
+
+def set_coroutine_wrapper(space, w_wrapper):
+    "Set a wrapper for coroutine objects."
+    ec = space.getexecutioncontext()
+    if space.is_w(w_wrapper, space.w_None):
+        ec.w_coroutine_wrapper_fn = None
+    elif space.is_true(space.callable(w_wrapper)):
+        ec.w_coroutine_wrapper_fn = w_wrapper
+    else:
+        raise oefmt(space.w_TypeError, "callable expected, got %T", w_wrapper)
_______________________________________________
pypy-commit mailing list
pypy-commit@python.org
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to