From 3a1b51f42031305efd484ed3ab8d0ebec7f6e50b Mon Sep 17 00:00:00 2001
From: Mark Florisson <markflorisson88@gmail.com>
Date: Wed, 16 Mar 2011 11:10:26 +0100
Subject: [PATCH] Implemented 'with gil:' statement

---
 Cython/Compiler/Code.py                        |   39 ++++++++++++
 Cython/Compiler/Nodes.py                       |   72 ++++++++++++++--------
 Cython/Compiler/ParseTreeTransforms.py         |   22 +++++++
 Cython/Compiler/Parsing.py                     |    2 +-
 Cython/Compiler/Symtab.py                      |    9 +++
 tests/errors/incorrectly_nested_gil_blocks.pyx |   28 +++++++++
 tests/run/with_gil.pyx                         |   77 ++++++++++++++++++++++++
 7 files changed, 223 insertions(+), 26 deletions(-)
 create mode 100644 tests/errors/incorrectly_nested_gil_blocks.pyx
 create mode 100644 tests/run/with_gil.pyx

diff --git a/Cython/Compiler/Code.py b/Cython/Compiler/Code.py
index 41221cf..383471f 100644
--- a/Cython/Compiler/Code.py
+++ b/Cython/Compiler/Code.py
@@ -1315,6 +1315,45 @@ class CCodeWriter(object):
                     doc_code,
                     term))
 
+    # GIL methods
+
+    def put_ensure_gil(self):
+        """
+        Acquire the GIL. The generated code is safe even when no PyThreadState
+        has been allocated for this thread (for threads not initialized by
+        using the Python API). Additionally, the code generated by this method
+        may be called recursively.
+        """
+        from Cython.Compiler import Nodes
+
+        self.globalstate.use_utility_code(Nodes.force_init_threads_utility_code)
+
+        self.putln("#ifdef WITH_THREAD")
+        self.putln("PyGILState_STATE _save = PyGILState_Ensure();")
+        self.putln("#endif")
+
+    def put_release_ensured_gil(self):
+        """
+        Releases the GIL, corresponds to `put_ensure_gil`.
+        """
+        self.putln("#ifdef WITH_THREAD")
+        self.putln("PyGILState_Release(_save);")
+        self.putln("#endif")
+
+    def put_acquire_gil(self):
+        """
+        Acquire the GIL. The thread's thread state must have been initialized
+        by a previous `put_release_gil`
+        """
+        self.putln("Py_BLOCK_THREADS")
+
+    def put_release_gil(self):
+        "Release the GIL, corresponds to `put_acquire_gil`."
+        self.putln("#ifdef WITH_THREAD")
+        self.putln("PyThreadState *_save = NULL;")
+        self.putln("#endif")
+        self.putln("Py_UNBLOCK_THREADS")
+
     # error handling
 
     def put_error_if_neg(self, pos, value):
diff --git a/Cython/Compiler/Nodes.py b/Cython/Compiler/Nodes.py
index 609489f..bcc126f 100644
--- a/Cython/Compiler/Nodes.py
+++ b/Cython/Compiler/Nodes.py
@@ -1288,9 +1288,11 @@ class FuncDefNode(StatNode, BlockNode):
             code.put(cenv.scope_class.type.declaration_code(Naming.outer_scope_cname))
             code.putln(";")
         self.generate_argument_declarations(lenv, code)
+
         for entry in lenv.var_entries:
             if not entry.in_closure:
                 code.put_var_declaration(entry)
+
         init = ""
         if not self.return_type.is_void:
             if self.return_type.is_pyobject:
@@ -1305,16 +1307,20 @@ class FuncDefNode(StatNode, BlockNode):
             code.put_trace_declarations()
         # ----- Extern library function declarations
         lenv.generate_library_function_declarations(code)
+
         # ----- GIL acquisition
         acquire_gil = self.acquire_gil
