Stefan wrote:

I agree with Robert. As long as Cython does not support closures, for
example, it cannot come close enough to being a real option for speeding


I couldn't resist the challenge :-)

Attached is a prototype using transforms to add closure support to Cython.

NB! It's not ready for prime-time yet. Unfortunately I must leave it for some days now, so I post it it prototype state. Mainly so that others don't start on the same thing (though feel free to take over this, just give me a note).

I mainly write it to see how it would be like to write a "real" transform. Which was not too bad...

It is a bit hacky but I think the approach should give correct results.

Known fatal bugs:
- I don't know the first thing about reference counting, CPython etc.. At least on one occasion I messed up GC-ing. This is probably something that others will be much quicker at spotting. - No name mangling, inner functions cannot collide in name with outer. Quick fix but don't have time now; also one might want a more generic "name mangler" support mechanism rather than just checking for collisions, not sure about how to do this. - No consideration for anything nontrivial: The global keyword and accessing variables in modules comes to mind.

Conscious limitations:
- Only Python def's, not inner cdefs, and all bound vars are bound as Python objects. This restriction can be removed later (for optimization), but I think Python-only inner defs should work fine.

Strategy:
- First, run a transform that records used and assigned symbols within a function (don't know if this is already done anywhere, probably this might be redundant and done in the scope system? Suggestions?). - Then, run a transform which lifts out inner functions, and replace with assignment to a method bound to a tuple containing the variables that should be bound (apparently this works. But one might create a specific type containing a tuple as a field as well.) - The lifted-out functions gets an instruction added first to unpack the "self" tuple to the correct names (making the names shadowing the outer scope).

How to run (if you want to help me out or have a look, otherwise don't bother):

Apply patch.

$ cat <<END > test.pyx
def make_adder(n):
   def adder(x):
       return x + n
   return adder


def timesthree(n):
   def util(x):
       return x * 2
   return n + util(n)
END

$ cat <<END > test.sh
python cython.py \
 -Tafter_parse:Cython.Compiler.Transforms.InnerFunctions.FunctionSymbols \
 -Tafter_parse:Cython.Compiler.Transforms.InnerFunctions.InnerFunctions \
-Tafter_analyse_function:Cython.Compiler.Transforms.InnerFunctions.MethodTableIndex \
 test.pyx
gcc -Wall -shared -fPIC -I/usr/include/python2.5 -o test.so test.c
END

$ python
>>> import test
>>> a = test.make_adder(20)
>>> b = test.make_adder(10)
>>> a
<built-in method adder of tuple object at 0x2af646ca8910>
>>> a(13)
33
>>> b(13)
23
>>> test.timesthree(100)
300

--
Dag Sverre

diff -r e005b58d83b8 Cython/Compiler/ModuleNode.py
--- a/Cython/Compiler/ModuleNode.py	Tue Apr 08 01:25:09 2008 -0700
+++ b/Cython/Compiler/ModuleNode.py	Wed Apr 09 11:50:03 2008 +0200
@@ -56,7 +56,7 @@ class ModuleNode(Nodes.Node, Nodes.Block
     #  module_temp_cname    string
     #  full_module_name     string
 
-    children_attrs = ["body"]
+    child_attrs = ["body"]
     
     def analyse_declarations(self, env):
         if Options.embed_pos_in_docstring:
@@ -68,6 +68,7 @@ class ModuleNode(Nodes.Node, Nodes.Block
         self.body.analyse_declarations(env)
     
     def process_implementation(self, env, options, result):
+        options.transforms.run('after_parse', self, env=env)
         self.analyse_declarations(env)
         env.check_c_classes()
         self.body.analyse_expressions(env)
@@ -77,6 +78,7 @@ class ModuleNode(Nodes.Node, Nodes.Block
         if self.has_imported_c_functions():
             self.module_temp_cname = env.allocate_temp_pyobject()
             env.release_temp(self.module_temp_cname)
+        options.transforms.run('before_module_c_code', self, env=env)
         self.generate_c_code(env, options, result)
         self.generate_h_code(env, options, result)
         self.generate_api_code(env, result)
@@ -247,6 +249,7 @@ class ModuleNode(Nodes.Node, Nodes.Block
         self.generate_interned_name_decls(env, code)
         self.generate_py_string_decls(env, code)
         self.generate_cached_builtins_decls(env, code)
+        self.generate_method_table_decl(env, code)
         self.body.generate_function_definitions(env, code, options.transforms)
         code.mark_pos(None)
         self.generate_interned_name_table(env, code)
@@ -1279,6 +1282,15 @@ class ModuleNode(Nodes.Node, Nodes.Block
         code.putln(
             "};")
     
+    
+    def generate_method_table_decl(self, env, code):
+        # Predeclare the method table for use by closures
+        code.putln("")
+        code.putln(
+            "static struct PyMethodDef %s[];" % env.method_table_cname
+        )
+        
+    
     def generate_method_table(self, env, code):
         code.putln("")
         code.putln(
diff -r e005b58d83b8 Cython/Compiler/Nodes.py
--- a/Cython/Compiler/Nodes.py	Tue Apr 08 01:25:09 2008 -0700
+++ b/Cython/Compiler/Nodes.py	Wed Apr 09 11:50:03 2008 +0200
@@ -188,6 +188,7 @@ class Node:
                     self._end_pos = max([child.end_pos() for child in flat])
             return self._end_pos
 
+        
 
 class BlockNode:
     #  Mixin class for nodes representing a declaration block.
diff -r e005b58d83b8 Cython/Compiler/Parsing.py
--- a/Cython/Compiler/Parsing.py	Tue Apr 08 01:25:09 2008 -0700
+++ b/Cython/Compiler/Parsing.py	Wed Apr 09 11:50:03 2008 +0200
@@ -1353,7 +1353,7 @@ def p_statement(s, level, cdef_flag = 0,
             if api:
                 error(s.pos, "'api' not allowed with this statement")
             elif s.sy == 'def':
-                if level not in ('module', 'class', 'c_class', 'property'):
+                if level not in ('module', 'class', 'c_class', 'property', 'function'):
                     s.error('def statement not allowed here')
                 s.level = level
                 return p_def_statement(s)
diff -r e005b58d83b8 Cython/Compiler/Transform.py
--- a/Cython/Compiler/Transform.py	Tue Apr 08 01:25:09 2008 -0700
+++ b/Cython/Compiler/Transform.py	Wed Apr 09 11:50:03 2008 +0200
@@ -3,6 +3,7 @@
 #
 import Nodes
 import ExprNodes
+import inspect
 
 class Transform(object):
     #  parent_stack [Node]       A stack providing information about where in the tree
@@ -45,12 +46,11 @@ class Transform(object):
     def process_list(self, l, name):
         """Calls process_node on all the items in l, using the name one gets when appending
         [idx] to the name. Each item in l is transformed in-place by the item process_node
-        returns, then l is returned."""
-        # Comment: If moving to a copying strategy, it might makes sense to return a
-        # new list instead.
+        returns, then l is returned. If process_node returns None, the item is removed
+        from the list."""
         for idx in xrange(len(l)):
             l[idx] = self.process_node(l[idx], "%s[%d]" % (name, idx))
-        return l
+        return [x for x in l if x is not None]
 
     def process_node(self, node, name):
         """Override this method to process nodes. name specifies which kind of relation the
@@ -58,6 +58,46 @@ class Transform(object):
         should use for this relation, which can either be the same node, None to remove
         the node, or a different node."""
         raise InternalError("Not implemented")
+    
+    def process_tree(self, root):
+        self.process_node(root, "(root)")
+
+
+
+class VisitorTransform(Transform):
+    def __init__(self):
+        super(VisitorTransform, self).__init__()
+        self.visitmethods = {}
+
+    def process_node(self, node, name):
+        # Pass on to calls registered in self.visitmethods
+        if node is None:
+            return None
+            
+        cls = node.__class__
+        mname = "process_" + cls.__name__
+        m = self.visitmethods.get(mname)
+        if m is None:
+            # Must resolve, try entire hierarchy
+            for cls in inspect.getmro(cls):
+                m = getattr(self, "process_" + cls.__name__, None)
+                if m is not None:
+                    break
+            if m is None: raise RuntimeError("Not a Node descendant: " + node)
+            self.visitmethods[mname] = m
+        
+        return m(node, name)
+    
+    def process_Node(self, node, name):
+        self.process_children(node)
+        return node
+
+# Utils
+def ensure_statlist(node):
+    if not isinstance(node, Nodes.StatListNode):
+        node = Nodes.StatListNode(pos=node.pos, stats=[node])
+    return node
+
 
 class PrintTree(Transform):
     """Prints a representation of the tree to standard output.
@@ -92,15 +132,22 @@ class PrintTree(Transform):
             return "(none)"
         else:
             result = node.__class__.__name__
-            if isinstance(node, ExprNodes.ExprNode):
+            if isinstance(node, ExprNodes.NameNode):
+                result += "(type=%s, name=\"%s\")" % (repr(node.type), node.name)
+            elif isinstance(node, Nodes.DefNode):
+                result += "(name=\"%s\")" % node.name
+            elif isinstance(node, ExprNodes.ExprNode):
                 t = node.type
                 result += "(type=%s)" % repr(t)
+                
             return result
 
 
 PHASES = [
+    'after_parse',             # run in Main.compile
     'before_analyse_function', # run in FuncDefNode.generate_function_definitions
-    'after_analyse_function'   # run in FuncDefNode.generate_function_definitions
+    'after_analyse_function',   # run in FuncDefNode.generate_function_definitions
+    'before_module_c_code',    # run in ModuleNode.process_implementation
 ]
 
 class TransformSet(dict):
@@ -111,6 +158,6 @@ class TransformSet(dict):
         assert name in self
         for transform in self[name]:
             transform.initialize(phase=name, **options)
-            transform.process_node(node, "(root)")
+            transform.process_tree(node)
 
 
diff -r e005b58d83b8 Cython/Compiler/Transforms/InnerFunctions.py
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/Cython/Compiler/Transforms/InnerFunctions.py	Wed Apr 09 11:50:03 2008 +0200
@@ -0,0 +1,204 @@
+from Cython.Compiler.Transform import Transform, VisitorTransform, ensure_statlist
+from Cython.Compiler.Nodes import *
+from Cython.Compiler.ExprNodes import *
+from sets import Set
+
+class FunctionSymbols(VisitorTransform):
+    """
+    Fills in information regarding used symbols for each function.
+    Each FuncDefNode will have the following attributes added as
+    a sets.Set():
+      used_symbols      - Any symbols used
+      assigned_symbols  - Any symbols assigned to (including def-s)
+      external_symbols  - Symbols that are used but never assigned to
+    Correctly handles nested functions.
+    
+    About transform process phase:
+        Made for running in after_parse. Expects SingleAssignmentNode
+        also for tuple unpacking.
+    """
+    
+    current_func = None
+    in_lhs = False
+    
+    def declare_symbol(self, symbolname):
+        if self.current_func is not None:
+            self.current_func.assigned_symbols.add(symbolname)
+    
+    def use_symbol(self, symbolname):
+        if self.current_func is not None:
+            self.current_func.used_symbols.add(symbolname)
+
+    def process_FuncDefNode(self, node, name):
+        # Push, visit children, pop
+        parent = self.current_func
+        self.current_func = node
+        node.used_symbols = Set()
+        node.assigned_symbols = Set()
+        self.process_children(node)
+        node.external_symbols = node.used_symbols - node.assigned_symbols
+        print node.name, node.assigned_symbols, node.external_symbols
+
+        # If this function is an inner function
+        if parent is not None:
+            # its name can be considered assigned in the parent function
+            parent.assigned_symbols.add(node.name)
+            # and also, parent function must inherit external symbols
+            parent.used_symbols.union_update(node.external_symbols)
+
+        self.current_func = parent
+
+        return node
+    
+    def process_CDeclaratorNode(self, node, name):
+        self.declare_symbol(node.name)
+        return node
+    
+    def process_NameNode(self, node, name):
+        if self.in_lhs:
+            self.declare_symbol(node.name)
+        else:
+            self.use_symbol(node.name)
+        return node
+    
+    def process_SingleAssignmentNode(self, node, name):
+        self.in_lhs = True
+        node.lhs = self.process_node(node.lhs, "lhs")
+        self.in_lhs = False
+        node.rhs = self.process_node(node.rhs, "rhs")
+        return node
+
+
+class InnerFunctions(VisitorTransform):
+    """Prints a representation of the tree to standard output.
+    Subclass and override repr_of to provide more information
+    about nodes. """
+    
+    def __init__(self):
+        super(InnerFunctions, self).__init__()
+        self.current_func = None
+
+    def initialize(self, phase, env):
+        self.env = env
+        self.extracted_functions = []
+
+    def process_tree(self, root):
+        # In-tree processing
+        super(InnerFunctions, self).process_tree(root)
+        # Now, insert the extracted functions at the beginning of the module
+        root.body = ensure_statlist(root.body)
+        root.body.stats = self.extracted_functions + root.body.stats
+
+    def process_DefNode(self, node, name):
+        # Declare attributes we use in a DefNode
+        parent = self.current_func
+        self.current_func = node
+        if parent is not None:
+            result = self.process_inner_function(node)
+        else:
+            self.process_children(node)
+            result = node
+        self.current_func = parent
+        return result
+
+    def process_inner_function(self, node):
+        # Modifications to function body: Any external symbols are unpacked from the
+        # "self" tuple (using ClosureVarsNode).
+        
+        # Modifications to function definition: It is lifted out (will be inserted
+        # again in process_tree), and replaced with an assignment to a bound
+        # method which binds to a tuple closuring up the external vars
+        # (using InnerFunctionRefNode).
+
+        self.extracted_functions.append(node) # will be inserted in module body
+        self.process_children(node)
+
+        # Insert unpacking of closure vars
+        if len(node.external_symbols) > 0:
+            assignment = SingleAssignmentNode(pos=node.pos,
+                lhs=TupleNode(pos=node.pos, args=[
+                    NameNode(pos=node.pos, name=x) for x in node.external_symbols
+                ]),
+                rhs=ClosureVarsNode(pos=node.pos)
+            )
+        
+            node.body = ensure_statlist(node.body)
+            node.body.stats = [assignment] + node.body.stats
+
+        # Insert packing of closure vars. Empty tuple is ok.
+        # InnerFunctionReferences will later
+        # reference this tuple in another transform phase
+        closure_assignment = SingleAssignmentNode(pos=node.pos,
+            lhs=NameNode(pos=node.pos, name=node.name),
+            rhs=InnerFunctionRefNode(
+                    pos=node.pos,
+                    funcname=node.name,
+                    closure=TupleNode(pos=node.pos, args=[
+                        NameNode(pos=node.pos, name=x) for x in node.external_symbols
+                    ])
+            )
+        )
+        
+        return closure_assignment
+        
+class MethodTableIndex(VisitorTransform):
+    """
+    Adds the cmethod_table_entry attribute to each DefNode containing
+    a lookup into the method table (e.g. "__pyx_methods[3]").
+    
+    Also, each InnerFunctionRefNode gets the same attribute set, to the
+    value of the function it references.
+    """
+    def initialize(self, phase, env, genv, lenv):
+        self.method_table_cname = genv.method_table_cname
+        # Create name -> idx map for functions
+        name_to_idx = {}
+        entries = env.pyfunc_entries
+        for idx in range(len(entries)):
+            name_to_idx[entries[idx].name] = idx
+        self.name_to_idx = name_to_idx
+
+    def process_DefNode(self, node, name):
+        node.cmethod_table_entry = \
+            self.method_table_cname + "[%d]" % self.name_to_idx[node.name]
+        self.process_children(node)
+        return node
+
+    def process_InnerFunctionRefNode(self, node, name):
+        node.cmethod_table_entry = \
+            self.method_table_cname + "[%d]" % self.name_to_idx[node.funcname]
+        self.process_children(node)
+        return node
+
+class ClosureVarsNode(ExprNode):
+    """Simply accesses the C "self" variable (__pyx_self) in a C function."""
+    subexprs = []
+    
+    type = py_object_type
+    is_temp = 1
+    
+    def analyse_types(self, env):
+        pass
+    
+    def generate_result_code(self, code):
+        code.putln("%s = __pyx_self;" % self.result_code)
+    
+class InnerFunctionRefNode(ExprNode):
+    # closure          TupleNode       The expressions to store in the closure.
+    # funcname         string          Name of function to call.
+
+    subexprs = ["closure"]
+    
+    type = py_object_type
+    is_temp = 1
+    
+    def analyse_types(self, env):
+        self.closure.analyse_types(env)
+    
+    def generate_result_code(self, code):
+        self.method_table_name = "__pyx_methods"
+        code.putln("%s = PyCFunction_New(&%s, %s); " %
+            (self.result_code, self.cmethod_table_entry, self.closure.result_code)
+        )
+
+
_______________________________________________
Cython-dev mailing list
Cython-dev@codespeak.net
http://codespeak.net/mailman/listinfo/cython-dev

Reply via email to