Author: Vincent Michel <[email protected]>
Branch: bpo-35409
Changeset: r96994:3c32e9b24388
Date: 2019-07-14 14:09 +0200
http://bitbucket.org/pypy/pypy/changeset/3c32e9b24388/

Log:    Make sure the athrow coroutine of an asynchronous generator gets
        closed when the underlying generator raises an exception

diff --git a/pypy/interpreter/generator.py b/pypy/interpreter/generator.py
--- a/pypy/interpreter/generator.py
+++ b/pypy/interpreter/generator.py
@@ -661,15 +661,7 @@
         return self.do_send(w_arg)
 
     def descr_throw(self, w_type, w_val=None, w_tb=None):
-        space = self.space
-        if self.state == self.ST_CLOSED:
-            raise OperationError(space.w_StopIteration, space.w_None)
-        try:
-            w_value = self.async_gen.throw(w_type, w_val, w_tb)
-            return self.unwrap_value(w_value)
-        except OperationError as e:
-            self.state = self.ST_CLOSED
-            raise
+        return self.do_throw(w_type, w_val, w_tb)
 
     def descr_close(self):
         self.state = self.ST_CLOSED
@@ -715,6 +707,17 @@
             self.state = self.ST_CLOSED
             raise
 
+    def do_throw(self, w_type, w_val, w_tb):
+        space = self.space
+        if self.state == self.ST_CLOSED:
+            raise OperationError(space.w_StopIteration, space.w_None)
+        try:
+            w_value = self.async_gen.throw(w_type, w_val, w_tb)
+            return self.unwrap_value(w_value)
+        except OperationError as e:
+            self.state = self.ST_CLOSED
+            raise
+
 
 class AsyncGenAThrow(AsyncGenABase):
 
@@ -756,34 +759,31 @@
                 w_value = self.async_gen.send_ex(w_arg_or_err)
             return self.unwrap_value(w_value)
         except OperationError as e:
-            if e.match(space, space.w_StopAsyncIteration):
-                self.state = self.ST_CLOSED
-                if self.w_exc_type is None:
-                    # When aclose() is called we don't want to propagate
-                    # StopAsyncIteration; just raise StopIteration, signalling
-                    # that 'aclose()' is done.
-                    raise OperationError(space.w_StopIteration, space.w_None)
-            if e.match(space, space.w_GeneratorExit):
-                self.state = self.ST_CLOSED
-                # Ignore this error.
-                raise OperationError(space.w_StopIteration, space.w_None)
-            raise
+            self.handle_error(e)
 
-    def descr_throw(self, w_type, w_val=None, w_tb=None):
+    def do_throw(self, w_type, w_val, w_tb):
         space = self.space
         if self.state == self.ST_INIT:
             raise OperationError(self.space.w_RuntimeError,
                 space.newtext("can't do async_generator.athrow().throw()"))
+        if self.state == self.ST_CLOSED:
+            raise OperationError(space.w_StopIteration, space.w_None)
         try:
-            return AsyncGenABase.descr_throw(self, w_type, w_val, w_tb)
+            w_value = self.async_gen.throw(w_type, w_val, w_tb)
+            return self.unwrap_value(w_value)
         except OperationError as e:
-            if e.match(space, space.w_StopAsyncIteration):
-                if self.w_exc_type is None:
-                    # When aclose() is called we don't want to propagate
-                    # StopAsyncIteration; just raise StopIteration, signalling
-                    # that 'aclose()' is done.
-                    raise OperationError(space.w_StopIteration, space.w_None)
-            if e.match(space, space.w_GeneratorExit):
-                # Ignore this error.
+            self.handle_error(e)
+
+    def handle_error(self, e):
+        space = self.space
+        self.state = self.ST_CLOSED
+        if e.match(space, space.w_StopAsyncIteration):
+            if self.w_exc_type is None:
+                # When aclose() is called we don't want to propagate
+                # StopAsyncIteration; just raise StopIteration, signalling
+                # that 'aclose()' is done.
                 raise OperationError(space.w_StopIteration, space.w_None)
-            raise
+        if e.match(space, space.w_GeneratorExit):
+            # Ignore this error.
+            raise OperationError(space.w_StopIteration, space.w_None)
+        raise e
diff --git a/pypy/interpreter/test/test_coroutine.py 
b/pypy/interpreter/test/test_coroutine.py
--- a/pypy/interpreter/test/test_coroutine.py
+++ b/pypy/interpreter/test/test_coroutine.py
@@ -453,6 +453,15 @@
         assert ex.value.args == expected
         """
 
+    def test_async_yield_athrow_send_after_exception(self): """
+        async def ag():
+            yield 42
+
+        athrow_coro = ag().athrow(ValueError)
+        raises(ValueError, athrow_coro.send, None)
+        raises(StopIteration, athrow_coro.send, None)
+        """
+
     def test_async_yield_athrow_throw(self): """
         async def ag():
             yield 42
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to