-        if acquire_gil:
-            env.use_utility_code(force_init_threads_utility_code)
-            code.putln("#ifdef WITH_THREAD")
-            code.putln("PyGILState_STATE _save = PyGILState_Ensure();")
-            code.putln("#endif")
+        acquire_gil_for_var_decls_only = (lenv.nogil and lenv.has_with_gil_block)
+
+        use_refnanny = not lenv.nogil or acquire_gil_for_var_decls_only
+
+        if acquire_gil or acquire_gil_for_var_decls_only:
+            code.put_ensure_gil()
+
         # ----- set up refnanny
-        if not lenv.nogil:
+        if use_refnanny:
             code.put_setup_refcount_context(self.entry.name)
+
         # ----- Automatic lead-ins for certain special functions
         if is_getbuffer_slot:
             self.getbuffer_init(code)
@@ -1329,8 +1335,12 @@ class FuncDefNode(StatNode, BlockNode):
             code.putln("if (unlikely(!%s)) {" % Naming.cur_scope_cname)
             if is_getbuffer_slot:
                 self.getbuffer_error_cleanup(code)
-            if not lenv.nogil:
+
+            if use_refnanny:
                 code.put_finish_refcount_context()
+                if acquire_gil_for_var_decls_only:
+                    code.put_release_ensured_gil()
+
             # FIXME: what if the error return value is a Python value?
             code.putln("return %s;" % self.error_value())
             code.putln("}")
@@ -1376,6 +1386,9 @@ class FuncDefNode(StatNode, BlockNode):
             if entry.type.is_buffer:
                 Buffer.put_acquire_arg_buffer(entry, code, self.pos)
 
+        if acquire_gil_for_var_decls_only:
+            code.put_release_ensured_gil()
+
         # -------------------------
         # ----- Function body -----
         # -------------------------
@@ -1494,13 +1507,22 @@ class FuncDefNode(StatNode, BlockNode):
                 code.put_trace_return(Naming.retval_cname)
             else:
                 code.put_trace_return("Py_None")
+
         if not lenv.nogil:
+            # GIL holding funcion
+            code.put_finish_refcount_context()
+        elif acquire_gil_for_var_decls_only:
+            # 'nogil' function with 'with gil:' block, tear down refnanny
+            code.putln("#if CYTHON_REFNANNY")
+            code.begin_block()
+            code.put_ensure_gil()
             code.put_finish_refcount_context()
+            code.put_release_ensured_gil()
+            code.end_block()
+            code.putln("#endif")
 
         if acquire_gil:
-            code.putln("#ifdef WITH_THREAD")
-            code.putln("PyGILState_Release(_save);")
-            code.putln("#endif")
+            code.put_release_ensured_gil()
 
         if not self.return_type.is_void:
             code.putln("return %s;" % Naming.retval_cname)
@@ -1722,7 +1744,7 @@ class CFuncDefNode(FuncDefNode):
                 error(self.pos,
                       "Function with Python return type cannot be declared nogil")
             for entry in self.local_scope.var_entries:
-                if entry.type.is_pyobject:
+                if entry.type.is_pyobject and not entry.in_with_gil_block:
                     error(self.pos, "Function declared nogil has Python locals or temporaries")
 
     def analyse_expressions(self, env):
@@ -5132,10 +5154,16 @@ class GILStatNode(TryFinallyStatNode):
             body = body,
             finally_clause = GILExitNode(pos, state = state))
 
+    def analyse_declarations(self, env):
+        env._in_with_gil_block = (self.state == 'gil')
+        if self.state == 'gil':
+            env.has_with_gil_block = True
+        return super(GILStatNode, self).analyse_declarations(env)
+
     def analyse_expressions(self, env):
         env.use_utility_code(force_init_threads_utility_code)
         was_nogil = env.nogil
-        env.nogil = 1
+        env.nogil = self.state == 'nogil'
         TryFinallyStatNode.analyse_expressions(self, env)
         env.nogil = was_nogil
 
@@ -5143,18 +5171,14 @@ class GILStatNode(TryFinallyStatNode):
 
     def generate_execution_code(self, code):
         code.mark_pos(self.pos)
-        code.putln("{")
+        code.begin_block()
         if self.state == 'gil':
-            code.putln("#ifdef WITH_THREAD")
-            code.putln("PyGILState_STATE _save = PyGILState_Ensure();")
-            code.putln("#endif")
+            code.put_ensure_gil()
         else:
-            code.putln("#ifdef WITH_THREAD")
-            code.putln("PyThreadState *_save = NULL;")
-            code.putln("#endif")
-            code.putln("Py_UNBLOCK_THREADS")
+            code.put_release_gil()
+
         TryFinallyStatNode.generate_execution_code(self, code)
-        code.putln("}")
+        code.end_block()
 
 
 class GILExitNode(StatNode):
@@ -5169,11 +5193,9 @@ class GILExitNode(StatNode):
 
     def generate_execution_code(self, code):
         if self.state == 'gil':
-            code.putln("#ifdef WITH_THREAD")
-            code.putln("PyGILState_Release(_save);")
-            code.putln("#endif")
+            code.put_release_ensured_gil()
         else:
-            code.putln("Py_BLOCK_THREADS")
+            code.put_acquire_gil()
 
 
 class CImportStatNode(StatNode):
diff --git a/Cython/Compiler/ParseTreeTransforms.py b/Cython/Compiler/ParseTreeTransforms.py
index aeacb79..4d9d6da 100644
--- a/Cython/Compiler/ParseTreeTransforms.py
+++ b/Cython/Compiler/ParseTreeTransforms.py
@@ -1537,15 +1537,28 @@ class GilCheck(VisitorTransform):
     def __call__(self, root):
         self.env_stack = [root.scope]
         self.nogil = False
+
+        # True for 'cdef func() nogil:' functions, as the GIL may be held while
+        # calling this function (thus contained 'nogil' blocks may be valid).
+        self.nogil_declarator_only = False
         return super(GilCheck, self).__call__(root)
 
     def visit_FuncDefNode(self, node):
         self.env_stack.append(node.local_scope)
         was_nogil = self.nogil
         self.nogil = node.local_scope.nogil
+
+        if self.nogil:
+            self.nogil_declarator_only = True
+
         if self.nogil and node.nogil_check:
             node.nogil_check(node.local_scope)
+
         self.visitchildren(node)
+
+        # This cannot be nested, so it doesn't need backup/restore
+        self.nogil_declarator_only = False
+
         self.env_stack.pop()
         self.nogil = was_nogil
         return node
@@ -1555,6 +1568,15 @@ class GilCheck(VisitorTransform):
         if self.nogil and node.nogil_check: node.nogil_check()
         was_nogil = self.nogil
         self.nogil = (node.state == 'nogil')
+
+        if was_nogil == self.nogil and not self.nogil_declarator_only:
+            if not was_nogil:
+                error(node.pos, "Trying to acquire the GIL while it is "
+                                "already held.")
+            else:
+                error(node.pos, "Trying to release the GIL while it was "
+                                "previously released.")
+
         self.visitchildren(node)
         self.nogil = was_nogil
         return node
diff --git a/Cython/Compiler/Parsing.py b/Cython/Compiler/Parsing.py
index 91ddf6a..34c2d35 100644
--- a/Cython/Compiler/Parsing.py
+++ b/Cython/Compiler/Parsing.py
@@ -1547,7 +1547,7 @@ def p_with_statement(s):
 
 def p_with_items(s):
     pos = s.position()
-    if not s.in_python_file and s.sy == 'IDENT' and s.systring == 'nogil':
+    if not s.in_python_file and s.sy == 'IDENT' and s.systring in ('nogil', 'gil'):
         state = s.systring
         s.next()
         if s.sy == ',':
diff --git a/Cython/Compiler/Symtab.py b/Cython/Compiler/Symtab.py
index 17fc975..9d55813 100644
--- a/Cython/Compiler/Symtab.py
+++ b/Cython/Compiler/Symtab.py
@@ -176,6 +176,7 @@ class Entry(object):
     buffer_aux = None
     prev_entry = None
     might_overflow = 0
+    in_with_gil_block = 0
 
     def __init__(self, name, cname, type, pos = None, init = None):
         self.name = name
@@ -1261,6 +1262,12 @@ class ModuleScope(Scope):
 
 class LocalScope(Scope):
 
+    # Does the function have a 'with gil:' block?
+    has_with_gil_block = False
+
+    # Transient attribute, used for symbol table variable declarations
+    _in_with_gil_block = False
+
     def __init__(self, name, outer_scope, parent_scope = None):
         if parent_scope is None:
             parent_scope = outer_scope
@@ -1293,6 +1300,8 @@ class LocalScope(Scope):
             entry.init = "0"
         entry.init_to_none = (type.is_pyobject or type.is_unspecified) and Options.init_local_none
         entry.is_local = 1
+
+        entry.in_with_gil_block = self._in_with_gil_block
         self.var_entries.append(entry)
         return entry
 
diff --git a/tests/errors/incorrectly_nested_gil_blocks.pyx b/tests/errors/incorrectly_nested_gil_blocks.pyx
new file mode 100644
index 0000000..aabd8a7
--- /dev/null
+++ b/tests/errors/incorrectly_nested_gil_blocks.pyx
@@ -0,0 +1,28 @@
+with gil:
+    pass
+
+with nogil:
+    with nogil:
+        pass
+
+cdef void without_gil() nogil:
+   # This is not an error, as 'func' *may* be called without the GIL, but it
+   # may also be held.
+    with nogil:
+        pass
+
+cdef void with_gil() with gil:
+    # This is an error, as the GIL is acquired already
+    with gil:
+        pass
+
+def func():
+    with gil:
+        pass
+
+_ERRORS = u'''
+1:5: Trying to acquire the GIL while it is already held.
+5:9: Trying to release the GIL while it was previously released.
+16:9: Trying to acquire the GIL while it is already held.
+20:9: Trying to acquire the GIL while it is already held.
+'''
diff --git a/tests/run/with_gil.pyx b/tests/run/with_gil.pyx
new file mode 100644
index 0000000..e4cd410
--- /dev/null
+++ b/tests/run/with_gil.pyx
@@ -0,0 +1,77 @@
+"""
+Most of these functions are 'cdef' functions, so we need to test using 'def'
+test wrappers.
+"""
+
+from libc.stdio cimport printf
+from cpython.ref cimport PyObject, Py_INCREF
+
+import sys
+
+try:
+    import StringIO
+except ImportError:
+    import io as StringIO
+
+def simple_func():
+    """
+    >>> simple_func()
+    ['spam', 'ham']
+    ('star', 'twinkle')
+    """
+    with nogil:
+        with gil:
+            print ['spam', 'ham']
+        cdef_simple_func()
+
+cdef void cdef_simple_func() nogil:
+    with gil:
+        print ('star', 'twinkle')
+
+def with_gil():
+    """
+    >>> with_gil()
+    None
+    {'spam': 'ham'}
+    """
+    print x
+    with nogil:
+        with gil:
+            x = dict(spam='ham')
+    print x
+
+
+cdef void without_gil() nogil:
+    with gil:
+        x = list(('foo', 'bar'))
+        raise NameError
+
+    with gil:
+        print "unreachable"
+
+def test_without_gil():
+    """
+    >>> test_without_gil()
+    Exception NameError in 'with_gil.without_gil' ignored
+    """
+    # Doctest doesn't capture-and-match stderr
+    stderr, sys.stderr = sys.stderr, StringIO.StringIO()
+    without_gil()
+    sys.stdout.write(sys.stderr.getvalue())
+    sys.stderr = stderr
+
+cdef PyObject *nogil_propagate_exception() nogil except NULL:
+    with nogil:
+        with gil:
+            raise Exception("This exception propagates!")
+    return <PyObject *> 1
+
+def test_nogil_propagate_exception():
+    """
+    >>> test_nogil_propagate_exception()
+    Traceback (most recent call last):
+        ...
+    Exception: This exception propagates!
+    """
+    nogil_propagate_exception()
+
-- 
1.7.4.1

